chomera / chimera /training /loops.py
Lgr54HFi's picture
fix: print every step + first-step timing to diagnose slow forward
5b5a08d verified
from __future__ import annotations
import json
import math
import os
import sys
import time
import torch
import chimera_turbo
from .common import save_final_checkpoint, save_training_checkpoint
from .hyper import ProgressiveLoopScheduler
def _safe_batch(desired_batch: int, seq_len: int, vocab_size: int,
max_logits_gb: float = 2.0) -> int:
"""Cap batch size so the logits tensor fits in memory.
Logits shape: [batch, seq, vocab] at fp32 = batch * seq * vocab * 4 bytes.
With vocab=200073, batch=256, seq=16: 3.28 GB just for logits.
Backward doubles this. Must stay well under 32 GB total.
"""
bytes_per_sample = seq_len * vocab_size * 4 # fp32 logits
max_bytes = int(max_logits_gb * 1024**3)
max_batch = max(1, max_bytes // bytes_per_sample)
capped = min(desired_batch, max_batch)
if capped < desired_batch:
print(f" [MEM] Batch {desired_batch} β†’ {capped} (logits would be "
f"{desired_batch * seq_len * vocab_size * 4 / 1e9:.1f} GB, cap={max_logits_gb} GB)")
sys.stdout.flush()
return capped
def train_fast_loop(args, model, config, loader, compute_loss) -> str:
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.98))
os.makedirs(args.output_dir, exist_ok=True)
model.train()
step, total_loss, best_loss, toks = 0, 0.0, float("inf"), 0
t0 = time.time()
data_iter = iter(loader)
while step < args.max_steps:
try:
batch = next(data_iter)
except StopIteration:
data_iter = iter(loader)
batch = next(data_iter)
loss = compute_loss(batch)
loss.backward()
total_loss += float(loss.item())
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
optimizer.zero_grad(set_to_none=True)
toks += batch["input_ids"].numel()
step += 1
if step % args.log_every == 0:
dt = time.time() - t0
avg = total_loss / args.log_every
ppl = math.exp(min(avg, 20))
tps = toks / dt if dt > 0 else 0
print(f" step {step:>6}/{args.max_steps} | loss {avg:.4f} | ppl {ppl:>8.2f} | {tps:.0f} tok/s")
best_loss = min(best_loss, avg)
total_loss, toks, t0 = 0.0, 0, time.time()
save_final_checkpoint(model, config, step, best_loss, os.path.join(args.output_dir, "final"))
return os.path.join(args.output_dir, "final")
def train_standard_loop(args, model, config, loader, compute_loss, optimizer, use_mezo):
pass
def train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer):
use_compile = getattr(args, "compile", False)
vocab_size = int(config.get("vocab_size", 200073))
# ── Muon LR for ternary BitLinear ──
muon_lr = 0.012
muon_warmup = 30
model, optimizer, scheduler, extras = chimera_turbo.apply(
model,
max_steps=args.max_steps,
lr=muon_lr,
weight_decay=0.02,
warmup_steps=muon_warmup,
use_compile=use_compile,
mtp_heads=0,
llrd_decay=0.90,
grokfast_alpha=0.95,
grokfast_lambda=1.5,
)
model.train()
# ── Gradient checkpointing: saves ~60% activation memory ──
raw_model = getattr(model, "_orig_mod", model)
if hasattr(raw_model, "enable_gradient_checkpointing"):
raw_model.enable_gradient_checkpointing()
print(f"[OPT] Gradient checkpointing: ON")
# ── Looping: force loops=1 ──
cur_loops = 1
if hasattr(raw_model, "loop_controller"):
raw_model.loop_controller.loop_default = 1
raw_model.loop_controller.loop_min = 1
raw_model.loop_controller.loop_max = 1
use_bf16 = bool(args.bf16)
os.makedirs(args.output_dir, exist_ok=True)
log_f = open(os.path.join(args.output_dir, "log_hyper.jsonl"), "w")
step, total_loss, valid_count, best_loss, toks = 0, 0.0, 0, float("inf"), 0
t0 = time.time()
t_start = t0
cur_seq = initial_seq
# ── Memory-safe batch size ──
desired_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq))
eff_batch = _safe_batch(desired_batch, cur_seq, vocab_size)
loader = torch.utils.data.DataLoader(
dataset, batch_size=eff_batch, shuffle=True, num_workers=0, drop_last=True)
data_iter = iter(loader)
print(f"\n{'=' * 65}")
print(f"Training batch={eff_batch} seq={cur_seq} loops={cur_loops}")
print(f"Starting first step (may take 30-60s on CPU with 227M params)...")
print(f"{'=' * 65}")
sys.stdout.flush()
while step < args.max_steps:
if grow:
ns = grow.get_seq_len(step)
if ns != cur_seq:
cur_seq = ns
dataset.set_seq_len(cur_seq)
desired_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq))
eff_batch = _safe_batch(desired_batch, cur_seq, vocab_size)
loader = torch.utils.data.DataLoader(
dataset, batch_size=eff_batch, shuffle=True, num_workers=0, drop_last=True)
data_iter = iter(loader)
print(f" [P1] seq -> {cur_seq} batch -> {eff_batch}")
sys.stdout.flush()
if unfreezer:
unfreezer.update(step)
try:
batch = next(data_iter)
except StopIteration:
data_iter = iter(loader)
batch = next(data_iter)
step_t0 = time.time()
loss_val = chimera_turbo.training_step(
model, batch, optimizer, scheduler,
extras=extras, grad_accum_steps=1, step=step,
autocast_dtype=torch.bfloat16 if use_bf16 else None,
)
step_dt = time.time() - step_t0
cur_lr = optimizer.param_groups[0]["lr"] * optimizer.param_groups[0].get("lr_scale", 1.0)
if math.isfinite(loss_val):
total_loss += loss_val
valid_count += 1
step_toks = batch["input_ids"].numel()
toks += step_toks
step += 1
# Print every step for the first 5 steps, then every log_every
should_log = (step <= 5) or (step % args.log_every == 0)
if step == 1:
step_tps = step_toks / step_dt if step_dt > 0 else 0
print(f" βœ“ Step 1 completed in {step_dt:.1f}s "
f"({step_tps:.0f} tok/s, loss={loss_val:.4f})")
sys.stdout.flush()
if should_log:
dt = time.time() - t0
if valid_count > 0:
avg = total_loss / valid_count
ppl = math.exp(min(avg, 20)) if math.isfinite(avg) else float("nan")
else:
avg = float("nan")
ppl = float("nan")
tps = toks / dt if dt > 0 else 0
elapsed = time.time() - t_start
eta_s = (args.max_steps - step) * (elapsed / max(1, step))
log_f.write(json.dumps({
"step": step, "loss": round(avg, 4) if math.isfinite(avg) else None,
"ppl": round(ppl, 2) if math.isfinite(ppl) else None,
"lr": round(cur_lr, 6), "tok/s": round(tps),
"seq": cur_seq, "loops": cur_loops,
"step_time": round(step_dt, 2),
}) + "\n")
log_f.flush()
print(
f" step {step:>6}/{args.max_steps} | loss {avg:.4f} | ppl {ppl:>8.2f} "
f"| {tps:,.0f} tok/s | {step_dt:.1f}s/step | seq {cur_seq} "
f"| ETA {eta_s / 60:.0f}m"
)
sys.stdout.flush()
if step > 5:
# Reset counters for clean averages
best_loss = min(best_loss, avg) if math.isfinite(avg) else best_loss
total_loss, valid_count, toks, t0 = 0.0, 0, 0, time.time()
if step % args.save_every == 0:
d = save_training_checkpoint(model, config, step,
os.path.join(args.output_dir, f"ckpt-{step}"))
print(f" [SAVE] {d}")
sys.stdout.flush()
d = save_final_checkpoint(model, config, step, best_loss,
os.path.join(args.output_dir, "final"))
log_f.close()
total_time = time.time() - t_start
print(f"\nDONE -- best loss {best_loss:.4f} ppl {math.exp(min(best_loss, 20)):.2f}"
f" total time {total_time / 60:.1f}m")
sys.stdout.flush()
return d