asdf98 commited on
Commit
5a3f8af
·
verified ·
1 Parent(s): fcd6bec

Upload iris/train_production.py

Browse files
Files changed (1) hide show
  1. iris/train_production.py +121 -0
iris/train_production.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ IRIS Production Training Script.
3
+ Usage: python -m iris.train_production --config iris-small --num_steps 5000
4
+ """
5
+
6
+ import argparse, json, os, time, math
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from torch.utils.data import DataLoader
10
+ from .model import IRIS
11
+ from .configs import get_model_config, CONFIGS
12
+ from .flow_matching import flow_matching_loss
13
+ from .train import SyntheticLatentDataset, CosineWarmupScheduler
14
+
15
+
16
+ def parse_args():
17
+ p = argparse.ArgumentParser(description="Train IRIS")
18
+ p.add_argument("--config", type=str, default="iris-small", choices=list(CONFIGS.keys()))
19
+ p.add_argument("--num_steps", type=int, default=5000)
20
+ p.add_argument("--batch_size", type=int, default=32)
21
+ p.add_argument("--lr", type=float, default=2e-4)
22
+ p.add_argument("--weight_decay", type=float, default=0.01)
23
+ p.add_argument("--warmup_steps", type=int, default=500)
24
+ p.add_argument("--grad_clip", type=float, default=1.0)
25
+ p.add_argument("--num_iterations", type=int, default=4)
26
+ p.add_argument("--num_workers", type=int, default=4)
27
+ p.add_argument("--log_every", type=int, default=50)
28
+ p.add_argument("--save_every", type=int, default=1000)
29
+ p.add_argument("--output_dir", type=str, default="./iris_output")
30
+ p.add_argument("--num_samples", type=int, default=50000)
31
+ p.add_argument("--seed", type=int, default=42)
32
+ p.add_argument("--resume", type=str, default=None)
33
+ return p.parse_args()
34
+
35
+
36
+ def main():
37
+ args = parse_args()
38
+ torch.manual_seed(args.seed)
39
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
+ use_amp = device.type == "cuda"
41
+ amp_dtype = torch.bfloat16 if (use_amp and torch.cuda.is_bf16_supported()) else torch.float16 if use_amp else torch.float32
42
+
43
+ print(f"IRIS Training - {args.config} | Device: {device}, AMP: {amp_dtype}")
44
+ model_cfg = get_model_config(args.config)
45
+ model = IRIS(gradient_checkpointing=True, **model_cfg).to(device)
46
+ counts = model.count_params()
47
+ print(f"Model: {counts['total']:,} params ({counts['total']/1e6:.1f}M)")
48
+
49
+ dataset = SyntheticLatentDataset(num_samples=args.num_samples, latent_channels=model_cfg["latent_channels"], latent_size=16, text_dim=model_cfg["dim"], text_length=32, seed=args.seed)
50
+ loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=(device.type=="cuda"), drop_last=True, persistent_workers=(args.num_workers>0))
51
+
52
+ optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, betas=(0.9, 0.999))
53
+ scheduler = CosineWarmupScheduler(optimizer, args.warmup_steps, args.num_steps)
54
+ scaler = torch.amp.GradScaler(enabled=(use_amp and amp_dtype == torch.float16))
55
+
56
+ start_step = 0
57
+ loss_history = []
58
+ if args.resume:
59
+ ckpt = torch.load(args.resume, map_location=device, weights_only=False)
60
+ model.load_state_dict(ckpt["model_state_dict"])
61
+ optimizer.load_state_dict(ckpt["optimizer_state_dict"])
62
+ start_step = ckpt["step"]
63
+ loss_history = ckpt.get("loss_history", [])
64
+ for _ in range(start_step): scheduler.step()
65
+
66
+ os.makedirs(args.output_dir, exist_ok=True)
67
+ with open(os.path.join(args.output_dir, "config.json"), "w") as f:
68
+ json.dump({"model_config": args.config, "model_params": model_cfg, "training": vars(args)}, f, indent=2)
69
+
70
+ model.train()
71
+ step, epoch, running_loss, best_loss = start_step, 0, 0.0, float("inf")
72
+ start_time = time.time()
73
+
74
+ while step < args.num_steps:
75
+ epoch += 1
76
+ for batch in loader:
77
+ if step >= args.num_steps: break
78
+ latent = batch["latent"].to(device, non_blocking=True)
79
+ text_embed = batch["text_embed"].to(device, non_blocking=True)
80
+
81
+ with torch.amp.autocast(device_type=device.type, dtype=amp_dtype, enabled=use_amp):
82
+ losses = flow_matching_loss(model, latent, text_embed, num_iterations=args.num_iterations, timestep_sampling="logit_normal")
83
+ loss = losses["loss"]
84
+
85
+ optimizer.zero_grad(set_to_none=True)
86
+ if scaler.is_enabled():
87
+ scaler.scale(loss).backward()
88
+ scaler.unscale_(optimizer)
89
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
90
+ scaler.step(optimizer)
91
+ scaler.update()
92
+ else:
93
+ loss.backward()
94
+ grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
95
+ optimizer.step()
96
+
97
+ scheduler.step()
98
+ step += 1
99
+ running_loss += loss.item()
100
+ loss_history.append(loss.item())
101
+
102
+ if step % args.log_every == 0:
103
+ avg = running_loss / args.log_every
104
+ gn = grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm
105
+ print(f"step={step:>6d} | loss={avg:.4f} | grad={gn:.3f} | lr={scheduler.get_lr()[0]:.2e} | {time.time()-start_time:.0f}s")
106
+ if avg < best_loss: best_loss = avg
107
+ running_loss = 0.0
108
+
109
+ if step % args.save_every == 0:
110
+ p = os.path.join(args.output_dir, f"iris_{args.config}_step{step}.pt")
111
+ torch.save({"step": step, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "loss_history": loss_history, "config": model_cfg}, p)
112
+ print(f" Saved: {p}")
113
+
114
+ final = os.path.join(args.output_dir, f"iris_{args.config}_final.pt")
115
+ torch.save({"step": step, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "loss_history": loss_history, "config": model_cfg, "config_name": args.config}, final)
116
+ f10 = sum(loss_history[:100]) / min(100, len(loss_history))
117
+ l10 = sum(loss_history[-100:]) / min(100, len(loss_history))
118
+ print(f"Done: {step} steps, loss {f10:.4f} -> {l10:.4f} ({(1-l10/f10)*100:.1f}% reduction). Saved: {final}")
119
+
120
+ if __name__ == "__main__":
121
+ main()