asdf98 commited on
Commit
4ad2cc3
·
verified ·
1 Parent(s): c4858e4

Add Colab training notebook

Browse files
Files changed (1) hide show
  1. LiquidGen_Colab_Notebook.ipynb +411 -0
LiquidGen_Colab_Notebook.ipynb ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "gpuType": "T4"
8
+ },
9
+ "kernelspec": {
10
+ "name": "python3",
11
+ "display_name": "Python 3"
12
+ },
13
+ "accelerator": "GPU"
14
+ },
15
+ "cells": [
16
+ {
17
+ "cell_type": "markdown",
18
+ "metadata": {},
19
+ "source": [
20
+ "# \ud83e\uddea LiquidGen: Liquid Neural Network Image Generator\n",
21
+ "\n",
22
+ "**A novel attention-free diffusion model using CfC Liquid Neural Network dynamics.**\n",
23
+ "\n",
24
+ "### Key Features:\n",
25
+ "- **No Attention** \u2014 O(n) complexity using liquid time constants\n",
26
+ "- **Fully Parallelizable** \u2014 No sequential ODE solving\n",
27
+ "- **Flow Matching** \u2014 Modern velocity-prediction training\n",
28
+ "- **Frozen Flux VAE** \u2014 16-channel latent space\n",
29
+ "- **Fits 16GB VRAM** \u2014 Designed for Colab free tier\n",
30
+ "\n",
31
+ "Based on: Liquid Time-constant Networks (NeurIPS 2020), CfC (Nature MI 2022), ZigMa (ECCV 2024), DiMSUM (NeurIPS 2024)\n"
32
+ ]
33
+ },
34
+ {
35
+ "cell_type": "markdown",
36
+ "metadata": {},
37
+ "source": [
38
+ "## \ud83d\udce6 Install Dependencies"
39
+ ]
40
+ },
41
+ {
42
+ "cell_type": "code",
43
+ "execution_count": null,
44
+ "metadata": {},
45
+ "outputs": [],
46
+ "source": [
47
+ "!pip install -q torch torchvision diffusers datasets accelerate huggingface_hub Pillow"
48
+ ]
49
+ },
50
+ {
51
+ "cell_type": "markdown",
52
+ "metadata": {},
53
+ "source": [
54
+ "## \ud83d\udd27 Configuration"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": null,
60
+ "metadata": {},
61
+ "outputs": [],
62
+ "source": [
63
+ "MODEL_SIZE = \"small\" # \"small\" (~55M), \"base\" (~140M), \"large\" (~280M)\n",
64
+ "IMAGE_SIZE = 256 # 256 or 512\n",
65
+ "DATASET_NAME = \"huggan/wikiart\"\n",
66
+ "IMAGE_COLUMN = \"image\"\n",
67
+ "LABEL_COLUMN = \"style\" # \"style\" (27), \"genre\" (11), \"\" for unconditional\n",
68
+ "NUM_CLASSES = 27\n",
69
+ "BATCH_SIZE = 8\n",
70
+ "GRADIENT_ACCUMULATION = 4\n",
71
+ "LEARNING_RATE = 1e-4\n",
72
+ "WEIGHT_DECAY = 0.01\n",
73
+ "MAX_GRAD_NORM = 2.0\n",
74
+ "NUM_EPOCHS = 50\n",
75
+ "WARMUP_STEPS = 500\n",
76
+ "EMA_DECAY = 0.9999\n",
77
+ "NUM_SAMPLE_STEPS = 50\n",
78
+ "CFG_SCALE = 2.0\n",
79
+ "OUTPUT_DIR = \"/content/liquidgen_outputs\"\n",
80
+ "SAVE_EVERY = 2000\n",
81
+ "SAMPLE_EVERY = 500\n",
82
+ "LOG_EVERY = 50\n",
83
+ "PUSH_TO_HUB = False\n",
84
+ "HUB_MODEL_ID = \"\"\n",
85
+ "VAE_ID = \"black-forest-labs/FLUX.1-schnell\"\n",
86
+ "VAE_SUBFOLDER = \"vae\"\n",
87
+ "\n",
88
+ "import torch\n",
89
+ "if torch.cuda.is_available():\n",
90
+ " gpu = torch.cuda.get_device_name(0)\n",
91
+ " mem = torch.cuda.get_device_properties(0).total_mem / 1024**3\n",
92
+ " print(f\"GPU: {gpu} ({mem:.1f} GB)\")\n",
93
+ " if mem < 12: print(\"\u26a0\ufe0f Low VRAM! Use small model, 256px, bs=4\")\n",
94
+ " elif mem < 20: print(\"\u2705 T4 detected. Good for base model, 256px\")\n",
95
+ " else: print(\"\ud83d\ude80 Large GPU! Can run large model, 512px\")\n",
96
+ "else:\n",
97
+ " print(\"\u26a0\ufe0f No GPU! Go to Runtime \u2192 Change runtime type \u2192 GPU\")"
98
+ ]
99
+ },
100
+ {
101
+ "cell_type": "markdown",
102
+ "metadata": {},
103
+ "source": [
104
+ "## \ud83c\udfd7\ufe0f Model Architecture"
105
+ ]
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "execution_count": null,
110
+ "metadata": {},
111
+ "outputs": [],
112
+ "source": [
113
+ "\"\"\"\nLiquidGen: A Novel Liquid Neural Network Image Generation Model\n\nArchitecture Overview:\n- Frozen VAE encoder/decoder (FLUX.1-schnell, 16ch latent, 8x compression)\n- Liquid backbone for denoising (fully parallelizable, no attention, no sequential ODE)\n- Flow matching training objective (velocity prediction)\n\nKey Innovation: Replaces attention with Liquid Neural Network dynamics:\n- CfC-inspired closed-form update: x_new = \u03b1\u00b7x + (1-\u03b1)\u00b7h(x)\n- Per-channel learnable decay rates (liquid time constants)\n- Depthwise + pointwise convolutions for spatial context (no attention needed)\n- Zigzag spatial scanning for global receptive field\n- Gated stimulus with biologically-inspired sign constraints\n- U-Net style long skip connections from shallow to deep blocks\n\nMath Foundation (from Hasani et al., CfC paper):\n x_{t+1} = exp(-\u0394t/\u03c4_t) \u00b7 x_t + (1 - exp(-\u0394t/\u03c4_t)) \u00b7 h(x_t, u_t)\n \nOur parallelizable adaptation (inspired by LiquidTAD):\n \u03b1 = exp(-softplus(\u03c1)) [per-channel learnable decay]\n h = gate \u00b7 stimulus [gated depthwise conv output] \n out = \u03b1 \u00b7 x + (1 - \u03b1) \u00b7 h [liquid relaxation blend]\n\nThis removes the input-dependent \u03c4 (which requires sequential computation)\nand replaces it with a per-channel learned decay \u2014 making it fully parallel\nwhile preserving the liquid dynamics' ability to blend old state with new input.\n\nDesign for 16GB VRAM (Colab free tier):\n- VAE frozen: ~1GB\n- Backbone: ~55-280M params (~100-550MB in fp16) \n- Training overhead (grads + optimizer): ~3-8GB\n- Batch of latents: ~1-2GB\n- Total: fits comfortably in 16GB\n\nReferences:\n- Hasani et al., \"Liquid Time-constant Networks\" (NeurIPS 2020)\n- Hasani et al., \"Closed-form Continuous-depth Models\" (Nature Machine Intelligence 2022)\n- Lechner et al., \"Neural Circuit Policies\" (Nature Machine Intelligence 2020)\n- LiquidTAD (2025) - Parallelized liquid dynamics\n- ZigMa (ECCV 2024) - Zigzag scanning for SSM-based diffusion\n- DiMSUM (NeurIPS 2024) - Attention-free diffusion\n\"\"\"\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport math\nfrom typing import Optional, Tuple\n\n\n# =============================================================================\n# Building Blocks\n# =============================================================================\n\nclass LiquidTimeConstant(nn.Module):\n \"\"\"\n Core liquid time-constant module.\n \n Implements the CfC closed-form dynamics in a fully parallelizable way:\n out = \u03b1 \u00b7 x + (1 - \u03b1) \u00b7 stimulus\n \n where \u03b1 = exp(-softplus(\u03c1)) is a learnable per-channel decay rate,\n derived from the liquid time constant \u03c4 = 1/softplus(\u03c1).\n \n This preserves the key property of Liquid Neural Networks:\n - Exponential relaxation toward a target (stimulus)\n - Rate controlled by \u03c4 (how fast to adapt)\n - No sequential ODE solving required\n \n Stability guarantee (from LTC Theorem 1):\n \u03c4_sys \u2208 [\u03c4/(1+\u03c4W), \u03c4] \u2014 time constants NEVER explode\n \"\"\"\n def __init__(self, channels: int):\n super().__init__()\n # \u03c1 parameterizes the decay: \u03bb = softplus(\u03c1), \u03b1 = exp(-\u03bb)\n # Initialize \u03c1=0 \u2192 \u03bb\u22480.693 \u2192 \u03b1\u22480.5 (equal blend of old and new)\n self.rho = nn.Parameter(torch.zeros(channels))\n \n def forward(self, x: torch.Tensor, stimulus: torch.Tensor) -> torch.Tensor:\n \"\"\"\n x: [B, C, H, W] - current state (residual path)\n stimulus: [B, C, H, W] - computed target from context\n returns: [B, C, H, W] - liquid-blended output\n \"\"\"\n lam = F.softplus(self.rho) + 1e-5\n alpha = torch.exp(-lam).view(1, -1, 1, 1)\n return alpha * x + (1.0 - alpha) * stimulus\n\n\nclass GatedDepthwiseStimulusConv(nn.Module):\n \"\"\"\n Computes the spatial stimulus using depthwise-separable convolutions\n with a sigmoid gate (inspired by GLU / gated mechanisms in SSMs).\n \n This replaces attention for capturing local spatial context:\n - Depthwise conv: captures local spatial patterns per channel\n - Pointwise conv: mixes channel information\n - Sigmoid gate: controls information flow (like synaptic gating in NCP)\n \n Two parallel paths (inspired by NCP inter\u2192command split):\n 1. Stimulus path: DW-conv \u2192 PW-conv \u2192 GELU \u2192 project back\n 2. Gate path: DW-conv \u2192 PW-conv \u2192 sigmoid\n Output = stimulus * gate\n \"\"\"\n def __init__(self, channels: int, kernel_size: int = 7, expand_ratio: float = 2.0):\n super().__init__()\n hidden = int(channels * expand_ratio)\n \n self.stim_dw = nn.Conv2d(channels, channels, kernel_size, \n padding=kernel_size // 2, groups=channels, bias=False)\n self.stim_pw = nn.Conv2d(channels, hidden, 1, bias=False)\n self.stim_act = nn.GELU()\n self.stim_proj = nn.Conv2d(hidden, channels, 1, bias=False)\n \n self.gate_dw = nn.Conv2d(channels, channels, kernel_size,\n padding=kernel_size // 2, groups=channels, bias=False)\n self.gate_pw = nn.Conv2d(channels, channels, 1, bias=True)\n \n def forward(self, x: torch.Tensor) -> torch.Tensor:\n stim = self.stim_proj(self.stim_act(self.stim_pw(self.stim_dw(x))))\n gate = torch.sigmoid(self.gate_pw(self.gate_dw(x)))\n return stim * gate\n\n\nclass ChannelMixMLP(nn.Module):\n \"\"\"Channel mixing MLP with GELU activation (command neuron processing in NCP).\"\"\"\n def __init__(self, channels: int, expand_ratio: float = 4.0):\n super().__init__()\n hidden = int(channels * expand_ratio)\n self.fc1 = nn.Conv2d(channels, hidden, 1, bias=True)\n self.act = nn.GELU()\n self.fc2 = nn.Conv2d(hidden, channels, 1, bias=True)\n \n def forward(self, x: torch.Tensor) -> torch.Tensor:\n return self.fc2(self.act(self.fc1(x)))\n\n\nclass AdaptiveGroupNorm(nn.Module):\n \"\"\"\n Adaptive Group Normalization conditioned on timestep embedding.\n Applies: out = (1 + scale) * GroupNorm(x) + shift\n \"\"\"\n def __init__(self, channels: int, cond_dim: int, num_groups: int = 32):\n super().__init__()\n self.norm = nn.GroupNorm(num_groups, channels, affine=False)\n self.proj = nn.Linear(cond_dim, channels * 2)\n nn.init.zeros_(self.proj.weight)\n nn.init.zeros_(self.proj.bias)\n \n def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:\n h = self.norm(x)\n params = self.proj(cond)\n scale, shift = params.chunk(2, dim=-1)\n return h * (1.0 + scale.unsqueeze(-1).unsqueeze(-1)) + shift.unsqueeze(-1).unsqueeze(-1)\n\n\nclass ZigzagScan1D(nn.Module):\n \"\"\"\n 1D global mixing via zigzag-scanned depthwise conv.\n \n Gives quasi-global receptive field without attention's O(n\u00b2) cost.\n Zigzag scan preserves spatial continuity (from ZigMa, ECCV 2024).\n \"\"\"\n def __init__(self, channels: int, kernel_size: int = 31):\n super().__init__()\n self.conv1d = nn.Conv1d(channels, channels, kernel_size, \n padding=kernel_size // 2, groups=channels, bias=False)\n self.pw = nn.Conv1d(channels, channels, 1, bias=True)\n self.act = nn.GELU()\n \n def _zigzag_indices(self, H: int, W: int, device: torch.device) -> torch.Tensor:\n indices = []\n for i in range(H):\n row = list(range(i * W, (i + 1) * W))\n if i % 2 == 1:\n row = row[::-1]\n indices.extend(row)\n return torch.tensor(indices, device=device, dtype=torch.long)\n \n def _inverse_zigzag_indices(self, H: int, W: int, device: torch.device) -> torch.Tensor:\n fwd = self._zigzag_indices(H, W, device)\n inv = torch.empty_like(fwd)\n inv[fwd] = torch.arange(H * W, device=device)\n return inv\n \n def forward(self, x: torch.Tensor) -> torch.Tensor:\n B, C, H, W = x.shape\n zz_idx = self._zigzag_indices(H, W, x.device)\n inv_idx = self._inverse_zigzag_indices(H, W, x.device)\n x_flat = x.reshape(B, C, H * W)\n x_zz = x_flat[:, :, zz_idx]\n x_mixed = self.pw(self.act(self.conv1d(x_zz)))\n x_restored = x_mixed[:, :, inv_idx]\n return x_restored.reshape(B, C, H, W)\n\n\n# =============================================================================\n# Liquid Block: The core building block\n# =============================================================================\n\nclass LiquidBlock(nn.Module):\n \"\"\"\n A single Liquid Neural Network block for image denoising.\n \n Architecture (maps to NCP hierarchy):\n 1. [SENSORY] AdaGN conditioning \u2192 spatial context extraction\n 2. [INTER] Zigzag 1D scan for global mixing\n 3. [COMMAND] Liquid time-constant blend (CfC dynamics)\n 4. [MOTOR] Channel mixing MLP for output projection\n \n All operations are fully parallelizable \u2014 no sequential dependencies.\n \"\"\"\n def __init__(\n self, channels: int, cond_dim: int, spatial_kernel: int = 7,\n scan_kernel: int = 31, expand_ratio: float = 2.0, mlp_ratio: float = 4.0,\n drop_rate: float = 0.0, use_zigzag: bool = True,\n ):\n super().__init__()\n self.norm1 = AdaptiveGroupNorm(channels, cond_dim)\n self.norm2 = AdaptiveGroupNorm(channels, cond_dim)\n self.spatial_stim = GatedDepthwiseStimulusConv(channels, spatial_kernel, expand_ratio)\n self.use_zigzag = use_zigzag\n if use_zigzag:\n self.zigzag = ZigzagScan1D(channels, scan_kernel)\n self.zigzag_gate = nn.Parameter(torch.zeros(1))\n self.liquid = LiquidTimeConstant(channels)\n self.channel_mix = ChannelMixMLP(channels, mlp_ratio)\n self.liquid2 = LiquidTimeConstant(channels)\n self.drop = nn.Dropout2d(drop_rate) if drop_rate > 0 else nn.Identity()\n \n def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:\n h = self.norm1(x, cond)\n stim = self.spatial_stim(h)\n if self.use_zigzag:\n zz = self.zigzag(h)\n stim = stim + torch.sigmoid(self.zigzag_gate) * zz\n stim = self.drop(stim)\n x = self.liquid(x, stim)\n h2 = self.norm2(x, cond)\n ch_out = self.drop(self.channel_mix(h2))\n x = self.liquid2(x, ch_out)\n return x\n\n\n# =============================================================================\n# Timestep and Class Embeddings\n# =============================================================================\n\nclass TimestepEmbedding(nn.Module):\n \"\"\"Sinusoidal timestep embedding followed by MLP projection.\"\"\"\n def __init__(self, dim: int, freq_dim: int = 256):\n super().__init__()\n self.freq_dim = freq_dim\n self.mlp = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))\n \n def forward(self, t: torch.Tensor) -> torch.Tensor:\n half = self.freq_dim // 2\n freqs = torch.exp(-math.log(10000.0) * torch.arange(half, device=t.device, dtype=t.dtype) / half)\n args = t.unsqueeze(-1) * freqs.unsqueeze(0)\n emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)\n return self.mlp(emb)\n\n\nclass ClassEmbedding(nn.Module):\n \"\"\"Optional class-conditional embedding with CFG null embedding.\"\"\"\n def __init__(self, num_classes: int, dim: int):\n super().__init__()\n self.embed = nn.Embedding(num_classes, dim)\n self.null_embed = nn.Parameter(torch.randn(dim) * 0.02)\n \n def forward(self, labels: torch.Tensor, drop_prob: float = 0.0) -> torch.Tensor:\n emb = self.embed(labels)\n if self.training and drop_prob > 0:\n mask = torch.rand(labels.shape[0], 1, device=labels.device) < drop_prob\n emb = torch.where(mask, self.null_embed.unsqueeze(0).expand_as(emb), emb)\n return emb\n\n\n# =============================================================================\n# LiquidGen: Full Model\n# =============================================================================\n\nclass LiquidGen(nn.Module):\n \"\"\"\n LiquidGen: Liquid Neural Network Image Generator\n \n A novel attention-free diffusion model that uses Liquid Neural Network\n dynamics (CfC closed-form continuous-depth) for image generation.\n \n Features:\n - NO self-attention anywhere \u2014 O(n) complexity\n - NO sequential ODE solving \u2014 fully parallelizable\n - Liquid time constants for adaptive information blending\n - Zigzag scanning for global context\n - Depthwise convolutions for local spatial structure\n - Gated stimulus (biologically-inspired from NCP)\n - U-Net long skip connections (from U-ViT/DiM)\n \n Config Presets:\n - LiquidGen-S: ~55M params (256px, fast training)\n - LiquidGen-B: ~140M params (256/512px, balanced)\n - LiquidGen-L: ~280M params (512px, high quality)\n \"\"\"\n \n def __init__(\n self,\n in_channels: int = 16,\n patch_size: int = 2,\n embed_dim: int = 512,\n depth: int = 16,\n spatial_kernel: int = 7,\n scan_kernel: int = 31,\n expand_ratio: float = 2.0,\n mlp_ratio: float = 4.0,\n drop_rate: float = 0.0,\n num_classes: int = 0,\n class_drop_prob: float = 0.1,\n use_zigzag: bool = True,\n ):\n super().__init__()\n self.in_channels = in_channels\n self.patch_size = patch_size\n self.embed_dim = embed_dim\n self.depth = depth\n self.num_classes = num_classes\n self.class_drop_prob = class_drop_prob\n \n cond_dim = embed_dim\n \n self.time_embed = TimestepEmbedding(cond_dim)\n self.class_embed = ClassEmbedding(num_classes, cond_dim) if num_classes > 0 else None\n \n self.patch_embed = nn.Conv2d(in_channels, embed_dim, patch_size, stride=patch_size)\n \n self.pos_embed_size = 32\n self.pos_embed = nn.Parameter(\n torch.randn(1, embed_dim, self.pos_embed_size, self.pos_embed_size) * 0.02\n )\n \n self.input_proj = nn.Sequential(\n nn.Conv2d(embed_dim, embed_dim, 3, padding=1, groups=embed_dim, bias=False),\n nn.Conv2d(embed_dim, embed_dim, 1, bias=True),\n nn.GELU(),\n )\n \n self.blocks = nn.ModuleList([\n LiquidBlock(embed_dim, cond_dim, spatial_kernel, scan_kernel,\n expand_ratio, mlp_ratio, drop_rate, use_zigzag)\n for _ in range(depth)\n ])\n \n self.final_norm = nn.GroupNorm(32, embed_dim)\n self.final_proj = nn.Sequential(\n nn.Conv2d(embed_dim, embed_dim, 3, padding=1, bias=True),\n nn.GELU(),\n )\n \n self.unpatch = nn.ConvTranspose2d(embed_dim, in_channels, patch_size, stride=patch_size)\n nn.init.zeros_(self.unpatch.weight)\n nn.init.zeros_(self.unpatch.bias)\n \n self.apply(self._init_weights)\n \n def _init_weights(self, m):\n if isinstance(m, nn.Conv2d):\n nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n if m.bias is not None:\n nn.init.zeros_(m.bias)\n elif isinstance(m, nn.Linear):\n nn.init.xavier_uniform_(m.weight)\n if m.bias is not None:\n nn.init.zeros_(m.bias)\n elif isinstance(m, nn.Embedding):\n nn.init.normal_(m.weight, std=0.02)\n \n def _interpolate_pos_embed(self, H: int, W: int) -> torch.Tensor:\n if H == self.pos_embed_size and W == self.pos_embed_size:\n return self.pos_embed\n return F.interpolate(self.pos_embed, size=(H, W), mode='bilinear', align_corners=False)\n \n def forward(\n self, x: torch.Tensor, t: torch.Tensor, class_labels: Optional[torch.Tensor] = None,\n ) -> torch.Tensor:\n \"\"\"\n Predict velocity field for flow matching.\n Args:\n x: [B, C, H, W] noisy latent (C=16 for Flux VAE)\n t: [B] timestep in [0, 1]\n class_labels: [B] optional class labels\n Returns:\n v: [B, C, H, W] predicted velocity\n \"\"\"\n cond = self.time_embed(t)\n if self.class_embed is not None and class_labels is not None:\n drop_p = self.class_drop_prob if self.training else 0.0\n cond = cond + self.class_embed(class_labels, drop_prob=drop_p)\n \n h = self.patch_embed(x)\n B, C, H_p, W_p = h.shape\n h = h + self._interpolate_pos_embed(H_p, W_p)\n h = self.input_proj(h)\n \n # U-Net style long skip connections\n skip_connections = []\n mid = self.depth // 2\n for i, block in enumerate(self.blocks):\n if i < mid:\n skip_connections.append(h)\n elif i >= mid and len(skip_connections) > 0:\n skip = skip_connections.pop()\n h = h + skip\n h = block(h, cond)\n \n h = self.final_norm(h)\n h = self.final_proj(h)\n v = self.unpatch(h)\n return v\n \n def count_params(self) -> int:\n return sum(p.numel() for p in self.parameters() if p.requires_grad)\n\n\n# =============================================================================\n# Model Presets\n# =============================================================================\n\ndef liquidgen_small(**kwargs) -> LiquidGen:\n \"\"\"~55M params - for 256px, fast training/testing\"\"\"\n defaults = dict(\n embed_dim=512, depth=12, spatial_kernel=7, scan_kernel=31,\n expand_ratio=2.0, mlp_ratio=3.0, use_zigzag=True,\n )\n defaults.update(kwargs)\n return LiquidGen(**defaults)\n\ndef liquidgen_base(**kwargs) -> LiquidGen:\n \"\"\"~140M params - for 256/512px, balanced (fits T4 16GB easily)\"\"\"\n defaults = dict(\n embed_dim=640, depth=18, spatial_kernel=7, scan_kernel=31,\n expand_ratio=2.0, mlp_ratio=4.0, use_zigzag=True,\n )\n defaults.update(kwargs)\n return LiquidGen(**defaults)\n\ndef liquidgen_large(**kwargs) -> LiquidGen:\n \"\"\"~280M params - for 512px, high quality (fits T4 16GB with small batch)\"\"\"\n defaults = dict(\n embed_dim=768, depth=24, spatial_kernel=7, scan_kernel=31,\n expand_ratio=2.5, mlp_ratio=4.0, use_zigzag=True,\n )\n defaults.update(kwargs)\n return LiquidGen(**defaults)\n\n\nif __name__ == \"__main__\":\n device = \"cpu\"\n for name, factory in [(\"Small\", liquidgen_small), (\"Base\", liquidgen_base), (\"Large\", liquidgen_large)]:\n model = factory(num_classes=27).to(device)\n print(f\"LiquidGen-{name}: {model.count_params() / 1e6:.1f}M params\")\n \n x = torch.randn(2, 16, 32, 32, device=device)\n t = torch.rand(2, device=device)\n labels = torch.randint(0, 27, (2,), device=device)\n v = model(x, t, labels)\n assert v.shape == x.shape\n \n x512 = torch.randn(1, 16, 64, 64, device=device)\n v512 = model(x512, t[:1], labels[:1])\n assert v512.shape == x512.shape\n print(f\" 256px \u2705 512px \u2705\")\n del model\n \n print(\"\\n\u2705 All tests passed!\")\n"
114
+ ]
115
+ },
116
+ {
117
+ "cell_type": "markdown",
118
+ "metadata": {},
119
+ "source": [
120
+ "## \ud83d\udd04 Training Utilities"
121
+ ]
122
+ },
123
+ {
124
+ "cell_type": "code",
125
+ "execution_count": null,
126
+ "metadata": {},
127
+ "outputs": [],
128
+ "source": [
129
+ "import os, json, time, math\n",
130
+ "import numpy as np\n",
131
+ "from torch.utils.data import DataLoader, Dataset\n",
132
+ "from torch.amp import autocast, GradScaler\n",
133
+ "from torchvision import transforms\n",
134
+ "from torchvision.utils import save_image\n",
135
+ "from PIL import Image\n",
136
+ "\n",
137
+ "class FlowMatchingScheduler:\n",
138
+ " def __init__(self, min_t=0.001, max_t=0.999): self.min_t, self.max_t = min_t, max_t\n",
139
+ " def sample_timesteps(self, bs, dev): return torch.rand(bs, device=dev) * (self.max_t - self.min_t) + self.min_t\n",
140
+ " def add_noise(self, x0, noise, t): t = t.view(-1,1,1,1); return (1-t)*x0 + t*noise\n",
141
+ " def get_velocity_target(self, x0, noise): return noise - x0\n",
142
+ " @torch.no_grad()\n",
143
+ " def sample(self, model, shape, dev, num_steps=50, labels=None, cfg=1.0):\n",
144
+ " model.eval(); x = torch.randn(shape, device=dev)\n",
145
+ " dt = 1.0 / num_steps\n",
146
+ " for t_val in torch.linspace(1.0, dt, num_steps, device=dev):\n",
147
+ " t = torch.full((shape[0],), t_val.item(), device=dev)\n",
148
+ " with torch.amp.autocast(\"cuda\"):\n",
149
+ " if cfg > 1.0 and labels is not None:\n",
150
+ " vc = model(x,t,labels); vu = model(x,t,torch.zeros_like(labels))\n",
151
+ " v = vu + cfg * (vc - vu)\n",
152
+ " else: v = model(x, t, labels)\n",
153
+ " x = x - dt * v.float()\n",
154
+ " return x\n",
155
+ "\n",
156
+ "class EMAModel:\n",
157
+ " def __init__(self, model, decay=0.9999):\n",
158
+ " self.decay = decay\n",
159
+ " self.shadow = {n: p.clone().detach() for n,p in model.named_parameters() if p.requires_grad}\n",
160
+ " @torch.no_grad()\n",
161
+ " def update(self, model):\n",
162
+ " for n,p in model.named_parameters():\n",
163
+ " if p.requires_grad and n in self.shadow: self.shadow[n].mul_(self.decay).add_(p.data, alpha=1-self.decay)\n",
164
+ " def apply(self, model):\n",
165
+ " self.backup = {n: p.data.clone() for n,p in model.named_parameters() if p.requires_grad}\n",
166
+ " for n,p in model.named_parameters():\n",
167
+ " if p.requires_grad and n in self.shadow: p.data.copy_(self.shadow[n])\n",
168
+ " def restore(self, model):\n",
169
+ " for n,p in model.named_parameters():\n",
170
+ " if p.requires_grad and n in self.backup: p.data.copy_(self.backup[n])\n",
171
+ "\n",
172
+ "class ImageDataset(Dataset):\n",
173
+ " def __init__(self, ds, tf, img_col, lbl_col=\"\"): self.ds, self.tf, self.ic, self.lc = ds, tf, img_col, lbl_col\n",
174
+ " def __len__(self): return len(self.ds)\n",
175
+ " def __getitem__(self, i):\n",
176
+ " item = self.ds[i]; img = item[self.ic]\n",
177
+ " if img.mode != \"RGB\": img = img.convert(\"RGB\")\n",
178
+ " label = item[self.lc] if self.lc and self.lc in item else -1\n",
179
+ " return self.tf(img), label\n",
180
+ "\n",
181
+ "def cosine_sched(opt, warmup, total):\n",
182
+ " def lr(s):\n",
183
+ " if s < warmup: return s / max(1, warmup)\n",
184
+ " return max(0, 0.5*(1+math.cos(math.pi*(s-warmup)/max(1,total-warmup))))\n",
185
+ " return torch.optim.lr_scheduler.LambdaLR(opt, lr)\n",
186
+ "\n",
187
+ "MODEL_CONFIGS = {\n",
188
+ " \"small\": dict(embed_dim=512, depth=12, spatial_kernel=7, scan_kernel=31, expand_ratio=2.0, mlp_ratio=3.0),\n",
189
+ " \"base\": dict(embed_dim=640, depth=18, spatial_kernel=7, scan_kernel=31, expand_ratio=2.0, mlp_ratio=4.0),\n",
190
+ " \"large\": dict(embed_dim=768, depth=24, spatial_kernel=7, scan_kernel=31, expand_ratio=2.5, mlp_ratio=4.0),\n",
191
+ "}\n",
192
+ "print(\"\u2705 Training utilities ready!\")"
193
+ ]
194
+ },
195
+ {
196
+ "cell_type": "markdown",
197
+ "metadata": {},
198
+ "source": [
199
+ "## \ud83d\udcca Load Dataset & VAE"
200
+ ]
201
+ },
202
+ {
203
+ "cell_type": "code",
204
+ "execution_count": null,
205
+ "metadata": {},
206
+ "outputs": [],
207
+ "source": [
208
+ "from datasets import load_dataset\n",
209
+ "from diffusers import AutoencoderKL\n",
210
+ "\n",
211
+ "print(f\"Loading dataset: {DATASET_NAME}...\")\n",
212
+ "dataset = load_dataset(DATASET_NAME, split=\"train\")\n",
213
+ "print(f\" {len(dataset)} images\")\n",
214
+ "\n",
215
+ "transform = transforms.Compose([\n",
216
+ " transforms.Resize(IMAGE_SIZE, interpolation=transforms.InterpolationMode.LANCZOS),\n",
217
+ " transforms.CenterCrop(IMAGE_SIZE), transforms.RandomHorizontalFlip(), transforms.ToTensor(),\n",
218
+ "])\n",
219
+ "train_ds = ImageDataset(dataset, transform, IMAGE_COLUMN, LABEL_COLUMN)\n",
220
+ "train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)\n",
221
+ "\n",
222
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
223
+ "vae = AutoencoderKL.from_pretrained(VAE_ID, subfolder=VAE_SUBFOLDER, torch_dtype=torch.float16).to(device).eval()\n",
224
+ "for p in vae.parameters(): p.requires_grad_(False)\n",
225
+ "print(f\"VAE: {sum(p.numel() for p in vae.parameters())/1e6:.1f}M params (frozen)\")\n",
226
+ "SCALE, SHIFT = 0.3611, 0.1159\n",
227
+ "print(\"\u2705 Ready!\")"
228
+ ]
229
+ },
230
+ {
231
+ "cell_type": "markdown",
232
+ "metadata": {},
233
+ "source": [
234
+ "## \ud83c\udfcb\ufe0f Create Model & Train"
235
+ ]
236
+ },
237
+ {
238
+ "cell_type": "code",
239
+ "execution_count": null,
240
+ "metadata": {},
241
+ "outputs": [],
242
+ "source": [
243
+ "cfg = MODEL_CONFIGS[MODEL_SIZE].copy()\n",
244
+ "cfg[\"num_classes\"] = NUM_CLASSES; cfg[\"class_drop_prob\"] = 0.1; cfg[\"use_zigzag\"] = True\n",
245
+ "model = LiquidGen(**cfg).to(device)\n",
246
+ "print(f\"LiquidGen-{MODEL_SIZE}: {model.count_params()/1e6:.1f}M params\")\n",
247
+ "\n",
248
+ "optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)\n",
249
+ "total_steps = len(train_loader) * NUM_EPOCHS // GRADIENT_ACCUMULATION\n",
250
+ "scheduler = cosine_sched(optimizer, WARMUP_STEPS, total_steps)\n",
251
+ "ema = EMAModel(model, EMA_DECAY)\n",
252
+ "scaler = GradScaler(\"cuda\")\n",
253
+ "fm = FlowMatchingScheduler()\n",
254
+ "os.makedirs(f\"{OUTPUT_DIR}/samples\", exist_ok=True)\n",
255
+ "os.makedirs(f\"{OUTPUT_DIR}/checkpoints\", exist_ok=True)\n",
256
+ "print(f\"Total steps: {total_steps}, Effective batch: {BATCH_SIZE*GRADIENT_ACCUMULATION}\")\n",
257
+ "if torch.cuda.is_available(): print(f\"VRAM: {torch.cuda.memory_allocated()/1024**3:.2f} GB used\")"
258
+ ]
259
+ },
260
+ {
261
+ "cell_type": "code",
262
+ "execution_count": null,
263
+ "metadata": {},
264
+ "outputs": [],
265
+ "source": [
266
+ "global_step = 0; loss_accum = 0.0; log_losses = []\n",
267
+ "print(\"\ud83d\ude80 Training!\n\")\n",
268
+ "t0 = time.time()\n",
269
+ "for epoch in range(NUM_EPOCHS):\n",
270
+ " model.train(); ep_loss = 0; ep_steps = 0; ep_t = time.time()\n",
271
+ " for bi, (imgs, lbls) in enumerate(train_loader):\n",
272
+ " imgs = imgs.to(device)\n",
273
+ " lbls = lbls.to(device) if NUM_CLASSES > 0 else None\n",
274
+ " with torch.no_grad():\n",
275
+ " lats = vae.encode(imgs.half()*2-1).latent_dist.sample()\n",
276
+ " lats = ((lats - SHIFT) * SCALE).float()\n",
277
+ " t = fm.sample_timesteps(lats.shape[0], device)\n",
278
+ " noise = torch.randn_like(lats)\n",
279
+ " xt = fm.add_noise(lats, noise, t)\n",
280
+ " vtgt = fm.get_velocity_target(lats, noise)\n",
281
+ " with autocast(\"cuda\"): loss = F.mse_loss(model(xt, t, lbls), vtgt) / GRADIENT_ACCUMULATION\n",
282
+ " scaler.scale(loss).backward()\n",
283
+ " loss_accum += loss.item()\n",
284
+ " if (bi+1) % GRADIENT_ACCUMULATION == 0:\n",
285
+ " scaler.unscale_(optimizer)\n",
286
+ " gn = torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)\n",
287
+ " scaler.step(optimizer); scaler.update(); optimizer.zero_grad(); scheduler.step()\n",
288
+ " ema.update(model); global_step += 1\n",
289
+ " if global_step % LOG_EVERY == 0:\n",
290
+ " al = loss_accum / LOG_EVERY; lr = optimizer.param_groups[0][\"lr\"]\n",
291
+ " vram = torch.cuda.memory_allocated()/1024**3 if torch.cuda.is_available() else 0\n",
292
+ " print(f\"step={global_step:>6d} | ep={epoch} | loss={al:.4f} | gn={gn:.2f} | lr={lr:.2e} | vram={vram:.1f}G\")\n",
293
+ " log_losses.append(al); loss_accum = 0\n",
294
+ " if math.isnan(al) or al > 50: print(\"\ud83d\udca5 Diverged!\"); break\n",
295
+ " if global_step % SAMPLE_EVERY == 0:\n",
296
+ " ema.apply(model); model.eval()\n",
297
+ " ls = IMAGE_SIZE // 8\n",
298
+ " sl = torch.randint(0, max(1,NUM_CLASSES), (4,), device=device) if NUM_CLASSES > 0 else None\n",
299
+ " samp = fm.sample(model, (4,16,ls,ls), device, NUM_SAMPLE_STEPS, sl, CFG_SCALE)\n",
300
+ " with torch.no_grad(): si = ((vae.decode(samp.half()/SCALE+SHIFT).sample+1)/2).clamp(0,1).float()\n",
301
+ " save_image(si, f\"{OUTPUT_DIR}/samples/step_{global_step:07d}.png\", nrow=2)\n",
302
+ " ema.restore(model); model.train()\n",
303
+ " if global_step % SAVE_EVERY == 0:\n",
304
+ " torch.save({\"model\":model.state_dict(),\"ema\":ema.shadow,\"step\":global_step,\"cfg\":cfg},\n",
305
+ " f\"{OUTPUT_DIR}/checkpoints/step_{global_step:07d}.pt\")\n",
306
+ " ep_loss += loss.item()*GRADIENT_ACCUMULATION; ep_steps += 1\n",
307
+ " print(f\"Epoch {epoch} | loss={ep_loss/max(ep_steps,1):.4f} | {time.time()-ep_t:.0f}s\")\n",
308
+ "torch.save({\"model\":model.state_dict(),\"ema\":ema.shadow,\"cfg\":cfg,\"step\":global_step},f\"{OUTPUT_DIR}/checkpoints/final.pt\")\n",
309
+ "print(f\"\ud83c\udf89 Done! {global_step} steps in {(time.time()-t0)/60:.1f} min\")"
310
+ ]
311
+ },
312
+ {
313
+ "cell_type": "markdown",
314
+ "metadata": {},
315
+ "source": [
316
+ "## \ud83d\udcc8 Training Loss"
317
+ ]
318
+ },
319
+ {
320
+ "cell_type": "code",
321
+ "execution_count": null,
322
+ "metadata": {},
323
+ "outputs": [],
324
+ "source": [
325
+ "import matplotlib.pyplot as plt\n",
326
+ "if log_losses:\n",
327
+ " plt.figure(figsize=(10,4)); plt.plot(log_losses); plt.xlabel(f\"Steps (\u00d7{LOG_EVERY})\"); plt.ylabel(\"Loss\")\n",
328
+ " plt.title(\"Training Loss\"); plt.grid(True, alpha=0.3); plt.savefig(f\"{OUTPUT_DIR}/loss.png\", dpi=150); plt.show()\n",
329
+ " print(f\"Min loss: {min(log_losses):.4f}\")"
330
+ ]
331
+ },
332
+ {
333
+ "cell_type": "markdown",
334
+ "metadata": {},
335
+ "source": [
336
+ "## \ud83c\udfa8 Generate Images"
337
+ ]
338
+ },
339
+ {
340
+ "cell_type": "code",
341
+ "execution_count": null,
342
+ "metadata": {},
343
+ "outputs": [],
344
+ "source": [
345
+ "ema.apply(model); model.eval()\n",
346
+ "N, STEPS, CFG = 8, 50, 2.5\n",
347
+ "ls = IMAGE_SIZE // 8\n",
348
+ "STYLES = [\"Abstract Expressionism\",\"Baroque\",\"Cubism\",\"Expressionism\",\"Impressionism\",\n",
349
+ " \"Pop Art\",\"Realism\",\"Romanticism\",\"Symbolism\",\"Ukiyo-e\"]\n",
350
+ "if NUM_CLASSES > 0:\n",
351
+ " for ci in range(min(NUM_CLASSES, 8)):\n",
352
+ " l = torch.full((N,), ci, device=device, dtype=torch.long)\n",
353
+ " s = fm.sample(model, (N,16,ls,ls), device, STEPS, l, CFG)\n",
354
+ " with torch.no_grad(): i = ((vae.decode(s.half()/SCALE+SHIFT).sample+1)/2).clamp(0,1).float()\n",
355
+ " nm = STYLES[ci] if ci < len(STYLES) else f\"Class_{ci}\"\n",
356
+ " save_image(i, f\"{OUTPUT_DIR}/gen_{nm.replace(chr(32),chr(95))}.png\", nrow=4)\n",
357
+ " print(f\"Generated: {nm}\")\n",
358
+ "else:\n",
359
+ " s = fm.sample(model, (N,16,ls,ls), device, STEPS)\n",
360
+ " with torch.no_grad(): i = ((vae.decode(s.half()/SCALE+SHIFT).sample+1)/2).clamp(0,1).float()\n",
361
+ " save_image(i, f\"{OUTPUT_DIR}/gen_uncond.png\", nrow=4)\n",
362
+ "ema.restore(model)\n",
363
+ "print(f\"\u2705 Saved to {OUTPUT_DIR}/\")"
364
+ ]
365
+ },
366
+ {
367
+ "cell_type": "markdown",
368
+ "metadata": {},
369
+ "source": [
370
+ "## \ud83d\udce4 Display Results"
371
+ ]
372
+ },
373
+ {
374
+ "cell_type": "code",
375
+ "execution_count": null,
376
+ "metadata": {},
377
+ "outputs": [],
378
+ "source": [
379
+ "from IPython.display import display\n",
380
+ "import glob\n",
381
+ "for f in sorted(glob.glob(f\"{OUTPUT_DIR}/samples/*.png\"))[-3:]:\n",
382
+ " print(os.path.basename(f)); display(Image.open(f))\n",
383
+ "for f in sorted(glob.glob(f\"{OUTPUT_DIR}/gen_*.png\")):\n",
384
+ " print(os.path.basename(f)); display(Image.open(f))"
385
+ ]
386
+ },
387
+ {
388
+ "cell_type": "markdown",
389
+ "metadata": {},
390
+ "source": [
391
+ "## \ud83d\udcdd Architecture Reference\n",
392
+ "\n",
393
+ "### Core Equation (CfC Liquid Dynamics)\n",
394
+ "\n",
395
+ "\n",
396
+ "### Flow Matching\n",
397
+ "\n",
398
+ "\n",
399
+ "### Sampling (Euler ODE)\n",
400
+ "\n",
401
+ "\n",
402
+ "### References\n",
403
+ "- Hasani et al., \"Liquid Time-constant Networks\" (NeurIPS 2020)\n",
404
+ "- Hasani et al., \"Closed-form Continuous-depth Models\" (Nature MI 2022)\n",
405
+ "- Lechner et al., \"Neural Circuit Policies\" (Nature MI 2020)\n",
406
+ "- ZigMa (ECCV 2024), DiMSUM (NeurIPS 2024)\n",
407
+ "- Lipman et al., \"Flow Matching\" (2023), SiT (2024)\n"
408
+ ]
409
+ }
410
+ ]
411
+ }