File size: 6,357 Bytes
11c11f8 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 | from __future__ import annotations
import copy
import json
import os
import time
import torch
from torch.utils.data import DataLoader, Dataset
from chimera.quantization import BitLinear
from .common import build_model_from_args
from .datasets import GrowLengthDataset, build_token_buffer
from .hyper import (
GrowLengthScheduler,
ProgressiveUnfreezer,
SeedReplayMeZO,
apply_reservoir_freezing,
patch_training_loops,
)
def run_baseline(model, token_buf, args):
model.train()
seq = args.seq_len
n = token_buf.numel() // (seq + 1)
chunks = token_buf[: n * (seq + 1)].view(n, seq + 1)
class _Dataset(Dataset):
def __len__(self):
return chunks.size(0)
def __getitem__(self, i):
c = chunks[i]
return {"input_ids": c[:-1], "labels": c[1:]}
loader = DataLoader(_Dataset(), batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True)
params = [(n, p) for n, p in model.named_parameters() if p.requires_grad]
eps = 1e-3
def loss_fn(batch):
return model(batch["input_ids"], labels=batch["labels"]).loss
total_toks, total_loss = 0, 0.0
t0 = time.time()
di = iter(loader)
for _ in range(args.max_steps):
try:
batch = next(di)
except StopIteration:
di = iter(loader)
batch = next(di)
seed = int(torch.randint(0, 2**31, (1,)).item())
gen = torch.Generator(device="cpu")
gen.manual_seed(seed)
for _, p in params:
p.data.add_(torch.randn(p.shape, generator=gen), alpha=eps)
for m in model.modules():
if isinstance(m, BitLinear):
m.invalidate_packed()
with torch.no_grad():
lp = float(loss_fn(batch).item())
gen.manual_seed(seed)
for _, p in params:
p.data.add_(torch.randn(p.shape, generator=gen), alpha=-2 * eps)
for m in model.modules():
if isinstance(m, BitLinear):
m.invalidate_packed()
with torch.no_grad():
ln = float(loss_fn(batch).item())
g = (lp - ln) / (2 * eps)
gen.manual_seed(seed)
for _, p in params:
z = torch.randn(p.shape, generator=gen)
p.data.add_(z, alpha=eps - args.lr * g)
for m in model.modules():
if isinstance(m, BitLinear):
m.invalidate_packed()
total_toks += batch["input_ids"].numel()
total_loss += 0.5 * (lp + ln)
dt = time.time() - t0
return total_toks / dt, total_loss / args.max_steps, dt
def run_hyper(model, token_buf, args):
model.train()
patch_training_loops(model, num_loops=1)
if args.reservoir:
apply_reservoir_freezing(model)
unfreezer = ProgressiveUnfreezer(model, args.max_steps, args.unfreeze_stages) if args.progressive_unfreeze else None
stages = [
(max(8, args.seq_len // 4), 0.30),
(max(16, args.seq_len // 2), 0.30),
(args.seq_len, 0.40),
]
grow = GrowLengthScheduler(stages, args.max_steps) if args.growlength else None
cur_seq = stages[0][0] if grow else args.seq_len
dataset = GrowLengthDataset(token_buf, cur_seq)
opt = SeedReplayMeZO(model, lr=args.lr * 0.01, eps=args.mezo_eps, weight_decay=0.1, momentum=0.9)
def loss_fn(batch):
if args.bf16:
with torch.autocast("cpu", dtype=torch.bfloat16):
return model(batch["input_ids"], labels=batch["labels"]).loss
return model(batch["input_ids"], labels=batch["labels"]).loss
total_toks, total_loss = 0, 0.0
t0 = time.time()
eff_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq))
loader = DataLoader(dataset, batch_size=eff_batch, shuffle=True, num_workers=0, drop_last=True)
di = iter(loader)
for step in range(args.max_steps):
if grow:
ns = grow.get_seq_len(step)
if ns != cur_seq:
cur_seq = ns
dataset.set_seq_len(cur_seq)
eff_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq))
loader = DataLoader(dataset, batch_size=eff_batch, shuffle=True, num_workers=0, drop_last=True)
di = iter(loader)
if unfreezer:
unfreezer.update(step)
try:
batch = next(di)
except StopIteration:
di = iter(loader)
batch = next(di)
loss_val = opt.step(loss_fn, batch)
total_toks += batch["input_ids"].numel()
total_loss += loss_val
dt = time.time() - t0
return total_toks / dt, total_loss / args.max_steps, dt
def benchmark_hyper(args):
print("=" * 65)
print("CHIMERA 5.3 HYPER v3 — BENCHMARK (full arch, all features)")
print("=" * 65)
model_a, cfg = build_model_from_args(args)
model_b = copy.deepcopy(model_a)
c = model_a.count_parameters()
print(f"Model: {c['total']:,} params, {cfg['num_hidden_layers']} layers")
print(f"Features: looping={model_a.looping_enabled} evolution={model_a.evolution is not None} span={model_a.span_engine is not None}")
tok_budget = max(500_000, args.max_steps * args.batch_size * (args.seq_len + 1) * 8)
token_buf = build_token_buffer(args.dataset_name, args.dataset_split, args.text_column, tok_budget, args.cache_dir)
print(f"Tokens: {token_buf.numel():,}\n")
print("-" * 65)
print("BASELINE (randn MeZO, invalidate_packed, loop=2, full evo)")
print("-" * 65)
bt, bl, bd = run_baseline(model_a, token_buf, args)
print(f" -> {bt:,.0f} tok/s loss={bl:.4f} time={bd:.1f}s\n")
print("-" * 65)
print("HYPER (seed-replay MeZO, STE path, loop=1, GrowLength, Reservoir)")
print("-" * 65)
ht, hl, hd = run_hyper(model_b, token_buf, args)
print(f" -> {ht:,.0f} tok/s loss={hl:.4f} time={hd:.1f}s\n")
sp = ht / bt if bt > 0 else float("inf")
print("=" * 65)
print(f" Baseline : {bt:>10,.0f} tok/s loss {bl:.4f}")
print(f" Hyper : {ht:>10,.0f} tok/s loss {hl:.4f}")
print(f" Speedup : {sp:>10.1f}x")
print("=" * 65)
os.makedirs(args.output_dir, exist_ok=True)
with open(os.path.join(args.output_dir, "benchmark.json"), "w") as f:
json.dump({"baseline_tps": round(bt), "hyper_tps": round(ht), "speedup": round(sp, 2)}, f, indent=2)
|