Spaces:
Running
Running
| import { | |
| pipeline, | |
| TextStreamer, | |
| DynamicCache, | |
| InterruptableStoppingCriteria, | |
| } from "@huggingface/transformers"; | |
| const MODEL_IDS = { | |
| "1.7b": "onnx-community/Bonsai-1.7B-ONNX", | |
| }; | |
| async function check() { | |
| try { | |
| const adapter = await navigator.gpu?.requestAdapter(); | |
| if (!adapter) throw new Error("WebGPU is not supported (no adapter found)"); | |
| } catch (e) { | |
| self.postMessage({ status: "error", data: e.toString() }); | |
| } | |
| } | |
| class TextGenerationPipeline { | |
| static instances = new Map(); | |
| static getInstance(modelKey, progress_callback = null) { | |
| const modelId = MODEL_IDS[modelKey]; | |
| if (!modelId) throw new Error(`Unknown model: ${modelKey}`); | |
| if (!this.instances.has(modelKey)) { | |
| this.instances.set( | |
| modelKey, | |
| pipeline("text-generation", modelId, { | |
| device: "webgpu", | |
| dtype: "q1", | |
| progress_callback, | |
| }), | |
| ); | |
| } | |
| return this.instances.get(modelKey); | |
| } | |
| } | |
| const stopping_criteria = new InterruptableStoppingCriteria(); | |
| let past_key_values_cache = null; | |
| let current_model_key = null; | |
| function disposePastKeyValues() { | |
| past_key_values_cache?.dispose?.(); | |
| past_key_values_cache = null; | |
| } | |
| async function load(modelKey) { | |
| if (current_model_key && current_model_key !== modelKey) { | |
| disposePastKeyValues(); | |
| } | |
| current_model_key = modelKey; | |
| self.postMessage({ status: "loading", data: "Loading model..." }); | |
| const generator = await TextGenerationPipeline.getInstance( | |
| modelKey, | |
| (info) => { | |
| if (info.status === "progress_total") { | |
| self.postMessage({ | |
| status: "progress_total", | |
| progress: Number(info.progress ?? 0), | |
| loaded: Number(info.loaded ?? 0), | |
| total: Number(info.total ?? 0), | |
| }); | |
| } | |
| }, | |
| ); | |
| self.postMessage({ | |
| status: "loading", | |
| data: "Optimizing model for 1-bit execution", | |
| }); | |
| const inputs = generator.tokenizer("a"); | |
| await generator.model.generate({ ...inputs, max_new_tokens: 1 }); | |
| self.postMessage({ status: "ready" }); | |
| } | |
| async function generate(messages) { | |
| const generator = await TextGenerationPipeline.getInstance(current_model_key); | |
| let startTime; | |
| let numTokens = 0; | |
| let tps; | |
| const streamer = new TextStreamer(generator.tokenizer, { | |
| skip_prompt: true, | |
| skip_special_tokens: true, | |
| callback_function: (output) => { | |
| self.postMessage({ status: "update", output, tps, numTokens }); | |
| }, | |
| token_callback_function: () => { | |
| startTime ??= performance.now(); | |
| if (numTokens++ > 0) { | |
| tps = (numTokens / (performance.now() - startTime)) * 1000; | |
| } | |
| }, | |
| }); | |
| self.postMessage({ status: "start" }); | |
| past_key_values_cache ??= new DynamicCache(); | |
| try { | |
| const output = await generator(messages, { | |
| max_new_tokens: 1024, | |
| do_sample: false, | |
| streamer, | |
| stopping_criteria, | |
| past_key_values: past_key_values_cache, | |
| }); | |
| self.postMessage({ | |
| status: "complete", | |
| output: output[0].generated_text.at(-1).content, | |
| }); | |
| } catch (e) { | |
| self.postMessage({ status: "error", data: e.toString() }); | |
| } | |
| } | |
| self.addEventListener("message", async (e) => { | |
| const { type, data } = e.data; | |
| switch (type) { | |
| case "check": | |
| check(); | |
| break; | |
| case "load": | |
| load(data); | |
| break; | |
| case "generate": | |
| stopping_criteria.reset(); | |
| generate(data); | |
| break; | |
| case "interrupt": | |
| stopping_criteria.interrupt(); | |
| break; | |
| case "reset": | |
| disposePastKeyValues(); | |
| stopping_criteria.reset(); | |
| break; | |
| } | |
| }); | |