| from __future__ import annotations |
|
|
| import torch |
| import torch.nn as nn |
|
|
| from chimera.quantization import BitLinear |
|
|
|
|
| class MeZOOptimizer: |
| """Memory-Efficient Zeroth-Order optimiser (Princeton MeZO).""" |
|
|
| def __init__( |
| self, |
| model: nn.Module, |
| lr: float = 1e-4, |
| eps: float = 1e-3, |
| weight_decay: float = 0.0, |
| momentum: float = 0.0, |
| direction: str = "rademacher", |
| ): |
| self.model = model |
| self.lr = float(lr) |
| self.eps = float(eps) |
| self.wd = float(weight_decay) |
| self.momentum = float(momentum) |
| if direction not in ("rademacher", "gaussian"): |
| raise ValueError(f"unknown direction: {direction!r}") |
| self.direction = direction |
|
|
| self._bitlinear_modules: list[tuple[str, BitLinear]] = [] |
| self._dense_params: list[tuple[str, torch.Tensor]] = [] |
| seen: set[int] = set() |
|
|
| for name, module in model.named_modules(): |
| if isinstance(module, BitLinear): |
| self._bitlinear_modules.append((name, module)) |
| seen.add(id(module.weight)) |
| if module.bias is not None: |
| seen.add(id(module.bias)) |
|
|
| for name, param in model.named_parameters(): |
| if param.requires_grad and id(param) not in seen: |
| self._dense_params.append((name, param)) |
| seen.add(id(param)) |
|
|
| self._momentum: dict[int, torch.Tensor] = {} |
| if self.momentum > 0: |
| for _, param in self._dense_params: |
| self._momentum[id(param)] = torch.zeros_like(param.data) |
| for _, module in self._bitlinear_modules: |
| self._momentum[id(module.weight)] = torch.zeros_like(module.weight.data) |
|
|
| self._step_masks: dict[int, torch.Tensor] = {} |
|
|
| def _direction(self, p: torch.Tensor, seed: int) -> torch.Tensor: |
| gen = torch.Generator(device="cpu") |
| gen.manual_seed(int(seed) & 0x7FFF_FFFF_FFFF_FFFF) |
| if self.direction == "gaussian": |
| return torch.randn(p.shape, dtype=p.dtype, device="cpu", generator=gen).to(p.device) |
| z = torch.empty(p.shape, dtype=p.dtype, device="cpu") |
| z.bernoulli_(0.5, generator=gen).mul_(2).sub_(1) |
| return z.to(p.device) |
|
|
| def _walk_params(self): |
| offset = 0 |
| for _, module in self._bitlinear_modules: |
| yield offset, module.weight.data, self._step_masks.get(id(module.weight)) |
| offset += 1 |
| if module.bias is not None: |
| yield offset, module.bias.data, None |
| offset += 1 |
| for _, param in self._dense_params: |
| yield offset, param.data, None |
| offset += 1 |
|
|
| def _perturb(self, base_seed: int, scale: float) -> None: |
| for off, param, mask in self._walk_params(): |
| z = self._direction(param, base_seed + off * 1_000_003) |
| if mask is not None: |
| z = z * mask.to(dtype=z.dtype, device=z.device) |
| param.add_(z, alpha=scale) |
| for _, module in self._bitlinear_modules: |
| module.invalidate_packed() |
|
|
| def _update(self, base_seed: int, projected_grad: float) -> None: |
| for off, param, mask in self._walk_params(): |
| z = self._direction(param, base_seed + off * 1_000_003) |
| if mask is not None: |
| z = z * mask.to(dtype=z.dtype, device=z.device) |
| buf = self._momentum.get(id(param)) |
| if buf is not None: |
| buf.mul_(self.momentum).add_(z, alpha=projected_grad) |
| param.add_(buf, alpha=-self.lr) |
| else: |
| param.add_(z, alpha=-self.lr * projected_grad) |
| if self.wd > 0: |
| param.mul_(1 - self.lr * self.wd) |
| for _, module in self._bitlinear_modules: |
| module.invalidate_packed() |
|
|
| @torch.no_grad() |
| def step(self, loss_fn, batch) -> float: |
| seed = int(torch.randint(0, 2**31, (1,)).item()) |
| self._step_masks = {id(m.weight): m.ternary_nonzero_mask().detach() for _, m in self._bitlinear_modules} |
| self._perturb(seed, +self.eps) |
| loss_pos = float(loss_fn(batch).item()) |
| self._perturb(seed, -2.0 * self.eps) |
| loss_neg = float(loss_fn(batch).item()) |
| self._perturb(seed, +self.eps) |
| projected_grad = (loss_pos - loss_neg) / (2.0 * self.eps) |
| self._update(seed, projected_grad) |
| self._step_masks = {} |
| return 0.5 * (loss_pos + loss_neg) |
|
|