| """ |
| SignGSD — Sign Gradient-Sign Descent optimizer. |
| |
| A minimal optimizer for low-precision (ternary/binary) training. |
| Key property: discards all magnitude information. Only signs matter. |
| This aligns with ternary weight domains where weights are {-1, 0, +1} |
| and updates are discrete flips rather than continuous steps. |
| |
| Memory: zero optimizer state (no momentum buffers). Only stores what |
| torch already tracks (params + grad). 0 bytes overhead vs AdamW's |
| 8 bytes/param (2× float32). |
| """ |
| import torch |
| from torch.optim import Optimizer |
|
|
|
|
| class ScaledOptum(Optimizer): |
| """ |
| Sign Gradient-Sign Descent. |
| |
| Update rule: |
| p += -lr * (sign(grad) + wd * sign(p)) |
| |
| Compared to AdamW: |
| - No first/second moment estimates (no exp_avg, exp_avg_sq) |
| - No adaptive per-parameter learning rate |
| - Weight decay acts on sign(p) not p itself |
| - Uniform LR across all parameters |
| |
| Why this works for ternary training: |
| Ternary weights live in {-1, 0, +1}. Continuous updates like |
| p -= lr * grad immediately leave the ternary domain. SignGSD |
| sidesteps this by only voting on direction — the actual flip |
| decision (±1 vote, not a continuous step) can be accumulated |
| elsewhere (e.g., T_accum counts sign votes and flips at threshold). |
| """ |
|
|
| def __init__(self, params, lr=1e-2, weight_decay=0.0): |
| """ |
| Args: |
| params: iterable of parameters or param groups. |
| lr: uniform learning rate (same for all params, no adaptive scaling). |
| weight_decay: L2-style decay, but applied as wd * sign(p), not wd * p. |
| This pushes ternary weights toward zero when |
| sign(grad) == sign(p), because the update becomes |
| sign(grad) + sign(p) = ±2 (stronger push) or 0 (cancel) |
| when signs disagree. |
| """ |
| defaults = dict(lr=lr, weight_decay=weight_decay) |
| super().__init__(params, defaults) |
|
|
| @torch.no_grad() |
| def step(self, closure=None): |
| """ |
| Perform a single optimization step. |
| |
| Flow: |
| 1. Compute grad.sign() — direction of steepest descent, ±1 per element. |
| Discards all magnitude. This is the core difference from AdamW |
| which uses grad magnitude via adaptive RMS scaling. |
| 2. Optionally add wd * p.sign() — weight decay using _sign_ of weight, |
| not the weight itself. In standard weight decay (wd * p), large |
| weights are regularized more. Here, all nonzero weights (±1 in |
| ternary) receive equal regularization regardless of magnitude. |
| 3. p += -lr * update — apply the sign-based step. |
| |
| Memory: Does NOT allocate any optimizer state. The gradient sign and |
| parameter sign are computed on-the-fly from existing .grad and .data. |
| |
| Returns: |
| loss from closure if provided. |
| """ |
| loss = None |
| if closure is not None: |
| with torch.enable_grad(): |
| loss = closure() |
|
|
| for group in self.param_groups: |
| lr = group["lr"] |
| wd = group["weight_decay"] |
|
|
| for p in group["params"]: |
| if p.grad is None: |
| continue |
|
|
| grad = p.grad |
| if grad.is_sparse: |
| grad = grad.to_dense() |
|
|
| |
| |
| |
| update = grad.sign() |
|
|
| if wd > 0: |
| |
| |
| |
| |
| |
| update = update + wd * p.sign() |
|
|
| |
| |
| |
| |
| |
| |
| p.add_(-lr * update) |
|
|
| return loss |
|
|
| @torch.no_grad() |
| def get_memory_mb(self, params=None) -> float: |
| """ |
| Compute total memory of given parameters in MB. |
| |
| Unlike AdamW which needs 8 bytes/param for state (2× float32), |
| SignGSD stores zero optimizer state. The memory reported here |
| is just the parameter tensors themselves. |
| """ |
| if params is None: |
| params = [] |
| for group in self.param_groups: |
| params.extend(group["params"]) |
| total_bytes = sum(p.numel() * p.element_size() for p in params) |
| return total_bytes / (1024 * 1024) |
|
|