krystv commited on
Commit
658087c
·
verified ·
1 Parent(s): f8a7028

Upload train.py

Browse files
Files changed (1) hide show
  1. train.py +335 -0
train.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LiquidFlow Trainer — Complete training pipeline.
3
+
4
+ Usage:
5
+ python train.py --dataset cifar10 --image_size 128 --variant small --batch_size 32 --epochs 100
6
+
7
+ Features:
8
+ - Automatic VAE loading (TAESD by default)
9
+ - Physics-informed regularization
10
+ - Mixed precision training (AMP)
11
+ - Checkpoint saving
12
+ - Sample generation during training
13
+ - Colab/Kaggle compatible (T4 GPU, 15GB VRAM)
14
+
15
+ Requirements:
16
+ pip install torch torchvision diffusers tqdm pillow numpy
17
+ """
18
+
19
+ import os
20
+ import sys
21
+ import math
22
+ import argparse
23
+ import json
24
+ from datetime import datetime
25
+ from pathlib import Path
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+ import torch.nn.functional as F
30
+ from torch.utils.data import DataLoader
31
+ from torchvision import datasets, transforms
32
+ from torchvision.utils import save_image
33
+ import numpy as np
34
+ from tqdm import tqdm
35
+
36
+ # Add parent to path
37
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
38
+
39
+ from liquid_flow.generator import LiquidFlowGenerator, create_liquidflow
40
+ from liquid_flow.vae_wrapper import TAESDWrapper
41
+
42
+
43
+ def get_dataloader(dataset_name, image_size, batch_size, data_dir='./data'):
44
+ """Get training dataloader for common datasets."""
45
+ transform = transforms.Compose([
46
+ transforms.Resize((image_size, image_size)),
47
+ transforms.ToTensor(),
48
+ transforms.Normalize([0.5], [0.5]), # [-1, 1]
49
+ ])
50
+
51
+ if dataset_name == 'cifar10':
52
+ dataset = datasets.CIFAR10(
53
+ root=data_dir, train=True, download=True, transform=transform
54
+ )
55
+ elif dataset_name == 'cifar100':
56
+ dataset = datasets.CIFAR100(
57
+ root=data_dir, train=True, download=True, transform=transform
58
+ )
59
+ elif dataset_name == 'stl10':
60
+ dataset = datasets.STL10(
61
+ root=data_dir, split='train', download=True, transform=transform
62
+ )
63
+ elif dataset_name == 'celeba':
64
+ dataset = datasets.CelebA(
65
+ root=data_dir, split='train', download=True, transform=transform
66
+ )
67
+ elif dataset_name == 'lsun':
68
+ dataset = datasets.LSUN(
69
+ root=data_dir, classes='bedroom_train', transform=transform
70
+ )
71
+ elif dataset_name == 'imagenet':
72
+ transform = transforms.Compose([
73
+ transforms.Resize((image_size, image_size)),
74
+ transforms.RandomCrop(image_size),
75
+ transforms.RandomHorizontalFlip(),
76
+ transforms.ToTensor(),
77
+ transforms.Normalize([0.5], [0.5]),
78
+ ])
79
+ dataset = datasets.ImageFolder(
80
+ root=f'{data_dir}/imagenet/train', transform=transform
81
+ )
82
+ else:
83
+ raise ValueError(f"Unknown dataset: {dataset_name}")
84
+
85
+ dataloader = DataLoader(
86
+ dataset,
87
+ batch_size=batch_size,
88
+ shuffle=True,
89
+ num_workers=min(4, os.cpu_count() or 1),
90
+ pin_memory=True,
91
+ drop_last=True,
92
+ )
93
+
94
+ return dataloader
95
+
96
+
97
+ def train(args):
98
+ """Main training loop."""
99
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
100
+ print(f"Using device: {device}")
101
+
102
+ # Create output directory
103
+ os.makedirs(args.output_dir, exist_ok=True)
104
+ os.makedirs(os.path.join(args.output_dir, 'samples'), exist_ok=True)
105
+ os.makedirs(os.path.join(args.output_dir, 'checkpoints'), exist_ok=True)
106
+
107
+ # Load VAE
108
+ print("Loading VAE...")
109
+ vae = TAESDWrapper.load(device)
110
+ print(f"VAE loaded. Latent size: {args.image_size // 8}x{args.image_size // 8}")
111
+
112
+ # Create model
113
+ print(f"Creating LiquidFlow model (variant={args.variant})...")
114
+ model = create_liquidflow(
115
+ variant=args.variant,
116
+ image_size=args.image_size,
117
+ )
118
+ model = model.to(device)
119
+
120
+ n_params = model.count_parameters()
121
+ print(f"Model parameters: {n_params:,} (~{n_params/1e6:.1f}M)")
122
+
123
+ # Calculate memory estimate
124
+ latent_h = latent_w = args.image_size // 8
125
+ mem_per_sample = latent_h * latent_w * 4 * 4 / (1024**2) # in MB
126
+ print(f"Estimated memory per sample: {mem_per_sample:.1f} MB")
127
+ print(f"Estimated batch memory: {mem_per_sample * args.batch_size:.1f} MB")
128
+
129
+ # Dataset
130
+ print(f"Loading dataset: {args.dataset}")
131
+ dataloader = get_dataloader(args.dataset, args.image_size, args.batch_size, args.data_dir)
132
+ print(f"Dataset size: {len(dataloader.dataset)} images, {len(dataloader)} batches")
133
+
134
+ # Optimizer (AdamW, following DiT/DiMSUM convention)
135
+ optimizer = torch.optim.AdamW(
136
+ model.parameters(),
137
+ lr=args.lr,
138
+ betas=(0.9, 0.999),
139
+ weight_decay=args.weight_decay,
140
+ )
141
+
142
+ # Learning rate scheduler
143
+ if args.lr_schedule == 'cosine':
144
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
145
+ optimizer, T_max=args.epochs * len(dataloader)
146
+ )
147
+ elif args.lr_schedule == 'cosine_restart':
148
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
149
+ optimizer, T_0=args.epochs * len(dataloader) // 3,
150
+ )
151
+ else:
152
+ scheduler = None
153
+
154
+ # AMP
155
+ use_amp = args.amp and device.type == 'cuda'
156
+ scaler = torch.cuda.amp.GradScaler() if use_amp else None
157
+
158
+ # Fixed noise for sample generation (track progress)
159
+ sample_noise = torch.randn(16, 4, args.image_size // 8, args.image_size // 8, device=device)
160
+
161
+ # Training state
162
+ global_step = 0
163
+ best_loss = float('inf')
164
+
165
+ print(f"\n{'='*60}")
166
+ print(f"Starting training: {args.epochs} epochs, {args.batch_size} batch size")
167
+ print(f"LR: {args.lr}, Weight Decay: {args.weight_decay}")
168
+ print(f"AMP: {use_amp}, LR Schedule: {args.lr_schedule}")
169
+ print(f"{'='*60}\n")
170
+
171
+ for epoch in range(args.epochs):
172
+ model.train()
173
+ epoch_losses = {'total': 0, 'diffusion': 0, 'physics': 0}
174
+
175
+ pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{args.epochs}")
176
+
177
+ for batch_idx, (images, _) in enumerate(pbar):
178
+ images = images.to(device)
179
+
180
+ # Encode to latent space
181
+ with torch.no_grad():
182
+ latents = TAESDWrapper.encode(vae, images)
183
+
184
+ # Training step
185
+ loss_dict = model.training_step(latents, optimizer, scaler, use_amp)
186
+
187
+ # Update scheduler
188
+ if scheduler is not None:
189
+ scheduler.step()
190
+
191
+ # Track losses
192
+ for k in epoch_losses:
193
+ epoch_losses[k] += loss_dict.get(k, 0)
194
+
195
+ global_step += 1
196
+
197
+ # Update progress bar
198
+ pbar.set_postfix({
199
+ 'loss': f"{loss_dict.get('total', 0):.4f}",
200
+ 'diff': f"{loss_dict.get('diffusion', 0):.4f}",
201
+ 'phys': f"{loss_dict.get('physics', 0):.4f}",
202
+ 'lr': f"{optimizer.param_groups[0]['lr']:.2e}",
203
+ })
204
+
205
+ # Epoch summary
206
+ n_batches = len(dataloader)
207
+ avg_losses = {k: v / n_batches for k, v in epoch_losses.items()}
208
+
209
+ print(f"\nEpoch {epoch+1} Summary:")
210
+ print(f" Total Loss: {avg_losses['total']:.4f}")
211
+ print(f" Diffusion Loss: {avg_losses['diffusion']:.4f}")
212
+ print(f" Physics Loss: {avg_losses['physics']:.4f}")
213
+
214
+ # Generate samples
215
+ if (epoch + 1) % args.sample_every == 0 or epoch == args.epochs - 1:
216
+ print(f"Generating samples...")
217
+ model.eval()
218
+
219
+ with torch.no_grad():
220
+ # DDIM sampling
221
+ latents_gen = model.sample(
222
+ batch_size=16,
223
+ steps=args.sample_steps,
224
+ ddim=True,
225
+ progress=False,
226
+ )
227
+ images_gen = TAESDWrapper.decode(vae, latents_gen)
228
+
229
+ # Also generate from fixed noise for tracking
230
+ t_fixed = torch.full((16,), 0, device=device, dtype=torch.long)
231
+ # Quick DDIM from fixed noise
232
+ x_fixed = sample_noise.clone()
233
+ skip = 1000 // args.sample_steps
234
+ for i in reversed(range(0, 1000, skip)):
235
+ t = torch.full((16,), i, device=device, dtype=torch.long)
236
+ noise_pred = model(x_fixed, t)
237
+ alpha_bar = model.alphas_cumprod[i]
238
+ alpha_bar_prev = model.alphas_cumprod[i - skip] if i >= skip else torch.tensor(1.0, device=device)
239
+ x0_pred = (x_fixed - torch.sqrt(1 - alpha_bar) * noise_pred) / torch.sqrt(alpha_bar)
240
+ x0_pred = torch.clamp(x0_pred, -1, 1)
241
+ x_fixed = torch.sqrt(alpha_bar_prev) * x0_pred + torch.sqrt(1 - alpha_bar_prev) * torch.randn_like(x_fixed)
242
+
243
+ images_fixed = TAESDWrapper.decode(vae, x_fixed)
244
+
245
+ # Save samples
246
+ sample_path = os.path.join(args.output_dir, 'samples', f'epoch_{epoch+1:03d}.png')
247
+ save_image(images_gen, sample_path, nrow=4, normalize=True, value_range=(-1, 1))
248
+
249
+ fixed_path = os.path.join(args.output_dir, 'samples', f'fixed_{epoch+1:03d}.png')
250
+ save_image(images_fixed, fixed_path, nrow=4, normalize=True, value_range=(-1, 1))
251
+
252
+ print(f" Samples saved to {sample_path}")
253
+
254
+ # Save checkpoint
255
+ if (epoch + 1) % args.save_every == 0 or epoch == args.epochs - 1:
256
+ checkpoint_path = os.path.join(args.output_dir, 'checkpoints', f'epoch_{epoch+1:03d}.pt')
257
+ torch.save({
258
+ 'epoch': epoch + 1,
259
+ 'global_step': global_step,
260
+ 'model_state_dict': model.state_dict(),
261
+ 'optimizer_state_dict': optimizer.state_dict(),
262
+ 'loss': avg_losses['total'],
263
+ 'args': vars(args),
264
+ }, checkpoint_path)
265
+ print(f" Checkpoint saved to {checkpoint_path}")
266
+
267
+ # Save best model
268
+ if avg_losses['total'] < best_loss:
269
+ best_loss = avg_losses['total']
270
+ best_path = os.path.join(args.output_dir, 'checkpoints', 'best_model.pt')
271
+ torch.save(model.state_dict(), best_path)
272
+ print(f" Best model saved (loss={best_loss:.4f})")
273
+
274
+ print()
275
+
276
+ print(f"\n{'='*60}")
277
+ print(f"Training complete!")
278
+ print(f"Best loss: {best_loss:.4f}")
279
+ print(f"Model saved to: {args.output_dir}/checkpoints/")
280
+ print(f"{'='*60}")
281
+
282
+ return model
283
+
284
+
285
+ def main():
286
+ parser = argparse.ArgumentParser(description='LiquidFlow Generator Training')
287
+
288
+ # Dataset
289
+ parser.add_argument('--dataset', type=str, default='cifar10',
290
+ choices=['cifar10', 'cifar100', 'stl10', 'celeba', 'lsun', 'imagenet'],
291
+ help='Training dataset')
292
+ parser.add_argument('--data_dir', type=str, default='./data',
293
+ help='Data directory')
294
+ parser.add_argument('--image_size', type=int, default=128,
295
+ choices=[64, 128, 256, 512],
296
+ help='Image size (will be VAE-encoded)')
297
+
298
+ # Model
299
+ parser.add_argument('--variant', type=str, default='small',
300
+ choices=['tiny', 'small', 'base'],
301
+ help='Model size variant')
302
+
303
+ # Training
304
+ parser.add_argument('--batch_size', type=int, default=32,
305
+ help='Batch size')
306
+ parser.add_argument('--epochs', type=int, default=100,
307
+ help='Number of epochs')
308
+ parser.add_argument('--lr', type=float, default=2e-4,
309
+ help='Learning rate')
310
+ parser.add_argument('--weight_decay', type=float, default=1e-4,
311
+ help='Weight decay')
312
+ parser.add_argument('--lr_schedule', type=str, default='cosine',
313
+ choices=['cosine', 'cosine_restart', 'none'],
314
+ help='LR schedule')
315
+ parser.add_argument('--amp', action='store_true', default=True,
316
+ help='Use automatic mixed precision')
317
+
318
+ # Generation
319
+ parser.add_argument('--sample_every', type=int, default=5,
320
+ help='Generate samples every N epochs')
321
+ parser.add_argument('--sample_steps', type=int, default=50,
322
+ help='DDIM sampling steps')
323
+
324
+ # IO
325
+ parser.add_argument('--output_dir', type=str, default='./outputs',
326
+ help='Output directory')
327
+ parser.add_argument('--save_every', type=int, default=10,
328
+ help='Save checkpoint every N epochs')
329
+
330
+ args = parser.parse_args()
331
+ train(args)
332
+
333
+
334
+ if __name__ == '__main__':
335
+ main()