File size: 15,026 Bytes
d8bc908 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 | # True Ternary Refactor 2
## What Changed In This Pass
### `tscale.py`
- Added Triton detection and a CUDA/Triton execution path for `TernaryScaleTensor`.
- Added packed ternary forward and grad-input kernels that read:
- `T_packed` as 5 trits per byte
- `E` as int8 log-scale groups
- `x` / `grad_y` on CUDA
- `TernaryScaleTensor.forward()` now prefers the Triton path when input tensors are CUDA tensors.
- Fixed the old TileLang negative exponent issue. The previous integer shift path made `2^-k` become zero. The TileLang fallback now reconstructs `2^E` before multiplying by sign.
- Fixed the TileLang kernel cache key so CUDA forward kernels are compiled for the actual flattened batch size `M`, not `N`.
- Fixed ternary update ordering by calling `update_E()` before `ternary_step()`, so the gradient sign is not deleted before the exponent update sees it.
- Kept `T_packed` on the original device after repacking. Without this, repacking moved the buffer back to CPU.
- Added a Triton sign-only weight-gradient reduction kernel. Backward no longer materializes:
```python
grad_w = grad_y.T @ x
```
as a dense FP32 tensor. It now computes only:
```text
sign(sum_m grad_y[m,n] * x[m,k])
```
directly into an `int8` CUDA tensor.
- Added a Triton `E` update kernel. CUDA exponent updates now read packed trits from `T_packed` and int8 gradient signs directly, then update int8 `E` in-place without unpacking full `T` through PyTorch.
- Added a Triton ternary-step/repack kernel. CUDA `ternary_step()` now updates `T_accum`, applies flip thresholds, resets flipped accumulators, and rewrites `T_packed` in-place without calling Python `pack_ternary()`.
- Removed the dense `int8 grad_sign[N,K]` allocation from the normal CUDA autograd path. Backward now retains compact `grad_y` and `x` views, and the CUDA `E` update / ternary-step kernels recompute the sign reduction directly inside the state-update kernels.
- Fused the expensive sign reduction into the ternary-step/repack pass. The fused CUDA path now updates `T_accum`, rewrites `T_packed`, and atomically accumulates per-group `E` scores in one pass over `grad_y` and `x`.
- Replaced the separate direct `E` reduction with a tiny score-apply kernel over `E` groups. The remaining temporary is `int32` per scale group, not per logical weight.
### `trigram.py`
- Replaced MoE routers with `TernaryScaleTensor`:
- `moe.router`
- `moe.router_h`
- Added constructor gates for strict text-only ternary training:
- `enable_image`
- `enable_vq`
- `enable_graph`
- `enable_memory_modules`
- `enable_moe`
- In strict mode, the model can be built without VQ, graph, image, LSTM, MemGram, or ConvVQ modules, which removes the hidden trainable float state from the core text model.
- Added a no-VQ forward path where text relational states go directly into MoE and ByteHead.
### `train.py`
- Default optimizer changed to `signsgd`.
- Added `--compute_dtype {bf16,fp16,none}`.
- Added `--strict_ternary`.
- Forces SignSGD.
- Forces `compute_dtype=none`.
- Disables VQ, graph, image, and memory modules.
- Freezes any remaining trainable float parameters.
- Added `--freeze_float_params` for non-strict runs.
- Added model state audit logging before training.
- Fixed the main training loop indentation so optimizer steps run inside the data batch loop.
- Fixed gradient clipping and optimizer construction to use only trainable parameters.
### `ternary_audit.py`
New helper module for reporting:
- logical ternary weights
- packed ternary bytes
- int8 exponent bytes
- int8 accumulator bytes
- trainable floating-point parameters
- frozen floating-point parameters
- floating-point buffers
Strict text-only audit currently reports zero trainable float params and zero float buffers.
### `testing/test_tscale.py`
- Added a CUDA/Triton path test for `TernaryScaleTensor`.
- The CUDA/Triton test now compares the Triton sign-only gradient against a PyTorch reference sign and asserts the captured gradient state is `int8` on CUDA.
- The CUDA/Triton test now verifies the device-side `E` update modifies exponent groups and keeps ternary buffers on CUDA.
- The CUDA/Triton test now forces threshold crossings and verifies GPU repack flips packed trits to `+1` and resets accumulators.
- The CUDA/Triton test now asserts normal backward does not create `_hook_grad_T_sign`; it uses retained `grad_y` and `x` views for direct state updates.
- The CUDA/Triton test now covers the fused ternary-step plus `E` score application path.
- Made missing TileLang reference tests skip instead of failing when `tilelang/kernels/dequant_gemm.py` is absent.
## Verification Run
Direct CUDA/Triton smoke:
```text
cuda True triton True y_device cuda:0 packed_device cuda:0 E_device cuda:0
```
Strict one-step train:
```text
Device: cuda
trainable float params: 0 tensors, 0.00 MB
float buffers: 0 tensors, 0.00 MB
Optimizer: SignSGD
```
Sign-only gradient kernel test:
```text
PASS test_cuda_triton_tscale_path
```
The same test now also covers the CUDA-side `E` update kernel and CUDA-side packed ternary repack.
Strict one-step train after replacing dense `grad_w`:
```text
Device: cuda
trainable float params: 0 tensors, 0.00 MB
float buffers: 0 tensors, 0.00 MB
train=6.7644 val=10.3655
```
Strict one-step train after moving `E` update to Triton:
```text
Device: cuda
trainable float params: 0 tensors, 0.00 MB
float buffers: 0 tensors, 0.00 MB
train=6.7644 val=10.3655
```
Strict one-step train after moving `ternary_step()` repack to Triton:
```text
Device: cuda
trainable float params: 0 tensors, 0.00 MB
float buffers: 0 tensors, 0.00 MB
train=6.7644 val=10.3655
```
Strict one-step train after removing dense `grad_sign[N,K]` from the normal CUDA path:
```text
Device: cuda
trainable float params: 0 tensors, 0.00 MB
float buffers: 0 tensors, 0.00 MB
train=6.7644 val=10.3655
```
Strict one-step train after fusing sign reduction into ternary-step/repack and applying `E` from per-group scores:
```text
Device: cuda
trainable float params: 0 tensors, 0.00 MB
float buffers: 0 tensors, 0.00 MB
train=6.7644 val=10.3655
```
## Remaining Problem: `grad_w = grad_y.T @ x`
The first version of the Triton path fixed forward and grad-input memory behavior, but backward still materialized a dense temporary `grad_w` in Python:
```python
grad_w = grad_2d.float().t() @ x_2d.float()
grad_sign = grad_w.sign().to(torch.int8)
```
That has now been replaced by a Triton sign-only reduction. The dense FP32 `grad_w` tensor is gone from the Triton backward path.
Current remaining issue:
- The CUDA update path now allocates a small `int32` score buffer shaped like `E`, not like the full logical weight matrix.
- The score buffer is `out_dim * ceil(in_dim / group_size)` entries. At `T32` group size 12, this is about 1/12 as many elements as the logical weights.
- The fused kernel uses `tl.atomic_add` into group scores because 5-trit packed bytes and scale groups are not naturally aligned.
This is now memory-aligned with the 3B-on-8GB goal much better than the previous path. The next optimization is performance tuning: reduce atomics, tune tile sizes, and consider group/pack alignment changes.
The actual required update is:
```text
T_accum[i, j] += sign(sum_m grad_y[m, i] * x[m, j])
E update uses grouped sign statistics from the same reduction.
```
So the next fix is to avoid storing even the dense `int8` sign tensor by fusing the reduction directly into the ternary state update.
## Proposed Fix For Gradient Sign Capture
### Stage 1: Sign-Only Grad Kernel
Implemented. Added a Triton kernel:
```text
input:
grad_y[M, N]
x[M, K]
output:
grad_sign[N, K] int8
```
Each program owns a tile of `(N, K)`, loops across `M`, accumulates a local `float32` or `int32-ish` sum, then immediately converts to sign:
```text
s = sum_m grad_y[m, n] * x[m, k]
g = sign(s)
grad_sign[n, k] = g
```
This removes the dense FP32 `grad_w` allocation. It still performs the same math, but the result is reduced to int8 at the end of each tile instead of returning a full precision matrix to PyTorch.
### Stage 2: GPU T Accumulator Update And Repack
Implemented as a fused Triton kernel. It computes the sign reduction from `grad_y` and `x`, updates `T_accum`, applies threshold flips, resets flipped accumulator entries, rewrites `T_packed` in-place, and emits per-group scale scores:
```text
g = sign(sum_m grad_y[m,n] * x[m,k])
score[n, group(k)] += g * old_T[n,k]
T_accum[n, k] = clamp(T_accum[n, k] + grad_sign[n, k], -128, 127)
if T_accum[n,k] > threshold: T[n,k] = +1, T_accum[n,k] = 0
if T_accum[n,k] < -threshold: T[n,k] = -1, T_accum[n,k] = 0
repack 5 T values into one uint8
```
This removes Python unpack/repack from the CUDA path.
### Stage 3: Separate E Group Update
Implemented for the CUDA/Triton path as a small score-apply kernel.
For scale exponents, update per group rather than per weight:
```text
group_score[n, g] = sum_{k in group} sign(grad_w[n, k]) * T[n, k]
E[n, g] = clamp(E[n, g] - sign(group_score[n, g]), -128, 127)
```
This is now a second Triton kernel over `(N, groups)` that applies the score emitted by the fused ternary-step/repack kernel.
Current implemented split:
1. fused sign-reduction + `T_accum` + repack + group-score kernel
2. small `E` score-apply kernel
No sign tensor is stored in the normal CUDA path.
### Stage 4: Tune Or Remove Group-Score Atomics
Next step. The current fused CUDA path uses atomics to accumulate group scores because packed bytes are 5-trit chunks while `E` groups are 12, 6, 24, 48, 64, or 96 weights depending on TScale type.
Options:
- Keep atomics and tune block sizes.
- Change group sizes to align with 5-trit packing where possible.
- Use a group-owned kernel for `E` and a pack-owned kernel for `T`, accepting two reductions but no atomics.
- Move to a custom CUDA/CUTLASS kernel if Triton atomics become the bottleneck.
This is now a speed optimization, not a memory correctness blocker.
### Stage 5: Activation Ternary Mode
Forward currently consumes normal activation tensors. True ternary training should eventually use:
```text
A in {-1, 0, +1}
W in {-1, 0, +1}
accumulator in int32 or fp32
scale from int8 E
```
This makes the gradient-sign kernel cheaper because `x[m, k]` is also sign/zero rather than a dense float.
## Native CUDA Speed Path
### cuBLAS
cuBLAS is the wrong direct target for packed ternary weights. NVIDIA cuBLAS is highly optimized for standard BLAS and low/mixed precision types, including tensor-core-backed FP/INT formats, but it does not accept a custom 5-trit-per-byte ternary matrix as a native GEMM input.
Use cuBLAS only for baselines or fallback dequantized GEMM.
### Triton
Triton is the best immediate path.
Reasons:
- Already installed in this environment.
- Integrates with PyTorch autograd quickly.
- Good for custom packed formats and fused update kernels.
- Lets us remove the dense `grad_w` allocation without waiting on TileLang setup.
Near-term target:
```text
Triton forward packed ternary GEMM
Triton grad-input packed ternary GEMM
Triton sign-only grad/state-update kernel
Triton device repack kernel
```
### TileLang
TileLang is still worth getting working after the algorithm is stable. Its docs describe `T.gemm` lowering to target-specific tensor cores, and it is designed for tiled AI kernels such as dequant GEMM and FlashAttention-style workloads.
Use TileLang when:
- Triton semantics are proven.
- Tile sizes and data layout are stable.
- We want a cleaner path toward tensor-core-style tiling and schedule tuning.
TileLang should not be the first blocker because it is not currently installed in this environment.
### Custom CUDA/CUTLASS
This is the final performance path, not the first implementation path.
Use custom CUDA/CUTLASS when:
- The Triton kernels prove the training algorithm.
- Profiling shows Triton is bottlenecked by decode/repack/reduction overhead.
- We need warp-level bit/trit decode, shared-memory staging, and tuned occupancy.
This path has the highest ceiling and highest development cost.
## Recommended Roadmap
1. Keep the current Triton forward and grad-input path.
2. Add `ternary_grad_sign_accum_kernel` to eliminate dense `grad_w`.
3. Add GPU-side `E` update.
4. Add GPU-side repack of `T_packed`.
5. Benchmark strict ternary training memory with `torch.cuda.max_memory_allocated()`.
6. Tune Triton block sizes.
7. Install and evaluate TileLang against the Triton kernels.
8. Move only proven hot kernels to CUDA/CUTLASS if needed.
## Key Constraint
There is no way around accumulation. Even a ternary model must accumulate dot products and gradient reductions in something wider than ternary. The goal is not "no accumulator precision"; the goal is:
- no persistent BF16/FP32/FP8 weights
- no persistent FP optimizer state
- no dense full-precision weight gradients
- packed ternary storage
- int8 scale memory
- streaming reductions into ternary/int8 state
## Speed Pass: Scale Update Scheduling
The current bottleneck is the exponent scale update, not the packed forward path. A 4-step strict smoke with scheduled scale updates ran slower than the no-scale path, while disabling `E` updates reached about `4.90 it/s` on the same small batch after compile.
Changes made:
- Added `--scale_update_interval`.
- Default: `4`.
- `1`: update int8 `E` every step.
- `0`: disable `E` updates and always use the fast direct ternary repack path.
- Changed CUDA `update_E()` back to the direct group-owned Triton kernel for the normal retained `grad_y/x` path.
- `ternary_step()` now uses the fast direct repack kernel after `update_E()` instead of the fused score/repack kernel on scheduled scale-update steps.
- This removes the temporary int32 score buffer and `tl.atomic_add` from the default scheduled scale-update path.
Observed before the direct `E` change:
```text
scale_update_interval=4: ~1.87 it/s overall on 4-step strict smoke
scale_update_interval=0: ~4.90 it/s overall on 4-step strict smoke
```
Observed after the direct group-owned `E` change:
```text
scale_update_interval=4: ~4.80 it/s overall on 4-step strict smoke
scale_update_interval=1: ~2.82 it/s overall on 4-step strict smoke
```
Interpretation:
- The fused score/repack path was slower than expected because the score buffer and atomics dominated this small-batch update path.
- The direct group-owned `E` kernel is faster even though it performs a separate reduction, because each program owns an `E` group and writes without atomics.
- The default `scale_update_interval=4` now preserves scheduled int8 scale learning while keeping most steps on the fast direct repack path.
Next speed target:
- Benchmark the direct group-owned `E` kernel versus the fused score kernel on real layer sizes.
- If direct `E` is faster, remove or demote the fused score kernel to an experimental path.
- Tune `BLOCK_N`, `BLOCK_K`, and `BLOCK_M` in `_triton_update_e_direct_kernel` and `_triton_ternary_step_direct_kernel`.
|