File size: 6,728 Bytes
11c11f8 5fd9d22 6a7521a 5fd9d22 6a7521a 5fd9d22 11c11f8 5fd9d22 11c11f8 6a7521a 11c11f8 6a7521a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | from __future__ import annotations
import torch
import torch.nn as nn
class GrowLengthScheduler:
def __init__(self, stages, total_steps):
total_frac = sum(frac for _, frac in stages) or 1.0
cumulative = 0
self._boundaries = []
for seq_len, frac in stages:
cumulative += int(total_steps * frac / total_frac)
self._boundaries.append((cumulative, int(seq_len)))
def get_seq_len(self, step: int) -> int:
for boundary, seq_len in self._boundaries:
if step < boundary:
return seq_len
return self._boundaries[-1][1]
def apply_reservoir_freezing(model) -> int:
frozen = 0
for _, module in model.named_modules():
targets = []
if hasattr(module, "a_proj") and hasattr(module, "b_proj"):
targets.extend(["a_proj", "b_proj"])
if hasattr(module, "fgate") and hasattr(module, "igate"):
targets.append("fgate")
if hasattr(module, "alpha_proj") and hasattr(module, "eta_proj"):
targets.append("alpha_proj")
for attr in targets:
proj = getattr(module, attr, None)
if proj is None:
continue
weight = getattr(proj, "weight", None)
if weight is None or not isinstance(weight, nn.Parameter):
continue
with torch.no_grad():
weight.data = torch.randint(-1, 2, weight.shape, dtype=weight.dtype, device=weight.device)
norm = torch.linalg.matrix_norm(weight.data.float(), ord=2).clamp(min=1.0)
weight.data.div_(norm)
weight.requires_grad = False
frozen += weight.numel()
return frozen
class SeedReplayMeZO:
def __init__(self, model, *, lr=1e-4, eps=1e-3, weight_decay=0.0, momentum=0.9):
self.model = model
self.lr = float(lr)
self.eps = float(eps)
self.wd = float(weight_decay)
self.mom = float(momentum)
self._params = []
seen = set()
for _, param in model.named_parameters():
if param.requires_grad and id(param) not in seen:
self._params.append(param)
seen.add(id(param))
self._momentum = [torch.zeros_like(param.data) for param in self._params] if self.mom > 0 else None
def _perturb_inplace(self, seed: int, scale: float) -> None:
gen = torch.Generator(device="cpu")
for i, param in enumerate(self._params):
gen.manual_seed((seed + i * 999983) & 0x7FFFFFFFFFFFFFFF)
z = torch.empty_like(param.data)
z.bernoulli_(0.5, generator=gen).mul_(2).sub_(1)
param.data.add_(z, alpha=scale)
def _update_inplace(self, seed: int, projected_grad: float) -> None:
gen = torch.Generator(device="cpu")
for i, param in enumerate(self._params):
gen.manual_seed((seed + i * 999983) & 0x7FFFFFFFFFFFFFFF)
z = torch.empty_like(param.data)
z.bernoulli_(0.5, generator=gen).mul_(2).sub_(1)
param.data.add_(z, alpha=self.eps)
if self._momentum is not None:
buf = self._momentum[i]
buf.mul_(self.mom).add_(z, alpha=projected_grad)
param.data.add_(buf, alpha=-self.lr)
else:
param.data.add_(z, alpha=-self.lr * projected_grad)
if self.wd > 0:
param.data.mul_(1 - self.lr * self.wd)
@torch.no_grad()
def step(self, loss_fn, batch) -> float:
seed = int(torch.randint(0, 2**31, (1,)).item())
self._perturb_inplace(seed, +self.eps)
loss_pos = float(loss_fn(batch).item())
self._perturb_inplace(seed, -2.0 * self.eps)
loss_neg = float(loss_fn(batch).item())
projected_grad = (loss_pos - loss_neg) / (2.0 * self.eps)
self._update_inplace(seed, projected_grad)
return 0.5 * (loss_pos + loss_neg)
class ProgressiveUnfreezer:
def __init__(self, model, total_steps, n_stages=4):
self._layers = model.layers
self._n = len(self._layers)
self._total = total_steps
self._stages = n_stages
self._block = max(1, self._n // n_stages)
self._current = self._n
self.update(0)
def update(self, step: int) -> int:
stage = min(step * self._stages // max(1, self._total), self._stages - 1)
target = max(0, self._n - (stage + 1) * self._block)
if target != self._current:
self._current = target
for i, layer in enumerate(self._layers):
requires_grad = i >= self._current
for param in layer.parameters():
param.requires_grad = requires_grad
return self._current
class ProgressiveLoopScheduler:
"""Gradually increase Parcae loop depth during training.
With STE+AdamW (not MeZO), multi-loop training is affordable.
Progressive schedule avoids instability from deep loops early on.
FIX: Old schedule (1→2→3 at 20%/60%/100%) was too aggressive —
with 5000 steps, loops=2 at step 1000 while the model is still at
loss=10. Now: loops=1 for 50% (stabilize), loops=2 for 30%, loops=3
for 20%. This gives the model time to learn basics before iterating.
"""
def __init__(self, total_steps: int, max_loops: int = 3):
self._total = total_steps
self._max_loops = max_loops
self._schedule = [
(0.50, 1), # First 50%: stabilize weights with single pass
(0.80, 2), # Next 30%: learn to iterate
(1.01, min(3, max_loops)), # Last 20%: deep refinement
]
def get_loops(self, step: int) -> int:
frac = step / max(1, self._total)
for threshold, loops in self._schedule:
if frac < threshold:
return loops
return self._schedule[-1][1]
def patch_training_loops(model, num_loops=1) -> None:
"""Set initial loop config. Use ProgressiveLoopScheduler to change during training."""
if hasattr(model, "loop_controller"):
model.loop_controller.loop_default = num_loops
model.loop_controller.loop_min = 1
model.loop_controller.loop_max = max(num_loops, 3)
# FIX: Evolution modulation is very expensive on CPU (HDC projections,
# Hamming distance queries over 50K entries, episodic retrieval).
# With evo_every_n_layers=4 and 28 layers, that's 7 calls per forward.
# Set to 28 → evolution fires once per full pass (at layer 0 only),
# which is enough for the memory to modulate the input embedding.
if hasattr(model, "evo_every_n_layers"):
model.evo_every_n_layers = max(model.evo_every_n_layers, 28)
|