ARBS / docs /true-ternary /TRUE-TERNARY-REFACTOR15.md
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified

TRUE TERNARY REFACTOR 15

Date: 2026-05-20

Goal

Fix the two regressions reported after the platform restructure:

  • model capacity had fallen back to about 1.9B instead of the 3B target
  • training was hitting OOM from fp16/fp32 state leaking into ternary paths

Changes

1. Restored the 3B target shape

arbitor/config.py now restores the large VQ targets while keeping the motif width small enough to stay near 3B:

  • CODEBOOK_DIM = 64
  • SHARED_VQ_SIZE = 10_000_000
  • KGVQ_CODEBOOK_SIZE = 5_000_000
  • KGVQ_CODEBOOK_DIM = 64

A no-allocation constructor trace of the assembled default model reports:

dummy logical ternary total: 3,011,944,672

This keeps the requested 10M shared VQ and 5M KG VQ without the accidental 1024-wide VQ explosion.

2. Removed MoEGraph fp16 edge EMA

MoEGraph no longer allocates dense codebook_size * 10 graph edges for large VQ graphs and no longer registers edge_ema as float16.

Large graphs now use bounded active edge state:

  • active_edge_src: int32
  • active_edge_dst: int32
  • active_edge_attr: int8 ternary edge sign
  • active_edge_score: int8 residual score
  • edge_index: empty compatibility buffer for large active mode

Small graph tests still use dense edges, but the score path is now int8 edge_score, not fp16 EMA.

3. Removed float KG VQ buffers

The old KGVQCodebook kept float32 embed and embed_avg buffers. It is now a compatibility wrapper around TernaryVQCodebook, so the KG/composite VQ uses packed ternary rows, int8 scales, int8 accumulators, and int16 usage counts.

4. Large VQ initialization is now packed-first

TernaryEmbeddingTable now detects million-entry tables and initializes directly into:

  • packed uint8 trits
  • int8 E
  • int8 E_accum
  • int8 T_accum

This avoids building temporary multi-GB float tensors for the 10M shared VQ and 5M KG VQ.

5. Removed persistent fp32 Triton training hooks

The Triton ternary backward path now stores _hook_grad_T_sign as int8 instead of keeping _hook_grad_2d and _hook_x_2d fp32 activation/gradient views on each ternary module after backward.

The direct fp32 hook fallback remains only for non-Triton compatibility paths, and the tests now assert that the CUDA Triton path does not retain fp32 grad/x hooks.

6. Ternary MoE centroids

MoEGraph routing centroids are now a TernaryEmbeddingTable instead of a float nn.Parameter.

Validation

Passed:

python -m compileall -q arbitor training testing
python -m pytest -q testing/kg/test_kg_edges.py testing/kg/test_composite_head.py testing/test_gradient_capture.py testing/test_tilelang_training.py
python -m pytest -q testing/test_tscale.py::test_cuda_triton_tscale_path
python -m pytest -q --import-mode=importlib testing/model/test_tscale.py::test_cuda_triton_tscale_path

Additional targeted checks passed:

  • large active MoEGraph with codebook_size=10_000_000 has edge_index.shape == (2, 0) and no float edge buffers
  • 1M-entry TernaryVQCodebook has no float buffers and trains through sparse forward/backward/update
  • small active MoEGraph forward/backward remains finite with ternary centroids

Remaining Risk

The kernels still use fp32 accumulators internally for numeric accumulation and losses still produce floating scalar loss values. This pass removes persistent fp16/fp32 ternary state and retained fp32 training hooks, which were the memory leak/OOM concern. A fully integer activation/loss path would be a separate kernel-level redesign.