""" Rectified Flow Training for LiquidDiffusion Training Objective (Rectified Flow): x_t = (1-t)*x0 + t*x1, t ~ U[0,1], x1 ~ N(0,I) v_target = x1 - x0 (constant velocity) L = E[||v_θ(x_t, t) - v_target||²] (simple MSE) Sampling (Euler ODE): Start from x_1 ~ N(0,I), integrate backward: x_{t-dt} = x_t - v_θ(x_t, t) * dt References: [1] Liu et al., "Flow Straight and Fast: Rectified Flow", ICLR 2023 [2] Lee et al., "Improving the Training of Rectified Flows", 2024 """ import math import copy import os import time import json import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset from torchvision import transforms from torchvision.utils import save_image, make_grid class RectifiedFlowTrainer: """Trainer for LiquidDiffusion using Rectified Flow objective.""" def __init__(self, model, optimizer=None, lr=1e-4, weight_decay=0.01, ema_decay=0.9999, grad_clip=1.0, time_sampling="logit_normal", logit_normal_mean=0.0, logit_normal_std=1.0, device="cuda", use_amp=True, amp_dtype="float16"): self.model = model.to(device) self.device = device self.ema_decay = ema_decay self.grad_clip = grad_clip self.time_sampling = time_sampling self.logit_normal_mean = logit_normal_mean self.logit_normal_std = logit_normal_std self.use_amp = use_amp and device == "cuda" self.amp_dtype = getattr(torch, amp_dtype) if self.use_amp else torch.float32 if optimizer is None: self.optimizer = torch.optim.AdamW( model.parameters(), lr=lr, weight_decay=weight_decay, betas=(0.9, 0.999)) else: self.optimizer = optimizer self.scaler = torch.amp.GradScaler("cuda", enabled=(self.use_amp and amp_dtype == "float16")) self.ema_model = self._build_ema() self.step = 0 self.losses = [] def _build_ema(self): ema = copy.deepcopy(self.model) ema.eval() for p in ema.parameters(): p.requires_grad_(False) return ema @torch.no_grad() def _update_ema(self): for ema_p, model_p in zip(self.ema_model.parameters(), self.model.parameters()): ema_p.data.mul_(self.ema_decay).add_(model_p.data, alpha=1 - self.ema_decay) def _sample_time(self, batch_size): eps = 1e-5 if self.time_sampling == "uniform": return torch.rand(batch_size, device=self.device) * (1 - 2*eps) + eps elif self.time_sampling == "logit_normal": u = torch.randn(batch_size, device=self.device) * self.logit_normal_std + self.logit_normal_mean return torch.sigmoid(u).clamp(eps, 1 - eps) raise ValueError(f"Unknown time_sampling: {self.time_sampling}") def train_step(self, x0): self.model.train() x1 = torch.randn_like(x0) t = self._sample_time(x0.shape[0]) t_expand = t[:, None, None, None] x_t = (1 - t_expand) * x0 + t_expand * x1 v_target = x1 - x0 with torch.amp.autocast(self.device, dtype=self.amp_dtype, enabled=self.use_amp): v_pred = self.model(x_t, t) loss = F.mse_loss(v_pred, v_target) self.optimizer.zero_grad(set_to_none=True) self.scaler.scale(loss).backward() if self.grad_clip > 0: self.scaler.unscale_(self.optimizer) grad_norm = torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip) else: grad_norm = torch.tensor(0.0) self.scaler.step(self.optimizer) self.scaler.update() self._update_ema() self.step += 1 loss_val = loss.item() self.losses.append(loss_val) return {'loss': loss_val, 'grad_norm': grad_norm.item() if torch.is_tensor(grad_norm) else grad_norm, 'step': self.step} @torch.no_grad() def sample(self, batch_size=4, image_size=256, channels=3, num_steps=50, use_ema=True): model = self.ema_model if use_ema else self.model model.eval() z = torch.randn(batch_size, channels, image_size, image_size, device=self.device) dt = 1.0 / num_steps for i in range(num_steps, 0, -1): t = torch.full((batch_size,), i / num_steps, device=self.device) with torch.amp.autocast(self.device, dtype=self.amp_dtype, enabled=self.use_amp): v = model(z, t) if self.use_amp and self.amp_dtype == torch.float16: v = v.float() z = z - v * dt return z.clamp(-1, 1) def save_checkpoint(self, path, extra=None): ckpt = {'model': self.model.state_dict(), 'ema_model': self.ema_model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'scaler': self.scaler.state_dict(), 'step': self.step, 'losses': self.losses[-1000:]} if extra: ckpt.update(extra) os.makedirs(os.path.dirname(path) if os.path.dirname(path) else '.', exist_ok=True) torch.save(ckpt, path) def load_checkpoint(self, path): ckpt = torch.load(path, map_location=self.device, weights_only=False) self.model.load_state_dict(ckpt['model']) self.ema_model.load_state_dict(ckpt['ema_model']) self.optimizer.load_state_dict(ckpt['optimizer']) if 'scaler' in ckpt: self.scaler.load_state_dict(ckpt['scaler']) self.step = ckpt.get('step', 0) self.losses = ckpt.get('losses', []) print(f"Resumed from step {self.step}") class ImageDataset(Dataset): """Image dataset from local folder or HuggingFace Hub.""" def __init__(self, source, image_size=256, split="train", image_column="image", max_samples=None, hf_dataset=None): self.image_size = image_size self.image_column = image_column self.transform = transforms.Compose([ transforms.Resize(image_size, interpolation=transforms.InterpolationMode.LANCZOS), transforms.CenterCrop(image_size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ]) if hf_dataset is not None: self.data = hf_dataset self.mode = "hf" elif source and os.path.isdir(source): from glob import glob self.files = [] for ext in ['*.png', '*.jpg', '*.jpeg', '*.webp', '*.bmp']: self.files.extend(glob(os.path.join(source, '**', ext), recursive=True)) self.files.sort() if max_samples: self.files = self.files[:max_samples] self.mode = "folder" else: from datasets import load_dataset self.data = load_dataset(source, split=split) if max_samples: self.data = self.data.select(range(min(max_samples, len(self.data)))) self.mode = "hf" def __len__(self): return len(self.files) if self.mode == "folder" else len(self.data) def __getitem__(self, idx): if self.mode == "folder": from PIL import Image img = Image.open(self.files[idx]).convert("RGB") else: img = self.data[idx][self.image_column] if not hasattr(img, 'convert'): from PIL import Image as PILImage img = PILImage.fromarray(img) img = img.convert("RGB") return self.transform(img) def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps): def lr_lambda(step): if step < num_warmup_steps: return float(step) / float(max(1, num_warmup_steps)) progress = float(step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress))) return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)