{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# LiquidFlow: Liquid Neural Network + Mamba-2 SSD Image Generator\n", "\n", "**Train on Google Colab Free Tier (T4 GPU) | Export for Mobile Deployment**\n", "\n", "LiquidFlow combines:\n", "- **CfC (Closed-form Continuous-time)** Liquid Neural Networks — adaptive time gates\n", "- **Mamba-2 SSD** — linear-time attention replacement, fully parallelizable\n", "- **Physics-Informed Regularization** — TV loss, spectral constraints\n", "- **TAESD VAE** — Tiny AutoEncoder (< 1M params) for fast encoding\n", "\n", "Based on:\n", "- CfC: Hasani et al., Nature MI 2022\n", "- Mamba-2: Dao & Gu, 2024 \n", "- PINN Diffusion: Bastek & Sun, ICLR 2025\n", "- DiMSUM: NeurIPS 2024\n", "\n", "---\n", "## Quick Start\n", "1. Runtime → Change runtime type → GPU (T4)\n", "2. Run all cells in order\n", "3. Training starts automatically on CIFAR-10\n", "4. Check samples in `./outputs/samples/`" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "# @title 1. Install Dependencies (~2 min)\n", "!pip install -q torch torchvision diffusers tqdm pillow numpy\n", "!pip install -q git+https://github.com/huggingface/diffusers.git\n", "\n", "import torch\n", "print(f\"PyTorch: {torch.__version__}\")\n", "print(f\"CUDA available: {torch.cuda.is_available()}\")\n", "if torch.cuda.is_available():\n", " print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n", " print(f\"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB\")" ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "# @title 2. Clone LiquidFlow Repository\n", "!git clone https://huggingface.co/krystv/LiquidFlow-Gen /content/LiquidFlow\n", "%cd /content/LiquidFlow\n", "\n", "import sys\n", "sys.path.insert(0, '/content/LiquidFlow')\n", "\n", "from liquid_flow.generator import create_liquidflow\n", "from liquid_flow.vae_wrapper import TAESDWrapper\n", "import torch.nn as nn\n", "import torch.nn.functional as F" ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "# @title 3. Configuration — Adjust these settings!\n", "\n", "# Model size: 'tiny' (~2M), 'small' (~8M), 'base' (~30M)\n", "MODEL_VARIANT = 'small' # @param ['tiny', 'small', 'base']\n", "\n", "# Image size: 128 recommended for T4, 512 needs more VRAM\n", "IMAGE_SIZE = 128 # @param [64, 128, 256, 512]\n", "\n", "# Training\n", "BATCH_SIZE = 32 # @param [8, 16, 32, 64]\n", "EPOCHS = 50 # @param [10, 25, 50, 100]\n", "LEARNING_RATE = 2e-4 # @param [1e-4, 2e-4, 5e-4, 1e-3]\n", "\n", "# Dataset\n", "DATASET = 'cifar10' # @param ['cifar10', 'cifar100', 'stl10']\n", "\n", "# Sampling (DDIM steps)\n", "SAMPLE_EVERY = 5 # @param [1, 5, 10]\n", "SAMPLE_STEPS = 50 # @param [20, 50, 100]\n", "\n", "# Physics regularization weights\n", "PHYSICS_TV_WEIGHT = 0.01\n", "PHYSICS_SPEC_WEIGHT = 0.01\n", "PHYSICS_GRAD_WEIGHT = 0.001\n", "\n", "print(f\"Config: {MODEL_VARIANT} model, {IMAGE_SIZE}px, batch={BATCH_SIZE}, epochs={EPOCHS}, lr={LEARNING_RATE}\")\n", "print(f\"Physics loss: TV={PHYSICS_TV_WEIGHT}, Spec={PHYSICS_SPEC_WEIGHT}, Grad={PHYSICS_GRAD_WEIGHT}\")" ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "# @title 4. Load VAE & Create Model\n", "import torch\n", "from torchvision import datasets, transforms\n", "from torch.utils.data import DataLoader\n", "import os\n", "from tqdm import tqdm\n", "\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "\n", "# Load TAESD (Tiny AutoEncoder)\n", "print(\"Loading TAESD VAE...\")\n", "vae = TAESDWrapper.load(device)\n", "print(f\"VAE loaded! Latent compression: {IMAGE_SIZE}x{IMAGE_SIZE} → {IMAGE_SIZE//8}x{IMAGE_SIZE//8}\")\n", "\n", "# Create LiquidFlow model\n", "print(f\"Creating {MODEL_VARIANT} LiquidFlow model...\")\n", "model = create_liquidflow(\n", " variant=MODEL_VARIANT,\n", " image_size=IMAGE_SIZE,\n", " physics_weights={\n", " 'tv': PHYSICS_TV_WEIGHT,\n", " 'cons': 0.001,\n", " 'spec': PHYSICS_SPEC_WEIGHT,\n", " 'grad': PHYSICS_GRAD_WEIGHT,\n", " },\n", ")\n", "model = model.to(device)\n", "\n", "n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n", "print(f\"Model: {n_params:,} parameters ({n_params/1e6:.1f}M)\")\n", "\n", "# Memory estimate\n", "latent_size = IMAGE_SIZE // 8\n", "mem_per_sample = latent_size * latent_size * 4 * 4 / 1e6 # MB\n", "print(f\"Memory per sample: {mem_per_sample:.1f} MB\")\n", "print(f\"Estimated batch memory: {mem_per_sample * BATCH_SIZE:.1f} MB\")\n", "print(f\"T4 VRAM: 15 GB — should fit!\" if mem_per_sample * BATCH_SIZE < 10 else \"Watch memory!\")" ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "# @title 5. Load Dataset\n", "\n", "transform = transforms.Compose([\n", " transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),\n", " transforms.ToTensor(),\n", " transforms.Normalize([0.5], [0.5]),\n", "])\n", "\n", "if DATASET == 'cifar10':\n", " dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)\n", "elif DATASET == 'cifar100':\n", " dataset = datasets.CIFAR100(root='./data', train=True, download=True, transform=transform)\n", "elif DATASET == 'stl10':\n", " dataset = datasets.STL10(root='./data', split='train', download=True, transform=transform)\n", "else:\n", " raise ValueError(f\"Unknown dataset: {DATASET}\")\n", "\n", "dataloader = DataLoader(\n", " dataset, batch_size=BATCH_SIZE, shuffle=True,\n", " num_workers=min(4, os.cpu_count() or 1),\n", " pin_memory=True, drop_last=True,\n", ")\n", "\n", "print(f\"Dataset: {DATASET}\")\n", "print(f\"Images: {len(dataset):,}, Batches per epoch: {len(dataloader)}\")" ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "# @title 6. Training Loop\n", "\n", "from torchvision.utils import save_image\n", "import math\n", "\n", "os.makedirs('./outputs/samples', exist_ok=True)\n", "os.makedirs('./outputs/checkpoints', exist_ok=True)\n", "\n", "# Optimizer\n", "optimizer = torch.optim.AdamW(\n", " model.parameters(),\n", " lr=LEARNING_RATE,\n", " betas=(0.9, 0.999),\n", " weight_decay=1e-4,\n", ")\n", "\n", "# Cosine LR scheduler\n", "scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n", " optimizer, T_max=EPOCHS * len(dataloader)\n", ")\n", "\n", "# AMP\n", "use_amp = device.type == 'cuda'\n", "scaler = torch.cuda.amp.GradScaler() if use_amp else None\n", "\n", "print(f\"Training: {EPOCHS} epochs, LR={LEARNING_RATE}, AMP={use_amp}\")\n", "print(\"=\"*60)\n", "\n", "global_step = 0\n", "best_loss = float('inf')\n", "\n", "for epoch in range(EPOCHS):\n", " model.train()\n", " epoch_total = 0\n", " \n", " pbar = tqdm(dataloader, desc=f\"Epoch {epoch+1}/{EPOCHS}\")\n", " \n", " for images, _ in pbar:\n", " images = images.to(device)\n", " \n", " # Encode to latent\n", " with torch.no_grad():\n", " latents = TAESDWrapper.encode(vae, images)\n", " \n", " # Training step with physics regularization\n", " loss_dict = model.training_step(latents, optimizer, scaler, use_amp)\n", " \n", " # Track\n", " total_loss = loss_dict['total']\n", " epoch_total += total_loss\n", " \n", " # Update scheduler\n", " scheduler.step()\n", " \n", " # Progress bar\n", " pbar.set_postfix({\n", " 'loss': f\"{total_loss:.4f}\",\n", " 'diff': f\"{loss_dict.get('diffusion', 0):.4f}\",\n", " 'phys': f\"{loss_dict.get('physics', 0):.4f}\",\n", " 'lr': f\"{optimizer.param_groups[0]['lr']:.2e}\",\n", " })\n", " \n", " global_step += 1\n", " \n", " avg_loss = epoch_total / len(dataloader)\n", " print(f\"Epoch {epoch+1}: avg_loss={avg_loss:.4f}\")\n", " \n", " # Generate samples\n", " if (epoch + 1) % SAMPLE_EVERY == 0 or epoch == EPOCHS - 1:\n", " print(\" Generating samples...\")\n", " model.eval()\n", " with torch.no_grad():\n", " latents_gen = model.sample(\n", " batch_size=16,\n", " steps=SAMPLE_STEPS,\n", " ddim=True,\n", " progress=False,\n", " )\n", " images_gen = TAESDWrapper.decode(vae, latents_gen)\n", " \n", " save_image(\n", " images_gen, f'./outputs/samples/epoch_{epoch+1:03d}.png',\n", " nrow=4, normalize=True, value_range=(-1, 1)\n", " )\n", " print(f\" Saved to ./outputs/samples/epoch_{epoch+1:03d}.png\")\n", " \n", " # Save checkpoint\n", " if (epoch + 1) % 10 == 0 or epoch == EPOCHS - 1:\n", " torch.save({\n", " 'epoch': epoch + 1,\n", " 'model_state_dict': model.state_dict(),\n", " 'optimizer_state_dict': optimizer.state_dict(),\n", " 'loss': avg_loss,\n", " }, f'./outputs/checkpoints/epoch_{epoch+1:03d}.pt')\n", " \n", " if avg_loss < best_loss:\n", " best_loss = avg_loss\n", " torch.save(model.state_dict(), './outputs/checkpoints/best_model.pt')\n", "\n", "print(\"=\"*60)\n", "print(f\"Training complete! Best loss: {best_loss:.4f}\")\n", "print(f\"Checkpoints saved to ./outputs/checkpoints/\")\n", "print(f\"Samples saved to ./outputs/samples/\")" ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "# @title 7. Generate & Display Samples\n", "\n", "import matplotlib.pyplot as plt\n", "from PIL import Image\n", "import glob\n", "\n", "# Load latest sample\n", "sample_files = sorted(glob.glob('./outputs/samples/epoch_*.png'))\n", "if sample_files:\n", " latest = sample_files[-1]\n", " img = Image.open(latest)\n", " plt.figure(figsize=(12, 12))\n", " plt.imshow(img)\n", " plt.title(f'LiquidFlow Samples — {MODEL_VARIANT} model, {IMAGE_SIZE}px')\n", " plt.axis('off')\n", " plt.show()\n", "else:\n", " print(\"No samples generated yet. Train for more epochs!\")" ], "outputs": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "source": [ "# @title 8. Export Model for Mobile (ONNX)\n", "\n", "# LiquidFlow can be exported to ONNX for mobile deployment\n", "# since it uses pure PyTorch (no custom CUDA kernels)\n", "\n", "def export_to_onnx(model, output_path='liquidflow_model.onnx', image_size=128):\n", " \"\"\"Export LiquidFlow to ONNX for mobile deployment.\"\"\"\n", " model = model.cpu()\n", " model.eval()\n", " \n", " latent_size = image_size // 8\n", " \n", " # Dummy inputs\n", " x = torch.randn(1, 4, latent_size, latent_size)\n", " t = torch.tensor([500], dtype=torch.long)\n", " \n", " # Export\n", " torch.onnx.export(\n", " model,\n", " (x, t),\n", " output_path,\n", " input_names=['noisy_latent', 'timestep'],\n", " output_names=['predicted_noise'],\n", " dynamic_axes={\n", " 'noisy_latent': {0: 'batch'},\n", " 'predicted_noise': {0: 'batch'},\n", " },\n", " opset_version=14,\n", " )\n", " \n", " import os\n", " size_mb = os.path.getsize(output_path) / 1e6\n", " print(f\"ONNX model exported to {output_path} ({size_mb:.1f} MB)\")\n", " return output_path\n", "\n", "# Load best model and export\n", "best_model_path = './outputs/checkpoints/best_model.pt'\n", "if os.path.exists(best_model_path):\n", " model.load_state_dict(torch.load(best_model_path, map_location='cpu'))\n", " export_to_onnx(model, 'liquidflow_128.onnx', IMAGE_SIZE)\n", " print(\"Ready for mobile deployment!\")\n", "else:\n", " print(\"Train model first before exporting.\")" ], "outputs": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Architecture Details\n", "\n", "### LiquidFlow Block Architecture\n", "```\n", "Input → [CfC Gate → Mamba-2 SSD → CfC Gate] → Output\n", " ↑ ↑\n", " Adaptive time gate Gated output\n", "```\n", "\n", "### CfC (Closed-form Continuous-time) Cell\n", "```\n", "h(t) = σ(-f(x,I;θ_f)·t) ⊙ g(x,I;θ_g) + (1-σ(-f(x,I;θ_f)·t)) ⊙ h(x,I;θ_h)\n", "```\n", "- **No ODE solver needed** — 100x faster than Neural ODEs\n", "- Time-continuous gating adaptively controls information flow\n", "- Closed-form solution → stable gradients\n", "\n", "### Mamba-2 SSD (State Space Duality)\n", "```\n", "h_t = A_t * h_{t-1} + B_t * x_t\n", "y_t = C_t^T * h_t\n", "```\n", "- **O(N) linear complexity** vs Transformers O(N²)\n", "- **Parallelizable** via associative scan (Blelloch)\n", "- **Scalar-A** formulation enables chunk-scan optimization\n", "- Pure PyTorch — no CUDA kernels needed\n", "\n", "### Physics-Informed Regularization\n", "- **Total Variation**: `L_TV = ||∇_x x̂||₁ + ||∇_y x̂||₁`\n", "- **Spectral**: Penalize high-frequency artifacts\n", "- **Gradient**: Sobolev norm for stable training\n", "- Pattern from Bastek & Sun (ICLR 2025): physics loss as training-only regularizer\n", "\n", "### Model Variants\n", "| Variant | Params | Hidden Dim | Stages | Blocks | T4 VRAM |\n", "|---------|--------|------------|--------|--------|---------|\n", "| Tiny | ~2M | 128 | 2 | 2 | < 2 GB |\n", "| Small | ~8M | 256 | 4 | 4 | ~4 GB |\n", "| Base | ~30M | 384 | 6 | 6 | ~8 GB |" ] } ], "metadata": { "colab": { "name": "LiquidFlow: LiquidNN + Mamba-2 SSD Image Generator", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 }