File size: 3,377 Bytes
814c07e
 
 
 
 
 
6c01099
814c07e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c01099
 
814c07e
6c01099
814c07e
 
 
 
6c01099
 
 
814c07e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6c01099
 
 
 
 
 
814c07e
 
 
 
 
 
 
 
 
 
6c01099
814c07e
 
 
 
 
6c01099
 
 
 
814c07e
6c01099
814c07e
6c01099
814c07e
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import './style.css';
import { createTokenizer } from './tokenizer';
import type { Tokenizer } from './tokenizer';
import { loadSessions } from './runtime';
import type { NeedleSessions } from './runtime';
import { generate } from './generate';
import { mountUI, setStatus, renderResult, renderError, readTools, setInteractiveEnabled } from './ui';
import type { UI } from './ui';
import { TOKENIZER_URL, SPECIALS_URL } from './config';

interface Specials { pad: number; eos: number; bos: number; tool_call: number; tools: number; }

async function fetchBytes(url: string): Promise<Uint8Array> {
  const resp = await fetch(url);
  if (!resp.ok) throw new Error(`fetch ${url}: ${resp.status}`);
  return new Uint8Array(await resp.arrayBuffer());
}

async function fetchJson<T>(url: string): Promise<T> {
  const resp = await fetch(url);
  if (!resp.ok) throw new Error(`fetch ${url}: ${resp.status}`);
  return resp.json();
}

async function boot() {
  const ui = mountUI();
  try {
    setStatus(ui, 'loading model…', true);
    const t0 = performance.now();
    const [sessions, tokenizerBytes, specials] = await Promise.all([
      loadSessions(m => setStatus(ui, m, true)),
      fetchBytes(TOKENIZER_URL),
      fetchJson<Specials>(SPECIALS_URL),
    ]);
    const tokenizer = await createTokenizer(tokenizerBytes);
    const loadSecs = ((performance.now() - t0) / 1000).toFixed(1);
    setStatus(ui, `ready · loaded in ${loadSecs}s`);
    setInteractiveEnabled(ui, true);
    wireRun(ui, sessions, tokenizer, specials);
  } catch (e) {
    setStatus(ui, 'failed');
    renderError(ui, `Failed to load model: ${(e as Error).message}`);
  }
}

function wireRun(ui: UI, sessions: NeedleSessions, tokenizer: Tokenizer, specials: Specials) {
  let running = false;
  ui.queryEl.addEventListener('change', async () => {
    if (running) return;
    const tools = readTools(ui);
    if (!tools.ok) { renderError(ui, tools.error); return; }
    const query = ui.queryEl.value.trim();
    if (!query) return;
    running = true;
    let tokensSoFar = 0;
    const t0 = performance.now();
    const tick = setInterval(() => {
      const elapsed = ((performance.now() - t0) / 1000).toFixed(1);
      setStatus(ui, `generating… ${elapsed}s · ${tokensSoFar} tok`, true);
    }, 100);
    try {
      const result = await generate(
        sessions, tokenizer, query, tools.tools,
        {
          eosTokenId: specials.eos,
          bosOrPrefixTokenId: specials.eos,   // Cactus seeds decoder with EOS, not BOS
          toolsTokenId: specials.tools,
          maxNewTokens: 256,
        },
        (_id, decodedSoFar) => {
          tokensSoFar += 1;
          let display = decodedSoFar;
          if (display.startsWith('<tool_call>')) display = display.slice('<tool_call>'.length);
          renderResult(ui, display);
        },
      );
      clearInterval(tick);
      const elapsedMs = performance.now() - t0;
      const elapsed = (elapsedMs / 1000).toFixed(2);
      const tps = (result.ids.length / (elapsedMs / 1000)).toFixed(1);
      renderResult(ui, result.text);
      setStatus(ui, `ready · ${elapsed}s · ${result.ids.length} tok · ${tps} tok/s`);
    } catch (e) {
      clearInterval(tick);
      renderError(ui, `Generation failed: ${(e as Error).message}`);
      setStatus(ui, 'ready');
    } finally {
      running = false;
    }
  });
}

boot();