Reza2kn commited on
Commit
1320aec
·
verified ·
1 Parent(s): ed37d13

Await tensor.getData() for WebGPU outputs (audio_embeds + logits) so data is actually copied back to CPU

Browse files
Files changed (1) hide show
  1. mega-asr.js +6 -3
mega-asr.js CHANGED
@@ -322,7 +322,8 @@ async function transcribe({ mel, dims, T_mel }) {
322
  setStatus("audio encoder ...");
323
  const melTensor = new ort.Tensor("float32", mel, dims);
324
  const encOut = await state.encoder.run({ mel: melTensor });
325
- const audioEmbedsAll = encOut.audio_embeds.data; // Float32Array (1*390*2048,)
 
326
  const audioEmbedsDims = encOut.audio_embeds.dims; // [1, 390, 2048]
327
  const realChunks = Math.floor((T_mel + 99) / 100);
328
  const lastChunkMel = T_mel - (realChunks - 1) * 100;
@@ -375,7 +376,9 @@ async function transcribe({ mel, dims, T_mel }) {
375
 
376
  // 5. greedy decode
377
  setStatus("decoding ...");
378
- let logits = prefillOut.logits.data; // (1, L, VOCAB)
 
 
379
  const logitsDims = prefillOut.logits.dims;
380
  // Diagnostic: dump top-5 of last logit so we can see what the decoder predicted
381
  {
@@ -419,7 +422,7 @@ async function transcribe({ mel, dims, T_mel }) {
419
  feeds[`past.${i}.value`] = kvs[2 * i + 1];
420
  }
421
  const out = await state.step.run(feeds);
422
- logits = out.logits.data;
423
  nid = argmax(logits, 0, VOCAB);
424
  gen.push(nid);
425
  curLen += 1;
 
322
  setStatus("audio encoder ...");
323
  const melTensor = new ort.Tensor("float32", mel, dims);
324
  const encOut = await state.encoder.run({ mel: melTensor });
325
+ // For WebGPU outputs we must await getData() to bring values back to CPU.
326
+ const audioEmbedsAll = await encOut.audio_embeds.getData(true); // Float32Array (1*390*2048,)
327
  const audioEmbedsDims = encOut.audio_embeds.dims; // [1, 390, 2048]
328
  const realChunks = Math.floor((T_mel + 99) / 100);
329
  const lastChunkMel = T_mel - (realChunks - 1) * 100;
 
376
 
377
  // 5. greedy decode
378
  setStatus("decoding ...");
379
+ // WebGPU outputs live in GPU memory — must call getData() (async) to bring
380
+ // them back to CPU. CPU/WASM tensors return their data array synchronously.
381
+ let logits = await prefillOut.logits.getData(true); // (1, L, VOCAB)
382
  const logitsDims = prefillOut.logits.dims;
383
  // Diagnostic: dump top-5 of last logit so we can see what the decoder predicted
384
  {
 
422
  feeds[`past.${i}.value`] = kvs[2 * i + 1];
423
  }
424
  const out = await state.step.run(feeds);
425
+ logits = await out.logits.getData(true);
426
  nid = argmax(logits, 0, VOCAB);
427
  gen.push(nid);
428
  curLen += 1;