Rawal Khirodkar
Pointmap: ship-mode β€” preload all sizes, default to 1B
c507c2e
"""Sapiens2 pointmap Gradio Space.
Image β†’ per-pixel 3D pointmap (camera frame, metric units). Right pane is an
interactive 3D point-cloud viewer rendering a `.glb` exported via trimesh
(MoGe-2's approach β€” much faster than Open3D's `.ply` for Three.js viewers).
All work happens at the model's NATIVE resolution. We additionally cap the
input image to height=1024 before processing so 4K uploads don't blow up
downstream sizes.
"""
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import tempfile
import time as _t
import cv2
import gradio as gr
import numpy as np
import spaces
import torch
import torch.nn.functional as F
import trimesh
from PIL import Image
from torchvision import transforms
from huggingface_hub import hf_hub_download
from sapiens.dense.models import PointmapEstimator, init_model # registers in registry
_ = PointmapEstimator
# -----------------------------------------------------------------------------
# Config
ASSETS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
CONFIGS_DIR = os.path.join(ASSETS_DIR, "configs")
POINTMAP_MODELS = {
"0.4B": {
"repo": "facebook/sapiens2-pointmap-0.4b",
"filename": "sapiens2_0.4b_pointmap.safetensors",
"config": os.path.join(CONFIGS_DIR, "sapiens2_0.4b_pointmap_render_people-1024x768.py"),
},
"0.8B": {
"repo": "facebook/sapiens2-pointmap-0.8b",
"filename": "sapiens2_0.8b_pointmap.safetensors",
"config": os.path.join(CONFIGS_DIR, "sapiens2_0.8b_pointmap_render_people-1024x768.py"),
},
"1B": {
"repo": "facebook/sapiens2-pointmap-1b",
"filename": "sapiens2_1b_pointmap.safetensors",
"config": os.path.join(CONFIGS_DIR, "sapiens2_1b_pointmap_render_people-1024x768.py"),
},
"5B": {
"repo": "facebook/sapiens2-pointmap-5b",
"filename": "sapiens2_5b_pointmap.safetensors",
"config": os.path.join(CONFIGS_DIR, "sapiens2_5b_pointmap_render_people-1024x768.py"),
},
}
DEFAULT_SIZE = "1B"
FG_REPO = "facebook/sapiens-seg-foreground-1b-torchscript"
FG_FILENAME = "sapiens_1b_seg_foreground_epoch_8_torchscript.pt2"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_HEIGHT = 1024 # cap input height before processing β€” keeps everything fast
_fg_transform = transforms.Compose([
transforms.Resize((1024, 768)),
transforms.ToTensor(),
transforms.Normalize(mean=[123.5 / 255, 116.5 / 255, 103.5 / 255],
std=[58.5 / 255, 57.0 / 255, 57.5 / 255]),
])
# -----------------------------------------------------------------------------
# Model cache
_pointmap_model_cache: dict = {}
_fg_model = None
def _get_pointmap_model(size: str):
if size not in _pointmap_model_cache:
spec = POINTMAP_MODELS[size]
ckpt = hf_hub_download(repo_id=spec["repo"], filename=spec["filename"])
model = init_model(spec["config"], ckpt, device=DEVICE)
_pointmap_model_cache[size] = model
return _pointmap_model_cache[size]
def _get_fg_model():
global _fg_model
if _fg_model is None:
ckpt = hf_hub_download(repo_id=FG_REPO, filename=FG_FILENAME)
_fg_model = torch.jit.load(ckpt).eval().to(DEVICE)
return _fg_model
print("[startup] pre-loading all pointmap sizes + fg/bg ...")
for _size in POINTMAP_MODELS:
_get_pointmap_model(_size)
_get_fg_model()
print("[startup] ready.")
# -----------------------------------------------------------------------------
# Helpers
def _cap_height(image_pil: Image.Image, max_h: int = MAX_HEIGHT) -> Image.Image:
w, h = image_pil.size
if h <= max_h:
return image_pil
new_w = int(round(w * max_h / h))
return image_pil.resize((new_w, max_h), Image.LANCZOS)
def _estimate_pointmap(image_bgr: np.ndarray, model) -> np.ndarray:
data = model.pipeline(dict(img=image_bgr))
data = model.data_preprocessor(data)
inputs, data_samples = data["inputs"], data["data_samples"]
if inputs.ndim == 3:
inputs = inputs.unsqueeze(0)
with torch.no_grad():
pointmap, scale = model(inputs)
pointmap = pointmap / scale # β†’ metric
pad_left, pad_right, pad_top, pad_bottom = data_samples["meta"]["padding_size"]
pointmap = pointmap[
:, :,
pad_top : inputs.shape[2] - pad_bottom,
pad_left : inputs.shape[3] - pad_right,
]
return pointmap.squeeze(0).cpu().float().numpy().transpose(1, 2, 0) # (H_native, W_native, 3)
def _foreground_mask(image_pil: Image.Image, target_h: int, target_w: int) -> np.ndarray:
fg = _get_fg_model()
inputs = _fg_transform(image_pil).unsqueeze(0).to(DEVICE)
with torch.no_grad():
out = fg(inputs)
out = F.interpolate(out, size=(target_h, target_w), mode="bilinear", align_corners=False)
mask = (out.argmax(dim=1)[0] > 0).cpu().numpy().astype(np.uint8)
# Morphological closing fills small false-negative holes inside the figure
# (shadows, dark hair, between fingers) so the mesh doesn't show pinpricks.
kernel = np.ones((5, 5), np.uint8)
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
return mask.astype(bool)
def _depth_to_rgb(depth: np.ndarray, mask: np.ndarray) -> np.ndarray:
"""Inverse-depth turbo colormap. Background pixels are left at 0 β€” caller overlays them."""
valid = np.isfinite(depth) & (depth > 1e-3) & mask
rgb = np.zeros((*depth.shape, 3), dtype=np.uint8)
if not valid.any():
return rgb
inv = np.zeros_like(depth, dtype=np.float32)
inv[valid] = 1.0 / depth[valid]
p1, p99 = np.percentile(inv[valid], [1, 99])
lo, hi = float(p1), float(p99)
if hi <= lo:
hi = lo + 1e-3
norm = ((inv - lo) / (hi - lo)).clip(0, 1)
grey = (norm * 255.0).astype(np.uint8)
color = cv2.applyColorMap(grey, cv2.COLORMAP_TURBO)[:, :, ::-1]
rgb[valid] = color[valid]
return rgb
# -----------------------------------------------------------------------------
# Mesh export β€” MoGe-2's recipe (trimesh β†’ .glb)
#
# We build a regular grid mesh from the (H, W) pointmap: each valid pixel is a
# vertex, adjacent valid pixels form quads β†’ 2 triangles each. The trick that
# makes MoGe-2's meshes look clean (no stretched-skin facets at depth jumps,
# no ragged silhouette) is what they call `mask_cleaned`:
#
# mask_cleaned = mask & ~depth_edge(depth, rtol=0.04)
#
# i.e. drop pixels sitting on a depth discontinuity *before* triangulation, so
# no triangle ever spans one. We don't post-filter triangles by edge length.
def _depth_edge(depth: np.ndarray, rtol: float = 0.04, kernel_size: int = 3) -> np.ndarray:
"""NumPy port of `utils3d.numpy.depth_edge` (rtol-only).
For each pixel, look at the kernelΓ—kernel window around it; if
(max βˆ’ min)/depth > rtol, mark it as a depth-edge pixel.
"""
pad = kernel_size // 2
# NaN-pad + nanmax/nanmin = ignore out-of-bounds pixels at image borders
# (matches upstream `utils3d.numpy.max_pool_1d`).
padded = np.pad(depth.astype(np.float32), pad, mode="constant", constant_values=np.nan)
windows = np.lib.stride_tricks.sliding_window_view(padded, (kernel_size, kernel_size))
d_max = np.nanmax(windows, axis=(-2, -1))
d_min = np.nanmin(windows, axis=(-2, -1))
with np.errstate(divide="ignore", invalid="ignore"):
rel = (d_max - d_min) / depth
return np.nan_to_num(rel, nan=0.0, posinf=0.0, neginf=0.0) > rtol
def _make_glb(image_pil_texture: Image.Image, pointmap_hwc: np.ndarray,
mask_hw: np.ndarray, rtol: float = 0.04) -> str:
"""Build a UV-textured triangulated mesh and export to .glb (MoGe-2 recipe)."""
H, W = pointmap_hwc.shape[:2]
image_native = image_pil_texture.resize((W, H), Image.LANCZOS)
z = pointmap_hwc[:, :, 2]
valid = mask_hw & np.isfinite(pointmap_hwc).all(axis=2) & (z > 0.05) & (z < 25.0)
valid &= ~_depth_edge(z, rtol=rtol)
idx_map = np.full((H, W), -1, dtype=np.int64)
yy, xx = np.where(valid)
idx_map[yy, xx] = np.arange(len(yy))
verts = pointmap_hwc[yy, xx].astype(np.float32)
uvs = np.stack([xx / max(W - 1, 1), yy / max(H - 1, 1)], axis=1).astype(np.float32)
# Quad order matches upstream `utils3d.numpy.image_mesh`: [TL, BL, BR, TR],
# split into triangles as fan from TL β†’ [TL, BL, BR] and [TL, BR, TR].
tl = idx_map[:-1, :-1]; tr = idx_map[:-1, 1:]
bl = idx_map[1:, :-1]; br = idx_map[1:, 1:]
quad_valid = (tl != -1) & (tr != -1) & (bl != -1) & (br != -1)
tl_v, tr_v, bl_v, br_v = tl[quad_valid], tr[quad_valid], bl[quad_valid], br[quad_valid]
tri1 = np.stack([tl_v, bl_v, br_v], axis=1)
tri2 = np.stack([tl_v, br_v, tr_v], axis=1)
faces = np.concatenate([tri1, tri2], axis=0).astype(np.int64)
# MoGe-2: y/z flip on positions, v-flip on UVs.
verts = verts * np.array([1.0, -1.0, -1.0], dtype=np.float32)
centroid = verts.mean(axis=0).astype(np.float32) if len(verts) else np.zeros(3, np.float32)
verts = verts - centroid
uvs = uvs * np.array([1.0, -1.0], dtype=np.float32) + np.array([0.0, 1.0], dtype=np.float32)
material = trimesh.visual.material.PBRMaterial(
baseColorTexture=image_native,
metallicFactor=0.5,
roughnessFactor=1.0,
doubleSided=True,
)
visual = trimesh.visual.texture.TextureVisuals(uv=uvs, material=material)
mesh = trimesh.Trimesh(
vertices=verts,
faces=faces,
vertex_normals=None, # MoGe-2 leaves this to the GLB consumer
visual=visual,
process=False,
)
out_path = tempfile.NamedTemporaryFile(delete=False, suffix=".glb").name
mesh.export(out_path)
return out_path
# -----------------------------------------------------------------------------
# Gradio handler
@spaces.GPU(duration=120)
def predict(image: Image.Image, size: str):
if image is None:
return None, None
t0 = _t.perf_counter()
image_pil_full = image.convert("RGB") # original-res (used as texture)
image_pil = _cap_height(image_pil_full) # capped (used for inference)
image_bgr = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
print(f"[time] convert+cap {(_t.perf_counter()-t0)*1000:.0f} ms (input {image_pil.size}, texture {image_pil_full.size})")
t = _t.perf_counter()
model = _get_pointmap_model(size)
pointmap = _estimate_pointmap(image_bgr, model)
h_n, w_n = pointmap.shape[:2]
print(f"[time] pointmap {(_t.perf_counter()-t)*1000:.0f} ms (native {w_n}x{h_n})")
t = _t.perf_counter()
mask = _foreground_mask(image_pil, h_n, w_n)
print(f"[time] fg mask {(_t.perf_counter()-t)*1000:.0f} ms")
t = _t.perf_counter()
depth = pointmap[:, :, 2]
depth_rgb = _depth_to_rgb(depth, mask)
depth_rgb[~mask] = 200 # solid grey background
w0, h0 = image_pil.size
depth_pil = Image.fromarray(depth_rgb).resize((w0, h0), Image.LANCZOS)
print(f"[time] depth+resize {(_t.perf_counter()-t)*1000:.0f} ms")
t = _t.perf_counter()
glb_path = _make_glb(image_pil_full, pointmap, mask)
print(f"[time] glb export {(_t.perf_counter()-t)*1000:.0f} ms")
print(f"[time] TOTAL {(_t.perf_counter()-t0)*1000:.0f} ms")
return depth_pil, glb_path
# -----------------------------------------------------------------------------
# UI
EXAMPLES = sorted(
os.path.join(ASSETS_DIR, "images", n)
for n in os.listdir(os.path.join(ASSETS_DIR, "images"))
if n.lower().endswith((".jpg", ".jpeg", ".png"))
)
CUSTOM_CSS = """
:root, body, .gradio-container, button, input, select, textarea,
.gradio-container *:not(code):not(pre) {
font-family: "Helvetica Neue", Helvetica, Arial, sans-serif !important;
-webkit-font-smoothing: antialiased;
-moz-osx-font-smoothing: grayscale;
}
#title { text-align: center; font-size: 44px; font-weight: 700;
letter-spacing: -0.01em; margin: 28px 0 4px;
background: linear-gradient(90deg, #1d4ed8 0%, #6d28d9 50%, #be185d 100%);
-webkit-background-clip: text; -webkit-text-fill-color: transparent;
background-clip: text; }
#subtitle { text-align: center; font-size: 12px; color: #64748b;
letter-spacing: 0.18em; margin: 0 0 14px; text-transform: uppercase;
font-weight: 500; }
#badges { display: flex; justify-content: center; flex-wrap: wrap;
gap: 8px; margin: 0 0 32px; }
.pill { display: inline-flex; align-items: center; gap: 6px;
padding: 7px 14px; border-radius: 999px;
background: #f1f5f9; color: #0f172a !important;
font-size: 13px; font-weight: 500; letter-spacing: 0.01em;
text-decoration: none !important; border: 1px solid #e2e8f0;
transition: background 150ms ease, transform 150ms ease, border-color 150ms ease; }
.pill:hover { background: #0f172a; color: #f8fafc !important;
border-color: #0f172a; transform: translateY(-1px); }
.pill svg { width: 14px; height: 14px; }
"""
HEADER_HTML = """
<div id="title">Sapiens2: Pointmap</div>
<div id="subtitle">ICLR 2026</div>
<div id="badges">
<a class="pill" href="https://github.com/facebookresearch/sapiens2" target="_blank" rel="noopener">
<svg viewBox="0 0 24 24" fill="currentColor"><path d="M12 .3a12 12 0 0 0-3.8 23.4c.6.1.8-.3.8-.6v-2c-3.3.7-4-1.6-4-1.6-.6-1.4-1.4-1.8-1.4-1.8-1.1-.7.1-.7.1-.7 1.3.1 2 1.3 2 1.3 1.1 1.9 3 1.4 3.7 1 .1-.8.4-1.4.8-1.7-2.7-.3-5.5-1.3-5.5-5.9 0-1.3.5-2.4 1.3-3.2-.1-.4-.6-1.6.1-3.2 0 0 1-.3 3.3 1.2a11.5 11.5 0 0 1 6 0c2.3-1.5 3.3-1.2 3.3-1.2.7 1.6.2 2.8.1 3.2.8.8 1.3 1.9 1.3 3.2 0 4.6-2.8 5.6-5.5 5.9.4.4.8 1.1.8 2.2v3.3c0 .3.2.7.8.6A12 12 0 0 0 12 .3"/></svg>
Code
</a>
<a class="pill" href="https://huggingface.co/facebook/sapiens2" target="_blank" rel="noopener">
πŸ€— Models
</a>
<a class="pill" href="https://arxiv.org/pdf/2604.21681" target="_blank" rel="noopener">
<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M14 2H6a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2h12a2 2 0 0 0 2-2V8z"/><polyline points="14 2 14 8 20 8"/><line x1="9" y1="13" x2="15" y2="13"/><line x1="9" y1="17" x2="15" y2="17"/></svg>
Paper
</a>
<a class="pill" href="https://rawalkhirodkar.github.io/sapiens2" target="_blank" rel="noopener">
<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><circle cx="12" cy="12" r="10"/><line x1="2" y1="12" x2="22" y2="12"/><path d="M12 2a15.3 15.3 0 0 1 4 10 15.3 15.3 0 0 1-4 10 15.3 15.3 0 0 1-4-10 15.3 15.3 0 0 1 4-10z"/></svg>
Project
</a>
</div>
"""
with gr.Blocks(title="Sapiens2 Pointmap", theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
gr.HTML(HEADER_HTML)
# Row 1: input ↔ 3D mesh, equal height
with gr.Row(equal_height=True):
inp = gr.Image(label="Input", type="pil", height=640, scale=2)
out_glb = gr.Model3D(
label="Pointmap",
height=640,
clear_color=[0.97, 0.97, 0.97, 1.0], # cinematic studio white
camera_position=(35, 70, 1.6), # closer, since scene is centered on the human
zoom_speed=0.7,
pan_speed=0.5,
scale=3,
)
# Row 2: controls (with examples below them) on the left | depth heatmap on the right.
with gr.Row():
with gr.Column(scale=2, min_width=320):
size = gr.Radio(
choices=list(POINTMAP_MODELS.keys()),
value=DEFAULT_SIZE,
label="Model",
container=False,
)
run = gr.Button("Run", variant="primary")
gr.Examples(examples=EXAMPLES, inputs=inp, examples_per_page=16)
out_depth = gr.Image(label="Depth (Z)", type="pil", height=640, scale=3)
run.click(predict, inputs=[inp, size], outputs=[out_depth, out_glb])
if __name__ == "__main__":
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
demo.launch(share=False)