#!/usr/bin/env python3 """ IRIS Colab Training — One-Click, Real Dataset, Real Learning ============================================================= Copy-paste into Google Colab (free tier T4) and run all cells. Trains IRIS on Pokemon BLIP Captions (833 images + text). Colab free tier specs (2025): - GPU: NVIDIA T4 (16 GB VRAM) - System RAM: ~12.7 GB - Disk: ~78 GB - PyTorch: 2.5+ preinstalled - Runtime: ~12 hours max session What this script does: 1. Installs dependencies (~30s) 2. Downloads IRIS source from HF Hub 3. Downloads DC-AE encoder (1.2 GB) + text encoder (87 MB) 4. Encodes all 833 Pokemon images to latents (~2 min on T4) 5. Encodes all captions to text embeddings (~5s) 6. Frees encoder VRAM 7. Trains IRIS-Small (40M params) for 3000 steps (~15 min on T4) 8. Generates sample images from trained model 9. Saves checkpoint Total wall time: ~20 minutes for a trained model. """ # ============================================================ # CELL 1: Install dependencies # ============================================================ print("Installing dependencies...") import subprocess, sys subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "diffusers>=0.32.0", "sentence-transformers", "datasets", "accelerate", "huggingface_hub", ]) print("Done.") # ============================================================ # CELL 2: Download IRIS source code # ============================================================ print("Downloading IRIS architecture from HF Hub...") from huggingface_hub import snapshot_download import os, shutil iris_path = snapshot_download( repo_id="asdf98/iris-image-gen", allow_patterns=["iris/*.py"], local_dir="./iris_repo", ) # Add to Python path sys.path.insert(0, os.path.join(iris_path)) print(f"IRIS source at: {iris_path}") # Verify import from iris import IRIS, get_model_config, flow_matching_loss, euler_sample from iris.flow_matching import DCAE_F32C32_SCALE print("IRIS imported successfully.") # ============================================================ # CELL 3: Detect hardware # ============================================================ import torch import gc device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device.type == "cuda": gpu_name = torch.cuda.get_device_name(0) gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9 print(f"GPU: {gpu_name} ({gpu_mem:.1f} GB)") else: print("WARNING: No GPU detected. Training will be very slow.") print("In Colab: Runtime -> Change runtime type -> T4 GPU") use_amp = device.type == "cuda" # T4 (compute capability 7.5) reports bf16 supported but cuDNN conv2d kernels # lack bf16 engines → crashes at runtime. Force fp16 which T4 natively supports. if use_amp: cc = torch.cuda.get_device_capability(0) if cc[0] < 8: # Ampere (8.0+) has native bf16; Turing (7.5) does not amp_dtype = torch.float16 print(f"GPU compute capability {cc[0]}.{cc[1]} — using fp16 (bf16 conv kernels unavailable)") else: amp_dtype = torch.bfloat16 print(f"GPU compute capability {cc[0]}.{cc[1]} — using bf16") else: amp_dtype = torch.float32 print(f"AMP dtype: {amp_dtype}") # ============================================================ # CELL 4: Load dataset # ============================================================ print("\nLoading Pokemon BLIP Captions dataset...") from datasets import load_dataset ds = load_dataset("reach-vb/pokemon-blip-captions", split="train") print(f"Loaded {len(ds)} images with captions.") print(f"Example: '{ds[0]['text']}'") # ============================================================ # CELL 5: Encode all images to DC-AE latents # ============================================================ print("\nLoading DC-AE encoder (~1.2 GB)...") from diffusers import AutoencoderDC import torchvision.transforms as T # Use float16 to save VRAM — stable for inference ae = AutoencoderDC.from_pretrained( "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers", torch_dtype=torch.float16, ).to(device).eval() ae.requires_grad_(False) SCALE = ae.config.scaling_factor # 0.41407 transform = T.Compose([ T.Resize(512, interpolation=T.InterpolationMode.BICUBIC, antialias=True), T.CenterCrop(512), T.ToTensor(), T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) print("Encoding images to latents...") all_latents = [] import time t0 = time.time() batch_imgs = [] for i, example in enumerate(ds): img = example["image"].convert("RGB") tensor = transform(img) batch_imgs.append(tensor) # Process in batches of 8 if len(batch_imgs) == 8 or i == len(ds) - 1: batch = torch.stack(batch_imgs).to(device, dtype=torch.float16) with torch.no_grad(): latent = ae.encode(batch).latent.float() # encode in fp16, store in fp32 all_latents.append(latent.cpu()) batch_imgs = [] if (i + 1) % 100 == 0 or i == len(ds) - 1: print(f" Encoded {i+1}/{len(ds)} images ({time.time()-t0:.1f}s)") all_latents = torch.cat(all_latents, dim=0) # (N, 32, 16, 16) print(f"All latents: {all_latents.shape}, took {time.time()-t0:.1f}s") print(f"Latent stats: mean={all_latents.mean():.3f}, std={all_latents.std():.3f}") # Free DC-AE VRAM del ae torch.cuda.empty_cache() gc.collect() print("DC-AE encoder freed from VRAM.") # ============================================================ # CELL 6: Encode all captions to text embeddings # ============================================================ print("\nLoading text encoder (~87 MB)...") from sentence_transformers import SentenceTransformer text_encoder = SentenceTransformer( "sentence-transformers/all-MiniLM-L6-v2", device=str(device), ) text_encoder.eval() captions = [ex["text"] for ex in ds] print(f"Encoding {len(captions)} captions...") with torch.no_grad(): all_text_embs = text_encoder.encode( captions, convert_to_tensor=True, normalize_embeddings=True, batch_size=128, show_progress_bar=True, ) # Expand to sequence format: (N, 1, 384) # The model projects 384 -> model_dim via registered context_proj all_text_embs = all_text_embs.unsqueeze(1).cpu() # (N, 1, 384) print(f"Text embeddings: {all_text_embs.shape}") # Free text encoder VRAM del text_encoder torch.cuda.empty_cache() gc.collect() print("Text encoder freed from VRAM.") # ============================================================ # CELL 7: Create dataset from precomputed features # ============================================================ from torch.utils.data import Dataset, DataLoader class PrecomputedLatentDataset(Dataset): """All latents and text embeddings precomputed — zero I/O during training.""" def __init__(self, latents, text_embs): self.latents = latents self.text_embs = text_embs def __len__(self): return len(self.latents) def __getitem__(self, idx): return { "latent": self.latents[idx], "text_embed": self.text_embs[idx], } train_ds = PrecomputedLatentDataset(all_latents, all_text_embs) print(f"Training dataset: {len(train_ds)} samples") print(f" Latent: {train_ds[0]['latent'].shape}") print(f" Text: {train_ds[0]['text_embed'].shape}") # ============================================================ # CELL 8: Create IRIS model # ============================================================ print("\nCreating IRIS-Small model...") model = IRIS( **get_model_config("iris-small"), gradient_checkpointing=True, text_dim=384, # all-MiniLM-L6-v2 output dim — registered as proper nn.Module ).to(device) counts = model.count_params() print(f"Parameters: {counts['total']:,} ({counts['total']/1e6:.1f}M)") print(f" Core: {counts['core']:,}") print(f" Decoder: {counts['tiny_decoder']:,}") if device.type == "cuda": print(f"VRAM used: {torch.cuda.memory_allocated()/1e9:.2f} GB / {torch.cuda.get_device_properties(0).total_memory/1e9:.1f} GB") # ============================================================ # CELL 9: Train! # ============================================================ import math from iris.train import CosineWarmupScheduler from iris.flow_matching import flow_matching_loss # Training config — tuned for Colab T4 with 833 Pokemon images NUM_STEPS = 3000 # ~15 min on T4 BATCH_SIZE = 16 # fits T4 with IRIS-Small + grad checkpoint LR = 3e-4 # slightly higher LR for small dataset WARMUP_STEPS = 200 GRAD_CLIP = 1.0 NUM_ITERS = 3 # refinement iterations (3 is good for speed/quality) LOG_EVERY = 50 SAVE_EVERY = 1000 loader = DataLoader( train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True, drop_last=True, persistent_workers=True, ) optimizer = torch.optim.AdamW( model.parameters(), lr=LR, weight_decay=0.01, betas=(0.9, 0.999), ) scheduler = CosineWarmupScheduler(optimizer, WARMUP_STEPS, NUM_STEPS, min_lr_ratio=0.05) scaler = torch.amp.GradScaler(enabled=(use_amp and amp_dtype == torch.float16)) model.train() step = 0 epoch = 0 running_loss = 0.0 loss_history = [] best_loss = float("inf") t_start = time.time() print(f"\n{'='*60}") print(f"Training IRIS-Small on Pokemon BLIP Captions") print(f" {len(train_ds)} images, {NUM_STEPS} steps, BS={BATCH_SIZE}, R={NUM_ITERS}") print(f" LR={LR}, warmup={WARMUP_STEPS}, AMP={amp_dtype}") print(f"{'='*60}\n") while step < NUM_STEPS: epoch += 1 for batch in loader: if step >= NUM_STEPS: break latent = batch["latent"].to(device, non_blocking=True) text_embed = batch["text_embed"].to(device, non_blocking=True) with torch.amp.autocast(device_type=device.type, dtype=amp_dtype, enabled=use_amp): losses = flow_matching_loss( model, latent, text_embed, num_iterations=NUM_ITERS, timestep_sampling="logit_normal", scale_factor=SCALE, ) loss = losses["loss"] optimizer.zero_grad(set_to_none=True) if scaler.is_enabled(): scaler.scale(loss).backward() scaler.unscale_(optimizer) gn = torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP) scaler.step(optimizer) scaler.update() else: loss.backward() gn = torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP) optimizer.step() scheduler.step() step += 1 lv = loss.item() running_loss += lv loss_history.append(lv) if step % LOG_EVERY == 0: avg = running_loss / LOG_EVERY elapsed = time.time() - t_start sps = step / elapsed eta = (NUM_STEPS - step) / sps lr = scheduler.get_lr()[0] gn_val = gn.item() if isinstance(gn, torch.Tensor) else gn tag = "OK" if not (math.isnan(avg) or math.isinf(avg)) else "!!" print( f"[{tag}] step {step:>5d}/{NUM_STEPS} | " f"loss={avg:.4f} | " f"grad={gn_val:.3f} | " f"lr={lr:.1e} | " f"{sps:.1f} steps/s | " f"ETA {eta/60:.0f}min" ) if avg < best_loss: best_loss = avg running_loss = 0.0 if step % SAVE_EVERY == 0: os.makedirs("./iris_checkpoints", exist_ok=True) p = f"./iris_checkpoints/iris_pokemon_step{step}.pt" torch.save({ "step": step, "model_state_dict": model.state_dict(), "loss_history": loss_history, "config": get_model_config("iris-small"), }, p) print(f" Saved: {p}") # Final save os.makedirs("./iris_checkpoints", exist_ok=True) final_path = "./iris_checkpoints/iris_pokemon_final.pt" torch.save({ "step": step, "model_state_dict": model.state_dict(), "loss_history": loss_history, "config": get_model_config("iris-small"), }, final_path) total_time = time.time() - t_start f50 = sum(loss_history[:50]) / min(50, len(loss_history)) l50 = sum(loss_history[-50:]) / min(50, len(loss_history)) print(f"\n{'='*60}") print(f"Training complete!") print(f" {step} steps in {total_time/60:.1f} min ({step/total_time:.1f} steps/s)") print(f" Loss: {f50:.4f} -> {l50:.4f} ({(1-l50/f50)*100:.1f}% reduction)") print(f" Best: {best_loss:.4f}") print(f" Saved: {final_path}") print(f"{'='*60}") # ============================================================ # CELL 10: Plot training loss # ============================================================ try: import matplotlib.pyplot as plt fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5)) ax1.plot(loss_history, alpha=0.3, color="blue", linewidth=0.5) window = 50 if len(loss_history) > window: smoothed = [sum(loss_history[max(0,i-window):i+1])/min(i+1, window) for i in range(len(loss_history))] ax1.plot(smoothed, color="red", linewidth=2, label=f"Smoothed (w={window})") ax1.set_xlabel("Step") ax1.set_ylabel("Flow Matching Loss") ax1.set_title("Training Loss") ax1.legend() ax1.grid(True, alpha=0.3) chunks = [loss_history[i:i+100] for i in range(0, len(loss_history), 100)] if len(chunks) > 1: ax2.boxplot([c for c in chunks], positions=list(range(len(chunks)))) ax2.set_xlabel("Step (x100)") ax2.set_ylabel("Loss") ax2.set_title("Loss Distribution Over Time") ax2.grid(True, alpha=0.3) plt.tight_layout() plt.savefig("./iris_checkpoints/training_loss.png", dpi=100) plt.show() print("Loss plot saved.") except ImportError: print("matplotlib not available, skipping loss plot") # ============================================================ # CELL 11: Generate sample images from trained model # ============================================================ print("\nGenerating sample images from trained model...") # Reload DC-AE decoder for visualization ae_decoder = AutoencoderDC.from_pretrained( "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers", torch_dtype=torch.float16, ).to(device).eval() ae_decoder.requires_grad_(False) # Reload text encoder for new prompts text_enc = SentenceTransformer( "sentence-transformers/all-MiniLM-L6-v2", device=str(device), ) model.eval() sample_prompts = [ "a blue water pokemon with fins", "a fire dragon pokemon with wings", "a cute pink pokemon with big eyes", "a green grass pokemon", ] for i, prompt in enumerate(sample_prompts): with torch.no_grad(): txt_emb = text_enc.encode( [prompt], convert_to_tensor=True, normalize_embeddings=True ).unsqueeze(1).to(device) # (1, 1, 384) noise = torch.randn(1, 32, 16, 16, device=device) with torch.no_grad(): z_pred = euler_sample( model, noise, txt_emb, num_steps=20, num_iterations=NUM_ITERS, cfg_scale=1.0, scale_factor=SCALE, ) img = ae_decoder.decode(z_pred.half()).sample img = (img.float().clamp(-1, 1) * 0.5 + 0.5) from torchvision.utils import save_image fname = f"./iris_checkpoints/sample_{i}_{prompt[:20].replace(' ','_')}.png" save_image(img, fname) print(f" Sample {i}: '{prompt}' -> {fname}") print("\nAll samples saved to ./iris_checkpoints/") print("NOTE: Trained on 833 images for 3000 steps — quality improves with more data + steps.")