ARBS / testing /sign_gsd_vs_adamw.md
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified

SignGSD vs AdamW — Optimizer Analysis for Ternary Training

1. Algorithm Comparison

AdamW (Loshchilov & Hutter, 2019)

exp_avg = β1·exp_avg + (1-β1)·grad          # first moment (momentum)
exp_avg_sq = β2·exp_avg_sq + (1-β2)·grad²    # second moment (RMS)
m_hat = exp_avg / (1-β1^t)                    # bias correction
v_hat = exp_avg_sq / (1-β2^t)
p -= lr · m_hat / (√v_hat + ε) + wd·lr·p     # decoupled weight decay

State per parameter:

  • exp_avg (float32): 4 bytes
  • exp_avg_sq (float32): 4 bytes
  • Step counter t (int): negligible
  • Total optimizer state: 8 bytes/param

SignGSD

update = sign(grad) + wd · sign(p)           # sign of gradient + sign of weight
p -= lr · update                             # uniform step

State per parameter: 0 bytes (no state whatsoever)

Key differences table

Property AdamW SignGSD
Gradient uses Direction + magnitude Direction only (sign)
Adaptive LR Per-param via RMS None (uniform LR)
Momentum Exponentially smoothed None
Weight decay wd * lr * p (decoupled) wd * sign(p) (coupled to update)
LR schedule Required (cosine, linear) Not needed (step size = ±lr always)
State/param 8 bytes 0 bytes
Suitable for Full-precision training Low-precision / binary / ternary

2. What is Ternary Training?

Ternary weights are constrained to {-1, 0, +1}. This gives three benefits:

  • Storage: ~1.6 bits/weight (base-3 packing: 5 trits per byte)
  • Compute: matmuls with {-1,0,+1} weights use add/sub only (no multiplications)
  • Memory bandwidth: ~1/16× vs float32 weights

The fundamental challenge: continuous optimizers cannot directly update ternary weights.

w ∈ {-1, 0, +1}
w -= lr * grad   →  0.7, -0.3, 1.2 ... NOT ternary!

Every optimizer step must eventually produce a flip decision (±1 change or stay), not a continuous displacement.


3. The Two-Layer Architecture of Ternary Training

All practical ternary training uses a latent/accumulator layer between the optimizer and the actual ternary weights:

Optimizer (SignGSD or AdamW)
    ↓ writes continuous updates
Latent / Accumulator (e.g., T_accum: int8 per-weight scale)
    ↓ threshold-based flips
Ternary Weights T ∈ {-1, 0, +1} (packed as base-3 trits)

Flow detail

  1. Forward: unpack T → matmul(W = S × T, x) → loss
    • T is ternary, S is a per-weight or per-group scale
  2. Backward: grad w.r.t. W flows to both S and T paths
    • Gradient through T needs special handling (discrete, not differentiable)
  3. Optimizer step on S (if per-param): standard optimizer (AdamW / SignGSD)
    • S is continuous (float/int) — no special treatment needed
  4. Accumulator update on T_accum: gradient sign votes accumulate
    T_accum[i] += sign(grad_w[i])     # vote: +1 or -1
    
  5. Threshold-based T flip:
    if T_accum[i] > threshold:  T[i] = +1;  T_accum[i] = 0  (carry reset)
    if T_accum[i] < -threshold: T[i] = -1;  T_accum[i] = 0
    

4. What Each Optimizer Contributes

What SignGSD contributes to ternary training

  1. Sign-only updates — exactly what T_accum needs.

    • grad.sign() produces {-1, 0, +1} which maps 1-to-1 to flip votes.
    • No need to discard magnitude from AdamW's continuous gradient.
  2. Zero optimizer state — critical at 3B+ parameter scale.

    • AdamW's 8 bytes/param × 3.1B = 24.2 GB of optimizer state.
    • That exceeds the entire model's storage budget (4 GB training state).
    • SignGSD adds 0 bytes — all budget goes to T_accum (which is the actual accumulator, not optimizer state).
  3. Uniform LR — matches the discrete flip domain.

    • Ternary flip decisions are uniform: every flip changes weight by ±1 trit.
    • AdamW's adaptive LR would modulate this uniform vote-counting process, adding complexity without clear benefit.
  4. Weight decay via sign(p) — naturally pushes ternary weights toward zero.

    • When sign(grad) == sign(p), update = ±2 → weight pushed past zero.
    • When sign(grad) != sign(p), update = 0 → dead zone, no flip.
    • This creates a natural sparsity bias: noisy gradients for nonzero weights eventually push them to zero. Standard weight decay wd * p can't apply directly to {-1,0,+1} because wd * p = 0, -wd, or +wd — not proportional.

What AdamW contributes (and why it usually fails here)

  1. Momentum — could smooth noisy sign gradients.

    • Sign gradients are +1 or -1 per element per step. This is extremely noisy.
    • Momentum would smooth this into something like running_mean ≈ E[sign(grad)].
    • But: momentum state adds 4 bytes/param (exp_avg), halving the available model size budget.
  2. Adaptive LR — could help rarely-updated weights.

    • Some weights might never flip because gradient sign oscillates.
    • RMS-normalized gradients would amplify small-magnitude gradients.
    • But: second moment adds another 4 bytes/param, and the adaptive scaling is designed for continuous domains, not discrete flip decisions.
  3. The core mismatch: AdamW's output is a continuous step (how much to move). Ternary training needs a discrete vote (which direction to flip, if at all). Converting continuous → discrete by thresholding (if abs(step) > 0.5) loses all the careful adaptive scaling AdamW computes.

Currently: only AdamW works for the float32 latent scalar case

For the ARBS project's current architecture:

  • T_accum (int8 per-weight) + E (int8 per-group) gradients are updated
  • The E scale is continuous and CAN use AdamW
  • The T_accum votes come from sign(grad) of the weight, not from AdamW

The current hybrid:

grad_W = autograd computed in float32
sign(grad_W) → T_accum accumulator (int8 votes)
grad_E from chain rule → E_accum → AdamW on E (per-group scale)

Where AdamW operates on the continuous E scales (small, 0.86% of params) and the sign-based accumulator handles the ternary T weights.


5. Practical Guidelines

When to use SignGSD

  • Target: ternary/binary weight training
  • Memory is the primary constraint (< 8 GB)
  • You already have an accumulator pipeline (T_accum + threshold flips)
  • Weight updates are naturally uniform (±1 trit flips)

When to use AdamW

  • Target: float32 latent weights or per-group scales (E)
  • You need adaptive scaling for rarely-activated weights
  • Memory budget is generous (> 32 GB)
  • Standard training where weights are continuous

When to use both (the ARBS approach)

  • T_accum votes: SignGSD-like sign(grad) accumulated as int8
  • E (per-group scale): AdamW on float32 (continuous, small — 249 MB)
  • T ternary weights: threshold flips driven by T_accum
  • This hybrid splits the problem: discrete path via sign votes, continuous path via AdamW on scales

6. Summary

Concern SignGSD AdamW
Memory overhead 0 bytes/param 8 bytes/param
Gradient uses sign(grad) only full grad magnitude
Per-param adaptivity None RMS normalization
Natural domain Discrete (flip votes) Continuous (float updates)
Weight decay sign(p) → sparsity bias wd × p → L2 regularization
Suitable for T_accum Direct, 1-to-1 Needs threshold → loses info
Suitable for E scales Possibly (if small) Yes (continuous, small)