File size: 12,620 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
# True Ternary Refactor 4 β€” True Ternary Exponent Dynamics

## Scope

Phase 9 was originally planned as "Ternary-FP8 Hybrid Precision Bridge" β€” upgrading E buffers from int8 log2 to float8_e4m3fn. An exploration session determined this approach was architecturally wrong: FP8 E reintroduces IEEE float mantissa/exponent into a system designed to eliminate it. The phase was replanned as **True Ternary Exponent Dynamics**, replacing the FP8 approach with the mathematically-correct logarithmic scaling system.

This refactor implements the first three waves of the replanned Phase 9.

## What Changed

### Architecture Decisions (from exploration session)

Five core principles were crystallized and are now encoded in code:

| Principle | Meaning | Code Impact |
|-----------|---------|-------------|
| **S is never stored** | S = 2^E is a function, not a value. No float8/int16/float stored. | `_get_S()` restored to `torch.exp2(E.float())` |
| **E is hybrid state** | Persistent int8 buffer, updated via EMA with statistical guidance (not pure SignSGD) | `update_E()` replaced SignSGD with EMA |
| **LossComponent = temperature** | Per-component loss signals control Ξ± (update energy) per group | `update_E(loss_signal)` modulates Ξ± via sigmoid |
| **TScaleType = fixed lattice** | Structure is stable; what's dynamic is energy routing across it | Group sizes unchanged |
| **Representation singular, learning ensemble** | Forward pass always single-scale; multi-scale exists only in update pathway | Forward uses single `W = T * 2^E` |

### REPLAN (before execution)

**Commit:** `9103dc5` β€” replan Phase 9: True Ternary Exponent Dynamics

- ROADMAP.md updated: Phase 9 renamed from "Ternary-FP8 Hybrid Precision Bridge" to "True Ternary Exponent Dynamics", 3 plans instead of 4
- REQUIREMENTS.md: HYB-01–06 replaced with TERN-E-01–05
- CONTEXT.md rewritten from FP8-centric to true ternary principles
- STATE.md updated with new phase structure
- Old FP8 plans (09-03, 09-04) and summaries (09-01, 09-02) removed

See `.planning/notes/true-ternary-architecture-principles.md` for full architecture rationale.

### Plan 09-01 β€” Roll Back FP8 E to int8 (TERN-E-01, TERN-E-02)

**Commit:** `80c6188` (+ SUMMARY `2bfc42a`)

Undid all FP8 changes from old Phase 9 Waves 1-2:

#### `tscale.py`

- **E buffer init** (lines 906, 1235): `clamp(-448, 448).to(float8_e4m3fn)` β†’ `log2().clamp(-128, 127).to(int8)`
- **`_get_S`**: `E_exp.float()` β†’ `torch.exp2(E_exp.float())` β€” restores log2β†’exp2 dequant
- **CPU update_E**: `clamp(E.float() + grad_E * 0.0625, -448, 448).to(float8_e4m3fn)` β†’ `clamp(E.float() + grad_E, -128, 127).to(int8)`
- **tscale_to**: `clamp(-448, 448).to(float8)` β†’ `log2().clamp(-128, 127).to(int8)`
- **5 Triton forward kernels** (fwd, grad_x, embed_fwd, rmsnorm_fwd, rmsnorm_bwd): load E with `other=0` (int), `.to(tl.float32)` β†’ `tl.exp2(e_val)` replaces direct float cast
- **2 Triton update kernels** (`_triton_update_e`, `_triton_update_e_direct`): `other=0.0` β†’ `other=0`, `tl.minimum(448, ...)` β†’ `tl.minimum(127, ...)`, store as `tl.int8` instead of `tl.float8e4nv`
- **TernaryRMSNorm E init**: `clamp(-448, 448).to(float8)` β†’ `log2().clamp(-128, 127).to(int8)`

#### `trigram.py`

- **ByteEmbedding E init**: `clamp(-448, 448).to(float8)` β†’ `log2().clamp(-128, 127).to(int8)`
- **ByteEmbedding CPU forward**: `E_exp.float()` β†’ `torch.exp2(E_exp.float())`
- **ByteEmbedding update_E**: `clamp(E.float() + grad_E * 0.0625, -448, 448).to(float8)` β†’ `clamp(E.float() + grad_E, -128, 127).to(int8)`

#### `ternary_audit.py`

- Removed `buf.dtype != torch.float8_e4m3fn` exclusion from `float_buffers` filter

#### `testing/test_tscale.py`

- Removed 11 FP8-specific test functions and their registrations
- Updated update_E correctness test comment (was FP8-specific)

### Plan 09-02 β€” EMA-based E Update Rule (TERN-E-03)

**Commit:** `97d0482` (+ SUMMARY `c1f4be7`)

Replaced the old SignSGD update_E with an EMA in log-space:

```python
# Old (SignSGD):
grp_mean_sign = grouped.mean(dim=2).sign()
grad_E = -grp_mean_sign.float()
E = clamp(E + grad_E, -128, 127)

# New (EMA):
mu_g = grouped.abs().mean(dim=2)  # mean abs gradient per group
e_proposed = round(log2(mu_g + 1e-10)).clamp(-128, 127)
E = (1-Ξ±) * E + Ξ± * e_proposed       # Ξ± defaults to 0.1
```

Applied to both `TernaryScaleTensor.update_E()` and `ByteEmbedding.update_E()`.

**Why this matters:** The old SignSGD only knew direction (Β±1 per group). The new EMA knows both direction and magnitude β€” E converges toward the log of the gradient energy, smoothed by Ξ±. This prevents the "step-2 mass-flip" loss spike from REFACTOR3.

### Plan 09-03 β€” LossComponent Temperature Routing (TERN-E-04)

**Commit:** `d77180d` (+ SUMMARY `bd8a750`)

Wired per-component loss signals as temperature for E updates:

```python
# update_E accepts loss_signal:
alpha = getattr(self, '_ema_alpha', 0.1)
if loss_signal is not None:
    temp_scale = getattr(self, '_loss_temp_scale', 1.0)
    alpha = alpha * torch.sigmoid(loss_signal.detach() * temp_scale).item()
```

Changes:
- `TernaryScaleTensor.update_E(loss_signal=...)`: Ξ± modulated by sigmoid(total_loss)
- `ByteEmbedding.update_E(loss_signal=...)`: same pattern
- `TernaryRMSNorm.update_E(loss_signal=...)`: accepts kwarg silently (frozen weights)
- `_ternary_update_memory(loss_signal=...)`: passes through to all modules
- `train.py`: passes `loss_comps.total` to `_ternary_update_memory`

**Later use (TERN-E-05):** The multi-scale lattice where each TScaleType level proposes Ξ”E_s was deferred per plan scope. Reactivation trigger: if single-scale EMA saturates and cannot separate magnitude regimes.

## Current State

### E Buffer Architecture (what the code now does)

```
Storage:  T_packed (5-trit/byte uint8) + E (int8 log2) + T_accum (int8)
Forward:  W_eff = 2^(E_exp) * T   where E_exp = expand(E)
Update:   E ← (1-Ξ±)*E + Ξ±*round(log2(|grad|_group))
          T_accum += sign(grad)
          if |T_accum| > 3: flip T, reset T_accum
```

### Files Changed

| File | Lines Changed |
|------|---------------|
| `tscale.py` | ~40 changed, FP8 removed, EMA update added, LossComponent routing added |
| `trigram.py` | ~20 changed, FP8 removed, EMA update + LossComponent routing added |
| `ternary_audit.py` | ~1 changed, FP8 exclusion removed |
| `train.py` | ~1 changed, loss_signal passed to update |
| `benchmark_true_ternary.py` | New: 4-way benchmark comparing Adam/FP32, SignSGD(old), TrueTernary |
| `testing/test_tscale.py` | ~210 lines removed (FP8 tests), 1 comment updated |
| `.planning/ROADMAP.md` | Phase 9 goal and plans rewritten |
| `.planning/REQUIREMENTS.md` | HYB-01–06 replaced with TERN-E-01–05 |

### Tests

All 18 CPU tscale tests pass. 110 morph tests pass.

### Benchmark (in progress)

`benchmark_true_ternary.py` compares 3 configs on the full MORPHTernaryModel (133M params):
- **Adam_FP32**: standard FP32 Adam baseline
- **SignSGD_Old**: SignSGD without `_ternary_update_memory` (static E, no EMA)
- **TrueTernary**: our new int8 E + EMA update + LossComponent routing (CPU update)

Preliminary findings (50 steps, batch=8, seq=66):
- Adam_FP32: ~600ms/step, ~800 tok/s, 4.8GB peak VRAM
- SignSGD_Old: ~600ms/step, ~800 tok/s, 4.8GB peak VRAM
- TrueTernary: ~600ms/step, same VRAM (uses CPU fallback update)

## Known Issue: Triton Kernel Caching Bug

**`_triton_ternary_step_direct_kernel`** takes ~13.6s to compile and does NOT cache across calls, even with identical shapes, causing every `ternary_step` call to take 13.6s. This affects the GPU path for `_ternary_update_memory`.

Workaround in benchmark: CPU-based `cpu_update_memory()` function bypasses Triton for updates.

Symptoms:
- First call: 13.6s (compilation)
- Second call with identical shapes: ALSO 13.6s (recompilation β€” cache miss)

Triton cache directory has 8788 cached kernels but this specific kernel is not cached. Root cause not yet identified β€” likely a complex interaction of `tl.constexpr` parameters (M, N, K, TOTAL, ACCUM_THRESHOLD) + kernel complexity.

## Float Materialization Audit (post-rollback)

| Path | Float tensors created | Persistent? |
|------|----------------------|-------------|
| Forward (CUDA) | `y[M,N]`, `S` (ephemeral exp2) | Output tensor only |
| Backward (CUDA) | `grad_x`, `grad_2d`, `x_2d` | Ephemeral (autograd) |
| update_E (CPU) | `grad_T`, `mu_g`, `e_proposed`, `E_float` | Ephemeral (computed, applied, freed) |
| E buffer | `int8` (always) | Persistent (buffer) |
| S (scale) | **Never stored** β€” `2^E` computed on-the-fly | β€” |

No IEEE float in weight state. No FP32/BF16/FP8 master weights. S is implicit. βœ“

## Remaining Work

1. **Resolve Triton caching bug** in `_triton_ternary_step_direct_kernel` β€” port kernel to use non-constexpr or simplify
2. **Run full benchmark** (all 3 configs to completion) β€” currently TrueTernary step 1+ hangs due to CPU update overhead; need to profile
3. **TERN-E-05 (multi-scale lattice)** β€” evaluate when single-scale EMA saturates
4. **LossComponent routing training validation** β€” verify that Ξ± modulation via sigmoid(loss) actually improves convergence vs fixed Ξ±

## Follow-Up Fix: Benchmark Progress And VRAM Attribution

The benchmark was not measuring the strict true-ternary path. It instantiated the full multimodal `MORPHTernaryModel`, passed trainable floating-point parameters to the optimizers, and kept the expensive baseline comparison path ahead of `TrueTernary`. That made all configs report roughly the same VRAM behavior even though strict ternary training has no persistent float weights.

### Changes

#### `benchmark_true_ternary.py`

- Added CLI controls:
  - `--configs`
  - `--steps`
  - `--warmup`
  - `--batch`
  - `--ctx`
  - `--strict-true-ternary / --no-strict-true-ternary`
  - `--update-backend {gpu,gpu-signcache,dense-fallback,none}`
  - `--scale-update-interval`
  - `--accum-threshold`
  - `--print-every`
  - `--reuse-base / --no-reuse-base`
- `TrueTernary` now defaults to a strict text-only ternary model:
  - no image/audio encoders
  - no VQ
  - no graph
  - no recurrent memory modules
  - float params frozen
- Optimizers are now built only from trainable parameters.
- Per-step logs now show allocated, reserved, and peak CUDA memory.
- Short benchmark averages now divide by the available number of steps instead of always dividing by 20.
- The reusable base model, when requested, is kept as CPU state so it does not inflate benchmark VRAM.

#### `tscale.py`

- CUDA `update_E()` now matches the Plan 09 logarithmic EMA rule instead of the older Β±1 sign update:

```text
mu_g = mean(abs(sign(grad_w) * T)) per group
e_proposed = round(log2(mu_g + eps))
E = (1 - alpha) * E + alpha * e_proposed
```

- The Triton `E` update kernels accept the loss-temperature-adjusted `alpha` computed in Python.

#### New benchmark update backend: `gpu-signcache`

The direct GPU update path is memory-clean but too slow for larger `M = batch * sequence` shapes because it recomputes the gradient reduction inside tiny packed-byte kernels.

`gpu-signcache` uses this flow per module:

```text
grad_y, x
  -> Triton tiled sign-only grad kernel
  -> temporary int8 grad_sign[N, K]
  -> Triton EMA E update
  -> Triton T_accum/repack
  -> free grad_sign
```

This intentionally reintroduces a temporary int8 per-layer sign cache, not a persistent float weight gradient. It is the practical speed path for the benchmark while keeping persistent model state true ternary.

### Verification

Focused CUDA tests:

```text
PASS test_cuda_triton_correctness_update_E
PASS test_cuda_triton_tscale_path
```

Strict small benchmark:

```text
python benchmark_true_ternary.py --configs TrueTernary --steps 4 --warmup 1 --batch 2 --ctx 10 --strict-true-ternary --update-backend gpu-signcache --scale-update-interval 4

logical ternary weights: 14,011,904
trainable float params: 0
float buffers: 0
peak VRAM: ~30 MB
avg step after warmup: ~170 ms
```

Strict benchmark at previous benchmark shape:

```text
python benchmark_true_ternary.py --configs TrueTernary --steps 3 --warmup 1 --batch 8 --ctx 66 --strict-true-ternary --update-backend gpu-signcache --scale-update-interval 4

logical ternary weights: 14,011,904
ternary training state: 17.15 MB
trainable float params: 0
float buffers: 0
peak VRAM: ~430 MB
steady step: ~1.85 s
```

Previous direct backend at the same shape was roughly `49 s/step` after warmup. The benchmark now progresses; remaining speed work is kernel tuning and eventually replacing the temporary sign cache with a fused tiled kernel that is both fast and memory-minimal.