chomera / chimera /hyper.py
Lgr54HFi's picture
Upload folder using huggingface_hub
11c11f8 verified
"""
Chimera 5.3 β€” HYPER Paradigm Engine for 10,000+ tok/s CPU Training
===================================================================
Seven orthogonal paradigms that stack multiplicatively:
P1 GrowLength Curriculum β€” Start seq=16, grow to target. Short seqs =
huge batch = way more tok/s early on.
(arxiv:2310.00576)
P2 Reservoir Freezing (GRC) β€” Freeze ~50 % of recurrent gate matrices as
random ternary. No grad for those params β‡’
2Γ— fewer FLOPs in recurrent layers.
(arxiv:2512.23145)
P3 Sparse MeZO β€” Perturb only top-K % most-sensitive params
(by magnitude). ZO signal quality ∝
β€–maskβŠ™βˆ‡fβ€–Β²/β€–βˆ‡fβ€–Β²; masking raises it.
(arxiv:2406.02913)
P4 Blockwise Pipeline β€” Pin layer-groups to core-groups; overlap
block N on batch t with block N-1 on t+1.
P5 Fused Ternary Cache β€” Pre-materialise dense ternary weights once
per step; reuse for both MeZO forwards.
P6 Aggressive Token Packing β€” Zero padding waste; pack documents
back-to-back with EOS separators.
P7 Progressive Layer Unfreeze β€” Train only top ~25 % of layers first; un-
freeze downward as training proceeds.
Expected combined multiplier (tiny-35 M on 8-core CPU):
P1 (4-8Γ—) Γ— P2 (1.5-2Γ—) Γ— P3 (3-5Γ—) Γ— P5 (1.3Γ—) Γ— P7 (1.5-2Γ—)
β‰ˆ 35-260Γ— β‡’ 50-200 tok/s baseline β†’ **1 750-52 000 tok/s**
"""
from __future__ import annotations
import math
import time
from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from .quantization import BitLinear
# ═══════════════════════════════════════════════════════════════════════════
# P1 β€” GrowLength Curriculum
# ═══════════════════════════════════════════════════════════════════════════
class GrowLengthDataset(Dataset):
"""Flat token buffer re-chunked on-the-fly when ``set_seq_len`` is called.
Because chunks are contiguous slices, set_seq_len is O(1).
"""
def __init__(self, all_ids: torch.Tensor, seq_len: int = 16):
self.all_ids = all_ids
self._seq_len = 0
self._n = 0
self.set_seq_len(seq_len)
# ── public API ───────────────────────────────────────────────────────
def set_seq_len(self, seq_len: int) -> None:
self._seq_len = int(seq_len)
self._n = self.all_ids.numel() // (self._seq_len + 1)
@property
def seq_len(self) -> int:
return self._seq_len
def __len__(self) -> int:
return self._n
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
start = idx * (self._seq_len + 1)
chunk = self.all_ids[start: start + self._seq_len + 1]
return {"input_ids": chunk[:-1], "labels": chunk[1:]}
class GrowLengthScheduler:
"""Maps a global step to the current target sequence length.
``stages`` is a list of ``(seq_len, fraction_of_total_steps)`` tuples.
Fractions are normalised internally so they need not sum to 1.
"""
def __init__(self, stages: List[Tuple[int, float]], total_steps: int):
total_frac = sum(f for _, f in stages) or 1.0
cumulative = 0
self._boundaries: List[Tuple[int, int]] = []
for seq_len, frac in stages:
cumulative += int(total_steps * frac / total_frac)
self._boundaries.append((cumulative, int(seq_len)))
def get_seq_len(self, step: int) -> int:
for boundary, seq_len in self._boundaries:
if step < boundary:
return seq_len
return self._boundaries[-1][1]
# ═══════════════════════════════════════════════════════════════════════════
# P2 β€” Reservoir Freezing (GRC-inspired, arxiv:2512.23145)
# ═══════════════════════════════════════════════════════════════════════════
def apply_reservoir_freezing(model: nn.Module,
freeze_ratio: float = 0.5) -> int:
"""Freeze gate / forget projections in recurrent layers as random ternary
reservoirs. Returns the number of frozen scalar parameters.
Targets:
β€’ GatedDeltaNet β†’ a_proj, b_proj (alpha / beta gates)
β€’ mLSTM β†’ fgate (forget gate)
β€’ TitansMAC β†’ alpha_proj (forgetting gate)
The frozen weights are re-initialised to unit-spectral-radius ternary
matrices so every layer starts with a stable reservoir.
"""
frozen = 0
for _name, module in model.named_modules():
# ── GatedDeltaNet gates ──────────────────────────────────────
if hasattr(module, "a_proj") and hasattr(module, "b_proj"):
for attr in ("a_proj", "b_proj"):
proj = getattr(module, attr, None)
if proj is None:
continue
w = getattr(proj, "weight", None)
if w is None or not isinstance(w, nn.Parameter):
continue
with torch.no_grad():
w.data = torch.randint(-1, 2, w.shape,
dtype=w.dtype, device=w.device)
norm = torch.linalg.matrix_norm(
w.data.float(), ord=2).clamp(min=1.0)
w.data.div_(norm)
w.requires_grad = False
frozen += w.numel()
# ── mLSTM forget gate ────────────────────────────────────────
if hasattr(module, "fgate") and hasattr(module, "igate"):
fg = module.fgate
w = getattr(fg, "weight", None)
if w is not None and isinstance(w, nn.Parameter):
with torch.no_grad():
w.data = torch.randint(-1, 2, w.shape,
dtype=w.dtype, device=w.device).float()
norm = torch.linalg.matrix_norm(
w.data, ord=2).clamp(min=1.0)
w.data.div_(norm)
w.requires_grad = False
frozen += w.numel()
# ── TitansMAC forgetting ─────────────────────────────────────
if hasattr(module, "alpha_proj") and hasattr(module, "eta_proj"):
ap = module.alpha_proj
w = getattr(ap, "weight", None)
if w is not None and isinstance(w, nn.Parameter):
with torch.no_grad():
w.data = torch.randint(-1, 2, w.shape,
dtype=w.dtype, device=w.device).float()
norm = torch.linalg.matrix_norm(
w.data, ord=2).clamp(min=1.0)
w.data.div_(norm)
w.requires_grad = False
frozen += w.numel()
return frozen
# ═══════════════════════════════════════════════════════════════════════════
# P3 β€” Sparse MeZO (arxiv:2406.02913)
# ═══════════════════════════════════════════════════════════════════════════
class SparseMeZOOptimizer:
"""Zeroth-order optimiser that perturbs only the top-K % most-sensitive
parameters (ranked by weight magnitude as a cheap proxy for gradient
magnitude).
Combined with **Paradigm 5** (fused ternary cache): before each dual-
forward the caller should invoke ``precompute_ternary_cache(model)``
once so that both forward passes reuse the same dense-weight buffers.
"""
def __init__(self, model: nn.Module, *,
lr: float = 1e-4,
eps: float = 1e-3,
sparsity: float = 0.01,
weight_decay: float = 0.0,
momentum: float = 0.0,
mask_refresh_interval: int = 50):
self.model = model
self.lr = float(lr)
self.eps = float(eps)
self.sparsity = float(sparsity)
self.wd = float(weight_decay)
self.momentum_coeff = float(momentum)
self.mask_refresh = int(mask_refresh_interval)
# Deduplicated trainable params
self._params: List[Tuple[str, nn.Parameter]] = []
seen: set = set()
for name, p in model.named_parameters():
if p.requires_grad and id(p) not in seen:
self._params.append((name, p))
seen.add(id(p))
self._total = sum(p.numel() for _, p in self._params)
self._k = max(1, int(self._total * self.sparsity))
self._masks: Dict[int, torch.Tensor] = {}
self._momentum: Dict[int, torch.Tensor] = {}
if self.momentum_coeff > 0:
for _, p in self._params:
self._momentum[id(p)] = torch.zeros_like(p.data)
self._step = 0
self._refresh_masks()
# ── mask computation ─────────────────────────────────────────────
def _refresh_masks(self) -> None:
slices, offset = [], 0
mags = []
for _, p in self._params:
flat = p.data.abs().flatten()
mags.append(flat)
slices.append((offset, offset + flat.numel()))
offset += flat.numel()
all_mag = torch.cat(mags)
if self._k < all_mag.numel():
thr = torch.topk(all_mag, self._k, sorted=False).values.min()
else:
thr = torch.tensor(0.0)
for i, (_, p) in enumerate(self._params):
s, e = slices[i]
self._masks[id(p)] = (all_mag[s:e] >= thr).view(p.shape)
# ── perturbation helpers ─────────────────────────────────────────
def _direction(self, p: torch.Tensor, seed: int,
mask: torch.Tensor) -> torch.Tensor:
gen = torch.Generator(device="cpu")
gen.manual_seed(seed & 0x7FFF_FFFF_FFFF_FFFF)
z = torch.empty(p.shape, dtype=p.dtype, device="cpu")
z.bernoulli_(0.5, generator=gen).mul_(2).sub_(1)
return z * mask.to(z.dtype)
def _perturb(self, seed: int, scale: float) -> None:
for i, (_, p) in enumerate(self._params):
z = self._direction(p.data, seed + i * 1_000_003,
self._masks.get(id(p),
torch.ones_like(p.data)))
p.data.add_(z, alpha=scale)
_invalidate_bitlinear(self.model)
# ── step ─────────────────────────────────────────────────────────
@torch.no_grad()
def step(self, loss_fn, batch) -> float:
self._step += 1
if self._step % self.mask_refresh == 0:
self._refresh_masks()
seed = int(torch.randint(0, 2 ** 31, (1,)).item())
self._perturb(seed, +self.eps)
loss_pos = float(loss_fn(batch).item())
self._perturb(seed, -2.0 * self.eps)
loss_neg = float(loss_fn(batch).item())
self._perturb(seed, +self.eps) # restore
proj = (loss_pos - loss_neg) / (2.0 * self.eps)
for i, (_, p) in enumerate(self._params):
mask = self._masks.get(id(p), torch.ones_like(p.data))
z = self._direction(p.data, seed + i * 1_000_003, mask)
if self.momentum_coeff > 0:
buf = self._momentum[id(p)]
buf.mul_(self.momentum_coeff).add_(z, alpha=proj)
p.data.add_(buf, alpha=-self.lr)
else:
p.data.add_(z, alpha=-self.lr * proj)
if self.wd > 0:
p.data.mul_(1 - self.lr * self.wd)
_invalidate_bitlinear(self.model)
return 0.5 * (loss_pos + loss_neg)
# ═══════════════════════════════════════════════════════════════════════════
# P5 β€” Fused Ternary Cache
# ═══════════════════════════════════════════════════════════════════════════
def precompute_ternary_cache(model: nn.Module) -> None:
"""Materialise every BitLinear's packed + dense fp32 cache so the next
forward pass is allocation-free. Call once before each MeZO dual-fwd."""
for m in model.modules():
if isinstance(m, BitLinear):
m._ensure_packed()
m._ensure_dense()
def _invalidate_bitlinear(model: nn.Module) -> None:
for m in model.modules():
if isinstance(m, BitLinear):
m.invalidate_packed()
# ═══════════════════════════════════════════════════════════════════════════
# P6 β€” Aggressive Token Packing
# ═══════════════════════════════════════════════════════════════════════════
def pack_documents(raw_ids: torch.Tensor, eos_id: int,
max_tokens: int) -> torch.Tensor:
"""Return a contiguous 1-D ``LongTensor`` of ``max_tokens`` tokens where
individual documents are separated by ``eos_id`` and there is **zero**
padding. Already-tokenised documents should be concatenated in
``raw_ids`` (the function simply truncates to ``max_tokens``).
"""
n = min(raw_ids.numel(), int(max_tokens))
return raw_ids[:n].contiguous()
# ═══════════════════════════════════════════════════════════════════════════
# P7 β€” Progressive Layer Unfreezing
# ═══════════════════════════════════════════════════════════════════════════
class ProgressiveUnfreezer:
"""Freeze all but the top *k* layers initially; unfreeze downward as
training advances.
``n_stages`` = number of unfreeze events spread evenly across
``total_steps``. At each event one more block of layers becomes
trainable (starting from the output end).
"""
def __init__(self, model: nn.Module, total_steps: int,
n_stages: int = 4):
self._layers = model.layers # nn.ModuleList
self._n = len(self._layers)
self._total = int(total_steps)
self._stages = int(n_stages)
self._block = max(1, self._n // self._stages)
self._current_from = self._n # everything frozen initially
# Immediately unfreeze the first block (top layers)
self.update(0)
def update(self, step: int) -> int:
"""Call every step. Returns the index of the first trainable layer."""
stage = min(step * self._stages // max(1, self._total),
self._stages - 1)
target = max(0, self._n - (stage + 1) * self._block)
if target != self._current_from:
self._current_from = target
for i, layer in enumerate(self._layers):
req = i >= self._current_from
for p in layer.parameters():
p.requires_grad = req
return self._current_from
# ═══════════════════════════════════════════════════════════════════════════
# Cosine LR helper (shared)
# ═══════════════════════════════════════════════════════════════════════════
def cosine_lr(step: int, warmup: int, total: int,
max_lr: float, min_lr: float) -> float:
if warmup > 0 and step < warmup:
return max_lr * (step + 1) / warmup
if step >= total:
return min_lr
p = (step - warmup) / max(1, total - warmup)
return min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos(math.pi * p))
# ═══════════════════════════════════════════════════════════════════════════
# Public surface
# ═══════════════════════════════════════════════════════════════════════════
__all__ = [
"GrowLengthDataset",
"GrowLengthScheduler",
"apply_reservoir_freezing",
"SparseMeZOOptimizer",
"precompute_ternary_cache",
"pack_documents",
"ProgressiveUnfreezer",
"cosine_lr",
]