File size: 37,767 Bytes
4ad2cc3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 | {
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "T4"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# \ud83e\uddea LiquidGen: Liquid Neural Network Image Generator\n",
"\n",
"**A novel attention-free diffusion model using CfC Liquid Neural Network dynamics.**\n",
"\n",
"### Key Features:\n",
"- **No Attention** \u2014 O(n) complexity using liquid time constants\n",
"- **Fully Parallelizable** \u2014 No sequential ODE solving\n",
"- **Flow Matching** \u2014 Modern velocity-prediction training\n",
"- **Frozen Flux VAE** \u2014 16-channel latent space\n",
"- **Fits 16GB VRAM** \u2014 Designed for Colab free tier\n",
"\n",
"Based on: Liquid Time-constant Networks (NeurIPS 2020), CfC (Nature MI 2022), ZigMa (ECCV 2024), DiMSUM (NeurIPS 2024)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## \ud83d\udce6 Install Dependencies"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"!pip install -q torch torchvision diffusers datasets accelerate huggingface_hub Pillow"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## \ud83d\udd27 Configuration"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"MODEL_SIZE = \"small\" # \"small\" (~55M), \"base\" (~140M), \"large\" (~280M)\n",
"IMAGE_SIZE = 256 # 256 or 512\n",
"DATASET_NAME = \"huggan/wikiart\"\n",
"IMAGE_COLUMN = \"image\"\n",
"LABEL_COLUMN = \"style\" # \"style\" (27), \"genre\" (11), \"\" for unconditional\n",
"NUM_CLASSES = 27\n",
"BATCH_SIZE = 8\n",
"GRADIENT_ACCUMULATION = 4\n",
"LEARNING_RATE = 1e-4\n",
"WEIGHT_DECAY = 0.01\n",
"MAX_GRAD_NORM = 2.0\n",
"NUM_EPOCHS = 50\n",
"WARMUP_STEPS = 500\n",
"EMA_DECAY = 0.9999\n",
"NUM_SAMPLE_STEPS = 50\n",
"CFG_SCALE = 2.0\n",
"OUTPUT_DIR = \"/content/liquidgen_outputs\"\n",
"SAVE_EVERY = 2000\n",
"SAMPLE_EVERY = 500\n",
"LOG_EVERY = 50\n",
"PUSH_TO_HUB = False\n",
"HUB_MODEL_ID = \"\"\n",
"VAE_ID = \"black-forest-labs/FLUX.1-schnell\"\n",
"VAE_SUBFOLDER = \"vae\"\n",
"\n",
"import torch\n",
"if torch.cuda.is_available():\n",
" gpu = torch.cuda.get_device_name(0)\n",
" mem = torch.cuda.get_device_properties(0).total_mem / 1024**3\n",
" print(f\"GPU: {gpu} ({mem:.1f} GB)\")\n",
" if mem < 12: print(\"\u26a0\ufe0f Low VRAM! Use small model, 256px, bs=4\")\n",
" elif mem < 20: print(\"\u2705 T4 detected. Good for base model, 256px\")\n",
" else: print(\"\ud83d\ude80 Large GPU! Can run large model, 512px\")\n",
"else:\n",
" print(\"\u26a0\ufe0f No GPU! Go to Runtime \u2192 Change runtime type \u2192 GPU\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## \ud83c\udfd7\ufe0f Model Architecture"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\"\"\"\nLiquidGen: A Novel Liquid Neural Network Image Generation Model\n\nArchitecture Overview:\n- Frozen VAE encoder/decoder (FLUX.1-schnell, 16ch latent, 8x compression)\n- Liquid backbone for denoising (fully parallelizable, no attention, no sequential ODE)\n- Flow matching training objective (velocity prediction)\n\nKey Innovation: Replaces attention with Liquid Neural Network dynamics:\n- CfC-inspired closed-form update: x_new = \u03b1\u00b7x + (1-\u03b1)\u00b7h(x)\n- Per-channel learnable decay rates (liquid time constants)\n- Depthwise + pointwise convolutions for spatial context (no attention needed)\n- Zigzag spatial scanning for global receptive field\n- Gated stimulus with biologically-inspired sign constraints\n- U-Net style long skip connections from shallow to deep blocks\n\nMath Foundation (from Hasani et al., CfC paper):\n x_{t+1} = exp(-\u0394t/\u03c4_t) \u00b7 x_t + (1 - exp(-\u0394t/\u03c4_t)) \u00b7 h(x_t, u_t)\n \nOur parallelizable adaptation (inspired by LiquidTAD):\n \u03b1 = exp(-softplus(\u03c1)) [per-channel learnable decay]\n h = gate \u00b7 stimulus [gated depthwise conv output] \n out = \u03b1 \u00b7 x + (1 - \u03b1) \u00b7 h [liquid relaxation blend]\n\nThis removes the input-dependent \u03c4 (which requires sequential computation)\nand replaces it with a per-channel learned decay \u2014 making it fully parallel\nwhile preserving the liquid dynamics' ability to blend old state with new input.\n\nDesign for 16GB VRAM (Colab free tier):\n- VAE frozen: ~1GB\n- Backbone: ~55-280M params (~100-550MB in fp16) \n- Training overhead (grads + optimizer): ~3-8GB\n- Batch of latents: ~1-2GB\n- Total: fits comfortably in 16GB\n\nReferences:\n- Hasani et al., \"Liquid Time-constant Networks\" (NeurIPS 2020)\n- Hasani et al., \"Closed-form Continuous-depth Models\" (Nature Machine Intelligence 2022)\n- Lechner et al., \"Neural Circuit Policies\" (Nature Machine Intelligence 2020)\n- LiquidTAD (2025) - Parallelized liquid dynamics\n- ZigMa (ECCV 2024) - Zigzag scanning for SSM-based diffusion\n- DiMSUM (NeurIPS 2024) - Attention-free diffusion\n\"\"\"\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport math\nfrom typing import Optional, Tuple\n\n\n# =============================================================================\n# Building Blocks\n# =============================================================================\n\nclass LiquidTimeConstant(nn.Module):\n \"\"\"\n Core liquid time-constant module.\n \n Implements the CfC closed-form dynamics in a fully parallelizable way:\n out = \u03b1 \u00b7 x + (1 - \u03b1) \u00b7 stimulus\n \n where \u03b1 = exp(-softplus(\u03c1)) is a learnable per-channel decay rate,\n derived from the liquid time constant \u03c4 = 1/softplus(\u03c1).\n \n This preserves the key property of Liquid Neural Networks:\n - Exponential relaxation toward a target (stimulus)\n - Rate controlled by \u03c4 (how fast to adapt)\n - No sequential ODE solving required\n \n Stability guarantee (from LTC Theorem 1):\n \u03c4_sys \u2208 [\u03c4/(1+\u03c4W), \u03c4] \u2014 time constants NEVER explode\n \"\"\"\n def __init__(self, channels: int):\n super().__init__()\n # \u03c1 parameterizes the decay: \u03bb = softplus(\u03c1), \u03b1 = exp(-\u03bb)\n # Initialize \u03c1=0 \u2192 \u03bb\u22480.693 \u2192 \u03b1\u22480.5 (equal blend of old and new)\n self.rho = nn.Parameter(torch.zeros(channels))\n \n def forward(self, x: torch.Tensor, stimulus: torch.Tensor) -> torch.Tensor:\n \"\"\"\n x: [B, C, H, W] - current state (residual path)\n stimulus: [B, C, H, W] - computed target from context\n returns: [B, C, H, W] - liquid-blended output\n \"\"\"\n lam = F.softplus(self.rho) + 1e-5\n alpha = torch.exp(-lam).view(1, -1, 1, 1)\n return alpha * x + (1.0 - alpha) * stimulus\n\n\nclass GatedDepthwiseStimulusConv(nn.Module):\n \"\"\"\n Computes the spatial stimulus using depthwise-separable convolutions\n with a sigmoid gate (inspired by GLU / gated mechanisms in SSMs).\n \n This replaces attention for capturing local spatial context:\n - Depthwise conv: captures local spatial patterns per channel\n - Pointwise conv: mixes channel information\n - Sigmoid gate: controls information flow (like synaptic gating in NCP)\n \n Two parallel paths (inspired by NCP inter\u2192command split):\n 1. Stimulus path: DW-conv \u2192 PW-conv \u2192 GELU \u2192 project back\n 2. Gate path: DW-conv \u2192 PW-conv \u2192 sigmoid\n Output = stimulus * gate\n \"\"\"\n def __init__(self, channels: int, kernel_size: int = 7, expand_ratio: float = 2.0):\n super().__init__()\n hidden = int(channels * expand_ratio)\n \n self.stim_dw = nn.Conv2d(channels, channels, kernel_size, \n padding=kernel_size // 2, groups=channels, bias=False)\n self.stim_pw = nn.Conv2d(channels, hidden, 1, bias=False)\n self.stim_act = nn.GELU()\n self.stim_proj = nn.Conv2d(hidden, channels, 1, bias=False)\n \n self.gate_dw = nn.Conv2d(channels, channels, kernel_size,\n padding=kernel_size // 2, groups=channels, bias=False)\n self.gate_pw = nn.Conv2d(channels, channels, 1, bias=True)\n \n def forward(self, x: torch.Tensor) -> torch.Tensor:\n stim = self.stim_proj(self.stim_act(self.stim_pw(self.stim_dw(x))))\n gate = torch.sigmoid(self.gate_pw(self.gate_dw(x)))\n return stim * gate\n\n\nclass ChannelMixMLP(nn.Module):\n \"\"\"Channel mixing MLP with GELU activation (command neuron processing in NCP).\"\"\"\n def __init__(self, channels: int, expand_ratio: float = 4.0):\n super().__init__()\n hidden = int(channels * expand_ratio)\n self.fc1 = nn.Conv2d(channels, hidden, 1, bias=True)\n self.act = nn.GELU()\n self.fc2 = nn.Conv2d(hidden, channels, 1, bias=True)\n \n def forward(self, x: torch.Tensor) -> torch.Tensor:\n return self.fc2(self.act(self.fc1(x)))\n\n\nclass AdaptiveGroupNorm(nn.Module):\n \"\"\"\n Adaptive Group Normalization conditioned on timestep embedding.\n Applies: out = (1 + scale) * GroupNorm(x) + shift\n \"\"\"\n def __init__(self, channels: int, cond_dim: int, num_groups: int = 32):\n super().__init__()\n self.norm = nn.GroupNorm(num_groups, channels, affine=False)\n self.proj = nn.Linear(cond_dim, channels * 2)\n nn.init.zeros_(self.proj.weight)\n nn.init.zeros_(self.proj.bias)\n \n def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:\n h = self.norm(x)\n params = self.proj(cond)\n scale, shift = params.chunk(2, dim=-1)\n return h * (1.0 + scale.unsqueeze(-1).unsqueeze(-1)) + shift.unsqueeze(-1).unsqueeze(-1)\n\n\nclass ZigzagScan1D(nn.Module):\n \"\"\"\n 1D global mixing via zigzag-scanned depthwise conv.\n \n Gives quasi-global receptive field without attention's O(n\u00b2) cost.\n Zigzag scan preserves spatial continuity (from ZigMa, ECCV 2024).\n \"\"\"\n def __init__(self, channels: int, kernel_size: int = 31):\n super().__init__()\n self.conv1d = nn.Conv1d(channels, channels, kernel_size, \n padding=kernel_size // 2, groups=channels, bias=False)\n self.pw = nn.Conv1d(channels, channels, 1, bias=True)\n self.act = nn.GELU()\n \n def _zigzag_indices(self, H: int, W: int, device: torch.device) -> torch.Tensor:\n indices = []\n for i in range(H):\n row = list(range(i * W, (i + 1) * W))\n if i % 2 == 1:\n row = row[::-1]\n indices.extend(row)\n return torch.tensor(indices, device=device, dtype=torch.long)\n \n def _inverse_zigzag_indices(self, H: int, W: int, device: torch.device) -> torch.Tensor:\n fwd = self._zigzag_indices(H, W, device)\n inv = torch.empty_like(fwd)\n inv[fwd] = torch.arange(H * W, device=device)\n return inv\n \n def forward(self, x: torch.Tensor) -> torch.Tensor:\n B, C, H, W = x.shape\n zz_idx = self._zigzag_indices(H, W, x.device)\n inv_idx = self._inverse_zigzag_indices(H, W, x.device)\n x_flat = x.reshape(B, C, H * W)\n x_zz = x_flat[:, :, zz_idx]\n x_mixed = self.pw(self.act(self.conv1d(x_zz)))\n x_restored = x_mixed[:, :, inv_idx]\n return x_restored.reshape(B, C, H, W)\n\n\n# =============================================================================\n# Liquid Block: The core building block\n# =============================================================================\n\nclass LiquidBlock(nn.Module):\n \"\"\"\n A single Liquid Neural Network block for image denoising.\n \n Architecture (maps to NCP hierarchy):\n 1. [SENSORY] AdaGN conditioning \u2192 spatial context extraction\n 2. [INTER] Zigzag 1D scan for global mixing\n 3. [COMMAND] Liquid time-constant blend (CfC dynamics)\n 4. [MOTOR] Channel mixing MLP for output projection\n \n All operations are fully parallelizable \u2014 no sequential dependencies.\n \"\"\"\n def __init__(\n self, channels: int, cond_dim: int, spatial_kernel: int = 7,\n scan_kernel: int = 31, expand_ratio: float = 2.0, mlp_ratio: float = 4.0,\n drop_rate: float = 0.0, use_zigzag: bool = True,\n ):\n super().__init__()\n self.norm1 = AdaptiveGroupNorm(channels, cond_dim)\n self.norm2 = AdaptiveGroupNorm(channels, cond_dim)\n self.spatial_stim = GatedDepthwiseStimulusConv(channels, spatial_kernel, expand_ratio)\n self.use_zigzag = use_zigzag\n if use_zigzag:\n self.zigzag = ZigzagScan1D(channels, scan_kernel)\n self.zigzag_gate = nn.Parameter(torch.zeros(1))\n self.liquid = LiquidTimeConstant(channels)\n self.channel_mix = ChannelMixMLP(channels, mlp_ratio)\n self.liquid2 = LiquidTimeConstant(channels)\n self.drop = nn.Dropout2d(drop_rate) if drop_rate > 0 else nn.Identity()\n \n def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:\n h = self.norm1(x, cond)\n stim = self.spatial_stim(h)\n if self.use_zigzag:\n zz = self.zigzag(h)\n stim = stim + torch.sigmoid(self.zigzag_gate) * zz\n stim = self.drop(stim)\n x = self.liquid(x, stim)\n h2 = self.norm2(x, cond)\n ch_out = self.drop(self.channel_mix(h2))\n x = self.liquid2(x, ch_out)\n return x\n\n\n# =============================================================================\n# Timestep and Class Embeddings\n# =============================================================================\n\nclass TimestepEmbedding(nn.Module):\n \"\"\"Sinusoidal timestep embedding followed by MLP projection.\"\"\"\n def __init__(self, dim: int, freq_dim: int = 256):\n super().__init__()\n self.freq_dim = freq_dim\n self.mlp = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))\n \n def forward(self, t: torch.Tensor) -> torch.Tensor:\n half = self.freq_dim // 2\n freqs = torch.exp(-math.log(10000.0) * torch.arange(half, device=t.device, dtype=t.dtype) / half)\n args = t.unsqueeze(-1) * freqs.unsqueeze(0)\n emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)\n return self.mlp(emb)\n\n\nclass ClassEmbedding(nn.Module):\n \"\"\"Optional class-conditional embedding with CFG null embedding.\"\"\"\n def __init__(self, num_classes: int, dim: int):\n super().__init__()\n self.embed = nn.Embedding(num_classes, dim)\n self.null_embed = nn.Parameter(torch.randn(dim) * 0.02)\n \n def forward(self, labels: torch.Tensor, drop_prob: float = 0.0) -> torch.Tensor:\n emb = self.embed(labels)\n if self.training and drop_prob > 0:\n mask = torch.rand(labels.shape[0], 1, device=labels.device) < drop_prob\n emb = torch.where(mask, self.null_embed.unsqueeze(0).expand_as(emb), emb)\n return emb\n\n\n# =============================================================================\n# LiquidGen: Full Model\n# =============================================================================\n\nclass LiquidGen(nn.Module):\n \"\"\"\n LiquidGen: Liquid Neural Network Image Generator\n \n A novel attention-free diffusion model that uses Liquid Neural Network\n dynamics (CfC closed-form continuous-depth) for image generation.\n \n Features:\n - NO self-attention anywhere \u2014 O(n) complexity\n - NO sequential ODE solving \u2014 fully parallelizable\n - Liquid time constants for adaptive information blending\n - Zigzag scanning for global context\n - Depthwise convolutions for local spatial structure\n - Gated stimulus (biologically-inspired from NCP)\n - U-Net long skip connections (from U-ViT/DiM)\n \n Config Presets:\n - LiquidGen-S: ~55M params (256px, fast training)\n - LiquidGen-B: ~140M params (256/512px, balanced)\n - LiquidGen-L: ~280M params (512px, high quality)\n \"\"\"\n \n def __init__(\n self,\n in_channels: int = 16,\n patch_size: int = 2,\n embed_dim: int = 512,\n depth: int = 16,\n spatial_kernel: int = 7,\n scan_kernel: int = 31,\n expand_ratio: float = 2.0,\n mlp_ratio: float = 4.0,\n drop_rate: float = 0.0,\n num_classes: int = 0,\n class_drop_prob: float = 0.1,\n use_zigzag: bool = True,\n ):\n super().__init__()\n self.in_channels = in_channels\n self.patch_size = patch_size\n self.embed_dim = embed_dim\n self.depth = depth\n self.num_classes = num_classes\n self.class_drop_prob = class_drop_prob\n \n cond_dim = embed_dim\n \n self.time_embed = TimestepEmbedding(cond_dim)\n self.class_embed = ClassEmbedding(num_classes, cond_dim) if num_classes > 0 else None\n \n self.patch_embed = nn.Conv2d(in_channels, embed_dim, patch_size, stride=patch_size)\n \n self.pos_embed_size = 32\n self.pos_embed = nn.Parameter(\n torch.randn(1, embed_dim, self.pos_embed_size, self.pos_embed_size) * 0.02\n )\n \n self.input_proj = nn.Sequential(\n nn.Conv2d(embed_dim, embed_dim, 3, padding=1, groups=embed_dim, bias=False),\n nn.Conv2d(embed_dim, embed_dim, 1, bias=True),\n nn.GELU(),\n )\n \n self.blocks = nn.ModuleList([\n LiquidBlock(embed_dim, cond_dim, spatial_kernel, scan_kernel,\n expand_ratio, mlp_ratio, drop_rate, use_zigzag)\n for _ in range(depth)\n ])\n \n self.final_norm = nn.GroupNorm(32, embed_dim)\n self.final_proj = nn.Sequential(\n nn.Conv2d(embed_dim, embed_dim, 3, padding=1, bias=True),\n nn.GELU(),\n )\n \n self.unpatch = nn.ConvTranspose2d(embed_dim, in_channels, patch_size, stride=patch_size)\n nn.init.zeros_(self.unpatch.weight)\n nn.init.zeros_(self.unpatch.bias)\n \n self.apply(self._init_weights)\n \n def _init_weights(self, m):\n if isinstance(m, nn.Conv2d):\n nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n if m.bias is not None:\n nn.init.zeros_(m.bias)\n elif isinstance(m, nn.Linear):\n nn.init.xavier_uniform_(m.weight)\n if m.bias is not None:\n nn.init.zeros_(m.bias)\n elif isinstance(m, nn.Embedding):\n nn.init.normal_(m.weight, std=0.02)\n \n def _interpolate_pos_embed(self, H: int, W: int) -> torch.Tensor:\n if H == self.pos_embed_size and W == self.pos_embed_size:\n return self.pos_embed\n return F.interpolate(self.pos_embed, size=(H, W), mode='bilinear', align_corners=False)\n \n def forward(\n self, x: torch.Tensor, t: torch.Tensor, class_labels: Optional[torch.Tensor] = None,\n ) -> torch.Tensor:\n \"\"\"\n Predict velocity field for flow matching.\n Args:\n x: [B, C, H, W] noisy latent (C=16 for Flux VAE)\n t: [B] timestep in [0, 1]\n class_labels: [B] optional class labels\n Returns:\n v: [B, C, H, W] predicted velocity\n \"\"\"\n cond = self.time_embed(t)\n if self.class_embed is not None and class_labels is not None:\n drop_p = self.class_drop_prob if self.training else 0.0\n cond = cond + self.class_embed(class_labels, drop_prob=drop_p)\n \n h = self.patch_embed(x)\n B, C, H_p, W_p = h.shape\n h = h + self._interpolate_pos_embed(H_p, W_p)\n h = self.input_proj(h)\n \n # U-Net style long skip connections\n skip_connections = []\n mid = self.depth // 2\n for i, block in enumerate(self.blocks):\n if i < mid:\n skip_connections.append(h)\n elif i >= mid and len(skip_connections) > 0:\n skip = skip_connections.pop()\n h = h + skip\n h = block(h, cond)\n \n h = self.final_norm(h)\n h = self.final_proj(h)\n v = self.unpatch(h)\n return v\n \n def count_params(self) -> int:\n return sum(p.numel() for p in self.parameters() if p.requires_grad)\n\n\n# =============================================================================\n# Model Presets\n# =============================================================================\n\ndef liquidgen_small(**kwargs) -> LiquidGen:\n \"\"\"~55M params - for 256px, fast training/testing\"\"\"\n defaults = dict(\n embed_dim=512, depth=12, spatial_kernel=7, scan_kernel=31,\n expand_ratio=2.0, mlp_ratio=3.0, use_zigzag=True,\n )\n defaults.update(kwargs)\n return LiquidGen(**defaults)\n\ndef liquidgen_base(**kwargs) -> LiquidGen:\n \"\"\"~140M params - for 256/512px, balanced (fits T4 16GB easily)\"\"\"\n defaults = dict(\n embed_dim=640, depth=18, spatial_kernel=7, scan_kernel=31,\n expand_ratio=2.0, mlp_ratio=4.0, use_zigzag=True,\n )\n defaults.update(kwargs)\n return LiquidGen(**defaults)\n\ndef liquidgen_large(**kwargs) -> LiquidGen:\n \"\"\"~280M params - for 512px, high quality (fits T4 16GB with small batch)\"\"\"\n defaults = dict(\n embed_dim=768, depth=24, spatial_kernel=7, scan_kernel=31,\n expand_ratio=2.5, mlp_ratio=4.0, use_zigzag=True,\n )\n defaults.update(kwargs)\n return LiquidGen(**defaults)\n\n\nif __name__ == \"__main__\":\n device = \"cpu\"\n for name, factory in [(\"Small\", liquidgen_small), (\"Base\", liquidgen_base), (\"Large\", liquidgen_large)]:\n model = factory(num_classes=27).to(device)\n print(f\"LiquidGen-{name}: {model.count_params() / 1e6:.1f}M params\")\n \n x = torch.randn(2, 16, 32, 32, device=device)\n t = torch.rand(2, device=device)\n labels = torch.randint(0, 27, (2,), device=device)\n v = model(x, t, labels)\n assert v.shape == x.shape\n \n x512 = torch.randn(1, 16, 64, 64, device=device)\n v512 = model(x512, t[:1], labels[:1])\n assert v512.shape == x512.shape\n print(f\" 256px \u2705 512px \u2705\")\n del model\n \n print(\"\\n\u2705 All tests passed!\")\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## \ud83d\udd04 Training Utilities"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os, json, time, math\n",
"import numpy as np\n",
"from torch.utils.data import DataLoader, Dataset\n",
"from torch.amp import autocast, GradScaler\n",
"from torchvision import transforms\n",
"from torchvision.utils import save_image\n",
"from PIL import Image\n",
"\n",
"class FlowMatchingScheduler:\n",
" def __init__(self, min_t=0.001, max_t=0.999): self.min_t, self.max_t = min_t, max_t\n",
" def sample_timesteps(self, bs, dev): return torch.rand(bs, device=dev) * (self.max_t - self.min_t) + self.min_t\n",
" def add_noise(self, x0, noise, t): t = t.view(-1,1,1,1); return (1-t)*x0 + t*noise\n",
" def get_velocity_target(self, x0, noise): return noise - x0\n",
" @torch.no_grad()\n",
" def sample(self, model, shape, dev, num_steps=50, labels=None, cfg=1.0):\n",
" model.eval(); x = torch.randn(shape, device=dev)\n",
" dt = 1.0 / num_steps\n",
" for t_val in torch.linspace(1.0, dt, num_steps, device=dev):\n",
" t = torch.full((shape[0],), t_val.item(), device=dev)\n",
" with torch.amp.autocast(\"cuda\"):\n",
" if cfg > 1.0 and labels is not None:\n",
" vc = model(x,t,labels); vu = model(x,t,torch.zeros_like(labels))\n",
" v = vu + cfg * (vc - vu)\n",
" else: v = model(x, t, labels)\n",
" x = x - dt * v.float()\n",
" return x\n",
"\n",
"class EMAModel:\n",
" def __init__(self, model, decay=0.9999):\n",
" self.decay = decay\n",
" self.shadow = {n: p.clone().detach() for n,p in model.named_parameters() if p.requires_grad}\n",
" @torch.no_grad()\n",
" def update(self, model):\n",
" for n,p in model.named_parameters():\n",
" if p.requires_grad and n in self.shadow: self.shadow[n].mul_(self.decay).add_(p.data, alpha=1-self.decay)\n",
" def apply(self, model):\n",
" self.backup = {n: p.data.clone() for n,p in model.named_parameters() if p.requires_grad}\n",
" for n,p in model.named_parameters():\n",
" if p.requires_grad and n in self.shadow: p.data.copy_(self.shadow[n])\n",
" def restore(self, model):\n",
" for n,p in model.named_parameters():\n",
" if p.requires_grad and n in self.backup: p.data.copy_(self.backup[n])\n",
"\n",
"class ImageDataset(Dataset):\n",
" def __init__(self, ds, tf, img_col, lbl_col=\"\"): self.ds, self.tf, self.ic, self.lc = ds, tf, img_col, lbl_col\n",
" def __len__(self): return len(self.ds)\n",
" def __getitem__(self, i):\n",
" item = self.ds[i]; img = item[self.ic]\n",
" if img.mode != \"RGB\": img = img.convert(\"RGB\")\n",
" label = item[self.lc] if self.lc and self.lc in item else -1\n",
" return self.tf(img), label\n",
"\n",
"def cosine_sched(opt, warmup, total):\n",
" def lr(s):\n",
" if s < warmup: return s / max(1, warmup)\n",
" return max(0, 0.5*(1+math.cos(math.pi*(s-warmup)/max(1,total-warmup))))\n",
" return torch.optim.lr_scheduler.LambdaLR(opt, lr)\n",
"\n",
"MODEL_CONFIGS = {\n",
" \"small\": dict(embed_dim=512, depth=12, spatial_kernel=7, scan_kernel=31, expand_ratio=2.0, mlp_ratio=3.0),\n",
" \"base\": dict(embed_dim=640, depth=18, spatial_kernel=7, scan_kernel=31, expand_ratio=2.0, mlp_ratio=4.0),\n",
" \"large\": dict(embed_dim=768, depth=24, spatial_kernel=7, scan_kernel=31, expand_ratio=2.5, mlp_ratio=4.0),\n",
"}\n",
"print(\"\u2705 Training utilities ready!\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## \ud83d\udcca Load Dataset & VAE"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from datasets import load_dataset\n",
"from diffusers import AutoencoderKL\n",
"\n",
"print(f\"Loading dataset: {DATASET_NAME}...\")\n",
"dataset = load_dataset(DATASET_NAME, split=\"train\")\n",
"print(f\" {len(dataset)} images\")\n",
"\n",
"transform = transforms.Compose([\n",
" transforms.Resize(IMAGE_SIZE, interpolation=transforms.InterpolationMode.LANCZOS),\n",
" transforms.CenterCrop(IMAGE_SIZE), transforms.RandomHorizontalFlip(), transforms.ToTensor(),\n",
"])\n",
"train_ds = ImageDataset(dataset, transform, IMAGE_COLUMN, LABEL_COLUMN)\n",
"train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)\n",
"\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"vae = AutoencoderKL.from_pretrained(VAE_ID, subfolder=VAE_SUBFOLDER, torch_dtype=torch.float16).to(device).eval()\n",
"for p in vae.parameters(): p.requires_grad_(False)\n",
"print(f\"VAE: {sum(p.numel() for p in vae.parameters())/1e6:.1f}M params (frozen)\")\n",
"SCALE, SHIFT = 0.3611, 0.1159\n",
"print(\"\u2705 Ready!\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## \ud83c\udfcb\ufe0f Create Model & Train"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"cfg = MODEL_CONFIGS[MODEL_SIZE].copy()\n",
"cfg[\"num_classes\"] = NUM_CLASSES; cfg[\"class_drop_prob\"] = 0.1; cfg[\"use_zigzag\"] = True\n",
"model = LiquidGen(**cfg).to(device)\n",
"print(f\"LiquidGen-{MODEL_SIZE}: {model.count_params()/1e6:.1f}M params\")\n",
"\n",
"optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)\n",
"total_steps = len(train_loader) * NUM_EPOCHS // GRADIENT_ACCUMULATION\n",
"scheduler = cosine_sched(optimizer, WARMUP_STEPS, total_steps)\n",
"ema = EMAModel(model, EMA_DECAY)\n",
"scaler = GradScaler(\"cuda\")\n",
"fm = FlowMatchingScheduler()\n",
"os.makedirs(f\"{OUTPUT_DIR}/samples\", exist_ok=True)\n",
"os.makedirs(f\"{OUTPUT_DIR}/checkpoints\", exist_ok=True)\n",
"print(f\"Total steps: {total_steps}, Effective batch: {BATCH_SIZE*GRADIENT_ACCUMULATION}\")\n",
"if torch.cuda.is_available(): print(f\"VRAM: {torch.cuda.memory_allocated()/1024**3:.2f} GB used\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"global_step = 0; loss_accum = 0.0; log_losses = []\n",
"print(\"\ud83d\ude80 Training!\n\")\n",
"t0 = time.time()\n",
"for epoch in range(NUM_EPOCHS):\n",
" model.train(); ep_loss = 0; ep_steps = 0; ep_t = time.time()\n",
" for bi, (imgs, lbls) in enumerate(train_loader):\n",
" imgs = imgs.to(device)\n",
" lbls = lbls.to(device) if NUM_CLASSES > 0 else None\n",
" with torch.no_grad():\n",
" lats = vae.encode(imgs.half()*2-1).latent_dist.sample()\n",
" lats = ((lats - SHIFT) * SCALE).float()\n",
" t = fm.sample_timesteps(lats.shape[0], device)\n",
" noise = torch.randn_like(lats)\n",
" xt = fm.add_noise(lats, noise, t)\n",
" vtgt = fm.get_velocity_target(lats, noise)\n",
" with autocast(\"cuda\"): loss = F.mse_loss(model(xt, t, lbls), vtgt) / GRADIENT_ACCUMULATION\n",
" scaler.scale(loss).backward()\n",
" loss_accum += loss.item()\n",
" if (bi+1) % GRADIENT_ACCUMULATION == 0:\n",
" scaler.unscale_(optimizer)\n",
" gn = torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)\n",
" scaler.step(optimizer); scaler.update(); optimizer.zero_grad(); scheduler.step()\n",
" ema.update(model); global_step += 1\n",
" if global_step % LOG_EVERY == 0:\n",
" al = loss_accum / LOG_EVERY; lr = optimizer.param_groups[0][\"lr\"]\n",
" vram = torch.cuda.memory_allocated()/1024**3 if torch.cuda.is_available() else 0\n",
" print(f\"step={global_step:>6d} | ep={epoch} | loss={al:.4f} | gn={gn:.2f} | lr={lr:.2e} | vram={vram:.1f}G\")\n",
" log_losses.append(al); loss_accum = 0\n",
" if math.isnan(al) or al > 50: print(\"\ud83d\udca5 Diverged!\"); break\n",
" if global_step % SAMPLE_EVERY == 0:\n",
" ema.apply(model); model.eval()\n",
" ls = IMAGE_SIZE // 8\n",
" sl = torch.randint(0, max(1,NUM_CLASSES), (4,), device=device) if NUM_CLASSES > 0 else None\n",
" samp = fm.sample(model, (4,16,ls,ls), device, NUM_SAMPLE_STEPS, sl, CFG_SCALE)\n",
" with torch.no_grad(): si = ((vae.decode(samp.half()/SCALE+SHIFT).sample+1)/2).clamp(0,1).float()\n",
" save_image(si, f\"{OUTPUT_DIR}/samples/step_{global_step:07d}.png\", nrow=2)\n",
" ema.restore(model); model.train()\n",
" if global_step % SAVE_EVERY == 0:\n",
" torch.save({\"model\":model.state_dict(),\"ema\":ema.shadow,\"step\":global_step,\"cfg\":cfg},\n",
" f\"{OUTPUT_DIR}/checkpoints/step_{global_step:07d}.pt\")\n",
" ep_loss += loss.item()*GRADIENT_ACCUMULATION; ep_steps += 1\n",
" print(f\"Epoch {epoch} | loss={ep_loss/max(ep_steps,1):.4f} | {time.time()-ep_t:.0f}s\")\n",
"torch.save({\"model\":model.state_dict(),\"ema\":ema.shadow,\"cfg\":cfg,\"step\":global_step},f\"{OUTPUT_DIR}/checkpoints/final.pt\")\n",
"print(f\"\ud83c\udf89 Done! {global_step} steps in {(time.time()-t0)/60:.1f} min\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## \ud83d\udcc8 Training Loss"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"if log_losses:\n",
" plt.figure(figsize=(10,4)); plt.plot(log_losses); plt.xlabel(f\"Steps (\u00d7{LOG_EVERY})\"); plt.ylabel(\"Loss\")\n",
" plt.title(\"Training Loss\"); plt.grid(True, alpha=0.3); plt.savefig(f\"{OUTPUT_DIR}/loss.png\", dpi=150); plt.show()\n",
" print(f\"Min loss: {min(log_losses):.4f}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## \ud83c\udfa8 Generate Images"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ema.apply(model); model.eval()\n",
"N, STEPS, CFG = 8, 50, 2.5\n",
"ls = IMAGE_SIZE // 8\n",
"STYLES = [\"Abstract Expressionism\",\"Baroque\",\"Cubism\",\"Expressionism\",\"Impressionism\",\n",
" \"Pop Art\",\"Realism\",\"Romanticism\",\"Symbolism\",\"Ukiyo-e\"]\n",
"if NUM_CLASSES > 0:\n",
" for ci in range(min(NUM_CLASSES, 8)):\n",
" l = torch.full((N,), ci, device=device, dtype=torch.long)\n",
" s = fm.sample(model, (N,16,ls,ls), device, STEPS, l, CFG)\n",
" with torch.no_grad(): i = ((vae.decode(s.half()/SCALE+SHIFT).sample+1)/2).clamp(0,1).float()\n",
" nm = STYLES[ci] if ci < len(STYLES) else f\"Class_{ci}\"\n",
" save_image(i, f\"{OUTPUT_DIR}/gen_{nm.replace(chr(32),chr(95))}.png\", nrow=4)\n",
" print(f\"Generated: {nm}\")\n",
"else:\n",
" s = fm.sample(model, (N,16,ls,ls), device, STEPS)\n",
" with torch.no_grad(): i = ((vae.decode(s.half()/SCALE+SHIFT).sample+1)/2).clamp(0,1).float()\n",
" save_image(i, f\"{OUTPUT_DIR}/gen_uncond.png\", nrow=4)\n",
"ema.restore(model)\n",
"print(f\"\u2705 Saved to {OUTPUT_DIR}/\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## \ud83d\udce4 Display Results"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import display\n",
"import glob\n",
"for f in sorted(glob.glob(f\"{OUTPUT_DIR}/samples/*.png\"))[-3:]:\n",
" print(os.path.basename(f)); display(Image.open(f))\n",
"for f in sorted(glob.glob(f\"{OUTPUT_DIR}/gen_*.png\")):\n",
" print(os.path.basename(f)); display(Image.open(f))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## \ud83d\udcdd Architecture Reference\n",
"\n",
"### Core Equation (CfC Liquid Dynamics)\n",
"\n",
"\n",
"### Flow Matching\n",
"\n",
"\n",
"### Sampling (Euler ODE)\n",
"\n",
"\n",
"### References\n",
"- Hasani et al., \"Liquid Time-constant Networks\" (NeurIPS 2020)\n",
"- Hasani et al., \"Closed-form Continuous-depth Models\" (Nature MI 2022)\n",
"- Lechner et al., \"Neural Circuit Policies\" (Nature MI 2020)\n",
"- ZigMa (ECCV 2024), DiMSUM (NeurIPS 2024)\n",
"- Lipman et al., \"Flow Matching\" (2023), SiT (2024)\n"
]
}
]
} |