File size: 5,461 Bytes
814c07e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import * as ort from 'onnxruntime-web/wasm';
import { Tokenizer } from './tokenizer';
import { runEncoder, stepDecoder, initialPastKv } from './runtime';
import type { NeedleSessions } from './runtime';

export interface GenerateOpts {
  maxNewTokens?: number;     // default 256
  eosTokenId: number;        // = 1 per tokenizer-specials.json
  bosOrPrefixTokenId: number; // Cactus uses EOS (id=1) as the decoder seed; pass eos here.
  toolsTokenId: number;      // = 5 per tokenizer-specials.json
}

/**
 * Format (query, tools) into the encoder input token list, matching Cactus's
 * `_build_encoder_input`:
 *
 *     [query_tokens..., <tools>(id=5), tools_tokens...]
 *
 * The Python side truncates to max_enc_len=1024; we do the same here. Tools are
 * stringified to JSON before encoding.
 */
export function buildEncoderInput(
  tokenizer: Tokenizer,
  query: string,
  tools: unknown[],
  toolsTokenId: number,
  maxEncLen = 1024,
): number[] {
  const qTokens = tokenizer.encode(query);
  const tTokens = tokenizer.encode(JSON.stringify(tools));
  const maxQuery = maxEncLen - 2;
  const q = qTokens.length > maxQuery ? qTokens.slice(0, maxQuery) : qTokens;
  const remaining = maxEncLen - q.length - 1;
  const t = tTokens.slice(0, remaining);
  return [...q, toolsTokenId, ...t];
}

export async function generate(
  sessions: NeedleSessions,
  tokenizer: Tokenizer,
  query: string,
  tools: unknown[],
  opts: GenerateOpts,
  onToken?: (id: number, decodedSoFar: string) => void,
): Promise<{ ids: number[]; text: string }> {
  const encoderInputIds = buildEncoderInput(tokenizer, query, tools, opts.toolsTokenId);
  const encoderOut = await runEncoder(sessions.encoder, encoderInputIds);

  let pastKv = initialPastKv();
  let nextId = opts.bosOrPrefixTokenId; // Cactus convention: decoder seeded with EOS (id=1)
  const generated: number[] = [];
  const maxNew = opts.maxNewTokens ?? 256;

  for (let i = 0; i < maxNew; i++) {
    const { logits, presentSelfKv } = await stepDecoder(
      sessions.decoder, nextId, encoderOut, pastKv,
    );
    pastKv = presentSelfKv;

    nextId = sampleNextToken(logits);
    if (shouldStop(nextId, generated, opts.eosTokenId, tokenizer)) break;

    generated.push(nextId);
    onToken?.(nextId, tokenizer.decode(generated));
  }

  let text = tokenizer.decode(generated);
  // Strip the leading <tool_call> marker that Cactus's generate() also strips.
  if (text.startsWith('<tool_call>')) text = text.slice('<tool_call>'.length);
  return { ids: generated, text };
}

// ────────────────────────────────────────────────────────────────────────────
// USER CONTRIBUTION POINT #1 β€” sampleNextToken
//
// Pick the next token from a (1, 1, vocab_size) float32 logits tensor.
//
// Choices:
//   (a) argmax β€” deterministic, repeatable. Function calling has a narrow
//       correct answer; argmax is what Cactus's native generate() uses.
//   (b) temperature sampling β€” softmax(logits / T), then sample. T<1 = sharper,
//       T>1 = more varied. Non-deterministic without a seed.
//   (c) top-p (nucleus) β€” softmax, sort, keep tokens until cumulative β‰₯ p,
//       sample. Most "natural" sampling but adds two hyperparams.
//
// Default (if no preference given): (a) argmax.
// ────────────────────────────────────────────────────────────────────────────
function sampleNextToken(logits: ort.Tensor): number {
  const data = logits.data as Float32Array;     // shape (1, 1, vocab_size) β†’ flat array of vocab_size
  // USER: replace this body with your choice from (a)/(b)/(c). The default is (a) argmax.
  let bestIdx = 0;
  let bestVal = -Infinity;
  for (let i = 0; i < data.length; i++) {
    if (data[i] > bestVal) { bestVal = data[i]; bestIdx = i; }
  }
  return bestIdx;
}

// ────────────────────────────────────────────────────────────────────────────
// USER CONTRIBUTION POINT #2 β€” shouldStop
//
// Decide whether to halt generation after emitting `nextId` on top of `soFar`.
// The decode loop's `maxNewTokens` cap also bounds the loop independently.
//
// Choices:
//   (a) EOS-only β€” matches Cactus's native generate() exactly. Simplest, safest.
//   (b) EOS OR balanced-JSON β€” if the decoded text since <tool_call> is a
//       valid parseable JSON array (e.g. ']' at top level with brace balance
//       at zero), stop. Crisper exit when the model trails into padding.
//   (c) EOS OR balanced-JSON OR token == ']' β€” same as (b) but cheaper to
//       check, since tokenizer's ']' token ID is fixed.
//
// Default: (a) EOS-only.
// ────────────────────────────────────────────────────────────────────────────
function shouldStop(
  nextId: number,
  soFar: number[],
  eosId: number,
  tokenizer: Tokenizer,
): boolean {
  // USER: replace this body with your choice from (a)/(b)/(c). Default is (a).
  void soFar; void tokenizer;
  return nextId === eosId;
}