Reza2kn commited on
Commit
2cf4acc
·
verified ·
1 Parent(s): b8c2d24

Clean up debug diagnostics now that WebGPU works end-to-end

Browse files
Files changed (1) hide show
  1. mega-asr.js +16 -67
mega-asr.js CHANGED
@@ -309,19 +309,11 @@ async function transcribe({ mel, dims, T_mel }) {
309
  if (!state.loaded) throw new Error("models not loaded");
310
  // 1. encode
311
  setStatus("audio encoder ...");
312
- let melTensor;
313
- try { melTensor = new ort.Tensor("float32", mel, dims); }
314
- catch (e) { log(`[step] Tensor ctor failed: ${e.message || e}`); throw e; }
315
- log(`[step] running encoder ...`);
316
- let encOut;
317
- try { encOut = await state.encoder.run({ mel: melTensor }); }
318
- catch (e) { log(`[step] encoder.run failed: ${e.message || e}`); throw e; }
319
- log(`[step] encoder ok; reading audio_embeds ...`);
320
- let audioEmbedsAll;
321
- try { audioEmbedsAll = await encOut.audio_embeds.getData(true); }
322
- catch (e) { log(`[step] getData failed: ${e.message || e}`); throw e; }
323
- log(`[step] audio_embeds len=${audioEmbedsAll.length} dtype=${audioEmbedsAll.constructor.name}`);
324
- const audioEmbedsDims = encOut.audio_embeds.dims; // [1, 390, 2048]
325
  const realChunks = Math.floor((T_mel + 99) / 100);
326
  const lastChunkMel = T_mel - (realChunks - 1) * 100;
327
  const realAudioFrames = (realChunks - 1) * 13 + Math.floor((lastChunkMel + 7) / 8);
@@ -364,15 +356,11 @@ async function transcribe({ mel, dims, T_mel }) {
364
  // 4. prefill
365
  setStatus("prefill ...");
366
  const t0 = performance.now();
367
- log(`[step] running prefill (L=${L}) ...`);
368
- let prefillOut;
369
- try {
370
- prefillOut = await state.prefill.run({
371
- inputs_embeds: new ort.Tensor("float16", inputsEmbedsF16, [1, L, HIDDEN]),
372
- attention_mask: new ort.Tensor("int64", attnMask, [1, L]),
373
- position_ids: new ort.Tensor("int64", posIds, [1, L]),
374
- });
375
- } catch (e) { log(`[step] prefill.run failed: ${e.message || e}`); throw e; }
376
  log(`prefill: ${(performance.now() - t0).toFixed(0)} ms (L=${L})`);
377
 
378
  // 5. greedy decode
@@ -381,33 +369,6 @@ async function transcribe({ mel, dims, T_mel }) {
381
  // them back to CPU. CPU/WASM tensors return their data array synchronously.
382
  let logits = await prefillOut.logits.getData(true); // (1, L, VOCAB)
383
  const logitsDims = prefillOut.logits.dims;
384
- // Diagnostic
385
- {
386
- const off = (logitsDims[1] - 1) * VOCAB;
387
- log(`logits type=${logits.constructor.name} len=${logits.length} dims=[${logitsDims.join(",")}] expect=${1 * logitsDims[1] * VOCAB}`);
388
- const sample = [];
389
- for (let i = 0; i < 8; i++) sample.push(String(logits[off + i]));
390
- log(`first 8 logits at last pos: ${sample.join(", ")}`);
391
- let nNaN = 0, nInf = 0, nNorm = 0, lo = Infinity, hi = -Infinity;
392
- for (let i = 0; i < VOCAB; i++) {
393
- const v = Number(logits[off + i]);
394
- if (Number.isNaN(v)) nNaN++;
395
- else if (!Number.isFinite(v)) nInf++;
396
- else { nNorm++; if (v < lo) lo = v; if (v > hi) hi = v; }
397
- }
398
- log(`logit health: NaN=${nNaN} Inf=${nInf} finite=${nNorm} range=[${lo.toFixed(2)}, ${hi.toFixed(2)}]`);
399
- const idxs = [], vals = [];
400
- for (let k = 0; k < 5; k++) {
401
- let best = -Infinity, bi = -1;
402
- for (let i = 0; i < VOCAB; i++) {
403
- if (idxs.includes(i)) continue;
404
- const v = Number(logits[off + i]);
405
- if (Number.isFinite(v) && v > best) { best = v; bi = i; }
406
- }
407
- idxs.push(bi); vals.push(best);
408
- }
409
- log(`prefill top-5: ${idxs.map((i, k) => `${i}(${vals[k].toFixed(2)})`).join(" ")}`);
410
- }
411
  // get argmax of last token
412
  let nid = argmax(logits, (logitsDims[1] - 1) * VOCAB, VOCAB);
413
  const gen = [nid];
@@ -482,28 +443,16 @@ function f32ToF16Bits(v) {
482
  return (sign << 15) | (newExp << 10) | (frac >> 13);
483
  }
484
 
485
- // Build fp16 storage: if browser has Float16Array, use it directly (ORT 1.20
486
- // validates the constructor). Otherwise build a Uint16Array of bit patterns
487
- // and view it as a Float16Array if available. Diagnostics: also dump the
488
- // first few converted values once so we can spot conversion errors.
489
  const HAS_F16 = typeof Float16Array !== "undefined";
490
- let _f16_diag_count = 0;
491
 
492
  function floatArrayToFp16(arr) {
493
- // Build the u16 bit-pattern explicitly (canonical round-to-nearest-even)
494
  const u16 = new Uint16Array(arr.length);
495
  for (let i = 0; i < arr.length; i++) u16[i] = f32ToF16Bits(arr[i]);
496
- if (HAS_F16) {
497
- // View the same buffer as Float16Array so ORT's type validation passes.
498
- const f16 = new Float16Array(u16.buffer, u16.byteOffset, u16.length);
499
- if (_f16_diag_count === 0) {
500
- _f16_diag_count = 1;
501
- const sample = [];
502
- for (let i = 0; i < Math.min(5, arr.length); i++) sample.push(arr[i].toFixed(4) + "->" + f16[i].toFixed(4));
503
- log(`fp16 sanity: ${sample.join(" ")}`);
504
- }
505
- return f16;
506
- }
507
  return u16;
508
  }
509
 
@@ -589,7 +538,7 @@ document.getElementById("transcribe-btn").addEventListener("click", async () =>
589
  }
590
  const text = await transcribe({ mel, dims, T_mel });
591
  const elapsed = (performance.now() - t0) / 1000;
592
- renderResult(text, refText, `INT4 ONNX · ${state.device} · ${elapsed.toFixed(1)}s`);
593
  } catch (e) {
594
  const msg = (e && (e.message || e.toString())) || JSON.stringify(e) || "(no error info)";
595
  const stk = (e && e.stack) ? e.stack.split("\n").slice(0, 3).join(" | ") : "(no stack)";
 
309
  if (!state.loaded) throw new Error("models not loaded");
310
  // 1. encode
311
  setStatus("audio encoder ...");
312
+ const melTensor = new ort.Tensor("float32", mel, dims);
313
+ const encOut = await state.encoder.run({ mel: melTensor });
314
+ // WebGPU outputs live in GPU memory getData(true) downloads to CPU.
315
+ const audioEmbedsAll = await encOut.audio_embeds.getData(true);
316
+ const audioEmbedsDims = encOut.audio_embeds.dims;
 
 
 
 
 
 
 
 
317
  const realChunks = Math.floor((T_mel + 99) / 100);
318
  const lastChunkMel = T_mel - (realChunks - 1) * 100;
319
  const realAudioFrames = (realChunks - 1) * 13 + Math.floor((lastChunkMel + 7) / 8);
 
356
  // 4. prefill
357
  setStatus("prefill ...");
358
  const t0 = performance.now();
359
+ const prefillOut = await state.prefill.run({
360
+ inputs_embeds: new ort.Tensor("float16", inputsEmbedsF16, [1, L, HIDDEN]),
361
+ attention_mask: new ort.Tensor("int64", attnMask, [1, L]),
362
+ position_ids: new ort.Tensor("int64", posIds, [1, L]),
363
+ });
 
 
 
 
364
  log(`prefill: ${(performance.now() - t0).toFixed(0)} ms (L=${L})`);
365
 
366
  // 5. greedy decode
 
369
  // them back to CPU. CPU/WASM tensors return their data array synchronously.
370
  let logits = await prefillOut.logits.getData(true); // (1, L, VOCAB)
371
  const logitsDims = prefillOut.logits.dims;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
  // get argmax of last token
373
  let nid = argmax(logits, (logitsDims[1] - 1) * VOCAB, VOCAB);
374
  const gen = [nid];
 
443
  return (sign << 15) | (newExp << 10) | (frac >> 13);
444
  }
445
 
446
+ // Build fp16 storage: explicit Uint16 bit-pattern conversion (canonical
447
+ // round-to-nearest-even). ORT 1.20+ validates that the data is a Float16Array
448
+ // instance, so when available we return a Float16Array view over the same
449
+ // buffer (no copy).
450
  const HAS_F16 = typeof Float16Array !== "undefined";
 
451
 
452
  function floatArrayToFp16(arr) {
 
453
  const u16 = new Uint16Array(arr.length);
454
  for (let i = 0; i < arr.length; i++) u16[i] = f32ToF16Bits(arr[i]);
455
+ if (HAS_F16) return new Float16Array(u16.buffer, u16.byteOffset, u16.length);
 
 
 
 
 
 
 
 
 
 
456
  return u16;
457
  }
458
 
 
538
  }
539
  const text = await transcribe({ mel, dims, T_mel });
540
  const elapsed = (performance.now() - t0) / 1000;
541
+ renderResult(text, refText, `INT4 enc + GPTQ-INT4 dec · ${state.device} · ${elapsed.toFixed(1)}s`);
542
  } catch (e) {
543
  const msg = (e && (e.message || e.toString())) || JSON.stringify(e) || "(no error info)";
544
  const stk = (e && e.stack) ? e.stack.split("\n").slice(0, 3).join(" | ") : "(no stack)";