ARBS / docs /true-ternary /TRUE-TERNARY-REFACTOR9.md
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified

True Ternary Refactor 9 — Platform Components And Output Bridge

Scope

The codebase has moved into the arbitor/ package. This pass focuses only on the newly added platform components:

  • output bridge heads: OutputRouter, VideoHead, TalkerHead
  • custom audio training encoder: AudioVQEncoder
  • imported sidecars: pig_vae, Moonshine audio encoder, ViT/DINO vision encoder
  • new loop-heavy output paths

AudioVQEncoder Ternarization

arbitor/encoders/audio.py was still a custom trainable float module:

nn.Conv1d -> nn.Conv1d -> nn.Linear -> nn.Embedding -> nn.Linear

Converted it to persistent ternary state:

  • Added TernaryConv1d, implemented as unfold + TernaryScaleTensor.
  • Replaced all conv blocks with TernaryConv1d.
  • Replaced proj and out_proj with TernaryScaleTensor.
  • Replaced the VQ codebook nn.Embedding with TernaryEmbeddingTable.

Focused audit:

AudioVQEncoder logical ternary weights: 404,864
trainable float params: 0
frozen float params: 0
float buffers: 0

Focused smoke:

audio_vq_encoder_ok logits=(1, 4, 289), indices=(1, 4)

Output Bridge Ternarization

VideoHead.noise_embed was a hidden float nn.Embedding.

Changed:

nn.Embedding(max_steps, TRIGRAM_DIM)

to:

TernaryEmbeddingTable(max_steps, TRIGRAM_DIM)

Focused audit for VideoHead:

logical ternary weights: 17,040,896
trainable float params: 0
frozen float params: 0
float buffers: 0

TalkerHead.forward() had a nested Python loop:

for token:
    for stride:
        logits = head(state)
        append argmax token

Replaced it with one ternary head call over all conditioning tokens plus repeat_interleave, keeping the same stride/pad/truncate behavior.

Focused smoke:

video_head_ok latents=(1, 16, 1, 32, 32)
talker_head_ok tokens=(1, 10)

Imported Sidecars

pig_vae now explicitly freezes all parameters after optional int8 quantization:

quantize(vae, weights=qint8)
freeze(vae)
for p in vae.parameters(): p.requires_grad = False

Moonshine audio and ViT/DINO vision already default to quantize_weights='int8' through optimum.quanto, then freeze parameters. If optimum.quanto is unavailable, they fall back to frozen BF16; that fallback is not strict ternary, but it is frozen imported sidecar state rather than trainable model state.

New Kernel Support

Added a Triton denoise-step kernel for VideoHead:

latent = (latent - (1 - alpha) * pred_noise) / sqrt(alpha)

Forward and backward are Triton-backed on CUDA. The ACT-style diffusion loop remains because it controls halting and repeated shared-weight denoising, but the per-step latent update is now one custom kernel.

Correctness against PyTorch:

video_denoise_fwd_maxdiff:         7.15e-07
video_denoise_grad_latent_maxdiff: 4.77e-07
video_denoise_grad_pred_maxdiff:   1.79e-07

Model-Level Verification

Package compile:

python -m py_compile arbitor/components.py arbitor/sequencers.py arbitor/encoders/audio.py arbitor/encoders/pig_vae.py arbitor/main.py arbitor/vq.py arbitor/kernel/ternary_scale.py arbitor/kernel/ternary_audit.py

ARBModel with image/audio imports disabled, VQ/Graph/Memory/MoE/output heads enabled:

logical ternary weights: 41,087,552
ternary training state: 53.65 MB
trainable float params: 0
frozen float params: 0
float buffers: 0

Smokes:

arb_model_cpu_forward_ok logits=(2, 8, 297), indices=(2, 8)
arb_model_cuda_train_smoke_ok logits=(2, 8, 297), targets=(2, 7), loss=12.1709

The CUDA smoke completed forward, backward, and _ternary_update_memory().

Remaining Work

  1. Add a strict sidecar audit mode that reports imported quantized sidecars separately from core ternary state.
  2. Add tests that instantiate Moonshine/ViT only when cached locally, to avoid network-dependent CI.
  3. Consider a true ternary transposed-conv replacement if TinyNeuralCodec is promoted from lazy frozen sidecar to trainable core model component.
  4. The VideoHead diffusion control loop is still Python-level. Full fusion would require a fixed-step, no-break kernel variant or a persistent CUDA kernel, which is a larger design change.