Reza2kn commited on
Commit
af482b4
·
verified ·
1 Parent(s): a821180

Float16Array for fp16 tensors; per-session WebGPU fallback (decoders stay on webgpu)

Browse files
Files changed (1) hide show
  1. mega-asr.js +14 -6
mega-asr.js CHANGED
@@ -116,8 +116,8 @@ async function fetchWithCache(url, label, onProgress) {
116
  }
117
 
118
  // ---- ONNX session creation -------------------------------------------------
119
- // Build EP list: prefer requested device, always include WASM as fallback so
120
- // that "backend not found" / device init failures don't abort the whole load.
121
  function epList() {
122
  return state.device === "webgpu" ? ["webgpu", "wasm"] : ["wasm"];
123
  }
@@ -130,8 +130,7 @@ async function createSessionSimple(graphUrl, label, onProgress) {
130
  return sess;
131
  } catch (e) {
132
  if (state.device === "webgpu") {
133
- log(`webgpu failed for ${label} (${e.message}); retrying with wasm`);
134
- state.device = "wasm";
135
  const sess = await ort.InferenceSession.create(graph, { executionProviders: ["wasm"] });
136
  log(`session ready: ${label} (wasm fallback)`);
137
  return sess;
@@ -152,8 +151,7 @@ async function createSession(graphUrl, dataUrl, label, onProgress) {
152
  return sess;
153
  } catch (e) {
154
  if (state.device === "webgpu") {
155
- log(`webgpu failed for ${label} (${e.message}); retrying with wasm`);
156
- state.device = "wasm";
157
  const sess = await ort.InferenceSession.create(graph, {
158
  executionProviders: ["wasm"], externalData: externalFiles,
159
  });
@@ -451,7 +449,17 @@ function f32ToF16Bits(v) {
451
  return (sign << 15) | (newExp << 10) | (frac >> 13);
452
  }
453
 
 
 
 
 
 
454
  function floatArrayToFp16(arr) {
 
 
 
 
 
455
  const u16 = new Uint16Array(arr.length);
456
  for (let i = 0; i < arr.length; i++) u16[i] = f32ToF16Bits(arr[i]);
457
  return u16;
 
116
  }
117
 
118
  // ---- ONNX session creation -------------------------------------------------
119
+ // Always prefer the user-selected device; fall back to WASM only for the
120
+ // session that fails (per-session, not global). Don't mutate state.device.
121
  function epList() {
122
  return state.device === "webgpu" ? ["webgpu", "wasm"] : ["wasm"];
123
  }
 
130
  return sess;
131
  } catch (e) {
132
  if (state.device === "webgpu") {
133
+ log(`webgpu failed for ${label} (${e.message}); retrying this session with wasm`);
 
134
  const sess = await ort.InferenceSession.create(graph, { executionProviders: ["wasm"] });
135
  log(`session ready: ${label} (wasm fallback)`);
136
  return sess;
 
151
  return sess;
152
  } catch (e) {
153
  if (state.device === "webgpu") {
154
+ log(`webgpu failed for ${label} (${e.message}); retrying this session with wasm`);
 
155
  const sess = await ort.InferenceSession.create(graph, {
156
  executionProviders: ["wasm"], externalData: externalFiles,
157
  });
 
449
  return (sign << 15) | (newExp << 10) | (frac >> 13);
450
  }
451
 
452
+ // onnxruntime-web 1.20+ wants fp16 data as Float16Array (Chrome 134+); fall
453
+ // back to Uint16Array bit-pattern path on older engines (ORT will treat it
454
+ // as raw fp16 bytes).
455
+ const HAS_F16 = typeof Float16Array !== "undefined";
456
+
457
  function floatArrayToFp16(arr) {
458
+ if (HAS_F16) {
459
+ const out = new Float16Array(arr.length);
460
+ for (let i = 0; i < arr.length; i++) out[i] = arr[i];
461
+ return out;
462
+ }
463
  const u16 = new Uint16Array(arr.length);
464
  for (let i = 0; i < arr.length; i++) u16[i] = f32ToF16Bits(arr[i]);
465
  return u16;