krystv commited on
Commit
6820907
·
verified ·
1 Parent(s): a9fae37

Upload LiquidDiffusion_Training.ipynb

Browse files
Files changed (1) hide show
  1. LiquidDiffusion_Training.ipynb +675 -873
LiquidDiffusion_Training.ipynb CHANGED
@@ -1,876 +1,678 @@
1
  {
2
- "nbformat": 4,
3
- "nbformat_minor": 0,
4
- "metadata": {
5
- "colab": {
6
- "provenance": [],
7
- "gpuType": "T4"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  },
9
- "kernelspec": {
10
- "name": "python3",
11
- "display_name": "Python 3"
12
- },
13
- "accelerator": "GPU"
14
- },
15
- "cells": [
16
- {
17
- "cell_type": "markdown",
18
- "metadata": {},
19
- "source": [
20
- "# 🌊 LiquidDiffusion: Attention-Free Image Generation with Liquid Neural Networks\n",
21
- "\n",
22
- "**A novel image generation model** combining:\n",
23
- "- **Liquid Neural Networks** (CfC — Closed-form Continuous-depth) for adaptive, time-aware processing\n",
24
- "- **Rectified Flow** for simple, stable training (MSE velocity prediction)\n",
25
- "- **Zero attention** — fully convolutional with multi-scale spatial mixing\n",
26
- "- **Fully parallelizable** — no sequential ODE loops or recurrence\n",
27
- "\n",
28
- "### Key Innovation\n",
29
- "The diffusion timestep serves as the **liquid time constant** — the CfC gate `σ(-f·t)` naturally adapts the network's behavior based on noise level, giving input-dependent processing without attention.\n",
30
- "\n",
31
- "### References\n",
32
- "- CfC Networks: [Hasani et al., Nature MI 2022](https://arxiv.org/abs/2106.13898)\n",
33
- "- LiquidTAD (parallel CfC): [arxiv 2604.18274](https://arxiv.org/abs/2604.18274)\n",
34
- "- USM (U-Shape Mamba): [arxiv 2504.13499](https://arxiv.org/abs/2504.13499)\n",
35
- "- Rectified Flow: [Liu et al., ICLR 2023](https://arxiv.org/abs/2209.03003)\n",
36
- "\n",
37
- "**Model repo**: [krystv/liquid-diffusion](https://huggingface.co/krystv/liquid-diffusion)\n",
38
- "\n",
39
- "---"
40
- ]
41
- },
42
- {
43
- "cell_type": "markdown",
44
- "metadata": {},
45
- "source": [
46
- "## ⚙️ Configuration\n",
47
- "\n",
48
- "Choose your settings here. Everything is configurable from this one cell."
49
- ]
50
- },
51
- {
52
- "cell_type": "code",
53
- "execution_count": null,
54
- "metadata": {},
55
- "outputs": [],
56
- "source": [
57
- "# ============================================================================\n",
58
- "# CONFIGURATION — Edit this cell to customize your training\n",
59
- "# ============================================================================\n",
60
- "\n",
61
- "# --- Model Size ---\n",
62
- "# 'tiny' = ~23M params, best for 256px, fits easily in T4 16GB\n",
63
- "# 'small' = ~69M params, better quality 256px, tight fit on T4\n",
64
- "# 'base' = ~154M params, for 512px (needs A100 or reduce batch size)\n",
65
- "# 'custom' = define your own channels/blocks below\n",
66
- "MODEL_SIZE = 'tiny' # @param ['tiny', 'small', 'base', 'custom']\n",
67
- "\n",
68
- "# Custom model config (only used if MODEL_SIZE='custom')\n",
69
- "CUSTOM_CHANNELS = [48, 96, 192] # channel dims per stage\n",
70
- "CUSTOM_BLOCKS = [1, 2, 3] # blocks per stage\n",
71
- "CUSTOM_T_DIM = 192 # time embedding dimension\n",
72
- "\n",
73
- "# --- Image Resolution ---\n",
74
- "IMAGE_SIZE = 256 # @param [64, 128, 256, 512] {type:\"integer\"}\n",
75
- "\n",
76
- "# --- Dataset ---\n",
77
- "# Options:\n",
78
- "# 'huggan/CelebA-HQ' - 30K celebrity faces (256px native)\n",
79
- "# 'huggan/flowers-102-categories' - Flowers dataset\n",
80
- "# 'huggan/anime-faces' - Anime faces\n",
81
- "# 'lambdalabs/naruto-blip-captions' - Naruto illustrations\n",
82
- "# 'jlbaker361/CelebA-HQ-256' - CelebA-HQ at 256px\n",
83
- "# Or any HF dataset with an 'image' column\n",
84
- "# Or a local folder path with images\n",
85
- "DATASET = 'huggan/CelebA-HQ' # @param {type:\"string\"}\n",
86
- "IMAGE_COLUMN = 'image' # column name in HF dataset containing images\n",
87
- "MAX_SAMPLES = None # Set to e.g. 1000 for quick testing, None for full dataset\n",
88
- "\n",
89
- "# --- Training ---\n",
90
- "BATCH_SIZE = 8 # @param {type:\"integer\"}\n",
91
- "LEARNING_RATE = 1e-4 # @param {type:\"number\"}\n",
92
- "WEIGHT_DECAY = 0.01 # @param {type:\"number\"}\n",
93
- "NUM_EPOCHS = 100 # @param {type:\"integer\"}\n",
94
- "GRAD_CLIP = 1.0 # @param {type:\"number\"}\n",
95
- "EMA_DECAY = 0.9999 # @param {type:\"number\"}\n",
96
- "NUM_WORKERS = 2 # DataLoader workers\n",
97
- "\n",
98
- "# --- Time Sampling ---\n",
99
- "# 'logit_normal' (from SD3) = more weight on intermediate timesteps\n",
100
- "# 'uniform' = standard\n",
101
- "TIME_SAMPLING = 'logit_normal' # @param ['uniform', 'logit_normal']\n",
102
- "\n",
103
- "# --- Mixed Precision ---\n",
104
- "USE_AMP = True # @param {type:\"boolean\"}\n",
105
- "AMP_DTYPE = 'float16' # @param ['float16', 'bfloat16']\n",
106
- "\n",
107
- "# --- Sampling ---\n",
108
- "SAMPLE_EVERY = 500 # Generate samples every N steps\n",
109
- "NUM_SAMPLE_IMAGES = 8 # Images to generate per sample\n",
110
- "NUM_EULER_STEPS = 50 # Euler ODE steps (more = better quality)\n",
111
- "\n",
112
- "# --- Checkpointing ---\n",
113
- "SAVE_EVERY = 2000 # Save checkpoint every N steps\n",
114
- "OUTPUT_DIR = './outputs' # Where to save everything\n",
115
- "RESUME_FROM = None # Path to checkpoint to resume from, or None\n",
116
- "\n",
117
- "# --- Logging ---\n",
118
- "LOG_EVERY = 50 # Print loss every N steps\n",
119
- "\n",
120
- "print(f\"Config: {MODEL_SIZE} model, {IMAGE_SIZE}px, dataset={DATASET}\")\n",
121
- "print(f\"Training: bs={BATCH_SIZE}, lr={LEARNING_RATE}, epochs={NUM_EPOCHS}\")\n",
122
- "print(f\"AMP: {USE_AMP} ({AMP_DTYPE}), Time sampling: {TIME_SAMPLING}\")"
123
- ]
124
- },
125
- {
126
- "cell_type": "markdown",
127
- "metadata": {},
128
- "source": [
129
- "## 📦 Install Dependencies"
130
- ]
131
- },
132
- {
133
- "cell_type": "code",
134
- "execution_count": null,
135
- "metadata": {},
136
- "outputs": [],
137
- "source": [
138
- "!pip install -q datasets huggingface_hub Pillow matplotlib\n",
139
- "\n",
140
- "import torch\n",
141
- "print(f\"PyTorch: {torch.__version__}\")\n",
142
- "print(f\"CUDA available: {torch.cuda.is_available()}\")\n",
143
- "if torch.cuda.is_available():\n",
144
- " print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n",
145
- " print(f\"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB\")"
146
- ]
147
- },
148
- {
149
- "cell_type": "markdown",
150
- "metadata": {},
151
- "source": [
152
- "## 🏗️ Model Architecture\n",
153
- "\n",
154
- "The complete LiquidDiffusion model — defined inline so you can inspect and modify everything."
155
- ]
156
- },
157
- {
158
- "cell_type": "code",
159
- "execution_count": null,
160
- "metadata": {},
161
- "outputs": [],
162
- "source": [
163
- "import math\n",
164
- "import copy\n",
165
- "import os\n",
166
- "import time\n",
167
- "import json\n",
168
- "from glob import glob\n",
169
- "\n",
170
- "import torch\n",
171
- "import torch.nn as nn\n",
172
- "import torch.nn.functional as F\n",
173
- "from torch.utils.data import DataLoader, Dataset\n",
174
- "from torchvision import transforms\n",
175
- "from torchvision.utils import save_image, make_grid\n",
176
- "\n",
177
- "\n",
178
- "# ========================= TIME EMBEDDING =========================\n",
179
- "\n",
180
- "class SinusoidalTimeEmbedding(nn.Module):\n",
181
- " \"\"\"Sinusoidal position encoding + MLP for timestep embedding.\"\"\"\n",
182
- " def __init__(self, dim, max_period=10000):\n",
183
- " super().__init__()\n",
184
- " self.dim = dim\n",
185
- " self.max_period = max_period\n",
186
- " self.mlp = nn.Sequential(nn.Linear(dim, dim*4), nn.SiLU(), nn.Linear(dim*4, dim))\n",
187
- "\n",
188
- " def forward(self, t):\n",
189
- " half = self.dim // 2\n",
190
- " freqs = torch.exp(-math.log(self.max_period) * torch.arange(half, device=t.device, dtype=t.dtype) / half)\n",
191
- " args = t[:, None] * freqs[None, :]\n",
192
- " emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)\n",
193
- " if self.dim % 2: emb = F.pad(emb, (0, 1))\n",
194
- " return self.mlp(emb)\n",
195
- "\n",
196
- "\n",
197
- "# ========================= ADAPTIVE LAYER NORM =========================\n",
198
- "\n",
199
- "class AdaLN(nn.Module):\n",
200
- " \"\"\"Adaptive LayerNorm: norm(x) * (1+scale(t)) + shift(t)\"\"\"\n",
201
- " def __init__(self, dim, cond_dim):\n",
202
- " super().__init__()\n",
203
- " ng = min(32, dim)\n",
204
- " while dim % ng != 0: ng -= 1\n",
205
- " self.norm = nn.GroupNorm(ng, dim, affine=False)\n",
206
- " self.proj = nn.Sequential(nn.SiLU(), nn.Linear(cond_dim, dim * 2))\n",
207
- "\n",
208
- " def forward(self, x, t_emb):\n",
209
- " s, sh = self.proj(t_emb).chunk(2, dim=1)\n",
210
- " return self.norm(x) * (1 + s[:,:,None,None]) + sh[:,:,None,None]\n",
211
- "\n",
212
- "\n",
213
- "# ========================= PARALLEL CfC BLOCK =========================\n",
214
- "\n",
215
- "class ParallelCfCBlock(nn.Module):\n",
216
- " \"\"\"\n",
217
- " Parallel Closed-form Continuous-depth (CfC) block.\n",
218
- " \n",
219
- " CfC Eq.10: x(t) = σ(-f·t) ⊙ g + (1-σ(-f·t)) ⊙ h\n",
220
- " \n",
221
- " • f/g/h heads operate on 2D feature maps (depthwise conv)\n",
222
- " • Diffusion timestep t IS the liquid time constant\n",
223
- " • No recurrence, no ODE solver — fully parallel\n",
224
- " • Liquid relaxation: α·residual + (1-α)·CfC_out, α=exp(-λ·t)\n",
225
- " \"\"\"\n",
226
- " def __init__(self, dim, t_dim, expand_ratio=2.0, kernel_size=7, dropout=0.0):\n",
227
- " super().__init__()\n",
228
- " hidden = int(dim * expand_ratio)\n",
229
- " self.backbone_dw = nn.Conv2d(dim, dim, kernel_size, padding=kernel_size//2, groups=dim)\n",
230
- " self.backbone_pw = nn.Conv2d(dim, hidden, 1)\n",
231
- " self.backbone_act = nn.SiLU()\n",
232
- " self.f_head = nn.Conv2d(hidden, dim, 1)\n",
233
- " self.g_head = nn.Sequential(\n",
234
- " nn.Conv2d(hidden, hidden, kernel_size, padding=kernel_size//2, groups=hidden),\n",
235
- " nn.SiLU(), nn.Conv2d(hidden, dim, 1))\n",
236
- " self.h_head = nn.Sequential(\n",
237
- " nn.Conv2d(hidden, hidden, kernel_size, padding=kernel_size//2, groups=hidden),\n",
238
- " nn.SiLU(), nn.Conv2d(hidden, dim, 1))\n",
239
- " self.time_a = nn.Linear(t_dim, dim)\n",
240
- " self.time_b = nn.Linear(t_dim, dim)\n",
241
- " self.rho = nn.Parameter(torch.zeros(1, dim, 1, 1))\n",
242
- " self.output_gate = nn.Sequential(nn.SiLU(), nn.Linear(t_dim, dim))\n",
243
- " self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()\n",
244
- "\n",
245
- " def forward(self, x, t_emb):\n",
246
- " residual = x\n",
247
- " backbone = self.backbone_act(self.backbone_pw(self.backbone_dw(x)))\n",
248
- " f, g, h = self.f_head(backbone), self.g_head(backbone), self.h_head(backbone)\n",
249
- " ta = self.time_a(t_emb)[:,:,None,None]\n",
250
- " tb = self.time_b(t_emb)[:,:,None,None]\n",
251
- " gate = torch.sigmoid(ta * f - tb)\n",
252
- " cfc_out = self.dropout(gate * g + (1.0 - gate) * h)\n",
253
- " t_scalar = t_emb.mean(dim=1, keepdim=True)[:,:,None,None]\n",
254
- " lam = F.softplus(self.rho) + 1e-6\n",
255
- " alpha = torch.exp(-lam * t_scalar.abs().clamp(min=0.01))\n",
256
- " out = alpha * residual + (1.0 - alpha) * cfc_out\n",
257
- " return out * torch.sigmoid(self.output_gate(t_emb))[:,:,None,None]\n",
258
- "\n",
259
- "\n",
260
- "# ========================= MULTI-SCALE SPATIAL MIXING =========================\n",
261
- "\n",
262
- "class MultiScaleSpatialMix(nn.Module):\n",
263
- " \"\"\"Multi-scale depthwise conv + global pooling (replaces attention).\"\"\"\n",
264
- " def __init__(self, dim, t_dim):\n",
265
- " super().__init__()\n",
266
- " self.dw3 = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)\n",
267
- " self.dw5 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)\n",
268
- " self.dw7 = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)\n",
269
- " self.global_pool = nn.AdaptiveAvgPool2d(1)\n",
270
- " self.global_proj = nn.Conv2d(dim, dim, 1)\n",
271
- " self.merge = nn.Conv2d(dim*4, dim, 1)\n",
272
- " self.act = nn.SiLU()\n",
273
- " self.adaln = AdaLN(dim, t_dim)\n",
274
- "\n",
275
- " def forward(self, x, t_emb):\n",
276
- " xn = self.adaln(x, t_emb)\n",
277
- " return x + self.act(self.merge(torch.cat([\n",
278
- " self.dw3(xn), self.dw5(xn), self.dw7(xn),\n",
279
- " self.global_proj(self.global_pool(xn)).expand_as(xn)], dim=1)))\n",
280
- "\n",
281
- "\n",
282
- "# ========================= LIQUID DIFFUSION BLOCK =========================\n",
283
- "\n",
284
- "class LiquidDiffusionBlock(nn.Module):\n",
285
- " \"\"\"AdaLN → CfC → SpatialMix → FF with residual scaling.\"\"\"\n",
286
- " def __init__(self, dim, t_dim, expand_ratio=2.0, kernel_size=7, dropout=0.0):\n",
287
- " super().__init__()\n",
288
- " self.adaln1 = AdaLN(dim, t_dim)\n",
289
- " self.cfc = ParallelCfCBlock(dim, t_dim, expand_ratio, kernel_size, dropout)\n",
290
- " self.spatial_mix = MultiScaleSpatialMix(dim, t_dim)\n",
291
- " self.adaln2 = AdaLN(dim, t_dim)\n",
292
- " ff_dim = int(dim * expand_ratio)\n",
293
- " self.ff = nn.Sequential(nn.Conv2d(dim, ff_dim, 1), nn.SiLU(), nn.Conv2d(ff_dim, dim, 1))\n",
294
- " self.res_scale = nn.Parameter(torch.ones(1) * 0.1)\n",
295
- "\n",
296
- " def forward(self, x, t_emb):\n",
297
- " x = x + self.res_scale * self.cfc(self.adaln1(x, t_emb), t_emb)\n",
298
- " x = self.spatial_mix(x, t_emb)\n",
299
- " x = x + self.res_scale * self.ff(self.adaln2(x, t_emb))\n",
300
- " return x\n",
301
- "\n",
302
- "\n",
303
- "# ========================= SPATIAL OPS =========================\n",
304
- "\n",
305
- "class DownSample(nn.Module):\n",
306
- " def __init__(self, in_d, out_d):\n",
307
- " super().__init__()\n",
308
- " self.conv = nn.Conv2d(in_d, out_d, 3, stride=2, padding=1)\n",
309
- " def forward(self, x): return self.conv(x)\n",
310
- "\n",
311
- "class UpSample(nn.Module):\n",
312
- " def __init__(self, in_d, out_d):\n",
313
- " super().__init__()\n",
314
- " self.conv = nn.Conv2d(in_d, out_d, 3, padding=1)\n",
315
- " def forward(self, x): return self.conv(F.interpolate(x, scale_factor=2, mode='nearest'))\n",
316
- "\n",
317
- "class SkipFusion(nn.Module):\n",
318
- " def __init__(self, dim, t_dim):\n",
319
- " super().__init__()\n",
320
- " self.proj = nn.Conv2d(dim*2, dim, 1)\n",
321
- " self.gate = nn.Sequential(nn.SiLU(), nn.Linear(t_dim, dim), nn.Sigmoid())\n",
322
- " def forward(self, x, skip, t_emb):\n",
323
- " m = self.proj(torch.cat([x, skip], dim=1))\n",
324
- " g = self.gate(t_emb)[:,:,None,None]\n",
325
- " return m * g + x * (1 - g)\n",
326
- "\n",
327
- "\n",
328
- "# ========================= LIQUID DIFFUSION U-NET =========================\n",
329
- "\n",
330
- "class LiquidDiffusionUNet(nn.Module):\n",
331
- " \"\"\"\n",
332
- " LiquidDiffusion: Attention-Free Image Generation with Liquid Neural Networks.\n",
333
- " U-Net with Parallel CfC blocks. Diffusion timestep = liquid time constant.\n",
334
- " \"\"\"\n",
335
- " def __init__(self, in_channels=3, channels=None, blocks_per_stage=None,\n",
336
- " t_dim=256, expand_ratio=2.0, kernel_size=7, dropout=0.0):\n",
337
- " super().__init__()\n",
338
- " channels = channels or [64, 128, 256]\n",
339
- " blocks_per_stage = blocks_per_stage or [2, 2, 4]\n",
340
- " assert len(channels) == len(blocks_per_stage)\n",
341
- " self.channels, self.num_stages = channels, len(channels)\n",
342
- " \n",
343
- " self.time_embed = SinusoidalTimeEmbedding(t_dim)\n",
344
- " self.stem = nn.Sequential(\n",
345
- " nn.Conv2d(in_channels, channels[0], 3, padding=1), nn.SiLU(),\n",
346
- " nn.Conv2d(channels[0], channels[0], 3, padding=1))\n",
347
- " \n",
348
- " # Encoder\n",
349
- " self.encoder_blocks = nn.ModuleList()\n",
350
- " self.downsamplers = nn.ModuleList()\n",
351
- " for i in range(self.num_stages):\n",
352
- " stage = nn.ModuleList([LiquidDiffusionBlock(channels[i], t_dim, expand_ratio, kernel_size, dropout)\n",
353
- " for _ in range(blocks_per_stage[i])])\n",
354
- " self.encoder_blocks.append(stage)\n",
355
- " if i < self.num_stages - 1:\n",
356
- " self.downsamplers.append(DownSample(channels[i], channels[i+1]))\n",
357
- " \n",
358
- " # Bottleneck\n",
359
- " self.bottleneck = nn.ModuleList([\n",
360
- " LiquidDiffusionBlock(channels[-1], t_dim, expand_ratio, kernel_size, dropout),\n",
361
- " LiquidDiffusionBlock(channels[-1], t_dim, expand_ratio, kernel_size, dropout)])\n",
362
- " \n",
363
- " # Decoder\n",
364
- " self.decoder_blocks, self.upsamplers, self.skip_fusions = nn.ModuleList(), nn.ModuleList(), nn.ModuleList()\n",
365
- " for i in range(self.num_stages-1, -1, -1):\n",
366
- " if i < self.num_stages - 1:\n",
367
- " self.upsamplers.append(UpSample(channels[i+1], channels[i]))\n",
368
- " self.skip_fusions.append(SkipFusion(channels[i], t_dim))\n",
369
- " self.decoder_blocks.append(nn.ModuleList([\n",
370
- " LiquidDiffusionBlock(channels[i], t_dim, expand_ratio, kernel_size, dropout)\n",
371
- " for _ in range(blocks_per_stage[i])]))\n",
372
- " \n",
373
- " hg = min(32, channels[0])\n",
374
- " while channels[0] % hg != 0: hg -= 1\n",
375
- " self.head = nn.Sequential(nn.GroupNorm(hg, channels[0]), nn.SiLU(),\n",
376
- " nn.Conv2d(channels[0], in_channels, 3, padding=1))\n",
377
- " nn.init.zeros_(self.head[-1].weight); nn.init.zeros_(self.head[-1].bias)\n",
378
- "\n",
379
- " def forward(self, x, t):\n",
380
- " t_emb = self.time_embed(t)\n",
381
- " h = self.stem(x)\n",
382
- " skips = []\n",
383
- " for i in range(self.num_stages):\n",
384
- " for blk in self.encoder_blocks[i]: h = blk(h, t_emb)\n",
385
- " skips.append(h)\n",
386
- " if i < self.num_stages - 1: h = self.downsamplers[i](h)\n",
387
- " for blk in self.bottleneck: h = blk(h, t_emb)\n",
388
- " up_idx = 0\n",
389
- " for dec_i in range(self.num_stages):\n",
390
- " si = self.num_stages - 1 - dec_i\n",
391
- " if dec_i > 0:\n",
392
- " h = self.upsamplers[up_idx](h)\n",
393
- " h = self.skip_fusions[up_idx](h, skips[si], t_emb)\n",
394
- " up_idx += 1\n",
395
- " for blk in self.decoder_blocks[dec_i]: h = blk(h, t_emb)\n",
396
- " return self.head(h)\n",
397
- "\n",
398
- " def count_params(self):\n",
399
- " return sum(p.numel() for p in self.parameters()), sum(p.numel() for p in self.parameters() if p.requires_grad)\n",
400
- "\n",
401
- "\n",
402
- "print(\"✅ Model architecture defined.\")"
403
- ]
404
- },
405
- {
406
- "cell_type": "markdown",
407
- "metadata": {},
408
- "source": [
409
- "## 🔧 Build Model"
410
- ]
411
- },
412
- {
413
- "cell_type": "code",
414
- "execution_count": null,
415
- "metadata": {},
416
- "outputs": [],
417
- "source": [
418
- "# Build model based on config\n",
419
- "MODEL_CONFIGS = {\n",
420
- " 'tiny': dict(channels=[64, 128, 256], blocks_per_stage=[2, 2, 4], t_dim=256),\n",
421
- " 'small': dict(channels=[96, 192, 384], blocks_per_stage=[2, 3, 6], t_dim=384),\n",
422
- " 'base': dict(channels=[128, 256, 512], blocks_per_stage=[2, 4, 8], t_dim=512),\n",
423
- "}\n",
424
- "\n",
425
- "if MODEL_SIZE == 'custom':\n",
426
- " config = dict(channels=CUSTOM_CHANNELS, blocks_per_stage=CUSTOM_BLOCKS, t_dim=CUSTOM_T_DIM)\n",
427
- "else:\n",
428
- " config = MODEL_CONFIGS[MODEL_SIZE]\n",
429
- "\n",
430
- "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
431
- "model = LiquidDiffusionUNet(**config).to(device)\n",
432
- "total_params, trainable_params = model.count_params()\n",
433
- "\n",
434
- "print(f\"Model: {MODEL_SIZE}\")\n",
435
- "print(f\" Channels: {config['channels']}\")\n",
436
- "print(f\" Blocks: {config['blocks_per_stage']}\")\n",
437
- "print(f\" t_dim: {config['t_dim']}\")\n",
438
- "print(f\" Total parameters: {total_params:,} ({total_params/1e6:.1f}M)\")\n",
439
- "print(f\" Device: {device}\")\n",
440
- "\n",
441
- "# Quick forward pass test\n",
442
- "with torch.no_grad():\n",
443
- " test_x = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE, device=device)\n",
444
- " test_t = torch.tensor([0.5], device=device)\n",
445
- " test_out = model(test_x, test_t)\n",
446
- " print(f\" Forward pass OK: {test_x.shape} → {test_out.shape}\")\n",
447
- " del test_x, test_out\n",
448
- " if device == 'cuda':\n",
449
- " torch.cuda.empty_cache()\n",
450
- " print(f\" VRAM after test: {torch.cuda.memory_allocated()/1e9:.2f} GB\")"
451
- ]
452
- },
453
- {
454
- "cell_type": "markdown",
455
- "metadata": {},
456
- "source": [
457
- "## 📊 Load Dataset"
458
- ]
459
- },
460
- {
461
- "cell_type": "code",
462
- "execution_count": null,
463
- "metadata": {},
464
- "outputs": [],
465
- "source": [
466
- "from PIL import Image\n",
467
- "\n",
468
- "class ImageDataset(Dataset):\n",
469
- " def __init__(self, source, image_size=256, image_column='image', max_samples=None):\n",
470
- " self.image_column = image_column\n",
471
- " self.transform = transforms.Compose([\n",
472
- " transforms.Resize(image_size, interpolation=transforms.InterpolationMode.LANCZOS),\n",
473
- " transforms.CenterCrop(image_size),\n",
474
- " transforms.RandomHorizontalFlip(),\n",
475
- " transforms.ToTensor(),\n",
476
- " transforms.Normalize([0.5], [0.5]),\n",
477
- " ])\n",
478
- " if os.path.isdir(source):\n",
479
- " self.files = sorted(sum([glob(os.path.join(source, '**', f'*.{e}'), recursive=True)\n",
480
- " for e in ['png','jpg','jpeg','webp','bmp']], []))\n",
481
- " if max_samples: self.files = self.files[:max_samples]\n",
482
- " self.mode = 'folder'\n",
483
- " else:\n",
484
- " from datasets import load_dataset\n",
485
- " self.data = load_dataset(source, split='train')\n",
486
- " if max_samples: self.data = self.data.select(range(min(max_samples, len(self.data))))\n",
487
- " self.mode = 'hf'\n",
488
- " \n",
489
- " def __len__(self):\n",
490
- " return len(self.files) if self.mode == 'folder' else len(self.data)\n",
491
- " \n",
492
- " def __getitem__(self, idx):\n",
493
- " if self.mode == 'folder':\n",
494
- " img = Image.open(self.files[idx]).convert('RGB')\n",
495
- " else:\n",
496
- " img = self.data[idx][self.image_column]\n",
497
- " if not hasattr(img, 'convert'): img = Image.fromarray(img)\n",
498
- " img = img.convert('RGB')\n",
499
- " return self.transform(img)\n",
500
- "\n",
501
- "# Load dataset\n",
502
- "print(f\"Loading dataset: {DATASET}\")\n",
503
- "dataset = ImageDataset(DATASET, IMAGE_SIZE, IMAGE_COLUMN, MAX_SAMPLES)\n",
504
- "dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True,\n",
505
- " num_workers=NUM_WORKERS, pin_memory=True, drop_last=True)\n",
506
- "\n",
507
- "print(f\"Dataset size: {len(dataset):,} images\")\n",
508
- "print(f\"Steps per epoch: {len(dataloader):,}\")\n",
509
- "print(f\"Total steps: ~{len(dataloader) * NUM_EPOCHS:,}\")\n",
510
- "\n",
511
- "# Show sample\n",
512
- "import matplotlib.pyplot as plt\n",
513
- "sample_batch = next(iter(dataloader))\n",
514
- "fig, axes = plt.subplots(1, min(8, BATCH_SIZE), figsize=(16, 2))\n",
515
- "for i, ax in enumerate(axes):\n",
516
- " img = (sample_batch[i].permute(1,2,0) * 0.5 + 0.5).clamp(0,1)\n",
517
- " ax.imshow(img); ax.axis('off')\n",
518
- "plt.suptitle(f'Training samples ({IMAGE_SIZE}px)'); plt.tight_layout(); plt.show()"
519
- ]
520
- },
521
- {
522
- "cell_type": "markdown",
523
- "metadata": {},
524
- "source": [
525
- "## 🚀 Training Loop"
526
- ]
527
- },
528
- {
529
- "cell_type": "code",
530
- "execution_count": null,
531
- "metadata": {},
532
- "outputs": [],
533
- "source": [
534
- "import matplotlib.pyplot as plt\n",
535
- "from IPython.display import clear_output, display\n",
536
- "\n",
537
- "# Setup\n",
538
- "os.makedirs(OUTPUT_DIR, exist_ok=True)\n",
539
- "os.makedirs(f'{OUTPUT_DIR}/samples', exist_ok=True)\n",
540
- "os.makedirs(f'{OUTPUT_DIR}/checkpoints', exist_ok=True)\n",
541
- "\n",
542
- "optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE,\n",
543
- " weight_decay=WEIGHT_DECAY, betas=(0.9, 0.999))\n",
544
- "\n",
545
- "# Cosine LR schedule with warmup\n",
546
- "total_steps = len(dataloader) * NUM_EPOCHS\n",
547
- "warmup_steps = min(1000, total_steps // 10)\n",
548
- "\n",
549
- "def lr_lambda(step):\n",
550
- " if step < warmup_steps:\n",
551
- " return float(step) / float(max(1, warmup_steps))\n",
552
- " progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))\n",
553
- " return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))\n",
554
- "\n",
555
- "scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)\n",
556
- "\n",
557
- "# EMA model\n",
558
- "ema_model = copy.deepcopy(model).eval()\n",
559
- "for p in ema_model.parameters(): p.requires_grad_(False)\n",
560
- "\n",
561
- "# AMP\n",
562
- "scaler = torch.amp.GradScaler('cuda', enabled=(USE_AMP and device == 'cuda'))\n",
563
- "amp_dtype = getattr(torch, AMP_DTYPE) if USE_AMP and device == 'cuda' else torch.float32\n",
564
- "\n",
565
- "# Time sampling\n",
566
- "def sample_time(bs):\n",
567
- " eps = 1e-5\n",
568
- " if TIME_SAMPLING == 'uniform':\n",
569
- " return torch.rand(bs, device=device) * (1 - 2*eps) + eps\n",
570
- " else: # logit_normal\n",
571
- " return torch.sigmoid(torch.randn(bs, device=device)).clamp(eps, 1-eps)\n",
572
- "\n",
573
- "# Resume if requested\n",
574
- "global_step = 0\n",
575
- "start_epoch = 0\n",
576
- "all_losses = []\n",
577
- "\n",
578
- "if RESUME_FROM and os.path.exists(RESUME_FROM):\n",
579
- " ckpt = torch.load(RESUME_FROM, map_location=device, weights_only=False)\n",
580
- " model.load_state_dict(ckpt['model'])\n",
581
- " ema_model.load_state_dict(ckpt['ema_model'])\n",
582
- " optimizer.load_state_dict(ckpt['optimizer'])\n",
583
- " global_step = ckpt.get('step', 0)\n",
584
- " start_epoch = ckpt.get('epoch', 0)\n",
585
- " all_losses = ckpt.get('losses', [])\n",
586
- " print(f\"Resumed from step {global_step}, epoch {start_epoch}\")\n",
587
- "\n",
588
- "\n",
589
- "@torch.no_grad()\n",
590
- "def generate_samples(step):\n",
591
- " \"\"\"Generate and save sample images.\"\"\"\n",
592
- " ema_model.eval()\n",
593
- " z = torch.randn(NUM_SAMPLE_IMAGES, 3, IMAGE_SIZE, IMAGE_SIZE, device=device)\n",
594
- " dt = 1.0 / NUM_EULER_STEPS\n",
595
- " for i in range(NUM_EULER_STEPS, 0, -1):\n",
596
- " t = torch.full((NUM_SAMPLE_IMAGES,), i / NUM_EULER_STEPS, device=device)\n",
597
- " with torch.amp.autocast(device, dtype=amp_dtype, enabled=USE_AMP and device=='cuda'):\n",
598
- " v = ema_model(z, t)\n",
599
- " if USE_AMP and amp_dtype == torch.float16: v = v.float()\n",
600
- " z = z - v * dt\n",
601
- " z = z.clamp(-1, 1)\n",
602
- " grid = make_grid(z * 0.5 + 0.5, nrow=int(math.sqrt(NUM_SAMPLE_IMAGES)), padding=2)\n",
603
- " save_image(grid, f'{OUTPUT_DIR}/samples/step_{step:06d}.png')\n",
604
- " return z\n",
605
- "\n",
606
- "\n",
607
- "# ========== TRAINING LOOP ==========\n",
608
- "print(f\"\\n{'='*60}\")\n",
609
- "print(f\"Starting training: {NUM_EPOCHS} epochs, {total_steps:,} total steps\")\n",
610
- "print(f\"Warmup: {warmup_steps} steps, LR: {LEARNING_RATE}\")\n",
611
- "print(f\"{'='*60}\\n\")\n",
612
- "\n",
613
- "train_start = time.time()\n",
614
- "epoch_losses = []\n",
615
- "\n",
616
- "for epoch in range(start_epoch, NUM_EPOCHS):\n",
617
- " model.train()\n",
618
- " epoch_loss = 0\n",
619
- " \n",
620
- " for batch_idx, x0 in enumerate(dataloader):\n",
621
- " x0 = x0.to(device, non_blocking=True)\n",
622
- " \n",
623
- " # Rectified Flow: x_t = (1-t)*x0 + t*x1, target = x1 - x0\n",
624
- " x1 = torch.randn_like(x0)\n",
625
- " t = sample_time(x0.shape[0])\n",
626
- " te = t[:, None, None, None]\n",
627
- " x_t = (1 - te) * x0 + te * x1\n",
628
- " v_target = x1 - x0\n",
629
- " \n",
630
- " with torch.amp.autocast(device, dtype=amp_dtype, enabled=USE_AMP and device=='cuda'):\n",
631
- " v_pred = model(x_t, t)\n",
632
- " loss = F.mse_loss(v_pred, v_target)\n",
633
- " \n",
634
- " optimizer.zero_grad(set_to_none=True)\n",
635
- " scaler.scale(loss).backward()\n",
636
- " if GRAD_CLIP > 0:\n",
637
- " scaler.unscale_(optimizer)\n",
638
- " torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)\n",
639
- " scaler.step(optimizer)\n",
640
- " scaler.update()\n",
641
- " scheduler.step()\n",
642
- " \n",
643
- " # EMA update\n",
644
- " with torch.no_grad():\n",
645
- " for ep, mp in zip(ema_model.parameters(), model.parameters()):\n",
646
- " ep.data.mul_(EMA_DECAY).add_(mp.data, alpha=1-EMA_DECAY)\n",
647
- " \n",
648
- " global_step += 1\n",
649
- " loss_val = loss.item()\n",
650
- " all_losses.append(loss_val)\n",
651
- " epoch_loss += loss_val\n",
652
- " \n",
653
- " # Logging\n",
654
- " if global_step % LOG_EVERY == 0:\n",
655
- " avg_loss = sum(all_losses[-LOG_EVERY:]) / LOG_EVERY\n",
656
- " lr = scheduler.get_last_lr()[0]\n",
657
- " elapsed = time.time() - train_start\n",
658
- " steps_per_sec = global_step / elapsed\n",
659
- " eta = (total_steps - global_step) / max(steps_per_sec, 1e-8)\n",
660
- " if device == 'cuda':\n",
661
- " vram = torch.cuda.max_memory_allocated() / 1e9\n",
662
- " print(f\"Step {global_step:6d}/{total_steps} | Loss: {avg_loss:.4f} | LR: {lr:.2e} | {steps_per_sec:.1f} it/s | ETA: {eta/60:.0f}m | VRAM: {vram:.1f}GB\")\n",
663
- " else:\n",
664
- " print(f\"Step {global_step:6d}/{total_steps} | Loss: {avg_loss:.4f} | LR: {lr:.2e} | {steps_per_sec:.1f} it/s | ETA: {eta/60:.0f}m\")\n",
665
- " \n",
666
- " # Generate samples\n",
667
- " if global_step % SAMPLE_EVERY == 0:\n",
668
- " print(f\"\\n 📸 Generating samples at step {global_step}...\")\n",
669
- " samples = generate_samples(global_step)\n",
670
- " \n",
671
- " # Display in notebook\n",
672
- " fig, axes = plt.subplots(1, min(8, NUM_SAMPLE_IMAGES), figsize=(16, 2.5))\n",
673
- " if NUM_SAMPLE_IMAGES == 1: axes = [axes]\n",
674
- " for i, ax in enumerate(axes):\n",
675
- " if i < samples.shape[0]:\n",
676
- " img = (samples[i].cpu().permute(1,2,0) * 0.5 + 0.5).clamp(0,1)\n",
677
- " ax.imshow(img); ax.axis('off')\n",
678
- " plt.suptitle(f'Step {global_step} | Loss: {loss_val:.4f}'); plt.tight_layout(); plt.show()\n",
679
- " \n",
680
- " # Save checkpoint\n",
681
- " if global_step % SAVE_EVERY == 0:\n",
682
- " ckpt_path = f'{OUTPUT_DIR}/checkpoints/step_{global_step:06d}.pt'\n",
683
- " torch.save({\n",
684
- " 'model': model.state_dict(),\n",
685
- " 'ema_model': ema_model.state_dict(),\n",
686
- " 'optimizer': optimizer.state_dict(),\n",
687
- " 'step': global_step,\n",
688
- " 'epoch': epoch,\n",
689
- " 'losses': all_losses[-2000:],\n",
690
- " 'config': config,\n",
691
- " }, ckpt_path)\n",
692
- " print(f\" 💾 Saved checkpoint: {ckpt_path}\")\n",
693
- " \n",
694
- " # Epoch summary\n",
695
- " avg_epoch_loss = epoch_loss / len(dataloader)\n",
696
- " epoch_losses.append(avg_epoch_loss)\n",
697
- " print(f\"\\n Epoch {epoch+1}/{NUM_EPOCHS} complete | Avg loss: {avg_epoch_loss:.4f}\")\n",
698
- "\n",
699
- "# Final save\n",
700
- "final_path = f'{OUTPUT_DIR}/checkpoints/final.pt'\n",
701
- "torch.save({\n",
702
- " 'model': model.state_dict(),\n",
703
- " 'ema_model': ema_model.state_dict(),\n",
704
- " 'step': global_step,\n",
705
- " 'config': config,\n",
706
- " 'losses': all_losses[-2000:],\n",
707
- "}, final_path)\n",
708
- "print(f\"\\n✅ Training complete! Final checkpoint: {final_path}\")\n",
709
- "print(f\"Total time: {(time.time()-train_start)/3600:.1f} hours\")"
710
- ]
711
- },
712
- {
713
- "cell_type": "markdown",
714
- "metadata": {},
715
- "source": [
716
- "## 📈 Training Curves"
717
- ]
718
- },
719
- {
720
- "cell_type": "code",
721
- "execution_count": null,
722
- "metadata": {},
723
- "outputs": [],
724
- "source": [
725
- "import matplotlib.pyplot as plt\n",
726
- "import numpy as np\n",
727
- "\n",
728
- "fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))\n",
729
- "\n",
730
- "# Raw loss\n",
731
- "ax1.plot(all_losses, alpha=0.3, color='blue', linewidth=0.5)\n",
732
- "# Smoothed loss\n",
733
- "window = min(200, len(all_losses)//5)\n",
734
- "if window > 1:\n",
735
- " smoothed = np.convolve(all_losses, np.ones(window)/window, mode='valid')\n",
736
- " ax1.plot(range(window-1, len(all_losses)), smoothed, color='red', linewidth=2, label=f'Smoothed (w={window})')\n",
737
- "ax1.set_xlabel('Step'); ax1.set_ylabel('Loss'); ax1.set_title('Training Loss')\n",
738
- "ax1.legend(); ax1.grid(True, alpha=0.3)\n",
739
- "\n",
740
- "# Epoch loss\n",
741
- "if epoch_losses:\n",
742
- " ax2.plot(range(1, len(epoch_losses)+1), epoch_losses, 'o-', color='green')\n",
743
- " ax2.set_xlabel('Epoch'); ax2.set_ylabel('Avg Loss'); ax2.set_title('Loss per Epoch')\n",
744
- " ax2.grid(True, alpha=0.3)\n",
745
- "\n",
746
- "plt.tight_layout(); plt.show()"
747
- ]
748
- },
749
- {
750
- "cell_type": "markdown",
751
- "metadata": {},
752
- "source": [
753
- "## 🎨 Generate Images"
754
- ]
755
- },
756
- {
757
- "cell_type": "code",
758
- "execution_count": null,
759
- "metadata": {},
760
- "outputs": [],
761
- "source": [
762
- "# Generate a batch of images\n",
763
- "NUM_GENERATE = 16 # @param {type:\"integer\"}\n",
764
- "EULER_STEPS = 50 # @param {type:\"integer\"}\n",
765
- "\n",
766
- "print(f\"Generating {NUM_GENERATE} images with {EULER_STEPS} Euler steps...\")\n",
767
- "ema_model.eval()\n",
768
- "\n",
769
- "with torch.no_grad():\n",
770
- " z = torch.randn(NUM_GENERATE, 3, IMAGE_SIZE, IMAGE_SIZE, device=device)\n",
771
- " dt = 1.0 / EULER_STEPS\n",
772
- " for i in range(EULER_STEPS, 0, -1):\n",
773
- " t = torch.full((NUM_GENERATE,), i / EULER_STEPS, device=device)\n",
774
- " with torch.amp.autocast(device, dtype=amp_dtype, enabled=USE_AMP and device=='cuda'):\n",
775
- " v = ema_model(z, t)\n",
776
- " if USE_AMP and amp_dtype == torch.float16: v = v.float()\n",
777
- " z = z - v * dt\n",
778
- " generated = z.clamp(-1, 1)\n",
779
- "\n",
780
- "# Display\n",
781
- "nrow = int(math.ceil(math.sqrt(NUM_GENERATE)))\n",
782
- "fig, axes = plt.subplots(nrow, nrow, figsize=(2.5*nrow, 2.5*nrow))\n",
783
- "axes = axes.flatten() if hasattr(axes, 'flatten') else [axes]\n",
784
- "for i, ax in enumerate(axes):\n",
785
- " if i < NUM_GENERATE:\n",
786
- " img = (generated[i].cpu().permute(1,2,0) * 0.5 + 0.5).clamp(0,1)\n",
787
- " ax.imshow(img)\n",
788
- " ax.axis('off')\n",
789
- "plt.suptitle(f'LiquidDiffusion Samples ({IMAGE_SIZE}px, {EULER_STEPS} steps)', fontsize=14)\n",
790
- "plt.tight_layout(); plt.show()\n",
791
- "\n",
792
- "# Save\n",
793
- "grid = make_grid(generated * 0.5 + 0.5, nrow=nrow, padding=2)\n",
794
- "save_image(grid, f'{OUTPUT_DIR}/final_samples.png')\n",
795
- "print(f\"Saved to {OUTPUT_DIR}/final_samples.png\")"
796
- ]
797
- },
798
- {
799
- "cell_type": "markdown",
800
- "metadata": {},
801
- "source": [
802
- "## 💾 Save / Load Model"
803
- ]
804
- },
805
- {
806
- "cell_type": "code",
807
- "execution_count": null,
808
- "metadata": {},
809
- "outputs": [],
810
- "source": [
811
- "# Save to HuggingFace Hub (optional)\n",
812
- "PUSH_TO_HUB = False # @param {type:\"boolean\"}\n",
813
- "HUB_MODEL_ID = 'your-username/liquid-diffusion-celebahq-256' # @param {type:\"string\"}\n",
814
- "\n",
815
- "if PUSH_TO_HUB:\n",
816
- " from huggingface_hub import HfApi\n",
817
- " api = HfApi()\n",
818
- " api.create_repo(HUB_MODEL_ID, exist_ok=True)\n",
819
- " api.upload_file(\n",
820
- " path_or_fileobj=final_path,\n",
821
- " path_in_repo='model.pt',\n",
822
- " repo_id=HUB_MODEL_ID,\n",
823
- " )\n",
824
- " print(f\"Pushed to https://huggingface.co/{HUB_MODEL_ID}\")"
825
- ]
826
- },
827
- {
828
- "cell_type": "markdown",
829
- "metadata": {},
830
- "source": [
831
- "---\n",
832
- "\n",
833
- "## 📖 Architecture Deep Dive\n",
834
- "\n",
835
- "### What makes LiquidDiffusion special?\n",
836
- "\n",
837
- "**1. CfC Time-Gating (the \"liquid\" part)**\n",
838
- "```\n",
839
- "gate = σ(time_a(t_emb) · f(features) - time_b(t_emb))\n",
840
- "output = gate · g(features) + (1 - gate) · h(features)\n",
841
- "```\n",
842
- "- `f` = time-constant head (controls gate sensitivity)\n",
843
- "- `g` = \"from\" state (what features look like at short time)\n",
844
- "- `h` = \"to\" state (attractor for long time)\n",
845
- "- The gate adapts **per-channel, per-spatial-position** based on both the input features AND the noise level\n",
846
- "\n",
847
- "**2. Liquid Relaxation Residual**\n",
848
- "```\n",
849
- "α = exp(-λ · |t_emb_mean|)\n",
850
- "out = α · input + (1-α) · CfC_output\n",
851
- "```\n",
852
- "- When noise is high (large t): α→0, rely on CfC output (needs heavy processing)\n",
853
- "- When noise is low (small t): α→1, preserve input (just refine details)\n",
854
- "- λ is learned per-channel — each feature dimension decides its own decay rate\n",
855
- "\n",
856
- "**3. Multi-Scale Spatial Mixing**\n",
857
- "- 3×3 + 5×5 + 7×7 depthwise convolutions + global average pooling\n",
858
- "- Gives effective global receptive field without O(n²) attention\n",
859
- "- All parallel, all efficient\n",
860
- "\n",
861
- "### Why no attention?\n",
862
- "- Self-attention is O(n²) in spatial tokens — at 256px that's 65K tokens\n",
863
- "- Depthwise convolutions + global pooling give global context at O(n) cost\n",
864
- "- The CfC time-gating provides the \"adaptive routing\" that attention normally gives\n",
865
- "- Result: **same expressivity, 10× less memory, 3× faster**\n",
866
- "\n",
867
- "### Parameter counts\n",
868
- "| Config | Params | 256px VRAM | 512px VRAM |\n",
869
- "|--------|--------|------------|------------|\n",
870
- "| tiny | ~23M | ~6 GB | ~12 GB |\n",
871
- "| small | ~69M | ~10 GB | ~20 GB |\n",
872
- "| base | ~154M | ~16 GB | ~30 GB |"
873
- ]
874
- }
875
- ]
876
  }
 
1
  {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# 🌊 LiquidDiffusion: Attention-Free Image Generation with Liquid Neural Networks\n",
8
+ "\n",
9
+ "**A novel image generation architecture** that replaces attention with Parallel CfC (Closed-form Continuous-depth) blocks from Liquid Neural Networks.\n",
10
+ "\n",
11
+ "## Key Innovations\n",
12
+ "- **No attention mechanism** — all spatial mixing via multi-scale depthwise convolutions\n",
13
+ "- **Fully parallelizable** — no sequential ODE solving loops (unlike original LTC/Neural ODE)\n",
14
+ "- **Diffusion timestep IS the liquid time constant** — natural CfC-diffusion bridge\n",
15
+ "- **Liquid relaxation residuals** — time-aware skip connections that adapt to noise level\n",
16
+ "- **Fits in 16GB VRAM** — designed for Colab free tier (T4 GPU)\n",
17
+ "\n",
18
+ "## Architecture Based On\n",
19
+ "- [CfC Networks](https://arxiv.org/abs/2106.13898) (Hasani et al., Nature Machine Intelligence 2022)\n",
20
+ "- [LiquidTAD](https://arxiv.org/abs/2604.18274) — parallel liquid relaxation\n",
21
+ "- [USM](https://arxiv.org/abs/2504.13499) — U-Shape architecture for diffusion\n",
22
+ "- [Rectified Flow](https://arxiv.org/abs/2209.03003) — simplest flow matching objective\n",
23
+ "\n",
24
+ "## Training: Rectified Flow\n",
25
+ "```\n",
26
+ "x_t = (1-t)*x0 + t*noise, t ~ U[0,1]\n",
27
+ "Loss = MSE(model(x_t, t), noise - x0) # velocity prediction\n",
28
+ "```\n",
29
+ "That's it — no noise schedule, no variance, just MSE on a straight-line velocity."
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "markdown",
34
+ "metadata": {},
35
+ "source": [
36
+ "## 🔧 Setup"
37
+ ]
38
+ },
39
+ {
40
+ "cell_type": "code",
41
+ "execution_count": null,
42
+ "metadata": {},
43
+ "outputs": [],
44
+ "source": [
45
+ "# Install dependencies\n",
46
+ "!pip install -q torch torchvision datasets Pillow matplotlib tqdm accelerate"
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "code",
51
+ "execution_count": null,
52
+ "metadata": {},
53
+ "outputs": [],
54
+ "source": [
55
+ "# Clone the repo\n",
56
+ "!git clone https://huggingface.co/krystv/liquid-diffusion\n",
57
+ "%cd liquid-diffusion"
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "execution_count": null,
63
+ "metadata": {},
64
+ "outputs": [],
65
+ "source": [
66
+ "import torch\n",
67
+ "print(f'PyTorch: {torch.__version__}')\n",
68
+ "print(f'CUDA available: {torch.cuda.is_available()}')\n",
69
+ "if torch.cuda.is_available():\n",
70
+ " print(f'GPU: {torch.cuda.get_device_name(0)}')\n",
71
+ " print(f'VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB')"
72
+ ]
73
+ },
74
+ {
75
+ "cell_type": "markdown",
76
+ "metadata": {},
77
+ "source": [
78
+ "## 📐 Architecture Overview\n",
79
+ "\n",
80
+ "The core innovation is the **ParallelCfCBlock** — a parallelized version of CfC (Closed-form Continuous-depth) networks adapted for 2D image features:\n",
81
+ "\n",
82
+ "```\n",
83
+ "CfC Equation (Hasani et al. 2022, Eq. 10):\n",
84
+ " x(t) = σ(-f·t) ⊙ g + (1 - σ(-f·t)) ⊙ h\n",
85
+ "\n",
86
+ "Our adaptation for image generation:\n",
87
+ " backbone = SiLU(PointwiseConv(DepthwiseConv(features))) # shared spatial context\n",
88
+ " f = Conv1x1(backbone) # time-constant gate\n",
89
+ " g = DWConv→SiLU→Conv1x1(backbone) # \"from\" state\n",
90
+ " h = DWConv→SiLU→Conv1x1(backbone) # \"to\" state (attractor)\n",
91
+ " gate = σ(time_a(t_emb) · f - time_b(t_emb)) # liquid time gate\n",
92
+ " cfc_out = gate · g + (1 - gate) · h # CfC interpolation\n",
93
+ " \n",
94
+ " # Liquid relaxation (from LiquidTAD):\n",
95
+ " α = exp(-softplus(ρ) · |t|) # time-aware residual weight\n",
96
+ " output = α · input + (1 - α) · cfc_out # adapts to noise level\n",
97
+ "```\n",
98
+ "\n",
99
+ "The **diffusion timestep t** serves double duty:\n",
100
+ "1. Standard: conditions the denoiser via AdaLN scale/shift\n",
101
+ "2. Novel: acts as the CfC time parameter — controls interpolation between g and h\n",
102
+ "\n",
103
+ "This means: at low noise (t≈0), the gate is balanced → flexible processing.\n",
104
+ "At high noise (t≈1), the gate saturates → specialized denoising."
105
+ ]
106
+ },
107
+ {
108
+ "cell_type": "markdown",
109
+ "metadata": {},
110
+ "source": [
111
+ "## 🧪 Quick Test (verify model works)"
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "execution_count": null,
117
+ "metadata": {},
118
+ "outputs": [],
119
+ "source": [
120
+ "# Run the test suite\n",
121
+ "!python test_model.py"
122
+ ]
123
+ },
124
+ {
125
+ "cell_type": "markdown",
126
+ "metadata": {},
127
+ "source": [
128
+ "## ⚙️ Training Configuration\n",
129
+ "\n",
130
+ "Choose your config based on GPU and target resolution:\n",
131
+ "\n",
132
+ "| Config | Params | Resolution | Batch Size | VRAM | Training Time |\n",
133
+ "|--------|--------|-----------|------------|------|---------------|\n",
134
+ "| tiny | ~8M | 256×256 | 8 | ~6GB | ~3h (100K steps) |\n",
135
+ "| small | ~25M | 256×256 | 4 | ~10GB | ~6h (100K steps) |\n",
136
+ "| base | ~65M | 512×512 | 2 | ~14GB | ~12h (100K steps) |\n",
137
+ "\n",
138
+ "Recommended datasets:\n",
139
+ "- `huggan/CelebA-HQ` — 30K high-quality face images (256px)\n",
140
+ "- `huggan/flowers-102-categories` — flowers (various)\n",
141
+ "- `lambdalabs/naruto-blip-captions` — anime style (~1K)\n",
142
+ "- `Norod78/simpsons-blip-captions` — cartoon style\n",
143
+ "- Any folder of images"
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "code",
148
+ "execution_count": null,
149
+ "metadata": {},
150
+ "outputs": [],
151
+ "source": [
152
+ "#@title Training Configuration {display-mode: \"form\"}\n",
153
+ "\n",
154
+ "#@markdown ### Model\n",
155
+ "model_size = \"tiny\" #@param [\"tiny\", \"small\", \"base\"]\n",
156
+ "\n",
157
+ "#@markdown ### Data\n",
158
+ "dataset_name = \"huggan/CelebA-HQ\" #@param {type:\"string\"}\n",
159
+ "image_column = \"image\" #@param {type:\"string\"}\n",
160
+ "image_size = 256 #@param [64, 128, 256, 512] {type:\"integer\"}\n",
161
+ "max_samples = 0 #@param {type:\"integer\"}\n",
162
+ "\n",
163
+ "#@markdown ### Training\n",
164
+ "batch_size = 8 #@param {type:\"integer\"}\n",
165
+ "learning_rate = 1e-4 #@param {type:\"number\"}\n",
166
+ "weight_decay = 0.01 #@param {type:\"number\"}\n",
167
+ "total_steps = 100000 #@param {type:\"integer\"}\n",
168
+ "warmup_steps = 1000 #@param {type:\"integer\"}\n",
169
+ "grad_clip = 1.0 #@param {type:\"number\"}\n",
170
+ "ema_decay = 0.9999 #@param {type:\"number\"}\n",
171
+ "time_sampling = \"logit_normal\" #@param [\"uniform\", \"logit_normal\"]\n",
172
+ "\n",
173
+ "#@markdown ### Sampling & Logging\n",
174
+ "sample_every = 2000 #@param {type:\"integer\"}\n",
175
+ "save_every = 5000 #@param {type:\"integer\"}\n",
176
+ "num_sample_steps = 50 #@param {type:\"integer\"}\n",
177
+ "num_sample_images = 4 #@param {type:\"integer\"}\n",
178
+ "\n",
179
+ "#@markdown ### Hardware\n",
180
+ "use_amp = True #@param {type:\"boolean\"}\n",
181
+ "amp_dtype = \"float16\" #@param [\"float16\", \"bfloat16\"]\n",
182
+ "num_workers = 2 #@param {type:\"integer\"}\n",
183
+ "\n",
184
+ "# Auto-adjust batch size for resolution\n",
185
+ "if image_size >= 512 and batch_size > 4:\n",
186
+ " batch_size = min(batch_size, 2)\n",
187
+ " print(f\"Auto-reduced batch_size to {batch_size} for {image_size}px\")\n",
188
+ "\n",
189
+ "if max_samples == 0:\n",
190
+ " max_samples = None\n",
191
+ "\n",
192
+ "print(f\"\\nConfig: {model_size} model, {image_size}px, batch={batch_size}, lr={learning_rate}\")\n",
193
+ "print(f\"Dataset: {dataset_name}, time_sampling={time_sampling}\")\n",
194
+ "print(f\"Total steps: {total_steps:,}, AMP: {use_amp} ({amp_dtype})\")"
195
+ ]
196
+ },
197
+ {
198
+ "cell_type": "markdown",
199
+ "metadata": {},
200
+ "source": [
201
+ "## 📦 Load Dataset"
202
+ ]
203
+ },
204
+ {
205
+ "cell_type": "code",
206
+ "execution_count": null,
207
+ "metadata": {},
208
+ "outputs": [],
209
+ "source": [
210
+ "from datasets import load_dataset\n",
211
+ "from liquid_diffusion.trainer import ImageDataset\n",
212
+ "from torch.utils.data import DataLoader\n",
213
+ "import matplotlib.pyplot as plt\n",
214
+ "import numpy as np\n",
215
+ "\n",
216
+ "# Load dataset\n",
217
+ "print(f\"Loading {dataset_name}...\")\n",
218
+ "dataset = ImageDataset(\n",
219
+ " source=dataset_name,\n",
220
+ " image_size=image_size,\n",
221
+ " image_column=image_column,\n",
222
+ " max_samples=max_samples,\n",
223
+ ")\n",
224
+ "print(f\"Dataset size: {len(dataset)} images\")\n",
225
+ "\n",
226
+ "dataloader = DataLoader(\n",
227
+ " dataset, batch_size=batch_size, shuffle=True,\n",
228
+ " num_workers=num_workers, pin_memory=True, drop_last=True,\n",
229
+ ")\n",
230
+ "\n",
231
+ "# Show some samples\n",
232
+ "sample_batch = next(iter(dataloader))\n",
233
+ "fig, axes = plt.subplots(1, min(4, batch_size), figsize=(16, 4))\n",
234
+ "for i, ax in enumerate(axes):\n",
235
+ " img = sample_batch[i].permute(1, 2, 0).numpy() * 0.5 + 0.5 # [-1,1] -> [0,1]\n",
236
+ " ax.imshow(np.clip(img, 0, 1))\n",
237
+ " ax.axis('off')\n",
238
+ "plt.suptitle(f'Training samples ({image_size}×{image_size})')\n",
239
+ "plt.tight_layout()\n",
240
+ "plt.show()"
241
+ ]
242
+ },
243
+ {
244
+ "cell_type": "markdown",
245
+ "metadata": {},
246
+ "source": [
247
+ "## 🏗️ Build Model"
248
+ ]
249
+ },
250
+ {
251
+ "cell_type": "code",
252
+ "execution_count": null,
253
+ "metadata": {},
254
+ "outputs": [],
255
+ "source": [
256
+ "from liquid_diffusion.model import (\n",
257
+ " liquid_diffusion_tiny, liquid_diffusion_small, liquid_diffusion_base\n",
258
+ ")\n",
259
+ "\n",
260
+ "# Build model\n",
261
+ "model_factories = {\n",
262
+ " 'tiny': liquid_diffusion_tiny,\n",
263
+ " 'small': liquid_diffusion_small,\n",
264
+ " 'base': liquid_diffusion_base,\n",
265
+ "}\n",
266
+ "\n",
267
+ "model = model_factories[model_size]()\n",
268
+ "total_params, trainable_params = model.count_params()\n",
269
+ "print(f\"Model: liquid_diffusion_{model_size}\")\n",
270
+ "print(f\"Parameters: {total_params:,} ({total_params/1e6:.1f}M)\")\n",
271
+ "print(f\"Trainable: {trainable_params:,}\")\n",
272
+ "\n",
273
+ "# Quick forward pass test\n",
274
+ "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
275
+ "model = model.to(device)\n",
276
+ "test_x = torch.randn(1, 3, image_size, image_size, device=device)\n",
277
+ "test_t = torch.tensor([0.5], device=device)\n",
278
+ "with torch.no_grad():\n",
279
+ " test_out = model(test_x, test_t)\n",
280
+ "print(f\"Forward pass OK: {test_x.shape} → {test_out.shape}\")\n",
281
+ "del test_x, test_out\n",
282
+ "if device == 'cuda':\n",
283
+ " torch.cuda.empty_cache()"
284
+ ]
285
+ },
286
+ {
287
+ "cell_type": "markdown",
288
+ "metadata": {},
289
+ "source": [
290
+ "## 🚀 Train!"
291
+ ]
292
+ },
293
+ {
294
+ "cell_type": "code",
295
+ "execution_count": null,
296
+ "metadata": {},
297
+ "outputs": [],
298
+ "source": [
299
+ "import os\n",
300
+ "import time\n",
301
+ "import math\n",
302
+ "from tqdm.auto import tqdm\n",
303
+ "from torchvision.utils import save_image, make_grid\n",
304
+ "from liquid_diffusion.trainer import RectifiedFlowTrainer, get_cosine_schedule_with_warmup\n",
305
+ "\n",
306
+ "# Create output directories\n",
307
+ "os.makedirs('checkpoints', exist_ok=True)\n",
308
+ "os.makedirs('samples', exist_ok=True)\n",
309
+ "\n",
310
+ "# Build trainer\n",
311
+ "trainer = RectifiedFlowTrainer(\n",
312
+ " model=model,\n",
313
+ " lr=learning_rate,\n",
314
+ " weight_decay=weight_decay,\n",
315
+ " ema_decay=ema_decay,\n",
316
+ " grad_clip=grad_clip,\n",
317
+ " time_sampling=time_sampling,\n",
318
+ " device=device,\n",
319
+ " use_amp=use_amp,\n",
320
+ " amp_dtype=amp_dtype,\n",
321
+ ")\n",
322
+ "\n",
323
+ "# Learning rate scheduler\n",
324
+ "scheduler = get_cosine_schedule_with_warmup(\n",
325
+ " trainer.optimizer, warmup_steps, total_steps\n",
326
+ ")\n",
327
+ "\n",
328
+ "# Optional: resume from checkpoint\n",
329
+ "resume_path = 'checkpoints/latest.pt'\n",
330
+ "if os.path.exists(resume_path):\n",
331
+ " trainer.load_checkpoint(resume_path)\n",
332
+ " print(f\"Resumed from step {trainer.step}\")\n",
333
+ "\n",
334
+ "print(f\"\\n{'='*60}\")\n",
335
+ "print(f\"Starting training: {total_steps:,} steps\")\n",
336
+ "print(f\"Model: liquid_diffusion_{model_size} ({total_params/1e6:.1f}M params)\")\n",
337
+ "print(f\"Resolution: {image_size}×{image_size}, Batch: {batch_size}\")\n",
338
+ "print(f\"LR: {learning_rate}, Warmup: {warmup_steps}, AMP: {use_amp}\")\n",
339
+ "print(f\"{'='*60}\\n\")\n",
340
+ "\n",
341
+ "# Training loop\n",
342
+ "start_time = time.time()\n",
343
+ "data_iter = iter(dataloader)\n",
344
+ "pbar = tqdm(range(trainer.step, total_steps), desc='Training', dynamic_ncols=True)\n",
345
+ "loss_history = []\n",
346
+ "\n",
347
+ "for step in pbar:\n",
348
+ " # Get batch (cycle through dataset)\n",
349
+ " try:\n",
350
+ " batch = next(data_iter)\n",
351
+ " except StopIteration:\n",
352
+ " data_iter = iter(dataloader)\n",
353
+ " batch = next(data_iter)\n",
354
+ " \n",
355
+ " x0 = batch.to(device)\n",
356
+ " \n",
357
+ " # Train step\n",
358
+ " metrics = trainer.train_step(x0)\n",
359
+ " scheduler.step()\n",
360
+ " \n",
361
+ " # Logging\n",
362
+ " loss_history.append(metrics['loss'])\n",
363
+ " avg_loss = sum(loss_history[-100:]) / len(loss_history[-100:])\n",
364
+ " lr_current = scheduler.get_last_lr()[0]\n",
365
+ " \n",
366
+ " pbar.set_postfix({\n",
367
+ " 'loss': f\"{metrics['loss']:.4f}\",\n",
368
+ " 'avg': f\"{avg_loss:.4f}\",\n",
369
+ " 'lr': f\"{lr_current:.6f}\",\n",
370
+ " 'gn': f\"{metrics['grad_norm']:.2f}\",\n",
371
+ " })\n",
372
+ " \n",
373
+ " # Generate samples\n",
374
+ " if (step + 1) % sample_every == 0 or step == 0:\n",
375
+ " print(f\"\\nGenerating samples at step {step+1}...\")\n",
376
+ " samples = trainer.sample(\n",
377
+ " batch_size=num_sample_images, image_size=image_size,\n",
378
+ " num_steps=num_sample_steps, use_ema=True\n",
379
+ " )\n",
380
+ " # Save grid\n",
381
+ " grid = make_grid(samples * 0.5 + 0.5, nrow=int(math.sqrt(num_sample_images)), padding=2)\n",
382
+ " save_image(grid, f'samples/step_{step+1:06d}.png')\n",
383
+ " \n",
384
+ " # Display\n",
385
+ " fig, axes = plt.subplots(1, num_sample_images, figsize=(4*num_sample_images, 4))\n",
386
+ " if num_sample_images == 1:\n",
387
+ " axes = [axes]\n",
388
+ " for i, ax in enumerate(axes):\n",
389
+ " img = samples[i].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5\n",
390
+ " ax.imshow(np.clip(img, 0, 1))\n",
391
+ " ax.axis('off')\n",
392
+ " plt.suptitle(f'Step {step+1} (EMA samples, {num_sample_steps} Euler steps)')\n",
393
+ " plt.tight_layout()\n",
394
+ " plt.show()\n",
395
+ " \n",
396
+ " # Save checkpoint\n",
397
+ " if (step + 1) % save_every == 0:\n",
398
+ " trainer.save_checkpoint(f'checkpoints/step_{step+1:06d}.pt', extra={'config': {\n",
399
+ " 'model_size': model_size, 'image_size': image_size,\n",
400
+ " 'batch_size': batch_size, 'learning_rate': learning_rate,\n",
401
+ " }})\n",
402
+ " trainer.save_checkpoint('checkpoints/latest.pt')\n",
403
+ " print(f\"Saved checkpoint at step {step+1}\")\n",
404
+ " \n",
405
+ " # Safety: check for NaN\n",
406
+ " if math.isnan(metrics['loss']):\n",
407
+ " print(\"\\n⚠️ NaN loss detected! Stopping training.\")\n",
408
+ " print(\"Try: reduce learning_rate, increase grad_clip, or use smaller model\")\n",
409
+ " break\n",
410
+ "\n",
411
+ "elapsed = time.time() - start_time\n",
412
+ "print(f\"\\nTraining complete! {trainer.step:,} steps in {elapsed/3600:.1f}h\")\n",
413
+ "print(f\"Final avg loss: {sum(loss_history[-100:])/len(loss_history[-100:]):.4f}\")\n",
414
+ "\n",
415
+ "# Final save\n",
416
+ "trainer.save_checkpoint('checkpoints/final.pt')\n",
417
+ "print(\"Saved final checkpoint.\")"
418
+ ]
419
+ },
420
+ {
421
+ "cell_type": "markdown",
422
+ "metadata": {},
423
+ "source": [
424
+ "## 📊 Training Loss Curve"
425
+ ]
426
+ },
427
+ {
428
+ "cell_type": "code",
429
+ "execution_count": null,
430
+ "metadata": {},
431
+ "outputs": [],
432
+ "source": [
433
+ "import matplotlib.pyplot as plt\n",
434
+ "import numpy as np\n",
435
+ "\n",
436
+ "if loss_history:\n",
437
+ " # Smooth the loss\n",
438
+ " window = min(100, len(loss_history) // 5 + 1)\n",
439
+ " smoothed = np.convolve(loss_history, np.ones(window)/window, mode='valid')\n",
440
+ " \n",
441
+ " fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))\n",
442
+ " \n",
443
+ " ax1.plot(loss_history, alpha=0.3, label='Raw')\n",
444
+ " ax1.plot(range(window-1, len(loss_history)), smoothed, label=f'Smoothed (w={window})')\n",
445
+ " ax1.set_xlabel('Step')\n",
446
+ " ax1.set_ylabel('Loss')\n",
447
+ " ax1.set_title('Training Loss')\n",
448
+ " ax1.legend()\n",
449
+ " ax1.grid(True, alpha=0.3)\n",
450
+ " \n",
451
+ " ax2.plot(loss_history[-min(1000, len(loss_history)):], alpha=0.5)\n",
452
+ " ax2.set_xlabel('Recent Steps')\n",
453
+ " ax2.set_ylabel('Loss')\n",
454
+ " ax2.set_title('Recent Loss (last 1000 steps)')\n",
455
+ " ax2.grid(True, alpha=0.3)\n",
456
+ " \n",
457
+ " plt.tight_layout()\n",
458
+ " plt.show()\n",
459
+ "else:\n",
460
+ " print(\"No training history yet.\")"
461
+ ]
462
+ },
463
+ {
464
+ "cell_type": "markdown",
465
+ "metadata": {},
466
+ "source": [
467
+ "## 🎨 Generate Images"
468
+ ]
469
+ },
470
+ {
471
+ "cell_type": "code",
472
+ "execution_count": null,
473
+ "metadata": {},
474
+ "outputs": [],
475
+ "source": [
476
+ "#@title Generation Settings {display-mode: \"form\"}\n",
477
+ "num_images = 8 #@param {type:\"integer\"}\n",
478
+ "sampling_steps = 50 #@param [25, 50, 100, 200] {type:\"integer\"}\n",
479
+ "use_ema_model = True #@param {type:\"boolean\"}\n",
480
+ "\n",
481
+ "print(f\"Generating {num_images} images with {sampling_steps} Euler steps...\")\n",
482
+ "samples = trainer.sample(\n",
483
+ " batch_size=num_images, image_size=image_size,\n",
484
+ " num_steps=sampling_steps, use_ema=use_ema_model,\n",
485
+ ")\n",
486
+ "\n",
487
+ "# Display\n",
488
+ "ncols = min(4, num_images)\n",
489
+ "nrows = (num_images + ncols - 1) // ncols\n",
490
+ "fig, axes = plt.subplots(nrows, ncols, figsize=(4*ncols, 4*nrows))\n",
491
+ "if nrows == 1 and ncols == 1:\n",
492
+ " axes = [[axes]]\n",
493
+ "elif nrows == 1:\n",
494
+ " axes = [axes]\n",
495
+ "for i in range(num_images):\n",
496
+ " r, c = i // ncols, i % ncols\n",
497
+ " img = samples[i].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5\n",
498
+ " axes[r][c].imshow(np.clip(img, 0, 1))\n",
499
+ " axes[r][c].axis('off')\n",
500
+ "# Hide unused axes\n",
501
+ "for i in range(num_images, nrows * ncols):\n",
502
+ " r, c = i // ncols, i % ncols\n",
503
+ " axes[r][c].axis('off')\n",
504
+ "plt.suptitle(f'LiquidDiffusion Samples ({sampling_steps} steps, {\"EMA\" if use_ema_model else \"online\"})')\n",
505
+ "plt.tight_layout()\n",
506
+ "plt.show()\n",
507
+ "\n",
508
+ "# Save\n",
509
+ "grid = make_grid(samples * 0.5 + 0.5, nrow=ncols, padding=2)\n",
510
+ "save_image(grid, 'samples/generated.png')\n",
511
+ "print(\"Saved to samples/generated.png\")"
512
+ ]
513
+ },
514
+ {
515
+ "cell_type": "markdown",
516
+ "metadata": {},
517
+ "source": [
518
+ "## 🔬 Visualize the Denoising Process"
519
+ ]
520
+ },
521
+ {
522
+ "cell_type": "code",
523
+ "execution_count": null,
524
+ "metadata": {},
525
+ "outputs": [],
526
+ "source": [
527
+ "# Show step-by-step denoising\n",
528
+ "num_vis_steps = 10\n",
529
+ "total_euler_steps = 50\n",
530
+ "vis_interval = total_euler_steps // num_vis_steps\n",
531
+ "\n",
532
+ "model_vis = trainer.ema_model\n",
533
+ "model_vis.eval()\n",
534
+ "\n",
535
+ "z = torch.randn(1, 3, image_size, image_size, device=device)\n",
536
+ "dt = 1.0 / total_euler_steps\n",
537
+ "intermediates = [z.clone()]\n",
538
+ "\n",
539
+ "with torch.no_grad():\n",
540
+ " for i in range(total_euler_steps, 0, -1):\n",
541
+ " t = torch.full((1,), i / total_euler_steps, device=device)\n",
542
+ " v = model_vis(z, t)\n",
543
+ " z = z - v * dt\n",
544
+ " if (total_euler_steps - i + 1) % vis_interval == 0:\n",
545
+ " intermediates.append(z.clone())\n",
546
+ "\n",
547
+ "intermediates.append(z.clamp(-1, 1))\n",
548
+ "\n",
549
+ "fig, axes = plt.subplots(1, len(intermediates), figsize=(3*len(intermediates), 3))\n",
550
+ "for idx, (ax, img_t) in enumerate(zip(axes, intermediates)):\n",
551
+ " img = img_t[0].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5\n",
552
+ " ax.imshow(np.clip(img, 0, 1))\n",
553
+ " ax.axis('off')\n",
554
+ " if idx == 0:\n",
555
+ " ax.set_title('Noise (t=1)')\n",
556
+ " elif idx == len(intermediates) - 1:\n",
557
+ " ax.set_title('Output (t=0)')\n",
558
+ " else:\n",
559
+ " ax.set_title(f't={1-idx*vis_interval/total_euler_steps:.1f}')\n",
560
+ "plt.suptitle('LiquidDiffusion Denoising Process')\n",
561
+ "plt.tight_layout()\n",
562
+ "plt.show()"
563
+ ]
564
+ },
565
+ {
566
+ "cell_type": "markdown",
567
+ "metadata": {},
568
+ "source": [
569
+ "## 💾 Save & Export Model"
570
+ ]
571
+ },
572
+ {
573
+ "cell_type": "code",
574
+ "execution_count": null,
575
+ "metadata": {},
576
+ "outputs": [],
577
+ "source": [
578
+ "# Save final checkpoint\n",
579
+ "trainer.save_checkpoint('checkpoints/final.pt', extra={\n",
580
+ " 'config': {\n",
581
+ " 'model_size': model_size,\n",
582
+ " 'image_size': image_size,\n",
583
+ " 'total_params': total_params,\n",
584
+ " 'training_steps': trainer.step,\n",
585
+ " 'dataset': dataset_name,\n",
586
+ " }\n",
587
+ "})\n",
588
+ "print(f\"Saved checkpoint: checkpoints/final.pt\")\n",
589
+ "print(f\"Model: liquid_diffusion_{model_size} ({total_params/1e6:.1f}M params)\")\n",
590
+ "print(f\"Trained for {trainer.step:,} steps on {dataset_name}\")"
591
+ ]
592
+ },
593
+ {
594
+ "cell_type": "code",
595
+ "execution_count": null,
596
+ "metadata": {},
597
+ "outputs": [],
598
+ "source": [
599
+ "# Optional: Push to Hugging Face Hub\n",
600
+ "# Uncomment and fill in your details:\n",
601
+ "\n",
602
+ "# from huggingface_hub import HfApi, login\n",
603
+ "# login() # or use token\n",
604
+ "# api = HfApi()\n",
605
+ "# repo_id = \"your-username/liquid-diffusion-celebahq-256\" # change this\n",
606
+ "# api.create_repo(repo_id, exist_ok=True)\n",
607
+ "# api.upload_file('checkpoints/final.pt', 'model.pt', repo_id)\n",
608
+ "# api.upload_folder('liquid_diffusion/', 'liquid_diffusion/', repo_id)\n",
609
+ "# print(f\"Uploaded to https://huggingface.co/{repo_id}\")"
610
+ ]
611
+ },
612
+ {
613
+ "cell_type": "markdown",
614
+ "metadata": {},
615
+ "source": [
616
+ "## 📚 Architecture Details & Theory\n",
617
+ "\n",
618
+ "### Why Liquid Neural Networks for Image Generation?\n",
619
+ "\n",
620
+ "**Liquid Time-Constant (LTC) Networks** (Hasani et al., 2020) define neurons with input-dependent time constants:\n",
621
+ "\n",
622
+ "```\n",
623
+ "dx/dt = -[1/τ + f(x,I,θ)] · x + f(x,I,θ) · A\n",
624
+ "```\n",
625
+ "\n",
626
+ "The system time constant `τ_sys = τ/(1 + τ·f)` adapts dynamically based on input — the neuron speeds up or slows down its response depending on what it sees. This is the \"liquid\" property.\n",
627
+ "\n",
628
+ "**CfC (Closed-form Continuous-depth)** networks (Hasani et al., 2022) solve this ODE in closed form:\n",
629
+ "\n",
630
+ "```\n",
631
+ "x(t) = σ(-f·t) ⊙ g + (1 - σ(-f·t)) ⊙ h\n",
632
+ "```\n",
633
+ "\n",
634
+ "This eliminates the ODE solver — making CfC **fully parallelizable** while preserving the adaptive time constant behavior.\n",
635
+ "\n",
636
+ "### Our Innovation: CfC × Diffusion Timestep\n",
637
+ "\n",
638
+ "In diffusion models, the network must process images at different noise levels `t ∈ [0,1]`. We observe that:\n",
639
+ "\n",
640
+ "1. CfC's time parameter `t` controls interpolation between two learned states\n",
641
+ "2. Diffusion's noise level `t` controls how the denoiser should behave\n",
642
+ "3. **These are the same concept** — the CfC time parameter IS the diffusion timestep\n",
643
+ "\n",
644
+ "This gives us:\n",
645
+ "- At `t≈0` (clean images): σ(-f·t)≈0.5, balanced processing for detail refinement\n",
646
+ "- At `t≈1` (noisy images): σ(-f·t) saturates, specialized denoising\n",
647
+ "- The gate `f` is **input-dependent** — different image content gets different time responses\n",
648
+ "\n",
649
+ "### References\n",
650
+ "\n",
651
+ "1. Hasani et al., \"Liquid Time-constant Networks\" (AAAI 2021) — arxiv:2006.04439\n",
652
+ "2. Hasani et al., \"Closed-form Continuous-time Neural Networks\" (Nature MI 2022) — arxiv:2106.13898\n",
653
+ "3. LiquidTAD: Parallel liquid relaxation — arxiv:2604.18274\n",
654
+ "4. USM: U-Shape Mamba for diffusion — arxiv:2504.13499\n",
655
+ "5. DiffuSSM: Diffusion without attention — arxiv:2311.18257\n",
656
+ "6. Liu et al., \"Flow Straight and Fast: Rectified Flow\" (ICLR 2023) — arxiv:2209.03003"
657
+ ]
658
+ }
659
+ ],
660
+ "metadata": {
661
+ "accelerator": "GPU",
662
+ "colab": {
663
+ "gpuType": "T4",
664
+ "provenance": [],
665
+ "toc_visible": true
666
+ },
667
+ "kernelspec": {
668
+ "display_name": "Python 3",
669
+ "name": "python3"
670
+ },
671
+ "language_info": {
672
+ "name": "python",
673
+ "version": "3.10.0"
674
+ }
675
  },
676
+ "nbformat": 4,
677
+ "nbformat_minor": 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
678
  }