{ "nbformat": 4, "nbformat_minor": 5, "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.10.0" }, "accelerator": "GPU", "colab": { "provenance": [], "gpuType": "T4" } }, "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# \ud83d\udd2e IRIS Training Notebook \u2014 v2", "", "**Train the IRIS recurrent-depth image generator on free Colab/Kaggle GPUs.**", "", "This version uses a **pre-trained Stable Diffusion VAE** (perfect reconstruction quality out of the box) ", "so we focus 100% on training the novel IRIS generator.", "", "### Pipeline", "```", "Image \u2192 SD-VAE Encode \u2192 z\u2080 [4\u00d732\u00d732] \u2192 IRIS Generator learns to denoise \u2192 SD-VAE Decode \u2192 Image", "```", "", "### Hardware", "| Platform | GPU | VRAM | Training Time |", "|----------|-----|------|---------------|", "| **Colab Free** | T4 | 16GB | ~40-60 min |", "| **Kaggle** | P100/T4\u00d72 | 16GB | ~40-60 min |", "| **Colab Pro** | A100 | 40GB | ~15 min |" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Setup" ] }, { "cell_type": "code", "metadata": {}, "source": [ "!pip install -q torch torchvision diffusers transformers datasets accelerate matplotlib tqdm huggingface_hub\n", "\n", "import torch\n", "print(f\"PyTorch {torch.__version__} | CUDA: {torch.cuda.is_available()}\")\n", "if torch.cuda.is_available():\n", " print(f\"GPU: {torch.cuda.get_device_name(0)} | VRAM: {torch.cuda.get_device_properties(0).total_mem/1024**3:.1f} GB\")\n", " device = torch.device('cuda')\n", "else:\n", " print(\"\u26a0\ufe0f No GPU \u2014 will be slow!\")\n", " device = torch.device('cpu')" ], "outputs": [], "execution_count": null }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Download IRIS Architecture" ] }, { "cell_type": "code", "metadata": {}, "source": [ "from huggingface_hub import hf_hub_download\n", "import shutil\n", "\n", "# Force fresh download (bypass cache) to get latest version\n", "for f in [\"iris_model.py\"]:\n", " path = hf_hub_download(\"asdf98/IRIS-architecture\", f, force_download=True)\n", " shutil.copy(path, f\"./{f}\")\n", "\n", "from iris_model import IRIS, IRISConfig, IRISGenerator, create_iris_tiny, create_iris_small, count_parameters\n", "print(\"\u2705 IRIS loaded\")\n", "\n", "# Quick verification that train_step_latent exists\n", "assert hasattr(IRIS, 'train_step_latent'), \"ERROR: Old iris_model.py cached! Restart runtime and re-run.\"\n", "assert hasattr(IRIS, 'generate_latent'), \"ERROR: Old iris_model.py cached! Restart runtime and re-run.\"\n", "print(\"\u2705 Verified: train_step_latent and generate_latent available\")" ], "outputs": [], "execution_count": null }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Load Pre-trained SD VAE (Perfect Reconstruction)", "", "Using `stabilityai/sd-vae-ft-mse` \u2014 the industry-standard VAE used by Stable Diffusion.", "- 83M params, but **frozen** (no gradients, no VRAM for optimizer)", "- Encodes 256\u00d7256 \u2192 4\u00d732\u00d732 latent (8\u00d7 spatial compression)", "- Near-perfect reconstruction (PSNR 24.5dB on COCO)" ] }, { "cell_type": "code", "metadata": {}, "source": [ "from diffusers import AutoencoderKL\n", "\n", "print(\"Loading SD-VAE (sd-vae-ft-mse)...\")\n", "sd_vae = AutoencoderKL.from_pretrained(\n", " \"stabilityai/sd-vae-ft-mse\", torch_dtype=torch.float16\n", ").to(device).eval()\n", "\n", "# Freeze completely\n", "for p in sd_vae.parameters():\n", " p.requires_grad = False\n", "\n", "SCALING_FACTOR = sd_vae.config.scaling_factor # 0.18215\n", "print(f\"\u2705 SD-VAE loaded | scaling_factor={SCALING_FACTOR}\")\n", "print(f\" Latent: 256px \u2192 [B, 4, 32, 32] | 512px \u2192 [B, 4, 64, 64]\")\n", "\n", "@torch.no_grad()\n", "def vae_encode(images):\n", " \"\"\"Images [-1,1] \u2192 latent [B,4,H/8,W/8]\"\"\"\n", " dist = sd_vae.encode(images.half()).latent_dist\n", " return dist.mean * SCALING_FACTOR # deterministic, no sampling noise\n", "\n", "@torch.no_grad()\n", "def vae_decode(latents):\n", " \"\"\"Latent \u2192 images [-1,1]\"\"\"\n", " return sd_vae.decode(latents.half() / SCALING_FACTOR).sample.float()" ], "outputs": [], "execution_count": null }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Load Dataset & CLIP Text Encoder" ] }, { "cell_type": "code", "metadata": {}, "source": [ "from datasets import load_dataset\n", "from torchvision import transforms\n", "from torch.utils.data import Dataset, DataLoader\n", "from transformers import CLIPTextModel, CLIPTokenizer\n", "import matplotlib.pyplot as plt\n", "from tqdm.auto import tqdm\n", "\n", "# \u2500\u2500\u2500 Dataset \u2500\u2500\u2500\n", "IMAGE_SIZE = 256\n", "BATCH_SIZE = 4\n", "\n", "raw_dataset = load_dataset(\"reach-vb/pokemon-blip-captions\", split=\"train\")\n", "print(f\"\u2705 Dataset: {len(raw_dataset)} image-caption pairs\")\n", "\n", "train_transform = transforms.Compose([\n", " transforms.Resize(IMAGE_SIZE, interpolation=transforms.InterpolationMode.LANCZOS),\n", " transforms.CenterCrop(IMAGE_SIZE),\n", " transforms.RandomHorizontalFlip(),\n", " transforms.ToTensor(),\n", " transforms.Normalize([0.5]*3, [0.5]*3),\n", "])\n", "\n", "class ImageCaptionDataset(Dataset):\n", " def __init__(self, hf_ds, transform):\n", " self.ds = hf_ds\n", " self.transform = transform\n", " def __len__(self): return len(self.ds)\n", " def __getitem__(self, i):\n", " item = self.ds[i]\n", " return self.transform(item[\"image\"].convert(\"RGB\")), item[\"text\"]\n", "\n", "train_dataset = ImageCaptionDataset(raw_dataset, train_transform)\n", "\n", "# \u2500\u2500\u2500 CLIP Text Encoder \u2500\u2500\u2500\n", "print(\"Loading CLIP-L/14...\")\n", "tokenizer = CLIPTokenizer.from_pretrained(\"openai/clip-vit-large-patch14\")\n", "text_encoder = CLIPTextModel.from_pretrained(\"openai/clip-vit-large-patch14\").to(device).eval()\n", "for p in text_encoder.parameters():\n", " p.requires_grad = False\n", "\n", "@torch.no_grad()\n", "def encode_text(captions):\n", " tok = tokenizer(captions, padding=\"max_length\", truncation=True, max_length=77, return_tensors=\"pt\").to(device)\n", " return text_encoder(**tok).last_hidden_state\n", "\n", "print(f\"\u2705 CLIP-L/14 loaded\")" ], "outputs": [], "execution_count": null }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. Pre-encode Everything (One-Time Cost)", "", "Encode ALL images and captions upfront \u2192 zero overhead during training." ] }, { "cell_type": "code", "metadata": {}, "source": [ "# Pre-encode all images through SD-VAE and all captions through CLIP\n", "print(\"Pre-encoding dataset (one-time cost)...\")\n", "cache_loader = DataLoader(train_dataset, batch_size=16, shuffle=False, num_workers=2)\n", "\n", "all_latents = []\n", "all_text_embs = []\n", "\n", "for images, captions in tqdm(cache_loader, desc=\"Encoding\"):\n", " images = images.to(device)\n", " z = vae_encode(images)\n", " all_latents.append(z.cpu())\n", " \n", " emb = encode_text(list(captions))\n", " all_text_embs.append(emb.cpu())\n", "\n", "all_latents = torch.cat(all_latents) # [N, 4, 32, 32]\n", "all_text_embs = torch.cat(all_text_embs) # [N, 77, 768]\n", "\n", "print(f\"\u2705 Pre-encoded {len(all_latents)} samples\")\n", "print(f\" Latents: {all_latents.shape} | range [{all_latents.min():.2f}, {all_latents.max():.2f}]\")\n", "print(f\" Text: {all_text_embs.shape}\")\n", "\n", "# \u2500\u2500\u2500 Free CLIP and VAE encoder from GPU to save VRAM \u2500\u2500\u2500\n", "text_encoder.cpu()\n", "# Keep sd_vae on GPU for decode during visualization\n", "torch.cuda.empty_cache()\n", "print(f\"\u2705 Freed ~600MB VRAM (CLIP moved to CPU)\")\n", "\n", "# \u2500\u2500\u2500 Show VAE reconstruction quality \u2500\u2500\u2500\n", "fig, axes = plt.subplots(2, 6, figsize=(18, 6))\n", "sample_imgs, _ = next(iter(DataLoader(train_dataset, batch_size=6, shuffle=True)))\n", "sample_imgs = sample_imgs.to(device)\n", "sample_z = vae_encode(sample_imgs)\n", "sample_recon = vae_decode(sample_z)\n", "\n", "for i in range(6):\n", " axes[0, i].imshow(sample_imgs[i].cpu().permute(1,2,0).numpy()*0.5+0.5)\n", " axes[0, i].set_title(\"Original\", fontsize=9)\n", " axes[0, i].axis(\"off\")\n", " axes[1, i].imshow(sample_recon[i].cpu().clamp(-1,1).permute(1,2,0).numpy()*0.5+0.5)\n", " axes[1, i].set_title(\"SD-VAE Recon\", fontsize=9)\n", " axes[1, i].axis(\"off\")\n", "plt.suptitle(\"Pre-trained SD-VAE Reconstruction (near-perfect)\", fontsize=13)\n", "plt.tight_layout()\n", "plt.show()" ], "outputs": [], "execution_count": null }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 6. Create IRIS Generator", "", "Now we create the IRIS generator that works in the SD-VAE latent space.", "- `latent_channels=4` (SD-VAE standard)", "- `latent_spatial=32` (256px / 8)", "- No VAE training needed \u2014 we just train the denoiser!" ] }, { "cell_type": "code", "metadata": {}, "source": [ "# Create IRIS-Tiny (best for free-tier)\n", "# patch_size=4 reduces tokens from 256 to 64 \u2192 4\u00d7 faster training\n", "config = IRISConfig(\n", " latent_channels=4, # SD-VAE standard\n", " latent_spatial=32, # 256px / 8\n", " hidden_dim=384,\n", " num_heads=6,\n", " head_dim=64,\n", " ffn_ratio=2.667,\n", " num_prelude_blocks=1,\n", " num_core_layers=2, # 2 layers (speed vs quality tradeoff for demo)\n", " num_coda_blocks=1,\n", " default_iterations=4,\n", " max_iterations=16,\n", " fourier_num_blocks=6,\n", " sparsity_threshold=0.01,\n", " recurrence_dim=192,\n", " manhattan_window=8,\n", " text_dim=768,\n", " max_text_tokens=77,\n", " patch_size=4, # 4\u00d7 larger patches \u2192 64 tokens instead of 256\n", ")\n", "\n", "iris = IRIS(config).to(device)\n", "gen_params = sum(p.numel() for p in iris.generator.parameters())\n", "core_params = sum(p.numel() for p in iris.generator.core.parameters())\n", "\n", "print(f\"IRIS Generator: {gen_params:,} params ({gen_params*2/1024/1024:.1f} MB fp16)\")\n", "print(f\" Core (shared): {core_params:,} ({core_params/gen_params*100:.1f}%)\")\n", "print(f\" Tokens: {config.num_patches} (from {config.latent_spatial}\u00d7{config.latent_spatial} latent, patch_size={config.patch_size})\")\n", "print(f\" Input: [B, 4, 32, 32] latent \u2192 Output: [B, 4, 32, 32] velocity\")" ], "outputs": [], "execution_count": null }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 7. Train IRIS Generator (Rectified Flow)", "", "The main training loop. Since everything is pre-cached, each epoch is **pure generator training** \u2014 no VAE encoding, no CLIP forward passes." ] }, { "cell_type": "code", "metadata": {}, "source": [ "import time, math\n", "\n", "# \u2500\u2500\u2500 Cached DataLoader \u2500\u2500\u2500\n", "class CachedDataset(Dataset):\n", " def __init__(self, latents, text_embs):\n", " self.latents = latents\n", " self.text_embs = text_embs\n", " def __len__(self): return len(self.latents)\n", " def __getitem__(self, i): return self.latents[i], self.text_embs[i]\n", "\n", "cached_loader = DataLoader(\n", " CachedDataset(all_latents, all_text_embs),\n", " batch_size=BATCH_SIZE, shuffle=True, num_workers=2,\n", " pin_memory=True, drop_last=True,\n", ")\n", "\n", "# \u2500\u2500\u2500 Training Config \u2500\u2500\u2500\n", "EPOCHS = 200\n", "LR = 2e-4\n", "GRAD_ACCUM = 2\n", "\n", "optimizer = torch.optim.AdamW(iris.generator.parameters(), lr=LR, weight_decay=0.03, betas=(0.9, 0.95))\n", "total_steps = EPOCHS * len(cached_loader) // GRAD_ACCUM\n", "warmup = min(200, total_steps // 10)\n", "\n", "scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda s: \n", " s/max(1,warmup) if s < warmup else 0.5*(1+math.cos(math.pi*(s-warmup)/max(1,total_steps-warmup))))\n", "scaler = torch.amp.GradScaler('cuda')\n", "\n", "print(f\"Training for {EPOCHS} epochs ({total_steps} optimizer steps)\")\n", "print(f\"Batch: {BATCH_SIZE} \u00d7 {GRAD_ACCUM} accum = {BATCH_SIZE*GRAD_ACCUM} effective\")\n", "print(f\"Iterations per step: random from [2, 3, 4]\")\n", "print()\n", "\n", "# \u2500\u2500\u2500 Training Loop \u2500\u2500\u2500\n", "losses = []\n", "iris.generator.train()\n", "best_loss = float('inf')\n", "global_step = 0\n", "\n", "pbar = tqdm(range(EPOCHS), desc=\"Training\")\n", "for epoch in pbar:\n", " epoch_loss = 0\n", " n = 0\n", " optimizer.zero_grad(set_to_none=True)\n", "\n", " for batch_idx, (z_0, text_emb) in enumerate(cached_loader):\n", " z_0 = z_0.to(device, non_blocking=True)\n", " text_emb = text_emb.to(device, non_blocking=True)\n", "\n", " with torch.amp.autocast('cuda', dtype=torch.float16):\n", " r = [2, 3, 4][torch.randint(0, 3, (1,)).item()]\n", " result = iris.train_step_latent(z_0, text_emb, num_iterations=r)\n", " loss = result[\"loss\"] / GRAD_ACCUM\n", "\n", " scaler.scale(loss).backward()\n", "\n", " if (batch_idx + 1) % GRAD_ACCUM == 0:\n", " scaler.unscale_(optimizer)\n", " torch.nn.utils.clip_grad_norm_(iris.generator.parameters(), 1.0)\n", " scaler.step(optimizer)\n", " scaler.update()\n", " optimizer.zero_grad(set_to_none=True)\n", " scheduler.step()\n", " global_step += 1\n", "\n", " epoch_loss += result[\"velocity_loss\"]\n", " n += 1\n", "\n", " avg = epoch_loss / n\n", " losses.append(avg)\n", " if avg < best_loss:\n", " best_loss = avg\n", " pbar.set_postfix(loss=f\"{avg:.4f}\", best=f\"{best_loss:.4f}\", lr=f\"{optimizer.param_groups[0]['lr']:.1e}\")\n", "\n", "print(f\"\\n\u2705 Training complete! Best loss: {best_loss:.4f}\")" ], "outputs": [], "execution_count": null }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 8. Generate Images!" ] }, { "cell_type": "code", "metadata": {}, "source": [ "# Reload CLIP on GPU for prompt encoding\n", "text_encoder.to(device)\n", "\n", "prompts = [\n", " \"a fire-breathing dragon pokemon\",\n", " \"a cute blue water pokemon with bubbles\",\n", " \"a green grass-type pokemon with leaves\",\n", " \"a yellow electric pokemon with lightning bolts\",\n", "]\n", "\n", "iris.eval()\n", "fig, axes = plt.subplots(len(prompts), 4, figsize=(16, len(prompts)*4))\n", "iter_counts = [2, 3, 4, 6]\n", "\n", "for row, prompt in enumerate(prompts):\n", " text_emb = encode_text([prompt])\n", " for col, r in enumerate(iter_counts):\n", " z = iris.generate_latent(text_emb, num_steps=4, num_iterations=r, cfg_scale=1.0, seed=42)\n", " img = vae_decode(z)\n", " img_np = img[0].cpu().clamp(-1, 1).permute(1, 2, 0).numpy() * 0.5 + 0.5\n", " axes[row, col].imshow(img_np)\n", " axes[row, col].axis(\"off\")\n", " if row == 0:\n", " axes[row, col].set_title(f\"r={r} iterations\", fontsize=11)\n", " axes[row, 0].set_ylabel(prompt[:30], fontsize=9, rotation=0, labelpad=120, va='center')\n", "\n", "plt.suptitle(\"IRIS Generated Images (Adaptive Compute)\", fontsize=14, y=1.01)\n", "plt.tight_layout()\n", "plt.show()\n", "\n", "print(\"Note: ~800 training images \u2192 noisy outputs. This validates the architecture works.\")\n", "print(\"Scale up with CC3M/CC12M + more epochs for production quality.\")" ], "outputs": [], "execution_count": null }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 9. Training Loss & Checkpoint" ] }, { "cell_type": "code", "metadata": {}, "source": [ "# Loss curve\n", "plt.figure(figsize=(10, 4))\n", "plt.plot(losses, color='green', alpha=0.7)\n", "plt.plot([sum(losses[max(0,i-10):i+1])/min(i+1,10) for i in range(len(losses))], \n", " color='green', linewidth=2, label='Moving Avg (10)')\n", "plt.xlabel(\"Epoch\")\n", "plt.ylabel(\"Velocity Loss\")\n", "plt.title(\"IRIS Generator Training Loss\")\n", "plt.legend()\n", "plt.grid(True, alpha=0.3)\n", "plt.show()\n", "\n", "# Save checkpoint\n", "import os\n", "os.makedirs(\"iris_checkpoint\", exist_ok=True)\n", "torch.save({\n", " \"config\": config,\n", " \"generator_state_dict\": iris.generator.state_dict(),\n", " \"best_loss\": best_loss,\n", " \"losses\": losses,\n", "}, \"iris_checkpoint/iris_gen.pt\")\n", "print(f\"\u2705 Saved: iris_checkpoint/iris_gen.pt ({os.path.getsize('iris_checkpoint/iris_gen.pt')/1024/1024:.1f} MB)\")" ], "outputs": [], "execution_count": null }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 10. Scaling Up", "", "| Scale | Dataset | Model | GPU | Expected Quality |", "|-------|---------|-------|-----|-----------------|", "| **This notebook** | Pok\u00e9mon (833) | IRIS-Tiny | T4 free | Proof of concept |", "| **Hobby** | CC3M (3M) | IRIS-Small | A100 40GB | Decent |", "| **Production** | CC12M + LAION | IRIS-Base | 4\u00d7A100 | High quality |", "", "For **Kaggle** dual-T4: just enable `GPU T4 \u00d72` and run as-is. DataParallel is automatic for larger models.", "", "For **512px generation**: change `IMAGE_SIZE=512` and `latent_spatial=64`. Everything else stays the same." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "---", "*[asdf98/IRIS-architecture](https://huggingface.co/asdf98/IRIS-architecture)*" ] } ] }