ARBS / arbitor /optim /sign_sgd.py
CLIWorks's picture
Upload folder using huggingface_hub
d8bc908 verified
import torch
from torch.optim import Optimizer
class SignSGD(Optimizer):
def __init__(self, params, lr=1e-2, weight_decay=0.0):
defaults = dict(lr=lr, weight_decay=weight_decay)
super().__init__(params, defaults)
@torch.no_grad()
def step(self, closure=None):
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
lr = group["lr"]
wd = group["weight_decay"]
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad
if grad.is_sparse:
grad = grad.to_dense()
update = grad.sign()
if wd > 0:
update = update + wd * p.sign()
p.add_(-lr * update)
return loss
@torch.no_grad()
def get_memory_mb(self, params=None) -> float:
if params is None:
params = []
for group in self.param_groups:
params.extend(group["params"])
total_bytes = sum(p.numel() * p.element_size() for p in params)
return total_bytes / (1024 * 1024)