ARBS / docs /true-ternary /TRUE-TERNARY-REFACTOR6.md
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified

True Ternary Refactor 6 — Architecture Ternarization And Accumulator Hardening

Scope

This pass moves the non-imported MORPH architecture toward persistent ternary state everywhere. ViT and Whisper remain imported frozen encoders as requested. The internal trainable/storage components are now ternary buffers or integer buffers rather than FP parameters.

Architecture Ternarization

Converted internal float trainable components:

  • ImageSequencer.patch_proj: nn.Linear -> TernaryScaleTensor
  • AudioSequencer.frame_proj: nn.Linear -> TernaryScaleTensor
  • ModalityGate.weights: float parameter -> int8 buffer
  • GNNLoRAAdapter.B: float parameter -> TernaryScaleTensor up projection
  • GNNLoRAAdapter.scale: nn.Embedding -> TernaryEmbeddingTable
  • MemGram.struct_emb / conv_emb: float ParameterList -> TernaryEmbeddingTable modules
  • MemGram strength/decay logits: float parameters -> int8 buffers
  • FocusGate.boundary_embed: nn.Embedding -> TernaryEmbeddingTable
  • FocusGate.reset_fc / dampen_fc: nn.Linear -> TernaryScaleTensor
  • ConversationLSTM.focus_cell / topic_cell: nn.LSTMCell -> TernaryLSTMCell
  • ConversationLSTM.topic_gate_fc: nn.Linear -> TernaryScaleTensor
  • GraphMoEGate.query: float parameter -> TernaryScaleTensor query projection
  • TernaryGraph.edge_attr: float parameter -> int8 ternary edge buffer
  • VQAdapter: FlashVQCodebook float buffers -> TernaryVQCodebook with TernaryEmbeddingTable
  • ConvVQCodebook.embed: float buffer -> TernaryEmbeddingTable
  • ConvVQCodebook strength/decay logits: float parameters -> int8 buffers

New reusable modules:

  • TernaryEmbeddingTable: packed ternary lookup table with int8 E, int8 E_accum, and int8 T_accum.
  • TernaryLSTMCell: LSTM-style gate cell using one ternary projection over [x, h].
  • TernaryVQCodebook: VQ lookup against a ternary embedding table with integer cluster counters.

Audit Results

Text/internal full architecture without ViT/Whisper:

logical ternary weights: 23,887,936
ternary training state: 31.22 MB
trainable float params: 0 tensors, 0.00 MB
frozen float params: 0 tensors, 0.00 MB
float buffers: 0 tensors, 0.00 MB

Full model with ViT and Whisper enabled:

logical ternary weights: 25,560,128
ternary training state: 33.40 MB
trainable float params: 0 tensors, 0.00 MB
frozen float params: ViT/Whisper only
non-imported trainable float params: 0
non-imported float buffers: 0

Loss / Accumulator Hardening

The previous strict update could fail to flip T because T_accum only moved by 1 per update and many gradients changed direction before reaching threshold 3.

Added loss-strength integer accumulator stepping:

loss_signal -> t_step in {1, 2, 3, 4}
T_accum += sign(grad) * t_step

This keeps T_accum int8, but lets high-loss updates reach threshold faster without adding float optimizer state. The Triton ternary-step kernels now accept T_ACCUM_STEP, and _ternary_update_memory() sets _t_accum_step per update from the current loss.

Scale Semantics

E remains an int8 logarithmic exponent and S remains derived, not stored:

W = T * 2^E

The effective weight values are not limited to {-1, 0, +1}. They are:

{-S, 0, +S}

So if the scale path represents S = 99.9, then the effective group values are { -99.9, 0, +99.9 }. Current implementation uses base-2 integer exponent scales; representing non-power-of-two values like 99.9 exactly would require either a mantissa/residual scale field or a different logarithmic base/lattice. The current approach keeps persistent state integer-only and low overhead.

Kernel Status

The packed ternary linear, embedding, RMSNorm, E update, and T_accum update paths are Triton-backed. Graph edge weighting plus target aggregation is now also Triton-backed on CUDA, with a custom backward for projected message gradients.

MoE and Graph still contain Python-level control flow around multiple ternary kernels:

  • MoE loops over top-k/expert routing and calls ternary projections per expert.
  • Graph still loops over hops and calls GNN/update projections per hop, but each hop no longer materializes messages and calls scatter_add_; ternary edge weighting and aggregation are one Triton launch.

This pass did not honestly collapse the full MoE or Graph computation into one monolithic Triton kernel. Doing that correctly requires a dedicated packed-ternary fused expert dispatch kernel and a fused graph message-passing kernel that decode packed weights, route/scatter tokens, and update outputs inside one launch. The architecture is now ternary enough for that kernel work to be the next isolated performance phase.

Verification

  • python -m py_compile trigram.py tscale.py benchmark_true_ternary.py train.py ternary_audit.py testing/test_tscale.py
  • PASS test_cuda_triton_correctness_update_E
  • PASS test_cuda_triton_tscale_path
  • PASS graph_aggregate_cuda_ok
  • Full text/internal audit: zero float params and zero float buffers.
  • Strict train construction now passes enable_audio=False as well as enable_image=False, so strict mode no longer instantiates Whisper.
  • Strict train-style audit with image/audio/VQ/graph/memory disabled and MoE enabled:
logical ternary weights: 14,011,904
ternary training state: 18.27 MB
trainable float params: 0 tensors, 0.00 MB
frozen float params: 0 tensors, 0.00 MB
float buffers: 0 tensors, 0.00 MB
  • Current strict train smoke after disabling audio in strict mode ran 3 steps with zero float params/buffers and loss moved 8.2048 -> 9.7809 -> 7.7685; final eval loss 6.4239.
  • CUDA full-path smoke with VQ, graph, memory, and MoE enabled passed forward, backward, and _ternary_update_memory().

Remaining Work

  1. Build fused MoE Triton dispatch kernel for top-k expert routing and expert projection scheduling.
  2. Extend the Graph Triton aggregation kernel into a full fused message-passing/hop-update kernel.
  3. Add component-specific ternary backward routing so LossComponents can update selected ternary module groups separately, not only through weighted total loss.
  4. Consider a low-overhead mantissa/residual scale lattice if exact non-power-of-two scale values such as 99.9 become required.