# Reference: https://github.com/apple/ml-depth-pro # Strictly follows official README Python API: # model, transform = depth_pro.create_model_and_transforms() # prediction = model.infer(image, f_px=f_px) # depth = prediction["depth"] # in [m] # focallength_px = prediction["focallength_px"] # # Depth Pro outputs *metric* depth. Returns key `depth_metric` plus `intrinsics` # when focal length is recovered, so MoGe's compute_metrics can use the metric path. import os import sys from typing import * from pathlib import Path import click import torch import torch.nn.functional as F from moge.test.baseline import MGEBaselineInterface class Baseline(MGEBaselineInterface): def __init__(self, repo_path: str, checkpoint_path: str, precision: str, device: Union[torch.device, str]): repo_path = os.path.abspath(repo_path) if not Path(repo_path).exists(): raise FileNotFoundError( f"Cannot find Depth Pro repo at {repo_path}. Clone https://github.com/apple/ml-depth-pro " f"and pass --repo ." ) src_path = os.path.join(repo_path, "src") if src_path not in sys.path: sys.path.insert(0, src_path) import depth_pro from depth_pro.depth_pro import DepthProConfig, DEFAULT_MONODEPTH_CONFIG_DICT if not os.path.isabs(checkpoint_path): checkpoint_path = os.path.join(repo_path, checkpoint_path) if not os.path.exists(checkpoint_path): raise FileNotFoundError( f"Cannot find Depth Pro checkpoint at {checkpoint_path}. " f"Run `source get_pretrained_models.sh` inside the ml-depth-pro repo to download it." ) device = torch.device(device) precision_dtype = {"fp32": torch.float32, "fp16": torch.float16}[precision] config = DepthProConfig( patch_encoder_preset=DEFAULT_MONODEPTH_CONFIG_DICT.patch_encoder_preset, image_encoder_preset=DEFAULT_MONODEPTH_CONFIG_DICT.image_encoder_preset, decoder_features=DEFAULT_MONODEPTH_CONFIG_DICT.decoder_features, checkpoint_uri=checkpoint_path, use_fov_head=DEFAULT_MONODEPTH_CONFIG_DICT.use_fov_head, fov_encoder_preset=DEFAULT_MONODEPTH_CONFIG_DICT.fov_encoder_preset, ) model, _ = depth_pro.create_model_and_transforms(config=config, device=device, precision=precision_dtype) model.eval() self.model = model self.device = device self.precision_dtype = precision_dtype @click.command() @click.option('--repo', 'repo_path', type=click.Path(), default='../ml-depth-pro', help='Path to the apple/ml-depth-pro repository.') @click.option('--checkpoint', 'checkpoint_path', type=click.Path(), default='checkpoints/depth_pro.pt', help='Checkpoint path; relative paths are resolved against --repo.') @click.option('--precision', type=click.Choice(['fp32', 'fp16']), default='fp32') @click.option('--device', type=str, default='cuda') @staticmethod def load(repo_path: str, checkpoint_path: str, precision: str, device: str = 'cuda'): return Baseline(repo_path, checkpoint_path, precision, device) @torch.inference_mode() def infer(self, image: torch.Tensor, intrinsics: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]: omit_batch = image.ndim == 3 if omit_batch: image = image.unsqueeze(0) assert image.shape[0] == 1, "Depth Pro baseline only supports batch size 1" _, _, H, W = image.shape # Depth Pro transform: torchvision.Normalize([0.5]*3, [0.5]*3) maps [0,1] -> [-1,1]. x = (image.to(self.device, dtype=self.precision_dtype) - 0.5) / 0.5 # Convert normalized intrinsics (fx, fy in image-relative units) to pixel focal length if provided. f_px = None if intrinsics is not None: intr = intrinsics.to(self.device) if intr.ndim == 3: intr = intr[0] f_px = intr[0, 0] * W # MoGe normalized intrinsics: fx in [0, 1] of width prediction = self.model.infer(x, f_px=f_px) depth = prediction["depth"] # [H, W] in meters (squeezed by Depth Pro) focallength_px = prediction["focallength_px"] # scalar tensor (pixels) out: Dict[str, torch.Tensor] = {"depth_metric": depth} # Build normalized intrinsics (fx, fy in fraction of image width / height). fx_norm = (focallength_px / W).reshape(()) fy_norm = (focallength_px / H).reshape(()) K = torch.eye(3, device=depth.device, dtype=depth.dtype) K[0, 0] = fx_norm K[1, 1] = fy_norm K[0, 2] = 0.5 K[1, 2] = 0.5 out["intrinsics"] = K if not omit_batch: out["depth_metric"] = out["depth_metric"].unsqueeze(0) out["intrinsics"] = out["intrinsics"].unsqueeze(0) return out