File size: 7,097 Bytes
d8bc908 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 | # TRUE-TERNARY-REFACTOR13
Date: 2026-05-19
## Scope
TileLang production-readiness pass:
- keep TileLang as a fast ephemeral compute backend,
- preserve packed ternary persistent state,
- avoid dense float weight-gradient scratch buffers,
- keep Triton as the stable fallback,
- make backend selection explicit for cloud/debug runs.
## Backend Selection
`TernaryScaleTensor` now reads `ARB_TERNARY_BACKEND`:
- `auto` default: prefer TileLang when installed, otherwise use Triton, otherwise PyTorch fallback.
- `tilelang`: force TileLang and raise if unavailable or failing.
- `triton`: force Triton.
- `torch`: force the PyTorch fallback.
This lets cloud runs keep TileLang speed while still providing a stable fallback path during deployment.
## TileLang Integration Changes
- TileLang forward still consumes packed `T_packed` (`uint8`) and log-scale `E` (`int8`).
- TileLang output is cast back to the input activation dtype before returning.
- TileLang bias handling is now restored.
- TileLang backward no longer allocates dense `grad_W` (`K x N float32`) just to produce signs.
- TileLang backward now stores direct hooks:
- `_hook_grad_2d`
- `_hook_x_2d`
- Existing direct update kernels then update:
- `T_packed`
- `T_accum`
- `E`
- `E_accum`
- If Triton is unavailable, the fallback direct-update path computes the sign from `grad_y.T @ x` ephemerally and then discards the hooks.
Persistent model state remains ternary/integer:
- `T_packed`: `torch.uint8`
- `E`: `torch.int8`
- `E_accum`: `torch.int8`
- `T_accum`: `torch.int8`
- `bias`: `torch.int32`
The activation output is still a normal tensor because PyTorch autograd needs an activation dtype, but no TileLang float output is stored as model state.
## Verification
TileLang is not installed in this local environment:
```text
ModuleNotFoundError("No module named 'tilelang'")
```
So verification covered the production fallback and backend controls:
- `python -m compileall -q arbitor/kernel/ternary_scale.py arbitor training`
- `python -m pytest -q testing/test_tscale.py -k "cuda_triton_correctness_update_E or cuda_triton_tscale_path or cuda_triton_correctness_ternary_step"`
- `3 passed, 24 deselected`
- `ARB_TERNARY_BACKEND=triton` linear forward/backward/update smoke:
- output dtype: `torch.float32`
- `T_packed`: `torch.uint8`
- `E/E_accum/T_accum`: `torch.int8`
- direct hooks consumed after update.
- `ARB_TERNARY_BACKEND=torch` linear forward/backward/update smoke:
- persistent buffers stayed integer/ternary.
- `ARB_TERNARY_BACKEND=tilelang` correctly raises when TileLang is unavailable.
- `python -m arbitor.smoke --device cuda --ctx 4 --batch 1 --max-moe-iters 1 --no-vq --no-graph --backward`
- cold compile run: forward `34.320s`, backward/update `50.808s`
- cached run: forward `0.560s`, backward/update `1.379s`
- CUDA peak: `1652.45 MB`
- zero trainable float params and zero float buffers.
## Operational Notes
Use this to force TileLang on a machine where it is installed:
```bash
ARB_TERNARY_BACKEND=tilelang python -m arbitor.smoke --device cuda --ctx 4 --batch 1 --max-moe-iters 1 --backward
```
Use this for stable production fallback:
```bash
ARB_TERNARY_BACKEND=auto python -m arbitor.train --ctx 128 --batch 1 --accum 4 --max-moe-iters 1
```
Use this to isolate Triton regressions:
```bash
ARB_TERNARY_BACKEND=triton python -m pytest -q testing/test_tscale.py -k cuda_triton
```
## Remaining Work
- Run the same smoke on the machine where TileLang is actually installed and compare `auto` vs `tilelang` vs `triton`.
- If TileLang is consistently faster for the production shapes, add a `prewarm_tilelang.py` helper that walks the known `M,N,K,group_size` shapes before training.
- The next speed target remains fused sparse MoE dispatch for large-token batches.
## TileLang NaN Hotfix
The first TileLang integration still allowed fp16 activation output and fp16 dequantized scale materialization. That can overflow for valid int8 log-scale values and poison training with NaN losses.
Fixes applied:
- TileLang forward output tensor is now `float32`, matching the stable Triton activation path.
- TileLang grad-x output tensor is now `float32`.
- The TileLang fp16 dequant operand clamps the exponent to the fp16-safe range `[-14, 15]` before casting into the fp16 GEMM tile. Persistent `E` remains int8 and is not clamped in storage.
- `ARB_TILELANG_CHECK_FINITE=1` is enabled by default. If TileLang produces non-finite activations in `auto` mode, the module disables TileLang and falls back to Triton/PyTorch instead of training on NaNs.
- `ARB_TERNARY_BACKEND=tilelang` still raises hard on non-finite TileLang output so debugging does not silently hide a broken kernel.
Additional verification in this environment, where TileLang itself is not installed:
- `python -m compileall -q arbitor/kernel/ternary_scale.py arbitor training`
- `python -m pytest -q testing/test_tscale.py -k "cuda_triton_correctness_update_E or cuda_triton_tscale_path or cuda_triton_correctness_ternary_step"`
- `3 passed, 24 deselected`
- Minimal CUDA training smoke:
- finite loss: `10.875505447387695`
- no leftover ternary update hooks after `_ternary_update_memory()`
## TileLang Training Gate
Follow-up debug found that the remaining loss spikes/NaNs are caused by the TileLang fp16 compute path itself, not by persistent ternary state becoming float.
Small reproducer:
- Persistent state stayed `T_packed uint8` and `E int8`.
- A PyTorch TileLang-like fp16 path over packed ternary state showed the issue:
- very negative `E` values are numerically floored by the fp16/clamped dequant path, making tiny weights much larger than the true ternary scale,
- `E >= 15` can produce huge logits through fp16 tile operands,
- those logits can push training loss up or non-finite even though the stored model remains ternary.
Production fix:
- TileLang is no longer used for grad-enabled training by default.
- `ARB_TERNARY_BACKEND=auto` uses TileLang only outside training, then falls back to Triton for training.
- `ARB_TERNARY_BACKEND=tilelang` now raises during training unless explicitly enabled.
- `ARB_TILELANG_TRAINING=1` exists only for isolated debugging on the TileLang machine.
- `_ternary_update_memory()` now refuses to update ternary state after a NaN/Inf loss.
- `arbitor.train` and `training/pretrain.py` abort before backward if loss is non-finite.
Additional verification:
- `python -m pytest -q testing/test_tscale.py -k "small_ternary_training_loss_finite or ternary_update_rejects_nonfinite_loss or cuda_triton_correctness_update_E or cuda_triton_tscale_path or cuda_triton_correctness_ternary_step"`
- `5 passed, 24 deselected`
- `python -m arbitor.smoke --device cuda --ctx 4 --batch 1 --no-moe --no-vq --no-graph --backward`
- finite loss: `16.085211`
- backward/update: `0.147s`
- zero trainable float params and zero float buffers
Until the TileLang kernel has a bf16/fp32 dequant/GEMM path or an integer-scale matmul path, it should be treated as inference/prewarm acceleration, not a training backend.
|