| """ |
| inspect_needle.py β Introspect the Cactus/Needle Flax model, tokenizer, and prompt format. |
| |
| Outputs three markdown files: |
| ../notes/needle-internals.md |
| ../notes/tokenizer-internals.md |
| ../notes/prompt-format.md |
| """ |
|
|
| import json |
| import os |
| import sys |
| import textwrap |
| import inspect |
| import traceback |
|
|
| |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'external', 'needle')) |
|
|
| import jax |
| import jax.numpy as jnp |
| import numpy as np |
|
|
| |
| from needle.model.architecture import ( |
| SimpleAttentionNetwork, TransformerConfig, |
| MultiHeadAttention, FeedForward, EncoderBlock, DecoderBlock, |
| Encoder, Decoder, ZCRMSNorm, make_causal_mask, make_padding_mask, |
| ) |
| from needle.model.run import generate, _build_encoder_input, load_checkpoint |
| from needle.dataset.tokenizer import ( |
| NeedleTokenizer, get_tokenizer, |
| PAD_ID, EOS_ID, BOS_ID, UNK_ID, TOOL_CALL_ID, TOOLS_ID, |
| DEFAULT_MAX_ENC_LEN, DEFAULT_MAX_DEC_LEN, DEFAULT_MAX_GEN_LEN, |
| TOKENIZER_PREFIX, |
| ) |
|
|
| print("=== Needle/Cactus Inspector ===\n") |
|
|
| |
| config = TransformerConfig() |
| print(f"TransformerConfig defaults:") |
| for field in config.__dataclass_fields__: |
| print(f" {field} = {getattr(config, field)}") |
|
|
| model = SimpleAttentionNetwork(config) |
|
|
| |
| dummy_src = jnp.zeros((1, 16), dtype=jnp.int32) |
| dummy_tgt = jnp.zeros((1, 16), dtype=jnp.int32) |
| variables = model.init(jax.random.PRNGKey(0), dummy_src, dummy_tgt, method='init_all') |
| params = variables['params'] |
|
|
| print("\n=== Param Tree (paths + shapes) ===") |
|
|
| def print_param_tree(tree, prefix='', indent=0): |
| """Recursively print param tree with shapes.""" |
| lines = [] |
| if isinstance(tree, dict): |
| for k, v in sorted(tree.items()): |
| if isinstance(v, dict): |
| lines.append(' ' * indent + f'{k}:') |
| lines.extend(print_param_tree(v, prefix=prefix+k+'.', indent=indent+1)) |
| else: |
| shape = tuple(v.shape) if hasattr(v, 'shape') else repr(v) |
| dtype = str(v.dtype) if hasattr(v, 'dtype') else '?' |
| lines.append(' ' * indent + f'{k}: shape={shape}, dtype={dtype}') |
| return lines |
|
|
| param_lines = print_param_tree(params) |
| for line in param_lines: |
| print(line) |
|
|
| |
| def build_shape_dict(tree): |
| if isinstance(tree, dict): |
| return {k: build_shape_dict(v) for k, v in sorted(tree.items())} |
| else: |
| return {'shape': list(tree.shape), 'dtype': str(tree.dtype)} |
|
|
| shape_dict = build_shape_dict(params) |
|
|
| |
| total_params = sum(x.size for x in jax.tree.leaves(params)) |
| print(f"\nTotal parameters: {total_params:,}") |
|
|
| |
| print("\n=== Tokenizer ===") |
| tok = get_tokenizer() |
| print(f"Class: {type(tok).__name__}") |
| print(f"vocab_size: {tok.vocab_size}") |
| print(f"pad_token_id: {tok.pad_token_id}") |
| print(f"eos_token_id: {tok.eos_token_id}") |
| print(f"bos_token_id: {tok.bos_token_id}") |
| print(f"tool_call_token_id: {tok.tool_call_token_id}") |
| print(f"tools_token_id: {tok.tools_token_id}") |
| print(f"attrs: {[a for a in dir(tok) if not a.startswith('_')]}") |
|
|
| |
| sample_text = "set a 5 min timer" |
| encoded = tok.encode(sample_text) |
| decoded = tok.decode(encoded) |
| print(f"\nExample encode('{sample_text}'):") |
| print(f" -> {encoded}") |
| print(f" -> decode -> '{decoded}'") |
|
|
| |
| sp = tok.sp |
| print(f"\nSentencePiece GetPieceSize(): {sp.GetPieceSize()}") |
| print(f"Model type: BPE (confirmed in train_tokenizer source)") |
| print(f"byte_fallback: True (confirmed in train_tokenizer source)") |
| print(f"normalization_rule_name: identity (no Unicode normalization)") |
|
|
| |
| print("\n=== Prompt Format Capture ===") |
|
|
| query = "set a 5 min timer" |
| tools = [{"name": "set_timer", "description": "Set a timer.", "parameters": {"time_human": {"type": "string", "description": "duration", "required": True}}}] |
| tools_json = json.dumps(tools) |
|
|
| captured_inputs = [] |
| original_encode = tok.encode |
|
|
| def capture_encode(text, *a, **kw): |
| captured_inputs.append(('encode', text)) |
| return original_encode(text, *a, **kw) |
|
|
| tok.encode = capture_encode |
|
|
| |
| enc_tokens = _build_encoder_input(tok, query, tools_json) |
| print(f"Captured encode calls: {captured_inputs}") |
| print(f"Encoder token IDs: {enc_tokens}") |
| print(f"Encoder token count: {len(enc_tokens)}") |
|
|
| |
| decoded_enc = tok.decode(enc_tokens) |
| print(f"Decoded encoder input: {repr(decoded_enc)}") |
|
|
| |
| tok.encode = original_encode |
| q_toks = tok.encode(query) |
| t_toks = tok.encode(tools_json) |
| print(f"\nQuery tokens: {q_toks} -> '{tok.decode(q_toks)}'") |
| print(f"tools_token_id (separator): {TOOLS_ID}") |
| print(f"Tools tokens: {t_toks[:20]}... -> '{tok.decode(t_toks[:20])}'") |
|
|
| |
|
|
| NOTES_DIR = os.path.join(os.path.dirname(__file__), '..', 'notes') |
| os.makedirs(NOTES_DIR, exist_ok=True) |
|
|
| |
|
|
| def fmt_shape_dict(d, indent=0): |
| """Format shape dict as markdown nested list.""" |
| lines = [] |
| prefix = ' ' * indent |
| for k, v in sorted(d.items()): |
| if isinstance(v, dict) and 'shape' not in v: |
| lines.append(f"{prefix}- **{k}**") |
| lines.extend(fmt_shape_dict(v, indent+1)) |
| else: |
| shape = v.get('shape', '?') |
| dtype = v.get('dtype', '?') |
| lines.append(f"{prefix} - `{k}`: shape={shape}, dtype={dtype}") |
| return lines |
|
|
| needle_internals = f"""# Needle (Cactus) Flax Model Internals |
| |
| ## Model: `SimpleAttentionNetwork` |
| |
| **Architecture:** Encoderβdecoder transformer with shared embeddings, tied output projection, RoPE, bfloat16. |
| |
| ### `TransformerConfig` Defaults |
| |
| | Field | Value | Notes | |
| |---|---|---| |
| | vocab_size | {config.vocab_size} | SentencePiece BPE vocab | |
| | d_model | {config.d_model} | Hidden dimension | |
| | num_heads | {config.num_heads} | Attention heads | |
| | num_kv_heads | {config.num_kv_heads} | GQA key/value heads | |
| | num_encoder_layers | {config.num_encoder_layers} | Encoder depth | |
| | num_decoder_layers | {config.num_decoder_layers} | Decoder depth | |
| | d_ff | {config.d_ff} | FFN hidden dim (unused by default) | |
| | max_seq_len | {config.max_seq_len} | | |
| | pad_token_id | {config.pad_token_id} | | |
| | rope_theta | {config.rope_theta} | RoPE base frequency | |
| | dtype | {config.dtype} | | |
| | activation | {config.activation} | dual-ReLU gate (drelu) | |
| | num_memory_slots | {config.num_memory_slots} | | |
| | dropout_rate | {config.dropout_rate} | | |
| | contrastive_dim | {config.contrastive_dim} | Contrastive projection output dim | |
| | no_feedforward | {config.no_feedforward} | **True β FFN blocks are SKIPPED** | |
| |
| **Total parameters:** {total_params:,} |
| |
| **Important:** `no_feedforward=True` by default β `FeedForward` / `gate_proj` / `up_proj` / `down_proj` are NOT present in the saved param tree. The `ffn_gate` scalar also absent. This is a pure attention-only transformer. |
| |
| --- |
| |
| ## Layer-by-Layer Architecture |
| |
| ### Encoder (`Encoder` module) |
| |
| Uses `nn.scan` over `num_encoder_layers={config.num_encoder_layers}` layers. Each `EncoderBlock` contains: |
| |
| 1. **`attn_gate`** β scalar sigmoid gate `()` on attention residual |
| 2. **Pre-norm**: `ZCRMSNorm` (scale initialized to 0, applied as `(1+Ξ³) * x / RMS(x)`) |
| 3. **`self_attn`** (`MultiHeadAttention`): |
| - `q_proj`: Dense `(d_model, d_model)` = `({config.d_model}, {config.d_model})` |
| - `k_proj`: Dense `(d_model, num_kv_heads * head_dim)` = `({config.d_model}, {config.num_kv_heads * (config.d_model // config.num_heads)})` |
| - `v_proj`: Dense `(d_model, num_kv_heads * head_dim)` = `({config.d_model}, {config.num_kv_heads * (config.d_model // config.num_heads)})` |
| - `out_proj`: Dense `(d_model, d_model)` = `({config.d_model}, {config.d_model})` |
| - `q_norm`, `k_norm`: `ZCRMSNorm` on Q and K (pre-RoPE) |
| 4. Residual: `x = x_in + gate * attn_out` |
| 5. (FFN block **skipped** because `no_feedforward=True`) |
| 6. **`final_norm`**: `ZCRMSNorm` after all layers (on encoder output) |
| |
| ### Decoder (`Decoder` module) |
| |
| Uses `nn.scan` over `num_decoder_layers={config.num_decoder_layers}` layers. Each `DecoderBlock` contains: |
| |
| 1. **`self_attn_gate`** β scalar on causal self-attention residual |
| 2. Pre-norm + **`self_attn`** (causal self-attention, same structure as encoder) |
| 3. Residual: `x = x_in + self_gate * self_attn_out` |
| 4. **`cross_attn_gate`** β scalar on cross-attention residual |
| 5. Pre-norm + **`cross_attn`** (cross-attention: Q from decoder, K/V from encoder) |
| 6. Residual: `x = x_in + cross_gate * cross_attn_out` |
| 7. (FFN block **skipped**) |
| 8. Final `ZCRMSNorm` after all layers |
| |
| ### Top-level `SimpleAttentionNetwork` params |
| |
| - **`embedding`**: `nn.Embed` kernel `(vocab_size, d_model)` = `({config.vocab_size}, {config.d_model})` |
| Also used as tied output: `logits = hidden @ embedding.T` |
| - **`log_temp`**: scalar `()` β contrastive temperature |
| - **`contrastive_hidden`**: Dense `(d_model, d_model//4)` = `({config.d_model}, {config.d_model//4})` |
| - **`contrastive_proj`**: Dense `(d_model//4, contrastive_dim)` = `({config.d_model//4}, {config.contrastive_dim})`, no bias |
| |
| --- |
| |
| ## `nn.scan` Parameter Layout |
| |
| Because Flax `nn.scan` with `variable_axes={{"params": 0}}` is used, each scanned layer stack is stored as a **single** parameter array with an extra leading dimension of size `num_layers`. So for encoder: `layers.attn_gate` has shape `(num_encoder_layers, ...)` = `({config.num_encoder_layers}, ...)`. |
| |
| **This is critical for Task 2C:** When mapping Flax β PyTorch weights, you must index `params['encoder']['layers']['...'][layer_idx]` to get per-layer weights. |
| |
| --- |
| |
| ## Full Flax `params` Tree (shapes) |
| |
| ```json |
| {json.dumps(shape_dict, indent=2)} |
| ``` |
| |
| --- |
| |
| ## `SimpleAttentionNetwork.__call__` Source |
| |
| ```python |
| {inspect.getsource(SimpleAttentionNetwork.__call__)} |
| ``` |
| |
| ## `EncoderBlock.__call__` Source |
| |
| ```python |
| {inspect.getsource(EncoderBlock.__call__)} |
| ``` |
| |
| ## `DecoderBlock.__call__` Source |
| |
| ```python |
| {inspect.getsource(DecoderBlock.__call__)} |
| ``` |
| |
| ## `MultiHeadAttention.__call__` Source |
| |
| ```python |
| {inspect.getsource(MultiHeadAttention.__call__)} |
| ``` |
| """ |
|
|
| with open(os.path.join(NOTES_DIR, 'needle-internals.md'), 'w') as f: |
| f.write(needle_internals) |
| print(f"\nWrote needle-internals.md ({len(needle_internals.splitlines())} lines)") |
|
|
| |
|
|
| tokenizer_internals = f"""# Needle Tokenizer Internals |
| |
| ## Tokenizer Class |
| |
| **Class:** `NeedleTokenizer` (defined in `external/needle/needle/dataset/tokenizer.py`) |
| |
| **Backend:** SentencePiece BPE β the model file is downloaded from HuggingFace on first use. |
| |
| **HuggingFace repo:** `Cactus-Compute/needle-tokenizer` (dataset repo) |
| |
| **Local model file path:** `{TOKENIZER_PREFIX}.model` |
| (relative to the external/needle repo root: `needle/tokenizer/needle.model`) |
| |
| --- |
| |
| ## Tokenizer Properties |
| |
| | Property | Value | Source | |
| |---|---|---| |
| | `vocab_size` | {tok.vocab_size} | `sp.GetPieceSize()` | |
| | `pad_token_id` | {PAD_ID} | hardcoded = 0 | |
| | `eos_token_id` | {EOS_ID} | hardcoded = 1 | |
| | `bos_token_id` | {BOS_ID} | hardcoded = 2 | |
| | `unk_id` | {UNK_ID} | hardcoded = 3 | |
| | `tool_call_token_id` | {TOOL_CALL_ID} | hardcoded = 4, symbol `<tool_call>` | |
| | `tools_token_id` | {TOOLS_ID} | hardcoded = 5, symbol `<tools>` | |
| |
| --- |
| |
| ## Encode Algorithm |
| |
| **Algorithm:** SentencePiece **BPE** (Byte-Pair Encoding) |
| |
| **Configured with:** |
| - `model_type="bpe"` |
| - `byte_fallback=True` β unknown bytes represented as `<0xNN>` hex tokens |
| - `normalization_rule_name="identity"` β **no Unicode normalization applied** |
| - `user_defined_symbols=["<tool_call>", "<tools>"]` β fixed IDs 4 and 5 |
| |
| **Encode call in NeedleTokenizer:** |
| ```python |
| def encode(self, text): |
| return self.sp.Encode(text, out_type=int) |
| ``` |
| Returns a Python `list[int]`. Does NOT add BOS/EOS automatically. |
| |
| **Decode call:** |
| ```python |
| def decode(self, ids): |
| return self.sp.Decode(list(ids)) |
| ``` |
| |
| --- |
| |
| ## Available Attributes on `NeedleTokenizer` |
| |
| ``` |
| {[a for a in dir(tok) if not a.startswith('_')]} |
| ``` |
| |
| The underlying SentencePiece processor is accessible as `tok.sp`. |
| |
| --- |
| |
| ## SentencePiece Model File |
| |
| The `.model` file is a serialized SentencePiece protobuf. For Task 5 (port tokenizer to JS/ONNX/etc.), you need either: |
| 1. The `.model` file directly (loadable by the `sentencepiece` library) |
| 2. Export the vocabulary + merge rules via `tok.sp.GetPieceSize()`, `tok.sp.IdToPiece(i)`, etc. |
| |
| To export the full vocabulary: |
| ```python |
| vocab = [(i, tok.sp.IdToPiece(i), tok.sp.GetScore(i)) for i in range(tok.vocab_size)] |
| ``` |
| |
| --- |
| |
| ## Example Encoding |
| |
| ``` |
| encode("{sample_text}") |
| -> {encoded} |
| -> decode -> "{decoded}" |
| ``` |
| """ |
|
|
| with open(os.path.join(NOTES_DIR, 'tokenizer-internals.md'), 'w') as f: |
| f.write(tokenizer_internals) |
| print(f"Wrote tokenizer-internals.md ({len(tokenizer_internals.splitlines())} lines)") |
|
|
| |
|
|
| |
| tok.encode = capture_encode |
| enc_tokens2 = _build_encoder_input(tok, query, tools_json) |
| tok.encode = original_encode |
|
|
| |
| q_toks_raw = tok.encode(query) |
| t_toks_raw = tok.encode(tools_json) |
|
|
| MAX_ENC = DEFAULT_MAX_ENC_LEN |
| max_query = MAX_ENC - 2 |
| q_toks_trunc = q_toks_raw[:max_query] |
| remaining = MAX_ENC - len(q_toks_trunc) - 1 |
| t_toks_trunc = t_toks_raw[:remaining] |
|
|
| final_tokens = q_toks_trunc + [TOOLS_ID] + t_toks_trunc |
| decoded_full = tok.decode(final_tokens) |
|
|
| |
| prompt_format = f"""# Needle Prompt Format |
| |
| ## Overview |
| |
| Needle uses an **encoderβdecoder** architecture. The query + tools are encoded by the encoder; the decoder generates the tool-call JSON. |
| |
| ## Encoder Input Format |
| |
| ``` |
| [query_tokens..., <tools>(id={TOOLS_ID}), tools_tokens...] |
| ``` |
| |
| **Built by `_build_encoder_input(tokenizer, query, tools, max_enc_len={DEFAULT_MAX_ENC_LEN})`:** |
| |
| ```python |
| {inspect.getsource(_build_encoder_input)} |
| ``` |
| |
| ### Truncation Rules |
| |
| 1. Query is truncated to `max_enc_len - 2 = {MAX_ENC - 2}` tokens |
| 2. Tools are truncated to `max_enc_len - len(q_toks) - 1` tokens (fills remaining space) |
| 3. The `<tools>` separator token (id={TOOLS_ID}) is always inserted between query and tools |
| |
| ### Example: `query="set a 5 min timer"`, `tools=[{{"name": "set_timer", ...}}]` |
| |
| **Encode calls made by `_build_encoder_input`:** |
| 1. `tokenizer.encode("{query}")` β {q_toks_raw} ({len(q_toks_raw)} tokens) |
| 2. `tokenizer.encode('{tools_json[:80]}...')` β {t_toks_raw[:20]}... ({len(t_toks_raw)} tokens total) |
| |
| **Final encoder token sequence (len={len(final_tokens)}):** |
| ``` |
| {final_tokens} |
| ``` |
| |
| **Decoded encoder input string:** |
| ``` |
| {repr(decoded_full)} |
| ``` |
| |
| **Structural breakdown:** |
| - Query tokens ({len(q_toks_trunc)} tokens): `{q_toks_trunc}` |
| - Separator token (1 token): `[{TOOLS_ID}]` β `<tools>` |
| - Tools tokens ({len(t_toks_trunc)} tokens): `{t_toks_trunc}` |
| |
| --- |
| |
| ## Decoder Input Format |
| |
| The decoder is initialized with a single **EOS token** (id={EOS_ID}) as prefix: |
| |
| ```python |
| dec_buffer = jnp.full((1, max_gen_len), pad_id, dtype=jnp.int32) |
| dec_buffer = dec_buffer.at[0, 0].set(eos_id) # prefix = [EOS] |
| ``` |
| |
| At each step, the model predicts the next token autoregressively. |
| |
| **Expected output sequence:** `<tool_call>(id={TOOL_CALL_ID}) + answer_json_tokens + EOS(id={EOS_ID})` |
| |
| The `<tool_call>` prefix is stripped from the decoded output string in `generate()`. |
| |
| --- |
| |
| ## Tool Name Normalization |
| |
| Before building encoder input, tool names are converted to `snake_case` via `normalize_tools()`: |
| - `camelCase` β `camel_case` |
| - `PascalCase` β `pascal_case` |
| - `dot.notation` β `dot_notation` |
| - Non-alphanumeric chars β underscores |
| |
| After decoding, `restore_tool_names()` maps back to original names. |
| |
| --- |
| |
| ## Full `generate()` Signature |
| |
| ```python |
| {inspect.getsource(generate)} |
| ``` |
| """ |
|
|
| with open(os.path.join(NOTES_DIR, 'prompt-format.md'), 'w') as f: |
| f.write(prompt_format) |
| print(f"Wrote prompt-format.md ({len(prompt_format.splitlines())} lines)") |
|
|
| print("\n=== All notes written successfully ===") |
|
|