| """ |
| LiquidFlow Trainer — Complete training pipeline. |
| |
| Usage: |
| python train.py --dataset cifar10 --image_size 128 --variant small --batch_size 32 --epochs 100 |
| |
| Features: |
| - Automatic VAE loading (TAESD by default) |
| - Physics-informed regularization |
| - Mixed precision training (AMP) |
| - Checkpoint saving |
| - Sample generation during training |
| - Colab/Kaggle compatible (T4 GPU, 15GB VRAM) |
| |
| Requirements: |
| pip install torch torchvision diffusers tqdm pillow numpy |
| """ |
|
|
| import os |
| import sys |
| import math |
| import argparse |
| import json |
| from datetime import datetime |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader |
| from torchvision import datasets, transforms |
| from torchvision.utils import save_image |
| import numpy as np |
| from tqdm import tqdm |
|
|
| |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
| from liquid_flow.generator import LiquidFlowGenerator, create_liquidflow |
| from liquid_flow.vae_wrapper import TAESDWrapper |
|
|
|
|
| def get_dataloader(dataset_name, image_size, batch_size, data_dir='./data'): |
| """Get training dataloader for common datasets.""" |
| transform = transforms.Compose([ |
| transforms.Resize((image_size, image_size)), |
| transforms.ToTensor(), |
| transforms.Normalize([0.5], [0.5]), |
| ]) |
| |
| if dataset_name == 'cifar10': |
| dataset = datasets.CIFAR10( |
| root=data_dir, train=True, download=True, transform=transform |
| ) |
| elif dataset_name == 'cifar100': |
| dataset = datasets.CIFAR100( |
| root=data_dir, train=True, download=True, transform=transform |
| ) |
| elif dataset_name == 'stl10': |
| dataset = datasets.STL10( |
| root=data_dir, split='train', download=True, transform=transform |
| ) |
| elif dataset_name == 'celeba': |
| dataset = datasets.CelebA( |
| root=data_dir, split='train', download=True, transform=transform |
| ) |
| elif dataset_name == 'lsun': |
| dataset = datasets.LSUN( |
| root=data_dir, classes='bedroom_train', transform=transform |
| ) |
| elif dataset_name == 'imagenet': |
| transform = transforms.Compose([ |
| transforms.Resize((image_size, image_size)), |
| transforms.RandomCrop(image_size), |
| transforms.RandomHorizontalFlip(), |
| transforms.ToTensor(), |
| transforms.Normalize([0.5], [0.5]), |
| ]) |
| dataset = datasets.ImageFolder( |
| root=f'{data_dir}/imagenet/train', transform=transform |
| ) |
| else: |
| raise ValueError(f"Unknown dataset: {dataset_name}") |
| |
| dataloader = DataLoader( |
| dataset, |
| batch_size=batch_size, |
| shuffle=True, |
| num_workers=min(4, os.cpu_count() or 1), |
| pin_memory=True, |
| drop_last=True, |
| ) |
| |
| return dataloader |
|
|
|
|
| def train(args): |
| """Main training loop.""" |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Using device: {device}") |
| |
| |
| os.makedirs(args.output_dir, exist_ok=True) |
| os.makedirs(os.path.join(args.output_dir, 'samples'), exist_ok=True) |
| os.makedirs(os.path.join(args.output_dir, 'checkpoints'), exist_ok=True) |
| |
| |
| print("Loading VAE...") |
| vae = TAESDWrapper.load(device) |
| print(f"VAE loaded. Latent size: {args.image_size // 8}x{args.image_size // 8}") |
| |
| |
| print(f"Creating LiquidFlow model (variant={args.variant})...") |
| model = create_liquidflow( |
| variant=args.variant, |
| image_size=args.image_size, |
| ) |
| model = model.to(device) |
| |
| n_params = model.count_parameters() |
| print(f"Model parameters: {n_params:,} (~{n_params/1e6:.1f}M)") |
| |
| |
| latent_h = latent_w = args.image_size // 8 |
| mem_per_sample = latent_h * latent_w * 4 * 4 / (1024**2) |
| print(f"Estimated memory per sample: {mem_per_sample:.1f} MB") |
| print(f"Estimated batch memory: {mem_per_sample * args.batch_size:.1f} MB") |
| |
| |
| print(f"Loading dataset: {args.dataset}") |
| dataloader = get_dataloader(args.dataset, args.image_size, args.batch_size, args.data_dir) |
| print(f"Dataset size: {len(dataloader.dataset)} images, {len(dataloader)} batches") |
| |
| |
| optimizer = torch.optim.AdamW( |
| model.parameters(), |
| lr=args.lr, |
| betas=(0.9, 0.999), |
| weight_decay=args.weight_decay, |
| ) |
| |
| |
| if args.lr_schedule == 'cosine': |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( |
| optimizer, T_max=args.epochs * len(dataloader) |
| ) |
| elif args.lr_schedule == 'cosine_restart': |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( |
| optimizer, T_0=args.epochs * len(dataloader) // 3, |
| ) |
| else: |
| scheduler = None |
| |
| |
| use_amp = args.amp and device.type == 'cuda' |
| scaler = torch.cuda.amp.GradScaler() if use_amp else None |
| |
| |
| sample_noise = torch.randn(16, 4, args.image_size // 8, args.image_size // 8, device=device) |
| |
| |
| global_step = 0 |
| best_loss = float('inf') |
| |
| print(f"\n{'='*60}") |
| print(f"Starting training: {args.epochs} epochs, {args.batch_size} batch size") |
| print(f"LR: {args.lr}, Weight Decay: {args.weight_decay}") |
| print(f"AMP: {use_amp}, LR Schedule: {args.lr_schedule}") |
| print(f"{'='*60}\n") |
| |
| for epoch in range(args.epochs): |
| model.train() |
| epoch_losses = {'total': 0, 'diffusion': 0, 'physics': 0} |
| |
| pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{args.epochs}") |
| |
| for batch_idx, (images, _) in enumerate(pbar): |
| images = images.to(device) |
| |
| |
| with torch.no_grad(): |
| latents = TAESDWrapper.encode(vae, images) |
| |
| |
| loss_dict = model.training_step(latents, optimizer, scaler, use_amp) |
| |
| |
| if scheduler is not None: |
| scheduler.step() |
| |
| |
| for k in epoch_losses: |
| epoch_losses[k] += loss_dict.get(k, 0) |
| |
| global_step += 1 |
| |
| |
| pbar.set_postfix({ |
| 'loss': f"{loss_dict.get('total', 0):.4f}", |
| 'diff': f"{loss_dict.get('diffusion', 0):.4f}", |
| 'phys': f"{loss_dict.get('physics', 0):.4f}", |
| 'lr': f"{optimizer.param_groups[0]['lr']:.2e}", |
| }) |
| |
| |
| n_batches = len(dataloader) |
| avg_losses = {k: v / n_batches for k, v in epoch_losses.items()} |
| |
| print(f"\nEpoch {epoch+1} Summary:") |
| print(f" Total Loss: {avg_losses['total']:.4f}") |
| print(f" Diffusion Loss: {avg_losses['diffusion']:.4f}") |
| print(f" Physics Loss: {avg_losses['physics']:.4f}") |
| |
| |
| if (epoch + 1) % args.sample_every == 0 or epoch == args.epochs - 1: |
| print(f"Generating samples...") |
| model.eval() |
| |
| with torch.no_grad(): |
| |
| latents_gen = model.sample( |
| batch_size=16, |
| steps=args.sample_steps, |
| ddim=True, |
| progress=False, |
| ) |
| images_gen = TAESDWrapper.decode(vae, latents_gen) |
| |
| |
| t_fixed = torch.full((16,), 0, device=device, dtype=torch.long) |
| |
| x_fixed = sample_noise.clone() |
| skip = 1000 // args.sample_steps |
| for i in reversed(range(0, 1000, skip)): |
| t = torch.full((16,), i, device=device, dtype=torch.long) |
| noise_pred = model(x_fixed, t) |
| alpha_bar = model.alphas_cumprod[i] |
| alpha_bar_prev = model.alphas_cumprod[i - skip] if i >= skip else torch.tensor(1.0, device=device) |
| x0_pred = (x_fixed - torch.sqrt(1 - alpha_bar) * noise_pred) / torch.sqrt(alpha_bar) |
| x0_pred = torch.clamp(x0_pred, -1, 1) |
| x_fixed = torch.sqrt(alpha_bar_prev) * x0_pred + torch.sqrt(1 - alpha_bar_prev) * torch.randn_like(x_fixed) |
| |
| images_fixed = TAESDWrapper.decode(vae, x_fixed) |
| |
| |
| sample_path = os.path.join(args.output_dir, 'samples', f'epoch_{epoch+1:03d}.png') |
| save_image(images_gen, sample_path, nrow=4, normalize=True, value_range=(-1, 1)) |
| |
| fixed_path = os.path.join(args.output_dir, 'samples', f'fixed_{epoch+1:03d}.png') |
| save_image(images_fixed, fixed_path, nrow=4, normalize=True, value_range=(-1, 1)) |
| |
| print(f" Samples saved to {sample_path}") |
| |
| |
| if (epoch + 1) % args.save_every == 0 or epoch == args.epochs - 1: |
| checkpoint_path = os.path.join(args.output_dir, 'checkpoints', f'epoch_{epoch+1:03d}.pt') |
| torch.save({ |
| 'epoch': epoch + 1, |
| 'global_step': global_step, |
| 'model_state_dict': model.state_dict(), |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'loss': avg_losses['total'], |
| 'args': vars(args), |
| }, checkpoint_path) |
| print(f" Checkpoint saved to {checkpoint_path}") |
| |
| |
| if avg_losses['total'] < best_loss: |
| best_loss = avg_losses['total'] |
| best_path = os.path.join(args.output_dir, 'checkpoints', 'best_model.pt') |
| torch.save(model.state_dict(), best_path) |
| print(f" Best model saved (loss={best_loss:.4f})") |
| |
| print() |
| |
| print(f"\n{'='*60}") |
| print(f"Training complete!") |
| print(f"Best loss: {best_loss:.4f}") |
| print(f"Model saved to: {args.output_dir}/checkpoints/") |
| print(f"{'='*60}") |
| |
| return model |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description='LiquidFlow Generator Training') |
| |
| |
| parser.add_argument('--dataset', type=str, default='cifar10', |
| choices=['cifar10', 'cifar100', 'stl10', 'celeba', 'lsun', 'imagenet'], |
| help='Training dataset') |
| parser.add_argument('--data_dir', type=str, default='./data', |
| help='Data directory') |
| parser.add_argument('--image_size', type=int, default=128, |
| choices=[64, 128, 256, 512], |
| help='Image size (will be VAE-encoded)') |
| |
| |
| parser.add_argument('--variant', type=str, default='small', |
| choices=['tiny', 'small', 'base'], |
| help='Model size variant') |
| |
| |
| parser.add_argument('--batch_size', type=int, default=32, |
| help='Batch size') |
| parser.add_argument('--epochs', type=int, default=100, |
| help='Number of epochs') |
| parser.add_argument('--lr', type=float, default=2e-4, |
| help='Learning rate') |
| parser.add_argument('--weight_decay', type=float, default=1e-4, |
| help='Weight decay') |
| parser.add_argument('--lr_schedule', type=str, default='cosine', |
| choices=['cosine', 'cosine_restart', 'none'], |
| help='LR schedule') |
| parser.add_argument('--amp', action='store_true', default=True, |
| help='Use automatic mixed precision') |
| |
| |
| parser.add_argument('--sample_every', type=int, default=5, |
| help='Generate samples every N epochs') |
| parser.add_argument('--sample_steps', type=int, default=50, |
| help='DDIM sampling steps') |
| |
| |
| parser.add_argument('--output_dir', type=str, default='./outputs', |
| help='Output directory') |
| parser.add_argument('--save_every', type=int, default=10, |
| help='Save checkpoint every N epochs') |
| |
| args = parser.parse_args() |
| train(args) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|