Upload iris/train_production.py
Browse files- iris/train_production.py +121 -0
iris/train_production.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
IRIS Production Training Script.
|
| 3 |
+
Usage: python -m iris.train_production --config iris-small --num_steps 5000
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import argparse, json, os, time, math
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch.utils.data import DataLoader
|
| 10 |
+
from .model import IRIS
|
| 11 |
+
from .configs import get_model_config, CONFIGS
|
| 12 |
+
from .flow_matching import flow_matching_loss
|
| 13 |
+
from .train import SyntheticLatentDataset, CosineWarmupScheduler
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def parse_args():
|
| 17 |
+
p = argparse.ArgumentParser(description="Train IRIS")
|
| 18 |
+
p.add_argument("--config", type=str, default="iris-small", choices=list(CONFIGS.keys()))
|
| 19 |
+
p.add_argument("--num_steps", type=int, default=5000)
|
| 20 |
+
p.add_argument("--batch_size", type=int, default=32)
|
| 21 |
+
p.add_argument("--lr", type=float, default=2e-4)
|
| 22 |
+
p.add_argument("--weight_decay", type=float, default=0.01)
|
| 23 |
+
p.add_argument("--warmup_steps", type=int, default=500)
|
| 24 |
+
p.add_argument("--grad_clip", type=float, default=1.0)
|
| 25 |
+
p.add_argument("--num_iterations", type=int, default=4)
|
| 26 |
+
p.add_argument("--num_workers", type=int, default=4)
|
| 27 |
+
p.add_argument("--log_every", type=int, default=50)
|
| 28 |
+
p.add_argument("--save_every", type=int, default=1000)
|
| 29 |
+
p.add_argument("--output_dir", type=str, default="./iris_output")
|
| 30 |
+
p.add_argument("--num_samples", type=int, default=50000)
|
| 31 |
+
p.add_argument("--seed", type=int, default=42)
|
| 32 |
+
p.add_argument("--resume", type=str, default=None)
|
| 33 |
+
return p.parse_args()
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def main():
|
| 37 |
+
args = parse_args()
|
| 38 |
+
torch.manual_seed(args.seed)
|
| 39 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 40 |
+
use_amp = device.type == "cuda"
|
| 41 |
+
amp_dtype = torch.bfloat16 if (use_amp and torch.cuda.is_bf16_supported()) else torch.float16 if use_amp else torch.float32
|
| 42 |
+
|
| 43 |
+
print(f"IRIS Training - {args.config} | Device: {device}, AMP: {amp_dtype}")
|
| 44 |
+
model_cfg = get_model_config(args.config)
|
| 45 |
+
model = IRIS(gradient_checkpointing=True, **model_cfg).to(device)
|
| 46 |
+
counts = model.count_params()
|
| 47 |
+
print(f"Model: {counts['total']:,} params ({counts['total']/1e6:.1f}M)")
|
| 48 |
+
|
| 49 |
+
dataset = SyntheticLatentDataset(num_samples=args.num_samples, latent_channels=model_cfg["latent_channels"], latent_size=16, text_dim=model_cfg["dim"], text_length=32, seed=args.seed)
|
| 50 |
+
loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=(device.type=="cuda"), drop_last=True, persistent_workers=(args.num_workers>0))
|
| 51 |
+
|
| 52 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, betas=(0.9, 0.999))
|
| 53 |
+
scheduler = CosineWarmupScheduler(optimizer, args.warmup_steps, args.num_steps)
|
| 54 |
+
scaler = torch.amp.GradScaler(enabled=(use_amp and amp_dtype == torch.float16))
|
| 55 |
+
|
| 56 |
+
start_step = 0
|
| 57 |
+
loss_history = []
|
| 58 |
+
if args.resume:
|
| 59 |
+
ckpt = torch.load(args.resume, map_location=device, weights_only=False)
|
| 60 |
+
model.load_state_dict(ckpt["model_state_dict"])
|
| 61 |
+
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
|
| 62 |
+
start_step = ckpt["step"]
|
| 63 |
+
loss_history = ckpt.get("loss_history", [])
|
| 64 |
+
for _ in range(start_step): scheduler.step()
|
| 65 |
+
|
| 66 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 67 |
+
with open(os.path.join(args.output_dir, "config.json"), "w") as f:
|
| 68 |
+
json.dump({"model_config": args.config, "model_params": model_cfg, "training": vars(args)}, f, indent=2)
|
| 69 |
+
|
| 70 |
+
model.train()
|
| 71 |
+
step, epoch, running_loss, best_loss = start_step, 0, 0.0, float("inf")
|
| 72 |
+
start_time = time.time()
|
| 73 |
+
|
| 74 |
+
while step < args.num_steps:
|
| 75 |
+
epoch += 1
|
| 76 |
+
for batch in loader:
|
| 77 |
+
if step >= args.num_steps: break
|
| 78 |
+
latent = batch["latent"].to(device, non_blocking=True)
|
| 79 |
+
text_embed = batch["text_embed"].to(device, non_blocking=True)
|
| 80 |
+
|
| 81 |
+
with torch.amp.autocast(device_type=device.type, dtype=amp_dtype, enabled=use_amp):
|
| 82 |
+
losses = flow_matching_loss(model, latent, text_embed, num_iterations=args.num_iterations, timestep_sampling="logit_normal")
|
| 83 |
+
loss = losses["loss"]
|
| 84 |
+
|
| 85 |
+
optimizer.zero_grad(set_to_none=True)
|
| 86 |
+
if scaler.is_enabled():
|
| 87 |
+
scaler.scale(loss).backward()
|
| 88 |
+
scaler.unscale_(optimizer)
|
| 89 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
| 90 |
+
scaler.step(optimizer)
|
| 91 |
+
scaler.update()
|
| 92 |
+
else:
|
| 93 |
+
loss.backward()
|
| 94 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
|
| 95 |
+
optimizer.step()
|
| 96 |
+
|
| 97 |
+
scheduler.step()
|
| 98 |
+
step += 1
|
| 99 |
+
running_loss += loss.item()
|
| 100 |
+
loss_history.append(loss.item())
|
| 101 |
+
|
| 102 |
+
if step % args.log_every == 0:
|
| 103 |
+
avg = running_loss / args.log_every
|
| 104 |
+
gn = grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm
|
| 105 |
+
print(f"step={step:>6d} | loss={avg:.4f} | grad={gn:.3f} | lr={scheduler.get_lr()[0]:.2e} | {time.time()-start_time:.0f}s")
|
| 106 |
+
if avg < best_loss: best_loss = avg
|
| 107 |
+
running_loss = 0.0
|
| 108 |
+
|
| 109 |
+
if step % args.save_every == 0:
|
| 110 |
+
p = os.path.join(args.output_dir, f"iris_{args.config}_step{step}.pt")
|
| 111 |
+
torch.save({"step": step, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "loss_history": loss_history, "config": model_cfg}, p)
|
| 112 |
+
print(f" Saved: {p}")
|
| 113 |
+
|
| 114 |
+
final = os.path.join(args.output_dir, f"iris_{args.config}_final.pt")
|
| 115 |
+
torch.save({"step": step, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "loss_history": loss_history, "config": model_cfg, "config_name": args.config}, final)
|
| 116 |
+
f10 = sum(loss_history[:100]) / min(100, len(loss_history))
|
| 117 |
+
l10 = sum(loss_history[-100:]) / min(100, len(loss_history))
|
| 118 |
+
print(f"Done: {step} steps, loss {f10:.4f} -> {l10:.4f} ({(1-l10/f10)*100:.1f}% reduction). Saved: {final}")
|
| 119 |
+
|
| 120 |
+
if __name__ == "__main__":
|
| 121 |
+
main()
|