| import torch |
| from torch.optim import Optimizer |
|
|
| class SnooC(Optimizer): |
| """ |
| @DominikKallusky, @vishal9-team, @vinaysrao |
| |
| Sparse Nesterov Outer Optimizer (Snoo) is a momentum-based wrapper to any optimizer that can |
| improve the stability and smoothness of the optimization process and thus the quality |
| of large language models (LLM) and other models. Snoo implicitly adds temporal regularization |
| to the parameters, thus smoothing the training trajectory and instilling a bias towards flatter |
| minima and lower parameter norms. Snoo is computationally efficient, incurring minimal overhead |
| in compute and moderate memory usage. |
| """ |
|
|
| @torch.no_grad() |
| def __init__(self, optimizer, lr: float = 0.67, momentum: float = 0.67, k: int = 20) -> None: |
| self.optimizer = optimizer |
| self.lr = lr |
| self.momentum = momentum |
| self.k = k |
| self.current_step = 0 |
| self.model_params = None |
| self.outer_buf = None |
| self.outer_optimizer = None |
|
|
| |
| if self.optimizer.param_groups: |
| self.param_groups = self.optimizer.param_groups |
| |
| @torch.no_grad() |
| def _initialize_outer_optimizer(self): |
| params = [] |
| for pg in self.optimizer.param_groups: |
| if len(pg['params']) > 1: |
| for param in pg['params']: |
| if isinstance(param, torch.Tensor): |
| params.append(param) |
| else: |
| params = pg['params'] |
| |
| if not params: |
| return |
|
|
| self.model_params = list(params) |
| self.outer_buf = [p.clone() for p in self.model_params] |
| self.outer_optimizer = torch.optim.SGD( |
| self.model_params, |
| lr=self.lr, |
| momentum=self.momentum, |
| nesterov=True, |
| fused=True, |
| ) |
| self.param_groups = self.optimizer.param_groups |
| del params |
|
|
| @torch.no_grad() |
| def step(self, closure=None): |
| if self.outer_optimizer is None or self.current_step == 0: |
| |
| if self.optimizer.param_groups: |
| self._initialize_outer_optimizer() |
| else: |
| |
| |
| |
| return self.optimizer.step(closure) |
|
|
| loss = self.optimizer.step(closure) |
| if self.current_step % self.k == 0: |
| for p_new, p_old in zip(self.model_params, self.outer_buf): |
| p_new.grad = p_old.data - p_new.data |
| p_new.copy_(p_old, non_blocking=True) |
|
|
| self.outer_optimizer.step() |
|
|
| for p_new, p_old in zip(self.model_params, self.outer_buf): |
| p_old.copy_(p_new, non_blocking=True) |
| self.current_step += 1 |
| return loss |
| |
| def zero_grad(self, set_to_none: bool = False): |
| self.optimizer.zero_grad(set_to_none=set_to_none) |
|
|
| def state_dict(self): |
| return self.optimizer.state_dict() |
|
|
| def load_state_dict(self, state_dict): |
| self.optimizer.load_state_dict(state_dict) |