Spaces:
Running on Zero
Running on Zero
| import torch | |
| import mast3r.utils.path_to_dust3r # noqa | |
| from dust3r.model import AsymmetricCroCo3DStereo | |
| from dust3r.image_pairs import make_pairs | |
| from dust3r.inference import loss_of_one_batch | |
| from dust3r.utils.device import to_cpu, collate_with_cat | |
| from mast3r.model import AsymmetricMASt3R | |
| inf = float("inf") | |
| class RES(): | |
| def __init__(self, output): | |
| self.output = output | |
| def get_depth(self): | |
| return self.output['pred1']['pts3d'][0, :, :, 2].detach().cpu().numpy() | |
| def get_conf(self): | |
| return (self.output['pred1']['conf'].squeeze(0).detach().cpu().numpy(), self.output['pred2']['conf'].squeeze(0).detach().cpu().numpy()) | |
| def get_clip(self): | |
| return (self.output['pred1']['clip'].squeeze(0).detach().cpu().numpy(), self.output['pred2']['clip'].squeeze(0).detach().cpu().numpy()) | |
| def get_dino(self): | |
| return (self.output['pred1']['dino'].squeeze(0).detach().cpu().numpy(), self.output['pred2']['dino'].squeeze(0).detach().cpu().numpy()) | |
| class dust3r(): | |
| def __init__(self, model_name = "naver/DUSt3R_ViTLarge_BaseDecoder_512_dpt", device = "cuda"): | |
| self.model = AsymmetricCroCo3DStereo.from_pretrained(model_name).to(device) | |
| self.devide = device | |
| def predict(self, images): | |
| # input list of two images | |
| res = loss_of_one_batch(collate_with_cat([tuple(images)]), self.model, None, self.device) | |
| return RES(to_cpu(res)) | |
| class mast3r(): | |
| def __init__(self, model_name = "naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric", device = "cuda"): | |
| self.model = AsymmetricMASt3R.from_pretrained(model_name).to(device) | |
| self.devide = device | |
| def predict(self, images): | |
| # input list of two images | |
| res = loss_of_one_batch(collate_with_cat([tuple(images)]), self.model, None, self.device) | |
| return RES(to_cpu(res)) | |
| class Sab3r(): | |
| def __init__(self, model_config, model_path, device = "cuda"): | |
| self.device = device | |
| def load_model(model, ckpt_path, device): | |
| ckpt = torch.load(ckpt_path, map_location='cpu') | |
| if ckpt_path.endswith('.pth'): | |
| model.load_state_dict(ckpt['model'], strict=False) | |
| elif ckpt_path.endswith('.pt'): | |
| model.load_state_dict(ckpt['module']) | |
| else: | |
| raise ValueError(f"Unknown checkpoint format: {ckpt_path}") | |
| model = model.to(device) | |
| def enable_mast3r(args): | |
| args = args.replace("ManyAR_PatchEmbed", "PatchEmbedDust3R") | |
| if 'landscape_only' not in args: | |
| args = args[:-1] + ', landscape_only=False)' | |
| else: | |
| args = args.replace(" ", "").replace('landscape_only=True', 'landscape_only=False') | |
| return args | |
| model_config = enable_mast3r(model_config) | |
| self.model = eval(model_config) | |
| load_model(self.model, model_path, device) | |
| def predict(self, images): | |
| # input list of two images | |
| res = loss_of_one_batch(collate_with_cat([tuple(images)]), self.model, None, self.device) | |
| return RES(to_cpu(res)) |