SignGSD vs AdamW — Optimizer Analysis for Ternary Training
1. Algorithm Comparison
AdamW (Loshchilov & Hutter, 2019)
exp_avg = β1·exp_avg + (1-β1)·grad # first moment (momentum)
exp_avg_sq = β2·exp_avg_sq + (1-β2)·grad² # second moment (RMS)
m_hat = exp_avg / (1-β1^t) # bias correction
v_hat = exp_avg_sq / (1-β2^t)
p -= lr · m_hat / (√v_hat + ε) + wd·lr·p # decoupled weight decay
State per parameter:
exp_avg(float32): 4 bytesexp_avg_sq(float32): 4 bytes- Step counter
t(int): negligible - Total optimizer state: 8 bytes/param
SignGSD
update = sign(grad) + wd · sign(p) # sign of gradient + sign of weight
p -= lr · update # uniform step
State per parameter: 0 bytes (no state whatsoever)
Key differences table
| Property | AdamW | SignGSD |
|---|---|---|
| Gradient uses | Direction + magnitude | Direction only (sign) |
| Adaptive LR | Per-param via RMS | None (uniform LR) |
| Momentum | Exponentially smoothed | None |
| Weight decay | wd * lr * p (decoupled) |
wd * sign(p) (coupled to update) |
| LR schedule | Required (cosine, linear) | Not needed (step size = ±lr always) |
| State/param | 8 bytes | 0 bytes |
| Suitable for | Full-precision training | Low-precision / binary / ternary |
2. What is Ternary Training?
Ternary weights are constrained to {-1, 0, +1}. This gives three benefits:
- Storage: ~1.6 bits/weight (base-3 packing: 5 trits per byte)
- Compute: matmuls with {-1,0,+1} weights use add/sub only (no multiplications)
- Memory bandwidth: ~1/16× vs float32 weights
The fundamental challenge: continuous optimizers cannot directly update ternary weights.
w ∈ {-1, 0, +1}
w -= lr * grad → 0.7, -0.3, 1.2 ... NOT ternary!
Every optimizer step must eventually produce a flip decision (±1 change or stay), not a continuous displacement.
3. The Two-Layer Architecture of Ternary Training
All practical ternary training uses a latent/accumulator layer between the optimizer and the actual ternary weights:
Optimizer (SignGSD or AdamW)
↓ writes continuous updates
Latent / Accumulator (e.g., T_accum: int8 per-weight scale)
↓ threshold-based flips
Ternary Weights T ∈ {-1, 0, +1} (packed as base-3 trits)
Flow detail
- Forward: unpack T → matmul(W = S × T, x) → loss
- T is ternary, S is a per-weight or per-group scale
- Backward: grad w.r.t. W flows to both S and T paths
- Gradient through T needs special handling (discrete, not differentiable)
- Optimizer step on S (if per-param): standard optimizer (AdamW / SignGSD)
- S is continuous (float/int) — no special treatment needed
- Accumulator update on T_accum: gradient sign votes accumulate
T_accum[i] += sign(grad_w[i]) # vote: +1 or -1 - Threshold-based T flip:
if T_accum[i] > threshold: T[i] = +1; T_accum[i] = 0 (carry reset) if T_accum[i] < -threshold: T[i] = -1; T_accum[i] = 0
4. What Each Optimizer Contributes
What SignGSD contributes to ternary training
Sign-only updates — exactly what T_accum needs.
grad.sign()produces {-1, 0, +1} which maps 1-to-1 to flip votes.- No need to discard magnitude from AdamW's continuous gradient.
Zero optimizer state — critical at 3B+ parameter scale.
- AdamW's 8 bytes/param × 3.1B = 24.2 GB of optimizer state.
- That exceeds the entire model's storage budget (4 GB training state).
- SignGSD adds 0 bytes — all budget goes to T_accum (which is the actual accumulator, not optimizer state).
Uniform LR — matches the discrete flip domain.
- Ternary flip decisions are uniform: every flip changes weight by ±1 trit.
- AdamW's adaptive LR would modulate this uniform vote-counting process, adding complexity without clear benefit.
Weight decay via sign(p) — naturally pushes ternary weights toward zero.
- When
sign(grad) == sign(p), update = ±2 → weight pushed past zero. - When
sign(grad) != sign(p), update = 0 → dead zone, no flip. - This creates a natural sparsity bias: noisy gradients for nonzero weights
eventually push them to zero. Standard weight decay
wd * pcan't apply directly to {-1,0,+1} becausewd * p= 0, -wd, or +wd — not proportional.
- When
What AdamW contributes (and why it usually fails here)
Momentum — could smooth noisy sign gradients.
- Sign gradients are +1 or -1 per element per step. This is extremely noisy.
- Momentum would smooth this into something like
running_mean ≈ E[sign(grad)]. - But: momentum state adds 4 bytes/param (exp_avg), halving the available model size budget.
Adaptive LR — could help rarely-updated weights.
- Some weights might never flip because gradient sign oscillates.
- RMS-normalized gradients would amplify small-magnitude gradients.
- But: second moment adds another 4 bytes/param, and the adaptive scaling is designed for continuous domains, not discrete flip decisions.
The core mismatch: AdamW's output is a continuous step (how much to move). Ternary training needs a discrete vote (which direction to flip, if at all). Converting continuous → discrete by thresholding (
if abs(step) > 0.5) loses all the careful adaptive scaling AdamW computes.
Currently: only AdamW works for the float32 latent scalar case
For the ARBS project's current architecture:
- T_accum (int8 per-weight) + E (int8 per-group) gradients are updated
- The E scale is continuous and CAN use AdamW
- The T_accum votes come from
sign(grad)of the weight, not from AdamW
The current hybrid:
grad_W = autograd computed in float32
sign(grad_W) → T_accum accumulator (int8 votes)
grad_E from chain rule → E_accum → AdamW on E (per-group scale)
Where AdamW operates on the continuous E scales (small, 0.86% of params) and the sign-based accumulator handles the ternary T weights.
5. Practical Guidelines
When to use SignGSD
- Target: ternary/binary weight training
- Memory is the primary constraint (< 8 GB)
- You already have an accumulator pipeline (T_accum + threshold flips)
- Weight updates are naturally uniform (±1 trit flips)
When to use AdamW
- Target: float32 latent weights or per-group scales (E)
- You need adaptive scaling for rarely-activated weights
- Memory budget is generous (> 32 GB)
- Standard training where weights are continuous
When to use both (the ARBS approach)
- T_accum votes: SignGSD-like
sign(grad)accumulated as int8 - E (per-group scale): AdamW on float32 (continuous, small — 249 MB)
- T ternary weights: threshold flips driven by T_accum
- This hybrid splits the problem: discrete path via sign votes, continuous path via AdamW on scales
6. Summary
| Concern | SignGSD | AdamW |
|---|---|---|
| Memory overhead | 0 bytes/param | 8 bytes/param |
| Gradient uses | sign(grad) only | full grad magnitude |
| Per-param adaptivity | None | RMS normalization |
| Natural domain | Discrete (flip votes) | Continuous (float updates) |
| Weight decay | sign(p) → sparsity bias | wd × p → L2 regularization |
| Suitable for T_accum | Direct, 1-to-1 | Needs threshold → loses info |
| Suitable for E scales | Possibly (if small) | Yes (continuous, small) |