Upload LiquidDiffusion_Training.ipynb
Browse files- LiquidDiffusion_Training.ipynb +675 -873
LiquidDiffusion_Training.ipynb
CHANGED
|
@@ -1,876 +1,678 @@
|
|
| 1 |
{
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
},
|
| 9 |
-
"
|
| 10 |
-
|
| 11 |
-
"display_name": "Python 3"
|
| 12 |
-
},
|
| 13 |
-
"accelerator": "GPU"
|
| 14 |
-
},
|
| 15 |
-
"cells": [
|
| 16 |
-
{
|
| 17 |
-
"cell_type": "markdown",
|
| 18 |
-
"metadata": {},
|
| 19 |
-
"source": [
|
| 20 |
-
"# 🌊 LiquidDiffusion: Attention-Free Image Generation with Liquid Neural Networks\n",
|
| 21 |
-
"\n",
|
| 22 |
-
"**A novel image generation model** combining:\n",
|
| 23 |
-
"- **Liquid Neural Networks** (CfC — Closed-form Continuous-depth) for adaptive, time-aware processing\n",
|
| 24 |
-
"- **Rectified Flow** for simple, stable training (MSE velocity prediction)\n",
|
| 25 |
-
"- **Zero attention** — fully convolutional with multi-scale spatial mixing\n",
|
| 26 |
-
"- **Fully parallelizable** — no sequential ODE loops or recurrence\n",
|
| 27 |
-
"\n",
|
| 28 |
-
"### Key Innovation\n",
|
| 29 |
-
"The diffusion timestep serves as the **liquid time constant** — the CfC gate `σ(-f·t)` naturally adapts the network's behavior based on noise level, giving input-dependent processing without attention.\n",
|
| 30 |
-
"\n",
|
| 31 |
-
"### References\n",
|
| 32 |
-
"- CfC Networks: [Hasani et al., Nature MI 2022](https://arxiv.org/abs/2106.13898)\n",
|
| 33 |
-
"- LiquidTAD (parallel CfC): [arxiv 2604.18274](https://arxiv.org/abs/2604.18274)\n",
|
| 34 |
-
"- USM (U-Shape Mamba): [arxiv 2504.13499](https://arxiv.org/abs/2504.13499)\n",
|
| 35 |
-
"- Rectified Flow: [Liu et al., ICLR 2023](https://arxiv.org/abs/2209.03003)\n",
|
| 36 |
-
"\n",
|
| 37 |
-
"**Model repo**: [krystv/liquid-diffusion](https://huggingface.co/krystv/liquid-diffusion)\n",
|
| 38 |
-
"\n",
|
| 39 |
-
"---"
|
| 40 |
-
]
|
| 41 |
-
},
|
| 42 |
-
{
|
| 43 |
-
"cell_type": "markdown",
|
| 44 |
-
"metadata": {},
|
| 45 |
-
"source": [
|
| 46 |
-
"## ⚙️ Configuration\n",
|
| 47 |
-
"\n",
|
| 48 |
-
"Choose your settings here. Everything is configurable from this one cell."
|
| 49 |
-
]
|
| 50 |
-
},
|
| 51 |
-
{
|
| 52 |
-
"cell_type": "code",
|
| 53 |
-
"execution_count": null,
|
| 54 |
-
"metadata": {},
|
| 55 |
-
"outputs": [],
|
| 56 |
-
"source": [
|
| 57 |
-
"# ============================================================================\n",
|
| 58 |
-
"# CONFIGURATION — Edit this cell to customize your training\n",
|
| 59 |
-
"# ============================================================================\n",
|
| 60 |
-
"\n",
|
| 61 |
-
"# --- Model Size ---\n",
|
| 62 |
-
"# 'tiny' = ~23M params, best for 256px, fits easily in T4 16GB\n",
|
| 63 |
-
"# 'small' = ~69M params, better quality 256px, tight fit on T4\n",
|
| 64 |
-
"# 'base' = ~154M params, for 512px (needs A100 or reduce batch size)\n",
|
| 65 |
-
"# 'custom' = define your own channels/blocks below\n",
|
| 66 |
-
"MODEL_SIZE = 'tiny' # @param ['tiny', 'small', 'base', 'custom']\n",
|
| 67 |
-
"\n",
|
| 68 |
-
"# Custom model config (only used if MODEL_SIZE='custom')\n",
|
| 69 |
-
"CUSTOM_CHANNELS = [48, 96, 192] # channel dims per stage\n",
|
| 70 |
-
"CUSTOM_BLOCKS = [1, 2, 3] # blocks per stage\n",
|
| 71 |
-
"CUSTOM_T_DIM = 192 # time embedding dimension\n",
|
| 72 |
-
"\n",
|
| 73 |
-
"# --- Image Resolution ---\n",
|
| 74 |
-
"IMAGE_SIZE = 256 # @param [64, 128, 256, 512] {type:\"integer\"}\n",
|
| 75 |
-
"\n",
|
| 76 |
-
"# --- Dataset ---\n",
|
| 77 |
-
"# Options:\n",
|
| 78 |
-
"# 'huggan/CelebA-HQ' - 30K celebrity faces (256px native)\n",
|
| 79 |
-
"# 'huggan/flowers-102-categories' - Flowers dataset\n",
|
| 80 |
-
"# 'huggan/anime-faces' - Anime faces\n",
|
| 81 |
-
"# 'lambdalabs/naruto-blip-captions' - Naruto illustrations\n",
|
| 82 |
-
"# 'jlbaker361/CelebA-HQ-256' - CelebA-HQ at 256px\n",
|
| 83 |
-
"# Or any HF dataset with an 'image' column\n",
|
| 84 |
-
"# Or a local folder path with images\n",
|
| 85 |
-
"DATASET = 'huggan/CelebA-HQ' # @param {type:\"string\"}\n",
|
| 86 |
-
"IMAGE_COLUMN = 'image' # column name in HF dataset containing images\n",
|
| 87 |
-
"MAX_SAMPLES = None # Set to e.g. 1000 for quick testing, None for full dataset\n",
|
| 88 |
-
"\n",
|
| 89 |
-
"# --- Training ---\n",
|
| 90 |
-
"BATCH_SIZE = 8 # @param {type:\"integer\"}\n",
|
| 91 |
-
"LEARNING_RATE = 1e-4 # @param {type:\"number\"}\n",
|
| 92 |
-
"WEIGHT_DECAY = 0.01 # @param {type:\"number\"}\n",
|
| 93 |
-
"NUM_EPOCHS = 100 # @param {type:\"integer\"}\n",
|
| 94 |
-
"GRAD_CLIP = 1.0 # @param {type:\"number\"}\n",
|
| 95 |
-
"EMA_DECAY = 0.9999 # @param {type:\"number\"}\n",
|
| 96 |
-
"NUM_WORKERS = 2 # DataLoader workers\n",
|
| 97 |
-
"\n",
|
| 98 |
-
"# --- Time Sampling ---\n",
|
| 99 |
-
"# 'logit_normal' (from SD3) = more weight on intermediate timesteps\n",
|
| 100 |
-
"# 'uniform' = standard\n",
|
| 101 |
-
"TIME_SAMPLING = 'logit_normal' # @param ['uniform', 'logit_normal']\n",
|
| 102 |
-
"\n",
|
| 103 |
-
"# --- Mixed Precision ---\n",
|
| 104 |
-
"USE_AMP = True # @param {type:\"boolean\"}\n",
|
| 105 |
-
"AMP_DTYPE = 'float16' # @param ['float16', 'bfloat16']\n",
|
| 106 |
-
"\n",
|
| 107 |
-
"# --- Sampling ---\n",
|
| 108 |
-
"SAMPLE_EVERY = 500 # Generate samples every N steps\n",
|
| 109 |
-
"NUM_SAMPLE_IMAGES = 8 # Images to generate per sample\n",
|
| 110 |
-
"NUM_EULER_STEPS = 50 # Euler ODE steps (more = better quality)\n",
|
| 111 |
-
"\n",
|
| 112 |
-
"# --- Checkpointing ---\n",
|
| 113 |
-
"SAVE_EVERY = 2000 # Save checkpoint every N steps\n",
|
| 114 |
-
"OUTPUT_DIR = './outputs' # Where to save everything\n",
|
| 115 |
-
"RESUME_FROM = None # Path to checkpoint to resume from, or None\n",
|
| 116 |
-
"\n",
|
| 117 |
-
"# --- Logging ---\n",
|
| 118 |
-
"LOG_EVERY = 50 # Print loss every N steps\n",
|
| 119 |
-
"\n",
|
| 120 |
-
"print(f\"Config: {MODEL_SIZE} model, {IMAGE_SIZE}px, dataset={DATASET}\")\n",
|
| 121 |
-
"print(f\"Training: bs={BATCH_SIZE}, lr={LEARNING_RATE}, epochs={NUM_EPOCHS}\")\n",
|
| 122 |
-
"print(f\"AMP: {USE_AMP} ({AMP_DTYPE}), Time sampling: {TIME_SAMPLING}\")"
|
| 123 |
-
]
|
| 124 |
-
},
|
| 125 |
-
{
|
| 126 |
-
"cell_type": "markdown",
|
| 127 |
-
"metadata": {},
|
| 128 |
-
"source": [
|
| 129 |
-
"## 📦 Install Dependencies"
|
| 130 |
-
]
|
| 131 |
-
},
|
| 132 |
-
{
|
| 133 |
-
"cell_type": "code",
|
| 134 |
-
"execution_count": null,
|
| 135 |
-
"metadata": {},
|
| 136 |
-
"outputs": [],
|
| 137 |
-
"source": [
|
| 138 |
-
"!pip install -q datasets huggingface_hub Pillow matplotlib\n",
|
| 139 |
-
"\n",
|
| 140 |
-
"import torch\n",
|
| 141 |
-
"print(f\"PyTorch: {torch.__version__}\")\n",
|
| 142 |
-
"print(f\"CUDA available: {torch.cuda.is_available()}\")\n",
|
| 143 |
-
"if torch.cuda.is_available():\n",
|
| 144 |
-
" print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n",
|
| 145 |
-
" print(f\"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB\")"
|
| 146 |
-
]
|
| 147 |
-
},
|
| 148 |
-
{
|
| 149 |
-
"cell_type": "markdown",
|
| 150 |
-
"metadata": {},
|
| 151 |
-
"source": [
|
| 152 |
-
"## 🏗️ Model Architecture\n",
|
| 153 |
-
"\n",
|
| 154 |
-
"The complete LiquidDiffusion model — defined inline so you can inspect and modify everything."
|
| 155 |
-
]
|
| 156 |
-
},
|
| 157 |
-
{
|
| 158 |
-
"cell_type": "code",
|
| 159 |
-
"execution_count": null,
|
| 160 |
-
"metadata": {},
|
| 161 |
-
"outputs": [],
|
| 162 |
-
"source": [
|
| 163 |
-
"import math\n",
|
| 164 |
-
"import copy\n",
|
| 165 |
-
"import os\n",
|
| 166 |
-
"import time\n",
|
| 167 |
-
"import json\n",
|
| 168 |
-
"from glob import glob\n",
|
| 169 |
-
"\n",
|
| 170 |
-
"import torch\n",
|
| 171 |
-
"import torch.nn as nn\n",
|
| 172 |
-
"import torch.nn.functional as F\n",
|
| 173 |
-
"from torch.utils.data import DataLoader, Dataset\n",
|
| 174 |
-
"from torchvision import transforms\n",
|
| 175 |
-
"from torchvision.utils import save_image, make_grid\n",
|
| 176 |
-
"\n",
|
| 177 |
-
"\n",
|
| 178 |
-
"# ========================= TIME EMBEDDING =========================\n",
|
| 179 |
-
"\n",
|
| 180 |
-
"class SinusoidalTimeEmbedding(nn.Module):\n",
|
| 181 |
-
" \"\"\"Sinusoidal position encoding + MLP for timestep embedding.\"\"\"\n",
|
| 182 |
-
" def __init__(self, dim, max_period=10000):\n",
|
| 183 |
-
" super().__init__()\n",
|
| 184 |
-
" self.dim = dim\n",
|
| 185 |
-
" self.max_period = max_period\n",
|
| 186 |
-
" self.mlp = nn.Sequential(nn.Linear(dim, dim*4), nn.SiLU(), nn.Linear(dim*4, dim))\n",
|
| 187 |
-
"\n",
|
| 188 |
-
" def forward(self, t):\n",
|
| 189 |
-
" half = self.dim // 2\n",
|
| 190 |
-
" freqs = torch.exp(-math.log(self.max_period) * torch.arange(half, device=t.device, dtype=t.dtype) / half)\n",
|
| 191 |
-
" args = t[:, None] * freqs[None, :]\n",
|
| 192 |
-
" emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)\n",
|
| 193 |
-
" if self.dim % 2: emb = F.pad(emb, (0, 1))\n",
|
| 194 |
-
" return self.mlp(emb)\n",
|
| 195 |
-
"\n",
|
| 196 |
-
"\n",
|
| 197 |
-
"# ========================= ADAPTIVE LAYER NORM =========================\n",
|
| 198 |
-
"\n",
|
| 199 |
-
"class AdaLN(nn.Module):\n",
|
| 200 |
-
" \"\"\"Adaptive LayerNorm: norm(x) * (1+scale(t)) + shift(t)\"\"\"\n",
|
| 201 |
-
" def __init__(self, dim, cond_dim):\n",
|
| 202 |
-
" super().__init__()\n",
|
| 203 |
-
" ng = min(32, dim)\n",
|
| 204 |
-
" while dim % ng != 0: ng -= 1\n",
|
| 205 |
-
" self.norm = nn.GroupNorm(ng, dim, affine=False)\n",
|
| 206 |
-
" self.proj = nn.Sequential(nn.SiLU(), nn.Linear(cond_dim, dim * 2))\n",
|
| 207 |
-
"\n",
|
| 208 |
-
" def forward(self, x, t_emb):\n",
|
| 209 |
-
" s, sh = self.proj(t_emb).chunk(2, dim=1)\n",
|
| 210 |
-
" return self.norm(x) * (1 + s[:,:,None,None]) + sh[:,:,None,None]\n",
|
| 211 |
-
"\n",
|
| 212 |
-
"\n",
|
| 213 |
-
"# ========================= PARALLEL CfC BLOCK =========================\n",
|
| 214 |
-
"\n",
|
| 215 |
-
"class ParallelCfCBlock(nn.Module):\n",
|
| 216 |
-
" \"\"\"\n",
|
| 217 |
-
" Parallel Closed-form Continuous-depth (CfC) block.\n",
|
| 218 |
-
" \n",
|
| 219 |
-
" CfC Eq.10: x(t) = σ(-f·t) ⊙ g + (1-σ(-f·t)) ⊙ h\n",
|
| 220 |
-
" \n",
|
| 221 |
-
" • f/g/h heads operate on 2D feature maps (depthwise conv)\n",
|
| 222 |
-
" • Diffusion timestep t IS the liquid time constant\n",
|
| 223 |
-
" • No recurrence, no ODE solver — fully parallel\n",
|
| 224 |
-
" • Liquid relaxation: α·residual + (1-α)·CfC_out, α=exp(-λ·t)\n",
|
| 225 |
-
" \"\"\"\n",
|
| 226 |
-
" def __init__(self, dim, t_dim, expand_ratio=2.0, kernel_size=7, dropout=0.0):\n",
|
| 227 |
-
" super().__init__()\n",
|
| 228 |
-
" hidden = int(dim * expand_ratio)\n",
|
| 229 |
-
" self.backbone_dw = nn.Conv2d(dim, dim, kernel_size, padding=kernel_size//2, groups=dim)\n",
|
| 230 |
-
" self.backbone_pw = nn.Conv2d(dim, hidden, 1)\n",
|
| 231 |
-
" self.backbone_act = nn.SiLU()\n",
|
| 232 |
-
" self.f_head = nn.Conv2d(hidden, dim, 1)\n",
|
| 233 |
-
" self.g_head = nn.Sequential(\n",
|
| 234 |
-
" nn.Conv2d(hidden, hidden, kernel_size, padding=kernel_size//2, groups=hidden),\n",
|
| 235 |
-
" nn.SiLU(), nn.Conv2d(hidden, dim, 1))\n",
|
| 236 |
-
" self.h_head = nn.Sequential(\n",
|
| 237 |
-
" nn.Conv2d(hidden, hidden, kernel_size, padding=kernel_size//2, groups=hidden),\n",
|
| 238 |
-
" nn.SiLU(), nn.Conv2d(hidden, dim, 1))\n",
|
| 239 |
-
" self.time_a = nn.Linear(t_dim, dim)\n",
|
| 240 |
-
" self.time_b = nn.Linear(t_dim, dim)\n",
|
| 241 |
-
" self.rho = nn.Parameter(torch.zeros(1, dim, 1, 1))\n",
|
| 242 |
-
" self.output_gate = nn.Sequential(nn.SiLU(), nn.Linear(t_dim, dim))\n",
|
| 243 |
-
" self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()\n",
|
| 244 |
-
"\n",
|
| 245 |
-
" def forward(self, x, t_emb):\n",
|
| 246 |
-
" residual = x\n",
|
| 247 |
-
" backbone = self.backbone_act(self.backbone_pw(self.backbone_dw(x)))\n",
|
| 248 |
-
" f, g, h = self.f_head(backbone), self.g_head(backbone), self.h_head(backbone)\n",
|
| 249 |
-
" ta = self.time_a(t_emb)[:,:,None,None]\n",
|
| 250 |
-
" tb = self.time_b(t_emb)[:,:,None,None]\n",
|
| 251 |
-
" gate = torch.sigmoid(ta * f - tb)\n",
|
| 252 |
-
" cfc_out = self.dropout(gate * g + (1.0 - gate) * h)\n",
|
| 253 |
-
" t_scalar = t_emb.mean(dim=1, keepdim=True)[:,:,None,None]\n",
|
| 254 |
-
" lam = F.softplus(self.rho) + 1e-6\n",
|
| 255 |
-
" alpha = torch.exp(-lam * t_scalar.abs().clamp(min=0.01))\n",
|
| 256 |
-
" out = alpha * residual + (1.0 - alpha) * cfc_out\n",
|
| 257 |
-
" return out * torch.sigmoid(self.output_gate(t_emb))[:,:,None,None]\n",
|
| 258 |
-
"\n",
|
| 259 |
-
"\n",
|
| 260 |
-
"# ========================= MULTI-SCALE SPATIAL MIXING =========================\n",
|
| 261 |
-
"\n",
|
| 262 |
-
"class MultiScaleSpatialMix(nn.Module):\n",
|
| 263 |
-
" \"\"\"Multi-scale depthwise conv + global pooling (replaces attention).\"\"\"\n",
|
| 264 |
-
" def __init__(self, dim, t_dim):\n",
|
| 265 |
-
" super().__init__()\n",
|
| 266 |
-
" self.dw3 = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)\n",
|
| 267 |
-
" self.dw5 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)\n",
|
| 268 |
-
" self.dw7 = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)\n",
|
| 269 |
-
" self.global_pool = nn.AdaptiveAvgPool2d(1)\n",
|
| 270 |
-
" self.global_proj = nn.Conv2d(dim, dim, 1)\n",
|
| 271 |
-
" self.merge = nn.Conv2d(dim*4, dim, 1)\n",
|
| 272 |
-
" self.act = nn.SiLU()\n",
|
| 273 |
-
" self.adaln = AdaLN(dim, t_dim)\n",
|
| 274 |
-
"\n",
|
| 275 |
-
" def forward(self, x, t_emb):\n",
|
| 276 |
-
" xn = self.adaln(x, t_emb)\n",
|
| 277 |
-
" return x + self.act(self.merge(torch.cat([\n",
|
| 278 |
-
" self.dw3(xn), self.dw5(xn), self.dw7(xn),\n",
|
| 279 |
-
" self.global_proj(self.global_pool(xn)).expand_as(xn)], dim=1)))\n",
|
| 280 |
-
"\n",
|
| 281 |
-
"\n",
|
| 282 |
-
"# ========================= LIQUID DIFFUSION BLOCK =========================\n",
|
| 283 |
-
"\n",
|
| 284 |
-
"class LiquidDiffusionBlock(nn.Module):\n",
|
| 285 |
-
" \"\"\"AdaLN → CfC → SpatialMix → FF with residual scaling.\"\"\"\n",
|
| 286 |
-
" def __init__(self, dim, t_dim, expand_ratio=2.0, kernel_size=7, dropout=0.0):\n",
|
| 287 |
-
" super().__init__()\n",
|
| 288 |
-
" self.adaln1 = AdaLN(dim, t_dim)\n",
|
| 289 |
-
" self.cfc = ParallelCfCBlock(dim, t_dim, expand_ratio, kernel_size, dropout)\n",
|
| 290 |
-
" self.spatial_mix = MultiScaleSpatialMix(dim, t_dim)\n",
|
| 291 |
-
" self.adaln2 = AdaLN(dim, t_dim)\n",
|
| 292 |
-
" ff_dim = int(dim * expand_ratio)\n",
|
| 293 |
-
" self.ff = nn.Sequential(nn.Conv2d(dim, ff_dim, 1), nn.SiLU(), nn.Conv2d(ff_dim, dim, 1))\n",
|
| 294 |
-
" self.res_scale = nn.Parameter(torch.ones(1) * 0.1)\n",
|
| 295 |
-
"\n",
|
| 296 |
-
" def forward(self, x, t_emb):\n",
|
| 297 |
-
" x = x + self.res_scale * self.cfc(self.adaln1(x, t_emb), t_emb)\n",
|
| 298 |
-
" x = self.spatial_mix(x, t_emb)\n",
|
| 299 |
-
" x = x + self.res_scale * self.ff(self.adaln2(x, t_emb))\n",
|
| 300 |
-
" return x\n",
|
| 301 |
-
"\n",
|
| 302 |
-
"\n",
|
| 303 |
-
"# ========================= SPATIAL OPS =========================\n",
|
| 304 |
-
"\n",
|
| 305 |
-
"class DownSample(nn.Module):\n",
|
| 306 |
-
" def __init__(self, in_d, out_d):\n",
|
| 307 |
-
" super().__init__()\n",
|
| 308 |
-
" self.conv = nn.Conv2d(in_d, out_d, 3, stride=2, padding=1)\n",
|
| 309 |
-
" def forward(self, x): return self.conv(x)\n",
|
| 310 |
-
"\n",
|
| 311 |
-
"class UpSample(nn.Module):\n",
|
| 312 |
-
" def __init__(self, in_d, out_d):\n",
|
| 313 |
-
" super().__init__()\n",
|
| 314 |
-
" self.conv = nn.Conv2d(in_d, out_d, 3, padding=1)\n",
|
| 315 |
-
" def forward(self, x): return self.conv(F.interpolate(x, scale_factor=2, mode='nearest'))\n",
|
| 316 |
-
"\n",
|
| 317 |
-
"class SkipFusion(nn.Module):\n",
|
| 318 |
-
" def __init__(self, dim, t_dim):\n",
|
| 319 |
-
" super().__init__()\n",
|
| 320 |
-
" self.proj = nn.Conv2d(dim*2, dim, 1)\n",
|
| 321 |
-
" self.gate = nn.Sequential(nn.SiLU(), nn.Linear(t_dim, dim), nn.Sigmoid())\n",
|
| 322 |
-
" def forward(self, x, skip, t_emb):\n",
|
| 323 |
-
" m = self.proj(torch.cat([x, skip], dim=1))\n",
|
| 324 |
-
" g = self.gate(t_emb)[:,:,None,None]\n",
|
| 325 |
-
" return m * g + x * (1 - g)\n",
|
| 326 |
-
"\n",
|
| 327 |
-
"\n",
|
| 328 |
-
"# ========================= LIQUID DIFFUSION U-NET =========================\n",
|
| 329 |
-
"\n",
|
| 330 |
-
"class LiquidDiffusionUNet(nn.Module):\n",
|
| 331 |
-
" \"\"\"\n",
|
| 332 |
-
" LiquidDiffusion: Attention-Free Image Generation with Liquid Neural Networks.\n",
|
| 333 |
-
" U-Net with Parallel CfC blocks. Diffusion timestep = liquid time constant.\n",
|
| 334 |
-
" \"\"\"\n",
|
| 335 |
-
" def __init__(self, in_channels=3, channels=None, blocks_per_stage=None,\n",
|
| 336 |
-
" t_dim=256, expand_ratio=2.0, kernel_size=7, dropout=0.0):\n",
|
| 337 |
-
" super().__init__()\n",
|
| 338 |
-
" channels = channels or [64, 128, 256]\n",
|
| 339 |
-
" blocks_per_stage = blocks_per_stage or [2, 2, 4]\n",
|
| 340 |
-
" assert len(channels) == len(blocks_per_stage)\n",
|
| 341 |
-
" self.channels, self.num_stages = channels, len(channels)\n",
|
| 342 |
-
" \n",
|
| 343 |
-
" self.time_embed = SinusoidalTimeEmbedding(t_dim)\n",
|
| 344 |
-
" self.stem = nn.Sequential(\n",
|
| 345 |
-
" nn.Conv2d(in_channels, channels[0], 3, padding=1), nn.SiLU(),\n",
|
| 346 |
-
" nn.Conv2d(channels[0], channels[0], 3, padding=1))\n",
|
| 347 |
-
" \n",
|
| 348 |
-
" # Encoder\n",
|
| 349 |
-
" self.encoder_blocks = nn.ModuleList()\n",
|
| 350 |
-
" self.downsamplers = nn.ModuleList()\n",
|
| 351 |
-
" for i in range(self.num_stages):\n",
|
| 352 |
-
" stage = nn.ModuleList([LiquidDiffusionBlock(channels[i], t_dim, expand_ratio, kernel_size, dropout)\n",
|
| 353 |
-
" for _ in range(blocks_per_stage[i])])\n",
|
| 354 |
-
" self.encoder_blocks.append(stage)\n",
|
| 355 |
-
" if i < self.num_stages - 1:\n",
|
| 356 |
-
" self.downsamplers.append(DownSample(channels[i], channels[i+1]))\n",
|
| 357 |
-
" \n",
|
| 358 |
-
" # Bottleneck\n",
|
| 359 |
-
" self.bottleneck = nn.ModuleList([\n",
|
| 360 |
-
" LiquidDiffusionBlock(channels[-1], t_dim, expand_ratio, kernel_size, dropout),\n",
|
| 361 |
-
" LiquidDiffusionBlock(channels[-1], t_dim, expand_ratio, kernel_size, dropout)])\n",
|
| 362 |
-
" \n",
|
| 363 |
-
" # Decoder\n",
|
| 364 |
-
" self.decoder_blocks, self.upsamplers, self.skip_fusions = nn.ModuleList(), nn.ModuleList(), nn.ModuleList()\n",
|
| 365 |
-
" for i in range(self.num_stages-1, -1, -1):\n",
|
| 366 |
-
" if i < self.num_stages - 1:\n",
|
| 367 |
-
" self.upsamplers.append(UpSample(channels[i+1], channels[i]))\n",
|
| 368 |
-
" self.skip_fusions.append(SkipFusion(channels[i], t_dim))\n",
|
| 369 |
-
" self.decoder_blocks.append(nn.ModuleList([\n",
|
| 370 |
-
" LiquidDiffusionBlock(channels[i], t_dim, expand_ratio, kernel_size, dropout)\n",
|
| 371 |
-
" for _ in range(blocks_per_stage[i])]))\n",
|
| 372 |
-
" \n",
|
| 373 |
-
" hg = min(32, channels[0])\n",
|
| 374 |
-
" while channels[0] % hg != 0: hg -= 1\n",
|
| 375 |
-
" self.head = nn.Sequential(nn.GroupNorm(hg, channels[0]), nn.SiLU(),\n",
|
| 376 |
-
" nn.Conv2d(channels[0], in_channels, 3, padding=1))\n",
|
| 377 |
-
" nn.init.zeros_(self.head[-1].weight); nn.init.zeros_(self.head[-1].bias)\n",
|
| 378 |
-
"\n",
|
| 379 |
-
" def forward(self, x, t):\n",
|
| 380 |
-
" t_emb = self.time_embed(t)\n",
|
| 381 |
-
" h = self.stem(x)\n",
|
| 382 |
-
" skips = []\n",
|
| 383 |
-
" for i in range(self.num_stages):\n",
|
| 384 |
-
" for blk in self.encoder_blocks[i]: h = blk(h, t_emb)\n",
|
| 385 |
-
" skips.append(h)\n",
|
| 386 |
-
" if i < self.num_stages - 1: h = self.downsamplers[i](h)\n",
|
| 387 |
-
" for blk in self.bottleneck: h = blk(h, t_emb)\n",
|
| 388 |
-
" up_idx = 0\n",
|
| 389 |
-
" for dec_i in range(self.num_stages):\n",
|
| 390 |
-
" si = self.num_stages - 1 - dec_i\n",
|
| 391 |
-
" if dec_i > 0:\n",
|
| 392 |
-
" h = self.upsamplers[up_idx](h)\n",
|
| 393 |
-
" h = self.skip_fusions[up_idx](h, skips[si], t_emb)\n",
|
| 394 |
-
" up_idx += 1\n",
|
| 395 |
-
" for blk in self.decoder_blocks[dec_i]: h = blk(h, t_emb)\n",
|
| 396 |
-
" return self.head(h)\n",
|
| 397 |
-
"\n",
|
| 398 |
-
" def count_params(self):\n",
|
| 399 |
-
" return sum(p.numel() for p in self.parameters()), sum(p.numel() for p in self.parameters() if p.requires_grad)\n",
|
| 400 |
-
"\n",
|
| 401 |
-
"\n",
|
| 402 |
-
"print(\"✅ Model architecture defined.\")"
|
| 403 |
-
]
|
| 404 |
-
},
|
| 405 |
-
{
|
| 406 |
-
"cell_type": "markdown",
|
| 407 |
-
"metadata": {},
|
| 408 |
-
"source": [
|
| 409 |
-
"## 🔧 Build Model"
|
| 410 |
-
]
|
| 411 |
-
},
|
| 412 |
-
{
|
| 413 |
-
"cell_type": "code",
|
| 414 |
-
"execution_count": null,
|
| 415 |
-
"metadata": {},
|
| 416 |
-
"outputs": [],
|
| 417 |
-
"source": [
|
| 418 |
-
"# Build model based on config\n",
|
| 419 |
-
"MODEL_CONFIGS = {\n",
|
| 420 |
-
" 'tiny': dict(channels=[64, 128, 256], blocks_per_stage=[2, 2, 4], t_dim=256),\n",
|
| 421 |
-
" 'small': dict(channels=[96, 192, 384], blocks_per_stage=[2, 3, 6], t_dim=384),\n",
|
| 422 |
-
" 'base': dict(channels=[128, 256, 512], blocks_per_stage=[2, 4, 8], t_dim=512),\n",
|
| 423 |
-
"}\n",
|
| 424 |
-
"\n",
|
| 425 |
-
"if MODEL_SIZE == 'custom':\n",
|
| 426 |
-
" config = dict(channels=CUSTOM_CHANNELS, blocks_per_stage=CUSTOM_BLOCKS, t_dim=CUSTOM_T_DIM)\n",
|
| 427 |
-
"else:\n",
|
| 428 |
-
" config = MODEL_CONFIGS[MODEL_SIZE]\n",
|
| 429 |
-
"\n",
|
| 430 |
-
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
| 431 |
-
"model = LiquidDiffusionUNet(**config).to(device)\n",
|
| 432 |
-
"total_params, trainable_params = model.count_params()\n",
|
| 433 |
-
"\n",
|
| 434 |
-
"print(f\"Model: {MODEL_SIZE}\")\n",
|
| 435 |
-
"print(f\" Channels: {config['channels']}\")\n",
|
| 436 |
-
"print(f\" Blocks: {config['blocks_per_stage']}\")\n",
|
| 437 |
-
"print(f\" t_dim: {config['t_dim']}\")\n",
|
| 438 |
-
"print(f\" Total parameters: {total_params:,} ({total_params/1e6:.1f}M)\")\n",
|
| 439 |
-
"print(f\" Device: {device}\")\n",
|
| 440 |
-
"\n",
|
| 441 |
-
"# Quick forward pass test\n",
|
| 442 |
-
"with torch.no_grad():\n",
|
| 443 |
-
" test_x = torch.randn(1, 3, IMAGE_SIZE, IMAGE_SIZE, device=device)\n",
|
| 444 |
-
" test_t = torch.tensor([0.5], device=device)\n",
|
| 445 |
-
" test_out = model(test_x, test_t)\n",
|
| 446 |
-
" print(f\" Forward pass OK: {test_x.shape} → {test_out.shape}\")\n",
|
| 447 |
-
" del test_x, test_out\n",
|
| 448 |
-
" if device == 'cuda':\n",
|
| 449 |
-
" torch.cuda.empty_cache()\n",
|
| 450 |
-
" print(f\" VRAM after test: {torch.cuda.memory_allocated()/1e9:.2f} GB\")"
|
| 451 |
-
]
|
| 452 |
-
},
|
| 453 |
-
{
|
| 454 |
-
"cell_type": "markdown",
|
| 455 |
-
"metadata": {},
|
| 456 |
-
"source": [
|
| 457 |
-
"## 📊 Load Dataset"
|
| 458 |
-
]
|
| 459 |
-
},
|
| 460 |
-
{
|
| 461 |
-
"cell_type": "code",
|
| 462 |
-
"execution_count": null,
|
| 463 |
-
"metadata": {},
|
| 464 |
-
"outputs": [],
|
| 465 |
-
"source": [
|
| 466 |
-
"from PIL import Image\n",
|
| 467 |
-
"\n",
|
| 468 |
-
"class ImageDataset(Dataset):\n",
|
| 469 |
-
" def __init__(self, source, image_size=256, image_column='image', max_samples=None):\n",
|
| 470 |
-
" self.image_column = image_column\n",
|
| 471 |
-
" self.transform = transforms.Compose([\n",
|
| 472 |
-
" transforms.Resize(image_size, interpolation=transforms.InterpolationMode.LANCZOS),\n",
|
| 473 |
-
" transforms.CenterCrop(image_size),\n",
|
| 474 |
-
" transforms.RandomHorizontalFlip(),\n",
|
| 475 |
-
" transforms.ToTensor(),\n",
|
| 476 |
-
" transforms.Normalize([0.5], [0.5]),\n",
|
| 477 |
-
" ])\n",
|
| 478 |
-
" if os.path.isdir(source):\n",
|
| 479 |
-
" self.files = sorted(sum([glob(os.path.join(source, '**', f'*.{e}'), recursive=True)\n",
|
| 480 |
-
" for e in ['png','jpg','jpeg','webp','bmp']], []))\n",
|
| 481 |
-
" if max_samples: self.files = self.files[:max_samples]\n",
|
| 482 |
-
" self.mode = 'folder'\n",
|
| 483 |
-
" else:\n",
|
| 484 |
-
" from datasets import load_dataset\n",
|
| 485 |
-
" self.data = load_dataset(source, split='train')\n",
|
| 486 |
-
" if max_samples: self.data = self.data.select(range(min(max_samples, len(self.data))))\n",
|
| 487 |
-
" self.mode = 'hf'\n",
|
| 488 |
-
" \n",
|
| 489 |
-
" def __len__(self):\n",
|
| 490 |
-
" return len(self.files) if self.mode == 'folder' else len(self.data)\n",
|
| 491 |
-
" \n",
|
| 492 |
-
" def __getitem__(self, idx):\n",
|
| 493 |
-
" if self.mode == 'folder':\n",
|
| 494 |
-
" img = Image.open(self.files[idx]).convert('RGB')\n",
|
| 495 |
-
" else:\n",
|
| 496 |
-
" img = self.data[idx][self.image_column]\n",
|
| 497 |
-
" if not hasattr(img, 'convert'): img = Image.fromarray(img)\n",
|
| 498 |
-
" img = img.convert('RGB')\n",
|
| 499 |
-
" return self.transform(img)\n",
|
| 500 |
-
"\n",
|
| 501 |
-
"# Load dataset\n",
|
| 502 |
-
"print(f\"Loading dataset: {DATASET}\")\n",
|
| 503 |
-
"dataset = ImageDataset(DATASET, IMAGE_SIZE, IMAGE_COLUMN, MAX_SAMPLES)\n",
|
| 504 |
-
"dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True,\n",
|
| 505 |
-
" num_workers=NUM_WORKERS, pin_memory=True, drop_last=True)\n",
|
| 506 |
-
"\n",
|
| 507 |
-
"print(f\"Dataset size: {len(dataset):,} images\")\n",
|
| 508 |
-
"print(f\"Steps per epoch: {len(dataloader):,}\")\n",
|
| 509 |
-
"print(f\"Total steps: ~{len(dataloader) * NUM_EPOCHS:,}\")\n",
|
| 510 |
-
"\n",
|
| 511 |
-
"# Show sample\n",
|
| 512 |
-
"import matplotlib.pyplot as plt\n",
|
| 513 |
-
"sample_batch = next(iter(dataloader))\n",
|
| 514 |
-
"fig, axes = plt.subplots(1, min(8, BATCH_SIZE), figsize=(16, 2))\n",
|
| 515 |
-
"for i, ax in enumerate(axes):\n",
|
| 516 |
-
" img = (sample_batch[i].permute(1,2,0) * 0.5 + 0.5).clamp(0,1)\n",
|
| 517 |
-
" ax.imshow(img); ax.axis('off')\n",
|
| 518 |
-
"plt.suptitle(f'Training samples ({IMAGE_SIZE}px)'); plt.tight_layout(); plt.show()"
|
| 519 |
-
]
|
| 520 |
-
},
|
| 521 |
-
{
|
| 522 |
-
"cell_type": "markdown",
|
| 523 |
-
"metadata": {},
|
| 524 |
-
"source": [
|
| 525 |
-
"## 🚀 Training Loop"
|
| 526 |
-
]
|
| 527 |
-
},
|
| 528 |
-
{
|
| 529 |
-
"cell_type": "code",
|
| 530 |
-
"execution_count": null,
|
| 531 |
-
"metadata": {},
|
| 532 |
-
"outputs": [],
|
| 533 |
-
"source": [
|
| 534 |
-
"import matplotlib.pyplot as plt\n",
|
| 535 |
-
"from IPython.display import clear_output, display\n",
|
| 536 |
-
"\n",
|
| 537 |
-
"# Setup\n",
|
| 538 |
-
"os.makedirs(OUTPUT_DIR, exist_ok=True)\n",
|
| 539 |
-
"os.makedirs(f'{OUTPUT_DIR}/samples', exist_ok=True)\n",
|
| 540 |
-
"os.makedirs(f'{OUTPUT_DIR}/checkpoints', exist_ok=True)\n",
|
| 541 |
-
"\n",
|
| 542 |
-
"optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE,\n",
|
| 543 |
-
" weight_decay=WEIGHT_DECAY, betas=(0.9, 0.999))\n",
|
| 544 |
-
"\n",
|
| 545 |
-
"# Cosine LR schedule with warmup\n",
|
| 546 |
-
"total_steps = len(dataloader) * NUM_EPOCHS\n",
|
| 547 |
-
"warmup_steps = min(1000, total_steps // 10)\n",
|
| 548 |
-
"\n",
|
| 549 |
-
"def lr_lambda(step):\n",
|
| 550 |
-
" if step < warmup_steps:\n",
|
| 551 |
-
" return float(step) / float(max(1, warmup_steps))\n",
|
| 552 |
-
" progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))\n",
|
| 553 |
-
" return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))\n",
|
| 554 |
-
"\n",
|
| 555 |
-
"scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)\n",
|
| 556 |
-
"\n",
|
| 557 |
-
"# EMA model\n",
|
| 558 |
-
"ema_model = copy.deepcopy(model).eval()\n",
|
| 559 |
-
"for p in ema_model.parameters(): p.requires_grad_(False)\n",
|
| 560 |
-
"\n",
|
| 561 |
-
"# AMP\n",
|
| 562 |
-
"scaler = torch.amp.GradScaler('cuda', enabled=(USE_AMP and device == 'cuda'))\n",
|
| 563 |
-
"amp_dtype = getattr(torch, AMP_DTYPE) if USE_AMP and device == 'cuda' else torch.float32\n",
|
| 564 |
-
"\n",
|
| 565 |
-
"# Time sampling\n",
|
| 566 |
-
"def sample_time(bs):\n",
|
| 567 |
-
" eps = 1e-5\n",
|
| 568 |
-
" if TIME_SAMPLING == 'uniform':\n",
|
| 569 |
-
" return torch.rand(bs, device=device) * (1 - 2*eps) + eps\n",
|
| 570 |
-
" else: # logit_normal\n",
|
| 571 |
-
" return torch.sigmoid(torch.randn(bs, device=device)).clamp(eps, 1-eps)\n",
|
| 572 |
-
"\n",
|
| 573 |
-
"# Resume if requested\n",
|
| 574 |
-
"global_step = 0\n",
|
| 575 |
-
"start_epoch = 0\n",
|
| 576 |
-
"all_losses = []\n",
|
| 577 |
-
"\n",
|
| 578 |
-
"if RESUME_FROM and os.path.exists(RESUME_FROM):\n",
|
| 579 |
-
" ckpt = torch.load(RESUME_FROM, map_location=device, weights_only=False)\n",
|
| 580 |
-
" model.load_state_dict(ckpt['model'])\n",
|
| 581 |
-
" ema_model.load_state_dict(ckpt['ema_model'])\n",
|
| 582 |
-
" optimizer.load_state_dict(ckpt['optimizer'])\n",
|
| 583 |
-
" global_step = ckpt.get('step', 0)\n",
|
| 584 |
-
" start_epoch = ckpt.get('epoch', 0)\n",
|
| 585 |
-
" all_losses = ckpt.get('losses', [])\n",
|
| 586 |
-
" print(f\"Resumed from step {global_step}, epoch {start_epoch}\")\n",
|
| 587 |
-
"\n",
|
| 588 |
-
"\n",
|
| 589 |
-
"@torch.no_grad()\n",
|
| 590 |
-
"def generate_samples(step):\n",
|
| 591 |
-
" \"\"\"Generate and save sample images.\"\"\"\n",
|
| 592 |
-
" ema_model.eval()\n",
|
| 593 |
-
" z = torch.randn(NUM_SAMPLE_IMAGES, 3, IMAGE_SIZE, IMAGE_SIZE, device=device)\n",
|
| 594 |
-
" dt = 1.0 / NUM_EULER_STEPS\n",
|
| 595 |
-
" for i in range(NUM_EULER_STEPS, 0, -1):\n",
|
| 596 |
-
" t = torch.full((NUM_SAMPLE_IMAGES,), i / NUM_EULER_STEPS, device=device)\n",
|
| 597 |
-
" with torch.amp.autocast(device, dtype=amp_dtype, enabled=USE_AMP and device=='cuda'):\n",
|
| 598 |
-
" v = ema_model(z, t)\n",
|
| 599 |
-
" if USE_AMP and amp_dtype == torch.float16: v = v.float()\n",
|
| 600 |
-
" z = z - v * dt\n",
|
| 601 |
-
" z = z.clamp(-1, 1)\n",
|
| 602 |
-
" grid = make_grid(z * 0.5 + 0.5, nrow=int(math.sqrt(NUM_SAMPLE_IMAGES)), padding=2)\n",
|
| 603 |
-
" save_image(grid, f'{OUTPUT_DIR}/samples/step_{step:06d}.png')\n",
|
| 604 |
-
" return z\n",
|
| 605 |
-
"\n",
|
| 606 |
-
"\n",
|
| 607 |
-
"# ========== TRAINING LOOP ==========\n",
|
| 608 |
-
"print(f\"\\n{'='*60}\")\n",
|
| 609 |
-
"print(f\"Starting training: {NUM_EPOCHS} epochs, {total_steps:,} total steps\")\n",
|
| 610 |
-
"print(f\"Warmup: {warmup_steps} steps, LR: {LEARNING_RATE}\")\n",
|
| 611 |
-
"print(f\"{'='*60}\\n\")\n",
|
| 612 |
-
"\n",
|
| 613 |
-
"train_start = time.time()\n",
|
| 614 |
-
"epoch_losses = []\n",
|
| 615 |
-
"\n",
|
| 616 |
-
"for epoch in range(start_epoch, NUM_EPOCHS):\n",
|
| 617 |
-
" model.train()\n",
|
| 618 |
-
" epoch_loss = 0\n",
|
| 619 |
-
" \n",
|
| 620 |
-
" for batch_idx, x0 in enumerate(dataloader):\n",
|
| 621 |
-
" x0 = x0.to(device, non_blocking=True)\n",
|
| 622 |
-
" \n",
|
| 623 |
-
" # Rectified Flow: x_t = (1-t)*x0 + t*x1, target = x1 - x0\n",
|
| 624 |
-
" x1 = torch.randn_like(x0)\n",
|
| 625 |
-
" t = sample_time(x0.shape[0])\n",
|
| 626 |
-
" te = t[:, None, None, None]\n",
|
| 627 |
-
" x_t = (1 - te) * x0 + te * x1\n",
|
| 628 |
-
" v_target = x1 - x0\n",
|
| 629 |
-
" \n",
|
| 630 |
-
" with torch.amp.autocast(device, dtype=amp_dtype, enabled=USE_AMP and device=='cuda'):\n",
|
| 631 |
-
" v_pred = model(x_t, t)\n",
|
| 632 |
-
" loss = F.mse_loss(v_pred, v_target)\n",
|
| 633 |
-
" \n",
|
| 634 |
-
" optimizer.zero_grad(set_to_none=True)\n",
|
| 635 |
-
" scaler.scale(loss).backward()\n",
|
| 636 |
-
" if GRAD_CLIP > 0:\n",
|
| 637 |
-
" scaler.unscale_(optimizer)\n",
|
| 638 |
-
" torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)\n",
|
| 639 |
-
" scaler.step(optimizer)\n",
|
| 640 |
-
" scaler.update()\n",
|
| 641 |
-
" scheduler.step()\n",
|
| 642 |
-
" \n",
|
| 643 |
-
" # EMA update\n",
|
| 644 |
-
" with torch.no_grad():\n",
|
| 645 |
-
" for ep, mp in zip(ema_model.parameters(), model.parameters()):\n",
|
| 646 |
-
" ep.data.mul_(EMA_DECAY).add_(mp.data, alpha=1-EMA_DECAY)\n",
|
| 647 |
-
" \n",
|
| 648 |
-
" global_step += 1\n",
|
| 649 |
-
" loss_val = loss.item()\n",
|
| 650 |
-
" all_losses.append(loss_val)\n",
|
| 651 |
-
" epoch_loss += loss_val\n",
|
| 652 |
-
" \n",
|
| 653 |
-
" # Logging\n",
|
| 654 |
-
" if global_step % LOG_EVERY == 0:\n",
|
| 655 |
-
" avg_loss = sum(all_losses[-LOG_EVERY:]) / LOG_EVERY\n",
|
| 656 |
-
" lr = scheduler.get_last_lr()[0]\n",
|
| 657 |
-
" elapsed = time.time() - train_start\n",
|
| 658 |
-
" steps_per_sec = global_step / elapsed\n",
|
| 659 |
-
" eta = (total_steps - global_step) / max(steps_per_sec, 1e-8)\n",
|
| 660 |
-
" if device == 'cuda':\n",
|
| 661 |
-
" vram = torch.cuda.max_memory_allocated() / 1e9\n",
|
| 662 |
-
" print(f\"Step {global_step:6d}/{total_steps} | Loss: {avg_loss:.4f} | LR: {lr:.2e} | {steps_per_sec:.1f} it/s | ETA: {eta/60:.0f}m | VRAM: {vram:.1f}GB\")\n",
|
| 663 |
-
" else:\n",
|
| 664 |
-
" print(f\"Step {global_step:6d}/{total_steps} | Loss: {avg_loss:.4f} | LR: {lr:.2e} | {steps_per_sec:.1f} it/s | ETA: {eta/60:.0f}m\")\n",
|
| 665 |
-
" \n",
|
| 666 |
-
" # Generate samples\n",
|
| 667 |
-
" if global_step % SAMPLE_EVERY == 0:\n",
|
| 668 |
-
" print(f\"\\n 📸 Generating samples at step {global_step}...\")\n",
|
| 669 |
-
" samples = generate_samples(global_step)\n",
|
| 670 |
-
" \n",
|
| 671 |
-
" # Display in notebook\n",
|
| 672 |
-
" fig, axes = plt.subplots(1, min(8, NUM_SAMPLE_IMAGES), figsize=(16, 2.5))\n",
|
| 673 |
-
" if NUM_SAMPLE_IMAGES == 1: axes = [axes]\n",
|
| 674 |
-
" for i, ax in enumerate(axes):\n",
|
| 675 |
-
" if i < samples.shape[0]:\n",
|
| 676 |
-
" img = (samples[i].cpu().permute(1,2,0) * 0.5 + 0.5).clamp(0,1)\n",
|
| 677 |
-
" ax.imshow(img); ax.axis('off')\n",
|
| 678 |
-
" plt.suptitle(f'Step {global_step} | Loss: {loss_val:.4f}'); plt.tight_layout(); plt.show()\n",
|
| 679 |
-
" \n",
|
| 680 |
-
" # Save checkpoint\n",
|
| 681 |
-
" if global_step % SAVE_EVERY == 0:\n",
|
| 682 |
-
" ckpt_path = f'{OUTPUT_DIR}/checkpoints/step_{global_step:06d}.pt'\n",
|
| 683 |
-
" torch.save({\n",
|
| 684 |
-
" 'model': model.state_dict(),\n",
|
| 685 |
-
" 'ema_model': ema_model.state_dict(),\n",
|
| 686 |
-
" 'optimizer': optimizer.state_dict(),\n",
|
| 687 |
-
" 'step': global_step,\n",
|
| 688 |
-
" 'epoch': epoch,\n",
|
| 689 |
-
" 'losses': all_losses[-2000:],\n",
|
| 690 |
-
" 'config': config,\n",
|
| 691 |
-
" }, ckpt_path)\n",
|
| 692 |
-
" print(f\" 💾 Saved checkpoint: {ckpt_path}\")\n",
|
| 693 |
-
" \n",
|
| 694 |
-
" # Epoch summary\n",
|
| 695 |
-
" avg_epoch_loss = epoch_loss / len(dataloader)\n",
|
| 696 |
-
" epoch_losses.append(avg_epoch_loss)\n",
|
| 697 |
-
" print(f\"\\n Epoch {epoch+1}/{NUM_EPOCHS} complete | Avg loss: {avg_epoch_loss:.4f}\")\n",
|
| 698 |
-
"\n",
|
| 699 |
-
"# Final save\n",
|
| 700 |
-
"final_path = f'{OUTPUT_DIR}/checkpoints/final.pt'\n",
|
| 701 |
-
"torch.save({\n",
|
| 702 |
-
" 'model': model.state_dict(),\n",
|
| 703 |
-
" 'ema_model': ema_model.state_dict(),\n",
|
| 704 |
-
" 'step': global_step,\n",
|
| 705 |
-
" 'config': config,\n",
|
| 706 |
-
" 'losses': all_losses[-2000:],\n",
|
| 707 |
-
"}, final_path)\n",
|
| 708 |
-
"print(f\"\\n✅ Training complete! Final checkpoint: {final_path}\")\n",
|
| 709 |
-
"print(f\"Total time: {(time.time()-train_start)/3600:.1f} hours\")"
|
| 710 |
-
]
|
| 711 |
-
},
|
| 712 |
-
{
|
| 713 |
-
"cell_type": "markdown",
|
| 714 |
-
"metadata": {},
|
| 715 |
-
"source": [
|
| 716 |
-
"## 📈 Training Curves"
|
| 717 |
-
]
|
| 718 |
-
},
|
| 719 |
-
{
|
| 720 |
-
"cell_type": "code",
|
| 721 |
-
"execution_count": null,
|
| 722 |
-
"metadata": {},
|
| 723 |
-
"outputs": [],
|
| 724 |
-
"source": [
|
| 725 |
-
"import matplotlib.pyplot as plt\n",
|
| 726 |
-
"import numpy as np\n",
|
| 727 |
-
"\n",
|
| 728 |
-
"fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))\n",
|
| 729 |
-
"\n",
|
| 730 |
-
"# Raw loss\n",
|
| 731 |
-
"ax1.plot(all_losses, alpha=0.3, color='blue', linewidth=0.5)\n",
|
| 732 |
-
"# Smoothed loss\n",
|
| 733 |
-
"window = min(200, len(all_losses)//5)\n",
|
| 734 |
-
"if window > 1:\n",
|
| 735 |
-
" smoothed = np.convolve(all_losses, np.ones(window)/window, mode='valid')\n",
|
| 736 |
-
" ax1.plot(range(window-1, len(all_losses)), smoothed, color='red', linewidth=2, label=f'Smoothed (w={window})')\n",
|
| 737 |
-
"ax1.set_xlabel('Step'); ax1.set_ylabel('Loss'); ax1.set_title('Training Loss')\n",
|
| 738 |
-
"ax1.legend(); ax1.grid(True, alpha=0.3)\n",
|
| 739 |
-
"\n",
|
| 740 |
-
"# Epoch loss\n",
|
| 741 |
-
"if epoch_losses:\n",
|
| 742 |
-
" ax2.plot(range(1, len(epoch_losses)+1), epoch_losses, 'o-', color='green')\n",
|
| 743 |
-
" ax2.set_xlabel('Epoch'); ax2.set_ylabel('Avg Loss'); ax2.set_title('Loss per Epoch')\n",
|
| 744 |
-
" ax2.grid(True, alpha=0.3)\n",
|
| 745 |
-
"\n",
|
| 746 |
-
"plt.tight_layout(); plt.show()"
|
| 747 |
-
]
|
| 748 |
-
},
|
| 749 |
-
{
|
| 750 |
-
"cell_type": "markdown",
|
| 751 |
-
"metadata": {},
|
| 752 |
-
"source": [
|
| 753 |
-
"## 🎨 Generate Images"
|
| 754 |
-
]
|
| 755 |
-
},
|
| 756 |
-
{
|
| 757 |
-
"cell_type": "code",
|
| 758 |
-
"execution_count": null,
|
| 759 |
-
"metadata": {},
|
| 760 |
-
"outputs": [],
|
| 761 |
-
"source": [
|
| 762 |
-
"# Generate a batch of images\n",
|
| 763 |
-
"NUM_GENERATE = 16 # @param {type:\"integer\"}\n",
|
| 764 |
-
"EULER_STEPS = 50 # @param {type:\"integer\"}\n",
|
| 765 |
-
"\n",
|
| 766 |
-
"print(f\"Generating {NUM_GENERATE} images with {EULER_STEPS} Euler steps...\")\n",
|
| 767 |
-
"ema_model.eval()\n",
|
| 768 |
-
"\n",
|
| 769 |
-
"with torch.no_grad():\n",
|
| 770 |
-
" z = torch.randn(NUM_GENERATE, 3, IMAGE_SIZE, IMAGE_SIZE, device=device)\n",
|
| 771 |
-
" dt = 1.0 / EULER_STEPS\n",
|
| 772 |
-
" for i in range(EULER_STEPS, 0, -1):\n",
|
| 773 |
-
" t = torch.full((NUM_GENERATE,), i / EULER_STEPS, device=device)\n",
|
| 774 |
-
" with torch.amp.autocast(device, dtype=amp_dtype, enabled=USE_AMP and device=='cuda'):\n",
|
| 775 |
-
" v = ema_model(z, t)\n",
|
| 776 |
-
" if USE_AMP and amp_dtype == torch.float16: v = v.float()\n",
|
| 777 |
-
" z = z - v * dt\n",
|
| 778 |
-
" generated = z.clamp(-1, 1)\n",
|
| 779 |
-
"\n",
|
| 780 |
-
"# Display\n",
|
| 781 |
-
"nrow = int(math.ceil(math.sqrt(NUM_GENERATE)))\n",
|
| 782 |
-
"fig, axes = plt.subplots(nrow, nrow, figsize=(2.5*nrow, 2.5*nrow))\n",
|
| 783 |
-
"axes = axes.flatten() if hasattr(axes, 'flatten') else [axes]\n",
|
| 784 |
-
"for i, ax in enumerate(axes):\n",
|
| 785 |
-
" if i < NUM_GENERATE:\n",
|
| 786 |
-
" img = (generated[i].cpu().permute(1,2,0) * 0.5 + 0.5).clamp(0,1)\n",
|
| 787 |
-
" ax.imshow(img)\n",
|
| 788 |
-
" ax.axis('off')\n",
|
| 789 |
-
"plt.suptitle(f'LiquidDiffusion Samples ({IMAGE_SIZE}px, {EULER_STEPS} steps)', fontsize=14)\n",
|
| 790 |
-
"plt.tight_layout(); plt.show()\n",
|
| 791 |
-
"\n",
|
| 792 |
-
"# Save\n",
|
| 793 |
-
"grid = make_grid(generated * 0.5 + 0.5, nrow=nrow, padding=2)\n",
|
| 794 |
-
"save_image(grid, f'{OUTPUT_DIR}/final_samples.png')\n",
|
| 795 |
-
"print(f\"Saved to {OUTPUT_DIR}/final_samples.png\")"
|
| 796 |
-
]
|
| 797 |
-
},
|
| 798 |
-
{
|
| 799 |
-
"cell_type": "markdown",
|
| 800 |
-
"metadata": {},
|
| 801 |
-
"source": [
|
| 802 |
-
"## 💾 Save / Load Model"
|
| 803 |
-
]
|
| 804 |
-
},
|
| 805 |
-
{
|
| 806 |
-
"cell_type": "code",
|
| 807 |
-
"execution_count": null,
|
| 808 |
-
"metadata": {},
|
| 809 |
-
"outputs": [],
|
| 810 |
-
"source": [
|
| 811 |
-
"# Save to HuggingFace Hub (optional)\n",
|
| 812 |
-
"PUSH_TO_HUB = False # @param {type:\"boolean\"}\n",
|
| 813 |
-
"HUB_MODEL_ID = 'your-username/liquid-diffusion-celebahq-256' # @param {type:\"string\"}\n",
|
| 814 |
-
"\n",
|
| 815 |
-
"if PUSH_TO_HUB:\n",
|
| 816 |
-
" from huggingface_hub import HfApi\n",
|
| 817 |
-
" api = HfApi()\n",
|
| 818 |
-
" api.create_repo(HUB_MODEL_ID, exist_ok=True)\n",
|
| 819 |
-
" api.upload_file(\n",
|
| 820 |
-
" path_or_fileobj=final_path,\n",
|
| 821 |
-
" path_in_repo='model.pt',\n",
|
| 822 |
-
" repo_id=HUB_MODEL_ID,\n",
|
| 823 |
-
" )\n",
|
| 824 |
-
" print(f\"Pushed to https://huggingface.co/{HUB_MODEL_ID}\")"
|
| 825 |
-
]
|
| 826 |
-
},
|
| 827 |
-
{
|
| 828 |
-
"cell_type": "markdown",
|
| 829 |
-
"metadata": {},
|
| 830 |
-
"source": [
|
| 831 |
-
"---\n",
|
| 832 |
-
"\n",
|
| 833 |
-
"## 📖 Architecture Deep Dive\n",
|
| 834 |
-
"\n",
|
| 835 |
-
"### What makes LiquidDiffusion special?\n",
|
| 836 |
-
"\n",
|
| 837 |
-
"**1. CfC Time-Gating (the \"liquid\" part)**\n",
|
| 838 |
-
"```\n",
|
| 839 |
-
"gate = σ(time_a(t_emb) · f(features) - time_b(t_emb))\n",
|
| 840 |
-
"output = gate · g(features) + (1 - gate) · h(features)\n",
|
| 841 |
-
"```\n",
|
| 842 |
-
"- `f` = time-constant head (controls gate sensitivity)\n",
|
| 843 |
-
"- `g` = \"from\" state (what features look like at short time)\n",
|
| 844 |
-
"- `h` = \"to\" state (attractor for long time)\n",
|
| 845 |
-
"- The gate adapts **per-channel, per-spatial-position** based on both the input features AND the noise level\n",
|
| 846 |
-
"\n",
|
| 847 |
-
"**2. Liquid Relaxation Residual**\n",
|
| 848 |
-
"```\n",
|
| 849 |
-
"α = exp(-λ · |t_emb_mean|)\n",
|
| 850 |
-
"out = α · input + (1-α) · CfC_output\n",
|
| 851 |
-
"```\n",
|
| 852 |
-
"- When noise is high (large t): α→0, rely on CfC output (needs heavy processing)\n",
|
| 853 |
-
"- When noise is low (small t): α→1, preserve input (just refine details)\n",
|
| 854 |
-
"- λ is learned per-channel — each feature dimension decides its own decay rate\n",
|
| 855 |
-
"\n",
|
| 856 |
-
"**3. Multi-Scale Spatial Mixing**\n",
|
| 857 |
-
"- 3×3 + 5×5 + 7×7 depthwise convolutions + global average pooling\n",
|
| 858 |
-
"- Gives effective global receptive field without O(n²) attention\n",
|
| 859 |
-
"- All parallel, all efficient\n",
|
| 860 |
-
"\n",
|
| 861 |
-
"### Why no attention?\n",
|
| 862 |
-
"- Self-attention is O(n²) in spatial tokens — at 256px that's 65K tokens\n",
|
| 863 |
-
"- Depthwise convolutions + global pooling give global context at O(n) cost\n",
|
| 864 |
-
"- The CfC time-gating provides the \"adaptive routing\" that attention normally gives\n",
|
| 865 |
-
"- Result: **same expressivity, 10× less memory, 3× faster**\n",
|
| 866 |
-
"\n",
|
| 867 |
-
"### Parameter counts\n",
|
| 868 |
-
"| Config | Params | 256px VRAM | 512px VRAM |\n",
|
| 869 |
-
"|--------|--------|------------|------------|\n",
|
| 870 |
-
"| tiny | ~23M | ~6 GB | ~12 GB |\n",
|
| 871 |
-
"| small | ~69M | ~10 GB | ~20 GB |\n",
|
| 872 |
-
"| base | ~154M | ~16 GB | ~30 GB |"
|
| 873 |
-
]
|
| 874 |
-
}
|
| 875 |
-
]
|
| 876 |
}
|
|
|
|
| 1 |
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"# 🌊 LiquidDiffusion: Attention-Free Image Generation with Liquid Neural Networks\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"**A novel image generation architecture** that replaces attention with Parallel CfC (Closed-form Continuous-depth) blocks from Liquid Neural Networks.\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"## Key Innovations\n",
|
| 12 |
+
"- **No attention mechanism** — all spatial mixing via multi-scale depthwise convolutions\n",
|
| 13 |
+
"- **Fully parallelizable** — no sequential ODE solving loops (unlike original LTC/Neural ODE)\n",
|
| 14 |
+
"- **Diffusion timestep IS the liquid time constant** — natural CfC-diffusion bridge\n",
|
| 15 |
+
"- **Liquid relaxation residuals** — time-aware skip connections that adapt to noise level\n",
|
| 16 |
+
"- **Fits in 16GB VRAM** — designed for Colab free tier (T4 GPU)\n",
|
| 17 |
+
"\n",
|
| 18 |
+
"## Architecture Based On\n",
|
| 19 |
+
"- [CfC Networks](https://arxiv.org/abs/2106.13898) (Hasani et al., Nature Machine Intelligence 2022)\n",
|
| 20 |
+
"- [LiquidTAD](https://arxiv.org/abs/2604.18274) — parallel liquid relaxation\n",
|
| 21 |
+
"- [USM](https://arxiv.org/abs/2504.13499) — U-Shape architecture for diffusion\n",
|
| 22 |
+
"- [Rectified Flow](https://arxiv.org/abs/2209.03003) — simplest flow matching objective\n",
|
| 23 |
+
"\n",
|
| 24 |
+
"## Training: Rectified Flow\n",
|
| 25 |
+
"```\n",
|
| 26 |
+
"x_t = (1-t)*x0 + t*noise, t ~ U[0,1]\n",
|
| 27 |
+
"Loss = MSE(model(x_t, t), noise - x0) # velocity prediction\n",
|
| 28 |
+
"```\n",
|
| 29 |
+
"That's it — no noise schedule, no variance, just MSE on a straight-line velocity."
|
| 30 |
+
]
|
| 31 |
+
},
|
| 32 |
+
{
|
| 33 |
+
"cell_type": "markdown",
|
| 34 |
+
"metadata": {},
|
| 35 |
+
"source": [
|
| 36 |
+
"## 🔧 Setup"
|
| 37 |
+
]
|
| 38 |
+
},
|
| 39 |
+
{
|
| 40 |
+
"cell_type": "code",
|
| 41 |
+
"execution_count": null,
|
| 42 |
+
"metadata": {},
|
| 43 |
+
"outputs": [],
|
| 44 |
+
"source": [
|
| 45 |
+
"# Install dependencies\n",
|
| 46 |
+
"!pip install -q torch torchvision datasets Pillow matplotlib tqdm accelerate"
|
| 47 |
+
]
|
| 48 |
+
},
|
| 49 |
+
{
|
| 50 |
+
"cell_type": "code",
|
| 51 |
+
"execution_count": null,
|
| 52 |
+
"metadata": {},
|
| 53 |
+
"outputs": [],
|
| 54 |
+
"source": [
|
| 55 |
+
"# Clone the repo\n",
|
| 56 |
+
"!git clone https://huggingface.co/krystv/liquid-diffusion\n",
|
| 57 |
+
"%cd liquid-diffusion"
|
| 58 |
+
]
|
| 59 |
+
},
|
| 60 |
+
{
|
| 61 |
+
"cell_type": "code",
|
| 62 |
+
"execution_count": null,
|
| 63 |
+
"metadata": {},
|
| 64 |
+
"outputs": [],
|
| 65 |
+
"source": [
|
| 66 |
+
"import torch\n",
|
| 67 |
+
"print(f'PyTorch: {torch.__version__}')\n",
|
| 68 |
+
"print(f'CUDA available: {torch.cuda.is_available()}')\n",
|
| 69 |
+
"if torch.cuda.is_available():\n",
|
| 70 |
+
" print(f'GPU: {torch.cuda.get_device_name(0)}')\n",
|
| 71 |
+
" print(f'VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB')"
|
| 72 |
+
]
|
| 73 |
+
},
|
| 74 |
+
{
|
| 75 |
+
"cell_type": "markdown",
|
| 76 |
+
"metadata": {},
|
| 77 |
+
"source": [
|
| 78 |
+
"## 📐 Architecture Overview\n",
|
| 79 |
+
"\n",
|
| 80 |
+
"The core innovation is the **ParallelCfCBlock** — a parallelized version of CfC (Closed-form Continuous-depth) networks adapted for 2D image features:\n",
|
| 81 |
+
"\n",
|
| 82 |
+
"```\n",
|
| 83 |
+
"CfC Equation (Hasani et al. 2022, Eq. 10):\n",
|
| 84 |
+
" x(t) = σ(-f·t) ⊙ g + (1 - σ(-f·t)) ⊙ h\n",
|
| 85 |
+
"\n",
|
| 86 |
+
"Our adaptation for image generation:\n",
|
| 87 |
+
" backbone = SiLU(PointwiseConv(DepthwiseConv(features))) # shared spatial context\n",
|
| 88 |
+
" f = Conv1x1(backbone) # time-constant gate\n",
|
| 89 |
+
" g = DWConv→SiLU→Conv1x1(backbone) # \"from\" state\n",
|
| 90 |
+
" h = DWConv→SiLU→Conv1x1(backbone) # \"to\" state (attractor)\n",
|
| 91 |
+
" gate = σ(time_a(t_emb) · f - time_b(t_emb)) # liquid time gate\n",
|
| 92 |
+
" cfc_out = gate · g + (1 - gate) · h # CfC interpolation\n",
|
| 93 |
+
" \n",
|
| 94 |
+
" # Liquid relaxation (from LiquidTAD):\n",
|
| 95 |
+
" α = exp(-softplus(ρ) · |t|) # time-aware residual weight\n",
|
| 96 |
+
" output = α · input + (1 - α) · cfc_out # adapts to noise level\n",
|
| 97 |
+
"```\n",
|
| 98 |
+
"\n",
|
| 99 |
+
"The **diffusion timestep t** serves double duty:\n",
|
| 100 |
+
"1. Standard: conditions the denoiser via AdaLN scale/shift\n",
|
| 101 |
+
"2. Novel: acts as the CfC time parameter — controls interpolation between g and h\n",
|
| 102 |
+
"\n",
|
| 103 |
+
"This means: at low noise (t≈0), the gate is balanced → flexible processing.\n",
|
| 104 |
+
"At high noise (t≈1), the gate saturates → specialized denoising."
|
| 105 |
+
]
|
| 106 |
+
},
|
| 107 |
+
{
|
| 108 |
+
"cell_type": "markdown",
|
| 109 |
+
"metadata": {},
|
| 110 |
+
"source": [
|
| 111 |
+
"## 🧪 Quick Test (verify model works)"
|
| 112 |
+
]
|
| 113 |
+
},
|
| 114 |
+
{
|
| 115 |
+
"cell_type": "code",
|
| 116 |
+
"execution_count": null,
|
| 117 |
+
"metadata": {},
|
| 118 |
+
"outputs": [],
|
| 119 |
+
"source": [
|
| 120 |
+
"# Run the test suite\n",
|
| 121 |
+
"!python test_model.py"
|
| 122 |
+
]
|
| 123 |
+
},
|
| 124 |
+
{
|
| 125 |
+
"cell_type": "markdown",
|
| 126 |
+
"metadata": {},
|
| 127 |
+
"source": [
|
| 128 |
+
"## ⚙️ Training Configuration\n",
|
| 129 |
+
"\n",
|
| 130 |
+
"Choose your config based on GPU and target resolution:\n",
|
| 131 |
+
"\n",
|
| 132 |
+
"| Config | Params | Resolution | Batch Size | VRAM | Training Time |\n",
|
| 133 |
+
"|--------|--------|-----------|------------|------|---------------|\n",
|
| 134 |
+
"| tiny | ~8M | 256×256 | 8 | ~6GB | ~3h (100K steps) |\n",
|
| 135 |
+
"| small | ~25M | 256×256 | 4 | ~10GB | ~6h (100K steps) |\n",
|
| 136 |
+
"| base | ~65M | 512×512 | 2 | ~14GB | ~12h (100K steps) |\n",
|
| 137 |
+
"\n",
|
| 138 |
+
"Recommended datasets:\n",
|
| 139 |
+
"- `huggan/CelebA-HQ` — 30K high-quality face images (256px)\n",
|
| 140 |
+
"- `huggan/flowers-102-categories` — flowers (various)\n",
|
| 141 |
+
"- `lambdalabs/naruto-blip-captions` — anime style (~1K)\n",
|
| 142 |
+
"- `Norod78/simpsons-blip-captions` — cartoon style\n",
|
| 143 |
+
"- Any folder of images"
|
| 144 |
+
]
|
| 145 |
+
},
|
| 146 |
+
{
|
| 147 |
+
"cell_type": "code",
|
| 148 |
+
"execution_count": null,
|
| 149 |
+
"metadata": {},
|
| 150 |
+
"outputs": [],
|
| 151 |
+
"source": [
|
| 152 |
+
"#@title Training Configuration {display-mode: \"form\"}\n",
|
| 153 |
+
"\n",
|
| 154 |
+
"#@markdown ### Model\n",
|
| 155 |
+
"model_size = \"tiny\" #@param [\"tiny\", \"small\", \"base\"]\n",
|
| 156 |
+
"\n",
|
| 157 |
+
"#@markdown ### Data\n",
|
| 158 |
+
"dataset_name = \"huggan/CelebA-HQ\" #@param {type:\"string\"}\n",
|
| 159 |
+
"image_column = \"image\" #@param {type:\"string\"}\n",
|
| 160 |
+
"image_size = 256 #@param [64, 128, 256, 512] {type:\"integer\"}\n",
|
| 161 |
+
"max_samples = 0 #@param {type:\"integer\"}\n",
|
| 162 |
+
"\n",
|
| 163 |
+
"#@markdown ### Training\n",
|
| 164 |
+
"batch_size = 8 #@param {type:\"integer\"}\n",
|
| 165 |
+
"learning_rate = 1e-4 #@param {type:\"number\"}\n",
|
| 166 |
+
"weight_decay = 0.01 #@param {type:\"number\"}\n",
|
| 167 |
+
"total_steps = 100000 #@param {type:\"integer\"}\n",
|
| 168 |
+
"warmup_steps = 1000 #@param {type:\"integer\"}\n",
|
| 169 |
+
"grad_clip = 1.0 #@param {type:\"number\"}\n",
|
| 170 |
+
"ema_decay = 0.9999 #@param {type:\"number\"}\n",
|
| 171 |
+
"time_sampling = \"logit_normal\" #@param [\"uniform\", \"logit_normal\"]\n",
|
| 172 |
+
"\n",
|
| 173 |
+
"#@markdown ### Sampling & Logging\n",
|
| 174 |
+
"sample_every = 2000 #@param {type:\"integer\"}\n",
|
| 175 |
+
"save_every = 5000 #@param {type:\"integer\"}\n",
|
| 176 |
+
"num_sample_steps = 50 #@param {type:\"integer\"}\n",
|
| 177 |
+
"num_sample_images = 4 #@param {type:\"integer\"}\n",
|
| 178 |
+
"\n",
|
| 179 |
+
"#@markdown ### Hardware\n",
|
| 180 |
+
"use_amp = True #@param {type:\"boolean\"}\n",
|
| 181 |
+
"amp_dtype = \"float16\" #@param [\"float16\", \"bfloat16\"]\n",
|
| 182 |
+
"num_workers = 2 #@param {type:\"integer\"}\n",
|
| 183 |
+
"\n",
|
| 184 |
+
"# Auto-adjust batch size for resolution\n",
|
| 185 |
+
"if image_size >= 512 and batch_size > 4:\n",
|
| 186 |
+
" batch_size = min(batch_size, 2)\n",
|
| 187 |
+
" print(f\"Auto-reduced batch_size to {batch_size} for {image_size}px\")\n",
|
| 188 |
+
"\n",
|
| 189 |
+
"if max_samples == 0:\n",
|
| 190 |
+
" max_samples = None\n",
|
| 191 |
+
"\n",
|
| 192 |
+
"print(f\"\\nConfig: {model_size} model, {image_size}px, batch={batch_size}, lr={learning_rate}\")\n",
|
| 193 |
+
"print(f\"Dataset: {dataset_name}, time_sampling={time_sampling}\")\n",
|
| 194 |
+
"print(f\"Total steps: {total_steps:,}, AMP: {use_amp} ({amp_dtype})\")"
|
| 195 |
+
]
|
| 196 |
+
},
|
| 197 |
+
{
|
| 198 |
+
"cell_type": "markdown",
|
| 199 |
+
"metadata": {},
|
| 200 |
+
"source": [
|
| 201 |
+
"## 📦 Load Dataset"
|
| 202 |
+
]
|
| 203 |
+
},
|
| 204 |
+
{
|
| 205 |
+
"cell_type": "code",
|
| 206 |
+
"execution_count": null,
|
| 207 |
+
"metadata": {},
|
| 208 |
+
"outputs": [],
|
| 209 |
+
"source": [
|
| 210 |
+
"from datasets import load_dataset\n",
|
| 211 |
+
"from liquid_diffusion.trainer import ImageDataset\n",
|
| 212 |
+
"from torch.utils.data import DataLoader\n",
|
| 213 |
+
"import matplotlib.pyplot as plt\n",
|
| 214 |
+
"import numpy as np\n",
|
| 215 |
+
"\n",
|
| 216 |
+
"# Load dataset\n",
|
| 217 |
+
"print(f\"Loading {dataset_name}...\")\n",
|
| 218 |
+
"dataset = ImageDataset(\n",
|
| 219 |
+
" source=dataset_name,\n",
|
| 220 |
+
" image_size=image_size,\n",
|
| 221 |
+
" image_column=image_column,\n",
|
| 222 |
+
" max_samples=max_samples,\n",
|
| 223 |
+
")\n",
|
| 224 |
+
"print(f\"Dataset size: {len(dataset)} images\")\n",
|
| 225 |
+
"\n",
|
| 226 |
+
"dataloader = DataLoader(\n",
|
| 227 |
+
" dataset, batch_size=batch_size, shuffle=True,\n",
|
| 228 |
+
" num_workers=num_workers, pin_memory=True, drop_last=True,\n",
|
| 229 |
+
")\n",
|
| 230 |
+
"\n",
|
| 231 |
+
"# Show some samples\n",
|
| 232 |
+
"sample_batch = next(iter(dataloader))\n",
|
| 233 |
+
"fig, axes = plt.subplots(1, min(4, batch_size), figsize=(16, 4))\n",
|
| 234 |
+
"for i, ax in enumerate(axes):\n",
|
| 235 |
+
" img = sample_batch[i].permute(1, 2, 0).numpy() * 0.5 + 0.5 # [-1,1] -> [0,1]\n",
|
| 236 |
+
" ax.imshow(np.clip(img, 0, 1))\n",
|
| 237 |
+
" ax.axis('off')\n",
|
| 238 |
+
"plt.suptitle(f'Training samples ({image_size}×{image_size})')\n",
|
| 239 |
+
"plt.tight_layout()\n",
|
| 240 |
+
"plt.show()"
|
| 241 |
+
]
|
| 242 |
+
},
|
| 243 |
+
{
|
| 244 |
+
"cell_type": "markdown",
|
| 245 |
+
"metadata": {},
|
| 246 |
+
"source": [
|
| 247 |
+
"## 🏗️ Build Model"
|
| 248 |
+
]
|
| 249 |
+
},
|
| 250 |
+
{
|
| 251 |
+
"cell_type": "code",
|
| 252 |
+
"execution_count": null,
|
| 253 |
+
"metadata": {},
|
| 254 |
+
"outputs": [],
|
| 255 |
+
"source": [
|
| 256 |
+
"from liquid_diffusion.model import (\n",
|
| 257 |
+
" liquid_diffusion_tiny, liquid_diffusion_small, liquid_diffusion_base\n",
|
| 258 |
+
")\n",
|
| 259 |
+
"\n",
|
| 260 |
+
"# Build model\n",
|
| 261 |
+
"model_factories = {\n",
|
| 262 |
+
" 'tiny': liquid_diffusion_tiny,\n",
|
| 263 |
+
" 'small': liquid_diffusion_small,\n",
|
| 264 |
+
" 'base': liquid_diffusion_base,\n",
|
| 265 |
+
"}\n",
|
| 266 |
+
"\n",
|
| 267 |
+
"model = model_factories[model_size]()\n",
|
| 268 |
+
"total_params, trainable_params = model.count_params()\n",
|
| 269 |
+
"print(f\"Model: liquid_diffusion_{model_size}\")\n",
|
| 270 |
+
"print(f\"Parameters: {total_params:,} ({total_params/1e6:.1f}M)\")\n",
|
| 271 |
+
"print(f\"Trainable: {trainable_params:,}\")\n",
|
| 272 |
+
"\n",
|
| 273 |
+
"# Quick forward pass test\n",
|
| 274 |
+
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
| 275 |
+
"model = model.to(device)\n",
|
| 276 |
+
"test_x = torch.randn(1, 3, image_size, image_size, device=device)\n",
|
| 277 |
+
"test_t = torch.tensor([0.5], device=device)\n",
|
| 278 |
+
"with torch.no_grad():\n",
|
| 279 |
+
" test_out = model(test_x, test_t)\n",
|
| 280 |
+
"print(f\"Forward pass OK: {test_x.shape} → {test_out.shape}\")\n",
|
| 281 |
+
"del test_x, test_out\n",
|
| 282 |
+
"if device == 'cuda':\n",
|
| 283 |
+
" torch.cuda.empty_cache()"
|
| 284 |
+
]
|
| 285 |
+
},
|
| 286 |
+
{
|
| 287 |
+
"cell_type": "markdown",
|
| 288 |
+
"metadata": {},
|
| 289 |
+
"source": [
|
| 290 |
+
"## 🚀 Train!"
|
| 291 |
+
]
|
| 292 |
+
},
|
| 293 |
+
{
|
| 294 |
+
"cell_type": "code",
|
| 295 |
+
"execution_count": null,
|
| 296 |
+
"metadata": {},
|
| 297 |
+
"outputs": [],
|
| 298 |
+
"source": [
|
| 299 |
+
"import os\n",
|
| 300 |
+
"import time\n",
|
| 301 |
+
"import math\n",
|
| 302 |
+
"from tqdm.auto import tqdm\n",
|
| 303 |
+
"from torchvision.utils import save_image, make_grid\n",
|
| 304 |
+
"from liquid_diffusion.trainer import RectifiedFlowTrainer, get_cosine_schedule_with_warmup\n",
|
| 305 |
+
"\n",
|
| 306 |
+
"# Create output directories\n",
|
| 307 |
+
"os.makedirs('checkpoints', exist_ok=True)\n",
|
| 308 |
+
"os.makedirs('samples', exist_ok=True)\n",
|
| 309 |
+
"\n",
|
| 310 |
+
"# Build trainer\n",
|
| 311 |
+
"trainer = RectifiedFlowTrainer(\n",
|
| 312 |
+
" model=model,\n",
|
| 313 |
+
" lr=learning_rate,\n",
|
| 314 |
+
" weight_decay=weight_decay,\n",
|
| 315 |
+
" ema_decay=ema_decay,\n",
|
| 316 |
+
" grad_clip=grad_clip,\n",
|
| 317 |
+
" time_sampling=time_sampling,\n",
|
| 318 |
+
" device=device,\n",
|
| 319 |
+
" use_amp=use_amp,\n",
|
| 320 |
+
" amp_dtype=amp_dtype,\n",
|
| 321 |
+
")\n",
|
| 322 |
+
"\n",
|
| 323 |
+
"# Learning rate scheduler\n",
|
| 324 |
+
"scheduler = get_cosine_schedule_with_warmup(\n",
|
| 325 |
+
" trainer.optimizer, warmup_steps, total_steps\n",
|
| 326 |
+
")\n",
|
| 327 |
+
"\n",
|
| 328 |
+
"# Optional: resume from checkpoint\n",
|
| 329 |
+
"resume_path = 'checkpoints/latest.pt'\n",
|
| 330 |
+
"if os.path.exists(resume_path):\n",
|
| 331 |
+
" trainer.load_checkpoint(resume_path)\n",
|
| 332 |
+
" print(f\"Resumed from step {trainer.step}\")\n",
|
| 333 |
+
"\n",
|
| 334 |
+
"print(f\"\\n{'='*60}\")\n",
|
| 335 |
+
"print(f\"Starting training: {total_steps:,} steps\")\n",
|
| 336 |
+
"print(f\"Model: liquid_diffusion_{model_size} ({total_params/1e6:.1f}M params)\")\n",
|
| 337 |
+
"print(f\"Resolution: {image_size}×{image_size}, Batch: {batch_size}\")\n",
|
| 338 |
+
"print(f\"LR: {learning_rate}, Warmup: {warmup_steps}, AMP: {use_amp}\")\n",
|
| 339 |
+
"print(f\"{'='*60}\\n\")\n",
|
| 340 |
+
"\n",
|
| 341 |
+
"# Training loop\n",
|
| 342 |
+
"start_time = time.time()\n",
|
| 343 |
+
"data_iter = iter(dataloader)\n",
|
| 344 |
+
"pbar = tqdm(range(trainer.step, total_steps), desc='Training', dynamic_ncols=True)\n",
|
| 345 |
+
"loss_history = []\n",
|
| 346 |
+
"\n",
|
| 347 |
+
"for step in pbar:\n",
|
| 348 |
+
" # Get batch (cycle through dataset)\n",
|
| 349 |
+
" try:\n",
|
| 350 |
+
" batch = next(data_iter)\n",
|
| 351 |
+
" except StopIteration:\n",
|
| 352 |
+
" data_iter = iter(dataloader)\n",
|
| 353 |
+
" batch = next(data_iter)\n",
|
| 354 |
+
" \n",
|
| 355 |
+
" x0 = batch.to(device)\n",
|
| 356 |
+
" \n",
|
| 357 |
+
" # Train step\n",
|
| 358 |
+
" metrics = trainer.train_step(x0)\n",
|
| 359 |
+
" scheduler.step()\n",
|
| 360 |
+
" \n",
|
| 361 |
+
" # Logging\n",
|
| 362 |
+
" loss_history.append(metrics['loss'])\n",
|
| 363 |
+
" avg_loss = sum(loss_history[-100:]) / len(loss_history[-100:])\n",
|
| 364 |
+
" lr_current = scheduler.get_last_lr()[0]\n",
|
| 365 |
+
" \n",
|
| 366 |
+
" pbar.set_postfix({\n",
|
| 367 |
+
" 'loss': f\"{metrics['loss']:.4f}\",\n",
|
| 368 |
+
" 'avg': f\"{avg_loss:.4f}\",\n",
|
| 369 |
+
" 'lr': f\"{lr_current:.6f}\",\n",
|
| 370 |
+
" 'gn': f\"{metrics['grad_norm']:.2f}\",\n",
|
| 371 |
+
" })\n",
|
| 372 |
+
" \n",
|
| 373 |
+
" # Generate samples\n",
|
| 374 |
+
" if (step + 1) % sample_every == 0 or step == 0:\n",
|
| 375 |
+
" print(f\"\\nGenerating samples at step {step+1}...\")\n",
|
| 376 |
+
" samples = trainer.sample(\n",
|
| 377 |
+
" batch_size=num_sample_images, image_size=image_size,\n",
|
| 378 |
+
" num_steps=num_sample_steps, use_ema=True\n",
|
| 379 |
+
" )\n",
|
| 380 |
+
" # Save grid\n",
|
| 381 |
+
" grid = make_grid(samples * 0.5 + 0.5, nrow=int(math.sqrt(num_sample_images)), padding=2)\n",
|
| 382 |
+
" save_image(grid, f'samples/step_{step+1:06d}.png')\n",
|
| 383 |
+
" \n",
|
| 384 |
+
" # Display\n",
|
| 385 |
+
" fig, axes = plt.subplots(1, num_sample_images, figsize=(4*num_sample_images, 4))\n",
|
| 386 |
+
" if num_sample_images == 1:\n",
|
| 387 |
+
" axes = [axes]\n",
|
| 388 |
+
" for i, ax in enumerate(axes):\n",
|
| 389 |
+
" img = samples[i].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5\n",
|
| 390 |
+
" ax.imshow(np.clip(img, 0, 1))\n",
|
| 391 |
+
" ax.axis('off')\n",
|
| 392 |
+
" plt.suptitle(f'Step {step+1} (EMA samples, {num_sample_steps} Euler steps)')\n",
|
| 393 |
+
" plt.tight_layout()\n",
|
| 394 |
+
" plt.show()\n",
|
| 395 |
+
" \n",
|
| 396 |
+
" # Save checkpoint\n",
|
| 397 |
+
" if (step + 1) % save_every == 0:\n",
|
| 398 |
+
" trainer.save_checkpoint(f'checkpoints/step_{step+1:06d}.pt', extra={'config': {\n",
|
| 399 |
+
" 'model_size': model_size, 'image_size': image_size,\n",
|
| 400 |
+
" 'batch_size': batch_size, 'learning_rate': learning_rate,\n",
|
| 401 |
+
" }})\n",
|
| 402 |
+
" trainer.save_checkpoint('checkpoints/latest.pt')\n",
|
| 403 |
+
" print(f\"Saved checkpoint at step {step+1}\")\n",
|
| 404 |
+
" \n",
|
| 405 |
+
" # Safety: check for NaN\n",
|
| 406 |
+
" if math.isnan(metrics['loss']):\n",
|
| 407 |
+
" print(\"\\n⚠️ NaN loss detected! Stopping training.\")\n",
|
| 408 |
+
" print(\"Try: reduce learning_rate, increase grad_clip, or use smaller model\")\n",
|
| 409 |
+
" break\n",
|
| 410 |
+
"\n",
|
| 411 |
+
"elapsed = time.time() - start_time\n",
|
| 412 |
+
"print(f\"\\nTraining complete! {trainer.step:,} steps in {elapsed/3600:.1f}h\")\n",
|
| 413 |
+
"print(f\"Final avg loss: {sum(loss_history[-100:])/len(loss_history[-100:]):.4f}\")\n",
|
| 414 |
+
"\n",
|
| 415 |
+
"# Final save\n",
|
| 416 |
+
"trainer.save_checkpoint('checkpoints/final.pt')\n",
|
| 417 |
+
"print(\"Saved final checkpoint.\")"
|
| 418 |
+
]
|
| 419 |
+
},
|
| 420 |
+
{
|
| 421 |
+
"cell_type": "markdown",
|
| 422 |
+
"metadata": {},
|
| 423 |
+
"source": [
|
| 424 |
+
"## 📊 Training Loss Curve"
|
| 425 |
+
]
|
| 426 |
+
},
|
| 427 |
+
{
|
| 428 |
+
"cell_type": "code",
|
| 429 |
+
"execution_count": null,
|
| 430 |
+
"metadata": {},
|
| 431 |
+
"outputs": [],
|
| 432 |
+
"source": [
|
| 433 |
+
"import matplotlib.pyplot as plt\n",
|
| 434 |
+
"import numpy as np\n",
|
| 435 |
+
"\n",
|
| 436 |
+
"if loss_history:\n",
|
| 437 |
+
" # Smooth the loss\n",
|
| 438 |
+
" window = min(100, len(loss_history) // 5 + 1)\n",
|
| 439 |
+
" smoothed = np.convolve(loss_history, np.ones(window)/window, mode='valid')\n",
|
| 440 |
+
" \n",
|
| 441 |
+
" fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))\n",
|
| 442 |
+
" \n",
|
| 443 |
+
" ax1.plot(loss_history, alpha=0.3, label='Raw')\n",
|
| 444 |
+
" ax1.plot(range(window-1, len(loss_history)), smoothed, label=f'Smoothed (w={window})')\n",
|
| 445 |
+
" ax1.set_xlabel('Step')\n",
|
| 446 |
+
" ax1.set_ylabel('Loss')\n",
|
| 447 |
+
" ax1.set_title('Training Loss')\n",
|
| 448 |
+
" ax1.legend()\n",
|
| 449 |
+
" ax1.grid(True, alpha=0.3)\n",
|
| 450 |
+
" \n",
|
| 451 |
+
" ax2.plot(loss_history[-min(1000, len(loss_history)):], alpha=0.5)\n",
|
| 452 |
+
" ax2.set_xlabel('Recent Steps')\n",
|
| 453 |
+
" ax2.set_ylabel('Loss')\n",
|
| 454 |
+
" ax2.set_title('Recent Loss (last 1000 steps)')\n",
|
| 455 |
+
" ax2.grid(True, alpha=0.3)\n",
|
| 456 |
+
" \n",
|
| 457 |
+
" plt.tight_layout()\n",
|
| 458 |
+
" plt.show()\n",
|
| 459 |
+
"else:\n",
|
| 460 |
+
" print(\"No training history yet.\")"
|
| 461 |
+
]
|
| 462 |
+
},
|
| 463 |
+
{
|
| 464 |
+
"cell_type": "markdown",
|
| 465 |
+
"metadata": {},
|
| 466 |
+
"source": [
|
| 467 |
+
"## 🎨 Generate Images"
|
| 468 |
+
]
|
| 469 |
+
},
|
| 470 |
+
{
|
| 471 |
+
"cell_type": "code",
|
| 472 |
+
"execution_count": null,
|
| 473 |
+
"metadata": {},
|
| 474 |
+
"outputs": [],
|
| 475 |
+
"source": [
|
| 476 |
+
"#@title Generation Settings {display-mode: \"form\"}\n",
|
| 477 |
+
"num_images = 8 #@param {type:\"integer\"}\n",
|
| 478 |
+
"sampling_steps = 50 #@param [25, 50, 100, 200] {type:\"integer\"}\n",
|
| 479 |
+
"use_ema_model = True #@param {type:\"boolean\"}\n",
|
| 480 |
+
"\n",
|
| 481 |
+
"print(f\"Generating {num_images} images with {sampling_steps} Euler steps...\")\n",
|
| 482 |
+
"samples = trainer.sample(\n",
|
| 483 |
+
" batch_size=num_images, image_size=image_size,\n",
|
| 484 |
+
" num_steps=sampling_steps, use_ema=use_ema_model,\n",
|
| 485 |
+
")\n",
|
| 486 |
+
"\n",
|
| 487 |
+
"# Display\n",
|
| 488 |
+
"ncols = min(4, num_images)\n",
|
| 489 |
+
"nrows = (num_images + ncols - 1) // ncols\n",
|
| 490 |
+
"fig, axes = plt.subplots(nrows, ncols, figsize=(4*ncols, 4*nrows))\n",
|
| 491 |
+
"if nrows == 1 and ncols == 1:\n",
|
| 492 |
+
" axes = [[axes]]\n",
|
| 493 |
+
"elif nrows == 1:\n",
|
| 494 |
+
" axes = [axes]\n",
|
| 495 |
+
"for i in range(num_images):\n",
|
| 496 |
+
" r, c = i // ncols, i % ncols\n",
|
| 497 |
+
" img = samples[i].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5\n",
|
| 498 |
+
" axes[r][c].imshow(np.clip(img, 0, 1))\n",
|
| 499 |
+
" axes[r][c].axis('off')\n",
|
| 500 |
+
"# Hide unused axes\n",
|
| 501 |
+
"for i in range(num_images, nrows * ncols):\n",
|
| 502 |
+
" r, c = i // ncols, i % ncols\n",
|
| 503 |
+
" axes[r][c].axis('off')\n",
|
| 504 |
+
"plt.suptitle(f'LiquidDiffusion Samples ({sampling_steps} steps, {\"EMA\" if use_ema_model else \"online\"})')\n",
|
| 505 |
+
"plt.tight_layout()\n",
|
| 506 |
+
"plt.show()\n",
|
| 507 |
+
"\n",
|
| 508 |
+
"# Save\n",
|
| 509 |
+
"grid = make_grid(samples * 0.5 + 0.5, nrow=ncols, padding=2)\n",
|
| 510 |
+
"save_image(grid, 'samples/generated.png')\n",
|
| 511 |
+
"print(\"Saved to samples/generated.png\")"
|
| 512 |
+
]
|
| 513 |
+
},
|
| 514 |
+
{
|
| 515 |
+
"cell_type": "markdown",
|
| 516 |
+
"metadata": {},
|
| 517 |
+
"source": [
|
| 518 |
+
"## 🔬 Visualize the Denoising Process"
|
| 519 |
+
]
|
| 520 |
+
},
|
| 521 |
+
{
|
| 522 |
+
"cell_type": "code",
|
| 523 |
+
"execution_count": null,
|
| 524 |
+
"metadata": {},
|
| 525 |
+
"outputs": [],
|
| 526 |
+
"source": [
|
| 527 |
+
"# Show step-by-step denoising\n",
|
| 528 |
+
"num_vis_steps = 10\n",
|
| 529 |
+
"total_euler_steps = 50\n",
|
| 530 |
+
"vis_interval = total_euler_steps // num_vis_steps\n",
|
| 531 |
+
"\n",
|
| 532 |
+
"model_vis = trainer.ema_model\n",
|
| 533 |
+
"model_vis.eval()\n",
|
| 534 |
+
"\n",
|
| 535 |
+
"z = torch.randn(1, 3, image_size, image_size, device=device)\n",
|
| 536 |
+
"dt = 1.0 / total_euler_steps\n",
|
| 537 |
+
"intermediates = [z.clone()]\n",
|
| 538 |
+
"\n",
|
| 539 |
+
"with torch.no_grad():\n",
|
| 540 |
+
" for i in range(total_euler_steps, 0, -1):\n",
|
| 541 |
+
" t = torch.full((1,), i / total_euler_steps, device=device)\n",
|
| 542 |
+
" v = model_vis(z, t)\n",
|
| 543 |
+
" z = z - v * dt\n",
|
| 544 |
+
" if (total_euler_steps - i + 1) % vis_interval == 0:\n",
|
| 545 |
+
" intermediates.append(z.clone())\n",
|
| 546 |
+
"\n",
|
| 547 |
+
"intermediates.append(z.clamp(-1, 1))\n",
|
| 548 |
+
"\n",
|
| 549 |
+
"fig, axes = plt.subplots(1, len(intermediates), figsize=(3*len(intermediates), 3))\n",
|
| 550 |
+
"for idx, (ax, img_t) in enumerate(zip(axes, intermediates)):\n",
|
| 551 |
+
" img = img_t[0].cpu().permute(1, 2, 0).numpy() * 0.5 + 0.5\n",
|
| 552 |
+
" ax.imshow(np.clip(img, 0, 1))\n",
|
| 553 |
+
" ax.axis('off')\n",
|
| 554 |
+
" if idx == 0:\n",
|
| 555 |
+
" ax.set_title('Noise (t=1)')\n",
|
| 556 |
+
" elif idx == len(intermediates) - 1:\n",
|
| 557 |
+
" ax.set_title('Output (t=0)')\n",
|
| 558 |
+
" else:\n",
|
| 559 |
+
" ax.set_title(f't={1-idx*vis_interval/total_euler_steps:.1f}')\n",
|
| 560 |
+
"plt.suptitle('LiquidDiffusion Denoising Process')\n",
|
| 561 |
+
"plt.tight_layout()\n",
|
| 562 |
+
"plt.show()"
|
| 563 |
+
]
|
| 564 |
+
},
|
| 565 |
+
{
|
| 566 |
+
"cell_type": "markdown",
|
| 567 |
+
"metadata": {},
|
| 568 |
+
"source": [
|
| 569 |
+
"## 💾 Save & Export Model"
|
| 570 |
+
]
|
| 571 |
+
},
|
| 572 |
+
{
|
| 573 |
+
"cell_type": "code",
|
| 574 |
+
"execution_count": null,
|
| 575 |
+
"metadata": {},
|
| 576 |
+
"outputs": [],
|
| 577 |
+
"source": [
|
| 578 |
+
"# Save final checkpoint\n",
|
| 579 |
+
"trainer.save_checkpoint('checkpoints/final.pt', extra={\n",
|
| 580 |
+
" 'config': {\n",
|
| 581 |
+
" 'model_size': model_size,\n",
|
| 582 |
+
" 'image_size': image_size,\n",
|
| 583 |
+
" 'total_params': total_params,\n",
|
| 584 |
+
" 'training_steps': trainer.step,\n",
|
| 585 |
+
" 'dataset': dataset_name,\n",
|
| 586 |
+
" }\n",
|
| 587 |
+
"})\n",
|
| 588 |
+
"print(f\"Saved checkpoint: checkpoints/final.pt\")\n",
|
| 589 |
+
"print(f\"Model: liquid_diffusion_{model_size} ({total_params/1e6:.1f}M params)\")\n",
|
| 590 |
+
"print(f\"Trained for {trainer.step:,} steps on {dataset_name}\")"
|
| 591 |
+
]
|
| 592 |
+
},
|
| 593 |
+
{
|
| 594 |
+
"cell_type": "code",
|
| 595 |
+
"execution_count": null,
|
| 596 |
+
"metadata": {},
|
| 597 |
+
"outputs": [],
|
| 598 |
+
"source": [
|
| 599 |
+
"# Optional: Push to Hugging Face Hub\n",
|
| 600 |
+
"# Uncomment and fill in your details:\n",
|
| 601 |
+
"\n",
|
| 602 |
+
"# from huggingface_hub import HfApi, login\n",
|
| 603 |
+
"# login() # or use token\n",
|
| 604 |
+
"# api = HfApi()\n",
|
| 605 |
+
"# repo_id = \"your-username/liquid-diffusion-celebahq-256\" # change this\n",
|
| 606 |
+
"# api.create_repo(repo_id, exist_ok=True)\n",
|
| 607 |
+
"# api.upload_file('checkpoints/final.pt', 'model.pt', repo_id)\n",
|
| 608 |
+
"# api.upload_folder('liquid_diffusion/', 'liquid_diffusion/', repo_id)\n",
|
| 609 |
+
"# print(f\"Uploaded to https://huggingface.co/{repo_id}\")"
|
| 610 |
+
]
|
| 611 |
+
},
|
| 612 |
+
{
|
| 613 |
+
"cell_type": "markdown",
|
| 614 |
+
"metadata": {},
|
| 615 |
+
"source": [
|
| 616 |
+
"## 📚 Architecture Details & Theory\n",
|
| 617 |
+
"\n",
|
| 618 |
+
"### Why Liquid Neural Networks for Image Generation?\n",
|
| 619 |
+
"\n",
|
| 620 |
+
"**Liquid Time-Constant (LTC) Networks** (Hasani et al., 2020) define neurons with input-dependent time constants:\n",
|
| 621 |
+
"\n",
|
| 622 |
+
"```\n",
|
| 623 |
+
"dx/dt = -[1/τ + f(x,I,θ)] · x + f(x,I,θ) · A\n",
|
| 624 |
+
"```\n",
|
| 625 |
+
"\n",
|
| 626 |
+
"The system time constant `τ_sys = τ/(1 + τ·f)` adapts dynamically based on input — the neuron speeds up or slows down its response depending on what it sees. This is the \"liquid\" property.\n",
|
| 627 |
+
"\n",
|
| 628 |
+
"**CfC (Closed-form Continuous-depth)** networks (Hasani et al., 2022) solve this ODE in closed form:\n",
|
| 629 |
+
"\n",
|
| 630 |
+
"```\n",
|
| 631 |
+
"x(t) = σ(-f·t) ⊙ g + (1 - σ(-f·t)) ⊙ h\n",
|
| 632 |
+
"```\n",
|
| 633 |
+
"\n",
|
| 634 |
+
"This eliminates the ODE solver — making CfC **fully parallelizable** while preserving the adaptive time constant behavior.\n",
|
| 635 |
+
"\n",
|
| 636 |
+
"### Our Innovation: CfC × Diffusion Timestep\n",
|
| 637 |
+
"\n",
|
| 638 |
+
"In diffusion models, the network must process images at different noise levels `t ∈ [0,1]`. We observe that:\n",
|
| 639 |
+
"\n",
|
| 640 |
+
"1. CfC's time parameter `t` controls interpolation between two learned states\n",
|
| 641 |
+
"2. Diffusion's noise level `t` controls how the denoiser should behave\n",
|
| 642 |
+
"3. **These are the same concept** — the CfC time parameter IS the diffusion timestep\n",
|
| 643 |
+
"\n",
|
| 644 |
+
"This gives us:\n",
|
| 645 |
+
"- At `t≈0` (clean images): σ(-f·t)≈0.5, balanced processing for detail refinement\n",
|
| 646 |
+
"- At `t≈1` (noisy images): σ(-f·t) saturates, specialized denoising\n",
|
| 647 |
+
"- The gate `f` is **input-dependent** — different image content gets different time responses\n",
|
| 648 |
+
"\n",
|
| 649 |
+
"### References\n",
|
| 650 |
+
"\n",
|
| 651 |
+
"1. Hasani et al., \"Liquid Time-constant Networks\" (AAAI 2021) — arxiv:2006.04439\n",
|
| 652 |
+
"2. Hasani et al., \"Closed-form Continuous-time Neural Networks\" (Nature MI 2022) — arxiv:2106.13898\n",
|
| 653 |
+
"3. LiquidTAD: Parallel liquid relaxation — arxiv:2604.18274\n",
|
| 654 |
+
"4. USM: U-Shape Mamba for diffusion — arxiv:2504.13499\n",
|
| 655 |
+
"5. DiffuSSM: Diffusion without attention — arxiv:2311.18257\n",
|
| 656 |
+
"6. Liu et al., \"Flow Straight and Fast: Rectified Flow\" (ICLR 2023) — arxiv:2209.03003"
|
| 657 |
+
]
|
| 658 |
+
}
|
| 659 |
+
],
|
| 660 |
+
"metadata": {
|
| 661 |
+
"accelerator": "GPU",
|
| 662 |
+
"colab": {
|
| 663 |
+
"gpuType": "T4",
|
| 664 |
+
"provenance": [],
|
| 665 |
+
"toc_visible": true
|
| 666 |
+
},
|
| 667 |
+
"kernelspec": {
|
| 668 |
+
"display_name": "Python 3",
|
| 669 |
+
"name": "python3"
|
| 670 |
+
},
|
| 671 |
+
"language_info": {
|
| 672 |
+
"name": "python",
|
| 673 |
+
"version": "3.10.0"
|
| 674 |
+
}
|
| 675 |
},
|
| 676 |
+
"nbformat": 4,
|
| 677 |
+
"nbformat_minor": 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 678 |
}
|