ARBS / docs /true-ternary /TRUE-TERNARY-REFACTOR4.md
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified

True Ternary Refactor 4 β€” True Ternary Exponent Dynamics

Scope

Phase 9 was originally planned as "Ternary-FP8 Hybrid Precision Bridge" β€” upgrading E buffers from int8 log2 to float8_e4m3fn. An exploration session determined this approach was architecturally wrong: FP8 E reintroduces IEEE float mantissa/exponent into a system designed to eliminate it. The phase was replanned as True Ternary Exponent Dynamics, replacing the FP8 approach with the mathematically-correct logarithmic scaling system.

This refactor implements the first three waves of the replanned Phase 9.

What Changed

Architecture Decisions (from exploration session)

Five core principles were crystallized and are now encoded in code:

Principle Meaning Code Impact
S is never stored S = 2^E is a function, not a value. No float8/int16/float stored. _get_S() restored to torch.exp2(E.float())
E is hybrid state Persistent int8 buffer, updated via EMA with statistical guidance (not pure SignSGD) update_E() replaced SignSGD with EMA
LossComponent = temperature Per-component loss signals control Ξ± (update energy) per group update_E(loss_signal) modulates Ξ± via sigmoid
TScaleType = fixed lattice Structure is stable; what's dynamic is energy routing across it Group sizes unchanged
Representation singular, learning ensemble Forward pass always single-scale; multi-scale exists only in update pathway Forward uses single W = T * 2^E

REPLAN (before execution)

Commit: 9103dc5 β€” replan Phase 9: True Ternary Exponent Dynamics

  • ROADMAP.md updated: Phase 9 renamed from "Ternary-FP8 Hybrid Precision Bridge" to "True Ternary Exponent Dynamics", 3 plans instead of 4
  • REQUIREMENTS.md: HYB-01–06 replaced with TERN-E-01–05
  • CONTEXT.md rewritten from FP8-centric to true ternary principles
  • STATE.md updated with new phase structure
  • Old FP8 plans (09-03, 09-04) and summaries (09-01, 09-02) removed

See .planning/notes/true-ternary-architecture-principles.md for full architecture rationale.

Plan 09-01 β€” Roll Back FP8 E to int8 (TERN-E-01, TERN-E-02)

Commit: 80c6188 (+ SUMMARY 2bfc42a)

Undid all FP8 changes from old Phase 9 Waves 1-2:

tscale.py

  • E buffer init (lines 906, 1235): clamp(-448, 448).to(float8_e4m3fn) β†’ log2().clamp(-128, 127).to(int8)
  • _get_S: E_exp.float() β†’ torch.exp2(E_exp.float()) β€” restores log2β†’exp2 dequant
  • CPU update_E: clamp(E.float() + grad_E * 0.0625, -448, 448).to(float8_e4m3fn) β†’ clamp(E.float() + grad_E, -128, 127).to(int8)
  • tscale_to: clamp(-448, 448).to(float8) β†’ log2().clamp(-128, 127).to(int8)
  • 5 Triton forward kernels (fwd, grad_x, embed_fwd, rmsnorm_fwd, rmsnorm_bwd): load E with other=0 (int), .to(tl.float32) β†’ tl.exp2(e_val) replaces direct float cast
  • 2 Triton update kernels (_triton_update_e, _triton_update_e_direct): other=0.0 β†’ other=0, tl.minimum(448, ...) β†’ tl.minimum(127, ...), store as tl.int8 instead of tl.float8e4nv
  • TernaryRMSNorm E init: clamp(-448, 448).to(float8) β†’ log2().clamp(-128, 127).to(int8)

trigram.py

  • ByteEmbedding E init: clamp(-448, 448).to(float8) β†’ log2().clamp(-128, 127).to(int8)
  • ByteEmbedding CPU forward: E_exp.float() β†’ torch.exp2(E_exp.float())
  • ByteEmbedding update_E: clamp(E.float() + grad_E * 0.0625, -448, 448).to(float8) β†’ clamp(E.float() + grad_E, -128, 127).to(int8)

ternary_audit.py

  • Removed buf.dtype != torch.float8_e4m3fn exclusion from float_buffers filter

testing/test_tscale.py

  • Removed 11 FP8-specific test functions and their registrations
  • Updated update_E correctness test comment (was FP8-specific)

Plan 09-02 β€” EMA-based E Update Rule (TERN-E-03)

Commit: 97d0482 (+ SUMMARY c1f4be7)

Replaced the old SignSGD update_E with an EMA in log-space:

# Old (SignSGD):
grp_mean_sign = grouped.mean(dim=2).sign()
grad_E = -grp_mean_sign.float()
E = clamp(E + grad_E, -128, 127)

# New (EMA):
mu_g = grouped.abs().mean(dim=2)  # mean abs gradient per group
e_proposed = round(log2(mu_g + 1e-10)).clamp(-128, 127)
E = (1-Ξ±) * E + Ξ± * e_proposed       # Ξ± defaults to 0.1

Applied to both TernaryScaleTensor.update_E() and ByteEmbedding.update_E().

Why this matters: The old SignSGD only knew direction (Β±1 per group). The new EMA knows both direction and magnitude β€” E converges toward the log of the gradient energy, smoothed by Ξ±. This prevents the "step-2 mass-flip" loss spike from REFACTOR3.

Plan 09-03 β€” LossComponent Temperature Routing (TERN-E-04)

Commit: d77180d (+ SUMMARY bd8a750)

Wired per-component loss signals as temperature for E updates:

# update_E accepts loss_signal:
alpha = getattr(self, '_ema_alpha', 0.1)
if loss_signal is not None:
    temp_scale = getattr(self, '_loss_temp_scale', 1.0)
    alpha = alpha * torch.sigmoid(loss_signal.detach() * temp_scale).item()

Changes:

  • TernaryScaleTensor.update_E(loss_signal=...): Ξ± modulated by sigmoid(total_loss)
  • ByteEmbedding.update_E(loss_signal=...): same pattern
  • TernaryRMSNorm.update_E(loss_signal=...): accepts kwarg silently (frozen weights)
  • _ternary_update_memory(loss_signal=...): passes through to all modules
  • train.py: passes loss_comps.total to _ternary_update_memory

Later use (TERN-E-05): The multi-scale lattice where each TScaleType level proposes Ξ”E_s was deferred per plan scope. Reactivation trigger: if single-scale EMA saturates and cannot separate magnitude regimes.

Current State

E Buffer Architecture (what the code now does)

Storage:  T_packed (5-trit/byte uint8) + E (int8 log2) + T_accum (int8)
Forward:  W_eff = 2^(E_exp) * T   where E_exp = expand(E)
Update:   E ← (1-Ξ±)*E + Ξ±*round(log2(|grad|_group))
          T_accum += sign(grad)
          if |T_accum| > 3: flip T, reset T_accum

Files Changed

File Lines Changed
tscale.py ~40 changed, FP8 removed, EMA update added, LossComponent routing added
trigram.py ~20 changed, FP8 removed, EMA update + LossComponent routing added
ternary_audit.py ~1 changed, FP8 exclusion removed
train.py ~1 changed, loss_signal passed to update
benchmark_true_ternary.py New: 4-way benchmark comparing Adam/FP32, SignSGD(old), TrueTernary
testing/test_tscale.py ~210 lines removed (FP8 tests), 1 comment updated
.planning/ROADMAP.md Phase 9 goal and plans rewritten
.planning/REQUIREMENTS.md HYB-01–06 replaced with TERN-E-01–05

Tests

All 18 CPU tscale tests pass. 110 morph tests pass.

Benchmark (in progress)

benchmark_true_ternary.py compares 3 configs on the full MORPHTernaryModel (133M params):

  • Adam_FP32: standard FP32 Adam baseline
  • SignSGD_Old: SignSGD without _ternary_update_memory (static E, no EMA)
  • TrueTernary: our new int8 E + EMA update + LossComponent routing (CPU update)

Preliminary findings (50 steps, batch=8, seq=66):

  • Adam_FP32: ~600ms/step, ~800 tok/s, 4.8GB peak VRAM
  • SignSGD_Old: ~600ms/step, ~800 tok/s, 4.8GB peak VRAM
  • TrueTernary: ~600ms/step, same VRAM (uses CPU fallback update)

Known Issue: Triton Kernel Caching Bug

_triton_ternary_step_direct_kernel takes ~13.6s to compile and does NOT cache across calls, even with identical shapes, causing every ternary_step call to take 13.6s. This affects the GPU path for _ternary_update_memory.

Workaround in benchmark: CPU-based cpu_update_memory() function bypasses Triton for updates.

Symptoms:

  • First call: 13.6s (compilation)
  • Second call with identical shapes: ALSO 13.6s (recompilation β€” cache miss)

Triton cache directory has 8788 cached kernels but this specific kernel is not cached. Root cause not yet identified β€” likely a complex interaction of tl.constexpr parameters (M, N, K, TOTAL, ACCUM_THRESHOLD) + kernel complexity.

Float Materialization Audit (post-rollback)

Path Float tensors created Persistent?
Forward (CUDA) y[M,N], S (ephemeral exp2) Output tensor only
Backward (CUDA) grad_x, grad_2d, x_2d Ephemeral (autograd)
update_E (CPU) grad_T, mu_g, e_proposed, E_float Ephemeral (computed, applied, freed)
E buffer int8 (always) Persistent (buffer)
S (scale) Never stored β€” 2^E computed on-the-fly β€”

No IEEE float in weight state. No FP32/BF16/FP8 master weights. S is implicit. βœ“

Remaining Work

  1. Resolve Triton caching bug in _triton_ternary_step_direct_kernel β€” port kernel to use non-constexpr or simplify
  2. Run full benchmark (all 3 configs to completion) β€” currently TrueTernary step 1+ hangs due to CPU update overhead; need to profile
  3. TERN-E-05 (multi-scale lattice) β€” evaluate when single-scale EMA saturates
  4. LossComponent routing training validation β€” verify that Ξ± modulation via sigmoid(loss) actually improves convergence vs fixed Ξ±

Follow-Up Fix: Benchmark Progress And VRAM Attribution

The benchmark was not measuring the strict true-ternary path. It instantiated the full multimodal MORPHTernaryModel, passed trainable floating-point parameters to the optimizers, and kept the expensive baseline comparison path ahead of TrueTernary. That made all configs report roughly the same VRAM behavior even though strict ternary training has no persistent float weights.

Changes

benchmark_true_ternary.py

  • Added CLI controls:
    • --configs
    • --steps
    • --warmup
    • --batch
    • --ctx
    • --strict-true-ternary / --no-strict-true-ternary
    • --update-backend {gpu,gpu-signcache,dense-fallback,none}
    • --scale-update-interval
    • --accum-threshold
    • --print-every
    • --reuse-base / --no-reuse-base
  • TrueTernary now defaults to a strict text-only ternary model:
    • no image/audio encoders
    • no VQ
    • no graph
    • no recurrent memory modules
    • float params frozen
  • Optimizers are now built only from trainable parameters.
  • Per-step logs now show allocated, reserved, and peak CUDA memory.
  • Short benchmark averages now divide by the available number of steps instead of always dividing by 20.
  • The reusable base model, when requested, is kept as CPU state so it does not inflate benchmark VRAM.

tscale.py

  • CUDA update_E() now matches the Plan 09 logarithmic EMA rule instead of the older Β±1 sign update:
mu_g = mean(abs(sign(grad_w) * T)) per group
e_proposed = round(log2(mu_g + eps))
E = (1 - alpha) * E + alpha * e_proposed
  • The Triton E update kernels accept the loss-temperature-adjusted alpha computed in Python.

New benchmark update backend: gpu-signcache

The direct GPU update path is memory-clean but too slow for larger M = batch * sequence shapes because it recomputes the gradient reduction inside tiny packed-byte kernels.

gpu-signcache uses this flow per module:

grad_y, x
  -> Triton tiled sign-only grad kernel
  -> temporary int8 grad_sign[N, K]
  -> Triton EMA E update
  -> Triton T_accum/repack
  -> free grad_sign

This intentionally reintroduces a temporary int8 per-layer sign cache, not a persistent float weight gradient. It is the practical speed path for the benchmark while keeping persistent model state true ternary.

Verification

Focused CUDA tests:

PASS test_cuda_triton_correctness_update_E
PASS test_cuda_triton_tscale_path

Strict small benchmark:

python benchmark_true_ternary.py --configs TrueTernary --steps 4 --warmup 1 --batch 2 --ctx 10 --strict-true-ternary --update-backend gpu-signcache --scale-update-interval 4

logical ternary weights: 14,011,904
trainable float params: 0
float buffers: 0
peak VRAM: ~30 MB
avg step after warmup: ~170 ms

Strict benchmark at previous benchmark shape:

python benchmark_true_ternary.py --configs TrueTernary --steps 3 --warmup 1 --batch 8 --ctx 66 --strict-true-ternary --update-backend gpu-signcache --scale-update-interval 4

logical ternary weights: 14,011,904
ternary training state: 17.15 MB
trainable float params: 0
float buffers: 0
peak VRAM: ~430 MB
steady step: ~1.85 s

Previous direct backend at the same shape was roughly 49 s/step after warmup. The benchmark now progresses; remaining speed work is kernel tuning and eventually replacing the temporary sign cache with a fused tiled kernel that is both fast and memory-minimal.