File size: 1,242 Bytes
d8bc908 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 | import torch
from torch.optim import Optimizer
class SignSGD(Optimizer):
def __init__(self, params, lr=1e-2, weight_decay=0.0):
defaults = dict(lr=lr, weight_decay=weight_decay)
super().__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
lr = group["lr"]
wd = group["weight_decay"]
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad
if grad.is_sparse:
grad = grad.to_dense()
update = grad.sign()
if wd > 0:
update = update + wd * p.sign()
p.add_(-lr * update)
return loss
@torch.no_grad()
def get_memory_mb(self, params=None) -> float:
if params is None:
params = []
for group in self.param_groups:
params.extend(group["params"])
total_bytes = sum(p.numel() * p.element_size() for p in params)
return total_bytes / (1024 * 1024)
|