| # True Ternary Refactor 8 — MoE/Graph Triton Kernel Phase |
|
|
| ## Scope |
|
|
| This phase keeps the focus on MoE and Graph kernels. It does not change the strict ternary architecture target from REFACTOR6/7. |
|
|
| ## MoE Kernel Work |
|
|
| Added a Triton-backed dense MoE combine kernel for the fixed-shape small-batch path. |
|
|
| Previous dense path: |
|
|
| ```text |
| for expert: |
| expert_weight = torch.where(topk_idx == expert, topk_weights, 0).sum() |
| routed_out += expert_weight * expert_out |
| ``` |
|
|
| New dense path: |
|
|
| ```text |
| expert_out_stack = [num_experts, tokens, dim] |
| routed_out = triton_moe_dense_combine(expert_out_stack, topk_idx, topk_weights) |
| ``` |
|
|
| The kernel performs selected expert output accumulation in one Triton launch. Backward is also Triton-backed: |
|
|
| - `grad_expert_out`: atomic accumulated from selected route weights. |
| - `grad_topk_weights`: dot product of `grad_out` with selected expert output. |
| - `topk_idx`: treated as routing indices, no gradient. |
|
|
| Correctness check against PyTorch reference: |
|
|
| ```text |
| moe_fwd_maxdiff: 2.38e-07 |
| moe_grad_expert_maxdiff: 0.0 |
| moe_grad_weight_maxdiff: 9.54e-07 |
| ``` |
|
|
| This does not yet fuse the expert ternary projections themselves. It fuses the dense MoE routing combine stage and keeps the fixed-shape dispatch from REFACTOR7, which is the practical route to reducing recompilation churn without destabilizing training. |
|
|
| ## Graph Kernel Work |
|
|
| Added a Triton-backed Graph gather/add kernel: |
|
|
| Previous graph position path: |
|
|
| ```text |
| graph_features = node_features[vq_indices] |
| per_position = vq_output + graph_features |
| ``` |
|
|
| New path: |
|
|
| ```text |
| per_position = triton_graph_gather_add(vq_output, node_features, vq_indices) |
| ``` |
|
|
| Backward is Triton-backed: |
|
|
| - `grad_vq_output`: direct copy from `grad_out`. |
| - `grad_node_features`: atomic add by `vq_indices`. |
|
|
| Correctness check against PyTorch reference: |
|
|
| ```text |
| graph_fwd_maxdiff: 0.0 |
| graph_grad_vq_maxdiff: 0.0 |
| graph_grad_node_maxdiff: 4.77e-07 |
| ``` |
|
|
| The Graph path now has two Triton-backed custom stages: |
|
|
| 1. Edge ternary weighting + target aggregation. |
| 2. VQ-index node gather + residual add. |
|
|
| The remaining unfused Graph work is the full hop body: packed ternary message projection, aggregation, update projection, residual, and hop-LoRA composition in one larger kernel. |
|
|
| ## Verification |
|
|
| - `python -m py_compile trigram.py tscale.py benchmark_true_ternary.py train.py ternary_audit.py testing/test_tscale.py` |
| - `python -m pytest -q testing/test_tscale.py -k "cuda_triton_correctness_update_E or cuda_triton_tscale_path"`: `2 passed` |
| - Isolated MoE combine Triton kernel forward/backward matched PyTorch reference. |
| - Isolated Graph gather/add Triton kernel forward/backward matched PyTorch reference. |
| - Full CUDA model smoke with VQ, Graph, Memory, and MoE enabled passed forward, backward, and `_ternary_update_memory()`. |
| - Strict one-step train smoke stayed zero-float: |
|
|
| ```text |
| trainable float params: 0 |
| frozen float params: 0 |
| float buffers: 0 |
| training step: about 1.92 it/s before eval/checkpoint overhead |
| final val loss: 7.7302 |
| ``` |
|
|
| ## Remaining Kernel Work |
|
|
| 1. Fuse the dense MoE expert body itself: |
| - packed ternary `W_gate` |
| - packed ternary `W_transform` |
| - shared hidden multiply |
| - packed ternary `shared_down` |
| - selected route accumulation |
|
|
| 2. Fuse the full Graph hop: |
| - packed ternary source projection |
| - ternary edge aggregation |
| - packed ternary update projection |
| - residual + hop-LoRA |
|
|
| 3. Add benchmark counters that separate: |
| - first-call Triton compile time |
| - steady-state forward/backward time |
| - ternary update time |
|
|