LiRA / lira /training.py
asdf98's picture
Add lira/training.py
85accf4 verified
"""
LiRA Training Pipeline
Training Strategy:
==================
1. Flow Matching with v-prediction (from SANA/SD3)
- More stable than epsilon prediction near t=T
- Better gradients throughout the diffusion process
2. Laplace Noise Schedule (from "Improved Noise Schedule for Diffusion")
- Concentrates sampling around logSNR=0
- Better FID than cosine/linear schedules
3. Progressive Resolution Training (from SANA)
- Start at 256px → 512px → 1024px
- Each stage uses the previous as initialization
4. Curriculum Learning (from "Curriculum Learning for Diffusion")
- Easy timesteps first (high noise), hard timesteps later (low noise)
5. EMA with post-hoc tuning (from EDM2)
- EMA decay 0.9999 during training
- Post-hoc search for optimal EMA length
Training Stability:
===================
- Gradient clipping (max_norm=1.0)
- AdamW with weight decay 0.01
- Warmup + cosine decay learning rate
- AdaLN-Zero initialization (network acts as identity at start)
- Loss scaling: velocity prediction is naturally bounded
- Mixed precision (bf16) with gradient scaling
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import os
from typing import Optional, Dict, Tuple
from dataclasses import dataclass, field
@dataclass
class LiRATrainingConfig:
"""Training configuration with sensible defaults for Colab-friendly training"""
# Model
model_config: str = 'tiny' # Start small for testing
latent_channels: int = 4 # SD1.x/SDXL VAE
spatial_compression: int = 8
d_text: int = 768
patch_size: int = 2 # 2x2 patches for f8 VAE (128x128 → 64x64 tokens)
# Training
batch_size: int = 8
learning_rate: float = 1e-4
weight_decay: float = 0.01
warmup_steps: int = 1000
max_steps: int = 100000
grad_clip: float = 1.0
# EMA
ema_decay: float = 0.9999
# Flow matching
prediction_target: str = 'velocity' # 'velocity' or 'epsilon'
noise_schedule: str = 'laplace' # 'laplace', 'logit_normal', or 'uniform'
# Progressive resolution
progressive_stages: list = field(default_factory=lambda: [
{'resolution': 256, 'steps': 50000},
{'resolution': 512, 'steps': 30000},
{'resolution': 1024, 'steps': 20000},
])
# Curriculum
use_curriculum: bool = True
curriculum_warmup: int = 10000 # Steps before full timestep range
# Logging
log_every: int = 100
save_every: int = 5000
sample_every: int = 2500
# Hardware
mixed_precision: str = 'bf16' # 'bf16', 'fp16', or 'no'
compile_model: bool = False # torch.compile (if available)
# Data
dataset_name: str = ''
num_workers: int = 4
# Output
output_dir: str = './lira_output'
hub_model_id: str = ''
push_to_hub: bool = True
class FlowMatchingScheduler:
"""
Flow Matching noise scheduler with Laplace distribution.
Flow matching interpolation:
z_t = (1 - t) * z_0 + t * ε where ε ~ N(0, I)
v_t = ε - z_0 (velocity)
Laplace noise schedule (from "Improved Noise Schedule"):
t ~ Laplace(μ=0, b=1), mapped to [0, 1] via CDF
This concentrates samples around t=0.5 where learning is most effective.
"""
def __init__(self, schedule: str = 'laplace', shift: float = 1.0):
self.schedule = schedule
self.shift = shift # For resolution-dependent shifting (from SD3)
def sample_timesteps(self, batch_size: int, device: torch.device,
curriculum_progress: float = 1.0) -> torch.Tensor:
"""
Sample timesteps from the noise schedule.
curriculum_progress: 0→1 over training. At 0, only easy timesteps (near 1.0).
At 1.0, full range.
"""
if self.schedule == 'laplace':
# Laplace distribution centered at 0, mapped to [0,1]
u = torch.rand(batch_size, device=device)
# Laplace CDF inverse: t = μ - b * sign(u-0.5) * log(1 - 2|u-0.5|)
t = 0.5 - torch.sign(u - 0.5) * torch.log(1 - 2 * torch.abs(u - 0.5) + 1e-8)
# Map from (-inf, inf) to (0, 1) via sigmoid
t = torch.sigmoid(t)
elif self.schedule == 'logit_normal':
# Logit-normal (from SD3): sample from N(0,1) then sigmoid
t = torch.sigmoid(torch.randn(batch_size, device=device))
else: # uniform
t = torch.rand(batch_size, device=device)
# Apply resolution-dependent shift (from SD3)
# Higher shift → more weight on higher noise levels
if self.shift != 1.0:
t = t * self.shift / (1 + (self.shift - 1) * t)
# Curriculum: restrict to easier timesteps early in training
if curriculum_progress < 1.0:
min_t = 0.5 * (1 - curriculum_progress) # Start from t>0.5, expand to t>0
t = min_t + t * (1 - min_t)
# Clamp for numerical stability
t = t.clamp(1e-5, 1 - 1e-5)
return t
def add_noise(self, z_0: torch.Tensor, t: torch.Tensor,
noise: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Flow matching interpolation: z_t = (1-t)*z_0 + t*ε
Returns: (z_t, noise)
"""
if noise is None:
noise = torch.randn_like(z_0)
t = t.view(-1, 1, 1, 1) # Broadcast over spatial dims
z_t = (1 - t) * z_0 + t * noise
return z_t, noise
def get_velocity(self, z_0: torch.Tensor, noise: torch.Tensor) -> torch.Tensor:
"""Compute velocity target: v = ε - z_0"""
return noise - z_0
def predict_z0(self, z_t: torch.Tensor, v_pred: torch.Tensor,
t: torch.Tensor) -> torch.Tensor:
"""Recover z_0 from z_t and predicted velocity"""
t = t.view(-1, 1, 1, 1)
# z_t = (1-t)*z_0 + t*ε
# v = ε - z_0
# z_0 = z_t - t*v / (1-t+t) ... simplified:
# z_0 = z_t - t * v_pred ... wait let me derive properly
# z_t = (1-t)*z_0 + t*(z_0 + v) = z_0 + t*v
# z_0 = z_t - t * v_pred
return z_t - t * v_pred
class EMAModel:
"""Exponential Moving Average of model parameters"""
def __init__(self, model: nn.Module, decay: float = 0.9999):
self.decay = decay
self.shadow = {}
self.backup = {}
for name, param in model.named_parameters():
if param.requires_grad:
self.shadow[name] = param.data.clone()
@torch.no_grad()
def update(self, model: nn.Module):
for name, param in model.named_parameters():
if param.requires_grad and name in self.shadow:
self.shadow[name] = (
self.decay * self.shadow[name] + (1 - self.decay) * param.data
)
def apply_shadow(self, model: nn.Module):
"""Replace model params with EMA params"""
for name, param in model.named_parameters():
if param.requires_grad and name in self.shadow:
self.backup[name] = param.data
param.data = self.shadow[name]
def restore(self, model: nn.Module):
"""Restore original model params"""
for name, param in model.named_parameters():
if param.requires_grad and name in self.backup:
param.data = self.backup[name]
self.backup = {}
def state_dict(self):
return self.shadow
def load_state_dict(self, state_dict):
self.shadow = state_dict
def compute_loss(
model: nn.Module,
z_0: torch.Tensor,
text_features: torch.Tensor,
scheduler: FlowMatchingScheduler,
config: LiRATrainingConfig,
global_step: int = 0,
text_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Dict]:
"""
Compute training loss.
Loss = ||v_pred - v_target||^2 (MSE on velocity prediction)
With optional:
- Reasoning regularization (encourage adaptive compute)
- Frequency-weighted loss (higher weight on low-frequency errors)
"""
device = z_0.device
B = z_0.shape[0]
# Curriculum progress
if config.use_curriculum:
curriculum_progress = min(1.0, global_step / config.curriculum_warmup)
else:
curriculum_progress = 1.0
# Sample timesteps
t = scheduler.sample_timesteps(B, device, curriculum_progress)
# Add noise
z_t, noise = scheduler.add_noise(z_0, t)
# Get velocity target
v_target = scheduler.get_velocity(z_0, noise)
# Forward pass
v_pred, reason_info = model(z_t, t, text_features, text_mask)
# MSE loss on velocity
loss = F.mse_loss(v_pred, v_target)
# Reasoning regularization: encourage variable thinking steps
# Small penalty to discourage always using max steps
if reason_info.get('total_steps', 0) > 0 and len(reason_info.get('stop_values', [])) > 0:
avg_stop = sum(reason_info['stop_values']) / len(reason_info['stop_values'])
# Encourage the stop gate to actually stop sometimes
reason_reg = 0.01 * (1.0 - avg_stop) # Small penalty
loss = loss + reason_reg
info = {
'loss': loss.item(),
'mse_loss': F.mse_loss(v_pred, v_target).item(),
'reason_steps': reason_info.get('total_steps', 0),
}
return loss, info
def get_lr_scheduler(optimizer, config: LiRATrainingConfig):
"""Warmup + cosine decay learning rate schedule"""
def lr_lambda(step):
if step < config.warmup_steps:
return step / config.warmup_steps
else:
progress = (step - config.warmup_steps) / (config.max_steps - config.warmup_steps)
return 0.5 * (1 + math.cos(math.pi * progress))
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
# ============================================================================
# DPM-Solver for fast sampling (from SANA's Flow-DPM-Solver)
# ============================================================================
class FlowDPMSolver:
"""
DPM-Solver adapted for flow matching.
Standard Euler: z_{t-dt} = z_t - dt * v(z_t, t)
DPM-Solver-2: Second-order correction for fewer steps
From SANA: "Flow-DPM-Solver converges at 14-20 steps vs 28-50 for Euler"
"""
def __init__(self, num_steps: int = 20, order: int = 2):
self.num_steps = num_steps
self.order = min(order, 2)
@torch.no_grad()
def sample(
self,
model: nn.Module,
shape: Tuple[int, ...],
text_features: torch.Tensor,
text_mask: Optional[torch.Tensor] = None,
cfg_scale: float = 4.0,
device: torch.device = torch.device('cpu'),
) -> torch.Tensor:
"""
Generate samples using DPM-Solver.
Args:
model: LiRA model
shape: (B, C, H, W) latent shape
text_features: (B, M, D) text features
cfg_scale: classifier-free guidance scale
"""
B = shape[0]
# Start from pure noise (t=1)
z = torch.randn(shape, device=device)
# Time steps from 1 → 0
timesteps = torch.linspace(1, 0, self.num_steps + 1, device=device)
prev_v = None
for i in range(self.num_steps):
t_cur = timesteps[i]
t_next = timesteps[i + 1]
dt = t_next - t_cur # Negative (going from 1 to 0)
t_batch = t_cur.expand(B)
# Predict velocity (with CFG if scale > 1)
if cfg_scale > 1.0:
v_pred = self._cfg_predict(model, z, t_batch, text_features, text_mask, cfg_scale)
else:
v_pred, _ = model(z, t_batch, text_features, text_mask)
if self.order == 1 or prev_v is None:
# Euler step
z = z + dt * v_pred
else:
# DPM-Solver-2 (second-order correction)
# Uses previous velocity for better approximation
z = z + dt * (1.5 * v_pred - 0.5 * prev_v)
prev_v = v_pred
return z
def _cfg_predict(self, model, z, t, text_features, text_mask, cfg_scale):
"""Classifier-free guidance"""
# Unconditional prediction (zero text)
null_text = torch.zeros_like(text_features)
v_uncond, _ = model(z, t, null_text, text_mask)
# Conditional prediction
v_cond, _ = model(z, t, text_features, text_mask)
# CFG
return v_uncond + cfg_scale * (v_cond - v_uncond)