# True Ternary Refactor 3 ## Scope This pass extends the Triton kernel coverage from `TernaryScaleTensor` (linear layers) to the remaining model components: `TernaryRMSNorm` (normalization) and `ByteEmbedding` (token embedding). It also removes dead code, adds low-level correctness tests, and investigates the training loss spike. Refactor 2 established the packed-ternary Triton path for `TernaryScaleTensor` forward, grad-x, sign-only weight-gradient, E-update, and ternary-step/repack. This pass ports the two remaining hot paths and cleans up the kernel surface. ## What Changed ### `tscale.py` #### TernaryRMSNorm Triton kernels - Added `_triton_rmsnorm_fwd_kernel`: fused RMS norm + ternary weight application in a single kernel. Each program loads a batch row of `x`, computes RMS, normalizes, then loads packed ternary weights and int8 exponents for that dimension, dequantizes via `sign * exp2(E)`, and multiplies into the output. - Added `_triton_rmsnorm_bwd_kernel`: fused backward for RMS norm with ternary weights. Computes `dx = (dy * w - x_norm * mean(x_norm * dy * w)) / rms` without materializing the weight tensor. - Added `_TritonRMSNormFn(torch.autograd.Function)`: autograd wrapper that saves `x_2d`, `packed`, `e` for backward. Forward calls `_triton_rmsnorm_fwd_kernel`, backward calls `_triton_rmsnorm_bwd_kernel`. - `TernaryRMSNorm.forward()` now prefers the Triton path when input is on CUDA. CPU fallback remains unchanged. - `TernaryRMSNorm.ternary_step()` and `TernaryRMSNorm.update_E()` remain no-ops (these weights are frozen). Correctness: forward max diff 0.0 vs CPU reference, backward max diff 0.0 vs autograd reference, for all 6 TScaleTypes. #### ByteEmbedding Triton kernels - Added `_triton_ternary_embed_fwd_kernel`: embedding lookup from packed ternary + int8 E. Each program loads a batch of indices, gathers the corresponding packed ternary values and exponents, dequantizes via `sign * exp2(E)`, and writes the output embedding vectors. Uses linear indexing into the flat packed buffer (`lin = idx * DIM + d`, `pack_idx = lin // 5`). - Added `_triton_ternary_embed_bwd_accum_kernel`: scatter-adds gradient contributions from each index position into a float accumulator buffer shaped `[VOCAB, DIM]` using `tl.atomic_add`. - Added `_triton_ternary_embed_bwd_sign_kernel`: converts the float accumulator to int8 sign (`1 / 0 / -1`) in a single pass over `[VOCAB, DIM]`. - Added `_triton_ternary_embed_grad_sign(indices, grad_output, vocab, dim)`: two-pass helper that runs the accum kernel then the sign kernel. - Updated `_TritonTernaryEmbedFn.backward()`: no longer materializes float `w_eff`. Instead calls `_triton_ternary_embed_grad_sign()` to compute `grad_sign` directly into int8 via atomic scatter-add + sign. Still unpacks `T` for `_hook_T` (used by `update_E` CPU-style path on ByteEmbedding). Correctness: forward max diff 0.0 vs CPU reference, grad_sign 100% match vs CPU for all 6 TScaleTypes including duplicate indices. #### Dead code removal - Removed `_hook_defer_e_to_ternary_step` dead branch from `ternary_step()`. This branch called the deleted `_triton_ternary_step_score` kernel. The direct path (`_triton_ternary_step_direct`) is now the only Triton ternary-step path. - Removed the corresponding assertion from `test_cuda_triton_tscale_path` that checked `_hook_defer_e_to_ternary_step` was not set. - No other code references `_hook_defer_e_to_ternary_step` or the deleted score kernels. #### Bug fix: missing `_triton_ternary_embed_fwd_kernel` The embed forward kernel definition was accidentally deleted during a previous session's cleanup (the helper `_triton_ternary_embed` still referenced it, causing `NameError` at runtime). Reconstructed with correct linear-index packed ternary decoding. The kernel was verified from scratch: exact match vs CPU for all 6 TScaleTypes. ### `trigram.py` No changes to model code in this pass. The `ByteEmbedding` and `TernaryRMSNorm` classes already called the Triton path via `_TritonTernaryEmbedFn` and `_TritonRMSNormFn` — they just were not working because the kernel definitions were missing. ### `testing/test_tscale.py` Added 5 low-level Triton vs CPU reference correctness tests: | Test | What it checks | |------|---------------| | `test_cuda_triton_correctness_linear` | TernaryScaleTensor forward + grad-x vs CPU for all 6 TScaleTypes (atol 1e-3) | | `test_cuda_triton_correctness_rmsnorm` | TernaryRMSNorm forward + backward vs CPU for all 6 TScaleTypes (atol 1e-5) | | `test_cuda_triton_correctness_embedding` | ByteEmbedding forward + grad_sign vs CPU for all 6 TScaleTypes | | `test_cuda_triton_correctness_update_E` | E update exact match vs CPU for all 6 TScaleTypes | | `test_cuda_triton_correctness_ternary_step` | T flip + T_accum exact match vs CPU for all 6 TScaleTypes | All tests create CPU and GPU modules from the same random seed, run forward + backward independently, and compare outputs/state updates. ## Kernel Inventory ### TernaryScaleTensor (from Refactor 2) | Kernel | Purpose | Input | Output | |--------|---------|-------|--------| | `_triton_ternary_fwd_kernel` | Packed ternary GEMM | x[M,K], T_packed, E | y[M,N] float32 | | `_triton_ternary_grad_x_kernel` | Grad-input GEMM | grad_y[M,N], T_packed, E | grad_x[M,K] float32 | | `_triton_ternary_grad_sign_kernel` | Sign-only weight gradient | grad_y[M,N], x[M,K] | grad_sign[N,K] int8 | | `_triton_update_e_kernel` | E update from precomputed grad_sign | T_packed, grad_sign[N,K], E | E (in-place) | | `_triton_update_e_direct_kernel` | E update from raw grad/x (avoids grad_sign alloc) | T_packed, grad_y[M,N], x[M,K], E | E (in-place) | | `_triton_ternary_step_kernel` | T_accum update + flip + repack from grad_sign | T_packed, grad_sign[N,K], T_accum | T_packed, T_accum (in-place) | | `_triton_ternary_step_direct_kernel` | T_accum update + flip + repack from raw grad/x | T_packed, grad_y[M,N], x[M,K], T_accum | T_packed, T_accum (in-place) | ### TernaryRMSNorm (new in Refactor 3) | Kernel | Purpose | Input | Output | |--------|---------|-------|--------| | `_triton_rmsnorm_fwd_kernel` | RMS norm + ternary weight | x[B,D], T_packed, E | out[B,D] float32 | | `_triton_rmsnorm_bwd_kernel` | RMS norm backward through ternary weight | grad_out[B,D], x[B,D], T_packed, E | grad_x[B,D] float32 | ### ByteEmbedding (new in Refactor 3) | Kernel | Purpose | Input | Output | |--------|---------|-------|--------| | `_triton_ternary_embed_fwd_kernel` | Embedding lookup from packed ternary | indices[N], T_packed, E | out[N,D] float32 | | `_triton_ternary_embed_bwd_accum_kernel` | Scatter-add grad into per-vocab accumulator | indices[N], grad_out[N,D], accum[V,D] | accum (atomic add) | | `_triton_ternary_embed_bwd_sign_kernel` | Float accumulator to int8 sign | accum[V,D] | grad_sign[V,D] int8 | ### Autograd Functions | Function | Module | Forward | Backward | |----------|--------|---------|----------| | `_TritonTernaryLinearFn` | TernaryScaleTensor | fwd kernel | grad_x kernel + retain grad_2d/x_2d | | `_TritonRMSNormFn` | TernaryRMSNorm | rmsnorm_fwd kernel | rmsnorm_bwd kernel | | `_TritonTernaryEmbedFn` | ByteEmbedding | embed_fwd kernel | embed_bwd accum+sign kernels + retain T | ### Deleted kernels (Refactor 2 cleanup, confirmed in Refactor 3) | Kernel | Reason | |--------|--------| | `_triton_ternary_step_score_kernel` | Replaced by direct group-owned E kernel | | `_triton_ternary_step_score_block_kernel` | Same | | `_triton_apply_e_score_kernel` | Same | | `_triton_apply_e_score` | Same | | `_triton_ternary_step_score` helper | Same | ## Data Flow ### Forward (CUDA) ```text ByteEmbedding: indices ──► _triton_ternary_embed_fwd_kernel ──► emb[N,D] ──► TernaryRMSNorm ──► normed TernaryScaleTensor (linear): x[M,K] ──► _triton_ternary_fwd_kernel ──► y[M,N] TernaryRMSNorm: x[B,D] ──► _triton_rmsnorm_fwd_kernel ──► out[B,D] ``` ### Backward (CUDA) ```text ByteEmbedding: grad_out[N,D] ──► atomic scatter-add ──► accum[V,D] ──► sign ──► grad_sign[V,D] int8 (no float w_eff materialized; T unpacked for _hook_T only) TernaryScaleTensor (linear): grad_y[M,N] ──► _triton_ternary_grad_x_kernel ──► grad_x[M,K] (grad_2d, x_2d retained for update_E and ternary_step) TernaryRMSNorm: grad_out[B,D] ──► _triton_rmsnorm_bwd_kernel ──► grad_x[B,D] (no grad captured for frozen weights) ``` ### State Updates (CUDA) ```text update_E: _triton_update_e_direct(T_packed, grad_2d, x_2d, E) - Computes sign(grad^T @ x) internally - Reads packed T to get current ternary signs - Sums grad_sign * T per group - Applies delta = -sign(group_score) to E ternary_step: _triton_ternary_step_direct(T_packed, grad_2d, x_2d, T_accum) - Computes sign(grad^T @ x) internally - Updates T_accum += grad_sign - Flips T where |T_accum| > threshold - Repacks T_packed in-place ``` ## Float Materialization Audit | Path | Float tensors created | Persistent? | |------|----------------------|-------------| | TernaryScaleTensor CUDA forward | `y[M,N]` float32 output | Ephemeral (output tensor) | | TernaryScaleTensor CUDA backward | `grad_x[M,K]` float32 | Ephemeral (autograd) | | TernaryScaleTensor CUDA update_E | None | All int8 in-place | | TernaryScaleTensor CUDA ternary_step | None | All int8/uint8 in-place | | TernaryRMSNorm CUDA forward | `out[B,D]` float32 output | Ephemeral | | TernaryRMSNorm CUDA backward | `grad_x[B,D]` float32 | Ephemeral (autograd) | | ByteEmbedding CUDA forward | `out[N,D]` float32 output | Ephemeral | | ByteEmbedding CUDA backward | `accum[V,D]` float32 | Ephemeral (freed after sign) | | ByteEmbedding CUDA backward | `grad_sign[V,D]` int8 | Hook (consumed by ternary_step) | | CPU fallbacks | `w_eff`, `S`, `T.float()` | Ephemeral (via detach+requires_grad) | No persistent float parameters or float optimizer state in strict ternary mode. ## Loss Spike Investigation Strict ternary training shows loss spike from 6.89 to 10.08 at step 2. Root cause analysis: ### Primary cause: mass T-flip at step 2 With `T_accum` initialized to zeros and `accum_threshold=3`: - Step 1: `T_accum` goes from 0 to +1 or -1 - Step 2: `T_accum` reaches 2 or 3+ - When `T_accum` hits threshold, correlated initial gradients cause thousands of simultaneous T sign flips This is a catastrophic weight change at step 2 because the model's learned representations from step 1 are invalidated en masse. ### Secondary: E-then-T ordering `update_E()` runs before `ternary_step()`. After E is updated based on pre-flip T, `ternary_step()` may flip T values, making E inconsistent with the new T state. ### Tertiary: redundant normalization The training loop applies `clip_grad_norm_` then `inv_scale = 1/||grad||`. For SignSGD, `sign(x) = sign(x/||x||)`, so `inv_scale` normalization has no effect on the optimizer. The double normalization is dead code for the sign path. ### Suggested fixes 1. **Warmup `accum_threshold`**: start at 7+ and decay to 3 over ~100 steps 2. **Swap update order**: call `ternary_step()` before `update_E()` so E sees post-flip T 3. **Initialize `T_accum` with small random values** (e.g. `torch.randint(-2, 3)`) to break synchronization 4. **Remove redundant `inv_scale`** for SignSGD (it does nothing) 5. **Rate-limit T flips**: only flip top-k% of positions per step ## Verification 146 tests pass: 27 tscale (22 original + 5 new correctness) + 119 morph. New correctness tests verify exact or near-exact match between Triton CUDA and CPU reference paths for: - TernaryScaleTensor forward + backward (all 6 TScaleTypes) - TernaryRMSNorm forward + backward (all 6 TScaleTypes) - ByteEmbedding forward + grad_sign (all 6 TScaleTypes) - E update (exact match, all 6 TScaleTypes) - T flip + T_accum (exact match, all 6 TScaleTypes) ## Remaining Work 1. **T_accum warmup or random init** to prevent step-2 mass-flip loss spike 2. **Swap ternary_step/update_E ordering** for E/T consistency 3. **Remove redundant inv_scale** in training loop for SignSGD 4. **Benchmark at target batch=1024** to check if ~45s/step is JIT warmup 5. **Tune Triton block sizes** for production layer dimensions 6. **Evaluate TileLang** against proven Triton kernels for tensor-core path 7. **ByteEmbedding `_hook_T`**: backward still unpacks T for `_hook_T` (needed by `ByteEmbedding.update_E` CPU path). Could be replaced with a Triton kernel that reads T_packed directly. 8. **GNNLoRAAdapter `self.B`**, **MemGram `nn.ParameterList`**, **GraphMoEGate `self.query`**: remaining trainable float params that break strict ternary in non-text-only mode