File size: 5,060 Bytes
d8bc908 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | """
SignGSD — Sign Gradient-Sign Descent optimizer.
A minimal optimizer for low-precision (ternary/binary) training.
Key property: discards all magnitude information. Only signs matter.
This aligns with ternary weight domains where weights are {-1, 0, +1}
and updates are discrete flips rather than continuous steps.
Memory: zero optimizer state (no momentum buffers). Only stores what
torch already tracks (params + grad). 0 bytes overhead vs AdamW's
8 bytes/param (2× float32).
"""
import torch
from torch.optim import Optimizer
class ScaledOptum(Optimizer):
"""
Sign Gradient-Sign Descent.
Update rule:
p += -lr * (sign(grad) + wd * sign(p))
Compared to AdamW:
- No first/second moment estimates (no exp_avg, exp_avg_sq)
- No adaptive per-parameter learning rate
- Weight decay acts on sign(p) not p itself
- Uniform LR across all parameters
Why this works for ternary training:
Ternary weights live in {-1, 0, +1}. Continuous updates like
p -= lr * grad immediately leave the ternary domain. SignGSD
sidesteps this by only voting on direction — the actual flip
decision (±1 vote, not a continuous step) can be accumulated
elsewhere (e.g., T_accum counts sign votes and flips at threshold).
"""
def __init__(self, params, lr=1e-2, weight_decay=0.0):
"""
Args:
params: iterable of parameters or param groups.
lr: uniform learning rate (same for all params, no adaptive scaling).
weight_decay: L2-style decay, but applied as wd * sign(p), not wd * p.
This pushes ternary weights toward zero when
sign(grad) == sign(p), because the update becomes
sign(grad) + sign(p) = ±2 (stronger push) or 0 (cancel)
when signs disagree.
"""
defaults = dict(lr=lr, weight_decay=weight_decay)
super().__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None):
"""
Perform a single optimization step.
Flow:
1. Compute grad.sign() — direction of steepest descent, ±1 per element.
Discards all magnitude. This is the core difference from AdamW
which uses grad magnitude via adaptive RMS scaling.
2. Optionally add wd * p.sign() — weight decay using _sign_ of weight,
not the weight itself. In standard weight decay (wd * p), large
weights are regularized more. Here, all nonzero weights (±1 in
ternary) receive equal regularization regardless of magnitude.
3. p += -lr * update — apply the sign-based step.
Memory: Does NOT allocate any optimizer state. The gradient sign and
parameter sign are computed on-the-fly from existing .grad and .data.
Returns:
loss from closure if provided.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
lr = group["lr"]
wd = group["weight_decay"]
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad
if grad.is_sparse:
grad = grad.to_dense()
# === Core: sign-sign update ===
# update = sign(grad) ∈ {-1, 0, +1}
# Zero gradient → zero update (no flip vote)
update = grad.sign()
if wd > 0:
# Weight decay as sign(p) not p.
# For ternary p ∈ {-1, 0, +1}, sign(p) = p (except 0).
# This biases toward zero: when grad and p agree,
# |update| = 2 (stronger pull back toward zero).
# When they disagree, they cancel to 0 (no update).
update = update + wd * p.sign()
# p += -lr * update
# For ternary: the actual flip happens elsewhere.
# This step writes to the _latent_ or _accumulator_ values,
# not the ternary weights themselves.
# (See prepare_ternary_backward + _ternary_update_memory
# in the ARBS training loop for the flip pipeline.)
p.add_(-lr * update)
return loss
@torch.no_grad()
def get_memory_mb(self, params=None) -> float:
"""
Compute total memory of given parameters in MB.
Unlike AdamW which needs 8 bytes/param for state (2× float32),
SignGSD stores zero optimizer state. The memory reported here
is just the parameter tensors themselves.
"""
if params is None:
params = []
for group in self.param_groups:
params.extend(group["params"])
total_bytes = sum(p.numel() * p.element_size() for p in params)
return total_bytes / (1024 * 1024)
|