ARBS / docs /true-ternary /TRUE-TERNARY-REFACTOR5.md
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified

True Ternary Refactor 5 — Trainable Integer Scale Residuals

Problem

Strict true-ternary mode has no trainable PyTorch parameters by design. The model learns only by mutating ternary buffers:

  • T_packed: packed {-1, 0, +1} weights
  • T_accum: int8 sign-gradient accumulator for ternary flips
  • E: int8 logarithmic scale exponent

The previous logarithmic E update tried to behave like a smoothed float EMA:

E = (1 - alpha) * E + alpha * e_proposed

but E is stored as int8. Casting every update back to int8 destroys most sub-integer movement. With alpha=0.1, small scale-learning signals often round away before they can affect inference. That makes strict mode look like it is not training its scale field, and it leaves learning mostly dependent on coarse T_accum sign flips.

Storing per-group Python floats is not viable. A Python float object per scale group would be far larger than the ternary weights. FP8 would work mechanically, but it reintroduces floating-point scale state and weakens the true-ternary constraint.

Decision

Keep E as int8 and add E_accum, an int8 residual accumulator per scale group.

This is fixed-point learning state:

E      : int8 stored log2 exponent used by inference
E_accum: int8 residual update energy used only during training

Update rule:

score_g = sum(sign(grad_w) * T) over group g
delta_g = -sign(score_g)
E_accum_g += delta_g

if E_accum_g >= threshold:
    E_g += 1
    E_accum_g -= threshold

if E_accum_g <= -threshold:
    E_g -= 1
    E_accum_g += threshold

Default threshold is 4, which behaves like a quarter-step scale learning rate without storing fractional floats.

This mirrors the useful part of a per-group float accumulator while storing only one extra byte per scale group.

Why This Fits True Ternary

  • Inference still uses only T_packed and int8 E.
  • S is still not stored; it is derived as 2^E.
  • There are no FP32/BF16/FP8 master weights.
  • There is no persistent full-precision scale tensor.
  • Training state grows by 1 byte * num_E_groups, not by Python float objects or FP tensors.

For the current strict model, audit overhead increases by roughly the size of E:

E       ~= 1.12 MB
E_accum ~= 1.12 MB

For larger models this is still far cheaper than hidden full-precision weights or optimizer state.

Code Changes

tscale.py

  • TernaryScaleTensor now registers:
self.register_buffer("E_accum", torch.zeros_like(self.E, dtype=torch.int8))
  • Added _ensure_E_accum() for compatibility with older checkpoints that do not have E_accum.
  • CPU update_E() now uses fixed-point residual updates instead of float EMA truncation.
  • Triton _triton_update_e_kernel and _triton_update_e_direct_kernel now update both:
    • E
    • E_accum
  • Triton wrappers now pass E_accum and integer E_ACCUM_THRESHOLD.
  • tscale_to() resets E_accum when regrouping changes the shape of E.

trigram.py

  • ByteEmbedding now registers E_accum.
  • Added _ensure_E_accum().
  • ByteEmbedding.update_E() now uses the same fixed-point residual rule as TernaryScaleTensor.

ternary_audit.py

  • Audit now reports E_accum bytes separately in ternary training state.

benchmark_true_ternary.py

  • CPU fallback update and gpu-signcache update now use the fixed-point residual scale path.
  • Benchmark output reports both E and E_accum counts.

testing/test_tscale.py

  • CUDA E-update correctness now compares both E and E_accum.
  • The direct CUDA path test now accepts a residual-only update as valid; with threshold 4, a single update is expected to modify E_accum before E itself moves.

Expected Behavior

Strict mode should no longer rely on immediate whole-step exponent jumps or no-op float EMA truncation. Scale-learning evidence can accumulate across steps in E_accum and eventually move E by one integer log2 unit.

This does not solve all optimization quality issues. T is still discrete, E still changes in integer exponent steps, and the model can still be unstable if T_accum flips too many weights together. But it fixes the missing scale-learning mechanism without adding persistent floating-point model state.

LossComponent And Exact Weight Gradients

Important correction: strict mode has zero trainable PyTorch parameters. The existing pinpoint_backward() path only calls torch.autograd.grad() for parameter groups. With no parameter groups, it could skip backward entirely, which means the custom ternary autograd hooks never run.

Fix in train.py:

if not items:
    (loss_comps.total / grad_accum).backward()

This keeps strict true-ternary training connected to autograd even when there are no float parameters. The gradient captured by each TernaryScaleTensor is still exact per logical weight:

grad_sign[n, k] = sign(sum_m grad_y[m, n] * x[m, k])

That per-weight gradient sign drives:

  • T_accum[n, k] for exact ternary weight flips
  • grouped E_accum[n, group(k)] / E[n, group(k)] for scale updates

So the update is per exact weight first, then grouped only at the scale field. That matches the intended split:

  • T learns at individual weight resolution.
  • E learns at group resolution because scale is shared by the group.
  • LossComponent weights still affect the gradient field through loss_comps.total.

Current limitation: in strict mode, component-specific routing does not separately update different ternary module groups, because there are no trainable parameter groups to target. It applies the weighted total loss to all ternary hooks. A future improvement would add ternary module groups to the LossComponent routing map so component losses can be backpropagated separately into selected ternary modules while still avoiding float parameters.

Next Checks

  1. Verify strict training mutates:
    • T_packed
    • T_accum
    • E_accum
    • E after enough scale updates
  2. Track loss with:
--strict_ternary --scale_update_interval 1

and compare against interval 4.

  1. If E still moves too slowly, reduce _e_accum_threshold to 2 for selected layers.
  2. If E moves too violently, raise _e_accum_threshold to 8.