# True Ternary Refactor 4 — True Ternary Exponent Dynamics ## Scope Phase 9 was originally planned as "Ternary-FP8 Hybrid Precision Bridge" — upgrading E buffers from int8 log2 to float8_e4m3fn. An exploration session determined this approach was architecturally wrong: FP8 E reintroduces IEEE float mantissa/exponent into a system designed to eliminate it. The phase was replanned as **True Ternary Exponent Dynamics**, replacing the FP8 approach with the mathematically-correct logarithmic scaling system. This refactor implements the first three waves of the replanned Phase 9. ## What Changed ### Architecture Decisions (from exploration session) Five core principles were crystallized and are now encoded in code: | Principle | Meaning | Code Impact | |-----------|---------|-------------| | **S is never stored** | S = 2^E is a function, not a value. No float8/int16/float stored. | `_get_S()` restored to `torch.exp2(E.float())` | | **E is hybrid state** | Persistent int8 buffer, updated via EMA with statistical guidance (not pure SignSGD) | `update_E()` replaced SignSGD with EMA | | **LossComponent = temperature** | Per-component loss signals control α (update energy) per group | `update_E(loss_signal)` modulates α via sigmoid | | **TScaleType = fixed lattice** | Structure is stable; what's dynamic is energy routing across it | Group sizes unchanged | | **Representation singular, learning ensemble** | Forward pass always single-scale; multi-scale exists only in update pathway | Forward uses single `W = T * 2^E` | ### REPLAN (before execution) **Commit:** `9103dc5` — replan Phase 9: True Ternary Exponent Dynamics - ROADMAP.md updated: Phase 9 renamed from "Ternary-FP8 Hybrid Precision Bridge" to "True Ternary Exponent Dynamics", 3 plans instead of 4 - REQUIREMENTS.md: HYB-01–06 replaced with TERN-E-01–05 - CONTEXT.md rewritten from FP8-centric to true ternary principles - STATE.md updated with new phase structure - Old FP8 plans (09-03, 09-04) and summaries (09-01, 09-02) removed See `.planning/notes/true-ternary-architecture-principles.md` for full architecture rationale. ### Plan 09-01 — Roll Back FP8 E to int8 (TERN-E-01, TERN-E-02) **Commit:** `80c6188` (+ SUMMARY `2bfc42a`) Undid all FP8 changes from old Phase 9 Waves 1-2: #### `tscale.py` - **E buffer init** (lines 906, 1235): `clamp(-448, 448).to(float8_e4m3fn)` → `log2().clamp(-128, 127).to(int8)` - **`_get_S`**: `E_exp.float()` → `torch.exp2(E_exp.float())` — restores log2→exp2 dequant - **CPU update_E**: `clamp(E.float() + grad_E * 0.0625, -448, 448).to(float8_e4m3fn)` → `clamp(E.float() + grad_E, -128, 127).to(int8)` - **tscale_to**: `clamp(-448, 448).to(float8)` → `log2().clamp(-128, 127).to(int8)` - **5 Triton forward kernels** (fwd, grad_x, embed_fwd, rmsnorm_fwd, rmsnorm_bwd): load E with `other=0` (int), `.to(tl.float32)` → `tl.exp2(e_val)` replaces direct float cast - **2 Triton update kernels** (`_triton_update_e`, `_triton_update_e_direct`): `other=0.0` → `other=0`, `tl.minimum(448, ...)` → `tl.minimum(127, ...)`, store as `tl.int8` instead of `tl.float8e4nv` - **TernaryRMSNorm E init**: `clamp(-448, 448).to(float8)` → `log2().clamp(-128, 127).to(int8)` #### `trigram.py` - **ByteEmbedding E init**: `clamp(-448, 448).to(float8)` → `log2().clamp(-128, 127).to(int8)` - **ByteEmbedding CPU forward**: `E_exp.float()` → `torch.exp2(E_exp.float())` - **ByteEmbedding update_E**: `clamp(E.float() + grad_E * 0.0625, -448, 448).to(float8)` → `clamp(E.float() + grad_E, -128, 127).to(int8)` #### `ternary_audit.py` - Removed `buf.dtype != torch.float8_e4m3fn` exclusion from `float_buffers` filter #### `testing/test_tscale.py` - Removed 11 FP8-specific test functions and their registrations - Updated update_E correctness test comment (was FP8-specific) ### Plan 09-02 — EMA-based E Update Rule (TERN-E-03) **Commit:** `97d0482` (+ SUMMARY `c1f4be7`) Replaced the old SignSGD update_E with an EMA in log-space: ```python # Old (SignSGD): grp_mean_sign = grouped.mean(dim=2).sign() grad_E = -grp_mean_sign.float() E = clamp(E + grad_E, -128, 127) # New (EMA): mu_g = grouped.abs().mean(dim=2) # mean abs gradient per group e_proposed = round(log2(mu_g + 1e-10)).clamp(-128, 127) E = (1-α) * E + α * e_proposed # α defaults to 0.1 ``` Applied to both `TernaryScaleTensor.update_E()` and `ByteEmbedding.update_E()`. **Why this matters:** The old SignSGD only knew direction (±1 per group). The new EMA knows both direction and magnitude — E converges toward the log of the gradient energy, smoothed by α. This prevents the "step-2 mass-flip" loss spike from REFACTOR3. ### Plan 09-03 — LossComponent Temperature Routing (TERN-E-04) **Commit:** `d77180d` (+ SUMMARY `bd8a750`) Wired per-component loss signals as temperature for E updates: ```python # update_E accepts loss_signal: alpha = getattr(self, '_ema_alpha', 0.1) if loss_signal is not None: temp_scale = getattr(self, '_loss_temp_scale', 1.0) alpha = alpha * torch.sigmoid(loss_signal.detach() * temp_scale).item() ``` Changes: - `TernaryScaleTensor.update_E(loss_signal=...)`: α modulated by sigmoid(total_loss) - `ByteEmbedding.update_E(loss_signal=...)`: same pattern - `TernaryRMSNorm.update_E(loss_signal=...)`: accepts kwarg silently (frozen weights) - `_ternary_update_memory(loss_signal=...)`: passes through to all modules - `train.py`: passes `loss_comps.total` to `_ternary_update_memory` **Later use (TERN-E-05):** The multi-scale lattice where each TScaleType level proposes ΔE_s was deferred per plan scope. Reactivation trigger: if single-scale EMA saturates and cannot separate magnitude regimes. ## Current State ### E Buffer Architecture (what the code now does) ``` Storage: T_packed (5-trit/byte uint8) + E (int8 log2) + T_accum (int8) Forward: W_eff = 2^(E_exp) * T where E_exp = expand(E) Update: E ← (1-α)*E + α*round(log2(|grad|_group)) T_accum += sign(grad) if |T_accum| > 3: flip T, reset T_accum ``` ### Files Changed | File | Lines Changed | |------|---------------| | `tscale.py` | ~40 changed, FP8 removed, EMA update added, LossComponent routing added | | `trigram.py` | ~20 changed, FP8 removed, EMA update + LossComponent routing added | | `ternary_audit.py` | ~1 changed, FP8 exclusion removed | | `train.py` | ~1 changed, loss_signal passed to update | | `benchmark_true_ternary.py` | New: 4-way benchmark comparing Adam/FP32, SignSGD(old), TrueTernary | | `testing/test_tscale.py` | ~210 lines removed (FP8 tests), 1 comment updated | | `.planning/ROADMAP.md` | Phase 9 goal and plans rewritten | | `.planning/REQUIREMENTS.md` | HYB-01–06 replaced with TERN-E-01–05 | ### Tests All 18 CPU tscale tests pass. 110 morph tests pass. ### Benchmark (in progress) `benchmark_true_ternary.py` compares 3 configs on the full MORPHTernaryModel (133M params): - **Adam_FP32**: standard FP32 Adam baseline - **SignSGD_Old**: SignSGD without `_ternary_update_memory` (static E, no EMA) - **TrueTernary**: our new int8 E + EMA update + LossComponent routing (CPU update) Preliminary findings (50 steps, batch=8, seq=66): - Adam_FP32: ~600ms/step, ~800 tok/s, 4.8GB peak VRAM - SignSGD_Old: ~600ms/step, ~800 tok/s, 4.8GB peak VRAM - TrueTernary: ~600ms/step, same VRAM (uses CPU fallback update) ## Known Issue: Triton Kernel Caching Bug **`_triton_ternary_step_direct_kernel`** takes ~13.6s to compile and does NOT cache across calls, even with identical shapes, causing every `ternary_step` call to take 13.6s. This affects the GPU path for `_ternary_update_memory`. Workaround in benchmark: CPU-based `cpu_update_memory()` function bypasses Triton for updates. Symptoms: - First call: 13.6s (compilation) - Second call with identical shapes: ALSO 13.6s (recompilation — cache miss) Triton cache directory has 8788 cached kernels but this specific kernel is not cached. Root cause not yet identified — likely a complex interaction of `tl.constexpr` parameters (M, N, K, TOTAL, ACCUM_THRESHOLD) + kernel complexity. ## Float Materialization Audit (post-rollback) | Path | Float tensors created | Persistent? | |------|----------------------|-------------| | Forward (CUDA) | `y[M,N]`, `S` (ephemeral exp2) | Output tensor only | | Backward (CUDA) | `grad_x`, `grad_2d`, `x_2d` | Ephemeral (autograd) | | update_E (CPU) | `grad_T`, `mu_g`, `e_proposed`, `E_float` | Ephemeral (computed, applied, freed) | | E buffer | `int8` (always) | Persistent (buffer) | | S (scale) | **Never stored** — `2^E` computed on-the-fly | — | No IEEE float in weight state. No FP32/BF16/FP8 master weights. S is implicit. ✓ ## Remaining Work 1. **Resolve Triton caching bug** in `_triton_ternary_step_direct_kernel` — port kernel to use non-constexpr or simplify 2. **Run full benchmark** (all 3 configs to completion) — currently TrueTernary step 1+ hangs due to CPU update overhead; need to profile 3. **TERN-E-05 (multi-scale lattice)** — evaluate when single-scale EMA saturates 4. **LossComponent routing training validation** — verify that α modulation via sigmoid(loss) actually improves convergence vs fixed α ## Follow-Up Fix: Benchmark Progress And VRAM Attribution The benchmark was not measuring the strict true-ternary path. It instantiated the full multimodal `MORPHTernaryModel`, passed trainable floating-point parameters to the optimizers, and kept the expensive baseline comparison path ahead of `TrueTernary`. That made all configs report roughly the same VRAM behavior even though strict ternary training has no persistent float weights. ### Changes #### `benchmark_true_ternary.py` - Added CLI controls: - `--configs` - `--steps` - `--warmup` - `--batch` - `--ctx` - `--strict-true-ternary / --no-strict-true-ternary` - `--update-backend {gpu,gpu-signcache,dense-fallback,none}` - `--scale-update-interval` - `--accum-threshold` - `--print-every` - `--reuse-base / --no-reuse-base` - `TrueTernary` now defaults to a strict text-only ternary model: - no image/audio encoders - no VQ - no graph - no recurrent memory modules - float params frozen - Optimizers are now built only from trainable parameters. - Per-step logs now show allocated, reserved, and peak CUDA memory. - Short benchmark averages now divide by the available number of steps instead of always dividing by 20. - The reusable base model, when requested, is kept as CPU state so it does not inflate benchmark VRAM. #### `tscale.py` - CUDA `update_E()` now matches the Plan 09 logarithmic EMA rule instead of the older ±1 sign update: ```text mu_g = mean(abs(sign(grad_w) * T)) per group e_proposed = round(log2(mu_g + eps)) E = (1 - alpha) * E + alpha * e_proposed ``` - The Triton `E` update kernels accept the loss-temperature-adjusted `alpha` computed in Python. #### New benchmark update backend: `gpu-signcache` The direct GPU update path is memory-clean but too slow for larger `M = batch * sequence` shapes because it recomputes the gradient reduction inside tiny packed-byte kernels. `gpu-signcache` uses this flow per module: ```text grad_y, x -> Triton tiled sign-only grad kernel -> temporary int8 grad_sign[N, K] -> Triton EMA E update -> Triton T_accum/repack -> free grad_sign ``` This intentionally reintroduces a temporary int8 per-layer sign cache, not a persistent float weight gradient. It is the practical speed path for the benchmark while keeping persistent model state true ternary. ### Verification Focused CUDA tests: ```text PASS test_cuda_triton_correctness_update_E PASS test_cuda_triton_tscale_path ``` Strict small benchmark: ```text python benchmark_true_ternary.py --configs TrueTernary --steps 4 --warmup 1 --batch 2 --ctx 10 --strict-true-ternary --update-backend gpu-signcache --scale-update-interval 4 logical ternary weights: 14,011,904 trainable float params: 0 float buffers: 0 peak VRAM: ~30 MB avg step after warmup: ~170 ms ``` Strict benchmark at previous benchmark shape: ```text python benchmark_true_ternary.py --configs TrueTernary --steps 3 --warmup 1 --batch 8 --ctx 66 --strict-true-ternary --update-backend gpu-signcache --scale-update-interval 4 logical ternary weights: 14,011,904 ternary training state: 17.15 MB trainable float params: 0 float buffers: 0 peak VRAM: ~430 MB steady step: ~1.85 s ``` Previous direct backend at the same shape was roughly `49 s/step` after warmup. The benchmark now progresses; remaining speed work is kernel tuning and eventually replacing the temporary sign cache with a fused tiled kernel that is both fast and memory-minimal.