File size: 12,129 Bytes
f8a7028 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 | """
LiquidFlow Generator — Main diffusion model.
Combines:
- LiquidFlowBackbone (CfC + Mamba-2 SSD) as the noise predictor
- DDPM/DDIM diffusion process
- Physics-informed regularization
Supports:
- Training on 128×128 and 512×512 images
- TAESD VAE (lightweight, Colab/Kaggle compatible)
- SD VAE (higher quality)
- Both DDPM and DDIM sampling
The model is designed to be:
- Trainable on Google Colab free tier / Kaggle (T4 GPU, 15GB)
- Exportable to ONNX/CoreML for mobile deployment
- Pure PyTorch — no CUDA kernels needed (Mamba-2 SSD runs on CPU too)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
from tqdm import tqdm
from typing import Optional, Dict, Tuple
from .liquid_flow_block import LiquidFlowBackbone
from .physics_loss import PhysicsRegularizer, DDIMEstimator
def linear_beta_schedule(timesteps, beta_start=1e-4, beta_end=0.02):
"""Linear noise schedule (DDPM)."""
return torch.linspace(beta_start, beta_end, timesteps)
def cosine_beta_schedule(timesteps, s=0.008):
"""Cosine noise schedule (Improved DDPM)."""
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0.0001, 0.9999)
class LiquidFlowGenerator(nn.Module):
"""
LiquidFlow Generator: Liquid Neural Network + Mamba-2 SSD Diffusion Model.
Uses LiquidFlowBackbone as noise predictor in a DDPM/DDIM framework.
Architecture:
Noise Predictor = LiquidFlowBackbone (CfC + Mamba-2 SSD)
Diffusion = DDPM (forward) + DDIM (sampling)
Regularizer = Physics-Informed Losses (TV, spectral, conservation)
Args:
in_channels: Latent channels from VAE (default 4)
hidden_dim: Hidden dimension in backbone
num_stages: Number of LiquidFlow stages
blocks_per_stage: Blocks per stage
image_size: Target image size (for latent computation)
beta_schedule: 'linear' or 'cosine'
timesteps: Number of diffusion timesteps
physics_weights: Weights for physics regularizers
"""
def __init__(
self,
in_channels=4,
hidden_dim=256,
num_stages=4,
blocks_per_stage=4,
image_size=128,
beta_schedule='cosine',
timesteps=1000,
physics_weights=None,
):
super().__init__()
self.in_channels = in_channels
self.hidden_dim = hidden_dim
self.image_size = image_size # Latent space size = image_size / 8
self.timesteps = timesteps
# Noise predictor (backbone)
self.backbone = LiquidFlowBackbone(
in_channels=in_channels,
hidden_dim=hidden_dim,
num_stages=num_stages,
blocks_per_stage=blocks_per_stage,
d_state=16,
expand=2,
dropout=0.0,
)
# Diffusion schedule
if beta_schedule == 'linear':
betas = linear_beta_schedule(timesteps)
else:
betas = cosine_beta_schedule(timesteps)
self.register_buffer('betas', betas)
self.register_buffer('alphas', 1.0 - betas)
self.register_buffer('alphas_cumprod', torch.cumprod(self.alphas, dim=0))
self.register_buffer('alphas_cumprod_prev', F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0))
# For DDIM sampling
self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(self.alphas_cumprod))
self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1.0 - self.alphas_cumprod))
# Physics regularizer
if physics_weights is None:
physics_weights = {'tv': 0.01, 'cons': 0.001, 'spec': 0.01, 'grad': 0.001}
self.physics = PhysicsRegularizer(**physics_weights)
self.ddim_estimator = DDIMEstimator()
def q_sample(self, x0, t, noise=None):
"""
Forward diffusion: q(x_t | x_0).
x_t = √(ᾱ_t) * x_0 + √(1 - ᾱ_t) * ε
"""
if noise is None:
noise = torch.randn_like(x0)
sqrt_alpha_bar = self.sqrt_alphas_cumprod[t].reshape(-1, 1, 1, 1)
sqrt_one_minus_alpha_bar = self.sqrt_one_minus_alphas_cumprod[t].reshape(-1, 1, 1, 1)
return sqrt_alpha_bar * x0 + sqrt_one_minus_alpha_bar * noise, noise
def forward(self, x, t):
"""Predict noise from noisy input."""
return self.backbone(x, t)
def training_step(self, x0, optimizer, scaler=None, use_amp=False):
"""
Single training step with physics regularization.
Args:
x0: Clean latents [B, C, H, W]
optimizer: Optimizer
scaler: Optional GradScaler for AMP
use_amp: Whether to use automatic mixed precision
Returns:
loss_dict: Dictionary of losses
"""
B = x0.shape[0]
device = x0.device
# Sample timesteps
t = torch.randint(0, self.timesteps, (B,), device=device)
# Forward diffusion
noise = torch.randn_like(x0)
xt, noise = self.q_sample(x0, t, noise)
if use_amp and scaler is not None:
with torch.cuda.amp.autocast():
# Predict noise
noise_pred = self.forward(xt, t)
# Base diffusion loss (L2 or L1)
diffusion_loss = F.mse_loss(noise_pred, noise)
# Physics regularization on estimated x0
x0_hat = self.ddim_estimator.estimate_x0(
xt, noise_pred, self.alphas_cumprod[t]
)
phys_loss, phys_dict = self.physics(x0_hat, x0)
total_loss = diffusion_loss + phys_loss
else:
noise_pred = self.forward(xt, t)
diffusion_loss = F.mse_loss(noise_pred, noise)
x0_hat = self.ddim_estimator.estimate_x0(
xt, noise_pred, self.alphas_cumprod[t]
)
phys_loss, phys_dict = self.physics(x0_hat, x0)
total_loss = diffusion_loss + phys_loss
# Backward
optimizer.zero_grad()
if scaler is not None:
scaler.scale(total_loss).backward()
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(self.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
else:
total_loss.backward()
torch.nn.utils.clip_grad_norm_(self.parameters(), 1.0)
optimizer.step()
return {
'total': total_loss.item(),
'diffusion': diffusion_loss.item(),
'physics': phys_loss.item(),
**{f'phys_{k}': v.item() for k, v in phys_dict.items()},
}
@torch.no_grad()
def sample(self, batch_size=4, steps=50, ddim=True, eta=0.0, progress=True):
"""
Generate images using DDPM or DDIM sampling.
Args:
batch_size: Number of images
steps: Sampling steps (for DDIM: can be << timesteps)
ddim: Use DDIM sampling (faster)
eta: DDIM stochasticity (0 = deterministic)
progress: Show progress bar
Returns:
Generated latents [B, C, H, W]
"""
device = next(self.parameters()).device
latent_size = self.image_size // 8
# Start from pure noise
x = torch.randn(batch_size, self.in_channels, latent_size, latent_size, device=device)
if ddim:
return self._ddim_sample(x, steps, eta, progress)
else:
return self._ddpm_sample(x, progress)
@torch.no_grad()
def _ddpm_sample(self, x, progress=True):
"""DDPM sampling (full 1000 steps)."""
device = x.device
iterator = tqdm(
reversed(range(0, self.timesteps)),
desc='DDPM Sampling',
total=self.timesteps,
disable=not progress,
)
for t_idx in iterator:
t = torch.full((x.shape[0],), t_idx, device=device, dtype=torch.long)
noise_pred = self.forward(x, t)
alpha = self.alphas[t_idx]
alpha_bar = self.alphas_cumprod[t_idx]
alpha_bar_prev = self.alphas_cumprod_prev[t_idx]
beta = self.betas[t_idx]
if t_idx > 0:
noise = torch.randn_like(x)
else:
noise = 0
# DDPM posterior
x = (1 / torch.sqrt(alpha)) * (
x - (beta / torch.sqrt(1 - alpha_bar)) * noise_pred
) + torch.sqrt(beta) * noise
return x
@torch.no_grad()
def _ddim_sample(self, x, steps=50, eta=0.0, progress=True):
"""
DDIM sampling with fewer steps.
DDIM can produce good samples in 20-50 steps
instead of 1000 DDPM steps.
"""
device = x.device
# Timestep spacing
skip = self.timesteps // steps
seq = list(range(0, self.timesteps, skip))
seq_next = [-1] + seq[:-1]
iterator = tqdm(
zip(reversed(seq), reversed(seq_next)),
desc='DDIM Sampling',
total=len(seq),
disable=not progress,
)
for i, j in iterator:
t = torch.full((x.shape[0],), i, device=device, dtype=torch.long)
noise_pred = self.forward(x, t)
alpha_bar_i = self.alphas_cumprod[i]
alpha_bar_j = self.alphas_cumprod[j] if j >= 0 else torch.tensor(1.0, device=device)
# Predicted x0
x0_pred = (x - torch.sqrt(1 - alpha_bar_i) * noise_pred) / torch.sqrt(alpha_bar_i)
x0_pred = torch.clamp(x0_pred, -1, 1) # Prevent outliers
# Direction pointing to x_t
dir_xt = torch.sqrt(1 - alpha_bar_j - eta * eta * (
(1 - alpha_bar_j) / (1 - alpha_bar_i)
)) * noise_pred
# Random noise
if eta > 0:
noise = torch.randn_like(x)
sigma = eta * torch.sqrt((1 - alpha_bar_j) / (1 - alpha_bar_i) * (1 - alpha_bar_i / alpha_bar_j))
x = torch.sqrt(alpha_bar_j) * x0_pred + dir_xt + sigma * noise
else:
noise = 0
x = torch.sqrt(alpha_bar_j) * x0_pred + dir_xt
return x
def count_parameters(self):
"""Count trainable parameters."""
return sum(p.numel() for p in self.parameters() if p.requires_grad)
def create_liquidflow(
variant='small',
image_size=128,
**kwargs,
):
"""
Create a LiquidFlow model with preset configurations.
Variants:
- 'tiny': ~2M params, 2 stages, 2 blocks each, hidden_dim=128
- 'small': ~8M params, 4 stages, 4 blocks each, hidden_dim=256
- 'base': ~30M params, 6 stages, 6 blocks each, hidden_dim=384
All designed to run on T4 (15GB) with batch_size >= 16 at 128×128.
"""
configs = {
'tiny': {
'hidden_dim': 128,
'num_stages': 2,
'blocks_per_stage': 2,
},
'small': {
'hidden_dim': 256,
'num_stages': 4,
'blocks_per_stage': 4,
},
'base': {
'hidden_dim': 384,
'num_stages': 6,
'blocks_per_stage': 6,
},
}
config = configs.get(variant, configs['small'])
config.update(kwargs)
model = LiquidFlowGenerator(
in_channels=4, # VAE latent channels
image_size=image_size,
**config,
)
return model
|