mega-asr-bench / mega-asr.js
Reza2kn's picture
Clean up debug diagnostics now that WebGPU works end-to-end
2cf4acc verified
// Mega-ASR — pure browser ASR
// Loads ONNX models from Reza2kn/mega-asr-onnx via onnxruntime-web,
// the tokenizer + Whisper mel features via @huggingface/transformers,
// and runs the encode/prefill/step pipeline on the user's device.
import { AutoTokenizer, AutoProcessor, RawAudio } from "https://cdn.jsdelivr.net/npm/@huggingface/transformers@3.5.2/dist/transformers.min.js";
import * as ort from "https://cdn.jsdelivr.net/npm/onnxruntime-web@1.23.0/dist/ort.webgpu.bundle.min.mjs";
ort.env.wasm.numThreads = navigator.hardwareConcurrency || 4;
ort.env.wasm.simd = true;
const HF_ROOT = "https://huggingface.co/Reza2kn/mega-asr-onnx/resolve/main";
const NUM_LAYERS = 28;
const HIDDEN = 2048;
const VOCAB = 151936;
const REFERENCES = {
noise: "I usually take the quieter road home because the main street gets crowded after work.",
far_field: "Please remind me to print the forms before we leave for the appointment tomorrow.",
obstructed: "I forgot my charger at home, so I need to find an outlet before the meeting starts.",
distortion: "The new coffee machine is simple, but everyone keeps forgetting where the filters are stored.",
recording: "Can you check whether the train still stops at the downtown station after eight tonight?",
echo: "I need to return these shoes because the size feels fine standing up but terrible while walking.",
dropout: "My aunt is learning video calls, and she gets excited whenever the picture actually works.",
mixed: "My sister is bringing dinner over later, so we do not need to cook tonight.",
};
// ---- state -----------------------------------------------------------------
const state = {
loaded: false,
loading: false,
encoder: null,
prefill: null,
step: null,
tokenizer: null,
processor: null,
embedI8: null, // Int8Array, shape (VOCAB, HIDDEN)
embedScales: null, // Float16->Float32Array of length VOCAB
manifest: null,
device: "wasm",
};
const log = (msg) => {
const el = document.getElementById("log");
const line = document.createElement("div");
line.textContent = `[${new Date().toLocaleTimeString()}] ${msg}`;
el.appendChild(line);
el.scrollTop = el.scrollHeight;
console.log(msg);
};
const setStatus = (s) => { document.getElementById("status").textContent = s; };
const setLoaderStatus = (s) => { document.getElementById("loader-status").textContent = s; };
const setProgress = (pct) => { document.getElementById("loader-bar").style.width = pct + "%"; };
// ---- IndexedDB cache for big blobs ----------------------------------------
const DB_NAME = "mega-asr-cache-v2-gptq";
const DB_STORE = "blobs";
function openDB() {
return new Promise((resolve, reject) => {
const req = indexedDB.open(DB_NAME, 1);
req.onupgradeneeded = (e) => {
const db = e.target.result;
if (!db.objectStoreNames.contains(DB_STORE)) db.createObjectStore(DB_STORE);
};
req.onsuccess = (e) => resolve(e.target.result);
req.onerror = (e) => reject(e.target.error);
});
}
async function cacheGet(key) {
const db = await openDB();
return new Promise((resolve, reject) => {
const tx = db.transaction(DB_STORE, "readonly");
const r = tx.objectStore(DB_STORE).get(key);
r.onsuccess = () => resolve(r.result || null);
r.onerror = () => reject(r.error);
});
}
async function cachePut(key, blob) {
const db = await openDB();
return new Promise((resolve, reject) => {
const tx = db.transaction(DB_STORE, "readwrite");
const r = tx.objectStore(DB_STORE).put(blob, key);
r.onsuccess = () => resolve();
r.onerror = () => reject(r.error);
});
}
async function fetchWithCache(url, label, onProgress) {
const key = url;
const cached = await cacheGet(key);
if (cached) { log(`cached: ${label}`); return cached; }
log(`downloading ${label} ...`);
const res = await fetch(url);
if (!res.ok) throw new Error(`${url}: ${res.status}`);
const total = parseInt(res.headers.get("content-length") || "0", 10);
const reader = res.body.getReader();
const chunks = [];
let read = 0;
while (true) {
const { done, value } = await reader.read();
if (done) break;
chunks.push(value);
read += value.length;
if (total && onProgress) onProgress(read / total);
}
const buf = new Uint8Array(read);
let off = 0;
for (const c of chunks) { buf.set(c, off); off += c.length; }
await cachePut(key, buf);
log(`downloaded ${label} (${(read/1e6).toFixed(0)} MB)`);
return buf;
}
// ---- ONNX session creation -------------------------------------------------
// Always prefer the user-selected device; fall back to WASM only for the
// session that fails (per-session, not global). Don't mutate state.device.
function epList() {
return state.device === "webgpu" ? ["webgpu", "wasm"] : ["wasm"];
}
async function createSessionSimple(graphUrl, label, onProgress) {
const graph = await fetchWithCache(graphUrl, label, onProgress);
try {
const sess = await ort.InferenceSession.create(graph, { executionProviders: epList() });
log(`session ready: ${label} (${state.device})`);
return sess;
} catch (e) {
if (state.device === "webgpu") {
log(`webgpu failed for ${label} (${e.message}); retrying this session with wasm`);
const sess = await ort.InferenceSession.create(graph, { executionProviders: ["wasm"] });
log(`session ready: ${label} (wasm fallback)`);
return sess;
}
throw e;
}
}
async function createSession(graphUrl, dataUrl, label, onProgress) {
const graph = await fetchWithCache(graphUrl, label + " graph", () => {});
const weights = await fetchWithCache(dataUrl, label + " weights", onProgress);
const externalFiles = [{ path: dataUrl.split("/").pop(), data: weights }];
try {
const sess = await ort.InferenceSession.create(graph, {
executionProviders: epList(), externalData: externalFiles,
});
log(`session ready: ${label} (${state.device})`);
return sess;
} catch (e) {
if (state.device === "webgpu") {
log(`webgpu failed for ${label} (${e.message}); retrying this session with wasm`);
const sess = await ort.InferenceSession.create(graph, {
executionProviders: ["wasm"], externalData: externalFiles,
});
log(`session ready: ${label} (wasm fallback)`);
return sess;
}
log(`session create failed for ${label}: ${e.message}`);
throw e;
}
}
// ---- embedding lookup ------------------------------------------------------
// Convert int16 fp16 -> JS Number (slow, only for embed scales which is small)
function fp16ToF32(u16) {
const sign = (u16 >> 15) & 0x1;
const exp = (u16 >> 10) & 0x1f;
const frac = u16 & 0x3ff;
let v;
if (exp === 0) v = (frac === 0) ? 0 : Math.pow(2, -14) * (frac / 1024);
else if (exp === 31) v = (frac === 0) ? Infinity : NaN;
else v = Math.pow(2, exp - 15) * (1 + frac / 1024);
return sign ? -v : v;
}
function lookupEmbedding(tokenId) {
// Returns a Float32Array of length HIDDEN with the dequantized embedding.
const out = new Float32Array(HIDDEN);
const scale = state.embedScales[tokenId];
const base = tokenId * HIDDEN;
for (let i = 0; i < HIDDEN; i++) {
out[i] = state.embedI8[base + i] * scale;
}
return out;
}
// ---- model loader ----------------------------------------------------------
async function pickDevice() {
// Try WebGPU first, fall back to WASM
if ("gpu" in navigator) {
try {
const adapter = await navigator.gpu.requestAdapter();
if (adapter) {
const device = await adapter.requestDevice();
if (device) { state.device = "webgpu"; return; }
}
} catch (e) { log("WebGPU unavailable: " + e.message); }
}
state.device = "wasm";
}
async function loadAll() {
if (state.loaded || state.loading) return;
state.loading = true;
setLoaderStatus("starting...");
await pickDevice();
log(`execution provider: ${state.device}`);
// 1. manifest + tokenizer
setLoaderStatus("tokenizer + manifest ...");
state.tokenizer = await AutoTokenizer.from_pretrained("Reza2kn/mega-asr-onnx");
log("tokenizer loaded");
state.processor = await AutoProcessor.from_pretrained("Reza2kn/mega-asr-onnx").catch(() => null);
if (state.processor) log("processor (feature extractor) loaded");
else log("processor unavailable -- live audio uploads will not work, examples still ok");
const manifest = await fetch(`${HF_ROOT}/examples_mels/manifest.json`).then(r => r.json());
state.manifest = manifest;
// 2. embedding table + scales (313 MB)
setLoaderStatus("embedding table ...");
const embedBlob = await fetchWithCache(`${HF_ROOT}/onnx/embed_int8.bin`, "embed (311 MB)", p => setProgress(p * 25));
state.embedI8 = new Int8Array(embedBlob.buffer);
const scalesBlob = await fetchWithCache(`${HF_ROOT}/onnx/embed_int8_scales.bin`, "embed scales", () => {});
// scales are stored as fp16; expand to fp32
const u16 = new Uint16Array(scalesBlob.buffer);
state.embedScales = new Float32Array(u16.length);
for (let i = 0; i < u16.length; i++) state.embedScales[i] = fp16ToF32(u16[i]);
log(`embedding ready: ${u16.length} tokens × ${HIDDEN}`);
setProgress(30);
// 3. ONNX sessions
// Audio encoder: INT4 (MatMulNBits) — well-supported on WebGPU and WASM.
// Static INT8 (QLinearConv/QLinearMatMul) crashes onnxruntime-web on WebGPU.
setLoaderStatus("audio encoder INT4 ...");
state.encoder = await createSession(
`${HF_ROOT}/onnx/audio_encoder_int4.onnx`,
`${HF_ROOT}/onnx/audio_encoder_int4.onnx.data`,
"audio_encoder INT4",
p => setProgress(30 + p * 10),
);
setProgress(40);
setLoaderStatus("decoder prefill (~970 MB)...");
state.prefill = await createSession(
`${HF_ROOT}/onnx/decoder_prefill_int4.onnx`,
`${HF_ROOT}/onnx/decoder_prefill_int4.onnx.data`,
"decoder_prefill",
p => setProgress(40 + p * 30),
);
setProgress(70);
setLoaderStatus("decoder step (~970 MB)...");
state.step = await createSession(
`${HF_ROOT}/onnx/decoder_step_int4.onnx`,
`${HF_ROOT}/onnx/decoder_step_int4.onnx.data`,
"decoder_step",
p => setProgress(70 + p * 30),
);
setProgress(100);
state.loaded = true;
state.loading = false;
setLoaderStatus(`ready (${state.device})`);
document.getElementById("load-btn").disabled = true;
document.getElementById("transcribe-btn").disabled = false;
log("all models loaded.");
}
// ---- mel features for arbitrary audio ---------------------------------------
async function audioToMel(file) {
if (!state.processor) throw new Error("Live audio uploads need the processor (not available)");
const buf = await file.arrayBuffer();
// Decode + resample to 16 kHz mono via OfflineAudioContext
const audioCtx = new (window.AudioContext || window.webkitAudioContext)({ sampleRate: 16000 });
const decoded = await audioCtx.decodeAudioData(buf);
// Average to mono if multi-channel
let pcm = decoded.getChannelData(0);
if (decoded.numberOfChannels > 1) {
const tmp = new Float32Array(decoded.length);
for (let c = 0; c < decoded.numberOfChannels; c++) {
const ch = decoded.getChannelData(c);
for (let i = 0; i < ch.length; i++) tmp[i] += ch[i] / decoded.numberOfChannels;
}
pcm = tmp;
}
// Run through transformers.js WhisperFeatureExtractor (via the loaded processor)
const feat = await state.processor(new RawAudio(pcm, 16000));
// feat.input_features: Tensor[1, 128, T]
return { mel: feat.input_features.data, dims: feat.input_features.dims };
}
// ---- example mel loader ----------------------------------------------------
async function loadExampleMel(name) {
const url = `${HF_ROOT}/examples_mels/${name}.mel.bin`;
const buf = await fetchWithCache(url, `mel ${name}`, () => {});
// fp16 -> fp32 (3000 * 128 floats)
const u16 = new Uint16Array(buf.buffer);
const f32 = new Float32Array(u16.length);
for (let i = 0; i < u16.length; i++) f32[i] = fp16ToF32(u16[i]);
// Shape (1, 128, 3000)
return { mel: f32, dims: [1, 128, 3000], T_mel: state.manifest.examples[name].T_mel };
}
// ---- core inference --------------------------------------------------------
async function transcribe({ mel, dims, T_mel }) {
if (!state.loaded) throw new Error("models not loaded");
// 1. encode
setStatus("audio encoder ...");
const melTensor = new ort.Tensor("float32", mel, dims);
const encOut = await state.encoder.run({ mel: melTensor });
// WebGPU outputs live in GPU memory — getData(true) downloads to CPU.
const audioEmbedsAll = await encOut.audio_embeds.getData(true);
const audioEmbedsDims = encOut.audio_embeds.dims;
const realChunks = Math.floor((T_mel + 99) / 100);
const lastChunkMel = T_mel - (realChunks - 1) * 100;
const realAudioFrames = (realChunks - 1) * 13 + Math.floor((lastChunkMel + 7) / 8);
// 2. build prompt + scatter audio embeds at <|audio_pad|>.
// Default to the forced-English prompt; the model's auto language detection
// can fail at INT4 quantization on borderline audio.
setStatus("building prompt ...");
const lang = (document.getElementById("lang-select")?.value) || "english";
const promptIds = (state.manifest.prompts && state.manifest.prompts[lang]?.ids) || state.manifest.prompt_ids;
const audioPadId = state.manifest.audio_pad_id;
// Expand audio_pad in the prompt to realAudioFrames placeholder tokens
const tokens = [];
for (const t of promptIds) {
if (t === audioPadId) for (let i = 0; i < realAudioFrames; i++) tokens.push(audioPadId);
else tokens.push(t);
}
const L = tokens.length;
// 3. embed text tokens, scatter audio embeds at placeholder positions
const inputsEmbeds = new Float32Array(L * HIDDEN);
let audioIdx = 0;
for (let i = 0; i < L; i++) {
if (tokens[i] === audioPadId) {
// audio_embed[audioIdx]
const src = audioIdx * HIDDEN;
const dst = i * HIDDEN;
for (let k = 0; k < HIDDEN; k++) inputsEmbeds[dst + k] = audioEmbedsAll[src + k];
audioIdx++;
} else {
const e = lookupEmbedding(tokens[i]);
const dst = i * HIDDEN;
for (let k = 0; k < HIDDEN; k++) inputsEmbeds[dst + k] = e[k];
}
}
// ONNX wants fp16 embeds: convert
const inputsEmbedsF16 = floatArrayToFp16(inputsEmbeds);
const attnMask = new BigInt64Array(L); for (let i = 0; i < L; i++) attnMask[i] = 1n;
const posIds = new BigInt64Array(L); for (let i = 0; i < L; i++) posIds[i] = BigInt(i);
// 4. prefill
setStatus("prefill ...");
const t0 = performance.now();
const prefillOut = await state.prefill.run({
inputs_embeds: new ort.Tensor("float16", inputsEmbedsF16, [1, L, HIDDEN]),
attention_mask: new ort.Tensor("int64", attnMask, [1, L]),
position_ids: new ort.Tensor("int64", posIds, [1, L]),
});
log(`prefill: ${(performance.now() - t0).toFixed(0)} ms (L=${L})`);
// 5. greedy decode
setStatus("decoding ...");
// WebGPU outputs live in GPU memory — must call getData() (async) to bring
// them back to CPU. CPU/WASM tensors return their data array synchronously.
let logits = await prefillOut.logits.getData(true); // (1, L, VOCAB)
const logitsDims = prefillOut.logits.dims;
// get argmax of last token
let nid = argmax(logits, (logitsDims[1] - 1) * VOCAB, VOCAB);
const gen = [nid];
const eos = state.manifest.eos_token_id;
let curLen = L;
// collect KV cache
let kvs = [];
for (let i = 0; i < NUM_LAYERS; i++) {
kvs.push(prefillOut[`present.${i}.key`]);
kvs.push(prefillOut[`present.${i}.value`]);
}
for (let step = 0; step < 80 && nid !== eos; step++) {
setStatus(`step ${step + 1} / 80 ...`);
const newEmb = lookupEmbedding(nid);
const newEmbF16 = floatArrayToFp16(newEmb);
const newAttn = new BigInt64Array(curLen + 1); for (let i = 0; i < curLen + 1; i++) newAttn[i] = 1n;
const newPos = new BigInt64Array([BigInt(curLen)]);
const feeds = {
inputs_embeds: new ort.Tensor("float16", newEmbF16, [1, 1, HIDDEN]),
attention_mask: new ort.Tensor("int64", newAttn, [1, curLen + 1]),
position_ids: new ort.Tensor("int64", newPos, [1, 1]),
};
for (let i = 0; i < NUM_LAYERS; i++) {
feeds[`past.${i}.key`] = kvs[2 * i];
feeds[`past.${i}.value`] = kvs[2 * i + 1];
}
const out = await state.step.run(feeds);
logits = await out.logits.getData(true);
nid = argmax(logits, 0, VOCAB);
gen.push(nid);
curLen += 1;
kvs = [];
for (let i = 0; i < NUM_LAYERS; i++) {
kvs.push(out[`present.${i}.key`]);
kvs.push(out[`present.${i}.value`]);
}
}
// 6. detokenize
const filtered = gen.filter(t => t !== eos);
const text = await state.tokenizer.decode(filtered, { skip_special_tokens: true });
setStatus("done");
return text;
}
function argmax(arr, offset, len) {
let best = -Infinity, bestIdx = 0;
for (let i = 0; i < len; i++) {
const v = arr[offset + i];
if (v > best) { best = v; bestIdx = i; }
}
return bestIdx;
}
// Helper: encode fp32 -> fp16 Uint16Array
function f32ToF16Bits(v) {
// Standard IEEE 754 fp32 -> fp16 conversion (round-to-nearest-even).
const f32 = new Float32Array(1); f32[0] = v;
const i32 = new Uint32Array(f32.buffer)[0];
const sign = (i32 >>> 31) & 0x1;
const exp = (i32 >>> 23) & 0xff;
let frac = i32 & 0x7fffff;
if (exp === 0xff) { // inf or nan
return (sign << 15) | (0x1f << 10) | (frac ? 0x200 : 0);
}
const newExp = exp - 127 + 15;
if (newExp >= 31) return (sign << 15) | (0x1f << 10);
if (newExp <= 0) {
if (newExp < -10) return (sign << 15);
frac = (frac | 0x800000) >> (1 - newExp);
return (sign << 15) | (frac >> 13);
}
return (sign << 15) | (newExp << 10) | (frac >> 13);
}
// Build fp16 storage: explicit Uint16 bit-pattern conversion (canonical
// round-to-nearest-even). ORT 1.20+ validates that the data is a Float16Array
// instance, so when available we return a Float16Array view over the same
// buffer (no copy).
const HAS_F16 = typeof Float16Array !== "undefined";
function floatArrayToFp16(arr) {
const u16 = new Uint16Array(arr.length);
for (let i = 0; i < arr.length; i++) u16[i] = f32ToF16Bits(arr[i]);
if (HAS_F16) return new Float16Array(u16.buffer, u16.byteOffset, u16.length);
return u16;
}
// ---- agreement scoring -----------------------------------------------------
function normalize(text) {
let t = text;
if (t.includes("<asr_text>")) t = t.split("<asr_text>")[1];
t = t.toLowerCase().replace(/[^a-z0-9\s]/g, " ").replace(/\s+/g, " ").trim();
return t;
}
function wer(ref, hyp) {
const r = ref.split(" ").filter(x => x);
const h = hyp.split(" ").filter(x => x);
if (!r.length) return [(h.length ? 1 : 0), h.length, 0];
const d = Array.from({ length: r.length + 1 }, () => new Int32Array(h.length + 1));
for (let i = 0; i <= r.length; i++) d[i][0] = i;
for (let j = 0; j <= h.length; j++) d[0][j] = j;
for (let i = 1; i <= r.length; i++) {
for (let j = 1; j <= h.length; j++) {
const sub = d[i-1][j-1] + (r[i-1] === h[j-1] ? 0 : 1);
const ins = d[i][j-1] + 1;
const del = d[i-1][j] + 1;
d[i][j] = Math.min(sub, ins, del);
}
}
return [d[r.length][h.length] / r.length, d[r.length][h.length], r.length];
}
function renderResult(hyp, ref, extra) {
const el = document.getElementById("result");
el.className = "result";
if (!ref || !ref.trim()) {
el.className += " neutral";
el.innerHTML = `<div><b>Transcription:</b> ${hyp || "<i>(empty)</i>"}</div>
<div class="muted" style="margin-top:6px;">${extra}</div>`;
return;
}
const rN = normalize(ref); const hN = normalize(hyp);
const [w, err, nw] = wer(rN, hN);
const pct = Math.max(0, 1 - w) * 100;
let cls = "red", emoji = "🔴", label = "diverged";
if (pct >= 70) { cls = "green"; emoji = "✅"; label = "match"; }
else if (pct >= 50) { cls = "orange"; emoji = "🟠"; label = "close"; }
else if (pct >= 25) { cls = "yellow"; emoji = "🟡"; label = "partial"; }
el.className = "result " + cls;
el.innerHTML = `
<div class="label"><b>${emoji} ${pct.toFixed(1)}% agreement</b> &middot; ${label}</div>
<div><b>Transcription:</b> ${hN || "<i>(empty)</i>"}</div>
<div class="ref-line"><b>Reference:</b> ${rN}</div>
<div class="muted" style="margin-top:6px;">${extra} &middot; WER ${(w*100).toFixed(1)}% (${err}/${nw})</div>`;
}
// ---- UI wiring -------------------------------------------------------------
document.getElementById("load-btn").addEventListener("click", () => {
loadAll().catch(e => { log("LOAD FAILED: " + e.message); state.loading = false; });
});
document.getElementById("audio-file").addEventListener("change", (e) => {
const f = e.target.files[0];
if (!f) return;
const player = document.getElementById("audio-player");
player.src = URL.createObjectURL(f);
});
document.getElementById("transcribe-btn").addEventListener("click", async () => {
const refText = document.getElementById("ref-text").value;
const file = document.getElementById("audio-file").files[0];
const example = document.body.dataset.example;
if (!file && !example) {
renderResult("", "", "Pick an audio file or example first.");
return;
}
try {
document.getElementById("transcribe-btn").disabled = true;
let mel, dims, T_mel;
const t0 = performance.now();
if (example) {
({ mel, dims, T_mel } = await loadExampleMel(example));
} else {
({ mel, dims } = await audioToMel(file));
T_mel = dims[2];
}
const text = await transcribe({ mel, dims, T_mel });
const elapsed = (performance.now() - t0) / 1000;
renderResult(text, refText, `INT4 enc + GPTQ-INT4 dec · ${state.device} · ${elapsed.toFixed(1)}s`);
} catch (e) {
const msg = (e && (e.message || e.toString())) || JSON.stringify(e) || "(no error info)";
const stk = (e && e.stack) ? e.stack.split("\n").slice(0, 3).join(" | ") : "(no stack)";
log("TRANSCRIBE FAILED: " + msg);
log("stack: " + stk);
console.error(e);
renderResult("", refText, `error: ${msg}`);
} finally {
document.getElementById("transcribe-btn").disabled = false;
}
});
// Build the 8 example buttons
const examplesEl = document.getElementById("examples");
const exampleEmojis = {
noise: "🔊", far_field: "📡", obstructed: "🚧", distortion: "🎛️",
recording: "🎙️", echo: "🏛️", dropout: "✂️", mixed: "🌪️",
};
for (const [name, ref] of Object.entries(REFERENCES)) {
const b = document.createElement("button");
b.textContent = `${exampleEmojis[name]} ${name}`;
b.addEventListener("click", () => {
document.body.dataset.example = name;
document.getElementById("ref-text").value = ref;
document.getElementById("audio-file").value = "";
document.getElementById("audio-player").src = `${HF_ROOT}/examples/${name}.wav`;
});
examplesEl.appendChild(b);
}
document.getElementById("audio-file").addEventListener("change", () => {
document.body.dataset.example = "";
});
log("page loaded; click 'Load model' to start.");