ARBS / docs /true-ternary /TRUE-TERNARY-REFACTOR13.md
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified

TRUE-TERNARY-REFACTOR13

Date: 2026-05-19

Scope

TileLang production-readiness pass:

  • keep TileLang as a fast ephemeral compute backend,
  • preserve packed ternary persistent state,
  • avoid dense float weight-gradient scratch buffers,
  • keep Triton as the stable fallback,
  • make backend selection explicit for cloud/debug runs.

Backend Selection

TernaryScaleTensor now reads ARB_TERNARY_BACKEND:

  • auto default: prefer TileLang when installed, otherwise use Triton, otherwise PyTorch fallback.
  • tilelang: force TileLang and raise if unavailable or failing.
  • triton: force Triton.
  • torch: force the PyTorch fallback.

This lets cloud runs keep TileLang speed while still providing a stable fallback path during deployment.

TileLang Integration Changes

  • TileLang forward still consumes packed T_packed (uint8) and log-scale E (int8).
  • TileLang output is cast back to the input activation dtype before returning.
  • TileLang bias handling is now restored.
  • TileLang backward no longer allocates dense grad_W (K x N float32) just to produce signs.
  • TileLang backward now stores direct hooks:
    • _hook_grad_2d
    • _hook_x_2d
  • Existing direct update kernels then update:
    • T_packed
    • T_accum
    • E
    • E_accum
  • If Triton is unavailable, the fallback direct-update path computes the sign from grad_y.T @ x ephemerally and then discards the hooks.

Persistent model state remains ternary/integer:

  • T_packed: torch.uint8
  • E: torch.int8
  • E_accum: torch.int8
  • T_accum: torch.int8
  • bias: torch.int32

The activation output is still a normal tensor because PyTorch autograd needs an activation dtype, but no TileLang float output is stored as model state.

Verification

TileLang is not installed in this local environment:

ModuleNotFoundError("No module named 'tilelang'")

So verification covered the production fallback and backend controls:

  • python -m compileall -q arbitor/kernel/ternary_scale.py arbitor training
  • python -m pytest -q testing/test_tscale.py -k "cuda_triton_correctness_update_E or cuda_triton_tscale_path or cuda_triton_correctness_ternary_step"
    • 3 passed, 24 deselected
  • ARB_TERNARY_BACKEND=triton linear forward/backward/update smoke:
    • output dtype: torch.float32
    • T_packed: torch.uint8
    • E/E_accum/T_accum: torch.int8
    • direct hooks consumed after update.
  • ARB_TERNARY_BACKEND=torch linear forward/backward/update smoke:
    • persistent buffers stayed integer/ternary.
  • ARB_TERNARY_BACKEND=tilelang correctly raises when TileLang is unavailable.
  • python -m arbitor.smoke --device cuda --ctx 4 --batch 1 --max-moe-iters 1 --no-vq --no-graph --backward
    • cold compile run: forward 34.320s, backward/update 50.808s
    • cached run: forward 0.560s, backward/update 1.379s
    • CUDA peak: 1652.45 MB
    • zero trainable float params and zero float buffers.

Operational Notes

Use this to force TileLang on a machine where it is installed:

ARB_TERNARY_BACKEND=tilelang python -m arbitor.smoke --device cuda --ctx 4 --batch 1 --max-moe-iters 1 --backward

Use this for stable production fallback:

ARB_TERNARY_BACKEND=auto python -m arbitor.train --ctx 128 --batch 1 --accum 4 --max-moe-iters 1

Use this to isolate Triton regressions:

ARB_TERNARY_BACKEND=triton python -m pytest -q testing/test_tscale.py -k cuda_triton

Remaining Work

  • Run the same smoke on the machine where TileLang is actually installed and compare auto vs tilelang vs triton.
  • If TileLang is consistently faster for the production shapes, add a prewarm_tilelang.py helper that walks the known M,N,K,group_size shapes before training.
  • The next speed target remains fused sparse MoE dispatch for large-token batches.

TileLang NaN Hotfix

The first TileLang integration still allowed fp16 activation output and fp16 dequantized scale materialization. That can overflow for valid int8 log-scale values and poison training with NaN losses.

Fixes applied:

  • TileLang forward output tensor is now float32, matching the stable Triton activation path.
  • TileLang grad-x output tensor is now float32.
  • The TileLang fp16 dequant operand clamps the exponent to the fp16-safe range [-14, 15] before casting into the fp16 GEMM tile. Persistent E remains int8 and is not clamped in storage.
  • ARB_TILELANG_CHECK_FINITE=1 is enabled by default. If TileLang produces non-finite activations in auto mode, the module disables TileLang and falls back to Triton/PyTorch instead of training on NaNs.
  • ARB_TERNARY_BACKEND=tilelang still raises hard on non-finite TileLang output so debugging does not silently hide a broken kernel.

Additional verification in this environment, where TileLang itself is not installed:

  • python -m compileall -q arbitor/kernel/ternary_scale.py arbitor training
  • python -m pytest -q testing/test_tscale.py -k "cuda_triton_correctness_update_E or cuda_triton_tscale_path or cuda_triton_correctness_ternary_step"
    • 3 passed, 24 deselected
  • Minimal CUDA training smoke:
    • finite loss: 10.875505447387695
    • no leftover ternary update hooks after _ternary_update_memory()

TileLang Training Gate

Follow-up debug found that the remaining loss spikes/NaNs are caused by the TileLang fp16 compute path itself, not by persistent ternary state becoming float.

Small reproducer:

  • Persistent state stayed T_packed uint8 and E int8.
  • A PyTorch TileLang-like fp16 path over packed ternary state showed the issue:
    • very negative E values are numerically floored by the fp16/clamped dequant path, making tiny weights much larger than the true ternary scale,
    • E >= 15 can produce huge logits through fp16 tile operands,
    • those logits can push training loss up or non-finite even though the stored model remains ternary.

Production fix:

  • TileLang is no longer used for grad-enabled training by default.
  • ARB_TERNARY_BACKEND=auto uses TileLang only outside training, then falls back to Triton for training.
  • ARB_TERNARY_BACKEND=tilelang now raises during training unless explicitly enabled.
  • ARB_TILELANG_TRAINING=1 exists only for isolated debugging on the TileLang machine.
  • _ternary_update_memory() now refuses to update ternary state after a NaN/Inf loss.
  • arbitor.train and training/pretrain.py abort before backward if loss is non-finite.

Additional verification:

  • python -m pytest -q testing/test_tscale.py -k "small_ternary_training_loss_finite or ternary_update_rejects_nonfinite_loss or cuda_triton_correctness_update_E or cuda_triton_tscale_path or cuda_triton_correctness_ternary_step"
    • 5 passed, 24 deselected
  • python -m arbitor.smoke --device cuda --ctx 4 --batch 1 --no-moe --no-vq --no-graph --backward
    • finite loss: 16.085211
    • backward/update: 0.147s
    • zero trainable float params and zero float buffers

Until the TileLang kernel has a bf16/fp32 dequant/GEMM path or an integer-scale matmul path, it should be treated as inference/prewarm acceleration, not a training backend.