# SignGSD vs AdamW — Optimizer Analysis for Ternary Training ## 1. Algorithm Comparison ### AdamW (Loshchilov & Hutter, 2019) ``` exp_avg = β1·exp_avg + (1-β1)·grad # first moment (momentum) exp_avg_sq = β2·exp_avg_sq + (1-β2)·grad² # second moment (RMS) m_hat = exp_avg / (1-β1^t) # bias correction v_hat = exp_avg_sq / (1-β2^t) p -= lr · m_hat / (√v_hat + ε) + wd·lr·p # decoupled weight decay ``` State per parameter: - `exp_avg` (float32): 4 bytes - `exp_avg_sq` (float32): 4 bytes - Step counter `t` (int): negligible - **Total optimizer state: 8 bytes/param** ### SignGSD ``` update = sign(grad) + wd · sign(p) # sign of gradient + sign of weight p -= lr · update # uniform step ``` State per parameter: **0 bytes** (no state whatsoever) ### Key differences table | Property | AdamW | SignGSD | |---|---|---| | Gradient uses | Direction + magnitude | Direction only (sign) | | Adaptive LR | Per-param via RMS | None (uniform LR) | | Momentum | Exponentially smoothed | None | | Weight decay | `wd * lr * p` (decoupled) | `wd * sign(p)` (coupled to update) | | LR schedule | Required (cosine, linear) | Not needed (step size = ±lr always) | | State/param | 8 bytes | 0 bytes | | Suitable for | Full-precision training | Low-precision / binary / ternary | --- ## 2. What is Ternary Training? Ternary weights are constrained to **{-1, 0, +1}**. This gives three benefits: - **Storage**: ~1.6 bits/weight (base-3 packing: 5 trits per byte) - **Compute**: matmuls with {-1,0,+1} weights use add/sub only (no multiplications) - **Memory bandwidth**: ~1/16× vs float32 weights The fundamental challenge: **continuous optimizers cannot directly update ternary weights**. ``` w ∈ {-1, 0, +1} w -= lr * grad → 0.7, -0.3, 1.2 ... NOT ternary! ``` Every optimizer step must eventually produce a flip decision (±1 change or stay), not a continuous displacement. --- ## 3. The Two-Layer Architecture of Ternary Training All practical ternary training uses a **latent/accumulator layer** between the optimizer and the actual ternary weights: ``` Optimizer (SignGSD or AdamW) ↓ writes continuous updates Latent / Accumulator (e.g., T_accum: int8 per-weight scale) ↓ threshold-based flips Ternary Weights T ∈ {-1, 0, +1} (packed as base-3 trits) ``` ### Flow detail 1. **Forward**: unpack T → matmul(W = S × T, x) → loss - T is ternary, S is a per-weight or per-group scale 2. **Backward**: grad w.r.t. W flows to both S and T paths - Gradient through T needs special handling (discrete, not differentiable) 3. **Optimizer step on S** (if per-param): standard optimizer (AdamW / SignGSD) - S is continuous (float/int) — no special treatment needed 4. **Accumulator update on T_accum**: gradient sign votes accumulate ``` T_accum[i] += sign(grad_w[i]) # vote: +1 or -1 ``` 5. **Threshold-based T flip**: ``` if T_accum[i] > threshold: T[i] = +1; T_accum[i] = 0 (carry reset) if T_accum[i] < -threshold: T[i] = -1; T_accum[i] = 0 ``` --- ## 4. What Each Optimizer Contributes ### What SignGSD contributes to ternary training 1. **Sign-only updates** — exactly what T_accum needs. - `grad.sign()` produces {-1, 0, +1} which maps 1-to-1 to flip votes. - No need to discard magnitude from AdamW's continuous gradient. 2. **Zero optimizer state** — critical at 3B+ parameter scale. - AdamW's 8 bytes/param × 3.1B = **24.2 GB** of optimizer state. - That exceeds the entire model's storage budget (4 GB training state). - SignGSD adds 0 bytes — all budget goes to T_accum (which is the actual accumulator, not optimizer state). 3. **Uniform LR** — matches the discrete flip domain. - Ternary flip decisions are uniform: every flip changes weight by ±1 trit. - AdamW's adaptive LR would modulate this uniform vote-counting process, adding complexity without clear benefit. 4. **Weight decay via sign(p)** — naturally pushes ternary weights toward zero. - When `sign(grad) == sign(p)`, update = ±2 → weight pushed past zero. - When `sign(grad) != sign(p)`, update = 0 → dead zone, no flip. - This creates a natural sparsity bias: noisy gradients for nonzero weights eventually push them to zero. Standard weight decay `wd * p` can't apply directly to {-1,0,+1} because `wd * p` = 0, -wd, or +wd — not proportional. ### What AdamW contributes (and why it usually fails here) 1. **Momentum** — could smooth noisy sign gradients. - Sign gradients are +1 or -1 per element per step. This is extremely noisy. - Momentum would smooth this into something like `running_mean ≈ E[sign(grad)]`. - **But**: momentum state adds 4 bytes/param (exp_avg), halving the available model size budget. 2. **Adaptive LR** — could help rarely-updated weights. - Some weights might never flip because gradient sign oscillates. - RMS-normalized gradients would amplify small-magnitude gradients. - **But**: second moment adds another 4 bytes/param, and the adaptive scaling is designed for continuous domains, not discrete flip decisions. 3. **The core mismatch**: AdamW's output is a continuous step (how much to move). Ternary training needs a discrete vote (which direction to flip, if at all). Converting continuous → discrete by thresholding (`if abs(step) > 0.5`) loses all the careful adaptive scaling AdamW computes. ### Currently: only AdamW works for the float32 latent scalar case For the ARBS project's current architecture: - T_accum (int8 per-weight) + E (int8 per-group) gradients are updated - The E scale is continuous and CAN use AdamW - The T_accum votes come from `sign(grad)` of the weight, not from AdamW The current hybrid: ``` grad_W = autograd computed in float32 sign(grad_W) → T_accum accumulator (int8 votes) grad_E from chain rule → E_accum → AdamW on E (per-group scale) ``` Where AdamW operates on the continuous E scales (small, 0.86% of params) and the sign-based accumulator handles the ternary T weights. --- ## 5. Practical Guidelines ### When to use SignGSD - Target: ternary/binary weight training - Memory is the primary constraint (< 8 GB) - You already have an accumulator pipeline (T_accum + threshold flips) - Weight updates are naturally uniform (±1 trit flips) ### When to use AdamW - Target: float32 latent weights or per-group scales (E) - You need adaptive scaling for rarely-activated weights - Memory budget is generous (> 32 GB) - Standard training where weights are continuous ### When to use both (the ARBS approach) - **T_accum** votes: SignGSD-like `sign(grad)` accumulated as int8 - **E** (per-group scale): AdamW on float32 (continuous, small — 249 MB) - **T** ternary weights: threshold flips driven by T_accum - This hybrid splits the problem: discrete path via sign votes, continuous path via AdamW on scales --- ## 6. Summary | Concern | SignGSD | AdamW | |---|---|---| | Memory overhead | 0 bytes/param | 8 bytes/param | | Gradient uses | sign(grad) only | full grad magnitude | | Per-param adaptivity | None | RMS normalization | | Natural domain | Discrete (flip votes) | Continuous (float updates) | | Weight decay | sign(p) → sparsity bias | wd × p → L2 regularization | | Suitable for T_accum | Direct, 1-to-1 | Needs threshold → loses info | | Suitable for E scales | Possibly (if small) | Yes (continuous, small) |