| # SignGSD vs AdamW β Optimizer Analysis for Ternary Training |
|
|
| ## 1. Algorithm Comparison |
|
|
| ### AdamW (Loshchilov & Hutter, 2019) |
|
|
| ``` |
| exp_avg = Ξ²1Β·exp_avg + (1-Ξ²1)Β·grad # first moment (momentum) |
| exp_avg_sq = Ξ²2Β·exp_avg_sq + (1-Ξ²2)Β·gradΒ² # second moment (RMS) |
| m_hat = exp_avg / (1-Ξ²1^t) # bias correction |
| v_hat = exp_avg_sq / (1-Ξ²2^t) |
| p -= lr Β· m_hat / (βv_hat + Ξ΅) + wdΒ·lrΒ·p # decoupled weight decay |
| ``` |
|
|
| State per parameter: |
| - `exp_avg` (float32): 4 bytes |
| - `exp_avg_sq` (float32): 4 bytes |
| - Step counter `t` (int): negligible |
| - **Total optimizer state: 8 bytes/param** |
|
|
| ### SignGSD |
|
|
| ``` |
| update = sign(grad) + wd Β· sign(p) # sign of gradient + sign of weight |
| p -= lr Β· update # uniform step |
| ``` |
|
|
| State per parameter: **0 bytes** (no state whatsoever) |
|
|
| ### Key differences table |
|
|
| | Property | AdamW | SignGSD | |
| |---|---|---| |
| | Gradient uses | Direction + magnitude | Direction only (sign) | |
| | Adaptive LR | Per-param via RMS | None (uniform LR) | |
| | Momentum | Exponentially smoothed | None | |
| | Weight decay | `wd * lr * p` (decoupled) | `wd * sign(p)` (coupled to update) | |
| | LR schedule | Required (cosine, linear) | Not needed (step size = Β±lr always) | |
| | State/param | 8 bytes | 0 bytes | |
| | Suitable for | Full-precision training | Low-precision / binary / ternary | |
|
|
| --- |
|
|
| ## 2. What is Ternary Training? |
|
|
| Ternary weights are constrained to **{-1, 0, +1}**. This gives three benefits: |
| - **Storage**: ~1.6 bits/weight (base-3 packing: 5 trits per byte) |
| - **Compute**: matmuls with {-1,0,+1} weights use add/sub only (no multiplications) |
| - **Memory bandwidth**: ~1/16Γ vs float32 weights |
|
|
| The fundamental challenge: **continuous optimizers cannot directly update ternary weights**. |
|
|
| ``` |
| w β {-1, 0, +1} |
| w -= lr * grad β 0.7, -0.3, 1.2 ... NOT ternary! |
| ``` |
|
|
| Every optimizer step must eventually produce a flip decision (Β±1 change or stay), |
| not a continuous displacement. |
|
|
| --- |
|
|
| ## 3. The Two-Layer Architecture of Ternary Training |
|
|
| All practical ternary training uses a **latent/accumulator layer** between the |
| optimizer and the actual ternary weights: |
|
|
| ``` |
| Optimizer (SignGSD or AdamW) |
| β writes continuous updates |
| Latent / Accumulator (e.g., T_accum: int8 per-weight scale) |
| β threshold-based flips |
| Ternary Weights T β {-1, 0, +1} (packed as base-3 trits) |
| ``` |
|
|
| ### Flow detail |
|
|
| 1. **Forward**: unpack T β matmul(W = S Γ T, x) β loss |
| - T is ternary, S is a per-weight or per-group scale |
| 2. **Backward**: grad w.r.t. W flows to both S and T paths |
| - Gradient through T needs special handling (discrete, not differentiable) |
| 3. **Optimizer step on S** (if per-param): standard optimizer (AdamW / SignGSD) |
| - S is continuous (float/int) β no special treatment needed |
| 4. **Accumulator update on T_accum**: gradient sign votes accumulate |
| ``` |
| T_accum[i] += sign(grad_w[i]) # vote: +1 or -1 |
| ``` |
| 5. **Threshold-based T flip**: |
| ``` |
| if T_accum[i] > threshold: T[i] = +1; T_accum[i] = 0 (carry reset) |
| if T_accum[i] < -threshold: T[i] = -1; T_accum[i] = 0 |
| ``` |
| |
| --- |
| |
| ## 4. What Each Optimizer Contributes |
| |
| ### What SignGSD contributes to ternary training |
| |
| 1. **Sign-only updates** β exactly what T_accum needs. |
| - `grad.sign()` produces {-1, 0, +1} which maps 1-to-1 to flip votes. |
| - No need to discard magnitude from AdamW's continuous gradient. |
| |
| 2. **Zero optimizer state** β critical at 3B+ parameter scale. |
| - AdamW's 8 bytes/param Γ 3.1B = **24.2 GB** of optimizer state. |
| - That exceeds the entire model's storage budget (4 GB training state). |
| - SignGSD adds 0 bytes β all budget goes to T_accum (which is the actual |
| accumulator, not optimizer state). |
| |
| 3. **Uniform LR** β matches the discrete flip domain. |
| - Ternary flip decisions are uniform: every flip changes weight by Β±1 trit. |
| - AdamW's adaptive LR would modulate this uniform vote-counting process, |
| adding complexity without clear benefit. |
| |
| 4. **Weight decay via sign(p)** β naturally pushes ternary weights toward zero. |
| - When `sign(grad) == sign(p)`, update = Β±2 β weight pushed past zero. |
| - When `sign(grad) != sign(p)`, update = 0 β dead zone, no flip. |
| - This creates a natural sparsity bias: noisy gradients for nonzero weights |
| eventually push them to zero. Standard weight decay `wd * p` can't apply |
| directly to {-1,0,+1} because `wd * p` = 0, -wd, or +wd β not proportional. |
| |
| ### What AdamW contributes (and why it usually fails here) |
| |
| 1. **Momentum** β could smooth noisy sign gradients. |
| - Sign gradients are +1 or -1 per element per step. This is extremely noisy. |
| - Momentum would smooth this into something like `running_mean β E[sign(grad)]`. |
| - **But**: momentum state adds 4 bytes/param (exp_avg), halving the available |
| model size budget. |
| |
| 2. **Adaptive LR** β could help rarely-updated weights. |
| - Some weights might never flip because gradient sign oscillates. |
| - RMS-normalized gradients would amplify small-magnitude gradients. |
| - **But**: second moment adds another 4 bytes/param, and the adaptive scaling |
| is designed for continuous domains, not discrete flip decisions. |
| |
| 3. **The core mismatch**: AdamW's output is a continuous step (how much to move). |
| Ternary training needs a discrete vote (which direction to flip, if at all). |
| Converting continuous β discrete by thresholding (`if abs(step) > 0.5`) loses |
| all the careful adaptive scaling AdamW computes. |
| |
| ### Currently: only AdamW works for the float32 latent scalar case |
| |
| For the ARBS project's current architecture: |
| - T_accum (int8 per-weight) + E (int8 per-group) gradients are updated |
| - The E scale is continuous and CAN use AdamW |
| - The T_accum votes come from `sign(grad)` of the weight, not from AdamW |
| |
| The current hybrid: |
| ``` |
| grad_W = autograd computed in float32 |
| sign(grad_W) β T_accum accumulator (int8 votes) |
| grad_E from chain rule β E_accum β AdamW on E (per-group scale) |
| ``` |
| |
| Where AdamW operates on the continuous E scales (small, 0.86% of params) |
| and the sign-based accumulator handles the ternary T weights. |
| |
| --- |
| |
| ## 5. Practical Guidelines |
| |
| ### When to use SignGSD |
| - Target: ternary/binary weight training |
| - Memory is the primary constraint (< 8 GB) |
| - You already have an accumulator pipeline (T_accum + threshold flips) |
| - Weight updates are naturally uniform (Β±1 trit flips) |
| |
| ### When to use AdamW |
| - Target: float32 latent weights or per-group scales (E) |
| - You need adaptive scaling for rarely-activated weights |
| - Memory budget is generous (> 32 GB) |
| - Standard training where weights are continuous |
| |
| ### When to use both (the ARBS approach) |
| - **T_accum** votes: SignGSD-like `sign(grad)` accumulated as int8 |
| - **E** (per-group scale): AdamW on float32 (continuous, small β 249 MB) |
| - **T** ternary weights: threshold flips driven by T_accum |
| - This hybrid splits the problem: discrete path via sign votes, continuous |
| path via AdamW on scales |
| |
| --- |
| |
| ## 6. Summary |
| |
| | Concern | SignGSD | AdamW | |
| |---|---|---| |
| | Memory overhead | 0 bytes/param | 8 bytes/param | |
| | Gradient uses | sign(grad) only | full grad magnitude | |
| | Per-param adaptivity | None | RMS normalization | |
| | Natural domain | Discrete (flip votes) | Continuous (float updates) | |
| | Weight decay | sign(p) β sparsity bias | wd Γ p β L2 regularization | |
| | Suitable for T_accum | Direct, 1-to-1 | Needs threshold β loses info | |
| | Suitable for E scales | Possibly (if small) | Yes (continuous, small) | |
| |