""" inspect_needle.py — Introspect the Cactus/Needle Flax model, tokenizer, and prompt format. Outputs three markdown files: ../notes/needle-internals.md ../notes/tokenizer-internals.md ../notes/prompt-format.md """ import json import os import sys import textwrap import inspect import traceback # Add the needle package to path sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'external', 'needle')) import jax import jax.numpy as jnp import numpy as np # ── 1. Import needle symbols ────────────────────────────────────────────────── from needle.model.architecture import ( SimpleAttentionNetwork, TransformerConfig, MultiHeadAttention, FeedForward, EncoderBlock, DecoderBlock, Encoder, Decoder, ZCRMSNorm, make_causal_mask, make_padding_mask, ) from needle.model.run import generate, _build_encoder_input, load_checkpoint from needle.dataset.tokenizer import ( NeedleTokenizer, get_tokenizer, PAD_ID, EOS_ID, BOS_ID, UNK_ID, TOOL_CALL_ID, TOOLS_ID, DEFAULT_MAX_ENC_LEN, DEFAULT_MAX_DEC_LEN, DEFAULT_MAX_GEN_LEN, TOKENIZER_PREFIX, ) print("=== Needle/Cactus Inspector ===\n") # ── 2. Instantiate model with default config ────────────────────────────────── config = TransformerConfig() print(f"TransformerConfig defaults:") for field in config.__dataclass_fields__: print(f" {field} = {getattr(config, field)}") model = SimpleAttentionNetwork(config) # Use init_all to initialize all parameters (both text + contrastive paths) dummy_src = jnp.zeros((1, 16), dtype=jnp.int32) dummy_tgt = jnp.zeros((1, 16), dtype=jnp.int32) variables = model.init(jax.random.PRNGKey(0), dummy_src, dummy_tgt, method='init_all') params = variables['params'] print("\n=== Param Tree (paths + shapes) ===") def print_param_tree(tree, prefix='', indent=0): """Recursively print param tree with shapes.""" lines = [] if isinstance(tree, dict): for k, v in sorted(tree.items()): if isinstance(v, dict): lines.append(' ' * indent + f'{k}:') lines.extend(print_param_tree(v, prefix=prefix+k+'.', indent=indent+1)) else: shape = tuple(v.shape) if hasattr(v, 'shape') else repr(v) dtype = str(v.dtype) if hasattr(v, 'dtype') else '?' lines.append(' ' * indent + f'{k}: shape={shape}, dtype={dtype}') return lines param_lines = print_param_tree(params) for line in param_lines: print(line) # Build shape dict for JSON serialization def build_shape_dict(tree): if isinstance(tree, dict): return {k: build_shape_dict(v) for k, v in sorted(tree.items())} else: return {'shape': list(tree.shape), 'dtype': str(tree.dtype)} shape_dict = build_shape_dict(params) # Count parameters total_params = sum(x.size for x in jax.tree.leaves(params)) print(f"\nTotal parameters: {total_params:,}") # ── 3. Tokenizer info ───────────────────────────────────────────────────────── print("\n=== Tokenizer ===") tok = get_tokenizer() print(f"Class: {type(tok).__name__}") print(f"vocab_size: {tok.vocab_size}") print(f"pad_token_id: {tok.pad_token_id}") print(f"eos_token_id: {tok.eos_token_id}") print(f"bos_token_id: {tok.bos_token_id}") print(f"tool_call_token_id: {tok.tool_call_token_id}") print(f"tools_token_id: {tok.tools_token_id}") print(f"attrs: {[a for a in dir(tok) if not a.startswith('_')]}") # Test encode/decode sample_text = "set a 5 min timer" encoded = tok.encode(sample_text) decoded = tok.decode(encoded) print(f"\nExample encode('{sample_text}'):") print(f" -> {encoded}") print(f" -> decode -> '{decoded}'") # Get vocab size from SentencePiece directly sp = tok.sp print(f"\nSentencePiece GetPieceSize(): {sp.GetPieceSize()}") print(f"Model type: BPE (confirmed in train_tokenizer source)") print(f"byte_fallback: True (confirmed in train_tokenizer source)") print(f"normalization_rule_name: identity (no Unicode normalization)") # ── 4. Prompt format capture ────────────────────────────────────────────────── print("\n=== Prompt Format Capture ===") query = "set a 5 min timer" tools = [{"name": "set_timer", "description": "Set a timer.", "parameters": {"time_human": {"type": "string", "description": "duration", "required": True}}}] tools_json = json.dumps(tools) captured_inputs = [] original_encode = tok.encode def capture_encode(text, *a, **kw): captured_inputs.append(('encode', text)) return original_encode(text, *a, **kw) tok.encode = capture_encode # Call _build_encoder_input directly enc_tokens = _build_encoder_input(tok, query, tools_json) print(f"Captured encode calls: {captured_inputs}") print(f"Encoder token IDs: {enc_tokens}") print(f"Encoder token count: {len(enc_tokens)}") # Decode the encoder input to see the full prompt decoded_enc = tok.decode(enc_tokens) print(f"Decoded encoder input: {repr(decoded_enc)}") # Also decode each segment tok.encode = original_encode # restore q_toks = tok.encode(query) t_toks = tok.encode(tools_json) print(f"\nQuery tokens: {q_toks} -> '{tok.decode(q_toks)}'") print(f"tools_token_id (separator): {TOOLS_ID}") print(f"Tools tokens: {t_toks[:20]}... -> '{tok.decode(t_toks[:20])}'") # ── 5. Write notes files ────────────────────────────────────────────────────── NOTES_DIR = os.path.join(os.path.dirname(__file__), '..', 'notes') os.makedirs(NOTES_DIR, exist_ok=True) # ── 5a. needle-internals.md ─────────────────────────────────────────────────── def fmt_shape_dict(d, indent=0): """Format shape dict as markdown nested list.""" lines = [] prefix = ' ' * indent for k, v in sorted(d.items()): if isinstance(v, dict) and 'shape' not in v: lines.append(f"{prefix}- **{k}**") lines.extend(fmt_shape_dict(v, indent+1)) else: shape = v.get('shape', '?') dtype = v.get('dtype', '?') lines.append(f"{prefix} - `{k}`: shape={shape}, dtype={dtype}") return lines needle_internals = f"""# Needle (Cactus) Flax Model Internals ## Model: `SimpleAttentionNetwork` **Architecture:** Encoder–decoder transformer with shared embeddings, tied output projection, RoPE, bfloat16. ### `TransformerConfig` Defaults | Field | Value | Notes | |---|---|---| | vocab_size | {config.vocab_size} | SentencePiece BPE vocab | | d_model | {config.d_model} | Hidden dimension | | num_heads | {config.num_heads} | Attention heads | | num_kv_heads | {config.num_kv_heads} | GQA key/value heads | | num_encoder_layers | {config.num_encoder_layers} | Encoder depth | | num_decoder_layers | {config.num_decoder_layers} | Decoder depth | | d_ff | {config.d_ff} | FFN hidden dim (unused by default) | | max_seq_len | {config.max_seq_len} | | | pad_token_id | {config.pad_token_id} | | | rope_theta | {config.rope_theta} | RoPE base frequency | | dtype | {config.dtype} | | | activation | {config.activation} | dual-ReLU gate (drelu) | | num_memory_slots | {config.num_memory_slots} | | | dropout_rate | {config.dropout_rate} | | | contrastive_dim | {config.contrastive_dim} | Contrastive projection output dim | | no_feedforward | {config.no_feedforward} | **True → FFN blocks are SKIPPED** | **Total parameters:** {total_params:,} **Important:** `no_feedforward=True` by default — `FeedForward` / `gate_proj` / `up_proj` / `down_proj` are NOT present in the saved param tree. The `ffn_gate` scalar also absent. This is a pure attention-only transformer. --- ## Layer-by-Layer Architecture ### Encoder (`Encoder` module) Uses `nn.scan` over `num_encoder_layers={config.num_encoder_layers}` layers. Each `EncoderBlock` contains: 1. **`attn_gate`** — scalar sigmoid gate `()` on attention residual 2. **Pre-norm**: `ZCRMSNorm` (scale initialized to 0, applied as `(1+γ) * x / RMS(x)`) 3. **`self_attn`** (`MultiHeadAttention`): - `q_proj`: Dense `(d_model, d_model)` = `({config.d_model}, {config.d_model})` - `k_proj`: Dense `(d_model, num_kv_heads * head_dim)` = `({config.d_model}, {config.num_kv_heads * (config.d_model // config.num_heads)})` - `v_proj`: Dense `(d_model, num_kv_heads * head_dim)` = `({config.d_model}, {config.num_kv_heads * (config.d_model // config.num_heads)})` - `out_proj`: Dense `(d_model, d_model)` = `({config.d_model}, {config.d_model})` - `q_norm`, `k_norm`: `ZCRMSNorm` on Q and K (pre-RoPE) 4. Residual: `x = x_in + gate * attn_out` 5. (FFN block **skipped** because `no_feedforward=True`) 6. **`final_norm`**: `ZCRMSNorm` after all layers (on encoder output) ### Decoder (`Decoder` module) Uses `nn.scan` over `num_decoder_layers={config.num_decoder_layers}` layers. Each `DecoderBlock` contains: 1. **`self_attn_gate`** — scalar on causal self-attention residual 2. Pre-norm + **`self_attn`** (causal self-attention, same structure as encoder) 3. Residual: `x = x_in + self_gate * self_attn_out` 4. **`cross_attn_gate`** — scalar on cross-attention residual 5. Pre-norm + **`cross_attn`** (cross-attention: Q from decoder, K/V from encoder) 6. Residual: `x = x_in + cross_gate * cross_attn_out` 7. (FFN block **skipped**) 8. Final `ZCRMSNorm` after all layers ### Top-level `SimpleAttentionNetwork` params - **`embedding`**: `nn.Embed` kernel `(vocab_size, d_model)` = `({config.vocab_size}, {config.d_model})` Also used as tied output: `logits = hidden @ embedding.T` - **`log_temp`**: scalar `()` — contrastive temperature - **`contrastive_hidden`**: Dense `(d_model, d_model//4)` = `({config.d_model}, {config.d_model//4})` - **`contrastive_proj`**: Dense `(d_model//4, contrastive_dim)` = `({config.d_model//4}, {config.contrastive_dim})`, no bias --- ## `nn.scan` Parameter Layout Because Flax `nn.scan` with `variable_axes={{"params": 0}}` is used, each scanned layer stack is stored as a **single** parameter array with an extra leading dimension of size `num_layers`. So for encoder: `layers.attn_gate` has shape `(num_encoder_layers, ...)` = `({config.num_encoder_layers}, ...)`. **This is critical for Task 2C:** When mapping Flax → PyTorch weights, you must index `params['encoder']['layers']['...'][layer_idx]` to get per-layer weights. --- ## Full Flax `params` Tree (shapes) ```json {json.dumps(shape_dict, indent=2)} ``` --- ## `SimpleAttentionNetwork.__call__` Source ```python {inspect.getsource(SimpleAttentionNetwork.__call__)} ``` ## `EncoderBlock.__call__` Source ```python {inspect.getsource(EncoderBlock.__call__)} ``` ## `DecoderBlock.__call__` Source ```python {inspect.getsource(DecoderBlock.__call__)} ``` ## `MultiHeadAttention.__call__` Source ```python {inspect.getsource(MultiHeadAttention.__call__)} ``` """ with open(os.path.join(NOTES_DIR, 'needle-internals.md'), 'w') as f: f.write(needle_internals) print(f"\nWrote needle-internals.md ({len(needle_internals.splitlines())} lines)") # ── 5b. tokenizer-internals.md ──────────────────────────────────────────────── tokenizer_internals = f"""# Needle Tokenizer Internals ## Tokenizer Class **Class:** `NeedleTokenizer` (defined in `external/needle/needle/dataset/tokenizer.py`) **Backend:** SentencePiece BPE — the model file is downloaded from HuggingFace on first use. **HuggingFace repo:** `Cactus-Compute/needle-tokenizer` (dataset repo) **Local model file path:** `{TOKENIZER_PREFIX}.model` (relative to the external/needle repo root: `needle/tokenizer/needle.model`) --- ## Tokenizer Properties | Property | Value | Source | |---|---|---| | `vocab_size` | {tok.vocab_size} | `sp.GetPieceSize()` | | `pad_token_id` | {PAD_ID} | hardcoded = 0 | | `eos_token_id` | {EOS_ID} | hardcoded = 1 | | `bos_token_id` | {BOS_ID} | hardcoded = 2 | | `unk_id` | {UNK_ID} | hardcoded = 3 | | `tool_call_token_id` | {TOOL_CALL_ID} | hardcoded = 4, symbol `` | | `tools_token_id` | {TOOLS_ID} | hardcoded = 5, symbol `` | --- ## Encode Algorithm **Algorithm:** SentencePiece **BPE** (Byte-Pair Encoding) **Configured with:** - `model_type="bpe"` - `byte_fallback=True` — unknown bytes represented as `<0xNN>` hex tokens - `normalization_rule_name="identity"` — **no Unicode normalization applied** - `user_defined_symbols=["", ""]` — fixed IDs 4 and 5 **Encode call in NeedleTokenizer:** ```python def encode(self, text): return self.sp.Encode(text, out_type=int) ``` Returns a Python `list[int]`. Does NOT add BOS/EOS automatically. **Decode call:** ```python def decode(self, ids): return self.sp.Decode(list(ids)) ``` --- ## Available Attributes on `NeedleTokenizer` ``` {[a for a in dir(tok) if not a.startswith('_')]} ``` The underlying SentencePiece processor is accessible as `tok.sp`. --- ## SentencePiece Model File The `.model` file is a serialized SentencePiece protobuf. For Task 5 (port tokenizer to JS/ONNX/etc.), you need either: 1. The `.model` file directly (loadable by the `sentencepiece` library) 2. Export the vocabulary + merge rules via `tok.sp.GetPieceSize()`, `tok.sp.IdToPiece(i)`, etc. To export the full vocabulary: ```python vocab = [(i, tok.sp.IdToPiece(i), tok.sp.GetScore(i)) for i in range(tok.vocab_size)] ``` --- ## Example Encoding ``` encode("{sample_text}") -> {encoded} -> decode -> "{decoded}" ``` """ with open(os.path.join(NOTES_DIR, 'tokenizer-internals.md'), 'w') as f: f.write(tokenizer_internals) print(f"Wrote tokenizer-internals.md ({len(tokenizer_internals.splitlines())} lines)") # ── 5c. prompt-format.md ───────────────────────────────────────────────────── # Get the decoded full prompt tok.encode = capture_encode enc_tokens2 = _build_encoder_input(tok, query, tools_json) tok.encode = original_encode # restore # Reconstruct what each segment contains q_toks_raw = tok.encode(query) t_toks_raw = tok.encode(tools_json) MAX_ENC = DEFAULT_MAX_ENC_LEN # 1024 max_query = MAX_ENC - 2 q_toks_trunc = q_toks_raw[:max_query] remaining = MAX_ENC - len(q_toks_trunc) - 1 t_toks_trunc = t_toks_raw[:remaining] final_tokens = q_toks_trunc + [TOOLS_ID] + t_toks_trunc decoded_full = tok.decode(final_tokens) # Decoder starts with EOS token (1), then model predicts (4) first prompt_format = f"""# Needle Prompt Format ## Overview Needle uses an **encoder–decoder** architecture. The query + tools are encoded by the encoder; the decoder generates the tool-call JSON. ## Encoder Input Format ``` [query_tokens..., (id={TOOLS_ID}), tools_tokens...] ``` **Built by `_build_encoder_input(tokenizer, query, tools, max_enc_len={DEFAULT_MAX_ENC_LEN})`:** ```python {inspect.getsource(_build_encoder_input)} ``` ### Truncation Rules 1. Query is truncated to `max_enc_len - 2 = {MAX_ENC - 2}` tokens 2. Tools are truncated to `max_enc_len - len(q_toks) - 1` tokens (fills remaining space) 3. The `` separator token (id={TOOLS_ID}) is always inserted between query and tools ### Example: `query="set a 5 min timer"`, `tools=[{{"name": "set_timer", ...}}]` **Encode calls made by `_build_encoder_input`:** 1. `tokenizer.encode("{query}")` → {q_toks_raw} ({len(q_toks_raw)} tokens) 2. `tokenizer.encode('{tools_json[:80]}...')` → {t_toks_raw[:20]}... ({len(t_toks_raw)} tokens total) **Final encoder token sequence (len={len(final_tokens)}):** ``` {final_tokens} ``` **Decoded encoder input string:** ``` {repr(decoded_full)} ``` **Structural breakdown:** - Query tokens ({len(q_toks_trunc)} tokens): `{q_toks_trunc}` - Separator token (1 token): `[{TOOLS_ID}]` → `` - Tools tokens ({len(t_toks_trunc)} tokens): `{t_toks_trunc}` --- ## Decoder Input Format The decoder is initialized with a single **EOS token** (id={EOS_ID}) as prefix: ```python dec_buffer = jnp.full((1, max_gen_len), pad_id, dtype=jnp.int32) dec_buffer = dec_buffer.at[0, 0].set(eos_id) # prefix = [EOS] ``` At each step, the model predicts the next token autoregressively. **Expected output sequence:** `(id={TOOL_CALL_ID}) + answer_json_tokens + EOS(id={EOS_ID})` The `` prefix is stripped from the decoded output string in `generate()`. --- ## Tool Name Normalization Before building encoder input, tool names are converted to `snake_case` via `normalize_tools()`: - `camelCase` → `camel_case` - `PascalCase` → `pascal_case` - `dot.notation` → `dot_notation` - Non-alphanumeric chars → underscores After decoding, `restore_tool_names()` maps back to original names. --- ## Full `generate()` Signature ```python {inspect.getsource(generate)} ``` """ with open(os.path.join(NOTES_DIR, 'prompt-format.md'), 'w') as f: f.write(prompt_format) print(f"Wrote prompt-format.md ({len(prompt_format.splitlines())} lines)") print("\n=== All notes written successfully ===")