| from __future__ import annotations |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| class GrowLengthScheduler: |
| def __init__(self, stages, total_steps): |
| total_frac = sum(frac for _, frac in stages) or 1.0 |
| cumulative = 0 |
| self._boundaries = [] |
| for seq_len, frac in stages: |
| cumulative += int(total_steps * frac / total_frac) |
| self._boundaries.append((cumulative, int(seq_len))) |
|
|
| def get_seq_len(self, step: int) -> int: |
| for boundary, seq_len in self._boundaries: |
| if step < boundary: |
| return seq_len |
| return self._boundaries[-1][1] |
|
|
|
|
| def apply_reservoir_freezing(model) -> int: |
| frozen = 0 |
| for _, module in model.named_modules(): |
| targets = [] |
| if hasattr(module, "a_proj") and hasattr(module, "b_proj"): |
| targets.extend(["a_proj", "b_proj"]) |
| if hasattr(module, "fgate") and hasattr(module, "igate"): |
| targets.append("fgate") |
| if hasattr(module, "alpha_proj") and hasattr(module, "eta_proj"): |
| targets.append("alpha_proj") |
| for attr in targets: |
| proj = getattr(module, attr, None) |
| if proj is None: |
| continue |
| weight = getattr(proj, "weight", None) |
| if weight is None or not isinstance(weight, nn.Parameter): |
| continue |
| with torch.no_grad(): |
| weight.data = torch.randint(-1, 2, weight.shape, dtype=weight.dtype, device=weight.device) |
| norm = torch.linalg.matrix_norm(weight.data.float(), ord=2).clamp(min=1.0) |
| weight.data.div_(norm) |
| weight.requires_grad = False |
| frozen += weight.numel() |
| return frozen |
|
|
|
|
| class SeedReplayMeZO: |
| def __init__(self, model, *, lr=1e-4, eps=1e-3, weight_decay=0.0, momentum=0.9): |
| self.model = model |
| self.lr = float(lr) |
| self.eps = float(eps) |
| self.wd = float(weight_decay) |
| self.mom = float(momentum) |
| self._params = [] |
| seen = set() |
| for _, param in model.named_parameters(): |
| if param.requires_grad and id(param) not in seen: |
| self._params.append(param) |
| seen.add(id(param)) |
| self._momentum = [torch.zeros_like(param.data) for param in self._params] if self.mom > 0 else None |
|
|
| def _perturb_inplace(self, seed: int, scale: float) -> None: |
| gen = torch.Generator(device="cpu") |
| for i, param in enumerate(self._params): |
| gen.manual_seed((seed + i * 999983) & 0x7FFFFFFFFFFFFFFF) |
| z = torch.empty_like(param.data) |
| z.bernoulli_(0.5, generator=gen).mul_(2).sub_(1) |
| param.data.add_(z, alpha=scale) |
|
|
| def _update_inplace(self, seed: int, projected_grad: float) -> None: |
| gen = torch.Generator(device="cpu") |
| for i, param in enumerate(self._params): |
| gen.manual_seed((seed + i * 999983) & 0x7FFFFFFFFFFFFFFF) |
| z = torch.empty_like(param.data) |
| z.bernoulli_(0.5, generator=gen).mul_(2).sub_(1) |
| param.data.add_(z, alpha=self.eps) |
| if self._momentum is not None: |
| buf = self._momentum[i] |
| buf.mul_(self.mom).add_(z, alpha=projected_grad) |
| param.data.add_(buf, alpha=-self.lr) |
| else: |
| param.data.add_(z, alpha=-self.lr * projected_grad) |
| if self.wd > 0: |
| param.data.mul_(1 - self.lr * self.wd) |
|
|
| @torch.no_grad() |
| def step(self, loss_fn, batch) -> float: |
| seed = int(torch.randint(0, 2**31, (1,)).item()) |
| self._perturb_inplace(seed, +self.eps) |
| loss_pos = float(loss_fn(batch).item()) |
| self._perturb_inplace(seed, -2.0 * self.eps) |
| loss_neg = float(loss_fn(batch).item()) |
| projected_grad = (loss_pos - loss_neg) / (2.0 * self.eps) |
| self._update_inplace(seed, projected_grad) |
| return 0.5 * (loss_pos + loss_neg) |
|
|
|
|
| class ProgressiveUnfreezer: |
| def __init__(self, model, total_steps, n_stages=4): |
| self._layers = model.layers |
| self._n = len(self._layers) |
| self._total = total_steps |
| self._stages = n_stages |
| self._block = max(1, self._n // n_stages) |
| self._current = self._n |
| self.update(0) |
|
|
| def update(self, step: int) -> int: |
| stage = min(step * self._stages // max(1, self._total), self._stages - 1) |
| target = max(0, self._n - (stage + 1) * self._block) |
| if target != self._current: |
| self._current = target |
| for i, layer in enumerate(self._layers): |
| requires_grad = i >= self._current |
| for param in layer.parameters(): |
| param.requires_grad = requires_grad |
| return self._current |
|
|
|
|
| class ProgressiveLoopScheduler: |
| """Gradually increase Parcae loop depth during training. |
| |
| With STE+AdamW (not MeZO), multi-loop training is affordable. |
| Progressive schedule avoids instability from deep loops early on. |
| |
| FIX: Old schedule (1→2→3 at 20%/60%/100%) was too aggressive — |
| with 5000 steps, loops=2 at step 1000 while the model is still at |
| loss=10. Now: loops=1 for 50% (stabilize), loops=2 for 30%, loops=3 |
| for 20%. This gives the model time to learn basics before iterating. |
| """ |
|
|
| def __init__(self, total_steps: int, max_loops: int = 3): |
| self._total = total_steps |
| self._max_loops = max_loops |
| self._schedule = [ |
| (0.50, 1), |
| (0.80, 2), |
| (1.01, min(3, max_loops)), |
| ] |
|
|
| def get_loops(self, step: int) -> int: |
| frac = step / max(1, self._total) |
| for threshold, loops in self._schedule: |
| if frac < threshold: |
| return loops |
| return self._schedule[-1][1] |
|
|
|
|
| def patch_training_loops(model, num_loops=1) -> None: |
| """Set initial loop config. Use ProgressiveLoopScheduler to change during training.""" |
| if hasattr(model, "loop_controller"): |
| model.loop_controller.loop_default = num_loops |
| model.loop_controller.loop_min = 1 |
| model.loop_controller.loop_max = max(num_loops, 3) |
| |
| |
| |
| |
| |
| if hasattr(model, "evo_every_n_layers"): |
| model.evo_every_n_layers = max(model.evo_every_n_layers, 28) |
|
|