File size: 6,255 Bytes
d8bc908 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 | # True Ternary Refactoring
## Architecture Contract
| Component | Type | Storage | Role |
|-----------|------|---------|------|
| **T** | {-1, 0, +1} | 5-trit/byte packed (1.6 BPW) | Weight values |
| **E** | int8 | 1 per group (log-space exponent) | Scale memory: `S = 2^E` |
| **T_accum** | int8 | 1 per weight | Gradient sign accumulator for T flips |
| **S** | ephemeral | derived from E in forward | `S = 2^E`, log-space block scale |
### Key design decisions
- **No IEEE float anywhere in weight state.**
- S is NOT stored β only E (int8) is stored. `S = 2^E` is ephemeral in the computation graph.
- Ephemeral float values exist only in autograd's computation graph (forward/backward pass), never in persistent state.
- Bias is stored as int32 buffer, cast to float ephemerally during forward.
### Log-Space Representation (Option B)
Scales use **log-space storage** as recommended by agents:
```
S = 2^E where E = int8 (logβ of scale factor)
W_eff = T * 2^E
```
Log-space replaces float multiply with integer shift:
| Operation | Float version | Log-space version |
|-----------|---------------|-------------------|
| Scale Γ scale | `S1 * S2` (float mul) | `E1 + E2` (int add) |
| Scale Γ ternary | `S * T` (float mul) | `T << E \| T >> (-E)` (int shift) |
| Dequant in kernel | `sign * scale` (fp16 mul) | `sign << exp` (int shift) |
The **TileLang kernel** (`tilelang/kernels/dequant_gemm.py`) uses integer shift directly:
```python
# Per-element dequant in TileLang:
if sign_val == 0: dequant_int = 0
elif sign_val > 0: dequant_int = 1 << exp_val if exp_val >= 0 else 1 >> (-exp_val)
else: dequant_int = -(1 << exp_val) if exp_val >= 0 else -(1 >> (-exp_val))
```
### T scale type β block sizing
`TILE_SIZE = 384`. TScaleType determines group size and thus E's granularity:
| Type | Group size | E entries per 384-dim row |
|------|-----------|--------------------------|
| T64 | 6 | 64 |
| T32 | 12 | 32 |
| T16 | 24 | 16 |
| T8 | 48 | 8 |
| T6 | 64 | 6 |
| T4 | 96 | 4 |
At model scale=10: block = 3840, group_count scales proportionally.
## Training State: what lives where
During training, each TernaryScaleTensor stores:
| Buffer | Shape | dtype | Bits/weight |
|--------|-------|-------|-------------|
| `T_packed` | flat (5-trit/byte) | uint8 | 1.6 |
| `E` | (n_groups,) | int8 | 8 / group_size |
| `T_accum` | (out, in) | int8 | 8 |
| `bias` | (out,) | int32 | 32 / out_dim (negligible) |
Total stored = `1.6 + 8 + 8/group_size` bits/weight.
At group_size=12: **1.6 + 8 + 0.67 = 10.27 bits/weight** (~1.28 bytes/weight) during training.
For inference/storage (no T_accum needed): **1.6 + 0.67 = 2.27 bits/weight** (~0.28 bytes/weight).
## Pipeline (forward)
```
x (ephemeral float from prev layer)
β unpack T_packed β T β {-1,0,+1}
β expand E β 2^E as ephemeral float
β w_eff = (2^E) * T (ephemeral float)
β y = F.linear(x, w_eff) β ternarize activation β next layer
β register hook to capture grad_w = grad_y^T @ x for T_accum/E updates
```
## Update Rule
```
After backward:
1. S (E) update (called by model._ternary_update_memory):
grad_E = -sign(mean over group of grad_w * T)
E = clamp(E + grad_E, -128, 127)
2. T flip (called by model._ternary_update_memory):
T_accum = clamp(T_accum + sign(grad_w), -128, 127)
if |T_accum| > threshold (default 3):
flip T at that position
reset T_accum at that position to 0
```
## Files Changed
### Core layers: `tscale.py`
- `TernaryScaleTensor` β complete rewrite
- Removed: `self.weight` (FP32 master weight), `_compute_T`, `_compute_S`
- Added: `T_packed`, `E` (int8 log-exponent), `T_accum` (int8 gradient counter)
- Added: `ternary_step()`, `update_E()` for per-step updates
- `forward`: unpack T β expand E β `w_eff = 2^E * T β linear β capture gradient
- `_hook_T`, `_hook_x` captured per-forward-call via closure
- `TernaryRMSNorm` β same T+E+accum scheme
### Model architecture: `trigram.py`
- `ByteEmbedding` β same T+E+accum scheme, removed `self.weight`
- `StickyZoneSTE` β kept (used by `TernaryGNNLayer` for edge_attr ternarization)
- `ScaledTernaryLinear` β removed (thin wrapper, no longer meaningful)
- `TernaryFFN` β removed (dead code, never used by model)
- `T_GRAPH_N_LAYERS` β removed (unused parameter)
- `MORPHTernaryModel._ternary_update_memory()` β new method, iterates all modules calling `ternary_step()` + `update_E()`
- Remaining float `nn.Parameter` instances (ViT frozen, MemGram embeddings, LSTM cell, edge_attr, router) kept as-is β small and either frozen or not weight matrices
### Serialization: `convert_to_ternary.py`
- New file with `pack_ternary()` and `unpack_ternary()` β 5-trit-per-byte base-3 encoding
- Independent of `convert_to_ternary8.py` (no circular import)
### Training loop: `train.py`
- Removed `from convert_to_ternary import save_model` β no longer needed
- Checkpoint save uses `model.state_dict()` directly (all buffers + float params)
- After `optimizer.step()` β `model._ternary_update_memory(accum_threshold=3)`
### Deleted: `ht.py`
- Commented-out planning notes, superseded by Phase 7 implementation
## Memory Projection for 3B Parameter Model
| Component | Size |
|-----------|------|
| T (packed 5-trit/byte) | ~0.6 GB |
| E (int8, group=12) | ~0.25 GB |
| T_accum (int8, 1 per weight) | ~3 GB |
| Gradients (ephemeral sign-only) | ~3 GB |
| Activations (ternary, checkpointed) | ~0.2β0.5 GB |
| **Total training** | **~7.4 GB** |
## Test Results (140/140 passing)
| Suite | Pass | Fail | Notes |
|-------|------|------|-------|
| `test_morph.py` | 119 | 0 | All Phase 1-7 tests |
| `test_tscale.py` | 21 | 0 | Core ternary + SignSGD tests |
## Pending Work
1. **Remaining float params** (MemGram embeddings, LSTM cell weights, MoE router) β ternarize these for full compliance
2. **TileLang GEMM kernel** β rewrite `dequant_gemm.py` to use int8 E (log-space) instead of float16 scales
3. **Activation ternarization** β optional optimization, clamp inter-layer activations to {-1,0,+1}
4. **Generate with memory** β `generate()` currently passes memory_state but basic sampling
|