ARBS / docs /true-ternary /TRUE_TERNARY_REFACTOR.md
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified

True Ternary Refactoring

Architecture Contract

Component Type Storage Role
T {-1, 0, +1} 5-trit/byte packed (1.6 BPW) Weight values
E int8 1 per group (log-space exponent) Scale memory: S = 2^E
T_accum int8 1 per weight Gradient sign accumulator for T flips
S ephemeral derived from E in forward S = 2^E, log-space block scale

Key design decisions

  • No IEEE float anywhere in weight state.
  • S is NOT stored β€” only E (int8) is stored. S = 2^E is ephemeral in the computation graph.
  • Ephemeral float values exist only in autograd's computation graph (forward/backward pass), never in persistent state.
  • Bias is stored as int32 buffer, cast to float ephemerally during forward.

Log-Space Representation (Option B)

Scales use log-space storage as recommended by agents:

S = 2^E       where E = int8 (logβ‚‚ of scale factor)
W_eff = T * 2^E

Log-space replaces float multiply with integer shift:

Operation Float version Log-space version
Scale Γ— scale S1 * S2 (float mul) E1 + E2 (int add)
Scale Γ— ternary S * T (float mul) T << E | T >> (-E) (int shift)
Dequant in kernel sign * scale (fp16 mul) sign << exp (int shift)

The TileLang kernel (tilelang/kernels/dequant_gemm.py) uses integer shift directly:

# Per-element dequant in TileLang:
if sign_val == 0:       dequant_int = 0
elif sign_val > 0:      dequant_int = 1 << exp_val   if exp_val >= 0 else 1 >> (-exp_val)
else:                   dequant_int = -(1 << exp_val) if exp_val >= 0 else -(1 >> (-exp_val))

T scale type β†’ block sizing

TILE_SIZE = 384. TScaleType determines group size and thus E's granularity:

Type Group size E entries per 384-dim row
T64 6 64
T32 12 32
T16 24 16
T8 48 8
T6 64 6
T4 96 4

At model scale=10: block = 3840, group_count scales proportionally.

Training State: what lives where

During training, each TernaryScaleTensor stores:

Buffer Shape dtype Bits/weight
T_packed flat (5-trit/byte) uint8 1.6
E (n_groups,) int8 8 / group_size
T_accum (out, in) int8 8
bias (out,) int32 32 / out_dim (negligible)

Total stored = 1.6 + 8 + 8/group_size bits/weight.

At group_size=12: 1.6 + 8 + 0.67 = 10.27 bits/weight (~1.28 bytes/weight) during training.

For inference/storage (no T_accum needed): 1.6 + 0.67 = 2.27 bits/weight (~0.28 bytes/weight).

Pipeline (forward)

x (ephemeral float from prev layer)
  β†’ unpack T_packed β†’ T ∈ {-1,0,+1}
  β†’ expand E β†’ 2^E as ephemeral float
  β†’ w_eff = (2^E) * T  (ephemeral float)
  β†’ y = F.linear(x, w_eff)  β†’ ternarize activation β†’ next layer
  β†’ register hook to capture grad_w = grad_y^T @ x for T_accum/E updates

Update Rule

After backward:

1. S (E) update (called by model._ternary_update_memory):
   grad_E = -sign(mean over group of grad_w * T)
   E = clamp(E + grad_E, -128, 127)

2. T flip (called by model._ternary_update_memory):
   T_accum = clamp(T_accum + sign(grad_w), -128, 127)
   if |T_accum| > threshold (default 3):
       flip T at that position
       reset T_accum at that position to 0

Files Changed

Core layers: tscale.py

  • TernaryScaleTensor β€” complete rewrite
    • Removed: self.weight (FP32 master weight), _compute_T, _compute_S
    • Added: T_packed, E (int8 log-exponent), T_accum (int8 gradient counter)
    • Added: ternary_step(), update_E() for per-step updates
    • forward: unpack T β†’ expand E β†’ `w_eff = 2^E * T β†’ linear β†’ capture gradient
    • _hook_T, _hook_x captured per-forward-call via closure
  • TernaryRMSNorm β€” same T+E+accum scheme

Model architecture: trigram.py

  • ByteEmbedding β€” same T+E+accum scheme, removed self.weight
  • StickyZoneSTE β€” kept (used by TernaryGNNLayer for edge_attr ternarization)
  • ScaledTernaryLinear β€” removed (thin wrapper, no longer meaningful)
  • TernaryFFN β€” removed (dead code, never used by model)
  • T_GRAPH_N_LAYERS β€” removed (unused parameter)
  • MORPHTernaryModel._ternary_update_memory() β€” new method, iterates all modules calling ternary_step() + update_E()
  • Remaining float nn.Parameter instances (ViT frozen, MemGram embeddings, LSTM cell, edge_attr, router) kept as-is β€” small and either frozen or not weight matrices

Serialization: convert_to_ternary.py

  • New file with pack_ternary() and unpack_ternary() β€” 5-trit-per-byte base-3 encoding
  • Independent of convert_to_ternary8.py (no circular import)

Training loop: train.py

  • Removed from convert_to_ternary import save_model β€” no longer needed
  • Checkpoint save uses model.state_dict() directly (all buffers + float params)
  • After optimizer.step() β†’ model._ternary_update_memory(accum_threshold=3)

Deleted: ht.py

  • Commented-out planning notes, superseded by Phase 7 implementation

Memory Projection for 3B Parameter Model

Component Size
T (packed 5-trit/byte) ~0.6 GB
E (int8, group=12) ~0.25 GB
T_accum (int8, 1 per weight) ~3 GB
Gradients (ephemeral sign-only) ~3 GB
Activations (ternary, checkpointed) ~0.2–0.5 GB
Total training ~7.4 GB

Test Results (140/140 passing)

Suite Pass Fail Notes
test_morph.py 119 0 All Phase 1-7 tests
test_tscale.py 21 0 Core ternary + SignSGD tests

Pending Work

  1. Remaining float params (MemGram embeddings, LSTM cell weights, MoE router) β€” ternarize these for full compliance
  2. TileLang GEMM kernel β€” rewrite dequant_gemm.py to use int8 E (log-space) instead of float16 scales
  3. Activation ternarization β€” optional optimization, clamp inter-layer activations to {-1,0,+1}
  4. Generate with memory β€” generate() currently passes memory_state but basic sampling