Fix: streaming dataset in notebook (no full download on Colab)
Browse files- LiquidGen_Colab_Notebook.ipynb +34 -247
LiquidGen_Colab_Notebook.ipynb
CHANGED
|
@@ -21,21 +21,17 @@
|
|
| 21 |
"\n",
|
| 22 |
"**A novel attention-free diffusion model using CfC Liquid Neural Network dynamics.**\n",
|
| 23 |
"\n",
|
| 24 |
-
"### Key Features:\n",
|
| 25 |
"- **No Attention** \u2014 O(n) complexity using liquid time constants\n",
|
| 26 |
"- **Fully Parallelizable** \u2014 No sequential ODE solving\n",
|
| 27 |
-
"- **
|
| 28 |
-
"- **
|
| 29 |
-
"- **Fits 16GB VRAM** \u2014 Designed for Colab free tier\n",
|
| 30 |
-
"\n",
|
| 31 |
-
"Based on: Liquid Time-constant Networks (NeurIPS 2020), CfC (Nature MI 2022), ZigMa (ECCV 2024), DiMSUM (NeurIPS 2024)\n"
|
| 32 |
]
|
| 33 |
},
|
| 34 |
{
|
| 35 |
"cell_type": "markdown",
|
| 36 |
"metadata": {},
|
| 37 |
"source": [
|
| 38 |
-
"## \ud83d\udce6
|
| 39 |
]
|
| 40 |
},
|
| 41 |
{
|
|
@@ -44,14 +40,21 @@
|
|
| 44 |
"metadata": {},
|
| 45 |
"outputs": [],
|
| 46 |
"source": [
|
| 47 |
-
"!pip install -q torch torchvision diffusers datasets accelerate
|
| 48 |
]
|
| 49 |
},
|
| 50 |
{
|
| 51 |
"cell_type": "markdown",
|
| 52 |
"metadata": {},
|
| 53 |
"source": [
|
| 54 |
-
"## \ud83d\udd27 Configuration"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
]
|
| 56 |
},
|
| 57 |
{
|
|
@@ -60,48 +63,14 @@
|
|
| 60 |
"metadata": {},
|
| 61 |
"outputs": [],
|
| 62 |
"source": [
|
| 63 |
-
"
|
| 64 |
-
"IMAGE_SIZE = 256 # 256 or 512\n",
|
| 65 |
-
"DATASET_NAME = \"huggan/wikiart\"\n",
|
| 66 |
-
"IMAGE_COLUMN = \"image\"\n",
|
| 67 |
-
"LABEL_COLUMN = \"style\" # \"style\" (27), \"genre\" (11), \"\" for unconditional\n",
|
| 68 |
-
"NUM_CLASSES = 27\n",
|
| 69 |
-
"BATCH_SIZE = 8\n",
|
| 70 |
-
"GRADIENT_ACCUMULATION = 4\n",
|
| 71 |
-
"LEARNING_RATE = 1e-4\n",
|
| 72 |
-
"WEIGHT_DECAY = 0.01\n",
|
| 73 |
-
"MAX_GRAD_NORM = 2.0\n",
|
| 74 |
-
"NUM_EPOCHS = 50\n",
|
| 75 |
-
"WARMUP_STEPS = 500\n",
|
| 76 |
-
"EMA_DECAY = 0.9999\n",
|
| 77 |
-
"NUM_SAMPLE_STEPS = 50\n",
|
| 78 |
-
"CFG_SCALE = 2.0\n",
|
| 79 |
-
"OUTPUT_DIR = \"/content/liquidgen_outputs\"\n",
|
| 80 |
-
"SAVE_EVERY = 2000\n",
|
| 81 |
-
"SAMPLE_EVERY = 500\n",
|
| 82 |
-
"LOG_EVERY = 50\n",
|
| 83 |
-
"PUSH_TO_HUB = False\n",
|
| 84 |
-
"HUB_MODEL_ID = \"\"\n",
|
| 85 |
-
"VAE_ID = \"black-forest-labs/FLUX.1-schnell\"\n",
|
| 86 |
-
"VAE_SUBFOLDER = \"vae\"\n",
|
| 87 |
-
"\n",
|
| 88 |
-
"import torch\n",
|
| 89 |
-
"if torch.cuda.is_available():\n",
|
| 90 |
-
" gpu = torch.cuda.get_device_name(0)\n",
|
| 91 |
-
" mem = torch.cuda.get_device_properties(0).total_mem / 1024**3\n",
|
| 92 |
-
" print(f\"GPU: {gpu} ({mem:.1f} GB)\")\n",
|
| 93 |
-
" if mem < 12: print(\"\u26a0\ufe0f Low VRAM! Use small model, 256px, bs=4\")\n",
|
| 94 |
-
" elif mem < 20: print(\"\u2705 T4 detected. Good for base model, 256px\")\n",
|
| 95 |
-
" else: print(\"\ud83d\ude80 Large GPU! Can run large model, 512px\")\n",
|
| 96 |
-
"else:\n",
|
| 97 |
-
" print(\"\u26a0\ufe0f No GPU! Go to Runtime \u2192 Change runtime type \u2192 GPU\")"
|
| 98 |
]
|
| 99 |
},
|
| 100 |
{
|
| 101 |
"cell_type": "markdown",
|
| 102 |
"metadata": {},
|
| 103 |
"source": [
|
| 104 |
-
"## \ud83c\udfd7\ufe0f Model Architecture"
|
| 105 |
]
|
| 106 |
},
|
| 107 |
{
|
|
@@ -117,7 +86,7 @@
|
|
| 117 |
"cell_type": "markdown",
|
| 118 |
"metadata": {},
|
| 119 |
"source": [
|
| 120 |
-
"## \ud83d\udd04 Training Utilities"
|
| 121 |
]
|
| 122 |
},
|
| 123 |
{
|
|
@@ -126,77 +95,14 @@
|
|
| 126 |
"metadata": {},
|
| 127 |
"outputs": [],
|
| 128 |
"source": [
|
| 129 |
-
"import os,
|
| 130 |
-
"import numpy as np\n",
|
| 131 |
-
"from torch.utils.data import DataLoader, Dataset\n",
|
| 132 |
-
"from torch.amp import autocast, GradScaler\n",
|
| 133 |
-
"from torchvision import transforms\n",
|
| 134 |
-
"from torchvision.utils import save_image\n",
|
| 135 |
-
"from PIL import Image\n",
|
| 136 |
-
"\n",
|
| 137 |
-
"class FlowMatchingScheduler:\n",
|
| 138 |
-
" def __init__(self, min_t=0.001, max_t=0.999): self.min_t, self.max_t = min_t, max_t\n",
|
| 139 |
-
" def sample_timesteps(self, bs, dev): return torch.rand(bs, device=dev) * (self.max_t - self.min_t) + self.min_t\n",
|
| 140 |
-
" def add_noise(self, x0, noise, t): t = t.view(-1,1,1,1); return (1-t)*x0 + t*noise\n",
|
| 141 |
-
" def get_velocity_target(self, x0, noise): return noise - x0\n",
|
| 142 |
-
" @torch.no_grad()\n",
|
| 143 |
-
" def sample(self, model, shape, dev, num_steps=50, labels=None, cfg=1.0):\n",
|
| 144 |
-
" model.eval(); x = torch.randn(shape, device=dev)\n",
|
| 145 |
-
" dt = 1.0 / num_steps\n",
|
| 146 |
-
" for t_val in torch.linspace(1.0, dt, num_steps, device=dev):\n",
|
| 147 |
-
" t = torch.full((shape[0],), t_val.item(), device=dev)\n",
|
| 148 |
-
" with torch.amp.autocast(\"cuda\"):\n",
|
| 149 |
-
" if cfg > 1.0 and labels is not None:\n",
|
| 150 |
-
" vc = model(x,t,labels); vu = model(x,t,torch.zeros_like(labels))\n",
|
| 151 |
-
" v = vu + cfg * (vc - vu)\n",
|
| 152 |
-
" else: v = model(x, t, labels)\n",
|
| 153 |
-
" x = x - dt * v.float()\n",
|
| 154 |
-
" return x\n",
|
| 155 |
-
"\n",
|
| 156 |
-
"class EMAModel:\n",
|
| 157 |
-
" def __init__(self, model, decay=0.9999):\n",
|
| 158 |
-
" self.decay = decay\n",
|
| 159 |
-
" self.shadow = {n: p.clone().detach() for n,p in model.named_parameters() if p.requires_grad}\n",
|
| 160 |
-
" @torch.no_grad()\n",
|
| 161 |
-
" def update(self, model):\n",
|
| 162 |
-
" for n,p in model.named_parameters():\n",
|
| 163 |
-
" if p.requires_grad and n in self.shadow: self.shadow[n].mul_(self.decay).add_(p.data, alpha=1-self.decay)\n",
|
| 164 |
-
" def apply(self, model):\n",
|
| 165 |
-
" self.backup = {n: p.data.clone() for n,p in model.named_parameters() if p.requires_grad}\n",
|
| 166 |
-
" for n,p in model.named_parameters():\n",
|
| 167 |
-
" if p.requires_grad and n in self.shadow: p.data.copy_(self.shadow[n])\n",
|
| 168 |
-
" def restore(self, model):\n",
|
| 169 |
-
" for n,p in model.named_parameters():\n",
|
| 170 |
-
" if p.requires_grad and n in self.backup: p.data.copy_(self.backup[n])\n",
|
| 171 |
-
"\n",
|
| 172 |
-
"class ImageDataset(Dataset):\n",
|
| 173 |
-
" def __init__(self, ds, tf, img_col, lbl_col=\"\"): self.ds, self.tf, self.ic, self.lc = ds, tf, img_col, lbl_col\n",
|
| 174 |
-
" def __len__(self): return len(self.ds)\n",
|
| 175 |
-
" def __getitem__(self, i):\n",
|
| 176 |
-
" item = self.ds[i]; img = item[self.ic]\n",
|
| 177 |
-
" if img.mode != \"RGB\": img = img.convert(\"RGB\")\n",
|
| 178 |
-
" label = item[self.lc] if self.lc and self.lc in item else -1\n",
|
| 179 |
-
" return self.tf(img), label\n",
|
| 180 |
-
"\n",
|
| 181 |
-
"def cosine_sched(opt, warmup, total):\n",
|
| 182 |
-
" def lr(s):\n",
|
| 183 |
-
" if s < warmup: return s / max(1, warmup)\n",
|
| 184 |
-
" return max(0, 0.5*(1+math.cos(math.pi*(s-warmup)/max(1,total-warmup))))\n",
|
| 185 |
-
" return torch.optim.lr_scheduler.LambdaLR(opt, lr)\n",
|
| 186 |
-
"\n",
|
| 187 |
-
"MODEL_CONFIGS = {\n",
|
| 188 |
-
" \"small\": dict(embed_dim=512, depth=12, spatial_kernel=7, scan_kernel=31, expand_ratio=2.0, mlp_ratio=3.0),\n",
|
| 189 |
-
" \"base\": dict(embed_dim=640, depth=18, spatial_kernel=7, scan_kernel=31, expand_ratio=2.0, mlp_ratio=4.0),\n",
|
| 190 |
-
" \"large\": dict(embed_dim=768, depth=24, spatial_kernel=7, scan_kernel=31, expand_ratio=2.5, mlp_ratio=4.0),\n",
|
| 191 |
-
"}\n",
|
| 192 |
-
"print(\"\u2705 Training utilities ready!\")"
|
| 193 |
]
|
| 194 |
},
|
| 195 |
{
|
| 196 |
"cell_type": "markdown",
|
| 197 |
"metadata": {},
|
| 198 |
"source": [
|
| 199 |
-
"## \ud83d\udcca Load Dataset & VAE"
|
| 200 |
]
|
| 201 |
},
|
| 202 |
{
|
|
@@ -205,33 +111,14 @@
|
|
| 205 |
"metadata": {},
|
| 206 |
"outputs": [],
|
| 207 |
"source": [
|
| 208 |
-
"from
|
| 209 |
-
"from diffusers import AutoencoderKL\n",
|
| 210 |
-
"\n",
|
| 211 |
-
"print(f\"Loading dataset: {DATASET_NAME}...\")\n",
|
| 212 |
-
"dataset = load_dataset(DATASET_NAME, split=\"train\")\n",
|
| 213 |
-
"print(f\" {len(dataset)} images\")\n",
|
| 214 |
-
"\n",
|
| 215 |
-
"transform = transforms.Compose([\n",
|
| 216 |
-
" transforms.Resize(IMAGE_SIZE, interpolation=transforms.InterpolationMode.LANCZOS),\n",
|
| 217 |
-
" transforms.CenterCrop(IMAGE_SIZE), transforms.RandomHorizontalFlip(), transforms.ToTensor(),\n",
|
| 218 |
-
"])\n",
|
| 219 |
-
"train_ds = ImageDataset(dataset, transform, IMAGE_COLUMN, LABEL_COLUMN)\n",
|
| 220 |
-
"train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)\n",
|
| 221 |
-
"\n",
|
| 222 |
-
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
| 223 |
-
"vae = AutoencoderKL.from_pretrained(VAE_ID, subfolder=VAE_SUBFOLDER, torch_dtype=torch.float16).to(device).eval()\n",
|
| 224 |
-
"for p in vae.parameters(): p.requires_grad_(False)\n",
|
| 225 |
-
"print(f\"VAE: {sum(p.numel() for p in vae.parameters())/1e6:.1f}M params (frozen)\")\n",
|
| 226 |
-
"SCALE, SHIFT = 0.3611, 0.1159\n",
|
| 227 |
-
"print(\"\u2705 Ready!\")"
|
| 228 |
]
|
| 229 |
},
|
| 230 |
{
|
| 231 |
"cell_type": "markdown",
|
| 232 |
"metadata": {},
|
| 233 |
"source": [
|
| 234 |
-
"## \ud83c\udfcb\ufe0f
|
| 235 |
]
|
| 236 |
},
|
| 237 |
{
|
|
@@ -240,21 +127,14 @@
|
|
| 240 |
"metadata": {},
|
| 241 |
"outputs": [],
|
| 242 |
"source": [
|
| 243 |
-
"cfg = MODEL_CONFIGS[MODEL_SIZE].copy()\n",
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
"
|
| 251 |
-
"ema = EMAModel(model, EMA_DECAY)\n",
|
| 252 |
-
"scaler = GradScaler(\"cuda\")\n",
|
| 253 |
-
"fm = FlowMatchingScheduler()\n",
|
| 254 |
-
"os.makedirs(f\"{OUTPUT_DIR}/samples\", exist_ok=True)\n",
|
| 255 |
-
"os.makedirs(f\"{OUTPUT_DIR}/checkpoints\", exist_ok=True)\n",
|
| 256 |
-
"print(f\"Total steps: {total_steps}, Effective batch: {BATCH_SIZE*GRADIENT_ACCUMULATION}\")\n",
|
| 257 |
-
"if torch.cuda.is_available(): print(f\"VRAM: {torch.cuda.memory_allocated()/1024**3:.2f} GB used\")"
|
| 258 |
]
|
| 259 |
},
|
| 260 |
{
|
|
@@ -263,57 +143,14 @@
|
|
| 263 |
"metadata": {},
|
| 264 |
"outputs": [],
|
| 265 |
"source": [
|
| 266 |
-
"global_step = 0; loss_accum = 0.0; log_losses = []\n",
|
| 267 |
-
"print(\"\ud83d\ude80 Training!\n\")\n",
|
| 268 |
-
"t0 = time.time()\n",
|
| 269 |
-
"for epoch in range(NUM_EPOCHS):\n",
|
| 270 |
-
" model.train(); ep_loss = 0; ep_steps = 0; ep_t = time.time()\n",
|
| 271 |
-
" for bi, (imgs, lbls) in enumerate(train_loader):\n",
|
| 272 |
-
" imgs = imgs.to(device)\n",
|
| 273 |
-
" lbls = lbls.to(device) if NUM_CLASSES > 0 else None\n",
|
| 274 |
-
" with torch.no_grad():\n",
|
| 275 |
-
" lats = vae.encode(imgs.half()*2-1).latent_dist.sample()\n",
|
| 276 |
-
" lats = ((lats - SHIFT) * SCALE).float()\n",
|
| 277 |
-
" t = fm.sample_timesteps(lats.shape[0], device)\n",
|
| 278 |
-
" noise = torch.randn_like(lats)\n",
|
| 279 |
-
" xt = fm.add_noise(lats, noise, t)\n",
|
| 280 |
-
" vtgt = fm.get_velocity_target(lats, noise)\n",
|
| 281 |
-
" with autocast(\"cuda\"): loss = F.mse_loss(model(xt, t, lbls), vtgt) / GRADIENT_ACCUMULATION\n",
|
| 282 |
-
" scaler.scale(loss).backward()\n",
|
| 283 |
-
" loss_accum += loss.item()\n",
|
| 284 |
-
" if (bi+1) % GRADIENT_ACCUMULATION == 0:\n",
|
| 285 |
-
" scaler.unscale_(optimizer)\n",
|
| 286 |
-
" gn = torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)\n",
|
| 287 |
-
" scaler.step(optimizer); scaler.update(); optimizer.zero_grad(); scheduler.step()\n",
|
| 288 |
-
" ema.update(model); global_step += 1\n",
|
| 289 |
-
" if global_step % LOG_EVERY == 0:\n",
|
| 290 |
-
" al = loss_accum / LOG_EVERY; lr = optimizer.param_groups[0][\"lr\"]\n",
|
| 291 |
-
" vram = torch.cuda.memory_allocated()/1024**3 if torch.cuda.is_available() else 0\n",
|
| 292 |
-
" print(f\"step={global_step:>6d} | ep={epoch} | loss={al:.4f} | gn={gn:.2f} | lr={lr:.2e} | vram={vram:.1f}G\")\n",
|
| 293 |
-
" log_losses.append(al); loss_accum = 0\n",
|
| 294 |
-
" if math.isnan(al) or al > 50: print(\"\ud83d\udca5 Diverged!\"); break\n",
|
| 295 |
-
" if global_step % SAMPLE_EVERY == 0:\n",
|
| 296 |
-
" ema.apply(model); model.eval()\n",
|
| 297 |
-
" ls = IMAGE_SIZE // 8\n",
|
| 298 |
-
" sl = torch.randint(0, max(1,NUM_CLASSES), (4,), device=device) if NUM_CLASSES > 0 else None\n",
|
| 299 |
-
" samp = fm.sample(model, (4,16,ls,ls), device, NUM_SAMPLE_STEPS, sl, CFG_SCALE)\n",
|
| 300 |
-
" with torch.no_grad(): si = ((vae.decode(samp.half()/SCALE+SHIFT).sample+1)/2).clamp(0,1).float()\n",
|
| 301 |
-
" save_image(si, f\"{OUTPUT_DIR}/samples/step_{global_step:07d}.png\", nrow=2)\n",
|
| 302 |
-
" ema.restore(model); model.train()\n",
|
| 303 |
-
" if global_step % SAVE_EVERY == 0:\n",
|
| 304 |
-
" torch.save({\"model\":model.state_dict(),\"ema\":ema.shadow,\"step\":global_step,\"cfg\":cfg},\n",
|
| 305 |
-
" f\"{OUTPUT_DIR}/checkpoints/step_{global_step:07d}.pt\")\n",
|
| 306 |
-
" ep_loss += loss.item()*GRADIENT_ACCUMULATION; ep_steps += 1\n",
|
| 307 |
-
" print(f\"Epoch {epoch} | loss={ep_loss/max(ep_steps,1):.4f} | {time.time()-ep_t:.0f}s\")\n",
|
| 308 |
-
"torch.save({\"model\":model.state_dict(),\"ema\":ema.shadow,\"cfg\":cfg,\"step\":global_step},f\"{OUTPUT_DIR}/checkpoints/final.pt\")\n",
|
| 309 |
-
"print(f\"\ud83c\udf89 Done! {global_step} steps in {(time.time()-t0)/60:.1f} min\")"
|
| 310 |
]
|
| 311 |
},
|
| 312 |
{
|
| 313 |
"cell_type": "markdown",
|
| 314 |
"metadata": {},
|
| 315 |
"source": [
|
| 316 |
-
"## \ud83d\udcc8
|
| 317 |
]
|
| 318 |
},
|
| 319 |
{
|
|
@@ -322,18 +159,14 @@
|
|
| 322 |
"metadata": {},
|
| 323 |
"outputs": [],
|
| 324 |
"source": [
|
| 325 |
-
"import matplotlib.pyplot as plt\n",
|
| 326 |
-
"if log_losses:\n",
|
| 327 |
-
" plt.figure(figsize=(10,4)); plt.plot(log_losses); plt.xlabel(f\"Steps (\u00d7{LOG_EVERY})\"); plt.ylabel(\"Loss\")\n",
|
| 328 |
-
" plt.title(\"Training Loss\"); plt.grid(True, alpha=0.3); plt.savefig(f\"{OUTPUT_DIR}/loss.png\", dpi=150); plt.show()\n",
|
| 329 |
-
" print(f\"Min loss: {min(log_losses):.4f}\")"
|
| 330 |
]
|
| 331 |
},
|
| 332 |
{
|
| 333 |
"cell_type": "markdown",
|
| 334 |
"metadata": {},
|
| 335 |
"source": [
|
| 336 |
-
"## \ud83c\udfa8
|
| 337 |
]
|
| 338 |
},
|
| 339 |
{
|
|
@@ -342,32 +175,14 @@
|
|
| 342 |
"metadata": {},
|
| 343 |
"outputs": [],
|
| 344 |
"source": [
|
| 345 |
-
"ema.apply(model); model.eval()\n",
|
| 346 |
-
"N, STEPS, CFG = 8, 50, 2.5\n",
|
| 347 |
-
"ls = IMAGE_SIZE // 8\n",
|
| 348 |
-
"STYLES = [\"Abstract Expressionism\",\"Baroque\",\"Cubism\",\"Expressionism\",\"Impressionism\",\n",
|
| 349 |
-
" \"Pop Art\",\"Realism\",\"Romanticism\",\"Symbolism\",\"Ukiyo-e\"]\n",
|
| 350 |
-
"if NUM_CLASSES > 0:\n",
|
| 351 |
-
" for ci in range(min(NUM_CLASSES, 8)):\n",
|
| 352 |
-
" l = torch.full((N,), ci, device=device, dtype=torch.long)\n",
|
| 353 |
-
" s = fm.sample(model, (N,16,ls,ls), device, STEPS, l, CFG)\n",
|
| 354 |
-
" with torch.no_grad(): i = ((vae.decode(s.half()/SCALE+SHIFT).sample+1)/2).clamp(0,1).float()\n",
|
| 355 |
-
" nm = STYLES[ci] if ci < len(STYLES) else f\"Class_{ci}\"\n",
|
| 356 |
-
" save_image(i, f\"{OUTPUT_DIR}/gen_{nm.replace(chr(32),chr(95))}.png\", nrow=4)\n",
|
| 357 |
-
" print(f\"Generated: {nm}\")\n",
|
| 358 |
-
"else:\n",
|
| 359 |
-
" s = fm.sample(model, (N,16,ls,ls), device, STEPS)\n",
|
| 360 |
-
" with torch.no_grad(): i = ((vae.decode(s.half()/SCALE+SHIFT).sample+1)/2).clamp(0,1).float()\n",
|
| 361 |
-
" save_image(i, f\"{OUTPUT_DIR}/gen_uncond.png\", nrow=4)\n",
|
| 362 |
-
"ema.restore(model)\n",
|
| 363 |
-
"print(f\"\u2705 Saved to {OUTPUT_DIR}/\")"
|
| 364 |
]
|
| 365 |
},
|
| 366 |
{
|
| 367 |
"cell_type": "markdown",
|
| 368 |
"metadata": {},
|
| 369 |
"source": [
|
| 370 |
-
"## \ud83d\udce4
|
| 371 |
]
|
| 372 |
},
|
| 373 |
{
|
|
@@ -376,35 +191,7 @@
|
|
| 376 |
"metadata": {},
|
| 377 |
"outputs": [],
|
| 378 |
"source": [
|
| 379 |
-
"from IPython.display import display\n"
|
| 380 |
-
"import glob\n",
|
| 381 |
-
"for f in sorted(glob.glob(f\"{OUTPUT_DIR}/samples/*.png\"))[-3:]:\n",
|
| 382 |
-
" print(os.path.basename(f)); display(Image.open(f))\n",
|
| 383 |
-
"for f in sorted(glob.glob(f\"{OUTPUT_DIR}/gen_*.png\")):\n",
|
| 384 |
-
" print(os.path.basename(f)); display(Image.open(f))"
|
| 385 |
-
]
|
| 386 |
-
},
|
| 387 |
-
{
|
| 388 |
-
"cell_type": "markdown",
|
| 389 |
-
"metadata": {},
|
| 390 |
-
"source": [
|
| 391 |
-
"## \ud83d\udcdd Architecture Reference\n",
|
| 392 |
-
"\n",
|
| 393 |
-
"### Core Equation (CfC Liquid Dynamics)\n",
|
| 394 |
-
"\n",
|
| 395 |
-
"\n",
|
| 396 |
-
"### Flow Matching\n",
|
| 397 |
-
"\n",
|
| 398 |
-
"\n",
|
| 399 |
-
"### Sampling (Euler ODE)\n",
|
| 400 |
-
"\n",
|
| 401 |
-
"\n",
|
| 402 |
-
"### References\n",
|
| 403 |
-
"- Hasani et al., \"Liquid Time-constant Networks\" (NeurIPS 2020)\n",
|
| 404 |
-
"- Hasani et al., \"Closed-form Continuous-depth Models\" (Nature MI 2022)\n",
|
| 405 |
-
"- Lechner et al., \"Neural Circuit Policies\" (Nature MI 2020)\n",
|
| 406 |
-
"- ZigMa (ECCV 2024), DiMSUM (NeurIPS 2024)\n",
|
| 407 |
-
"- Lipman et al., \"Flow Matching\" (2023), SiT (2024)\n"
|
| 408 |
]
|
| 409 |
}
|
| 410 |
]
|
|
|
|
| 21 |
"\n",
|
| 22 |
"**A novel attention-free diffusion model using CfC Liquid Neural Network dynamics.**\n",
|
| 23 |
"\n",
|
|
|
|
| 24 |
"- **No Attention** \u2014 O(n) complexity using liquid time constants\n",
|
| 25 |
"- **Fully Parallelizable** \u2014 No sequential ODE solving\n",
|
| 26 |
+
"- **Streaming Dataset** \u2014 No full download, starts training immediately\n",
|
| 27 |
+
"- **Fits 16GB VRAM** \u2014 Designed for Colab free tier T4\n"
|
|
|
|
|
|
|
|
|
|
| 28 |
]
|
| 29 |
},
|
| 30 |
{
|
| 31 |
"cell_type": "markdown",
|
| 32 |
"metadata": {},
|
| 33 |
"source": [
|
| 34 |
+
"## \ud83d\udce6 Step 1: Install"
|
| 35 |
]
|
| 36 |
},
|
| 37 |
{
|
|
|
|
| 40 |
"metadata": {},
|
| 41 |
"outputs": [],
|
| 42 |
"source": [
|
| 43 |
+
"!pip install -q torch torchvision diffusers datasets accelerate Pillow"
|
| 44 |
]
|
| 45 |
},
|
| 46 |
{
|
| 47 |
"cell_type": "markdown",
|
| 48 |
"metadata": {},
|
| 49 |
"source": [
|
| 50 |
+
"## \ud83d\udd27 Step 2: Configuration\n",
|
| 51 |
+
"\n",
|
| 52 |
+
"**Dataset options:**\n",
|
| 53 |
+
"| Dataset | Size | Download | Type |\n",
|
| 54 |
+
"|---------|------|----------|------|\n",
|
| 55 |
+
"| `huggan/wikiart` | ~80K | **Streaming** (no download!) | Art, 27 styles |\n",
|
| 56 |
+
"| `reach-vb/pokemon-blip-captions` | 833 | 95MB (fast) | Pokemon |\n",
|
| 57 |
+
"| `huggan/flowers-102-categories` | 8K | 330MB | Flowers |\n"
|
| 58 |
]
|
| 59 |
},
|
| 60 |
{
|
|
|
|
| 63 |
"metadata": {},
|
| 64 |
"outputs": [],
|
| 65 |
"source": [
|
| 66 |
+
"# ============================================================================\n# CONFIGURATION\n# ============================================================================\n\nMODEL_SIZE = \"small\" # \"small\" (~55M), \"base\" (~140M), \"large\" (~280M)\nIMAGE_SIZE = 256 # 256 or 512\n\n# --- Dataset (Option A: WikiArt streaming \u2014 NO download) ---\nDATASET_NAME = \"huggan/wikiart\"\nIMAGE_COLUMN = \"image\"\nLABEL_COLUMN = \"style\" # \"style\"(27), \"genre\"(11), \"\" for unconditional\nNUM_CLASSES = 27\nUSE_STREAMING = True # KEY: no full download!\n\n# --- Dataset (Option B: Pokemon \u2014 small, fast download, good for testing) ---\n# DATASET_NAME = \"reach-vb/pokemon-blip-captions\"\n# IMAGE_COLUMN = \"image\"; LABEL_COLUMN = \"\"; NUM_CLASSES = 0; USE_STREAMING = False\n\n# --- Training ---\nBATCH_SIZE = 8; GRADIENT_ACCUMULATION = 4\nLEARNING_RATE = 1e-4; WEIGHT_DECAY = 0.01; MAX_GRAD_NORM = 2.0\nMAX_STEPS = 20000; WARMUP_STEPS = 500; EMA_DECAY = 0.9999\nNUM_SAMPLE_STEPS = 50; CFG_SCALE = 2.0\n\n# --- Saving ---\nOUTPUT_DIR = \"/content/liquidgen_outputs\"\nSAVE_EVERY = 5000; SAMPLE_EVERY = 500; LOG_EVERY = 50\n\n# --- VAE ---\nVAE_ID = \"black-forest-labs/FLUX.1-schnell\"\nVAE_SUBFOLDER = \"vae\"\nSCALE, SHIFT = 0.3611, 0.1159\n\nimport torch\nif torch.cuda.is_available():\n gpu = torch.cuda.get_device_name(0)\n mem = torch.cuda.get_device_properties(0).total_mem / 1024**3\n print(f\"GPU: {gpu} ({mem:.1f} GB)\")\nelse:\n print(\"No GPU! Go to Runtime > Change runtime type > GPU\")\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
]
|
| 68 |
},
|
| 69 |
{
|
| 70 |
"cell_type": "markdown",
|
| 71 |
"metadata": {},
|
| 72 |
"source": [
|
| 73 |
+
"## \ud83c\udfd7\ufe0f Step 3: Model Architecture"
|
| 74 |
]
|
| 75 |
},
|
| 76 |
{
|
|
|
|
| 86 |
"cell_type": "markdown",
|
| 87 |
"metadata": {},
|
| 88 |
"source": [
|
| 89 |
+
"## \ud83d\udd04 Step 4: Training Utilities"
|
| 90 |
]
|
| 91 |
},
|
| 92 |
{
|
|
|
|
| 95 |
"metadata": {},
|
| 96 |
"outputs": [],
|
| 97 |
"source": [
|
| 98 |
+
"import os, time, math\nimport numpy as np\nfrom torch.utils.data import DataLoader, IterableDataset, Dataset\nfrom torch.amp import autocast, GradScaler\nfrom torchvision import transforms\nfrom torchvision.utils import save_image\nfrom PIL import Image\n\nclass StreamingImageDataset(IterableDataset):\n \"\"\"Streaming dataset \u2014 NO full download. Images load on-the-fly.\"\"\"\n def __init__(self, name, img_col=\"image\", lbl_col=\"\", img_size=256,\n split=\"train\", config=\"\", buffer=1000, seed=42):\n super().__init__()\n self.name, self.img_col, self.lbl_col = name, img_col, lbl_col\n self.split, self.config, self.buffer, self.seed = split, config, buffer, seed\n self.tf = transforms.Compose([\n transforms.Resize(img_size, interpolation=transforms.InterpolationMode.LANCZOS),\n transforms.CenterCrop(img_size), transforms.RandomHorizontalFlip(), transforms.ToTensor()])\n\n def __iter__(self):\n from datasets import load_dataset\n kw = {\"name\": self.config} if self.config else {}\n ds = load_dataset(self.name, split=self.split, streaming=True, **kw)\n ds = ds.shuffle(seed=self.seed, buffer_size=self.buffer)\n for item in ds:\n try:\n img = item[self.img_col]\n if img.mode != \"RGB\": img = img.convert(\"RGB\")\n lbl = item[self.lbl_col] if self.lbl_col and self.lbl_col in item else -1\n yield self.tf(img), lbl\n except: continue\n\nclass MapImageDataset(Dataset):\n \"\"\"For small datasets (<500MB) \u2014 downloads once.\"\"\"\n def __init__(self, name, img_col=\"image\", lbl_col=\"\", img_size=256, split=\"train\"):\n from datasets import load_dataset\n print(f\"Downloading {name}...\")\n self.ds = load_dataset(name, split=split)\n self.img_col, self.lbl_col = img_col, lbl_col\n self.tf = transforms.Compose([\n transforms.Resize(img_size, interpolation=transforms.InterpolationMode.LANCZOS),\n transforms.CenterCrop(img_size), transforms.RandomHorizontalFlip(), transforms.ToTensor()])\n print(f\" {len(self.ds)} images\")\n\n def __len__(self): return len(self.ds)\n def __getitem__(self, i):\n item = self.ds[i]; img = item[self.img_col]\n if img.mode != \"RGB\": img = img.convert(\"RGB\")\n lbl = item[self.lbl_col] if self.lbl_col and self.lbl_col in item else -1\n return self.tf(img), lbl\n\nclass FlowMatchingScheduler:\n def __init__(self, min_t=0.001, max_t=0.999): self.min_t, self.max_t = min_t, max_t\n def sample_t(self, bs, dev): return torch.rand(bs, device=dev)*(self.max_t-self.min_t)+self.min_t\n def add_noise(self, x0, noise, t): return (1-t.view(-1,1,1,1))*x0 + t.view(-1,1,1,1)*noise\n def velocity(self, x0, noise): return noise - x0\n @torch.no_grad()\n def sample(self, model, shape, dev, steps=50, labels=None, cfg=1.0):\n model.eval(); x = torch.randn(shape, device=dev); dt = 1.0/steps\n for tv in torch.linspace(1.0, dt, steps, device=dev):\n t = torch.full((shape[0],), tv.item(), device=dev)\n with torch.amp.autocast(\"cuda\"):\n if cfg > 1.0 and labels is not None:\n vc = model(x,t,labels); vu = model(x,t,torch.zeros_like(labels))\n v = vu + cfg*(vc-vu)\n else: v = model(x,t,labels)\n x = x - dt * v.float()\n return x\n\nclass EMAModel:\n def __init__(self, model, decay=0.9999):\n self.decay = decay\n self.shadow = {n:p.clone().detach() for n,p in model.named_parameters() if p.requires_grad}\n @torch.no_grad()\n def update(self, m):\n for n,p in m.named_parameters():\n if p.requires_grad and n in self.shadow: self.shadow[n].mul_(self.decay).add_(p.data, alpha=1-self.decay)\n def apply(self, m):\n self.bk = {n:p.data.clone() for n,p in m.named_parameters() if p.requires_grad}\n for n,p in m.named_parameters():\n if p.requires_grad and n in self.shadow: p.data.copy_(self.shadow[n])\n def restore(self, m):\n for n,p in m.named_parameters():\n if p.requires_grad and n in self.bk: p.data.copy_(self.bk[n])\n\ndef cosine_sched(opt, warmup, total):\n def lr(s):\n if s < warmup: return s/max(1,warmup)\n return max(0, 0.5*(1+math.cos(math.pi*(s-warmup)/max(1,total-warmup))))\n return torch.optim.lr_scheduler.LambdaLR(opt, lr)\n\nMODEL_CONFIGS = {\n \"small\": dict(embed_dim=512, depth=12, spatial_kernel=7, scan_kernel=31, expand_ratio=2.0, mlp_ratio=3.0),\n \"base\": dict(embed_dim=640, depth=18, spatial_kernel=7, scan_kernel=31, expand_ratio=2.0, mlp_ratio=4.0),\n \"large\": dict(embed_dim=768, depth=24, spatial_kernel=7, scan_kernel=31, expand_ratio=2.5, mlp_ratio=4.0),\n}\nprint(\"Training utilities ready!\")\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
]
|
| 100 |
},
|
| 101 |
{
|
| 102 |
"cell_type": "markdown",
|
| 103 |
"metadata": {},
|
| 104 |
"source": [
|
| 105 |
+
"## \ud83d\udcca Step 5: Load Dataset & VAE"
|
| 106 |
]
|
| 107 |
},
|
| 108 |
{
|
|
|
|
| 111 |
"metadata": {},
|
| 112 |
"outputs": [],
|
| 113 |
"source": [
|
| 114 |
+
"from diffusers import AutoencoderKL\n\nif USE_STREAMING:\n print(f\"Loading {DATASET_NAME} in STREAMING mode (no full download)...\")\n train_ds = StreamingImageDataset(DATASET_NAME, IMAGE_COLUMN, LABEL_COLUMN, IMAGE_SIZE, buffer=1000)\n train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, num_workers=0, pin_memory=True)\n print(\" Streaming ready! Images load on-the-fly.\")\nelse:\n train_ds = MapImageDataset(DATASET_NAME, IMAGE_COLUMN, LABEL_COLUMN, IMAGE_SIZE)\n train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\nprint(f\"Loading VAE to {device}...\")\nvae = AutoencoderKL.from_pretrained(VAE_ID, subfolder=VAE_SUBFOLDER, torch_dtype=torch.float16).to(device).eval()\nfor p in vae.parameters(): p.requires_grad_(False)\nprint(f\" VAE: {sum(p.numel() for p in vae.parameters())/1e6:.1f}M params (frozen)\")\nprint(\"Ready!\")\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
]
|
| 116 |
},
|
| 117 |
{
|
| 118 |
"cell_type": "markdown",
|
| 119 |
"metadata": {},
|
| 120 |
"source": [
|
| 121 |
+
"## \ud83c\udfcb\ufe0f Step 6: Create Model"
|
| 122 |
]
|
| 123 |
},
|
| 124 |
{
|
|
|
|
| 127 |
"metadata": {},
|
| 128 |
"outputs": [],
|
| 129 |
"source": [
|
| 130 |
+
"cfg = MODEL_CONFIGS[MODEL_SIZE].copy()\ncfg[\"num_classes\"] = NUM_CLASSES; cfg[\"class_drop_prob\"] = 0.1; cfg[\"use_zigzag\"] = True\nmodel = LiquidGen(**cfg).to(device)\nprint(f\"LiquidGen-{MODEL_SIZE}: {model.count_params()/1e6:.1f}M params\")\n\noptimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)\nscheduler = cosine_sched(optimizer, WARMUP_STEPS, MAX_STEPS)\nema = EMAModel(model, EMA_DECAY)\nscaler = GradScaler(\"cuda\")\nfm = FlowMatchingScheduler()\nos.makedirs(f\"{OUTPUT_DIR}/samples\", exist_ok=True)\nos.makedirs(f\"{OUTPUT_DIR}/checkpoints\", exist_ok=True)\nprint(f\"Training: {MAX_STEPS} steps, effective batch {BATCH_SIZE*GRADIENT_ACCUMULATION}\")\n"
|
| 131 |
+
]
|
| 132 |
+
},
|
| 133 |
+
{
|
| 134 |
+
"cell_type": "markdown",
|
| 135 |
+
"metadata": {},
|
| 136 |
+
"source": [
|
| 137 |
+
"## \ud83d\ude80 Step 7: Train!"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
]
|
| 139 |
},
|
| 140 |
{
|
|
|
|
| 143 |
"metadata": {},
|
| 144 |
"outputs": [],
|
| 145 |
"source": [
|
| 146 |
+
"global_step = 0; loss_accum = 0.0; log_losses = []; accum_count = 0\nprint(\"Training started!\\n\")\nt0 = time.time(); model.train()\n\nwhile global_step < MAX_STEPS:\n for imgs, lbls in train_loader:\n if global_step >= MAX_STEPS: break\n imgs = imgs.to(device)\n lbls = lbls.to(device) if NUM_CLASSES > 0 else None\n\n with torch.no_grad():\n lats = vae.encode(imgs.half()*2-1).latent_dist.sample()\n lats = ((lats - SHIFT) * SCALE).float()\n\n t = fm.sample_t(lats.shape[0], device)\n noise = torch.randn_like(lats)\n xt = fm.add_noise(lats, noise, t)\n vtgt = fm.velocity(lats, noise)\n\n with autocast(\"cuda\"):\n loss = F.mse_loss(model(xt, t, lbls), vtgt) / GRADIENT_ACCUMULATION\n scaler.scale(loss).backward()\n loss_accum += loss.item()\n accum_count += 1\n\n if accum_count % GRADIENT_ACCUMULATION == 0:\n scaler.unscale_(optimizer)\n gn = torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)\n scaler.step(optimizer); scaler.update(); optimizer.zero_grad()\n scheduler.step(); ema.update(model); global_step += 1\n\n if global_step % LOG_EVERY == 0:\n al = loss_accum / LOG_EVERY; lr = optimizer.param_groups[0][\"lr\"]\n vram = torch.cuda.memory_allocated()/1024**3 if torch.cuda.is_available() else 0\n print(f\"step={global_step:>6d} | loss={al:.4f} | gn={gn:.2f} | lr={lr:.2e} | vram={vram:.1f}G | {time.time()-t0:.0f}s\")\n log_losses.append(al); loss_accum = 0\n if math.isnan(al) or al > 50: print(\"Diverged!\"); break\n\n if global_step % SAMPLE_EVERY == 0:\n ema.apply(model); model.eval()\n ls = IMAGE_SIZE // 8\n sl = torch.randint(0, max(1,NUM_CLASSES), (4,), device=device) if NUM_CLASSES > 0 else None\n samp = fm.sample(model, (4,16,ls,ls), device, NUM_SAMPLE_STEPS, sl, CFG_SCALE)\n with torch.no_grad():\n si = ((vae.decode(samp.half()/SCALE+SHIFT).sample+1)/2).clamp(0,1).float()\n save_image(si, f\"{OUTPUT_DIR}/samples/step_{global_step:07d}.png\", nrow=2)\n print(f\" Saved samples\")\n ema.restore(model); model.train()\n\n if global_step % SAVE_EVERY == 0:\n torch.save({\"model\":model.state_dict(),\"ema\":ema.shadow,\"step\":global_step,\"cfg\":cfg},\n f\"{OUTPUT_DIR}/checkpoints/step_{global_step:07d}.pt\")\n print(f\" Checkpoint saved\")\n\ntorch.save({\"model\":model.state_dict(),\"ema\":ema.shadow,\"cfg\":cfg,\"step\":global_step},\n f\"{OUTPUT_DIR}/checkpoints/final.pt\")\nprint(f\"\\nDone! {global_step} steps in {(time.time()-t0)/60:.1f} min\")\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
]
|
| 148 |
},
|
| 149 |
{
|
| 150 |
"cell_type": "markdown",
|
| 151 |
"metadata": {},
|
| 152 |
"source": [
|
| 153 |
+
"## \ud83d\udcc8 Step 8: Loss Curve"
|
| 154 |
]
|
| 155 |
},
|
| 156 |
{
|
|
|
|
| 159 |
"metadata": {},
|
| 160 |
"outputs": [],
|
| 161 |
"source": [
|
| 162 |
+
"import matplotlib.pyplot as plt\nif log_losses:\n plt.figure(figsize=(10,4)); plt.plot(log_losses)\n plt.xlabel(f\"Steps (x{LOG_EVERY})\"); plt.ylabel(\"Loss\")\n plt.title(\"Training Loss\"); plt.grid(True, alpha=0.3)\n plt.savefig(f\"{OUTPUT_DIR}/loss.png\", dpi=150); plt.show()\n print(f\"Min loss: {min(log_losses):.4f}\")\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
]
|
| 164 |
},
|
| 165 |
{
|
| 166 |
"cell_type": "markdown",
|
| 167 |
"metadata": {},
|
| 168 |
"source": [
|
| 169 |
+
"## \ud83c\udfa8 Step 9: Generate"
|
| 170 |
]
|
| 171 |
},
|
| 172 |
{
|
|
|
|
| 175 |
"metadata": {},
|
| 176 |
"outputs": [],
|
| 177 |
"source": [
|
| 178 |
+
"ema.apply(model); model.eval()\nN, STEPS, G = 8, 50, 2.5; ls = IMAGE_SIZE // 8\nif NUM_CLASSES > 0:\n for ci in range(min(NUM_CLASSES, 6)):\n l = torch.full((N,), ci, device=device, dtype=torch.long)\n s = fm.sample(model, (N,16,ls,ls), device, STEPS, l, G)\n with torch.no_grad(): i = ((vae.decode(s.half()/SCALE+SHIFT).sample+1)/2).clamp(0,1).float()\n save_image(i, f\"{OUTPUT_DIR}/gen_class{ci}.png\", nrow=4)\n print(f\"Class {ci}\")\nelse:\n s = fm.sample(model, (N,16,ls,ls), device, STEPS)\n with torch.no_grad(): i = ((vae.decode(s.half()/SCALE+SHIFT).sample+1)/2).clamp(0,1).float()\n save_image(i, f\"{OUTPUT_DIR}/gen_uncond.png\", nrow=4)\nema.restore(model)\nprint(f\"Saved to {OUTPUT_DIR}/\")\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
]
|
| 180 |
},
|
| 181 |
{
|
| 182 |
"cell_type": "markdown",
|
| 183 |
"metadata": {},
|
| 184 |
"source": [
|
| 185 |
+
"## \ud83d\udce4 Step 10: Display"
|
| 186 |
]
|
| 187 |
},
|
| 188 |
{
|
|
|
|
| 191 |
"metadata": {},
|
| 192 |
"outputs": [],
|
| 193 |
"source": [
|
| 194 |
+
"from IPython.display import display\nimport glob\nfor f in sorted(glob.glob(f\"{OUTPUT_DIR}/samples/*.png\"))[-3:]:\n print(os.path.basename(f)); display(Image.open(f))\nfor f in sorted(glob.glob(f\"{OUTPUT_DIR}/gen_*.png\")):\n print(os.path.basename(f)); display(Image.open(f))\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
]
|
| 196 |
}
|
| 197 |
]
|