Spaces:
Running
Running
| // js/engine.js β wraps transformers.js text encoder + hyparquet embedding | |
| // loader behind a small synchronous-looking API. | |
| // | |
| // References (canonical 2025/26 loaders): | |
| // transformers.js v4 release notes (Feb 2026): | |
| // https://huggingface.co/blog/transformersjs-v4 | |
| // CLIPTextModelWithProjection API (added in transformers.js PR #227, | |
| // unchanged in v4): | |
| // https://github.com/huggingface/transformers.js/pull/227 | |
| // hyparquet browser example: | |
| // https://github.com/hyparam/hyparquet#browser-example | |
| // | |
| // We deliberately skip the high-level `pipeline()` API because we want | |
| // just the text-side projection and direct control over the output tensor. | |
| import { | |
| MODELS, | |
| DEFAULT_MODEL_ID, | |
| DATASET_REPO, | |
| DATASET_REVISION, | |
| DEVICE, | |
| imagesUrl, | |
| TRANSFORMERS_JS_CDN, | |
| HYPARQUET_CDN, | |
| } from './config.js'; | |
| // Lazy-loaded once. | |
| let transformersJsPromise = null; | |
| let hyparquetPromise = null; | |
| function loadTransformersJs() { | |
| if (!transformersJsPromise) transformersJsPromise = import(TRANSFORMERS_JS_CDN); | |
| return transformersJsPromise; | |
| } | |
| function loadHyparquet() { | |
| if (!hyparquetPromise) hyparquetPromise = import(HYPARQUET_CDN); | |
| return hyparquetPromise; | |
| } | |
| /** | |
| * Build the per-model embeddings URL from the dataset repo + the model's | |
| * relative path. Centralised so the model registry only knows the path, | |
| * not the full URL. | |
| */ | |
| function embeddingsUrlFor(model) { | |
| return `https://huggingface.co/datasets/${DATASET_REPO}/resolve/${DATASET_REVISION}/${model.embeddingPath}`; | |
| } | |
| /** | |
| * Load a text encoder + tokenizer for the given model entry. The model | |
| * arg is one of the MODELS registry rows; we use its `repo`, `revision`, | |
| * and `dtype` fields. Caller is expected to filter MODELS to only | |
| * `available: true` entries. | |
| * | |
| * @param {object} model β entry from MODELS in config.js | |
| * @param {(stage: string, progress: number | null) => void} [onProgress] | |
| * stage β {'tokenizer', 'model'}, progress in [0, 1] or null when unknown. | |
| * @returns {Promise<{ tokenizer: any, model: any, encode: (text: string) => Promise<Float32Array> }>} | |
| */ | |
| export async function loadEncoder(model, onProgress = () => {}) { | |
| const tfjs = await loadTransformersJs(); | |
| const { AutoTokenizer, env } = tfjs; | |
| // Force remote loading from the Hub (no local proxy on a static Space). | |
| env.allowLocalModels = false; | |
| env.allowRemoteModels = true; | |
| onProgress('tokenizer', null); | |
| const tokenizer = await AutoTokenizer.from_pretrained(model.repo, { | |
| revision: model.revision || 'main', | |
| }); | |
| // Pick the right model loader. `AutoModel` works for encoder-only ONNX | |
| // exports (like IconClip) but fails on full multimodal CLIPs (Xenova's | |
| // ports) because they demand pixel_values. `CLIPTextModelWithProjection` | |
| // loads only the text tower of a CLIPModel and exposes `text_embeds`. | |
| const ModelClass = tfjs[model.modelClass || 'AutoModel']; | |
| if (!ModelClass) { | |
| throw new Error(`Unknown modelClass '${model.modelClass}' for ${model.repo}`); | |
| } | |
| onProgress('model', null); | |
| const onnxModel = await ModelClass.from_pretrained(model.repo, { | |
| revision: model.revision || 'main', | |
| dtype: model.dtype || 'q8', | |
| device: DEVICE, | |
| progress_callback: (p) => { | |
| if (p && typeof p.progress === 'number') { | |
| onProgress('model', p.progress / 100); | |
| } | |
| }, | |
| }); | |
| /** @param {string} text */ | |
| async function encode(text) { | |
| // CLIP text encoders use a fixed 77-token context window β that's | |
| // the size of the learned positional-embedding table the model was | |
| // trained against. Every CLIP-family ONNX export on HF requires the | |
| // input to be padded to 77; transformers.js's standard recipe for | |
| // these models is `padding: 'max_length', max_length: 77`. | |
| const inputs = tokenizer([text], { | |
| padding: 'max_length', | |
| max_length: 77, | |
| truncation: true, | |
| }); | |
| const out = await onnxModel(inputs); | |
| // The output key varies per model (see config.modelClass): IconClip's | |
| // encoder-only ONNX exposes `embeddings`; CLIPTextModelWithProjection | |
| // exposes `text_embeds`. We prefer the model's declared `outputKey` | |
| // but fall back across known names so the engine is forgiving. | |
| const tensor = | |
| (model.outputKey && out[model.outputKey]) || | |
| out.embeddings || | |
| out.text_embeds || | |
| out.last_hidden_state; | |
| if (!tensor || !tensor.data) { | |
| throw new Error( | |
| `Unexpected ONNX output for ${model.repo} β expected ` + | |
| `'${model.outputKey || 'embeddings'}', got keys: ` + | |
| Object.keys(out).join(', '), | |
| ); | |
| } | |
| return tensor.data; | |
| } | |
| return { tokenizer, model: onnxModel, encode }; | |
| } | |
| /** | |
| * Load + decode a model's embedding parquet from the benchmark dataset. | |
| * Schema expected per row: `_id: string, embedding: list<float32>` | |
| * (dim derived from the first row). Returns a contiguous Float32Array | |
| * matrix + parallel ids array for cosine.js's row-major linear access. | |
| * | |
| * @param {object} model β entry from MODELS (uses `embeddingPath`) | |
| * @param {(loaded: number, total: number) => void} [onProgress] | |
| * @returns {Promise<{ ids: string[], matrix: Float32Array, dim: number }>} | |
| */ | |
| export async function loadEmbeddings(model, onProgress = () => {}) { | |
| const url = embeddingsUrlFor(model); | |
| const cached = await openCache().then((c) => c.match(url)); | |
| let buf; | |
| if (cached) { | |
| buf = await cached.arrayBuffer(); | |
| onProgress(buf.byteLength, buf.byteLength); | |
| } else { | |
| const resp = await fetch(url); | |
| if (!resp.ok) { | |
| throw new Error(`Embeddings fetch failed: ${resp.status} ${resp.statusText} (${url})`); | |
| } | |
| const total = Number(resp.headers.get('content-length')) || 0; | |
| const reader = resp.body.getReader(); | |
| const chunks = []; | |
| let loaded = 0; | |
| for (;;) { | |
| const { done, value } = await reader.read(); | |
| if (done) break; | |
| chunks.push(value); | |
| loaded += value.byteLength; | |
| onProgress(loaded, total); | |
| } | |
| buf = new Uint8Array(loaded); | |
| let off = 0; | |
| for (const c of chunks) { buf.set(c, off); off += c.byteLength; } | |
| buf = buf.buffer; | |
| // Fire and forget; Cache.put expects a Response and may fail if the | |
| // browser disallows opaque caching β swallow that, it's an optimisation. | |
| openCache() | |
| .then((c) => c.put(url, new Response(buf.slice(0)))) | |
| .catch(() => {}); | |
| } | |
| const { parquetReadObjects } = await loadHyparquet(); | |
| // hyparquet wants an AsyncBuffer β for an in-memory blob the simplest | |
| // adapter is to make `slice` resolve synchronously via Promise.resolve. | |
| const file = { | |
| byteLength: buf.byteLength, | |
| slice: (start, end) => Promise.resolve(buf.slice(start, end ?? buf.byteLength)), | |
| }; | |
| const rows = await parquetReadObjects({ file, columns: ['_id', 'embedding'] }); | |
| if (rows.length === 0) { | |
| throw new Error('Embeddings parquet is empty.'); | |
| } | |
| const dim = rows[0].embedding.length; | |
| const matrix = new Float32Array(rows.length * dim); | |
| const ids = new Array(rows.length); | |
| for (let r = 0; r < rows.length; r++) { | |
| ids[r] = rows[r]._id; | |
| const v = rows[r].embedding; | |
| const off = r * dim; | |
| for (let i = 0; i < dim; i++) matrix[off + i] = v[i]; | |
| } | |
| return { ids, matrix, dim }; | |
| } | |
| /** | |
| * Load (and cache) the SVG dict for a single library. Lazy β only called | |
| * when that library is enabled in the filter chips. | |
| * | |
| * @param {string} librarySlug | |
| * @returns {Promise<Map<string, string>>} Map<_id, svg_text> | |
| */ | |
| const svgCache = new Map(); // slug β Map<_id, svg_text> | |
| export async function loadLibrarySvgs(librarySlug) { | |
| if (svgCache.has(librarySlug)) return svgCache.get(librarySlug); | |
| const url = imagesUrl(librarySlug); | |
| const cached = await openCache().then((c) => c.match(url)); | |
| let buf; | |
| if (cached) { | |
| buf = await cached.arrayBuffer(); | |
| } else { | |
| const resp = await fetch(url); | |
| if (!resp.ok) { | |
| throw new Error(`Images fetch failed for ${librarySlug}: ${resp.status}`); | |
| } | |
| buf = await resp.arrayBuffer(); | |
| openCache() | |
| .then((c) => c.put(url, new Response(buf.slice(0)))) | |
| .catch(() => {}); | |
| } | |
| const { parquetReadObjects } = await loadHyparquet(); | |
| const file = { | |
| byteLength: buf.byteLength, | |
| slice: (start, end) => Promise.resolve(buf.slice(start, end ?? buf.byteLength)), | |
| }; | |
| const rows = await parquetReadObjects({ file, columns: ['_id', 'svg_text'] }); | |
| const map = new Map(); | |
| for (const row of rows) map.set(row._id, row.svg_text); | |
| svgCache.set(librarySlug, map); | |
| return map; | |
| } | |
| /** Get a cached SVG without triggering a fetch. Returns null if not loaded. */ | |
| export function getCachedSvg(id) { | |
| const colon = id.indexOf(':'); | |
| if (colon === -1) return null; | |
| const slug = id.slice(0, colon); | |
| const map = svgCache.get(slug); | |
| return map ? map.get(id) ?? null : null; | |
| } | |
| // Bump the cache version any time the upstream parquet schema or | |
| // compression changes β v2 invalidates the v1 ZSTD-compressed copies that | |
| // older visits cached before we switched to snappy. | |
| const CACHE_NAME = 'iconclip-demo-v2'; | |
| function openCache() { | |
| if (typeof caches === 'undefined') { | |
| return Promise.resolve({ | |
| match: () => Promise.resolve(undefined), | |
| put: () => Promise.resolve(), | |
| }); | |
| } | |
| // Best-effort cleanup of older cache versions on first run. | |
| caches.keys().then((keys) => { | |
| for (const k of keys) { | |
| if (k.startsWith('iconclip-demo-') && k !== CACHE_NAME) { | |
| caches.delete(k).catch(() => {}); | |
| } | |
| } | |
| }).catch(() => {}); | |
| return caches.open(CACHE_NAME); | |
| } | |