Nekochu commited on
Commit
57e2037
·
0 Parent(s):

Initial commit

Browse files
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.jpg filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Sapiens2 CPU
3
+ emoji: 🧍
4
+ colorFrom: indigo
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 6.14.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: other
11
+ ---
12
+
13
+ # Sapiens2 CPU
14
+
15
+ Meta's `facebook/sapiens2-*` running on free HF CPU. 15 variants exposed: seg, normal, pointmap, pose across 0.4b, 0.8b, 1b, plus seg-5b, normal-5b, pointmap-5b as INT8 ONNX. Curl-callable with a Bearer token.
16
+
17
+ ## Variants and inference time on the included 6000×4000 demo image
18
+
19
+ | Task | Notes | 0.4b | 0.8b | 1b | 5b (INT8 ONNX) |
20
+ |---|---|---|---|---|---|
21
+ | seg | DOME 29-class body parts | 54 s | 79 s | 142 s | downloads once, then runs |
22
+ | normal | per-pixel surface normals | 50 s | 121 s | 150 s | downloads once, then runs |
23
+ | pointmap | per-pixel XYZ in meters | 76 s | 145 s | 173 s | downloads once, then runs |
24
+ | pose | DETR detect, 308 keypoints | 72 s | 103 s | 174 s | not shipped |
25
+
26
+ 0.4b through 1b run as fp32 PyTorch. 5B runs as INT8 ONNX (5 to 6 GB on disk; fp32 5B would need ~20 GB RAM, more than the free tier provides). Dense tasks share an LRU(2) cache; pose has its own slot and DETR is loaded once. First call per variant downloads then loads (~30-90 s warmup).
27
+
28
+ The model fixes a 1024×768 input tensor (NCHW with H=1024, W=768, a portrait canvas in Meta's convention). Any input is aspect-preserve resized then padded to that.
29
+
30
+ ## CPU-friendly ONNX exports
31
+
32
+ Companion repo: [`WeReCooking/sapiens2-onnx`](https://huggingface.co/WeReCooking/sapiens2-onnx) (private). File naming `<task>_<size>_<precision>.onnx` plus a `.onnx.data` external sidecar. 15 ONNX artifacts shipped: 12 covering 0.4b/0.8b/1b (fp16 for seg-0.4b, fp32 for the rest), and 3 new 5B int8 files (seg, normal, pointmap). Cosine similarity vs PyTorch fp32 is 0.999 or better on all shipped variants.
33
+
34
+ Turnkey CLI built into `app.py` (no sapiens2 / PyTorch dep needed; install `requirements.txt`):
35
+
36
+ ```bash
37
+ export HF_TOKEN=hf_xxx
38
+ python app.py onnx seg 0.4b photo.jpg --output seg_overlay.png
39
+ python app.py onnx normal 1b photo.jpg --output normals.png
40
+ python app.py onnx pointmap 0.8b photo.jpg --output depth.png
41
+ python app.py onnx pose 0.4b photo.jpg --output pose.png
42
+ python app.py onnx seg 5b photo.jpg --output seg_5b.png
43
+ ```
44
+
45
+ ## Curl tests
46
+
47
+ ```bash
48
+ TOKEN="hf_xxx"
49
+ SPACE="https://werecooking-sapiens2-cpu.hf.space"
50
+ IMG="https://huggingface.co/spaces/facebook/sapiens2-seg/resolve/main/assets/images/pexels-alex-green-5699868.jpg"
51
+
52
+ EVT=$(curl -s -X POST "$SPACE/gradio_api/call/predict" \
53
+ -H "Authorization: Bearer $TOKEN" -H "Content-Type: application/json" \
54
+ -d "{\"data\":[{\"path\":\"$IMG\",\"meta\":{\"_type\":\"gradio.FileData\"}},\"seg\",\"0.4b\"]}" \
55
+ | python -c "import sys,json;print(json.load(sys.stdin)['event_id'])")
56
+ curl -sN "$SPACE/gradio_api/call/predict/$EVT" -H "Authorization: Bearer $TOKEN"
57
+ ```
58
+
59
+ ## Logs (SSE)
60
+
61
+ ```bash
62
+ curl -N -H "Authorization: Bearer $TOKEN" "https://huggingface.co/api/spaces/WeReCooking/sapiens2-cpu/logs/build"
63
+ curl -N -H "Authorization: Bearer $TOKEN" "https://huggingface.co/api/spaces/WeReCooking/sapiens2-cpu/logs/run"
64
+ ```
65
+
66
+ ## 5B INT8 ONNX conversion recipe
67
+
68
+ The dense 5B variants ship as INT8 ONNX. To re-run the pipeline:
69
+
70
+ 1. Export fp16 ONNX using lazy fp16 init. Call `torch.set_default_dtype(torch.float16)` before `init_model(cfg, None, device="cpu")`, then stream the safetensors file tensor by tensor into the empty fp16 model. This avoids the ~22 GB fp32 init that OOMs on a 32 GB box. Export with `opset_version=18` and no `dynamic_axes`. Force `sys.stdout.reconfigure(encoding="utf-8")` so torch.onnx's success print does not crash on Windows cp1252.
71
+ 2. Stream cast fp16 to fp32 on disk via `onnx.external_data_helper.load_external_data_for_model` plus per-tensor `numpy_helper`. Peak RAM stays close to a single tensor (~250 MB). Drop Cast(fp16 / fp32) nodes with transitive rename closure so consumers point at the original input.
72
+ 3. Run `quantize.shape_inference.quant_pre_process(skip_onnx_shape=True, skip_optimization=True)`. This routes through ORT symbolic shape inference which understands sapiens2 windowed attention. Vanilla `onnx.shape_inference` errors with `(6144) vs (512)` on the pointmap and normal heads.
73
+ 4. `quantize_dynamic(weight_type=QuantType.QInt8, per_channel=True, op_types_to_quantize=["MatMul"], use_external_data_format=True)`. This lowers to `MatMulIntegerToFloat`, which accepts fp32 input and has no 2D-only filter (unlike `MatMulNBitsQuantizer` which silently skips 3D packed-QKV weights).
74
+
75
+ Pose-5b is not shipped. It uses a different forward signature (single person bbox cropped tensor) and the int8 quantize attempt did not complete on the available hardware.
76
+
77
+ ## Files
78
+
79
+ * `app.py` everything: Gradio Space UI, PyTorch dispatch for 0.4b/0.8b/1b, ORT for 5B, inlined keypoint visualization, plus the `python app.py onnx ...` CLI
80
+ * `requirements.txt` Python deps including `sapiens @ git+https://github.com/facebookresearch/sapiens2.git`
81
+ * `packages.txt` apt deps (`libgl1`, `libglib2.0-0`) installed by the Gradio SDK at build time
app.py ADDED
@@ -0,0 +1,706 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Sapiens2 multi-task CPU: seg / normal / pointmap / pose at 0.4b/0.8b/1b plus 5B INT8 ONNX.
2
+
3
+ 5B (seg, normal, pointmap) runs via INT8 ONNX from WeReCooking/sapiens2-onnx; pose-5b not shipped.
4
+ Pose top-down: DETR finds people, sapiens2 estimates 308 keypoints per crop.
5
+ Lazy-load with LRU cache (keeps 2 dense models + 1 pose model resident).
6
+ Per-task API endpoint via Gradio's auto-API (curl-able with Bearer token).
7
+
8
+ Also exposes a standalone ONNX CLI mode that does not need PyTorch or sapiens2:
9
+ python app.py onnx seg 0.4b photo.jpg --output seg.png
10
+ python app.py onnx pointmap 5b photo.jpg --output depth.png
11
+ """
12
+ # Block mmpretrain: mmdet's reid modules try to import it via try/except ImportError,
13
+ # but mmpretrain raises TypeError on import (transformers API drift) which escapes
14
+ # the except and kills the process.
15
+ import sys
16
+ sys.modules["mmpretrain"] = None
17
+
18
+
19
+ # --- ONNX CLI (standalone, no PyTorch/sapiens2 import) ----------------------
20
+ def _onnx_cli():
21
+ """Run a published sapiens2 ONNX model on a local image. Only needs numpy,
22
+ onnxruntime, huggingface_hub, opencv-python-headless."""
23
+ import argparse
24
+ import os
25
+ import time
26
+ from pathlib import Path
27
+ import numpy as np
28
+ import cv2
29
+ import onnxruntime as ort
30
+ from huggingface_hub import hf_hub_download
31
+
32
+ DEFAULT_REPO = "WeReCooking/sapiens2-onnx"
33
+ PRECISIONS = {("seg", "0.4b"): "fp16"} # only seg-0.4b is fp16; rest fp32 or int8 for 5B
34
+ INPUT_HW = (1024, 768)
35
+
36
+ parser = argparse.ArgumentParser(prog="app.py onnx")
37
+ parser.add_argument("task", choices=["seg", "normal", "pointmap", "pose"])
38
+ parser.add_argument("size", choices=["0.4b", "0.8b", "1b", "5b"])
39
+ parser.add_argument("image", help="Local image path")
40
+ parser.add_argument("--cache-dir", default="./onnx_cache")
41
+ parser.add_argument("--token", default=os.environ.get("HF_TOKEN"))
42
+ parser.add_argument("--output", default=None, help="Save the visualization here")
43
+ parser.add_argument("--repo", default=DEFAULT_REPO)
44
+ args = parser.parse_args(sys.argv[2:])
45
+
46
+ precision = PRECISIONS.get((args.task, args.size), "int8" if args.size == "5b" else "fp32")
47
+ filename = f"{args.task}/{args.task}_{args.size}_{precision}.onnx"
48
+ print(f"[1/3] downloading {filename} from {args.repo}", flush=True)
49
+ t0 = time.time()
50
+ onnx_path = hf_hub_download(repo_id=args.repo, filename=filename, local_dir=args.cache_dir, token=args.token)
51
+ hf_hub_download(repo_id=args.repo, filename=f"{filename}.data", local_dir=args.cache_dir, token=args.token)
52
+ print(f" ready in {time.time()-t0:.1f}s", flush=True)
53
+
54
+ img = cv2.imread(args.image, cv2.IMREAD_COLOR)
55
+ if img is None:
56
+ raise FileNotFoundError(args.image)
57
+ H, W = INPUT_HW
58
+ h0, w0 = img.shape[:2]
59
+ scale = min(W / w0, H / h0)
60
+ new_w, new_h = int(round(w0 * scale)), int(round(h0 * scale))
61
+ resized = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
62
+ canvas = np.zeros((H, W, 3), dtype=np.uint8)
63
+ top = (H - new_h) // 2
64
+ left = (W - new_w) // 2
65
+ canvas[top:top + new_h, left:left + new_w] = resized
66
+ mean = (123.675, 116.28, 103.53)
67
+ std = (58.395, 57.12, 57.375)
68
+ x = canvas.astype(np.float32)
69
+ for c in range(3):
70
+ x[:, :, c] = (x[:, :, c] - mean[c]) / std[c]
71
+ x = x.transpose(2, 0, 1)[None]
72
+
73
+ print(f"[2/3] ORT forward (input {x.shape} {x.dtype})", flush=True)
74
+ sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
75
+ t0 = time.time()
76
+ out = sess.run(None, {sess.get_inputs()[0].name: x})
77
+ print(f" forward {time.time()-t0:.1f}s, outputs={[o.shape for o in out]}", flush=True)
78
+
79
+ print(f"[3/3] postprocess + preview", flush=True)
80
+ if args.task == "pose":
81
+ heatmaps = out[0][0]
82
+ K, hH, hW = heatmaps.shape
83
+ flat = heatmaps.reshape(K, -1)
84
+ peak = flat.argmax(axis=1)
85
+ ys, xs = np.unravel_index(peak, (hH, hW))
86
+ scores = flat.max(axis=1)
87
+ inp_y = ys * (INPUT_HW[0] / hH)
88
+ inp_x = xs * (INPUT_HW[1] / hW)
89
+ scale_y = h0 / new_h
90
+ scale_x = w0 / new_w
91
+ img_y = (inp_y - top) * scale_y
92
+ img_x = (inp_x - left) * scale_x
93
+ n_visible = int((scores > 0.3).sum())
94
+ print(f" {n_visible}/{K} keypoints above 0.3 confidence (range {scores.min():.3f} to {scores.max():.3f})")
95
+ if args.output:
96
+ for i in range(K):
97
+ if scores[i] < 0.3:
98
+ continue
99
+ cv2.circle(img, (int(img_x[i]), int(img_y[i])), 4, (0, 255, 0), -1)
100
+ cv2.imwrite(args.output, img)
101
+ print(f" saved {args.output}")
102
+ return
103
+
104
+ if args.task == "seg":
105
+ logits = out[0][0]
106
+ class_map = logits.argmax(axis=0).astype(np.int32)
107
+ class_map_crop = class_map[top:top + new_h, left:left + new_w]
108
+ class_map_full = cv2.resize(class_map_crop, (w0, h0), interpolation=cv2.INTER_NEAREST)
109
+ classes = np.unique(class_map_full).tolist()
110
+ print(f" classes detected: {classes[:15]}")
111
+ if args.output:
112
+ palette = (np.random.RandomState(42).rand(29, 3) * 255).astype(np.uint8)
113
+ cv2.imwrite(args.output, palette[class_map_full])
114
+ print(f" saved {args.output}")
115
+ return
116
+
117
+ if args.task == "normal":
118
+ normal_raw = out[0][0].transpose(1, 2, 0)
119
+ norm = np.linalg.norm(normal_raw, axis=2, keepdims=True)
120
+ normal_unit = normal_raw / np.maximum(norm, 1e-8)
121
+ normal_crop = normal_unit[top:top + new_h, left:left + new_w]
122
+ normal_full = cv2.resize(normal_crop, (w0, h0), interpolation=cv2.INTER_LINEAR)
123
+ if args.output:
124
+ rgb = (((normal_full + 1.0) / 2.0) * 255).clip(0, 255).astype(np.uint8)
125
+ cv2.imwrite(args.output, rgb)
126
+ print(f" saved {args.output}")
127
+ return
128
+
129
+ # pointmap
130
+ pointmap_rel = out[0][0].transpose(1, 2, 0)
131
+ s = out[1][0, 0] if len(out) > 1 else 1.0
132
+ pointmap_metric = pointmap_rel / max(float(s), 1e-8)
133
+ z = pointmap_metric[..., 2]
134
+ z_crop = z[top:top + new_h, left:left + new_w]
135
+ z_full = cv2.resize(z_crop, (w0, h0), interpolation=cv2.INTER_LINEAR)
136
+ zmin, zmax = float(z_full.min()), float(z_full.max())
137
+ print(f" Z range: [{zmin:.2f}, {zmax:.2f}] meters")
138
+ if args.output:
139
+ z_norm = ((z_full - zmin) / max(zmax - zmin, 1e-8) * 255).astype(np.uint8)
140
+ cv2.imwrite(args.output, z_norm)
141
+ print(f" saved {args.output}")
142
+
143
+
144
+ if len(sys.argv) > 1 and sys.argv[1] == "onnx":
145
+ _onnx_cli()
146
+ sys.exit(0)
147
+
148
+
149
+ # --- Gradio path -----------------------------------------------------------
150
+ import glob
151
+ import os
152
+ import time
153
+ import traceback
154
+ from pathlib import Path
155
+
156
+ import gradio as gr
157
+ import numpy as np
158
+ from PIL import Image
159
+
160
+ # --- Catalog ----------------------------------------------------------------
161
+ VARIANTS = {
162
+ ("seg", "0.4b"): {"repo": "facebook/sapiens2-seg-0.4b", "filename": "sapiens2_0.4b_seg.safetensors", "config_glob": "**/sapiens2_0.4b_seg*shutterstock*1024x768*.py", "kind": "seg"},
163
+ ("seg", "0.8b"): {"repo": "facebook/sapiens2-seg-0.8b", "filename": "sapiens2_0.8b_seg.safetensors", "config_glob": "**/sapiens2_0.8b_seg*shutterstock*1024x768*.py", "kind": "seg"},
164
+ ("seg", "1b"): {"repo": "facebook/sapiens2-seg-1b", "filename": "sapiens2_1b_seg.safetensors", "config_glob": "**/sapiens2_1b_seg*shutterstock*1024x768*.py", "kind": "seg"},
165
+ ("normal", "0.4b"): {"repo": "facebook/sapiens2-normal-0.4b", "filename": "sapiens2_0.4b_normal.safetensors", "config_glob": "**/sapiens2_0.4b_normal*metasim*1024x768*.py", "kind": "normal"},
166
+ ("normal", "0.8b"): {"repo": "facebook/sapiens2-normal-0.8b", "filename": "sapiens2_0.8b_normal.safetensors", "config_glob": "**/sapiens2_0.8b_normal*metasim*1024x768*.py", "kind": "normal"},
167
+ ("normal", "1b"): {"repo": "facebook/sapiens2-normal-1b", "filename": "sapiens2_1b_normal.safetensors", "config_glob": "**/sapiens2_1b_normal*metasim*1024x768*.py", "kind": "normal"},
168
+ ("pointmap", "0.4b"): {"repo": "facebook/sapiens2-pointmap-0.4b", "filename": "sapiens2_0.4b_pointmap.safetensors", "config_glob": "**/sapiens2_0.4b_pointmap*render_people*1024x768*.py", "kind": "pointmap"},
169
+ ("pointmap", "0.8b"): {"repo": "facebook/sapiens2-pointmap-0.8b", "filename": "sapiens2_0.8b_pointmap.safetensors", "config_glob": "**/sapiens2_0.8b_pointmap*render_people*1024x768*.py", "kind": "pointmap"},
170
+ ("pointmap", "1b"): {"repo": "facebook/sapiens2-pointmap-1b", "filename": "sapiens2_1b_pointmap.safetensors", "config_glob": "**/sapiens2_1b_pointmap*render_people*1024x768*.py", "kind": "pointmap"},
171
+ ("pose", "0.4b"): {"repo": "facebook/sapiens2-pose-0.4b", "filename": "sapiens2_0.4b_pose.safetensors", "config_glob": "**/sapiens2_0.4b_keypoints308*shutterstock_goliath*1024x768*.py", "kind": "pose"},
172
+ ("pose", "0.8b"): {"repo": "facebook/sapiens2-pose-0.8b", "filename": "sapiens2_0.8b_pose.safetensors", "config_glob": "**/sapiens2_0.8b_keypoints308*shutterstock_goliath*1024x768*.py", "kind": "pose"},
173
+ ("pose", "1b"): {"repo": "facebook/sapiens2-pose-1b", "filename": "sapiens2_1b_pose.safetensors", "config_glob": "**/sapiens2_1b_keypoints308*shutterstock_goliath*1024x768*.py", "kind": "pose"},
174
+ # 5B variants run via prebuilt INT8 ONNX from WeReCooking/sapiens2-onnx.
175
+ # fp32 5B PyTorch (~20 GB) won't fit in the free CPU Space's 16 GB; INT8 ONNX is ~5-6 GB.
176
+ # pose-5b is intentionally absent — INT8 wasn't successfully built for it.
177
+ ("seg", "5b"): {"onnx_repo": "WeReCooking/sapiens2-onnx", "onnx_filename": "seg/seg_5b_int8.onnx", "kind": "seg"},
178
+ ("normal", "5b"): {"onnx_repo": "WeReCooking/sapiens2-onnx", "onnx_filename": "normal/normal_5b_int8.onnx", "kind": "normal"},
179
+ ("pointmap", "5b"): {"onnx_repo": "WeReCooking/sapiens2-onnx", "onnx_filename": "pointmap/pointmap_5b_int8.onnx", "kind": "pointmap"},
180
+ }
181
+
182
+ DENSE_KINDS = {"seg", "normal", "pointmap"}
183
+
184
+ _MODELS: dict = {} # (task, size) -> dense model (LRU)
185
+ _POSE_MODELS: dict = {} # (task, size) -> pose model (separate cache so DETR survives)
186
+ _DETECTOR = None # tuple(processor, model) — lazily loaded once
187
+ _POSE_METAINFO = None
188
+ _ORT_SESSIONS: dict = {} # (task, "5b") -> onnxruntime InferenceSession
189
+ _MAX_CACHED = 2
190
+ _DOME_CLASSES_29 = None
191
+
192
+
193
+ _SAPIENS_PKG_ROOT = None
194
+
195
+
196
+ def _sapiens_root() -> Path:
197
+ """Return the directory containing the installed sapiens package."""
198
+ global _SAPIENS_PKG_ROOT
199
+ if _SAPIENS_PKG_ROOT is None:
200
+ import sapiens # imported lazily because it has side effects (mmdet etc.)
201
+ _SAPIENS_PKG_ROOT = Path(sapiens.__file__).resolve().parent
202
+ return _SAPIENS_PKG_ROOT
203
+
204
+
205
+ def _find_config(pattern: str) -> str:
206
+ # cfg_glob comes in as "**/sapiens2_..._1024x768*.py"; rglob applies the leading ** implicitly
207
+ leaf = pattern.split("/")[-1]
208
+ root = _sapiens_root()
209
+ matches = list(root.rglob(leaf))
210
+ if not matches:
211
+ raise FileNotFoundError(f"No config matching {leaf} under {root}")
212
+ return str(matches[0])
213
+
214
+
215
+ def _get_dense_model(task: str, size: str):
216
+ """Lazy-load + LRU-cache for seg/normal/pointmap."""
217
+ key = (task, size)
218
+ if key in _MODELS:
219
+ _MODELS[key] = _MODELS.pop(key)
220
+ return _MODELS[key]
221
+
222
+ spec = VARIANTS[key]
223
+ from sapiens.dense.models import init_model
224
+ if spec["kind"] == "normal":
225
+ from sapiens.dense.models import NormalEstimator # noqa: F401
226
+ elif spec["kind"] == "pointmap":
227
+ from sapiens.dense.models import PointmapEstimator # noqa: F401
228
+
229
+ config = _find_config(spec["config_glob"])
230
+
231
+ from huggingface_hub import hf_hub_download
232
+ local_dir = f"/tmp/sapiens_models/{task}-{size}"
233
+ os.makedirs(local_dir, exist_ok=True)
234
+ ckpt = hf_hub_download(repo_id=spec["repo"], filename=spec["filename"], local_dir=local_dir)
235
+
236
+ model = init_model(config, ckpt, device="cpu")
237
+
238
+ while len(_MODELS) >= _MAX_CACHED:
239
+ oldest = next(iter(_MODELS))
240
+ del _MODELS[oldest]
241
+ import gc
242
+ gc.collect()
243
+ _MODELS[key] = model
244
+ return model
245
+
246
+
247
+ def _get_pose_metainfo():
248
+ global _POSE_METAINFO
249
+ if _POSE_METAINFO is None:
250
+ from sapiens.pose.datasets import parse_pose_metainfo
251
+ meta_cfg = _find_config("**/pose/configs/**/keypoints308.py")
252
+ import importlib.util
253
+ spec_obj = importlib.util.spec_from_file_location("keypoints308_meta", meta_cfg)
254
+ mod = importlib.util.module_from_spec(spec_obj)
255
+ spec_obj.loader.exec_module(mod)
256
+ ds_info = getattr(mod, "dataset_info", None)
257
+ if ds_info is None:
258
+ raise RuntimeError(f"No dataset_info in {meta_cfg}")
259
+ _POSE_METAINFO = parse_pose_metainfo(ds_info)
260
+ return _POSE_METAINFO
261
+
262
+
263
+ def _get_pose_model(size: str):
264
+ key = ("pose", size)
265
+ if key in _POSE_MODELS:
266
+ return _POSE_MODELS[key]
267
+ spec = VARIANTS[key]
268
+ from sapiens.pose.models import init_model
269
+ from sapiens.pose.datasets import UDPHeatmap
270
+
271
+ config = _find_config(spec["config_glob"])
272
+ from huggingface_hub import hf_hub_download
273
+ local_dir = f"/tmp/sapiens_models/pose-{size}"
274
+ os.makedirs(local_dir, exist_ok=True)
275
+ ckpt = hf_hub_download(repo_id=spec["repo"], filename=spec["filename"], local_dir=local_dir)
276
+ model = init_model(config, ckpt, device="cpu")
277
+
278
+ codec_cfg = dict(model.cfg.codec)
279
+ assert codec_cfg.pop("type") == "UDPHeatmap"
280
+ model.codec = UDPHeatmap(**codec_cfg)
281
+ model.pose_metainfo = _get_pose_metainfo()
282
+
283
+ # Free the largest cached dense model first if more than one pose model present.
284
+ while len(_POSE_MODELS) >= 1:
285
+ oldest = next(iter(_POSE_MODELS))
286
+ del _POSE_MODELS[oldest]
287
+ import gc
288
+ gc.collect()
289
+ _POSE_MODELS[key] = model
290
+ return model
291
+
292
+
293
+ def _get_detector():
294
+ global _DETECTOR
295
+ if _DETECTOR is None:
296
+ import torch # noqa: F401
297
+ from transformers import DetrImageProcessor, DetrForObjectDetection
298
+ proc = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
299
+ det = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50").eval()
300
+ _DETECTOR = (proc, det)
301
+ return _DETECTOR
302
+
303
+
304
+ def _load_dome_classes():
305
+ global _DOME_CLASSES_29
306
+ if _DOME_CLASSES_29 is None:
307
+ from sapiens.dense.src.datasets.seg.seg_utils import DOME_CLASSES_29
308
+ _DOME_CLASSES_29 = DOME_CLASSES_29
309
+ return _DOME_CLASSES_29
310
+
311
+
312
+ def _get_padding(data_samples):
313
+ ds = data_samples[0] if isinstance(data_samples, list) and data_samples else data_samples
314
+ if hasattr(ds, "padding_size"):
315
+ return tuple(ds.padding_size)
316
+ if hasattr(ds, "metainfo") and isinstance(ds.metainfo, dict):
317
+ if "padding_size" in ds.metainfo:
318
+ return tuple(ds.metainfo["padding_size"])
319
+ if "pad_shape" in ds.metainfo and "img_shape" in ds.metainfo:
320
+ ph, pw = ds.metainfo["pad_shape"][:2]
321
+ ih, iw = ds.metainfo["img_shape"][:2]
322
+ return (0, pw - iw, 0, ph - ih)
323
+ if isinstance(ds, dict):
324
+ meta = ds.get("meta") or ds
325
+ if "padding_size" in meta:
326
+ return tuple(meta["padding_size"])
327
+ return (0, 0, 0, 0)
328
+
329
+
330
+ # --- Per-task inference -----------------------------------------------------
331
+ def _infer_seg(image_bgr, model):
332
+ import torch
333
+ import torch.nn.functional as F
334
+ import cv2
335
+ h0, w0 = image_bgr.shape[:2]
336
+ data = model.pipeline(dict(img=image_bgr))
337
+ data = model.data_preprocessor(data)
338
+ with torch.no_grad():
339
+ logits = model(data["inputs"])
340
+ logits = F.interpolate(logits, size=(h0, w0), mode="bilinear", align_corners=False)
341
+ label_map = logits.argmax(dim=1).squeeze(0).cpu().numpy().astype(np.int32)
342
+
343
+ classes = _load_dome_classes()
344
+ palette = np.zeros((256, 3), dtype=np.uint8)
345
+ for cid, meta in classes.items():
346
+ palette[cid] = meta["color"][::-1]
347
+ color_mask = palette[label_map]
348
+ overlay_bgr = cv2.addWeighted(image_bgr, 0.5, color_mask, 0.5, 0)
349
+ overlay_rgb = cv2.cvtColor(overlay_bgr, cv2.COLOR_BGR2RGB)
350
+ uniq = sorted(int(c) for c in np.unique(label_map))
351
+ labels = [classes[c]["name"].replace("_", " ") for c in uniq if c in classes]
352
+ return Image.fromarray(overlay_rgb), f"classes: {', '.join(labels)}"
353
+
354
+
355
+ def _infer_normal(image_bgr, model):
356
+ import torch
357
+ data = model.pipeline(dict(img=image_bgr))
358
+ data = model.data_preprocessor(data)
359
+ inputs, data_samples = data["inputs"], data["data_samples"]
360
+ if inputs.ndim == 3:
361
+ inputs = inputs.unsqueeze(0)
362
+ with torch.no_grad():
363
+ normal = model(inputs)
364
+ normal = normal / normal.norm(dim=1, keepdim=True).clamp_min(1e-8)
365
+ pl, pr, pt, pb = _get_padding(data_samples)
366
+ normal = normal[:, :, pt:inputs.shape[2] - pb, pl:inputs.shape[3] - pr]
367
+ normal_hwc = normal.squeeze(0).cpu().float().numpy().transpose(1, 2, 0)
368
+ rgb = (((normal_hwc + 1.0) / 2.0) * 255.0).clip(0, 255).astype(np.uint8)
369
+ return Image.fromarray(rgb), f"normal map {rgb.shape}"
370
+
371
+
372
+ def _infer_pointmap(image_bgr, model):
373
+ import torch
374
+ data = model.pipeline(dict(img=image_bgr))
375
+ data = model.data_preprocessor(data)
376
+ inputs, data_samples = data["inputs"], data["data_samples"]
377
+ if inputs.ndim == 3:
378
+ inputs = inputs.unsqueeze(0)
379
+ with torch.no_grad():
380
+ out = model(inputs)
381
+ if isinstance(out, tuple) and len(out) == 2:
382
+ pointmap, scale = out
383
+ pointmap = pointmap / scale.clamp_min(1e-8)
384
+ else:
385
+ pointmap = out
386
+ pl, pr, pt, pb = _get_padding(data_samples)
387
+ pointmap = pointmap[:, :, pt:inputs.shape[2] - pb, pl:inputs.shape[3] - pr]
388
+ pmap_hwc = pointmap.squeeze(0).cpu().float().numpy().transpose(1, 2, 0)
389
+ z = pmap_hwc[..., 2]
390
+ z_min, z_max = float(z.min()), float(z.max())
391
+ z_norm = (z - z_min) / max(z_max - z_min, 1e-8)
392
+ z_rgb = (z_norm * 255).astype(np.uint8)
393
+ rgb = np.stack([z_rgb, z_rgb, z_rgb], axis=-1)
394
+ return Image.fromarray(rgb), f"pointmap {pmap_hwc.shape} | Z [{z_min:.2f}, {z_max:.2f}]"
395
+
396
+
397
+ # --- 5B INT8 ONNX path -------------------------------------------------------
398
+ def _get_ort_session(task: str):
399
+ """Lazy-load + cache an ORT session for {task}_5b_int8.onnx."""
400
+ key = (task, "5b")
401
+ sess = _ORT_SESSIONS.get(key)
402
+ if sess is not None:
403
+ return sess
404
+ import onnxruntime as ort
405
+ from huggingface_hub import hf_hub_download
406
+ spec = VARIANTS[key]
407
+ cache_dir = os.environ.get("ONNX_5B_CACHE", "/app/onnx_5b")
408
+ os.makedirs(cache_dir, exist_ok=True)
409
+ # Download both the graph .onnx and its external-data sidecar.
410
+ fn = spec["onnx_filename"]
411
+ onnx_path = hf_hub_download(repo_id=spec["onnx_repo"], filename=fn, local_dir=cache_dir)
412
+ hf_hub_download(repo_id=spec["onnx_repo"], filename=fn + ".data", local_dir=cache_dir)
413
+ sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
414
+ _ORT_SESSIONS[key] = sess
415
+ return sess
416
+
417
+
418
+ def _infer_dense_5b(image_bgr, task: str):
419
+ """5B inference: preprocess via the 0.4b PyTorch pipeline (cached), forward via ORT INT8."""
420
+ import torch
421
+ import torch.nn.functional as F
422
+ import cv2
423
+
424
+ # Use the 0.4b model's pipeline+preprocessor for image prep — it's already in cache for warm calls.
425
+ proxy = _get_dense_model(task, "0.4b")
426
+ data = proxy.pipeline(dict(img=image_bgr))
427
+ data = proxy.data_preprocessor(data)
428
+ inputs, data_samples = data["inputs"], data["data_samples"]
429
+ if inputs.ndim == 3:
430
+ inputs = inputs.unsqueeze(0)
431
+
432
+ sess = _get_ort_session(task)
433
+ out = sess.run(None, {sess.get_inputs()[0].name: inputs.float().cpu().numpy()})
434
+
435
+ if task == "seg":
436
+ logits = torch.from_numpy(out[0])
437
+ h0, w0 = image_bgr.shape[:2]
438
+ logits = F.interpolate(logits, size=(h0, w0), mode="bilinear", align_corners=False)
439
+ label_map = logits.argmax(dim=1).squeeze(0).numpy().astype(np.int32)
440
+ classes = _load_dome_classes()
441
+ palette = np.zeros((256, 3), dtype=np.uint8)
442
+ for cid, meta in classes.items():
443
+ palette[cid] = meta["color"][::-1]
444
+ color_mask = palette[label_map]
445
+ overlay_bgr = cv2.addWeighted(image_bgr, 0.5, color_mask, 0.5, 0)
446
+ overlay_rgb = cv2.cvtColor(overlay_bgr, cv2.COLOR_BGR2RGB)
447
+ uniq = sorted(int(c) for c in np.unique(label_map))
448
+ labels = [classes[c]["name"].replace("_", " ") for c in uniq if c in classes]
449
+ return Image.fromarray(overlay_rgb), f"classes: {', '.join(labels)}"
450
+
451
+ if task == "normal":
452
+ normal = torch.from_numpy(out[0])
453
+ normal = normal / normal.norm(dim=1, keepdim=True).clamp_min(1e-8)
454
+ pl, pr, pt, pb = _get_padding(data_samples)
455
+ normal = normal[:, :, pt:inputs.shape[2] - pb, pl:inputs.shape[3] - pr]
456
+ normal_hwc = normal.squeeze(0).numpy().transpose(1, 2, 0)
457
+ rgb = (((normal_hwc + 1.0) / 2.0) * 255.0).clip(0, 255).astype(np.uint8)
458
+ return Image.fromarray(rgb), f"normal map {rgb.shape}"
459
+
460
+ # pointmap — ONNX produces (pointmap [1,3,H,W], scale [1,1]); divide to recover metric depths.
461
+ pointmap = torch.from_numpy(out[0])
462
+ if len(out) > 1:
463
+ scale = torch.from_numpy(out[1])
464
+ pointmap = pointmap / scale.clamp_min(1e-8)
465
+ pl, pr, pt, pb = _get_padding(data_samples)
466
+ pointmap = pointmap[:, :, pt:inputs.shape[2] - pb, pl:inputs.shape[3] - pr]
467
+ pmap_hwc = pointmap.squeeze(0).numpy().transpose(1, 2, 0)
468
+ z = pmap_hwc[..., 2]
469
+ z_min, z_max = float(z.min()), float(z.max())
470
+ z_norm = (z - z_min) / max(z_max - z_min, 1e-8)
471
+ z_rgb = (z_norm * 255).astype(np.uint8)
472
+ rgb = np.stack([z_rgb, z_rgb, z_rgb], axis=-1)
473
+ return Image.fromarray(rgb), f"pointmap {pmap_hwc.shape} | Z [{z_min:.2f}, {z_max:.2f}]"
474
+
475
+
476
+ # Inlined from the upstream Meta sample (pose keypoint render).
477
+ # Draws skeleton links + colored keypoints; thickness/radius are picked by the caller.
478
+ def visualize_keypoints(
479
+ image: np.ndarray,
480
+ keypoints,
481
+ keypoints_visible,
482
+ keypoint_scores,
483
+ *,
484
+ radius: int = 4,
485
+ thickness: int = -1,
486
+ color=(255, 0, 0),
487
+ kpt_thr: float = 0.3,
488
+ skeleton: list | None = None,
489
+ kpt_color=None,
490
+ link_color=None,
491
+ show_kpt_idx: bool = False,
492
+ ) -> np.ndarray:
493
+ import cv2
494
+ img = image.copy()
495
+ H, W = img.shape[:2]
496
+ if skeleton is None:
497
+ skeleton = []
498
+ if kpt_color is None:
499
+ kpt_color = color
500
+ if link_color is None:
501
+ link_color = (0, 255, 0)
502
+
503
+ def _as_color_list(c, n):
504
+ if hasattr(c, "detach"):
505
+ c = c.detach().cpu().numpy()
506
+ if isinstance(c, np.ndarray):
507
+ if c.ndim == 2 and c.shape[1] == 3:
508
+ return [tuple(int(v) for v in row) for row in c.tolist()]
509
+ if c.size == 3:
510
+ return [tuple(int(v) for v in c.tolist())] * max(1, n)
511
+ if isinstance(c, (list, tuple)):
512
+ if n and len(c) == n and isinstance(c[0], (list, tuple, np.ndarray)):
513
+ out = []
514
+ for cc in c:
515
+ cc = np.asarray(cc).reshape(-1)
516
+ out.append(tuple(int(v) for v in cc.tolist()))
517
+ return out
518
+ c_arr = np.asarray(c).reshape(-1)
519
+ if c_arr.size == 3:
520
+ return [tuple(int(v) for v in c_arr.tolist())] * max(1, n)
521
+ return [(255, 0, 0)] * max(1, n)
522
+
523
+ J = keypoints[0].shape[0] if keypoints else 0
524
+ kpt_colors = _as_color_list(kpt_color, J)
525
+ link_colors = _as_color_list(link_color, len(skeleton))
526
+
527
+ def in_bounds(x, y):
528
+ return 0 <= x < W and 0 <= y < H
529
+
530
+ for kpts, vis, score in zip(keypoints, keypoints_visible, keypoint_scores):
531
+ kpts = np.asarray(kpts, float)
532
+ vis = np.asarray(vis).reshape(-1).astype(bool)
533
+ score = np.asarray(score).reshape(-1)
534
+ for lk, (i, j) in enumerate(skeleton):
535
+ if i >= len(kpts) or j >= len(kpts):
536
+ continue
537
+ if not (vis[i] and vis[j]):
538
+ continue
539
+ if score[i] < kpt_thr or score[j] < kpt_thr:
540
+ continue
541
+ x1, y1 = map(int, np.round(kpts[i]))
542
+ x2, y2 = map(int, np.round(kpts[j]))
543
+ if not (in_bounds(x1, y1) and in_bounds(x2, y2)):
544
+ continue
545
+ cv2.line(img, (x1, y1), (x2, y2), link_colors[lk % len(link_colors)],
546
+ thickness=max(1, thickness), lineType=cv2.LINE_AA)
547
+ for j_idx, (xy, v, s) in enumerate(zip(kpts, vis, score)):
548
+ if not v or s < kpt_thr:
549
+ continue
550
+ x, y = map(int, np.round(xy))
551
+ if not in_bounds(x, y):
552
+ continue
553
+ c = kpt_colors[min(j_idx, len(kpt_colors) - 1)]
554
+ cv2.circle(img, (x, y), radius, c, thickness=-1, lineType=cv2.LINE_AA)
555
+ if show_kpt_idx:
556
+ cv2.putText(img, str(j_idx), (x + radius, y - radius),
557
+ cv2.FONT_HERSHEY_SIMPLEX, 0.4, c, 1, cv2.LINE_AA)
558
+ return img
559
+
560
+
561
+ def _detect_persons(image_rgb: np.ndarray, threshold: float = 0.5):
562
+ import torch
563
+ proc, det = _get_detector()
564
+ pil_img = Image.fromarray(image_rgb)
565
+ inputs = proc(images=pil_img, return_tensors="pt")
566
+ with torch.no_grad():
567
+ outputs = det(**inputs)
568
+ target_sizes = torch.tensor([image_rgb.shape[:2]])
569
+ results = proc.post_process_object_detection(
570
+ outputs, target_sizes=target_sizes, threshold=threshold
571
+ )[0]
572
+ person_mask = results["labels"] == 1 # COCO class 1 = person
573
+ boxes = results["boxes"][person_mask].cpu().numpy()
574
+ scores = results["scores"][person_mask].cpu().numpy().reshape(-1, 1)
575
+ if len(boxes) == 0:
576
+ h, w = image_rgb.shape[:2]
577
+ return np.array([[0, 0, w - 1, h - 1, 1.0]], dtype=np.float32)
578
+ return np.concatenate([boxes, scores], axis=1).astype(np.float32)
579
+
580
+
581
+ def _infer_pose(image_bgr, model, kpt_thr: float = 0.3):
582
+ import torch
583
+ import cv2
584
+
585
+ image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
586
+ bboxes = _detect_persons(image_rgb)
587
+ inputs_list, samples_list = [], []
588
+ for bbox in bboxes:
589
+ data_info = dict(img=image_bgr, bbox=bbox[None, :4], bbox_score=np.ones(1, dtype=np.float32))
590
+ data = model.pipeline(data_info)
591
+ data = model.data_preprocessor(data)
592
+ inputs_list.append(data["inputs"])
593
+ samples_list.append(data["data_samples"])
594
+ inputs = torch.cat(inputs_list, dim=0)
595
+ with torch.no_grad():
596
+ pred = model(inputs).cpu().numpy()
597
+
598
+ keypoints, scores = [], []
599
+ for i, sample in enumerate(samples_list):
600
+ kpts_i, scr_i = model.codec.decode(pred[i])
601
+ meta = sample["meta"] if isinstance(sample, dict) else sample.metainfo
602
+ kpts_i = kpts_i / np.array(meta["input_size"]) * meta["bbox_scale"] + meta["bbox_center"] - 0.5 * meta["bbox_scale"]
603
+ keypoints.append(kpts_i[0])
604
+ scores.append(scr_i[0])
605
+
606
+ pmeta = model.pose_metainfo
607
+ vis_rgb = image_rgb.copy()
608
+ # Scale render thickness so 308-keypoint dense pose stays visible on high-res input
609
+ short_side = min(vis_rgb.shape[:2])
610
+ radius_px = max(3, short_side // 200)
611
+ thick_px = max(2, short_side // 250)
612
+ box_thick = max(2, short_side // 300)
613
+ for bbox, kpts, scr in zip(bboxes, keypoints, scores):
614
+ x1, y1, x2, y2 = map(int, bbox[:4])
615
+ cv2.rectangle(vis_rgb, (x1, y1), (x2, y2), (0, 255, 0), box_thick)
616
+ vis_rgb = visualize_keypoints(
617
+ image=vis_rgb,
618
+ keypoints=[kpts],
619
+ keypoints_visible=[np.ones(len(scr), dtype=bool)],
620
+ keypoint_scores=[scr],
621
+ radius=radius_px, thickness=thick_px, kpt_thr=kpt_thr,
622
+ skeleton=pmeta["skeleton_links"],
623
+ kpt_color=pmeta["keypoint_colors"],
624
+ link_color=pmeta["skeleton_link_colors"],
625
+ )
626
+ return Image.fromarray(vis_rgb), f"persons={len(bboxes)} | kpts/person={len(keypoints[0]) if keypoints else 0}"
627
+
628
+
629
+ # --- Predict entry point ----------------------------------------------------
630
+ def predict(image: Image.Image, task: str, size: str):
631
+ if image is None:
632
+ return None, "No image provided"
633
+ key = (task, size)
634
+ if key not in VARIANTS:
635
+ return None, f"Unknown variant {task}-{size}. Allowed: {sorted(VARIANTS.keys())}"
636
+ t0 = time.time()
637
+ try:
638
+ import cv2
639
+ image_pil = image.convert("RGB")
640
+ in_w, in_h = image_pil.size
641
+ image_bgr = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR)
642
+ kind = VARIANTS[key]["kind"]
643
+ if size == "5b":
644
+ out_img, info = _infer_dense_5b(image_bgr, task)
645
+ elif kind == "pose":
646
+ model = _get_pose_model(size)
647
+ out_img, info = _infer_pose(image_bgr, model)
648
+ else:
649
+ model = _get_dense_model(task, size)
650
+ if kind == "seg":
651
+ out_img, info = _infer_seg(image_bgr, model)
652
+ elif kind == "normal":
653
+ out_img, info = _infer_normal(image_bgr, model)
654
+ elif kind == "pointmap":
655
+ out_img, info = _infer_pointmap(image_bgr, model)
656
+ else:
657
+ return None, f"Unhandled kind: {kind}"
658
+ elapsed = time.time() - t0
659
+ out_w, out_h = out_img.size
660
+ return out_img, f"{task}-{size}: done in {elapsed:.1f}s | {in_w}×{in_h} → 1024×768 → {out_w}×{out_h} | {info}"
661
+ except Exception as e:
662
+ return None, f"{type(e).__name__}: {e}\n\n{traceback.format_exc()[:1500]}"
663
+
664
+
665
+ def health():
666
+ return (
667
+ f"Service up | dense cache: {list(_MODELS.keys())} | pose cache: {list(_POSE_MODELS.keys())} | "
668
+ f"detector_loaded={_DETECTOR is not None} | variants={len(VARIANTS)} "
669
+ f"({sorted(set(t for t, _ in VARIANTS))} × {sorted(set(s for _, s in VARIANTS))})"
670
+ )
671
+
672
+
673
+ DEMO_IMAGES = sorted(str(p) for p in Path("/app/assets/images").glob("*.jpg"))
674
+
675
+ with gr.Blocks(title="Sapiens2 CPU", css="""
676
+ #img-in,#img-out{max-height:220px}
677
+ #status-box textarea{max-height:60px!important;min-height:60px!important}
678
+ #status-box{flex-grow:0!important}
679
+ """) as demo:
680
+ with gr.Row(equal_height=False):
681
+ with gr.Column(scale=1):
682
+ img_in = gr.Image(type="pil", label="Input", height=200, elem_id="img-in")
683
+ with gr.Row():
684
+ task_in = gr.Dropdown(choices=["seg", "normal", "pointmap", "pose"], value="seg", label="Task", scale=1)
685
+ size_in = gr.Dropdown(choices=["0.4b", "0.8b", "1b", "5b"], value="0.4b", label="Size", scale=1)
686
+ run_btn = gr.Button("Predict - 1024×768 native", variant="primary")
687
+ gr.Examples(
688
+ examples=[[u] for u in DEMO_IMAGES],
689
+ inputs=[img_in],
690
+ examples_per_page=6,
691
+ cache_examples=False,
692
+ label="Meta demo images",
693
+ )
694
+ with gr.Column(scale=1):
695
+ img_out = gr.Image(type="pil", label="Output", height=200, elem_id="img-out")
696
+ status = gr.Textbox(show_label=False, lines=2, max_lines=2, interactive=False, container=False, placeholder="Status will show here after Predict", elem_id="status-box")
697
+ run_btn.click(
698
+ fn=predict, inputs=[img_in, task_in, size_in], outputs=[img_out, status], api_name="predict"
699
+ )
700
+ # Keep health endpoint accessible via API (no UI button — useless in browser)
701
+ gr.Button(visible=False).click(fn=health, outputs=[gr.Textbox(visible=False)], api_name="health")
702
+
703
+ demo.queue(default_concurrency_limit=1)
704
+
705
+ if __name__ == "__main__":
706
+ demo.launch()
assets/images/pexels-alex-green-5699868.jpg ADDED

Git LFS Details

  • SHA256: 6fcb9449ac2b09fd12f06e1ee615465d8b796fc745c1b76bf4413d84a4cb8a38
  • Pointer size: 132 Bytes
  • Size of remote file: 1.52 MB
assets/images/pexels-anntarazevich-4928706.jpg ADDED

Git LFS Details

  • SHA256: 189304446300509d0e9e173c285749802fee8e04e1da0b31f2f6a87b1fd3a14e
  • Pointer size: 131 Bytes
  • Size of remote file: 258 kB
assets/images/pexels-blue-bird-7242832.jpg ADDED

Git LFS Details

  • SHA256: bf4fe2357cf23f52fcb0f26babd16ddf27c09e319a5c5a27a109042075949a38
  • Pointer size: 132 Bytes
  • Size of remote file: 1.9 MB
assets/images/pexels-cibelebergamim-29948140.jpg ADDED

Git LFS Details

  • SHA256: 3b4b1a5e416b9cae6ac2bcb091f57021a3a023fc26cd4536d6cc0035db96e5a1
  • Pointer size: 131 Bytes
  • Size of remote file: 912 kB
assets/images/pexels-cottonbro-4057693.jpg ADDED

Git LFS Details

  • SHA256: 01038cf1dc126c32ef962bfb5dccd7328549163acf611be25e851caf0de20038
  • Pointer size: 132 Bytes
  • Size of remote file: 2.31 MB
assets/images/pexels-cottonbro-6616678.jpg ADDED

Git LFS Details

  • SHA256: 323d49d6ce428c0cea89decbce13170a1a116f9e3a69985fe88cc8263997695b
  • Pointer size: 132 Bytes
  • Size of remote file: 1.27 MB
assets/images/pexels-marcus-aurelius-6787390.jpg ADDED

Git LFS Details

  • SHA256: 97755a9e98ce0f505085064b4583c530f0bb3d89a9e7e2b3269726f0b9f86730
  • Pointer size: 132 Bytes
  • Size of remote file: 2.57 MB
assets/images/pexels-mikhail-nilov-8350316.jpg ADDED

Git LFS Details

  • SHA256: 8d20698000ea36f976da4d6267441d88c0f57f63081580e9f25b4eeaf9f4830b
  • Pointer size: 132 Bytes
  • Size of remote file: 1.28 MB
assets/images/pexels-shvetsa-5830936.jpg ADDED

Git LFS Details

  • SHA256: 40a233785b63f15ffeb0d4b468a6827f3f13ab13eee0581c23f682704fc79c91
  • Pointer size: 132 Bytes
  • Size of remote file: 2.86 MB
assets/images/pexels-vazhnik-7562218.jpg ADDED

Git LFS Details

  • SHA256: f33f583df710a12397c55c49c529e2ff17a98dde06f2c2e3d08c51ae03f0e49f
  • Pointer size: 131 Bytes
  • Size of remote file: 826 kB
packages.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ libgl1
2
+ libglib2.0-0
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cpu
2
+ torch>=2.7
3
+ torchvision
4
+ pillow
5
+ numpy
6
+ safetensors
7
+ huggingface_hub
8
+ opencv-python-headless
9
+ onnxruntime
10
+ transformers
11
+ sapiens @ git+https://github.com/facebookresearch/sapiens2.git