TRUE-TERNARY-REFACTOR13
Date: 2026-05-19
Scope
TileLang production-readiness pass:
- keep TileLang as a fast ephemeral compute backend,
- preserve packed ternary persistent state,
- avoid dense float weight-gradient scratch buffers,
- keep Triton as the stable fallback,
- make backend selection explicit for cloud/debug runs.
Backend Selection
TernaryScaleTensor now reads ARB_TERNARY_BACKEND:
autodefault: prefer TileLang when installed, otherwise use Triton, otherwise PyTorch fallback.tilelang: force TileLang and raise if unavailable or failing.triton: force Triton.torch: force the PyTorch fallback.
This lets cloud runs keep TileLang speed while still providing a stable fallback path during deployment.
TileLang Integration Changes
- TileLang forward still consumes packed
T_packed(uint8) and log-scaleE(int8). - TileLang output is cast back to the input activation dtype before returning.
- TileLang bias handling is now restored.
- TileLang backward no longer allocates dense
grad_W(K x N float32) just to produce signs. - TileLang backward now stores direct hooks:
_hook_grad_2d_hook_x_2d
- Existing direct update kernels then update:
T_packedT_accumEE_accum
- If Triton is unavailable, the fallback direct-update path computes the sign from
grad_y.T @ xephemerally and then discards the hooks.
Persistent model state remains ternary/integer:
T_packed:torch.uint8E:torch.int8E_accum:torch.int8T_accum:torch.int8bias:torch.int32
The activation output is still a normal tensor because PyTorch autograd needs an activation dtype, but no TileLang float output is stored as model state.
Verification
TileLang is not installed in this local environment:
ModuleNotFoundError("No module named 'tilelang'")
So verification covered the production fallback and backend controls:
python -m compileall -q arbitor/kernel/ternary_scale.py arbitor trainingpython -m pytest -q testing/test_tscale.py -k "cuda_triton_correctness_update_E or cuda_triton_tscale_path or cuda_triton_correctness_ternary_step"3 passed, 24 deselected
ARB_TERNARY_BACKEND=tritonlinear forward/backward/update smoke:- output dtype:
torch.float32 T_packed:torch.uint8E/E_accum/T_accum:torch.int8- direct hooks consumed after update.
- output dtype:
ARB_TERNARY_BACKEND=torchlinear forward/backward/update smoke:- persistent buffers stayed integer/ternary.
ARB_TERNARY_BACKEND=tilelangcorrectly raises when TileLang is unavailable.python -m arbitor.smoke --device cuda --ctx 4 --batch 1 --max-moe-iters 1 --no-vq --no-graph --backward- cold compile run: forward
34.320s, backward/update50.808s - cached run: forward
0.560s, backward/update1.379s - CUDA peak:
1652.45 MB - zero trainable float params and zero float buffers.
- cold compile run: forward
Operational Notes
Use this to force TileLang on a machine where it is installed:
ARB_TERNARY_BACKEND=tilelang python -m arbitor.smoke --device cuda --ctx 4 --batch 1 --max-moe-iters 1 --backward
Use this for stable production fallback:
ARB_TERNARY_BACKEND=auto python -m arbitor.train --ctx 128 --batch 1 --accum 4 --max-moe-iters 1
Use this to isolate Triton regressions:
ARB_TERNARY_BACKEND=triton python -m pytest -q testing/test_tscale.py -k cuda_triton
Remaining Work
- Run the same smoke on the machine where TileLang is actually installed and compare
autovstilelangvstriton. - If TileLang is consistently faster for the production shapes, add a
prewarm_tilelang.pyhelper that walks the knownM,N,K,group_sizeshapes before training. - The next speed target remains fused sparse MoE dispatch for large-token batches.
TileLang NaN Hotfix
The first TileLang integration still allowed fp16 activation output and fp16 dequantized scale materialization. That can overflow for valid int8 log-scale values and poison training with NaN losses.
Fixes applied:
- TileLang forward output tensor is now
float32, matching the stable Triton activation path. - TileLang grad-x output tensor is now
float32. - The TileLang fp16 dequant operand clamps the exponent to the fp16-safe range
[-14, 15]before casting into the fp16 GEMM tile. PersistentEremains int8 and is not clamped in storage. ARB_TILELANG_CHECK_FINITE=1is enabled by default. If TileLang produces non-finite activations inautomode, the module disables TileLang and falls back to Triton/PyTorch instead of training on NaNs.ARB_TERNARY_BACKEND=tilelangstill raises hard on non-finite TileLang output so debugging does not silently hide a broken kernel.
Additional verification in this environment, where TileLang itself is not installed:
python -m compileall -q arbitor/kernel/ternary_scale.py arbitor trainingpython -m pytest -q testing/test_tscale.py -k "cuda_triton_correctness_update_E or cuda_triton_tscale_path or cuda_triton_correctness_ternary_step"3 passed, 24 deselected
- Minimal CUDA training smoke:
- finite loss:
10.875505447387695 - no leftover ternary update hooks after
_ternary_update_memory()
- finite loss:
TileLang Training Gate
Follow-up debug found that the remaining loss spikes/NaNs are caused by the TileLang fp16 compute path itself, not by persistent ternary state becoming float.
Small reproducer:
- Persistent state stayed
T_packed uint8andE int8. - A PyTorch TileLang-like fp16 path over packed ternary state showed the issue:
- very negative
Evalues are numerically floored by the fp16/clamped dequant path, making tiny weights much larger than the true ternary scale, E >= 15can produce huge logits through fp16 tile operands,- those logits can push training loss up or non-finite even though the stored model remains ternary.
- very negative
Production fix:
- TileLang is no longer used for grad-enabled training by default.
ARB_TERNARY_BACKEND=autouses TileLang only outside training, then falls back to Triton for training.ARB_TERNARY_BACKEND=tilelangnow raises during training unless explicitly enabled.ARB_TILELANG_TRAINING=1exists only for isolated debugging on the TileLang machine._ternary_update_memory()now refuses to update ternary state after a NaN/Inf loss.arbitor.trainandtraining/pretrain.pyabort before backward if loss is non-finite.
Additional verification:
python -m pytest -q testing/test_tscale.py -k "small_ternary_training_loss_finite or ternary_update_rejects_nonfinite_loss or cuda_triton_correctness_update_E or cuda_triton_tscale_path or cuda_triton_correctness_ternary_step"5 passed, 24 deselected
python -m arbitor.smoke --device cuda --ctx 4 --batch 1 --no-moe --no-vq --no-graph --backward- finite loss:
16.085211 - backward/update:
0.147s - zero trainable float params and zero float buffers
- finite loss:
Until the TileLang kernel has a bf16/fp32 dequant/GEMM path or an integer-scale matmul path, it should be treated as inference/prewarm acceleration, not a training backend.