needle-onnx / inspect_needle.py
shreyask's picture
Upload inspect_needle.py with huggingface_hub
5b2a426 verified
"""
inspect_needle.py β€” Introspect the Cactus/Needle Flax model, tokenizer, and prompt format.
Outputs three markdown files:
../notes/needle-internals.md
../notes/tokenizer-internals.md
../notes/prompt-format.md
"""
import json
import os
import sys
import textwrap
import inspect
import traceback
# Add the needle package to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'external', 'needle'))
import jax
import jax.numpy as jnp
import numpy as np
# ── 1. Import needle symbols ──────────────────────────────────────────────────
from needle.model.architecture import (
SimpleAttentionNetwork, TransformerConfig,
MultiHeadAttention, FeedForward, EncoderBlock, DecoderBlock,
Encoder, Decoder, ZCRMSNorm, make_causal_mask, make_padding_mask,
)
from needle.model.run import generate, _build_encoder_input, load_checkpoint
from needle.dataset.tokenizer import (
NeedleTokenizer, get_tokenizer,
PAD_ID, EOS_ID, BOS_ID, UNK_ID, TOOL_CALL_ID, TOOLS_ID,
DEFAULT_MAX_ENC_LEN, DEFAULT_MAX_DEC_LEN, DEFAULT_MAX_GEN_LEN,
TOKENIZER_PREFIX,
)
print("=== Needle/Cactus Inspector ===\n")
# ── 2. Instantiate model with default config ──────────────────────────────────
config = TransformerConfig()
print(f"TransformerConfig defaults:")
for field in config.__dataclass_fields__:
print(f" {field} = {getattr(config, field)}")
model = SimpleAttentionNetwork(config)
# Use init_all to initialize all parameters (both text + contrastive paths)
dummy_src = jnp.zeros((1, 16), dtype=jnp.int32)
dummy_tgt = jnp.zeros((1, 16), dtype=jnp.int32)
variables = model.init(jax.random.PRNGKey(0), dummy_src, dummy_tgt, method='init_all')
params = variables['params']
print("\n=== Param Tree (paths + shapes) ===")
def print_param_tree(tree, prefix='', indent=0):
"""Recursively print param tree with shapes."""
lines = []
if isinstance(tree, dict):
for k, v in sorted(tree.items()):
if isinstance(v, dict):
lines.append(' ' * indent + f'{k}:')
lines.extend(print_param_tree(v, prefix=prefix+k+'.', indent=indent+1))
else:
shape = tuple(v.shape) if hasattr(v, 'shape') else repr(v)
dtype = str(v.dtype) if hasattr(v, 'dtype') else '?'
lines.append(' ' * indent + f'{k}: shape={shape}, dtype={dtype}')
return lines
param_lines = print_param_tree(params)
for line in param_lines:
print(line)
# Build shape dict for JSON serialization
def build_shape_dict(tree):
if isinstance(tree, dict):
return {k: build_shape_dict(v) for k, v in sorted(tree.items())}
else:
return {'shape': list(tree.shape), 'dtype': str(tree.dtype)}
shape_dict = build_shape_dict(params)
# Count parameters
total_params = sum(x.size for x in jax.tree.leaves(params))
print(f"\nTotal parameters: {total_params:,}")
# ── 3. Tokenizer info ─────────────────────────────────────────────────────────
print("\n=== Tokenizer ===")
tok = get_tokenizer()
print(f"Class: {type(tok).__name__}")
print(f"vocab_size: {tok.vocab_size}")
print(f"pad_token_id: {tok.pad_token_id}")
print(f"eos_token_id: {tok.eos_token_id}")
print(f"bos_token_id: {tok.bos_token_id}")
print(f"tool_call_token_id: {tok.tool_call_token_id}")
print(f"tools_token_id: {tok.tools_token_id}")
print(f"attrs: {[a for a in dir(tok) if not a.startswith('_')]}")
# Test encode/decode
sample_text = "set a 5 min timer"
encoded = tok.encode(sample_text)
decoded = tok.decode(encoded)
print(f"\nExample encode('{sample_text}'):")
print(f" -> {encoded}")
print(f" -> decode -> '{decoded}'")
# Get vocab size from SentencePiece directly
sp = tok.sp
print(f"\nSentencePiece GetPieceSize(): {sp.GetPieceSize()}")
print(f"Model type: BPE (confirmed in train_tokenizer source)")
print(f"byte_fallback: True (confirmed in train_tokenizer source)")
print(f"normalization_rule_name: identity (no Unicode normalization)")
# ── 4. Prompt format capture ──────────────────────────────────────────────────
print("\n=== Prompt Format Capture ===")
query = "set a 5 min timer"
tools = [{"name": "set_timer", "description": "Set a timer.", "parameters": {"time_human": {"type": "string", "description": "duration", "required": True}}}]
tools_json = json.dumps(tools)
captured_inputs = []
original_encode = tok.encode
def capture_encode(text, *a, **kw):
captured_inputs.append(('encode', text))
return original_encode(text, *a, **kw)
tok.encode = capture_encode
# Call _build_encoder_input directly
enc_tokens = _build_encoder_input(tok, query, tools_json)
print(f"Captured encode calls: {captured_inputs}")
print(f"Encoder token IDs: {enc_tokens}")
print(f"Encoder token count: {len(enc_tokens)}")
# Decode the encoder input to see the full prompt
decoded_enc = tok.decode(enc_tokens)
print(f"Decoded encoder input: {repr(decoded_enc)}")
# Also decode each segment
tok.encode = original_encode # restore
q_toks = tok.encode(query)
t_toks = tok.encode(tools_json)
print(f"\nQuery tokens: {q_toks} -> '{tok.decode(q_toks)}'")
print(f"tools_token_id (separator): {TOOLS_ID}")
print(f"Tools tokens: {t_toks[:20]}... -> '{tok.decode(t_toks[:20])}'")
# ── 5. Write notes files ──────────────────────────────────────────────────────
NOTES_DIR = os.path.join(os.path.dirname(__file__), '..', 'notes')
os.makedirs(NOTES_DIR, exist_ok=True)
# ── 5a. needle-internals.md ───────────────────────────────────────────────────
def fmt_shape_dict(d, indent=0):
"""Format shape dict as markdown nested list."""
lines = []
prefix = ' ' * indent
for k, v in sorted(d.items()):
if isinstance(v, dict) and 'shape' not in v:
lines.append(f"{prefix}- **{k}**")
lines.extend(fmt_shape_dict(v, indent+1))
else:
shape = v.get('shape', '?')
dtype = v.get('dtype', '?')
lines.append(f"{prefix} - `{k}`: shape={shape}, dtype={dtype}")
return lines
needle_internals = f"""# Needle (Cactus) Flax Model Internals
## Model: `SimpleAttentionNetwork`
**Architecture:** Encoder–decoder transformer with shared embeddings, tied output projection, RoPE, bfloat16.
### `TransformerConfig` Defaults
| Field | Value | Notes |
|---|---|---|
| vocab_size | {config.vocab_size} | SentencePiece BPE vocab |
| d_model | {config.d_model} | Hidden dimension |
| num_heads | {config.num_heads} | Attention heads |
| num_kv_heads | {config.num_kv_heads} | GQA key/value heads |
| num_encoder_layers | {config.num_encoder_layers} | Encoder depth |
| num_decoder_layers | {config.num_decoder_layers} | Decoder depth |
| d_ff | {config.d_ff} | FFN hidden dim (unused by default) |
| max_seq_len | {config.max_seq_len} | |
| pad_token_id | {config.pad_token_id} | |
| rope_theta | {config.rope_theta} | RoPE base frequency |
| dtype | {config.dtype} | |
| activation | {config.activation} | dual-ReLU gate (drelu) |
| num_memory_slots | {config.num_memory_slots} | |
| dropout_rate | {config.dropout_rate} | |
| contrastive_dim | {config.contrastive_dim} | Contrastive projection output dim |
| no_feedforward | {config.no_feedforward} | **True β†’ FFN blocks are SKIPPED** |
**Total parameters:** {total_params:,}
**Important:** `no_feedforward=True` by default β€” `FeedForward` / `gate_proj` / `up_proj` / `down_proj` are NOT present in the saved param tree. The `ffn_gate` scalar also absent. This is a pure attention-only transformer.
---
## Layer-by-Layer Architecture
### Encoder (`Encoder` module)
Uses `nn.scan` over `num_encoder_layers={config.num_encoder_layers}` layers. Each `EncoderBlock` contains:
1. **`attn_gate`** β€” scalar sigmoid gate `()` on attention residual
2. **Pre-norm**: `ZCRMSNorm` (scale initialized to 0, applied as `(1+Ξ³) * x / RMS(x)`)
3. **`self_attn`** (`MultiHeadAttention`):
- `q_proj`: Dense `(d_model, d_model)` = `({config.d_model}, {config.d_model})`
- `k_proj`: Dense `(d_model, num_kv_heads * head_dim)` = `({config.d_model}, {config.num_kv_heads * (config.d_model // config.num_heads)})`
- `v_proj`: Dense `(d_model, num_kv_heads * head_dim)` = `({config.d_model}, {config.num_kv_heads * (config.d_model // config.num_heads)})`
- `out_proj`: Dense `(d_model, d_model)` = `({config.d_model}, {config.d_model})`
- `q_norm`, `k_norm`: `ZCRMSNorm` on Q and K (pre-RoPE)
4. Residual: `x = x_in + gate * attn_out`
5. (FFN block **skipped** because `no_feedforward=True`)
6. **`final_norm`**: `ZCRMSNorm` after all layers (on encoder output)
### Decoder (`Decoder` module)
Uses `nn.scan` over `num_decoder_layers={config.num_decoder_layers}` layers. Each `DecoderBlock` contains:
1. **`self_attn_gate`** β€” scalar on causal self-attention residual
2. Pre-norm + **`self_attn`** (causal self-attention, same structure as encoder)
3. Residual: `x = x_in + self_gate * self_attn_out`
4. **`cross_attn_gate`** β€” scalar on cross-attention residual
5. Pre-norm + **`cross_attn`** (cross-attention: Q from decoder, K/V from encoder)
6. Residual: `x = x_in + cross_gate * cross_attn_out`
7. (FFN block **skipped**)
8. Final `ZCRMSNorm` after all layers
### Top-level `SimpleAttentionNetwork` params
- **`embedding`**: `nn.Embed` kernel `(vocab_size, d_model)` = `({config.vocab_size}, {config.d_model})`
Also used as tied output: `logits = hidden @ embedding.T`
- **`log_temp`**: scalar `()` β€” contrastive temperature
- **`contrastive_hidden`**: Dense `(d_model, d_model//4)` = `({config.d_model}, {config.d_model//4})`
- **`contrastive_proj`**: Dense `(d_model//4, contrastive_dim)` = `({config.d_model//4}, {config.contrastive_dim})`, no bias
---
## `nn.scan` Parameter Layout
Because Flax `nn.scan` with `variable_axes={{"params": 0}}` is used, each scanned layer stack is stored as a **single** parameter array with an extra leading dimension of size `num_layers`. So for encoder: `layers.attn_gate` has shape `(num_encoder_layers, ...)` = `({config.num_encoder_layers}, ...)`.
**This is critical for Task 2C:** When mapping Flax β†’ PyTorch weights, you must index `params['encoder']['layers']['...'][layer_idx]` to get per-layer weights.
---
## Full Flax `params` Tree (shapes)
```json
{json.dumps(shape_dict, indent=2)}
```
---
## `SimpleAttentionNetwork.__call__` Source
```python
{inspect.getsource(SimpleAttentionNetwork.__call__)}
```
## `EncoderBlock.__call__` Source
```python
{inspect.getsource(EncoderBlock.__call__)}
```
## `DecoderBlock.__call__` Source
```python
{inspect.getsource(DecoderBlock.__call__)}
```
## `MultiHeadAttention.__call__` Source
```python
{inspect.getsource(MultiHeadAttention.__call__)}
```
"""
with open(os.path.join(NOTES_DIR, 'needle-internals.md'), 'w') as f:
f.write(needle_internals)
print(f"\nWrote needle-internals.md ({len(needle_internals.splitlines())} lines)")
# ── 5b. tokenizer-internals.md ────────────────────────────────────────────────
tokenizer_internals = f"""# Needle Tokenizer Internals
## Tokenizer Class
**Class:** `NeedleTokenizer` (defined in `external/needle/needle/dataset/tokenizer.py`)
**Backend:** SentencePiece BPE β€” the model file is downloaded from HuggingFace on first use.
**HuggingFace repo:** `Cactus-Compute/needle-tokenizer` (dataset repo)
**Local model file path:** `{TOKENIZER_PREFIX}.model`
(relative to the external/needle repo root: `needle/tokenizer/needle.model`)
---
## Tokenizer Properties
| Property | Value | Source |
|---|---|---|
| `vocab_size` | {tok.vocab_size} | `sp.GetPieceSize()` |
| `pad_token_id` | {PAD_ID} | hardcoded = 0 |
| `eos_token_id` | {EOS_ID} | hardcoded = 1 |
| `bos_token_id` | {BOS_ID} | hardcoded = 2 |
| `unk_id` | {UNK_ID} | hardcoded = 3 |
| `tool_call_token_id` | {TOOL_CALL_ID} | hardcoded = 4, symbol `<tool_call>` |
| `tools_token_id` | {TOOLS_ID} | hardcoded = 5, symbol `<tools>` |
---
## Encode Algorithm
**Algorithm:** SentencePiece **BPE** (Byte-Pair Encoding)
**Configured with:**
- `model_type="bpe"`
- `byte_fallback=True` β€” unknown bytes represented as `<0xNN>` hex tokens
- `normalization_rule_name="identity"` β€” **no Unicode normalization applied**
- `user_defined_symbols=["<tool_call>", "<tools>"]` β€” fixed IDs 4 and 5
**Encode call in NeedleTokenizer:**
```python
def encode(self, text):
return self.sp.Encode(text, out_type=int)
```
Returns a Python `list[int]`. Does NOT add BOS/EOS automatically.
**Decode call:**
```python
def decode(self, ids):
return self.sp.Decode(list(ids))
```
---
## Available Attributes on `NeedleTokenizer`
```
{[a for a in dir(tok) if not a.startswith('_')]}
```
The underlying SentencePiece processor is accessible as `tok.sp`.
---
## SentencePiece Model File
The `.model` file is a serialized SentencePiece protobuf. For Task 5 (port tokenizer to JS/ONNX/etc.), you need either:
1. The `.model` file directly (loadable by the `sentencepiece` library)
2. Export the vocabulary + merge rules via `tok.sp.GetPieceSize()`, `tok.sp.IdToPiece(i)`, etc.
To export the full vocabulary:
```python
vocab = [(i, tok.sp.IdToPiece(i), tok.sp.GetScore(i)) for i in range(tok.vocab_size)]
```
---
## Example Encoding
```
encode("{sample_text}")
-> {encoded}
-> decode -> "{decoded}"
```
"""
with open(os.path.join(NOTES_DIR, 'tokenizer-internals.md'), 'w') as f:
f.write(tokenizer_internals)
print(f"Wrote tokenizer-internals.md ({len(tokenizer_internals.splitlines())} lines)")
# ── 5c. prompt-format.md ─────────────────────────────────────────────────────
# Get the decoded full prompt
tok.encode = capture_encode
enc_tokens2 = _build_encoder_input(tok, query, tools_json)
tok.encode = original_encode # restore
# Reconstruct what each segment contains
q_toks_raw = tok.encode(query)
t_toks_raw = tok.encode(tools_json)
MAX_ENC = DEFAULT_MAX_ENC_LEN # 1024
max_query = MAX_ENC - 2
q_toks_trunc = q_toks_raw[:max_query]
remaining = MAX_ENC - len(q_toks_trunc) - 1
t_toks_trunc = t_toks_raw[:remaining]
final_tokens = q_toks_trunc + [TOOLS_ID] + t_toks_trunc
decoded_full = tok.decode(final_tokens)
# Decoder starts with EOS token (1), then model predicts <tool_call> (4) first
prompt_format = f"""# Needle Prompt Format
## Overview
Needle uses an **encoder–decoder** architecture. The query + tools are encoded by the encoder; the decoder generates the tool-call JSON.
## Encoder Input Format
```
[query_tokens..., <tools>(id={TOOLS_ID}), tools_tokens...]
```
**Built by `_build_encoder_input(tokenizer, query, tools, max_enc_len={DEFAULT_MAX_ENC_LEN})`:**
```python
{inspect.getsource(_build_encoder_input)}
```
### Truncation Rules
1. Query is truncated to `max_enc_len - 2 = {MAX_ENC - 2}` tokens
2. Tools are truncated to `max_enc_len - len(q_toks) - 1` tokens (fills remaining space)
3. The `<tools>` separator token (id={TOOLS_ID}) is always inserted between query and tools
### Example: `query="set a 5 min timer"`, `tools=[{{"name": "set_timer", ...}}]`
**Encode calls made by `_build_encoder_input`:**
1. `tokenizer.encode("{query}")` β†’ {q_toks_raw} ({len(q_toks_raw)} tokens)
2. `tokenizer.encode('{tools_json[:80]}...')` β†’ {t_toks_raw[:20]}... ({len(t_toks_raw)} tokens total)
**Final encoder token sequence (len={len(final_tokens)}):**
```
{final_tokens}
```
**Decoded encoder input string:**
```
{repr(decoded_full)}
```
**Structural breakdown:**
- Query tokens ({len(q_toks_trunc)} tokens): `{q_toks_trunc}`
- Separator token (1 token): `[{TOOLS_ID}]` β†’ `<tools>`
- Tools tokens ({len(t_toks_trunc)} tokens): `{t_toks_trunc}`
---
## Decoder Input Format
The decoder is initialized with a single **EOS token** (id={EOS_ID}) as prefix:
```python
dec_buffer = jnp.full((1, max_gen_len), pad_id, dtype=jnp.int32)
dec_buffer = dec_buffer.at[0, 0].set(eos_id) # prefix = [EOS]
```
At each step, the model predicts the next token autoregressively.
**Expected output sequence:** `<tool_call>(id={TOOL_CALL_ID}) + answer_json_tokens + EOS(id={EOS_ID})`
The `<tool_call>` prefix is stripped from the decoded output string in `generate()`.
---
## Tool Name Normalization
Before building encoder input, tool names are converted to `snake_case` via `normalize_tools()`:
- `camelCase` β†’ `camel_case`
- `PascalCase` β†’ `pascal_case`
- `dot.notation` β†’ `dot_notation`
- Non-alphanumeric chars β†’ underscores
After decoding, `restore_tool_names()` maps back to original names.
---
## Full `generate()` Signature
```python
{inspect.getsource(generate)}
```
"""
with open(os.path.join(NOTES_DIR, 'prompt-format.md'), 'w') as f:
f.write(prompt_format)
print(f"Wrote prompt-format.md ({len(prompt_format.splitlines())} lines)")
print("\n=== All notes written successfully ===")