File size: 6,324 Bytes
d8bc908 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 | # True Ternary Refactor 6 — Architecture Ternarization And Accumulator Hardening
## Scope
This pass moves the non-imported MORPH architecture toward persistent ternary state everywhere. ViT and Whisper remain imported frozen encoders as requested. The internal trainable/storage components are now ternary buffers or integer buffers rather than FP parameters.
## Architecture Ternarization
Converted internal float trainable components:
- `ImageSequencer.patch_proj`: `nn.Linear` -> `TernaryScaleTensor`
- `AudioSequencer.frame_proj`: `nn.Linear` -> `TernaryScaleTensor`
- `ModalityGate.weights`: float parameter -> int8 buffer
- `GNNLoRAAdapter.B`: float parameter -> `TernaryScaleTensor` up projection
- `GNNLoRAAdapter.scale`: `nn.Embedding` -> `TernaryEmbeddingTable`
- `MemGram.struct_emb` / `conv_emb`: float `ParameterList` -> `TernaryEmbeddingTable` modules
- `MemGram` strength/decay logits: float parameters -> int8 buffers
- `FocusGate.boundary_embed`: `nn.Embedding` -> `TernaryEmbeddingTable`
- `FocusGate.reset_fc` / `dampen_fc`: `nn.Linear` -> `TernaryScaleTensor`
- `ConversationLSTM.focus_cell` / `topic_cell`: `nn.LSTMCell` -> `TernaryLSTMCell`
- `ConversationLSTM.topic_gate_fc`: `nn.Linear` -> `TernaryScaleTensor`
- `GraphMoEGate.query`: float parameter -> `TernaryScaleTensor` query projection
- `TernaryGraph.edge_attr`: float parameter -> int8 ternary edge buffer
- `VQAdapter`: `FlashVQCodebook` float buffers -> `TernaryVQCodebook` with `TernaryEmbeddingTable`
- `ConvVQCodebook.embed`: float buffer -> `TernaryEmbeddingTable`
- `ConvVQCodebook` strength/decay logits: float parameters -> int8 buffers
New reusable modules:
- `TernaryEmbeddingTable`: packed ternary lookup table with int8 `E`, int8 `E_accum`, and int8 `T_accum`.
- `TernaryLSTMCell`: LSTM-style gate cell using one ternary projection over `[x, h]`.
- `TernaryVQCodebook`: VQ lookup against a ternary embedding table with integer cluster counters.
## Audit Results
Text/internal full architecture without ViT/Whisper:
```text
logical ternary weights: 23,887,936
ternary training state: 31.22 MB
trainable float params: 0 tensors, 0.00 MB
frozen float params: 0 tensors, 0.00 MB
float buffers: 0 tensors, 0.00 MB
```
Full model with ViT and Whisper enabled:
```text
logical ternary weights: 25,560,128
ternary training state: 33.40 MB
trainable float params: 0 tensors, 0.00 MB
frozen float params: ViT/Whisper only
non-imported trainable float params: 0
non-imported float buffers: 0
```
## Loss / Accumulator Hardening
The previous strict update could fail to flip `T` because `T_accum` only moved by `1` per update and many gradients changed direction before reaching threshold `3`.
Added loss-strength integer accumulator stepping:
```text
loss_signal -> t_step in {1, 2, 3, 4}
T_accum += sign(grad) * t_step
```
This keeps `T_accum` int8, but lets high-loss updates reach threshold faster without adding float optimizer state. The Triton ternary-step kernels now accept `T_ACCUM_STEP`, and `_ternary_update_memory()` sets `_t_accum_step` per update from the current loss.
## Scale Semantics
`E` remains an int8 logarithmic exponent and `S` remains derived, not stored:
```text
W = T * 2^E
```
The effective weight values are not limited to `{-1, 0, +1}`. They are:
```text
{-S, 0, +S}
```
So if the scale path represents `S = 99.9`, then the effective group values are `{ -99.9, 0, +99.9 }`. Current implementation uses base-2 integer exponent scales; representing non-power-of-two values like `99.9` exactly would require either a mantissa/residual scale field or a different logarithmic base/lattice. The current approach keeps persistent state integer-only and low overhead.
## Kernel Status
The packed ternary linear, embedding, RMSNorm, `E` update, and `T_accum` update paths are Triton-backed. Graph edge weighting plus target aggregation is now also Triton-backed on CUDA, with a custom backward for projected message gradients.
MoE and Graph still contain Python-level control flow around multiple ternary kernels:
- MoE loops over top-k/expert routing and calls ternary projections per expert.
- Graph still loops over hops and calls GNN/update projections per hop, but each hop no longer materializes `messages` and calls `scatter_add_`; ternary edge weighting and aggregation are one Triton launch.
This pass did not honestly collapse the full MoE or Graph computation into one monolithic Triton kernel. Doing that correctly requires a dedicated packed-ternary fused expert dispatch kernel and a fused graph message-passing kernel that decode packed weights, route/scatter tokens, and update outputs inside one launch. The architecture is now ternary enough for that kernel work to be the next isolated performance phase.
## Verification
- `python -m py_compile trigram.py tscale.py benchmark_true_ternary.py train.py ternary_audit.py testing/test_tscale.py`
- `PASS test_cuda_triton_correctness_update_E`
- `PASS test_cuda_triton_tscale_path`
- `PASS graph_aggregate_cuda_ok`
- Full text/internal audit: zero float params and zero float buffers.
- Strict train construction now passes `enable_audio=False` as well as `enable_image=False`, so strict mode no longer instantiates Whisper.
- Strict train-style audit with image/audio/VQ/graph/memory disabled and MoE enabled:
```text
logical ternary weights: 14,011,904
ternary training state: 18.27 MB
trainable float params: 0 tensors, 0.00 MB
frozen float params: 0 tensors, 0.00 MB
float buffers: 0 tensors, 0.00 MB
```
- Current strict train smoke after disabling audio in strict mode ran 3 steps with zero float params/buffers and loss moved `8.2048 -> 9.7809 -> 7.7685`; final eval loss `6.4239`.
- CUDA full-path smoke with VQ, graph, memory, and MoE enabled passed forward, backward, and `_ternary_update_memory()`.
## Remaining Work
1. Build fused MoE Triton dispatch kernel for top-k expert routing and expert projection scheduling.
2. Extend the Graph Triton aggregation kernel into a full fused message-passing/hop-update kernel.
3. Add component-specific ternary backward routing so LossComponents can update selected ternary module groups separately, not only through weighted total loss.
4. Consider a low-overhead mantissa/residual scale lattice if exact non-power-of-two scale values such as `99.9` become required.
|