File size: 4,480 Bytes
c7b663e
 
 
 
 
 
 
 
773ae1a
c7b663e
 
 
 
773ae1a
 
 
 
 
 
c7b663e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#!/usr/bin/env python3
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# SAB3R gradio demo executable
# --------------------------------------------------------
import os
import sys
import argparse
import tempfile
from contextlib import nullcontext

# Allow `python demo/demo.py` from the repo root to find the sibling
# mast3r/ and dust3r/ packages.
_REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if _REPO_ROOT not in sys.path:
    sys.path.insert(0, _REPO_ROOT)

import torch
from huggingface_hub import hf_hub_download

from mast3r.demo import get_args_parser as sab3r_get_args_parser, main_demo
from mast3r.model import AsymmetricMASt3R  # noqa: F401  (referenced via eval() below)
from mast3r.utils.misc import hash_md5

import mast3r.utils.path_to_dust3r  # noqa: F401  (side-effect: puts vendored dust3r on sys.path)
from dust3r.demo import set_print_with_timestamp

import matplotlib.pyplot as pl
pl.ion()

torch.backends.cuda.matmul.allow_tf32 = True  # for gpu >= Ampere and pytorch >= 1.12

inf = float("inf")

DEFAULT_MODEL_REPO = "uva-cv-lab/SAB3R"
DEFAULT_CKPT_FILENAME = "demo_ckpt/base/base.pt"


def get_args_parser():
    parser = sab3r_get_args_parser()
    parser.add_argument(
        "--model_repo",
        default=os.environ.get("SAB3R_MODEL_REPO", DEFAULT_MODEL_REPO),
        help="Hugging Face Hub repo id hosting the SAB3R checkpoint "
             "(used only when --weights is not provided).",
    )
    parser.add_argument(
        "--ckpt_filename",
        default=os.environ.get("SAB3R_CKPT_FILENAME", DEFAULT_CKPT_FILENAME),
        help="Checkpoint filename inside --model_repo.",
    )
    parser.add_argument(
        "--checkpoint_dir",
        default=os.environ.get("SAB3R_CHECKPOINT_DIR", None),
        help="Optional local directory containing one sub-directory per "
             "checkpoint (each sub-dir must hold `<name>.pt`). When provided, "
             "the UI exposes a dropdown to switch between them. Useful for "
             "local dev; leave unset for single-checkpoint HF Spaces deployments.",
    )
    return parser


def load_weights(model, ckp_path, device):
    ckp = torch.load(ckp_path, map_location='cpu')
    if ckp_path.endswith('.pth'):
        model.load_state_dict(ckp['model'], strict=False)
    elif ckp_path.endswith('.pt'):
        model.load_state_dict(ckp['module'])
    else:
        raise ValueError(f"Unknown checkpoint format: {ckp_path}")
    model.to(device)


def build_model_config():
    return (
        "AsymmetricMASt3R(pos_embed='RoPE100', patch_embed_cls='PatchEmbedDust3R', "
        "img_size=(512, 512), head_type='catmlp+dpt', output_mode='pts3d+desc24', "
        "clip_head_type='dpt', dino_head_type='dpt', "
        "depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), "
        "enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, "
        "dec_embed_dim=768, dec_depth=12, dec_num_heads=12, "
        "two_confs=True, landscape_only=False)"
    )


def resolve_weights_path(args: argparse.Namespace) -> str:
    if args.weights:
        return args.weights
    print(f"[sab3r] Downloading checkpoint from HF Hub: {args.model_repo}/{args.ckpt_filename}")
    return hf_hub_download(repo_id=args.model_repo, filename=args.ckpt_filename)


def main(argv=None):
    parser = get_args_parser()
    args = parser.parse_args(argv)
    set_print_with_timestamp()

    if args.server_name is not None:
        server_name = args.server_name
    else:
        server_name = '0.0.0.0' if args.local_network else '127.0.0.1'

    model = eval(build_model_config())
    ckp_path = resolve_weights_path(args)
    load_weights(model, ckp_path, args.device)
    chkpt_tag = hash_md5(ckp_path)

    def get_context(tmp_dir):
        return (tempfile.TemporaryDirectory(suffix='_sab3r_gradio_demo') if tmp_dir is None
                else nullcontext(tmp_dir))

    with get_context(args.tmp_dir) as tmpdirname:
        cache_path = os.path.join(tmpdirname, chkpt_tag)
        os.makedirs(cache_path, exist_ok=True)
        main_demo(
            cache_path, model, args.device, args.image_size, server_name, args.server_port,
            silent=args.silent, share=args.share, gradio_delete_cache=args.gradio_delete_cache,
            checkpoint_dir=args.checkpoint_dir,
        )


if __name__ == '__main__':
    main()