ARBS / docs /true-ternary /TRUE-TERNARY-REFACTOR17.md
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified

TRUE TERNARY REFACTOR 17

Date: 2026-05-20

Goal

Fix the post-REFACTOR16 training OOM/regression and compare the production training path against training/train_small_float.py, which was reported as the small learning reference.

Finding

training/train_small_float.py is not a ternary-weight trainer. It is a float32 baseline:

  • nn.Embedding
  • nn.Linear
  • nn.RMSNorm
  • torch.optim.AdamW

Local run:

python training/train_small_float.py --steps 200 --dim 256 --layers 4 --batch 16 --ctx 64 --device cuda
Model: 1.73M params
step 0:   train=5.8205 val=5.5523
step 199: train=2.5854 val=2.6280

The useful parts to copy into the ternary system were not AdamW or float weights. They were stable fan-in-scaled initialization, simple byte supervision, and small update steps.

Changes

1. Restored Streaming Backward In Training

ARBModel.prepare_ternary_backward() was setting _stream_backward_updates = False, which disabled the REFACTOR16 streaming path before every training backward pass. That forced old full weight-shaped hook tensors back into memory.

It now sets:

_stream_backward_updates = True

This keeps Triton linears accumulating directly into T_accum and E_accum.

2. Fixed Gradient Descent Direction For T Flips

The ternary T update was accumulating +sign(grad), which is gradient ascent for the ternary sign. It now accumulates -sign(grad) in:

  • Triton direct accumulation
  • Triton legacy ternary step
  • dense TernaryScaleTensor.ternary_step
  • ByteEmbedding.ternary_step
  • sparse TernaryEmbeddingTable rows

The CUDA test expectation was updated so a positive gradient moves the trit toward -1.

3. Removed Persistent Float Component Accumulator

The componentwise loss path could create _T_accum_fp, a full float32 buffer shaped like the ternary weight. That violates the packed/int8 training-state rule and can OOM large layers.

Component loss accumulation now uses int8 sign accumulation only and clears any stale _T_accum_fp.

4. Large RMSNorm CUDA Fallback

Large hidden RMSNorms were OOMing in Triton backward because the kernel used BLOCK_D=next_power_of_2(dim). For wide norms, CUDA now uses the identity-weight RMS path:

x * rsqrt(mean(x*x) + eps)

This is correct for current TernaryRMSNorm because its T/E state initializes as identity and is not trained.

5. Fan-In-Aware Ternary Initialization

The full system was initializing every ternary linear with std=0.1, independent of fan-in. At 8192-wide layers this made the initial logits much too large.

Default TernaryScaleTensor init now uses:

std = min(0.1, 1 / sqrt(in_dim))
threshold = min(config_threshold, 0.5 * std)

This keeps ternary masks active while matching the float baseline's stable fan-in scaling. ByteEmbedding and dense TernaryEmbeddingTable now also scale their ternary threshold from init std.

6. Conservative T Step

_ternary_t_step() now returns 1. The old loss-scaled value could be 4, which crossed the default threshold 3 in one backward pass and caused a destructive mass flip immediately after good initialization.

7. Configurable MoEGraph Top-K

ARBModel no longer hard-codes MoEGraph top_k=4. It now uses MG_TOP_K from config, defaulting to 2, which reduces routed expert work while keeping multi-expert routing active.

Validation

Passed:

python -m compileall -q arbitor training testing
python -m pytest -q testing/test_gradient_capture.py testing/test_tilelang_training.py testing/test_tscale.py::test_small_ternary_training_loss_finite testing/test_tscale.py::test_e_per_component_routing testing/test_tscale.py::test_cuda_triton_tscale_path

Full model memory and loss probe, text-only, ctx=64, max_moe_iters=4, MG_TOP_K=2:

logical ternary weights: 3,122,925,280
ternary training state: 4071.0 MB
after cuda: alloc=4307.3MB reserved=4378.0MB

step 0 loss=7.565 cleanup alloc=4355.6MB peak=5058.0MB
step 1 loss=7.503 cleanup alloc=4355.6MB peak=5106.3MB
step 2 loss=7.668 cleanup alloc=4355.6MB peak=5106.3MB

The important result is that allocated VRAM returns to the same post-step baseline. The previous constant stacking behavior was not reproduced.

Actual pretrain entrypoint smoke:

python training/pretrain.py --steps 2 --batch 1 --ctx 64 --text-data /tmp/tinyshakespeare.txt --max-moe-iters 4 --no-save --log-interval 1 --eval-interval 0 --save-interval 0

logical ternary weights: 3,122,925,280
trainable float params: 0 tensors, 0.00 MB
float buffers: 0 tensors, 0.00 MB
loss step 1: 7.5023
loss step 2: 7.4532

Current Status

The model is now above the requested 3.1B logical ternary target locally:

3.122925B logical ternary weights

Persistent trainable state remains packed trits and int8 accumulators/scales. The remaining practical issue is speed: the first step can still pay compile/setup cost, and MoEGraph top-k/depth is expensive at full size. The memory blocker and immediate loss blow-up are fixed in the local CUDA probes.