from __future__ import annotations import copy import json import os import time import torch from torch.utils.data import DataLoader, Dataset from chimera.quantization import BitLinear from .common import build_model_from_args from .datasets import GrowLengthDataset, build_token_buffer from .hyper import ( GrowLengthScheduler, ProgressiveUnfreezer, SeedReplayMeZO, apply_reservoir_freezing, patch_training_loops, ) def run_baseline(model, token_buf, args): model.train() seq = args.seq_len n = token_buf.numel() // (seq + 1) chunks = token_buf[: n * (seq + 1)].view(n, seq + 1) class _Dataset(Dataset): def __len__(self): return chunks.size(0) def __getitem__(self, i): c = chunks[i] return {"input_ids": c[:-1], "labels": c[1:]} loader = DataLoader(_Dataset(), batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True) params = [(n, p) for n, p in model.named_parameters() if p.requires_grad] eps = 1e-3 def loss_fn(batch): return model(batch["input_ids"], labels=batch["labels"]).loss total_toks, total_loss = 0, 0.0 t0 = time.time() di = iter(loader) for _ in range(args.max_steps): try: batch = next(di) except StopIteration: di = iter(loader) batch = next(di) seed = int(torch.randint(0, 2**31, (1,)).item()) gen = torch.Generator(device="cpu") gen.manual_seed(seed) for _, p in params: p.data.add_(torch.randn(p.shape, generator=gen), alpha=eps) for m in model.modules(): if isinstance(m, BitLinear): m.invalidate_packed() with torch.no_grad(): lp = float(loss_fn(batch).item()) gen.manual_seed(seed) for _, p in params: p.data.add_(torch.randn(p.shape, generator=gen), alpha=-2 * eps) for m in model.modules(): if isinstance(m, BitLinear): m.invalidate_packed() with torch.no_grad(): ln = float(loss_fn(batch).item()) g = (lp - ln) / (2 * eps) gen.manual_seed(seed) for _, p in params: z = torch.randn(p.shape, generator=gen) p.data.add_(z, alpha=eps - args.lr * g) for m in model.modules(): if isinstance(m, BitLinear): m.invalidate_packed() total_toks += batch["input_ids"].numel() total_loss += 0.5 * (lp + ln) dt = time.time() - t0 return total_toks / dt, total_loss / args.max_steps, dt def run_hyper(model, token_buf, args): model.train() patch_training_loops(model, num_loops=1) if args.reservoir: apply_reservoir_freezing(model) unfreezer = ProgressiveUnfreezer(model, args.max_steps, args.unfreeze_stages) if args.progressive_unfreeze else None stages = [ (max(8, args.seq_len // 4), 0.30), (max(16, args.seq_len // 2), 0.30), (args.seq_len, 0.40), ] grow = GrowLengthScheduler(stages, args.max_steps) if args.growlength else None cur_seq = stages[0][0] if grow else args.seq_len dataset = GrowLengthDataset(token_buf, cur_seq) opt = SeedReplayMeZO(model, lr=args.lr * 0.01, eps=args.mezo_eps, weight_decay=0.1, momentum=0.9) def loss_fn(batch): if args.bf16: with torch.autocast("cpu", dtype=torch.bfloat16): return model(batch["input_ids"], labels=batch["labels"]).loss return model(batch["input_ids"], labels=batch["labels"]).loss total_toks, total_loss = 0, 0.0 t0 = time.time() eff_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq)) loader = DataLoader(dataset, batch_size=eff_batch, shuffle=True, num_workers=0, drop_last=True) di = iter(loader) for step in range(args.max_steps): if grow: ns = grow.get_seq_len(step) if ns != cur_seq: cur_seq = ns dataset.set_seq_len(cur_seq) eff_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq)) loader = DataLoader(dataset, batch_size=eff_batch, shuffle=True, num_workers=0, drop_last=True) di = iter(loader) if unfreezer: unfreezer.update(step) try: batch = next(di) except StopIteration: di = iter(loader) batch = next(di) loss_val = opt.step(loss_fn, batch) total_toks += batch["input_ids"].numel() total_loss += loss_val dt = time.time() - t0 return total_toks / dt, total_loss / args.max_steps, dt def benchmark_hyper(args): print("=" * 65) print("CHIMERA 5.3 HYPER v3 — BENCHMARK (full arch, all features)") print("=" * 65) model_a, cfg = build_model_from_args(args) model_b = copy.deepcopy(model_a) c = model_a.count_parameters() print(f"Model: {c['total']:,} params, {cfg['num_hidden_layers']} layers") print(f"Features: looping={model_a.looping_enabled} evolution={model_a.evolution is not None} span={model_a.span_engine is not None}") tok_budget = max(500_000, args.max_steps * args.batch_size * (args.seq_len + 1) * 8) token_buf = build_token_buffer(args.dataset_name, args.dataset_split, args.text_column, tok_budget, args.cache_dir) print(f"Tokens: {token_buf.numel():,}\n") print("-" * 65) print("BASELINE (randn MeZO, invalidate_packed, loop=2, full evo)") print("-" * 65) bt, bl, bd = run_baseline(model_a, token_buf, args) print(f" -> {bt:,.0f} tok/s loss={bl:.4f} time={bd:.1f}s\n") print("-" * 65) print("HYPER (seed-replay MeZO, STE path, loop=1, GrowLength, Reservoir)") print("-" * 65) ht, hl, hd = run_hyper(model_b, token_buf, args) print(f" -> {ht:,.0f} tok/s loss={hl:.4f} time={hd:.1f}s\n") sp = ht / bt if bt > 0 else float("inf") print("=" * 65) print(f" Baseline : {bt:>10,.0f} tok/s loss {bl:.4f}") print(f" Hyper : {ht:>10,.0f} tok/s loss {hl:.4f}") print(f" Speedup : {sp:>10.1f}x") print("=" * 65) os.makedirs(args.output_dir, exist_ok=True) with open(os.path.join(args.output_dir, "benchmark.json"), "w") as f: json.dump({"baseline_tps": round(bt), "hyper_tps": round(ht), "speedup": round(sp, 2)}, f, indent=2)