|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import torch
|
|
|
|
|
| class LARS(torch.optim.Optimizer):
|
| """
|
| LARS optimizer, no rate scaling or weight decay for parameters <= 1D.
|
| """
|
| def __init__(self, params, lr=0, weight_decay=0, momentum=0.9, trust_coefficient=0.001):
|
| defaults = dict(lr=lr, weight_decay=weight_decay, momentum=momentum, trust_coefficient=trust_coefficient)
|
| super().__init__(params, defaults)
|
|
|
| @torch.no_grad()
|
| def step(self):
|
| for g in self.param_groups:
|
| for p in g['params']:
|
| dp = p.grad
|
|
|
| if dp is None:
|
| continue
|
|
|
| if p.ndim > 1:
|
| dp = dp.add(p, alpha=g['weight_decay'])
|
| param_norm = torch.norm(p)
|
| update_norm = torch.norm(dp)
|
| one = torch.ones_like(param_norm)
|
| q = torch.where(param_norm > 0.,
|
| torch.where(update_norm > 0,
|
| (g['trust_coefficient'] * param_norm / update_norm), one),
|
| one)
|
| dp = dp.mul(q)
|
|
|
| param_state = self.state[p]
|
| if 'mu' not in param_state:
|
| param_state['mu'] = torch.zeros_like(p)
|
| mu = param_state['mu']
|
| mu.mul_(g['momentum']).add_(dp)
|
| p.add_(mu, alpha=-g['lr']) |