File size: 3,561 Bytes
d8bc908 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 | # True Ternary Refactor 8 — MoE/Graph Triton Kernel Phase
## Scope
This phase keeps the focus on MoE and Graph kernels. It does not change the strict ternary architecture target from REFACTOR6/7.
## MoE Kernel Work
Added a Triton-backed dense MoE combine kernel for the fixed-shape small-batch path.
Previous dense path:
```text
for expert:
expert_weight = torch.where(topk_idx == expert, topk_weights, 0).sum()
routed_out += expert_weight * expert_out
```
New dense path:
```text
expert_out_stack = [num_experts, tokens, dim]
routed_out = triton_moe_dense_combine(expert_out_stack, topk_idx, topk_weights)
```
The kernel performs selected expert output accumulation in one Triton launch. Backward is also Triton-backed:
- `grad_expert_out`: atomic accumulated from selected route weights.
- `grad_topk_weights`: dot product of `grad_out` with selected expert output.
- `topk_idx`: treated as routing indices, no gradient.
Correctness check against PyTorch reference:
```text
moe_fwd_maxdiff: 2.38e-07
moe_grad_expert_maxdiff: 0.0
moe_grad_weight_maxdiff: 9.54e-07
```
This does not yet fuse the expert ternary projections themselves. It fuses the dense MoE routing combine stage and keeps the fixed-shape dispatch from REFACTOR7, which is the practical route to reducing recompilation churn without destabilizing training.
## Graph Kernel Work
Added a Triton-backed Graph gather/add kernel:
Previous graph position path:
```text
graph_features = node_features[vq_indices]
per_position = vq_output + graph_features
```
New path:
```text
per_position = triton_graph_gather_add(vq_output, node_features, vq_indices)
```
Backward is Triton-backed:
- `grad_vq_output`: direct copy from `grad_out`.
- `grad_node_features`: atomic add by `vq_indices`.
Correctness check against PyTorch reference:
```text
graph_fwd_maxdiff: 0.0
graph_grad_vq_maxdiff: 0.0
graph_grad_node_maxdiff: 4.77e-07
```
The Graph path now has two Triton-backed custom stages:
1. Edge ternary weighting + target aggregation.
2. VQ-index node gather + residual add.
The remaining unfused Graph work is the full hop body: packed ternary message projection, aggregation, update projection, residual, and hop-LoRA composition in one larger kernel.
## Verification
- `python -m py_compile trigram.py tscale.py benchmark_true_ternary.py train.py ternary_audit.py testing/test_tscale.py`
- `python -m pytest -q testing/test_tscale.py -k "cuda_triton_correctness_update_E or cuda_triton_tscale_path"`: `2 passed`
- Isolated MoE combine Triton kernel forward/backward matched PyTorch reference.
- Isolated Graph gather/add Triton kernel forward/backward matched PyTorch reference.
- Full CUDA model smoke with VQ, Graph, Memory, and MoE enabled passed forward, backward, and `_ternary_update_memory()`.
- Strict one-step train smoke stayed zero-float:
```text
trainable float params: 0
frozen float params: 0
float buffers: 0
training step: about 1.92 it/s before eval/checkpoint overhead
final val loss: 7.7302
```
## Remaining Kernel Work
1. Fuse the dense MoE expert body itself:
- packed ternary `W_gate`
- packed ternary `W_transform`
- shared hidden multiply
- packed ternary `shared_down`
- selected route accumulation
2. Fuse the full Graph hop:
- packed ternary source projection
- ternary edge aggregation
- packed ternary update projection
- residual + hop-LoRA
3. Add benchmark counters that separate:
- first-call Triton compile time
- steady-state forward/backward time
- ternary update time
|