ARBS / docs /true-ternary /TRUE-TERNARY-REFACTOR3.md
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
# True Ternary Refactor 3
## Scope
This pass extends the Triton kernel coverage from `TernaryScaleTensor` (linear layers) to the remaining model components: `TernaryRMSNorm` (normalization) and `ByteEmbedding` (token embedding). It also removes dead code, adds low-level correctness tests, and investigates the training loss spike.
Refactor 2 established the packed-ternary Triton path for `TernaryScaleTensor` forward, grad-x, sign-only weight-gradient, E-update, and ternary-step/repack. This pass ports the two remaining hot paths and cleans up the kernel surface.
## What Changed
### `tscale.py`
#### TernaryRMSNorm Triton kernels
- Added `_triton_rmsnorm_fwd_kernel`: fused RMS norm + ternary weight application in a single kernel. Each program loads a batch row of `x`, computes RMS, normalizes, then loads packed ternary weights and int8 exponents for that dimension, dequantizes via `sign * exp2(E)`, and multiplies into the output.
- Added `_triton_rmsnorm_bwd_kernel`: fused backward for RMS norm with ternary weights. Computes `dx = (dy * w - x_norm * mean(x_norm * dy * w)) / rms` without materializing the weight tensor.
- Added `_TritonRMSNormFn(torch.autograd.Function)`: autograd wrapper that saves `x_2d`, `packed`, `e` for backward. Forward calls `_triton_rmsnorm_fwd_kernel`, backward calls `_triton_rmsnorm_bwd_kernel`.
- `TernaryRMSNorm.forward()` now prefers the Triton path when input is on CUDA. CPU fallback remains unchanged.
- `TernaryRMSNorm.ternary_step()` and `TernaryRMSNorm.update_E()` remain no-ops (these weights are frozen).
Correctness: forward max diff 0.0 vs CPU reference, backward max diff 0.0 vs autograd reference, for all 6 TScaleTypes.
#### ByteEmbedding Triton kernels
- Added `_triton_ternary_embed_fwd_kernel`: embedding lookup from packed ternary + int8 E. Each program loads a batch of indices, gathers the corresponding packed ternary values and exponents, dequantizes via `sign * exp2(E)`, and writes the output embedding vectors. Uses linear indexing into the flat packed buffer (`lin = idx * DIM + d`, `pack_idx = lin // 5`).
- Added `_triton_ternary_embed_bwd_accum_kernel`: scatter-adds gradient contributions from each index position into a float accumulator buffer shaped `[VOCAB, DIM]` using `tl.atomic_add`.
- Added `_triton_ternary_embed_bwd_sign_kernel`: converts the float accumulator to int8 sign (`1 / 0 / -1`) in a single pass over `[VOCAB, DIM]`.
- Added `_triton_ternary_embed_grad_sign(indices, grad_output, vocab, dim)`: two-pass helper that runs the accum kernel then the sign kernel.
- Updated `_TritonTernaryEmbedFn.backward()`: no longer materializes float `w_eff`. Instead calls `_triton_ternary_embed_grad_sign()` to compute `grad_sign` directly into int8 via atomic scatter-add + sign. Still unpacks `T` for `_hook_T` (used by `update_E` CPU-style path on ByteEmbedding).
Correctness: forward max diff 0.0 vs CPU reference, grad_sign 100% match vs CPU for all 6 TScaleTypes including duplicate indices.
#### Dead code removal
- Removed `_hook_defer_e_to_ternary_step` dead branch from `ternary_step()`. This branch called the deleted `_triton_ternary_step_score` kernel. The direct path (`_triton_ternary_step_direct`) is now the only Triton ternary-step path.
- Removed the corresponding assertion from `test_cuda_triton_tscale_path` that checked `_hook_defer_e_to_ternary_step` was not set.
- No other code references `_hook_defer_e_to_ternary_step` or the deleted score kernels.
#### Bug fix: missing `_triton_ternary_embed_fwd_kernel`
The embed forward kernel definition was accidentally deleted during a previous session's cleanup (the helper `_triton_ternary_embed` still referenced it, causing `NameError` at runtime). Reconstructed with correct linear-index packed ternary decoding. The kernel was verified from scratch: exact match vs CPU for all 6 TScaleTypes.
### `trigram.py`
No changes to model code in this pass. The `ByteEmbedding` and `TernaryRMSNorm` classes already called the Triton path via `_TritonTernaryEmbedFn` and `_TritonRMSNormFn` β€” they just were not working because the kernel definitions were missing.
### `testing/test_tscale.py`
Added 5 low-level Triton vs CPU reference correctness tests:
| Test | What it checks |
|------|---------------|
| `test_cuda_triton_correctness_linear` | TernaryScaleTensor forward + grad-x vs CPU for all 6 TScaleTypes (atol 1e-3) |
| `test_cuda_triton_correctness_rmsnorm` | TernaryRMSNorm forward + backward vs CPU for all 6 TScaleTypes (atol 1e-5) |
| `test_cuda_triton_correctness_embedding` | ByteEmbedding forward + grad_sign vs CPU for all 6 TScaleTypes |
| `test_cuda_triton_correctness_update_E` | E update exact match vs CPU for all 6 TScaleTypes |
| `test_cuda_triton_correctness_ternary_step` | T flip + T_accum exact match vs CPU for all 6 TScaleTypes |
All tests create CPU and GPU modules from the same random seed, run forward + backward independently, and compare outputs/state updates.
## Kernel Inventory
### TernaryScaleTensor (from Refactor 2)
| Kernel | Purpose | Input | Output |
|--------|---------|-------|--------|
| `_triton_ternary_fwd_kernel` | Packed ternary GEMM | x[M,K], T_packed, E | y[M,N] float32 |
| `_triton_ternary_grad_x_kernel` | Grad-input GEMM | grad_y[M,N], T_packed, E | grad_x[M,K] float32 |
| `_triton_ternary_grad_sign_kernel` | Sign-only weight gradient | grad_y[M,N], x[M,K] | grad_sign[N,K] int8 |
| `_triton_update_e_kernel` | E update from precomputed grad_sign | T_packed, grad_sign[N,K], E | E (in-place) |
| `_triton_update_e_direct_kernel` | E update from raw grad/x (avoids grad_sign alloc) | T_packed, grad_y[M,N], x[M,K], E | E (in-place) |
| `_triton_ternary_step_kernel` | T_accum update + flip + repack from grad_sign | T_packed, grad_sign[N,K], T_accum | T_packed, T_accum (in-place) |
| `_triton_ternary_step_direct_kernel` | T_accum update + flip + repack from raw grad/x | T_packed, grad_y[M,N], x[M,K], T_accum | T_packed, T_accum (in-place) |
### TernaryRMSNorm (new in Refactor 3)
| Kernel | Purpose | Input | Output |
|--------|---------|-------|--------|
| `_triton_rmsnorm_fwd_kernel` | RMS norm + ternary weight | x[B,D], T_packed, E | out[B,D] float32 |
| `_triton_rmsnorm_bwd_kernel` | RMS norm backward through ternary weight | grad_out[B,D], x[B,D], T_packed, E | grad_x[B,D] float32 |
### ByteEmbedding (new in Refactor 3)
| Kernel | Purpose | Input | Output |
|--------|---------|-------|--------|
| `_triton_ternary_embed_fwd_kernel` | Embedding lookup from packed ternary | indices[N], T_packed, E | out[N,D] float32 |
| `_triton_ternary_embed_bwd_accum_kernel` | Scatter-add grad into per-vocab accumulator | indices[N], grad_out[N,D], accum[V,D] | accum (atomic add) |
| `_triton_ternary_embed_bwd_sign_kernel` | Float accumulator to int8 sign | accum[V,D] | grad_sign[V,D] int8 |
### Autograd Functions
| Function | Module | Forward | Backward |
|----------|--------|---------|----------|
| `_TritonTernaryLinearFn` | TernaryScaleTensor | fwd kernel | grad_x kernel + retain grad_2d/x_2d |
| `_TritonRMSNormFn` | TernaryRMSNorm | rmsnorm_fwd kernel | rmsnorm_bwd kernel |
| `_TritonTernaryEmbedFn` | ByteEmbedding | embed_fwd kernel | embed_bwd accum+sign kernels + retain T |
### Deleted kernels (Refactor 2 cleanup, confirmed in Refactor 3)
| Kernel | Reason |
|--------|--------|
| `_triton_ternary_step_score_kernel` | Replaced by direct group-owned E kernel |
| `_triton_ternary_step_score_block_kernel` | Same |
| `_triton_apply_e_score_kernel` | Same |
| `_triton_apply_e_score` | Same |
| `_triton_ternary_step_score` helper | Same |
## Data Flow
### Forward (CUDA)
```text
ByteEmbedding:
indices ──► _triton_ternary_embed_fwd_kernel ──► emb[N,D] ──► TernaryRMSNorm ──► normed
TernaryScaleTensor (linear):
x[M,K] ──► _triton_ternary_fwd_kernel ──► y[M,N]
TernaryRMSNorm:
x[B,D] ──► _triton_rmsnorm_fwd_kernel ──► out[B,D]
```
### Backward (CUDA)
```text
ByteEmbedding:
grad_out[N,D] ──► atomic scatter-add ──► accum[V,D] ──► sign ──► grad_sign[V,D] int8
(no float w_eff materialized; T unpacked for _hook_T only)
TernaryScaleTensor (linear):
grad_y[M,N] ──► _triton_ternary_grad_x_kernel ──► grad_x[M,K]
(grad_2d, x_2d retained for update_E and ternary_step)
TernaryRMSNorm:
grad_out[B,D] ──► _triton_rmsnorm_bwd_kernel ──► grad_x[B,D]
(no grad captured for frozen weights)
```
### State Updates (CUDA)
```text
update_E:
_triton_update_e_direct(T_packed, grad_2d, x_2d, E)
- Computes sign(grad^T @ x) internally
- Reads packed T to get current ternary signs
- Sums grad_sign * T per group
- Applies delta = -sign(group_score) to E
ternary_step:
_triton_ternary_step_direct(T_packed, grad_2d, x_2d, T_accum)
- Computes sign(grad^T @ x) internally
- Updates T_accum += grad_sign
- Flips T where |T_accum| > threshold
- Repacks T_packed in-place
```
## Float Materialization Audit
| Path | Float tensors created | Persistent? |
|------|----------------------|-------------|
| TernaryScaleTensor CUDA forward | `y[M,N]` float32 output | Ephemeral (output tensor) |
| TernaryScaleTensor CUDA backward | `grad_x[M,K]` float32 | Ephemeral (autograd) |
| TernaryScaleTensor CUDA update_E | None | All int8 in-place |
| TernaryScaleTensor CUDA ternary_step | None | All int8/uint8 in-place |
| TernaryRMSNorm CUDA forward | `out[B,D]` float32 output | Ephemeral |
| TernaryRMSNorm CUDA backward | `grad_x[B,D]` float32 | Ephemeral (autograd) |
| ByteEmbedding CUDA forward | `out[N,D]` float32 output | Ephemeral |
| ByteEmbedding CUDA backward | `accum[V,D]` float32 | Ephemeral (freed after sign) |
| ByteEmbedding CUDA backward | `grad_sign[V,D]` int8 | Hook (consumed by ternary_step) |
| CPU fallbacks | `w_eff`, `S`, `T.float()` | Ephemeral (via detach+requires_grad) |
No persistent float parameters or float optimizer state in strict ternary mode.
## Loss Spike Investigation
Strict ternary training shows loss spike from 6.89 to 10.08 at step 2. Root cause analysis:
### Primary cause: mass T-flip at step 2
With `T_accum` initialized to zeros and `accum_threshold=3`:
- Step 1: `T_accum` goes from 0 to +1 or -1
- Step 2: `T_accum` reaches 2 or 3+
- When `T_accum` hits threshold, correlated initial gradients cause thousands of simultaneous T sign flips
This is a catastrophic weight change at step 2 because the model's learned representations from step 1 are invalidated en masse.
### Secondary: E-then-T ordering
`update_E()` runs before `ternary_step()`. After E is updated based on pre-flip T, `ternary_step()` may flip T values, making E inconsistent with the new T state.
### Tertiary: redundant normalization
The training loop applies `clip_grad_norm_` then `inv_scale = 1/||grad||`. For SignSGD, `sign(x) = sign(x/||x||)`, so `inv_scale` normalization has no effect on the optimizer. The double normalization is dead code for the sign path.
### Suggested fixes
1. **Warmup `accum_threshold`**: start at 7+ and decay to 3 over ~100 steps
2. **Swap update order**: call `ternary_step()` before `update_E()` so E sees post-flip T
3. **Initialize `T_accum` with small random values** (e.g. `torch.randint(-2, 3)`) to break synchronization
4. **Remove redundant `inv_scale`** for SignSGD (it does nothing)
5. **Rate-limit T flips**: only flip top-k% of positions per step
## Verification
146 tests pass: 27 tscale (22 original + 5 new correctness) + 119 morph.
New correctness tests verify exact or near-exact match between Triton CUDA and CPU reference paths for:
- TernaryScaleTensor forward + backward (all 6 TScaleTypes)
- TernaryRMSNorm forward + backward (all 6 TScaleTypes)
- ByteEmbedding forward + grad_sign (all 6 TScaleTypes)
- E update (exact match, all 6 TScaleTypes)
- T flip + T_accum (exact match, all 6 TScaleTypes)
## Remaining Work
1. **T_accum warmup or random init** to prevent step-2 mass-flip loss spike
2. **Swap ternary_step/update_E ordering** for E/T consistency
3. **Remove redundant inv_scale** in training loop for SignSGD
4. **Benchmark at target batch=1024** to check if ~45s/step is JIT warmup
5. **Tune Triton block sizes** for production layer dimensions
6. **Evaluate TileLang** against proven Triton kernels for tensor-core path
7. **ByteEmbedding `_hook_T`**: backward still unpacks T for `_hook_T` (needed by `ByteEmbedding.update_E` CPU path). Could be replaced with a Triton kernel that reads T_packed directly.
8. **GNNLoRAAdapter `self.B`**, **MemGram `nn.ParameterList`**, **GraphMoEGate `self.query`**: remaining trainable float params that break strict ternary in non-text-only mode