krystv commited on
Commit
f8a7028
·
verified ·
1 Parent(s): 1a3345c

Upload liquid_flow/generator.py

Browse files
Files changed (1) hide show
  1. liquid_flow/generator.py +363 -0
liquid_flow/generator.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LiquidFlow Generator — Main diffusion model.
3
+
4
+ Combines:
5
+ - LiquidFlowBackbone (CfC + Mamba-2 SSD) as the noise predictor
6
+ - DDPM/DDIM diffusion process
7
+ - Physics-informed regularization
8
+
9
+ Supports:
10
+ - Training on 128×128 and 512×512 images
11
+ - TAESD VAE (lightweight, Colab/Kaggle compatible)
12
+ - SD VAE (higher quality)
13
+ - Both DDPM and DDIM sampling
14
+
15
+ The model is designed to be:
16
+ - Trainable on Google Colab free tier / Kaggle (T4 GPU, 15GB)
17
+ - Exportable to ONNX/CoreML for mobile deployment
18
+ - Pure PyTorch — no CUDA kernels needed (Mamba-2 SSD runs on CPU too)
19
+ """
20
+
21
+ import torch
22
+ import torch.nn as nn
23
+ import torch.nn.functional as F
24
+ import math
25
+ import numpy as np
26
+ from tqdm import tqdm
27
+ from typing import Optional, Dict, Tuple
28
+
29
+ from .liquid_flow_block import LiquidFlowBackbone
30
+ from .physics_loss import PhysicsRegularizer, DDIMEstimator
31
+
32
+
33
+ def linear_beta_schedule(timesteps, beta_start=1e-4, beta_end=0.02):
34
+ """Linear noise schedule (DDPM)."""
35
+ return torch.linspace(beta_start, beta_end, timesteps)
36
+
37
+
38
+ def cosine_beta_schedule(timesteps, s=0.008):
39
+ """Cosine noise schedule (Improved DDPM)."""
40
+ steps = timesteps + 1
41
+ x = torch.linspace(0, timesteps, steps)
42
+ alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2
43
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
44
+ betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
45
+ return torch.clip(betas, 0.0001, 0.9999)
46
+
47
+
48
+ class LiquidFlowGenerator(nn.Module):
49
+ """
50
+ LiquidFlow Generator: Liquid Neural Network + Mamba-2 SSD Diffusion Model.
51
+
52
+ Uses LiquidFlowBackbone as noise predictor in a DDPM/DDIM framework.
53
+
54
+ Architecture:
55
+ Noise Predictor = LiquidFlowBackbone (CfC + Mamba-2 SSD)
56
+ Diffusion = DDPM (forward) + DDIM (sampling)
57
+ Regularizer = Physics-Informed Losses (TV, spectral, conservation)
58
+
59
+ Args:
60
+ in_channels: Latent channels from VAE (default 4)
61
+ hidden_dim: Hidden dimension in backbone
62
+ num_stages: Number of LiquidFlow stages
63
+ blocks_per_stage: Blocks per stage
64
+ image_size: Target image size (for latent computation)
65
+ beta_schedule: 'linear' or 'cosine'
66
+ timesteps: Number of diffusion timesteps
67
+ physics_weights: Weights for physics regularizers
68
+ """
69
+
70
+ def __init__(
71
+ self,
72
+ in_channels=4,
73
+ hidden_dim=256,
74
+ num_stages=4,
75
+ blocks_per_stage=4,
76
+ image_size=128,
77
+ beta_schedule='cosine',
78
+ timesteps=1000,
79
+ physics_weights=None,
80
+ ):
81
+ super().__init__()
82
+ self.in_channels = in_channels
83
+ self.hidden_dim = hidden_dim
84
+ self.image_size = image_size # Latent space size = image_size / 8
85
+ self.timesteps = timesteps
86
+
87
+ # Noise predictor (backbone)
88
+ self.backbone = LiquidFlowBackbone(
89
+ in_channels=in_channels,
90
+ hidden_dim=hidden_dim,
91
+ num_stages=num_stages,
92
+ blocks_per_stage=blocks_per_stage,
93
+ d_state=16,
94
+ expand=2,
95
+ dropout=0.0,
96
+ )
97
+
98
+ # Diffusion schedule
99
+ if beta_schedule == 'linear':
100
+ betas = linear_beta_schedule(timesteps)
101
+ else:
102
+ betas = cosine_beta_schedule(timesteps)
103
+
104
+ self.register_buffer('betas', betas)
105
+ self.register_buffer('alphas', 1.0 - betas)
106
+ self.register_buffer('alphas_cumprod', torch.cumprod(self.alphas, dim=0))
107
+ self.register_buffer('alphas_cumprod_prev', F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0))
108
+
109
+ # For DDIM sampling
110
+ self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(self.alphas_cumprod))
111
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1.0 - self.alphas_cumprod))
112
+
113
+ # Physics regularizer
114
+ if physics_weights is None:
115
+ physics_weights = {'tv': 0.01, 'cons': 0.001, 'spec': 0.01, 'grad': 0.001}
116
+ self.physics = PhysicsRegularizer(**physics_weights)
117
+ self.ddim_estimator = DDIMEstimator()
118
+
119
+ def q_sample(self, x0, t, noise=None):
120
+ """
121
+ Forward diffusion: q(x_t | x_0).
122
+
123
+ x_t = √(ᾱ_t) * x_0 + √(1 - ᾱ_t) * ε
124
+ """
125
+ if noise is None:
126
+ noise = torch.randn_like(x0)
127
+
128
+ sqrt_alpha_bar = self.sqrt_alphas_cumprod[t].reshape(-1, 1, 1, 1)
129
+ sqrt_one_minus_alpha_bar = self.sqrt_one_minus_alphas_cumprod[t].reshape(-1, 1, 1, 1)
130
+
131
+ return sqrt_alpha_bar * x0 + sqrt_one_minus_alpha_bar * noise, noise
132
+
133
+ def forward(self, x, t):
134
+ """Predict noise from noisy input."""
135
+ return self.backbone(x, t)
136
+
137
+ def training_step(self, x0, optimizer, scaler=None, use_amp=False):
138
+ """
139
+ Single training step with physics regularization.
140
+
141
+ Args:
142
+ x0: Clean latents [B, C, H, W]
143
+ optimizer: Optimizer
144
+ scaler: Optional GradScaler for AMP
145
+ use_amp: Whether to use automatic mixed precision
146
+
147
+ Returns:
148
+ loss_dict: Dictionary of losses
149
+ """
150
+ B = x0.shape[0]
151
+ device = x0.device
152
+
153
+ # Sample timesteps
154
+ t = torch.randint(0, self.timesteps, (B,), device=device)
155
+
156
+ # Forward diffusion
157
+ noise = torch.randn_like(x0)
158
+ xt, noise = self.q_sample(x0, t, noise)
159
+
160
+ if use_amp and scaler is not None:
161
+ with torch.cuda.amp.autocast():
162
+ # Predict noise
163
+ noise_pred = self.forward(xt, t)
164
+
165
+ # Base diffusion loss (L2 or L1)
166
+ diffusion_loss = F.mse_loss(noise_pred, noise)
167
+
168
+ # Physics regularization on estimated x0
169
+ x0_hat = self.ddim_estimator.estimate_x0(
170
+ xt, noise_pred, self.alphas_cumprod[t]
171
+ )
172
+ phys_loss, phys_dict = self.physics(x0_hat, x0)
173
+
174
+ total_loss = diffusion_loss + phys_loss
175
+ else:
176
+ noise_pred = self.forward(xt, t)
177
+ diffusion_loss = F.mse_loss(noise_pred, noise)
178
+
179
+ x0_hat = self.ddim_estimator.estimate_x0(
180
+ xt, noise_pred, self.alphas_cumprod[t]
181
+ )
182
+ phys_loss, phys_dict = self.physics(x0_hat, x0)
183
+
184
+ total_loss = diffusion_loss + phys_loss
185
+
186
+ # Backward
187
+ optimizer.zero_grad()
188
+ if scaler is not None:
189
+ scaler.scale(total_loss).backward()
190
+ scaler.unscale_(optimizer)
191
+ torch.nn.utils.clip_grad_norm_(self.parameters(), 1.0)
192
+ scaler.step(optimizer)
193
+ scaler.update()
194
+ else:
195
+ total_loss.backward()
196
+ torch.nn.utils.clip_grad_norm_(self.parameters(), 1.0)
197
+ optimizer.step()
198
+
199
+ return {
200
+ 'total': total_loss.item(),
201
+ 'diffusion': diffusion_loss.item(),
202
+ 'physics': phys_loss.item(),
203
+ **{f'phys_{k}': v.item() for k, v in phys_dict.items()},
204
+ }
205
+
206
+ @torch.no_grad()
207
+ def sample(self, batch_size=4, steps=50, ddim=True, eta=0.0, progress=True):
208
+ """
209
+ Generate images using DDPM or DDIM sampling.
210
+
211
+ Args:
212
+ batch_size: Number of images
213
+ steps: Sampling steps (for DDIM: can be << timesteps)
214
+ ddim: Use DDIM sampling (faster)
215
+ eta: DDIM stochasticity (0 = deterministic)
216
+ progress: Show progress bar
217
+
218
+ Returns:
219
+ Generated latents [B, C, H, W]
220
+ """
221
+ device = next(self.parameters()).device
222
+ latent_size = self.image_size // 8
223
+
224
+ # Start from pure noise
225
+ x = torch.randn(batch_size, self.in_channels, latent_size, latent_size, device=device)
226
+
227
+ if ddim:
228
+ return self._ddim_sample(x, steps, eta, progress)
229
+ else:
230
+ return self._ddpm_sample(x, progress)
231
+
232
+ @torch.no_grad()
233
+ def _ddpm_sample(self, x, progress=True):
234
+ """DDPM sampling (full 1000 steps)."""
235
+ device = x.device
236
+
237
+ iterator = tqdm(
238
+ reversed(range(0, self.timesteps)),
239
+ desc='DDPM Sampling',
240
+ total=self.timesteps,
241
+ disable=not progress,
242
+ )
243
+
244
+ for t_idx in iterator:
245
+ t = torch.full((x.shape[0],), t_idx, device=device, dtype=torch.long)
246
+
247
+ noise_pred = self.forward(x, t)
248
+
249
+ alpha = self.alphas[t_idx]
250
+ alpha_bar = self.alphas_cumprod[t_idx]
251
+ alpha_bar_prev = self.alphas_cumprod_prev[t_idx]
252
+ beta = self.betas[t_idx]
253
+
254
+ if t_idx > 0:
255
+ noise = torch.randn_like(x)
256
+ else:
257
+ noise = 0
258
+
259
+ # DDPM posterior
260
+ x = (1 / torch.sqrt(alpha)) * (
261
+ x - (beta / torch.sqrt(1 - alpha_bar)) * noise_pred
262
+ ) + torch.sqrt(beta) * noise
263
+
264
+ return x
265
+
266
+ @torch.no_grad()
267
+ def _ddim_sample(self, x, steps=50, eta=0.0, progress=True):
268
+ """
269
+ DDIM sampling with fewer steps.
270
+
271
+ DDIM can produce good samples in 20-50 steps
272
+ instead of 1000 DDPM steps.
273
+ """
274
+ device = x.device
275
+
276
+ # Timestep spacing
277
+ skip = self.timesteps // steps
278
+ seq = list(range(0, self.timesteps, skip))
279
+ seq_next = [-1] + seq[:-1]
280
+
281
+ iterator = tqdm(
282
+ zip(reversed(seq), reversed(seq_next)),
283
+ desc='DDIM Sampling',
284
+ total=len(seq),
285
+ disable=not progress,
286
+ )
287
+
288
+ for i, j in iterator:
289
+ t = torch.full((x.shape[0],), i, device=device, dtype=torch.long)
290
+
291
+ noise_pred = self.forward(x, t)
292
+
293
+ alpha_bar_i = self.alphas_cumprod[i]
294
+ alpha_bar_j = self.alphas_cumprod[j] if j >= 0 else torch.tensor(1.0, device=device)
295
+
296
+ # Predicted x0
297
+ x0_pred = (x - torch.sqrt(1 - alpha_bar_i) * noise_pred) / torch.sqrt(alpha_bar_i)
298
+ x0_pred = torch.clamp(x0_pred, -1, 1) # Prevent outliers
299
+
300
+ # Direction pointing to x_t
301
+ dir_xt = torch.sqrt(1 - alpha_bar_j - eta * eta * (
302
+ (1 - alpha_bar_j) / (1 - alpha_bar_i)
303
+ )) * noise_pred
304
+
305
+ # Random noise
306
+ if eta > 0:
307
+ noise = torch.randn_like(x)
308
+ sigma = eta * torch.sqrt((1 - alpha_bar_j) / (1 - alpha_bar_i) * (1 - alpha_bar_i / alpha_bar_j))
309
+ x = torch.sqrt(alpha_bar_j) * x0_pred + dir_xt + sigma * noise
310
+ else:
311
+ noise = 0
312
+ x = torch.sqrt(alpha_bar_j) * x0_pred + dir_xt
313
+
314
+ return x
315
+
316
+ def count_parameters(self):
317
+ """Count trainable parameters."""
318
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
319
+
320
+
321
+ def create_liquidflow(
322
+ variant='small',
323
+ image_size=128,
324
+ **kwargs,
325
+ ):
326
+ """
327
+ Create a LiquidFlow model with preset configurations.
328
+
329
+ Variants:
330
+ - 'tiny': ~2M params, 2 stages, 2 blocks each, hidden_dim=128
331
+ - 'small': ~8M params, 4 stages, 4 blocks each, hidden_dim=256
332
+ - 'base': ~30M params, 6 stages, 6 blocks each, hidden_dim=384
333
+
334
+ All designed to run on T4 (15GB) with batch_size >= 16 at 128×128.
335
+ """
336
+ configs = {
337
+ 'tiny': {
338
+ 'hidden_dim': 128,
339
+ 'num_stages': 2,
340
+ 'blocks_per_stage': 2,
341
+ },
342
+ 'small': {
343
+ 'hidden_dim': 256,
344
+ 'num_stages': 4,
345
+ 'blocks_per_stage': 4,
346
+ },
347
+ 'base': {
348
+ 'hidden_dim': 384,
349
+ 'num_stages': 6,
350
+ 'blocks_per_stage': 6,
351
+ },
352
+ }
353
+
354
+ config = configs.get(variant, configs['small'])
355
+ config.update(kwargs)
356
+
357
+ model = LiquidFlowGenerator(
358
+ in_channels=4, # VAE latent channels
359
+ image_size=image_size,
360
+ **config,
361
+ )
362
+
363
+ return model