ARBS / testing /sign_gsd_vs_adamw.md
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
# SignGSD vs AdamW β€” Optimizer Analysis for Ternary Training
## 1. Algorithm Comparison
### AdamW (Loshchilov & Hutter, 2019)
```
exp_avg = Ξ²1Β·exp_avg + (1-Ξ²1)Β·grad # first moment (momentum)
exp_avg_sq = Ξ²2Β·exp_avg_sq + (1-Ξ²2)Β·gradΒ² # second moment (RMS)
m_hat = exp_avg / (1-Ξ²1^t) # bias correction
v_hat = exp_avg_sq / (1-Ξ²2^t)
p -= lr · m_hat / (√v_hat + Ρ) + wd·lr·p # decoupled weight decay
```
State per parameter:
- `exp_avg` (float32): 4 bytes
- `exp_avg_sq` (float32): 4 bytes
- Step counter `t` (int): negligible
- **Total optimizer state: 8 bytes/param**
### SignGSD
```
update = sign(grad) + wd Β· sign(p) # sign of gradient + sign of weight
p -= lr Β· update # uniform step
```
State per parameter: **0 bytes** (no state whatsoever)
### Key differences table
| Property | AdamW | SignGSD |
|---|---|---|
| Gradient uses | Direction + magnitude | Direction only (sign) |
| Adaptive LR | Per-param via RMS | None (uniform LR) |
| Momentum | Exponentially smoothed | None |
| Weight decay | `wd * lr * p` (decoupled) | `wd * sign(p)` (coupled to update) |
| LR schedule | Required (cosine, linear) | Not needed (step size = Β±lr always) |
| State/param | 8 bytes | 0 bytes |
| Suitable for | Full-precision training | Low-precision / binary / ternary |
---
## 2. What is Ternary Training?
Ternary weights are constrained to **{-1, 0, +1}**. This gives three benefits:
- **Storage**: ~1.6 bits/weight (base-3 packing: 5 trits per byte)
- **Compute**: matmuls with {-1,0,+1} weights use add/sub only (no multiplications)
- **Memory bandwidth**: ~1/16Γ— vs float32 weights
The fundamental challenge: **continuous optimizers cannot directly update ternary weights**.
```
w ∈ {-1, 0, +1}
w -= lr * grad β†’ 0.7, -0.3, 1.2 ... NOT ternary!
```
Every optimizer step must eventually produce a flip decision (Β±1 change or stay),
not a continuous displacement.
---
## 3. The Two-Layer Architecture of Ternary Training
All practical ternary training uses a **latent/accumulator layer** between the
optimizer and the actual ternary weights:
```
Optimizer (SignGSD or AdamW)
↓ writes continuous updates
Latent / Accumulator (e.g., T_accum: int8 per-weight scale)
↓ threshold-based flips
Ternary Weights T ∈ {-1, 0, +1} (packed as base-3 trits)
```
### Flow detail
1. **Forward**: unpack T β†’ matmul(W = S Γ— T, x) β†’ loss
- T is ternary, S is a per-weight or per-group scale
2. **Backward**: grad w.r.t. W flows to both S and T paths
- Gradient through T needs special handling (discrete, not differentiable)
3. **Optimizer step on S** (if per-param): standard optimizer (AdamW / SignGSD)
- S is continuous (float/int) β€” no special treatment needed
4. **Accumulator update on T_accum**: gradient sign votes accumulate
```
T_accum[i] += sign(grad_w[i]) # vote: +1 or -1
```
5. **Threshold-based T flip**:
```
if T_accum[i] > threshold: T[i] = +1; T_accum[i] = 0 (carry reset)
if T_accum[i] < -threshold: T[i] = -1; T_accum[i] = 0
```
---
## 4. What Each Optimizer Contributes
### What SignGSD contributes to ternary training
1. **Sign-only updates** β€” exactly what T_accum needs.
- `grad.sign()` produces {-1, 0, +1} which maps 1-to-1 to flip votes.
- No need to discard magnitude from AdamW's continuous gradient.
2. **Zero optimizer state** β€” critical at 3B+ parameter scale.
- AdamW's 8 bytes/param Γ— 3.1B = **24.2 GB** of optimizer state.
- That exceeds the entire model's storage budget (4 GB training state).
- SignGSD adds 0 bytes β€” all budget goes to T_accum (which is the actual
accumulator, not optimizer state).
3. **Uniform LR** β€” matches the discrete flip domain.
- Ternary flip decisions are uniform: every flip changes weight by Β±1 trit.
- AdamW's adaptive LR would modulate this uniform vote-counting process,
adding complexity without clear benefit.
4. **Weight decay via sign(p)** β€” naturally pushes ternary weights toward zero.
- When `sign(grad) == sign(p)`, update = Β±2 β†’ weight pushed past zero.
- When `sign(grad) != sign(p)`, update = 0 β†’ dead zone, no flip.
- This creates a natural sparsity bias: noisy gradients for nonzero weights
eventually push them to zero. Standard weight decay `wd * p` can't apply
directly to {-1,0,+1} because `wd * p` = 0, -wd, or +wd β€” not proportional.
### What AdamW contributes (and why it usually fails here)
1. **Momentum** β€” could smooth noisy sign gradients.
- Sign gradients are +1 or -1 per element per step. This is extremely noisy.
- Momentum would smooth this into something like `running_mean β‰ˆ E[sign(grad)]`.
- **But**: momentum state adds 4 bytes/param (exp_avg), halving the available
model size budget.
2. **Adaptive LR** β€” could help rarely-updated weights.
- Some weights might never flip because gradient sign oscillates.
- RMS-normalized gradients would amplify small-magnitude gradients.
- **But**: second moment adds another 4 bytes/param, and the adaptive scaling
is designed for continuous domains, not discrete flip decisions.
3. **The core mismatch**: AdamW's output is a continuous step (how much to move).
Ternary training needs a discrete vote (which direction to flip, if at all).
Converting continuous β†’ discrete by thresholding (`if abs(step) > 0.5`) loses
all the careful adaptive scaling AdamW computes.
### Currently: only AdamW works for the float32 latent scalar case
For the ARBS project's current architecture:
- T_accum (int8 per-weight) + E (int8 per-group) gradients are updated
- The E scale is continuous and CAN use AdamW
- The T_accum votes come from `sign(grad)` of the weight, not from AdamW
The current hybrid:
```
grad_W = autograd computed in float32
sign(grad_W) β†’ T_accum accumulator (int8 votes)
grad_E from chain rule β†’ E_accum β†’ AdamW on E (per-group scale)
```
Where AdamW operates on the continuous E scales (small, 0.86% of params)
and the sign-based accumulator handles the ternary T weights.
---
## 5. Practical Guidelines
### When to use SignGSD
- Target: ternary/binary weight training
- Memory is the primary constraint (< 8 GB)
- You already have an accumulator pipeline (T_accum + threshold flips)
- Weight updates are naturally uniform (Β±1 trit flips)
### When to use AdamW
- Target: float32 latent weights or per-group scales (E)
- You need adaptive scaling for rarely-activated weights
- Memory budget is generous (> 32 GB)
- Standard training where weights are continuous
### When to use both (the ARBS approach)
- **T_accum** votes: SignGSD-like `sign(grad)` accumulated as int8
- **E** (per-group scale): AdamW on float32 (continuous, small β€” 249 MB)
- **T** ternary weights: threshold flips driven by T_accum
- This hybrid splits the problem: discrete path via sign votes, continuous
path via AdamW on scales
---
## 6. Summary
| Concern | SignGSD | AdamW |
|---|---|---|
| Memory overhead | 0 bytes/param | 8 bytes/param |
| Gradient uses | sign(grad) only | full grad magnitude |
| Per-param adaptivity | None | RMS normalization |
| Natural domain | Discrete (flip votes) | Continuous (float updates) |
| Weight decay | sign(p) β†’ sparsity bias | wd Γ— p β†’ L2 regularization |
| Suitable for T_accum | Direct, 1-to-1 | Needs threshold β†’ loses info |
| Suitable for E scales | Possibly (if small) | Yes (continuous, small) |