File size: 7,481 Bytes
d8bc908
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
# SignGSD vs AdamW — Optimizer Analysis for Ternary Training

## 1. Algorithm Comparison

### AdamW (Loshchilov & Hutter, 2019)

```
exp_avg = β1·exp_avg + (1-β1)·grad          # first moment (momentum)
exp_avg_sq = β2·exp_avg_sq + (1-β2)·grad²    # second moment (RMS)
m_hat = exp_avg / (1-β1^t)                    # bias correction
v_hat = exp_avg_sq / (1-β2^t)
p -= lr · m_hat / (√v_hat + ε) + wd·lr·p     # decoupled weight decay
```

State per parameter:
- `exp_avg` (float32): 4 bytes
- `exp_avg_sq` (float32): 4 bytes  
- Step counter `t` (int): negligible
- **Total optimizer state: 8 bytes/param**

### SignGSD

```
update = sign(grad) + wd · sign(p)           # sign of gradient + sign of weight
p -= lr · update                             # uniform step
```

State per parameter: **0 bytes** (no state whatsoever)

### Key differences table

| Property | AdamW | SignGSD |
|---|---|---|
| Gradient uses | Direction + magnitude | Direction only (sign) |
| Adaptive LR | Per-param via RMS | None (uniform LR) |
| Momentum | Exponentially smoothed | None |
| Weight decay | `wd * lr * p` (decoupled) | `wd * sign(p)` (coupled to update) |
| LR schedule | Required (cosine, linear) | Not needed (step size = ±lr always) |
| State/param | 8 bytes | 0 bytes |
| Suitable for | Full-precision training | Low-precision / binary / ternary |

---

## 2. What is Ternary Training?

Ternary weights are constrained to **{-1, 0, +1}**. This gives three benefits:
- **Storage**: ~1.6 bits/weight (base-3 packing: 5 trits per byte)
- **Compute**: matmuls with {-1,0,+1} weights use add/sub only (no multiplications)
- **Memory bandwidth**: ~1/16× vs float32 weights

The fundamental challenge: **continuous optimizers cannot directly update ternary weights**.

```
w ∈ {-1, 0, +1}
w -= lr * grad   →  0.7, -0.3, 1.2 ... NOT ternary!
```

Every optimizer step must eventually produce a flip decision (±1 change or stay),
not a continuous displacement.

---

## 3. The Two-Layer Architecture of Ternary Training

All practical ternary training uses a **latent/accumulator layer** between the
optimizer and the actual ternary weights:

```
Optimizer (SignGSD or AdamW)
    ↓ writes continuous updates
Latent / Accumulator (e.g., T_accum: int8 per-weight scale)
    ↓ threshold-based flips
Ternary Weights T ∈ {-1, 0, +1} (packed as base-3 trits)
```

### Flow detail

1. **Forward**: unpack T → matmul(W = S × T, x) → loss
   - T is ternary, S is a per-weight or per-group scale
2. **Backward**: grad w.r.t. W flows to both S and T paths
   - Gradient through T needs special handling (discrete, not differentiable)
3. **Optimizer step on S** (if per-param): standard optimizer (AdamW / SignGSD)
   - S is continuous (float/int) — no special treatment needed
4. **Accumulator update on T_accum**: gradient sign votes accumulate
   ```
   T_accum[i] += sign(grad_w[i])     # vote: +1 or -1
   ```
5. **Threshold-based T flip**:
   ```
   if T_accum[i] > threshold:  T[i] = +1;  T_accum[i] = 0  (carry reset)
   if T_accum[i] < -threshold: T[i] = -1;  T_accum[i] = 0
   ```

---

## 4. What Each Optimizer Contributes

### What SignGSD contributes to ternary training

1. **Sign-only updates** — exactly what T_accum needs.
   - `grad.sign()` produces {-1, 0, +1} which maps 1-to-1 to flip votes.
   - No need to discard magnitude from AdamW's continuous gradient.

2. **Zero optimizer state** — critical at 3B+ parameter scale.
   - AdamW's 8 bytes/param × 3.1B = **24.2 GB** of optimizer state.
   - That exceeds the entire model's storage budget (4 GB training state).
   - SignGSD adds 0 bytes — all budget goes to T_accum (which is the actual
     accumulator, not optimizer state).

3. **Uniform LR** — matches the discrete flip domain.
   - Ternary flip decisions are uniform: every flip changes weight by ±1 trit.
   - AdamW's adaptive LR would modulate this uniform vote-counting process,
     adding complexity without clear benefit.

4. **Weight decay via sign(p)** — naturally pushes ternary weights toward zero.
   - When `sign(grad) == sign(p)`, update = ±2 → weight pushed past zero.
   - When `sign(grad) != sign(p)`, update = 0 → dead zone, no flip.
   - This creates a natural sparsity bias: noisy gradients for nonzero weights
     eventually push them to zero. Standard weight decay `wd * p` can't apply
     directly to {-1,0,+1} because `wd * p` = 0, -wd, or +wd — not proportional.

### What AdamW contributes (and why it usually fails here)

1. **Momentum** — could smooth noisy sign gradients.
   - Sign gradients are +1 or -1 per element per step. This is extremely noisy.
   - Momentum would smooth this into something like `running_mean ≈ E[sign(grad)]`.
   - **But**: momentum state adds 4 bytes/param (exp_avg), halving the available
     model size budget.

2. **Adaptive LR** — could help rarely-updated weights.
   - Some weights might never flip because gradient sign oscillates.
   - RMS-normalized gradients would amplify small-magnitude gradients.
   - **But**: second moment adds another 4 bytes/param, and the adaptive scaling
     is designed for continuous domains, not discrete flip decisions.

3. **The core mismatch**: AdamW's output is a continuous step (how much to move).
   Ternary training needs a discrete vote (which direction to flip, if at all).
   Converting continuous → discrete by thresholding (`if abs(step) > 0.5`) loses
   all the careful adaptive scaling AdamW computes.

### Currently: only AdamW works for the float32 latent scalar case

For the ARBS project's current architecture:
- T_accum (int8 per-weight) + E (int8 per-group) gradients are updated
- The E scale is continuous and CAN use AdamW
- The T_accum votes come from `sign(grad)` of the weight, not from AdamW

The current hybrid:
```
grad_W = autograd computed in float32
sign(grad_W) → T_accum accumulator (int8 votes)
grad_E from chain rule → E_accum → AdamW on E (per-group scale)
```

Where AdamW operates on the continuous E scales (small, 0.86% of params)
and the sign-based accumulator handles the ternary T weights.

---

## 5. Practical Guidelines

### When to use SignGSD
- Target: ternary/binary weight training
- Memory is the primary constraint (< 8 GB)
- You already have an accumulator pipeline (T_accum + threshold flips)
- Weight updates are naturally uniform (±1 trit flips)

### When to use AdamW
- Target: float32 latent weights or per-group scales (E)
- You need adaptive scaling for rarely-activated weights
- Memory budget is generous (> 32 GB)
- Standard training where weights are continuous

### When to use both (the ARBS approach)
- **T_accum** votes: SignGSD-like `sign(grad)` accumulated as int8
- **E** (per-group scale): AdamW on float32 (continuous, small — 249 MB)
- **T** ternary weights: threshold flips driven by T_accum
- This hybrid splits the problem: discrete path via sign votes, continuous
  path via AdamW on scales

---

## 6. Summary

| Concern | SignGSD | AdamW |
|---|---|---|
| Memory overhead | 0 bytes/param | 8 bytes/param |
| Gradient uses | sign(grad) only | full grad magnitude |
| Per-param adaptivity | None | RMS normalization |
| Natural domain | Discrete (flip votes) | Continuous (float updates) |
| Weight decay | sign(p) → sparsity bias | wd × p → L2 regularization |
| Suitable for T_accum | Direct, 1-to-1 | Needs threshold → loses info |
| Suitable for E scales | Possibly (if small) | Yes (continuous, small) |