krystv commited on
Commit
c2b4760
·
verified ·
1 Parent(s): 1a08b06

v2: Add VAE latent training, fix datasets, streaming support

Browse files
Files changed (1) hide show
  1. LiquidDiffusion_Training.ipynb +13 -15
LiquidDiffusion_Training.ipynb CHANGED
@@ -7,27 +7,25 @@
7
  "accelerator": "GPU"
8
  },
9
  "cells": [
10
- {"cell_type": "markdown", "metadata": {}, "source": ["# \ud83c\udf0a LiquidDiffusion: Attention-Free Image Generation with Liquid Neural Networks\n", "\n", "A **novel** image generation model combining:\n", "- **Liquid Neural Networks** (CfC \u2014 Closed-form Continuous-depth) for adaptive processing\n", "- **Rectified Flow** for simple, stable training (MSE velocity prediction)\n", "- **Pretrained SD-VAE** for efficient latent-space training (4ch, 8\u00d7 downscale)\n", "- **Zero attention** \u2014 fully convolutional + multi-scale spatial mixing\n", "- **Fully parallelizable** \u2014 no ODE loops, no recurrence\n", "\n", "### Key Innovation\n", "Diffusion timestep = liquid time constant. CfC gate `\u03c3(-f\u00b7t)` adapts behavior to noise level.\n", "\n", "### References\n", "- [CfC Networks (Nature MI 2022)](https://arxiv.org/abs/2106.13898)\n", "- [LiquidTAD (2024)](https://arxiv.org/abs/2604.18274) | [USM (CVPR 2025)](https://arxiv.org/abs/2504.13499)\n", "- [Rectified Flow (ICLR 2023)](https://arxiv.org/abs/2209.03003)\n", "- **Repo**: [krystv/liquid-diffusion](https://huggingface.co/krystv/liquid-diffusion)"]},
11
  {"cell_type": "markdown", "metadata": {}, "source": ["## \u2699\ufe0f Configuration"]},
12
- {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["#@title \u2699\ufe0f Configuration { display-mode: \"form\" }\n", "\n", "#@markdown ### Model\n", "MODEL_SIZE = 'tiny' #@param ['tiny', 'small', 'custom']\n", "#@markdown > `tiny`=23M (256px, ~6GB) | `small`=69M (256px, ~10GB)\n", "CUSTOM_CHANNELS = [48, 96, 192]\n", "CUSTOM_BLOCKS = [1, 2, 3]\n", "CUSTOM_T_DIM = 192\n", "\n", "#@markdown ### Resolution\n", "IMAGE_SIZE = 256 #@param [128, 256, 512] {type:\"raw\"}\n", "\n", "#@markdown ### VAE (Latent Space)\n", "USE_VAE = True #@param {type:\"boolean\"}\n", "#@markdown > Pretrained SD-VAE encodes images to 4ch latents (8\u00d7 smaller). **Highly recommended.**\n", "VAE_MODEL = 'stabilityai/sd-vae-ft-mse' #@param ['stabilityai/sd-vae-ft-mse', 'madebyollin/sdxl-vae-fp16-fix']\n", "PRECACHE_LATENTS = True #@param {type:\"boolean\"}\n", "#@markdown > Pre-encode all images once. Frees ~160MB VAE VRAM during training.\n", "\n", "#@markdown ### Dataset\n", "DATASET = 'nielsr/CelebA-faces' #@param ['nielsr/CelebA-faces', 'huggan/flowers-102-categories', 'reach-vb/pokemon-blip-captions', 'huggan/anime-faces', 'huggan/AFHQv2', 'Norod78/cartoon-blip-captions']\n", "#@markdown > All verified \u2713 | CelebA=202K faces | flowers=8K | pokemon=833 | anime=21K | AFHQ=16K animals | cartoon=2K\n", "IMAGE_COLUMN = 'image'\n", "MAX_SAMPLES = None # e.g. 5000 for quick test, None=full\n", "\n", "#@markdown ### Training\n", "BATCH_SIZE = 8 #@param {type:\"integer\"}\n", "LEARNING_RATE = 1e-4 #@param {type:\"number\"}\n", "WEIGHT_DECAY = 0.01 #@param {type:\"number\"}\n", "NUM_EPOCHS = 100 #@param {type:\"integer\"}\n", "GRAD_CLIP = 1.0 #@param {type:\"number\"}\n", "EMA_DECAY = 0.9999 #@param {type:\"number\"}\n", "NUM_WORKERS = 2\n", "TIME_SAMPLING = 'logit_normal' #@param ['uniform', 'logit_normal']\n", "USE_AMP = True #@param {type:\"boolean\"}\n", "AMP_DTYPE = 'float16' #@param ['float16', 'bfloat16']\n", "\n", "#@markdown ### Sampling & Logging\n", "SAMPLE_EVERY = 500 #@param {type:\"integer\"}\n", "NUM_SAMPLE_IMAGES = 8 #@param {type:\"integer\"}\n", "NUM_EULER_STEPS = 50 #@param {type:\"integer\"}\n", "SAVE_EVERY = 2000 #@param {type:\"integer\"}\n", "OUTPUT_DIR = './outputs'\n", "RESUME_FROM = None\n", "LOG_EVERY = 50\n", "\n", "LATENT_SIZE = IMAGE_SIZE // 8 if USE_VAE else IMAGE_SIZE\n", "IN_CHANNELS = 4 if USE_VAE else 3\n", "print(f\"Config: {MODEL_SIZE} | {IMAGE_SIZE}px {'(latent '+str(LATENT_SIZE)+'px)' if USE_VAE else '(pixel)'} | {DATASET}\")\n", "print(f\"Training: bs={BATCH_SIZE}, lr={LEARNING_RATE}, epochs={NUM_EPOCHS}, AMP={USE_AMP}\")"]},
13
- {"cell_type": "markdown", "metadata": {}, "source": ["## \ud83d\udce6 Install & Check GPU"]},
14
- {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["!pip install -q datasets diffusers accelerate huggingface_hub Pillow matplotlib transformers\n", "import torch\n", "print(f\"PyTorch: {torch.__version__}, CUDA: {torch.cuda.is_available()}\")\n", "if torch.cuda.is_available():\n", " print(f\"GPU: {torch.cuda.get_device_name(0)}, VRAM: {torch.cuda.get_device_properties(0).total_mem/1e9:.1f}GB\")\n", "else:\n", " print(\"\u26a0\ufe0f No GPU! Enable via Runtime \u2192 Change runtime type.\")"]},
15
  {"cell_type": "markdown", "metadata": {}, "source": ["## \ud83c\udfd7\ufe0f Model Architecture"]},
16
- {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["import math, copy, os, time\nfrom glob import glob\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nfrom torch.utils.data import DataLoader, Dataset\nfrom torchvision import transforms\nfrom torchvision.utils import save_image, make_grid\n\nclass SinusoidalTimeEmbedding(nn.Module):\n def __init__(self, dim, max_period=10000):\n super().__init__()\n self.dim, self.mp = dim, max_period\n self.mlp = nn.Sequential(nn.Linear(dim, dim*4), nn.SiLU(), nn.Linear(dim*4, dim))\n def forward(self, t):\n h = self.dim // 2\n f = torch.exp(-math.log(self.mp)*torch.arange(h, device=t.device, dtype=t.dtype)/h)\n e = torch.cat([torch.cos(t[:,None]*f[None]), torch.sin(t[:,None]*f[None])], -1)\n if self.dim%2: e = F.pad(e,(0,1))\n return self.mlp(e)\n\nclass AdaLN(nn.Module):\n def __init__(self, dim, cd):\n super().__init__()\n ng = min(32, dim)\n while dim%ng!=0: ng-=1\n self.norm = nn.GroupNorm(ng, dim, affine=False)\n self.proj = nn.Sequential(nn.SiLU(), nn.Linear(cd, dim*2))\n def forward(self, x, te):\n s, sh = self.proj(te).chunk(2,1)\n return self.norm(x)*(1+s[:,:,None,None])+sh[:,:,None,None]\n\nclass ParallelCfCBlock(nn.Module):\n def __init__(self, dim, td, er=2.0, ks=7, dr=0.0):\n super().__init__()\n hid = int(dim*er)\n self.bdw = nn.Conv2d(dim, dim, ks, padding=ks//2, groups=dim)\n self.bpw = nn.Conv2d(dim, hid, 1)\n self.ba = nn.SiLU()\n self.fh = nn.Conv2d(hid, dim, 1)\n self.gh = nn.Sequential(nn.Conv2d(hid,hid,ks,padding=ks//2,groups=hid),nn.SiLU(),nn.Conv2d(hid,dim,1))\n self.hh = nn.Sequential(nn.Conv2d(hid,hid,ks,padding=ks//2,groups=hid),nn.SiLU(),nn.Conv2d(hid,dim,1))\n self.ta, self.tb = nn.Linear(td, dim), nn.Linear(td, dim)\n self.rho = nn.Parameter(torch.zeros(1,dim,1,1))\n self.og = nn.Sequential(nn.SiLU(), nn.Linear(td, dim))\n self.do = nn.Dropout(dr) if dr>0 else nn.Identity()\n def forward(self, x, te):\n res = x\n bb = self.ba(self.bpw(self.bdw(x)))\n f,g,h = self.fh(bb), self.gh(bb), self.hh(bb)\n gt = torch.sigmoid(self.ta(te)[:,:,None,None]*f - self.tb(te)[:,:,None,None])\n co = self.do(gt*g + (1-gt)*h)\n lam = F.softplus(self.rho)+1e-6\n al = torch.exp(-lam*te.mean(1,keepdim=True)[:,:,None,None].abs().clamp(min=0.01))\n return (al*res+(1-al)*co)*torch.sigmoid(self.og(te))[:,:,None,None]\n\nclass MultiScaleSpatialMix(nn.Module):\n def __init__(self, dim, td):\n super().__init__()\n self.d3=nn.Conv2d(dim,dim,3,padding=1,groups=dim)\n self.d5=nn.Conv2d(dim,dim,5,padding=2,groups=dim)\n self.d7=nn.Conv2d(dim,dim,7,padding=3,groups=dim)\n self.gp=nn.AdaptiveAvgPool2d(1); self.gpj=nn.Conv2d(dim,dim,1)\n self.mg=nn.Conv2d(dim*4,dim,1); self.ac=nn.SiLU(); self.an=AdaLN(dim,td)\n def forward(self, x, te):\n xn=self.an(x,te)\n return x+self.ac(self.mg(torch.cat([self.d3(xn),self.d5(xn),self.d7(xn),self.gpj(self.gp(xn)).expand_as(xn)],1)))\n\nclass LiquidDiffusionBlock(nn.Module):\n def __init__(self, dim, td, er=2.0, ks=7, dr=0.0):\n super().__init__()\n self.a1=AdaLN(dim,td); self.cfc=ParallelCfCBlock(dim,td,er,ks,dr)\n self.sm=MultiScaleSpatialMix(dim,td); self.a2=AdaLN(dim,td)\n ff=int(dim*er); self.ff=nn.Sequential(nn.Conv2d(dim,ff,1),nn.SiLU(),nn.Conv2d(ff,dim,1))\n self.rs=nn.Parameter(torch.ones(1)*0.1)\n def forward(self, x, te):\n x=x+self.rs*self.cfc(self.a1(x,te),te); x=self.sm(x,te)\n return x+self.rs*self.ff(self.a2(x,te))\n\nclass DS(nn.Module):\n def __init__(self,i,o): super().__init__(); self.c=nn.Conv2d(i,o,3,stride=2,padding=1)\n def forward(self,x): return self.c(x)\nclass US(nn.Module):\n def __init__(self,i,o): super().__init__(); self.c=nn.Conv2d(i,o,3,padding=1)\n def forward(self,x): return self.c(F.interpolate(x,scale_factor=2,mode='nearest'))\nclass SF(nn.Module):\n def __init__(self,d,td): super().__init__(); self.p=nn.Conv2d(d*2,d,1); self.g=nn.Sequential(nn.SiLU(),nn.Linear(td,d),nn.Sigmoid())\n def forward(self,x,sk,te): m=self.p(torch.cat([x,sk],1)); g=self.g(te)[:,:,None,None]; return m*g+x*(1-g)\n\nclass LiquidDiffusionUNet(nn.Module):\n def __init__(self, in_ch=3, chs=None, bps=None, td=256, er=2.0, ks=7, dr=0.0):\n super().__init__()\n chs=chs or [64,128,256]; bps=bps or [2,2,4]\n assert len(chs)==len(bps)\n self.chs,self.ns=chs,len(chs)\n self.te=SinusoidalTimeEmbedding(td)\n self.st=nn.Sequential(nn.Conv2d(in_ch,chs[0],3,padding=1),nn.SiLU(),nn.Conv2d(chs[0],chs[0],3,padding=1))\n self.enc,self.dn=nn.ModuleList(),nn.ModuleList()\n for i in range(self.ns):\n self.enc.append(nn.ModuleList([LiquidDiffusionBlock(chs[i],td,er,ks,dr) for _ in range(bps[i])]))\n if i<self.ns-1: self.dn.append(DS(chs[i],chs[i+1]))\n self.bot=nn.ModuleList([LiquidDiffusionBlock(chs[-1],td,er,ks,dr) for _ in range(2)])\n self.dec,self.up_,self.sf_=nn.ModuleList(),nn.ModuleList(),nn.ModuleList()\n for i in range(self.ns-1,-1,-1):\n if i<self.ns-1: self.up_.append(US(chs[i+1],chs[i])); self.sf_.append(SF(chs[i],td))\n self.dec.append(nn.ModuleList([LiquidDiffusionBlock(chs[i],td,er,ks,dr) for _ in range(bps[i])]))\n hg=min(32,chs[0])\n while chs[0]%hg!=0: hg-=1\n self.hd=nn.Sequential(nn.GroupNorm(hg,chs[0]),nn.SiLU(),nn.Conv2d(chs[0],in_ch,3,padding=1))\n nn.init.zeros_(self.hd[-1].weight); nn.init.zeros_(self.hd[-1].bias)\n def forward(self, x, t):\n te=self.te(t); h=self.st(x); sk=[]\n for i in range(self.ns):\n for b in self.enc[i]: h=b(h,te)\n sk.append(h)\n if i<self.ns-1: h=self.dn[i](h)\n for b in self.bot: h=b(h,te)\n ui=0\n for di in range(self.ns):\n si=self.ns-1-di\n if di>0: h=self.up_[ui](h); h=self.sf_[ui](h,sk[si],te); ui+=1\n for b in self.dec[di]: h=b(h,te)\n return self.hd(h)\n def count_params(self): return sum(p.numel() for p in self.parameters()), sum(p.numel() for p in self.parameters() if p.requires_grad)\n\nprint('\u2705 Model architecture defined.')"]},
17
  {"cell_type": "markdown", "metadata": {}, "source": ["## \ud83d\udd27 Build Model + Load VAE"]},
18
- {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["device = 'cuda' if torch.cuda.is_available() else 'cpu'\nCFGS = {'tiny': dict(chs=[64,128,256], bps=[2,2,4], td=256), 'small': dict(chs=[96,192,384], bps=[2,3,6], td=384)}\nif MODEL_SIZE=='custom': cfg=dict(chs=CUSTOM_CHANNELS,bps=CUSTOM_BLOCKS,td=CUSTOM_T_DIM)\nelse: cfg=CFGS[MODEL_SIZE]\nmodel = LiquidDiffusionUNet(in_ch=IN_CHANNELS, **cfg).to(device)\ntp,_=model.count_params()\nprint(f'Model: {MODEL_SIZE} | {tp:,} params ({tp/1e6:.1f}M) | in_ch={IN_CHANNELS}')\n\nvae=None; vae_scale=1.0\nif USE_VAE:\n from diffusers import AutoencoderKL\n print(f'Loading VAE: {VAE_MODEL}...')\n vae = AutoencoderKL.from_pretrained(VAE_MODEL, torch_dtype=torch.float16 if device=='cuda' else torch.float32)\n vae = vae.to(device).eval(); vae.requires_grad_(False)\n vae_scale = vae.config.scaling_factor\n print(f'VAE: {sum(p.numel() for p in vae.parameters())/1e6:.1f}M params, latent_ch={vae.config.latent_channels}, scale={vae_scale}')\n print(f' {IMAGE_SIZE}px \u2192 {LATENT_SIZE}px latent (8\u00d7 downsample)')\n\nwith torch.no_grad():\n tx=torch.randn(1,IN_CHANNELS,LATENT_SIZE,LATENT_SIZE,device=device)\n assert model(tx, torch.tensor([0.5],device=device)).shape==tx.shape\n print(f'Forward OK: {tx.shape}')\n del tx\nif device=='cuda': torch.cuda.empty_cache(); print(f'VRAM: {torch.cuda.memory_allocated()/1e9:.2f}GB')"]},
19
  {"cell_type": "markdown", "metadata": {}, "source": ["## \ud83d\udcca Load Dataset"]},
20
- {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["from PIL import Image\nfrom datasets import load_dataset\nimport matplotlib.pyplot as plt\n\nclass ImageDS(Dataset):\n def __init__(self, ds, sz, col='image'):\n self.ds, self.col = ds, col\n self.tf = transforms.Compose([transforms.Resize(sz, interpolation=transforms.InterpolationMode.LANCZOS),\n transforms.CenterCrop(sz), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize([0.5],[0.5])])\n def __len__(self): return len(self.ds)\n def __getitem__(self, i):\n img = self.ds[i][self.col]\n if not hasattr(img,'convert'): img=Image.fromarray(img)\n return self.tf(img.convert('RGB'))\n\nprint(f'Loading: {DATASET}')\nraw = load_dataset(DATASET, split='train')\nif MAX_SAMPLES: raw = raw.select(range(min(MAX_SAMPLES, len(raw))))\nprint(f' {len(raw):,} images')\ndataset = ImageDS(raw, IMAGE_SIZE, IMAGE_COLUMN)\ndata_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, drop_last=True, persistent_workers=True)\nprint(f' {len(data_loader):,} steps/epoch | ~{len(data_loader)*NUM_EPOCHS:,} total steps')\n\nsb=next(iter(data_loader))\nfig,axes=plt.subplots(1,min(8,BATCH_SIZE),figsize=(16,2.5))\nfor i,ax in enumerate(axes if hasattr(axes,'__len__') else [axes]): ax.imshow((sb[i].permute(1,2,0)*0.5+0.5).clamp(0,1)); ax.axis('off')\nplt.suptitle(f'Training samples ({IMAGE_SIZE}px)'); plt.tight_layout(); plt.show()"]},
21
- {"cell_type": "markdown", "metadata": {}, "source": ["## \ud83d\uddc3\ufe0f Pre-cache Latents (if VAE enabled)"]},
22
- {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["cached_latents = None\ntrain_loader = data_loader\n\nif USE_VAE and PRECACHE_LATENTS:\n print(f'Pre-encoding {len(dataset):,} images...')\n cl = DataLoader(dataset, batch_size=BATCH_SIZE*2, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)\n all_z = []\n vd = torch.float16 if device=='cuda' else torch.float32\n t0 = time.time()\n with torch.no_grad():\n for bi, imgs in enumerate(cl):\n z = vae.encode(imgs.to(device, dtype=vd)).latent_dist.sample() * vae_scale\n all_z.append(z.cpu().float())\n if (bi+1)%50==0: print(f' {(bi+1)*BATCH_SIZE*2:,}/{len(dataset):,}')\n cached_latents = torch.cat(all_z)\n print(f' Done in {time.time()-t0:.0f}s | Shape: {cached_latents.shape} | {cached_latents.numel()*4/1e9:.2f}GB')\n vae = vae.cpu()\n if device=='cuda': torch.cuda.empty_cache(); print(f' VAE \u2192 CPU. GPU VRAM: {torch.cuda.memory_allocated()/1e9:.2f}GB')\n class LatDS(Dataset):\n def __init__(self,z): self.z=z\n def __len__(self): return len(self.z)\n def __getitem__(self,i): return self.z[i]\n train_loader = DataLoader(LatDS(cached_latents), batch_size=BATCH_SIZE, shuffle=True, num_workers=0, pin_memory=True, drop_last=True)\n print(f' Latent loader: {len(train_loader)} steps/epoch')\nelif USE_VAE:\n print('Online VAE encoding (VAE stays on GPU)')\nelse:\n print('Pixel-space training (no VAE)')"]},
23
  {"cell_type": "markdown", "metadata": {}, "source": ["## \ud83d\ude80 Training"]},
24
- {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["import matplotlib.pyplot as plt\nos.makedirs(f'{OUTPUT_DIR}/samples', exist_ok=True); os.makedirs(f'{OUTPUT_DIR}/checkpoints', exist_ok=True)\n\noptimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY, betas=(0.9,0.999))\ntotal_steps = len(train_loader)*NUM_EPOCHS\nwarmup = min(1000, total_steps//10)\ndef lrl(step):\n if step<warmup: return step/max(1,warmup)\n return max(0.0, 0.5*(1+math.cos(math.pi*(step-warmup)/max(1,total_steps-warmup))))\nsched = torch.optim.lr_scheduler.LambdaLR(optimizer, lrl)\n\nema = copy.deepcopy(model).eval()\nfor p in ema.parameters(): p.requires_grad_(False)\nscaler = torch.amp.GradScaler('cuda', enabled=(USE_AMP and device=='cuda'))\namp_dt = getattr(torch, AMP_DTYPE) if USE_AMP and device=='cuda' else torch.float32\n\ndef st(bs):\n e=1e-5\n if TIME_SAMPLING=='uniform': return torch.rand(bs,device=device)*(1-2*e)+e\n return torch.sigmoid(torch.randn(bs,device=device)).clamp(e,1-e)\n\ngstep,start_ep,all_losses,ep_losses=0,0,[],[]\nif RESUME_FROM and os.path.exists(RESUME_FROM):\n ck=torch.load(RESUME_FROM,map_location=device,weights_only=False)\n model.load_state_dict(ck['model']); ema.load_state_dict(ck['ema_model']); optimizer.load_state_dict(ck['optimizer'])\n gstep=ck.get('step',0); start_ep=ck.get('epoch',0); all_losses=ck.get('losses',[])\n print(f'Resumed from step {gstep}')\n\n@torch.no_grad()\ndef gen_samples(step):\n ema.eval()\n z=torch.randn(NUM_SAMPLE_IMAGES,IN_CHANNELS,LATENT_SIZE,LATENT_SIZE,device=device)\n dt=1.0/NUM_EULER_STEPS\n for i in range(NUM_EULER_STEPS,0,-1):\n t=torch.full((NUM_SAMPLE_IMAGES,),i/NUM_EULER_STEPS,device=device)\n with torch.amp.autocast(device,dtype=amp_dt,enabled=USE_AMP and device=='cuda'): v=ema(z,t)\n if USE_AMP and amp_dt==torch.float16: v=v.float()\n z=z-v*dt\n z=z.clamp(-3,3)\n if USE_VAE:\n _v=vae.to(device); vd=torch.float16 if device=='cuda' else torch.float32\n imgs=_v.decode(z.to(vd)/vae_scale).sample.float()\n if PRECACHE_LATENTS: vae.cpu()\n else: imgs=z\n imgs=imgs.clamp(-1,1)\n save_image(make_grid(imgs*0.5+0.5,nrow=int(math.ceil(math.sqrt(NUM_SAMPLE_IMAGES))),padding=2),f'{OUTPUT_DIR}/samples/step_{step:06d}.png')\n return imgs\n\nprint(f'\\n{\"=\"*60}\\nTraining: {NUM_EPOCHS} epochs, {total_steps:,} steps\\n{\"=\"*60}\\n')\nt_start=time.time(); online_vae=USE_VAE and not PRECACHE_LATENTS; vd=torch.float16 if device=='cuda' else torch.float32\n\nfor epoch in range(start_ep, NUM_EPOCHS):\n model.train(); el=0\n for batch in train_loader:\n if online_vae:\n with torch.no_grad(): x0=vae.encode(batch.to(device,dtype=vd)).latent_dist.sample()*vae_scale; x0=x0.float()\n else: x0=batch.to(device)\n x1=torch.randn_like(x0); t=st(x0.shape[0]); te=t[:,None,None,None]\n xt=(1-te)*x0+te*x1; vt=x1-x0\n with torch.amp.autocast(device,dtype=amp_dt,enabled=USE_AMP and device=='cuda'):\n vp=model(xt,t); loss=F.mse_loss(vp,vt)\n optimizer.zero_grad(set_to_none=True); scaler.scale(loss).backward()\n if GRAD_CLIP>0: scaler.unscale_(optimizer); torch.nn.utils.clip_grad_norm_(model.parameters(),GRAD_CLIP)\n scaler.step(optimizer); scaler.update(); sched.step()\n with torch.no_grad():\n for ep,mp in zip(ema.parameters(),model.parameters()): ep.data.mul_(EMA_DECAY).add_(mp.data,alpha=1-EMA_DECAY)\n gstep+=1; lv=loss.item(); all_losses.append(lv); el+=lv\n if gstep%LOG_EVERY==0:\n avg=sum(all_losses[-LOG_EVERY:])/LOG_EVERY; lr=sched.get_last_lr()[0]\n sps=gstep/(time.time()-t_start); eta=(total_steps-gstep)/max(sps,1e-8)\n vm=f' | VRAM:{torch.cuda.max_memory_allocated()/1e9:.1f}GB' if device=='cuda' else ''\n print(f'Step {gstep:6d}/{total_steps} | Loss:{avg:.4f} | LR:{lr:.2e} | {sps:.1f}it/s | ETA:{eta/60:.0f}m{vm}')\n if gstep%SAMPLE_EVERY==0:\n print(' \\U0001f4f8 Generating...'); samps=gen_samples(gstep)\n fig,axes=plt.subplots(1,min(8,NUM_SAMPLE_IMAGES),figsize=(16,2.5))\n if not hasattr(axes,'__len__'): axes=[axes]\n for i,ax in enumerate(axes):\n if i<samps.shape[0]: ax.imshow((samps[i].cpu().permute(1,2,0)*0.5+0.5).clamp(0,1))\n ax.axis('off')\n plt.suptitle(f'Step {gstep} | Loss:{lv:.4f}'); plt.tight_layout(); plt.show()\n if gstep%SAVE_EVERY==0:\n cp=f'{OUTPUT_DIR}/checkpoints/step_{gstep:06d}.pt'\n torch.save({'model':model.state_dict(),'ema_model':ema.state_dict(),'optimizer':optimizer.state_dict(),'step':gstep,'epoch':epoch,'losses':all_losses[-2000:],'config':cfg},cp)\n print(f' \\U0001f4be Saved: {cp}')\n ep_losses.append(el/len(train_loader))\n print(f' Epoch {epoch+1}/{NUM_EPOCHS} | Avg loss:{ep_losses[-1]:.4f}')\n\nfp=f'{OUTPUT_DIR}/checkpoints/final.pt'\ntorch.save({'model':model.state_dict(),'ema_model':ema.state_dict(),'step':gstep,'config':cfg,'losses':all_losses[-2000:]},fp)\nprint(f'\\n\\u2705 Done! {fp} | {(time.time()-t_start)/3600:.1f}h')"]},
25
  {"cell_type": "markdown", "metadata": {}, "source": ["## \ud83d\udcc8 Training Curves"]},
26
- {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["import numpy as np\nfig,(a1,a2)=plt.subplots(1,2,figsize=(14,5))\na1.plot(all_losses,alpha=0.3,color='blue',lw=0.5)\nw=min(200,len(all_losses)//5)\nif w>1:\n sm=np.convolve(all_losses,np.ones(w)/w,mode='valid')\n a1.plot(range(w-1,len(all_losses)),sm,color='red',lw=2,label=f'Smooth(w={w})')\na1.set_xlabel('Step');a1.set_ylabel('Loss');a1.set_title('Training Loss');a1.legend();a1.grid(True,alpha=0.3)\nif ep_losses: a2.plot(range(1,len(ep_losses)+1),ep_losses,'o-',color='green'); a2.set_xlabel('Epoch');a2.set_ylabel('Loss');a2.set_title('Per Epoch');a2.grid(True,alpha=0.3)\nplt.tight_layout();plt.show()"]},
27
  {"cell_type": "markdown", "metadata": {}, "source": ["## \ud83c\udfa8 Generate Images"]},
28
- {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["N_GEN=16; STEPS=50\nprint(f'Generating {N_GEN} images ({STEPS} steps)...')\nema.eval()\nif USE_VAE: vae=vae.to(device)\nwith torch.no_grad():\n z=torch.randn(N_GEN,IN_CHANNELS,LATENT_SIZE,LATENT_SIZE,device=device)\n dt=1.0/STEPS\n for i in range(STEPS,0,-1):\n t=torch.full((N_GEN,),i/STEPS,device=device)\n with torch.amp.autocast(device,dtype=amp_dt,enabled=USE_AMP and device=='cuda'): v=ema(z,t)\n if USE_AMP and amp_dt==torch.float16: v=v.float()\n z=z-v*dt\n if USE_VAE: vdd=torch.float16 if device=='cuda' else torch.float32; gen=vae.decode(z.clamp(-3,3).to(vdd)/vae_scale).sample.float()\n else: gen=z\n gen=gen.clamp(-1,1)\nnr=int(math.ceil(math.sqrt(N_GEN)))\nfig,axes=plt.subplots(nr,nr,figsize=(2.5*nr,2.5*nr))\naxes=axes.flatten() if hasattr(axes,'flatten') else [axes]\nfor i,ax in enumerate(axes):\n if i<N_GEN: ax.imshow((gen[i].cpu().permute(1,2,0)*0.5+0.5).clamp(0,1))\n ax.axis('off')\nplt.suptitle(f'LiquidDiffusion ({IMAGE_SIZE}px, {STEPS} steps)',fontsize=14);plt.tight_layout();plt.show()\nsave_image(make_grid(gen*0.5+0.5,nrow=nr,padding=2),f'{OUTPUT_DIR}/final_samples.png')\nprint(f'Saved: {OUTPUT_DIR}/final_samples.png')"]},
29
- {"cell_type": "markdown", "metadata": {}, "source": ["## \ud83d\udcbe Push to Hub"]},
30
- {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["PUSH=False #@param {type:\"boolean\"}\nHUB_ID='your-username/liquid-diffusion-256' #@param {type:\"string\"}\nif PUSH:\n from huggingface_hub import HfApi\n api=HfApi(); api.create_repo(HUB_ID,exist_ok=True)\n api.upload_file(path_or_fileobj=fp,path_in_repo='model.pt',repo_id=HUB_ID)\n print(f'Pushed: https://huggingface.co/{HUB_ID}')"]},
31
- {"cell_type": "markdown", "metadata": {}, "source": ["---\n", "## \ud83d\udcd6 Architecture Reference\n", "\n", "### CfC Time-Gating\n", "```\n", "gate = \u03c3(time_a(t) \u00b7 f(features) - time_b(t))\n", "out = gate \u00b7 g + (1-gate) \u00b7 h\n", "```\n", "### Liquid Relaxation\n", "```\n", "\u03b1 = exp(-\u03bb\u00b7|t|), out = \u03b1\u00b7input + (1-\u03b1)\u00b7CfC_out\n", "```\n", "High noise \u2192 \u03b1\u22480 \u2192 heavy processing. Low noise \u2192 \u03b1\u22481 \u2192 preserve.\n", "\n", "### VAE: `stabilityai/sd-vae-ft-mse`\n", "83M params, 4ch latents, 8\u00d7 downscale. 256px\u219232\u00d732\u00d74 latent.\n", "\n", "### Verified Datasets\n", "| Dataset | Size | Content |\n", "|---------|------|---------|\n", "| `nielsr/CelebA-faces` | 202K | Celebrity faces |\n", "| `huggan/flowers-102-categories` | 8K | Flowers |\n", "| `reach-vb/pokemon-blip-captions` | 833 | Pokemon art |\n", "| `huggan/anime-faces` | 21K | Anime faces |\n", "| `huggan/AFHQv2` | 16K | Cat/dog/wild |\n", "| `Norod78/cartoon-blip-captions` | 2K | Cartoon characters |"]}
32
  ]
33
  }
 
7
  "accelerator": "GPU"
8
  },
9
  "cells": [
10
+ {"cell_type": "markdown", "metadata": {}, "source": ["# \ud83c\udf0a LiquidDiffusion: Attention-Free Image Generation with Liquid Neural Networks\n", "\n", "A **novel image generation model** combining:\n", "- **Liquid Neural Networks** (CfC) for adaptive, time-aware processing\n", "- **Rectified Flow** for simple, stable training\n", "- **Pretrained SD-VAE** for efficient latent-space training\n", "- **Zero attention** \u2014 fully convolutional\n", "- **Fully parallelizable** \u2014 no sequential ODE loops\n", "\n", "**Repo**: [krystv/liquid-diffusion](https://huggingface.co/krystv/liquid-diffusion)"]},
11
  {"cell_type": "markdown", "metadata": {}, "source": ["## \u2699\ufe0f Configuration"]},
12
+ {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["#@title \u2699\ufe0f Training Configuration\n", "\n", "# === MODEL ===\n", "MODEL_SIZE = 'tiny' #@param ['tiny', 'small', 'base', 'custom']\n", "CUSTOM_CHANNELS = [48, 96, 192]\n", "CUSTOM_BLOCKS = [1, 2, 3]\n", "CUSTOM_T_DIM = 192\n", "\n", "# === TRAINING MODE ===\n", "TRAINING_MODE = 'latent' #@param ['latent', 'pixel']\n", "# latent = train in VAE latent space (4ch, 8x smaller) - RECOMMENDED\n", "# pixel = train directly on RGB pixels (3ch, full res)\n", "\n", "# === IMAGE RESOLUTION ===\n", "IMAGE_SIZE = 256 #@param [128, 256, 512] {type:\"integer\"}\n", "\n", "# === DATASET ===\n", "DATASET = 'huggan/AFHQv2' #@param ['huggan/AFHQv2', 'nielsr/CelebA-faces', 'huggan/flowers-102-categories', 'reach-vb/pokemon-blip-captions', 'huggan/anime-faces', 'Norod78/cartoon-blip-captions']\n", "# huggan/AFHQv2 \u2192 16K animal faces (512px native)\n", "# nielsr/CelebA-faces \u2192 202K celebrity faces\n", "# huggan/flowers-102-categories \u2192 8K flower photos\n", "# reach-vb/pokemon-blip-captions \u2192 833 pokemon illustrations\n", "# huggan/anime-faces \u2192 63K anime faces (64px native)\n", "# Norod78/cartoon-blip-captions \u2192 ~3K cartoon characters\n", "IMAGE_COLUMN = 'image'\n", "USE_STREAMING = False #@param {type:\"boolean\"}\n", "MAX_SAMPLES = None # Set to e.g. 1000 for quick test\n", "\n", "# === TRAINING ===\n", "BATCH_SIZE = 8 #@param {type:\"integer\"}\n", "LEARNING_RATE = 1e-4 #@param {type:\"number\"}\n", "WEIGHT_DECAY = 0.01\n", "NUM_EPOCHS = 100 #@param {type:\"integer\"}\n", "GRAD_CLIP = 1.0\n", "EMA_DECAY = 0.9999\n", "NUM_WORKERS = 2\n", "TIME_SAMPLING = 'logit_normal' #@param ['logit_normal', 'uniform']\n", "USE_AMP = True #@param {type:\"boolean\"}\n", "AMP_DTYPE = 'float16'\n", "\n", "# === SAMPLING & CHECKPOINTS ===\n", "SAMPLE_EVERY = 500 #@param {type:\"integer\"}\n", "NUM_SAMPLE_IMAGES = 8\n", "NUM_EULER_STEPS = 50\n", "SAVE_EVERY = 2000 #@param {type:\"integer\"}\n", "OUTPUT_DIR = './outputs'\n", "RESUME_FROM = None\n", "LOG_EVERY = 50\n", "\n", "print(f'\u2705 Config: {MODEL_SIZE} model, {IMAGE_SIZE}px, mode={TRAINING_MODE}')\n", "print(f' Dataset: {DATASET}')\n", "print(f' bs={BATCH_SIZE}, lr={LEARNING_RATE}, epochs={NUM_EPOCHS}, AMP={USE_AMP}')"]},
13
+ {"cell_type": "markdown", "metadata": {}, "source": ["## \ud83d\udce6 Install Dependencies"]},
14
+ {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["!pip install -q datasets diffusers accelerate huggingface_hub Pillow matplotlib\n", "import torch\n", "print(f'PyTorch {torch.__version__}, CUDA: {torch.cuda.is_available()}')\n", "if torch.cuda.is_available():\n", " print(f'GPU: {torch.cuda.get_device_name(0)}, VRAM: {torch.cuda.get_device_properties(0).total_mem/1e9:.1f}GB')"]},
15
  {"cell_type": "markdown", "metadata": {}, "source": ["## \ud83c\udfd7\ufe0f Model Architecture"]},
16
+ {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["import math, copy, os, time\nimport torch, torch.nn as nn, torch.nn.functional as F\nfrom torch.utils.data import DataLoader, Dataset, IterableDataset\nfrom torchvision import transforms\nfrom torchvision.utils import save_image, make_grid\n\nclass SinusoidalTimeEmbedding(nn.Module):\n def __init__(self, dim, max_period=10000):\n super().__init__()\n self.dim, self.max_period = dim, max_period\n self.mlp = nn.Sequential(nn.Linear(dim, dim*4), nn.SiLU(), nn.Linear(dim*4, dim))\n def forward(self, t):\n half = self.dim // 2\n freqs = torch.exp(-math.log(self.max_period) * torch.arange(half, device=t.device, dtype=t.dtype) / half)\n args = t[:, None] * freqs[None, :]\n emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)\n if self.dim % 2: emb = F.pad(emb, (0, 1))\n return self.mlp(emb)\n\nclass AdaLN(nn.Module):\n def __init__(self, dim, cond_dim):\n super().__init__()\n ng = min(32, dim)\n while dim % ng != 0: ng -= 1\n self.norm = nn.GroupNorm(ng, dim, affine=False)\n self.proj = nn.Sequential(nn.SiLU(), nn.Linear(cond_dim, dim * 2))\n def forward(self, x, t_emb):\n s, sh = self.proj(t_emb).chunk(2, dim=1)\n return self.norm(x) * (1 + s[:,:,None,None]) + sh[:,:,None,None]\n\nclass ParallelCfCBlock(nn.Module):\n \"\"\"CfC Eq.10: x(t) = \\u03c3(-f\\u00b7t)\\u2299g + (1-\\u03c3(-f\\u00b7t))\\u2299h \\u2014 fully parallel.\"\"\"\n def __init__(self, dim, t_dim, expand_ratio=2.0, kernel_size=7, dropout=0.0):\n super().__init__()\n hidden = int(dim * expand_ratio)\n self.backbone_dw = nn.Conv2d(dim, dim, kernel_size, padding=kernel_size//2, groups=dim)\n self.backbone_pw = nn.Conv2d(dim, hidden, 1)\n self.backbone_act = nn.SiLU()\n self.f_head = nn.Conv2d(hidden, dim, 1)\n self.g_head = nn.Sequential(nn.Conv2d(hidden, hidden, kernel_size, padding=kernel_size//2, groups=hidden), nn.SiLU(), nn.Conv2d(hidden, dim, 1))\n self.h_head = nn.Sequential(nn.Conv2d(hidden, hidden, kernel_size, padding=kernel_size//2, groups=hidden), nn.SiLU(), nn.Conv2d(hidden, dim, 1))\n self.time_a, self.time_b = nn.Linear(t_dim, dim), nn.Linear(t_dim, dim)\n self.rho = nn.Parameter(torch.zeros(1, dim, 1, 1))\n self.output_gate = nn.Sequential(nn.SiLU(), nn.Linear(t_dim, dim))\n self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()\n def forward(self, x, t_emb):\n residual = x\n bb = self.backbone_act(self.backbone_pw(self.backbone_dw(x)))\n f, g, h = self.f_head(bb), self.g_head(bb), self.h_head(bb)\n ta, tb = self.time_a(t_emb)[:,:,None,None], self.time_b(t_emb)[:,:,None,None]\n gate = torch.sigmoid(ta * f - tb)\n cfc_out = self.dropout(gate * g + (1.0 - gate) * h)\n t_sc = t_emb.mean(dim=1, keepdim=True)[:,:,None,None]\n alpha = torch.exp(-(F.softplus(self.rho) + 1e-6) * t_sc.abs().clamp(min=0.01))\n out = alpha * residual + (1.0 - alpha) * cfc_out\n return out * torch.sigmoid(self.output_gate(t_emb))[:,:,None,None]\n\nclass MultiScaleSpatialMix(nn.Module):\n def __init__(self, dim, t_dim):\n super().__init__()\n self.dw3 = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)\n self.dw5 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)\n self.dw7 = nn.Conv2d(dim, dim, 7, padding=3, groups=dim)\n self.global_pool, self.global_proj = nn.AdaptiveAvgPool2d(1), nn.Conv2d(dim, dim, 1)\n self.merge, self.act, self.adaln = nn.Conv2d(dim*4, dim, 1), nn.SiLU(), AdaLN(dim, t_dim)\n def forward(self, x, t_emb):\n xn = self.adaln(x, t_emb)\n return x + self.act(self.merge(torch.cat([self.dw3(xn), self.dw5(xn), self.dw7(xn), self.global_proj(self.global_pool(xn)).expand_as(xn)], dim=1)))\n\nclass LiquidDiffusionBlock(nn.Module):\n def __init__(self, dim, t_dim, expand_ratio=2.0, kernel_size=7, dropout=0.0):\n super().__init__()\n self.adaln1, self.cfc = AdaLN(dim, t_dim), ParallelCfCBlock(dim, t_dim, expand_ratio, kernel_size, dropout)\n self.spatial_mix, self.adaln2 = MultiScaleSpatialMix(dim, t_dim), AdaLN(dim, t_dim)\n ff_dim = int(dim * expand_ratio)\n self.ff = nn.Sequential(nn.Conv2d(dim, ff_dim, 1), nn.SiLU(), nn.Conv2d(ff_dim, dim, 1))\n self.res_scale = nn.Parameter(torch.ones(1) * 0.1)\n def forward(self, x, t_emb):\n x = x + self.res_scale * self.cfc(self.adaln1(x, t_emb), t_emb)\n x = self.spatial_mix(x, t_emb)\n return x + self.res_scale * self.ff(self.adaln2(x, t_emb))\n\nclass DownSample(nn.Module):\n def __init__(self, i, o): super().__init__(); self.conv = nn.Conv2d(i, o, 3, stride=2, padding=1)\n def forward(self, x): return self.conv(x)\nclass UpSample(nn.Module):\n def __init__(self, i, o): super().__init__(); self.conv = nn.Conv2d(i, o, 3, padding=1)\n def forward(self, x): return self.conv(F.interpolate(x, scale_factor=2, mode='nearest'))\nclass SkipFusion(nn.Module):\n def __init__(self, dim, t_dim):\n super().__init__()\n self.proj = nn.Conv2d(dim*2, dim, 1)\n self.gate = nn.Sequential(nn.SiLU(), nn.Linear(t_dim, dim), nn.Sigmoid())\n def forward(self, x, skip, t_emb):\n m = self.proj(torch.cat([x, skip], dim=1)); g = self.gate(t_emb)[:,:,None,None]\n return m * g + x * (1 - g)\n\nclass LiquidDiffusionUNet(nn.Module):\n def __init__(self, in_channels=3, channels=None, blocks_per_stage=None, t_dim=256, expand_ratio=2.0, kernel_size=7, dropout=0.0):\n super().__init__()\n channels = channels or [64,128,256]; blocks_per_stage = blocks_per_stage or [2,2,4]\n assert len(channels) == len(blocks_per_stage)\n self.channels, self.num_stages, self.in_channels = channels, len(channels), in_channels\n self.time_embed = SinusoidalTimeEmbedding(t_dim)\n self.stem = nn.Sequential(nn.Conv2d(in_channels, channels[0], 3, padding=1), nn.SiLU(), nn.Conv2d(channels[0], channels[0], 3, padding=1))\n self.encoder_blocks, self.downsamplers = nn.ModuleList(), nn.ModuleList()\n for i in range(self.num_stages):\n self.encoder_blocks.append(nn.ModuleList([LiquidDiffusionBlock(channels[i], t_dim, expand_ratio, kernel_size, dropout) for _ in range(blocks_per_stage[i])]))\n if i < self.num_stages - 1: self.downsamplers.append(DownSample(channels[i], channels[i+1]))\n self.bottleneck = nn.ModuleList([LiquidDiffusionBlock(channels[-1], t_dim, expand_ratio, kernel_size, dropout) for _ in range(2)])\n self.decoder_blocks, self.upsamplers, self.skip_fusions = nn.ModuleList(), nn.ModuleList(), nn.ModuleList()\n for i in range(self.num_stages-1, -1, -1):\n if i < self.num_stages - 1:\n self.upsamplers.append(UpSample(channels[i+1], channels[i])); self.skip_fusions.append(SkipFusion(channels[i], t_dim))\n self.decoder_blocks.append(nn.ModuleList([LiquidDiffusionBlock(channels[i], t_dim, expand_ratio, kernel_size, dropout) for _ in range(blocks_per_stage[i])]))\n hg = min(32, channels[0])\n while channels[0] % hg != 0: hg -= 1\n self.head = nn.Sequential(nn.GroupNorm(hg, channels[0]), nn.SiLU(), nn.Conv2d(channels[0], in_channels, 3, padding=1))\n nn.init.zeros_(self.head[-1].weight); nn.init.zeros_(self.head[-1].bias)\n def forward(self, x, t):\n t_emb, h = self.time_embed(t), self.stem(x)\n skips = []\n for i in range(self.num_stages):\n for blk in self.encoder_blocks[i]: h = blk(h, t_emb)\n skips.append(h)\n if i < self.num_stages - 1: h = self.downsamplers[i](h)\n for blk in self.bottleneck: h = blk(h, t_emb)\n up_idx = 0\n for di in range(self.num_stages):\n si = self.num_stages - 1 - di\n if di > 0: h = self.upsamplers[up_idx](h); h = self.skip_fusions[up_idx](h, skips[si], t_emb); up_idx += 1\n for blk in self.decoder_blocks[di]: h = blk(h, t_emb)\n return self.head(h)\n def count_params(self): return sum(p.numel() for p in self.parameters()), sum(p.numel() for p in self.parameters() if p.requires_grad)\n\nprint('\u2705 LiquidDiffusion architecture defined.')"]},
17
  {"cell_type": "markdown", "metadata": {}, "source": ["## \ud83d\udd27 Build Model + Load VAE"]},
18
+ {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["device = 'cuda' if torch.cuda.is_available() else 'cpu'\nvae, vae_scale, model_in_channels = None, 1.0, 3\n\nif TRAINING_MODE == 'latent':\n from diffusers import AutoencoderKL\n print('Loading pretrained SD-VAE (stabilityai/sd-vae-ft-mse)...')\n vae = AutoencoderKL.from_pretrained('stabilityai/sd-vae-ft-mse',\n torch_dtype=torch.float16 if (USE_AMP and device=='cuda') else torch.float32\n ).to(device).eval()\n vae.requires_grad_(False)\n vae_scale = vae.config.scaling_factor # 0.18215\n model_in_channels = vae.config.latent_channels # 4\n latent_size = IMAGE_SIZE // 8\n print(f' VAE: {sum(p.numel() for p in vae.parameters())/1e6:.1f}M params (frozen)')\n print(f' Latent: {IMAGE_SIZE}px \\u2192 {latent_size}x{latent_size}x{model_in_channels}')\n if device == 'cuda': print(f' VAE VRAM: {torch.cuda.memory_allocated()/1e9:.2f} GB')\nelse:\n latent_size = IMAGE_SIZE\n print('Pixel mode: no VAE')\n\nMODEL_CONFIGS = {\n 'tiny': dict(channels=[64,128,256], blocks_per_stage=[2,2,4], t_dim=256),\n 'small': dict(channels=[96,192,384], blocks_per_stage=[2,3,6], t_dim=384),\n 'base': dict(channels=[128,256,512], blocks_per_stage=[2,4,8], t_dim=512),\n}\ncfg = MODEL_CONFIGS.get(MODEL_SIZE, dict(channels=CUSTOM_CHANNELS, blocks_per_stage=CUSTOM_BLOCKS, t_dim=CUSTOM_T_DIM))\ncfg['in_channels'] = model_in_channels\n\nmodel = LiquidDiffusionUNet(**cfg).to(device)\ntotal_p, _ = model.count_params()\nprint(f'\\nLiquidDiffusion [{MODEL_SIZE}]: {total_p:,} ({total_p/1e6:.1f}M) params')\nprint(f' in_ch={model_in_channels}, channels={cfg[\"channels\"]}, blocks={cfg[\"blocks_per_stage\"]}')\nwith torch.no_grad():\n tx = torch.randn(1, model_in_channels, latent_size, latent_size, device=device)\n to = model(tx, torch.tensor([0.5], device=device))\n print(f' Forward: {tx.shape} \\u2192 {to.shape} \\u2713'); del tx, to\nif device == 'cuda': torch.cuda.empty_cache(); print(f' Total VRAM: {torch.cuda.memory_allocated()/1e9:.2f} GB')"]},
19
  {"cell_type": "markdown", "metadata": {}, "source": ["## \ud83d\udcca Load Dataset"]},
20
+ {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["from PIL import Image\nfrom datasets import load_dataset\nimport matplotlib.pyplot as plt\n\nclass HFImageDataset(Dataset):\n def __init__(self, hf_data, image_size, image_column='image'):\n self.data, self.col = hf_data, image_column\n self.transform = transforms.Compose([\n transforms.Resize(image_size, interpolation=transforms.InterpolationMode.LANCZOS),\n transforms.CenterCrop(image_size), transforms.RandomHorizontalFlip(),\n transforms.ToTensor(), transforms.Normalize([0.5],[0.5])])\n def __len__(self): return len(self.data)\n def __getitem__(self, idx):\n img = self.data[idx][self.col]\n if not hasattr(img, 'convert'): img = Image.fromarray(img)\n return self.transform(img.convert('RGB'))\n\nclass StreamingImageDataset(IterableDataset):\n def __init__(self, name, image_size, image_column='image'):\n self.ds, self.col = load_dataset(name, split='train', streaming=True), image_column\n self.transform = transforms.Compose([\n transforms.Resize(image_size, interpolation=transforms.InterpolationMode.LANCZOS),\n transforms.CenterCrop(image_size), transforms.RandomHorizontalFlip(),\n transforms.ToTensor(), transforms.Normalize([0.5],[0.5])])\n def __iter__(self):\n for s in self.ds:\n img = s[self.col]\n if not hasattr(img, 'convert'): img = Image.fromarray(img)\n yield self.transform(img.convert('RGB'))\n\nprint(f'Loading: {DATASET}')\nif USE_STREAMING:\n dataset = StreamingImageDataset(DATASET, IMAGE_SIZE, IMAGE_COLUMN)\n dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, pin_memory=True)\n print(' Streaming mode')\nelse:\n hf_data = load_dataset(DATASET, split='train')\n if MAX_SAMPLES: hf_data = hf_data.select(range(min(MAX_SAMPLES, len(hf_data))))\n dataset = HFImageDataset(hf_data, IMAGE_SIZE, IMAGE_COLUMN)\n dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, drop_last=True)\n print(f' {len(dataset):,} images, {len(dataloader):,} steps/epoch')\n\n# Preview\nsb = next(iter(dataloader))\nfig, axes = plt.subplots(1, min(8, sb.shape[0]), figsize=(16, 2.5))\nif not hasattr(axes, '__len__'): axes = [axes]\nfor i, ax in enumerate(axes): ax.imshow((sb[i].permute(1,2,0)*0.5+0.5).clamp(0,1)); ax.axis('off')\nplt.suptitle(f'{DATASET} ({IMAGE_SIZE}px)'); plt.tight_layout(); plt.show()\n\nif vae is not None:\n with torch.no_grad():\n ti = sb[:4].to(device, dtype=vae.dtype)\n lat = vae.encode(ti).latent_dist.sample() * vae_scale\n dec = vae.decode(lat / vae_scale).sample\n print(f'\\n VAE: {ti.shape} \\u2192 {lat.shape} \\u2192 {dec.shape}')\n print(f' Latent: mean={lat.mean():.4f}, std={lat.std():.4f}')\n fig, axes = plt.subplots(2, 4, figsize=(12, 6))\n for i in range(4):\n axes[0,i].imshow((ti[i].cpu().float().permute(1,2,0)*0.5+0.5).clamp(0,1)); axes[0,i].set_title('Original'); axes[0,i].axis('off')\n axes[1,i].imshow((dec[i].cpu().float().permute(1,2,0)*0.5+0.5).clamp(0,1)); axes[1,i].set_title('VAE Recon'); axes[1,i].axis('off')\n plt.suptitle('VAE Quality Check'); plt.tight_layout(); plt.show()"]},
 
 
21
  {"cell_type": "markdown", "metadata": {}, "source": ["## \ud83d\ude80 Training"]},
22
+ {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["os.makedirs(f'{OUTPUT_DIR}/samples', exist_ok=True)\nos.makedirs(f'{OUTPUT_DIR}/checkpoints', exist_ok=True)\n\noptimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY, betas=(0.9, 0.999))\ntotal_steps = len(dataloader) * NUM_EPOCHS if not USE_STREAMING else SAMPLE_EVERY * 200\nwarmup_steps = min(1000, total_steps // 10)\ndef lr_lambda(step):\n if step < warmup_steps: return float(step) / max(1, warmup_steps)\n return max(0.0, 0.5 * (1.0 + math.cos(math.pi * (step - warmup_steps) / max(1, total_steps - warmup_steps))))\nscheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)\n\nema_model = copy.deepcopy(model).eval()\nfor p in ema_model.parameters(): p.requires_grad_(False)\nscaler = torch.amp.GradScaler('cuda', enabled=(USE_AMP and device=='cuda'))\namp_dtype = getattr(torch, AMP_DTYPE) if (USE_AMP and device=='cuda') else torch.float32\n\ndef sample_time(bs):\n eps = 1e-5\n if TIME_SAMPLING == 'uniform': return torch.rand(bs, device=device)*(1-2*eps)+eps\n return torch.sigmoid(torch.randn(bs, device=device)).clamp(eps, 1-eps)\n\nglobal_step, start_epoch, all_losses = 0, 0, []\nif RESUME_FROM and os.path.exists(RESUME_FROM):\n ckpt = torch.load(RESUME_FROM, map_location=device, weights_only=False)\n model.load_state_dict(ckpt['model']); ema_model.load_state_dict(ckpt['ema_model'])\n optimizer.load_state_dict(ckpt['optimizer'])\n global_step, start_epoch = ckpt.get('step',0), ckpt.get('epoch',0)\n all_losses = ckpt.get('losses',[]); print(f'Resumed from step {global_step}')\n\n@torch.no_grad()\ndef generate_samples(step):\n ema_model.eval()\n z = torch.randn(NUM_SAMPLE_IMAGES, model_in_channels, latent_size, latent_size, device=device)\n dt = 1.0 / NUM_EULER_STEPS\n for i in range(NUM_EULER_STEPS, 0, -1):\n t = torch.full((NUM_SAMPLE_IMAGES,), i/NUM_EULER_STEPS, device=device)\n with torch.amp.autocast(device, dtype=amp_dtype, enabled=USE_AMP and device=='cuda'): v = ema_model(z, t)\n if USE_AMP and amp_dtype == torch.float16: v = v.float()\n z = z - v * dt\n if vae is not None: pixels = vae.decode((z / vae_scale).to(vae.dtype)).sample.float()\n else: pixels = z\n pixels = pixels.clamp(-1, 1)\n save_image(make_grid(pixels*0.5+0.5, nrow=int(math.sqrt(NUM_SAMPLE_IMAGES)), padding=2), f'{OUTPUT_DIR}/samples/step_{step:06d}.png')\n return pixels\n\nprint(f'\\n{\"=\"*60}')\nprint(f'Training: {NUM_EPOCHS} epochs, mode={TRAINING_MODE}, latent={latent_size}x{latent_size}x{model_in_channels}')\nprint(f'{\"=\"*60}\\n')\ntrain_start = time.time()\nepoch_losses = []\n\nfor epoch in range(start_epoch, NUM_EPOCHS):\n model.train(); epoch_loss, nb = 0, 0\n for batch_idx, pixel_batch in enumerate(dataloader):\n pixel_batch = pixel_batch.to(device, non_blocking=True)\n if vae is not None:\n with torch.no_grad(): x0 = vae.encode(pixel_batch.to(vae.dtype)).latent_dist.sample().float() * vae_scale\n else: x0 = pixel_batch\n x1 = torch.randn_like(x0); t = sample_time(x0.shape[0]); te = t[:,None,None,None]\n x_t = (1-te)*x0 + te*x1; v_target = x1 - x0\n with torch.amp.autocast(device, dtype=amp_dtype, enabled=USE_AMP and device=='cuda'):\n loss = F.mse_loss(model(x_t, t), v_target)\n optimizer.zero_grad(set_to_none=True); scaler.scale(loss).backward()\n if GRAD_CLIP > 0: scaler.unscale_(optimizer); torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)\n scaler.step(optimizer); scaler.update(); scheduler.step()\n with torch.no_grad():\n for ep, mp in zip(ema_model.parameters(), model.parameters()): ep.data.mul_(EMA_DECAY).add_(mp.data, alpha=1-EMA_DECAY)\n global_step += 1; nb += 1; lv = loss.item(); all_losses.append(lv); epoch_loss += lv\n if global_step % LOG_EVERY == 0:\n avg = sum(all_losses[-LOG_EVERY:])/LOG_EVERY; lr = scheduler.get_last_lr()[0]\n sps = global_step/(time.time()-train_start)\n vr = f', VRAM:{torch.cuda.max_memory_allocated()/1e9:.1f}GB' if device=='cuda' else ''\n print(f'Step {global_step:6d} | Loss:{avg:.4f} | LR:{lr:.2e} | {sps:.1f}it/s{vr}')\n if global_step % SAMPLE_EVERY == 0:\n print(' \\ud83d\\udcf8 Generating...'); samples = generate_samples(global_step)\n fig, axes = plt.subplots(1, min(8, NUM_SAMPLE_IMAGES), figsize=(16, 2.5))\n if not hasattr(axes, '__len__'): axes = [axes]\n for i, ax in enumerate(axes):\n if i < samples.shape[0]: ax.imshow((samples[i].cpu().permute(1,2,0)*0.5+0.5).clamp(0,1))\n ax.axis('off')\n plt.suptitle(f'Step {global_step} | Loss:{lv:.4f}'); plt.tight_layout(); plt.show()\n if global_step % SAVE_EVERY == 0:\n torch.save({'model':model.state_dict(),'ema_model':ema_model.state_dict(),'optimizer':optimizer.state_dict(),'step':global_step,'epoch':epoch,'losses':all_losses[-2000:],'config':cfg}, f'{OUTPUT_DIR}/checkpoints/step_{global_step:06d}.pt')\n print(' \\ud83d\\udcbe Checkpoint saved')\n if nb > 0: epoch_losses.append(epoch_loss/nb); print(f' Epoch {epoch+1}/{NUM_EPOCHS} | Loss:{epoch_losses[-1]:.4f}')\n\ntorch.save({'model':model.state_dict(),'ema_model':ema_model.state_dict(),'step':global_step,'config':cfg,'losses':all_losses[-2000:]}, f'{OUTPUT_DIR}/checkpoints/final.pt')\nprint(f'\\n\\u2705 Done! {(time.time()-train_start)/3600:.1f}h')"]},
23
  {"cell_type": "markdown", "metadata": {}, "source": ["## \ud83d\udcc8 Training Curves"]},
24
+ {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["import numpy as np\nfig, (a1, a2) = plt.subplots(1, 2, figsize=(14, 5))\na1.plot(all_losses, alpha=0.3, color='blue', linewidth=0.5)\nw = min(200, max(1, len(all_losses)//5))\nif w > 1 and len(all_losses) > w:\n sm = np.convolve(all_losses, np.ones(w)/w, mode='valid')\n a1.plot(range(w-1, len(all_losses)), sm, color='red', linewidth=2, label=f'Smooth(w={w})')\na1.set_xlabel('Step'); a1.set_ylabel('Loss'); a1.set_title('Training Loss'); a1.legend(); a1.grid(True, alpha=0.3)\nif epoch_losses:\n a2.plot(range(1, len(epoch_losses)+1), epoch_losses, 'o-', color='green')\n a2.set_xlabel('Epoch'); a2.set_ylabel('Loss'); a2.set_title('Per Epoch'); a2.grid(True, alpha=0.3)\nplt.tight_layout(); plt.show()"]},
25
  {"cell_type": "markdown", "metadata": {}, "source": ["## \ud83c\udfa8 Generate Images"]},
26
+ {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["NUM_GENERATE = 16 #@param {type:\"integer\"}\nEULER_STEPS = 50 #@param {type:\"integer\"}\n\nprint(f'Generating {NUM_GENERATE} images ({EULER_STEPS} steps)...')\nema_model.eval()\nwith torch.no_grad():\n z = torch.randn(NUM_GENERATE, model_in_channels, latent_size, latent_size, device=device)\n dt = 1.0 / EULER_STEPS\n for i in range(EULER_STEPS, 0, -1):\n t = torch.full((NUM_GENERATE,), i/EULER_STEPS, device=device)\n with torch.amp.autocast(device, dtype=amp_dtype, enabled=USE_AMP and device=='cuda'): v = ema_model(z, t)\n if USE_AMP and amp_dtype == torch.float16: v = v.float()\n z = z - v * dt\n if vae is not None: generated = vae.decode((z/vae_scale).to(vae.dtype)).sample.float().clamp(-1,1)\n else: generated = z.clamp(-1,1)\n\nnr = int(math.ceil(math.sqrt(NUM_GENERATE)))\nfig, axes = plt.subplots(nr, nr, figsize=(2.5*nr, 2.5*nr))\naxes = axes.flatten() if hasattr(axes, 'flatten') else [axes]\nfor i, ax in enumerate(axes):\n if i < NUM_GENERATE: ax.imshow((generated[i].cpu().permute(1,2,0)*0.5+0.5).clamp(0,1))\n ax.axis('off')\nplt.suptitle(f'LiquidDiffusion ({IMAGE_SIZE}px)', fontsize=14); plt.tight_layout(); plt.show()\nsave_image(make_grid(generated*0.5+0.5, nrow=nr, padding=2), f'{OUTPUT_DIR}/final_samples.png')\nprint(f'Saved to {OUTPUT_DIR}/final_samples.png')"]},
27
+ {"cell_type": "markdown", "metadata": {}, "source": ["## \ud83d\udcbe Save to Hub"]},
28
+ {"cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": ["PUSH_TO_HUB = False #@param {type:\"boolean\"}\nHUB_MODEL_ID = 'your-username/liquid-diffusion-model' #@param {type:\"string\"}\nif PUSH_TO_HUB:\n from huggingface_hub import HfApi\n api = HfApi(); api.create_repo(HUB_MODEL_ID, exist_ok=True)\n api.upload_file(path_or_fileobj=f'{OUTPUT_DIR}/checkpoints/final.pt', path_in_repo='model.pt', repo_id=HUB_MODEL_ID)\n print(f'Pushed to https://huggingface.co/{HUB_MODEL_ID}')"]},
29
+ {"cell_type": "markdown", "metadata": {}, "source": ["---\n", "## \ud83d\udcd6 Architecture\n", "\n", "### CfC Time-Gating\n", "```\n", "gate = \u03c3(time_a(t) \u00b7 f(features) - time_b(t))\n", "out = gate \u00b7 g + (1-gate) \u00b7 h\n", "\u03b1 = exp(-\u03bb\u00b7|t|) \u2192 time-aware residual\n", "```\n", "\n", "### Latent Training Pipeline\n", "```\n", "pixel (3\u00d7256\u00d7256) \u2192 [SD-VAE encode] \u2192 latent (4\u00d732\u00d732) \u2192 [LiquidDiffusion] \u2192 [SD-VAE decode] \u2192 pixel\n", "```\n", "\n", "### References\n", "- [CfC (Nature MI 2022)](https://arxiv.org/abs/2106.13898)\n", "- [LiquidTAD](https://arxiv.org/abs/2604.18274)\n", "- [Rectified Flow (ICLR 2023)](https://arxiv.org/abs/2209.03003)\n", "- [SD-VAE ft-MSE](https://huggingface.co/stabilityai/sd-vae-ft-mse)"]}
30
  ]
31
  }