barakplasma commited on
Commit
95c2945
·
verified ·
1 Parent(s): 8c56094

Add model card with embeddings guide and browser/Node.js/Python examples

Browse files
Files changed (1) hide show
  1. README.md +327 -0
README.md ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: other
3
+ license_name: sapiens2-license
4
+ license_link: https://github.com/facebookresearch/sapiens2/blob/main/LICENSE.md
5
+ pipeline_tag: image-feature-extraction
6
+ library_name: transformers
7
+ base_model: facebook/sapiens2-pretrain-0.1b
8
+ tags:
9
+ - sapiens
10
+ - sapiens2
11
+ - vision-transformer
12
+ - human-centric
13
+ - feature-extraction
14
+ - onnx
15
+ - onnxruntime-web
16
+ ---
17
+
18
+ # Sapiens2-0.1B — ONNX Export
19
+
20
+ ONNX export of [facebook/sapiens2-pretrain-0.1b](https://huggingface.co/facebook/sapiens2-pretrain-0.1b), a vision transformer pretrained on **1 billion human images**. This repo provides ready-to-run weights for browser inference via `onnxruntime-web` and server inference via `onnxruntime-node` or `onnxruntime`.
21
+
22
+ | File | Size | Use |
23
+ |---|---|---|
24
+ | `sapiens2_0.1b_int8.onnx` | 116 MB | Browser / mobile (recommended) |
25
+ | `sapiens2_0.1b_fp32.onnx` | 458 MB | Server-side / higher precision |
26
+ | `example_embeddings.js` | — | Fully worked Node.js example |
27
+
28
+ **Output:** a single `(batch, 768)` float32 embedding vector per image (CLS token).
29
+
30
+ ---
31
+
32
+ ## What are embeddings?
33
+
34
+ The model encodes an image into a 768-dimensional vector that captures high-level human-centric semantics — pose, body shape, clothing, and identity. Two images with similar people in similar poses will have embeddings close together in this space. Common uses:
35
+
36
+ - **Similarity search** — find the most similar person/pose in a database
37
+ - **Clustering** — group images by body pose, clothing, or activity
38
+ - **Classification** — train a lightweight head on top of frozen embeddings
39
+ - **Retrieval-augmented generation** — image → embedding → nearest-neighbor lookup
40
+
41
+ ---
42
+
43
+ ## Generating embeddings — Browser
44
+
45
+ ```bash
46
+ npm install onnxruntime-web
47
+ ```
48
+
49
+ ```js
50
+ import * as ort from "onnxruntime-web";
51
+
52
+ // Point WASM binaries at the CDN build (avoids bundler complexity)
53
+ ort.env.wasm.wasmPaths = "https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/";
54
+
55
+ const MODEL_URL =
56
+ "https://huggingface.co/barakplasma/sapiens2-onnx/resolve/main/sapiens2_0.1b_int8.onnx";
57
+
58
+ const H = 1024, W = 768;
59
+ const MEAN = [0.485, 0.456, 0.406];
60
+ const STD = [0.229, 0.224, 0.225];
61
+
62
+ // Load once; reuse for all images
63
+ export async function loadModel() {
64
+ return ort.InferenceSession.create(MODEL_URL, {
65
+ executionProviders: ["webgpu", "wasm"], // WebGPU ~1-3s, WASM ~20-60s
66
+ graphOptimizationLevel: "all",
67
+ });
68
+ }
69
+
70
+ // Resize an <img> or <canvas> to 1024×768 and convert to a float32 NCHW tensor
71
+ function imageToTensor(source) {
72
+ const canvas = document.createElement("canvas");
73
+ canvas.width = W;
74
+ canvas.height = H;
75
+ canvas.getContext("2d").drawImage(source, 0, 0, W, H);
76
+ const { data } = canvas.getContext("2d").getImageData(0, 0, W, H); // RGBA uint8
77
+
78
+ const t = new Float32Array(3 * H * W);
79
+ for (let i = 0; i < H * W; i++) {
80
+ t[i] = (data[i * 4] / 255 - MEAN[0]) / STD[0]; // R plane
81
+ t[H * W + i] = (data[i * 4 + 1] / 255 - MEAN[1]) / STD[1]; // G plane
82
+ t[2 * H * W + i] = (data[i * 4 + 2] / 255 - MEAN[2]) / STD[2]; // B plane
83
+ }
84
+ return new ort.Tensor("float32", t, [1, 3, H, W]);
85
+ }
86
+
87
+ // Returns a Float32Array of length 768
88
+ export async function embed(session, imageElement) {
89
+ const feeds = { pixel_values: imageToTensor(imageElement) };
90
+ const { embedding } = await session.run(feeds);
91
+ return embedding.data;
92
+ }
93
+
94
+ // Cosine similarity between two embeddings (both Float32Array length 768)
95
+ export function cosineSimilarity(a, b) {
96
+ let dot = 0, normA = 0, normB = 0;
97
+ for (let i = 0; i < a.length; i++) {
98
+ dot += a[i] * b[i];
99
+ normA += a[i] * a[i];
100
+ normB += b[i] * b[i];
101
+ }
102
+ return dot / (Math.sqrt(normA) * Math.sqrt(normB));
103
+ }
104
+
105
+ // Example: compare two images
106
+ async function compareImages(imgA, imgB) {
107
+ const session = await loadModel();
108
+ const [embA, embB] = await Promise.all([embed(session, imgA), embed(session, imgB)]);
109
+ const score = cosineSimilarity(embA, embB); // -1 (opposite) to 1 (identical)
110
+ console.log(`Similarity: ${score.toFixed(4)}`);
111
+ return score;
112
+ }
113
+ ```
114
+
115
+ ### Caching the model in IndexedDB
116
+
117
+ The INT8 model is 116 MB. Cache it in IndexedDB to skip the download on repeat visits:
118
+
119
+ ```js
120
+ const DB_NAME = "sapiens2-onnx";
121
+ const STORE = "models";
122
+
123
+ async function openDB() {
124
+ return new Promise((resolve, reject) => {
125
+ const req = indexedDB.open(DB_NAME, 1);
126
+ req.onupgradeneeded = () => req.result.createObjectStore(STORE);
127
+ req.onsuccess = () => resolve(req.result);
128
+ req.onerror = () => reject(req.error);
129
+ });
130
+ }
131
+
132
+ export async function loadModelCached(url = MODEL_URL) {
133
+ const db = await openDB();
134
+ const hit = await new Promise(res => {
135
+ const req = db.transaction(STORE).objectStore(STORE).get(url);
136
+ req.onsuccess = () => res(req.result);
137
+ req.onerror = () => res(null);
138
+ });
139
+
140
+ const buf = hit ?? await fetch(url).then(r => r.arrayBuffer()).then(buf => {
141
+ db.transaction(STORE, "readwrite").objectStore(STORE).put(buf, url);
142
+ return buf;
143
+ });
144
+
145
+ return ort.InferenceSession.create(buf, {
146
+ executionProviders: ["webgpu", "wasm"],
147
+ graphOptimizationLevel: "all",
148
+ });
149
+ }
150
+ ```
151
+
152
+ ---
153
+
154
+ ## Generating embeddings — Node.js
155
+
156
+ See [`example_embeddings.js`](./example_embeddings.js) in this repo for a fully worked example. Short version:
157
+
158
+ ```bash
159
+ npm install onnxruntime-node sharp
160
+ ```
161
+
162
+ ```js
163
+ import * as ort from "onnxruntime-node";
164
+ import sharp from "sharp";
165
+
166
+ const H = 1024, W = 768;
167
+ const MEAN = [0.485, 0.456, 0.406];
168
+ const STD = [0.229, 0.224, 0.225];
169
+
170
+ async function embed(session, imagePath) {
171
+ const { data } = await sharp(imagePath)
172
+ .resize(W, H) // sharp takes (width, height)
173
+ .raw()
174
+ .toBuffer({ resolveWithObject: true });
175
+
176
+ const t = new Float32Array(3 * H * W);
177
+ for (let i = 0; i < H * W; i++) {
178
+ t[i] = (data[i * 3] / 255 - MEAN[0]) / STD[0];
179
+ t[H * W + i] = (data[i * 3 + 1] / 255 - MEAN[1]) / STD[1];
180
+ t[2 * H * W + i] = (data[i * 3 + 2] / 255 - MEAN[2]) / STD[2];
181
+ }
182
+
183
+ const { embedding } = await session.run({
184
+ pixel_values: new ort.Tensor("float32", t, [1, 3, H, W]),
185
+ });
186
+ return embedding.data; // Float32Array of length 768
187
+ }
188
+
189
+ const session = await ort.InferenceSession.create("sapiens2_0.1b_int8.onnx", {
190
+ executionProviders: ["cpu"],
191
+ });
192
+ const emb = await embed(session, "person.jpg");
193
+ console.log("Embedding length:", emb.length); // 768
194
+ ```
195
+
196
+ ---
197
+
198
+ ## Generating embeddings — Python (onnxruntime)
199
+
200
+ ```python
201
+ import onnxruntime as ort
202
+ import numpy as np
203
+ from PIL import Image
204
+
205
+ H, W = 1024, 768
206
+ MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
207
+ STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
208
+
209
+ sess = ort.InferenceSession(
210
+ "sapiens2_0.1b_int8.onnx",
211
+ providers=["CPUExecutionProvider"],
212
+ )
213
+
214
+ def embed(image_path):
215
+ img = np.array(Image.open(image_path).convert("RGB").resize((W, H)), dtype=np.float32)
216
+ img = (img / 255.0 - MEAN) / STD # normalize: (H, W, 3)
217
+ img = img.transpose(2, 0, 1)[np.newaxis] # NCHW: (1, 3, H, W)
218
+ return sess.run(["embedding"], {"pixel_values": img})[0] # (1, 768)
219
+
220
+ # Compare two images
221
+ a = embed("person_a.jpg")
222
+ b = embed("person_b.jpg")
223
+ similarity = np.dot(a[0], b[0]) / (np.linalg.norm(a) * np.linalg.norm(b))
224
+ print(f"Similarity: {similarity:.4f}") # -1 to 1
225
+ ```
226
+
227
+ ### Batch inference
228
+
229
+ ```python
230
+ def embed_batch(image_paths, batch_size=4):
231
+ embeddings = []
232
+ for i in range(0, len(image_paths), batch_size):
233
+ batch_paths = image_paths[i : i + batch_size]
234
+ imgs = []
235
+ for p in batch_paths:
236
+ img = np.array(Image.open(p).convert("RGB").resize((W, H)), dtype=np.float32)
237
+ imgs.append((img / 255.0 - MEAN) / STD)
238
+ batch = np.stack(imgs).transpose(0, 3, 1, 2) # (B, 3, H, W)
239
+ out = sess.run(["embedding"], {"pixel_values": batch})[0] # (B, 768)
240
+ embeddings.append(out)
241
+ return np.concatenate(embeddings, axis=0)
242
+ ```
243
+
244
+ ---
245
+
246
+ ## Model details
247
+
248
+ | | |
249
+ |---|---|
250
+ | **Base model** | [facebook/sapiens2-pretrain-0.1b](https://huggingface.co/facebook/sapiens2-pretrain-0.1b) |
251
+ | **Architecture** | Vision Transformer (RoPE, GQA, SwiGLU, RMSNorm, QK-norm) |
252
+ | **Parameters** | 0.114 B |
253
+ | **FLOPs** | 0.342 T |
254
+ | **Embedding dim** | 768 |
255
+ | **Layers / heads** | 12 / 12 |
256
+ | **Input size** | 1024 × 768 (H × W), RGB, ImageNet-normalized |
257
+ | **Patch size** | 16 px → 3,072 patch tokens |
258
+ | **Output** | CLS token: `(batch, 768)` float32 |
259
+ | **Pretraining data** | 1 billion curated human images |
260
+ | **ONNX opset** | 18 |
261
+ | **Exporter** | `torch.onnx.export` (dynamo) |
262
+ | **Quantization** | `onnxruntime.quantization.quantize_dynamic`, `QInt8` weights |
263
+
264
+ ### Preprocessing spec
265
+
266
+ Images must be resized to exactly **1024 × 768 (H × W)** and normalized with ImageNet statistics:
267
+
268
+ ```
269
+ mean = [0.485, 0.456, 0.406] # per channel, RGB order
270
+ std = [0.229, 0.224, 0.225]
271
+ pixel_values = (pixel / 255 - mean) / std
272
+ layout: NCHW float32 — shape (batch, 3, 1024, 768)
273
+ ```
274
+
275
+ ---
276
+
277
+ ## Browser requirements
278
+
279
+ | | Minimum | Recommended |
280
+ |---|---|---|
281
+ | **Browser** | Chrome/Edge 113+ | Chrome 120+ |
282
+ | **Execution provider** | WASM (CPU) | WebGPU |
283
+ | **RAM** | 4 GB free | 8 GB |
284
+ | **INT8 inference time** | ~20–60 s (WASM) | ~1–3 s (WebGPU) |
285
+
286
+ WebGPU is available in Chrome/Edge 113+ on desktop. Mobile inference is not recommended at this resolution.
287
+
288
+ ---
289
+
290
+ ## Export notes
291
+
292
+ Two non-obvious issues arose during export from the original safetensors checkpoint:
293
+
294
+ 1. **bfloat16 in RoPE** — `RopePositionEmbedding` defaults to `dtype=bfloat16` via `pos_embed_rope_dtype="bf16"`. This is stored as a plain Python attribute (`self.dtype`), not a tensor buffer, so calling `.float()` on the model doesn't fix it. Must be overridden at construction time: `pos_embed_rope_dtype="fp32"`.
295
+
296
+ 2. **`aten::rms_norm` unsupported in legacy tracer** — the TorchScript-based exporter (`dynamo=False`) does not support `rms_norm`. The dynamo-based exporter was used instead. By default this produces a sidecar `.onnx.data` file; weights were inlined back into a single file via `onnx.save_model(..., save_as_external_data=False)`.
297
+
298
+ ---
299
+
300
+ ## Sapiens2 model family
301
+
302
+ | Model | Params | Embed dim | Layers |
303
+ |---|---|---|---|
304
+ | **Sapiens2-0.1B** *(this)* | 0.114 B | 768 | 12 |
305
+ | [Sapiens2-0.4B](https://huggingface.co/facebook/sapiens2-pretrain-0.4b) | 0.398 B | 1024 | 24 |
306
+ | [Sapiens2-0.8B](https://huggingface.co/facebook/sapiens2-pretrain-0.8b) | 0.818 B | 1280 | 32 |
307
+ | [Sapiens2-1B](https://huggingface.co/facebook/sapiens2-pretrain-1b) | 1.462 B | 1536 | 40 |
308
+ | [Sapiens2-5B](https://huggingface.co/facebook/sapiens2-pretrain-5b) | 5.071 B | 2432 | 56 |
309
+
310
+ Only the 0.1B model is practical for browser inference. Larger models require server-side deployment.
311
+
312
+ ---
313
+
314
+ ## License
315
+
316
+ The original weights are released under the [Sapiens2 License](https://github.com/facebookresearch/sapiens2/blob/main/LICENSE.md). This ONNX conversion inherits the same license terms.
317
+
318
+ ## Citation
319
+
320
+ ```bibtex
321
+ @article{khirodkarsapiens2,
322
+ title = {Sapiens2},
323
+ author = {Khirodkar, Rawal and Wen, He and Martinez, Julieta and Dong, Yuan and Su, Zhaoen and Saito, Shunsuke},
324
+ journal = {arXiv preprint arXiv:2604.21681},
325
+ year = {2026}
326
+ }
327
+ ```