Fix: SDXL VAE (no login) + streaming dataset in notebook
Browse files
LiquidGen_Colab_Notebook.ipynb
CHANGED
|
@@ -22,7 +22,7 @@
|
|
| 22 |
"**A novel attention-free diffusion model using CfC Liquid Neural Network dynamics.**\n",
|
| 23 |
"\n",
|
| 24 |
"- **No Attention** \u2014 O(n) complexity using liquid time constants\n",
|
| 25 |
-
"- **
|
| 26 |
"- **Streaming Dataset** \u2014 No full download, starts training immediately\n",
|
| 27 |
"- **Fits 16GB VRAM** \u2014 Designed for Colab free tier T4\n"
|
| 28 |
]
|
|
@@ -49,7 +49,6 @@
|
|
| 49 |
"source": [
|
| 50 |
"## \ud83d\udd27 Step 2: Configuration\n",
|
| 51 |
"\n",
|
| 52 |
-
"**Dataset options:**\n",
|
| 53 |
"| Dataset | Size | Download | Type |\n",
|
| 54 |
"|---------|------|----------|------|\n",
|
| 55 |
"| `huggan/wikiart` | ~80K | **Streaming** (no download!) | Art, 27 styles |\n",
|
|
@@ -63,7 +62,7 @@
|
|
| 63 |
"metadata": {},
|
| 64 |
"outputs": [],
|
| 65 |
"source": [
|
| 66 |
-
"# ============================================================================\n# CONFIGURATION\n# ============================================================================\n\nMODEL_SIZE = \"small\" # \"small\" (~55M), \"base\" (~140M), \"large\" (~280M)\nIMAGE_SIZE = 256 # 256 or 512\n\n# --- Dataset (Option A: WikiArt streaming \u2014 NO download) ---\nDATASET_NAME = \"huggan/wikiart\"\nIMAGE_COLUMN = \"image\"\nLABEL_COLUMN = \"style\" # \"style\"(27), \"genre\"(11), \"\" for unconditional\nNUM_CLASSES = 27\nUSE_STREAMING = True # KEY: no full download!\n\n# --- Dataset (Option B: Pokemon \u2014 small, fast
|
| 67 |
]
|
| 68 |
},
|
| 69 |
{
|
|
@@ -79,7 +78,7 @@
|
|
| 79 |
"metadata": {},
|
| 80 |
"outputs": [],
|
| 81 |
"source": [
|
| 82 |
-
"\"\"\"\nLiquidGen: A Novel Liquid Neural Network Image Generation Model\n\nArchitecture Overview:\n- Frozen VAE encoder/decoder (FLUX.1-schnell, 16ch latent, 8x compression)\n- Liquid backbone for denoising (fully parallelizable, no attention, no sequential ODE)\n- Flow matching training objective (velocity prediction)\n\nKey Innovation: Replaces attention with Liquid Neural Network dynamics:\n- CfC-inspired closed-form update: x_new = \u03b1\u00b7x + (1-\u03b1)\u00b7h(x)\n- Per-channel learnable decay rates (liquid time constants)\n- Depthwise + pointwise convolutions for spatial context (no attention needed)\n- Zigzag spatial scanning for global receptive field\n- Gated stimulus with biologically-inspired sign constraints\n- U-Net style long skip connections from shallow to deep blocks\n\nMath Foundation (from Hasani et al., CfC paper):\n x_{t+1} = exp(-\u0394t/\u03c4_t) \u00b7 x_t + (1 - exp(-\u0394t/\u03c4_t)) \u00b7 h(x_t, u_t)\n \nOur parallelizable adaptation (inspired by LiquidTAD):\n \u03b1 = exp(-softplus(\u03c1)) [per-channel learnable decay]\n h = gate \u00b7 stimulus [gated depthwise conv output] \n out = \u03b1 \u00b7 x + (1 - \u03b1) \u00b7 h [liquid relaxation blend]\n\nThis removes the input-dependent \u03c4 (which requires sequential computation)\nand replaces it with a per-channel learned decay \u2014 making it fully parallel\nwhile preserving the liquid dynamics' ability to blend old state with new input.\n\nDesign for 16GB VRAM (Colab free tier):\n- VAE frozen: ~1GB\n- Backbone: ~55-280M params (~100-550MB in fp16) \n- Training overhead (grads + optimizer): ~3-8GB\n- Batch of latents: ~1-2GB\n- Total: fits comfortably in 16GB\n\nReferences:\n- Hasani et al., \"Liquid Time-constant Networks\" (NeurIPS 2020)\n- Hasani et al., \"Closed-form Continuous-depth Models\" (Nature Machine Intelligence 2022)\n- Lechner et al., \"Neural Circuit Policies\" (Nature Machine Intelligence 2020)\n- LiquidTAD (2025) - Parallelized liquid dynamics\n- ZigMa (ECCV 2024) - Zigzag scanning for SSM-based diffusion\n- DiMSUM (NeurIPS 2024) - Attention-free diffusion\n\"\"\"\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport math\nfrom typing import Optional, Tuple\n\n\n# =============================================================================\n# Building Blocks\n# =============================================================================\n\nclass LiquidTimeConstant(nn.Module):\n \"\"\"\n Core liquid time-constant module.\n \n Implements the CfC closed-form dynamics in a fully parallelizable way:\n out = \u03b1 \u00b7 x + (1 - \u03b1) \u00b7 stimulus\n \n where \u03b1 = exp(-softplus(\u03c1)) is a learnable per-channel decay rate,\n derived from the liquid time constant \u03c4 = 1/softplus(\u03c1).\n \n This preserves the key property of Liquid Neural Networks:\n - Exponential relaxation toward a target (stimulus)\n - Rate controlled by \u03c4 (how fast to adapt)\n - No sequential ODE solving required\n \n Stability guarantee (from LTC Theorem 1):\n \u03c4_sys \u2208 [\u03c4/(1+\u03c4W), \u03c4] \u2014 time constants NEVER explode\n \"\"\"\n def __init__(self, channels: int):\n super().__init__()\n # \u03c1 parameterizes the decay: \u03bb = softplus(\u03c1), \u03b1 = exp(-\u03bb)\n # Initialize \u03c1=0 \u2192 \u03bb\u22480.693 \u2192 \u03b1\u22480.5 (equal blend of old and new)\n self.rho = nn.Parameter(torch.zeros(channels))\n \n def forward(self, x: torch.Tensor, stimulus: torch.Tensor) -> torch.Tensor:\n \"\"\"\n x: [B, C, H, W] - current state (residual path)\n stimulus: [B, C, H, W] - computed target from context\n returns: [B, C, H, W] - liquid-blended output\n \"\"\"\n lam = F.softplus(self.rho) + 1e-5\n alpha = torch.exp(-lam).view(1, -1, 1, 1)\n return alpha * x + (1.0 - alpha) * stimulus\n\n\nclass GatedDepthwiseStimulusConv(nn.Module):\n \"\"\"\n Computes the spatial stimulus using depthwise-separable convolutions\n with a sigmoid gate (inspired by GLU / gated mechanisms in SSMs).\n \n This replaces attention for capturing local spatial context:\n - Depthwise conv: captures local spatial patterns per channel\n - Pointwise conv: mixes channel information\n - Sigmoid gate: controls information flow (like synaptic gating in NCP)\n \n Two parallel paths (inspired by NCP inter\u2192command split):\n 1. Stimulus path: DW-conv \u2192 PW-conv \u2192 GELU \u2192 project back\n 2. Gate path: DW-conv \u2192 PW-conv \u2192 sigmoid\n Output = stimulus * gate\n \"\"\"\n def __init__(self, channels: int, kernel_size: int = 7, expand_ratio: float = 2.0):\n super().__init__()\n hidden = int(channels * expand_ratio)\n \n self.stim_dw = nn.Conv2d(channels, channels, kernel_size, \n padding=kernel_size // 2, groups=channels, bias=False)\n self.stim_pw = nn.Conv2d(channels, hidden, 1, bias=False)\n self.stim_act = nn.GELU()\n self.stim_proj = nn.Conv2d(hidden, channels, 1, bias=False)\n \n self.gate_dw = nn.Conv2d(channels, channels, kernel_size,\n padding=kernel_size // 2, groups=channels, bias=False)\n self.gate_pw = nn.Conv2d(channels, channels, 1, bias=True)\n \n def forward(self, x: torch.Tensor) -> torch.Tensor:\n stim = self.stim_proj(self.stim_act(self.stim_pw(self.stim_dw(x))))\n gate = torch.sigmoid(self.gate_pw(self.gate_dw(x)))\n return stim * gate\n\n\nclass ChannelMixMLP(nn.Module):\n \"\"\"Channel mixing MLP with GELU activation (command neuron processing in NCP).\"\"\"\n def __init__(self, channels: int, expand_ratio: float = 4.0):\n super().__init__()\n hidden = int(channels * expand_ratio)\n self.fc1 = nn.Conv2d(channels, hidden, 1, bias=True)\n self.act = nn.GELU()\n self.fc2 = nn.Conv2d(hidden, channels, 1, bias=True)\n \n def forward(self, x: torch.Tensor) -> torch.Tensor:\n return self.fc2(self.act(self.fc1(x)))\n\n\nclass AdaptiveGroupNorm(nn.Module):\n \"\"\"\n Adaptive Group Normalization conditioned on timestep embedding.\n Applies: out = (1 + scale) * GroupNorm(x) + shift\n \"\"\"\n def __init__(self, channels: int, cond_dim: int, num_groups: int = 32):\n super().__init__()\n self.norm = nn.GroupNorm(num_groups, channels, affine=False)\n self.proj = nn.Linear(cond_dim, channels * 2)\n nn.init.zeros_(self.proj.weight)\n nn.init.zeros_(self.proj.bias)\n \n def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:\n h = self.norm(x)\n params = self.proj(cond)\n scale, shift = params.chunk(2, dim=-1)\n return h * (1.0 + scale.unsqueeze(-1).unsqueeze(-1)) + shift.unsqueeze(-1).unsqueeze(-1)\n\n\nclass ZigzagScan1D(nn.Module):\n \"\"\"\n 1D global mixing via zigzag-scanned depthwise conv.\n \n Gives quasi-global receptive field without attention's O(n\u00b2) cost.\n Zigzag scan preserves spatial continuity (from ZigMa, ECCV 2024).\n \"\"\"\n def __init__(self, channels: int, kernel_size: int = 31):\n super().__init__()\n self.conv1d = nn.Conv1d(channels, channels, kernel_size, \n padding=kernel_size // 2, groups=channels, bias=False)\n self.pw = nn.Conv1d(channels, channels, 1, bias=True)\n self.act = nn.GELU()\n \n def _zigzag_indices(self, H: int, W: int, device: torch.device) -> torch.Tensor:\n indices = []\n for i in range(H):\n row = list(range(i * W, (i + 1) * W))\n if i % 2 == 1:\n row = row[::-1]\n indices.extend(row)\n return torch.tensor(indices, device=device, dtype=torch.long)\n \n def _inverse_zigzag_indices(self, H: int, W: int, device: torch.device) -> torch.Tensor:\n fwd = self._zigzag_indices(H, W, device)\n inv = torch.empty_like(fwd)\n inv[fwd] = torch.arange(H * W, device=device)\n return inv\n \n def forward(self, x: torch.Tensor) -> torch.Tensor:\n B, C, H, W = x.shape\n zz_idx = self._zigzag_indices(H, W, x.device)\n inv_idx = self._inverse_zigzag_indices(H, W, x.device)\n x_flat = x.reshape(B, C, H * W)\n x_zz = x_flat[:, :, zz_idx]\n x_mixed = self.pw(self.act(self.conv1d(x_zz)))\n x_restored = x_mixed[:, :, inv_idx]\n return x_restored.reshape(B, C, H, W)\n\n\n# =============================================================================\n# Liquid Block: The core building block\n# =============================================================================\n\nclass LiquidBlock(nn.Module):\n \"\"\"\n A single Liquid Neural Network block for image denoising.\n \n Architecture (maps to NCP hierarchy):\n 1. [SENSORY] AdaGN conditioning \u2192 spatial context extraction\n 2. [INTER] Zigzag 1D scan for global mixing\n 3. [COMMAND] Liquid time-constant blend (CfC dynamics)\n 4. [MOTOR] Channel mixing MLP for output projection\n \n All operations are fully parallelizable \u2014 no sequential dependencies.\n \"\"\"\n def __init__(\n self, channels: int, cond_dim: int, spatial_kernel: int = 7,\n scan_kernel: int = 31, expand_ratio: float = 2.0, mlp_ratio: float = 4.0,\n drop_rate: float = 0.0, use_zigzag: bool = True,\n ):\n super().__init__()\n self.norm1 = AdaptiveGroupNorm(channels, cond_dim)\n self.norm2 = AdaptiveGroupNorm(channels, cond_dim)\n self.spatial_stim = GatedDepthwiseStimulusConv(channels, spatial_kernel, expand_ratio)\n self.use_zigzag = use_zigzag\n if use_zigzag:\n self.zigzag = ZigzagScan1D(channels, scan_kernel)\n self.zigzag_gate = nn.Parameter(torch.zeros(1))\n self.liquid = LiquidTimeConstant(channels)\n self.channel_mix = ChannelMixMLP(channels, mlp_ratio)\n self.liquid2 = LiquidTimeConstant(channels)\n self.drop = nn.Dropout2d(drop_rate) if drop_rate > 0 else nn.Identity()\n \n def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:\n h = self.norm1(x, cond)\n stim = self.spatial_stim(h)\n if self.use_zigzag:\n zz = self.zigzag(h)\n stim = stim + torch.sigmoid(self.zigzag_gate) * zz\n stim = self.drop(stim)\n x = self.liquid(x, stim)\n h2 = self.norm2(x, cond)\n ch_out = self.drop(self.channel_mix(h2))\n x = self.liquid2(x, ch_out)\n return x\n\n\n# =============================================================================\n# Timestep and Class Embeddings\n# =============================================================================\n\nclass TimestepEmbedding(nn.Module):\n \"\"\"Sinusoidal timestep embedding followed by MLP projection.\"\"\"\n def __init__(self, dim: int, freq_dim: int = 256):\n super().__init__()\n self.freq_dim = freq_dim\n self.mlp = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))\n \n def forward(self, t: torch.Tensor) -> torch.Tensor:\n half = self.freq_dim // 2\n freqs = torch.exp(-math.log(10000.0) * torch.arange(half, device=t.device, dtype=t.dtype) / half)\n args = t.unsqueeze(-1) * freqs.unsqueeze(0)\n emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)\n return self.mlp(emb)\n\n\nclass ClassEmbedding(nn.Module):\n \"\"\"Optional class-conditional embedding with CFG null embedding.\"\"\"\n def __init__(self, num_classes: int, dim: int):\n super().__init__()\n self.embed = nn.Embedding(num_classes, dim)\n self.null_embed = nn.Parameter(torch.randn(dim) * 0.02)\n \n def forward(self, labels: torch.Tensor, drop_prob: float = 0.0) -> torch.Tensor:\n emb = self.embed(labels)\n if self.training and drop_prob > 0:\n mask = torch.rand(labels.shape[0], 1, device=labels.device) < drop_prob\n emb = torch.where(mask, self.null_embed.unsqueeze(0).expand_as(emb), emb)\n return emb\n\n\n# =============================================================================\n# LiquidGen: Full Model\n# =============================================================================\n\nclass LiquidGen(nn.Module):\n \"\"\"\n LiquidGen: Liquid Neural Network Image Generator\n \n A novel attention-free diffusion model that uses Liquid Neural Network\n dynamics (CfC closed-form continuous-depth) for image generation.\n \n Features:\n - NO self-attention anywhere \u2014 O(n) complexity\n - NO sequential ODE solving \u2014 fully parallelizable\n - Liquid time constants for adaptive information blending\n - Zigzag scanning for global context\n - Depthwise convolutions for local spatial structure\n - Gated stimulus (biologically-inspired from NCP)\n - U-Net long skip connections (from U-ViT/DiM)\n \n Config Presets:\n - LiquidGen-S: ~55M params (256px, fast training)\n - LiquidGen-B: ~140M params (256/512px, balanced)\n - LiquidGen-L: ~280M params (512px, high quality)\n \"\"\"\n \n def __init__(\n self,\n in_channels: int = 16,\n patch_size: int = 2,\n embed_dim: int = 512,\n depth: int = 16,\n spatial_kernel: int = 7,\n scan_kernel: int = 31,\n expand_ratio: float = 2.0,\n mlp_ratio: float = 4.0,\n drop_rate: float = 0.0,\n num_classes: int = 0,\n class_drop_prob: float = 0.1,\n use_zigzag: bool = True,\n ):\n super().__init__()\n self.in_channels = in_channels\n self.patch_size = patch_size\n self.embed_dim = embed_dim\n self.depth = depth\n self.num_classes = num_classes\n self.class_drop_prob = class_drop_prob\n \n cond_dim = embed_dim\n \n self.time_embed = TimestepEmbedding(cond_dim)\n self.class_embed = ClassEmbedding(num_classes, cond_dim) if num_classes > 0 else None\n \n self.patch_embed = nn.Conv2d(in_channels, embed_dim, patch_size, stride=patch_size)\n \n self.pos_embed_size = 32\n self.pos_embed = nn.Parameter(\n torch.randn(1, embed_dim, self.pos_embed_size, self.pos_embed_size) * 0.02\n )\n \n self.input_proj = nn.Sequential(\n nn.Conv2d(embed_dim, embed_dim, 3, padding=1, groups=embed_dim, bias=False),\n nn.Conv2d(embed_dim, embed_dim, 1, bias=True),\n nn.GELU(),\n )\n \n self.blocks = nn.ModuleList([\n LiquidBlock(embed_dim, cond_dim, spatial_kernel, scan_kernel,\n expand_ratio, mlp_ratio, drop_rate, use_zigzag)\n for _ in range(depth)\n ])\n \n self.final_norm = nn.GroupNorm(32, embed_dim)\n self.final_proj = nn.Sequential(\n nn.Conv2d(embed_dim, embed_dim, 3, padding=1, bias=True),\n nn.GELU(),\n )\n \n self.unpatch = nn.ConvTranspose2d(embed_dim, in_channels, patch_size, stride=patch_size)\n nn.init.zeros_(self.unpatch.weight)\n nn.init.zeros_(self.unpatch.bias)\n \n self.apply(self._init_weights)\n \n def _init_weights(self, m):\n if isinstance(m, nn.Conv2d):\n nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n if m.bias is not None:\n nn.init.zeros_(m.bias)\n elif isinstance(m, nn.Linear):\n nn.init.xavier_uniform_(m.weight)\n if m.bias is not None:\n nn.init.zeros_(m.bias)\n elif isinstance(m, nn.Embedding):\n nn.init.normal_(m.weight, std=0.02)\n \n def _interpolate_pos_embed(self, H: int, W: int) -> torch.Tensor:\n if H == self.pos_embed_size and W == self.pos_embed_size:\n return self.pos_embed\n return F.interpolate(self.pos_embed, size=(H, W), mode='bilinear', align_corners=False)\n \n def forward(\n self, x: torch.Tensor, t: torch.Tensor, class_labels: Optional[torch.Tensor] = None,\n ) -> torch.Tensor:\n \"\"\"\n Predict velocity field for flow matching.\n Args:\n x: [B, C, H, W] noisy latent (C=16 for Flux VAE)\n t: [B] timestep in [0, 1]\n class_labels: [B] optional class labels\n Returns:\n v: [B, C, H, W] predicted velocity\n \"\"\"\n cond = self.time_embed(t)\n if self.class_embed is not None and class_labels is not None:\n drop_p = self.class_drop_prob if self.training else 0.0\n cond = cond + self.class_embed(class_labels, drop_prob=drop_p)\n \n h = self.patch_embed(x)\n B, C, H_p, W_p = h.shape\n h = h + self._interpolate_pos_embed(H_p, W_p)\n h = self.input_proj(h)\n \n # U-Net style long skip connections\n skip_connections = []\n mid = self.depth // 2\n for i, block in enumerate(self.blocks):\n if i < mid:\n skip_connections.append(h)\n elif i >= mid and len(skip_connections) > 0:\n skip = skip_connections.pop()\n h = h + skip\n h = block(h, cond)\n \n h = self.final_norm(h)\n h = self.final_proj(h)\n v = self.unpatch(h)\n return v\n \n def count_params(self) -> int:\n return sum(p.numel() for p in self.parameters() if p.requires_grad)\n\n\n# =============================================================================\n# Model Presets\n# =============================================================================\n\ndef liquidgen_small(**kwargs) -> LiquidGen:\n \"\"\"~55M params - for 256px, fast training/testing\"\"\"\n defaults = dict(\n embed_dim=512, depth=12, spatial_kernel=7, scan_kernel=31,\n expand_ratio=2.0, mlp_ratio=3.0, use_zigzag=True,\n )\n defaults.update(kwargs)\n return LiquidGen(**defaults)\n\ndef liquidgen_base(**kwargs) -> LiquidGen:\n \"\"\"~140M params - for 256/512px, balanced (fits T4 16GB easily)\"\"\"\n defaults = dict(\n embed_dim=640, depth=18, spatial_kernel=7, scan_kernel=31,\n expand_ratio=2.0, mlp_ratio=4.0, use_zigzag=True,\n )\n defaults.update(kwargs)\n return LiquidGen(**defaults)\n\ndef liquidgen_large(**kwargs) -> LiquidGen:\n \"\"\"~280M params - for 512px, high quality (fits T4 16GB with small batch)\"\"\"\n defaults = dict(\n embed_dim=768, depth=24, spatial_kernel=7, scan_kernel=31,\n expand_ratio=2.5, mlp_ratio=4.0, use_zigzag=True,\n )\n defaults.update(kwargs)\n return LiquidGen(**defaults)\n\n\nif __name__ == \"__main__\":\n device = \"cpu\"\n for name, factory in [(\"Small\", liquidgen_small), (\"Base\", liquidgen_base), (\"Large\", liquidgen_large)]:\n model = factory(num_classes=27).to(device)\n print(f\"LiquidGen-{name}: {model.count_params() / 1e6:.1f}M params\")\n \n x = torch.randn(2, 16, 32, 32, device=device)\n t = torch.rand(2, device=device)\n labels = torch.randint(0, 27, (2,), device=device)\n v = model(x, t, labels)\n assert v.shape == x.shape\n \n x512 = torch.randn(1, 16, 64, 64, device=device)\n v512 = model(x512, t[:1], labels[:1])\n assert v512.shape == x512.shape\n print(f\" 256px \u2705 512px \u2705\")\n del model\n \n print(\"\\n\u2705 All tests passed!\")\n"
|
| 83 |
]
|
| 84 |
},
|
| 85 |
{
|
|
@@ -95,7 +94,7 @@
|
|
| 95 |
"metadata": {},
|
| 96 |
"outputs": [],
|
| 97 |
"source": [
|
| 98 |
-
"import os, time, math\nimport numpy as np\nfrom torch.utils.data import DataLoader, IterableDataset, Dataset\nfrom torch.amp import autocast, GradScaler\nfrom torchvision import transforms\nfrom torchvision.utils import save_image\nfrom PIL import Image\n\nclass StreamingImageDataset(IterableDataset):\n \"\"\"Streaming
|
| 99 |
]
|
| 100 |
},
|
| 101 |
{
|
|
@@ -111,7 +110,7 @@
|
|
| 111 |
"metadata": {},
|
| 112 |
"outputs": [],
|
| 113 |
"source": [
|
| 114 |
-
"from diffusers import AutoencoderKL\n\nif USE_STREAMING:\n print(f\"Loading {DATASET_NAME} in STREAMING mode (no full download)...\")\n train_ds = StreamingImageDataset(DATASET_NAME, IMAGE_COLUMN, LABEL_COLUMN, IMAGE_SIZE, buffer=1000)\n train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, num_workers=0, pin_memory=True)\n print(\" Streaming ready!
|
| 115 |
]
|
| 116 |
},
|
| 117 |
{
|
|
@@ -127,7 +126,7 @@
|
|
| 127 |
"metadata": {},
|
| 128 |
"outputs": [],
|
| 129 |
"source": [
|
| 130 |
-
"cfg = MODEL_CONFIGS[MODEL_SIZE].copy()\ncfg[\"num_classes\"] = NUM_CLASSES; cfg[\"class_drop_prob\"] = 0.1; cfg[\"use_zigzag\"] = True\nmodel = LiquidGen(**cfg).to(device)\nprint(f\"LiquidGen-{MODEL_SIZE}: {model.count_params()/1e6:.1f}M params\")\n\noptimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)\nscheduler = cosine_sched(optimizer, WARMUP_STEPS, MAX_STEPS)\nema = EMAModel(model, EMA_DECAY)\nscaler = GradScaler(\"cuda\")\nfm = FlowMatchingScheduler()\nos.makedirs(f\"{OUTPUT_DIR}/samples\", exist_ok=True)\nos.makedirs(f\"{OUTPUT_DIR}/checkpoints\", exist_ok=True)\nprint(f\"Training: {MAX_STEPS} steps, effective batch {BATCH_SIZE*GRADIENT_ACCUMULATION}\")\n"
|
| 131 |
]
|
| 132 |
},
|
| 133 |
{
|
|
@@ -143,7 +142,7 @@
|
|
| 143 |
"metadata": {},
|
| 144 |
"outputs": [],
|
| 145 |
"source": [
|
| 146 |
-
"global_step = 0; loss_accum = 0.0; log_losses = []; accum_count = 0\nprint(\"Training started!\\n\")\nt0 = time.time(); model.train()\n\nwhile global_step < MAX_STEPS:\n for imgs, lbls in train_loader:\n if global_step >= MAX_STEPS: break\n imgs = imgs.to(device)\n lbls = lbls.to(device) if NUM_CLASSES > 0 else None\n\n with torch.no_grad():\n lats = vae.encode(imgs.half()*2-1).latent_dist.sample()\n lats = (
|
| 147 |
]
|
| 148 |
},
|
| 149 |
{
|
|
@@ -175,7 +174,7 @@
|
|
| 175 |
"metadata": {},
|
| 176 |
"outputs": [],
|
| 177 |
"source": [
|
| 178 |
-
"ema.apply(model); model.eval()\nN, STEPS, G = 8, 50, 2.5; ls = IMAGE_SIZE // 8\nif NUM_CLASSES > 0:\n for ci in range(min(NUM_CLASSES, 6)):\n l = torch.full((N,), ci, device=device, dtype=torch.long)\n s = fm.sample(model, (N,
|
| 179 |
]
|
| 180 |
},
|
| 181 |
{
|
|
|
|
| 22 |
"**A novel attention-free diffusion model using CfC Liquid Neural Network dynamics.**\n",
|
| 23 |
"\n",
|
| 24 |
"- **No Attention** \u2014 O(n) complexity using liquid time constants\n",
|
| 25 |
+
"- **No Login Required** \u2014 Uses open SDXL VAE (MIT license)\n",
|
| 26 |
"- **Streaming Dataset** \u2014 No full download, starts training immediately\n",
|
| 27 |
"- **Fits 16GB VRAM** \u2014 Designed for Colab free tier T4\n"
|
| 28 |
]
|
|
|
|
| 49 |
"source": [
|
| 50 |
"## \ud83d\udd27 Step 2: Configuration\n",
|
| 51 |
"\n",
|
|
|
|
| 52 |
"| Dataset | Size | Download | Type |\n",
|
| 53 |
"|---------|------|----------|------|\n",
|
| 54 |
"| `huggan/wikiart` | ~80K | **Streaming** (no download!) | Art, 27 styles |\n",
|
|
|
|
| 62 |
"metadata": {},
|
| 63 |
"outputs": [],
|
| 64 |
"source": [
|
| 65 |
+
"# ============================================================================\n# CONFIGURATION\n# ============================================================================\n\nMODEL_SIZE = \"small\" # \"small\" (~55M), \"base\" (~140M), \"large\" (~280M)\nIMAGE_SIZE = 256 # 256 or 512\n\n# --- Dataset (Option A: WikiArt streaming \u2014 NO download) ---\nDATASET_NAME = \"huggan/wikiart\"\nIMAGE_COLUMN = \"image\"\nLABEL_COLUMN = \"style\" # \"style\"(27), \"genre\"(11), \"\" for unconditional\nNUM_CLASSES = 27\nUSE_STREAMING = True # KEY: no full download!\n\n# --- Dataset (Option B: Pokemon \u2014 small, fast, good for testing) ---\n# DATASET_NAME = \"reach-vb/pokemon-blip-captions\"\n# IMAGE_COLUMN = \"image\"; LABEL_COLUMN = \"\"; NUM_CLASSES = 0; USE_STREAMING = False\n\n# --- Training ---\nBATCH_SIZE = 8; GRADIENT_ACCUMULATION = 4\nLEARNING_RATE = 1e-4; WEIGHT_DECAY = 0.01; MAX_GRAD_NORM = 2.0\nMAX_STEPS = 20000; WARMUP_STEPS = 500; EMA_DECAY = 0.9999\nNUM_SAMPLE_STEPS = 50; CFG_SCALE = 2.0\n\n# --- Saving ---\nOUTPUT_DIR = \"/content/liquidgen_outputs\"\nSAVE_EVERY = 5000; SAMPLE_EVERY = 500; LOG_EVERY = 50\n\n# --- VAE (SDXL VAE - open, no login needed, fp16-safe) ---\nVAE_ID = \"madebyollin/sdxl-vae-fp16-fix\"\nSCALE = 0.13025 # SDXL VAE scaling factor (no shift needed)\n\nimport torch\nif torch.cuda.is_available():\n gpu = torch.cuda.get_device_name(0)\n mem = torch.cuda.get_device_properties(0).total_mem / 1024**3\n print(f\"GPU: {gpu} ({mem:.1f} GB)\")\nelse:\n print(\"No GPU! Go to Runtime > Change runtime type > GPU\")\n"
|
| 66 |
]
|
| 67 |
},
|
| 68 |
{
|
|
|
|
| 78 |
"metadata": {},
|
| 79 |
"outputs": [],
|
| 80 |
"source": [
|
| 81 |
+
"\"\"\"\nLiquidGen: A Novel Liquid Neural Network Image Generation Model\n\nArchitecture Overview:\n- Frozen VAE encoder/decoder (SDXL VAE, 4ch latent, 8x compression, no login needed)\n- Liquid backbone for denoising (fully parallelizable, no attention, no sequential ODE)\n- Flow matching training objective (velocity prediction)\n\nKey Innovation: Replaces attention with Liquid Neural Network dynamics:\n- CfC-inspired closed-form update: x_new = \u03b1\u00b7x + (1-\u03b1)\u00b7h(x)\n- Per-channel learnable decay rates (liquid time constants)\n- Depthwise + pointwise convolutions for spatial context (no attention needed)\n- Zigzag spatial scanning for global receptive field\n- Gated stimulus with biologically-inspired sign constraints\n- U-Net style long skip connections from shallow to deep blocks\n\nMath Foundation (from Hasani et al., CfC paper):\n x_{t+1} = exp(-\u0394t/\u03c4_t) \u00b7 x_t + (1 - exp(-\u0394t/\u03c4_t)) \u00b7 h(x_t, u_t)\n \nOur parallelizable adaptation (inspired by LiquidTAD):\n \u03b1 = exp(-softplus(\u03c1)) [per-channel learnable decay]\n h = gate \u00b7 stimulus [gated depthwise conv output] \n out = \u03b1 \u00b7 x + (1 - \u03b1) \u00b7 h [liquid relaxation blend]\n\nThis removes the input-dependent \u03c4 (which requires sequential computation)\nand replaces it with a per-channel learned decay \u2014 making it fully parallel\nwhile preserving the liquid dynamics' ability to blend old state with new input.\n\nDesign for 16GB VRAM (Colab free tier):\n- VAE frozen: ~1GB\n- Backbone: ~55-280M params (~100-550MB in fp16) \n- Training overhead (grads + optimizer): ~3-8GB\n- Batch of latents: ~1-2GB\n- Total: fits comfortably in 16GB\n\nReferences:\n- Hasani et al., \"Liquid Time-constant Networks\" (NeurIPS 2020)\n- Hasani et al., \"Closed-form Continuous-depth Models\" (Nature Machine Intelligence 2022)\n- Lechner et al., \"Neural Circuit Policies\" (Nature Machine Intelligence 2020)\n- LiquidTAD (2025) - Parallelized liquid dynamics\n- ZigMa (ECCV 2024) - Zigzag scanning for SSM-based diffusion\n- DiMSUM (NeurIPS 2024) - Attention-free diffusion\n\"\"\"\n\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport math\nfrom typing import Optional, Tuple\n\n\n# =============================================================================\n# Building Blocks\n# =============================================================================\n\nclass LiquidTimeConstant(nn.Module):\n \"\"\"\n Core liquid time-constant module.\n \n Implements the CfC closed-form dynamics in a fully parallelizable way:\n out = \u03b1 \u00b7 x + (1 - \u03b1) \u00b7 stimulus\n \n where \u03b1 = exp(-softplus(\u03c1)) is a learnable per-channel decay rate,\n derived from the liquid time constant \u03c4 = 1/softplus(\u03c1).\n \n This preserves the key property of Liquid Neural Networks:\n - Exponential relaxation toward a target (stimulus)\n - Rate controlled by \u03c4 (how fast to adapt)\n - No sequential ODE solving required\n \n Stability guarantee (from LTC Theorem 1):\n \u03c4_sys \u2208 [\u03c4/(1+\u03c4W), \u03c4] \u2014 time constants NEVER explode\n \"\"\"\n def __init__(self, channels: int):\n super().__init__()\n # \u03c1 parameterizes the decay: \u03bb = softplus(\u03c1), \u03b1 = exp(-\u03bb)\n # Initialize \u03c1=0 \u2192 \u03bb\u22480.693 \u2192 \u03b1\u22480.5 (equal blend of old and new)\n self.rho = nn.Parameter(torch.zeros(channels))\n \n def forward(self, x: torch.Tensor, stimulus: torch.Tensor) -> torch.Tensor:\n \"\"\"\n x: [B, C, H, W] - current state (residual path)\n stimulus: [B, C, H, W] - computed target from context\n returns: [B, C, H, W] - liquid-blended output\n \"\"\"\n lam = F.softplus(self.rho) + 1e-5\n alpha = torch.exp(-lam).view(1, -1, 1, 1)\n return alpha * x + (1.0 - alpha) * stimulus\n\n\nclass GatedDepthwiseStimulusConv(nn.Module):\n \"\"\"\n Computes the spatial stimulus using depthwise-separable convolutions\n with a sigmoid gate (inspired by GLU / gated mechanisms in SSMs).\n \n This replaces attention for capturing local spatial context:\n - Depthwise conv: captures local spatial patterns per channel\n - Pointwise conv: mixes channel information\n - Sigmoid gate: controls information flow (like synaptic gating in NCP)\n \n Two parallel paths (inspired by NCP inter\u2192command split):\n 1. Stimulus path: DW-conv \u2192 PW-conv \u2192 GELU \u2192 project back\n 2. Gate path: DW-conv \u2192 PW-conv \u2192 sigmoid\n Output = stimulus * gate\n \"\"\"\n def __init__(self, channels: int, kernel_size: int = 7, expand_ratio: float = 2.0):\n super().__init__()\n hidden = int(channels * expand_ratio)\n \n self.stim_dw = nn.Conv2d(channels, channels, kernel_size, \n padding=kernel_size // 2, groups=channels, bias=False)\n self.stim_pw = nn.Conv2d(channels, hidden, 1, bias=False)\n self.stim_act = nn.GELU()\n self.stim_proj = nn.Conv2d(hidden, channels, 1, bias=False)\n \n self.gate_dw = nn.Conv2d(channels, channels, kernel_size,\n padding=kernel_size // 2, groups=channels, bias=False)\n self.gate_pw = nn.Conv2d(channels, channels, 1, bias=True)\n \n def forward(self, x: torch.Tensor) -> torch.Tensor:\n stim = self.stim_proj(self.stim_act(self.stim_pw(self.stim_dw(x))))\n gate = torch.sigmoid(self.gate_pw(self.gate_dw(x)))\n return stim * gate\n\n\nclass ChannelMixMLP(nn.Module):\n \"\"\"Channel mixing MLP with GELU activation (command neuron processing in NCP).\"\"\"\n def __init__(self, channels: int, expand_ratio: float = 4.0):\n super().__init__()\n hidden = int(channels * expand_ratio)\n self.fc1 = nn.Conv2d(channels, hidden, 1, bias=True)\n self.act = nn.GELU()\n self.fc2 = nn.Conv2d(hidden, channels, 1, bias=True)\n \n def forward(self, x: torch.Tensor) -> torch.Tensor:\n return self.fc2(self.act(self.fc1(x)))\n\n\nclass AdaptiveGroupNorm(nn.Module):\n \"\"\"\n Adaptive Group Normalization conditioned on timestep embedding.\n Applies: out = (1 + scale) * GroupNorm(x) + shift\n \"\"\"\n def __init__(self, channels: int, cond_dim: int, num_groups: int = 32):\n super().__init__()\n self.norm = nn.GroupNorm(num_groups, channels, affine=False)\n self.proj = nn.Linear(cond_dim, channels * 2)\n nn.init.zeros_(self.proj.weight)\n nn.init.zeros_(self.proj.bias)\n \n def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:\n h = self.norm(x)\n params = self.proj(cond)\n scale, shift = params.chunk(2, dim=-1)\n return h * (1.0 + scale.unsqueeze(-1).unsqueeze(-1)) + shift.unsqueeze(-1).unsqueeze(-1)\n\n\nclass ZigzagScan1D(nn.Module):\n \"\"\"\n 1D global mixing via zigzag-scanned depthwise conv.\n \n Gives quasi-global receptive field without attention's O(n\u00b2) cost.\n Zigzag scan preserves spatial continuity (from ZigMa, ECCV 2024).\n \"\"\"\n def __init__(self, channels: int, kernel_size: int = 31):\n super().__init__()\n self.conv1d = nn.Conv1d(channels, channels, kernel_size, \n padding=kernel_size // 2, groups=channels, bias=False)\n self.pw = nn.Conv1d(channels, channels, 1, bias=True)\n self.act = nn.GELU()\n \n def _zigzag_indices(self, H: int, W: int, device: torch.device) -> torch.Tensor:\n indices = []\n for i in range(H):\n row = list(range(i * W, (i + 1) * W))\n if i % 2 == 1:\n row = row[::-1]\n indices.extend(row)\n return torch.tensor(indices, device=device, dtype=torch.long)\n \n def _inverse_zigzag_indices(self, H: int, W: int, device: torch.device) -> torch.Tensor:\n fwd = self._zigzag_indices(H, W, device)\n inv = torch.empty_like(fwd)\n inv[fwd] = torch.arange(H * W, device=device)\n return inv\n \n def forward(self, x: torch.Tensor) -> torch.Tensor:\n B, C, H, W = x.shape\n zz_idx = self._zigzag_indices(H, W, x.device)\n inv_idx = self._inverse_zigzag_indices(H, W, x.device)\n x_flat = x.reshape(B, C, H * W)\n x_zz = x_flat[:, :, zz_idx]\n x_mixed = self.pw(self.act(self.conv1d(x_zz)))\n x_restored = x_mixed[:, :, inv_idx]\n return x_restored.reshape(B, C, H, W)\n\n\n# =============================================================================\n# Liquid Block: The core building block\n# =============================================================================\n\nclass LiquidBlock(nn.Module):\n \"\"\"\n A single Liquid Neural Network block for image denoising.\n \n Architecture (maps to NCP hierarchy):\n 1. [SENSORY] AdaGN conditioning \u2192 spatial context extraction\n 2. [INTER] Zigzag 1D scan for global mixing\n 3. [COMMAND] Liquid time-constant blend (CfC dynamics)\n 4. [MOTOR] Channel mixing MLP for output projection\n \n All operations are fully parallelizable \u2014 no sequential dependencies.\n \"\"\"\n def __init__(\n self, channels: int, cond_dim: int, spatial_kernel: int = 7,\n scan_kernel: int = 31, expand_ratio: float = 2.0, mlp_ratio: float = 4.0,\n drop_rate: float = 0.0, use_zigzag: bool = True,\n ):\n super().__init__()\n self.norm1 = AdaptiveGroupNorm(channels, cond_dim)\n self.norm2 = AdaptiveGroupNorm(channels, cond_dim)\n self.spatial_stim = GatedDepthwiseStimulusConv(channels, spatial_kernel, expand_ratio)\n self.use_zigzag = use_zigzag\n if use_zigzag:\n self.zigzag = ZigzagScan1D(channels, scan_kernel)\n self.zigzag_gate = nn.Parameter(torch.zeros(1))\n self.liquid = LiquidTimeConstant(channels)\n self.channel_mix = ChannelMixMLP(channels, mlp_ratio)\n self.liquid2 = LiquidTimeConstant(channels)\n self.drop = nn.Dropout2d(drop_rate) if drop_rate > 0 else nn.Identity()\n \n def forward(self, x: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:\n h = self.norm1(x, cond)\n stim = self.spatial_stim(h)\n if self.use_zigzag:\n zz = self.zigzag(h)\n stim = stim + torch.sigmoid(self.zigzag_gate) * zz\n stim = self.drop(stim)\n x = self.liquid(x, stim)\n h2 = self.norm2(x, cond)\n ch_out = self.drop(self.channel_mix(h2))\n x = self.liquid2(x, ch_out)\n return x\n\n\n# =============================================================================\n# Timestep and Class Embeddings\n# =============================================================================\n\nclass TimestepEmbedding(nn.Module):\n \"\"\"Sinusoidal timestep embedding followed by MLP projection.\"\"\"\n def __init__(self, dim: int, freq_dim: int = 256):\n super().__init__()\n self.freq_dim = freq_dim\n self.mlp = nn.Sequential(nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim))\n \n def forward(self, t: torch.Tensor) -> torch.Tensor:\n half = self.freq_dim // 2\n freqs = torch.exp(-math.log(10000.0) * torch.arange(half, device=t.device, dtype=t.dtype) / half)\n args = t.unsqueeze(-1) * freqs.unsqueeze(0)\n emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)\n return self.mlp(emb)\n\n\nclass ClassEmbedding(nn.Module):\n \"\"\"Optional class-conditional embedding with CFG null embedding.\"\"\"\n def __init__(self, num_classes: int, dim: int):\n super().__init__()\n self.embed = nn.Embedding(num_classes, dim)\n self.null_embed = nn.Parameter(torch.randn(dim) * 0.02)\n \n def forward(self, labels: torch.Tensor, drop_prob: float = 0.0) -> torch.Tensor:\n emb = self.embed(labels)\n if self.training and drop_prob > 0:\n mask = torch.rand(labels.shape[0], 1, device=labels.device) < drop_prob\n emb = torch.where(mask, self.null_embed.unsqueeze(0).expand_as(emb), emb)\n return emb\n\n\n# =============================================================================\n# LiquidGen: Full Model\n# =============================================================================\n\nclass LiquidGen(nn.Module):\n \"\"\"\n LiquidGen: Liquid Neural Network Image Generator\n \n A novel attention-free diffusion model that uses Liquid Neural Network\n dynamics (CfC closed-form continuous-depth) for image generation.\n \n Features:\n - NO self-attention anywhere \u2014 O(n) complexity\n - NO sequential ODE solving \u2014 fully parallelizable\n - Liquid time constants for adaptive information blending\n - Zigzag scanning for global context\n - Depthwise convolutions for local spatial structure\n - Gated stimulus (biologically-inspired from NCP)\n - U-Net long skip connections (from U-ViT/DiM)\n \n Config Presets:\n - LiquidGen-S: ~55M params (256px, fast training)\n - LiquidGen-B: ~140M params (256/512px, balanced)\n - LiquidGen-L: ~280M params (512px, high quality)\n \"\"\"\n \n def __init__(\n self,\n in_channels: int = 4, # 4 for SDXL VAE\n patch_size: int = 2,\n embed_dim: int = 512,\n depth: int = 16,\n spatial_kernel: int = 7,\n scan_kernel: int = 31,\n expand_ratio: float = 2.0,\n mlp_ratio: float = 4.0,\n drop_rate: float = 0.0,\n num_classes: int = 0,\n class_drop_prob: float = 0.1,\n use_zigzag: bool = True,\n ):\n super().__init__()\n self.in_channels = in_channels\n self.patch_size = patch_size\n self.embed_dim = embed_dim\n self.depth = depth\n self.num_classes = num_classes\n self.class_drop_prob = class_drop_prob\n \n cond_dim = embed_dim\n \n self.time_embed = TimestepEmbedding(cond_dim)\n self.class_embed = ClassEmbedding(num_classes, cond_dim) if num_classes > 0 else None\n \n self.patch_embed = nn.Conv2d(in_channels, embed_dim, patch_size, stride=patch_size)\n \n self.pos_embed_size = 32\n self.pos_embed = nn.Parameter(\n torch.randn(1, embed_dim, self.pos_embed_size, self.pos_embed_size) * 0.02\n )\n \n self.input_proj = nn.Sequential(\n nn.Conv2d(embed_dim, embed_dim, 3, padding=1, groups=embed_dim, bias=False),\n nn.Conv2d(embed_dim, embed_dim, 1, bias=True),\n nn.GELU(),\n )\n \n self.blocks = nn.ModuleList([\n LiquidBlock(embed_dim, cond_dim, spatial_kernel, scan_kernel,\n expand_ratio, mlp_ratio, drop_rate, use_zigzag)\n for _ in range(depth)\n ])\n \n self.final_norm = nn.GroupNorm(32, embed_dim)\n self.final_proj = nn.Sequential(\n nn.Conv2d(embed_dim, embed_dim, 3, padding=1, bias=True),\n nn.GELU(),\n )\n \n self.unpatch = nn.ConvTranspose2d(embed_dim, in_channels, patch_size, stride=patch_size)\n nn.init.zeros_(self.unpatch.weight)\n nn.init.zeros_(self.unpatch.bias)\n \n self.apply(self._init_weights)\n \n def _init_weights(self, m):\n if isinstance(m, nn.Conv2d):\n nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')\n if m.bias is not None:\n nn.init.zeros_(m.bias)\n elif isinstance(m, nn.Linear):\n nn.init.xavier_uniform_(m.weight)\n if m.bias is not None:\n nn.init.zeros_(m.bias)\n elif isinstance(m, nn.Embedding):\n nn.init.normal_(m.weight, std=0.02)\n \n def _interpolate_pos_embed(self, H: int, W: int) -> torch.Tensor:\n if H == self.pos_embed_size and W == self.pos_embed_size:\n return self.pos_embed\n return F.interpolate(self.pos_embed, size=(H, W), mode='bilinear', align_corners=False)\n \n def forward(\n self, x: torch.Tensor, t: torch.Tensor, class_labels: Optional[torch.Tensor] = None,\n ) -> torch.Tensor:\n \"\"\"\n Predict velocity field for flow matching.\n Args:\n x: [B, C, H, W] noisy latent (C=4 for SDXL VAE)\n t: [B] timestep in [0, 1]\n class_labels: [B] optional class labels\n Returns:\n v: [B, C, H, W] predicted velocity\n \"\"\"\n cond = self.time_embed(t)\n if self.class_embed is not None and class_labels is not None:\n drop_p = self.class_drop_prob if self.training else 0.0\n cond = cond + self.class_embed(class_labels, drop_prob=drop_p)\n \n h = self.patch_embed(x)\n B, C, H_p, W_p = h.shape\n h = h + self._interpolate_pos_embed(H_p, W_p)\n h = self.input_proj(h)\n \n # U-Net style long skip connections\n skip_connections = []\n mid = self.depth // 2\n for i, block in enumerate(self.blocks):\n if i < mid:\n skip_connections.append(h)\n elif i >= mid and len(skip_connections) > 0:\n skip = skip_connections.pop()\n h = h + skip\n h = block(h, cond)\n \n h = self.final_norm(h)\n h = self.final_proj(h)\n v = self.unpatch(h)\n return v\n \n def count_params(self) -> int:\n return sum(p.numel() for p in self.parameters() if p.requires_grad)\n\n\n# =============================================================================\n# Model Presets\n# =============================================================================\n\ndef liquidgen_small(**kwargs) -> LiquidGen:\n \"\"\"~55M params - for 256px, fast training/testing\"\"\"\n defaults = dict(\n embed_dim=512, depth=12, spatial_kernel=7, scan_kernel=31,\n expand_ratio=2.0, mlp_ratio=3.0, use_zigzag=True,\n )\n defaults.update(kwargs)\n return LiquidGen(**defaults)\n\ndef liquidgen_base(**kwargs) -> LiquidGen:\n \"\"\"~140M params - for 256/512px, balanced (fits T4 16GB easily)\"\"\"\n defaults = dict(\n embed_dim=640, depth=18, spatial_kernel=7, scan_kernel=31,\n expand_ratio=2.0, mlp_ratio=4.0, use_zigzag=True,\n )\n defaults.update(kwargs)\n return LiquidGen(**defaults)\n\ndef liquidgen_large(**kwargs) -> LiquidGen:\n \"\"\"~280M params - for 512px, high quality (fits T4 16GB with small batch)\"\"\"\n defaults = dict(\n embed_dim=768, depth=24, spatial_kernel=7, scan_kernel=31,\n expand_ratio=2.5, mlp_ratio=4.0, use_zigzag=True,\n )\n defaults.update(kwargs)\n return LiquidGen(**defaults)\n\n\nif __name__ == \"__main__\":\n device = \"cpu\"\n for name, factory in [(\"Small\", liquidgen_small), (\"Base\", liquidgen_base), (\"Large\", liquidgen_large)]:\n model = factory(num_classes=27).to(device)\n print(f\"LiquidGen-{name}: {model.count_params() / 1e6:.1f}M params\")\n \n # 256px: image/8 = 32x32 latent, 4 channels (SDXL VAE)\n x = torch.randn(2, 4, 32, 32, device=device)\n t = torch.rand(2, device=device)\n labels = torch.randint(0, 27, (2,), device=device)\n v = model(x, t, labels)\n assert v.shape == x.shape\n \n # 512px: image/8 = 64x64 latent\n x512 = torch.randn(1, 4, 64, 64, device=device)\n v512 = model(x512, t[:1], labels[:1])\n assert v512.shape == x512.shape\n print(f\" 256px \u2705 512px \u2705\")\n del model\n \n print(\"\\n\u2705 All tests passed!\")\n"
|
| 82 |
]
|
| 83 |
},
|
| 84 |
{
|
|
|
|
| 94 |
"metadata": {},
|
| 95 |
"outputs": [],
|
| 96 |
"source": [
|
| 97 |
+
"import os, time, math\nimport numpy as np\nfrom torch.utils.data import DataLoader, IterableDataset, Dataset\nfrom torch.amp import autocast, GradScaler\nfrom torchvision import transforms\nfrom torchvision.utils import save_image\nfrom PIL import Image\n\nclass StreamingImageDataset(IterableDataset):\n \"\"\"Streaming \u2014 NO full download. Images load on-the-fly.\"\"\"\n def __init__(self, name, img_col=\"image\", lbl_col=\"\", img_size=256,\n split=\"train\", config=\"\", buffer=1000, seed=42):\n super().__init__()\n self.name, self.img_col, self.lbl_col = name, img_col, lbl_col\n self.split, self.config, self.buffer, self.seed = split, config, buffer, seed\n self.tf = transforms.Compose([\n transforms.Resize(img_size, interpolation=transforms.InterpolationMode.LANCZOS),\n transforms.CenterCrop(img_size), transforms.RandomHorizontalFlip(), transforms.ToTensor()])\n\n def __iter__(self):\n from datasets import load_dataset\n kw = {\"name\": self.config} if self.config else {}\n ds = load_dataset(self.name, split=self.split, streaming=True, **kw)\n ds = ds.shuffle(seed=self.seed, buffer_size=self.buffer)\n for item in ds:\n try:\n img = item[self.img_col]\n if img.mode != \"RGB\": img = img.convert(\"RGB\")\n lbl = item[self.lbl_col] if self.lbl_col and self.lbl_col in item else -1\n yield self.tf(img), lbl\n except: continue\n\nclass MapImageDataset(Dataset):\n \"\"\"For small datasets (<500MB) \u2014 downloads once.\"\"\"\n def __init__(self, name, img_col=\"image\", lbl_col=\"\", img_size=256, split=\"train\"):\n from datasets import load_dataset\n print(f\"Downloading {name}...\")\n self.ds = load_dataset(name, split=split)\n self.img_col, self.lbl_col = img_col, lbl_col\n self.tf = transforms.Compose([\n transforms.Resize(img_size, interpolation=transforms.InterpolationMode.LANCZOS),\n transforms.CenterCrop(img_size), transforms.RandomHorizontalFlip(), transforms.ToTensor()])\n print(f\" {len(self.ds)} images\")\n def __len__(self): return len(self.ds)\n def __getitem__(self, i):\n item = self.ds[i]; img = item[self.img_col]\n if img.mode != \"RGB\": img = img.convert(\"RGB\")\n lbl = item[self.lbl_col] if self.lbl_col and self.lbl_col in item else -1\n return self.tf(img), lbl\n\nclass FlowMatchingScheduler:\n def __init__(self, min_t=0.001, max_t=0.999): self.min_t, self.max_t = min_t, max_t\n def sample_t(self, bs, dev): return torch.rand(bs, device=dev)*(self.max_t-self.min_t)+self.min_t\n def add_noise(self, x0, noise, t): return (1-t.view(-1,1,1,1))*x0 + t.view(-1,1,1,1)*noise\n def velocity(self, x0, noise): return noise - x0\n @torch.no_grad()\n def sample(self, model, shape, dev, steps=50, labels=None, cfg=1.0):\n model.eval(); x = torch.randn(shape, device=dev); dt = 1.0/steps\n for tv in torch.linspace(1.0, dt, steps, device=dev):\n t = torch.full((shape[0],), tv.item(), device=dev)\n with torch.amp.autocast(\"cuda\"):\n if cfg > 1.0 and labels is not None:\n vc = model(x,t,labels); vu = model(x,t,torch.zeros_like(labels))\n v = vu + cfg*(vc-vu)\n else: v = model(x,t,labels)\n x = x - dt * v.float()\n return x\n\nclass EMAModel:\n def __init__(self, model, decay=0.9999):\n self.decay = decay\n self.shadow = {n:p.clone().detach() for n,p in model.named_parameters() if p.requires_grad}\n @torch.no_grad()\n def update(self, m):\n for n,p in m.named_parameters():\n if p.requires_grad and n in self.shadow: self.shadow[n].mul_(self.decay).add_(p.data, alpha=1-self.decay)\n def apply(self, m):\n self.bk = {n:p.data.clone() for n,p in m.named_parameters() if p.requires_grad}\n for n,p in m.named_parameters():\n if p.requires_grad and n in self.shadow: p.data.copy_(self.shadow[n])\n def restore(self, m):\n for n,p in m.named_parameters():\n if p.requires_grad and n in self.bk: p.data.copy_(self.bk[n])\n\ndef cosine_sched(opt, warmup, total):\n def lr(s):\n if s < warmup: return s/max(1,warmup)\n return max(0, 0.5*(1+math.cos(math.pi*(s-warmup)/max(1,total-warmup))))\n return torch.optim.lr_scheduler.LambdaLR(opt, lr)\n\nMODEL_CONFIGS = {\n \"small\": dict(embed_dim=512, depth=12, spatial_kernel=7, scan_kernel=31, expand_ratio=2.0, mlp_ratio=3.0),\n \"base\": dict(embed_dim=640, depth=18, spatial_kernel=7, scan_kernel=31, expand_ratio=2.0, mlp_ratio=4.0),\n \"large\": dict(embed_dim=768, depth=24, spatial_kernel=7, scan_kernel=31, expand_ratio=2.5, mlp_ratio=4.0),\n}\nprint(\"Training utilities ready!\")\n"
|
| 98 |
]
|
| 99 |
},
|
| 100 |
{
|
|
|
|
| 110 |
"metadata": {},
|
| 111 |
"outputs": [],
|
| 112 |
"source": [
|
| 113 |
+
"from diffusers import AutoencoderKL\n\nif USE_STREAMING:\n print(f\"Loading {DATASET_NAME} in STREAMING mode (no full download)...\")\n train_ds = StreamingImageDataset(DATASET_NAME, IMAGE_COLUMN, LABEL_COLUMN, IMAGE_SIZE, buffer=1000)\n train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, num_workers=0, pin_memory=True)\n print(\" Streaming ready!\")\nelse:\n train_ds = MapImageDataset(DATASET_NAME, IMAGE_COLUMN, LABEL_COLUMN, IMAGE_SIZE)\n train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\nprint(f\"Loading VAE: {VAE_ID} (no login needed)...\")\nvae = AutoencoderKL.from_pretrained(VAE_ID, torch_dtype=torch.float16).to(device).eval()\nfor p in vae.parameters(): p.requires_grad_(False)\nprint(f\" VAE: {sum(p.numel() for p in vae.parameters())/1e6:.1f}M params (frozen)\")\nLAT_CH = vae.config.latent_channels # 4 for SDXL\nprint(f\" Latent: {LAT_CH} channels, scaling={SCALE}\")\nprint(\"Ready!\")\n"
|
| 114 |
]
|
| 115 |
},
|
| 116 |
{
|
|
|
|
| 126 |
"metadata": {},
|
| 127 |
"outputs": [],
|
| 128 |
"source": [
|
| 129 |
+
"cfg = MODEL_CONFIGS[MODEL_SIZE].copy()\ncfg[\"num_classes\"] = NUM_CLASSES; cfg[\"class_drop_prob\"] = 0.1; cfg[\"use_zigzag\"] = True\ncfg[\"in_channels\"] = LAT_CH # Match VAE latent channels\nmodel = LiquidGen(**cfg).to(device)\nprint(f\"LiquidGen-{MODEL_SIZE}: {model.count_params()/1e6:.1f}M params\")\n\noptimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)\nscheduler = cosine_sched(optimizer, WARMUP_STEPS, MAX_STEPS)\nema = EMAModel(model, EMA_DECAY)\nscaler = GradScaler(\"cuda\")\nfm = FlowMatchingScheduler()\nos.makedirs(f\"{OUTPUT_DIR}/samples\", exist_ok=True)\nos.makedirs(f\"{OUTPUT_DIR}/checkpoints\", exist_ok=True)\nprint(f\"Training: {MAX_STEPS} steps, effective batch {BATCH_SIZE*GRADIENT_ACCUMULATION}\")\nif torch.cuda.is_available():\n print(f\"VRAM used: {torch.cuda.memory_allocated()/1024**3:.2f} GB\")\n"
|
| 130 |
]
|
| 131 |
},
|
| 132 |
{
|
|
|
|
| 142 |
"metadata": {},
|
| 143 |
"outputs": [],
|
| 144 |
"source": [
|
| 145 |
+
"global_step = 0; loss_accum = 0.0; log_losses = []; accum_count = 0\nprint(\"Training started!\\n\")\nt0 = time.time(); model.train()\n\nwhile global_step < MAX_STEPS:\n for imgs, lbls in train_loader:\n if global_step >= MAX_STEPS: break\n imgs = imgs.to(device)\n lbls = lbls.to(device) if NUM_CLASSES > 0 else None\n\n # Encode with frozen VAE (SDXL: 4 channels, scale only, no shift)\n with torch.no_grad():\n lats = vae.encode(imgs.half()*2-1).latent_dist.sample()\n lats = (lats * SCALE).float()\n\n t = fm.sample_t(lats.shape[0], device)\n noise = torch.randn_like(lats)\n xt = fm.add_noise(lats, noise, t)\n vtgt = fm.velocity(lats, noise)\n\n with autocast(\"cuda\"):\n loss = F.mse_loss(model(xt, t, lbls), vtgt) / GRADIENT_ACCUMULATION\n scaler.scale(loss).backward()\n loss_accum += loss.item()\n accum_count += 1\n\n if accum_count % GRADIENT_ACCUMULATION == 0:\n scaler.unscale_(optimizer)\n gn = torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)\n scaler.step(optimizer); scaler.update(); optimizer.zero_grad()\n scheduler.step(); ema.update(model); global_step += 1\n\n if global_step % LOG_EVERY == 0:\n al = loss_accum / LOG_EVERY; lr = optimizer.param_groups[0][\"lr\"]\n vram = torch.cuda.memory_allocated()/1024**3 if torch.cuda.is_available() else 0\n print(f\"step={global_step:>6d} | loss={al:.4f} | gn={gn:.2f} | lr={lr:.2e} | vram={vram:.1f}G | {time.time()-t0:.0f}s\")\n log_losses.append(al); loss_accum = 0\n if math.isnan(al) or al > 50: print(\"Diverged!\"); break\n\n if global_step % SAMPLE_EVERY == 0:\n ema.apply(model); model.eval()\n ls = IMAGE_SIZE // 8\n sl = torch.randint(0, max(1,NUM_CLASSES), (4,), device=device) if NUM_CLASSES > 0 else None\n samp = fm.sample(model, (4, LAT_CH, ls, ls), device, NUM_SAMPLE_STEPS, sl, CFG_SCALE)\n with torch.no_grad():\n si = ((vae.decode(samp.half()/SCALE).sample+1)/2).clamp(0,1).float()\n save_image(si, f\"{OUTPUT_DIR}/samples/step_{global_step:07d}.png\", nrow=2)\n print(f\" Saved samples\")\n ema.restore(model); model.train()\n\n if global_step % SAVE_EVERY == 0:\n torch.save({\"model\":model.state_dict(),\"ema\":ema.shadow,\"step\":global_step,\"cfg\":cfg},\n f\"{OUTPUT_DIR}/checkpoints/step_{global_step:07d}.pt\")\n print(f\" Checkpoint saved\")\n\ntorch.save({\"model\":model.state_dict(),\"ema\":ema.shadow,\"cfg\":cfg,\"step\":global_step},\n f\"{OUTPUT_DIR}/checkpoints/final.pt\")\nprint(f\"\\nDone! {global_step} steps in {(time.time()-t0)/60:.1f} min\")\n"
|
| 146 |
]
|
| 147 |
},
|
| 148 |
{
|
|
|
|
| 174 |
"metadata": {},
|
| 175 |
"outputs": [],
|
| 176 |
"source": [
|
| 177 |
+
"ema.apply(model); model.eval()\nN, STEPS, G = 8, 50, 2.5; ls = IMAGE_SIZE // 8\nif NUM_CLASSES > 0:\n for ci in range(min(NUM_CLASSES, 6)):\n l = torch.full((N,), ci, device=device, dtype=torch.long)\n s = fm.sample(model, (N, LAT_CH, ls, ls), device, STEPS, l, G)\n with torch.no_grad(): i = ((vae.decode(s.half()/SCALE).sample+1)/2).clamp(0,1).float()\n save_image(i, f\"{OUTPUT_DIR}/gen_class{ci}.png\", nrow=4)\n print(f\"Generated class {ci}\")\nelse:\n s = fm.sample(model, (N, LAT_CH, ls, ls), device, STEPS)\n with torch.no_grad(): i = ((vae.decode(s.half()/SCALE).sample+1)/2).clamp(0,1).float()\n save_image(i, f\"{OUTPUT_DIR}/gen_uncond.png\", nrow=4)\nema.restore(model)\nprint(f\"Saved to {OUTPUT_DIR}/\")\n"
|
| 178 |
]
|
| 179 |
},
|
| 180 |
{
|