asdf98 commited on
Commit
1373ccf
·
verified ·
1 Parent(s): a1ff09a

Fix: streaming dataset in notebook (no full download on Colab)

Browse files
Files changed (1) hide show
  1. LiquidGen_Colab_Notebook.ipynb +34 -247
LiquidGen_Colab_Notebook.ipynb CHANGED
@@ -21,21 +21,17 @@
21
  "\n",
22
  "**A novel attention-free diffusion model using CfC Liquid Neural Network dynamics.**\n",
23
  "\n",
24
- "### Key Features:\n",
25
  "- **No Attention** \u2014 O(n) complexity using liquid time constants\n",
26
  "- **Fully Parallelizable** \u2014 No sequential ODE solving\n",
27
- "- **Flow Matching** \u2014 Modern velocity-prediction training\n",
28
- "- **Frozen Flux VAE** \u2014 16-channel latent space\n",
29
- "- **Fits 16GB VRAM** \u2014 Designed for Colab free tier\n",
30
- "\n",
31
- "Based on: Liquid Time-constant Networks (NeurIPS 2020), CfC (Nature MI 2022), ZigMa (ECCV 2024), DiMSUM (NeurIPS 2024)\n"
32
  ]
33
  },
34
  {
35
  "cell_type": "markdown",
36
  "metadata": {},
37
  "source": [
38
- "## \ud83d\udce6 Install Dependencies"
39
  ]
40
  },
41
  {
@@ -44,14 +40,21 @@
44
  "metadata": {},
45
  "outputs": [],
46
  "source": [
47
- "!pip install -q torch torchvision diffusers datasets accelerate huggingface_hub Pillow"
48
  ]
49
  },
50
  {
51
  "cell_type": "markdown",
52
  "metadata": {},
53
  "source": [
54
- "## \ud83d\udd27 Configuration"
 
 
 
 
 
 
 
55
  ]
56
  },
57
  {
@@ -60,48 +63,14 @@
60
  "metadata": {},
61
  "outputs": [],
62
  "source": [
63
- "MODEL_SIZE = \"small\" # \"small\" (~55M), \"base\" (~140M), \"large\" (~280M)\n",
64
- "IMAGE_SIZE = 256 # 256 or 512\n",
65
- "DATASET_NAME = \"huggan/wikiart\"\n",
66
- "IMAGE_COLUMN = \"image\"\n",
67
- "LABEL_COLUMN = \"style\" # \"style\" (27), \"genre\" (11), \"\" for unconditional\n",
68
- "NUM_CLASSES = 27\n",
69
- "BATCH_SIZE = 8\n",
70
- "GRADIENT_ACCUMULATION = 4\n",
71
- "LEARNING_RATE = 1e-4\n",
72
- "WEIGHT_DECAY = 0.01\n",
73
- "MAX_GRAD_NORM = 2.0\n",
74
- "NUM_EPOCHS = 50\n",
75
- "WARMUP_STEPS = 500\n",
76
- "EMA_DECAY = 0.9999\n",
77
- "NUM_SAMPLE_STEPS = 50\n",
78
- "CFG_SCALE = 2.0\n",
79
- "OUTPUT_DIR = \"/content/liquidgen_outputs\"\n",
80
- "SAVE_EVERY = 2000\n",
81
- "SAMPLE_EVERY = 500\n",
82
- "LOG_EVERY = 50\n",
83
- "PUSH_TO_HUB = False\n",
84
- "HUB_MODEL_ID = \"\"\n",
85
- "VAE_ID = \"black-forest-labs/FLUX.1-schnell\"\n",
86
- "VAE_SUBFOLDER = \"vae\"\n",
87
- "\n",
88
- "import torch\n",
89
- "if torch.cuda.is_available():\n",
90
- " gpu = torch.cuda.get_device_name(0)\n",
91
- " mem = torch.cuda.get_device_properties(0).total_mem / 1024**3\n",
92
- " print(f\"GPU: {gpu} ({mem:.1f} GB)\")\n",
93
- " if mem < 12: print(\"\u26a0\ufe0f Low VRAM! Use small model, 256px, bs=4\")\n",
94
- " elif mem < 20: print(\"\u2705 T4 detected. Good for base model, 256px\")\n",
95
- " else: print(\"\ud83d\ude80 Large GPU! Can run large model, 512px\")\n",
96
- "else:\n",
97
- " print(\"\u26a0\ufe0f No GPU! Go to Runtime \u2192 Change runtime type \u2192 GPU\")"
98
  ]
99
  },
100
  {
101
  "cell_type": "markdown",
102
  "metadata": {},
103
  "source": [
104
- "## \ud83c\udfd7\ufe0f Model Architecture"
105
  ]
106
  },
107
  {
@@ -117,7 +86,7 @@
117
  "cell_type": "markdown",
118
  "metadata": {},
119
  "source": [
120
- "## \ud83d\udd04 Training Utilities"
121
  ]
122
  },
123
  {
@@ -126,77 +95,14 @@
126
  "metadata": {},
127
  "outputs": [],
128
  "source": [
129
- "import os, json, time, math\n",
130
- "import numpy as np\n",
131
- "from torch.utils.data import DataLoader, Dataset\n",
132
- "from torch.amp import autocast, GradScaler\n",
133
- "from torchvision import transforms\n",
134
- "from torchvision.utils import save_image\n",
135
- "from PIL import Image\n",
136
- "\n",
137
- "class FlowMatchingScheduler:\n",
138
- " def __init__(self, min_t=0.001, max_t=0.999): self.min_t, self.max_t = min_t, max_t\n",
139
- " def sample_timesteps(self, bs, dev): return torch.rand(bs, device=dev) * (self.max_t - self.min_t) + self.min_t\n",
140
- " def add_noise(self, x0, noise, t): t = t.view(-1,1,1,1); return (1-t)*x0 + t*noise\n",
141
- " def get_velocity_target(self, x0, noise): return noise - x0\n",
142
- " @torch.no_grad()\n",
143
- " def sample(self, model, shape, dev, num_steps=50, labels=None, cfg=1.0):\n",
144
- " model.eval(); x = torch.randn(shape, device=dev)\n",
145
- " dt = 1.0 / num_steps\n",
146
- " for t_val in torch.linspace(1.0, dt, num_steps, device=dev):\n",
147
- " t = torch.full((shape[0],), t_val.item(), device=dev)\n",
148
- " with torch.amp.autocast(\"cuda\"):\n",
149
- " if cfg > 1.0 and labels is not None:\n",
150
- " vc = model(x,t,labels); vu = model(x,t,torch.zeros_like(labels))\n",
151
- " v = vu + cfg * (vc - vu)\n",
152
- " else: v = model(x, t, labels)\n",
153
- " x = x - dt * v.float()\n",
154
- " return x\n",
155
- "\n",
156
- "class EMAModel:\n",
157
- " def __init__(self, model, decay=0.9999):\n",
158
- " self.decay = decay\n",
159
- " self.shadow = {n: p.clone().detach() for n,p in model.named_parameters() if p.requires_grad}\n",
160
- " @torch.no_grad()\n",
161
- " def update(self, model):\n",
162
- " for n,p in model.named_parameters():\n",
163
- " if p.requires_grad and n in self.shadow: self.shadow[n].mul_(self.decay).add_(p.data, alpha=1-self.decay)\n",
164
- " def apply(self, model):\n",
165
- " self.backup = {n: p.data.clone() for n,p in model.named_parameters() if p.requires_grad}\n",
166
- " for n,p in model.named_parameters():\n",
167
- " if p.requires_grad and n in self.shadow: p.data.copy_(self.shadow[n])\n",
168
- " def restore(self, model):\n",
169
- " for n,p in model.named_parameters():\n",
170
- " if p.requires_grad and n in self.backup: p.data.copy_(self.backup[n])\n",
171
- "\n",
172
- "class ImageDataset(Dataset):\n",
173
- " def __init__(self, ds, tf, img_col, lbl_col=\"\"): self.ds, self.tf, self.ic, self.lc = ds, tf, img_col, lbl_col\n",
174
- " def __len__(self): return len(self.ds)\n",
175
- " def __getitem__(self, i):\n",
176
- " item = self.ds[i]; img = item[self.ic]\n",
177
- " if img.mode != \"RGB\": img = img.convert(\"RGB\")\n",
178
- " label = item[self.lc] if self.lc and self.lc in item else -1\n",
179
- " return self.tf(img), label\n",
180
- "\n",
181
- "def cosine_sched(opt, warmup, total):\n",
182
- " def lr(s):\n",
183
- " if s < warmup: return s / max(1, warmup)\n",
184
- " return max(0, 0.5*(1+math.cos(math.pi*(s-warmup)/max(1,total-warmup))))\n",
185
- " return torch.optim.lr_scheduler.LambdaLR(opt, lr)\n",
186
- "\n",
187
- "MODEL_CONFIGS = {\n",
188
- " \"small\": dict(embed_dim=512, depth=12, spatial_kernel=7, scan_kernel=31, expand_ratio=2.0, mlp_ratio=3.0),\n",
189
- " \"base\": dict(embed_dim=640, depth=18, spatial_kernel=7, scan_kernel=31, expand_ratio=2.0, mlp_ratio=4.0),\n",
190
- " \"large\": dict(embed_dim=768, depth=24, spatial_kernel=7, scan_kernel=31, expand_ratio=2.5, mlp_ratio=4.0),\n",
191
- "}\n",
192
- "print(\"\u2705 Training utilities ready!\")"
193
  ]
194
  },
195
  {
196
  "cell_type": "markdown",
197
  "metadata": {},
198
  "source": [
199
- "## \ud83d\udcca Load Dataset & VAE"
200
  ]
201
  },
202
  {
@@ -205,33 +111,14 @@
205
  "metadata": {},
206
  "outputs": [],
207
  "source": [
208
- "from datasets import load_dataset\n",
209
- "from diffusers import AutoencoderKL\n",
210
- "\n",
211
- "print(f\"Loading dataset: {DATASET_NAME}...\")\n",
212
- "dataset = load_dataset(DATASET_NAME, split=\"train\")\n",
213
- "print(f\" {len(dataset)} images\")\n",
214
- "\n",
215
- "transform = transforms.Compose([\n",
216
- " transforms.Resize(IMAGE_SIZE, interpolation=transforms.InterpolationMode.LANCZOS),\n",
217
- " transforms.CenterCrop(IMAGE_SIZE), transforms.RandomHorizontalFlip(), transforms.ToTensor(),\n",
218
- "])\n",
219
- "train_ds = ImageDataset(dataset, transform, IMAGE_COLUMN, LABEL_COLUMN)\n",
220
- "train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)\n",
221
- "\n",
222
- "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
223
- "vae = AutoencoderKL.from_pretrained(VAE_ID, subfolder=VAE_SUBFOLDER, torch_dtype=torch.float16).to(device).eval()\n",
224
- "for p in vae.parameters(): p.requires_grad_(False)\n",
225
- "print(f\"VAE: {sum(p.numel() for p in vae.parameters())/1e6:.1f}M params (frozen)\")\n",
226
- "SCALE, SHIFT = 0.3611, 0.1159\n",
227
- "print(\"\u2705 Ready!\")"
228
  ]
229
  },
230
  {
231
  "cell_type": "markdown",
232
  "metadata": {},
233
  "source": [
234
- "## \ud83c\udfcb\ufe0f Create Model & Train"
235
  ]
236
  },
237
  {
@@ -240,21 +127,14 @@
240
  "metadata": {},
241
  "outputs": [],
242
  "source": [
243
- "cfg = MODEL_CONFIGS[MODEL_SIZE].copy()\n",
244
- "cfg[\"num_classes\"] = NUM_CLASSES; cfg[\"class_drop_prob\"] = 0.1; cfg[\"use_zigzag\"] = True\n",
245
- "model = LiquidGen(**cfg).to(device)\n",
246
- "print(f\"LiquidGen-{MODEL_SIZE}: {model.count_params()/1e6:.1f}M params\")\n",
247
- "\n",
248
- "optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)\n",
249
- "total_steps = len(train_loader) * NUM_EPOCHS // GRADIENT_ACCUMULATION\n",
250
- "scheduler = cosine_sched(optimizer, WARMUP_STEPS, total_steps)\n",
251
- "ema = EMAModel(model, EMA_DECAY)\n",
252
- "scaler = GradScaler(\"cuda\")\n",
253
- "fm = FlowMatchingScheduler()\n",
254
- "os.makedirs(f\"{OUTPUT_DIR}/samples\", exist_ok=True)\n",
255
- "os.makedirs(f\"{OUTPUT_DIR}/checkpoints\", exist_ok=True)\n",
256
- "print(f\"Total steps: {total_steps}, Effective batch: {BATCH_SIZE*GRADIENT_ACCUMULATION}\")\n",
257
- "if torch.cuda.is_available(): print(f\"VRAM: {torch.cuda.memory_allocated()/1024**3:.2f} GB used\")"
258
  ]
259
  },
260
  {
@@ -263,57 +143,14 @@
263
  "metadata": {},
264
  "outputs": [],
265
  "source": [
266
- "global_step = 0; loss_accum = 0.0; log_losses = []\n",
267
- "print(\"\ud83d\ude80 Training!\n\")\n",
268
- "t0 = time.time()\n",
269
- "for epoch in range(NUM_EPOCHS):\n",
270
- " model.train(); ep_loss = 0; ep_steps = 0; ep_t = time.time()\n",
271
- " for bi, (imgs, lbls) in enumerate(train_loader):\n",
272
- " imgs = imgs.to(device)\n",
273
- " lbls = lbls.to(device) if NUM_CLASSES > 0 else None\n",
274
- " with torch.no_grad():\n",
275
- " lats = vae.encode(imgs.half()*2-1).latent_dist.sample()\n",
276
- " lats = ((lats - SHIFT) * SCALE).float()\n",
277
- " t = fm.sample_timesteps(lats.shape[0], device)\n",
278
- " noise = torch.randn_like(lats)\n",
279
- " xt = fm.add_noise(lats, noise, t)\n",
280
- " vtgt = fm.get_velocity_target(lats, noise)\n",
281
- " with autocast(\"cuda\"): loss = F.mse_loss(model(xt, t, lbls), vtgt) / GRADIENT_ACCUMULATION\n",
282
- " scaler.scale(loss).backward()\n",
283
- " loss_accum += loss.item()\n",
284
- " if (bi+1) % GRADIENT_ACCUMULATION == 0:\n",
285
- " scaler.unscale_(optimizer)\n",
286
- " gn = torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)\n",
287
- " scaler.step(optimizer); scaler.update(); optimizer.zero_grad(); scheduler.step()\n",
288
- " ema.update(model); global_step += 1\n",
289
- " if global_step % LOG_EVERY == 0:\n",
290
- " al = loss_accum / LOG_EVERY; lr = optimizer.param_groups[0][\"lr\"]\n",
291
- " vram = torch.cuda.memory_allocated()/1024**3 if torch.cuda.is_available() else 0\n",
292
- " print(f\"step={global_step:>6d} | ep={epoch} | loss={al:.4f} | gn={gn:.2f} | lr={lr:.2e} | vram={vram:.1f}G\")\n",
293
- " log_losses.append(al); loss_accum = 0\n",
294
- " if math.isnan(al) or al > 50: print(\"\ud83d\udca5 Diverged!\"); break\n",
295
- " if global_step % SAMPLE_EVERY == 0:\n",
296
- " ema.apply(model); model.eval()\n",
297
- " ls = IMAGE_SIZE // 8\n",
298
- " sl = torch.randint(0, max(1,NUM_CLASSES), (4,), device=device) if NUM_CLASSES > 0 else None\n",
299
- " samp = fm.sample(model, (4,16,ls,ls), device, NUM_SAMPLE_STEPS, sl, CFG_SCALE)\n",
300
- " with torch.no_grad(): si = ((vae.decode(samp.half()/SCALE+SHIFT).sample+1)/2).clamp(0,1).float()\n",
301
- " save_image(si, f\"{OUTPUT_DIR}/samples/step_{global_step:07d}.png\", nrow=2)\n",
302
- " ema.restore(model); model.train()\n",
303
- " if global_step % SAVE_EVERY == 0:\n",
304
- " torch.save({\"model\":model.state_dict(),\"ema\":ema.shadow,\"step\":global_step,\"cfg\":cfg},\n",
305
- " f\"{OUTPUT_DIR}/checkpoints/step_{global_step:07d}.pt\")\n",
306
- " ep_loss += loss.item()*GRADIENT_ACCUMULATION; ep_steps += 1\n",
307
- " print(f\"Epoch {epoch} | loss={ep_loss/max(ep_steps,1):.4f} | {time.time()-ep_t:.0f}s\")\n",
308
- "torch.save({\"model\":model.state_dict(),\"ema\":ema.shadow,\"cfg\":cfg,\"step\":global_step},f\"{OUTPUT_DIR}/checkpoints/final.pt\")\n",
309
- "print(f\"\ud83c\udf89 Done! {global_step} steps in {(time.time()-t0)/60:.1f} min\")"
310
  ]
311
  },
312
  {
313
  "cell_type": "markdown",
314
  "metadata": {},
315
  "source": [
316
- "## \ud83d\udcc8 Training Loss"
317
  ]
318
  },
319
  {
@@ -322,18 +159,14 @@
322
  "metadata": {},
323
  "outputs": [],
324
  "source": [
325
- "import matplotlib.pyplot as plt\n",
326
- "if log_losses:\n",
327
- " plt.figure(figsize=(10,4)); plt.plot(log_losses); plt.xlabel(f\"Steps (\u00d7{LOG_EVERY})\"); plt.ylabel(\"Loss\")\n",
328
- " plt.title(\"Training Loss\"); plt.grid(True, alpha=0.3); plt.savefig(f\"{OUTPUT_DIR}/loss.png\", dpi=150); plt.show()\n",
329
- " print(f\"Min loss: {min(log_losses):.4f}\")"
330
  ]
331
  },
332
  {
333
  "cell_type": "markdown",
334
  "metadata": {},
335
  "source": [
336
- "## \ud83c\udfa8 Generate Images"
337
  ]
338
  },
339
  {
@@ -342,32 +175,14 @@
342
  "metadata": {},
343
  "outputs": [],
344
  "source": [
345
- "ema.apply(model); model.eval()\n",
346
- "N, STEPS, CFG = 8, 50, 2.5\n",
347
- "ls = IMAGE_SIZE // 8\n",
348
- "STYLES = [\"Abstract Expressionism\",\"Baroque\",\"Cubism\",\"Expressionism\",\"Impressionism\",\n",
349
- " \"Pop Art\",\"Realism\",\"Romanticism\",\"Symbolism\",\"Ukiyo-e\"]\n",
350
- "if NUM_CLASSES > 0:\n",
351
- " for ci in range(min(NUM_CLASSES, 8)):\n",
352
- " l = torch.full((N,), ci, device=device, dtype=torch.long)\n",
353
- " s = fm.sample(model, (N,16,ls,ls), device, STEPS, l, CFG)\n",
354
- " with torch.no_grad(): i = ((vae.decode(s.half()/SCALE+SHIFT).sample+1)/2).clamp(0,1).float()\n",
355
- " nm = STYLES[ci] if ci < len(STYLES) else f\"Class_{ci}\"\n",
356
- " save_image(i, f\"{OUTPUT_DIR}/gen_{nm.replace(chr(32),chr(95))}.png\", nrow=4)\n",
357
- " print(f\"Generated: {nm}\")\n",
358
- "else:\n",
359
- " s = fm.sample(model, (N,16,ls,ls), device, STEPS)\n",
360
- " with torch.no_grad(): i = ((vae.decode(s.half()/SCALE+SHIFT).sample+1)/2).clamp(0,1).float()\n",
361
- " save_image(i, f\"{OUTPUT_DIR}/gen_uncond.png\", nrow=4)\n",
362
- "ema.restore(model)\n",
363
- "print(f\"\u2705 Saved to {OUTPUT_DIR}/\")"
364
  ]
365
  },
366
  {
367
  "cell_type": "markdown",
368
  "metadata": {},
369
  "source": [
370
- "## \ud83d\udce4 Display Results"
371
  ]
372
  },
373
  {
@@ -376,35 +191,7 @@
376
  "metadata": {},
377
  "outputs": [],
378
  "source": [
379
- "from IPython.display import display\n",
380
- "import glob\n",
381
- "for f in sorted(glob.glob(f\"{OUTPUT_DIR}/samples/*.png\"))[-3:]:\n",
382
- " print(os.path.basename(f)); display(Image.open(f))\n",
383
- "for f in sorted(glob.glob(f\"{OUTPUT_DIR}/gen_*.png\")):\n",
384
- " print(os.path.basename(f)); display(Image.open(f))"
385
- ]
386
- },
387
- {
388
- "cell_type": "markdown",
389
- "metadata": {},
390
- "source": [
391
- "## \ud83d\udcdd Architecture Reference\n",
392
- "\n",
393
- "### Core Equation (CfC Liquid Dynamics)\n",
394
- "\n",
395
- "\n",
396
- "### Flow Matching\n",
397
- "\n",
398
- "\n",
399
- "### Sampling (Euler ODE)\n",
400
- "\n",
401
- "\n",
402
- "### References\n",
403
- "- Hasani et al., \"Liquid Time-constant Networks\" (NeurIPS 2020)\n",
404
- "- Hasani et al., \"Closed-form Continuous-depth Models\" (Nature MI 2022)\n",
405
- "- Lechner et al., \"Neural Circuit Policies\" (Nature MI 2020)\n",
406
- "- ZigMa (ECCV 2024), DiMSUM (NeurIPS 2024)\n",
407
- "- Lipman et al., \"Flow Matching\" (2023), SiT (2024)\n"
408
  ]
409
  }
410
  ]
 
21
  "\n",
22
  "**A novel attention-free diffusion model using CfC Liquid Neural Network dynamics.**\n",
23
  "\n",
 
24
  "- **No Attention** \u2014 O(n) complexity using liquid time constants\n",
25
  "- **Fully Parallelizable** \u2014 No sequential ODE solving\n",
26
+ "- **Streaming Dataset** \u2014 No full download, starts training immediately\n",
27
+ "- **Fits 16GB VRAM** \u2014 Designed for Colab free tier T4\n"
 
 
 
28
  ]
29
  },
30
  {
31
  "cell_type": "markdown",
32
  "metadata": {},
33
  "source": [
34
+ "## \ud83d\udce6 Step 1: Install"
35
  ]
36
  },
37
  {
 
40
  "metadata": {},
41
  "outputs": [],
42
  "source": [
43
+ "!pip install -q torch torchvision diffusers datasets accelerate Pillow"
44
  ]
45
  },
46
  {
47
  "cell_type": "markdown",
48
  "metadata": {},
49
  "source": [
50
+ "## \ud83d\udd27 Step 2: Configuration\n",
51
+ "\n",
52
+ "**Dataset options:**\n",
53
+ "| Dataset | Size | Download | Type |\n",
54
+ "|---------|------|----------|------|\n",
55
+ "| `huggan/wikiart` | ~80K | **Streaming** (no download!) | Art, 27 styles |\n",
56
+ "| `reach-vb/pokemon-blip-captions` | 833 | 95MB (fast) | Pokemon |\n",
57
+ "| `huggan/flowers-102-categories` | 8K | 330MB | Flowers |\n"
58
  ]
59
  },
60
  {
 
63
  "metadata": {},
64
  "outputs": [],
65
  "source": [
66
+ "# ============================================================================\n# CONFIGURATION\n# ============================================================================\n\nMODEL_SIZE = \"small\" # \"small\" (~55M), \"base\" (~140M), \"large\" (~280M)\nIMAGE_SIZE = 256 # 256 or 512\n\n# --- Dataset (Option A: WikiArt streaming \u2014 NO download) ---\nDATASET_NAME = \"huggan/wikiart\"\nIMAGE_COLUMN = \"image\"\nLABEL_COLUMN = \"style\" # \"style\"(27), \"genre\"(11), \"\" for unconditional\nNUM_CLASSES = 27\nUSE_STREAMING = True # KEY: no full download!\n\n# --- Dataset (Option B: Pokemon \u2014 small, fast download, good for testing) ---\n# DATASET_NAME = \"reach-vb/pokemon-blip-captions\"\n# IMAGE_COLUMN = \"image\"; LABEL_COLUMN = \"\"; NUM_CLASSES = 0; USE_STREAMING = False\n\n# --- Training ---\nBATCH_SIZE = 8; GRADIENT_ACCUMULATION = 4\nLEARNING_RATE = 1e-4; WEIGHT_DECAY = 0.01; MAX_GRAD_NORM = 2.0\nMAX_STEPS = 20000; WARMUP_STEPS = 500; EMA_DECAY = 0.9999\nNUM_SAMPLE_STEPS = 50; CFG_SCALE = 2.0\n\n# --- Saving ---\nOUTPUT_DIR = \"/content/liquidgen_outputs\"\nSAVE_EVERY = 5000; SAMPLE_EVERY = 500; LOG_EVERY = 50\n\n# --- VAE ---\nVAE_ID = \"black-forest-labs/FLUX.1-schnell\"\nVAE_SUBFOLDER = \"vae\"\nSCALE, SHIFT = 0.3611, 0.1159\n\nimport torch\nif torch.cuda.is_available():\n gpu = torch.cuda.get_device_name(0)\n mem = torch.cuda.get_device_properties(0).total_mem / 1024**3\n print(f\"GPU: {gpu} ({mem:.1f} GB)\")\nelse:\n print(\"No GPU! Go to Runtime > Change runtime type > GPU\")\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  ]
68
  },
69
  {
70
  "cell_type": "markdown",
71
  "metadata": {},
72
  "source": [
73
+ "## \ud83c\udfd7\ufe0f Step 3: Model Architecture"
74
  ]
75
  },
76
  {
 
86
  "cell_type": "markdown",
87
  "metadata": {},
88
  "source": [
89
+ "## \ud83d\udd04 Step 4: Training Utilities"
90
  ]
91
  },
92
  {
 
95
  "metadata": {},
96
  "outputs": [],
97
  "source": [
98
+ "import os, time, math\nimport numpy as np\nfrom torch.utils.data import DataLoader, IterableDataset, Dataset\nfrom torch.amp import autocast, GradScaler\nfrom torchvision import transforms\nfrom torchvision.utils import save_image\nfrom PIL import Image\n\nclass StreamingImageDataset(IterableDataset):\n \"\"\"Streaming dataset \u2014 NO full download. Images load on-the-fly.\"\"\"\n def __init__(self, name, img_col=\"image\", lbl_col=\"\", img_size=256,\n split=\"train\", config=\"\", buffer=1000, seed=42):\n super().__init__()\n self.name, self.img_col, self.lbl_col = name, img_col, lbl_col\n self.split, self.config, self.buffer, self.seed = split, config, buffer, seed\n self.tf = transforms.Compose([\n transforms.Resize(img_size, interpolation=transforms.InterpolationMode.LANCZOS),\n transforms.CenterCrop(img_size), transforms.RandomHorizontalFlip(), transforms.ToTensor()])\n\n def __iter__(self):\n from datasets import load_dataset\n kw = {\"name\": self.config} if self.config else {}\n ds = load_dataset(self.name, split=self.split, streaming=True, **kw)\n ds = ds.shuffle(seed=self.seed, buffer_size=self.buffer)\n for item in ds:\n try:\n img = item[self.img_col]\n if img.mode != \"RGB\": img = img.convert(\"RGB\")\n lbl = item[self.lbl_col] if self.lbl_col and self.lbl_col in item else -1\n yield self.tf(img), lbl\n except: continue\n\nclass MapImageDataset(Dataset):\n \"\"\"For small datasets (<500MB) \u2014 downloads once.\"\"\"\n def __init__(self, name, img_col=\"image\", lbl_col=\"\", img_size=256, split=\"train\"):\n from datasets import load_dataset\n print(f\"Downloading {name}...\")\n self.ds = load_dataset(name, split=split)\n self.img_col, self.lbl_col = img_col, lbl_col\n self.tf = transforms.Compose([\n transforms.Resize(img_size, interpolation=transforms.InterpolationMode.LANCZOS),\n transforms.CenterCrop(img_size), transforms.RandomHorizontalFlip(), transforms.ToTensor()])\n print(f\" {len(self.ds)} images\")\n\n def __len__(self): return len(self.ds)\n def __getitem__(self, i):\n item = self.ds[i]; img = item[self.img_col]\n if img.mode != \"RGB\": img = img.convert(\"RGB\")\n lbl = item[self.lbl_col] if self.lbl_col and self.lbl_col in item else -1\n return self.tf(img), lbl\n\nclass FlowMatchingScheduler:\n def __init__(self, min_t=0.001, max_t=0.999): self.min_t, self.max_t = min_t, max_t\n def sample_t(self, bs, dev): return torch.rand(bs, device=dev)*(self.max_t-self.min_t)+self.min_t\n def add_noise(self, x0, noise, t): return (1-t.view(-1,1,1,1))*x0 + t.view(-1,1,1,1)*noise\n def velocity(self, x0, noise): return noise - x0\n @torch.no_grad()\n def sample(self, model, shape, dev, steps=50, labels=None, cfg=1.0):\n model.eval(); x = torch.randn(shape, device=dev); dt = 1.0/steps\n for tv in torch.linspace(1.0, dt, steps, device=dev):\n t = torch.full((shape[0],), tv.item(), device=dev)\n with torch.amp.autocast(\"cuda\"):\n if cfg > 1.0 and labels is not None:\n vc = model(x,t,labels); vu = model(x,t,torch.zeros_like(labels))\n v = vu + cfg*(vc-vu)\n else: v = model(x,t,labels)\n x = x - dt * v.float()\n return x\n\nclass EMAModel:\n def __init__(self, model, decay=0.9999):\n self.decay = decay\n self.shadow = {n:p.clone().detach() for n,p in model.named_parameters() if p.requires_grad}\n @torch.no_grad()\n def update(self, m):\n for n,p in m.named_parameters():\n if p.requires_grad and n in self.shadow: self.shadow[n].mul_(self.decay).add_(p.data, alpha=1-self.decay)\n def apply(self, m):\n self.bk = {n:p.data.clone() for n,p in m.named_parameters() if p.requires_grad}\n for n,p in m.named_parameters():\n if p.requires_grad and n in self.shadow: p.data.copy_(self.shadow[n])\n def restore(self, m):\n for n,p in m.named_parameters():\n if p.requires_grad and n in self.bk: p.data.copy_(self.bk[n])\n\ndef cosine_sched(opt, warmup, total):\n def lr(s):\n if s < warmup: return s/max(1,warmup)\n return max(0, 0.5*(1+math.cos(math.pi*(s-warmup)/max(1,total-warmup))))\n return torch.optim.lr_scheduler.LambdaLR(opt, lr)\n\nMODEL_CONFIGS = {\n \"small\": dict(embed_dim=512, depth=12, spatial_kernel=7, scan_kernel=31, expand_ratio=2.0, mlp_ratio=3.0),\n \"base\": dict(embed_dim=640, depth=18, spatial_kernel=7, scan_kernel=31, expand_ratio=2.0, mlp_ratio=4.0),\n \"large\": dict(embed_dim=768, depth=24, spatial_kernel=7, scan_kernel=31, expand_ratio=2.5, mlp_ratio=4.0),\n}\nprint(\"Training utilities ready!\")\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  ]
100
  },
101
  {
102
  "cell_type": "markdown",
103
  "metadata": {},
104
  "source": [
105
+ "## \ud83d\udcca Step 5: Load Dataset & VAE"
106
  ]
107
  },
108
  {
 
111
  "metadata": {},
112
  "outputs": [],
113
  "source": [
114
+ "from diffusers import AutoencoderKL\n\nif USE_STREAMING:\n print(f\"Loading {DATASET_NAME} in STREAMING mode (no full download)...\")\n train_ds = StreamingImageDataset(DATASET_NAME, IMAGE_COLUMN, LABEL_COLUMN, IMAGE_SIZE, buffer=1000)\n train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, num_workers=0, pin_memory=True)\n print(\" Streaming ready! Images load on-the-fly.\")\nelse:\n train_ds = MapImageDataset(DATASET_NAME, IMAGE_COLUMN, LABEL_COLUMN, IMAGE_SIZE)\n train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True, drop_last=True)\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\nprint(f\"Loading VAE to {device}...\")\nvae = AutoencoderKL.from_pretrained(VAE_ID, subfolder=VAE_SUBFOLDER, torch_dtype=torch.float16).to(device).eval()\nfor p in vae.parameters(): p.requires_grad_(False)\nprint(f\" VAE: {sum(p.numel() for p in vae.parameters())/1e6:.1f}M params (frozen)\")\nprint(\"Ready!\")\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  ]
116
  },
117
  {
118
  "cell_type": "markdown",
119
  "metadata": {},
120
  "source": [
121
+ "## \ud83c\udfcb\ufe0f Step 6: Create Model"
122
  ]
123
  },
124
  {
 
127
  "metadata": {},
128
  "outputs": [],
129
  "source": [
130
+ "cfg = MODEL_CONFIGS[MODEL_SIZE].copy()\ncfg[\"num_classes\"] = NUM_CLASSES; cfg[\"class_drop_prob\"] = 0.1; cfg[\"use_zigzag\"] = True\nmodel = LiquidGen(**cfg).to(device)\nprint(f\"LiquidGen-{MODEL_SIZE}: {model.count_params()/1e6:.1f}M params\")\n\noptimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)\nscheduler = cosine_sched(optimizer, WARMUP_STEPS, MAX_STEPS)\nema = EMAModel(model, EMA_DECAY)\nscaler = GradScaler(\"cuda\")\nfm = FlowMatchingScheduler()\nos.makedirs(f\"{OUTPUT_DIR}/samples\", exist_ok=True)\nos.makedirs(f\"{OUTPUT_DIR}/checkpoints\", exist_ok=True)\nprint(f\"Training: {MAX_STEPS} steps, effective batch {BATCH_SIZE*GRADIENT_ACCUMULATION}\")\n"
131
+ ]
132
+ },
133
+ {
134
+ "cell_type": "markdown",
135
+ "metadata": {},
136
+ "source": [
137
+ "## \ud83d\ude80 Step 7: Train!"
 
 
 
 
 
 
 
138
  ]
139
  },
140
  {
 
143
  "metadata": {},
144
  "outputs": [],
145
  "source": [
146
+ "global_step = 0; loss_accum = 0.0; log_losses = []; accum_count = 0\nprint(\"Training started!\\n\")\nt0 = time.time(); model.train()\n\nwhile global_step < MAX_STEPS:\n for imgs, lbls in train_loader:\n if global_step >= MAX_STEPS: break\n imgs = imgs.to(device)\n lbls = lbls.to(device) if NUM_CLASSES > 0 else None\n\n with torch.no_grad():\n lats = vae.encode(imgs.half()*2-1).latent_dist.sample()\n lats = ((lats - SHIFT) * SCALE).float()\n\n t = fm.sample_t(lats.shape[0], device)\n noise = torch.randn_like(lats)\n xt = fm.add_noise(lats, noise, t)\n vtgt = fm.velocity(lats, noise)\n\n with autocast(\"cuda\"):\n loss = F.mse_loss(model(xt, t, lbls), vtgt) / GRADIENT_ACCUMULATION\n scaler.scale(loss).backward()\n loss_accum += loss.item()\n accum_count += 1\n\n if accum_count % GRADIENT_ACCUMULATION == 0:\n scaler.unscale_(optimizer)\n gn = torch.nn.utils.clip_grad_norm_(model.parameters(), MAX_GRAD_NORM)\n scaler.step(optimizer); scaler.update(); optimizer.zero_grad()\n scheduler.step(); ema.update(model); global_step += 1\n\n if global_step % LOG_EVERY == 0:\n al = loss_accum / LOG_EVERY; lr = optimizer.param_groups[0][\"lr\"]\n vram = torch.cuda.memory_allocated()/1024**3 if torch.cuda.is_available() else 0\n print(f\"step={global_step:>6d} | loss={al:.4f} | gn={gn:.2f} | lr={lr:.2e} | vram={vram:.1f}G | {time.time()-t0:.0f}s\")\n log_losses.append(al); loss_accum = 0\n if math.isnan(al) or al > 50: print(\"Diverged!\"); break\n\n if global_step % SAMPLE_EVERY == 0:\n ema.apply(model); model.eval()\n ls = IMAGE_SIZE // 8\n sl = torch.randint(0, max(1,NUM_CLASSES), (4,), device=device) if NUM_CLASSES > 0 else None\n samp = fm.sample(model, (4,16,ls,ls), device, NUM_SAMPLE_STEPS, sl, CFG_SCALE)\n with torch.no_grad():\n si = ((vae.decode(samp.half()/SCALE+SHIFT).sample+1)/2).clamp(0,1).float()\n save_image(si, f\"{OUTPUT_DIR}/samples/step_{global_step:07d}.png\", nrow=2)\n print(f\" Saved samples\")\n ema.restore(model); model.train()\n\n if global_step % SAVE_EVERY == 0:\n torch.save({\"model\":model.state_dict(),\"ema\":ema.shadow,\"step\":global_step,\"cfg\":cfg},\n f\"{OUTPUT_DIR}/checkpoints/step_{global_step:07d}.pt\")\n print(f\" Checkpoint saved\")\n\ntorch.save({\"model\":model.state_dict(),\"ema\":ema.shadow,\"cfg\":cfg,\"step\":global_step},\n f\"{OUTPUT_DIR}/checkpoints/final.pt\")\nprint(f\"\\nDone! {global_step} steps in {(time.time()-t0)/60:.1f} min\")\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  ]
148
  },
149
  {
150
  "cell_type": "markdown",
151
  "metadata": {},
152
  "source": [
153
+ "## \ud83d\udcc8 Step 8: Loss Curve"
154
  ]
155
  },
156
  {
 
159
  "metadata": {},
160
  "outputs": [],
161
  "source": [
162
+ "import matplotlib.pyplot as plt\nif log_losses:\n plt.figure(figsize=(10,4)); plt.plot(log_losses)\n plt.xlabel(f\"Steps (x{LOG_EVERY})\"); plt.ylabel(\"Loss\")\n plt.title(\"Training Loss\"); plt.grid(True, alpha=0.3)\n plt.savefig(f\"{OUTPUT_DIR}/loss.png\", dpi=150); plt.show()\n print(f\"Min loss: {min(log_losses):.4f}\")\n"
 
 
 
 
163
  ]
164
  },
165
  {
166
  "cell_type": "markdown",
167
  "metadata": {},
168
  "source": [
169
+ "## \ud83c\udfa8 Step 9: Generate"
170
  ]
171
  },
172
  {
 
175
  "metadata": {},
176
  "outputs": [],
177
  "source": [
178
+ "ema.apply(model); model.eval()\nN, STEPS, G = 8, 50, 2.5; ls = IMAGE_SIZE // 8\nif NUM_CLASSES > 0:\n for ci in range(min(NUM_CLASSES, 6)):\n l = torch.full((N,), ci, device=device, dtype=torch.long)\n s = fm.sample(model, (N,16,ls,ls), device, STEPS, l, G)\n with torch.no_grad(): i = ((vae.decode(s.half()/SCALE+SHIFT).sample+1)/2).clamp(0,1).float()\n save_image(i, f\"{OUTPUT_DIR}/gen_class{ci}.png\", nrow=4)\n print(f\"Class {ci}\")\nelse:\n s = fm.sample(model, (N,16,ls,ls), device, STEPS)\n with torch.no_grad(): i = ((vae.decode(s.half()/SCALE+SHIFT).sample+1)/2).clamp(0,1).float()\n save_image(i, f\"{OUTPUT_DIR}/gen_uncond.png\", nrow=4)\nema.restore(model)\nprint(f\"Saved to {OUTPUT_DIR}/\")\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  ]
180
  },
181
  {
182
  "cell_type": "markdown",
183
  "metadata": {},
184
  "source": [
185
+ "## \ud83d\udce4 Step 10: Display"
186
  ]
187
  },
188
  {
 
191
  "metadata": {},
192
  "outputs": [],
193
  "source": [
194
+ "from IPython.display import display\nimport glob\nfor f in sorted(glob.glob(f\"{OUTPUT_DIR}/samples/*.png\"))[-3:]:\n print(os.path.basename(f)); display(Image.open(f))\nfor f in sorted(glob.glob(f\"{OUTPUT_DIR}/gen_*.png\")):\n print(os.path.basename(f)); display(Image.open(f))\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  ]
196
  }
197
  ]