ARBS / docs /true-ternary /TRUE-TERNARY-REFACTOR15.md
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
# TRUE TERNARY REFACTOR 15
Date: 2026-05-20
## Goal
Fix the two regressions reported after the platform restructure:
- model capacity had fallen back to about 1.9B instead of the 3B target
- training was hitting OOM from fp16/fp32 state leaking into ternary paths
## Changes
### 1. Restored the 3B target shape
`arbitor/config.py` now restores the large VQ targets while keeping the motif width small enough to stay near 3B:
- `CODEBOOK_DIM = 64`
- `SHARED_VQ_SIZE = 10_000_000`
- `KGVQ_CODEBOOK_SIZE = 5_000_000`
- `KGVQ_CODEBOOK_DIM = 64`
A no-allocation constructor trace of the assembled default model reports:
```text
dummy logical ternary total: 3,011,944,672
```
This keeps the requested 10M shared VQ and 5M KG VQ without the accidental 1024-wide VQ explosion.
### 2. Removed MoEGraph fp16 edge EMA
`MoEGraph` no longer allocates dense `codebook_size * 10` graph edges for large VQ graphs and no longer registers `edge_ema` as `float16`.
Large graphs now use bounded active edge state:
- `active_edge_src`: int32
- `active_edge_dst`: int32
- `active_edge_attr`: int8 ternary edge sign
- `active_edge_score`: int8 residual score
- `edge_index`: empty compatibility buffer for large active mode
Small graph tests still use dense edges, but the score path is now int8 `edge_score`, not fp16 EMA.
### 3. Removed float KG VQ buffers
The old `KGVQCodebook` kept float32 `embed` and `embed_avg` buffers. It is now a compatibility wrapper around `TernaryVQCodebook`, so the KG/composite VQ uses packed ternary rows, int8 scales, int8 accumulators, and int16 usage counts.
### 4. Large VQ initialization is now packed-first
`TernaryEmbeddingTable` now detects million-entry tables and initializes directly into:
- packed `uint8` trits
- int8 `E`
- int8 `E_accum`
- int8 `T_accum`
This avoids building temporary multi-GB float tensors for the 10M shared VQ and 5M KG VQ.
### 5. Removed persistent fp32 Triton training hooks
The Triton ternary backward path now stores `_hook_grad_T_sign` as int8 instead of keeping `_hook_grad_2d` and `_hook_x_2d` fp32 activation/gradient views on each ternary module after backward.
The direct fp32 hook fallback remains only for non-Triton compatibility paths, and the tests now assert that the CUDA Triton path does not retain fp32 grad/x hooks.
### 6. Ternary MoE centroids
MoEGraph routing centroids are now a `TernaryEmbeddingTable` instead of a float `nn.Parameter`.
## Validation
Passed:
```bash
python -m compileall -q arbitor training testing
python -m pytest -q testing/kg/test_kg_edges.py testing/kg/test_composite_head.py testing/test_gradient_capture.py testing/test_tilelang_training.py
python -m pytest -q testing/test_tscale.py::test_cuda_triton_tscale_path
python -m pytest -q --import-mode=importlib testing/model/test_tscale.py::test_cuda_triton_tscale_path
```
Additional targeted checks passed:
- large active MoEGraph with `codebook_size=10_000_000` has `edge_index.shape == (2, 0)` and no float edge buffers
- 1M-entry `TernaryVQCodebook` has no float buffers and trains through sparse forward/backward/update
- small active MoEGraph forward/backward remains finite with ternary centroids
## Remaining Risk
The kernels still use fp32 accumulators internally for numeric accumulation and losses still produce floating scalar loss values. This pass removes persistent fp16/fp32 ternary state and retained fp32 training hooks, which were the memory leak/OOM concern. A fully integer activation/loss path would be a separate kernel-level redesign.