from __future__ import annotations import json import math import os import sys import time import torch import chimera_turbo from .common import save_final_checkpoint, save_training_checkpoint from .hyper import ProgressiveLoopScheduler def _safe_batch(desired_batch: int, seq_len: int, vocab_size: int, max_logits_gb: float = 2.0) -> int: """Cap batch size so the logits tensor fits in memory. Logits shape: [batch, seq, vocab] at fp32 = batch * seq * vocab * 4 bytes. With vocab=200073, batch=256, seq=16: 3.28 GB just for logits. Backward doubles this. Must stay well under 32 GB total. """ bytes_per_sample = seq_len * vocab_size * 4 # fp32 logits max_bytes = int(max_logits_gb * 1024**3) max_batch = max(1, max_bytes // bytes_per_sample) capped = min(desired_batch, max_batch) if capped < desired_batch: print(f" [MEM] Batch {desired_batch} → {capped} (logits would be " f"{desired_batch * seq_len * vocab_size * 4 / 1e9:.1f} GB, cap={max_logits_gb} GB)") sys.stdout.flush() return capped def train_fast_loop(args, model, config, loader, compute_loss) -> str: optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.98)) os.makedirs(args.output_dir, exist_ok=True) model.train() step, total_loss, best_loss, toks = 0, 0.0, float("inf"), 0 t0 = time.time() data_iter = iter(loader) while step < args.max_steps: try: batch = next(data_iter) except StopIteration: data_iter = iter(loader) batch = next(data_iter) loss = compute_loss(batch) loss.backward() total_loss += float(loss.item()) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() optimizer.zero_grad(set_to_none=True) toks += batch["input_ids"].numel() step += 1 if step % args.log_every == 0: dt = time.time() - t0 avg = total_loss / args.log_every ppl = math.exp(min(avg, 20)) tps = toks / dt if dt > 0 else 0 print(f" step {step:>6}/{args.max_steps} | loss {avg:.4f} | ppl {ppl:>8.2f} | {tps:.0f} tok/s") best_loss = min(best_loss, avg) total_loss, toks, t0 = 0.0, 0, time.time() save_final_checkpoint(model, config, step, best_loss, os.path.join(args.output_dir, "final")) return os.path.join(args.output_dir, "final") def train_standard_loop(args, model, config, loader, compute_loss, optimizer, use_mezo): pass def train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer): use_compile = getattr(args, "compile", False) vocab_size = int(config.get("vocab_size", 200073)) # ── Muon LR for ternary BitLinear ── muon_lr = 0.012 muon_warmup = 30 model, optimizer, scheduler, extras = chimera_turbo.apply( model, max_steps=args.max_steps, lr=muon_lr, weight_decay=0.02, warmup_steps=muon_warmup, use_compile=use_compile, mtp_heads=0, llrd_decay=0.90, grokfast_alpha=0.95, grokfast_lambda=1.5, ) model.train() # ── Gradient checkpointing: saves ~60% activation memory ── raw_model = getattr(model, "_orig_mod", model) if hasattr(raw_model, "enable_gradient_checkpointing"): raw_model.enable_gradient_checkpointing() print(f"[OPT] Gradient checkpointing: ON") # ── Looping: force loops=1 ── cur_loops = 1 if hasattr(raw_model, "loop_controller"): raw_model.loop_controller.loop_default = 1 raw_model.loop_controller.loop_min = 1 raw_model.loop_controller.loop_max = 1 use_bf16 = bool(args.bf16) os.makedirs(args.output_dir, exist_ok=True) log_f = open(os.path.join(args.output_dir, "log_hyper.jsonl"), "w") step, total_loss, valid_count, best_loss, toks = 0, 0.0, 0, float("inf"), 0 t0 = time.time() t_start = t0 cur_seq = initial_seq # ── Memory-safe batch size ── desired_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq)) eff_batch = _safe_batch(desired_batch, cur_seq, vocab_size) loader = torch.utils.data.DataLoader( dataset, batch_size=eff_batch, shuffle=True, num_workers=0, drop_last=True) data_iter = iter(loader) print(f"\n{'=' * 65}") print(f"Training batch={eff_batch} seq={cur_seq} loops={cur_loops}") print(f"Starting first step (may take 30-60s on CPU with 227M params)...") print(f"{'=' * 65}") sys.stdout.flush() while step < args.max_steps: if grow: ns = grow.get_seq_len(step) if ns != cur_seq: cur_seq = ns dataset.set_seq_len(cur_seq) desired_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq)) eff_batch = _safe_batch(desired_batch, cur_seq, vocab_size) loader = torch.utils.data.DataLoader( dataset, batch_size=eff_batch, shuffle=True, num_workers=0, drop_last=True) data_iter = iter(loader) print(f" [P1] seq -> {cur_seq} batch -> {eff_batch}") sys.stdout.flush() if unfreezer: unfreezer.update(step) try: batch = next(data_iter) except StopIteration: data_iter = iter(loader) batch = next(data_iter) step_t0 = time.time() loss_val = chimera_turbo.training_step( model, batch, optimizer, scheduler, extras=extras, grad_accum_steps=1, step=step, autocast_dtype=torch.bfloat16 if use_bf16 else None, ) step_dt = time.time() - step_t0 cur_lr = optimizer.param_groups[0]["lr"] * optimizer.param_groups[0].get("lr_scale", 1.0) if math.isfinite(loss_val): total_loss += loss_val valid_count += 1 step_toks = batch["input_ids"].numel() toks += step_toks step += 1 # Print every step for the first 5 steps, then every log_every should_log = (step <= 5) or (step % args.log_every == 0) if step == 1: step_tps = step_toks / step_dt if step_dt > 0 else 0 print(f" ✓ Step 1 completed in {step_dt:.1f}s " f"({step_tps:.0f} tok/s, loss={loss_val:.4f})") sys.stdout.flush() if should_log: dt = time.time() - t0 if valid_count > 0: avg = total_loss / valid_count ppl = math.exp(min(avg, 20)) if math.isfinite(avg) else float("nan") else: avg = float("nan") ppl = float("nan") tps = toks / dt if dt > 0 else 0 elapsed = time.time() - t_start eta_s = (args.max_steps - step) * (elapsed / max(1, step)) log_f.write(json.dumps({ "step": step, "loss": round(avg, 4) if math.isfinite(avg) else None, "ppl": round(ppl, 2) if math.isfinite(ppl) else None, "lr": round(cur_lr, 6), "tok/s": round(tps), "seq": cur_seq, "loops": cur_loops, "step_time": round(step_dt, 2), }) + "\n") log_f.flush() print( f" step {step:>6}/{args.max_steps} | loss {avg:.4f} | ppl {ppl:>8.2f} " f"| {tps:,.0f} tok/s | {step_dt:.1f}s/step | seq {cur_seq} " f"| ETA {eta_s / 60:.0f}m" ) sys.stdout.flush() if step > 5: # Reset counters for clean averages best_loss = min(best_loss, avg) if math.isfinite(avg) else best_loss total_loss, valid_count, toks, t0 = 0.0, 0, 0, time.time() if step % args.save_every == 0: d = save_training_checkpoint(model, config, step, os.path.join(args.output_dir, f"ckpt-{step}")) print(f" [SAVE] {d}") sys.stdout.flush() d = save_final_checkpoint(model, config, step, best_loss, os.path.join(args.output_dir, "final")) log_f.close() total_time = time.time() - t_start print(f"\nDONE -- best loss {best_loss:.4f} ppl {math.exp(min(best_loss, 20)):.2f}" f" total time {total_time / 60:.1f}m") sys.stdout.flush() return d