SAB3R / demo.py
Xuweiyi's picture
Reorganize: demo lives in demo/ in the GitHub repo; at Space root we keep app.py+demo.py next to mast3r/ and dust3r/
773ae1a verified
#!/usr/bin/env python3
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# SAB3R gradio demo executable
# --------------------------------------------------------
import os
import sys
import argparse
import tempfile
from contextlib import nullcontext
# Allow `python demo/demo.py` from the repo root to find the sibling
# mast3r/ and dust3r/ packages.
_REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
if _REPO_ROOT not in sys.path:
sys.path.insert(0, _REPO_ROOT)
import torch
from huggingface_hub import hf_hub_download
from mast3r.demo import get_args_parser as sab3r_get_args_parser, main_demo
from mast3r.model import AsymmetricMASt3R # noqa: F401 (referenced via eval() below)
from mast3r.utils.misc import hash_md5
import mast3r.utils.path_to_dust3r # noqa: F401 (side-effect: puts vendored dust3r on sys.path)
from dust3r.demo import set_print_with_timestamp
import matplotlib.pyplot as pl
pl.ion()
torch.backends.cuda.matmul.allow_tf32 = True # for gpu >= Ampere and pytorch >= 1.12
inf = float("inf")
DEFAULT_MODEL_REPO = "uva-cv-lab/SAB3R"
DEFAULT_CKPT_FILENAME = "demo_ckpt/base/base.pt"
def get_args_parser():
parser = sab3r_get_args_parser()
parser.add_argument(
"--model_repo",
default=os.environ.get("SAB3R_MODEL_REPO", DEFAULT_MODEL_REPO),
help="Hugging Face Hub repo id hosting the SAB3R checkpoint "
"(used only when --weights is not provided).",
)
parser.add_argument(
"--ckpt_filename",
default=os.environ.get("SAB3R_CKPT_FILENAME", DEFAULT_CKPT_FILENAME),
help="Checkpoint filename inside --model_repo.",
)
parser.add_argument(
"--checkpoint_dir",
default=os.environ.get("SAB3R_CHECKPOINT_DIR", None),
help="Optional local directory containing one sub-directory per "
"checkpoint (each sub-dir must hold `<name>.pt`). When provided, "
"the UI exposes a dropdown to switch between them. Useful for "
"local dev; leave unset for single-checkpoint HF Spaces deployments.",
)
return parser
def load_weights(model, ckp_path, device):
ckp = torch.load(ckp_path, map_location='cpu')
if ckp_path.endswith('.pth'):
model.load_state_dict(ckp['model'], strict=False)
elif ckp_path.endswith('.pt'):
model.load_state_dict(ckp['module'])
else:
raise ValueError(f"Unknown checkpoint format: {ckp_path}")
model.to(device)
def build_model_config():
return (
"AsymmetricMASt3R(pos_embed='RoPE100', patch_embed_cls='PatchEmbedDust3R', "
"img_size=(512, 512), head_type='catmlp+dpt', output_mode='pts3d+desc24', "
"clip_head_type='dpt', dino_head_type='dpt', "
"depth_mode=('exp', -inf, inf), conf_mode=('exp', 1, inf), "
"enc_embed_dim=1024, enc_depth=24, enc_num_heads=16, "
"dec_embed_dim=768, dec_depth=12, dec_num_heads=12, "
"two_confs=True, landscape_only=False)"
)
def resolve_weights_path(args: argparse.Namespace) -> str:
if args.weights:
return args.weights
print(f"[sab3r] Downloading checkpoint from HF Hub: {args.model_repo}/{args.ckpt_filename}")
return hf_hub_download(repo_id=args.model_repo, filename=args.ckpt_filename)
def main(argv=None):
parser = get_args_parser()
args = parser.parse_args(argv)
set_print_with_timestamp()
if args.server_name is not None:
server_name = args.server_name
else:
server_name = '0.0.0.0' if args.local_network else '127.0.0.1'
model = eval(build_model_config())
ckp_path = resolve_weights_path(args)
load_weights(model, ckp_path, args.device)
chkpt_tag = hash_md5(ckp_path)
def get_context(tmp_dir):
return (tempfile.TemporaryDirectory(suffix='_sab3r_gradio_demo') if tmp_dir is None
else nullcontext(tmp_dir))
with get_context(args.tmp_dir) as tmpdirname:
cache_path = os.path.join(tmpdirname, chkpt_tag)
os.makedirs(cache_path, exist_ok=True)
main_demo(
cache_path, model, args.device, args.image_size, server_name, args.server_port,
silent=args.silent, share=args.share, gradio_delete_cache=args.gradio_delete_cache,
checkpoint_dir=args.checkpoint_dir,
)
if __name__ == '__main__':
main()