ARBS / docs /true-ternary /TRUE-TERNARY-REFACTOR16.md
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
# TRUE TERNARY REFACTOR 16
Date: 2026-05-20
## Goal
Fix the training memory stacking reported after REFACTOR15 while preserving the rule that ternary trainable components keep persistent state as packed trits or int8/int16 buffers, not fp16/fp32 weights.
## Finding
The remaining OOM pattern was not caused by a new fp32 master weight. It came from retained backward hooks:
- Triton linear backward retained a full `out_dim x in_dim` int8 grad-sign tensor per ternary layer until `_ternary_update_memory()`.
- The current shared VQ config is `131,072 x 1024`, which was below the old sparse threshold and retained a dense table-sized grad-sign hook.
Those tensors are int8, but they still stack during backward and can add hundreds of MB to GB before cleanup.
## Changes
### 1. Streaming CUDA Ternary Updates
Triton linear backward now streams updates directly into existing module buffers:
- `T_accum`: int8
- `E_accum`: int8
It does not retain `_hook_grad_T_sign`, `_hook_grad_2d`, or `_hook_x_2d` for Triton linears. `_ternary_update_memory()` now finalizes streamed state by applying:
- int8 residual exponent steps from `E_accum`
- packed-trit flips from `T_accum`
The packed flip finalizer uses a Triton kernel, so it does not unpack full large matrices just to flip thresholded trits.
### 2. Pre-Backward Loss Scaling
Added `ARBModel.prepare_ternary_backward(loss_signal, update_scales=True)`.
Pure ternary training scripts now call this immediately before `loss.backward()` so the streaming path can use the same loss-scaled ternary step without retaining weight-shaped gradients.
Updated:
- `training/pretrain.py`
- `training/text.py`
- `training/audio.py`
- `training/vision.py`
- `training/diffusion.py`
- `training/smoke_50.py`
`training/smoke_50.py` no longer uses AdamW; it freezes float parameters and uses `_ternary_update_memory()`.
### 3. Sparse Threshold Lowered for VQ Tables
`TernaryEmbeddingTable.sparse_threshold` is now `65,536`.
This makes the current `131,072 x 1024` shared VQ table use sparse candidate-row hooks instead of a dense full-table hook. The targeted check showed:
```text
dense_hooks []
sparse_hook_shapes torch.Size([512, 1024])
```
For the full model trace, shared VQ retained only the touched sparse candidate rows:
```text
bridge.vq.table _hook_sparse_grad_sign (2560, 1024) int8
```
## Full Model Trace
Local full default model run with `max_moe_iters=1`, text-only, batch 1, sequence 12:
```text
logical ternary weights: 2,216,567,968
ternary training state: 2896.13 MB
trainable float params: 0 tensors, 0.00 MB
float buffers: 0 tensors, 0.00 MB
after model.cuda: alloc=3198MB reserved=3219MB peak=3198MB
step 0 after forward: alloc=3273MB reserved=3364MB peak=3325MB
step 0 after backward: alloc=3280MB reserved=3364MB peak=3325MB
step 0 after cleanup: alloc=3249MB reserved=3284MB peak=3389MB
step 1 after forward: alloc=3314MB reserved=3448MB peak=3389MB
step 1 after backward: alloc=3280MB reserved=3448MB peak=3389MB
step 1 after cleanup: alloc=3249MB reserved=3284MB peak=3389MB
```
The important result: allocated memory returned to the same post-cleanup value after both steps, so the backward hook stack is fixed in this local trace.
The current config is not 3B in the audited run:
```text
CODEBOOK_DIM = 1024
SHARED_VQ_SIZE = 131072
HIDDEN_DIM = 7200
KGVQ_CODEBOOK_SIZE = 5_000_000
KGVQ_CODEBOOK_DIM = 64
```
This produced `2.216B` logical ternary weights locally.
## Validation
Passed:
```bash
python -m compileall -q arbitor training testing
python -m pytest -q testing/test_gradient_capture.py testing/test_tilelang_training.py testing/test_tscale.py::test_cuda_triton_tscale_path
python -m pytest -q testing/kg/test_composite_head.py testing/test_gradient_capture.py testing/test_tscale.py::test_cuda_triton_tscale_path
PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True python training/pretrain.py --steps 1 --batch 1 --ctx 12 --text-data training/data/tinyshakespeare.txt --no-save --log-interval 1 --eval-interval 0 --save-interval 0 --max-moe-iters 1
```
The pretrain smoke completed on CUDA with the full default model.
## Remaining Risk
The full dense ternary layers still initialize from temporary float random tensors before being packed. That is not persistent training state, but it can make cold construction slower and more memory-hungry on CPU. A packed-first initializer for dense `TernaryScaleTensor` would be the next cleanup if model construction itself becomes the bottleneck.