| # TRUE TERNARY REFACTOR 16 |
|
|
| Date: 2026-05-20 |
|
|
| ## Goal |
|
|
| Fix the training memory stacking reported after REFACTOR15 while preserving the rule that ternary trainable components keep persistent state as packed trits or int8/int16 buffers, not fp16/fp32 weights. |
|
|
| ## Finding |
|
|
| The remaining OOM pattern was not caused by a new fp32 master weight. It came from retained backward hooks: |
|
|
| - Triton linear backward retained a full `out_dim x in_dim` int8 grad-sign tensor per ternary layer until `_ternary_update_memory()`. |
| - The current shared VQ config is `131,072 x 1024`, which was below the old sparse threshold and retained a dense table-sized grad-sign hook. |
|
|
| Those tensors are int8, but they still stack during backward and can add hundreds of MB to GB before cleanup. |
|
|
| ## Changes |
|
|
| ### 1. Streaming CUDA Ternary Updates |
|
|
| Triton linear backward now streams updates directly into existing module buffers: |
|
|
| - `T_accum`: int8 |
| - `E_accum`: int8 |
|
|
| It does not retain `_hook_grad_T_sign`, `_hook_grad_2d`, or `_hook_x_2d` for Triton linears. `_ternary_update_memory()` now finalizes streamed state by applying: |
|
|
| - int8 residual exponent steps from `E_accum` |
| - packed-trit flips from `T_accum` |
|
|
| The packed flip finalizer uses a Triton kernel, so it does not unpack full large matrices just to flip thresholded trits. |
|
|
| ### 2. Pre-Backward Loss Scaling |
|
|
| Added `ARBModel.prepare_ternary_backward(loss_signal, update_scales=True)`. |
|
|
| Pure ternary training scripts now call this immediately before `loss.backward()` so the streaming path can use the same loss-scaled ternary step without retaining weight-shaped gradients. |
|
|
| Updated: |
|
|
| - `training/pretrain.py` |
| - `training/text.py` |
| - `training/audio.py` |
| - `training/vision.py` |
| - `training/diffusion.py` |
| - `training/smoke_50.py` |
|
|
| `training/smoke_50.py` no longer uses AdamW; it freezes float parameters and uses `_ternary_update_memory()`. |
|
|
| ### 3. Sparse Threshold Lowered for VQ Tables |
|
|
| `TernaryEmbeddingTable.sparse_threshold` is now `65,536`. |
|
|
| This makes the current `131,072 x 1024` shared VQ table use sparse candidate-row hooks instead of a dense full-table hook. The targeted check showed: |
|
|
| ```text |
| dense_hooks [] |
| sparse_hook_shapes torch.Size([512, 1024]) |
| ``` |
|
|
| For the full model trace, shared VQ retained only the touched sparse candidate rows: |
|
|
| ```text |
| bridge.vq.table _hook_sparse_grad_sign (2560, 1024) int8 |
| ``` |
|
|
| ## Full Model Trace |
|
|
| Local full default model run with `max_moe_iters=1`, text-only, batch 1, sequence 12: |
|
|
| ```text |
| logical ternary weights: 2,216,567,968 |
| ternary training state: 2896.13 MB |
| trainable float params: 0 tensors, 0.00 MB |
| float buffers: 0 tensors, 0.00 MB |
| |
| after model.cuda: alloc=3198MB reserved=3219MB peak=3198MB |
| step 0 after forward: alloc=3273MB reserved=3364MB peak=3325MB |
| step 0 after backward: alloc=3280MB reserved=3364MB peak=3325MB |
| step 0 after cleanup: alloc=3249MB reserved=3284MB peak=3389MB |
| step 1 after forward: alloc=3314MB reserved=3448MB peak=3389MB |
| step 1 after backward: alloc=3280MB reserved=3448MB peak=3389MB |
| step 1 after cleanup: alloc=3249MB reserved=3284MB peak=3389MB |
| ``` |
|
|
| The important result: allocated memory returned to the same post-cleanup value after both steps, so the backward hook stack is fixed in this local trace. |
|
|
| The current config is not 3B in the audited run: |
|
|
| ```text |
| CODEBOOK_DIM = 1024 |
| SHARED_VQ_SIZE = 131072 |
| HIDDEN_DIM = 7200 |
| KGVQ_CODEBOOK_SIZE = 5_000_000 |
| KGVQ_CODEBOOK_DIM = 64 |
| ``` |
|
|
| This produced `2.216B` logical ternary weights locally. |
|
|
| ## Validation |
|
|
| Passed: |
|
|
| ```bash |
| python -m compileall -q arbitor training testing |
| python -m pytest -q testing/test_gradient_capture.py testing/test_tilelang_training.py testing/test_tscale.py::test_cuda_triton_tscale_path |
| python -m pytest -q testing/kg/test_composite_head.py testing/test_gradient_capture.py testing/test_tscale.py::test_cuda_triton_tscale_path |
| PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True python training/pretrain.py --steps 1 --batch 1 --ctx 12 --text-data training/data/tinyshakespeare.txt --no-save --log-interval 1 --eval-interval 0 --save-interval 0 --max-moe-iters 1 |
| ``` |
|
|
| The pretrain smoke completed on CUDA with the full default model. |
|
|
| ## Remaining Risk |
|
|
| The full dense ternary layers still initialize from temporary float random tensors before being packed. That is not persistent training state, but it can make cold construction slower and more memory-hungry on CPU. A packed-first initializer for dense `TernaryScaleTensor` would be the next cleanup if model construction itself becomes the bottleneck. |
|
|