| |
| |
| |
| class Conversation { |
| constructor(config) { |
| this.system = config.system; |
| this.roles = config.roles; |
| this.offset = config.offset; |
| this.seps = config.seps; |
| this.convId = null; |
| this.contextWindowStart = 0; |
| } |
|
|
| |
| |
| |
| |
| |
| getPromptArray() { |
| if (this.seps.length == 0) { |
| throw Error("Need seps to work") |
| } |
| let ret = [this.system + this.seps[0]]; |
|
|
| for (let i = 0; i < tvmjsGlobalEnv.workerHistoryMsg.length; ++i) { |
| const item = tvmjsGlobalEnv.workerHistoryMsg[i]; |
| const role = item[0]; |
| const message = item[1]; |
| if (message !== undefined && message != "") { |
| ret.push(role + ": " + message + this.seps[i % this.seps.length]); |
| } else { |
| ret.push(role + ":"); |
| } |
| } |
| return ret; |
| } |
|
|
| |
| |
| |
| |
| |
| getPromptArrayUnproccessed() { |
| if (this.seps.length == 0) { |
| throw Error("Need seps to work") |
| } |
| if (tvmjsGlobalEnv.workerHistoryMsg.length < 3) { |
| throw Error("needs to call getLastPromptArray for the first message"); |
| } |
| let ret = [this.seps[this.seps.length - 1]]; |
| for (let i = tvmjsGlobalEnv.workerHistoryMsg.length - 2; i < tvmjsGlobalEnv.workerHistoryMsg.length; ++i) { |
| const item = tvmjsGlobalEnv.workerHistoryMsg[i]; |
| const role = item[0]; |
| const message = item[1]; |
| if (message !== undefined && message != "") { |
| ret.push(role + ": " + message + this.seps[i % this.seps.length]); |
| } else { |
| ret.push(role + ":"); |
| } |
| } |
| return ret; |
|
|
| } |
|
|
| |
| |
| |
| |
| |
| getLastPromptArray() { |
| if (this.seps.length == 0) { |
| throw Error("Need seps to work") |
| } |
| let ret = [this.system + this.seps[0]]; |
|
|
| for (let i = tvmjsGlobalEnv.workerHistoryMsg.length - 2; i < tvmjsGlobalEnv.workerHistoryMsg.length; ++i) { |
| const item = tvmjsGlobalEnv.workerHistoryMsg[i]; |
| const role = item[0]; |
| const message = item[1]; |
| if (message !== undefined && message != "") { |
| ret.push(role + ": " + message + this.seps[i % this.seps.length]); |
| } else { |
| ret.push(role + ":"); |
| } |
| } |
| return ret; |
| } |
|
|
| reset() { |
| tvmjsGlobalEnv.workerHistoryMsg = []; |
| this.covId = null |
| } |
|
|
| getStopStr() { |
| return this.seps[this.seps.length - 1]; |
| } |
|
|
| appendMessage(role, message) { |
| tvmjsGlobalEnv.workerHistoryMsg.push([role, message]); |
| } |
| |
| switchConversation(message) { |
| tvmjsGlobalEnv.workerHistoryMsg = message |
| this.covId = tvmjsGlobalEnv.covId |
| } |
| } |
|
|
| function defaultConversation(maxWindowLength = 2048) { |
| return new Conversation({ |
| system: "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. Follow the user's instructions carefully. Respond using markdown.", |
| roles: ["user", "assistant"], |
| maxWindowLength: maxWindowLength, |
| offset: 0, |
| seps: [" ", "</s>"], |
| }); |
| }; |
|
|
| class LLMChatPipeline { |
| constructor(tvm, tokenizer, cacheMetadata, config) { |
| if (cacheMetadata == undefined) { |
| throw Error("Expect cacheMetadata"); |
| } |
| this.tvm = tvm; |
| this.logger = console.log; |
| this.tokenizer = tokenizer; |
| this.bosTokenId = 1; |
| this.eosTokenId = 2; |
|
|
| this.maxWindowLength = config.maxWindowLength; |
| this.maxGenLength = config.maxGenLength; |
| this.meanGenLength = config.meanGenLength; |
| this.streamInterval = 1; |
|
|
| this.decodingTotalTime = 0; |
| this.decodingTotalTokens = 0; |
| this.encodingTotalTime = 0; |
| this.encodingTotalTokens = 0; |
|
|
| this.conversation = defaultConversation(this.maxWindowLength); |
|
|
| this.device = this.tvm.webgpu(); |
| this.vm = this.tvm.detachFromCurrentScope( |
| this.tvm.createVirtualMachine(this.device) |
| ); |
| this.encoding = this.tvm.detachFromCurrentScope( |
| this.vm.getFunction("encoding") |
| ); |
| this.decoding = this.tvm.detachFromCurrentScope( |
| this.vm.getFunction("decoding") |
| ); |
| this.params = this.tvm.detachFromCurrentScope( |
| this.tvm.getParamsFromCache("param", cacheMetadata.ParamSize) |
| ); |
| const fcreateCache = this.vm.getFunction("create_kv_cache"); |
| this.fclearKVCaches = this.tvm.detachFromCurrentScope( |
| this.tvm.getGlobalFunc("vm.builtin.attention_kv_cache_array_clear") |
| ); |
|
|
| |
| this.kvCache = this.tvm.detachFromCurrentScope(fcreateCache()); |
| |
| this.logitsOnCPU = undefined; |
|
|
| this.kvCacheLength = 0; |
| this.clearCache = true |
| } |
|
|
|
|
| dispose() { |
| |
| this.params.dispose(); |
| this.decoding.dispose(); |
| this.encoding.dispose(); |
| this.vm.dispose(); |
| this.kvCache.dispose(); |
| this.fclearKVCaches.dispose(); |
| if (this.logitsOnCPU != undefined) { |
| this.logitsOnCPU.dispose(); |
| } |
| } |
|
|
| #clearKVCache() { |
| this.fclearKVCaches(this.kvCache); |
| this.kvCacheLength = 0; |
| } |
|
|
| #forward(inputs, curPos) { |
| this.tvm.beginScope(); |
| var retValue; |
| const seqLenShape = this.tvm.makeShapeTuple([curPos]); |
| if (inputs.shape[1] > 1) { |
| retValue = this.encoding( |
| inputs, seqLenShape, this.kvCache, this.params |
| ); |
| } else { |
| retValue = this.decoding( |
| inputs, seqLenShape, this.kvCache, this.params |
| ); |
| } |
| const logits = this.tvm.detachFromCurrentScope(retValue.get(0)); |
| this.tvm.endScope(); |
| this.tvm.attachToCurrentScope(logits); |
| return logits; |
| } |
|
|
| |
| #updateLogitsOnCPU(logits) { |
| if (this.logitsOnCPU == undefined) { |
| this.logitsOnCPU = this.tvm.detachFromCurrentScope( |
| this.tvm.empty(logits.shape, logits.dtype, this.tvm.cpu()) |
| ); |
| } else { |
| if (logits.shape[0] != this.logitsOnCPU.shape[0]) { |
| throw Error("We expect the size of logits to remain unchanged"); |
| } |
| } |
| this.logitsOnCPU.copyFrom(logits); |
| } |
|
|
| async sampleTokenFromLogits(logits, temperature = 0.8, top_p = 0.95) { |
| this.tvm.beginScope(); |
| this.#updateLogitsOnCPU(logits); |
| this.tvm.endScope(); |
| await this.device.sync(); |
| return this.tvm.sampleTopPFromLogits(this.logitsOnCPU, temperature, top_p); |
| } |
|
|
| async getInputTokens() { |
| let tokens = [this.bosTokenId]; |
| let prompts = "" |
| if (tvmjsGlobalEnv.workerHistoryMsg.length <= 2) { |
| prompts = this.conversation.getPromptArray(); |
| } else { |
| tokens.pop(); |
| prompts = this.conversation.getPromptArrayUnproccessed(); |
| } |
| tokens.push(...await this.tokenizer.encodeIds(prompts[0])); |
| let ctxLength = tokens.length; |
| let context = []; |
| let need_shift_window = false; |
| for (let i = prompts.length - 1; i > 0; --i) { |
| const encoded = this.tokenizer.encodeIds(prompts[i]); |
| ctxLength += encoded.length; |
| if (this.kvCacheLength + ctxLength + this.meanGenLength >= this.maxWindowLength) { |
| need_shift_window = true; |
| break; |
| } |
| context.unshift(encoded); |
| } |
| if (!need_shift_window) { |
| for (const ctx of context) { |
| tokens.push(...ctx); |
| } |
| return tokens; |
| } |
| |
| this.logger("need shift window") |
| this.kvCacheLength = 0; |
| this.clearCache = true; |
| |
| tokens = [this.bosTokenId] |
| let all_prompts = this.conversation.getPromptArray(); |
| tokens.push(...await this.tokenizer.encodeIds(all_prompts[0])); |
| context = []; |
| ctxLength = tokens.length; |
| |
| const fill_factor = 0.1 |
| for (let i = all_prompts.length - 1; i > 0; --i) { |
| const encoded = this.tokenizer.encodeIds(all_prompts[i]); |
| ctxLength += encoded.length; |
| if (ctxLength >= fill_factor * this.maxWindowLength && i + 2 < all_prompts.length) { |
| break; |
| } |
| context.unshift(encoded); |
| } |
| for (const ctx of context) { |
| tokens.push(...ctx); |
| } |
| if (tokens.length + this.meanGenLength >= this.maxWindowLength) { |
| throw Error("Exceed max window length curr=" + tokens.length); |
| } |
| return tokens; |
| } |
|
|
| resetChat() { |
| if (this.conversation) { |
| this.conversation.reset(); |
| } |
| this.#clearKVCache(); |
| this.decodingTotalTime = 0; |
| this.encodingTotalTime = 0; |
| this.decodingTotalTokens = 0; |
| this.encodingTotalTokens = 0; |
| } |
|
|
| async generate(inputPrompt, callbackUpdateResponse) { |
| |
| if (this.conversation.convId !== tvmjsGlobalEnv.covId) {} |
| this.conversation.appendMessage(this.conversation.roles[0], inputPrompt); |
| this.conversation.appendMessage(this.conversation.roles[1], ""); |
| const stopStr = this.conversation.getStopStr(); |
| const tokens = await this.getInputTokens(); |
| const inputTokenLength = tokens.length; |
|
|
| var outputPrompt = ""; |
| if (this.clearCache) { |
| this.#clearKVCache(); |
| this.clearCache = false; |
| } |
| const maxGenLen = Math.min(this.maxGenLength, this.maxWindowLength - tokens.length); |
| if (maxGenLen < this.meanGenLength) { |
| throw Error("Too small window size config"); |
| } |
| let step = 0; |
| for (; step < maxGenLen && this.kvCacheLength + inputTokenLength + step < this.maxWindowLength; ++step) { |
| this.tvm.beginScope(); |
| var inputData; |
|
|
| let tstart = performance.now(); |
| if (step == 0) { |
| inputData = this.tvm.empty([1, tokens.length], "int32", this.device); |
| inputData.copyFrom(tokens); |
| } else { |
| inputData = this.tvm.empty([1, 1], "int32", this.device); |
| inputData.copyFrom(tokens.slice(tokens.length - 1)); |
| } |
| const logits = this.tvm.detachFromCurrentScope( |
| this.#forward(inputData, this.kvCacheLength + inputTokenLength + step) |
| ); |
| this.tvm.endScope(); |
|
|
| const nextToken = await this.sampleTokenFromLogits(logits); |
| logits.dispose(); |
|
|
| tokens.push(nextToken); |
| const outputTokens = tokens.slice(inputTokenLength); |
| outputPrompt = this.tokenizer.decodeIds(outputTokens); |
|
|
| if (nextToken == this.eosTokenId) break; |
|
|
| const stopPos = outputPrompt.lastIndexOf(stopStr); |
| if (stopPos != -1) { |
| outputPrompt = outputPrompt.substring(0, stopPos); |
| break; |
| } |
| let tend = performance.now(); |
| if (step != 0) { |
| this.decodingTotalTokens += 1; |
| this.decodingTotalTime += (tend - tstart) / 1000; |
| } else { |
| this.encodingTotalTime += (tend - tstart) / 1000; |
| this.encodingTotalTokens += inputTokenLength; |
| } |
|
|
| if (step % this.streamInterval == 0) { |
| callbackUpdateResponse(step, outputPrompt); |
| } |
| } |
| this.kvCacheLength += tokens.length - 1; |
| tvmjsGlobalEnv.workerHistoryMsg[tvmjsGlobalEnv.workerHistoryMsg.length - 1][1] = outputPrompt; |
| return outputPrompt; |
| } |
|
|
| async evaluate() { |
| |
| this.#clearKVCache(); |
| const testPrompt = "The capital of Canada is"; |
| const ids = await this.tokenizer.encodeIds(testPrompt); |
| const inputPromptSize = ids.length; |
| const tokens = Array.from(ids); |
| tokens.unshift(this.bosTokenId); |
| if (tokens.length == 0) { |
| throw Error("empty token"); |
| } |
|
|
| this.tvm.beginScope(); |
| const inputData = this.tvm.empty([1, tokens.length], "int32", this.device); |
| inputData.copyFrom(tokens); |
| const encodingStart = performance.now(); |
| this.#forward(inputData, tokens.length); |
| this.tvm.endScope(); |
| await this.device.sync(); |
|
|
| const decodingStart = performance.now(); |
|
|
| this.tvm.beginScope(); |
| const firstSampleToken = this.tvm.empty([1, 1], "int32", this.device).copyFrom([6234]); |
| this.#updateLogitsOnCPU(this.#forward(firstSampleToken, tokens.length + 1)); |
| await this.device.sync(); |
| this.tvm.endScope(); |
|
|
| const decodingEnd = performance.now(); |
| const msg = ( |
| `encoding-time=${((decodingStart - encodingStart) / 1000).toFixed(4)} sec` + |
| `decoding-time=${((decodingEnd - decodingStart) / 1000).toFixed(4)} sec` |
| ); |
|
|
| |
| console.log("Logits:"); |
| console.log(this.logitsOnCPU.toArray()); |
| console.log(msg); |
| } |
|
|
| |
| |
| |
| async asyncLoadWebGPUPiplines() { |
| await this.tvm.asyncLoadWebGPUPiplines(this.vm.getInternalModule()); |
| } |
|
|
| runtimeStatsText() { |
| return ( |
| `encoding: ${(this.encodingTotalTokens / this.encodingTotalTime).toFixed(4)} tokens/sec, ` + |
| `decoding: ${(this.decodingTotalTokens / this.decodingTotalTime).toFixed(4)} tokens/sec` |
| ) |
| } |
| } |
|
|
| |
| |
| |
| class LLMChatInstance { |
| constructor() { |
| this.requestInProgress = false; |
| this.config = undefined; |
| this.tvm = undefined; |
| this.pipeline = undefined; |
| this.logger = console.log; |
| this.debugTest = false; |
| } |
| |
| |
| |
| |
| |
| |
| async #asyncInitTVM(wasmUrl, cacheUrl) { |
| if (this.tvm !== undefined) { |
| return; |
| } |
| this.logger = console.log; |
|
|
| const wasmSource = await ( |
| await fetch(wasmUrl) |
| ).arrayBuffer(); |
| const tvm = await tvmjs.instantiate( |
| new Uint8Array(wasmSource), |
| new EmccWASI(), |
| this.logger |
| ); |
| |
| try { |
| const output = await tvmjs.detectGPUDevice(); |
| if (output !== undefined) { |
| var label = "WebGPU"; |
| if (output.adapterInfo.description.length != 0) { |
| label += " - " + output.adapterInfo.description; |
| } else { |
| label += " - " + output.adapterInfo.vendor; |
| } |
| this.appendMessage("init", "Initialize GPU device: " + label); |
| tvm.initWebGPU(output.device); |
| } else { |
| this.appendMessage("error", "This browser env do not support WebGPU"); |
| this.reset(); |
| throw Error("This browser env do not support WebGPU"); |
| } |
| } catch (err) { |
| this.appendMessage("error", "Find an error initializing the WebGPU device " + err.toString()); |
| console.log(err); |
| this.reset(); |
| throw Error("Find an error initializing WebGPU: " + err.toString()); |
| } |
| this.tvm = tvm; |
| const initProgressCallback = (report) => { |
| this.updateLastMessage("initing", report.text); |
| } |
| tvm.registerInitProgressCallback(initProgressCallback); |
|
|
| await tvm.fetchNDArrayCache(cacheUrl, tvm.webgpu()); |
| } |
| |
| |
| |
| async asyncInit() { |
| if (this.pipeline !== undefined) return; |
| await this.#asyncInitConfig(); |
| await this.#asyncInitTVM(this.config.wasmUrl, this.config.cacheUrl); |
| await this.#asyncInitPipeline(); |
| } |
|
|
| |
| |
| |
| async #asyncInitConfig() { |
| if (this.config !== undefined) return; |
| this.config = await (await fetch("/lib/WebLLM/config.json")).json(); |
| } |
|
|
| |
| |
| |
| |
| |
| async #asyncInitPipeline() { |
| if (this.pipeline !== undefined) return; |
| |
| const tokenizer = await tvmjsGlobalEnv.sentencePieceProcessor(this.config.tokenizer); |
| this.pipeline = this.tvm.withNewScope(() => { |
| return new LLMChatPipeline(this.tvm, tokenizer, this.tvm.cacheMetadata, this.config); |
| }); |
| await this.pipeline.asyncLoadWebGPUPiplines(); |
| this.appendMessage("initing", "All initialization finished.", true); |
| } |
|
|
| appendMessage(kind, text, ifFinish) { |
| if (kind == "initing") { |
| text = "[System Initalize] " + text; |
| } |
| console.log(`[${kind}] ${text}`); |
| globalThis.postMessage({ |
| type: 'initing', |
| action: 'append', |
| msg: text, |
| ifError: kind == 'error', |
| ifFinish: !!ifFinish |
| }) |
| } |
|
|
| updateLastMessage(type, text, ifFinish) { |
| if (type == "initing") { |
| text = `[System Initalize] ${text}` |
| } |
| globalThis.postMessage({ |
| type, |
| action: 'updateLast', |
| msg: text, |
| ifFinish: !!ifFinish |
| }) |
| } |
|
|
| async respondTestMessage(repeat) { |
| const testMessage = "I am a friendly bot. Please ask questions."; |
| const encodedResult = await this.pipeline.tokenizer.encodeIds(testMessage); |
|
|
| const currentIds = []; |
| for (let k = 0; k < repeat; ++k) { |
| for (let i = 0; i < encodedResult.length; ++i) { |
| currentIds.push(encodedResult[i]); |
| const msg = this.pipeline.tokenizer.decodeIds(currentIds); |
| this.updateLastMessage("chatting", msg); |
| await new Promise(resolve => setTimeout(resolve, 50)); |
| } |
| } |
| } |
|
|
| resetChat() { |
| if (this.pipeline) { |
| this.pipeline.resetChat(); |
| } |
| } |
|
|
| |
| |
| |
| async generate() { |
| if (this.requestInProgress) { |
| return; |
| } |
|
|
| this.requestInProgress = true; |
|
|
| try { |
| await this.asyncInit(); |
| } catch (err) { |
| this.appendMessage("error", "Init error, " + err.toString()); |
| console.log(err); |
| this.reset(); |
| this.requestInProgress = false; |
| return; |
| } |
|
|
| if (this.debugTest) { |
| await this.pipeline.evaluate(); |
| this.requestInProgress = false; |
| return; |
| } |
|
|
| const prompt = tvmjsGlobalEnv.message; |
| if (prompt == "") { |
| this.requestInProgress = false; |
| return; |
| } |
|
|
| const callbackUpdateResponse = (step, msg) => { |
| if (msg.endsWith("##")) { |
| msg = msg.substring(0, msg.length - 2); |
| } else if (msg.endsWith("#")) { |
| msg = msg.substring(0, msg.length - 1); |
| } |
| this.updateLastMessage("chatting", msg); |
| }; |
| try { |
| const output = await this.pipeline.generate(prompt, callbackUpdateResponse); |
| this.updateLastMessage("chatting", output, true); |
| this.updateLastMessage("stats",this.pipeline.runtimeStatsText()) |
| console.log(this.pipeline.runtimeStatsText()); |
| } catch (err) { |
| this.appendMessage("error", "Generate error, " + err.toString()); |
| console.log(err); |
| this.reset(); |
| } |
| this.requestInProgress = false; |
| } |
|
|
| |
| |
| |
| reset() { |
| this.tvm = undefined; |
| if (this.pipeline !== undefined) { |
| this.pipeline.dispose(); |
| } |
| this.pipeline = undefined; |
| } |
| } |
|
|
| localLLMChatIntance = new LLMChatInstance(); |
|
|
| tvmjsGlobalEnv.asyncOnGenerate = async function () { |
| await localLLMChatIntance.generate(); |
| }; |
|
|
| tvmjsGlobalEnv.asyncOnReset = async function () { |
| await localLLMChatIntance.resetChat(); |
| }; |
|
|