chomera / chimera /training /benchmark.py
Lgr54HFi's picture
Upload folder using huggingface_hub
11c11f8 verified
from __future__ import annotations
import copy
import json
import os
import time
import torch
from torch.utils.data import DataLoader, Dataset
from chimera.quantization import BitLinear
from .common import build_model_from_args
from .datasets import GrowLengthDataset, build_token_buffer
from .hyper import (
GrowLengthScheduler,
ProgressiveUnfreezer,
SeedReplayMeZO,
apply_reservoir_freezing,
patch_training_loops,
)
def run_baseline(model, token_buf, args):
model.train()
seq = args.seq_len
n = token_buf.numel() // (seq + 1)
chunks = token_buf[: n * (seq + 1)].view(n, seq + 1)
class _Dataset(Dataset):
def __len__(self):
return chunks.size(0)
def __getitem__(self, i):
c = chunks[i]
return {"input_ids": c[:-1], "labels": c[1:]}
loader = DataLoader(_Dataset(), batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True)
params = [(n, p) for n, p in model.named_parameters() if p.requires_grad]
eps = 1e-3
def loss_fn(batch):
return model(batch["input_ids"], labels=batch["labels"]).loss
total_toks, total_loss = 0, 0.0
t0 = time.time()
di = iter(loader)
for _ in range(args.max_steps):
try:
batch = next(di)
except StopIteration:
di = iter(loader)
batch = next(di)
seed = int(torch.randint(0, 2**31, (1,)).item())
gen = torch.Generator(device="cpu")
gen.manual_seed(seed)
for _, p in params:
p.data.add_(torch.randn(p.shape, generator=gen), alpha=eps)
for m in model.modules():
if isinstance(m, BitLinear):
m.invalidate_packed()
with torch.no_grad():
lp = float(loss_fn(batch).item())
gen.manual_seed(seed)
for _, p in params:
p.data.add_(torch.randn(p.shape, generator=gen), alpha=-2 * eps)
for m in model.modules():
if isinstance(m, BitLinear):
m.invalidate_packed()
with torch.no_grad():
ln = float(loss_fn(batch).item())
g = (lp - ln) / (2 * eps)
gen.manual_seed(seed)
for _, p in params:
z = torch.randn(p.shape, generator=gen)
p.data.add_(z, alpha=eps - args.lr * g)
for m in model.modules():
if isinstance(m, BitLinear):
m.invalidate_packed()
total_toks += batch["input_ids"].numel()
total_loss += 0.5 * (lp + ln)
dt = time.time() - t0
return total_toks / dt, total_loss / args.max_steps, dt
def run_hyper(model, token_buf, args):
model.train()
patch_training_loops(model, num_loops=1)
if args.reservoir:
apply_reservoir_freezing(model)
unfreezer = ProgressiveUnfreezer(model, args.max_steps, args.unfreeze_stages) if args.progressive_unfreeze else None
stages = [
(max(8, args.seq_len // 4), 0.30),
(max(16, args.seq_len // 2), 0.30),
(args.seq_len, 0.40),
]
grow = GrowLengthScheduler(stages, args.max_steps) if args.growlength else None
cur_seq = stages[0][0] if grow else args.seq_len
dataset = GrowLengthDataset(token_buf, cur_seq)
opt = SeedReplayMeZO(model, lr=args.lr * 0.01, eps=args.mezo_eps, weight_decay=0.1, momentum=0.9)
def loss_fn(batch):
if args.bf16:
with torch.autocast("cpu", dtype=torch.bfloat16):
return model(batch["input_ids"], labels=batch["labels"]).loss
return model(batch["input_ids"], labels=batch["labels"]).loss
total_toks, total_loss = 0, 0.0
t0 = time.time()
eff_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq))
loader = DataLoader(dataset, batch_size=eff_batch, shuffle=True, num_workers=0, drop_last=True)
di = iter(loader)
for step in range(args.max_steps):
if grow:
ns = grow.get_seq_len(step)
if ns != cur_seq:
cur_seq = ns
dataset.set_seq_len(cur_seq)
eff_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq))
loader = DataLoader(dataset, batch_size=eff_batch, shuffle=True, num_workers=0, drop_last=True)
di = iter(loader)
if unfreezer:
unfreezer.update(step)
try:
batch = next(di)
except StopIteration:
di = iter(loader)
batch = next(di)
loss_val = opt.step(loss_fn, batch)
total_toks += batch["input_ids"].numel()
total_loss += loss_val
dt = time.time() - t0
return total_toks / dt, total_loss / args.max_steps, dt
def benchmark_hyper(args):
print("=" * 65)
print("CHIMERA 5.3 HYPER v3 — BENCHMARK (full arch, all features)")
print("=" * 65)
model_a, cfg = build_model_from_args(args)
model_b = copy.deepcopy(model_a)
c = model_a.count_parameters()
print(f"Model: {c['total']:,} params, {cfg['num_hidden_layers']} layers")
print(f"Features: looping={model_a.looping_enabled} evolution={model_a.evolution is not None} span={model_a.span_engine is not None}")
tok_budget = max(500_000, args.max_steps * args.batch_size * (args.seq_len + 1) * 8)
token_buf = build_token_buffer(args.dataset_name, args.dataset_split, args.text_column, tok_budget, args.cache_dir)
print(f"Tokens: {token_buf.numel():,}\n")
print("-" * 65)
print("BASELINE (randn MeZO, invalidate_packed, loop=2, full evo)")
print("-" * 65)
bt, bl, bd = run_baseline(model_a, token_buf, args)
print(f" -> {bt:,.0f} tok/s loss={bl:.4f} time={bd:.1f}s\n")
print("-" * 65)
print("HYPER (seed-replay MeZO, STE path, loop=1, GrowLength, Reservoir)")
print("-" * 65)
ht, hl, hd = run_hyper(model_b, token_buf, args)
print(f" -> {ht:,.0f} tok/s loss={hl:.4f} time={hd:.1f}s\n")
sp = ht / bt if bt > 0 else float("inf")
print("=" * 65)
print(f" Baseline : {bt:>10,.0f} tok/s loss {bl:.4f}")
print(f" Hyper : {ht:>10,.0f} tok/s loss {hl:.4f}")
print(f" Speedup : {sp:>10.1f}x")
print("=" * 65)
os.makedirs(args.output_dir, exist_ok=True)
with open(os.path.join(args.output_dir, "benchmark.json"), "w") as f:
json.dump({"baseline_tps": round(bt), "hyper_tps": round(ht), "speedup": round(sp, 2)}, f, indent=2)