""" LiquidGen Training Pipeline Flow Matching training objective (velocity prediction): - Forward: x_t = (1 - t) * x_0 + t * ε (linear interpolation) - Target: v = ε - x_0 (velocity) - Loss: MSE(model(x_t, t), v) At inference: solve ODE from t=1 (noise) to t=0 (clean) using Euler steps. """ import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset from torch.amp import autocast, GradScaler import math import os import json import time from pathlib import Path from typing import Optional, Dict, Any from dataclasses import dataclass, field, asdict @dataclass class TrainConfig: """Training configuration with sensible defaults for Colab free tier.""" # Model model_size: str = "small" num_classes: int = 0 class_drop_prob: float = 0.1 # Data image_size: int = 256 dataset_name: str = "huggan/wikiart" dataset_config: str = "" image_column: str = "image" label_column: str = "" # VAE vae_id: str = "black-forest-labs/FLUX.1-schnell" vae_subfolder: str = "vae" vae_dtype: str = "float16" vae_scaling_factor: float = 0.3611 vae_shift_factor: float = 0.1159 # Training batch_size: int = 8 gradient_accumulation_steps: int = 4 learning_rate: float = 1e-4 weight_decay: float = 0.01 max_grad_norm: float = 2.0 num_epochs: int = 100 warmup_steps: int = 1000 ema_decay: float = 0.9999 mixed_precision: bool = True # Flow matching min_timestep: float = 0.001 max_timestep: float = 0.999 # Saving output_dir: str = "./outputs" save_every_n_steps: int = 5000 sample_every_n_steps: int = 1000 log_every_n_steps: int = 50 # Sampling num_sample_steps: int = 50 cfg_scale: float = 1.5 num_samples: int = 4 # System seed: int = 42 num_workers: int = 2 pin_memory: bool = True compile_model: bool = False # Hub push_to_hub: bool = False hub_model_id: str = "" def get_model_config(size: str, num_classes: int = 0, class_drop_prob: float = 0.1) -> dict: """Get model kwargs for a given size preset.""" configs = { "small": dict(embed_dim=512, depth=12, spatial_kernel=7, scan_kernel=31, expand_ratio=2.0, mlp_ratio=3.0), "base": dict(embed_dim=640, depth=18, spatial_kernel=7, scan_kernel=31, expand_ratio=2.0, mlp_ratio=4.0), "large": dict(embed_dim=768, depth=24, spatial_kernel=7, scan_kernel=31, expand_ratio=2.5, mlp_ratio=4.0), } cfg = configs[size] cfg["num_classes"] = num_classes cfg["class_drop_prob"] = class_drop_prob cfg["use_zigzag"] = True return cfg class EMAModel: """Exponential Moving Average of model parameters.""" def __init__(self, model: nn.Module, decay: float = 0.9999): self.decay = decay self.shadow = {name: p.clone().detach() for name, p in model.named_parameters() if p.requires_grad} @torch.no_grad() def update(self, model: nn.Module): for name, p in model.named_parameters(): if p.requires_grad and name in self.shadow: self.shadow[name].mul_(self.decay).add_(p.data, alpha=1 - self.decay) def apply(self, model: nn.Module): self.backup = {name: p.data.clone() for name, p in model.named_parameters() if p.requires_grad} for name, p in model.named_parameters(): if p.requires_grad and name in self.shadow: p.data.copy_(self.shadow[name]) def restore(self, model: nn.Module): for name, p in model.named_parameters(): if p.requires_grad and name in self.backup: p.data.copy_(self.backup[name]) self.backup = {} def state_dict(self): return self.shadow def load_state_dict(self, state_dict): self.shadow = state_dict class FlowMatchingScheduler: """ Flow Matching scheduler for training and sampling. Training: x_t = (1-t)*x_0 + t*ε, v_target = ε - x_0 Sampling: Euler ODE from t=1 (noise) to t=0 (clean) """ def __init__(self, min_t: float = 0.001, max_t: float = 0.999): self.min_t = min_t self.max_t = max_t def sample_timesteps(self, batch_size: int, device: torch.device) -> torch.Tensor: return torch.rand(batch_size, device=device) * (self.max_t - self.min_t) + self.min_t def add_noise(self, x0: torch.Tensor, noise: torch.Tensor, t: torch.Tensor) -> torch.Tensor: t_expand = t.view(-1, 1, 1, 1) return (1 - t_expand) * x0 + t_expand * noise def get_velocity_target(self, x0: torch.Tensor, noise: torch.Tensor) -> torch.Tensor: return noise - x0 @torch.no_grad() def sample( self, model: nn.Module, shape: tuple, device: torch.device, num_steps: int = 50, class_labels: Optional[torch.Tensor] = None, cfg_scale: float = 1.0, dtype: torch.dtype = torch.float32, ) -> torch.Tensor: model.eval() x = torch.randn(shape, device=device, dtype=dtype) dt = 1.0 / num_steps times = torch.linspace(1.0, dt, num_steps, device=device) for t_val in times: t = torch.full((shape[0],), t_val.item(), device=device, dtype=dtype) if cfg_scale > 1.0 and class_labels is not None: with torch.amp.autocast('cuda', enabled=(dtype != torch.float32)): v_cond = model(x, t, class_labels) v_uncond = model(x, t, torch.zeros_like(class_labels)) v = v_uncond + cfg_scale * (v_cond - v_uncond) else: with torch.amp.autocast('cuda', enabled=(dtype != torch.float32)): v = model(x, t, class_labels) x = x - dt * v return x def get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps): """Cosine LR schedule with linear warmup.""" def lr_lambda(current_step): if current_step < warmup_steps: return float(current_step) / float(max(1, warmup_steps)) progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps)) return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) @torch.no_grad() def encode_images_with_vae(images, vae, scaling_factor, shift_factor): """Encode pixel images to VAE latents.""" images = images * 2.0 - 1.0 latents = vae.encode(images).latent_dist.sample() latents = (latents - shift_factor) * scaling_factor return latents @torch.no_grad() def decode_latents_with_vae(latents, vae, scaling_factor, shift_factor): """Decode VAE latents to pixel images.""" latents = latents / scaling_factor + shift_factor images = vae.decode(latents).sample images = (images + 1.0) / 2.0 return images.clamp(0, 1) def train(config: TrainConfig): """Main training loop.""" from model import LiquidGen torch.manual_seed(config.seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Device: {device}") os.makedirs(config.output_dir, exist_ok=True) os.makedirs(os.path.join(config.output_dir, "samples"), exist_ok=True) os.makedirs(os.path.join(config.output_dir, "checkpoints"), exist_ok=True) with open(os.path.join(config.output_dir, "config.json"), "w") as f: json.dump(asdict(config), f, indent=2) # Load VAE print("Loading VAE...") from diffusers import AutoencoderKL vae_dtype = torch.float16 if config.vae_dtype == "float16" else torch.bfloat16 vae = AutoencoderKL.from_pretrained( config.vae_id, subfolder=config.vae_subfolder, torch_dtype=vae_dtype ).to(device).eval() for p in vae.parameters(): p.requires_grad_(False) # Load Dataset print(f"Loading dataset: {config.dataset_name}") from datasets import load_dataset from torchvision import transforms ds_kwargs = {} if config.dataset_config: ds_kwargs["name"] = config.dataset_config dataset = load_dataset(config.dataset_name, split="train", **ds_kwargs) transform = transforms.Compose([ transforms.Resize(config.image_size, interpolation=transforms.InterpolationMode.LANCZOS), transforms.CenterCrop(config.image_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), ]) class ImageDataset(Dataset): def __init__(self, hf_dataset, transform, image_col, label_col=""): self.dataset = hf_dataset self.transform = transform self.image_col = image_col self.label_col = label_col def __len__(self): return len(self.dataset) def __getitem__(self, idx): item = self.dataset[idx] img = item[self.image_col] if img.mode != "RGB": img = img.convert("RGB") img = self.transform(img) label = item[self.label_col] if self.label_col and self.label_col in item else -1 return img, label train_dataset = ImageDataset(dataset, transform, config.image_column, config.label_column) train_loader = DataLoader( train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=config.pin_memory, drop_last=True, ) # Create Model model_kwargs = get_model_config(config.model_size, config.num_classes, config.class_drop_prob) model = LiquidGen(**model_kwargs).to(device) print(f"LiquidGen-{config.model_size}: {model.count_params() / 1e6:.1f}M params") if config.compile_model and hasattr(torch, "compile"): model = torch.compile(model) optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay, betas=(0.9, 0.999)) total_steps = len(train_loader) * config.num_epochs // config.gradient_accumulation_steps scheduler = get_cosine_schedule_with_warmup(optimizer, config.warmup_steps, total_steps) ema = EMAModel(model, decay=config.ema_decay) scaler = GradScaler('cuda', enabled=config.mixed_precision) fm = FlowMatchingScheduler(min_t=config.min_timestep, max_t=config.max_timestep) print(f"\nTraining: {total_steps} steps, effective batch {config.batch_size * config.gradient_accumulation_steps}") global_step = 0 loss_accum = 0.0 for epoch in range(config.num_epochs): model.train() t_start = time.time() for batch_idx, (images, labels) in enumerate(train_loader): images = images.to(device) labels = labels.to(device) if config.num_classes > 0 else None with torch.no_grad(): latents = encode_images_with_vae( images.to(vae_dtype), vae, config.vae_scaling_factor, config.vae_shift_factor ).float() t = fm.sample_timesteps(latents.shape[0], device) noise = torch.randn_like(latents) x_t = fm.add_noise(latents, noise, t) v_target = fm.get_velocity_target(latents, noise) with autocast('cuda', enabled=config.mixed_precision): v_pred = model(x_t, t, labels) loss = F.mse_loss(v_pred, v_target) / config.gradient_accumulation_steps scaler.scale(loss).backward() loss_accum += loss.item() if (batch_idx + 1) % config.gradient_accumulation_steps == 0: scaler.unscale_(optimizer) grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm) scaler.step(optimizer) scaler.update() optimizer.zero_grad() scheduler.step() ema.update(model) global_step += 1 if global_step % config.log_every_n_steps == 0: avg_loss = loss_accum / config.log_every_n_steps lr = optimizer.param_groups[0]["lr"] print(f"step={global_step} | epoch={epoch} | loss={avg_loss:.4f} | " f"grad_norm={grad_norm:.2f} | lr={lr:.2e}") loss_accum = 0.0 if math.isnan(avg_loss) or avg_loss > 100: print("⚠️ Training diverged!") return if global_step % config.sample_every_n_steps == 0: ema.apply(model) model.eval() latent_size = config.image_size // 8 sample_labels = None if config.num_classes > 0: sample_labels = torch.randint(0, config.num_classes, (config.num_samples,), device=device) sampled = fm.sample(model, (config.num_samples, 16, latent_size, latent_size), device, config.num_sample_steps, sample_labels, config.cfg_scale) sample_imgs = decode_latents_with_vae(sampled.to(vae_dtype), vae, config.vae_scaling_factor, config.vae_shift_factor).float() from torchvision.utils import save_image save_image(sample_imgs, os.path.join(config.output_dir, "samples", f"step_{global_step:07d}.png"), nrow=2) ema.restore(model) model.train() if global_step % config.save_every_n_steps == 0: torch.save({ "model": model.state_dict(), "ema": ema.state_dict(), "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(), "global_step": global_step, "epoch": epoch, "config": asdict(config), }, os.path.join(config.output_dir, "checkpoints", f"step_{global_step:07d}.pt")) print(f"Epoch {epoch} complete | time={time.time()-t_start:.0f}s") torch.save({"model": model.state_dict(), "ema": ema.state_dict(), "config": asdict(config), "global_step": global_step}, os.path.join(config.output_dir, "checkpoints", "final.pt")) print(f"Training complete! Final model saved.") if __name__ == "__main__": config = TrainConfig(model_size="small", image_size=256, batch_size=4, num_epochs=2) train(config)