shreyask commited on
Commit
5b2a426
Β·
verified Β·
1 Parent(s): 1267a61

Upload inspect_needle.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inspect_needle.py +488 -0
inspect_needle.py ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ inspect_needle.py β€” Introspect the Cactus/Needle Flax model, tokenizer, and prompt format.
3
+
4
+ Outputs three markdown files:
5
+ ../notes/needle-internals.md
6
+ ../notes/tokenizer-internals.md
7
+ ../notes/prompt-format.md
8
+ """
9
+
10
+ import json
11
+ import os
12
+ import sys
13
+ import textwrap
14
+ import inspect
15
+ import traceback
16
+
17
+ # Add the needle package to path
18
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'external', 'needle'))
19
+
20
+ import jax
21
+ import jax.numpy as jnp
22
+ import numpy as np
23
+
24
+ # ── 1. Import needle symbols ──────────────────────────────────────────────────
25
+ from needle.model.architecture import (
26
+ SimpleAttentionNetwork, TransformerConfig,
27
+ MultiHeadAttention, FeedForward, EncoderBlock, DecoderBlock,
28
+ Encoder, Decoder, ZCRMSNorm, make_causal_mask, make_padding_mask,
29
+ )
30
+ from needle.model.run import generate, _build_encoder_input, load_checkpoint
31
+ from needle.dataset.tokenizer import (
32
+ NeedleTokenizer, get_tokenizer,
33
+ PAD_ID, EOS_ID, BOS_ID, UNK_ID, TOOL_CALL_ID, TOOLS_ID,
34
+ DEFAULT_MAX_ENC_LEN, DEFAULT_MAX_DEC_LEN, DEFAULT_MAX_GEN_LEN,
35
+ TOKENIZER_PREFIX,
36
+ )
37
+
38
+ print("=== Needle/Cactus Inspector ===\n")
39
+
40
+ # ── 2. Instantiate model with default config ──────────────────────────────────
41
+ config = TransformerConfig()
42
+ print(f"TransformerConfig defaults:")
43
+ for field in config.__dataclass_fields__:
44
+ print(f" {field} = {getattr(config, field)}")
45
+
46
+ model = SimpleAttentionNetwork(config)
47
+
48
+ # Use init_all to initialize all parameters (both text + contrastive paths)
49
+ dummy_src = jnp.zeros((1, 16), dtype=jnp.int32)
50
+ dummy_tgt = jnp.zeros((1, 16), dtype=jnp.int32)
51
+ variables = model.init(jax.random.PRNGKey(0), dummy_src, dummy_tgt, method='init_all')
52
+ params = variables['params']
53
+
54
+ print("\n=== Param Tree (paths + shapes) ===")
55
+
56
+ def print_param_tree(tree, prefix='', indent=0):
57
+ """Recursively print param tree with shapes."""
58
+ lines = []
59
+ if isinstance(tree, dict):
60
+ for k, v in sorted(tree.items()):
61
+ if isinstance(v, dict):
62
+ lines.append(' ' * indent + f'{k}:')
63
+ lines.extend(print_param_tree(v, prefix=prefix+k+'.', indent=indent+1))
64
+ else:
65
+ shape = tuple(v.shape) if hasattr(v, 'shape') else repr(v)
66
+ dtype = str(v.dtype) if hasattr(v, 'dtype') else '?'
67
+ lines.append(' ' * indent + f'{k}: shape={shape}, dtype={dtype}')
68
+ return lines
69
+
70
+ param_lines = print_param_tree(params)
71
+ for line in param_lines:
72
+ print(line)
73
+
74
+ # Build shape dict for JSON serialization
75
+ def build_shape_dict(tree):
76
+ if isinstance(tree, dict):
77
+ return {k: build_shape_dict(v) for k, v in sorted(tree.items())}
78
+ else:
79
+ return {'shape': list(tree.shape), 'dtype': str(tree.dtype)}
80
+
81
+ shape_dict = build_shape_dict(params)
82
+
83
+ # Count parameters
84
+ total_params = sum(x.size for x in jax.tree.leaves(params))
85
+ print(f"\nTotal parameters: {total_params:,}")
86
+
87
+ # ── 3. Tokenizer info ─────────────────────────────────────────────────────────
88
+ print("\n=== Tokenizer ===")
89
+ tok = get_tokenizer()
90
+ print(f"Class: {type(tok).__name__}")
91
+ print(f"vocab_size: {tok.vocab_size}")
92
+ print(f"pad_token_id: {tok.pad_token_id}")
93
+ print(f"eos_token_id: {tok.eos_token_id}")
94
+ print(f"bos_token_id: {tok.bos_token_id}")
95
+ print(f"tool_call_token_id: {tok.tool_call_token_id}")
96
+ print(f"tools_token_id: {tok.tools_token_id}")
97
+ print(f"attrs: {[a for a in dir(tok) if not a.startswith('_')]}")
98
+
99
+ # Test encode/decode
100
+ sample_text = "set a 5 min timer"
101
+ encoded = tok.encode(sample_text)
102
+ decoded = tok.decode(encoded)
103
+ print(f"\nExample encode('{sample_text}'):")
104
+ print(f" -> {encoded}")
105
+ print(f" -> decode -> '{decoded}'")
106
+
107
+ # Get vocab size from SentencePiece directly
108
+ sp = tok.sp
109
+ print(f"\nSentencePiece GetPieceSize(): {sp.GetPieceSize()}")
110
+ print(f"Model type: BPE (confirmed in train_tokenizer source)")
111
+ print(f"byte_fallback: True (confirmed in train_tokenizer source)")
112
+ print(f"normalization_rule_name: identity (no Unicode normalization)")
113
+
114
+ # ── 4. Prompt format capture ──────────────────────────────────────────────────
115
+ print("\n=== Prompt Format Capture ===")
116
+
117
+ query = "set a 5 min timer"
118
+ tools = [{"name": "set_timer", "description": "Set a timer.", "parameters": {"time_human": {"type": "string", "description": "duration", "required": True}}}]
119
+ tools_json = json.dumps(tools)
120
+
121
+ captured_inputs = []
122
+ original_encode = tok.encode
123
+
124
+ def capture_encode(text, *a, **kw):
125
+ captured_inputs.append(('encode', text))
126
+ return original_encode(text, *a, **kw)
127
+
128
+ tok.encode = capture_encode
129
+
130
+ # Call _build_encoder_input directly
131
+ enc_tokens = _build_encoder_input(tok, query, tools_json)
132
+ print(f"Captured encode calls: {captured_inputs}")
133
+ print(f"Encoder token IDs: {enc_tokens}")
134
+ print(f"Encoder token count: {len(enc_tokens)}")
135
+
136
+ # Decode the encoder input to see the full prompt
137
+ decoded_enc = tok.decode(enc_tokens)
138
+ print(f"Decoded encoder input: {repr(decoded_enc)}")
139
+
140
+ # Also decode each segment
141
+ tok.encode = original_encode # restore
142
+ q_toks = tok.encode(query)
143
+ t_toks = tok.encode(tools_json)
144
+ print(f"\nQuery tokens: {q_toks} -> '{tok.decode(q_toks)}'")
145
+ print(f"tools_token_id (separator): {TOOLS_ID}")
146
+ print(f"Tools tokens: {t_toks[:20]}... -> '{tok.decode(t_toks[:20])}'")
147
+
148
+ # ── 5. Write notes files ──────────────────────────────────────────────────────
149
+
150
+ NOTES_DIR = os.path.join(os.path.dirname(__file__), '..', 'notes')
151
+ os.makedirs(NOTES_DIR, exist_ok=True)
152
+
153
+ # ── 5a. needle-internals.md ───────────────────────────────────────────────────
154
+
155
+ def fmt_shape_dict(d, indent=0):
156
+ """Format shape dict as markdown nested list."""
157
+ lines = []
158
+ prefix = ' ' * indent
159
+ for k, v in sorted(d.items()):
160
+ if isinstance(v, dict) and 'shape' not in v:
161
+ lines.append(f"{prefix}- **{k}**")
162
+ lines.extend(fmt_shape_dict(v, indent+1))
163
+ else:
164
+ shape = v.get('shape', '?')
165
+ dtype = v.get('dtype', '?')
166
+ lines.append(f"{prefix} - `{k}`: shape={shape}, dtype={dtype}")
167
+ return lines
168
+
169
+ needle_internals = f"""# Needle (Cactus) Flax Model Internals
170
+
171
+ ## Model: `SimpleAttentionNetwork`
172
+
173
+ **Architecture:** Encoder–decoder transformer with shared embeddings, tied output projection, RoPE, bfloat16.
174
+
175
+ ### `TransformerConfig` Defaults
176
+
177
+ | Field | Value | Notes |
178
+ |---|---|---|
179
+ | vocab_size | {config.vocab_size} | SentencePiece BPE vocab |
180
+ | d_model | {config.d_model} | Hidden dimension |
181
+ | num_heads | {config.num_heads} | Attention heads |
182
+ | num_kv_heads | {config.num_kv_heads} | GQA key/value heads |
183
+ | num_encoder_layers | {config.num_encoder_layers} | Encoder depth |
184
+ | num_decoder_layers | {config.num_decoder_layers} | Decoder depth |
185
+ | d_ff | {config.d_ff} | FFN hidden dim (unused by default) |
186
+ | max_seq_len | {config.max_seq_len} | |
187
+ | pad_token_id | {config.pad_token_id} | |
188
+ | rope_theta | {config.rope_theta} | RoPE base frequency |
189
+ | dtype | {config.dtype} | |
190
+ | activation | {config.activation} | dual-ReLU gate (drelu) |
191
+ | num_memory_slots | {config.num_memory_slots} | |
192
+ | dropout_rate | {config.dropout_rate} | |
193
+ | contrastive_dim | {config.contrastive_dim} | Contrastive projection output dim |
194
+ | no_feedforward | {config.no_feedforward} | **True β†’ FFN blocks are SKIPPED** |
195
+
196
+ **Total parameters:** {total_params:,}
197
+
198
+ **Important:** `no_feedforward=True` by default β€” `FeedForward` / `gate_proj` / `up_proj` / `down_proj` are NOT present in the saved param tree. The `ffn_gate` scalar also absent. This is a pure attention-only transformer.
199
+
200
+ ---
201
+
202
+ ## Layer-by-Layer Architecture
203
+
204
+ ### Encoder (`Encoder` module)
205
+
206
+ Uses `nn.scan` over `num_encoder_layers={config.num_encoder_layers}` layers. Each `EncoderBlock` contains:
207
+
208
+ 1. **`attn_gate`** β€” scalar sigmoid gate `()` on attention residual
209
+ 2. **Pre-norm**: `ZCRMSNorm` (scale initialized to 0, applied as `(1+Ξ³) * x / RMS(x)`)
210
+ 3. **`self_attn`** (`MultiHeadAttention`):
211
+ - `q_proj`: Dense `(d_model, d_model)` = `({config.d_model}, {config.d_model})`
212
+ - `k_proj`: Dense `(d_model, num_kv_heads * head_dim)` = `({config.d_model}, {config.num_kv_heads * (config.d_model // config.num_heads)})`
213
+ - `v_proj`: Dense `(d_model, num_kv_heads * head_dim)` = `({config.d_model}, {config.num_kv_heads * (config.d_model // config.num_heads)})`
214
+ - `out_proj`: Dense `(d_model, d_model)` = `({config.d_model}, {config.d_model})`
215
+ - `q_norm`, `k_norm`: `ZCRMSNorm` on Q and K (pre-RoPE)
216
+ 4. Residual: `x = x_in + gate * attn_out`
217
+ 5. (FFN block **skipped** because `no_feedforward=True`)
218
+ 6. **`final_norm`**: `ZCRMSNorm` after all layers (on encoder output)
219
+
220
+ ### Decoder (`Decoder` module)
221
+
222
+ Uses `nn.scan` over `num_decoder_layers={config.num_decoder_layers}` layers. Each `DecoderBlock` contains:
223
+
224
+ 1. **`self_attn_gate`** β€” scalar on causal self-attention residual
225
+ 2. Pre-norm + **`self_attn`** (causal self-attention, same structure as encoder)
226
+ 3. Residual: `x = x_in + self_gate * self_attn_out`
227
+ 4. **`cross_attn_gate`** β€” scalar on cross-attention residual
228
+ 5. Pre-norm + **`cross_attn`** (cross-attention: Q from decoder, K/V from encoder)
229
+ 6. Residual: `x = x_in + cross_gate * cross_attn_out`
230
+ 7. (FFN block **skipped**)
231
+ 8. Final `ZCRMSNorm` after all layers
232
+
233
+ ### Top-level `SimpleAttentionNetwork` params
234
+
235
+ - **`embedding`**: `nn.Embed` kernel `(vocab_size, d_model)` = `({config.vocab_size}, {config.d_model})`
236
+ Also used as tied output: `logits = hidden @ embedding.T`
237
+ - **`log_temp`**: scalar `()` β€” contrastive temperature
238
+ - **`contrastive_hidden`**: Dense `(d_model, d_model//4)` = `({config.d_model}, {config.d_model//4})`
239
+ - **`contrastive_proj`**: Dense `(d_model//4, contrastive_dim)` = `({config.d_model//4}, {config.contrastive_dim})`, no bias
240
+
241
+ ---
242
+
243
+ ## `nn.scan` Parameter Layout
244
+
245
+ Because Flax `nn.scan` with `variable_axes={{"params": 0}}` is used, each scanned layer stack is stored as a **single** parameter array with an extra leading dimension of size `num_layers`. So for encoder: `layers.attn_gate` has shape `(num_encoder_layers, ...)` = `({config.num_encoder_layers}, ...)`.
246
+
247
+ **This is critical for Task 2C:** When mapping Flax β†’ PyTorch weights, you must index `params['encoder']['layers']['...'][layer_idx]` to get per-layer weights.
248
+
249
+ ---
250
+
251
+ ## Full Flax `params` Tree (shapes)
252
+
253
+ ```json
254
+ {json.dumps(shape_dict, indent=2)}
255
+ ```
256
+
257
+ ---
258
+
259
+ ## `SimpleAttentionNetwork.__call__` Source
260
+
261
+ ```python
262
+ {inspect.getsource(SimpleAttentionNetwork.__call__)}
263
+ ```
264
+
265
+ ## `EncoderBlock.__call__` Source
266
+
267
+ ```python
268
+ {inspect.getsource(EncoderBlock.__call__)}
269
+ ```
270
+
271
+ ## `DecoderBlock.__call__` Source
272
+
273
+ ```python
274
+ {inspect.getsource(DecoderBlock.__call__)}
275
+ ```
276
+
277
+ ## `MultiHeadAttention.__call__` Source
278
+
279
+ ```python
280
+ {inspect.getsource(MultiHeadAttention.__call__)}
281
+ ```
282
+ """
283
+
284
+ with open(os.path.join(NOTES_DIR, 'needle-internals.md'), 'w') as f:
285
+ f.write(needle_internals)
286
+ print(f"\nWrote needle-internals.md ({len(needle_internals.splitlines())} lines)")
287
+
288
+ # ── 5b. tokenizer-internals.md ────────────────────────────────────────────────
289
+
290
+ tokenizer_internals = f"""# Needle Tokenizer Internals
291
+
292
+ ## Tokenizer Class
293
+
294
+ **Class:** `NeedleTokenizer` (defined in `external/needle/needle/dataset/tokenizer.py`)
295
+
296
+ **Backend:** SentencePiece BPE β€” the model file is downloaded from HuggingFace on first use.
297
+
298
+ **HuggingFace repo:** `Cactus-Compute/needle-tokenizer` (dataset repo)
299
+
300
+ **Local model file path:** `{TOKENIZER_PREFIX}.model`
301
+ (relative to the external/needle repo root: `needle/tokenizer/needle.model`)
302
+
303
+ ---
304
+
305
+ ## Tokenizer Properties
306
+
307
+ | Property | Value | Source |
308
+ |---|---|---|
309
+ | `vocab_size` | {tok.vocab_size} | `sp.GetPieceSize()` |
310
+ | `pad_token_id` | {PAD_ID} | hardcoded = 0 |
311
+ | `eos_token_id` | {EOS_ID} | hardcoded = 1 |
312
+ | `bos_token_id` | {BOS_ID} | hardcoded = 2 |
313
+ | `unk_id` | {UNK_ID} | hardcoded = 3 |
314
+ | `tool_call_token_id` | {TOOL_CALL_ID} | hardcoded = 4, symbol `<tool_call>` |
315
+ | `tools_token_id` | {TOOLS_ID} | hardcoded = 5, symbol `<tools>` |
316
+
317
+ ---
318
+
319
+ ## Encode Algorithm
320
+
321
+ **Algorithm:** SentencePiece **BPE** (Byte-Pair Encoding)
322
+
323
+ **Configured with:**
324
+ - `model_type="bpe"`
325
+ - `byte_fallback=True` β€” unknown bytes represented as `<0xNN>` hex tokens
326
+ - `normalization_rule_name="identity"` β€” **no Unicode normalization applied**
327
+ - `user_defined_symbols=["<tool_call>", "<tools>"]` β€” fixed IDs 4 and 5
328
+
329
+ **Encode call in NeedleTokenizer:**
330
+ ```python
331
+ def encode(self, text):
332
+ return self.sp.Encode(text, out_type=int)
333
+ ```
334
+ Returns a Python `list[int]`. Does NOT add BOS/EOS automatically.
335
+
336
+ **Decode call:**
337
+ ```python
338
+ def decode(self, ids):
339
+ return self.sp.Decode(list(ids))
340
+ ```
341
+
342
+ ---
343
+
344
+ ## Available Attributes on `NeedleTokenizer`
345
+
346
+ ```
347
+ {[a for a in dir(tok) if not a.startswith('_')]}
348
+ ```
349
+
350
+ The underlying SentencePiece processor is accessible as `tok.sp`.
351
+
352
+ ---
353
+
354
+ ## SentencePiece Model File
355
+
356
+ The `.model` file is a serialized SentencePiece protobuf. For Task 5 (port tokenizer to JS/ONNX/etc.), you need either:
357
+ 1. The `.model` file directly (loadable by the `sentencepiece` library)
358
+ 2. Export the vocabulary + merge rules via `tok.sp.GetPieceSize()`, `tok.sp.IdToPiece(i)`, etc.
359
+
360
+ To export the full vocabulary:
361
+ ```python
362
+ vocab = [(i, tok.sp.IdToPiece(i), tok.sp.GetScore(i)) for i in range(tok.vocab_size)]
363
+ ```
364
+
365
+ ---
366
+
367
+ ## Example Encoding
368
+
369
+ ```
370
+ encode("{sample_text}")
371
+ -> {encoded}
372
+ -> decode -> "{decoded}"
373
+ ```
374
+ """
375
+
376
+ with open(os.path.join(NOTES_DIR, 'tokenizer-internals.md'), 'w') as f:
377
+ f.write(tokenizer_internals)
378
+ print(f"Wrote tokenizer-internals.md ({len(tokenizer_internals.splitlines())} lines)")
379
+
380
+ # ── 5c. prompt-format.md ─────────────────────────────────────────────────────
381
+
382
+ # Get the decoded full prompt
383
+ tok.encode = capture_encode
384
+ enc_tokens2 = _build_encoder_input(tok, query, tools_json)
385
+ tok.encode = original_encode # restore
386
+
387
+ # Reconstruct what each segment contains
388
+ q_toks_raw = tok.encode(query)
389
+ t_toks_raw = tok.encode(tools_json)
390
+
391
+ MAX_ENC = DEFAULT_MAX_ENC_LEN # 1024
392
+ max_query = MAX_ENC - 2
393
+ q_toks_trunc = q_toks_raw[:max_query]
394
+ remaining = MAX_ENC - len(q_toks_trunc) - 1
395
+ t_toks_trunc = t_toks_raw[:remaining]
396
+
397
+ final_tokens = q_toks_trunc + [TOOLS_ID] + t_toks_trunc
398
+ decoded_full = tok.decode(final_tokens)
399
+
400
+ # Decoder starts with EOS token (1), then model predicts <tool_call> (4) first
401
+ prompt_format = f"""# Needle Prompt Format
402
+
403
+ ## Overview
404
+
405
+ Needle uses an **encoder–decoder** architecture. The query + tools are encoded by the encoder; the decoder generates the tool-call JSON.
406
+
407
+ ## Encoder Input Format
408
+
409
+ ```
410
+ [query_tokens..., <tools>(id={TOOLS_ID}), tools_tokens...]
411
+ ```
412
+
413
+ **Built by `_build_encoder_input(tokenizer, query, tools, max_enc_len={DEFAULT_MAX_ENC_LEN})`:**
414
+
415
+ ```python
416
+ {inspect.getsource(_build_encoder_input)}
417
+ ```
418
+
419
+ ### Truncation Rules
420
+
421
+ 1. Query is truncated to `max_enc_len - 2 = {MAX_ENC - 2}` tokens
422
+ 2. Tools are truncated to `max_enc_len - len(q_toks) - 1` tokens (fills remaining space)
423
+ 3. The `<tools>` separator token (id={TOOLS_ID}) is always inserted between query and tools
424
+
425
+ ### Example: `query="set a 5 min timer"`, `tools=[{{"name": "set_timer", ...}}]`
426
+
427
+ **Encode calls made by `_build_encoder_input`:**
428
+ 1. `tokenizer.encode("{query}")` β†’ {q_toks_raw} ({len(q_toks_raw)} tokens)
429
+ 2. `tokenizer.encode('{tools_json[:80]}...')` β†’ {t_toks_raw[:20]}... ({len(t_toks_raw)} tokens total)
430
+
431
+ **Final encoder token sequence (len={len(final_tokens)}):**
432
+ ```
433
+ {final_tokens}
434
+ ```
435
+
436
+ **Decoded encoder input string:**
437
+ ```
438
+ {repr(decoded_full)}
439
+ ```
440
+
441
+ **Structural breakdown:**
442
+ - Query tokens ({len(q_toks_trunc)} tokens): `{q_toks_trunc}`
443
+ - Separator token (1 token): `[{TOOLS_ID}]` β†’ `<tools>`
444
+ - Tools tokens ({len(t_toks_trunc)} tokens): `{t_toks_trunc}`
445
+
446
+ ---
447
+
448
+ ## Decoder Input Format
449
+
450
+ The decoder is initialized with a single **EOS token** (id={EOS_ID}) as prefix:
451
+
452
+ ```python
453
+ dec_buffer = jnp.full((1, max_gen_len), pad_id, dtype=jnp.int32)
454
+ dec_buffer = dec_buffer.at[0, 0].set(eos_id) # prefix = [EOS]
455
+ ```
456
+
457
+ At each step, the model predicts the next token autoregressively.
458
+
459
+ **Expected output sequence:** `<tool_call>(id={TOOL_CALL_ID}) + answer_json_tokens + EOS(id={EOS_ID})`
460
+
461
+ The `<tool_call>` prefix is stripped from the decoded output string in `generate()`.
462
+
463
+ ---
464
+
465
+ ## Tool Name Normalization
466
+
467
+ Before building encoder input, tool names are converted to `snake_case` via `normalize_tools()`:
468
+ - `camelCase` β†’ `camel_case`
469
+ - `PascalCase` β†’ `pascal_case`
470
+ - `dot.notation` β†’ `dot_notation`
471
+ - Non-alphanumeric chars β†’ underscores
472
+
473
+ After decoding, `restore_tool_names()` maps back to original names.
474
+
475
+ ---
476
+
477
+ ## Full `generate()` Signature
478
+
479
+ ```python
480
+ {inspect.getsource(generate)}
481
+ ```
482
+ """
483
+
484
+ with open(os.path.join(NOTES_DIR, 'prompt-format.md'), 'w') as f:
485
+ f.write(prompt_format)
486
+ print(f"Wrote prompt-format.md ({len(prompt_format.splitlines())} lines)")
487
+
488
+ print("\n=== All notes written successfully ===")