ARBS / docs /true-ternary /TRUE-TERNARY-REFACTOR9.md
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
# True Ternary Refactor 9 — Platform Components And Output Bridge
## Scope
The codebase has moved into the `arbitor/` package. This pass focuses only on the newly added platform components:
- output bridge heads: `OutputRouter`, `VideoHead`, `TalkerHead`
- custom audio training encoder: `AudioVQEncoder`
- imported sidecars: `pig_vae`, Moonshine audio encoder, ViT/DINO vision encoder
- new loop-heavy output paths
## AudioVQEncoder Ternarization
`arbitor/encoders/audio.py` was still a custom trainable float module:
```text
nn.Conv1d -> nn.Conv1d -> nn.Linear -> nn.Embedding -> nn.Linear
```
Converted it to persistent ternary state:
- Added `TernaryConv1d`, implemented as `unfold + TernaryScaleTensor`.
- Replaced all conv blocks with `TernaryConv1d`.
- Replaced `proj` and `out_proj` with `TernaryScaleTensor`.
- Replaced the VQ codebook `nn.Embedding` with `TernaryEmbeddingTable`.
Focused audit:
```text
AudioVQEncoder logical ternary weights: 404,864
trainable float params: 0
frozen float params: 0
float buffers: 0
```
Focused smoke:
```text
audio_vq_encoder_ok logits=(1, 4, 289), indices=(1, 4)
```
## Output Bridge Ternarization
`VideoHead.noise_embed` was a hidden float `nn.Embedding`.
Changed:
```text
nn.Embedding(max_steps, TRIGRAM_DIM)
```
to:
```text
TernaryEmbeddingTable(max_steps, TRIGRAM_DIM)
```
Focused audit for `VideoHead`:
```text
logical ternary weights: 17,040,896
trainable float params: 0
frozen float params: 0
float buffers: 0
```
`TalkerHead.forward()` had a nested Python loop:
```text
for token:
for stride:
logits = head(state)
append argmax token
```
Replaced it with one ternary head call over all conditioning tokens plus `repeat_interleave`, keeping the same stride/pad/truncate behavior.
Focused smoke:
```text
video_head_ok latents=(1, 16, 1, 32, 32)
talker_head_ok tokens=(1, 10)
```
## Imported Sidecars
`pig_vae` now explicitly freezes all parameters after optional int8 quantization:
```text
quantize(vae, weights=qint8)
freeze(vae)
for p in vae.parameters(): p.requires_grad = False
```
Moonshine audio and ViT/DINO vision already default to `quantize_weights='int8'` through `optimum.quanto`, then freeze parameters. If `optimum.quanto` is unavailable, they fall back to frozen BF16; that fallback is not strict ternary, but it is frozen imported sidecar state rather than trainable model state.
## New Kernel Support
Added a Triton denoise-step kernel for `VideoHead`:
```text
latent = (latent - (1 - alpha) * pred_noise) / sqrt(alpha)
```
Forward and backward are Triton-backed on CUDA. The ACT-style diffusion loop remains because it controls halting and repeated shared-weight denoising, but the per-step latent update is now one custom kernel.
Correctness against PyTorch:
```text
video_denoise_fwd_maxdiff: 7.15e-07
video_denoise_grad_latent_maxdiff: 4.77e-07
video_denoise_grad_pred_maxdiff: 1.79e-07
```
## Model-Level Verification
Package compile:
```text
python -m py_compile arbitor/components.py arbitor/sequencers.py arbitor/encoders/audio.py arbitor/encoders/pig_vae.py arbitor/main.py arbitor/vq.py arbitor/kernel/ternary_scale.py arbitor/kernel/ternary_audit.py
```
ARBModel with image/audio imports disabled, VQ/Graph/Memory/MoE/output heads enabled:
```text
logical ternary weights: 41,087,552
ternary training state: 53.65 MB
trainable float params: 0
frozen float params: 0
float buffers: 0
```
Smokes:
```text
arb_model_cpu_forward_ok logits=(2, 8, 297), indices=(2, 8)
arb_model_cuda_train_smoke_ok logits=(2, 8, 297), targets=(2, 7), loss=12.1709
```
The CUDA smoke completed forward, backward, and `_ternary_update_memory()`.
## Remaining Work
1. Add a strict sidecar audit mode that reports imported quantized sidecars separately from core ternary state.
2. Add tests that instantiate Moonshine/ViT only when cached locally, to avoid network-dependent CI.
3. Consider a true ternary transposed-conv replacement if `TinyNeuralCodec` is promoted from lazy frozen sidecar to trainable core model component.
4. The VideoHead diffusion control loop is still Python-level. Full fusion would require a fixed-step, no-break kernel variant or a persistent CUDA kernel, which is a larger design change.