ARBS / docs /true-ternary /TRUE-TERNARY-REFACTOR7.md
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified

True Ternary Refactor 7 — MoE/Graph Kernel Hardening

Scope

This phase prioritizes MoE and Graph runtime hardening after REFACTOR6 confirmed the internal model can run with zero trainable float params and zero float buffers in strict mode.

The agent's REFACTOR6 readout is consistent with the code behavior:

  • E_accum being active in most groups is a real signal that residual scale updates are now preserving small gradient evidence instead of dropping it.
  • A 20-step loss curve is not enough to judge convergence for a random-initialized pure ternary model with sign flips and integer scale residuals.
  • The reported step-time swings are plausibly a kernel-shape/recompile issue, especially in MoE sparse routing where each expert receives a different token count per step.

MoE Dispatch Hardening

Added a fixed-shape dense dispatch path for small token counts in SharedProjectionMoE.

Previous small-batch sparse path:

top-k slot loop -> sort tokens -> expert loop -> variable [n_tokens, D] ternary kernels

That is compute-efficient when routing large batches, but it causes unstable Triton kernel shapes when n_tokens changes per expert. In strict training with small context/batch sizes this can create repeated compilation and slow outlier steps.

New small-batch path:

if N <= dense_dispatch_max_tokens:
    run every expert on fixed [N, D] inputs
    multiply each expert output by its top-k routing weight, zero for unselected tokens
else:
    keep sparse dynamic expert dispatch

Default cutoff:

dense_dispatch_max_tokens = 128

This keeps the sparse path available for larger runs while making strict-mode smoke and small-batch training use stable shapes.

Small CUDA timing on B=2, L=8, D=512:

dense_fixed_shape:   first call 0.3286s, then 0.0047, 0.0037, 0.0037, 0.0036
sparse_dynamic_shape: first call 0.0415s, then 0.0077, 0.0088, 0.0076, 0.0080

The first call still includes compilation/setup, but the steady-state fixed-shape path is faster for this strict small-batch shape and avoids per-expert token-count shape churn.

Graph Hardening

REFACTOR6 already added a Triton Graph aggregation kernel:

projected messages + ternary edge weighting + target aggregation

This replaces the previous CUDA path that materialized messages and then called scatter_add_. The Graph path still has a Python hop loop and still calls the ternary projections per hop. A full graph-hop fusion remains a separate kernel because it needs packed ternary decode, edge aggregation, update projection, residual, and hop LoRA scheduling in one launch.

Learning Hardening

Fixed ByteEmbedding.ternary_step() to honor _t_accum_step.

Before this fix, TernaryScaleTensor and TernaryEmbeddingTable consumed the loss-scaled integer step, but ByteEmbedding still did:

T_accum += sign(grad)

Now it does:

T_accum += sign(grad) * t_accum_step

This matters because text byte embeddings are part of the strict model's trainable ternary state. Without this, the embedding trits could lag behind the rest of the model and require more updates to reach threshold.

Verification

  • python -m py_compile trigram.py tscale.py benchmark_true_ternary.py train.py ternary_audit.py testing/test_tscale.py
  • python -m pytest -q testing/test_tscale.py -k "cuda_triton_correctness_update_E or cuda_triton_tscale_path": 2 passed
  • moe_dense_dispatch_cuda_ok: forward/backward passed.
  • byte_embedding_t_step_ok: embedding accumulator consumed the loss-scaled step and flipped at threshold.
  • Full CUDA model smoke with VQ, Graph, Memory, and MoE enabled passed forward, backward, and _ternary_update_memory().
  • Strict one-step train smoke remained zero-float and ran the training step at about 1.91 it/s before eval/checkpoint overhead.

Remaining Kernel Work

  1. Full fused MoE expert kernel:

    • packed ternary expert decode
    • top-k route weights
    • per-expert low-rank projection
    • shared hidden multiply
    • weighted output accumulation
  2. Full fused Graph hop kernel:

    • packed ternary projection decode
    • edge aggregation
    • update projection
    • residual/hop-LoRA composition
  3. Scale precision extension:

    • Current S = 2^E.
    • Exact S = 99.9 style values need a low-overhead integer/fixed-point mantissa or residual scale lattice.