| """ |
| IRIS Production Training Script. |
| Usage: python -m iris.train_production --config iris-small --num_steps 5000 |
| """ |
|
|
| import argparse, json, os, time, math |
| import torch |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader |
| from .model import IRIS |
| from .configs import get_model_config, CONFIGS |
| from .flow_matching import flow_matching_loss |
| from .train import SyntheticLatentDataset, CosineWarmupScheduler |
|
|
|
|
| def parse_args(): |
| p = argparse.ArgumentParser(description="Train IRIS") |
| p.add_argument("--config", type=str, default="iris-small", choices=list(CONFIGS.keys())) |
| p.add_argument("--num_steps", type=int, default=5000) |
| p.add_argument("--batch_size", type=int, default=32) |
| p.add_argument("--lr", type=float, default=2e-4) |
| p.add_argument("--weight_decay", type=float, default=0.01) |
| p.add_argument("--warmup_steps", type=int, default=500) |
| p.add_argument("--grad_clip", type=float, default=1.0) |
| p.add_argument("--num_iterations", type=int, default=4) |
| p.add_argument("--num_workers", type=int, default=4) |
| p.add_argument("--log_every", type=int, default=50) |
| p.add_argument("--save_every", type=int, default=1000) |
| p.add_argument("--output_dir", type=str, default="./iris_output") |
| p.add_argument("--num_samples", type=int, default=50000) |
| p.add_argument("--seed", type=int, default=42) |
| p.add_argument("--resume", type=str, default=None) |
| return p.parse_args() |
|
|
|
|
| def main(): |
| args = parse_args() |
| torch.manual_seed(args.seed) |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| use_amp = device.type == "cuda" |
| |
| |
| if use_amp: |
| cc = torch.cuda.get_device_capability(0) |
| amp_dtype = torch.float16 if cc[0] < 8 else torch.bfloat16 |
| else: |
| amp_dtype = torch.float32 |
|
|
| print(f"IRIS Training - {args.config} | Device: {device}, AMP: {amp_dtype}") |
| model_cfg = get_model_config(args.config) |
| model = IRIS(gradient_checkpointing=True, **model_cfg).to(device) |
| counts = model.count_params() |
| print(f"Model: {counts['total']:,} params ({counts['total']/1e6:.1f}M)") |
|
|
| dataset = SyntheticLatentDataset(num_samples=args.num_samples, latent_channels=model_cfg["latent_channels"], latent_size=16, text_dim=model_cfg["dim"], text_length=32, seed=args.seed) |
| loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=(device.type=="cuda"), drop_last=True, persistent_workers=(args.num_workers>0)) |
|
|
| optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, betas=(0.9, 0.999)) |
| scheduler = CosineWarmupScheduler(optimizer, args.warmup_steps, args.num_steps) |
| scaler = torch.amp.GradScaler(enabled=(use_amp and amp_dtype == torch.float16)) |
|
|
| start_step = 0 |
| loss_history = [] |
| if args.resume: |
| ckpt = torch.load(args.resume, map_location=device, weights_only=False) |
| model.load_state_dict(ckpt["model_state_dict"]) |
| optimizer.load_state_dict(ckpt["optimizer_state_dict"]) |
| start_step = ckpt["step"] |
| loss_history = ckpt.get("loss_history", []) |
| for _ in range(start_step): scheduler.step() |
|
|
| os.makedirs(args.output_dir, exist_ok=True) |
| with open(os.path.join(args.output_dir, "config.json"), "w") as f: |
| json.dump({"model_config": args.config, "model_params": model_cfg, "training": vars(args)}, f, indent=2) |
|
|
| model.train() |
| step, epoch, running_loss, best_loss = start_step, 0, 0.0, float("inf") |
| start_time = time.time() |
|
|
| while step < args.num_steps: |
| epoch += 1 |
| for batch in loader: |
| if step >= args.num_steps: break |
| latent = batch["latent"].to(device, non_blocking=True) |
| text_embed = batch["text_embed"].to(device, non_blocking=True) |
|
|
| with torch.amp.autocast(device_type=device.type, dtype=amp_dtype, enabled=use_amp): |
| losses = flow_matching_loss(model, latent, text_embed, num_iterations=args.num_iterations, timestep_sampling="logit_normal") |
| loss = losses["loss"] |
|
|
| optimizer.zero_grad(set_to_none=True) |
| if scaler.is_enabled(): |
| scaler.scale(loss).backward() |
| scaler.unscale_(optimizer) |
| grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) |
| scaler.step(optimizer) |
| scaler.update() |
| else: |
| loss.backward() |
| grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) |
| optimizer.step() |
|
|
| scheduler.step() |
| step += 1 |
| running_loss += loss.item() |
| loss_history.append(loss.item()) |
|
|
| if step % args.log_every == 0: |
| avg = running_loss / args.log_every |
| gn = grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm |
| print(f"step={step:>6d} | loss={avg:.4f} | grad={gn:.3f} | lr={scheduler.get_lr()[0]:.2e} | {time.time()-start_time:.0f}s") |
| if avg < best_loss: best_loss = avg |
| running_loss = 0.0 |
|
|
| if step % args.save_every == 0: |
| p = os.path.join(args.output_dir, f"iris_{args.config}_step{step}.pt") |
| torch.save({"step": step, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "loss_history": loss_history, "config": model_cfg}, p) |
| print(f" Saved: {p}") |
|
|
| final = os.path.join(args.output_dir, f"iris_{args.config}_final.pt") |
| torch.save({"step": step, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "loss_history": loss_history, "config": model_cfg, "config_name": args.config}, final) |
| f10 = sum(loss_history[:100]) / min(100, len(loss_history)) |
| l10 = sum(loss_history[-100:]) / min(100, len(loss_history)) |
| print(f"Done: {step} steps, loss {f10:.4f} -> {l10:.4f} ({(1-l10/f10)*100:.1f}% reduction). Saved: {final}") |
|
|
| if __name__ == "__main__": |
| main() |
|
|