File size: 6,697 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
# Factorized Scaled Ternary — Config E Equations

## The Core Identity

```
W = S ⊙ T  where  S = |W|,  T = sign(W)
```

This is always an identity for any real number: |w| × sign(w) = w.
No approximation. No information loss at active positions.

---

## Per-Config Forward Pass Equations

### Config A (FP32 Baseline)

```
y = xW + b
```

Standard linear. No ternarization. Full precision weights.

### Config C (Learned-S)

```
T = sign(W) · (|W| > θ)     ← ternarize with hard threshold
S = learned scalar            ← nn.Parameter, trained via Adam
W_eff = S · T                 ← all surviving weights share one scale
y = xW_eff + b
```

Problem: one scalar S cannot capture per-element magnitude variation.
At θ=0.05, only ~30% of weights survive, and they all get the same
scale regardless of actual magnitude. S converges to ~0.30.

### Config D (Computed-S per-layer / BitNet absmean)

```
T = sign(W) · (|W| > θ)     ← same ternarize
S = mean(|W|)                 ← single scalar computed from weight stats
W_eff = S · T                 ← uniform scale like BitNet
y = xW_eff + b
```

S tracks actual weight magnitude dynamically. Better than C because
S is always correct for the current weight distribution. Still one
scalar per layer — coarse granularity.

### Config E (Factorized — per-element)

```
T = sign(W) · (|W| > θ)     ← ternary mask: {-1, 0, +1}
S = |W|                      ← per-element magnitude tensor (same shape as W)
W_eff = S ⊙ T                ← element-wise product
y = xW_eff + b
```

At every active position (where |W[i,j]| > θ):

```
W_eff[i,j] = |W[i,j]| × sign(W[i,j]) = W[i,j]
```

Exact reconstruction. The only information loss is at positions where
|W[i,j]| ≤ θ — these are zeroed out. Those positions had near-zero
contribution anyway. Measured gap: ~5% vs FP32 at 5000 steps.

---

## STE Backward Pass

The TernarizeSTE function:

```
Forward:  T = sign(W) · (|W| > θ)
Backward: ∂L/∂W = ∂L/∂T_eff · ∂T_eff/∂W
```

Where the STE approximation treats the forward as identity within
the active zone:

```
∂L/∂W ≈ ∂L/∂W_eff · mask(|W| > θ)
```

For Config E specifically, since W_eff = |W| ⊙ T:

```
∂L/∂W ≈ ∂L/∂W_eff · mask(|W| > θ)
```

The gradient flows through to W directly. The optimizer updates W
via standard addition (Adam: W ← W - lr × update). The magnitude
|W| and direction sign(W) emerge naturally from this process.

No separate S gradient needed. S = |W| is a deterministic function
of W, not a learned parameter.

---

## Compute Comparison

| Operation              | FP32            | Config C          | Config D          | Config E              |
|------------------------|-----------------|-------------------|-------------------|-----------------------|
| Weight storage         | full FP32       | FP32 + 1 scalar   | FP32 only         | FP32 only             |
| Effective weights      | W               | S·T (uniform)     | S·T (uniform)     | S⊙T (per-element)     |
| Forward matmul         | x × W           | x × (S·T)         | x × (S·T)         | x × (S⊙T)             |
| STE backward           | N/A             | mask · ∂L/∂W_eff  | mask · ∂L/∂W_eff  | mask · ∂L/∂W_eff      |
| Extra compute / step   | 0               | 1 scalar mul      | 1 absmean         | 1 abs + element mul   |
| BPW (inference)        | 32.0            | 1.58 + overhead   | 1.58              | 1.58                  |
| Measured val loss      | 2.165           | 2.895             | 2.515             | 2.273                 |
| vs FP32                | 1.000×          | 1.337×            | 1.162×            | 1.050×                |

---

## Why Config E Wins

1. **Identity reconstruction**: At active positions, W_eff = W exactly.
   No quantization error. No information bottleneck from uniform scaling.

2. **No extra parameter**: S = |W| is computed, not learned. Fewer
   parameters means less BPW overhead and simpler architecture.

3. **Natural magnitude evolution**: As training progresses, weights
   that matter grow in magnitude (larger S), weights that don't
   shrink toward zero (absorbed by threshold). This is standard
   gradient descent — we just observe it through the S ⊙ T lens.

4. **Sparsity emerges organically**: The threshold θ = 0.05 naturally
   kills ~38% of weights in fc1 and ~63% in fc2 by step 5000. These
   are genuine zeros — structural sparsity available for inference.

---

## Implications for Inference Serialization

At inference time, Config E's effective weights can be decomposed
at any granularity:

| Granularity    | Format                   | BPW    | Accuracy    |
|----------------|--------------------------|--------|-------------|
| Per-layer S    | T (2-bit) + 1 FP16 S    | ~1.58  | Like D      |
| Per-group S    | T (2-bit) + group FP16  | ~1.58  | Better      |
| Per-element S  | T (2-bit) + per-elem FP | ~17    | Full E      |

The I2_S format (BitNet.cpp) packs 16 ternary weights in 32 bits
(2 bits each) with one FP16 group scale. This is the natural
serialization target for Config E — group the magnitudes, pack
the ternary patterns.

Training uses full FP32 W with STE. Inference uses packed T + S.
The bridge between them is the fact that W_eff at active positions
equals W — so any group quantization of |W| is just standard
weight quantization applied to the surviving weights.

---

## The Question: Is Ternary Actually Applied?

Config E's effective forward weight is W_eff = |W| ⊙ T.

At active positions this equals W itself. So is ternarization
even happening? Yes, but subtly:

1. **Forward pass**: The STE ternarizes W into T ∈ {-1, 0, +1},
   then re-scales by |W|. The result at active positions is W.
   But the **sparsity mask** (which positions are zero) IS the
   ternary structure. T determines shape, S determines magnitude.

2. **Gradient pass**: The STE mask blocks gradient to small weights.
   This IS ternary training — the gradient is zeroed for positions
   below threshold. This is the mechanism that creates sparsity.

3. **Inference packing**: T (the ternary pattern) can be packed at
   2 bits per weight. S (magnitudes) can be quantized per-group.
   This IS ternary inference — just with group scales attached.

4. **What Config E is NOT**: It is not a ternary-only matmul in
   training. During training, the matmul uses full-precision
   W_eff values (which equal W at active positions). The ternary
   advantage is realized at inference, not during training.

**The ternary IS the sparsity mask.** The magnitude IS the weight.
Training maintains both in one parameter. Inference separates them.