| """ |
| IRIS Training Script |
| ===================== |
| End-to-end training pipeline for IRIS (Iterative Recurrent Image Synthesis). |
| |
| Supports: |
| - Stage 1: Wavelet VAE pre-training (reconstruction) |
| - Stage 2: Class-conditional pretraining (ImageNet) |
| - Stage 3: Text-image alignment (CLIP-conditioned) |
| - Stage 4: Aesthetic fine-tuning |
| |
| Usage: |
| python train_iris.py --stage 1 --dataset imagenet --epochs 50 |
| python train_iris.py --stage 3 --dataset cc3m --epochs 100 |
| |
| Designed to run on Colab/Kaggle (single GPU, T4/A100). |
| """ |
|
|
| import os |
| import math |
| import argparse |
| import time |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader, Dataset |
| from torch.cuda.amp import autocast, GradScaler |
| from pathlib import Path |
|
|
| from iris_model import ( |
| IRIS, IRISConfig, WaveletVAE, |
| create_iris_small, create_iris_tiny, create_iris_base, |
| count_parameters, estimate_memory_mb, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| class SyntheticImageTextDataset(Dataset): |
| """Synthetic dataset for testing the training pipeline.""" |
| def __init__(self, num_samples=1000, image_size=256, text_dim=768, text_len=77): |
| self.num_samples = num_samples |
| self.image_size = image_size |
| self.text_dim = text_dim |
| self.text_len = text_len |
| |
| def __len__(self): |
| return self.num_samples |
| |
| def __getitem__(self, idx): |
| image = torch.randn(3, self.image_size, self.image_size) |
| text = torch.randn(self.text_len, self.text_dim) |
| return image, text |
|
|
|
|
| |
| |
| |
|
|
| def train_vae(config: IRISConfig, args): |
| """Train the Wavelet VAE for image reconstruction.""" |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Training VAE on {device}") |
| |
| vae = WaveletVAE(config).to(device) |
| print(f"VAE params: {sum(p.numel() for p in vae.parameters()):,}") |
| |
| optimizer = torch.optim.AdamW(vae.parameters(), lr=1e-4, weight_decay=0.05) |
| scaler = GradScaler() if args.fp16 else None |
| |
| |
| num_downsamples = len(config.vae_channels) - 1 |
| total_downsample = 2 * (2 ** num_downsamples) |
| input_size = config.latent_spatial * total_downsample |
| |
| dataset = SyntheticImageTextDataset( |
| num_samples=args.num_samples, |
| image_size=input_size, |
| ) |
| loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, |
| num_workers=2, pin_memory=True) |
| |
| print(f"Input image size: {input_size}×{input_size}") |
| print(f"Latent size: {config.latent_spatial}×{config.latent_spatial}×{config.latent_channels}") |
| |
| vae.train() |
| for epoch in range(args.epochs): |
| total_loss = 0 |
| t0 = time.time() |
| |
| for batch_idx, (images, _) in enumerate(loader): |
| images = images.to(device) |
| |
| with autocast(enabled=args.fp16, dtype=torch.float16): |
| x_recon, mean, logvar = vae(images) |
| |
| |
| recon_loss = F.mse_loss(x_recon, images) |
| |
| |
| kl_loss = -0.5 * (1 + logvar - mean.pow(2) - logvar.exp()).mean() |
| |
| |
| from iris_model import HaarDWT2D |
| dwt = HaarDWT2D() |
| recon_wavelet = dwt(x_recon) |
| target_wavelet = dwt(images) |
| freq_loss = F.l1_loss(recon_wavelet, target_wavelet) |
| |
| loss = recon_loss + 0.001 * kl_loss + 0.1 * freq_loss |
| |
| optimizer.zero_grad() |
| if scaler: |
| scaler.scale(loss).backward() |
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_(vae.parameters(), 1.0) |
| scaler.step(optimizer) |
| scaler.update() |
| else: |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(vae.parameters(), 1.0) |
| optimizer.step() |
| |
| total_loss += loss.item() |
| |
| if batch_idx % 10 == 0: |
| print(f" Step {batch_idx}: loss={loss.item():.4f} " |
| f"(recon={recon_loss.item():.4f}, kl={kl_loss.item():.4f}, " |
| f"freq={freq_loss.item():.4f})") |
| |
| avg_loss = total_loss / len(loader) |
| dt = time.time() - t0 |
| print(f"Epoch {epoch+1}/{args.epochs}: avg_loss={avg_loss:.4f}, time={dt:.1f}s") |
| |
| |
| save_path = Path(args.output_dir) / "vae_checkpoint.pt" |
| save_path.parent.mkdir(parents=True, exist_ok=True) |
| torch.save(vae.state_dict(), save_path) |
| print(f"VAE saved to {save_path}") |
| return vae |
|
|
|
|
| |
| |
| |
|
|
| def train_generator(config: IRISConfig, args, vae_path=None): |
| """Train the IRIS generator with rectified flow.""" |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Training Generator on {device}") |
| |
| model = IRIS(config).to(device) |
| |
| |
| if vae_path and os.path.exists(vae_path): |
| model.vae.load_state_dict(torch.load(vae_path, map_location=device)) |
| print(f"Loaded VAE from {vae_path}") |
| |
| |
| for p in model.vae.parameters(): |
| p.requires_grad = False |
| |
| counts = count_parameters(model.generator) |
| print(f"Generator params: {counts['total']:,}") |
| print(f"Generator memory: {estimate_memory_mb(model.generator):.1f} MB (fp32)") |
| |
| |
| optimizer = torch.optim.AdamW( |
| model.generator.parameters(), |
| lr=args.lr, |
| weight_decay=0.03, |
| betas=(0.9, 0.95), |
| ) |
| |
| |
| total_steps = args.epochs * (args.num_samples // args.batch_size) |
| warmup_steps = min(5000, total_steps // 10) |
| |
| def lr_lambda(step): |
| if step < warmup_steps: |
| return step / warmup_steps |
| progress = (step - warmup_steps) / (total_steps - warmup_steps) |
| return 0.5 * (1 + math.cos(math.pi * progress)) |
| |
| scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) |
| scaler = GradScaler() if args.fp16 else None |
| |
| |
| num_downsamples = len(config.vae_channels) - 1 |
| total_downsample = 2 * (2 ** num_downsamples) |
| input_size = config.latent_spatial * total_downsample |
| |
| dataset = SyntheticImageTextDataset( |
| num_samples=args.num_samples, |
| image_size=input_size, |
| text_dim=config.text_dim, |
| ) |
| loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, |
| num_workers=2, pin_memory=True) |
| |
| print(f"Input size: {input_size}×{input_size}") |
| print(f"Training for {args.epochs} epochs ({total_steps} steps)") |
| print(f"Warmup: {warmup_steps} steps") |
| |
| |
| global_step = 0 |
| model.train() |
| model.vae.eval() |
| |
| for epoch in range(args.epochs): |
| epoch_loss = 0 |
| t0 = time.time() |
| |
| for batch_idx, (images, text_tokens) in enumerate(loader): |
| images = images.to(device) |
| text_tokens = text_tokens.to(device) |
| |
| with autocast(enabled=args.fp16, dtype=torch.float16): |
| result = model.train_step(images, text_tokens) |
| loss = result['loss'] |
| |
| optimizer.zero_grad() |
| if scaler: |
| scaler.scale(loss).backward() |
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_(model.generator.parameters(), 1.0) |
| scaler.step(optimizer) |
| scaler.update() |
| else: |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.generator.parameters(), 1.0) |
| optimizer.step() |
| |
| scheduler.step() |
| global_step += 1 |
| epoch_loss += loss.item() |
| |
| if global_step % args.log_every == 0: |
| lr = optimizer.param_groups[0]['lr'] |
| print(f" Step {global_step}: loss={loss.item():.4f} " |
| f"(vel={result['velocity_loss']:.4f}, kl={result['kl_loss']:.4f}) " |
| f"lr={lr:.2e}") |
| |
| avg_loss = epoch_loss / len(loader) |
| dt = time.time() - t0 |
| print(f"Epoch {epoch+1}/{args.epochs}: avg_loss={avg_loss:.4f}, time={dt:.1f}s") |
| |
| |
| if (epoch + 1) % args.save_every == 0: |
| save_path = Path(args.output_dir) / f"iris_epoch{epoch+1}.pt" |
| save_path.parent.mkdir(parents=True, exist_ok=True) |
| torch.save({ |
| 'epoch': epoch + 1, |
| 'global_step': global_step, |
| 'model_state_dict': model.state_dict(), |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'config': config, |
| }, save_path) |
| print(f"Checkpoint saved to {save_path}") |
| |
| |
| save_path = Path(args.output_dir) / "iris_final.pt" |
| torch.save({ |
| 'model_state_dict': model.state_dict(), |
| 'config': config, |
| }, save_path) |
| print(f"Final model saved to {save_path}") |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="IRIS Training Pipeline") |
| parser.add_argument('--stage', type=int, default=1, choices=[1, 2, 3, 4], |
| help='Training stage: 1=VAE, 2=class-cond, 3=text-image, 4=aesthetic') |
| parser.add_argument('--model-size', type=str, default='tiny', choices=['tiny', 'small', 'base'], |
| help='Model size variant') |
| parser.add_argument('--epochs', type=int, default=10) |
| parser.add_argument('--batch-size', type=int, default=8) |
| parser.add_argument('--lr', type=float, default=1e-4) |
| parser.add_argument('--fp16', action='store_true', default=True) |
| parser.add_argument('--num-samples', type=int, default=1000, |
| help='Number of training samples (for synthetic data)') |
| parser.add_argument('--output-dir', type=str, default='./checkpoints') |
| parser.add_argument('--vae-path', type=str, default=None, |
| help='Path to pretrained VAE checkpoint') |
| parser.add_argument('--log-every', type=int, default=10) |
| parser.add_argument('--save-every', type=int, default=5) |
| args = parser.parse_args() |
| |
| |
| if args.model_size == 'tiny': |
| model = create_iris_tiny() |
| elif args.model_size == 'small': |
| model = create_iris_small() |
| else: |
| model = create_iris_base() |
| config = model.config |
| |
| print(f"{'='*60}") |
| print(f"IRIS Training — Stage {args.stage} — {args.model_size}") |
| print(f"{'='*60}") |
| |
| if args.stage == 1: |
| train_vae(config, args) |
| else: |
| train_generator(config, args, vae_path=args.vae_path) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|