ARBS / docs /true-ternary /TRUE-TERNARY-REFACTOR19.md
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified

TRUE TERNARY REFACTOR 19: KV/HCA Shape Safety and Training Entry Points

Problems Addressed

  • HCA attention could break when the full-context path already supplied a strided cache and MLA applied MLA_HCA_STRIDE again to the positional cache.
  • Training scripts made it hard to run quick loss checks because text training brought up the full VQ/KG/MoE/output stack by default.
  • TileLang was still reachable from training even though its BigInt update path is not production-ready.

Changes

  • arbitor/attention/mla.py

    • Added explicit hca_pe_cache support.
    • Crops KV and positional caches to the same key count before attention.
    • Prevents HCA compressed cache and positional cache shape drift.
  • arbitor/attention/context_attention.py

    • Passes the matching HCA positional cache directly to MLA instead of relying on MLA to stride it a second time.
  • arbitor/kernel/ternary_scale.py

    • TileLang remains available for inference, but training now errors unless ARB_TILELANG_TRAINING=1 is explicitly set.
    • Experimental TileLang backward was updated to stream BigInt corr_accum through the Triton correlation-update kernel instead of old T_accum/E_accum.
  • arbitor/main.py

    • Added optional construction flags for attention, output router, video head, and talker head.
    • Text-targeted forwards skip output routing entirely.
    • Training scripts can now instantiate only the components needed for the modality being trained.
  • training/text.py

    • Defaults to a lightweight byte-text stack: no VQ, graph, MoE, KV attention, video head, talker head, or output router.
    • --full-text-stack opts back into VQ + KG/MoEGraph + KV attention.
    • Adds --backend, --log-interval, and safe --eval-interval 0 behavior.
  • training/pretrain.py

    • Adds --backend triton default.
    • Adds --no-vq, --no-graph, --no-moe, --no-attention, and --enable-output-router.
    • Only enables video/talker heads when video/audio modality weights are active.
    • Prints per-step losses at --log-interval.
  • training/audio.py, training/vision.py, training/diffusion.py

    • Added explicit backend selection.
    • Disabled unrelated output heads/router/attention for modality-specific training.

Validation

python -m compileall -q arbitor training testing
python -m pytest -q testing/attention/test_mla.py testing/test_tilelang_training.py
python -m pytest -q testing/attention/test_ring_buffer.py testing/attention/test_kq_cache.py testing/kg/test_kv_integration.py
python -m pytest -q testing/test_polarity_validation.py testing/test_tscale.py -k "small_ternary_training_loss_finite or no_float or cuda_triton_tscale_path"
python training/text.py --steps 3 --batch 1 --ctx 8 --eval-interval 0 --log-interval 1 --backend triton --run text-smoke
python training/pretrain.py --steps 2 --batch 1 --ctx 8 --text-data training/data/tinyshakespeare.txt --text-weight 1.0 --code-weight 0 --image-weight 0 --audio-weight 0 --video-weight 0 --eval-interval 0 --log-interval 1 --save-interval 0 --no-save --backend triton --no-vq --no-graph --no-moe --no-attention --run pretrain-smoke-lite
python training/text.py --steps 1 --batch 1 --ctx 4 --eval-interval 0 --log-interval 1 --backend triton --full-text-stack --run text-full-stack-smoke

Results:

attention/tilelang tests: 14 passed
KV/ring integration tests: 20 passed
focused ternary tests: 3 passed
text lite smoke losses: 6.272, 19.635, 21.286
pretrain lite smoke losses: 5.5805, 20.0597
full text stack smoke loss: 6.794

The loss values are short smoke checks, not convergence claims. The important result is that the loss path is visible, finite, and no longer blocked by unrelated modality components.