| # True Ternary Refactor 3 |
|
|
| ## Scope |
|
|
| This pass extends the Triton kernel coverage from `TernaryScaleTensor` (linear layers) to the remaining model components: `TernaryRMSNorm` (normalization) and `ByteEmbedding` (token embedding). It also removes dead code, adds low-level correctness tests, and investigates the training loss spike. |
|
|
| Refactor 2 established the packed-ternary Triton path for `TernaryScaleTensor` forward, grad-x, sign-only weight-gradient, E-update, and ternary-step/repack. This pass ports the two remaining hot paths and cleans up the kernel surface. |
|
|
| ## What Changed |
|
|
| ### `tscale.py` |
|
|
| #### TernaryRMSNorm Triton kernels |
|
|
| - Added `_triton_rmsnorm_fwd_kernel`: fused RMS norm + ternary weight application in a single kernel. Each program loads a batch row of `x`, computes RMS, normalizes, then loads packed ternary weights and int8 exponents for that dimension, dequantizes via `sign * exp2(E)`, and multiplies into the output. |
| - Added `_triton_rmsnorm_bwd_kernel`: fused backward for RMS norm with ternary weights. Computes `dx = (dy * w - x_norm * mean(x_norm * dy * w)) / rms` without materializing the weight tensor. |
| - Added `_TritonRMSNormFn(torch.autograd.Function)`: autograd wrapper that saves `x_2d`, `packed`, `e` for backward. Forward calls `_triton_rmsnorm_fwd_kernel`, backward calls `_triton_rmsnorm_bwd_kernel`. |
| - `TernaryRMSNorm.forward()` now prefers the Triton path when input is on CUDA. CPU fallback remains unchanged. |
| - `TernaryRMSNorm.ternary_step()` and `TernaryRMSNorm.update_E()` remain no-ops (these weights are frozen). |
|
|
| Correctness: forward max diff 0.0 vs CPU reference, backward max diff 0.0 vs autograd reference, for all 6 TScaleTypes. |
|
|
| #### ByteEmbedding Triton kernels |
|
|
| - Added `_triton_ternary_embed_fwd_kernel`: embedding lookup from packed ternary + int8 E. Each program loads a batch of indices, gathers the corresponding packed ternary values and exponents, dequantizes via `sign * exp2(E)`, and writes the output embedding vectors. Uses linear indexing into the flat packed buffer (`lin = idx * DIM + d`, `pack_idx = lin // 5`). |
| - Added `_triton_ternary_embed_bwd_accum_kernel`: scatter-adds gradient contributions from each index position into a float accumulator buffer shaped `[VOCAB, DIM]` using `tl.atomic_add`. |
| - Added `_triton_ternary_embed_bwd_sign_kernel`: converts the float accumulator to int8 sign (`1 / 0 / -1`) in a single pass over `[VOCAB, DIM]`. |
| - Added `_triton_ternary_embed_grad_sign(indices, grad_output, vocab, dim)`: two-pass helper that runs the accum kernel then the sign kernel. |
| - Updated `_TritonTernaryEmbedFn.backward()`: no longer materializes float `w_eff`. Instead calls `_triton_ternary_embed_grad_sign()` to compute `grad_sign` directly into int8 via atomic scatter-add + sign. Still unpacks `T` for `_hook_T` (used by `update_E` CPU-style path on ByteEmbedding). |
|
|
| Correctness: forward max diff 0.0 vs CPU reference, grad_sign 100% match vs CPU for all 6 TScaleTypes including duplicate indices. |
| |
| #### Dead code removal |
| |
| - Removed `_hook_defer_e_to_ternary_step` dead branch from `ternary_step()`. This branch called the deleted `_triton_ternary_step_score` kernel. The direct path (`_triton_ternary_step_direct`) is now the only Triton ternary-step path. |
| - Removed the corresponding assertion from `test_cuda_triton_tscale_path` that checked `_hook_defer_e_to_ternary_step` was not set. |
| - No other code references `_hook_defer_e_to_ternary_step` or the deleted score kernels. |
|
|
| #### Bug fix: missing `_triton_ternary_embed_fwd_kernel` |
| |
| The embed forward kernel definition was accidentally deleted during a previous session's cleanup (the helper `_triton_ternary_embed` still referenced it, causing `NameError` at runtime). Reconstructed with correct linear-index packed ternary decoding. The kernel was verified from scratch: exact match vs CPU for all 6 TScaleTypes. |
|
|
| ### `trigram.py` |
|
|
| No changes to model code in this pass. The `ByteEmbedding` and `TernaryRMSNorm` classes already called the Triton path via `_TritonTernaryEmbedFn` and `_TritonRMSNormFn` β they just were not working because the kernel definitions were missing. |
|
|
| ### `testing/test_tscale.py` |
| |
| Added 5 low-level Triton vs CPU reference correctness tests: |
| |
| | Test | What it checks | |
| |------|---------------| |
| | `test_cuda_triton_correctness_linear` | TernaryScaleTensor forward + grad-x vs CPU for all 6 TScaleTypes (atol 1e-3) | |
| | `test_cuda_triton_correctness_rmsnorm` | TernaryRMSNorm forward + backward vs CPU for all 6 TScaleTypes (atol 1e-5) | |
| | `test_cuda_triton_correctness_embedding` | ByteEmbedding forward + grad_sign vs CPU for all 6 TScaleTypes | |
| | `test_cuda_triton_correctness_update_E` | E update exact match vs CPU for all 6 TScaleTypes | |
| | `test_cuda_triton_correctness_ternary_step` | T flip + T_accum exact match vs CPU for all 6 TScaleTypes | |
| |
| All tests create CPU and GPU modules from the same random seed, run forward + backward independently, and compare outputs/state updates. |
| |
| ## Kernel Inventory |
| |
| ### TernaryScaleTensor (from Refactor 2) |
| |
| | Kernel | Purpose | Input | Output | |
| |--------|---------|-------|--------| |
| | `_triton_ternary_fwd_kernel` | Packed ternary GEMM | x[M,K], T_packed, E | y[M,N] float32 | |
| | `_triton_ternary_grad_x_kernel` | Grad-input GEMM | grad_y[M,N], T_packed, E | grad_x[M,K] float32 | |
| | `_triton_ternary_grad_sign_kernel` | Sign-only weight gradient | grad_y[M,N], x[M,K] | grad_sign[N,K] int8 | |
| | `_triton_update_e_kernel` | E update from precomputed grad_sign | T_packed, grad_sign[N,K], E | E (in-place) | |
| | `_triton_update_e_direct_kernel` | E update from raw grad/x (avoids grad_sign alloc) | T_packed, grad_y[M,N], x[M,K], E | E (in-place) | |
| | `_triton_ternary_step_kernel` | T_accum update + flip + repack from grad_sign | T_packed, grad_sign[N,K], T_accum | T_packed, T_accum (in-place) | |
| | `_triton_ternary_step_direct_kernel` | T_accum update + flip + repack from raw grad/x | T_packed, grad_y[M,N], x[M,K], T_accum | T_packed, T_accum (in-place) | |
|
|
| ### TernaryRMSNorm (new in Refactor 3) |
|
|
| | Kernel | Purpose | Input | Output | |
| |--------|---------|-------|--------| |
| | `_triton_rmsnorm_fwd_kernel` | RMS norm + ternary weight | x[B,D], T_packed, E | out[B,D] float32 | |
| | `_triton_rmsnorm_bwd_kernel` | RMS norm backward through ternary weight | grad_out[B,D], x[B,D], T_packed, E | grad_x[B,D] float32 | |
|
|
| ### ByteEmbedding (new in Refactor 3) |
|
|
| | Kernel | Purpose | Input | Output | |
| |--------|---------|-------|--------| |
| | `_triton_ternary_embed_fwd_kernel` | Embedding lookup from packed ternary | indices[N], T_packed, E | out[N,D] float32 | |
| | `_triton_ternary_embed_bwd_accum_kernel` | Scatter-add grad into per-vocab accumulator | indices[N], grad_out[N,D], accum[V,D] | accum (atomic add) | |
| | `_triton_ternary_embed_bwd_sign_kernel` | Float accumulator to int8 sign | accum[V,D] | grad_sign[V,D] int8 | |
| |
| ### Autograd Functions |
| |
| | Function | Module | Forward | Backward | |
| |----------|--------|---------|----------| |
| | `_TritonTernaryLinearFn` | TernaryScaleTensor | fwd kernel | grad_x kernel + retain grad_2d/x_2d | |
| | `_TritonRMSNormFn` | TernaryRMSNorm | rmsnorm_fwd kernel | rmsnorm_bwd kernel | |
| | `_TritonTernaryEmbedFn` | ByteEmbedding | embed_fwd kernel | embed_bwd accum+sign kernels + retain T | |
|
|
| ### Deleted kernels (Refactor 2 cleanup, confirmed in Refactor 3) |
|
|
| | Kernel | Reason | |
| |--------|--------| |
| | `_triton_ternary_step_score_kernel` | Replaced by direct group-owned E kernel | |
| | `_triton_ternary_step_score_block_kernel` | Same | |
| | `_triton_apply_e_score_kernel` | Same | |
| | `_triton_apply_e_score` | Same | |
| | `_triton_ternary_step_score` helper | Same | |
|
|
| ## Data Flow |
|
|
| ### Forward (CUDA) |
|
|
| ```text |
| ByteEmbedding: |
| indices βββΊ _triton_ternary_embed_fwd_kernel βββΊ emb[N,D] βββΊ TernaryRMSNorm βββΊ normed |
| |
| TernaryScaleTensor (linear): |
| x[M,K] βββΊ _triton_ternary_fwd_kernel βββΊ y[M,N] |
| |
| TernaryRMSNorm: |
| x[B,D] βββΊ _triton_rmsnorm_fwd_kernel βββΊ out[B,D] |
| ``` |
|
|
| ### Backward (CUDA) |
|
|
| ```text |
| ByteEmbedding: |
| grad_out[N,D] βββΊ atomic scatter-add βββΊ accum[V,D] βββΊ sign βββΊ grad_sign[V,D] int8 |
| (no float w_eff materialized; T unpacked for _hook_T only) |
| |
| TernaryScaleTensor (linear): |
| grad_y[M,N] βββΊ _triton_ternary_grad_x_kernel βββΊ grad_x[M,K] |
| (grad_2d, x_2d retained for update_E and ternary_step) |
| |
| TernaryRMSNorm: |
| grad_out[B,D] βββΊ _triton_rmsnorm_bwd_kernel βββΊ grad_x[B,D] |
| (no grad captured for frozen weights) |
| ``` |
|
|
| ### State Updates (CUDA) |
|
|
| ```text |
| update_E: |
| _triton_update_e_direct(T_packed, grad_2d, x_2d, E) |
| - Computes sign(grad^T @ x) internally |
| - Reads packed T to get current ternary signs |
| - Sums grad_sign * T per group |
| - Applies delta = -sign(group_score) to E |
| |
| ternary_step: |
| _triton_ternary_step_direct(T_packed, grad_2d, x_2d, T_accum) |
| - Computes sign(grad^T @ x) internally |
| - Updates T_accum += grad_sign |
| - Flips T where |T_accum| > threshold |
| - Repacks T_packed in-place |
| ``` |
|
|
| ## Float Materialization Audit |
|
|
| | Path | Float tensors created | Persistent? | |
| |------|----------------------|-------------| |
| | TernaryScaleTensor CUDA forward | `y[M,N]` float32 output | Ephemeral (output tensor) | |
| | TernaryScaleTensor CUDA backward | `grad_x[M,K]` float32 | Ephemeral (autograd) | |
| | TernaryScaleTensor CUDA update_E | None | All int8 in-place | |
| | TernaryScaleTensor CUDA ternary_step | None | All int8/uint8 in-place | |
| | TernaryRMSNorm CUDA forward | `out[B,D]` float32 output | Ephemeral | |
| | TernaryRMSNorm CUDA backward | `grad_x[B,D]` float32 | Ephemeral (autograd) | |
| | ByteEmbedding CUDA forward | `out[N,D]` float32 output | Ephemeral | |
| | ByteEmbedding CUDA backward | `accum[V,D]` float32 | Ephemeral (freed after sign) | |
| | ByteEmbedding CUDA backward | `grad_sign[V,D]` int8 | Hook (consumed by ternary_step) | |
| | CPU fallbacks | `w_eff`, `S`, `T.float()` | Ephemeral (via detach+requires_grad) | |
| |
| No persistent float parameters or float optimizer state in strict ternary mode. |
| |
| ## Loss Spike Investigation |
| |
| Strict ternary training shows loss spike from 6.89 to 10.08 at step 2. Root cause analysis: |
| |
| ### Primary cause: mass T-flip at step 2 |
| |
| With `T_accum` initialized to zeros and `accum_threshold=3`: |
| - Step 1: `T_accum` goes from 0 to +1 or -1 |
| - Step 2: `T_accum` reaches 2 or 3+ |
| - When `T_accum` hits threshold, correlated initial gradients cause thousands of simultaneous T sign flips |
|
|
| This is a catastrophic weight change at step 2 because the model's learned representations from step 1 are invalidated en masse. |
|
|
| ### Secondary: E-then-T ordering |
|
|
| `update_E()` runs before `ternary_step()`. After E is updated based on pre-flip T, `ternary_step()` may flip T values, making E inconsistent with the new T state. |
|
|
| ### Tertiary: redundant normalization |
|
|
| The training loop applies `clip_grad_norm_` then `inv_scale = 1/||grad||`. For SignSGD, `sign(x) = sign(x/||x||)`, so `inv_scale` normalization has no effect on the optimizer. The double normalization is dead code for the sign path. |
|
|
| ### Suggested fixes |
|
|
| 1. **Warmup `accum_threshold`**: start at 7+ and decay to 3 over ~100 steps |
| 2. **Swap update order**: call `ternary_step()` before `update_E()` so E sees post-flip T |
| 3. **Initialize `T_accum` with small random values** (e.g. `torch.randint(-2, 3)`) to break synchronization |
| 4. **Remove redundant `inv_scale`** for SignSGD (it does nothing) |
| 5. **Rate-limit T flips**: only flip top-k% of positions per step |
| |
| ## Verification |
| |
| 146 tests pass: 27 tscale (22 original + 5 new correctness) + 119 morph. |
| |
| New correctness tests verify exact or near-exact match between Triton CUDA and CPU reference paths for: |
| - TernaryScaleTensor forward + backward (all 6 TScaleTypes) |
| - TernaryRMSNorm forward + backward (all 6 TScaleTypes) |
| - ByteEmbedding forward + grad_sign (all 6 TScaleTypes) |
| - E update (exact match, all 6 TScaleTypes) |
| - T flip + T_accum (exact match, all 6 TScaleTypes) |
| |
| ## Remaining Work |
| |
| 1. **T_accum warmup or random init** to prevent step-2 mass-flip loss spike |
| 2. **Swap ternary_step/update_E ordering** for E/T consistency |
| 3. **Remove redundant inv_scale** in training loop for SignSGD |
| 4. **Benchmark at target batch=1024** to check if ~45s/step is JIT warmup |
| 5. **Tune Triton block sizes** for production layer dimensions |
| 6. **Evaluate TileLang** against proven Triton kernels for tensor-core path |
| 7. **ByteEmbedding `_hook_T`**: backward still unpacks T for `_hook_T` (needed by `ByteEmbedding.update_E` CPU path). Could be replaced with a Triton kernel that reads T_packed directly. |
| 8. **GNNLoRAAdapter `self.B`**, **MemGram `nn.ParameterList`**, **GraphMoEGate `self.query`**: remaining trainable float params that break strict ternary in non-text-only mode |
| |