File size: 6,208 Bytes
5a3f8af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4aebda2
 
 
 
 
 
 
5a3f8af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
"""
IRIS Production Training Script.
Usage: python -m iris.train_production --config iris-small --num_steps 5000
"""

import argparse, json, os, time, math
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from .model import IRIS
from .configs import get_model_config, CONFIGS
from .flow_matching import flow_matching_loss
from .train import SyntheticLatentDataset, CosineWarmupScheduler


def parse_args():
    p = argparse.ArgumentParser(description="Train IRIS")
    p.add_argument("--config", type=str, default="iris-small", choices=list(CONFIGS.keys()))
    p.add_argument("--num_steps", type=int, default=5000)
    p.add_argument("--batch_size", type=int, default=32)
    p.add_argument("--lr", type=float, default=2e-4)
    p.add_argument("--weight_decay", type=float, default=0.01)
    p.add_argument("--warmup_steps", type=int, default=500)
    p.add_argument("--grad_clip", type=float, default=1.0)
    p.add_argument("--num_iterations", type=int, default=4)
    p.add_argument("--num_workers", type=int, default=4)
    p.add_argument("--log_every", type=int, default=50)
    p.add_argument("--save_every", type=int, default=1000)
    p.add_argument("--output_dir", type=str, default="./iris_output")
    p.add_argument("--num_samples", type=int, default=50000)
    p.add_argument("--seed", type=int, default=42)
    p.add_argument("--resume", type=str, default=None)
    return p.parse_args()


def main():
    args = parse_args()
    torch.manual_seed(args.seed)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    use_amp = device.type == "cuda"
    # T4 (compute cap 7.5) reports bf16 supported but cuDNN conv kernels crash.
    # Force fp16 on GPUs below Ampere (compute cap < 8.0).
    if use_amp:
        cc = torch.cuda.get_device_capability(0)
        amp_dtype = torch.float16 if cc[0] < 8 else torch.bfloat16
    else:
        amp_dtype = torch.float32

    print(f"IRIS Training - {args.config} | Device: {device}, AMP: {amp_dtype}")
    model_cfg = get_model_config(args.config)
    model = IRIS(gradient_checkpointing=True, **model_cfg).to(device)
    counts = model.count_params()
    print(f"Model: {counts['total']:,} params ({counts['total']/1e6:.1f}M)")

    dataset = SyntheticLatentDataset(num_samples=args.num_samples, latent_channels=model_cfg["latent_channels"], latent_size=16, text_dim=model_cfg["dim"], text_length=32, seed=args.seed)
    loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=(device.type=="cuda"), drop_last=True, persistent_workers=(args.num_workers>0))

    optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, betas=(0.9, 0.999))
    scheduler = CosineWarmupScheduler(optimizer, args.warmup_steps, args.num_steps)
    scaler = torch.amp.GradScaler(enabled=(use_amp and amp_dtype == torch.float16))

    start_step = 0
    loss_history = []
    if args.resume:
        ckpt = torch.load(args.resume, map_location=device, weights_only=False)
        model.load_state_dict(ckpt["model_state_dict"])
        optimizer.load_state_dict(ckpt["optimizer_state_dict"])
        start_step = ckpt["step"]
        loss_history = ckpt.get("loss_history", [])
        for _ in range(start_step): scheduler.step()

    os.makedirs(args.output_dir, exist_ok=True)
    with open(os.path.join(args.output_dir, "config.json"), "w") as f:
        json.dump({"model_config": args.config, "model_params": model_cfg, "training": vars(args)}, f, indent=2)

    model.train()
    step, epoch, running_loss, best_loss = start_step, 0, 0.0, float("inf")
    start_time = time.time()

    while step < args.num_steps:
        epoch += 1
        for batch in loader:
            if step >= args.num_steps: break
            latent = batch["latent"].to(device, non_blocking=True)
            text_embed = batch["text_embed"].to(device, non_blocking=True)

            with torch.amp.autocast(device_type=device.type, dtype=amp_dtype, enabled=use_amp):
                losses = flow_matching_loss(model, latent, text_embed, num_iterations=args.num_iterations, timestep_sampling="logit_normal")
                loss = losses["loss"]

            optimizer.zero_grad(set_to_none=True)
            if scaler.is_enabled():
                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
                optimizer.step()

            scheduler.step()
            step += 1
            running_loss += loss.item()
            loss_history.append(loss.item())

            if step % args.log_every == 0:
                avg = running_loss / args.log_every
                gn = grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm
                print(f"step={step:>6d} | loss={avg:.4f} | grad={gn:.3f} | lr={scheduler.get_lr()[0]:.2e} | {time.time()-start_time:.0f}s")
                if avg < best_loss: best_loss = avg
                running_loss = 0.0

            if step % args.save_every == 0:
                p = os.path.join(args.output_dir, f"iris_{args.config}_step{step}.pt")
                torch.save({"step": step, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "loss_history": loss_history, "config": model_cfg}, p)
                print(f"  Saved: {p}")

    final = os.path.join(args.output_dir, f"iris_{args.config}_final.pt")
    torch.save({"step": step, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "loss_history": loss_history, "config": model_cfg, "config_name": args.config}, final)
    f10 = sum(loss_history[:100]) / min(100, len(loss_history))
    l10 = sum(loss_history[-100:]) / min(100, len(loss_history))
    print(f"Done: {step} steps, loss {f10:.4f} -> {l10:.4f} ({(1-l10/f10)*100:.1f}% reduction). Saved: {final}")

if __name__ == "__main__":
    main()