File size: 12,664 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
# True Ternary Refactor 3

## Scope

This pass extends the Triton kernel coverage from `TernaryScaleTensor` (linear layers) to the remaining model components: `TernaryRMSNorm` (normalization) and `ByteEmbedding` (token embedding). It also removes dead code, adds low-level correctness tests, and investigates the training loss spike.

Refactor 2 established the packed-ternary Triton path for `TernaryScaleTensor` forward, grad-x, sign-only weight-gradient, E-update, and ternary-step/repack. This pass ports the two remaining hot paths and cleans up the kernel surface.

## What Changed

### `tscale.py`

#### TernaryRMSNorm Triton kernels

- Added `_triton_rmsnorm_fwd_kernel`: fused RMS norm + ternary weight application in a single kernel. Each program loads a batch row of `x`, computes RMS, normalizes, then loads packed ternary weights and int8 exponents for that dimension, dequantizes via `sign * exp2(E)`, and multiplies into the output.
- Added `_triton_rmsnorm_bwd_kernel`: fused backward for RMS norm with ternary weights. Computes `dx = (dy * w - x_norm * mean(x_norm * dy * w)) / rms` without materializing the weight tensor.
- Added `_TritonRMSNormFn(torch.autograd.Function)`: autograd wrapper that saves `x_2d`, `packed`, `e` for backward. Forward calls `_triton_rmsnorm_fwd_kernel`, backward calls `_triton_rmsnorm_bwd_kernel`.
- `TernaryRMSNorm.forward()` now prefers the Triton path when input is on CUDA. CPU fallback remains unchanged.
- `TernaryRMSNorm.ternary_step()` and `TernaryRMSNorm.update_E()` remain no-ops (these weights are frozen).

Correctness: forward max diff 0.0 vs CPU reference, backward max diff 0.0 vs autograd reference, for all 6 TScaleTypes.

#### ByteEmbedding Triton kernels

- Added `_triton_ternary_embed_fwd_kernel`: embedding lookup from packed ternary + int8 E. Each program loads a batch of indices, gathers the corresponding packed ternary values and exponents, dequantizes via `sign * exp2(E)`, and writes the output embedding vectors. Uses linear indexing into the flat packed buffer (`lin = idx * DIM + d`, `pack_idx = lin // 5`).
- Added `_triton_ternary_embed_bwd_accum_kernel`: scatter-adds gradient contributions from each index position into a float accumulator buffer shaped `[VOCAB, DIM]` using `tl.atomic_add`.
- Added `_triton_ternary_embed_bwd_sign_kernel`: converts the float accumulator to int8 sign (`1 / 0 / -1`) in a single pass over `[VOCAB, DIM]`.
- Added `_triton_ternary_embed_grad_sign(indices, grad_output, vocab, dim)`: two-pass helper that runs the accum kernel then the sign kernel.
- Updated `_TritonTernaryEmbedFn.backward()`: no longer materializes float `w_eff`. Instead calls `_triton_ternary_embed_grad_sign()` to compute `grad_sign` directly into int8 via atomic scatter-add + sign. Still unpacks `T` for `_hook_T` (used by `update_E` CPU-style path on ByteEmbedding).

Correctness: forward max diff 0.0 vs CPU reference, grad_sign 100% match vs CPU for all 6 TScaleTypes including duplicate indices.

#### Dead code removal

- Removed `_hook_defer_e_to_ternary_step` dead branch from `ternary_step()`. This branch called the deleted `_triton_ternary_step_score` kernel. The direct path (`_triton_ternary_step_direct`) is now the only Triton ternary-step path.
- Removed the corresponding assertion from `test_cuda_triton_tscale_path` that checked `_hook_defer_e_to_ternary_step` was not set.
- No other code references `_hook_defer_e_to_ternary_step` or the deleted score kernels.

#### Bug fix: missing `_triton_ternary_embed_fwd_kernel`

The embed forward kernel definition was accidentally deleted during a previous session's cleanup (the helper `_triton_ternary_embed` still referenced it, causing `NameError` at runtime). Reconstructed with correct linear-index packed ternary decoding. The kernel was verified from scratch: exact match vs CPU for all 6 TScaleTypes.

### `trigram.py`

No changes to model code in this pass. The `ByteEmbedding` and `TernaryRMSNorm` classes already called the Triton path via `_TritonTernaryEmbedFn` and `_TritonRMSNormFn` β€” they just were not working because the kernel definitions were missing.

### `testing/test_tscale.py`

Added 5 low-level Triton vs CPU reference correctness tests:

| Test | What it checks |
|------|---------------|
| `test_cuda_triton_correctness_linear` | TernaryScaleTensor forward + grad-x vs CPU for all 6 TScaleTypes (atol 1e-3) |
| `test_cuda_triton_correctness_rmsnorm` | TernaryRMSNorm forward + backward vs CPU for all 6 TScaleTypes (atol 1e-5) |
| `test_cuda_triton_correctness_embedding` | ByteEmbedding forward + grad_sign vs CPU for all 6 TScaleTypes |
| `test_cuda_triton_correctness_update_E` | E update exact match vs CPU for all 6 TScaleTypes |
| `test_cuda_triton_correctness_ternary_step` | T flip + T_accum exact match vs CPU for all 6 TScaleTypes |

All tests create CPU and GPU modules from the same random seed, run forward + backward independently, and compare outputs/state updates.

## Kernel Inventory

### TernaryScaleTensor (from Refactor 2)

| Kernel | Purpose | Input | Output |
|--------|---------|-------|--------|
| `_triton_ternary_fwd_kernel` | Packed ternary GEMM | x[M,K], T_packed, E | y[M,N] float32 |
| `_triton_ternary_grad_x_kernel` | Grad-input GEMM | grad_y[M,N], T_packed, E | grad_x[M,K] float32 |
| `_triton_ternary_grad_sign_kernel` | Sign-only weight gradient | grad_y[M,N], x[M,K] | grad_sign[N,K] int8 |
| `_triton_update_e_kernel` | E update from precomputed grad_sign | T_packed, grad_sign[N,K], E | E (in-place) |
| `_triton_update_e_direct_kernel` | E update from raw grad/x (avoids grad_sign alloc) | T_packed, grad_y[M,N], x[M,K], E | E (in-place) |
| `_triton_ternary_step_kernel` | T_accum update + flip + repack from grad_sign | T_packed, grad_sign[N,K], T_accum | T_packed, T_accum (in-place) |
| `_triton_ternary_step_direct_kernel` | T_accum update + flip + repack from raw grad/x | T_packed, grad_y[M,N], x[M,K], T_accum | T_packed, T_accum (in-place) |

### TernaryRMSNorm (new in Refactor 3)

| Kernel | Purpose | Input | Output |
|--------|---------|-------|--------|
| `_triton_rmsnorm_fwd_kernel` | RMS norm + ternary weight | x[B,D], T_packed, E | out[B,D] float32 |
| `_triton_rmsnorm_bwd_kernel` | RMS norm backward through ternary weight | grad_out[B,D], x[B,D], T_packed, E | grad_x[B,D] float32 |

### ByteEmbedding (new in Refactor 3)

| Kernel | Purpose | Input | Output |
|--------|---------|-------|--------|
| `_triton_ternary_embed_fwd_kernel` | Embedding lookup from packed ternary | indices[N], T_packed, E | out[N,D] float32 |
| `_triton_ternary_embed_bwd_accum_kernel` | Scatter-add grad into per-vocab accumulator | indices[N], grad_out[N,D], accum[V,D] | accum (atomic add) |
| `_triton_ternary_embed_bwd_sign_kernel` | Float accumulator to int8 sign | accum[V,D] | grad_sign[V,D] int8 |

### Autograd Functions

| Function | Module | Forward | Backward |
|----------|--------|---------|----------|
| `_TritonTernaryLinearFn` | TernaryScaleTensor | fwd kernel | grad_x kernel + retain grad_2d/x_2d |
| `_TritonRMSNormFn` | TernaryRMSNorm | rmsnorm_fwd kernel | rmsnorm_bwd kernel |
| `_TritonTernaryEmbedFn` | ByteEmbedding | embed_fwd kernel | embed_bwd accum+sign kernels + retain T |

### Deleted kernels (Refactor 2 cleanup, confirmed in Refactor 3)

| Kernel | Reason |
|--------|--------|
| `_triton_ternary_step_score_kernel` | Replaced by direct group-owned E kernel |
| `_triton_ternary_step_score_block_kernel` | Same |
| `_triton_apply_e_score_kernel` | Same |
| `_triton_apply_e_score` | Same |
| `_triton_ternary_step_score` helper | Same |

## Data Flow

### Forward (CUDA)

```text
ByteEmbedding:
  indices ──► _triton_ternary_embed_fwd_kernel ──► emb[N,D] ──► TernaryRMSNorm ──► normed

TernaryScaleTensor (linear):
  x[M,K] ──► _triton_ternary_fwd_kernel ──► y[M,N]

TernaryRMSNorm:
  x[B,D] ──► _triton_rmsnorm_fwd_kernel ──► out[B,D]
```

### Backward (CUDA)

```text
ByteEmbedding:
  grad_out[N,D] ──► atomic scatter-add ──► accum[V,D] ──► sign ──► grad_sign[V,D] int8
  (no float w_eff materialized; T unpacked for _hook_T only)

TernaryScaleTensor (linear):
  grad_y[M,N] ──► _triton_ternary_grad_x_kernel ──► grad_x[M,K]
  (grad_2d, x_2d retained for update_E and ternary_step)

TernaryRMSNorm:
  grad_out[B,D] ──► _triton_rmsnorm_bwd_kernel ──► grad_x[B,D]
  (no grad captured for frozen weights)
```

### State Updates (CUDA)

```text
update_E:
  _triton_update_e_direct(T_packed, grad_2d, x_2d, E)
  - Computes sign(grad^T @ x) internally
  - Reads packed T to get current ternary signs
  - Sums grad_sign * T per group
  - Applies delta = -sign(group_score) to E

ternary_step:
  _triton_ternary_step_direct(T_packed, grad_2d, x_2d, T_accum)
  - Computes sign(grad^T @ x) internally
  - Updates T_accum += grad_sign
  - Flips T where |T_accum| > threshold
  - Repacks T_packed in-place
```

## Float Materialization Audit

| Path | Float tensors created | Persistent? |
|------|----------------------|-------------|
| TernaryScaleTensor CUDA forward | `y[M,N]` float32 output | Ephemeral (output tensor) |
| TernaryScaleTensor CUDA backward | `grad_x[M,K]` float32 | Ephemeral (autograd) |
| TernaryScaleTensor CUDA update_E | None | All int8 in-place |
| TernaryScaleTensor CUDA ternary_step | None | All int8/uint8 in-place |
| TernaryRMSNorm CUDA forward | `out[B,D]` float32 output | Ephemeral |
| TernaryRMSNorm CUDA backward | `grad_x[B,D]` float32 | Ephemeral (autograd) |
| ByteEmbedding CUDA forward | `out[N,D]` float32 output | Ephemeral |
| ByteEmbedding CUDA backward | `accum[V,D]` float32 | Ephemeral (freed after sign) |
| ByteEmbedding CUDA backward | `grad_sign[V,D]` int8 | Hook (consumed by ternary_step) |
| CPU fallbacks | `w_eff`, `S`, `T.float()` | Ephemeral (via detach+requires_grad) |

No persistent float parameters or float optimizer state in strict ternary mode.

## Loss Spike Investigation

Strict ternary training shows loss spike from 6.89 to 10.08 at step 2. Root cause analysis:

### Primary cause: mass T-flip at step 2

With `T_accum` initialized to zeros and `accum_threshold=3`:
- Step 1: `T_accum` goes from 0 to +1 or -1
- Step 2: `T_accum` reaches 2 or 3+
- When `T_accum` hits threshold, correlated initial gradients cause thousands of simultaneous T sign flips

This is a catastrophic weight change at step 2 because the model's learned representations from step 1 are invalidated en masse.

### Secondary: E-then-T ordering

`update_E()` runs before `ternary_step()`. After E is updated based on pre-flip T, `ternary_step()` may flip T values, making E inconsistent with the new T state.

### Tertiary: redundant normalization

The training loop applies `clip_grad_norm_` then `inv_scale = 1/||grad||`. For SignSGD, `sign(x) = sign(x/||x||)`, so `inv_scale` normalization has no effect on the optimizer. The double normalization is dead code for the sign path.

### Suggested fixes

1. **Warmup `accum_threshold`**: start at 7+ and decay to 3 over ~100 steps
2. **Swap update order**: call `ternary_step()` before `update_E()` so E sees post-flip T
3. **Initialize `T_accum` with small random values** (e.g. `torch.randint(-2, 3)`) to break synchronization
4. **Remove redundant `inv_scale`** for SignSGD (it does nothing)
5. **Rate-limit T flips**: only flip top-k% of positions per step

## Verification

146 tests pass: 27 tscale (22 original + 5 new correctness) + 119 morph.

New correctness tests verify exact or near-exact match between Triton CUDA and CPU reference paths for:
- TernaryScaleTensor forward + backward (all 6 TScaleTypes)
- TernaryRMSNorm forward + backward (all 6 TScaleTypes)
- ByteEmbedding forward + grad_sign (all 6 TScaleTypes)
- E update (exact match, all 6 TScaleTypes)
- T flip + T_accum (exact match, all 6 TScaleTypes)

## Remaining Work

1. **T_accum warmup or random init** to prevent step-2 mass-flip loss spike
2. **Swap ternary_step/update_E ordering** for E/T consistency
3. **Remove redundant inv_scale** in training loop for SignSGD
4. **Benchmark at target batch=1024** to check if ~45s/step is JIT warmup
5. **Tune Triton block sizes** for production layer dimensions
6. **Evaluate TileLang** against proven Triton kernels for tensor-core path
7. **ByteEmbedding `_hook_T`**: backward still unpacks T for `_hook_T` (needed by `ByteEmbedding.update_E` CPU path). Could be replaced with a Triton kernel that reads T_packed directly.
8. **GNNLoRAAdapter `self.B`**, **MemGram `nn.ParameterList`**, **GraphMoEGate `self.query`**: remaining trainable float params that break strict ternary in non-text-only mode