TRUE TERNARY REFACTOR 16
Date: 2026-05-20
Goal
Fix the training memory stacking reported after REFACTOR15 while preserving the rule that ternary trainable components keep persistent state as packed trits or int8/int16 buffers, not fp16/fp32 weights.
Finding
The remaining OOM pattern was not caused by a new fp32 master weight. It came from retained backward hooks:
- Triton linear backward retained a full
out_dim x in_dimint8 grad-sign tensor per ternary layer until_ternary_update_memory(). - The current shared VQ config is
131,072 x 1024, which was below the old sparse threshold and retained a dense table-sized grad-sign hook.
Those tensors are int8, but they still stack during backward and can add hundreds of MB to GB before cleanup.
Changes
1. Streaming CUDA Ternary Updates
Triton linear backward now streams updates directly into existing module buffers:
T_accum: int8E_accum: int8
It does not retain _hook_grad_T_sign, _hook_grad_2d, or _hook_x_2d for Triton linears. _ternary_update_memory() now finalizes streamed state by applying:
- int8 residual exponent steps from
E_accum - packed-trit flips from
T_accum
The packed flip finalizer uses a Triton kernel, so it does not unpack full large matrices just to flip thresholded trits.
2. Pre-Backward Loss Scaling
Added ARBModel.prepare_ternary_backward(loss_signal, update_scales=True).
Pure ternary training scripts now call this immediately before loss.backward() so the streaming path can use the same loss-scaled ternary step without retaining weight-shaped gradients.
Updated:
training/pretrain.pytraining/text.pytraining/audio.pytraining/vision.pytraining/diffusion.pytraining/smoke_50.py
training/smoke_50.py no longer uses AdamW; it freezes float parameters and uses _ternary_update_memory().
3. Sparse Threshold Lowered for VQ Tables
TernaryEmbeddingTable.sparse_threshold is now 65,536.
This makes the current 131,072 x 1024 shared VQ table use sparse candidate-row hooks instead of a dense full-table hook. The targeted check showed:
dense_hooks []
sparse_hook_shapes torch.Size([512, 1024])
For the full model trace, shared VQ retained only the touched sparse candidate rows:
bridge.vq.table _hook_sparse_grad_sign (2560, 1024) int8
Full Model Trace
Local full default model run with max_moe_iters=1, text-only, batch 1, sequence 12:
logical ternary weights: 2,216,567,968
ternary training state: 2896.13 MB
trainable float params: 0 tensors, 0.00 MB
float buffers: 0 tensors, 0.00 MB
after model.cuda: alloc=3198MB reserved=3219MB peak=3198MB
step 0 after forward: alloc=3273MB reserved=3364MB peak=3325MB
step 0 after backward: alloc=3280MB reserved=3364MB peak=3325MB
step 0 after cleanup: alloc=3249MB reserved=3284MB peak=3389MB
step 1 after forward: alloc=3314MB reserved=3448MB peak=3389MB
step 1 after backward: alloc=3280MB reserved=3448MB peak=3389MB
step 1 after cleanup: alloc=3249MB reserved=3284MB peak=3389MB
The important result: allocated memory returned to the same post-cleanup value after both steps, so the backward hook stack is fixed in this local trace.
The current config is not 3B in the audited run:
CODEBOOK_DIM = 1024
SHARED_VQ_SIZE = 131072
HIDDEN_DIM = 7200
KGVQ_CODEBOOK_SIZE = 5_000_000
KGVQ_CODEBOOK_DIM = 64
This produced 2.216B logical ternary weights locally.
Validation
Passed:
python -m compileall -q arbitor training testing
python -m pytest -q testing/test_gradient_capture.py testing/test_tilelang_training.py testing/test_tscale.py::test_cuda_triton_tscale_path
python -m pytest -q testing/kg/test_composite_head.py testing/test_gradient_capture.py testing/test_tscale.py::test_cuda_triton_tscale_path
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True python training/pretrain.py --steps 1 --batch 1 --ctx 12 --text-data training/data/tinyshakespeare.txt --no-save --log-interval 1 --eval-interval 0 --save-interval 0 --max-moe-iters 1
The pretrain smoke completed on CUDA with the full default model.
Remaining Risk
The full dense ternary layers still initialize from temporary float random tensors before being packed. That is not persistent training state, but it can make cold construction slower and more memory-hungry on CPU. A packed-first initializer for dense TernaryScaleTensor would be the next cleanup if model construction itself becomes the bottleneck.