File size: 12,620 Bytes
d8bc908 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 | # True Ternary Refactor 4 β True Ternary Exponent Dynamics
## Scope
Phase 9 was originally planned as "Ternary-FP8 Hybrid Precision Bridge" β upgrading E buffers from int8 log2 to float8_e4m3fn. An exploration session determined this approach was architecturally wrong: FP8 E reintroduces IEEE float mantissa/exponent into a system designed to eliminate it. The phase was replanned as **True Ternary Exponent Dynamics**, replacing the FP8 approach with the mathematically-correct logarithmic scaling system.
This refactor implements the first three waves of the replanned Phase 9.
## What Changed
### Architecture Decisions (from exploration session)
Five core principles were crystallized and are now encoded in code:
| Principle | Meaning | Code Impact |
|-----------|---------|-------------|
| **S is never stored** | S = 2^E is a function, not a value. No float8/int16/float stored. | `_get_S()` restored to `torch.exp2(E.float())` |
| **E is hybrid state** | Persistent int8 buffer, updated via EMA with statistical guidance (not pure SignSGD) | `update_E()` replaced SignSGD with EMA |
| **LossComponent = temperature** | Per-component loss signals control Ξ± (update energy) per group | `update_E(loss_signal)` modulates Ξ± via sigmoid |
| **TScaleType = fixed lattice** | Structure is stable; what's dynamic is energy routing across it | Group sizes unchanged |
| **Representation singular, learning ensemble** | Forward pass always single-scale; multi-scale exists only in update pathway | Forward uses single `W = T * 2^E` |
### REPLAN (before execution)
**Commit:** `9103dc5` β replan Phase 9: True Ternary Exponent Dynamics
- ROADMAP.md updated: Phase 9 renamed from "Ternary-FP8 Hybrid Precision Bridge" to "True Ternary Exponent Dynamics", 3 plans instead of 4
- REQUIREMENTS.md: HYB-01β06 replaced with TERN-E-01β05
- CONTEXT.md rewritten from FP8-centric to true ternary principles
- STATE.md updated with new phase structure
- Old FP8 plans (09-03, 09-04) and summaries (09-01, 09-02) removed
See `.planning/notes/true-ternary-architecture-principles.md` for full architecture rationale.
### Plan 09-01 β Roll Back FP8 E to int8 (TERN-E-01, TERN-E-02)
**Commit:** `80c6188` (+ SUMMARY `2bfc42a`)
Undid all FP8 changes from old Phase 9 Waves 1-2:
#### `tscale.py`
- **E buffer init** (lines 906, 1235): `clamp(-448, 448).to(float8_e4m3fn)` β `log2().clamp(-128, 127).to(int8)`
- **`_get_S`**: `E_exp.float()` β `torch.exp2(E_exp.float())` β restores log2βexp2 dequant
- **CPU update_E**: `clamp(E.float() + grad_E * 0.0625, -448, 448).to(float8_e4m3fn)` β `clamp(E.float() + grad_E, -128, 127).to(int8)`
- **tscale_to**: `clamp(-448, 448).to(float8)` β `log2().clamp(-128, 127).to(int8)`
- **5 Triton forward kernels** (fwd, grad_x, embed_fwd, rmsnorm_fwd, rmsnorm_bwd): load E with `other=0` (int), `.to(tl.float32)` β `tl.exp2(e_val)` replaces direct float cast
- **2 Triton update kernels** (`_triton_update_e`, `_triton_update_e_direct`): `other=0.0` β `other=0`, `tl.minimum(448, ...)` β `tl.minimum(127, ...)`, store as `tl.int8` instead of `tl.float8e4nv`
- **TernaryRMSNorm E init**: `clamp(-448, 448).to(float8)` β `log2().clamp(-128, 127).to(int8)`
#### `trigram.py`
- **ByteEmbedding E init**: `clamp(-448, 448).to(float8)` β `log2().clamp(-128, 127).to(int8)`
- **ByteEmbedding CPU forward**: `E_exp.float()` β `torch.exp2(E_exp.float())`
- **ByteEmbedding update_E**: `clamp(E.float() + grad_E * 0.0625, -448, 448).to(float8)` β `clamp(E.float() + grad_E, -128, 127).to(int8)`
#### `ternary_audit.py`
- Removed `buf.dtype != torch.float8_e4m3fn` exclusion from `float_buffers` filter
#### `testing/test_tscale.py`
- Removed 11 FP8-specific test functions and their registrations
- Updated update_E correctness test comment (was FP8-specific)
### Plan 09-02 β EMA-based E Update Rule (TERN-E-03)
**Commit:** `97d0482` (+ SUMMARY `c1f4be7`)
Replaced the old SignSGD update_E with an EMA in log-space:
```python
# Old (SignSGD):
grp_mean_sign = grouped.mean(dim=2).sign()
grad_E = -grp_mean_sign.float()
E = clamp(E + grad_E, -128, 127)
# New (EMA):
mu_g = grouped.abs().mean(dim=2) # mean abs gradient per group
e_proposed = round(log2(mu_g + 1e-10)).clamp(-128, 127)
E = (1-Ξ±) * E + Ξ± * e_proposed # Ξ± defaults to 0.1
```
Applied to both `TernaryScaleTensor.update_E()` and `ByteEmbedding.update_E()`.
**Why this matters:** The old SignSGD only knew direction (Β±1 per group). The new EMA knows both direction and magnitude β E converges toward the log of the gradient energy, smoothed by Ξ±. This prevents the "step-2 mass-flip" loss spike from REFACTOR3.
### Plan 09-03 β LossComponent Temperature Routing (TERN-E-04)
**Commit:** `d77180d` (+ SUMMARY `bd8a750`)
Wired per-component loss signals as temperature for E updates:
```python
# update_E accepts loss_signal:
alpha = getattr(self, '_ema_alpha', 0.1)
if loss_signal is not None:
temp_scale = getattr(self, '_loss_temp_scale', 1.0)
alpha = alpha * torch.sigmoid(loss_signal.detach() * temp_scale).item()
```
Changes:
- `TernaryScaleTensor.update_E(loss_signal=...)`: Ξ± modulated by sigmoid(total_loss)
- `ByteEmbedding.update_E(loss_signal=...)`: same pattern
- `TernaryRMSNorm.update_E(loss_signal=...)`: accepts kwarg silently (frozen weights)
- `_ternary_update_memory(loss_signal=...)`: passes through to all modules
- `train.py`: passes `loss_comps.total` to `_ternary_update_memory`
**Later use (TERN-E-05):** The multi-scale lattice where each TScaleType level proposes ΞE_s was deferred per plan scope. Reactivation trigger: if single-scale EMA saturates and cannot separate magnitude regimes.
## Current State
### E Buffer Architecture (what the code now does)
```
Storage: T_packed (5-trit/byte uint8) + E (int8 log2) + T_accum (int8)
Forward: W_eff = 2^(E_exp) * T where E_exp = expand(E)
Update: E β (1-Ξ±)*E + Ξ±*round(log2(|grad|_group))
T_accum += sign(grad)
if |T_accum| > 3: flip T, reset T_accum
```
### Files Changed
| File | Lines Changed |
|------|---------------|
| `tscale.py` | ~40 changed, FP8 removed, EMA update added, LossComponent routing added |
| `trigram.py` | ~20 changed, FP8 removed, EMA update + LossComponent routing added |
| `ternary_audit.py` | ~1 changed, FP8 exclusion removed |
| `train.py` | ~1 changed, loss_signal passed to update |
| `benchmark_true_ternary.py` | New: 4-way benchmark comparing Adam/FP32, SignSGD(old), TrueTernary |
| `testing/test_tscale.py` | ~210 lines removed (FP8 tests), 1 comment updated |
| `.planning/ROADMAP.md` | Phase 9 goal and plans rewritten |
| `.planning/REQUIREMENTS.md` | HYB-01β06 replaced with TERN-E-01β05 |
### Tests
All 18 CPU tscale tests pass. 110 morph tests pass.
### Benchmark (in progress)
`benchmark_true_ternary.py` compares 3 configs on the full MORPHTernaryModel (133M params):
- **Adam_FP32**: standard FP32 Adam baseline
- **SignSGD_Old**: SignSGD without `_ternary_update_memory` (static E, no EMA)
- **TrueTernary**: our new int8 E + EMA update + LossComponent routing (CPU update)
Preliminary findings (50 steps, batch=8, seq=66):
- Adam_FP32: ~600ms/step, ~800 tok/s, 4.8GB peak VRAM
- SignSGD_Old: ~600ms/step, ~800 tok/s, 4.8GB peak VRAM
- TrueTernary: ~600ms/step, same VRAM (uses CPU fallback update)
## Known Issue: Triton Kernel Caching Bug
**`_triton_ternary_step_direct_kernel`** takes ~13.6s to compile and does NOT cache across calls, even with identical shapes, causing every `ternary_step` call to take 13.6s. This affects the GPU path for `_ternary_update_memory`.
Workaround in benchmark: CPU-based `cpu_update_memory()` function bypasses Triton for updates.
Symptoms:
- First call: 13.6s (compilation)
- Second call with identical shapes: ALSO 13.6s (recompilation β cache miss)
Triton cache directory has 8788 cached kernels but this specific kernel is not cached. Root cause not yet identified β likely a complex interaction of `tl.constexpr` parameters (M, N, K, TOTAL, ACCUM_THRESHOLD) + kernel complexity.
## Float Materialization Audit (post-rollback)
| Path | Float tensors created | Persistent? |
|------|----------------------|-------------|
| Forward (CUDA) | `y[M,N]`, `S` (ephemeral exp2) | Output tensor only |
| Backward (CUDA) | `grad_x`, `grad_2d`, `x_2d` | Ephemeral (autograd) |
| update_E (CPU) | `grad_T`, `mu_g`, `e_proposed`, `E_float` | Ephemeral (computed, applied, freed) |
| E buffer | `int8` (always) | Persistent (buffer) |
| S (scale) | **Never stored** β `2^E` computed on-the-fly | β |
No IEEE float in weight state. No FP32/BF16/FP8 master weights. S is implicit. β
## Remaining Work
1. **Resolve Triton caching bug** in `_triton_ternary_step_direct_kernel` β port kernel to use non-constexpr or simplify
2. **Run full benchmark** (all 3 configs to completion) β currently TrueTernary step 1+ hangs due to CPU update overhead; need to profile
3. **TERN-E-05 (multi-scale lattice)** β evaluate when single-scale EMA saturates
4. **LossComponent routing training validation** β verify that Ξ± modulation via sigmoid(loss) actually improves convergence vs fixed Ξ±
## Follow-Up Fix: Benchmark Progress And VRAM Attribution
The benchmark was not measuring the strict true-ternary path. It instantiated the full multimodal `MORPHTernaryModel`, passed trainable floating-point parameters to the optimizers, and kept the expensive baseline comparison path ahead of `TrueTernary`. That made all configs report roughly the same VRAM behavior even though strict ternary training has no persistent float weights.
### Changes
#### `benchmark_true_ternary.py`
- Added CLI controls:
- `--configs`
- `--steps`
- `--warmup`
- `--batch`
- `--ctx`
- `--strict-true-ternary / --no-strict-true-ternary`
- `--update-backend {gpu,gpu-signcache,dense-fallback,none}`
- `--scale-update-interval`
- `--accum-threshold`
- `--print-every`
- `--reuse-base / --no-reuse-base`
- `TrueTernary` now defaults to a strict text-only ternary model:
- no image/audio encoders
- no VQ
- no graph
- no recurrent memory modules
- float params frozen
- Optimizers are now built only from trainable parameters.
- Per-step logs now show allocated, reserved, and peak CUDA memory.
- Short benchmark averages now divide by the available number of steps instead of always dividing by 20.
- The reusable base model, when requested, is kept as CPU state so it does not inflate benchmark VRAM.
#### `tscale.py`
- CUDA `update_E()` now matches the Plan 09 logarithmic EMA rule instead of the older Β±1 sign update:
```text
mu_g = mean(abs(sign(grad_w) * T)) per group
e_proposed = round(log2(mu_g + eps))
E = (1 - alpha) * E + alpha * e_proposed
```
- The Triton `E` update kernels accept the loss-temperature-adjusted `alpha` computed in Python.
#### New benchmark update backend: `gpu-signcache`
The direct GPU update path is memory-clean but too slow for larger `M = batch * sequence` shapes because it recomputes the gradient reduction inside tiny packed-byte kernels.
`gpu-signcache` uses this flow per module:
```text
grad_y, x
-> Triton tiled sign-only grad kernel
-> temporary int8 grad_sign[N, K]
-> Triton EMA E update
-> Triton T_accum/repack
-> free grad_sign
```
This intentionally reintroduces a temporary int8 per-layer sign cache, not a persistent float weight gradient. It is the practical speed path for the benchmark while keeping persistent model state true ternary.
### Verification
Focused CUDA tests:
```text
PASS test_cuda_triton_correctness_update_E
PASS test_cuda_triton_tscale_path
```
Strict small benchmark:
```text
python benchmark_true_ternary.py --configs TrueTernary --steps 4 --warmup 1 --batch 2 --ctx 10 --strict-true-ternary --update-backend gpu-signcache --scale-update-interval 4
logical ternary weights: 14,011,904
trainable float params: 0
float buffers: 0
peak VRAM: ~30 MB
avg step after warmup: ~170 ms
```
Strict benchmark at previous benchmark shape:
```text
python benchmark_true_ternary.py --configs TrueTernary --steps 3 --warmup 1 --batch 8 --ctx 66 --strict-true-ternary --update-backend gpu-signcache --scale-update-interval 4
logical ternary weights: 14,011,904
ternary training state: 17.15 MB
trainable float params: 0
float buffers: 0
peak VRAM: ~430 MB
steady step: ~1.85 s
```
Previous direct backend at the same shape was roughly `49 s/step` after warmup. The benchmark now progresses; remaining speed work is kernel tuning and eventually replacing the temporary sign cache with a fused tiled kernel that is both fast and memory-minimal.
|