| # TRUE TERNARY REFACTOR 17 |
|
|
| Date: 2026-05-20 |
|
|
| ## Goal |
|
|
| Fix the post-REFACTOR16 training OOM/regression and compare the production training path against `training/train_small_float.py`, which was reported as the small learning reference. |
|
|
| ## Finding |
|
|
| `training/train_small_float.py` is not a ternary-weight trainer. It is a float32 baseline: |
|
|
| - `nn.Embedding` |
| - `nn.Linear` |
| - `nn.RMSNorm` |
| - `torch.optim.AdamW` |
|
|
| Local run: |
|
|
| ```text |
| python training/train_small_float.py --steps 200 --dim 256 --layers 4 --batch 16 --ctx 64 --device cuda |
| Model: 1.73M params |
| step 0: train=5.8205 val=5.5523 |
| step 199: train=2.5854 val=2.6280 |
| ``` |
|
|
| The useful parts to copy into the ternary system were not AdamW or float weights. They were stable fan-in-scaled initialization, simple byte supervision, and small update steps. |
|
|
| ## Changes |
|
|
| ### 1. Restored Streaming Backward In Training |
|
|
| `ARBModel.prepare_ternary_backward()` was setting `_stream_backward_updates = False`, which disabled the REFACTOR16 streaming path before every training backward pass. That forced old full weight-shaped hook tensors back into memory. |
|
|
| It now sets: |
|
|
| ```text |
| _stream_backward_updates = True |
| ``` |
|
|
| This keeps Triton linears accumulating directly into `T_accum` and `E_accum`. |
|
|
| ### 2. Fixed Gradient Descent Direction For T Flips |
|
|
| The ternary T update was accumulating `+sign(grad)`, which is gradient ascent for the ternary sign. It now accumulates `-sign(grad)` in: |
|
|
| - Triton direct accumulation |
| - Triton legacy ternary step |
| - dense `TernaryScaleTensor.ternary_step` |
| - `ByteEmbedding.ternary_step` |
| - sparse `TernaryEmbeddingTable` rows |
|
|
| The CUDA test expectation was updated so a positive gradient moves the trit toward `-1`. |
|
|
| ### 3. Removed Persistent Float Component Accumulator |
|
|
| The componentwise loss path could create `_T_accum_fp`, a full float32 buffer shaped like the ternary weight. That violates the packed/int8 training-state rule and can OOM large layers. |
|
|
| Component loss accumulation now uses int8 sign accumulation only and clears any stale `_T_accum_fp`. |
|
|
| ### 4. Large RMSNorm CUDA Fallback |
|
|
| Large hidden RMSNorms were OOMing in Triton backward because the kernel used `BLOCK_D=next_power_of_2(dim)`. For wide norms, CUDA now uses the identity-weight RMS path: |
|
|
| ```text |
| x * rsqrt(mean(x*x) + eps) |
| ``` |
|
|
| This is correct for current `TernaryRMSNorm` because its T/E state initializes as identity and is not trained. |
|
|
| ### 5. Fan-In-Aware Ternary Initialization |
|
|
| The full system was initializing every ternary linear with `std=0.1`, independent of fan-in. At 8192-wide layers this made the initial logits much too large. |
|
|
| Default `TernaryScaleTensor` init now uses: |
|
|
| ```text |
| std = min(0.1, 1 / sqrt(in_dim)) |
| threshold = min(config_threshold, 0.5 * std) |
| ``` |
|
|
| This keeps ternary masks active while matching the float baseline's stable fan-in scaling. `ByteEmbedding` and dense `TernaryEmbeddingTable` now also scale their ternary threshold from init std. |
|
|
| ### 6. Conservative T Step |
|
|
| `_ternary_t_step()` now returns `1`. The old loss-scaled value could be `4`, which crossed the default threshold `3` in one backward pass and caused a destructive mass flip immediately after good initialization. |
|
|
| ### 7. Configurable MoEGraph Top-K |
|
|
| `ARBModel` no longer hard-codes MoEGraph `top_k=4`. It now uses `MG_TOP_K` from config, defaulting to `2`, which reduces routed expert work while keeping multi-expert routing active. |
|
|
| ## Validation |
|
|
| Passed: |
|
|
| ```text |
| python -m compileall -q arbitor training testing |
| python -m pytest -q testing/test_gradient_capture.py testing/test_tilelang_training.py testing/test_tscale.py::test_small_ternary_training_loss_finite testing/test_tscale.py::test_e_per_component_routing testing/test_tscale.py::test_cuda_triton_tscale_path |
| ``` |
|
|
| Full model memory and loss probe, text-only, `ctx=64`, `max_moe_iters=4`, `MG_TOP_K=2`: |
|
|
| ```text |
| logical ternary weights: 3,122,925,280 |
| ternary training state: 4071.0 MB |
| after cuda: alloc=4307.3MB reserved=4378.0MB |
| |
| step 0 loss=7.565 cleanup alloc=4355.6MB peak=5058.0MB |
| step 1 loss=7.503 cleanup alloc=4355.6MB peak=5106.3MB |
| step 2 loss=7.668 cleanup alloc=4355.6MB peak=5106.3MB |
| ``` |
|
|
| The important result is that allocated VRAM returns to the same post-step baseline. The previous constant stacking behavior was not reproduced. |
|
|
| Actual pretrain entrypoint smoke: |
|
|
| ```text |
| python training/pretrain.py --steps 2 --batch 1 --ctx 64 --text-data /tmp/tinyshakespeare.txt --max-moe-iters 4 --no-save --log-interval 1 --eval-interval 0 --save-interval 0 |
| |
| logical ternary weights: 3,122,925,280 |
| trainable float params: 0 tensors, 0.00 MB |
| float buffers: 0 tensors, 0.00 MB |
| loss step 1: 7.5023 |
| loss step 2: 7.4532 |
| ``` |
|
|
| ## Current Status |
|
|
| The model is now above the requested 3.1B logical ternary target locally: |
|
|
| ```text |
| 3.122925B logical ternary weights |
| ``` |
|
|
| Persistent trainable state remains packed trits and int8 accumulators/scales. The remaining practical issue is speed: the first step can still pay compile/setup cost, and MoEGraph top-k/depth is expensive at full size. The memory blocker and immediate loss blow-up are fixed in the local CUDA probes. |
|
|