""" LiRA Training Pipeline Training Strategy: ================== 1. Flow Matching with v-prediction (from SANA/SD3) - More stable than epsilon prediction near t=T - Better gradients throughout the diffusion process 2. Laplace Noise Schedule (from "Improved Noise Schedule for Diffusion") - Concentrates sampling around logSNR=0 - Better FID than cosine/linear schedules 3. Progressive Resolution Training (from SANA) - Start at 256px → 512px → 1024px - Each stage uses the previous as initialization 4. Curriculum Learning (from "Curriculum Learning for Diffusion") - Easy timesteps first (high noise), hard timesteps later (low noise) 5. EMA with post-hoc tuning (from EDM2) - EMA decay 0.9999 during training - Post-hoc search for optimal EMA length Training Stability: =================== - Gradient clipping (max_norm=1.0) - AdamW with weight decay 0.01 - Warmup + cosine decay learning rate - AdaLN-Zero initialization (network acts as identity at start) - Loss scaling: velocity prediction is naturally bounded - Mixed precision (bf16) with gradient scaling """ import torch import torch.nn as nn import torch.nn.functional as F import math import os from typing import Optional, Dict, Tuple from dataclasses import dataclass, field @dataclass class LiRATrainingConfig: """Training configuration with sensible defaults for Colab-friendly training""" # Model model_config: str = 'tiny' # Start small for testing latent_channels: int = 4 # SD1.x/SDXL VAE spatial_compression: int = 8 d_text: int = 768 patch_size: int = 2 # 2x2 patches for f8 VAE (128x128 → 64x64 tokens) # Training batch_size: int = 8 learning_rate: float = 1e-4 weight_decay: float = 0.01 warmup_steps: int = 1000 max_steps: int = 100000 grad_clip: float = 1.0 # EMA ema_decay: float = 0.9999 # Flow matching prediction_target: str = 'velocity' # 'velocity' or 'epsilon' noise_schedule: str = 'laplace' # 'laplace', 'logit_normal', or 'uniform' # Progressive resolution progressive_stages: list = field(default_factory=lambda: [ {'resolution': 256, 'steps': 50000}, {'resolution': 512, 'steps': 30000}, {'resolution': 1024, 'steps': 20000}, ]) # Curriculum use_curriculum: bool = True curriculum_warmup: int = 10000 # Steps before full timestep range # Logging log_every: int = 100 save_every: int = 5000 sample_every: int = 2500 # Hardware mixed_precision: str = 'bf16' # 'bf16', 'fp16', or 'no' compile_model: bool = False # torch.compile (if available) # Data dataset_name: str = '' num_workers: int = 4 # Output output_dir: str = './lira_output' hub_model_id: str = '' push_to_hub: bool = True class FlowMatchingScheduler: """ Flow Matching noise scheduler with Laplace distribution. Flow matching interpolation: z_t = (1 - t) * z_0 + t * ε where ε ~ N(0, I) v_t = ε - z_0 (velocity) Laplace noise schedule (from "Improved Noise Schedule"): t ~ Laplace(μ=0, b=1), mapped to [0, 1] via CDF This concentrates samples around t=0.5 where learning is most effective. """ def __init__(self, schedule: str = 'laplace', shift: float = 1.0): self.schedule = schedule self.shift = shift # For resolution-dependent shifting (from SD3) def sample_timesteps(self, batch_size: int, device: torch.device, curriculum_progress: float = 1.0) -> torch.Tensor: """ Sample timesteps from the noise schedule. curriculum_progress: 0→1 over training. At 0, only easy timesteps (near 1.0). At 1.0, full range. """ if self.schedule == 'laplace': # Laplace distribution centered at 0, mapped to [0,1] u = torch.rand(batch_size, device=device) # Laplace CDF inverse: t = μ - b * sign(u-0.5) * log(1 - 2|u-0.5|) t = 0.5 - torch.sign(u - 0.5) * torch.log(1 - 2 * torch.abs(u - 0.5) + 1e-8) # Map from (-inf, inf) to (0, 1) via sigmoid t = torch.sigmoid(t) elif self.schedule == 'logit_normal': # Logit-normal (from SD3): sample from N(0,1) then sigmoid t = torch.sigmoid(torch.randn(batch_size, device=device)) else: # uniform t = torch.rand(batch_size, device=device) # Apply resolution-dependent shift (from SD3) # Higher shift → more weight on higher noise levels if self.shift != 1.0: t = t * self.shift / (1 + (self.shift - 1) * t) # Curriculum: restrict to easier timesteps early in training if curriculum_progress < 1.0: min_t = 0.5 * (1 - curriculum_progress) # Start from t>0.5, expand to t>0 t = min_t + t * (1 - min_t) # Clamp for numerical stability t = t.clamp(1e-5, 1 - 1e-5) return t def add_noise(self, z_0: torch.Tensor, t: torch.Tensor, noise: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: """ Flow matching interpolation: z_t = (1-t)*z_0 + t*ε Returns: (z_t, noise) """ if noise is None: noise = torch.randn_like(z_0) t = t.view(-1, 1, 1, 1) # Broadcast over spatial dims z_t = (1 - t) * z_0 + t * noise return z_t, noise def get_velocity(self, z_0: torch.Tensor, noise: torch.Tensor) -> torch.Tensor: """Compute velocity target: v = ε - z_0""" return noise - z_0 def predict_z0(self, z_t: torch.Tensor, v_pred: torch.Tensor, t: torch.Tensor) -> torch.Tensor: """Recover z_0 from z_t and predicted velocity""" t = t.view(-1, 1, 1, 1) # z_t = (1-t)*z_0 + t*ε # v = ε - z_0 # z_0 = z_t - t*v / (1-t+t) ... simplified: # z_0 = z_t - t * v_pred ... wait let me derive properly # z_t = (1-t)*z_0 + t*(z_0 + v) = z_0 + t*v # z_0 = z_t - t * v_pred return z_t - t * v_pred class EMAModel: """Exponential Moving Average of model parameters""" def __init__(self, model: nn.Module, decay: float = 0.9999): self.decay = decay self.shadow = {} self.backup = {} for name, param in model.named_parameters(): if param.requires_grad: self.shadow[name] = param.data.clone() @torch.no_grad() def update(self, model: nn.Module): for name, param in model.named_parameters(): if param.requires_grad and name in self.shadow: self.shadow[name] = ( self.decay * self.shadow[name] + (1 - self.decay) * param.data ) def apply_shadow(self, model: nn.Module): """Replace model params with EMA params""" for name, param in model.named_parameters(): if param.requires_grad and name in self.shadow: self.backup[name] = param.data param.data = self.shadow[name] def restore(self, model: nn.Module): """Restore original model params""" for name, param in model.named_parameters(): if param.requires_grad and name in self.backup: param.data = self.backup[name] self.backup = {} def state_dict(self): return self.shadow def load_state_dict(self, state_dict): self.shadow = state_dict def compute_loss( model: nn.Module, z_0: torch.Tensor, text_features: torch.Tensor, scheduler: FlowMatchingScheduler, config: LiRATrainingConfig, global_step: int = 0, text_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, Dict]: """ Compute training loss. Loss = ||v_pred - v_target||^2 (MSE on velocity prediction) With optional: - Reasoning regularization (encourage adaptive compute) - Frequency-weighted loss (higher weight on low-frequency errors) """ device = z_0.device B = z_0.shape[0] # Curriculum progress if config.use_curriculum: curriculum_progress = min(1.0, global_step / config.curriculum_warmup) else: curriculum_progress = 1.0 # Sample timesteps t = scheduler.sample_timesteps(B, device, curriculum_progress) # Add noise z_t, noise = scheduler.add_noise(z_0, t) # Get velocity target v_target = scheduler.get_velocity(z_0, noise) # Forward pass v_pred, reason_info = model(z_t, t, text_features, text_mask) # MSE loss on velocity loss = F.mse_loss(v_pred, v_target) # Reasoning regularization: encourage variable thinking steps # Small penalty to discourage always using max steps if reason_info.get('total_steps', 0) > 0 and len(reason_info.get('stop_values', [])) > 0: avg_stop = sum(reason_info['stop_values']) / len(reason_info['stop_values']) # Encourage the stop gate to actually stop sometimes reason_reg = 0.01 * (1.0 - avg_stop) # Small penalty loss = loss + reason_reg info = { 'loss': loss.item(), 'mse_loss': F.mse_loss(v_pred, v_target).item(), 'reason_steps': reason_info.get('total_steps', 0), } return loss, info def get_lr_scheduler(optimizer, config: LiRATrainingConfig): """Warmup + cosine decay learning rate schedule""" def lr_lambda(step): if step < config.warmup_steps: return step / config.warmup_steps else: progress = (step - config.warmup_steps) / (config.max_steps - config.warmup_steps) return 0.5 * (1 + math.cos(math.pi * progress)) return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) # ============================================================================ # DPM-Solver for fast sampling (from SANA's Flow-DPM-Solver) # ============================================================================ class FlowDPMSolver: """ DPM-Solver adapted for flow matching. Standard Euler: z_{t-dt} = z_t - dt * v(z_t, t) DPM-Solver-2: Second-order correction for fewer steps From SANA: "Flow-DPM-Solver converges at 14-20 steps vs 28-50 for Euler" """ def __init__(self, num_steps: int = 20, order: int = 2): self.num_steps = num_steps self.order = min(order, 2) @torch.no_grad() def sample( self, model: nn.Module, shape: Tuple[int, ...], text_features: torch.Tensor, text_mask: Optional[torch.Tensor] = None, cfg_scale: float = 4.0, device: torch.device = torch.device('cpu'), ) -> torch.Tensor: """ Generate samples using DPM-Solver. Args: model: LiRA model shape: (B, C, H, W) latent shape text_features: (B, M, D) text features cfg_scale: classifier-free guidance scale """ B = shape[0] # Start from pure noise (t=1) z = torch.randn(shape, device=device) # Time steps from 1 → 0 timesteps = torch.linspace(1, 0, self.num_steps + 1, device=device) prev_v = None for i in range(self.num_steps): t_cur = timesteps[i] t_next = timesteps[i + 1] dt = t_next - t_cur # Negative (going from 1 to 0) t_batch = t_cur.expand(B) # Predict velocity (with CFG if scale > 1) if cfg_scale > 1.0: v_pred = self._cfg_predict(model, z, t_batch, text_features, text_mask, cfg_scale) else: v_pred, _ = model(z, t_batch, text_features, text_mask) if self.order == 1 or prev_v is None: # Euler step z = z + dt * v_pred else: # DPM-Solver-2 (second-order correction) # Uses previous velocity for better approximation z = z + dt * (1.5 * v_pred - 0.5 * prev_v) prev_v = v_pred return z def _cfg_predict(self, model, z, t, text_features, text_mask, cfg_scale): """Classifier-free guidance""" # Unconditional prediction (zero text) null_text = torch.zeros_like(text_features) v_uncond, _ = model(z, t, null_text, text_mask) # Conditional prediction v_cond, _ = model(z, t, text_features, text_mask) # CFG return v_uncond + cfg_scale * (v_cond - v_uncond)