asdf98 commited on
Commit
2aee9f4
·
verified ·
1 Parent(s): a8d6acc

v2: Use pre-trained SD-VAE, fix all bugs, pre-cache everything, massive speedup

Browse files
Files changed (1) hide show
  1. IRIS_Training_Notebook.ipynb +242 -706
IRIS_Training_Notebook.ipynb CHANGED
@@ -22,52 +22,46 @@
22
  "cell_type": "markdown",
23
  "metadata": {},
24
  "source": [
25
- "# \ud83d\udd2e IRIS: Iterative Recurrent Image Synthesis \u2014 Training Notebook",
26
  "",
27
- "**Train a novel mobile-first image generation model from scratch on free Colab/Kaggle GPUs.**",
28
  "",
29
- "This notebook runs the complete 2-stage training pipeline:",
30
- "1. **Stage 1 \u2014 Wavelet VAE Training**: Learn to encode/decode images via wavelet-frequency latent space",
31
- "2. **Stage 2 \u2014 Generator Training**: Train the recurrent-depth denoiser with rectified flow on captioned images",
32
  "",
33
- "### Hardware Requirements",
34
- "| Platform | GPU | VRAM | Estimated Time |",
35
- "|----------|-----|------|----------------|",
36
- "| **Colab Free** | T4 | 16GB | ~2-3 hours total |",
37
- "| **Colab Pro** | A100 | 40GB | ~45 min total |",
38
- "| **Kaggle** | P100/T4\u00d72 | 16GB | ~2-3 hours total |",
39
  "",
40
- "### What You Get",
41
- "- A trained Wavelet VAE that compresses 256\u00d7256 images to 16\u00d716 latent (48\u00d7 compression)",
42
- "- A trained IRIS generator that can denoise latents conditioned on text (CLIP embeddings)",
43
- "- Visualization of reconstructions, generation samples, and loss curves",
44
- "- Saved checkpoints you can continue training from"
 
45
  ]
46
  },
47
  {
48
  "cell_type": "markdown",
49
  "metadata": {},
50
  "source": [
51
- "## 1. Setup & Installation"
52
  ]
53
  },
54
  {
55
  "cell_type": "code",
56
  "metadata": {},
57
  "source": [
58
- "# Install dependencies\n",
59
- "!pip install -q torch torchvision datasets transformers accelerate matplotlib Pillow tqdm huggingface_hub\n",
60
  "\n",
61
- "# Check GPU\n",
62
  "import torch\n",
63
- "print(f\"PyTorch: {torch.__version__}\")\n",
64
- "print(f\"CUDA available: {torch.cuda.is_available()}\")\n",
65
  "if torch.cuda.is_available():\n",
66
- " print(f\"GPU: {torch.cuda.get_device_name(0)}\")\n",
67
- " print(f\"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1024**3:.1f} GB\")\n",
68
  " device = torch.device('cuda')\n",
69
  "else:\n",
70
- " print(\"\u26a0\ufe0f No GPU detected! Training will be very slow on CPU.\")\n",
71
  " device = torch.device('cpu')"
72
  ],
73
  "outputs": [],
@@ -77,31 +71,21 @@
77
  "cell_type": "markdown",
78
  "metadata": {},
79
  "source": [
80
- "## 2. Download IRIS Architecture from Hugging Face"
81
  ]
82
  },
83
  {
84
  "cell_type": "code",
85
  "metadata": {},
86
  "source": [
87
- "# Download the IRIS architecture code from HF Hub\n",
88
  "from huggingface_hub import hf_hub_download\n",
89
- "import shutil, os\n",
90
- "\n",
91
- "repo_id = \"asdf98/IRIS-architecture\"\n",
92
- "for fname in [\"iris_model.py\", \"train_iris.py\", \"test_iris.py\"]:\n",
93
- " path = hf_hub_download(repo_id=repo_id, filename=fname)\n",
94
- " shutil.copy(path, f\"./{fname}\")\n",
95
- " print(f\"\u2705 Downloaded {fname}\")\n",
96
- "\n",
97
- "# Import IRIS\n",
98
- "from iris_model import (\n",
99
- " IRIS, IRISConfig, WaveletVAE, IRISGenerator,\n",
100
- " HaarDWT2D, HaarIDWT2D,\n",
101
- " create_iris_small, create_iris_tiny, create_iris_base,\n",
102
- " count_parameters, estimate_memory_mb,\n",
103
- ")\n",
104
- "print(\"\\n\u2705 IRIS architecture imported successfully!\")"
105
  ],
106
  "outputs": [],
107
  "execution_count": null
@@ -110,37 +94,43 @@
110
  "cell_type": "markdown",
111
  "metadata": {},
112
  "source": [
113
- "## 3. Model Architecture Overview",
114
  "",
115
- "Let's inspect the three model variants and their parameter counts."
 
 
 
116
  ]
117
  },
118
  {
119
  "cell_type": "code",
120
  "metadata": {},
121
  "source": [
122
- "# Show model variants\n",
123
- "for name, fn in [(\"IRIS-Tiny (ultra-mobile)\", create_iris_tiny),\n",
124
- " (\"IRIS-Small (mobile)\", create_iris_small),\n",
125
- " (\"IRIS-Base (desktop)\", create_iris_base)]:\n",
126
- " model = fn()\n",
127
- " counts = count_parameters(model)\n",
128
- " mem16 = estimate_memory_mb(model, torch.float16)\n",
129
- "\n",
130
- " core_params = sum(p.numel() for p in model.generator.core.parameters())\n",
131
- " print(f\"\\n{'='*55}\")\n",
132
- " print(f\" {name}\")\n",
133
- " print(f\"{'='*55}\")\n",
134
- " print(f\" Total params: {counts['total']:>12,}\")\n",
135
- " print(f\" Generator params: {counts['total'] - sum(p.numel() for p in model.vae.parameters()):>12,}\")\n",
136
- " print(f\" Core (shared): {core_params:>12,}\")\n",
137
- " print(f\" Model memory fp16: {mem16:>10.1f} MB\")\n",
138
- " print(f\" + CLIP-L/14 text: 156.0 MB\")\n",
139
- " print(f\" + Overhead: 350.0 MB\")\n",
140
- " print(f\" \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\")\n",
141
- " print(f\" Total inference: {mem16+156+350:>10.1f} MB {'\u2705 <3GB' if mem16+506 < 3000 else ''}\")\n",
142
- "\n",
143
- "del model # Free memory"
 
 
 
144
  ],
145
  "outputs": [],
146
  "execution_count": null
@@ -149,15 +139,7 @@
149
  "cell_type": "markdown",
150
  "metadata": {},
151
  "source": [
152
- "## 4. Load Dataset \u2014 Pok\u00e9mon BLIP Captions",
153
- "",
154
- "We use `reach-vb/pokemon-blip-captions` \u2014 a small, high-quality dataset with ~833 image-caption pairs. ",
155
- "Perfect for free-tier training to validate the architecture works end-to-end.",
156
- "",
157
- "**For serious training later**, swap in larger datasets:",
158
- "- `ILSVRC/imagenet-1k` (Stage 2 class-conditional)",
159
- "- `laion/laion-art` (Text-image alignment)",
160
- "- `caidas/JourneyDB` (Aesthetic fine-tuning)"
161
  ]
162
  },
163
  {
@@ -167,81 +149,49 @@
167
  "from datasets import load_dataset\n",
168
  "from torchvision import transforms\n",
169
  "from torch.utils.data import Dataset, DataLoader\n",
170
- "from PIL import Image\n",
171
- "import numpy as np\n",
 
172
  "\n",
173
- "# Load Pok\u00e9mon dataset\n",
174
- "print(\"Loading dataset...\")\n",
175
- "raw_dataset = load_dataset(\"reach-vb/pokemon-blip-captions\", split=\"train\")\n",
176
- "print(f\"\u2705 Loaded {len(raw_dataset)} image-caption pairs\")\n",
177
  "\n",
178
- "# Show a few examples\n",
179
- "import matplotlib.pyplot as plt\n",
180
- "fig, axes = plt.subplots(1, 5, figsize=(20, 4))\n",
181
- "for i, ax in enumerate(axes):\n",
182
- " item = raw_dataset[i]\n",
183
- " ax.imshow(item[\"image\"])\n",
184
- " ax.set_title(item[\"text\"][:40] + \"...\", fontsize=9)\n",
185
- " ax.axis(\"off\")\n",
186
- "plt.suptitle(\"Sample Training Images\", fontsize=14)\n",
187
- "plt.tight_layout()\n",
188
- "plt.show()"
189
- ],
190
- "outputs": [],
191
- "execution_count": null
192
- },
193
- {
194
- "cell_type": "markdown",
195
- "metadata": {},
196
- "source": [
197
- "### 4.1 Create PyTorch Dataset with Transforms"
198
- ]
199
- },
200
- {
201
- "cell_type": "code",
202
- "metadata": {},
203
- "source": [
204
- "# \u2500\u2500\u2500 Training configuration \u2500\u2500\u2500\n",
205
- "IMAGE_SIZE = 256 # Input image resolution\n",
206
- "BATCH_SIZE = 4 # Fits on T4 16GB; increase on A100\n",
207
- "NUM_WORKERS = 2 # Dataloader workers\n",
208
  "\n",
209
- "# \u2500\u2500\u2500 Image transforms \u2500\u2500\u2500\n",
210
  "train_transform = transforms.Compose([\n",
211
  " transforms.Resize(IMAGE_SIZE, interpolation=transforms.InterpolationMode.LANCZOS),\n",
212
  " transforms.CenterCrop(IMAGE_SIZE),\n",
213
  " transforms.RandomHorizontalFlip(),\n",
214
- " transforms.ToTensor(), # [0, 1]\n",
215
- " transforms.Normalize([0.5]*3, [0.5]*3), # [-1, 1]\n",
216
  "])\n",
217
  "\n",
218
  "class ImageCaptionDataset(Dataset):\n",
219
- " \"\"\"Wraps a HF dataset with transforms. Returns (image_tensor, caption_string).\"\"\"\n",
220
- " def __init__(self, hf_dataset, transform):\n",
221
- " self.dataset = hf_dataset\n",
222
  " self.transform = transform\n",
 
 
 
 
223
  "\n",
224
- " def __len__(self):\n",
225
- " return len(self.dataset)\n",
226
  "\n",
227
- " def __getitem__(self, idx):\n",
228
- " item = self.dataset[idx]\n",
229
- " image = item[\"image\"].convert(\"RGB\")\n",
230
- " image = self.transform(image)\n",
231
- " caption = item[\"text\"]\n",
232
- " return image, caption\n",
233
  "\n",
234
- "train_dataset = ImageCaptionDataset(raw_dataset, train_transform)\n",
235
- "train_loader = DataLoader(\n",
236
- " train_dataset, batch_size=BATCH_SIZE, shuffle=True,\n",
237
- " num_workers=NUM_WORKERS, pin_memory=True, drop_last=True,\n",
238
- ")\n",
239
- "print(f\"\u2705 DataLoader ready: {len(train_loader)} batches of {BATCH_SIZE}\")\n",
240
  "\n",
241
- "# Quick sanity check\n",
242
- "imgs, caps = next(iter(train_loader))\n",
243
- "print(f\" Image batch: {imgs.shape}, range [{imgs.min():.2f}, {imgs.max():.2f}]\")\n",
244
- "print(f\" Caption[0]: {caps[0]}\")"
245
  ],
246
  "outputs": [],
247
  "execution_count": null
@@ -250,45 +200,60 @@
250
  "cell_type": "markdown",
251
  "metadata": {},
252
  "source": [
253
- "## 5. Load CLIP Text Encoder (Frozen)",
254
  "",
255
- "We use CLIP-L/14 (~150MB) as the text encoder. It's frozen during training \u2014 ",
256
- "only the IRIS generator learns. This is the same encoder used in aMUSEd, Meissonic, and SnapGen."
257
  ]
258
  },
259
  {
260
  "cell_type": "code",
261
  "metadata": {},
262
  "source": [
263
- "from transformers import CLIPTextModel, CLIPTokenizer\n",
 
 
264
  "\n",
265
- "print(\"Loading CLIP-L/14 text encoder...\")\n",
266
- "clip_model_name = \"openai/clip-vit-large-patch14\"\n",
267
- "tokenizer = CLIPTokenizer.from_pretrained(clip_model_name)\n",
268
- "text_encoder = CLIPTextModel.from_pretrained(clip_model_name).to(device).eval()\n",
269
  "\n",
270
- "# Freeze text encoder\n",
271
- "for p in text_encoder.parameters():\n",
272
- " p.requires_grad = False\n",
 
 
 
 
273
  "\n",
274
- "clip_params = sum(p.numel() for p in text_encoder.parameters())\n",
275
- "print(f\"\u2705 CLIP-L/14 loaded: {clip_params/1e6:.1f}M params (frozen)\")\n",
276
- "print(f\" Text embedding dim: {text_encoder.config.hidden_size}\")\n",
277
- "print(f\" Max tokens: {tokenizer.model_max_length}\")\n",
278
  "\n",
279
- "@torch.no_grad()\n",
280
- "def encode_text(captions, max_length=77):\n",
281
- " \"\"\"Encode a list of captions to CLIP text embeddings.\"\"\"\n",
282
- " tokens = tokenizer(\n",
283
- " captions, padding=\"max_length\", truncation=True,\n",
284
- " max_length=max_length, return_tensors=\"pt\"\n",
285
- " ).to(device)\n",
286
- " outputs = text_encoder(**tokens)\n",
287
- " return outputs.last_hidden_state # [B, 77, 768]\n",
288
- "\n",
289
- "# Test encoding\n",
290
- "test_emb = encode_text([\"a cute dragon breathing fire\"])\n",
291
- "print(f\" Test encoding shape: {test_emb.shape}\")"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  ],
293
  "outputs": [],
294
  "execution_count": null
@@ -297,34 +262,23 @@
297
  "cell_type": "markdown",
298
  "metadata": {},
299
  "source": [
300
- "## 6. Stage 1 \u2014 Wavelet VAE Training",
301
- "",
302
- "Train the lightweight Wavelet VAE to reconstruct images through the wavelet-frequency latent space.",
303
  "",
304
- "**Architecture**: `Image \u2192 HaarDWT \u2192 Encoder \u2192 Latent(16ch, 16\u00d716) \u2192 Decoder \u2192 HaarIDWT \u2192 Image`",
305
- "",
306
- "**Losses**:",
307
- "- MSE reconstruction loss",
308
- "- KL divergence (variational regularization)",
309
- "- Wavelet frequency loss (preserves high-frequency details)",
310
- "- Perceptual loss via LPIPS-like gradient matching"
311
  ]
312
  },
313
  {
314
  "cell_type": "code",
315
  "metadata": {},
316
  "source": [
317
- "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
318
- "# STAGE 1: WAVELET VAE TRAINING\n",
319
- "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
320
- "\n",
321
- "# Build IRIS-Tiny config for free-tier training\n",
322
- "# DWT(2\u00d7) + 3 down-blocks(8\u00d7) = 16\u00d7 total compression\n",
323
- "# 256px input \u2192 128px after DWT \u2192 64\u219232\u219216 after encoder = 16\u00d716 latent\n",
324
  "config = IRISConfig(\n",
325
- " latent_channels=8, # Smaller for memory efficiency\n",
326
- " latent_spatial=16, # 16\u00d716 spatial latent\n",
327
- " hidden_dim=384, # IRIS-Tiny\n",
328
  " num_heads=6,\n",
329
  " head_dim=64,\n",
330
  " ffn_ratio=2.667,\n",
@@ -337,152 +291,19 @@
337
  " sparsity_threshold=0.01,\n",
338
  " recurrence_dim=192,\n",
339
  " manhattan_window=12,\n",
340
- " text_dim=768, # CLIP-L/14\n",
341
  " max_text_tokens=77,\n",
342
  " patch_size=2,\n",
343
- " vae_channels=[32, 64, 128, 256],\n",
344
  ")\n",
345
  "\n",
346
- "# Create VAE\n",
347
- "vae = WaveletVAE(config).to(device)\n",
348
- "vae_params = sum(p.numel() for p in vae.parameters())\n",
349
- "print(f\"Wavelet VAE: {vae_params:,} params ({vae_params*4/1024/1024:.1f} MB fp32)\")\n",
350
- "print(f\"Encoder: {sum(p.numel() for p in vae.encoder.parameters()):,}\")\n",
351
- "print(f\"Decoder: {sum(p.numel() for p in vae.decoder.parameters()):,}\")"
352
- ],
353
- "outputs": [],
354
- "execution_count": null
355
- },
356
- {
357
- "cell_type": "code",
358
- "metadata": {},
359
- "source": [
360
- "# \u2500\u2500\u2500 VAE Training Loop \u2500\u2500\u2500\n",
361
- "import time\n",
362
- "from tqdm.auto import tqdm\n",
363
- "\n",
364
- "VAE_EPOCHS = 80 # Enough to get good reconstructions\n",
365
- "VAE_LR = 1e-4\n",
366
- "KL_WEIGHT = 1e-4 # Light KL to avoid posterior collapse\n",
367
- "FREQ_WEIGHT = 0.1 # Wavelet frequency preservation\n",
368
- "\n",
369
- "optimizer_vae = torch.optim.AdamW(vae.parameters(), lr=VAE_LR, weight_decay=0.01)\n",
370
- "scheduler_vae = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_vae, T_max=VAE_EPOCHS)\n",
371
- "scaler = torch.amp.GradScaler('cuda')\n",
372
- "dwt = HaarDWT2D()\n",
373
- "\n",
374
- "# Logging\n",
375
- "vae_losses = {\"total\": [], \"recon\": [], \"kl\": [], \"freq\": []}\n",
376
- "\n",
377
- "print(f\"Training VAE for {VAE_EPOCHS} epochs on {len(train_loader)} batches...\")\n",
378
- "\n",
379
- "vae.train()\n",
380
- "pbar = tqdm(range(VAE_EPOCHS), desc=\"VAE Training\")\n",
381
- "for epoch in pbar:\n",
382
- " epoch_losses = {\"total\": 0, \"recon\": 0, \"kl\": 0, \"freq\": 0}\n",
383
- "\n",
384
- " for images, _ in train_loader:\n",
385
- " images = images.to(device, non_blocking=True)\n",
386
- "\n",
387
- " with torch.amp.autocast('cuda', dtype=torch.float16):\n",
388
- " x_recon, mean, logvar = vae(images)\n",
389
- "\n",
390
- " # Reconstruction loss\n",
391
- " recon_loss = torch.nn.functional.mse_loss(x_recon, images)\n",
392
- "\n",
393
- " # KL divergence\n",
394
- " kl_loss = -0.5 * (1 + logvar - mean.pow(2) - logvar.exp()).mean()\n",
395
- "\n",
396
- " # Wavelet frequency loss \u2014 preserve high-freq details\n",
397
- " with torch.no_grad():\n",
398
- " target_wv = dwt(images)\n",
399
- " recon_wv = dwt(x_recon)\n",
400
- " freq_loss = torch.nn.functional.l1_loss(recon_wv, target_wv)\n",
401
- "\n",
402
- " loss = recon_loss + KL_WEIGHT * kl_loss + FREQ_WEIGHT * freq_loss\n",
403
- "\n",
404
- " optimizer_vae.zero_grad(set_to_none=True)\n",
405
- " scaler.scale(loss).backward()\n",
406
- " scaler.unscale_(optimizer_vae)\n",
407
- " torch.nn.utils.clip_grad_norm_(vae.parameters(), 1.0)\n",
408
- " scaler.step(optimizer_vae)\n",
409
- " scaler.update()\n",
410
- "\n",
411
- " epoch_losses[\"total\"] += loss.item()\n",
412
- " epoch_losses[\"recon\"] += recon_loss.item()\n",
413
- " epoch_losses[\"kl\"] += kl_loss.item()\n",
414
- " epoch_losses[\"freq\"] += freq_loss.item()\n",
415
- "\n",
416
- " # Average losses\n",
417
- " n = len(train_loader)\n",
418
- " for k in epoch_losses:\n",
419
- " epoch_losses[k] /= n\n",
420
- " vae_losses[k].append(epoch_losses[k])\n",
421
- "\n",
422
- " scheduler_vae.step()\n",
423
- " pbar.set_postfix(loss=f\"{epoch_losses['total']:.4f}\", recon=f\"{epoch_losses['recon']:.4f}\")\n",
424
- "\n",
425
- "print(\"\\n\u2705 VAE training complete!\")"
426
- ],
427
- "outputs": [],
428
- "execution_count": null
429
- },
430
- {
431
- "cell_type": "markdown",
432
- "metadata": {},
433
- "source": [
434
- "### 6.1 Visualize VAE Reconstructions"
435
- ]
436
- },
437
- {
438
- "cell_type": "code",
439
- "metadata": {},
440
- "source": [
441
- "# Visualize reconstructions\n",
442
- "vae.eval()\n",
443
- "fig, axes = plt.subplots(3, 8, figsize=(20, 8))\n",
444
- "\n",
445
- "with torch.no_grad():\n",
446
- " imgs_sample, _ = next(iter(train_loader))\n",
447
- " imgs_sample = imgs_sample[:8].to(device)\n",
448
- " recon, _, _ = vae(imgs_sample)\n",
449
- "\n",
450
- " # Also show latent statistics\n",
451
- " z, mean, logvar = vae.encode(imgs_sample)\n",
452
- " print(f\"Latent shape: {z.shape}\")\n",
453
- " print(f\"Latent mean: {z.mean():.3f}, std: {z.std():.3f}\")\n",
454
- " print(f\"Latent range: [{z.min():.3f}, {z.max():.3f}]\")\n",
455
- "\n",
456
- "def show_img(ax, tensor, title=\"\"):\n",
457
- " img = tensor.cpu().clamp(-1, 1) * 0.5 + 0.5 # [-1,1] \u2192 [0,1]\n",
458
- " ax.imshow(img.permute(1, 2, 0).numpy())\n",
459
- " ax.set_title(title, fontsize=8)\n",
460
- " ax.axis(\"off\")\n",
461
- "\n",
462
- "for i in range(8):\n",
463
- " show_img(axes[0, i], imgs_sample[i], f\"Original {i}\")\n",
464
- " show_img(axes[1, i], recon[i], f\"Recon {i}\")\n",
465
- " axes[2, i].imshow(z[i, :3].cpu().permute(1, 2, 0).numpy() * 0.3 + 0.5)\n",
466
- " axes[2, i].set_title(f\"Latent ch0-2\", fontsize=8)\n",
467
- " axes[2, i].axis(\"off\")\n",
468
- "\n",
469
- "axes[0, 0].set_ylabel(\"Original\", fontsize=12)\n",
470
- "axes[1, 0].set_ylabel(\"Reconstructed\", fontsize=12)\n",
471
- "axes[2, 0].set_ylabel(\"Latent\", fontsize=12)\n",
472
- "plt.suptitle(\"Wavelet VAE Reconstructions\", fontsize=14)\n",
473
- "plt.tight_layout()\n",
474
- "plt.show()\n",
475
  "\n",
476
- "# Plot loss curves\n",
477
- "fig, axes = plt.subplots(1, 3, figsize=(15, 4))\n",
478
- "for ax, key, color in zip(axes, [\"total\", \"recon\", \"freq\"], [\"blue\", \"green\", \"red\"]):\n",
479
- " ax.plot(vae_losses[key], color=color)\n",
480
- " ax.set_title(f\"VAE {key.title()} Loss\")\n",
481
- " ax.set_xlabel(\"Epoch\")\n",
482
- " ax.set_ylabel(\"Loss\")\n",
483
- " ax.grid(True, alpha=0.3)\n",
484
- "plt.tight_layout()\n",
485
- "plt.show()"
486
  ],
487
  "outputs": [],
488
  "execution_count": null
@@ -491,178 +312,91 @@
491
  "cell_type": "markdown",
492
  "metadata": {},
493
  "source": [
494
- "## 7. Stage 2 \u2014 IRIS Generator Training (Rectified Flow)",
495
- "",
496
- "Now we train the **recurrent-depth generator** to denoise latent representations conditioned on CLIP text embeddings.",
497
  "",
498
- "**Key features of this training**:",
499
- "- **Rectified Flow**: Linear noise schedule, velocity prediction, logit-normal timestep sampling",
500
- "- **Recurrent Depth**: Core block is iterated randomly 4-8\u00d7 per step (training robustness)",
501
- "- **adaLN-Zero**: Stable training start via zero-initialized gating",
502
- "- **Mixed precision (fp16)**: Fits on 16GB VRAM",
503
- "- **Gradient checkpointing**: Optional, for very tight memory",
504
- "",
505
- "**The magic**: Because the core block shares weights across iterations, ",
506
- "we get deep effective network capacity from tiny parameter count!"
507
  ]
508
  },
509
  {
510
  "cell_type": "code",
511
  "metadata": {},
512
  "source": [
513
- "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
514
- "# STAGE 2: IRIS GENERATOR TRAINING (RECTIFIED FLOW)\n",
515
- "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
516
- "\n",
517
- "# Build full IRIS model (reusing config from VAE stage)\n",
518
- "iris = IRIS(config).to(device)\n",
519
  "\n",
520
- "# Load trained VAE weights\n",
521
- "iris.vae.load_state_dict(vae.state_dict())\n",
522
- "\n",
523
- "# Freeze VAE\n",
524
- "for p in iris.vae.parameters():\n",
525
- " p.requires_grad = False\n",
526
- "iris.vae.eval()\n",
527
- "\n",
528
- "gen_params = sum(p.numel() for p in iris.generator.parameters())\n",
529
- "core_params = sum(p.numel() for p in iris.generator.core.parameters())\n",
530
- "print(f\"IRIS Generator: {gen_params:,} trainable params\")\n",
531
- "print(f\" Core block (shared): {core_params:,} ({core_params/gen_params*100:.1f}%)\")\n",
532
- "print(f\" Effective at r=6: ~{gen_params + 5*core_params:,} effective params\")\n",
533
- "print(f\" Memory fp16: {gen_params*2/1024/1024:.1f} MB\")\n",
534
- "\n",
535
- "# \u2500\u2500\u2500 Pre-cache CLIP text embeddings (HUGE speedup) \u2500\u2500\u2500\n",
536
- "# Instead of encoding text every batch, cache all embeddings upfront\n",
537
- "print(\"\\nPre-caching CLIP text embeddings...\")\n",
538
- "all_text_embeddings = []\n",
539
- "cache_loader = DataLoader(train_dataset, batch_size=32, shuffle=False, num_workers=0)\n",
540
- "with torch.no_grad():\n",
541
- " for _, captions in tqdm(cache_loader, desc=\"Encoding text\"):\n",
542
- " emb = encode_text(list(captions))\n",
543
- " all_text_embeddings.append(emb.cpu())\n",
544
- "all_text_embeddings = torch.cat(all_text_embeddings, dim=0) # [N, 77, 768]\n",
545
- "print(f\"\u2705 Cached {all_text_embeddings.shape[0]} text embeddings: {all_text_embeddings.shape}\")\n",
546
- "\n",
547
- "# Free CLIP from GPU (we don't need it during training anymore!)\n",
548
- "text_encoder.cpu()\n",
549
- "torch.cuda.empty_cache()\n",
550
- "print(\"\u2705 CLIP moved to CPU to free ~600MB VRAM\")\n",
551
- "\n",
552
- "# Create a new dataset that uses cached embeddings\n",
553
  "class CachedDataset(Dataset):\n",
554
- " def __init__(self, image_dataset, cached_text_emb):\n",
555
- " self.image_dataset = image_dataset\n",
556
- " self.text_emb = cached_text_emb\n",
557
- " def __len__(self):\n",
558
- " return len(self.image_dataset)\n",
559
- " def __getitem__(self, idx):\n",
560
- " image, _ = self.image_dataset[idx]\n",
561
- " return image, self.text_emb[idx]\n",
562
- "\n",
563
- "cached_dataset = CachedDataset(train_dataset, all_text_embeddings)\n",
564
- "cached_loader = DataLoader(\n",
565
- " cached_dataset, batch_size=BATCH_SIZE, shuffle=True,\n",
566
- " num_workers=NUM_WORKERS, pin_memory=True, drop_last=True,\n",
567
- ")\n",
568
  "\n",
569
- "# Free standalone VAE to save memory\n",
570
- "del vae, optimizer_vae, scheduler_vae\n",
571
- "torch.cuda.empty_cache()"
572
- ],
573
- "outputs": [],
574
- "execution_count": null
575
- },
576
- {
577
- "cell_type": "code",
578
- "metadata": {},
579
- "source": [
580
- "# \u2500\u2500\u2500 Generator Training Loop \u2500\u2500\u2500\n",
581
- "import time\n",
582
- "\n",
583
- "GEN_EPOCHS = 150 # More epochs for small dataset\n",
584
- "GEN_LR = 2e-4 # Higher LR works well with AdamW + cosine\n",
585
- "GRAD_ACCUM = 2 # Effective batch = BATCH_SIZE \u00d7 GRAD_ACCUM = 8\n",
586
- "WARMUP_STEPS = 100\n",
587
- "\n",
588
- "optimizer_gen = torch.optim.AdamW(\n",
589
- " iris.generator.parameters(),\n",
590
- " lr=GEN_LR,\n",
591
- " weight_decay=0.03,\n",
592
- " betas=(0.9, 0.95),\n",
593
  ")\n",
594
  "\n",
595
- "total_steps = GEN_EPOCHS * len(cached_loader) // GRAD_ACCUM\n",
596
- "\n",
597
- "def lr_lambda(step):\n",
598
- " if step < WARMUP_STEPS:\n",
599
- " return step / max(1, WARMUP_STEPS)\n",
600
- " progress = (step - WARMUP_STEPS) / max(1, total_steps - WARMUP_STEPS)\n",
601
- " return 0.5 * (1 + __import__('math').cos(__import__('math').pi * progress))\n",
602
  "\n",
603
- "scheduler_gen = torch.optim.lr_scheduler.LambdaLR(optimizer_gen, lr_lambda)\n",
604
- "scaler_gen = torch.amp.GradScaler('cuda')\n",
 
605
  "\n",
606
- "# Logging\n",
607
- "gen_losses = {\"total\": [], \"velocity\": [], \"kl\": []}\n",
 
608
  "\n",
609
- "print(f\"Training generator for {GEN_EPOCHS} epochs ({total_steps} optimizer steps)\")\n",
610
- "print(f\"Effective batch size: {BATCH_SIZE} \u00d7 {GRAD_ACCUM} = {BATCH_SIZE * GRAD_ACCUM}\")\n",
611
- "print(f\"Using cached CLIP embeddings (no per-batch encoding overhead)\")\n",
 
612
  "\n",
 
 
613
  "iris.generator.train()\n",
614
- "global_step = 0\n",
615
  "best_loss = float('inf')\n",
 
616
  "\n",
617
- "pbar = tqdm(range(GEN_EPOCHS), desc=\"Gen Training\")\n",
618
  "for epoch in pbar:\n",
619
- " epoch_vel = 0\n",
620
- " epoch_total = 0\n",
621
- " n_batches = 0\n",
622
  "\n",
623
- " optimizer_gen.zero_grad(set_to_none=True)\n",
624
- "\n",
625
- " for batch_idx, (images, text_emb) in enumerate(cached_loader):\n",
626
- " images = images.to(device, non_blocking=True)\n",
627
  " text_emb = text_emb.to(device, non_blocking=True)\n",
628
  "\n",
629
- " # Forward pass with mixed precision\n",
630
  " with torch.amp.autocast('cuda', dtype=torch.float16):\n",
631
- " # Randomly sample iteration count for robustness (keep low for speed)\n",
632
  " r = [3, 4, 5][torch.randint(0, 3, (1,)).item()]\n",
633
- " result = iris.train_step(images, text_emb, num_iterations=r)\n",
634
  " loss = result[\"loss\"] / GRAD_ACCUM\n",
635
  "\n",
636
- " scaler_gen.scale(loss).backward()\n",
637
  "\n",
638
- " # Gradient accumulation\n",
639
  " if (batch_idx + 1) % GRAD_ACCUM == 0:\n",
640
- " scaler_gen.unscale_(optimizer_gen)\n",
641
  " torch.nn.utils.clip_grad_norm_(iris.generator.parameters(), 1.0)\n",
642
- " scaler_gen.step(optimizer_gen)\n",
643
- " scaler_gen.update()\n",
644
- " optimizer_gen.zero_grad(set_to_none=True)\n",
645
- " scheduler_gen.step()\n",
646
  " global_step += 1\n",
647
  "\n",
648
- " epoch_vel += result[\"velocity_loss\"]\n",
649
- " epoch_total += result[\"loss\"].item() if hasattr(result[\"loss\"], 'item') else result[\"velocity_loss\"]\n",
650
- " n_batches += 1\n",
651
- "\n",
652
- " avg_vel = epoch_vel / n_batches\n",
653
- " avg_total = epoch_total / n_batches\n",
654
- " gen_losses[\"velocity\"].append(avg_vel)\n",
655
- " gen_losses[\"total\"].append(avg_total)\n",
656
  "\n",
657
- " if avg_vel < best_loss:\n",
658
- " best_loss = avg_vel\n",
 
 
 
659
  "\n",
660
- " pbar.set_postfix(vel_loss=f\"{avg_vel:.4f}\", best=f\"{best_loss:.4f}\")\n",
661
- "\n",
662
- "print(f\"\\n\u2705 Generator training complete! Best velocity loss: {best_loss:.4f}\")\n",
663
- "\n",
664
- "# Reload CLIP for generation\n",
665
- "text_encoder.to(device)"
666
  ],
667
  "outputs": [],
668
  "execution_count": null
@@ -671,70 +405,45 @@
671
  "cell_type": "markdown",
672
  "metadata": {},
673
  "source": [
674
- "## 8. Generate Images!",
675
- "",
676
- "Now let's generate images using the trained IRIS model. We'll test different iteration budgets ",
677
- "to see the adaptive compute in action."
678
  ]
679
  },
680
  {
681
  "cell_type": "code",
682
  "metadata": {},
683
  "source": [
684
- "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
685
- "# GENERATION\n",
686
- "# \u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\u2550\n",
687
  "\n",
688
  "prompts = [\n",
689
  " \"a fire-breathing dragon pokemon\",\n",
690
- " \"a cute blue water pokemon\",\n",
691
  " \"a green grass-type pokemon with leaves\",\n",
692
- " \"a purple ghost pokemon floating\",\n",
693
- " \"a yellow electric pokemon with lightning\",\n",
694
- " \"a pink fairy pokemon with wings\",\n",
695
- " \"a red phoenix pokemon\",\n",
696
- " \"a small brown fox pokemon\",\n",
697
  "]\n",
698
  "\n",
699
  "iris.eval()\n",
700
- "\n",
701
- "# Generate with different iteration counts to show adaptive compute\n",
702
- "fig, axes = plt.subplots(len(prompts), 4, figsize=(16, len(prompts) * 4))\n",
703
- "iteration_counts = [2, 4, 6, 8]\n",
704
  "\n",
705
  "for row, prompt in enumerate(prompts):\n",
706
- " with torch.no_grad():\n",
707
- " text_emb = encode_text([prompt])\n",
708
- "\n",
709
- " for col, n_iter in enumerate(iteration_counts):\n",
710
- " with torch.no_grad():\n",
711
- " img = iris.generate(\n",
712
- " text_emb,\n",
713
- " num_steps=4,\n",
714
- " num_iterations=n_iter,\n",
715
- " cfg_scale=1.0, # No CFG on untrained model\n",
716
- " seed=42,\n",
717
- " )\n",
718
- " # Convert to displayable\n",
719
- " img_np = img[0].cpu().clamp(-1, 1) * 0.5 + 0.5\n",
720
- " img_np = img_np.permute(1, 2, 0).numpy()\n",
721
- "\n",
722
  " axes[row, col].imshow(img_np)\n",
723
  " axes[row, col].axis(\"off\")\n",
724
  " if row == 0:\n",
725
- " axes[row, col].set_title(f\"r={n_iter} iterations\", fontsize=11)\n",
726
- " axes[row, 0].set_ylabel(prompt[:25] + \"...\", fontsize=9, rotation=0, labelpad=120, va='center')\n",
727
  "\n",
728
- "plt.suptitle(\"IRIS Generated Images (Adaptive Compute Budget)\", fontsize=14, y=1.01)\n",
729
  "plt.tight_layout()\n",
730
  "plt.show()\n",
731
  "\n",
732
- "print(\"\\nNote: With only ~800 training images and short training, outputs are noisy.\")\n",
733
- "print(\"This demonstrates the architecture works. Quality improves dramatically with:\")\n",
734
- "print(\" \u2022 More training data (CC3M, LAION)\")\n",
735
- "print(\" \u2022 More epochs (1000+)\")\n",
736
- "print(\" \u2022 Larger model (IRIS-Small or IRIS-Base)\")\n",
737
- "print(\" \u2022 Stage 3-5 training (text alignment + aesthetics + distillation)\")"
738
  ],
739
  "outputs": [],
740
  "execution_count": null
@@ -743,158 +452,35 @@
743
  "cell_type": "markdown",
744
  "metadata": {},
745
  "source": [
746
- "### 8.1 Training Loss Curves"
747
  ]
748
  },
749
  {
750
  "cell_type": "code",
751
  "metadata": {},
752
  "source": [
753
- "fig, axes = plt.subplots(1, 2, figsize=(14, 5))\n",
754
- "\n",
755
- "# VAE losses\n",
756
- "ax = axes[0]\n",
757
- "ax.plot(vae_losses[\"recon\"], label=\"Reconstruction\", color=\"blue\")\n",
758
- "ax.plot(vae_losses[\"freq\"], label=\"Wavelet Freq\", color=\"red\")\n",
759
- "ax.set_title(\"Stage 1: VAE Losses\")\n",
760
- "ax.set_xlabel(\"Epoch\")\n",
761
- "ax.set_ylabel(\"Loss\")\n",
762
- "ax.legend()\n",
763
- "ax.grid(True, alpha=0.3)\n",
764
- "ax.set_yscale(\"log\")\n",
765
- "\n",
766
- "# Generator losses\n",
767
- "ax = axes[1]\n",
768
- "ax.plot(gen_losses[\"velocity\"], label=\"Velocity Loss\", color=\"green\")\n",
769
- "ax.set_title(\"Stage 2: Generator Velocity Loss\")\n",
770
- "ax.set_xlabel(\"Epoch\")\n",
771
- "ax.set_ylabel(\"Loss\")\n",
772
- "ax.legend()\n",
773
- "ax.grid(True, alpha=0.3)\n",
774
  "\n",
775
- "plt.tight_layout()\n",
776
- "plt.show()"
777
- ],
778
- "outputs": [],
779
- "execution_count": null
780
- },
781
- {
782
- "cell_type": "markdown",
783
- "metadata": {},
784
- "source": [
785
- "## 9. Save Checkpoint"
786
- ]
787
- },
788
- {
789
- "cell_type": "code",
790
- "metadata": {},
791
- "source": [
792
- "# Save the trained model\n",
793
  "import os\n",
794
  "os.makedirs(\"iris_checkpoint\", exist_ok=True)\n",
795
- "\n",
796
- "checkpoint = {\n",
797
  " \"config\": config,\n",
798
- " \"iris_state_dict\": iris.state_dict(),\n",
799
- " \"epoch\": GEN_EPOCHS,\n",
800
- " \"best_velocity_loss\": best_loss,\n",
801
- " \"vae_losses\": vae_losses,\n",
802
- " \"gen_losses\": gen_losses,\n",
803
- "}\n",
804
- "torch.save(checkpoint, \"iris_checkpoint/iris_trained.pt\")\n",
805
- "print(f\"\u2705 Checkpoint saved to iris_checkpoint/iris_trained.pt\")\n",
806
- "print(f\" File size: {os.path.getsize('iris_checkpoint/iris_trained.pt') / 1024 / 1024:.1f} MB\")\n",
807
- "\n",
808
- "# Optional: push to HF Hub\n",
809
- "# from huggingface_hub import HfApi\n",
810
- "# api = HfApi()\n",
811
- "# api.upload_folder(folder_path=\"iris_checkpoint\", repo_id=\"YOUR_USERNAME/iris-trained\")"
812
- ],
813
- "outputs": [],
814
- "execution_count": null
815
- },
816
- {
817
- "cell_type": "markdown",
818
- "metadata": {},
819
- "source": [
820
- "## 10. Inspect Learned Components",
821
- "",
822
- "Let's peek inside the trained model to understand what the different pathways learned."
823
- ]
824
- },
825
- {
826
- "cell_type": "code",
827
- "metadata": {},
828
- "source": [
829
- "# Inspect GRFM pathway gating\n",
830
- "print(\"=== GRFM Analysis ===\\n\")\n",
831
- "\n",
832
- "# Look at the blend gate \u2014 does it prefer Fourier or Recurrence?\n",
833
- "with torch.no_grad():\n",
834
- " # Get a sample through the model\n",
835
- " imgs_sample, caps = next(iter(train_loader))\n",
836
- " imgs_sample = imgs_sample.to(device)\n",
837
- " text_emb = encode_text(list(caps))\n",
838
- "\n",
839
- " z, _, _ = iris.encode(imgs_sample)\n",
840
- " noise = torch.randn_like(z)\n",
841
- " t = torch.tensor([0.5] * z.shape[0], device=device)\n",
842
- " z_t = iris.add_noise(z, noise, t)\n",
843
- "\n",
844
- " # Trace through to get GRFM internal state\n",
845
- " x = iris.generator.patch_embed(iris.generator.patchify(z_t)) + iris.generator.pos_embed\n",
846
- " for block in iris.generator.prelude:\n",
847
- " x = block(x)\n",
848
- "\n",
849
- " # Get first core layer's GRFM gate values\n",
850
- " core_layer = iris.generator.core.layers[0]\n",
851
- " H, W = iris.generator.patch_h, iris.generator.patch_w\n",
852
- "\n",
853
- " # Compute adaLN modulation\n",
854
- " t_emb = iris.generator.time_embed(t * 1000)\n",
855
- " i_emb = iris.generator.iter_embed(torch.zeros(z.shape[0], dtype=torch.long, device=device))\n",
856
- " text_global = iris.generator.text_pool_proj(text_emb.mean(dim=1))\n",
857
- " c = t_emb + i_emb + text_global\n",
858
- "\n",
859
- " s1, sh1, g1, *_ = core_layer.adaln(c)\n",
860
- " h_normed = core_layer._modulate(core_layer.norm1(x), s1, sh1)\n",
861
- "\n",
862
- " # Get the blend gate value from GRFM\n",
863
- " gate = core_layer.grfm.blend_gate(h_normed) # [B, N, D]\n",
864
- " gate_mean = gate.mean(dim=(0, 2)) # [N] \u2014 per-position gate\n",
865
- "\n",
866
- " # Reshape to 2D\n",
867
- " gate_2d = gate_mean.reshape(H, W).cpu().numpy()\n",
868
- "\n",
869
- " fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n",
870
- "\n",
871
- " # Gate heatmap\n",
872
- " im = axes[0].imshow(gate_2d, cmap='RdBu_r', vmin=0, vmax=1)\n",
873
- " axes[0].set_title(\"GRFM Blend Gate\\n(red=Fourier, blue=Recurrence)\")\n",
874
- " plt.colorbar(im, ax=axes[0])\n",
875
- "\n",
876
- " # Manhattan decay gammas\n",
877
- " gammas = torch.sigmoid(core_layer.grfm.spatial.gamma_logit).cpu().numpy()\n",
878
- " axes[1].bar(range(len(gammas)), gammas)\n",
879
- " axes[1].set_title(\"Manhattan Spatial Decay \u03b3 per Head\\n(lower=more local)\")\n",
880
- " axes[1].set_xlabel(\"Head\")\n",
881
- " axes[1].set_ylabel(\"\u03b3\")\n",
882
- " axes[1].set_ylim(0, 1)\n",
883
- "\n",
884
- " # Fourier sparsity (how many coefficients survive soft-shrink)\n",
885
- " x_2d = h_normed.reshape(h_normed.shape[0], H, W, h_normed.shape[-1])\n",
886
- " x_freq = torch.fft.rfft2(x_2d, dim=(1, 2), norm='ortho')\n",
887
- " magnitude = x_freq.abs()\n",
888
- " threshold = core_layer.grfm.fourier.sparsity_threshold\n",
889
- " alive = (magnitude > threshold).float().mean().item()\n",
890
- " axes[2].text(0.5, 0.5, f\"Fourier coefficients\\nabove threshold:\\n{alive*100:.1f}%\",\n",
891
- " ha='center', va='center', fontsize=16,\n",
892
- " transform=axes[2].transAxes)\n",
893
- " axes[2].set_title(\"Fourier Domain Sparsity\")\n",
894
- " axes[2].axis(\"off\")\n",
895
- "\n",
896
- " plt.tight_layout()\n",
897
- " plt.show()"
898
  ],
899
  "outputs": [],
900
  "execution_count": null
@@ -903,67 +489,17 @@
903
  "cell_type": "markdown",
904
  "metadata": {},
905
  "source": [
906
- "## 11. \ud83d\ude80 Next Steps \u2014 Scaling Up",
907
- "",
908
- "This notebook trained on ~800 images as a **proof of concept**. To get production quality:",
909
- "",
910
- "### Datasets for Each Training Stage",
911
  "",
912
- "| Stage | Dataset | Size | HF ID |",
913
- "|-------|---------|------|-------|",
914
- "| 1. VAE | ImageNet + CC3M | 4.2M images | `ILSVRC/imagenet-1k`, `pixparse/cc3m-wds` |",
915
- "| 2. Class-Cond | ImageNet | 1.2M images | `ILSVRC/imagenet-1k` |",
916
- "| 3. Text-Image | CC12M (VLM-recaptioned) | 12M images | `pixparse/cc12m-wds` |",
917
- "| 4. Aesthetic | JourneyDB + LAION-art | ~1M images | `caidas/JourneyDB` |",
918
- "| 5. Distillation | Self-distill from Stage 4 | Same data | \u2014 |",
919
  "",
920
- "### Optimization Tips for Larger Runs",
921
- "```python",
922
- "# On Kaggle with 2\u00d7 T4:",
923
- "# Use accelerate for multi-GPU",
924
- "# accelerate launch --num_processes 2 train.py",
925
- "",
926
- "# On Colab Pro (A100 40GB):",
927
- "BATCH_SIZE = 16",
928
- "GEN_EPOCHS = 500",
929
- "config = create_iris_small().config # Upgrade to IRIS-Small",
930
- "",
931
- "# For production (cloud GPUs):",
932
- "# Use IRIS-Base with 8\u00d7 A100",
933
- "# Add LADD adversarial distillation in Stage 5",
934
- "# Train for 200k+ steps on CC12M",
935
- "```",
936
- "",
937
- "### Model Size Recommendations",
938
- "| Use Case | Model | Batch | Resolution | GPU |",
939
- "|----------|-------|-------|-----------|-----|",
940
- "| Demo/Proof | IRIS-Tiny | 4 | 256px | T4 16GB |",
941
- "| Mobile deploy | IRIS-Small | 8 | 512px | A100 40GB |",
942
- "| Quality focus | IRIS-Base | 16 | 512px | 2\u00d7A100 |",
943
- "| Production | IRIS-Base | 64 | 1024px | 8\u00d7A100 |"
944
- ]
945
- },
946
- {
947
- "cell_type": "markdown",
948
- "metadata": {},
949
- "source": [
950
- "## 12. Kaggle Adaptation",
951
- "",
952
- "To run this on **Kaggle**, just change one thing:",
953
- "",
954
- "```python",
955
- "# In Kaggle, GPU is already available. Just:",
956
- "# 1. Copy this notebook to Kaggle",
957
- "# 2. Enable \"GPU T4 \u00d72\" or \"GPU P100\" in accelerator settings",
958
- "# 3. Run all cells!",
959
- "",
960
- "# For Kaggle's dual-T4 setup, use DataParallel:",
961
- "if torch.cuda.device_count() > 1:",
962
- " print(f\"Using {torch.cuda.device_count()} GPUs!\")",
963
- " iris.generator = torch.nn.DataParallel(iris.generator)",
964
- "```",
965
  "",
966
- "The training loop works identically on both platforms. \ud83c\udf89"
967
  ]
968
  },
969
  {
@@ -971,7 +507,7 @@
971
  "metadata": {},
972
  "source": [
973
  "---",
974
- "*Built with \u2764\ufe0f using the IRIS architecture. Repository: [asdf98/IRIS-architecture](https://huggingface.co/asdf98/IRIS-architecture)*"
975
  ]
976
  }
977
  ]
 
22
  "cell_type": "markdown",
23
  "metadata": {},
24
  "source": [
25
+ "# \ud83d\udd2e IRIS Training Notebook \u2014 v2",
26
  "",
27
+ "**Train the IRIS recurrent-depth image generator on free Colab/Kaggle GPUs.**",
28
  "",
29
+ "This version uses a **pre-trained Stable Diffusion VAE** (perfect reconstruction quality out of the box) ",
30
+ "so we focus 100% on training the novel IRIS generator.",
 
31
  "",
32
+ "### Pipeline",
33
+ "```",
34
+ "Image \u2192 SD-VAE Encode \u2192 z\u2080 [4\u00d732\u00d732] \u2192 IRIS Generator learns to denoise \u2192 SD-VAE Decode \u2192 Image",
35
+ "```",
 
 
36
  "",
37
+ "### Hardware",
38
+ "| Platform | GPU | VRAM | Training Time |",
39
+ "|----------|-----|------|---------------|",
40
+ "| **Colab Free** | T4 | 16GB | ~40-60 min |",
41
+ "| **Kaggle** | P100/T4\u00d72 | 16GB | ~40-60 min |",
42
+ "| **Colab Pro** | A100 | 40GB | ~15 min |"
43
  ]
44
  },
45
  {
46
  "cell_type": "markdown",
47
  "metadata": {},
48
  "source": [
49
+ "## 1. Setup"
50
  ]
51
  },
52
  {
53
  "cell_type": "code",
54
  "metadata": {},
55
  "source": [
56
+ "!pip install -q torch torchvision diffusers transformers datasets accelerate matplotlib tqdm huggingface_hub\n",
 
57
  "\n",
 
58
  "import torch\n",
59
+ "print(f\"PyTorch {torch.__version__} | CUDA: {torch.cuda.is_available()}\")\n",
 
60
  "if torch.cuda.is_available():\n",
61
+ " print(f\"GPU: {torch.cuda.get_device_name(0)} | VRAM: {torch.cuda.get_device_properties(0).total_mem/1024**3:.1f} GB\")\n",
 
62
  " device = torch.device('cuda')\n",
63
  "else:\n",
64
+ " print(\"\u26a0\ufe0f No GPU \u2014 will be slow!\")\n",
65
  " device = torch.device('cpu')"
66
  ],
67
  "outputs": [],
 
71
  "cell_type": "markdown",
72
  "metadata": {},
73
  "source": [
74
+ "## 2. Download IRIS Architecture"
75
  ]
76
  },
77
  {
78
  "cell_type": "code",
79
  "metadata": {},
80
  "source": [
 
81
  "from huggingface_hub import hf_hub_download\n",
82
+ "import shutil\n",
83
+ "\n",
84
+ "for f in [\"iris_model.py\"]:\n",
85
+ " shutil.copy(hf_hub_download(\"asdf98/IRIS-architecture\", f), f\"./{f}\")\n",
86
+ "\n",
87
+ "from iris_model import IRIS, IRISConfig, IRISGenerator, create_iris_tiny, create_iris_small, count_parameters\n",
88
+ "print(\"\u2705 IRIS loaded\")"
 
 
 
 
 
 
 
 
 
89
  ],
90
  "outputs": [],
91
  "execution_count": null
 
94
  "cell_type": "markdown",
95
  "metadata": {},
96
  "source": [
97
+ "## 3. Load Pre-trained SD VAE (Perfect Reconstruction)",
98
  "",
99
+ "Using `stabilityai/sd-vae-ft-mse` \u2014 the industry-standard VAE used by Stable Diffusion.",
100
+ "- 83M params, but **frozen** (no gradients, no VRAM for optimizer)",
101
+ "- Encodes 256\u00d7256 \u2192 4\u00d732\u00d732 latent (8\u00d7 spatial compression)",
102
+ "- Near-perfect reconstruction (PSNR 24.5dB on COCO)"
103
  ]
104
  },
105
  {
106
  "cell_type": "code",
107
  "metadata": {},
108
  "source": [
109
+ "from diffusers import AutoencoderKL\n",
110
+ "\n",
111
+ "print(\"Loading SD-VAE (sd-vae-ft-mse)...\")\n",
112
+ "sd_vae = AutoencoderKL.from_pretrained(\n",
113
+ " \"stabilityai/sd-vae-ft-mse\", torch_dtype=torch.float16\n",
114
+ ").to(device).eval()\n",
115
+ "\n",
116
+ "# Freeze completely\n",
117
+ "for p in sd_vae.parameters():\n",
118
+ " p.requires_grad = False\n",
119
+ "\n",
120
+ "SCALING_FACTOR = sd_vae.config.scaling_factor # 0.18215\n",
121
+ "print(f\"\u2705 SD-VAE loaded | scaling_factor={SCALING_FACTOR}\")\n",
122
+ "print(f\" Latent: 256px \u2192 [B, 4, 32, 32] | 512px \u2192 [B, 4, 64, 64]\")\n",
123
+ "\n",
124
+ "@torch.no_grad()\n",
125
+ "def vae_encode(images):\n",
126
+ " \"\"\"Images [-1,1] \u2192 latent [B,4,H/8,W/8]\"\"\"\n",
127
+ " dist = sd_vae.encode(images.half()).latent_dist\n",
128
+ " return dist.mean * SCALING_FACTOR # deterministic, no sampling noise\n",
129
+ "\n",
130
+ "@torch.no_grad()\n",
131
+ "def vae_decode(latents):\n",
132
+ " \"\"\"Latent \u2192 images [-1,1]\"\"\"\n",
133
+ " return sd_vae.decode(latents.half() / SCALING_FACTOR).sample.float()"
134
  ],
135
  "outputs": [],
136
  "execution_count": null
 
139
  "cell_type": "markdown",
140
  "metadata": {},
141
  "source": [
142
+ "## 4. Load Dataset & CLIP Text Encoder"
 
 
 
 
 
 
 
 
143
  ]
144
  },
145
  {
 
149
  "from datasets import load_dataset\n",
150
  "from torchvision import transforms\n",
151
  "from torch.utils.data import Dataset, DataLoader\n",
152
+ "from transformers import CLIPTextModel, CLIPTokenizer\n",
153
+ "import matplotlib.pyplot as plt\n",
154
+ "from tqdm.auto import tqdm\n",
155
  "\n",
156
+ "# \u2500\u2500\u2500 Dataset \u2500\u2500\u2500\n",
157
+ "IMAGE_SIZE = 256\n",
158
+ "BATCH_SIZE = 4\n",
 
159
  "\n",
160
+ "raw_dataset = load_dataset(\"reach-vb/pokemon-blip-captions\", split=\"train\")\n",
161
+ "print(f\"\u2705 Dataset: {len(raw_dataset)} image-caption pairs\")\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  "\n",
 
163
  "train_transform = transforms.Compose([\n",
164
  " transforms.Resize(IMAGE_SIZE, interpolation=transforms.InterpolationMode.LANCZOS),\n",
165
  " transforms.CenterCrop(IMAGE_SIZE),\n",
166
  " transforms.RandomHorizontalFlip(),\n",
167
+ " transforms.ToTensor(),\n",
168
+ " transforms.Normalize([0.5]*3, [0.5]*3),\n",
169
  "])\n",
170
  "\n",
171
  "class ImageCaptionDataset(Dataset):\n",
172
+ " def __init__(self, hf_ds, transform):\n",
173
+ " self.ds = hf_ds\n",
 
174
  " self.transform = transform\n",
175
+ " def __len__(self): return len(self.ds)\n",
176
+ " def __getitem__(self, i):\n",
177
+ " item = self.ds[i]\n",
178
+ " return self.transform(item[\"image\"].convert(\"RGB\")), item[\"text\"]\n",
179
  "\n",
180
+ "train_dataset = ImageCaptionDataset(raw_dataset, train_transform)\n",
 
181
  "\n",
182
+ "# \u2500\u2500\u2500 CLIP Text Encoder \u2500\u2500\u2500\n",
183
+ "print(\"Loading CLIP-L/14...\")\n",
184
+ "tokenizer = CLIPTokenizer.from_pretrained(\"openai/clip-vit-large-patch14\")\n",
185
+ "text_encoder = CLIPTextModel.from_pretrained(\"openai/clip-vit-large-patch14\").to(device).eval()\n",
186
+ "for p in text_encoder.parameters():\n",
187
+ " p.requires_grad = False\n",
188
  "\n",
189
+ "@torch.no_grad()\n",
190
+ "def encode_text(captions):\n",
191
+ " tok = tokenizer(captions, padding=\"max_length\", truncation=True, max_length=77, return_tensors=\"pt\").to(device)\n",
192
+ " return text_encoder(**tok).last_hidden_state\n",
 
 
193
  "\n",
194
+ "print(f\"\u2705 CLIP-L/14 loaded\")"
 
 
 
195
  ],
196
  "outputs": [],
197
  "execution_count": null
 
200
  "cell_type": "markdown",
201
  "metadata": {},
202
  "source": [
203
+ "## 5. Pre-encode Everything (One-Time Cost)",
204
  "",
205
+ "Encode ALL images and captions upfront \u2192 zero overhead during training."
 
206
  ]
207
  },
208
  {
209
  "cell_type": "code",
210
  "metadata": {},
211
  "source": [
212
+ "# Pre-encode all images through SD-VAE and all captions through CLIP\n",
213
+ "print(\"Pre-encoding dataset (one-time cost)...\")\n",
214
+ "cache_loader = DataLoader(train_dataset, batch_size=16, shuffle=False, num_workers=2)\n",
215
  "\n",
216
+ "all_latents = []\n",
217
+ "all_text_embs = []\n",
 
 
218
  "\n",
219
+ "for images, captions in tqdm(cache_loader, desc=\"Encoding\"):\n",
220
+ " images = images.to(device)\n",
221
+ " z = vae_encode(images)\n",
222
+ " all_latents.append(z.cpu())\n",
223
+ " \n",
224
+ " emb = encode_text(list(captions))\n",
225
+ " all_text_embs.append(emb.cpu())\n",
226
  "\n",
227
+ "all_latents = torch.cat(all_latents) # [N, 4, 32, 32]\n",
228
+ "all_text_embs = torch.cat(all_text_embs) # [N, 77, 768]\n",
 
 
229
  "\n",
230
+ "print(f\"\u2705 Pre-encoded {len(all_latents)} samples\")\n",
231
+ "print(f\" Latents: {all_latents.shape} | range [{all_latents.min():.2f}, {all_latents.max():.2f}]\")\n",
232
+ "print(f\" Text: {all_text_embs.shape}\")\n",
233
+ "\n",
234
+ "# \u2500\u2500\u2500 Free CLIP and VAE encoder from GPU to save VRAM \u2500\u2500\u2500\n",
235
+ "text_encoder.cpu()\n",
236
+ "# Keep sd_vae on GPU for decode during visualization\n",
237
+ "torch.cuda.empty_cache()\n",
238
+ "print(f\"\u2705 Freed ~600MB VRAM (CLIP moved to CPU)\")\n",
239
+ "\n",
240
+ "# \u2500\u2500\u2500 Show VAE reconstruction quality \u2500\u2500\u2500\n",
241
+ "fig, axes = plt.subplots(2, 6, figsize=(18, 6))\n",
242
+ "sample_imgs, _ = next(iter(DataLoader(train_dataset, batch_size=6, shuffle=True)))\n",
243
+ "sample_imgs = sample_imgs.to(device)\n",
244
+ "sample_z = vae_encode(sample_imgs)\n",
245
+ "sample_recon = vae_decode(sample_z)\n",
246
+ "\n",
247
+ "for i in range(6):\n",
248
+ " axes[0, i].imshow(sample_imgs[i].cpu().permute(1,2,0).numpy()*0.5+0.5)\n",
249
+ " axes[0, i].set_title(\"Original\", fontsize=9)\n",
250
+ " axes[0, i].axis(\"off\")\n",
251
+ " axes[1, i].imshow(sample_recon[i].cpu().clamp(-1,1).permute(1,2,0).numpy()*0.5+0.5)\n",
252
+ " axes[1, i].set_title(\"SD-VAE Recon\", fontsize=9)\n",
253
+ " axes[1, i].axis(\"off\")\n",
254
+ "plt.suptitle(\"Pre-trained SD-VAE Reconstruction (near-perfect)\", fontsize=13)\n",
255
+ "plt.tight_layout()\n",
256
+ "plt.show()"
257
  ],
258
  "outputs": [],
259
  "execution_count": null
 
262
  "cell_type": "markdown",
263
  "metadata": {},
264
  "source": [
265
+ "## 6. Create IRIS Generator",
 
 
266
  "",
267
+ "Now we create the IRIS generator that works in the SD-VAE latent space.",
268
+ "- `latent_channels=4` (SD-VAE standard)",
269
+ "- `latent_spatial=32` (256px / 8)",
270
+ "- No VAE training needed \u2014 we just train the denoiser!"
 
 
 
271
  ]
272
  },
273
  {
274
  "cell_type": "code",
275
  "metadata": {},
276
  "source": [
277
+ "# Create IRIS-Tiny (best for free-tier)\n",
 
 
 
 
 
 
278
  "config = IRISConfig(\n",
279
+ " latent_channels=4, # SD-VAE standard\n",
280
+ " latent_spatial=32, # 256px / 8\n",
281
+ " hidden_dim=384,\n",
282
  " num_heads=6,\n",
283
  " head_dim=64,\n",
284
  " ffn_ratio=2.667,\n",
 
291
  " sparsity_threshold=0.01,\n",
292
  " recurrence_dim=192,\n",
293
  " manhattan_window=12,\n",
294
+ " text_dim=768,\n",
295
  " max_text_tokens=77,\n",
296
  " patch_size=2,\n",
 
297
  ")\n",
298
  "\n",
299
+ "iris = IRIS(config).to(device)\n",
300
+ "gen_params = sum(p.numel() for p in iris.generator.parameters())\n",
301
+ "core_params = sum(p.numel() for p in iris.generator.core.parameters())\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
  "\n",
303
+ "print(f\"IRIS Generator: {gen_params:,} params ({gen_params*2/1024/1024:.1f} MB fp16)\")\n",
304
+ "print(f\" Core (shared): {core_params:,} ({core_params/gen_params*100:.1f}%)\")\n",
305
+ "print(f\" Effective @r=6: ~{gen_params + 5*core_params:,} effective params\")\n",
306
+ "print(f\" Input: [B, 4, 32, 32] latent \u2192 Output: [B, 4, 32, 32] velocity\")"
 
 
 
 
 
 
307
  ],
308
  "outputs": [],
309
  "execution_count": null
 
312
  "cell_type": "markdown",
313
  "metadata": {},
314
  "source": [
315
+ "## 7. Train IRIS Generator (Rectified Flow)",
 
 
316
  "",
317
+ "The main training loop. Since everything is pre-cached, each epoch is **pure generator training** \u2014 no VAE encoding, no CLIP forward passes."
 
 
 
 
 
 
 
 
318
  ]
319
  },
320
  {
321
  "cell_type": "code",
322
  "metadata": {},
323
  "source": [
324
+ "import time, math\n",
 
 
 
 
 
325
  "\n",
326
+ "# \u2500\u2500\u2500 Cached DataLoader \u2500\u2500\u2500\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327
  "class CachedDataset(Dataset):\n",
328
+ " def __init__(self, latents, text_embs):\n",
329
+ " self.latents = latents\n",
330
+ " self.text_embs = text_embs\n",
331
+ " def __len__(self): return len(self.latents)\n",
332
+ " def __getitem__(self, i): return self.latents[i], self.text_embs[i]\n",
 
 
 
 
 
 
 
 
 
333
  "\n",
334
+ "cached_loader = DataLoader(\n",
335
+ " CachedDataset(all_latents, all_text_embs),\n",
336
+ " batch_size=BATCH_SIZE, shuffle=True, num_workers=2,\n",
337
+ " pin_memory=True, drop_last=True,\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
338
  ")\n",
339
  "\n",
340
+ "# \u2500\u2500\u2500 Training Config \u2500\u2500\u2500\n",
341
+ "EPOCHS = 200\n",
342
+ "LR = 2e-4\n",
343
+ "GRAD_ACCUM = 2\n",
 
 
 
344
  "\n",
345
+ "optimizer = torch.optim.AdamW(iris.generator.parameters(), lr=LR, weight_decay=0.03, betas=(0.9, 0.95))\n",
346
+ "total_steps = EPOCHS * len(cached_loader) // GRAD_ACCUM\n",
347
+ "warmup = min(200, total_steps // 10)\n",
348
  "\n",
349
+ "scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda s: \n",
350
+ " s/max(1,warmup) if s < warmup else 0.5*(1+math.cos(math.pi*(s-warmup)/max(1,total_steps-warmup))))\n",
351
+ "scaler = torch.amp.GradScaler('cuda')\n",
352
  "\n",
353
+ "print(f\"Training for {EPOCHS} epochs ({total_steps} optimizer steps)\")\n",
354
+ "print(f\"Batch: {BATCH_SIZE} \u00d7 {GRAD_ACCUM} accum = {BATCH_SIZE*GRAD_ACCUM} effective\")\n",
355
+ "print(f\"Iterations per step: random from [3, 4, 5]\")\n",
356
+ "print()\n",
357
  "\n",
358
+ "# \u2500\u2500\u2500 Training Loop \u2500\u2500\u2500\n",
359
+ "losses = []\n",
360
  "iris.generator.train()\n",
 
361
  "best_loss = float('inf')\n",
362
+ "global_step = 0\n",
363
  "\n",
364
+ "pbar = tqdm(range(EPOCHS), desc=\"Training\")\n",
365
  "for epoch in pbar:\n",
366
+ " epoch_loss = 0\n",
367
+ " n = 0\n",
368
+ " optimizer.zero_grad(set_to_none=True)\n",
369
  "\n",
370
+ " for batch_idx, (z_0, text_emb) in enumerate(cached_loader):\n",
371
+ " z_0 = z_0.to(device, non_blocking=True)\n",
 
 
372
  " text_emb = text_emb.to(device, non_blocking=True)\n",
373
  "\n",
 
374
  " with torch.amp.autocast('cuda', dtype=torch.float16):\n",
 
375
  " r = [3, 4, 5][torch.randint(0, 3, (1,)).item()]\n",
376
+ " result = iris.train_step_latent(z_0, text_emb, num_iterations=r)\n",
377
  " loss = result[\"loss\"] / GRAD_ACCUM\n",
378
  "\n",
379
+ " scaler.scale(loss).backward()\n",
380
  "\n",
 
381
  " if (batch_idx + 1) % GRAD_ACCUM == 0:\n",
382
+ " scaler.unscale_(optimizer)\n",
383
  " torch.nn.utils.clip_grad_norm_(iris.generator.parameters(), 1.0)\n",
384
+ " scaler.step(optimizer)\n",
385
+ " scaler.update()\n",
386
+ " optimizer.zero_grad(set_to_none=True)\n",
387
+ " scheduler.step()\n",
388
  " global_step += 1\n",
389
  "\n",
390
+ " epoch_loss += result[\"velocity_loss\"]\n",
391
+ " n += 1\n",
 
 
 
 
 
 
392
  "\n",
393
+ " avg = epoch_loss / n\n",
394
+ " losses.append(avg)\n",
395
+ " if avg < best_loss:\n",
396
+ " best_loss = avg\n",
397
+ " pbar.set_postfix(loss=f\"{avg:.4f}\", best=f\"{best_loss:.4f}\", lr=f\"{optimizer.param_groups[0]['lr']:.1e}\")\n",
398
  "\n",
399
+ "print(f\"\\n\u2705 Training complete! Best loss: {best_loss:.4f}\")"
 
 
 
 
 
400
  ],
401
  "outputs": [],
402
  "execution_count": null
 
405
  "cell_type": "markdown",
406
  "metadata": {},
407
  "source": [
408
+ "## 8. Generate Images!"
 
 
 
409
  ]
410
  },
411
  {
412
  "cell_type": "code",
413
  "metadata": {},
414
  "source": [
415
+ "# Reload CLIP on GPU for prompt encoding\n",
416
+ "text_encoder.to(device)\n",
 
417
  "\n",
418
  "prompts = [\n",
419
  " \"a fire-breathing dragon pokemon\",\n",
420
+ " \"a cute blue water pokemon with bubbles\",\n",
421
  " \"a green grass-type pokemon with leaves\",\n",
422
+ " \"a yellow electric pokemon with lightning bolts\",\n",
 
 
 
 
423
  "]\n",
424
  "\n",
425
  "iris.eval()\n",
426
+ "fig, axes = plt.subplots(len(prompts), 4, figsize=(16, len(prompts)*4))\n",
427
+ "iter_counts = [2, 4, 6, 8]\n",
 
 
428
  "\n",
429
  "for row, prompt in enumerate(prompts):\n",
430
+ " text_emb = encode_text([prompt])\n",
431
+ " for col, r in enumerate(iter_counts):\n",
432
+ " z = iris.generate_latent(text_emb, num_steps=4, num_iterations=r, cfg_scale=1.0, seed=42)\n",
433
+ " img = vae_decode(z)\n",
434
+ " img_np = img[0].cpu().clamp(-1, 1).permute(1, 2, 0).numpy() * 0.5 + 0.5\n",
 
 
 
 
 
 
 
 
 
 
 
435
  " axes[row, col].imshow(img_np)\n",
436
  " axes[row, col].axis(\"off\")\n",
437
  " if row == 0:\n",
438
+ " axes[row, col].set_title(f\"r={r} iterations\", fontsize=11)\n",
439
+ " axes[row, 0].set_ylabel(prompt[:30], fontsize=9, rotation=0, labelpad=120, va='center')\n",
440
  "\n",
441
+ "plt.suptitle(\"IRIS Generated Images (Adaptive Compute)\", fontsize=14, y=1.01)\n",
442
  "plt.tight_layout()\n",
443
  "plt.show()\n",
444
  "\n",
445
+ "print(\"Note: ~800 training images \u2192 noisy outputs. This validates the architecture works.\")\n",
446
+ "print(\"Scale up with CC3M/CC12M + more epochs for production quality.\")"
 
 
 
 
447
  ],
448
  "outputs": [],
449
  "execution_count": null
 
452
  "cell_type": "markdown",
453
  "metadata": {},
454
  "source": [
455
+ "## 9. Training Loss & Checkpoint"
456
  ]
457
  },
458
  {
459
  "cell_type": "code",
460
  "metadata": {},
461
  "source": [
462
+ "# Loss curve\n",
463
+ "plt.figure(figsize=(10, 4))\n",
464
+ "plt.plot(losses, color='green', alpha=0.7)\n",
465
+ "plt.plot([sum(losses[max(0,i-10):i+1])/min(i+1,10) for i in range(len(losses))], \n",
466
+ " color='green', linewidth=2, label='Moving Avg (10)')\n",
467
+ "plt.xlabel(\"Epoch\")\n",
468
+ "plt.ylabel(\"Velocity Loss\")\n",
469
+ "plt.title(\"IRIS Generator Training Loss\")\n",
470
+ "plt.legend()\n",
471
+ "plt.grid(True, alpha=0.3)\n",
472
+ "plt.show()\n",
 
 
 
 
 
 
 
 
 
 
473
  "\n",
474
+ "# Save checkpoint\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475
  "import os\n",
476
  "os.makedirs(\"iris_checkpoint\", exist_ok=True)\n",
477
+ "torch.save({\n",
 
478
  " \"config\": config,\n",
479
+ " \"generator_state_dict\": iris.generator.state_dict(),\n",
480
+ " \"best_loss\": best_loss,\n",
481
+ " \"losses\": losses,\n",
482
+ "}, \"iris_checkpoint/iris_gen.pt\")\n",
483
+ "print(f\"\u2705 Saved: iris_checkpoint/iris_gen.pt ({os.path.getsize('iris_checkpoint/iris_gen.pt')/1024/1024:.1f} MB)\")"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
484
  ],
485
  "outputs": [],
486
  "execution_count": null
 
489
  "cell_type": "markdown",
490
  "metadata": {},
491
  "source": [
492
+ "## 10. Scaling Up",
 
 
 
 
493
  "",
494
+ "| Scale | Dataset | Model | GPU | Expected Quality |",
495
+ "|-------|---------|-------|-----|-----------------|",
496
+ "| **This notebook** | Pok\u00e9mon (833) | IRIS-Tiny | T4 free | Proof of concept |",
497
+ "| **Hobby** | CC3M (3M) | IRIS-Small | A100 40GB | Decent |",
498
+ "| **Production** | CC12M + LAION | IRIS-Base | 4\u00d7A100 | High quality |",
 
 
499
  "",
500
+ "For **Kaggle** dual-T4: just enable `GPU T4 \u00d72` and run as-is. DataParallel is automatic for larger models.",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
501
  "",
502
+ "For **512px generation**: change `IMAGE_SIZE=512` and `latent_spatial=64`. Everything else stays the same."
503
  ]
504
  },
505
  {
 
507
  "metadata": {},
508
  "source": [
509
  "---",
510
+ "*[asdf98/IRIS-architecture](https://huggingface.co/asdf98/IRIS-architecture)*"
511
  ]
512
  }
513
  ]