| import torch | |
| from torch.optim import Optimizer | |
| class SignSGD(Optimizer): | |
| def __init__(self, params, lr=1e-2, weight_decay=0.0): | |
| defaults = dict(lr=lr, weight_decay=weight_decay) | |
| super().__init__(params, defaults) | |
| def step(self, closure=None): | |
| loss = None | |
| if closure is not None: | |
| with torch.enable_grad(): | |
| loss = closure() | |
| for group in self.param_groups: | |
| lr = group["lr"] | |
| wd = group["weight_decay"] | |
| for p in group["params"]: | |
| if p.grad is None: | |
| continue | |
| grad = p.grad | |
| if grad.is_sparse: | |
| grad = grad.to_dense() | |
| update = grad.sign() | |
| if wd > 0: | |
| update = update + wd * p.sign() | |
| p.add_(-lr * update) | |
| return loss | |
| def get_memory_mb(self, params=None) -> float: | |
| if params is None: | |
| params = [] | |
| for group in self.param_groups: | |
| params.extend(group["params"]) | |
| total_bytes = sum(p.numel() * p.element_size() for p in params) | |
| return total_bytes / (1024 * 1024) | |