""" LiquidGen Training Pipeline v2 Optimized for Colab free tier: - Latent pre-caching: encode images with VAE once, save to disk, train on pure tensors - No VAE needed during training loop → saves ~1GB VRAM + faster iterations - Streaming support for large datasets - Multiple small dataset presets Flow Matching training objective (velocity prediction): - Forward: x_t = (1 - t) * x_0 + t * ε - Target: v = ε - x_0 - Loss: MSE(model(x_t, t), v) """ import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset from torch.amp import autocast, GradScaler import math import os import json import time from typing import Optional from dataclasses import dataclass, asdict # ============================================================================= # Dataset Presets (all verified, fast to download) # ============================================================================= DATASET_PRESETS = { "paintings_mini": { "name": "keremberke/painting-style-classification", "config": "mini", "image_column": "image", "label_column": "labels", "num_classes": 27, "description": "~200 painting samples, 27 styles, 1.7MB — instant smoke test", }, "paintings": { "name": "keremberke/painting-style-classification", "config": "full", "image_column": "image", "label_column": "labels", "num_classes": 27, "description": "~8K paintings, 27 styles, 204MB — best for style-conditional training", }, "cartoon": { "name": "Norod78/cartoon-blip-captions", "config": "", "image_column": "image", "label_column": "", "num_classes": 0, "description": "~2.5K cartoon/anime, unconditional, 181MB", }, "flowers": { "name": "huggan/flowers-102-categories", "config": "", "image_column": "image", "label_column": "", "num_classes": 0, "description": "~8K flower photos, unconditional, 331MB", }, "wikiart_stream": { "name": "huggan/wikiart", "config": "", "image_column": "image", "label_column": "style", "num_classes": 27, "streaming": True, "description": "~80K paintings, 27 styles, STREAMING (0 disk) — use max_images to limit", }, } @dataclass class TrainConfig: """Training configuration optimized for Colab free tier (T4 16GB).""" # Model model_size: str = "small" # small (~55M), base (~140M), large (~280M) num_classes: int = 27 class_drop_prob: float = 0.1 # Data dataset_preset: str = "paintings" # key from DATASET_PRESETS image_size: int = 256 # 256 or 512 max_images: int = 0 # 0 = use all, >0 = limit (for streaming/testing) # VAE (for pre-caching only — NOT loaded during training) vae_id: str = "black-forest-labs/FLUX.1-schnell" vae_subfolder: str = "vae" vae_scaling_factor: float = 0.3611 vae_shift_factor: float = 0.1159 # Training batch_size: int = 32 # Can be large since training on cached tensors! gradient_accumulation_steps: int = 1 learning_rate: float = 1e-4 weight_decay: float = 0.01 max_grad_norm: float = 2.0 # Critical for stability (ZigMa paper) num_epochs: int = 100 warmup_steps: int = 500 ema_decay: float = 0.9999 mixed_precision: bool = True # Flow matching min_timestep: float = 0.001 max_timestep: float = 0.999 # Saving output_dir: str = "./outputs" save_every_n_steps: int = 2000 sample_every_n_steps: int = 500 log_every_n_steps: int = 25 # Sampling num_sample_steps: int = 50 cfg_scale: float = 2.0 num_samples: int = 4 # System seed: int = 42 num_workers: int = 2 compile_model: bool = False # Hub push_to_hub: bool = False hub_model_id: str = "" def get_model_config(size, num_classes=0, class_drop_prob=0.1): configs = { "small": dict(embed_dim=512, depth=12, spatial_kernel=7, scan_kernel=31, expand_ratio=2.0, mlp_ratio=3.0), "base": dict(embed_dim=640, depth=18, spatial_kernel=7, scan_kernel=31, expand_ratio=2.0, mlp_ratio=4.0), "large": dict(embed_dim=768, depth=24, spatial_kernel=7, scan_kernel=31, expand_ratio=2.5, mlp_ratio=4.0), } cfg = configs[size] cfg["num_classes"] = num_classes cfg["class_drop_prob"] = class_drop_prob cfg["use_zigzag"] = True return cfg # ============================================================================= # Latent Pre-Caching (the key optimization for Colab) # ============================================================================= class CachedLatentDataset(Dataset): """Training dataset from pre-encoded VAE latents on disk.""" def __init__(self, cache_path): data = torch.load(cache_path, map_location="cpu", weights_only=True) self.latents = data["latents"] self.labels = data.get("labels", None) print(f"Loaded {len(self.latents)} cached latents from {cache_path}") print(f" Shape: {self.latents.shape}, dtype: {self.latents.dtype}") if self.labels is not None: print(f" Labels: unique={self.labels.unique().shape[0]}") def __len__(self): return len(self.latents) def __getitem__(self, idx): lat = self.latents[idx] label = self.labels[idx] if self.labels is not None else -1 return lat, label def precache_latents(config, cache_path=None): """ Encode all images to VAE latents once, save to disk. After caching: - VAE unloaded → frees ~1GB VRAM - Training loads pure tensors → much faster iterations - Larger batch sizes possible (no VAE memory overhead) Returns path to cache file. """ if cache_path is None: cache_path = os.path.join(config.output_dir, "cached_latents.pt") if os.path.exists(cache_path): print(f"✅ Cache exists: {cache_path}") data = torch.load(cache_path, map_location="cpu", weights_only=True) print(f" {data['latents'].shape[0]} latents, shape {data['latents'].shape[1:]}") return cache_path os.makedirs(os.path.dirname(cache_path) if os.path.dirname(cache_path) else ".", exist_ok=True) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load VAE temporarily print("Loading VAE for encoding...") from diffusers import AutoencoderKL vae = AutoencoderKL.from_pretrained( config.vae_id, subfolder=config.vae_subfolder, torch_dtype=torch.float16 ).to(device).eval() for p in vae.parameters(): p.requires_grad_(False) # Load dataset preset = DATASET_PRESETS[config.dataset_preset] print(f"Loading dataset: {preset['name']} ({preset['description']})") from datasets import load_dataset from torchvision import transforms is_streaming = preset.get("streaming", False) ds_kwargs = {"split": "train"} if preset["config"]: ds_kwargs["name"] = preset["config"] if is_streaming: ds_kwargs["streaming"] = True dataset = load_dataset(preset["name"], **ds_kwargs) transform = transforms.Compose([ transforms.Resize(config.image_size, interpolation=transforms.InterpolationMode.LANCZOS), transforms.CenterCrop(config.image_size), transforms.ToTensor(), ]) all_latents = [] all_labels = [] batch_pixels = [] batch_labels = [] encode_bs = 16 count = 0 max_imgs = config.max_images if config.max_images > 0 else float("inf") img_col = preset["image_column"] lbl_col = preset["label_column"] print(f"Encoding images to latents...") t0 = time.time() for item in dataset: if count >= max_imgs: break img = item[img_col] if img.mode != "RGB": img = img.convert("RGB") batch_pixels.append(transform(img)) if lbl_col and lbl_col in item: batch_labels.append(item[lbl_col]) else: batch_labels.append(-1) count += 1 if len(batch_pixels) >= encode_bs: with torch.no_grad(): px = torch.stack(batch_pixels).to(device, dtype=torch.float16) * 2 - 1 lat = vae.encode(px).latent_dist.sample() lat = (lat - config.vae_shift_factor) * config.vae_scaling_factor all_latents.append(lat.cpu().float()) all_labels.extend(batch_labels) batch_pixels, batch_labels = [], [] if count % 500 == 0: print(f" {count} images encoded ({time.time()-t0:.0f}s)") if batch_pixels: with torch.no_grad(): px = torch.stack(batch_pixels).to(device, dtype=torch.float16) * 2 - 1 lat = vae.encode(px).latent_dist.sample() lat = (lat - config.vae_shift_factor) * config.vae_scaling_factor all_latents.append(lat.cpu().float()) all_labels.extend(batch_labels) all_latents = torch.cat(all_latents, dim=0) all_labels = torch.tensor(all_labels, dtype=torch.long) torch.save({"latents": all_latents, "labels": all_labels}, cache_path) elapsed = time.time() - t0 mb = os.path.getsize(cache_path) / 1024**2 print(f"\n✅ Cached {count} latents → {cache_path}") print(f" Shape: {all_latents.shape}, Size: {mb:.1f}MB, Time: {elapsed:.0f}s") del vae if torch.cuda.is_available(): torch.cuda.empty_cache() print(" VAE unloaded, VRAM freed\n") return cache_path # ============================================================================= # EMA, FlowMatching, Scheduler # ============================================================================= class EMAModel: def __init__(self, model, decay=0.9999): self.decay = decay self.shadow = {n: p.clone().detach() for n, p in model.named_parameters() if p.requires_grad} @torch.no_grad() def update(self, model): for n, p in model.named_parameters(): if p.requires_grad and n in self.shadow: self.shadow[n].mul_(self.decay).add_(p.data, alpha=1 - self.decay) def apply(self, model): self.backup = {n: p.data.clone() for n, p in model.named_parameters() if p.requires_grad} for n, p in model.named_parameters(): if p.requires_grad and n in self.shadow: p.data.copy_(self.shadow[n]) def restore(self, model): for n, p in model.named_parameters(): if p.requires_grad and n in self.backup: p.data.copy_(self.backup[n]) self.backup = {} class FlowMatchingScheduler: def __init__(self, min_t=0.001, max_t=0.999): self.min_t, self.max_t = min_t, max_t def sample_timesteps(self, bs, dev): return torch.rand(bs, device=dev) * (self.max_t - self.min_t) + self.min_t def add_noise(self, x0, noise, t): t = t.view(-1, 1, 1, 1); return (1 - t) * x0 + t * noise def get_velocity_target(self, x0, noise): return noise - x0 @torch.no_grad() def sample(self, model, shape, dev, num_steps=50, labels=None, cfg=1.0): model.eval(); x = torch.randn(shape, device=dev) dt = 1.0 / num_steps for tv in torch.linspace(1.0, dt, num_steps, device=dev): t = torch.full((shape[0],), tv.item(), device=dev) with torch.amp.autocast("cuda"): if cfg > 1.0 and labels is not None: vc = model(x, t, labels); vu = model(x, t, torch.zeros_like(labels)) v = vu + cfg * (vc - vu) else: v = model(x, t, labels) x = x - dt * v.float() return x def cosine_schedule(opt, warmup, total): def lr(s): if s < warmup: return s / max(1, warmup) return max(0, 0.5 * (1 + math.cos(math.pi * (s - warmup) / max(1, total - warmup)))) return torch.optim.lr_scheduler.LambdaLR(opt, lr) # ============================================================================= # Main Training Loop # ============================================================================= def train(config): from model import LiquidGen torch.manual_seed(config.seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Device: {device}") if torch.cuda.is_available(): print(f"GPU: {torch.cuda.get_device_name(0)} " f"({torch.cuda.get_device_properties(0).total_mem/1024**3:.1f} GB)") os.makedirs(config.output_dir, exist_ok=True) os.makedirs(f"{config.output_dir}/samples", exist_ok=True) os.makedirs(f"{config.output_dir}/checkpoints", exist_ok=True) with open(f"{config.output_dir}/config.json", "w") as f: json.dump(asdict(config), f, indent=2) # Step 1: Pre-cache latents cache_path = precache_latents(config) # Step 2: Dataset from cache train_ds = CachedLatentDataset(cache_path) train_dl = DataLoader(train_ds, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=True, drop_last=True) # Step 3: Model mcfg = get_model_config(config.model_size, config.num_classes, config.class_drop_prob) model = LiquidGen(**mcfg).to(device) print(f"LiquidGen-{config.model_size}: {model.count_params()/1e6:.1f}M params") if config.compile_model and hasattr(torch, "compile"): model = torch.compile(model) # Step 4: Training setup opt = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay, betas=(0.9, 0.999)) total_steps = len(train_dl) * config.num_epochs // config.gradient_accumulation_steps sched = cosine_schedule(opt, config.warmup_steps, total_steps) ema = EMAModel(model, config.ema_decay) scaler = GradScaler("cuda", enabled=config.mixed_precision and torch.cuda.is_available()) fm = FlowMatchingScheduler(config.min_timestep, config.max_timestep) lat_size = config.image_size // 8 print(f"\nTotal steps: {total_steps}, Batch: {config.batch_size}×{config.gradient_accumulation_steps}") print(f"No VAE during training → max VRAM for model") if torch.cuda.is_available(): print(f"VRAM: {torch.cuda.memory_allocated()/1024**3:.1f} / " f"{torch.cuda.get_device_properties(0).total_mem/1024**3:.1f} GB") # Step 5: Train! gs = 0; la = 0.0; vae = None; vae_loaded = False print(f"\n{'='*60}\n🚀 Training!\n{'='*60}\n") t_start = time.time() for epoch in range(config.num_epochs): model.train(); et = time.time() for bi, (lats, lbls) in enumerate(train_dl): lats = lats.to(device) lbls = lbls.to(device) if config.num_classes > 0 else None t = fm.sample_timesteps(lats.shape[0], device) noise = torch.randn_like(lats) xt = fm.add_noise(lats, noise, t) vtgt = fm.get_velocity_target(lats, noise) with autocast("cuda", enabled=config.mixed_precision and torch.cuda.is_available()): vp = model(xt, t, lbls) loss = F.mse_loss(vp, vtgt) / config.gradient_accumulation_steps scaler.scale(loss).backward() la += loss.item() if (bi + 1) % config.gradient_accumulation_steps == 0: scaler.unscale_(opt) gn = torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm) scaler.step(opt); scaler.update(); opt.zero_grad(); sched.step() ema.update(model); gs += 1 if gs % config.log_every_n_steps == 0: al = la / config.log_every_n_steps lr = opt.param_groups[0]["lr"] vram = torch.cuda.memory_allocated()/1024**3 if torch.cuda.is_available() else 0 sps = gs / max(time.time() - t_start, 1) print(f"step={gs:>6d} | ep={epoch} | loss={al:.4f} | gn={gn:.2f} | " f"lr={lr:.2e} | vram={vram:.1f}G | {sps:.1f} st/s") la = 0.0 if math.isnan(al) or al > 50: print("💥 Diverged!"); return if gs % config.sample_every_n_steps == 0: if not vae_loaded: from diffusers import AutoencoderKL vae = AutoencoderKL.from_pretrained( config.vae_id, subfolder=config.vae_subfolder, torch_dtype=torch.float16).to(device).eval() for p in vae.parameters(): p.requires_grad_(False) vae_loaded = True ema.apply(model); model.eval() sl = torch.randint(0, max(1, config.num_classes), (config.num_samples,), device=device) if config.num_classes > 0 else None samp = fm.sample(model, (config.num_samples, 16, lat_size, lat_size), device, config.num_sample_steps, sl, config.cfg_scale) with torch.no_grad(): dec = samp.half() / config.vae_scaling_factor + config.vae_shift_factor imgs = ((vae.decode(dec).sample + 1) / 2).clamp(0, 1).float() from torchvision.utils import save_image sp = f"{config.output_dir}/samples/step_{gs:07d}.png" save_image(imgs, sp, nrow=2); print(f" 📸 {sp}") ema.restore(model); model.train() if gs % config.save_every_n_steps == 0: cp = f"{config.output_dir}/checkpoints/step_{gs:07d}.pt" torch.save({"model": model.state_dict(), "ema": ema.shadow, "optimizer": opt.state_dict(), "scheduler": sched.state_dict(), "step": gs, "epoch": epoch, "model_config": mcfg}, cp) print(f" 💾 {cp}") print(f"Epoch {epoch} | {time.time()-et:.0f}s\n") final = f"{config.output_dir}/checkpoints/final.pt" torch.save({"model": model.state_dict(), "ema": ema.shadow, "model_config": mcfg, "step": gs}, final) print(f"\n🎉 Done! {gs} steps, {(time.time()-t_start)/60:.1f}min → {final}") if __name__ == "__main__": config = TrainConfig( model_size="small", dataset_preset="paintings_mini", image_size=256, batch_size=8, num_epochs=5, log_every_n_steps=5, sample_every_n_steps=99999, ) train(config)