import torch from torch.optim import Optimizer class SignSGD(Optimizer): def __init__(self, params, lr=1e-2, weight_decay=0.0): defaults = dict(lr=lr, weight_decay=weight_decay) super().__init__(params, defaults) @torch.no_grad() def step(self, closure=None): loss = None if closure is not None: with torch.enable_grad(): loss = closure() for group in self.param_groups: lr = group["lr"] wd = group["weight_decay"] for p in group["params"]: if p.grad is None: continue grad = p.grad if grad.is_sparse: grad = grad.to_dense() update = grad.sign() if wd > 0: update = update + wd * p.sign() p.add_(-lr * update) return loss @torch.no_grad() def get_memory_mb(self, params=None) -> float: if params is None: params = [] for group in self.param_groups: params.extend(group["params"]) total_bytes = sum(p.numel() * p.element_size() for p in params) return total_bytes / (1024 * 1024)