sapiens2-pose / app.py
Rawal Khirodkar
Set examples_per_page=16 (all 16 thumbnails on one page)
fe9554e
"""Sapiens2 pose-estimation Gradio Space.
Top-down 308-keypoint pose: DETR finds people, Sapiens2 estimates keypoints
on each crop, and we draw skeleton + keypoints with the GOLIATH palette.
"""
# Block mmpretrain: mmdet's reid modules try `import mmpretrain` inside
# try/except ImportError, but mmpretrain's BLIP language_model.py raises
# TypeError (transformers API drift) — escapes the except and kills the process.
import sys
sys.modules["mmpretrain"] = None
# Use the vendored sapiens2 source.
import os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import json
import tempfile
from typing import List, Tuple
import cv2
import gradio as gr
import numpy as np
import spaces
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from sapiens.pose.datasets import UDPHeatmap, parse_pose_metainfo
from sapiens.pose.evaluators import nms
from sapiens.pose.models import init_model
from transformers import DetrForObjectDetection, DetrImageProcessor
from pose_render_utils import visualize_keypoints
# -----------------------------------------------------------------------------
# Config
ASSETS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
CONFIGS_DIR = os.path.join(ASSETS_DIR, "configs")
POSE_MODELS = {
"0.4B": {
"repo": "facebook/sapiens2-pose-0.4b",
"filename": "sapiens2_0.4b_pose.safetensors",
"config": os.path.join(CONFIGS_DIR, "sapiens2_0.4b_keypoints308_shutterstock_goliath_3po-1024x768.py"),
},
"0.8B": {
"repo": "facebook/sapiens2-pose-0.8b",
"filename": "sapiens2_0.8b_pose.safetensors",
"config": os.path.join(CONFIGS_DIR, "sapiens2_0.8b_keypoints308_shutterstock_goliath_3po-1024x768.py"),
},
"1B": {
"repo": "facebook/sapiens2-pose-1b",
"filename": "sapiens2_1b_pose.safetensors",
"config": os.path.join(CONFIGS_DIR, "sapiens2_1b_keypoints308_shutterstock_goliath_3po-1024x768.py"),
},
"5B": {
"repo": "facebook/sapiens2-pose-5b",
"filename": "sapiens2_5b_pose.safetensors",
"config": os.path.join(CONFIGS_DIR, "sapiens2_5b_keypoints308_shutterstock_goliath_3po-1024x768.py"),
},
}
DEFAULT_SIZE = "1B"
DETECTOR_MODEL_ID = "facebook/detr-resnet-50" # COCO person = label 1
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BBOX_THR = 0.3
NMS_THR = 0.3
# -----------------------------------------------------------------------------
# Model cache
_pose_model_cache: dict = {}
_detector_cache: dict = {}
_metainfo_cache = None
def _get_metainfo():
global _metainfo_cache
if _metainfo_cache is None:
meta_path = os.path.join(CONFIGS_DIR, "_base_", "keypoints308.py")
_metainfo_cache = parse_pose_metainfo(dict(from_file=meta_path))
return _metainfo_cache
def _get_detector():
if "model" not in _detector_cache:
proc = DetrImageProcessor.from_pretrained(DETECTOR_MODEL_ID)
model = DetrForObjectDetection.from_pretrained(DETECTOR_MODEL_ID).eval().to(DEVICE)
_detector_cache["proc"] = proc
_detector_cache["model"] = model
return _detector_cache["proc"], _detector_cache["model"]
def _get_pose_model(size: str):
if size not in _pose_model_cache:
spec = POSE_MODELS[size]
ckpt = hf_hub_download(repo_id=spec["repo"], filename=spec["filename"])
model = init_model(spec["config"], ckpt, device=DEVICE)
codec_cfg = dict(model.cfg.codec)
assert codec_cfg.pop("type") == "UDPHeatmap"
model.codec = UDPHeatmap(**codec_cfg)
model.pose_metainfo = _get_metainfo()
_pose_model_cache[size] = model
return _pose_model_cache[size]
print("[startup] pre-loading detector + all pose sizes ...")
_get_detector()
for _size in POSE_MODELS:
_get_pose_model(_size)
print("[startup] ready.")
# -----------------------------------------------------------------------------
# Inference
def _detect_persons(image_rgb: np.ndarray) -> np.ndarray:
proc, model = _get_detector()
pil_img = Image.fromarray(image_rgb)
inputs = proc(images=pil_img, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = model(**inputs)
target_sizes = torch.tensor([image_rgb.shape[:2]], device=DEVICE)
results = proc.post_process_object_detection(
outputs, target_sizes=target_sizes, threshold=BBOX_THR
)[0]
person_mask = results["labels"] == 1
boxes = results["boxes"][person_mask].cpu().numpy()
scores = results["scores"][person_mask].cpu().numpy().reshape(-1, 1)
bboxes = np.concatenate([boxes, scores], axis=1)
bboxes = bboxes[nms(bboxes, NMS_THR), :4]
if len(bboxes) == 0:
h, w = image_rgb.shape[:2]
bboxes = np.array([[0, 0, w - 1, h - 1]], dtype=np.float32)
return bboxes
def _estimate_pose(image_bgr: np.ndarray, bboxes: np.ndarray, model) -> Tuple[List[np.ndarray], List[np.ndarray]]:
inputs_list, samples_list = [], []
for bbox in bboxes:
data_info = dict(img=image_bgr, bbox=bbox[None], bbox_score=np.ones(1, dtype=np.float32))
data = model.pipeline(data_info)
data = model.data_preprocessor(data)
inputs_list.append(data["inputs"])
samples_list.append(data["data_samples"])
inputs = torch.cat(inputs_list, dim=0)
with torch.no_grad():
pred = model(inputs)
pred = pred.cpu().numpy()
keypoints, scores = [], []
for i, sample in enumerate(samples_list):
kpts_i, scr_i = model.codec.decode(pred[i])
meta = sample["meta"]
kpts_i = kpts_i / meta["input_size"] * meta["bbox_scale"] + meta["bbox_center"] - 0.5 * meta["bbox_scale"]
keypoints.append(kpts_i[0])
scores.append(scr_i[0])
return keypoints, scores
# -----------------------------------------------------------------------------
# Gradio handler
@spaces.GPU(duration=120)
def predict(image: Image.Image, size: str, kpt_thr: float):
if image is None:
return None, None
image_pil = image.convert("RGB")
image_rgb = np.array(image_pil)
image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
bboxes = _detect_persons(image_rgb)
model = _get_pose_model(size)
keypoints, scores = _estimate_pose(image_bgr, bboxes, model)
meta = model.pose_metainfo
vis_rgb = image_rgb.copy()
for bbox, kpts, scr in zip(bboxes, keypoints, scores):
x1, y1, x2, y2 = map(int, bbox[:4])
# Scale render sizes with bbox area (per-person), matching v1's recipe.
bbox_diag = float(np.sqrt(max(1, x2 - x1) * max(1, y2 - y1)))
radius = max(1, int(bbox_diag * 0.006))
thickness = max(1, int(bbox_diag * 0.006))
bbox_thickness = max(1, thickness // 4)
# Green bbox first (so skeleton draws on top).
cv2.rectangle(vis_rgb, (x1, y1), (x2, y2), (0, 255, 0), bbox_thickness)
# Skeleton + keypoints for this person only.
vis_rgb = visualize_keypoints(
image=vis_rgb,
keypoints=[kpts],
keypoints_visible=[np.ones(len(scr), dtype=bool)],
keypoint_scores=[scr],
radius=radius,
thickness=thickness,
kpt_thr=kpt_thr,
skeleton=meta["skeleton_links"],
kpt_color=meta["keypoint_colors"],
link_color=meta["skeleton_link_colors"],
)
vis_pil = Image.fromarray(vis_rgb)
instances = [
{
"bbox": [float(v) for v in np.asarray(bbox).reshape(-1)[:4]],
"keypoints": np.asarray(kpts, dtype=float).tolist(),
"keypoint_scores": np.asarray(s, dtype=float).reshape(-1).tolist(),
}
for bbox, kpts, s in zip(bboxes, keypoints, scores)
]
with tempfile.NamedTemporaryFile(delete=False, suffix=".json", mode="w") as f:
json.dump({"instances": instances}, f)
json_path = f.name
return vis_pil, json_path
# -----------------------------------------------------------------------------
# UI
EXAMPLES = sorted(
os.path.join(ASSETS_DIR, "images", n)
for n in os.listdir(os.path.join(ASSETS_DIR, "images"))
if n.lower().endswith((".jpg", ".jpeg", ".png"))
)
CUSTOM_CSS = """
:root, body, .gradio-container, button, input, select, textarea,
.gradio-container *:not(code):not(pre) {
font-family: "Helvetica Neue", Helvetica, Arial, sans-serif !important;
-webkit-font-smoothing: antialiased;
-moz-osx-font-smoothing: grayscale;
}
#title { text-align: center; font-size: 44px; font-weight: 700;
letter-spacing: -0.01em; margin: 28px 0 4px;
background: linear-gradient(90deg, #1d4ed8 0%, #6d28d9 50%, #be185d 100%);
-webkit-background-clip: text; -webkit-text-fill-color: transparent;
background-clip: text; }
#subtitle { text-align: center; font-size: 12px; color: #64748b;
letter-spacing: 0.18em; margin: 0 0 14px; text-transform: uppercase;
font-weight: 500; }
#badges { display: flex; justify-content: center; flex-wrap: wrap;
gap: 8px; margin: 0 0 32px; }
.pill { display: inline-flex; align-items: center; gap: 6px;
padding: 7px 14px; border-radius: 999px;
background: #f1f5f9; color: #0f172a !important;
font-size: 13px; font-weight: 500; letter-spacing: 0.01em;
text-decoration: none !important; border: 1px solid #e2e8f0;
transition: background 150ms ease, transform 150ms ease, border-color 150ms ease; }
.pill:hover { background: #0f172a; color: #f8fafc !important;
border-color: #0f172a; transform: translateY(-1px); }
.pill svg { width: 14px; height: 14px; }
"""
HEADER_HTML = """
<div id="title">Sapiens2: Pose</div>
<div id="subtitle">ICLR 2026</div>
<div id="badges">
<a class="pill" href="https://github.com/facebookresearch/sapiens2" target="_blank" rel="noopener">
<svg viewBox="0 0 24 24" fill="currentColor"><path d="M12 .3a12 12 0 0 0-3.8 23.4c.6.1.8-.3.8-.6v-2c-3.3.7-4-1.6-4-1.6-.6-1.4-1.4-1.8-1.4-1.8-1.1-.7.1-.7.1-.7 1.3.1 2 1.3 2 1.3 1.1 1.9 3 1.4 3.7 1 .1-.8.4-1.4.8-1.7-2.7-.3-5.5-1.3-5.5-5.9 0-1.3.5-2.4 1.3-3.2-.1-.4-.6-1.6.1-3.2 0 0 1-.3 3.3 1.2a11.5 11.5 0 0 1 6 0c2.3-1.5 3.3-1.2 3.3-1.2.7 1.6.2 2.8.1 3.2.8.8 1.3 1.9 1.3 3.2 0 4.6-2.8 5.6-5.5 5.9.4.4.8 1.1.8 2.2v3.3c0 .3.2.7.8.6A12 12 0 0 0 12 .3"/></svg>
Code
</a>
<a class="pill" href="https://huggingface.co/facebook/sapiens2" target="_blank" rel="noopener">
🤗 Models
</a>
<a class="pill" href="https://arxiv.org/pdf/2604.21681" target="_blank" rel="noopener">
<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M14 2H6a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2h12a2 2 0 0 0 2-2V8z"/><polyline points="14 2 14 8 20 8"/><line x1="9" y1="13" x2="15" y2="13"/><line x1="9" y1="17" x2="15" y2="17"/></svg>
Paper
</a>
<a class="pill" href="https://rawalkhirodkar.github.io/sapiens2" target="_blank" rel="noopener">
<svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><circle cx="12" cy="12" r="10"/><line x1="2" y1="12" x2="22" y2="12"/><path d="M12 2a15.3 15.3 0 0 1 4 10 15.3 15.3 0 0 1-4 10 15.3 15.3 0 0 1-4-10 15.3 15.3 0 0 1 4-10z"/></svg>
Project
</a>
</div>
"""
with gr.Blocks(title="Sapiens2 Pose", theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
gr.HTML(HEADER_HTML)
with gr.Row(equal_height=True):
inp = gr.Image(label="Input", type="pil", height=640)
out_img = gr.Image(label="Pose", type="pil", height=640)
with gr.Row():
size = gr.Radio(
choices=list(POSE_MODELS.keys()),
value=DEFAULT_SIZE,
label="Model",
scale=2,
)
thr = gr.Slider(0.0, 1.0, value=0.3, step=0.05, label="Keypoint threshold", scale=2)
run = gr.Button("Run", variant="primary", size="lg", scale=1)
gr.Examples(examples=EXAMPLES, inputs=inp, examples_per_page=16)
with gr.Accordion("Raw Keypoints", open=False):
out_json = gr.File(label="Keypoints (.json — bboxes + 308 keypoints + scores per person)")
run.click(predict, inputs=[inp, size, thr], outputs=[out_img, out_json])
if __name__ == "__main__":
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
demo.launch(share=False)