| # True Ternary Refactoring |
|
|
| ## Architecture Contract |
|
|
| | Component | Type | Storage | Role | |
| |-----------|------|---------|------| |
| | **T** | {-1, 0, +1} | 5-trit/byte packed (1.6 BPW) | Weight values | |
| | **E** | int8 | 1 per group (log-space exponent) | Scale memory: `S = 2^E` | |
| | **T_accum** | int8 | 1 per weight | Gradient sign accumulator for T flips | |
| | **S** | ephemeral | derived from E in forward | `S = 2^E`, log-space block scale | |
| |
| ### Key design decisions |
| |
| - **No IEEE float anywhere in weight state.** |
| - S is NOT stored β only E (int8) is stored. `S = 2^E` is ephemeral in the computation graph. |
| - Ephemeral float values exist only in autograd's computation graph (forward/backward pass), never in persistent state. |
| - Bias is stored as int32 buffer, cast to float ephemerally during forward. |
| |
| ### Log-Space Representation (Option B) |
| |
| Scales use **log-space storage** as recommended by agents: |
| |
| ``` |
| S = 2^E where E = int8 (logβ of scale factor) |
| W_eff = T * 2^E |
| ``` |
| |
| Log-space replaces float multiply with integer shift: |
| |
| | Operation | Float version | Log-space version | |
| |-----------|---------------|-------------------| |
| | Scale Γ scale | `S1 * S2` (float mul) | `E1 + E2` (int add) | |
| | Scale Γ ternary | `S * T` (float mul) | `T << E \| T >> (-E)` (int shift) | |
| | Dequant in kernel | `sign * scale` (fp16 mul) | `sign << exp` (int shift) | |
| |
| The **TileLang kernel** (`tilelang/kernels/dequant_gemm.py`) uses integer shift directly: |
|
|
| ```python |
| # Per-element dequant in TileLang: |
| if sign_val == 0: dequant_int = 0 |
| elif sign_val > 0: dequant_int = 1 << exp_val if exp_val >= 0 else 1 >> (-exp_val) |
| else: dequant_int = -(1 << exp_val) if exp_val >= 0 else -(1 >> (-exp_val)) |
| ``` |
|
|
| ### T scale type β block sizing |
|
|
| `TILE_SIZE = 384`. TScaleType determines group size and thus E's granularity: |
|
|
| | Type | Group size | E entries per 384-dim row | |
| |------|-----------|--------------------------| |
| | T64 | 6 | 64 | |
| | T32 | 12 | 32 | |
| | T16 | 24 | 16 | |
| | T8 | 48 | 8 | |
| | T6 | 64 | 6 | |
| | T4 | 96 | 4 | |
|
|
| At model scale=10: block = 3840, group_count scales proportionally. |
| |
| ## Training State: what lives where |
| |
| During training, each TernaryScaleTensor stores: |
| |
| | Buffer | Shape | dtype | Bits/weight | |
| |--------|-------|-------|-------------| |
| | `T_packed` | flat (5-trit/byte) | uint8 | 1.6 | |
| | `E` | (n_groups,) | int8 | 8 / group_size | |
| | `T_accum` | (out, in) | int8 | 8 | |
| | `bias` | (out,) | int32 | 32 / out_dim (negligible) | |
| |
| Total stored = `1.6 + 8 + 8/group_size` bits/weight. |
|
|
| At group_size=12: **1.6 + 8 + 0.67 = 10.27 bits/weight** (~1.28 bytes/weight) during training. |
| |
| For inference/storage (no T_accum needed): **1.6 + 0.67 = 2.27 bits/weight** (~0.28 bytes/weight). |
|
|
| ## Pipeline (forward) |
|
|
| ``` |
| x (ephemeral float from prev layer) |
| β unpack T_packed β T β {-1,0,+1} |
| β expand E β 2^E as ephemeral float |
| β w_eff = (2^E) * T (ephemeral float) |
| β y = F.linear(x, w_eff) β ternarize activation β next layer |
| β register hook to capture grad_w = grad_y^T @ x for T_accum/E updates |
| ``` |
|
|
| ## Update Rule |
|
|
| ``` |
| After backward: |
| |
| 1. S (E) update (called by model._ternary_update_memory): |
| grad_E = -sign(mean over group of grad_w * T) |
| E = clamp(E + grad_E, -128, 127) |
| |
| 2. T flip (called by model._ternary_update_memory): |
| T_accum = clamp(T_accum + sign(grad_w), -128, 127) |
| if |T_accum| > threshold (default 3): |
| flip T at that position |
| reset T_accum at that position to 0 |
| ``` |
|
|
| ## Files Changed |
|
|
| ### Core layers: `tscale.py` |
| - `TernaryScaleTensor` β complete rewrite |
| - Removed: `self.weight` (FP32 master weight), `_compute_T`, `_compute_S` |
| - Added: `T_packed`, `E` (int8 log-exponent), `T_accum` (int8 gradient counter) |
| - Added: `ternary_step()`, `update_E()` for per-step updates |
| - `forward`: unpack T β expand E β `w_eff = 2^E * T β linear β capture gradient |
| - `_hook_T`, `_hook_x` captured per-forward-call via closure |
| - `TernaryRMSNorm` β same T+E+accum scheme |
| |
| ### Model architecture: `trigram.py` |
| - `ByteEmbedding` β same T+E+accum scheme, removed `self.weight` |
| - `StickyZoneSTE` β kept (used by `TernaryGNNLayer` for edge_attr ternarization) |
| - `ScaledTernaryLinear` β removed (thin wrapper, no longer meaningful) |
| - `TernaryFFN` β removed (dead code, never used by model) |
| - `T_GRAPH_N_LAYERS` β removed (unused parameter) |
| - `MORPHTernaryModel._ternary_update_memory()` β new method, iterates all modules calling `ternary_step()` + `update_E()` |
| - Remaining float `nn.Parameter` instances (ViT frozen, MemGram embeddings, LSTM cell, edge_attr, router) kept as-is β small and either frozen or not weight matrices |
| |
| ### Serialization: `convert_to_ternary.py` |
| - New file with `pack_ternary()` and `unpack_ternary()` β 5-trit-per-byte base-3 encoding |
| - Independent of `convert_to_ternary8.py` (no circular import) |
| |
| ### Training loop: `train.py` |
| - Removed `from convert_to_ternary import save_model` β no longer needed |
| - Checkpoint save uses `model.state_dict()` directly (all buffers + float params) |
| - After `optimizer.step()` β `model._ternary_update_memory(accum_threshold=3)` |
|
|
| ### Deleted: `ht.py` |
| - Commented-out planning notes, superseded by Phase 7 implementation |
|
|
| ## Memory Projection for 3B Parameter Model |
|
|
| | Component | Size | |
| |-----------|------| |
| | T (packed 5-trit/byte) | ~0.6 GB | |
| | E (int8, group=12) | ~0.25 GB | |
| | T_accum (int8, 1 per weight) | ~3 GB | |
| | Gradients (ephemeral sign-only) | ~3 GB | |
| | Activations (ternary, checkpointed) | ~0.2β0.5 GB | |
| | **Total training** | **~7.4 GB** | |
| |
| ## Test Results (140/140 passing) |
| |
| | Suite | Pass | Fail | Notes | |
| |-------|------|------|-------| |
| | `test_morph.py` | 119 | 0 | All Phase 1-7 tests | |
| | `test_tscale.py` | 21 | 0 | Core ternary + SignSGD tests | |
|
|
| ## Pending Work |
|
|
| 1. **Remaining float params** (MemGram embeddings, LSTM cell weights, MoE router) β ternarize these for full compliance |
| 2. **TileLang GEMM kernel** β rewrite `dequant_gemm.py` to use int8 E (log-space) instead of float16 scales |
| 3. **Activation ternarization** β optional optimization, clamp inter-layer activations to {-1,0,+1} |
| 4. **Generate with memory** β `generate()` currently passes memory_state but basic sampling |
| |