"""Sapiens2 pose-estimation Gradio Space. Top-down 308-keypoint pose: DETR finds people, Sapiens2 estimates keypoints on each crop, and we draw skeleton + keypoints with the GOLIATH palette. """ # Block mmpretrain: mmdet's reid modules try `import mmpretrain` inside # try/except ImportError, but mmpretrain's BLIP language_model.py raises # TypeError (transformers API drift) — escapes the except and kills the process. import sys sys.modules["mmpretrain"] = None # Use the vendored sapiens2 source. import os sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) import json import tempfile from typing import List, Tuple import cv2 import gradio as gr import numpy as np import spaces import torch from huggingface_hub import hf_hub_download from PIL import Image from sapiens.pose.datasets import UDPHeatmap, parse_pose_metainfo from sapiens.pose.evaluators import nms from sapiens.pose.models import init_model from transformers import DetrForObjectDetection, DetrImageProcessor from pose_render_utils import visualize_keypoints # ----------------------------------------------------------------------------- # Config ASSETS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets") CONFIGS_DIR = os.path.join(ASSETS_DIR, "configs") POSE_MODELS = { "0.4B": { "repo": "facebook/sapiens2-pose-0.4b", "filename": "sapiens2_0.4b_pose.safetensors", "config": os.path.join(CONFIGS_DIR, "sapiens2_0.4b_keypoints308_shutterstock_goliath_3po-1024x768.py"), }, "0.8B": { "repo": "facebook/sapiens2-pose-0.8b", "filename": "sapiens2_0.8b_pose.safetensors", "config": os.path.join(CONFIGS_DIR, "sapiens2_0.8b_keypoints308_shutterstock_goliath_3po-1024x768.py"), }, "1B": { "repo": "facebook/sapiens2-pose-1b", "filename": "sapiens2_1b_pose.safetensors", "config": os.path.join(CONFIGS_DIR, "sapiens2_1b_keypoints308_shutterstock_goliath_3po-1024x768.py"), }, "5B": { "repo": "facebook/sapiens2-pose-5b", "filename": "sapiens2_5b_pose.safetensors", "config": os.path.join(CONFIGS_DIR, "sapiens2_5b_keypoints308_shutterstock_goliath_3po-1024x768.py"), }, } DEFAULT_SIZE = "1B" DETECTOR_MODEL_ID = "facebook/detr-resnet-50" # COCO person = label 1 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" BBOX_THR = 0.3 NMS_THR = 0.3 # ----------------------------------------------------------------------------- # Model cache _pose_model_cache: dict = {} _detector_cache: dict = {} _metainfo_cache = None def _get_metainfo(): global _metainfo_cache if _metainfo_cache is None: meta_path = os.path.join(CONFIGS_DIR, "_base_", "keypoints308.py") _metainfo_cache = parse_pose_metainfo(dict(from_file=meta_path)) return _metainfo_cache def _get_detector(): if "model" not in _detector_cache: proc = DetrImageProcessor.from_pretrained(DETECTOR_MODEL_ID) model = DetrForObjectDetection.from_pretrained(DETECTOR_MODEL_ID).eval().to(DEVICE) _detector_cache["proc"] = proc _detector_cache["model"] = model return _detector_cache["proc"], _detector_cache["model"] def _get_pose_model(size: str): if size not in _pose_model_cache: spec = POSE_MODELS[size] ckpt = hf_hub_download(repo_id=spec["repo"], filename=spec["filename"]) model = init_model(spec["config"], ckpt, device=DEVICE) codec_cfg = dict(model.cfg.codec) assert codec_cfg.pop("type") == "UDPHeatmap" model.codec = UDPHeatmap(**codec_cfg) model.pose_metainfo = _get_metainfo() _pose_model_cache[size] = model return _pose_model_cache[size] print("[startup] pre-loading detector + all pose sizes ...") _get_detector() for _size in POSE_MODELS: _get_pose_model(_size) print("[startup] ready.") # ----------------------------------------------------------------------------- # Inference def _detect_persons(image_rgb: np.ndarray) -> np.ndarray: proc, model = _get_detector() pil_img = Image.fromarray(image_rgb) inputs = proc(images=pil_img, return_tensors="pt").to(DEVICE) with torch.no_grad(): outputs = model(**inputs) target_sizes = torch.tensor([image_rgb.shape[:2]], device=DEVICE) results = proc.post_process_object_detection( outputs, target_sizes=target_sizes, threshold=BBOX_THR )[0] person_mask = results["labels"] == 1 boxes = results["boxes"][person_mask].cpu().numpy() scores = results["scores"][person_mask].cpu().numpy().reshape(-1, 1) bboxes = np.concatenate([boxes, scores], axis=1) bboxes = bboxes[nms(bboxes, NMS_THR), :4] if len(bboxes) == 0: h, w = image_rgb.shape[:2] bboxes = np.array([[0, 0, w - 1, h - 1]], dtype=np.float32) return bboxes def _estimate_pose(image_bgr: np.ndarray, bboxes: np.ndarray, model) -> Tuple[List[np.ndarray], List[np.ndarray]]: inputs_list, samples_list = [], [] for bbox in bboxes: data_info = dict(img=image_bgr, bbox=bbox[None], bbox_score=np.ones(1, dtype=np.float32)) data = model.pipeline(data_info) data = model.data_preprocessor(data) inputs_list.append(data["inputs"]) samples_list.append(data["data_samples"]) inputs = torch.cat(inputs_list, dim=0) with torch.no_grad(): pred = model(inputs) pred = pred.cpu().numpy() keypoints, scores = [], [] for i, sample in enumerate(samples_list): kpts_i, scr_i = model.codec.decode(pred[i]) meta = sample["meta"] kpts_i = kpts_i / meta["input_size"] * meta["bbox_scale"] + meta["bbox_center"] - 0.5 * meta["bbox_scale"] keypoints.append(kpts_i[0]) scores.append(scr_i[0]) return keypoints, scores # ----------------------------------------------------------------------------- # Gradio handler @spaces.GPU(duration=120) def predict(image: Image.Image, size: str, kpt_thr: float): if image is None: return None, None image_pil = image.convert("RGB") image_rgb = np.array(image_pil) image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR) bboxes = _detect_persons(image_rgb) model = _get_pose_model(size) keypoints, scores = _estimate_pose(image_bgr, bboxes, model) meta = model.pose_metainfo vis_rgb = image_rgb.copy() for bbox, kpts, scr in zip(bboxes, keypoints, scores): x1, y1, x2, y2 = map(int, bbox[:4]) # Scale render sizes with bbox area (per-person), matching v1's recipe. bbox_diag = float(np.sqrt(max(1, x2 - x1) * max(1, y2 - y1))) radius = max(1, int(bbox_diag * 0.006)) thickness = max(1, int(bbox_diag * 0.006)) bbox_thickness = max(1, thickness // 4) # Green bbox first (so skeleton draws on top). cv2.rectangle(vis_rgb, (x1, y1), (x2, y2), (0, 255, 0), bbox_thickness) # Skeleton + keypoints for this person only. vis_rgb = visualize_keypoints( image=vis_rgb, keypoints=[kpts], keypoints_visible=[np.ones(len(scr), dtype=bool)], keypoint_scores=[scr], radius=radius, thickness=thickness, kpt_thr=kpt_thr, skeleton=meta["skeleton_links"], kpt_color=meta["keypoint_colors"], link_color=meta["skeleton_link_colors"], ) vis_pil = Image.fromarray(vis_rgb) instances = [ { "bbox": [float(v) for v in np.asarray(bbox).reshape(-1)[:4]], "keypoints": np.asarray(kpts, dtype=float).tolist(), "keypoint_scores": np.asarray(s, dtype=float).reshape(-1).tolist(), } for bbox, kpts, s in zip(bboxes, keypoints, scores) ] with tempfile.NamedTemporaryFile(delete=False, suffix=".json", mode="w") as f: json.dump({"instances": instances}, f) json_path = f.name return vis_pil, json_path # ----------------------------------------------------------------------------- # UI EXAMPLES = sorted( os.path.join(ASSETS_DIR, "images", n) for n in os.listdir(os.path.join(ASSETS_DIR, "images")) if n.lower().endswith((".jpg", ".jpeg", ".png")) ) CUSTOM_CSS = """ :root, body, .gradio-container, button, input, select, textarea, .gradio-container *:not(code):not(pre) { font-family: "Helvetica Neue", Helvetica, Arial, sans-serif !important; -webkit-font-smoothing: antialiased; -moz-osx-font-smoothing: grayscale; } #title { text-align: center; font-size: 44px; font-weight: 700; letter-spacing: -0.01em; margin: 28px 0 4px; background: linear-gradient(90deg, #1d4ed8 0%, #6d28d9 50%, #be185d 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; background-clip: text; } #subtitle { text-align: center; font-size: 12px; color: #64748b; letter-spacing: 0.18em; margin: 0 0 14px; text-transform: uppercase; font-weight: 500; } #badges { display: flex; justify-content: center; flex-wrap: wrap; gap: 8px; margin: 0 0 32px; } .pill { display: inline-flex; align-items: center; gap: 6px; padding: 7px 14px; border-radius: 999px; background: #f1f5f9; color: #0f172a !important; font-size: 13px; font-weight: 500; letter-spacing: 0.01em; text-decoration: none !important; border: 1px solid #e2e8f0; transition: background 150ms ease, transform 150ms ease, border-color 150ms ease; } .pill:hover { background: #0f172a; color: #f8fafc !important; border-color: #0f172a; transform: translateY(-1px); } .pill svg { width: 14px; height: 14px; } """ HEADER_HTML = """
Sapiens2: Pose
ICLR 2026
Code 🤗 Models Paper Project
""" with gr.Blocks(title="Sapiens2 Pose", theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo: gr.HTML(HEADER_HTML) with gr.Row(equal_height=True): inp = gr.Image(label="Input", type="pil", height=640) out_img = gr.Image(label="Pose", type="pil", height=640) with gr.Row(): size = gr.Radio( choices=list(POSE_MODELS.keys()), value=DEFAULT_SIZE, label="Model", scale=2, ) thr = gr.Slider(0.0, 1.0, value=0.3, step=0.05, label="Keypoint threshold", scale=2) run = gr.Button("Run", variant="primary", size="lg", scale=1) gr.Examples(examples=EXAMPLES, inputs=inp, examples_per_page=16) with gr.Accordion("Raw Keypoints", open=False): out_json = gr.File(label="Keypoints (.json — bboxes + 308 keypoints + scores per person)") run.click(predict, inputs=[inp, size, thr], outputs=[out_img, out_json]) if __name__ == "__main__": if torch.cuda.is_available(): torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True demo.launch(share=False)