True Ternary Refactor 4 β True Ternary Exponent Dynamics
Scope
Phase 9 was originally planned as "Ternary-FP8 Hybrid Precision Bridge" β upgrading E buffers from int8 log2 to float8_e4m3fn. An exploration session determined this approach was architecturally wrong: FP8 E reintroduces IEEE float mantissa/exponent into a system designed to eliminate it. The phase was replanned as True Ternary Exponent Dynamics, replacing the FP8 approach with the mathematically-correct logarithmic scaling system.
This refactor implements the first three waves of the replanned Phase 9.
What Changed
Architecture Decisions (from exploration session)
Five core principles were crystallized and are now encoded in code:
| Principle | Meaning | Code Impact |
|---|---|---|
| S is never stored | S = 2^E is a function, not a value. No float8/int16/float stored. | _get_S() restored to torch.exp2(E.float()) |
| E is hybrid state | Persistent int8 buffer, updated via EMA with statistical guidance (not pure SignSGD) | update_E() replaced SignSGD with EMA |
| LossComponent = temperature | Per-component loss signals control Ξ± (update energy) per group | update_E(loss_signal) modulates Ξ± via sigmoid |
| TScaleType = fixed lattice | Structure is stable; what's dynamic is energy routing across it | Group sizes unchanged |
| Representation singular, learning ensemble | Forward pass always single-scale; multi-scale exists only in update pathway | Forward uses single W = T * 2^E |
REPLAN (before execution)
Commit: 9103dc5 β replan Phase 9: True Ternary Exponent Dynamics
- ROADMAP.md updated: Phase 9 renamed from "Ternary-FP8 Hybrid Precision Bridge" to "True Ternary Exponent Dynamics", 3 plans instead of 4
- REQUIREMENTS.md: HYB-01β06 replaced with TERN-E-01β05
- CONTEXT.md rewritten from FP8-centric to true ternary principles
- STATE.md updated with new phase structure
- Old FP8 plans (09-03, 09-04) and summaries (09-01, 09-02) removed
See .planning/notes/true-ternary-architecture-principles.md for full architecture rationale.
Plan 09-01 β Roll Back FP8 E to int8 (TERN-E-01, TERN-E-02)
Commit: 80c6188 (+ SUMMARY 2bfc42a)
Undid all FP8 changes from old Phase 9 Waves 1-2:
tscale.py
- E buffer init (lines 906, 1235):
clamp(-448, 448).to(float8_e4m3fn)βlog2().clamp(-128, 127).to(int8) _get_S:E_exp.float()βtorch.exp2(E_exp.float())β restores log2βexp2 dequant- CPU update_E:
clamp(E.float() + grad_E * 0.0625, -448, 448).to(float8_e4m3fn)βclamp(E.float() + grad_E, -128, 127).to(int8) - tscale_to:
clamp(-448, 448).to(float8)βlog2().clamp(-128, 127).to(int8) - 5 Triton forward kernels (fwd, grad_x, embed_fwd, rmsnorm_fwd, rmsnorm_bwd): load E with
other=0(int),.to(tl.float32)βtl.exp2(e_val)replaces direct float cast - 2 Triton update kernels (
_triton_update_e,_triton_update_e_direct):other=0.0βother=0,tl.minimum(448, ...)βtl.minimum(127, ...), store astl.int8instead oftl.float8e4nv - TernaryRMSNorm E init:
clamp(-448, 448).to(float8)βlog2().clamp(-128, 127).to(int8)
trigram.py
- ByteEmbedding E init:
clamp(-448, 448).to(float8)βlog2().clamp(-128, 127).to(int8) - ByteEmbedding CPU forward:
E_exp.float()βtorch.exp2(E_exp.float()) - ByteEmbedding update_E:
clamp(E.float() + grad_E * 0.0625, -448, 448).to(float8)βclamp(E.float() + grad_E, -128, 127).to(int8)
ternary_audit.py
- Removed
buf.dtype != torch.float8_e4m3fnexclusion fromfloat_buffersfilter
testing/test_tscale.py
- Removed 11 FP8-specific test functions and their registrations
- Updated update_E correctness test comment (was FP8-specific)
Plan 09-02 β EMA-based E Update Rule (TERN-E-03)
Commit: 97d0482 (+ SUMMARY c1f4be7)
Replaced the old SignSGD update_E with an EMA in log-space:
# Old (SignSGD):
grp_mean_sign = grouped.mean(dim=2).sign()
grad_E = -grp_mean_sign.float()
E = clamp(E + grad_E, -128, 127)
# New (EMA):
mu_g = grouped.abs().mean(dim=2) # mean abs gradient per group
e_proposed = round(log2(mu_g + 1e-10)).clamp(-128, 127)
E = (1-Ξ±) * E + Ξ± * e_proposed # Ξ± defaults to 0.1
Applied to both TernaryScaleTensor.update_E() and ByteEmbedding.update_E().
Why this matters: The old SignSGD only knew direction (Β±1 per group). The new EMA knows both direction and magnitude β E converges toward the log of the gradient energy, smoothed by Ξ±. This prevents the "step-2 mass-flip" loss spike from REFACTOR3.
Plan 09-03 β LossComponent Temperature Routing (TERN-E-04)
Commit: d77180d (+ SUMMARY bd8a750)
Wired per-component loss signals as temperature for E updates:
# update_E accepts loss_signal:
alpha = getattr(self, '_ema_alpha', 0.1)
if loss_signal is not None:
temp_scale = getattr(self, '_loss_temp_scale', 1.0)
alpha = alpha * torch.sigmoid(loss_signal.detach() * temp_scale).item()
Changes:
TernaryScaleTensor.update_E(loss_signal=...): Ξ± modulated by sigmoid(total_loss)ByteEmbedding.update_E(loss_signal=...): same patternTernaryRMSNorm.update_E(loss_signal=...): accepts kwarg silently (frozen weights)_ternary_update_memory(loss_signal=...): passes through to all modulestrain.py: passesloss_comps.totalto_ternary_update_memory
Later use (TERN-E-05): The multi-scale lattice where each TScaleType level proposes ΞE_s was deferred per plan scope. Reactivation trigger: if single-scale EMA saturates and cannot separate magnitude regimes.
Current State
E Buffer Architecture (what the code now does)
Storage: T_packed (5-trit/byte uint8) + E (int8 log2) + T_accum (int8)
Forward: W_eff = 2^(E_exp) * T where E_exp = expand(E)
Update: E β (1-Ξ±)*E + Ξ±*round(log2(|grad|_group))
T_accum += sign(grad)
if |T_accum| > 3: flip T, reset T_accum
Files Changed
| File | Lines Changed |
|---|---|
tscale.py |
~40 changed, FP8 removed, EMA update added, LossComponent routing added |
trigram.py |
~20 changed, FP8 removed, EMA update + LossComponent routing added |
ternary_audit.py |
~1 changed, FP8 exclusion removed |
train.py |
~1 changed, loss_signal passed to update |
benchmark_true_ternary.py |
New: 4-way benchmark comparing Adam/FP32, SignSGD(old), TrueTernary |
testing/test_tscale.py |
~210 lines removed (FP8 tests), 1 comment updated |
.planning/ROADMAP.md |
Phase 9 goal and plans rewritten |
.planning/REQUIREMENTS.md |
HYB-01β06 replaced with TERN-E-01β05 |
Tests
All 18 CPU tscale tests pass. 110 morph tests pass.
Benchmark (in progress)
benchmark_true_ternary.py compares 3 configs on the full MORPHTernaryModel (133M params):
- Adam_FP32: standard FP32 Adam baseline
- SignSGD_Old: SignSGD without
_ternary_update_memory(static E, no EMA) - TrueTernary: our new int8 E + EMA update + LossComponent routing (CPU update)
Preliminary findings (50 steps, batch=8, seq=66):
- Adam_FP32: ~600ms/step, ~800 tok/s, 4.8GB peak VRAM
- SignSGD_Old: ~600ms/step, ~800 tok/s, 4.8GB peak VRAM
- TrueTernary: ~600ms/step, same VRAM (uses CPU fallback update)
Known Issue: Triton Kernel Caching Bug
_triton_ternary_step_direct_kernel takes ~13.6s to compile and does NOT cache across calls, even with identical shapes, causing every ternary_step call to take 13.6s. This affects the GPU path for _ternary_update_memory.
Workaround in benchmark: CPU-based cpu_update_memory() function bypasses Triton for updates.
Symptoms:
- First call: 13.6s (compilation)
- Second call with identical shapes: ALSO 13.6s (recompilation β cache miss)
Triton cache directory has 8788 cached kernels but this specific kernel is not cached. Root cause not yet identified β likely a complex interaction of tl.constexpr parameters (M, N, K, TOTAL, ACCUM_THRESHOLD) + kernel complexity.
Float Materialization Audit (post-rollback)
| Path | Float tensors created | Persistent? |
|---|---|---|
| Forward (CUDA) | y[M,N], S (ephemeral exp2) |
Output tensor only |
| Backward (CUDA) | grad_x, grad_2d, x_2d |
Ephemeral (autograd) |
| update_E (CPU) | grad_T, mu_g, e_proposed, E_float |
Ephemeral (computed, applied, freed) |
| E buffer | int8 (always) |
Persistent (buffer) |
| S (scale) | Never stored β 2^E computed on-the-fly |
β |
No IEEE float in weight state. No FP32/BF16/FP8 master weights. S is implicit. β
Remaining Work
- Resolve Triton caching bug in
_triton_ternary_step_direct_kernelβ port kernel to use non-constexpr or simplify - Run full benchmark (all 3 configs to completion) β currently TrueTernary step 1+ hangs due to CPU update overhead; need to profile
- TERN-E-05 (multi-scale lattice) β evaluate when single-scale EMA saturates
- LossComponent routing training validation β verify that Ξ± modulation via sigmoid(loss) actually improves convergence vs fixed Ξ±
Follow-Up Fix: Benchmark Progress And VRAM Attribution
The benchmark was not measuring the strict true-ternary path. It instantiated the full multimodal MORPHTernaryModel, passed trainable floating-point parameters to the optimizers, and kept the expensive baseline comparison path ahead of TrueTernary. That made all configs report roughly the same VRAM behavior even though strict ternary training has no persistent float weights.
Changes
benchmark_true_ternary.py
- Added CLI controls:
--configs--steps--warmup--batch--ctx--strict-true-ternary / --no-strict-true-ternary--update-backend {gpu,gpu-signcache,dense-fallback,none}--scale-update-interval--accum-threshold--print-every--reuse-base / --no-reuse-base
TrueTernarynow defaults to a strict text-only ternary model:- no image/audio encoders
- no VQ
- no graph
- no recurrent memory modules
- float params frozen
- Optimizers are now built only from trainable parameters.
- Per-step logs now show allocated, reserved, and peak CUDA memory.
- Short benchmark averages now divide by the available number of steps instead of always dividing by 20.
- The reusable base model, when requested, is kept as CPU state so it does not inflate benchmark VRAM.
tscale.py
- CUDA
update_E()now matches the Plan 09 logarithmic EMA rule instead of the older Β±1 sign update:
mu_g = mean(abs(sign(grad_w) * T)) per group
e_proposed = round(log2(mu_g + eps))
E = (1 - alpha) * E + alpha * e_proposed
- The Triton
Eupdate kernels accept the loss-temperature-adjustedalphacomputed in Python.
New benchmark update backend: gpu-signcache
The direct GPU update path is memory-clean but too slow for larger M = batch * sequence shapes because it recomputes the gradient reduction inside tiny packed-byte kernels.
gpu-signcache uses this flow per module:
grad_y, x
-> Triton tiled sign-only grad kernel
-> temporary int8 grad_sign[N, K]
-> Triton EMA E update
-> Triton T_accum/repack
-> free grad_sign
This intentionally reintroduces a temporary int8 per-layer sign cache, not a persistent float weight gradient. It is the practical speed path for the benchmark while keeping persistent model state true ternary.
Verification
Focused CUDA tests:
PASS test_cuda_triton_correctness_update_E
PASS test_cuda_triton_tscale_path
Strict small benchmark:
python benchmark_true_ternary.py --configs TrueTernary --steps 4 --warmup 1 --batch 2 --ctx 10 --strict-true-ternary --update-backend gpu-signcache --scale-update-interval 4
logical ternary weights: 14,011,904
trainable float params: 0
float buffers: 0
peak VRAM: ~30 MB
avg step after warmup: ~170 ms
Strict benchmark at previous benchmark shape:
python benchmark_true_ternary.py --configs TrueTernary --steps 3 --warmup 1 --batch 8 --ctx 66 --strict-true-ternary --update-backend gpu-signcache --scale-update-interval 4
logical ternary weights: 14,011,904
ternary training state: 17.15 MB
trainable float params: 0
float buffers: 0
peak VRAM: ~430 MB
steady step: ~1.85 s
Previous direct backend at the same shape was roughly 49 s/step after warmup. The benchmark now progresses; remaining speed work is kernel tuning and eventually replacing the temporary sign cache with a fused tiled kernel that is both fast and memory-minimal.