chomera / chimera /training /optimizers.py
Lgr54HFi's picture
Upload folder using huggingface_hub
11c11f8 verified
from __future__ import annotations
import torch
import torch.nn as nn
from chimera.quantization import BitLinear
class MeZOOptimizer:
"""Memory-Efficient Zeroth-Order optimiser (Princeton MeZO)."""
def __init__(
self,
model: nn.Module,
lr: float = 1e-4,
eps: float = 1e-3,
weight_decay: float = 0.0,
momentum: float = 0.0,
direction: str = "rademacher",
):
self.model = model
self.lr = float(lr)
self.eps = float(eps)
self.wd = float(weight_decay)
self.momentum = float(momentum)
if direction not in ("rademacher", "gaussian"):
raise ValueError(f"unknown direction: {direction!r}")
self.direction = direction
self._bitlinear_modules: list[tuple[str, BitLinear]] = []
self._dense_params: list[tuple[str, torch.Tensor]] = []
seen: set[int] = set()
for name, module in model.named_modules():
if isinstance(module, BitLinear):
self._bitlinear_modules.append((name, module))
seen.add(id(module.weight))
if module.bias is not None:
seen.add(id(module.bias))
for name, param in model.named_parameters():
if param.requires_grad and id(param) not in seen:
self._dense_params.append((name, param))
seen.add(id(param))
self._momentum: dict[int, torch.Tensor] = {}
if self.momentum > 0:
for _, param in self._dense_params:
self._momentum[id(param)] = torch.zeros_like(param.data)
for _, module in self._bitlinear_modules:
self._momentum[id(module.weight)] = torch.zeros_like(module.weight.data)
self._step_masks: dict[int, torch.Tensor] = {}
def _direction(self, p: torch.Tensor, seed: int) -> torch.Tensor:
gen = torch.Generator(device="cpu")
gen.manual_seed(int(seed) & 0x7FFF_FFFF_FFFF_FFFF)
if self.direction == "gaussian":
return torch.randn(p.shape, dtype=p.dtype, device="cpu", generator=gen).to(p.device)
z = torch.empty(p.shape, dtype=p.dtype, device="cpu")
z.bernoulli_(0.5, generator=gen).mul_(2).sub_(1)
return z.to(p.device)
def _walk_params(self):
offset = 0
for _, module in self._bitlinear_modules:
yield offset, module.weight.data, self._step_masks.get(id(module.weight))
offset += 1
if module.bias is not None:
yield offset, module.bias.data, None
offset += 1
for _, param in self._dense_params:
yield offset, param.data, None
offset += 1
def _perturb(self, base_seed: int, scale: float) -> None:
for off, param, mask in self._walk_params():
z = self._direction(param, base_seed + off * 1_000_003)
if mask is not None:
z = z * mask.to(dtype=z.dtype, device=z.device)
param.add_(z, alpha=scale)
for _, module in self._bitlinear_modules:
module.invalidate_packed()
def _update(self, base_seed: int, projected_grad: float) -> None:
for off, param, mask in self._walk_params():
z = self._direction(param, base_seed + off * 1_000_003)
if mask is not None:
z = z * mask.to(dtype=z.dtype, device=z.device)
buf = self._momentum.get(id(param))
if buf is not None:
buf.mul_(self.momentum).add_(z, alpha=projected_grad)
param.add_(buf, alpha=-self.lr)
else:
param.add_(z, alpha=-self.lr * projected_grad)
if self.wd > 0:
param.mul_(1 - self.lr * self.wd)
for _, module in self._bitlinear_modules:
module.invalidate_packed()
@torch.no_grad()
def step(self, loss_fn, batch) -> float:
seed = int(torch.randint(0, 2**31, (1,)).item())
self._step_masks = {id(m.weight): m.ternary_nonzero_mask().detach() for _, m in self._bitlinear_modules}
self._perturb(seed, +self.eps)
loss_pos = float(loss_fn(batch).item())
self._perturb(seed, -2.0 * self.eps)
loss_neg = float(loss_fn(batch).item())
self._perturb(seed, +self.eps)
projected_grad = (loss_pos - loss_neg) / (2.0 * self.eps)
self._update(seed, projected_grad)
self._step_masks = {}
return 0.5 * (loss_pos + loss_neg)