Diffusers
Safetensors
File size: 4,494 Bytes
4165f20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# Reference: https://github.com/indu1ge/DepthMaster
# Strictly follows official `run.py`:
#   from depthmaster import DepthMasterPipeline
#   from depthmaster.modules.unet_2d_condition_s2 import UNet2DConditionModel
#   pipe = DepthMasterPipeline.from_pretrained(checkpoint_path, variant=variant, torch_dtype=dtype)
#   unet = UNet2DConditionModel.from_pretrained(os.path.join(checkpoint_path, 'unet'))
#   pipe.unet = unet
#   pipe = pipe.to(device)
#   pipe_out = pipe(input_pil_image, processing_res=..., match_input_res=...,
#                   batch_size=..., color_map=..., show_progress_bar=..., resample_method=...)
#   depth_pred = pipe_out.depth_np  # H x W float, affine-invariant depth

import os
import sys
from typing import *
from pathlib import Path

import click
import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image

from moge.test.baseline import MGEBaselineInterface


class Baseline(MGEBaselineInterface):
    def __init__(self, repo_path: str, checkpoint: str, processing_res: Optional[int],
                 half_precision: bool, device: Union[torch.device, str]):
        repo_path = os.path.abspath(repo_path)
        if not Path(repo_path).exists():
            raise FileNotFoundError(
                f"Cannot find DepthMaster repo at {repo_path}. Clone https://github.com/indu1ge/DepthMaster."
            )
        if repo_path not in sys.path:
            sys.path.insert(0, repo_path)

        from depthmaster import DepthMasterPipeline
        from depthmaster.modules.unet_2d_condition_s2 import UNet2DConditionModel

        device = torch.device(device)
        dtype = torch.float16 if half_precision else torch.float32
        variant = "fp16" if half_precision else None

        pipe = DepthMasterPipeline.from_pretrained(checkpoint, variant=variant, torch_dtype=dtype)
        unet_dir = os.path.join(checkpoint, "unet")
        unet = UNet2DConditionModel.from_pretrained(unet_dir)
        pipe.unet = unet
        try:
            pipe.enable_xformers_memory_efficient_attention()
        except ImportError:
            pass
        pipe = pipe.to(device)

        self.pipe = pipe
        self.device = device
        self.processing_res = processing_res

    @click.command()
    @click.option('--repo', 'repo_path', type=click.Path(), default='../DepthMaster',
                  help='Path to the indu1ge/DepthMaster repository.')
    @click.option('--checkpoint', type=click.Path(), required=True,
                  help='Local checkpoint directory containing pipeline files + unet subdir (HF: zysong212/DepthMaster).')
    @click.option('--processing_res', type=int, default=768,
                  help='Pipeline processing resolution (run.py default 768).')
    @click.option('--fp16', 'half_precision', is_flag=True, help='Run in half precision.')
    @click.option('--device', type=str, default='cuda')
    @staticmethod
    def load(repo_path: str, checkpoint: str, processing_res: Optional[int],
             half_precision: bool, device: str = 'cuda'):
        return Baseline(repo_path, checkpoint, processing_res, half_precision, device)

    @torch.inference_mode()
    def infer(self, image: torch.Tensor, intrinsics: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
        omit_batch = image.ndim == 3
        if omit_batch:
            image = image.unsqueeze(0)
        assert image.shape[0] == 1, "DepthMaster baseline only supports batch size 1"
        _, _, H, W = image.shape

        # Pipeline takes a PIL.Image (per run.py).
        arr = (image[0].cpu().permute(1, 2, 0).clamp(0, 1).numpy() * 255).astype(np.uint8)
        pil = Image.fromarray(arr)

        out = self.pipe(
            pil,
            processing_res=self.processing_res,
            match_input_res=True,
            batch_size=0,
            color_map='Spectral',
            show_progress_bar=False,
            resample_method='bilinear',
        )

        depth_np = out.depth_np
        depth = torch.from_numpy(np.ascontiguousarray(depth_np)).to(self.device).float()
        if depth.shape != (H, W):
            depth = F.interpolate(depth[None, None], size=(H, W), mode='bilinear', align_corners=False)[0, 0]

        # DepthMaster predicts affine-invariant depth (TCSVT 2026). Emit only this physical key.
        result = {'depth_affine_invariant': depth}
        if not omit_batch:
            result['depth_affine_invariant'] = result['depth_affine_invariant'].unsqueeze(0)
        return result