Diffusers
Safetensors
EvalMDE / baselines /depth_pro.py
zeyuren2002's picture
Add files using upload-large-folder tool
4165f20 verified
# Reference: https://github.com/apple/ml-depth-pro
# Strictly follows official README Python API:
# model, transform = depth_pro.create_model_and_transforms()
# prediction = model.infer(image, f_px=f_px)
# depth = prediction["depth"] # in [m]
# focallength_px = prediction["focallength_px"]
#
# Depth Pro outputs *metric* depth. Returns key `depth_metric` plus `intrinsics`
# when focal length is recovered, so MoGe's compute_metrics can use the metric path.
import os
import sys
from typing import *
from pathlib import Path
import click
import torch
import torch.nn.functional as F
from moge.test.baseline import MGEBaselineInterface
class Baseline(MGEBaselineInterface):
def __init__(self, repo_path: str, checkpoint_path: str, precision: str, device: Union[torch.device, str]):
repo_path = os.path.abspath(repo_path)
if not Path(repo_path).exists():
raise FileNotFoundError(
f"Cannot find Depth Pro repo at {repo_path}. Clone https://github.com/apple/ml-depth-pro "
f"and pass --repo <path>."
)
src_path = os.path.join(repo_path, "src")
if src_path not in sys.path:
sys.path.insert(0, src_path)
import depth_pro
from depth_pro.depth_pro import DepthProConfig, DEFAULT_MONODEPTH_CONFIG_DICT
if not os.path.isabs(checkpoint_path):
checkpoint_path = os.path.join(repo_path, checkpoint_path)
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(
f"Cannot find Depth Pro checkpoint at {checkpoint_path}. "
f"Run `source get_pretrained_models.sh` inside the ml-depth-pro repo to download it."
)
device = torch.device(device)
precision_dtype = {"fp32": torch.float32, "fp16": torch.float16}[precision]
config = DepthProConfig(
patch_encoder_preset=DEFAULT_MONODEPTH_CONFIG_DICT.patch_encoder_preset,
image_encoder_preset=DEFAULT_MONODEPTH_CONFIG_DICT.image_encoder_preset,
decoder_features=DEFAULT_MONODEPTH_CONFIG_DICT.decoder_features,
checkpoint_uri=checkpoint_path,
use_fov_head=DEFAULT_MONODEPTH_CONFIG_DICT.use_fov_head,
fov_encoder_preset=DEFAULT_MONODEPTH_CONFIG_DICT.fov_encoder_preset,
)
model, _ = depth_pro.create_model_and_transforms(config=config, device=device, precision=precision_dtype)
model.eval()
self.model = model
self.device = device
self.precision_dtype = precision_dtype
@click.command()
@click.option('--repo', 'repo_path', type=click.Path(), default='../ml-depth-pro',
help='Path to the apple/ml-depth-pro repository.')
@click.option('--checkpoint', 'checkpoint_path', type=click.Path(),
default='checkpoints/depth_pro.pt',
help='Checkpoint path; relative paths are resolved against --repo.')
@click.option('--precision', type=click.Choice(['fp32', 'fp16']), default='fp32')
@click.option('--device', type=str, default='cuda')
@staticmethod
def load(repo_path: str, checkpoint_path: str, precision: str, device: str = 'cuda'):
return Baseline(repo_path, checkpoint_path, precision, device)
@torch.inference_mode()
def infer(self, image: torch.Tensor, intrinsics: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
omit_batch = image.ndim == 3
if omit_batch:
image = image.unsqueeze(0)
assert image.shape[0] == 1, "Depth Pro baseline only supports batch size 1"
_, _, H, W = image.shape
# Depth Pro transform: torchvision.Normalize([0.5]*3, [0.5]*3) maps [0,1] -> [-1,1].
x = (image.to(self.device, dtype=self.precision_dtype) - 0.5) / 0.5
# Convert normalized intrinsics (fx, fy in image-relative units) to pixel focal length if provided.
f_px = None
if intrinsics is not None:
intr = intrinsics.to(self.device)
if intr.ndim == 3:
intr = intr[0]
f_px = intr[0, 0] * W # MoGe normalized intrinsics: fx in [0, 1] of width
prediction = self.model.infer(x, f_px=f_px)
depth = prediction["depth"] # [H, W] in meters (squeezed by Depth Pro)
focallength_px = prediction["focallength_px"] # scalar tensor (pixels)
out: Dict[str, torch.Tensor] = {"depth_metric": depth}
# Build normalized intrinsics (fx, fy in fraction of image width / height).
fx_norm = (focallength_px / W).reshape(())
fy_norm = (focallength_px / H).reshape(())
K = torch.eye(3, device=depth.device, dtype=depth.dtype)
K[0, 0] = fx_norm
K[1, 1] = fy_norm
K[0, 2] = 0.5
K[1, 2] = 0.5
out["intrinsics"] = K
if not omit_batch:
out["depth_metric"] = out["depth_metric"].unsqueeze(0)
out["intrinsics"] = out["intrinsics"].unsqueeze(0)
return out