bonsai-webgpu / src /worker.js
Xenova's picture
Xenova HF Staff
Upload 468 files
cbb6a01 verified
import {
pipeline,
TextStreamer,
DynamicCache,
InterruptableStoppingCriteria,
} from "@huggingface/transformers";
const MODEL_IDS = {
"1.7b": "onnx-community/Bonsai-1.7B-ONNX",
};
async function check() {
try {
const adapter = await navigator.gpu?.requestAdapter();
if (!adapter) throw new Error("WebGPU is not supported (no adapter found)");
} catch (e) {
self.postMessage({ status: "error", data: e.toString() });
}
}
class TextGenerationPipeline {
static instances = new Map();
static getInstance(modelKey, progress_callback = null) {
const modelId = MODEL_IDS[modelKey];
if (!modelId) throw new Error(`Unknown model: ${modelKey}`);
if (!this.instances.has(modelKey)) {
this.instances.set(
modelKey,
pipeline("text-generation", modelId, {
device: "webgpu",
dtype: "q1",
progress_callback,
}),
);
}
return this.instances.get(modelKey);
}
}
const stopping_criteria = new InterruptableStoppingCriteria();
let past_key_values_cache = null;
let current_model_key = null;
function disposePastKeyValues() {
past_key_values_cache?.dispose?.();
past_key_values_cache = null;
}
async function load(modelKey) {
if (current_model_key && current_model_key !== modelKey) {
disposePastKeyValues();
}
current_model_key = modelKey;
self.postMessage({ status: "loading", data: "Loading model..." });
const generator = await TextGenerationPipeline.getInstance(
modelKey,
(info) => {
if (info.status === "progress_total") {
self.postMessage({
status: "progress_total",
progress: Number(info.progress ?? 0),
loaded: Number(info.loaded ?? 0),
total: Number(info.total ?? 0),
});
}
},
);
self.postMessage({
status: "loading",
data: "Optimizing model for 1-bit execution",
});
const inputs = generator.tokenizer("a");
await generator.model.generate({ ...inputs, max_new_tokens: 1 });
self.postMessage({ status: "ready" });
}
async function generate(messages) {
const generator = await TextGenerationPipeline.getInstance(current_model_key);
let startTime;
let numTokens = 0;
let tps;
const streamer = new TextStreamer(generator.tokenizer, {
skip_prompt: true,
skip_special_tokens: true,
callback_function: (output) => {
self.postMessage({ status: "update", output, tps, numTokens });
},
token_callback_function: () => {
startTime ??= performance.now();
if (numTokens++ > 0) {
tps = (numTokens / (performance.now() - startTime)) * 1000;
}
},
});
self.postMessage({ status: "start" });
past_key_values_cache ??= new DynamicCache();
try {
const output = await generator(messages, {
max_new_tokens: 1024,
do_sample: false,
streamer,
stopping_criteria,
past_key_values: past_key_values_cache,
});
self.postMessage({
status: "complete",
output: output[0].generated_text.at(-1).content,
});
} catch (e) {
self.postMessage({ status: "error", data: e.toString() });
}
}
self.addEventListener("message", async (e) => {
const { type, data } = e.data;
switch (type) {
case "check":
check();
break;
case "load":
load(data);
break;
case "generate":
stopping_criteria.reset();
generate(data);
break;
case "interrupt":
stopping_criteria.interrupt();
break;
case "reset":
disposePastKeyValues();
stopping_criteria.reset();
break;
}
});