| """ |
| LiRA Training Pipeline |
| |
| Training Strategy: |
| ================== |
| 1. Flow Matching with v-prediction (from SANA/SD3) |
| - More stable than epsilon prediction near t=T |
| - Better gradients throughout the diffusion process |
| |
| 2. Laplace Noise Schedule (from "Improved Noise Schedule for Diffusion") |
| - Concentrates sampling around logSNR=0 |
| - Better FID than cosine/linear schedules |
| |
| 3. Progressive Resolution Training (from SANA) |
| - Start at 256px → 512px → 1024px |
| - Each stage uses the previous as initialization |
| |
| 4. Curriculum Learning (from "Curriculum Learning for Diffusion") |
| - Easy timesteps first (high noise), hard timesteps later (low noise) |
| |
| 5. EMA with post-hoc tuning (from EDM2) |
| - EMA decay 0.9999 during training |
| - Post-hoc search for optimal EMA length |
| |
| Training Stability: |
| =================== |
| - Gradient clipping (max_norm=1.0) |
| - AdamW with weight decay 0.01 |
| - Warmup + cosine decay learning rate |
| - AdaLN-Zero initialization (network acts as identity at start) |
| - Loss scaling: velocity prediction is naturally bounded |
| - Mixed precision (bf16) with gradient scaling |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import math |
| import os |
| from typing import Optional, Dict, Tuple |
| from dataclasses import dataclass, field |
|
|
|
|
| @dataclass |
| class LiRATrainingConfig: |
| """Training configuration with sensible defaults for Colab-friendly training""" |
| |
| |
| model_config: str = 'tiny' |
| latent_channels: int = 4 |
| spatial_compression: int = 8 |
| d_text: int = 768 |
| patch_size: int = 2 |
| |
| |
| batch_size: int = 8 |
| learning_rate: float = 1e-4 |
| weight_decay: float = 0.01 |
| warmup_steps: int = 1000 |
| max_steps: int = 100000 |
| grad_clip: float = 1.0 |
| |
| |
| ema_decay: float = 0.9999 |
| |
| |
| prediction_target: str = 'velocity' |
| noise_schedule: str = 'laplace' |
| |
| |
| progressive_stages: list = field(default_factory=lambda: [ |
| {'resolution': 256, 'steps': 50000}, |
| {'resolution': 512, 'steps': 30000}, |
| {'resolution': 1024, 'steps': 20000}, |
| ]) |
| |
| |
| use_curriculum: bool = True |
| curriculum_warmup: int = 10000 |
| |
| |
| log_every: int = 100 |
| save_every: int = 5000 |
| sample_every: int = 2500 |
| |
| |
| mixed_precision: str = 'bf16' |
| compile_model: bool = False |
| |
| |
| dataset_name: str = '' |
| num_workers: int = 4 |
| |
| |
| output_dir: str = './lira_output' |
| hub_model_id: str = '' |
| push_to_hub: bool = True |
|
|
|
|
| class FlowMatchingScheduler: |
| """ |
| Flow Matching noise scheduler with Laplace distribution. |
| |
| Flow matching interpolation: |
| z_t = (1 - t) * z_0 + t * ε where ε ~ N(0, I) |
| v_t = ε - z_0 (velocity) |
| |
| Laplace noise schedule (from "Improved Noise Schedule"): |
| t ~ Laplace(μ=0, b=1), mapped to [0, 1] via CDF |
| This concentrates samples around t=0.5 where learning is most effective. |
| """ |
| |
| def __init__(self, schedule: str = 'laplace', shift: float = 1.0): |
| self.schedule = schedule |
| self.shift = shift |
| |
| def sample_timesteps(self, batch_size: int, device: torch.device, |
| curriculum_progress: float = 1.0) -> torch.Tensor: |
| """ |
| Sample timesteps from the noise schedule. |
| |
| curriculum_progress: 0→1 over training. At 0, only easy timesteps (near 1.0). |
| At 1.0, full range. |
| """ |
| if self.schedule == 'laplace': |
| |
| u = torch.rand(batch_size, device=device) |
| |
| t = 0.5 - torch.sign(u - 0.5) * torch.log(1 - 2 * torch.abs(u - 0.5) + 1e-8) |
| |
| t = torch.sigmoid(t) |
| |
| elif self.schedule == 'logit_normal': |
| |
| t = torch.sigmoid(torch.randn(batch_size, device=device)) |
| |
| else: |
| t = torch.rand(batch_size, device=device) |
| |
| |
| |
| if self.shift != 1.0: |
| t = t * self.shift / (1 + (self.shift - 1) * t) |
| |
| |
| if curriculum_progress < 1.0: |
| min_t = 0.5 * (1 - curriculum_progress) |
| t = min_t + t * (1 - min_t) |
| |
| |
| t = t.clamp(1e-5, 1 - 1e-5) |
| |
| return t |
| |
| def add_noise(self, z_0: torch.Tensor, t: torch.Tensor, |
| noise: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Flow matching interpolation: z_t = (1-t)*z_0 + t*ε |
| |
| Returns: (z_t, noise) |
| """ |
| if noise is None: |
| noise = torch.randn_like(z_0) |
| |
| t = t.view(-1, 1, 1, 1) |
| z_t = (1 - t) * z_0 + t * noise |
| |
| return z_t, noise |
| |
| def get_velocity(self, z_0: torch.Tensor, noise: torch.Tensor) -> torch.Tensor: |
| """Compute velocity target: v = ε - z_0""" |
| return noise - z_0 |
| |
| def predict_z0(self, z_t: torch.Tensor, v_pred: torch.Tensor, |
| t: torch.Tensor) -> torch.Tensor: |
| """Recover z_0 from z_t and predicted velocity""" |
| t = t.view(-1, 1, 1, 1) |
| |
| |
| |
| |
| |
| |
| return z_t - t * v_pred |
|
|
|
|
| class EMAModel: |
| """Exponential Moving Average of model parameters""" |
| |
| def __init__(self, model: nn.Module, decay: float = 0.9999): |
| self.decay = decay |
| self.shadow = {} |
| self.backup = {} |
| |
| for name, param in model.named_parameters(): |
| if param.requires_grad: |
| self.shadow[name] = param.data.clone() |
| |
| @torch.no_grad() |
| def update(self, model: nn.Module): |
| for name, param in model.named_parameters(): |
| if param.requires_grad and name in self.shadow: |
| self.shadow[name] = ( |
| self.decay * self.shadow[name] + (1 - self.decay) * param.data |
| ) |
| |
| def apply_shadow(self, model: nn.Module): |
| """Replace model params with EMA params""" |
| for name, param in model.named_parameters(): |
| if param.requires_grad and name in self.shadow: |
| self.backup[name] = param.data |
| param.data = self.shadow[name] |
| |
| def restore(self, model: nn.Module): |
| """Restore original model params""" |
| for name, param in model.named_parameters(): |
| if param.requires_grad and name in self.backup: |
| param.data = self.backup[name] |
| self.backup = {} |
| |
| def state_dict(self): |
| return self.shadow |
| |
| def load_state_dict(self, state_dict): |
| self.shadow = state_dict |
|
|
|
|
| def compute_loss( |
| model: nn.Module, |
| z_0: torch.Tensor, |
| text_features: torch.Tensor, |
| scheduler: FlowMatchingScheduler, |
| config: LiRATrainingConfig, |
| global_step: int = 0, |
| text_mask: Optional[torch.Tensor] = None, |
| ) -> Tuple[torch.Tensor, Dict]: |
| """ |
| Compute training loss. |
| |
| Loss = ||v_pred - v_target||^2 (MSE on velocity prediction) |
| |
| With optional: |
| - Reasoning regularization (encourage adaptive compute) |
| - Frequency-weighted loss (higher weight on low-frequency errors) |
| """ |
| device = z_0.device |
| B = z_0.shape[0] |
| |
| |
| if config.use_curriculum: |
| curriculum_progress = min(1.0, global_step / config.curriculum_warmup) |
| else: |
| curriculum_progress = 1.0 |
| |
| |
| t = scheduler.sample_timesteps(B, device, curriculum_progress) |
| |
| |
| z_t, noise = scheduler.add_noise(z_0, t) |
| |
| |
| v_target = scheduler.get_velocity(z_0, noise) |
| |
| |
| v_pred, reason_info = model(z_t, t, text_features, text_mask) |
| |
| |
| loss = F.mse_loss(v_pred, v_target) |
| |
| |
| |
| if reason_info.get('total_steps', 0) > 0 and len(reason_info.get('stop_values', [])) > 0: |
| avg_stop = sum(reason_info['stop_values']) / len(reason_info['stop_values']) |
| |
| reason_reg = 0.01 * (1.0 - avg_stop) |
| loss = loss + reason_reg |
| |
| info = { |
| 'loss': loss.item(), |
| 'mse_loss': F.mse_loss(v_pred, v_target).item(), |
| 'reason_steps': reason_info.get('total_steps', 0), |
| } |
| |
| return loss, info |
|
|
|
|
| def get_lr_scheduler(optimizer, config: LiRATrainingConfig): |
| """Warmup + cosine decay learning rate schedule""" |
| |
| def lr_lambda(step): |
| if step < config.warmup_steps: |
| return step / config.warmup_steps |
| else: |
| progress = (step - config.warmup_steps) / (config.max_steps - config.warmup_steps) |
| return 0.5 * (1 + math.cos(math.pi * progress)) |
| |
| return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) |
|
|
|
|
| |
| |
| |
|
|
| class FlowDPMSolver: |
| """ |
| DPM-Solver adapted for flow matching. |
| |
| Standard Euler: z_{t-dt} = z_t - dt * v(z_t, t) |
| DPM-Solver-2: Second-order correction for fewer steps |
| |
| From SANA: "Flow-DPM-Solver converges at 14-20 steps vs 28-50 for Euler" |
| """ |
| |
| def __init__(self, num_steps: int = 20, order: int = 2): |
| self.num_steps = num_steps |
| self.order = min(order, 2) |
| |
| @torch.no_grad() |
| def sample( |
| self, |
| model: nn.Module, |
| shape: Tuple[int, ...], |
| text_features: torch.Tensor, |
| text_mask: Optional[torch.Tensor] = None, |
| cfg_scale: float = 4.0, |
| device: torch.device = torch.device('cpu'), |
| ) -> torch.Tensor: |
| """ |
| Generate samples using DPM-Solver. |
| |
| Args: |
| model: LiRA model |
| shape: (B, C, H, W) latent shape |
| text_features: (B, M, D) text features |
| cfg_scale: classifier-free guidance scale |
| """ |
| B = shape[0] |
| |
| |
| z = torch.randn(shape, device=device) |
| |
| |
| timesteps = torch.linspace(1, 0, self.num_steps + 1, device=device) |
| |
| prev_v = None |
| |
| for i in range(self.num_steps): |
| t_cur = timesteps[i] |
| t_next = timesteps[i + 1] |
| dt = t_next - t_cur |
| |
| t_batch = t_cur.expand(B) |
| |
| |
| if cfg_scale > 1.0: |
| v_pred = self._cfg_predict(model, z, t_batch, text_features, text_mask, cfg_scale) |
| else: |
| v_pred, _ = model(z, t_batch, text_features, text_mask) |
| |
| if self.order == 1 or prev_v is None: |
| |
| z = z + dt * v_pred |
| else: |
| |
| |
| z = z + dt * (1.5 * v_pred - 0.5 * prev_v) |
| |
| prev_v = v_pred |
| |
| return z |
| |
| def _cfg_predict(self, model, z, t, text_features, text_mask, cfg_scale): |
| """Classifier-free guidance""" |
| |
| null_text = torch.zeros_like(text_features) |
| v_uncond, _ = model(z, t, null_text, text_mask) |
| |
| |
| v_cond, _ = model(z, t, text_features, text_mask) |
| |
| |
| return v_uncond + cfg_scale * (v_cond - v_uncond) |
|
|