| """ |
| LiquidGen Training Pipeline |
| |
| Flow Matching training objective (velocity prediction): |
| - Forward: x_t = (1 - t) * x_0 + t * ε (linear interpolation) |
| - Target: v = ε - x_0 (velocity) |
| - Loss: MSE(model(x_t, t), v) |
| |
| At inference: solve ODE from t=1 (noise) to t=0 (clean) using Euler steps. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader, Dataset |
| from torch.amp import autocast, GradScaler |
| import math |
| import os |
| import json |
| import time |
| from pathlib import Path |
| from typing import Optional, Dict, Any |
| from dataclasses import dataclass, field, asdict |
|
|
|
|
| @dataclass |
| class TrainConfig: |
| """Training configuration with sensible defaults for Colab free tier.""" |
| |
| model_size: str = "small" |
| num_classes: int = 0 |
| class_drop_prob: float = 0.1 |
| |
| |
| image_size: int = 256 |
| dataset_name: str = "huggan/wikiart" |
| dataset_config: str = "" |
| image_column: str = "image" |
| label_column: str = "" |
| |
| |
| vae_id: str = "black-forest-labs/FLUX.1-schnell" |
| vae_subfolder: str = "vae" |
| vae_dtype: str = "float16" |
| vae_scaling_factor: float = 0.3611 |
| vae_shift_factor: float = 0.1159 |
| |
| |
| batch_size: int = 8 |
| gradient_accumulation_steps: int = 4 |
| learning_rate: float = 1e-4 |
| weight_decay: float = 0.01 |
| max_grad_norm: float = 2.0 |
| num_epochs: int = 100 |
| warmup_steps: int = 1000 |
| ema_decay: float = 0.9999 |
| mixed_precision: bool = True |
| |
| |
| min_timestep: float = 0.001 |
| max_timestep: float = 0.999 |
| |
| |
| output_dir: str = "./outputs" |
| save_every_n_steps: int = 5000 |
| sample_every_n_steps: int = 1000 |
| log_every_n_steps: int = 50 |
| |
| |
| num_sample_steps: int = 50 |
| cfg_scale: float = 1.5 |
| num_samples: int = 4 |
| |
| |
| seed: int = 42 |
| num_workers: int = 2 |
| pin_memory: bool = True |
| compile_model: bool = False |
| |
| |
| push_to_hub: bool = False |
| hub_model_id: str = "" |
|
|
|
|
| def get_model_config(size: str, num_classes: int = 0, class_drop_prob: float = 0.1) -> dict: |
| """Get model kwargs for a given size preset.""" |
| configs = { |
| "small": dict(embed_dim=512, depth=12, spatial_kernel=7, scan_kernel=31, |
| expand_ratio=2.0, mlp_ratio=3.0), |
| "base": dict(embed_dim=640, depth=18, spatial_kernel=7, scan_kernel=31, |
| expand_ratio=2.0, mlp_ratio=4.0), |
| "large": dict(embed_dim=768, depth=24, spatial_kernel=7, scan_kernel=31, |
| expand_ratio=2.5, mlp_ratio=4.0), |
| } |
| cfg = configs[size] |
| cfg["num_classes"] = num_classes |
| cfg["class_drop_prob"] = class_drop_prob |
| cfg["use_zigzag"] = True |
| return cfg |
|
|
|
|
| class EMAModel: |
| """Exponential Moving Average of model parameters.""" |
| |
| def __init__(self, model: nn.Module, decay: float = 0.9999): |
| self.decay = decay |
| self.shadow = {name: p.clone().detach() for name, p in model.named_parameters() if p.requires_grad} |
| |
| @torch.no_grad() |
| def update(self, model: nn.Module): |
| for name, p in model.named_parameters(): |
| if p.requires_grad and name in self.shadow: |
| self.shadow[name].mul_(self.decay).add_(p.data, alpha=1 - self.decay) |
| |
| def apply(self, model: nn.Module): |
| self.backup = {name: p.data.clone() for name, p in model.named_parameters() if p.requires_grad} |
| for name, p in model.named_parameters(): |
| if p.requires_grad and name in self.shadow: |
| p.data.copy_(self.shadow[name]) |
| |
| def restore(self, model: nn.Module): |
| for name, p in model.named_parameters(): |
| if p.requires_grad and name in self.backup: |
| p.data.copy_(self.backup[name]) |
| self.backup = {} |
| |
| def state_dict(self): |
| return self.shadow |
| |
| def load_state_dict(self, state_dict): |
| self.shadow = state_dict |
|
|
|
|
| class FlowMatchingScheduler: |
| """ |
| Flow Matching scheduler for training and sampling. |
| |
| Training: x_t = (1-t)*x_0 + t*ε, v_target = ε - x_0 |
| Sampling: Euler ODE from t=1 (noise) to t=0 (clean) |
| """ |
| |
| def __init__(self, min_t: float = 0.001, max_t: float = 0.999): |
| self.min_t = min_t |
| self.max_t = max_t |
| |
| def sample_timesteps(self, batch_size: int, device: torch.device) -> torch.Tensor: |
| return torch.rand(batch_size, device=device) * (self.max_t - self.min_t) + self.min_t |
| |
| def add_noise(self, x0: torch.Tensor, noise: torch.Tensor, t: torch.Tensor) -> torch.Tensor: |
| t_expand = t.view(-1, 1, 1, 1) |
| return (1 - t_expand) * x0 + t_expand * noise |
| |
| def get_velocity_target(self, x0: torch.Tensor, noise: torch.Tensor) -> torch.Tensor: |
| return noise - x0 |
| |
| @torch.no_grad() |
| def sample( |
| self, model: nn.Module, shape: tuple, device: torch.device, |
| num_steps: int = 50, class_labels: Optional[torch.Tensor] = None, |
| cfg_scale: float = 1.0, dtype: torch.dtype = torch.float32, |
| ) -> torch.Tensor: |
| model.eval() |
| x = torch.randn(shape, device=device, dtype=dtype) |
| dt = 1.0 / num_steps |
| times = torch.linspace(1.0, dt, num_steps, device=device) |
| |
| for t_val in times: |
| t = torch.full((shape[0],), t_val.item(), device=device, dtype=dtype) |
| |
| if cfg_scale > 1.0 and class_labels is not None: |
| with torch.amp.autocast('cuda', enabled=(dtype != torch.float32)): |
| v_cond = model(x, t, class_labels) |
| v_uncond = model(x, t, torch.zeros_like(class_labels)) |
| v = v_uncond + cfg_scale * (v_cond - v_uncond) |
| else: |
| with torch.amp.autocast('cuda', enabled=(dtype != torch.float32)): |
| v = model(x, t, class_labels) |
| |
| x = x - dt * v |
| |
| return x |
|
|
|
|
| def get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps): |
| """Cosine LR schedule with linear warmup.""" |
| def lr_lambda(current_step): |
| if current_step < warmup_steps: |
| return float(current_step) / float(max(1, warmup_steps)) |
| progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps)) |
| return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) |
| return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) |
|
|
|
|
| @torch.no_grad() |
| def encode_images_with_vae(images, vae, scaling_factor, shift_factor): |
| """Encode pixel images to VAE latents.""" |
| images = images * 2.0 - 1.0 |
| latents = vae.encode(images).latent_dist.sample() |
| latents = (latents - shift_factor) * scaling_factor |
| return latents |
|
|
|
|
| @torch.no_grad() |
| def decode_latents_with_vae(latents, vae, scaling_factor, shift_factor): |
| """Decode VAE latents to pixel images.""" |
| latents = latents / scaling_factor + shift_factor |
| images = vae.decode(latents).sample |
| images = (images + 1.0) / 2.0 |
| return images.clamp(0, 1) |
|
|
|
|
| def train(config: TrainConfig): |
| """Main training loop.""" |
| from model import LiquidGen |
| |
| torch.manual_seed(config.seed) |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Device: {device}") |
| |
| os.makedirs(config.output_dir, exist_ok=True) |
| os.makedirs(os.path.join(config.output_dir, "samples"), exist_ok=True) |
| os.makedirs(os.path.join(config.output_dir, "checkpoints"), exist_ok=True) |
| |
| with open(os.path.join(config.output_dir, "config.json"), "w") as f: |
| json.dump(asdict(config), f, indent=2) |
| |
| |
| print("Loading VAE...") |
| from diffusers import AutoencoderKL |
| vae_dtype = torch.float16 if config.vae_dtype == "float16" else torch.bfloat16 |
| vae = AutoencoderKL.from_pretrained( |
| config.vae_id, subfolder=config.vae_subfolder, torch_dtype=vae_dtype |
| ).to(device).eval() |
| for p in vae.parameters(): |
| p.requires_grad_(False) |
| |
| |
| print(f"Loading dataset: {config.dataset_name}") |
| from datasets import load_dataset |
| from torchvision import transforms |
| |
| ds_kwargs = {} |
| if config.dataset_config: |
| ds_kwargs["name"] = config.dataset_config |
| dataset = load_dataset(config.dataset_name, split="train", **ds_kwargs) |
| |
| transform = transforms.Compose([ |
| transforms.Resize(config.image_size, interpolation=transforms.InterpolationMode.LANCZOS), |
| transforms.CenterCrop(config.image_size), |
| transforms.RandomHorizontalFlip(), |
| transforms.ToTensor(), |
| ]) |
| |
| class ImageDataset(Dataset): |
| def __init__(self, hf_dataset, transform, image_col, label_col=""): |
| self.dataset = hf_dataset |
| self.transform = transform |
| self.image_col = image_col |
| self.label_col = label_col |
| |
| def __len__(self): |
| return len(self.dataset) |
| |
| def __getitem__(self, idx): |
| item = self.dataset[idx] |
| img = item[self.image_col] |
| if img.mode != "RGB": |
| img = img.convert("RGB") |
| img = self.transform(img) |
| label = item[self.label_col] if self.label_col and self.label_col in item else -1 |
| return img, label |
| |
| train_dataset = ImageDataset(dataset, transform, config.image_column, config.label_column) |
| train_loader = DataLoader( |
| train_dataset, batch_size=config.batch_size, shuffle=True, |
| num_workers=config.num_workers, pin_memory=config.pin_memory, drop_last=True, |
| ) |
| |
| |
| model_kwargs = get_model_config(config.model_size, config.num_classes, config.class_drop_prob) |
| model = LiquidGen(**model_kwargs).to(device) |
| print(f"LiquidGen-{config.model_size}: {model.count_params() / 1e6:.1f}M params") |
| |
| if config.compile_model and hasattr(torch, "compile"): |
| model = torch.compile(model) |
| |
| optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, |
| weight_decay=config.weight_decay, betas=(0.9, 0.999)) |
| total_steps = len(train_loader) * config.num_epochs // config.gradient_accumulation_steps |
| scheduler = get_cosine_schedule_with_warmup(optimizer, config.warmup_steps, total_steps) |
| ema = EMAModel(model, decay=config.ema_decay) |
| scaler = GradScaler('cuda', enabled=config.mixed_precision) |
| fm = FlowMatchingScheduler(min_t=config.min_timestep, max_t=config.max_timestep) |
| |
| print(f"\nTraining: {total_steps} steps, effective batch {config.batch_size * config.gradient_accumulation_steps}") |
| |
| global_step = 0 |
| loss_accum = 0.0 |
| |
| for epoch in range(config.num_epochs): |
| model.train() |
| t_start = time.time() |
| |
| for batch_idx, (images, labels) in enumerate(train_loader): |
| images = images.to(device) |
| labels = labels.to(device) if config.num_classes > 0 else None |
| |
| with torch.no_grad(): |
| latents = encode_images_with_vae( |
| images.to(vae_dtype), vae, config.vae_scaling_factor, config.vae_shift_factor |
| ).float() |
| |
| t = fm.sample_timesteps(latents.shape[0], device) |
| noise = torch.randn_like(latents) |
| x_t = fm.add_noise(latents, noise, t) |
| v_target = fm.get_velocity_target(latents, noise) |
| |
| with autocast('cuda', enabled=config.mixed_precision): |
| v_pred = model(x_t, t, labels) |
| loss = F.mse_loss(v_pred, v_target) / config.gradient_accumulation_steps |
| |
| scaler.scale(loss).backward() |
| loss_accum += loss.item() |
| |
| if (batch_idx + 1) % config.gradient_accumulation_steps == 0: |
| scaler.unscale_(optimizer) |
| grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm) |
| scaler.step(optimizer) |
| scaler.update() |
| optimizer.zero_grad() |
| scheduler.step() |
| ema.update(model) |
| global_step += 1 |
| |
| if global_step % config.log_every_n_steps == 0: |
| avg_loss = loss_accum / config.log_every_n_steps |
| lr = optimizer.param_groups[0]["lr"] |
| print(f"step={global_step} | epoch={epoch} | loss={avg_loss:.4f} | " |
| f"grad_norm={grad_norm:.2f} | lr={lr:.2e}") |
| loss_accum = 0.0 |
| |
| if math.isnan(avg_loss) or avg_loss > 100: |
| print("⚠️ Training diverged!") |
| return |
| |
| if global_step % config.sample_every_n_steps == 0: |
| ema.apply(model) |
| model.eval() |
| latent_size = config.image_size // 8 |
| sample_labels = None |
| if config.num_classes > 0: |
| sample_labels = torch.randint(0, config.num_classes, (config.num_samples,), device=device) |
| sampled = fm.sample(model, (config.num_samples, 16, latent_size, latent_size), |
| device, config.num_sample_steps, sample_labels, config.cfg_scale) |
| sample_imgs = decode_latents_with_vae(sampled.to(vae_dtype), vae, |
| config.vae_scaling_factor, config.vae_shift_factor).float() |
| from torchvision.utils import save_image |
| save_image(sample_imgs, os.path.join(config.output_dir, "samples", f"step_{global_step:07d}.png"), nrow=2) |
| ema.restore(model) |
| model.train() |
| |
| if global_step % config.save_every_n_steps == 0: |
| torch.save({ |
| "model": model.state_dict(), "ema": ema.state_dict(), |
| "optimizer": optimizer.state_dict(), "scheduler": scheduler.state_dict(), |
| "global_step": global_step, "epoch": epoch, "config": asdict(config), |
| }, os.path.join(config.output_dir, "checkpoints", f"step_{global_step:07d}.pt")) |
| |
| print(f"Epoch {epoch} complete | time={time.time()-t_start:.0f}s") |
| |
| torch.save({"model": model.state_dict(), "ema": ema.state_dict(), "config": asdict(config), |
| "global_step": global_step}, os.path.join(config.output_dir, "checkpoints", "final.pt")) |
| print(f"Training complete! Final model saved.") |
|
|
|
|
| if __name__ == "__main__": |
| config = TrainConfig(model_size="small", image_size=256, batch_size=4, num_epochs=2) |
| train(config) |
|
|