File size: 7,481 Bytes
d8bc908 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 | # SignGSD vs AdamW — Optimizer Analysis for Ternary Training
## 1. Algorithm Comparison
### AdamW (Loshchilov & Hutter, 2019)
```
exp_avg = β1·exp_avg + (1-β1)·grad # first moment (momentum)
exp_avg_sq = β2·exp_avg_sq + (1-β2)·grad² # second moment (RMS)
m_hat = exp_avg / (1-β1^t) # bias correction
v_hat = exp_avg_sq / (1-β2^t)
p -= lr · m_hat / (√v_hat + ε) + wd·lr·p # decoupled weight decay
```
State per parameter:
- `exp_avg` (float32): 4 bytes
- `exp_avg_sq` (float32): 4 bytes
- Step counter `t` (int): negligible
- **Total optimizer state: 8 bytes/param**
### SignGSD
```
update = sign(grad) + wd · sign(p) # sign of gradient + sign of weight
p -= lr · update # uniform step
```
State per parameter: **0 bytes** (no state whatsoever)
### Key differences table
| Property | AdamW | SignGSD |
|---|---|---|
| Gradient uses | Direction + magnitude | Direction only (sign) |
| Adaptive LR | Per-param via RMS | None (uniform LR) |
| Momentum | Exponentially smoothed | None |
| Weight decay | `wd * lr * p` (decoupled) | `wd * sign(p)` (coupled to update) |
| LR schedule | Required (cosine, linear) | Not needed (step size = ±lr always) |
| State/param | 8 bytes | 0 bytes |
| Suitable for | Full-precision training | Low-precision / binary / ternary |
---
## 2. What is Ternary Training?
Ternary weights are constrained to **{-1, 0, +1}**. This gives three benefits:
- **Storage**: ~1.6 bits/weight (base-3 packing: 5 trits per byte)
- **Compute**: matmuls with {-1,0,+1} weights use add/sub only (no multiplications)
- **Memory bandwidth**: ~1/16× vs float32 weights
The fundamental challenge: **continuous optimizers cannot directly update ternary weights**.
```
w ∈ {-1, 0, +1}
w -= lr * grad → 0.7, -0.3, 1.2 ... NOT ternary!
```
Every optimizer step must eventually produce a flip decision (±1 change or stay),
not a continuous displacement.
---
## 3. The Two-Layer Architecture of Ternary Training
All practical ternary training uses a **latent/accumulator layer** between the
optimizer and the actual ternary weights:
```
Optimizer (SignGSD or AdamW)
↓ writes continuous updates
Latent / Accumulator (e.g., T_accum: int8 per-weight scale)
↓ threshold-based flips
Ternary Weights T ∈ {-1, 0, +1} (packed as base-3 trits)
```
### Flow detail
1. **Forward**: unpack T → matmul(W = S × T, x) → loss
- T is ternary, S is a per-weight or per-group scale
2. **Backward**: grad w.r.t. W flows to both S and T paths
- Gradient through T needs special handling (discrete, not differentiable)
3. **Optimizer step on S** (if per-param): standard optimizer (AdamW / SignGSD)
- S is continuous (float/int) — no special treatment needed
4. **Accumulator update on T_accum**: gradient sign votes accumulate
```
T_accum[i] += sign(grad_w[i]) # vote: +1 or -1
```
5. **Threshold-based T flip**:
```
if T_accum[i] > threshold: T[i] = +1; T_accum[i] = 0 (carry reset)
if T_accum[i] < -threshold: T[i] = -1; T_accum[i] = 0
```
---
## 4. What Each Optimizer Contributes
### What SignGSD contributes to ternary training
1. **Sign-only updates** — exactly what T_accum needs.
- `grad.sign()` produces {-1, 0, +1} which maps 1-to-1 to flip votes.
- No need to discard magnitude from AdamW's continuous gradient.
2. **Zero optimizer state** — critical at 3B+ parameter scale.
- AdamW's 8 bytes/param × 3.1B = **24.2 GB** of optimizer state.
- That exceeds the entire model's storage budget (4 GB training state).
- SignGSD adds 0 bytes — all budget goes to T_accum (which is the actual
accumulator, not optimizer state).
3. **Uniform LR** — matches the discrete flip domain.
- Ternary flip decisions are uniform: every flip changes weight by ±1 trit.
- AdamW's adaptive LR would modulate this uniform vote-counting process,
adding complexity without clear benefit.
4. **Weight decay via sign(p)** — naturally pushes ternary weights toward zero.
- When `sign(grad) == sign(p)`, update = ±2 → weight pushed past zero.
- When `sign(grad) != sign(p)`, update = 0 → dead zone, no flip.
- This creates a natural sparsity bias: noisy gradients for nonzero weights
eventually push them to zero. Standard weight decay `wd * p` can't apply
directly to {-1,0,+1} because `wd * p` = 0, -wd, or +wd — not proportional.
### What AdamW contributes (and why it usually fails here)
1. **Momentum** — could smooth noisy sign gradients.
- Sign gradients are +1 or -1 per element per step. This is extremely noisy.
- Momentum would smooth this into something like `running_mean ≈ E[sign(grad)]`.
- **But**: momentum state adds 4 bytes/param (exp_avg), halving the available
model size budget.
2. **Adaptive LR** — could help rarely-updated weights.
- Some weights might never flip because gradient sign oscillates.
- RMS-normalized gradients would amplify small-magnitude gradients.
- **But**: second moment adds another 4 bytes/param, and the adaptive scaling
is designed for continuous domains, not discrete flip decisions.
3. **The core mismatch**: AdamW's output is a continuous step (how much to move).
Ternary training needs a discrete vote (which direction to flip, if at all).
Converting continuous → discrete by thresholding (`if abs(step) > 0.5`) loses
all the careful adaptive scaling AdamW computes.
### Currently: only AdamW works for the float32 latent scalar case
For the ARBS project's current architecture:
- T_accum (int8 per-weight) + E (int8 per-group) gradients are updated
- The E scale is continuous and CAN use AdamW
- The T_accum votes come from `sign(grad)` of the weight, not from AdamW
The current hybrid:
```
grad_W = autograd computed in float32
sign(grad_W) → T_accum accumulator (int8 votes)
grad_E from chain rule → E_accum → AdamW on E (per-group scale)
```
Where AdamW operates on the continuous E scales (small, 0.86% of params)
and the sign-based accumulator handles the ternary T weights.
---
## 5. Practical Guidelines
### When to use SignGSD
- Target: ternary/binary weight training
- Memory is the primary constraint (< 8 GB)
- You already have an accumulator pipeline (T_accum + threshold flips)
- Weight updates are naturally uniform (±1 trit flips)
### When to use AdamW
- Target: float32 latent weights or per-group scales (E)
- You need adaptive scaling for rarely-activated weights
- Memory budget is generous (> 32 GB)
- Standard training where weights are continuous
### When to use both (the ARBS approach)
- **T_accum** votes: SignGSD-like `sign(grad)` accumulated as int8
- **E** (per-group scale): AdamW on float32 (continuous, small — 249 MB)
- **T** ternary weights: threshold flips driven by T_accum
- This hybrid splits the problem: discrete path via sign votes, continuous
path via AdamW on scales
---
## 6. Summary
| Concern | SignGSD | AdamW |
|---|---|---|
| Memory overhead | 0 bytes/param | 8 bytes/param |
| Gradient uses | sign(grad) only | full grad magnitude |
| Per-param adaptivity | None | RMS normalization |
| Natural domain | Discrete (flip votes) | Continuous (float updates) |
| Weight decay | sign(p) → sparsity bias | wd × p → L2 regularization |
| Suitable for T_accum | Direct, 1-to-1 | Needs threshold → loses info |
| Suitable for E scales | Possibly (if small) | Yes (continuous, small) |
|