File size: 12,664 Bytes
d8bc908 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 | # True Ternary Refactor 3
## Scope
This pass extends the Triton kernel coverage from `TernaryScaleTensor` (linear layers) to the remaining model components: `TernaryRMSNorm` (normalization) and `ByteEmbedding` (token embedding). It also removes dead code, adds low-level correctness tests, and investigates the training loss spike.
Refactor 2 established the packed-ternary Triton path for `TernaryScaleTensor` forward, grad-x, sign-only weight-gradient, E-update, and ternary-step/repack. This pass ports the two remaining hot paths and cleans up the kernel surface.
## What Changed
### `tscale.py`
#### TernaryRMSNorm Triton kernels
- Added `_triton_rmsnorm_fwd_kernel`: fused RMS norm + ternary weight application in a single kernel. Each program loads a batch row of `x`, computes RMS, normalizes, then loads packed ternary weights and int8 exponents for that dimension, dequantizes via `sign * exp2(E)`, and multiplies into the output.
- Added `_triton_rmsnorm_bwd_kernel`: fused backward for RMS norm with ternary weights. Computes `dx = (dy * w - x_norm * mean(x_norm * dy * w)) / rms` without materializing the weight tensor.
- Added `_TritonRMSNormFn(torch.autograd.Function)`: autograd wrapper that saves `x_2d`, `packed`, `e` for backward. Forward calls `_triton_rmsnorm_fwd_kernel`, backward calls `_triton_rmsnorm_bwd_kernel`.
- `TernaryRMSNorm.forward()` now prefers the Triton path when input is on CUDA. CPU fallback remains unchanged.
- `TernaryRMSNorm.ternary_step()` and `TernaryRMSNorm.update_E()` remain no-ops (these weights are frozen).
Correctness: forward max diff 0.0 vs CPU reference, backward max diff 0.0 vs autograd reference, for all 6 TScaleTypes.
#### ByteEmbedding Triton kernels
- Added `_triton_ternary_embed_fwd_kernel`: embedding lookup from packed ternary + int8 E. Each program loads a batch of indices, gathers the corresponding packed ternary values and exponents, dequantizes via `sign * exp2(E)`, and writes the output embedding vectors. Uses linear indexing into the flat packed buffer (`lin = idx * DIM + d`, `pack_idx = lin // 5`).
- Added `_triton_ternary_embed_bwd_accum_kernel`: scatter-adds gradient contributions from each index position into a float accumulator buffer shaped `[VOCAB, DIM]` using `tl.atomic_add`.
- Added `_triton_ternary_embed_bwd_sign_kernel`: converts the float accumulator to int8 sign (`1 / 0 / -1`) in a single pass over `[VOCAB, DIM]`.
- Added `_triton_ternary_embed_grad_sign(indices, grad_output, vocab, dim)`: two-pass helper that runs the accum kernel then the sign kernel.
- Updated `_TritonTernaryEmbedFn.backward()`: no longer materializes float `w_eff`. Instead calls `_triton_ternary_embed_grad_sign()` to compute `grad_sign` directly into int8 via atomic scatter-add + sign. Still unpacks `T` for `_hook_T` (used by `update_E` CPU-style path on ByteEmbedding).
Correctness: forward max diff 0.0 vs CPU reference, grad_sign 100% match vs CPU for all 6 TScaleTypes including duplicate indices.
#### Dead code removal
- Removed `_hook_defer_e_to_ternary_step` dead branch from `ternary_step()`. This branch called the deleted `_triton_ternary_step_score` kernel. The direct path (`_triton_ternary_step_direct`) is now the only Triton ternary-step path.
- Removed the corresponding assertion from `test_cuda_triton_tscale_path` that checked `_hook_defer_e_to_ternary_step` was not set.
- No other code references `_hook_defer_e_to_ternary_step` or the deleted score kernels.
#### Bug fix: missing `_triton_ternary_embed_fwd_kernel`
The embed forward kernel definition was accidentally deleted during a previous session's cleanup (the helper `_triton_ternary_embed` still referenced it, causing `NameError` at runtime). Reconstructed with correct linear-index packed ternary decoding. The kernel was verified from scratch: exact match vs CPU for all 6 TScaleTypes.
### `trigram.py`
No changes to model code in this pass. The `ByteEmbedding` and `TernaryRMSNorm` classes already called the Triton path via `_TritonTernaryEmbedFn` and `_TritonRMSNormFn` β they just were not working because the kernel definitions were missing.
### `testing/test_tscale.py`
Added 5 low-level Triton vs CPU reference correctness tests:
| Test | What it checks |
|------|---------------|
| `test_cuda_triton_correctness_linear` | TernaryScaleTensor forward + grad-x vs CPU for all 6 TScaleTypes (atol 1e-3) |
| `test_cuda_triton_correctness_rmsnorm` | TernaryRMSNorm forward + backward vs CPU for all 6 TScaleTypes (atol 1e-5) |
| `test_cuda_triton_correctness_embedding` | ByteEmbedding forward + grad_sign vs CPU for all 6 TScaleTypes |
| `test_cuda_triton_correctness_update_E` | E update exact match vs CPU for all 6 TScaleTypes |
| `test_cuda_triton_correctness_ternary_step` | T flip + T_accum exact match vs CPU for all 6 TScaleTypes |
All tests create CPU and GPU modules from the same random seed, run forward + backward independently, and compare outputs/state updates.
## Kernel Inventory
### TernaryScaleTensor (from Refactor 2)
| Kernel | Purpose | Input | Output |
|--------|---------|-------|--------|
| `_triton_ternary_fwd_kernel` | Packed ternary GEMM | x[M,K], T_packed, E | y[M,N] float32 |
| `_triton_ternary_grad_x_kernel` | Grad-input GEMM | grad_y[M,N], T_packed, E | grad_x[M,K] float32 |
| `_triton_ternary_grad_sign_kernel` | Sign-only weight gradient | grad_y[M,N], x[M,K] | grad_sign[N,K] int8 |
| `_triton_update_e_kernel` | E update from precomputed grad_sign | T_packed, grad_sign[N,K], E | E (in-place) |
| `_triton_update_e_direct_kernel` | E update from raw grad/x (avoids grad_sign alloc) | T_packed, grad_y[M,N], x[M,K], E | E (in-place) |
| `_triton_ternary_step_kernel` | T_accum update + flip + repack from grad_sign | T_packed, grad_sign[N,K], T_accum | T_packed, T_accum (in-place) |
| `_triton_ternary_step_direct_kernel` | T_accum update + flip + repack from raw grad/x | T_packed, grad_y[M,N], x[M,K], T_accum | T_packed, T_accum (in-place) |
### TernaryRMSNorm (new in Refactor 3)
| Kernel | Purpose | Input | Output |
|--------|---------|-------|--------|
| `_triton_rmsnorm_fwd_kernel` | RMS norm + ternary weight | x[B,D], T_packed, E | out[B,D] float32 |
| `_triton_rmsnorm_bwd_kernel` | RMS norm backward through ternary weight | grad_out[B,D], x[B,D], T_packed, E | grad_x[B,D] float32 |
### ByteEmbedding (new in Refactor 3)
| Kernel | Purpose | Input | Output |
|--------|---------|-------|--------|
| `_triton_ternary_embed_fwd_kernel` | Embedding lookup from packed ternary | indices[N], T_packed, E | out[N,D] float32 |
| `_triton_ternary_embed_bwd_accum_kernel` | Scatter-add grad into per-vocab accumulator | indices[N], grad_out[N,D], accum[V,D] | accum (atomic add) |
| `_triton_ternary_embed_bwd_sign_kernel` | Float accumulator to int8 sign | accum[V,D] | grad_sign[V,D] int8 |
### Autograd Functions
| Function | Module | Forward | Backward |
|----------|--------|---------|----------|
| `_TritonTernaryLinearFn` | TernaryScaleTensor | fwd kernel | grad_x kernel + retain grad_2d/x_2d |
| `_TritonRMSNormFn` | TernaryRMSNorm | rmsnorm_fwd kernel | rmsnorm_bwd kernel |
| `_TritonTernaryEmbedFn` | ByteEmbedding | embed_fwd kernel | embed_bwd accum+sign kernels + retain T |
### Deleted kernels (Refactor 2 cleanup, confirmed in Refactor 3)
| Kernel | Reason |
|--------|--------|
| `_triton_ternary_step_score_kernel` | Replaced by direct group-owned E kernel |
| `_triton_ternary_step_score_block_kernel` | Same |
| `_triton_apply_e_score_kernel` | Same |
| `_triton_apply_e_score` | Same |
| `_triton_ternary_step_score` helper | Same |
## Data Flow
### Forward (CUDA)
```text
ByteEmbedding:
indices βββΊ _triton_ternary_embed_fwd_kernel βββΊ emb[N,D] βββΊ TernaryRMSNorm βββΊ normed
TernaryScaleTensor (linear):
x[M,K] βββΊ _triton_ternary_fwd_kernel βββΊ y[M,N]
TernaryRMSNorm:
x[B,D] βββΊ _triton_rmsnorm_fwd_kernel βββΊ out[B,D]
```
### Backward (CUDA)
```text
ByteEmbedding:
grad_out[N,D] βββΊ atomic scatter-add βββΊ accum[V,D] βββΊ sign βββΊ grad_sign[V,D] int8
(no float w_eff materialized; T unpacked for _hook_T only)
TernaryScaleTensor (linear):
grad_y[M,N] βββΊ _triton_ternary_grad_x_kernel βββΊ grad_x[M,K]
(grad_2d, x_2d retained for update_E and ternary_step)
TernaryRMSNorm:
grad_out[B,D] βββΊ _triton_rmsnorm_bwd_kernel βββΊ grad_x[B,D]
(no grad captured for frozen weights)
```
### State Updates (CUDA)
```text
update_E:
_triton_update_e_direct(T_packed, grad_2d, x_2d, E)
- Computes sign(grad^T @ x) internally
- Reads packed T to get current ternary signs
- Sums grad_sign * T per group
- Applies delta = -sign(group_score) to E
ternary_step:
_triton_ternary_step_direct(T_packed, grad_2d, x_2d, T_accum)
- Computes sign(grad^T @ x) internally
- Updates T_accum += grad_sign
- Flips T where |T_accum| > threshold
- Repacks T_packed in-place
```
## Float Materialization Audit
| Path | Float tensors created | Persistent? |
|------|----------------------|-------------|
| TernaryScaleTensor CUDA forward | `y[M,N]` float32 output | Ephemeral (output tensor) |
| TernaryScaleTensor CUDA backward | `grad_x[M,K]` float32 | Ephemeral (autograd) |
| TernaryScaleTensor CUDA update_E | None | All int8 in-place |
| TernaryScaleTensor CUDA ternary_step | None | All int8/uint8 in-place |
| TernaryRMSNorm CUDA forward | `out[B,D]` float32 output | Ephemeral |
| TernaryRMSNorm CUDA backward | `grad_x[B,D]` float32 | Ephemeral (autograd) |
| ByteEmbedding CUDA forward | `out[N,D]` float32 output | Ephemeral |
| ByteEmbedding CUDA backward | `accum[V,D]` float32 | Ephemeral (freed after sign) |
| ByteEmbedding CUDA backward | `grad_sign[V,D]` int8 | Hook (consumed by ternary_step) |
| CPU fallbacks | `w_eff`, `S`, `T.float()` | Ephemeral (via detach+requires_grad) |
No persistent float parameters or float optimizer state in strict ternary mode.
## Loss Spike Investigation
Strict ternary training shows loss spike from 6.89 to 10.08 at step 2. Root cause analysis:
### Primary cause: mass T-flip at step 2
With `T_accum` initialized to zeros and `accum_threshold=3`:
- Step 1: `T_accum` goes from 0 to +1 or -1
- Step 2: `T_accum` reaches 2 or 3+
- When `T_accum` hits threshold, correlated initial gradients cause thousands of simultaneous T sign flips
This is a catastrophic weight change at step 2 because the model's learned representations from step 1 are invalidated en masse.
### Secondary: E-then-T ordering
`update_E()` runs before `ternary_step()`. After E is updated based on pre-flip T, `ternary_step()` may flip T values, making E inconsistent with the new T state.
### Tertiary: redundant normalization
The training loop applies `clip_grad_norm_` then `inv_scale = 1/||grad||`. For SignSGD, `sign(x) = sign(x/||x||)`, so `inv_scale` normalization has no effect on the optimizer. The double normalization is dead code for the sign path.
### Suggested fixes
1. **Warmup `accum_threshold`**: start at 7+ and decay to 3 over ~100 steps
2. **Swap update order**: call `ternary_step()` before `update_E()` so E sees post-flip T
3. **Initialize `T_accum` with small random values** (e.g. `torch.randint(-2, 3)`) to break synchronization
4. **Remove redundant `inv_scale`** for SignSGD (it does nothing)
5. **Rate-limit T flips**: only flip top-k% of positions per step
## Verification
146 tests pass: 27 tscale (22 original + 5 new correctness) + 119 morph.
New correctness tests verify exact or near-exact match between Triton CUDA and CPU reference paths for:
- TernaryScaleTensor forward + backward (all 6 TScaleTypes)
- TernaryRMSNorm forward + backward (all 6 TScaleTypes)
- ByteEmbedding forward + grad_sign (all 6 TScaleTypes)
- E update (exact match, all 6 TScaleTypes)
- T flip + T_accum (exact match, all 6 TScaleTypes)
## Remaining Work
1. **T_accum warmup or random init** to prevent step-2 mass-flip loss spike
2. **Swap ternary_step/update_E ordering** for E/T consistency
3. **Remove redundant inv_scale** in training loop for SignSGD
4. **Benchmark at target batch=1024** to check if ~45s/step is JIT warmup
5. **Tune Triton block sizes** for production layer dimensions
6. **Evaluate TileLang** against proven Triton kernels for tensor-core path
7. **ByteEmbedding `_hook_T`**: backward still unpacks T for `_hook_T` (needed by `ByteEmbedding.update_E` CPU path). Could be replaced with a Triton kernel that reads T_packed directly.
8. **GNNLoRAAdapter `self.B`**, **MemGram `nn.ParameterList`**, **GraphMoEGate `self.query`**: remaining trainable float params that break strict ternary in non-text-only mode
|