krystv commited on
Commit
9d75008
·
verified ·
1 Parent(s): 658087c

Upload LiquidFlow_Colab.ipynb

Browse files
Files changed (1) hide show
  1. LiquidFlow_Colab.ipynb +428 -0
LiquidFlow_Colab.ipynb ADDED
@@ -0,0 +1,428 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# LiquidFlow: Liquid Neural Network + Mamba-2 SSD Image Generator\n",
8
+ "\n",
9
+ "**Train on Google Colab Free Tier (T4 GPU) | Export for Mobile Deployment**\n",
10
+ "\n",
11
+ "LiquidFlow combines:\n",
12
+ "- **CfC (Closed-form Continuous-time)** Liquid Neural Networks — adaptive time gates\n",
13
+ "- **Mamba-2 SSD** — linear-time attention replacement, fully parallelizable\n",
14
+ "- **Physics-Informed Regularization** — TV loss, spectral constraints\n",
15
+ "- **TAESD VAE** — Tiny AutoEncoder (< 1M params) for fast encoding\n",
16
+ "\n",
17
+ "Based on:\n",
18
+ "- CfC: Hasani et al., Nature MI 2022\n",
19
+ "- Mamba-2: Dao & Gu, 2024 \n",
20
+ "- PINN Diffusion: Bastek & Sun, ICLR 2025\n",
21
+ "- DiMSUM: NeurIPS 2024\n",
22
+ "\n",
23
+ "---\n",
24
+ "## Quick Start\n",
25
+ "1. Runtime → Change runtime type → GPU (T4)\n",
26
+ "2. Run all cells in order\n",
27
+ "3. Training starts automatically on CIFAR-10\n",
28
+ "4. Check samples in `./outputs/samples/`"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "code",
33
+ "execution_count": null,
34
+ "metadata": {},
35
+ "source": [
36
+ "# @title 1. Install Dependencies (~2 min)\n",
37
+ "!pip install -q torch torchvision diffusers tqdm pillow numpy\n",
38
+ "!pip install -q git+https://github.com/huggingface/diffusers.git\n",
39
+ "\n",
40
+ "import torch\n",
41
+ "print(f\"PyTorch: {torch.__version__}\")\n",
42
+ "print(f\"CUDA available: {torch.cuda.is_available()}\")\n",
43
+ "if torch.cuda.is_available():\n",
44
+ " print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n",
45
+ " print(f\"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB\")"
46
+ ],
47
+ "outputs": []
48
+ },
49
+ {
50
+ "cell_type": "code",
51
+ "execution_count": null,
52
+ "metadata": {},
53
+ "source": [
54
+ "# @title 2. Clone LiquidFlow Repository\n",
55
+ "!git clone https://huggingface.co/krystv/LiquidFlow-Gen /content/LiquidFlow\n",
56
+ "%cd /content/LiquidFlow\n",
57
+ "\n",
58
+ "import sys\n",
59
+ "sys.path.insert(0, '/content/LiquidFlow')\n",
60
+ "\n",
61
+ "from liquid_flow.generator import create_liquidflow\n",
62
+ "from liquid_flow.vae_wrapper import TAESDWrapper\n",
63
+ "import torch.nn as nn\n",
64
+ "import torch.nn.functional as F"
65
+ ],
66
+ "outputs": []
67
+ },
68
+ {
69
+ "cell_type": "code",
70
+ "execution_count": null,
71
+ "metadata": {},
72
+ "source": [
73
+ "# @title 3. Configuration — Adjust these settings!\n",
74
+ "\n",
75
+ "# Model size: 'tiny' (~2M), 'small' (~8M), 'base' (~30M)\n",
76
+ "MODEL_VARIANT = 'small' # @param ['tiny', 'small', 'base']\n",
77
+ "\n",
78
+ "# Image size: 128 recommended for T4, 512 needs more VRAM\n",
79
+ "IMAGE_SIZE = 128 # @param [64, 128, 256, 512]\n",
80
+ "\n",
81
+ "# Training\n",
82
+ "BATCH_SIZE = 32 # @param [8, 16, 32, 64]\n",
83
+ "EPOCHS = 50 # @param [10, 25, 50, 100]\n",
84
+ "LEARNING_RATE = 2e-4 # @param [1e-4, 2e-4, 5e-4, 1e-3]\n",
85
+ "\n",
86
+ "# Dataset\n",
87
+ "DATASET = 'cifar10' # @param ['cifar10', 'cifar100', 'stl10']\n",
88
+ "\n",
89
+ "# Sampling (DDIM steps)\n",
90
+ "SAMPLE_EVERY = 5 # @param [1, 5, 10]\n",
91
+ "SAMPLE_STEPS = 50 # @param [20, 50, 100]\n",
92
+ "\n",
93
+ "# Physics regularization weights\n",
94
+ "PHYSICS_TV_WEIGHT = 0.01\n",
95
+ "PHYSICS_SPEC_WEIGHT = 0.01\n",
96
+ "PHYSICS_GRAD_WEIGHT = 0.001\n",
97
+ "\n",
98
+ "print(f\"Config: {MODEL_VARIANT} model, {IMAGE_SIZE}px, batch={BATCH_SIZE}, epochs={EPOCHS}, lr={LEARNING_RATE}\")\n",
99
+ "print(f\"Physics loss: TV={PHYSICS_TV_WEIGHT}, Spec={PHYSICS_SPEC_WEIGHT}, Grad={PHYSICS_GRAD_WEIGHT}\")"
100
+ ],
101
+ "outputs": []
102
+ },
103
+ {
104
+ "cell_type": "code",
105
+ "execution_count": null,
106
+ "metadata": {},
107
+ "source": [
108
+ "# @title 4. Load VAE & Create Model\n",
109
+ "import torch\n",
110
+ "from torchvision import datasets, transforms\n",
111
+ "from torch.utils.data import DataLoader\n",
112
+ "import os\n",
113
+ "from tqdm import tqdm\n",
114
+ "\n",
115
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
116
+ "\n",
117
+ "# Load TAESD (Tiny AutoEncoder)\n",
118
+ "print(\"Loading TAESD VAE...\")\n",
119
+ "vae = TAESDWrapper.load(device)\n",
120
+ "print(f\"VAE loaded! Latent compression: {IMAGE_SIZE}x{IMAGE_SIZE} → {IMAGE_SIZE//8}x{IMAGE_SIZE//8}\")\n",
121
+ "\n",
122
+ "# Create LiquidFlow model\n",
123
+ "print(f\"Creating {MODEL_VARIANT} LiquidFlow model...\")\n",
124
+ "model = create_liquidflow(\n",
125
+ " variant=MODEL_VARIANT,\n",
126
+ " image_size=IMAGE_SIZE,\n",
127
+ " physics_weights={\n",
128
+ " 'tv': PHYSICS_TV_WEIGHT,\n",
129
+ " 'cons': 0.001,\n",
130
+ " 'spec': PHYSICS_SPEC_WEIGHT,\n",
131
+ " 'grad': PHYSICS_GRAD_WEIGHT,\n",
132
+ " },\n",
133
+ ")\n",
134
+ "model = model.to(device)\n",
135
+ "\n",
136
+ "n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
137
+ "print(f\"Model: {n_params:,} parameters ({n_params/1e6:.1f}M)\")\n",
138
+ "\n",
139
+ "# Memory estimate\n",
140
+ "latent_size = IMAGE_SIZE // 8\n",
141
+ "mem_per_sample = latent_size * latent_size * 4 * 4 / 1e6 # MB\n",
142
+ "print(f\"Memory per sample: {mem_per_sample:.1f} MB\")\n",
143
+ "print(f\"Estimated batch memory: {mem_per_sample * BATCH_SIZE:.1f} MB\")\n",
144
+ "print(f\"T4 VRAM: 15 GB — should fit!\" if mem_per_sample * BATCH_SIZE < 10 else \"Watch memory!\")"
145
+ ],
146
+ "outputs": []
147
+ },
148
+ {
149
+ "cell_type": "code",
150
+ "execution_count": null,
151
+ "metadata": {},
152
+ "source": [
153
+ "# @title 5. Load Dataset\n",
154
+ "\n",
155
+ "transform = transforms.Compose([\n",
156
+ " transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),\n",
157
+ " transforms.ToTensor(),\n",
158
+ " transforms.Normalize([0.5], [0.5]),\n",
159
+ "])\n",
160
+ "\n",
161
+ "if DATASET == 'cifar10':\n",
162
+ " dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)\n",
163
+ "elif DATASET == 'cifar100':\n",
164
+ " dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)\n",
165
+ "elif DATASET == 'stl10':\n",
166
+ " dataset = datasets.STL10(root='./data', split='train', download=True, transform=transform)\n",
167
+ "else:\n",
168
+ " raise ValueError(f\"Unknown dataset: {DATASET}\")\n",
169
+ "\n",
170
+ "dataloader = DataLoader(\n",
171
+ " dataset, batch_size=BATCH_SIZE, shuffle=True,\n",
172
+ " num_workers=min(4, os.cpu_count() or 1),\n",
173
+ " pin_memory=True, drop_last=True,\n",
174
+ ")\n",
175
+ "\n",
176
+ "print(f\"Dataset: {DATASET}\")\n",
177
+ "print(f\"Images: {len(dataset):,}, Batches per epoch: {len(dataloader)}\")"
178
+ ],
179
+ "outputs": []
180
+ },
181
+ {
182
+ "cell_type": "code",
183
+ "execution_count": null,
184
+ "metadata": {},
185
+ "source": [
186
+ "# @title 6. Training Loop\n",
187
+ "\n",
188
+ "from torchvision.utils import save_image\n",
189
+ "import math\n",
190
+ "\n",
191
+ "os.makedirs('./outputs/samples', exist_ok=True)\n",
192
+ "os.makedirs('./outputs/checkpoints', exist_ok=True)\n",
193
+ "\n",
194
+ "# Optimizer\n",
195
+ "optimizer = torch.optim.AdamW(\n",
196
+ " model.parameters(),\n",
197
+ " lr=LEARNING_RATE,\n",
198
+ " betas=(0.9, 0.999),\n",
199
+ " weight_decay=1e-4,\n",
200
+ ")\n",
201
+ "\n",
202
+ "# Cosine LR scheduler\n",
203
+ "scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n",
204
+ " optimizer, T_max=EPOCHS * len(dataloader)\n",
205
+ ")\n",
206
+ "\n",
207
+ "# AMP\n",
208
+ "use_amp = device.type == 'cuda'\n",
209
+ "scaler = torch.cuda.amp.GradScaler() if use_amp else None\n",
210
+ "\n",
211
+ "print(f\"Training: {EPOCHS} epochs, LR={LEARNING_RATE}, AMP={use_amp}\")\n",
212
+ "print(\"=\"*60)\n",
213
+ "\n",
214
+ "global_step = 0\n",
215
+ "best_loss = float('inf')\n",
216
+ "\n",
217
+ "for epoch in range(EPOCHS):\n",
218
+ " model.train()\n",
219
+ " epoch_total = 0\n",
220
+ " \n",
221
+ " pbar = tqdm(dataloader, desc=f\"Epoch {epoch+1}/{EPOCHS}\")\n",
222
+ " \n",
223
+ " for images, _ in pbar:\n",
224
+ " images = images.to(device)\n",
225
+ " \n",
226
+ " # Encode to latent\n",
227
+ " with torch.no_grad():\n",
228
+ " latents = TAESDWrapper.encode(vae, images)\n",
229
+ " \n",
230
+ " # Training step with physics regularization\n",
231
+ " loss_dict = model.training_step(latents, optimizer, scaler, use_amp)\n",
232
+ " \n",
233
+ " # Track\n",
234
+ " total_loss = loss_dict['total']\n",
235
+ " epoch_total += total_loss\n",
236
+ " \n",
237
+ " # Update scheduler\n",
238
+ " scheduler.step()\n",
239
+ " \n",
240
+ " # Progress bar\n",
241
+ " pbar.set_postfix({\n",
242
+ " 'loss': f\"{total_loss:.4f}\",\n",
243
+ " 'diff': f\"{loss_dict.get('diffusion', 0):.4f}\",\n",
244
+ " 'phys': f\"{loss_dict.get('physics', 0):.4f}\",\n",
245
+ " 'lr': f\"{optimizer.param_groups[0]['lr']:.2e}\",\n",
246
+ " })\n",
247
+ " \n",
248
+ " global_step += 1\n",
249
+ " \n",
250
+ " avg_loss = epoch_total / len(dataloader)\n",
251
+ " print(f\"Epoch {epoch+1}: avg_loss={avg_loss:.4f}\")\n",
252
+ " \n",
253
+ " # Generate samples\n",
254
+ " if (epoch + 1) % SAMPLE_EVERY == 0 or epoch == EPOCHS - 1:\n",
255
+ " print(\" Generating samples...\")\n",
256
+ " model.eval()\n",
257
+ " with torch.no_grad():\n",
258
+ " latents_gen = model.sample(\n",
259
+ " batch_size=16,\n",
260
+ " steps=SAMPLE_STEPS,\n",
261
+ " ddim=True,\n",
262
+ " progress=False,\n",
263
+ " )\n",
264
+ " images_gen = TAESDWrapper.decode(vae, latents_gen)\n",
265
+ " \n",
266
+ " save_image(\n",
267
+ " images_gen, f'./outputs/samples/epoch_{epoch+1:03d}.png',\n",
268
+ " nrow=4, normalize=True, value_range=(-1, 1)\n",
269
+ " )\n",
270
+ " print(f\" Saved to ./outputs/samples/epoch_{epoch+1:03d}.png\")\n",
271
+ " \n",
272
+ " # Save checkpoint\n",
273
+ " if (epoch + 1) % 10 == 0 or epoch == EPOCHS - 1:\n",
274
+ " torch.save({\n",
275
+ " 'epoch': epoch + 1,\n",
276
+ " 'model_state_dict': model.state_dict(),\n",
277
+ " 'optimizer_state_dict': optimizer.state_dict(),\n",
278
+ " 'loss': avg_loss,\n",
279
+ " }, f'./outputs/checkpoints/epoch_{epoch+1:03d}.pt')\n",
280
+ " \n",
281
+ " if avg_loss < best_loss:\n",
282
+ " best_loss = avg_loss\n",
283
+ " torch.save(model.state_dict(), './outputs/checkpoints/best_model.pt')\n",
284
+ "\n",
285
+ "print(\"=\"*60)\n",
286
+ "print(f\"Training complete! Best loss: {best_loss:.4f}\")\n",
287
+ "print(f\"Checkpoints saved to ./outputs/checkpoints/\")\n",
288
+ "print(f\"Samples saved to ./outputs/samples/\")"
289
+ ],
290
+ "outputs": []
291
+ },
292
+ {
293
+ "cell_type": "code",
294
+ "execution_count": null,
295
+ "metadata": {},
296
+ "source": [
297
+ "# @title 7. Generate & Display Samples\n",
298
+ "\n",
299
+ "import matplotlib.pyplot as plt\n",
300
+ "from PIL import Image\n",
301
+ "import glob\n",
302
+ "\n",
303
+ "# Load latest sample\n",
304
+ "sample_files = sorted(glob.glob('./outputs/samples/epoch_*.png'))\n",
305
+ "if sample_files:\n",
306
+ " latest = sample_files[-1]\n",
307
+ " img = Image.open(latest)\n",
308
+ " plt.figure(figsize=(12, 12))\n",
309
+ " plt.imshow(img)\n",
310
+ " plt.title(f'LiquidFlow Samples — {MODEL_VARIANT} model, {IMAGE_SIZE}px')\n",
311
+ " plt.axis('off')\n",
312
+ " plt.show()\n",
313
+ "else:\n",
314
+ " print(\"No samples generated yet. Train for more epochs!\")"
315
+ ],
316
+ "outputs": []
317
+ },
318
+ {
319
+ "cell_type": "code",
320
+ "execution_count": null,
321
+ "metadata": {},
322
+ "source": [
323
+ "# @title 8. Export Model for Mobile (ONNX)\n",
324
+ "\n",
325
+ "# LiquidFlow can be exported to ONNX for mobile deployment\n",
326
+ "# since it uses pure PyTorch (no custom CUDA kernels)\n",
327
+ "\n",
328
+ "def export_to_onnx(model, output_path='liquidflow_model.onnx', image_size=128):\n",
329
+ " \"\"\"Export LiquidFlow to ONNX for mobile deployment.\"\"\"\n",
330
+ " model = model.cpu()\n",
331
+ " model.eval()\n",
332
+ " \n",
333
+ " latent_size = image_size // 8\n",
334
+ " \n",
335
+ " # Dummy inputs\n",
336
+ " x = torch.randn(1, 4, latent_size, latent_size)\n",
337
+ " t = torch.tensor([500], dtype=torch.long)\n",
338
+ " \n",
339
+ " # Export\n",
340
+ " torch.onnx.export(\n",
341
+ " model,\n",
342
+ " (x, t),\n",
343
+ " output_path,\n",
344
+ " input_names=['noisy_latent', 'timestep'],\n",
345
+ " output_names=['predicted_noise'],\n",
346
+ " dynamic_axes={\n",
347
+ " 'noisy_latent': {0: 'batch'},\n",
348
+ " 'predicted_noise': {0: 'batch'},\n",
349
+ " },\n",
350
+ " opset_version=14,\n",
351
+ " )\n",
352
+ " \n",
353
+ " import os\n",
354
+ " size_mb = os.path.getsize(output_path) / 1e6\n",
355
+ " print(f\"ONNX model exported to {output_path} ({size_mb:.1f} MB)\")\n",
356
+ " return output_path\n",
357
+ "\n",
358
+ "# Load best model and export\n",
359
+ "best_model_path = './outputs/checkpoints/best_model.pt'\n",
360
+ "if os.path.exists(best_model_path):\n",
361
+ " model.load_state_dict(torch.load(best_model_path, map_location='cpu'))\n",
362
+ " export_to_onnx(model, 'liquidflow_128.onnx', IMAGE_SIZE)\n",
363
+ " print(\"Ready for mobile deployment!\")\n",
364
+ "else:\n",
365
+ " print(\"Train model first before exporting.\")"
366
+ ],
367
+ "outputs": []
368
+ },
369
+ {
370
+ "cell_type": "markdown",
371
+ "metadata": {},
372
+ "source": [
373
+ "## Architecture Details\n",
374
+ "\n",
375
+ "### LiquidFlow Block Architecture\n",
376
+ "```\n",
377
+ "Input → [CfC Gate → Mamba-2 SSD → CfC Gate] → Output\n",
378
+ " ↑ ↑\n",
379
+ " Adaptive time gate Gated output\n",
380
+ "```\n",
381
+ "\n",
382
+ "### CfC (Closed-form Continuous-time) Cell\n",
383
+ "```\n",
384
+ "h(t) = σ(-f(x,I;θ_f)·t) ⊙ g(x,I;θ_g) + (1-σ(-f(x,I;θ_f)·t)) ⊙ h(x,I;θ_h)\n",
385
+ "```\n",
386
+ "- **No ODE solver needed** — 100x faster than Neural ODEs\n",
387
+ "- Time-continuous gating adaptively controls information flow\n",
388
+ "- Closed-form solution → stable gradients\n",
389
+ "\n",
390
+ "### Mamba-2 SSD (State Space Duality)\n",
391
+ "```\n",
392
+ "h_t = A_t * h_{t-1} + B_t * x_t\n",
393
+ "y_t = C_t^T * h_t\n",
394
+ "```\n",
395
+ "- **O(N) linear complexity** vs Transformers O(N²)\n",
396
+ "- **Parallelizable** via associative scan (Blelloch)\n",
397
+ "- **Scalar-A** formulation enables chunk-scan optimization\n",
398
+ "- Pure PyTorch — no CUDA kernels needed\n",
399
+ "\n",
400
+ "### Physics-Informed Regularization\n",
401
+ "- **Total Variation**: `L_TV = ||∇_x x̂||₁ + ||∇_y x̂||₁`\n",
402
+ "- **Spectral**: Penalize high-frequency artifacts\n",
403
+ "- **Gradient**: Sobolev norm for stable training\n",
404
+ "- Pattern from Bastek & Sun (ICLR 2025): physics loss as training-only regularizer\n",
405
+ "\n",
406
+ "### Model Variants\n",
407
+ "| Variant | Params | Hidden Dim | Stages | Blocks | T4 VRAM |\n",
408
+ "|---------|--------|------------|--------|--------|---------|\n",
409
+ "| Tiny | ~2M | 128 | 2 | 2 | < 2 GB |\n",
410
+ "| Small | ~8M | 256 | 4 | 4 | ~4 GB |\n",
411
+ "| Base | ~30M | 384 | 6 | 6 | ~8 GB |"
412
+ ]
413
+ }
414
+ ],
415
+ "metadata": {
416
+ "colab": {
417
+ "name": "LiquidFlow: LiquidNN + Mamba-2 SSD Image Generator",
418
+ "provenance": []
419
+ },
420
+ "kernelspec": {
421
+ "display_name": "Python 3",
422
+ "language": "python",
423
+ "name": "python3"
424
+ }
425
+ },
426
+ "nbformat": 4,
427
+ "nbformat_minor": 0
428
+ }