# True Ternary Refactoring ## Architecture Contract | Component | Type | Storage | Role | |-----------|------|---------|------| | **T** | {-1, 0, +1} | 5-trit/byte packed (1.6 BPW) | Weight values | | **E** | int8 | 1 per group (log-space exponent) | Scale memory: `S = 2^E` | | **T_accum** | int8 | 1 per weight | Gradient sign accumulator for T flips | | **S** | ephemeral | derived from E in forward | `S = 2^E`, log-space block scale | ### Key design decisions - **No IEEE float anywhere in weight state.** - S is NOT stored — only E (int8) is stored. `S = 2^E` is ephemeral in the computation graph. - Ephemeral float values exist only in autograd's computation graph (forward/backward pass), never in persistent state. - Bias is stored as int32 buffer, cast to float ephemerally during forward. ### Log-Space Representation (Option B) Scales use **log-space storage** as recommended by agents: ``` S = 2^E where E = int8 (log₂ of scale factor) W_eff = T * 2^E ``` Log-space replaces float multiply with integer shift: | Operation | Float version | Log-space version | |-----------|---------------|-------------------| | Scale × scale | `S1 * S2` (float mul) | `E1 + E2` (int add) | | Scale × ternary | `S * T` (float mul) | `T << E \| T >> (-E)` (int shift) | | Dequant in kernel | `sign * scale` (fp16 mul) | `sign << exp` (int shift) | The **TileLang kernel** (`tilelang/kernels/dequant_gemm.py`) uses integer shift directly: ```python # Per-element dequant in TileLang: if sign_val == 0: dequant_int = 0 elif sign_val > 0: dequant_int = 1 << exp_val if exp_val >= 0 else 1 >> (-exp_val) else: dequant_int = -(1 << exp_val) if exp_val >= 0 else -(1 >> (-exp_val)) ``` ### T scale type → block sizing `TILE_SIZE = 384`. TScaleType determines group size and thus E's granularity: | Type | Group size | E entries per 384-dim row | |------|-----------|--------------------------| | T64 | 6 | 64 | | T32 | 12 | 32 | | T16 | 24 | 16 | | T8 | 48 | 8 | | T6 | 64 | 6 | | T4 | 96 | 4 | At model scale=10: block = 3840, group_count scales proportionally. ## Training State: what lives where During training, each TernaryScaleTensor stores: | Buffer | Shape | dtype | Bits/weight | |--------|-------|-------|-------------| | `T_packed` | flat (5-trit/byte) | uint8 | 1.6 | | `E` | (n_groups,) | int8 | 8 / group_size | | `T_accum` | (out, in) | int8 | 8 | | `bias` | (out,) | int32 | 32 / out_dim (negligible) | Total stored = `1.6 + 8 + 8/group_size` bits/weight. At group_size=12: **1.6 + 8 + 0.67 = 10.27 bits/weight** (~1.28 bytes/weight) during training. For inference/storage (no T_accum needed): **1.6 + 0.67 = 2.27 bits/weight** (~0.28 bytes/weight). ## Pipeline (forward) ``` x (ephemeral float from prev layer) → unpack T_packed → T ∈ {-1,0,+1} → expand E → 2^E as ephemeral float → w_eff = (2^E) * T (ephemeral float) → y = F.linear(x, w_eff) → ternarize activation → next layer → register hook to capture grad_w = grad_y^T @ x for T_accum/E updates ``` ## Update Rule ``` After backward: 1. S (E) update (called by model._ternary_update_memory): grad_E = -sign(mean over group of grad_w * T) E = clamp(E + grad_E, -128, 127) 2. T flip (called by model._ternary_update_memory): T_accum = clamp(T_accum + sign(grad_w), -128, 127) if |T_accum| > threshold (default 3): flip T at that position reset T_accum at that position to 0 ``` ## Files Changed ### Core layers: `tscale.py` - `TernaryScaleTensor` — complete rewrite - Removed: `self.weight` (FP32 master weight), `_compute_T`, `_compute_S` - Added: `T_packed`, `E` (int8 log-exponent), `T_accum` (int8 gradient counter) - Added: `ternary_step()`, `update_E()` for per-step updates - `forward`: unpack T → expand E → `w_eff = 2^E * T → linear → capture gradient - `_hook_T`, `_hook_x` captured per-forward-call via closure - `TernaryRMSNorm` — same T+E+accum scheme ### Model architecture: `trigram.py` - `ByteEmbedding` — same T+E+accum scheme, removed `self.weight` - `StickyZoneSTE` — kept (used by `TernaryGNNLayer` for edge_attr ternarization) - `ScaledTernaryLinear` — removed (thin wrapper, no longer meaningful) - `TernaryFFN` — removed (dead code, never used by model) - `T_GRAPH_N_LAYERS` — removed (unused parameter) - `MORPHTernaryModel._ternary_update_memory()` — new method, iterates all modules calling `ternary_step()` + `update_E()` - Remaining float `nn.Parameter` instances (ViT frozen, MemGram embeddings, LSTM cell, edge_attr, router) kept as-is — small and either frozen or not weight matrices ### Serialization: `convert_to_ternary.py` - New file with `pack_ternary()` and `unpack_ternary()` — 5-trit-per-byte base-3 encoding - Independent of `convert_to_ternary8.py` (no circular import) ### Training loop: `train.py` - Removed `from convert_to_ternary import save_model` — no longer needed - Checkpoint save uses `model.state_dict()` directly (all buffers + float params) - After `optimizer.step()` → `model._ternary_update_memory(accum_threshold=3)` ### Deleted: `ht.py` - Commented-out planning notes, superseded by Phase 7 implementation ## Memory Projection for 3B Parameter Model | Component | Size | |-----------|------| | T (packed 5-trit/byte) | ~0.6 GB | | E (int8, group=12) | ~0.25 GB | | T_accum (int8, 1 per weight) | ~3 GB | | Gradients (ephemeral sign-only) | ~3 GB | | Activations (ternary, checkpointed) | ~0.2–0.5 GB | | **Total training** | **~7.4 GB** | ## Test Results (140/140 passing) | Suite | Pass | Fail | Notes | |-------|------|------|-------| | `test_morph.py` | 119 | 0 | All Phase 1-7 tests | | `test_tscale.py` | 21 | 0 | Core ternary + SignSGD tests | ## Pending Work 1. **Remaining float params** (MemGram embeddings, LSTM cell weights, MoE router) — ternarize these for full compliance 2. **TileLang GEMM kernel** — rewrite `dequant_gemm.py` to use int8 E (log-space) instead of float16 scales 3. **Activation ternarization** — optional optimization, clamp inter-layer activations to {-1,0,+1} 4. **Generate with memory** — `generate()` currently passes memory_state but basic sampling