| # True Ternary Refactor 2 |
|
|
| ## What Changed In This Pass |
|
|
| ### `tscale.py` |
|
|
| - Added Triton detection and a CUDA/Triton execution path for `TernaryScaleTensor`. |
| - Added packed ternary forward and grad-input kernels that read: |
| - `T_packed` as 5 trits per byte |
| - `E` as int8 log-scale groups |
| - `x` / `grad_y` on CUDA |
| - `TernaryScaleTensor.forward()` now prefers the Triton path when input tensors are CUDA tensors. |
| - Fixed the old TileLang negative exponent issue. The previous integer shift path made `2^-k` become zero. The TileLang fallback now reconstructs `2^E` before multiplying by sign. |
| - Fixed the TileLang kernel cache key so CUDA forward kernels are compiled for the actual flattened batch size `M`, not `N`. |
| - Fixed ternary update ordering by calling `update_E()` before `ternary_step()`, so the gradient sign is not deleted before the exponent update sees it. |
| - Kept `T_packed` on the original device after repacking. Without this, repacking moved the buffer back to CPU. |
| - Added a Triton sign-only weight-gradient reduction kernel. Backward no longer materializes: |
|
|
| ```python |
| grad_w = grad_y.T @ x |
| ``` |
|
|
| as a dense FP32 tensor. It now computes only: |
|
|
| ```text |
| sign(sum_m grad_y[m,n] * x[m,k]) |
| ``` |
|
|
| directly into an `int8` CUDA tensor. |
| - Added a Triton `E` update kernel. CUDA exponent updates now read packed trits from `T_packed` and int8 gradient signs directly, then update int8 `E` in-place without unpacking full `T` through PyTorch. |
| - Added a Triton ternary-step/repack kernel. CUDA `ternary_step()` now updates `T_accum`, applies flip thresholds, resets flipped accumulators, and rewrites `T_packed` in-place without calling Python `pack_ternary()`. |
| - Removed the dense `int8 grad_sign[N,K]` allocation from the normal CUDA autograd path. Backward now retains compact `grad_y` and `x` views, and the CUDA `E` update / ternary-step kernels recompute the sign reduction directly inside the state-update kernels. |
| - Fused the expensive sign reduction into the ternary-step/repack pass. The fused CUDA path now updates `T_accum`, rewrites `T_packed`, and atomically accumulates per-group `E` scores in one pass over `grad_y` and `x`. |
| - Replaced the separate direct `E` reduction with a tiny score-apply kernel over `E` groups. The remaining temporary is `int32` per scale group, not per logical weight. |
|
|
| ### `trigram.py` |
|
|
| - Replaced MoE routers with `TernaryScaleTensor`: |
| - `moe.router` |
| - `moe.router_h` |
| - Added constructor gates for strict text-only ternary training: |
| - `enable_image` |
| - `enable_vq` |
| - `enable_graph` |
| - `enable_memory_modules` |
| - `enable_moe` |
| - In strict mode, the model can be built without VQ, graph, image, LSTM, MemGram, or ConvVQ modules, which removes the hidden trainable float state from the core text model. |
| - Added a no-VQ forward path where text relational states go directly into MoE and ByteHead. |
|
|
| ### `train.py` |
|
|
| - Default optimizer changed to `signsgd`. |
| - Added `--compute_dtype {bf16,fp16,none}`. |
| - Added `--strict_ternary`. |
| - Forces SignSGD. |
| - Forces `compute_dtype=none`. |
| - Disables VQ, graph, image, and memory modules. |
| - Freezes any remaining trainable float parameters. |
| - Added `--freeze_float_params` for non-strict runs. |
| - Added model state audit logging before training. |
| - Fixed the main training loop indentation so optimizer steps run inside the data batch loop. |
| - Fixed gradient clipping and optimizer construction to use only trainable parameters. |
|
|
| ### `ternary_audit.py` |
| |
| New helper module for reporting: |
| |
| - logical ternary weights |
| - packed ternary bytes |
| - int8 exponent bytes |
| - int8 accumulator bytes |
| - trainable floating-point parameters |
| - frozen floating-point parameters |
| - floating-point buffers |
| |
| Strict text-only audit currently reports zero trainable float params and zero float buffers. |
| |
| ### `testing/test_tscale.py` |
|
|
| - Added a CUDA/Triton path test for `TernaryScaleTensor`. |
| - The CUDA/Triton test now compares the Triton sign-only gradient against a PyTorch reference sign and asserts the captured gradient state is `int8` on CUDA. |
| - The CUDA/Triton test now verifies the device-side `E` update modifies exponent groups and keeps ternary buffers on CUDA. |
| - The CUDA/Triton test now forces threshold crossings and verifies GPU repack flips packed trits to `+1` and resets accumulators. |
| - The CUDA/Triton test now asserts normal backward does not create `_hook_grad_T_sign`; it uses retained `grad_y` and `x` views for direct state updates. |
| - The CUDA/Triton test now covers the fused ternary-step plus `E` score application path. |
| - Made missing TileLang reference tests skip instead of failing when `tilelang/kernels/dequant_gemm.py` is absent. |
|
|
| ## Verification Run |
|
|
| Direct CUDA/Triton smoke: |
|
|
| ```text |
| cuda True triton True y_device cuda:0 packed_device cuda:0 E_device cuda:0 |
| ``` |
|
|
| Strict one-step train: |
|
|
| ```text |
| Device: cuda |
| trainable float params: 0 tensors, 0.00 MB |
| float buffers: 0 tensors, 0.00 MB |
| Optimizer: SignSGD |
| ``` |
|
|
| Sign-only gradient kernel test: |
|
|
| ```text |
| PASS test_cuda_triton_tscale_path |
| ``` |
|
|
| The same test now also covers the CUDA-side `E` update kernel and CUDA-side packed ternary repack. |
|
|
| Strict one-step train after replacing dense `grad_w`: |
|
|
| ```text |
| Device: cuda |
| trainable float params: 0 tensors, 0.00 MB |
| float buffers: 0 tensors, 0.00 MB |
| train=6.7644 val=10.3655 |
| ``` |
|
|
| Strict one-step train after moving `E` update to Triton: |
|
|
| ```text |
| Device: cuda |
| trainable float params: 0 tensors, 0.00 MB |
| float buffers: 0 tensors, 0.00 MB |
| train=6.7644 val=10.3655 |
| ``` |
|
|
| Strict one-step train after moving `ternary_step()` repack to Triton: |
|
|
| ```text |
| Device: cuda |
| trainable float params: 0 tensors, 0.00 MB |
| float buffers: 0 tensors, 0.00 MB |
| train=6.7644 val=10.3655 |
| ``` |
|
|
| Strict one-step train after removing dense `grad_sign[N,K]` from the normal CUDA path: |
|
|
| ```text |
| Device: cuda |
| trainable float params: 0 tensors, 0.00 MB |
| float buffers: 0 tensors, 0.00 MB |
| train=6.7644 val=10.3655 |
| ``` |
|
|
| Strict one-step train after fusing sign reduction into ternary-step/repack and applying `E` from per-group scores: |
|
|
| ```text |
| Device: cuda |
| trainable float params: 0 tensors, 0.00 MB |
| float buffers: 0 tensors, 0.00 MB |
| train=6.7644 val=10.3655 |
| ``` |
|
|
| ## Remaining Problem: `grad_w = grad_y.T @ x` |
|
|
| The first version of the Triton path fixed forward and grad-input memory behavior, but backward still materialized a dense temporary `grad_w` in Python: |
|
|
| ```python |
| grad_w = grad_2d.float().t() @ x_2d.float() |
| grad_sign = grad_w.sign().to(torch.int8) |
| ``` |
|
|
| That has now been replaced by a Triton sign-only reduction. The dense FP32 `grad_w` tensor is gone from the Triton backward path. |
|
|
| Current remaining issue: |
|
|
| - The CUDA update path now allocates a small `int32` score buffer shaped like `E`, not like the full logical weight matrix. |
| - The score buffer is `out_dim * ceil(in_dim / group_size)` entries. At `T32` group size 12, this is about 1/12 as many elements as the logical weights. |
| - The fused kernel uses `tl.atomic_add` into group scores because 5-trit packed bytes and scale groups are not naturally aligned. |
|
|
| This is now memory-aligned with the 3B-on-8GB goal much better than the previous path. The next optimization is performance tuning: reduce atomics, tune tile sizes, and consider group/pack alignment changes. |
|
|
| The actual required update is: |
|
|
| ```text |
| T_accum[i, j] += sign(sum_m grad_y[m, i] * x[m, j]) |
| E update uses grouped sign statistics from the same reduction. |
| ``` |
|
|
| So the next fix is to avoid storing even the dense `int8` sign tensor by fusing the reduction directly into the ternary state update. |
|
|
| ## Proposed Fix For Gradient Sign Capture |
|
|
| ### Stage 1: Sign-Only Grad Kernel |
|
|
| Implemented. Added a Triton kernel: |
|
|
| ```text |
| input: |
| grad_y[M, N] |
| x[M, K] |
| |
| output: |
| grad_sign[N, K] int8 |
| ``` |
|
|
| Each program owns a tile of `(N, K)`, loops across `M`, accumulates a local `float32` or `int32-ish` sum, then immediately converts to sign: |
|
|
| ```text |
| s = sum_m grad_y[m, n] * x[m, k] |
| g = sign(s) |
| grad_sign[n, k] = g |
| ``` |
|
|
| This removes the dense FP32 `grad_w` allocation. It still performs the same math, but the result is reduced to int8 at the end of each tile instead of returning a full precision matrix to PyTorch. |
|
|
| ### Stage 2: GPU T Accumulator Update And Repack |
|
|
| Implemented as a fused Triton kernel. It computes the sign reduction from `grad_y` and `x`, updates `T_accum`, applies threshold flips, resets flipped accumulator entries, rewrites `T_packed` in-place, and emits per-group scale scores: |
|
|
| ```text |
| g = sign(sum_m grad_y[m,n] * x[m,k]) |
| score[n, group(k)] += g * old_T[n,k] |
| T_accum[n, k] = clamp(T_accum[n, k] + grad_sign[n, k], -128, 127) |
| if T_accum[n,k] > threshold: T[n,k] = +1, T_accum[n,k] = 0 |
| if T_accum[n,k] < -threshold: T[n,k] = -1, T_accum[n,k] = 0 |
| repack 5 T values into one uint8 |
| ``` |
|
|
| This removes Python unpack/repack from the CUDA path. |
|
|
| ### Stage 3: Separate E Group Update |
|
|
| Implemented for the CUDA/Triton path as a small score-apply kernel. |
|
|
| For scale exponents, update per group rather than per weight: |
|
|
| ```text |
| group_score[n, g] = sum_{k in group} sign(grad_w[n, k]) * T[n, k] |
| E[n, g] = clamp(E[n, g] - sign(group_score[n, g]), -128, 127) |
| ``` |
|
|
| This is now a second Triton kernel over `(N, groups)` that applies the score emitted by the fused ternary-step/repack kernel. |
|
|
| Current implemented split: |
|
|
| 1. fused sign-reduction + `T_accum` + repack + group-score kernel |
| 2. small `E` score-apply kernel |
|
|
| No sign tensor is stored in the normal CUDA path. |
|
|
| ### Stage 4: Tune Or Remove Group-Score Atomics |
|
|
| Next step. The current fused CUDA path uses atomics to accumulate group scores because packed bytes are 5-trit chunks while `E` groups are 12, 6, 24, 48, 64, or 96 weights depending on TScale type. |
|
|
| Options: |
|
|
| - Keep atomics and tune block sizes. |
| - Change group sizes to align with 5-trit packing where possible. |
| - Use a group-owned kernel for `E` and a pack-owned kernel for `T`, accepting two reductions but no atomics. |
| - Move to a custom CUDA/CUTLASS kernel if Triton atomics become the bottleneck. |
|
|
| This is now a speed optimization, not a memory correctness blocker. |
|
|
| ### Stage 5: Activation Ternary Mode |
|
|
| Forward currently consumes normal activation tensors. True ternary training should eventually use: |
|
|
| ```text |
| A in {-1, 0, +1} |
| W in {-1, 0, +1} |
| accumulator in int32 or fp32 |
| scale from int8 E |
| ``` |
|
|
| This makes the gradient-sign kernel cheaper because `x[m, k]` is also sign/zero rather than a dense float. |
|
|
| ## Native CUDA Speed Path |
|
|
| ### cuBLAS |
|
|
| cuBLAS is the wrong direct target for packed ternary weights. NVIDIA cuBLAS is highly optimized for standard BLAS and low/mixed precision types, including tensor-core-backed FP/INT formats, but it does not accept a custom 5-trit-per-byte ternary matrix as a native GEMM input. |
|
|
| Use cuBLAS only for baselines or fallback dequantized GEMM. |
|
|
| ### Triton |
|
|
| Triton is the best immediate path. |
|
|
| Reasons: |
|
|
| - Already installed in this environment. |
| - Integrates with PyTorch autograd quickly. |
| - Good for custom packed formats and fused update kernels. |
| - Lets us remove the dense `grad_w` allocation without waiting on TileLang setup. |
|
|
| Near-term target: |
|
|
| ```text |
| Triton forward packed ternary GEMM |
| Triton grad-input packed ternary GEMM |
| Triton sign-only grad/state-update kernel |
| Triton device repack kernel |
| ``` |
|
|
| ### TileLang |
|
|
| TileLang is still worth getting working after the algorithm is stable. Its docs describe `T.gemm` lowering to target-specific tensor cores, and it is designed for tiled AI kernels such as dequant GEMM and FlashAttention-style workloads. |
|
|
| Use TileLang when: |
|
|
| - Triton semantics are proven. |
| - Tile sizes and data layout are stable. |
| - We want a cleaner path toward tensor-core-style tiling and schedule tuning. |
|
|
| TileLang should not be the first blocker because it is not currently installed in this environment. |
|
|
| ### Custom CUDA/CUTLASS |
|
|
| This is the final performance path, not the first implementation path. |
|
|
| Use custom CUDA/CUTLASS when: |
|
|
| - The Triton kernels prove the training algorithm. |
| - Profiling shows Triton is bottlenecked by decode/repack/reduction overhead. |
| - We need warp-level bit/trit decode, shared-memory staging, and tuned occupancy. |
|
|
| This path has the highest ceiling and highest development cost. |
|
|
| ## Recommended Roadmap |
|
|
| 1. Keep the current Triton forward and grad-input path. |
| 2. Add `ternary_grad_sign_accum_kernel` to eliminate dense `grad_w`. |
| 3. Add GPU-side `E` update. |
| 4. Add GPU-side repack of `T_packed`. |
| 5. Benchmark strict ternary training memory with `torch.cuda.max_memory_allocated()`. |
| 6. Tune Triton block sizes. |
| 7. Install and evaluate TileLang against the Triton kernels. |
| 8. Move only proven hot kernels to CUDA/CUTLASS if needed. |
|
|
| ## Key Constraint |
|
|
| There is no way around accumulation. Even a ternary model must accumulate dot products and gradient reductions in something wider than ternary. The goal is not "no accumulator precision"; the goal is: |
|
|
| - no persistent BF16/FP32/FP8 weights |
| - no persistent FP optimizer state |
| - no dense full-precision weight gradients |
| - packed ternary storage |
| - int8 scale memory |
| - streaming reductions into ternary/int8 state |
|
|
| ## Speed Pass: Scale Update Scheduling |
|
|
| The current bottleneck is the exponent scale update, not the packed forward path. A 4-step strict smoke with scheduled scale updates ran slower than the no-scale path, while disabling `E` updates reached about `4.90 it/s` on the same small batch after compile. |
|
|
| Changes made: |
|
|
| - Added `--scale_update_interval`. |
| - Default: `4`. |
| - `1`: update int8 `E` every step. |
| - `0`: disable `E` updates and always use the fast direct ternary repack path. |
| - Changed CUDA `update_E()` back to the direct group-owned Triton kernel for the normal retained `grad_y/x` path. |
| - `ternary_step()` now uses the fast direct repack kernel after `update_E()` instead of the fused score/repack kernel on scheduled scale-update steps. |
| - This removes the temporary int32 score buffer and `tl.atomic_add` from the default scheduled scale-update path. |
|
|
| Observed before the direct `E` change: |
|
|
| ```text |
| scale_update_interval=4: ~1.87 it/s overall on 4-step strict smoke |
| scale_update_interval=0: ~4.90 it/s overall on 4-step strict smoke |
| ``` |
|
|
| Observed after the direct group-owned `E` change: |
|
|
| ```text |
| scale_update_interval=4: ~4.80 it/s overall on 4-step strict smoke |
| scale_update_interval=1: ~2.82 it/s overall on 4-step strict smoke |
| ``` |
|
|
| Interpretation: |
|
|
| - The fused score/repack path was slower than expected because the score buffer and atomics dominated this small-batch update path. |
| - The direct group-owned `E` kernel is faster even though it performs a separate reduction, because each program owns an `E` group and writes without atomics. |
| - The default `scale_update_interval=4` now preserves scheduled int8 scale learning while keeping most steps on the fast direct repack path. |
|
|
| Next speed target: |
|
|
| - Benchmark the direct group-owned `E` kernel versus the fused score kernel on real layer sizes. |
| - If direct `E` is faster, remove or demote the fused score kernel to an experimental path. |
| - Tune `BLOCK_N`, `BLOCK_K`, and `BLOCK_M` in `_triton_update_e_direct_kernel` and `_triton_ternary_step_direct_kernel`. |
|
|