File size: 7,097 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# TRUE-TERNARY-REFACTOR13

Date: 2026-05-19

## Scope

TileLang production-readiness pass:

- keep TileLang as a fast ephemeral compute backend,
- preserve packed ternary persistent state,
- avoid dense float weight-gradient scratch buffers,
- keep Triton as the stable fallback,
- make backend selection explicit for cloud/debug runs.

## Backend Selection

`TernaryScaleTensor` now reads `ARB_TERNARY_BACKEND`:

- `auto` default: prefer TileLang when installed, otherwise use Triton, otherwise PyTorch fallback.
- `tilelang`: force TileLang and raise if unavailable or failing.
- `triton`: force Triton.
- `torch`: force the PyTorch fallback.

This lets cloud runs keep TileLang speed while still providing a stable fallback path during deployment.

## TileLang Integration Changes

- TileLang forward still consumes packed `T_packed` (`uint8`) and log-scale `E` (`int8`).
- TileLang output is cast back to the input activation dtype before returning.
- TileLang bias handling is now restored.
- TileLang backward no longer allocates dense `grad_W` (`K x N float32`) just to produce signs.
- TileLang backward now stores direct hooks:
  - `_hook_grad_2d`
  - `_hook_x_2d`
- Existing direct update kernels then update:
  - `T_packed`
  - `T_accum`
  - `E`
  - `E_accum`
- If Triton is unavailable, the fallback direct-update path computes the sign from `grad_y.T @ x` ephemerally and then discards the hooks.

Persistent model state remains ternary/integer:

- `T_packed`: `torch.uint8`
- `E`: `torch.int8`
- `E_accum`: `torch.int8`
- `T_accum`: `torch.int8`
- `bias`: `torch.int32`

The activation output is still a normal tensor because PyTorch autograd needs an activation dtype, but no TileLang float output is stored as model state.

## Verification

TileLang is not installed in this local environment:

```text
ModuleNotFoundError("No module named 'tilelang'")
```

So verification covered the production fallback and backend controls:

- `python -m compileall -q arbitor/kernel/ternary_scale.py arbitor training`
- `python -m pytest -q testing/test_tscale.py -k "cuda_triton_correctness_update_E or cuda_triton_tscale_path or cuda_triton_correctness_ternary_step"`
  - `3 passed, 24 deselected`
- `ARB_TERNARY_BACKEND=triton` linear forward/backward/update smoke:
  - output dtype: `torch.float32`
  - `T_packed`: `torch.uint8`
  - `E/E_accum/T_accum`: `torch.int8`
  - direct hooks consumed after update.
- `ARB_TERNARY_BACKEND=torch` linear forward/backward/update smoke:
  - persistent buffers stayed integer/ternary.
- `ARB_TERNARY_BACKEND=tilelang` correctly raises when TileLang is unavailable.
- `python -m arbitor.smoke --device cuda --ctx 4 --batch 1 --max-moe-iters 1 --no-vq --no-graph --backward`
  - cold compile run: forward `34.320s`, backward/update `50.808s`
  - cached run: forward `0.560s`, backward/update `1.379s`
  - CUDA peak: `1652.45 MB`
  - zero trainable float params and zero float buffers.

## Operational Notes

Use this to force TileLang on a machine where it is installed:

```bash
ARB_TERNARY_BACKEND=tilelang python -m arbitor.smoke --device cuda --ctx 4 --batch 1 --max-moe-iters 1 --backward
```

Use this for stable production fallback:

```bash
ARB_TERNARY_BACKEND=auto python -m arbitor.train --ctx 128 --batch 1 --accum 4 --max-moe-iters 1
```

Use this to isolate Triton regressions:

```bash
ARB_TERNARY_BACKEND=triton python -m pytest -q testing/test_tscale.py -k cuda_triton
```

## Remaining Work

- Run the same smoke on the machine where TileLang is actually installed and compare `auto` vs `tilelang` vs `triton`.
- If TileLang is consistently faster for the production shapes, add a `prewarm_tilelang.py` helper that walks the known `M,N,K,group_size` shapes before training.
- The next speed target remains fused sparse MoE dispatch for large-token batches.

## TileLang NaN Hotfix

The first TileLang integration still allowed fp16 activation output and fp16 dequantized scale materialization. That can overflow for valid int8 log-scale values and poison training with NaN losses.

Fixes applied:

- TileLang forward output tensor is now `float32`, matching the stable Triton activation path.
- TileLang grad-x output tensor is now `float32`.
- The TileLang fp16 dequant operand clamps the exponent to the fp16-safe range `[-14, 15]` before casting into the fp16 GEMM tile. Persistent `E` remains int8 and is not clamped in storage.
- `ARB_TILELANG_CHECK_FINITE=1` is enabled by default. If TileLang produces non-finite activations in `auto` mode, the module disables TileLang and falls back to Triton/PyTorch instead of training on NaNs.
- `ARB_TERNARY_BACKEND=tilelang` still raises hard on non-finite TileLang output so debugging does not silently hide a broken kernel.

Additional verification in this environment, where TileLang itself is not installed:

- `python -m compileall -q arbitor/kernel/ternary_scale.py arbitor training`
- `python -m pytest -q testing/test_tscale.py -k "cuda_triton_correctness_update_E or cuda_triton_tscale_path or cuda_triton_correctness_ternary_step"`
  - `3 passed, 24 deselected`
- Minimal CUDA training smoke:
  - finite loss: `10.875505447387695`
  - no leftover ternary update hooks after `_ternary_update_memory()`

## TileLang Training Gate

Follow-up debug found that the remaining loss spikes/NaNs are caused by the TileLang fp16 compute path itself, not by persistent ternary state becoming float.

Small reproducer:

- Persistent state stayed `T_packed uint8` and `E int8`.
- A PyTorch TileLang-like fp16 path over packed ternary state showed the issue:
  - very negative `E` values are numerically floored by the fp16/clamped dequant path, making tiny weights much larger than the true ternary scale,
  - `E >= 15` can produce huge logits through fp16 tile operands,
  - those logits can push training loss up or non-finite even though the stored model remains ternary.

Production fix:

- TileLang is no longer used for grad-enabled training by default.
- `ARB_TERNARY_BACKEND=auto` uses TileLang only outside training, then falls back to Triton for training.
- `ARB_TERNARY_BACKEND=tilelang` now raises during training unless explicitly enabled.
- `ARB_TILELANG_TRAINING=1` exists only for isolated debugging on the TileLang machine.
- `_ternary_update_memory()` now refuses to update ternary state after a NaN/Inf loss.
- `arbitor.train` and `training/pretrain.py` abort before backward if loss is non-finite.

Additional verification:

- `python -m pytest -q testing/test_tscale.py -k "small_ternary_training_loss_finite or ternary_update_rejects_nonfinite_loss or cuda_triton_correctness_update_E or cuda_triton_tscale_path or cuda_triton_correctness_ternary_step"`
  - `5 passed, 24 deselected`
- `python -m arbitor.smoke --device cuda --ctx 4 --batch 1 --no-moe --no-vq --no-graph --backward`
  - finite loss: `16.085211`
  - backward/update: `0.147s`
  - zero trainable float params and zero float buffers

Until the TileLang kernel has a bf16/fp32 dequant/GEMM path or an integer-scale matmul path, it should be treated as inference/prewarm acceleration, not a training backend.