File size: 4,395 Bytes
d8bc908 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 | # True Ternary Refactor 7 — MoE/Graph Kernel Hardening
## Scope
This phase prioritizes MoE and Graph runtime hardening after REFACTOR6 confirmed the internal model can run with zero trainable float params and zero float buffers in strict mode.
The agent's REFACTOR6 readout is consistent with the code behavior:
- `E_accum` being active in most groups is a real signal that residual scale updates are now preserving small gradient evidence instead of dropping it.
- A 20-step loss curve is not enough to judge convergence for a random-initialized pure ternary model with sign flips and integer scale residuals.
- The reported step-time swings are plausibly a kernel-shape/recompile issue, especially in MoE sparse routing where each expert receives a different token count per step.
## MoE Dispatch Hardening
Added a fixed-shape dense dispatch path for small token counts in `SharedProjectionMoE`.
Previous small-batch sparse path:
```text
top-k slot loop -> sort tokens -> expert loop -> variable [n_tokens, D] ternary kernels
```
That is compute-efficient when routing large batches, but it causes unstable Triton kernel shapes when `n_tokens` changes per expert. In strict training with small context/batch sizes this can create repeated compilation and slow outlier steps.
New small-batch path:
```text
if N <= dense_dispatch_max_tokens:
run every expert on fixed [N, D] inputs
multiply each expert output by its top-k routing weight, zero for unselected tokens
else:
keep sparse dynamic expert dispatch
```
Default cutoff:
```text
dense_dispatch_max_tokens = 128
```
This keeps the sparse path available for larger runs while making strict-mode smoke and small-batch training use stable shapes.
Small CUDA timing on `B=2, L=8, D=512`:
```text
dense_fixed_shape: first call 0.3286s, then 0.0047, 0.0037, 0.0037, 0.0036
sparse_dynamic_shape: first call 0.0415s, then 0.0077, 0.0088, 0.0076, 0.0080
```
The first call still includes compilation/setup, but the steady-state fixed-shape path is faster for this strict small-batch shape and avoids per-expert token-count shape churn.
## Graph Hardening
REFACTOR6 already added a Triton Graph aggregation kernel:
```text
projected messages + ternary edge weighting + target aggregation
```
This replaces the previous CUDA path that materialized `messages` and then called `scatter_add_`. The Graph path still has a Python hop loop and still calls the ternary projections per hop. A full graph-hop fusion remains a separate kernel because it needs packed ternary decode, edge aggregation, update projection, residual, and hop LoRA scheduling in one launch.
## Learning Hardening
Fixed `ByteEmbedding.ternary_step()` to honor `_t_accum_step`.
Before this fix, `TernaryScaleTensor` and `TernaryEmbeddingTable` consumed the loss-scaled integer step, but `ByteEmbedding` still did:
```text
T_accum += sign(grad)
```
Now it does:
```text
T_accum += sign(grad) * t_accum_step
```
This matters because text byte embeddings are part of the strict model's trainable ternary state. Without this, the embedding trits could lag behind the rest of the model and require more updates to reach threshold.
## Verification
- `python -m py_compile trigram.py tscale.py benchmark_true_ternary.py train.py ternary_audit.py testing/test_tscale.py`
- `python -m pytest -q testing/test_tscale.py -k "cuda_triton_correctness_update_E or cuda_triton_tscale_path"`: `2 passed`
- `moe_dense_dispatch_cuda_ok`: forward/backward passed.
- `byte_embedding_t_step_ok`: embedding accumulator consumed the loss-scaled step and flipped at threshold.
- Full CUDA model smoke with VQ, Graph, Memory, and MoE enabled passed forward, backward, and `_ternary_update_memory()`.
- Strict one-step train smoke remained zero-float and ran the training step at about `1.91 it/s` before eval/checkpoint overhead.
## Remaining Kernel Work
1. Full fused MoE expert kernel:
- packed ternary expert decode
- top-k route weights
- per-expert low-rank projection
- shared hidden multiply
- weighted output accumulation
2. Full fused Graph hop kernel:
- packed ternary projection decode
- edge aggregation
- update projection
- residual/hop-LoRA composition
3. Scale precision extension:
- Current `S = 2^E`.
- Exact `S = 99.9` style values need a low-overhead integer/fixed-point mantissa or residual scale lattice.
|