ARBS / docs /true-ternary /TRUE_TERNARY_REFACTOR.md
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
# True Ternary Refactoring
## Architecture Contract
| Component | Type | Storage | Role |
|-----------|------|---------|------|
| **T** | {-1, 0, +1} | 5-trit/byte packed (1.6 BPW) | Weight values |
| **E** | int8 | 1 per group (log-space exponent) | Scale memory: `S = 2^E` |
| **T_accum** | int8 | 1 per weight | Gradient sign accumulator for T flips |
| **S** | ephemeral | derived from E in forward | `S = 2^E`, log-space block scale |
### Key design decisions
- **No IEEE float anywhere in weight state.**
- S is NOT stored β€” only E (int8) is stored. `S = 2^E` is ephemeral in the computation graph.
- Ephemeral float values exist only in autograd's computation graph (forward/backward pass), never in persistent state.
- Bias is stored as int32 buffer, cast to float ephemerally during forward.
### Log-Space Representation (Option B)
Scales use **log-space storage** as recommended by agents:
```
S = 2^E where E = int8 (logβ‚‚ of scale factor)
W_eff = T * 2^E
```
Log-space replaces float multiply with integer shift:
| Operation | Float version | Log-space version |
|-----------|---------------|-------------------|
| Scale Γ— scale | `S1 * S2` (float mul) | `E1 + E2` (int add) |
| Scale Γ— ternary | `S * T` (float mul) | `T << E \| T >> (-E)` (int shift) |
| Dequant in kernel | `sign * scale` (fp16 mul) | `sign << exp` (int shift) |
The **TileLang kernel** (`tilelang/kernels/dequant_gemm.py`) uses integer shift directly:
```python
# Per-element dequant in TileLang:
if sign_val == 0: dequant_int = 0
elif sign_val > 0: dequant_int = 1 << exp_val if exp_val >= 0 else 1 >> (-exp_val)
else: dequant_int = -(1 << exp_val) if exp_val >= 0 else -(1 >> (-exp_val))
```
### T scale type β†’ block sizing
`TILE_SIZE = 384`. TScaleType determines group size and thus E's granularity:
| Type | Group size | E entries per 384-dim row |
|------|-----------|--------------------------|
| T64 | 6 | 64 |
| T32 | 12 | 32 |
| T16 | 24 | 16 |
| T8 | 48 | 8 |
| T6 | 64 | 6 |
| T4 | 96 | 4 |
At model scale=10: block = 3840, group_count scales proportionally.
## Training State: what lives where
During training, each TernaryScaleTensor stores:
| Buffer | Shape | dtype | Bits/weight |
|--------|-------|-------|-------------|
| `T_packed` | flat (5-trit/byte) | uint8 | 1.6 |
| `E` | (n_groups,) | int8 | 8 / group_size |
| `T_accum` | (out, in) | int8 | 8 |
| `bias` | (out,) | int32 | 32 / out_dim (negligible) |
Total stored = `1.6 + 8 + 8/group_size` bits/weight.
At group_size=12: **1.6 + 8 + 0.67 = 10.27 bits/weight** (~1.28 bytes/weight) during training.
For inference/storage (no T_accum needed): **1.6 + 0.67 = 2.27 bits/weight** (~0.28 bytes/weight).
## Pipeline (forward)
```
x (ephemeral float from prev layer)
β†’ unpack T_packed β†’ T ∈ {-1,0,+1}
β†’ expand E β†’ 2^E as ephemeral float
β†’ w_eff = (2^E) * T (ephemeral float)
β†’ y = F.linear(x, w_eff) β†’ ternarize activation β†’ next layer
β†’ register hook to capture grad_w = grad_y^T @ x for T_accum/E updates
```
## Update Rule
```
After backward:
1. S (E) update (called by model._ternary_update_memory):
grad_E = -sign(mean over group of grad_w * T)
E = clamp(E + grad_E, -128, 127)
2. T flip (called by model._ternary_update_memory):
T_accum = clamp(T_accum + sign(grad_w), -128, 127)
if |T_accum| > threshold (default 3):
flip T at that position
reset T_accum at that position to 0
```
## Files Changed
### Core layers: `tscale.py`
- `TernaryScaleTensor` β€” complete rewrite
- Removed: `self.weight` (FP32 master weight), `_compute_T`, `_compute_S`
- Added: `T_packed`, `E` (int8 log-exponent), `T_accum` (int8 gradient counter)
- Added: `ternary_step()`, `update_E()` for per-step updates
- `forward`: unpack T β†’ expand E β†’ `w_eff = 2^E * T β†’ linear β†’ capture gradient
- `_hook_T`, `_hook_x` captured per-forward-call via closure
- `TernaryRMSNorm` β€” same T+E+accum scheme
### Model architecture: `trigram.py`
- `ByteEmbedding` β€” same T+E+accum scheme, removed `self.weight`
- `StickyZoneSTE` β€” kept (used by `TernaryGNNLayer` for edge_attr ternarization)
- `ScaledTernaryLinear` β€” removed (thin wrapper, no longer meaningful)
- `TernaryFFN` β€” removed (dead code, never used by model)
- `T_GRAPH_N_LAYERS` β€” removed (unused parameter)
- `MORPHTernaryModel._ternary_update_memory()` β€” new method, iterates all modules calling `ternary_step()` + `update_E()`
- Remaining float `nn.Parameter` instances (ViT frozen, MemGram embeddings, LSTM cell, edge_attr, router) kept as-is β€” small and either frozen or not weight matrices
### Serialization: `convert_to_ternary.py`
- New file with `pack_ternary()` and `unpack_ternary()` β€” 5-trit-per-byte base-3 encoding
- Independent of `convert_to_ternary8.py` (no circular import)
### Training loop: `train.py`
- Removed `from convert_to_ternary import save_model` β€” no longer needed
- Checkpoint save uses `model.state_dict()` directly (all buffers + float params)
- After `optimizer.step()` β†’ `model._ternary_update_memory(accum_threshold=3)`
### Deleted: `ht.py`
- Commented-out planning notes, superseded by Phase 7 implementation
## Memory Projection for 3B Parameter Model
| Component | Size |
|-----------|------|
| T (packed 5-trit/byte) | ~0.6 GB |
| E (int8, group=12) | ~0.25 GB |
| T_accum (int8, 1 per weight) | ~3 GB |
| Gradients (ephemeral sign-only) | ~3 GB |
| Activations (ternary, checkpointed) | ~0.2–0.5 GB |
| **Total training** | **~7.4 GB** |
## Test Results (140/140 passing)
| Suite | Pass | Fail | Notes |
|-------|------|------|-------|
| `test_morph.py` | 119 | 0 | All Phase 1-7 tests |
| `test_tscale.py` | 21 | 0 | Core ternary + SignSGD tests |
## Pending Work
1. **Remaining float params** (MemGram embeddings, LSTM cell weights, MoE router) β€” ternarize these for full compliance
2. **TileLang GEMM kernel** β€” rewrite `dequant_gemm.py` to use int8 E (log-space) instead of float16 scales
3. **Activation ternarization** β€” optional optimization, clamp inter-layer activations to {-1,0,+1}
4. **Generate with memory** β€” `generate()` currently passes memory_state but basic sampling