ARBS / docs /true-ternary /TRUE-TERNARY-REFACTOR8.md
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified

True Ternary Refactor 8 — MoE/Graph Triton Kernel Phase

Scope

This phase keeps the focus on MoE and Graph kernels. It does not change the strict ternary architecture target from REFACTOR6/7.

MoE Kernel Work

Added a Triton-backed dense MoE combine kernel for the fixed-shape small-batch path.

Previous dense path:

for expert:
    expert_weight = torch.where(topk_idx == expert, topk_weights, 0).sum()
    routed_out += expert_weight * expert_out

New dense path:

expert_out_stack = [num_experts, tokens, dim]
routed_out = triton_moe_dense_combine(expert_out_stack, topk_idx, topk_weights)

The kernel performs selected expert output accumulation in one Triton launch. Backward is also Triton-backed:

  • grad_expert_out: atomic accumulated from selected route weights.
  • grad_topk_weights: dot product of grad_out with selected expert output.
  • topk_idx: treated as routing indices, no gradient.

Correctness check against PyTorch reference:

moe_fwd_maxdiff:          2.38e-07
moe_grad_expert_maxdiff:  0.0
moe_grad_weight_maxdiff:  9.54e-07

This does not yet fuse the expert ternary projections themselves. It fuses the dense MoE routing combine stage and keeps the fixed-shape dispatch from REFACTOR7, which is the practical route to reducing recompilation churn without destabilizing training.

Graph Kernel Work

Added a Triton-backed Graph gather/add kernel:

Previous graph position path:

graph_features = node_features[vq_indices]
per_position = vq_output + graph_features

New path:

per_position = triton_graph_gather_add(vq_output, node_features, vq_indices)

Backward is Triton-backed:

  • grad_vq_output: direct copy from grad_out.
  • grad_node_features: atomic add by vq_indices.

Correctness check against PyTorch reference:

graph_fwd_maxdiff:       0.0
graph_grad_vq_maxdiff:   0.0
graph_grad_node_maxdiff: 4.77e-07

The Graph path now has two Triton-backed custom stages:

  1. Edge ternary weighting + target aggregation.
  2. VQ-index node gather + residual add.

The remaining unfused Graph work is the full hop body: packed ternary message projection, aggregation, update projection, residual, and hop-LoRA composition in one larger kernel.

Verification

  • python -m py_compile trigram.py tscale.py benchmark_true_ternary.py train.py ternary_audit.py testing/test_tscale.py
  • python -m pytest -q testing/test_tscale.py -k "cuda_triton_correctness_update_E or cuda_triton_tscale_path": 2 passed
  • Isolated MoE combine Triton kernel forward/backward matched PyTorch reference.
  • Isolated Graph gather/add Triton kernel forward/backward matched PyTorch reference.
  • Full CUDA model smoke with VQ, Graph, Memory, and MoE enabled passed forward, backward, and _ternary_update_memory().
  • Strict one-step train smoke stayed zero-float:
trainable float params: 0
frozen float params: 0
float buffers: 0
training step: about 1.92 it/s before eval/checkpoint overhead
final val loss: 7.7302

Remaining Kernel Work

  1. Fuse the dense MoE expert body itself:

    • packed ternary W_gate
    • packed ternary W_transform
    • shared hidden multiply
    • packed ternary shared_down
    • selected route accumulation
  2. Fuse the full Graph hop:

    • packed ternary source projection
    • ternary edge aggregation
    • packed ternary update projection
    • residual + hop-LoRA
  3. Add benchmark counters that separate:

    • first-call Triton compile time
    • steady-state forward/backward time
    • ternary update time