from __future__ import annotations import torch import torch.nn as nn from chimera.quantization import BitLinear class MeZOOptimizer: """Memory-Efficient Zeroth-Order optimiser (Princeton MeZO).""" def __init__( self, model: nn.Module, lr: float = 1e-4, eps: float = 1e-3, weight_decay: float = 0.0, momentum: float = 0.0, direction: str = "rademacher", ): self.model = model self.lr = float(lr) self.eps = float(eps) self.wd = float(weight_decay) self.momentum = float(momentum) if direction not in ("rademacher", "gaussian"): raise ValueError(f"unknown direction: {direction!r}") self.direction = direction self._bitlinear_modules: list[tuple[str, BitLinear]] = [] self._dense_params: list[tuple[str, torch.Tensor]] = [] seen: set[int] = set() for name, module in model.named_modules(): if isinstance(module, BitLinear): self._bitlinear_modules.append((name, module)) seen.add(id(module.weight)) if module.bias is not None: seen.add(id(module.bias)) for name, param in model.named_parameters(): if param.requires_grad and id(param) not in seen: self._dense_params.append((name, param)) seen.add(id(param)) self._momentum: dict[int, torch.Tensor] = {} if self.momentum > 0: for _, param in self._dense_params: self._momentum[id(param)] = torch.zeros_like(param.data) for _, module in self._bitlinear_modules: self._momentum[id(module.weight)] = torch.zeros_like(module.weight.data) self._step_masks: dict[int, torch.Tensor] = {} def _direction(self, p: torch.Tensor, seed: int) -> torch.Tensor: gen = torch.Generator(device="cpu") gen.manual_seed(int(seed) & 0x7FFF_FFFF_FFFF_FFFF) if self.direction == "gaussian": return torch.randn(p.shape, dtype=p.dtype, device="cpu", generator=gen).to(p.device) z = torch.empty(p.shape, dtype=p.dtype, device="cpu") z.bernoulli_(0.5, generator=gen).mul_(2).sub_(1) return z.to(p.device) def _walk_params(self): offset = 0 for _, module in self._bitlinear_modules: yield offset, module.weight.data, self._step_masks.get(id(module.weight)) offset += 1 if module.bias is not None: yield offset, module.bias.data, None offset += 1 for _, param in self._dense_params: yield offset, param.data, None offset += 1 def _perturb(self, base_seed: int, scale: float) -> None: for off, param, mask in self._walk_params(): z = self._direction(param, base_seed + off * 1_000_003) if mask is not None: z = z * mask.to(dtype=z.dtype, device=z.device) param.add_(z, alpha=scale) for _, module in self._bitlinear_modules: module.invalidate_packed() def _update(self, base_seed: int, projected_grad: float) -> None: for off, param, mask in self._walk_params(): z = self._direction(param, base_seed + off * 1_000_003) if mask is not None: z = z * mask.to(dtype=z.dtype, device=z.device) buf = self._momentum.get(id(param)) if buf is not None: buf.mul_(self.momentum).add_(z, alpha=projected_grad) param.add_(buf, alpha=-self.lr) else: param.add_(z, alpha=-self.lr * projected_grad) if self.wd > 0: param.mul_(1 - self.lr * self.wd) for _, module in self._bitlinear_modules: module.invalidate_packed() @torch.no_grad() def step(self, loss_fn, batch) -> float: seed = int(torch.randint(0, 2**31, (1,)).item()) self._step_masks = {id(m.weight): m.ternary_nonzero_mask().detach() for _, m in self._bitlinear_modules} self._perturb(seed, +self.eps) loss_pos = float(loss_fn(batch).item()) self._perturb(seed, -2.0 * self.eps) loss_neg = float(loss_fn(batch).item()) self._perturb(seed, +self.eps) projected_grad = (loss_pos - loss_neg) / (2.0 * self.eps) self._update(seed, projected_grad) self._step_masks = {} return 0.5 * (loss_pos + loss_neg)