# True Ternary Refactor 2 ## What Changed In This Pass ### `tscale.py` - Added Triton detection and a CUDA/Triton execution path for `TernaryScaleTensor`. - Added packed ternary forward and grad-input kernels that read: - `T_packed` as 5 trits per byte - `E` as int8 log-scale groups - `x` / `grad_y` on CUDA - `TernaryScaleTensor.forward()` now prefers the Triton path when input tensors are CUDA tensors. - Fixed the old TileLang negative exponent issue. The previous integer shift path made `2^-k` become zero. The TileLang fallback now reconstructs `2^E` before multiplying by sign. - Fixed the TileLang kernel cache key so CUDA forward kernels are compiled for the actual flattened batch size `M`, not `N`. - Fixed ternary update ordering by calling `update_E()` before `ternary_step()`, so the gradient sign is not deleted before the exponent update sees it. - Kept `T_packed` on the original device after repacking. Without this, repacking moved the buffer back to CPU. - Added a Triton sign-only weight-gradient reduction kernel. Backward no longer materializes: ```python grad_w = grad_y.T @ x ``` as a dense FP32 tensor. It now computes only: ```text sign(sum_m grad_y[m,n] * x[m,k]) ``` directly into an `int8` CUDA tensor. - Added a Triton `E` update kernel. CUDA exponent updates now read packed trits from `T_packed` and int8 gradient signs directly, then update int8 `E` in-place without unpacking full `T` through PyTorch. - Added a Triton ternary-step/repack kernel. CUDA `ternary_step()` now updates `T_accum`, applies flip thresholds, resets flipped accumulators, and rewrites `T_packed` in-place without calling Python `pack_ternary()`. - Removed the dense `int8 grad_sign[N,K]` allocation from the normal CUDA autograd path. Backward now retains compact `grad_y` and `x` views, and the CUDA `E` update / ternary-step kernels recompute the sign reduction directly inside the state-update kernels. - Fused the expensive sign reduction into the ternary-step/repack pass. The fused CUDA path now updates `T_accum`, rewrites `T_packed`, and atomically accumulates per-group `E` scores in one pass over `grad_y` and `x`. - Replaced the separate direct `E` reduction with a tiny score-apply kernel over `E` groups. The remaining temporary is `int32` per scale group, not per logical weight. ### `trigram.py` - Replaced MoE routers with `TernaryScaleTensor`: - `moe.router` - `moe.router_h` - Added constructor gates for strict text-only ternary training: - `enable_image` - `enable_vq` - `enable_graph` - `enable_memory_modules` - `enable_moe` - In strict mode, the model can be built without VQ, graph, image, LSTM, MemGram, or ConvVQ modules, which removes the hidden trainable float state from the core text model. - Added a no-VQ forward path where text relational states go directly into MoE and ByteHead. ### `train.py` - Default optimizer changed to `signsgd`. - Added `--compute_dtype {bf16,fp16,none}`. - Added `--strict_ternary`. - Forces SignSGD. - Forces `compute_dtype=none`. - Disables VQ, graph, image, and memory modules. - Freezes any remaining trainable float parameters. - Added `--freeze_float_params` for non-strict runs. - Added model state audit logging before training. - Fixed the main training loop indentation so optimizer steps run inside the data batch loop. - Fixed gradient clipping and optimizer construction to use only trainable parameters. ### `ternary_audit.py` New helper module for reporting: - logical ternary weights - packed ternary bytes - int8 exponent bytes - int8 accumulator bytes - trainable floating-point parameters - frozen floating-point parameters - floating-point buffers Strict text-only audit currently reports zero trainable float params and zero float buffers. ### `testing/test_tscale.py` - Added a CUDA/Triton path test for `TernaryScaleTensor`. - The CUDA/Triton test now compares the Triton sign-only gradient against a PyTorch reference sign and asserts the captured gradient state is `int8` on CUDA. - The CUDA/Triton test now verifies the device-side `E` update modifies exponent groups and keeps ternary buffers on CUDA. - The CUDA/Triton test now forces threshold crossings and verifies GPU repack flips packed trits to `+1` and resets accumulators. - The CUDA/Triton test now asserts normal backward does not create `_hook_grad_T_sign`; it uses retained `grad_y` and `x` views for direct state updates. - The CUDA/Triton test now covers the fused ternary-step plus `E` score application path. - Made missing TileLang reference tests skip instead of failing when `tilelang/kernels/dequant_gemm.py` is absent. ## Verification Run Direct CUDA/Triton smoke: ```text cuda True triton True y_device cuda:0 packed_device cuda:0 E_device cuda:0 ``` Strict one-step train: ```text Device: cuda trainable float params: 0 tensors, 0.00 MB float buffers: 0 tensors, 0.00 MB Optimizer: SignSGD ``` Sign-only gradient kernel test: ```text PASS test_cuda_triton_tscale_path ``` The same test now also covers the CUDA-side `E` update kernel and CUDA-side packed ternary repack. Strict one-step train after replacing dense `grad_w`: ```text Device: cuda trainable float params: 0 tensors, 0.00 MB float buffers: 0 tensors, 0.00 MB train=6.7644 val=10.3655 ``` Strict one-step train after moving `E` update to Triton: ```text Device: cuda trainable float params: 0 tensors, 0.00 MB float buffers: 0 tensors, 0.00 MB train=6.7644 val=10.3655 ``` Strict one-step train after moving `ternary_step()` repack to Triton: ```text Device: cuda trainable float params: 0 tensors, 0.00 MB float buffers: 0 tensors, 0.00 MB train=6.7644 val=10.3655 ``` Strict one-step train after removing dense `grad_sign[N,K]` from the normal CUDA path: ```text Device: cuda trainable float params: 0 tensors, 0.00 MB float buffers: 0 tensors, 0.00 MB train=6.7644 val=10.3655 ``` Strict one-step train after fusing sign reduction into ternary-step/repack and applying `E` from per-group scores: ```text Device: cuda trainable float params: 0 tensors, 0.00 MB float buffers: 0 tensors, 0.00 MB train=6.7644 val=10.3655 ``` ## Remaining Problem: `grad_w = grad_y.T @ x` The first version of the Triton path fixed forward and grad-input memory behavior, but backward still materialized a dense temporary `grad_w` in Python: ```python grad_w = grad_2d.float().t() @ x_2d.float() grad_sign = grad_w.sign().to(torch.int8) ``` That has now been replaced by a Triton sign-only reduction. The dense FP32 `grad_w` tensor is gone from the Triton backward path. Current remaining issue: - The CUDA update path now allocates a small `int32` score buffer shaped like `E`, not like the full logical weight matrix. - The score buffer is `out_dim * ceil(in_dim / group_size)` entries. At `T32` group size 12, this is about 1/12 as many elements as the logical weights. - The fused kernel uses `tl.atomic_add` into group scores because 5-trit packed bytes and scale groups are not naturally aligned. This is now memory-aligned with the 3B-on-8GB goal much better than the previous path. The next optimization is performance tuning: reduce atomics, tune tile sizes, and consider group/pack alignment changes. The actual required update is: ```text T_accum[i, j] += sign(sum_m grad_y[m, i] * x[m, j]) E update uses grouped sign statistics from the same reduction. ``` So the next fix is to avoid storing even the dense `int8` sign tensor by fusing the reduction directly into the ternary state update. ## Proposed Fix For Gradient Sign Capture ### Stage 1: Sign-Only Grad Kernel Implemented. Added a Triton kernel: ```text input: grad_y[M, N] x[M, K] output: grad_sign[N, K] int8 ``` Each program owns a tile of `(N, K)`, loops across `M`, accumulates a local `float32` or `int32-ish` sum, then immediately converts to sign: ```text s = sum_m grad_y[m, n] * x[m, k] g = sign(s) grad_sign[n, k] = g ``` This removes the dense FP32 `grad_w` allocation. It still performs the same math, but the result is reduced to int8 at the end of each tile instead of returning a full precision matrix to PyTorch. ### Stage 2: GPU T Accumulator Update And Repack Implemented as a fused Triton kernel. It computes the sign reduction from `grad_y` and `x`, updates `T_accum`, applies threshold flips, resets flipped accumulator entries, rewrites `T_packed` in-place, and emits per-group scale scores: ```text g = sign(sum_m grad_y[m,n] * x[m,k]) score[n, group(k)] += g * old_T[n,k] T_accum[n, k] = clamp(T_accum[n, k] + grad_sign[n, k], -128, 127) if T_accum[n,k] > threshold: T[n,k] = +1, T_accum[n,k] = 0 if T_accum[n,k] < -threshold: T[n,k] = -1, T_accum[n,k] = 0 repack 5 T values into one uint8 ``` This removes Python unpack/repack from the CUDA path. ### Stage 3: Separate E Group Update Implemented for the CUDA/Triton path as a small score-apply kernel. For scale exponents, update per group rather than per weight: ```text group_score[n, g] = sum_{k in group} sign(grad_w[n, k]) * T[n, k] E[n, g] = clamp(E[n, g] - sign(group_score[n, g]), -128, 127) ``` This is now a second Triton kernel over `(N, groups)` that applies the score emitted by the fused ternary-step/repack kernel. Current implemented split: 1. fused sign-reduction + `T_accum` + repack + group-score kernel 2. small `E` score-apply kernel No sign tensor is stored in the normal CUDA path. ### Stage 4: Tune Or Remove Group-Score Atomics Next step. The current fused CUDA path uses atomics to accumulate group scores because packed bytes are 5-trit chunks while `E` groups are 12, 6, 24, 48, 64, or 96 weights depending on TScale type. Options: - Keep atomics and tune block sizes. - Change group sizes to align with 5-trit packing where possible. - Use a group-owned kernel for `E` and a pack-owned kernel for `T`, accepting two reductions but no atomics. - Move to a custom CUDA/CUTLASS kernel if Triton atomics become the bottleneck. This is now a speed optimization, not a memory correctness blocker. ### Stage 5: Activation Ternary Mode Forward currently consumes normal activation tensors. True ternary training should eventually use: ```text A in {-1, 0, +1} W in {-1, 0, +1} accumulator in int32 or fp32 scale from int8 E ``` This makes the gradient-sign kernel cheaper because `x[m, k]` is also sign/zero rather than a dense float. ## Native CUDA Speed Path ### cuBLAS cuBLAS is the wrong direct target for packed ternary weights. NVIDIA cuBLAS is highly optimized for standard BLAS and low/mixed precision types, including tensor-core-backed FP/INT formats, but it does not accept a custom 5-trit-per-byte ternary matrix as a native GEMM input. Use cuBLAS only for baselines or fallback dequantized GEMM. ### Triton Triton is the best immediate path. Reasons: - Already installed in this environment. - Integrates with PyTorch autograd quickly. - Good for custom packed formats and fused update kernels. - Lets us remove the dense `grad_w` allocation without waiting on TileLang setup. Near-term target: ```text Triton forward packed ternary GEMM Triton grad-input packed ternary GEMM Triton sign-only grad/state-update kernel Triton device repack kernel ``` ### TileLang TileLang is still worth getting working after the algorithm is stable. Its docs describe `T.gemm` lowering to target-specific tensor cores, and it is designed for tiled AI kernels such as dequant GEMM and FlashAttention-style workloads. Use TileLang when: - Triton semantics are proven. - Tile sizes and data layout are stable. - We want a cleaner path toward tensor-core-style tiling and schedule tuning. TileLang should not be the first blocker because it is not currently installed in this environment. ### Custom CUDA/CUTLASS This is the final performance path, not the first implementation path. Use custom CUDA/CUTLASS when: - The Triton kernels prove the training algorithm. - Profiling shows Triton is bottlenecked by decode/repack/reduction overhead. - We need warp-level bit/trit decode, shared-memory staging, and tuned occupancy. This path has the highest ceiling and highest development cost. ## Recommended Roadmap 1. Keep the current Triton forward and grad-input path. 2. Add `ternary_grad_sign_accum_kernel` to eliminate dense `grad_w`. 3. Add GPU-side `E` update. 4. Add GPU-side repack of `T_packed`. 5. Benchmark strict ternary training memory with `torch.cuda.max_memory_allocated()`. 6. Tune Triton block sizes. 7. Install and evaluate TileLang against the Triton kernels. 8. Move only proven hot kernels to CUDA/CUTLASS if needed. ## Key Constraint There is no way around accumulation. Even a ternary model must accumulate dot products and gradient reductions in something wider than ternary. The goal is not "no accumulator precision"; the goal is: - no persistent BF16/FP32/FP8 weights - no persistent FP optimizer state - no dense full-precision weight gradients - packed ternary storage - int8 scale memory - streaming reductions into ternary/int8 state ## Speed Pass: Scale Update Scheduling The current bottleneck is the exponent scale update, not the packed forward path. A 4-step strict smoke with scheduled scale updates ran slower than the no-scale path, while disabling `E` updates reached about `4.90 it/s` on the same small batch after compile. Changes made: - Added `--scale_update_interval`. - Default: `4`. - `1`: update int8 `E` every step. - `0`: disable `E` updates and always use the fast direct ternary repack path. - Changed CUDA `update_E()` back to the direct group-owned Triton kernel for the normal retained `grad_y/x` path. - `ternary_step()` now uses the fast direct repack kernel after `update_E()` instead of the fused score/repack kernel on scheduled scale-update steps. - This removes the temporary int32 score buffer and `tl.atomic_add` from the default scheduled scale-update path. Observed before the direct `E` change: ```text scale_update_interval=4: ~1.87 it/s overall on 4-step strict smoke scale_update_interval=0: ~4.90 it/s overall on 4-step strict smoke ``` Observed after the direct group-owned `E` change: ```text scale_update_interval=4: ~4.80 it/s overall on 4-step strict smoke scale_update_interval=1: ~2.82 it/s overall on 4-step strict smoke ``` Interpretation: - The fused score/repack path was slower than expected because the score buffer and atomics dominated this small-batch update path. - The direct group-owned `E` kernel is faster even though it performs a separate reduction, because each program owns an `E` group and writes without atomics. - The default `scale_update_interval=4` now preserves scheduled int8 scale learning while keeping most steps on the fast direct repack path. Next speed target: - Benchmark the direct group-owned `E` kernel versus the fused score kernel on real layer sizes. - If direct `E` is faster, remove or demote the fused score kernel to an experimental path. - Tune `BLOCK_N`, `BLOCK_K`, and `BLOCK_M` in `_triton_update_e_direct_kernel` and `_triton_ternary_step_direct_kernel`.