barakplasma commited on
Commit
42e82e0
Β·
verified Β·
1 Parent(s): 1b4a282

Rewrite example as browser ES module; remove Node.js/sharp dependency

Browse files
Files changed (1) hide show
  1. example_embeddings.js +107 -100
example_embeddings.js CHANGED
@@ -1,92 +1,139 @@
1
  /**
2
  * example_embeddings.js
3
  *
4
- * Demonstrates loading sapiens2-0.1b ONNX and generating image embeddings
5
- * in Node.js. Compares two images and finds the most similar image in a set.
 
 
 
 
6
  *
7
- * Requirements:
8
- * npm install onnxruntime-node sharp
9
  *
10
  * Usage:
11
- * node example_embeddings.js image_a.jpg image_b.jpg [image_c.jpg ...]
12
- *
13
- * First two images are compared directly.
14
- * If more images are supplied, the image most similar to image_a.jpg is found.
15
  */
16
 
17
- import * as ort from "onnxruntime-node";
18
- import sharp from "sharp";
19
- import path from "path";
20
- import { fileURLToPath } from "url";
21
 
22
- // ── Config ────────────────────────────────────────────────────────────────────
23
 
24
- const MODEL_PATH = path.join(
25
- path.dirname(fileURLToPath(import.meta.url)),
26
- "sapiens2_0.1b_int8.onnx"
27
- );
28
 
29
  const H = 1024;
30
  const W = 768;
31
  const MEAN = [0.485, 0.456, 0.406];
32
  const STD = [0.229, 0.224, 0.225];
33
 
34
- // ── Core functions ────────────────────────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  /**
37
- * Load the ONNX inference session. Reuse the returned session for all images β€”
38
- * loading takes ~1-2 s and should only happen once.
 
 
 
39
  */
40
- export async function loadModel(modelPath = MODEL_PATH) {
41
- return ort.InferenceSession.create(modelPath, {
42
- executionProviders: ["cpu"],
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  graphOptimizationLevel: "all",
44
  });
45
  }
46
 
47
  /**
48
- * Read an image from disk, resize to 1024Γ—768, and convert to a float32
49
- * NCHW tensor with ImageNet normalization.
50
  *
51
- * @param {string} imagePath Path to any image format supported by sharp.
52
- * @returns {ort.Tensor} Shape (1, 3, 1024, 768).
53
  */
54
- export async function imageToTensor(imagePath) {
55
- const { data } = await sharp(imagePath)
56
- .resize(W, H) // sharp uses (width, height)
57
- .removeAlpha() // drop alpha if present
58
- .raw() // uncompressed RGB bytes
59
- .toBuffer({ resolveWithObject: true });
 
60
 
61
  const t = new Float32Array(3 * H * W);
62
  for (let i = 0; i < H * W; i++) {
63
- t[i] = (data[i * 3] / 255 - MEAN[0]) / STD[0]; // R
64
- t[H * W + i] = (data[i * 3 + 1] / 255 - MEAN[1]) / STD[1]; // G
65
- t[2 * H * W + i] = (data[i * 3 + 2] / 255 - MEAN[2]) / STD[2]; // B
66
  }
67
-
68
  return new ort.Tensor("float32", t, [1, 3, H, W]);
69
  }
70
 
71
  /**
72
- * Run the model on a single image and return its 768-dimensional embedding.
73
  *
74
- * @param {ort.InferenceSession} session
75
- * @param {string} imagePath
76
- * @returns {Float32Array} Length 768.
77
  */
78
- export async function embed(session, imagePath) {
79
- const tensor = await imageToTensor(imagePath);
80
- const { embedding } = await session.run({ pixel_values: tensor });
81
- return embedding.data; // Float32Array
82
  }
83
 
84
  /**
85
- * Cosine similarity between two equal-length Float32Arrays.
86
- * Returns a value in [-1, 1]: 1 = identical direction, 0 = orthogonal, -1 = opposite.
 
 
 
 
87
  */
88
  export function cosineSimilarity(a, b) {
89
- if (a.length !== b.length) throw new Error("Embedding length mismatch");
90
  let dot = 0, normA = 0, normB = 0;
91
  for (let i = 0; i < a.length; i++) {
92
  dot += a[i] * b[i];
@@ -97,8 +144,11 @@ export function cosineSimilarity(a, b) {
97
  }
98
 
99
  /**
100
- * L2-normalize an embedding in place. After normalization you can use a simple
101
- * dot product instead of cosine similarity, which is faster for large databases.
 
 
 
102
  */
103
  export function l2Normalize(v) {
104
  let norm = 0;
@@ -110,60 +160,17 @@ export function l2Normalize(v) {
110
  }
111
 
112
  /**
113
- * Given a query embedding and an array of candidate embeddings, return the
114
- * index of the most similar candidate and its similarity score.
 
 
 
115
  */
116
  export function findMostSimilar(query, candidates) {
117
- let bestIdx = -1;
118
- let bestScore = -Infinity;
119
  for (let i = 0; i < candidates.length; i++) {
120
  const score = cosineSimilarity(query, candidates[i]);
121
- if (score > bestScore) {
122
- bestScore = score;
123
- bestIdx = i;
124
- }
125
  }
126
  return { index: bestIdx, score: bestScore };
127
  }
128
-
129
- // ── Demo ──────────────────────────────────────────────────────────────────────
130
-
131
- async function main() {
132
- const args = process.argv.slice(2);
133
- if (args.length < 2) {
134
- console.error("Usage: node example_embeddings.js image_a.jpg image_b.jpg [more...]");
135
- process.exit(1);
136
- }
137
-
138
- console.log("Loading model...");
139
- const session = await loadModel();
140
- console.log("Model loaded.\n");
141
-
142
- // Embed all provided images
143
- const embeddings = [];
144
- for (const imgPath of args) {
145
- process.stdout.write(`Embedding ${path.basename(imgPath)}... `);
146
- const t0 = Date.now();
147
- const emb = await embed(session, imgPath);
148
- console.log(`done (${Date.now() - t0} ms, dim=${emb.length})`);
149
- embeddings.push(emb);
150
- }
151
-
152
- // Compare first two images
153
- const [a, b] = args;
154
- const score = cosineSimilarity(embeddings[0], embeddings[1]);
155
- console.log(`\nCosine similarity between ${path.basename(a)} and ${path.basename(b)}: ${score.toFixed(4)}`);
156
- console.log(score > 0.85 ? " β†’ Very similar" : score > 0.6 ? " β†’ Somewhat similar" : " β†’ Dissimilar");
157
-
158
- // If more than 2 images: find the most similar to the first
159
- if (args.length > 2) {
160
- const candidates = embeddings.slice(1);
161
- const { index, score: bestScore } = findMostSimilar(embeddings[0], candidates);
162
- console.log(
163
- `\nMost similar to ${path.basename(args[0])}: ${path.basename(args[index + 1])} ` +
164
- `(score=${bestScore.toFixed(4)})`
165
- );
166
- }
167
- }
168
-
169
- main().catch(err => { console.error(err); process.exit(1); });
 
1
  /**
2
  * example_embeddings.js
3
  *
4
+ * Drop-in ES module for browser use. Exports:
5
+ * loadModelCached(url?) β€” load and cache model in IndexedDB
6
+ * embed(session, source) β€” get 768-dim Float32Array from any image source
7
+ * cosineSimilarity(a, b) β€” similarity score in [-1, 1]
8
+ * l2Normalize(v) β€” normalize so dot product equals cosine similarity
9
+ * findMostSimilar(q, list) β€” nearest-neighbor in an embedding array
10
  *
11
+ * Requirements: onnxruntime-web (npm install onnxruntime-web)
 
12
  *
13
  * Usage:
14
+ * import { loadModelCached, embed, cosineSimilarity } from "./example_embeddings.js";
15
+ * const session = await loadModelCached();
16
+ * const emb = await embed(session, document.getElementById("myImage"));
 
17
  */
18
 
19
+ import * as ort from "onnxruntime-web";
20
+
21
+ // ── Config ─────────────────────────────────────────────────────────────────
 
22
 
23
+ ort.env.wasm.wasmPaths = "https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/";
24
 
25
+ const MODEL_URL =
26
+ "https://huggingface.co/barakplasma/sapiens2-onnx/resolve/main/sapiens2_0.1b_int8.onnx";
 
 
27
 
28
  const H = 1024;
29
  const W = 768;
30
  const MEAN = [0.485, 0.456, 0.406];
31
  const STD = [0.229, 0.224, 0.225];
32
 
33
+ const DB_NAME = "sapiens2-onnx";
34
+ const DB_STORE = "models";
35
+
36
+ // ── IndexedDB helpers ──────────────────────────────────────────────────────
37
+
38
+ function openDB() {
39
+ return new Promise((resolve, reject) => {
40
+ const req = indexedDB.open(DB_NAME, 1);
41
+ req.onupgradeneeded = () => req.result.createObjectStore(DB_STORE);
42
+ req.onsuccess = () => resolve(req.result);
43
+ req.onerror = () => reject(req.error);
44
+ });
45
+ }
46
+
47
+ function idbGet(db, key) {
48
+ return new Promise(resolve => {
49
+ const req = db.transaction(DB_STORE).objectStore(DB_STORE).get(key);
50
+ req.onsuccess = () => resolve(req.result ?? null);
51
+ req.onerror = () => resolve(null);
52
+ });
53
+ }
54
+
55
+ function idbPut(db, key, value) {
56
+ return new Promise((resolve, reject) => {
57
+ const req = db.transaction(DB_STORE, "readwrite").objectStore(DB_STORE).put(value, key);
58
+ req.onsuccess = () => resolve();
59
+ req.onerror = () => reject(req.error);
60
+ });
61
+ }
62
+
63
+ // ── Public API ─────────────────────────────────────────────────────────────
64
 
65
  /**
66
+ * Load the ONNX model. On first call, fetches from HuggingFace and stores the
67
+ * ArrayBuffer in IndexedDB. Subsequent calls load from cache instantly.
68
+ *
69
+ * @param {string} [url] Override the default model URL.
70
+ * @returns {Promise<ort.InferenceSession>}
71
  */
72
+ export async function loadModelCached(url = MODEL_URL) {
73
+ const db = await openDB();
74
+ const cached = await idbGet(db, url);
75
+
76
+ const buf = cached ?? await fetch(url)
77
+ .then(r => {
78
+ if (!r.ok) throw new Error(`Failed to fetch model: ${r.status} ${r.statusText}`);
79
+ return r.arrayBuffer();
80
+ })
81
+ .then(async buf => {
82
+ await idbPut(db, url, buf);
83
+ return buf;
84
+ });
85
+
86
+ return ort.InferenceSession.create(buf, {
87
+ executionProviders: ["webgpu", "wasm"],
88
  graphOptimizationLevel: "all",
89
  });
90
  }
91
 
92
  /**
93
+ * Convert an image source to a float32 NCHW tensor with ImageNet normalization.
94
+ * Accepts anything drawImage() accepts: <img>, <canvas>, ImageBitmap, VideoFrame.
95
  *
96
+ * @param {HTMLImageElement|HTMLCanvasElement|ImageBitmap|VideoFrame} source
97
+ * @returns {ort.Tensor} Shape (1, 3, 1024, 768).
98
  */
99
+ export function imageToTensor(source) {
100
+ const canvas = document.createElement("canvas");
101
+ canvas.width = W;
102
+ canvas.height = H;
103
+ const ctx = canvas.getContext("2d");
104
+ ctx.drawImage(source, 0, 0, W, H);
105
+ const { data } = ctx.getImageData(0, 0, W, H); // RGBA uint8
106
 
107
  const t = new Float32Array(3 * H * W);
108
  for (let i = 0; i < H * W; i++) {
109
+ t[i] = (data[i * 4] / 255 - MEAN[0]) / STD[0]; // R
110
+ t[H * W + i] = (data[i * 4 + 1] / 255 - MEAN[1]) / STD[1]; // G
111
+ t[2 * H * W + i] = (data[i * 4 + 2] / 255 - MEAN[2]) / STD[2]; // B
112
  }
 
113
  return new ort.Tensor("float32", t, [1, 3, H, W]);
114
  }
115
 
116
  /**
117
+ * Run the model on one image and return its 768-dim embedding.
118
  *
119
+ * @param {ort.InferenceSession} session
120
+ * @param {HTMLImageElement|HTMLCanvasElement|ImageBitmap|VideoFrame} source
121
+ * @returns {Promise<Float32Array>} Length 768.
122
  */
123
+ export async function embed(session, source) {
124
+ const { embedding } = await session.run({ pixel_values: imageToTensor(source) });
125
+ return embedding.data;
 
126
  }
127
 
128
  /**
129
+ * Cosine similarity between two embeddings.
130
+ * Returns a value in [-1, 1]: 1 = identical direction, 0 = orthogonal.
131
+ *
132
+ * @param {Float32Array} a
133
+ * @param {Float32Array} b
134
+ * @returns {number}
135
  */
136
  export function cosineSimilarity(a, b) {
 
137
  let dot = 0, normA = 0, normB = 0;
138
  for (let i = 0; i < a.length; i++) {
139
  dot += a[i] * b[i];
 
144
  }
145
 
146
  /**
147
+ * L2-normalize an embedding. After normalizing all vectors in your database,
148
+ * you can use a plain dot product instead of cosine similarity (faster at scale).
149
+ *
150
+ * @param {Float32Array} v
151
+ * @returns {Float32Array}
152
  */
153
  export function l2Normalize(v) {
154
  let norm = 0;
 
160
  }
161
 
162
  /**
163
+ * Find the index and score of the most similar embedding in a list.
164
+ *
165
+ * @param {Float32Array} query
166
+ * @param {Float32Array[]} candidates
167
+ * @returns {{ index: number, score: number }}
168
  */
169
  export function findMostSimilar(query, candidates) {
170
+ let bestIdx = -1, bestScore = -Infinity;
 
171
  for (let i = 0; i < candidates.length; i++) {
172
  const score = cosineSimilarity(query, candidates[i]);
173
+ if (score > bestScore) { bestScore = score; bestIdx = i; }
 
 
 
174
  }
175
  return { index: bestIdx, score: bestScore };
176
  }