TRUE TERNARY REFACTOR 17
Date: 2026-05-20
Goal
Fix the post-REFACTOR16 training OOM/regression and compare the production training path against training/train_small_float.py, which was reported as the small learning reference.
Finding
training/train_small_float.py is not a ternary-weight trainer. It is a float32 baseline:
nn.Embeddingnn.Linearnn.RMSNormtorch.optim.AdamW
Local run:
python training/train_small_float.py --steps 200 --dim 256 --layers 4 --batch 16 --ctx 64 --device cuda
Model: 1.73M params
step 0: train=5.8205 val=5.5523
step 199: train=2.5854 val=2.6280
The useful parts to copy into the ternary system were not AdamW or float weights. They were stable fan-in-scaled initialization, simple byte supervision, and small update steps.
Changes
1. Restored Streaming Backward In Training
ARBModel.prepare_ternary_backward() was setting _stream_backward_updates = False, which disabled the REFACTOR16 streaming path before every training backward pass. That forced old full weight-shaped hook tensors back into memory.
It now sets:
_stream_backward_updates = True
This keeps Triton linears accumulating directly into T_accum and E_accum.
2. Fixed Gradient Descent Direction For T Flips
The ternary T update was accumulating +sign(grad), which is gradient ascent for the ternary sign. It now accumulates -sign(grad) in:
- Triton direct accumulation
- Triton legacy ternary step
- dense
TernaryScaleTensor.ternary_step ByteEmbedding.ternary_step- sparse
TernaryEmbeddingTablerows
The CUDA test expectation was updated so a positive gradient moves the trit toward -1.
3. Removed Persistent Float Component Accumulator
The componentwise loss path could create _T_accum_fp, a full float32 buffer shaped like the ternary weight. That violates the packed/int8 training-state rule and can OOM large layers.
Component loss accumulation now uses int8 sign accumulation only and clears any stale _T_accum_fp.
4. Large RMSNorm CUDA Fallback
Large hidden RMSNorms were OOMing in Triton backward because the kernel used BLOCK_D=next_power_of_2(dim). For wide norms, CUDA now uses the identity-weight RMS path:
x * rsqrt(mean(x*x) + eps)
This is correct for current TernaryRMSNorm because its T/E state initializes as identity and is not trained.
5. Fan-In-Aware Ternary Initialization
The full system was initializing every ternary linear with std=0.1, independent of fan-in. At 8192-wide layers this made the initial logits much too large.
Default TernaryScaleTensor init now uses:
std = min(0.1, 1 / sqrt(in_dim))
threshold = min(config_threshold, 0.5 * std)
This keeps ternary masks active while matching the float baseline's stable fan-in scaling. ByteEmbedding and dense TernaryEmbeddingTable now also scale their ternary threshold from init std.
6. Conservative T Step
_ternary_t_step() now returns 1. The old loss-scaled value could be 4, which crossed the default threshold 3 in one backward pass and caused a destructive mass flip immediately after good initialization.
7. Configurable MoEGraph Top-K
ARBModel no longer hard-codes MoEGraph top_k=4. It now uses MG_TOP_K from config, defaulting to 2, which reduces routed expert work while keeping multi-expert routing active.
Validation
Passed:
python -m compileall -q arbitor training testing
python -m pytest -q testing/test_gradient_capture.py testing/test_tilelang_training.py testing/test_tscale.py::test_small_ternary_training_loss_finite testing/test_tscale.py::test_e_per_component_routing testing/test_tscale.py::test_cuda_triton_tscale_path
Full model memory and loss probe, text-only, ctx=64, max_moe_iters=4, MG_TOP_K=2:
logical ternary weights: 3,122,925,280
ternary training state: 4071.0 MB
after cuda: alloc=4307.3MB reserved=4378.0MB
step 0 loss=7.565 cleanup alloc=4355.6MB peak=5058.0MB
step 1 loss=7.503 cleanup alloc=4355.6MB peak=5106.3MB
step 2 loss=7.668 cleanup alloc=4355.6MB peak=5106.3MB
The important result is that allocated VRAM returns to the same post-step baseline. The previous constant stacking behavior was not reproduced.
Actual pretrain entrypoint smoke:
python training/pretrain.py --steps 2 --batch 1 --ctx 64 --text-data /tmp/tinyshakespeare.txt --max-moe-iters 4 --no-save --log-interval 1 --eval-interval 0 --save-interval 0
logical ternary weights: 3,122,925,280
trainable float params: 0 tensors, 0.00 MB
float buffers: 0 tensors, 0.00 MB
loss step 1: 7.5023
loss step 2: 7.4532
Current Status
The model is now above the requested 3.1B logical ternary target locally:
3.122925B logical ternary weights
Persistent trainable state remains packed trits and int8 accumulators/scales. The remaining practical issue is speed: the first step can still pay compile/setup cost, and MoEGraph top-k/depth is expensive at full size. The memory blocker and immediate loss blow-up are fixed in the local CUDA probes.