// Mega-ASR — pure browser ASR // Loads ONNX models from Reza2kn/mega-asr-onnx via onnxruntime-web, // the tokenizer + Whisper mel features via @huggingface/transformers, // and runs the encode/prefill/step pipeline on the user's device. import { AutoTokenizer, AutoProcessor, RawAudio } from "https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.5.2/dist/transformers.min.js"; import * as ort from "https://cdn.jsdelivr.net/npm/onnxruntime-web@1.23.0/dist/ort.webgpu.bundle.min.mjs"; ort.env.wasm.numThreads = navigator.hardwareConcurrency || 4; ort.env.wasm.simd = true; const HF_ROOT = "https://huggingface.co/Reza2kn/mega-asr-onnx/resolve/main"; const NUM_LAYERS = 28; const HIDDEN = 2048; const VOCAB = 151936; const REFERENCES = { noise: "I usually take the quieter road home because the main street gets crowded after work.", far_field: "Please remind me to print the forms before we leave for the appointment tomorrow.", obstructed: "I forgot my charger at home, so I need to find an outlet before the meeting starts.", distortion: "The new coffee machine is simple, but everyone keeps forgetting where the filters are stored.", recording: "Can you check whether the train still stops at the downtown station after eight tonight?", echo: "I need to return these shoes because the size feels fine standing up but terrible while walking.", dropout: "My aunt is learning video calls, and she gets excited whenever the picture actually works.", mixed: "My sister is bringing dinner over later, so we do not need to cook tonight.", }; // ---- state ----------------------------------------------------------------- const state = { loaded: false, loading: false, encoder: null, prefill: null, step: null, tokenizer: null, processor: null, embedI8: null, // Int8Array, shape (VOCAB, HIDDEN) embedScales: null, // Float16->Float32Array of length VOCAB manifest: null, device: "wasm", }; const log = (msg) => { const el = document.getElementById("log"); const line = document.createElement("div"); line.textContent = `[${new Date().toLocaleTimeString()}] ${msg}`; el.appendChild(line); el.scrollTop = el.scrollHeight; console.log(msg); }; const setStatus = (s) => { document.getElementById("status").textContent = s; }; const setLoaderStatus = (s) => { document.getElementById("loader-status").textContent = s; }; const setProgress = (pct) => { document.getElementById("loader-bar").style.width = pct + "%"; }; // ---- IndexedDB cache for big blobs ---------------------------------------- const DB_NAME = "mega-asr-cache-v2-gptq"; const DB_STORE = "blobs"; function openDB() { return new Promise((resolve, reject) => { const req = indexedDB.open(DB_NAME, 1); req.onupgradeneeded = (e) => { const db = e.target.result; if (!db.objectStoreNames.contains(DB_STORE)) db.createObjectStore(DB_STORE); }; req.onsuccess = (e) => resolve(e.target.result); req.onerror = (e) => reject(e.target.error); }); } async function cacheGet(key) { const db = await openDB(); return new Promise((resolve, reject) => { const tx = db.transaction(DB_STORE, "readonly"); const r = tx.objectStore(DB_STORE).get(key); r.onsuccess = () => resolve(r.result || null); r.onerror = () => reject(r.error); }); } async function cachePut(key, blob) { const db = await openDB(); return new Promise((resolve, reject) => { const tx = db.transaction(DB_STORE, "readwrite"); const r = tx.objectStore(DB_STORE).put(blob, key); r.onsuccess = () => resolve(); r.onerror = () => reject(r.error); }); } async function fetchWithCache(url, label, onProgress) { const key = url; const cached = await cacheGet(key); if (cached) { log(`cached: ${label}`); return cached; } log(`downloading ${label} ...`); const res = await fetch(url); if (!res.ok) throw new Error(`${url}: ${res.status}`); const total = parseInt(res.headers.get("content-length") || "0", 10); const reader = res.body.getReader(); const chunks = []; let read = 0; while (true) { const { done, value } = await reader.read(); if (done) break; chunks.push(value); read += value.length; if (total && onProgress) onProgress(read / total); } const buf = new Uint8Array(read); let off = 0; for (const c of chunks) { buf.set(c, off); off += c.length; } await cachePut(key, buf); log(`downloaded ${label} (${(read/1e6).toFixed(0)} MB)`); return buf; } // ---- ONNX session creation ------------------------------------------------- // Always prefer the user-selected device; fall back to WASM only for the // session that fails (per-session, not global). Don't mutate state.device. function epList() { return state.device === "webgpu" ? ["webgpu", "wasm"] : ["wasm"]; } async function createSessionSimple(graphUrl, label, onProgress) { const graph = await fetchWithCache(graphUrl, label, onProgress); try { const sess = await ort.InferenceSession.create(graph, { executionProviders: epList() }); log(`session ready: ${label} (${state.device})`); return sess; } catch (e) { if (state.device === "webgpu") { log(`webgpu failed for ${label} (${e.message}); retrying this session with wasm`); const sess = await ort.InferenceSession.create(graph, { executionProviders: ["wasm"] }); log(`session ready: ${label} (wasm fallback)`); return sess; } throw e; } } async function createSession(graphUrl, dataUrl, label, onProgress) { const graph = await fetchWithCache(graphUrl, label + " graph", () => {}); const weights = await fetchWithCache(dataUrl, label + " weights", onProgress); const externalFiles = [{ path: dataUrl.split("/").pop(), data: weights }]; try { const sess = await ort.InferenceSession.create(graph, { executionProviders: epList(), externalData: externalFiles, }); log(`session ready: ${label} (${state.device})`); return sess; } catch (e) { if (state.device === "webgpu") { log(`webgpu failed for ${label} (${e.message}); retrying this session with wasm`); const sess = await ort.InferenceSession.create(graph, { executionProviders: ["wasm"], externalData: externalFiles, }); log(`session ready: ${label} (wasm fallback)`); return sess; } log(`session create failed for ${label}: ${e.message}`); throw e; } } // ---- embedding lookup ------------------------------------------------------ // Convert int16 fp16 -> JS Number (slow, only for embed scales which is small) function fp16ToF32(u16) { const sign = (u16 >> 15) & 0x1; const exp = (u16 >> 10) & 0x1f; const frac = u16 & 0x3ff; let v; if (exp === 0) v = (frac === 0) ? 0 : Math.pow(2, -14) * (frac / 1024); else if (exp === 31) v = (frac === 0) ? Infinity : NaN; else v = Math.pow(2, exp - 15) * (1 + frac / 1024); return sign ? -v : v; } function lookupEmbedding(tokenId) { // Returns a Float32Array of length HIDDEN with the dequantized embedding. const out = new Float32Array(HIDDEN); const scale = state.embedScales[tokenId]; const base = tokenId * HIDDEN; for (let i = 0; i < HIDDEN; i++) { out[i] = state.embedI8[base + i] * scale; } return out; } // ---- model loader ---------------------------------------------------------- async function pickDevice() { // Try WebGPU first, fall back to WASM if ("gpu" in navigator) { try { const adapter = await navigator.gpu.requestAdapter(); if (adapter) { const device = await adapter.requestDevice(); if (device) { state.device = "webgpu"; return; } } } catch (e) { log("WebGPU unavailable: " + e.message); } } state.device = "wasm"; } async function loadAll() { if (state.loaded || state.loading) return; state.loading = true; setLoaderStatus("starting..."); await pickDevice(); log(`execution provider: ${state.device}`); // 1. manifest + tokenizer setLoaderStatus("tokenizer + manifest ..."); state.tokenizer = await AutoTokenizer.from_pretrained("Reza2kn/mega-asr-onnx"); log("tokenizer loaded"); state.processor = await AutoProcessor.from_pretrained("Reza2kn/mega-asr-onnx").catch(() => null); if (state.processor) log("processor (feature extractor) loaded"); else log("processor unavailable -- live audio uploads will not work, examples still ok"); const manifest = await fetch(`${HF_ROOT}/examples_mels/manifest.json`).then(r => r.json()); state.manifest = manifest; // 2. embedding table + scales (313 MB) setLoaderStatus("embedding table ..."); const embedBlob = await fetchWithCache(`${HF_ROOT}/onnx/embed_int8.bin`, "embed (311 MB)", p => setProgress(p * 25)); state.embedI8 = new Int8Array(embedBlob.buffer); const scalesBlob = await fetchWithCache(`${HF_ROOT}/onnx/embed_int8_scales.bin`, "embed scales", () => {}); // scales are stored as fp16; expand to fp32 const u16 = new Uint16Array(scalesBlob.buffer); state.embedScales = new Float32Array(u16.length); for (let i = 0; i < u16.length; i++) state.embedScales[i] = fp16ToF32(u16[i]); log(`embedding ready: ${u16.length} tokens × ${HIDDEN}`); setProgress(30); // 3. ONNX sessions // Audio encoder: INT4 (MatMulNBits) — well-supported on WebGPU and WASM. // Static INT8 (QLinearConv/QLinearMatMul) crashes onnxruntime-web on WebGPU. setLoaderStatus("audio encoder INT4 ..."); state.encoder = await createSession( `${HF_ROOT}/onnx/audio_encoder_int4.onnx`, `${HF_ROOT}/onnx/audio_encoder_int4.onnx.data`, "audio_encoder INT4", p => setProgress(30 + p * 10), ); setProgress(40); setLoaderStatus("decoder prefill (~970 MB)..."); state.prefill = await createSession( `${HF_ROOT}/onnx/decoder_prefill_int4.onnx`, `${HF_ROOT}/onnx/decoder_prefill_int4.onnx.data`, "decoder_prefill", p => setProgress(40 + p * 30), ); setProgress(70); setLoaderStatus("decoder step (~970 MB)..."); state.step = await createSession( `${HF_ROOT}/onnx/decoder_step_int4.onnx`, `${HF_ROOT}/onnx/decoder_step_int4.onnx.data`, "decoder_step", p => setProgress(70 + p * 30), ); setProgress(100); state.loaded = true; state.loading = false; setLoaderStatus(`ready (${state.device})`); document.getElementById("load-btn").disabled = true; document.getElementById("transcribe-btn").disabled = false; log("all models loaded."); } // ---- mel features for arbitrary audio --------------------------------------- async function audioToMel(file) { if (!state.processor) throw new Error("Live audio uploads need the processor (not available)"); const buf = await file.arrayBuffer(); // Decode + resample to 16 kHz mono via OfflineAudioContext const audioCtx = new (window.AudioContext || window.webkitAudioContext)({ sampleRate: 16000 }); const decoded = await audioCtx.decodeAudioData(buf); // Average to mono if multi-channel let pcm = decoded.getChannelData(0); if (decoded.numberOfChannels > 1) { const tmp = new Float32Array(decoded.length); for (let c = 0; c < decoded.numberOfChannels; c++) { const ch = decoded.getChannelData(c); for (let i = 0; i < ch.length; i++) tmp[i] += ch[i] / decoded.numberOfChannels; } pcm = tmp; } // Run through transformers.js WhisperFeatureExtractor (via the loaded processor) const feat = await state.processor(new RawAudio(pcm, 16000)); // feat.input_features: Tensor[1, 128, T] return { mel: feat.input_features.data, dims: feat.input_features.dims }; } // ---- example mel loader ---------------------------------------------------- async function loadExampleMel(name) { const url = `${HF_ROOT}/examples_mels/${name}.mel.bin`; const buf = await fetchWithCache(url, `mel ${name}`, () => {}); // fp16 -> fp32 (3000 * 128 floats) const u16 = new Uint16Array(buf.buffer); const f32 = new Float32Array(u16.length); for (let i = 0; i < u16.length; i++) f32[i] = fp16ToF32(u16[i]); // Shape (1, 128, 3000) return { mel: f32, dims: [1, 128, 3000], T_mel: state.manifest.examples[name].T_mel }; } // ---- core inference -------------------------------------------------------- async function transcribe({ mel, dims, T_mel }) { if (!state.loaded) throw new Error("models not loaded"); // 1. encode setStatus("audio encoder ..."); const melTensor = new ort.Tensor("float32", mel, dims); const encOut = await state.encoder.run({ mel: melTensor }); // WebGPU outputs live in GPU memory — getData(true) downloads to CPU. const audioEmbedsAll = await encOut.audio_embeds.getData(true); const audioEmbedsDims = encOut.audio_embeds.dims; const realChunks = Math.floor((T_mel + 99) / 100); const lastChunkMel = T_mel - (realChunks - 1) * 100; const realAudioFrames = (realChunks - 1) * 13 + Math.floor((lastChunkMel + 7) / 8); // 2. build prompt + scatter audio embeds at <|audio_pad|>. // Default to the forced-English prompt; the model's auto language detection // can fail at INT4 quantization on borderline audio. setStatus("building prompt ..."); const lang = (document.getElementById("lang-select")?.value) || "english"; const promptIds = (state.manifest.prompts && state.manifest.prompts[lang]?.ids) || state.manifest.prompt_ids; const audioPadId = state.manifest.audio_pad_id; // Expand audio_pad in the prompt to realAudioFrames placeholder tokens const tokens = []; for (const t of promptIds) { if (t === audioPadId) for (let i = 0; i < realAudioFrames; i++) tokens.push(audioPadId); else tokens.push(t); } const L = tokens.length; // 3. embed text tokens, scatter audio embeds at placeholder positions const inputsEmbeds = new Float32Array(L * HIDDEN); let audioIdx = 0; for (let i = 0; i < L; i++) { if (tokens[i] === audioPadId) { // audio_embed[audioIdx] const src = audioIdx * HIDDEN; const dst = i * HIDDEN; for (let k = 0; k < HIDDEN; k++) inputsEmbeds[dst + k] = audioEmbedsAll[src + k]; audioIdx++; } else { const e = lookupEmbedding(tokens[i]); const dst = i * HIDDEN; for (let k = 0; k < HIDDEN; k++) inputsEmbeds[dst + k] = e[k]; } } // ONNX wants fp16 embeds: convert const inputsEmbedsF16 = floatArrayToFp16(inputsEmbeds); const attnMask = new BigInt64Array(L); for (let i = 0; i < L; i++) attnMask[i] = 1n; const posIds = new BigInt64Array(L); for (let i = 0; i < L; i++) posIds[i] = BigInt(i); // 4. prefill setStatus("prefill ..."); const t0 = performance.now(); const prefillOut = await state.prefill.run({ inputs_embeds: new ort.Tensor("float16", inputsEmbedsF16, [1, L, HIDDEN]), attention_mask: new ort.Tensor("int64", attnMask, [1, L]), position_ids: new ort.Tensor("int64", posIds, [1, L]), }); log(`prefill: ${(performance.now() - t0).toFixed(0)} ms (L=${L})`); // 5. greedy decode setStatus("decoding ..."); // WebGPU outputs live in GPU memory — must call getData() (async) to bring // them back to CPU. CPU/WASM tensors return their data array synchronously. let logits = await prefillOut.logits.getData(true); // (1, L, VOCAB) const logitsDims = prefillOut.logits.dims; // get argmax of last token let nid = argmax(logits, (logitsDims[1] - 1) * VOCAB, VOCAB); const gen = [nid]; const eos = state.manifest.eos_token_id; let curLen = L; // collect KV cache let kvs = []; for (let i = 0; i < NUM_LAYERS; i++) { kvs.push(prefillOut[`present.${i}.key`]); kvs.push(prefillOut[`present.${i}.value`]); } for (let step = 0; step < 80 && nid !== eos; step++) { setStatus(`step ${step + 1} / 80 ...`); const newEmb = lookupEmbedding(nid); const newEmbF16 = floatArrayToFp16(newEmb); const newAttn = new BigInt64Array(curLen + 1); for (let i = 0; i < curLen + 1; i++) newAttn[i] = 1n; const newPos = new BigInt64Array([BigInt(curLen)]); const feeds = { inputs_embeds: new ort.Tensor("float16", newEmbF16, [1, 1, HIDDEN]), attention_mask: new ort.Tensor("int64", newAttn, [1, curLen + 1]), position_ids: new ort.Tensor("int64", newPos, [1, 1]), }; for (let i = 0; i < NUM_LAYERS; i++) { feeds[`past.${i}.key`] = kvs[2 * i]; feeds[`past.${i}.value`] = kvs[2 * i + 1]; } const out = await state.step.run(feeds); logits = await out.logits.getData(true); nid = argmax(logits, 0, VOCAB); gen.push(nid); curLen += 1; kvs = []; for (let i = 0; i < NUM_LAYERS; i++) { kvs.push(out[`present.${i}.key`]); kvs.push(out[`present.${i}.value`]); } } // 6. detokenize const filtered = gen.filter(t => t !== eos); const text = await state.tokenizer.decode(filtered, { skip_special_tokens: true }); setStatus("done"); return text; } function argmax(arr, offset, len) { let best = -Infinity, bestIdx = 0; for (let i = 0; i < len; i++) { const v = arr[offset + i]; if (v > best) { best = v; bestIdx = i; } } return bestIdx; } // Helper: encode fp32 -> fp16 Uint16Array function f32ToF16Bits(v) { // Standard IEEE 754 fp32 -> fp16 conversion (round-to-nearest-even). const f32 = new Float32Array(1); f32[0] = v; const i32 = new Uint32Array(f32.buffer)[0]; const sign = (i32 >>> 31) & 0x1; const exp = (i32 >>> 23) & 0xff; let frac = i32 & 0x7fffff; if (exp === 0xff) { // inf or nan return (sign << 15) | (0x1f << 10) | (frac ? 0x200 : 0); } const newExp = exp - 127 + 15; if (newExp >= 31) return (sign << 15) | (0x1f << 10); if (newExp <= 0) { if (newExp < -10) return (sign << 15); frac = (frac | 0x800000) >> (1 - newExp); return (sign << 15) | (frac >> 13); } return (sign << 15) | (newExp << 10) | (frac >> 13); } // Build fp16 storage: explicit Uint16 bit-pattern conversion (canonical // round-to-nearest-even). ORT 1.20+ validates that the data is a Float16Array // instance, so when available we return a Float16Array view over the same // buffer (no copy). const HAS_F16 = typeof Float16Array !== "undefined"; function floatArrayToFp16(arr) { const u16 = new Uint16Array(arr.length); for (let i = 0; i < arr.length; i++) u16[i] = f32ToF16Bits(arr[i]); if (HAS_F16) return new Float16Array(u16.buffer, u16.byteOffset, u16.length); return u16; } // ---- agreement scoring ----------------------------------------------------- function normalize(text) { let t = text; if (t.includes("")) t = t.split("")[1]; t = t.toLowerCase().replace(/[^a-z0-9\s]/g, " ").replace(/\s+/g, " ").trim(); return t; } function wer(ref, hyp) { const r = ref.split(" ").filter(x => x); const h = hyp.split(" ").filter(x => x); if (!r.length) return [(h.length ? 1 : 0), h.length, 0]; const d = Array.from({ length: r.length + 1 }, () => new Int32Array(h.length + 1)); for (let i = 0; i <= r.length; i++) d[i][0] = i; for (let j = 0; j <= h.length; j++) d[0][j] = j; for (let i = 1; i <= r.length; i++) { for (let j = 1; j <= h.length; j++) { const sub = d[i-1][j-1] + (r[i-1] === h[j-1] ? 0 : 1); const ins = d[i][j-1] + 1; const del = d[i-1][j] + 1; d[i][j] = Math.min(sub, ins, del); } } return [d[r.length][h.length] / r.length, d[r.length][h.length], r.length]; } function renderResult(hyp, ref, extra) { const el = document.getElementById("result"); el.className = "result"; if (!ref || !ref.trim()) { el.className += " neutral"; el.innerHTML = `
Transcription: ${hyp || "(empty)"}
${extra}
`; return; } const rN = normalize(ref); const hN = normalize(hyp); const [w, err, nw] = wer(rN, hN); const pct = Math.max(0, 1 - w) * 100; let cls = "red", emoji = "🔴", label = "diverged"; if (pct >= 70) { cls = "green"; emoji = "✅"; label = "match"; } else if (pct >= 50) { cls = "orange"; emoji = "🟠"; label = "close"; } else if (pct >= 25) { cls = "yellow"; emoji = "🟡"; label = "partial"; } el.className = "result " + cls; el.innerHTML = `
${emoji} ${pct.toFixed(1)}% agreement · ${label}
Transcription: ${hN || "(empty)"}
Reference: ${rN}
${extra} · WER ${(w*100).toFixed(1)}% (${err}/${nw})
`; } // ---- UI wiring ------------------------------------------------------------- document.getElementById("load-btn").addEventListener("click", () => { loadAll().catch(e => { log("LOAD FAILED: " + e.message); state.loading = false; }); }); document.getElementById("audio-file").addEventListener("change", (e) => { const f = e.target.files[0]; if (!f) return; const player = document.getElementById("audio-player"); player.src = URL.createObjectURL(f); }); document.getElementById("transcribe-btn").addEventListener("click", async () => { const refText = document.getElementById("ref-text").value; const file = document.getElementById("audio-file").files[0]; const example = document.body.dataset.example; if (!file && !example) { renderResult("", "", "Pick an audio file or example first."); return; } try { document.getElementById("transcribe-btn").disabled = true; let mel, dims, T_mel; const t0 = performance.now(); if (example) { ({ mel, dims, T_mel } = await loadExampleMel(example)); } else { ({ mel, dims } = await audioToMel(file)); T_mel = dims[2]; } const text = await transcribe({ mel, dims, T_mel }); const elapsed = (performance.now() - t0) / 1000; renderResult(text, refText, `INT4 enc + GPTQ-INT4 dec · ${state.device} · ${elapsed.toFixed(1)}s`); } catch (e) { const msg = (e && (e.message || e.toString())) || JSON.stringify(e) || "(no error info)"; const stk = (e && e.stack) ? e.stack.split("\n").slice(0, 3).join(" | ") : "(no stack)"; log("TRANSCRIBE FAILED: " + msg); log("stack: " + stk); console.error(e); renderResult("", refText, `error: ${msg}`); } finally { document.getElementById("transcribe-btn").disabled = false; } }); // Build the 8 example buttons const examplesEl = document.getElementById("examples"); const exampleEmojis = { noise: "🔊", far_field: "📡", obstructed: "🚧", distortion: "🎛️", recording: "🎙️", echo: "🏛️", dropout: "✂️", mixed: "🌪️", }; for (const [name, ref] of Object.entries(REFERENCES)) { const b = document.createElement("button"); b.textContent = `${exampleEmojis[name]} ${name}`; b.addEventListener("click", () => { document.body.dataset.example = name; document.getElementById("ref-text").value = ref; document.getElementById("audio-file").value = ""; document.getElementById("audio-player").src = `${HF_ROOT}/examples/${name}.wav`; }); examplesEl.appendChild(b); } document.getElementById("audio-file").addEventListener("change", () => { document.body.dataset.example = ""; }); log("page loaded; click 'Load model' to start.");