File size: 6,280 Bytes
d8bc908 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 | # True Ternary Refactor 5 — Trainable Integer Scale Residuals
## Problem
Strict true-ternary mode has no trainable PyTorch parameters by design. The model learns only by mutating ternary buffers:
- `T_packed`: packed {-1, 0, +1} weights
- `T_accum`: int8 sign-gradient accumulator for ternary flips
- `E`: int8 logarithmic scale exponent
The previous logarithmic `E` update tried to behave like a smoothed float EMA:
```text
E = (1 - alpha) * E + alpha * e_proposed
```
but `E` is stored as `int8`. Casting every update back to `int8` destroys most sub-integer movement. With `alpha=0.1`, small scale-learning signals often round away before they can affect inference. That makes strict mode look like it is not training its scale field, and it leaves learning mostly dependent on coarse `T_accum` sign flips.
Storing per-group Python floats is not viable. A Python float object per scale group would be far larger than the ternary weights. FP8 would work mechanically, but it reintroduces floating-point scale state and weakens the true-ternary constraint.
## Decision
Keep `E` as int8 and add `E_accum`, an int8 residual accumulator per scale group.
This is fixed-point learning state:
```text
E : int8 stored log2 exponent used by inference
E_accum: int8 residual update energy used only during training
```
Update rule:
```text
score_g = sum(sign(grad_w) * T) over group g
delta_g = -sign(score_g)
E_accum_g += delta_g
if E_accum_g >= threshold:
E_g += 1
E_accum_g -= threshold
if E_accum_g <= -threshold:
E_g -= 1
E_accum_g += threshold
```
Default threshold is `4`, which behaves like a quarter-step scale learning rate without storing fractional floats.
This mirrors the useful part of a per-group float accumulator while storing only one extra byte per scale group.
## Why This Fits True Ternary
- Inference still uses only `T_packed` and int8 `E`.
- `S` is still not stored; it is derived as `2^E`.
- There are no FP32/BF16/FP8 master weights.
- There is no persistent full-precision scale tensor.
- Training state grows by `1 byte * num_E_groups`, not by Python float objects or FP tensors.
For the current strict model, audit overhead increases by roughly the size of `E`:
```text
E ~= 1.12 MB
E_accum ~= 1.12 MB
```
For larger models this is still far cheaper than hidden full-precision weights or optimizer state.
## Code Changes
### `tscale.py`
- `TernaryScaleTensor` now registers:
```python
self.register_buffer("E_accum", torch.zeros_like(self.E, dtype=torch.int8))
```
- Added `_ensure_E_accum()` for compatibility with older checkpoints that do not have `E_accum`.
- CPU `update_E()` now uses fixed-point residual updates instead of float EMA truncation.
- Triton `_triton_update_e_kernel` and `_triton_update_e_direct_kernel` now update both:
- `E`
- `E_accum`
- Triton wrappers now pass `E_accum` and integer `E_ACCUM_THRESHOLD`.
- `tscale_to()` resets `E_accum` when regrouping changes the shape of `E`.
### `trigram.py`
- `ByteEmbedding` now registers `E_accum`.
- Added `_ensure_E_accum()`.
- `ByteEmbedding.update_E()` now uses the same fixed-point residual rule as `TernaryScaleTensor`.
### `ternary_audit.py`
- Audit now reports `E_accum` bytes separately in ternary training state.
### `benchmark_true_ternary.py`
- CPU fallback update and `gpu-signcache` update now use the fixed-point residual scale path.
- Benchmark output reports both `E` and `E_accum` counts.
### `testing/test_tscale.py`
- CUDA E-update correctness now compares both `E` and `E_accum`.
- The direct CUDA path test now accepts a residual-only update as valid; with threshold `4`, a single update is expected to modify `E_accum` before `E` itself moves.
## Expected Behavior
Strict mode should no longer rely on immediate whole-step exponent jumps or no-op float EMA truncation. Scale-learning evidence can accumulate across steps in `E_accum` and eventually move `E` by one integer log2 unit.
This does not solve all optimization quality issues. `T` is still discrete, `E` still changes in integer exponent steps, and the model can still be unstable if `T_accum` flips too many weights together. But it fixes the missing scale-learning mechanism without adding persistent floating-point model state.
## LossComponent And Exact Weight Gradients
Important correction: strict mode has zero trainable PyTorch parameters. The existing `pinpoint_backward()` path only calls `torch.autograd.grad()` for parameter groups. With no parameter groups, it could skip backward entirely, which means the custom ternary autograd hooks never run.
Fix in `train.py`:
```python
if not items:
(loss_comps.total / grad_accum).backward()
```
This keeps strict true-ternary training connected to autograd even when there are no float parameters. The gradient captured by each `TernaryScaleTensor` is still exact per logical weight:
```text
grad_sign[n, k] = sign(sum_m grad_y[m, n] * x[m, k])
```
That per-weight gradient sign drives:
- `T_accum[n, k]` for exact ternary weight flips
- grouped `E_accum[n, group(k)]` / `E[n, group(k)]` for scale updates
So the update is per exact weight first, then grouped only at the scale field. That matches the intended split:
- `T` learns at individual weight resolution.
- `E` learns at group resolution because scale is shared by the group.
- LossComponent weights still affect the gradient field through `loss_comps.total`.
Current limitation: in strict mode, component-specific routing does not separately update different ternary module groups, because there are no trainable parameter groups to target. It applies the weighted total loss to all ternary hooks. A future improvement would add ternary module groups to the LossComponent routing map so component losses can be backpropagated separately into selected ternary modules while still avoiding float parameters.
## Next Checks
1. Verify strict training mutates:
- `T_packed`
- `T_accum`
- `E_accum`
- `E` after enough scale updates
2. Track loss with:
```text
--strict_ternary --scale_update_interval 1
```
and compare against interval `4`.
3. If `E` still moves too slowly, reduce `_e_accum_threshold` to `2` for selected layers.
4. If `E` moves too violently, raise `_e_accum_threshold` to `8`.
|