File size: 3,561 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
# True Ternary Refactor 8 — MoE/Graph Triton Kernel Phase

## Scope

This phase keeps the focus on MoE and Graph kernels. It does not change the strict ternary architecture target from REFACTOR6/7.

## MoE Kernel Work

Added a Triton-backed dense MoE combine kernel for the fixed-shape small-batch path.

Previous dense path:

```text
for expert:
    expert_weight = torch.where(topk_idx == expert, topk_weights, 0).sum()
    routed_out += expert_weight * expert_out
```

New dense path:

```text
expert_out_stack = [num_experts, tokens, dim]
routed_out = triton_moe_dense_combine(expert_out_stack, topk_idx, topk_weights)
```

The kernel performs selected expert output accumulation in one Triton launch. Backward is also Triton-backed:

- `grad_expert_out`: atomic accumulated from selected route weights.
- `grad_topk_weights`: dot product of `grad_out` with selected expert output.
- `topk_idx`: treated as routing indices, no gradient.

Correctness check against PyTorch reference:

```text
moe_fwd_maxdiff:          2.38e-07
moe_grad_expert_maxdiff:  0.0
moe_grad_weight_maxdiff:  9.54e-07
```

This does not yet fuse the expert ternary projections themselves. It fuses the dense MoE routing combine stage and keeps the fixed-shape dispatch from REFACTOR7, which is the practical route to reducing recompilation churn without destabilizing training.

## Graph Kernel Work

Added a Triton-backed Graph gather/add kernel:

Previous graph position path:

```text
graph_features = node_features[vq_indices]
per_position = vq_output + graph_features
```

New path:

```text
per_position = triton_graph_gather_add(vq_output, node_features, vq_indices)
```

Backward is Triton-backed:

- `grad_vq_output`: direct copy from `grad_out`.
- `grad_node_features`: atomic add by `vq_indices`.

Correctness check against PyTorch reference:

```text
graph_fwd_maxdiff:       0.0
graph_grad_vq_maxdiff:   0.0
graph_grad_node_maxdiff: 4.77e-07
```

The Graph path now has two Triton-backed custom stages:

1. Edge ternary weighting + target aggregation.
2. VQ-index node gather + residual add.

The remaining unfused Graph work is the full hop body: packed ternary message projection, aggregation, update projection, residual, and hop-LoRA composition in one larger kernel.

## Verification

- `python -m py_compile trigram.py tscale.py benchmark_true_ternary.py train.py ternary_audit.py testing/test_tscale.py`
- `python -m pytest -q testing/test_tscale.py -k "cuda_triton_correctness_update_E or cuda_triton_tscale_path"`: `2 passed`
- Isolated MoE combine Triton kernel forward/backward matched PyTorch reference.
- Isolated Graph gather/add Triton kernel forward/backward matched PyTorch reference.
- Full CUDA model smoke with VQ, Graph, Memory, and MoE enabled passed forward, backward, and `_ternary_update_memory()`.
- Strict one-step train smoke stayed zero-float:

```text
trainable float params: 0
frozen float params: 0
float buffers: 0
training step: about 1.92 it/s before eval/checkpoint overhead
final val loss: 7.7302
```

## Remaining Kernel Work

1. Fuse the dense MoE expert body itself:
   - packed ternary `W_gate`
   - packed ternary `W_transform`
   - shared hidden multiply
   - packed ternary `shared_down`
   - selected route accumulation

2. Fuse the full Graph hop:
   - packed ternary source projection
   - ternary edge aggregation
   - packed ternary update projection
   - residual + hop-LoRA

3. Add benchmark counters that separate:
   - first-call Triton compile time
   - steady-state forward/backward time
   - ternary update time