Upload inspect_needle.py with huggingface_hub
Browse files- inspect_needle.py +488 -0
inspect_needle.py
ADDED
|
@@ -0,0 +1,488 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
inspect_needle.py β Introspect the Cactus/Needle Flax model, tokenizer, and prompt format.
|
| 3 |
+
|
| 4 |
+
Outputs three markdown files:
|
| 5 |
+
../notes/needle-internals.md
|
| 6 |
+
../notes/tokenizer-internals.md
|
| 7 |
+
../notes/prompt-format.md
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import json
|
| 11 |
+
import os
|
| 12 |
+
import sys
|
| 13 |
+
import textwrap
|
| 14 |
+
import inspect
|
| 15 |
+
import traceback
|
| 16 |
+
|
| 17 |
+
# Add the needle package to path
|
| 18 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'external', 'needle'))
|
| 19 |
+
|
| 20 |
+
import jax
|
| 21 |
+
import jax.numpy as jnp
|
| 22 |
+
import numpy as np
|
| 23 |
+
|
| 24 |
+
# ββ 1. Import needle symbols ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 25 |
+
from needle.model.architecture import (
|
| 26 |
+
SimpleAttentionNetwork, TransformerConfig,
|
| 27 |
+
MultiHeadAttention, FeedForward, EncoderBlock, DecoderBlock,
|
| 28 |
+
Encoder, Decoder, ZCRMSNorm, make_causal_mask, make_padding_mask,
|
| 29 |
+
)
|
| 30 |
+
from needle.model.run import generate, _build_encoder_input, load_checkpoint
|
| 31 |
+
from needle.dataset.tokenizer import (
|
| 32 |
+
NeedleTokenizer, get_tokenizer,
|
| 33 |
+
PAD_ID, EOS_ID, BOS_ID, UNK_ID, TOOL_CALL_ID, TOOLS_ID,
|
| 34 |
+
DEFAULT_MAX_ENC_LEN, DEFAULT_MAX_DEC_LEN, DEFAULT_MAX_GEN_LEN,
|
| 35 |
+
TOKENIZER_PREFIX,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
print("=== Needle/Cactus Inspector ===\n")
|
| 39 |
+
|
| 40 |
+
# ββ 2. Instantiate model with default config ββββββββββββββββββββββββββββββββββ
|
| 41 |
+
config = TransformerConfig()
|
| 42 |
+
print(f"TransformerConfig defaults:")
|
| 43 |
+
for field in config.__dataclass_fields__:
|
| 44 |
+
print(f" {field} = {getattr(config, field)}")
|
| 45 |
+
|
| 46 |
+
model = SimpleAttentionNetwork(config)
|
| 47 |
+
|
| 48 |
+
# Use init_all to initialize all parameters (both text + contrastive paths)
|
| 49 |
+
dummy_src = jnp.zeros((1, 16), dtype=jnp.int32)
|
| 50 |
+
dummy_tgt = jnp.zeros((1, 16), dtype=jnp.int32)
|
| 51 |
+
variables = model.init(jax.random.PRNGKey(0), dummy_src, dummy_tgt, method='init_all')
|
| 52 |
+
params = variables['params']
|
| 53 |
+
|
| 54 |
+
print("\n=== Param Tree (paths + shapes) ===")
|
| 55 |
+
|
| 56 |
+
def print_param_tree(tree, prefix='', indent=0):
|
| 57 |
+
"""Recursively print param tree with shapes."""
|
| 58 |
+
lines = []
|
| 59 |
+
if isinstance(tree, dict):
|
| 60 |
+
for k, v in sorted(tree.items()):
|
| 61 |
+
if isinstance(v, dict):
|
| 62 |
+
lines.append(' ' * indent + f'{k}:')
|
| 63 |
+
lines.extend(print_param_tree(v, prefix=prefix+k+'.', indent=indent+1))
|
| 64 |
+
else:
|
| 65 |
+
shape = tuple(v.shape) if hasattr(v, 'shape') else repr(v)
|
| 66 |
+
dtype = str(v.dtype) if hasattr(v, 'dtype') else '?'
|
| 67 |
+
lines.append(' ' * indent + f'{k}: shape={shape}, dtype={dtype}')
|
| 68 |
+
return lines
|
| 69 |
+
|
| 70 |
+
param_lines = print_param_tree(params)
|
| 71 |
+
for line in param_lines:
|
| 72 |
+
print(line)
|
| 73 |
+
|
| 74 |
+
# Build shape dict for JSON serialization
|
| 75 |
+
def build_shape_dict(tree):
|
| 76 |
+
if isinstance(tree, dict):
|
| 77 |
+
return {k: build_shape_dict(v) for k, v in sorted(tree.items())}
|
| 78 |
+
else:
|
| 79 |
+
return {'shape': list(tree.shape), 'dtype': str(tree.dtype)}
|
| 80 |
+
|
| 81 |
+
shape_dict = build_shape_dict(params)
|
| 82 |
+
|
| 83 |
+
# Count parameters
|
| 84 |
+
total_params = sum(x.size for x in jax.tree.leaves(params))
|
| 85 |
+
print(f"\nTotal parameters: {total_params:,}")
|
| 86 |
+
|
| 87 |
+
# ββ 3. Tokenizer info βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 88 |
+
print("\n=== Tokenizer ===")
|
| 89 |
+
tok = get_tokenizer()
|
| 90 |
+
print(f"Class: {type(tok).__name__}")
|
| 91 |
+
print(f"vocab_size: {tok.vocab_size}")
|
| 92 |
+
print(f"pad_token_id: {tok.pad_token_id}")
|
| 93 |
+
print(f"eos_token_id: {tok.eos_token_id}")
|
| 94 |
+
print(f"bos_token_id: {tok.bos_token_id}")
|
| 95 |
+
print(f"tool_call_token_id: {tok.tool_call_token_id}")
|
| 96 |
+
print(f"tools_token_id: {tok.tools_token_id}")
|
| 97 |
+
print(f"attrs: {[a for a in dir(tok) if not a.startswith('_')]}")
|
| 98 |
+
|
| 99 |
+
# Test encode/decode
|
| 100 |
+
sample_text = "set a 5 min timer"
|
| 101 |
+
encoded = tok.encode(sample_text)
|
| 102 |
+
decoded = tok.decode(encoded)
|
| 103 |
+
print(f"\nExample encode('{sample_text}'):")
|
| 104 |
+
print(f" -> {encoded}")
|
| 105 |
+
print(f" -> decode -> '{decoded}'")
|
| 106 |
+
|
| 107 |
+
# Get vocab size from SentencePiece directly
|
| 108 |
+
sp = tok.sp
|
| 109 |
+
print(f"\nSentencePiece GetPieceSize(): {sp.GetPieceSize()}")
|
| 110 |
+
print(f"Model type: BPE (confirmed in train_tokenizer source)")
|
| 111 |
+
print(f"byte_fallback: True (confirmed in train_tokenizer source)")
|
| 112 |
+
print(f"normalization_rule_name: identity (no Unicode normalization)")
|
| 113 |
+
|
| 114 |
+
# ββ 4. Prompt format capture ββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 115 |
+
print("\n=== Prompt Format Capture ===")
|
| 116 |
+
|
| 117 |
+
query = "set a 5 min timer"
|
| 118 |
+
tools = [{"name": "set_timer", "description": "Set a timer.", "parameters": {"time_human": {"type": "string", "description": "duration", "required": True}}}]
|
| 119 |
+
tools_json = json.dumps(tools)
|
| 120 |
+
|
| 121 |
+
captured_inputs = []
|
| 122 |
+
original_encode = tok.encode
|
| 123 |
+
|
| 124 |
+
def capture_encode(text, *a, **kw):
|
| 125 |
+
captured_inputs.append(('encode', text))
|
| 126 |
+
return original_encode(text, *a, **kw)
|
| 127 |
+
|
| 128 |
+
tok.encode = capture_encode
|
| 129 |
+
|
| 130 |
+
# Call _build_encoder_input directly
|
| 131 |
+
enc_tokens = _build_encoder_input(tok, query, tools_json)
|
| 132 |
+
print(f"Captured encode calls: {captured_inputs}")
|
| 133 |
+
print(f"Encoder token IDs: {enc_tokens}")
|
| 134 |
+
print(f"Encoder token count: {len(enc_tokens)}")
|
| 135 |
+
|
| 136 |
+
# Decode the encoder input to see the full prompt
|
| 137 |
+
decoded_enc = tok.decode(enc_tokens)
|
| 138 |
+
print(f"Decoded encoder input: {repr(decoded_enc)}")
|
| 139 |
+
|
| 140 |
+
# Also decode each segment
|
| 141 |
+
tok.encode = original_encode # restore
|
| 142 |
+
q_toks = tok.encode(query)
|
| 143 |
+
t_toks = tok.encode(tools_json)
|
| 144 |
+
print(f"\nQuery tokens: {q_toks} -> '{tok.decode(q_toks)}'")
|
| 145 |
+
print(f"tools_token_id (separator): {TOOLS_ID}")
|
| 146 |
+
print(f"Tools tokens: {t_toks[:20]}... -> '{tok.decode(t_toks[:20])}'")
|
| 147 |
+
|
| 148 |
+
# ββ 5. Write notes files ββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 149 |
+
|
| 150 |
+
NOTES_DIR = os.path.join(os.path.dirname(__file__), '..', 'notes')
|
| 151 |
+
os.makedirs(NOTES_DIR, exist_ok=True)
|
| 152 |
+
|
| 153 |
+
# ββ 5a. needle-internals.md βββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 154 |
+
|
| 155 |
+
def fmt_shape_dict(d, indent=0):
|
| 156 |
+
"""Format shape dict as markdown nested list."""
|
| 157 |
+
lines = []
|
| 158 |
+
prefix = ' ' * indent
|
| 159 |
+
for k, v in sorted(d.items()):
|
| 160 |
+
if isinstance(v, dict) and 'shape' not in v:
|
| 161 |
+
lines.append(f"{prefix}- **{k}**")
|
| 162 |
+
lines.extend(fmt_shape_dict(v, indent+1))
|
| 163 |
+
else:
|
| 164 |
+
shape = v.get('shape', '?')
|
| 165 |
+
dtype = v.get('dtype', '?')
|
| 166 |
+
lines.append(f"{prefix} - `{k}`: shape={shape}, dtype={dtype}")
|
| 167 |
+
return lines
|
| 168 |
+
|
| 169 |
+
needle_internals = f"""# Needle (Cactus) Flax Model Internals
|
| 170 |
+
|
| 171 |
+
## Model: `SimpleAttentionNetwork`
|
| 172 |
+
|
| 173 |
+
**Architecture:** Encoderβdecoder transformer with shared embeddings, tied output projection, RoPE, bfloat16.
|
| 174 |
+
|
| 175 |
+
### `TransformerConfig` Defaults
|
| 176 |
+
|
| 177 |
+
| Field | Value | Notes |
|
| 178 |
+
|---|---|---|
|
| 179 |
+
| vocab_size | {config.vocab_size} | SentencePiece BPE vocab |
|
| 180 |
+
| d_model | {config.d_model} | Hidden dimension |
|
| 181 |
+
| num_heads | {config.num_heads} | Attention heads |
|
| 182 |
+
| num_kv_heads | {config.num_kv_heads} | GQA key/value heads |
|
| 183 |
+
| num_encoder_layers | {config.num_encoder_layers} | Encoder depth |
|
| 184 |
+
| num_decoder_layers | {config.num_decoder_layers} | Decoder depth |
|
| 185 |
+
| d_ff | {config.d_ff} | FFN hidden dim (unused by default) |
|
| 186 |
+
| max_seq_len | {config.max_seq_len} | |
|
| 187 |
+
| pad_token_id | {config.pad_token_id} | |
|
| 188 |
+
| rope_theta | {config.rope_theta} | RoPE base frequency |
|
| 189 |
+
| dtype | {config.dtype} | |
|
| 190 |
+
| activation | {config.activation} | dual-ReLU gate (drelu) |
|
| 191 |
+
| num_memory_slots | {config.num_memory_slots} | |
|
| 192 |
+
| dropout_rate | {config.dropout_rate} | |
|
| 193 |
+
| contrastive_dim | {config.contrastive_dim} | Contrastive projection output dim |
|
| 194 |
+
| no_feedforward | {config.no_feedforward} | **True β FFN blocks are SKIPPED** |
|
| 195 |
+
|
| 196 |
+
**Total parameters:** {total_params:,}
|
| 197 |
+
|
| 198 |
+
**Important:** `no_feedforward=True` by default β `FeedForward` / `gate_proj` / `up_proj` / `down_proj` are NOT present in the saved param tree. The `ffn_gate` scalar also absent. This is a pure attention-only transformer.
|
| 199 |
+
|
| 200 |
+
---
|
| 201 |
+
|
| 202 |
+
## Layer-by-Layer Architecture
|
| 203 |
+
|
| 204 |
+
### Encoder (`Encoder` module)
|
| 205 |
+
|
| 206 |
+
Uses `nn.scan` over `num_encoder_layers={config.num_encoder_layers}` layers. Each `EncoderBlock` contains:
|
| 207 |
+
|
| 208 |
+
1. **`attn_gate`** β scalar sigmoid gate `()` on attention residual
|
| 209 |
+
2. **Pre-norm**: `ZCRMSNorm` (scale initialized to 0, applied as `(1+Ξ³) * x / RMS(x)`)
|
| 210 |
+
3. **`self_attn`** (`MultiHeadAttention`):
|
| 211 |
+
- `q_proj`: Dense `(d_model, d_model)` = `({config.d_model}, {config.d_model})`
|
| 212 |
+
- `k_proj`: Dense `(d_model, num_kv_heads * head_dim)` = `({config.d_model}, {config.num_kv_heads * (config.d_model // config.num_heads)})`
|
| 213 |
+
- `v_proj`: Dense `(d_model, num_kv_heads * head_dim)` = `({config.d_model}, {config.num_kv_heads * (config.d_model // config.num_heads)})`
|
| 214 |
+
- `out_proj`: Dense `(d_model, d_model)` = `({config.d_model}, {config.d_model})`
|
| 215 |
+
- `q_norm`, `k_norm`: `ZCRMSNorm` on Q and K (pre-RoPE)
|
| 216 |
+
4. Residual: `x = x_in + gate * attn_out`
|
| 217 |
+
5. (FFN block **skipped** because `no_feedforward=True`)
|
| 218 |
+
6. **`final_norm`**: `ZCRMSNorm` after all layers (on encoder output)
|
| 219 |
+
|
| 220 |
+
### Decoder (`Decoder` module)
|
| 221 |
+
|
| 222 |
+
Uses `nn.scan` over `num_decoder_layers={config.num_decoder_layers}` layers. Each `DecoderBlock` contains:
|
| 223 |
+
|
| 224 |
+
1. **`self_attn_gate`** β scalar on causal self-attention residual
|
| 225 |
+
2. Pre-norm + **`self_attn`** (causal self-attention, same structure as encoder)
|
| 226 |
+
3. Residual: `x = x_in + self_gate * self_attn_out`
|
| 227 |
+
4. **`cross_attn_gate`** β scalar on cross-attention residual
|
| 228 |
+
5. Pre-norm + **`cross_attn`** (cross-attention: Q from decoder, K/V from encoder)
|
| 229 |
+
6. Residual: `x = x_in + cross_gate * cross_attn_out`
|
| 230 |
+
7. (FFN block **skipped**)
|
| 231 |
+
8. Final `ZCRMSNorm` after all layers
|
| 232 |
+
|
| 233 |
+
### Top-level `SimpleAttentionNetwork` params
|
| 234 |
+
|
| 235 |
+
- **`embedding`**: `nn.Embed` kernel `(vocab_size, d_model)` = `({config.vocab_size}, {config.d_model})`
|
| 236 |
+
Also used as tied output: `logits = hidden @ embedding.T`
|
| 237 |
+
- **`log_temp`**: scalar `()` β contrastive temperature
|
| 238 |
+
- **`contrastive_hidden`**: Dense `(d_model, d_model//4)` = `({config.d_model}, {config.d_model//4})`
|
| 239 |
+
- **`contrastive_proj`**: Dense `(d_model//4, contrastive_dim)` = `({config.d_model//4}, {config.contrastive_dim})`, no bias
|
| 240 |
+
|
| 241 |
+
---
|
| 242 |
+
|
| 243 |
+
## `nn.scan` Parameter Layout
|
| 244 |
+
|
| 245 |
+
Because Flax `nn.scan` with `variable_axes={{"params": 0}}` is used, each scanned layer stack is stored as a **single** parameter array with an extra leading dimension of size `num_layers`. So for encoder: `layers.attn_gate` has shape `(num_encoder_layers, ...)` = `({config.num_encoder_layers}, ...)`.
|
| 246 |
+
|
| 247 |
+
**This is critical for Task 2C:** When mapping Flax β PyTorch weights, you must index `params['encoder']['layers']['...'][layer_idx]` to get per-layer weights.
|
| 248 |
+
|
| 249 |
+
---
|
| 250 |
+
|
| 251 |
+
## Full Flax `params` Tree (shapes)
|
| 252 |
+
|
| 253 |
+
```json
|
| 254 |
+
{json.dumps(shape_dict, indent=2)}
|
| 255 |
+
```
|
| 256 |
+
|
| 257 |
+
---
|
| 258 |
+
|
| 259 |
+
## `SimpleAttentionNetwork.__call__` Source
|
| 260 |
+
|
| 261 |
+
```python
|
| 262 |
+
{inspect.getsource(SimpleAttentionNetwork.__call__)}
|
| 263 |
+
```
|
| 264 |
+
|
| 265 |
+
## `EncoderBlock.__call__` Source
|
| 266 |
+
|
| 267 |
+
```python
|
| 268 |
+
{inspect.getsource(EncoderBlock.__call__)}
|
| 269 |
+
```
|
| 270 |
+
|
| 271 |
+
## `DecoderBlock.__call__` Source
|
| 272 |
+
|
| 273 |
+
```python
|
| 274 |
+
{inspect.getsource(DecoderBlock.__call__)}
|
| 275 |
+
```
|
| 276 |
+
|
| 277 |
+
## `MultiHeadAttention.__call__` Source
|
| 278 |
+
|
| 279 |
+
```python
|
| 280 |
+
{inspect.getsource(MultiHeadAttention.__call__)}
|
| 281 |
+
```
|
| 282 |
+
"""
|
| 283 |
+
|
| 284 |
+
with open(os.path.join(NOTES_DIR, 'needle-internals.md'), 'w') as f:
|
| 285 |
+
f.write(needle_internals)
|
| 286 |
+
print(f"\nWrote needle-internals.md ({len(needle_internals.splitlines())} lines)")
|
| 287 |
+
|
| 288 |
+
# ββ 5b. tokenizer-internals.md ββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 289 |
+
|
| 290 |
+
tokenizer_internals = f"""# Needle Tokenizer Internals
|
| 291 |
+
|
| 292 |
+
## Tokenizer Class
|
| 293 |
+
|
| 294 |
+
**Class:** `NeedleTokenizer` (defined in `external/needle/needle/dataset/tokenizer.py`)
|
| 295 |
+
|
| 296 |
+
**Backend:** SentencePiece BPE β the model file is downloaded from HuggingFace on first use.
|
| 297 |
+
|
| 298 |
+
**HuggingFace repo:** `Cactus-Compute/needle-tokenizer` (dataset repo)
|
| 299 |
+
|
| 300 |
+
**Local model file path:** `{TOKENIZER_PREFIX}.model`
|
| 301 |
+
(relative to the external/needle repo root: `needle/tokenizer/needle.model`)
|
| 302 |
+
|
| 303 |
+
---
|
| 304 |
+
|
| 305 |
+
## Tokenizer Properties
|
| 306 |
+
|
| 307 |
+
| Property | Value | Source |
|
| 308 |
+
|---|---|---|
|
| 309 |
+
| `vocab_size` | {tok.vocab_size} | `sp.GetPieceSize()` |
|
| 310 |
+
| `pad_token_id` | {PAD_ID} | hardcoded = 0 |
|
| 311 |
+
| `eos_token_id` | {EOS_ID} | hardcoded = 1 |
|
| 312 |
+
| `bos_token_id` | {BOS_ID} | hardcoded = 2 |
|
| 313 |
+
| `unk_id` | {UNK_ID} | hardcoded = 3 |
|
| 314 |
+
| `tool_call_token_id` | {TOOL_CALL_ID} | hardcoded = 4, symbol `<tool_call>` |
|
| 315 |
+
| `tools_token_id` | {TOOLS_ID} | hardcoded = 5, symbol `<tools>` |
|
| 316 |
+
|
| 317 |
+
---
|
| 318 |
+
|
| 319 |
+
## Encode Algorithm
|
| 320 |
+
|
| 321 |
+
**Algorithm:** SentencePiece **BPE** (Byte-Pair Encoding)
|
| 322 |
+
|
| 323 |
+
**Configured with:**
|
| 324 |
+
- `model_type="bpe"`
|
| 325 |
+
- `byte_fallback=True` β unknown bytes represented as `<0xNN>` hex tokens
|
| 326 |
+
- `normalization_rule_name="identity"` β **no Unicode normalization applied**
|
| 327 |
+
- `user_defined_symbols=["<tool_call>", "<tools>"]` β fixed IDs 4 and 5
|
| 328 |
+
|
| 329 |
+
**Encode call in NeedleTokenizer:**
|
| 330 |
+
```python
|
| 331 |
+
def encode(self, text):
|
| 332 |
+
return self.sp.Encode(text, out_type=int)
|
| 333 |
+
```
|
| 334 |
+
Returns a Python `list[int]`. Does NOT add BOS/EOS automatically.
|
| 335 |
+
|
| 336 |
+
**Decode call:**
|
| 337 |
+
```python
|
| 338 |
+
def decode(self, ids):
|
| 339 |
+
return self.sp.Decode(list(ids))
|
| 340 |
+
```
|
| 341 |
+
|
| 342 |
+
---
|
| 343 |
+
|
| 344 |
+
## Available Attributes on `NeedleTokenizer`
|
| 345 |
+
|
| 346 |
+
```
|
| 347 |
+
{[a for a in dir(tok) if not a.startswith('_')]}
|
| 348 |
+
```
|
| 349 |
+
|
| 350 |
+
The underlying SentencePiece processor is accessible as `tok.sp`.
|
| 351 |
+
|
| 352 |
+
---
|
| 353 |
+
|
| 354 |
+
## SentencePiece Model File
|
| 355 |
+
|
| 356 |
+
The `.model` file is a serialized SentencePiece protobuf. For Task 5 (port tokenizer to JS/ONNX/etc.), you need either:
|
| 357 |
+
1. The `.model` file directly (loadable by the `sentencepiece` library)
|
| 358 |
+
2. Export the vocabulary + merge rules via `tok.sp.GetPieceSize()`, `tok.sp.IdToPiece(i)`, etc.
|
| 359 |
+
|
| 360 |
+
To export the full vocabulary:
|
| 361 |
+
```python
|
| 362 |
+
vocab = [(i, tok.sp.IdToPiece(i), tok.sp.GetScore(i)) for i in range(tok.vocab_size)]
|
| 363 |
+
```
|
| 364 |
+
|
| 365 |
+
---
|
| 366 |
+
|
| 367 |
+
## Example Encoding
|
| 368 |
+
|
| 369 |
+
```
|
| 370 |
+
encode("{sample_text}")
|
| 371 |
+
-> {encoded}
|
| 372 |
+
-> decode -> "{decoded}"
|
| 373 |
+
```
|
| 374 |
+
"""
|
| 375 |
+
|
| 376 |
+
with open(os.path.join(NOTES_DIR, 'tokenizer-internals.md'), 'w') as f:
|
| 377 |
+
f.write(tokenizer_internals)
|
| 378 |
+
print(f"Wrote tokenizer-internals.md ({len(tokenizer_internals.splitlines())} lines)")
|
| 379 |
+
|
| 380 |
+
# ββ 5c. prompt-format.md βββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 381 |
+
|
| 382 |
+
# Get the decoded full prompt
|
| 383 |
+
tok.encode = capture_encode
|
| 384 |
+
enc_tokens2 = _build_encoder_input(tok, query, tools_json)
|
| 385 |
+
tok.encode = original_encode # restore
|
| 386 |
+
|
| 387 |
+
# Reconstruct what each segment contains
|
| 388 |
+
q_toks_raw = tok.encode(query)
|
| 389 |
+
t_toks_raw = tok.encode(tools_json)
|
| 390 |
+
|
| 391 |
+
MAX_ENC = DEFAULT_MAX_ENC_LEN # 1024
|
| 392 |
+
max_query = MAX_ENC - 2
|
| 393 |
+
q_toks_trunc = q_toks_raw[:max_query]
|
| 394 |
+
remaining = MAX_ENC - len(q_toks_trunc) - 1
|
| 395 |
+
t_toks_trunc = t_toks_raw[:remaining]
|
| 396 |
+
|
| 397 |
+
final_tokens = q_toks_trunc + [TOOLS_ID] + t_toks_trunc
|
| 398 |
+
decoded_full = tok.decode(final_tokens)
|
| 399 |
+
|
| 400 |
+
# Decoder starts with EOS token (1), then model predicts <tool_call> (4) first
|
| 401 |
+
prompt_format = f"""# Needle Prompt Format
|
| 402 |
+
|
| 403 |
+
## Overview
|
| 404 |
+
|
| 405 |
+
Needle uses an **encoderβdecoder** architecture. The query + tools are encoded by the encoder; the decoder generates the tool-call JSON.
|
| 406 |
+
|
| 407 |
+
## Encoder Input Format
|
| 408 |
+
|
| 409 |
+
```
|
| 410 |
+
[query_tokens..., <tools>(id={TOOLS_ID}), tools_tokens...]
|
| 411 |
+
```
|
| 412 |
+
|
| 413 |
+
**Built by `_build_encoder_input(tokenizer, query, tools, max_enc_len={DEFAULT_MAX_ENC_LEN})`:**
|
| 414 |
+
|
| 415 |
+
```python
|
| 416 |
+
{inspect.getsource(_build_encoder_input)}
|
| 417 |
+
```
|
| 418 |
+
|
| 419 |
+
### Truncation Rules
|
| 420 |
+
|
| 421 |
+
1. Query is truncated to `max_enc_len - 2 = {MAX_ENC - 2}` tokens
|
| 422 |
+
2. Tools are truncated to `max_enc_len - len(q_toks) - 1` tokens (fills remaining space)
|
| 423 |
+
3. The `<tools>` separator token (id={TOOLS_ID}) is always inserted between query and tools
|
| 424 |
+
|
| 425 |
+
### Example: `query="set a 5 min timer"`, `tools=[{{"name": "set_timer", ...}}]`
|
| 426 |
+
|
| 427 |
+
**Encode calls made by `_build_encoder_input`:**
|
| 428 |
+
1. `tokenizer.encode("{query}")` β {q_toks_raw} ({len(q_toks_raw)} tokens)
|
| 429 |
+
2. `tokenizer.encode('{tools_json[:80]}...')` β {t_toks_raw[:20]}... ({len(t_toks_raw)} tokens total)
|
| 430 |
+
|
| 431 |
+
**Final encoder token sequence (len={len(final_tokens)}):**
|
| 432 |
+
```
|
| 433 |
+
{final_tokens}
|
| 434 |
+
```
|
| 435 |
+
|
| 436 |
+
**Decoded encoder input string:**
|
| 437 |
+
```
|
| 438 |
+
{repr(decoded_full)}
|
| 439 |
+
```
|
| 440 |
+
|
| 441 |
+
**Structural breakdown:**
|
| 442 |
+
- Query tokens ({len(q_toks_trunc)} tokens): `{q_toks_trunc}`
|
| 443 |
+
- Separator token (1 token): `[{TOOLS_ID}]` β `<tools>`
|
| 444 |
+
- Tools tokens ({len(t_toks_trunc)} tokens): `{t_toks_trunc}`
|
| 445 |
+
|
| 446 |
+
---
|
| 447 |
+
|
| 448 |
+
## Decoder Input Format
|
| 449 |
+
|
| 450 |
+
The decoder is initialized with a single **EOS token** (id={EOS_ID}) as prefix:
|
| 451 |
+
|
| 452 |
+
```python
|
| 453 |
+
dec_buffer = jnp.full((1, max_gen_len), pad_id, dtype=jnp.int32)
|
| 454 |
+
dec_buffer = dec_buffer.at[0, 0].set(eos_id) # prefix = [EOS]
|
| 455 |
+
```
|
| 456 |
+
|
| 457 |
+
At each step, the model predicts the next token autoregressively.
|
| 458 |
+
|
| 459 |
+
**Expected output sequence:** `<tool_call>(id={TOOL_CALL_ID}) + answer_json_tokens + EOS(id={EOS_ID})`
|
| 460 |
+
|
| 461 |
+
The `<tool_call>` prefix is stripped from the decoded output string in `generate()`.
|
| 462 |
+
|
| 463 |
+
---
|
| 464 |
+
|
| 465 |
+
## Tool Name Normalization
|
| 466 |
+
|
| 467 |
+
Before building encoder input, tool names are converted to `snake_case` via `normalize_tools()`:
|
| 468 |
+
- `camelCase` β `camel_case`
|
| 469 |
+
- `PascalCase` β `pascal_case`
|
| 470 |
+
- `dot.notation` β `dot_notation`
|
| 471 |
+
- Non-alphanumeric chars β underscores
|
| 472 |
+
|
| 473 |
+
After decoding, `restore_tool_names()` maps back to original names.
|
| 474 |
+
|
| 475 |
+
---
|
| 476 |
+
|
| 477 |
+
## Full `generate()` Signature
|
| 478 |
+
|
| 479 |
+
```python
|
| 480 |
+
{inspect.getsource(generate)}
|
| 481 |
+
```
|
| 482 |
+
"""
|
| 483 |
+
|
| 484 |
+
with open(os.path.join(NOTES_DIR, 'prompt-format.md'), 'w') as f:
|
| 485 |
+
f.write(prompt_format)
|
| 486 |
+
print(f"Wrote prompt-format.md ({len(prompt_format.splitlines())} lines)")
|
| 487 |
+
|
| 488 |
+
print("\n=== All notes written successfully ===")
|