True Ternary Refactor 8 — MoE/Graph Triton Kernel Phase
Scope
This phase keeps the focus on MoE and Graph kernels. It does not change the strict ternary architecture target from REFACTOR6/7.
MoE Kernel Work
Added a Triton-backed dense MoE combine kernel for the fixed-shape small-batch path.
Previous dense path:
for expert:
expert_weight = torch.where(topk_idx == expert, topk_weights, 0).sum()
routed_out += expert_weight * expert_out
New dense path:
expert_out_stack = [num_experts, tokens, dim]
routed_out = triton_moe_dense_combine(expert_out_stack, topk_idx, topk_weights)
The kernel performs selected expert output accumulation in one Triton launch. Backward is also Triton-backed:
grad_expert_out: atomic accumulated from selected route weights.grad_topk_weights: dot product ofgrad_outwith selected expert output.topk_idx: treated as routing indices, no gradient.
Correctness check against PyTorch reference:
moe_fwd_maxdiff: 2.38e-07
moe_grad_expert_maxdiff: 0.0
moe_grad_weight_maxdiff: 9.54e-07
This does not yet fuse the expert ternary projections themselves. It fuses the dense MoE routing combine stage and keeps the fixed-shape dispatch from REFACTOR7, which is the practical route to reducing recompilation churn without destabilizing training.
Graph Kernel Work
Added a Triton-backed Graph gather/add kernel:
Previous graph position path:
graph_features = node_features[vq_indices]
per_position = vq_output + graph_features
New path:
per_position = triton_graph_gather_add(vq_output, node_features, vq_indices)
Backward is Triton-backed:
grad_vq_output: direct copy fromgrad_out.grad_node_features: atomic add byvq_indices.
Correctness check against PyTorch reference:
graph_fwd_maxdiff: 0.0
graph_grad_vq_maxdiff: 0.0
graph_grad_node_maxdiff: 4.77e-07
The Graph path now has two Triton-backed custom stages:
- Edge ternary weighting + target aggregation.
- VQ-index node gather + residual add.
The remaining unfused Graph work is the full hop body: packed ternary message projection, aggregation, update projection, residual, and hop-LoRA composition in one larger kernel.
Verification
python -m py_compile trigram.py tscale.py benchmark_true_ternary.py train.py ternary_audit.py testing/test_tscale.pypython -m pytest -q testing/test_tscale.py -k "cuda_triton_correctness_update_E or cuda_triton_tscale_path":2 passed- Isolated MoE combine Triton kernel forward/backward matched PyTorch reference.
- Isolated Graph gather/add Triton kernel forward/backward matched PyTorch reference.
- Full CUDA model smoke with VQ, Graph, Memory, and MoE enabled passed forward, backward, and
_ternary_update_memory(). - Strict one-step train smoke stayed zero-float:
trainable float params: 0
frozen float params: 0
float buffers: 0
training step: about 1.92 it/s before eval/checkpoint overhead
final val loss: 7.7302
Remaining Kernel Work
Fuse the dense MoE expert body itself:
- packed ternary
W_gate - packed ternary
W_transform - shared hidden multiply
- packed ternary
shared_down - selected route accumulation
- packed ternary
Fuse the full Graph hop:
- packed ternary source projection
- ternary edge aggregation
- packed ternary update projection
- residual + hop-LoRA
Add benchmark counters that separate:
- first-call Triton compile time
- steady-state forward/backward time
- ternary update time