Spaces:
Running
Running
Clean up debug diagnostics now that WebGPU works end-to-end
Browse files- mega-asr.js +16 -67
mega-asr.js
CHANGED
|
@@ -309,19 +309,11 @@ async function transcribe({ mel, dims, T_mel }) {
|
|
| 309 |
if (!state.loaded) throw new Error("models not loaded");
|
| 310 |
// 1. encode
|
| 311 |
setStatus("audio encoder ...");
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
try { encOut = await state.encoder.run({ mel: melTensor }); }
|
| 318 |
-
catch (e) { log(`[step] encoder.run failed: ${e.message || e}`); throw e; }
|
| 319 |
-
log(`[step] encoder ok; reading audio_embeds ...`);
|
| 320 |
-
let audioEmbedsAll;
|
| 321 |
-
try { audioEmbedsAll = await encOut.audio_embeds.getData(true); }
|
| 322 |
-
catch (e) { log(`[step] getData failed: ${e.message || e}`); throw e; }
|
| 323 |
-
log(`[step] audio_embeds len=${audioEmbedsAll.length} dtype=${audioEmbedsAll.constructor.name}`);
|
| 324 |
-
const audioEmbedsDims = encOut.audio_embeds.dims; // [1, 390, 2048]
|
| 325 |
const realChunks = Math.floor((T_mel + 99) / 100);
|
| 326 |
const lastChunkMel = T_mel - (realChunks - 1) * 100;
|
| 327 |
const realAudioFrames = (realChunks - 1) * 13 + Math.floor((lastChunkMel + 7) / 8);
|
|
@@ -364,15 +356,11 @@ async function transcribe({ mel, dims, T_mel }) {
|
|
| 364 |
// 4. prefill
|
| 365 |
setStatus("prefill ...");
|
| 366 |
const t0 = performance.now();
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
attention_mask: new ort.Tensor("int64", attnMask, [1, L]),
|
| 373 |
-
position_ids: new ort.Tensor("int64", posIds, [1, L]),
|
| 374 |
-
});
|
| 375 |
-
} catch (e) { log(`[step] prefill.run failed: ${e.message || e}`); throw e; }
|
| 376 |
log(`prefill: ${(performance.now() - t0).toFixed(0)} ms (L=${L})`);
|
| 377 |
|
| 378 |
// 5. greedy decode
|
|
@@ -381,33 +369,6 @@ async function transcribe({ mel, dims, T_mel }) {
|
|
| 381 |
// them back to CPU. CPU/WASM tensors return their data array synchronously.
|
| 382 |
let logits = await prefillOut.logits.getData(true); // (1, L, VOCAB)
|
| 383 |
const logitsDims = prefillOut.logits.dims;
|
| 384 |
-
// Diagnostic
|
| 385 |
-
{
|
| 386 |
-
const off = (logitsDims[1] - 1) * VOCAB;
|
| 387 |
-
log(`logits type=${logits.constructor.name} len=${logits.length} dims=[${logitsDims.join(",")}] expect=${1 * logitsDims[1] * VOCAB}`);
|
| 388 |
-
const sample = [];
|
| 389 |
-
for (let i = 0; i < 8; i++) sample.push(String(logits[off + i]));
|
| 390 |
-
log(`first 8 logits at last pos: ${sample.join(", ")}`);
|
| 391 |
-
let nNaN = 0, nInf = 0, nNorm = 0, lo = Infinity, hi = -Infinity;
|
| 392 |
-
for (let i = 0; i < VOCAB; i++) {
|
| 393 |
-
const v = Number(logits[off + i]);
|
| 394 |
-
if (Number.isNaN(v)) nNaN++;
|
| 395 |
-
else if (!Number.isFinite(v)) nInf++;
|
| 396 |
-
else { nNorm++; if (v < lo) lo = v; if (v > hi) hi = v; }
|
| 397 |
-
}
|
| 398 |
-
log(`logit health: NaN=${nNaN} Inf=${nInf} finite=${nNorm} range=[${lo.toFixed(2)}, ${hi.toFixed(2)}]`);
|
| 399 |
-
const idxs = [], vals = [];
|
| 400 |
-
for (let k = 0; k < 5; k++) {
|
| 401 |
-
let best = -Infinity, bi = -1;
|
| 402 |
-
for (let i = 0; i < VOCAB; i++) {
|
| 403 |
-
if (idxs.includes(i)) continue;
|
| 404 |
-
const v = Number(logits[off + i]);
|
| 405 |
-
if (Number.isFinite(v) && v > best) { best = v; bi = i; }
|
| 406 |
-
}
|
| 407 |
-
idxs.push(bi); vals.push(best);
|
| 408 |
-
}
|
| 409 |
-
log(`prefill top-5: ${idxs.map((i, k) => `${i}(${vals[k].toFixed(2)})`).join(" ")}`);
|
| 410 |
-
}
|
| 411 |
// get argmax of last token
|
| 412 |
let nid = argmax(logits, (logitsDims[1] - 1) * VOCAB, VOCAB);
|
| 413 |
const gen = [nid];
|
|
@@ -482,28 +443,16 @@ function f32ToF16Bits(v) {
|
|
| 482 |
return (sign << 15) | (newExp << 10) | (frac >> 13);
|
| 483 |
}
|
| 484 |
|
| 485 |
-
// Build fp16 storage:
|
| 486 |
-
//
|
| 487 |
-
//
|
| 488 |
-
//
|
| 489 |
const HAS_F16 = typeof Float16Array !== "undefined";
|
| 490 |
-
let _f16_diag_count = 0;
|
| 491 |
|
| 492 |
function floatArrayToFp16(arr) {
|
| 493 |
-
// Build the u16 bit-pattern explicitly (canonical round-to-nearest-even)
|
| 494 |
const u16 = new Uint16Array(arr.length);
|
| 495 |
for (let i = 0; i < arr.length; i++) u16[i] = f32ToF16Bits(arr[i]);
|
| 496 |
-
if (HAS_F16)
|
| 497 |
-
// View the same buffer as Float16Array so ORT's type validation passes.
|
| 498 |
-
const f16 = new Float16Array(u16.buffer, u16.byteOffset, u16.length);
|
| 499 |
-
if (_f16_diag_count === 0) {
|
| 500 |
-
_f16_diag_count = 1;
|
| 501 |
-
const sample = [];
|
| 502 |
-
for (let i = 0; i < Math.min(5, arr.length); i++) sample.push(arr[i].toFixed(4) + "->" + f16[i].toFixed(4));
|
| 503 |
-
log(`fp16 sanity: ${sample.join(" ")}`);
|
| 504 |
-
}
|
| 505 |
-
return f16;
|
| 506 |
-
}
|
| 507 |
return u16;
|
| 508 |
}
|
| 509 |
|
|
@@ -589,7 +538,7 @@ document.getElementById("transcribe-btn").addEventListener("click", async () =>
|
|
| 589 |
}
|
| 590 |
const text = await transcribe({ mel, dims, T_mel });
|
| 591 |
const elapsed = (performance.now() - t0) / 1000;
|
| 592 |
-
renderResult(text, refText, `INT4
|
| 593 |
} catch (e) {
|
| 594 |
const msg = (e && (e.message || e.toString())) || JSON.stringify(e) || "(no error info)";
|
| 595 |
const stk = (e && e.stack) ? e.stack.split("\n").slice(0, 3).join(" | ") : "(no stack)";
|
|
|
|
| 309 |
if (!state.loaded) throw new Error("models not loaded");
|
| 310 |
// 1. encode
|
| 311 |
setStatus("audio encoder ...");
|
| 312 |
+
const melTensor = new ort.Tensor("float32", mel, dims);
|
| 313 |
+
const encOut = await state.encoder.run({ mel: melTensor });
|
| 314 |
+
// WebGPU outputs live in GPU memory — getData(true) downloads to CPU.
|
| 315 |
+
const audioEmbedsAll = await encOut.audio_embeds.getData(true);
|
| 316 |
+
const audioEmbedsDims = encOut.audio_embeds.dims;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 317 |
const realChunks = Math.floor((T_mel + 99) / 100);
|
| 318 |
const lastChunkMel = T_mel - (realChunks - 1) * 100;
|
| 319 |
const realAudioFrames = (realChunks - 1) * 13 + Math.floor((lastChunkMel + 7) / 8);
|
|
|
|
| 356 |
// 4. prefill
|
| 357 |
setStatus("prefill ...");
|
| 358 |
const t0 = performance.now();
|
| 359 |
+
const prefillOut = await state.prefill.run({
|
| 360 |
+
inputs_embeds: new ort.Tensor("float16", inputsEmbedsF16, [1, L, HIDDEN]),
|
| 361 |
+
attention_mask: new ort.Tensor("int64", attnMask, [1, L]),
|
| 362 |
+
position_ids: new ort.Tensor("int64", posIds, [1, L]),
|
| 363 |
+
});
|
|
|
|
|
|
|
|
|
|
|
|
|
| 364 |
log(`prefill: ${(performance.now() - t0).toFixed(0)} ms (L=${L})`);
|
| 365 |
|
| 366 |
// 5. greedy decode
|
|
|
|
| 369 |
// them back to CPU. CPU/WASM tensors return their data array synchronously.
|
| 370 |
let logits = await prefillOut.logits.getData(true); // (1, L, VOCAB)
|
| 371 |
const logitsDims = prefillOut.logits.dims;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 372 |
// get argmax of last token
|
| 373 |
let nid = argmax(logits, (logitsDims[1] - 1) * VOCAB, VOCAB);
|
| 374 |
const gen = [nid];
|
|
|
|
| 443 |
return (sign << 15) | (newExp << 10) | (frac >> 13);
|
| 444 |
}
|
| 445 |
|
| 446 |
+
// Build fp16 storage: explicit Uint16 bit-pattern conversion (canonical
|
| 447 |
+
// round-to-nearest-even). ORT 1.20+ validates that the data is a Float16Array
|
| 448 |
+
// instance, so when available we return a Float16Array view over the same
|
| 449 |
+
// buffer (no copy).
|
| 450 |
const HAS_F16 = typeof Float16Array !== "undefined";
|
|
|
|
| 451 |
|
| 452 |
function floatArrayToFp16(arr) {
|
|
|
|
| 453 |
const u16 = new Uint16Array(arr.length);
|
| 454 |
for (let i = 0; i < arr.length; i++) u16[i] = f32ToF16Bits(arr[i]);
|
| 455 |
+
if (HAS_F16) return new Float16Array(u16.buffer, u16.byteOffset, u16.length);
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 456 |
return u16;
|
| 457 |
}
|
| 458 |
|
|
|
|
| 538 |
}
|
| 539 |
const text = await transcribe({ mel, dims, T_mel });
|
| 540 |
const elapsed = (performance.now() - t0) / 1000;
|
| 541 |
+
renderResult(text, refText, `INT4 enc + GPTQ-INT4 dec · ${state.device} · ${elapsed.toFixed(1)}s`);
|
| 542 |
} catch (e) {
|
| 543 |
const msg = (e && (e.message || e.toString())) || JSON.stringify(e) || "(no error info)";
|
| 544 |
const stk = (e && e.stack) ? e.stack.split("\n").slice(0, 3).join(" | ") : "(no stack)";
|