asdf98 commited on
Commit
49c59f1
·
verified ·
1 Parent(s): 0b46772

Fix notebook: use pure-parquet datasets only, cartoon default

Browse files
Files changed (1) hide show
  1. LiquidGen_Colab_Notebook.ipynb +1 -1
LiquidGen_Colab_Notebook.ipynb CHANGED
@@ -1 +1 @@
1
- {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"gpuType":"T4"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{},"source":["# LiquidGen: Liquid Neural Network Image Generator\n","\n","A novel **attention-free** diffusion model using CfC Liquid Neural Network dynamics.\n","\n","**Optimized for Colab free tier (T4 16GB):**\n","- Latent pre-caching: encode images once with VAE, then train on pure tensors\n","- No VAE loaded during training = more VRAM for model + larger batches\n","- Small curated datasets that download in seconds\n","- Uses **open SDXL VAE** (no login/auth needed)\n","\n","**Dataset presets:**\n","| Preset | Images | Size | Type |\n","|--------|--------|------|------|\n","| `paintings_mini` | ~200 | 1.7MB | 27 painting styles (smoke test) |\n","| `paintings` | ~8K | 204MB | 27 painting styles (recommended) |\n","| `cartoon` | ~2.5K | 181MB | Cartoon/anime (unconditional) |\n","| `flowers` | ~8K | 331MB | Flower photos (unconditional) |\n","| `wikiart_stream` | ~80K | streaming | Full WikiArt (use max_images) |\n"]},{"cell_type":"markdown","metadata":{},"source":["## 1. Install"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["!pip install -q torch torchvision diffusers datasets accelerate huggingface_hub"]},{"cell_type":"markdown","metadata":{},"source":["## 2. Configuration"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["# ===== EDIT THESE =====\n","MODEL_SIZE = \"small\" # \"small\" (~55M), \"base\" (~140M), \"large\" (~280M)\n","IMAGE_SIZE = 256 # 256 or 512\n","DATASET_PRESET = \"paintings\" # See table above\n","MAX_IMAGES = 0 # 0=all, >0 to limit\n","BATCH_SIZE = 32 # Large OK - training on cached tensors!\n","GRAD_ACCUM = 1\n","LEARNING_RATE = 1e-4\n","NUM_EPOCHS = 100\n","WARMUP_STEPS = 500\n","SAMPLE_EVERY = 500\n","SAMPLE_STEPS = 50\n","CFG_SCALE = 2.0\n","OUTPUT_DIR = \"/content/liquidgen\"\n","SAVE_EVERY = 2000\n","LOG_EVERY = 25\n","\n","# VAE — fully open, no login needed\n","VAE_ID = \"madebyollin/sdxl-vae-fp16-fix\"\n","VAE_SCALE = 0.13025 # SDXL VAE scaling factor\n","LATENT_CH = 4 # SDXL VAE has 4 latent channels\n","\n","import torch\n","if torch.cuda.is_available():\n"," g = torch.cuda.get_device_name(0)\n"," m = torch.cuda.get_device_properties(0).total_mem/1024**3\n"," print(f\"GPU: {g} ({m:.1f}GB)\")\n","else:\n"," print(\"No GPU! Runtime > Change runtime type > GPU\")"]},{"cell_type":"markdown","metadata":{},"source":["## 3. Download Model Code"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["!wget -q https://huggingface.co/asdf98/LiquidGen/resolve/main/model.py\n","!wget -q https://huggingface.co/asdf98/LiquidGen/resolve/main/train.py\n","from model import LiquidGen, liquidgen_small, liquidgen_base, liquidgen_large\n","from train import (TrainConfig, DATASET_PRESETS, get_model_config,\n"," precache_latents, CachedLatentDataset, FlowMatchingScheduler,\n"," EMAModel, cosine_schedule)\n","print(\"Code loaded!\")\n","for n,f in [(\"Small\",liquidgen_small),(\"Base\",liquidgen_base),(\"Large\",liquidgen_large)]:\n"," m=f(num_classes=27); print(f\" LiquidGen-{n}: {m.count_params()/1e6:.1f}M\"); del m"]},{"cell_type":"markdown","metadata":{},"source":["## 4. Pre-Cache Latents (one-time)\n","\n","Encodes all images to VAE latents, saves to disk, then unloads VAE to free VRAM.\n","Uses `madebyollin/sdxl-vae-fp16-fix` — fully open, no login needed."]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import os, time\n","os.makedirs(f\"{OUTPUT_DIR}/samples\", exist_ok=True)\n","os.makedirs(f\"{OUTPUT_DIR}/checkpoints\", exist_ok=True)\n","\n","preset = DATASET_PRESETS[DATASET_PRESET]\n","NUM_CLASSES = preset[\"num_classes\"]\n","print(f\"Dataset: {preset['description']}\")\n","\n","config = TrainConfig(\n"," model_size=MODEL_SIZE, num_classes=NUM_CLASSES,\n"," dataset_preset=DATASET_PRESET, image_size=IMAGE_SIZE,\n"," max_images=MAX_IMAGES, batch_size=BATCH_SIZE,\n"," gradient_accumulation_steps=GRAD_ACCUM,\n"," learning_rate=LEARNING_RATE, num_epochs=NUM_EPOCHS,\n"," warmup_steps=WARMUP_STEPS, output_dir=OUTPUT_DIR,\n"," save_every_n_steps=SAVE_EVERY, sample_every_n_steps=SAMPLE_EVERY,\n"," log_every_n_steps=LOG_EVERY, num_sample_steps=SAMPLE_STEPS,\n"," cfg_scale=CFG_SCALE, vae_id=VAE_ID,\n"," vae_scaling_factor=VAE_SCALE, latent_channels=LATENT_CH,\n",")\n","cache_path = precache_latents(config)"]},{"cell_type":"markdown","metadata":{},"source":["## 5. Train!"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import torch.nn.functional as F\n","from torch.utils.data import DataLoader\n","from torch.amp import autocast, GradScaler\n","import math\n","\n","device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n","train_ds = CachedLatentDataset(cache_path)\n","train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,\n"," num_workers=2, pin_memory=True, drop_last=True)\n","\n","mcfg = get_model_config(MODEL_SIZE, NUM_CLASSES)\n","mcfg[\"in_channels\"] = LATENT_CH\n","model = LiquidGen(**mcfg).to(device)\n","print(f\"LiquidGen-{MODEL_SIZE}: {model.count_params()/1e6:.1f}M params\")\n","\n","opt = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)\n","total_steps = len(train_dl) * NUM_EPOCHS // GRAD_ACCUM\n","sched = cosine_schedule(opt, WARMUP_STEPS, total_steps)\n","ema = EMAModel(model, 0.9999)\n","scaler = GradScaler(\"cuda\")\n","fm = FlowMatchingScheduler()\n","lat_size = IMAGE_SIZE // 8\n","print(f\"Steps: {total_steps}, Batch: {BATCH_SIZE}\")\n","print(f\"Latent: [{BATCH_SIZE}, {LATENT_CH}, {lat_size}, {lat_size}]\")\n","\n","gs=0; la=0; log_losses=[]; vae=None\n","print(\"\\nTraining!\\n\")\n","t0 = time.time()\n","\n","for epoch in range(NUM_EPOCHS):\n"," model.train(); et=time.time()\n"," for bi,(lats,lbls) in enumerate(train_dl):\n"," lats=lats.to(device)\n"," lbls=lbls.to(device) if NUM_CLASSES>0 else None\n"," t=fm.sample_timesteps(lats.shape[0],device)\n"," noise=torch.randn_like(lats)\n"," xt=fm.add_noise(lats,noise,t)\n"," vtgt=fm.get_velocity_target(lats,noise)\n"," with autocast(\"cuda\"):\n"," loss=F.mse_loss(model(xt,t,lbls),vtgt)/GRAD_ACCUM\n"," scaler.scale(loss).backward(); la+=loss.item()\n"," if (bi+1)%GRAD_ACCUM==0:\n"," scaler.unscale_(opt)\n"," gn=torch.nn.utils.clip_grad_norm_(model.parameters(),2.0)\n"," scaler.step(opt); scaler.update(); opt.zero_grad(); sched.step()\n"," ema.update(model); gs+=1\n"," if gs%LOG_EVERY==0:\n"," al=la/LOG_EVERY\n"," vram=torch.cuda.memory_allocated()/1024**3 if torch.cuda.is_available() else 0\n"," print(f\"step={gs:>5d} | ep={epoch} | loss={al:.4f} | gn={gn:.2f} | \"\n"," f\"lr={opt.param_groups[0]['lr']:.2e} | vram={vram:.1f}G\")\n"," log_losses.append(al); la=0\n"," if math.isnan(al): print(\"Diverged!\"); break\n"," if gs%SAMPLE_EVERY==0:\n"," if vae is None:\n"," from diffusers import AutoencoderKL\n"," vae=AutoencoderKL.from_pretrained(VAE_ID,\n"," torch_dtype=torch.float16).to(device).eval()\n"," for p in vae.parameters(): p.requires_grad_(False)\n"," ema.apply(model); model.eval()\n"," sl=torch.randint(0,max(1,NUM_CLASSES),(4,),device=device) if NUM_CLASSES>0 else None\n"," samp=fm.sample(model,(4,LATENT_CH,lat_size,lat_size),device,SAMPLE_STEPS,sl,CFG_SCALE)\n"," with torch.no_grad():\n"," imgs=((vae.decode(samp.half()/VAE_SCALE).sample+1)/2).clamp(0,1).float()\n"," from torchvision.utils import save_image\n"," save_image(imgs,f\"{OUTPUT_DIR}/samples/step_{gs:07d}.png\",nrow=2)\n"," print(f\" Saved samples\"); ema.restore(model); model.train()\n"," if gs%SAVE_EVERY==0:\n"," torch.save({\"model\":model.state_dict(),\"ema\":ema.shadow,\"step\":gs,\"cfg\":mcfg},\n"," f\"{OUTPUT_DIR}/checkpoints/step_{gs:07d}.pt\")\n"," print(f\"Epoch {epoch} | {time.time()-et:.0f}s\")\n","\n","torch.save({\"model\":model.state_dict(),\"ema\":ema.shadow,\"step\":gs,\"cfg\":mcfg},\n"," f\"{OUTPUT_DIR}/checkpoints/final.pt\")\n","print(f\"\\nDone! {gs} steps, {(time.time()-t0)/60:.1f}min\")"]},{"cell_type":"markdown","metadata":{},"source":["## 6. Loss Curve"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import matplotlib.pyplot as plt\n","if log_losses:\n"," plt.figure(figsize=(10,4)); plt.plot(log_losses)\n"," plt.xlabel(f\"Steps (x{LOG_EVERY})\"); plt.ylabel(\"Loss\")\n"," plt.title(\"Training Loss\"); plt.grid(True,alpha=0.3)\n"," plt.savefig(f\"{OUTPUT_DIR}/loss.png\",dpi=150); plt.show()"]},{"cell_type":"markdown","metadata":{},"source":["## 7. Generate Images"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["if vae is None:\n"," from diffusers import AutoencoderKL\n"," vae=AutoencoderKL.from_pretrained(VAE_ID,\n"," torch_dtype=torch.float16).to(device).eval()\n"," for p in vae.parameters(): p.requires_grad_(False)\n","ema.apply(model); model.eval()\n","ls=IMAGE_SIZE//8\n","if NUM_CLASSES>0:\n"," for ci in range(min(NUM_CLASSES,6)):\n"," l=torch.full((8,),ci,device=device,dtype=torch.long)\n"," s=fm.sample(model,(8,LATENT_CH,ls,ls),device,50,l,2.5)\n"," with torch.no_grad():\n"," i=((vae.decode(s.half()/VAE_SCALE).sample+1)/2).clamp(0,1).float()\n"," from torchvision.utils import save_image\n"," save_image(i,f\"{OUTPUT_DIR}/gen_class{ci}.png\",nrow=4)\n"," print(f\"Generated class {ci}\")\n","else:\n"," s=fm.sample(model,(8,LATENT_CH,ls,ls),device,50)\n"," with torch.no_grad():\n"," i=((vae.decode(s.half()/VAE_SCALE).sample+1)/2).clamp(0,1).float()\n"," from torchvision.utils import save_image\n"," save_image(i,f\"{OUTPUT_DIR}/gen_uncond.png\",nrow=4)\n","ema.restore(model); print(\"Done!\")"]},{"cell_type":"markdown","metadata":{},"source":["## 8. Display"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["from IPython.display import display\n","from PIL import Image\n","import glob\n","for f in sorted(glob.glob(f\"{OUTPUT_DIR}/samples/*.png\"))[-3:]:\n"," print(os.path.basename(f)); display(Image.open(f))\n","for f in sorted(glob.glob(f\"{OUTPUT_DIR}/gen_*.png\")):\n"," print(os.path.basename(f)); display(Image.open(f))"]},{"cell_type":"markdown","metadata":{},"source":["## Architecture Reference\n","\n","**VAE:** madebyollin/sdxl-vae-fp16-fix (open, 4ch latent, 8x compression, no auth)\n","\n","**Liquid Time Constant (CfC):** `alpha = exp(-softplus(rho))`, `out = alpha*state + (1-alpha)*stimulus`\n","\n","**Flow Matching:** `x_t = (1-t)*x0 + t*eps`, target `v = eps - x0`\n","\n","**Sampling:** `x_{t-dt} = x_t - dt * model(x_t, t)` from t=1 to t=0\n","\n","Papers: LTC (NeurIPS 2020), CfC (Nature MI 2022), NCP (Nature MI 2020), ZigMa (ECCV 2024), DiMSUM (NeurIPS 2024)"]}]}
 
1
+ {"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"gpuType":"T4"},"kernelspec":{"name":"python3","display_name":"Python 3"},"accelerator":"GPU"},"cells":[{"cell_type":"markdown","metadata":{},"source":["# LiquidGen: Liquid Neural Network Image Generator\n","\n","A novel **attention-free** diffusion model using CfC Liquid Neural Network dynamics.\n","\n","**Optimized for Colab free tier (T4 16GB):**\n","- Latent pre-caching: encode images once, train on pure tensors\n","- Open SDXL VAE — no login needed\n","- Pure parquet datasets — no legacy script errors\n","\n","**Dataset presets:**\n","| Preset | Images | Size | Type |\n","|--------|--------|------|------|\n","| `cartoon` | ~2.5K | 181MB | Cartoon/anime (recommended start) |\n","| `flowers` | ~8K | 331MB | Flower photos |\n","| `art_painting` | ~6K | 511MB | Art paintings |\n","| `wikiart` | ~105K | 1.6GB | Full WikiArt with styles (use max_images!) |\n"]},{"cell_type":"markdown","metadata":{},"source":["## 1. Install"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["!pip install -q torch torchvision diffusers datasets accelerate huggingface_hub"]},{"cell_type":"markdown","metadata":{},"source":["## 2. Configuration"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["# ===== EDIT THESE =====\n","MODEL_SIZE = \"small\" # \"small\" (~55M), \"base\" (~140M), \"large\" (~280M)\n","IMAGE_SIZE = 256 # 256 or 512\n","DATASET_PRESET = \"cartoon\" # See table above\n","MAX_IMAGES = 0 # 0=all, >0 to limit (use with wikiart!)\n","BATCH_SIZE = 32 # Large OK - training on cached tensors!\n","GRAD_ACCUM = 1\n","LEARNING_RATE = 1e-4\n","NUM_EPOCHS = 100\n","WARMUP_STEPS = 500\n","SAMPLE_EVERY = 500\n","SAMPLE_STEPS = 50\n","CFG_SCALE = 2.0\n","OUTPUT_DIR = \"/content/liquidgen\"\n","SAVE_EVERY = 2000\n","LOG_EVERY = 25\n","\n","# VAE — fully open, no login needed\n","VAE_ID = \"madebyollin/sdxl-vae-fp16-fix\"\n","VAE_SCALE = 0.13025\n","LATENT_CH = 4\n","\n","import torch\n","if torch.cuda.is_available():\n"," g = torch.cuda.get_device_name(0)\n"," m = torch.cuda.get_device_properties(0).total_mem/1024**3\n"," print(f\"GPU: {g} ({m:.1f}GB)\")\n","else:\n"," print(\"No GPU! Runtime > Change runtime type > GPU\")"]},{"cell_type":"markdown","metadata":{},"source":["## 3. Download Model Code"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["!wget -q -O model.py https://huggingface.co/asdf98/LiquidGen/resolve/main/model.py\n","!wget -q -O train.py https://huggingface.co/asdf98/LiquidGen/resolve/main/train.py\n","from model import LiquidGen, liquidgen_small, liquidgen_base, liquidgen_large\n","from train import (TrainConfig, DATASET_PRESETS, get_model_config,\n"," precache_latents, CachedLatentDataset, FlowMatchingScheduler,\n"," EMAModel, cosine_schedule)\n","print(\"Code loaded!\")\n","for n,f in [(\"Small\",liquidgen_small),(\"Base\",liquidgen_base),(\"Large\",liquidgen_large)]:\n"," m=f(num_classes=0); print(f\" LiquidGen-{n}: {m.count_params()/1e6:.1f}M\"); del m"]},{"cell_type":"markdown","metadata":{},"source":["## 4. Pre-Cache Latents\n","Encodes all images with VAE once, saves to disk, unloads VAE."]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import os, time\n","os.makedirs(f\"{OUTPUT_DIR}/samples\", exist_ok=True)\n","os.makedirs(f\"{OUTPUT_DIR}/checkpoints\", exist_ok=True)\n","\n","preset = DATASET_PRESETS[DATASET_PRESET]\n","NUM_CLASSES = preset[\"num_classes\"]\n","print(f\"Dataset: {preset['description']}\")\n","\n","config = TrainConfig(\n"," model_size=MODEL_SIZE, num_classes=NUM_CLASSES,\n"," dataset_preset=DATASET_PRESET, image_size=IMAGE_SIZE,\n"," max_images=MAX_IMAGES, batch_size=BATCH_SIZE,\n"," gradient_accumulation_steps=GRAD_ACCUM,\n"," learning_rate=LEARNING_RATE, num_epochs=NUM_EPOCHS,\n"," warmup_steps=WARMUP_STEPS, output_dir=OUTPUT_DIR,\n"," save_every_n_steps=SAVE_EVERY, sample_every_n_steps=SAMPLE_EVERY,\n"," log_every_n_steps=LOG_EVERY, num_sample_steps=SAMPLE_STEPS,\n"," cfg_scale=CFG_SCALE, vae_id=VAE_ID,\n"," vae_scaling_factor=VAE_SCALE, latent_channels=LATENT_CH,\n",")\n","cache_path = precache_latents(config)"]},{"cell_type":"markdown","metadata":{},"source":["## 5. Train"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import torch.nn.functional as F\n","from torch.utils.data import DataLoader\n","from torch.amp import autocast, GradScaler\n","import math\n","\n","device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n","train_ds = CachedLatentDataset(cache_path)\n","train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,\n"," num_workers=2, pin_memory=True, drop_last=True)\n","\n","mcfg = get_model_config(MODEL_SIZE, NUM_CLASSES)\n","mcfg[\"in_channels\"] = LATENT_CH\n","model = LiquidGen(**mcfg).to(device)\n","print(f\"LiquidGen-{MODEL_SIZE}: {model.count_params()/1e6:.1f}M params\")\n","\n","opt = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=0.01)\n","total_steps = len(train_dl) * NUM_EPOCHS // GRAD_ACCUM\n","sched = cosine_schedule(opt, WARMUP_STEPS, total_steps)\n","ema = EMAModel(model, 0.9999)\n","scaler = GradScaler(\"cuda\")\n","fm = FlowMatchingScheduler()\n","lat_size = IMAGE_SIZE // 8\n","print(f\"Steps: {total_steps}, Latent: [{BATCH_SIZE},{LATENT_CH},{lat_size},{lat_size}]\")\n","\n","gs=0; la=0; log_losses=[]; vae=None\n","print(\"\\nTraining!\\n\")\n","t0 = time.time()\n","\n","for epoch in range(NUM_EPOCHS):\n"," model.train(); et=time.time()\n"," for bi,(lats,lbls) in enumerate(train_dl):\n"," lats=lats.to(device)\n"," lbls=lbls.to(device) if NUM_CLASSES>0 else None\n"," t=fm.sample_timesteps(lats.shape[0],device)\n"," noise=torch.randn_like(lats)\n"," xt=fm.add_noise(lats,noise,t)\n"," vtgt=fm.get_velocity_target(lats,noise)\n"," with autocast(\"cuda\"):\n"," loss=F.mse_loss(model(xt,t,lbls),vtgt)/GRAD_ACCUM\n"," scaler.scale(loss).backward(); la+=loss.item()\n"," if (bi+1)%GRAD_ACCUM==0:\n"," scaler.unscale_(opt)\n"," gn=torch.nn.utils.clip_grad_norm_(model.parameters(),2.0)\n"," scaler.step(opt); scaler.update(); opt.zero_grad(); sched.step()\n"," ema.update(model); gs+=1\n"," if gs%LOG_EVERY==0:\n"," al=la/LOG_EVERY\n"," vram=torch.cuda.memory_allocated()/1024**3 if torch.cuda.is_available() else 0\n"," print(f\"step={gs:>5d} | ep={epoch} | loss={al:.4f} | gn={gn:.2f} | \"\n"," f\"lr={opt.param_groups[0]['lr']:.2e} | vram={vram:.1f}G\")\n"," log_losses.append(al); la=0\n"," if math.isnan(al): print(\"Diverged!\"); break\n"," if gs%SAMPLE_EVERY==0:\n"," if vae is None:\n"," from diffusers import AutoencoderKL\n"," vae=AutoencoderKL.from_pretrained(VAE_ID,\n"," torch_dtype=torch.float16).to(device).eval()\n"," for p in vae.parameters(): p.requires_grad_(False)\n"," ema.apply(model); model.eval()\n"," sl=torch.randint(0,max(1,NUM_CLASSES),(4,),device=device) if NUM_CLASSES>0 else None\n"," samp=fm.sample(model,(4,LATENT_CH,lat_size,lat_size),device,SAMPLE_STEPS,sl,CFG_SCALE)\n"," with torch.no_grad():\n"," imgs=((vae.decode(samp.half()/VAE_SCALE).sample+1)/2).clamp(0,1).float()\n"," from torchvision.utils import save_image\n"," save_image(imgs,f\"{OUTPUT_DIR}/samples/step_{gs:07d}.png\",nrow=2)\n"," print(f\" Saved samples\"); ema.restore(model); model.train()\n"," if gs%SAVE_EVERY==0:\n"," torch.save({\"model\":model.state_dict(),\"ema\":ema.shadow,\"step\":gs,\"cfg\":mcfg},\n"," f\"{OUTPUT_DIR}/checkpoints/step_{gs:07d}.pt\")\n"," print(f\"Epoch {epoch} | {time.time()-et:.0f}s\")\n","\n","torch.save({\"model\":model.state_dict(),\"ema\":ema.shadow,\"step\":gs,\"cfg\":mcfg},\n"," f\"{OUTPUT_DIR}/checkpoints/final.pt\")\n","print(f\"\\nDone! {gs} steps, {(time.time()-t0)/60:.1f}min\")"]},{"cell_type":"markdown","metadata":{},"source":["## 6. Loss Curve"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["import matplotlib.pyplot as plt\n","if log_losses:\n"," plt.figure(figsize=(10,4)); plt.plot(log_losses)\n"," plt.xlabel(f\"Steps (x{LOG_EVERY})\"); plt.ylabel(\"Loss\")\n"," plt.title(\"Training Loss\"); plt.grid(True,alpha=0.3)\n"," plt.savefig(f\"{OUTPUT_DIR}/loss.png\",dpi=150); plt.show()"]},{"cell_type":"markdown","metadata":{},"source":["## 7. Generate"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["if vae is None:\n"," from diffusers import AutoencoderKL\n"," vae=AutoencoderKL.from_pretrained(VAE_ID,\n"," torch_dtype=torch.float16).to(device).eval()\n"," for p in vae.parameters(): p.requires_grad_(False)\n","ema.apply(model); model.eval()\n","ls=IMAGE_SIZE//8\n","s=fm.sample(model,(8,LATENT_CH,ls,ls),device,50)\n","with torch.no_grad():\n"," i=((vae.decode(s.half()/VAE_SCALE).sample+1)/2).clamp(0,1).float()\n","from torchvision.utils import save_image\n","save_image(i,f\"{OUTPUT_DIR}/generated.png\",nrow=4)\n","ema.restore(model); print(f\"Saved to {OUTPUT_DIR}/generated.png\")"]},{"cell_type":"markdown","metadata":{},"source":["## 8. Display"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["from IPython.display import display\n","from PIL import Image\n","import glob\n","for f in sorted(glob.glob(f\"{OUTPUT_DIR}/samples/*.png\"))[-3:]:\n"," print(os.path.basename(f)); display(Image.open(f))\n","for f in sorted(glob.glob(f\"{OUTPUT_DIR}/gen*.png\")):\n"," print(os.path.basename(f)); display(Image.open(f))"]},{"cell_type":"markdown","metadata":{},"source":["## Architecture\n","\n","**VAE:** madebyollin/sdxl-vae-fp16-fix (open, 4ch, 8x compression)\n","\n","**Liquid Time Constant:** `alpha = exp(-softplus(rho))`, `out = alpha*x + (1-alpha)*stimulus`\n","\n","**Flow Matching:** `x_t = (1-t)*x0 + t*eps`, target `v = eps - x0`\n","\n","**Papers:** LTC (NeurIPS 2020), CfC (Nature MI 2022), ZigMa (ECCV 2024), DiMSUM (NeurIPS 2024)"]}]}