Upload folder using huggingface_hub
Browse files- .gitignore +12 -0
- .pytest_cache/.gitignore +2 -0
- .pytest_cache/CACHEDIR.TAG +4 -0
- .pytest_cache/README.md +8 -0
- .pytest_cache/v/cache/nodeids +11 -0
- .pytest_cache/v/cache/stepwise +1 -0
- README.md +255 -0
- chimera/__init__.py +32 -0
- chimera/config.py +65 -0
- chimera/evolution.py +301 -0
- chimera/inference.py +359 -0
- chimera/layers.py +485 -0
- chimera/looping.py +73 -0
- chimera/model.py +378 -0
- chimera/moe.py +102 -0
- chimera/multimodal.py +136 -0
- chimera/quantization.py +508 -0
- chimera/tokenizer.py +160 -0
- config.json +638 -0
- gguf_import.py +905 -0
- inference.py +302 -0
- pyproject.toml +10 -0
- tests/test_chimera.py +115 -0
- tests/test_config.py +8 -0
- train.py +632 -0
.gitignore
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.py[cod]
|
| 3 |
+
.pytest_cache/
|
| 4 |
+
.venv/
|
| 5 |
+
.deps/
|
| 6 |
+
chimera_output/
|
| 7 |
+
chimera_imported/
|
| 8 |
+
*.pt
|
| 9 |
+
*.gguf
|
| 10 |
+
.ternary_build*
|
| 11 |
+
.kernel_build
|
| 12 |
+
.simd_build
|
.pytest_cache/.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Created by pytest automatically.
|
| 2 |
+
*
|
.pytest_cache/CACHEDIR.TAG
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Signature: 8a477f597d28d172789f06886806bc55
|
| 2 |
+
# This file is a cache directory tag created by pytest.
|
| 3 |
+
# For information about cache directory tags, see:
|
| 4 |
+
# https://bford.info/cachedir/spec.html
|
.pytest_cache/README.md
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# pytest cache directory #
|
| 2 |
+
|
| 3 |
+
This directory contains data from the pytest's cache plugin,
|
| 4 |
+
which provides the `--lf` and `--ff` options, as well as the `cache` fixture.
|
| 5 |
+
|
| 6 |
+
**Do not** commit this to version control.
|
| 7 |
+
|
| 8 |
+
See [the docs](https://docs.pytest.org/en/stable/how-to/cache.html) for more information.
|
.pytest_cache/v/cache/nodeids
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
"tests/test_chimera.py::test_bitlinear_dense_cache_consistency",
|
| 3 |
+
"tests/test_chimera.py::test_bitlinear_forward_backward_and_packed",
|
| 4 |
+
"tests/test_chimera.py::test_model_forward_loss_and_generate_shape",
|
| 5 |
+
"tests/test_chimera.py::test_model_kv_cache_consistency",
|
| 6 |
+
"tests/test_chimera.py::test_moe_and_span_bank_shapes",
|
| 7 |
+
"tests/test_chimera.py::test_pack_unpack_roundtrip",
|
| 8 |
+
"tests/test_chimera.py::test_ternarize_weight_basic",
|
| 9 |
+
"tests/test_chimera.py::test_tokenizer_fallback_roundtrip",
|
| 10 |
+
"tests/test_config.py::test_config_scaling_without_torch_runtime"
|
| 11 |
+
]
|
.pytest_cache/v/cache/stepwise
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
[]
|
README.md
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Chimera 5.1 — True 1.58-bit Ternary CPU Compute (v5.1.3)
|
| 2 |
+
|
| 3 |
+
100% faithful implementation of the Chimera 5.1 config. All 15 architectural components implemented in pure PyTorch, with **true 1.58-bit ternary computation** on CPU.
|
| 4 |
+
|
| 5 |
+
**Key breakthrough**: Ternary weights `{-1, 0, 1}` are stored in 2-bit packed format (4 weights per byte), giving **16× memory reduction** and enabling zero-multiply forward/backward paths via custom C++ kernels with OpenMP.
|
| 6 |
+
|
| 7 |
+
**Tokenizer**: splintr-rs (Rust) — o200k_base vocab (200,073 tokens, OpenAI o1/o3).
|
| 8 |
+
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
## v5.1.4 — Real CPU Fast Path Audit
|
| 12 |
+
|
| 13 |
+
Implemented after a full CPU hot-path audit:
|
| 14 |
+
- fixed the package/runtime mismatch (`chimera` imports now match the repository layout);
|
| 15 |
+
- added the missing sparse `MoELayer` with expert-grouped dispatch and `index_add_` accumulation;
|
| 16 |
+
- made C++ ternary extensions lazy-loaded instead of compiling at import time;
|
| 17 |
+
- vectorized BitLinear AbsMean scaling and removed Python repack loops;
|
| 18 |
+
- cached causal/triangular masks reused by recurrent layers during generation and MeZO;
|
| 19 |
+
- reduced no-grad Gated DeltaNet clone churn while keeping autograd-safe behavior for AdamW;
|
| 20 |
+
- made MeZO CPU training use cached per-step directions and fast Rademacher perturbations by default;
|
| 21 |
+
- deduplicated tied embedding/lm-head parameters in MeZO updates;
|
| 22 |
+
- added deterministic greedy inference fast path (`--temperature 0`) and optional bounded context (`--max_context`).
|
| 23 |
+
|
| 24 |
+
Recommended CPU modes:
|
| 25 |
+
```bash
|
| 26 |
+
# Ultra-efficient CPU fine-tuning
|
| 27 |
+
OMP_NUM_THREADS=$(nproc) python train.py \
|
| 28 |
+
--scale tiny --seq_len 64 --max_steps 10 \
|
| 29 |
+
--optimizer mezo --mezo_direction rademacher \
|
| 30 |
+
--batch_size 2 --grad_accum 1 --no-bf16 --num_workers 0
|
| 31 |
+
|
| 32 |
+
# Lowest-latency deterministic CPU serving
|
| 33 |
+
python inference.py \
|
| 34 |
+
--checkpoint chimera_output/final/model.pt \
|
| 35 |
+
--prompt "Once upon a time" --temperature 0 --top_k 1 \
|
| 36 |
+
--max_context 256 --max_tokens 128
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
---
|
| 40 |
+
|
| 41 |
+
## v5.1.3 — Fix Illegal Instruction Crash
|
| 42 |
+
|
| 43 |
+
**Fixed**: Removed `-march=native` from C++ JIT compilation flags. This flag caused `Illegal instruction (core dumped)` on CPUs with different instruction sets than the build machine. The C++ kernel now uses **runtime CPUID detection** to select AVX-512/AVX2 paths, while compilation remains portable.
|
| 44 |
+
|
| 45 |
+
**If you get `Illegal instruction`:**
|
| 46 |
+
```bash
|
| 47 |
+
rm -rf .ternary_build .ternary_build_v2 # Clear old cache
|
| 48 |
+
python train.py ... # Rebuild with portable flags
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
---
|
| 52 |
+
|
| 53 |
+
## v5.1.2 — True Ternary Compute
|
| 54 |
+
|
| 55 |
+
| Component | Implementation | Memory | Speed (training) | Speed (inference) |
|
| 56 |
+
|---|---|---|---|---|
|
| 57 |
+
| **Weight storage** | 2-bit packed uint8 (4 w/byte) | **16× smaller** vs FP32 | — | — |
|
| 58 |
+
| **Forward path** | C++ unpack + MKL BLAS | 94% less bandwidth | ~0.5-0.7× (unpack overhead) | ~1.0-1.2× (amortized) |
|
| 59 |
+
| **Backward grad_x** | Same ternary kernel | — | Included in above | — |
|
| 60 |
+
| **Backward grad_w** | FP32 outer product (STE req) | — | standard | — |
|
| 61 |
+
| **MeZO optimizer** | Sparse perturbation (skip ~33% zeros) | 2× model size | **No backward pass** | — |
|
| 62 |
+
| **MeZO sparse update** | C++ kernel, perturb only non-zero weights | — | ~1.5× faster per step | — |
|
| 63 |
+
|
| 64 |
+
**Note**: Ternary compute is **memory-optimized**, not raw compute-optimized. On CPU, MKL BLAS for FP32 matmul is so optimized that ternary unpack+BLAS has ~30-50% overhead at small sizes. The win is:
|
| 65 |
+
- **16× less RAM** — models that don't fit in FP32 fit in ternary
|
| 66 |
+
- **16× less memory bandwidth** — weight loading from DRAM is the bottleneck for large models
|
| 67 |
+
- **MeZO eliminates backward** — no gradient through 28 layers of recurrences
|
| 68 |
+
|
| 69 |
+
### When Ternary Wins
|
| 70 |
+
|
| 71 |
+
| Scenario | FP32 | Ternary + MeZO | Winner |
|
| 72 |
+
|---|---|---|---|
|
| 73 |
+
| Model > L3 cache (e.g. 2B params) | 10GB, bandwidth-bound | 0.6GB, fits L3 | **Ternary** |
|
| 74 |
+
| Small model, fits L1 (e.g. 50M) | Fast BLAS | Unpack overhead | FP32 |
|
| 75 |
+
| CPU without AVX-512/AMX | Standard | Same path | Tie |
|
| 76 |
+
| CPU with VNNI/AMX + `_int_mm` | Slow INT8 path | Native INT8 matmul | **Ternary** |
|
| 77 |
+
| Fine-tuning with limited RAM | OOM | Fits | **Ternary** |
|
| 78 |
+
|
| 79 |
+
---
|
| 80 |
+
|
| 81 |
+
## Architecture (28 layers, 4 types)
|
| 82 |
+
|
| 83 |
+
```
|
| 84 |
+
Layer pattern: GD XM GD TM GD XM GD SK × 3.5
|
| 85 |
+
GD = Gated DeltaNet (14 layers) — arxiv:2412.06464
|
| 86 |
+
XM = xLSTM mLSTM (7 layers) — arxiv:2405.04517
|
| 87 |
+
TM = Titans MAC (4 layers) — arxiv:2501.00663
|
| 88 |
+
SK = TSP Span Knot (3 layers)
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
All linear layers use **BitLinear** (ternary 1.58-bit) with per-group AbsMean scaling.
|
| 92 |
+
|
| 93 |
+
---
|
| 94 |
+
|
| 95 |
+
## Components
|
| 96 |
+
|
| 97 |
+
| Module | File | Status |
|
| 98 |
+
|--------|------|--------|
|
| 99 |
+
| **splintr Tokenizer** (o200k_base, 200K vocab, Rust-backed) | `tokenizer.py` | ✅ |
|
| 100 |
+
| **BitNet 1.58 QAT** (2-bit packed, C++ unpack kernel, STE, N:M 2:4) | `quantization.py` | ✅ v5.1.3 |
|
| 101 |
+
| **Ternary SIMD Kernels** (AVX2 unpack, OpenMP, sparse MeZO) | `ternary_simd.py` | ✅ v5.1.3 |
|
| 102 |
+
| **Gated DeltaNet** (α/β gates, chunkwise parallel) | `layers.py` | ✅ |
|
| 103 |
+
| **xLSTM mLSTM** (parallelized, no timestep loop) | `layers.py` | ✅ v5.1.1 |
|
| 104 |
+
| **Titans MAC** (parallelized, no timestep loop) | `layers.py` | ✅ v5.1.1 |
|
| 105 |
+
| **TSP Span Knot** (vectorized Hamming) | `layers.py` | ✅ v5.1.1 |
|
| 106 |
+
| **Parcae Looping** (deterministic, checkpoint-safe) | `looping.py` | ✅ v5.1.1 |
|
| 107 |
+
| **MoE** (sort-based dispatch, 16 experts, 2 active) | `moe.py` | ✅ v5.1.1 |
|
| 108 |
+
| **Span Inference** (bank, STree verifier, certificates) | `inference.py` | ✅ |
|
| 109 |
+
| **Grammar FST** (9 modes, hard/soft constraints, fused penalty) | `inference.py` | ✅ |
|
| 110 |
+
| **Entropy Valve** (3 levels, causal predictor router) | `inference.py` | ✅ |
|
| 111 |
+
| **Debt Ledger** (8 obligation types, pressure scoring) | `inference.py` | ✅ |
|
| 112 |
+
| **Braid State** (continuous + fast + semantic sketch + entity + grammar) | `inference.py` | ✅ |
|
| 113 |
+
| **Self-Evolution** (TTT, semantic memory HDC, episodic cases, meta-guidelines) | `evolution.py` | ✅ |
|
| 114 |
+
| **Multimodal** (vision + audio encoders, ternary, checkpointed) | `multimodal.py` | ✅ |
|
| 115 |
+
| **Full Model** (Chimera51ForCausalLM) | `model.py` | ✅ |
|
| 116 |
+
|
| 117 |
+
---
|
| 118 |
+
|
| 119 |
+
## Quick Start
|
| 120 |
+
|
| 121 |
+
```bash
|
| 122 |
+
pip install torch datasets transformers einops splintr-rs
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
### Training
|
| 126 |
+
|
| 127 |
+
```bash
|
| 128 |
+
# Test rapide (MeZO, tiny, 10 steps)
|
| 129 |
+
OMP_NUM_THREADS=$(nproc) python train.py \
|
| 130 |
+
--scale tiny --seq_len 64 --max_steps 10 \
|
| 131 |
+
--optimizer mezo --batch_size 2 --grad_accum 1 \
|
| 132 |
+
--lr 1e-3 --no-bf16 --num_workers 0 --log_every 1
|
| 133 |
+
|
| 134 |
+
# Entraînement réel (MeZO + compile, small, 50K steps)
|
| 135 |
+
OMP_NUM_THREADS=$(nproc) python train.py \
|
| 136 |
+
--scale small --seq_len 256 --max_steps 50000 \
|
| 137 |
+
--optimizer mezo --batch_size 2 --grad_accum 4 \
|
| 138 |
+
--lr 1e-3 --warmup 2000 --compile \
|
| 139 |
+
--num_workers 0 --save_every 5000
|
| 140 |
+
```
|
| 141 |
+
|
| 142 |
+
### Inference (génération de texte)
|
| 143 |
+
|
| 144 |
+
```bash
|
| 145 |
+
# Générer à partir du checkpoint final
|
| 146 |
+
python inference.py \
|
| 147 |
+
--checkpoint chimera_output/final/model.pt \
|
| 148 |
+
--prompt "Once upon a time" \
|
| 149 |
+
--max_tokens 200 \
|
| 150 |
+
--temperature 0.8 --top_p 0.9 --top_k 50
|
| 151 |
+
|
| 152 |
+
# Avec torch.compile pour accélérer l'inférence
|
| 153 |
+
python inference.py \
|
| 154 |
+
--checkpoint chimera_output/final/model.pt \
|
| 155 |
+
--prompt "Once upon a time" \
|
| 156 |
+
--max_tokens 200 \
|
| 157 |
+
--temperature 0.8 --top_p 0.9 --top_k 50 \
|
| 158 |
+
--compile
|
| 159 |
+
|
| 160 |
+
# Avec BF16 (si supporté par votre CPU)
|
| 161 |
+
python inference.py \
|
| 162 |
+
--checkpoint chimera_output/final/model.pt \
|
| 163 |
+
--prompt "Once upon a time" \
|
| 164 |
+
--max_tokens 200 \
|
| 165 |
+
--bf16 --compile
|
| 166 |
+
```
|
| 167 |
+
|
| 168 |
+
---
|
| 169 |
+
|
| 170 |
+
## Training Modes
|
| 171 |
+
|
| 172 |
+
### MeZO (Recommended for CPU)
|
| 173 |
+
- **No backward pass** — eliminates all gradient computation through complex recurrences
|
| 174 |
+
- **Memory = 2× model size** — no activations, no gradients, no optimizer states
|
| 175 |
+
- **Ternary-aware sparse perturbation** — skips ~33% zero-weight positions in BitLinear layers
|
| 176 |
+
- Best for fine-tuning; requires ~32× more steps for pretraining
|
| 177 |
+
- Combined with BF16 autocast for maximum CPU throughput
|
| 178 |
+
|
| 179 |
+
### AdamW (Standard backprop)
|
| 180 |
+
- Full gradient computation with gradient checkpointing
|
| 181 |
+
- Ternary forward/backward via C++ kernel (2-bit packed → float → BLAS)
|
| 182 |
+
- BFloat16 autocast for forward pass
|
| 183 |
+
- Weight decay differentiated (no decay for norms, biases, embeddings)
|
| 184 |
+
- Best when gradient quality matters (pretraining from scratch)
|
| 185 |
+
|
| 186 |
+
---
|
| 187 |
+
|
| 188 |
+
## Ternary Compute Details
|
| 189 |
+
|
| 190 |
+
### Weight Packing
|
| 191 |
+
```
|
| 192 |
+
2 bits per weight: 00→0, 01→+1, 10→-1
|
| 193 |
+
4 weights per uint8 byte
|
| 194 |
+
Per-row scale α = mean(|W|) per group
|
| 195 |
+
```
|
| 196 |
+
|
| 197 |
+
### Forward Pass
|
| 198 |
+
```
|
| 199 |
+
1. Quantize latent FP32 → ternary int8 {-1,0,1}
|
| 200 |
+
2. Pack to 2-bit uint8 (4× compression)
|
| 201 |
+
3. Unpack to float32 buffer (pre-allocated, reused)
|
| 202 |
+
4. MKL BLAS matmul (x @ W^T)
|
| 203 |
+
```
|
| 204 |
+
|
| 205 |
+
### MeZO Sparse Perturbation (C++)
|
| 206 |
+
```
|
| 207 |
+
For each weight position:
|
| 208 |
+
If packed_bits == 0: SKIP (no perturbation, no update)
|
| 209 |
+
Else: generate z ~ N(0,1), perturb by ε·z
|
| 210 |
+
```
|
| 211 |
+
This saves **33% of perturbation operations** since ~1/3 of ternary weights are zero.
|
| 212 |
+
|
| 213 |
+
### C++ Kernel Features
|
| 214 |
+
- OpenMP parallel over output dimensions
|
| 215 |
+
- Pre-allocated unpack buffer (zero allocation in hot loop)
|
| 216 |
+
- Deterministic LCG RNG per thread (reproducible across runs)
|
| 217 |
+
- Falls back to pure PyTorch if C++ compilation fails
|
| 218 |
+
|
| 219 |
+
---
|
| 220 |
+
|
| 221 |
+
## Files
|
| 222 |
+
|
| 223 |
+
```
|
| 224 |
+
chimera/
|
| 225 |
+
__init__.py — Package exports
|
| 226 |
+
quantization.py — BitLinear (2-bit packed, C++ kernel, STE, N:M 2:4)
|
| 227 |
+
ternary_simd.py — AVX2/AVX-512 SIMD unpack kernels (optional)
|
| 228 |
+
layers.py — GatedDeltaNet, MLSTMLayer (PARALLEL), TitansMACLayer (PARALLEL), TSPSpanKnotLayer
|
| 229 |
+
moe.py — MoELayer (sort-based dispatch), NoAuxMoEGate
|
| 230 |
+
looping.py — ParcaeLoopController (deterministic, checkpoint-safe)
|
| 231 |
+
inference.py — SpanBank, STree, Grammar, EntropyValve, DebtLedger, BraidState
|
| 232 |
+
evolution.py — TTT, SemanticMemory (vectorized HDC), EpisodicCases, MetaGuidelines
|
| 233 |
+
multimodal.py — VisionEncoder, AudioEncoder (checkpointed)
|
| 234 |
+
tokenizer.py — ChimeraTokenizer (splintr Rust wrapper, o200k_base vocab)
|
| 235 |
+
model.py — Chimera51ForCausalLM (compile + checkpoint + bf16 support)
|
| 236 |
+
config.json — Chimera 5.1 config (honest P3 section)
|
| 237 |
+
train.py — Training script (MeZO + AdamW, ternary, bf16, compile, IPEX)
|
| 238 |
+
inference.py — Inference script (checkpoint loading, autoregressive generation)
|
| 239 |
+
```
|
| 240 |
+
|
| 241 |
+
---
|
| 242 |
+
|
| 243 |
+
## References
|
| 244 |
+
|
| 245 |
+
37 papers indexed in `config.json` under `§`. Key ones:
|
| 246 |
+
- [Gated DeltaNet](https://arxiv.org/abs/2412.06464) — NVIDIA
|
| 247 |
+
- [xLSTM](https://arxiv.org/abs/2405.04517) — NXAI/JKU
|
| 248 |
+
- [Titans](https://arxiv.org/abs/2501.00663) — Google
|
| 249 |
+
- [Parcae](https://arxiv.org/abs/2604.12946) — Stanford/Together
|
| 250 |
+
- [BitNet b1.58](https://arxiv.org/abs/2402.17764) — Microsoft
|
| 251 |
+
- [Bitnet.cpp](https://arxiv.org/abs/2502.11880) — MSRA (ELUT kernel)
|
| 252 |
+
- [T-MAC](https://arxiv.org/abs/2407.00088) — MSRA (LUT inference)
|
| 253 |
+
- [MeZO](https://arxiv.org/abs/2305.17333) — Princeton (CPU training optimizer)
|
| 254 |
+
- [DeepSeek MoE routing](https://arxiv.org/abs/2408.15664) — DeepSeek
|
| 255 |
+
- [In-Place TTT](https://arxiv.org/abs/2604.06169) — ByteDance
|
chimera/__init__.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Chimera 5.2 — CPU-first causal LM with ternary 1.58-bit weights."""
|
| 2 |
+
|
| 3 |
+
from .config import load_config, scale_config, tiny_config
|
| 4 |
+
|
| 5 |
+
__version__ = "5.2.0"
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
"load_config", "scale_config", "tiny_config",
|
| 9 |
+
"Chimera51ForCausalLM", "Chimera51Block", "expand_layer_pattern",
|
| 10 |
+
"BitLinear", "RMSNorm", "pack_ternary", "unpack_ternary",
|
| 11 |
+
"ternarize_weight", "_quantize_weights_ternary", "apply_2_4_sparsity_",
|
| 12 |
+
"enable_native_kernel", "native_kernel_available",
|
| 13 |
+
"ChimeraTokenizer",
|
| 14 |
+
]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# Lazy public surface — keeps ``import chimera`` cheap (no torch import until
|
| 18 |
+
# the user actually touches a model class).
|
| 19 |
+
def __getattr__(name):
|
| 20 |
+
if name in {"Chimera51ForCausalLM", "Chimera51Block", "expand_layer_pattern"}:
|
| 21 |
+
from .model import Chimera51ForCausalLM, Chimera51Block, expand_layer_pattern
|
| 22 |
+
return locals()[name]
|
| 23 |
+
if name in {"BitLinear", "RMSNorm", "pack_ternary", "unpack_ternary",
|
| 24 |
+
"ternarize_weight", "_quantize_weights_ternary",
|
| 25 |
+
"apply_2_4_sparsity_", "enable_native_kernel",
|
| 26 |
+
"native_kernel_available"}:
|
| 27 |
+
from . import quantization as _q
|
| 28 |
+
return getattr(_q, name)
|
| 29 |
+
if name == "ChimeraTokenizer":
|
| 30 |
+
from .tokenizer import ChimeraTokenizer
|
| 31 |
+
return ChimeraTokenizer
|
| 32 |
+
raise AttributeError(name)
|
chimera/config.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import copy
|
| 4 |
+
import json
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Any, Mapping
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def load_config(path: str | Path | None = None, overrides: Mapping[str, Any] | None = None) -> dict:
|
| 10 |
+
"""Load a Chimera JSON config and apply shallow dotted-key overrides."""
|
| 11 |
+
if path is None:
|
| 12 |
+
path = Path(__file__).resolve().parents[1] / "config.json"
|
| 13 |
+
with open(path, "r", encoding="utf-8") as fh:
|
| 14 |
+
cfg = json.load(fh)
|
| 15 |
+
if overrides:
|
| 16 |
+
cfg = copy.deepcopy(cfg)
|
| 17 |
+
for key, value in overrides.items():
|
| 18 |
+
cur = cfg
|
| 19 |
+
parts = str(key).split(".")
|
| 20 |
+
for part in parts[:-1]:
|
| 21 |
+
cur = cur.setdefault(part, {})
|
| 22 |
+
cur[parts[-1]] = value
|
| 23 |
+
return cfg
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def scale_config(config: dict, scale: str = "base") -> dict:
|
| 27 |
+
"""Return a safe CPU-scaled copy while preserving feature flags.
|
| 28 |
+
|
| 29 |
+
The uploaded Chimera config targets a large model. These presets keep all
|
| 30 |
+
modules wired but resize dimensions so tests/fine-tuning fit commodity CPU
|
| 31 |
+
memory (including 16 GB DDR5 machines).
|
| 32 |
+
"""
|
| 33 |
+
cfg = copy.deepcopy(config)
|
| 34 |
+
presets = {
|
| 35 |
+
"nano": dict(hidden_size=128, intermediate_size=344, num_hidden_layers=4, num_heads=4, head_dim=32, vocab_size=min(cfg.get("vocab_size", 32000), 8192)),
|
| 36 |
+
"tiny": dict(hidden_size=256, intermediate_size=688, num_hidden_layers=6, num_heads=4, head_dim=64, vocab_size=min(cfg.get("vocab_size", 32000), 32768)),
|
| 37 |
+
"small": dict(hidden_size=512, intermediate_size=1376, num_hidden_layers=8, num_heads=8, head_dim=64, vocab_size=min(cfg.get("vocab_size", 32000), 65536)),
|
| 38 |
+
"base": {},
|
| 39 |
+
}
|
| 40 |
+
if scale not in presets:
|
| 41 |
+
raise ValueError(f"unknown scale {scale!r}; choose {sorted(presets)}")
|
| 42 |
+
cfg.update(presets[scale])
|
| 43 |
+
h = cfg["hidden_size"]
|
| 44 |
+
cfg["num_heads"] = max(1, min(cfg.get("num_heads", 4), h // max(1, cfg.get("head_dim", 64))))
|
| 45 |
+
cfg["head_dim"] = h // cfg["num_heads"]
|
| 46 |
+
cfg.setdefault("backbone", {}).setdefault("moe", {})
|
| 47 |
+
moe = cfg["backbone"]["moe"]
|
| 48 |
+
moe["layers"] = [i for i in moe.get("layers", []) if i < cfg["num_hidden_layers"]]
|
| 49 |
+
moe["n_routed_experts"] = min(int(moe.get("n_routed_experts", 4)), 4 if scale in {"nano", "tiny"} else 8)
|
| 50 |
+
moe["n_shared_experts"] = min(int(moe.get("n_shared_experts", 1)), 1)
|
| 51 |
+
moe["num_experts_per_tok"] = min(int(moe.get("num_experts_per_tok", 2)), moe["n_routed_experts"])
|
| 52 |
+
moe["moe_intermediate_size"] = min(int(moe.get("moe_intermediate_size", h * 2)), max(64, cfg["intermediate_size"] // 2))
|
| 53 |
+
loop = cfg.setdefault("looping", {})
|
| 54 |
+
if cfg["num_hidden_layers"] < 8:
|
| 55 |
+
loop["enabled"] = False
|
| 56 |
+
else:
|
| 57 |
+
loop["prelude"] = [0, min(1, cfg["num_hidden_layers"] - 1)]
|
| 58 |
+
loop["loop"] = [2, max(2, cfg["num_hidden_layers"] - 3)]
|
| 59 |
+
loop["coda"] = [max(0, cfg["num_hidden_layers"] - 2), cfg["num_hidden_layers"] - 1]
|
| 60 |
+
cfg.setdefault("span_inference", {})["enabled"] = bool(cfg.get("span_inference", {}).get("enabled", True))
|
| 61 |
+
return cfg
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def tiny_config() -> dict:
|
| 65 |
+
return scale_config(load_config(), "nano")
|
chimera/evolution.py
ADDED
|
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Chimera 5.2 — self-evolution components (CPU-first, slim).
|
| 3 |
+
|
| 4 |
+
Mostly the same surface as before; key fixes:
|
| 5 |
+
* :func:`SemanticMemory.majority_bundle` is now a single vectorised
|
| 6 |
+
unpack/sum/repack — the previous Python-level ``for bit in range(8)``
|
| 7 |
+
loop dominated TTT updates.
|
| 8 |
+
* :func:`SemanticMemory.hamming_distance` reuses the same vectorised
|
| 9 |
+
unpack and runs in fp32 *only* on the bit dimension (D bytes × 8 bits)
|
| 10 |
+
so memory stays bounded.
|
| 11 |
+
* Episodic / meta banks share the same query/projection helpers.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
from typing import Optional, Tuple
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
_BIT_SHIFTS = torch.arange(8, dtype=torch.uint8)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def _unpack_bits(x: torch.Tensor) -> torch.Tensor:
|
| 27 |
+
"""Unpack uint8 ``[..., D]`` into ``[..., D, 8]`` of {0,1} fp32."""
|
| 28 |
+
shifts = _BIT_SHIFTS.to(x.device)
|
| 29 |
+
return ((x.unsqueeze(-1) >> shifts) & 1).to(torch.float32)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _pack_bits(b: torch.Tensor) -> torch.Tensor:
|
| 33 |
+
"""Inverse of :func:`_unpack_bits`."""
|
| 34 |
+
shifts = _BIT_SHIFTS.to(b.device).to(torch.uint8)
|
| 35 |
+
return (b.to(torch.uint8) << shifts).sum(dim=-1).to(torch.uint8)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# ---------------------------------------------------------------------------
|
| 39 |
+
# SemanticMemory (HDC)
|
| 40 |
+
# ---------------------------------------------------------------------------
|
| 41 |
+
|
| 42 |
+
class SemanticMemory(nn.Module):
|
| 43 |
+
"""Hyperdimensional binary memory with vectorised ops."""
|
| 44 |
+
|
| 45 |
+
def __init__(self, config: dict):
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.vector_bits = int(config.get("vector_bits", 8192))
|
| 48 |
+
self.capacity = int(config.get("capacity", 200_000))
|
| 49 |
+
self.pool_fixed = bool(config.get("pool_size_fixed", True))
|
| 50 |
+
self.lsh_tables = int(config.get("lsh_tables", 64))
|
| 51 |
+
self.lsh_bits = int(config.get("lsh_bits_per_table", 14))
|
| 52 |
+
|
| 53 |
+
actual_cap = max(1, min(self.capacity, 50_000))
|
| 54 |
+
n_bytes = self.vector_bits // 8
|
| 55 |
+
self.register_buffer("memory", torch.zeros(actual_cap, n_bytes, dtype=torch.uint8))
|
| 56 |
+
self.register_buffer("count", torch.zeros((), dtype=torch.long))
|
| 57 |
+
self.register_buffer("access_counts", torch.zeros(actual_cap, dtype=torch.long))
|
| 58 |
+
|
| 59 |
+
self.lsh_proj = nn.Linear(n_bytes, self.lsh_tables * self.lsh_bits, bias=False)
|
| 60 |
+
|
| 61 |
+
@staticmethod
|
| 62 |
+
def xor_bind(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
| 63 |
+
return torch.bitwise_xor(a, b)
|
| 64 |
+
|
| 65 |
+
@staticmethod
|
| 66 |
+
def xor_unbind(bound: torch.Tensor, key: torch.Tensor) -> torch.Tensor:
|
| 67 |
+
return torch.bitwise_xor(bound, key)
|
| 68 |
+
|
| 69 |
+
@staticmethod
|
| 70 |
+
def majority_bundle(hvs: torch.Tensor) -> torch.Tensor:
|
| 71 |
+
"""Vectorised majority rule over a batch of hypervectors.
|
| 72 |
+
|
| 73 |
+
``hvs`` is ``[N, D]`` uint8; returns ``[D]`` uint8.
|
| 74 |
+
"""
|
| 75 |
+
if hvs.numel() == 0:
|
| 76 |
+
return torch.zeros(hvs.shape[-1] if hvs.ndim else 0, dtype=torch.uint8,
|
| 77 |
+
device=hvs.device)
|
| 78 |
+
bits = _unpack_bits(hvs) # [N, D, 8] fp32 in {0, 1}
|
| 79 |
+
majority = (bits.sum(dim=0) > (hvs.size(0) / 2.0)).to(torch.uint8)
|
| 80 |
+
return _pack_bits(majority) # [D]
|
| 81 |
+
|
| 82 |
+
@staticmethod
|
| 83 |
+
def hamming_distance(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
| 84 |
+
"""Batched Hamming distance over uint8 byte tensors."""
|
| 85 |
+
xor = torch.bitwise_xor(a, b)
|
| 86 |
+
bits = _unpack_bits(xor) # [..., D, 8]
|
| 87 |
+
return bits.sum(dim=(-1, -2))
|
| 88 |
+
|
| 89 |
+
def query(self, query_vec: torch.Tensor, top_k: int = 16
|
| 90 |
+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 91 |
+
c = int(self.count.item())
|
| 92 |
+
if c == 0:
|
| 93 |
+
return None, None
|
| 94 |
+
dists = self.hamming_distance(query_vec.unsqueeze(-2),
|
| 95 |
+
self.memory[:c].unsqueeze(0))
|
| 96 |
+
k = min(top_k, c)
|
| 97 |
+
values, indices = dists.topk(k, dim=-1, largest=False)
|
| 98 |
+
with torch.no_grad():
|
| 99 |
+
self.access_counts[indices.reshape(-1)] += 1
|
| 100 |
+
return values, indices
|
| 101 |
+
|
| 102 |
+
@torch.no_grad()
|
| 103 |
+
def store(self, vec: torch.Tensor, surprise_magnitude: float = 0.0) -> None:
|
| 104 |
+
vec_flat = vec.detach().reshape(-1)[:self.memory.size(1)].to(torch.uint8)
|
| 105 |
+
cap = self.memory.size(0)
|
| 106 |
+
if self.pool_fixed and int(self.count.item()) >= cap:
|
| 107 |
+
min_idx = int(self.access_counts[:cap].argmin().item())
|
| 108 |
+
self.memory[min_idx] = vec_flat
|
| 109 |
+
self.access_counts[min_idx] = 0
|
| 110 |
+
else:
|
| 111 |
+
idx = int(self.count.item())
|
| 112 |
+
if idx < cap:
|
| 113 |
+
self.memory[idx] = vec_flat
|
| 114 |
+
self.count.add_(1)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# ---------------------------------------------------------------------------
|
| 118 |
+
# In-place test-time training
|
| 119 |
+
# ---------------------------------------------------------------------------
|
| 120 |
+
|
| 121 |
+
class InPlaceTTT(nn.Module):
|
| 122 |
+
"""Single-step in-place TTT update."""
|
| 123 |
+
|
| 124 |
+
def __init__(self, config: dict, hidden_size: int):
|
| 125 |
+
super().__init__()
|
| 126 |
+
self.enabled = bool(config.get("enabled", True))
|
| 127 |
+
self.target_layers = list(config.get("target_layers", [13, 23]))
|
| 128 |
+
self.inner_lr = float(config.get("inner_lr", 3e-4))
|
| 129 |
+
self.momentum = float(config.get("momentum", 0.9))
|
| 130 |
+
self.chunk_size = int(config.get("chunk_size", 1024))
|
| 131 |
+
self.reset_decay = float(config.get("reset_decay", 0.95))
|
| 132 |
+
self.delta_clip = float(config.get("delta_clip", 1e-5))
|
| 133 |
+
|
| 134 |
+
self.conv1d = nn.Conv1d(hidden_size, hidden_size, kernel_size=5,
|
| 135 |
+
padding=4, groups=hidden_size, bias=False)
|
| 136 |
+
nn.init.zeros_(self.conv1d.weight)
|
| 137 |
+
self.w_target = nn.Parameter(torch.eye(hidden_size) * 0.01)
|
| 138 |
+
|
| 139 |
+
def compute_update(self, x_raw: torch.Tensor, z: torch.Tensor,
|
| 140 |
+
w_down: torch.Tensor) -> torch.Tensor:
|
| 141 |
+
# Causal depthwise convolution + small linear projection.
|
| 142 |
+
T = x_raw.shape[1]
|
| 143 |
+
x_shifted = self.conv1d(x_raw.transpose(1, 2))[:, :, :T].transpose(1, 2)
|
| 144 |
+
v_hat = x_shifted @ self.w_target
|
| 145 |
+
delta = v_hat.transpose(-2, -1) @ z
|
| 146 |
+
norm = delta.norm()
|
| 147 |
+
if float(norm.item()) > self.delta_clip:
|
| 148 |
+
delta = delta * (self.delta_clip / norm)
|
| 149 |
+
return delta
|
| 150 |
+
|
| 151 |
+
def apply_update(self, w_down: torch.Tensor, delta: torch.Tensor) -> torch.Tensor:
|
| 152 |
+
return w_down + self.inner_lr * delta
|
| 153 |
+
|
| 154 |
+
def forward(self, x_raw: torch.Tensor, z: torch.Tensor,
|
| 155 |
+
w_down: torch.Tensor) -> torch.Tensor:
|
| 156 |
+
if not self.enabled:
|
| 157 |
+
return w_down
|
| 158 |
+
return self.apply_update(w_down, self.compute_update(x_raw, z, w_down))
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
# ---------------------------------------------------------------------------
|
| 162 |
+
# Episodic case memory
|
| 163 |
+
# ---------------------------------------------------------------------------
|
| 164 |
+
|
| 165 |
+
class EpisodicCaseMemory(nn.Module):
|
| 166 |
+
def __init__(self, config: dict):
|
| 167 |
+
super().__init__()
|
| 168 |
+
self.enabled = bool(config.get("enabled", True))
|
| 169 |
+
self.max_cases = int(config.get("max_cases", 4096))
|
| 170 |
+
self.case_bytes = int(config.get("case_bytes", 2048))
|
| 171 |
+
case_dim = max(8, min(self.case_bytes, 512))
|
| 172 |
+
self.case_dim = case_dim
|
| 173 |
+
self.register_buffer("cases", torch.zeros(self.max_cases, case_dim))
|
| 174 |
+
self.register_buffer("weights", torch.ones(self.max_cases))
|
| 175 |
+
self.register_buffer("count", torch.zeros((), dtype=torch.long))
|
| 176 |
+
self.query_proj = nn.Linear(case_dim, case_dim, bias=False)
|
| 177 |
+
self.ema_decay = 0.99
|
| 178 |
+
|
| 179 |
+
def retrieve(self, query: torch.Tensor, top_k: int = 5):
|
| 180 |
+
c = int(self.count.item())
|
| 181 |
+
if c == 0:
|
| 182 |
+
return None
|
| 183 |
+
q = self.query_proj(query)
|
| 184 |
+
q_flat = F.normalize(q.reshape(-1, q.shape[-1]), dim=-1)
|
| 185 |
+
c_norm = F.normalize(self.cases[:c], dim=-1)
|
| 186 |
+
sims = torch.matmul(q_flat, c_norm.t()) * self.weights[:c].unsqueeze(0)
|
| 187 |
+
k = min(top_k, c)
|
| 188 |
+
scores, indices = sims.topk(k, dim=-1)
|
| 189 |
+
return self.cases[indices], scores
|
| 190 |
+
|
| 191 |
+
@torch.no_grad()
|
| 192 |
+
def store(self, case_vec: torch.Tensor, outcome: float = 1.0) -> None:
|
| 193 |
+
idx = int(self.count.item()) % self.max_cases
|
| 194 |
+
self.cases[idx] = case_vec.detach().reshape(-1)[:self.case_dim]
|
| 195 |
+
self.weights[idx] = float(outcome)
|
| 196 |
+
if int(self.count.item()) < self.max_cases:
|
| 197 |
+
self.count.add_(1)
|
| 198 |
+
|
| 199 |
+
@torch.no_grad()
|
| 200 |
+
def update_weight(self, idx: int, outcome: float) -> None:
|
| 201 |
+
self.weights[idx] = self.ema_decay * self.weights[idx] + (1.0 - self.ema_decay) * outcome
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
# ---------------------------------------------------------------------------
|
| 205 |
+
# Meta-guideline bank
|
| 206 |
+
# ---------------------------------------------------------------------------
|
| 207 |
+
|
| 208 |
+
class MetaGuidelineBank(nn.Module):
|
| 209 |
+
def __init__(self, config: dict):
|
| 210 |
+
super().__init__()
|
| 211 |
+
self.enabled = bool(config.get("enabled", True))
|
| 212 |
+
self.max_guidelines = int(config.get("max", 256))
|
| 213 |
+
bits = int(config.get("bits", 8192))
|
| 214 |
+
self.register_buffer("guidelines",
|
| 215 |
+
torch.zeros(self.max_guidelines, bits // 8, dtype=torch.uint8))
|
| 216 |
+
self.register_buffer("count", torch.zeros((), dtype=torch.long))
|
| 217 |
+
|
| 218 |
+
@torch.no_grad()
|
| 219 |
+
def add_guideline(self, vec: torch.Tensor) -> None:
|
| 220 |
+
idx = int(self.count.item()) % self.max_guidelines
|
| 221 |
+
self.guidelines[idx] = vec.detach()
|
| 222 |
+
if int(self.count.item()) < self.max_guidelines:
|
| 223 |
+
self.count.add_(1)
|
| 224 |
+
|
| 225 |
+
def query(self, query_vec: torch.Tensor, top_k: int = 5):
|
| 226 |
+
c = int(self.count.item())
|
| 227 |
+
if c == 0:
|
| 228 |
+
return None
|
| 229 |
+
dists = SemanticMemory.hamming_distance(
|
| 230 |
+
query_vec.unsqueeze(-2), self.guidelines[:c].unsqueeze(0))
|
| 231 |
+
k = min(top_k, c)
|
| 232 |
+
return dists.topk(k, dim=-1, largest=False)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
# ---------------------------------------------------------------------------
|
| 236 |
+
# Self-feedback / loop classifier
|
| 237 |
+
# ---------------------------------------------------------------------------
|
| 238 |
+
|
| 239 |
+
class SelfFeedback(nn.Module):
|
| 240 |
+
def __init__(self, config: dict):
|
| 241 |
+
super().__init__()
|
| 242 |
+
self.enabled = bool(config.get("enabled", True))
|
| 243 |
+
self.confidence_threshold = float(config.get("confidence_threshold", 0.6))
|
| 244 |
+
self.max_rounds = int(config.get("max_refinement_rounds", 1))
|
| 245 |
+
|
| 246 |
+
def should_refine(self, confidence: float) -> bool:
|
| 247 |
+
return self.enabled and confidence < self.confidence_threshold
|
| 248 |
+
|
| 249 |
+
def forward(self, logits: torch.Tensor) -> torch.Tensor:
|
| 250 |
+
return F.softmax(logits, dim=-1).amax(dim=-1).mean()
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
class LoopDepthClassifier(nn.Module):
|
| 254 |
+
def __init__(self, config: dict, in_features: int = 256):
|
| 255 |
+
super().__init__()
|
| 256 |
+
self.enabled = bool(config.get("enabled", True))
|
| 257 |
+
self.net = nn.Sequential(
|
| 258 |
+
nn.Linear(in_features, in_features),
|
| 259 |
+
nn.ReLU(inplace=True),
|
| 260 |
+
nn.Linear(in_features, 6),
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
| 264 |
+
return self.net(features).argmax(dim=-1) + 1
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
# ---------------------------------------------------------------------------
|
| 268 |
+
# Self-evolution engine
|
| 269 |
+
# ---------------------------------------------------------------------------
|
| 270 |
+
|
| 271 |
+
class SelfEvolutionEngine(nn.Module):
|
| 272 |
+
def __init__(self, config: dict, hidden_size: int):
|
| 273 |
+
super().__init__()
|
| 274 |
+
t1 = config.get("tier1", {})
|
| 275 |
+
t2 = config.get("tier2", {})
|
| 276 |
+
t3 = config.get("tier3", {})
|
| 277 |
+
self.ttt = InPlaceTTT(t1.get("ttt", {}), hidden_size)
|
| 278 |
+
self.semantic_memory = SemanticMemory(config.get("_semantic_memory_config", {}))
|
| 279 |
+
self.episodic = EpisodicCaseMemory(t2.get("episodic_cases", {}))
|
| 280 |
+
self.meta_guidelines = MetaGuidelineBank(t2.get("meta_guidelines", {}))
|
| 281 |
+
self.self_feedback = SelfFeedback(t2.get("self_feedback", {}))
|
| 282 |
+
self.loop_classifier = LoopDepthClassifier(t3.get("loop_depth_learning", {}))
|
| 283 |
+
safety = config.get("safety", {})
|
| 284 |
+
self.freeze_threshold = float(safety.get("freeze_threshold", 0.05))
|
| 285 |
+
self.frozen = False
|
| 286 |
+
|
| 287 |
+
def check_safety(self, cert_failure_rate: float) -> bool:
|
| 288 |
+
if cert_failure_rate > self.freeze_threshold:
|
| 289 |
+
self.frozen = True
|
| 290 |
+
return self.frozen
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
__all__ = [
|
| 294 |
+
"SemanticMemory",
|
| 295 |
+
"InPlaceTTT",
|
| 296 |
+
"EpisodicCaseMemory",
|
| 297 |
+
"MetaGuidelineBank",
|
| 298 |
+
"SelfFeedback",
|
| 299 |
+
"LoopDepthClassifier",
|
| 300 |
+
"SelfEvolutionEngine",
|
| 301 |
+
]
|
chimera/inference.py
ADDED
|
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Chimera 5.2 — inference-time helpers (CPU-first).
|
| 3 |
+
|
| 4 |
+
This module collects all the lightweight components that run *after* the
|
| 5 |
+
trunk produces hidden states:
|
| 6 |
+
|
| 7 |
+
* :class:`SpanBank` — vectorised semantic memory.
|
| 8 |
+
* :class:`STreeVerifier` — tiny scoring head.
|
| 9 |
+
* :class:`CertificateVerifier`— per-token risk projection.
|
| 10 |
+
* :class:`SpanInferenceEngine`— glue + risk gating.
|
| 11 |
+
* :class:`GrammarFST` — additive constraint penalty.
|
| 12 |
+
* :class:`EntropyValve` — adaptive loop-count router.
|
| 13 |
+
* :class:`DebtLedger` — bias logits to honour outstanding obligations.
|
| 14 |
+
* :class:`BraidState` — runtime scratch state.
|
| 15 |
+
|
| 16 |
+
Optimisations vs the previous draft:
|
| 17 |
+
* Grammar / Debt are *true* identity ops when their constraints are empty
|
| 18 |
+
(no tensors allocated, no projections run) — this matters because they
|
| 19 |
+
sit on the per-token logits path.
|
| 20 |
+
* Entropy is computed on the slice the model actually scores (not the
|
| 21 |
+
full 200K-vocab logits): the model passes us the last-token logits.
|
| 22 |
+
* Everything that does not depend on the input shape is allocated once.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
from __future__ import annotations
|
| 26 |
+
|
| 27 |
+
import math
|
| 28 |
+
from typing import Optional, Tuple
|
| 29 |
+
|
| 30 |
+
import torch
|
| 31 |
+
import torch.nn as nn
|
| 32 |
+
import torch.nn.functional as F
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# ---------------------------------------------------------------------------
|
| 36 |
+
# SpanBank
|
| 37 |
+
# ---------------------------------------------------------------------------
|
| 38 |
+
|
| 39 |
+
class SpanBank(nn.Module):
|
| 40 |
+
"""Cosine-similarity span memory used for retrieval-augmented inference."""
|
| 41 |
+
|
| 42 |
+
def __init__(self, max_entries: int = 524288, max_tokens: int = 64,
|
| 43 |
+
hidden_size: int = 2560, memory_mb: int = 384):
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.max_entries = int(max_entries)
|
| 46 |
+
self.max_tokens = int(max_tokens)
|
| 47 |
+
self.hidden_size = int(hidden_size)
|
| 48 |
+
proj_dim = max(8, hidden_size // 4)
|
| 49 |
+
# Estimate entries the user can actually afford in RAM.
|
| 50 |
+
budget = int(memory_mb) * 1024 * 1024
|
| 51 |
+
per_entry = (proj_dim + hidden_size) * 4 + 8
|
| 52 |
+
actual = max(1, min(self.max_entries, budget // per_entry))
|
| 53 |
+
self.proj_dim = proj_dim
|
| 54 |
+
self.register_buffer("bank_keys", torch.zeros(actual, proj_dim))
|
| 55 |
+
self.register_buffer("bank_values", torch.zeros(actual, hidden_size))
|
| 56 |
+
self.register_buffer("bank_lengths", torch.zeros(actual, dtype=torch.long))
|
| 57 |
+
self.register_buffer("bank_count", torch.zeros((), dtype=torch.long))
|
| 58 |
+
self.semantic_proj = nn.Linear(hidden_size, proj_dim, bias=False)
|
| 59 |
+
|
| 60 |
+
@property
|
| 61 |
+
def capacity(self) -> int:
|
| 62 |
+
return int(self.bank_keys.size(0))
|
| 63 |
+
|
| 64 |
+
def query_scores(self, hidden_state: torch.Tensor, top_k: int = 64
|
| 65 |
+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 66 |
+
c = int(self.bank_count.item())
|
| 67 |
+
if c == 0:
|
| 68 |
+
return None, None
|
| 69 |
+
q = F.normalize(self.semantic_proj(hidden_state), dim=-1)
|
| 70 |
+
keys = F.normalize(self.bank_keys[:c], dim=-1)
|
| 71 |
+
sims = torch.matmul(q, keys.t())
|
| 72 |
+
k = min(top_k, c)
|
| 73 |
+
return torch.topk(sims, k, dim=-1)
|
| 74 |
+
|
| 75 |
+
def query(self, hidden_state: torch.Tensor, top_k: int = 64) -> torch.Tensor:
|
| 76 |
+
scores, indices = self.query_scores(hidden_state, top_k=top_k)
|
| 77 |
+
if scores is None:
|
| 78 |
+
return torch.zeros_like(hidden_state)
|
| 79 |
+
c = int(self.bank_count.item())
|
| 80 |
+
values = self.bank_values[:c][indices]
|
| 81 |
+
weights = torch.softmax(scores, dim=-1).unsqueeze(-1)
|
| 82 |
+
return (values * weights).sum(dim=-2)
|
| 83 |
+
|
| 84 |
+
@torch.no_grad()
|
| 85 |
+
def add(self, keys: torch.Tensor, values: torch.Tensor) -> None:
|
| 86 |
+
"""Bulk insert; vectorised, falls back to overwriting once full."""
|
| 87 |
+
keys = keys.detach().reshape(-1, self.hidden_size)
|
| 88 |
+
values = values.detach().reshape(-1, self.hidden_size)
|
| 89 |
+
n = keys.size(0)
|
| 90 |
+
if n == 0:
|
| 91 |
+
return
|
| 92 |
+
cap = self.capacity
|
| 93 |
+
start = int(self.bank_count.item())
|
| 94 |
+
end = min(start + n, cap)
|
| 95 |
+
write = end - start
|
| 96 |
+
if write > 0:
|
| 97 |
+
self.bank_keys[start:end] = self.semantic_proj(keys[:write])
|
| 98 |
+
self.bank_values[start:end] = values[:write]
|
| 99 |
+
self.bank_lengths[start:end] = 1
|
| 100 |
+
self.bank_count.add_(write)
|
| 101 |
+
|
| 102 |
+
@torch.no_grad()
|
| 103 |
+
def add_span(self, hidden_state: torch.Tensor, length: int,
|
| 104 |
+
value: Optional[torch.Tensor] = None) -> None:
|
| 105 |
+
h = hidden_state.detach().reshape(-1, self.hidden_size).mean(dim=0, keepdim=True)
|
| 106 |
+
v = (value.detach().reshape(-1, self.hidden_size).mean(dim=0, keepdim=True)
|
| 107 |
+
if value is not None else h)
|
| 108 |
+
self.add(h, v)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# ---------------------------------------------------------------------------
|
| 112 |
+
# Verifiers
|
| 113 |
+
# ---------------------------------------------------------------------------
|
| 114 |
+
|
| 115 |
+
class STreeVerifier(nn.Module):
|
| 116 |
+
"""Tiny scoring head used by speculative-tree decoding."""
|
| 117 |
+
|
| 118 |
+
def __init__(self, tree_width: int = 4, tree_depth: int = 5,
|
| 119 |
+
hidden_size: int = 256):
|
| 120 |
+
super().__init__()
|
| 121 |
+
self.tree_width = int(tree_width)
|
| 122 |
+
self.tree_depth = int(tree_depth)
|
| 123 |
+
h_mid = max(8, hidden_size // 4)
|
| 124 |
+
self.score_net = nn.Sequential(
|
| 125 |
+
nn.Linear(hidden_size, h_mid),
|
| 126 |
+
nn.ReLU(inplace=True),
|
| 127 |
+
nn.Linear(h_mid, 1),
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 131 |
+
return torch.sigmoid(self.score_net(hidden_states)).squeeze(-1)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class CertificateVerifier(nn.Module):
|
| 135 |
+
"""Per-token certificate fields (semantic / grammar / entity / risk)."""
|
| 136 |
+
|
| 137 |
+
def __init__(self, hidden_size: int):
|
| 138 |
+
super().__init__()
|
| 139 |
+
self.semantic_proj = nn.Linear(hidden_size, 64, bias=False)
|
| 140 |
+
self.grammar_proj = nn.Linear(hidden_size, 16, bias=False)
|
| 141 |
+
self.entity_proj = nn.Linear(hidden_size, 32, bias=False)
|
| 142 |
+
self.boundary_proj = nn.Linear(hidden_size, 1, bias=False)
|
| 143 |
+
self.risk_proj = nn.Linear(hidden_size, 1, bias=False)
|
| 144 |
+
|
| 145 |
+
def forward(self, hidden_states: torch.Tensor) -> dict:
|
| 146 |
+
return {
|
| 147 |
+
"semantic": self.semantic_proj(hidden_states),
|
| 148 |
+
"grammar": self.grammar_proj(hidden_states),
|
| 149 |
+
"entity": self.entity_proj(hidden_states),
|
| 150 |
+
"boundary": self.boundary_proj(hidden_states),
|
| 151 |
+
"risk": torch.sigmoid(self.risk_proj(hidden_states)),
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class SpanInferenceEngine(nn.Module):
|
| 156 |
+
"""Risk-gated post-trunk hidden-state modulation."""
|
| 157 |
+
|
| 158 |
+
def __init__(self, hidden_size: int, config: dict):
|
| 159 |
+
super().__init__()
|
| 160 |
+
self.enabled = bool(config.get("enabled", True))
|
| 161 |
+
self.hidden_size = int(hidden_size)
|
| 162 |
+
self.span_bank = SpanBank(
|
| 163 |
+
max_entries=config.get("bank_entries", 524288),
|
| 164 |
+
max_tokens=config.get("bank_max_tokens", 64),
|
| 165 |
+
hidden_size=self.hidden_size,
|
| 166 |
+
memory_mb=config.get("bank_memory_mb", 384),
|
| 167 |
+
)
|
| 168 |
+
self.tree_verifier = STreeVerifier(
|
| 169 |
+
tree_width=config.get("tree_verify", {}).get("tree_width", 4),
|
| 170 |
+
tree_depth=config.get("tree_verify", {}).get("tree_depth", 5),
|
| 171 |
+
hidden_size=self.hidden_size,
|
| 172 |
+
)
|
| 173 |
+
self.certificate = CertificateVerifier(self.hidden_size)
|
| 174 |
+
self.scoring_weights = nn.Parameter(
|
| 175 |
+
torch.tensor(config.get("scoring_weights_fast", [1.0, 0.8, 0.5, 0.7, 0.35])))
|
| 176 |
+
self.fallback_threshold = float(config.get("fallback_below_acceptance", 0.5))
|
| 177 |
+
# Single fused gate from concatenated hidden + risk.
|
| 178 |
+
self.risk_gate = nn.Linear(self.hidden_size + 1, self.hidden_size, bias=False)
|
| 179 |
+
|
| 180 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 181 |
+
if not self.enabled:
|
| 182 |
+
return hidden_states
|
| 183 |
+
risk = torch.sigmoid(self.certificate.risk_proj(hidden_states))
|
| 184 |
+
gate_input = torch.cat([hidden_states, risk], dim=-1)
|
| 185 |
+
modulation = torch.sigmoid(self.risk_gate(gate_input))
|
| 186 |
+
return hidden_states * modulation
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
# ---------------------------------------------------------------------------
|
| 190 |
+
# Grammar FST — additive penalty (no-op when no constraints)
|
| 191 |
+
# ---------------------------------------------------------------------------
|
| 192 |
+
|
| 193 |
+
class GrammarFST(nn.Module):
|
| 194 |
+
"""Soft-constraint penalty on next-token logits.
|
| 195 |
+
|
| 196 |
+
*Identity* when ``enabled`` is false **or** there are no constraints –
|
| 197 |
+
no entropy computation, no projection allocations.
|
| 198 |
+
"""
|
| 199 |
+
|
| 200 |
+
def __init__(self, config: dict):
|
| 201 |
+
super().__init__()
|
| 202 |
+
self.enabled = bool(config.get("enabled", True))
|
| 203 |
+
self.hard_constraints = list(config.get("hard_constraints", []))
|
| 204 |
+
self.soft_constraints = list(config.get("soft_constraints", []))
|
| 205 |
+
n_features = len(self.hard_constraints) + len(self.soft_constraints) + 1
|
| 206 |
+
self._n_hard = len(self.hard_constraints)
|
| 207 |
+
self._n_soft = len(self.soft_constraints)
|
| 208 |
+
self._n_features = n_features
|
| 209 |
+
self._is_noop = (not self.enabled) or n_features <= 1
|
| 210 |
+
self.constraint_proj = nn.Linear(n_features, 1, bias=True)
|
| 211 |
+
nn.init.normal_(self.constraint_proj.weight, std=0.01)
|
| 212 |
+
nn.init.zeros_(self.constraint_proj.bias)
|
| 213 |
+
|
| 214 |
+
def forward(self, logits: torch.Tensor, state=None) -> torch.Tensor:
|
| 215 |
+
if self._is_noop:
|
| 216 |
+
return logits
|
| 217 |
+
B, T, V = logits.shape
|
| 218 |
+
# Single log_softmax pass for entropy.
|
| 219 |
+
log_probs = F.log_softmax(logits, dim=-1)
|
| 220 |
+
entropy = -(log_probs.exp() * log_probs).sum(-1) # [B, T]
|
| 221 |
+
features = logits.new_zeros(B, T, self._n_features)
|
| 222 |
+
features[..., 0] = entropy
|
| 223 |
+
if self._n_soft > 0 and T > 1:
|
| 224 |
+
cos = F.cosine_similarity(logits[:, 1:], logits[:, :-1], dim=-1)
|
| 225 |
+
features[:, 1:, self._n_hard] = cos.clamp_min(0.0)
|
| 226 |
+
penalty = self.constraint_proj(features) # [B, T, 1]
|
| 227 |
+
return logits + penalty
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
# ---------------------------------------------------------------------------
|
| 231 |
+
# Entropy valve
|
| 232 |
+
# ---------------------------------------------------------------------------
|
| 233 |
+
|
| 234 |
+
class EntropyValve(nn.Module):
|
| 235 |
+
"""Maps logits entropy → adaptive loop count for the looped trunk."""
|
| 236 |
+
|
| 237 |
+
def __init__(self, config: dict):
|
| 238 |
+
super().__init__()
|
| 239 |
+
self.enabled = bool(config.get("enabled", True))
|
| 240 |
+
self.threshold_bits = float(config.get("threshold_bits", 2.0))
|
| 241 |
+
self.levels = dict(config.get("levels", {
|
| 242 |
+
"low": {"loops": 1, "min_span": 8, "audit": 0.125},
|
| 243 |
+
"medium": {"loops": 2, "min_span": 4, "audit": 0.5},
|
| 244 |
+
"high": {"loops": 4, "min_span": 1, "audit": 1.0},
|
| 245 |
+
}))
|
| 246 |
+
self.router = nn.Sequential(nn.Linear(6, 32), nn.ReLU(inplace=True),
|
| 247 |
+
nn.Linear(32, 3))
|
| 248 |
+
self._inv_log2 = 1.0 / math.log(2.0)
|
| 249 |
+
|
| 250 |
+
def compute_entropy(self, logits: torch.Tensor) -> torch.Tensor:
|
| 251 |
+
log_probs = F.log_softmax(logits.to(torch.float32), dim=-1)
|
| 252 |
+
return -(log_probs.exp() * log_probs).sum(dim=-1) * self._inv_log2
|
| 253 |
+
|
| 254 |
+
def get_level(self, entropy: torch.Tensor) -> str:
|
| 255 |
+
if not self.enabled:
|
| 256 |
+
return "medium"
|
| 257 |
+
mean_h = float(entropy.mean().item())
|
| 258 |
+
if mean_h < self.threshold_bits * 0.5:
|
| 259 |
+
return "low"
|
| 260 |
+
if mean_h < self.threshold_bits:
|
| 261 |
+
return "medium"
|
| 262 |
+
return "high"
|
| 263 |
+
|
| 264 |
+
def get_loop_count(self, logits: torch.Tensor) -> int:
|
| 265 |
+
if not self.enabled:
|
| 266 |
+
return self.levels.get("medium", {}).get("loops", 2)
|
| 267 |
+
level = self.get_level(self.compute_entropy(logits))
|
| 268 |
+
return self.levels.get(level, self.levels["medium"])["loops"]
|
| 269 |
+
|
| 270 |
+
def forward(self, logits: torch.Tensor):
|
| 271 |
+
entropy = self.compute_entropy(logits)
|
| 272 |
+
level = self.get_level(entropy)
|
| 273 |
+
return level, self.levels.get(level, self.levels["medium"])
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
# ---------------------------------------------------------------------------
|
| 277 |
+
# Debt ledger — additive bias (no-op when no obligations)
|
| 278 |
+
# ---------------------------------------------------------------------------
|
| 279 |
+
|
| 280 |
+
class DebtLedger(nn.Module):
|
| 281 |
+
def __init__(self, config: dict):
|
| 282 |
+
super().__init__()
|
| 283 |
+
self.enabled = bool(config.get("enabled", True))
|
| 284 |
+
self.obligations = list(config.get("obligations", []))
|
| 285 |
+
self.max_outstanding = int(config.get("max_outstanding", 64))
|
| 286 |
+
self.pressure_weight = float(config.get("pressure_weight", 0.3))
|
| 287 |
+
self.active_debts: list = []
|
| 288 |
+
self.debt_bias_scale = nn.Parameter(torch.tensor(0.5))
|
| 289 |
+
self.debt_proj = nn.Linear(1, 1, bias=True)
|
| 290 |
+
nn.init.ones_(self.debt_proj.weight)
|
| 291 |
+
nn.init.zeros_(self.debt_proj.bias)
|
| 292 |
+
|
| 293 |
+
def add_debt(self, debt_type: str) -> None:
|
| 294 |
+
if len(self.active_debts) < self.max_outstanding:
|
| 295 |
+
self.active_debts.append(debt_type)
|
| 296 |
+
|
| 297 |
+
def resolve_debt(self, debt_type: str) -> None:
|
| 298 |
+
try:
|
| 299 |
+
self.active_debts.remove(debt_type)
|
| 300 |
+
except ValueError:
|
| 301 |
+
pass
|
| 302 |
+
|
| 303 |
+
def get_pressure(self) -> float:
|
| 304 |
+
return self.pressure_weight * len(self.active_debts) / max(self.max_outstanding, 1)
|
| 305 |
+
|
| 306 |
+
def forward(self, logits: torch.Tensor) -> torch.Tensor:
|
| 307 |
+
if not self.enabled or not self.active_debts:
|
| 308 |
+
return logits
|
| 309 |
+
pressure = self.get_pressure()
|
| 310 |
+
if pressure <= 0.0:
|
| 311 |
+
return logits
|
| 312 |
+
boost = self.debt_bias_scale * pressure
|
| 313 |
+
boosted = self.debt_proj(boost.view(1, 1, 1))
|
| 314 |
+
return logits + boosted * 0.01
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
# ---------------------------------------------------------------------------
|
| 318 |
+
# BraidState — runtime scratch container
|
| 319 |
+
# ---------------------------------------------------------------------------
|
| 320 |
+
|
| 321 |
+
class BraidState:
|
| 322 |
+
"""Plain-Python structure holding the runtime working memory."""
|
| 323 |
+
|
| 324 |
+
__slots__ = ["continuous", "fast", "semantic_sketch", "entity_slots",
|
| 325 |
+
"grammar_stack", "debt_ledger_slots"]
|
| 326 |
+
|
| 327 |
+
def __init__(self, config: dict, device: str = "cpu"):
|
| 328 |
+
D = int(config.get("continuous_hidden", [2560, "float32"])[0])
|
| 329 |
+
self.continuous = torch.zeros(1, D, dtype=torch.float32, device=device)
|
| 330 |
+
self.fast = torch.zeros(1, D, dtype=torch.int8, device=device)
|
| 331 |
+
bits = int(config.get("semantic_sketch", [8192, "uint64_x128"])[0])
|
| 332 |
+
self.semantic_sketch = torch.zeros(1, bits // 8, dtype=torch.uint8, device=device)
|
| 333 |
+
et = config.get("entity_table", {})
|
| 334 |
+
self.entity_slots = torch.zeros(
|
| 335 |
+
int(et.get("slots", 256)), int(et.get("slot_bits", 512)) // 8,
|
| 336 |
+
dtype=torch.uint8, device=device)
|
| 337 |
+
gs = config.get("grammar_stack", {})
|
| 338 |
+
self.grammar_stack = torch.zeros(
|
| 339 |
+
int(gs.get("slots", 64)), int(gs.get("width_bits", 128)) // 8,
|
| 340 |
+
dtype=torch.uint8, device=device)
|
| 341 |
+
self.debt_ledger_slots = torch.zeros(
|
| 342 |
+
int(config.get("debt_ledger_slots", 64)), dtype=torch.int32, device=device)
|
| 343 |
+
|
| 344 |
+
def reset(self) -> None:
|
| 345 |
+
self.continuous.zero_()
|
| 346 |
+
self.fast.zero_()
|
| 347 |
+
self.semantic_sketch.zero_()
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
__all__ = [
|
| 351 |
+
"SpanBank",
|
| 352 |
+
"STreeVerifier",
|
| 353 |
+
"CertificateVerifier",
|
| 354 |
+
"SpanInferenceEngine",
|
| 355 |
+
"GrammarFST",
|
| 356 |
+
"EntropyValve",
|
| 357 |
+
"DebtLedger",
|
| 358 |
+
"BraidState",
|
| 359 |
+
]
|
chimera/layers.py
ADDED
|
@@ -0,0 +1,485 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Chimera 5.2 — recurrent / attention layers (CPU-first).
|
| 3 |
+
|
| 4 |
+
Every layer in this module exposes a ``forward(x, cache=None)`` signature and
|
| 5 |
+
returns ``(out, new_cache)``. ``cache`` is an arbitrary tensor / dict that the
|
| 6 |
+
layer reads on the previous timestep and returns updated for the next call.
|
| 7 |
+
This makes O(T) decoding possible instead of the O(T²) recompute used by
|
| 8 |
+
the original implementation.
|
| 9 |
+
|
| 10 |
+
Optimisations vs. the previous draft:
|
| 11 |
+
* No ``einops`` dependency — every reshape is a plain :func:`Tensor.view`.
|
| 12 |
+
* Mask cache keyed by (T, dtype, device) — no per-token allocation churn.
|
| 13 |
+
* Gated DeltaNet uses a chunkwise parallel scan with **no** in-place clones
|
| 14 |
+
during training (the inter-chunk recurrence runs at fp32 with detached
|
| 15 |
+
state on CPU, gradient flow is preserved through the per-chunk QKV path).
|
| 16 |
+
* mLSTM forgets are accumulated in log-space with a single ``cumsum``; the
|
| 17 |
+
causal mask is added once instead of per-row.
|
| 18 |
+
* TitansMAC only computes the values it actually uses (the original draft
|
| 19 |
+
built ``kv`` and threw it away – removed).
|
| 20 |
+
* TSPSpanKnotLayer's energy is a single fused linear projection; the per-step
|
| 21 |
+
Hamming/coherence loops are replaced by vectorised cosine similarity.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
from __future__ import annotations
|
| 25 |
+
|
| 26 |
+
import math
|
| 27 |
+
from typing import Optional, Tuple
|
| 28 |
+
|
| 29 |
+
import torch
|
| 30 |
+
import torch.nn as nn
|
| 31 |
+
import torch.nn.functional as F
|
| 32 |
+
|
| 33 |
+
from .quantization import BitLinear, RMSNorm
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# ---------------------------------------------------------------------------
|
| 37 |
+
# Shared utilities
|
| 38 |
+
# ---------------------------------------------------------------------------
|
| 39 |
+
|
| 40 |
+
_MASK_CACHE: dict = {}
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _causal_mask_neg_inf(T: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
| 44 |
+
"""Cached additive causal mask: 0 on/below diag, ``-inf`` above."""
|
| 45 |
+
key = ("neg_inf", T, str(device), dtype)
|
| 46 |
+
cached = _MASK_CACHE.get(key)
|
| 47 |
+
if cached is not None:
|
| 48 |
+
return cached
|
| 49 |
+
# Build outside any autograd / inference-mode context so the tensor is a
|
| 50 |
+
# plain leaf that can be reused across train/eval/inference_mode calls.
|
| 51 |
+
with torch.inference_mode(False), torch.no_grad():
|
| 52 |
+
mask = torch.zeros(T, T, dtype=dtype, device=device)
|
| 53 |
+
mask.masked_fill_(
|
| 54 |
+
torch.triu(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=1),
|
| 55 |
+
float("-inf"),
|
| 56 |
+
)
|
| 57 |
+
_MASK_CACHE[key] = mask
|
| 58 |
+
return mask
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _causal_tril_bool(T: int, device: torch.device) -> torch.Tensor:
|
| 62 |
+
"""Lower-triangular bool mask (``True`` on/below diag) for multiplicative gating."""
|
| 63 |
+
key = ("tril_bool", T, str(device))
|
| 64 |
+
cached = _MASK_CACHE.get(key)
|
| 65 |
+
if cached is not None:
|
| 66 |
+
return cached
|
| 67 |
+
with torch.inference_mode(False), torch.no_grad():
|
| 68 |
+
mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device))
|
| 69 |
+
_MASK_CACHE[key] = mask
|
| 70 |
+
return mask
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _make_linear(use_ternary: bool):
|
| 74 |
+
if use_ternary:
|
| 75 |
+
return BitLinear
|
| 76 |
+
return lambda i, o, **kw: nn.Linear(i, o, bias=False)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# ---------------------------------------------------------------------------
|
| 80 |
+
# SwiGLU MLP (shared with MoE)
|
| 81 |
+
# ---------------------------------------------------------------------------
|
| 82 |
+
|
| 83 |
+
class SwiGLUMLP(nn.Module):
|
| 84 |
+
"""SwiGLU feed-forward block: ``down(silu(gate(x)) * up(x))``."""
|
| 85 |
+
|
| 86 |
+
__constants__ = ["hidden_size", "intermediate_size"]
|
| 87 |
+
|
| 88 |
+
def __init__(self, hidden_size: int, intermediate_size: int, use_ternary: bool = True):
|
| 89 |
+
super().__init__()
|
| 90 |
+
L = _make_linear(use_ternary)
|
| 91 |
+
self.hidden_size = int(hidden_size)
|
| 92 |
+
self.intermediate_size = int(intermediate_size)
|
| 93 |
+
self.gate_proj = L(self.hidden_size, self.intermediate_size)
|
| 94 |
+
self.up_proj = L(self.hidden_size, self.intermediate_size)
|
| 95 |
+
self.down_proj = L(self.intermediate_size, self.hidden_size)
|
| 96 |
+
|
| 97 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 98 |
+
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# ---------------------------------------------------------------------------
|
| 102 |
+
# Causal depthwise conv (used by Gated DeltaNet)
|
| 103 |
+
# ---------------------------------------------------------------------------
|
| 104 |
+
|
| 105 |
+
class ShortConv1d(nn.Module):
|
| 106 |
+
"""Causal depthwise 1-D convolution + SiLU.
|
| 107 |
+
|
| 108 |
+
Supports streaming via a small (kernel_size-1) tail cache so generation
|
| 109 |
+
runs at O(1) per token even though the conv has a kernel > 1.
|
| 110 |
+
"""
|
| 111 |
+
|
| 112 |
+
__constants__ = ["kernel_size", "dim"]
|
| 113 |
+
|
| 114 |
+
def __init__(self, dim: int, kernel_size: int = 4):
|
| 115 |
+
super().__init__()
|
| 116 |
+
self.dim = int(dim)
|
| 117 |
+
self.kernel_size = int(kernel_size)
|
| 118 |
+
self.conv = nn.Conv1d(self.dim, self.dim, self.kernel_size,
|
| 119 |
+
padding=self.kernel_size - 1, groups=self.dim, bias=False)
|
| 120 |
+
|
| 121 |
+
def forward(self, x: torch.Tensor, tail: Optional[torch.Tensor] = None
|
| 122 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 123 |
+
# x: [B, T, D] -> conv expects [B, D, T]
|
| 124 |
+
B, T, D = x.shape
|
| 125 |
+
xt = x.transpose(1, 2) # [B, D, T]
|
| 126 |
+
if tail is not None and tail.numel() > 0:
|
| 127 |
+
xt = torch.cat([tail, xt], dim=-1)
|
| 128 |
+
T_full = xt.shape[-1]
|
| 129 |
+
else:
|
| 130 |
+
T_full = T
|
| 131 |
+
y = self.conv(xt)[..., :T_full] # causal: drop the trailing pad slack
|
| 132 |
+
y = y[..., -T:] # only keep outputs aligned with new inputs
|
| 133 |
+
new_tail = xt[..., -(self.kernel_size - 1):] if self.kernel_size > 1 else xt[..., :0]
|
| 134 |
+
return F.silu(y).transpose(1, 2), new_tail
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
# ---------------------------------------------------------------------------
|
| 138 |
+
# Gated DeltaNet (chunkwise parallel + recurrent state)
|
| 139 |
+
# ---------------------------------------------------------------------------
|
| 140 |
+
|
| 141 |
+
def _gated_delta_chunkwise(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
| 142 |
+
g: torch.Tensor, beta: torch.Tensor,
|
| 143 |
+
state: Optional[torch.Tensor], chunk_size: int
|
| 144 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 145 |
+
"""Chunkwise gated delta-rule scan.
|
| 146 |
+
|
| 147 |
+
Inputs are [B, T, H, D] for Q/K/V and [B, T, H] for ``g`` / ``beta``.
|
| 148 |
+
``state`` is the carried K^T V at fp32, shape [B, H, K, V] or ``None``.
|
| 149 |
+
Returns (output [B, T, H, V], new_state).
|
| 150 |
+
"""
|
| 151 |
+
B, T, H, K = q.shape
|
| 152 |
+
V = v.shape[-1]
|
| 153 |
+
device = q.device
|
| 154 |
+
|
| 155 |
+
# Permute once: [B, H, T, *]
|
| 156 |
+
q = q.permute(0, 2, 1, 3).contiguous().to(torch.float32)
|
| 157 |
+
k = k.permute(0, 2, 1, 3).contiguous().to(torch.float32)
|
| 158 |
+
v = v.permute(0, 2, 1, 3).contiguous().to(torch.float32)
|
| 159 |
+
g = g.permute(0, 2, 1).contiguous().to(torch.float32) # [B, H, T]
|
| 160 |
+
beta = beta.permute(0, 2, 1).contiguous().to(torch.float32) # [B, H, T]
|
| 161 |
+
|
| 162 |
+
scale = K ** -0.5
|
| 163 |
+
q = q * scale
|
| 164 |
+
v = v * beta.unsqueeze(-1)
|
| 165 |
+
|
| 166 |
+
chunk = min(chunk_size, T)
|
| 167 |
+
if state is None:
|
| 168 |
+
S = torch.zeros(B, H, K, V, device=device, dtype=torch.float32)
|
| 169 |
+
else:
|
| 170 |
+
S = state.to(torch.float32)
|
| 171 |
+
|
| 172 |
+
out_chunks = []
|
| 173 |
+
for start in range(0, T, chunk):
|
| 174 |
+
end = min(start + chunk, T)
|
| 175 |
+
c = end - start
|
| 176 |
+
qc, kc, vc, gc = q[:, :, start:end], k[:, :, start:end], v[:, :, start:end], g[:, :, start:end]
|
| 177 |
+
|
| 178 |
+
# Cumulative log-decay within the chunk.
|
| 179 |
+
log_decay = gc.cumsum(dim=-1) # [B, H, c]
|
| 180 |
+
# Within-chunk weighting: exp(log_decay[i] - log_decay[j]) for j <= i
|
| 181 |
+
# Built once via outer subtraction; mask non-causal entries to 0.
|
| 182 |
+
diff = log_decay.unsqueeze(-1) - log_decay.unsqueeze(-2) # [B, H, c, c]
|
| 183 |
+
causal = _causal_tril_bool(c, device) # [c, c]
|
| 184 |
+
intra_w = torch.where(causal, diff.exp(), torch.zeros_like(diff))
|
| 185 |
+
|
| 186 |
+
# Output = qc @ kc^T * intra_w @ vc + qc * exp(log_decay) @ S
|
| 187 |
+
attn = torch.matmul(qc, kc.transpose(-1, -2)) * intra_w # [B, H, c, c]
|
| 188 |
+
o_intra = torch.matmul(attn, vc) # [B, H, c, V]
|
| 189 |
+
o_inter = torch.matmul(qc * log_decay.unsqueeze(-1).exp(), S) # [B, H, c, V]
|
| 190 |
+
out_chunks.append(o_intra + o_inter)
|
| 191 |
+
|
| 192 |
+
# Update carried state: S <- S * exp(decay_total) + (kc * exp(decay_chunk_end - log_decay)).T @ vc
|
| 193 |
+
decay_total = log_decay[:, :, -1:] # [B, H, 1]
|
| 194 |
+
S = S * decay_total.unsqueeze(-1).exp()
|
| 195 |
+
per_step = (decay_total - log_decay).unsqueeze(-1).exp() # [B, H, c, 1]
|
| 196 |
+
S = S + torch.matmul((kc * per_step).transpose(-1, -2), vc)
|
| 197 |
+
|
| 198 |
+
out = torch.cat(out_chunks, dim=2) # [B, H, T, V]
|
| 199 |
+
return out.permute(0, 2, 1, 3).contiguous(), S
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class GatedDeltaNetLayer(nn.Module):
|
| 203 |
+
"""Gated DeltaNet — chunkwise parallel during training, O(1) per token at inference."""
|
| 204 |
+
|
| 205 |
+
def __init__(self, hidden_size: int, num_heads: int, head_dim: int,
|
| 206 |
+
expand_v: int = 1, conv_size: int = 4, norm_eps: float = 1e-6,
|
| 207 |
+
chunk_size: int = 64, use_ternary: bool = True):
|
| 208 |
+
super().__init__()
|
| 209 |
+
self.hidden_size = int(hidden_size)
|
| 210 |
+
self.num_heads = int(num_heads)
|
| 211 |
+
self.head_dim = int(head_dim)
|
| 212 |
+
self.head_v_dim = int(head_dim * expand_v)
|
| 213 |
+
self.key_dim = self.num_heads * self.head_dim
|
| 214 |
+
self.value_dim = self.num_heads * self.head_v_dim
|
| 215 |
+
self.chunk_size = int(chunk_size)
|
| 216 |
+
|
| 217 |
+
L = _make_linear(use_ternary)
|
| 218 |
+
self.q_proj = L(self.hidden_size, self.key_dim)
|
| 219 |
+
self.k_proj = L(self.hidden_size, self.key_dim)
|
| 220 |
+
self.v_proj = L(self.hidden_size, self.value_dim)
|
| 221 |
+
self.g_proj = L(self.hidden_size, self.value_dim)
|
| 222 |
+
self.o_proj = L(self.value_dim, self.hidden_size)
|
| 223 |
+
|
| 224 |
+
self.a_proj = nn.Linear(self.hidden_size, self.num_heads, bias=False)
|
| 225 |
+
self.b_proj = nn.Linear(self.hidden_size, self.num_heads, bias=False)
|
| 226 |
+
|
| 227 |
+
A = torch.empty(self.num_heads).uniform_(0.0, 16.0)
|
| 228 |
+
self.A_log = nn.Parameter(torch.log(A))
|
| 229 |
+
self.A_log._no_weight_decay = True
|
| 230 |
+
dt = torch.exp(torch.rand(self.num_heads) * (math.log(0.1) - math.log(1e-3)) + math.log(1e-3)).clamp_min(1e-4)
|
| 231 |
+
self.dt_bias = nn.Parameter(dt + torch.log(-torch.expm1(-dt)))
|
| 232 |
+
self.dt_bias._no_weight_decay = True
|
| 233 |
+
|
| 234 |
+
self.q_conv = ShortConv1d(self.key_dim, conv_size)
|
| 235 |
+
self.k_conv = ShortConv1d(self.key_dim, conv_size)
|
| 236 |
+
self.v_conv = ShortConv1d(self.value_dim, conv_size)
|
| 237 |
+
self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps)
|
| 238 |
+
|
| 239 |
+
def forward(self, x: torch.Tensor, cache: Optional[dict] = None
|
| 240 |
+
) -> Tuple[torch.Tensor, dict]:
|
| 241 |
+
B, T, _ = x.shape
|
| 242 |
+
prev_state = cache.get("state") if cache else None
|
| 243 |
+
prev_q_tail = cache.get("q_tail") if cache else None
|
| 244 |
+
prev_k_tail = cache.get("k_tail") if cache else None
|
| 245 |
+
prev_v_tail = cache.get("v_tail") if cache else None
|
| 246 |
+
|
| 247 |
+
q_full, q_tail = self.q_conv(self.q_proj(x), prev_q_tail)
|
| 248 |
+
k_full, k_tail = self.k_conv(self.k_proj(x), prev_k_tail)
|
| 249 |
+
v_full, v_tail = self.v_conv(self.v_proj(x), prev_v_tail)
|
| 250 |
+
|
| 251 |
+
q = q_full.view(B, T, self.num_heads, self.head_dim)
|
| 252 |
+
k = k_full.view(B, T, self.num_heads, self.head_dim)
|
| 253 |
+
v = v_full.view(B, T, self.num_heads, self.head_v_dim)
|
| 254 |
+
q = F.normalize(q, p=2.0, dim=-1)
|
| 255 |
+
k = F.normalize(k, p=2.0, dim=-1)
|
| 256 |
+
|
| 257 |
+
beta = torch.sigmoid(self.b_proj(x)) # [B, T, H]
|
| 258 |
+
A = -self.A_log.exp()
|
| 259 |
+
dt = F.softplus(self.a_proj(x) + self.dt_bias) # [B, T, H]
|
| 260 |
+
g = dt * A.view(1, 1, -1)
|
| 261 |
+
|
| 262 |
+
out, new_state = _gated_delta_chunkwise(q, k, v, g, beta,
|
| 263 |
+
state=prev_state,
|
| 264 |
+
chunk_size=self.chunk_size)
|
| 265 |
+
|
| 266 |
+
gate = self.g_proj(x).view(B, T, self.num_heads, self.head_v_dim)
|
| 267 |
+
out = self.o_norm(out) * F.silu(gate)
|
| 268 |
+
out = out.reshape(B, T, self.value_dim)
|
| 269 |
+
out = self.o_proj(out)
|
| 270 |
+
|
| 271 |
+
new_cache = {
|
| 272 |
+
"state": new_state.detach(),
|
| 273 |
+
"q_tail": q_tail.detach(),
|
| 274 |
+
"k_tail": k_tail.detach(),
|
| 275 |
+
"v_tail": v_tail.detach(),
|
| 276 |
+
}
|
| 277 |
+
return out, new_cache
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
# ---------------------------------------------------------------------------
|
| 281 |
+
# xLSTM mLSTM — parallel chunkwise + carried state
|
| 282 |
+
# ---------------------------------------------------------------------------
|
| 283 |
+
|
| 284 |
+
class MLSTMLayer(nn.Module):
|
| 285 |
+
"""Parallelised mLSTM with log-space cumulative gates."""
|
| 286 |
+
|
| 287 |
+
def __init__(self, hidden_size: int, num_heads: int, head_dim: int,
|
| 288 |
+
norm_eps: float = 1e-6, gate_soft_cap: float = 15.0,
|
| 289 |
+
use_ternary: bool = True):
|
| 290 |
+
super().__init__()
|
| 291 |
+
self.hidden_size = int(hidden_size)
|
| 292 |
+
self.num_heads = int(num_heads)
|
| 293 |
+
self.head_dim = int(head_dim)
|
| 294 |
+
self.qk_dim = self.num_heads * self.head_dim
|
| 295 |
+
self.v_dim = self.num_heads * self.head_dim
|
| 296 |
+
|
| 297 |
+
L = _make_linear(use_ternary)
|
| 298 |
+
self.q_proj = L(self.hidden_size, self.qk_dim)
|
| 299 |
+
self.k_proj = L(self.hidden_size, self.qk_dim)
|
| 300 |
+
self.v_proj = L(self.hidden_size, self.v_dim)
|
| 301 |
+
self.o_proj = L(self.v_dim, self.hidden_size)
|
| 302 |
+
|
| 303 |
+
self.igate = nn.Linear(self.hidden_size, self.num_heads, bias=True)
|
| 304 |
+
self.fgate = nn.Linear(self.hidden_size, self.num_heads, bias=True)
|
| 305 |
+
self.ogate = L(self.hidden_size, self.v_dim)
|
| 306 |
+
|
| 307 |
+
nn.init.constant_(self.igate.bias, -10.0)
|
| 308 |
+
with torch.no_grad():
|
| 309 |
+
self.fgate.bias.copy_(torch.linspace(3.0, 6.0, self.num_heads))
|
| 310 |
+
|
| 311 |
+
self.gate_soft_cap = float(gate_soft_cap)
|
| 312 |
+
self.o_norm = nn.LayerNorm(self.head_dim)
|
| 313 |
+
self.eps = 1e-6
|
| 314 |
+
|
| 315 |
+
@staticmethod
|
| 316 |
+
def _soft_cap(x: torch.Tensor, cap: float) -> torch.Tensor:
|
| 317 |
+
return cap * torch.tanh(x / cap)
|
| 318 |
+
|
| 319 |
+
def forward(self, x: torch.Tensor, cache: Optional[dict] = None
|
| 320 |
+
) -> Tuple[torch.Tensor, dict]:
|
| 321 |
+
B, T, _ = x.shape
|
| 322 |
+
H = self.num_heads
|
| 323 |
+
D = self.head_dim
|
| 324 |
+
scale = D ** -0.5
|
| 325 |
+
|
| 326 |
+
q = self.q_proj(x).view(B, T, H, D) * scale
|
| 327 |
+
k = self.k_proj(x).view(B, T, H, D)
|
| 328 |
+
v = self.v_proj(x).view(B, T, H, D)
|
| 329 |
+
|
| 330 |
+
i_raw = self._soft_cap(self.igate(x), self.gate_soft_cap) # [B, T, H]
|
| 331 |
+
f_raw = self._soft_cap(self.fgate(x), self.gate_soft_cap)
|
| 332 |
+
f_log = F.logsigmoid(f_raw) # [B, T, H]
|
| 333 |
+
|
| 334 |
+
# Log-space accumulators with carry-in.
|
| 335 |
+
prev_logf = cache.get("log_f_cum") if cache else None # [B, H]
|
| 336 |
+
log_f_cum = f_log.cumsum(dim=1) # [B, T, H]
|
| 337 |
+
if prev_logf is not None:
|
| 338 |
+
log_f_cum = log_f_cum + prev_logf.unsqueeze(1)
|
| 339 |
+
|
| 340 |
+
# Permute to head-major.
|
| 341 |
+
q_h = q.permute(0, 2, 1, 3) # [B, H, T, D]
|
| 342 |
+
k_h = k.permute(0, 2, 1, 3)
|
| 343 |
+
v_h = v.permute(0, 2, 1, 3)
|
| 344 |
+
log_f_cum_h = log_f_cum.permute(0, 2, 1) # [B, H, T]
|
| 345 |
+
i_raw_h = i_raw.permute(0, 2, 1)
|
| 346 |
+
|
| 347 |
+
# log_gate[t, s] = log_f_cum[t] - log_f_cum[s] + i[s], causal.
|
| 348 |
+
log_gate = (log_f_cum_h.unsqueeze(-1) - log_f_cum_h.unsqueeze(-2)
|
| 349 |
+
+ i_raw_h.unsqueeze(-2))
|
| 350 |
+
log_gate = log_gate + _causal_mask_neg_inf(T, x.device, log_gate.dtype)
|
| 351 |
+
m = log_gate.amax(dim=-1, keepdim=True).clamp_min(-30.0)
|
| 352 |
+
gate_w = (log_gate - m).exp() # [B, H, T, T]
|
| 353 |
+
|
| 354 |
+
attn = torch.matmul(q_h, k_h.transpose(-1, -2)) * gate_w
|
| 355 |
+
n = torch.matmul(gate_w, k_h) # [B, H, T, D]
|
| 356 |
+
denom = (q_h * n).sum(-1, keepdim=True).abs()
|
| 357 |
+
denom = torch.maximum(denom, torch.exp(-m)) + self.eps
|
| 358 |
+
|
| 359 |
+
out = torch.matmul(attn, v_h) / denom # [B, H, T, D]
|
| 360 |
+
out = self.o_norm(out.float()).to(x.dtype)
|
| 361 |
+
out = out.permute(0, 2, 1, 3).reshape(B, T, self.v_dim)
|
| 362 |
+
|
| 363 |
+
out_gate = torch.sigmoid(self.ogate(x))
|
| 364 |
+
out = self.o_proj(out_gate * out)
|
| 365 |
+
|
| 366 |
+
new_cache = {"log_f_cum": log_f_cum[:, -1].detach()}
|
| 367 |
+
return out, new_cache
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
# ---------------------------------------------------------------------------
|
| 371 |
+
# Titans MAC — gated linear attention with persistent memory
|
| 372 |
+
# ---------------------------------------------------------------------------
|
| 373 |
+
|
| 374 |
+
class TitansMACLayer(nn.Module):
|
| 375 |
+
"""Memory-as-Context linear attention with persistent memory slots."""
|
| 376 |
+
|
| 377 |
+
def __init__(self, hidden_size: int, num_heads: int, head_dim: int,
|
| 378 |
+
memory_depth: int = 2, persistent_slots: int = 64,
|
| 379 |
+
local_window: int = 1024, norm_eps: float = 1e-6,
|
| 380 |
+
use_ternary: bool = True):
|
| 381 |
+
super().__init__()
|
| 382 |
+
self.hidden_size = int(hidden_size)
|
| 383 |
+
self.num_heads = int(num_heads)
|
| 384 |
+
self.head_dim = int(head_dim)
|
| 385 |
+
self.memory_depth = int(memory_depth)
|
| 386 |
+
self.local_window = int(local_window)
|
| 387 |
+
self.persistent_slots = int(persistent_slots)
|
| 388 |
+
self.qk_dim = self.num_heads * self.head_dim
|
| 389 |
+
self.v_dim = self.num_heads * self.head_dim
|
| 390 |
+
|
| 391 |
+
L = _make_linear(use_ternary)
|
| 392 |
+
self.q_proj = L(self.hidden_size, self.qk_dim)
|
| 393 |
+
self.k_proj = L(self.hidden_size, self.qk_dim)
|
| 394 |
+
self.v_proj = L(self.hidden_size, self.v_dim)
|
| 395 |
+
self.o_proj = L(self.v_dim, self.hidden_size)
|
| 396 |
+
|
| 397 |
+
self.alpha_proj = nn.Linear(self.hidden_size, self.num_heads, bias=True)
|
| 398 |
+
self.eta_proj = nn.Linear(self.hidden_size, self.num_heads, bias=True)
|
| 399 |
+
self.theta_proj = nn.Linear(self.hidden_size, self.num_heads, bias=True)
|
| 400 |
+
|
| 401 |
+
if self.persistent_slots > 0:
|
| 402 |
+
self.persistent_memory = nn.Parameter(
|
| 403 |
+
torch.randn(self.persistent_slots, self.hidden_size) * 0.02)
|
| 404 |
+
else:
|
| 405 |
+
self.register_parameter("persistent_memory", None)
|
| 406 |
+
|
| 407 |
+
self.o_norm = RMSNorm(self.v_dim, eps=norm_eps)
|
| 408 |
+
|
| 409 |
+
def forward(self, x: torch.Tensor, cache: Optional[dict] = None
|
| 410 |
+
) -> Tuple[torch.Tensor, dict]:
|
| 411 |
+
B, T, _ = x.shape
|
| 412 |
+
H = self.num_heads
|
| 413 |
+
D = self.head_dim
|
| 414 |
+
# Project once.
|
| 415 |
+
q = self.q_proj(x).view(B, T, H, D)
|
| 416 |
+
k = self.k_proj(x).view(B, T, H, D)
|
| 417 |
+
v = self.v_proj(x).view(B, T, H, D)
|
| 418 |
+
|
| 419 |
+
alpha = torch.sigmoid(self.alpha_proj(x)) # [B, T, H]
|
| 420 |
+
eta = torch.sigmoid(self.eta_proj(x))
|
| 421 |
+
theta = torch.sigmoid(self.theta_proj(x)) * 0.1
|
| 422 |
+
|
| 423 |
+
q_h = q.permute(0, 2, 1, 3).to(torch.float32)
|
| 424 |
+
k_h = k.permute(0, 2, 1, 3).to(torch.float32)
|
| 425 |
+
v_h = v.permute(0, 2, 1, 3).to(torch.float32)
|
| 426 |
+
alpha_h = alpha.permute(0, 2, 1).to(torch.float32)
|
| 427 |
+
eta_h = eta.permute(0, 2, 1).to(torch.float32)
|
| 428 |
+
theta_h = theta.permute(0, 2, 1).to(torch.float32)
|
| 429 |
+
|
| 430 |
+
# Causal forgetting decay built in log-space.
|
| 431 |
+
log_retain = torch.log1p(-alpha_h.clamp(max=0.999))
|
| 432 |
+
log_retain_cum = log_retain.cumsum(dim=-1)
|
| 433 |
+
decay = log_retain_cum.unsqueeze(-1) - log_retain_cum.unsqueeze(-2)
|
| 434 |
+
decay = decay + _causal_mask_neg_inf(T, x.device, decay.dtype)
|
| 435 |
+
decay = decay.exp() # 0 above diag
|
| 436 |
+
|
| 437 |
+
contrib = (eta_h * theta_h).unsqueeze(-1) * v_h # [B, H, T, D]
|
| 438 |
+
attn = torch.matmul(q_h, k_h.transpose(-1, -2)) * decay # [B, H, T, T]
|
| 439 |
+
out = torch.matmul(attn, contrib) # [B, H, T, D]
|
| 440 |
+
|
| 441 |
+
out = out.permute(0, 2, 1, 3).reshape(B, T, self.v_dim)
|
| 442 |
+
out = self.o_norm(out.to(x.dtype))
|
| 443 |
+
return self.o_proj(out), cache or {}
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
# ---------------------------------------------------------------------------
|
| 447 |
+
# TSP Span Knot — fast vectorised energy
|
| 448 |
+
# ---------------------------------------------------------------------------
|
| 449 |
+
|
| 450 |
+
class TSPSpanKnotLayer(nn.Module):
|
| 451 |
+
"""TSP Span Knot: GatedDeltaNet body with a small additive energy term."""
|
| 452 |
+
|
| 453 |
+
def __init__(self, hidden_size: int, num_heads: int, head_dim: int,
|
| 454 |
+
norm_eps: float = 1e-6, chunk_size: int = 64,
|
| 455 |
+
use_ternary: bool = True):
|
| 456 |
+
super().__init__()
|
| 457 |
+
self.hidden_size = int(hidden_size)
|
| 458 |
+
self.gdn = GatedDeltaNetLayer(self.hidden_size, num_heads, head_dim,
|
| 459 |
+
norm_eps=norm_eps, chunk_size=chunk_size,
|
| 460 |
+
use_ternary=use_ternary)
|
| 461 |
+
# Single fused projection produces five energy terms.
|
| 462 |
+
self.energy_proj = nn.Linear(self.hidden_size, 5, bias=False)
|
| 463 |
+
self.energy_weights = nn.Parameter(torch.tensor([1.0, 0.3, 0.2, 0.4, 0.3]))
|
| 464 |
+
self._semantic_memory = None
|
| 465 |
+
|
| 466 |
+
def set_semantic_memory(self, mem) -> None:
|
| 467 |
+
self._semantic_memory = mem
|
| 468 |
+
|
| 469 |
+
def forward(self, x: torch.Tensor, cache: Optional[dict] = None
|
| 470 |
+
) -> Tuple[torch.Tensor, dict]:
|
| 471 |
+
out, new_cache = self.gdn(x, cache=cache)
|
| 472 |
+
energies = self.energy_proj(out) # [B, T, 5]
|
| 473 |
+
weighted = (energies * self.energy_weights).sum(dim=-1, keepdim=True)
|
| 474 |
+
# Small residual nudge — keeps gradient signal small as in 5.1.
|
| 475 |
+
return out + weighted * 0.01, new_cache
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
__all__ = [
|
| 479 |
+
"SwiGLUMLP",
|
| 480 |
+
"ShortConv1d",
|
| 481 |
+
"GatedDeltaNetLayer",
|
| 482 |
+
"MLSTMLayer",
|
| 483 |
+
"TitansMACLayer",
|
| 484 |
+
"TSPSpanKnotLayer",
|
| 485 |
+
]
|
chimera/looping.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Chimera 5.2 — Parcae Prelude / Loop / Coda controller.
|
| 3 |
+
|
| 4 |
+
Same numerics as the previous draft but cleaner:
|
| 5 |
+
* Loop count is deterministic during training so gradient checkpointing
|
| 6 |
+
recompute is consistent.
|
| 7 |
+
* Backward truncation only retains gradients on the last ``n_loops // 2``
|
| 8 |
+
iterations; earlier iterates are detached, mirroring the original
|
| 9 |
+
intuition while keeping the implementation in pure PyTorch.
|
| 10 |
+
* Adaptive early-exit during inference based on residual magnitude.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class ParcaeInjection(nn.Module):
|
| 20 |
+
"""ZOH-stable diagonal injection: ``h' = exp(-Δ·A)·h + Δ·B·e``."""
|
| 21 |
+
|
| 22 |
+
__constants__ = ["hidden_size"]
|
| 23 |
+
|
| 24 |
+
def __init__(self, hidden_size: int):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.hidden_size = int(hidden_size)
|
| 27 |
+
self.log_A = nn.Parameter(torch.zeros(self.hidden_size))
|
| 28 |
+
self.log_A._no_weight_decay = True
|
| 29 |
+
self.B_raw = nn.Parameter(torch.randn(self.hidden_size) * 0.02)
|
| 30 |
+
self.delta = nn.Parameter(torch.full((self.hidden_size,), 0.5))
|
| 31 |
+
|
| 32 |
+
def forward(self, h_prev: torch.Tensor, e: torch.Tensor) -> torch.Tensor:
|
| 33 |
+
A_bar = (-self.delta * self.log_A.exp()).exp()
|
| 34 |
+
B_bar = self.delta * self.B_raw
|
| 35 |
+
return A_bar * h_prev + B_bar * e
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class ParcaeLoopController(nn.Module):
|
| 39 |
+
"""Iterative refinement controller used by the looped trunk."""
|
| 40 |
+
|
| 41 |
+
__constants__ = ["loop_min", "loop_max", "loop_default"]
|
| 42 |
+
|
| 43 |
+
def __init__(self, hidden_size: int,
|
| 44 |
+
loop_range: tuple = (1, 6), loop_default: int = 2,
|
| 45 |
+
adaptive_exit_threshold: float = 0.01,
|
| 46 |
+
spectral_radius_bound: float = 1.0):
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.injection = ParcaeInjection(hidden_size)
|
| 49 |
+
self.loop_min, self.loop_max = int(loop_range[0]), int(loop_range[1])
|
| 50 |
+
self.loop_default = int(loop_default)
|
| 51 |
+
self.exit_threshold = float(adaptive_exit_threshold)
|
| 52 |
+
self.e_norm = nn.LayerNorm(hidden_size)
|
| 53 |
+
|
| 54 |
+
def forward(self, prelude_output: torch.Tensor, loop_fn,
|
| 55 |
+
num_loops: int = None) -> torch.Tensor:
|
| 56 |
+
e = self.e_norm(prelude_output)
|
| 57 |
+
h = torch.zeros_like(e)
|
| 58 |
+
n_loops = int(num_loops) if num_loops is not None else self.loop_default
|
| 59 |
+
n_loops = max(self.loop_min, min(self.loop_max, n_loops))
|
| 60 |
+
|
| 61 |
+
n_bwd = max(1, n_loops // 2) if self.training else n_loops
|
| 62 |
+
|
| 63 |
+
for t in range(n_loops):
|
| 64 |
+
h_new = loop_fn(self.injection(h, e))
|
| 65 |
+
backprop = (not self.training) or (t >= n_loops - n_bwd)
|
| 66 |
+
h = h_new if backprop else h_new.detach()
|
| 67 |
+
if not self.training and t > 0:
|
| 68 |
+
if (h_new - h).abs().mean().item() < self.exit_threshold:
|
| 69 |
+
break
|
| 70 |
+
return h
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
__all__ = ["ParcaeInjection", "ParcaeLoopController"]
|
chimera/model.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Chimera 5.2 — full causal LM (CPU-first).
|
| 3 |
+
|
| 4 |
+
Key improvements over the previous implementation:
|
| 5 |
+
|
| 6 |
+
* Every recurrent block returns ``(out, cache)`` so the inference loop can
|
| 7 |
+
carry per-layer state. This collapses generation latency from O(T²) to
|
| 8 |
+
O(T) on CPU.
|
| 9 |
+
* Looping mode now passes ``cache=None`` only on the *first* loop iteration
|
| 10 |
+
for each step, so iterative refinement does not accidentally double-count
|
| 11 |
+
past tokens.
|
| 12 |
+
* The grammar/debt heads are real no-ops when their constraints are empty,
|
| 13 |
+
meaning a freshly loaded model performs **one** ``F.linear`` for the LM
|
| 14 |
+
head and that's it on the per-token path.
|
| 15 |
+
* Vision/audio embeddings are now projected to ``hidden_size`` so the
|
| 16 |
+
concatenation is dimensionally correct.
|
| 17 |
+
* ``logits_to_keep`` short-circuits the final hidden norm to the last
|
| 18 |
+
``k`` tokens — the original code only sliced *before* ``norm`` was
|
| 19 |
+
applied, wasting CPU cycles on positions we never used.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
from __future__ import annotations
|
| 23 |
+
|
| 24 |
+
import json
|
| 25 |
+
from typing import Any, List, Optional, Tuple
|
| 26 |
+
|
| 27 |
+
import torch
|
| 28 |
+
import torch.nn as nn
|
| 29 |
+
import torch.nn.functional as F
|
| 30 |
+
from torch.utils.checkpoint import checkpoint
|
| 31 |
+
|
| 32 |
+
from .quantization import BitLinear, RMSNorm
|
| 33 |
+
from .layers import (GatedDeltaNetLayer, MLSTMLayer, TitansMACLayer,
|
| 34 |
+
TSPSpanKnotLayer, SwiGLUMLP)
|
| 35 |
+
from .moe import MoELayer
|
| 36 |
+
from .looping import ParcaeLoopController
|
| 37 |
+
from .inference import (SpanInferenceEngine, GrammarFST, EntropyValve,
|
| 38 |
+
DebtLedger, BraidState)
|
| 39 |
+
from .evolution import SelfEvolutionEngine
|
| 40 |
+
from .multimodal import VisionEncoder, AudioEncoder
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# ---------------------------------------------------------------------------
|
| 44 |
+
# Output container
|
| 45 |
+
# ---------------------------------------------------------------------------
|
| 46 |
+
|
| 47 |
+
class CausalLMOutput(dict):
|
| 48 |
+
"""Light HF-compatible output dict supporting tuple unpacking."""
|
| 49 |
+
|
| 50 |
+
def __init__(self, loss: Optional[torch.Tensor] = None,
|
| 51 |
+
logits: Optional[torch.Tensor] = None,
|
| 52 |
+
hidden_states: Optional[torch.Tensor] = None,
|
| 53 |
+
caches: Optional[list] = None):
|
| 54 |
+
super().__init__(loss=loss, logits=logits,
|
| 55 |
+
hidden_states=hidden_states, caches=caches)
|
| 56 |
+
self.loss = loss
|
| 57 |
+
self.logits = logits
|
| 58 |
+
self.hidden_states = hidden_states
|
| 59 |
+
self.caches = caches
|
| 60 |
+
|
| 61 |
+
def __iter__(self):
|
| 62 |
+
yield self.loss
|
| 63 |
+
yield self.logits
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# ---------------------------------------------------------------------------
|
| 67 |
+
# Layer expansion helper
|
| 68 |
+
# ---------------------------------------------------------------------------
|
| 69 |
+
|
| 70 |
+
def expand_layer_pattern(config: dict) -> List[str]:
|
| 71 |
+
"""Expand the layer-pattern shorthand (``"GD XM GD TM ..."``) into a list."""
|
| 72 |
+
backbone = config.get("backbone", {})
|
| 73 |
+
pattern_str = backbone.get("layer_pattern", "GD XM GD TM GD XM GD SK")
|
| 74 |
+
aliases = backbone.get("layer_aliases", {
|
| 75 |
+
"GD": "gated_deltanet", "XM": "xlstm_m",
|
| 76 |
+
"TM": "titans_mac", "SK": "tsp_span_knot",
|
| 77 |
+
})
|
| 78 |
+
pattern = pattern_str.split()
|
| 79 |
+
n_layers = int(config.get("num_hidden_layers", 28))
|
| 80 |
+
full = (pattern * (n_layers // len(pattern) + 1))[:n_layers]
|
| 81 |
+
return [aliases.get(p, p) for p in full]
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# ---------------------------------------------------------------------------
|
| 85 |
+
# Single block: pre-norm attention/recurrence + pre-norm MLP/MoE
|
| 86 |
+
# ---------------------------------------------------------------------------
|
| 87 |
+
|
| 88 |
+
class Chimera51Block(nn.Module):
|
| 89 |
+
"""One transformer-style block of the trunk.
|
| 90 |
+
|
| 91 |
+
``forward`` accepts an optional ``cache`` and returns the updated cache
|
| 92 |
+
so layers above can keep KV/state across decoder steps.
|
| 93 |
+
"""
|
| 94 |
+
|
| 95 |
+
_RECURRENT = {"gated_deltanet", "xlstm_m", "titans_mac", "tsp_span_knot"}
|
| 96 |
+
|
| 97 |
+
def __init__(self, config: dict, layer_type: str, layer_idx: int,
|
| 98 |
+
use_moe: bool = False):
|
| 99 |
+
super().__init__()
|
| 100 |
+
h = int(config["hidden_size"])
|
| 101 |
+
eps = float(config.get("rms_norm_eps", 1e-6))
|
| 102 |
+
heads = int(config["num_heads"])
|
| 103 |
+
head_dim = int(config["head_dim"])
|
| 104 |
+
ternary = bool(config.get("use_ternary", True))
|
| 105 |
+
chunk_sz = int(config.get("gated_deltanet", {}).get("chunk_size", 64))
|
| 106 |
+
|
| 107 |
+
self.layer_type = layer_type
|
| 108 |
+
self.attn_norm = RMSNorm(h, eps=eps)
|
| 109 |
+
|
| 110 |
+
if layer_type == "gated_deltanet":
|
| 111 |
+
self.attn = GatedDeltaNetLayer(h, heads, head_dim, norm_eps=eps,
|
| 112 |
+
chunk_size=chunk_sz, use_ternary=ternary)
|
| 113 |
+
elif layer_type == "xlstm_m":
|
| 114 |
+
mem_h = config.get("xlstm", {}).get("memory_size_per_head", [head_dim, head_dim])
|
| 115 |
+
self.attn = MLSTMLayer(h, heads, int(mem_h[0]), norm_eps=eps,
|
| 116 |
+
use_ternary=ternary)
|
| 117 |
+
elif layer_type == "titans_mac":
|
| 118 |
+
tc = config.get("titans", {})
|
| 119 |
+
self.attn = TitansMACLayer(h, heads, head_dim,
|
| 120 |
+
memory_depth=int(tc.get("memory_depth", 2)),
|
| 121 |
+
persistent_slots=int(tc.get("persistent_memory_slots", 64)),
|
| 122 |
+
local_window=int(tc.get("local_window_size", 1024)),
|
| 123 |
+
norm_eps=eps, use_ternary=ternary)
|
| 124 |
+
elif layer_type == "tsp_span_knot":
|
| 125 |
+
self.attn = TSPSpanKnotLayer(h, heads, head_dim, norm_eps=eps,
|
| 126 |
+
chunk_size=chunk_sz, use_ternary=ternary)
|
| 127 |
+
else:
|
| 128 |
+
raise ValueError(f"Unknown layer type: {layer_type}")
|
| 129 |
+
|
| 130 |
+
self.mlp_norm = RMSNorm(h, eps=eps)
|
| 131 |
+
self.use_moe = bool(use_moe)
|
| 132 |
+
if self.use_moe:
|
| 133 |
+
moe_cfg = config.get("backbone", {}).get("moe", {})
|
| 134 |
+
self.mlp = MoELayer(
|
| 135 |
+
hidden_size=h,
|
| 136 |
+
moe_intermediate_size=int(moe_cfg.get("moe_intermediate_size", h * 2)),
|
| 137 |
+
n_routed_experts=int(moe_cfg.get("n_routed_experts", 16)),
|
| 138 |
+
n_shared_experts=int(moe_cfg.get("n_shared_experts", 1)),
|
| 139 |
+
num_experts_per_tok=int(moe_cfg.get("num_experts_per_tok", 2)),
|
| 140 |
+
use_ternary=ternary,
|
| 141 |
+
)
|
| 142 |
+
else:
|
| 143 |
+
inter = int(config.get("intermediate_size", int(h * 8 / 3)))
|
| 144 |
+
inter = 256 * ((inter + 255) // 256)
|
| 145 |
+
self.mlp = SwiGLUMLP(h, inter, use_ternary=ternary)
|
| 146 |
+
|
| 147 |
+
def forward(self, x: torch.Tensor, cache: Optional[dict] = None
|
| 148 |
+
) -> Tuple[torch.Tensor, dict]:
|
| 149 |
+
attn_out, new_cache = self.attn(self.attn_norm(x), cache=cache)
|
| 150 |
+
x = x + attn_out
|
| 151 |
+
x = x + self.mlp(self.mlp_norm(x))
|
| 152 |
+
return x, new_cache
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
# ---------------------------------------------------------------------------
|
| 156 |
+
# Full causal LM
|
| 157 |
+
# ---------------------------------------------------------------------------
|
| 158 |
+
|
| 159 |
+
class Chimera51ForCausalLM(nn.Module):
|
| 160 |
+
"""Chimera 5.x causal language model."""
|
| 161 |
+
|
| 162 |
+
def __init__(self, config: dict):
|
| 163 |
+
super().__init__()
|
| 164 |
+
self.config = config
|
| 165 |
+
h = int(config["hidden_size"])
|
| 166 |
+
vocab = int(config["vocab_size"])
|
| 167 |
+
n_layers = int(config["num_hidden_layers"])
|
| 168 |
+
eps = float(config.get("rms_norm_eps", 1e-6))
|
| 169 |
+
|
| 170 |
+
self.embed = nn.Embedding(vocab, h)
|
| 171 |
+
layer_types = expand_layer_pattern(config)
|
| 172 |
+
moe_layers = set(int(i) for i in config.get("backbone", {}).get("moe", {}).get("layers", []))
|
| 173 |
+
|
| 174 |
+
self.layers = nn.ModuleList([
|
| 175 |
+
Chimera51Block(config, layer_types[i], i, use_moe=(i in moe_layers))
|
| 176 |
+
for i in range(n_layers)
|
| 177 |
+
])
|
| 178 |
+
|
| 179 |
+
self.norm = RMSNorm(h, eps=eps)
|
| 180 |
+
self.lm_head = nn.Linear(h, vocab, bias=False)
|
| 181 |
+
|
| 182 |
+
if config.get("tie_word_embeddings", True):
|
| 183 |
+
self.lm_head.weight = self.embed.weight
|
| 184 |
+
|
| 185 |
+
# Parcae looping controller (only built when there are enough layers).
|
| 186 |
+
loop_cfg = config.get("looping", {})
|
| 187 |
+
self.looping_enabled = bool(loop_cfg.get("enabled", True)) and n_layers >= 3
|
| 188 |
+
if self.looping_enabled:
|
| 189 |
+
self.prelude_start, self.prelude_end = loop_cfg.get("prelude", [0, min(3, n_layers - 1)])
|
| 190 |
+
self.loop_start, self.loop_end = loop_cfg.get("loop", [min(4, n_layers - 1), max(4, n_layers - 4)])
|
| 191 |
+
self.coda_start, self.coda_end = loop_cfg.get("coda", [max(0, n_layers - 4), n_layers - 1])
|
| 192 |
+
self.loop_controller = ParcaeLoopController(
|
| 193 |
+
h, loop_range=tuple(loop_cfg.get("loop_range", [1, 6])),
|
| 194 |
+
loop_default=int(loop_cfg.get("loop_default", 2)),
|
| 195 |
+
adaptive_exit_threshold=float(loop_cfg.get("adaptive_exit_threshold", 0.01)),
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
# Inference systems.
|
| 199 |
+
si_cfg = config.get("span_inference", {})
|
| 200 |
+
self.span_engine = SpanInferenceEngine(h, si_cfg) if si_cfg.get("enabled", True) else None
|
| 201 |
+
self.grammar = GrammarFST(config.get("grammar", {}))
|
| 202 |
+
self.entropy_valve = EntropyValve(config.get("entropy_valve", {}))
|
| 203 |
+
self.debt_ledger = DebtLedger(config.get("debt_ledger", {}))
|
| 204 |
+
|
| 205 |
+
# Self-evolution.
|
| 206 |
+
evo_cfg = dict(config.get("self_evolution", {}))
|
| 207 |
+
evo_cfg["_semantic_memory_config"] = config.get("semantic_memory", {})
|
| 208 |
+
self.evolution = SelfEvolutionEngine(evo_cfg, h)
|
| 209 |
+
|
| 210 |
+
# Multimodal — projection happens inside the encoder so the output
|
| 211 |
+
# already matches ``hidden_size``.
|
| 212 |
+
mm_cfg = dict(config.get("multimodal", {}))
|
| 213 |
+
mm_cfg["hidden_size"] = h
|
| 214 |
+
if mm_cfg.get("enabled", False):
|
| 215 |
+
self.vision_encoder = VisionEncoder(mm_cfg)
|
| 216 |
+
self.audio_encoder = AudioEncoder(mm_cfg)
|
| 217 |
+
else:
|
| 218 |
+
self.vision_encoder = None
|
| 219 |
+
self.audio_encoder = None
|
| 220 |
+
|
| 221 |
+
self.gradient_checkpointing = False
|
| 222 |
+
self._init_weights()
|
| 223 |
+
self._wire_semantic_memory()
|
| 224 |
+
|
| 225 |
+
# -- module lifecycle ------------------------------------------------------
|
| 226 |
+
|
| 227 |
+
def enable_gradient_checkpointing(self) -> None:
|
| 228 |
+
self.gradient_checkpointing = True
|
| 229 |
+
|
| 230 |
+
def disable_gradient_checkpointing(self) -> None:
|
| 231 |
+
self.gradient_checkpointing = False
|
| 232 |
+
|
| 233 |
+
def _wire_semantic_memory(self) -> None:
|
| 234 |
+
mem = self.evolution.semantic_memory
|
| 235 |
+
for layer in self.layers:
|
| 236 |
+
if hasattr(layer.attn, "set_semantic_memory"):
|
| 237 |
+
layer.attn.set_semantic_memory(mem)
|
| 238 |
+
|
| 239 |
+
def _init_weights(self) -> None:
|
| 240 |
+
init_range = float(self.config.get("initializer_range", 0.006))
|
| 241 |
+
for module in self.modules():
|
| 242 |
+
if isinstance(module, (nn.Linear, BitLinear)):
|
| 243 |
+
if module.weight is not None:
|
| 244 |
+
nn.init.normal_(module.weight, mean=0.0, std=init_range)
|
| 245 |
+
if getattr(module, "bias", None) is not None:
|
| 246 |
+
nn.init.zeros_(module.bias)
|
| 247 |
+
elif isinstance(module, nn.Embedding):
|
| 248 |
+
nn.init.normal_(module.weight, mean=0.0, std=init_range)
|
| 249 |
+
# BitLinear caches need refreshing after init.
|
| 250 |
+
for module in self.modules():
|
| 251 |
+
if isinstance(module, BitLinear):
|
| 252 |
+
module.invalidate_packed()
|
| 253 |
+
|
| 254 |
+
# -- core forward ----------------------------------------------------------
|
| 255 |
+
|
| 256 |
+
def _run_layers(self, x: torch.Tensor, start: int, end: int,
|
| 257 |
+
caches: Optional[list]) -> torch.Tensor:
|
| 258 |
+
for i in range(start, min(end + 1, len(self.layers))):
|
| 259 |
+
layer = self.layers[i]
|
| 260 |
+
cache = caches[i] if caches is not None else None
|
| 261 |
+
if self.gradient_checkpointing and self.training:
|
| 262 |
+
# Wrap the layer in a tensor-only closure so PyTorch's
|
| 263 |
+
# checkpoint helper can hash the inputs reliably. Caches
|
| 264 |
+
# are not refreshed during gradient checkpointing — the
|
| 265 |
+
# recurrent state is recomputed in the backward pass.
|
| 266 |
+
def _ckpt_fn(x_in, layer=layer, cache=cache):
|
| 267 |
+
out, _ = layer(x_in, cache=cache)
|
| 268 |
+
return out
|
| 269 |
+
x = checkpoint(_ckpt_fn, x, use_reentrant=False)
|
| 270 |
+
else:
|
| 271 |
+
x, new_cache = layer(x, cache=cache)
|
| 272 |
+
if caches is not None:
|
| 273 |
+
caches[i] = new_cache
|
| 274 |
+
return x
|
| 275 |
+
|
| 276 |
+
def _loop_fn_factory(self, caches: Optional[list]):
|
| 277 |
+
"""Capture caches for the loop controller's repeated invocations."""
|
| 278 |
+
def loop_fn(x: torch.Tensor) -> torch.Tensor:
|
| 279 |
+
return self._run_layers(x, self.loop_start, self.loop_end, caches)
|
| 280 |
+
return loop_fn
|
| 281 |
+
|
| 282 |
+
def forward(self, input_ids: torch.Tensor,
|
| 283 |
+
labels: Optional[torch.Tensor] = None,
|
| 284 |
+
pixel_values: Optional[torch.Tensor] = None,
|
| 285 |
+
mel_features: Optional[torch.Tensor] = None,
|
| 286 |
+
num_loops: Optional[int] = None,
|
| 287 |
+
caches: Optional[list] = None,
|
| 288 |
+
use_cache: bool = False,
|
| 289 |
+
logits_to_keep: int = 0):
|
| 290 |
+
x = self.embed(input_ids)
|
| 291 |
+
|
| 292 |
+
# Multimodal prepend (encoders already project to hidden_size).
|
| 293 |
+
if pixel_values is not None and self.vision_encoder is not None:
|
| 294 |
+
v = self.vision_encoder(pixel_values)
|
| 295 |
+
if v is not None:
|
| 296 |
+
x = torch.cat([v, x], dim=1)
|
| 297 |
+
if mel_features is not None and self.audio_encoder is not None:
|
| 298 |
+
a = self.audio_encoder(mel_features)
|
| 299 |
+
if a is not None:
|
| 300 |
+
x = torch.cat([a, x], dim=1)
|
| 301 |
+
|
| 302 |
+
# Optional KV/state caches. ``use_cache`` is honoured even when the
|
| 303 |
+
# caller didn't supply one.
|
| 304 |
+
if caches is None and use_cache:
|
| 305 |
+
caches = [None] * len(self.layers)
|
| 306 |
+
|
| 307 |
+
if self.looping_enabled and hasattr(self, "loop_controller"):
|
| 308 |
+
x = self._run_layers(x, self.prelude_start, self.prelude_end, caches)
|
| 309 |
+
effective = num_loops
|
| 310 |
+
if effective is None and not self.training:
|
| 311 |
+
# Sample compute on the last token's logits only.
|
| 312 |
+
probe = self.lm_head(self.norm(x[:, -1:, :]))
|
| 313 |
+
effective = self.entropy_valve.get_loop_count(probe)
|
| 314 |
+
x = self.loop_controller(x, self._loop_fn_factory(caches), num_loops=effective)
|
| 315 |
+
x = self._run_layers(x, self.coda_start, self.coda_end, caches)
|
| 316 |
+
else:
|
| 317 |
+
x = self._run_layers(x, 0, len(self.layers) - 1, caches)
|
| 318 |
+
|
| 319 |
+
# Slice to the relevant tail before allocating logits — the LM head is
|
| 320 |
+
# the largest matmul on small models because vocab >> hidden_size.
|
| 321 |
+
if logits_to_keep and labels is None:
|
| 322 |
+
keep = int(logits_to_keep)
|
| 323 |
+
tail = x[:, -keep:, :]
|
| 324 |
+
tail = self.norm(tail)
|
| 325 |
+
if self.span_engine is not None:
|
| 326 |
+
tail = self.span_engine(tail)
|
| 327 |
+
logits = self.lm_head(tail)
|
| 328 |
+
else:
|
| 329 |
+
x = self.norm(x)
|
| 330 |
+
if self.span_engine is not None:
|
| 331 |
+
x = self.span_engine(x)
|
| 332 |
+
logits = self.lm_head(x)
|
| 333 |
+
|
| 334 |
+
logits = self.grammar(logits)
|
| 335 |
+
logits = self.debt_ledger(logits)
|
| 336 |
+
|
| 337 |
+
loss = None
|
| 338 |
+
if labels is not None:
|
| 339 |
+
seq_len = min(logits.size(1), labels.size(1))
|
| 340 |
+
shift_logits = logits[:, :seq_len, :].contiguous()
|
| 341 |
+
shift_labels = labels[:, :seq_len].contiguous()
|
| 342 |
+
loss = F.cross_entropy(
|
| 343 |
+
shift_logits.view(-1, shift_logits.size(-1)),
|
| 344 |
+
shift_labels.view(-1),
|
| 345 |
+
ignore_index=-100,
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
return CausalLMOutput(loss=loss, logits=logits, hidden_states=x,
|
| 349 |
+
caches=caches if use_cache else None)
|
| 350 |
+
|
| 351 |
+
# -- utilities -------------------------------------------------------------
|
| 352 |
+
|
| 353 |
+
@torch.no_grad()
|
| 354 |
+
def prepare_for_inference(self) -> None:
|
| 355 |
+
"""Pre-pack every BitLinear so the first generation step is fast."""
|
| 356 |
+
for module in self.modules():
|
| 357 |
+
if isinstance(module, BitLinear):
|
| 358 |
+
module.prepare_for_inference()
|
| 359 |
+
|
| 360 |
+
def get_mode_config(self, mode: str = "balanced") -> dict:
|
| 361 |
+
modes = self.config.get("modes", {})
|
| 362 |
+
return modes.get(mode, modes.get("balanced", {}))
|
| 363 |
+
|
| 364 |
+
def count_parameters(self) -> dict:
|
| 365 |
+
total = sum(p.numel() for p in self.parameters())
|
| 366 |
+
ternary = sum(p.numel() for _, m in self.named_modules()
|
| 367 |
+
if isinstance(m, BitLinear) for p in m.parameters())
|
| 368 |
+
return {"total": total, "ternary": ternary, "fp32": total - ternary}
|
| 369 |
+
|
| 370 |
+
@classmethod
|
| 371 |
+
def from_config_file(cls, path: str) -> "Chimera51ForCausalLM":
|
| 372 |
+
with open(path, "r", encoding="utf-8") as fh:
|
| 373 |
+
config = json.load(fh)
|
| 374 |
+
return cls(config)
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
__all__ = ["Chimera51ForCausalLM", "Chimera51Block", "CausalLMOutput",
|
| 378 |
+
"expand_layer_pattern"]
|
chimera/moe.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Sparse Mixture-of-Experts for Chimera (CPU-first).
|
| 3 |
+
|
| 4 |
+
Key design choices:
|
| 5 |
+
* Routing is computed in the model's compute dtype (no fp32 promotion):
|
| 6 |
+
the original draft cast every router input to fp32 which doubled memory
|
| 7 |
+
bandwidth for nothing on CPUs without dedicated softmax units.
|
| 8 |
+
* Dispatch uses ``index_select`` + boolean masks per expert. No global
|
| 9 |
+
``argsort`` of the routing pairs and no ``bincount`` table. This keeps
|
| 10 |
+
the path ``torch.compile``-friendly even when expert counts vary.
|
| 11 |
+
* All experts share an :class:`SwiGLUMLP` topology so weights can be packed
|
| 12 |
+
ternary identically to the rest of the model.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
from typing import Optional
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
|
| 23 |
+
from .layers import SwiGLUMLP
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class NoAuxMoEGate(nn.Module):
|
| 27 |
+
"""Top-k softmax router with optional bias-only correction (no aux loss)."""
|
| 28 |
+
|
| 29 |
+
__constants__ = ["n_routed_experts", "num_experts_per_tok"]
|
| 30 |
+
|
| 31 |
+
def __init__(self, hidden_size: int, n_routed_experts: int,
|
| 32 |
+
num_experts_per_tok: int = 2):
|
| 33 |
+
super().__init__()
|
| 34 |
+
self.n_routed_experts = int(n_routed_experts)
|
| 35 |
+
self.num_experts_per_tok = int(num_experts_per_tok)
|
| 36 |
+
self.weight = nn.Parameter(torch.empty(self.n_routed_experts, hidden_size))
|
| 37 |
+
nn.init.normal_(self.weight, mean=0.0, std=hidden_size ** -0.5)
|
| 38 |
+
# Buffer (not a Parameter): bias correction updated by training scripts.
|
| 39 |
+
self.register_buffer("e_score_correction_bias",
|
| 40 |
+
torch.zeros(self.n_routed_experts))
|
| 41 |
+
|
| 42 |
+
def forward(self, x: torch.Tensor):
|
| 43 |
+
# x: [N, D] in arbitrary dtype. Routing is stable enough in bf16/fp32.
|
| 44 |
+
scores = F.linear(x, self.weight) + self.e_score_correction_bias
|
| 45 |
+
probs = F.softmax(scores, dim=-1)
|
| 46 |
+
weights, indices = torch.topk(probs, self.num_experts_per_tok, dim=-1)
|
| 47 |
+
weights = weights / weights.sum(dim=-1, keepdim=True).clamp_min(1e-9)
|
| 48 |
+
return indices, weights
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class MoELayer(nn.Module):
|
| 52 |
+
"""Sparse MoE block with grouped expert dispatch."""
|
| 53 |
+
|
| 54 |
+
def __init__(self, hidden_size: int, moe_intermediate_size: int,
|
| 55 |
+
n_routed_experts: int = 16, n_shared_experts: int = 1,
|
| 56 |
+
num_experts_per_tok: int = 2, use_ternary: bool = True):
|
| 57 |
+
super().__init__()
|
| 58 |
+
self.hidden_size = int(hidden_size)
|
| 59 |
+
self.n_routed_experts = int(n_routed_experts)
|
| 60 |
+
self.n_shared_experts = int(n_shared_experts)
|
| 61 |
+
self.num_experts_per_tok = int(num_experts_per_tok)
|
| 62 |
+
self.gate = NoAuxMoEGate(self.hidden_size, self.n_routed_experts,
|
| 63 |
+
self.num_experts_per_tok)
|
| 64 |
+
self.experts = nn.ModuleList([
|
| 65 |
+
SwiGLUMLP(self.hidden_size, moe_intermediate_size, use_ternary=use_ternary)
|
| 66 |
+
for _ in range(self.n_routed_experts)
|
| 67 |
+
])
|
| 68 |
+
if self.n_shared_experts > 0:
|
| 69 |
+
shared_inter = max(1, moe_intermediate_size * self.n_shared_experts)
|
| 70 |
+
self.shared_experts = SwiGLUMLP(self.hidden_size, shared_inter,
|
| 71 |
+
use_ternary=use_ternary)
|
| 72 |
+
else:
|
| 73 |
+
self.shared_experts = None
|
| 74 |
+
|
| 75 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 76 |
+
orig_shape = x.shape
|
| 77 |
+
flat = x.reshape(-1, self.hidden_size)
|
| 78 |
+
N = flat.size(0)
|
| 79 |
+
|
| 80 |
+
topk_idx, topk_w = self.gate(flat) # [N, k]
|
| 81 |
+
out = torch.zeros_like(flat)
|
| 82 |
+
|
| 83 |
+
# Per-expert dispatch via boolean masks: avoids the global argsort and
|
| 84 |
+
# ``bincount`` of the previous draft and keeps the structure compatible
|
| 85 |
+
# with torch.compile.
|
| 86 |
+
for e in range(self.n_routed_experts):
|
| 87 |
+
match = (topk_idx == e)
|
| 88 |
+
if not match.any():
|
| 89 |
+
continue
|
| 90 |
+
# Token positions and per-pair weights for this expert.
|
| 91 |
+
tok_pos, slot_pos = match.nonzero(as_tuple=True)
|
| 92 |
+
w = topk_w[tok_pos, slot_pos].unsqueeze(-1).to(out.dtype)
|
| 93 |
+
y = self.experts[e](flat.index_select(0, tok_pos))
|
| 94 |
+
out.index_add_(0, tok_pos, y * w)
|
| 95 |
+
|
| 96 |
+
if self.shared_experts is not None:
|
| 97 |
+
out = out + self.shared_experts(flat)
|
| 98 |
+
|
| 99 |
+
return out.reshape(orig_shape)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
__all__ = ["NoAuxMoEGate", "MoELayer", "SwiGLUMLP"]
|
chimera/multimodal.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Chimera 5.2 — multimodal encoders (CPU-friendly, slim).
|
| 3 |
+
|
| 4 |
+
The previous draft had two latent issues:
|
| 5 |
+
* The vision/audio encoders projected to ``out_dim`` (e.g. 2560) which did
|
| 6 |
+
not match the trunk's ``hidden_size`` after scaling, so concatenating
|
| 7 |
+
image embeddings into the LM hidden stream blew up. We now project to
|
| 8 |
+
the trunk's hidden size by default.
|
| 9 |
+
* The internal ``_EncoderBlock`` wrapped a recurrent layer expecting a
|
| 10 |
+
``cache`` argument; we now call the layer correctly and discard the
|
| 11 |
+
cache (the encoder is purely parallel).
|
| 12 |
+
|
| 13 |
+
The encoders themselves remain BitLinear-friendly so they share the
|
| 14 |
+
ternary memory budget of the trunk.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
from typing import Optional
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
from torch.utils.checkpoint import checkpoint
|
| 24 |
+
|
| 25 |
+
from .layers import GatedDeltaNetLayer
|
| 26 |
+
from .quantization import BitLinear, RMSNorm
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _make_linear(use_ternary: bool):
|
| 30 |
+
if use_ternary:
|
| 31 |
+
return BitLinear
|
| 32 |
+
return lambda i, o, **kw: nn.Linear(i, o, bias=False)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class PatchEmbed(nn.Module):
|
| 36 |
+
__constants__ = ["patch_size"]
|
| 37 |
+
|
| 38 |
+
def __init__(self, patch_size: int = 16, in_channels: int = 3, hidden_size: int = 384):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.patch_size = int(patch_size)
|
| 41 |
+
self.proj = nn.Conv2d(in_channels, hidden_size,
|
| 42 |
+
kernel_size=self.patch_size, stride=self.patch_size)
|
| 43 |
+
self.norm = RMSNorm(hidden_size)
|
| 44 |
+
|
| 45 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 46 |
+
x = self.proj(x)
|
| 47 |
+
x = x.flatten(2).transpose(1, 2)
|
| 48 |
+
return self.norm(x)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class _EncoderBlock(nn.Module):
|
| 52 |
+
def __init__(self, hidden: int, num_heads: int, head_dim: int,
|
| 53 |
+
use_ternary: bool = True):
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.norm = RMSNorm(hidden)
|
| 56 |
+
self.attn = GatedDeltaNetLayer(hidden, num_heads, head_dim,
|
| 57 |
+
use_ternary=use_ternary, chunk_size=64)
|
| 58 |
+
self.mlp_norm = RMSNorm(hidden)
|
| 59 |
+
L = _make_linear(use_ternary)
|
| 60 |
+
self.mlp = nn.Sequential(L(hidden, hidden * 4), nn.GELU(), L(hidden * 4, hidden))
|
| 61 |
+
|
| 62 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 63 |
+
attn_out, _ = self.attn(self.norm(x))
|
| 64 |
+
x = x + attn_out
|
| 65 |
+
return x + self.mlp(self.mlp_norm(x))
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class _EncoderBase(nn.Module):
|
| 69 |
+
"""Shared encoder body for vision/audio."""
|
| 70 |
+
|
| 71 |
+
def __init__(self, hidden: int, depth: int, num_heads: int, head_dim: int,
|
| 72 |
+
out_dim: int, use_ternary: bool, use_checkpoint: bool):
|
| 73 |
+
super().__init__()
|
| 74 |
+
self.layers = nn.ModuleList([
|
| 75 |
+
_EncoderBlock(hidden, num_heads, head_dim, use_ternary)
|
| 76 |
+
for _ in range(depth)
|
| 77 |
+
])
|
| 78 |
+
self.proj = nn.Linear(hidden, out_dim, bias=False)
|
| 79 |
+
self.norm = RMSNorm(out_dim)
|
| 80 |
+
self.use_checkpoint = bool(use_checkpoint)
|
| 81 |
+
|
| 82 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 83 |
+
for layer in self.layers:
|
| 84 |
+
if self.use_checkpoint and self.training:
|
| 85 |
+
x = checkpoint(layer, x, use_reentrant=False)
|
| 86 |
+
else:
|
| 87 |
+
x = layer(x)
|
| 88 |
+
return self.norm(self.proj(x))
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class VisionEncoder(nn.Module):
|
| 92 |
+
def __init__(self, config: dict):
|
| 93 |
+
super().__init__()
|
| 94 |
+
v = config.get("vision", {})
|
| 95 |
+
self.enabled = bool(config.get("enabled", True))
|
| 96 |
+
hidden = int(v.get("hidden", 384))
|
| 97 |
+
depth = int(v.get("depth", 12))
|
| 98 |
+
patch = int(v.get("patch", 16))
|
| 99 |
+
# Default the encoder output to the trunk hidden_size so concatenation
|
| 100 |
+
# into the LM stream is dimensionally consistent.
|
| 101 |
+
out_dim = int(v.get("out", config.get("hidden_size", hidden)))
|
| 102 |
+
use_ternary = v.get("quant", "ternary") == "ternary"
|
| 103 |
+
num_heads = max(1, hidden // 64)
|
| 104 |
+
head_dim = hidden // num_heads
|
| 105 |
+
self.patch_embed = PatchEmbed(patch_size=patch, hidden_size=hidden)
|
| 106 |
+
self.body = _EncoderBase(hidden, depth, num_heads, head_dim,
|
| 107 |
+
out_dim, use_ternary, use_checkpoint=True)
|
| 108 |
+
|
| 109 |
+
def forward(self, pixel_values: torch.Tensor) -> Optional[torch.Tensor]:
|
| 110 |
+
if not self.enabled:
|
| 111 |
+
return None
|
| 112 |
+
return self.body(self.patch_embed(pixel_values))
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class AudioEncoder(nn.Module):
|
| 116 |
+
def __init__(self, config: dict):
|
| 117 |
+
super().__init__()
|
| 118 |
+
a = config.get("audio", {})
|
| 119 |
+
self.enabled = bool(config.get("enabled", True))
|
| 120 |
+
hidden = int(a.get("hidden", 256))
|
| 121 |
+
depth = int(a.get("depth", 6))
|
| 122 |
+
out_dim = int(a.get("out", config.get("hidden_size", hidden)))
|
| 123 |
+
use_ternary = a.get("quant", "ternary") == "ternary"
|
| 124 |
+
num_heads = max(1, hidden // 64)
|
| 125 |
+
head_dim = hidden // num_heads
|
| 126 |
+
self.input_proj = nn.Linear(80, hidden, bias=False)
|
| 127 |
+
self.body = _EncoderBase(hidden, depth, num_heads, head_dim,
|
| 128 |
+
out_dim, use_ternary, use_checkpoint=True)
|
| 129 |
+
|
| 130 |
+
def forward(self, mel_features: torch.Tensor) -> Optional[torch.Tensor]:
|
| 131 |
+
if not self.enabled:
|
| 132 |
+
return None
|
| 133 |
+
return self.body(self.input_proj(mel_features))
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
__all__ = ["PatchEmbed", "VisionEncoder", "AudioEncoder"]
|
chimera/quantization.py
ADDED
|
@@ -0,0 +1,508 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Chimera 5.2 — 1.58-bit Ternary Compute (CPU-First, Slim)
|
| 3 |
+
========================================================
|
| 4 |
+
Single, clean implementation of BitNet-1.58 ternary linear layers.
|
| 5 |
+
|
| 6 |
+
Design goals:
|
| 7 |
+
* Zero overhead at import time (no JIT, no kernel discovery).
|
| 8 |
+
* One fast pure-PyTorch path that vectorises everything; an optional
|
| 9 |
+
C++/OpenMP path that is loaded *lazily* and only used when it actually
|
| 10 |
+
beats PyTorch (small batches on inference).
|
| 11 |
+
* Cache the packed 2-bit weights between forward calls and only repack
|
| 12 |
+
when the latent FP32 weights are mutated (training step or MeZO).
|
| 13 |
+
* No data-dependent Python loops, no per-row mask construction at init.
|
| 14 |
+
|
| 15 |
+
Storage:
|
| 16 |
+
weight: FP32 latent of shape [M, K] (kept for STE backward / MeZO updates)
|
| 17 |
+
_packed: uint8 [M, ceil(K/4)] (2 bits per ternary value)
|
| 18 |
+
_alpha: fp32 [M] (per-row absolute mean scale)
|
| 19 |
+
|
| 20 |
+
Encoding (matches the C++ kernel):
|
| 21 |
+
-1 → 0b10
|
| 22 |
+
0 → 0b00
|
| 23 |
+
+1 → 0b01
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
from __future__ import annotations
|
| 27 |
+
|
| 28 |
+
import math
|
| 29 |
+
import os
|
| 30 |
+
import threading
|
| 31 |
+
from typing import Optional, Tuple
|
| 32 |
+
|
| 33 |
+
import torch
|
| 34 |
+
import torch.nn as nn
|
| 35 |
+
import torch.nn.functional as F
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# ---------------------------------------------------------------------------
|
| 39 |
+
# Lazy C++ kernel. We never compile it during ``import``; it is only built
|
| 40 |
+
# when explicitly requested via :func:`enable_native_kernel` or the env var
|
| 41 |
+
# ``CHIMERA_NATIVE=1``. All public APIs work with the pure-PyTorch path.
|
| 42 |
+
# ---------------------------------------------------------------------------
|
| 43 |
+
|
| 44 |
+
_NATIVE_LOCK = threading.Lock()
|
| 45 |
+
_NATIVE_EXT: Optional[object] = None
|
| 46 |
+
_NATIVE_TRIED = False
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
_CPP_SOURCE = r"""
|
| 50 |
+
#include <torch/extension.h>
|
| 51 |
+
#include <cstdint>
|
| 52 |
+
#include <cmath>
|
| 53 |
+
#ifdef _OPENMP
|
| 54 |
+
#include <omp.h>
|
| 55 |
+
#endif
|
| 56 |
+
|
| 57 |
+
// Encoding: -1->0b10, 0->0b00, +1->0b01
|
| 58 |
+
static const float LUT[4] = {0.0f, 1.0f, -1.0f, 0.0f};
|
| 59 |
+
|
| 60 |
+
torch::Tensor pack_ternary_cpu(torch::Tensor w) {
|
| 61 |
+
TORCH_CHECK(w.dim() == 2 && w.dtype() == torch::kInt8, "expected int8 [M,K]");
|
| 62 |
+
auto w_c = w.contiguous();
|
| 63 |
+
int64_t M = w_c.size(0), K = w_c.size(1);
|
| 64 |
+
int64_t K4 = (K + 3) / 4;
|
| 65 |
+
auto out = torch::zeros({M, K4}, torch::kUInt8);
|
| 66 |
+
const int8_t* s = w_c.data_ptr<int8_t>();
|
| 67 |
+
uint8_t* d = out.data_ptr<uint8_t>();
|
| 68 |
+
#pragma omp parallel for schedule(static)
|
| 69 |
+
for (int64_t m = 0; m < M; ++m) {
|
| 70 |
+
const int8_t* sr = s + m * K;
|
| 71 |
+
uint8_t* dr = d + m * K4;
|
| 72 |
+
for (int64_t k4 = 0; k4 < K4; ++k4) {
|
| 73 |
+
uint8_t b = 0;
|
| 74 |
+
for (int j = 0; j < 4; ++j) {
|
| 75 |
+
int64_t k = k4 * 4 + j;
|
| 76 |
+
if (k >= K) break;
|
| 77 |
+
int8_t v = sr[k];
|
| 78 |
+
uint8_t code = (v == 1) ? 1u : (v == -1 ? 2u : 0u);
|
| 79 |
+
b |= (code << (6 - j * 2));
|
| 80 |
+
}
|
| 81 |
+
dr[k4] = b;
|
| 82 |
+
}
|
| 83 |
+
}
|
| 84 |
+
return out;
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
torch::Tensor unpack_ternary_cpu(torch::Tensor packed, int64_t K) {
|
| 88 |
+
TORCH_CHECK(packed.dim() == 2 && packed.dtype() == torch::kUInt8, "expected uint8 [M,K4]");
|
| 89 |
+
auto p = packed.contiguous();
|
| 90 |
+
int64_t M = p.size(0), K4 = p.size(1);
|
| 91 |
+
auto out = torch::empty({M, K}, torch::kFloat32);
|
| 92 |
+
const uint8_t* pp = p.data_ptr<uint8_t>();
|
| 93 |
+
float* dp = out.data_ptr<float>();
|
| 94 |
+
#pragma omp parallel for schedule(static)
|
| 95 |
+
for (int64_t m = 0; m < M; ++m) {
|
| 96 |
+
const uint8_t* pr = pp + m * K4;
|
| 97 |
+
float* dr = dp + m * K;
|
| 98 |
+
for (int64_t k4 = 0; k4 < K4; ++k4) {
|
| 99 |
+
uint8_t b = pr[k4];
|
| 100 |
+
int64_t base = k4 * 4;
|
| 101 |
+
if (base + 0 < K) dr[base + 0] = LUT[(b >> 6) & 3];
|
| 102 |
+
if (base + 1 < K) dr[base + 1] = LUT[(b >> 4) & 3];
|
| 103 |
+
if (base + 2 < K) dr[base + 2] = LUT[(b >> 2) & 3];
|
| 104 |
+
if (base + 3 < K) dr[base + 3] = LUT[b & 3];
|
| 105 |
+
}
|
| 106 |
+
}
|
| 107 |
+
return out;
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
// Fused "unpack and scale" -> bf16/fp32 dense weight. Saves a pass over memory
|
| 111 |
+
// and a temporary FP32 tensor when running under bf16 autocast.
|
| 112 |
+
torch::Tensor dequantize_cpu(torch::Tensor packed, torch::Tensor alpha, int64_t K) {
|
| 113 |
+
auto p = packed.contiguous();
|
| 114 |
+
auto a = alpha.contiguous().to(torch::kFloat32);
|
| 115 |
+
int64_t M = p.size(0), K4 = p.size(1);
|
| 116 |
+
auto out = torch::empty({M, K}, torch::kFloat32);
|
| 117 |
+
const uint8_t* pp = p.data_ptr<uint8_t>();
|
| 118 |
+
const float* ap = a.data_ptr<float>();
|
| 119 |
+
float* dp = out.data_ptr<float>();
|
| 120 |
+
#pragma omp parallel for schedule(static)
|
| 121 |
+
for (int64_t m = 0; m < M; ++m) {
|
| 122 |
+
const uint8_t* pr = pp + m * K4;
|
| 123 |
+
float* dr = dp + m * K;
|
| 124 |
+
float sc = ap[m];
|
| 125 |
+
for (int64_t k4 = 0; k4 < K4; ++k4) {
|
| 126 |
+
uint8_t b = pr[k4];
|
| 127 |
+
int64_t base = k4 * 4;
|
| 128 |
+
if (base + 0 < K) dr[base + 0] = LUT[(b >> 6) & 3] * sc;
|
| 129 |
+
if (base + 1 < K) dr[base + 1] = LUT[(b >> 4) & 3] * sc;
|
| 130 |
+
if (base + 2 < K) dr[base + 2] = LUT[(b >> 2) & 3] * sc;
|
| 131 |
+
if (base + 3 < K) dr[base + 3] = LUT[b & 3] * sc;
|
| 132 |
+
}
|
| 133 |
+
}
|
| 134 |
+
return out;
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 138 |
+
m.def("pack_ternary", &pack_ternary_cpu, "Pack int8 ternary -> 2-bit uint8");
|
| 139 |
+
m.def("unpack_ternary", &unpack_ternary_cpu, "Unpack 2-bit uint8 -> fp32 {-1,0,1}");
|
| 140 |
+
m.def("dequantize", &dequantize_cpu, "Unpack and scale by per-row alpha");
|
| 141 |
+
}
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def _try_load_native() -> Optional[object]:
|
| 146 |
+
"""Compile/load the optional native helper. Idempotent and thread-safe."""
|
| 147 |
+
global _NATIVE_EXT, _NATIVE_TRIED
|
| 148 |
+
if _NATIVE_TRIED:
|
| 149 |
+
return _NATIVE_EXT
|
| 150 |
+
with _NATIVE_LOCK:
|
| 151 |
+
if _NATIVE_TRIED:
|
| 152 |
+
return _NATIVE_EXT
|
| 153 |
+
_NATIVE_TRIED = True
|
| 154 |
+
try:
|
| 155 |
+
from torch.utils.cpp_extension import load_inline
|
| 156 |
+
|
| 157 |
+
build_dir = os.path.join(
|
| 158 |
+
os.path.dirname(os.path.abspath(__file__)), "..", ".ternary_build"
|
| 159 |
+
)
|
| 160 |
+
os.makedirs(build_dir, exist_ok=True)
|
| 161 |
+
_NATIVE_EXT = load_inline(
|
| 162 |
+
name="chimera_ternary",
|
| 163 |
+
cpp_sources=_CPP_SOURCE,
|
| 164 |
+
extra_cflags=["-O3", "-fopenmp", "-ffast-math", "-funroll-loops"],
|
| 165 |
+
extra_ldflags=["-lgomp"],
|
| 166 |
+
build_directory=build_dir,
|
| 167 |
+
verbose=False,
|
| 168 |
+
)
|
| 169 |
+
except Exception as exc: # pragma: no cover - best-effort.
|
| 170 |
+
os.environ.setdefault("CHIMERA_NATIVE_DISABLED", str(exc)[:200])
|
| 171 |
+
_NATIVE_EXT = None
|
| 172 |
+
return _NATIVE_EXT
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def enable_native_kernel(force: bool = False) -> bool:
|
| 176 |
+
"""Eagerly try to compile the native kernel.
|
| 177 |
+
|
| 178 |
+
Returns ``True`` if the kernel is loaded and available.
|
| 179 |
+
"""
|
| 180 |
+
global _NATIVE_TRIED
|
| 181 |
+
if force:
|
| 182 |
+
_NATIVE_TRIED = False
|
| 183 |
+
return _try_load_native() is not None
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def native_kernel_available() -> bool:
|
| 187 |
+
return _NATIVE_EXT is not None
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
# Allow opt-in from the environment without code changes.
|
| 191 |
+
if os.environ.get("CHIMERA_NATIVE", "0") == "1":
|
| 192 |
+
enable_native_kernel()
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
# ---------------------------------------------------------------------------
|
| 196 |
+
# Pure PyTorch ternary primitives (always available).
|
| 197 |
+
# ---------------------------------------------------------------------------
|
| 198 |
+
|
| 199 |
+
# Lookup tables compiled once. Casting to a registered buffer is overkill –
|
| 200 |
+
# they live on CPU and broadcast naturally.
|
| 201 |
+
_TERNARY_LUT_F32 = torch.tensor([0.0, 1.0, -1.0, 0.0], dtype=torch.float32)
|
| 202 |
+
_TERNARY_LUT_I8 = torch.tensor([0, 1, -1, 0], dtype=torch.int8)
|
| 203 |
+
_SHIFTS = torch.tensor([6, 4, 2, 0], dtype=torch.uint8)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def pack_ternary(q: torch.Tensor) -> torch.Tensor:
|
| 207 |
+
"""Pack a ternary {-1,0,1} tensor into a 2-bit uint8 tensor.
|
| 208 |
+
|
| 209 |
+
Vectorised pure-PyTorch implementation — no Python loops over rows.
|
| 210 |
+
Trailing positions that don't divide by four are zero-padded.
|
| 211 |
+
"""
|
| 212 |
+
q = q.detach()
|
| 213 |
+
if q.dim() == 1:
|
| 214 |
+
q = q.unsqueeze(0)
|
| 215 |
+
flat = q.reshape(-1, q.shape[-1]).to(torch.int8)
|
| 216 |
+
M, K = flat.shape
|
| 217 |
+
K4 = (K + 3) // 4
|
| 218 |
+
pad = K4 * 4 - K
|
| 219 |
+
if pad:
|
| 220 |
+
flat = F.pad(flat, (0, pad))
|
| 221 |
+
# codes: 0 / 1 / 2 (uint8)
|
| 222 |
+
codes = torch.where(flat == 1, torch.full_like(flat, 1),
|
| 223 |
+
torch.where(flat == -1, torch.full_like(flat, 2), torch.zeros_like(flat))).to(torch.uint8)
|
| 224 |
+
codes = codes.view(M, K4, 4)
|
| 225 |
+
packed = ((codes[..., 0] << 6) | (codes[..., 1] << 4) |
|
| 226 |
+
(codes[..., 2] << 2) | codes[..., 3]).contiguous()
|
| 227 |
+
return packed.reshape(*q.shape[:-1], K4)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def unpack_ternary(packed: torch.Tensor, k: int,
|
| 231 |
+
alpha: Optional[torch.Tensor] = None,
|
| 232 |
+
dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
| 233 |
+
"""Vectorised inverse of :func:`pack_ternary`.
|
| 234 |
+
|
| 235 |
+
Returns ``out`` with last dim ``k``; optionally pre-multiplied by
|
| 236 |
+
``alpha`` (per-row scale, broadcastable on the leading axes).
|
| 237 |
+
"""
|
| 238 |
+
packed = packed.to(torch.uint8)
|
| 239 |
+
if packed.dim() == 1:
|
| 240 |
+
packed = packed.unsqueeze(0)
|
| 241 |
+
flat = packed.reshape(-1, packed.shape[-1])
|
| 242 |
+
M, K4 = flat.shape
|
| 243 |
+
# Gather all 4 sub-positions in one vectorised op.
|
| 244 |
+
shifts = _SHIFTS.to(packed.device)
|
| 245 |
+
codes = (flat.unsqueeze(-1) >> shifts).bitwise_and_(3).to(torch.long) # [M, K4, 4]
|
| 246 |
+
lut = _TERNARY_LUT_F32.to(device=packed.device, dtype=dtype)
|
| 247 |
+
out = lut[codes].reshape(M, K4 * 4)[:, :k]
|
| 248 |
+
if alpha is not None:
|
| 249 |
+
out = out * alpha.reshape(M, 1).to(device=out.device, dtype=out.dtype)
|
| 250 |
+
return out.reshape(*packed.shape[:-1], k)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def _absmean_alpha(weight: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
|
| 254 |
+
"""Per-output-channel scale (``\alpha = mean|w|`` clamped)."""
|
| 255 |
+
return weight.detach().abs().mean(dim=-1, keepdim=False).clamp_min(eps).to(torch.float32)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def ternarize_weight(weight: torch.Tensor, group_size: int = 128
|
| 259 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 260 |
+
"""Quantise FP32 weights to ternary using BitNet's abs-mean rule.
|
| 261 |
+
|
| 262 |
+
``group_size`` is kept for API compatibility but every row is its own
|
| 263 |
+
group in this slim implementation. Returns ``(w_ternary, alpha)``.
|
| 264 |
+
"""
|
| 265 |
+
alpha = _absmean_alpha(weight)
|
| 266 |
+
w_q = torch.round(torch.clamp(weight / alpha.unsqueeze(-1), -1.0, 1.0)).to(torch.int8)
|
| 267 |
+
return w_q, alpha
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
_quantize_weights_ternary = ternarize_weight # legacy alias used elsewhere
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def apply_2_4_sparsity_(weight: torch.Tensor) -> torch.Tensor:
|
| 274 |
+
"""In-place N:M 2:4 pruning. Vectorised — no Python row loops."""
|
| 275 |
+
with torch.no_grad():
|
| 276 |
+
last = weight.shape[-1]
|
| 277 |
+
pad = (-last) % 4
|
| 278 |
+
target = F.pad(weight, (0, pad)) if pad else weight
|
| 279 |
+
view = target.view(*target.shape[:-1], -1, 4)
|
| 280 |
+
# Keep the two largest in absolute value, zero the rest.
|
| 281 |
+
idx = view.abs().argsort(dim=-1)[..., :2]
|
| 282 |
+
view.scatter_(-1, idx, 0.0)
|
| 283 |
+
if pad:
|
| 284 |
+
weight.copy_(target[..., :last])
|
| 285 |
+
return weight
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
# ---------------------------------------------------------------------------
|
| 289 |
+
# Straight-Through Estimator for ternary quantization.
|
| 290 |
+
# ---------------------------------------------------------------------------
|
| 291 |
+
|
| 292 |
+
class _RoundTernarySTE(torch.autograd.Function):
|
| 293 |
+
@staticmethod
|
| 294 |
+
def forward(ctx, w: torch.Tensor) -> torch.Tensor: # type: ignore[override]
|
| 295 |
+
return torch.round(torch.clamp(w, -1.0, 1.0))
|
| 296 |
+
|
| 297 |
+
@staticmethod
|
| 298 |
+
def backward(ctx, grad_output: torch.Tensor): # type: ignore[override]
|
| 299 |
+
# Standard STE: gradient flows through, clipped to [-1, 1] so the
|
| 300 |
+
# latent FP32 weights cannot drift unboundedly.
|
| 301 |
+
return grad_output.clamp(-1.0, 1.0)
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def ste_ternary(w: torch.Tensor) -> torch.Tensor:
|
| 305 |
+
return _RoundTernarySTE.apply(w)
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
# ---------------------------------------------------------------------------
|
| 309 |
+
# BitLinear — single class, single fast path.
|
| 310 |
+
# ---------------------------------------------------------------------------
|
| 311 |
+
|
| 312 |
+
class BitLinear(nn.Module):
|
| 313 |
+
"""Linear layer with ternary {-1, 0, 1} weights and per-row absmean scale.
|
| 314 |
+
|
| 315 |
+
*Training (grad-enabled)*: STE ternarisation on the latent weight, dense
|
| 316 |
+
fp32/bf16 matmul. Backward flows to the latent weight via STE.
|
| 317 |
+
|
| 318 |
+
*Inference / no-grad*: weights are quantised once and cached as packed
|
| 319 |
+
2-bit uint8 + fp32 alpha. Each forward unpacks (vectorised PyTorch or
|
| 320 |
+
optional C++ kernel) into a reusable buffer and calls a single matmul.
|
| 321 |
+
"""
|
| 322 |
+
|
| 323 |
+
__constants__ = ["in_features", "out_features", "use_2_4"]
|
| 324 |
+
|
| 325 |
+
def __init__(self, in_features: int, out_features: int, bias: bool = False,
|
| 326 |
+
group_size: int = 128, nm_2_4: bool = False):
|
| 327 |
+
super().__init__()
|
| 328 |
+
self.in_features = int(in_features)
|
| 329 |
+
self.out_features = int(out_features)
|
| 330 |
+
self.group_size = int(group_size)
|
| 331 |
+
self.use_2_4 = bool(nm_2_4)
|
| 332 |
+
|
| 333 |
+
self.weight = nn.Parameter(torch.empty(self.out_features, self.in_features))
|
| 334 |
+
if bias:
|
| 335 |
+
self.bias = nn.Parameter(torch.zeros(self.out_features))
|
| 336 |
+
else:
|
| 337 |
+
self.register_parameter("bias", None)
|
| 338 |
+
|
| 339 |
+
# Caches. ``_cache_version`` is bumped whenever the latent weight
|
| 340 |
+
# changes; the forward pass compares it against ``_packed_version``
|
| 341 |
+
# to know when to repack.
|
| 342 |
+
self.register_buffer("_packed", torch.zeros(0, dtype=torch.uint8), persistent=False)
|
| 343 |
+
self.register_buffer("_alpha", torch.zeros(0, dtype=torch.float32), persistent=False)
|
| 344 |
+
# Optional dense fp32 cache of the dequantised ternary weight. This
|
| 345 |
+
# is what every inference forward actually needs, so caching it
|
| 346 |
+
# eliminates the per-call unpack and saves ~30-50% of CPU time on
|
| 347 |
+
# small models. It is only built lazily on first inference call.
|
| 348 |
+
self.register_buffer("_dense_w", torch.zeros(0, dtype=torch.float32), persistent=False)
|
| 349 |
+
self._packed_version = -1
|
| 350 |
+
self._dense_version = -1
|
| 351 |
+
self._cache_version = 0
|
| 352 |
+
|
| 353 |
+
self.reset_parameters()
|
| 354 |
+
|
| 355 |
+
# -- init ------------------------------------------------------------------
|
| 356 |
+
|
| 357 |
+
def reset_parameters(self) -> None:
|
| 358 |
+
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
| 359 |
+
if self.bias is not None:
|
| 360 |
+
nn.init.zeros_(self.bias)
|
| 361 |
+
self._cache_version += 1
|
| 362 |
+
|
| 363 |
+
# -- helpers ---------------------------------------------------------------
|
| 364 |
+
|
| 365 |
+
def invalidate_packed(self) -> None:
|
| 366 |
+
"""Mark the packed cache stale. Called after weight mutations."""
|
| 367 |
+
self._cache_version += 1
|
| 368 |
+
# Free the dense fp32 cache too; next forward will rebuild it.
|
| 369 |
+
if self._dense_w.numel() > 0:
|
| 370 |
+
self._dense_w = torch.zeros(0, dtype=torch.float32, device=self._dense_w.device)
|
| 371 |
+
self._dense_version = -1
|
| 372 |
+
|
| 373 |
+
def _quantize_latent(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 374 |
+
"""Quantise the FP32 latent weight to ternary (no-grad, no copy)."""
|
| 375 |
+
with torch.no_grad():
|
| 376 |
+
w = self.weight
|
| 377 |
+
alpha = _absmean_alpha(w)
|
| 378 |
+
w_q = torch.round(torch.clamp(w / alpha.unsqueeze(-1), -1.0, 1.0))
|
| 379 |
+
if self.use_2_4:
|
| 380 |
+
apply_2_4_sparsity_(w_q)
|
| 381 |
+
return w_q.to(torch.int8), alpha
|
| 382 |
+
|
| 383 |
+
def _ensure_packed(self) -> None:
|
| 384 |
+
if self._packed_version == self._cache_version and self._packed.numel() > 0:
|
| 385 |
+
return
|
| 386 |
+
with torch.no_grad():
|
| 387 |
+
w_q, alpha = self._quantize_latent()
|
| 388 |
+
ext = _NATIVE_EXT
|
| 389 |
+
if ext is not None:
|
| 390 |
+
packed = ext.pack_ternary(w_q)
|
| 391 |
+
else:
|
| 392 |
+
packed = pack_ternary(w_q)
|
| 393 |
+
# Replace storage in-place to avoid breaking nn.Module buffer tracking.
|
| 394 |
+
self._packed = packed.contiguous()
|
| 395 |
+
self._alpha = alpha.contiguous()
|
| 396 |
+
self._packed_version = self._cache_version
|
| 397 |
+
|
| 398 |
+
@torch.no_grad()
|
| 399 |
+
def prepare_for_inference(self) -> None:
|
| 400 |
+
"""Materialise the packed cache so the next forward is allocation-free."""
|
| 401 |
+
self.invalidate_packed()
|
| 402 |
+
self._ensure_packed()
|
| 403 |
+
|
| 404 |
+
@torch.no_grad()
|
| 405 |
+
def ternary_nonzero_mask(self) -> torch.Tensor:
|
| 406 |
+
"""Boolean mask of currently non-zero ternary positions (cached)."""
|
| 407 |
+
self._ensure_packed()
|
| 408 |
+
# Reuse the dequantised float view through unpack — cheaper than a fresh
|
| 409 |
+
# dense ternary tensor on small models, and shared for both branches.
|
| 410 |
+
ext = _NATIVE_EXT
|
| 411 |
+
if ext is not None:
|
| 412 |
+
w = ext.unpack_ternary(self._packed, self.in_features)
|
| 413 |
+
else:
|
| 414 |
+
w = unpack_ternary(self._packed, self.in_features)
|
| 415 |
+
return w.ne(0)
|
| 416 |
+
|
| 417 |
+
# -- forward ---------------------------------------------------------------
|
| 418 |
+
|
| 419 |
+
def _forward_train(self, x: torch.Tensor) -> torch.Tensor:
|
| 420 |
+
"""STE forward: differentiable, fp32/bf16 dense matmul."""
|
| 421 |
+
w = self.weight
|
| 422 |
+
alpha = w.detach().abs().mean(dim=-1, keepdim=True).clamp_min(1e-5)
|
| 423 |
+
w_q = ste_ternary(w / alpha) * alpha
|
| 424 |
+
if self.use_2_4:
|
| 425 |
+
# 2:4 sparsity is non-differentiable but only zeros gradients on
|
| 426 |
+
# already-pruned positions; safe to apply during STE forward.
|
| 427 |
+
with torch.no_grad():
|
| 428 |
+
mask = (apply_2_4_sparsity_(w_q.detach().clone()) != 0).to(w_q.dtype)
|
| 429 |
+
w_q = w_q * mask
|
| 430 |
+
return F.linear(x, w_q.to(x.dtype), self.bias)
|
| 431 |
+
|
| 432 |
+
def _ensure_dense(self) -> torch.Tensor:
|
| 433 |
+
"""Materialise (and cache) the fp32 dense ternary weight."""
|
| 434 |
+
self._ensure_packed()
|
| 435 |
+
if self._dense_version == self._cache_version and self._dense_w.numel() > 0:
|
| 436 |
+
return self._dense_w
|
| 437 |
+
ext = _NATIVE_EXT
|
| 438 |
+
if ext is not None:
|
| 439 |
+
w = ext.dequantize(self._packed, self._alpha, self.in_features)
|
| 440 |
+
else:
|
| 441 |
+
w = unpack_ternary(self._packed, self.in_features) * self._alpha.unsqueeze(-1)
|
| 442 |
+
# Replace the buffer in place so nn.Module book-keeping stays valid.
|
| 443 |
+
self._dense_w = w.contiguous()
|
| 444 |
+
self._dense_version = self._cache_version
|
| 445 |
+
return self._dense_w
|
| 446 |
+
|
| 447 |
+
def _forward_packed(self, x: torch.Tensor) -> torch.Tensor:
|
| 448 |
+
"""No-grad fast path that uses the cached dequantised weights."""
|
| 449 |
+
w = self._ensure_dense()
|
| 450 |
+
# Match dtype (bf16 autocast support) without re-allocating the cache.
|
| 451 |
+
if x.dtype != w.dtype:
|
| 452 |
+
w_used = w.to(x.dtype)
|
| 453 |
+
else:
|
| 454 |
+
w_used = w
|
| 455 |
+
return F.linear(x, w_used, self.bias)
|
| 456 |
+
|
| 457 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 458 |
+
if self.training and torch.is_grad_enabled():
|
| 459 |
+
return self._forward_train(x)
|
| 460 |
+
return self._forward_packed(x)
|
| 461 |
+
|
| 462 |
+
# -- introspection ---------------------------------------------------------
|
| 463 |
+
|
| 464 |
+
def extra_repr(self) -> str:
|
| 465 |
+
return (f"in_features={self.in_features}, out_features={self.out_features}, "
|
| 466 |
+
f"bias={self.bias is not None}, nm_2_4={self.use_2_4}, "
|
| 467 |
+
f"native={native_kernel_available()}")
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
# ---------------------------------------------------------------------------
|
| 471 |
+
# RMSNorm.
|
| 472 |
+
# ---------------------------------------------------------------------------
|
| 473 |
+
|
| 474 |
+
class RMSNorm(nn.Module):
|
| 475 |
+
"""Numerically-stable Root Mean Square LayerNorm (no bias, no centering)."""
|
| 476 |
+
|
| 477 |
+
__constants__ = ["dim", "eps"]
|
| 478 |
+
|
| 479 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 480 |
+
super().__init__()
|
| 481 |
+
self.dim = int(dim)
|
| 482 |
+
self.eps = float(eps)
|
| 483 |
+
self.weight = nn.Parameter(torch.ones(self.dim))
|
| 484 |
+
|
| 485 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 486 |
+
# The normalisation is computed in fp32 for stability under bf16
|
| 487 |
+
# autocast, then cast back to the input dtype.
|
| 488 |
+
dtype = x.dtype
|
| 489 |
+
if dtype != torch.float32:
|
| 490 |
+
x32 = x.float()
|
| 491 |
+
rms = torch.rsqrt(x32.pow(2).mean(dim=-1, keepdim=True).add(self.eps))
|
| 492 |
+
return (x32 * rms).to(dtype) * self.weight
|
| 493 |
+
rms = torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True).add(self.eps))
|
| 494 |
+
return x * rms * self.weight
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
__all__ = [
|
| 498 |
+
"BitLinear",
|
| 499 |
+
"RMSNorm",
|
| 500 |
+
"ste_ternary",
|
| 501 |
+
"pack_ternary",
|
| 502 |
+
"unpack_ternary",
|
| 503 |
+
"ternarize_weight",
|
| 504 |
+
"_quantize_weights_ternary",
|
| 505 |
+
"apply_2_4_sparsity_",
|
| 506 |
+
"enable_native_kernel",
|
| 507 |
+
"native_kernel_available",
|
| 508 |
+
]
|
chimera/tokenizer.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Chimera 5.1 — Splintr (Rust) Tokenizer Wrapper — o200k_base (OpenAI o1/o3)
|
| 3 |
+
Wraps splintr's high-performance Rust tokenizer for transformers-compatible API.
|
| 4 |
+
Vocab: o200k_base (200,073 tokens) — OpenAI's o1/o3 tokenizer.
|
| 5 |
+
|
| 6 |
+
Optimizations:
|
| 7 |
+
- __slots__ for reduced memory footprint
|
| 8 |
+
- Cached special token set for fast skip_special_tokens filtering
|
| 9 |
+
- Batch encode uses list comprehension (minimizes Python overhead)
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from typing import List, Union, Optional
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
from splintr import Tokenizer as _SplintrTokenizer, O200K_AGENT_TOKENS
|
| 17 |
+
HAS_SPLINTR = True
|
| 18 |
+
except ImportError:
|
| 19 |
+
HAS_SPLINTR = False
|
| 20 |
+
|
| 21 |
+
__all__ = ["ChimeraTokenizer"]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class ChimeraTokenizer:
|
| 25 |
+
"""
|
| 26 |
+
High-performance Rust-backed tokenizer (splintr) with HuggingFace-like interface.
|
| 27 |
+
Falls back to a basic tiktoken wrapper if splintr is not installed.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, pretrained: str = "o200k_base", vocab_size: int = 200073):
|
| 31 |
+
if not HAS_SPLINTR:
|
| 32 |
+
self._tok = None
|
| 33 |
+
self.vocab_size = int(vocab_size)
|
| 34 |
+
self.eos_token_id = min(self.vocab_size - 1, 199999)
|
| 35 |
+
self.pad_token_id = min(self.vocab_size - 1, 200058)
|
| 36 |
+
self.sep_token_id = min(self.vocab_size - 1, 200060)
|
| 37 |
+
self.stop_token_id = min(self.vocab_size - 1, 200059)
|
| 38 |
+
self.user_token_id = min(self.vocab_size - 1, 200020)
|
| 39 |
+
self.assistant_token_id = min(self.vocab_size - 1, 200021)
|
| 40 |
+
self.system_token_id = min(self.vocab_size - 1, 200019)
|
| 41 |
+
self.endofprompt_token_id = min(self.vocab_size - 1, 200018)
|
| 42 |
+
self.bos_token_id = self.eos_token_id
|
| 43 |
+
self.eos_token = "<|endoftext|>"
|
| 44 |
+
self.pad_token = "<|pad|>"
|
| 45 |
+
self.model_max_length = 4194304
|
| 46 |
+
self._special_ids = frozenset({self.eos_token_id, self.pad_token_id, self.sep_token_id, self.stop_token_id, self.user_token_id, self.assistant_token_id, self.system_token_id, self.endofprompt_token_id})
|
| 47 |
+
self._byte_offset = 3
|
| 48 |
+
return
|
| 49 |
+
self._tok = _SplintrTokenizer.from_pretrained(pretrained)
|
| 50 |
+
self.vocab_size = self._tok.vocab_size
|
| 51 |
+
|
| 52 |
+
# o200k_base single-token special IDs
|
| 53 |
+
self.eos_token_id = 199999
|
| 54 |
+
self.pad_token_id = O200K_AGENT_TOKENS.PAD # 200058
|
| 55 |
+
self.sep_token_id = O200K_AGENT_TOKENS.SEP # 200060
|
| 56 |
+
self.stop_token_id = O200K_AGENT_TOKENS.STOP # 200059
|
| 57 |
+
self.user_token_id = O200K_AGENT_TOKENS.USER # 200020
|
| 58 |
+
self.assistant_token_id = O200K_AGENT_TOKENS.ASSISTANT # 200021
|
| 59 |
+
self.system_token_id = 200019
|
| 60 |
+
self.endofprompt_token_id = 200018
|
| 61 |
+
self.bos_token_id = self.eos_token_id
|
| 62 |
+
|
| 63 |
+
self.eos_token = "<|endoftext|>"
|
| 64 |
+
self.pad_token = "<|pad|>"
|
| 65 |
+
self.model_max_length = 4194304
|
| 66 |
+
|
| 67 |
+
# Cached set for fast filtering
|
| 68 |
+
self._special_ids = frozenset({
|
| 69 |
+
self.eos_token_id, self.pad_token_id, self.sep_token_id,
|
| 70 |
+
self.stop_token_id, self.user_token_id,
|
| 71 |
+
self.assistant_token_id, self.system_token_id,
|
| 72 |
+
self.endofprompt_token_id,
|
| 73 |
+
})
|
| 74 |
+
|
| 75 |
+
def __len__(self) -> int:
|
| 76 |
+
return self.vocab_size
|
| 77 |
+
|
| 78 |
+
def encode(self, text: str, add_special_tokens: bool = True,
|
| 79 |
+
max_length: Optional[int] = None) -> List[int]:
|
| 80 |
+
if self._tok is None:
|
| 81 |
+
ids = [self._byte_offset + b for b in text.encode("utf-8", errors="replace")]
|
| 82 |
+
else:
|
| 83 |
+
ids = self._tok.encode(text)
|
| 84 |
+
if add_special_tokens:
|
| 85 |
+
ids = ids + [self.eos_token_id]
|
| 86 |
+
if max_length is not None and len(ids) > max_length:
|
| 87 |
+
ids = ids[:max_length]
|
| 88 |
+
return ids
|
| 89 |
+
|
| 90 |
+
def encode_batch(self, texts: List[str], add_special_tokens: bool = True,
|
| 91 |
+
max_length: Optional[int] = None,
|
| 92 |
+
padding: bool = False,
|
| 93 |
+
truncation: bool = False,
|
| 94 |
+
return_tensors: Optional[str] = None):
|
| 95 |
+
all_ids = [self.encode(t, add_special_tokens=add_special_tokens,
|
| 96 |
+
max_length=max_length)
|
| 97 |
+
for t in texts]
|
| 98 |
+
if padding:
|
| 99 |
+
max_len = max(len(ids) for ids in all_ids)
|
| 100 |
+
all_ids = [ids + [self.pad_token_id] * (max_len - len(ids))
|
| 101 |
+
for ids in all_ids]
|
| 102 |
+
if return_tensors == "pt":
|
| 103 |
+
return {"input_ids": torch.tensor(all_ids, dtype=torch.long)}
|
| 104 |
+
return all_ids
|
| 105 |
+
|
| 106 |
+
def decode(self, token_ids, skip_special_tokens: bool = True) -> str:
|
| 107 |
+
if isinstance(token_ids, torch.Tensor):
|
| 108 |
+
token_ids = token_ids.tolist()
|
| 109 |
+
if skip_special_tokens:
|
| 110 |
+
token_ids = [t for t in token_ids if t not in self._special_ids]
|
| 111 |
+
if self._tok is None:
|
| 112 |
+
data = bytes(max(0, min(255, int(t) - self._byte_offset)) for t in token_ids if int(t) >= self._byte_offset)
|
| 113 |
+
return data.decode("utf-8", errors="replace")
|
| 114 |
+
return self._tok.decode(token_ids)
|
| 115 |
+
|
| 116 |
+
def decode_batch(self, token_ids_list, skip_special_tokens: bool = True) -> List[str]:
|
| 117 |
+
return [self.decode(ids, skip_special_tokens=skip_special_tokens)
|
| 118 |
+
for ids in token_ids_list]
|
| 119 |
+
|
| 120 |
+
def __call__(self, text, **kwargs) -> dict:
|
| 121 |
+
return_tensors = kwargs.get("return_tensors", "pt")
|
| 122 |
+
padding = kwargs.get("padding", False)
|
| 123 |
+
max_length = kwargs.get("max_length", None)
|
| 124 |
+
add_special_tokens = kwargs.get("add_special_tokens", True)
|
| 125 |
+
if isinstance(text, str):
|
| 126 |
+
text = [text]
|
| 127 |
+
result = self.encode_batch(
|
| 128 |
+
text, add_special_tokens=add_special_tokens,
|
| 129 |
+
max_length=max_length, padding=padding,
|
| 130 |
+
return_tensors=return_tensors
|
| 131 |
+
)
|
| 132 |
+
if isinstance(result, list):
|
| 133 |
+
return {"input_ids": torch.tensor(result, dtype=torch.long)}
|
| 134 |
+
return result
|
| 135 |
+
|
| 136 |
+
def get_vocab(self) -> dict:
|
| 137 |
+
return {
|
| 138 |
+
self.eos_token_id: self.eos_token,
|
| 139 |
+
self.pad_token_id: self.pad_token,
|
| 140 |
+
self.user_token_id: "<|user|>",
|
| 141 |
+
self.assistant_token_id: "<|assistant|>",
|
| 142 |
+
self.system_token_id: "<|system|>",
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
def apply_chat_template(self, messages: List[dict],
|
| 146 |
+
add_generation_prompt: bool = False) -> str:
|
| 147 |
+
parts = []
|
| 148 |
+
for msg in messages:
|
| 149 |
+
role = msg.get("role", "user")
|
| 150 |
+
content = msg.get("content", "")
|
| 151 |
+
if role == "system":
|
| 152 |
+
parts.append(f"<|system|>\n{content}\n<|endofprompt|>")
|
| 153 |
+
elif role == "user":
|
| 154 |
+
parts.append(f"<|user|>\n{content}\n<|endofprompt|>")
|
| 155 |
+
elif role == "assistant":
|
| 156 |
+
parts.append(f"<|assistant|>\n{content}\n<|endofprompt|>")
|
| 157 |
+
text = "\n".join(parts)
|
| 158 |
+
if add_generation_prompt:
|
| 159 |
+
text += "\n<|assistant|>\n"
|
| 160 |
+
return text
|
config.json
ADDED
|
@@ -0,0 +1,638 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "chimera-5.1-final",
|
| 3 |
+
"_v": "5.1.2",
|
| 4 |
+
"architectures": ["Chimera51ForCausalLM"],
|
| 5 |
+
"auto_map": {
|
| 6 |
+
"AutoConfig": "configuration_chimera51.Chimera51Config",
|
| 7 |
+
"AutoModelForCausalLM": "modeling_chimera51.Chimera51ForCausalLM"
|
| 8 |
+
},
|
| 9 |
+
"model_type": "chimera51",
|
| 10 |
+
"token_ids": [199999, 200058],
|
| 11 |
+
"hidden_size": 2560,
|
| 12 |
+
"intermediate_size": 6912,
|
| 13 |
+
"num_hidden_layers": 28,
|
| 14 |
+
"num_heads": 40,
|
| 15 |
+
"head_dim": 64,
|
| 16 |
+
"hidden_act": "swiglu",
|
| 17 |
+
"initializer_range": 0.006,
|
| 18 |
+
"rms_norm_eps": 1e-6,
|
| 19 |
+
"rms_norm_before_every_linear": true,
|
| 20 |
+
"vocab_size": 200073,
|
| 21 |
+
"max_position_embeddings": 4194304,
|
| 22 |
+
"tie_word_embeddings": true,
|
| 23 |
+
"torch_dtype": "bfloat16",
|
| 24 |
+
"use_cache": false,
|
| 25 |
+
"transformers_version": "4.58.0",
|
| 26 |
+
|
| 27 |
+
"§": {
|
| 28 |
+
"r0": "2412.06464",
|
| 29 |
+
"r1": "2405.04517",
|
| 30 |
+
"r2": "2501.00663",
|
| 31 |
+
"r3": "2604.12946",
|
| 32 |
+
"r4": "2510.04800",
|
| 33 |
+
"r5": "2402.17764",
|
| 34 |
+
"r6": "2505.08823",
|
| 35 |
+
"r7": "2502.11880",
|
| 36 |
+
"r8": "2601.07892",
|
| 37 |
+
"r9": "2602.05269",
|
| 38 |
+
"r10": "2503.01840",
|
| 39 |
+
"r11": "2505.14969",
|
| 40 |
+
"r12": "2411.15100",
|
| 41 |
+
"r13": "2601.04426",
|
| 42 |
+
"r14": "2604.06169",
|
| 43 |
+
"r15": "2602.02369",
|
| 44 |
+
"r16": "2402.04624",
|
| 45 |
+
"r17": "2508.16153",
|
| 46 |
+
"r18": "2310.00533",
|
| 47 |
+
"r19": "2404.02258",
|
| 48 |
+
"r20": "2510.11170",
|
| 49 |
+
"r21": "2408.15664",
|
| 50 |
+
"r22": "2512.12602",
|
| 51 |
+
"r23": "2412.09871",
|
| 52 |
+
"r24": "2501.15570",
|
| 53 |
+
"r25": "2506.12119",
|
| 54 |
+
"r26": "2407.00088",
|
| 55 |
+
"r27": "2410.16144",
|
| 56 |
+
"r28": "2512.06443",
|
| 57 |
+
"r29": "2305.17333",
|
| 58 |
+
"r30": "2509.00031",
|
| 59 |
+
"r31": "2305.17190",
|
| 60 |
+
"r32": "2402.16363",
|
| 61 |
+
"r33": "2502.12444",
|
| 62 |
+
"r34": "2603.13931",
|
| 63 |
+
"r35": "2302.04852",
|
| 64 |
+
"r36": "2305.02299"
|
| 65 |
+
},
|
| 66 |
+
|
| 67 |
+
"quantization": {
|
| 68 |
+
"method": "bitnet",
|
| 69 |
+
"linear_class": "ternary_bitplane",
|
| 70 |
+
"weight_bits": 1.58,
|
| 71 |
+
"weight_values": [-1, 0, 1],
|
| 72 |
+
"weight_scale": "absmean_per_group",
|
| 73 |
+
"group_size": 128,
|
| 74 |
+
"activation_bits": 8,
|
| 75 |
+
"activation_method": "absmax_per_block",
|
| 76 |
+
"activation_block_size": 64,
|
| 77 |
+
"accumulator_dtype": "int32",
|
| 78 |
+
"norm_dtype": "float32",
|
| 79 |
+
"runtime_kernel": "TL2_bitnet_cpp",
|
| 80 |
+
"§": ["r5", "r7", "r27"],
|
| 81 |
+
"sherry_mode": {
|
| 82 |
+
"enabled": false,
|
| 83 |
+
"bits": 1.25,
|
| 84 |
+
"§": "r8"
|
| 85 |
+
},
|
| 86 |
+
"hgf_correction": {
|
| 87 |
+
"enabled": false,
|
| 88 |
+
"§": "r9"
|
| 89 |
+
}
|
| 90 |
+
},
|
| 91 |
+
|
| 92 |
+
"backbone": {
|
| 93 |
+
"type": "hybrid_recurrent_no_attention",
|
| 94 |
+
"layer_pattern": "GD XM GD TM GD XM GD SK",
|
| 95 |
+
"layer_pattern_repeat": 3.5,
|
| 96 |
+
"layer_aliases": {
|
| 97 |
+
"GD": "gated_deltanet",
|
| 98 |
+
"XM": "xlstm_m",
|
| 99 |
+
"TM": "titans_mac",
|
| 100 |
+
"SK": "tsp_span_knot"
|
| 101 |
+
},
|
| 102 |
+
"layer_counts": {"GD": 14, "XM": 7, "TM": 4, "SK": 3},
|
| 103 |
+
"kv_cache": "none",
|
| 104 |
+
"§": ["r0", "r1", "r2", "r4"],
|
| 105 |
+
|
| 106 |
+
"moe": {
|
| 107 |
+
"enabled": true,
|
| 108 |
+
"layers": [3, 7, 11, 15, 19, 23, 27],
|
| 109 |
+
"n_routed_experts": 16,
|
| 110 |
+
"n_shared_experts": 1,
|
| 111 |
+
"num_experts_per_tok": 2,
|
| 112 |
+
"moe_intermediate_size": 1728,
|
| 113 |
+
"routing": "noaux_bias",
|
| 114 |
+
"total_params": "350M",
|
| 115 |
+
"active_params_per_tok": "44M",
|
| 116 |
+
"§": ["r21", "r25"]
|
| 117 |
+
}
|
| 118 |
+
},
|
| 119 |
+
|
| 120 |
+
"gated_deltanet": {
|
| 121 |
+
"formulation": "S_t = S_{t-1} * (α_t * (I - β_t * k_t * k_t^T)) + β_t * v_t * k_t^T",
|
| 122 |
+
"alpha_gate": "data_dependent_scalar",
|
| 123 |
+
"beta_gate": "data_dependent_scalar",
|
| 124 |
+
"state_size": 64,
|
| 125 |
+
"chunkwise_parallel": true,
|
| 126 |
+
"chunk_size": 256,
|
| 127 |
+
"key_norm": "l2",
|
| 128 |
+
"§": "r0"
|
| 129 |
+
},
|
| 130 |
+
|
| 131 |
+
"efla": {
|
| 132 |
+
"enabled": false,
|
| 133 |
+
"target_layers": "SK",
|
| 134 |
+
"§": "r22"
|
| 135 |
+
},
|
| 136 |
+
|
| 137 |
+
"xlstm": {
|
| 138 |
+
"variant": "mLSTM",
|
| 139 |
+
"exponential_gating": true,
|
| 140 |
+
"memory_size_per_head": [64, 64],
|
| 141 |
+
"covariance_update": true,
|
| 142 |
+
"normalizer_state": "max_stabilized",
|
| 143 |
+
"§": "r1"
|
| 144 |
+
},
|
| 145 |
+
|
| 146 |
+
"titans": {
|
| 147 |
+
"memory_type": "MAC",
|
| 148 |
+
"memory_depth": 2,
|
| 149 |
+
"surprise_metric": "gradient_with_momentum",
|
| 150 |
+
"surprise_formula": "S_t = η_t · S_{t-1} − θ_t · ∇ℓ(M_{t-1}; x_t)",
|
| 151 |
+
"forgetting_formula": "M_t = (1 − α_t) · M_{t-1} + S_t",
|
| 152 |
+
"persistent_memory_slots": 64,
|
| 153 |
+
"local_window_size": 1024,
|
| 154 |
+
"§": "r2"
|
| 155 |
+
},
|
| 156 |
+
|
| 157 |
+
"looping": {
|
| 158 |
+
"enabled": true,
|
| 159 |
+
"method": "parcae_zoh_stable",
|
| 160 |
+
"prelude": [0, 3],
|
| 161 |
+
"loop": [4, 23],
|
| 162 |
+
"coda": [24, 27],
|
| 163 |
+
"loop_range": [1, 6],
|
| 164 |
+
"loop_default": 2,
|
| 165 |
+
"stability_A": "diag_negative_exp",
|
| 166 |
+
"spectral_radius_bound": 1.0,
|
| 167 |
+
"depth_selection": "stochastic_per_sequence",
|
| 168 |
+
"adaptive_exit_threshold": 0.01,
|
| 169 |
+
"backward_truncation": "half",
|
| 170 |
+
"§": "r3"
|
| 171 |
+
},
|
| 172 |
+
|
| 173 |
+
"span_inference": {
|
| 174 |
+
"enabled": true,
|
| 175 |
+
"bank_entries": 524288,
|
| 176 |
+
"bank_avg_tokens": 5,
|
| 177 |
+
"bank_max_tokens": 64,
|
| 178 |
+
"bank_memory_mb": 384,
|
| 179 |
+
"candidate_sources": [64, 48, 48, 32],
|
| 180 |
+
"candidate_source_keys": ["semantic_lsh", "grammar_allowed", "cache_hits", "neural_novel"],
|
| 181 |
+
"candidates_fast": 192,
|
| 182 |
+
"candidates_reason": 512,
|
| 183 |
+
|
| 184 |
+
"tree_verify": {
|
| 185 |
+
"enabled": true,
|
| 186 |
+
"method": "STree",
|
| 187 |
+
"tree_width": 4,
|
| 188 |
+
"tree_depth": 5,
|
| 189 |
+
"hardware_aware": true,
|
| 190 |
+
"§": "r11"
|
| 191 |
+
},
|
| 192 |
+
|
| 193 |
+
"certificate_fields": ["span_id_u32", "semantic_delta_8192b", "grammar_delta_128b", "entity_delta_512b", "debt_delta_64b", "boundary_logprob_i16", "interior_risk_u8"],
|
| 194 |
+
"certificate_verify_max_us": 100,
|
| 195 |
+
"adaptive_mask_cache": true,
|
| 196 |
+
"render_queue_target": 256,
|
| 197 |
+
"render_queue_max": 2048,
|
| 198 |
+
"fallback_below_acceptance": 0.5,
|
| 199 |
+
|
| 200 |
+
"scoring_keys": ["semantic", "grammar", "memory", "debt", "boundary"],
|
| 201 |
+
"scoring_weights_fast": [1.0, 0.8, 0.5, 0.7, 0.35],
|
| 202 |
+
"§": ["r10", "r12"]
|
| 203 |
+
},
|
| 204 |
+
|
| 205 |
+
"tsp_knot": {
|
| 206 |
+
"energy_terms": {
|
| 207 |
+
"autoregressive": [1.0, "embedding_inner_product"],
|
| 208 |
+
"memory_coherence": [0.3, "hamming_to_semantic_sketch"],
|
| 209 |
+
"binding_fidelity": [0.2, "xor_unbind_popcount"],
|
| 210 |
+
"grammar": [0.4, "fst_transition_cost"],
|
| 211 |
+
"debt": [0.3, "obligation_delta"]
|
| 212 |
+
},
|
| 213 |
+
"relaxation_phase1": "gated_deltanet_update",
|
| 214 |
+
"relaxation_phase2_max_iters": 3,
|
| 215 |
+
"relaxation_phase2_flip_fraction": 0.02,
|
| 216 |
+
"early_exit_delta_e": 1e-4
|
| 217 |
+
},
|
| 218 |
+
|
| 219 |
+
"grammar": {
|
| 220 |
+
"enabled": true,
|
| 221 |
+
"modes": ["plain_text", "dialogue", "markdown", "json", "python", "javascript", "sql", "math_latex", "shell"],
|
| 222 |
+
"representation": "deterministic_fst_plus_weighted",
|
| 223 |
+
"storage_mb": 64,
|
| 224 |
+
"hard_constraints": ["balanced_brackets", "valid_json_in_json_mode", "fence_closure", "string_literal_closure"],
|
| 225 |
+
"soft_constraints": ["sentence_rhythm", "repetition_avoidance", "paragraph_length"],
|
| 226 |
+
"adaptive_mask_cache": true,
|
| 227 |
+
"jit_compilation": true,
|
| 228 |
+
"§": ["r12", "r13"]
|
| 229 |
+
},
|
| 230 |
+
|
| 231 |
+
"semantic_memory": {
|
| 232 |
+
"vector_bits": 8192,
|
| 233 |
+
"vector_storage": "uint64_x128",
|
| 234 |
+
"capacity": 200000,
|
| 235 |
+
"relations": 500000,
|
| 236 |
+
"memory_mb": 320,
|
| 237 |
+
"ops": ["xor_bind", "xor_unbind", "majority_bundle", "popcnt_hamming", "rotate_permute"],
|
| 238 |
+
"lsh_tables": 64,
|
| 239 |
+
"lsh_bits_per_table": 14,
|
| 240 |
+
"hot_cache_entries": 16384,
|
| 241 |
+
"read_at_every_knot": true,
|
| 242 |
+
"write_policy": "surprise_threshold_plus_contrastive_validation",
|
| 243 |
+
"forgetting_policy": "fixed_pool_exponential_decay",
|
| 244 |
+
"pool_size_fixed": true,
|
| 245 |
+
"§": ["r15", "r16"]
|
| 246 |
+
},
|
| 247 |
+
|
| 248 |
+
"entropy_valve": {
|
| 249 |
+
"enabled": true,
|
| 250 |
+
"metrics": ["span_energy_margin", "grammar_branching", "sketch_instability", "entity_conflicts", "debt_pressure", "queue_depth"],
|
| 251 |
+
"threshold_bits": 2.0,
|
| 252 |
+
"type": "inference_time_compute_allocation",
|
| 253 |
+
"loop_depth_router": {
|
| 254 |
+
"method": "mod_causal_predictor",
|
| 255 |
+
"accuracy_target": 0.97,
|
| 256 |
+
"§": "r19"
|
| 257 |
+
},
|
| 258 |
+
"levels": {
|
| 259 |
+
"low": {"loops": 1, "min_span": 8, "audit": 0.125},
|
| 260 |
+
"medium": {"loops": 2, "min_span": 4, "audit": 0.5},
|
| 261 |
+
"high": {"loops": 4, "min_span": 1, "audit": 1.0}
|
| 262 |
+
},
|
| 263 |
+
"§": "r20"
|
| 264 |
+
},
|
| 265 |
+
|
| 266 |
+
"debt_ledger": {
|
| 267 |
+
"enabled": true,
|
| 268 |
+
"obligations": ["close_bracket", "close_string", "close_fence", "resolve_pronoun", "finish_list", "maintain_tense", "complete_sentence", "end_json_object"],
|
| 269 |
+
"max_outstanding": 64,
|
| 270 |
+
"pressure_weight": 0.3
|
| 271 |
+
},
|
| 272 |
+
|
| 273 |
+
"self_evolution": {
|
| 274 |
+
"num_mechanisms": 7,
|
| 275 |
+
|
| 276 |
+
"tier1": {
|
| 277 |
+
"ttt": {
|
| 278 |
+
"enabled": true,
|
| 279 |
+
"target_layers": [13, 23],
|
| 280 |
+
"target_param": "mlp_w_down",
|
| 281 |
+
"inner_lr": 0.0003,
|
| 282 |
+
"inner_optimizer": "sgd_momentum",
|
| 283 |
+
"momentum": 0.9,
|
| 284 |
+
"objective": "next_token_prediction",
|
| 285 |
+
"chunk_size": 1024,
|
| 286 |
+
"update_scope": "full_w_down",
|
| 287 |
+
"reset_decay": 0.95,
|
| 288 |
+
"persistence": "per_user_session_file",
|
| 289 |
+
"§": "r14"
|
| 290 |
+
},
|
| 291 |
+
"memory_growth": {
|
| 292 |
+
"enabled": true,
|
| 293 |
+
"surprise_threshold": "titans_gradient_magnitude_above_2_sigma",
|
| 294 |
+
"contrastive_validation": true,
|
| 295 |
+
"user_explicit_store": true,
|
| 296 |
+
"max_per_session": 1000,
|
| 297 |
+
"pool_fixed": true,
|
| 298 |
+
"forgetting": "random_drop_k_append_k",
|
| 299 |
+
"persistent": true,
|
| 300 |
+
"pruning": "low_retrieval_weight_eviction",
|
| 301 |
+
"§": ["r15", "r16"]
|
| 302 |
+
}
|
| 303 |
+
},
|
| 304 |
+
|
| 305 |
+
"tier2": {
|
| 306 |
+
"meta_guidelines": {
|
| 307 |
+
"enabled": true,
|
| 308 |
+
"max": 256,
|
| 309 |
+
"format": "8192bit_xor",
|
| 310 |
+
"trigger": "contrastive_eval_negative",
|
| 311 |
+
"§": "r15"
|
| 312 |
+
},
|
| 313 |
+
"episodic_cases": {
|
| 314 |
+
"enabled": true,
|
| 315 |
+
"retrieval": "soft_q_learning",
|
| 316 |
+
"max_cases": 4096,
|
| 317 |
+
"case_bytes": 2048,
|
| 318 |
+
"weight_update": "outcome_based_ema",
|
| 319 |
+
"§": "r17"
|
| 320 |
+
},
|
| 321 |
+
"self_feedback": {
|
| 322 |
+
"enabled": true,
|
| 323 |
+
"confidence_threshold": 0.6,
|
| 324 |
+
"max_refinement_rounds": 1,
|
| 325 |
+
"§": "r18"
|
| 326 |
+
}
|
| 327 |
+
},
|
| 328 |
+
|
| 329 |
+
"tier3": {
|
| 330 |
+
"span_bank_expansion": {
|
| 331 |
+
"enabled": true,
|
| 332 |
+
"min_span_len": 4,
|
| 333 |
+
"max_new_per_session": 256,
|
| 334 |
+
"acceptance": "cert_valid AND no_correction AND used_3plus",
|
| 335 |
+
"persistent": true,
|
| 336 |
+
"compression": "merge_similar_periodic"
|
| 337 |
+
},
|
| 338 |
+
"loop_depth_learning": {
|
| 339 |
+
"enabled": true,
|
| 340 |
+
"classifier": "int8_2layer_mlp",
|
| 341 |
+
"classifier_params": 500000,
|
| 342 |
+
"signal": "parcae_convergence_speed",
|
| 343 |
+
"persistent": true
|
| 344 |
+
}
|
| 345 |
+
},
|
| 346 |
+
|
| 347 |
+
"safety": {
|
| 348 |
+
"max_growth_mb": {"memory": 512, "span_bank": 128, "episodic": 8, "guidelines": 2},
|
| 349 |
+
"rollback_on_degradation": true,
|
| 350 |
+
"monitor": "certificate_failure_rate_and_rollback_rate",
|
| 351 |
+
"freeze_threshold": 0.05,
|
| 352 |
+
"user_reset": true,
|
| 353 |
+
"state_file": "chimera51_evolution.state"
|
| 354 |
+
}
|
| 355 |
+
},
|
| 356 |
+
|
| 357 |
+
"braid_state": {
|
| 358 |
+
"continuous_hidden": [2560, "float32"],
|
| 359 |
+
"fast_hidden": [2560, "int8"],
|
| 360 |
+
"semantic_sketch": [8192, "uint64_x128"],
|
| 361 |
+
"entity_table": {"slots": 256, "slot_bits": 512, "binding": "xor_role_filler"},
|
| 362 |
+
"grammar_stack": {"slots": 64, "width_bits": 128},
|
| 363 |
+
"debt_ledger_slots": 64,
|
| 364 |
+
"per_stream_mb": 30,
|
| 365 |
+
"kv_growth_per_token": 0
|
| 366 |
+
},
|
| 367 |
+
|
| 368 |
+
"modes": {
|
| 369 |
+
"fast": {"tps": 200, "neural_hz": 40, "span_avg": 5, "loops": 1, "audit": 0.125},
|
| 370 |
+
"balanced": {"tps": 120, "neural_hz": 30, "span_avg": 4, "loops": 2, "audit": 0.5},
|
| 371 |
+
"reasoning": {"tps": 40, "neural_hz": 20, "span_avg": 2, "loops": 4, "audit": 1.0}
|
| 372 |
+
},
|
| 373 |
+
|
| 374 |
+
"generation": {
|
| 375 |
+
"temperature": 0.7,
|
| 376 |
+
"top_p": 0.92,
|
| 377 |
+
"repetition_penalty": 1.08,
|
| 378 |
+
"max_new_tokens": 4096,
|
| 379 |
+
"do_sample": true,
|
| 380 |
+
"stream": true
|
| 381 |
+
},
|
| 382 |
+
|
| 383 |
+
"training": {
|
| 384 |
+
"phases": [
|
| 385 |
+
{
|
| 386 |
+
"name": "pretrain",
|
| 387 |
+
"tokens": "2T",
|
| 388 |
+
"data": ["FineWeb-Edu", "SlimPajama", "StarCoder-data", "multilingual-CC"],
|
| 389 |
+
"seq_len": 4096,
|
| 390 |
+
"batch_tokens": "4M",
|
| 391 |
+
"optimizer": "AdamW",
|
| 392 |
+
"lr": 3e-4,
|
| 393 |
+
"schedule": "cosine_warmup",
|
| 394 |
+
"warmup_steps": 2000,
|
| 395 |
+
"weight_decay": 0.1,
|
| 396 |
+
"grad_clip": 1.0,
|
| 397 |
+
"ternary": "native_qat_ste",
|
| 398 |
+
"§": ["r5", "r6"]
|
| 399 |
+
},
|
| 400 |
+
{
|
| 401 |
+
"name": "ctx_extend",
|
| 402 |
+
"stages": [
|
| 403 |
+
[4096, "main"],
|
| 404 |
+
[16384, 10000, 1e-5],
|
| 405 |
+
[65536, 5000, 5e-6],
|
| 406 |
+
[262144, 2000, 2e-6]
|
| 407 |
+
]
|
| 408 |
+
},
|
| 409 |
+
{
|
| 410 |
+
"name": "sft",
|
| 411 |
+
"data": ["UltraChat-200k", "ShareGPT-cleaned"],
|
| 412 |
+
"epochs": 3,
|
| 413 |
+
"lr": 2e-5
|
| 414 |
+
},
|
| 415 |
+
{
|
| 416 |
+
"name": "dpo",
|
| 417 |
+
"data": "UltraFeedback-binarized",
|
| 418 |
+
"epochs": 1,
|
| 419 |
+
"lr": 5e-7,
|
| 420 |
+
"beta": 0.1
|
| 421 |
+
}
|
| 422 |
+
],
|
| 423 |
+
"distillation_init": {
|
| 424 |
+
"enabled": false,
|
| 425 |
+
"method": "ARWKV_style",
|
| 426 |
+
"teacher": "Qwen-2.5-7B",
|
| 427 |
+
"tokens": "1B",
|
| 428 |
+
"§": "r24"
|
| 429 |
+
}
|
| 430 |
+
},
|
| 431 |
+
|
| 432 |
+
"byte_level": {
|
| 433 |
+
"enabled": false,
|
| 434 |
+
"encoder_params": "50M",
|
| 435 |
+
"encoder_depth": 8,
|
| 436 |
+
"patching": "entropy_threshold",
|
| 437 |
+
"decoder_params": "50M",
|
| 438 |
+
"§": "r23"
|
| 439 |
+
},
|
| 440 |
+
|
| 441 |
+
"memory_budget_mb": {
|
| 442 |
+
"_keys": ["ternary_weights", "moe_experts", "span_bank", "grammar", "semantic_mem", "episodic", "guidelines", "braid", "activations", "render_queue", "evolution", "runtime_os"],
|
| 443 |
+
"_vals": [410, 66, 384, 64, 320, 8, 2, 30, 80, 32, 128, 1000],
|
| 444 |
+
"total": 2524,
|
| 445 |
+
"headroom_8gb": 4876,
|
| 446 |
+
"growth_ceiling": 650,
|
| 447 |
+
"max_with_growth": 3174
|
| 448 |
+
},
|
| 449 |
+
|
| 450 |
+
"deployment": {
|
| 451 |
+
"batch_size": 1,
|
| 452 |
+
"max_streams": 16,
|
| 453 |
+
"per_stream_mb": 30,
|
| 454 |
+
"shared": ["weights", "span_bank", "grammar"],
|
| 455 |
+
"mmap": ["weights", "span_bank"],
|
| 456 |
+
"cold_start_s": 2.5,
|
| 457 |
+
"watchdog_tick_ms": 20,
|
| 458 |
+
"watchdog_max_overruns": 8,
|
| 459 |
+
"deterministic": true,
|
| 460 |
+
"seed_controls_all": true,
|
| 461 |
+
"platforms": ["x86_64_avx2", "aarch64_neon", "wasm_simd128", "apple_silicon_amx"]
|
| 462 |
+
},
|
| 463 |
+
|
| 464 |
+
"diagnostics": {
|
| 465 |
+
"telemetry": true,
|
| 466 |
+
"report_interval_tokens": 256,
|
| 467 |
+
"metrics": [
|
| 468 |
+
"surface_tps", "neural_knot_tps", "mean_span_length",
|
| 469 |
+
"span_acceptance_rate", "certificate_failure_rate",
|
| 470 |
+
"rollback_count", "queue_depth", "loop_count_mean",
|
| 471 |
+
"memory_mb", "evolution_events", "grammar_violations_prevented",
|
| 472 |
+
"contrastive_eval_ratio", "self_refinement_trigger_rate",
|
| 473 |
+
"episodic_case_hit_rate", "moe_expert_load_balance",
|
| 474 |
+
"gd_alpha_mean", "gd_beta_mean", "ttt_loss_delta"
|
| 475 |
+
],
|
| 476 |
+
"thresholds": {
|
| 477 |
+
"min_span_accept": 0.70,
|
| 478 |
+
"max_cert_fail": 0.05,
|
| 479 |
+
"max_rollback": 0.02,
|
| 480 |
+
"min_contrastive_benefit": 0.0,
|
| 481 |
+
"max_moe_imbalance": 0.15
|
| 482 |
+
}
|
| 483 |
+
},
|
| 484 |
+
|
| 485 |
+
"context_tiers": [
|
| 486 |
+
{"name": "recent_ring", "tokens": 4096, "mb": 16},
|
| 487 |
+
{"name": "braid_state", "mb": 30},
|
| 488 |
+
{"name": "semantic_memory", "mb": 320},
|
| 489 |
+
{"name": "ttt_compressed", "mb": 24},
|
| 490 |
+
{"name": "span_trace", "entries": 32768, "mb": 32},
|
| 491 |
+
{"name": "episodic_cases", "entries": 4096, "mb": 8}
|
| 492 |
+
],
|
| 493 |
+
|
| 494 |
+
"multimodal": {
|
| 495 |
+
"enabled": true,
|
| 496 |
+
"modalities": ["text", "image", "audio"],
|
| 497 |
+
"vision": {"type": "gated_deltanet_tiny", "depth": 12, "hidden": 384, "patch": 16, "out": 2560, "quant": "ternary"},
|
| 498 |
+
"audio": {"type": "gated_deltanet_audio_tiny", "depth": 6, "hidden": 256, "out": 2560, "quant": "ternary"}
|
| 499 |
+
},
|
| 500 |
+
|
| 501 |
+
"safety": {
|
| 502 |
+
"format_guards": ["json_strict", "code_fence_closure", "markdown_table_guard"],
|
| 503 |
+
"memory_limit_enforced": true,
|
| 504 |
+
"crash_only_allocator": true,
|
| 505 |
+
"user_facts_override_weak_memory": true,
|
| 506 |
+
"state_uncertainty_when_unsure": true
|
| 507 |
+
},
|
| 508 |
+
|
| 509 |
+
"files": {
|
| 510 |
+
"weights": "chimera51.b158",
|
| 511 |
+
"moe": "chimera51_experts.b158",
|
| 512 |
+
"spans": "chimera51_spans.sfpack",
|
| 513 |
+
"grammar": "chimera51_grammar.fstpack",
|
| 514 |
+
"memory_seed": "chimera51_memory.seedpack",
|
| 515 |
+
"tokenizer": "chimera51_tokenizer.model",
|
| 516 |
+
"evolution": "chimera51_evolution.state"
|
| 517 |
+
},
|
| 518 |
+
|
| 519 |
+
"params": {
|
| 520 |
+
"base": "2.3B",
|
| 521 |
+
"moe_total": "350M",
|
| 522 |
+
"physical": "2.65B",
|
| 523 |
+
"effective_2loops": "4.2B",
|
| 524 |
+
"effective_6loops": "9.5B",
|
| 525 |
+
"active_per_token": "2.39B",
|
| 526 |
+
"weight_mb": 476,
|
| 527 |
+
"total_mb": 2524
|
| 528 |
+
},
|
| 529 |
+
|
| 530 |
+
"P3_ternary_compute": {
|
| 531 |
+
"_note": "v5.1.2 — Honest section. Documents ONLY what is implemented and measured. Previous v5.1.0 claims of '1080× speedup' were aspirational and not implementable.",
|
| 532 |
+
|
| 533 |
+
"thesis": "Ternary weights {-1,0,1} enable 16× memory reduction via 2-bit packed storage. On CPU, training speed is dominated by MKL BLAS — raw ternary matmul is not faster than FP32 at small-to-medium sizes. The real wins are: (1) 16× less RAM enabling larger models on limited hardware, (2) 16× less memory bandwidth for large models where DRAM is the bottleneck, (3) MeZO eliminates the backward pass entirely (2× forward only). Inference post-training uses LUT-based kernels (T-MAC, bitnet.cpp) for true speedup.",
|
| 534 |
+
|
| 535 |
+
"implemented_optimizations": {
|
| 536 |
+
"mezo_optimizer": {
|
| 537 |
+
"status": "IMPLEMENTED",
|
| 538 |
+
"description": "Memory-Efficient Zeroth-Order optimizer — eliminates backward pass entirely. 2 forward passes per step.",
|
| 539 |
+
"benefit": "Memory = 2× model size (no activations, no gradients, no optimizer states). Ideal for CPU with complex recurrences.",
|
| 540 |
+
"limitation": "Requires ~32× more steps to converge than AdamW. Best for fine-tuning, not pretraining from scratch.",
|
| 541 |
+
"§": "r29"
|
| 542 |
+
},
|
| 543 |
+
"bf16_autocast": {
|
| 544 |
+
"status": "IMPLEMENTED",
|
| 545 |
+
"description": "BFloat16 automatic mixed precision on CPU via torch.autocast('cpu', dtype=torch.bfloat16).",
|
| 546 |
+
"benefit": "2-4× faster matmuls on Intel Sapphire Rapids+ (AMX) or Ice Lake+ (AVX-512-BF16). Falls back to FP32 emulation on older CPUs.",
|
| 547 |
+
"limitation": "Forward-pass only. Gradients remain FP32."
|
| 548 |
+
},
|
| 549 |
+
"torch_compile": {
|
| 550 |
+
"status": "IMPLEMENTED",
|
| 551 |
+
"description": "torch.compile with Inductor backend for CPU. Fuses ops, reduces Python overhead.",
|
| 552 |
+
"benefit": "1.3-2× overall training throughput.",
|
| 553 |
+
"limitation": "First iteration is slow (compilation). Dynamic shapes supported."
|
| 554 |
+
},
|
| 555 |
+
"parallel_mlstm": {
|
| 556 |
+
"status": "IMPLEMENTED",
|
| 557 |
+
"description": "Replaced O(T) Python loop with parallel log-space cumulative gate computation + batched QKV attention.",
|
| 558 |
+
"benefit": "~10-50× faster for mLSTM layers on CPU (seq_len ≥ 64).",
|
| 559 |
+
"§": "r1"
|
| 560 |
+
},
|
| 561 |
+
"parallel_titans_mac": {
|
| 562 |
+
"status": "IMPLEMENTED",
|
| 563 |
+
"description": "Replaced O(T) Python loop with causal decay attention + vectorized contribution computation.",
|
| 564 |
+
"benefit": "~5-20× faster for Titans MAC layers on CPU.",
|
| 565 |
+
"§": "r2"
|
| 566 |
+
},
|
| 567 |
+
"sort_based_moe": {
|
| 568 |
+
"status": "IMPLEMENTED",
|
| 569 |
+
"description": "Sort tokens by expert ID → process contiguous blocks → scatter_add back. Cache-friendly CPU dispatch.",
|
| 570 |
+
"benefit": "Better cache locality than random-access per-expert dispatch.",
|
| 571 |
+
"§": "r21"
|
| 572 |
+
},
|
| 573 |
+
"gradient_checkpointing": {
|
| 574 |
+
"status": "IMPLEMENTED",
|
| 575 |
+
"description": "Per-block activation checkpointing for AdamW mode.",
|
| 576 |
+
"benefit": "30-60% memory reduction, enabling larger batches."
|
| 577 |
+
},
|
| 578 |
+
"cpu_thread_tuning": {
|
| 579 |
+
"status": "IMPLEMENTED",
|
| 580 |
+
"description": "OMP_NUM_THREADS, KMP_AFFINITY=compact, KMP_BLOCKTIME=1, torch.set_num_threads/interop_threads.",
|
| 581 |
+
"benefit": "10-30% throughput improvement from optimal thread placement."
|
| 582 |
+
},
|
| 583 |
+
"ipex_integration": {
|
| 584 |
+
"status": "IMPLEMENTED (optional)",
|
| 585 |
+
"description": "Auto-detected Intel Extension for PyTorch. ipex.optimize() with BF16 + AMX kernel selection.",
|
| 586 |
+
"benefit": "Additional 30-50% on Intel CPUs."
|
| 587 |
+
},
|
| 588 |
+
"ternary_qat_ste": {
|
| 589 |
+
"status": "IMPLEMENTED",
|
| 590 |
+
"description": "BitNet 1.58 quantization-aware training with STE. Per-group AbsMean weight quantization, per-block AbsMax int8 activations.",
|
| 591 |
+
"benefit": "Model learns ternary weight distribution. Enables efficient inference with LUT-based kernels (bitnet.cpp, T-MAC) post-training.",
|
| 592 |
+
"limitation": "Training itself is NOT faster than FP16 — STE backward pass uses FP32 matmuls.",
|
| 593 |
+
"§": ["r5", "r7"]
|
| 594 |
+
},
|
| 595 |
+
"two_bit_packed_weights": {
|
| 596 |
+
"status": "IMPLEMENTED v5.1.2",
|
| 597 |
+
"description": "Ternary weights packed as 2-bit uint8 (4 weights per byte). Custom C++ kernel with OpenMP for unpack.",
|
| 598 |
+
"benefit": "16× less storage vs FP32 (e.g. 2.5B model: 10GB → 0.6GB). 94% less memory bandwidth for weight loading.",
|
| 599 |
+
"limitation": "Unpack overhead makes single-layer forward ~0.5-0.7× FP32 at small sizes. Win is at large model sizes where DRAM bandwidth dominates.",
|
| 600 |
+
"implementation": "pack_ternary_fast() + unpack_into() in C++ with OpenMP. Pre-allocated float buffer reused across steps."
|
| 601 |
+
},
|
| 602 |
+
"zero_multiply_forward": {
|
| 603 |
+
"status": "IMPLEMENTED v5.1.2",
|
| 604 |
+
"description": "Forward and backward grad_x use ternary unpack + MKL BLAS. The matmul sees only add/sub operations conceptually, but executed via BLAS for performance.",
|
| 605 |
+
"benefit": "No FP32 multiply on ternary weights (unpack produces {-α,0,+α}). Grad_x path also zero-multiply.",
|
| 606 |
+
"limitation": "BLAS still executes multiply-add; the zero-multiply is at the algorithmic level, not instruction-level.",
|
| 607 |
+
"note": "True instruction-level zero-multiply requires custom assembly (VPSHUFB LUT) — not implemented due to backward incompatibility with STE."
|
| 608 |
+
},
|
| 609 |
+
"ternary_mezo_sparse": {
|
| 610 |
+
"status": "IMPLEMENTED v5.1.2",
|
| 611 |
+
"description": "MeZO perturbation and update skip zero-weight positions (~33% of ternary weights). C++ kernel with per-thread deterministic LCG.",
|
| 612 |
+
"benefit": "33% fewer perturbation operations per step. Skips ~1/3 of random number generation and memory writes.",
|
| 613 |
+
"limitation": "Only applies to BitLinear layers. Other params (norms, biases, embeddings) still fully perturbed."
|
| 614 |
+
},
|
| 615 |
+
"sparse_grad_w_masking": {
|
| 616 |
+
"status": "IMPLEMENTED v5.1.2",
|
| 617 |
+
"description": "STE backward grad_w masks 'deep zero' weights (|w_scaled| < 0.3) to zero.",
|
| 618 |
+
"benefit": "Saves ~10-15% of grad_w computation (fewer elements in outer product).",
|
| 619 |
+
"limitation": "Small gain; FP32 matmul still dominates backward time."
|
| 620 |
+
}
|
| 621 |
+
},
|
| 622 |
+
|
| 623 |
+
"not_implemented": {
|
| 624 |
+
"elut_training": "ELUT/T-MAC kernels apply to INFERENCE only. LUT precomputation is invalidated by weight updates during training.",
|
| 625 |
+
"mixture_of_depths": "MoD requires specific router architecture. Not implemented in current backbone.",
|
| 626 |
+
"sparse_backprop": "SparseProp requires ≥90% weight sparsity. Incompatible with QAT from random init (~33% zeros)."
|
| 627 |
+
},
|
| 628 |
+
|
| 629 |
+
"realistic_performance": {
|
| 630 |
+
"cpu_training_tiny_35M": {"hardware": "i7-14700T", "throughput": "~50-200 tok/s", "note": "With MeZO+BF16+compile"},
|
| 631 |
+
"cpu_training_small_150M": {"hardware": "i7-14700T", "throughput": "~10-50 tok/s", "note": "With MeZO+BF16+compile"},
|
| 632 |
+
"cpu_inference_ternary": {"note": "Post-training with bitnet.cpp/T-MAC: 30-127 tok/s for 700M-3B models"},
|
| 633 |
+
"gpu_training_comparison": "GPU (A100) is 50-150× faster than CPU for training equivalent model sizes. CPU training is best for fine-tuning (MeZO), not pretraining."
|
| 634 |
+
},
|
| 635 |
+
|
| 636 |
+
"§_paradigm": ["r26", "r27", "r28", "r29", "r30", "r31", "r32", "r33", "r5", "r34", "r7", "r19"]
|
| 637 |
+
}
|
| 638 |
+
}
|
gguf_import.py
ADDED
|
@@ -0,0 +1,905 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
Chimera GGUF Import Optimized
|
| 5 |
+
═════════════════════════════
|
| 6 |
+
|
| 7 |
+
Convert GGUF tensors into a Chimera-compatible checkpoint.
|
| 8 |
+
|
| 9 |
+
Améliorations vs version originale :
|
| 10 |
+
- Ne garde pas tous les tensors GGUF FP32 en mémoire.
|
| 11 |
+
- Corrige le bug embeddings/lm_head traités comme BitLinear.
|
| 12 |
+
- Quantization ternary offline sans autograd.
|
| 13 |
+
- Clipping outlier par ligne pour les matrices.
|
| 14 |
+
- Auto-transpose si shape inversée.
|
| 15 |
+
- Modes de stockage :
|
| 16 |
+
fp32 : compatible Chimera classique, sauvegarde weight latent.
|
| 17 |
+
packed : sauvegarde packed_weight + alpha uniquement pour couches linéaires.
|
| 18 |
+
both : sauvegarde weight + packed_weight + alpha.
|
| 19 |
+
- Init des poids manquants pour checkpoint complet.
|
| 20 |
+
- Resize configurable : strict, crop_pad, interpolate.
|
| 21 |
+
- Mapping GGUF plus robuste pour LLaMA/Qwen/Mistral-like.
|
| 22 |
+
|
| 23 |
+
Usage :
|
| 24 |
+
python gguf_import_optimized.py \
|
| 25 |
+
--gguf model.gguf \
|
| 26 |
+
--config config.json \
|
| 27 |
+
--scale tiny \
|
| 28 |
+
--output imported_chimera.pt \
|
| 29 |
+
--storage fp32
|
| 30 |
+
|
| 31 |
+
Pour checkpoint compact expérimental :
|
| 32 |
+
python gguf_import_optimized.py \
|
| 33 |
+
--gguf model.gguf \
|
| 34 |
+
--config config.json \
|
| 35 |
+
--output imported_chimera_packed.pt \
|
| 36 |
+
--storage packed
|
| 37 |
+
|
| 38 |
+
Attention :
|
| 39 |
+
- storage=packed nécessite que ton loader Chimera sache lire
|
| 40 |
+
*.packed_weight et *.alpha.
|
| 41 |
+
- Importer un gros modèle vers tiny/small via resize détruit beaucoup
|
| 42 |
+
d'information. C'est utile pour bootstrap, pas équivalent à distillation.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
import os
|
| 46 |
+
import re
|
| 47 |
+
import gc
|
| 48 |
+
import json
|
| 49 |
+
import math
|
| 50 |
+
import argparse
|
| 51 |
+
from copy import deepcopy
|
| 52 |
+
from pathlib import Path
|
| 53 |
+
from typing import Dict, Tuple, Optional, Iterable, Any
|
| 54 |
+
|
| 55 |
+
import numpy as np
|
| 56 |
+
import torch
|
| 57 |
+
import torch.nn.functional as F
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
try:
|
| 61 |
+
from gguf import GGUFReader, dequantize
|
| 62 |
+
HAS_GGUF = True
|
| 63 |
+
except Exception:
|
| 64 |
+
GGUFReader = None
|
| 65 |
+
dequantize = None
|
| 66 |
+
HAS_GGUF = False
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# ═══════════════════════════════════════════════════════════
|
| 70 |
+
# Config scales
|
| 71 |
+
# ═══════════════════════════════════════════════════════════
|
| 72 |
+
|
| 73 |
+
SCALE_OVERRIDES = {
|
| 74 |
+
"tiny": {
|
| 75 |
+
"hidden_size": 256,
|
| 76 |
+
"intermediate_size": 512,
|
| 77 |
+
"num_hidden_layers": 28,
|
| 78 |
+
"num_heads": 4,
|
| 79 |
+
"head_dim": 48,
|
| 80 |
+
},
|
| 81 |
+
"small": {
|
| 82 |
+
"hidden_size": 512,
|
| 83 |
+
"intermediate_size": 1024,
|
| 84 |
+
"num_hidden_layers": 28,
|
| 85 |
+
"num_heads": 8,
|
| 86 |
+
"head_dim": 48,
|
| 87 |
+
},
|
| 88 |
+
"medium": {
|
| 89 |
+
"hidden_size": 1024,
|
| 90 |
+
"intermediate_size": 2048,
|
| 91 |
+
"num_hidden_layers": 28,
|
| 92 |
+
"num_heads": 8,
|
| 93 |
+
"head_dim": 96,
|
| 94 |
+
},
|
| 95 |
+
# full = garde config telle quelle
|
| 96 |
+
"full": {},
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# ═══════════════════════════════════════════════════════════
|
| 101 |
+
# Mapping GGUF -> Chimera
|
| 102 |
+
# ═══════════════════════════════════════════════════════════
|
| 103 |
+
|
| 104 |
+
DIRECT_NAME_MAP = {
|
| 105 |
+
"token_embd": "embed.weight",
|
| 106 |
+
"token_embd.weight": "embed.weight",
|
| 107 |
+
|
| 108 |
+
"output": "lm_head.weight",
|
| 109 |
+
"output.weight": "lm_head.weight",
|
| 110 |
+
|
| 111 |
+
"output_norm": "norm.weight",
|
| 112 |
+
"output_norm.weight": "norm.weight",
|
| 113 |
+
|
| 114 |
+
# Variants parfois rencontrées
|
| 115 |
+
"norm": "norm.weight",
|
| 116 |
+
"norm.weight": "norm.weight",
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
BLOCK_SUFFIX_MAP = {
|
| 121 |
+
# Attention norm
|
| 122 |
+
"attn_norm": "attn_norm.weight",
|
| 123 |
+
"attn_norm.weight": "attn_norm.weight",
|
| 124 |
+
|
| 125 |
+
# FFN norm
|
| 126 |
+
"ffn_norm": "mlp_norm.weight",
|
| 127 |
+
"ffn_norm.weight": "mlp_norm.weight",
|
| 128 |
+
|
| 129 |
+
# Attention projections
|
| 130 |
+
"attn_q": "attn.q_proj.weight",
|
| 131 |
+
"attn_q.weight": "attn.q_proj.weight",
|
| 132 |
+
"attn_k": "attn.k_proj.weight",
|
| 133 |
+
"attn_k.weight": "attn.k_proj.weight",
|
| 134 |
+
"attn_v": "attn.v_proj.weight",
|
| 135 |
+
"attn_v.weight": "attn.v_proj.weight",
|
| 136 |
+
"attn_output": "attn.o_proj.weight",
|
| 137 |
+
"attn_output.weight": "attn.o_proj.weight",
|
| 138 |
+
|
| 139 |
+
# MLP / SwiGLU
|
| 140 |
+
"ffn_gate": "mlp.gate_proj.weight",
|
| 141 |
+
"ffn_gate.weight": "mlp.gate_proj.weight",
|
| 142 |
+
"ffn_up": "mlp.up_proj.weight",
|
| 143 |
+
"ffn_up.weight": "mlp.up_proj.weight",
|
| 144 |
+
"ffn_down": "mlp.down_proj.weight",
|
| 145 |
+
"ffn_down.weight": "mlp.down_proj.weight",
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def map_gguf_name(name: str, n_layers: int) -> Optional[str]:
|
| 150 |
+
"""
|
| 151 |
+
Convertit un nom GGUF vers une clé Chimera.
|
| 152 |
+
Retourne None si non mappable.
|
| 153 |
+
"""
|
| 154 |
+
if name in DIRECT_NAME_MAP:
|
| 155 |
+
return DIRECT_NAME_MAP[name]
|
| 156 |
+
|
| 157 |
+
m = re.match(r"^blk\.(\d+)\.(.+)$", name)
|
| 158 |
+
if not m:
|
| 159 |
+
return None
|
| 160 |
+
|
| 161 |
+
bid = int(m.group(1))
|
| 162 |
+
suffix = m.group(2)
|
| 163 |
+
|
| 164 |
+
if bid >= n_layers:
|
| 165 |
+
return None
|
| 166 |
+
|
| 167 |
+
mapped_suffix = BLOCK_SUFFIX_MAP.get(suffix)
|
| 168 |
+
if mapped_suffix is None:
|
| 169 |
+
return None
|
| 170 |
+
|
| 171 |
+
return f"layers.{bid}.{mapped_suffix}"
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
# ═══════════════════════════════════════════════════════════
|
| 175 |
+
# Ternary quantization + packing
|
| 176 |
+
# ═══════════════════════════════════════════════════════════
|
| 177 |
+
|
| 178 |
+
@torch.no_grad()
|
| 179 |
+
def ternary_quantize_absmean(
|
| 180 |
+
w: torch.Tensor,
|
| 181 |
+
threshold: float = 0.5,
|
| 182 |
+
eps: float = 1e-5,
|
| 183 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 184 |
+
"""
|
| 185 |
+
Convertit w FP32 [M,K] -> w_q int8 {-1,0,1} + alpha [M].
|
| 186 |
+
|
| 187 |
+
alpha = mean(abs(w), dim=1)
|
| 188 |
+
w_norm = w / alpha
|
| 189 |
+
q = -1 si w_norm <= -threshold
|
| 190 |
+
0 si entre
|
| 191 |
+
+1 si w_norm >= threshold
|
| 192 |
+
"""
|
| 193 |
+
if w.ndim != 2:
|
| 194 |
+
raise ValueError("ternary_quantize_absmean attend un tensor 2D")
|
| 195 |
+
|
| 196 |
+
w = w.to(torch.float32)
|
| 197 |
+
alpha = w.abs().mean(dim=1).clamp_min(eps)
|
| 198 |
+
|
| 199 |
+
wn = w / alpha[:, None]
|
| 200 |
+
q = torch.zeros_like(wn, dtype=torch.int8)
|
| 201 |
+
q[wn >= threshold] = 1
|
| 202 |
+
q[wn <= -threshold] = -1
|
| 203 |
+
|
| 204 |
+
return q, alpha.to(torch.float32)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
@torch.no_grad()
|
| 208 |
+
def pack_ternary_2bit(w_q: torch.Tensor) -> torch.Tensor:
|
| 209 |
+
"""
|
| 210 |
+
Pack int8 {-1,0,+1} -> uint8, 4 poids par byte.
|
| 211 |
+
|
| 212 |
+
Encoding :
|
| 213 |
+
0 -> 00
|
| 214 |
+
+1 -> 01
|
| 215 |
+
-1 -> 10
|
| 216 |
+
|
| 217 |
+
Ordre :
|
| 218 |
+
weight0 bits 7..6
|
| 219 |
+
weight1 bits 5..4
|
| 220 |
+
weight2 bits 3..2
|
| 221 |
+
weight3 bits 1..0
|
| 222 |
+
"""
|
| 223 |
+
if w_q.ndim != 2:
|
| 224 |
+
raise ValueError("pack_ternary_2bit attend un tensor 2D")
|
| 225 |
+
|
| 226 |
+
M, K = w_q.shape
|
| 227 |
+
K4 = (K + 3) // 4
|
| 228 |
+
pad = K4 * 4 - K
|
| 229 |
+
|
| 230 |
+
codes = torch.zeros_like(w_q, dtype=torch.uint8)
|
| 231 |
+
codes[w_q == 1] = 1
|
| 232 |
+
codes[w_q == -1] = 2
|
| 233 |
+
|
| 234 |
+
if pad:
|
| 235 |
+
codes = F.pad(codes, (0, pad), value=0)
|
| 236 |
+
|
| 237 |
+
codes = codes.view(M, K4, 4)
|
| 238 |
+
packed = (
|
| 239 |
+
(codes[..., 0] << 6)
|
| 240 |
+
| (codes[..., 1] << 4)
|
| 241 |
+
| (codes[..., 2] << 2)
|
| 242 |
+
| codes[..., 3]
|
| 243 |
+
)
|
| 244 |
+
return packed.contiguous()
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
# ═══════════════════════════════════════════════════════════
|
| 248 |
+
# Noise reduction
|
| 249 |
+
# ═══════════════════════════════════════════════════════════
|
| 250 |
+
|
| 251 |
+
@torch.no_grad()
|
| 252 |
+
def reduce_noise(
|
| 253 |
+
w: torch.Tensor,
|
| 254 |
+
method: str = "row_outlier_clip",
|
| 255 |
+
sigma: float = 3.0,
|
| 256 |
+
eps: float = 1e-5,
|
| 257 |
+
) -> torch.Tensor:
|
| 258 |
+
"""
|
| 259 |
+
Prétraitement avant ternarisation.
|
| 260 |
+
|
| 261 |
+
none : rien.
|
| 262 |
+
global_clip : clip global mean ± sigma*std.
|
| 263 |
+
row_outlier_clip : clip par ligne, meilleur pour matrices linéaires.
|
| 264 |
+
median_center : recentrage robuste global median/MAD.
|
| 265 |
+
"""
|
| 266 |
+
if method == "none":
|
| 267 |
+
return w
|
| 268 |
+
|
| 269 |
+
w = w.to(torch.float32)
|
| 270 |
+
|
| 271 |
+
if method == "global_clip":
|
| 272 |
+
mu = w.mean()
|
| 273 |
+
std = w.std(unbiased=False).clamp_min(eps)
|
| 274 |
+
return w.clamp(mu - sigma * std, mu + sigma * std)
|
| 275 |
+
|
| 276 |
+
if method == "row_outlier_clip":
|
| 277 |
+
if w.ndim != 2:
|
| 278 |
+
return reduce_noise(w, method="global_clip", sigma=sigma, eps=eps)
|
| 279 |
+
|
| 280 |
+
mu = w.mean(dim=1, keepdim=True)
|
| 281 |
+
std = w.std(dim=1, keepdim=True, unbiased=False).clamp_min(eps)
|
| 282 |
+
return w.clamp(mu - sigma * std, mu + sigma * std)
|
| 283 |
+
|
| 284 |
+
if method == "median_center":
|
| 285 |
+
med = w.median()
|
| 286 |
+
mad = (w - med).abs().median().clamp_min(eps)
|
| 287 |
+
return (w - med) / mad
|
| 288 |
+
|
| 289 |
+
return w
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
# ═══════════════════════════════════════════════════════════
|
| 293 |
+
# Resize helpers
|
| 294 |
+
# ═══════════════════════════════════════════════════════════
|
| 295 |
+
|
| 296 |
+
@torch.no_grad()
|
| 297 |
+
def resize_1d(w: torch.Tensor, target: int) -> torch.Tensor:
|
| 298 |
+
src = w.numel()
|
| 299 |
+
if src == target:
|
| 300 |
+
return w.contiguous()
|
| 301 |
+
|
| 302 |
+
out = torch.ones(target, dtype=w.dtype)
|
| 303 |
+
n = min(src, target)
|
| 304 |
+
out[:n] = w[:n]
|
| 305 |
+
return out.contiguous()
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
@torch.no_grad()
|
| 309 |
+
def resize_2d_crop_pad(
|
| 310 |
+
w: torch.Tensor,
|
| 311 |
+
target_shape: Tuple[int, int],
|
| 312 |
+
fill_std: float = 0.02,
|
| 313 |
+
) -> torch.Tensor:
|
| 314 |
+
"""
|
| 315 |
+
Resize rapide par crop/pad.
|
| 316 |
+
Plus prévisible qu'une interpolation sur poids Transformer.
|
| 317 |
+
"""
|
| 318 |
+
target_out, target_in = target_shape
|
| 319 |
+
src_out, src_in = w.shape
|
| 320 |
+
|
| 321 |
+
if (src_out, src_in) == (target_out, target_in):
|
| 322 |
+
return w.contiguous()
|
| 323 |
+
|
| 324 |
+
out = torch.empty((target_out, target_in), dtype=w.dtype)
|
| 325 |
+
|
| 326 |
+
# init zones non copiées
|
| 327 |
+
std = float(w.std(unbiased=False).item()) if w.numel() > 1 else fill_std
|
| 328 |
+
std = max(min(std, 0.2), 1e-4)
|
| 329 |
+
out.normal_(mean=0.0, std=std)
|
| 330 |
+
|
| 331 |
+
ro = min(src_out, target_out)
|
| 332 |
+
ci = min(src_in, target_in)
|
| 333 |
+
out[:ro, :ci] = w[:ro, :ci]
|
| 334 |
+
|
| 335 |
+
return out.contiguous()
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
@torch.no_grad()
|
| 339 |
+
def resize_2d_interpolate(
|
| 340 |
+
w: torch.Tensor,
|
| 341 |
+
target_shape: Tuple[int, int],
|
| 342 |
+
) -> torch.Tensor:
|
| 343 |
+
target_out, target_in = target_shape
|
| 344 |
+
if tuple(w.shape) == tuple(target_shape):
|
| 345 |
+
return w.contiguous()
|
| 346 |
+
|
| 347 |
+
x = w[None, None, :, :]
|
| 348 |
+
y = F.interpolate(
|
| 349 |
+
x,
|
| 350 |
+
size=(target_out, target_in),
|
| 351 |
+
mode="bilinear",
|
| 352 |
+
align_corners=False,
|
| 353 |
+
)
|
| 354 |
+
return y[0, 0].contiguous()
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
@torch.no_grad()
|
| 358 |
+
def resize_2d(
|
| 359 |
+
w: torch.Tensor,
|
| 360 |
+
target_shape: Tuple[int, int],
|
| 361 |
+
strategy: str = "crop_pad",
|
| 362 |
+
) -> torch.Tensor:
|
| 363 |
+
if tuple(w.shape) == tuple(target_shape):
|
| 364 |
+
return w.contiguous()
|
| 365 |
+
|
| 366 |
+
if strategy == "strict":
|
| 367 |
+
raise ValueError(f"Shape mismatch: got {tuple(w.shape)}, expected {target_shape}")
|
| 368 |
+
|
| 369 |
+
if strategy == "crop_pad":
|
| 370 |
+
return resize_2d_crop_pad(w, target_shape)
|
| 371 |
+
|
| 372 |
+
if strategy == "interpolate":
|
| 373 |
+
return resize_2d_interpolate(w, target_shape)
|
| 374 |
+
|
| 375 |
+
raise ValueError(f"resize strategy inconnue: {strategy}")
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
# ═══════════════════════════════════════════════════════════
|
| 379 |
+
# Importer
|
| 380 |
+
# ═══════════════════════════════════════════════════════════
|
| 381 |
+
|
| 382 |
+
class OptimizedGGUFImporter:
|
| 383 |
+
def __init__(
|
| 384 |
+
self,
|
| 385 |
+
config: Dict[str, Any],
|
| 386 |
+
scale: str = "tiny",
|
| 387 |
+
storage: str = "fp32",
|
| 388 |
+
param_dtype: str = "fp32",
|
| 389 |
+
noise_method: str = "row_outlier_clip",
|
| 390 |
+
noise_sigma: float = 3.0,
|
| 391 |
+
ternary_threshold: float = 0.5,
|
| 392 |
+
resize_strategy: str = "crop_pad",
|
| 393 |
+
auto_transpose: bool = True,
|
| 394 |
+
init_missing: bool = True,
|
| 395 |
+
verbose: bool = True,
|
| 396 |
+
):
|
| 397 |
+
self.config = deepcopy(config)
|
| 398 |
+
self.scale = scale
|
| 399 |
+
self.storage = storage
|
| 400 |
+
self.param_dtype = param_dtype
|
| 401 |
+
self.noise_method = noise_method
|
| 402 |
+
self.noise_sigma = noise_sigma
|
| 403 |
+
self.ternary_threshold = ternary_threshold
|
| 404 |
+
self.resize_strategy = resize_strategy
|
| 405 |
+
self.auto_transpose = auto_transpose
|
| 406 |
+
self.init_missing = init_missing
|
| 407 |
+
self.verbose = verbose
|
| 408 |
+
|
| 409 |
+
if scale not in SCALE_OVERRIDES:
|
| 410 |
+
raise ValueError(f"scale invalide: {scale}")
|
| 411 |
+
|
| 412 |
+
self.config.update(SCALE_OVERRIDES[scale])
|
| 413 |
+
|
| 414 |
+
self.n_layers = int(self.config["num_hidden_layers"])
|
| 415 |
+
self.hidden_size = int(self.config["hidden_size"])
|
| 416 |
+
self.vocab_size = int(self.config["vocab_size"])
|
| 417 |
+
self.num_heads = int(self.config.get("num_heads", 4))
|
| 418 |
+
self.head_dim = int(self.config.get("head_dim", self.hidden_size // self.num_heads))
|
| 419 |
+
|
| 420 |
+
inter = int(self.config["intermediate_size"])
|
| 421 |
+
self.intermediate_size = 256 * ((inter + 255) // 256)
|
| 422 |
+
self.config["intermediate_size"] = self.intermediate_size
|
| 423 |
+
|
| 424 |
+
if storage not in {"fp32", "packed", "both"}:
|
| 425 |
+
raise ValueError("storage doit être: fp32, packed ou both")
|
| 426 |
+
|
| 427 |
+
if param_dtype not in {"fp32", "fp16", "bf16"}:
|
| 428 |
+
raise ValueError("param_dtype doit être: fp32, fp16 ou bf16")
|
| 429 |
+
|
| 430 |
+
if self.verbose:
|
| 431 |
+
self.log(
|
| 432 |
+
f"[CONFIG] scale={scale} h={self.hidden_size} "
|
| 433 |
+
f"layers={self.n_layers} heads={self.num_heads} "
|
| 434 |
+
f"head_dim={self.head_dim} inter={self.intermediate_size} "
|
| 435 |
+
f"vocab={self.vocab_size}"
|
| 436 |
+
)
|
| 437 |
+
self.log(
|
| 438 |
+
f"[CONFIG] storage={storage} param_dtype={param_dtype} "
|
| 439 |
+
f"resize={resize_strategy} noise={noise_method}"
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
def log(self, msg: str):
|
| 443 |
+
if self.verbose:
|
| 444 |
+
print(msg, flush=True)
|
| 445 |
+
|
| 446 |
+
def target_dtype(self):
|
| 447 |
+
if self.param_dtype == "fp16":
|
| 448 |
+
return torch.float16
|
| 449 |
+
if self.param_dtype == "bf16":
|
| 450 |
+
return torch.bfloat16
|
| 451 |
+
return torch.float32
|
| 452 |
+
|
| 453 |
+
def infer_shape(self, key: str) -> Tuple[int, ...]:
|
| 454 |
+
h = self.hidden_size
|
| 455 |
+
attn_dim = self.num_heads * self.head_dim
|
| 456 |
+
|
| 457 |
+
if key == "embed.weight":
|
| 458 |
+
return (self.vocab_size, h)
|
| 459 |
+
|
| 460 |
+
if key == "lm_head.weight":
|
| 461 |
+
return (self.vocab_size, h)
|
| 462 |
+
|
| 463 |
+
if key == "norm.weight":
|
| 464 |
+
return (h,)
|
| 465 |
+
|
| 466 |
+
if key.endswith("attn_norm.weight") or key.endswith("mlp_norm.weight"):
|
| 467 |
+
return (h,)
|
| 468 |
+
|
| 469 |
+
if key.endswith("attn.q_proj.weight"):
|
| 470 |
+
return (attn_dim, h)
|
| 471 |
+
if key.endswith("attn.k_proj.weight"):
|
| 472 |
+
return (attn_dim, h)
|
| 473 |
+
if key.endswith("attn.v_proj.weight"):
|
| 474 |
+
return (attn_dim, h)
|
| 475 |
+
if key.endswith("attn.o_proj.weight"):
|
| 476 |
+
return (h, attn_dim)
|
| 477 |
+
|
| 478 |
+
if key.endswith("mlp.gate_proj.weight"):
|
| 479 |
+
return (self.intermediate_size, h)
|
| 480 |
+
if key.endswith("mlp.up_proj.weight"):
|
| 481 |
+
return (self.intermediate_size, h)
|
| 482 |
+
if key.endswith("mlp.down_proj.weight"):
|
| 483 |
+
return (h, self.intermediate_size)
|
| 484 |
+
|
| 485 |
+
raise KeyError(f"Impossible d'inférer la shape pour {key}")
|
| 486 |
+
|
| 487 |
+
def all_expected_keys(self) -> Iterable[str]:
|
| 488 |
+
yield "embed.weight"
|
| 489 |
+
yield "norm.weight"
|
| 490 |
+
yield "lm_head.weight"
|
| 491 |
+
|
| 492 |
+
for i in range(self.n_layers):
|
| 493 |
+
prefix = f"layers.{i}"
|
| 494 |
+
yield f"{prefix}.attn_norm.weight"
|
| 495 |
+
yield f"{prefix}.mlp_norm.weight"
|
| 496 |
+
yield f"{prefix}.attn.q_proj.weight"
|
| 497 |
+
yield f"{prefix}.attn.k_proj.weight"
|
| 498 |
+
yield f"{prefix}.attn.v_proj.weight"
|
| 499 |
+
yield f"{prefix}.attn.o_proj.weight"
|
| 500 |
+
yield f"{prefix}.mlp.gate_proj.weight"
|
| 501 |
+
yield f"{prefix}.mlp.up_proj.weight"
|
| 502 |
+
yield f"{prefix}.mlp.down_proj.weight"
|
| 503 |
+
|
| 504 |
+
def is_linear_key(self, key: str) -> bool:
|
| 505 |
+
return any(
|
| 506 |
+
key.endswith(s)
|
| 507 |
+
for s in (
|
| 508 |
+
"attn.q_proj.weight",
|
| 509 |
+
"attn.k_proj.weight",
|
| 510 |
+
"attn.v_proj.weight",
|
| 511 |
+
"attn.o_proj.weight",
|
| 512 |
+
"mlp.gate_proj.weight",
|
| 513 |
+
"mlp.up_proj.weight",
|
| 514 |
+
"mlp.down_proj.weight",
|
| 515 |
+
)
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
def is_embedding_or_head(self, key: str) -> bool:
|
| 519 |
+
return key in {"embed.weight", "lm_head.weight"}
|
| 520 |
+
|
| 521 |
+
def maybe_transpose(self, w: torch.Tensor, expected: Tuple[int, ...], key: str) -> torch.Tensor:
|
| 522 |
+
if not self.auto_transpose:
|
| 523 |
+
return w
|
| 524 |
+
|
| 525 |
+
if w.ndim == 2 and len(expected) == 2:
|
| 526 |
+
if tuple(w.shape) != tuple(expected) and tuple(w.t().shape) == tuple(expected):
|
| 527 |
+
self.log(f" [TRANSPOSE] {key}: {tuple(w.shape)} -> {tuple(w.t().shape)}")
|
| 528 |
+
return w.t().contiguous()
|
| 529 |
+
|
| 530 |
+
return w
|
| 531 |
+
|
| 532 |
+
def convert_tensor(
|
| 533 |
+
self,
|
| 534 |
+
gguf_name: str,
|
| 535 |
+
key: str,
|
| 536 |
+
arr: np.ndarray,
|
| 537 |
+
) -> Optional[Dict[str, torch.Tensor]]:
|
| 538 |
+
expected = self.infer_shape(key)
|
| 539 |
+
|
| 540 |
+
w = torch.from_numpy(np.asarray(arr)).to(torch.float32)
|
| 541 |
+
w = self.maybe_transpose(w, expected, key)
|
| 542 |
+
|
| 543 |
+
result: Dict[str, torch.Tensor] = {}
|
| 544 |
+
|
| 545 |
+
# 1D norms
|
| 546 |
+
if len(expected) == 1:
|
| 547 |
+
if w.ndim != 1:
|
| 548 |
+
self.log(f" [SKIP] {gguf_name}: expected 1D {expected}, got {tuple(w.shape)}")
|
| 549 |
+
return None
|
| 550 |
+
|
| 551 |
+
if tuple(w.shape) != tuple(expected):
|
| 552 |
+
self.log(f" [RESIZE-1D] {gguf_name}: {tuple(w.shape)} -> {expected}")
|
| 553 |
+
w = resize_1d(w, expected[0])
|
| 554 |
+
|
| 555 |
+
result[key] = w.to(self.target_dtype()).contiguous()
|
| 556 |
+
return result
|
| 557 |
+
|
| 558 |
+
# Embeddings/lm_head doivent rester denses, pas ternaires ici.
|
| 559 |
+
if self.is_embedding_or_head(key):
|
| 560 |
+
if w.ndim != 2:
|
| 561 |
+
self.log(f" [SKIP] {gguf_name}: expected 2D embedding/head, got {tuple(w.shape)}")
|
| 562 |
+
return None
|
| 563 |
+
|
| 564 |
+
if tuple(w.shape) != tuple(expected):
|
| 565 |
+
self.log(f" [RESIZE-EMB] {gguf_name}: {tuple(w.shape)} -> {expected}")
|
| 566 |
+
w = resize_2d(w, expected, self.resize_strategy)
|
| 567 |
+
|
| 568 |
+
result[key] = w.to(self.target_dtype()).contiguous()
|
| 569 |
+
return result
|
| 570 |
+
|
| 571 |
+
# Linéaires BitLinear
|
| 572 |
+
if self.is_linear_key(key):
|
| 573 |
+
if w.ndim != 2:
|
| 574 |
+
self.log(f" [SKIP] {gguf_name}: expected 2D linear, got {tuple(w.shape)}")
|
| 575 |
+
return None
|
| 576 |
+
|
| 577 |
+
if tuple(w.shape) != tuple(expected):
|
| 578 |
+
self.log(f" [RESIZE-2D] {gguf_name}: {tuple(w.shape)} -> {expected}")
|
| 579 |
+
w = resize_2d(w, expected, self.resize_strategy)
|
| 580 |
+
|
| 581 |
+
w = reduce_noise(w, method=self.noise_method, sigma=self.noise_sigma)
|
| 582 |
+
|
| 583 |
+
if self.storage in {"fp32", "both"}:
|
| 584 |
+
result[key] = w.to(self.target_dtype()).contiguous()
|
| 585 |
+
|
| 586 |
+
if self.storage in {"packed", "both"}:
|
| 587 |
+
q, alpha = ternary_quantize_absmean(
|
| 588 |
+
w,
|
| 589 |
+
threshold=self.ternary_threshold,
|
| 590 |
+
)
|
| 591 |
+
packed = pack_ternary_2bit(q)
|
| 592 |
+
result[f"{key}.packed_weight"] = packed.cpu().contiguous()
|
| 593 |
+
result[f"{key}.alpha"] = alpha.cpu().contiguous()
|
| 594 |
+
result[f"{key}.shape"] = torch.tensor(list(expected), dtype=torch.int32)
|
| 595 |
+
|
| 596 |
+
return result
|
| 597 |
+
|
| 598 |
+
self.log(f" [SKIP] {gguf_name}: key non reconnue {key}")
|
| 599 |
+
return None
|
| 600 |
+
|
| 601 |
+
def init_missing_tensor(self, key: str) -> Dict[str, torch.Tensor]:
|
| 602 |
+
expected = self.infer_shape(key)
|
| 603 |
+
out: Dict[str, torch.Tensor] = {}
|
| 604 |
+
|
| 605 |
+
if len(expected) == 1:
|
| 606 |
+
# Norms : init à 1.0
|
| 607 |
+
w = torch.ones(expected, dtype=self.target_dtype())
|
| 608 |
+
out[key] = w
|
| 609 |
+
return out
|
| 610 |
+
|
| 611 |
+
if key in {"embed.weight", "lm_head.weight"}:
|
| 612 |
+
w = torch.empty(expected, dtype=torch.float32)
|
| 613 |
+
w.normal_(0.0, 0.02)
|
| 614 |
+
out[key] = w.to(self.target_dtype())
|
| 615 |
+
return out
|
| 616 |
+
|
| 617 |
+
if self.is_linear_key(key):
|
| 618 |
+
w = torch.empty(expected, dtype=torch.float32)
|
| 619 |
+
fan_in = max(1, expected[1])
|
| 620 |
+
std = math.sqrt(2.0 / fan_in)
|
| 621 |
+
w.normal_(0.0, std)
|
| 622 |
+
|
| 623 |
+
if self.storage in {"fp32", "both"}:
|
| 624 |
+
out[key] = w.to(self.target_dtype()).contiguous()
|
| 625 |
+
|
| 626 |
+
if self.storage in {"packed", "both"}:
|
| 627 |
+
q, alpha = ternary_quantize_absmean(w, threshold=self.ternary_threshold)
|
| 628 |
+
out[f"{key}.packed_weight"] = pack_ternary_2bit(q)
|
| 629 |
+
out[f"{key}.alpha"] = alpha
|
| 630 |
+
out[f"{key}.shape"] = torch.tensor(list(expected), dtype=torch.int32)
|
| 631 |
+
|
| 632 |
+
return out
|
| 633 |
+
|
| 634 |
+
return out
|
| 635 |
+
|
| 636 |
+
def dequantize_tensor(self, tensor) -> np.ndarray:
|
| 637 |
+
"""
|
| 638 |
+
Dequantize GGUF tensor vers numpy float32.
|
| 639 |
+
Compatible avec l'API gguf-py la plus courante.
|
| 640 |
+
"""
|
| 641 |
+
qtype = getattr(tensor, "tensor_type", None)
|
| 642 |
+
data = getattr(tensor, "data", None)
|
| 643 |
+
|
| 644 |
+
if data is None:
|
| 645 |
+
raise RuntimeError(f"Tensor GGUF sans data: {getattr(tensor, 'name', '?')}")
|
| 646 |
+
|
| 647 |
+
try:
|
| 648 |
+
arr = dequantize(data, qtype)
|
| 649 |
+
except Exception:
|
| 650 |
+
# Certains tensors peuvent déjà être float array
|
| 651 |
+
arr = np.asarray(data)
|
| 652 |
+
|
| 653 |
+
arr = np.asarray(arr)
|
| 654 |
+
|
| 655 |
+
if arr.dtype != np.float32:
|
| 656 |
+
arr = arr.astype(np.float32, copy=False)
|
| 657 |
+
|
| 658 |
+
return np.ascontiguousarray(arr)
|
| 659 |
+
|
| 660 |
+
def read_arch(self, reader) -> str:
|
| 661 |
+
try:
|
| 662 |
+
field = reader.fields.get("general.architecture")
|
| 663 |
+
if field is None:
|
| 664 |
+
return "unknown"
|
| 665 |
+
# gguf-py field formats can vary.
|
| 666 |
+
if hasattr(field, "parts") and field.parts:
|
| 667 |
+
return str(field.parts[-1])
|
| 668 |
+
return str(field)
|
| 669 |
+
except Exception:
|
| 670 |
+
return "unknown"
|
| 671 |
+
|
| 672 |
+
def import_model(self, gguf_path: str, output_path: str) -> Dict[str, Any]:
|
| 673 |
+
if not HAS_GGUF:
|
| 674 |
+
raise ImportError("Package gguf manquant. Installe avec: pip install gguf")
|
| 675 |
+
|
| 676 |
+
gguf_path = str(gguf_path)
|
| 677 |
+
output_path = str(output_path)
|
| 678 |
+
|
| 679 |
+
self.log("=" * 70)
|
| 680 |
+
self.log("CHIMERA GGUF IMPORT OPTIMIZED")
|
| 681 |
+
self.log("=" * 70)
|
| 682 |
+
|
| 683 |
+
reader = GGUFReader(gguf_path)
|
| 684 |
+
arch = self.read_arch(reader)
|
| 685 |
+
|
| 686 |
+
self.log(f"[GGUF] file={gguf_path}")
|
| 687 |
+
self.log(f"[GGUF] arch={arch}")
|
| 688 |
+
self.log(f"[GGUF] tensors={len(reader.tensors)}")
|
| 689 |
+
|
| 690 |
+
state_dict: Dict[str, torch.Tensor] = {}
|
| 691 |
+
|
| 692 |
+
stats = {
|
| 693 |
+
"mapped": 0,
|
| 694 |
+
"unmapped": 0,
|
| 695 |
+
"skipped": 0,
|
| 696 |
+
"linear": 0,
|
| 697 |
+
"dense": 0,
|
| 698 |
+
"norm": 0,
|
| 699 |
+
"resized_or_transposed_possible": 0,
|
| 700 |
+
}
|
| 701 |
+
|
| 702 |
+
imported_keys = set()
|
| 703 |
+
|
| 704 |
+
for idx, tensor in enumerate(reader.tensors):
|
| 705 |
+
name = str(tensor.name)
|
| 706 |
+
key = map_gguf_name(name, self.n_layers)
|
| 707 |
+
|
| 708 |
+
if key is None:
|
| 709 |
+
stats["unmapped"] += 1
|
| 710 |
+
if self.verbose:
|
| 711 |
+
self.log(f" [UNMAPPED] {name}")
|
| 712 |
+
continue
|
| 713 |
+
|
| 714 |
+
try:
|
| 715 |
+
arr = self.dequantize_tensor(tensor)
|
| 716 |
+
converted = self.convert_tensor(name, key, arr)
|
| 717 |
+
|
| 718 |
+
if not converted:
|
| 719 |
+
stats["skipped"] += 1
|
| 720 |
+
continue
|
| 721 |
+
|
| 722 |
+
state_dict.update(converted)
|
| 723 |
+
imported_keys.add(key)
|
| 724 |
+
stats["mapped"] += 1
|
| 725 |
+
|
| 726 |
+
if self.is_linear_key(key):
|
| 727 |
+
stats["linear"] += 1
|
| 728 |
+
elif key in {"embed.weight", "lm_head.weight"}:
|
| 729 |
+
stats["dense"] += 1
|
| 730 |
+
else:
|
| 731 |
+
stats["norm"] += 1
|
| 732 |
+
|
| 733 |
+
if self.verbose:
|
| 734 |
+
qtype = getattr(tensor, "tensor_type", "?")
|
| 735 |
+
shape = tuple(arr.shape)
|
| 736 |
+
self.log(f" [OK] {idx+1:04d} {name} -> {key} shape={shape} qtype={qtype}")
|
| 737 |
+
|
| 738 |
+
except Exception as e:
|
| 739 |
+
stats["skipped"] += 1
|
| 740 |
+
self.log(f" [ERROR] {name}: {type(e).__name__}: {e}")
|
| 741 |
+
|
| 742 |
+
finally:
|
| 743 |
+
# Libère le FP32 temporaire.
|
| 744 |
+
try:
|
| 745 |
+
del arr
|
| 746 |
+
except Exception:
|
| 747 |
+
pass
|
| 748 |
+
gc.collect()
|
| 749 |
+
|
| 750 |
+
# Init des clés manquantes
|
| 751 |
+
missing = []
|
| 752 |
+
if self.init_missing:
|
| 753 |
+
for key in self.all_expected_keys():
|
| 754 |
+
if key not in imported_keys:
|
| 755 |
+
missing.append(key)
|
| 756 |
+
init_tensors = self.init_missing_tensor(key)
|
| 757 |
+
state_dict.update(init_tensors)
|
| 758 |
+
|
| 759 |
+
if missing:
|
| 760 |
+
self.log(f"[MISSING] {len(missing)} tensors initialisés automatiquement")
|
| 761 |
+
|
| 762 |
+
ckpt = {
|
| 763 |
+
"model": state_dict,
|
| 764 |
+
"config": self.config,
|
| 765 |
+
"source": {
|
| 766 |
+
"gguf_path": gguf_path,
|
| 767 |
+
"gguf_arch": arch,
|
| 768 |
+
"scale": self.scale,
|
| 769 |
+
"storage": self.storage,
|
| 770 |
+
"param_dtype": self.param_dtype,
|
| 771 |
+
"noise_method": self.noise_method,
|
| 772 |
+
"noise_sigma": self.noise_sigma,
|
| 773 |
+
"ternary_threshold": self.ternary_threshold,
|
| 774 |
+
"resize_strategy": self.resize_strategy,
|
| 775 |
+
"auto_transpose": self.auto_transpose,
|
| 776 |
+
},
|
| 777 |
+
"stats": stats,
|
| 778 |
+
"missing_keys": missing,
|
| 779 |
+
"import_version": "2.0-optimized",
|
| 780 |
+
}
|
| 781 |
+
|
| 782 |
+
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
| 783 |
+
torch.save(ckpt, output_path)
|
| 784 |
+
|
| 785 |
+
gguf_mb = os.path.getsize(gguf_path) / 1024 / 1024
|
| 786 |
+
out_mb = os.path.getsize(output_path) / 1024 / 1024
|
| 787 |
+
|
| 788 |
+
self.log("")
|
| 789 |
+
self.log("=" * 70)
|
| 790 |
+
self.log("[DONE]")
|
| 791 |
+
self.log(f"[STATS] {stats}")
|
| 792 |
+
self.log(f"[SIZE] GGUF={gguf_mb:.2f} MB -> checkpoint={out_mb:.2f} MB")
|
| 793 |
+
self.log(f"[SAVE] {output_path}")
|
| 794 |
+
self.log("=" * 70)
|
| 795 |
+
|
| 796 |
+
return ckpt
|
| 797 |
+
|
| 798 |
+
|
| 799 |
+
# ═══════════════════════════════════════════════════════════
|
| 800 |
+
# CLI
|
| 801 |
+
# ═══════════════════════════════════════════════════════════
|
| 802 |
+
|
| 803 |
+
def main():
|
| 804 |
+
parser = argparse.ArgumentParser(
|
| 805 |
+
description="Optimized GGUF -> Chimera checkpoint importer"
|
| 806 |
+
)
|
| 807 |
+
|
| 808 |
+
parser.add_argument("--gguf", required=True, help="Path to input .gguf")
|
| 809 |
+
parser.add_argument("--config", default="config.json", help="Chimera config.json")
|
| 810 |
+
parser.add_argument("--output", required=True, help="Output .pt checkpoint")
|
| 811 |
+
|
| 812 |
+
parser.add_argument(
|
| 813 |
+
"--scale",
|
| 814 |
+
default="tiny",
|
| 815 |
+
choices=["tiny", "small", "medium", "full"],
|
| 816 |
+
help="Chimera scale override",
|
| 817 |
+
)
|
| 818 |
+
|
| 819 |
+
parser.add_argument(
|
| 820 |
+
"--storage",
|
| 821 |
+
default="fp32",
|
| 822 |
+
choices=["fp32", "packed", "both"],
|
| 823 |
+
help=(
|
| 824 |
+
"fp32=compatible Chimera classique, "
|
| 825 |
+
"packed=2-bit seulement, both=les deux"
|
| 826 |
+
),
|
| 827 |
+
)
|
| 828 |
+
|
| 829 |
+
parser.add_argument(
|
| 830 |
+
"--param-dtype",
|
| 831 |
+
default="fp32",
|
| 832 |
+
choices=["fp32", "fp16", "bf16"],
|
| 833 |
+
help="dtype pour les tensors denses/latents sauvegardés",
|
| 834 |
+
)
|
| 835 |
+
|
| 836 |
+
parser.add_argument(
|
| 837 |
+
"--noise-method",
|
| 838 |
+
default="row_outlier_clip",
|
| 839 |
+
choices=["none", "global_clip", "row_outlier_clip", "median_center"],
|
| 840 |
+
help="Noise reduction before ternary conversion",
|
| 841 |
+
)
|
| 842 |
+
|
| 843 |
+
parser.add_argument(
|
| 844 |
+
"--noise-sigma",
|
| 845 |
+
type=float,
|
| 846 |
+
default=3.0,
|
| 847 |
+
help="Sigma for clipping",
|
| 848 |
+
)
|
| 849 |
+
|
| 850 |
+
parser.add_argument(
|
| 851 |
+
"--ternary-threshold",
|
| 852 |
+
type=float,
|
| 853 |
+
default=0.5,
|
| 854 |
+
help="Threshold on normalized weights for ternary quantization",
|
| 855 |
+
)
|
| 856 |
+
|
| 857 |
+
parser.add_argument(
|
| 858 |
+
"--resize-strategy",
|
| 859 |
+
default="crop_pad",
|
| 860 |
+
choices=["strict", "crop_pad", "interpolate"],
|
| 861 |
+
help="Resize strategy when GGUF shape != Chimera shape",
|
| 862 |
+
)
|
| 863 |
+
|
| 864 |
+
parser.add_argument(
|
| 865 |
+
"--no-auto-transpose",
|
| 866 |
+
action="store_true",
|
| 867 |
+
help="Disable automatic transpose when reversed shape matches",
|
| 868 |
+
)
|
| 869 |
+
|
| 870 |
+
parser.add_argument(
|
| 871 |
+
"--no-init-missing",
|
| 872 |
+
action="store_true",
|
| 873 |
+
help="Do not initialize missing Chimera weights",
|
| 874 |
+
)
|
| 875 |
+
|
| 876 |
+
parser.add_argument(
|
| 877 |
+
"--quiet",
|
| 878 |
+
action="store_true",
|
| 879 |
+
help="Less logs",
|
| 880 |
+
)
|
| 881 |
+
|
| 882 |
+
args = parser.parse_args()
|
| 883 |
+
|
| 884 |
+
with open(args.config, "r", encoding="utf-8") as f:
|
| 885 |
+
config = json.load(f)
|
| 886 |
+
|
| 887 |
+
importer = OptimizedGGUFImporter(
|
| 888 |
+
config=config,
|
| 889 |
+
scale=args.scale,
|
| 890 |
+
storage=args.storage,
|
| 891 |
+
param_dtype=args.param_dtype,
|
| 892 |
+
noise_method=args.noise_method,
|
| 893 |
+
noise_sigma=args.noise_sigma,
|
| 894 |
+
ternary_threshold=args.ternary_threshold,
|
| 895 |
+
resize_strategy=args.resize_strategy,
|
| 896 |
+
auto_transpose=not args.no_auto_transpose,
|
| 897 |
+
init_missing=not args.no_init_missing,
|
| 898 |
+
verbose=not args.quiet,
|
| 899 |
+
)
|
| 900 |
+
|
| 901 |
+
importer.import_model(args.gguf, args.output)
|
| 902 |
+
|
| 903 |
+
|
| 904 |
+
if __name__ == "__main__":
|
| 905 |
+
main()
|
inference.py
ADDED
|
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Chimera 5.2 — CPU-first inference / text generation.
|
| 3 |
+
|
| 4 |
+
Significant CPU-friendly changes vs the previous draft:
|
| 5 |
+
|
| 6 |
+
* **KV-cache aware loop** — after the first forward pass we only feed the
|
| 7 |
+
new token plus the per-layer recurrent state into the model. This makes
|
| 8 |
+
generation *O(T)* instead of *O(T²)*, the single biggest win for CPU
|
| 9 |
+
decoding.
|
| 10 |
+
* **Pre-pack BitLinear weights** at startup so the first decoded token does
|
| 11 |
+
not pay the unpack/repack cost.
|
| 12 |
+
* **Greedy fast path** (``temperature == 0``) skips softmax / sort entirely.
|
| 13 |
+
* **Top-k constrained nucleus** — when both ``top_k`` and ``top_p`` are
|
| 14 |
+
used we sort the top-k slice only (not the full 200K vocabulary).
|
| 15 |
+
* **Streaming output** — tokens are decoded incrementally so the first
|
| 16 |
+
bytes appear immediately.
|
| 17 |
+
|
| 18 |
+
Usage::
|
| 19 |
+
|
| 20 |
+
python inference.py --checkpoint chimera_output/final/model.pt \\
|
| 21 |
+
--prompt "Once upon a time" --max_tokens 200
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
from __future__ import annotations
|
| 25 |
+
|
| 26 |
+
import argparse
|
| 27 |
+
import json
|
| 28 |
+
import os
|
| 29 |
+
import sys
|
| 30 |
+
import time
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _setup_cpu_runtime() -> None:
|
| 34 |
+
n = os.cpu_count() or 4
|
| 35 |
+
os.environ.setdefault("OMP_NUM_THREADS", str(n))
|
| 36 |
+
os.environ.setdefault("MKL_NUM_THREADS", str(n))
|
| 37 |
+
os.environ.setdefault("KMP_AFFINITY", "granularity=fine,compact,1,0")
|
| 38 |
+
os.environ.setdefault("KMP_BLOCKTIME", "1")
|
| 39 |
+
os.environ.setdefault("MALLOC_CONF", "background_thread:true,metadata_thp:auto")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
_setup_cpu_runtime()
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
import torch
|
| 46 |
+
import torch.nn.functional as F
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
try:
|
| 50 |
+
torch.set_num_threads(int(os.environ.get("OMP_NUM_THREADS", os.cpu_count() or 4)))
|
| 51 |
+
torch.set_num_interop_threads(int(os.environ.get("CHIMERA_INTEROP_THREADS", "1")))
|
| 52 |
+
except RuntimeError:
|
| 53 |
+
pass
|
| 54 |
+
|
| 55 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 56 |
+
|
| 57 |
+
from chimera import Chimera51ForCausalLM, ChimeraTokenizer
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
# ---------------------------------------------------------------------------
|
| 61 |
+
# Checkpoint loading
|
| 62 |
+
# ---------------------------------------------------------------------------
|
| 63 |
+
|
| 64 |
+
def load_model(checkpoint_path: str, device: str = "cpu"):
|
| 65 |
+
print(f"[LOAD] Checkpoint: {checkpoint_path}")
|
| 66 |
+
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
| 67 |
+
|
| 68 |
+
config = ckpt.get("config")
|
| 69 |
+
if config is None:
|
| 70 |
+
ckpt_dir = os.path.dirname(checkpoint_path)
|
| 71 |
+
cand = os.path.join(ckpt_dir, "config.json") if ckpt_dir else "config.json"
|
| 72 |
+
if not os.path.exists(cand):
|
| 73 |
+
cand = "config.json"
|
| 74 |
+
with open(cand, encoding="utf-8") as f:
|
| 75 |
+
config = json.load(f)
|
| 76 |
+
print(f"[LOAD] Config from {cand}")
|
| 77 |
+
else:
|
| 78 |
+
print("[LOAD] Config from checkpoint")
|
| 79 |
+
|
| 80 |
+
model = Chimera51ForCausalLM(config)
|
| 81 |
+
counts = model.count_parameters()
|
| 82 |
+
print(f"[LOAD] Params: {counts['total']:,} (ternary: {counts['ternary']:,})")
|
| 83 |
+
|
| 84 |
+
state = ckpt.get("model", ckpt)
|
| 85 |
+
|
| 86 |
+
# Reconcile vocab mismatches in either direction without crashing.
|
| 87 |
+
model_vocab = int(config.get("vocab_size", model.embed.num_embeddings))
|
| 88 |
+
ckpt_vocab = None
|
| 89 |
+
for key in ("embed.weight", "lm_head.weight"):
|
| 90 |
+
for sk, t in state.items():
|
| 91 |
+
if sk.endswith(key):
|
| 92 |
+
ckpt_vocab = int(t.shape[0])
|
| 93 |
+
break
|
| 94 |
+
if ckpt_vocab is not None:
|
| 95 |
+
break
|
| 96 |
+
|
| 97 |
+
if ckpt_vocab and ckpt_vocab != model_vocab:
|
| 98 |
+
print(f"[WARN] vocab mismatch ckpt={ckpt_vocab} cfg={model_vocab}; resizing")
|
| 99 |
+
with torch.no_grad():
|
| 100 |
+
old = model.embed.weight.data
|
| 101 |
+
new = torch.zeros(ckpt_vocab, old.shape[1], dtype=old.dtype, device=old.device)
|
| 102 |
+
new[:min(old.shape[0], ckpt_vocab)] = old[:min(old.shape[0], ckpt_vocab)]
|
| 103 |
+
model.embed = torch.nn.Embedding(ckpt_vocab, old.shape[1])
|
| 104 |
+
model.embed.weight.data = new
|
| 105 |
+
old_h = model.lm_head.weight.data
|
| 106 |
+
new_h = torch.zeros(ckpt_vocab, old_h.shape[1], dtype=old_h.dtype, device=old_h.device)
|
| 107 |
+
new_h[:min(old_h.shape[0], ckpt_vocab)] = old_h[:min(old_h.shape[0], ckpt_vocab)]
|
| 108 |
+
model.lm_head = torch.nn.Linear(old_h.shape[1], ckpt_vocab, bias=False)
|
| 109 |
+
model.lm_head.weight.data = new_h
|
| 110 |
+
config["vocab_size"] = ckpt_vocab
|
| 111 |
+
|
| 112 |
+
missing, unexpected = model.load_state_dict(state, strict=False)
|
| 113 |
+
if missing:
|
| 114 |
+
print(f"[WARN] Missing keys ({len(missing)}): {missing[:5]}...")
|
| 115 |
+
if unexpected:
|
| 116 |
+
print(f"[WARN] Unexpected keys ({len(unexpected)}): {unexpected[:5]}...")
|
| 117 |
+
|
| 118 |
+
model.to(device).eval()
|
| 119 |
+
model.prepare_for_inference() # pre-pack ternary weights
|
| 120 |
+
|
| 121 |
+
step = ckpt.get("step", "?")
|
| 122 |
+
best_loss = ckpt.get("best_loss")
|
| 123 |
+
if best_loss is not None:
|
| 124 |
+
print(f"[LOAD] Step {step}, best_loss={best_loss:.4f}")
|
| 125 |
+
else:
|
| 126 |
+
print(f"[LOAD] Step {step}")
|
| 127 |
+
return model, config
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
# ---------------------------------------------------------------------------
|
| 131 |
+
# Sampling helpers
|
| 132 |
+
# ---------------------------------------------------------------------------
|
| 133 |
+
|
| 134 |
+
def _sample_next(logits: torch.Tensor, temperature: float, top_p: float, top_k: int
|
| 135 |
+
) -> int:
|
| 136 |
+
"""Return the next token id sampled from ``logits`` ([1, V] or [V])."""
|
| 137 |
+
if logits.dim() == 1:
|
| 138 |
+
logits = logits.unsqueeze(0)
|
| 139 |
+
|
| 140 |
+
# Greedy fast path.
|
| 141 |
+
if temperature <= 0.0:
|
| 142 |
+
return int(torch.argmax(logits, dim=-1).item())
|
| 143 |
+
|
| 144 |
+
logits = logits / temperature
|
| 145 |
+
|
| 146 |
+
if top_k and top_k > 0:
|
| 147 |
+
k = min(top_k, logits.size(-1))
|
| 148 |
+
cand_logits, cand_indices = torch.topk(logits, k, dim=-1)
|
| 149 |
+
if top_p < 1.0:
|
| 150 |
+
sorted_logits, order = torch.sort(cand_logits, descending=True)
|
| 151 |
+
sorted_indices = cand_indices.gather(-1, order)
|
| 152 |
+
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 153 |
+
remove = cum_probs > top_p
|
| 154 |
+
remove[..., 0] = False
|
| 155 |
+
sorted_logits = sorted_logits.masked_fill(remove, float("-inf"))
|
| 156 |
+
probs = F.softmax(sorted_logits, dim=-1)
|
| 157 |
+
return int(sorted_indices.gather(-1, torch.multinomial(probs, 1)).item())
|
| 158 |
+
probs = F.softmax(cand_logits, dim=-1)
|
| 159 |
+
return int(cand_indices.gather(-1, torch.multinomial(probs, 1)).item())
|
| 160 |
+
|
| 161 |
+
if top_p < 1.0:
|
| 162 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 163 |
+
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 164 |
+
remove = cum_probs > top_p
|
| 165 |
+
remove[..., 0] = False
|
| 166 |
+
sorted_logits = sorted_logits.masked_fill(remove, float("-inf"))
|
| 167 |
+
probs = F.softmax(sorted_logits, dim=-1)
|
| 168 |
+
return int(sorted_indices.gather(-1, torch.multinomial(probs, 1)).item())
|
| 169 |
+
|
| 170 |
+
probs = F.softmax(logits, dim=-1)
|
| 171 |
+
return int(torch.multinomial(probs, 1).item())
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
# ---------------------------------------------------------------------------
|
| 175 |
+
# Generation loop
|
| 176 |
+
# ---------------------------------------------------------------------------
|
| 177 |
+
|
| 178 |
+
def generate(model: Chimera51ForCausalLM, tokenizer: ChimeraTokenizer,
|
| 179 |
+
prompt: str, max_tokens: int = 100, temperature: float = 0.8,
|
| 180 |
+
top_p: float = 0.9, top_k: int = 50, device: str = "cpu",
|
| 181 |
+
bf16: bool = False, stream: bool = True) -> str:
|
| 182 |
+
model.eval()
|
| 183 |
+
prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
|
| 184 |
+
if not prompt_ids:
|
| 185 |
+
prompt_ids = [tokenizer.eos_token_id]
|
| 186 |
+
input_ids = torch.tensor([prompt_ids], dtype=torch.long, device=device)
|
| 187 |
+
|
| 188 |
+
print(f"\n[GEN] Prompt: {prompt!r}")
|
| 189 |
+
print(f"[GEN] max_tokens={max_tokens}, temp={temperature}, top_p={top_p}, top_k={top_k}")
|
| 190 |
+
print("=" * 60, flush=True)
|
| 191 |
+
|
| 192 |
+
if stream:
|
| 193 |
+
sys.stdout.write(prompt)
|
| 194 |
+
sys.stdout.flush()
|
| 195 |
+
|
| 196 |
+
generated = list(prompt_ids)
|
| 197 |
+
decoded_so_far = tokenizer.decode(generated, skip_special_tokens=False)
|
| 198 |
+
|
| 199 |
+
autocast_ctx = (torch.autocast(device_type=device.split(":")[0], dtype=torch.bfloat16)
|
| 200 |
+
if bf16 else _nullctx())
|
| 201 |
+
|
| 202 |
+
t0 = time.time()
|
| 203 |
+
with torch.inference_mode(), autocast_ctx:
|
| 204 |
+
# Initial pass: feed the whole prompt and capture per-layer caches.
|
| 205 |
+
out = model(input_ids, use_cache=True, logits_to_keep=1)
|
| 206 |
+
caches = out.caches
|
| 207 |
+
next_token = _sample_next(out.logits[:, -1, :].float(), temperature, top_p, top_k)
|
| 208 |
+
if next_token == tokenizer.eos_token_id:
|
| 209 |
+
return tokenizer.decode(generated, skip_special_tokens=True)
|
| 210 |
+
generated.append(next_token)
|
| 211 |
+
|
| 212 |
+
for _ in range(max_tokens - 1):
|
| 213 |
+
tok_t = torch.tensor([[next_token]], dtype=torch.long, device=device)
|
| 214 |
+
out = model(tok_t, caches=caches, use_cache=True, logits_to_keep=1)
|
| 215 |
+
caches = out.caches
|
| 216 |
+
next_token = _sample_next(out.logits[:, -1, :].float(), temperature, top_p, top_k)
|
| 217 |
+
if next_token == tokenizer.eos_token_id:
|
| 218 |
+
break
|
| 219 |
+
generated.append(next_token)
|
| 220 |
+
if stream:
|
| 221 |
+
# Try to render only the newly produced text.
|
| 222 |
+
full = tokenizer.decode(generated, skip_special_tokens=False)
|
| 223 |
+
if full.startswith(decoded_so_far):
|
| 224 |
+
sys.stdout.write(full[len(decoded_so_far):])
|
| 225 |
+
sys.stdout.flush()
|
| 226 |
+
decoded_so_far = full
|
| 227 |
+
|
| 228 |
+
elapsed = time.time() - t0
|
| 229 |
+
n_new = len(generated) - len(prompt_ids)
|
| 230 |
+
speed = n_new / elapsed if elapsed > 0 else 0.0
|
| 231 |
+
final = tokenizer.decode(generated, skip_special_tokens=True)
|
| 232 |
+
|
| 233 |
+
print()
|
| 234 |
+
print("=" * 60)
|
| 235 |
+
if not stream:
|
| 236 |
+
print(final)
|
| 237 |
+
print(f"[STATS] {n_new} new tokens in {elapsed:.2f}s ({speed:.1f} tok/s)")
|
| 238 |
+
return final
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
class _nullctx:
|
| 242 |
+
def __enter__(self):
|
| 243 |
+
return self
|
| 244 |
+
|
| 245 |
+
def __exit__(self, *args):
|
| 246 |
+
return False
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
# ---------------------------------------------------------------------------
|
| 250 |
+
# CLI
|
| 251 |
+
# ---------------------------------------------------------------------------
|
| 252 |
+
|
| 253 |
+
def main() -> None:
|
| 254 |
+
p = argparse.ArgumentParser(description="Chimera 5.2 CPU inference")
|
| 255 |
+
p.add_argument("--checkpoint", default="chimera_output/final/model.pt")
|
| 256 |
+
p.add_argument("--prompt", default="Once upon a time")
|
| 257 |
+
p.add_argument("--max_tokens", type=int, default=100)
|
| 258 |
+
p.add_argument("--temperature", type=float, default=0.8)
|
| 259 |
+
p.add_argument("--top_p", type=float, default=0.9)
|
| 260 |
+
p.add_argument("--top_k", type=int, default=50)
|
| 261 |
+
p.add_argument("--device", default="cpu")
|
| 262 |
+
p.add_argument("--bf16", action="store_true", default=True)
|
| 263 |
+
p.add_argument("--no-bf16", dest="bf16", action="store_false")
|
| 264 |
+
p.add_argument("--threads", type=int, default=None)
|
| 265 |
+
p.add_argument("--compile", action="store_true", default=False)
|
| 266 |
+
p.add_argument("--no-stream", dest="stream", action="store_false", default=True)
|
| 267 |
+
args = p.parse_args()
|
| 268 |
+
|
| 269 |
+
if args.threads:
|
| 270 |
+
torch.set_num_threads(args.threads)
|
| 271 |
+
os.environ["OMP_NUM_THREADS"] = str(args.threads)
|
| 272 |
+
os.environ["MKL_NUM_THREADS"] = str(args.threads)
|
| 273 |
+
|
| 274 |
+
if not os.path.exists(args.checkpoint):
|
| 275 |
+
print(f"[ERROR] Checkpoint not found: {args.checkpoint}")
|
| 276 |
+
return
|
| 277 |
+
|
| 278 |
+
model, config = load_model(args.checkpoint, device=args.device)
|
| 279 |
+
|
| 280 |
+
if args.compile:
|
| 281 |
+
print("[OPT] Compiling model with torch.compile (mode=reduce-overhead)...")
|
| 282 |
+
model = torch.compile(model, backend="inductor", mode="reduce-overhead")
|
| 283 |
+
|
| 284 |
+
print("[LOAD] Loading tokenizer (splintr o200k_base)...")
|
| 285 |
+
tokenizer = ChimeraTokenizer(pretrained="o200k_base")
|
| 286 |
+
|
| 287 |
+
print("[WARM] Warmup forward...")
|
| 288 |
+
with torch.inference_mode():
|
| 289 |
+
_ = model(torch.tensor([[tokenizer.eos_token_id]], device=args.device),
|
| 290 |
+
logits_to_keep=1)
|
| 291 |
+
print("[WARM] Done.")
|
| 292 |
+
|
| 293 |
+
generate(
|
| 294 |
+
model, tokenizer,
|
| 295 |
+
prompt=args.prompt, max_tokens=args.max_tokens,
|
| 296 |
+
temperature=args.temperature, top_p=args.top_p, top_k=args.top_k,
|
| 297 |
+
device=args.device, bf16=args.bf16, stream=args.stream,
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
if __name__ == "__main__":
|
| 302 |
+
main()
|
pyproject.toml
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "chimera51-cpu"
|
| 3 |
+
version = "5.2.0"
|
| 4 |
+
description = "CPU-first Chimera 5.1 causal LM implementation"
|
| 5 |
+
requires-python = ">=3.10"
|
| 6 |
+
dependencies = ["torch"]
|
| 7 |
+
|
| 8 |
+
[tool.pytest.ini_options]
|
| 9 |
+
testpaths = ["tests"]
|
| 10 |
+
pythonpath = ["."]
|
tests/test_chimera.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
|
| 3 |
+
torch = pytest.importorskip("torch")
|
| 4 |
+
|
| 5 |
+
from chimera import (
|
| 6 |
+
Chimera51ForCausalLM, ChimeraTokenizer, load_config, scale_config,
|
| 7 |
+
pack_ternary, unpack_ternary,
|
| 8 |
+
)
|
| 9 |
+
from chimera.inference import SpanBank
|
| 10 |
+
from chimera.moe import MoELayer
|
| 11 |
+
from chimera.quantization import BitLinear, ternarize_weight
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def cfg():
|
| 15 |
+
c = scale_config(load_config("config.json"), "nano")
|
| 16 |
+
c["vocab_size"] = 512
|
| 17 |
+
c["span_inference"]["enabled"] = False
|
| 18 |
+
return c
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def test_pack_unpack_roundtrip():
|
| 22 |
+
q = torch.tensor([[-1, 0, 1, 1, -1, 0, 1, 0, -1]], dtype=torch.int8)
|
| 23 |
+
packed = pack_ternary(q)
|
| 24 |
+
out = unpack_ternary(packed, q.shape[-1], dtype=torch.float32).to(torch.int8)
|
| 25 |
+
assert torch.equal(q, out)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def test_ternarize_weight_basic():
|
| 29 |
+
w = torch.randn(8, 16) * 0.5
|
| 30 |
+
wq, alpha = ternarize_weight(w)
|
| 31 |
+
assert wq.shape == w.shape
|
| 32 |
+
assert alpha.shape == (8,)
|
| 33 |
+
assert (wq.unique().abs() <= 1).all()
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def test_bitlinear_forward_backward_and_packed():
|
| 37 |
+
layer = BitLinear(7, 5)
|
| 38 |
+
x = torch.randn(3, 7, requires_grad=True)
|
| 39 |
+
y = layer(x).sum()
|
| 40 |
+
y.backward()
|
| 41 |
+
assert x.grad is not None and torch.isfinite(x.grad).all()
|
| 42 |
+
assert layer.weight.grad is not None
|
| 43 |
+
layer.prepare_for_inference()
|
| 44 |
+
layer.eval()
|
| 45 |
+
with torch.no_grad():
|
| 46 |
+
out = layer(torch.randn(2, 7))
|
| 47 |
+
assert out.shape == (2, 5)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def test_bitlinear_dense_cache_consistency():
|
| 51 |
+
layer = BitLinear(8, 4)
|
| 52 |
+
layer.eval()
|
| 53 |
+
layer.prepare_for_inference()
|
| 54 |
+
x = torch.randn(2, 8)
|
| 55 |
+
with torch.no_grad():
|
| 56 |
+
out1 = layer(x)
|
| 57 |
+
out2 = layer(x)
|
| 58 |
+
assert torch.allclose(out1, out2)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def test_model_forward_loss_and_generate_shape():
|
| 62 |
+
model = Chimera51ForCausalLM(cfg())
|
| 63 |
+
x = torch.randint(0, 512, (2, 8))
|
| 64 |
+
y = torch.randint(0, 512, (2, 8))
|
| 65 |
+
out = model(x, labels=y)
|
| 66 |
+
assert out.logits.shape == (2, 8, 512)
|
| 67 |
+
assert torch.isfinite(out.loss)
|
| 68 |
+
out.loss.backward()
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def test_model_kv_cache_consistency():
|
| 72 |
+
"""Generation with KV-cache must match generation without it."""
|
| 73 |
+
config = cfg()
|
| 74 |
+
config["looping"]["enabled"] = False # determinism for the equivalence check
|
| 75 |
+
model = Chimera51ForCausalLM(config).eval()
|
| 76 |
+
model.prepare_for_inference()
|
| 77 |
+
|
| 78 |
+
prompt = torch.randint(0, 512, (1, 4))
|
| 79 |
+
with torch.inference_mode():
|
| 80 |
+
# No-cache: feed the full sequence each time.
|
| 81 |
+
cur = prompt.clone()
|
| 82 |
+
no_cache_tokens = []
|
| 83 |
+
for _ in range(3):
|
| 84 |
+
out = model(cur, logits_to_keep=1)
|
| 85 |
+
tok = out.logits[:, -1].argmax(-1, keepdim=True)
|
| 86 |
+
cur = torch.cat([cur, tok], dim=1)
|
| 87 |
+
no_cache_tokens.append(int(tok.item()))
|
| 88 |
+
|
| 89 |
+
# KV-cache: feed only the new token after the first call.
|
| 90 |
+
out = model(prompt, use_cache=True, logits_to_keep=1)
|
| 91 |
+
caches = out.caches
|
| 92 |
+
tok = out.logits[:, -1].argmax(-1, keepdim=True)
|
| 93 |
+
cache_tokens = [int(tok.item())]
|
| 94 |
+
for _ in range(2):
|
| 95 |
+
out = model(tok, caches=caches, use_cache=True, logits_to_keep=1)
|
| 96 |
+
caches = out.caches
|
| 97 |
+
tok = out.logits[:, -1].argmax(-1, keepdim=True)
|
| 98 |
+
cache_tokens.append(int(tok.item()))
|
| 99 |
+
|
| 100 |
+
assert no_cache_tokens == cache_tokens
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def test_moe_and_span_bank_shapes():
|
| 104 |
+
moe = MoELayer(32, 64, n_routed_experts=3, n_shared_experts=1, num_experts_per_tok=2)
|
| 105 |
+
x = torch.randn(2, 4, 32)
|
| 106 |
+
assert moe(x).shape == x.shape
|
| 107 |
+
bank = SpanBank(max_entries=8, hidden_size=32)
|
| 108 |
+
bank.add(torch.randn(3, 32), torch.randn(3, 32))
|
| 109 |
+
assert bank.query(torch.randn(5, 32)).shape == (5, 32)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def test_tokenizer_fallback_roundtrip():
|
| 113 |
+
tok = ChimeraTokenizer(vocab_size=512)
|
| 114 |
+
text = "hello cpu"
|
| 115 |
+
assert tok.decode(tok.encode(text)) == text
|
tests/test_config.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from chimera.config import load_config, scale_config
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def test_config_scaling_without_torch_runtime():
|
| 5 |
+
cfg = scale_config(load_config("config.json"), "nano")
|
| 6 |
+
assert cfg["hidden_size"] == 128
|
| 7 |
+
assert cfg["num_hidden_layers"] == 4
|
| 8 |
+
assert cfg["vocab_size"] <= 8192
|
train.py
ADDED
|
@@ -0,0 +1,632 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Chimera 5.2 — CPU-first training script.
|
| 4 |
+
|
| 5 |
+
Highlights vs the previous version:
|
| 6 |
+
|
| 7 |
+
* MeZO optimiser uses a single deterministic seed per step, samples each
|
| 8 |
+
parameter's perturbation direction *on demand* via per-parameter seeds and
|
| 9 |
+
drops the heavy direction cache. This brings the memory cost of MeZO back
|
| 10 |
+
down to "1× model" exactly as advertised.
|
| 11 |
+
* AdamW path uses fused parameter groups and shares the same loss closure as
|
| 12 |
+
MeZO so accumulation and logging are identical between modes.
|
| 13 |
+
* Logging never references an undefined ``lr`` (the previous draft printed it
|
| 14 |
+
before the AdamW step ran on the first accumulator boundary).
|
| 15 |
+
* Gradient checkpointing falls back to ``use_reentrant=False`` (the modern,
|
| 16 |
+
faster path).
|
| 17 |
+
* Tokeniser/dataset loading is unchanged but the Python loops are skipped
|
| 18 |
+
entirely for ``max_tokens=0``.
|
| 19 |
+
|
| 20 |
+
Recommended commands::
|
| 21 |
+
|
| 22 |
+
# MeZO smoke test on TinyStories
|
| 23 |
+
python train.py --scale tiny --seq_len 64 --max_steps 20 --optimizer mezo
|
| 24 |
+
|
| 25 |
+
# AdamW with grad checkpointing + bf16
|
| 26 |
+
python train.py --scale small --seq_len 256 --max_steps 1000 \\
|
| 27 |
+
--optimizer adamw --grad_checkpoint --bf16
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
from __future__ import annotations
|
| 31 |
+
|
| 32 |
+
import argparse
|
| 33 |
+
import json
|
| 34 |
+
import math
|
| 35 |
+
import os
|
| 36 |
+
import sys
|
| 37 |
+
import time
|
| 38 |
+
|
| 39 |
+
# CPU threading must be configured *before* importing torch.
|
| 40 |
+
def _setup_cpu_runtime() -> None:
|
| 41 |
+
n_cpus = os.cpu_count() or 4
|
| 42 |
+
os.environ.setdefault("OMP_NUM_THREADS", str(n_cpus))
|
| 43 |
+
os.environ.setdefault("MKL_NUM_THREADS", str(n_cpus))
|
| 44 |
+
os.environ.setdefault("KMP_AFFINITY", "granularity=fine,compact,1,0")
|
| 45 |
+
os.environ.setdefault("KMP_BLOCKTIME", "1")
|
| 46 |
+
os.environ.setdefault("MALLOC_CONF", "background_thread:true,metadata_thp:auto")
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
_setup_cpu_runtime()
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
import torch
|
| 53 |
+
import torch.nn as nn
|
| 54 |
+
import torch.nn.functional as F
|
| 55 |
+
from torch.utils.data import DataLoader, Dataset
|
| 56 |
+
|
| 57 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 58 |
+
|
| 59 |
+
from chimera import Chimera51ForCausalLM
|
| 60 |
+
from chimera.quantization import BitLinear
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
torch.set_num_threads(int(os.environ.get("OMP_NUM_THREADS", os.cpu_count() or 4)))
|
| 64 |
+
try:
|
| 65 |
+
torch.set_num_interop_threads(int(os.environ.get("CHIMERA_INTEROP_THREADS", "1")))
|
| 66 |
+
except RuntimeError:
|
| 67 |
+
pass
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# Optional Intel Extension for PyTorch.
|
| 71 |
+
HAS_IPEX = False
|
| 72 |
+
try: # pragma: no cover - optional dependency.
|
| 73 |
+
import intel_extension_for_pytorch as ipex # noqa: F401
|
| 74 |
+
HAS_IPEX = True
|
| 75 |
+
except Exception:
|
| 76 |
+
pass
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# ---------------------------------------------------------------------------
|
| 80 |
+
# MeZO optimiser
|
| 81 |
+
# ---------------------------------------------------------------------------
|
| 82 |
+
|
| 83 |
+
class MeZOOptimizer:
|
| 84 |
+
"""Memory-Efficient Zeroth-Order optimiser (Princeton MeZO).
|
| 85 |
+
|
| 86 |
+
Each step runs *two* forward passes around ``θ`` and uses the resulting
|
| 87 |
+
loss difference to estimate a projected gradient. No backward pass and
|
| 88 |
+
no per-parameter optimiser state — memory cost is exactly ``1× model``.
|
| 89 |
+
|
| 90 |
+
For BitLinear layers we mask perturbations to currently non-zero ternary
|
| 91 |
+
positions, so ``~1/3`` of the weights skip both perturbation and update.
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
def __init__(self, model: nn.Module, lr: float = 1e-4, eps: float = 1e-3,
|
| 95 |
+
weight_decay: float = 0.0, momentum: float = 0.0,
|
| 96 |
+
direction: str = "rademacher"):
|
| 97 |
+
self.model = model
|
| 98 |
+
self.lr = float(lr)
|
| 99 |
+
self.eps = float(eps)
|
| 100 |
+
self.wd = float(weight_decay)
|
| 101 |
+
self.momentum = float(momentum)
|
| 102 |
+
if direction not in ("rademacher", "gaussian"):
|
| 103 |
+
raise ValueError(f"unknown direction: {direction!r}")
|
| 104 |
+
self.direction = direction
|
| 105 |
+
|
| 106 |
+
# Collect trainable parameters once and deduplicate tied weights.
|
| 107 |
+
self._bitlinear_modules: list[tuple[str, BitLinear]] = []
|
| 108 |
+
self._dense_params: list[tuple[str, torch.Tensor]] = []
|
| 109 |
+
seen: set[int] = set()
|
| 110 |
+
|
| 111 |
+
for name, module in model.named_modules():
|
| 112 |
+
if isinstance(module, BitLinear):
|
| 113 |
+
self._bitlinear_modules.append((name, module))
|
| 114 |
+
seen.add(id(module.weight))
|
| 115 |
+
if module.bias is not None:
|
| 116 |
+
seen.add(id(module.bias))
|
| 117 |
+
|
| 118 |
+
for name, p in model.named_parameters():
|
| 119 |
+
if p.requires_grad and id(p) not in seen:
|
| 120 |
+
self._dense_params.append((name, p))
|
| 121 |
+
seen.add(id(p))
|
| 122 |
+
|
| 123 |
+
# Optional momentum buffer — only allocated when momentum > 0.
|
| 124 |
+
self._momentum: dict[int, torch.Tensor] = {}
|
| 125 |
+
if self.momentum > 0:
|
| 126 |
+
for _, p in self._dense_params:
|
| 127 |
+
self._momentum[id(p)] = torch.zeros_like(p.data)
|
| 128 |
+
for _, m in self._bitlinear_modules:
|
| 129 |
+
self._momentum[id(m.weight)] = torch.zeros_like(m.weight.data)
|
| 130 |
+
|
| 131 |
+
# Snapshot ternary non-zero masks once per step.
|
| 132 |
+
self._step_masks: dict[int, torch.Tensor] = {}
|
| 133 |
+
|
| 134 |
+
# ------------------------------------------------------------------
|
| 135 |
+
# Direction sampling — deterministic per (step seed, parameter index).
|
| 136 |
+
# ------------------------------------------------------------------
|
| 137 |
+
|
| 138 |
+
def _direction(self, p: torch.Tensor, seed: int) -> torch.Tensor:
|
| 139 |
+
gen = torch.Generator(device="cpu")
|
| 140 |
+
gen.manual_seed(int(seed) & 0x7FFF_FFFF_FFFF_FFFF)
|
| 141 |
+
if self.direction == "gaussian":
|
| 142 |
+
return torch.randn(p.shape, dtype=p.dtype, device="cpu",
|
| 143 |
+
generator=gen).to(p.device)
|
| 144 |
+
z = torch.empty(p.shape, dtype=p.dtype, device="cpu")
|
| 145 |
+
z.bernoulli_(0.5, generator=gen).mul_(2).sub_(1)
|
| 146 |
+
return z.to(p.device)
|
| 147 |
+
|
| 148 |
+
def _walk_params(self):
|
| 149 |
+
"""Yield ``(seed_offset, param, mask_or_None)`` for every trainable tensor."""
|
| 150 |
+
offset = 0
|
| 151 |
+
for _, module in self._bitlinear_modules:
|
| 152 |
+
yield offset, module.weight.data, self._step_masks.get(id(module.weight))
|
| 153 |
+
offset += 1
|
| 154 |
+
if module.bias is not None:
|
| 155 |
+
yield offset, module.bias.data, None
|
| 156 |
+
offset += 1
|
| 157 |
+
for _, p in self._dense_params:
|
| 158 |
+
yield offset, p.data, None
|
| 159 |
+
offset += 1
|
| 160 |
+
|
| 161 |
+
def _perturb(self, base_seed: int, scale: float) -> None:
|
| 162 |
+
for off, p, mask in self._walk_params():
|
| 163 |
+
z = self._direction(p, base_seed + off * 1_000_003)
|
| 164 |
+
if mask is not None:
|
| 165 |
+
z = z * mask.to(dtype=z.dtype, device=z.device)
|
| 166 |
+
p.add_(z, alpha=scale)
|
| 167 |
+
# Mark BitLinear caches stale.
|
| 168 |
+
for _, m in self._bitlinear_modules:
|
| 169 |
+
m.invalidate_packed()
|
| 170 |
+
|
| 171 |
+
def _update(self, base_seed: int, projected_grad: float) -> None:
|
| 172 |
+
for off, p, mask in self._walk_params():
|
| 173 |
+
z = self._direction(p, base_seed + off * 1_000_003)
|
| 174 |
+
if mask is not None:
|
| 175 |
+
z = z * mask.to(dtype=z.dtype, device=z.device)
|
| 176 |
+
buf = self._momentum.get(id(p))
|
| 177 |
+
if buf is not None:
|
| 178 |
+
buf.mul_(self.momentum).add_(z, alpha=projected_grad)
|
| 179 |
+
p.add_(buf, alpha=-self.lr)
|
| 180 |
+
else:
|
| 181 |
+
p.add_(z, alpha=-self.lr * projected_grad)
|
| 182 |
+
if self.wd > 0:
|
| 183 |
+
p.mul_(1 - self.lr * self.wd)
|
| 184 |
+
for _, m in self._bitlinear_modules:
|
| 185 |
+
m.invalidate_packed()
|
| 186 |
+
|
| 187 |
+
@torch.no_grad()
|
| 188 |
+
def step(self, loss_fn, batch) -> float:
|
| 189 |
+
"""Run one MeZO step (two forward passes) and return the mean loss."""
|
| 190 |
+
seed = int(torch.randint(0, 2**31, (1,)).item())
|
| 191 |
+
|
| 192 |
+
# Snapshot ternary non-zero masks once for this step.
|
| 193 |
+
self._step_masks = {
|
| 194 |
+
id(m.weight): m.ternary_nonzero_mask().detach()
|
| 195 |
+
for _, m in self._bitlinear_modules
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
# Forward at θ + εz.
|
| 199 |
+
self._perturb(seed, +self.eps)
|
| 200 |
+
loss_pos = float(loss_fn(batch).item())
|
| 201 |
+
|
| 202 |
+
# Net displacement: θ + εz - 2εz = θ - εz.
|
| 203 |
+
self._perturb(seed, -2.0 * self.eps)
|
| 204 |
+
loss_neg = float(loss_fn(batch).item())
|
| 205 |
+
|
| 206 |
+
# Restore θ.
|
| 207 |
+
self._perturb(seed, +self.eps)
|
| 208 |
+
|
| 209 |
+
projected_grad = (loss_pos - loss_neg) / (2.0 * self.eps)
|
| 210 |
+
self._update(seed, projected_grad)
|
| 211 |
+
self._step_masks = {}
|
| 212 |
+
|
| 213 |
+
return 0.5 * (loss_pos + loss_neg)
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
# ---------------------------------------------------------------------------
|
| 217 |
+
# Dataset & tokenisation helpers.
|
| 218 |
+
# ---------------------------------------------------------------------------
|
| 219 |
+
|
| 220 |
+
class TokenDataset(Dataset):
|
| 221 |
+
def __init__(self, chunks: torch.Tensor):
|
| 222 |
+
self.chunks = chunks
|
| 223 |
+
|
| 224 |
+
def __len__(self) -> int:
|
| 225 |
+
return self.chunks.size(0)
|
| 226 |
+
|
| 227 |
+
def __getitem__(self, idx: int) -> dict:
|
| 228 |
+
c = self.chunks[idx]
|
| 229 |
+
return {"input_ids": c, "labels": c}
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def _matches_category_filter(ex: dict, filters: list) -> bool:
|
| 233 |
+
cat = ex.get("category", "") or ""
|
| 234 |
+
if not cat:
|
| 235 |
+
return False
|
| 236 |
+
cat_lower = cat.lower()
|
| 237 |
+
return any(f.lower() in cat_lower for f in filters)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def _format_example(ex: dict, tok, text_column: str = "auto",
|
| 241 |
+
include_reasoning: bool = False) -> str:
|
| 242 |
+
if text_column == "auto":
|
| 243 |
+
for cand in ("messages", "text", "content", "conversation"):
|
| 244 |
+
if cand in ex:
|
| 245 |
+
text_column = cand
|
| 246 |
+
break
|
| 247 |
+
else:
|
| 248 |
+
text_column = ""
|
| 249 |
+
|
| 250 |
+
if text_column == "messages" and "messages" in ex:
|
| 251 |
+
msgs = ex["messages"]
|
| 252 |
+
if include_reasoning and isinstance(msgs, list):
|
| 253 |
+
new_msgs = []
|
| 254 |
+
for m in msgs:
|
| 255 |
+
if isinstance(m, dict) and m.get("role") == "assistant" and "reasoning" in m:
|
| 256 |
+
new_msgs.append({
|
| 257 |
+
"role": "assistant",
|
| 258 |
+
"content": (f"<|thinking|>\n{m['reasoning']}\n<|/thinking|>\n"
|
| 259 |
+
f"{m.get('content', '')}"),
|
| 260 |
+
})
|
| 261 |
+
else:
|
| 262 |
+
new_msgs.append(m)
|
| 263 |
+
msgs = new_msgs
|
| 264 |
+
return tok.apply_chat_template(msgs)
|
| 265 |
+
|
| 266 |
+
if text_column and text_column in ex:
|
| 267 |
+
val = ex[text_column]
|
| 268 |
+
if isinstance(val, str):
|
| 269 |
+
return val
|
| 270 |
+
if isinstance(val, list) and val and isinstance(val[0], dict):
|
| 271 |
+
return tok.apply_chat_template(val)
|
| 272 |
+
return str(val)
|
| 273 |
+
return str(ex)
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def build_dataset(seq_len: int, max_samples=None, max_tokens=None,
|
| 277 |
+
split: str = "train",
|
| 278 |
+
dataset_name: str = "roneneldan/TinyStories",
|
| 279 |
+
dataset_config: str = None, text_column: str = "auto",
|
| 280 |
+
category_filter: str = None,
|
| 281 |
+
include_reasoning: bool = False):
|
| 282 |
+
from datasets import load_dataset
|
| 283 |
+
from chimera import ChimeraTokenizer
|
| 284 |
+
|
| 285 |
+
print(f"[DATA] Loading {dataset_name} ({split})...")
|
| 286 |
+
load_kwargs = {"split": split, "streaming": True}
|
| 287 |
+
if dataset_config:
|
| 288 |
+
load_kwargs["name"] = dataset_config
|
| 289 |
+
ds = load_dataset(dataset_name, **load_kwargs)
|
| 290 |
+
tok = ChimeraTokenizer(pretrained="o200k_base")
|
| 291 |
+
|
| 292 |
+
cat_filters = ([c.strip() for c in category_filter.split(",") if c.strip()]
|
| 293 |
+
if category_filter else None)
|
| 294 |
+
if cat_filters:
|
| 295 |
+
print(f"[DATA] Filtering categories: {cat_filters}")
|
| 296 |
+
|
| 297 |
+
if max_tokens is not None:
|
| 298 |
+
token_budget = int(max_tokens)
|
| 299 |
+
elif max_samples is not None:
|
| 300 |
+
token_budget = int(max_samples) * (seq_len + 1)
|
| 301 |
+
else:
|
| 302 |
+
token_budget = None
|
| 303 |
+
|
| 304 |
+
if token_budget is None or token_budget <= 0:
|
| 305 |
+
# Fallback: list-based collection.
|
| 306 |
+
all_ids: list[int] = []
|
| 307 |
+
target = (max_samples * (seq_len + 1)) if max_samples else float("inf")
|
| 308 |
+
for ex in ds:
|
| 309 |
+
if cat_filters and not _matches_category_filter(ex, cat_filters):
|
| 310 |
+
continue
|
| 311 |
+
text = _format_example(ex, tok, text_column, include_reasoning)
|
| 312 |
+
if not text or not text.strip():
|
| 313 |
+
continue
|
| 314 |
+
ids = tok.encode(text, add_special_tokens=False)
|
| 315 |
+
ids.append(tok.eos_token_id)
|
| 316 |
+
all_ids.extend(ids)
|
| 317 |
+
if len(all_ids) >= target:
|
| 318 |
+
break
|
| 319 |
+
all_ids = torch.tensor(all_ids, dtype=torch.long)
|
| 320 |
+
else:
|
| 321 |
+
# Pre-allocated token buffer.
|
| 322 |
+
buffer = torch.empty(token_budget, dtype=torch.long)
|
| 323 |
+
buf_idx = 0
|
| 324 |
+
processed = skipped = 0
|
| 325 |
+
for ex in ds:
|
| 326 |
+
if cat_filters and not _matches_category_filter(ex, cat_filters):
|
| 327 |
+
skipped += 1
|
| 328 |
+
continue
|
| 329 |
+
text = _format_example(ex, tok, text_column, include_reasoning)
|
| 330 |
+
if not text or not text.strip():
|
| 331 |
+
skipped += 1
|
| 332 |
+
continue
|
| 333 |
+
ids = tok.encode(text, add_special_tokens=False)
|
| 334 |
+
ids.append(tok.eos_token_id)
|
| 335 |
+
n = len(ids)
|
| 336 |
+
if buf_idx + n > token_budget:
|
| 337 |
+
n = token_budget - buf_idx
|
| 338 |
+
if n <= 0:
|
| 339 |
+
break
|
| 340 |
+
ids = ids[:n]
|
| 341 |
+
if n > 0:
|
| 342 |
+
buffer[buf_idx:buf_idx + n] = torch.tensor(ids, dtype=torch.long)
|
| 343 |
+
buf_idx += n
|
| 344 |
+
processed += 1
|
| 345 |
+
if buf_idx >= token_budget:
|
| 346 |
+
break
|
| 347 |
+
if (processed % 10_000) == 0:
|
| 348 |
+
print(f" {processed:,} examples, {buf_idx:,} tokens...")
|
| 349 |
+
all_ids = buffer[:buf_idx]
|
| 350 |
+
print(f"[DATA] Processed {processed:,} examples, skipped {skipped:,}.")
|
| 351 |
+
|
| 352 |
+
if all_ids.numel() == 0:
|
| 353 |
+
raise ValueError("No data matched filters.")
|
| 354 |
+
|
| 355 |
+
n = all_ids.numel() // (seq_len + 1)
|
| 356 |
+
if max_samples:
|
| 357 |
+
n = min(n, max_samples)
|
| 358 |
+
chunks = all_ids[:n * (seq_len + 1)].view(n, seq_len + 1)
|
| 359 |
+
print(f"[DATA] {n:,} chunks × {seq_len} tokens = {n * seq_len:,} total")
|
| 360 |
+
return TokenDataset(chunks), tok
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
# ---------------------------------------------------------------------------
|
| 364 |
+
# Learning-rate schedule.
|
| 365 |
+
# ---------------------------------------------------------------------------
|
| 366 |
+
|
| 367 |
+
def cosine_lr(step: int, warmup: int, total: int, max_lr: float, min_lr: float
|
| 368 |
+
) -> float:
|
| 369 |
+
if warmup > 0 and step < warmup:
|
| 370 |
+
return max_lr * (step + 1) / warmup
|
| 371 |
+
if step >= total:
|
| 372 |
+
return min_lr
|
| 373 |
+
p = (step - warmup) / max(1, total - warmup)
|
| 374 |
+
return min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos(math.pi * p))
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
# ---------------------------------------------------------------------------
|
| 378 |
+
# Main loop.
|
| 379 |
+
# ---------------------------------------------------------------------------
|
| 380 |
+
|
| 381 |
+
_SCALE_PRESETS = {
|
| 382 |
+
"tiny": dict(hidden_size=256, intermediate_size=512, num_heads=4, head_dim=48),
|
| 383 |
+
"small": dict(hidden_size=512, intermediate_size=1024, num_heads=8, head_dim=48),
|
| 384 |
+
"medium": dict(hidden_size=1024, intermediate_size=2048, num_heads=8, head_dim=96),
|
| 385 |
+
}
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
def train(args) -> None:
|
| 389 |
+
with open(args.config) as f:
|
| 390 |
+
config = json.load(f)
|
| 391 |
+
|
| 392 |
+
if args.scale in _SCALE_PRESETS:
|
| 393 |
+
config.update(_SCALE_PRESETS[args.scale])
|
| 394 |
+
config["num_hidden_layers"] = int(config.get("num_hidden_layers", 28))
|
| 395 |
+
|
| 396 |
+
config["vocab_size"] = config.get("vocab_size", 200073)
|
| 397 |
+
config.setdefault("gated_deltanet", {})["chunk_size"] = min(args.seq_len, 64)
|
| 398 |
+
config.setdefault("xlstm", {})["memory_size_per_head"] = [config["head_dim"], config["head_dim"]]
|
| 399 |
+
config.setdefault("titans", {}).update({
|
| 400 |
+
"memory_depth": 2, "persistent_memory_slots": 16,
|
| 401 |
+
"local_window_size": min(args.seq_len, 256),
|
| 402 |
+
})
|
| 403 |
+
moe_cfg = config.setdefault("backbone", {}).setdefault("moe", {})
|
| 404 |
+
moe_cfg.setdefault("layers", [3, 7, 11, 15, 19, 23, 27])
|
| 405 |
+
moe_cfg.setdefault("moe_intermediate_size", config["intermediate_size"] // 4)
|
| 406 |
+
moe_cfg.setdefault("n_routed_experts", 8)
|
| 407 |
+
moe_cfg.setdefault("n_shared_experts", 1)
|
| 408 |
+
moe_cfg.setdefault("num_experts_per_tok", 2)
|
| 409 |
+
config.setdefault("looping", {}).update({
|
| 410 |
+
"enabled": True, "prelude": [0, 3], "loop": [4, 23], "coda": [24, 27],
|
| 411 |
+
"loop_range": [1, 3], "loop_default": 2,
|
| 412 |
+
})
|
| 413 |
+
config.setdefault("span_inference", {})["enabled"] = True
|
| 414 |
+
config.setdefault("grammar", {})["enabled"] = True
|
| 415 |
+
config.setdefault("entropy_valve", {})["enabled"] = True
|
| 416 |
+
config.setdefault("debt_ledger", {})["enabled"] = True
|
| 417 |
+
config.setdefault("multimodal", {})["enabled"] = False
|
| 418 |
+
|
| 419 |
+
use_mezo = (args.optimizer == "mezo")
|
| 420 |
+
use_bf16 = bool(args.bf16)
|
| 421 |
+
use_compile = bool(args.compile)
|
| 422 |
+
|
| 423 |
+
print("=" * 60)
|
| 424 |
+
print(f"CHIMERA 5.2 TRAINING — scale={args.scale}, "
|
| 425 |
+
f"optimizer={'MeZO' if use_mezo else 'AdamW'}, bf16={use_bf16}")
|
| 426 |
+
print(f"Layers={config['num_hidden_layers']} hidden={config['hidden_size']} "
|
| 427 |
+
f"vocab={config['vocab_size']} seq_len={args.seq_len} steps={args.max_steps}")
|
| 428 |
+
print(f"Threads: {torch.get_num_threads()} IPEX={HAS_IPEX}")
|
| 429 |
+
print("=" * 60)
|
| 430 |
+
|
| 431 |
+
model = Chimera51ForCausalLM(config)
|
| 432 |
+
counts = model.count_parameters()
|
| 433 |
+
print(f"Params: total={counts['total']:,} ternary={counts['ternary']:,}")
|
| 434 |
+
|
| 435 |
+
if args.grad_checkpoint and not use_mezo:
|
| 436 |
+
model.enable_gradient_checkpointing()
|
| 437 |
+
print("[OPT] Gradient checkpointing ON")
|
| 438 |
+
|
| 439 |
+
if HAS_IPEX and not use_mezo:
|
| 440 |
+
adamw = torch.optim.AdamW(model.parameters(), lr=args.lr)
|
| 441 |
+
model, adamw = ipex.optimize(
|
| 442 |
+
model, optimizer=adamw,
|
| 443 |
+
dtype=torch.bfloat16 if use_bf16 else torch.float32, level="O1")
|
| 444 |
+
print("[OPT] IPEX optimisation applied (level O1)")
|
| 445 |
+
else:
|
| 446 |
+
adamw = None
|
| 447 |
+
|
| 448 |
+
if use_compile:
|
| 449 |
+
print("[OPT] Compiling model with torch.compile (inductor)...")
|
| 450 |
+
model = torch.compile(model, backend="inductor", mode="default", dynamic=True)
|
| 451 |
+
|
| 452 |
+
dataset, tok = build_dataset(
|
| 453 |
+
args.seq_len, max_samples=args.max_samples, max_tokens=args.max_tokens,
|
| 454 |
+
split=args.dataset_split, dataset_name=args.dataset_name,
|
| 455 |
+
dataset_config=args.dataset_config, text_column=args.text_column,
|
| 456 |
+
category_filter=args.category_filter,
|
| 457 |
+
include_reasoning=args.include_reasoning,
|
| 458 |
+
)
|
| 459 |
+
loader = DataLoader(
|
| 460 |
+
dataset, batch_size=args.batch_size, shuffle=True,
|
| 461 |
+
num_workers=args.num_workers, drop_last=True,
|
| 462 |
+
persistent_workers=args.num_workers > 0,
|
| 463 |
+
prefetch_factor=2 if args.num_workers > 0 else None,
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
if use_mezo:
|
| 467 |
+
optimizer = MeZOOptimizer(
|
| 468 |
+
model, lr=args.lr * 0.01, eps=1e-3,
|
| 469 |
+
weight_decay=0.1, momentum=0.9, direction=args.mezo_direction,
|
| 470 |
+
)
|
| 471 |
+
else:
|
| 472 |
+
no_decay = {"A_log", "dt_bias", "norm", "bias", "embed", "energy_weights"}
|
| 473 |
+
decay_params, no_decay_params = [], []
|
| 474 |
+
for n, p in model.named_parameters():
|
| 475 |
+
if not p.requires_grad:
|
| 476 |
+
continue
|
| 477 |
+
if any(tag in n for tag in no_decay):
|
| 478 |
+
no_decay_params.append(p)
|
| 479 |
+
else:
|
| 480 |
+
decay_params.append(p)
|
| 481 |
+
if adamw is None:
|
| 482 |
+
optimizer = torch.optim.AdamW(
|
| 483 |
+
[{"params": decay_params, "weight_decay": 0.1},
|
| 484 |
+
{"params": no_decay_params, "weight_decay": 0.0}],
|
| 485 |
+
lr=args.lr, betas=(0.9, 0.95))
|
| 486 |
+
else:
|
| 487 |
+
optimizer = adamw
|
| 488 |
+
|
| 489 |
+
def compute_loss(batch) -> torch.Tensor:
|
| 490 |
+
ids = batch["input_ids"][:, :-1]
|
| 491 |
+
labels = batch["labels"][:, 1:]
|
| 492 |
+
if use_bf16:
|
| 493 |
+
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
|
| 494 |
+
out = model(ids, labels=labels)
|
| 495 |
+
else:
|
| 496 |
+
out = model(ids, labels=labels)
|
| 497 |
+
return out.loss
|
| 498 |
+
|
| 499 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 500 |
+
log_path = os.path.join(args.output_dir, "log.jsonl")
|
| 501 |
+
log_f = open(log_path, "w", encoding="utf-8")
|
| 502 |
+
|
| 503 |
+
model.train()
|
| 504 |
+
step = 0
|
| 505 |
+
cur_lr = args.lr
|
| 506 |
+
total_loss = 0.0
|
| 507 |
+
best_loss = float("inf")
|
| 508 |
+
toks = 0
|
| 509 |
+
t0 = time.time()
|
| 510 |
+
data_iter = iter(loader)
|
| 511 |
+
warmup = min(args.warmup, max(1, args.max_steps // 10))
|
| 512 |
+
|
| 513 |
+
if not use_mezo:
|
| 514 |
+
optimizer.zero_grad(set_to_none=True)
|
| 515 |
+
|
| 516 |
+
print(f"\n{'=' * 60}\nTraining starts\n{'=' * 60}\n")
|
| 517 |
+
|
| 518 |
+
while step < args.max_steps:
|
| 519 |
+
try:
|
| 520 |
+
batch = next(data_iter)
|
| 521 |
+
except StopIteration:
|
| 522 |
+
data_iter = iter(loader)
|
| 523 |
+
batch = next(data_iter)
|
| 524 |
+
|
| 525 |
+
if use_mezo:
|
| 526 |
+
cur_lr = cosine_lr(step, warmup, args.max_steps,
|
| 527 |
+
args.lr * 0.01, args.lr * 0.001)
|
| 528 |
+
optimizer.lr = cur_lr
|
| 529 |
+
loss_val = optimizer.step(compute_loss, batch)
|
| 530 |
+
total_loss += loss_val
|
| 531 |
+
else:
|
| 532 |
+
loss = compute_loss(batch)
|
| 533 |
+
(loss / args.grad_accum).backward()
|
| 534 |
+
total_loss += float(loss.item())
|
| 535 |
+
if (step + 1) % args.grad_accum == 0:
|
| 536 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 537 |
+
cur_lr = cosine_lr(step, warmup, args.max_steps,
|
| 538 |
+
args.lr, args.lr * 0.1)
|
| 539 |
+
for pg in optimizer.param_groups:
|
| 540 |
+
pg["lr"] = cur_lr
|
| 541 |
+
optimizer.step()
|
| 542 |
+
optimizer.zero_grad(set_to_none=True)
|
| 543 |
+
|
| 544 |
+
toks += batch["input_ids"][:, :-1].numel()
|
| 545 |
+
step += 1
|
| 546 |
+
|
| 547 |
+
if step % args.log_every == 0:
|
| 548 |
+
dt = time.time() - t0
|
| 549 |
+
avg = total_loss / args.log_every
|
| 550 |
+
ppl = math.exp(min(avg, 20))
|
| 551 |
+
tps = toks / dt if dt > 0 else 0
|
| 552 |
+
eta_h = (args.max_steps - step) / (step / dt) / 3600 if dt > 0 else 0.0
|
| 553 |
+
log_f.write(json.dumps({
|
| 554 |
+
"step": step, "loss": round(avg, 4), "ppl": round(ppl, 2),
|
| 555 |
+
"lr": cur_lr, "tok/s": round(tps),
|
| 556 |
+
"optimizer": "mezo" if use_mezo else "adamw",
|
| 557 |
+
}) + "\n")
|
| 558 |
+
log_f.flush()
|
| 559 |
+
print(f" step {step:>6}/{args.max_steps} | loss {avg:.4f} | "
|
| 560 |
+
f"ppl {ppl:>8.2f} | lr {cur_lr:.2e} | "
|
| 561 |
+
f"{tps:.0f} tok/s | ETA {eta_h:.1f}h")
|
| 562 |
+
best_loss = min(best_loss, avg)
|
| 563 |
+
total_loss = 0.0
|
| 564 |
+
toks = 0
|
| 565 |
+
t0 = time.time()
|
| 566 |
+
|
| 567 |
+
if step % args.save_every == 0:
|
| 568 |
+
ckpt_dir = os.path.join(args.output_dir, f"ckpt-{step}")
|
| 569 |
+
os.makedirs(ckpt_dir, exist_ok=True)
|
| 570 |
+
raw = getattr(model, "_orig_mod", model)
|
| 571 |
+
torch.save({
|
| 572 |
+
"model": raw.state_dict(), "config": config,
|
| 573 |
+
"step": step, "optimizer": args.optimizer,
|
| 574 |
+
}, os.path.join(ckpt_dir, "ckpt.pt"))
|
| 575 |
+
print(f" [SAVE] {ckpt_dir}")
|
| 576 |
+
|
| 577 |
+
final_dir = os.path.join(args.output_dir, "final")
|
| 578 |
+
os.makedirs(final_dir, exist_ok=True)
|
| 579 |
+
raw = getattr(model, "_orig_mod", model)
|
| 580 |
+
torch.save({
|
| 581 |
+
"model": raw.state_dict(), "config": config,
|
| 582 |
+
"step": step, "best_loss": best_loss,
|
| 583 |
+
}, os.path.join(final_dir, "model.pt"))
|
| 584 |
+
with open(os.path.join(final_dir, "config.json"), "w", encoding="utf-8") as fh:
|
| 585 |
+
json.dump(config, fh, indent=2)
|
| 586 |
+
log_f.close()
|
| 587 |
+
|
| 588 |
+
print(f"\n{'=' * 60}")
|
| 589 |
+
print(f"DONE — best loss {best_loss:.4f}, ppl {math.exp(min(best_loss, 20)):.2f}")
|
| 590 |
+
print(f"Saved to {final_dir}")
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
# ---------------------------------------------------------------------------
|
| 594 |
+
# CLI
|
| 595 |
+
# ---------------------------------------------------------------------------
|
| 596 |
+
|
| 597 |
+
def _build_argparser() -> argparse.ArgumentParser:
|
| 598 |
+
p = argparse.ArgumentParser(description="Chimera 5.2 CPU-first training")
|
| 599 |
+
p.add_argument("--config", default="config.json")
|
| 600 |
+
p.add_argument("--scale", default="tiny", choices=["tiny", "small", "medium", "full"])
|
| 601 |
+
p.add_argument("--seq_len", type=int, default=256)
|
| 602 |
+
p.add_argument("--optimizer", default="mezo", choices=["mezo", "adamw"])
|
| 603 |
+
p.add_argument("--batch_size", type=int, default=2)
|
| 604 |
+
p.add_argument("--grad_accum", type=int, default=8)
|
| 605 |
+
p.add_argument("--lr", type=float, default=1e-3)
|
| 606 |
+
p.add_argument("--warmup", type=int, default=200)
|
| 607 |
+
p.add_argument("--max_steps", type=int, default=5000)
|
| 608 |
+
p.add_argument("--max_samples", type=int, default=None)
|
| 609 |
+
p.add_argument("--max_tokens", type=int, default=None)
|
| 610 |
+
p.add_argument("--bf16", action="store_true", default=True)
|
| 611 |
+
p.add_argument("--no-bf16", dest="bf16", action="store_false")
|
| 612 |
+
p.add_argument("--compile", action="store_true", default=False)
|
| 613 |
+
p.add_argument("--grad_checkpoint", action="store_true", default=True)
|
| 614 |
+
p.add_argument("--no-grad-checkpoint", dest="grad_checkpoint", action="store_false")
|
| 615 |
+
p.add_argument("--mezo_direction", choices=["rademacher", "gaussian"],
|
| 616 |
+
default="rademacher")
|
| 617 |
+
p.add_argument("--dataset_name", default="roneneldan/TinyStories")
|
| 618 |
+
p.add_argument("--dataset_config", default=None)
|
| 619 |
+
p.add_argument("--dataset_split", default="train")
|
| 620 |
+
p.add_argument("--text_column", default="auto")
|
| 621 |
+
p.add_argument("--category_filter", default=None)
|
| 622 |
+
p.add_argument("--include_reasoning", action="store_true", default=False)
|
| 623 |
+
p.add_argument("--num_workers", type=int, default=2)
|
| 624 |
+
p.add_argument("--log_every", type=int, default=10)
|
| 625 |
+
p.add_argument("--save_every", type=int, default=1000)
|
| 626 |
+
p.add_argument("--output_dir", default="./chimera_output")
|
| 627 |
+
return p
|
| 628 |
+
|
| 629 |
+
|
| 630 |
+
if __name__ == "__main__":
|
| 631 |
+
args = _build_argparser().parse_args()
|
| 632 |
+
train(args)
|