"""Sapiens2 pointmap Gradio Space. Image → per-pixel 3D pointmap (camera frame, metric units). Right pane is an interactive 3D point-cloud viewer rendering a `.glb` exported via trimesh (MoGe-2's approach — much faster than Open3D's `.ply` for Three.js viewers). All work happens at the model's NATIVE resolution. We additionally cap the input image to height=1024 before processing so 4K uploads don't blow up downstream sizes. """ import sys import os sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) import tempfile import time as _t import cv2 import gradio as gr import numpy as np import spaces import torch import torch.nn.functional as F import trimesh from PIL import Image from torchvision import transforms from huggingface_hub import hf_hub_download from sapiens.dense.models import PointmapEstimator, init_model # registers in registry _ = PointmapEstimator # ----------------------------------------------------------------------------- # Config ASSETS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") CONFIGS_DIR = os.path.join(ASSETS_DIR, "configs") POINTMAP_MODELS = { "0.4B": { "repo": "facebook/sapiens2-pointmap-0.4b", "filename": "sapiens2_0.4b_pointmap.safetensors", "config": os.path.join(CONFIGS_DIR, "sapiens2_0.4b_pointmap_render_people-1024x768.py"), }, "0.8B": { "repo": "facebook/sapiens2-pointmap-0.8b", "filename": "sapiens2_0.8b_pointmap.safetensors", "config": os.path.join(CONFIGS_DIR, "sapiens2_0.8b_pointmap_render_people-1024x768.py"), }, "1B": { "repo": "facebook/sapiens2-pointmap-1b", "filename": "sapiens2_1b_pointmap.safetensors", "config": os.path.join(CONFIGS_DIR, "sapiens2_1b_pointmap_render_people-1024x768.py"), }, "5B": { "repo": "facebook/sapiens2-pointmap-5b", "filename": "sapiens2_5b_pointmap.safetensors", "config": os.path.join(CONFIGS_DIR, "sapiens2_5b_pointmap_render_people-1024x768.py"), }, } DEFAULT_SIZE = "1B" FG_REPO = "facebook/sapiens-seg-foreground-1b-torchscript" FG_FILENAME = "sapiens_1b_seg_foreground_epoch_8_torchscript.pt2" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" MAX_HEIGHT = 1024 # cap input height before processing — keeps everything fast _fg_transform = transforms.Compose([ transforms.Resize((1024, 768)), transforms.ToTensor(), transforms.Normalize(mean=[123.5 / 255, 116.5 / 255, 103.5 / 255], std=[58.5 / 255, 57.0 / 255, 57.5 / 255]), ]) # ----------------------------------------------------------------------------- # Model cache _pointmap_model_cache: dict = {} _fg_model = None def _get_pointmap_model(size: str): if size not in _pointmap_model_cache: spec = POINTMAP_MODELS[size] ckpt = hf_hub_download(repo_id=spec["repo"], filename=spec["filename"]) model = init_model(spec["config"], ckpt, device=DEVICE) _pointmap_model_cache[size] = model return _pointmap_model_cache[size] def _get_fg_model(): global _fg_model if _fg_model is None: ckpt = hf_hub_download(repo_id=FG_REPO, filename=FG_FILENAME) _fg_model = torch.jit.load(ckpt).eval().to(DEVICE) return _fg_model print("[startup] pre-loading all pointmap sizes + fg/bg ...") for _size in POINTMAP_MODELS: _get_pointmap_model(_size) _get_fg_model() print("[startup] ready.") # ----------------------------------------------------------------------------- # Helpers def _cap_height(image_pil: Image.Image, max_h: int = MAX_HEIGHT) -> Image.Image: w, h = image_pil.size if h <= max_h: return image_pil new_w = int(round(w * max_h / h)) return image_pil.resize((new_w, max_h), Image.LANCZOS) def _estimate_pointmap(image_bgr: np.ndarray, model) -> np.ndarray: data = model.pipeline(dict(img=image_bgr)) data = model.data_preprocessor(data) inputs, data_samples = data["inputs"], data["data_samples"] if inputs.ndim == 3: inputs = inputs.unsqueeze(0) with torch.no_grad(): pointmap, scale = model(inputs) pointmap = pointmap / scale # → metric pad_left, pad_right, pad_top, pad_bottom = data_samples["meta"]["padding_size"] pointmap = pointmap[ :, :, pad_top : inputs.shape[2] - pad_bottom, pad_left : inputs.shape[3] - pad_right, ] return pointmap.squeeze(0).cpu().float().numpy().transpose(1, 2, 0) # (H_native, W_native, 3) def _foreground_mask(image_pil: Image.Image, target_h: int, target_w: int) -> np.ndarray: fg = _get_fg_model() inputs = _fg_transform(image_pil).unsqueeze(0).to(DEVICE) with torch.no_grad(): out = fg(inputs) out = F.interpolate(out, size=(target_h, target_w), mode="bilinear", align_corners=False) mask = (out.argmax(dim=1)[0] > 0).cpu().numpy().astype(np.uint8) # Morphological closing fills small false-negative holes inside the figure # (shadows, dark hair, between fingers) so the mesh doesn't show pinpricks. kernel = np.ones((5, 5), np.uint8) mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) return mask.astype(bool) def _depth_to_rgb(depth: np.ndarray, mask: np.ndarray) -> np.ndarray: """Inverse-depth turbo colormap. Background pixels are left at 0 — caller overlays them.""" valid = np.isfinite(depth) & (depth > 1e-3) & mask rgb = np.zeros((*depth.shape, 3), dtype=np.uint8) if not valid.any(): return rgb inv = np.zeros_like(depth, dtype=np.float32) inv[valid] = 1.0 / depth[valid] p1, p99 = np.percentile(inv[valid], [1, 99]) lo, hi = float(p1), float(p99) if hi <= lo: hi = lo + 1e-3 norm = ((inv - lo) / (hi - lo)).clip(0, 1) grey = (norm * 255.0).astype(np.uint8) color = cv2.applyColorMap(grey, cv2.COLORMAP_TURBO)[:, :, ::-1] rgb[valid] = color[valid] return rgb # ----------------------------------------------------------------------------- # Mesh export — MoGe-2's recipe (trimesh → .glb) # # We build a regular grid mesh from the (H, W) pointmap: each valid pixel is a # vertex, adjacent valid pixels form quads → 2 triangles each. The trick that # makes MoGe-2's meshes look clean (no stretched-skin facets at depth jumps, # no ragged silhouette) is what they call `mask_cleaned`: # # mask_cleaned = mask & ~depth_edge(depth, rtol=0.04) # # i.e. drop pixels sitting on a depth discontinuity *before* triangulation, so # no triangle ever spans one. We don't post-filter triangles by edge length. def _depth_edge(depth: np.ndarray, rtol: float = 0.04, kernel_size: int = 3) -> np.ndarray: """NumPy port of `utils3d.numpy.depth_edge` (rtol-only). For each pixel, look at the kernel×kernel window around it; if (max − min)/depth > rtol, mark it as a depth-edge pixel. """ pad = kernel_size // 2 # NaN-pad + nanmax/nanmin = ignore out-of-bounds pixels at image borders # (matches upstream `utils3d.numpy.max_pool_1d`). padded = np.pad(depth.astype(np.float32), pad, mode="constant", constant_values=np.nan) windows = np.lib.stride_tricks.sliding_window_view(padded, (kernel_size, kernel_size)) d_max = np.nanmax(windows, axis=(-2, -1)) d_min = np.nanmin(windows, axis=(-2, -1)) with np.errstate(divide="ignore", invalid="ignore"): rel = (d_max - d_min) / depth return np.nan_to_num(rel, nan=0.0, posinf=0.0, neginf=0.0) > rtol def _make_glb(image_pil_texture: Image.Image, pointmap_hwc: np.ndarray, mask_hw: np.ndarray, rtol: float = 0.04) -> str: """Build a UV-textured triangulated mesh and export to .glb (MoGe-2 recipe).""" H, W = pointmap_hwc.shape[:2] image_native = image_pil_texture.resize((W, H), Image.LANCZOS) z = pointmap_hwc[:, :, 2] valid = mask_hw & np.isfinite(pointmap_hwc).all(axis=2) & (z > 0.05) & (z < 25.0) valid &= ~_depth_edge(z, rtol=rtol) idx_map = np.full((H, W), -1, dtype=np.int64) yy, xx = np.where(valid) idx_map[yy, xx] = np.arange(len(yy)) verts = pointmap_hwc[yy, xx].astype(np.float32) uvs = np.stack([xx / max(W - 1, 1), yy / max(H - 1, 1)], axis=1).astype(np.float32) # Quad order matches upstream `utils3d.numpy.image_mesh`: [TL, BL, BR, TR], # split into triangles as fan from TL → [TL, BL, BR] and [TL, BR, TR]. tl = idx_map[:-1, :-1]; tr = idx_map[:-1, 1:] bl = idx_map[1:, :-1]; br = idx_map[1:, 1:] quad_valid = (tl != -1) & (tr != -1) & (bl != -1) & (br != -1) tl_v, tr_v, bl_v, br_v = tl[quad_valid], tr[quad_valid], bl[quad_valid], br[quad_valid] tri1 = np.stack([tl_v, bl_v, br_v], axis=1) tri2 = np.stack([tl_v, br_v, tr_v], axis=1) faces = np.concatenate([tri1, tri2], axis=0).astype(np.int64) # MoGe-2: y/z flip on positions, v-flip on UVs. verts = verts * np.array([1.0, -1.0, -1.0], dtype=np.float32) centroid = verts.mean(axis=0).astype(np.float32) if len(verts) else np.zeros(3, np.float32) verts = verts - centroid uvs = uvs * np.array([1.0, -1.0], dtype=np.float32) + np.array([0.0, 1.0], dtype=np.float32) material = trimesh.visual.material.PBRMaterial( baseColorTexture=image_native, metallicFactor=0.5, roughnessFactor=1.0, doubleSided=True, ) visual = trimesh.visual.texture.TextureVisuals(uv=uvs, material=material) mesh = trimesh.Trimesh( vertices=verts, faces=faces, vertex_normals=None, # MoGe-2 leaves this to the GLB consumer visual=visual, process=False, ) out_path = tempfile.NamedTemporaryFile(delete=False, suffix=".glb").name mesh.export(out_path) return out_path # ----------------------------------------------------------------------------- # Gradio handler @spaces.GPU(duration=120) def predict(image: Image.Image, size: str): if image is None: return None, None t0 = _t.perf_counter() image_pil_full = image.convert("RGB") # original-res (used as texture) image_pil = _cap_height(image_pil_full) # capped (used for inference) image_bgr = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR) print(f"[time] convert+cap {(_t.perf_counter()-t0)*1000:.0f} ms (input {image_pil.size}, texture {image_pil_full.size})") t = _t.perf_counter() model = _get_pointmap_model(size) pointmap = _estimate_pointmap(image_bgr, model) h_n, w_n = pointmap.shape[:2] print(f"[time] pointmap {(_t.perf_counter()-t)*1000:.0f} ms (native {w_n}x{h_n})") t = _t.perf_counter() mask = _foreground_mask(image_pil, h_n, w_n) print(f"[time] fg mask {(_t.perf_counter()-t)*1000:.0f} ms") t = _t.perf_counter() depth = pointmap[:, :, 2] depth_rgb = _depth_to_rgb(depth, mask) depth_rgb[~mask] = 200 # solid grey background w0, h0 = image_pil.size depth_pil = Image.fromarray(depth_rgb).resize((w0, h0), Image.LANCZOS) print(f"[time] depth+resize {(_t.perf_counter()-t)*1000:.0f} ms") t = _t.perf_counter() glb_path = _make_glb(image_pil_full, pointmap, mask) print(f"[time] glb export {(_t.perf_counter()-t)*1000:.0f} ms") print(f"[time] TOTAL {(_t.perf_counter()-t0)*1000:.0f} ms") return depth_pil, glb_path # ----------------------------------------------------------------------------- # UI EXAMPLES = sorted( os.path.join(ASSETS_DIR, "images", n) for n in os.listdir(os.path.join(ASSETS_DIR, "images")) if n.lower().endswith((".jpg", ".jpeg", ".png")) ) CUSTOM_CSS = """ :root, body, .gradio-container, button, input, select, textarea, .gradio-container *:not(code):not(pre) { font-family: "Helvetica Neue", Helvetica, Arial, sans-serif !important; -webkit-font-smoothing: antialiased; -moz-osx-font-smoothing: grayscale; } #title { text-align: center; font-size: 44px; font-weight: 700; letter-spacing: -0.01em; margin: 28px 0 4px; background: linear-gradient(90deg, #1d4ed8 0%, #6d28d9 50%, #be185d 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; background-clip: text; } #subtitle { text-align: center; font-size: 12px; color: #64748b; letter-spacing: 0.18em; margin: 0 0 14px; text-transform: uppercase; font-weight: 500; } #badges { display: flex; justify-content: center; flex-wrap: wrap; gap: 8px; margin: 0 0 32px; } .pill { display: inline-flex; align-items: center; gap: 6px; padding: 7px 14px; border-radius: 999px; background: #f1f5f9; color: #0f172a !important; font-size: 13px; font-weight: 500; letter-spacing: 0.01em; text-decoration: none !important; border: 1px solid #e2e8f0; transition: background 150ms ease, transform 150ms ease, border-color 150ms ease; } .pill:hover { background: #0f172a; color: #f8fafc !important; border-color: #0f172a; transform: translateY(-1px); } .pill svg { width: 14px; height: 14px; } """ HEADER_HTML = """