ARBS / docs /true-ternary /TRUE-TERNARY-REFACTOR2.md
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
# True Ternary Refactor 2
## What Changed In This Pass
### `tscale.py`
- Added Triton detection and a CUDA/Triton execution path for `TernaryScaleTensor`.
- Added packed ternary forward and grad-input kernels that read:
- `T_packed` as 5 trits per byte
- `E` as int8 log-scale groups
- `x` / `grad_y` on CUDA
- `TernaryScaleTensor.forward()` now prefers the Triton path when input tensors are CUDA tensors.
- Fixed the old TileLang negative exponent issue. The previous integer shift path made `2^-k` become zero. The TileLang fallback now reconstructs `2^E` before multiplying by sign.
- Fixed the TileLang kernel cache key so CUDA forward kernels are compiled for the actual flattened batch size `M`, not `N`.
- Fixed ternary update ordering by calling `update_E()` before `ternary_step()`, so the gradient sign is not deleted before the exponent update sees it.
- Kept `T_packed` on the original device after repacking. Without this, repacking moved the buffer back to CPU.
- Added a Triton sign-only weight-gradient reduction kernel. Backward no longer materializes:
```python
grad_w = grad_y.T @ x
```
as a dense FP32 tensor. It now computes only:
```text
sign(sum_m grad_y[m,n] * x[m,k])
```
directly into an `int8` CUDA tensor.
- Added a Triton `E` update kernel. CUDA exponent updates now read packed trits from `T_packed` and int8 gradient signs directly, then update int8 `E` in-place without unpacking full `T` through PyTorch.
- Added a Triton ternary-step/repack kernel. CUDA `ternary_step()` now updates `T_accum`, applies flip thresholds, resets flipped accumulators, and rewrites `T_packed` in-place without calling Python `pack_ternary()`.
- Removed the dense `int8 grad_sign[N,K]` allocation from the normal CUDA autograd path. Backward now retains compact `grad_y` and `x` views, and the CUDA `E` update / ternary-step kernels recompute the sign reduction directly inside the state-update kernels.
- Fused the expensive sign reduction into the ternary-step/repack pass. The fused CUDA path now updates `T_accum`, rewrites `T_packed`, and atomically accumulates per-group `E` scores in one pass over `grad_y` and `x`.
- Replaced the separate direct `E` reduction with a tiny score-apply kernel over `E` groups. The remaining temporary is `int32` per scale group, not per logical weight.
### `trigram.py`
- Replaced MoE routers with `TernaryScaleTensor`:
- `moe.router`
- `moe.router_h`
- Added constructor gates for strict text-only ternary training:
- `enable_image`
- `enable_vq`
- `enable_graph`
- `enable_memory_modules`
- `enable_moe`
- In strict mode, the model can be built without VQ, graph, image, LSTM, MemGram, or ConvVQ modules, which removes the hidden trainable float state from the core text model.
- Added a no-VQ forward path where text relational states go directly into MoE and ByteHead.
### `train.py`
- Default optimizer changed to `signsgd`.
- Added `--compute_dtype {bf16,fp16,none}`.
- Added `--strict_ternary`.
- Forces SignSGD.
- Forces `compute_dtype=none`.
- Disables VQ, graph, image, and memory modules.
- Freezes any remaining trainable float parameters.
- Added `--freeze_float_params` for non-strict runs.
- Added model state audit logging before training.
- Fixed the main training loop indentation so optimizer steps run inside the data batch loop.
- Fixed gradient clipping and optimizer construction to use only trainable parameters.
### `ternary_audit.py`
New helper module for reporting:
- logical ternary weights
- packed ternary bytes
- int8 exponent bytes
- int8 accumulator bytes
- trainable floating-point parameters
- frozen floating-point parameters
- floating-point buffers
Strict text-only audit currently reports zero trainable float params and zero float buffers.
### `testing/test_tscale.py`
- Added a CUDA/Triton path test for `TernaryScaleTensor`.
- The CUDA/Triton test now compares the Triton sign-only gradient against a PyTorch reference sign and asserts the captured gradient state is `int8` on CUDA.
- The CUDA/Triton test now verifies the device-side `E` update modifies exponent groups and keeps ternary buffers on CUDA.
- The CUDA/Triton test now forces threshold crossings and verifies GPU repack flips packed trits to `+1` and resets accumulators.
- The CUDA/Triton test now asserts normal backward does not create `_hook_grad_T_sign`; it uses retained `grad_y` and `x` views for direct state updates.
- The CUDA/Triton test now covers the fused ternary-step plus `E` score application path.
- Made missing TileLang reference tests skip instead of failing when `tilelang/kernels/dequant_gemm.py` is absent.
## Verification Run
Direct CUDA/Triton smoke:
```text
cuda True triton True y_device cuda:0 packed_device cuda:0 E_device cuda:0
```
Strict one-step train:
```text
Device: cuda
trainable float params: 0 tensors, 0.00 MB
float buffers: 0 tensors, 0.00 MB
Optimizer: SignSGD
```
Sign-only gradient kernel test:
```text
PASS test_cuda_triton_tscale_path
```
The same test now also covers the CUDA-side `E` update kernel and CUDA-side packed ternary repack.
Strict one-step train after replacing dense `grad_w`:
```text
Device: cuda
trainable float params: 0 tensors, 0.00 MB
float buffers: 0 tensors, 0.00 MB
train=6.7644 val=10.3655
```
Strict one-step train after moving `E` update to Triton:
```text
Device: cuda
trainable float params: 0 tensors, 0.00 MB
float buffers: 0 tensors, 0.00 MB
train=6.7644 val=10.3655
```
Strict one-step train after moving `ternary_step()` repack to Triton:
```text
Device: cuda
trainable float params: 0 tensors, 0.00 MB
float buffers: 0 tensors, 0.00 MB
train=6.7644 val=10.3655
```
Strict one-step train after removing dense `grad_sign[N,K]` from the normal CUDA path:
```text
Device: cuda
trainable float params: 0 tensors, 0.00 MB
float buffers: 0 tensors, 0.00 MB
train=6.7644 val=10.3655
```
Strict one-step train after fusing sign reduction into ternary-step/repack and applying `E` from per-group scores:
```text
Device: cuda
trainable float params: 0 tensors, 0.00 MB
float buffers: 0 tensors, 0.00 MB
train=6.7644 val=10.3655
```
## Remaining Problem: `grad_w = grad_y.T @ x`
The first version of the Triton path fixed forward and grad-input memory behavior, but backward still materialized a dense temporary `grad_w` in Python:
```python
grad_w = grad_2d.float().t() @ x_2d.float()
grad_sign = grad_w.sign().to(torch.int8)
```
That has now been replaced by a Triton sign-only reduction. The dense FP32 `grad_w` tensor is gone from the Triton backward path.
Current remaining issue:
- The CUDA update path now allocates a small `int32` score buffer shaped like `E`, not like the full logical weight matrix.
- The score buffer is `out_dim * ceil(in_dim / group_size)` entries. At `T32` group size 12, this is about 1/12 as many elements as the logical weights.
- The fused kernel uses `tl.atomic_add` into group scores because 5-trit packed bytes and scale groups are not naturally aligned.
This is now memory-aligned with the 3B-on-8GB goal much better than the previous path. The next optimization is performance tuning: reduce atomics, tune tile sizes, and consider group/pack alignment changes.
The actual required update is:
```text
T_accum[i, j] += sign(sum_m grad_y[m, i] * x[m, j])
E update uses grouped sign statistics from the same reduction.
```
So the next fix is to avoid storing even the dense `int8` sign tensor by fusing the reduction directly into the ternary state update.
## Proposed Fix For Gradient Sign Capture
### Stage 1: Sign-Only Grad Kernel
Implemented. Added a Triton kernel:
```text
input:
grad_y[M, N]
x[M, K]
output:
grad_sign[N, K] int8
```
Each program owns a tile of `(N, K)`, loops across `M`, accumulates a local `float32` or `int32-ish` sum, then immediately converts to sign:
```text
s = sum_m grad_y[m, n] * x[m, k]
g = sign(s)
grad_sign[n, k] = g
```
This removes the dense FP32 `grad_w` allocation. It still performs the same math, but the result is reduced to int8 at the end of each tile instead of returning a full precision matrix to PyTorch.
### Stage 2: GPU T Accumulator Update And Repack
Implemented as a fused Triton kernel. It computes the sign reduction from `grad_y` and `x`, updates `T_accum`, applies threshold flips, resets flipped accumulator entries, rewrites `T_packed` in-place, and emits per-group scale scores:
```text
g = sign(sum_m grad_y[m,n] * x[m,k])
score[n, group(k)] += g * old_T[n,k]
T_accum[n, k] = clamp(T_accum[n, k] + grad_sign[n, k], -128, 127)
if T_accum[n,k] > threshold: T[n,k] = +1, T_accum[n,k] = 0
if T_accum[n,k] < -threshold: T[n,k] = -1, T_accum[n,k] = 0
repack 5 T values into one uint8
```
This removes Python unpack/repack from the CUDA path.
### Stage 3: Separate E Group Update
Implemented for the CUDA/Triton path as a small score-apply kernel.
For scale exponents, update per group rather than per weight:
```text
group_score[n, g] = sum_{k in group} sign(grad_w[n, k]) * T[n, k]
E[n, g] = clamp(E[n, g] - sign(group_score[n, g]), -128, 127)
```
This is now a second Triton kernel over `(N, groups)` that applies the score emitted by the fused ternary-step/repack kernel.
Current implemented split:
1. fused sign-reduction + `T_accum` + repack + group-score kernel
2. small `E` score-apply kernel
No sign tensor is stored in the normal CUDA path.
### Stage 4: Tune Or Remove Group-Score Atomics
Next step. The current fused CUDA path uses atomics to accumulate group scores because packed bytes are 5-trit chunks while `E` groups are 12, 6, 24, 48, 64, or 96 weights depending on TScale type.
Options:
- Keep atomics and tune block sizes.
- Change group sizes to align with 5-trit packing where possible.
- Use a group-owned kernel for `E` and a pack-owned kernel for `T`, accepting two reductions but no atomics.
- Move to a custom CUDA/CUTLASS kernel if Triton atomics become the bottleneck.
This is now a speed optimization, not a memory correctness blocker.
### Stage 5: Activation Ternary Mode
Forward currently consumes normal activation tensors. True ternary training should eventually use:
```text
A in {-1, 0, +1}
W in {-1, 0, +1}
accumulator in int32 or fp32
scale from int8 E
```
This makes the gradient-sign kernel cheaper because `x[m, k]` is also sign/zero rather than a dense float.
## Native CUDA Speed Path
### cuBLAS
cuBLAS is the wrong direct target for packed ternary weights. NVIDIA cuBLAS is highly optimized for standard BLAS and low/mixed precision types, including tensor-core-backed FP/INT formats, but it does not accept a custom 5-trit-per-byte ternary matrix as a native GEMM input.
Use cuBLAS only for baselines or fallback dequantized GEMM.
### Triton
Triton is the best immediate path.
Reasons:
- Already installed in this environment.
- Integrates with PyTorch autograd quickly.
- Good for custom packed formats and fused update kernels.
- Lets us remove the dense `grad_w` allocation without waiting on TileLang setup.
Near-term target:
```text
Triton forward packed ternary GEMM
Triton grad-input packed ternary GEMM
Triton sign-only grad/state-update kernel
Triton device repack kernel
```
### TileLang
TileLang is still worth getting working after the algorithm is stable. Its docs describe `T.gemm` lowering to target-specific tensor cores, and it is designed for tiled AI kernels such as dequant GEMM and FlashAttention-style workloads.
Use TileLang when:
- Triton semantics are proven.
- Tile sizes and data layout are stable.
- We want a cleaner path toward tensor-core-style tiling and schedule tuning.
TileLang should not be the first blocker because it is not currently installed in this environment.
### Custom CUDA/CUTLASS
This is the final performance path, not the first implementation path.
Use custom CUDA/CUTLASS when:
- The Triton kernels prove the training algorithm.
- Profiling shows Triton is bottlenecked by decode/repack/reduction overhead.
- We need warp-level bit/trit decode, shared-memory staging, and tuned occupancy.
This path has the highest ceiling and highest development cost.
## Recommended Roadmap
1. Keep the current Triton forward and grad-input path.
2. Add `ternary_grad_sign_accum_kernel` to eliminate dense `grad_w`.
3. Add GPU-side `E` update.
4. Add GPU-side repack of `T_packed`.
5. Benchmark strict ternary training memory with `torch.cuda.max_memory_allocated()`.
6. Tune Triton block sizes.
7. Install and evaluate TileLang against the Triton kernels.
8. Move only proven hot kernels to CUDA/CUTLASS if needed.
## Key Constraint
There is no way around accumulation. Even a ternary model must accumulate dot products and gradient reductions in something wider than ternary. The goal is not "no accumulator precision"; the goal is:
- no persistent BF16/FP32/FP8 weights
- no persistent FP optimizer state
- no dense full-precision weight gradients
- packed ternary storage
- int8 scale memory
- streaming reductions into ternary/int8 state
## Speed Pass: Scale Update Scheduling
The current bottleneck is the exponent scale update, not the packed forward path. A 4-step strict smoke with scheduled scale updates ran slower than the no-scale path, while disabling `E` updates reached about `4.90 it/s` on the same small batch after compile.
Changes made:
- Added `--scale_update_interval`.
- Default: `4`.
- `1`: update int8 `E` every step.
- `0`: disable `E` updates and always use the fast direct ternary repack path.
- Changed CUDA `update_E()` back to the direct group-owned Triton kernel for the normal retained `grad_y/x` path.
- `ternary_step()` now uses the fast direct repack kernel after `update_E()` instead of the fused score/repack kernel on scheduled scale-update steps.
- This removes the temporary int32 score buffer and `tl.atomic_add` from the default scheduled scale-update path.
Observed before the direct `E` change:
```text
scale_update_interval=4: ~1.87 it/s overall on 4-step strict smoke
scale_update_interval=0: ~4.90 it/s overall on 4-step strict smoke
```
Observed after the direct group-owned `E` change:
```text
scale_update_interval=4: ~4.80 it/s overall on 4-step strict smoke
scale_update_interval=1: ~2.82 it/s overall on 4-step strict smoke
```
Interpretation:
- The fused score/repack path was slower than expected because the score buffer and atomics dominated this small-batch update path.
- The direct group-owned `E` kernel is faster even though it performs a separate reduction, because each program owns an `E` group and writes without atomics.
- The default `scale_update_interval=4` now preserves scheduled int8 scale learning while keeping most steps on the fast direct repack path.
Next speed target:
- Benchmark the direct group-owned `E` kernel versus the fused score kernel on real layer sizes.
- If direct `E` is faster, remove or demote the fused score kernel to an experimental path.
- Tune `BLOCK_N`, `BLOCK_K`, and `BLOCK_M` in `_triton_update_e_direct_kernel` and `_triton_ternary_step_direct_kernel`.