True Ternary Refactor 3
Scope
This pass extends the Triton kernel coverage from TernaryScaleTensor (linear layers) to the remaining model components: TernaryRMSNorm (normalization) and ByteEmbedding (token embedding). It also removes dead code, adds low-level correctness tests, and investigates the training loss spike.
Refactor 2 established the packed-ternary Triton path for TernaryScaleTensor forward, grad-x, sign-only weight-gradient, E-update, and ternary-step/repack. This pass ports the two remaining hot paths and cleans up the kernel surface.
What Changed
tscale.py
TernaryRMSNorm Triton kernels
- Added
_triton_rmsnorm_fwd_kernel: fused RMS norm + ternary weight application in a single kernel. Each program loads a batch row ofx, computes RMS, normalizes, then loads packed ternary weights and int8 exponents for that dimension, dequantizes viasign * exp2(E), and multiplies into the output. - Added
_triton_rmsnorm_bwd_kernel: fused backward for RMS norm with ternary weights. Computesdx = (dy * w - x_norm * mean(x_norm * dy * w)) / rmswithout materializing the weight tensor. - Added
_TritonRMSNormFn(torch.autograd.Function): autograd wrapper that savesx_2d,packed,efor backward. Forward calls_triton_rmsnorm_fwd_kernel, backward calls_triton_rmsnorm_bwd_kernel. TernaryRMSNorm.forward()now prefers the Triton path when input is on CUDA. CPU fallback remains unchanged.TernaryRMSNorm.ternary_step()andTernaryRMSNorm.update_E()remain no-ops (these weights are frozen).
Correctness: forward max diff 0.0 vs CPU reference, backward max diff 0.0 vs autograd reference, for all 6 TScaleTypes.
ByteEmbedding Triton kernels
- Added
_triton_ternary_embed_fwd_kernel: embedding lookup from packed ternary + int8 E. Each program loads a batch of indices, gathers the corresponding packed ternary values and exponents, dequantizes viasign * exp2(E), and writes the output embedding vectors. Uses linear indexing into the flat packed buffer (lin = idx * DIM + d,pack_idx = lin // 5). - Added
_triton_ternary_embed_bwd_accum_kernel: scatter-adds gradient contributions from each index position into a float accumulator buffer shaped[VOCAB, DIM]usingtl.atomic_add. - Added
_triton_ternary_embed_bwd_sign_kernel: converts the float accumulator to int8 sign (1 / 0 / -1) in a single pass over[VOCAB, DIM]. - Added
_triton_ternary_embed_grad_sign(indices, grad_output, vocab, dim): two-pass helper that runs the accum kernel then the sign kernel. - Updated
_TritonTernaryEmbedFn.backward(): no longer materializes floatw_eff. Instead calls_triton_ternary_embed_grad_sign()to computegrad_signdirectly into int8 via atomic scatter-add + sign. Still unpacksTfor_hook_T(used byupdate_ECPU-style path on ByteEmbedding).
Correctness: forward max diff 0.0 vs CPU reference, grad_sign 100% match vs CPU for all 6 TScaleTypes including duplicate indices.
Dead code removal
- Removed
_hook_defer_e_to_ternary_stepdead branch fromternary_step(). This branch called the deleted_triton_ternary_step_scorekernel. The direct path (_triton_ternary_step_direct) is now the only Triton ternary-step path. - Removed the corresponding assertion from
test_cuda_triton_tscale_paththat checked_hook_defer_e_to_ternary_stepwas not set. - No other code references
_hook_defer_e_to_ternary_stepor the deleted score kernels.
Bug fix: missing _triton_ternary_embed_fwd_kernel
The embed forward kernel definition was accidentally deleted during a previous session's cleanup (the helper _triton_ternary_embed still referenced it, causing NameError at runtime). Reconstructed with correct linear-index packed ternary decoding. The kernel was verified from scratch: exact match vs CPU for all 6 TScaleTypes.
trigram.py
No changes to model code in this pass. The ByteEmbedding and TernaryRMSNorm classes already called the Triton path via _TritonTernaryEmbedFn and _TritonRMSNormFn β they just were not working because the kernel definitions were missing.
testing/test_tscale.py
Added 5 low-level Triton vs CPU reference correctness tests:
| Test | What it checks |
|---|---|
test_cuda_triton_correctness_linear |
TernaryScaleTensor forward + grad-x vs CPU for all 6 TScaleTypes (atol 1e-3) |
test_cuda_triton_correctness_rmsnorm |
TernaryRMSNorm forward + backward vs CPU for all 6 TScaleTypes (atol 1e-5) |
test_cuda_triton_correctness_embedding |
ByteEmbedding forward + grad_sign vs CPU for all 6 TScaleTypes |
test_cuda_triton_correctness_update_E |
E update exact match vs CPU for all 6 TScaleTypes |
test_cuda_triton_correctness_ternary_step |
T flip + T_accum exact match vs CPU for all 6 TScaleTypes |
All tests create CPU and GPU modules from the same random seed, run forward + backward independently, and compare outputs/state updates.
Kernel Inventory
TernaryScaleTensor (from Refactor 2)
| Kernel | Purpose | Input | Output |
|---|---|---|---|
_triton_ternary_fwd_kernel |
Packed ternary GEMM | x[M,K], T_packed, E | y[M,N] float32 |
_triton_ternary_grad_x_kernel |
Grad-input GEMM | grad_y[M,N], T_packed, E | grad_x[M,K] float32 |
_triton_ternary_grad_sign_kernel |
Sign-only weight gradient | grad_y[M,N], x[M,K] | grad_sign[N,K] int8 |
_triton_update_e_kernel |
E update from precomputed grad_sign | T_packed, grad_sign[N,K], E | E (in-place) |
_triton_update_e_direct_kernel |
E update from raw grad/x (avoids grad_sign alloc) | T_packed, grad_y[M,N], x[M,K], E | E (in-place) |
_triton_ternary_step_kernel |
T_accum update + flip + repack from grad_sign | T_packed, grad_sign[N,K], T_accum | T_packed, T_accum (in-place) |
_triton_ternary_step_direct_kernel |
T_accum update + flip + repack from raw grad/x | T_packed, grad_y[M,N], x[M,K], T_accum | T_packed, T_accum (in-place) |
TernaryRMSNorm (new in Refactor 3)
| Kernel | Purpose | Input | Output |
|---|---|---|---|
_triton_rmsnorm_fwd_kernel |
RMS norm + ternary weight | x[B,D], T_packed, E | out[B,D] float32 |
_triton_rmsnorm_bwd_kernel |
RMS norm backward through ternary weight | grad_out[B,D], x[B,D], T_packed, E | grad_x[B,D] float32 |
ByteEmbedding (new in Refactor 3)
| Kernel | Purpose | Input | Output |
|---|---|---|---|
_triton_ternary_embed_fwd_kernel |
Embedding lookup from packed ternary | indices[N], T_packed, E | out[N,D] float32 |
_triton_ternary_embed_bwd_accum_kernel |
Scatter-add grad into per-vocab accumulator | indices[N], grad_out[N,D], accum[V,D] | accum (atomic add) |
_triton_ternary_embed_bwd_sign_kernel |
Float accumulator to int8 sign | accum[V,D] | grad_sign[V,D] int8 |
Autograd Functions
| Function | Module | Forward | Backward |
|---|---|---|---|
_TritonTernaryLinearFn |
TernaryScaleTensor | fwd kernel | grad_x kernel + retain grad_2d/x_2d |
_TritonRMSNormFn |
TernaryRMSNorm | rmsnorm_fwd kernel | rmsnorm_bwd kernel |
_TritonTernaryEmbedFn |
ByteEmbedding | embed_fwd kernel | embed_bwd accum+sign kernels + retain T |
Deleted kernels (Refactor 2 cleanup, confirmed in Refactor 3)
| Kernel | Reason |
|---|---|
_triton_ternary_step_score_kernel |
Replaced by direct group-owned E kernel |
_triton_ternary_step_score_block_kernel |
Same |
_triton_apply_e_score_kernel |
Same |
_triton_apply_e_score |
Same |
_triton_ternary_step_score helper |
Same |
Data Flow
Forward (CUDA)
ByteEmbedding:
indices βββΊ _triton_ternary_embed_fwd_kernel βββΊ emb[N,D] βββΊ TernaryRMSNorm βββΊ normed
TernaryScaleTensor (linear):
x[M,K] βββΊ _triton_ternary_fwd_kernel βββΊ y[M,N]
TernaryRMSNorm:
x[B,D] βββΊ _triton_rmsnorm_fwd_kernel βββΊ out[B,D]
Backward (CUDA)
ByteEmbedding:
grad_out[N,D] βββΊ atomic scatter-add βββΊ accum[V,D] βββΊ sign βββΊ grad_sign[V,D] int8
(no float w_eff materialized; T unpacked for _hook_T only)
TernaryScaleTensor (linear):
grad_y[M,N] βββΊ _triton_ternary_grad_x_kernel βββΊ grad_x[M,K]
(grad_2d, x_2d retained for update_E and ternary_step)
TernaryRMSNorm:
grad_out[B,D] βββΊ _triton_rmsnorm_bwd_kernel βββΊ grad_x[B,D]
(no grad captured for frozen weights)
State Updates (CUDA)
update_E:
_triton_update_e_direct(T_packed, grad_2d, x_2d, E)
- Computes sign(grad^T @ x) internally
- Reads packed T to get current ternary signs
- Sums grad_sign * T per group
- Applies delta = -sign(group_score) to E
ternary_step:
_triton_ternary_step_direct(T_packed, grad_2d, x_2d, T_accum)
- Computes sign(grad^T @ x) internally
- Updates T_accum += grad_sign
- Flips T where |T_accum| > threshold
- Repacks T_packed in-place
Float Materialization Audit
| Path | Float tensors created | Persistent? |
|---|---|---|
| TernaryScaleTensor CUDA forward | y[M,N] float32 output |
Ephemeral (output tensor) |
| TernaryScaleTensor CUDA backward | grad_x[M,K] float32 |
Ephemeral (autograd) |
| TernaryScaleTensor CUDA update_E | None | All int8 in-place |
| TernaryScaleTensor CUDA ternary_step | None | All int8/uint8 in-place |
| TernaryRMSNorm CUDA forward | out[B,D] float32 output |
Ephemeral |
| TernaryRMSNorm CUDA backward | grad_x[B,D] float32 |
Ephemeral (autograd) |
| ByteEmbedding CUDA forward | out[N,D] float32 output |
Ephemeral |
| ByteEmbedding CUDA backward | accum[V,D] float32 |
Ephemeral (freed after sign) |
| ByteEmbedding CUDA backward | grad_sign[V,D] int8 |
Hook (consumed by ternary_step) |
| CPU fallbacks | w_eff, S, T.float() |
Ephemeral (via detach+requires_grad) |
No persistent float parameters or float optimizer state in strict ternary mode.
Loss Spike Investigation
Strict ternary training shows loss spike from 6.89 to 10.08 at step 2. Root cause analysis:
Primary cause: mass T-flip at step 2
With T_accum initialized to zeros and accum_threshold=3:
- Step 1:
T_accumgoes from 0 to +1 or -1 - Step 2:
T_accumreaches 2 or 3+ - When
T_accumhits threshold, correlated initial gradients cause thousands of simultaneous T sign flips
This is a catastrophic weight change at step 2 because the model's learned representations from step 1 are invalidated en masse.
Secondary: E-then-T ordering
update_E() runs before ternary_step(). After E is updated based on pre-flip T, ternary_step() may flip T values, making E inconsistent with the new T state.
Tertiary: redundant normalization
The training loop applies clip_grad_norm_ then inv_scale = 1/||grad||. For SignSGD, sign(x) = sign(x/||x||), so inv_scale normalization has no effect on the optimizer. The double normalization is dead code for the sign path.
Suggested fixes
- Warmup
accum_threshold: start at 7+ and decay to 3 over ~100 steps - Swap update order: call
ternary_step()beforeupdate_E()so E sees post-flip T - Initialize
T_accumwith small random values (e.g.torch.randint(-2, 3)) to break synchronization - Remove redundant
inv_scalefor SignSGD (it does nothing) - Rate-limit T flips: only flip top-k% of positions per step
Verification
146 tests pass: 27 tscale (22 original + 5 new correctness) + 119 morph.
New correctness tests verify exact or near-exact match between Triton CUDA and CPU reference paths for:
- TernaryScaleTensor forward + backward (all 6 TScaleTypes)
- TernaryRMSNorm forward + backward (all 6 TScaleTypes)
- ByteEmbedding forward + grad_sign (all 6 TScaleTypes)
- E update (exact match, all 6 TScaleTypes)
- T flip + T_accum (exact match, all 6 TScaleTypes)
Remaining Work
- T_accum warmup or random init to prevent step-2 mass-flip loss spike
- Swap ternary_step/update_E ordering for E/T consistency
- Remove redundant inv_scale in training loop for SignSGD
- Benchmark at target batch=1024 to check if ~45s/step is JIT warmup
- Tune Triton block sizes for production layer dimensions
- Evaluate TileLang against proven Triton kernels for tensor-core path
- ByteEmbedding
_hook_T: backward still unpacks T for_hook_T(needed byByteEmbedding.update_ECPU path). Could be replaced with a Triton kernel that reads T_packed directly. - GNNLoRAAdapter
self.B, MemGramnn.ParameterList, GraphMoEGateself.query: remaining trainable float params that break strict ternary in non-text-only mode