File size: 6,255 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
# True Ternary Refactoring

## Architecture Contract

| Component | Type | Storage | Role |
|-----------|------|---------|------|
| **T** | {-1, 0, +1} | 5-trit/byte packed (1.6 BPW) | Weight values |
| **E** | int8 | 1 per group (log-space exponent) | Scale memory: `S = 2^E` |
| **T_accum** | int8 | 1 per weight | Gradient sign accumulator for T flips |
| **S** | ephemeral | derived from E in forward | `S = 2^E`, log-space block scale |

### Key design decisions

- **No IEEE float anywhere in weight state.**
- S is NOT stored β€” only E (int8) is stored. `S = 2^E` is ephemeral in the computation graph.
- Ephemeral float values exist only in autograd's computation graph (forward/backward pass), never in persistent state.
- Bias is stored as int32 buffer, cast to float ephemerally during forward.

### Log-Space Representation (Option B)

Scales use **log-space storage** as recommended by agents:

```
S = 2^E       where E = int8 (logβ‚‚ of scale factor)
W_eff = T * 2^E
```

Log-space replaces float multiply with integer shift:

| Operation | Float version | Log-space version |
|-----------|---------------|-------------------|
| Scale Γ— scale | `S1 * S2` (float mul) | `E1 + E2` (int add) |
| Scale Γ— ternary | `S * T` (float mul) | `T << E \| T >> (-E)` (int shift) |
| Dequant in kernel | `sign * scale` (fp16 mul) | `sign << exp` (int shift) |

The **TileLang kernel** (`tilelang/kernels/dequant_gemm.py`) uses integer shift directly:

```python
# Per-element dequant in TileLang:
if sign_val == 0:       dequant_int = 0
elif sign_val > 0:      dequant_int = 1 << exp_val   if exp_val >= 0 else 1 >> (-exp_val)
else:                   dequant_int = -(1 << exp_val) if exp_val >= 0 else -(1 >> (-exp_val))
```

### T scale type β†’ block sizing

`TILE_SIZE = 384`. TScaleType determines group size and thus E's granularity:

| Type | Group size | E entries per 384-dim row |
|------|-----------|--------------------------|
| T64  | 6         | 64 |
| T32  | 12        | 32 |
| T16  | 24        | 16 |
| T8   | 48        | 8  |
| T6   | 64        | 6  |
| T4   | 96        | 4  |

At model scale=10: block = 3840, group_count scales proportionally.

## Training State: what lives where

During training, each TernaryScaleTensor stores:

| Buffer | Shape | dtype | Bits/weight |
|--------|-------|-------|-------------|
| `T_packed` | flat (5-trit/byte) | uint8 | 1.6 |
| `E` | (n_groups,) | int8 | 8 / group_size |
| `T_accum` | (out, in) | int8 | 8 |
| `bias` | (out,) | int32 | 32 / out_dim (negligible) |

Total stored = `1.6 + 8 + 8/group_size` bits/weight.

At group_size=12: **1.6 + 8 + 0.67 = 10.27 bits/weight** (~1.28 bytes/weight) during training.

For inference/storage (no T_accum needed): **1.6 + 0.67 = 2.27 bits/weight** (~0.28 bytes/weight).

## Pipeline (forward)

```
x (ephemeral float from prev layer)
  β†’ unpack T_packed β†’ T ∈ {-1,0,+1}
  β†’ expand E β†’ 2^E as ephemeral float
  β†’ w_eff = (2^E) * T  (ephemeral float)
  β†’ y = F.linear(x, w_eff)  β†’ ternarize activation β†’ next layer
  β†’ register hook to capture grad_w = grad_y^T @ x for T_accum/E updates
```

## Update Rule

```
After backward:

1. S (E) update (called by model._ternary_update_memory):
   grad_E = -sign(mean over group of grad_w * T)
   E = clamp(E + grad_E, -128, 127)

2. T flip (called by model._ternary_update_memory):
   T_accum = clamp(T_accum + sign(grad_w), -128, 127)
   if |T_accum| > threshold (default 3):
       flip T at that position
       reset T_accum at that position to 0
```

## Files Changed

### Core layers: `tscale.py`
- `TernaryScaleTensor` β€” complete rewrite
  - Removed: `self.weight` (FP32 master weight), `_compute_T`, `_compute_S`
  - Added: `T_packed`, `E` (int8 log-exponent), `T_accum` (int8 gradient counter)
  - Added: `ternary_step()`, `update_E()` for per-step updates
  - `forward`: unpack T β†’ expand E β†’ `w_eff = 2^E * T β†’ linear β†’ capture gradient
  - `_hook_T`, `_hook_x` captured per-forward-call via closure
- `TernaryRMSNorm` β€” same T+E+accum scheme

### Model architecture: `trigram.py`
- `ByteEmbedding` β€” same T+E+accum scheme, removed `self.weight`
- `StickyZoneSTE` β€” kept (used by `TernaryGNNLayer` for edge_attr ternarization)
- `ScaledTernaryLinear` β€” removed (thin wrapper, no longer meaningful)
- `TernaryFFN` β€” removed (dead code, never used by model)
- `T_GRAPH_N_LAYERS` β€” removed (unused parameter)
- `MORPHTernaryModel._ternary_update_memory()` β€” new method, iterates all modules calling `ternary_step()` + `update_E()`
- Remaining float `nn.Parameter` instances (ViT frozen, MemGram embeddings, LSTM cell, edge_attr, router) kept as-is β€” small and either frozen or not weight matrices

### Serialization: `convert_to_ternary.py`
- New file with `pack_ternary()` and `unpack_ternary()` β€” 5-trit-per-byte base-3 encoding
- Independent of `convert_to_ternary8.py` (no circular import)

### Training loop: `train.py`
- Removed `from convert_to_ternary import save_model` β€” no longer needed
- Checkpoint save uses `model.state_dict()` directly (all buffers + float params)
- After `optimizer.step()` β†’ `model._ternary_update_memory(accum_threshold=3)`

### Deleted: `ht.py`
- Commented-out planning notes, superseded by Phase 7 implementation

## Memory Projection for 3B Parameter Model

| Component | Size |
|-----------|------|
| T (packed 5-trit/byte) | ~0.6 GB |
| E (int8, group=12) | ~0.25 GB |
| T_accum (int8, 1 per weight) | ~3 GB |
| Gradients (ephemeral sign-only) | ~3 GB |
| Activations (ternary, checkpointed) | ~0.2–0.5 GB |
| **Total training** | **~7.4 GB** |

## Test Results (140/140 passing)

| Suite | Pass | Fail | Notes |
|-------|------|------|-------|
| `test_morph.py` | 119 | 0 | All Phase 1-7 tests |
| `test_tscale.py` | 21 | 0 | Core ternary + SignSGD tests |

## Pending Work

1. **Remaining float params** (MemGram embeddings, LSTM cell weights, MoE router) β€” ternarize these for full compliance
2. **TileLang GEMM kernel** β€” rewrite `dequant_gemm.py` to use int8 E (log-space) instead of float16 scales
3. **Activation ternarization** β€” optional optimization, clamp inter-layer activations to {-1,0,+1}
4. **Generate with memory** β€” `generate()` currently passes memory_state but basic sampling