Add Colab training notebook
Browse files- LiquidGen_Colab_Notebook.ipynb +411 -0
LiquidGen_Colab_Notebook.ipynb
ADDED
|
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"nbformat": 4,
|
| 3 |
+
"nbformat_minor": 0,
|
| 4 |
+
"metadata": {
|
| 5 |
+
"colab": {
|
| 6 |
+
"provenance": [],
|
| 7 |
+
"gpuType": "T4"
|
| 8 |
+
},
|
| 9 |
+
"kernelspec": {
|
| 10 |
+
"name": "python3",
|
| 11 |
+
"display_name": "Python 3"
|
| 12 |
+
},
|
| 13 |
+
"accelerator": "GPU"
|
| 14 |
+
},
|
| 15 |
+
"cells": [
|
| 16 |
+
{
|
| 17 |
+
"cell_type": "markdown",
|
| 18 |
+
"metadata": {},
|
| 19 |
+
"source": [
|
| 20 |
+
"# \ud83e\uddea LiquidGen: Liquid Neural Network Image Generator\n",
|
| 21 |
+
"\n",
|
| 22 |
+
"**A novel attention-free diffusion model using CfC Liquid Neural Network dynamics.**\n",
|
| 23 |
+
"\n",
|
| 24 |
+
"### Key Features:\n",
|
| 25 |
+
"- **No Attention** \u2014 O(n) complexity using liquid time constants\n",
|
| 26 |
+
"- **Fully Parallelizable** \u2014 No sequential ODE solving\n",
|
| 27 |
+
"- **Flow Matching** \u2014 Modern velocity-prediction training\n",
|
| 28 |
+
"- **Frozen Flux VAE** \u2014 16-channel latent space\n",
|
| 29 |
+
"- **Fits 16GB VRAM** \u2014 Designed for Colab free tier\n",
|
| 30 |
+
"\n",
|
| 31 |
+
"Based on: Liquid Time-constant Networks (NeurIPS 2020), CfC (Nature MI 2022), ZigMa (ECCV 2024), DiMSUM (NeurIPS 2024)\n"
|
| 32 |
+
]
|
| 33 |
+
},
|
| 34 |
+
{
|
| 35 |
+
"cell_type": "markdown",
|
| 36 |
+
"metadata": {},
|
| 37 |
+
"source": [
|
| 38 |
+
"## \ud83d\udce6 Install Dependencies"
|
| 39 |
+
]
|
| 40 |
+
},
|
| 41 |
+
{
|
| 42 |
+
"cell_type": "code",
|
| 43 |
+
"execution_count": null,
|
| 44 |
+
"metadata": {},
|
| 45 |
+
"outputs": [],
|
| 46 |
+
"source": [
|
| 47 |
+
"!pip install -q torch torchvision diffusers datasets accelerate huggingface_hub Pillow"
|
| 48 |
+
]
|
| 49 |
+
},
|
| 50 |
+
{
|
| 51 |
+
"cell_type": "markdown",
|
| 52 |
+
"metadata": {},
|
| 53 |
+
"source": [
|
| 54 |
+
"## \ud83d\udd27 Configuration"
|
| 55 |
+
]
|
| 56 |
+
},
|
| 57 |
+
{
|
| 58 |
+
"cell_type": "code",
|
| 59 |
+
"execution_count": null,
|
| 60 |
+
"metadata": {},
|
| 61 |
+
"outputs": [],
|
| 62 |
+
"source": [
|
| 63 |
+
"MODEL_SIZE = \"small\" # \"small\" (~55M), \"base\" (~140M), \"large\" (~280M)\n",
|
| 64 |
+
"IMAGE_SIZE = 256 # 256 or 512\n",
|
| 65 |
+
"DATASET_NAME = \"huggan/wikiart\"\n",
|
| 66 |
+
"IMAGE_COLUMN = \"image\"\n",
|
| 67 |
+
"LABEL_COLUMN = \"style\" # \"style\" (27), \"genre\" (11), \"\" for unconditional\n",
|
| 68 |
+
"NUM_CLASSES = 27\n",
|
| 69 |
+
"BATCH_SIZE = 8\n",
|
| 70 |
+
"GRADIENT_ACCUMULATION = 4\n",
|
| 71 |
+
"LEARNING_RATE = 1e-4\n",
|
| 72 |
+
"WEIGHT_DECAY = 0.01\n",
|
| 73 |
+
"MAX_GRAD_NORM = 2.0\n",
|
| 74 |
+
"NUM_EPOCHS = 50\n",
|
| 75 |
+
"WARMUP_STEPS = 500\n",
|
| 76 |
+
"EMA_DECAY = 0.9999\n",
|
| 77 |
+
"NUM_SAMPLE_STEPS = 50\n",
|
| 78 |
+
"CFG_SCALE = 2.0\n",
|
| 79 |
+
"OUTPUT_DIR = \"/content/liquidgen_outputs\"\n",
|
| 80 |
+
"SAVE_EVERY = 2000\n",
|
| 81 |
+
"SAMPLE_EVERY = 500\n",
|
| 82 |
+
"LOG_EVERY = 50\n",
|
| 83 |
+
"PUSH_TO_HUB = False\n",
|
| 84 |
+
"HUB_MODEL_ID = \"\"\n",
|
| 85 |
+
"VAE_ID = \"black-forest-labs/FLUX.1-schnell\"\n",
|
| 86 |
+
"VAE_SUBFOLDER = \"vae\"\n",
|
| 87 |
+
"\n",
|
| 88 |
+
"import torch\n",
|
| 89 |
+
"if torch.cuda.is_available():\n",
|
| 90 |
+
" gpu = torch.cuda.get_device_name(0)\n",
|
| 91 |
+
" mem = torch.cuda.get_device_properties(0).total_mem / 1024**3\n",
|
| 92 |
+
" print(f\"GPU: {gpu} ({mem:.1f} GB)\")\n",
|
| 93 |
+
" if mem < 12: print(\"\u26a0\ufe0f Low VRAM! Use small model, 256px, bs=4\")\n",
|
| 94 |
+
" elif mem < 20: print(\"\u2705 T4 detected. Good for base model, 256px\")\n",
|
| 95 |
+
" else: print(\"\ud83d\ude80 Large GPU! Can run large model, 512px\")\n",
|
| 96 |
+
"else:\n",
|
| 97 |
+
" print(\"\u26a0\ufe0f No GPU! Go to Runtime \u2192 Change runtime type \u2192 GPU\")"
|
| 98 |
+
]
|
| 99 |
+
},
|
| 100 |
+
{
|
| 101 |
+
"cell_type": "markdown",
|
| 102 |
+
"metadata": {},
|
| 103 |
+
"source": [
|
| 104 |
+
"## \ud83c\udfd7\ufe0f Model Architecture"
|
| 105 |
+
]
|
| 106 |
+
},
|
| 107 |
+
{
|
| 108 |
+
"cell_type": "code",
|
| 109 |
+
"execution_count": null,
|
| 110 |
+
"metadata": {},
|
| 111 |
+
"outputs": [],
|
| 112 |
+
"source": [
|
| 113 |
+
"\"\"\"\nLiquidGen: A Novel Liquid Neural Network Image Generation Model\n\nArchitecture Overview:\n- Frozen VAE encoder/decoder (FLUX.1-schnell, 16ch latent, 8x compression)\n- Liquid backbone for denoising (fully parallelizable, no attention, no sequential ODE)\n- Flow matching training objective (velocity prediction)\n\nKey Innovation: Replaces attention with Liquid Neural Network dynamics:\n- CfC-inspired closed-form update: x_new = \u03b1\u00b7x + (1-\u03b1)\u00b7h(x)\n- Per-channel learnable decay rates (liquid time constants)\n- Depthwise + pointwise convolutions for spatial context (no attention needed)\n- Zigzag spatial scanning for global receptive field\n- Gated stimulus with biologically-inspired sign constraints\n- U-Net style long skip connections from shallow to deep blocks\n\nMath Foundation (from Hasani et al., CfC paper):\n x_{t+1} = exp(-\u0394t/\u03c4_t) \u00b7 x_t + (1 - exp(-\u0394t/\u03c4_t)) \u00b7 h(x_t, u_t)\n \nOur parallelizable adaptation (inspired by LiquidTAD):\n \u03b1 = exp(-softplus(\u03c1)) [per-channel learnable decay]\n h = gate \u00b7 stimulus [gated depthwise conv output] \n out = \u03b1 \u00b7 x + (1 - \u03b1) \u00b7 h [liquid relaxation blend]\n\nThis removes the input-dependent \u03c4 (which requires sequential computation)\nand replaces it with a per-channel learned decay \u2014 making it fully parallel\nwhile preserving the liquid dynamics' ability to blend old state with new input.\n\nDesign for 16GB VRAM (Colab free tier):\n- VAE frozen: ~1GB\n- Backbone: ~55-280M params (~100-550MB in fp16) \n- Training overhead (grads + optimizer): ~3-8GB\n- Batch of latents: ~1-2GB\n- Total: fits comfortably in 16GB\n\nReferences:\n- Hasani et al., \"Liquid Time-constant Networks\" (NeurIPS 2020)\n- Hasani et al., \"Closed-form Continuous-depth Models\" (Nature Machine Intelligence 2022)\n- Lechner et al., \"Neural Circuit Policies\" (Nature Machine Intelligence 2020)\n- LiquidTAD (2025) - Parallelized liquid dynamics\n- ZigMa (ECCV 2024) - Zigzag scanning for SSM-based diffusion\n- DiMSUM (NeurIPS 2024) - Attention-free diffusion\n\"\"\"\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport math\nfrom typing import Optional, Tuple\n\n\n# =============================================================================\n# Building Blocks\n# =============================================================================\n\nclass LiquidTimeConstant(nn.Module):\n \"\"\"\n Core liquid time-constant module.\n \n Implements the CfC closed-form dynamics in a fully parallelizable way:\n out = \u03b1 \u00b7 x + (1 - \u03b1) \u00b7 stimulus\n \n where \u03b1 = exp(-softplus(\u03c1)) is a learnable per-channel decay rate,\n derived from the liquid time constant \u03c4 = 1/softplus(\u03c1).\n \n This preserves the key property of Liquid Neural Networks:\n - Exponential relaxation toward a target (stimulus)\n - Rate controlled by \u03c4 (how fast to adapt)\n - No sequential ODE solving required\n \n Stability guarantee (from LTC Theorem 1):\n \u03c4_sys \u2208 [\u03c4/(1+\u03c4W), \u03c4] \u2014 time constants NEVER explode\n \"\"\"\n def __init__(self, channels: int):\n super().__init__()\n # \u03c1 parameterizes the decay: \u03bb = softplus(\u03c1), \u03b1 = exp(-\u03bb)\n # Initialize \u03c1=0 \u2192 \u03bb\u22480.693 \u2192 \u03b1\u22480.5 (equal blend of old and new)\n self.rho = nn.Parameter(torch.zeros(channels))\n \n def forward(self, x: torch.Tensor, stimulus: torch.Tensor) -> torch.Tensor:\n \"\"\"\n x: [B, C, H, W] - current state (residual path)\n stimulus: [B, C, H, W] - computed target from context\n returns: [B, C, H, W] - liquid-blended output\n \"\"\"\n lam = F.softplus(self.rho) + 1e-5\n alpha = torch.exp(-lam).view(1, -1, 1, 1)\n return alpha * x + (1.0 - alpha) * stimulus\n\n\nclass GatedDepthwiseStimulusConv(nn.Module):\n \"\"\"\n Computes the spatial stimulus using depthwise-separable convolutions\n with a sigmoid gate (inspired by GLU / gated mechanisms in SSMs).\n \n This replaces attention for capturing local spatial context:\n - Depthwise conv: captures local spatial patterns per channel\n - Pointwise conv: mixes channel information\n - Sigmoid gate: controls information flow (like synaptic gating in NCP)\n \n Two parallel paths (inspired by NCP inter\u2192command split):\n 1. Stimulus path: DW-conv \u2192 PW-conv \u2192 GELU \u2192 project back\n 2. Gate path: DW-conv \u2192 PW-conv \u2192 sigmoid\n Output = stimulus * gate\n \"\"\"\n def __init__(self, channels: int, kernel_size: int = 7, expand_ratio: float = 2.0):\n super().__init__()\n hidden = int(channels * expand_ratio)\n \n self.stim_dw = nn.Conv2d(channels, channels, kernel_size, \n padding=kernel_size // 2, groups=channels, bias=False)\n self.stim_pw = nn.Conv2d(channels, hidden, 1, bias=False)\n self.stim_act = nn.GELU()\n self.stim_proj = nn.Conv2d(hidden, channels, 1, bias=False)\n \n self.gate_dw = nn.Conv2d(channels, channels, kernel_size,\n padding=kernel_size // 2, groups=channels, bias=False)\n self.gate_pw = nn.Conv2d(channels, channels, 1, bias=True)\n \n def forward(self, x: torch.Tensor) -> torch.Tensor:\n stim = self.stim_proj(self.stim_act(self.stim_pw(self.stim_dw(x))))\n gate = torch.sigmoid(self.gate_pw(self.gate_dw(x)))\n return stim * gate\n\n\nclass ChannelMixMLP(nn.Module):\n \"\"\"Channel mixing MLP with GELU activation (command neuron processing in NCP).\"\"\"\n def __init__(self, channels: int, expand_ratio: float = 4.0):\n super().__init__()\n hidden = int(channels * expand_ratio)\n self.fc1 = nn.Conv2d(channels, hidden, 1, bias=True)\n self.act = nn.GELU()\n self.fc2 = nn.Conv2d(hidden, channels, 1, bias=True)\n \n def forward(self, x: torch.Tensor) -> torch.Tensor:\n return self.fc2(self.act(self.fc1(x)))\n\n\nclass AdaptiveGroupNorm(nn.Module):\n \"\"\"\n Adaptive Group Normalization conditioned on timestep embedding.\n Applies: out = (1 + scale) * GroupNorm(x) + shift\n \"\"\"\n def __init__(self, channels: int, cond_dim: int, num_groups: int = 32):\n super().__init__()\n self.norm = nn.GroupNorm(num_groups, channels, affine=False)\n self.proj = nn.Linear(cond_dim, channels * 2)\n nn.init.zeros_(self.proj.weight)\n nn.init.zeros_(self.proj.bias)\n \n def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:\n h = self.norm(x)\n params = self.proj(cond)\n scale, shift = params.chunk(2, dim=-1)\n return h * (1.0 + scale.unsqueeze(-1).unsqueeze(-1)) + shift.unsqueeze(-1).unsqueeze(-1)\n\n\nclass ZigzagScan1D(nn.Module):\n \"\"\"\n 1D global mixing via zigzag-scanned depthwise conv.\n \n Gives quasi-global receptive field without attention's O(n\u00b2) cost.\n Zigzag scan preserves spatial continuity (from ZigMa, ECCV 2024).\n \"\"\"\n def __init__(self, channels: int, kernel_size: int = 31):\n super().__init__()\n self.conv1d = nn.Conv1d(channels, channels, kernel_size, \n padding=kernel_size // 2, groups=channels, bias=False)\n self.pw = nn.Conv1d(channels, channels, 1, bias=True)\n self.act = nn.GELU()\n \n def _zigzag_indices(self, H: int, W: int, device: torch.device) -> torch.Tensor:\n indices = []\n for i in range(H):\n row = list(range(i * W, (i + 1) * W))\n if i % 2 == 1:\n row = row[::-1]\n indices.extend(row)\n return torch.tensor(indices, device=device, dtype=torch.long)\n \n def _inverse_zigzag_indices(self, H: int, W: int, device: torch.device) -> torch.Tensor:\n fwd = self._zigzag_indices(H, W, device)\n inv = torch.empty_like(fwd)\n inv[fwd] = torch.arange(H * W, device=device)\n return inv\n \n def forward(self, x: torch.Tensor) -> torch.Tensor:\n B, C, H, W = x.shape\n zz_idx = self._zigzag_indices(H, W, x.device)\n inv_idx = self._inverse_zigzag_indices(H, W, x.device)\n x_flat = x.reshape(B, C, H * W)\n x_zz = x_flat[:, :, zz_idx]\n x_mixed = self.pw(self.act(self.conv1d(x_zz)))\n x_restored = x_mixed[:, :, inv_idx]\n return x_restored.reshape(B, C, H, W)\n\n\n# =============================================================================\n# Liquid Block: The core building block\n# =============================================================================\n\nclass LiquidBlock(nn.Module):\n \"\"\"\n A single Liquid Neural Network block for image denoising.\n \n Architecture (maps to NCP hierarchy):\n 1. [SENSORY] AdaGN conditioning \u2192 spatial context extraction\n 2. [INTER] Zigzag 1D scan for global mixing\n 3. [COMMAND] Liquid time-constant blend (CfC dynamics)\n 4. [MOTOR] Channel mixing MLP for output projection\n \n All operations are fully parallelizable \u2014 no sequential dependencies.\n \"\"\"\n def __init__(\n self, channels: int, cond_dim: int, spatial_kernel: int = 7,\n scan_kernel: int = 31, expand_ratio: float = 2.0, mlp_ratio: float = 4.0,\n drop_rate: float = 0.0, use_zigzag: bool = True,\n ):\n super().__init__()\n self.norm1 = AdaptiveGroupNorm(channels, cond_dim)\n self.norm2 = AdaptiveGroupNorm(channels, cond_dim)\n self.spatial_stim = GatedDepthwiseStimulusConv(channels, spatial_kernel, expand_ratio)\n self.use_zigzag = use_zigzag\n if use_zigzag:\n self.zigzag = ZigzagScan1D(channels, scan_kernel)\n self.zigzag_gate = nn.Parameter(torch.zeros(1))\n self.liquid = LiquidTimeConstant(channels)\n self.channel_mix = ChannelMixMLP(channels, mlp_ratio)\n self.liquid2 = LiquidTimeConstant(channels)\n self.drop = nn.Dropout2d(drop_rate) if drop_rate > 0 else nn.Identity()\n \n def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:\n h = self.norm1(x, cond)\n stim = self.spatial_stim(h)\n if self.use_zigzag:\n zz = self.zigzag(h)\n stim = stim + torch.sigmoid(self.zigzag_gate) * zz\n stim = self.drop(stim)\n x = self.liquid(x, stim)\n h2 = self.norm2(x, cond)\n ch_out = self.drop(self.channel_mix(h2))\n x = self.liquid2(x, ch_out)\n return x\n\n\n# =============================================================================\n# Timestep and Class Embeddings\n# =============================================================================\n\nclass TimestepEmbedding(nn.Module):\n \"\"\"Sinusoidal timestep embedding followed by MLP projection.\"\"\"\n def __init__(self, dim: int, freq_dim: int = 256):\n super().__init__()\n self.freq_dim = freq_dim\n self.mlp = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))\n \n def forward(self, t: torch.Tensor) -> torch.Tensor:\n half = self.freq_dim // 2\n freqs = torch.exp(-math.log(10000.0) * torch.arange(half, device=t.device, dtype=t.dtype) / half)\n args = t.unsqueeze(-1) * freqs.unsqueeze(0)\n emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)\n return self.mlp(emb)\n\n\nclass ClassEmbedding(nn.Module):\n \"\"\"Optional class-conditional embedding with CFG null embedding.\"\"\"\n def __init__(self, num_classes: int, dim: int):\n super().__init__()\n self.embed = nn.Embedding(num_classes, dim)\n self.null_embed = nn.Parameter(torch.randn(dim) * 0.02)\n \n def forward(self, labels: torch.Tensor, drop_prob: float = 0.0) -> torch.Tensor:\n emb = self.embed(labels)\n if self.training and drop_prob > 0:\n mask = torch.rand(labels.shape[0], 1, device=labels.device) < drop_prob\n emb = torch.where(mask, self.null_embed.unsqueeze(0).expand_as(emb), emb)\n return emb\n\n\n# =============================================================================\n# LiquidGen: Full Model\n# =============================================================================\n\nclass LiquidGen(nn.Module):\n \"\"\"\n LiquidGen: Liquid Neural Network Image Generator\n \n A novel attention-free diffusion model that uses Liquid Neural Network\n dynamics (CfC closed-form continuous-depth) for image generation.\n \n Features:\n - NO self-attention anywhere \u2014 O(n) complexity\n - NO sequential ODE solving \u2014 fully parallelizable\n - Liquid time constants for adaptive information blending\n - Zigzag scanning for global context\n - Depthwise convolutions for local spatial structure\n - Gated stimulus (biologically-inspired from NCP)\n - U-Net long skip connections (from U-ViT/DiM)\n \n Config Presets:\n - LiquidGen-S: ~55M params (256px, fast training)\n - LiquidGen-B: ~140M params (256/512px, balanced)\n - LiquidGen-L: ~280M params (512px, high quality)\n \"\"\"\n \n def __init__(\n self,\n in_channels: int = 16,\n patch_size: int = 2,\n embed_dim: int = 512,\n depth: int = 16,\n spatial_kernel: int = 7,\n scan_kernel: int = 31,\n expand_ratio: float = 2.0,\n mlp_ratio: float = 4.0,\n drop_rate: float = 0.0,\n num_classes: int = 0,\n class_drop_prob: float = 0.1,\n use_zigzag: bool = True,\n ):\n super().__init__()\n self.in_channels = in_channels\n self.patch_size = patch_size\n self.embed_dim = embed_dim\n self.depth = depth\n self.num_classes = num_classes\n self.class_drop_prob = class_drop_prob\n \n cond_dim = embed_dim\n \n self.time_embed = TimestepEmbedding(cond_dim)\n self.class_embed = ClassEmbedding(num_classes, cond_dim) if num_classes > 0 else None\n \n self.patch_embed = nn.Conv2d(in_channels, embed_dim, patch_size, stride=patch_size)\n \n self.pos_embed_size = 32\n self.pos_embed = nn.Parameter(\n torch.randn(1, embed_dim, self.pos_embed_size, self.pos_embed_size) * 0.02\n )\n \n self.input_proj = nn.Sequential(\n nn.Conv2d(embed_dim, embed_dim, 3, padding=1, groups=embed_dim, bias=False),\n nn.Conv2d(embed_dim, embed_dim, 1, bias=True),\n nn.GELU(),\n )\n \n self.blocks = nn.ModuleList([\n LiquidBlock(embed_dim, cond_dim, spatial_kernel, scan_kernel,\n expand_ratio, mlp_ratio, drop_rate, use_zigzag)\n for _ in range(depth)\n ])\n \n self.final_norm = nn.GroupNorm(32, embed_dim)\n self.final_proj = nn.Sequential(\n nn.Conv2d(embed_dim, embed_dim, 3, padding=1, bias=True),\n nn.GELU(),\n )\n \n self.unpatch = nn.ConvTranspose2d(embed_dim, in_channels, patch_size, stride=patch_size)\n nn.init.zeros_(self.unpatch.weight)\n nn.init.zeros_(self.unpatch.bias)\n \n self.apply(self._init_weights)\n \n def _init_weights(self, m):\n if isinstance(m, nn.Conv2d):\n nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n if m.bias is not None:\n nn.init.zeros_(m.bias)\n elif isinstance(m, nn.Linear):\n nn.init.xavier_uniform_(m.weight)\n if m.bias is not None:\n nn.init.zeros_(m.bias)\n elif isinstance(m, nn.Embedding):\n nn.init.normal_(m.weight, std=0.02)\n \n def _interpolate_pos_embed(self, H: int, W: int) -> torch.Tensor:\n if H == self.pos_embed_size and W == self.pos_embed_size:\n return self.pos_embed\n return F.interpolate(self.pos_embed, size=(H, W), mode='bilinear', align_corners=False)\n \n def forward(\n self, x: torch.Tensor, t: torch.Tensor, class_labels: Optional[torch.Tensor] = None,\n ) -> torch.Tensor:\n \"\"\"\n Predict velocity field for flow matching.\n Args:\n x: [B, C, H, W] noisy latent (C=16 for Flux VAE)\n t: [B] timestep in [0, 1]\n class_labels: [B] optional class labels\n Returns:\n v: [B, C, H, W] predicted velocity\n \"\"\"\n cond = self.time_embed(t)\n if self.class_embed is not None and class_labels is not None:\n drop_p = self.class_drop_prob if self.training else 0.0\n cond = cond + self.class_embed(class_labels, drop_prob=drop_p)\n \n h = self.patch_embed(x)\n B, C, H_p, W_p = h.shape\n h = h + self._interpolate_pos_embed(H_p, W_p)\n h = self.input_proj(h)\n \n # U-Net style long skip connections\n skip_connections = []\n mid = self.depth // 2\n for i, block in enumerate(self.blocks):\n if i < mid:\n skip_connections.append(h)\n elif i >= mid and len(skip_connections) > 0:\n skip = skip_connections.pop()\n h = h + skip\n h = block(h, cond)\n \n h = self.final_norm(h)\n h = self.final_proj(h)\n v = self.unpatch(h)\n return v\n \n def count_params(self) -> int:\n return sum(p.numel() for p in self.parameters() if p.requires_grad)\n\n\n# =============================================================================\n# Model Presets\n# =============================================================================\n\ndef liquidgen_small(**kwargs) -> LiquidGen:\n \"\"\"~55M params - for 256px, fast training/testing\"\"\"\n defaults = dict(\n embed_dim=512, depth=12, spatial_kernel=7, scan_kernel=31,\n expand_ratio=2.0, mlp_ratio=3.0, use_zigzag=True,\n )\n defaults.update(kwargs)\n return LiquidGen(**defaults)\n\ndef liquidgen_base(**kwargs) -> LiquidGen:\n \"\"\"~140M params - for 256/512px, balanced (fits T4 16GB easily)\"\"\"\n defaults = dict(\n embed_dim=640, depth=18, spatial_kernel=7, scan_kernel=31,\n expand_ratio=2.0, mlp_ratio=4.0, use_zigzag=True,\n )\n defaults.update(kwargs)\n return LiquidGen(**defaults)\n\ndef liquidgen_large(**kwargs) -> LiquidGen:\n \"\"\"~280M params - for 512px, high quality (fits T4 16GB with small batch)\"\"\"\n defaults = dict(\n embed_dim=768, depth=24, spatial_kernel=7, scan_kernel=31,\n expand_ratio=2.5, mlp_ratio=4.0, use_zigzag=True,\n )\n defaults.update(kwargs)\n return LiquidGen(**defaults)\n\n\nif __name__ == \"__main__\":\n device = \"cpu\"\n for name, factory in [(\"Small\", liquidgen_small), (\"Base\", liquidgen_base), (\"Large\", liquidgen_large)]:\n model = factory(num_classes=27).to(device)\n print(f\"LiquidGen-{name}: {model.count_params() / 1e6:.1f}M params\")\n \n x = torch.randn(2, 16, 32, 32, device=device)\n t = torch.rand(2, device=device)\n labels = torch.randint(0, 27, (2,), device=device)\n v = model(x, t, labels)\n assert v.shape == x.shape\n \n x512 = torch.randn(1, 16, 64, 64, device=device)\n v512 = model(x512, t[:1], labels[:1])\n assert v512.shape == x512.shape\n print(f\" 256px \u2705 512px \u2705\")\n del model\n \n print(\"\\n\u2705 All tests passed!\")\n"
|
| 114 |
+
]
|
| 115 |
+
},
|
| 116 |
+
{
|
| 117 |
+
"cell_type": "markdown",
|
| 118 |
+
"metadata": {},
|
| 119 |
+
"source": [
|
| 120 |
+
"## \ud83d\udd04 Training Utilities"
|
| 121 |
+
]
|
| 122 |
+
},
|
| 123 |
+
{
|
| 124 |
+
"cell_type": "code",
|
| 125 |
+
"execution_count": null,
|
| 126 |
+
"metadata": {},
|
| 127 |
+
"outputs": [],
|
| 128 |
+
"source": [
|
| 129 |
+
"import os, json, time, math\n",
|
| 130 |
+
"import numpy as np\n",
|
| 131 |
+
"from torch.utils.data import DataLoader, Dataset\n",
|
| 132 |
+
"from torch.amp import autocast, GradScaler\n",
|
| 133 |
+
"from torchvision import transforms\n",
|
| 134 |
+
"from torchvision.utils import save_image\n",
|
| 135 |
+
"from PIL import Image\n",
|
| 136 |
+
"\n",
|
| 137 |
+
"class FlowMatchingScheduler:\n",
|
| 138 |
+
" def __init__(self, min_t=0.001, max_t=0.999): self.min_t, self.max_t = min_t, max_t\n",
|
| 139 |
+
" def sample_timesteps(self, bs, dev): return torch.rand(bs, device=dev) * (self.max_t - self.min_t) + self.min_t\n",
|
| 140 |
+
" def add_noise(self, x0, noise, t): t = t.view(-1,1,1,1); return (1-t)*x0 + t*noise\n",
|
| 141 |
+
" def get_velocity_target(self, x0, noise): return noise - x0\n",
|
| 142 |
+
" @torch.no_grad()\n",
|
| 143 |
+
" def sample(self, model, shape, dev, num_steps=50, labels=None, cfg=1.0):\n",
|
| 144 |
+
" model.eval(); x = torch.randn(shape, device=dev)\n",
|
| 145 |
+
" dt = 1.0 / num_steps\n",
|
| 146 |
+
" for t_val in torch.linspace(1.0, dt, num_steps, device=dev):\n",
|
| 147 |
+
" t = torch.full((shape[0],), t_val.item(), device=dev)\n",
|
| 148 |
+
" with torch.amp.autocast(\"cuda\"):\n",
|
| 149 |
+
" if cfg > 1.0 and labels is not None:\n",
|
| 150 |
+
" vc = model(x,t,labels); vu = model(x,t,torch.zeros_like(labels))\n",
|
| 151 |
+
" v = vu + cfg * (vc - vu)\n",
|
| 152 |
+
" else: v = model(x, t, labels)\n",
|
| 153 |
+
" x = x - dt * v.float()\n",
|
| 154 |
+
" return x\n",
|
| 155 |
+
"\n",
|
| 156 |
+
"class EMAModel:\n",
|
| 157 |
+
" def __init__(self, model, decay=0.9999):\n",
|
| 158 |
+
" self.decay = decay\n",
|
| 159 |
+
" self.shadow = {n: p.clone().detach() for n,p in model.named_parameters() if p.requires_grad}\n",
|
| 160 |
+
" @torch.no_grad()\n",
|
| 161 |
+
" def update(self, model):\n",
|
| 162 |
+
" for n,p in model.named_parameters():\n",
|
| 163 |
+
" if p.requires_grad and n in self.shadow: self.shadow[n].mul_(self.decay).add_(p.data, alpha=1-self.decay)\n",
|
| 164 |
+
" def apply(self, model):\n",
|
| 165 |
+
" self.backup = {n: p.data.clone() for n,p in model.named_parameters() if p.requires_grad}\n",
|
| 166 |
+
" for n,p in model.named_parameters():\n",
|
| 167 |
+
" if p.requires_grad and n in self.shadow: p.data.copy_(self.shadow[n])\n",
|
| 168 |
+
" def restore(self, model):\n",
|
| 169 |
+
" for n,p in model.named_parameters():\n",
|
| 170 |
+
" if p.requires_grad and n in self.backup: p.data.copy_(self.backup[n])\n",
|
| 171 |
+
"\n",
|
| 172 |
+
"class ImageDataset(Dataset):\n",
|
| 173 |
+
" def __init__(self, ds, tf, img_col, lbl_col=\"\"): self.ds, self.tf, self.ic, self.lc = ds, tf, img_col, lbl_col\n",
|
| 174 |
+
" def __len__(self): return len(self.ds)\n",
|
| 175 |
+
" def __getitem__(self, i):\n",
|
| 176 |
+
" item = self.ds[i]; img = item[self.ic]\n",
|
| 177 |
+
" if img.mode != \"RGB\": img = img.convert(\"RGB\")\n",
|
| 178 |
+
" label = item[self.lc] if self.lc and self.lc in item else -1\n",
|
| 179 |
+
" return self.tf(img), label\n",
|
| 180 |
+
"\n",
|
| 181 |
+
"def cosine_sched(opt, warmup, total):\n",
|
| 182 |
+
" def lr(s):\n",
|
| 183 |
+
" if s < warmup: return s / max(1, warmup)\n",
|
| 184 |
+
" return max(0, 0.5*(1+math.cos(math.pi*(s-warmup)/max(1,total-warmup))))\n",
|
| 185 |
+
" return torch.optim.lr_scheduler.LambdaLR(opt, lr)\n",
|
| 186 |
+
"\n",
|
| 187 |
+
"MODEL_CONFIGS = {\n",
|
| 188 |
+
" \"small\": dict(embed_dim=512, depth=12, spatial_kernel=7, scan_kernel=31, expand_ratio=2.0, mlp_ratio=3.0),\n",
|
| 189 |
+
" \"base\": dict(embed_dim=640, depth=18, spatial_kernel=7, scan_kernel=31, expand_ratio=2.0, mlp_ratio=4.0),\n",
|
| 190 |
+
" \"large\": dict(embed_dim=768, depth=24, spatial_kernel=7, scan_kernel=31, expand_ratio=2.5, mlp_ratio=4.0),\n",
|
| 191 |
+
"}\n",
|
| 192 |
+
"print(\"\u2705 Training utilities ready!\")"
|
| 193 |
+
]
|
| 194 |
+
},
|
| 195 |
+
{
|
| 196 |
+
"cell_type": "markdown",
|
| 197 |
+
"metadata": {},
|
| 198 |
+
"source": [
|
| 199 |
+
"## \ud83d\udcca Load Dataset & VAE"
|
| 200 |
+
]
|
| 201 |
+
},
|
| 202 |
+
{
|
| 203 |
+
"cell_type": "code",
|
| 204 |
+
"execution_count": null,
|
| 205 |
+
"metadata": {},
|
| 206 |
+
"outputs": [],
|
| 207 |
+
"source": [
|
| 208 |
+
"from datasets import load_dataset\n",
|
| 209 |
+
"from diffusers import AutoencoderKL\n",
|
| 210 |
+
"\n",
|
| 211 |
+
"print(f\"Loading dataset: {DATASET_NAME}...\")\n",
|
| 212 |
+
"dataset = load_dataset(DATASET_NAME, split=\"train\")\n",
|
| 213 |
+
"print(f\" {len(dataset)} images\")\n",
|
| 214 |
+
"\n",
|
| 215 |
+
"transform = transforms.Compose([\n",
|
| 216 |
+
" transforms.Resize(IMAGE_SIZE, interpolation=transforms.InterpolationMode.LANCZOS),\n",
|
| 217 |
+
" transforms.CenterCrop(IMAGE_SIZE), transforms.RandomHorizontalFlip(), transforms.ToTensor(),\n",
|
| 218 |
+
"])\n",
|
| 219 |
+
"train_ds = ImageDataset(dataset, transform, IMAGE_COLUMN, LABEL_COLUMN)\n",
|
| 220 |
+
"train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)\n",
|
| 221 |
+
"\n",
|
| 222 |
+
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 223 |
+
"vae = AutoencoderKL.from_pretrained(VAE_ID, subfolder=VAE_SUBFOLDER, torch_dtype=torch.float16).to(device).eval()\n",
|
| 224 |
+
"for p in vae.parameters(): p.requires_grad_(False)\n",
|
| 225 |
+
"print(f\"VAE: {sum(p.numel() for p in vae.parameters())/1e6:.1f}M params (frozen)\")\n",
|
| 226 |
+
"SCALE, SHIFT = 0.3611, 0.1159\n",
|
| 227 |
+
"print(\"\u2705 Ready!\")"
|
| 228 |
+
]
|
| 229 |
+
},
|
| 230 |
+
{
|
| 231 |
+
"cell_type": "markdown",
|
| 232 |
+
"metadata": {},
|
| 233 |
+
"source": [
|
| 234 |
+
"## \ud83c\udfcb\ufe0f Create Model & Train"
|
| 235 |
+
]
|
| 236 |
+
},
|
| 237 |
+
{
|
| 238 |
+
"cell_type": "code",
|
| 239 |
+
"execution_count": null,
|
| 240 |
+
"metadata": {},
|
| 241 |
+
"outputs": [],
|
| 242 |
+
"source": [
|
| 243 |
+
"cfg = MODEL_CONFIGS[MODEL_SIZE].copy()\n",
|
| 244 |
+
"cfg[\"num_classes\"] = NUM_CLASSES; cfg[\"class_drop_prob\"] = 0.1; cfg[\"use_zigzag\"] = True\n",
|
| 245 |
+
"model = LiquidGen(**cfg).to(device)\n",
|
| 246 |
+
"print(f\"LiquidGen-{MODEL_SIZE}: {model.count_params()/1e6:.1f}M params\")\n",
|
| 247 |
+
"\n",
|
| 248 |
+
"optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)\n",
|
| 249 |
+
"total_steps = len(train_loader) * NUM_EPOCHS // GRADIENT_ACCUMULATION\n",
|
| 250 |
+
"scheduler = cosine_sched(optimizer, WARMUP_STEPS, total_steps)\n",
|
| 251 |
+
"ema = EMAModel(model, EMA_DECAY)\n",
|
| 252 |
+
"scaler = GradScaler(\"cuda\")\n",
|
| 253 |
+
"fm = FlowMatchingScheduler()\n",
|
| 254 |
+
"os.makedirs(f\"{OUTPUT_DIR}/samples\", exist_ok=True)\n",
|
| 255 |
+
"os.makedirs(f\"{OUTPUT_DIR}/checkpoints\", exist_ok=True)\n",
|
| 256 |
+
"print(f\"Total steps: {total_steps}, Effective batch: {BATCH_SIZE*GRADIENT_ACCUMULATION}\")\n",
|
| 257 |
+
"if torch.cuda.is_available(): print(f\"VRAM: {torch.cuda.memory_allocated()/1024**3:.2f} GB used\")"
|
| 258 |
+
]
|
| 259 |
+
},
|
| 260 |
+
{
|
| 261 |
+
"cell_type": "code",
|
| 262 |
+
"execution_count": null,
|
| 263 |
+
"metadata": {},
|
| 264 |
+
"outputs": [],
|
| 265 |
+
"source": [
|
| 266 |
+
"global_step = 0; loss_accum = 0.0; log_losses = []\n",
|
| 267 |
+
"print(\"\ud83d\ude80 Training!\n\")\n",
|
| 268 |
+
"t0 = time.time()\n",
|
| 269 |
+
"for epoch in range(NUM_EPOCHS):\n",
|
| 270 |
+
" model.train(); ep_loss = 0; ep_steps = 0; ep_t = time.time()\n",
|
| 271 |
+
" for bi, (imgs, lbls) in enumerate(train_loader):\n",
|
| 272 |
+
" imgs = imgs.to(device)\n",
|
| 273 |
+
" lbls = lbls.to(device) if NUM_CLASSES > 0 else None\n",
|
| 274 |
+
" with torch.no_grad():\n",
|
| 275 |
+
" lats = vae.encode(imgs.half()*2-1).latent_dist.sample()\n",
|
| 276 |
+
" lats = ((lats - SHIFT) * SCALE).float()\n",
|
| 277 |
+
" t = fm.sample_timesteps(lats.shape[0], device)\n",
|
| 278 |
+
" noise = torch.randn_like(lats)\n",
|
| 279 |
+
" xt = fm.add_noise(lats, noise, t)\n",
|
| 280 |
+
" vtgt = fm.get_velocity_target(lats, noise)\n",
|
| 281 |
+
" with autocast(\"cuda\"): loss = F.mse_loss(model(xt, t, lbls), vtgt) / GRADIENT_ACCUMULATION\n",
|
| 282 |
+
" scaler.scale(loss).backward()\n",
|
| 283 |
+
" loss_accum += loss.item()\n",
|
| 284 |
+
" if (bi+1) % GRADIENT_ACCUMULATION == 0:\n",
|
| 285 |
+
" scaler.unscale_(optimizer)\n",
|
| 286 |
+
" gn = torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)\n",
|
| 287 |
+
" scaler.step(optimizer); scaler.update(); optimizer.zero_grad(); scheduler.step()\n",
|
| 288 |
+
" ema.update(model); global_step += 1\n",
|
| 289 |
+
" if global_step % LOG_EVERY == 0:\n",
|
| 290 |
+
" al = loss_accum / LOG_EVERY; lr = optimizer.param_groups[0][\"lr\"]\n",
|
| 291 |
+
" vram = torch.cuda.memory_allocated()/1024**3 if torch.cuda.is_available() else 0\n",
|
| 292 |
+
" print(f\"step={global_step:>6d} | ep={epoch} | loss={al:.4f} | gn={gn:.2f} | lr={lr:.2e} | vram={vram:.1f}G\")\n",
|
| 293 |
+
" log_losses.append(al); loss_accum = 0\n",
|
| 294 |
+
" if math.isnan(al) or al > 50: print(\"\ud83d\udca5 Diverged!\"); break\n",
|
| 295 |
+
" if global_step % SAMPLE_EVERY == 0:\n",
|
| 296 |
+
" ema.apply(model); model.eval()\n",
|
| 297 |
+
" ls = IMAGE_SIZE // 8\n",
|
| 298 |
+
" sl = torch.randint(0, max(1,NUM_CLASSES), (4,), device=device) if NUM_CLASSES > 0 else None\n",
|
| 299 |
+
" samp = fm.sample(model, (4,16,ls,ls), device, NUM_SAMPLE_STEPS, sl, CFG_SCALE)\n",
|
| 300 |
+
" with torch.no_grad(): si = ((vae.decode(samp.half()/SCALE+SHIFT).sample+1)/2).clamp(0,1).float()\n",
|
| 301 |
+
" save_image(si, f\"{OUTPUT_DIR}/samples/step_{global_step:07d}.png\", nrow=2)\n",
|
| 302 |
+
" ema.restore(model); model.train()\n",
|
| 303 |
+
" if global_step % SAVE_EVERY == 0:\n",
|
| 304 |
+
" torch.save({\"model\":model.state_dict(),\"ema\":ema.shadow,\"step\":global_step,\"cfg\":cfg},\n",
|
| 305 |
+
" f\"{OUTPUT_DIR}/checkpoints/step_{global_step:07d}.pt\")\n",
|
| 306 |
+
" ep_loss += loss.item()*GRADIENT_ACCUMULATION; ep_steps += 1\n",
|
| 307 |
+
" print(f\"Epoch {epoch} | loss={ep_loss/max(ep_steps,1):.4f} | {time.time()-ep_t:.0f}s\")\n",
|
| 308 |
+
"torch.save({\"model\":model.state_dict(),\"ema\":ema.shadow,\"cfg\":cfg,\"step\":global_step},f\"{OUTPUT_DIR}/checkpoints/final.pt\")\n",
|
| 309 |
+
"print(f\"\ud83c\udf89 Done! {global_step} steps in {(time.time()-t0)/60:.1f} min\")"
|
| 310 |
+
]
|
| 311 |
+
},
|
| 312 |
+
{
|
| 313 |
+
"cell_type": "markdown",
|
| 314 |
+
"metadata": {},
|
| 315 |
+
"source": [
|
| 316 |
+
"## \ud83d\udcc8 Training Loss"
|
| 317 |
+
]
|
| 318 |
+
},
|
| 319 |
+
{
|
| 320 |
+
"cell_type": "code",
|
| 321 |
+
"execution_count": null,
|
| 322 |
+
"metadata": {},
|
| 323 |
+
"outputs": [],
|
| 324 |
+
"source": [
|
| 325 |
+
"import matplotlib.pyplot as plt\n",
|
| 326 |
+
"if log_losses:\n",
|
| 327 |
+
" plt.figure(figsize=(10,4)); plt.plot(log_losses); plt.xlabel(f\"Steps (\u00d7{LOG_EVERY})\"); plt.ylabel(\"Loss\")\n",
|
| 328 |
+
" plt.title(\"Training Loss\"); plt.grid(True, alpha=0.3); plt.savefig(f\"{OUTPUT_DIR}/loss.png\", dpi=150); plt.show()\n",
|
| 329 |
+
" print(f\"Min loss: {min(log_losses):.4f}\")"
|
| 330 |
+
]
|
| 331 |
+
},
|
| 332 |
+
{
|
| 333 |
+
"cell_type": "markdown",
|
| 334 |
+
"metadata": {},
|
| 335 |
+
"source": [
|
| 336 |
+
"## \ud83c\udfa8 Generate Images"
|
| 337 |
+
]
|
| 338 |
+
},
|
| 339 |
+
{
|
| 340 |
+
"cell_type": "code",
|
| 341 |
+
"execution_count": null,
|
| 342 |
+
"metadata": {},
|
| 343 |
+
"outputs": [],
|
| 344 |
+
"source": [
|
| 345 |
+
"ema.apply(model); model.eval()\n",
|
| 346 |
+
"N, STEPS, CFG = 8, 50, 2.5\n",
|
| 347 |
+
"ls = IMAGE_SIZE // 8\n",
|
| 348 |
+
"STYLES = [\"Abstract Expressionism\",\"Baroque\",\"Cubism\",\"Expressionism\",\"Impressionism\",\n",
|
| 349 |
+
" \"Pop Art\",\"Realism\",\"Romanticism\",\"Symbolism\",\"Ukiyo-e\"]\n",
|
| 350 |
+
"if NUM_CLASSES > 0:\n",
|
| 351 |
+
" for ci in range(min(NUM_CLASSES, 8)):\n",
|
| 352 |
+
" l = torch.full((N,), ci, device=device, dtype=torch.long)\n",
|
| 353 |
+
" s = fm.sample(model, (N,16,ls,ls), device, STEPS, l, CFG)\n",
|
| 354 |
+
" with torch.no_grad(): i = ((vae.decode(s.half()/SCALE+SHIFT).sample+1)/2).clamp(0,1).float()\n",
|
| 355 |
+
" nm = STYLES[ci] if ci < len(STYLES) else f\"Class_{ci}\"\n",
|
| 356 |
+
" save_image(i, f\"{OUTPUT_DIR}/gen_{nm.replace(chr(32),chr(95))}.png\", nrow=4)\n",
|
| 357 |
+
" print(f\"Generated: {nm}\")\n",
|
| 358 |
+
"else:\n",
|
| 359 |
+
" s = fm.sample(model, (N,16,ls,ls), device, STEPS)\n",
|
| 360 |
+
" with torch.no_grad(): i = ((vae.decode(s.half()/SCALE+SHIFT).sample+1)/2).clamp(0,1).float()\n",
|
| 361 |
+
" save_image(i, f\"{OUTPUT_DIR}/gen_uncond.png\", nrow=4)\n",
|
| 362 |
+
"ema.restore(model)\n",
|
| 363 |
+
"print(f\"\u2705 Saved to {OUTPUT_DIR}/\")"
|
| 364 |
+
]
|
| 365 |
+
},
|
| 366 |
+
{
|
| 367 |
+
"cell_type": "markdown",
|
| 368 |
+
"metadata": {},
|
| 369 |
+
"source": [
|
| 370 |
+
"## \ud83d\udce4 Display Results"
|
| 371 |
+
]
|
| 372 |
+
},
|
| 373 |
+
{
|
| 374 |
+
"cell_type": "code",
|
| 375 |
+
"execution_count": null,
|
| 376 |
+
"metadata": {},
|
| 377 |
+
"outputs": [],
|
| 378 |
+
"source": [
|
| 379 |
+
"from IPython.display import display\n",
|
| 380 |
+
"import glob\n",
|
| 381 |
+
"for f in sorted(glob.glob(f\"{OUTPUT_DIR}/samples/*.png\"))[-3:]:\n",
|
| 382 |
+
" print(os.path.basename(f)); display(Image.open(f))\n",
|
| 383 |
+
"for f in sorted(glob.glob(f\"{OUTPUT_DIR}/gen_*.png\")):\n",
|
| 384 |
+
" print(os.path.basename(f)); display(Image.open(f))"
|
| 385 |
+
]
|
| 386 |
+
},
|
| 387 |
+
{
|
| 388 |
+
"cell_type": "markdown",
|
| 389 |
+
"metadata": {},
|
| 390 |
+
"source": [
|
| 391 |
+
"## \ud83d\udcdd Architecture Reference\n",
|
| 392 |
+
"\n",
|
| 393 |
+
"### Core Equation (CfC Liquid Dynamics)\n",
|
| 394 |
+
"\n",
|
| 395 |
+
"\n",
|
| 396 |
+
"### Flow Matching\n",
|
| 397 |
+
"\n",
|
| 398 |
+
"\n",
|
| 399 |
+
"### Sampling (Euler ODE)\n",
|
| 400 |
+
"\n",
|
| 401 |
+
"\n",
|
| 402 |
+
"### References\n",
|
| 403 |
+
"- Hasani et al., \"Liquid Time-constant Networks\" (NeurIPS 2020)\n",
|
| 404 |
+
"- Hasani et al., \"Closed-form Continuous-depth Models\" (Nature MI 2022)\n",
|
| 405 |
+
"- Lechner et al., \"Neural Circuit Policies\" (Nature MI 2020)\n",
|
| 406 |
+
"- ZigMa (ECCV 2024), DiMSUM (NeurIPS 2024)\n",
|
| 407 |
+
"- Lipman et al., \"Flow Matching\" (2023), SiT (2024)\n"
|
| 408 |
+
]
|
| 409 |
+
}
|
| 410 |
+
]
|
| 411 |
+
}
|