# TRUE-TERNARY-REFACTOR13 Date: 2026-05-19 ## Scope TileLang production-readiness pass: - keep TileLang as a fast ephemeral compute backend, - preserve packed ternary persistent state, - avoid dense float weight-gradient scratch buffers, - keep Triton as the stable fallback, - make backend selection explicit for cloud/debug runs. ## Backend Selection `TernaryScaleTensor` now reads `ARB_TERNARY_BACKEND`: - `auto` default: prefer TileLang when installed, otherwise use Triton, otherwise PyTorch fallback. - `tilelang`: force TileLang and raise if unavailable or failing. - `triton`: force Triton. - `torch`: force the PyTorch fallback. This lets cloud runs keep TileLang speed while still providing a stable fallback path during deployment. ## TileLang Integration Changes - TileLang forward still consumes packed `T_packed` (`uint8`) and log-scale `E` (`int8`). - TileLang output is cast back to the input activation dtype before returning. - TileLang bias handling is now restored. - TileLang backward no longer allocates dense `grad_W` (`K x N float32`) just to produce signs. - TileLang backward now stores direct hooks: - `_hook_grad_2d` - `_hook_x_2d` - Existing direct update kernels then update: - `T_packed` - `T_accum` - `E` - `E_accum` - If Triton is unavailable, the fallback direct-update path computes the sign from `grad_y.T @ x` ephemerally and then discards the hooks. Persistent model state remains ternary/integer: - `T_packed`: `torch.uint8` - `E`: `torch.int8` - `E_accum`: `torch.int8` - `T_accum`: `torch.int8` - `bias`: `torch.int32` The activation output is still a normal tensor because PyTorch autograd needs an activation dtype, but no TileLang float output is stored as model state. ## Verification TileLang is not installed in this local environment: ```text ModuleNotFoundError("No module named 'tilelang'") ``` So verification covered the production fallback and backend controls: - `python -m compileall -q arbitor/kernel/ternary_scale.py arbitor training` - `python -m pytest -q testing/test_tscale.py -k "cuda_triton_correctness_update_E or cuda_triton_tscale_path or cuda_triton_correctness_ternary_step"` - `3 passed, 24 deselected` - `ARB_TERNARY_BACKEND=triton` linear forward/backward/update smoke: - output dtype: `torch.float32` - `T_packed`: `torch.uint8` - `E/E_accum/T_accum`: `torch.int8` - direct hooks consumed after update. - `ARB_TERNARY_BACKEND=torch` linear forward/backward/update smoke: - persistent buffers stayed integer/ternary. - `ARB_TERNARY_BACKEND=tilelang` correctly raises when TileLang is unavailable. - `python -m arbitor.smoke --device cuda --ctx 4 --batch 1 --max-moe-iters 1 --no-vq --no-graph --backward` - cold compile run: forward `34.320s`, backward/update `50.808s` - cached run: forward `0.560s`, backward/update `1.379s` - CUDA peak: `1652.45 MB` - zero trainable float params and zero float buffers. ## Operational Notes Use this to force TileLang on a machine where it is installed: ```bash ARB_TERNARY_BACKEND=tilelang python -m arbitor.smoke --device cuda --ctx 4 --batch 1 --max-moe-iters 1 --backward ``` Use this for stable production fallback: ```bash ARB_TERNARY_BACKEND=auto python -m arbitor.train --ctx 128 --batch 1 --accum 4 --max-moe-iters 1 ``` Use this to isolate Triton regressions: ```bash ARB_TERNARY_BACKEND=triton python -m pytest -q testing/test_tscale.py -k cuda_triton ``` ## Remaining Work - Run the same smoke on the machine where TileLang is actually installed and compare `auto` vs `tilelang` vs `triton`. - If TileLang is consistently faster for the production shapes, add a `prewarm_tilelang.py` helper that walks the known `M,N,K,group_size` shapes before training. - The next speed target remains fused sparse MoE dispatch for large-token batches. ## TileLang NaN Hotfix The first TileLang integration still allowed fp16 activation output and fp16 dequantized scale materialization. That can overflow for valid int8 log-scale values and poison training with NaN losses. Fixes applied: - TileLang forward output tensor is now `float32`, matching the stable Triton activation path. - TileLang grad-x output tensor is now `float32`. - The TileLang fp16 dequant operand clamps the exponent to the fp16-safe range `[-14, 15]` before casting into the fp16 GEMM tile. Persistent `E` remains int8 and is not clamped in storage. - `ARB_TILELANG_CHECK_FINITE=1` is enabled by default. If TileLang produces non-finite activations in `auto` mode, the module disables TileLang and falls back to Triton/PyTorch instead of training on NaNs. - `ARB_TERNARY_BACKEND=tilelang` still raises hard on non-finite TileLang output so debugging does not silently hide a broken kernel. Additional verification in this environment, where TileLang itself is not installed: - `python -m compileall -q arbitor/kernel/ternary_scale.py arbitor training` - `python -m pytest -q testing/test_tscale.py -k "cuda_triton_correctness_update_E or cuda_triton_tscale_path or cuda_triton_correctness_ternary_step"` - `3 passed, 24 deselected` - Minimal CUDA training smoke: - finite loss: `10.875505447387695` - no leftover ternary update hooks after `_ternary_update_memory()` ## TileLang Training Gate Follow-up debug found that the remaining loss spikes/NaNs are caused by the TileLang fp16 compute path itself, not by persistent ternary state becoming float. Small reproducer: - Persistent state stayed `T_packed uint8` and `E int8`. - A PyTorch TileLang-like fp16 path over packed ternary state showed the issue: - very negative `E` values are numerically floored by the fp16/clamped dequant path, making tiny weights much larger than the true ternary scale, - `E >= 15` can produce huge logits through fp16 tile operands, - those logits can push training loss up or non-finite even though the stored model remains ternary. Production fix: - TileLang is no longer used for grad-enabled training by default. - `ARB_TERNARY_BACKEND=auto` uses TileLang only outside training, then falls back to Triton for training. - `ARB_TERNARY_BACKEND=tilelang` now raises during training unless explicitly enabled. - `ARB_TILELANG_TRAINING=1` exists only for isolated debugging on the TileLang machine. - `_ternary_update_memory()` now refuses to update ternary state after a NaN/Inf loss. - `arbitor.train` and `training/pretrain.py` abort before backward if loss is non-finite. Additional verification: - `python -m pytest -q testing/test_tscale.py -k "small_ternary_training_loss_finite or ternary_update_rejects_nonfinite_loss or cuda_triton_correctness_update_E or cuda_triton_tscale_path or cuda_triton_correctness_ternary_step"` - `5 passed, 24 deselected` - `python -m arbitor.smoke --device cuda --ctx 4 --batch 1 --no-moe --no-vq --no-graph --backward` - finite loss: `16.085211` - backward/update: `0.147s` - zero trainable float params and zero float buffers Until the TileLang kernel has a bf16/fp32 dequant/GEMM path or an integer-scale matmul path, it should be treated as inference/prewarm acceleration, not a training backend.