File size: 17,436 Bytes
5b2a426
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
"""
inspect_needle.py β€” Introspect the Cactus/Needle Flax model, tokenizer, and prompt format.

Outputs three markdown files:
  ../notes/needle-internals.md
  ../notes/tokenizer-internals.md
  ../notes/prompt-format.md
"""

import json
import os
import sys
import textwrap
import inspect
import traceback

# Add the needle package to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..', 'external', 'needle'))

import jax
import jax.numpy as jnp
import numpy as np

# ── 1. Import needle symbols ──────────────────────────────────────────────────
from needle.model.architecture import (
    SimpleAttentionNetwork, TransformerConfig,
    MultiHeadAttention, FeedForward, EncoderBlock, DecoderBlock,
    Encoder, Decoder, ZCRMSNorm, make_causal_mask, make_padding_mask,
)
from needle.model.run import generate, _build_encoder_input, load_checkpoint
from needle.dataset.tokenizer import (
    NeedleTokenizer, get_tokenizer,
    PAD_ID, EOS_ID, BOS_ID, UNK_ID, TOOL_CALL_ID, TOOLS_ID,
    DEFAULT_MAX_ENC_LEN, DEFAULT_MAX_DEC_LEN, DEFAULT_MAX_GEN_LEN,
    TOKENIZER_PREFIX,
)

print("=== Needle/Cactus Inspector ===\n")

# ── 2. Instantiate model with default config ──────────────────────────────────
config = TransformerConfig()
print(f"TransformerConfig defaults:")
for field in config.__dataclass_fields__:
    print(f"  {field} = {getattr(config, field)}")

model = SimpleAttentionNetwork(config)

# Use init_all to initialize all parameters (both text + contrastive paths)
dummy_src = jnp.zeros((1, 16), dtype=jnp.int32)
dummy_tgt = jnp.zeros((1, 16), dtype=jnp.int32)
variables = model.init(jax.random.PRNGKey(0), dummy_src, dummy_tgt, method='init_all')
params = variables['params']

print("\n=== Param Tree (paths + shapes) ===")

def print_param_tree(tree, prefix='', indent=0):
    """Recursively print param tree with shapes."""
    lines = []
    if isinstance(tree, dict):
        for k, v in sorted(tree.items()):
            if isinstance(v, dict):
                lines.append('  ' * indent + f'{k}:')
                lines.extend(print_param_tree(v, prefix=prefix+k+'.', indent=indent+1))
            else:
                shape = tuple(v.shape) if hasattr(v, 'shape') else repr(v)
                dtype = str(v.dtype) if hasattr(v, 'dtype') else '?'
                lines.append('  ' * indent + f'{k}: shape={shape}, dtype={dtype}')
    return lines

param_lines = print_param_tree(params)
for line in param_lines:
    print(line)

# Build shape dict for JSON serialization
def build_shape_dict(tree):
    if isinstance(tree, dict):
        return {k: build_shape_dict(v) for k, v in sorted(tree.items())}
    else:
        return {'shape': list(tree.shape), 'dtype': str(tree.dtype)}

shape_dict = build_shape_dict(params)

# Count parameters
total_params = sum(x.size for x in jax.tree.leaves(params))
print(f"\nTotal parameters: {total_params:,}")

# ── 3. Tokenizer info ─────────────────────────────────────────────────────────
print("\n=== Tokenizer ===")
tok = get_tokenizer()
print(f"Class: {type(tok).__name__}")
print(f"vocab_size: {tok.vocab_size}")
print(f"pad_token_id: {tok.pad_token_id}")
print(f"eos_token_id: {tok.eos_token_id}")
print(f"bos_token_id: {tok.bos_token_id}")
print(f"tool_call_token_id: {tok.tool_call_token_id}")
print(f"tools_token_id: {tok.tools_token_id}")
print(f"attrs: {[a for a in dir(tok) if not a.startswith('_')]}")

# Test encode/decode
sample_text = "set a 5 min timer"
encoded = tok.encode(sample_text)
decoded = tok.decode(encoded)
print(f"\nExample encode('{sample_text}'):")
print(f"  -> {encoded}")
print(f"  -> decode -> '{decoded}'")

# Get vocab size from SentencePiece directly
sp = tok.sp
print(f"\nSentencePiece GetPieceSize(): {sp.GetPieceSize()}")
print(f"Model type: BPE (confirmed in train_tokenizer source)")
print(f"byte_fallback: True (confirmed in train_tokenizer source)")
print(f"normalization_rule_name: identity (no Unicode normalization)")

# ── 4. Prompt format capture ──────────────────────────────────────────────────
print("\n=== Prompt Format Capture ===")

query = "set a 5 min timer"
tools = [{"name": "set_timer", "description": "Set a timer.", "parameters": {"time_human": {"type": "string", "description": "duration", "required": True}}}]
tools_json = json.dumps(tools)

captured_inputs = []
original_encode = tok.encode

def capture_encode(text, *a, **kw):
    captured_inputs.append(('encode', text))
    return original_encode(text, *a, **kw)

tok.encode = capture_encode

# Call _build_encoder_input directly
enc_tokens = _build_encoder_input(tok, query, tools_json)
print(f"Captured encode calls: {captured_inputs}")
print(f"Encoder token IDs: {enc_tokens}")
print(f"Encoder token count: {len(enc_tokens)}")

# Decode the encoder input to see the full prompt
decoded_enc = tok.decode(enc_tokens)
print(f"Decoded encoder input: {repr(decoded_enc)}")

# Also decode each segment
tok.encode = original_encode  # restore
q_toks = tok.encode(query)
t_toks = tok.encode(tools_json)
print(f"\nQuery tokens: {q_toks} -> '{tok.decode(q_toks)}'")
print(f"tools_token_id (separator): {TOOLS_ID}")
print(f"Tools tokens: {t_toks[:20]}... -> '{tok.decode(t_toks[:20])}'")

# ── 5. Write notes files ──────────────────────────────────────────────────────

NOTES_DIR = os.path.join(os.path.dirname(__file__), '..', 'notes')
os.makedirs(NOTES_DIR, exist_ok=True)

# ── 5a. needle-internals.md ───────────────────────────────────────────────────

def fmt_shape_dict(d, indent=0):
    """Format shape dict as markdown nested list."""
    lines = []
    prefix = '  ' * indent
    for k, v in sorted(d.items()):
        if isinstance(v, dict) and 'shape' not in v:
            lines.append(f"{prefix}- **{k}**")
            lines.extend(fmt_shape_dict(v, indent+1))
        else:
            shape = v.get('shape', '?')
            dtype = v.get('dtype', '?')
            lines.append(f"{prefix}  - `{k}`: shape={shape}, dtype={dtype}")
    return lines

needle_internals = f"""# Needle (Cactus) Flax Model Internals

## Model: `SimpleAttentionNetwork`

**Architecture:** Encoder–decoder transformer with shared embeddings, tied output projection, RoPE, bfloat16.

### `TransformerConfig` Defaults

| Field | Value | Notes |
|---|---|---|
| vocab_size | {config.vocab_size} | SentencePiece BPE vocab |
| d_model | {config.d_model} | Hidden dimension |
| num_heads | {config.num_heads} | Attention heads |
| num_kv_heads | {config.num_kv_heads} | GQA key/value heads |
| num_encoder_layers | {config.num_encoder_layers} | Encoder depth |
| num_decoder_layers | {config.num_decoder_layers} | Decoder depth |
| d_ff | {config.d_ff} | FFN hidden dim (unused by default) |
| max_seq_len | {config.max_seq_len} | |
| pad_token_id | {config.pad_token_id} | |
| rope_theta | {config.rope_theta} | RoPE base frequency |
| dtype | {config.dtype} | |
| activation | {config.activation} | dual-ReLU gate (drelu) |
| num_memory_slots | {config.num_memory_slots} | |
| dropout_rate | {config.dropout_rate} | |
| contrastive_dim | {config.contrastive_dim} | Contrastive projection output dim |
| no_feedforward | {config.no_feedforward} | **True β†’ FFN blocks are SKIPPED** |

**Total parameters:** {total_params:,}

**Important:** `no_feedforward=True` by default β€” `FeedForward` / `gate_proj` / `up_proj` / `down_proj` are NOT present in the saved param tree. The `ffn_gate` scalar also absent. This is a pure attention-only transformer.

---

## Layer-by-Layer Architecture

### Encoder (`Encoder` module)

Uses `nn.scan` over `num_encoder_layers={config.num_encoder_layers}` layers. Each `EncoderBlock` contains:

1. **`attn_gate`** β€” scalar sigmoid gate `()` on attention residual
2. **Pre-norm**: `ZCRMSNorm` (scale initialized to 0, applied as `(1+Ξ³) * x / RMS(x)`)
3. **`self_attn`** (`MultiHeadAttention`):
   - `q_proj`: Dense `(d_model, d_model)` = `({config.d_model}, {config.d_model})`
   - `k_proj`: Dense `(d_model, num_kv_heads * head_dim)` = `({config.d_model}, {config.num_kv_heads * (config.d_model // config.num_heads)})`
   - `v_proj`: Dense `(d_model, num_kv_heads * head_dim)` = `({config.d_model}, {config.num_kv_heads * (config.d_model // config.num_heads)})`
   - `out_proj`: Dense `(d_model, d_model)` = `({config.d_model}, {config.d_model})`
   - `q_norm`, `k_norm`: `ZCRMSNorm` on Q and K (pre-RoPE)
4. Residual: `x = x_in + gate * attn_out`
5. (FFN block **skipped** because `no_feedforward=True`)
6. **`final_norm`**: `ZCRMSNorm` after all layers (on encoder output)

### Decoder (`Decoder` module)

Uses `nn.scan` over `num_decoder_layers={config.num_decoder_layers}` layers. Each `DecoderBlock` contains:

1. **`self_attn_gate`** β€” scalar on causal self-attention residual
2. Pre-norm + **`self_attn`** (causal self-attention, same structure as encoder)
3. Residual: `x = x_in + self_gate * self_attn_out`
4. **`cross_attn_gate`** β€” scalar on cross-attention residual
5. Pre-norm + **`cross_attn`** (cross-attention: Q from decoder, K/V from encoder)
6. Residual: `x = x_in + cross_gate * cross_attn_out`
7. (FFN block **skipped**)
8. Final `ZCRMSNorm` after all layers

### Top-level `SimpleAttentionNetwork` params

- **`embedding`**: `nn.Embed` kernel `(vocab_size, d_model)` = `({config.vocab_size}, {config.d_model})`
  Also used as tied output: `logits = hidden @ embedding.T`
- **`log_temp`**: scalar `()` β€” contrastive temperature
- **`contrastive_hidden`**: Dense `(d_model, d_model//4)` = `({config.d_model}, {config.d_model//4})`
- **`contrastive_proj`**: Dense `(d_model//4, contrastive_dim)` = `({config.d_model//4}, {config.contrastive_dim})`, no bias

---

## `nn.scan` Parameter Layout

Because Flax `nn.scan` with `variable_axes={{"params": 0}}` is used, each scanned layer stack is stored as a **single** parameter array with an extra leading dimension of size `num_layers`. So for encoder: `layers.attn_gate` has shape `(num_encoder_layers, ...)` = `({config.num_encoder_layers}, ...)`.

**This is critical for Task 2C:** When mapping Flax β†’ PyTorch weights, you must index `params['encoder']['layers']['...'][layer_idx]` to get per-layer weights.

---

## Full Flax `params` Tree (shapes)

```json
{json.dumps(shape_dict, indent=2)}
```

---

## `SimpleAttentionNetwork.__call__` Source

```python
{inspect.getsource(SimpleAttentionNetwork.__call__)}
```

## `EncoderBlock.__call__` Source

```python
{inspect.getsource(EncoderBlock.__call__)}
```

## `DecoderBlock.__call__` Source

```python
{inspect.getsource(DecoderBlock.__call__)}
```

## `MultiHeadAttention.__call__` Source

```python
{inspect.getsource(MultiHeadAttention.__call__)}
```
"""

with open(os.path.join(NOTES_DIR, 'needle-internals.md'), 'w') as f:
    f.write(needle_internals)
print(f"\nWrote needle-internals.md ({len(needle_internals.splitlines())} lines)")

# ── 5b. tokenizer-internals.md ────────────────────────────────────────────────

tokenizer_internals = f"""# Needle Tokenizer Internals

## Tokenizer Class

**Class:** `NeedleTokenizer` (defined in `external/needle/needle/dataset/tokenizer.py`)

**Backend:** SentencePiece BPE β€” the model file is downloaded from HuggingFace on first use.

**HuggingFace repo:** `Cactus-Compute/needle-tokenizer` (dataset repo)

**Local model file path:** `{TOKENIZER_PREFIX}.model`
(relative to the external/needle repo root: `needle/tokenizer/needle.model`)

---

## Tokenizer Properties

| Property | Value | Source |
|---|---|---|
| `vocab_size` | {tok.vocab_size} | `sp.GetPieceSize()` |
| `pad_token_id` | {PAD_ID} | hardcoded = 0 |
| `eos_token_id` | {EOS_ID} | hardcoded = 1 |
| `bos_token_id` | {BOS_ID} | hardcoded = 2 |
| `unk_id` | {UNK_ID} | hardcoded = 3 |
| `tool_call_token_id` | {TOOL_CALL_ID} | hardcoded = 4, symbol `<tool_call>` |
| `tools_token_id` | {TOOLS_ID} | hardcoded = 5, symbol `<tools>` |

---

## Encode Algorithm

**Algorithm:** SentencePiece **BPE** (Byte-Pair Encoding)

**Configured with:**
- `model_type="bpe"`
- `byte_fallback=True` β€” unknown bytes represented as `<0xNN>` hex tokens
- `normalization_rule_name="identity"` β€” **no Unicode normalization applied**
- `user_defined_symbols=["<tool_call>", "<tools>"]` β€” fixed IDs 4 and 5

**Encode call in NeedleTokenizer:**
```python
def encode(self, text):
    return self.sp.Encode(text, out_type=int)
```
Returns a Python `list[int]`. Does NOT add BOS/EOS automatically.

**Decode call:**
```python
def decode(self, ids):
    return self.sp.Decode(list(ids))
```

---

## Available Attributes on `NeedleTokenizer`

```
{[a for a in dir(tok) if not a.startswith('_')]}
```

The underlying SentencePiece processor is accessible as `tok.sp`.

---

## SentencePiece Model File

The `.model` file is a serialized SentencePiece protobuf. For Task 5 (port tokenizer to JS/ONNX/etc.), you need either:
1. The `.model` file directly (loadable by the `sentencepiece` library)
2. Export the vocabulary + merge rules via `tok.sp.GetPieceSize()`, `tok.sp.IdToPiece(i)`, etc.

To export the full vocabulary:
```python
vocab = [(i, tok.sp.IdToPiece(i), tok.sp.GetScore(i)) for i in range(tok.vocab_size)]
```

---

## Example Encoding

```
encode("{sample_text}")
  -> {encoded}
  -> decode -> "{decoded}"
```
"""

with open(os.path.join(NOTES_DIR, 'tokenizer-internals.md'), 'w') as f:
    f.write(tokenizer_internals)
print(f"Wrote tokenizer-internals.md ({len(tokenizer_internals.splitlines())} lines)")

# ── 5c. prompt-format.md ─────────────────────────────────────────────────────

# Get the decoded full prompt
tok.encode = capture_encode
enc_tokens2 = _build_encoder_input(tok, query, tools_json)
tok.encode = original_encode  # restore

# Reconstruct what each segment contains
q_toks_raw = tok.encode(query)
t_toks_raw = tok.encode(tools_json)

MAX_ENC = DEFAULT_MAX_ENC_LEN  # 1024
max_query = MAX_ENC - 2
q_toks_trunc = q_toks_raw[:max_query]
remaining = MAX_ENC - len(q_toks_trunc) - 1
t_toks_trunc = t_toks_raw[:remaining]

final_tokens = q_toks_trunc + [TOOLS_ID] + t_toks_trunc
decoded_full = tok.decode(final_tokens)

# Decoder starts with EOS token (1), then model predicts <tool_call> (4) first
prompt_format = f"""# Needle Prompt Format

## Overview

Needle uses an **encoder–decoder** architecture. The query + tools are encoded by the encoder; the decoder generates the tool-call JSON.

## Encoder Input Format

```
[query_tokens..., <tools>(id={TOOLS_ID}), tools_tokens...]
```

**Built by `_build_encoder_input(tokenizer, query, tools, max_enc_len={DEFAULT_MAX_ENC_LEN})`:**

```python
{inspect.getsource(_build_encoder_input)}
```

### Truncation Rules

1. Query is truncated to `max_enc_len - 2 = {MAX_ENC - 2}` tokens
2. Tools are truncated to `max_enc_len - len(q_toks) - 1` tokens (fills remaining space)
3. The `<tools>` separator token (id={TOOLS_ID}) is always inserted between query and tools

### Example: `query="set a 5 min timer"`, `tools=[{{"name": "set_timer", ...}}]`

**Encode calls made by `_build_encoder_input`:**
1. `tokenizer.encode("{query}")` β†’ {q_toks_raw} ({len(q_toks_raw)} tokens)
2. `tokenizer.encode('{tools_json[:80]}...')` β†’ {t_toks_raw[:20]}... ({len(t_toks_raw)} tokens total)

**Final encoder token sequence (len={len(final_tokens)}):**
```
{final_tokens}
```

**Decoded encoder input string:**
```
{repr(decoded_full)}
```

**Structural breakdown:**
- Query tokens ({len(q_toks_trunc)} tokens): `{q_toks_trunc}`
- Separator token (1 token): `[{TOOLS_ID}]` β†’ `<tools>`
- Tools tokens ({len(t_toks_trunc)} tokens): `{t_toks_trunc}`

---

## Decoder Input Format

The decoder is initialized with a single **EOS token** (id={EOS_ID}) as prefix:

```python
dec_buffer = jnp.full((1, max_gen_len), pad_id, dtype=jnp.int32)
dec_buffer = dec_buffer.at[0, 0].set(eos_id)   # prefix = [EOS]
```

At each step, the model predicts the next token autoregressively.

**Expected output sequence:** `<tool_call>(id={TOOL_CALL_ID}) + answer_json_tokens + EOS(id={EOS_ID})`

The `<tool_call>` prefix is stripped from the decoded output string in `generate()`.

---

## Tool Name Normalization

Before building encoder input, tool names are converted to `snake_case` via `normalize_tools()`:
- `camelCase` β†’ `camel_case`
- `PascalCase` β†’ `pascal_case`
- `dot.notation` β†’ `dot_notation`
- Non-alphanumeric chars β†’ underscores

After decoding, `restore_tool_names()` maps back to original names.

---

## Full `generate()` Signature

```python
{inspect.getsource(generate)}
```
"""

with open(os.path.join(NOTES_DIR, 'prompt-format.md'), 'w') as f:
    f.write(prompt_format)
print(f"Wrote prompt-format.md ({len(prompt_format.splitlines())} lines)")

print("\n=== All notes written successfully ===")