File size: 5,085 Bytes
d8bc908 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | # TRUE TERNARY REFACTOR 17
Date: 2026-05-20
## Goal
Fix the post-REFACTOR16 training OOM/regression and compare the production training path against `training/train_small_float.py`, which was reported as the small learning reference.
## Finding
`training/train_small_float.py` is not a ternary-weight trainer. It is a float32 baseline:
- `nn.Embedding`
- `nn.Linear`
- `nn.RMSNorm`
- `torch.optim.AdamW`
Local run:
```text
python training/train_small_float.py --steps 200 --dim 256 --layers 4 --batch 16 --ctx 64 --device cuda
Model: 1.73M params
step 0: train=5.8205 val=5.5523
step 199: train=2.5854 val=2.6280
```
The useful parts to copy into the ternary system were not AdamW or float weights. They were stable fan-in-scaled initialization, simple byte supervision, and small update steps.
## Changes
### 1. Restored Streaming Backward In Training
`ARBModel.prepare_ternary_backward()` was setting `_stream_backward_updates = False`, which disabled the REFACTOR16 streaming path before every training backward pass. That forced old full weight-shaped hook tensors back into memory.
It now sets:
```text
_stream_backward_updates = True
```
This keeps Triton linears accumulating directly into `T_accum` and `E_accum`.
### 2. Fixed Gradient Descent Direction For T Flips
The ternary T update was accumulating `+sign(grad)`, which is gradient ascent for the ternary sign. It now accumulates `-sign(grad)` in:
- Triton direct accumulation
- Triton legacy ternary step
- dense `TernaryScaleTensor.ternary_step`
- `ByteEmbedding.ternary_step`
- sparse `TernaryEmbeddingTable` rows
The CUDA test expectation was updated so a positive gradient moves the trit toward `-1`.
### 3. Removed Persistent Float Component Accumulator
The componentwise loss path could create `_T_accum_fp`, a full float32 buffer shaped like the ternary weight. That violates the packed/int8 training-state rule and can OOM large layers.
Component loss accumulation now uses int8 sign accumulation only and clears any stale `_T_accum_fp`.
### 4. Large RMSNorm CUDA Fallback
Large hidden RMSNorms were OOMing in Triton backward because the kernel used `BLOCK_D=next_power_of_2(dim)`. For wide norms, CUDA now uses the identity-weight RMS path:
```text
x * rsqrt(mean(x*x) + eps)
```
This is correct for current `TernaryRMSNorm` because its T/E state initializes as identity and is not trained.
### 5. Fan-In-Aware Ternary Initialization
The full system was initializing every ternary linear with `std=0.1`, independent of fan-in. At 8192-wide layers this made the initial logits much too large.
Default `TernaryScaleTensor` init now uses:
```text
std = min(0.1, 1 / sqrt(in_dim))
threshold = min(config_threshold, 0.5 * std)
```
This keeps ternary masks active while matching the float baseline's stable fan-in scaling. `ByteEmbedding` and dense `TernaryEmbeddingTable` now also scale their ternary threshold from init std.
### 6. Conservative T Step
`_ternary_t_step()` now returns `1`. The old loss-scaled value could be `4`, which crossed the default threshold `3` in one backward pass and caused a destructive mass flip immediately after good initialization.
### 7. Configurable MoEGraph Top-K
`ARBModel` no longer hard-codes MoEGraph `top_k=4`. It now uses `MG_TOP_K` from config, defaulting to `2`, which reduces routed expert work while keeping multi-expert routing active.
## Validation
Passed:
```text
python -m compileall -q arbitor training testing
python -m pytest -q testing/test_gradient_capture.py testing/test_tilelang_training.py testing/test_tscale.py::test_small_ternary_training_loss_finite testing/test_tscale.py::test_e_per_component_routing testing/test_tscale.py::test_cuda_triton_tscale_path
```
Full model memory and loss probe, text-only, `ctx=64`, `max_moe_iters=4`, `MG_TOP_K=2`:
```text
logical ternary weights: 3,122,925,280
ternary training state: 4071.0 MB
after cuda: alloc=4307.3MB reserved=4378.0MB
step 0 loss=7.565 cleanup alloc=4355.6MB peak=5058.0MB
step 1 loss=7.503 cleanup alloc=4355.6MB peak=5106.3MB
step 2 loss=7.668 cleanup alloc=4355.6MB peak=5106.3MB
```
The important result is that allocated VRAM returns to the same post-step baseline. The previous constant stacking behavior was not reproduced.
Actual pretrain entrypoint smoke:
```text
python training/pretrain.py --steps 2 --batch 1 --ctx 64 --text-data /tmp/tinyshakespeare.txt --max-moe-iters 4 --no-save --log-interval 1 --eval-interval 0 --save-interval 0
logical ternary weights: 3,122,925,280
trainable float params: 0 tensors, 0.00 MB
float buffers: 0 tensors, 0.00 MB
loss step 1: 7.5023
loss step 2: 7.4532
```
## Current Status
The model is now above the requested 3.1B logical ternary target locally:
```text
3.122925B logical ternary weights
```
Persistent trainable state remains packed trits and int8 accumulators/scales. The remaining practical issue is speed: the first step can still pay compile/setup cost, and MoEGraph top-k/depth is expensive at full size. The memory blocker and immediate loss blow-up are fixed in the local CUDA probes.
|