sapiens2-normal / app.py
Rawal Khirodkar
Set examples_per_page=16 (all 16 thumbnails on one page)
1a44717
"""Sapiens2 surface-normal Gradio Space.
Image → per-pixel surface normals. Visualized by RGB-encoding the unit-length
(x, y, z) normal: r = (x + 1) / 2, g = (y + 1) / 2, b = (z + 1) / 2.
Optionally applies a v1 binary fg/bg mask so background pixels are blacked out.
"""
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import tempfile
import cv2
import gradio as gr
import numpy as np
import spaces
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from huggingface_hub import hf_hub_download
from sapiens.dense.models import NormalEstimator, init_model # registers NormalEstimator
_ = NormalEstimator
# -----------------------------------------------------------------------------
# Config
ASSETS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
CONFIGS_DIR = os.path.join(ASSETS_DIR, "configs")
NORMAL_MODELS = {
"0.4B": {
"repo": "facebook/sapiens2-normal-0.4b",
"filename": "sapiens2_0.4b_normal.safetensors",
"config": os.path.join(CONFIGS_DIR, "sapiens2_0.4b_normal_metasim_render_people-1024x768.py"),
},
"0.8B": {
"repo": "facebook/sapiens2-normal-0.8b",
"filename": "sapiens2_0.8b_normal.safetensors",
"config": os.path.join(CONFIGS_DIR, "sapiens2_0.8b_normal_metasim_render_people-1024x768.py"),
},
"1B": {
"repo": "facebook/sapiens2-normal-1b",
"filename": "sapiens2_1b_normal.safetensors",
"config": os.path.join(CONFIGS_DIR, "sapiens2_1b_normal_metasim_render_people-1024x768.py"),
},
"5B": {
"repo": "facebook/sapiens2-normal-5b",
"filename": "sapiens2_5b_normal.safetensors",
"config": os.path.join(CONFIGS_DIR, "sapiens2_5b_normal_metasim_render_people-1024x768.py"),
},
}
DEFAULT_SIZE = "1B"
# v1 binary fg/bg TorchScript model.
FG_REPO = "facebook/sapiens-seg-foreground-1b-torchscript"
FG_FILENAME = "sapiens_1b_seg_foreground_epoch_8_torchscript.pt2"
BG_OPTIONS = ["fg-bg", "no-bg-removal"]
DEFAULT_BG = "fg-bg"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
_fg_transform = transforms.Compose([
transforms.Resize((1024, 768)),
transforms.ToTensor(),
transforms.Normalize(mean=[123.5 / 255, 116.5 / 255, 103.5 / 255],
std=[58.5 / 255, 57.0 / 255, 57.5 / 255]),
])
# -----------------------------------------------------------------------------
# Model cache
_normal_model_cache: dict = {}
_fg_model = None
def _get_normal_model(size: str):
if size not in _normal_model_cache:
spec = NORMAL_MODELS[size]
ckpt = hf_hub_download(repo_id=spec["repo"], filename=spec["filename"])
model = init_model(spec["config"], ckpt, device=DEVICE)
_normal_model_cache[size] = model
return _normal_model_cache[size]
def _get_fg_model():
global _fg_model
if _fg_model is None:
ckpt = hf_hub_download(repo_id=FG_REPO, filename=FG_FILENAME)
_fg_model = torch.jit.load(ckpt).eval().to(DEVICE)
return _fg_model
print("[startup] pre-loading all normal sizes + fg/bg ...")
for _size in NORMAL_MODELS:
_get_normal_model(_size)
_get_fg_model()
print("[startup] ready.")
# -----------------------------------------------------------------------------
# Inference (mirrors sapiens/dense/tools/vis/vis_normal.py)
def _estimate_normal(image_bgr: np.ndarray, model) -> np.ndarray:
"""Returns unit-length normals at the model's NATIVE (post-unpad) resolution
— no upsampling here. The caller upsamples the encoded RGB image instead,
which gives sharper edges than bilinear-upsampling the raw float vectors."""
data = model.pipeline(dict(img=image_bgr)) # resize + pad
data = model.data_preprocessor(data) # normalize + batch
inputs, data_samples = data["inputs"], data["data_samples"]
with torch.no_grad():
normal = model(inputs) # (1, 3, padded_H, padded_W)
normal = normal / normal.norm(dim=1, keepdim=True).clamp_min(1e-8)
pad_left, pad_right, pad_top, pad_bottom = data_samples["meta"]["padding_size"]
normal = normal[
:, :,
pad_top : inputs.shape[2] - pad_bottom,
pad_left : inputs.shape[3] - pad_right,
]
return normal.squeeze(0).cpu().float().numpy().transpose(1, 2, 0) # (H_native, W_native, 3)
def _foreground_mask(image_pil: Image.Image, target_h: int, target_w: int) -> np.ndarray:
fg = _get_fg_model()
inputs = _fg_transform(image_pil).unsqueeze(0).to(DEVICE)
with torch.no_grad():
out = fg(inputs) # (1, K, H, W) logits
out = F.interpolate(out, size=(target_h, target_w), mode="bilinear", align_corners=False)
return (out.argmax(dim=1)[0] > 0).cpu().numpy()
def _normal_to_rgb(normal_hwc: np.ndarray) -> np.ndarray:
"""(H, W, 3) in [-1, 1] → (H, W, 3) uint8 RGB. NO channel swap (the swap in
vis_normal.py is purely for cv2.imwrite's BGR convention)."""
return (((normal_hwc + 1.0) / 2.0) * 255.0).clip(0, 255).astype(np.uint8)
# -----------------------------------------------------------------------------
# Gradio handler
@spaces.GPU(duration=120)
def predict(image: Image.Image, size: str, bg_mode: str):
if image is None:
return None, None, None
image_pil = image.convert("RGB")
image_rgb = np.array(image_pil)
image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
h0, w0 = image_rgb.shape[:2]
model = _get_normal_model(size)
normal_native = _estimate_normal(image_bgr, model) # (H_native, W_native, 3) in [-1, 1]
h_n, w_n = normal_native.shape[:2]
if bg_mode == "fg-bg":
# Mask is computed at native resolution to keep things fast and
# consistent with the normal map's actual pixel grid.
mask_native = _foreground_mask(image_pil, h_n, w_n)
normal_native[~mask_native] = -1.0 # → RGB(0,0,0) after vis
rgb_native = _normal_to_rgb(normal_native) # (H_native, W_native, 3) uint8
rgb_pil = Image.fromarray(rgb_native).resize((w0, h0), Image.LANCZOS) # upsample IMAGE, lanczos
# Save raw normals at native resolution (NaN where bg removed for completeness).
raw = normal_native.copy()
if bg_mode == "fg-bg":
raw[~mask_native] = np.nan
npy_path = tempfile.NamedTemporaryFile(delete=False, suffix=".npy").name
np.save(npy_path, raw.astype(np.float32))
# Slider value: (left_image, right_image) — left=input RGB, right=normal map.
return (image_pil, rgb_pil), npy_path
# -----------------------------------------------------------------------------
# UI
EXAMPLES = sorted(
os.path.join(ASSETS_DIR, "images", n)
for n in os.listdir(os.path.join(ASSETS_DIR, "images"))
if n.lower().endswith((".jpg", ".jpeg", ".png"))
)
CUSTOM_CSS = """
:root, body, .gradio-container, button, input, select, textarea,
.gradio-container *:not(code):not(pre) {
font-family: "Helvetica Neue", Helvetica, Arial, sans-serif !important;
-webkit-font-smoothing: antialiased;
-moz-osx-font-smoothing: grayscale;
}
#title { text-align: center; font-size: 44px; font-weight: 700;
letter-spacing: -0.01em; margin: 28px 0 4px;
background: linear-gradient(90deg, #1d4ed8 0%, #6d28d9 50%, #be185d 100%);
-webkit-background-clip: text; -webkit-text-fill-color: transparent;
background-clip: text; }
#subtitle { text-align: center; font-size: 12px; color: #64748b;
letter-spacing: 0.18em; margin: 0 0 14px; text-transform: uppercase;
font-weight: 500; }
#badges { display: flex; justify-content: center; flex-wrap: wrap;
gap: 8px; margin: 0 0 32px; }
.pill { display: inline-flex; align-items: center; gap: 6px;
padding: 7px 14px; border-radius: 999px;
background: #f1f5f9; color: #0f172a !important;
font-size: 13px; font-weight: 500; letter-spacing: 0.01em;
text-decoration: none !important; border: 1px solid #e2e8f0;
transition: background 150ms ease, transform 150ms ease, border-color 150ms ease; }
.pill:hover { background: #0f172a; color: #f8fafc !important;
border-color: #0f172a; transform: translateY(-1px); }
.pill svg { width: 14px; height: 14px; }
"""
HEADER_HTML = """
<div id="title">Sapiens2: Normal</div>
<div id="subtitle">ICLR 2026</div>
<div id="badges">
<a class="pill" href="https://github.com/facebookresearch/sapiens2" target="_blank" rel="noopener">
<svg viewBox="0 0 24 24" fill="currentColor"><path d="M12 .3a12 12 0 0 0-3.8 23.4c.6.1.8-.3.8-.6v-2c-3.3.7-4-1.6-4-1.6-.6-1.4-1.4-1.8-1.4-1.8-1.1-.7.1-.7.1-.7 1.3.1 2 1.3 2 1.3 1.1 1.9 3 1.4 3.7 1 .1-.8.4-1.4.8-1.7-2.7-.3-5.5-1.3-5.5-5.9 0-1.3.5-2.4 1.3-3.2-.1-.4-.6-1.6.1-3.2 0 0 1-.3 3.3 1.2a11.5 11.5 0 0 1 6 0c2.3-1.5 3.3-1.2 3.3-1.2.7 1.6.2 2.8.1 3.2.8.8 1.3 1.9 1.3 3.2 0 4.6-2.8 5.6-5.5 5.9.4.4.8 1.1.8 2.2v3.3c0 .3.2.7.8.6A12 12 0 0 0 12 .3"/></svg>
Code
</a>
<a class="pill" href="https://huggingface.co/facebook/sapiens2" target="_blank" rel="noopener">
🤗 Models
</a>
<a class="pill" href="https://arxiv.org/pdf/2604.21681" target="_blank" rel="noopener">
<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M14 2H6a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2h12a2 2 0 0 0 2-2V8z"/><polyline points="14 2 14 8 20 8"/><line x1="9" y1="13" x2="15" y2="13"/><line x1="9" y1="17" x2="15" y2="17"/></svg>
Paper
</a>
<a class="pill" href="https://rawalkhirodkar.github.io/sapiens2" target="_blank" rel="noopener">
<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><circle cx="12" cy="12" r="10"/><line x1="2" y1="12" x2="22" y2="12"/><path d="M12 2a15.3 15.3 0 0 1 4 10 15.3 15.3 0 0 1-4 10 15.3 15.3 0 0 1-4-10 15.3 15.3 0 0 1 4-10z"/></svg>
Project
</a>
</div>
"""
with gr.Blocks(title="Sapiens2 Normal", theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
gr.HTML(HEADER_HTML)
with gr.Row(equal_height=True):
inp = gr.Image(label="Input", type="pil", height=640)
out_img = gr.ImageSlider(
label="Input ↔ Normal",
type="pil",
height=640,
max_height=640,
slider_position=50,
)
with gr.Row():
size = gr.Radio(
choices=list(NORMAL_MODELS.keys()),
value=DEFAULT_SIZE,
label="Model",
scale=2,
)
bg = gr.Radio(
choices=BG_OPTIONS,
value=DEFAULT_BG,
label="Background",
scale=2,
)
run = gr.Button("Run", variant="primary", size="lg", scale=1)
gr.Examples(examples=EXAMPLES, inputs=inp, examples_per_page=16)
with gr.Accordion("Raw Normals", open=False):
out_npy = gr.File(label="Raw normals (.npy float32 [-1, 1]; NaN where bg removed)")
run.click(predict, inputs=[inp, size, bg], outputs=[out_img, out_npy])
if __name__ == "__main__":
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
demo.launch(share=False)