File size: 12,424 Bytes
c139808
 
397e140
c139808
 
 
 
 
 
 
 
 
397e140
c139808
5f5f544
 
 
c139808
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbdd74a
c139808
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94742fe
c139808
dbdd74a
c139808
 
 
 
 
 
 
397e140
c139808
 
dbdd74a
c139808
 
 
 
 
 
 
 
 
 
 
 
dbdd74a
 
 
 
 
 
c139808
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94742fe
c139808
94742fe
 
c139808
 
 
 
 
 
dbdd74a
 
 
 
 
 
397e140
dbdd74a
 
 
397e140
 
dbdd74a
397e140
dbdd74a
c139808
dbdd74a
c139808
 
 
 
 
 
 
 
 
 
 
 
 
 
 
397e140
c139808
 
 
 
397e140
c139808
 
 
 
 
 
 
 
 
 
 
 
 
 
 
397e140
 
c139808
 
dbdd74a
c139808
 
 
 
1bdad69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
397e140
c139808
 
 
 
 
 
 
 
 
 
 
 
 
7ca9191
c139808
 
 
 
 
 
 
 
 
 
 
397e140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38409d1
397e140
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7ca9191
397e140
c139808
397e140
94742fe
397e140
 
 
 
 
 
 
fe9554e
397e140
 
 
c139808
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
"""Sapiens2 pose-estimation Gradio Space.

Top-down 308-keypoint pose: DETR finds people, Sapiens2 estimates keypoints
on each crop, and we draw skeleton + keypoints with the GOLIATH palette.
"""

# Block mmpretrain: mmdet's reid modules try `import mmpretrain` inside
# try/except ImportError, but mmpretrain's BLIP language_model.py raises
# TypeError (transformers API drift) — escapes the except and kills the process.
import sys
sys.modules["mmpretrain"] = None

# Use the vendored sapiens2 source.
import os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))

import json
import tempfile
from typing import List, Tuple

import cv2
import gradio as gr
import numpy as np
import spaces
import torch
from huggingface_hub import hf_hub_download
from PIL import Image

from sapiens.pose.datasets import UDPHeatmap, parse_pose_metainfo
from sapiens.pose.evaluators import nms
from sapiens.pose.models import init_model

from transformers import DetrForObjectDetection, DetrImageProcessor

from pose_render_utils import visualize_keypoints


# -----------------------------------------------------------------------------
# Config

ASSETS_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
CONFIGS_DIR = os.path.join(ASSETS_DIR, "configs")

POSE_MODELS = {
    "0.4B": {
        "repo": "facebook/sapiens2-pose-0.4b",
        "filename": "sapiens2_0.4b_pose.safetensors",
        "config": os.path.join(CONFIGS_DIR, "sapiens2_0.4b_keypoints308_shutterstock_goliath_3po-1024x768.py"),
    },
    "0.8B": {
        "repo": "facebook/sapiens2-pose-0.8b",
        "filename": "sapiens2_0.8b_pose.safetensors",
        "config": os.path.join(CONFIGS_DIR, "sapiens2_0.8b_keypoints308_shutterstock_goliath_3po-1024x768.py"),
    },
    "1B": {
        "repo": "facebook/sapiens2-pose-1b",
        "filename": "sapiens2_1b_pose.safetensors",
        "config": os.path.join(CONFIGS_DIR, "sapiens2_1b_keypoints308_shutterstock_goliath_3po-1024x768.py"),
    },
    "5B": {
        "repo": "facebook/sapiens2-pose-5b",
        "filename": "sapiens2_5b_pose.safetensors",
        "config": os.path.join(CONFIGS_DIR, "sapiens2_5b_keypoints308_shutterstock_goliath_3po-1024x768.py"),
    },
}
DEFAULT_SIZE = "1B"

DETECTOR_MODEL_ID = "facebook/detr-resnet-50"  # COCO person = label 1

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BBOX_THR = 0.3
NMS_THR = 0.3


# -----------------------------------------------------------------------------
# Model cache

_pose_model_cache: dict = {}
_detector_cache: dict = {}
_metainfo_cache = None


def _get_metainfo():
    global _metainfo_cache
    if _metainfo_cache is None:
        meta_path = os.path.join(CONFIGS_DIR, "_base_", "keypoints308.py")
        _metainfo_cache = parse_pose_metainfo(dict(from_file=meta_path))
    return _metainfo_cache


def _get_detector():
    if "model" not in _detector_cache:
        proc = DetrImageProcessor.from_pretrained(DETECTOR_MODEL_ID)
        model = DetrForObjectDetection.from_pretrained(DETECTOR_MODEL_ID).eval().to(DEVICE)
        _detector_cache["proc"] = proc
        _detector_cache["model"] = model
    return _detector_cache["proc"], _detector_cache["model"]


def _get_pose_model(size: str):
    if size not in _pose_model_cache:
        spec = POSE_MODELS[size]
        ckpt = hf_hub_download(repo_id=spec["repo"], filename=spec["filename"])
        model = init_model(spec["config"], ckpt, device=DEVICE)
        codec_cfg = dict(model.cfg.codec)
        assert codec_cfg.pop("type") == "UDPHeatmap"
        model.codec = UDPHeatmap(**codec_cfg)
        model.pose_metainfo = _get_metainfo()
        _pose_model_cache[size] = model
    return _pose_model_cache[size]


print("[startup] pre-loading detector + all pose sizes ...")
_get_detector()
for _size in POSE_MODELS:
    _get_pose_model(_size)
print("[startup] ready.")


# -----------------------------------------------------------------------------
# Inference

def _detect_persons(image_rgb: np.ndarray) -> np.ndarray:
    proc, model = _get_detector()
    pil_img = Image.fromarray(image_rgb)
    inputs = proc(images=pil_img, return_tensors="pt").to(DEVICE)
    with torch.no_grad():
        outputs = model(**inputs)
    target_sizes = torch.tensor([image_rgb.shape[:2]], device=DEVICE)
    results = proc.post_process_object_detection(
        outputs, target_sizes=target_sizes, threshold=BBOX_THR
    )[0]
    person_mask = results["labels"] == 1
    boxes = results["boxes"][person_mask].cpu().numpy()
    scores = results["scores"][person_mask].cpu().numpy().reshape(-1, 1)
    bboxes = np.concatenate([boxes, scores], axis=1)
    bboxes = bboxes[nms(bboxes, NMS_THR), :4]
    if len(bboxes) == 0:
        h, w = image_rgb.shape[:2]
        bboxes = np.array([[0, 0, w - 1, h - 1]], dtype=np.float32)
    return bboxes


def _estimate_pose(image_bgr: np.ndarray, bboxes: np.ndarray, model) -> Tuple[List[np.ndarray], List[np.ndarray]]:
    inputs_list, samples_list = [], []
    for bbox in bboxes:
        data_info = dict(img=image_bgr, bbox=bbox[None], bbox_score=np.ones(1, dtype=np.float32))
        data = model.pipeline(data_info)
        data = model.data_preprocessor(data)
        inputs_list.append(data["inputs"])
        samples_list.append(data["data_samples"])

    inputs = torch.cat(inputs_list, dim=0)
    with torch.no_grad():
        pred = model(inputs)

    pred = pred.cpu().numpy()
    keypoints, scores = [], []
    for i, sample in enumerate(samples_list):
        kpts_i, scr_i = model.codec.decode(pred[i])
        meta = sample["meta"]
        kpts_i = kpts_i / meta["input_size"] * meta["bbox_scale"] + meta["bbox_center"] - 0.5 * meta["bbox_scale"]
        keypoints.append(kpts_i[0])
        scores.append(scr_i[0])
    return keypoints, scores


# -----------------------------------------------------------------------------
# Gradio handler

@spaces.GPU(duration=120)
def predict(image: Image.Image, size: str, kpt_thr: float):
    if image is None:
        return None, None

    image_pil = image.convert("RGB")
    image_rgb = np.array(image_pil)
    image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)

    bboxes = _detect_persons(image_rgb)
    model = _get_pose_model(size)
    keypoints, scores = _estimate_pose(image_bgr, bboxes, model)

    meta = model.pose_metainfo
    vis_rgb = image_rgb.copy()
    for bbox, kpts, scr in zip(bboxes, keypoints, scores):
        x1, y1, x2, y2 = map(int, bbox[:4])
        # Scale render sizes with bbox area (per-person), matching v1's recipe.
        bbox_diag = float(np.sqrt(max(1, x2 - x1) * max(1, y2 - y1)))
        radius = max(1, int(bbox_diag * 0.006))
        thickness = max(1, int(bbox_diag * 0.006))
        bbox_thickness = max(1, thickness // 4)

        # Green bbox first (so skeleton draws on top).
        cv2.rectangle(vis_rgb, (x1, y1), (x2, y2), (0, 255, 0), bbox_thickness)

        # Skeleton + keypoints for this person only.
        vis_rgb = visualize_keypoints(
            image=vis_rgb,
            keypoints=[kpts],
            keypoints_visible=[np.ones(len(scr), dtype=bool)],
            keypoint_scores=[scr],
            radius=radius,
            thickness=thickness,
            kpt_thr=kpt_thr,
            skeleton=meta["skeleton_links"],
            kpt_color=meta["keypoint_colors"],
            link_color=meta["skeleton_link_colors"],
        )
    vis_pil = Image.fromarray(vis_rgb)

    instances = [
        {
            "bbox": [float(v) for v in np.asarray(bbox).reshape(-1)[:4]],
            "keypoints": np.asarray(kpts, dtype=float).tolist(),
            "keypoint_scores": np.asarray(s, dtype=float).reshape(-1).tolist(),
        }
        for bbox, kpts, s in zip(bboxes, keypoints, scores)
    ]
    with tempfile.NamedTemporaryFile(delete=False, suffix=".json", mode="w") as f:
        json.dump({"instances": instances}, f)
        json_path = f.name

    return vis_pil, json_path


# -----------------------------------------------------------------------------
# UI

EXAMPLES = sorted(
    os.path.join(ASSETS_DIR, "images", n)
    for n in os.listdir(os.path.join(ASSETS_DIR, "images"))
    if n.lower().endswith((".jpg", ".jpeg", ".png"))
)

CUSTOM_CSS = """
:root, body, .gradio-container, button, input, select, textarea,
.gradio-container *:not(code):not(pre) {
    font-family: "Helvetica Neue", Helvetica, Arial, sans-serif !important;
    -webkit-font-smoothing: antialiased;
    -moz-osx-font-smoothing: grayscale;
}

#title { text-align: center; font-size: 44px; font-weight: 700;
         letter-spacing: -0.01em; margin: 28px 0 4px;
         background: linear-gradient(90deg, #1d4ed8 0%, #6d28d9 50%, #be185d 100%);
         -webkit-background-clip: text; -webkit-text-fill-color: transparent;
         background-clip: text; }
#subtitle { text-align: center; font-size: 12px; color: #64748b;
            letter-spacing: 0.18em; margin: 0 0 14px; text-transform: uppercase;
            font-weight: 500; }
#badges { display: flex; justify-content: center; flex-wrap: wrap;
          gap: 8px; margin: 0 0 32px; }
.pill { display: inline-flex; align-items: center; gap: 6px;
        padding: 7px 14px; border-radius: 999px;
        background: #f1f5f9; color: #0f172a !important;
        font-size: 13px; font-weight: 500; letter-spacing: 0.01em;
        text-decoration: none !important; border: 1px solid #e2e8f0;
        transition: background 150ms ease, transform 150ms ease, border-color 150ms ease; }
.pill:hover { background: #0f172a; color: #f8fafc !important;
              border-color: #0f172a; transform: translateY(-1px); }
.pill svg { width: 14px; height: 14px; }
"""

HEADER_HTML = """
<div id="title">Sapiens2: Pose</div>
<div id="subtitle">ICLR 2026</div>
<div id="badges">
  <a class="pill" href="https://github.com/facebookresearch/sapiens2" target="_blank" rel="noopener">
    <svg viewBox="0 0 24 24" fill="currentColor"><path d="M12 .3a12 12 0 0 0-3.8 23.4c.6.1.8-.3.8-.6v-2c-3.3.7-4-1.6-4-1.6-.6-1.4-1.4-1.8-1.4-1.8-1.1-.7.1-.7.1-.7 1.3.1 2 1.3 2 1.3 1.1 1.9 3 1.4 3.7 1 .1-.8.4-1.4.8-1.7-2.7-.3-5.5-1.3-5.5-5.9 0-1.3.5-2.4 1.3-3.2-.1-.4-.6-1.6.1-3.2 0 0 1-.3 3.3 1.2a11.5 11.5 0 0 1 6 0c2.3-1.5 3.3-1.2 3.3-1.2.7 1.6.2 2.8.1 3.2.8.8 1.3 1.9 1.3 3.2 0 4.6-2.8 5.6-5.5 5.9.4.4.8 1.1.8 2.2v3.3c0 .3.2.7.8.6A12 12 0 0 0 12 .3"/></svg>
    Code
  </a>
  <a class="pill" href="https://huggingface.co/facebook/sapiens2" target="_blank" rel="noopener">
    🤗 Models
  </a>
  <a class="pill" href="https://arxiv.org/pdf/2604.21681" target="_blank" rel="noopener">
    <svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M14 2H6a2 2 0 0 0-2 2v16a2 2 0 0 0 2 2h12a2 2 0 0 0 2-2V8z"/><polyline points="14 2 14 8 20 8"/><line x1="9" y1="13" x2="15" y2="13"/><line x1="9" y1="17" x2="15" y2="17"/></svg>
    Paper
  </a>
  <a class="pill" href="https://rawalkhirodkar.github.io/sapiens2" target="_blank" rel="noopener">
    <svg viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><circle cx="12" cy="12" r="10"/><line x1="2" y1="12" x2="22" y2="12"/><path d="M12 2a15.3 15.3 0 0 1 4 10 15.3 15.3 0 0 1-4 10 15.3 15.3 0 0 1-4-10 15.3 15.3 0 0 1 4-10z"/></svg>
    Project
  </a>
</div>
"""

with gr.Blocks(title="Sapiens2 Pose", theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
    gr.HTML(HEADER_HTML)

    with gr.Row(equal_height=True):
        inp = gr.Image(label="Input", type="pil", height=640)
        out_img = gr.Image(label="Pose", type="pil", height=640)

    with gr.Row():
        size = gr.Radio(
            choices=list(POSE_MODELS.keys()),
            value=DEFAULT_SIZE,
            label="Model",
            scale=2,
        )
        thr = gr.Slider(0.0, 1.0, value=0.3, step=0.05, label="Keypoint threshold", scale=2)
        run = gr.Button("Run", variant="primary", size="lg", scale=1)

    gr.Examples(examples=EXAMPLES, inputs=inp, examples_per_page=16)

    with gr.Accordion("Raw Keypoints", open=False):
        out_json = gr.File(label="Keypoints (.json — bboxes + 308 keypoints + scores per person)")

    run.click(predict, inputs=[inp, size, thr], outputs=[out_img, out_json])


if __name__ == "__main__":
    if torch.cuda.is_available():
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
    demo.launch(share=False)