| # TRUE-TERNARY-REFACTOR13 |
|
|
| Date: 2026-05-19 |
|
|
| ## Scope |
|
|
| TileLang production-readiness pass: |
|
|
| - keep TileLang as a fast ephemeral compute backend, |
| - preserve packed ternary persistent state, |
| - avoid dense float weight-gradient scratch buffers, |
| - keep Triton as the stable fallback, |
| - make backend selection explicit for cloud/debug runs. |
|
|
| ## Backend Selection |
|
|
| `TernaryScaleTensor` now reads `ARB_TERNARY_BACKEND`: |
|
|
| - `auto` default: prefer TileLang when installed, otherwise use Triton, otherwise PyTorch fallback. |
| - `tilelang`: force TileLang and raise if unavailable or failing. |
| - `triton`: force Triton. |
| - `torch`: force the PyTorch fallback. |
|
|
| This lets cloud runs keep TileLang speed while still providing a stable fallback path during deployment. |
|
|
| ## TileLang Integration Changes |
|
|
| - TileLang forward still consumes packed `T_packed` (`uint8`) and log-scale `E` (`int8`). |
| - TileLang output is cast back to the input activation dtype before returning. |
| - TileLang bias handling is now restored. |
| - TileLang backward no longer allocates dense `grad_W` (`K x N float32`) just to produce signs. |
| - TileLang backward now stores direct hooks: |
| - `_hook_grad_2d` |
| - `_hook_x_2d` |
| - Existing direct update kernels then update: |
| - `T_packed` |
| - `T_accum` |
| - `E` |
| - `E_accum` |
| - If Triton is unavailable, the fallback direct-update path computes the sign from `grad_y.T @ x` ephemerally and then discards the hooks. |
|
|
| Persistent model state remains ternary/integer: |
|
|
| - `T_packed`: `torch.uint8` |
| - `E`: `torch.int8` |
| - `E_accum`: `torch.int8` |
| - `T_accum`: `torch.int8` |
| - `bias`: `torch.int32` |
|
|
| The activation output is still a normal tensor because PyTorch autograd needs an activation dtype, but no TileLang float output is stored as model state. |
|
|
| ## Verification |
|
|
| TileLang is not installed in this local environment: |
|
|
| ```text |
| ModuleNotFoundError("No module named 'tilelang'") |
| ``` |
|
|
| So verification covered the production fallback and backend controls: |
|
|
| - `python -m compileall -q arbitor/kernel/ternary_scale.py arbitor training` |
| - `python -m pytest -q testing/test_tscale.py -k "cuda_triton_correctness_update_E or cuda_triton_tscale_path or cuda_triton_correctness_ternary_step"` |
| - `3 passed, 24 deselected` |
| - `ARB_TERNARY_BACKEND=triton` linear forward/backward/update smoke: |
| - output dtype: `torch.float32` |
| - `T_packed`: `torch.uint8` |
| - `E/E_accum/T_accum`: `torch.int8` |
| - direct hooks consumed after update. |
| - `ARB_TERNARY_BACKEND=torch` linear forward/backward/update smoke: |
| - persistent buffers stayed integer/ternary. |
| - `ARB_TERNARY_BACKEND=tilelang` correctly raises when TileLang is unavailable. |
| - `python -m arbitor.smoke --device cuda --ctx 4 --batch 1 --max-moe-iters 1 --no-vq --no-graph --backward` |
| - cold compile run: forward `34.320s`, backward/update `50.808s` |
| - cached run: forward `0.560s`, backward/update `1.379s` |
| - CUDA peak: `1652.45 MB` |
| - zero trainable float params and zero float buffers. |
|
|
| ## Operational Notes |
|
|
| Use this to force TileLang on a machine where it is installed: |
|
|
| ```bash |
| ARB_TERNARY_BACKEND=tilelang python -m arbitor.smoke --device cuda --ctx 4 --batch 1 --max-moe-iters 1 --backward |
| ``` |
|
|
| Use this for stable production fallback: |
|
|
| ```bash |
| ARB_TERNARY_BACKEND=auto python -m arbitor.train --ctx 128 --batch 1 --accum 4 --max-moe-iters 1 |
| ``` |
|
|
| Use this to isolate Triton regressions: |
|
|
| ```bash |
| ARB_TERNARY_BACKEND=triton python -m pytest -q testing/test_tscale.py -k cuda_triton |
| ``` |
|
|
| ## Remaining Work |
|
|
| - Run the same smoke on the machine where TileLang is actually installed and compare `auto` vs `tilelang` vs `triton`. |
| - If TileLang is consistently faster for the production shapes, add a `prewarm_tilelang.py` helper that walks the known `M,N,K,group_size` shapes before training. |
| - The next speed target remains fused sparse MoE dispatch for large-token batches. |
|
|
| ## TileLang NaN Hotfix |
|
|
| The first TileLang integration still allowed fp16 activation output and fp16 dequantized scale materialization. That can overflow for valid int8 log-scale values and poison training with NaN losses. |
|
|
| Fixes applied: |
|
|
| - TileLang forward output tensor is now `float32`, matching the stable Triton activation path. |
| - TileLang grad-x output tensor is now `float32`. |
| - The TileLang fp16 dequant operand clamps the exponent to the fp16-safe range `[-14, 15]` before casting into the fp16 GEMM tile. Persistent `E` remains int8 and is not clamped in storage. |
| - `ARB_TILELANG_CHECK_FINITE=1` is enabled by default. If TileLang produces non-finite activations in `auto` mode, the module disables TileLang and falls back to Triton/PyTorch instead of training on NaNs. |
| - `ARB_TERNARY_BACKEND=tilelang` still raises hard on non-finite TileLang output so debugging does not silently hide a broken kernel. |
|
|
| Additional verification in this environment, where TileLang itself is not installed: |
|
|
| - `python -m compileall -q arbitor/kernel/ternary_scale.py arbitor training` |
| - `python -m pytest -q testing/test_tscale.py -k "cuda_triton_correctness_update_E or cuda_triton_tscale_path or cuda_triton_correctness_ternary_step"` |
| - `3 passed, 24 deselected` |
| - Minimal CUDA training smoke: |
| - finite loss: `10.875505447387695` |
| - no leftover ternary update hooks after `_ternary_update_memory()` |
|
|
| ## TileLang Training Gate |
|
|
| Follow-up debug found that the remaining loss spikes/NaNs are caused by the TileLang fp16 compute path itself, not by persistent ternary state becoming float. |
|
|
| Small reproducer: |
|
|
| - Persistent state stayed `T_packed uint8` and `E int8`. |
| - A PyTorch TileLang-like fp16 path over packed ternary state showed the issue: |
| - very negative `E` values are numerically floored by the fp16/clamped dequant path, making tiny weights much larger than the true ternary scale, |
| - `E >= 15` can produce huge logits through fp16 tile operands, |
| - those logits can push training loss up or non-finite even though the stored model remains ternary. |
|
|
| Production fix: |
|
|
| - TileLang is no longer used for grad-enabled training by default. |
| - `ARB_TERNARY_BACKEND=auto` uses TileLang only outside training, then falls back to Triton for training. |
| - `ARB_TERNARY_BACKEND=tilelang` now raises during training unless explicitly enabled. |
| - `ARB_TILELANG_TRAINING=1` exists only for isolated debugging on the TileLang machine. |
| - `_ternary_update_memory()` now refuses to update ternary state after a NaN/Inf loss. |
| - `arbitor.train` and `training/pretrain.py` abort before backward if loss is non-finite. |
|
|
| Additional verification: |
|
|
| - `python -m pytest -q testing/test_tscale.py -k "small_ternary_training_loss_finite or ternary_update_rejects_nonfinite_loss or cuda_triton_correctness_update_E or cuda_triton_tscale_path or cuda_triton_correctness_ternary_step"` |
| - `5 passed, 24 deselected` |
| - `python -m arbitor.smoke --device cuda --ctx 4 --batch 1 --no-moe --no-vq --no-graph --backward` |
| - finite loss: `16.085211` |
| - backward/update: `0.147s` |
| - zero trainable float params and zero float buffers |
|
|
| Until the TileLang kernel has a bf16/fp32 dequant/GEMM path or an integer-scale matmul path, it should be treated as inference/prewarm acceleration, not a training backend. |
|
|