ARBS / testing /sign_gsd.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
"""
SignGSD — Sign Gradient-Sign Descent optimizer.
A minimal optimizer for low-precision (ternary/binary) training.
Key property: discards all magnitude information. Only signs matter.
This aligns with ternary weight domains where weights are {-1, 0, +1}
and updates are discrete flips rather than continuous steps.
Memory: zero optimizer state (no momentum buffers). Only stores what
torch already tracks (params + grad). 0 bytes overhead vs AdamW's
8 bytes/param (2× float32).
"""
import torch
from torch.optim import Optimizer
class ScaledOptum(Optimizer):
"""
Sign Gradient-Sign Descent.
Update rule:
p += -lr * (sign(grad) + wd * sign(p))
Compared to AdamW:
- No first/second moment estimates (no exp_avg, exp_avg_sq)
- No adaptive per-parameter learning rate
- Weight decay acts on sign(p) not p itself
- Uniform LR across all parameters
Why this works for ternary training:
Ternary weights live in {-1, 0, +1}. Continuous updates like
p -= lr * grad immediately leave the ternary domain. SignGSD
sidesteps this by only voting on direction — the actual flip
decision (±1 vote, not a continuous step) can be accumulated
elsewhere (e.g., T_accum counts sign votes and flips at threshold).
"""
def __init__(self, params, lr=1e-2, weight_decay=0.0):
"""
Args:
params: iterable of parameters or param groups.
lr: uniform learning rate (same for all params, no adaptive scaling).
weight_decay: L2-style decay, but applied as wd * sign(p), not wd * p.
This pushes ternary weights toward zero when
sign(grad) == sign(p), because the update becomes
sign(grad) + sign(p) = ±2 (stronger push) or 0 (cancel)
when signs disagree.
"""
defaults = dict(lr=lr, weight_decay=weight_decay)
super().__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None):
"""
Perform a single optimization step.
Flow:
1. Compute grad.sign() — direction of steepest descent, ±1 per element.
Discards all magnitude. This is the core difference from AdamW
which uses grad magnitude via adaptive RMS scaling.
2. Optionally add wd * p.sign() — weight decay using _sign_ of weight,
not the weight itself. In standard weight decay (wd * p), large
weights are regularized more. Here, all nonzero weights (±1 in
ternary) receive equal regularization regardless of magnitude.
3. p += -lr * update — apply the sign-based step.
Memory: Does NOT allocate any optimizer state. The gradient sign and
parameter sign are computed on-the-fly from existing .grad and .data.
Returns:
loss from closure if provided.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
lr = group["lr"]
wd = group["weight_decay"]
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad
if grad.is_sparse:
grad = grad.to_dense()
# === Core: sign-sign update ===
# update = sign(grad) ∈ {-1, 0, +1}
# Zero gradient → zero update (no flip vote)
update = grad.sign()
if wd > 0:
# Weight decay as sign(p) not p.
# For ternary p ∈ {-1, 0, +1}, sign(p) = p (except 0).
# This biases toward zero: when grad and p agree,
# |update| = 2 (stronger pull back toward zero).
# When they disagree, they cancel to 0 (no update).
update = update + wd * p.sign()
# p += -lr * update
# For ternary: the actual flip happens elsewhere.
# This step writes to the _latent_ or _accumulator_ values,
# not the ternary weights themselves.
# (See prepare_ternary_backward + _ternary_update_memory
# in the ARBS training loop for the flip pipeline.)
p.add_(-lr * update)
return loss
@torch.no_grad()
def get_memory_mb(self, params=None) -> float:
"""
Compute total memory of given parameters in MB.
Unlike AdamW which needs 8 bytes/param for state (2× float32),
SignGSD stores zero optimizer state. The memory reported here
is just the parameter tensors themselves.
"""
if params is None:
params = []
for group in self.param_groups:
params.extend(group["params"])
total_bytes = sum(p.numel() * p.element_size() for p in params)
return total_bytes / (1024 * 1024)