chomera / chimera /training /hyper.py
Lgr54HFi's picture
Upload chimera/training/hyper.py
6a7521a verified
from __future__ import annotations
import torch
import torch.nn as nn
class GrowLengthScheduler:
def __init__(self, stages, total_steps):
total_frac = sum(frac for _, frac in stages) or 1.0
cumulative = 0
self._boundaries = []
for seq_len, frac in stages:
cumulative += int(total_steps * frac / total_frac)
self._boundaries.append((cumulative, int(seq_len)))
def get_seq_len(self, step: int) -> int:
for boundary, seq_len in self._boundaries:
if step < boundary:
return seq_len
return self._boundaries[-1][1]
def apply_reservoir_freezing(model) -> int:
frozen = 0
for _, module in model.named_modules():
targets = []
if hasattr(module, "a_proj") and hasattr(module, "b_proj"):
targets.extend(["a_proj", "b_proj"])
if hasattr(module, "fgate") and hasattr(module, "igate"):
targets.append("fgate")
if hasattr(module, "alpha_proj") and hasattr(module, "eta_proj"):
targets.append("alpha_proj")
for attr in targets:
proj = getattr(module, attr, None)
if proj is None:
continue
weight = getattr(proj, "weight", None)
if weight is None or not isinstance(weight, nn.Parameter):
continue
with torch.no_grad():
weight.data = torch.randint(-1, 2, weight.shape, dtype=weight.dtype, device=weight.device)
norm = torch.linalg.matrix_norm(weight.data.float(), ord=2).clamp(min=1.0)
weight.data.div_(norm)
weight.requires_grad = False
frozen += weight.numel()
return frozen
class SeedReplayMeZO:
def __init__(self, model, *, lr=1e-4, eps=1e-3, weight_decay=0.0, momentum=0.9):
self.model = model
self.lr = float(lr)
self.eps = float(eps)
self.wd = float(weight_decay)
self.mom = float(momentum)
self._params = []
seen = set()
for _, param in model.named_parameters():
if param.requires_grad and id(param) not in seen:
self._params.append(param)
seen.add(id(param))
self._momentum = [torch.zeros_like(param.data) for param in self._params] if self.mom > 0 else None
def _perturb_inplace(self, seed: int, scale: float) -> None:
gen = torch.Generator(device="cpu")
for i, param in enumerate(self._params):
gen.manual_seed((seed + i * 999983) & 0x7FFFFFFFFFFFFFFF)
z = torch.empty_like(param.data)
z.bernoulli_(0.5, generator=gen).mul_(2).sub_(1)
param.data.add_(z, alpha=scale)
def _update_inplace(self, seed: int, projected_grad: float) -> None:
gen = torch.Generator(device="cpu")
for i, param in enumerate(self._params):
gen.manual_seed((seed + i * 999983) & 0x7FFFFFFFFFFFFFFF)
z = torch.empty_like(param.data)
z.bernoulli_(0.5, generator=gen).mul_(2).sub_(1)
param.data.add_(z, alpha=self.eps)
if self._momentum is not None:
buf = self._momentum[i]
buf.mul_(self.mom).add_(z, alpha=projected_grad)
param.data.add_(buf, alpha=-self.lr)
else:
param.data.add_(z, alpha=-self.lr * projected_grad)
if self.wd > 0:
param.data.mul_(1 - self.lr * self.wd)
@torch.no_grad()
def step(self, loss_fn, batch) -> float:
seed = int(torch.randint(0, 2**31, (1,)).item())
self._perturb_inplace(seed, +self.eps)
loss_pos = float(loss_fn(batch).item())
self._perturb_inplace(seed, -2.0 * self.eps)
loss_neg = float(loss_fn(batch).item())
projected_grad = (loss_pos - loss_neg) / (2.0 * self.eps)
self._update_inplace(seed, projected_grad)
return 0.5 * (loss_pos + loss_neg)
class ProgressiveUnfreezer:
def __init__(self, model, total_steps, n_stages=4):
self._layers = model.layers
self._n = len(self._layers)
self._total = total_steps
self._stages = n_stages
self._block = max(1, self._n // n_stages)
self._current = self._n
self.update(0)
def update(self, step: int) -> int:
stage = min(step * self._stages // max(1, self._total), self._stages - 1)
target = max(0, self._n - (stage + 1) * self._block)
if target != self._current:
self._current = target
for i, layer in enumerate(self._layers):
requires_grad = i >= self._current
for param in layer.parameters():
param.requires_grad = requires_grad
return self._current
class ProgressiveLoopScheduler:
"""Gradually increase Parcae loop depth during training.
With STE+AdamW (not MeZO), multi-loop training is affordable.
Progressive schedule avoids instability from deep loops early on.
FIX: Old schedule (1→2→3 at 20%/60%/100%) was too aggressive —
with 5000 steps, loops=2 at step 1000 while the model is still at
loss=10. Now: loops=1 for 50% (stabilize), loops=2 for 30%, loops=3
for 20%. This gives the model time to learn basics before iterating.
"""
def __init__(self, total_steps: int, max_loops: int = 3):
self._total = total_steps
self._max_loops = max_loops
self._schedule = [
(0.50, 1), # First 50%: stabilize weights with single pass
(0.80, 2), # Next 30%: learn to iterate
(1.01, min(3, max_loops)), # Last 20%: deep refinement
]
def get_loops(self, step: int) -> int:
frac = step / max(1, self._total)
for threshold, loops in self._schedule:
if frac < threshold:
return loops
return self._schedule[-1][1]
def patch_training_loops(model, num_loops=1) -> None:
"""Set initial loop config. Use ProgressiveLoopScheduler to change during training."""
if hasattr(model, "loop_controller"):
model.loop_controller.loop_default = num_loops
model.loop_controller.loop_min = 1
model.loop_controller.loop_max = max(num_loops, 3)
# FIX: Evolution modulation is very expensive on CPU (HDC projections,
# Hamming distance queries over 50K entries, episodic retrieval).
# With evo_every_n_layers=4 and 28 layers, that's 7 calls per forward.
# Set to 28 → evolution fires once per full pass (at layer 0 only),
# which is enough for the memory to modulate the input embedding.
if hasattr(model, "evo_every_n_layers"):
model.evo_every_n_layers = max(model.evo_every_n_layers, 28)