Spaces:
Running on Zero
Running on Zero
| """Sapiens2 pointmap Gradio Space. | |
| Image → per-pixel 3D pointmap (camera frame, metric units). Right pane is an | |
| interactive 3D point-cloud viewer rendering a `.glb` exported via trimesh | |
| (MoGe-2's approach — much faster than Open3D's `.ply` for Three.js viewers). | |
| All work happens at the model's NATIVE resolution. We additionally cap the | |
| input image to height=1024 before processing so 4K uploads don't blow up | |
| downstream sizes. | |
| """ | |
| import sys | |
| import os | |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) | |
| import json | |
| import tempfile | |
| import time as _t | |
| import zipfile | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| import spaces | |
| import torch | |
| import torch.nn.functional as F | |
| import trimesh | |
| from PIL import Image | |
| from torchvision import transforms | |
| from huggingface_hub import hf_hub_download | |
| from sapiens.dense.models import PointmapEstimator, init_model # registers in registry | |
| _ = PointmapEstimator | |
| # ----------------------------------------------------------------------------- | |
| # Config | |
| ASSETS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") | |
| CONFIGS_DIR = os.path.join(ASSETS_DIR, "configs") | |
| POINTMAP_MODELS = { | |
| "0.4B": { | |
| "repo": "facebook/sapiens2-pointmap-0.4b", | |
| "filename": "sapiens2_0.4b_pointmap.safetensors", | |
| "config": os.path.join(CONFIGS_DIR, "sapiens2_0.4b_pointmap_render_people-1024x768.py"), | |
| }, | |
| "0.8B": { | |
| "repo": "facebook/sapiens2-pointmap-0.8b", | |
| "filename": "sapiens2_0.8b_pointmap.safetensors", | |
| "config": os.path.join(CONFIGS_DIR, "sapiens2_0.8b_pointmap_render_people-1024x768.py"), | |
| }, | |
| "1B": { | |
| "repo": "facebook/sapiens2-pointmap-1b", | |
| "filename": "sapiens2_1b_pointmap.safetensors", | |
| "config": os.path.join(CONFIGS_DIR, "sapiens2_1b_pointmap_render_people-1024x768.py"), | |
| }, | |
| "5B": { | |
| "repo": "facebook/sapiens2-pointmap-5b", | |
| "filename": "sapiens2_5b_pointmap.safetensors", | |
| "config": os.path.join(CONFIGS_DIR, "sapiens2_5b_pointmap_render_people-1024x768.py"), | |
| }, | |
| } | |
| DEFAULT_SIZE = "0.4B" | |
| FG_REPO = "facebook/sapiens-seg-foreground-1b-torchscript" | |
| FG_FILENAME = "sapiens_1b_seg_foreground_epoch_8_torchscript.pt2" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| MAX_HEIGHT = 1024 # cap input height before processing — keeps everything fast | |
| VIDEO_MAX_HEIGHT = 512 | |
| VIDEO_DEFAULT_FRAMES = 36 | |
| VIDEO_DEFAULT_STRIDE = 2 | |
| _fg_transform = transforms.Compose([ | |
| transforms.Resize((1024, 768)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[123.5 / 255, 116.5 / 255, 103.5 / 255], | |
| std=[58.5 / 255, 57.0 / 255, 57.5 / 255]), | |
| ]) | |
| # ----------------------------------------------------------------------------- | |
| # Model cache | |
| _pointmap_model_cache: dict = {} | |
| _fg_model = None | |
| def _get_pointmap_model(size: str): | |
| if size not in _pointmap_model_cache: | |
| spec = POINTMAP_MODELS[size] | |
| ckpt = hf_hub_download(repo_id=spec["repo"], filename=spec["filename"]) | |
| model = init_model(spec["config"], ckpt, device=DEVICE) | |
| _pointmap_model_cache[size] = model | |
| return _pointmap_model_cache[size] | |
| def _get_fg_model(): | |
| global _fg_model | |
| if _fg_model is None: | |
| ckpt = hf_hub_download(repo_id=FG_REPO, filename=FG_FILENAME) | |
| _fg_model = torch.jit.load(ckpt).eval().to(DEVICE) | |
| return _fg_model | |
| print("[startup] ready; models will load lazily on first request.") | |
| # ----------------------------------------------------------------------------- | |
| # Helpers | |
| def _cap_height(image_pil: Image.Image, max_h: int = MAX_HEIGHT) -> Image.Image: | |
| w, h = image_pil.size | |
| if h <= max_h: | |
| return image_pil | |
| new_w = int(round(w * max_h / h)) | |
| return image_pil.resize((new_w, max_h), Image.LANCZOS) | |
| def _estimate_pointmap(image_bgr: np.ndarray, model) -> np.ndarray: | |
| data = model.pipeline(dict(img=image_bgr)) | |
| data = model.data_preprocessor(data) | |
| inputs, data_samples = data["inputs"], data["data_samples"] | |
| if inputs.ndim == 3: | |
| inputs = inputs.unsqueeze(0) | |
| with torch.no_grad(): | |
| pointmap, scale = model(inputs) | |
| pointmap = pointmap / scale # → metric | |
| pad_left, pad_right, pad_top, pad_bottom = data_samples["meta"]["padding_size"] | |
| pointmap = pointmap[ | |
| :, :, | |
| pad_top : inputs.shape[2] - pad_bottom, | |
| pad_left : inputs.shape[3] - pad_right, | |
| ] | |
| return pointmap.squeeze(0).cpu().float().numpy().transpose(1, 2, 0) # (H_native, W_native, 3) | |
| def _foreground_mask(image_pil: Image.Image, target_h: int, target_w: int) -> np.ndarray: | |
| fg = _get_fg_model() | |
| inputs = _fg_transform(image_pil).unsqueeze(0).to(DEVICE) | |
| with torch.no_grad(): | |
| out = fg(inputs) | |
| out = F.interpolate(out, size=(target_h, target_w), mode="bilinear", align_corners=False) | |
| mask = (out.argmax(dim=1)[0] > 0).cpu().numpy().astype(np.uint8) | |
| # Morphological closing fills small false-negative holes inside the figure | |
| # (shadows, dark hair, between fingers) so the mesh doesn't show pinpricks. | |
| kernel = np.ones((5, 5), np.uint8) | |
| mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) | |
| return mask.astype(bool) | |
| def _depth_to_rgb(depth: np.ndarray, mask: np.ndarray) -> np.ndarray: | |
| """Inverse-depth turbo colormap. Background pixels are left at 0 — caller overlays them.""" | |
| valid = np.isfinite(depth) & (depth > 1e-3) & mask | |
| rgb = np.zeros((*depth.shape, 3), dtype=np.uint8) | |
| if not valid.any(): | |
| return rgb | |
| inv = np.zeros_like(depth, dtype=np.float32) | |
| inv[valid] = 1.0 / depth[valid] | |
| p1, p99 = np.percentile(inv[valid], [1, 99]) | |
| lo, hi = float(p1), float(p99) | |
| if hi <= lo: | |
| hi = lo + 1e-3 | |
| norm = ((inv - lo) / (hi - lo)).clip(0, 1) | |
| grey = (norm * 255.0).astype(np.uint8) | |
| color = cv2.applyColorMap(grey, cv2.COLORMAP_TURBO)[:, :, ::-1] | |
| rgb[valid] = color[valid] | |
| return rgb | |
| # ----------------------------------------------------------------------------- | |
| # Mesh export — MoGe-2's recipe (trimesh → .glb) | |
| # | |
| # We build a regular grid mesh from the (H, W) pointmap: each valid pixel is a | |
| # vertex, adjacent valid pixels form quads → 2 triangles each. The trick that | |
| # makes MoGe-2's meshes look clean (no stretched-skin facets at depth jumps, | |
| # no ragged silhouette) is what they call `mask_cleaned`: | |
| # | |
| # mask_cleaned = mask & ~depth_edge(depth, rtol=0.04) | |
| # | |
| # i.e. drop pixels sitting on a depth discontinuity *before* triangulation, so | |
| # no triangle ever spans one. We don't post-filter triangles by edge length. | |
| def _depth_edge(depth: np.ndarray, rtol: float = 0.04, kernel_size: int = 3) -> np.ndarray: | |
| """NumPy port of `utils3d.numpy.depth_edge` (rtol-only). | |
| For each pixel, look at the kernel×kernel window around it; if | |
| (max − min)/depth > rtol, mark it as a depth-edge pixel. | |
| """ | |
| pad = kernel_size // 2 | |
| # NaN-pad + nanmax/nanmin = ignore out-of-bounds pixels at image borders | |
| # (matches upstream `utils3d.numpy.max_pool_1d`). | |
| padded = np.pad(depth.astype(np.float32), pad, mode="constant", constant_values=np.nan) | |
| windows = np.lib.stride_tricks.sliding_window_view(padded, (kernel_size, kernel_size)) | |
| d_max = np.nanmax(windows, axis=(-2, -1)) | |
| d_min = np.nanmin(windows, axis=(-2, -1)) | |
| with np.errstate(divide="ignore", invalid="ignore"): | |
| rel = (d_max - d_min) / depth | |
| return np.nan_to_num(rel, nan=0.0, posinf=0.0, neginf=0.0) > rtol | |
| def _make_glb(image_pil_texture: Image.Image, pointmap_hwc: np.ndarray, | |
| mask_hw: np.ndarray, rtol: float = 0.04) -> str: | |
| """Build a UV-textured triangulated mesh and export to .glb (MoGe-2 recipe).""" | |
| H, W = pointmap_hwc.shape[:2] | |
| image_native = image_pil_texture.resize((W, H), Image.LANCZOS) | |
| z = pointmap_hwc[:, :, 2] | |
| valid = mask_hw & np.isfinite(pointmap_hwc).all(axis=2) & (z > 0.05) & (z < 25.0) | |
| valid &= ~_depth_edge(z, rtol=rtol) | |
| idx_map = np.full((H, W), -1, dtype=np.int64) | |
| yy, xx = np.where(valid) | |
| idx_map[yy, xx] = np.arange(len(yy)) | |
| verts = pointmap_hwc[yy, xx].astype(np.float32) | |
| uvs = np.stack([xx / max(W - 1, 1), yy / max(H - 1, 1)], axis=1).astype(np.float32) | |
| # Quad order matches upstream `utils3d.numpy.image_mesh`: [TL, BL, BR, TR], | |
| # split into triangles as fan from TL → [TL, BL, BR] and [TL, BR, TR]. | |
| tl = idx_map[:-1, :-1]; tr = idx_map[:-1, 1:] | |
| bl = idx_map[1:, :-1]; br = idx_map[1:, 1:] | |
| quad_valid = (tl != -1) & (tr != -1) & (bl != -1) & (br != -1) | |
| tl_v, tr_v, bl_v, br_v = tl[quad_valid], tr[quad_valid], bl[quad_valid], br[quad_valid] | |
| tri1 = np.stack([tl_v, bl_v, br_v], axis=1) | |
| tri2 = np.stack([tl_v, br_v, tr_v], axis=1) | |
| faces = np.concatenate([tri1, tri2], axis=0).astype(np.int64) | |
| # MoGe-2: y/z flip on positions, v-flip on UVs. | |
| verts = verts * np.array([1.0, -1.0, -1.0], dtype=np.float32) | |
| centroid = verts.mean(axis=0).astype(np.float32) if len(verts) else np.zeros(3, np.float32) | |
| verts = verts - centroid | |
| uvs = uvs * np.array([1.0, -1.0], dtype=np.float32) + np.array([0.0, 1.0], dtype=np.float32) | |
| material = trimesh.visual.material.PBRMaterial( | |
| baseColorTexture=image_native, | |
| metallicFactor=0.5, | |
| roughnessFactor=1.0, | |
| doubleSided=True, | |
| ) | |
| visual = trimesh.visual.texture.TextureVisuals(uv=uvs, material=material) | |
| mesh = trimesh.Trimesh( | |
| vertices=verts, | |
| faces=faces, | |
| vertex_normals=None, # MoGe-2 leaves this to the GLB consumer | |
| visual=visual, | |
| process=False, | |
| ) | |
| out_path = tempfile.NamedTemporaryFile(delete=False, suffix=".glb").name | |
| mesh.export(out_path) | |
| return out_path | |
| def _sample_video_frames(video_path: str, max_frames: int) -> tuple[list[Image.Image], float, int, int]: | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): | |
| raise gr.Error("Could not open the uploaded video.") | |
| fps = float(cap.get(cv2.CAP_PROP_FPS) or 0) or 24.0 | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT) or 0) | |
| frame_limit = int(max(1, min(max_frames, 120))) | |
| if total_frames > 0: | |
| indices = np.linspace(0, max(total_frames - 1, 0), min(frame_limit, total_frames), dtype=np.int32) | |
| else: | |
| indices = np.arange(frame_limit, dtype=np.int32) | |
| frames: list[Image.Image] = [] | |
| source_w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH) or 0) | |
| source_h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT) or 0) | |
| for index in indices: | |
| if total_frames > 0: | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, int(index)) | |
| ok, frame_bgr = cap.read() | |
| if not ok: | |
| if total_frames <= 0: | |
| break | |
| continue | |
| frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) | |
| frames.append(Image.fromarray(frame_rgb)) | |
| if total_frames <= 0 and len(frames) >= frame_limit: | |
| break | |
| cap.release() | |
| if not frames: | |
| raise gr.Error("No readable frames were found in the uploaded video.") | |
| return frames, fps, source_w, source_h | |
| def _sample_grid(mask: np.ndarray, pointmap: np.ndarray, stride: int) -> tuple[np.ndarray, np.ndarray, np.ndarray, int, int]: | |
| stride = int(max(1, min(stride, 16))) | |
| z = pointmap[:, :, 2] | |
| full_valid = ( | |
| mask | |
| & np.isfinite(pointmap).all(axis=2) | |
| & (z > 0.05) | |
| & (z < 25.0) | |
| & ~_depth_edge(z, rtol=0.04) | |
| ) | |
| ys = np.arange(0, mask.shape[0], stride, dtype=np.int32) | |
| xs = np.arange(0, mask.shape[1], stride, dtype=np.int32) | |
| grid_y, grid_x = np.meshgrid(ys, xs, indexing="ij") | |
| valid = full_valid[grid_y, grid_x].reshape(-1) | |
| return ( | |
| grid_y.reshape(-1).astype(np.int32), | |
| grid_x.reshape(-1).astype(np.int32), | |
| valid.astype(np.uint8), | |
| int(len(ys)), | |
| int(len(xs)), | |
| ) | |
| def _point_sequence_zip( | |
| frames: list[Image.Image], | |
| size: str, | |
| max_frames: int, | |
| point_stride: int, | |
| fps: float, | |
| source_w: int, | |
| source_h: int, | |
| ) -> str: | |
| model = _get_pointmap_model(size) | |
| sampled = frames[: int(max(1, min(max_frames, len(frames))))] | |
| positions_frames: list[np.ndarray] = [] | |
| colors_frames: list[np.ndarray] = [] | |
| valid_frames: list[np.ndarray] = [] | |
| sample_y: np.ndarray | None = None | |
| sample_x: np.ndarray | None = None | |
| grid_rows = 0 | |
| grid_cols = 0 | |
| native_w = 0 | |
| native_h = 0 | |
| for frame_index, frame in enumerate(sampled): | |
| t = _t.perf_counter() | |
| image_pil = _cap_height(frame.convert("RGB"), VIDEO_MAX_HEIGHT) | |
| image_bgr = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR) | |
| pointmap = _estimate_pointmap(image_bgr, model) | |
| native_h, native_w = pointmap.shape[:2] | |
| mask = _foreground_mask(image_pil, native_h, native_w) | |
| if sample_y is None or sample_x is None: | |
| sample_y, sample_x, _, grid_rows, grid_cols = _sample_grid(mask, pointmap, point_stride) | |
| print(f"[video] sampled {grid_cols}x{grid_rows} grid at native {native_w}x{native_h}") | |
| image_native = np.array(image_pil.resize((native_w, native_h), Image.LANCZOS)) | |
| points = pointmap[sample_y, sample_x].astype(np.float32) | |
| finite = np.isfinite(points).all(axis=1) & (points[:, 2] > 0.05) & (points[:, 2] < 25.0) | |
| frame_valid = mask[sample_y, sample_x] & ~_depth_edge(pointmap[:, :, 2], rtol=0.04)[sample_y, sample_x] | |
| finite &= frame_valid.astype(bool) | |
| if finite.any(): | |
| centroid = points[finite].mean(axis=0).astype(np.float32) | |
| else: | |
| centroid = np.zeros(3, dtype=np.float32) | |
| points = (points - centroid) * np.array([1.0, -1.0, -1.0], dtype=np.float32) | |
| points[~finite] = 0 | |
| colors = image_native[sample_y, sample_x, :3].astype(np.uint8) | |
| colors[~finite] = 0 | |
| positions_frames.append(points) | |
| colors_frames.append(colors) | |
| valid_frames.append(finite.astype(np.uint8)) | |
| print(f"[video] frame {frame_index + 1}/{len(sampled)} {int(finite.sum())} valid points {(_t.perf_counter() - t) * 1000:.0f} ms") | |
| if not positions_frames: | |
| raise gr.Error("Pointmap inference did not produce any frames.") | |
| positions = np.stack(positions_frames, axis=0).astype(np.float32) | |
| colors = np.stack(colors_frames, axis=0).astype(np.uint8) | |
| valid = np.stack(valid_frames, axis=0).astype(np.uint8) | |
| metadata = { | |
| "format": "fpbox-sapiens-pointmap-sequence-v5", | |
| "model": f"sapiens2-pointmap-{size}", | |
| "frameCount": int(positions.shape[0]), | |
| "fps": float(min(fps, max(1, positions.shape[0]))), | |
| "pointCount": int(positions.shape[1]), | |
| "gridRows": int(grid_rows), | |
| "gridCols": int(grid_cols), | |
| "width": int(native_w), | |
| "height": int(native_h), | |
| "sourceWidth": int(source_w), | |
| "sourceHeight": int(source_h), | |
| "coordinateSystem": "x, -y, -z, centered per frame, regular sampled image grid, per-frame valid mask", | |
| "dtype": {"positions": "float32", "colors": "uint8", "valid": "uint8"}, | |
| } | |
| out_path = tempfile.NamedTemporaryFile(delete=False, suffix=".zip").name | |
| with zipfile.ZipFile(out_path, "w", compression=zipfile.ZIP_DEFLATED) as zf: | |
| zf.writestr("metadata.json", json.dumps(metadata, indent=2)) | |
| zf.writestr("positions_f32.bin", positions.tobytes(order="C")) | |
| zf.writestr("colors_u8.bin", colors.tobytes(order="C")) | |
| zf.writestr("valid_u8.bin", valid.tobytes(order="C")) | |
| return out_path | |
| # ----------------------------------------------------------------------------- | |
| # Gradio handler | |
| def predict(image: Image.Image, size: str): | |
| if image is None: | |
| return None, None | |
| t0 = _t.perf_counter() | |
| image_pil_full = image.convert("RGB") # original-res (used as texture) | |
| image_pil = _cap_height(image_pil_full) # capped (used for inference) | |
| image_bgr = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR) | |
| print(f"[time] convert+cap {(_t.perf_counter()-t0)*1000:.0f} ms (input {image_pil.size}, texture {image_pil_full.size})") | |
| t = _t.perf_counter() | |
| model = _get_pointmap_model(size) | |
| pointmap = _estimate_pointmap(image_bgr, model) | |
| h_n, w_n = pointmap.shape[:2] | |
| print(f"[time] pointmap {(_t.perf_counter()-t)*1000:.0f} ms (native {w_n}x{h_n})") | |
| t = _t.perf_counter() | |
| mask = _foreground_mask(image_pil, h_n, w_n) | |
| print(f"[time] fg mask {(_t.perf_counter()-t)*1000:.0f} ms") | |
| t = _t.perf_counter() | |
| depth = pointmap[:, :, 2] | |
| depth_rgb = _depth_to_rgb(depth, mask) | |
| depth_rgb[~mask] = 200 # solid grey background | |
| w0, h0 = image_pil.size | |
| depth_pil = Image.fromarray(depth_rgb).resize((w0, h0), Image.LANCZOS) | |
| print(f"[time] depth+resize {(_t.perf_counter()-t)*1000:.0f} ms") | |
| t = _t.perf_counter() | |
| glb_path = _make_glb(image_pil_full, pointmap, mask) | |
| print(f"[time] glb export {(_t.perf_counter()-t)*1000:.0f} ms") | |
| print(f"[time] TOTAL {(_t.perf_counter()-t0)*1000:.0f} ms") | |
| return depth_pil, glb_path | |
| def predict_video(video_path: str, size: str, max_frames: int, point_stride: int): | |
| if video_path is None: | |
| return None | |
| t0 = _t.perf_counter() | |
| frames, fps, source_w, source_h = _sample_video_frames(video_path, max_frames) | |
| print(f"[video] sampled {len(frames)} frames from {source_w}x{source_h} video at {fps:.2f} fps") | |
| zip_path = _point_sequence_zip( | |
| frames=frames, | |
| size=size, | |
| max_frames=max_frames, | |
| point_stride=point_stride, | |
| fps=fps, | |
| source_w=source_w, | |
| source_h=source_h, | |
| ) | |
| print(f"[video] TOTAL {(_t.perf_counter() - t0) * 1000:.0f} ms") | |
| return zip_path | |
| # ----------------------------------------------------------------------------- | |
| # UI | |
| EXAMPLES = sorted( | |
| os.path.join(ASSETS_DIR, "images", n) | |
| for n in os.listdir(os.path.join(ASSETS_DIR, "images")) | |
| if n.lower().endswith((".jpg", ".jpeg", ".png")) | |
| ) | |
| CUSTOM_CSS = """ | |
| :root, body, .gradio-container, button, input, select, textarea, | |
| .gradio-container *:not(code):not(pre) { | |
| font-family: "Helvetica Neue", Helvetica, Arial, sans-serif !important; | |
| -webkit-font-smoothing: antialiased; | |
| -moz-osx-font-smoothing: grayscale; | |
| } | |
| #title { text-align: center; font-size: 44px; font-weight: 700; | |
| letter-spacing: -0.01em; margin: 28px 0 4px; | |
| background: linear-gradient(90deg, #1d4ed8 0%, #6d28d9 50%, #be185d 100%); | |
| -webkit-background-clip: text; -webkit-text-fill-color: transparent; | |
| background-clip: text; } | |
| #subtitle { text-align: center; font-size: 12px; color: #64748b; | |
| letter-spacing: 0.18em; margin: 0 0 14px; text-transform: uppercase; | |
| font-weight: 500; } | |
| #badges { display: flex; justify-content: center; flex-wrap: wrap; | |
| gap: 8px; margin: 0 0 32px; } | |
| .pill { display: inline-flex; align-items: center; gap: 6px; | |
| padding: 7px 14px; border-radius: 999px; | |
| background: #f1f5f9; color: #0f172a !important; | |
| font-size: 13px; font-weight: 500; letter-spacing: 0.01em; | |
| text-decoration: none !important; border: 1px solid #e2e8f0; | |
| transition: background 150ms ease, transform 150ms ease, border-color 150ms ease; } | |
| .pill:hover { background: #0f172a; color: #f8fafc !important; | |
| border-color: #0f172a; transform: translateY(-1px); } | |
| .pill svg { width: 14px; height: 14px; } | |
| """ | |
| HEADER_HTML = """ | |
| <div id="title">Sapiens2: Pointmap</div> | |
| <div id="subtitle">ICLR 2026</div> | |
| <div id="badges"> | |
| <a class="pill" href="https://github.com/facebookresearch/sapiens2" target="_blank" rel="noopener"> | |
| <svg viewBox="0 0 24 24" fill="currentColor"><path d="M12 .3a12 12 0 0 0-3.8 23.4c.6.1.8-.3.8-.6v-2c-3.3.7-4-1.6-4-1.6-.6-1.4-1.4-1.8-1.4-1.8-1.1-.7.1-.7.1-.7 1.3.1 2 1.3 2 1.3 1.1 1.9 3 1.4 3.7 1 .1-.8.4-1.4.8-1.7-2.7-.3-5.5-1.3-5.5-5.9 0-1.3.5-2.4 1.3-3.2-.1-.4-.6-1.6.1-3.2 0 0 1-.3 3.3 1.2a11.5 11.5 0 0 1 6 0c2.3-1.5 3.3-1.2 3.3-1.2.7 1.6.2 2.8.1 3.2.8.8 1.3 1.9 1.3 3.2 0 4.6-2.8 5.6-5.5 5.9.4.4.8 1.1.8 2.2v3.3c0 .3.2.7.8.6A12 12 0 0 0 12 .3"/></svg> | |
| Code | |
| </a> | |
| <a class="pill" href="https://huggingface.co/facebook/sapiens2" target="_blank" rel="noopener"> | |
| 🤗 Models | |
| </a> | |
| <a class="pill" href="https://arxiv.org/pdf/2604.21681" target="_blank" rel="noopener"> | |
| <svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M14 2H6a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2h12a2 2 0 0 0 2-2V8z"/><polyline points="14 2 14 8 20 8"/><line x1="9" y1="13" x2="15" y2="13"/><line x1="9" y1="17" x2="15" y2="17"/></svg> | |
| Paper | |
| </a> | |
| <a class="pill" href="https://rawalkhirodkar.github.io/sapiens2" target="_blank" rel="noopener"> | |
| <svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><circle cx="12" cy="12" r="10"/><line x1="2" y1="12" x2="22" y2="12"/><path d="M12 2a15.3 15.3 0 0 1 4 10 15.3 15.3 0 0 1-4 10 15.3 15.3 0 0 1-4-10 15.3 15.3 0 0 1 4-10z"/></svg> | |
| Project | |
| </a> | |
| </div> | |
| """ | |
| with gr.Blocks(title="Sapiens2 Pointmap", theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo: | |
| gr.HTML(HEADER_HTML) | |
| with gr.Tabs(): | |
| with gr.Tab("Image"): | |
| # Row 1: input ↔ 3D mesh, equal height | |
| with gr.Row(equal_height=True): | |
| inp = gr.Image(label="Input", type="pil", height=640, scale=2) | |
| out_glb = gr.Model3D( | |
| label="Pointmap", | |
| height=640, | |
| clear_color=[0.97, 0.97, 0.97, 1.0], # cinematic studio white | |
| camera_position=(35, 70, 1.6), # closer, since scene is centered on the human | |
| zoom_speed=0.7, | |
| pan_speed=0.5, | |
| scale=3, | |
| ) | |
| # Row 2: controls (with examples below them) on the left | depth heatmap on the right. | |
| with gr.Row(): | |
| with gr.Column(scale=2, min_width=320): | |
| size = gr.Radio( | |
| choices=list(POINTMAP_MODELS.keys()), | |
| value=DEFAULT_SIZE, | |
| label="Model", | |
| container=False, | |
| ) | |
| run = gr.Button("Run", variant="primary") | |
| gr.Examples(examples=EXAMPLES, inputs=inp, examples_per_page=16) | |
| out_depth = gr.Image(label="Depth (Z)", type="pil", height=640, scale=3) | |
| run.click(predict, inputs=[inp, size], outputs=[out_depth, out_glb]) | |
| with gr.Tab("Video"): | |
| with gr.Row(): | |
| video_inp = gr.Video(label="Input Video", height=420) | |
| sequence_zip = gr.File(label="Pointmap Sequence (.zip)") | |
| with gr.Row(): | |
| video_size = gr.Radio( | |
| choices=list(POINTMAP_MODELS.keys()), | |
| value=DEFAULT_SIZE, | |
| label="Model", | |
| container=False, | |
| ) | |
| video_frames = gr.Slider( | |
| minimum=1, | |
| maximum=120, | |
| step=1, | |
| value=VIDEO_DEFAULT_FRAMES, | |
| label="Sampled Frames", | |
| ) | |
| video_stride = gr.Slider( | |
| minimum=1, | |
| maximum=16, | |
| step=1, | |
| value=VIDEO_DEFAULT_STRIDE, | |
| label="Point Stride", | |
| ) | |
| run_video = gr.Button("Run Video Pointmap", variant="primary") | |
| run_video.click( | |
| predict_video, | |
| inputs=[video_inp, video_size, video_frames, video_stride], | |
| outputs=[sequence_zip], | |
| api_name="video_pointmap", | |
| ) | |
| if __name__ == "__main__": | |
| if torch.cuda.is_available(): | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| demo.launch(share=False) | |