ARBS / docs /true-ternary /TRUE-TERNARY-REFACTOR13.md
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
# TRUE-TERNARY-REFACTOR13
Date: 2026-05-19
## Scope
TileLang production-readiness pass:
- keep TileLang as a fast ephemeral compute backend,
- preserve packed ternary persistent state,
- avoid dense float weight-gradient scratch buffers,
- keep Triton as the stable fallback,
- make backend selection explicit for cloud/debug runs.
## Backend Selection
`TernaryScaleTensor` now reads `ARB_TERNARY_BACKEND`:
- `auto` default: prefer TileLang when installed, otherwise use Triton, otherwise PyTorch fallback.
- `tilelang`: force TileLang and raise if unavailable or failing.
- `triton`: force Triton.
- `torch`: force the PyTorch fallback.
This lets cloud runs keep TileLang speed while still providing a stable fallback path during deployment.
## TileLang Integration Changes
- TileLang forward still consumes packed `T_packed` (`uint8`) and log-scale `E` (`int8`).
- TileLang output is cast back to the input activation dtype before returning.
- TileLang bias handling is now restored.
- TileLang backward no longer allocates dense `grad_W` (`K x N float32`) just to produce signs.
- TileLang backward now stores direct hooks:
- `_hook_grad_2d`
- `_hook_x_2d`
- Existing direct update kernels then update:
- `T_packed`
- `T_accum`
- `E`
- `E_accum`
- If Triton is unavailable, the fallback direct-update path computes the sign from `grad_y.T @ x` ephemerally and then discards the hooks.
Persistent model state remains ternary/integer:
- `T_packed`: `torch.uint8`
- `E`: `torch.int8`
- `E_accum`: `torch.int8`
- `T_accum`: `torch.int8`
- `bias`: `torch.int32`
The activation output is still a normal tensor because PyTorch autograd needs an activation dtype, but no TileLang float output is stored as model state.
## Verification
TileLang is not installed in this local environment:
```text
ModuleNotFoundError("No module named 'tilelang'")
```
So verification covered the production fallback and backend controls:
- `python -m compileall -q arbitor/kernel/ternary_scale.py arbitor training`
- `python -m pytest -q testing/test_tscale.py -k "cuda_triton_correctness_update_E or cuda_triton_tscale_path or cuda_triton_correctness_ternary_step"`
- `3 passed, 24 deselected`
- `ARB_TERNARY_BACKEND=triton` linear forward/backward/update smoke:
- output dtype: `torch.float32`
- `T_packed`: `torch.uint8`
- `E/E_accum/T_accum`: `torch.int8`
- direct hooks consumed after update.
- `ARB_TERNARY_BACKEND=torch` linear forward/backward/update smoke:
- persistent buffers stayed integer/ternary.
- `ARB_TERNARY_BACKEND=tilelang` correctly raises when TileLang is unavailable.
- `python -m arbitor.smoke --device cuda --ctx 4 --batch 1 --max-moe-iters 1 --no-vq --no-graph --backward`
- cold compile run: forward `34.320s`, backward/update `50.808s`
- cached run: forward `0.560s`, backward/update `1.379s`
- CUDA peak: `1652.45 MB`
- zero trainable float params and zero float buffers.
## Operational Notes
Use this to force TileLang on a machine where it is installed:
```bash
ARB_TERNARY_BACKEND=tilelang python -m arbitor.smoke --device cuda --ctx 4 --batch 1 --max-moe-iters 1 --backward
```
Use this for stable production fallback:
```bash
ARB_TERNARY_BACKEND=auto python -m arbitor.train --ctx 128 --batch 1 --accum 4 --max-moe-iters 1
```
Use this to isolate Triton regressions:
```bash
ARB_TERNARY_BACKEND=triton python -m pytest -q testing/test_tscale.py -k cuda_triton
```
## Remaining Work
- Run the same smoke on the machine where TileLang is actually installed and compare `auto` vs `tilelang` vs `triton`.
- If TileLang is consistently faster for the production shapes, add a `prewarm_tilelang.py` helper that walks the known `M,N,K,group_size` shapes before training.
- The next speed target remains fused sparse MoE dispatch for large-token batches.
## TileLang NaN Hotfix
The first TileLang integration still allowed fp16 activation output and fp16 dequantized scale materialization. That can overflow for valid int8 log-scale values and poison training with NaN losses.
Fixes applied:
- TileLang forward output tensor is now `float32`, matching the stable Triton activation path.
- TileLang grad-x output tensor is now `float32`.
- The TileLang fp16 dequant operand clamps the exponent to the fp16-safe range `[-14, 15]` before casting into the fp16 GEMM tile. Persistent `E` remains int8 and is not clamped in storage.
- `ARB_TILELANG_CHECK_FINITE=1` is enabled by default. If TileLang produces non-finite activations in `auto` mode, the module disables TileLang and falls back to Triton/PyTorch instead of training on NaNs.
- `ARB_TERNARY_BACKEND=tilelang` still raises hard on non-finite TileLang output so debugging does not silently hide a broken kernel.
Additional verification in this environment, where TileLang itself is not installed:
- `python -m compileall -q arbitor/kernel/ternary_scale.py arbitor training`
- `python -m pytest -q testing/test_tscale.py -k "cuda_triton_correctness_update_E or cuda_triton_tscale_path or cuda_triton_correctness_ternary_step"`
- `3 passed, 24 deselected`
- Minimal CUDA training smoke:
- finite loss: `10.875505447387695`
- no leftover ternary update hooks after `_ternary_update_memory()`
## TileLang Training Gate
Follow-up debug found that the remaining loss spikes/NaNs are caused by the TileLang fp16 compute path itself, not by persistent ternary state becoming float.
Small reproducer:
- Persistent state stayed `T_packed uint8` and `E int8`.
- A PyTorch TileLang-like fp16 path over packed ternary state showed the issue:
- very negative `E` values are numerically floored by the fp16/clamped dequant path, making tiny weights much larger than the true ternary scale,
- `E >= 15` can produce huge logits through fp16 tile operands,
- those logits can push training loss up or non-finite even though the stored model remains ternary.
Production fix:
- TileLang is no longer used for grad-enabled training by default.
- `ARB_TERNARY_BACKEND=auto` uses TileLang only outside training, then falls back to Triton for training.
- `ARB_TERNARY_BACKEND=tilelang` now raises during training unless explicitly enabled.
- `ARB_TILELANG_TRAINING=1` exists only for isolated debugging on the TileLang machine.
- `_ternary_update_memory()` now refuses to update ternary state after a NaN/Inf loss.
- `arbitor.train` and `training/pretrain.py` abort before backward if loss is non-finite.
Additional verification:
- `python -m pytest -q testing/test_tscale.py -k "small_ternary_training_loss_finite or ternary_update_rejects_nonfinite_loss or cuda_triton_correctness_update_E or cuda_triton_tscale_path or cuda_triton_correctness_ternary_step"`
- `5 passed, 24 deselected`
- `python -m arbitor.smoke --device cuda --ctx 4 --batch 1 --no-moe --no-vq --no-graph --backward`
- finite loss: `16.085211`
- backward/update: `0.147s`
- zero trainable float params and zero float buffers
Until the TileLang kernel has a bf16/fp32 dequant/GEMM path or an integer-scale matmul path, it should be treated as inference/prewarm acceleration, not a training backend.