ARBS / docs /true-ternary /TRUE-TERNARY-REFACTOR4.md
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
# True Ternary Refactor 4 — True Ternary Exponent Dynamics
## Scope
Phase 9 was originally planned as "Ternary-FP8 Hybrid Precision Bridge" — upgrading E buffers from int8 log2 to float8_e4m3fn. An exploration session determined this approach was architecturally wrong: FP8 E reintroduces IEEE float mantissa/exponent into a system designed to eliminate it. The phase was replanned as **True Ternary Exponent Dynamics**, replacing the FP8 approach with the mathematically-correct logarithmic scaling system.
This refactor implements the first three waves of the replanned Phase 9.
## What Changed
### Architecture Decisions (from exploration session)
Five core principles were crystallized and are now encoded in code:
| Principle | Meaning | Code Impact |
|-----------|---------|-------------|
| **S is never stored** | S = 2^E is a function, not a value. No float8/int16/float stored. | `_get_S()` restored to `torch.exp2(E.float())` |
| **E is hybrid state** | Persistent int8 buffer, updated via EMA with statistical guidance (not pure SignSGD) | `update_E()` replaced SignSGD with EMA |
| **LossComponent = temperature** | Per-component loss signals control α (update energy) per group | `update_E(loss_signal)` modulates α via sigmoid |
| **TScaleType = fixed lattice** | Structure is stable; what's dynamic is energy routing across it | Group sizes unchanged |
| **Representation singular, learning ensemble** | Forward pass always single-scale; multi-scale exists only in update pathway | Forward uses single `W = T * 2^E` |
### REPLAN (before execution)
**Commit:** `9103dc5` — replan Phase 9: True Ternary Exponent Dynamics
- ROADMAP.md updated: Phase 9 renamed from "Ternary-FP8 Hybrid Precision Bridge" to "True Ternary Exponent Dynamics", 3 plans instead of 4
- REQUIREMENTS.md: HYB-01–06 replaced with TERN-E-01–05
- CONTEXT.md rewritten from FP8-centric to true ternary principles
- STATE.md updated with new phase structure
- Old FP8 plans (09-03, 09-04) and summaries (09-01, 09-02) removed
See `.planning/notes/true-ternary-architecture-principles.md` for full architecture rationale.
### Plan 09-01 — Roll Back FP8 E to int8 (TERN-E-01, TERN-E-02)
**Commit:** `80c6188` (+ SUMMARY `2bfc42a`)
Undid all FP8 changes from old Phase 9 Waves 1-2:
#### `tscale.py`
- **E buffer init** (lines 906, 1235): `clamp(-448, 448).to(float8_e4m3fn)``log2().clamp(-128, 127).to(int8)`
- **`_get_S`**: `E_exp.float()``torch.exp2(E_exp.float())` — restores log2→exp2 dequant
- **CPU update_E**: `clamp(E.float() + grad_E * 0.0625, -448, 448).to(float8_e4m3fn)` → `clamp(E.float() + grad_E, -128, 127).to(int8)`
- **tscale_to**: `clamp(-448, 448).to(float8)` → `log2().clamp(-128, 127).to(int8)`
- **5 Triton forward kernels** (fwd, grad_x, embed_fwd, rmsnorm_fwd, rmsnorm_bwd): load E with `other=0` (int), `.to(tl.float32)` → `tl.exp2(e_val)` replaces direct float cast
- **2 Triton update kernels** (`_triton_update_e`, `_triton_update_e_direct`): `other=0.0``other=0`, `tl.minimum(448, ...)``tl.minimum(127, ...)`, store as `tl.int8` instead of `tl.float8e4nv`
- **TernaryRMSNorm E init**: `clamp(-448, 448).to(float8)``log2().clamp(-128, 127).to(int8)`
#### `trigram.py`
- **ByteEmbedding E init**: `clamp(-448, 448).to(float8)``log2().clamp(-128, 127).to(int8)`
- **ByteEmbedding CPU forward**: `E_exp.float()``torch.exp2(E_exp.float())`
- **ByteEmbedding update_E**: `clamp(E.float() + grad_E * 0.0625, -448, 448).to(float8)` → `clamp(E.float() + grad_E, -128, 127).to(int8)`
#### `ternary_audit.py`
- Removed `buf.dtype != torch.float8_e4m3fn` exclusion from `float_buffers` filter
#### `testing/test_tscale.py`
- Removed 11 FP8-specific test functions and their registrations
- Updated update_E correctness test comment (was FP8-specific)
### Plan 09-02 — EMA-based E Update Rule (TERN-E-03)
**Commit:** `97d0482` (+ SUMMARY `c1f4be7`)
Replaced the old SignSGD update_E with an EMA in log-space:
```python
# Old (SignSGD):
grp_mean_sign = grouped.mean(dim=2).sign()
grad_E = -grp_mean_sign.float()
E = clamp(E + grad_E, -128, 127)
# New (EMA):
mu_g = grouped.abs().mean(dim=2) # mean abs gradient per group
e_proposed = round(log2(mu_g + 1e-10)).clamp(-128, 127)
E = (1-α) * E + α * e_proposed # α defaults to 0.1
```
Applied to both `TernaryScaleTensor.update_E()` and `ByteEmbedding.update_E()`.
**Why this matters:** The old SignSGD only knew direction (±1 per group). The new EMA knows both direction and magnitude — E converges toward the log of the gradient energy, smoothed by α. This prevents the "step-2 mass-flip" loss spike from REFACTOR3.
### Plan 09-03 — LossComponent Temperature Routing (TERN-E-04)
**Commit:** `d77180d` (+ SUMMARY `bd8a750`)
Wired per-component loss signals as temperature for E updates:
```python
# update_E accepts loss_signal:
alpha = getattr(self, '_ema_alpha', 0.1)
if loss_signal is not None:
temp_scale = getattr(self, '_loss_temp_scale', 1.0)
alpha = alpha * torch.sigmoid(loss_signal.detach() * temp_scale).item()
```
Changes:
- `TernaryScaleTensor.update_E(loss_signal=...)`: α modulated by sigmoid(total_loss)
- `ByteEmbedding.update_E(loss_signal=...)`: same pattern
- `TernaryRMSNorm.update_E(loss_signal=...)`: accepts kwarg silently (frozen weights)
- `_ternary_update_memory(loss_signal=...)`: passes through to all modules
- `train.py`: passes `loss_comps.total` to `_ternary_update_memory`
**Later use (TERN-E-05):** The multi-scale lattice where each TScaleType level proposes ΔE_s was deferred per plan scope. Reactivation trigger: if single-scale EMA saturates and cannot separate magnitude regimes.
## Current State
### E Buffer Architecture (what the code now does)
```
Storage: T_packed (5-trit/byte uint8) + E (int8 log2) + T_accum (int8)
Forward: W_eff = 2^(E_exp) * T where E_exp = expand(E)
Update: E ← (1-α)*E + α*round(log2(|grad|_group))
T_accum += sign(grad)
if |T_accum| > 3: flip T, reset T_accum
```
### Files Changed
| File | Lines Changed |
|------|---------------|
| `tscale.py` | ~40 changed, FP8 removed, EMA update added, LossComponent routing added |
| `trigram.py` | ~20 changed, FP8 removed, EMA update + LossComponent routing added |
| `ternary_audit.py` | ~1 changed, FP8 exclusion removed |
| `train.py` | ~1 changed, loss_signal passed to update |
| `benchmark_true_ternary.py` | New: 4-way benchmark comparing Adam/FP32, SignSGD(old), TrueTernary |
| `testing/test_tscale.py` | ~210 lines removed (FP8 tests), 1 comment updated |
| `.planning/ROADMAP.md` | Phase 9 goal and plans rewritten |
| `.planning/REQUIREMENTS.md` | HYB-01–06 replaced with TERN-E-01–05 |
### Tests
All 18 CPU tscale tests pass. 110 morph tests pass.
### Benchmark (in progress)
`benchmark_true_ternary.py` compares 3 configs on the full MORPHTernaryModel (133M params):
- **Adam_FP32**: standard FP32 Adam baseline
- **SignSGD_Old**: SignSGD without `_ternary_update_memory` (static E, no EMA)
- **TrueTernary**: our new int8 E + EMA update + LossComponent routing (CPU update)
Preliminary findings (50 steps, batch=8, seq=66):
- Adam_FP32: ~600ms/step, ~800 tok/s, 4.8GB peak VRAM
- SignSGD_Old: ~600ms/step, ~800 tok/s, 4.8GB peak VRAM
- TrueTernary: ~600ms/step, same VRAM (uses CPU fallback update)
## Known Issue: Triton Kernel Caching Bug
**`_triton_ternary_step_direct_kernel`** takes ~13.6s to compile and does NOT cache across calls, even with identical shapes, causing every `ternary_step` call to take 13.6s. This affects the GPU path for `_ternary_update_memory`.
Workaround in benchmark: CPU-based `cpu_update_memory()` function bypasses Triton for updates.
Symptoms:
- First call: 13.6s (compilation)
- Second call with identical shapes: ALSO 13.6s (recompilation — cache miss)
Triton cache directory has 8788 cached kernels but this specific kernel is not cached. Root cause not yet identified — likely a complex interaction of `tl.constexpr` parameters (M, N, K, TOTAL, ACCUM_THRESHOLD) + kernel complexity.
## Float Materialization Audit (post-rollback)
| Path | Float tensors created | Persistent? |
|------|----------------------|-------------|
| Forward (CUDA) | `y[M,N]`, `S` (ephemeral exp2) | Output tensor only |
| Backward (CUDA) | `grad_x`, `grad_2d`, `x_2d` | Ephemeral (autograd) |
| update_E (CPU) | `grad_T`, `mu_g`, `e_proposed`, `E_float` | Ephemeral (computed, applied, freed) |
| E buffer | `int8` (always) | Persistent (buffer) |
| S (scale) | **Never stored** — `2^E` computed on-the-fly | — |
No IEEE float in weight state. No FP32/BF16/FP8 master weights. S is implicit. ✓
## Remaining Work
1. **Resolve Triton caching bug** in `_triton_ternary_step_direct_kernel` — port kernel to use non-constexpr or simplify
2. **Run full benchmark** (all 3 configs to completion) — currently TrueTernary step 1+ hangs due to CPU update overhead; need to profile
3. **TERN-E-05 (multi-scale lattice)** — evaluate when single-scale EMA saturates
4. **LossComponent routing training validation** — verify that α modulation via sigmoid(loss) actually improves convergence vs fixed α
## Follow-Up Fix: Benchmark Progress And VRAM Attribution
The benchmark was not measuring the strict true-ternary path. It instantiated the full multimodal `MORPHTernaryModel`, passed trainable floating-point parameters to the optimizers, and kept the expensive baseline comparison path ahead of `TrueTernary`. That made all configs report roughly the same VRAM behavior even though strict ternary training has no persistent float weights.
### Changes
#### `benchmark_true_ternary.py`
- Added CLI controls:
- `--configs`
- `--steps`
- `--warmup`
- `--batch`
- `--ctx`
- `--strict-true-ternary / --no-strict-true-ternary`
- `--update-backend {gpu,gpu-signcache,dense-fallback,none}`
- `--scale-update-interval`
- `--accum-threshold`
- `--print-every`
- `--reuse-base / --no-reuse-base`
- `TrueTernary` now defaults to a strict text-only ternary model:
- no image/audio encoders
- no VQ
- no graph
- no recurrent memory modules
- float params frozen
- Optimizers are now built only from trainable parameters.
- Per-step logs now show allocated, reserved, and peak CUDA memory.
- Short benchmark averages now divide by the available number of steps instead of always dividing by 20.
- The reusable base model, when requested, is kept as CPU state so it does not inflate benchmark VRAM.
#### `tscale.py`
- CUDA `update_E()` now matches the Plan 09 logarithmic EMA rule instead of the older ±1 sign update:
```text
mu_g = mean(abs(sign(grad_w) * T)) per group
e_proposed = round(log2(mu_g + eps))
E = (1 - alpha) * E + alpha * e_proposed
```
- The Triton `E` update kernels accept the loss-temperature-adjusted `alpha` computed in Python.
#### New benchmark update backend: `gpu-signcache`
The direct GPU update path is memory-clean but too slow for larger `M = batch * sequence` shapes because it recomputes the gradient reduction inside tiny packed-byte kernels.
`gpu-signcache` uses this flow per module:
```text
grad_y, x
-> Triton tiled sign-only grad kernel
-> temporary int8 grad_sign[N, K]
-> Triton EMA E update
-> Triton T_accum/repack
-> free grad_sign
```
This intentionally reintroduces a temporary int8 per-layer sign cache, not a persistent float weight gradient. It is the practical speed path for the benchmark while keeping persistent model state true ternary.
### Verification
Focused CUDA tests:
```text
PASS test_cuda_triton_correctness_update_E
PASS test_cuda_triton_tscale_path
```
Strict small benchmark:
```text
python benchmark_true_ternary.py --configs TrueTernary --steps 4 --warmup 1 --batch 2 --ctx 10 --strict-true-ternary --update-backend gpu-signcache --scale-update-interval 4
logical ternary weights: 14,011,904
trainable float params: 0
float buffers: 0
peak VRAM: ~30 MB
avg step after warmup: ~170 ms
```
Strict benchmark at previous benchmark shape:
```text
python benchmark_true_ternary.py --configs TrueTernary --steps 3 --warmup 1 --batch 8 --ctx 66 --strict-true-ternary --update-backend gpu-signcache --scale-update-interval 4
logical ternary weights: 14,011,904
ternary training state: 17.15 MB
trainable float params: 0
float buffers: 0
peak VRAM: ~430 MB
steady step: ~1.85 s
```
Previous direct backend at the same shape was roughly `49 s/step` after warmup. The benchmark now progresses; remaining speed work is kernel tuning and eventually replacing the temporary sign cache with a fused tiled kernel that is both fast and memory-minimal.