File size: 4,503 Bytes
11c11f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from __future__ import annotations

import torch
import torch.nn as nn

from chimera.quantization import BitLinear


class MeZOOptimizer:
    """Memory-Efficient Zeroth-Order optimiser (Princeton MeZO)."""

    def __init__(
        self,
        model: nn.Module,
        lr: float = 1e-4,
        eps: float = 1e-3,
        weight_decay: float = 0.0,
        momentum: float = 0.0,
        direction: str = "rademacher",
    ):
        self.model = model
        self.lr = float(lr)
        self.eps = float(eps)
        self.wd = float(weight_decay)
        self.momentum = float(momentum)
        if direction not in ("rademacher", "gaussian"):
            raise ValueError(f"unknown direction: {direction!r}")
        self.direction = direction

        self._bitlinear_modules: list[tuple[str, BitLinear]] = []
        self._dense_params: list[tuple[str, torch.Tensor]] = []
        seen: set[int] = set()

        for name, module in model.named_modules():
            if isinstance(module, BitLinear):
                self._bitlinear_modules.append((name, module))
                seen.add(id(module.weight))
                if module.bias is not None:
                    seen.add(id(module.bias))

        for name, param in model.named_parameters():
            if param.requires_grad and id(param) not in seen:
                self._dense_params.append((name, param))
                seen.add(id(param))

        self._momentum: dict[int, torch.Tensor] = {}
        if self.momentum > 0:
            for _, param in self._dense_params:
                self._momentum[id(param)] = torch.zeros_like(param.data)
            for _, module in self._bitlinear_modules:
                self._momentum[id(module.weight)] = torch.zeros_like(module.weight.data)

        self._step_masks: dict[int, torch.Tensor] = {}

    def _direction(self, p: torch.Tensor, seed: int) -> torch.Tensor:
        gen = torch.Generator(device="cpu")
        gen.manual_seed(int(seed) & 0x7FFF_FFFF_FFFF_FFFF)
        if self.direction == "gaussian":
            return torch.randn(p.shape, dtype=p.dtype, device="cpu", generator=gen).to(p.device)
        z = torch.empty(p.shape, dtype=p.dtype, device="cpu")
        z.bernoulli_(0.5, generator=gen).mul_(2).sub_(1)
        return z.to(p.device)

    def _walk_params(self):
        offset = 0
        for _, module in self._bitlinear_modules:
            yield offset, module.weight.data, self._step_masks.get(id(module.weight))
            offset += 1
            if module.bias is not None:
                yield offset, module.bias.data, None
                offset += 1
        for _, param in self._dense_params:
            yield offset, param.data, None
            offset += 1

    def _perturb(self, base_seed: int, scale: float) -> None:
        for off, param, mask in self._walk_params():
            z = self._direction(param, base_seed + off * 1_000_003)
            if mask is not None:
                z = z * mask.to(dtype=z.dtype, device=z.device)
            param.add_(z, alpha=scale)
        for _, module in self._bitlinear_modules:
            module.invalidate_packed()

    def _update(self, base_seed: int, projected_grad: float) -> None:
        for off, param, mask in self._walk_params():
            z = self._direction(param, base_seed + off * 1_000_003)
            if mask is not None:
                z = z * mask.to(dtype=z.dtype, device=z.device)
            buf = self._momentum.get(id(param))
            if buf is not None:
                buf.mul_(self.momentum).add_(z, alpha=projected_grad)
                param.add_(buf, alpha=-self.lr)
            else:
                param.add_(z, alpha=-self.lr * projected_grad)
            if self.wd > 0:
                param.mul_(1 - self.lr * self.wd)
        for _, module in self._bitlinear_modules:
            module.invalidate_packed()

    @torch.no_grad()
    def step(self, loss_fn, batch) -> float:
        seed = int(torch.randint(0, 2**31, (1,)).item())
        self._step_masks = {id(m.weight): m.ternary_nonzero_mask().detach() for _, m in self._bitlinear_modules}
        self._perturb(seed, +self.eps)
        loss_pos = float(loss_fn(batch).item())
        self._perturb(seed, -2.0 * self.eps)
        loss_neg = float(loss_fn(batch).item())
        self._perturb(seed, +self.eps)
        projected_grad = (loss_pos - loss_neg) / (2.0 * self.eps)
        self._update(seed, projected_grad)
        self._step_masks = {}
        return 0.5 * (loss_pos + loss_neg)