| from __future__ import annotations |
|
|
| import copy |
| import json |
| import os |
| import time |
|
|
| import torch |
| from torch.utils.data import DataLoader, Dataset |
|
|
| from chimera.quantization import BitLinear |
|
|
| from .common import build_model_from_args |
| from .datasets import GrowLengthDataset, build_token_buffer |
| from .hyper import ( |
| GrowLengthScheduler, |
| ProgressiveUnfreezer, |
| SeedReplayMeZO, |
| apply_reservoir_freezing, |
| patch_training_loops, |
| ) |
|
|
|
|
| def run_baseline(model, token_buf, args): |
| model.train() |
| seq = args.seq_len |
| n = token_buf.numel() // (seq + 1) |
| chunks = token_buf[: n * (seq + 1)].view(n, seq + 1) |
|
|
| class _Dataset(Dataset): |
| def __len__(self): |
| return chunks.size(0) |
|
|
| def __getitem__(self, i): |
| c = chunks[i] |
| return {"input_ids": c[:-1], "labels": c[1:]} |
|
|
| loader = DataLoader(_Dataset(), batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True) |
| params = [(n, p) for n, p in model.named_parameters() if p.requires_grad] |
| eps = 1e-3 |
|
|
| def loss_fn(batch): |
| return model(batch["input_ids"], labels=batch["labels"]).loss |
|
|
| total_toks, total_loss = 0, 0.0 |
| t0 = time.time() |
| di = iter(loader) |
| for _ in range(args.max_steps): |
| try: |
| batch = next(di) |
| except StopIteration: |
| di = iter(loader) |
| batch = next(di) |
| seed = int(torch.randint(0, 2**31, (1,)).item()) |
| gen = torch.Generator(device="cpu") |
| gen.manual_seed(seed) |
| for _, p in params: |
| p.data.add_(torch.randn(p.shape, generator=gen), alpha=eps) |
| for m in model.modules(): |
| if isinstance(m, BitLinear): |
| m.invalidate_packed() |
| with torch.no_grad(): |
| lp = float(loss_fn(batch).item()) |
| gen.manual_seed(seed) |
| for _, p in params: |
| p.data.add_(torch.randn(p.shape, generator=gen), alpha=-2 * eps) |
| for m in model.modules(): |
| if isinstance(m, BitLinear): |
| m.invalidate_packed() |
| with torch.no_grad(): |
| ln = float(loss_fn(batch).item()) |
| g = (lp - ln) / (2 * eps) |
| gen.manual_seed(seed) |
| for _, p in params: |
| z = torch.randn(p.shape, generator=gen) |
| p.data.add_(z, alpha=eps - args.lr * g) |
| for m in model.modules(): |
| if isinstance(m, BitLinear): |
| m.invalidate_packed() |
| total_toks += batch["input_ids"].numel() |
| total_loss += 0.5 * (lp + ln) |
| dt = time.time() - t0 |
| return total_toks / dt, total_loss / args.max_steps, dt |
|
|
|
|
| def run_hyper(model, token_buf, args): |
| model.train() |
| patch_training_loops(model, num_loops=1) |
| if args.reservoir: |
| apply_reservoir_freezing(model) |
| unfreezer = ProgressiveUnfreezer(model, args.max_steps, args.unfreeze_stages) if args.progressive_unfreeze else None |
| stages = [ |
| (max(8, args.seq_len // 4), 0.30), |
| (max(16, args.seq_len // 2), 0.30), |
| (args.seq_len, 0.40), |
| ] |
| grow = GrowLengthScheduler(stages, args.max_steps) if args.growlength else None |
| cur_seq = stages[0][0] if grow else args.seq_len |
| dataset = GrowLengthDataset(token_buf, cur_seq) |
| opt = SeedReplayMeZO(model, lr=args.lr * 0.01, eps=args.mezo_eps, weight_decay=0.1, momentum=0.9) |
|
|
| def loss_fn(batch): |
| if args.bf16: |
| with torch.autocast("cpu", dtype=torch.bfloat16): |
| return model(batch["input_ids"], labels=batch["labels"]).loss |
| return model(batch["input_ids"], labels=batch["labels"]).loss |
|
|
| total_toks, total_loss = 0, 0.0 |
| t0 = time.time() |
| eff_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq)) |
| loader = DataLoader(dataset, batch_size=eff_batch, shuffle=True, num_workers=0, drop_last=True) |
| di = iter(loader) |
| for step in range(args.max_steps): |
| if grow: |
| ns = grow.get_seq_len(step) |
| if ns != cur_seq: |
| cur_seq = ns |
| dataset.set_seq_len(cur_seq) |
| eff_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq)) |
| loader = DataLoader(dataset, batch_size=eff_batch, shuffle=True, num_workers=0, drop_last=True) |
| di = iter(loader) |
| if unfreezer: |
| unfreezer.update(step) |
| try: |
| batch = next(di) |
| except StopIteration: |
| di = iter(loader) |
| batch = next(di) |
| loss_val = opt.step(loss_fn, batch) |
| total_toks += batch["input_ids"].numel() |
| total_loss += loss_val |
| dt = time.time() - t0 |
| return total_toks / dt, total_loss / args.max_steps, dt |
|
|
|
|
| def benchmark_hyper(args): |
| print("=" * 65) |
| print("CHIMERA 5.3 HYPER v3 — BENCHMARK (full arch, all features)") |
| print("=" * 65) |
| model_a, cfg = build_model_from_args(args) |
| model_b = copy.deepcopy(model_a) |
| c = model_a.count_parameters() |
| print(f"Model: {c['total']:,} params, {cfg['num_hidden_layers']} layers") |
| print(f"Features: looping={model_a.looping_enabled} evolution={model_a.evolution is not None} span={model_a.span_engine is not None}") |
|
|
| tok_budget = max(500_000, args.max_steps * args.batch_size * (args.seq_len + 1) * 8) |
| token_buf = build_token_buffer(args.dataset_name, args.dataset_split, args.text_column, tok_budget, args.cache_dir) |
| print(f"Tokens: {token_buf.numel():,}\n") |
|
|
| print("-" * 65) |
| print("BASELINE (randn MeZO, invalidate_packed, loop=2, full evo)") |
| print("-" * 65) |
| bt, bl, bd = run_baseline(model_a, token_buf, args) |
| print(f" -> {bt:,.0f} tok/s loss={bl:.4f} time={bd:.1f}s\n") |
|
|
| print("-" * 65) |
| print("HYPER (seed-replay MeZO, STE path, loop=1, GrowLength, Reservoir)") |
| print("-" * 65) |
| ht, hl, hd = run_hyper(model_b, token_buf, args) |
| print(f" -> {ht:,.0f} tok/s loss={hl:.4f} time={hd:.1f}s\n") |
|
|
| sp = ht / bt if bt > 0 else float("inf") |
| print("=" * 65) |
| print(f" Baseline : {bt:>10,.0f} tok/s loss {bl:.4f}") |
| print(f" Hyper : {ht:>10,.0f} tok/s loss {hl:.4f}") |
| print(f" Speedup : {sp:>10.1f}x") |
| print("=" * 65) |
|
|
| os.makedirs(args.output_dir, exist_ok=True) |
| with open(os.path.join(args.output_dir, "benchmark.json"), "w") as f: |
| json.dump({"baseline_tps": round(bt), "hyper_tps": round(ht), "speedup": round(sp, 2)}, f, indent=2) |
|
|