Spaces:
Running
Running
Float16Array for fp16 tensors; per-session WebGPU fallback (decoders stay on webgpu)
Browse files- mega-asr.js +14 -6
mega-asr.js
CHANGED
|
@@ -116,8 +116,8 @@ async function fetchWithCache(url, label, onProgress) {
|
|
| 116 |
}
|
| 117 |
|
| 118 |
// ---- ONNX session creation -------------------------------------------------
|
| 119 |
-
//
|
| 120 |
-
// that
|
| 121 |
function epList() {
|
| 122 |
return state.device === "webgpu" ? ["webgpu", "wasm"] : ["wasm"];
|
| 123 |
}
|
|
@@ -130,8 +130,7 @@ async function createSessionSimple(graphUrl, label, onProgress) {
|
|
| 130 |
return sess;
|
| 131 |
} catch (e) {
|
| 132 |
if (state.device === "webgpu") {
|
| 133 |
-
log(`webgpu failed for ${label} (${e.message}); retrying with wasm`);
|
| 134 |
-
state.device = "wasm";
|
| 135 |
const sess = await ort.InferenceSession.create(graph, { executionProviders: ["wasm"] });
|
| 136 |
log(`session ready: ${label} (wasm fallback)`);
|
| 137 |
return sess;
|
|
@@ -152,8 +151,7 @@ async function createSession(graphUrl, dataUrl, label, onProgress) {
|
|
| 152 |
return sess;
|
| 153 |
} catch (e) {
|
| 154 |
if (state.device === "webgpu") {
|
| 155 |
-
log(`webgpu failed for ${label} (${e.message}); retrying with wasm`);
|
| 156 |
-
state.device = "wasm";
|
| 157 |
const sess = await ort.InferenceSession.create(graph, {
|
| 158 |
executionProviders: ["wasm"], externalData: externalFiles,
|
| 159 |
});
|
|
@@ -451,7 +449,17 @@ function f32ToF16Bits(v) {
|
|
| 451 |
return (sign << 15) | (newExp << 10) | (frac >> 13);
|
| 452 |
}
|
| 453 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 454 |
function floatArrayToFp16(arr) {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 455 |
const u16 = new Uint16Array(arr.length);
|
| 456 |
for (let i = 0; i < arr.length; i++) u16[i] = f32ToF16Bits(arr[i]);
|
| 457 |
return u16;
|
|
|
|
| 116 |
}
|
| 117 |
|
| 118 |
// ---- ONNX session creation -------------------------------------------------
|
| 119 |
+
// Always prefer the user-selected device; fall back to WASM only for the
|
| 120 |
+
// session that fails (per-session, not global). Don't mutate state.device.
|
| 121 |
function epList() {
|
| 122 |
return state.device === "webgpu" ? ["webgpu", "wasm"] : ["wasm"];
|
| 123 |
}
|
|
|
|
| 130 |
return sess;
|
| 131 |
} catch (e) {
|
| 132 |
if (state.device === "webgpu") {
|
| 133 |
+
log(`webgpu failed for ${label} (${e.message}); retrying this session with wasm`);
|
|
|
|
| 134 |
const sess = await ort.InferenceSession.create(graph, { executionProviders: ["wasm"] });
|
| 135 |
log(`session ready: ${label} (wasm fallback)`);
|
| 136 |
return sess;
|
|
|
|
| 151 |
return sess;
|
| 152 |
} catch (e) {
|
| 153 |
if (state.device === "webgpu") {
|
| 154 |
+
log(`webgpu failed for ${label} (${e.message}); retrying this session with wasm`);
|
|
|
|
| 155 |
const sess = await ort.InferenceSession.create(graph, {
|
| 156 |
executionProviders: ["wasm"], externalData: externalFiles,
|
| 157 |
});
|
|
|
|
| 449 |
return (sign << 15) | (newExp << 10) | (frac >> 13);
|
| 450 |
}
|
| 451 |
|
| 452 |
+
// onnxruntime-web 1.20+ wants fp16 data as Float16Array (Chrome 134+); fall
|
| 453 |
+
// back to Uint16Array bit-pattern path on older engines (ORT will treat it
|
| 454 |
+
// as raw fp16 bytes).
|
| 455 |
+
const HAS_F16 = typeof Float16Array !== "undefined";
|
| 456 |
+
|
| 457 |
function floatArrayToFp16(arr) {
|
| 458 |
+
if (HAS_F16) {
|
| 459 |
+
const out = new Float16Array(arr.length);
|
| 460 |
+
for (let i = 0; i < arr.length; i++) out[i] = arr[i];
|
| 461 |
+
return out;
|
| 462 |
+
}
|
| 463 |
const u16 = new Uint16Array(arr.length);
|
| 464 |
for (let i = 0; i < arr.length; i++) u16[i] = f32ToF16Bits(arr[i]);
|
| 465 |
return u16;
|