sapiens2-onnx / example_embeddings.js
barakplasma's picture
Rewrite example as browser ES module; remove Node.js/sharp dependency
42e82e0 verified
/**
* example_embeddings.js
*
* Drop-in ES module for browser use. Exports:
* loadModelCached(url?) β€” load and cache model in IndexedDB
* embed(session, source) β€” get 768-dim Float32Array from any image source
* cosineSimilarity(a, b) β€” similarity score in [-1, 1]
* l2Normalize(v) β€” normalize so dot product equals cosine similarity
* findMostSimilar(q, list) β€” nearest-neighbor in an embedding array
*
* Requirements: onnxruntime-web (npm install onnxruntime-web)
*
* Usage:
* import { loadModelCached, embed, cosineSimilarity } from "./example_embeddings.js";
* const session = await loadModelCached();
* const emb = await embed(session, document.getElementById("myImage"));
*/
import * as ort from "onnxruntime-web";
// ── Config ─────────────────────────────────────────────────────────────────
ort.env.wasm.wasmPaths = "https://cdn.jsdelivr.net/npm/onnxruntime-web/dist/";
const MODEL_URL =
"https://huggingface.co/barakplasma/sapiens2-onnx/resolve/main/sapiens2_0.1b_int8.onnx";
const H = 1024;
const W = 768;
const MEAN = [0.485, 0.456, 0.406];
const STD = [0.229, 0.224, 0.225];
const DB_NAME = "sapiens2-onnx";
const DB_STORE = "models";
// ── IndexedDB helpers ──────────────────────────────────────────────────────
function openDB() {
return new Promise((resolve, reject) => {
const req = indexedDB.open(DB_NAME, 1);
req.onupgradeneeded = () => req.result.createObjectStore(DB_STORE);
req.onsuccess = () => resolve(req.result);
req.onerror = () => reject(req.error);
});
}
function idbGet(db, key) {
return new Promise(resolve => {
const req = db.transaction(DB_STORE).objectStore(DB_STORE).get(key);
req.onsuccess = () => resolve(req.result ?? null);
req.onerror = () => resolve(null);
});
}
function idbPut(db, key, value) {
return new Promise((resolve, reject) => {
const req = db.transaction(DB_STORE, "readwrite").objectStore(DB_STORE).put(value, key);
req.onsuccess = () => resolve();
req.onerror = () => reject(req.error);
});
}
// ── Public API ─────────────────────────────────────────────────────────────
/**
* Load the ONNX model. On first call, fetches from HuggingFace and stores the
* ArrayBuffer in IndexedDB. Subsequent calls load from cache instantly.
*
* @param {string} [url] Override the default model URL.
* @returns {Promise<ort.InferenceSession>}
*/
export async function loadModelCached(url = MODEL_URL) {
const db = await openDB();
const cached = await idbGet(db, url);
const buf = cached ?? await fetch(url)
.then(r => {
if (!r.ok) throw new Error(`Failed to fetch model: ${r.status} ${r.statusText}`);
return r.arrayBuffer();
})
.then(async buf => {
await idbPut(db, url, buf);
return buf;
});
return ort.InferenceSession.create(buf, {
executionProviders: ["webgpu", "wasm"],
graphOptimizationLevel: "all",
});
}
/**
* Convert an image source to a float32 NCHW tensor with ImageNet normalization.
* Accepts anything drawImage() accepts: <img>, <canvas>, ImageBitmap, VideoFrame.
*
* @param {HTMLImageElement|HTMLCanvasElement|ImageBitmap|VideoFrame} source
* @returns {ort.Tensor} Shape (1, 3, 1024, 768).
*/
export function imageToTensor(source) {
const canvas = document.createElement("canvas");
canvas.width = W;
canvas.height = H;
const ctx = canvas.getContext("2d");
ctx.drawImage(source, 0, 0, W, H);
const { data } = ctx.getImageData(0, 0, W, H); // RGBA uint8
const t = new Float32Array(3 * H * W);
for (let i = 0; i < H * W; i++) {
t[i] = (data[i * 4] / 255 - MEAN[0]) / STD[0]; // R
t[H * W + i] = (data[i * 4 + 1] / 255 - MEAN[1]) / STD[1]; // G
t[2 * H * W + i] = (data[i * 4 + 2] / 255 - MEAN[2]) / STD[2]; // B
}
return new ort.Tensor("float32", t, [1, 3, H, W]);
}
/**
* Run the model on one image and return its 768-dim embedding.
*
* @param {ort.InferenceSession} session
* @param {HTMLImageElement|HTMLCanvasElement|ImageBitmap|VideoFrame} source
* @returns {Promise<Float32Array>} Length 768.
*/
export async function embed(session, source) {
const { embedding } = await session.run({ pixel_values: imageToTensor(source) });
return embedding.data;
}
/**
* Cosine similarity between two embeddings.
* Returns a value in [-1, 1]: 1 = identical direction, 0 = orthogonal.
*
* @param {Float32Array} a
* @param {Float32Array} b
* @returns {number}
*/
export function cosineSimilarity(a, b) {
let dot = 0, normA = 0, normB = 0;
for (let i = 0; i < a.length; i++) {
dot += a[i] * b[i];
normA += a[i] * a[i];
normB += b[i] * b[i];
}
return dot / (Math.sqrt(normA) * Math.sqrt(normB));
}
/**
* L2-normalize an embedding. After normalizing all vectors in your database,
* you can use a plain dot product instead of cosine similarity (faster at scale).
*
* @param {Float32Array} v
* @returns {Float32Array}
*/
export function l2Normalize(v) {
let norm = 0;
for (let i = 0; i < v.length; i++) norm += v[i] * v[i];
norm = Math.sqrt(norm);
const out = new Float32Array(v.length);
for (let i = 0; i < v.length; i++) out[i] = v[i] / norm;
return out;
}
/**
* Find the index and score of the most similar embedding in a list.
*
* @param {Float32Array} query
* @param {Float32Array[]} candidates
* @returns {{ index: number, score: number }}
*/
export function findMostSimilar(query, candidates) {
let bestIdx = -1, bestScore = -Infinity;
for (let i = 0; i < candidates.length; i++) {
const score = cosineSimilarity(query, candidates[i]);
if (score > bestScore) { bestScore = score; bestIdx = i; }
}
return { index: bestIdx, score: bestScore };
}