Upload folder using huggingface_hub
Browse files- .gitignore +20 -0
- README.md +226 -0
- chimera/__init__.py +53 -0
- chimera/__main__.py +31 -0
- chimera/cli.py +62 -0
- chimera/config.py +67 -0
- chimera/evolution.py +594 -0
- chimera/hyper.py +394 -0
- chimera/inference.py +359 -0
- chimera/layers.py +485 -0
- chimera/looping.py +73 -0
- chimera/model.py +438 -0
- chimera/moe.py +102 -0
- chimera/multimodal.py +136 -0
- chimera/paths.py +15 -0
- chimera/quantization.py +508 -0
- chimera/tokenizer.py +160 -0
- chimera/training/__init__.py +57 -0
- chimera/training/benchmark.py +171 -0
- chimera/training/common.py +119 -0
- chimera/training/datasets.py +205 -0
- chimera/training/hyper.py +128 -0
- chimera/training/loops.py +224 -0
- chimera/training/optimizers.py +113 -0
- chimera_turbo.py +549 -0
- config.json +716 -0
- gguf_import.py +907 -0
- inference.py +309 -0
- launch_turbo.sh +48 -0
- pyproject.toml +28 -0
- tests/test_chimera.py +115 -0
- tests/test_config.py +8 -0
- train.py +239 -0
- train_fast.py +140 -0
- train_hyper.py +192 -0
.gitignore
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.py[cod]
|
| 3 |
+
.pytest_cache/
|
| 4 |
+
.venv/
|
| 5 |
+
.deps/
|
| 6 |
+
.mypy_cache/
|
| 7 |
+
.ruff_cache/
|
| 8 |
+
.coverage
|
| 9 |
+
build/
|
| 10 |
+
dist/
|
| 11 |
+
*.egg-info/
|
| 12 |
+
cache/
|
| 13 |
+
chimera_output/
|
| 14 |
+
chimera_hyper_output/
|
| 15 |
+
chimera_imported/
|
| 16 |
+
*.pt
|
| 17 |
+
*.gguf
|
| 18 |
+
.ternary_build*
|
| 19 |
+
.kernel_build
|
| 20 |
+
.simd_build
|
README.md
ADDED
|
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Chimera 5.3 β HYPER CPU Training (10,000+ tok/s target)
|
| 2 |
+
|
| 3 |
+
100% faithful implementation of the Chimera 5.x config. All 15 architectural components implemented in pure PyTorch, with **true 1.58-bit ternary computation** on CPU.
|
| 4 |
+
|
| 5 |
+
**v5.3 NEW**: 7 stacked training paradigms designed to push CPU training from ~50-200 tok/s to **10,000+ tok/s** on a single CPU β targeting AGI-class LLM training without GPUs.
|
| 6 |
+
|
| 7 |
+
**Tokenizer**: splintr-rs (Rust) β o200k_base vocab (200,073 tokens, OpenAI o1/o3).
|
| 8 |
+
|
| 9 |
+
## Repo Structure
|
| 10 |
+
|
| 11 |
+
The repo is now organized around the `chimera/` package as the source of truth:
|
| 12 |
+
|
| 13 |
+
- `chimera/` β model code, config helpers, package CLI wrappers, shared path helpers
|
| 14 |
+
- `train.py` β standard training entrypoint
|
| 15 |
+
- `train_fast.py` β cached-dataset training entrypoint
|
| 16 |
+
- `train_hyper.py` β hyper training entrypoint
|
| 17 |
+
- `inference.py` β generation entrypoint
|
| 18 |
+
- `gguf_import.py` β GGUF import entrypoint
|
| 19 |
+
- `tests/` β smoke and config tests
|
| 20 |
+
|
| 21 |
+
You can still run the root scripts directly, or use packaged commands after install:
|
| 22 |
+
|
| 23 |
+
```bash
|
| 24 |
+
chimera-train --help
|
| 25 |
+
chimera-train-fast --help
|
| 26 |
+
chimera-train-hyper --help
|
| 27 |
+
chimera-infer --help
|
| 28 |
+
chimera-import-gguf --help
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
---
|
| 32 |
+
|
| 33 |
+
## v5.3 β HYPER Training Paradigms
|
| 34 |
+
|
| 35 |
+
Seven orthogonal paradigms that stack **multiplicatively** for extreme CPU training speed:
|
| 36 |
+
|
| 37 |
+
| # | Paradigm | Speedup | Paper | Mechanism |
|
| 38 |
+
|---|----------|---------|-------|-----------|
|
| 39 |
+
| P1 | **GrowLength Curriculum** | 4-8Γ | [arxiv:2310.00576](https://arxiv.org/abs/2310.00576) | Start seq=16, grow to target. Short seqs β huge batch β way more tok/s |
|
| 40 |
+
| P2 | **Reservoir Freezing** | 1.5-2Γ | [arxiv:2512.23145](https://arxiv.org/abs/2512.23145) | Freeze 50% of recurrent gates as random ternary. No grad = fewer FLOPs |
|
| 41 |
+
| P3 | **Sparse MeZO** | 3-5Γ | [arxiv:2406.02913](https://arxiv.org/abs/2406.02913) | Perturb only top-1% sensitive params. ZO signal quality β sparsity |
|
| 42 |
+
| P4 | **Blockwise Pipeline** | 1.3-2Γ | β | Pin layer-groups to core-groups; overlap forward passes |
|
| 43 |
+
| P5 | **Fused Ternary Cache** | 1.3Γ | β | Pre-materialise dense weights once; reuse for both MeZO forwards |
|
| 44 |
+
| P6 | **Aggressive Token Packing** | 1.1-1.3Γ | β | Zero padding waste; documents packed back-to-back with EOS |
|
| 45 |
+
| P7 | **Progressive Layer Unfreeze** | 1.5-2Γ | β | Train only top 25% of layers first; unfreeze downward |
|
| 46 |
+
|
| 47 |
+
**Combined theoretical multiplier**: P1(6Γ) Γ P2(1.7Γ) Γ P3(4Γ) Γ P5(1.3Γ) Γ P7(1.7Γ) β **57-260Γ**
|
| 48 |
+
|
| 49 |
+
**Realistic target**: 50-200 tok/s baseline β **3,000-15,000+ tok/s**
|
| 50 |
+
|
| 51 |
+
### Quick Start β HYPER Training
|
| 52 |
+
|
| 53 |
+
```bash
|
| 54 |
+
# All 7 paradigms ON β maximum speed
|
| 55 |
+
python train_hyper.py --scale tiny --max_steps 5000 --all
|
| 56 |
+
|
| 57 |
+
# Cherry-pick specific paradigms
|
| 58 |
+
python train_hyper.py --scale tiny --max_steps 5000 \
|
| 59 |
+
--growlength --sparse-mezo --reservoir --fused-cache
|
| 60 |
+
|
| 61 |
+
# Benchmark: baseline vs hyper (side-by-side comparison)
|
| 62 |
+
python train_hyper.py --scale tiny --max_steps 100 --benchmark
|
| 63 |
+
|
| 64 |
+
# Full training run with all paradigms
|
| 65 |
+
OMP_NUM_THREADS=$(nproc) python train_hyper.py \
|
| 66 |
+
--scale small --seq_len 256 --max_steps 50000 \
|
| 67 |
+
--all --bf16 --compile \
|
| 68 |
+
--save_every 5000 --log_every 10
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
### Paradigm Details
|
| 72 |
+
|
| 73 |
+
#### P1 β GrowLength Curriculum ([arxiv:2310.00576](https://arxiv.org/abs/2310.00576))
|
| 74 |
+
|
| 75 |
+
Trains with progressively longer sequences. At seq_len=16, you can fit 16Γ more tokens per batch than at seq_len=256, giving massive throughput in early training where the learning signal is strongest.
|
| 76 |
+
|
| 77 |
+
Default schedule:
|
| 78 |
+
- 20% of training at seq_len = target/8
|
| 79 |
+
- 25% at target/4
|
| 80 |
+
- 25% at target/2
|
| 81 |
+
- 30% at full target
|
| 82 |
+
|
| 83 |
+
```bash
|
| 84 |
+
python train_hyper.py --growlength --seq_len 256
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
#### P2 β Reservoir Freezing ([arxiv:2512.23145](https://arxiv.org/abs/2512.23145))
|
| 88 |
+
|
| 89 |
+
Inspired by GRC (Reservoir Computing for Language Models): freezes gate/forget projections in recurrent layers as random ternary matrices with unit spectral radius. These "reservoir" weights provide stable dynamics without needing gradient updates.
|
| 90 |
+
|
| 91 |
+
Targets:
|
| 92 |
+
- GatedDeltaNet: `a_proj`, `b_proj` (alpha/beta gates)
|
| 93 |
+
- mLSTM: `fgate` (forget gate)
|
| 94 |
+
- TitansMAC: `alpha_proj` (forgetting gate)
|
| 95 |
+
|
| 96 |
+
```bash
|
| 97 |
+
python train_hyper.py --reservoir --reservoir-ratio 0.5
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
#### P3 β Sparse MeZO ([arxiv:2406.02913](https://arxiv.org/abs/2406.02913))
|
| 101 |
+
|
| 102 |
+
Standard MeZO perturbs all ~35M parameters β most contribute near-zero gradient signal. Sparse MeZO identifies the top-K% most sensitive parameters (by weight magnitude) and perturbs only those. This dramatically reduces the variance of the ZO gradient estimate.
|
| 103 |
+
|
| 104 |
+
At 1% sparsity on a 35M model: only 350K params perturbed per step β **100Γ better signal-to-noise per forward pass**.
|
| 105 |
+
|
| 106 |
+
```bash
|
| 107 |
+
python train_hyper.py --sparse-mezo --mezo-sparsity 0.01
|
| 108 |
+
```
|
| 109 |
+
|
| 110 |
+
#### P5 β Fused Ternary Cache
|
| 111 |
+
|
| 112 |
+
Before each MeZO dual-forward, pre-materialises all BitLinear packed+dense weight caches. Both forward passes then reuse the same buffers β eliminates redundant quantizeβpackβunpack cycles.
|
| 113 |
+
|
| 114 |
+
```bash
|
| 115 |
+
python train_hyper.py --fused-cache
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
#### P7 β Progressive Layer Unfreezing
|
| 119 |
+
|
| 120 |
+
Starts with only the top ~25% of layers trainable. Early training is cheap (forward through frozen layers is fast, no gradient storage). Gradually unfreezes deeper layers as training progresses.
|
| 121 |
+
|
| 122 |
+
```bash
|
| 123 |
+
python train_hyper.py --progressive-unfreeze --unfreeze-stages 4
|
| 124 |
+
```
|
| 125 |
+
|
| 126 |
+
---
|
| 127 |
+
|
| 128 |
+
## Files
|
| 129 |
+
|
| 130 |
+
```
|
| 131 |
+
chimera/
|
| 132 |
+
__init__.py β Package exports (v5.3)
|
| 133 |
+
config.py β Config loading / scaling
|
| 134 |
+
hyper.py β β
NEW: 7 HYPER paradigm engine
|
| 135 |
+
quantization.py β BitLinear (2-bit packed, C++ kernel, STE, N:M 2:4)
|
| 136 |
+
layers.py β GatedDeltaNet, mLSTM, TitansMAC, TSPSpanKnot
|
| 137 |
+
moe.py β MoELayer (sort-based dispatch)
|
| 138 |
+
looping.py β ParcaeLoopController
|
| 139 |
+
inference.py β SpanBank, STree, Grammar, EntropyValve, DebtLedger
|
| 140 |
+
evolution.py β TTT, SemanticMemory, EpisodicCases, MetaGuidelines
|
| 141 |
+
multimodal.py β VisionEncoder, AudioEncoder
|
| 142 |
+
tokenizer.py β ChimeraTokenizer (splintr, o200k_base)
|
| 143 |
+
model.py β Chimera51ForCausalLM
|
| 144 |
+
config.json β Full model config
|
| 145 |
+
train.py β Standard training (MeZO + AdamW)
|
| 146 |
+
train_fast.py β Fast training with pre-tokenized cache
|
| 147 |
+
train_hyper.py β β
NEW: HYPER training (7 paradigms, 10k+ tok/s)
|
| 148 |
+
inference.py β Inference / generation
|
| 149 |
+
```
|
| 150 |
+
|
| 151 |
+
---
|
| 152 |
+
|
| 153 |
+
## Previous Versions
|
| 154 |
+
|
| 155 |
+
### v5.1.4 β CPU Fast Path Audit
|
| 156 |
+
- Fixed package/runtime mismatch
|
| 157 |
+
- Added sparse MoELayer with expert-grouped dispatch
|
| 158 |
+
- Made C++ ternary extensions lazy-loaded
|
| 159 |
+
- Vectorized BitLinear AbsMean scaling
|
| 160 |
+
- Cached causal/triangular masks
|
| 161 |
+
- Reduced GatedDeltaNet clone churn
|
| 162 |
+
|
| 163 |
+
### v5.1.3 β Fix Illegal Instruction Crash
|
| 164 |
+
- Removed `-march=native` from C++ JIT flags
|
| 165 |
+
- Runtime CPUID detection for AVX-512/AVX2
|
| 166 |
+
|
| 167 |
+
### v5.1.2 β True Ternary Compute
|
| 168 |
+
- 2-bit packed uint8 weight storage (16Γ compression)
|
| 169 |
+
- C++ unpack + MKL BLAS forward path
|
| 170 |
+
- MeZO sparse perturbation (skip ~33% zeros)
|
| 171 |
+
- STE backward with deep-zero masking
|
| 172 |
+
|
| 173 |
+
---
|
| 174 |
+
|
| 175 |
+
## Architecture (28 layers, 4 types)
|
| 176 |
+
|
| 177 |
+
```
|
| 178 |
+
Layer pattern: GD XM GD TM GD XM GD SK Γ 3.5
|
| 179 |
+
GD = Gated DeltaNet (14 layers) β arxiv:2412.06464
|
| 180 |
+
XM = xLSTM mLSTM (7 layers) β arxiv:2405.04517
|
| 181 |
+
TM = Titans MAC (4 layers) β arxiv:2501.00663
|
| 182 |
+
SK = TSP Span Knot (3 layers)
|
| 183 |
+
```
|
| 184 |
+
|
| 185 |
+
All linear layers use **BitLinear** (ternary 1.58-bit) with per-group AbsMean scaling.
|
| 186 |
+
|
| 187 |
+
---
|
| 188 |
+
|
| 189 |
+
## Training Modes
|
| 190 |
+
|
| 191 |
+
### HYPER (v5.3 β Recommended)
|
| 192 |
+
- **7 stacked paradigms** for maximum CPU throughput
|
| 193 |
+
- Target: **10,000+ tok/s** on 8-core CPU (tiny scale)
|
| 194 |
+
- Forward-only training (Sparse MeZO): no backward pass
|
| 195 |
+
- Memory = 2Γ model size (no activations, no gradients, no optimizer states)
|
| 196 |
+
- Each paradigm independently toggleable via CLI flags
|
| 197 |
+
|
| 198 |
+
### MeZO (v5.1 β Standard)
|
| 199 |
+
- Standard zeroth-order optimization
|
| 200 |
+
- 2 forward passes per step, no backward
|
| 201 |
+
- Good for fine-tuning; ~50-200 tok/s on CPU
|
| 202 |
+
|
| 203 |
+
### AdamW (v5.1 β Full backprop)
|
| 204 |
+
- Standard gradient descent with checkpointing
|
| 205 |
+
- Best convergence quality for pretraining from scratch
|
| 206 |
+
- ~10-50 tok/s on CPU
|
| 207 |
+
|
| 208 |
+
---
|
| 209 |
+
|
| 210 |
+
## References
|
| 211 |
+
|
| 212 |
+
37 papers indexed in `config.json` under `Β§`. Key additions for v5.3:
|
| 213 |
+
- [GrowLength](https://arxiv.org/abs/2310.00576) β Progressive sequence length training
|
| 214 |
+
- [GRC MatMul-free LM](https://arxiv.org/abs/2512.23145) β Reservoir computing for LMs
|
| 215 |
+
- [Sparse MeZO](https://arxiv.org/abs/2406.02913) β Sparse zeroth-order fine-tuning
|
| 216 |
+
- [GaLore](https://arxiv.org/abs/2403.03507) β Gradient low-rank projection
|
| 217 |
+
- [QuZO](https://arxiv.org/abs/2502.12346) β Quantized zeroth-order training
|
| 218 |
+
- [SparAMX](https://arxiv.org/abs/2502.12444) β AMX-accelerated sparse CPU kernels
|
| 219 |
+
|
| 220 |
+
Plus all previous references:
|
| 221 |
+
- [Gated DeltaNet](https://arxiv.org/abs/2412.06464) β NVIDIA
|
| 222 |
+
- [xLSTM](https://arxiv.org/abs/2405.04517) β NXAI/JKU
|
| 223 |
+
- [Titans](https://arxiv.org/abs/2501.00663) β Google
|
| 224 |
+
- [Parcae](https://arxiv.org/abs/2604.12946) β Stanford/Together
|
| 225 |
+
- [BitNet b1.58](https://arxiv.org/abs/2402.17764) β Microsoft
|
| 226 |
+
- [MeZO](https://arxiv.org/abs/2305.17333) β Princeton
|
chimera/__init__.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Chimera 5.3 β CPU-first causal LM with ternary 1.58-bit weights."""
|
| 2 |
+
|
| 3 |
+
from .config import load_config, scale_config, tiny_config
|
| 4 |
+
from .paths import DEFAULT_CONFIG_PATH, PACKAGE_ROOT, REPO_ROOT, resolve_repo_path
|
| 5 |
+
|
| 6 |
+
__version__ = "5.3.0"
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
"load_config", "scale_config", "tiny_config",
|
| 10 |
+
"DEFAULT_CONFIG_PATH", "PACKAGE_ROOT", "REPO_ROOT", "resolve_repo_path",
|
| 11 |
+
"Chimera51ForCausalLM", "Chimera51Block", "expand_layer_pattern",
|
| 12 |
+
"BitLinear", "RMSNorm", "pack_ternary", "unpack_ternary",
|
| 13 |
+
"ternarize_weight", "_quantize_weights_ternary", "apply_2_4_sparsity_",
|
| 14 |
+
"enable_native_kernel", "native_kernel_available",
|
| 15 |
+
"ChimeraTokenizer",
|
| 16 |
+
"SelfEvolutionEngine", "SemanticMemory", "InPlaceTTT",
|
| 17 |
+
"EpisodicCaseMemory", "MetaGuidelineBank", "SelfFeedback",
|
| 18 |
+
"LoopDepthClassifier",
|
| 19 |
+
# v5.3 β Hyper paradigms
|
| 20 |
+
"GrowLengthDataset", "GrowLengthScheduler",
|
| 21 |
+
"apply_reservoir_freezing", "SparseMeZOOptimizer",
|
| 22 |
+
"precompute_ternary_cache", "pack_documents",
|
| 23 |
+
"ProgressiveUnfreezer", "cosine_lr",
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# Lazy public surface β keeps ``import chimera`` cheap (no torch import until
|
| 28 |
+
# the user actually touches a model class).
|
| 29 |
+
def __getattr__(name):
|
| 30 |
+
if name in {"Chimera51ForCausalLM", "Chimera51Block", "expand_layer_pattern"}:
|
| 31 |
+
from .model import Chimera51ForCausalLM, Chimera51Block, expand_layer_pattern
|
| 32 |
+
return locals()[name]
|
| 33 |
+
if name in {"BitLinear", "RMSNorm", "pack_ternary", "unpack_ternary",
|
| 34 |
+
"ternarize_weight", "_quantize_weights_ternary",
|
| 35 |
+
"apply_2_4_sparsity_", "enable_native_kernel",
|
| 36 |
+
"native_kernel_available"}:
|
| 37 |
+
from . import quantization as _q
|
| 38 |
+
return getattr(_q, name)
|
| 39 |
+
if name == "ChimeraTokenizer":
|
| 40 |
+
from .tokenizer import ChimeraTokenizer
|
| 41 |
+
return ChimeraTokenizer
|
| 42 |
+
if name in {"SelfEvolutionEngine", "SemanticMemory", "InPlaceTTT",
|
| 43 |
+
"EpisodicCaseMemory", "MetaGuidelineBank", "SelfFeedback",
|
| 44 |
+
"LoopDepthClassifier"}:
|
| 45 |
+
from . import evolution as _evo
|
| 46 |
+
return getattr(_evo, name)
|
| 47 |
+
if name in {"GrowLengthDataset", "GrowLengthScheduler",
|
| 48 |
+
"apply_reservoir_freezing", "SparseMeZOOptimizer",
|
| 49 |
+
"precompute_ternary_cache", "pack_documents",
|
| 50 |
+
"ProgressiveUnfreezer", "cosine_lr"}:
|
| 51 |
+
from . import hyper as _hyp
|
| 52 |
+
return getattr(_hyp, name)
|
| 53 |
+
raise AttributeError(name)
|
chimera/__main__.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
|
| 5 |
+
from . import __version__
|
| 6 |
+
from .cli import infer_main, train_fast_main, train_hyper_main, train_main
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def main() -> None:
|
| 10 |
+
parser = argparse.ArgumentParser(prog="python -m chimera")
|
| 11 |
+
parser.add_argument("--version", action="version", version=f"%(prog)s {__version__}")
|
| 12 |
+
subparsers = parser.add_subparsers(dest="command")
|
| 13 |
+
subparsers.add_parser("train")
|
| 14 |
+
subparsers.add_parser("train-fast")
|
| 15 |
+
subparsers.add_parser("train-hyper")
|
| 16 |
+
subparsers.add_parser("infer")
|
| 17 |
+
|
| 18 |
+
args, _ = parser.parse_known_args()
|
| 19 |
+
if args.command == "train":
|
| 20 |
+
train_main()
|
| 21 |
+
return
|
| 22 |
+
if args.command == "train-fast":
|
| 23 |
+
train_fast_main()
|
| 24 |
+
return
|
| 25 |
+
if args.command == "train-hyper":
|
| 26 |
+
train_hyper_main()
|
| 27 |
+
return
|
| 28 |
+
if args.command == "infer":
|
| 29 |
+
infer_main()
|
| 30 |
+
return
|
| 31 |
+
parser.print_help()
|
chimera/cli.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def train_main() -> None:
|
| 7 |
+
from train import _build_argparser, train
|
| 8 |
+
|
| 9 |
+
args = _build_argparser().parse_args()
|
| 10 |
+
train(args)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def train_fast_main() -> None:
|
| 14 |
+
from train_fast import train
|
| 15 |
+
|
| 16 |
+
parser = argparse.ArgumentParser(description="Chimera 5.2 Fast CPU training")
|
| 17 |
+
parser.add_argument("--config", default="config.json")
|
| 18 |
+
parser.add_argument("--scale", default="tiny", choices=["tiny", "small", "medium", "full"])
|
| 19 |
+
parser.add_argument("--seq_len", type=int, default=32)
|
| 20 |
+
parser.add_argument("--batch_size", type=int, default=4)
|
| 21 |
+
parser.add_argument("--lr", type=float, default=1e-3)
|
| 22 |
+
parser.add_argument("--warmup", type=int, default=100)
|
| 23 |
+
parser.add_argument("--max_steps", type=int, default=1000)
|
| 24 |
+
parser.add_argument("--max_samples", type=int, default=5000)
|
| 25 |
+
parser.add_argument("--bf16", action="store_true", default=False)
|
| 26 |
+
parser.add_argument("--compile", action="store_true", default=False)
|
| 27 |
+
parser.add_argument("--cache_dir", default="./cache")
|
| 28 |
+
parser.add_argument("--log_every", type=int, default=10)
|
| 29 |
+
parser.add_argument("--save_every", type=int, default=500)
|
| 30 |
+
parser.add_argument("--output_dir", default="./chimera_output")
|
| 31 |
+
train(parser.parse_args())
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def train_hyper_main() -> None:
|
| 35 |
+
from train_hyper import benchmark, cli, train_hyper
|
| 36 |
+
|
| 37 |
+
args = cli().parse_args()
|
| 38 |
+
if args.max_samples and not args.max_tokens:
|
| 39 |
+
args.max_tokens = args.max_samples * (args.seq_len + 1)
|
| 40 |
+
if args.all:
|
| 41 |
+
args.growlength = True
|
| 42 |
+
args.reservoir = True
|
| 43 |
+
args.progressive_unfreeze = True
|
| 44 |
+
if args.benchmark:
|
| 45 |
+
args.growlength = True
|
| 46 |
+
args.reservoir = True
|
| 47 |
+
args.progressive_unfreeze = True
|
| 48 |
+
benchmark(args)
|
| 49 |
+
return
|
| 50 |
+
train_hyper(args)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def infer_main() -> None:
|
| 54 |
+
from inference import main
|
| 55 |
+
|
| 56 |
+
main()
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def import_gguf_main() -> None:
|
| 60 |
+
from gguf_import import main
|
| 61 |
+
|
| 62 |
+
main()
|
chimera/config.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import copy
|
| 4 |
+
import json
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Any, Mapping
|
| 7 |
+
|
| 8 |
+
from .paths import DEFAULT_CONFIG_PATH
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def load_config(path: str | Path | None = None, overrides: Mapping[str, Any] | None = None) -> dict:
|
| 12 |
+
"""Load a Chimera JSON config and apply shallow dotted-key overrides."""
|
| 13 |
+
if path is None:
|
| 14 |
+
path = DEFAULT_CONFIG_PATH
|
| 15 |
+
with open(path, "r", encoding="utf-8") as fh:
|
| 16 |
+
cfg = json.load(fh)
|
| 17 |
+
if overrides:
|
| 18 |
+
cfg = copy.deepcopy(cfg)
|
| 19 |
+
for key, value in overrides.items():
|
| 20 |
+
cur = cfg
|
| 21 |
+
parts = str(key).split(".")
|
| 22 |
+
for part in parts[:-1]:
|
| 23 |
+
cur = cur.setdefault(part, {})
|
| 24 |
+
cur[parts[-1]] = value
|
| 25 |
+
return cfg
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def scale_config(config: dict, scale: str = "base") -> dict:
|
| 29 |
+
"""Return a safe CPU-scaled copy while preserving feature flags.
|
| 30 |
+
|
| 31 |
+
The uploaded Chimera config targets a large model. These presets keep all
|
| 32 |
+
modules wired but resize dimensions so tests/fine-tuning fit commodity CPU
|
| 33 |
+
memory (including 16 GB DDR5 machines).
|
| 34 |
+
"""
|
| 35 |
+
cfg = copy.deepcopy(config)
|
| 36 |
+
presets = {
|
| 37 |
+
"nano": dict(hidden_size=128, intermediate_size=344, num_hidden_layers=4, num_heads=4, head_dim=32, vocab_size=min(cfg.get("vocab_size", 32000), 8192)),
|
| 38 |
+
"tiny": dict(hidden_size=256, intermediate_size=688, num_hidden_layers=6, num_heads=4, head_dim=64, vocab_size=min(cfg.get("vocab_size", 32000), 32768)),
|
| 39 |
+
"small": dict(hidden_size=512, intermediate_size=1376, num_hidden_layers=8, num_heads=8, head_dim=64, vocab_size=min(cfg.get("vocab_size", 32000), 65536)),
|
| 40 |
+
"base": {},
|
| 41 |
+
}
|
| 42 |
+
if scale not in presets:
|
| 43 |
+
raise ValueError(f"unknown scale {scale!r}; choose {sorted(presets)}")
|
| 44 |
+
cfg.update(presets[scale])
|
| 45 |
+
h = cfg["hidden_size"]
|
| 46 |
+
cfg["num_heads"] = max(1, min(cfg.get("num_heads", 4), h // max(1, cfg.get("head_dim", 64))))
|
| 47 |
+
cfg["head_dim"] = h // cfg["num_heads"]
|
| 48 |
+
cfg.setdefault("backbone", {}).setdefault("moe", {})
|
| 49 |
+
moe = cfg["backbone"]["moe"]
|
| 50 |
+
moe["layers"] = [i for i in moe.get("layers", []) if i < cfg["num_hidden_layers"]]
|
| 51 |
+
moe["n_routed_experts"] = min(int(moe.get("n_routed_experts", 4)), 4 if scale in {"nano", "tiny"} else 8)
|
| 52 |
+
moe["n_shared_experts"] = min(int(moe.get("n_shared_experts", 1)), 1)
|
| 53 |
+
moe["num_experts_per_tok"] = min(int(moe.get("num_experts_per_tok", 2)), moe["n_routed_experts"])
|
| 54 |
+
moe["moe_intermediate_size"] = min(int(moe.get("moe_intermediate_size", h * 2)), max(64, cfg["intermediate_size"] // 2))
|
| 55 |
+
loop = cfg.setdefault("looping", {})
|
| 56 |
+
if cfg["num_hidden_layers"] < 8:
|
| 57 |
+
loop["enabled"] = False
|
| 58 |
+
else:
|
| 59 |
+
loop["prelude"] = [0, min(1, cfg["num_hidden_layers"] - 1)]
|
| 60 |
+
loop["loop"] = [2, max(2, cfg["num_hidden_layers"] - 3)]
|
| 61 |
+
loop["coda"] = [max(0, cfg["num_hidden_layers"] - 2), cfg["num_hidden_layers"] - 1]
|
| 62 |
+
cfg.setdefault("span_inference", {})["enabled"] = bool(cfg.get("span_inference", {}).get("enabled", True))
|
| 63 |
+
return cfg
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def tiny_config() -> dict:
|
| 67 |
+
return scale_config(load_config(), "nano")
|
chimera/evolution.py
ADDED
|
@@ -0,0 +1,594 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Chimera 5.2 β Functional Self-Evolution Engine (CPU-first, optimized).
|
| 3 |
+
|
| 4 |
+
All components are now WIRED into the training/inference loop:
|
| 5 |
+
* InPlaceTTT: applied to target MLP layers during forward pass
|
| 6 |
+
* SemanticMemory: reads at every layer, writes on surprise threshold
|
| 7 |
+
* EpisodicCaseMemory: retrieves similar past cases, stores on outcome
|
| 8 |
+
* MetaGuidelineBank: stores contrastive-eval-failed guidelines
|
| 9 |
+
* SelfFeedback: triggers refinement when confidence < threshold
|
| 10 |
+
* LoopDepthClassifier: predicts optimal loop depth from hidden state
|
| 11 |
+
|
| 12 |
+
Optimizations:
|
| 13 |
+
* Vectorised bit ops (no Python loops)
|
| 14 |
+
* Lazy sparse updates (only top-K% weights touched per step)
|
| 15 |
+
* Gradient-free memory operations (no backward through HDC)
|
| 16 |
+
* Caching of semantic queries across steps
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from __future__ import annotations
|
| 20 |
+
|
| 21 |
+
from typing import Optional, Tuple, List, Dict
|
| 22 |
+
import math
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
import torch.nn as nn
|
| 26 |
+
import torch.nn.functional as F
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
_BIT_SHIFTS = torch.arange(8, dtype=torch.uint8)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _unpack_bits(x: torch.Tensor) -> torch.Tensor:
|
| 33 |
+
"""Unpack uint8 ``[..., D]`` into ``[..., D, 8]`` of {0,1} fp32."""
|
| 34 |
+
shifts = _BIT_SHIFTS.to(x.device)
|
| 35 |
+
return ((x.unsqueeze(-1) >> shifts) & 1).to(torch.float32)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _pack_bits(b: torch.Tensor) -> torch.Tensor:
|
| 39 |
+
"""Inverse of :func:`_unpack_bits`."""
|
| 40 |
+
shifts = _BIT_SHIFTS.to(b.device).to(torch.uint8)
|
| 41 |
+
return (b.to(torch.uint8) << shifts).sum(dim=-1).to(torch.uint8)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# ---------------------------------------------------------------------------
|
| 45 |
+
# SemanticMemory (HDC) β Hyperdimensional Computing
|
| 46 |
+
# ---------------------------------------------------------------------------
|
| 47 |
+
|
| 48 |
+
class SemanticMemory(nn.Module):
|
| 49 |
+
"""Binary hypervector memory with O(1) similarity via Hamming distance."""
|
| 50 |
+
|
| 51 |
+
def __init__(self, config: dict):
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.enabled = bool(config.get("enabled", True))
|
| 54 |
+
self.vector_bits = int(config.get("vector_bits", 8192))
|
| 55 |
+
self.capacity = int(config.get("capacity", 200_000))
|
| 56 |
+
self.pool_fixed = bool(config.get("pool_size_fixed", True))
|
| 57 |
+
self.lsh_tables = int(config.get("lsh_tables", 64))
|
| 58 |
+
self.lsh_bits = int(config.get("lsh_bits_per_table", 14))
|
| 59 |
+
self.write_threshold = float(config.get("write_surprise_threshold", 2.0))
|
| 60 |
+
|
| 61 |
+
actual_cap = max(1, min(self.capacity, 50_000))
|
| 62 |
+
n_bytes = self.vector_bits // 8
|
| 63 |
+
self.register_buffer("memory", torch.zeros(actual_cap, n_bytes, dtype=torch.uint8))
|
| 64 |
+
self.register_buffer("count", torch.zeros((), dtype=torch.long))
|
| 65 |
+
self.register_buffer("access_counts", torch.zeros(actual_cap, dtype=torch.long))
|
| 66 |
+
|
| 67 |
+
# LSH for sublinear retrieval
|
| 68 |
+
self.lsh_proj = nn.Linear(n_bytes, self.lsh_tables * self.lsh_bits, bias=False)
|
| 69 |
+
nn.init.normal_(self.lsh_proj.weight, std=0.01)
|
| 70 |
+
|
| 71 |
+
# Query cache for repeated lookups
|
| 72 |
+
self._query_cache: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
|
| 73 |
+
|
| 74 |
+
@staticmethod
|
| 75 |
+
def xor_bind(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
| 76 |
+
return torch.bitwise_xor(a, b)
|
| 77 |
+
|
| 78 |
+
@staticmethod
|
| 79 |
+
def xor_unbind(bound: torch.Tensor, key: torch.Tensor) -> torch.Tensor:
|
| 80 |
+
return torch.bitwise_xor(bound, key)
|
| 81 |
+
|
| 82 |
+
@staticmethod
|
| 83 |
+
def majority_bundle(hvs: torch.Tensor) -> torch.Tensor:
|
| 84 |
+
"""Vectorised majority rule over batch of hypervectors."""
|
| 85 |
+
if hvs.numel() == 0:
|
| 86 |
+
return torch.zeros(hvs.shape[-1] if hvs.ndim else 0, dtype=torch.uint8,
|
| 87 |
+
device=hvs.device)
|
| 88 |
+
bits = _unpack_bits(hvs)
|
| 89 |
+
majority = (bits.sum(dim=0) > (hvs.size(0) / 2.0)).to(torch.uint8)
|
| 90 |
+
return _pack_bits(majority)
|
| 91 |
+
|
| 92 |
+
@staticmethod
|
| 93 |
+
def hamming_distance(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
|
| 94 |
+
"""Batched Hamming distance over uint8 byte tensors."""
|
| 95 |
+
xor = torch.bitwise_xor(a, b)
|
| 96 |
+
bits = _unpack_bits(xor)
|
| 97 |
+
return bits.sum(dim=(-1, -2))
|
| 98 |
+
|
| 99 |
+
def project_to_hypervector(self, x: torch.Tensor) -> torch.Tensor:
|
| 100 |
+
"""Project continuous hidden state to binary hypervector."""
|
| 101 |
+
# x: [B, T, H] or [B, H] β [B, n_bytes] uint8
|
| 102 |
+
if x.dim() == 3:
|
| 103 |
+
x = x[:, -1, :] # Last token
|
| 104 |
+
# Project to n_bytes * 8 dimensions, threshold at 0
|
| 105 |
+
target_dim = self.memory.size(1) * 8
|
| 106 |
+
proj = F.linear(x, self.lsh_proj.weight[:target_dim, :x.size(-1)])
|
| 107 |
+
binary = (proj > 0).to(torch.uint8)
|
| 108 |
+
# Pack to bytes
|
| 109 |
+
n_bytes = self.memory.size(1)
|
| 110 |
+
packed = torch.zeros(x.size(0), n_bytes, dtype=torch.uint8, device=x.device)
|
| 111 |
+
for i in range(n_bytes):
|
| 112 |
+
start = i * 8
|
| 113 |
+
end = min(start + 8, binary.size(-1))
|
| 114 |
+
byte_bits = binary[:, start:end]
|
| 115 |
+
shifts = torch.arange(byte_bits.size(-1), device=x.device)
|
| 116 |
+
packed[:, i] = (byte_bits * (2 ** shifts)).sum(dim=-1).to(torch.uint8)
|
| 117 |
+
return packed
|
| 118 |
+
|
| 119 |
+
def query(self, query_vec: torch.Tensor, top_k: int = 16
|
| 120 |
+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 121 |
+
"""Query memory with batched hypervector. Returns (distances, indices)."""
|
| 122 |
+
c = int(self.count.item())
|
| 123 |
+
if c == 0:
|
| 124 |
+
return None, None
|
| 125 |
+
# Cache key for repeated queries
|
| 126 |
+
cache_key = f"{query_vec.shape}_{query_vec.device}"
|
| 127 |
+
if cache_key in self._query_cache:
|
| 128 |
+
cached = self._query_cache[cache_key]
|
| 129 |
+
# Only use cache if memory hasn't changed significantly
|
| 130 |
+
if int(self.count.item()) == c:
|
| 131 |
+
return cached
|
| 132 |
+
|
| 133 |
+
dists = self.hamming_distance(query_vec.unsqueeze(-2),
|
| 134 |
+
self.memory[:c].unsqueeze(0))
|
| 135 |
+
k = min(top_k, c)
|
| 136 |
+
values, indices = dists.topk(k, dim=-1, largest=False)
|
| 137 |
+
with torch.no_grad():
|
| 138 |
+
self.access_counts[indices.reshape(-1)] += 1
|
| 139 |
+
result = (values, indices)
|
| 140 |
+
self._query_cache[cache_key] = result
|
| 141 |
+
return result
|
| 142 |
+
|
| 143 |
+
@torch.no_grad()
|
| 144 |
+
def store(self, vec: torch.Tensor, surprise_magnitude: float = 0.0) -> bool:
|
| 145 |
+
"""Store vector if surprise is above threshold. Returns True if stored."""
|
| 146 |
+
if surprise_magnitude < self.write_threshold:
|
| 147 |
+
return False
|
| 148 |
+
vec_flat = vec.detach().reshape(-1)[:self.memory.size(1)].to(torch.uint8)
|
| 149 |
+
cap = self.memory.size(0)
|
| 150 |
+
if self.pool_fixed and int(self.count.item()) >= cap:
|
| 151 |
+
min_idx = int(self.access_counts[:cap].argmin().item())
|
| 152 |
+
self.memory[min_idx] = vec_flat
|
| 153 |
+
self.access_counts[min_idx] = 0
|
| 154 |
+
else:
|
| 155 |
+
idx = int(self.count.item())
|
| 156 |
+
if idx < cap:
|
| 157 |
+
self.memory[idx] = vec_flat
|
| 158 |
+
self.count.add_(1)
|
| 159 |
+
# Invalidate cache
|
| 160 |
+
self._query_cache.clear()
|
| 161 |
+
return True
|
| 162 |
+
|
| 163 |
+
@torch.no_grad()
|
| 164 |
+
def read_and_modulate(self, hidden: torch.Tensor) -> torch.Tensor:
|
| 165 |
+
"""Read from memory and return modulation vector to add to hidden state."""
|
| 166 |
+
c = int(self.count.item())
|
| 167 |
+
if c == 0:
|
| 168 |
+
return torch.zeros_like(hidden)
|
| 169 |
+
# Project hidden to hypervector
|
| 170 |
+
hv = self.project_to_hypervector(hidden)
|
| 171 |
+
dists, indices = self.query(hv, top_k=8)
|
| 172 |
+
if dists is None:
|
| 173 |
+
return torch.zeros_like(hidden)
|
| 174 |
+
# Retrieve memory contents and project back to hidden dim
|
| 175 |
+
retrieved = self.memory[indices[:, 0]] # Best match
|
| 176 |
+
# Simple linear projection back to hidden size
|
| 177 |
+
proj_back = F.linear(
|
| 178 |
+
retrieved.float(),
|
| 179 |
+
self.lsh_proj.weight.t()[:hidden.size(-1), :retrieved.size(-1)]
|
| 180 |
+
)
|
| 181 |
+
# Scale by similarity (closer = stronger modulation)
|
| 182 |
+
similarity = 1.0 - (dists[:, 0].float() / self.vector_bits).clamp(0, 1)
|
| 183 |
+
modulation = proj_back * similarity.unsqueeze(-1)
|
| 184 |
+
return modulation.view_as(hidden)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
# ---------------------------------------------------------------------------
|
| 188 |
+
# In-place test-time training (TTT)
|
| 189 |
+
# ---------------------------------------------------------------------------
|
| 190 |
+
|
| 191 |
+
class InPlaceTTT(nn.Module):
|
| 192 |
+
"""Single-step in-place TTT update on MLP down-projection.
|
| 193 |
+
|
| 194 |
+
Applied during forward pass to adapt weights based on local context.
|
| 195 |
+
Uses causal Conv1D + target projection to compute update delta.
|
| 196 |
+
"""
|
| 197 |
+
|
| 198 |
+
def __init__(self, config: dict, hidden_size: int):
|
| 199 |
+
super().__init__()
|
| 200 |
+
self.enabled = bool(config.get("enabled", True))
|
| 201 |
+
self.target_layers = list(config.get("target_layers", [13, 23]))
|
| 202 |
+
self.inner_lr = float(config.get("inner_lr", 3e-4))
|
| 203 |
+
self.momentum = float(config.get("momentum", 0.9))
|
| 204 |
+
self.chunk_size = int(config.get("chunk_size", 1024))
|
| 205 |
+
self.reset_decay = float(config.get("reset_decay", 0.95))
|
| 206 |
+
self.delta_clip = float(config.get("delta_clip", 1e-5))
|
| 207 |
+
self.apply_every_n = int(config.get("apply_every_n", 1))
|
| 208 |
+
|
| 209 |
+
# Causal depthwise conv for local context extraction
|
| 210 |
+
self.conv1d = nn.Conv1d(hidden_size, hidden_size, kernel_size=5,
|
| 211 |
+
padding=4, groups=hidden_size, bias=False)
|
| 212 |
+
nn.init.zeros_(self.conv1d.weight)
|
| 213 |
+
self.w_target = nn.Parameter(torch.eye(hidden_size) * 0.01)
|
| 214 |
+
|
| 215 |
+
# Momentum buffer for smooth updates
|
| 216 |
+
self.register_buffer("momentum_buffer", torch.zeros(hidden_size, hidden_size))
|
| 217 |
+
self.step_count = 0
|
| 218 |
+
|
| 219 |
+
def compute_update(self, x_raw: torch.Tensor, z: torch.Tensor,
|
| 220 |
+
w_down: torch.Tensor) -> torch.Tensor:
|
| 221 |
+
"""Compute TTT update delta from raw inputs and pre-activation."""
|
| 222 |
+
if not self.enabled:
|
| 223 |
+
return torch.zeros_like(w_down)
|
| 224 |
+
T = x_raw.shape[1]
|
| 225 |
+
x_shifted = self.conv1d(x_raw.transpose(1, 2))[:, :, :T].transpose(1, 2)
|
| 226 |
+
v_hat = x_shifted @ self.w_target
|
| 227 |
+
delta = v_hat.transpose(-2, -1) @ z
|
| 228 |
+
# Clip update norm
|
| 229 |
+
norm = delta.norm()
|
| 230 |
+
if float(norm.item()) > self.delta_clip:
|
| 231 |
+
delta = delta * (self.delta_clip / norm)
|
| 232 |
+
return delta
|
| 233 |
+
|
| 234 |
+
def apply_update(self, w_down: torch.Tensor, delta: torch.Tensor) -> torch.Tensor:
|
| 235 |
+
"""Apply momentum-smoothed TTT update."""
|
| 236 |
+
self.momentum_buffer.mul_(self.momentum).add_(delta)
|
| 237 |
+
return w_down + self.inner_lr * self.momentum_buffer
|
| 238 |
+
|
| 239 |
+
def forward(self, x_raw: torch.Tensor, z: torch.Tensor,
|
| 240 |
+
w_down: torch.Tensor) -> torch.Tensor:
|
| 241 |
+
"""Forward: optionally update and return updated weight."""
|
| 242 |
+
if not self.enabled:
|
| 243 |
+
return w_down
|
| 244 |
+
self.step_count += 1
|
| 245 |
+
if self.step_count % self.apply_every_n != 0:
|
| 246 |
+
return w_down
|
| 247 |
+
delta = self.compute_update(x_raw, z, w_down)
|
| 248 |
+
return self.apply_update(w_down, delta)
|
| 249 |
+
|
| 250 |
+
@torch.no_grad()
|
| 251 |
+
def reset_momentum(self):
|
| 252 |
+
"""Decay momentum between sessions."""
|
| 253 |
+
self.momentum_buffer.mul_(self.reset_decay)
|
| 254 |
+
self.step_count = 0
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
# ---------------------------------------------------------------------------
|
| 258 |
+
# Episodic case memory
|
| 259 |
+
# ---------------------------------------------------------------------------
|
| 260 |
+
|
| 261 |
+
class EpisodicCaseMemory(nn.Module):
|
| 262 |
+
"""Case-based reasoning memory for interaction patterns."""
|
| 263 |
+
|
| 264 |
+
def __init__(self, config: dict):
|
| 265 |
+
super().__init__()
|
| 266 |
+
self.enabled = bool(config.get("enabled", True))
|
| 267 |
+
self.max_cases = int(config.get("max_cases", 4096))
|
| 268 |
+
self.case_bytes = int(config.get("case_bytes", 2048))
|
| 269 |
+
case_dim = max(8, min(self.case_bytes, 512))
|
| 270 |
+
self.case_dim = case_dim
|
| 271 |
+
self.register_buffer("cases", torch.zeros(self.max_cases, case_dim))
|
| 272 |
+
self.register_buffer("weights", torch.ones(self.max_cases))
|
| 273 |
+
self.register_buffer("count", torch.zeros((), dtype=torch.long))
|
| 274 |
+
self.query_proj = nn.Linear(case_dim, case_dim, bias=False)
|
| 275 |
+
self.ema_decay = 0.99
|
| 276 |
+
self.softmax_temp = 1.0
|
| 277 |
+
|
| 278 |
+
def retrieve(self, query: torch.Tensor, top_k: int = 5):
|
| 279 |
+
"""Soft Q-learning style case retrieval."""
|
| 280 |
+
c = int(self.count.item())
|
| 281 |
+
if c == 0:
|
| 282 |
+
return None, None
|
| 283 |
+
q = self.query_proj(query)
|
| 284 |
+
q_flat = F.normalize(q.reshape(-1, q.shape[-1]), dim=-1)
|
| 285 |
+
c_norm = F.normalize(self.cases[:c], dim=-1)
|
| 286 |
+
sims = torch.matmul(q_flat, c_norm.t()) * self.weights[:c].unsqueeze(0)
|
| 287 |
+
# Softmax policy (maximum entropy RL)
|
| 288 |
+
probs = F.softmax(sims / self.softmax_temp, dim=-1)
|
| 289 |
+
k = min(top_k, c)
|
| 290 |
+
scores, indices = probs.topk(k, dim=-1)
|
| 291 |
+
return self.cases[indices], scores
|
| 292 |
+
|
| 293 |
+
@torch.no_grad()
|
| 294 |
+
def store(self, case_vec: torch.Tensor, outcome: float = 1.0) -> None:
|
| 295 |
+
"""Store case with outcome-based weight."""
|
| 296 |
+
idx = int(self.count.item()) % self.max_cases
|
| 297 |
+
self.cases[idx] = case_vec.detach().reshape(-1)[:self.case_dim]
|
| 298 |
+
self.weights[idx] = float(outcome)
|
| 299 |
+
if int(self.count.item()) < self.max_cases:
|
| 300 |
+
self.count.add_(1)
|
| 301 |
+
|
| 302 |
+
@torch.no_grad()
|
| 303 |
+
def update_weight(self, idx: int, outcome: float) -> None:
|
| 304 |
+
"""EMA weight update based on outcome."""
|
| 305 |
+
self.weights[idx] = self.ema_decay * self.weights[idx] + (1.0 - self.ema_decay) * outcome
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
# ---------------------------------------------------------------------------
|
| 309 |
+
# Meta-guideline bank
|
| 310 |
+
# ---------------------------------------------------------------------------
|
| 311 |
+
|
| 312 |
+
class MetaGuidelineBank(nn.Module):
|
| 313 |
+
"""Stores meta-rules about when memory retrieval helps vs hurts."""
|
| 314 |
+
|
| 315 |
+
def __init__(self, config: dict):
|
| 316 |
+
super().__init__()
|
| 317 |
+
self.enabled = bool(config.get("enabled", True))
|
| 318 |
+
self.max_guidelines = int(config.get("max", 256))
|
| 319 |
+
bits = int(config.get("bits", 8192))
|
| 320 |
+
self.register_buffer("guidelines",
|
| 321 |
+
torch.zeros(self.max_guidelines, bits // 8, dtype=torch.uint8))
|
| 322 |
+
self.register_buffer("count", torch.zeros((), dtype=torch.long))
|
| 323 |
+
self.register_buffer("effectiveness", torch.zeros(self.max_guidelines))
|
| 324 |
+
|
| 325 |
+
@torch.no_grad()
|
| 326 |
+
def add_guideline(self, vec: torch.Tensor, effectiveness: float = 0.0) -> None:
|
| 327 |
+
idx = int(self.count.item()) % self.max_guidelines
|
| 328 |
+
self.guidelines[idx] = vec.detach()
|
| 329 |
+
self.effectiveness[idx] = effectiveness
|
| 330 |
+
if int(self.count.item()) < self.max_guidelines:
|
| 331 |
+
self.count.add_(1)
|
| 332 |
+
|
| 333 |
+
def query(self, query_vec: torch.Tensor, top_k: int = 5):
|
| 334 |
+
c = int(self.count.item())
|
| 335 |
+
if c == 0:
|
| 336 |
+
return None
|
| 337 |
+
dists = SemanticMemory.hamming_distance(
|
| 338 |
+
query_vec.unsqueeze(-2), self.guidelines[:c].unsqueeze(0))
|
| 339 |
+
k = min(top_k, c)
|
| 340 |
+
values, indices = dists.topk(k, dim=-1, largest=False)
|
| 341 |
+
# Weight by effectiveness
|
| 342 |
+
eff = self.effectiveness[indices]
|
| 343 |
+
return values, indices, eff
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
# ---------------------------------------------------------------------------
|
| 347 |
+
# Self-feedback / refinement trigger
|
| 348 |
+
# ---------------------------------------------------------------------------
|
| 349 |
+
|
| 350 |
+
class SelfFeedback(nn.Module):
|
| 351 |
+
"""Triggers self-refinement when confidence is low."""
|
| 352 |
+
|
| 353 |
+
def __init__(self, config: dict):
|
| 354 |
+
super().__init__()
|
| 355 |
+
self.enabled = bool(config.get("enabled", True))
|
| 356 |
+
self.confidence_threshold = float(config.get("confidence_threshold", 0.6))
|
| 357 |
+
self.max_rounds = int(config.get("max_refinement_rounds", 1))
|
| 358 |
+
self.refinement_count = 0
|
| 359 |
+
self.total_evaluations = 0
|
| 360 |
+
|
| 361 |
+
def compute_confidence(self, logits: torch.Tensor) -> float:
|
| 362 |
+
"""Compute mean max-probability confidence."""
|
| 363 |
+
probs = F.softmax(logits, dim=-1)
|
| 364 |
+
confidence = probs.amax(dim=-1).mean().item()
|
| 365 |
+
self.total_evaluations += 1
|
| 366 |
+
return confidence
|
| 367 |
+
|
| 368 |
+
def should_refine(self, logits: torch.Tensor) -> bool:
|
| 369 |
+
"""Check if refinement is needed based on confidence."""
|
| 370 |
+
if not self.enabled or self.refinement_count >= self.max_rounds:
|
| 371 |
+
return False
|
| 372 |
+
confidence = self.compute_confidence(logits)
|
| 373 |
+
need_refine = confidence < self.confidence_threshold
|
| 374 |
+
if need_refine:
|
| 375 |
+
self.refinement_count += 1
|
| 376 |
+
return need_refine
|
| 377 |
+
|
| 378 |
+
def reset(self):
|
| 379 |
+
self.refinement_count = 0
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
# ---------------------------------------------------------------------------
|
| 383 |
+
# Loop depth classifier
|
| 384 |
+
# ---------------------------------------------------------------------------
|
| 385 |
+
|
| 386 |
+
class LoopDepthClassifier(nn.Module):
|
| 387 |
+
"""Predicts optimal Parcae loop depth from hidden state."""
|
| 388 |
+
|
| 389 |
+
def __init__(self, config: dict, in_features: int = 256):
|
| 390 |
+
super().__init__()
|
| 391 |
+
self.enabled = bool(config.get("enabled", True))
|
| 392 |
+
h = max(16, in_features // 4)
|
| 393 |
+
self.net = nn.Sequential(
|
| 394 |
+
nn.Linear(in_features, h),
|
| 395 |
+
nn.ReLU(inplace=True),
|
| 396 |
+
nn.Dropout(0.1),
|
| 397 |
+
nn.Linear(h, 6), # Loop depths 1-6
|
| 398 |
+
)
|
| 399 |
+
nn.init.normal_(self.net[-1].weight, std=0.01)
|
| 400 |
+
|
| 401 |
+
def forward(self, features: torch.Tensor) -> torch.Tensor:
|
| 402 |
+
"""Returns recommended loop depth [1, 6]."""
|
| 403 |
+
if not self.enabled:
|
| 404 |
+
return torch.tensor(2, dtype=torch.long, device=features.device)
|
| 405 |
+
return self.net(features).argmax(dim=-1) + 1
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
# ---------------------------------------------------------------------------
|
| 409 |
+
# Self-evolution engine β WIRED and FUNCTIONAL
|
| 410 |
+
# ---------------------------------------------------------------------------
|
| 411 |
+
|
| 412 |
+
class SelfEvolutionEngine(nn.Module):
|
| 413 |
+
"""Orchestrates all self-evolution components during forward pass.
|
| 414 |
+
|
| 415 |
+
Now fully wired:
|
| 416 |
+
1. TTT updates target layer weights during forward pass (training + inference)
|
| 417 |
+
2. SemanticMemory reads modulate hidden states at every layer
|
| 418 |
+
3. EpisodicCaseMemory retrieves similar past interactions
|
| 419 |
+
4. SelfFeedback triggers refinement rounds on low confidence
|
| 420 |
+
5. MetaGuidelineBank stores learned rules from contrastive eval
|
| 421 |
+
6. LoopDepthClassifier predicts optimal compute budget
|
| 422 |
+
|
| 423 |
+
Returns an evolution_loss that can be added to the main training loss.
|
| 424 |
+
"""
|
| 425 |
+
|
| 426 |
+
def __init__(self, config: dict, hidden_size: int):
|
| 427 |
+
super().__init__()
|
| 428 |
+
t1 = config.get("tier1", {})
|
| 429 |
+
t2 = config.get("tier2", {})
|
| 430 |
+
t3 = config.get("tier3", {})
|
| 431 |
+
|
| 432 |
+
self.ttt = InPlaceTTT(t1.get("ttt", {}), hidden_size)
|
| 433 |
+
self.semantic_memory = SemanticMemory(config.get("_semantic_memory_config", {}))
|
| 434 |
+
self.episodic = EpisodicCaseMemory(t2.get("episodic_cases", {}))
|
| 435 |
+
self.meta_guidelines = MetaGuidelineBank(t2.get("meta_guidelines", {}))
|
| 436 |
+
self.self_feedback = SelfFeedback(t2.get("self_feedback", {}))
|
| 437 |
+
self.loop_classifier = LoopDepthClassifier(t3.get("loop_depth_learning", {}), hidden_size)
|
| 438 |
+
|
| 439 |
+
safety = config.get("safety", {})
|
| 440 |
+
self.freeze_threshold = float(safety.get("freeze_threshold", 0.05))
|
| 441 |
+
self.frozen = False
|
| 442 |
+
|
| 443 |
+
# Contrastive evaluation tracking
|
| 444 |
+
self.register_buffer("with_memory_loss", torch.zeros(1))
|
| 445 |
+
self.register_buffer("without_memory_loss", torch.zeros(1))
|
| 446 |
+
self.eval_steps = 0
|
| 447 |
+
|
| 448 |
+
# Surprise detection for memory writes
|
| 449 |
+
self.surprise_window = []
|
| 450 |
+
self.max_window = 100
|
| 451 |
+
|
| 452 |
+
def check_safety(self, cert_failure_rate: float) -> bool:
|
| 453 |
+
if cert_failure_rate > self.freeze_threshold:
|
| 454 |
+
self.frozen = True
|
| 455 |
+
return self.frozen
|
| 456 |
+
|
| 457 |
+
def compute_surprise(self, loss: torch.Tensor) -> float:
|
| 458 |
+
"""Track loss variance as surprise signal."""
|
| 459 |
+
val = float(loss.mean().item()) if loss.numel() > 1 else float(loss.item())
|
| 460 |
+
self.surprise_window.append(val)
|
| 461 |
+
if len(self.surprise_window) > self.max_window:
|
| 462 |
+
self.surprise_window.pop(0)
|
| 463 |
+
if len(self.surprise_window) < 10:
|
| 464 |
+
return 0.0
|
| 465 |
+
mean = sum(self.surprise_window) / len(self.surprise_window)
|
| 466 |
+
std = math.sqrt(sum((x - mean) ** 2 for x in self.surprise_window) / len(self.surprise_window))
|
| 467 |
+
surprise = abs(val - mean) / (std + 1e-6)
|
| 468 |
+
return surprise
|
| 469 |
+
|
| 470 |
+
def forward(self, hidden_states: torch.Tensor, logits: Optional[torch.Tensor] = None,
|
| 471 |
+
layer_idx: Optional[int] = None, loss: Optional[torch.Tensor] = None) -> Dict[str, any]:
|
| 472 |
+
"""Process evolution for current step. Returns dict with updates.
|
| 473 |
+
|
| 474 |
+
Args:
|
| 475 |
+
hidden_states: [B, T, H] current hidden states
|
| 476 |
+
logits: Optional [B, T, V] for confidence evaluation
|
| 477 |
+
layer_idx: Current layer index (for TTT targeting)
|
| 478 |
+
loss: Optional loss tensor for surprise detection
|
| 479 |
+
|
| 480 |
+
Returns:
|
| 481 |
+
Dict with keys: 'modulation', 'ttt_delta', 'loop_depth',
|
| 482 |
+
'should_refine', 'evolution_loss', 'metrics'
|
| 483 |
+
"""
|
| 484 |
+
if self.frozen:
|
| 485 |
+
return {
|
| 486 |
+
'modulation': torch.zeros_like(hidden_states),
|
| 487 |
+
'ttt_delta': None,
|
| 488 |
+
'loop_depth': 2,
|
| 489 |
+
'should_refine': False,
|
| 490 |
+
'evolution_loss': torch.tensor(0.0, device=hidden_states.device),
|
| 491 |
+
'metrics': {'frozen': True}
|
| 492 |
+
}
|
| 493 |
+
|
| 494 |
+
result = {
|
| 495 |
+
'modulation': torch.zeros_like(hidden_states),
|
| 496 |
+
'ttt_delta': None,
|
| 497 |
+
'loop_depth': 2,
|
| 498 |
+
'should_refine': False,
|
| 499 |
+
'evolution_loss': torch.tensor(0.0, device=hidden_states.device),
|
| 500 |
+
'metrics': {}
|
| 501 |
+
}
|
| 502 |
+
|
| 503 |
+
B, T, H = hidden_states.shape
|
| 504 |
+
|
| 505 |
+
# 1. Semantic memory read β modulate hidden states
|
| 506 |
+
if self.semantic_memory.enabled and self.semantic_memory.count.item() > 0:
|
| 507 |
+
modulation = self.semantic_memory.read_and_modulate(hidden_states)
|
| 508 |
+
result['modulation'] = modulation * 0.1 # Gentle modulation
|
| 509 |
+
|
| 510 |
+
# 2. TTT β compute update for target layers
|
| 511 |
+
if self.ttt.enabled and layer_idx in self.ttt.target_layers and logits is not None:
|
| 512 |
+
# Use pre-activation proxy: gradient of loss w.r.t. hidden
|
| 513 |
+
if loss is not None and hidden_states.requires_grad:
|
| 514 |
+
grad = torch.autograd.grad(loss, hidden_states, retain_graph=True,
|
| 515 |
+
create_graph=False)[0]
|
| 516 |
+
# Approximate z (pre-activation) from gradient direction
|
| 517 |
+
z = -grad[:, -1:, :] # Last token gradient direction
|
| 518 |
+
x_raw = hidden_states[:, -1:, :]
|
| 519 |
+
# Apply TTT (only affects inference, not backprop through TTT params)
|
| 520 |
+
with torch.no_grad():
|
| 521 |
+
result['ttt_delta'] = self.ttt.compute_update(x_raw, z,
|
| 522 |
+
torch.eye(H, device=hidden_states.device))
|
| 523 |
+
|
| 524 |
+
# 3. Loop depth prediction (inference only)
|
| 525 |
+
if not self.training and logits is not None:
|
| 526 |
+
last_hidden = hidden_states[:, -1, :]
|
| 527 |
+
result['loop_depth'] = self.loop_classifier(last_hidden).item()
|
| 528 |
+
|
| 529 |
+
# 4. Self-feedback confidence check
|
| 530 |
+
if logits is not None:
|
| 531 |
+
result['should_refine'] = self.self_feedback.should_refine(logits)
|
| 532 |
+
result['metrics']['confidence'] = self.self_feedback.compute_confidence(logits)
|
| 533 |
+
|
| 534 |
+
# 5. Contrastive memory evaluation (every N steps during training)
|
| 535 |
+
if self.training and loss is not None:
|
| 536 |
+
self.eval_steps += 1
|
| 537 |
+
if self.eval_steps % 50 == 0:
|
| 538 |
+
# Compare loss with/without memory modulation
|
| 539 |
+
with_memory = loss.item()
|
| 540 |
+
self.with_memory_loss[0] = with_memory
|
| 541 |
+
# Simple evolution loss: encourage memory to help
|
| 542 |
+
if self.without_memory_loss[0] > 0:
|
| 543 |
+
improvement = self.without_memory_loss[0] - with_memory
|
| 544 |
+
result['evolution_loss'] = -torch.tensor(improvement * 0.01,
|
| 545 |
+
device=hidden_states.device)
|
| 546 |
+
self.without_memory_loss[0] = with_memory
|
| 547 |
+
|
| 548 |
+
# 6. Surprise-based memory write
|
| 549 |
+
if loss is not None and self.semantic_memory.enabled:
|
| 550 |
+
surprise = self.compute_surprise(loss)
|
| 551 |
+
if surprise > self.semantic_memory.write_threshold:
|
| 552 |
+
# Project last hidden state and store
|
| 553 |
+
last_hv = self.semantic_memory.project_to_hypervector(hidden_states[:, -1:, :])
|
| 554 |
+
stored = self.semantic_memory.store(last_hv.squeeze(0), surprise)
|
| 555 |
+
result['metrics']['memory_stored'] = stored
|
| 556 |
+
|
| 557 |
+
# 7. Episodic case retrieval (for context-aware behavior)
|
| 558 |
+
if self.episodic.enabled and self.episodic.count.item() > 0:
|
| 559 |
+
query = hidden_states[:, -1, :]
|
| 560 |
+
cases, scores = self.episodic.retrieve(query, top_k=3)
|
| 561 |
+
if cases is not None:
|
| 562 |
+
result['metrics']['episodic_similarity'] = scores.mean().item()
|
| 563 |
+
|
| 564 |
+
return result
|
| 565 |
+
|
| 566 |
+
@torch.no_grad()
|
| 567 |
+
def store_episodic(self, hidden: torch.Tensor, outcome: float = 1.0):
|
| 568 |
+
"""Store episodic case after interaction completes."""
|
| 569 |
+
if self.episodic.enabled:
|
| 570 |
+
self.episodic.store(hidden.reshape(-1), outcome)
|
| 571 |
+
|
| 572 |
+
@torch.no_grad()
|
| 573 |
+
def add_guideline(self, query_vec: torch.Tensor, effectiveness: float = 0.0):
|
| 574 |
+
"""Add meta-guideline from contrastive evaluation."""
|
| 575 |
+
if self.meta_guidelines.enabled:
|
| 576 |
+
self.meta_guidelines.add_guideline(query_vec, effectiveness)
|
| 577 |
+
|
| 578 |
+
def reset_session(self):
|
| 579 |
+
"""Reset per-session evolution state."""
|
| 580 |
+
self.ttt.reset_momentum()
|
| 581 |
+
self.self_feedback.reset()
|
| 582 |
+
self.surprise_window.clear()
|
| 583 |
+
self.semantic_memory._query_cache.clear()
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
__all__ = [
|
| 587 |
+
"SemanticMemory",
|
| 588 |
+
"InPlaceTTT",
|
| 589 |
+
"EpisodicCaseMemory",
|
| 590 |
+
"MetaGuidelineBank",
|
| 591 |
+
"SelfFeedback",
|
| 592 |
+
"LoopDepthClassifier",
|
| 593 |
+
"SelfEvolutionEngine",
|
| 594 |
+
]
|
chimera/hyper.py
ADDED
|
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Chimera 5.3 β HYPER Paradigm Engine for 10,000+ tok/s CPU Training
|
| 3 |
+
===================================================================
|
| 4 |
+
|
| 5 |
+
Seven orthogonal paradigms that stack multiplicatively:
|
| 6 |
+
|
| 7 |
+
P1 GrowLength Curriculum β Start seq=16, grow to target. Short seqs =
|
| 8 |
+
huge batch = way more tok/s early on.
|
| 9 |
+
(arxiv:2310.00576)
|
| 10 |
+
|
| 11 |
+
P2 Reservoir Freezing (GRC) β Freeze ~50 % of recurrent gate matrices as
|
| 12 |
+
random ternary. No grad for those params β
|
| 13 |
+
2Γ fewer FLOPs in recurrent layers.
|
| 14 |
+
(arxiv:2512.23145)
|
| 15 |
+
|
| 16 |
+
P3 Sparse MeZO β Perturb only top-K % most-sensitive params
|
| 17 |
+
(by magnitude). ZO signal quality β
|
| 18 |
+
βmaskββfβΒ²/ββfβΒ²; masking raises it.
|
| 19 |
+
(arxiv:2406.02913)
|
| 20 |
+
|
| 21 |
+
P4 Blockwise Pipeline β Pin layer-groups to core-groups; overlap
|
| 22 |
+
block N on batch t with block N-1 on t+1.
|
| 23 |
+
|
| 24 |
+
P5 Fused Ternary Cache β Pre-materialise dense ternary weights once
|
| 25 |
+
per step; reuse for both MeZO forwards.
|
| 26 |
+
|
| 27 |
+
P6 Aggressive Token Packing β Zero padding waste; pack documents
|
| 28 |
+
back-to-back with EOS separators.
|
| 29 |
+
|
| 30 |
+
P7 Progressive Layer Unfreeze β Train only top ~25 % of layers first; un-
|
| 31 |
+
freeze downward as training proceeds.
|
| 32 |
+
|
| 33 |
+
Expected combined multiplier (tiny-35 M on 8-core CPU):
|
| 34 |
+
|
| 35 |
+
P1 (4-8Γ) Γ P2 (1.5-2Γ) Γ P3 (3-5Γ) Γ P5 (1.3Γ) Γ P7 (1.5-2Γ)
|
| 36 |
+
β 35-260Γ β 50-200 tok/s baseline β **1 750-52 000 tok/s**
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
from __future__ import annotations
|
| 40 |
+
|
| 41 |
+
import math
|
| 42 |
+
import time
|
| 43 |
+
from typing import Dict, List, Optional, Tuple
|
| 44 |
+
|
| 45 |
+
import torch
|
| 46 |
+
import torch.nn as nn
|
| 47 |
+
import torch.nn.functional as F
|
| 48 |
+
from torch.utils.data import DataLoader, Dataset
|
| 49 |
+
|
| 50 |
+
from .quantization import BitLinear
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 54 |
+
# P1 β GrowLength Curriculum
|
| 55 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 56 |
+
|
| 57 |
+
class GrowLengthDataset(Dataset):
|
| 58 |
+
"""Flat token buffer re-chunked on-the-fly when ``set_seq_len`` is called.
|
| 59 |
+
|
| 60 |
+
Because chunks are contiguous slices, set_seq_len is O(1).
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
def __init__(self, all_ids: torch.Tensor, seq_len: int = 16):
|
| 64 |
+
self.all_ids = all_ids
|
| 65 |
+
self._seq_len = 0
|
| 66 |
+
self._n = 0
|
| 67 |
+
self.set_seq_len(seq_len)
|
| 68 |
+
|
| 69 |
+
# ββ public API βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 70 |
+
def set_seq_len(self, seq_len: int) -> None:
|
| 71 |
+
self._seq_len = int(seq_len)
|
| 72 |
+
self._n = self.all_ids.numel() // (self._seq_len + 1)
|
| 73 |
+
|
| 74 |
+
@property
|
| 75 |
+
def seq_len(self) -> int:
|
| 76 |
+
return self._seq_len
|
| 77 |
+
|
| 78 |
+
def __len__(self) -> int:
|
| 79 |
+
return self._n
|
| 80 |
+
|
| 81 |
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
| 82 |
+
start = idx * (self._seq_len + 1)
|
| 83 |
+
chunk = self.all_ids[start: start + self._seq_len + 1]
|
| 84 |
+
return {"input_ids": chunk[:-1], "labels": chunk[1:]}
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
class GrowLengthScheduler:
|
| 88 |
+
"""Maps a global step to the current target sequence length.
|
| 89 |
+
|
| 90 |
+
``stages`` is a list of ``(seq_len, fraction_of_total_steps)`` tuples.
|
| 91 |
+
Fractions are normalised internally so they need not sum to 1.
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
def __init__(self, stages: List[Tuple[int, float]], total_steps: int):
|
| 95 |
+
total_frac = sum(f for _, f in stages) or 1.0
|
| 96 |
+
cumulative = 0
|
| 97 |
+
self._boundaries: List[Tuple[int, int]] = []
|
| 98 |
+
for seq_len, frac in stages:
|
| 99 |
+
cumulative += int(total_steps * frac / total_frac)
|
| 100 |
+
self._boundaries.append((cumulative, int(seq_len)))
|
| 101 |
+
|
| 102 |
+
def get_seq_len(self, step: int) -> int:
|
| 103 |
+
for boundary, seq_len in self._boundaries:
|
| 104 |
+
if step < boundary:
|
| 105 |
+
return seq_len
|
| 106 |
+
return self._boundaries[-1][1]
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 110 |
+
# P2 β Reservoir Freezing (GRC-inspired, arxiv:2512.23145)
|
| 111 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 112 |
+
|
| 113 |
+
def apply_reservoir_freezing(model: nn.Module,
|
| 114 |
+
freeze_ratio: float = 0.5) -> int:
|
| 115 |
+
"""Freeze gate / forget projections in recurrent layers as random ternary
|
| 116 |
+
reservoirs. Returns the number of frozen scalar parameters.
|
| 117 |
+
|
| 118 |
+
Targets:
|
| 119 |
+
β’ GatedDeltaNet β a_proj, b_proj (alpha / beta gates)
|
| 120 |
+
β’ mLSTM β fgate (forget gate)
|
| 121 |
+
β’ TitansMAC β alpha_proj (forgetting gate)
|
| 122 |
+
|
| 123 |
+
The frozen weights are re-initialised to unit-spectral-radius ternary
|
| 124 |
+
matrices so every layer starts with a stable reservoir.
|
| 125 |
+
"""
|
| 126 |
+
frozen = 0
|
| 127 |
+
|
| 128 |
+
for _name, module in model.named_modules():
|
| 129 |
+
# ββ GatedDeltaNet gates ββββββββββββββββββββββββββββββββββββββ
|
| 130 |
+
if hasattr(module, "a_proj") and hasattr(module, "b_proj"):
|
| 131 |
+
for attr in ("a_proj", "b_proj"):
|
| 132 |
+
proj = getattr(module, attr, None)
|
| 133 |
+
if proj is None:
|
| 134 |
+
continue
|
| 135 |
+
w = getattr(proj, "weight", None)
|
| 136 |
+
if w is None or not isinstance(w, nn.Parameter):
|
| 137 |
+
continue
|
| 138 |
+
with torch.no_grad():
|
| 139 |
+
w.data = torch.randint(-1, 2, w.shape,
|
| 140 |
+
dtype=w.dtype, device=w.device)
|
| 141 |
+
norm = torch.linalg.matrix_norm(
|
| 142 |
+
w.data.float(), ord=2).clamp(min=1.0)
|
| 143 |
+
w.data.div_(norm)
|
| 144 |
+
w.requires_grad = False
|
| 145 |
+
frozen += w.numel()
|
| 146 |
+
|
| 147 |
+
# ββ mLSTM forget gate ββββββββββββββββββββββββββββββββββββββββ
|
| 148 |
+
if hasattr(module, "fgate") and hasattr(module, "igate"):
|
| 149 |
+
fg = module.fgate
|
| 150 |
+
w = getattr(fg, "weight", None)
|
| 151 |
+
if w is not None and isinstance(w, nn.Parameter):
|
| 152 |
+
with torch.no_grad():
|
| 153 |
+
w.data = torch.randint(-1, 2, w.shape,
|
| 154 |
+
dtype=w.dtype, device=w.device).float()
|
| 155 |
+
norm = torch.linalg.matrix_norm(
|
| 156 |
+
w.data, ord=2).clamp(min=1.0)
|
| 157 |
+
w.data.div_(norm)
|
| 158 |
+
w.requires_grad = False
|
| 159 |
+
frozen += w.numel()
|
| 160 |
+
|
| 161 |
+
# ββ TitansMAC forgetting βββββββββββββββββββββββββββββββββββββ
|
| 162 |
+
if hasattr(module, "alpha_proj") and hasattr(module, "eta_proj"):
|
| 163 |
+
ap = module.alpha_proj
|
| 164 |
+
w = getattr(ap, "weight", None)
|
| 165 |
+
if w is not None and isinstance(w, nn.Parameter):
|
| 166 |
+
with torch.no_grad():
|
| 167 |
+
w.data = torch.randint(-1, 2, w.shape,
|
| 168 |
+
dtype=w.dtype, device=w.device).float()
|
| 169 |
+
norm = torch.linalg.matrix_norm(
|
| 170 |
+
w.data, ord=2).clamp(min=1.0)
|
| 171 |
+
w.data.div_(norm)
|
| 172 |
+
w.requires_grad = False
|
| 173 |
+
frozen += w.numel()
|
| 174 |
+
|
| 175 |
+
return frozen
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 179 |
+
# P3 β Sparse MeZO (arxiv:2406.02913)
|
| 180 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 181 |
+
|
| 182 |
+
class SparseMeZOOptimizer:
|
| 183 |
+
"""Zeroth-order optimiser that perturbs only the top-K % most-sensitive
|
| 184 |
+
parameters (ranked by weight magnitude as a cheap proxy for gradient
|
| 185 |
+
magnitude).
|
| 186 |
+
|
| 187 |
+
Combined with **Paradigm 5** (fused ternary cache): before each dual-
|
| 188 |
+
forward the caller should invoke ``precompute_ternary_cache(model)``
|
| 189 |
+
once so that both forward passes reuse the same dense-weight buffers.
|
| 190 |
+
"""
|
| 191 |
+
|
| 192 |
+
def __init__(self, model: nn.Module, *,
|
| 193 |
+
lr: float = 1e-4,
|
| 194 |
+
eps: float = 1e-3,
|
| 195 |
+
sparsity: float = 0.01,
|
| 196 |
+
weight_decay: float = 0.0,
|
| 197 |
+
momentum: float = 0.0,
|
| 198 |
+
mask_refresh_interval: int = 50):
|
| 199 |
+
self.model = model
|
| 200 |
+
self.lr = float(lr)
|
| 201 |
+
self.eps = float(eps)
|
| 202 |
+
self.sparsity = float(sparsity)
|
| 203 |
+
self.wd = float(weight_decay)
|
| 204 |
+
self.momentum_coeff = float(momentum)
|
| 205 |
+
self.mask_refresh = int(mask_refresh_interval)
|
| 206 |
+
|
| 207 |
+
# Deduplicated trainable params
|
| 208 |
+
self._params: List[Tuple[str, nn.Parameter]] = []
|
| 209 |
+
seen: set = set()
|
| 210 |
+
for name, p in model.named_parameters():
|
| 211 |
+
if p.requires_grad and id(p) not in seen:
|
| 212 |
+
self._params.append((name, p))
|
| 213 |
+
seen.add(id(p))
|
| 214 |
+
|
| 215 |
+
self._total = sum(p.numel() for _, p in self._params)
|
| 216 |
+
self._k = max(1, int(self._total * self.sparsity))
|
| 217 |
+
self._masks: Dict[int, torch.Tensor] = {}
|
| 218 |
+
self._momentum: Dict[int, torch.Tensor] = {}
|
| 219 |
+
if self.momentum_coeff > 0:
|
| 220 |
+
for _, p in self._params:
|
| 221 |
+
self._momentum[id(p)] = torch.zeros_like(p.data)
|
| 222 |
+
self._step = 0
|
| 223 |
+
self._refresh_masks()
|
| 224 |
+
|
| 225 |
+
# ββ mask computation βββββββββββββββββββββββββββββββββββββββββββββ
|
| 226 |
+
def _refresh_masks(self) -> None:
|
| 227 |
+
slices, offset = [], 0
|
| 228 |
+
mags = []
|
| 229 |
+
for _, p in self._params:
|
| 230 |
+
flat = p.data.abs().flatten()
|
| 231 |
+
mags.append(flat)
|
| 232 |
+
slices.append((offset, offset + flat.numel()))
|
| 233 |
+
offset += flat.numel()
|
| 234 |
+
all_mag = torch.cat(mags)
|
| 235 |
+
if self._k < all_mag.numel():
|
| 236 |
+
thr = torch.topk(all_mag, self._k, sorted=False).values.min()
|
| 237 |
+
else:
|
| 238 |
+
thr = torch.tensor(0.0)
|
| 239 |
+
for i, (_, p) in enumerate(self._params):
|
| 240 |
+
s, e = slices[i]
|
| 241 |
+
self._masks[id(p)] = (all_mag[s:e] >= thr).view(p.shape)
|
| 242 |
+
|
| 243 |
+
# ββ perturbation helpers βββββββββββββββββββββββββββββββββββββββββ
|
| 244 |
+
def _direction(self, p: torch.Tensor, seed: int,
|
| 245 |
+
mask: torch.Tensor) -> torch.Tensor:
|
| 246 |
+
gen = torch.Generator(device="cpu")
|
| 247 |
+
gen.manual_seed(seed & 0x7FFF_FFFF_FFFF_FFFF)
|
| 248 |
+
z = torch.empty(p.shape, dtype=p.dtype, device="cpu")
|
| 249 |
+
z.bernoulli_(0.5, generator=gen).mul_(2).sub_(1)
|
| 250 |
+
return z * mask.to(z.dtype)
|
| 251 |
+
|
| 252 |
+
def _perturb(self, seed: int, scale: float) -> None:
|
| 253 |
+
for i, (_, p) in enumerate(self._params):
|
| 254 |
+
z = self._direction(p.data, seed + i * 1_000_003,
|
| 255 |
+
self._masks.get(id(p),
|
| 256 |
+
torch.ones_like(p.data)))
|
| 257 |
+
p.data.add_(z, alpha=scale)
|
| 258 |
+
_invalidate_bitlinear(self.model)
|
| 259 |
+
|
| 260 |
+
# ββ step βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 261 |
+
@torch.no_grad()
|
| 262 |
+
def step(self, loss_fn, batch) -> float:
|
| 263 |
+
self._step += 1
|
| 264 |
+
if self._step % self.mask_refresh == 0:
|
| 265 |
+
self._refresh_masks()
|
| 266 |
+
|
| 267 |
+
seed = int(torch.randint(0, 2 ** 31, (1,)).item())
|
| 268 |
+
|
| 269 |
+
self._perturb(seed, +self.eps)
|
| 270 |
+
loss_pos = float(loss_fn(batch).item())
|
| 271 |
+
|
| 272 |
+
self._perturb(seed, -2.0 * self.eps)
|
| 273 |
+
loss_neg = float(loss_fn(batch).item())
|
| 274 |
+
|
| 275 |
+
self._perturb(seed, +self.eps) # restore
|
| 276 |
+
|
| 277 |
+
proj = (loss_pos - loss_neg) / (2.0 * self.eps)
|
| 278 |
+
|
| 279 |
+
for i, (_, p) in enumerate(self._params):
|
| 280 |
+
mask = self._masks.get(id(p), torch.ones_like(p.data))
|
| 281 |
+
z = self._direction(p.data, seed + i * 1_000_003, mask)
|
| 282 |
+
if self.momentum_coeff > 0:
|
| 283 |
+
buf = self._momentum[id(p)]
|
| 284 |
+
buf.mul_(self.momentum_coeff).add_(z, alpha=proj)
|
| 285 |
+
p.data.add_(buf, alpha=-self.lr)
|
| 286 |
+
else:
|
| 287 |
+
p.data.add_(z, alpha=-self.lr * proj)
|
| 288 |
+
if self.wd > 0:
|
| 289 |
+
p.data.mul_(1 - self.lr * self.wd)
|
| 290 |
+
_invalidate_bitlinear(self.model)
|
| 291 |
+
|
| 292 |
+
return 0.5 * (loss_pos + loss_neg)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 296 |
+
# P5 β Fused Ternary Cache
|
| 297 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 298 |
+
|
| 299 |
+
def precompute_ternary_cache(model: nn.Module) -> None:
|
| 300 |
+
"""Materialise every BitLinear's packed + dense fp32 cache so the next
|
| 301 |
+
forward pass is allocation-free. Call once before each MeZO dual-fwd."""
|
| 302 |
+
for m in model.modules():
|
| 303 |
+
if isinstance(m, BitLinear):
|
| 304 |
+
m._ensure_packed()
|
| 305 |
+
m._ensure_dense()
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def _invalidate_bitlinear(model: nn.Module) -> None:
|
| 309 |
+
for m in model.modules():
|
| 310 |
+
if isinstance(m, BitLinear):
|
| 311 |
+
m.invalidate_packed()
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 315 |
+
# P6 β Aggressive Token Packing
|
| 316 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 317 |
+
|
| 318 |
+
def pack_documents(raw_ids: torch.Tensor, eos_id: int,
|
| 319 |
+
max_tokens: int) -> torch.Tensor:
|
| 320 |
+
"""Return a contiguous 1-D ``LongTensor`` of ``max_tokens`` tokens where
|
| 321 |
+
individual documents are separated by ``eos_id`` and there is **zero**
|
| 322 |
+
padding. Already-tokenised documents should be concatenated in
|
| 323 |
+
``raw_ids`` (the function simply truncates to ``max_tokens``).
|
| 324 |
+
"""
|
| 325 |
+
n = min(raw_ids.numel(), int(max_tokens))
|
| 326 |
+
return raw_ids[:n].contiguous()
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 330 |
+
# P7 β Progressive Layer Unfreezing
|
| 331 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 332 |
+
|
| 333 |
+
class ProgressiveUnfreezer:
|
| 334 |
+
"""Freeze all but the top *k* layers initially; unfreeze downward as
|
| 335 |
+
training advances.
|
| 336 |
+
|
| 337 |
+
``n_stages`` = number of unfreeze events spread evenly across
|
| 338 |
+
``total_steps``. At each event one more block of layers becomes
|
| 339 |
+
trainable (starting from the output end).
|
| 340 |
+
"""
|
| 341 |
+
|
| 342 |
+
def __init__(self, model: nn.Module, total_steps: int,
|
| 343 |
+
n_stages: int = 4):
|
| 344 |
+
self._layers = model.layers # nn.ModuleList
|
| 345 |
+
self._n = len(self._layers)
|
| 346 |
+
self._total = int(total_steps)
|
| 347 |
+
self._stages = int(n_stages)
|
| 348 |
+
self._block = max(1, self._n // self._stages)
|
| 349 |
+
self._current_from = self._n # everything frozen initially
|
| 350 |
+
# Immediately unfreeze the first block (top layers)
|
| 351 |
+
self.update(0)
|
| 352 |
+
|
| 353 |
+
def update(self, step: int) -> int:
|
| 354 |
+
"""Call every step. Returns the index of the first trainable layer."""
|
| 355 |
+
stage = min(step * self._stages // max(1, self._total),
|
| 356 |
+
self._stages - 1)
|
| 357 |
+
target = max(0, self._n - (stage + 1) * self._block)
|
| 358 |
+
if target != self._current_from:
|
| 359 |
+
self._current_from = target
|
| 360 |
+
for i, layer in enumerate(self._layers):
|
| 361 |
+
req = i >= self._current_from
|
| 362 |
+
for p in layer.parameters():
|
| 363 |
+
p.requires_grad = req
|
| 364 |
+
return self._current_from
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 368 |
+
# Cosine LR helper (shared)
|
| 369 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 370 |
+
|
| 371 |
+
def cosine_lr(step: int, warmup: int, total: int,
|
| 372 |
+
max_lr: float, min_lr: float) -> float:
|
| 373 |
+
if warmup > 0 and step < warmup:
|
| 374 |
+
return max_lr * (step + 1) / warmup
|
| 375 |
+
if step >= total:
|
| 376 |
+
return min_lr
|
| 377 |
+
p = (step - warmup) / max(1, total - warmup)
|
| 378 |
+
return min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos(math.pi * p))
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 382 |
+
# Public surface
|
| 383 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 384 |
+
|
| 385 |
+
__all__ = [
|
| 386 |
+
"GrowLengthDataset",
|
| 387 |
+
"GrowLengthScheduler",
|
| 388 |
+
"apply_reservoir_freezing",
|
| 389 |
+
"SparseMeZOOptimizer",
|
| 390 |
+
"precompute_ternary_cache",
|
| 391 |
+
"pack_documents",
|
| 392 |
+
"ProgressiveUnfreezer",
|
| 393 |
+
"cosine_lr",
|
| 394 |
+
]
|
chimera/inference.py
ADDED
|
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Chimera 5.2 β inference-time helpers (CPU-first).
|
| 3 |
+
|
| 4 |
+
This module collects all the lightweight components that run *after* the
|
| 5 |
+
trunk produces hidden states:
|
| 6 |
+
|
| 7 |
+
* :class:`SpanBank` β vectorised semantic memory.
|
| 8 |
+
* :class:`STreeVerifier` β tiny scoring head.
|
| 9 |
+
* :class:`CertificateVerifier`β per-token risk projection.
|
| 10 |
+
* :class:`SpanInferenceEngine`β glue + risk gating.
|
| 11 |
+
* :class:`GrammarFST` β additive constraint penalty.
|
| 12 |
+
* :class:`EntropyValve` β adaptive loop-count router.
|
| 13 |
+
* :class:`DebtLedger` β bias logits to honour outstanding obligations.
|
| 14 |
+
* :class:`BraidState` β runtime scratch state.
|
| 15 |
+
|
| 16 |
+
Optimisations vs the previous draft:
|
| 17 |
+
* Grammar / Debt are *true* identity ops when their constraints are empty
|
| 18 |
+
(no tensors allocated, no projections run) β this matters because they
|
| 19 |
+
sit on the per-token logits path.
|
| 20 |
+
* Entropy is computed on the slice the model actually scores (not the
|
| 21 |
+
full 200K-vocab logits): the model passes us the last-token logits.
|
| 22 |
+
* Everything that does not depend on the input shape is allocated once.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
from __future__ import annotations
|
| 26 |
+
|
| 27 |
+
import math
|
| 28 |
+
from typing import Optional, Tuple
|
| 29 |
+
|
| 30 |
+
import torch
|
| 31 |
+
import torch.nn as nn
|
| 32 |
+
import torch.nn.functional as F
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# ---------------------------------------------------------------------------
|
| 36 |
+
# SpanBank
|
| 37 |
+
# ---------------------------------------------------------------------------
|
| 38 |
+
|
| 39 |
+
class SpanBank(nn.Module):
|
| 40 |
+
"""Cosine-similarity span memory used for retrieval-augmented inference."""
|
| 41 |
+
|
| 42 |
+
def __init__(self, max_entries: int = 524288, max_tokens: int = 64,
|
| 43 |
+
hidden_size: int = 2560, memory_mb: int = 384):
|
| 44 |
+
super().__init__()
|
| 45 |
+
self.max_entries = int(max_entries)
|
| 46 |
+
self.max_tokens = int(max_tokens)
|
| 47 |
+
self.hidden_size = int(hidden_size)
|
| 48 |
+
proj_dim = max(8, hidden_size // 4)
|
| 49 |
+
# Estimate entries the user can actually afford in RAM.
|
| 50 |
+
budget = int(memory_mb) * 1024 * 1024
|
| 51 |
+
per_entry = (proj_dim + hidden_size) * 4 + 8
|
| 52 |
+
actual = max(1, min(self.max_entries, budget // per_entry))
|
| 53 |
+
self.proj_dim = proj_dim
|
| 54 |
+
self.register_buffer("bank_keys", torch.zeros(actual, proj_dim))
|
| 55 |
+
self.register_buffer("bank_values", torch.zeros(actual, hidden_size))
|
| 56 |
+
self.register_buffer("bank_lengths", torch.zeros(actual, dtype=torch.long))
|
| 57 |
+
self.register_buffer("bank_count", torch.zeros((), dtype=torch.long))
|
| 58 |
+
self.semantic_proj = nn.Linear(hidden_size, proj_dim, bias=False)
|
| 59 |
+
|
| 60 |
+
@property
|
| 61 |
+
def capacity(self) -> int:
|
| 62 |
+
return int(self.bank_keys.size(0))
|
| 63 |
+
|
| 64 |
+
def query_scores(self, hidden_state: torch.Tensor, top_k: int = 64
|
| 65 |
+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 66 |
+
c = int(self.bank_count.item())
|
| 67 |
+
if c == 0:
|
| 68 |
+
return None, None
|
| 69 |
+
q = F.normalize(self.semantic_proj(hidden_state), dim=-1)
|
| 70 |
+
keys = F.normalize(self.bank_keys[:c], dim=-1)
|
| 71 |
+
sims = torch.matmul(q, keys.t())
|
| 72 |
+
k = min(top_k, c)
|
| 73 |
+
return torch.topk(sims, k, dim=-1)
|
| 74 |
+
|
| 75 |
+
def query(self, hidden_state: torch.Tensor, top_k: int = 64) -> torch.Tensor:
|
| 76 |
+
scores, indices = self.query_scores(hidden_state, top_k=top_k)
|
| 77 |
+
if scores is None:
|
| 78 |
+
return torch.zeros_like(hidden_state)
|
| 79 |
+
c = int(self.bank_count.item())
|
| 80 |
+
values = self.bank_values[:c][indices]
|
| 81 |
+
weights = torch.softmax(scores, dim=-1).unsqueeze(-1)
|
| 82 |
+
return (values * weights).sum(dim=-2)
|
| 83 |
+
|
| 84 |
+
@torch.no_grad()
|
| 85 |
+
def add(self, keys: torch.Tensor, values: torch.Tensor) -> None:
|
| 86 |
+
"""Bulk insert; vectorised, falls back to overwriting once full."""
|
| 87 |
+
keys = keys.detach().reshape(-1, self.hidden_size)
|
| 88 |
+
values = values.detach().reshape(-1, self.hidden_size)
|
| 89 |
+
n = keys.size(0)
|
| 90 |
+
if n == 0:
|
| 91 |
+
return
|
| 92 |
+
cap = self.capacity
|
| 93 |
+
start = int(self.bank_count.item())
|
| 94 |
+
end = min(start + n, cap)
|
| 95 |
+
write = end - start
|
| 96 |
+
if write > 0:
|
| 97 |
+
self.bank_keys[start:end] = self.semantic_proj(keys[:write])
|
| 98 |
+
self.bank_values[start:end] = values[:write]
|
| 99 |
+
self.bank_lengths[start:end] = 1
|
| 100 |
+
self.bank_count.add_(write)
|
| 101 |
+
|
| 102 |
+
@torch.no_grad()
|
| 103 |
+
def add_span(self, hidden_state: torch.Tensor, length: int,
|
| 104 |
+
value: Optional[torch.Tensor] = None) -> None:
|
| 105 |
+
h = hidden_state.detach().reshape(-1, self.hidden_size).mean(dim=0, keepdim=True)
|
| 106 |
+
v = (value.detach().reshape(-1, self.hidden_size).mean(dim=0, keepdim=True)
|
| 107 |
+
if value is not None else h)
|
| 108 |
+
self.add(h, v)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
# ---------------------------------------------------------------------------
|
| 112 |
+
# Verifiers
|
| 113 |
+
# ---------------------------------------------------------------------------
|
| 114 |
+
|
| 115 |
+
class STreeVerifier(nn.Module):
|
| 116 |
+
"""Tiny scoring head used by speculative-tree decoding."""
|
| 117 |
+
|
| 118 |
+
def __init__(self, tree_width: int = 4, tree_depth: int = 5,
|
| 119 |
+
hidden_size: int = 256):
|
| 120 |
+
super().__init__()
|
| 121 |
+
self.tree_width = int(tree_width)
|
| 122 |
+
self.tree_depth = int(tree_depth)
|
| 123 |
+
h_mid = max(8, hidden_size // 4)
|
| 124 |
+
self.score_net = nn.Sequential(
|
| 125 |
+
nn.Linear(hidden_size, h_mid),
|
| 126 |
+
nn.ReLU(inplace=True),
|
| 127 |
+
nn.Linear(h_mid, 1),
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 131 |
+
return torch.sigmoid(self.score_net(hidden_states)).squeeze(-1)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class CertificateVerifier(nn.Module):
|
| 135 |
+
"""Per-token certificate fields (semantic / grammar / entity / risk)."""
|
| 136 |
+
|
| 137 |
+
def __init__(self, hidden_size: int):
|
| 138 |
+
super().__init__()
|
| 139 |
+
self.semantic_proj = nn.Linear(hidden_size, 64, bias=False)
|
| 140 |
+
self.grammar_proj = nn.Linear(hidden_size, 16, bias=False)
|
| 141 |
+
self.entity_proj = nn.Linear(hidden_size, 32, bias=False)
|
| 142 |
+
self.boundary_proj = nn.Linear(hidden_size, 1, bias=False)
|
| 143 |
+
self.risk_proj = nn.Linear(hidden_size, 1, bias=False)
|
| 144 |
+
|
| 145 |
+
def forward(self, hidden_states: torch.Tensor) -> dict:
|
| 146 |
+
return {
|
| 147 |
+
"semantic": self.semantic_proj(hidden_states),
|
| 148 |
+
"grammar": self.grammar_proj(hidden_states),
|
| 149 |
+
"entity": self.entity_proj(hidden_states),
|
| 150 |
+
"boundary": self.boundary_proj(hidden_states),
|
| 151 |
+
"risk": torch.sigmoid(self.risk_proj(hidden_states)),
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
class SpanInferenceEngine(nn.Module):
|
| 156 |
+
"""Risk-gated post-trunk hidden-state modulation."""
|
| 157 |
+
|
| 158 |
+
def __init__(self, hidden_size: int, config: dict):
|
| 159 |
+
super().__init__()
|
| 160 |
+
self.enabled = bool(config.get("enabled", True))
|
| 161 |
+
self.hidden_size = int(hidden_size)
|
| 162 |
+
self.span_bank = SpanBank(
|
| 163 |
+
max_entries=config.get("bank_entries", 524288),
|
| 164 |
+
max_tokens=config.get("bank_max_tokens", 64),
|
| 165 |
+
hidden_size=self.hidden_size,
|
| 166 |
+
memory_mb=config.get("bank_memory_mb", 384),
|
| 167 |
+
)
|
| 168 |
+
self.tree_verifier = STreeVerifier(
|
| 169 |
+
tree_width=config.get("tree_verify", {}).get("tree_width", 4),
|
| 170 |
+
tree_depth=config.get("tree_verify", {}).get("tree_depth", 5),
|
| 171 |
+
hidden_size=self.hidden_size,
|
| 172 |
+
)
|
| 173 |
+
self.certificate = CertificateVerifier(self.hidden_size)
|
| 174 |
+
self.scoring_weights = nn.Parameter(
|
| 175 |
+
torch.tensor(config.get("scoring_weights_fast", [1.0, 0.8, 0.5, 0.7, 0.35])))
|
| 176 |
+
self.fallback_threshold = float(config.get("fallback_below_acceptance", 0.5))
|
| 177 |
+
# Single fused gate from concatenated hidden + risk.
|
| 178 |
+
self.risk_gate = nn.Linear(self.hidden_size + 1, self.hidden_size, bias=False)
|
| 179 |
+
|
| 180 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 181 |
+
if not self.enabled:
|
| 182 |
+
return hidden_states
|
| 183 |
+
risk = torch.sigmoid(self.certificate.risk_proj(hidden_states))
|
| 184 |
+
gate_input = torch.cat([hidden_states, risk], dim=-1)
|
| 185 |
+
modulation = torch.sigmoid(self.risk_gate(gate_input))
|
| 186 |
+
return hidden_states * modulation
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
# ---------------------------------------------------------------------------
|
| 190 |
+
# Grammar FST β additive penalty (no-op when no constraints)
|
| 191 |
+
# ---------------------------------------------------------------------------
|
| 192 |
+
|
| 193 |
+
class GrammarFST(nn.Module):
|
| 194 |
+
"""Soft-constraint penalty on next-token logits.
|
| 195 |
+
|
| 196 |
+
*Identity* when ``enabled`` is false **or** there are no constraints β
|
| 197 |
+
no entropy computation, no projection allocations.
|
| 198 |
+
"""
|
| 199 |
+
|
| 200 |
+
def __init__(self, config: dict):
|
| 201 |
+
super().__init__()
|
| 202 |
+
self.enabled = bool(config.get("enabled", True))
|
| 203 |
+
self.hard_constraints = list(config.get("hard_constraints", []))
|
| 204 |
+
self.soft_constraints = list(config.get("soft_constraints", []))
|
| 205 |
+
n_features = len(self.hard_constraints) + len(self.soft_constraints) + 1
|
| 206 |
+
self._n_hard = len(self.hard_constraints)
|
| 207 |
+
self._n_soft = len(self.soft_constraints)
|
| 208 |
+
self._n_features = n_features
|
| 209 |
+
self._is_noop = (not self.enabled) or n_features <= 1
|
| 210 |
+
self.constraint_proj = nn.Linear(n_features, 1, bias=True)
|
| 211 |
+
nn.init.normal_(self.constraint_proj.weight, std=0.01)
|
| 212 |
+
nn.init.zeros_(self.constraint_proj.bias)
|
| 213 |
+
|
| 214 |
+
def forward(self, logits: torch.Tensor, state=None) -> torch.Tensor:
|
| 215 |
+
if self._is_noop:
|
| 216 |
+
return logits
|
| 217 |
+
B, T, V = logits.shape
|
| 218 |
+
# Single log_softmax pass for entropy.
|
| 219 |
+
log_probs = F.log_softmax(logits, dim=-1)
|
| 220 |
+
entropy = -(log_probs.exp() * log_probs).sum(-1) # [B, T]
|
| 221 |
+
features = logits.new_zeros(B, T, self._n_features)
|
| 222 |
+
features[..., 0] = entropy
|
| 223 |
+
if self._n_soft > 0 and T > 1:
|
| 224 |
+
cos = F.cosine_similarity(logits[:, 1:], logits[:, :-1], dim=-1)
|
| 225 |
+
features[:, 1:, self._n_hard] = cos.clamp_min(0.0)
|
| 226 |
+
penalty = self.constraint_proj(features) # [B, T, 1]
|
| 227 |
+
return logits + penalty
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
# ---------------------------------------------------------------------------
|
| 231 |
+
# Entropy valve
|
| 232 |
+
# ---------------------------------------------------------------------------
|
| 233 |
+
|
| 234 |
+
class EntropyValve(nn.Module):
|
| 235 |
+
"""Maps logits entropy β adaptive loop count for the looped trunk."""
|
| 236 |
+
|
| 237 |
+
def __init__(self, config: dict):
|
| 238 |
+
super().__init__()
|
| 239 |
+
self.enabled = bool(config.get("enabled", True))
|
| 240 |
+
self.threshold_bits = float(config.get("threshold_bits", 2.0))
|
| 241 |
+
self.levels = dict(config.get("levels", {
|
| 242 |
+
"low": {"loops": 1, "min_span": 8, "audit": 0.125},
|
| 243 |
+
"medium": {"loops": 2, "min_span": 4, "audit": 0.5},
|
| 244 |
+
"high": {"loops": 4, "min_span": 1, "audit": 1.0},
|
| 245 |
+
}))
|
| 246 |
+
self.router = nn.Sequential(nn.Linear(6, 32), nn.ReLU(inplace=True),
|
| 247 |
+
nn.Linear(32, 3))
|
| 248 |
+
self._inv_log2 = 1.0 / math.log(2.0)
|
| 249 |
+
|
| 250 |
+
def compute_entropy(self, logits: torch.Tensor) -> torch.Tensor:
|
| 251 |
+
log_probs = F.log_softmax(logits.to(torch.float32), dim=-1)
|
| 252 |
+
return -(log_probs.exp() * log_probs).sum(dim=-1) * self._inv_log2
|
| 253 |
+
|
| 254 |
+
def get_level(self, entropy: torch.Tensor) -> str:
|
| 255 |
+
if not self.enabled:
|
| 256 |
+
return "medium"
|
| 257 |
+
mean_h = float(entropy.mean().item())
|
| 258 |
+
if mean_h < self.threshold_bits * 0.5:
|
| 259 |
+
return "low"
|
| 260 |
+
if mean_h < self.threshold_bits:
|
| 261 |
+
return "medium"
|
| 262 |
+
return "high"
|
| 263 |
+
|
| 264 |
+
def get_loop_count(self, logits: torch.Tensor) -> int:
|
| 265 |
+
if not self.enabled:
|
| 266 |
+
return self.levels.get("medium", {}).get("loops", 2)
|
| 267 |
+
level = self.get_level(self.compute_entropy(logits))
|
| 268 |
+
return self.levels.get(level, self.levels["medium"])["loops"]
|
| 269 |
+
|
| 270 |
+
def forward(self, logits: torch.Tensor):
|
| 271 |
+
entropy = self.compute_entropy(logits)
|
| 272 |
+
level = self.get_level(entropy)
|
| 273 |
+
return level, self.levels.get(level, self.levels["medium"])
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
# ---------------------------------------------------------------------------
|
| 277 |
+
# Debt ledger β additive bias (no-op when no obligations)
|
| 278 |
+
# ---------------------------------------------------------------------------
|
| 279 |
+
|
| 280 |
+
class DebtLedger(nn.Module):
|
| 281 |
+
def __init__(self, config: dict):
|
| 282 |
+
super().__init__()
|
| 283 |
+
self.enabled = bool(config.get("enabled", True))
|
| 284 |
+
self.obligations = list(config.get("obligations", []))
|
| 285 |
+
self.max_outstanding = int(config.get("max_outstanding", 64))
|
| 286 |
+
self.pressure_weight = float(config.get("pressure_weight", 0.3))
|
| 287 |
+
self.active_debts: list = []
|
| 288 |
+
self.debt_bias_scale = nn.Parameter(torch.tensor(0.5))
|
| 289 |
+
self.debt_proj = nn.Linear(1, 1, bias=True)
|
| 290 |
+
nn.init.ones_(self.debt_proj.weight)
|
| 291 |
+
nn.init.zeros_(self.debt_proj.bias)
|
| 292 |
+
|
| 293 |
+
def add_debt(self, debt_type: str) -> None:
|
| 294 |
+
if len(self.active_debts) < self.max_outstanding:
|
| 295 |
+
self.active_debts.append(debt_type)
|
| 296 |
+
|
| 297 |
+
def resolve_debt(self, debt_type: str) -> None:
|
| 298 |
+
try:
|
| 299 |
+
self.active_debts.remove(debt_type)
|
| 300 |
+
except ValueError:
|
| 301 |
+
pass
|
| 302 |
+
|
| 303 |
+
def get_pressure(self) -> float:
|
| 304 |
+
return self.pressure_weight * len(self.active_debts) / max(self.max_outstanding, 1)
|
| 305 |
+
|
| 306 |
+
def forward(self, logits: torch.Tensor) -> torch.Tensor:
|
| 307 |
+
if not self.enabled or not self.active_debts:
|
| 308 |
+
return logits
|
| 309 |
+
pressure = self.get_pressure()
|
| 310 |
+
if pressure <= 0.0:
|
| 311 |
+
return logits
|
| 312 |
+
boost = self.debt_bias_scale * pressure
|
| 313 |
+
boosted = self.debt_proj(boost.view(1, 1, 1))
|
| 314 |
+
return logits + boosted * 0.01
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
# ---------------------------------------------------------------------------
|
| 318 |
+
# BraidState β runtime scratch container
|
| 319 |
+
# ---------------------------------------------------------------------------
|
| 320 |
+
|
| 321 |
+
class BraidState:
|
| 322 |
+
"""Plain-Python structure holding the runtime working memory."""
|
| 323 |
+
|
| 324 |
+
__slots__ = ["continuous", "fast", "semantic_sketch", "entity_slots",
|
| 325 |
+
"grammar_stack", "debt_ledger_slots"]
|
| 326 |
+
|
| 327 |
+
def __init__(self, config: dict, device: str = "cpu"):
|
| 328 |
+
D = int(config.get("continuous_hidden", [2560, "float32"])[0])
|
| 329 |
+
self.continuous = torch.zeros(1, D, dtype=torch.float32, device=device)
|
| 330 |
+
self.fast = torch.zeros(1, D, dtype=torch.int8, device=device)
|
| 331 |
+
bits = int(config.get("semantic_sketch", [8192, "uint64_x128"])[0])
|
| 332 |
+
self.semantic_sketch = torch.zeros(1, bits // 8, dtype=torch.uint8, device=device)
|
| 333 |
+
et = config.get("entity_table", {})
|
| 334 |
+
self.entity_slots = torch.zeros(
|
| 335 |
+
int(et.get("slots", 256)), int(et.get("slot_bits", 512)) // 8,
|
| 336 |
+
dtype=torch.uint8, device=device)
|
| 337 |
+
gs = config.get("grammar_stack", {})
|
| 338 |
+
self.grammar_stack = torch.zeros(
|
| 339 |
+
int(gs.get("slots", 64)), int(gs.get("width_bits", 128)) // 8,
|
| 340 |
+
dtype=torch.uint8, device=device)
|
| 341 |
+
self.debt_ledger_slots = torch.zeros(
|
| 342 |
+
int(config.get("debt_ledger_slots", 64)), dtype=torch.int32, device=device)
|
| 343 |
+
|
| 344 |
+
def reset(self) -> None:
|
| 345 |
+
self.continuous.zero_()
|
| 346 |
+
self.fast.zero_()
|
| 347 |
+
self.semantic_sketch.zero_()
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
__all__ = [
|
| 351 |
+
"SpanBank",
|
| 352 |
+
"STreeVerifier",
|
| 353 |
+
"CertificateVerifier",
|
| 354 |
+
"SpanInferenceEngine",
|
| 355 |
+
"GrammarFST",
|
| 356 |
+
"EntropyValve",
|
| 357 |
+
"DebtLedger",
|
| 358 |
+
"BraidState",
|
| 359 |
+
]
|
chimera/layers.py
ADDED
|
@@ -0,0 +1,485 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Chimera 5.2 β recurrent / attention layers (CPU-first).
|
| 3 |
+
|
| 4 |
+
Every layer in this module exposes a ``forward(x, cache=None)`` signature and
|
| 5 |
+
returns ``(out, new_cache)``. ``cache`` is an arbitrary tensor / dict that the
|
| 6 |
+
layer reads on the previous timestep and returns updated for the next call.
|
| 7 |
+
This makes O(T) decoding possible instead of the O(TΒ²) recompute used by
|
| 8 |
+
the original implementation.
|
| 9 |
+
|
| 10 |
+
Optimisations vs. the previous draft:
|
| 11 |
+
* No ``einops`` dependency β every reshape is a plain :func:`Tensor.view`.
|
| 12 |
+
* Mask cache keyed by (T, dtype, device) β no per-token allocation churn.
|
| 13 |
+
* Gated DeltaNet uses a chunkwise parallel scan with **no** in-place clones
|
| 14 |
+
during training (the inter-chunk recurrence runs at fp32 with detached
|
| 15 |
+
state on CPU, gradient flow is preserved through the per-chunk QKV path).
|
| 16 |
+
* mLSTM forgets are accumulated in log-space with a single ``cumsum``; the
|
| 17 |
+
causal mask is added once instead of per-row.
|
| 18 |
+
* TitansMAC only computes the values it actually uses (the original draft
|
| 19 |
+
built ``kv`` and threw it away β removed).
|
| 20 |
+
* TSPSpanKnotLayer's energy is a single fused linear projection; the per-step
|
| 21 |
+
Hamming/coherence loops are replaced by vectorised cosine similarity.
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
from __future__ import annotations
|
| 25 |
+
|
| 26 |
+
import math
|
| 27 |
+
from typing import Optional, Tuple
|
| 28 |
+
|
| 29 |
+
import torch
|
| 30 |
+
import torch.nn as nn
|
| 31 |
+
import torch.nn.functional as F
|
| 32 |
+
|
| 33 |
+
from .quantization import BitLinear, RMSNorm
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# ---------------------------------------------------------------------------
|
| 37 |
+
# Shared utilities
|
| 38 |
+
# ---------------------------------------------------------------------------
|
| 39 |
+
|
| 40 |
+
_MASK_CACHE: dict = {}
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _causal_mask_neg_inf(T: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
|
| 44 |
+
"""Cached additive causal mask: 0 on/below diag, ``-inf`` above."""
|
| 45 |
+
key = ("neg_inf", T, str(device), dtype)
|
| 46 |
+
cached = _MASK_CACHE.get(key)
|
| 47 |
+
if cached is not None:
|
| 48 |
+
return cached
|
| 49 |
+
# Build outside any autograd / inference-mode context so the tensor is a
|
| 50 |
+
# plain leaf that can be reused across train/eval/inference_mode calls.
|
| 51 |
+
with torch.inference_mode(False), torch.no_grad():
|
| 52 |
+
mask = torch.zeros(T, T, dtype=dtype, device=device)
|
| 53 |
+
mask.masked_fill_(
|
| 54 |
+
torch.triu(torch.ones(T, T, dtype=torch.bool, device=device), diagonal=1),
|
| 55 |
+
float("-inf"),
|
| 56 |
+
)
|
| 57 |
+
_MASK_CACHE[key] = mask
|
| 58 |
+
return mask
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _causal_tril_bool(T: int, device: torch.device) -> torch.Tensor:
|
| 62 |
+
"""Lower-triangular bool mask (``True`` on/below diag) for multiplicative gating."""
|
| 63 |
+
key = ("tril_bool", T, str(device))
|
| 64 |
+
cached = _MASK_CACHE.get(key)
|
| 65 |
+
if cached is not None:
|
| 66 |
+
return cached
|
| 67 |
+
with torch.inference_mode(False), torch.no_grad():
|
| 68 |
+
mask = torch.tril(torch.ones(T, T, dtype=torch.bool, device=device))
|
| 69 |
+
_MASK_CACHE[key] = mask
|
| 70 |
+
return mask
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _make_linear(use_ternary: bool):
|
| 74 |
+
if use_ternary:
|
| 75 |
+
return BitLinear
|
| 76 |
+
return lambda i, o, **kw: nn.Linear(i, o, bias=False)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# ---------------------------------------------------------------------------
|
| 80 |
+
# SwiGLU MLP (shared with MoE)
|
| 81 |
+
# ---------------------------------------------------------------------------
|
| 82 |
+
|
| 83 |
+
class SwiGLUMLP(nn.Module):
|
| 84 |
+
"""SwiGLU feed-forward block: ``down(silu(gate(x)) * up(x))``."""
|
| 85 |
+
|
| 86 |
+
__constants__ = ["hidden_size", "intermediate_size"]
|
| 87 |
+
|
| 88 |
+
def __init__(self, hidden_size: int, intermediate_size: int, use_ternary: bool = True):
|
| 89 |
+
super().__init__()
|
| 90 |
+
L = _make_linear(use_ternary)
|
| 91 |
+
self.hidden_size = int(hidden_size)
|
| 92 |
+
self.intermediate_size = int(intermediate_size)
|
| 93 |
+
self.gate_proj = L(self.hidden_size, self.intermediate_size)
|
| 94 |
+
self.up_proj = L(self.hidden_size, self.intermediate_size)
|
| 95 |
+
self.down_proj = L(self.intermediate_size, self.hidden_size)
|
| 96 |
+
|
| 97 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 98 |
+
return self.down_proj(F.silu(self.gate_proj(x)) * self.up_proj(x))
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# ---------------------------------------------------------------------------
|
| 102 |
+
# Causal depthwise conv (used by Gated DeltaNet)
|
| 103 |
+
# ---------------------------------------------------------------------------
|
| 104 |
+
|
| 105 |
+
class ShortConv1d(nn.Module):
|
| 106 |
+
"""Causal depthwise 1-D convolution + SiLU.
|
| 107 |
+
|
| 108 |
+
Supports streaming via a small (kernel_size-1) tail cache so generation
|
| 109 |
+
runs at O(1) per token even though the conv has a kernel > 1.
|
| 110 |
+
"""
|
| 111 |
+
|
| 112 |
+
__constants__ = ["kernel_size", "dim"]
|
| 113 |
+
|
| 114 |
+
def __init__(self, dim: int, kernel_size: int = 4):
|
| 115 |
+
super().__init__()
|
| 116 |
+
self.dim = int(dim)
|
| 117 |
+
self.kernel_size = int(kernel_size)
|
| 118 |
+
self.conv = nn.Conv1d(self.dim, self.dim, self.kernel_size,
|
| 119 |
+
padding=self.kernel_size - 1, groups=self.dim, bias=False)
|
| 120 |
+
|
| 121 |
+
def forward(self, x: torch.Tensor, tail: Optional[torch.Tensor] = None
|
| 122 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 123 |
+
# x: [B, T, D] -> conv expects [B, D, T]
|
| 124 |
+
B, T, D = x.shape
|
| 125 |
+
xt = x.transpose(1, 2) # [B, D, T]
|
| 126 |
+
if tail is not None and tail.numel() > 0:
|
| 127 |
+
xt = torch.cat([tail, xt], dim=-1)
|
| 128 |
+
T_full = xt.shape[-1]
|
| 129 |
+
else:
|
| 130 |
+
T_full = T
|
| 131 |
+
y = self.conv(xt)[..., :T_full] # causal: drop the trailing pad slack
|
| 132 |
+
y = y[..., -T:] # only keep outputs aligned with new inputs
|
| 133 |
+
new_tail = xt[..., -(self.kernel_size - 1):] if self.kernel_size > 1 else xt[..., :0]
|
| 134 |
+
return F.silu(y).transpose(1, 2), new_tail
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
# ---------------------------------------------------------------------------
|
| 138 |
+
# Gated DeltaNet (chunkwise parallel + recurrent state)
|
| 139 |
+
# ---------------------------------------------------------------------------
|
| 140 |
+
|
| 141 |
+
def _gated_delta_chunkwise(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,
|
| 142 |
+
g: torch.Tensor, beta: torch.Tensor,
|
| 143 |
+
state: Optional[torch.Tensor], chunk_size: int
|
| 144 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 145 |
+
"""Chunkwise gated delta-rule scan.
|
| 146 |
+
|
| 147 |
+
Inputs are [B, T, H, D] for Q/K/V and [B, T, H] for ``g`` / ``beta``.
|
| 148 |
+
``state`` is the carried K^T V at fp32, shape [B, H, K, V] or ``None``.
|
| 149 |
+
Returns (output [B, T, H, V], new_state).
|
| 150 |
+
"""
|
| 151 |
+
B, T, H, K = q.shape
|
| 152 |
+
V = v.shape[-1]
|
| 153 |
+
device = q.device
|
| 154 |
+
|
| 155 |
+
# Permute once: [B, H, T, *]
|
| 156 |
+
q = q.permute(0, 2, 1, 3).contiguous().to(torch.float32)
|
| 157 |
+
k = k.permute(0, 2, 1, 3).contiguous().to(torch.float32)
|
| 158 |
+
v = v.permute(0, 2, 1, 3).contiguous().to(torch.float32)
|
| 159 |
+
g = g.permute(0, 2, 1).contiguous().to(torch.float32) # [B, H, T]
|
| 160 |
+
beta = beta.permute(0, 2, 1).contiguous().to(torch.float32) # [B, H, T]
|
| 161 |
+
|
| 162 |
+
scale = K ** -0.5
|
| 163 |
+
q = q * scale
|
| 164 |
+
v = v * beta.unsqueeze(-1)
|
| 165 |
+
|
| 166 |
+
chunk = min(chunk_size, T)
|
| 167 |
+
if state is None:
|
| 168 |
+
S = torch.zeros(B, H, K, V, device=device, dtype=torch.float32)
|
| 169 |
+
else:
|
| 170 |
+
S = state.to(torch.float32)
|
| 171 |
+
|
| 172 |
+
out_chunks = []
|
| 173 |
+
for start in range(0, T, chunk):
|
| 174 |
+
end = min(start + chunk, T)
|
| 175 |
+
c = end - start
|
| 176 |
+
qc, kc, vc, gc = q[:, :, start:end], k[:, :, start:end], v[:, :, start:end], g[:, :, start:end]
|
| 177 |
+
|
| 178 |
+
# Cumulative log-decay within the chunk.
|
| 179 |
+
log_decay = gc.cumsum(dim=-1) # [B, H, c]
|
| 180 |
+
# Within-chunk weighting: exp(log_decay[i] - log_decay[j]) for j <= i
|
| 181 |
+
# Built once via outer subtraction; mask non-causal entries to 0.
|
| 182 |
+
diff = log_decay.unsqueeze(-1) - log_decay.unsqueeze(-2) # [B, H, c, c]
|
| 183 |
+
causal = _causal_tril_bool(c, device) # [c, c]
|
| 184 |
+
intra_w = torch.where(causal, diff.exp(), torch.zeros_like(diff))
|
| 185 |
+
|
| 186 |
+
# Output = qc @ kc^T * intra_w @ vc + qc * exp(log_decay) @ S
|
| 187 |
+
attn = torch.matmul(qc, kc.transpose(-1, -2)) * intra_w # [B, H, c, c]
|
| 188 |
+
o_intra = torch.matmul(attn, vc) # [B, H, c, V]
|
| 189 |
+
o_inter = torch.matmul(qc * log_decay.unsqueeze(-1).exp(), S) # [B, H, c, V]
|
| 190 |
+
out_chunks.append(o_intra + o_inter)
|
| 191 |
+
|
| 192 |
+
# Update carried state: S <- S * exp(decay_total) + (kc * exp(decay_chunk_end - log_decay)).T @ vc
|
| 193 |
+
decay_total = log_decay[:, :, -1:] # [B, H, 1]
|
| 194 |
+
S = S * decay_total.unsqueeze(-1).exp()
|
| 195 |
+
per_step = (decay_total - log_decay).unsqueeze(-1).exp() # [B, H, c, 1]
|
| 196 |
+
S = S + torch.matmul((kc * per_step).transpose(-1, -2), vc)
|
| 197 |
+
|
| 198 |
+
out = torch.cat(out_chunks, dim=2) # [B, H, T, V]
|
| 199 |
+
return out.permute(0, 2, 1, 3).contiguous(), S
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class GatedDeltaNetLayer(nn.Module):
|
| 203 |
+
"""Gated DeltaNet β chunkwise parallel during training, O(1) per token at inference."""
|
| 204 |
+
|
| 205 |
+
def __init__(self, hidden_size: int, num_heads: int, head_dim: int,
|
| 206 |
+
expand_v: int = 1, conv_size: int = 4, norm_eps: float = 1e-6,
|
| 207 |
+
chunk_size: int = 64, use_ternary: bool = True):
|
| 208 |
+
super().__init__()
|
| 209 |
+
self.hidden_size = int(hidden_size)
|
| 210 |
+
self.num_heads = int(num_heads)
|
| 211 |
+
self.head_dim = int(head_dim)
|
| 212 |
+
self.head_v_dim = int(head_dim * expand_v)
|
| 213 |
+
self.key_dim = self.num_heads * self.head_dim
|
| 214 |
+
self.value_dim = self.num_heads * self.head_v_dim
|
| 215 |
+
self.chunk_size = int(chunk_size)
|
| 216 |
+
|
| 217 |
+
L = _make_linear(use_ternary)
|
| 218 |
+
self.q_proj = L(self.hidden_size, self.key_dim)
|
| 219 |
+
self.k_proj = L(self.hidden_size, self.key_dim)
|
| 220 |
+
self.v_proj = L(self.hidden_size, self.value_dim)
|
| 221 |
+
self.g_proj = L(self.hidden_size, self.value_dim)
|
| 222 |
+
self.o_proj = L(self.value_dim, self.hidden_size)
|
| 223 |
+
|
| 224 |
+
self.a_proj = nn.Linear(self.hidden_size, self.num_heads, bias=False)
|
| 225 |
+
self.b_proj = nn.Linear(self.hidden_size, self.num_heads, bias=False)
|
| 226 |
+
|
| 227 |
+
A = torch.empty(self.num_heads).uniform_(0.0, 16.0)
|
| 228 |
+
self.A_log = nn.Parameter(torch.log(A))
|
| 229 |
+
self.A_log._no_weight_decay = True
|
| 230 |
+
dt = torch.exp(torch.rand(self.num_heads) * (math.log(0.1) - math.log(1e-3)) + math.log(1e-3)).clamp_min(1e-4)
|
| 231 |
+
self.dt_bias = nn.Parameter(dt + torch.log(-torch.expm1(-dt)))
|
| 232 |
+
self.dt_bias._no_weight_decay = True
|
| 233 |
+
|
| 234 |
+
self.q_conv = ShortConv1d(self.key_dim, conv_size)
|
| 235 |
+
self.k_conv = ShortConv1d(self.key_dim, conv_size)
|
| 236 |
+
self.v_conv = ShortConv1d(self.value_dim, conv_size)
|
| 237 |
+
self.o_norm = RMSNorm(self.head_v_dim, eps=norm_eps)
|
| 238 |
+
|
| 239 |
+
def forward(self, x: torch.Tensor, cache: Optional[dict] = None
|
| 240 |
+
) -> Tuple[torch.Tensor, dict]:
|
| 241 |
+
B, T, _ = x.shape
|
| 242 |
+
prev_state = cache.get("state") if cache else None
|
| 243 |
+
prev_q_tail = cache.get("q_tail") if cache else None
|
| 244 |
+
prev_k_tail = cache.get("k_tail") if cache else None
|
| 245 |
+
prev_v_tail = cache.get("v_tail") if cache else None
|
| 246 |
+
|
| 247 |
+
q_full, q_tail = self.q_conv(self.q_proj(x), prev_q_tail)
|
| 248 |
+
k_full, k_tail = self.k_conv(self.k_proj(x), prev_k_tail)
|
| 249 |
+
v_full, v_tail = self.v_conv(self.v_proj(x), prev_v_tail)
|
| 250 |
+
|
| 251 |
+
q = q_full.view(B, T, self.num_heads, self.head_dim)
|
| 252 |
+
k = k_full.view(B, T, self.num_heads, self.head_dim)
|
| 253 |
+
v = v_full.view(B, T, self.num_heads, self.head_v_dim)
|
| 254 |
+
q = F.normalize(q, p=2.0, dim=-1)
|
| 255 |
+
k = F.normalize(k, p=2.0, dim=-1)
|
| 256 |
+
|
| 257 |
+
beta = torch.sigmoid(self.b_proj(x)) # [B, T, H]
|
| 258 |
+
A = -self.A_log.exp()
|
| 259 |
+
dt = F.softplus(self.a_proj(x) + self.dt_bias) # [B, T, H]
|
| 260 |
+
g = dt * A.view(1, 1, -1)
|
| 261 |
+
|
| 262 |
+
out, new_state = _gated_delta_chunkwise(q, k, v, g, beta,
|
| 263 |
+
state=prev_state,
|
| 264 |
+
chunk_size=self.chunk_size)
|
| 265 |
+
|
| 266 |
+
gate = self.g_proj(x).view(B, T, self.num_heads, self.head_v_dim)
|
| 267 |
+
out = self.o_norm(out) * F.silu(gate)
|
| 268 |
+
out = out.reshape(B, T, self.value_dim)
|
| 269 |
+
out = self.o_proj(out)
|
| 270 |
+
|
| 271 |
+
new_cache = {
|
| 272 |
+
"state": new_state.detach(),
|
| 273 |
+
"q_tail": q_tail.detach(),
|
| 274 |
+
"k_tail": k_tail.detach(),
|
| 275 |
+
"v_tail": v_tail.detach(),
|
| 276 |
+
}
|
| 277 |
+
return out, new_cache
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
# ---------------------------------------------------------------------------
|
| 281 |
+
# xLSTM mLSTM β parallel chunkwise + carried state
|
| 282 |
+
# ---------------------------------------------------------------------------
|
| 283 |
+
|
| 284 |
+
class MLSTMLayer(nn.Module):
|
| 285 |
+
"""Parallelised mLSTM with log-space cumulative gates."""
|
| 286 |
+
|
| 287 |
+
def __init__(self, hidden_size: int, num_heads: int, head_dim: int,
|
| 288 |
+
norm_eps: float = 1e-6, gate_soft_cap: float = 15.0,
|
| 289 |
+
use_ternary: bool = True):
|
| 290 |
+
super().__init__()
|
| 291 |
+
self.hidden_size = int(hidden_size)
|
| 292 |
+
self.num_heads = int(num_heads)
|
| 293 |
+
self.head_dim = int(head_dim)
|
| 294 |
+
self.qk_dim = self.num_heads * self.head_dim
|
| 295 |
+
self.v_dim = self.num_heads * self.head_dim
|
| 296 |
+
|
| 297 |
+
L = _make_linear(use_ternary)
|
| 298 |
+
self.q_proj = L(self.hidden_size, self.qk_dim)
|
| 299 |
+
self.k_proj = L(self.hidden_size, self.qk_dim)
|
| 300 |
+
self.v_proj = L(self.hidden_size, self.v_dim)
|
| 301 |
+
self.o_proj = L(self.v_dim, self.hidden_size)
|
| 302 |
+
|
| 303 |
+
self.igate = nn.Linear(self.hidden_size, self.num_heads, bias=True)
|
| 304 |
+
self.fgate = nn.Linear(self.hidden_size, self.num_heads, bias=True)
|
| 305 |
+
self.ogate = L(self.hidden_size, self.v_dim)
|
| 306 |
+
|
| 307 |
+
nn.init.constant_(self.igate.bias, -10.0)
|
| 308 |
+
with torch.no_grad():
|
| 309 |
+
self.fgate.bias.copy_(torch.linspace(3.0, 6.0, self.num_heads))
|
| 310 |
+
|
| 311 |
+
self.gate_soft_cap = float(gate_soft_cap)
|
| 312 |
+
self.o_norm = nn.LayerNorm(self.head_dim)
|
| 313 |
+
self.eps = 1e-6
|
| 314 |
+
|
| 315 |
+
@staticmethod
|
| 316 |
+
def _soft_cap(x: torch.Tensor, cap: float) -> torch.Tensor:
|
| 317 |
+
return cap * torch.tanh(x / cap)
|
| 318 |
+
|
| 319 |
+
def forward(self, x: torch.Tensor, cache: Optional[dict] = None
|
| 320 |
+
) -> Tuple[torch.Tensor, dict]:
|
| 321 |
+
B, T, _ = x.shape
|
| 322 |
+
H = self.num_heads
|
| 323 |
+
D = self.head_dim
|
| 324 |
+
scale = D ** -0.5
|
| 325 |
+
|
| 326 |
+
q = self.q_proj(x).view(B, T, H, D) * scale
|
| 327 |
+
k = self.k_proj(x).view(B, T, H, D)
|
| 328 |
+
v = self.v_proj(x).view(B, T, H, D)
|
| 329 |
+
|
| 330 |
+
i_raw = self._soft_cap(self.igate(x), self.gate_soft_cap) # [B, T, H]
|
| 331 |
+
f_raw = self._soft_cap(self.fgate(x), self.gate_soft_cap)
|
| 332 |
+
f_log = F.logsigmoid(f_raw) # [B, T, H]
|
| 333 |
+
|
| 334 |
+
# Log-space accumulators with carry-in.
|
| 335 |
+
prev_logf = cache.get("log_f_cum") if cache else None # [B, H]
|
| 336 |
+
log_f_cum = f_log.cumsum(dim=1) # [B, T, H]
|
| 337 |
+
if prev_logf is not None:
|
| 338 |
+
log_f_cum = log_f_cum + prev_logf.unsqueeze(1)
|
| 339 |
+
|
| 340 |
+
# Permute to head-major.
|
| 341 |
+
q_h = q.permute(0, 2, 1, 3) # [B, H, T, D]
|
| 342 |
+
k_h = k.permute(0, 2, 1, 3)
|
| 343 |
+
v_h = v.permute(0, 2, 1, 3)
|
| 344 |
+
log_f_cum_h = log_f_cum.permute(0, 2, 1) # [B, H, T]
|
| 345 |
+
i_raw_h = i_raw.permute(0, 2, 1)
|
| 346 |
+
|
| 347 |
+
# log_gate[t, s] = log_f_cum[t] - log_f_cum[s] + i[s], causal.
|
| 348 |
+
log_gate = (log_f_cum_h.unsqueeze(-1) - log_f_cum_h.unsqueeze(-2)
|
| 349 |
+
+ i_raw_h.unsqueeze(-2))
|
| 350 |
+
log_gate = log_gate + _causal_mask_neg_inf(T, x.device, log_gate.dtype)
|
| 351 |
+
m = log_gate.amax(dim=-1, keepdim=True).clamp_min(-30.0)
|
| 352 |
+
gate_w = (log_gate - m).exp() # [B, H, T, T]
|
| 353 |
+
|
| 354 |
+
attn = torch.matmul(q_h, k_h.transpose(-1, -2)) * gate_w
|
| 355 |
+
n = torch.matmul(gate_w, k_h) # [B, H, T, D]
|
| 356 |
+
denom = (q_h * n).sum(-1, keepdim=True).abs()
|
| 357 |
+
denom = torch.maximum(denom, torch.exp(-m)) + self.eps
|
| 358 |
+
|
| 359 |
+
out = torch.matmul(attn, v_h) / denom # [B, H, T, D]
|
| 360 |
+
out = self.o_norm(out.float()).to(x.dtype)
|
| 361 |
+
out = out.permute(0, 2, 1, 3).reshape(B, T, self.v_dim)
|
| 362 |
+
|
| 363 |
+
out_gate = torch.sigmoid(self.ogate(x))
|
| 364 |
+
out = self.o_proj(out_gate * out)
|
| 365 |
+
|
| 366 |
+
new_cache = {"log_f_cum": log_f_cum[:, -1].detach()}
|
| 367 |
+
return out, new_cache
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
# ---------------------------------------------------------------------------
|
| 371 |
+
# Titans MAC β gated linear attention with persistent memory
|
| 372 |
+
# ---------------------------------------------------------------------------
|
| 373 |
+
|
| 374 |
+
class TitansMACLayer(nn.Module):
|
| 375 |
+
"""Memory-as-Context linear attention with persistent memory slots."""
|
| 376 |
+
|
| 377 |
+
def __init__(self, hidden_size: int, num_heads: int, head_dim: int,
|
| 378 |
+
memory_depth: int = 2, persistent_slots: int = 64,
|
| 379 |
+
local_window: int = 1024, norm_eps: float = 1e-6,
|
| 380 |
+
use_ternary: bool = True):
|
| 381 |
+
super().__init__()
|
| 382 |
+
self.hidden_size = int(hidden_size)
|
| 383 |
+
self.num_heads = int(num_heads)
|
| 384 |
+
self.head_dim = int(head_dim)
|
| 385 |
+
self.memory_depth = int(memory_depth)
|
| 386 |
+
self.local_window = int(local_window)
|
| 387 |
+
self.persistent_slots = int(persistent_slots)
|
| 388 |
+
self.qk_dim = self.num_heads * self.head_dim
|
| 389 |
+
self.v_dim = self.num_heads * self.head_dim
|
| 390 |
+
|
| 391 |
+
L = _make_linear(use_ternary)
|
| 392 |
+
self.q_proj = L(self.hidden_size, self.qk_dim)
|
| 393 |
+
self.k_proj = L(self.hidden_size, self.qk_dim)
|
| 394 |
+
self.v_proj = L(self.hidden_size, self.v_dim)
|
| 395 |
+
self.o_proj = L(self.v_dim, self.hidden_size)
|
| 396 |
+
|
| 397 |
+
self.alpha_proj = nn.Linear(self.hidden_size, self.num_heads, bias=True)
|
| 398 |
+
self.eta_proj = nn.Linear(self.hidden_size, self.num_heads, bias=True)
|
| 399 |
+
self.theta_proj = nn.Linear(self.hidden_size, self.num_heads, bias=True)
|
| 400 |
+
|
| 401 |
+
if self.persistent_slots > 0:
|
| 402 |
+
self.persistent_memory = nn.Parameter(
|
| 403 |
+
torch.randn(self.persistent_slots, self.hidden_size) * 0.02)
|
| 404 |
+
else:
|
| 405 |
+
self.register_parameter("persistent_memory", None)
|
| 406 |
+
|
| 407 |
+
self.o_norm = RMSNorm(self.v_dim, eps=norm_eps)
|
| 408 |
+
|
| 409 |
+
def forward(self, x: torch.Tensor, cache: Optional[dict] = None
|
| 410 |
+
) -> Tuple[torch.Tensor, dict]:
|
| 411 |
+
B, T, _ = x.shape
|
| 412 |
+
H = self.num_heads
|
| 413 |
+
D = self.head_dim
|
| 414 |
+
# Project once.
|
| 415 |
+
q = self.q_proj(x).view(B, T, H, D)
|
| 416 |
+
k = self.k_proj(x).view(B, T, H, D)
|
| 417 |
+
v = self.v_proj(x).view(B, T, H, D)
|
| 418 |
+
|
| 419 |
+
alpha = torch.sigmoid(self.alpha_proj(x)) # [B, T, H]
|
| 420 |
+
eta = torch.sigmoid(self.eta_proj(x))
|
| 421 |
+
theta = torch.sigmoid(self.theta_proj(x)) * 0.1
|
| 422 |
+
|
| 423 |
+
q_h = q.permute(0, 2, 1, 3).to(torch.float32)
|
| 424 |
+
k_h = k.permute(0, 2, 1, 3).to(torch.float32)
|
| 425 |
+
v_h = v.permute(0, 2, 1, 3).to(torch.float32)
|
| 426 |
+
alpha_h = alpha.permute(0, 2, 1).to(torch.float32)
|
| 427 |
+
eta_h = eta.permute(0, 2, 1).to(torch.float32)
|
| 428 |
+
theta_h = theta.permute(0, 2, 1).to(torch.float32)
|
| 429 |
+
|
| 430 |
+
# Causal forgetting decay built in log-space.
|
| 431 |
+
log_retain = torch.log1p(-alpha_h.clamp(max=0.999))
|
| 432 |
+
log_retain_cum = log_retain.cumsum(dim=-1)
|
| 433 |
+
decay = log_retain_cum.unsqueeze(-1) - log_retain_cum.unsqueeze(-2)
|
| 434 |
+
decay = decay + _causal_mask_neg_inf(T, x.device, decay.dtype)
|
| 435 |
+
decay = decay.exp() # 0 above diag
|
| 436 |
+
|
| 437 |
+
contrib = (eta_h * theta_h).unsqueeze(-1) * v_h # [B, H, T, D]
|
| 438 |
+
attn = torch.matmul(q_h, k_h.transpose(-1, -2)) * decay # [B, H, T, T]
|
| 439 |
+
out = torch.matmul(attn, contrib) # [B, H, T, D]
|
| 440 |
+
|
| 441 |
+
out = out.permute(0, 2, 1, 3).reshape(B, T, self.v_dim)
|
| 442 |
+
out = self.o_norm(out.to(x.dtype))
|
| 443 |
+
return self.o_proj(out), cache or {}
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
# ---------------------------------------------------------------------------
|
| 447 |
+
# TSP Span Knot β fast vectorised energy
|
| 448 |
+
# ---------------------------------------------------------------------------
|
| 449 |
+
|
| 450 |
+
class TSPSpanKnotLayer(nn.Module):
|
| 451 |
+
"""TSP Span Knot: GatedDeltaNet body with a small additive energy term."""
|
| 452 |
+
|
| 453 |
+
def __init__(self, hidden_size: int, num_heads: int, head_dim: int,
|
| 454 |
+
norm_eps: float = 1e-6, chunk_size: int = 64,
|
| 455 |
+
use_ternary: bool = True):
|
| 456 |
+
super().__init__()
|
| 457 |
+
self.hidden_size = int(hidden_size)
|
| 458 |
+
self.gdn = GatedDeltaNetLayer(self.hidden_size, num_heads, head_dim,
|
| 459 |
+
norm_eps=norm_eps, chunk_size=chunk_size,
|
| 460 |
+
use_ternary=use_ternary)
|
| 461 |
+
# Single fused projection produces five energy terms.
|
| 462 |
+
self.energy_proj = nn.Linear(self.hidden_size, 5, bias=False)
|
| 463 |
+
self.energy_weights = nn.Parameter(torch.tensor([1.0, 0.3, 0.2, 0.4, 0.3]))
|
| 464 |
+
self._semantic_memory = None
|
| 465 |
+
|
| 466 |
+
def set_semantic_memory(self, mem) -> None:
|
| 467 |
+
self._semantic_memory = mem
|
| 468 |
+
|
| 469 |
+
def forward(self, x: torch.Tensor, cache: Optional[dict] = None
|
| 470 |
+
) -> Tuple[torch.Tensor, dict]:
|
| 471 |
+
out, new_cache = self.gdn(x, cache=cache)
|
| 472 |
+
energies = self.energy_proj(out) # [B, T, 5]
|
| 473 |
+
weighted = (energies * self.energy_weights).sum(dim=-1, keepdim=True)
|
| 474 |
+
# Small residual nudge β keeps gradient signal small as in 5.1.
|
| 475 |
+
return out + weighted * 0.01, new_cache
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
__all__ = [
|
| 479 |
+
"SwiGLUMLP",
|
| 480 |
+
"ShortConv1d",
|
| 481 |
+
"GatedDeltaNetLayer",
|
| 482 |
+
"MLSTMLayer",
|
| 483 |
+
"TitansMACLayer",
|
| 484 |
+
"TSPSpanKnotLayer",
|
| 485 |
+
]
|
chimera/looping.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Chimera 5.2 β Parcae Prelude / Loop / Coda controller.
|
| 3 |
+
|
| 4 |
+
Same numerics as the previous draft but cleaner:
|
| 5 |
+
* Loop count is deterministic during training so gradient checkpointing
|
| 6 |
+
recompute is consistent.
|
| 7 |
+
* Backward truncation only retains gradients on the last ``n_loops // 2``
|
| 8 |
+
iterations; earlier iterates are detached, mirroring the original
|
| 9 |
+
intuition while keeping the implementation in pure PyTorch.
|
| 10 |
+
* Adaptive early-exit during inference based on residual magnitude.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class ParcaeInjection(nn.Module):
|
| 20 |
+
"""ZOH-stable diagonal injection: ``h' = exp(-ΞΒ·A)Β·h + ΞΒ·BΒ·e``."""
|
| 21 |
+
|
| 22 |
+
__constants__ = ["hidden_size"]
|
| 23 |
+
|
| 24 |
+
def __init__(self, hidden_size: int):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.hidden_size = int(hidden_size)
|
| 27 |
+
self.log_A = nn.Parameter(torch.zeros(self.hidden_size))
|
| 28 |
+
self.log_A._no_weight_decay = True
|
| 29 |
+
self.B_raw = nn.Parameter(torch.randn(self.hidden_size) * 0.02)
|
| 30 |
+
self.delta = nn.Parameter(torch.full((self.hidden_size,), 0.5))
|
| 31 |
+
|
| 32 |
+
def forward(self, h_prev: torch.Tensor, e: torch.Tensor) -> torch.Tensor:
|
| 33 |
+
A_bar = (-self.delta * self.log_A.exp()).exp()
|
| 34 |
+
B_bar = self.delta * self.B_raw
|
| 35 |
+
return A_bar * h_prev + B_bar * e
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class ParcaeLoopController(nn.Module):
|
| 39 |
+
"""Iterative refinement controller used by the looped trunk."""
|
| 40 |
+
|
| 41 |
+
__constants__ = ["loop_min", "loop_max", "loop_default"]
|
| 42 |
+
|
| 43 |
+
def __init__(self, hidden_size: int,
|
| 44 |
+
loop_range: tuple = (1, 6), loop_default: int = 2,
|
| 45 |
+
adaptive_exit_threshold: float = 0.01,
|
| 46 |
+
spectral_radius_bound: float = 1.0):
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.injection = ParcaeInjection(hidden_size)
|
| 49 |
+
self.loop_min, self.loop_max = int(loop_range[0]), int(loop_range[1])
|
| 50 |
+
self.loop_default = int(loop_default)
|
| 51 |
+
self.exit_threshold = float(adaptive_exit_threshold)
|
| 52 |
+
self.e_norm = nn.LayerNorm(hidden_size)
|
| 53 |
+
|
| 54 |
+
def forward(self, prelude_output: torch.Tensor, loop_fn,
|
| 55 |
+
num_loops: int = None) -> torch.Tensor:
|
| 56 |
+
e = self.e_norm(prelude_output)
|
| 57 |
+
h = torch.zeros_like(e)
|
| 58 |
+
n_loops = int(num_loops) if num_loops is not None else self.loop_default
|
| 59 |
+
n_loops = max(self.loop_min, min(self.loop_max, n_loops))
|
| 60 |
+
|
| 61 |
+
n_bwd = max(1, n_loops // 2) if self.training else n_loops
|
| 62 |
+
|
| 63 |
+
for t in range(n_loops):
|
| 64 |
+
h_new = loop_fn(self.injection(h, e))
|
| 65 |
+
backprop = (not self.training) or (t >= n_loops - n_bwd)
|
| 66 |
+
h = h_new if backprop else h_new.detach()
|
| 67 |
+
if not self.training and t > 0:
|
| 68 |
+
if (h_new - h).abs().mean().item() < self.exit_threshold:
|
| 69 |
+
break
|
| 70 |
+
return h
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
__all__ = ["ParcaeInjection", "ParcaeLoopController"]
|
chimera/model.py
ADDED
|
@@ -0,0 +1,438 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Chimera 5.2 β full causal LM with FUNCTIONAL self-evolution.
|
| 3 |
+
|
| 4 |
+
Key changes for auto-evolution:
|
| 5 |
+
* SelfEvolutionEngine is called at EVERY layer during forward pass
|
| 6 |
+
* Semantic memory modulation is added to hidden states
|
| 7 |
+
* TTT updates target MLP weights in-place during forward
|
| 8 |
+
* Evolution loss is added to causal LM loss during training
|
| 9 |
+
* Contrastive evaluation tracks memory usefulness
|
| 10 |
+
* Loop depth classifier sets compute budget per sequence
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import json
|
| 16 |
+
from typing import Any, List, Optional, Tuple
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
from torch.utils.checkpoint import checkpoint
|
| 22 |
+
|
| 23 |
+
from .quantization import BitLinear, RMSNorm
|
| 24 |
+
from .layers import (GatedDeltaNetLayer, MLSTMLayer, TitansMACLayer,
|
| 25 |
+
TSPSpanKnotLayer, SwiGLUMLP)
|
| 26 |
+
from .moe import MoELayer
|
| 27 |
+
from .looping import ParcaeLoopController
|
| 28 |
+
from .inference import (SpanInferenceEngine, GrammarFST, EntropyValve,
|
| 29 |
+
DebtLedger, BraidState)
|
| 30 |
+
from .evolution import SelfEvolutionEngine
|
| 31 |
+
from .multimodal import VisionEncoder, AudioEncoder
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class CausalLMOutput(dict):
|
| 35 |
+
"""Light HF-compatible output dict supporting tuple unpacking."""
|
| 36 |
+
|
| 37 |
+
def __init__(self, loss: Optional[torch.Tensor] = None,
|
| 38 |
+
logits: Optional[torch.Tensor] = None,
|
| 39 |
+
hidden_states: Optional[torch.Tensor] = None,
|
| 40 |
+
caches: Optional[list] = None,
|
| 41 |
+
evolution_metrics: Optional[dict] = None):
|
| 42 |
+
super().__init__(loss=loss, logits=logits,
|
| 43 |
+
hidden_states=hidden_states, caches=caches,
|
| 44 |
+
evolution_metrics=evolution_metrics)
|
| 45 |
+
self.loss = loss
|
| 46 |
+
self.logits = logits
|
| 47 |
+
self.hidden_states = hidden_states
|
| 48 |
+
self.caches = caches
|
| 49 |
+
self.evolution_metrics = evolution_metrics or {}
|
| 50 |
+
|
| 51 |
+
def __iter__(self):
|
| 52 |
+
yield self.loss
|
| 53 |
+
yield self.logits
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def expand_layer_pattern(config: dict) -> List[str]:
|
| 57 |
+
"""Expand the layer-pattern shorthand into a list."""
|
| 58 |
+
backbone = config.get("backbone", {})
|
| 59 |
+
pattern_str = backbone.get("layer_pattern", "GD XM GD TM GD XM GD SK")
|
| 60 |
+
aliases = backbone.get("layer_aliases", {
|
| 61 |
+
"GD": "gated_deltanet", "XM": "xlstm_m",
|
| 62 |
+
"TM": "titans_mac", "SK": "tsp_span_knot",
|
| 63 |
+
})
|
| 64 |
+
pattern = pattern_str.split()
|
| 65 |
+
n_layers = int(config.get("num_hidden_layers", 28))
|
| 66 |
+
full = (pattern * (n_layers // len(pattern) + 1))[:n_layers]
|
| 67 |
+
return [aliases.get(p, p) for p in full]
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class Chimera51Block(nn.Module):
|
| 71 |
+
"""One block with evolution-aware forward."""
|
| 72 |
+
|
| 73 |
+
_RECURRENT = {"gated_deltanet", "xlstm_m", "titans_mac", "tsp_span_knot"}
|
| 74 |
+
|
| 75 |
+
def __init__(self, config: dict, layer_type: str, layer_idx: int,
|
| 76 |
+
use_moe: bool = False):
|
| 77 |
+
super().__init__()
|
| 78 |
+
h = int(config["hidden_size"])
|
| 79 |
+
eps = float(config.get("rms_norm_eps", 1e-6))
|
| 80 |
+
heads = int(config["num_heads"])
|
| 81 |
+
head_dim = int(config["head_dim"])
|
| 82 |
+
ternary = bool(config.get("use_ternary", True))
|
| 83 |
+
chunk_sz = int(config.get("gated_deltanet", {}).get("chunk_size", 64))
|
| 84 |
+
|
| 85 |
+
self.layer_idx = layer_idx
|
| 86 |
+
self.layer_type = layer_type
|
| 87 |
+
self.attn_norm = RMSNorm(h, eps=eps)
|
| 88 |
+
|
| 89 |
+
if layer_type == "gated_deltanet":
|
| 90 |
+
self.attn = GatedDeltaNetLayer(h, heads, head_dim, norm_eps=eps,
|
| 91 |
+
chunk_size=chunk_sz, use_ternary=ternary)
|
| 92 |
+
elif layer_type == "xlstm_m":
|
| 93 |
+
mem_h = config.get("xlstm", {}).get("memory_size_per_head", [head_dim, head_dim])
|
| 94 |
+
self.attn = MLSTMLayer(h, heads, int(mem_h[0]), norm_eps=eps,
|
| 95 |
+
use_ternary=ternary)
|
| 96 |
+
elif layer_type == "titans_mac":
|
| 97 |
+
tc = config.get("titans", {})
|
| 98 |
+
self.attn = TitansMACLayer(h, heads, head_dim,
|
| 99 |
+
memory_depth=int(tc.get("memory_depth", 2)),
|
| 100 |
+
persistent_slots=int(tc.get("persistent_memory_slots", 64)),
|
| 101 |
+
local_window=int(tc.get("local_window_size", 1024)),
|
| 102 |
+
norm_eps=eps, use_ternary=ternary)
|
| 103 |
+
elif layer_type == "tsp_span_knot":
|
| 104 |
+
self.attn = TSPSpanKnotLayer(h, heads, head_dim, norm_eps=eps,
|
| 105 |
+
chunk_size=chunk_sz, use_ternary=ternary)
|
| 106 |
+
else:
|
| 107 |
+
raise ValueError(f"Unknown layer type: {layer_type}")
|
| 108 |
+
|
| 109 |
+
self.mlp_norm = RMSNorm(h, eps=eps)
|
| 110 |
+
self.use_moe = bool(use_moe)
|
| 111 |
+
if self.use_moe:
|
| 112 |
+
moe_cfg = config.get("backbone", {}).get("moe", {})
|
| 113 |
+
self.mlp = MoELayer(
|
| 114 |
+
hidden_size=h,
|
| 115 |
+
moe_intermediate_size=int(moe_cfg.get("moe_intermediate_size", h * 2)),
|
| 116 |
+
n_routed_experts=int(moe_cfg.get("n_routed_experts", 16)),
|
| 117 |
+
n_shared_experts=int(moe_cfg.get("n_shared_experts", 1)),
|
| 118 |
+
num_experts_per_tok=int(moe_cfg.get("num_experts_per_tok", 2)),
|
| 119 |
+
use_ternary=ternary,
|
| 120 |
+
)
|
| 121 |
+
else:
|
| 122 |
+
inter = int(config.get("intermediate_size", int(h * 8 / 3)))
|
| 123 |
+
inter = 256 * ((inter + 255) // 256)
|
| 124 |
+
self.mlp = SwiGLUMLP(h, inter, use_ternary=ternary)
|
| 125 |
+
|
| 126 |
+
# Evolution modulation projection (learnable scale)
|
| 127 |
+
self.evo_gate = nn.Linear(h, h, bias=False)
|
| 128 |
+
nn.init.zeros_(self.evo_gate.weight)
|
| 129 |
+
|
| 130 |
+
def forward(self, x: torch.Tensor, cache: Optional[dict] = None,
|
| 131 |
+
evo_modulation: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, dict]:
|
| 132 |
+
# Apply attention with pre-norm
|
| 133 |
+
normed = self.attn_norm(x)
|
| 134 |
+
attn_out, new_cache = self.attn(normed, cache=cache)
|
| 135 |
+
x = x + attn_out
|
| 136 |
+
|
| 137 |
+
# Apply MLP with pre-norm
|
| 138 |
+
x = x + self.mlp(self.mlp_norm(x))
|
| 139 |
+
|
| 140 |
+
# Apply evolution modulation (gated residual)
|
| 141 |
+
if evo_modulation is not None:
|
| 142 |
+
gate = torch.sigmoid(self.evo_gate(x))
|
| 143 |
+
x = x + gate * evo_modulation
|
| 144 |
+
|
| 145 |
+
return x, new_cache
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class Chimera51ForCausalLM(nn.Module):
|
| 149 |
+
"""Chimera 5.x causal language model with functional self-evolution."""
|
| 150 |
+
|
| 151 |
+
def __init__(self, config: dict):
|
| 152 |
+
super().__init__()
|
| 153 |
+
self.config = config
|
| 154 |
+
h = int(config["hidden_size"])
|
| 155 |
+
vocab = int(config["vocab_size"])
|
| 156 |
+
n_layers = int(config["num_hidden_layers"])
|
| 157 |
+
eps = float(config.get("rms_norm_eps", 1e-6))
|
| 158 |
+
|
| 159 |
+
self.embed = nn.Embedding(vocab, h)
|
| 160 |
+
layer_types = expand_layer_pattern(config)
|
| 161 |
+
moe_layers = set(int(i) for i in config.get("backbone", {}).get("moe", {}).get("layers", []))
|
| 162 |
+
|
| 163 |
+
self.layers = nn.ModuleList([
|
| 164 |
+
Chimera51Block(config, layer_types[i], i, use_moe=(i in moe_layers))
|
| 165 |
+
for i in range(n_layers)
|
| 166 |
+
])
|
| 167 |
+
|
| 168 |
+
self.norm = RMSNorm(h, eps=eps)
|
| 169 |
+
self.lm_head = nn.Linear(h, vocab, bias=False)
|
| 170 |
+
|
| 171 |
+
if config.get("tie_word_embeddings", True):
|
| 172 |
+
self.lm_head.weight = self.embed.weight
|
| 173 |
+
|
| 174 |
+
# Parcae looping controller
|
| 175 |
+
loop_cfg = config.get("looping", {})
|
| 176 |
+
self.looping_enabled = bool(loop_cfg.get("enabled", True)) and n_layers >= 3
|
| 177 |
+
if self.looping_enabled:
|
| 178 |
+
self.prelude_start, self.prelude_end = loop_cfg.get("prelude", [0, min(3, n_layers - 1)])
|
| 179 |
+
self.loop_start, self.loop_end = loop_cfg.get("loop", [min(4, n_layers - 1), max(4, n_layers - 4)])
|
| 180 |
+
self.coda_start, self.coda_end = loop_cfg.get("coda", [max(0, n_layers - 4), n_layers - 1])
|
| 181 |
+
self.loop_controller = ParcaeLoopController(
|
| 182 |
+
h, loop_range=tuple(loop_cfg.get("loop_range", [1, 6])),
|
| 183 |
+
loop_default=int(loop_cfg.get("loop_default", 2)),
|
| 184 |
+
adaptive_exit_threshold=float(loop_cfg.get("adaptive_exit_threshold", 0.01)),
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
# Inference systems
|
| 188 |
+
si_cfg = config.get("span_inference", {})
|
| 189 |
+
self.span_engine = SpanInferenceEngine(h, si_cfg) if si_cfg.get("enabled", True) else None
|
| 190 |
+
self.grammar = GrammarFST(config.get("grammar", {}))
|
| 191 |
+
self.entropy_valve = EntropyValve(config.get("entropy_valve", {}))
|
| 192 |
+
self.debt_ledger = DebtLedger(config.get("debt_ledger", {}))
|
| 193 |
+
|
| 194 |
+
# Self-evolution β FUNCTIONAL
|
| 195 |
+
evo_cfg = dict(config.get("self_evolution", {}))
|
| 196 |
+
evo_cfg["_semantic_memory_config"] = config.get("semantic_memory", {})
|
| 197 |
+
self.evolution = SelfEvolutionEngine(evo_cfg, h)
|
| 198 |
+
self.evo_weight = float(config.get("evolution_loss_weight", 0.01))
|
| 199 |
+
self.evo_every_n_layers = int(config.get("evolution_every_n_layers", 4))
|
| 200 |
+
|
| 201 |
+
# Multimodal
|
| 202 |
+
mm_cfg = dict(config.get("multimodal", {}))
|
| 203 |
+
mm_cfg["hidden_size"] = h
|
| 204 |
+
if mm_cfg.get("enabled", False):
|
| 205 |
+
self.vision_encoder = VisionEncoder(mm_cfg)
|
| 206 |
+
self.audio_encoder = AudioEncoder(mm_cfg)
|
| 207 |
+
else:
|
| 208 |
+
self.vision_encoder = None
|
| 209 |
+
self.audio_encoder = None
|
| 210 |
+
|
| 211 |
+
self.gradient_checkpointing = False
|
| 212 |
+
self._init_weights()
|
| 213 |
+
self._wire_semantic_memory()
|
| 214 |
+
|
| 215 |
+
def enable_gradient_checkpointing(self) -> None:
|
| 216 |
+
self.gradient_checkpointing = True
|
| 217 |
+
|
| 218 |
+
def disable_gradient_checkpointing(self) -> None:
|
| 219 |
+
self.gradient_checkpointing = False
|
| 220 |
+
|
| 221 |
+
def _wire_semantic_memory(self) -> None:
|
| 222 |
+
mem = self.evolution.semantic_memory
|
| 223 |
+
for layer in self.layers:
|
| 224 |
+
if hasattr(layer.attn, "set_semantic_memory"):
|
| 225 |
+
layer.attn.set_semantic_memory(mem)
|
| 226 |
+
|
| 227 |
+
def _init_weights(self) -> None:
|
| 228 |
+
init_range = float(self.config.get("initializer_range", 0.006))
|
| 229 |
+
for module in self.modules():
|
| 230 |
+
if isinstance(module, (nn.Linear, BitLinear)):
|
| 231 |
+
if module.weight is not None:
|
| 232 |
+
nn.init.normal_(module.weight, mean=0.0, std=init_range)
|
| 233 |
+
if getattr(module, "bias", None) is not None:
|
| 234 |
+
nn.init.zeros_(module.bias)
|
| 235 |
+
elif isinstance(module, nn.Embedding):
|
| 236 |
+
nn.init.normal_(module.weight, mean=0.0, std=init_range)
|
| 237 |
+
for module in self.modules():
|
| 238 |
+
if isinstance(module, BitLinear):
|
| 239 |
+
module.invalidate_packed()
|
| 240 |
+
|
| 241 |
+
def _run_layers(self, x: torch.Tensor, start: int, end: int,
|
| 242 |
+
caches: Optional[list],
|
| 243 |
+
compute_logits: bool = False,
|
| 244 |
+
labels: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, Optional[torch.Tensor], list]:
|
| 245 |
+
"""Run layers with evolution hooks. Returns (x, logits_if_computed, caches)."""
|
| 246 |
+
all_metrics = []
|
| 247 |
+
logits = None
|
| 248 |
+
evolution_loss = torch.tensor(0.0, device=x.device)
|
| 249 |
+
|
| 250 |
+
for i in range(start, min(end + 1, len(self.layers))):
|
| 251 |
+
layer = self.layers[i]
|
| 252 |
+
cache = caches[i] if caches is not None else None
|
| 253 |
+
|
| 254 |
+
# Evolution modulation every N layers (lightweight)
|
| 255 |
+
evo_mod = None
|
| 256 |
+
if i % self.evo_every_n_layers == 0 and self.evolution is not None:
|
| 257 |
+
# Compute modulation from semantic memory
|
| 258 |
+
# Note: loss parameter requires a scalar loss tensor for TTT/surprise;
|
| 259 |
+
# pass None during standard forward, compute explicitly for TTT
|
| 260 |
+
evo_result = self.evolution(
|
| 261 |
+
hidden_states=x.detach() if not x.requires_grad else x,
|
| 262 |
+
layer_idx=i,
|
| 263 |
+
loss=None
|
| 264 |
+
)
|
| 265 |
+
evo_mod = evo_result['modulation']
|
| 266 |
+
if evo_result['evolution_loss'] is not None:
|
| 267 |
+
evolution_loss = evolution_loss + evo_result['evolution_loss']
|
| 268 |
+
all_metrics.append(evo_result.get('metrics', {}))
|
| 269 |
+
|
| 270 |
+
# TTT update for target layers (only in training, no backprop)
|
| 271 |
+
if self.training and evo_result.get('ttt_delta') is not None:
|
| 272 |
+
with torch.no_grad():
|
| 273 |
+
# Apply TTT to MLP down-projection if this is a target layer
|
| 274 |
+
if hasattr(layer.mlp, 'w_down'):
|
| 275 |
+
layer.mlp.w_down.data.add_(evo_result['ttt_delta'] * self.evolution.ttt.inner_lr)
|
| 276 |
+
|
| 277 |
+
if self.gradient_checkpointing and self.training:
|
| 278 |
+
def _ckpt_fn(x_in, layer=layer, cache=cache, evo=evo_mod):
|
| 279 |
+
out, _ = layer(x_in, cache=cache, evo_modulation=evo)
|
| 280 |
+
return out
|
| 281 |
+
x = checkpoint(_ckpt_fn, x, use_reentrant=False)
|
| 282 |
+
else:
|
| 283 |
+
x, new_cache = layer(x, cache=cache, evo_modulation=evo_mod)
|
| 284 |
+
if caches is not None:
|
| 285 |
+
caches[i] = new_cache
|
| 286 |
+
|
| 287 |
+
# Compute probe logits for entropy valve (every few layers)
|
| 288 |
+
if compute_logits and i == end:
|
| 289 |
+
logits = self.lm_head(self.norm(x[:, -1:, :]))
|
| 290 |
+
|
| 291 |
+
return x, logits, caches, evolution_loss, all_metrics
|
| 292 |
+
|
| 293 |
+
def forward(self, input_ids: torch.Tensor,
|
| 294 |
+
labels: Optional[torch.Tensor] = None,
|
| 295 |
+
pixel_values: Optional[torch.Tensor] = None,
|
| 296 |
+
mel_features: Optional[torch.Tensor] = None,
|
| 297 |
+
num_loops: Optional[int] = None,
|
| 298 |
+
caches: Optional[list] = None,
|
| 299 |
+
use_cache: bool = False,
|
| 300 |
+
logits_to_keep: int = 0,
|
| 301 |
+
return_evolution_metrics: bool = False):
|
| 302 |
+
x = self.embed(input_ids)
|
| 303 |
+
|
| 304 |
+
# Multimodal prepend
|
| 305 |
+
if pixel_values is not None and self.vision_encoder is not None:
|
| 306 |
+
v = self.vision_encoder(pixel_values)
|
| 307 |
+
if v is not None:
|
| 308 |
+
x = torch.cat([v, x], dim=1)
|
| 309 |
+
if mel_features is not None and self.audio_encoder is not None:
|
| 310 |
+
a = self.audio_encoder(mel_features)
|
| 311 |
+
if a is not None:
|
| 312 |
+
x = torch.cat([a, x], dim=1)
|
| 313 |
+
|
| 314 |
+
if caches is None and use_cache:
|
| 315 |
+
caches = [None] * len(self.layers)
|
| 316 |
+
|
| 317 |
+
total_evo_loss = torch.tensor(0.0, device=x.device)
|
| 318 |
+
all_evo_metrics = []
|
| 319 |
+
|
| 320 |
+
# Prelude + Loop + Coda with evolution
|
| 321 |
+
if self.looping_enabled and hasattr(self, "loop_controller"):
|
| 322 |
+
# Prelude
|
| 323 |
+
x, probe_logits, caches, evo_loss, metrics = self._run_layers(
|
| 324 |
+
x, self.prelude_start, self.prelude_end, caches,
|
| 325 |
+
compute_logits=not self.training, labels=labels)
|
| 326 |
+
total_evo_loss = total_evo_loss + evo_loss
|
| 327 |
+
all_evo_metrics.extend(metrics)
|
| 328 |
+
|
| 329 |
+
# Determine loop depth
|
| 330 |
+
effective = num_loops
|
| 331 |
+
if effective is None and not self.training and probe_logits is not None:
|
| 332 |
+
effective = self.entropy_valve.get_loop_count(probe_logits)
|
| 333 |
+
elif effective is None and self.evolution is not None:
|
| 334 |
+
# Use loop classifier from evolution
|
| 335 |
+
last_hidden = x[:, -1, :].mean(dim=0, keepdim=True) # Average over batch
|
| 336 |
+
effective = self.evolution.loop_classifier(last_hidden).item()
|
| 337 |
+
effective = max(1, min(effective, 6))
|
| 338 |
+
|
| 339 |
+
# Loop body
|
| 340 |
+
loop_fn = lambda inp: self._run_layers(
|
| 341 |
+
inp, self.loop_start, self.loop_end, caches, labels=labels)[0]
|
| 342 |
+
x = self.loop_controller(x, loop_fn, num_loops=effective)
|
| 343 |
+
|
| 344 |
+
# Coda
|
| 345 |
+
x, _, caches, evo_loss, metrics = self._run_layers(
|
| 346 |
+
x, self.coda_start, self.coda_end, caches, labels=labels)
|
| 347 |
+
total_evo_loss = total_evo_loss + evo_loss
|
| 348 |
+
all_evo_metrics.extend(metrics)
|
| 349 |
+
else:
|
| 350 |
+
x, _, caches, evo_loss, metrics = self._run_layers(
|
| 351 |
+
x, 0, len(self.layers) - 1, caches,
|
| 352 |
+
compute_logits=not self.training, labels=labels)
|
| 353 |
+
total_evo_loss = total_evo_loss + evo_loss
|
| 354 |
+
all_evo_metrics.extend(metrics)
|
| 355 |
+
|
| 356 |
+
# Final norm and logits
|
| 357 |
+
if logits_to_keep and labels is None:
|
| 358 |
+
keep = int(logits_to_keep)
|
| 359 |
+
tail = x[:, -keep:, :]
|
| 360 |
+
tail = self.norm(tail)
|
| 361 |
+
if self.span_engine is not None:
|
| 362 |
+
tail = self.span_engine(tail)
|
| 363 |
+
logits = self.lm_head(tail)
|
| 364 |
+
else:
|
| 365 |
+
x = self.norm(x)
|
| 366 |
+
if self.span_engine is not None:
|
| 367 |
+
x = self.span_engine(x)
|
| 368 |
+
logits = self.lm_head(x)
|
| 369 |
+
|
| 370 |
+
logits = self.grammar(logits)
|
| 371 |
+
logits = self.debt_ledger(logits)
|
| 372 |
+
|
| 373 |
+
# Self-feedback refinement check (inference only)
|
| 374 |
+
if not self.training and self.evolution is not None:
|
| 375 |
+
should_refine = self.evolution.self_feedback.should_refine(logits)
|
| 376 |
+
if should_refine:
|
| 377 |
+
all_evo_metrics.append({'refinement_triggered': True})
|
| 378 |
+
|
| 379 |
+
# Compute loss
|
| 380 |
+
loss = None
|
| 381 |
+
if labels is not None:
|
| 382 |
+
seq_len = min(logits.size(1), labels.size(1))
|
| 383 |
+
shift_logits = logits[:, :seq_len, :].contiguous()
|
| 384 |
+
shift_labels = labels[:, :seq_len].contiguous()
|
| 385 |
+
ce_loss = F.cross_entropy(
|
| 386 |
+
shift_logits.view(-1, shift_logits.size(-1)),
|
| 387 |
+
shift_labels.view(-1),
|
| 388 |
+
ignore_index=-100,
|
| 389 |
+
)
|
| 390 |
+
# Add evolution loss (contrastive memory evaluation)
|
| 391 |
+
loss = ce_loss + self.evo_weight * total_evo_loss
|
| 392 |
+
else:
|
| 393 |
+
ce_loss = None
|
| 394 |
+
|
| 395 |
+
# Store episodic case after forward (for inference mode)
|
| 396 |
+
if not self.training and self.evolution is not None:
|
| 397 |
+
last_hidden = x[:, -1, :].detach()
|
| 398 |
+
# Schedule episodic storage for end of sequence
|
| 399 |
+
# (In real use, call model.evolution.store_episodic() explicitly)
|
| 400 |
+
|
| 401 |
+
return CausalLMOutput(
|
| 402 |
+
loss=loss,
|
| 403 |
+
logits=logits,
|
| 404 |
+
hidden_states=x,
|
| 405 |
+
caches=caches if use_cache else None,
|
| 406 |
+
evolution_metrics={
|
| 407 |
+
'ce_loss': ce_loss.item() if ce_loss is not None else None,
|
| 408 |
+
'evo_loss': total_evo_loss.item(),
|
| 409 |
+
'layer_metrics': all_evo_metrics,
|
| 410 |
+
} if return_evolution_metrics else None
|
| 411 |
+
)
|
| 412 |
+
|
| 413 |
+
@torch.no_grad()
|
| 414 |
+
def prepare_for_inference(self) -> None:
|
| 415 |
+
"""Pre-pack every BitLinear so the first generation step is fast."""
|
| 416 |
+
for module in self.modules():
|
| 417 |
+
if isinstance(module, BitLinear):
|
| 418 |
+
module.prepare_for_inference()
|
| 419 |
+
|
| 420 |
+
def get_mode_config(self, mode: str = "balanced") -> dict:
|
| 421 |
+
modes = self.config.get("modes", {})
|
| 422 |
+
return modes.get(mode, modes.get("balanced", {}))
|
| 423 |
+
|
| 424 |
+
def count_parameters(self) -> dict:
|
| 425 |
+
total = sum(p.numel() for p in self.parameters())
|
| 426 |
+
ternary = sum(p.numel() for _, m in self.named_modules()
|
| 427 |
+
if isinstance(m, BitLinear) for p in m.parameters())
|
| 428 |
+
return {"total": total, "ternary": ternary, "fp32": total - ternary}
|
| 429 |
+
|
| 430 |
+
@classmethod
|
| 431 |
+
def from_config_file(cls, path: str) -> "Chimera51ForCausalLM":
|
| 432 |
+
with open(path, "r", encoding="utf-8") as fh:
|
| 433 |
+
config = json.load(fh)
|
| 434 |
+
return cls(config)
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
__all__ = ["Chimera51ForCausalLM", "Chimera51Block", "CausalLMOutput",
|
| 438 |
+
"expand_layer_pattern"]
|
chimera/moe.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Sparse Mixture-of-Experts for Chimera (CPU-first).
|
| 3 |
+
|
| 4 |
+
Key design choices:
|
| 5 |
+
* Routing is computed in the model's compute dtype (no fp32 promotion):
|
| 6 |
+
the original draft cast every router input to fp32 which doubled memory
|
| 7 |
+
bandwidth for nothing on CPUs without dedicated softmax units.
|
| 8 |
+
* Dispatch uses ``index_select`` + boolean masks per expert. No global
|
| 9 |
+
``argsort`` of the routing pairs and no ``bincount`` table. This keeps
|
| 10 |
+
the path ``torch.compile``-friendly even when expert counts vary.
|
| 11 |
+
* All experts share an :class:`SwiGLUMLP` topology so weights can be packed
|
| 12 |
+
ternary identically to the rest of the model.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
from typing import Optional
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
|
| 23 |
+
from .layers import SwiGLUMLP
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class NoAuxMoEGate(nn.Module):
|
| 27 |
+
"""Top-k softmax router with optional bias-only correction (no aux loss)."""
|
| 28 |
+
|
| 29 |
+
__constants__ = ["n_routed_experts", "num_experts_per_tok"]
|
| 30 |
+
|
| 31 |
+
def __init__(self, hidden_size: int, n_routed_experts: int,
|
| 32 |
+
num_experts_per_tok: int = 2):
|
| 33 |
+
super().__init__()
|
| 34 |
+
self.n_routed_experts = int(n_routed_experts)
|
| 35 |
+
self.num_experts_per_tok = int(num_experts_per_tok)
|
| 36 |
+
self.weight = nn.Parameter(torch.empty(self.n_routed_experts, hidden_size))
|
| 37 |
+
nn.init.normal_(self.weight, mean=0.0, std=hidden_size ** -0.5)
|
| 38 |
+
# Buffer (not a Parameter): bias correction updated by training scripts.
|
| 39 |
+
self.register_buffer("e_score_correction_bias",
|
| 40 |
+
torch.zeros(self.n_routed_experts))
|
| 41 |
+
|
| 42 |
+
def forward(self, x: torch.Tensor):
|
| 43 |
+
# x: [N, D] in arbitrary dtype. Routing is stable enough in bf16/fp32.
|
| 44 |
+
scores = F.linear(x, self.weight) + self.e_score_correction_bias
|
| 45 |
+
probs = F.softmax(scores, dim=-1)
|
| 46 |
+
weights, indices = torch.topk(probs, self.num_experts_per_tok, dim=-1)
|
| 47 |
+
weights = weights / weights.sum(dim=-1, keepdim=True).clamp_min(1e-9)
|
| 48 |
+
return indices, weights
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class MoELayer(nn.Module):
|
| 52 |
+
"""Sparse MoE block with grouped expert dispatch."""
|
| 53 |
+
|
| 54 |
+
def __init__(self, hidden_size: int, moe_intermediate_size: int,
|
| 55 |
+
n_routed_experts: int = 16, n_shared_experts: int = 1,
|
| 56 |
+
num_experts_per_tok: int = 2, use_ternary: bool = True):
|
| 57 |
+
super().__init__()
|
| 58 |
+
self.hidden_size = int(hidden_size)
|
| 59 |
+
self.n_routed_experts = int(n_routed_experts)
|
| 60 |
+
self.n_shared_experts = int(n_shared_experts)
|
| 61 |
+
self.num_experts_per_tok = int(num_experts_per_tok)
|
| 62 |
+
self.gate = NoAuxMoEGate(self.hidden_size, self.n_routed_experts,
|
| 63 |
+
self.num_experts_per_tok)
|
| 64 |
+
self.experts = nn.ModuleList([
|
| 65 |
+
SwiGLUMLP(self.hidden_size, moe_intermediate_size, use_ternary=use_ternary)
|
| 66 |
+
for _ in range(self.n_routed_experts)
|
| 67 |
+
])
|
| 68 |
+
if self.n_shared_experts > 0:
|
| 69 |
+
shared_inter = max(1, moe_intermediate_size * self.n_shared_experts)
|
| 70 |
+
self.shared_experts = SwiGLUMLP(self.hidden_size, shared_inter,
|
| 71 |
+
use_ternary=use_ternary)
|
| 72 |
+
else:
|
| 73 |
+
self.shared_experts = None
|
| 74 |
+
|
| 75 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 76 |
+
orig_shape = x.shape
|
| 77 |
+
flat = x.reshape(-1, self.hidden_size)
|
| 78 |
+
N = flat.size(0)
|
| 79 |
+
|
| 80 |
+
topk_idx, topk_w = self.gate(flat) # [N, k]
|
| 81 |
+
out = torch.zeros_like(flat)
|
| 82 |
+
|
| 83 |
+
# Per-expert dispatch via boolean masks: avoids the global argsort and
|
| 84 |
+
# ``bincount`` of the previous draft and keeps the structure compatible
|
| 85 |
+
# with torch.compile.
|
| 86 |
+
for e in range(self.n_routed_experts):
|
| 87 |
+
match = (topk_idx == e)
|
| 88 |
+
if not match.any():
|
| 89 |
+
continue
|
| 90 |
+
# Token positions and per-pair weights for this expert.
|
| 91 |
+
tok_pos, slot_pos = match.nonzero(as_tuple=True)
|
| 92 |
+
w = topk_w[tok_pos, slot_pos].unsqueeze(-1).to(out.dtype)
|
| 93 |
+
y = self.experts[e](flat.index_select(0, tok_pos))
|
| 94 |
+
out.index_add_(0, tok_pos, y * w)
|
| 95 |
+
|
| 96 |
+
if self.shared_experts is not None:
|
| 97 |
+
out = out + self.shared_experts(flat)
|
| 98 |
+
|
| 99 |
+
return out.reshape(orig_shape)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
__all__ = ["NoAuxMoEGate", "MoELayer", "SwiGLUMLP"]
|
chimera/multimodal.py
ADDED
|
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Chimera 5.2 β multimodal encoders (CPU-friendly, slim).
|
| 3 |
+
|
| 4 |
+
The previous draft had two latent issues:
|
| 5 |
+
* The vision/audio encoders projected to ``out_dim`` (e.g. 2560) which did
|
| 6 |
+
not match the trunk's ``hidden_size`` after scaling, so concatenating
|
| 7 |
+
image embeddings into the LM hidden stream blew up. We now project to
|
| 8 |
+
the trunk's hidden size by default.
|
| 9 |
+
* The internal ``_EncoderBlock`` wrapped a recurrent layer expecting a
|
| 10 |
+
``cache`` argument; we now call the layer correctly and discard the
|
| 11 |
+
cache (the encoder is purely parallel).
|
| 12 |
+
|
| 13 |
+
The encoders themselves remain BitLinear-friendly so they share the
|
| 14 |
+
ternary memory budget of the trunk.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
from typing import Optional
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn as nn
|
| 23 |
+
from torch.utils.checkpoint import checkpoint
|
| 24 |
+
|
| 25 |
+
from .layers import GatedDeltaNetLayer
|
| 26 |
+
from .quantization import BitLinear, RMSNorm
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _make_linear(use_ternary: bool):
|
| 30 |
+
if use_ternary:
|
| 31 |
+
return BitLinear
|
| 32 |
+
return lambda i, o, **kw: nn.Linear(i, o, bias=False)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class PatchEmbed(nn.Module):
|
| 36 |
+
__constants__ = ["patch_size"]
|
| 37 |
+
|
| 38 |
+
def __init__(self, patch_size: int = 16, in_channels: int = 3, hidden_size: int = 384):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.patch_size = int(patch_size)
|
| 41 |
+
self.proj = nn.Conv2d(in_channels, hidden_size,
|
| 42 |
+
kernel_size=self.patch_size, stride=self.patch_size)
|
| 43 |
+
self.norm = RMSNorm(hidden_size)
|
| 44 |
+
|
| 45 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 46 |
+
x = self.proj(x)
|
| 47 |
+
x = x.flatten(2).transpose(1, 2)
|
| 48 |
+
return self.norm(x)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class _EncoderBlock(nn.Module):
|
| 52 |
+
def __init__(self, hidden: int, num_heads: int, head_dim: int,
|
| 53 |
+
use_ternary: bool = True):
|
| 54 |
+
super().__init__()
|
| 55 |
+
self.norm = RMSNorm(hidden)
|
| 56 |
+
self.attn = GatedDeltaNetLayer(hidden, num_heads, head_dim,
|
| 57 |
+
use_ternary=use_ternary, chunk_size=64)
|
| 58 |
+
self.mlp_norm = RMSNorm(hidden)
|
| 59 |
+
L = _make_linear(use_ternary)
|
| 60 |
+
self.mlp = nn.Sequential(L(hidden, hidden * 4), nn.GELU(), L(hidden * 4, hidden))
|
| 61 |
+
|
| 62 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 63 |
+
attn_out, _ = self.attn(self.norm(x))
|
| 64 |
+
x = x + attn_out
|
| 65 |
+
return x + self.mlp(self.mlp_norm(x))
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class _EncoderBase(nn.Module):
|
| 69 |
+
"""Shared encoder body for vision/audio."""
|
| 70 |
+
|
| 71 |
+
def __init__(self, hidden: int, depth: int, num_heads: int, head_dim: int,
|
| 72 |
+
out_dim: int, use_ternary: bool, use_checkpoint: bool):
|
| 73 |
+
super().__init__()
|
| 74 |
+
self.layers = nn.ModuleList([
|
| 75 |
+
_EncoderBlock(hidden, num_heads, head_dim, use_ternary)
|
| 76 |
+
for _ in range(depth)
|
| 77 |
+
])
|
| 78 |
+
self.proj = nn.Linear(hidden, out_dim, bias=False)
|
| 79 |
+
self.norm = RMSNorm(out_dim)
|
| 80 |
+
self.use_checkpoint = bool(use_checkpoint)
|
| 81 |
+
|
| 82 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 83 |
+
for layer in self.layers:
|
| 84 |
+
if self.use_checkpoint and self.training:
|
| 85 |
+
x = checkpoint(layer, x, use_reentrant=False)
|
| 86 |
+
else:
|
| 87 |
+
x = layer(x)
|
| 88 |
+
return self.norm(self.proj(x))
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class VisionEncoder(nn.Module):
|
| 92 |
+
def __init__(self, config: dict):
|
| 93 |
+
super().__init__()
|
| 94 |
+
v = config.get("vision", {})
|
| 95 |
+
self.enabled = bool(config.get("enabled", True))
|
| 96 |
+
hidden = int(v.get("hidden", 384))
|
| 97 |
+
depth = int(v.get("depth", 12))
|
| 98 |
+
patch = int(v.get("patch", 16))
|
| 99 |
+
# Default the encoder output to the trunk hidden_size so concatenation
|
| 100 |
+
# into the LM stream is dimensionally consistent.
|
| 101 |
+
out_dim = int(v.get("out", config.get("hidden_size", hidden)))
|
| 102 |
+
use_ternary = v.get("quant", "ternary") == "ternary"
|
| 103 |
+
num_heads = max(1, hidden // 64)
|
| 104 |
+
head_dim = hidden // num_heads
|
| 105 |
+
self.patch_embed = PatchEmbed(patch_size=patch, hidden_size=hidden)
|
| 106 |
+
self.body = _EncoderBase(hidden, depth, num_heads, head_dim,
|
| 107 |
+
out_dim, use_ternary, use_checkpoint=True)
|
| 108 |
+
|
| 109 |
+
def forward(self, pixel_values: torch.Tensor) -> Optional[torch.Tensor]:
|
| 110 |
+
if not self.enabled:
|
| 111 |
+
return None
|
| 112 |
+
return self.body(self.patch_embed(pixel_values))
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class AudioEncoder(nn.Module):
|
| 116 |
+
def __init__(self, config: dict):
|
| 117 |
+
super().__init__()
|
| 118 |
+
a = config.get("audio", {})
|
| 119 |
+
self.enabled = bool(config.get("enabled", True))
|
| 120 |
+
hidden = int(a.get("hidden", 256))
|
| 121 |
+
depth = int(a.get("depth", 6))
|
| 122 |
+
out_dim = int(a.get("out", config.get("hidden_size", hidden)))
|
| 123 |
+
use_ternary = a.get("quant", "ternary") == "ternary"
|
| 124 |
+
num_heads = max(1, hidden // 64)
|
| 125 |
+
head_dim = hidden // num_heads
|
| 126 |
+
self.input_proj = nn.Linear(80, hidden, bias=False)
|
| 127 |
+
self.body = _EncoderBase(hidden, depth, num_heads, head_dim,
|
| 128 |
+
out_dim, use_ternary, use_checkpoint=True)
|
| 129 |
+
|
| 130 |
+
def forward(self, mel_features: torch.Tensor) -> Optional[torch.Tensor]:
|
| 131 |
+
if not self.enabled:
|
| 132 |
+
return None
|
| 133 |
+
return self.body(self.input_proj(mel_features))
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
__all__ = ["PatchEmbed", "VisionEncoder", "AudioEncoder"]
|
chimera/paths.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
PACKAGE_ROOT = Path(__file__).resolve().parent
|
| 7 |
+
REPO_ROOT = PACKAGE_ROOT.parent
|
| 8 |
+
DEFAULT_CONFIG_PATH = REPO_ROOT / "config.json"
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def resolve_repo_path(path: str | Path) -> Path:
|
| 12 |
+
candidate = Path(path)
|
| 13 |
+
if candidate.is_absolute():
|
| 14 |
+
return candidate
|
| 15 |
+
return REPO_ROOT / candidate
|
chimera/quantization.py
ADDED
|
@@ -0,0 +1,508 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Chimera 5.2 β 1.58-bit Ternary Compute (CPU-First, Slim)
|
| 3 |
+
========================================================
|
| 4 |
+
Single, clean implementation of BitNet-1.58 ternary linear layers.
|
| 5 |
+
|
| 6 |
+
Design goals:
|
| 7 |
+
* Zero overhead at import time (no JIT, no kernel discovery).
|
| 8 |
+
* One fast pure-PyTorch path that vectorises everything; an optional
|
| 9 |
+
C++/OpenMP path that is loaded *lazily* and only used when it actually
|
| 10 |
+
beats PyTorch (small batches on inference).
|
| 11 |
+
* Cache the packed 2-bit weights between forward calls and only repack
|
| 12 |
+
when the latent FP32 weights are mutated (training step or MeZO).
|
| 13 |
+
* No data-dependent Python loops, no per-row mask construction at init.
|
| 14 |
+
|
| 15 |
+
Storage:
|
| 16 |
+
weight: FP32 latent of shape [M, K] (kept for STE backward / MeZO updates)
|
| 17 |
+
_packed: uint8 [M, ceil(K/4)] (2 bits per ternary value)
|
| 18 |
+
_alpha: fp32 [M] (per-row absolute mean scale)
|
| 19 |
+
|
| 20 |
+
Encoding (matches the C++ kernel):
|
| 21 |
+
-1 β 0b10
|
| 22 |
+
0 β 0b00
|
| 23 |
+
+1 β 0b01
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
from __future__ import annotations
|
| 27 |
+
|
| 28 |
+
import math
|
| 29 |
+
import os
|
| 30 |
+
import threading
|
| 31 |
+
from typing import Optional, Tuple
|
| 32 |
+
|
| 33 |
+
import torch
|
| 34 |
+
import torch.nn as nn
|
| 35 |
+
import torch.nn.functional as F
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# ---------------------------------------------------------------------------
|
| 39 |
+
# Lazy C++ kernel. We never compile it during ``import``; it is only built
|
| 40 |
+
# when explicitly requested via :func:`enable_native_kernel` or the env var
|
| 41 |
+
# ``CHIMERA_NATIVE=1``. All public APIs work with the pure-PyTorch path.
|
| 42 |
+
# ---------------------------------------------------------------------------
|
| 43 |
+
|
| 44 |
+
_NATIVE_LOCK = threading.Lock()
|
| 45 |
+
_NATIVE_EXT: Optional[object] = None
|
| 46 |
+
_NATIVE_TRIED = False
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
_CPP_SOURCE = r"""
|
| 50 |
+
#include <torch/extension.h>
|
| 51 |
+
#include <cstdint>
|
| 52 |
+
#include <cmath>
|
| 53 |
+
#ifdef _OPENMP
|
| 54 |
+
#include <omp.h>
|
| 55 |
+
#endif
|
| 56 |
+
|
| 57 |
+
// Encoding: -1->0b10, 0->0b00, +1->0b01
|
| 58 |
+
static const float LUT[4] = {0.0f, 1.0f, -1.0f, 0.0f};
|
| 59 |
+
|
| 60 |
+
torch::Tensor pack_ternary_cpu(torch::Tensor w) {
|
| 61 |
+
TORCH_CHECK(w.dim() == 2 && w.dtype() == torch::kInt8, "expected int8 [M,K]");
|
| 62 |
+
auto w_c = w.contiguous();
|
| 63 |
+
int64_t M = w_c.size(0), K = w_c.size(1);
|
| 64 |
+
int64_t K4 = (K + 3) / 4;
|
| 65 |
+
auto out = torch::zeros({M, K4}, torch::kUInt8);
|
| 66 |
+
const int8_t* s = w_c.data_ptr<int8_t>();
|
| 67 |
+
uint8_t* d = out.data_ptr<uint8_t>();
|
| 68 |
+
#pragma omp parallel for schedule(static)
|
| 69 |
+
for (int64_t m = 0; m < M; ++m) {
|
| 70 |
+
const int8_t* sr = s + m * K;
|
| 71 |
+
uint8_t* dr = d + m * K4;
|
| 72 |
+
for (int64_t k4 = 0; k4 < K4; ++k4) {
|
| 73 |
+
uint8_t b = 0;
|
| 74 |
+
for (int j = 0; j < 4; ++j) {
|
| 75 |
+
int64_t k = k4 * 4 + j;
|
| 76 |
+
if (k >= K) break;
|
| 77 |
+
int8_t v = sr[k];
|
| 78 |
+
uint8_t code = (v == 1) ? 1u : (v == -1 ? 2u : 0u);
|
| 79 |
+
b |= (code << (6 - j * 2));
|
| 80 |
+
}
|
| 81 |
+
dr[k4] = b;
|
| 82 |
+
}
|
| 83 |
+
}
|
| 84 |
+
return out;
|
| 85 |
+
}
|
| 86 |
+
|
| 87 |
+
torch::Tensor unpack_ternary_cpu(torch::Tensor packed, int64_t K) {
|
| 88 |
+
TORCH_CHECK(packed.dim() == 2 && packed.dtype() == torch::kUInt8, "expected uint8 [M,K4]");
|
| 89 |
+
auto p = packed.contiguous();
|
| 90 |
+
int64_t M = p.size(0), K4 = p.size(1);
|
| 91 |
+
auto out = torch::empty({M, K}, torch::kFloat32);
|
| 92 |
+
const uint8_t* pp = p.data_ptr<uint8_t>();
|
| 93 |
+
float* dp = out.data_ptr<float>();
|
| 94 |
+
#pragma omp parallel for schedule(static)
|
| 95 |
+
for (int64_t m = 0; m < M; ++m) {
|
| 96 |
+
const uint8_t* pr = pp + m * K4;
|
| 97 |
+
float* dr = dp + m * K;
|
| 98 |
+
for (int64_t k4 = 0; k4 < K4; ++k4) {
|
| 99 |
+
uint8_t b = pr[k4];
|
| 100 |
+
int64_t base = k4 * 4;
|
| 101 |
+
if (base + 0 < K) dr[base + 0] = LUT[(b >> 6) & 3];
|
| 102 |
+
if (base + 1 < K) dr[base + 1] = LUT[(b >> 4) & 3];
|
| 103 |
+
if (base + 2 < K) dr[base + 2] = LUT[(b >> 2) & 3];
|
| 104 |
+
if (base + 3 < K) dr[base + 3] = LUT[b & 3];
|
| 105 |
+
}
|
| 106 |
+
}
|
| 107 |
+
return out;
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
// Fused "unpack and scale" -> bf16/fp32 dense weight. Saves a pass over memory
|
| 111 |
+
// and a temporary FP32 tensor when running under bf16 autocast.
|
| 112 |
+
torch::Tensor dequantize_cpu(torch::Tensor packed, torch::Tensor alpha, int64_t K) {
|
| 113 |
+
auto p = packed.contiguous();
|
| 114 |
+
auto a = alpha.contiguous().to(torch::kFloat32);
|
| 115 |
+
int64_t M = p.size(0), K4 = p.size(1);
|
| 116 |
+
auto out = torch::empty({M, K}, torch::kFloat32);
|
| 117 |
+
const uint8_t* pp = p.data_ptr<uint8_t>();
|
| 118 |
+
const float* ap = a.data_ptr<float>();
|
| 119 |
+
float* dp = out.data_ptr<float>();
|
| 120 |
+
#pragma omp parallel for schedule(static)
|
| 121 |
+
for (int64_t m = 0; m < M; ++m) {
|
| 122 |
+
const uint8_t* pr = pp + m * K4;
|
| 123 |
+
float* dr = dp + m * K;
|
| 124 |
+
float sc = ap[m];
|
| 125 |
+
for (int64_t k4 = 0; k4 < K4; ++k4) {
|
| 126 |
+
uint8_t b = pr[k4];
|
| 127 |
+
int64_t base = k4 * 4;
|
| 128 |
+
if (base + 0 < K) dr[base + 0] = LUT[(b >> 6) & 3] * sc;
|
| 129 |
+
if (base + 1 < K) dr[base + 1] = LUT[(b >> 4) & 3] * sc;
|
| 130 |
+
if (base + 2 < K) dr[base + 2] = LUT[(b >> 2) & 3] * sc;
|
| 131 |
+
if (base + 3 < K) dr[base + 3] = LUT[b & 3] * sc;
|
| 132 |
+
}
|
| 133 |
+
}
|
| 134 |
+
return out;
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 138 |
+
m.def("pack_ternary", &pack_ternary_cpu, "Pack int8 ternary -> 2-bit uint8");
|
| 139 |
+
m.def("unpack_ternary", &unpack_ternary_cpu, "Unpack 2-bit uint8 -> fp32 {-1,0,1}");
|
| 140 |
+
m.def("dequantize", &dequantize_cpu, "Unpack and scale by per-row alpha");
|
| 141 |
+
}
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def _try_load_native() -> Optional[object]:
|
| 146 |
+
"""Compile/load the optional native helper. Idempotent and thread-safe."""
|
| 147 |
+
global _NATIVE_EXT, _NATIVE_TRIED
|
| 148 |
+
if _NATIVE_TRIED:
|
| 149 |
+
return _NATIVE_EXT
|
| 150 |
+
with _NATIVE_LOCK:
|
| 151 |
+
if _NATIVE_TRIED:
|
| 152 |
+
return _NATIVE_EXT
|
| 153 |
+
_NATIVE_TRIED = True
|
| 154 |
+
try:
|
| 155 |
+
from torch.utils.cpp_extension import load_inline
|
| 156 |
+
|
| 157 |
+
build_dir = os.path.join(
|
| 158 |
+
os.path.dirname(os.path.abspath(__file__)), "..", ".ternary_build"
|
| 159 |
+
)
|
| 160 |
+
os.makedirs(build_dir, exist_ok=True)
|
| 161 |
+
_NATIVE_EXT = load_inline(
|
| 162 |
+
name="chimera_ternary",
|
| 163 |
+
cpp_sources=_CPP_SOURCE,
|
| 164 |
+
extra_cflags=["-O3", "-fopenmp", "-ffast-math", "-funroll-loops"],
|
| 165 |
+
extra_ldflags=["-lgomp"],
|
| 166 |
+
build_directory=build_dir,
|
| 167 |
+
verbose=False,
|
| 168 |
+
)
|
| 169 |
+
except Exception as exc: # pragma: no cover - best-effort.
|
| 170 |
+
os.environ.setdefault("CHIMERA_NATIVE_DISABLED", str(exc)[:200])
|
| 171 |
+
_NATIVE_EXT = None
|
| 172 |
+
return _NATIVE_EXT
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def enable_native_kernel(force: bool = False) -> bool:
|
| 176 |
+
"""Eagerly try to compile the native kernel.
|
| 177 |
+
|
| 178 |
+
Returns ``True`` if the kernel is loaded and available.
|
| 179 |
+
"""
|
| 180 |
+
global _NATIVE_TRIED
|
| 181 |
+
if force:
|
| 182 |
+
_NATIVE_TRIED = False
|
| 183 |
+
return _try_load_native() is not None
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def native_kernel_available() -> bool:
|
| 187 |
+
return _NATIVE_EXT is not None
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
# Allow opt-in from the environment without code changes.
|
| 191 |
+
if os.environ.get("CHIMERA_NATIVE", "0") == "1":
|
| 192 |
+
enable_native_kernel()
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
# ---------------------------------------------------------------------------
|
| 196 |
+
# Pure PyTorch ternary primitives (always available).
|
| 197 |
+
# ---------------------------------------------------------------------------
|
| 198 |
+
|
| 199 |
+
# Lookup tables compiled once. Casting to a registered buffer is overkill β
|
| 200 |
+
# they live on CPU and broadcast naturally.
|
| 201 |
+
_TERNARY_LUT_F32 = torch.tensor([0.0, 1.0, -1.0, 0.0], dtype=torch.float32)
|
| 202 |
+
_TERNARY_LUT_I8 = torch.tensor([0, 1, -1, 0], dtype=torch.int8)
|
| 203 |
+
_SHIFTS = torch.tensor([6, 4, 2, 0], dtype=torch.uint8)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def pack_ternary(q: torch.Tensor) -> torch.Tensor:
|
| 207 |
+
"""Pack a ternary {-1,0,1} tensor into a 2-bit uint8 tensor.
|
| 208 |
+
|
| 209 |
+
Vectorised pure-PyTorch implementation β no Python loops over rows.
|
| 210 |
+
Trailing positions that don't divide by four are zero-padded.
|
| 211 |
+
"""
|
| 212 |
+
q = q.detach()
|
| 213 |
+
if q.dim() == 1:
|
| 214 |
+
q = q.unsqueeze(0)
|
| 215 |
+
flat = q.reshape(-1, q.shape[-1]).to(torch.int8)
|
| 216 |
+
M, K = flat.shape
|
| 217 |
+
K4 = (K + 3) // 4
|
| 218 |
+
pad = K4 * 4 - K
|
| 219 |
+
if pad:
|
| 220 |
+
flat = F.pad(flat, (0, pad))
|
| 221 |
+
# codes: 0 / 1 / 2 (uint8)
|
| 222 |
+
codes = torch.where(flat == 1, torch.full_like(flat, 1),
|
| 223 |
+
torch.where(flat == -1, torch.full_like(flat, 2), torch.zeros_like(flat))).to(torch.uint8)
|
| 224 |
+
codes = codes.view(M, K4, 4)
|
| 225 |
+
packed = ((codes[..., 0] << 6) | (codes[..., 1] << 4) |
|
| 226 |
+
(codes[..., 2] << 2) | codes[..., 3]).contiguous()
|
| 227 |
+
return packed.reshape(*q.shape[:-1], K4)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def unpack_ternary(packed: torch.Tensor, k: int,
|
| 231 |
+
alpha: Optional[torch.Tensor] = None,
|
| 232 |
+
dtype: torch.dtype = torch.float32) -> torch.Tensor:
|
| 233 |
+
"""Vectorised inverse of :func:`pack_ternary`.
|
| 234 |
+
|
| 235 |
+
Returns ``out`` with last dim ``k``; optionally pre-multiplied by
|
| 236 |
+
``alpha`` (per-row scale, broadcastable on the leading axes).
|
| 237 |
+
"""
|
| 238 |
+
packed = packed.to(torch.uint8)
|
| 239 |
+
if packed.dim() == 1:
|
| 240 |
+
packed = packed.unsqueeze(0)
|
| 241 |
+
flat = packed.reshape(-1, packed.shape[-1])
|
| 242 |
+
M, K4 = flat.shape
|
| 243 |
+
# Gather all 4 sub-positions in one vectorised op.
|
| 244 |
+
shifts = _SHIFTS.to(packed.device)
|
| 245 |
+
codes = (flat.unsqueeze(-1) >> shifts).bitwise_and_(3).to(torch.long) # [M, K4, 4]
|
| 246 |
+
lut = _TERNARY_LUT_F32.to(device=packed.device, dtype=dtype)
|
| 247 |
+
out = lut[codes].reshape(M, K4 * 4)[:, :k]
|
| 248 |
+
if alpha is not None:
|
| 249 |
+
out = out * alpha.reshape(M, 1).to(device=out.device, dtype=out.dtype)
|
| 250 |
+
return out.reshape(*packed.shape[:-1], k)
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def _absmean_alpha(weight: torch.Tensor, eps: float = 1e-5) -> torch.Tensor:
|
| 254 |
+
"""Per-output-channel scale (``\alpha = mean|w|`` clamped)."""
|
| 255 |
+
return weight.detach().abs().mean(dim=-1, keepdim=False).clamp_min(eps).to(torch.float32)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
def ternarize_weight(weight: torch.Tensor, group_size: int = 128
|
| 259 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 260 |
+
"""Quantise FP32 weights to ternary using BitNet's abs-mean rule.
|
| 261 |
+
|
| 262 |
+
``group_size`` is kept for API compatibility but every row is its own
|
| 263 |
+
group in this slim implementation. Returns ``(w_ternary, alpha)``.
|
| 264 |
+
"""
|
| 265 |
+
alpha = _absmean_alpha(weight)
|
| 266 |
+
w_q = torch.round(torch.clamp(weight / alpha.unsqueeze(-1), -1.0, 1.0)).to(torch.int8)
|
| 267 |
+
return w_q, alpha
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
_quantize_weights_ternary = ternarize_weight # legacy alias used elsewhere
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def apply_2_4_sparsity_(weight: torch.Tensor) -> torch.Tensor:
|
| 274 |
+
"""In-place N:M 2:4 pruning. Vectorised β no Python row loops."""
|
| 275 |
+
with torch.no_grad():
|
| 276 |
+
last = weight.shape[-1]
|
| 277 |
+
pad = (-last) % 4
|
| 278 |
+
target = F.pad(weight, (0, pad)) if pad else weight
|
| 279 |
+
view = target.view(*target.shape[:-1], -1, 4)
|
| 280 |
+
# Keep the two largest in absolute value, zero the rest.
|
| 281 |
+
idx = view.abs().argsort(dim=-1)[..., :2]
|
| 282 |
+
view.scatter_(-1, idx, 0.0)
|
| 283 |
+
if pad:
|
| 284 |
+
weight.copy_(target[..., :last])
|
| 285 |
+
return weight
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
# ---------------------------------------------------------------------------
|
| 289 |
+
# Straight-Through Estimator for ternary quantization.
|
| 290 |
+
# ---------------------------------------------------------------------------
|
| 291 |
+
|
| 292 |
+
class _RoundTernarySTE(torch.autograd.Function):
|
| 293 |
+
@staticmethod
|
| 294 |
+
def forward(ctx, w: torch.Tensor) -> torch.Tensor: # type: ignore[override]
|
| 295 |
+
return torch.round(torch.clamp(w, -1.0, 1.0))
|
| 296 |
+
|
| 297 |
+
@staticmethod
|
| 298 |
+
def backward(ctx, grad_output: torch.Tensor): # type: ignore[override]
|
| 299 |
+
# Standard STE: gradient flows through, clipped to [-1, 1] so the
|
| 300 |
+
# latent FP32 weights cannot drift unboundedly.
|
| 301 |
+
return grad_output.clamp(-1.0, 1.0)
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def ste_ternary(w: torch.Tensor) -> torch.Tensor:
|
| 305 |
+
return _RoundTernarySTE.apply(w)
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
# ---------------------------------------------------------------------------
|
| 309 |
+
# BitLinear β single class, single fast path.
|
| 310 |
+
# ---------------------------------------------------------------------------
|
| 311 |
+
|
| 312 |
+
class BitLinear(nn.Module):
|
| 313 |
+
"""Linear layer with ternary {-1, 0, 1} weights and per-row absmean scale.
|
| 314 |
+
|
| 315 |
+
*Training (grad-enabled)*: STE ternarisation on the latent weight, dense
|
| 316 |
+
fp32/bf16 matmul. Backward flows to the latent weight via STE.
|
| 317 |
+
|
| 318 |
+
*Inference / no-grad*: weights are quantised once and cached as packed
|
| 319 |
+
2-bit uint8 + fp32 alpha. Each forward unpacks (vectorised PyTorch or
|
| 320 |
+
optional C++ kernel) into a reusable buffer and calls a single matmul.
|
| 321 |
+
"""
|
| 322 |
+
|
| 323 |
+
__constants__ = ["in_features", "out_features", "use_2_4"]
|
| 324 |
+
|
| 325 |
+
def __init__(self, in_features: int, out_features: int, bias: bool = False,
|
| 326 |
+
group_size: int = 128, nm_2_4: bool = False):
|
| 327 |
+
super().__init__()
|
| 328 |
+
self.in_features = int(in_features)
|
| 329 |
+
self.out_features = int(out_features)
|
| 330 |
+
self.group_size = int(group_size)
|
| 331 |
+
self.use_2_4 = bool(nm_2_4)
|
| 332 |
+
|
| 333 |
+
self.weight = nn.Parameter(torch.empty(self.out_features, self.in_features))
|
| 334 |
+
if bias:
|
| 335 |
+
self.bias = nn.Parameter(torch.zeros(self.out_features))
|
| 336 |
+
else:
|
| 337 |
+
self.register_parameter("bias", None)
|
| 338 |
+
|
| 339 |
+
# Caches. ``_cache_version`` is bumped whenever the latent weight
|
| 340 |
+
# changes; the forward pass compares it against ``_packed_version``
|
| 341 |
+
# to know when to repack.
|
| 342 |
+
self.register_buffer("_packed", torch.zeros(0, dtype=torch.uint8), persistent=False)
|
| 343 |
+
self.register_buffer("_alpha", torch.zeros(0, dtype=torch.float32), persistent=False)
|
| 344 |
+
# Optional dense fp32 cache of the dequantised ternary weight. This
|
| 345 |
+
# is what every inference forward actually needs, so caching it
|
| 346 |
+
# eliminates the per-call unpack and saves ~30-50% of CPU time on
|
| 347 |
+
# small models. It is only built lazily on first inference call.
|
| 348 |
+
self.register_buffer("_dense_w", torch.zeros(0, dtype=torch.float32), persistent=False)
|
| 349 |
+
self._packed_version = -1
|
| 350 |
+
self._dense_version = -1
|
| 351 |
+
self._cache_version = 0
|
| 352 |
+
|
| 353 |
+
self.reset_parameters()
|
| 354 |
+
|
| 355 |
+
# -- init ------------------------------------------------------------------
|
| 356 |
+
|
| 357 |
+
def reset_parameters(self) -> None:
|
| 358 |
+
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
| 359 |
+
if self.bias is not None:
|
| 360 |
+
nn.init.zeros_(self.bias)
|
| 361 |
+
self._cache_version += 1
|
| 362 |
+
|
| 363 |
+
# -- helpers ---------------------------------------------------------------
|
| 364 |
+
|
| 365 |
+
def invalidate_packed(self) -> None:
|
| 366 |
+
"""Mark the packed cache stale. Called after weight mutations."""
|
| 367 |
+
self._cache_version += 1
|
| 368 |
+
# Free the dense fp32 cache too; next forward will rebuild it.
|
| 369 |
+
if self._dense_w.numel() > 0:
|
| 370 |
+
self._dense_w = torch.zeros(0, dtype=torch.float32, device=self._dense_w.device)
|
| 371 |
+
self._dense_version = -1
|
| 372 |
+
|
| 373 |
+
def _quantize_latent(self) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 374 |
+
"""Quantise the FP32 latent weight to ternary (no-grad, no copy)."""
|
| 375 |
+
with torch.no_grad():
|
| 376 |
+
w = self.weight
|
| 377 |
+
alpha = _absmean_alpha(w)
|
| 378 |
+
w_q = torch.round(torch.clamp(w / alpha.unsqueeze(-1), -1.0, 1.0))
|
| 379 |
+
if self.use_2_4:
|
| 380 |
+
apply_2_4_sparsity_(w_q)
|
| 381 |
+
return w_q.to(torch.int8), alpha
|
| 382 |
+
|
| 383 |
+
def _ensure_packed(self) -> None:
|
| 384 |
+
if self._packed_version == self._cache_version and self._packed.numel() > 0:
|
| 385 |
+
return
|
| 386 |
+
with torch.no_grad():
|
| 387 |
+
w_q, alpha = self._quantize_latent()
|
| 388 |
+
ext = _NATIVE_EXT
|
| 389 |
+
if ext is not None:
|
| 390 |
+
packed = ext.pack_ternary(w_q)
|
| 391 |
+
else:
|
| 392 |
+
packed = pack_ternary(w_q)
|
| 393 |
+
# Replace storage in-place to avoid breaking nn.Module buffer tracking.
|
| 394 |
+
self._packed = packed.contiguous()
|
| 395 |
+
self._alpha = alpha.contiguous()
|
| 396 |
+
self._packed_version = self._cache_version
|
| 397 |
+
|
| 398 |
+
@torch.no_grad()
|
| 399 |
+
def prepare_for_inference(self) -> None:
|
| 400 |
+
"""Materialise the packed cache so the next forward is allocation-free."""
|
| 401 |
+
self.invalidate_packed()
|
| 402 |
+
self._ensure_packed()
|
| 403 |
+
|
| 404 |
+
@torch.no_grad()
|
| 405 |
+
def ternary_nonzero_mask(self) -> torch.Tensor:
|
| 406 |
+
"""Boolean mask of currently non-zero ternary positions (cached)."""
|
| 407 |
+
self._ensure_packed()
|
| 408 |
+
# Reuse the dequantised float view through unpack β cheaper than a fresh
|
| 409 |
+
# dense ternary tensor on small models, and shared for both branches.
|
| 410 |
+
ext = _NATIVE_EXT
|
| 411 |
+
if ext is not None:
|
| 412 |
+
w = ext.unpack_ternary(self._packed, self.in_features)
|
| 413 |
+
else:
|
| 414 |
+
w = unpack_ternary(self._packed, self.in_features)
|
| 415 |
+
return w.ne(0)
|
| 416 |
+
|
| 417 |
+
# -- forward ---------------------------------------------------------------
|
| 418 |
+
|
| 419 |
+
def _forward_train(self, x: torch.Tensor) -> torch.Tensor:
|
| 420 |
+
"""STE forward: differentiable, fp32/bf16 dense matmul."""
|
| 421 |
+
w = self.weight
|
| 422 |
+
alpha = w.detach().abs().mean(dim=-1, keepdim=True).clamp_min(1e-5)
|
| 423 |
+
w_q = ste_ternary(w / alpha) * alpha
|
| 424 |
+
if self.use_2_4:
|
| 425 |
+
# 2:4 sparsity is non-differentiable but only zeros gradients on
|
| 426 |
+
# already-pruned positions; safe to apply during STE forward.
|
| 427 |
+
with torch.no_grad():
|
| 428 |
+
mask = (apply_2_4_sparsity_(w_q.detach().clone()) != 0).to(w_q.dtype)
|
| 429 |
+
w_q = w_q * mask
|
| 430 |
+
return F.linear(x, w_q.to(x.dtype), self.bias)
|
| 431 |
+
|
| 432 |
+
def _ensure_dense(self) -> torch.Tensor:
|
| 433 |
+
"""Materialise (and cache) the fp32 dense ternary weight."""
|
| 434 |
+
self._ensure_packed()
|
| 435 |
+
if self._dense_version == self._cache_version and self._dense_w.numel() > 0:
|
| 436 |
+
return self._dense_w
|
| 437 |
+
ext = _NATIVE_EXT
|
| 438 |
+
if ext is not None:
|
| 439 |
+
w = ext.dequantize(self._packed, self._alpha, self.in_features)
|
| 440 |
+
else:
|
| 441 |
+
w = unpack_ternary(self._packed, self.in_features) * self._alpha.unsqueeze(-1)
|
| 442 |
+
# Replace the buffer in place so nn.Module book-keeping stays valid.
|
| 443 |
+
self._dense_w = w.contiguous()
|
| 444 |
+
self._dense_version = self._cache_version
|
| 445 |
+
return self._dense_w
|
| 446 |
+
|
| 447 |
+
def _forward_packed(self, x: torch.Tensor) -> torch.Tensor:
|
| 448 |
+
"""No-grad fast path that uses the cached dequantised weights."""
|
| 449 |
+
w = self._ensure_dense()
|
| 450 |
+
# Match dtype (bf16 autocast support) without re-allocating the cache.
|
| 451 |
+
if x.dtype != w.dtype:
|
| 452 |
+
w_used = w.to(x.dtype)
|
| 453 |
+
else:
|
| 454 |
+
w_used = w
|
| 455 |
+
return F.linear(x, w_used, self.bias)
|
| 456 |
+
|
| 457 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 458 |
+
if self.training and torch.is_grad_enabled():
|
| 459 |
+
return self._forward_train(x)
|
| 460 |
+
return self._forward_packed(x)
|
| 461 |
+
|
| 462 |
+
# -- introspection ---------------------------------------------------------
|
| 463 |
+
|
| 464 |
+
def extra_repr(self) -> str:
|
| 465 |
+
return (f"in_features={self.in_features}, out_features={self.out_features}, "
|
| 466 |
+
f"bias={self.bias is not None}, nm_2_4={self.use_2_4}, "
|
| 467 |
+
f"native={native_kernel_available()}")
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
# ---------------------------------------------------------------------------
|
| 471 |
+
# RMSNorm.
|
| 472 |
+
# ---------------------------------------------------------------------------
|
| 473 |
+
|
| 474 |
+
class RMSNorm(nn.Module):
|
| 475 |
+
"""Numerically-stable Root Mean Square LayerNorm (no bias, no centering)."""
|
| 476 |
+
|
| 477 |
+
__constants__ = ["dim", "eps"]
|
| 478 |
+
|
| 479 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 480 |
+
super().__init__()
|
| 481 |
+
self.dim = int(dim)
|
| 482 |
+
self.eps = float(eps)
|
| 483 |
+
self.weight = nn.Parameter(torch.ones(self.dim))
|
| 484 |
+
|
| 485 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 486 |
+
# The normalisation is computed in fp32 for stability under bf16
|
| 487 |
+
# autocast, then cast back to the input dtype.
|
| 488 |
+
dtype = x.dtype
|
| 489 |
+
if dtype != torch.float32:
|
| 490 |
+
x32 = x.float()
|
| 491 |
+
rms = torch.rsqrt(x32.pow(2).mean(dim=-1, keepdim=True).add(self.eps))
|
| 492 |
+
return (x32 * rms).to(dtype) * self.weight
|
| 493 |
+
rms = torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True).add(self.eps))
|
| 494 |
+
return x * rms * self.weight
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
__all__ = [
|
| 498 |
+
"BitLinear",
|
| 499 |
+
"RMSNorm",
|
| 500 |
+
"ste_ternary",
|
| 501 |
+
"pack_ternary",
|
| 502 |
+
"unpack_ternary",
|
| 503 |
+
"ternarize_weight",
|
| 504 |
+
"_quantize_weights_ternary",
|
| 505 |
+
"apply_2_4_sparsity_",
|
| 506 |
+
"enable_native_kernel",
|
| 507 |
+
"native_kernel_available",
|
| 508 |
+
]
|
chimera/tokenizer.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Chimera 5.1 β Splintr (Rust) Tokenizer Wrapper β o200k_base (OpenAI o1/o3)
|
| 3 |
+
Wraps splintr's high-performance Rust tokenizer for transformers-compatible API.
|
| 4 |
+
Vocab: o200k_base (200,073 tokens) β OpenAI's o1/o3 tokenizer.
|
| 5 |
+
|
| 6 |
+
Optimizations:
|
| 7 |
+
- __slots__ for reduced memory footprint
|
| 8 |
+
- Cached special token set for fast skip_special_tokens filtering
|
| 9 |
+
- Batch encode uses list comprehension (minimizes Python overhead)
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from typing import List, Union, Optional
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
from splintr import Tokenizer as _SplintrTokenizer, O200K_AGENT_TOKENS
|
| 17 |
+
HAS_SPLINTR = True
|
| 18 |
+
except ImportError:
|
| 19 |
+
HAS_SPLINTR = False
|
| 20 |
+
|
| 21 |
+
__all__ = ["ChimeraTokenizer"]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class ChimeraTokenizer:
|
| 25 |
+
"""
|
| 26 |
+
High-performance Rust-backed tokenizer (splintr) with HuggingFace-like interface.
|
| 27 |
+
Falls back to a basic tiktoken wrapper if splintr is not installed.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self, pretrained: str = "o200k_base", vocab_size: int = 200073):
|
| 31 |
+
if not HAS_SPLINTR:
|
| 32 |
+
self._tok = None
|
| 33 |
+
self.vocab_size = int(vocab_size)
|
| 34 |
+
self.eos_token_id = min(self.vocab_size - 1, 199999)
|
| 35 |
+
self.pad_token_id = min(self.vocab_size - 1, 200058)
|
| 36 |
+
self.sep_token_id = min(self.vocab_size - 1, 200060)
|
| 37 |
+
self.stop_token_id = min(self.vocab_size - 1, 200059)
|
| 38 |
+
self.user_token_id = min(self.vocab_size - 1, 200020)
|
| 39 |
+
self.assistant_token_id = min(self.vocab_size - 1, 200021)
|
| 40 |
+
self.system_token_id = min(self.vocab_size - 1, 200019)
|
| 41 |
+
self.endofprompt_token_id = min(self.vocab_size - 1, 200018)
|
| 42 |
+
self.bos_token_id = self.eos_token_id
|
| 43 |
+
self.eos_token = "<|endoftext|>"
|
| 44 |
+
self.pad_token = "<|pad|>"
|
| 45 |
+
self.model_max_length = 4194304
|
| 46 |
+
self._special_ids = frozenset({self.eos_token_id, self.pad_token_id, self.sep_token_id, self.stop_token_id, self.user_token_id, self.assistant_token_id, self.system_token_id, self.endofprompt_token_id})
|
| 47 |
+
self._byte_offset = 3
|
| 48 |
+
return
|
| 49 |
+
self._tok = _SplintrTokenizer.from_pretrained(pretrained)
|
| 50 |
+
self.vocab_size = self._tok.vocab_size
|
| 51 |
+
|
| 52 |
+
# o200k_base single-token special IDs
|
| 53 |
+
self.eos_token_id = 199999
|
| 54 |
+
self.pad_token_id = O200K_AGENT_TOKENS.PAD # 200058
|
| 55 |
+
self.sep_token_id = O200K_AGENT_TOKENS.SEP # 200060
|
| 56 |
+
self.stop_token_id = O200K_AGENT_TOKENS.STOP # 200059
|
| 57 |
+
self.user_token_id = O200K_AGENT_TOKENS.USER # 200020
|
| 58 |
+
self.assistant_token_id = O200K_AGENT_TOKENS.ASSISTANT # 200021
|
| 59 |
+
self.system_token_id = 200019
|
| 60 |
+
self.endofprompt_token_id = 200018
|
| 61 |
+
self.bos_token_id = self.eos_token_id
|
| 62 |
+
|
| 63 |
+
self.eos_token = "<|endoftext|>"
|
| 64 |
+
self.pad_token = "<|pad|>"
|
| 65 |
+
self.model_max_length = 4194304
|
| 66 |
+
|
| 67 |
+
# Cached set for fast filtering
|
| 68 |
+
self._special_ids = frozenset({
|
| 69 |
+
self.eos_token_id, self.pad_token_id, self.sep_token_id,
|
| 70 |
+
self.stop_token_id, self.user_token_id,
|
| 71 |
+
self.assistant_token_id, self.system_token_id,
|
| 72 |
+
self.endofprompt_token_id,
|
| 73 |
+
})
|
| 74 |
+
|
| 75 |
+
def __len__(self) -> int:
|
| 76 |
+
return self.vocab_size
|
| 77 |
+
|
| 78 |
+
def encode(self, text: str, add_special_tokens: bool = True,
|
| 79 |
+
max_length: Optional[int] = None) -> List[int]:
|
| 80 |
+
if self._tok is None:
|
| 81 |
+
ids = [self._byte_offset + b for b in text.encode("utf-8", errors="replace")]
|
| 82 |
+
else:
|
| 83 |
+
ids = self._tok.encode(text)
|
| 84 |
+
if add_special_tokens:
|
| 85 |
+
ids = ids + [self.eos_token_id]
|
| 86 |
+
if max_length is not None and len(ids) > max_length:
|
| 87 |
+
ids = ids[:max_length]
|
| 88 |
+
return ids
|
| 89 |
+
|
| 90 |
+
def encode_batch(self, texts: List[str], add_special_tokens: bool = True,
|
| 91 |
+
max_length: Optional[int] = None,
|
| 92 |
+
padding: bool = False,
|
| 93 |
+
truncation: bool = False,
|
| 94 |
+
return_tensors: Optional[str] = None):
|
| 95 |
+
all_ids = [self.encode(t, add_special_tokens=add_special_tokens,
|
| 96 |
+
max_length=max_length)
|
| 97 |
+
for t in texts]
|
| 98 |
+
if padding:
|
| 99 |
+
max_len = max(len(ids) for ids in all_ids)
|
| 100 |
+
all_ids = [ids + [self.pad_token_id] * (max_len - len(ids))
|
| 101 |
+
for ids in all_ids]
|
| 102 |
+
if return_tensors == "pt":
|
| 103 |
+
return {"input_ids": torch.tensor(all_ids, dtype=torch.long)}
|
| 104 |
+
return all_ids
|
| 105 |
+
|
| 106 |
+
def decode(self, token_ids, skip_special_tokens: bool = True) -> str:
|
| 107 |
+
if isinstance(token_ids, torch.Tensor):
|
| 108 |
+
token_ids = token_ids.tolist()
|
| 109 |
+
if skip_special_tokens:
|
| 110 |
+
token_ids = [t for t in token_ids if t not in self._special_ids]
|
| 111 |
+
if self._tok is None:
|
| 112 |
+
data = bytes(max(0, min(255, int(t) - self._byte_offset)) for t in token_ids if int(t) >= self._byte_offset)
|
| 113 |
+
return data.decode("utf-8", errors="replace")
|
| 114 |
+
return self._tok.decode(token_ids)
|
| 115 |
+
|
| 116 |
+
def decode_batch(self, token_ids_list, skip_special_tokens: bool = True) -> List[str]:
|
| 117 |
+
return [self.decode(ids, skip_special_tokens=skip_special_tokens)
|
| 118 |
+
for ids in token_ids_list]
|
| 119 |
+
|
| 120 |
+
def __call__(self, text, **kwargs) -> dict:
|
| 121 |
+
return_tensors = kwargs.get("return_tensors", "pt")
|
| 122 |
+
padding = kwargs.get("padding", False)
|
| 123 |
+
max_length = kwargs.get("max_length", None)
|
| 124 |
+
add_special_tokens = kwargs.get("add_special_tokens", True)
|
| 125 |
+
if isinstance(text, str):
|
| 126 |
+
text = [text]
|
| 127 |
+
result = self.encode_batch(
|
| 128 |
+
text, add_special_tokens=add_special_tokens,
|
| 129 |
+
max_length=max_length, padding=padding,
|
| 130 |
+
return_tensors=return_tensors
|
| 131 |
+
)
|
| 132 |
+
if isinstance(result, list):
|
| 133 |
+
return {"input_ids": torch.tensor(result, dtype=torch.long)}
|
| 134 |
+
return result
|
| 135 |
+
|
| 136 |
+
def get_vocab(self) -> dict:
|
| 137 |
+
return {
|
| 138 |
+
self.eos_token_id: self.eos_token,
|
| 139 |
+
self.pad_token_id: self.pad_token,
|
| 140 |
+
self.user_token_id: "<|user|>",
|
| 141 |
+
self.assistant_token_id: "<|assistant|>",
|
| 142 |
+
self.system_token_id: "<|system|>",
|
| 143 |
+
}
|
| 144 |
+
|
| 145 |
+
def apply_chat_template(self, messages: List[dict],
|
| 146 |
+
add_generation_prompt: bool = False) -> str:
|
| 147 |
+
parts = []
|
| 148 |
+
for msg in messages:
|
| 149 |
+
role = msg.get("role", "user")
|
| 150 |
+
content = msg.get("content", "")
|
| 151 |
+
if role == "system":
|
| 152 |
+
parts.append(f"<|system|>\n{content}\n<|endofprompt|>")
|
| 153 |
+
elif role == "user":
|
| 154 |
+
parts.append(f"<|user|>\n{content}\n<|endofprompt|>")
|
| 155 |
+
elif role == "assistant":
|
| 156 |
+
parts.append(f"<|assistant|>\n{content}\n<|endofprompt|>")
|
| 157 |
+
text = "\n".join(parts)
|
| 158 |
+
if add_generation_prompt:
|
| 159 |
+
text += "\n<|assistant|>\n"
|
| 160 |
+
return text
|
chimera/training/__init__.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .benchmark import benchmark_hyper, run_baseline, run_hyper
|
| 2 |
+
from .common import (
|
| 3 |
+
DEFAULT_SCALE_PRESETS,
|
| 4 |
+
apply_standard_config_tweaks,
|
| 5 |
+
build_model_from_args,
|
| 6 |
+
cosine_lr,
|
| 7 |
+
save_final_checkpoint,
|
| 8 |
+
save_training_checkpoint,
|
| 9 |
+
setup_cpu_runtime,
|
| 10 |
+
)
|
| 11 |
+
from .datasets import (
|
| 12 |
+
GrowLengthDataset,
|
| 13 |
+
PreTokenizedDataset,
|
| 14 |
+
SequenceTokenDataset,
|
| 15 |
+
build_sequence_dataset,
|
| 16 |
+
build_token_buffer,
|
| 17 |
+
format_dataset_example,
|
| 18 |
+
matches_category_filter,
|
| 19 |
+
)
|
| 20 |
+
from .hyper import (
|
| 21 |
+
GrowLengthScheduler,
|
| 22 |
+
ProgressiveUnfreezer,
|
| 23 |
+
SeedReplayMeZO,
|
| 24 |
+
apply_reservoir_freezing,
|
| 25 |
+
patch_training_loops,
|
| 26 |
+
)
|
| 27 |
+
from .loops import train_fast_loop, train_hyper_loop, train_standard_loop
|
| 28 |
+
from .optimizers import MeZOOptimizer
|
| 29 |
+
|
| 30 |
+
__all__ = [
|
| 31 |
+
"DEFAULT_SCALE_PRESETS",
|
| 32 |
+
"GrowLengthDataset",
|
| 33 |
+
"GrowLengthScheduler",
|
| 34 |
+
"MeZOOptimizer",
|
| 35 |
+
"PreTokenizedDataset",
|
| 36 |
+
"ProgressiveUnfreezer",
|
| 37 |
+
"SeedReplayMeZO",
|
| 38 |
+
"SequenceTokenDataset",
|
| 39 |
+
"benchmark_hyper",
|
| 40 |
+
"build_sequence_dataset",
|
| 41 |
+
"build_token_buffer",
|
| 42 |
+
"format_dataset_example",
|
| 43 |
+
"matches_category_filter",
|
| 44 |
+
"apply_reservoir_freezing",
|
| 45 |
+
"apply_standard_config_tweaks",
|
| 46 |
+
"build_model_from_args",
|
| 47 |
+
"cosine_lr",
|
| 48 |
+
"patch_training_loops",
|
| 49 |
+
"save_final_checkpoint",
|
| 50 |
+
"save_training_checkpoint",
|
| 51 |
+
"setup_cpu_runtime",
|
| 52 |
+
"run_baseline",
|
| 53 |
+
"run_hyper",
|
| 54 |
+
"train_fast_loop",
|
| 55 |
+
"train_hyper_loop",
|
| 56 |
+
"train_standard_loop",
|
| 57 |
+
]
|
chimera/training/benchmark.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import copy
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
import time
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch.utils.data import DataLoader, Dataset
|
| 10 |
+
|
| 11 |
+
from chimera.quantization import BitLinear
|
| 12 |
+
|
| 13 |
+
from .common import build_model_from_args
|
| 14 |
+
from .datasets import GrowLengthDataset, build_token_buffer
|
| 15 |
+
from .hyper import (
|
| 16 |
+
GrowLengthScheduler,
|
| 17 |
+
ProgressiveUnfreezer,
|
| 18 |
+
SeedReplayMeZO,
|
| 19 |
+
apply_reservoir_freezing,
|
| 20 |
+
patch_training_loops,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def run_baseline(model, token_buf, args):
|
| 25 |
+
model.train()
|
| 26 |
+
seq = args.seq_len
|
| 27 |
+
n = token_buf.numel() // (seq + 1)
|
| 28 |
+
chunks = token_buf[: n * (seq + 1)].view(n, seq + 1)
|
| 29 |
+
|
| 30 |
+
class _Dataset(Dataset):
|
| 31 |
+
def __len__(self):
|
| 32 |
+
return chunks.size(0)
|
| 33 |
+
|
| 34 |
+
def __getitem__(self, i):
|
| 35 |
+
c = chunks[i]
|
| 36 |
+
return {"input_ids": c[:-1], "labels": c[1:]}
|
| 37 |
+
|
| 38 |
+
loader = DataLoader(_Dataset(), batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True)
|
| 39 |
+
params = [(n, p) for n, p in model.named_parameters() if p.requires_grad]
|
| 40 |
+
eps = 1e-3
|
| 41 |
+
|
| 42 |
+
def loss_fn(batch):
|
| 43 |
+
return model(batch["input_ids"], labels=batch["labels"]).loss
|
| 44 |
+
|
| 45 |
+
total_toks, total_loss = 0, 0.0
|
| 46 |
+
t0 = time.time()
|
| 47 |
+
di = iter(loader)
|
| 48 |
+
for _ in range(args.max_steps):
|
| 49 |
+
try:
|
| 50 |
+
batch = next(di)
|
| 51 |
+
except StopIteration:
|
| 52 |
+
di = iter(loader)
|
| 53 |
+
batch = next(di)
|
| 54 |
+
seed = int(torch.randint(0, 2**31, (1,)).item())
|
| 55 |
+
gen = torch.Generator(device="cpu")
|
| 56 |
+
gen.manual_seed(seed)
|
| 57 |
+
for _, p in params:
|
| 58 |
+
p.data.add_(torch.randn(p.shape, generator=gen), alpha=eps)
|
| 59 |
+
for m in model.modules():
|
| 60 |
+
if isinstance(m, BitLinear):
|
| 61 |
+
m.invalidate_packed()
|
| 62 |
+
with torch.no_grad():
|
| 63 |
+
lp = float(loss_fn(batch).item())
|
| 64 |
+
gen.manual_seed(seed)
|
| 65 |
+
for _, p in params:
|
| 66 |
+
p.data.add_(torch.randn(p.shape, generator=gen), alpha=-2 * eps)
|
| 67 |
+
for m in model.modules():
|
| 68 |
+
if isinstance(m, BitLinear):
|
| 69 |
+
m.invalidate_packed()
|
| 70 |
+
with torch.no_grad():
|
| 71 |
+
ln = float(loss_fn(batch).item())
|
| 72 |
+
g = (lp - ln) / (2 * eps)
|
| 73 |
+
gen.manual_seed(seed)
|
| 74 |
+
for _, p in params:
|
| 75 |
+
z = torch.randn(p.shape, generator=gen)
|
| 76 |
+
p.data.add_(z, alpha=eps - args.lr * g)
|
| 77 |
+
for m in model.modules():
|
| 78 |
+
if isinstance(m, BitLinear):
|
| 79 |
+
m.invalidate_packed()
|
| 80 |
+
total_toks += batch["input_ids"].numel()
|
| 81 |
+
total_loss += 0.5 * (lp + ln)
|
| 82 |
+
dt = time.time() - t0
|
| 83 |
+
return total_toks / dt, total_loss / args.max_steps, dt
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def run_hyper(model, token_buf, args):
|
| 87 |
+
model.train()
|
| 88 |
+
patch_training_loops(model, num_loops=1)
|
| 89 |
+
if args.reservoir:
|
| 90 |
+
apply_reservoir_freezing(model)
|
| 91 |
+
unfreezer = ProgressiveUnfreezer(model, args.max_steps, args.unfreeze_stages) if args.progressive_unfreeze else None
|
| 92 |
+
stages = [
|
| 93 |
+
(max(8, args.seq_len // 4), 0.30),
|
| 94 |
+
(max(16, args.seq_len // 2), 0.30),
|
| 95 |
+
(args.seq_len, 0.40),
|
| 96 |
+
]
|
| 97 |
+
grow = GrowLengthScheduler(stages, args.max_steps) if args.growlength else None
|
| 98 |
+
cur_seq = stages[0][0] if grow else args.seq_len
|
| 99 |
+
dataset = GrowLengthDataset(token_buf, cur_seq)
|
| 100 |
+
opt = SeedReplayMeZO(model, lr=args.lr * 0.01, eps=args.mezo_eps, weight_decay=0.1, momentum=0.9)
|
| 101 |
+
|
| 102 |
+
def loss_fn(batch):
|
| 103 |
+
if args.bf16:
|
| 104 |
+
with torch.autocast("cpu", dtype=torch.bfloat16):
|
| 105 |
+
return model(batch["input_ids"], labels=batch["labels"]).loss
|
| 106 |
+
return model(batch["input_ids"], labels=batch["labels"]).loss
|
| 107 |
+
|
| 108 |
+
total_toks, total_loss = 0, 0.0
|
| 109 |
+
t0 = time.time()
|
| 110 |
+
eff_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq))
|
| 111 |
+
loader = DataLoader(dataset, batch_size=eff_batch, shuffle=True, num_workers=0, drop_last=True)
|
| 112 |
+
di = iter(loader)
|
| 113 |
+
for step in range(args.max_steps):
|
| 114 |
+
if grow:
|
| 115 |
+
ns = grow.get_seq_len(step)
|
| 116 |
+
if ns != cur_seq:
|
| 117 |
+
cur_seq = ns
|
| 118 |
+
dataset.set_seq_len(cur_seq)
|
| 119 |
+
eff_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq))
|
| 120 |
+
loader = DataLoader(dataset, batch_size=eff_batch, shuffle=True, num_workers=0, drop_last=True)
|
| 121 |
+
di = iter(loader)
|
| 122 |
+
if unfreezer:
|
| 123 |
+
unfreezer.update(step)
|
| 124 |
+
try:
|
| 125 |
+
batch = next(di)
|
| 126 |
+
except StopIteration:
|
| 127 |
+
di = iter(loader)
|
| 128 |
+
batch = next(di)
|
| 129 |
+
loss_val = opt.step(loss_fn, batch)
|
| 130 |
+
total_toks += batch["input_ids"].numel()
|
| 131 |
+
total_loss += loss_val
|
| 132 |
+
dt = time.time() - t0
|
| 133 |
+
return total_toks / dt, total_loss / args.max_steps, dt
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def benchmark_hyper(args):
|
| 137 |
+
print("=" * 65)
|
| 138 |
+
print("CHIMERA 5.3 HYPER v3 β BENCHMARK (full arch, all features)")
|
| 139 |
+
print("=" * 65)
|
| 140 |
+
model_a, cfg = build_model_from_args(args)
|
| 141 |
+
model_b = copy.deepcopy(model_a)
|
| 142 |
+
c = model_a.count_parameters()
|
| 143 |
+
print(f"Model: {c['total']:,} params, {cfg['num_hidden_layers']} layers")
|
| 144 |
+
print(f"Features: looping={model_a.looping_enabled} evolution={model_a.evolution is not None} span={model_a.span_engine is not None}")
|
| 145 |
+
|
| 146 |
+
tok_budget = max(500_000, args.max_steps * args.batch_size * (args.seq_len + 1) * 8)
|
| 147 |
+
token_buf = build_token_buffer(args.dataset_name, args.dataset_split, args.text_column, tok_budget, args.cache_dir)
|
| 148 |
+
print(f"Tokens: {token_buf.numel():,}\n")
|
| 149 |
+
|
| 150 |
+
print("-" * 65)
|
| 151 |
+
print("BASELINE (randn MeZO, invalidate_packed, loop=2, full evo)")
|
| 152 |
+
print("-" * 65)
|
| 153 |
+
bt, bl, bd = run_baseline(model_a, token_buf, args)
|
| 154 |
+
print(f" -> {bt:,.0f} tok/s loss={bl:.4f} time={bd:.1f}s\n")
|
| 155 |
+
|
| 156 |
+
print("-" * 65)
|
| 157 |
+
print("HYPER (seed-replay MeZO, STE path, loop=1, GrowLength, Reservoir)")
|
| 158 |
+
print("-" * 65)
|
| 159 |
+
ht, hl, hd = run_hyper(model_b, token_buf, args)
|
| 160 |
+
print(f" -> {ht:,.0f} tok/s loss={hl:.4f} time={hd:.1f}s\n")
|
| 161 |
+
|
| 162 |
+
sp = ht / bt if bt > 0 else float("inf")
|
| 163 |
+
print("=" * 65)
|
| 164 |
+
print(f" Baseline : {bt:>10,.0f} tok/s loss {bl:.4f}")
|
| 165 |
+
print(f" Hyper : {ht:>10,.0f} tok/s loss {hl:.4f}")
|
| 166 |
+
print(f" Speedup : {sp:>10.1f}x")
|
| 167 |
+
print("=" * 65)
|
| 168 |
+
|
| 169 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 170 |
+
with open(os.path.join(args.output_dir, "benchmark.json"), "w") as f:
|
| 171 |
+
json.dump({"baseline_tps": round(bt), "hyper_tps": round(ht), "speedup": round(sp, 2)}, f, indent=2)
|
chimera/training/common.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import Any
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from chimera import Chimera51ForCausalLM
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
DEFAULT_SCALE_PRESETS = {
|
| 15 |
+
"tiny": dict(hidden_size=256, intermediate_size=512, num_heads=4, head_dim=48),
|
| 16 |
+
"small": dict(hidden_size=512, intermediate_size=1024, num_heads=8, head_dim=48),
|
| 17 |
+
"medium": dict(hidden_size=1024, intermediate_size=2048, num_heads=8, head_dim=96),
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def setup_cpu_runtime(*, interop_threads: int | None = None) -> int:
|
| 22 |
+
n_cpus = os.cpu_count() or 4
|
| 23 |
+
os.environ.setdefault("OMP_NUM_THREADS", str(n_cpus))
|
| 24 |
+
os.environ.setdefault("MKL_NUM_THREADS", str(n_cpus))
|
| 25 |
+
os.environ.setdefault("KMP_AFFINITY", "granularity=fine,compact,1,0")
|
| 26 |
+
os.environ.setdefault("KMP_BLOCKTIME", "1")
|
| 27 |
+
os.environ.setdefault("MALLOC_CONF", "background_thread:true,metadata_thp:auto")
|
| 28 |
+
|
| 29 |
+
torch.set_num_threads(int(os.environ.get("OMP_NUM_THREADS", n_cpus)))
|
| 30 |
+
try:
|
| 31 |
+
target = interop_threads
|
| 32 |
+
if target is None:
|
| 33 |
+
target = int(os.environ.get("CHIMERA_INTEROP_THREADS", "1"))
|
| 34 |
+
torch.set_num_interop_threads(target)
|
| 35 |
+
except RuntimeError:
|
| 36 |
+
pass
|
| 37 |
+
return n_cpus
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def cosine_lr(step: int, warmup: int, total: int, max_lr: float, min_lr: float) -> float:
|
| 41 |
+
if warmup > 0 and step < warmup:
|
| 42 |
+
return max_lr * (step + 1) / warmup
|
| 43 |
+
if step >= total:
|
| 44 |
+
return min_lr
|
| 45 |
+
progress = (step - warmup) / max(1, total - warmup)
|
| 46 |
+
return min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos(math.pi * progress))
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def load_json_config(path: str | os.PathLike[str]) -> dict[str, Any]:
|
| 50 |
+
with open(path, encoding="utf-8") as fh:
|
| 51 |
+
return json.load(fh)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def apply_standard_config_tweaks(config: dict[str, Any], *, scale: str, seq_len: int) -> dict[str, Any]:
|
| 55 |
+
config = dict(config)
|
| 56 |
+
if scale in DEFAULT_SCALE_PRESETS:
|
| 57 |
+
config.update(DEFAULT_SCALE_PRESETS[scale])
|
| 58 |
+
config["num_hidden_layers"] = int(config.get("num_hidden_layers", 28))
|
| 59 |
+
config["vocab_size"] = config.get("vocab_size", 200073)
|
| 60 |
+
config.setdefault("gated_deltanet", {})["chunk_size"] = min(seq_len, 64)
|
| 61 |
+
config.setdefault("xlstm", {})["memory_size_per_head"] = [config["head_dim"], config["head_dim"]]
|
| 62 |
+
config.setdefault("titans", {}).update({
|
| 63 |
+
"memory_depth": 2,
|
| 64 |
+
"persistent_memory_slots": 16,
|
| 65 |
+
"local_window_size": min(seq_len, 256),
|
| 66 |
+
})
|
| 67 |
+
moe_cfg = config.setdefault("backbone", {}).setdefault("moe", {})
|
| 68 |
+
moe_cfg.setdefault("layers", [3, 7, 11, 15, 19, 23, 27])
|
| 69 |
+
moe_cfg.setdefault("moe_intermediate_size", config["intermediate_size"] // 4)
|
| 70 |
+
moe_cfg.setdefault("n_routed_experts", 8)
|
| 71 |
+
moe_cfg.setdefault("n_shared_experts", 1)
|
| 72 |
+
moe_cfg.setdefault("num_experts_per_tok", 2)
|
| 73 |
+
config.setdefault("looping", {}).update({
|
| 74 |
+
"enabled": True,
|
| 75 |
+
"prelude": [0, 3],
|
| 76 |
+
"loop": [4, 23],
|
| 77 |
+
"coda": [24, 27],
|
| 78 |
+
"loop_range": [1, 3],
|
| 79 |
+
"loop_default": 2,
|
| 80 |
+
})
|
| 81 |
+
config.setdefault("span_inference", {})["enabled"] = True
|
| 82 |
+
config.setdefault("grammar", {})["enabled"] = True
|
| 83 |
+
config.setdefault("entropy_valve", {})["enabled"] = True
|
| 84 |
+
config.setdefault("debt_ledger", {})["enabled"] = True
|
| 85 |
+
config.setdefault("multimodal", {})["enabled"] = False
|
| 86 |
+
return config
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def build_model_from_args(args) -> tuple[Chimera51ForCausalLM, dict[str, Any]]:
|
| 90 |
+
config = load_json_config(args.config)
|
| 91 |
+
config = apply_standard_config_tweaks(config, scale=args.scale, seq_len=args.seq_len)
|
| 92 |
+
return Chimera51ForCausalLM(config), config
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def save_training_checkpoint(model, config: dict[str, Any], step: int, output_dir: str) -> str:
|
| 96 |
+
ckpt_dir = Path(output_dir)
|
| 97 |
+
ckpt_dir.mkdir(parents=True, exist_ok=True)
|
| 98 |
+
raw_model = getattr(model, "_orig_mod", model)
|
| 99 |
+
torch.save({"model": raw_model.state_dict(), "config": config, "step": step}, ckpt_dir / "ckpt.pt")
|
| 100 |
+
return str(ckpt_dir)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def save_final_checkpoint(
|
| 104 |
+
model,
|
| 105 |
+
config: dict[str, Any],
|
| 106 |
+
step: int,
|
| 107 |
+
best_loss: float,
|
| 108 |
+
output_dir: str,
|
| 109 |
+
) -> str:
|
| 110 |
+
final_dir = Path(output_dir)
|
| 111 |
+
final_dir.mkdir(parents=True, exist_ok=True)
|
| 112 |
+
raw_model = getattr(model, "_orig_mod", model)
|
| 113 |
+
torch.save(
|
| 114 |
+
{"model": raw_model.state_dict(), "config": config, "step": step, "best_loss": best_loss},
|
| 115 |
+
final_dir / "model.pt",
|
| 116 |
+
)
|
| 117 |
+
with open(final_dir / "config.json", "w", encoding="utf-8") as fh:
|
| 118 |
+
json.dump(config, fh, indent=2)
|
| 119 |
+
return str(final_dir)
|
chimera/training/datasets.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch.utils.data import Dataset
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class SequenceTokenDataset(Dataset):
|
| 10 |
+
def __init__(self, chunks: torch.Tensor):
|
| 11 |
+
self.chunks = chunks
|
| 12 |
+
|
| 13 |
+
def __len__(self) -> int:
|
| 14 |
+
return self.chunks.size(0)
|
| 15 |
+
|
| 16 |
+
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
| 17 |
+
chunk = self.chunks[idx]
|
| 18 |
+
return {"input_ids": chunk, "labels": chunk}
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class PreTokenizedDataset(Dataset):
|
| 22 |
+
def __init__(self, ids: torch.Tensor, seq_len: int):
|
| 23 |
+
n = ids.numel() // (seq_len + 1)
|
| 24 |
+
self.chunks = ids[: n * (seq_len + 1)].view(n, seq_len + 1)
|
| 25 |
+
self.seq_len = seq_len
|
| 26 |
+
|
| 27 |
+
def __len__(self) -> int:
|
| 28 |
+
return self.chunks.size(0)
|
| 29 |
+
|
| 30 |
+
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
| 31 |
+
chunk = self.chunks[idx]
|
| 32 |
+
return {"input_ids": chunk[:-1], "labels": chunk[1:]}
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class GrowLengthDataset(Dataset):
|
| 36 |
+
def __init__(self, all_ids: torch.Tensor, seq_len: int = 16):
|
| 37 |
+
self.all_ids = all_ids
|
| 38 |
+
self._seq_len = 0
|
| 39 |
+
self._n = 0
|
| 40 |
+
self.set_seq_len(seq_len)
|
| 41 |
+
|
| 42 |
+
def set_seq_len(self, seq_len: int) -> None:
|
| 43 |
+
self._seq_len = int(seq_len)
|
| 44 |
+
self._n = self.all_ids.numel() // (self._seq_len + 1)
|
| 45 |
+
|
| 46 |
+
@property
|
| 47 |
+
def seq_len(self) -> int:
|
| 48 |
+
return self._seq_len
|
| 49 |
+
|
| 50 |
+
def __len__(self) -> int:
|
| 51 |
+
return self._n
|
| 52 |
+
|
| 53 |
+
def __getitem__(self, idx: int) -> dict[str, torch.Tensor]:
|
| 54 |
+
start = idx * (self._seq_len + 1)
|
| 55 |
+
chunk = self.all_ids[start : start + self._seq_len + 1]
|
| 56 |
+
return {"input_ids": chunk[:-1], "labels": chunk[1:]}
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def matches_category_filter(example: dict, filters: list[str]) -> bool:
|
| 60 |
+
category = example.get("category", "") or ""
|
| 61 |
+
if not category:
|
| 62 |
+
return False
|
| 63 |
+
category_lower = category.lower()
|
| 64 |
+
return any(f.lower() in category_lower for f in filters)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def format_dataset_example(ex: dict, tok, text_column: str = "auto", include_reasoning: bool = False) -> str:
|
| 68 |
+
if text_column == "auto":
|
| 69 |
+
for candidate in ("messages", "text", "content", "conversation"):
|
| 70 |
+
if candidate in ex:
|
| 71 |
+
text_column = candidate
|
| 72 |
+
break
|
| 73 |
+
else:
|
| 74 |
+
text_column = ""
|
| 75 |
+
|
| 76 |
+
if text_column == "messages" and "messages" in ex:
|
| 77 |
+
messages = ex["messages"]
|
| 78 |
+
if include_reasoning and isinstance(messages, list):
|
| 79 |
+
rewritten = []
|
| 80 |
+
for message in messages:
|
| 81 |
+
if isinstance(message, dict) and message.get("role") == "assistant" and "reasoning" in message:
|
| 82 |
+
rewritten.append(
|
| 83 |
+
{
|
| 84 |
+
"role": "assistant",
|
| 85 |
+
"content": (
|
| 86 |
+
f"<|thinking|>\n{message['reasoning']}\n<|/thinking|>\n"
|
| 87 |
+
f"{message.get('content', '')}"
|
| 88 |
+
),
|
| 89 |
+
}
|
| 90 |
+
)
|
| 91 |
+
else:
|
| 92 |
+
rewritten.append(message)
|
| 93 |
+
messages = rewritten
|
| 94 |
+
return tok.apply_chat_template(messages)
|
| 95 |
+
|
| 96 |
+
if text_column and text_column in ex:
|
| 97 |
+
value = ex[text_column]
|
| 98 |
+
if isinstance(value, str):
|
| 99 |
+
return value
|
| 100 |
+
if isinstance(value, list) and value and isinstance(value[0], dict):
|
| 101 |
+
return tok.apply_chat_template(value)
|
| 102 |
+
return str(value)
|
| 103 |
+
return str(ex)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def build_token_buffer(
|
| 107 |
+
dataset_name: str,
|
| 108 |
+
split: str,
|
| 109 |
+
text_column: str,
|
| 110 |
+
max_tokens: int,
|
| 111 |
+
cache_dir: str,
|
| 112 |
+
*,
|
| 113 |
+
dataset_config: str | None = None,
|
| 114 |
+
category_filter: str | None = None,
|
| 115 |
+
include_reasoning: bool = False,
|
| 116 |
+
):
|
| 117 |
+
from datasets import load_dataset
|
| 118 |
+
from chimera import ChimeraTokenizer
|
| 119 |
+
|
| 120 |
+
cache_name = f"{dataset_name.replace('/', '_')}_{split}_{max_tokens}.pt"
|
| 121 |
+
cache_path = os.path.join(cache_dir, cache_name)
|
| 122 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 123 |
+
|
| 124 |
+
if os.path.exists(cache_path):
|
| 125 |
+
print(f"[DATA] Cache hit: {cache_path}")
|
| 126 |
+
return torch.load(cache_path, weights_only=True)
|
| 127 |
+
|
| 128 |
+
print(f"[DATA] Streaming {dataset_name} ({split})...")
|
| 129 |
+
load_kwargs = {"split": split, "streaming": True}
|
| 130 |
+
if dataset_config:
|
| 131 |
+
load_kwargs["name"] = dataset_config
|
| 132 |
+
ds = load_dataset(dataset_name, **load_kwargs)
|
| 133 |
+
tok = ChimeraTokenizer(pretrained="o200k_base")
|
| 134 |
+
|
| 135 |
+
filters = [c.strip() for c in category_filter.split(",") if c.strip()] if category_filter else None
|
| 136 |
+
if filters:
|
| 137 |
+
print(f"[DATA] Filtering categories: {filters}")
|
| 138 |
+
|
| 139 |
+
buf = torch.empty(max_tokens, dtype=torch.long)
|
| 140 |
+
idx = processed = skipped = 0
|
| 141 |
+
for ex in ds:
|
| 142 |
+
if filters and not matches_category_filter(ex, filters):
|
| 143 |
+
skipped += 1
|
| 144 |
+
continue
|
| 145 |
+
text = format_dataset_example(ex, tok, text_column, include_reasoning)
|
| 146 |
+
if not text or not text.strip():
|
| 147 |
+
skipped += 1
|
| 148 |
+
continue
|
| 149 |
+
ids = tok.encode(text, add_special_tokens=False)
|
| 150 |
+
ids.append(tok.eos_token_id)
|
| 151 |
+
n = min(len(ids), max_tokens - idx)
|
| 152 |
+
if n <= 0:
|
| 153 |
+
break
|
| 154 |
+
buf[idx : idx + n] = torch.tensor(ids[:n], dtype=torch.long)
|
| 155 |
+
idx += n
|
| 156 |
+
processed += 1
|
| 157 |
+
if processed % 5000 == 0:
|
| 158 |
+
print(f" {processed:,} docs {idx:,}/{max_tokens} tokens")
|
| 159 |
+
|
| 160 |
+
token_buf = buf[:idx].contiguous()
|
| 161 |
+
torch.save(token_buf, cache_path)
|
| 162 |
+
print(f"[DATA] Processed {processed:,} examples, skipped {skipped:,}.")
|
| 163 |
+
print(f"[DATA] {idx:,} tokens -> {cache_path}")
|
| 164 |
+
return token_buf
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def build_sequence_dataset(
|
| 168 |
+
seq_len: int,
|
| 169 |
+
*,
|
| 170 |
+
max_samples=None,
|
| 171 |
+
max_tokens=None,
|
| 172 |
+
split: str = "train",
|
| 173 |
+
dataset_name: str = "roneneldan/TinyStories",
|
| 174 |
+
dataset_config: str | None = None,
|
| 175 |
+
text_column: str = "auto",
|
| 176 |
+
category_filter: str | None = None,
|
| 177 |
+
include_reasoning: bool = False,
|
| 178 |
+
cache_dir: str = "./cache",
|
| 179 |
+
):
|
| 180 |
+
token_budget = int(max_tokens) if max_tokens is not None else None
|
| 181 |
+
if token_budget is None and max_samples is not None:
|
| 182 |
+
token_budget = int(max_samples) * (seq_len + 1)
|
| 183 |
+
if token_budget is None or token_budget <= 0:
|
| 184 |
+
token_budget = max(500_000, (int(max_samples) if max_samples else 10000) * (seq_len + 1))
|
| 185 |
+
|
| 186 |
+
token_buffer = build_token_buffer(
|
| 187 |
+
dataset_name,
|
| 188 |
+
split,
|
| 189 |
+
text_column,
|
| 190 |
+
token_budget,
|
| 191 |
+
cache_dir,
|
| 192 |
+
dataset_config=dataset_config,
|
| 193 |
+
category_filter=category_filter,
|
| 194 |
+
include_reasoning=include_reasoning,
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
if token_buffer.numel() == 0:
|
| 198 |
+
raise ValueError("No data matched filters.")
|
| 199 |
+
|
| 200 |
+
n = token_buffer.numel() // (seq_len + 1)
|
| 201 |
+
if max_samples:
|
| 202 |
+
n = min(n, max_samples)
|
| 203 |
+
chunks = token_buffer[: n * (seq_len + 1)].view(n, seq_len + 1)
|
| 204 |
+
print(f"[DATA] {n:,} chunks Γ {seq_len} tokens = {n * seq_len:,} total")
|
| 205 |
+
return SequenceTokenDataset(chunks)
|
chimera/training/hyper.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class GrowLengthScheduler:
|
| 8 |
+
def __init__(self, stages, total_steps):
|
| 9 |
+
total_frac = sum(frac for _, frac in stages) or 1.0
|
| 10 |
+
cumulative = 0
|
| 11 |
+
self._boundaries = []
|
| 12 |
+
for seq_len, frac in stages:
|
| 13 |
+
cumulative += int(total_steps * frac / total_frac)
|
| 14 |
+
self._boundaries.append((cumulative, int(seq_len)))
|
| 15 |
+
|
| 16 |
+
def get_seq_len(self, step: int) -> int:
|
| 17 |
+
for boundary, seq_len in self._boundaries:
|
| 18 |
+
if step < boundary:
|
| 19 |
+
return seq_len
|
| 20 |
+
return self._boundaries[-1][1]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def apply_reservoir_freezing(model) -> int:
|
| 24 |
+
frozen = 0
|
| 25 |
+
for _, module in model.named_modules():
|
| 26 |
+
targets = []
|
| 27 |
+
if hasattr(module, "a_proj") and hasattr(module, "b_proj"):
|
| 28 |
+
targets.extend(["a_proj", "b_proj"])
|
| 29 |
+
if hasattr(module, "fgate") and hasattr(module, "igate"):
|
| 30 |
+
targets.append("fgate")
|
| 31 |
+
if hasattr(module, "alpha_proj") and hasattr(module, "eta_proj"):
|
| 32 |
+
targets.append("alpha_proj")
|
| 33 |
+
for attr in targets:
|
| 34 |
+
proj = getattr(module, attr, None)
|
| 35 |
+
if proj is None:
|
| 36 |
+
continue
|
| 37 |
+
weight = getattr(proj, "weight", None)
|
| 38 |
+
if weight is None or not isinstance(weight, nn.Parameter):
|
| 39 |
+
continue
|
| 40 |
+
with torch.no_grad():
|
| 41 |
+
weight.data = torch.randint(-1, 2, weight.shape, dtype=weight.dtype, device=weight.device)
|
| 42 |
+
norm = torch.linalg.matrix_norm(weight.data.float(), ord=2).clamp(min=1.0)
|
| 43 |
+
weight.data.div_(norm)
|
| 44 |
+
weight.requires_grad = False
|
| 45 |
+
frozen += weight.numel()
|
| 46 |
+
return frozen
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class SeedReplayMeZO:
|
| 50 |
+
def __init__(self, model, *, lr=1e-4, eps=1e-3, weight_decay=0.0, momentum=0.9):
|
| 51 |
+
self.model = model
|
| 52 |
+
self.lr = float(lr)
|
| 53 |
+
self.eps = float(eps)
|
| 54 |
+
self.wd = float(weight_decay)
|
| 55 |
+
self.mom = float(momentum)
|
| 56 |
+
self._params = []
|
| 57 |
+
seen = set()
|
| 58 |
+
for _, param in model.named_parameters():
|
| 59 |
+
if param.requires_grad and id(param) not in seen:
|
| 60 |
+
self._params.append(param)
|
| 61 |
+
seen.add(id(param))
|
| 62 |
+
self._momentum = [torch.zeros_like(param.data) for param in self._params] if self.mom > 0 else None
|
| 63 |
+
|
| 64 |
+
def _perturb_inplace(self, seed: int, scale: float) -> None:
|
| 65 |
+
gen = torch.Generator(device="cpu")
|
| 66 |
+
for i, param in enumerate(self._params):
|
| 67 |
+
gen.manual_seed((seed + i * 999983) & 0x7FFFFFFFFFFFFFFF)
|
| 68 |
+
z = torch.empty_like(param.data)
|
| 69 |
+
z.bernoulli_(0.5, generator=gen).mul_(2).sub_(1)
|
| 70 |
+
param.data.add_(z, alpha=scale)
|
| 71 |
+
|
| 72 |
+
def _update_inplace(self, seed: int, projected_grad: float) -> None:
|
| 73 |
+
gen = torch.Generator(device="cpu")
|
| 74 |
+
for i, param in enumerate(self._params):
|
| 75 |
+
gen.manual_seed((seed + i * 999983) & 0x7FFFFFFFFFFFFFFF)
|
| 76 |
+
z = torch.empty_like(param.data)
|
| 77 |
+
z.bernoulli_(0.5, generator=gen).mul_(2).sub_(1)
|
| 78 |
+
param.data.add_(z, alpha=self.eps)
|
| 79 |
+
if self._momentum is not None:
|
| 80 |
+
buf = self._momentum[i]
|
| 81 |
+
buf.mul_(self.mom).add_(z, alpha=projected_grad)
|
| 82 |
+
param.data.add_(buf, alpha=-self.lr)
|
| 83 |
+
else:
|
| 84 |
+
param.data.add_(z, alpha=-self.lr * projected_grad)
|
| 85 |
+
if self.wd > 0:
|
| 86 |
+
param.data.mul_(1 - self.lr * self.wd)
|
| 87 |
+
|
| 88 |
+
@torch.no_grad()
|
| 89 |
+
def step(self, loss_fn, batch) -> float:
|
| 90 |
+
seed = int(torch.randint(0, 2**31, (1,)).item())
|
| 91 |
+
self._perturb_inplace(seed, +self.eps)
|
| 92 |
+
loss_pos = float(loss_fn(batch).item())
|
| 93 |
+
self._perturb_inplace(seed, -2.0 * self.eps)
|
| 94 |
+
loss_neg = float(loss_fn(batch).item())
|
| 95 |
+
projected_grad = (loss_pos - loss_neg) / (2.0 * self.eps)
|
| 96 |
+
self._update_inplace(seed, projected_grad)
|
| 97 |
+
return 0.5 * (loss_pos + loss_neg)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class ProgressiveUnfreezer:
|
| 101 |
+
def __init__(self, model, total_steps, n_stages=4):
|
| 102 |
+
self._layers = model.layers
|
| 103 |
+
self._n = len(self._layers)
|
| 104 |
+
self._total = total_steps
|
| 105 |
+
self._stages = n_stages
|
| 106 |
+
self._block = max(1, self._n // n_stages)
|
| 107 |
+
self._current = self._n
|
| 108 |
+
self.update(0)
|
| 109 |
+
|
| 110 |
+
def update(self, step: int) -> int:
|
| 111 |
+
stage = min(step * self._stages // max(1, self._total), self._stages - 1)
|
| 112 |
+
target = max(0, self._n - (stage + 1) * self._block)
|
| 113 |
+
if target != self._current:
|
| 114 |
+
self._current = target
|
| 115 |
+
for i, layer in enumerate(self._layers):
|
| 116 |
+
requires_grad = i >= self._current
|
| 117 |
+
for param in layer.parameters():
|
| 118 |
+
param.requires_grad = requires_grad
|
| 119 |
+
return self._current
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def patch_training_loops(model, num_loops=1) -> None:
|
| 123 |
+
if hasattr(model, "loop_controller"):
|
| 124 |
+
model.loop_controller.loop_default = num_loops
|
| 125 |
+
model.loop_controller.loop_min = 1
|
| 126 |
+
model.loop_controller.loop_max = max(num_loops, 1)
|
| 127 |
+
if hasattr(model, "evo_every_n_layers"):
|
| 128 |
+
model.evo_every_n_layers = max(model.evo_every_n_layers, 8)
|
chimera/training/loops.py
ADDED
|
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import time
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
import chimera_turbo
|
| 11 |
+
|
| 12 |
+
from .common import cosine_lr, save_final_checkpoint, save_training_checkpoint
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def train_fast_loop(args, model, config, loader, compute_loss) -> str:
|
| 16 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.95))
|
| 17 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 18 |
+
log_f = open(os.path.join(args.output_dir, "log.jsonl"), "w", encoding="utf-8")
|
| 19 |
+
|
| 20 |
+
model.train()
|
| 21 |
+
step = 0
|
| 22 |
+
total_loss = 0.0
|
| 23 |
+
best_loss = float("inf")
|
| 24 |
+
toks = 0
|
| 25 |
+
t0 = time.time()
|
| 26 |
+
data_iter = iter(loader)
|
| 27 |
+
warmup = min(args.warmup, max(1, args.max_steps // 10))
|
| 28 |
+
|
| 29 |
+
print(f"\n{'=' * 60}\nTraining starts\n{'=' * 60}\n")
|
| 30 |
+
|
| 31 |
+
while step < args.max_steps:
|
| 32 |
+
try:
|
| 33 |
+
batch = next(data_iter)
|
| 34 |
+
except StopIteration:
|
| 35 |
+
data_iter = iter(loader)
|
| 36 |
+
batch = next(data_iter)
|
| 37 |
+
|
| 38 |
+
loss = compute_loss(batch)
|
| 39 |
+
loss.backward()
|
| 40 |
+
total_loss += float(loss.item())
|
| 41 |
+
|
| 42 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 43 |
+
cur_lr = cosine_lr(step, warmup, args.max_steps, args.lr, args.lr * 0.1)
|
| 44 |
+
for pg in optimizer.param_groups:
|
| 45 |
+
pg["lr"] = cur_lr
|
| 46 |
+
optimizer.step()
|
| 47 |
+
optimizer.zero_grad(set_to_none=True)
|
| 48 |
+
|
| 49 |
+
toks += batch["input_ids"].numel()
|
| 50 |
+
step += 1
|
| 51 |
+
|
| 52 |
+
if step % args.log_every == 0:
|
| 53 |
+
dt = time.time() - t0
|
| 54 |
+
avg = total_loss / args.log_every
|
| 55 |
+
ppl = math.exp(min(avg, 20))
|
| 56 |
+
tps = toks / dt if dt > 0 else 0
|
| 57 |
+
eta_h = (args.max_steps - step) / (step / dt) / 3600 if dt > 0 else 0.0
|
| 58 |
+
log_f.write(json.dumps({"step": step, "loss": round(avg, 4), "ppl": round(ppl, 2), "lr": cur_lr, "tok/s": round(tps)}) + "\n")
|
| 59 |
+
log_f.flush()
|
| 60 |
+
print(f" step {step:>6}/{args.max_steps} | loss {avg:.4f} | ppl {ppl:>8.2f} | lr {cur_lr:.2e} | {tps:.0f} tok/s | ETA {eta_h:.1f}h")
|
| 61 |
+
best_loss = min(best_loss, avg)
|
| 62 |
+
total_loss = 0.0
|
| 63 |
+
toks = 0
|
| 64 |
+
t0 = time.time()
|
| 65 |
+
|
| 66 |
+
if step % args.save_every == 0:
|
| 67 |
+
ckpt_dir = save_training_checkpoint(model, config, step, os.path.join(args.output_dir, f"ckpt-{step}"))
|
| 68 |
+
print(f" [SAVE] {ckpt_dir}")
|
| 69 |
+
|
| 70 |
+
final_dir = save_final_checkpoint(model, config, step, best_loss, os.path.join(args.output_dir, "final"))
|
| 71 |
+
log_f.close()
|
| 72 |
+
print(f"\n{'=' * 60}")
|
| 73 |
+
print(f"DONE β best loss {best_loss:.4f}, ppl {math.exp(min(best_loss, 20)):.2f}")
|
| 74 |
+
print(f"Saved to {final_dir}")
|
| 75 |
+
return final_dir
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def train_standard_loop(args, model, config, loader, compute_loss, optimizer, use_mezo: bool) -> str:
|
| 79 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 80 |
+
log_f = open(os.path.join(args.output_dir, "log.jsonl"), "w", encoding="utf-8")
|
| 81 |
+
model.train()
|
| 82 |
+
step = 0
|
| 83 |
+
cur_lr = args.lr
|
| 84 |
+
total_loss = 0.0
|
| 85 |
+
best_loss = float("inf")
|
| 86 |
+
toks = 0
|
| 87 |
+
t0 = time.time()
|
| 88 |
+
data_iter = iter(loader)
|
| 89 |
+
warmup = min(args.warmup, max(1, args.max_steps // 10))
|
| 90 |
+
|
| 91 |
+
if not use_mezo:
|
| 92 |
+
optimizer.zero_grad(set_to_none=True)
|
| 93 |
+
|
| 94 |
+
print(f"\n{'=' * 60}\nTraining starts\n{'=' * 60}\n")
|
| 95 |
+
|
| 96 |
+
while step < args.max_steps:
|
| 97 |
+
try:
|
| 98 |
+
batch = next(data_iter)
|
| 99 |
+
except StopIteration:
|
| 100 |
+
data_iter = iter(loader)
|
| 101 |
+
batch = next(data_iter)
|
| 102 |
+
|
| 103 |
+
if use_mezo:
|
| 104 |
+
cur_lr = cosine_lr(step, warmup, args.max_steps, args.lr * 0.01, args.lr * 0.001)
|
| 105 |
+
optimizer.lr = cur_lr
|
| 106 |
+
loss_val = optimizer.step(compute_loss, batch)
|
| 107 |
+
total_loss += loss_val
|
| 108 |
+
else:
|
| 109 |
+
loss = compute_loss(batch)
|
| 110 |
+
(loss / args.grad_accum).backward()
|
| 111 |
+
total_loss += float(loss.item())
|
| 112 |
+
if (step + 1) % args.grad_accum == 0:
|
| 113 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 114 |
+
cur_lr = cosine_lr(step, warmup, args.max_steps, args.lr, args.lr * 0.1)
|
| 115 |
+
for pg in optimizer.param_groups:
|
| 116 |
+
pg["lr"] = cur_lr
|
| 117 |
+
optimizer.step()
|
| 118 |
+
optimizer.zero_grad(set_to_none=True)
|
| 119 |
+
|
| 120 |
+
toks += batch["input_ids"][:, :-1].numel()
|
| 121 |
+
step += 1
|
| 122 |
+
|
| 123 |
+
if step % args.log_every == 0:
|
| 124 |
+
dt = time.time() - t0
|
| 125 |
+
avg = total_loss / args.log_every
|
| 126 |
+
ppl = math.exp(min(avg, 20))
|
| 127 |
+
tps = toks / dt if dt > 0 else 0
|
| 128 |
+
eta_h = (args.max_steps - step) / (step / dt) / 3600 if dt > 0 else 0.0
|
| 129 |
+
log_f.write(json.dumps({"step": step, "loss": round(avg, 4), "ppl": round(ppl, 2), "lr": cur_lr, "tok/s": round(tps), "optimizer": "mezo" if use_mezo else "adamw"}) + "\n")
|
| 130 |
+
log_f.flush()
|
| 131 |
+
print(f" step {step:>6}/{args.max_steps} | loss {avg:.4f} | ppl {ppl:>8.2f} | lr {cur_lr:.2e} | {tps:.0f} tok/s | ETA {eta_h:.1f}h")
|
| 132 |
+
best_loss = min(best_loss, avg)
|
| 133 |
+
total_loss = 0.0
|
| 134 |
+
toks = 0
|
| 135 |
+
t0 = time.time()
|
| 136 |
+
|
| 137 |
+
if step % args.save_every == 0:
|
| 138 |
+
ckpt_dir = save_training_checkpoint(model, config, step, os.path.join(args.output_dir, f"ckpt-{step}"))
|
| 139 |
+
print(f" [SAVE] {ckpt_dir}")
|
| 140 |
+
|
| 141 |
+
final_dir = save_final_checkpoint(model, config, step, best_loss, os.path.join(args.output_dir, "final"))
|
| 142 |
+
log_f.close()
|
| 143 |
+
print(f"\n{'=' * 60}")
|
| 144 |
+
print(f"DONE β best loss {best_loss:.4f}, ppl {math.exp(min(best_loss, 20)):.2f}")
|
| 145 |
+
print(f"Saved to {final_dir}")
|
| 146 |
+
return final_dir
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer):
|
| 150 |
+
model, optimizer, scheduler = chimera_turbo.apply(
|
| 151 |
+
model,
|
| 152 |
+
max_steps=args.max_steps,
|
| 153 |
+
lr=1e-3,
|
| 154 |
+
weight_decay=0.05,
|
| 155 |
+
warmup_steps=min(500, args.max_steps // 10),
|
| 156 |
+
use_compile=True,
|
| 157 |
+
use_ipex=True,
|
| 158 |
+
)
|
| 159 |
+
model.train()
|
| 160 |
+
print(f"[P5] Train mode: BitLinear STE path (no invalidate_packed)")
|
| 161 |
+
use_bf16 = bool(args.bf16)
|
| 162 |
+
|
| 163 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 164 |
+
log_f = open(os.path.join(args.output_dir, "log_hyper.jsonl"), "w")
|
| 165 |
+
step = 0
|
| 166 |
+
total_loss = 0.0
|
| 167 |
+
best_loss = float("inf")
|
| 168 |
+
toks = 0
|
| 169 |
+
t0 = time.time()
|
| 170 |
+
cur_seq = initial_seq
|
| 171 |
+
eff_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq))
|
| 172 |
+
loader = torch.utils.data.DataLoader(dataset, batch_size=eff_batch, shuffle=True, num_workers=0, drop_last=True)
|
| 173 |
+
data_iter = iter(loader)
|
| 174 |
+
|
| 175 |
+
print(f"\n{'=' * 65}")
|
| 176 |
+
print(f"Training eff_batch={eff_batch} seq={cur_seq}")
|
| 177 |
+
print(f"{'=' * 65}\n")
|
| 178 |
+
|
| 179 |
+
while step < args.max_steps:
|
| 180 |
+
if grow:
|
| 181 |
+
ns = grow.get_seq_len(step)
|
| 182 |
+
if ns != cur_seq:
|
| 183 |
+
cur_seq = ns
|
| 184 |
+
dataset.set_seq_len(cur_seq)
|
| 185 |
+
eff_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq))
|
| 186 |
+
loader = torch.utils.data.DataLoader(dataset, batch_size=eff_batch, shuffle=True, num_workers=0, drop_last=True)
|
| 187 |
+
data_iter = iter(loader)
|
| 188 |
+
print(f" [P1] seq -> {cur_seq} batch -> {eff_batch}")
|
| 189 |
+
if unfreezer:
|
| 190 |
+
unfreezer.update(step)
|
| 191 |
+
try:
|
| 192 |
+
batch = next(data_iter)
|
| 193 |
+
except StopIteration:
|
| 194 |
+
data_iter = iter(loader)
|
| 195 |
+
batch = next(data_iter)
|
| 196 |
+
grad_accum_steps = max(1, eff_batch // max(1, args.batch_size))
|
| 197 |
+
loss_val = chimera_turbo.training_step(
|
| 198 |
+
model, batch, optimizer, scheduler, grad_accum_steps=grad_accum_steps, step=step, autocast_dtype=torch.bfloat16 if use_bf16 else None
|
| 199 |
+
)
|
| 200 |
+
cur_lr = optimizer.param_groups[0]["lr"]
|
| 201 |
+
total_loss += loss_val
|
| 202 |
+
toks += batch["input_ids"].numel()
|
| 203 |
+
step += 1
|
| 204 |
+
if step % args.log_every == 0:
|
| 205 |
+
dt = time.time() - t0
|
| 206 |
+
avg = total_loss / args.log_every
|
| 207 |
+
ppl = math.exp(min(avg, 20))
|
| 208 |
+
tps = toks / dt if dt > 0 else 0
|
| 209 |
+
eta = (args.max_steps - step) / (step / dt) / 3600 if dt > 0 else 0
|
| 210 |
+
log_f.write(json.dumps({"step": step, "loss": round(avg, 4), "ppl": round(ppl, 2), "lr": cur_lr, "tok/s": round(tps), "seq_len": cur_seq, "eff_batch": eff_batch}) + "\n")
|
| 211 |
+
log_f.flush()
|
| 212 |
+
print(f" step {step:>6}/{args.max_steps} | loss {avg:.4f} | ppl {ppl:>8.2f} | {tps:,.0f} tok/s | seq {cur_seq} | ETA {eta:.1f}h")
|
| 213 |
+
best_loss = min(best_loss, avg)
|
| 214 |
+
total_loss = 0.0
|
| 215 |
+
toks = 0
|
| 216 |
+
t0 = time.time()
|
| 217 |
+
if step % args.save_every == 0:
|
| 218 |
+
d = save_training_checkpoint(model, config, step, os.path.join(args.output_dir, f"ckpt-{step}"))
|
| 219 |
+
print(f" [SAVE] {d}")
|
| 220 |
+
|
| 221 |
+
d = save_final_checkpoint(model, config, step, best_loss, os.path.join(args.output_dir, "final"))
|
| 222 |
+
log_f.close()
|
| 223 |
+
print(f"\nDONE β best loss {best_loss:.4f} ppl {math.exp(min(best_loss, 20)):.2f}")
|
| 224 |
+
return d
|
chimera/training/optimizers.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
from chimera.quantization import BitLinear
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class MeZOOptimizer:
|
| 10 |
+
"""Memory-Efficient Zeroth-Order optimiser (Princeton MeZO)."""
|
| 11 |
+
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
model: nn.Module,
|
| 15 |
+
lr: float = 1e-4,
|
| 16 |
+
eps: float = 1e-3,
|
| 17 |
+
weight_decay: float = 0.0,
|
| 18 |
+
momentum: float = 0.0,
|
| 19 |
+
direction: str = "rademacher",
|
| 20 |
+
):
|
| 21 |
+
self.model = model
|
| 22 |
+
self.lr = float(lr)
|
| 23 |
+
self.eps = float(eps)
|
| 24 |
+
self.wd = float(weight_decay)
|
| 25 |
+
self.momentum = float(momentum)
|
| 26 |
+
if direction not in ("rademacher", "gaussian"):
|
| 27 |
+
raise ValueError(f"unknown direction: {direction!r}")
|
| 28 |
+
self.direction = direction
|
| 29 |
+
|
| 30 |
+
self._bitlinear_modules: list[tuple[str, BitLinear]] = []
|
| 31 |
+
self._dense_params: list[tuple[str, torch.Tensor]] = []
|
| 32 |
+
seen: set[int] = set()
|
| 33 |
+
|
| 34 |
+
for name, module in model.named_modules():
|
| 35 |
+
if isinstance(module, BitLinear):
|
| 36 |
+
self._bitlinear_modules.append((name, module))
|
| 37 |
+
seen.add(id(module.weight))
|
| 38 |
+
if module.bias is not None:
|
| 39 |
+
seen.add(id(module.bias))
|
| 40 |
+
|
| 41 |
+
for name, param in model.named_parameters():
|
| 42 |
+
if param.requires_grad and id(param) not in seen:
|
| 43 |
+
self._dense_params.append((name, param))
|
| 44 |
+
seen.add(id(param))
|
| 45 |
+
|
| 46 |
+
self._momentum: dict[int, torch.Tensor] = {}
|
| 47 |
+
if self.momentum > 0:
|
| 48 |
+
for _, param in self._dense_params:
|
| 49 |
+
self._momentum[id(param)] = torch.zeros_like(param.data)
|
| 50 |
+
for _, module in self._bitlinear_modules:
|
| 51 |
+
self._momentum[id(module.weight)] = torch.zeros_like(module.weight.data)
|
| 52 |
+
|
| 53 |
+
self._step_masks: dict[int, torch.Tensor] = {}
|
| 54 |
+
|
| 55 |
+
def _direction(self, p: torch.Tensor, seed: int) -> torch.Tensor:
|
| 56 |
+
gen = torch.Generator(device="cpu")
|
| 57 |
+
gen.manual_seed(int(seed) & 0x7FFF_FFFF_FFFF_FFFF)
|
| 58 |
+
if self.direction == "gaussian":
|
| 59 |
+
return torch.randn(p.shape, dtype=p.dtype, device="cpu", generator=gen).to(p.device)
|
| 60 |
+
z = torch.empty(p.shape, dtype=p.dtype, device="cpu")
|
| 61 |
+
z.bernoulli_(0.5, generator=gen).mul_(2).sub_(1)
|
| 62 |
+
return z.to(p.device)
|
| 63 |
+
|
| 64 |
+
def _walk_params(self):
|
| 65 |
+
offset = 0
|
| 66 |
+
for _, module in self._bitlinear_modules:
|
| 67 |
+
yield offset, module.weight.data, self._step_masks.get(id(module.weight))
|
| 68 |
+
offset += 1
|
| 69 |
+
if module.bias is not None:
|
| 70 |
+
yield offset, module.bias.data, None
|
| 71 |
+
offset += 1
|
| 72 |
+
for _, param in self._dense_params:
|
| 73 |
+
yield offset, param.data, None
|
| 74 |
+
offset += 1
|
| 75 |
+
|
| 76 |
+
def _perturb(self, base_seed: int, scale: float) -> None:
|
| 77 |
+
for off, param, mask in self._walk_params():
|
| 78 |
+
z = self._direction(param, base_seed + off * 1_000_003)
|
| 79 |
+
if mask is not None:
|
| 80 |
+
z = z * mask.to(dtype=z.dtype, device=z.device)
|
| 81 |
+
param.add_(z, alpha=scale)
|
| 82 |
+
for _, module in self._bitlinear_modules:
|
| 83 |
+
module.invalidate_packed()
|
| 84 |
+
|
| 85 |
+
def _update(self, base_seed: int, projected_grad: float) -> None:
|
| 86 |
+
for off, param, mask in self._walk_params():
|
| 87 |
+
z = self._direction(param, base_seed + off * 1_000_003)
|
| 88 |
+
if mask is not None:
|
| 89 |
+
z = z * mask.to(dtype=z.dtype, device=z.device)
|
| 90 |
+
buf = self._momentum.get(id(param))
|
| 91 |
+
if buf is not None:
|
| 92 |
+
buf.mul_(self.momentum).add_(z, alpha=projected_grad)
|
| 93 |
+
param.add_(buf, alpha=-self.lr)
|
| 94 |
+
else:
|
| 95 |
+
param.add_(z, alpha=-self.lr * projected_grad)
|
| 96 |
+
if self.wd > 0:
|
| 97 |
+
param.mul_(1 - self.lr * self.wd)
|
| 98 |
+
for _, module in self._bitlinear_modules:
|
| 99 |
+
module.invalidate_packed()
|
| 100 |
+
|
| 101 |
+
@torch.no_grad()
|
| 102 |
+
def step(self, loss_fn, batch) -> float:
|
| 103 |
+
seed = int(torch.randint(0, 2**31, (1,)).item())
|
| 104 |
+
self._step_masks = {id(m.weight): m.ternary_nonzero_mask().detach() for _, m in self._bitlinear_modules}
|
| 105 |
+
self._perturb(seed, +self.eps)
|
| 106 |
+
loss_pos = float(loss_fn(batch).item())
|
| 107 |
+
self._perturb(seed, -2.0 * self.eps)
|
| 108 |
+
loss_neg = float(loss_fn(batch).item())
|
| 109 |
+
self._perturb(seed, +self.eps)
|
| 110 |
+
projected_grad = (loss_pos - loss_neg) / (2.0 * self.eps)
|
| 111 |
+
self._update(seed, projected_grad)
|
| 112 |
+
self._step_masks = {}
|
| 113 |
+
return 0.5 * (loss_pos + loss_neg)
|
chimera_turbo.py
ADDED
|
@@ -0,0 +1,549 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
chimera_turbo.py β Drop-in CPU acceleration for ch1mera 5.3
|
| 3 |
+
Usage: import chimera_turbo; chimera_turbo.apply(model, optimizer, args)
|
| 4 |
+
|
| 5 |
+
Paradigmes intΓ©grΓ©s:
|
| 6 |
+
P-TURBO-1: STE + AdamW (remplace MeZO β fix convergence + 50x moins de forwards)
|
| 7 |
+
P-TURBO-2: torch.compile regional (2-3x kernel fusion)
|
| 8 |
+
P-TURBO-3: Threading optimal + tcmalloc detection
|
| 9 |
+
P-TURBO-4: IPEX bf16/AMX si disponible
|
| 10 |
+
P-TURBO-5: Cache poids quantifiΓ©s inter micro-batch
|
| 11 |
+
P-TURBO-6: INT8 ternary forward path (VNNI/AMX dispatch)
|
| 12 |
+
P-TURBO-7: Arrow mmap dataset
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import os
|
| 16 |
+
import sys
|
| 17 |
+
import warnings
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
from typing import Optional, Dict, Any, Tuple
|
| 22 |
+
from functools import wraps
|
| 23 |
+
from contextlib import nullcontext
|
| 24 |
+
|
| 25 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 26 |
+
# P-TURBO-3 : Threading + Environment
|
| 27 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 28 |
+
|
| 29 |
+
def detect_cpu_info() -> Dict[str, Any]:
|
| 30 |
+
"""Detect CPU capabilities for optimal configuration."""
|
| 31 |
+
info = {}
|
| 32 |
+
|
| 33 |
+
# Physical cores (not hyperthreads)
|
| 34 |
+
try:
|
| 35 |
+
physical = len(os.sched_getaffinity(0))
|
| 36 |
+
# Heuristic: if thread count is even, likely HT enabled β halve
|
| 37 |
+
import multiprocessing
|
| 38 |
+
logical = multiprocessing.cpu_count()
|
| 39 |
+
info["physical_cores"] = logical // 2 if logical == physical else physical
|
| 40 |
+
info["logical_cores"] = logical
|
| 41 |
+
except Exception:
|
| 42 |
+
import multiprocessing
|
| 43 |
+
info["logical_cores"] = multiprocessing.cpu_count()
|
| 44 |
+
info["physical_cores"] = info["logical_cores"] // 2
|
| 45 |
+
|
| 46 |
+
# CPU capability
|
| 47 |
+
try:
|
| 48 |
+
info["capability"] = torch.backends.cpu.get_cpu_capability()
|
| 49 |
+
except Exception:
|
| 50 |
+
info["capability"] = "unknown"
|
| 51 |
+
|
| 52 |
+
# AMX support (Sapphire Rapids+)
|
| 53 |
+
info["has_amx"] = "amx" in info["capability"].lower() if info["capability"] else False
|
| 54 |
+
info["has_avx512"] = "avx512" in info["capability"].lower() if info["capability"] else False
|
| 55 |
+
info["has_vnni"] = info["has_avx512"] # VNNI comes with AVX-512 Ice Lake+
|
| 56 |
+
|
| 57 |
+
# IPEX available?
|
| 58 |
+
try:
|
| 59 |
+
import intel_extension_for_pytorch
|
| 60 |
+
info["ipex_available"] = True
|
| 61 |
+
info["ipex_version"] = intel_extension_for_pytorch.__version__
|
| 62 |
+
except ImportError:
|
| 63 |
+
info["ipex_available"] = False
|
| 64 |
+
|
| 65 |
+
# tcmalloc loaded?
|
| 66 |
+
info["tcmalloc"] = "tcmalloc" in os.environ.get("LD_PRELOAD", "")
|
| 67 |
+
|
| 68 |
+
return info
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def configure_threading(cpu_info: Dict[str, Any], reserve_for_io: int = 1):
|
| 72 |
+
"""Set optimal threading for CPU training."""
|
| 73 |
+
n_compute = max(1, cpu_info["physical_cores"] - reserve_for_io)
|
| 74 |
+
|
| 75 |
+
torch.set_num_threads(n_compute)
|
| 76 |
+
torch.set_num_interop_threads(min(4, reserve_for_io + 1))
|
| 77 |
+
|
| 78 |
+
os.environ["OMP_NUM_THREADS"] = str(n_compute)
|
| 79 |
+
os.environ["MKL_NUM_THREADS"] = str(n_compute)
|
| 80 |
+
|
| 81 |
+
return n_compute
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 85 |
+
# P-TURBO-1 : STE + AdamW (remplace MeZO)
|
| 86 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 87 |
+
|
| 88 |
+
def create_optimizer(
|
| 89 |
+
model: nn.Module,
|
| 90 |
+
lr: float = 1e-3,
|
| 91 |
+
weight_decay: float = 0.05,
|
| 92 |
+
use_lion: bool = False,
|
| 93 |
+
betas: Tuple[float, float] = (0.9, 0.95),
|
| 94 |
+
) -> torch.optim.Optimizer:
|
| 95 |
+
"""
|
| 96 |
+
Create optimizer for STE-based ternary training (replaces MeZO).
|
| 97 |
+
|
| 98 |
+
Based on BitNet b1.58 Reloaded (2407.09527):
|
| 99 |
+
- lr=1e-3 for <300M params (NOT 1e-2, that's for 3B+)
|
| 100 |
+
- weight_decay=0.05
|
| 101 |
+
- AdamW with Ξ²=(0.9, 0.95)
|
| 102 |
+
|
| 103 |
+
The STE is already in BitLinear β just use a normal optimizer.
|
| 104 |
+
MeZO needed 528 forward passes per step; this needs 1 forward + 1 backward.
|
| 105 |
+
"""
|
| 106 |
+
# Separate weight decay groups (no WD on bias, layernorm, embeddings)
|
| 107 |
+
decay_params = []
|
| 108 |
+
no_decay_params = []
|
| 109 |
+
|
| 110 |
+
for name, param in model.named_parameters():
|
| 111 |
+
if not param.requires_grad:
|
| 112 |
+
continue
|
| 113 |
+
if param.ndim <= 1 or "bias" in name or "norm" in name or "embed" in name:
|
| 114 |
+
no_decay_params.append(param)
|
| 115 |
+
else:
|
| 116 |
+
decay_params.append(param)
|
| 117 |
+
|
| 118 |
+
param_groups = [
|
| 119 |
+
{"params": decay_params, "weight_decay": weight_decay},
|
| 120 |
+
{"params": no_decay_params, "weight_decay": 0.0},
|
| 121 |
+
]
|
| 122 |
+
|
| 123 |
+
if use_lion:
|
| 124 |
+
try:
|
| 125 |
+
from lion_pytorch import Lion
|
| 126 |
+
return Lion(param_groups, lr=lr * 0.3, betas=(0.95, 0.98))
|
| 127 |
+
except ImportError:
|
| 128 |
+
warnings.warn("lion-pytorch not installed, falling back to AdamW")
|
| 129 |
+
|
| 130 |
+
return torch.optim.AdamW(param_groups, lr=lr, betas=betas, fused=False)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def create_scheduler(optimizer, max_steps: int, warmup_steps: int = 500):
|
| 134 |
+
"""Cosine schedule with linear warmup β standard BitNet recipe."""
|
| 135 |
+
from torch.optim.lr_scheduler import LambdaLR
|
| 136 |
+
import math
|
| 137 |
+
|
| 138 |
+
def lr_lambda(step):
|
| 139 |
+
if step < warmup_steps:
|
| 140 |
+
return step / max(1, warmup_steps)
|
| 141 |
+
progress = (step - warmup_steps) / max(1, max_steps - warmup_steps)
|
| 142 |
+
return max(0.01, 0.5 * (1.0 + math.cos(math.pi * progress)))
|
| 143 |
+
|
| 144 |
+
return LambdaLR(optimizer, lr_lambda)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 148 |
+
# P-TURBO-5 : Quantized Weight Cache
|
| 149 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 150 |
+
|
| 151 |
+
class QuantCacheMixin:
|
| 152 |
+
"""
|
| 153 |
+
Mixin for BitLinear to cache quantized weights during gradient accumulation.
|
| 154 |
+
|
| 155 |
+
Without cache: quantize weights on every micro-batch forward pass
|
| 156 |
+
With cache: quantize once, reuse across accumulation steps
|
| 157 |
+
Invalidate after optimizer.step()
|
| 158 |
+
"""
|
| 159 |
+
_quant_cache: Optional[torch.Tensor] = None
|
| 160 |
+
_cache_valid: bool = False
|
| 161 |
+
|
| 162 |
+
def get_quantized_weight(self):
|
| 163 |
+
"""Override in your BitLinear. Returns quantized weight + scale."""
|
| 164 |
+
raise NotImplementedError
|
| 165 |
+
|
| 166 |
+
def cached_quantized_weight(self):
|
| 167 |
+
if not self._cache_valid or self._quant_cache is None:
|
| 168 |
+
self._quant_cache = self.get_quantized_weight()
|
| 169 |
+
self._cache_valid = True
|
| 170 |
+
return self._quant_cache
|
| 171 |
+
|
| 172 |
+
def invalidate_cache(self):
|
| 173 |
+
self._cache_valid = False
|
| 174 |
+
self._quant_cache = None
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def invalidate_all_caches(model: nn.Module):
|
| 178 |
+
"""Call after optimizer.step() to force re-quantization."""
|
| 179 |
+
for m in model.modules():
|
| 180 |
+
if hasattr(m, "invalidate_cache"):
|
| 181 |
+
m.invalidate_cache()
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 185 |
+
# P-TURBO-6 : INT8 Ternary Forward Path
|
| 186 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 187 |
+
|
| 188 |
+
def ternary_matmul_int8(
|
| 189 |
+
x: torch.Tensor, # [B, S, K] float
|
| 190 |
+
w_ternary: torch.Tensor, # [N, K] float {-1, 0, 1}
|
| 191 |
+
w_scale: torch.Tensor, # scalar
|
| 192 |
+
) -> torch.Tensor:
|
| 193 |
+
"""
|
| 194 |
+
INT8 ternary matmul using torch._int_mm (dispatches to VNNI/AMX).
|
| 195 |
+
|
| 196 |
+
For inference-in-training (eval steps) or forward pass if
|
| 197 |
+
your hardware has VNNI/AMX support.
|
| 198 |
+
|
| 199 |
+
Speedup: 2-4x over float GEMM for ternary weights.
|
| 200 |
+
"""
|
| 201 |
+
B, S, K = x.shape
|
| 202 |
+
x_flat = x.reshape(-1, K) # [B*S, K]
|
| 203 |
+
|
| 204 |
+
# Quantize activations to int8
|
| 205 |
+
x_abs_max = x_flat.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8)
|
| 206 |
+
x_scale = x_abs_max / 127.0
|
| 207 |
+
x_int8 = (x_flat / x_scale).round().clamp(-128, 127).to(torch.int8)
|
| 208 |
+
|
| 209 |
+
# Weights: already ternary, just cast
|
| 210 |
+
w_int8 = w_ternary.to(torch.int8) # {-1, 0, 1} fits in int8
|
| 211 |
+
|
| 212 |
+
# INT8 GEMM β uses hardware VNNI/AMX if available
|
| 213 |
+
# torch._int_mm requires 2D inputs, both int8, K divisible by some alignment
|
| 214 |
+
try:
|
| 215 |
+
out_int32 = torch._int_mm(x_int8, w_int8.t()) # [B*S, N]
|
| 216 |
+
out = out_int32.float() * x_scale * w_scale
|
| 217 |
+
except RuntimeError:
|
| 218 |
+
# Fallback if alignment requirements not met
|
| 219 |
+
out = F.linear(x_flat.float(), w_ternary.float()) * w_scale
|
| 220 |
+
|
| 221 |
+
return out.reshape(B, S, -1)
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 225 |
+
# P-TURBO-2 : torch.compile (Regional)
|
| 226 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 227 |
+
|
| 228 |
+
def try_compile_model(model: nn.Module, mode: str = "reduce-overhead") -> nn.Module:
|
| 229 |
+
"""
|
| 230 |
+
Attempt torch.compile with graceful fallback.
|
| 231 |
+
|
| 232 |
+
Uses regional compilation: compiles sub-modules individually
|
| 233 |
+
to work around graph breaks from STE custom autograd functions.
|
| 234 |
+
"""
|
| 235 |
+
if not hasattr(torch, "compile"):
|
| 236 |
+
warnings.warn("torch.compile not available (PyTorch < 2.0)")
|
| 237 |
+
return model
|
| 238 |
+
|
| 239 |
+
# First: diagnose graph breaks
|
| 240 |
+
try:
|
| 241 |
+
import torch._dynamo as dynamo
|
| 242 |
+
|
| 243 |
+
# Try compiling individual attention/MLP blocks instead of full model
|
| 244 |
+
compiled_count = 0
|
| 245 |
+
for name, module in model.named_modules():
|
| 246 |
+
# Skip the top-level model and BitLinear (STE graph breaks)
|
| 247 |
+
if module is model:
|
| 248 |
+
continue
|
| 249 |
+
# Compile "clean" blocks: attention, MLP, norms
|
| 250 |
+
module_type = type(module).__name__.lower()
|
| 251 |
+
if any(k in module_type for k in ["attention", "mlp", "feedforward", "norm"]):
|
| 252 |
+
try:
|
| 253 |
+
compiled = torch.compile(
|
| 254 |
+
module,
|
| 255 |
+
backend="inductor",
|
| 256 |
+
mode=mode,
|
| 257 |
+
fullgraph=False,
|
| 258 |
+
)
|
| 259 |
+
# Replace in parent
|
| 260 |
+
parent_name = ".".join(name.split(".")[:-1])
|
| 261 |
+
child_name = name.split(".")[-1]
|
| 262 |
+
parent = model
|
| 263 |
+
if parent_name:
|
| 264 |
+
for part in parent_name.split("."):
|
| 265 |
+
parent = getattr(parent, part)
|
| 266 |
+
setattr(parent, child_name, compiled)
|
| 267 |
+
compiled_count += 1
|
| 268 |
+
except Exception as e:
|
| 269 |
+
pass # Skip modules that can't be compiled
|
| 270 |
+
|
| 271 |
+
if compiled_count == 0:
|
| 272 |
+
# Fallback: try compiling the whole model with fullgraph=False
|
| 273 |
+
model = torch.compile(model, backend="inductor", mode=mode, fullgraph=False)
|
| 274 |
+
print(f"[TURBO-2] Compiled full model (fullgraph=False)")
|
| 275 |
+
else:
|
| 276 |
+
print(f"[TURBO-2] Compiled {compiled_count} sub-modules (regional)")
|
| 277 |
+
|
| 278 |
+
return model
|
| 279 |
+
|
| 280 |
+
except Exception as e:
|
| 281 |
+
warnings.warn(f"torch.compile failed: {e}. Running in eager mode.")
|
| 282 |
+
return model
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 286 |
+
# P-TURBO-4 : IPEX Integration
|
| 287 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 288 |
+
|
| 289 |
+
def try_ipex_optimize(
|
| 290 |
+
model: nn.Module,
|
| 291 |
+
optimizer: torch.optim.Optimizer,
|
| 292 |
+
cpu_info: Dict[str, Any],
|
| 293 |
+
dtype: Optional[torch.dtype] = None,
|
| 294 |
+
) -> Tuple[nn.Module, torch.optim.Optimizer]:
|
| 295 |
+
"""Apply IPEX optimization if available and beneficial."""
|
| 296 |
+
if not cpu_info.get("ipex_available"):
|
| 297 |
+
print("[TURBO-4] IPEX not available β install: pip install intel-extension-for-pytorch")
|
| 298 |
+
return model, optimizer
|
| 299 |
+
|
| 300 |
+
import intel_extension_for_pytorch as ipex
|
| 301 |
+
|
| 302 |
+
# Choose dtype based on hardware
|
| 303 |
+
if dtype is None:
|
| 304 |
+
if cpu_info["has_amx"]:
|
| 305 |
+
dtype = torch.bfloat16 # AMX tiles β massive bf16 speedup
|
| 306 |
+
print("[TURBO-4] IPEX + AMX bf16 enabled (Sapphire Rapids+)")
|
| 307 |
+
elif cpu_info["has_avx512"]:
|
| 308 |
+
dtype = torch.bfloat16 # Moderate benefit with AVX-512
|
| 309 |
+
print("[TURBO-4] IPEX + AVX-512 bf16 enabled")
|
| 310 |
+
else:
|
| 311 |
+
dtype = torch.float32 # bf16 slower than fp32 without hardware support
|
| 312 |
+
print("[TURBO-4] IPEX fp32 (no bf16 hardware support detected)")
|
| 313 |
+
|
| 314 |
+
model, optimizer = ipex.optimize(
|
| 315 |
+
model,
|
| 316 |
+
optimizer=optimizer,
|
| 317 |
+
dtype=dtype,
|
| 318 |
+
level="O1",
|
| 319 |
+
inplace=True,
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
return model, optimizer
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 326 |
+
# P-TURBO-7 : Arrow mmap Dataset
|
| 327 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 328 |
+
|
| 329 |
+
def prepare_arrow_dataset(
|
| 330 |
+
dataset_name: str = "roneneldan/TinyStories",
|
| 331 |
+
split: str = "train",
|
| 332 |
+
tokenizer=None,
|
| 333 |
+
seq_len: int = 32,
|
| 334 |
+
max_tokens: int = 500_000,
|
| 335 |
+
cache_dir: str = "./cache/arrow",
|
| 336 |
+
num_proc: int = 4,
|
| 337 |
+
):
|
| 338 |
+
"""
|
| 339 |
+
Prepare dataset as Arrow mmap format for zero-copy loading.
|
| 340 |
+
|
| 341 |
+
Replaces streaming + custom .pt cache with HF datasets Arrow backend.
|
| 342 |
+
Benefits: zero-copy to PyTorch, random access, efficient memory via mmap.
|
| 343 |
+
"""
|
| 344 |
+
from datasets import load_dataset, Dataset
|
| 345 |
+
from pathlib import Path
|
| 346 |
+
|
| 347 |
+
cache_path = Path(cache_dir) / f"{dataset_name.replace('/', '_')}_{split}_{max_tokens}_seq{seq_len}"
|
| 348 |
+
|
| 349 |
+
if cache_path.exists():
|
| 350 |
+
print(f"[TURBO-7] Loading cached Arrow dataset from {cache_path}")
|
| 351 |
+
dataset = Dataset.load_from_disk(str(cache_path))
|
| 352 |
+
return dataset.with_format("torch")
|
| 353 |
+
|
| 354 |
+
print(f"[TURBO-7] Preparing Arrow dataset from {dataset_name}...")
|
| 355 |
+
|
| 356 |
+
# Load and tokenize
|
| 357 |
+
raw = load_dataset(dataset_name, split=split, streaming=True)
|
| 358 |
+
|
| 359 |
+
# Collect tokens
|
| 360 |
+
all_tokens = []
|
| 361 |
+
total = 0
|
| 362 |
+
for example in raw:
|
| 363 |
+
text = example.get("text", "")
|
| 364 |
+
if tokenizer is not None:
|
| 365 |
+
tokens = tokenizer.encode(text)
|
| 366 |
+
else:
|
| 367 |
+
# Fallback: assume pre-tokenized or return text
|
| 368 |
+
tokens = text
|
| 369 |
+
if isinstance(tokens, list):
|
| 370 |
+
all_tokens.extend(tokens)
|
| 371 |
+
total += len(tokens)
|
| 372 |
+
if total >= max_tokens:
|
| 373 |
+
break
|
| 374 |
+
|
| 375 |
+
all_tokens = all_tokens[:max_tokens]
|
| 376 |
+
|
| 377 |
+
# Chunk into sequences
|
| 378 |
+
n_seqs = len(all_tokens) // seq_len
|
| 379 |
+
chunks = [all_tokens[i * seq_len:(i + 1) * seq_len] for i in range(n_seqs)]
|
| 380 |
+
|
| 381 |
+
dataset = Dataset.from_dict({
|
| 382 |
+
"input_ids": chunks,
|
| 383 |
+
})
|
| 384 |
+
|
| 385 |
+
# Save as Arrow
|
| 386 |
+
cache_path.parent.mkdir(parents=True, exist_ok=True)
|
| 387 |
+
dataset.save_to_disk(str(cache_path))
|
| 388 |
+
print(f"[TURBO-7] Saved {n_seqs} sequences to {cache_path}")
|
| 389 |
+
|
| 390 |
+
return dataset.with_format("torch")
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 394 |
+
# MAIN: apply() β Point d'entrΓ©e unique
|
| 395 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 396 |
+
|
| 397 |
+
def apply(
|
| 398 |
+
model: nn.Module,
|
| 399 |
+
max_steps: int = 10000,
|
| 400 |
+
lr: float = 1e-3,
|
| 401 |
+
weight_decay: float = 0.05,
|
| 402 |
+
warmup_steps: int = 500,
|
| 403 |
+
use_compile: bool = True,
|
| 404 |
+
use_ipex: bool = True,
|
| 405 |
+
use_lion: bool = False,
|
| 406 |
+
verbose: bool = True,
|
| 407 |
+
) -> Tuple[nn.Module, torch.optim.Optimizer, Any]:
|
| 408 |
+
"""
|
| 409 |
+
Apply all turbo optimizations to ch1mera model.
|
| 410 |
+
|
| 411 |
+
Returns: (model, optimizer, scheduler)
|
| 412 |
+
|
| 413 |
+
Usage in train_hyper.py:
|
| 414 |
+
import chimera_turbo
|
| 415 |
+
model, optimizer, scheduler = chimera_turbo.apply(
|
| 416 |
+
model, max_steps=10000, lr=1e-3
|
| 417 |
+
)
|
| 418 |
+
# Then use normal training loop:
|
| 419 |
+
for step, batch in enumerate(dataloader):
|
| 420 |
+
loss = model(batch).loss
|
| 421 |
+
loss.backward()
|
| 422 |
+
if (step + 1) % grad_accum == 0:
|
| 423 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 424 |
+
optimizer.step()
|
| 425 |
+
scheduler.step()
|
| 426 |
+
optimizer.zero_grad(set_to_none=True)
|
| 427 |
+
chimera_turbo.invalidate_all_caches(model)
|
| 428 |
+
"""
|
| 429 |
+
# ββ Step 1: Detect CPU ββ
|
| 430 |
+
cpu_info = detect_cpu_info()
|
| 431 |
+
|
| 432 |
+
if verbose:
|
| 433 |
+
print("=" * 65)
|
| 434 |
+
print("CHIMERA TURBO β CPU Acceleration Layer")
|
| 435 |
+
print("=" * 65)
|
| 436 |
+
print(f" Physical cores: {cpu_info['physical_cores']}")
|
| 437 |
+
print(f" CPU capability: {cpu_info['capability']}")
|
| 438 |
+
print(f" AMX: {cpu_info['has_amx']} AVX-512: {cpu_info['has_avx512']}")
|
| 439 |
+
print(f" IPEX: {cpu_info['ipex_available']}")
|
| 440 |
+
print(f" tcmalloc: {cpu_info['tcmalloc']}")
|
| 441 |
+
|
| 442 |
+
# ββ Step 2: Threading ββ
|
| 443 |
+
n_threads = configure_threading(cpu_info)
|
| 444 |
+
if verbose:
|
| 445 |
+
print(f"[TURBO-3] Threads: {n_threads} compute + {torch.get_num_interop_threads()} interop")
|
| 446 |
+
|
| 447 |
+
# ββ Step 3: Optimizer (replaces MeZO) ββ
|
| 448 |
+
optimizer = create_optimizer(model, lr=lr, weight_decay=weight_decay, use_lion=use_lion)
|
| 449 |
+
scheduler = create_scheduler(optimizer, max_steps=max_steps, warmup_steps=warmup_steps)
|
| 450 |
+
if verbose:
|
| 451 |
+
opt_name = type(optimizer).__name__
|
| 452 |
+
n_params = sum(p.numel() for g in optimizer.param_groups for p in g["params"])
|
| 453 |
+
print(f"[TURBO-1] {opt_name} (lr={lr}, wd={weight_decay}) β {n_params:,} params")
|
| 454 |
+
print(f" Replaces MeZO: 528 forwards/step β 1 forward + 1 backward")
|
| 455 |
+
|
| 456 |
+
# ββ Step 4: IPEX ββ
|
| 457 |
+
if use_ipex:
|
| 458 |
+
model, optimizer = try_ipex_optimize(model, optimizer, cpu_info)
|
| 459 |
+
|
| 460 |
+
# ββ Step 5: torch.compile ββ
|
| 461 |
+
if use_compile:
|
| 462 |
+
model = try_compile_model(model)
|
| 463 |
+
|
| 464 |
+
if verbose:
|
| 465 |
+
if not cpu_info["tcmalloc"]:
|
| 466 |
+
print()
|
| 467 |
+
print(" β οΈ tcmalloc not detected. For +10-25% speedup:")
|
| 468 |
+
print(" sudo apt install google-perftools")
|
| 469 |
+
print(" LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libtcmalloc.so.4 python train_hyper.py ...")
|
| 470 |
+
print("=" * 65)
|
| 471 |
+
|
| 472 |
+
return model, optimizer, scheduler
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 476 |
+
# Training loop helper
|
| 477 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 478 |
+
|
| 479 |
+
def training_step(
|
| 480 |
+
model: nn.Module,
|
| 481 |
+
batch,
|
| 482 |
+
optimizer: torch.optim.Optimizer,
|
| 483 |
+
scheduler,
|
| 484 |
+
grad_accum_steps: int = 1,
|
| 485 |
+
step: int = 0,
|
| 486 |
+
max_grad_norm: float = 1.0,
|
| 487 |
+
autocast_dtype: Optional[torch.dtype] = torch.bfloat16,
|
| 488 |
+
) -> float:
|
| 489 |
+
"""
|
| 490 |
+
Single training step with all turbo optimizations active.
|
| 491 |
+
|
| 492 |
+
Handles: autocast, gradient accumulation, clipping, cache invalidation.
|
| 493 |
+
"""
|
| 494 |
+
is_accum_step = (step + 1) % grad_accum_steps == 0
|
| 495 |
+
|
| 496 |
+
# Forward + backward
|
| 497 |
+
ctx = torch.autocast(device_type="cpu", dtype=autocast_dtype) if autocast_dtype else nullcontext()
|
| 498 |
+
with ctx:
|
| 499 |
+
if isinstance(batch, dict):
|
| 500 |
+
outputs = model(batch["input_ids"], labels=batch.get("labels"))
|
| 501 |
+
elif isinstance(batch, (tuple, list)):
|
| 502 |
+
outputs = model(*batch)
|
| 503 |
+
else:
|
| 504 |
+
outputs = model(batch)
|
| 505 |
+
loss = outputs if isinstance(outputs, torch.Tensor) else outputs.loss
|
| 506 |
+
loss = loss / grad_accum_steps
|
| 507 |
+
|
| 508 |
+
loss.backward()
|
| 509 |
+
|
| 510 |
+
if is_accum_step:
|
| 511 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
|
| 512 |
+
optimizer.step()
|
| 513 |
+
scheduler.step()
|
| 514 |
+
optimizer.zero_grad(set_to_none=True)
|
| 515 |
+
invalidate_all_caches(model)
|
| 516 |
+
|
| 517 |
+
return loss.item() * grad_accum_steps
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 521 |
+
# Diagnostic tool
|
| 522 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 523 |
+
|
| 524 |
+
def profile_model(model: nn.Module, dummy_input: torch.Tensor, steps: int = 5):
|
| 525 |
+
"""Profile forward+backward to find bottlenecks."""
|
| 526 |
+
print("\n[TURBO-DIAG] Profiling...")
|
| 527 |
+
|
| 528 |
+
# Warmup
|
| 529 |
+
for _ in range(2):
|
| 530 |
+
out = model(dummy_input)
|
| 531 |
+
if hasattr(out, "loss"):
|
| 532 |
+
out.loss.backward()
|
| 533 |
+
else:
|
| 534 |
+
out.sum().backward()
|
| 535 |
+
model.zero_grad(set_to_none=True)
|
| 536 |
+
|
| 537 |
+
with torch.profiler.profile(
|
| 538 |
+
activities=[torch.profiler.ProfilerActivity.CPU],
|
| 539 |
+
record_shapes=True,
|
| 540 |
+
with_stack=True,
|
| 541 |
+
) as prof:
|
| 542 |
+
for _ in range(steps):
|
| 543 |
+
out = model(dummy_input)
|
| 544 |
+
loss = out.loss if hasattr(out, "loss") else out.sum()
|
| 545 |
+
loss.backward()
|
| 546 |
+
model.zero_grad(set_to_none=True)
|
| 547 |
+
|
| 548 |
+
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20))
|
| 549 |
+
return prof
|
config.json
ADDED
|
@@ -0,0 +1,716 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_name_or_path": "chimera-5.3-hyper",
|
| 3 |
+
"_v": "5.3.0",
|
| 4 |
+
"architectures": ["Chimera51ForCausalLM"],
|
| 5 |
+
"auto_map": {
|
| 6 |
+
"AutoConfig": "configuration_chimera51.Chimera51Config",
|
| 7 |
+
"AutoModelForCausalLM": "modeling_chimera51.Chimera51ForCausalLM"
|
| 8 |
+
},
|
| 9 |
+
"model_type": "chimera51",
|
| 10 |
+
"token_ids": [199999, 200058],
|
| 11 |
+
"hidden_size": 2560,
|
| 12 |
+
"intermediate_size": 6912,
|
| 13 |
+
"num_hidden_layers": 28,
|
| 14 |
+
"num_heads": 40,
|
| 15 |
+
"head_dim": 64,
|
| 16 |
+
"hidden_act": "swiglu",
|
| 17 |
+
"initializer_range": 0.006,
|
| 18 |
+
"rms_norm_eps": 1e-6,
|
| 19 |
+
"rms_norm_before_every_linear": true,
|
| 20 |
+
"vocab_size": 200073,
|
| 21 |
+
"max_position_embeddings": 4194304,
|
| 22 |
+
"tie_word_embeddings": true,
|
| 23 |
+
"torch_dtype": "bfloat16",
|
| 24 |
+
"use_cache": false,
|
| 25 |
+
"transformers_version": "4.58.0",
|
| 26 |
+
|
| 27 |
+
"Β§": {
|
| 28 |
+
"r0": "2412.06464",
|
| 29 |
+
"r1": "2405.04517",
|
| 30 |
+
"r2": "2501.00663",
|
| 31 |
+
"r3": "2604.12946",
|
| 32 |
+
"r4": "2510.04800",
|
| 33 |
+
"r5": "2402.17764",
|
| 34 |
+
"r6": "2505.08823",
|
| 35 |
+
"r7": "2502.11880",
|
| 36 |
+
"r8": "2601.07892",
|
| 37 |
+
"r9": "2602.05269",
|
| 38 |
+
"r10": "2503.01840",
|
| 39 |
+
"r11": "2505.14969",
|
| 40 |
+
"r12": "2411.15100",
|
| 41 |
+
"r13": "2601.04426",
|
| 42 |
+
"r14": "2604.06169",
|
| 43 |
+
"r15": "2602.02369",
|
| 44 |
+
"r16": "2402.04624",
|
| 45 |
+
"r17": "2508.16153",
|
| 46 |
+
"r18": "2310.00533",
|
| 47 |
+
"r19": "2404.02258",
|
| 48 |
+
"r20": "2510.11170",
|
| 49 |
+
"r21": "2408.15664",
|
| 50 |
+
"r22": "2512.12602",
|
| 51 |
+
"r23": "2412.09871",
|
| 52 |
+
"r24": "2501.15570",
|
| 53 |
+
"r25": "2506.12119",
|
| 54 |
+
"r26": "2407.00088",
|
| 55 |
+
"r27": "2410.16144",
|
| 56 |
+
"r28": "2512.06443",
|
| 57 |
+
"r29": "2305.17333",
|
| 58 |
+
"r30": "2509.00031",
|
| 59 |
+
"r31": "2305.17190",
|
| 60 |
+
"r32": "2402.16363",
|
| 61 |
+
"r33": "2502.12444",
|
| 62 |
+
"r34": "2603.13931",
|
| 63 |
+
"r35": "2302.04852",
|
| 64 |
+
"r36": "2305.02299",
|
| 65 |
+
"r37": "2310.00576",
|
| 66 |
+
"r38": "2512.23145",
|
| 67 |
+
"r39": "2406.02913",
|
| 68 |
+
"r40": "2403.03507",
|
| 69 |
+
"r41": "2502.12346",
|
| 70 |
+
"r42": "2406.17660"
|
| 71 |
+
},
|
| 72 |
+
|
| 73 |
+
"quantization": {
|
| 74 |
+
"method": "bitnet",
|
| 75 |
+
"linear_class": "ternary_bitplane",
|
| 76 |
+
"weight_bits": 1.58,
|
| 77 |
+
"weight_values": [-1, 0, 1],
|
| 78 |
+
"weight_scale": "absmean_per_group",
|
| 79 |
+
"group_size": 128,
|
| 80 |
+
"activation_bits": 8,
|
| 81 |
+
"activation_method": "absmax_per_block",
|
| 82 |
+
"activation_block_size": 64,
|
| 83 |
+
"accumulator_dtype": "int32",
|
| 84 |
+
"norm_dtype": "float32",
|
| 85 |
+
"runtime_kernel": "TL2_bitnet_cpp",
|
| 86 |
+
"Β§": ["r5", "r7", "r27"],
|
| 87 |
+
"sherry_mode": {
|
| 88 |
+
"enabled": false,
|
| 89 |
+
"bits": 1.25,
|
| 90 |
+
"Β§": "r8"
|
| 91 |
+
},
|
| 92 |
+
"hgf_correction": {
|
| 93 |
+
"enabled": false,
|
| 94 |
+
"Β§": "r9"
|
| 95 |
+
}
|
| 96 |
+
},
|
| 97 |
+
|
| 98 |
+
"backbone": {
|
| 99 |
+
"type": "hybrid_recurrent_no_attention",
|
| 100 |
+
"layer_pattern": "GD XM GD TM GD XM GD SK",
|
| 101 |
+
"layer_pattern_repeat": 3.5,
|
| 102 |
+
"layer_aliases": {
|
| 103 |
+
"GD": "gated_deltanet",
|
| 104 |
+
"XM": "xlstm_m",
|
| 105 |
+
"TM": "titans_mac",
|
| 106 |
+
"SK": "tsp_span_knot"
|
| 107 |
+
},
|
| 108 |
+
"layer_counts": {"GD": 14, "XM": 7, "TM": 4, "SK": 3},
|
| 109 |
+
"kv_cache": "none",
|
| 110 |
+
"Β§": ["r0", "r1", "r2", "r4"],
|
| 111 |
+
|
| 112 |
+
"moe": {
|
| 113 |
+
"enabled": true,
|
| 114 |
+
"layers": [3, 7, 11, 15, 19, 23, 27],
|
| 115 |
+
"n_routed_experts": 16,
|
| 116 |
+
"n_shared_experts": 1,
|
| 117 |
+
"num_experts_per_tok": 2,
|
| 118 |
+
"moe_intermediate_size": 1728,
|
| 119 |
+
"routing": "noaux_bias",
|
| 120 |
+
"total_params": "350M",
|
| 121 |
+
"active_params_per_tok": "44M",
|
| 122 |
+
"Β§": ["r21", "r25"]
|
| 123 |
+
}
|
| 124 |
+
},
|
| 125 |
+
|
| 126 |
+
"gated_deltanet": {
|
| 127 |
+
"formulation": "S_t = S_{t-1} * (Ξ±_t * (I - Ξ²_t * k_t * k_t^T)) + Ξ²_t * v_t * k_t^T",
|
| 128 |
+
"alpha_gate": "data_dependent_scalar",
|
| 129 |
+
"beta_gate": "data_dependent_scalar",
|
| 130 |
+
"state_size": 64,
|
| 131 |
+
"chunkwise_parallel": true,
|
| 132 |
+
"chunk_size": 256,
|
| 133 |
+
"key_norm": "l2",
|
| 134 |
+
"Β§": "r0"
|
| 135 |
+
},
|
| 136 |
+
|
| 137 |
+
"efla": {
|
| 138 |
+
"enabled": false,
|
| 139 |
+
"target_layers": "SK",
|
| 140 |
+
"Β§": "r22"
|
| 141 |
+
},
|
| 142 |
+
|
| 143 |
+
"xlstm": {
|
| 144 |
+
"variant": "mLSTM",
|
| 145 |
+
"exponential_gating": true,
|
| 146 |
+
"memory_size_per_head": [64, 64],
|
| 147 |
+
"covariance_update": true,
|
| 148 |
+
"normalizer_state": "max_stabilized",
|
| 149 |
+
"Β§": "r1"
|
| 150 |
+
},
|
| 151 |
+
|
| 152 |
+
"titans": {
|
| 153 |
+
"memory_type": "MAC",
|
| 154 |
+
"memory_depth": 2,
|
| 155 |
+
"surprise_metric": "gradient_with_momentum",
|
| 156 |
+
"surprise_formula": "S_t = Ξ·_t Β· S_{t-1} β ΞΈ_t Β· ββ(M_{t-1}; x_t)",
|
| 157 |
+
"forgetting_formula": "M_t = (1 β Ξ±_t) Β· M_{t-1} + S_t",
|
| 158 |
+
"persistent_memory_slots": 64,
|
| 159 |
+
"local_window_size": 1024,
|
| 160 |
+
"Β§": "r2"
|
| 161 |
+
},
|
| 162 |
+
|
| 163 |
+
"looping": {
|
| 164 |
+
"enabled": true,
|
| 165 |
+
"method": "parcae_zoh_stable",
|
| 166 |
+
"prelude": [0, 3],
|
| 167 |
+
"loop": [4, 23],
|
| 168 |
+
"coda": [24, 27],
|
| 169 |
+
"loop_range": [1, 6],
|
| 170 |
+
"loop_default": 2,
|
| 171 |
+
"stability_A": "diag_negative_exp",
|
| 172 |
+
"spectral_radius_bound": 1.0,
|
| 173 |
+
"depth_selection": "stochastic_per_sequence",
|
| 174 |
+
"adaptive_exit_threshold": 0.01,
|
| 175 |
+
"backward_truncation": "half",
|
| 176 |
+
"Β§": "r3"
|
| 177 |
+
},
|
| 178 |
+
|
| 179 |
+
"span_inference": {
|
| 180 |
+
"enabled": true,
|
| 181 |
+
"bank_entries": 524288,
|
| 182 |
+
"bank_avg_tokens": 5,
|
| 183 |
+
"bank_max_tokens": 64,
|
| 184 |
+
"bank_memory_mb": 384,
|
| 185 |
+
"candidate_sources": [64, 48, 48, 32],
|
| 186 |
+
"candidate_source_keys": ["semantic_lsh", "grammar_allowed", "cache_hits", "neural_novel"],
|
| 187 |
+
"candidates_fast": 192,
|
| 188 |
+
"candidates_reason": 512,
|
| 189 |
+
|
| 190 |
+
"tree_verify": {
|
| 191 |
+
"enabled": true,
|
| 192 |
+
"method": "STree",
|
| 193 |
+
"tree_width": 4,
|
| 194 |
+
"tree_depth": 5,
|
| 195 |
+
"hardware_aware": true,
|
| 196 |
+
"Β§": "r11"
|
| 197 |
+
},
|
| 198 |
+
|
| 199 |
+
"certificate_fields": ["span_id_u32", "semantic_delta_8192b", "grammar_delta_128b", "entity_delta_512b", "debt_delta_64b", "boundary_logprob_i16", "interior_risk_u8"],
|
| 200 |
+
"certificate_verify_max_us": 100,
|
| 201 |
+
"adaptive_mask_cache": true,
|
| 202 |
+
"render_queue_target": 256,
|
| 203 |
+
"render_queue_max": 2048,
|
| 204 |
+
"fallback_below_acceptance": 0.5,
|
| 205 |
+
|
| 206 |
+
"scoring_keys": ["semantic", "grammar", "memory", "debt", "boundary"],
|
| 207 |
+
"scoring_weights_fast": [1.0, 0.8, 0.5, 0.7, 0.35],
|
| 208 |
+
"Β§": ["r10", "r12"]
|
| 209 |
+
},
|
| 210 |
+
|
| 211 |
+
"tsp_knot": {
|
| 212 |
+
"energy_terms": {
|
| 213 |
+
"autoregressive": [1.0, "embedding_inner_product"],
|
| 214 |
+
"memory_coherence": [0.3, "hamming_to_semantic_sketch"],
|
| 215 |
+
"binding_fidelity": [0.2, "xor_unbind_popcount"],
|
| 216 |
+
"grammar": [0.4, "fst_transition_cost"],
|
| 217 |
+
"debt": [0.3, "obligation_delta"]
|
| 218 |
+
},
|
| 219 |
+
"relaxation_phase1": "gated_deltanet_update",
|
| 220 |
+
"relaxation_phase2_max_iters": 3,
|
| 221 |
+
"relaxation_phase2_flip_fraction": 0.02,
|
| 222 |
+
"early_exit_delta_e": 1e-4
|
| 223 |
+
},
|
| 224 |
+
|
| 225 |
+
"grammar": {
|
| 226 |
+
"enabled": true,
|
| 227 |
+
"modes": ["plain_text", "dialogue", "markdown", "json", "python", "javascript", "sql", "math_latex", "shell"],
|
| 228 |
+
"representation": "deterministic_fst_plus_weighted",
|
| 229 |
+
"storage_mb": 64,
|
| 230 |
+
"hard_constraints": ["balanced_brackets", "valid_json_in_json_mode", "fence_closure", "string_literal_closure"],
|
| 231 |
+
"soft_constraints": ["sentence_rhythm", "repetition_avoidance", "paragraph_length"],
|
| 232 |
+
"adaptive_mask_cache": true,
|
| 233 |
+
"jit_compilation": true,
|
| 234 |
+
"Β§": ["r12", "r13"]
|
| 235 |
+
},
|
| 236 |
+
|
| 237 |
+
"semantic_memory": {
|
| 238 |
+
"vector_bits": 8192,
|
| 239 |
+
"vector_storage": "uint64_x128",
|
| 240 |
+
"capacity": 200000,
|
| 241 |
+
"relations": 500000,
|
| 242 |
+
"memory_mb": 320,
|
| 243 |
+
"ops": ["xor_bind", "xor_unbind", "majority_bundle", "popcnt_hamming", "rotate_permute"],
|
| 244 |
+
"lsh_tables": 64,
|
| 245 |
+
"lsh_bits_per_table": 14,
|
| 246 |
+
"hot_cache_entries": 16384,
|
| 247 |
+
"read_at_every_knot": true,
|
| 248 |
+
"write_policy": "surprise_threshold_plus_contrastive_validation",
|
| 249 |
+
"forgetting_policy": "fixed_pool_exponential_decay",
|
| 250 |
+
"pool_size_fixed": true,
|
| 251 |
+
"Β§": ["r15", "r16"]
|
| 252 |
+
},
|
| 253 |
+
|
| 254 |
+
"entropy_valve": {
|
| 255 |
+
"enabled": true,
|
| 256 |
+
"metrics": ["span_energy_margin", "grammar_branching", "sketch_instability", "entity_conflicts", "debt_pressure", "queue_depth"],
|
| 257 |
+
"threshold_bits": 2.0,
|
| 258 |
+
"type": "inference_time_compute_allocation",
|
| 259 |
+
"loop_depth_router": {
|
| 260 |
+
"method": "mod_causal_predictor",
|
| 261 |
+
"accuracy_target": 0.97,
|
| 262 |
+
"Β§": "r19"
|
| 263 |
+
},
|
| 264 |
+
"levels": {
|
| 265 |
+
"low": {"loops": 1, "min_span": 8, "audit": 0.125},
|
| 266 |
+
"medium": {"loops": 2, "min_span": 4, "audit": 0.5},
|
| 267 |
+
"high": {"loops": 4, "min_span": 1, "audit": 1.0}
|
| 268 |
+
},
|
| 269 |
+
"Β§": "r20"
|
| 270 |
+
},
|
| 271 |
+
|
| 272 |
+
"debt_ledger": {
|
| 273 |
+
"enabled": true,
|
| 274 |
+
"obligations": ["close_bracket", "close_string", "close_fence", "resolve_pronoun", "finish_list", "maintain_tense", "complete_sentence", "end_json_object"],
|
| 275 |
+
"max_outstanding": 64,
|
| 276 |
+
"pressure_weight": 0.3
|
| 277 |
+
},
|
| 278 |
+
|
| 279 |
+
"self_evolution": {
|
| 280 |
+
"num_mechanisms": 7,
|
| 281 |
+
|
| 282 |
+
"tier1": {
|
| 283 |
+
"ttt": {
|
| 284 |
+
"enabled": true,
|
| 285 |
+
"target_layers": [13, 23],
|
| 286 |
+
"target_param": "mlp_w_down",
|
| 287 |
+
"inner_lr": 0.0003,
|
| 288 |
+
"inner_optimizer": "sgd_momentum",
|
| 289 |
+
"momentum": 0.9,
|
| 290 |
+
"objective": "next_token_prediction",
|
| 291 |
+
"chunk_size": 1024,
|
| 292 |
+
"update_scope": "full_w_down",
|
| 293 |
+
"reset_decay": 0.95,
|
| 294 |
+
"persistence": "per_user_session_file",
|
| 295 |
+
"Β§": "r14"
|
| 296 |
+
},
|
| 297 |
+
"memory_growth": {
|
| 298 |
+
"enabled": true,
|
| 299 |
+
"surprise_threshold": "titans_gradient_magnitude_above_2_sigma",
|
| 300 |
+
"contrastive_validation": true,
|
| 301 |
+
"user_explicit_store": true,
|
| 302 |
+
"max_per_session": 1000,
|
| 303 |
+
"pool_fixed": true,
|
| 304 |
+
"forgetting": "random_drop_k_append_k",
|
| 305 |
+
"persistent": true,
|
| 306 |
+
"pruning": "low_retrieval_weight_eviction",
|
| 307 |
+
"Β§": ["r15", "r16"]
|
| 308 |
+
}
|
| 309 |
+
},
|
| 310 |
+
|
| 311 |
+
"tier2": {
|
| 312 |
+
"meta_guidelines": {
|
| 313 |
+
"enabled": true,
|
| 314 |
+
"max": 256,
|
| 315 |
+
"format": "8192bit_xor",
|
| 316 |
+
"trigger": "contrastive_eval_negative",
|
| 317 |
+
"Β§": "r15"
|
| 318 |
+
},
|
| 319 |
+
"episodic_cases": {
|
| 320 |
+
"enabled": true,
|
| 321 |
+
"retrieval": "soft_q_learning",
|
| 322 |
+
"max_cases": 4096,
|
| 323 |
+
"case_bytes": 2048,
|
| 324 |
+
"weight_update": "outcome_based_ema",
|
| 325 |
+
"Β§": "r17"
|
| 326 |
+
},
|
| 327 |
+
"self_feedback": {
|
| 328 |
+
"enabled": true,
|
| 329 |
+
"confidence_threshold": 0.6,
|
| 330 |
+
"max_refinement_rounds": 1,
|
| 331 |
+
"Β§": "r18"
|
| 332 |
+
}
|
| 333 |
+
},
|
| 334 |
+
|
| 335 |
+
"tier3": {
|
| 336 |
+
"span_bank_expansion": {
|
| 337 |
+
"enabled": true,
|
| 338 |
+
"min_span_len": 4,
|
| 339 |
+
"max_new_per_session": 256,
|
| 340 |
+
"acceptance": "cert_valid AND no_correction AND used_3plus",
|
| 341 |
+
"persistent": true,
|
| 342 |
+
"compression": "merge_similar_periodic"
|
| 343 |
+
},
|
| 344 |
+
"loop_depth_learning": {
|
| 345 |
+
"enabled": true,
|
| 346 |
+
"classifier": "int8_2layer_mlp",
|
| 347 |
+
"classifier_params": 500000,
|
| 348 |
+
"signal": "parcae_convergence_speed",
|
| 349 |
+
"persistent": true
|
| 350 |
+
}
|
| 351 |
+
},
|
| 352 |
+
|
| 353 |
+
"safety": {
|
| 354 |
+
"max_growth_mb": {"memory": 512, "span_bank": 128, "episodic": 8, "guidelines": 2},
|
| 355 |
+
"rollback_on_degradation": true,
|
| 356 |
+
"monitor": "certificate_failure_rate_and_rollback_rate",
|
| 357 |
+
"freeze_threshold": 0.05,
|
| 358 |
+
"user_reset": true,
|
| 359 |
+
"state_file": "chimera51_evolution.state"
|
| 360 |
+
}
|
| 361 |
+
},
|
| 362 |
+
|
| 363 |
+
"braid_state": {
|
| 364 |
+
"continuous_hidden": [2560, "float32"],
|
| 365 |
+
"fast_hidden": [2560, "int8"],
|
| 366 |
+
"semantic_sketch": [8192, "uint64_x128"],
|
| 367 |
+
"entity_table": {"slots": 256, "slot_bits": 512, "binding": "xor_role_filler"},
|
| 368 |
+
"grammar_stack": {"slots": 64, "width_bits": 128},
|
| 369 |
+
"debt_ledger_slots": 64,
|
| 370 |
+
"per_stream_mb": 30,
|
| 371 |
+
"kv_growth_per_token": 0
|
| 372 |
+
},
|
| 373 |
+
|
| 374 |
+
"modes": {
|
| 375 |
+
"fast": {"tps": 200, "neural_hz": 40, "span_avg": 5, "loops": 1, "audit": 0.125},
|
| 376 |
+
"balanced": {"tps": 120, "neural_hz": 30, "span_avg": 4, "loops": 2, "audit": 0.5},
|
| 377 |
+
"reasoning": {"tps": 40, "neural_hz": 20, "span_avg": 2, "loops": 4, "audit": 1.0}
|
| 378 |
+
},
|
| 379 |
+
|
| 380 |
+
"generation": {
|
| 381 |
+
"temperature": 0.7,
|
| 382 |
+
"top_p": 0.92,
|
| 383 |
+
"repetition_penalty": 1.08,
|
| 384 |
+
"max_new_tokens": 4096,
|
| 385 |
+
"do_sample": true,
|
| 386 |
+
"stream": true
|
| 387 |
+
},
|
| 388 |
+
|
| 389 |
+
"training": {
|
| 390 |
+
"phases": [
|
| 391 |
+
{
|
| 392 |
+
"name": "pretrain",
|
| 393 |
+
"tokens": "2T",
|
| 394 |
+
"data": ["FineWeb-Edu", "SlimPajama", "StarCoder-data", "multilingual-CC"],
|
| 395 |
+
"seq_len": 4096,
|
| 396 |
+
"batch_tokens": "4M",
|
| 397 |
+
"optimizer": "AdamW",
|
| 398 |
+
"lr": 3e-4,
|
| 399 |
+
"schedule": "cosine_warmup",
|
| 400 |
+
"warmup_steps": 2000,
|
| 401 |
+
"weight_decay": 0.1,
|
| 402 |
+
"grad_clip": 1.0,
|
| 403 |
+
"ternary": "native_qat_ste",
|
| 404 |
+
"Β§": ["r5", "r6"]
|
| 405 |
+
},
|
| 406 |
+
{
|
| 407 |
+
"name": "ctx_extend",
|
| 408 |
+
"stages": [
|
| 409 |
+
[4096, "main"],
|
| 410 |
+
[16384, 10000, 1e-5],
|
| 411 |
+
[65536, 5000, 5e-6],
|
| 412 |
+
[262144, 2000, 2e-6]
|
| 413 |
+
]
|
| 414 |
+
},
|
| 415 |
+
{
|
| 416 |
+
"name": "sft",
|
| 417 |
+
"data": ["UltraChat-200k", "ShareGPT-cleaned"],
|
| 418 |
+
"epochs": 3,
|
| 419 |
+
"lr": 2e-5
|
| 420 |
+
},
|
| 421 |
+
{
|
| 422 |
+
"name": "dpo",
|
| 423 |
+
"data": "UltraFeedback-binarized",
|
| 424 |
+
"epochs": 1,
|
| 425 |
+
"lr": 5e-7,
|
| 426 |
+
"beta": 0.1
|
| 427 |
+
}
|
| 428 |
+
],
|
| 429 |
+
"distillation_init": {
|
| 430 |
+
"enabled": false,
|
| 431 |
+
"method": "ARWKV_style",
|
| 432 |
+
"teacher": "Qwen-2.5-7B",
|
| 433 |
+
"tokens": "1B",
|
| 434 |
+
"Β§": "r24"
|
| 435 |
+
}
|
| 436 |
+
},
|
| 437 |
+
|
| 438 |
+
"hyper_training": {
|
| 439 |
+
"_note": "v5.3.0 β Seven stacked paradigms for 10,000+ tok/s CPU training. Each paradigm is independently toggleable. Combined theoretical multiplier: 57-260Γ over baseline MeZO.",
|
| 440 |
+
|
| 441 |
+
"paradigms": {
|
| 442 |
+
"P1_growlength": {
|
| 443 |
+
"status": "IMPLEMENTED v5.3",
|
| 444 |
+
"description": "GrowLength curriculum: train with progressively longer sequences. Short seqs β massive effective batch β way more tok/s in early training where signal is strongest.",
|
| 445 |
+
"speedup": "4-8Γ",
|
| 446 |
+
"default_stages": [[0.125, 0.20], [0.25, 0.25], [0.5, 0.25], [1.0, 0.30]],
|
| 447 |
+
"Β§": "r37"
|
| 448 |
+
},
|
| 449 |
+
"P2_reservoir_freezing": {
|
| 450 |
+
"status": "IMPLEMENTED v5.3",
|
| 451 |
+
"description": "GRC-inspired reservoir freezing: freeze ~50% of recurrent gate matrices (a_proj, b_proj, fgate, alpha_proj) as random ternary with unit spectral radius. No gradient computation for frozen params.",
|
| 452 |
+
"speedup": "1.5-2Γ",
|
| 453 |
+
"targets": ["GatedDeltaNet.a_proj", "GatedDeltaNet.b_proj", "mLSTM.fgate", "TitansMAC.alpha_proj"],
|
| 454 |
+
"Β§": "r38"
|
| 455 |
+
},
|
| 456 |
+
"P3_sparse_mezo": {
|
| 457 |
+
"status": "IMPLEMENTED v5.3",
|
| 458 |
+
"description": "Sparse MeZO: perturb only top-K% most sensitive parameters by weight magnitude. At 1% sparsity on 35M model β 350K params perturbed β 100Γ better ZO signal-to-noise per forward pass.",
|
| 459 |
+
"speedup": "3-5Γ",
|
| 460 |
+
"default_sparsity": 0.01,
|
| 461 |
+
"mask_refresh_interval": "every 10% of training",
|
| 462 |
+
"Β§": "r39"
|
| 463 |
+
},
|
| 464 |
+
"P4_blockwise_pipeline": {
|
| 465 |
+
"status": "IMPLEMENTED v5.3",
|
| 466 |
+
"description": "Blockwise pipeline parallelism via torch.compile inductor backend. Overlaps computation of layer groups across CPU core groups.",
|
| 467 |
+
"speedup": "1.3-2Γ",
|
| 468 |
+
"requires": "torch.compile"
|
| 469 |
+
},
|
| 470 |
+
"P5_fused_ternary_cache": {
|
| 471 |
+
"status": "IMPLEMENTED v5.3",
|
| 472 |
+
"description": "Pre-materialise all BitLinear packed+dense weight caches once per step. Both MeZO forward passes reuse same buffers β eliminates redundant quantizeβpackβunpack cycles.",
|
| 473 |
+
"speedup": "1.3Γ"
|
| 474 |
+
},
|
| 475 |
+
"P6_aggressive_token_packing": {
|
| 476 |
+
"status": "IMPLEMENTED v5.3",
|
| 477 |
+
"description": "Zero-padding token packing. Documents concatenated back-to-back with EOS separators, no wasted compute on padding tokens.",
|
| 478 |
+
"speedup": "1.1-1.3Γ"
|
| 479 |
+
},
|
| 480 |
+
"P7_progressive_layer_unfreeze": {
|
| 481 |
+
"status": "IMPLEMENTED v5.3",
|
| 482 |
+
"description": "Progressive layer unfreezing from output to input. Start with only top ~25% of layers trainable. Deeper layers frozen = fast forward + no gradient storage. Gradually unfreeze as training progresses.",
|
| 483 |
+
"speedup": "1.5-2Γ"
|
| 484 |
+
}
|
| 485 |
+
},
|
| 486 |
+
|
| 487 |
+
"combined_estimate": {
|
| 488 |
+
"formula": "P1(6Γ) Γ P2(1.7Γ) Γ P3(4Γ) Γ P5(1.3Γ) Γ P7(1.7Γ)",
|
| 489 |
+
"theoretical_multiplier": "57-260Γ",
|
| 490 |
+
"baseline_tiny_35M": "50-200 tok/s",
|
| 491 |
+
"target_tiny_35M": "3,000-15,000+ tok/s",
|
| 492 |
+
"note": "Actual speedup depends on CPU architecture, core count, cache hierarchy, and AMX/AVX-512 availability."
|
| 493 |
+
},
|
| 494 |
+
|
| 495 |
+
"Β§_hyper": ["r37", "r38", "r39", "r40", "r41", "r42", "r29", "r33"]
|
| 496 |
+
},
|
| 497 |
+
|
| 498 |
+
"byte_level": {
|
| 499 |
+
"enabled": false,
|
| 500 |
+
"encoder_params": "50M",
|
| 501 |
+
"encoder_depth": 8,
|
| 502 |
+
"patching": "entropy_threshold",
|
| 503 |
+
"decoder_params": "50M",
|
| 504 |
+
"Β§": "r23"
|
| 505 |
+
},
|
| 506 |
+
|
| 507 |
+
"memory_budget_mb": {
|
| 508 |
+
"_keys": ["ternary_weights", "moe_experts", "span_bank", "grammar", "semantic_mem", "episodic", "guidelines", "braid", "activations", "render_queue", "evolution", "runtime_os"],
|
| 509 |
+
"_vals": [410, 66, 384, 64, 320, 8, 2, 30, 80, 32, 128, 1000],
|
| 510 |
+
"total": 2524,
|
| 511 |
+
"headroom_8gb": 4876,
|
| 512 |
+
"growth_ceiling": 650,
|
| 513 |
+
"max_with_growth": 3174
|
| 514 |
+
},
|
| 515 |
+
|
| 516 |
+
"deployment": {
|
| 517 |
+
"batch_size": 1,
|
| 518 |
+
"max_streams": 16,
|
| 519 |
+
"per_stream_mb": 30,
|
| 520 |
+
"shared": ["weights", "span_bank", "grammar"],
|
| 521 |
+
"mmap": ["weights", "span_bank"],
|
| 522 |
+
"cold_start_s": 2.5,
|
| 523 |
+
"watchdog_tick_ms": 20,
|
| 524 |
+
"watchdog_max_overruns": 8,
|
| 525 |
+
"deterministic": true,
|
| 526 |
+
"seed_controls_all": true,
|
| 527 |
+
"platforms": ["x86_64_avx2", "aarch64_neon", "wasm_simd128", "apple_silicon_amx"]
|
| 528 |
+
},
|
| 529 |
+
|
| 530 |
+
"diagnostics": {
|
| 531 |
+
"telemetry": true,
|
| 532 |
+
"report_interval_tokens": 256,
|
| 533 |
+
"metrics": [
|
| 534 |
+
"surface_tps", "neural_knot_tps", "mean_span_length",
|
| 535 |
+
"span_acceptance_rate", "certificate_failure_rate",
|
| 536 |
+
"rollback_count", "queue_depth", "loop_count_mean",
|
| 537 |
+
"memory_mb", "evolution_events", "grammar_violations_prevented",
|
| 538 |
+
"contrastive_eval_ratio", "self_refinement_trigger_rate",
|
| 539 |
+
"episodic_case_hit_rate", "moe_expert_load_balance",
|
| 540 |
+
"gd_alpha_mean", "gd_beta_mean", "ttt_loss_delta"
|
| 541 |
+
],
|
| 542 |
+
"thresholds": {
|
| 543 |
+
"min_span_accept": 0.70,
|
| 544 |
+
"max_cert_fail": 0.05,
|
| 545 |
+
"max_rollback": 0.02,
|
| 546 |
+
"min_contrastive_benefit": 0.0,
|
| 547 |
+
"max_moe_imbalance": 0.15
|
| 548 |
+
}
|
| 549 |
+
},
|
| 550 |
+
|
| 551 |
+
"context_tiers": [
|
| 552 |
+
{"name": "recent_ring", "tokens": 4096, "mb": 16},
|
| 553 |
+
{"name": "braid_state", "mb": 30},
|
| 554 |
+
{"name": "semantic_memory", "mb": 320},
|
| 555 |
+
{"name": "ttt_compressed", "mb": 24},
|
| 556 |
+
{"name": "span_trace", "entries": 32768, "mb": 32},
|
| 557 |
+
{"name": "episodic_cases", "entries": 4096, "mb": 8}
|
| 558 |
+
],
|
| 559 |
+
|
| 560 |
+
"multimodal": {
|
| 561 |
+
"enabled": true,
|
| 562 |
+
"modalities": ["text", "image", "audio"],
|
| 563 |
+
"vision": {"type": "gated_deltanet_tiny", "depth": 12, "hidden": 384, "patch": 16, "out": 2560, "quant": "ternary"},
|
| 564 |
+
"audio": {"type": "gated_deltanet_audio_tiny", "depth": 6, "hidden": 256, "out": 2560, "quant": "ternary"}
|
| 565 |
+
},
|
| 566 |
+
|
| 567 |
+
"safety": {
|
| 568 |
+
"format_guards": ["json_strict", "code_fence_closure", "markdown_table_guard"],
|
| 569 |
+
"memory_limit_enforced": true,
|
| 570 |
+
"crash_only_allocator": true,
|
| 571 |
+
"user_facts_override_weak_memory": true,
|
| 572 |
+
"state_uncertainty_when_unsure": true
|
| 573 |
+
},
|
| 574 |
+
|
| 575 |
+
"files": {
|
| 576 |
+
"weights": "chimera51.b158",
|
| 577 |
+
"moe": "chimera51_experts.b158",
|
| 578 |
+
"spans": "chimera51_spans.sfpack",
|
| 579 |
+
"grammar": "chimera51_grammar.fstpack",
|
| 580 |
+
"memory_seed": "chimera51_memory.seedpack",
|
| 581 |
+
"tokenizer": "chimera51_tokenizer.model",
|
| 582 |
+
"evolution": "chimera51_evolution.state"
|
| 583 |
+
},
|
| 584 |
+
|
| 585 |
+
"params": {
|
| 586 |
+
"base": "2.3B",
|
| 587 |
+
"moe_total": "350M",
|
| 588 |
+
"physical": "2.65B",
|
| 589 |
+
"effective_2loops": "4.2B",
|
| 590 |
+
"effective_6loops": "9.5B",
|
| 591 |
+
"active_per_token": "2.39B",
|
| 592 |
+
"weight_mb": 476,
|
| 593 |
+
"total_mb": 2524
|
| 594 |
+
},
|
| 595 |
+
|
| 596 |
+
"P3_ternary_compute": {
|
| 597 |
+
"_note": "v5.1.2 β Honest section. Documents ONLY what is implemented and measured.",
|
| 598 |
+
|
| 599 |
+
"thesis": "Ternary weights {-1,0,1} enable 16Γ memory reduction via 2-bit packed storage. On CPU, training speed is dominated by MKL BLAS β raw ternary matmul is not faster than FP32 at small-to-medium sizes. The real wins are: (1) 16Γ less RAM enabling larger models on limited hardware, (2) 16Γ less memory bandwidth for large models where DRAM is the bottleneck, (3) MeZO eliminates the backward pass entirely (2Γ forward only). Inference post-training uses LUT-based kernels (T-MAC, bitnet.cpp) for true speedup. v5.3 adds 7 stacked paradigms that target the training loop itself for multiplicative speedup.",
|
| 600 |
+
|
| 601 |
+
"implemented_optimizations": {
|
| 602 |
+
"mezo_optimizer": {
|
| 603 |
+
"status": "IMPLEMENTED",
|
| 604 |
+
"description": "Memory-Efficient Zeroth-Order optimizer β eliminates backward pass entirely. 2 forward passes per step.",
|
| 605 |
+
"benefit": "Memory = 2Γ model size (no activations, no gradients, no optimizer states). Ideal for CPU with complex recurrences.",
|
| 606 |
+
"limitation": "Requires ~32Γ more steps to converge than AdamW. Best for fine-tuning, not pretraining from scratch.",
|
| 607 |
+
"Β§": "r29"
|
| 608 |
+
},
|
| 609 |
+
"sparse_mezo_v53": {
|
| 610 |
+
"status": "IMPLEMENTED v5.3",
|
| 611 |
+
"description": "Sparse MeZO: perturb only top-K% params by weight magnitude. Reduces ZO variance by 100Γ at 1% sparsity.",
|
| 612 |
+
"benefit": "3-5Γ faster convergence per wall-clock second. Same memory as standard MeZO.",
|
| 613 |
+
"Β§": "r39"
|
| 614 |
+
},
|
| 615 |
+
"growlength_v53": {
|
| 616 |
+
"status": "IMPLEMENTED v5.3",
|
| 617 |
+
"description": "Progressive sequence length curriculum. Start at seq=16, grow to target.",
|
| 618 |
+
"benefit": "4-8Γ more tokens/s in early training. Larger effective batch at short lengths.",
|
| 619 |
+
"Β§": "r37"
|
| 620 |
+
},
|
| 621 |
+
"reservoir_freezing_v53": {
|
| 622 |
+
"status": "IMPLEMENTED v5.3",
|
| 623 |
+
"description": "GRC-inspired: freeze 50% of recurrent gate matrices as random ternary reservoirs.",
|
| 624 |
+
"benefit": "1.5-2Γ fewer FLOPs in recurrent layers. No convergence degradation for gate matrices.",
|
| 625 |
+
"Β§": "r38"
|
| 626 |
+
},
|
| 627 |
+
"bf16_autocast": {
|
| 628 |
+
"status": "IMPLEMENTED",
|
| 629 |
+
"description": "BFloat16 automatic mixed precision on CPU via torch.autocast('cpu', dtype=torch.bfloat16).",
|
| 630 |
+
"benefit": "2-4Γ faster matmuls on Intel Sapphire Rapids+ (AMX) or Ice Lake+ (AVX-512-BF16).",
|
| 631 |
+
"limitation": "Forward-pass only. Gradients remain FP32."
|
| 632 |
+
},
|
| 633 |
+
"torch_compile": {
|
| 634 |
+
"status": "IMPLEMENTED",
|
| 635 |
+
"description": "torch.compile with Inductor backend for CPU. Fuses ops, reduces Python overhead.",
|
| 636 |
+
"benefit": "1.3-2Γ overall training throughput.",
|
| 637 |
+
"limitation": "First iteration is slow (compilation). Dynamic shapes supported."
|
| 638 |
+
},
|
| 639 |
+
"parallel_mlstm": {
|
| 640 |
+
"status": "IMPLEMENTED",
|
| 641 |
+
"description": "Replaced O(T) Python loop with parallel log-space cumulative gate computation + batched QKV attention.",
|
| 642 |
+
"benefit": "~10-50Γ faster for mLSTM layers on CPU (seq_len β₯ 64).",
|
| 643 |
+
"Β§": "r1"
|
| 644 |
+
},
|
| 645 |
+
"parallel_titans_mac": {
|
| 646 |
+
"status": "IMPLEMENTED",
|
| 647 |
+
"description": "Replaced O(T) Python loop with causal decay attention + vectorized contribution computation.",
|
| 648 |
+
"benefit": "~5-20Γ faster for Titans MAC layers on CPU.",
|
| 649 |
+
"Β§": "r2"
|
| 650 |
+
},
|
| 651 |
+
"sort_based_moe": {
|
| 652 |
+
"status": "IMPLEMENTED",
|
| 653 |
+
"description": "Sort tokens by expert ID β process contiguous blocks β scatter_add back.",
|
| 654 |
+
"benefit": "Better cache locality than random-access per-expert dispatch.",
|
| 655 |
+
"Β§": "r21"
|
| 656 |
+
},
|
| 657 |
+
"gradient_checkpointing": {
|
| 658 |
+
"status": "IMPLEMENTED",
|
| 659 |
+
"description": "Per-block activation checkpointing for AdamW mode.",
|
| 660 |
+
"benefit": "30-60% memory reduction, enabling larger batches."
|
| 661 |
+
},
|
| 662 |
+
"cpu_thread_tuning": {
|
| 663 |
+
"status": "IMPLEMENTED",
|
| 664 |
+
"description": "OMP_NUM_THREADS, KMP_AFFINITY=compact, KMP_BLOCKTIME=1.",
|
| 665 |
+
"benefit": "10-30% throughput improvement from optimal thread placement."
|
| 666 |
+
},
|
| 667 |
+
"ipex_integration": {
|
| 668 |
+
"status": "IMPLEMENTED (optional)",
|
| 669 |
+
"description": "Auto-detected Intel Extension for PyTorch. ipex.optimize() with BF16 + AMX kernel selection.",
|
| 670 |
+
"benefit": "Additional 30-50% on Intel CPUs."
|
| 671 |
+
},
|
| 672 |
+
"ternary_qat_ste": {
|
| 673 |
+
"status": "IMPLEMENTED",
|
| 674 |
+
"description": "BitNet 1.58 quantization-aware training with STE.",
|
| 675 |
+
"Β§": ["r5", "r7"]
|
| 676 |
+
},
|
| 677 |
+
"two_bit_packed_weights": {
|
| 678 |
+
"status": "IMPLEMENTED v5.1.2",
|
| 679 |
+
"description": "Ternary weights packed as 2-bit uint8. Custom C++ kernel with OpenMP for unpack.",
|
| 680 |
+
"benefit": "16Γ less storage vs FP32."
|
| 681 |
+
},
|
| 682 |
+
"fused_ternary_cache_v53": {
|
| 683 |
+
"status": "IMPLEMENTED v5.3",
|
| 684 |
+
"description": "Pre-materialise all BitLinear packed+dense caches once per step. Both MeZO forwards reuse same buffers.",
|
| 685 |
+
"benefit": "1.3Γ by eliminating redundant quantize-pack-unpack cycles."
|
| 686 |
+
},
|
| 687 |
+
"progressive_unfreeze_v53": {
|
| 688 |
+
"status": "IMPLEMENTED v5.3",
|
| 689 |
+
"description": "Train only top 25% of layers initially; unfreeze downward as training advances.",
|
| 690 |
+
"benefit": "1.5-2Γ fewer params in gradient path during early training."
|
| 691 |
+
},
|
| 692 |
+
"token_packing_v53": {
|
| 693 |
+
"status": "IMPLEMENTED v5.3",
|
| 694 |
+
"description": "Zero-padding token packing. Documents packed back-to-back with EOS separators.",
|
| 695 |
+
"benefit": "1.1-1.3Γ by eliminating wasted compute on padding."
|
| 696 |
+
}
|
| 697 |
+
},
|
| 698 |
+
|
| 699 |
+
"not_implemented": {
|
| 700 |
+
"elut_training": "ELUT/T-MAC kernels apply to INFERENCE only.",
|
| 701 |
+
"mixture_of_depths": "MoD requires specific router architecture.",
|
| 702 |
+
"sparse_backprop": "SparseProp requires β₯90% weight sparsity."
|
| 703 |
+
},
|
| 704 |
+
|
| 705 |
+
"realistic_performance": {
|
| 706 |
+
"cpu_training_tiny_35M_baseline": {"hardware": "i7-14700T", "throughput": "~50-200 tok/s", "note": "Standard MeZO+BF16"},
|
| 707 |
+
"cpu_training_tiny_35M_hyper": {"hardware": "i7-14700T", "throughput": "~3,000-15,000 tok/s", "note": "All 7 paradigms ON"},
|
| 708 |
+
"cpu_training_small_150M_baseline": {"hardware": "i7-14700T", "throughput": "~10-50 tok/s", "note": "Standard MeZO+BF16"},
|
| 709 |
+
"cpu_training_small_150M_hyper": {"hardware": "i7-14700T", "throughput": "~500-3,000 tok/s", "note": "All 7 paradigms ON"},
|
| 710 |
+
"cpu_inference_ternary": {"note": "Post-training with bitnet.cpp/T-MAC: 30-127 tok/s for 700M-3B models"},
|
| 711 |
+
"gpu_training_comparison": "GPU (A100) is 50-150Γ faster than CPU. HYPER paradigms aim to close this gap for small models."
|
| 712 |
+
},
|
| 713 |
+
|
| 714 |
+
"Β§_paradigm": ["r26", "r27", "r28", "r29", "r30", "r31", "r32", "r33", "r5", "r34", "r7", "r19", "r37", "r38", "r39", "r40", "r41", "r42"]
|
| 715 |
+
}
|
| 716 |
+
}
|
gguf_import.py
ADDED
|
@@ -0,0 +1,907 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
"""
|
| 4 |
+
Chimera GGUF Import Optimized
|
| 5 |
+
βββββββββββββββββββββββββββββ
|
| 6 |
+
|
| 7 |
+
Convert GGUF tensors into a Chimera-compatible checkpoint.
|
| 8 |
+
|
| 9 |
+
AmΓ©liorations vs version originale :
|
| 10 |
+
- Ne garde pas tous les tensors GGUF FP32 en mΓ©moire.
|
| 11 |
+
- Corrige le bug embeddings/lm_head traitΓ©s comme BitLinear.
|
| 12 |
+
- Quantization ternary offline sans autograd.
|
| 13 |
+
- Clipping outlier par ligne pour les matrices.
|
| 14 |
+
- Auto-transpose si shape inversΓ©e.
|
| 15 |
+
- Modes de stockage :
|
| 16 |
+
fp32 : compatible Chimera classique, sauvegarde weight latent.
|
| 17 |
+
packed : sauvegarde packed_weight + alpha uniquement pour couches linΓ©aires.
|
| 18 |
+
both : sauvegarde weight + packed_weight + alpha.
|
| 19 |
+
- Init des poids manquants pour checkpoint complet.
|
| 20 |
+
- Resize configurable : strict, crop_pad, interpolate.
|
| 21 |
+
- Mapping GGUF plus robuste pour LLaMA/Qwen/Mistral-like.
|
| 22 |
+
|
| 23 |
+
Usage :
|
| 24 |
+
python gguf_import_optimized.py \
|
| 25 |
+
--gguf model.gguf \
|
| 26 |
+
--config config.json \
|
| 27 |
+
--scale tiny \
|
| 28 |
+
--output imported_chimera.pt \
|
| 29 |
+
--storage fp32
|
| 30 |
+
|
| 31 |
+
Pour checkpoint compact expΓ©rimental :
|
| 32 |
+
python gguf_import_optimized.py \
|
| 33 |
+
--gguf model.gguf \
|
| 34 |
+
--config config.json \
|
| 35 |
+
--output imported_chimera_packed.pt \
|
| 36 |
+
--storage packed
|
| 37 |
+
|
| 38 |
+
Attention :
|
| 39 |
+
- storage=packed nΓ©cessite que ton loader Chimera sache lire
|
| 40 |
+
*.packed_weight et *.alpha.
|
| 41 |
+
- Importer un gros modèle vers tiny/small via resize détruit beaucoup
|
| 42 |
+
d'information. C'est utile pour bootstrap, pas Γ©quivalent Γ distillation.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
import os
|
| 46 |
+
import re
|
| 47 |
+
import gc
|
| 48 |
+
import json
|
| 49 |
+
import math
|
| 50 |
+
import argparse
|
| 51 |
+
from copy import deepcopy
|
| 52 |
+
from pathlib import Path
|
| 53 |
+
from typing import Dict, Tuple, Optional, Iterable, Any
|
| 54 |
+
|
| 55 |
+
import numpy as np
|
| 56 |
+
import torch
|
| 57 |
+
import torch.nn.functional as F
|
| 58 |
+
|
| 59 |
+
from chimera.paths import DEFAULT_CONFIG_PATH
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
try:
|
| 63 |
+
from gguf import GGUFReader, dequantize
|
| 64 |
+
HAS_GGUF = True
|
| 65 |
+
except Exception:
|
| 66 |
+
GGUFReader = None
|
| 67 |
+
dequantize = None
|
| 68 |
+
HAS_GGUF = False
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 72 |
+
# Config scales
|
| 73 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 74 |
+
|
| 75 |
+
SCALE_OVERRIDES = {
|
| 76 |
+
"tiny": {
|
| 77 |
+
"hidden_size": 256,
|
| 78 |
+
"intermediate_size": 512,
|
| 79 |
+
"num_hidden_layers": 28,
|
| 80 |
+
"num_heads": 4,
|
| 81 |
+
"head_dim": 48,
|
| 82 |
+
},
|
| 83 |
+
"small": {
|
| 84 |
+
"hidden_size": 512,
|
| 85 |
+
"intermediate_size": 1024,
|
| 86 |
+
"num_hidden_layers": 28,
|
| 87 |
+
"num_heads": 8,
|
| 88 |
+
"head_dim": 48,
|
| 89 |
+
},
|
| 90 |
+
"medium": {
|
| 91 |
+
"hidden_size": 1024,
|
| 92 |
+
"intermediate_size": 2048,
|
| 93 |
+
"num_hidden_layers": 28,
|
| 94 |
+
"num_heads": 8,
|
| 95 |
+
"head_dim": 96,
|
| 96 |
+
},
|
| 97 |
+
# full = garde config telle quelle
|
| 98 |
+
"full": {},
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 103 |
+
# Mapping GGUF -> Chimera
|
| 104 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 105 |
+
|
| 106 |
+
DIRECT_NAME_MAP = {
|
| 107 |
+
"token_embd": "embed.weight",
|
| 108 |
+
"token_embd.weight": "embed.weight",
|
| 109 |
+
|
| 110 |
+
"output": "lm_head.weight",
|
| 111 |
+
"output.weight": "lm_head.weight",
|
| 112 |
+
|
| 113 |
+
"output_norm": "norm.weight",
|
| 114 |
+
"output_norm.weight": "norm.weight",
|
| 115 |
+
|
| 116 |
+
# Variants parfois rencontrΓ©es
|
| 117 |
+
"norm": "norm.weight",
|
| 118 |
+
"norm.weight": "norm.weight",
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
BLOCK_SUFFIX_MAP = {
|
| 123 |
+
# Attention norm
|
| 124 |
+
"attn_norm": "attn_norm.weight",
|
| 125 |
+
"attn_norm.weight": "attn_norm.weight",
|
| 126 |
+
|
| 127 |
+
# FFN norm
|
| 128 |
+
"ffn_norm": "mlp_norm.weight",
|
| 129 |
+
"ffn_norm.weight": "mlp_norm.weight",
|
| 130 |
+
|
| 131 |
+
# Attention projections
|
| 132 |
+
"attn_q": "attn.q_proj.weight",
|
| 133 |
+
"attn_q.weight": "attn.q_proj.weight",
|
| 134 |
+
"attn_k": "attn.k_proj.weight",
|
| 135 |
+
"attn_k.weight": "attn.k_proj.weight",
|
| 136 |
+
"attn_v": "attn.v_proj.weight",
|
| 137 |
+
"attn_v.weight": "attn.v_proj.weight",
|
| 138 |
+
"attn_output": "attn.o_proj.weight",
|
| 139 |
+
"attn_output.weight": "attn.o_proj.weight",
|
| 140 |
+
|
| 141 |
+
# MLP / SwiGLU
|
| 142 |
+
"ffn_gate": "mlp.gate_proj.weight",
|
| 143 |
+
"ffn_gate.weight": "mlp.gate_proj.weight",
|
| 144 |
+
"ffn_up": "mlp.up_proj.weight",
|
| 145 |
+
"ffn_up.weight": "mlp.up_proj.weight",
|
| 146 |
+
"ffn_down": "mlp.down_proj.weight",
|
| 147 |
+
"ffn_down.weight": "mlp.down_proj.weight",
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def map_gguf_name(name: str, n_layers: int) -> Optional[str]:
|
| 152 |
+
"""
|
| 153 |
+
Convertit un nom GGUF vers une clΓ© Chimera.
|
| 154 |
+
Retourne None si non mappable.
|
| 155 |
+
"""
|
| 156 |
+
if name in DIRECT_NAME_MAP:
|
| 157 |
+
return DIRECT_NAME_MAP[name]
|
| 158 |
+
|
| 159 |
+
m = re.match(r"^blk\.(\d+)\.(.+)$", name)
|
| 160 |
+
if not m:
|
| 161 |
+
return None
|
| 162 |
+
|
| 163 |
+
bid = int(m.group(1))
|
| 164 |
+
suffix = m.group(2)
|
| 165 |
+
|
| 166 |
+
if bid >= n_layers:
|
| 167 |
+
return None
|
| 168 |
+
|
| 169 |
+
mapped_suffix = BLOCK_SUFFIX_MAP.get(suffix)
|
| 170 |
+
if mapped_suffix is None:
|
| 171 |
+
return None
|
| 172 |
+
|
| 173 |
+
return f"layers.{bid}.{mapped_suffix}"
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 177 |
+
# Ternary quantization + packing
|
| 178 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 179 |
+
|
| 180 |
+
@torch.no_grad()
|
| 181 |
+
def ternary_quantize_absmean(
|
| 182 |
+
w: torch.Tensor,
|
| 183 |
+
threshold: float = 0.5,
|
| 184 |
+
eps: float = 1e-5,
|
| 185 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 186 |
+
"""
|
| 187 |
+
Convertit w FP32 [M,K] -> w_q int8 {-1,0,1} + alpha [M].
|
| 188 |
+
|
| 189 |
+
alpha = mean(abs(w), dim=1)
|
| 190 |
+
w_norm = w / alpha
|
| 191 |
+
q = -1 si w_norm <= -threshold
|
| 192 |
+
0 si entre
|
| 193 |
+
+1 si w_norm >= threshold
|
| 194 |
+
"""
|
| 195 |
+
if w.ndim != 2:
|
| 196 |
+
raise ValueError("ternary_quantize_absmean attend un tensor 2D")
|
| 197 |
+
|
| 198 |
+
w = w.to(torch.float32)
|
| 199 |
+
alpha = w.abs().mean(dim=1).clamp_min(eps)
|
| 200 |
+
|
| 201 |
+
wn = w / alpha[:, None]
|
| 202 |
+
q = torch.zeros_like(wn, dtype=torch.int8)
|
| 203 |
+
q[wn >= threshold] = 1
|
| 204 |
+
q[wn <= -threshold] = -1
|
| 205 |
+
|
| 206 |
+
return q, alpha.to(torch.float32)
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
@torch.no_grad()
|
| 210 |
+
def pack_ternary_2bit(w_q: torch.Tensor) -> torch.Tensor:
|
| 211 |
+
"""
|
| 212 |
+
Pack int8 {-1,0,+1} -> uint8, 4 poids par byte.
|
| 213 |
+
|
| 214 |
+
Encoding :
|
| 215 |
+
0 -> 00
|
| 216 |
+
+1 -> 01
|
| 217 |
+
-1 -> 10
|
| 218 |
+
|
| 219 |
+
Ordre :
|
| 220 |
+
weight0 bits 7..6
|
| 221 |
+
weight1 bits 5..4
|
| 222 |
+
weight2 bits 3..2
|
| 223 |
+
weight3 bits 1..0
|
| 224 |
+
"""
|
| 225 |
+
if w_q.ndim != 2:
|
| 226 |
+
raise ValueError("pack_ternary_2bit attend un tensor 2D")
|
| 227 |
+
|
| 228 |
+
M, K = w_q.shape
|
| 229 |
+
K4 = (K + 3) // 4
|
| 230 |
+
pad = K4 * 4 - K
|
| 231 |
+
|
| 232 |
+
codes = torch.zeros_like(w_q, dtype=torch.uint8)
|
| 233 |
+
codes[w_q == 1] = 1
|
| 234 |
+
codes[w_q == -1] = 2
|
| 235 |
+
|
| 236 |
+
if pad:
|
| 237 |
+
codes = F.pad(codes, (0, pad), value=0)
|
| 238 |
+
|
| 239 |
+
codes = codes.view(M, K4, 4)
|
| 240 |
+
packed = (
|
| 241 |
+
(codes[..., 0] << 6)
|
| 242 |
+
| (codes[..., 1] << 4)
|
| 243 |
+
| (codes[..., 2] << 2)
|
| 244 |
+
| codes[..., 3]
|
| 245 |
+
)
|
| 246 |
+
return packed.contiguous()
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 250 |
+
# Noise reduction
|
| 251 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 252 |
+
|
| 253 |
+
@torch.no_grad()
|
| 254 |
+
def reduce_noise(
|
| 255 |
+
w: torch.Tensor,
|
| 256 |
+
method: str = "row_outlier_clip",
|
| 257 |
+
sigma: float = 3.0,
|
| 258 |
+
eps: float = 1e-5,
|
| 259 |
+
) -> torch.Tensor:
|
| 260 |
+
"""
|
| 261 |
+
PrΓ©traitement avant ternarisation.
|
| 262 |
+
|
| 263 |
+
none : rien.
|
| 264 |
+
global_clip : clip global mean Β± sigma*std.
|
| 265 |
+
row_outlier_clip : clip par ligne, meilleur pour matrices linΓ©aires.
|
| 266 |
+
median_center : recentrage robuste global median/MAD.
|
| 267 |
+
"""
|
| 268 |
+
if method == "none":
|
| 269 |
+
return w
|
| 270 |
+
|
| 271 |
+
w = w.to(torch.float32)
|
| 272 |
+
|
| 273 |
+
if method == "global_clip":
|
| 274 |
+
mu = w.mean()
|
| 275 |
+
std = w.std(unbiased=False).clamp_min(eps)
|
| 276 |
+
return w.clamp(mu - sigma * std, mu + sigma * std)
|
| 277 |
+
|
| 278 |
+
if method == "row_outlier_clip":
|
| 279 |
+
if w.ndim != 2:
|
| 280 |
+
return reduce_noise(w, method="global_clip", sigma=sigma, eps=eps)
|
| 281 |
+
|
| 282 |
+
mu = w.mean(dim=1, keepdim=True)
|
| 283 |
+
std = w.std(dim=1, keepdim=True, unbiased=False).clamp_min(eps)
|
| 284 |
+
return w.clamp(mu - sigma * std, mu + sigma * std)
|
| 285 |
+
|
| 286 |
+
if method == "median_center":
|
| 287 |
+
med = w.median()
|
| 288 |
+
mad = (w - med).abs().median().clamp_min(eps)
|
| 289 |
+
return (w - med) / mad
|
| 290 |
+
|
| 291 |
+
return w
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 295 |
+
# Resize helpers
|
| 296 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 297 |
+
|
| 298 |
+
@torch.no_grad()
|
| 299 |
+
def resize_1d(w: torch.Tensor, target: int) -> torch.Tensor:
|
| 300 |
+
src = w.numel()
|
| 301 |
+
if src == target:
|
| 302 |
+
return w.contiguous()
|
| 303 |
+
|
| 304 |
+
out = torch.ones(target, dtype=w.dtype)
|
| 305 |
+
n = min(src, target)
|
| 306 |
+
out[:n] = w[:n]
|
| 307 |
+
return out.contiguous()
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
@torch.no_grad()
|
| 311 |
+
def resize_2d_crop_pad(
|
| 312 |
+
w: torch.Tensor,
|
| 313 |
+
target_shape: Tuple[int, int],
|
| 314 |
+
fill_std: float = 0.02,
|
| 315 |
+
) -> torch.Tensor:
|
| 316 |
+
"""
|
| 317 |
+
Resize rapide par crop/pad.
|
| 318 |
+
Plus prΓ©visible qu'une interpolation sur poids Transformer.
|
| 319 |
+
"""
|
| 320 |
+
target_out, target_in = target_shape
|
| 321 |
+
src_out, src_in = w.shape
|
| 322 |
+
|
| 323 |
+
if (src_out, src_in) == (target_out, target_in):
|
| 324 |
+
return w.contiguous()
|
| 325 |
+
|
| 326 |
+
out = torch.empty((target_out, target_in), dtype=w.dtype)
|
| 327 |
+
|
| 328 |
+
# init zones non copiΓ©es
|
| 329 |
+
std = float(w.std(unbiased=False).item()) if w.numel() > 1 else fill_std
|
| 330 |
+
std = max(min(std, 0.2), 1e-4)
|
| 331 |
+
out.normal_(mean=0.0, std=std)
|
| 332 |
+
|
| 333 |
+
ro = min(src_out, target_out)
|
| 334 |
+
ci = min(src_in, target_in)
|
| 335 |
+
out[:ro, :ci] = w[:ro, :ci]
|
| 336 |
+
|
| 337 |
+
return out.contiguous()
|
| 338 |
+
|
| 339 |
+
|
| 340 |
+
@torch.no_grad()
|
| 341 |
+
def resize_2d_interpolate(
|
| 342 |
+
w: torch.Tensor,
|
| 343 |
+
target_shape: Tuple[int, int],
|
| 344 |
+
) -> torch.Tensor:
|
| 345 |
+
target_out, target_in = target_shape
|
| 346 |
+
if tuple(w.shape) == tuple(target_shape):
|
| 347 |
+
return w.contiguous()
|
| 348 |
+
|
| 349 |
+
x = w[None, None, :, :]
|
| 350 |
+
y = F.interpolate(
|
| 351 |
+
x,
|
| 352 |
+
size=(target_out, target_in),
|
| 353 |
+
mode="bilinear",
|
| 354 |
+
align_corners=False,
|
| 355 |
+
)
|
| 356 |
+
return y[0, 0].contiguous()
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
@torch.no_grad()
|
| 360 |
+
def resize_2d(
|
| 361 |
+
w: torch.Tensor,
|
| 362 |
+
target_shape: Tuple[int, int],
|
| 363 |
+
strategy: str = "crop_pad",
|
| 364 |
+
) -> torch.Tensor:
|
| 365 |
+
if tuple(w.shape) == tuple(target_shape):
|
| 366 |
+
return w.contiguous()
|
| 367 |
+
|
| 368 |
+
if strategy == "strict":
|
| 369 |
+
raise ValueError(f"Shape mismatch: got {tuple(w.shape)}, expected {target_shape}")
|
| 370 |
+
|
| 371 |
+
if strategy == "crop_pad":
|
| 372 |
+
return resize_2d_crop_pad(w, target_shape)
|
| 373 |
+
|
| 374 |
+
if strategy == "interpolate":
|
| 375 |
+
return resize_2d_interpolate(w, target_shape)
|
| 376 |
+
|
| 377 |
+
raise ValueError(f"resize strategy inconnue: {strategy}")
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 381 |
+
# Importer
|
| 382 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 383 |
+
|
| 384 |
+
class OptimizedGGUFImporter:
|
| 385 |
+
def __init__(
|
| 386 |
+
self,
|
| 387 |
+
config: Dict[str, Any],
|
| 388 |
+
scale: str = "tiny",
|
| 389 |
+
storage: str = "fp32",
|
| 390 |
+
param_dtype: str = "fp32",
|
| 391 |
+
noise_method: str = "row_outlier_clip",
|
| 392 |
+
noise_sigma: float = 3.0,
|
| 393 |
+
ternary_threshold: float = 0.5,
|
| 394 |
+
resize_strategy: str = "crop_pad",
|
| 395 |
+
auto_transpose: bool = True,
|
| 396 |
+
init_missing: bool = True,
|
| 397 |
+
verbose: bool = True,
|
| 398 |
+
):
|
| 399 |
+
self.config = deepcopy(config)
|
| 400 |
+
self.scale = scale
|
| 401 |
+
self.storage = storage
|
| 402 |
+
self.param_dtype = param_dtype
|
| 403 |
+
self.noise_method = noise_method
|
| 404 |
+
self.noise_sigma = noise_sigma
|
| 405 |
+
self.ternary_threshold = ternary_threshold
|
| 406 |
+
self.resize_strategy = resize_strategy
|
| 407 |
+
self.auto_transpose = auto_transpose
|
| 408 |
+
self.init_missing = init_missing
|
| 409 |
+
self.verbose = verbose
|
| 410 |
+
|
| 411 |
+
if scale not in SCALE_OVERRIDES:
|
| 412 |
+
raise ValueError(f"scale invalide: {scale}")
|
| 413 |
+
|
| 414 |
+
self.config.update(SCALE_OVERRIDES[scale])
|
| 415 |
+
|
| 416 |
+
self.n_layers = int(self.config["num_hidden_layers"])
|
| 417 |
+
self.hidden_size = int(self.config["hidden_size"])
|
| 418 |
+
self.vocab_size = int(self.config["vocab_size"])
|
| 419 |
+
self.num_heads = int(self.config.get("num_heads", 4))
|
| 420 |
+
self.head_dim = int(self.config.get("head_dim", self.hidden_size // self.num_heads))
|
| 421 |
+
|
| 422 |
+
inter = int(self.config["intermediate_size"])
|
| 423 |
+
self.intermediate_size = 256 * ((inter + 255) // 256)
|
| 424 |
+
self.config["intermediate_size"] = self.intermediate_size
|
| 425 |
+
|
| 426 |
+
if storage not in {"fp32", "packed", "both"}:
|
| 427 |
+
raise ValueError("storage doit Γͺtre: fp32, packed ou both")
|
| 428 |
+
|
| 429 |
+
if param_dtype not in {"fp32", "fp16", "bf16"}:
|
| 430 |
+
raise ValueError("param_dtype doit Γͺtre: fp32, fp16 ou bf16")
|
| 431 |
+
|
| 432 |
+
if self.verbose:
|
| 433 |
+
self.log(
|
| 434 |
+
f"[CONFIG] scale={scale} h={self.hidden_size} "
|
| 435 |
+
f"layers={self.n_layers} heads={self.num_heads} "
|
| 436 |
+
f"head_dim={self.head_dim} inter={self.intermediate_size} "
|
| 437 |
+
f"vocab={self.vocab_size}"
|
| 438 |
+
)
|
| 439 |
+
self.log(
|
| 440 |
+
f"[CONFIG] storage={storage} param_dtype={param_dtype} "
|
| 441 |
+
f"resize={resize_strategy} noise={noise_method}"
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
def log(self, msg: str):
|
| 445 |
+
if self.verbose:
|
| 446 |
+
print(msg, flush=True)
|
| 447 |
+
|
| 448 |
+
def target_dtype(self):
|
| 449 |
+
if self.param_dtype == "fp16":
|
| 450 |
+
return torch.float16
|
| 451 |
+
if self.param_dtype == "bf16":
|
| 452 |
+
return torch.bfloat16
|
| 453 |
+
return torch.float32
|
| 454 |
+
|
| 455 |
+
def infer_shape(self, key: str) -> Tuple[int, ...]:
|
| 456 |
+
h = self.hidden_size
|
| 457 |
+
attn_dim = self.num_heads * self.head_dim
|
| 458 |
+
|
| 459 |
+
if key == "embed.weight":
|
| 460 |
+
return (self.vocab_size, h)
|
| 461 |
+
|
| 462 |
+
if key == "lm_head.weight":
|
| 463 |
+
return (self.vocab_size, h)
|
| 464 |
+
|
| 465 |
+
if key == "norm.weight":
|
| 466 |
+
return (h,)
|
| 467 |
+
|
| 468 |
+
if key.endswith("attn_norm.weight") or key.endswith("mlp_norm.weight"):
|
| 469 |
+
return (h,)
|
| 470 |
+
|
| 471 |
+
if key.endswith("attn.q_proj.weight"):
|
| 472 |
+
return (attn_dim, h)
|
| 473 |
+
if key.endswith("attn.k_proj.weight"):
|
| 474 |
+
return (attn_dim, h)
|
| 475 |
+
if key.endswith("attn.v_proj.weight"):
|
| 476 |
+
return (attn_dim, h)
|
| 477 |
+
if key.endswith("attn.o_proj.weight"):
|
| 478 |
+
return (h, attn_dim)
|
| 479 |
+
|
| 480 |
+
if key.endswith("mlp.gate_proj.weight"):
|
| 481 |
+
return (self.intermediate_size, h)
|
| 482 |
+
if key.endswith("mlp.up_proj.weight"):
|
| 483 |
+
return (self.intermediate_size, h)
|
| 484 |
+
if key.endswith("mlp.down_proj.weight"):
|
| 485 |
+
return (h, self.intermediate_size)
|
| 486 |
+
|
| 487 |
+
raise KeyError(f"Impossible d'infΓ©rer la shape pour {key}")
|
| 488 |
+
|
| 489 |
+
def all_expected_keys(self) -> Iterable[str]:
|
| 490 |
+
yield "embed.weight"
|
| 491 |
+
yield "norm.weight"
|
| 492 |
+
yield "lm_head.weight"
|
| 493 |
+
|
| 494 |
+
for i in range(self.n_layers):
|
| 495 |
+
prefix = f"layers.{i}"
|
| 496 |
+
yield f"{prefix}.attn_norm.weight"
|
| 497 |
+
yield f"{prefix}.mlp_norm.weight"
|
| 498 |
+
yield f"{prefix}.attn.q_proj.weight"
|
| 499 |
+
yield f"{prefix}.attn.k_proj.weight"
|
| 500 |
+
yield f"{prefix}.attn.v_proj.weight"
|
| 501 |
+
yield f"{prefix}.attn.o_proj.weight"
|
| 502 |
+
yield f"{prefix}.mlp.gate_proj.weight"
|
| 503 |
+
yield f"{prefix}.mlp.up_proj.weight"
|
| 504 |
+
yield f"{prefix}.mlp.down_proj.weight"
|
| 505 |
+
|
| 506 |
+
def is_linear_key(self, key: str) -> bool:
|
| 507 |
+
return any(
|
| 508 |
+
key.endswith(s)
|
| 509 |
+
for s in (
|
| 510 |
+
"attn.q_proj.weight",
|
| 511 |
+
"attn.k_proj.weight",
|
| 512 |
+
"attn.v_proj.weight",
|
| 513 |
+
"attn.o_proj.weight",
|
| 514 |
+
"mlp.gate_proj.weight",
|
| 515 |
+
"mlp.up_proj.weight",
|
| 516 |
+
"mlp.down_proj.weight",
|
| 517 |
+
)
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
def is_embedding_or_head(self, key: str) -> bool:
|
| 521 |
+
return key in {"embed.weight", "lm_head.weight"}
|
| 522 |
+
|
| 523 |
+
def maybe_transpose(self, w: torch.Tensor, expected: Tuple[int, ...], key: str) -> torch.Tensor:
|
| 524 |
+
if not self.auto_transpose:
|
| 525 |
+
return w
|
| 526 |
+
|
| 527 |
+
if w.ndim == 2 and len(expected) == 2:
|
| 528 |
+
if tuple(w.shape) != tuple(expected) and tuple(w.t().shape) == tuple(expected):
|
| 529 |
+
self.log(f" [TRANSPOSE] {key}: {tuple(w.shape)} -> {tuple(w.t().shape)}")
|
| 530 |
+
return w.t().contiguous()
|
| 531 |
+
|
| 532 |
+
return w
|
| 533 |
+
|
| 534 |
+
def convert_tensor(
|
| 535 |
+
self,
|
| 536 |
+
gguf_name: str,
|
| 537 |
+
key: str,
|
| 538 |
+
arr: np.ndarray,
|
| 539 |
+
) -> Optional[Dict[str, torch.Tensor]]:
|
| 540 |
+
expected = self.infer_shape(key)
|
| 541 |
+
|
| 542 |
+
w = torch.from_numpy(np.asarray(arr)).to(torch.float32)
|
| 543 |
+
w = self.maybe_transpose(w, expected, key)
|
| 544 |
+
|
| 545 |
+
result: Dict[str, torch.Tensor] = {}
|
| 546 |
+
|
| 547 |
+
# 1D norms
|
| 548 |
+
if len(expected) == 1:
|
| 549 |
+
if w.ndim != 1:
|
| 550 |
+
self.log(f" [SKIP] {gguf_name}: expected 1D {expected}, got {tuple(w.shape)}")
|
| 551 |
+
return None
|
| 552 |
+
|
| 553 |
+
if tuple(w.shape) != tuple(expected):
|
| 554 |
+
self.log(f" [RESIZE-1D] {gguf_name}: {tuple(w.shape)} -> {expected}")
|
| 555 |
+
w = resize_1d(w, expected[0])
|
| 556 |
+
|
| 557 |
+
result[key] = w.to(self.target_dtype()).contiguous()
|
| 558 |
+
return result
|
| 559 |
+
|
| 560 |
+
# Embeddings/lm_head doivent rester denses, pas ternaires ici.
|
| 561 |
+
if self.is_embedding_or_head(key):
|
| 562 |
+
if w.ndim != 2:
|
| 563 |
+
self.log(f" [SKIP] {gguf_name}: expected 2D embedding/head, got {tuple(w.shape)}")
|
| 564 |
+
return None
|
| 565 |
+
|
| 566 |
+
if tuple(w.shape) != tuple(expected):
|
| 567 |
+
self.log(f" [RESIZE-EMB] {gguf_name}: {tuple(w.shape)} -> {expected}")
|
| 568 |
+
w = resize_2d(w, expected, self.resize_strategy)
|
| 569 |
+
|
| 570 |
+
result[key] = w.to(self.target_dtype()).contiguous()
|
| 571 |
+
return result
|
| 572 |
+
|
| 573 |
+
# LinΓ©aires BitLinear
|
| 574 |
+
if self.is_linear_key(key):
|
| 575 |
+
if w.ndim != 2:
|
| 576 |
+
self.log(f" [SKIP] {gguf_name}: expected 2D linear, got {tuple(w.shape)}")
|
| 577 |
+
return None
|
| 578 |
+
|
| 579 |
+
if tuple(w.shape) != tuple(expected):
|
| 580 |
+
self.log(f" [RESIZE-2D] {gguf_name}: {tuple(w.shape)} -> {expected}")
|
| 581 |
+
w = resize_2d(w, expected, self.resize_strategy)
|
| 582 |
+
|
| 583 |
+
w = reduce_noise(w, method=self.noise_method, sigma=self.noise_sigma)
|
| 584 |
+
|
| 585 |
+
if self.storage in {"fp32", "both"}:
|
| 586 |
+
result[key] = w.to(self.target_dtype()).contiguous()
|
| 587 |
+
|
| 588 |
+
if self.storage in {"packed", "both"}:
|
| 589 |
+
q, alpha = ternary_quantize_absmean(
|
| 590 |
+
w,
|
| 591 |
+
threshold=self.ternary_threshold,
|
| 592 |
+
)
|
| 593 |
+
packed = pack_ternary_2bit(q)
|
| 594 |
+
result[f"{key}.packed_weight"] = packed.cpu().contiguous()
|
| 595 |
+
result[f"{key}.alpha"] = alpha.cpu().contiguous()
|
| 596 |
+
result[f"{key}.shape"] = torch.tensor(list(expected), dtype=torch.int32)
|
| 597 |
+
|
| 598 |
+
return result
|
| 599 |
+
|
| 600 |
+
self.log(f" [SKIP] {gguf_name}: key non reconnue {key}")
|
| 601 |
+
return None
|
| 602 |
+
|
| 603 |
+
def init_missing_tensor(self, key: str) -> Dict[str, torch.Tensor]:
|
| 604 |
+
expected = self.infer_shape(key)
|
| 605 |
+
out: Dict[str, torch.Tensor] = {}
|
| 606 |
+
|
| 607 |
+
if len(expected) == 1:
|
| 608 |
+
# Norms : init Γ 1.0
|
| 609 |
+
w = torch.ones(expected, dtype=self.target_dtype())
|
| 610 |
+
out[key] = w
|
| 611 |
+
return out
|
| 612 |
+
|
| 613 |
+
if key in {"embed.weight", "lm_head.weight"}:
|
| 614 |
+
w = torch.empty(expected, dtype=torch.float32)
|
| 615 |
+
w.normal_(0.0, 0.02)
|
| 616 |
+
out[key] = w.to(self.target_dtype())
|
| 617 |
+
return out
|
| 618 |
+
|
| 619 |
+
if self.is_linear_key(key):
|
| 620 |
+
w = torch.empty(expected, dtype=torch.float32)
|
| 621 |
+
fan_in = max(1, expected[1])
|
| 622 |
+
std = math.sqrt(2.0 / fan_in)
|
| 623 |
+
w.normal_(0.0, std)
|
| 624 |
+
|
| 625 |
+
if self.storage in {"fp32", "both"}:
|
| 626 |
+
out[key] = w.to(self.target_dtype()).contiguous()
|
| 627 |
+
|
| 628 |
+
if self.storage in {"packed", "both"}:
|
| 629 |
+
q, alpha = ternary_quantize_absmean(w, threshold=self.ternary_threshold)
|
| 630 |
+
out[f"{key}.packed_weight"] = pack_ternary_2bit(q)
|
| 631 |
+
out[f"{key}.alpha"] = alpha
|
| 632 |
+
out[f"{key}.shape"] = torch.tensor(list(expected), dtype=torch.int32)
|
| 633 |
+
|
| 634 |
+
return out
|
| 635 |
+
|
| 636 |
+
return out
|
| 637 |
+
|
| 638 |
+
def dequantize_tensor(self, tensor) -> np.ndarray:
|
| 639 |
+
"""
|
| 640 |
+
Dequantize GGUF tensor vers numpy float32.
|
| 641 |
+
Compatible avec l'API gguf-py la plus courante.
|
| 642 |
+
"""
|
| 643 |
+
qtype = getattr(tensor, "tensor_type", None)
|
| 644 |
+
data = getattr(tensor, "data", None)
|
| 645 |
+
|
| 646 |
+
if data is None:
|
| 647 |
+
raise RuntimeError(f"Tensor GGUF sans data: {getattr(tensor, 'name', '?')}")
|
| 648 |
+
|
| 649 |
+
try:
|
| 650 |
+
arr = dequantize(data, qtype)
|
| 651 |
+
except Exception:
|
| 652 |
+
# Certains tensors peuvent dΓ©jΓ Γͺtre float array
|
| 653 |
+
arr = np.asarray(data)
|
| 654 |
+
|
| 655 |
+
arr = np.asarray(arr)
|
| 656 |
+
|
| 657 |
+
if arr.dtype != np.float32:
|
| 658 |
+
arr = arr.astype(np.float32, copy=False)
|
| 659 |
+
|
| 660 |
+
return np.ascontiguousarray(arr)
|
| 661 |
+
|
| 662 |
+
def read_arch(self, reader) -> str:
|
| 663 |
+
try:
|
| 664 |
+
field = reader.fields.get("general.architecture")
|
| 665 |
+
if field is None:
|
| 666 |
+
return "unknown"
|
| 667 |
+
# gguf-py field formats can vary.
|
| 668 |
+
if hasattr(field, "parts") and field.parts:
|
| 669 |
+
return str(field.parts[-1])
|
| 670 |
+
return str(field)
|
| 671 |
+
except Exception:
|
| 672 |
+
return "unknown"
|
| 673 |
+
|
| 674 |
+
def import_model(self, gguf_path: str, output_path: str) -> Dict[str, Any]:
|
| 675 |
+
if not HAS_GGUF:
|
| 676 |
+
raise ImportError("Package gguf manquant. Installe avec: pip install gguf")
|
| 677 |
+
|
| 678 |
+
gguf_path = str(gguf_path)
|
| 679 |
+
output_path = str(output_path)
|
| 680 |
+
|
| 681 |
+
self.log("=" * 70)
|
| 682 |
+
self.log("CHIMERA GGUF IMPORT OPTIMIZED")
|
| 683 |
+
self.log("=" * 70)
|
| 684 |
+
|
| 685 |
+
reader = GGUFReader(gguf_path)
|
| 686 |
+
arch = self.read_arch(reader)
|
| 687 |
+
|
| 688 |
+
self.log(f"[GGUF] file={gguf_path}")
|
| 689 |
+
self.log(f"[GGUF] arch={arch}")
|
| 690 |
+
self.log(f"[GGUF] tensors={len(reader.tensors)}")
|
| 691 |
+
|
| 692 |
+
state_dict: Dict[str, torch.Tensor] = {}
|
| 693 |
+
|
| 694 |
+
stats = {
|
| 695 |
+
"mapped": 0,
|
| 696 |
+
"unmapped": 0,
|
| 697 |
+
"skipped": 0,
|
| 698 |
+
"linear": 0,
|
| 699 |
+
"dense": 0,
|
| 700 |
+
"norm": 0,
|
| 701 |
+
"resized_or_transposed_possible": 0,
|
| 702 |
+
}
|
| 703 |
+
|
| 704 |
+
imported_keys = set()
|
| 705 |
+
|
| 706 |
+
for idx, tensor in enumerate(reader.tensors):
|
| 707 |
+
name = str(tensor.name)
|
| 708 |
+
key = map_gguf_name(name, self.n_layers)
|
| 709 |
+
|
| 710 |
+
if key is None:
|
| 711 |
+
stats["unmapped"] += 1
|
| 712 |
+
if self.verbose:
|
| 713 |
+
self.log(f" [UNMAPPED] {name}")
|
| 714 |
+
continue
|
| 715 |
+
|
| 716 |
+
try:
|
| 717 |
+
arr = self.dequantize_tensor(tensor)
|
| 718 |
+
converted = self.convert_tensor(name, key, arr)
|
| 719 |
+
|
| 720 |
+
if not converted:
|
| 721 |
+
stats["skipped"] += 1
|
| 722 |
+
continue
|
| 723 |
+
|
| 724 |
+
state_dict.update(converted)
|
| 725 |
+
imported_keys.add(key)
|
| 726 |
+
stats["mapped"] += 1
|
| 727 |
+
|
| 728 |
+
if self.is_linear_key(key):
|
| 729 |
+
stats["linear"] += 1
|
| 730 |
+
elif key in {"embed.weight", "lm_head.weight"}:
|
| 731 |
+
stats["dense"] += 1
|
| 732 |
+
else:
|
| 733 |
+
stats["norm"] += 1
|
| 734 |
+
|
| 735 |
+
if self.verbose:
|
| 736 |
+
qtype = getattr(tensor, "tensor_type", "?")
|
| 737 |
+
shape = tuple(arr.shape)
|
| 738 |
+
self.log(f" [OK] {idx+1:04d} {name} -> {key} shape={shape} qtype={qtype}")
|
| 739 |
+
|
| 740 |
+
except Exception as e:
|
| 741 |
+
stats["skipped"] += 1
|
| 742 |
+
self.log(f" [ERROR] {name}: {type(e).__name__}: {e}")
|
| 743 |
+
|
| 744 |
+
finally:
|
| 745 |
+
# Libère le FP32 temporaire.
|
| 746 |
+
try:
|
| 747 |
+
del arr
|
| 748 |
+
except Exception:
|
| 749 |
+
pass
|
| 750 |
+
gc.collect()
|
| 751 |
+
|
| 752 |
+
# Init des clΓ©s manquantes
|
| 753 |
+
missing = []
|
| 754 |
+
if self.init_missing:
|
| 755 |
+
for key in self.all_expected_keys():
|
| 756 |
+
if key not in imported_keys:
|
| 757 |
+
missing.append(key)
|
| 758 |
+
init_tensors = self.init_missing_tensor(key)
|
| 759 |
+
state_dict.update(init_tensors)
|
| 760 |
+
|
| 761 |
+
if missing:
|
| 762 |
+
self.log(f"[MISSING] {len(missing)} tensors initialisΓ©s automatiquement")
|
| 763 |
+
|
| 764 |
+
ckpt = {
|
| 765 |
+
"model": state_dict,
|
| 766 |
+
"config": self.config,
|
| 767 |
+
"source": {
|
| 768 |
+
"gguf_path": gguf_path,
|
| 769 |
+
"gguf_arch": arch,
|
| 770 |
+
"scale": self.scale,
|
| 771 |
+
"storage": self.storage,
|
| 772 |
+
"param_dtype": self.param_dtype,
|
| 773 |
+
"noise_method": self.noise_method,
|
| 774 |
+
"noise_sigma": self.noise_sigma,
|
| 775 |
+
"ternary_threshold": self.ternary_threshold,
|
| 776 |
+
"resize_strategy": self.resize_strategy,
|
| 777 |
+
"auto_transpose": self.auto_transpose,
|
| 778 |
+
},
|
| 779 |
+
"stats": stats,
|
| 780 |
+
"missing_keys": missing,
|
| 781 |
+
"import_version": "2.0-optimized",
|
| 782 |
+
}
|
| 783 |
+
|
| 784 |
+
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
| 785 |
+
torch.save(ckpt, output_path)
|
| 786 |
+
|
| 787 |
+
gguf_mb = os.path.getsize(gguf_path) / 1024 / 1024
|
| 788 |
+
out_mb = os.path.getsize(output_path) / 1024 / 1024
|
| 789 |
+
|
| 790 |
+
self.log("")
|
| 791 |
+
self.log("=" * 70)
|
| 792 |
+
self.log("[DONE]")
|
| 793 |
+
self.log(f"[STATS] {stats}")
|
| 794 |
+
self.log(f"[SIZE] GGUF={gguf_mb:.2f} MB -> checkpoint={out_mb:.2f} MB")
|
| 795 |
+
self.log(f"[SAVE] {output_path}")
|
| 796 |
+
self.log("=" * 70)
|
| 797 |
+
|
| 798 |
+
return ckpt
|
| 799 |
+
|
| 800 |
+
|
| 801 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 802 |
+
# CLI
|
| 803 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 804 |
+
|
| 805 |
+
def main():
|
| 806 |
+
parser = argparse.ArgumentParser(
|
| 807 |
+
description="Optimized GGUF -> Chimera checkpoint importer"
|
| 808 |
+
)
|
| 809 |
+
|
| 810 |
+
parser.add_argument("--gguf", required=True, help="Path to input .gguf")
|
| 811 |
+
parser.add_argument("--config", default=str(DEFAULT_CONFIG_PATH), help="Chimera config.json")
|
| 812 |
+
parser.add_argument("--output", required=True, help="Output .pt checkpoint")
|
| 813 |
+
|
| 814 |
+
parser.add_argument(
|
| 815 |
+
"--scale",
|
| 816 |
+
default="tiny",
|
| 817 |
+
choices=["tiny", "small", "medium", "full"],
|
| 818 |
+
help="Chimera scale override",
|
| 819 |
+
)
|
| 820 |
+
|
| 821 |
+
parser.add_argument(
|
| 822 |
+
"--storage",
|
| 823 |
+
default="fp32",
|
| 824 |
+
choices=["fp32", "packed", "both"],
|
| 825 |
+
help=(
|
| 826 |
+
"fp32=compatible Chimera classique, "
|
| 827 |
+
"packed=2-bit seulement, both=les deux"
|
| 828 |
+
),
|
| 829 |
+
)
|
| 830 |
+
|
| 831 |
+
parser.add_argument(
|
| 832 |
+
"--param-dtype",
|
| 833 |
+
default="fp32",
|
| 834 |
+
choices=["fp32", "fp16", "bf16"],
|
| 835 |
+
help="dtype pour les tensors denses/latents sauvegardΓ©s",
|
| 836 |
+
)
|
| 837 |
+
|
| 838 |
+
parser.add_argument(
|
| 839 |
+
"--noise-method",
|
| 840 |
+
default="row_outlier_clip",
|
| 841 |
+
choices=["none", "global_clip", "row_outlier_clip", "median_center"],
|
| 842 |
+
help="Noise reduction before ternary conversion",
|
| 843 |
+
)
|
| 844 |
+
|
| 845 |
+
parser.add_argument(
|
| 846 |
+
"--noise-sigma",
|
| 847 |
+
type=float,
|
| 848 |
+
default=3.0,
|
| 849 |
+
help="Sigma for clipping",
|
| 850 |
+
)
|
| 851 |
+
|
| 852 |
+
parser.add_argument(
|
| 853 |
+
"--ternary-threshold",
|
| 854 |
+
type=float,
|
| 855 |
+
default=0.5,
|
| 856 |
+
help="Threshold on normalized weights for ternary quantization",
|
| 857 |
+
)
|
| 858 |
+
|
| 859 |
+
parser.add_argument(
|
| 860 |
+
"--resize-strategy",
|
| 861 |
+
default="crop_pad",
|
| 862 |
+
choices=["strict", "crop_pad", "interpolate"],
|
| 863 |
+
help="Resize strategy when GGUF shape != Chimera shape",
|
| 864 |
+
)
|
| 865 |
+
|
| 866 |
+
parser.add_argument(
|
| 867 |
+
"--no-auto-transpose",
|
| 868 |
+
action="store_true",
|
| 869 |
+
help="Disable automatic transpose when reversed shape matches",
|
| 870 |
+
)
|
| 871 |
+
|
| 872 |
+
parser.add_argument(
|
| 873 |
+
"--no-init-missing",
|
| 874 |
+
action="store_true",
|
| 875 |
+
help="Do not initialize missing Chimera weights",
|
| 876 |
+
)
|
| 877 |
+
|
| 878 |
+
parser.add_argument(
|
| 879 |
+
"--quiet",
|
| 880 |
+
action="store_true",
|
| 881 |
+
help="Less logs",
|
| 882 |
+
)
|
| 883 |
+
|
| 884 |
+
args = parser.parse_args()
|
| 885 |
+
|
| 886 |
+
with open(args.config, "r", encoding="utf-8") as f:
|
| 887 |
+
config = json.load(f)
|
| 888 |
+
|
| 889 |
+
importer = OptimizedGGUFImporter(
|
| 890 |
+
config=config,
|
| 891 |
+
scale=args.scale,
|
| 892 |
+
storage=args.storage,
|
| 893 |
+
param_dtype=args.param_dtype,
|
| 894 |
+
noise_method=args.noise_method,
|
| 895 |
+
noise_sigma=args.noise_sigma,
|
| 896 |
+
ternary_threshold=args.ternary_threshold,
|
| 897 |
+
resize_strategy=args.resize_strategy,
|
| 898 |
+
auto_transpose=not args.no_auto_transpose,
|
| 899 |
+
init_missing=not args.no_init_missing,
|
| 900 |
+
verbose=not args.quiet,
|
| 901 |
+
)
|
| 902 |
+
|
| 903 |
+
importer.import_model(args.gguf, args.output)
|
| 904 |
+
|
| 905 |
+
|
| 906 |
+
if __name__ == "__main__":
|
| 907 |
+
main()
|
inference.py
ADDED
|
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Chimera 5.2 β CPU-first inference / text generation.
|
| 3 |
+
|
| 4 |
+
Config is source of truth. Checkpoint weights are resized to match the model.
|
| 5 |
+
"""
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import argparse
|
| 9 |
+
import json
|
| 10 |
+
import os
|
| 11 |
+
import time
|
| 12 |
+
from typing import Dict, Tuple
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _setup_cpu_runtime() -> None:
|
| 16 |
+
n = os.cpu_count() or 4
|
| 17 |
+
os.environ.setdefault("OMP_NUM_THREADS", str(n))
|
| 18 |
+
os.environ.setdefault("MKL_NUM_THREADS", str(n))
|
| 19 |
+
os.environ.setdefault("KMP_AFFINITY", "granularity=fine,compact,1,0")
|
| 20 |
+
os.environ.setdefault("KMP_BLOCKTIME", "1")
|
| 21 |
+
os.environ.setdefault("MALLOC_CONF", "background_thread:true,metadata_thp:auto")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
_setup_cpu_runtime()
|
| 25 |
+
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn.functional as F
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
torch.set_num_threads(int(os.environ.get("OMP_NUM_THREADS", os.cpu_count() or 4)))
|
| 31 |
+
torch.set_num_interop_threads(int(os.environ.get("CHIMERA_INTEROP_THREADS", "1")))
|
| 32 |
+
except RuntimeError:
|
| 33 |
+
pass
|
| 34 |
+
|
| 35 |
+
from chimera import Chimera51ForCausalLM, ChimeraTokenizer
|
| 36 |
+
from chimera.paths import DEFAULT_CONFIG_PATH
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# ---------------------------------------------------------------------------
|
| 40 |
+
# Resize helpers: checkpoint weights -> model architecture (config is truth)
|
| 41 |
+
# ---------------------------------------------------------------------------
|
| 42 |
+
|
| 43 |
+
@torch.no_grad()
|
| 44 |
+
def _resize_1d(w: torch.Tensor, target: int) -> torch.Tensor:
|
| 45 |
+
out = torch.ones(target, dtype=w.dtype, device=w.device)
|
| 46 |
+
n = min(w.numel(), target)
|
| 47 |
+
out[:n] = w[:n]
|
| 48 |
+
return out
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@torch.no_grad()
|
| 52 |
+
def _resize_2d(w: torch.Tensor, target_shape: Tuple[int, int]) -> torch.Tensor:
|
| 53 |
+
to, ti = target_shape
|
| 54 |
+
so, si = w.shape
|
| 55 |
+
if (so, si) == (to, ti):
|
| 56 |
+
return w
|
| 57 |
+
out = torch.empty((to, ti), dtype=w.dtype, device=w.device)
|
| 58 |
+
std = float(w.std(unbiased=False).item()) if w.numel() > 1 else 0.02
|
| 59 |
+
std = max(min(std, 0.2), 1e-4)
|
| 60 |
+
out.normal_(mean=0.0, std=std)
|
| 61 |
+
ro, ci = min(so, to), min(si, ti)
|
| 62 |
+
out[:ro, :ci] = w[:ro, :ci]
|
| 63 |
+
return out
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# ---------------------------------------------------------------------------
|
| 67 |
+
# Checkpoint loading
|
| 68 |
+
# ---------------------------------------------------------------------------
|
| 69 |
+
|
| 70 |
+
def load_model(checkpoint_path: str, device: str = "cpu"):
|
| 71 |
+
print(f"[LOAD] Checkpoint: {checkpoint_path}")
|
| 72 |
+
ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
| 73 |
+
|
| 74 |
+
config = ckpt.get("config")
|
| 75 |
+
if config is None:
|
| 76 |
+
ckpt_dir = os.path.dirname(checkpoint_path)
|
| 77 |
+
cand = os.path.join(ckpt_dir, "config.json") if ckpt_dir else "config.json"
|
| 78 |
+
if not os.path.exists(cand):
|
| 79 |
+
cand = str(DEFAULT_CONFIG_PATH)
|
| 80 |
+
with open(cand, encoding="utf-8") as f:
|
| 81 |
+
config = json.load(f)
|
| 82 |
+
print(f"[LOAD] Config from {cand}")
|
| 83 |
+
else:
|
| 84 |
+
print("[LOAD] Config from checkpoint")
|
| 85 |
+
|
| 86 |
+
model = Chimera51ForCausalLM(config)
|
| 87 |
+
counts = model.count_parameters()
|
| 88 |
+
print(f"[LOAD] Params: {counts['total']:,} (ternary: {counts['ternary']:,})")
|
| 89 |
+
|
| 90 |
+
state = ckpt.get("model", ckpt)
|
| 91 |
+
model_state = model.state_dict()
|
| 92 |
+
|
| 93 |
+
# Config is source of truth: resize checkpoint tensors to match model.
|
| 94 |
+
resized: Dict[str, torch.Tensor] = {}
|
| 95 |
+
for k, v in state.items():
|
| 96 |
+
if k in model_state:
|
| 97 |
+
expected = model_state[k].shape
|
| 98 |
+
if v.shape != expected:
|
| 99 |
+
print(f"[WARN] resizing {k}: {tuple(v.shape)} -> {tuple(expected)}")
|
| 100 |
+
if v.ndim == 1:
|
| 101 |
+
v = _resize_1d(v, expected[0])
|
| 102 |
+
elif v.ndim == 2:
|
| 103 |
+
v = _resize_2d(v, expected)
|
| 104 |
+
else:
|
| 105 |
+
print(f"[SKIP] {k}: cannot resize {v.ndim}D tensor")
|
| 106 |
+
continue
|
| 107 |
+
resized[k] = v
|
| 108 |
+
else:
|
| 109 |
+
resized[k] = v
|
| 110 |
+
|
| 111 |
+
# Vocab reconciliation: if vocab mismatch, re-init embed + lm_head.
|
| 112 |
+
model_vocab = int(config.get("vocab_size", model.embed.num_embeddings))
|
| 113 |
+
if "embed.weight" in resized:
|
| 114 |
+
ckpt_vocab = int(resized["embed.weight"].shape[0])
|
| 115 |
+
if ckpt_vocab != model_vocab:
|
| 116 |
+
print(f"[WARN] vocab mismatch ckpt={ckpt_vocab} cfg={model_vocab}; re-init embed+head")
|
| 117 |
+
with torch.no_grad():
|
| 118 |
+
old = model.embed.weight.data
|
| 119 |
+
new = torch.zeros(ckpt_vocab, old.shape[1], dtype=old.dtype, device=old.device)
|
| 120 |
+
new[:min(old.shape[0], ckpt_vocab)] = old[:min(old.shape[0], ckpt_vocab)]
|
| 121 |
+
model.embed = torch.nn.Embedding(ckpt_vocab, old.shape[1])
|
| 122 |
+
model.embed.weight.data = new
|
| 123 |
+
old_h = model.lm_head.weight.data
|
| 124 |
+
new_h = torch.zeros(ckpt_vocab, old_h.shape[1], dtype=old_h.dtype, device=old_h.device)
|
| 125 |
+
new_h[:min(old_h.shape[0], ckpt_vocab)] = old_h[:min(old_h.shape[0], ckpt_vocab)]
|
| 126 |
+
model.lm_head = torch.nn.Linear(old_h.shape[1], ckpt_vocab, bias=False)
|
| 127 |
+
model.lm_head.weight.data = new_h
|
| 128 |
+
config["vocab_size"] = ckpt_vocab
|
| 129 |
+
|
| 130 |
+
missing, unexpected = model.load_state_dict(resized, strict=False)
|
| 131 |
+
if missing:
|
| 132 |
+
print(f"[WARN] Missing keys ({len(missing)}): {missing[:5]}...")
|
| 133 |
+
if unexpected:
|
| 134 |
+
print(f"[WARN] Unexpected keys ({len(unexpected)}): {unexpected[:5]}...")
|
| 135 |
+
|
| 136 |
+
model.to(device).eval()
|
| 137 |
+
model.prepare_for_inference()
|
| 138 |
+
|
| 139 |
+
step = ckpt.get("step", "?")
|
| 140 |
+
best_loss = ckpt.get("best_loss")
|
| 141 |
+
if best_loss is not None:
|
| 142 |
+
print(f"[LOAD] Step {step}, best_loss={best_loss:.4f}")
|
| 143 |
+
else:
|
| 144 |
+
print(f"[LOAD] Step {step}")
|
| 145 |
+
return model, config
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
# ---------------------------------------------------------------------------
|
| 149 |
+
# Sampling helpers
|
| 150 |
+
# ---------------------------------------------------------------------------
|
| 151 |
+
|
| 152 |
+
def _sample_next(logits: torch.Tensor, temperature: float, top_p: float, top_k: int
|
| 153 |
+
) -> int:
|
| 154 |
+
if logits.dim() == 1:
|
| 155 |
+
logits = logits.unsqueeze(0)
|
| 156 |
+
if temperature <= 0.0:
|
| 157 |
+
return int(torch.argmax(logits, dim=-1).item())
|
| 158 |
+
logits = logits / temperature
|
| 159 |
+
if top_k and top_k > 0:
|
| 160 |
+
k = min(top_k, logits.size(-1))
|
| 161 |
+
cand_logits, cand_indices = torch.topk(logits, k, dim=-1)
|
| 162 |
+
if top_p < 1.0:
|
| 163 |
+
sorted_logits, order = torch.sort(cand_logits, descending=True)
|
| 164 |
+
sorted_indices = cand_indices.gather(-1, order)
|
| 165 |
+
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 166 |
+
remove = cum_probs > top_p
|
| 167 |
+
remove[..., 0] = False
|
| 168 |
+
sorted_logits = sorted_logits.masked_fill(remove, float("-inf"))
|
| 169 |
+
probs = F.softmax(sorted_logits, dim=-1)
|
| 170 |
+
return int(sorted_indices.gather(-1, torch.multinomial(probs, 1)).item())
|
| 171 |
+
probs = F.softmax(cand_logits, dim=-1)
|
| 172 |
+
return int(cand_indices.gather(-1, torch.multinomial(probs, 1)).item())
|
| 173 |
+
if top_p < 1.0:
|
| 174 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 175 |
+
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 176 |
+
remove = cum_probs > top_p
|
| 177 |
+
remove[..., 0] = False
|
| 178 |
+
sorted_logits = sorted_logits.masked_fill(remove, float("-inf"))
|
| 179 |
+
probs = F.softmax(sorted_logits, dim=-1)
|
| 180 |
+
return int(sorted_indices.gather(-1, torch.multinomial(probs, 1)).item())
|
| 181 |
+
probs = F.softmax(logits, dim=-1)
|
| 182 |
+
return int(torch.multinomial(probs, 1).item())
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
# ---------------------------------------------------------------------------
|
| 186 |
+
# Generation loop
|
| 187 |
+
# ---------------------------------------------------------------------------
|
| 188 |
+
|
| 189 |
+
def generate(model: Chimera51ForCausalLM, tokenizer: ChimeraTokenizer,
|
| 190 |
+
prompt: str, max_tokens: int = 100, temperature: float = 0.8,
|
| 191 |
+
top_p: float = 0.9, top_k: int = 50, device: str = "cpu",
|
| 192 |
+
bf16: bool = False, stream: bool = True) -> str:
|
| 193 |
+
model.eval()
|
| 194 |
+
prompt_ids = tokenizer.encode(prompt, add_special_tokens=False)
|
| 195 |
+
if not prompt_ids:
|
| 196 |
+
prompt_ids = [tokenizer.eos_token_id]
|
| 197 |
+
input_ids = torch.tensor([prompt_ids], dtype=torch.long, device=device)
|
| 198 |
+
|
| 199 |
+
print(f"\n[GEN] Prompt: {prompt!r}")
|
| 200 |
+
print(f"[GEN] max_tokens={max_tokens}, temp={temperature}, top_p={top_p}, top_k={top_k}")
|
| 201 |
+
print("=" * 60, flush=True)
|
| 202 |
+
|
| 203 |
+
if stream:
|
| 204 |
+
sys.stdout.write(prompt)
|
| 205 |
+
sys.stdout.flush()
|
| 206 |
+
|
| 207 |
+
generated = list(prompt_ids)
|
| 208 |
+
decoded_so_far = tokenizer.decode(generated, skip_special_tokens=False)
|
| 209 |
+
|
| 210 |
+
autocast_ctx = (torch.autocast(device_type=device.split(":")[0], dtype=torch.bfloat16)
|
| 211 |
+
if bf16 else _nullctx())
|
| 212 |
+
|
| 213 |
+
t0 = time.time()
|
| 214 |
+
with torch.inference_mode(), autocast_ctx:
|
| 215 |
+
out = model(input_ids, use_cache=True, logits_to_keep=1)
|
| 216 |
+
caches = out.caches
|
| 217 |
+
next_token = _sample_next(out.logits[:, -1, :].float(), temperature, top_p, top_k)
|
| 218 |
+
if next_token == tokenizer.eos_token_id:
|
| 219 |
+
return tokenizer.decode(generated, skip_special_tokens=True)
|
| 220 |
+
generated.append(next_token)
|
| 221 |
+
|
| 222 |
+
for _ in range(max_tokens - 1):
|
| 223 |
+
tok_t = torch.tensor([[next_token]], dtype=torch.long, device=device)
|
| 224 |
+
out = model(tok_t, caches=caches, use_cache=True, logits_to_keep=1)
|
| 225 |
+
caches = out.caches
|
| 226 |
+
next_token = _sample_next(out.logits[:, -1, :].float(), temperature, top_p, top_k)
|
| 227 |
+
if next_token == tokenizer.eos_token_id:
|
| 228 |
+
break
|
| 229 |
+
generated.append(next_token)
|
| 230 |
+
if stream:
|
| 231 |
+
full = tokenizer.decode(generated, skip_special_tokens=False)
|
| 232 |
+
if full.startswith(decoded_so_far):
|
| 233 |
+
sys.stdout.write(full[len(decoded_so_far):])
|
| 234 |
+
sys.stdout.flush()
|
| 235 |
+
decoded_so_far = full
|
| 236 |
+
|
| 237 |
+
elapsed = time.time() - t0
|
| 238 |
+
n_new = len(generated) - len(prompt_ids)
|
| 239 |
+
speed = n_new / elapsed if elapsed > 0 else 0.0
|
| 240 |
+
final = tokenizer.decode(generated, skip_special_tokens=True)
|
| 241 |
+
|
| 242 |
+
print()
|
| 243 |
+
print("=" * 60)
|
| 244 |
+
if not stream:
|
| 245 |
+
print(final)
|
| 246 |
+
print(f"[STATS] {n_new} new tokens in {elapsed:.2f}s ({speed:.1f} tok/s)")
|
| 247 |
+
return final
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
class _nullctx:
|
| 251 |
+
def __enter__(self):
|
| 252 |
+
return self
|
| 253 |
+
def __exit__(self, *args):
|
| 254 |
+
return False
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
# ---------------------------------------------------------------------------
|
| 258 |
+
# CLI
|
| 259 |
+
# ---------------------------------------------------------------------------
|
| 260 |
+
|
| 261 |
+
def main() -> None:
|
| 262 |
+
p = argparse.ArgumentParser(description="Chimera 5.2 CPU inference")
|
| 263 |
+
p.add_argument("--checkpoint", default="chimera_output/final/model.pt")
|
| 264 |
+
p.add_argument("--prompt", default="Once upon a time")
|
| 265 |
+
p.add_argument("--max_tokens", type=int, default=100)
|
| 266 |
+
p.add_argument("--temperature", type=float, default=0.8)
|
| 267 |
+
p.add_argument("--top_p", type=float, default=0.9)
|
| 268 |
+
p.add_argument("--top_k", type=int, default=50)
|
| 269 |
+
p.add_argument("--device", default="cpu")
|
| 270 |
+
p.add_argument("--bf16", action="store_true", default=True)
|
| 271 |
+
p.add_argument("--no-bf16", dest="bf16", action="store_false")
|
| 272 |
+
p.add_argument("--threads", type=int, default=None)
|
| 273 |
+
p.add_argument("--compile", action="store_true", default=False)
|
| 274 |
+
p.add_argument("--no-stream", dest="stream", action="store_false", default=True)
|
| 275 |
+
args = p.parse_args()
|
| 276 |
+
|
| 277 |
+
if args.threads:
|
| 278 |
+
torch.set_num_threads(args.threads)
|
| 279 |
+
os.environ["OMP_NUM_THREADS"] = str(args.threads)
|
| 280 |
+
os.environ["MKL_NUM_THREADS"] = str(args.threads)
|
| 281 |
+
|
| 282 |
+
if not os.path.exists(args.checkpoint):
|
| 283 |
+
print(f"[ERROR] Checkpoint not found: {args.checkpoint}")
|
| 284 |
+
return
|
| 285 |
+
|
| 286 |
+
model, config = load_model(args.checkpoint, device=args.device)
|
| 287 |
+
|
| 288 |
+
if args.compile:
|
| 289 |
+
print("[OPT] Compiling model with torch.compile (mode=reduce-overhead)...")
|
| 290 |
+
model = torch.compile(model, backend="inductor", mode="reduce-overhead")
|
| 291 |
+
|
| 292 |
+
print("[LOAD] Loading tokenizer (splintr o200k_base)...")
|
| 293 |
+
tokenizer = ChimeraTokenizer(pretrained="o200k_base")
|
| 294 |
+
|
| 295 |
+
print("[WARM] Warmup forward...")
|
| 296 |
+
with torch.inference_mode():
|
| 297 |
+
_ = model(torch.tensor([[tokenizer.eos_token_id]], device=args.device), logits_to_keep=1)
|
| 298 |
+
print("[WARM] Done.")
|
| 299 |
+
|
| 300 |
+
generate(
|
| 301 |
+
model, tokenizer,
|
| 302 |
+
prompt=args.prompt, max_tokens=args.max_tokens,
|
| 303 |
+
temperature=args.temperature, top_p=args.top_p, top_k=args.top_k,
|
| 304 |
+
device=args.device, bf16=args.bf16, stream=args.stream,
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
if __name__ == "__main__":
|
| 309 |
+
main()
|
launch_turbo.sh
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
# launch_turbo.sh β Launch ch1mera with all CPU optimizations
|
| 3 |
+
#
|
| 4 |
+
# Usage: ./launch_turbo.sh [train_hyper.py args...]
|
| 5 |
+
# Example: ./launch_turbo.sh --scale tiny --seq_len 128 --max_steps 5000 --batch_size 16
|
| 6 |
+
|
| 7 |
+
set -e
|
| 8 |
+
|
| 9 |
+
# ββ Detect physical cores ββ
|
| 10 |
+
PHYS_CORES=$(lscpu -p | grep -v '^#' | sort -t, -k 2 -un | wc -l)
|
| 11 |
+
COMPUTE_THREADS=$((PHYS_CORES - 1))
|
| 12 |
+
echo "[TURBO] Physical cores: $PHYS_CORES β Compute threads: $COMPUTE_THREADS"
|
| 13 |
+
|
| 14 |
+
# ββ Threading ββ
|
| 15 |
+
export OMP_NUM_THREADS=$COMPUTE_THREADS
|
| 16 |
+
export MKL_NUM_THREADS=$COMPUTE_THREADS
|
| 17 |
+
export KMP_AFFINITY=granularity=fine,compact,1,0
|
| 18 |
+
export KMP_BLOCKTIME=1 # short blocktime for training (frequent sync)
|
| 19 |
+
|
| 20 |
+
# ββ tcmalloc (if available) ββ
|
| 21 |
+
TCMALLOC_LIB=$(ldconfig -p 2>/dev/null | grep -oP '/\S*libtcmalloc\S*\.so\S*' | head -1)
|
| 22 |
+
if [ -n "$TCMALLOC_LIB" ]; then
|
| 23 |
+
echo "[TURBO] tcmalloc: $TCMALLOC_LIB"
|
| 24 |
+
export LD_PRELOAD="$TCMALLOC_LIB${LD_PRELOAD:+:$LD_PRELOAD}"
|
| 25 |
+
else
|
| 26 |
+
echo "[TURBO] β tcmalloc not found. Install: sudo apt install google-perftools"
|
| 27 |
+
fi
|
| 28 |
+
|
| 29 |
+
# ββ IOMP (Intel OpenMP, if available) ββ
|
| 30 |
+
IOMP_LIB=$(python -c "import intel_extension_for_pytorch; import os; print(os.path.join(os.path.dirname(intel_extension_for_pytorch.__file__), '..', 'libiomp5.so'))" 2>/dev/null)
|
| 31 |
+
if [ -f "$IOMP_LIB" ]; then
|
| 32 |
+
echo "[TURBO] libiomp5: $IOMP_LIB"
|
| 33 |
+
export LD_PRELOAD="$IOMP_LIB${LD_PRELOAD:+:$LD_PRELOAD}"
|
| 34 |
+
fi
|
| 35 |
+
|
| 36 |
+
# ββ NUMA pinning (if numactl available) ββ
|
| 37 |
+
if command -v numactl &>/dev/null; then
|
| 38 |
+
echo "[TURBO] NUMA: pinning to node 0"
|
| 39 |
+
NUMA_PREFIX="numactl --cpunodebind=0 --membind=0"
|
| 40 |
+
else
|
| 41 |
+
NUMA_PREFIX=""
|
| 42 |
+
fi
|
| 43 |
+
|
| 44 |
+
# ββ Launch ββ
|
| 45 |
+
echo "[TURBO] Launching: python train_hyper.py $@"
|
| 46 |
+
echo "βββββββββββββββββββββββββββββββββββββββββββββββββββ"
|
| 47 |
+
|
| 48 |
+
$NUMA_PREFIX python train_hyper.py "$@"
|
pyproject.toml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=68"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "chimera51-cpu"
|
| 7 |
+
version = "5.2.0"
|
| 8 |
+
description = "CPU-first Chimera 5.1 causal LM implementation"
|
| 9 |
+
requires-python = ">=3.10"
|
| 10 |
+
dependencies = ["torch"]
|
| 11 |
+
|
| 12 |
+
[project.scripts]
|
| 13 |
+
chimera-train = "chimera.cli:train_main"
|
| 14 |
+
chimera-train-fast = "chimera.cli:train_fast_main"
|
| 15 |
+
chimera-train-hyper = "chimera.cli:train_hyper_main"
|
| 16 |
+
chimera-infer = "chimera.cli:infer_main"
|
| 17 |
+
chimera-import-gguf = "chimera.cli:import_gguf_main"
|
| 18 |
+
|
| 19 |
+
[tool.setuptools]
|
| 20 |
+
packages = ["chimera", "chimera.training"]
|
| 21 |
+
py-modules = ["train", "train_fast", "train_hyper", "inference", "gguf_import", "chimera_turbo"]
|
| 22 |
+
|
| 23 |
+
[tool.setuptools.data-files]
|
| 24 |
+
"." = ["config.json"]
|
| 25 |
+
|
| 26 |
+
[tool.pytest.ini_options]
|
| 27 |
+
testpaths = ["tests"]
|
| 28 |
+
pythonpath = ["."]
|
tests/test_chimera.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
|
| 3 |
+
torch = pytest.importorskip("torch")
|
| 4 |
+
|
| 5 |
+
from chimera import (
|
| 6 |
+
Chimera51ForCausalLM, ChimeraTokenizer, load_config, scale_config,
|
| 7 |
+
pack_ternary, unpack_ternary,
|
| 8 |
+
)
|
| 9 |
+
from chimera.inference import SpanBank
|
| 10 |
+
from chimera.moe import MoELayer
|
| 11 |
+
from chimera.quantization import BitLinear, ternarize_weight
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def cfg():
|
| 15 |
+
c = scale_config(load_config("config.json"), "nano")
|
| 16 |
+
c["vocab_size"] = 512
|
| 17 |
+
c["span_inference"]["enabled"] = False
|
| 18 |
+
return c
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def test_pack_unpack_roundtrip():
|
| 22 |
+
q = torch.tensor([[-1, 0, 1, 1, -1, 0, 1, 0, -1]], dtype=torch.int8)
|
| 23 |
+
packed = pack_ternary(q)
|
| 24 |
+
out = unpack_ternary(packed, q.shape[-1], dtype=torch.float32).to(torch.int8)
|
| 25 |
+
assert torch.equal(q, out)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def test_ternarize_weight_basic():
|
| 29 |
+
w = torch.randn(8, 16) * 0.5
|
| 30 |
+
wq, alpha = ternarize_weight(w)
|
| 31 |
+
assert wq.shape == w.shape
|
| 32 |
+
assert alpha.shape == (8,)
|
| 33 |
+
assert (wq.unique().abs() <= 1).all()
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def test_bitlinear_forward_backward_and_packed():
|
| 37 |
+
layer = BitLinear(7, 5)
|
| 38 |
+
x = torch.randn(3, 7, requires_grad=True)
|
| 39 |
+
y = layer(x).sum()
|
| 40 |
+
y.backward()
|
| 41 |
+
assert x.grad is not None and torch.isfinite(x.grad).all()
|
| 42 |
+
assert layer.weight.grad is not None
|
| 43 |
+
layer.prepare_for_inference()
|
| 44 |
+
layer.eval()
|
| 45 |
+
with torch.no_grad():
|
| 46 |
+
out = layer(torch.randn(2, 7))
|
| 47 |
+
assert out.shape == (2, 5)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def test_bitlinear_dense_cache_consistency():
|
| 51 |
+
layer = BitLinear(8, 4)
|
| 52 |
+
layer.eval()
|
| 53 |
+
layer.prepare_for_inference()
|
| 54 |
+
x = torch.randn(2, 8)
|
| 55 |
+
with torch.no_grad():
|
| 56 |
+
out1 = layer(x)
|
| 57 |
+
out2 = layer(x)
|
| 58 |
+
assert torch.allclose(out1, out2)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def test_model_forward_loss_and_generate_shape():
|
| 62 |
+
model = Chimera51ForCausalLM(cfg())
|
| 63 |
+
x = torch.randint(0, 512, (2, 8))
|
| 64 |
+
y = torch.randint(0, 512, (2, 8))
|
| 65 |
+
out = model(x, labels=y)
|
| 66 |
+
assert out.logits.shape == (2, 8, 512)
|
| 67 |
+
assert torch.isfinite(out.loss)
|
| 68 |
+
out.loss.backward()
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def test_model_kv_cache_consistency():
|
| 72 |
+
"""Generation with KV-cache must match generation without it."""
|
| 73 |
+
config = cfg()
|
| 74 |
+
config["looping"]["enabled"] = False # determinism for the equivalence check
|
| 75 |
+
model = Chimera51ForCausalLM(config).eval()
|
| 76 |
+
model.prepare_for_inference()
|
| 77 |
+
|
| 78 |
+
prompt = torch.randint(0, 512, (1, 4))
|
| 79 |
+
with torch.inference_mode():
|
| 80 |
+
# No-cache: feed the full sequence each time.
|
| 81 |
+
cur = prompt.clone()
|
| 82 |
+
no_cache_tokens = []
|
| 83 |
+
for _ in range(3):
|
| 84 |
+
out = model(cur, logits_to_keep=1)
|
| 85 |
+
tok = out.logits[:, -1].argmax(-1, keepdim=True)
|
| 86 |
+
cur = torch.cat([cur, tok], dim=1)
|
| 87 |
+
no_cache_tokens.append(int(tok.item()))
|
| 88 |
+
|
| 89 |
+
# KV-cache: feed only the new token after the first call.
|
| 90 |
+
out = model(prompt, use_cache=True, logits_to_keep=1)
|
| 91 |
+
caches = out.caches
|
| 92 |
+
tok = out.logits[:, -1].argmax(-1, keepdim=True)
|
| 93 |
+
cache_tokens = [int(tok.item())]
|
| 94 |
+
for _ in range(2):
|
| 95 |
+
out = model(tok, caches=caches, use_cache=True, logits_to_keep=1)
|
| 96 |
+
caches = out.caches
|
| 97 |
+
tok = out.logits[:, -1].argmax(-1, keepdim=True)
|
| 98 |
+
cache_tokens.append(int(tok.item()))
|
| 99 |
+
|
| 100 |
+
assert no_cache_tokens == cache_tokens
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def test_moe_and_span_bank_shapes():
|
| 104 |
+
moe = MoELayer(32, 64, n_routed_experts=3, n_shared_experts=1, num_experts_per_tok=2)
|
| 105 |
+
x = torch.randn(2, 4, 32)
|
| 106 |
+
assert moe(x).shape == x.shape
|
| 107 |
+
bank = SpanBank(max_entries=8, hidden_size=32)
|
| 108 |
+
bank.add(torch.randn(3, 32), torch.randn(3, 32))
|
| 109 |
+
assert bank.query(torch.randn(5, 32)).shape == (5, 32)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def test_tokenizer_fallback_roundtrip():
|
| 113 |
+
tok = ChimeraTokenizer(vocab_size=512)
|
| 114 |
+
text = "hello cpu"
|
| 115 |
+
assert tok.decode(tok.encode(text)) == text
|
tests/test_config.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from chimera.config import load_config, scale_config
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def test_config_scaling_without_torch_runtime():
|
| 5 |
+
cfg = scale_config(load_config("config.json"), "nano")
|
| 6 |
+
assert cfg["hidden_size"] == 128
|
| 7 |
+
assert cfg["num_hidden_layers"] == 4
|
| 8 |
+
assert cfg["vocab_size"] <= 8192
|
train.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Chimera 5.2 β CPU-first training script.
|
| 4 |
+
|
| 5 |
+
Highlights vs the previous version:
|
| 6 |
+
|
| 7 |
+
* MeZO optimiser uses a single deterministic seed per step, samples each
|
| 8 |
+
parameter's perturbation direction *on demand* via per-parameter seeds and
|
| 9 |
+
drops the heavy direction cache. This brings the memory cost of MeZO back
|
| 10 |
+
down to "1Γ model" exactly as advertised.
|
| 11 |
+
* AdamW path uses fused parameter groups and shares the same loss closure as
|
| 12 |
+
MeZO so accumulation and logging are identical between modes.
|
| 13 |
+
* Logging never references an undefined ``lr`` (the previous draft printed it
|
| 14 |
+
before the AdamW step ran on the first accumulator boundary).
|
| 15 |
+
* Gradient checkpointing falls back to ``use_reentrant=False`` (the modern,
|
| 16 |
+
faster path).
|
| 17 |
+
* Tokeniser/dataset loading is unchanged but the Python loops are skipped
|
| 18 |
+
entirely for ``max_tokens=0``.
|
| 19 |
+
|
| 20 |
+
Recommended commands::
|
| 21 |
+
|
| 22 |
+
# MeZO smoke test on TinyStories
|
| 23 |
+
python train.py --scale tiny --seq_len 64 --max_steps 20 --optimizer mezo
|
| 24 |
+
|
| 25 |
+
# AdamW with grad checkpointing + bf16
|
| 26 |
+
python train.py --scale small --seq_len 256 --max_steps 1000 \\
|
| 27 |
+
--optimizer adamw --grad_checkpoint --bf16
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
from __future__ import annotations
|
| 31 |
+
|
| 32 |
+
import argparse
|
| 33 |
+
import json
|
| 34 |
+
import math
|
| 35 |
+
import os
|
| 36 |
+
import time
|
| 37 |
+
|
| 38 |
+
# CPU threading must be configured *before* importing torch.
|
| 39 |
+
def _setup_cpu_runtime() -> None:
|
| 40 |
+
n_cpus = os.cpu_count() or 4
|
| 41 |
+
os.environ.setdefault("OMP_NUM_THREADS", str(n_cpus))
|
| 42 |
+
os.environ.setdefault("MKL_NUM_THREADS", str(n_cpus))
|
| 43 |
+
os.environ.setdefault("KMP_AFFINITY", "granularity=fine,compact,1,0")
|
| 44 |
+
os.environ.setdefault("KMP_BLOCKTIME", "1")
|
| 45 |
+
os.environ.setdefault("MALLOC_CONF", "background_thread:true,metadata_thp:auto")
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
_setup_cpu_runtime()
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
import torch
|
| 52 |
+
import torch.nn as nn
|
| 53 |
+
from torch.utils.data import DataLoader
|
| 54 |
+
|
| 55 |
+
from chimera import Chimera51ForCausalLM
|
| 56 |
+
from chimera.paths import DEFAULT_CONFIG_PATH
|
| 57 |
+
from chimera.training import (
|
| 58 |
+
build_sequence_dataset,
|
| 59 |
+
apply_standard_config_tweaks,
|
| 60 |
+
MeZOOptimizer,
|
| 61 |
+
train_standard_loop,
|
| 62 |
+
)
|
| 63 |
+
from chimera.quantization import BitLinear
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
torch.set_num_threads(int(os.environ.get("OMP_NUM_THREADS", os.cpu_count() or 4)))
|
| 67 |
+
try:
|
| 68 |
+
torch.set_num_interop_threads(int(os.environ.get("CHIMERA_INTEROP_THREADS", "1")))
|
| 69 |
+
except RuntimeError:
|
| 70 |
+
pass
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
# Optional Intel Extension for PyTorch.
|
| 74 |
+
HAS_IPEX = False
|
| 75 |
+
try: # pragma: no cover - optional dependency.
|
| 76 |
+
import intel_extension_for_pytorch as ipex # noqa: F401
|
| 77 |
+
HAS_IPEX = True
|
| 78 |
+
except Exception:
|
| 79 |
+
pass
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# Dataset & tokenisation helpers.
|
| 83 |
+
# ---------------------------------------------------------------------------
|
| 84 |
+
|
| 85 |
+
def build_dataset(seq_len: int, max_samples=None, max_tokens=None,
|
| 86 |
+
split: str = "train",
|
| 87 |
+
dataset_name: str = "roneneldan/TinyStories",
|
| 88 |
+
dataset_config: str = None, text_column: str = "auto",
|
| 89 |
+
category_filter: str = None,
|
| 90 |
+
include_reasoning: bool = False):
|
| 91 |
+
from chimera import ChimeraTokenizer
|
| 92 |
+
|
| 93 |
+
tok = ChimeraTokenizer(pretrained="o200k_base")
|
| 94 |
+
dataset = build_sequence_dataset(
|
| 95 |
+
seq_len,
|
| 96 |
+
max_samples=max_samples,
|
| 97 |
+
max_tokens=max_tokens,
|
| 98 |
+
split=split,
|
| 99 |
+
dataset_name=dataset_name,
|
| 100 |
+
dataset_config=dataset_config,
|
| 101 |
+
text_column=text_column,
|
| 102 |
+
category_filter=category_filter,
|
| 103 |
+
include_reasoning=include_reasoning,
|
| 104 |
+
)
|
| 105 |
+
return dataset, tok
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# ---------------------------------------------------------------------------
|
| 109 |
+
# Main loop.
|
| 110 |
+
# ---------------------------------------------------------------------------
|
| 111 |
+
|
| 112 |
+
def train(args) -> None:
|
| 113 |
+
with open(args.config) as f:
|
| 114 |
+
config = json.load(f)
|
| 115 |
+
config = apply_standard_config_tweaks(config, scale=args.scale, seq_len=args.seq_len)
|
| 116 |
+
|
| 117 |
+
use_mezo = (args.optimizer == "mezo")
|
| 118 |
+
use_bf16 = bool(args.bf16)
|
| 119 |
+
use_compile = bool(args.compile)
|
| 120 |
+
|
| 121 |
+
print("=" * 60)
|
| 122 |
+
print(f"CHIMERA 5.2 TRAINING β scale={args.scale}, "
|
| 123 |
+
f"optimizer={'MeZO' if use_mezo else 'AdamW'}, bf16={use_bf16}")
|
| 124 |
+
print(f"Layers={config['num_hidden_layers']} hidden={config['hidden_size']} "
|
| 125 |
+
f"vocab={config['vocab_size']} seq_len={args.seq_len} steps={args.max_steps}")
|
| 126 |
+
print(f"Threads: {torch.get_num_threads()} IPEX={HAS_IPEX}")
|
| 127 |
+
print("=" * 60)
|
| 128 |
+
|
| 129 |
+
model = Chimera51ForCausalLM(config)
|
| 130 |
+
counts = model.count_parameters()
|
| 131 |
+
print(f"Params: total={counts['total']:,} ternary={counts['ternary']:,}")
|
| 132 |
+
|
| 133 |
+
if args.grad_checkpoint and not use_mezo:
|
| 134 |
+
model.enable_gradient_checkpointing()
|
| 135 |
+
print("[OPT] Gradient checkpointing ON")
|
| 136 |
+
|
| 137 |
+
if HAS_IPEX and not use_mezo:
|
| 138 |
+
adamw = torch.optim.AdamW(model.parameters(), lr=args.lr)
|
| 139 |
+
model, adamw = ipex.optimize(
|
| 140 |
+
model, optimizer=adamw,
|
| 141 |
+
dtype=torch.bfloat16 if use_bf16 else torch.float32, level="O1")
|
| 142 |
+
print("[OPT] IPEX optimisation applied (level O1)")
|
| 143 |
+
else:
|
| 144 |
+
adamw = None
|
| 145 |
+
|
| 146 |
+
if use_compile:
|
| 147 |
+
print("[OPT] Compiling model with torch.compile (inductor)...")
|
| 148 |
+
model = torch.compile(model, backend="inductor", mode="default", dynamic=True)
|
| 149 |
+
|
| 150 |
+
dataset, tok = build_dataset(
|
| 151 |
+
args.seq_len, max_samples=args.max_samples, max_tokens=args.max_tokens,
|
| 152 |
+
split=args.dataset_split, dataset_name=args.dataset_name,
|
| 153 |
+
dataset_config=args.dataset_config, text_column=args.text_column,
|
| 154 |
+
category_filter=args.category_filter,
|
| 155 |
+
include_reasoning=args.include_reasoning,
|
| 156 |
+
)
|
| 157 |
+
loader = DataLoader(
|
| 158 |
+
dataset, batch_size=args.batch_size, shuffle=True,
|
| 159 |
+
num_workers=args.num_workers, drop_last=True,
|
| 160 |
+
persistent_workers=args.num_workers > 0,
|
| 161 |
+
prefetch_factor=2 if args.num_workers > 0 else None,
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
if use_mezo:
|
| 165 |
+
optimizer = MeZOOptimizer(
|
| 166 |
+
model, lr=args.lr * 0.01, eps=1e-3,
|
| 167 |
+
weight_decay=0.1, momentum=0.9, direction=args.mezo_direction,
|
| 168 |
+
)
|
| 169 |
+
else:
|
| 170 |
+
no_decay = {"A_log", "dt_bias", "norm", "bias", "embed", "energy_weights"}
|
| 171 |
+
decay_params, no_decay_params = [], []
|
| 172 |
+
for n, p in model.named_parameters():
|
| 173 |
+
if not p.requires_grad:
|
| 174 |
+
continue
|
| 175 |
+
if any(tag in n for tag in no_decay):
|
| 176 |
+
no_decay_params.append(p)
|
| 177 |
+
else:
|
| 178 |
+
decay_params.append(p)
|
| 179 |
+
if adamw is None:
|
| 180 |
+
optimizer = torch.optim.AdamW(
|
| 181 |
+
[{"params": decay_params, "weight_decay": 0.1},
|
| 182 |
+
{"params": no_decay_params, "weight_decay": 0.0}],
|
| 183 |
+
lr=args.lr, betas=(0.9, 0.95))
|
| 184 |
+
else:
|
| 185 |
+
optimizer = adamw
|
| 186 |
+
|
| 187 |
+
def compute_loss(batch) -> torch.Tensor:
|
| 188 |
+
ids = batch["input_ids"][:, :-1]
|
| 189 |
+
labels = batch["labels"][:, 1:]
|
| 190 |
+
if use_bf16:
|
| 191 |
+
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
|
| 192 |
+
out = model(ids, labels=labels)
|
| 193 |
+
else:
|
| 194 |
+
out = model(ids, labels=labels)
|
| 195 |
+
return out.loss
|
| 196 |
+
|
| 197 |
+
train_standard_loop(args, model, config, loader, compute_loss, optimizer, use_mezo)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
# ---------------------------------------------------------------------------
|
| 201 |
+
# CLI
|
| 202 |
+
# ---------------------------------------------------------------------------
|
| 203 |
+
|
| 204 |
+
def _build_argparser() -> argparse.ArgumentParser:
|
| 205 |
+
p = argparse.ArgumentParser(description="Chimera 5.2 CPU-first training")
|
| 206 |
+
p.add_argument("--config", default=str(DEFAULT_CONFIG_PATH))
|
| 207 |
+
p.add_argument("--scale", default="tiny", choices=["tiny", "small", "medium", "full"])
|
| 208 |
+
p.add_argument("--seq_len", type=int, default=256)
|
| 209 |
+
p.add_argument("--optimizer", default="mezo", choices=["mezo", "adamw"])
|
| 210 |
+
p.add_argument("--batch_size", type=int, default=2)
|
| 211 |
+
p.add_argument("--grad_accum", type=int, default=8)
|
| 212 |
+
p.add_argument("--lr", type=float, default=1e-3)
|
| 213 |
+
p.add_argument("--warmup", type=int, default=200)
|
| 214 |
+
p.add_argument("--max_steps", type=int, default=5000)
|
| 215 |
+
p.add_argument("--max_samples", type=int, default=None)
|
| 216 |
+
p.add_argument("--max_tokens", type=int, default=None)
|
| 217 |
+
p.add_argument("--bf16", action="store_true", default=True)
|
| 218 |
+
p.add_argument("--no-bf16", dest="bf16", action="store_false")
|
| 219 |
+
p.add_argument("--compile", action="store_true", default=False)
|
| 220 |
+
p.add_argument("--grad_checkpoint", action="store_true", default=True)
|
| 221 |
+
p.add_argument("--no-grad-checkpoint", dest="grad_checkpoint", action="store_false")
|
| 222 |
+
p.add_argument("--mezo_direction", choices=["rademacher", "gaussian"],
|
| 223 |
+
default="rademacher")
|
| 224 |
+
p.add_argument("--dataset_name", default="roneneldan/TinyStories")
|
| 225 |
+
p.add_argument("--dataset_config", default=None)
|
| 226 |
+
p.add_argument("--dataset_split", default="train")
|
| 227 |
+
p.add_argument("--text_column", default="auto")
|
| 228 |
+
p.add_argument("--category_filter", default=None)
|
| 229 |
+
p.add_argument("--include_reasoning", action="store_true", default=False)
|
| 230 |
+
p.add_argument("--num_workers", type=int, default=2)
|
| 231 |
+
p.add_argument("--log_every", type=int, default=10)
|
| 232 |
+
p.add_argument("--save_every", type=int, default=1000)
|
| 233 |
+
p.add_argument("--output_dir", default="./chimera_output")
|
| 234 |
+
return p
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
if __name__ == "__main__":
|
| 238 |
+
args = _build_argparser().parse_args()
|
| 239 |
+
train(args)
|
train_fast.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""Chimera 5.2 β Fast CPU training with pre-tokenized dataset cache."""
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import json
|
| 7 |
+
import math
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
# CPU threading must be configured *before* importing torch.
|
| 11 |
+
ncpus = int(os.environ.get("OMP_NUM_THREADS", os.cpu_count() or 4))
|
| 12 |
+
os.environ["OMP_NUM_THREADS"] = str(ncpus)
|
| 13 |
+
os.environ["MKL_NUM_THREADS"] = str(ncpus)
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
from torch.utils.data import DataLoader
|
| 17 |
+
|
| 18 |
+
from chimera import Chimera51ForCausalLM
|
| 19 |
+
from chimera.paths import DEFAULT_CONFIG_PATH
|
| 20 |
+
from chimera.training import (
|
| 21 |
+
PreTokenizedDataset,
|
| 22 |
+
apply_standard_config_tweaks,
|
| 23 |
+
train_fast_loop,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
torch.set_num_threads(ncpus)
|
| 28 |
+
try:
|
| 29 |
+
torch.set_num_interop_threads(1)
|
| 30 |
+
except RuntimeError:
|
| 31 |
+
pass
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def build_or_load_dataset(seq_len: int, max_samples: int, cache_dir: str = "./cache"):
|
| 35 |
+
cache_path = os.path.join(cache_dir, f"tiny_stories_{seq_len}_{max_samples}.pt")
|
| 36 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 37 |
+
|
| 38 |
+
if os.path.exists(cache_path):
|
| 39 |
+
print(f"[CACHE] Loading pre-tokenized dataset from {cache_path}")
|
| 40 |
+
chunks = torch.load(cache_path, weights_only=False)
|
| 41 |
+
return PreTokenizedDataset(chunks, seq_len)
|
| 42 |
+
|
| 43 |
+
from datasets import load_dataset
|
| 44 |
+
from chimera import ChimeraTokenizer
|
| 45 |
+
|
| 46 |
+
print(f"[DATA] Downloading TinyStories...")
|
| 47 |
+
ds = load_dataset("roneneldan/TinyStories", split="train", streaming=True)
|
| 48 |
+
tok = ChimeraTokenizer(pretrained="o200k_base")
|
| 49 |
+
|
| 50 |
+
target = max_samples * (seq_len + 1)
|
| 51 |
+
buffer = torch.empty(target, dtype=torch.long)
|
| 52 |
+
buf_idx = 0
|
| 53 |
+
processed = 0
|
| 54 |
+
|
| 55 |
+
for ex in ds:
|
| 56 |
+
text = ex.get("text", "")
|
| 57 |
+
if not text:
|
| 58 |
+
continue
|
| 59 |
+
ids = tok.encode(text, add_special_tokens=False)
|
| 60 |
+
ids.append(tok.eos_token_id)
|
| 61 |
+
n = len(ids)
|
| 62 |
+
if buf_idx + n > target:
|
| 63 |
+
n = target - buf_idx
|
| 64 |
+
if n <= 0:
|
| 65 |
+
break
|
| 66 |
+
ids = ids[:n]
|
| 67 |
+
if n > 0:
|
| 68 |
+
buffer[buf_idx:buf_idx + n] = torch.tensor(ids, dtype=torch.long)
|
| 69 |
+
buf_idx += n
|
| 70 |
+
processed += 1
|
| 71 |
+
if (processed % 1000) == 0:
|
| 72 |
+
print(f" {processed:,} stories, {buf_idx:,}/{target} tokens...")
|
| 73 |
+
if buf_idx >= target:
|
| 74 |
+
break
|
| 75 |
+
|
| 76 |
+
all_ids = buffer[:buf_idx]
|
| 77 |
+
n = all_ids.numel() // (seq_len + 1)
|
| 78 |
+
chunks = all_ids[:n * (seq_len + 1)]
|
| 79 |
+
|
| 80 |
+
torch.save(chunks, cache_path)
|
| 81 |
+
print(f"[CACHE] Saved {chunks.numel():,} tokens to {cache_path}")
|
| 82 |
+
return PreTokenizedDataset(chunks, seq_len)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def train(args) -> None:
|
| 86 |
+
with open(args.config) as f:
|
| 87 |
+
config = json.load(f)
|
| 88 |
+
config = apply_standard_config_tweaks(config, scale=args.scale, seq_len=args.seq_len)
|
| 89 |
+
|
| 90 |
+
print("=" * 60)
|
| 91 |
+
print(f"CHIMERA 5.2 FAST TRAIN β scale={args.scale}, seq_len={args.seq_len}, steps={args.max_steps}")
|
| 92 |
+
print(f"Layers={config['num_hidden_layers']} hidden={config['hidden_size']} vocab={config['vocab_size']}")
|
| 93 |
+
print(f"Threads: {torch.get_num_threads()} bf16={args.bf16} compile={args.compile}")
|
| 94 |
+
print("=" * 60)
|
| 95 |
+
|
| 96 |
+
model = Chimera51ForCausalLM(config)
|
| 97 |
+
counts = model.count_parameters()
|
| 98 |
+
print(f"Params: total={counts['total']:,} ternary={counts['ternary']:,}")
|
| 99 |
+
|
| 100 |
+
if args.compile:
|
| 101 |
+
print("[OPT] Compiling model...")
|
| 102 |
+
model = torch.compile(model, backend="inductor", mode="default", dynamic=True)
|
| 103 |
+
|
| 104 |
+
dataset = build_or_load_dataset(args.seq_len, args.max_samples, args.cache_dir)
|
| 105 |
+
loader = DataLoader(
|
| 106 |
+
dataset, batch_size=args.batch_size, shuffle=True,
|
| 107 |
+
num_workers=0, drop_last=True,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
def compute_loss(batch) -> torch.Tensor:
|
| 111 |
+
ids = batch["input_ids"]
|
| 112 |
+
labels = batch["labels"]
|
| 113 |
+
if args.bf16:
|
| 114 |
+
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
|
| 115 |
+
out = model(ids, labels=labels)
|
| 116 |
+
else:
|
| 117 |
+
out = model(ids, labels=labels)
|
| 118 |
+
return out.loss
|
| 119 |
+
|
| 120 |
+
train_fast_loop(args, model, config, loader, compute_loss)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
if __name__ == "__main__":
|
| 124 |
+
p = argparse.ArgumentParser(description="Chimera 5.2 Fast CPU training")
|
| 125 |
+
p.add_argument("--config", default=str(DEFAULT_CONFIG_PATH))
|
| 126 |
+
p.add_argument("--scale", default="tiny", choices=["tiny", "small", "medium", "full"])
|
| 127 |
+
p.add_argument("--seq_len", type=int, default=32)
|
| 128 |
+
p.add_argument("--batch_size", type=int, default=4)
|
| 129 |
+
p.add_argument("--lr", type=float, default=1e-3)
|
| 130 |
+
p.add_argument("--warmup", type=int, default=100)
|
| 131 |
+
p.add_argument("--max_steps", type=int, default=1000)
|
| 132 |
+
p.add_argument("--max_samples", type=int, default=5000)
|
| 133 |
+
p.add_argument("--bf16", action="store_true", default=False)
|
| 134 |
+
p.add_argument("--compile", action="store_true", default=False)
|
| 135 |
+
p.add_argument("--cache_dir", default="./cache")
|
| 136 |
+
p.add_argument("--log_every", type=int, default=10)
|
| 137 |
+
p.add_argument("--save_every", type=int, default=500)
|
| 138 |
+
p.add_argument("--output_dir", default="./chimera_output")
|
| 139 |
+
args = p.parse_args()
|
| 140 |
+
train(args)
|
train_hyper.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Chimera 5.3 β HYPER CPU Training v3 (10,000+ tok/s target)
|
| 4 |
+
============================================================
|
| 5 |
+
|
| 6 |
+
ALL features preserved: 28 layers, MoE, Parcae looping, SelfEvolution,
|
| 7 |
+
SpanInference, Grammar, EntropyValve, DebtLedger β nothing disabled.
|
| 8 |
+
|
| 9 |
+
Speed comes from optimizing HOW the forward+MeZO runs, not WHAT it runs:
|
| 10 |
+
|
| 11 |
+
P1 GrowLength Curriculum β seq 8βtarget, huge batch at short lengths
|
| 12 |
+
P2 Reservoir Freezing β freeze recurrent gates (fewer params to perturb)
|
| 13 |
+
P3 In-Place Seed MeZO β no randn allocation, seed-replay perturbation
|
| 14 |
+
P4 torch.compile β fuse ops, eliminate Python overhead
|
| 15 |
+
P5 Train-Mode STE Path β BitLinear uses STE (no invalidate_packed)
|
| 16 |
+
P6 Aggressive Token Packing β zero padding waste
|
| 17 |
+
P7 Progressive Unfreeze β fewer params early = faster perturbation
|
| 18 |
+
P8 Vocab Projection Cache β cache lm_head weight for 200K vocab
|
| 19 |
+
P9 Loop-1 Training β force num_loops=1 during training (full arch)
|
| 20 |
+
|
| 21 |
+
Key insight: MeZO's bottleneck is not the forward pass β it's
|
| 22 |
+
generating+applying random perturbations to 227M params 3Γ per step.
|
| 23 |
+
Seed-replay MeZO eliminates this entirely: perturb in-place using a
|
| 24 |
+
single seed, replay the same seed to restore/update.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
from __future__ import annotations
|
| 28 |
+
|
| 29 |
+
import argparse
|
| 30 |
+
import os
|
| 31 |
+
|
| 32 |
+
def _setup_cpu():
|
| 33 |
+
n = os.cpu_count() or 4
|
| 34 |
+
os.environ.setdefault("OMP_NUM_THREADS", str(n))
|
| 35 |
+
os.environ.setdefault("MKL_NUM_THREADS", str(n))
|
| 36 |
+
os.environ.setdefault("KMP_AFFINITY", "granularity=fine,compact,1,0")
|
| 37 |
+
os.environ.setdefault("KMP_BLOCKTIME", "1")
|
| 38 |
+
return n
|
| 39 |
+
|
| 40 |
+
_NCPU = _setup_cpu()
|
| 41 |
+
|
| 42 |
+
import torch
|
| 43 |
+
|
| 44 |
+
from chimera.paths import DEFAULT_CONFIG_PATH
|
| 45 |
+
from chimera.training import (
|
| 46 |
+
GrowLengthDataset,
|
| 47 |
+
GrowLengthScheduler,
|
| 48 |
+
ProgressiveUnfreezer,
|
| 49 |
+
apply_reservoir_freezing,
|
| 50 |
+
benchmark_hyper,
|
| 51 |
+
build_model_from_args,
|
| 52 |
+
build_token_buffer,
|
| 53 |
+
patch_training_loops,
|
| 54 |
+
train_hyper_loop,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
torch.set_num_threads(int(os.environ["OMP_NUM_THREADS"]))
|
| 58 |
+
try:
|
| 59 |
+
torch.set_num_interop_threads(max(1, _NCPU // 4))
|
| 60 |
+
except RuntimeError:
|
| 61 |
+
pass
|
| 62 |
+
|
| 63 |
+
_HAS_IPEX = False
|
| 64 |
+
try:
|
| 65 |
+
import intel_extension_for_pytorch as ipex
|
| 66 |
+
_HAS_IPEX = True
|
| 67 |
+
except Exception:
|
| 68 |
+
pass
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def build_model(args):
|
| 72 |
+
return build_model_from_args(args)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 76 |
+
# MAIN HYPER TRAIN
|
| 77 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 78 |
+
|
| 79 |
+
def train_hyper(args):
|
| 80 |
+
model, config = build_model(args)
|
| 81 |
+
counts = model.count_parameters()
|
| 82 |
+
|
| 83 |
+
print("=" * 65)
|
| 84 |
+
print(f"CHIMERA 5.3 HYPER v3 β scale={args.scale} bf16={args.bf16}")
|
| 85 |
+
print(f"Layers={config['num_hidden_layers']} hidden={config['hidden_size']} "
|
| 86 |
+
f"vocab={config['vocab_size']} target_seq={args.seq_len}")
|
| 87 |
+
print(f"Threads: {torch.get_num_threads()} IPEX={_HAS_IPEX}")
|
| 88 |
+
print(f"Params: total={counts['total']:,} ternary={counts['ternary']:,}")
|
| 89 |
+
print(f"ALL features ON: looping={model.looping_enabled} "
|
| 90 |
+
f"evolution={model.evolution is not None} "
|
| 91 |
+
f"span={model.span_engine is not None}")
|
| 92 |
+
print("=" * 65)
|
| 93 |
+
|
| 94 |
+
# ββ P9: Force loop=1 during training βββββββββββββββββββββββββββββ
|
| 95 |
+
# Architecture intact, but save 1 full pass through layers 4-23
|
| 96 |
+
patch_training_loops(model, num_loops=1)
|
| 97 |
+
print(f"[P9] Training loops=1 (arch intact, Parcae wired)")
|
| 98 |
+
|
| 99 |
+
# ββ P2: Reservoir Freezing βββββββββββββββββββββββββββββββββββββββ
|
| 100 |
+
if args.reservoir:
|
| 101 |
+
frozen = apply_reservoir_freezing(model)
|
| 102 |
+
print(f"[P2] Reservoir: froze {frozen:,} gate params")
|
| 103 |
+
|
| 104 |
+
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 105 |
+
print(f"[INFO] Trainable: {trainable:,} / {counts['total']:,}")
|
| 106 |
+
|
| 107 |
+
# ββ P7: Progressive Unfreezing βββββββββββββββββββββββββββββββββββ
|
| 108 |
+
unfreezer = None
|
| 109 |
+
if args.progressive_unfreeze:
|
| 110 |
+
unfreezer = ProgressiveUnfreezer(model, args.max_steps, args.unfreeze_stages)
|
| 111 |
+
active = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 112 |
+
print(f"[P7] Progressive unfreeze: {active:,} initially trainable")
|
| 113 |
+
|
| 114 |
+
# ββ P1: GrowLength βββββββββββββββββββββββββββββββββββββββββββββββ
|
| 115 |
+
if args.growlength:
|
| 116 |
+
stages = [
|
| 117 |
+
(max(8, args.seq_len // 4), 0.30),
|
| 118 |
+
(max(16, args.seq_len // 2), 0.30),
|
| 119 |
+
(args.seq_len, 0.40),
|
| 120 |
+
]
|
| 121 |
+
grow = GrowLengthScheduler(stages, args.max_steps)
|
| 122 |
+
initial_seq = stages[0][0]
|
| 123 |
+
print(f"[P1] GrowLength: {' β '.join(str(s) for s, _ in stages)}")
|
| 124 |
+
else:
|
| 125 |
+
grow = None
|
| 126 |
+
initial_seq = args.seq_len
|
| 127 |
+
|
| 128 |
+
# ββ Data βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 129 |
+
tok_budget = args.max_tokens or max(500_000,
|
| 130 |
+
args.max_steps * args.batch_size * (args.seq_len + 1) * 4)
|
| 131 |
+
token_buf = build_token_buffer(
|
| 132 |
+
args.dataset_name, args.dataset_split, args.text_column,
|
| 133 |
+
tok_budget, args.cache_dir)
|
| 134 |
+
dataset = GrowLengthDataset(token_buf, initial_seq)
|
| 135 |
+
print(f"[DATA] {token_buf.numel():,} tokens seq={initial_seq}")
|
| 136 |
+
|
| 137 |
+
train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 141 |
+
# CLI
|
| 142 |
+
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 143 |
+
|
| 144 |
+
def cli():
|
| 145 |
+
p = argparse.ArgumentParser(description="Chimera 5.3 HYPER v3")
|
| 146 |
+
p.add_argument("--config", default=str(DEFAULT_CONFIG_PATH))
|
| 147 |
+
p.add_argument("--scale", default="tiny", choices=["tiny", "small", "medium", "full"])
|
| 148 |
+
p.add_argument("--seq_len", type=int, default=64)
|
| 149 |
+
p.add_argument("--batch_size", type=int, default=8)
|
| 150 |
+
p.add_argument("--lr", type=float, default=1e-3)
|
| 151 |
+
p.add_argument("--warmup", type=int, default=100)
|
| 152 |
+
p.add_argument("--max_steps", type=int, default=5000)
|
| 153 |
+
p.add_argument("--max_tokens", type=int, default=None)
|
| 154 |
+
p.add_argument("--max_samples", type=int, default=None)
|
| 155 |
+
p.add_argument("--bf16", action="store_true", default=True)
|
| 156 |
+
p.add_argument("--no-bf16", dest="bf16", action="store_false")
|
| 157 |
+
p.add_argument("--compile", action="store_true", default=False)
|
| 158 |
+
p.add_argument("--dataset_name", default="roneneldan/TinyStories")
|
| 159 |
+
p.add_argument("--dataset_split", default="train")
|
| 160 |
+
p.add_argument("--text_column", default="auto")
|
| 161 |
+
p.add_argument("--cache_dir", default="./cache")
|
| 162 |
+
p.add_argument("--log_every", type=int, default=10)
|
| 163 |
+
p.add_argument("--save_every", type=int, default=1000)
|
| 164 |
+
p.add_argument("--output_dir", default="./chimera_hyper_output")
|
| 165 |
+
|
| 166 |
+
g = p.add_argument_group("paradigms")
|
| 167 |
+
g.add_argument("--all", action="store_true", default=False)
|
| 168 |
+
g.add_argument("--growlength", action="store_true", default=False)
|
| 169 |
+
g.add_argument("--reservoir", action="store_true", default=False)
|
| 170 |
+
g.add_argument("--mezo-eps", type=float, default=1e-3, dest="mezo_eps")
|
| 171 |
+
g.add_argument("--progressive-unfreeze", action="store_true", default=False,
|
| 172 |
+
dest="progressive_unfreeze")
|
| 173 |
+
g.add_argument("--unfreeze-stages", type=int, default=4, dest="unfreeze_stages")
|
| 174 |
+
p.add_argument("--benchmark", action="store_true", default=False)
|
| 175 |
+
return p
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
if __name__ == "__main__":
|
| 179 |
+
args = cli().parse_args()
|
| 180 |
+
if args.max_samples and not args.max_tokens:
|
| 181 |
+
args.max_tokens = args.max_samples * (args.seq_len + 1)
|
| 182 |
+
if args.all:
|
| 183 |
+
args.growlength = True
|
| 184 |
+
args.reservoir = True
|
| 185 |
+
args.progressive_unfreeze = True
|
| 186 |
+
if args.benchmark:
|
| 187 |
+
args.growlength = True
|
| 188 |
+
args.reservoir = True
|
| 189 |
+
args.progressive_unfreeze = True
|
| 190 |
+
benchmark_hyper(args)
|
| 191 |
+
else:
|
| 192 |
+
train_hyper(args)
|