True Ternary Refactoring
Architecture Contract
| Component | Type | Storage | Role |
|---|---|---|---|
| T | {-1, 0, +1} | 5-trit/byte packed (1.6 BPW) | Weight values |
| E | int8 | 1 per group (log-space exponent) | Scale memory: S = 2^E |
| T_accum | int8 | 1 per weight | Gradient sign accumulator for T flips |
| S | ephemeral | derived from E in forward | S = 2^E, log-space block scale |
Key design decisions
- No IEEE float anywhere in weight state.
- S is NOT stored β only E (int8) is stored.
S = 2^Eis ephemeral in the computation graph. - Ephemeral float values exist only in autograd's computation graph (forward/backward pass), never in persistent state.
- Bias is stored as int32 buffer, cast to float ephemerally during forward.
Log-Space Representation (Option B)
Scales use log-space storage as recommended by agents:
S = 2^E where E = int8 (logβ of scale factor)
W_eff = T * 2^E
Log-space replaces float multiply with integer shift:
| Operation | Float version | Log-space version |
|---|---|---|
| Scale Γ scale | S1 * S2 (float mul) |
E1 + E2 (int add) |
| Scale Γ ternary | S * T (float mul) |
T << E | T >> (-E) (int shift) |
| Dequant in kernel | sign * scale (fp16 mul) |
sign << exp (int shift) |
The TileLang kernel (tilelang/kernels/dequant_gemm.py) uses integer shift directly:
# Per-element dequant in TileLang:
if sign_val == 0: dequant_int = 0
elif sign_val > 0: dequant_int = 1 << exp_val if exp_val >= 0 else 1 >> (-exp_val)
else: dequant_int = -(1 << exp_val) if exp_val >= 0 else -(1 >> (-exp_val))
T scale type β block sizing
TILE_SIZE = 384. TScaleType determines group size and thus E's granularity:
| Type | Group size | E entries per 384-dim row |
|---|---|---|
| T64 | 6 | 64 |
| T32 | 12 | 32 |
| T16 | 24 | 16 |
| T8 | 48 | 8 |
| T6 | 64 | 6 |
| T4 | 96 | 4 |
At model scale=10: block = 3840, group_count scales proportionally.
Training State: what lives where
During training, each TernaryScaleTensor stores:
| Buffer | Shape | dtype | Bits/weight |
|---|---|---|---|
T_packed |
flat (5-trit/byte) | uint8 | 1.6 |
E |
(n_groups,) | int8 | 8 / group_size |
T_accum |
(out, in) | int8 | 8 |
bias |
(out,) | int32 | 32 / out_dim (negligible) |
Total stored = 1.6 + 8 + 8/group_size bits/weight.
At group_size=12: 1.6 + 8 + 0.67 = 10.27 bits/weight (~1.28 bytes/weight) during training.
For inference/storage (no T_accum needed): 1.6 + 0.67 = 2.27 bits/weight (~0.28 bytes/weight).
Pipeline (forward)
x (ephemeral float from prev layer)
β unpack T_packed β T β {-1,0,+1}
β expand E β 2^E as ephemeral float
β w_eff = (2^E) * T (ephemeral float)
β y = F.linear(x, w_eff) β ternarize activation β next layer
β register hook to capture grad_w = grad_y^T @ x for T_accum/E updates
Update Rule
After backward:
1. S (E) update (called by model._ternary_update_memory):
grad_E = -sign(mean over group of grad_w * T)
E = clamp(E + grad_E, -128, 127)
2. T flip (called by model._ternary_update_memory):
T_accum = clamp(T_accum + sign(grad_w), -128, 127)
if |T_accum| > threshold (default 3):
flip T at that position
reset T_accum at that position to 0
Files Changed
Core layers: tscale.py
TernaryScaleTensorβ complete rewrite- Removed:
self.weight(FP32 master weight),_compute_T,_compute_S - Added:
T_packed,E(int8 log-exponent),T_accum(int8 gradient counter) - Added:
ternary_step(),update_E()for per-step updates forward: unpack T β expand E β `w_eff = 2^E * T β linear β capture gradient_hook_T,_hook_xcaptured per-forward-call via closure
- Removed:
TernaryRMSNormβ same T+E+accum scheme
Model architecture: trigram.py
ByteEmbeddingβ same T+E+accum scheme, removedself.weightStickyZoneSTEβ kept (used byTernaryGNNLayerfor edge_attr ternarization)ScaledTernaryLinearβ removed (thin wrapper, no longer meaningful)TernaryFFNβ removed (dead code, never used by model)T_GRAPH_N_LAYERSβ removed (unused parameter)MORPHTernaryModel._ternary_update_memory()β new method, iterates all modules callingternary_step()+update_E()- Remaining float
nn.Parameterinstances (ViT frozen, MemGram embeddings, LSTM cell, edge_attr, router) kept as-is β small and either frozen or not weight matrices
Serialization: convert_to_ternary.py
- New file with
pack_ternary()andunpack_ternary()β 5-trit-per-byte base-3 encoding - Independent of
convert_to_ternary8.py(no circular import)
Training loop: train.py
- Removed
from convert_to_ternary import save_modelβ no longer needed - Checkpoint save uses
model.state_dict()directly (all buffers + float params) - After
optimizer.step()βmodel._ternary_update_memory(accum_threshold=3)
Deleted: ht.py
- Commented-out planning notes, superseded by Phase 7 implementation
Memory Projection for 3B Parameter Model
| Component | Size |
|---|---|
| T (packed 5-trit/byte) | ~0.6 GB |
| E (int8, group=12) | ~0.25 GB |
| T_accum (int8, 1 per weight) | ~3 GB |
| Gradients (ephemeral sign-only) | ~3 GB |
| Activations (ternary, checkpointed) | ~0.2β0.5 GB |
| Total training | ~7.4 GB |
Test Results (140/140 passing)
| Suite | Pass | Fail | Notes |
|---|---|---|---|
test_morph.py |
119 | 0 | All Phase 1-7 tests |
test_tscale.py |
21 | 0 | Core ternary + SignSGD tests |
Pending Work
- Remaining float params (MemGram embeddings, LSTM cell weights, MoE router) β ternarize these for full compliance
- TileLang GEMM kernel β rewrite
dequant_gemm.pyto use int8 E (log-space) instead of float16 scales - Activation ternarization β optional optimization, clamp inter-layer activations to {-1,0,+1}
- Generate with memory β
generate()currently passes memory_state but basic sampling