True Ternary Refactor 5 — Trainable Integer Scale Residuals
Problem
Strict true-ternary mode has no trainable PyTorch parameters by design. The model learns only by mutating ternary buffers:
T_packed: packed {-1, 0, +1} weightsT_accum: int8 sign-gradient accumulator for ternary flipsE: int8 logarithmic scale exponent
The previous logarithmic E update tried to behave like a smoothed float EMA:
E = (1 - alpha) * E + alpha * e_proposed
but E is stored as int8. Casting every update back to int8 destroys most sub-integer movement. With alpha=0.1, small scale-learning signals often round away before they can affect inference. That makes strict mode look like it is not training its scale field, and it leaves learning mostly dependent on coarse T_accum sign flips.
Storing per-group Python floats is not viable. A Python float object per scale group would be far larger than the ternary weights. FP8 would work mechanically, but it reintroduces floating-point scale state and weakens the true-ternary constraint.
Decision
Keep E as int8 and add E_accum, an int8 residual accumulator per scale group.
This is fixed-point learning state:
E : int8 stored log2 exponent used by inference
E_accum: int8 residual update energy used only during training
Update rule:
score_g = sum(sign(grad_w) * T) over group g
delta_g = -sign(score_g)
E_accum_g += delta_g
if E_accum_g >= threshold:
E_g += 1
E_accum_g -= threshold
if E_accum_g <= -threshold:
E_g -= 1
E_accum_g += threshold
Default threshold is 4, which behaves like a quarter-step scale learning rate without storing fractional floats.
This mirrors the useful part of a per-group float accumulator while storing only one extra byte per scale group.
Why This Fits True Ternary
- Inference still uses only
T_packedand int8E. Sis still not stored; it is derived as2^E.- There are no FP32/BF16/FP8 master weights.
- There is no persistent full-precision scale tensor.
- Training state grows by
1 byte * num_E_groups, not by Python float objects or FP tensors.
For the current strict model, audit overhead increases by roughly the size of E:
E ~= 1.12 MB
E_accum ~= 1.12 MB
For larger models this is still far cheaper than hidden full-precision weights or optimizer state.
Code Changes
tscale.py
TernaryScaleTensornow registers:
self.register_buffer("E_accum", torch.zeros_like(self.E, dtype=torch.int8))
- Added
_ensure_E_accum()for compatibility with older checkpoints that do not haveE_accum. - CPU
update_E()now uses fixed-point residual updates instead of float EMA truncation. - Triton
_triton_update_e_kerneland_triton_update_e_direct_kernelnow update both:EE_accum
- Triton wrappers now pass
E_accumand integerE_ACCUM_THRESHOLD. tscale_to()resetsE_accumwhen regrouping changes the shape ofE.
trigram.py
ByteEmbeddingnow registersE_accum.- Added
_ensure_E_accum(). ByteEmbedding.update_E()now uses the same fixed-point residual rule asTernaryScaleTensor.
ternary_audit.py
- Audit now reports
E_accumbytes separately in ternary training state.
benchmark_true_ternary.py
- CPU fallback update and
gpu-signcacheupdate now use the fixed-point residual scale path. - Benchmark output reports both
EandE_accumcounts.
testing/test_tscale.py
- CUDA E-update correctness now compares both
EandE_accum. - The direct CUDA path test now accepts a residual-only update as valid; with threshold
4, a single update is expected to modifyE_accumbeforeEitself moves.
Expected Behavior
Strict mode should no longer rely on immediate whole-step exponent jumps or no-op float EMA truncation. Scale-learning evidence can accumulate across steps in E_accum and eventually move E by one integer log2 unit.
This does not solve all optimization quality issues. T is still discrete, E still changes in integer exponent steps, and the model can still be unstable if T_accum flips too many weights together. But it fixes the missing scale-learning mechanism without adding persistent floating-point model state.
LossComponent And Exact Weight Gradients
Important correction: strict mode has zero trainable PyTorch parameters. The existing pinpoint_backward() path only calls torch.autograd.grad() for parameter groups. With no parameter groups, it could skip backward entirely, which means the custom ternary autograd hooks never run.
Fix in train.py:
if not items:
(loss_comps.total / grad_accum).backward()
This keeps strict true-ternary training connected to autograd even when there are no float parameters. The gradient captured by each TernaryScaleTensor is still exact per logical weight:
grad_sign[n, k] = sign(sum_m grad_y[m, n] * x[m, k])
That per-weight gradient sign drives:
T_accum[n, k]for exact ternary weight flips- grouped
E_accum[n, group(k)]/E[n, group(k)]for scale updates
So the update is per exact weight first, then grouped only at the scale field. That matches the intended split:
Tlearns at individual weight resolution.Elearns at group resolution because scale is shared by the group.- LossComponent weights still affect the gradient field through
loss_comps.total.
Current limitation: in strict mode, component-specific routing does not separately update different ternary module groups, because there are no trainable parameter groups to target. It applies the weighted total loss to all ternary hooks. A future improvement would add ternary module groups to the LossComponent routing map so component losses can be backpropagated separately into selected ternary modules while still avoiding float parameters.
Next Checks
- Verify strict training mutates:
T_packedT_accumE_accumEafter enough scale updates
- Track loss with:
--strict_ternary --scale_update_interval 1
and compare against interval 4.
- If
Estill moves too slowly, reduce_e_accum_thresholdto2for selected layers. - If
Emoves too violently, raise_e_accum_thresholdto8.