ARBS / docs /true-ternary /TRUE-TERNARY-REFACTOR3.md
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified

True Ternary Refactor 3

Scope

This pass extends the Triton kernel coverage from TernaryScaleTensor (linear layers) to the remaining model components: TernaryRMSNorm (normalization) and ByteEmbedding (token embedding). It also removes dead code, adds low-level correctness tests, and investigates the training loss spike.

Refactor 2 established the packed-ternary Triton path for TernaryScaleTensor forward, grad-x, sign-only weight-gradient, E-update, and ternary-step/repack. This pass ports the two remaining hot paths and cleans up the kernel surface.

What Changed

tscale.py

TernaryRMSNorm Triton kernels

  • Added _triton_rmsnorm_fwd_kernel: fused RMS norm + ternary weight application in a single kernel. Each program loads a batch row of x, computes RMS, normalizes, then loads packed ternary weights and int8 exponents for that dimension, dequantizes via sign * exp2(E), and multiplies into the output.
  • Added _triton_rmsnorm_bwd_kernel: fused backward for RMS norm with ternary weights. Computes dx = (dy * w - x_norm * mean(x_norm * dy * w)) / rms without materializing the weight tensor.
  • Added _TritonRMSNormFn(torch.autograd.Function): autograd wrapper that saves x_2d, packed, e for backward. Forward calls _triton_rmsnorm_fwd_kernel, backward calls _triton_rmsnorm_bwd_kernel.
  • TernaryRMSNorm.forward() now prefers the Triton path when input is on CUDA. CPU fallback remains unchanged.
  • TernaryRMSNorm.ternary_step() and TernaryRMSNorm.update_E() remain no-ops (these weights are frozen).

Correctness: forward max diff 0.0 vs CPU reference, backward max diff 0.0 vs autograd reference, for all 6 TScaleTypes.

ByteEmbedding Triton kernels

  • Added _triton_ternary_embed_fwd_kernel: embedding lookup from packed ternary + int8 E. Each program loads a batch of indices, gathers the corresponding packed ternary values and exponents, dequantizes via sign * exp2(E), and writes the output embedding vectors. Uses linear indexing into the flat packed buffer (lin = idx * DIM + d, pack_idx = lin // 5).
  • Added _triton_ternary_embed_bwd_accum_kernel: scatter-adds gradient contributions from each index position into a float accumulator buffer shaped [VOCAB, DIM] using tl.atomic_add.
  • Added _triton_ternary_embed_bwd_sign_kernel: converts the float accumulator to int8 sign (1 / 0 / -1) in a single pass over [VOCAB, DIM].
  • Added _triton_ternary_embed_grad_sign(indices, grad_output, vocab, dim): two-pass helper that runs the accum kernel then the sign kernel.
  • Updated _TritonTernaryEmbedFn.backward(): no longer materializes float w_eff. Instead calls _triton_ternary_embed_grad_sign() to compute grad_sign directly into int8 via atomic scatter-add + sign. Still unpacks T for _hook_T (used by update_E CPU-style path on ByteEmbedding).

Correctness: forward max diff 0.0 vs CPU reference, grad_sign 100% match vs CPU for all 6 TScaleTypes including duplicate indices.

Dead code removal

  • Removed _hook_defer_e_to_ternary_step dead branch from ternary_step(). This branch called the deleted _triton_ternary_step_score kernel. The direct path (_triton_ternary_step_direct) is now the only Triton ternary-step path.
  • Removed the corresponding assertion from test_cuda_triton_tscale_path that checked _hook_defer_e_to_ternary_step was not set.
  • No other code references _hook_defer_e_to_ternary_step or the deleted score kernels.

Bug fix: missing _triton_ternary_embed_fwd_kernel

The embed forward kernel definition was accidentally deleted during a previous session's cleanup (the helper _triton_ternary_embed still referenced it, causing NameError at runtime). Reconstructed with correct linear-index packed ternary decoding. The kernel was verified from scratch: exact match vs CPU for all 6 TScaleTypes.

trigram.py

No changes to model code in this pass. The ByteEmbedding and TernaryRMSNorm classes already called the Triton path via _TritonTernaryEmbedFn and _TritonRMSNormFn β€” they just were not working because the kernel definitions were missing.

testing/test_tscale.py

Added 5 low-level Triton vs CPU reference correctness tests:

Test What it checks
test_cuda_triton_correctness_linear TernaryScaleTensor forward + grad-x vs CPU for all 6 TScaleTypes (atol 1e-3)
test_cuda_triton_correctness_rmsnorm TernaryRMSNorm forward + backward vs CPU for all 6 TScaleTypes (atol 1e-5)
test_cuda_triton_correctness_embedding ByteEmbedding forward + grad_sign vs CPU for all 6 TScaleTypes
test_cuda_triton_correctness_update_E E update exact match vs CPU for all 6 TScaleTypes
test_cuda_triton_correctness_ternary_step T flip + T_accum exact match vs CPU for all 6 TScaleTypes

All tests create CPU and GPU modules from the same random seed, run forward + backward independently, and compare outputs/state updates.

Kernel Inventory

TernaryScaleTensor (from Refactor 2)

Kernel Purpose Input Output
_triton_ternary_fwd_kernel Packed ternary GEMM x[M,K], T_packed, E y[M,N] float32
_triton_ternary_grad_x_kernel Grad-input GEMM grad_y[M,N], T_packed, E grad_x[M,K] float32
_triton_ternary_grad_sign_kernel Sign-only weight gradient grad_y[M,N], x[M,K] grad_sign[N,K] int8
_triton_update_e_kernel E update from precomputed grad_sign T_packed, grad_sign[N,K], E E (in-place)
_triton_update_e_direct_kernel E update from raw grad/x (avoids grad_sign alloc) T_packed, grad_y[M,N], x[M,K], E E (in-place)
_triton_ternary_step_kernel T_accum update + flip + repack from grad_sign T_packed, grad_sign[N,K], T_accum T_packed, T_accum (in-place)
_triton_ternary_step_direct_kernel T_accum update + flip + repack from raw grad/x T_packed, grad_y[M,N], x[M,K], T_accum T_packed, T_accum (in-place)

TernaryRMSNorm (new in Refactor 3)

Kernel Purpose Input Output
_triton_rmsnorm_fwd_kernel RMS norm + ternary weight x[B,D], T_packed, E out[B,D] float32
_triton_rmsnorm_bwd_kernel RMS norm backward through ternary weight grad_out[B,D], x[B,D], T_packed, E grad_x[B,D] float32

ByteEmbedding (new in Refactor 3)

Kernel Purpose Input Output
_triton_ternary_embed_fwd_kernel Embedding lookup from packed ternary indices[N], T_packed, E out[N,D] float32
_triton_ternary_embed_bwd_accum_kernel Scatter-add grad into per-vocab accumulator indices[N], grad_out[N,D], accum[V,D] accum (atomic add)
_triton_ternary_embed_bwd_sign_kernel Float accumulator to int8 sign accum[V,D] grad_sign[V,D] int8

Autograd Functions

Function Module Forward Backward
_TritonTernaryLinearFn TernaryScaleTensor fwd kernel grad_x kernel + retain grad_2d/x_2d
_TritonRMSNormFn TernaryRMSNorm rmsnorm_fwd kernel rmsnorm_bwd kernel
_TritonTernaryEmbedFn ByteEmbedding embed_fwd kernel embed_bwd accum+sign kernels + retain T

Deleted kernels (Refactor 2 cleanup, confirmed in Refactor 3)

Kernel Reason
_triton_ternary_step_score_kernel Replaced by direct group-owned E kernel
_triton_ternary_step_score_block_kernel Same
_triton_apply_e_score_kernel Same
_triton_apply_e_score Same
_triton_ternary_step_score helper Same

Data Flow

Forward (CUDA)

ByteEmbedding:
  indices ──► _triton_ternary_embed_fwd_kernel ──► emb[N,D] ──► TernaryRMSNorm ──► normed

TernaryScaleTensor (linear):
  x[M,K] ──► _triton_ternary_fwd_kernel ──► y[M,N]

TernaryRMSNorm:
  x[B,D] ──► _triton_rmsnorm_fwd_kernel ──► out[B,D]

Backward (CUDA)

ByteEmbedding:
  grad_out[N,D] ──► atomic scatter-add ──► accum[V,D] ──► sign ──► grad_sign[V,D] int8
  (no float w_eff materialized; T unpacked for _hook_T only)

TernaryScaleTensor (linear):
  grad_y[M,N] ──► _triton_ternary_grad_x_kernel ──► grad_x[M,K]
  (grad_2d, x_2d retained for update_E and ternary_step)

TernaryRMSNorm:
  grad_out[B,D] ──► _triton_rmsnorm_bwd_kernel ──► grad_x[B,D]
  (no grad captured for frozen weights)

State Updates (CUDA)

update_E:
  _triton_update_e_direct(T_packed, grad_2d, x_2d, E)
  - Computes sign(grad^T @ x) internally
  - Reads packed T to get current ternary signs
  - Sums grad_sign * T per group
  - Applies delta = -sign(group_score) to E

ternary_step:
  _triton_ternary_step_direct(T_packed, grad_2d, x_2d, T_accum)
  - Computes sign(grad^T @ x) internally
  - Updates T_accum += grad_sign
  - Flips T where |T_accum| > threshold
  - Repacks T_packed in-place

Float Materialization Audit

Path Float tensors created Persistent?
TernaryScaleTensor CUDA forward y[M,N] float32 output Ephemeral (output tensor)
TernaryScaleTensor CUDA backward grad_x[M,K] float32 Ephemeral (autograd)
TernaryScaleTensor CUDA update_E None All int8 in-place
TernaryScaleTensor CUDA ternary_step None All int8/uint8 in-place
TernaryRMSNorm CUDA forward out[B,D] float32 output Ephemeral
TernaryRMSNorm CUDA backward grad_x[B,D] float32 Ephemeral (autograd)
ByteEmbedding CUDA forward out[N,D] float32 output Ephemeral
ByteEmbedding CUDA backward accum[V,D] float32 Ephemeral (freed after sign)
ByteEmbedding CUDA backward grad_sign[V,D] int8 Hook (consumed by ternary_step)
CPU fallbacks w_eff, S, T.float() Ephemeral (via detach+requires_grad)

No persistent float parameters or float optimizer state in strict ternary mode.

Loss Spike Investigation

Strict ternary training shows loss spike from 6.89 to 10.08 at step 2. Root cause analysis:

Primary cause: mass T-flip at step 2

With T_accum initialized to zeros and accum_threshold=3:

  • Step 1: T_accum goes from 0 to +1 or -1
  • Step 2: T_accum reaches 2 or 3+
  • When T_accum hits threshold, correlated initial gradients cause thousands of simultaneous T sign flips

This is a catastrophic weight change at step 2 because the model's learned representations from step 1 are invalidated en masse.

Secondary: E-then-T ordering

update_E() runs before ternary_step(). After E is updated based on pre-flip T, ternary_step() may flip T values, making E inconsistent with the new T state.

Tertiary: redundant normalization

The training loop applies clip_grad_norm_ then inv_scale = 1/||grad||. For SignSGD, sign(x) = sign(x/||x||), so inv_scale normalization has no effect on the optimizer. The double normalization is dead code for the sign path.

Suggested fixes

  1. Warmup accum_threshold: start at 7+ and decay to 3 over ~100 steps
  2. Swap update order: call ternary_step() before update_E() so E sees post-flip T
  3. Initialize T_accum with small random values (e.g. torch.randint(-2, 3)) to break synchronization
  4. Remove redundant inv_scale for SignSGD (it does nothing)
  5. Rate-limit T flips: only flip top-k% of positions per step

Verification

146 tests pass: 27 tscale (22 original + 5 new correctness) + 119 morph.

New correctness tests verify exact or near-exact match between Triton CUDA and CPU reference paths for:

  • TernaryScaleTensor forward + backward (all 6 TScaleTypes)
  • TernaryRMSNorm forward + backward (all 6 TScaleTypes)
  • ByteEmbedding forward + grad_sign (all 6 TScaleTypes)
  • E update (exact match, all 6 TScaleTypes)
  • T flip + T_accum (exact match, all 6 TScaleTypes)

Remaining Work

  1. T_accum warmup or random init to prevent step-2 mass-flip loss spike
  2. Swap ternary_step/update_E ordering for E/T consistency
  3. Remove redundant inv_scale in training loop for SignSGD
  4. Benchmark at target batch=1024 to check if ~45s/step is JIT warmup
  5. Tune Triton block sizes for production layer dimensions
  6. Evaluate TileLang against proven Triton kernels for tensor-core path
  7. ByteEmbedding _hook_T: backward still unpacks T for _hook_T (needed by ByteEmbedding.update_E CPU path). Could be replaced with a Triton kernel that reads T_packed directly.
  8. GNNLoRAAdapter self.B, MemGram nn.ParameterList, GraphMoEGate self.query: remaining trainable float params that break strict ternary in non-text-only mode