"""Sapiens2 surface-normal Gradio Space. Image → per-pixel surface normals. Visualized by RGB-encoding the unit-length (x, y, z) normal: r = (x + 1) / 2, g = (y + 1) / 2, b = (z + 1) / 2. Optionally applies a v1 binary fg/bg mask so background pixels are blacked out. """ import sys import os sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) import tempfile import cv2 import gradio as gr import numpy as np import spaces import torch import torch.nn.functional as F from PIL import Image from torchvision import transforms from huggingface_hub import hf_hub_download from sapiens.dense.models import NormalEstimator, init_model # registers NormalEstimator _ = NormalEstimator # ----------------------------------------------------------------------------- # Config ASSETS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") CONFIGS_DIR = os.path.join(ASSETS_DIR, "configs") NORMAL_MODELS = { "0.4B": { "repo": "facebook/sapiens2-normal-0.4b", "filename": "sapiens2_0.4b_normal.safetensors", "config": os.path.join(CONFIGS_DIR, "sapiens2_0.4b_normal_metasim_render_people-1024x768.py"), }, "0.8B": { "repo": "facebook/sapiens2-normal-0.8b", "filename": "sapiens2_0.8b_normal.safetensors", "config": os.path.join(CONFIGS_DIR, "sapiens2_0.8b_normal_metasim_render_people-1024x768.py"), }, "1B": { "repo": "facebook/sapiens2-normal-1b", "filename": "sapiens2_1b_normal.safetensors", "config": os.path.join(CONFIGS_DIR, "sapiens2_1b_normal_metasim_render_people-1024x768.py"), }, "5B": { "repo": "facebook/sapiens2-normal-5b", "filename": "sapiens2_5b_normal.safetensors", "config": os.path.join(CONFIGS_DIR, "sapiens2_5b_normal_metasim_render_people-1024x768.py"), }, } DEFAULT_SIZE = "1B" # v1 binary fg/bg TorchScript model. FG_REPO = "facebook/sapiens-seg-foreground-1b-torchscript" FG_FILENAME = "sapiens_1b_seg_foreground_epoch_8_torchscript.pt2" BG_OPTIONS = ["fg-bg", "no-bg-removal"] DEFAULT_BG = "fg-bg" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" _fg_transform = transforms.Compose([ transforms.Resize((1024, 768)), transforms.ToTensor(), transforms.Normalize(mean=[123.5 / 255, 116.5 / 255, 103.5 / 255], std=[58.5 / 255, 57.0 / 255, 57.5 / 255]), ]) # ----------------------------------------------------------------------------- # Model cache _normal_model_cache: dict = {} _fg_model = None def _get_normal_model(size: str): if size not in _normal_model_cache: spec = NORMAL_MODELS[size] ckpt = hf_hub_download(repo_id=spec["repo"], filename=spec["filename"]) model = init_model(spec["config"], ckpt, device=DEVICE) _normal_model_cache[size] = model return _normal_model_cache[size] def _get_fg_model(): global _fg_model if _fg_model is None: ckpt = hf_hub_download(repo_id=FG_REPO, filename=FG_FILENAME) _fg_model = torch.jit.load(ckpt).eval().to(DEVICE) return _fg_model print("[startup] pre-loading all normal sizes + fg/bg ...") for _size in NORMAL_MODELS: _get_normal_model(_size) _get_fg_model() print("[startup] ready.") # ----------------------------------------------------------------------------- # Inference (mirrors sapiens/dense/tools/vis/vis_normal.py) def _estimate_normal(image_bgr: np.ndarray, model) -> np.ndarray: """Returns unit-length normals at the model's NATIVE (post-unpad) resolution — no upsampling here. The caller upsamples the encoded RGB image instead, which gives sharper edges than bilinear-upsampling the raw float vectors.""" data = model.pipeline(dict(img=image_bgr)) # resize + pad data = model.data_preprocessor(data) # normalize + batch inputs, data_samples = data["inputs"], data["data_samples"] with torch.no_grad(): normal = model(inputs) # (1, 3, padded_H, padded_W) normal = normal / normal.norm(dim=1, keepdim=True).clamp_min(1e-8) pad_left, pad_right, pad_top, pad_bottom = data_samples["meta"]["padding_size"] normal = normal[ :, :, pad_top : inputs.shape[2] - pad_bottom, pad_left : inputs.shape[3] - pad_right, ] return normal.squeeze(0).cpu().float().numpy().transpose(1, 2, 0) # (H_native, W_native, 3) def _foreground_mask(image_pil: Image.Image, target_h: int, target_w: int) -> np.ndarray: fg = _get_fg_model() inputs = _fg_transform(image_pil).unsqueeze(0).to(DEVICE) with torch.no_grad(): out = fg(inputs) # (1, K, H, W) logits out = F.interpolate(out, size=(target_h, target_w), mode="bilinear", align_corners=False) return (out.argmax(dim=1)[0] > 0).cpu().numpy() def _normal_to_rgb(normal_hwc: np.ndarray) -> np.ndarray: """(H, W, 3) in [-1, 1] → (H, W, 3) uint8 RGB. NO channel swap (the swap in vis_normal.py is purely for cv2.imwrite's BGR convention).""" return (((normal_hwc + 1.0) / 2.0) * 255.0).clip(0, 255).astype(np.uint8) # ----------------------------------------------------------------------------- # Gradio handler @spaces.GPU(duration=120) def predict(image: Image.Image, size: str, bg_mode: str): if image is None: return None, None, None image_pil = image.convert("RGB") image_rgb = np.array(image_pil) image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR) h0, w0 = image_rgb.shape[:2] model = _get_normal_model(size) normal_native = _estimate_normal(image_bgr, model) # (H_native, W_native, 3) in [-1, 1] h_n, w_n = normal_native.shape[:2] if bg_mode == "fg-bg": # Mask is computed at native resolution to keep things fast and # consistent with the normal map's actual pixel grid. mask_native = _foreground_mask(image_pil, h_n, w_n) normal_native[~mask_native] = -1.0 # → RGB(0,0,0) after vis rgb_native = _normal_to_rgb(normal_native) # (H_native, W_native, 3) uint8 rgb_pil = Image.fromarray(rgb_native).resize((w0, h0), Image.LANCZOS) # upsample IMAGE, lanczos # Save raw normals at native resolution (NaN where bg removed for completeness). raw = normal_native.copy() if bg_mode == "fg-bg": raw[~mask_native] = np.nan npy_path = tempfile.NamedTemporaryFile(delete=False, suffix=".npy").name np.save(npy_path, raw.astype(np.float32)) # Slider value: (left_image, right_image) — left=input RGB, right=normal map. return (image_pil, rgb_pil), npy_path # ----------------------------------------------------------------------------- # UI EXAMPLES = sorted( os.path.join(ASSETS_DIR, "images", n) for n in os.listdir(os.path.join(ASSETS_DIR, "images")) if n.lower().endswith((".jpg", ".jpeg", ".png")) ) CUSTOM_CSS = """ :root, body, .gradio-container, button, input, select, textarea, .gradio-container *:not(code):not(pre) { font-family: "Helvetica Neue", Helvetica, Arial, sans-serif !important; -webkit-font-smoothing: antialiased; -moz-osx-font-smoothing: grayscale; } #title { text-align: center; font-size: 44px; font-weight: 700; letter-spacing: -0.01em; margin: 28px 0 4px; background: linear-gradient(90deg, #1d4ed8 0%, #6d28d9 50%, #be185d 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; background-clip: text; } #subtitle { text-align: center; font-size: 12px; color: #64748b; letter-spacing: 0.18em; margin: 0 0 14px; text-transform: uppercase; font-weight: 500; } #badges { display: flex; justify-content: center; flex-wrap: wrap; gap: 8px; margin: 0 0 32px; } .pill { display: inline-flex; align-items: center; gap: 6px; padding: 7px 14px; border-radius: 999px; background: #f1f5f9; color: #0f172a !important; font-size: 13px; font-weight: 500; letter-spacing: 0.01em; text-decoration: none !important; border: 1px solid #e2e8f0; transition: background 150ms ease, transform 150ms ease, border-color 150ms ease; } .pill:hover { background: #0f172a; color: #f8fafc !important; border-color: #0f172a; transform: translateY(-1px); } .pill svg { width: 14px; height: 14px; } """ HEADER_HTML = """
Sapiens2: Normal
ICLR 2026
Code 🤗 Models Paper Project
""" with gr.Blocks(title="Sapiens2 Normal", theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo: gr.HTML(HEADER_HTML) with gr.Row(equal_height=True): inp = gr.Image(label="Input", type="pil", height=640) out_img = gr.ImageSlider( label="Input ↔ Normal", type="pil", height=640, max_height=640, slider_position=50, ) with gr.Row(): size = gr.Radio( choices=list(NORMAL_MODELS.keys()), value=DEFAULT_SIZE, label="Model", scale=2, ) bg = gr.Radio( choices=BG_OPTIONS, value=DEFAULT_BG, label="Background", scale=2, ) run = gr.Button("Run", variant="primary", size="lg", scale=1) gr.Examples(examples=EXAMPLES, inputs=inp, examples_per_page=14) with gr.Accordion("Raw Normals", open=False): out_npy = gr.File(label="Raw normals (.npy float32 [-1, 1]; NaN where bg removed)") run.click(predict, inputs=[inp, size, bg], outputs=[out_img, out_npy]) if __name__ == "__main__": if torch.cuda.is_available(): torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True demo.launch(share=False)