iconclip-demo / js /engine.js
NullSense's picture
Publish frontend-only IconClip search demo (transformers.js + parquet vectors)
32775f3 verified
// js/engine.js β€” wraps transformers.js text encoder + hyparquet embedding
// loader behind a small synchronous-looking API.
//
// References (canonical 2025/26 loaders):
// transformers.js v4 release notes (Feb 2026):
// https://huggingface.co/blog/transformersjs-v4
// CLIPTextModelWithProjection API (added in transformers.js PR #227,
// unchanged in v4):
// https://github.com/huggingface/transformers.js/pull/227
// hyparquet browser example:
// https://github.com/hyparam/hyparquet#browser-example
//
// We deliberately skip the high-level `pipeline()` API because we want
// just the text-side projection and direct control over the output tensor.
import {
MODELS,
DEFAULT_MODEL_ID,
DATASET_REPO,
DATASET_REVISION,
DEVICE,
imagesUrl,
TRANSFORMERS_JS_CDN,
HYPARQUET_CDN,
} from './config.js';
// Lazy-loaded once.
let transformersJsPromise = null;
let hyparquetPromise = null;
function loadTransformersJs() {
if (!transformersJsPromise) transformersJsPromise = import(TRANSFORMERS_JS_CDN);
return transformersJsPromise;
}
function loadHyparquet() {
if (!hyparquetPromise) hyparquetPromise = import(HYPARQUET_CDN);
return hyparquetPromise;
}
/**
* Build the per-model embeddings URL from the dataset repo + the model's
* relative path. Centralised so the model registry only knows the path,
* not the full URL.
*/
function embeddingsUrlFor(model) {
return `https://huggingface.co/datasets/${DATASET_REPO}/resolve/${DATASET_REVISION}/${model.embeddingPath}`;
}
/**
* Load a text encoder + tokenizer for the given model entry. The model
* arg is one of the MODELS registry rows; we use its `repo`, `revision`,
* and `dtype` fields. Caller is expected to filter MODELS to only
* `available: true` entries.
*
* @param {object} model β€” entry from MODELS in config.js
* @param {(stage: string, progress: number | null) => void} [onProgress]
* stage ∈ {'tokenizer', 'model'}, progress in [0, 1] or null when unknown.
* @returns {Promise<{ tokenizer: any, model: any, encode: (text: string) => Promise<Float32Array> }>}
*/
export async function loadEncoder(model, onProgress = () => {}) {
const tfjs = await loadTransformersJs();
const { AutoTokenizer, env } = tfjs;
// Force remote loading from the Hub (no local proxy on a static Space).
env.allowLocalModels = false;
env.allowRemoteModels = true;
onProgress('tokenizer', null);
const tokenizer = await AutoTokenizer.from_pretrained(model.repo, {
revision: model.revision || 'main',
});
// Pick the right model loader. `AutoModel` works for encoder-only ONNX
// exports (like IconClip) but fails on full multimodal CLIPs (Xenova's
// ports) because they demand pixel_values. `CLIPTextModelWithProjection`
// loads only the text tower of a CLIPModel and exposes `text_embeds`.
const ModelClass = tfjs[model.modelClass || 'AutoModel'];
if (!ModelClass) {
throw new Error(`Unknown modelClass '${model.modelClass}' for ${model.repo}`);
}
onProgress('model', null);
const onnxModel = await ModelClass.from_pretrained(model.repo, {
revision: model.revision || 'main',
dtype: model.dtype || 'q8',
device: DEVICE,
progress_callback: (p) => {
if (p && typeof p.progress === 'number') {
onProgress('model', p.progress / 100);
}
},
});
/** @param {string} text */
async function encode(text) {
// CLIP text encoders use a fixed 77-token context window β€” that's
// the size of the learned positional-embedding table the model was
// trained against. Every CLIP-family ONNX export on HF requires the
// input to be padded to 77; transformers.js's standard recipe for
// these models is `padding: 'max_length', max_length: 77`.
const inputs = tokenizer([text], {
padding: 'max_length',
max_length: 77,
truncation: true,
});
const out = await onnxModel(inputs);
// The output key varies per model (see config.modelClass): IconClip's
// encoder-only ONNX exposes `embeddings`; CLIPTextModelWithProjection
// exposes `text_embeds`. We prefer the model's declared `outputKey`
// but fall back across known names so the engine is forgiving.
const tensor =
(model.outputKey && out[model.outputKey]) ||
out.embeddings ||
out.text_embeds ||
out.last_hidden_state;
if (!tensor || !tensor.data) {
throw new Error(
`Unexpected ONNX output for ${model.repo} β€” expected ` +
`'${model.outputKey || 'embeddings'}', got keys: ` +
Object.keys(out).join(', '),
);
}
return tensor.data;
}
return { tokenizer, model: onnxModel, encode };
}
/**
* Load + decode a model's embedding parquet from the benchmark dataset.
* Schema expected per row: `_id: string, embedding: list<float32>`
* (dim derived from the first row). Returns a contiguous Float32Array
* matrix + parallel ids array for cosine.js's row-major linear access.
*
* @param {object} model β€” entry from MODELS (uses `embeddingPath`)
* @param {(loaded: number, total: number) => void} [onProgress]
* @returns {Promise<{ ids: string[], matrix: Float32Array, dim: number }>}
*/
export async function loadEmbeddings(model, onProgress = () => {}) {
const url = embeddingsUrlFor(model);
const cached = await openCache().then((c) => c.match(url));
let buf;
if (cached) {
buf = await cached.arrayBuffer();
onProgress(buf.byteLength, buf.byteLength);
} else {
const resp = await fetch(url);
if (!resp.ok) {
throw new Error(`Embeddings fetch failed: ${resp.status} ${resp.statusText} (${url})`);
}
const total = Number(resp.headers.get('content-length')) || 0;
const reader = resp.body.getReader();
const chunks = [];
let loaded = 0;
for (;;) {
const { done, value } = await reader.read();
if (done) break;
chunks.push(value);
loaded += value.byteLength;
onProgress(loaded, total);
}
buf = new Uint8Array(loaded);
let off = 0;
for (const c of chunks) { buf.set(c, off); off += c.byteLength; }
buf = buf.buffer;
// Fire and forget; Cache.put expects a Response and may fail if the
// browser disallows opaque caching β€” swallow that, it's an optimisation.
openCache()
.then((c) => c.put(url, new Response(buf.slice(0))))
.catch(() => {});
}
const { parquetReadObjects } = await loadHyparquet();
// hyparquet wants an AsyncBuffer β€” for an in-memory blob the simplest
// adapter is to make `slice` resolve synchronously via Promise.resolve.
const file = {
byteLength: buf.byteLength,
slice: (start, end) => Promise.resolve(buf.slice(start, end ?? buf.byteLength)),
};
const rows = await parquetReadObjects({ file, columns: ['_id', 'embedding'] });
if (rows.length === 0) {
throw new Error('Embeddings parquet is empty.');
}
const dim = rows[0].embedding.length;
const matrix = new Float32Array(rows.length * dim);
const ids = new Array(rows.length);
for (let r = 0; r < rows.length; r++) {
ids[r] = rows[r]._id;
const v = rows[r].embedding;
const off = r * dim;
for (let i = 0; i < dim; i++) matrix[off + i] = v[i];
}
return { ids, matrix, dim };
}
/**
* Load (and cache) the SVG dict for a single library. Lazy β€” only called
* when that library is enabled in the filter chips.
*
* @param {string} librarySlug
* @returns {Promise<Map<string, string>>} Map<_id, svg_text>
*/
const svgCache = new Map(); // slug β†’ Map<_id, svg_text>
export async function loadLibrarySvgs(librarySlug) {
if (svgCache.has(librarySlug)) return svgCache.get(librarySlug);
const url = imagesUrl(librarySlug);
const cached = await openCache().then((c) => c.match(url));
let buf;
if (cached) {
buf = await cached.arrayBuffer();
} else {
const resp = await fetch(url);
if (!resp.ok) {
throw new Error(`Images fetch failed for ${librarySlug}: ${resp.status}`);
}
buf = await resp.arrayBuffer();
openCache()
.then((c) => c.put(url, new Response(buf.slice(0))))
.catch(() => {});
}
const { parquetReadObjects } = await loadHyparquet();
const file = {
byteLength: buf.byteLength,
slice: (start, end) => Promise.resolve(buf.slice(start, end ?? buf.byteLength)),
};
const rows = await parquetReadObjects({ file, columns: ['_id', 'svg_text'] });
const map = new Map();
for (const row of rows) map.set(row._id, row.svg_text);
svgCache.set(librarySlug, map);
return map;
}
/** Get a cached SVG without triggering a fetch. Returns null if not loaded. */
export function getCachedSvg(id) {
const colon = id.indexOf(':');
if (colon === -1) return null;
const slug = id.slice(0, colon);
const map = svgCache.get(slug);
return map ? map.get(id) ?? null : null;
}
// Bump the cache version any time the upstream parquet schema or
// compression changes β€” v2 invalidates the v1 ZSTD-compressed copies that
// older visits cached before we switched to snappy.
const CACHE_NAME = 'iconclip-demo-v2';
function openCache() {
if (typeof caches === 'undefined') {
return Promise.resolve({
match: () => Promise.resolve(undefined),
put: () => Promise.resolve(),
});
}
// Best-effort cleanup of older cache versions on first run.
caches.keys().then((keys) => {
for (const k of keys) {
if (k.startsWith('iconclip-demo-') && k !== CACHE_NAME) {
caches.delete(k).catch(() => {});
}
}
}).catch(() => {});
return caches.open(CACHE_NAME);
}