parameter-golf-v2 / ANALYSIS.md
rtferraz's picture
Add analysis document
c382430 verified
# Parameter Golf β€” Competitive Analysis & Implementation Plan
## Executive Summary
Your original BitNet b1.58 submission has strong fundamentals (depth recurrence, Muon optimizer, sliding window eval) but is missing **5 critical techniques** used by the top entries. I've built an improved script incorporating all of them. Here's the path to the top of the leaderboard.
## Gap Analysis: Your Submission vs. Top (1.0810 BPB)
| Technique | Your Code | Top Entries | Expected BPB Impact |
|---|---|---|---|
| **Tokenizer** | SP1024 (1024 vocab) | SP8192 (8192 vocab) | **~5-8% BPB improvement** |
| **Quantization** | Ternary QAT (BitNet) | Int6 GPTQ + SDClip | **~2-3% better quality** |
| **Architecture** | Serial residual | Parallel residual (PAF) | **~1-2% BPB** |
| **TTT** | None | Score-first TTT at eval | **~1-3% BPB** |
| **QK-Gain** | 1.5 | 5.0-5.25 | **~0.5-1%** |
| **Weight Decay** | 0.02 | 0.09 | **~0.3-0.5%** |
| **Warmdown** | 1200 steps | 3500 steps | **~0.2-0.5%** |
| **EMA** | None | EMA (decay 0.999) | **~0.3-0.5%** |
| **Depth Recurrence** | 4 unique Γ— 6 loops βœ… | 3 unique Γ— 8 loops | Similar |
| **Residual mixing** | x0 anchor βœ… | x0 anchor βœ… | Same |
**Estimated total improvement: Your ~1.15-1.20 β†’ 1.08-1.10 BPB**
## Key Technique Explanations
### 1. SP8192 Vocabulary (Biggest Single Win)
The BPB metric is **bits per byte** β€” not per token. With 1024 vocab tokens, each token covers fewer characters, so the model needs more tokens to represent the same text. With 8192 tokens, each token covers more text on average, and the model gets "credit" for compressing more bytes per correct prediction.
The top entries all use SP8192 because:
- 8Γ— more vocab = better text coverage = lower BPB
- The embedding table (8192 Γ— 768 = 6.3M params) still fits in budget with int6/int8 quantization
- Domain-tuned on FineWeb data for optimal subword splits
**Action**: Use the SP8192 tokenizer and matching data shards provided in the competition repo.
### 2. Int6 GPTQ + SDClip (vs. Your BitNet Ternary)
Your ternary QAT approach is creative but suboptimal here because:
- **Ternary β‰ˆ 1.58 bits/param** but int6 = 6 bits/param β†’ each parameter carries 4Γ— more information
- In a 16MB budget, int6 fits ~29M params vs ternary fitting ~64M params
- BUT: 29M params at int6 quality > 64M params at ternary quality for language modeling
- The effective "information per parameter" is much higher with int6
**SDClip** (std-based clipping): Before quantizing, clip each row's values to `mean Β± 2.5*std`. This removes outliers that would otherwise dominate the quantization grid, dramatically reducing quantization error.
### 3. Parallel Residuals (PAF Architecture)
Standard transformer:
```python
x = x + attn(norm1(x)) # step 1
x = x + mlp(norm2(x)) # step 2 (uses updated x)
```
Parallel (GPT-J/PaLM style):
```python
h = norm(x) # single norm
x = x + attn(h) + mlp(h) # both use same input
```
Benefits: saves one norm (fewer params), both branches see the same input (wider information flow), empirically ~1-2% better BPB at small scale.
### 4. Score-First TTT (Test-Time Training)
At evaluation time only (free compute!):
1. Process tokens in chunks of 64
2. For each chunk: **score first** (compute loss with current weights)
3. Then **update**: gradient step on MLP.proj weights using reconstruction loss
4. Next chunk benefits from the updated weights
This is "legal" because it's strictly causal β€” predictions for chunk i only depend on chunks 0..i-1. The competition allows arbitrary test-time compute.
### 5. Higher QK-Gain (5.25 vs 1.5)
At small model dimensions (768), the QK dot products are too small to create sharp attention patterns. QK-Gain multiplies the queries by a learned scalar, effectively controlling the "temperature" of attention. 5.25 is the empirically optimal value found by the top entries.
## Architecture Config
```
Vocab: 8192 (SP8192 BPE)
Model dim: 768
Heads: 12 (QKV)
KV heads: 4 (GQA)
MLP multiplier: 4Γ— (hidden = 3072)
Unique layers: 3
Train recurrence: 8 (24 effective layers)
Eval recurrence: 16 (48 effective layers)
QK-Gain: 5.25
Logit softcap: 30.0
RoPE base: 10000
Unique params: ~25.2M
Compressed: ~13-14 MB (int6 + zlib)
```
## Hyperparameter Choices
| Param | Value | Rationale |
|---|---|---|
| Matrix LR | 0.04 | Muon with NS5 orthogonalization |
| Embed LR | 0.05 | Adam for embeddings |
| Scalar LR | 0.04 | Adam for norms/scales |
| Weight Decay | 0.09 | High WD regularizes small models (PR #1285 showed 0.09 > 0.04) |
| Warmdown | 3500 steps | Longer warmdown preserves learned representations (PR #374) |
| EMA start | 40% through training | Only average later checkpoints |
| EMA decay | 0.999 | Standard for small models |
| TTT LR | 0.01 | Inner loop learning rate for test-time adaptation |
| TTT chunk | 64 | Score-first TTT chunk size |
| SDClip n_std | 2.5 | Standard deviation clipping range |
## How to Run
```bash
# On 8Γ—H100 (competition standard):
torchrun --standalone --nproc_per_node=8 train_final.py
# Override any hyperparameter via env vars:
V=8192 D=768 NUL=3 NR=8 QKG=5.25 MWD=0.09 torchrun --standalone --nproc_per_node=8 train_final.py
# Use SP8192 data (must match tokenizer):
DATA_PATH=./data/datasets/fineweb10B_sp8192 TOKENIZER_PATH=./data/tokenizers/fineweb_8192_bpe.model torchrun ...
```
## Next Steps to Push Further
### Immediate (before submission)
1. **Verify SP8192 data availability** in the competition repo
2. **Run on 8Γ—H100** β€” the 10-minute wall clock starts here
3. **Tune TTT LR** β€” sweep [0.001, 0.005, 0.01, 0.02, 0.05] on val set
### If time permits (iterative improvements)
4. **Self-generated GPTQ calibration data**: Generate text from the trained model, use as calibration data for GPTQ quantization (PR #1019 technique)
5. **XSA (Cross-Sequence Attention)**: On the last 3-4 layers, attend across sequence boundaries in the sliding window β€” effectively increasing context length
6. **Progressive recurrence**: Start with fewer recurrences and increase during training β€” warm up the depth gradually
7. **Hessian-aware SDClip**: Use actual Hessian diagonal (from Fisher information) to set per-row clip ranges, instead of simple std-based clipping
8. **BigramHash embeddings**: Hash bigrams to augment the embedding table β€” more input information for free
### Longer-term experiments
9. **Increase vocab to larger**: If budget allows after int6 compression, try SP16384
10. **Mixed quantization**: Int4 for some layers, int6 for critical ones (first/last layers)
11. **Depth-conditional scaling**: Different attn_scale/mlp_scale for each recurrence step (not just each unique layer)