| from __future__ import annotations |
|
|
| import json |
| import math |
| import os |
| import sys |
| import time |
|
|
| import torch |
|
|
| import chimera_turbo |
|
|
| from .common import save_final_checkpoint, save_training_checkpoint |
| from .hyper import ProgressiveLoopScheduler |
|
|
|
|
| def _safe_batch(desired_batch: int, seq_len: int, vocab_size: int, |
| max_logits_gb: float = 2.0) -> int: |
| """Cap batch size so the logits tensor fits in memory. |
| |
| Logits shape: [batch, seq, vocab] at fp32 = batch * seq * vocab * 4 bytes. |
| With vocab=200073, batch=256, seq=16: 3.28 GB just for logits. |
| Backward doubles this. Must stay well under 32 GB total. |
| """ |
| bytes_per_sample = seq_len * vocab_size * 4 |
| max_bytes = int(max_logits_gb * 1024**3) |
| max_batch = max(1, max_bytes // bytes_per_sample) |
| capped = min(desired_batch, max_batch) |
| if capped < desired_batch: |
| print(f" [MEM] Batch {desired_batch} β {capped} (logits would be " |
| f"{desired_batch * seq_len * vocab_size * 4 / 1e9:.1f} GB, cap={max_logits_gb} GB)") |
| sys.stdout.flush() |
| return capped |
|
|
|
|
| def train_fast_loop(args, model, config, loader, compute_loss) -> str: |
| optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.98)) |
| os.makedirs(args.output_dir, exist_ok=True) |
| model.train() |
| step, total_loss, best_loss, toks = 0, 0.0, float("inf"), 0 |
| t0 = time.time() |
| data_iter = iter(loader) |
| while step < args.max_steps: |
| try: |
| batch = next(data_iter) |
| except StopIteration: |
| data_iter = iter(loader) |
| batch = next(data_iter) |
| loss = compute_loss(batch) |
| loss.backward() |
| total_loss += float(loss.item()) |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
| optimizer.zero_grad(set_to_none=True) |
| toks += batch["input_ids"].numel() |
| step += 1 |
| if step % args.log_every == 0: |
| dt = time.time() - t0 |
| avg = total_loss / args.log_every |
| ppl = math.exp(min(avg, 20)) |
| tps = toks / dt if dt > 0 else 0 |
| print(f" step {step:>6}/{args.max_steps} | loss {avg:.4f} | ppl {ppl:>8.2f} | {tps:.0f} tok/s") |
| best_loss = min(best_loss, avg) |
| total_loss, toks, t0 = 0.0, 0, time.time() |
| save_final_checkpoint(model, config, step, best_loss, os.path.join(args.output_dir, "final")) |
| return os.path.join(args.output_dir, "final") |
|
|
|
|
| def train_standard_loop(args, model, config, loader, compute_loss, optimizer, use_mezo): |
| pass |
|
|
|
|
| def train_hyper_loop(args, model, config, dataset, initial_seq, grow, unfreezer): |
| use_compile = getattr(args, "compile", False) |
| vocab_size = int(config.get("vocab_size", 200073)) |
|
|
| |
| muon_lr = 0.012 |
| muon_warmup = 30 |
|
|
| model, optimizer, scheduler, extras = chimera_turbo.apply( |
| model, |
| max_steps=args.max_steps, |
| lr=muon_lr, |
| weight_decay=0.02, |
| warmup_steps=muon_warmup, |
| use_compile=use_compile, |
| mtp_heads=0, |
| llrd_decay=0.90, |
| grokfast_alpha=0.95, |
| grokfast_lambda=1.5, |
| ) |
| model.train() |
|
|
| |
| raw_model = getattr(model, "_orig_mod", model) |
| if hasattr(raw_model, "enable_gradient_checkpointing"): |
| raw_model.enable_gradient_checkpointing() |
| print(f"[OPT] Gradient checkpointing: ON") |
|
|
| |
| cur_loops = 1 |
| if hasattr(raw_model, "loop_controller"): |
| raw_model.loop_controller.loop_default = 1 |
| raw_model.loop_controller.loop_min = 1 |
| raw_model.loop_controller.loop_max = 1 |
|
|
| use_bf16 = bool(args.bf16) |
|
|
| os.makedirs(args.output_dir, exist_ok=True) |
| log_f = open(os.path.join(args.output_dir, "log_hyper.jsonl"), "w") |
| step, total_loss, valid_count, best_loss, toks = 0, 0.0, 0, float("inf"), 0 |
| t0 = time.time() |
| t_start = t0 |
| cur_seq = initial_seq |
|
|
| |
| desired_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq)) |
| eff_batch = _safe_batch(desired_batch, cur_seq, vocab_size) |
|
|
| loader = torch.utils.data.DataLoader( |
| dataset, batch_size=eff_batch, shuffle=True, num_workers=0, drop_last=True) |
| data_iter = iter(loader) |
|
|
| print(f"\n{'=' * 65}") |
| print(f"Training batch={eff_batch} seq={cur_seq} loops={cur_loops}") |
| print(f"Starting first step (may take 30-60s on CPU with 227M params)...") |
| print(f"{'=' * 65}") |
| sys.stdout.flush() |
|
|
| while step < args.max_steps: |
| if grow: |
| ns = grow.get_seq_len(step) |
| if ns != cur_seq: |
| cur_seq = ns |
| dataset.set_seq_len(cur_seq) |
| desired_batch = args.batch_size * max(1, args.seq_len // max(1, cur_seq)) |
| eff_batch = _safe_batch(desired_batch, cur_seq, vocab_size) |
| loader = torch.utils.data.DataLoader( |
| dataset, batch_size=eff_batch, shuffle=True, num_workers=0, drop_last=True) |
| data_iter = iter(loader) |
| print(f" [P1] seq -> {cur_seq} batch -> {eff_batch}") |
| sys.stdout.flush() |
|
|
| if unfreezer: |
| unfreezer.update(step) |
|
|
| try: |
| batch = next(data_iter) |
| except StopIteration: |
| data_iter = iter(loader) |
| batch = next(data_iter) |
|
|
| step_t0 = time.time() |
|
|
| loss_val = chimera_turbo.training_step( |
| model, batch, optimizer, scheduler, |
| extras=extras, grad_accum_steps=1, step=step, |
| autocast_dtype=torch.bfloat16 if use_bf16 else None, |
| ) |
|
|
| step_dt = time.time() - step_t0 |
|
|
| cur_lr = optimizer.param_groups[0]["lr"] * optimizer.param_groups[0].get("lr_scale", 1.0) |
| if math.isfinite(loss_val): |
| total_loss += loss_val |
| valid_count += 1 |
| step_toks = batch["input_ids"].numel() |
| toks += step_toks |
| step += 1 |
|
|
| |
| should_log = (step <= 5) or (step % args.log_every == 0) |
|
|
| if step == 1: |
| step_tps = step_toks / step_dt if step_dt > 0 else 0 |
| print(f" β Step 1 completed in {step_dt:.1f}s " |
| f"({step_tps:.0f} tok/s, loss={loss_val:.4f})") |
| sys.stdout.flush() |
|
|
| if should_log: |
| dt = time.time() - t0 |
| if valid_count > 0: |
| avg = total_loss / valid_count |
| ppl = math.exp(min(avg, 20)) if math.isfinite(avg) else float("nan") |
| else: |
| avg = float("nan") |
| ppl = float("nan") |
| tps = toks / dt if dt > 0 else 0 |
| elapsed = time.time() - t_start |
| eta_s = (args.max_steps - step) * (elapsed / max(1, step)) |
| log_f.write(json.dumps({ |
| "step": step, "loss": round(avg, 4) if math.isfinite(avg) else None, |
| "ppl": round(ppl, 2) if math.isfinite(ppl) else None, |
| "lr": round(cur_lr, 6), "tok/s": round(tps), |
| "seq": cur_seq, "loops": cur_loops, |
| "step_time": round(step_dt, 2), |
| }) + "\n") |
| log_f.flush() |
| print( |
| f" step {step:>6}/{args.max_steps} | loss {avg:.4f} | ppl {ppl:>8.2f} " |
| f"| {tps:,.0f} tok/s | {step_dt:.1f}s/step | seq {cur_seq} " |
| f"| ETA {eta_s / 60:.0f}m" |
| ) |
| sys.stdout.flush() |
|
|
| if step > 5: |
| |
| best_loss = min(best_loss, avg) if math.isfinite(avg) else best_loss |
| total_loss, valid_count, toks, t0 = 0.0, 0, 0, time.time() |
|
|
| if step % args.save_every == 0: |
| d = save_training_checkpoint(model, config, step, |
| os.path.join(args.output_dir, f"ckpt-{step}")) |
| print(f" [SAVE] {d}") |
| sys.stdout.flush() |
|
|
| d = save_final_checkpoint(model, config, step, best_loss, |
| os.path.join(args.output_dir, "final")) |
| log_f.close() |
| total_time = time.time() - t_start |
| print(f"\nDONE -- best loss {best_loss:.4f} ppl {math.exp(min(best_loss, 20)):.2f}" |
| f" total time {total_time / 60:.1f}m") |
| sys.stdout.flush() |
| return d |
|
|