| # True Ternary Refactor 4 — True Ternary Exponent Dynamics |
|
|
| ## Scope |
|
|
| Phase 9 was originally planned as "Ternary-FP8 Hybrid Precision Bridge" — upgrading E buffers from int8 log2 to float8_e4m3fn. An exploration session determined this approach was architecturally wrong: FP8 E reintroduces IEEE float mantissa/exponent into a system designed to eliminate it. The phase was replanned as **True Ternary Exponent Dynamics**, replacing the FP8 approach with the mathematically-correct logarithmic scaling system. |
| |
| This refactor implements the first three waves of the replanned Phase 9. |
| |
| ## What Changed |
| |
| ### Architecture Decisions (from exploration session) |
| |
| Five core principles were crystallized and are now encoded in code: |
| |
| | Principle | Meaning | Code Impact | |
| |-----------|---------|-------------| |
| | **S is never stored** | S = 2^E is a function, not a value. No float8/int16/float stored. | `_get_S()` restored to `torch.exp2(E.float())` | |
| | **E is hybrid state** | Persistent int8 buffer, updated via EMA with statistical guidance (not pure SignSGD) | `update_E()` replaced SignSGD with EMA | |
| | **LossComponent = temperature** | Per-component loss signals control α (update energy) per group | `update_E(loss_signal)` modulates α via sigmoid | |
| | **TScaleType = fixed lattice** | Structure is stable; what's dynamic is energy routing across it | Group sizes unchanged | |
| | **Representation singular, learning ensemble** | Forward pass always single-scale; multi-scale exists only in update pathway | Forward uses single `W = T * 2^E` | |
|
|
| ### REPLAN (before execution) |
|
|
| **Commit:** `9103dc5` — replan Phase 9: True Ternary Exponent Dynamics |
|
|
| - ROADMAP.md updated: Phase 9 renamed from "Ternary-FP8 Hybrid Precision Bridge" to "True Ternary Exponent Dynamics", 3 plans instead of 4 |
| - REQUIREMENTS.md: HYB-01–06 replaced with TERN-E-01–05 |
| - CONTEXT.md rewritten from FP8-centric to true ternary principles |
| - STATE.md updated with new phase structure |
| - Old FP8 plans (09-03, 09-04) and summaries (09-01, 09-02) removed |
|
|
| See `.planning/notes/true-ternary-architecture-principles.md` for full architecture rationale. |
|
|
| ### Plan 09-01 — Roll Back FP8 E to int8 (TERN-E-01, TERN-E-02) |
|
|
| **Commit:** `80c6188` (+ SUMMARY `2bfc42a`) |
|
|
| Undid all FP8 changes from old Phase 9 Waves 1-2: |
|
|
| #### `tscale.py` |
|
|
| - **E buffer init** (lines 906, 1235): `clamp(-448, 448).to(float8_e4m3fn)` → `log2().clamp(-128, 127).to(int8)` |
| - **`_get_S`**: `E_exp.float()` → `torch.exp2(E_exp.float())` — restores log2→exp2 dequant |
| - **CPU update_E**: `clamp(E.float() + grad_E * 0.0625, -448, 448).to(float8_e4m3fn)` → `clamp(E.float() + grad_E, -128, 127).to(int8)` |
| - **tscale_to**: `clamp(-448, 448).to(float8)` → `log2().clamp(-128, 127).to(int8)` |
| - **5 Triton forward kernels** (fwd, grad_x, embed_fwd, rmsnorm_fwd, rmsnorm_bwd): load E with `other=0` (int), `.to(tl.float32)` → `tl.exp2(e_val)` replaces direct float cast |
| - **2 Triton update kernels** (`_triton_update_e`, `_triton_update_e_direct`): `other=0.0` → `other=0`, `tl.minimum(448, ...)` → `tl.minimum(127, ...)`, store as `tl.int8` instead of `tl.float8e4nv` |
| - **TernaryRMSNorm E init**: `clamp(-448, 448).to(float8)` → `log2().clamp(-128, 127).to(int8)` |
|
|
| #### `trigram.py` |
|
|
| - **ByteEmbedding E init**: `clamp(-448, 448).to(float8)` → `log2().clamp(-128, 127).to(int8)` |
| - **ByteEmbedding CPU forward**: `E_exp.float()` → `torch.exp2(E_exp.float())` |
| - **ByteEmbedding update_E**: `clamp(E.float() + grad_E * 0.0625, -448, 448).to(float8)` → `clamp(E.float() + grad_E, -128, 127).to(int8)` |
| |
| #### `ternary_audit.py` |
| |
| - Removed `buf.dtype != torch.float8_e4m3fn` exclusion from `float_buffers` filter |
| |
| #### `testing/test_tscale.py` |
| |
| - Removed 11 FP8-specific test functions and their registrations |
| - Updated update_E correctness test comment (was FP8-specific) |
| |
| ### Plan 09-02 — EMA-based E Update Rule (TERN-E-03) |
| |
| **Commit:** `97d0482` (+ SUMMARY `c1f4be7`) |
|
|
| Replaced the old SignSGD update_E with an EMA in log-space: |
| |
| ```python |
| # Old (SignSGD): |
| grp_mean_sign = grouped.mean(dim=2).sign() |
| grad_E = -grp_mean_sign.float() |
| E = clamp(E + grad_E, -128, 127) |
| |
| # New (EMA): |
| mu_g = grouped.abs().mean(dim=2) # mean abs gradient per group |
| e_proposed = round(log2(mu_g + 1e-10)).clamp(-128, 127) |
| E = (1-α) * E + α * e_proposed # α defaults to 0.1 |
| ``` |
| |
| Applied to both `TernaryScaleTensor.update_E()` and `ByteEmbedding.update_E()`. |
| |
| **Why this matters:** The old SignSGD only knew direction (±1 per group). The new EMA knows both direction and magnitude — E converges toward the log of the gradient energy, smoothed by α. This prevents the "step-2 mass-flip" loss spike from REFACTOR3. |
| |
| ### Plan 09-03 — LossComponent Temperature Routing (TERN-E-04) |
| |
| **Commit:** `d77180d` (+ SUMMARY `bd8a750`) |
| |
| Wired per-component loss signals as temperature for E updates: |
| |
| ```python |
| # update_E accepts loss_signal: |
| alpha = getattr(self, '_ema_alpha', 0.1) |
| if loss_signal is not None: |
| temp_scale = getattr(self, '_loss_temp_scale', 1.0) |
| alpha = alpha * torch.sigmoid(loss_signal.detach() * temp_scale).item() |
| ``` |
| |
| Changes: |
| - `TernaryScaleTensor.update_E(loss_signal=...)`: α modulated by sigmoid(total_loss) |
| - `ByteEmbedding.update_E(loss_signal=...)`: same pattern |
| - `TernaryRMSNorm.update_E(loss_signal=...)`: accepts kwarg silently (frozen weights) |
| - `_ternary_update_memory(loss_signal=...)`: passes through to all modules |
| - `train.py`: passes `loss_comps.total` to `_ternary_update_memory` |
| |
| **Later use (TERN-E-05):** The multi-scale lattice where each TScaleType level proposes ΔE_s was deferred per plan scope. Reactivation trigger: if single-scale EMA saturates and cannot separate magnitude regimes. |
|
|
| ## Current State |
|
|
| ### E Buffer Architecture (what the code now does) |
|
|
| ``` |
| Storage: T_packed (5-trit/byte uint8) + E (int8 log2) + T_accum (int8) |
| Forward: W_eff = 2^(E_exp) * T where E_exp = expand(E) |
| Update: E ← (1-α)*E + α*round(log2(|grad|_group)) |
| T_accum += sign(grad) |
| if |T_accum| > 3: flip T, reset T_accum |
| ``` |
|
|
| ### Files Changed |
|
|
| | File | Lines Changed | |
| |------|---------------| |
| | `tscale.py` | ~40 changed, FP8 removed, EMA update added, LossComponent routing added | |
| | `trigram.py` | ~20 changed, FP8 removed, EMA update + LossComponent routing added | |
| | `ternary_audit.py` | ~1 changed, FP8 exclusion removed | |
| | `train.py` | ~1 changed, loss_signal passed to update | |
| | `benchmark_true_ternary.py` | New: 4-way benchmark comparing Adam/FP32, SignSGD(old), TrueTernary | |
| | `testing/test_tscale.py` | ~210 lines removed (FP8 tests), 1 comment updated | |
| | `.planning/ROADMAP.md` | Phase 9 goal and plans rewritten | |
| | `.planning/REQUIREMENTS.md` | HYB-01–06 replaced with TERN-E-01–05 | |
|
|
| ### Tests |
|
|
| All 18 CPU tscale tests pass. 110 morph tests pass. |
|
|
| ### Benchmark (in progress) |
|
|
| `benchmark_true_ternary.py` compares 3 configs on the full MORPHTernaryModel (133M params): |
| - **Adam_FP32**: standard FP32 Adam baseline |
| - **SignSGD_Old**: SignSGD without `_ternary_update_memory` (static E, no EMA) |
| - **TrueTernary**: our new int8 E + EMA update + LossComponent routing (CPU update) |
|
|
| Preliminary findings (50 steps, batch=8, seq=66): |
| - Adam_FP32: ~600ms/step, ~800 tok/s, 4.8GB peak VRAM |
| - SignSGD_Old: ~600ms/step, ~800 tok/s, 4.8GB peak VRAM |
| - TrueTernary: ~600ms/step, same VRAM (uses CPU fallback update) |
|
|
| ## Known Issue: Triton Kernel Caching Bug |
|
|
| **`_triton_ternary_step_direct_kernel`** takes ~13.6s to compile and does NOT cache across calls, even with identical shapes, causing every `ternary_step` call to take 13.6s. This affects the GPU path for `_ternary_update_memory`. |
| |
| Workaround in benchmark: CPU-based `cpu_update_memory()` function bypasses Triton for updates. |
| |
| Symptoms: |
| - First call: 13.6s (compilation) |
| - Second call with identical shapes: ALSO 13.6s (recompilation — cache miss) |
| |
| Triton cache directory has 8788 cached kernels but this specific kernel is not cached. Root cause not yet identified — likely a complex interaction of `tl.constexpr` parameters (M, N, K, TOTAL, ACCUM_THRESHOLD) + kernel complexity. |
| |
| ## Float Materialization Audit (post-rollback) |
| |
| | Path | Float tensors created | Persistent? | |
| |------|----------------------|-------------| |
| | Forward (CUDA) | `y[M,N]`, `S` (ephemeral exp2) | Output tensor only | |
| | Backward (CUDA) | `grad_x`, `grad_2d`, `x_2d` | Ephemeral (autograd) | |
| | update_E (CPU) | `grad_T`, `mu_g`, `e_proposed`, `E_float` | Ephemeral (computed, applied, freed) | |
| | E buffer | `int8` (always) | Persistent (buffer) | |
| | S (scale) | **Never stored** — `2^E` computed on-the-fly | — | |
|
|
| No IEEE float in weight state. No FP32/BF16/FP8 master weights. S is implicit. ✓ |
|
|
| ## Remaining Work |
|
|
| 1. **Resolve Triton caching bug** in `_triton_ternary_step_direct_kernel` — port kernel to use non-constexpr or simplify |
| 2. **Run full benchmark** (all 3 configs to completion) — currently TrueTernary step 1+ hangs due to CPU update overhead; need to profile |
| 3. **TERN-E-05 (multi-scale lattice)** — evaluate when single-scale EMA saturates |
| 4. **LossComponent routing training validation** — verify that α modulation via sigmoid(loss) actually improves convergence vs fixed α |
|
|
| ## Follow-Up Fix: Benchmark Progress And VRAM Attribution |
|
|
| The benchmark was not measuring the strict true-ternary path. It instantiated the full multimodal `MORPHTernaryModel`, passed trainable floating-point parameters to the optimizers, and kept the expensive baseline comparison path ahead of `TrueTernary`. That made all configs report roughly the same VRAM behavior even though strict ternary training has no persistent float weights. |
|
|
| ### Changes |
|
|
| #### `benchmark_true_ternary.py` |
|
|
| - Added CLI controls: |
| - `--configs` |
| - `--steps` |
| - `--warmup` |
| - `--batch` |
| - `--ctx` |
| - `--strict-true-ternary / --no-strict-true-ternary` |
| - `--update-backend {gpu,gpu-signcache,dense-fallback,none}` |
| - `--scale-update-interval` |
| - `--accum-threshold` |
| - `--print-every` |
| - `--reuse-base / --no-reuse-base` |
| - `TrueTernary` now defaults to a strict text-only ternary model: |
| - no image/audio encoders |
| - no VQ |
| - no graph |
| - no recurrent memory modules |
| - float params frozen |
| - Optimizers are now built only from trainable parameters. |
| - Per-step logs now show allocated, reserved, and peak CUDA memory. |
| - Short benchmark averages now divide by the available number of steps instead of always dividing by 20. |
| - The reusable base model, when requested, is kept as CPU state so it does not inflate benchmark VRAM. |
|
|
| #### `tscale.py` |
|
|
| - CUDA `update_E()` now matches the Plan 09 logarithmic EMA rule instead of the older ±1 sign update: |
|
|
| ```text |
| mu_g = mean(abs(sign(grad_w) * T)) per group |
| e_proposed = round(log2(mu_g + eps)) |
| E = (1 - alpha) * E + alpha * e_proposed |
| ``` |
|
|
| - The Triton `E` update kernels accept the loss-temperature-adjusted `alpha` computed in Python. |
|
|
| #### New benchmark update backend: `gpu-signcache` |
|
|
| The direct GPU update path is memory-clean but too slow for larger `M = batch * sequence` shapes because it recomputes the gradient reduction inside tiny packed-byte kernels. |
|
|
| `gpu-signcache` uses this flow per module: |
|
|
| ```text |
| grad_y, x |
| -> Triton tiled sign-only grad kernel |
| -> temporary int8 grad_sign[N, K] |
| -> Triton EMA E update |
| -> Triton T_accum/repack |
| -> free grad_sign |
| ``` |
|
|
| This intentionally reintroduces a temporary int8 per-layer sign cache, not a persistent float weight gradient. It is the practical speed path for the benchmark while keeping persistent model state true ternary. |
|
|
| ### Verification |
|
|
| Focused CUDA tests: |
|
|
| ```text |
| PASS test_cuda_triton_correctness_update_E |
| PASS test_cuda_triton_tscale_path |
| ``` |
|
|
| Strict small benchmark: |
|
|
| ```text |
| python benchmark_true_ternary.py --configs TrueTernary --steps 4 --warmup 1 --batch 2 --ctx 10 --strict-true-ternary --update-backend gpu-signcache --scale-update-interval 4 |
| |
| logical ternary weights: 14,011,904 |
| trainable float params: 0 |
| float buffers: 0 |
| peak VRAM: ~30 MB |
| avg step after warmup: ~170 ms |
| ``` |
|
|
| Strict benchmark at previous benchmark shape: |
|
|
| ```text |
| python benchmark_true_ternary.py --configs TrueTernary --steps 3 --warmup 1 --batch 8 --ctx 66 --strict-true-ternary --update-backend gpu-signcache --scale-update-interval 4 |
| |
| logical ternary weights: 14,011,904 |
| ternary training state: 17.15 MB |
| trainable float params: 0 |
| float buffers: 0 |
| peak VRAM: ~430 MB |
| steady step: ~1.85 s |
| ``` |
|
|
| Previous direct backend at the same shape was roughly `49 s/step` after warmup. The benchmark now progresses; remaining speed work is kernel tuning and eventually replacing the temporary sign cache with a fused tiled kernel that is both fast and memory-minimal. |
|
|