asdf98 commited on
Commit
5a5cffa
·
verified ·
1 Parent(s): 774e194

Add train_iris.py

Browse files
Files changed (1) hide show
  1. train_iris.py +315 -0
train_iris.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ IRIS Training Script
3
+ =====================
4
+ End-to-end training pipeline for IRIS (Iterative Recurrent Image Synthesis).
5
+
6
+ Supports:
7
+ - Stage 1: Wavelet VAE pre-training (reconstruction)
8
+ - Stage 2: Class-conditional pretraining (ImageNet)
9
+ - Stage 3: Text-image alignment (CLIP-conditioned)
10
+ - Stage 4: Aesthetic fine-tuning
11
+
12
+ Usage:
13
+ python train_iris.py --stage 1 --dataset imagenet --epochs 50
14
+ python train_iris.py --stage 3 --dataset cc3m --epochs 100
15
+
16
+ Designed to run on Colab/Kaggle (single GPU, T4/A100).
17
+ """
18
+
19
+ import os
20
+ import math
21
+ import argparse
22
+ import time
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.nn.functional as F
26
+ from torch.utils.data import DataLoader, Dataset
27
+ from torch.cuda.amp import autocast, GradScaler
28
+ from pathlib import Path
29
+
30
+ from iris_model import (
31
+ IRIS, IRISConfig, WaveletVAE,
32
+ create_iris_small, create_iris_tiny, create_iris_base,
33
+ count_parameters, estimate_memory_mb,
34
+ )
35
+
36
+
37
+ # ============================================================================
38
+ # Synthetic Dataset (for testing; replace with real dataset loaders)
39
+ # ============================================================================
40
+
41
+ class SyntheticImageTextDataset(Dataset):
42
+ """Synthetic dataset for testing the training pipeline."""
43
+ def __init__(self, num_samples=1000, image_size=256, text_dim=768, text_len=77):
44
+ self.num_samples = num_samples
45
+ self.image_size = image_size
46
+ self.text_dim = text_dim
47
+ self.text_len = text_len
48
+
49
+ def __len__(self):
50
+ return self.num_samples
51
+
52
+ def __getitem__(self, idx):
53
+ image = torch.randn(3, self.image_size, self.image_size)
54
+ text = torch.randn(self.text_len, self.text_dim)
55
+ return image, text
56
+
57
+
58
+ # ============================================================================
59
+ # VAE Training (Stage 1)
60
+ # ============================================================================
61
+
62
+ def train_vae(config: IRISConfig, args):
63
+ """Train the Wavelet VAE for image reconstruction."""
64
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
65
+ print(f"Training VAE on {device}")
66
+
67
+ vae = WaveletVAE(config).to(device)
68
+ print(f"VAE params: {sum(p.numel() for p in vae.parameters()):,}")
69
+
70
+ optimizer = torch.optim.AdamW(vae.parameters(), lr=1e-4, weight_decay=0.05)
71
+ scaler = GradScaler() if args.fp16 else None
72
+
73
+ # Input size depends on VAE architecture: DWT(2×) + down_blocks
74
+ num_downsamples = len(config.vae_channels) - 1
75
+ total_downsample = 2 * (2 ** num_downsamples) # DWT + conv downsamples
76
+ input_size = config.latent_spatial * total_downsample
77
+
78
+ dataset = SyntheticImageTextDataset(
79
+ num_samples=args.num_samples,
80
+ image_size=input_size,
81
+ )
82
+ loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True,
83
+ num_workers=2, pin_memory=True)
84
+
85
+ print(f"Input image size: {input_size}×{input_size}")
86
+ print(f"Latent size: {config.latent_spatial}×{config.latent_spatial}×{config.latent_channels}")
87
+
88
+ vae.train()
89
+ for epoch in range(args.epochs):
90
+ total_loss = 0
91
+ t0 = time.time()
92
+
93
+ for batch_idx, (images, _) in enumerate(loader):
94
+ images = images.to(device)
95
+
96
+ with autocast(enabled=args.fp16, dtype=torch.float16):
97
+ x_recon, mean, logvar = vae(images)
98
+
99
+ # Reconstruction loss (MSE + Perceptual-like via gradient)
100
+ recon_loss = F.mse_loss(x_recon, images)
101
+
102
+ # KL divergence
103
+ kl_loss = -0.5 * (1 + logvar - mean.pow(2) - logvar.exp()).mean()
104
+
105
+ # Wavelet frequency loss (enforce high-freq detail preservation)
106
+ from iris_model import HaarDWT2D
107
+ dwt = HaarDWT2D()
108
+ recon_wavelet = dwt(x_recon)
109
+ target_wavelet = dwt(images)
110
+ freq_loss = F.l1_loss(recon_wavelet, target_wavelet)
111
+
112
+ loss = recon_loss + 0.001 * kl_loss + 0.1 * freq_loss
113
+
114
+ optimizer.zero_grad()
115
+ if scaler:
116
+ scaler.scale(loss).backward()
117
+ scaler.unscale_(optimizer)
118
+ torch.nn.utils.clip_grad_norm_(vae.parameters(), 1.0)
119
+ scaler.step(optimizer)
120
+ scaler.update()
121
+ else:
122
+ loss.backward()
123
+ torch.nn.utils.clip_grad_norm_(vae.parameters(), 1.0)
124
+ optimizer.step()
125
+
126
+ total_loss += loss.item()
127
+
128
+ if batch_idx % 10 == 0:
129
+ print(f" Step {batch_idx}: loss={loss.item():.4f} "
130
+ f"(recon={recon_loss.item():.4f}, kl={kl_loss.item():.4f}, "
131
+ f"freq={freq_loss.item():.4f})")
132
+
133
+ avg_loss = total_loss / len(loader)
134
+ dt = time.time() - t0
135
+ print(f"Epoch {epoch+1}/{args.epochs}: avg_loss={avg_loss:.4f}, time={dt:.1f}s")
136
+
137
+ # Save
138
+ save_path = Path(args.output_dir) / "vae_checkpoint.pt"
139
+ save_path.parent.mkdir(parents=True, exist_ok=True)
140
+ torch.save(vae.state_dict(), save_path)
141
+ print(f"VAE saved to {save_path}")
142
+ return vae
143
+
144
+
145
+ # ============================================================================
146
+ # Generator Training (Stages 2-4)
147
+ # ============================================================================
148
+
149
+ def train_generator(config: IRISConfig, args, vae_path=None):
150
+ """Train the IRIS generator with rectified flow."""
151
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
152
+ print(f"Training Generator on {device}")
153
+
154
+ model = IRIS(config).to(device)
155
+
156
+ # Load pretrained VAE if available
157
+ if vae_path and os.path.exists(vae_path):
158
+ model.vae.load_state_dict(torch.load(vae_path, map_location=device))
159
+ print(f"Loaded VAE from {vae_path}")
160
+
161
+ # Freeze VAE during generator training
162
+ for p in model.vae.parameters():
163
+ p.requires_grad = False
164
+
165
+ counts = count_parameters(model.generator)
166
+ print(f"Generator params: {counts['total']:,}")
167
+ print(f"Generator memory: {estimate_memory_mb(model.generator):.1f} MB (fp32)")
168
+
169
+ # Optimizer (AdamW with cosine schedule)
170
+ optimizer = torch.optim.AdamW(
171
+ model.generator.parameters(),
172
+ lr=args.lr,
173
+ weight_decay=0.03,
174
+ betas=(0.9, 0.95),
175
+ )
176
+
177
+ # Cosine LR schedule with warmup
178
+ total_steps = args.epochs * (args.num_samples // args.batch_size)
179
+ warmup_steps = min(5000, total_steps // 10)
180
+
181
+ def lr_lambda(step):
182
+ if step < warmup_steps:
183
+ return step / warmup_steps
184
+ progress = (step - warmup_steps) / (total_steps - warmup_steps)
185
+ return 0.5 * (1 + math.cos(math.pi * progress))
186
+
187
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
188
+ scaler = GradScaler() if args.fp16 else None
189
+
190
+ # Dataset
191
+ num_downsamples = len(config.vae_channels) - 1
192
+ total_downsample = 2 * (2 ** num_downsamples)
193
+ input_size = config.latent_spatial * total_downsample
194
+
195
+ dataset = SyntheticImageTextDataset(
196
+ num_samples=args.num_samples,
197
+ image_size=input_size,
198
+ text_dim=config.text_dim,
199
+ )
200
+ loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True,
201
+ num_workers=2, pin_memory=True)
202
+
203
+ print(f"Input size: {input_size}×{input_size}")
204
+ print(f"Training for {args.epochs} epochs ({total_steps} steps)")
205
+ print(f"Warmup: {warmup_steps} steps")
206
+
207
+ # Training loop
208
+ global_step = 0
209
+ model.train()
210
+ model.vae.eval()
211
+
212
+ for epoch in range(args.epochs):
213
+ epoch_loss = 0
214
+ t0 = time.time()
215
+
216
+ for batch_idx, (images, text_tokens) in enumerate(loader):
217
+ images = images.to(device)
218
+ text_tokens = text_tokens.to(device)
219
+
220
+ with autocast(enabled=args.fp16, dtype=torch.float16):
221
+ result = model.train_step(images, text_tokens)
222
+ loss = result['loss']
223
+
224
+ optimizer.zero_grad()
225
+ if scaler:
226
+ scaler.scale(loss).backward()
227
+ scaler.unscale_(optimizer)
228
+ torch.nn.utils.clip_grad_norm_(model.generator.parameters(), 1.0)
229
+ scaler.step(optimizer)
230
+ scaler.update()
231
+ else:
232
+ loss.backward()
233
+ torch.nn.utils.clip_grad_norm_(model.generator.parameters(), 1.0)
234
+ optimizer.step()
235
+
236
+ scheduler.step()
237
+ global_step += 1
238
+ epoch_loss += loss.item()
239
+
240
+ if global_step % args.log_every == 0:
241
+ lr = optimizer.param_groups[0]['lr']
242
+ print(f" Step {global_step}: loss={loss.item():.4f} "
243
+ f"(vel={result['velocity_loss']:.4f}, kl={result['kl_loss']:.4f}) "
244
+ f"lr={lr:.2e}")
245
+
246
+ avg_loss = epoch_loss / len(loader)
247
+ dt = time.time() - t0
248
+ print(f"Epoch {epoch+1}/{args.epochs}: avg_loss={avg_loss:.4f}, time={dt:.1f}s")
249
+
250
+ # Save checkpoint
251
+ if (epoch + 1) % args.save_every == 0:
252
+ save_path = Path(args.output_dir) / f"iris_epoch{epoch+1}.pt"
253
+ save_path.parent.mkdir(parents=True, exist_ok=True)
254
+ torch.save({
255
+ 'epoch': epoch + 1,
256
+ 'global_step': global_step,
257
+ 'model_state_dict': model.state_dict(),
258
+ 'optimizer_state_dict': optimizer.state_dict(),
259
+ 'config': config,
260
+ }, save_path)
261
+ print(f"Checkpoint saved to {save_path}")
262
+
263
+ # Final save
264
+ save_path = Path(args.output_dir) / "iris_final.pt"
265
+ torch.save({
266
+ 'model_state_dict': model.state_dict(),
267
+ 'config': config,
268
+ }, save_path)
269
+ print(f"Final model saved to {save_path}")
270
+
271
+
272
+ # ============================================================================
273
+ # Main
274
+ # ============================================================================
275
+
276
+ def main():
277
+ parser = argparse.ArgumentParser(description="IRIS Training Pipeline")
278
+ parser.add_argument('--stage', type=int, default=1, choices=[1, 2, 3, 4],
279
+ help='Training stage: 1=VAE, 2=class-cond, 3=text-image, 4=aesthetic')
280
+ parser.add_argument('--model-size', type=str, default='tiny', choices=['tiny', 'small', 'base'],
281
+ help='Model size variant')
282
+ parser.add_argument('--epochs', type=int, default=10)
283
+ parser.add_argument('--batch-size', type=int, default=8)
284
+ parser.add_argument('--lr', type=float, default=1e-4)
285
+ parser.add_argument('--fp16', action='store_true', default=True)
286
+ parser.add_argument('--num-samples', type=int, default=1000,
287
+ help='Number of training samples (for synthetic data)')
288
+ parser.add_argument('--output-dir', type=str, default='./checkpoints')
289
+ parser.add_argument('--vae-path', type=str, default=None,
290
+ help='Path to pretrained VAE checkpoint')
291
+ parser.add_argument('--log-every', type=int, default=10)
292
+ parser.add_argument('--save-every', type=int, default=5)
293
+ args = parser.parse_args()
294
+
295
+ # Create config based on model size
296
+ if args.model_size == 'tiny':
297
+ model = create_iris_tiny()
298
+ elif args.model_size == 'small':
299
+ model = create_iris_small()
300
+ else:
301
+ model = create_iris_base()
302
+ config = model.config
303
+
304
+ print(f"{'='*60}")
305
+ print(f"IRIS Training — Stage {args.stage} — {args.model_size}")
306
+ print(f"{'='*60}")
307
+
308
+ if args.stage == 1:
309
+ train_vae(config, args)
310
+ else:
311
+ train_generator(config, args, vae_path=args.vae_path)
312
+
313
+
314
+ if __name__ == "__main__":
315
+ main()