File size: 15,026 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
# True Ternary Refactor 2

## What Changed In This Pass

### `tscale.py`

- Added Triton detection and a CUDA/Triton execution path for `TernaryScaleTensor`.
- Added packed ternary forward and grad-input kernels that read:
  - `T_packed` as 5 trits per byte
  - `E` as int8 log-scale groups
  - `x` / `grad_y` on CUDA
- `TernaryScaleTensor.forward()` now prefers the Triton path when input tensors are CUDA tensors.
- Fixed the old TileLang negative exponent issue. The previous integer shift path made `2^-k` become zero. The TileLang fallback now reconstructs `2^E` before multiplying by sign.
- Fixed the TileLang kernel cache key so CUDA forward kernels are compiled for the actual flattened batch size `M`, not `N`.
- Fixed ternary update ordering by calling `update_E()` before `ternary_step()`, so the gradient sign is not deleted before the exponent update sees it.
- Kept `T_packed` on the original device after repacking. Without this, repacking moved the buffer back to CPU.
- Added a Triton sign-only weight-gradient reduction kernel. Backward no longer materializes:

  ```python
  grad_w = grad_y.T @ x
  ```

  as a dense FP32 tensor. It now computes only:

  ```text
  sign(sum_m grad_y[m,n] * x[m,k])
  ```

  directly into an `int8` CUDA tensor.
- Added a Triton `E` update kernel. CUDA exponent updates now read packed trits from `T_packed` and int8 gradient signs directly, then update int8 `E` in-place without unpacking full `T` through PyTorch.
- Added a Triton ternary-step/repack kernel. CUDA `ternary_step()` now updates `T_accum`, applies flip thresholds, resets flipped accumulators, and rewrites `T_packed` in-place without calling Python `pack_ternary()`.
- Removed the dense `int8 grad_sign[N,K]` allocation from the normal CUDA autograd path. Backward now retains compact `grad_y` and `x` views, and the CUDA `E` update / ternary-step kernels recompute the sign reduction directly inside the state-update kernels.
- Fused the expensive sign reduction into the ternary-step/repack pass. The fused CUDA path now updates `T_accum`, rewrites `T_packed`, and atomically accumulates per-group `E` scores in one pass over `grad_y` and `x`.
- Replaced the separate direct `E` reduction with a tiny score-apply kernel over `E` groups. The remaining temporary is `int32` per scale group, not per logical weight.

### `trigram.py`

- Replaced MoE routers with `TernaryScaleTensor`:
  - `moe.router`
  - `moe.router_h`
- Added constructor gates for strict text-only ternary training:
  - `enable_image`
  - `enable_vq`
  - `enable_graph`
  - `enable_memory_modules`
  - `enable_moe`
- In strict mode, the model can be built without VQ, graph, image, LSTM, MemGram, or ConvVQ modules, which removes the hidden trainable float state from the core text model.
- Added a no-VQ forward path where text relational states go directly into MoE and ByteHead.

### `train.py`

- Default optimizer changed to `signsgd`.
- Added `--compute_dtype {bf16,fp16,none}`.
- Added `--strict_ternary`.
  - Forces SignSGD.
  - Forces `compute_dtype=none`.
  - Disables VQ, graph, image, and memory modules.
  - Freezes any remaining trainable float parameters.
- Added `--freeze_float_params` for non-strict runs.
- Added model state audit logging before training.
- Fixed the main training loop indentation so optimizer steps run inside the data batch loop.
- Fixed gradient clipping and optimizer construction to use only trainable parameters.

### `ternary_audit.py`

New helper module for reporting:

- logical ternary weights
- packed ternary bytes
- int8 exponent bytes
- int8 accumulator bytes
- trainable floating-point parameters
- frozen floating-point parameters
- floating-point buffers

Strict text-only audit currently reports zero trainable float params and zero float buffers.

### `testing/test_tscale.py`

- Added a CUDA/Triton path test for `TernaryScaleTensor`.
- The CUDA/Triton test now compares the Triton sign-only gradient against a PyTorch reference sign and asserts the captured gradient state is `int8` on CUDA.
- The CUDA/Triton test now verifies the device-side `E` update modifies exponent groups and keeps ternary buffers on CUDA.
- The CUDA/Triton test now forces threshold crossings and verifies GPU repack flips packed trits to `+1` and resets accumulators.
- The CUDA/Triton test now asserts normal backward does not create `_hook_grad_T_sign`; it uses retained `grad_y` and `x` views for direct state updates.
- The CUDA/Triton test now covers the fused ternary-step plus `E` score application path.
- Made missing TileLang reference tests skip instead of failing when `tilelang/kernels/dequant_gemm.py` is absent.

## Verification Run

Direct CUDA/Triton smoke:

```text
cuda True triton True y_device cuda:0 packed_device cuda:0 E_device cuda:0
```

Strict one-step train:

```text
Device: cuda
trainable float params: 0 tensors, 0.00 MB
float buffers: 0 tensors, 0.00 MB
Optimizer: SignSGD
```

Sign-only gradient kernel test:

```text
PASS test_cuda_triton_tscale_path
```

The same test now also covers the CUDA-side `E` update kernel and CUDA-side packed ternary repack.

Strict one-step train after replacing dense `grad_w`:

```text
Device: cuda
trainable float params: 0 tensors, 0.00 MB
float buffers: 0 tensors, 0.00 MB
train=6.7644 val=10.3655
```

Strict one-step train after moving `E` update to Triton:

```text
Device: cuda
trainable float params: 0 tensors, 0.00 MB
float buffers: 0 tensors, 0.00 MB
train=6.7644 val=10.3655
```

Strict one-step train after moving `ternary_step()` repack to Triton:

```text
Device: cuda
trainable float params: 0 tensors, 0.00 MB
float buffers: 0 tensors, 0.00 MB
train=6.7644 val=10.3655
```

Strict one-step train after removing dense `grad_sign[N,K]` from the normal CUDA path:

```text
Device: cuda
trainable float params: 0 tensors, 0.00 MB
float buffers: 0 tensors, 0.00 MB
train=6.7644 val=10.3655
```

Strict one-step train after fusing sign reduction into ternary-step/repack and applying `E` from per-group scores:

```text
Device: cuda
trainable float params: 0 tensors, 0.00 MB
float buffers: 0 tensors, 0.00 MB
train=6.7644 val=10.3655
```

## Remaining Problem: `grad_w = grad_y.T @ x`

The first version of the Triton path fixed forward and grad-input memory behavior, but backward still materialized a dense temporary `grad_w` in Python:

```python
grad_w = grad_2d.float().t() @ x_2d.float()
grad_sign = grad_w.sign().to(torch.int8)
```

That has now been replaced by a Triton sign-only reduction. The dense FP32 `grad_w` tensor is gone from the Triton backward path.

Current remaining issue:

- The CUDA update path now allocates a small `int32` score buffer shaped like `E`, not like the full logical weight matrix.
- The score buffer is `out_dim * ceil(in_dim / group_size)` entries. At `T32` group size 12, this is about 1/12 as many elements as the logical weights.
- The fused kernel uses `tl.atomic_add` into group scores because 5-trit packed bytes and scale groups are not naturally aligned.

This is now memory-aligned with the 3B-on-8GB goal much better than the previous path. The next optimization is performance tuning: reduce atomics, tune tile sizes, and consider group/pack alignment changes.

The actual required update is:

```text
T_accum[i, j] += sign(sum_m grad_y[m, i] * x[m, j])
E update uses grouped sign statistics from the same reduction.
```

So the next fix is to avoid storing even the dense `int8` sign tensor by fusing the reduction directly into the ternary state update.

## Proposed Fix For Gradient Sign Capture

### Stage 1: Sign-Only Grad Kernel

Implemented. Added a Triton kernel:

```text
input:
  grad_y[M, N]
  x[M, K]

output:
  grad_sign[N, K] int8
```

Each program owns a tile of `(N, K)`, loops across `M`, accumulates a local `float32` or `int32-ish` sum, then immediately converts to sign:

```text
s = sum_m grad_y[m, n] * x[m, k]
g = sign(s)
grad_sign[n, k] = g
```

This removes the dense FP32 `grad_w` allocation. It still performs the same math, but the result is reduced to int8 at the end of each tile instead of returning a full precision matrix to PyTorch.

### Stage 2: GPU T Accumulator Update And Repack

Implemented as a fused Triton kernel. It computes the sign reduction from `grad_y` and `x`, updates `T_accum`, applies threshold flips, resets flipped accumulator entries, rewrites `T_packed` in-place, and emits per-group scale scores:

```text
g = sign(sum_m grad_y[m,n] * x[m,k])
score[n, group(k)] += g * old_T[n,k]
T_accum[n, k] = clamp(T_accum[n, k] + grad_sign[n, k], -128, 127)
if T_accum[n,k] > threshold:  T[n,k] = +1, T_accum[n,k] = 0
if T_accum[n,k] < -threshold: T[n,k] = -1, T_accum[n,k] = 0
repack 5 T values into one uint8
```

This removes Python unpack/repack from the CUDA path.

### Stage 3: Separate E Group Update

Implemented for the CUDA/Triton path as a small score-apply kernel.

For scale exponents, update per group rather than per weight:

```text
group_score[n, g] = sum_{k in group} sign(grad_w[n, k]) * T[n, k]
E[n, g] = clamp(E[n, g] - sign(group_score[n, g]), -128, 127)
```

This is now a second Triton kernel over `(N, groups)` that applies the score emitted by the fused ternary-step/repack kernel.

Current implemented split:

1. fused sign-reduction + `T_accum` + repack + group-score kernel
2. small `E` score-apply kernel

No sign tensor is stored in the normal CUDA path.

### Stage 4: Tune Or Remove Group-Score Atomics

Next step. The current fused CUDA path uses atomics to accumulate group scores because packed bytes are 5-trit chunks while `E` groups are 12, 6, 24, 48, 64, or 96 weights depending on TScale type.

Options:

- Keep atomics and tune block sizes.
- Change group sizes to align with 5-trit packing where possible.
- Use a group-owned kernel for `E` and a pack-owned kernel for `T`, accepting two reductions but no atomics.
- Move to a custom CUDA/CUTLASS kernel if Triton atomics become the bottleneck.

This is now a speed optimization, not a memory correctness blocker.

### Stage 5: Activation Ternary Mode

Forward currently consumes normal activation tensors. True ternary training should eventually use:

```text
A in {-1, 0, +1}
W in {-1, 0, +1}
accumulator in int32 or fp32
scale from int8 E
```

This makes the gradient-sign kernel cheaper because `x[m, k]` is also sign/zero rather than a dense float.

## Native CUDA Speed Path

### cuBLAS

cuBLAS is the wrong direct target for packed ternary weights. NVIDIA cuBLAS is highly optimized for standard BLAS and low/mixed precision types, including tensor-core-backed FP/INT formats, but it does not accept a custom 5-trit-per-byte ternary matrix as a native GEMM input.

Use cuBLAS only for baselines or fallback dequantized GEMM.

### Triton

Triton is the best immediate path.

Reasons:

- Already installed in this environment.
- Integrates with PyTorch autograd quickly.
- Good for custom packed formats and fused update kernels.
- Lets us remove the dense `grad_w` allocation without waiting on TileLang setup.

Near-term target:

```text
Triton forward packed ternary GEMM
Triton grad-input packed ternary GEMM
Triton sign-only grad/state-update kernel
Triton device repack kernel
```

### TileLang

TileLang is still worth getting working after the algorithm is stable. Its docs describe `T.gemm` lowering to target-specific tensor cores, and it is designed for tiled AI kernels such as dequant GEMM and FlashAttention-style workloads.

Use TileLang when:

- Triton semantics are proven.
- Tile sizes and data layout are stable.
- We want a cleaner path toward tensor-core-style tiling and schedule tuning.

TileLang should not be the first blocker because it is not currently installed in this environment.

### Custom CUDA/CUTLASS

This is the final performance path, not the first implementation path.

Use custom CUDA/CUTLASS when:

- The Triton kernels prove the training algorithm.
- Profiling shows Triton is bottlenecked by decode/repack/reduction overhead.
- We need warp-level bit/trit decode, shared-memory staging, and tuned occupancy.

This path has the highest ceiling and highest development cost.

## Recommended Roadmap

1. Keep the current Triton forward and grad-input path.
2. Add `ternary_grad_sign_accum_kernel` to eliminate dense `grad_w`.
3. Add GPU-side `E` update.
4. Add GPU-side repack of `T_packed`.
5. Benchmark strict ternary training memory with `torch.cuda.max_memory_allocated()`.
6. Tune Triton block sizes.
7. Install and evaluate TileLang against the Triton kernels.
8. Move only proven hot kernels to CUDA/CUTLASS if needed.

## Key Constraint

There is no way around accumulation. Even a ternary model must accumulate dot products and gradient reductions in something wider than ternary. The goal is not "no accumulator precision"; the goal is:

- no persistent BF16/FP32/FP8 weights
- no persistent FP optimizer state
- no dense full-precision weight gradients
- packed ternary storage
- int8 scale memory
- streaming reductions into ternary/int8 state

## Speed Pass: Scale Update Scheduling

The current bottleneck is the exponent scale update, not the packed forward path. A 4-step strict smoke with scheduled scale updates ran slower than the no-scale path, while disabling `E` updates reached about `4.90 it/s` on the same small batch after compile.

Changes made:

- Added `--scale_update_interval`.
  - Default: `4`.
  - `1`: update int8 `E` every step.
  - `0`: disable `E` updates and always use the fast direct ternary repack path.
- Changed CUDA `update_E()` back to the direct group-owned Triton kernel for the normal retained `grad_y/x` path.
- `ternary_step()` now uses the fast direct repack kernel after `update_E()` instead of the fused score/repack kernel on scheduled scale-update steps.
- This removes the temporary int32 score buffer and `tl.atomic_add` from the default scheduled scale-update path.

Observed before the direct `E` change:

```text
scale_update_interval=4: ~1.87 it/s overall on 4-step strict smoke
scale_update_interval=0: ~4.90 it/s overall on 4-step strict smoke
```

Observed after the direct group-owned `E` change:

```text
scale_update_interval=4: ~4.80 it/s overall on 4-step strict smoke
scale_update_interval=1: ~2.82 it/s overall on 4-step strict smoke
```

Interpretation:

- The fused score/repack path was slower than expected because the score buffer and atomics dominated this small-batch update path.
- The direct group-owned `E` kernel is faster even though it performs a separate reduction, because each program owns an `E` group and writes without atomics.
- The default `scale_update_interval=4` now preserves scheduled int8 scale learning while keeping most steps on the fast direct repack path.

Next speed target:

- Benchmark the direct group-owned `E` kernel versus the fused score kernel on real layer sizes.
- If direct `E` is faster, remove or demote the fused score kernel to an experimental path.
- Tune `BLOCK_N`, `BLOCK_K`, and `BLOCK_M` in `_triton_update_e_direct_kernel` and `_triton_ternary_step_direct_kernel`.