ARBS / docs /true-ternary /TRUE-TERNARY-REFACTOR17.md
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
# TRUE TERNARY REFACTOR 17
Date: 2026-05-20
## Goal
Fix the post-REFACTOR16 training OOM/regression and compare the production training path against `training/train_small_float.py`, which was reported as the small learning reference.
## Finding
`training/train_small_float.py` is not a ternary-weight trainer. It is a float32 baseline:
- `nn.Embedding`
- `nn.Linear`
- `nn.RMSNorm`
- `torch.optim.AdamW`
Local run:
```text
python training/train_small_float.py --steps 200 --dim 256 --layers 4 --batch 16 --ctx 64 --device cuda
Model: 1.73M params
step 0: train=5.8205 val=5.5523
step 199: train=2.5854 val=2.6280
```
The useful parts to copy into the ternary system were not AdamW or float weights. They were stable fan-in-scaled initialization, simple byte supervision, and small update steps.
## Changes
### 1. Restored Streaming Backward In Training
`ARBModel.prepare_ternary_backward()` was setting `_stream_backward_updates = False`, which disabled the REFACTOR16 streaming path before every training backward pass. That forced old full weight-shaped hook tensors back into memory.
It now sets:
```text
_stream_backward_updates = True
```
This keeps Triton linears accumulating directly into `T_accum` and `E_accum`.
### 2. Fixed Gradient Descent Direction For T Flips
The ternary T update was accumulating `+sign(grad)`, which is gradient ascent for the ternary sign. It now accumulates `-sign(grad)` in:
- Triton direct accumulation
- Triton legacy ternary step
- dense `TernaryScaleTensor.ternary_step`
- `ByteEmbedding.ternary_step`
- sparse `TernaryEmbeddingTable` rows
The CUDA test expectation was updated so a positive gradient moves the trit toward `-1`.
### 3. Removed Persistent Float Component Accumulator
The componentwise loss path could create `_T_accum_fp`, a full float32 buffer shaped like the ternary weight. That violates the packed/int8 training-state rule and can OOM large layers.
Component loss accumulation now uses int8 sign accumulation only and clears any stale `_T_accum_fp`.
### 4. Large RMSNorm CUDA Fallback
Large hidden RMSNorms were OOMing in Triton backward because the kernel used `BLOCK_D=next_power_of_2(dim)`. For wide norms, CUDA now uses the identity-weight RMS path:
```text
x * rsqrt(mean(x*x) + eps)
```
This is correct for current `TernaryRMSNorm` because its T/E state initializes as identity and is not trained.
### 5. Fan-In-Aware Ternary Initialization
The full system was initializing every ternary linear with `std=0.1`, independent of fan-in. At 8192-wide layers this made the initial logits much too large.
Default `TernaryScaleTensor` init now uses:
```text
std = min(0.1, 1 / sqrt(in_dim))
threshold = min(config_threshold, 0.5 * std)
```
This keeps ternary masks active while matching the float baseline's stable fan-in scaling. `ByteEmbedding` and dense `TernaryEmbeddingTable` now also scale their ternary threshold from init std.
### 6. Conservative T Step
`_ternary_t_step()` now returns `1`. The old loss-scaled value could be `4`, which crossed the default threshold `3` in one backward pass and caused a destructive mass flip immediately after good initialization.
### 7. Configurable MoEGraph Top-K
`ARBModel` no longer hard-codes MoEGraph `top_k=4`. It now uses `MG_TOP_K` from config, defaulting to `2`, which reduces routed expert work while keeping multi-expert routing active.
## Validation
Passed:
```text
python -m compileall -q arbitor training testing
python -m pytest -q testing/test_gradient_capture.py testing/test_tilelang_training.py testing/test_tscale.py::test_small_ternary_training_loss_finite testing/test_tscale.py::test_e_per_component_routing testing/test_tscale.py::test_cuda_triton_tscale_path
```
Full model memory and loss probe, text-only, `ctx=64`, `max_moe_iters=4`, `MG_TOP_K=2`:
```text
logical ternary weights: 3,122,925,280
ternary training state: 4071.0 MB
after cuda: alloc=4307.3MB reserved=4378.0MB
step 0 loss=7.565 cleanup alloc=4355.6MB peak=5058.0MB
step 1 loss=7.503 cleanup alloc=4355.6MB peak=5106.3MB
step 2 loss=7.668 cleanup alloc=4355.6MB peak=5106.3MB
```
The important result is that allocated VRAM returns to the same post-step baseline. The previous constant stacking behavior was not reproduced.
Actual pretrain entrypoint smoke:
```text
python training/pretrain.py --steps 2 --batch 1 --ctx 64 --text-data /tmp/tinyshakespeare.txt --max-moe-iters 4 --no-save --log-interval 1 --eval-interval 0 --save-interval 0
logical ternary weights: 3,122,925,280
trainable float params: 0 tensors, 0.00 MB
float buffers: 0 tensors, 0.00 MB
loss step 1: 7.5023
loss step 2: 7.4532
```
## Current Status
The model is now above the requested 3.1B logical ternary target locally:
```text
3.122925B logical ternary weights
```
Persistent trainable state remains packed trits and int8 accumulators/scales. The remaining practical issue is speed: the first step can still pay compile/setup cost, and MoEGraph top-k/depth is expensive at full size. The memory blocker and immediate loss blow-up are fixed in the local CUDA probes.