Spaces:
Running
Running
File size: 5,461 Bytes
814c07e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 | import * as ort from 'onnxruntime-web/wasm';
import { Tokenizer } from './tokenizer';
import { runEncoder, stepDecoder, initialPastKv } from './runtime';
import type { NeedleSessions } from './runtime';
export interface GenerateOpts {
maxNewTokens?: number; // default 256
eosTokenId: number; // = 1 per tokenizer-specials.json
bosOrPrefixTokenId: number; // Cactus uses EOS (id=1) as the decoder seed; pass eos here.
toolsTokenId: number; // = 5 per tokenizer-specials.json
}
/**
* Format (query, tools) into the encoder input token list, matching Cactus's
* `_build_encoder_input`:
*
* [query_tokens..., <tools>(id=5), tools_tokens...]
*
* The Python side truncates to max_enc_len=1024; we do the same here. Tools are
* stringified to JSON before encoding.
*/
export function buildEncoderInput(
tokenizer: Tokenizer,
query: string,
tools: unknown[],
toolsTokenId: number,
maxEncLen = 1024,
): number[] {
const qTokens = tokenizer.encode(query);
const tTokens = tokenizer.encode(JSON.stringify(tools));
const maxQuery = maxEncLen - 2;
const q = qTokens.length > maxQuery ? qTokens.slice(0, maxQuery) : qTokens;
const remaining = maxEncLen - q.length - 1;
const t = tTokens.slice(0, remaining);
return [...q, toolsTokenId, ...t];
}
export async function generate(
sessions: NeedleSessions,
tokenizer: Tokenizer,
query: string,
tools: unknown[],
opts: GenerateOpts,
onToken?: (id: number, decodedSoFar: string) => void,
): Promise<{ ids: number[]; text: string }> {
const encoderInputIds = buildEncoderInput(tokenizer, query, tools, opts.toolsTokenId);
const encoderOut = await runEncoder(sessions.encoder, encoderInputIds);
let pastKv = initialPastKv();
let nextId = opts.bosOrPrefixTokenId; // Cactus convention: decoder seeded with EOS (id=1)
const generated: number[] = [];
const maxNew = opts.maxNewTokens ?? 256;
for (let i = 0; i < maxNew; i++) {
const { logits, presentSelfKv } = await stepDecoder(
sessions.decoder, nextId, encoderOut, pastKv,
);
pastKv = presentSelfKv;
nextId = sampleNextToken(logits);
if (shouldStop(nextId, generated, opts.eosTokenId, tokenizer)) break;
generated.push(nextId);
onToken?.(nextId, tokenizer.decode(generated));
}
let text = tokenizer.decode(generated);
// Strip the leading <tool_call> marker that Cactus's generate() also strips.
if (text.startsWith('<tool_call>')) text = text.slice('<tool_call>'.length);
return { ids: generated, text };
}
// ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
// USER CONTRIBUTION POINT #1 β sampleNextToken
//
// Pick the next token from a (1, 1, vocab_size) float32 logits tensor.
//
// Choices:
// (a) argmax β deterministic, repeatable. Function calling has a narrow
// correct answer; argmax is what Cactus's native generate() uses.
// (b) temperature sampling β softmax(logits / T), then sample. T<1 = sharper,
// T>1 = more varied. Non-deterministic without a seed.
// (c) top-p (nucleus) β softmax, sort, keep tokens until cumulative β₯ p,
// sample. Most "natural" sampling but adds two hyperparams.
//
// Default (if no preference given): (a) argmax.
// ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
function sampleNextToken(logits: ort.Tensor): number {
const data = logits.data as Float32Array; // shape (1, 1, vocab_size) β flat array of vocab_size
// USER: replace this body with your choice from (a)/(b)/(c). The default is (a) argmax.
let bestIdx = 0;
let bestVal = -Infinity;
for (let i = 0; i < data.length; i++) {
if (data[i] > bestVal) { bestVal = data[i]; bestIdx = i; }
}
return bestIdx;
}
// ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
// USER CONTRIBUTION POINT #2 β shouldStop
//
// Decide whether to halt generation after emitting `nextId` on top of `soFar`.
// The decode loop's `maxNewTokens` cap also bounds the loop independently.
//
// Choices:
// (a) EOS-only β matches Cactus's native generate() exactly. Simplest, safest.
// (b) EOS OR balanced-JSON β if the decoded text since <tool_call> is a
// valid parseable JSON array (e.g. ']' at top level with brace balance
// at zero), stop. Crisper exit when the model trails into padding.
// (c) EOS OR balanced-JSON OR token == ']' β same as (b) but cheaper to
// check, since tokenizer's ']' token ID is fixed.
//
// Default: (a) EOS-only.
// ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
function shouldStop(
nextId: number,
soFar: number[],
eosId: number,
tokenizer: Tokenizer,
): boolean {
// USER: replace this body with your choice from (a)/(b)/(c). Default is (a).
void soFar; void tokenizer;
return nextId === eosId;
}
|