| """ |
| LiquidGen Training Pipeline v2 |
| |
| Optimized for Colab free tier: |
| - Latent pre-caching: encode images with VAE once, save to disk, train on pure tensors |
| - No VAE needed during training loop โ saves ~1GB VRAM + faster iterations |
| - Streaming support for large datasets |
| - Multiple small dataset presets |
| |
| Flow Matching training objective (velocity prediction): |
| - Forward: x_t = (1 - t) * x_0 + t * ฮต |
| - Target: v = ฮต - x_0 |
| - Loss: MSE(model(x_t, t), v) |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader, Dataset |
| from torch.amp import autocast, GradScaler |
| import math |
| import os |
| import json |
| import time |
| from typing import Optional |
| from dataclasses import dataclass, asdict |
|
|
|
|
| |
| |
| |
|
|
| DATASET_PRESETS = { |
| "paintings_mini": { |
| "name": "keremberke/painting-style-classification", |
| "config": "mini", |
| "image_column": "image", |
| "label_column": "labels", |
| "num_classes": 27, |
| "description": "~200 painting samples, 27 styles, 1.7MB โ instant smoke test", |
| }, |
| "paintings": { |
| "name": "keremberke/painting-style-classification", |
| "config": "full", |
| "image_column": "image", |
| "label_column": "labels", |
| "num_classes": 27, |
| "description": "~8K paintings, 27 styles, 204MB โ best for style-conditional training", |
| }, |
| "cartoon": { |
| "name": "Norod78/cartoon-blip-captions", |
| "config": "", |
| "image_column": "image", |
| "label_column": "", |
| "num_classes": 0, |
| "description": "~2.5K cartoon/anime, unconditional, 181MB", |
| }, |
| "flowers": { |
| "name": "huggan/flowers-102-categories", |
| "config": "", |
| "image_column": "image", |
| "label_column": "", |
| "num_classes": 0, |
| "description": "~8K flower photos, unconditional, 331MB", |
| }, |
| "wikiart_stream": { |
| "name": "huggan/wikiart", |
| "config": "", |
| "image_column": "image", |
| "label_column": "style", |
| "num_classes": 27, |
| "streaming": True, |
| "description": "~80K paintings, 27 styles, STREAMING (0 disk) โ use max_images to limit", |
| }, |
| } |
|
|
|
|
| @dataclass |
| class TrainConfig: |
| """Training configuration optimized for Colab free tier (T4 16GB).""" |
| |
| model_size: str = "small" |
| num_classes: int = 27 |
| class_drop_prob: float = 0.1 |
|
|
| |
| dataset_preset: str = "paintings" |
| image_size: int = 256 |
| max_images: int = 0 |
|
|
| |
| vae_id: str = "black-forest-labs/FLUX.1-schnell" |
| vae_subfolder: str = "vae" |
| vae_scaling_factor: float = 0.3611 |
| vae_shift_factor: float = 0.1159 |
|
|
| |
| batch_size: int = 32 |
| gradient_accumulation_steps: int = 1 |
| learning_rate: float = 1e-4 |
| weight_decay: float = 0.01 |
| max_grad_norm: float = 2.0 |
| num_epochs: int = 100 |
| warmup_steps: int = 500 |
| ema_decay: float = 0.9999 |
| mixed_precision: bool = True |
|
|
| |
| min_timestep: float = 0.001 |
| max_timestep: float = 0.999 |
|
|
| |
| output_dir: str = "./outputs" |
| save_every_n_steps: int = 2000 |
| sample_every_n_steps: int = 500 |
| log_every_n_steps: int = 25 |
|
|
| |
| num_sample_steps: int = 50 |
| cfg_scale: float = 2.0 |
| num_samples: int = 4 |
|
|
| |
| seed: int = 42 |
| num_workers: int = 2 |
| compile_model: bool = False |
|
|
| |
| push_to_hub: bool = False |
| hub_model_id: str = "" |
|
|
|
|
| def get_model_config(size, num_classes=0, class_drop_prob=0.1): |
| configs = { |
| "small": dict(embed_dim=512, depth=12, spatial_kernel=7, scan_kernel=31, |
| expand_ratio=2.0, mlp_ratio=3.0), |
| "base": dict(embed_dim=640, depth=18, spatial_kernel=7, scan_kernel=31, |
| expand_ratio=2.0, mlp_ratio=4.0), |
| "large": dict(embed_dim=768, depth=24, spatial_kernel=7, scan_kernel=31, |
| expand_ratio=2.5, mlp_ratio=4.0), |
| } |
| cfg = configs[size] |
| cfg["num_classes"] = num_classes |
| cfg["class_drop_prob"] = class_drop_prob |
| cfg["use_zigzag"] = True |
| return cfg |
|
|
|
|
| |
| |
| |
|
|
| class CachedLatentDataset(Dataset): |
| """Training dataset from pre-encoded VAE latents on disk.""" |
|
|
| def __init__(self, cache_path): |
| data = torch.load(cache_path, map_location="cpu", weights_only=True) |
| self.latents = data["latents"] |
| self.labels = data.get("labels", None) |
| print(f"Loaded {len(self.latents)} cached latents from {cache_path}") |
| print(f" Shape: {self.latents.shape}, dtype: {self.latents.dtype}") |
| if self.labels is not None: |
| print(f" Labels: unique={self.labels.unique().shape[0]}") |
|
|
| def __len__(self): |
| return len(self.latents) |
|
|
| def __getitem__(self, idx): |
| lat = self.latents[idx] |
| label = self.labels[idx] if self.labels is not None else -1 |
| return lat, label |
|
|
|
|
| def precache_latents(config, cache_path=None): |
| """ |
| Encode all images to VAE latents once, save to disk. |
| |
| After caching: |
| - VAE unloaded โ frees ~1GB VRAM |
| - Training loads pure tensors โ much faster iterations |
| - Larger batch sizes possible (no VAE memory overhead) |
| |
| Returns path to cache file. |
| """ |
| if cache_path is None: |
| cache_path = os.path.join(config.output_dir, "cached_latents.pt") |
|
|
| if os.path.exists(cache_path): |
| print(f"โ
Cache exists: {cache_path}") |
| data = torch.load(cache_path, map_location="cpu", weights_only=True) |
| print(f" {data['latents'].shape[0]} latents, shape {data['latents'].shape[1:]}") |
| return cache_path |
|
|
| os.makedirs(os.path.dirname(cache_path) if os.path.dirname(cache_path) else ".", exist_ok=True) |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| print("Loading VAE for encoding...") |
| from diffusers import AutoencoderKL |
| vae = AutoencoderKL.from_pretrained( |
| config.vae_id, subfolder=config.vae_subfolder, torch_dtype=torch.float16 |
| ).to(device).eval() |
| for p in vae.parameters(): |
| p.requires_grad_(False) |
|
|
| |
| preset = DATASET_PRESETS[config.dataset_preset] |
| print(f"Loading dataset: {preset['name']} ({preset['description']})") |
|
|
| from datasets import load_dataset |
| from torchvision import transforms |
|
|
| is_streaming = preset.get("streaming", False) |
| ds_kwargs = {"split": "train"} |
| if preset["config"]: |
| ds_kwargs["name"] = preset["config"] |
| if is_streaming: |
| ds_kwargs["streaming"] = True |
|
|
| dataset = load_dataset(preset["name"], **ds_kwargs) |
|
|
| transform = transforms.Compose([ |
| transforms.Resize(config.image_size, interpolation=transforms.InterpolationMode.LANCZOS), |
| transforms.CenterCrop(config.image_size), |
| transforms.ToTensor(), |
| ]) |
|
|
| all_latents = [] |
| all_labels = [] |
| batch_pixels = [] |
| batch_labels = [] |
| encode_bs = 16 |
| count = 0 |
| max_imgs = config.max_images if config.max_images > 0 else float("inf") |
| img_col = preset["image_column"] |
| lbl_col = preset["label_column"] |
|
|
| print(f"Encoding images to latents...") |
| t0 = time.time() |
|
|
| for item in dataset: |
| if count >= max_imgs: |
| break |
| img = item[img_col] |
| if img.mode != "RGB": |
| img = img.convert("RGB") |
| batch_pixels.append(transform(img)) |
| if lbl_col and lbl_col in item: |
| batch_labels.append(item[lbl_col]) |
| else: |
| batch_labels.append(-1) |
| count += 1 |
|
|
| if len(batch_pixels) >= encode_bs: |
| with torch.no_grad(): |
| px = torch.stack(batch_pixels).to(device, dtype=torch.float16) * 2 - 1 |
| lat = vae.encode(px).latent_dist.sample() |
| lat = (lat - config.vae_shift_factor) * config.vae_scaling_factor |
| all_latents.append(lat.cpu().float()) |
| all_labels.extend(batch_labels) |
| batch_pixels, batch_labels = [], [] |
| if count % 500 == 0: |
| print(f" {count} images encoded ({time.time()-t0:.0f}s)") |
|
|
| if batch_pixels: |
| with torch.no_grad(): |
| px = torch.stack(batch_pixels).to(device, dtype=torch.float16) * 2 - 1 |
| lat = vae.encode(px).latent_dist.sample() |
| lat = (lat - config.vae_shift_factor) * config.vae_scaling_factor |
| all_latents.append(lat.cpu().float()) |
| all_labels.extend(batch_labels) |
|
|
| all_latents = torch.cat(all_latents, dim=0) |
| all_labels = torch.tensor(all_labels, dtype=torch.long) |
| torch.save({"latents": all_latents, "labels": all_labels}, cache_path) |
|
|
| elapsed = time.time() - t0 |
| mb = os.path.getsize(cache_path) / 1024**2 |
| print(f"\nโ
Cached {count} latents โ {cache_path}") |
| print(f" Shape: {all_latents.shape}, Size: {mb:.1f}MB, Time: {elapsed:.0f}s") |
|
|
| del vae |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| print(" VAE unloaded, VRAM freed\n") |
| return cache_path |
|
|
|
|
| |
| |
| |
|
|
| class EMAModel: |
| def __init__(self, model, decay=0.9999): |
| self.decay = decay |
| self.shadow = {n: p.clone().detach() for n, p in model.named_parameters() if p.requires_grad} |
|
|
| @torch.no_grad() |
| def update(self, model): |
| for n, p in model.named_parameters(): |
| if p.requires_grad and n in self.shadow: |
| self.shadow[n].mul_(self.decay).add_(p.data, alpha=1 - self.decay) |
|
|
| def apply(self, model): |
| self.backup = {n: p.data.clone() for n, p in model.named_parameters() if p.requires_grad} |
| for n, p in model.named_parameters(): |
| if p.requires_grad and n in self.shadow: |
| p.data.copy_(self.shadow[n]) |
|
|
| def restore(self, model): |
| for n, p in model.named_parameters(): |
| if p.requires_grad and n in self.backup: |
| p.data.copy_(self.backup[n]) |
| self.backup = {} |
|
|
|
|
| class FlowMatchingScheduler: |
| def __init__(self, min_t=0.001, max_t=0.999): |
| self.min_t, self.max_t = min_t, max_t |
|
|
| def sample_timesteps(self, bs, dev): |
| return torch.rand(bs, device=dev) * (self.max_t - self.min_t) + self.min_t |
|
|
| def add_noise(self, x0, noise, t): |
| t = t.view(-1, 1, 1, 1); return (1 - t) * x0 + t * noise |
|
|
| def get_velocity_target(self, x0, noise): |
| return noise - x0 |
|
|
| @torch.no_grad() |
| def sample(self, model, shape, dev, num_steps=50, labels=None, cfg=1.0): |
| model.eval(); x = torch.randn(shape, device=dev) |
| dt = 1.0 / num_steps |
| for tv in torch.linspace(1.0, dt, num_steps, device=dev): |
| t = torch.full((shape[0],), tv.item(), device=dev) |
| with torch.amp.autocast("cuda"): |
| if cfg > 1.0 and labels is not None: |
| vc = model(x, t, labels); vu = model(x, t, torch.zeros_like(labels)) |
| v = vu + cfg * (vc - vu) |
| else: |
| v = model(x, t, labels) |
| x = x - dt * v.float() |
| return x |
|
|
|
|
| def cosine_schedule(opt, warmup, total): |
| def lr(s): |
| if s < warmup: return s / max(1, warmup) |
| return max(0, 0.5 * (1 + math.cos(math.pi * (s - warmup) / max(1, total - warmup)))) |
| return torch.optim.lr_scheduler.LambdaLR(opt, lr) |
|
|
|
|
| |
| |
| |
|
|
| def train(config): |
| from model import LiquidGen |
|
|
| torch.manual_seed(config.seed) |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Device: {device}") |
| if torch.cuda.is_available(): |
| print(f"GPU: {torch.cuda.get_device_name(0)} " |
| f"({torch.cuda.get_device_properties(0).total_mem/1024**3:.1f} GB)") |
|
|
| os.makedirs(config.output_dir, exist_ok=True) |
| os.makedirs(f"{config.output_dir}/samples", exist_ok=True) |
| os.makedirs(f"{config.output_dir}/checkpoints", exist_ok=True) |
|
|
| with open(f"{config.output_dir}/config.json", "w") as f: |
| json.dump(asdict(config), f, indent=2) |
|
|
| |
| cache_path = precache_latents(config) |
|
|
| |
| train_ds = CachedLatentDataset(cache_path) |
| train_dl = DataLoader(train_ds, batch_size=config.batch_size, shuffle=True, |
| num_workers=config.num_workers, pin_memory=True, drop_last=True) |
|
|
| |
| mcfg = get_model_config(config.model_size, config.num_classes, config.class_drop_prob) |
| model = LiquidGen(**mcfg).to(device) |
| print(f"LiquidGen-{config.model_size}: {model.count_params()/1e6:.1f}M params") |
|
|
| if config.compile_model and hasattr(torch, "compile"): |
| model = torch.compile(model) |
|
|
| |
| opt = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, |
| weight_decay=config.weight_decay, betas=(0.9, 0.999)) |
| total_steps = len(train_dl) * config.num_epochs // config.gradient_accumulation_steps |
| sched = cosine_schedule(opt, config.warmup_steps, total_steps) |
| ema = EMAModel(model, config.ema_decay) |
| scaler = GradScaler("cuda", enabled=config.mixed_precision and torch.cuda.is_available()) |
| fm = FlowMatchingScheduler(config.min_timestep, config.max_timestep) |
| lat_size = config.image_size // 8 |
|
|
| print(f"\nTotal steps: {total_steps}, Batch: {config.batch_size}ร{config.gradient_accumulation_steps}") |
| print(f"No VAE during training โ max VRAM for model") |
| if torch.cuda.is_available(): |
| print(f"VRAM: {torch.cuda.memory_allocated()/1024**3:.1f} / " |
| f"{torch.cuda.get_device_properties(0).total_mem/1024**3:.1f} GB") |
|
|
| |
| gs = 0; la = 0.0; vae = None; vae_loaded = False |
| print(f"\n{'='*60}\n๐ Training!\n{'='*60}\n") |
| t_start = time.time() |
|
|
| for epoch in range(config.num_epochs): |
| model.train(); et = time.time() |
| for bi, (lats, lbls) in enumerate(train_dl): |
| lats = lats.to(device) |
| lbls = lbls.to(device) if config.num_classes > 0 else None |
|
|
| t = fm.sample_timesteps(lats.shape[0], device) |
| noise = torch.randn_like(lats) |
| xt = fm.add_noise(lats, noise, t) |
| vtgt = fm.get_velocity_target(lats, noise) |
|
|
| with autocast("cuda", enabled=config.mixed_precision and torch.cuda.is_available()): |
| vp = model(xt, t, lbls) |
| loss = F.mse_loss(vp, vtgt) / config.gradient_accumulation_steps |
|
|
| scaler.scale(loss).backward() |
| la += loss.item() |
|
|
| if (bi + 1) % config.gradient_accumulation_steps == 0: |
| scaler.unscale_(opt) |
| gn = torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm) |
| scaler.step(opt); scaler.update(); opt.zero_grad(); sched.step() |
| ema.update(model); gs += 1 |
|
|
| if gs % config.log_every_n_steps == 0: |
| al = la / config.log_every_n_steps |
| lr = opt.param_groups[0]["lr"] |
| vram = torch.cuda.memory_allocated()/1024**3 if torch.cuda.is_available() else 0 |
| sps = gs / max(time.time() - t_start, 1) |
| print(f"step={gs:>6d} | ep={epoch} | loss={al:.4f} | gn={gn:.2f} | " |
| f"lr={lr:.2e} | vram={vram:.1f}G | {sps:.1f} st/s") |
| la = 0.0 |
| if math.isnan(al) or al > 50: |
| print("๐ฅ Diverged!"); return |
|
|
| if gs % config.sample_every_n_steps == 0: |
| if not vae_loaded: |
| from diffusers import AutoencoderKL |
| vae = AutoencoderKL.from_pretrained( |
| config.vae_id, subfolder=config.vae_subfolder, |
| torch_dtype=torch.float16).to(device).eval() |
| for p in vae.parameters(): p.requires_grad_(False) |
| vae_loaded = True |
| ema.apply(model); model.eval() |
| sl = torch.randint(0, max(1, config.num_classes), (config.num_samples,), |
| device=device) if config.num_classes > 0 else None |
| samp = fm.sample(model, (config.num_samples, 16, lat_size, lat_size), |
| device, config.num_sample_steps, sl, config.cfg_scale) |
| with torch.no_grad(): |
| dec = samp.half() / config.vae_scaling_factor + config.vae_shift_factor |
| imgs = ((vae.decode(dec).sample + 1) / 2).clamp(0, 1).float() |
| from torchvision.utils import save_image |
| sp = f"{config.output_dir}/samples/step_{gs:07d}.png" |
| save_image(imgs, sp, nrow=2); print(f" ๐ธ {sp}") |
| ema.restore(model); model.train() |
|
|
| if gs % config.save_every_n_steps == 0: |
| cp = f"{config.output_dir}/checkpoints/step_{gs:07d}.pt" |
| torch.save({"model": model.state_dict(), "ema": ema.shadow, |
| "optimizer": opt.state_dict(), "scheduler": sched.state_dict(), |
| "step": gs, "epoch": epoch, "model_config": mcfg}, cp) |
| print(f" ๐พ {cp}") |
|
|
| print(f"Epoch {epoch} | {time.time()-et:.0f}s\n") |
|
|
| final = f"{config.output_dir}/checkpoints/final.pt" |
| torch.save({"model": model.state_dict(), "ema": ema.shadow, |
| "model_config": mcfg, "step": gs}, final) |
| print(f"\n๐ Done! {gs} steps, {(time.time()-t_start)/60:.1f}min โ {final}") |
|
|
|
|
| if __name__ == "__main__": |
| config = TrainConfig( |
| model_size="small", dataset_preset="paintings_mini", |
| image_size=256, batch_size=8, num_epochs=5, |
| log_every_n_steps=5, sample_every_n_steps=99999, |
| ) |
| train(config) |
|
|