asdf98 commited on
Commit
4f46baa
·
verified ·
1 Parent(s): fe0d9c3

Add model architecture

Browse files
Files changed (1) hide show
  1. model.py +474 -0
model.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LiquidGen: A Novel Liquid Neural Network Image Generation Model
3
+
4
+ Architecture Overview:
5
+ - Frozen VAE encoder/decoder (FLUX.1-schnell, 16ch latent, 8x compression)
6
+ - Liquid backbone for denoising (fully parallelizable, no attention, no sequential ODE)
7
+ - Flow matching training objective (velocity prediction)
8
+
9
+ Key Innovation: Replaces attention with Liquid Neural Network dynamics:
10
+ - CfC-inspired closed-form update: x_new = α·x + (1-α)·h(x)
11
+ - Per-channel learnable decay rates (liquid time constants)
12
+ - Depthwise + pointwise convolutions for spatial context (no attention needed)
13
+ - Zigzag spatial scanning for global receptive field
14
+ - Gated stimulus with biologically-inspired sign constraints
15
+ - U-Net style long skip connections from shallow to deep blocks
16
+
17
+ Math Foundation (from Hasani et al., CfC paper):
18
+ x_{t+1} = exp(-Δt/τ_t) · x_t + (1 - exp(-Δt/τ_t)) · h(x_t, u_t)
19
+
20
+ Our parallelizable adaptation (inspired by LiquidTAD):
21
+ α = exp(-softplus(ρ)) [per-channel learnable decay]
22
+ h = gate · stimulus [gated depthwise conv output]
23
+ out = α · x + (1 - α) · h [liquid relaxation blend]
24
+
25
+ This removes the input-dependent τ (which requires sequential computation)
26
+ and replaces it with a per-channel learned decay — making it fully parallel
27
+ while preserving the liquid dynamics' ability to blend old state with new input.
28
+
29
+ Design for 16GB VRAM (Colab free tier):
30
+ - VAE frozen: ~1GB
31
+ - Backbone: ~55-280M params (~100-550MB in fp16)
32
+ - Training overhead (grads + optimizer): ~3-8GB
33
+ - Batch of latents: ~1-2GB
34
+ - Total: fits comfortably in 16GB
35
+
36
+ References:
37
+ - Hasani et al., "Liquid Time-constant Networks" (NeurIPS 2020)
38
+ - Hasani et al., "Closed-form Continuous-depth Models" (Nature Machine Intelligence 2022)
39
+ - Lechner et al., "Neural Circuit Policies" (Nature Machine Intelligence 2020)
40
+ - LiquidTAD (2025) - Parallelized liquid dynamics
41
+ - ZigMa (ECCV 2024) - Zigzag scanning for SSM-based diffusion
42
+ - DiMSUM (NeurIPS 2024) - Attention-free diffusion
43
+ """
44
+
45
+ import torch
46
+ import torch.nn as nn
47
+ import torch.nn.functional as F
48
+ import math
49
+ from typing import Optional, Tuple
50
+
51
+
52
+ # =============================================================================
53
+ # Building Blocks
54
+ # =============================================================================
55
+
56
+ class LiquidTimeConstant(nn.Module):
57
+ """
58
+ Core liquid time-constant module.
59
+
60
+ Implements the CfC closed-form dynamics in a fully parallelizable way:
61
+ out = α · x + (1 - α) · stimulus
62
+
63
+ where α = exp(-softplus(ρ)) is a learnable per-channel decay rate,
64
+ derived from the liquid time constant τ = 1/softplus(ρ).
65
+
66
+ This preserves the key property of Liquid Neural Networks:
67
+ - Exponential relaxation toward a target (stimulus)
68
+ - Rate controlled by τ (how fast to adapt)
69
+ - No sequential ODE solving required
70
+
71
+ Stability guarantee (from LTC Theorem 1):
72
+ τ_sys ∈ [τ/(1+τW), τ] — time constants NEVER explode
73
+ """
74
+ def __init__(self, channels: int):
75
+ super().__init__()
76
+ # ρ parameterizes the decay: λ = softplus(ρ), α = exp(-λ)
77
+ # Initialize ρ=0 → λ≈0.693 → α≈0.5 (equal blend of old and new)
78
+ self.rho = nn.Parameter(torch.zeros(channels))
79
+
80
+ def forward(self, x: torch.Tensor, stimulus: torch.Tensor) -> torch.Tensor:
81
+ """
82
+ x: [B, C, H, W] - current state (residual path)
83
+ stimulus: [B, C, H, W] - computed target from context
84
+ returns: [B, C, H, W] - liquid-blended output
85
+ """
86
+ lam = F.softplus(self.rho) + 1e-5
87
+ alpha = torch.exp(-lam).view(1, -1, 1, 1)
88
+ return alpha * x + (1.0 - alpha) * stimulus
89
+
90
+
91
+ class GatedDepthwiseStimulusConv(nn.Module):
92
+ """
93
+ Computes the spatial stimulus using depthwise-separable convolutions
94
+ with a sigmoid gate (inspired by GLU / gated mechanisms in SSMs).
95
+
96
+ This replaces attention for capturing local spatial context:
97
+ - Depthwise conv: captures local spatial patterns per channel
98
+ - Pointwise conv: mixes channel information
99
+ - Sigmoid gate: controls information flow (like synaptic gating in NCP)
100
+
101
+ Two parallel paths (inspired by NCP inter→command split):
102
+ 1. Stimulus path: DW-conv → PW-conv → GELU → project back
103
+ 2. Gate path: DW-conv → PW-conv → sigmoid
104
+ Output = stimulus * gate
105
+ """
106
+ def __init__(self, channels: int, kernel_size: int = 7, expand_ratio: float = 2.0):
107
+ super().__init__()
108
+ hidden = int(channels * expand_ratio)
109
+
110
+ self.stim_dw = nn.Conv2d(channels, channels, kernel_size,
111
+ padding=kernel_size // 2, groups=channels, bias=False)
112
+ self.stim_pw = nn.Conv2d(channels, hidden, 1, bias=False)
113
+ self.stim_act = nn.GELU()
114
+ self.stim_proj = nn.Conv2d(hidden, channels, 1, bias=False)
115
+
116
+ self.gate_dw = nn.Conv2d(channels, channels, kernel_size,
117
+ padding=kernel_size // 2, groups=channels, bias=False)
118
+ self.gate_pw = nn.Conv2d(channels, channels, 1, bias=True)
119
+
120
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
121
+ stim = self.stim_proj(self.stim_act(self.stim_pw(self.stim_dw(x))))
122
+ gate = torch.sigmoid(self.gate_pw(self.gate_dw(x)))
123
+ return stim * gate
124
+
125
+
126
+ class ChannelMixMLP(nn.Module):
127
+ """Channel mixing MLP with GELU activation (command neuron processing in NCP)."""
128
+ def __init__(self, channels: int, expand_ratio: float = 4.0):
129
+ super().__init__()
130
+ hidden = int(channels * expand_ratio)
131
+ self.fc1 = nn.Conv2d(channels, hidden, 1, bias=True)
132
+ self.act = nn.GELU()
133
+ self.fc2 = nn.Conv2d(hidden, channels, 1, bias=True)
134
+
135
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
136
+ return self.fc2(self.act(self.fc1(x)))
137
+
138
+
139
+ class AdaptiveGroupNorm(nn.Module):
140
+ """
141
+ Adaptive Group Normalization conditioned on timestep embedding.
142
+ Applies: out = (1 + scale) * GroupNorm(x) + shift
143
+ """
144
+ def __init__(self, channels: int, cond_dim: int, num_groups: int = 32):
145
+ super().__init__()
146
+ self.norm = nn.GroupNorm(num_groups, channels, affine=False)
147
+ self.proj = nn.Linear(cond_dim, channels * 2)
148
+ nn.init.zeros_(self.proj.weight)
149
+ nn.init.zeros_(self.proj.bias)
150
+
151
+ def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
152
+ h = self.norm(x)
153
+ params = self.proj(cond)
154
+ scale, shift = params.chunk(2, dim=-1)
155
+ return h * (1.0 + scale.unsqueeze(-1).unsqueeze(-1)) + shift.unsqueeze(-1).unsqueeze(-1)
156
+
157
+
158
+ class ZigzagScan1D(nn.Module):
159
+ """
160
+ 1D global mixing via zigzag-scanned depthwise conv.
161
+
162
+ Gives quasi-global receptive field without attention's O(n²) cost.
163
+ Zigzag scan preserves spatial continuity (from ZigMa, ECCV 2024).
164
+ """
165
+ def __init__(self, channels: int, kernel_size: int = 31):
166
+ super().__init__()
167
+ self.conv1d = nn.Conv1d(channels, channels, kernel_size,
168
+ padding=kernel_size // 2, groups=channels, bias=False)
169
+ self.pw = nn.Conv1d(channels, channels, 1, bias=True)
170
+ self.act = nn.GELU()
171
+
172
+ def _zigzag_indices(self, H: int, W: int, device: torch.device) -> torch.Tensor:
173
+ indices = []
174
+ for i in range(H):
175
+ row = list(range(i * W, (i + 1) * W))
176
+ if i % 2 == 1:
177
+ row = row[::-1]
178
+ indices.extend(row)
179
+ return torch.tensor(indices, device=device, dtype=torch.long)
180
+
181
+ def _inverse_zigzag_indices(self, H: int, W: int, device: torch.device) -> torch.Tensor:
182
+ fwd = self._zigzag_indices(H, W, device)
183
+ inv = torch.empty_like(fwd)
184
+ inv[fwd] = torch.arange(H * W, device=device)
185
+ return inv
186
+
187
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
188
+ B, C, H, W = x.shape
189
+ zz_idx = self._zigzag_indices(H, W, x.device)
190
+ inv_idx = self._inverse_zigzag_indices(H, W, x.device)
191
+ x_flat = x.reshape(B, C, H * W)
192
+ x_zz = x_flat[:, :, zz_idx]
193
+ x_mixed = self.pw(self.act(self.conv1d(x_zz)))
194
+ x_restored = x_mixed[:, :, inv_idx]
195
+ return x_restored.reshape(B, C, H, W)
196
+
197
+
198
+ # =============================================================================
199
+ # Liquid Block: The core building block
200
+ # =============================================================================
201
+
202
+ class LiquidBlock(nn.Module):
203
+ """
204
+ A single Liquid Neural Network block for image denoising.
205
+
206
+ Architecture (maps to NCP hierarchy):
207
+ 1. [SENSORY] AdaGN conditioning → spatial context extraction
208
+ 2. [INTER] Zigzag 1D scan for global mixing
209
+ 3. [COMMAND] Liquid time-constant blend (CfC dynamics)
210
+ 4. [MOTOR] Channel mixing MLP for output projection
211
+
212
+ All operations are fully parallelizable — no sequential dependencies.
213
+ """
214
+ def __init__(
215
+ self, channels: int, cond_dim: int, spatial_kernel: int = 7,
216
+ scan_kernel: int = 31, expand_ratio: float = 2.0, mlp_ratio: float = 4.0,
217
+ drop_rate: float = 0.0, use_zigzag: bool = True,
218
+ ):
219
+ super().__init__()
220
+ self.norm1 = AdaptiveGroupNorm(channels, cond_dim)
221
+ self.norm2 = AdaptiveGroupNorm(channels, cond_dim)
222
+ self.spatial_stim = GatedDepthwiseStimulusConv(channels, spatial_kernel, expand_ratio)
223
+ self.use_zigzag = use_zigzag
224
+ if use_zigzag:
225
+ self.zigzag = ZigzagScan1D(channels, scan_kernel)
226
+ self.zigzag_gate = nn.Parameter(torch.zeros(1))
227
+ self.liquid = LiquidTimeConstant(channels)
228
+ self.channel_mix = ChannelMixMLP(channels, mlp_ratio)
229
+ self.liquid2 = LiquidTimeConstant(channels)
230
+ self.drop = nn.Dropout2d(drop_rate) if drop_rate > 0 else nn.Identity()
231
+
232
+ def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
233
+ h = self.norm1(x, cond)
234
+ stim = self.spatial_stim(h)
235
+ if self.use_zigzag:
236
+ zz = self.zigzag(h)
237
+ stim = stim + torch.sigmoid(self.zigzag_gate) * zz
238
+ stim = self.drop(stim)
239
+ x = self.liquid(x, stim)
240
+ h2 = self.norm2(x, cond)
241
+ ch_out = self.drop(self.channel_mix(h2))
242
+ x = self.liquid2(x, ch_out)
243
+ return x
244
+
245
+
246
+ # =============================================================================
247
+ # Timestep and Class Embeddings
248
+ # =============================================================================
249
+
250
+ class TimestepEmbedding(nn.Module):
251
+ """Sinusoidal timestep embedding followed by MLP projection."""
252
+ def __init__(self, dim: int, freq_dim: int = 256):
253
+ super().__init__()
254
+ self.freq_dim = freq_dim
255
+ self.mlp = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
256
+
257
+ def forward(self, t: torch.Tensor) -> torch.Tensor:
258
+ half = self.freq_dim // 2
259
+ freqs = torch.exp(-math.log(10000.0) * torch.arange(half, device=t.device, dtype=t.dtype) / half)
260
+ args = t.unsqueeze(-1) * freqs.unsqueeze(0)
261
+ emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
262
+ return self.mlp(emb)
263
+
264
+
265
+ class ClassEmbedding(nn.Module):
266
+ """Optional class-conditional embedding with CFG null embedding."""
267
+ def __init__(self, num_classes: int, dim: int):
268
+ super().__init__()
269
+ self.embed = nn.Embedding(num_classes, dim)
270
+ self.null_embed = nn.Parameter(torch.randn(dim) * 0.02)
271
+
272
+ def forward(self, labels: torch.Tensor, drop_prob: float = 0.0) -> torch.Tensor:
273
+ emb = self.embed(labels)
274
+ if self.training and drop_prob > 0:
275
+ mask = torch.rand(labels.shape[0], 1, device=labels.device) < drop_prob
276
+ emb = torch.where(mask, self.null_embed.unsqueeze(0).expand_as(emb), emb)
277
+ return emb
278
+
279
+
280
+ # =============================================================================
281
+ # LiquidGen: Full Model
282
+ # =============================================================================
283
+
284
+ class LiquidGen(nn.Module):
285
+ """
286
+ LiquidGen: Liquid Neural Network Image Generator
287
+
288
+ A novel attention-free diffusion model that uses Liquid Neural Network
289
+ dynamics (CfC closed-form continuous-depth) for image generation.
290
+
291
+ Features:
292
+ - NO self-attention anywhere — O(n) complexity
293
+ - NO sequential ODE solving — fully parallelizable
294
+ - Liquid time constants for adaptive information blending
295
+ - Zigzag scanning for global context
296
+ - Depthwise convolutions for local spatial structure
297
+ - Gated stimulus (biologically-inspired from NCP)
298
+ - U-Net long skip connections (from U-ViT/DiM)
299
+
300
+ Config Presets:
301
+ - LiquidGen-S: ~55M params (256px, fast training)
302
+ - LiquidGen-B: ~140M params (256/512px, balanced)
303
+ - LiquidGen-L: ~280M params (512px, high quality)
304
+ """
305
+
306
+ def __init__(
307
+ self,
308
+ in_channels: int = 16,
309
+ patch_size: int = 2,
310
+ embed_dim: int = 512,
311
+ depth: int = 16,
312
+ spatial_kernel: int = 7,
313
+ scan_kernel: int = 31,
314
+ expand_ratio: float = 2.0,
315
+ mlp_ratio: float = 4.0,
316
+ drop_rate: float = 0.0,
317
+ num_classes: int = 0,
318
+ class_drop_prob: float = 0.1,
319
+ use_zigzag: bool = True,
320
+ ):
321
+ super().__init__()
322
+ self.in_channels = in_channels
323
+ self.patch_size = patch_size
324
+ self.embed_dim = embed_dim
325
+ self.depth = depth
326
+ self.num_classes = num_classes
327
+ self.class_drop_prob = class_drop_prob
328
+
329
+ cond_dim = embed_dim
330
+
331
+ self.time_embed = TimestepEmbedding(cond_dim)
332
+ self.class_embed = ClassEmbedding(num_classes, cond_dim) if num_classes > 0 else None
333
+
334
+ self.patch_embed = nn.Conv2d(in_channels, embed_dim, patch_size, stride=patch_size)
335
+
336
+ self.pos_embed_size = 32
337
+ self.pos_embed = nn.Parameter(
338
+ torch.randn(1, embed_dim, self.pos_embed_size, self.pos_embed_size) * 0.02
339
+ )
340
+
341
+ self.input_proj = nn.Sequential(
342
+ nn.Conv2d(embed_dim, embed_dim, 3, padding=1, groups=embed_dim, bias=False),
343
+ nn.Conv2d(embed_dim, embed_dim, 1, bias=True),
344
+ nn.GELU(),
345
+ )
346
+
347
+ self.blocks = nn.ModuleList([
348
+ LiquidBlock(embed_dim, cond_dim, spatial_kernel, scan_kernel,
349
+ expand_ratio, mlp_ratio, drop_rate, use_zigzag)
350
+ for _ in range(depth)
351
+ ])
352
+
353
+ self.final_norm = nn.GroupNorm(32, embed_dim)
354
+ self.final_proj = nn.Sequential(
355
+ nn.Conv2d(embed_dim, embed_dim, 3, padding=1, bias=True),
356
+ nn.GELU(),
357
+ )
358
+
359
+ self.unpatch = nn.ConvTranspose2d(embed_dim, in_channels, patch_size, stride=patch_size)
360
+ nn.init.zeros_(self.unpatch.weight)
361
+ nn.init.zeros_(self.unpatch.bias)
362
+
363
+ self.apply(self._init_weights)
364
+
365
+ def _init_weights(self, m):
366
+ if isinstance(m, nn.Conv2d):
367
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
368
+ if m.bias is not None:
369
+ nn.init.zeros_(m.bias)
370
+ elif isinstance(m, nn.Linear):
371
+ nn.init.xavier_uniform_(m.weight)
372
+ if m.bias is not None:
373
+ nn.init.zeros_(m.bias)
374
+ elif isinstance(m, nn.Embedding):
375
+ nn.init.normal_(m.weight, std=0.02)
376
+
377
+ def _interpolate_pos_embed(self, H: int, W: int) -> torch.Tensor:
378
+ if H == self.pos_embed_size and W == self.pos_embed_size:
379
+ return self.pos_embed
380
+ return F.interpolate(self.pos_embed, size=(H, W), mode='bilinear', align_corners=False)
381
+
382
+ def forward(
383
+ self, x: torch.Tensor, t: torch.Tensor, class_labels: Optional[torch.Tensor] = None,
384
+ ) -> torch.Tensor:
385
+ """
386
+ Predict velocity field for flow matching.
387
+ Args:
388
+ x: [B, C, H, W] noisy latent (C=16 for Flux VAE)
389
+ t: [B] timestep in [0, 1]
390
+ class_labels: [B] optional class labels
391
+ Returns:
392
+ v: [B, C, H, W] predicted velocity
393
+ """
394
+ cond = self.time_embed(t)
395
+ if self.class_embed is not None and class_labels is not None:
396
+ drop_p = self.class_drop_prob if self.training else 0.0
397
+ cond = cond + self.class_embed(class_labels, drop_prob=drop_p)
398
+
399
+ h = self.patch_embed(x)
400
+ B, C, H_p, W_p = h.shape
401
+ h = h + self._interpolate_pos_embed(H_p, W_p)
402
+ h = self.input_proj(h)
403
+
404
+ # U-Net style long skip connections
405
+ skip_connections = []
406
+ mid = self.depth // 2
407
+ for i, block in enumerate(self.blocks):
408
+ if i < mid:
409
+ skip_connections.append(h)
410
+ elif i >= mid and len(skip_connections) > 0:
411
+ skip = skip_connections.pop()
412
+ h = h + skip
413
+ h = block(h, cond)
414
+
415
+ h = self.final_norm(h)
416
+ h = self.final_proj(h)
417
+ v = self.unpatch(h)
418
+ return v
419
+
420
+ def count_params(self) -> int:
421
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
422
+
423
+
424
+ # =============================================================================
425
+ # Model Presets
426
+ # =============================================================================
427
+
428
+ def liquidgen_small(**kwargs) -> LiquidGen:
429
+ """~55M params - for 256px, fast training/testing"""
430
+ defaults = dict(
431
+ embed_dim=512, depth=12, spatial_kernel=7, scan_kernel=31,
432
+ expand_ratio=2.0, mlp_ratio=3.0, use_zigzag=True,
433
+ )
434
+ defaults.update(kwargs)
435
+ return LiquidGen(**defaults)
436
+
437
+ def liquidgen_base(**kwargs) -> LiquidGen:
438
+ """~140M params - for 256/512px, balanced (fits T4 16GB easily)"""
439
+ defaults = dict(
440
+ embed_dim=640, depth=18, spatial_kernel=7, scan_kernel=31,
441
+ expand_ratio=2.0, mlp_ratio=4.0, use_zigzag=True,
442
+ )
443
+ defaults.update(kwargs)
444
+ return LiquidGen(**defaults)
445
+
446
+ def liquidgen_large(**kwargs) -> LiquidGen:
447
+ """~280M params - for 512px, high quality (fits T4 16GB with small batch)"""
448
+ defaults = dict(
449
+ embed_dim=768, depth=24, spatial_kernel=7, scan_kernel=31,
450
+ expand_ratio=2.5, mlp_ratio=4.0, use_zigzag=True,
451
+ )
452
+ defaults.update(kwargs)
453
+ return LiquidGen(**defaults)
454
+
455
+
456
+ if __name__ == "__main__":
457
+ device = "cpu"
458
+ for name, factory in [("Small", liquidgen_small), ("Base", liquidgen_base), ("Large", liquidgen_large)]:
459
+ model = factory(num_classes=27).to(device)
460
+ print(f"LiquidGen-{name}: {model.count_params() / 1e6:.1f}M params")
461
+
462
+ x = torch.randn(2, 16, 32, 32, device=device)
463
+ t = torch.rand(2, device=device)
464
+ labels = torch.randint(0, 27, (2,), device=device)
465
+ v = model(x, t, labels)
466
+ assert v.shape == x.shape
467
+
468
+ x512 = torch.randn(1, 16, 64, 64, device=device)
469
+ v512 = model(x512, t[:1], labels[:1])
470
+ assert v512.shape == x512.shape
471
+ print(f" 256px ✅ 512px ✅")
472
+ del model
473
+
474
+ print("\n✅ All tests passed!")