| |
| """ |
| IRIS Colab Training — One-Click, Real Dataset, Real Learning |
| ============================================================= |
| |
| Copy-paste into Google Colab (free tier T4) and run all cells. |
| Trains IRIS on Pokemon BLIP Captions (833 images + text). |
| |
| Colab free tier specs (2025): |
| - GPU: NVIDIA T4 (16 GB VRAM) |
| - System RAM: ~12.7 GB |
| - Disk: ~78 GB |
| - PyTorch: 2.5+ preinstalled |
| - Runtime: ~12 hours max session |
| |
| What this script does: |
| 1. Installs dependencies (~30s) |
| 2. Downloads IRIS source from HF Hub |
| 3. Downloads DC-AE encoder (1.2 GB) + text encoder (87 MB) |
| 4. Encodes all 833 Pokemon images to latents (~2 min on T4) |
| 5. Encodes all captions to text embeddings (~5s) |
| 6. Frees encoder VRAM |
| 7. Trains IRIS-Small (40M params) for 3000 steps (~15 min on T4) |
| 8. Generates sample images from trained model |
| 9. Saves checkpoint |
| |
| Total wall time: ~20 minutes for a trained model. |
| """ |
|
|
| |
| |
| |
| print("Installing dependencies...") |
| import subprocess, sys |
|
|
| subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", |
| "diffusers>=0.32.0", |
| "sentence-transformers", |
| "datasets", |
| "accelerate", |
| "huggingface_hub", |
| ]) |
| print("Done.") |
|
|
| |
| |
| |
| print("Downloading IRIS architecture from HF Hub...") |
| from huggingface_hub import snapshot_download |
| import os, shutil |
|
|
| iris_path = snapshot_download( |
| repo_id="asdf98/iris-image-gen", |
| allow_patterns=["iris/*.py"], |
| local_dir="./iris_repo", |
| ) |
| |
| sys.path.insert(0, os.path.join(iris_path)) |
| print(f"IRIS source at: {iris_path}") |
|
|
| |
| from iris import IRIS, get_model_config, flow_matching_loss, euler_sample |
| from iris.flow_matching import DCAE_F32C32_SCALE |
| print("IRIS imported successfully.") |
|
|
| |
| |
| |
| import torch |
| import gc |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| if device.type == "cuda": |
| gpu_name = torch.cuda.get_device_name(0) |
| gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9 |
| print(f"GPU: {gpu_name} ({gpu_mem:.1f} GB)") |
| else: |
| print("WARNING: No GPU detected. Training will be very slow.") |
| print("In Colab: Runtime -> Change runtime type -> T4 GPU") |
|
|
| use_amp = device.type == "cuda" |
| |
| |
| if use_amp: |
| cc = torch.cuda.get_device_capability(0) |
| if cc[0] < 8: |
| amp_dtype = torch.float16 |
| print(f"GPU compute capability {cc[0]}.{cc[1]} — using fp16 (bf16 conv kernels unavailable)") |
| else: |
| amp_dtype = torch.bfloat16 |
| print(f"GPU compute capability {cc[0]}.{cc[1]} — using bf16") |
| else: |
| amp_dtype = torch.float32 |
| print(f"AMP dtype: {amp_dtype}") |
|
|
| |
| |
| |
| print("\nLoading Pokemon BLIP Captions dataset...") |
| from datasets import load_dataset |
|
|
| ds = load_dataset("reach-vb/pokemon-blip-captions", split="train") |
| print(f"Loaded {len(ds)} images with captions.") |
| print(f"Example: '{ds[0]['text']}'") |
|
|
| |
| |
| |
| print("\nLoading DC-AE encoder (~1.2 GB)...") |
| from diffusers import AutoencoderDC |
| import torchvision.transforms as T |
|
|
| |
| ae = AutoencoderDC.from_pretrained( |
| "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers", |
| torch_dtype=torch.float16, |
| ).to(device).eval() |
| ae.requires_grad_(False) |
|
|
| SCALE = ae.config.scaling_factor |
|
|
| transform = T.Compose([ |
| T.Resize(512, interpolation=T.InterpolationMode.BICUBIC, antialias=True), |
| T.CenterCrop(512), |
| T.ToTensor(), |
| T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), |
| ]) |
|
|
| print("Encoding images to latents...") |
| all_latents = [] |
| import time |
| t0 = time.time() |
|
|
| batch_imgs = [] |
| for i, example in enumerate(ds): |
| img = example["image"].convert("RGB") |
| tensor = transform(img) |
| batch_imgs.append(tensor) |
|
|
| |
| if len(batch_imgs) == 8 or i == len(ds) - 1: |
| batch = torch.stack(batch_imgs).to(device, dtype=torch.float16) |
| with torch.no_grad(): |
| latent = ae.encode(batch).latent.float() |
| all_latents.append(latent.cpu()) |
| batch_imgs = [] |
|
|
| if (i + 1) % 100 == 0 or i == len(ds) - 1: |
| print(f" Encoded {i+1}/{len(ds)} images ({time.time()-t0:.1f}s)") |
|
|
| all_latents = torch.cat(all_latents, dim=0) |
| print(f"All latents: {all_latents.shape}, took {time.time()-t0:.1f}s") |
| print(f"Latent stats: mean={all_latents.mean():.3f}, std={all_latents.std():.3f}") |
|
|
| |
| del ae |
| torch.cuda.empty_cache() |
| gc.collect() |
| print("DC-AE encoder freed from VRAM.") |
|
|
| |
| |
| |
| print("\nLoading text encoder (~87 MB)...") |
| from sentence_transformers import SentenceTransformer |
|
|
| text_encoder = SentenceTransformer( |
| "sentence-transformers/all-MiniLM-L6-v2", |
| device=str(device), |
| ) |
| text_encoder.eval() |
|
|
| captions = [ex["text"] for ex in ds] |
| print(f"Encoding {len(captions)} captions...") |
|
|
| with torch.no_grad(): |
| all_text_embs = text_encoder.encode( |
| captions, |
| convert_to_tensor=True, |
| normalize_embeddings=True, |
| batch_size=128, |
| show_progress_bar=True, |
| ) |
|
|
| |
| |
| all_text_embs = all_text_embs.unsqueeze(1).cpu() |
| print(f"Text embeddings: {all_text_embs.shape}") |
|
|
| |
| del text_encoder |
| torch.cuda.empty_cache() |
| gc.collect() |
| print("Text encoder freed from VRAM.") |
|
|
| |
| |
| |
| from torch.utils.data import Dataset, DataLoader |
|
|
| class PrecomputedLatentDataset(Dataset): |
| """All latents and text embeddings precomputed — zero I/O during training.""" |
| def __init__(self, latents, text_embs): |
| self.latents = latents |
| self.text_embs = text_embs |
|
|
| def __len__(self): |
| return len(self.latents) |
|
|
| def __getitem__(self, idx): |
| return { |
| "latent": self.latents[idx], |
| "text_embed": self.text_embs[idx], |
| } |
|
|
| train_ds = PrecomputedLatentDataset(all_latents, all_text_embs) |
| print(f"Training dataset: {len(train_ds)} samples") |
| print(f" Latent: {train_ds[0]['latent'].shape}") |
| print(f" Text: {train_ds[0]['text_embed'].shape}") |
|
|
| |
| |
| |
| print("\nCreating IRIS-Small model...") |
|
|
| model = IRIS( |
| **get_model_config("iris-small"), |
| gradient_checkpointing=True, |
| text_dim=384, |
| ).to(device) |
|
|
| counts = model.count_params() |
| print(f"Parameters: {counts['total']:,} ({counts['total']/1e6:.1f}M)") |
| print(f" Core: {counts['core']:,}") |
| print(f" Decoder: {counts['tiny_decoder']:,}") |
|
|
| if device.type == "cuda": |
| print(f"VRAM used: {torch.cuda.memory_allocated()/1e9:.2f} GB / {torch.cuda.get_device_properties(0).total_memory/1e9:.1f} GB") |
|
|
| |
| |
| |
| import math |
| from iris.train import CosineWarmupScheduler |
| from iris.flow_matching import flow_matching_loss |
|
|
| |
| NUM_STEPS = 3000 |
| BATCH_SIZE = 16 |
| LR = 3e-4 |
| WARMUP_STEPS = 200 |
| GRAD_CLIP = 1.0 |
| NUM_ITERS = 3 |
| LOG_EVERY = 50 |
| SAVE_EVERY = 1000 |
|
|
| loader = DataLoader( |
| train_ds, |
| batch_size=BATCH_SIZE, |
| shuffle=True, |
| num_workers=2, |
| pin_memory=True, |
| drop_last=True, |
| persistent_workers=True, |
| ) |
|
|
| optimizer = torch.optim.AdamW( |
| model.parameters(), |
| lr=LR, |
| weight_decay=0.01, |
| betas=(0.9, 0.999), |
| ) |
| scheduler = CosineWarmupScheduler(optimizer, WARMUP_STEPS, NUM_STEPS, min_lr_ratio=0.05) |
| scaler = torch.amp.GradScaler(enabled=(use_amp and amp_dtype == torch.float16)) |
|
|
| model.train() |
| step = 0 |
| epoch = 0 |
| running_loss = 0.0 |
| loss_history = [] |
| best_loss = float("inf") |
| t_start = time.time() |
|
|
| print(f"\n{'='*60}") |
| print(f"Training IRIS-Small on Pokemon BLIP Captions") |
| print(f" {len(train_ds)} images, {NUM_STEPS} steps, BS={BATCH_SIZE}, R={NUM_ITERS}") |
| print(f" LR={LR}, warmup={WARMUP_STEPS}, AMP={amp_dtype}") |
| print(f"{'='*60}\n") |
|
|
| while step < NUM_STEPS: |
| epoch += 1 |
| for batch in loader: |
| if step >= NUM_STEPS: |
| break |
|
|
| latent = batch["latent"].to(device, non_blocking=True) |
| text_embed = batch["text_embed"].to(device, non_blocking=True) |
|
|
| with torch.amp.autocast(device_type=device.type, dtype=amp_dtype, enabled=use_amp): |
| losses = flow_matching_loss( |
| model, latent, text_embed, |
| num_iterations=NUM_ITERS, |
| timestep_sampling="logit_normal", |
| scale_factor=SCALE, |
| ) |
| loss = losses["loss"] |
|
|
| optimizer.zero_grad(set_to_none=True) |
| if scaler.is_enabled(): |
| scaler.scale(loss).backward() |
| scaler.unscale_(optimizer) |
| gn = torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP) |
| scaler.step(optimizer) |
| scaler.update() |
| else: |
| loss.backward() |
| gn = torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP) |
| optimizer.step() |
|
|
| scheduler.step() |
| step += 1 |
| lv = loss.item() |
| running_loss += lv |
| loss_history.append(lv) |
|
|
| if step % LOG_EVERY == 0: |
| avg = running_loss / LOG_EVERY |
| elapsed = time.time() - t_start |
| sps = step / elapsed |
| eta = (NUM_STEPS - step) / sps |
| lr = scheduler.get_lr()[0] |
| gn_val = gn.item() if isinstance(gn, torch.Tensor) else gn |
| tag = "OK" if not (math.isnan(avg) or math.isinf(avg)) else "!!" |
|
|
| print( |
| f"[{tag}] step {step:>5d}/{NUM_STEPS} | " |
| f"loss={avg:.4f} | " |
| f"grad={gn_val:.3f} | " |
| f"lr={lr:.1e} | " |
| f"{sps:.1f} steps/s | " |
| f"ETA {eta/60:.0f}min" |
| ) |
|
|
| if avg < best_loss: |
| best_loss = avg |
| running_loss = 0.0 |
|
|
| if step % SAVE_EVERY == 0: |
| os.makedirs("./iris_checkpoints", exist_ok=True) |
| p = f"./iris_checkpoints/iris_pokemon_step{step}.pt" |
| torch.save({ |
| "step": step, |
| "model_state_dict": model.state_dict(), |
| "loss_history": loss_history, |
| "config": get_model_config("iris-small"), |
| }, p) |
| print(f" Saved: {p}") |
|
|
| |
| os.makedirs("./iris_checkpoints", exist_ok=True) |
| final_path = "./iris_checkpoints/iris_pokemon_final.pt" |
| torch.save({ |
| "step": step, |
| "model_state_dict": model.state_dict(), |
| "loss_history": loss_history, |
| "config": get_model_config("iris-small"), |
| }, final_path) |
|
|
| total_time = time.time() - t_start |
| f50 = sum(loss_history[:50]) / min(50, len(loss_history)) |
| l50 = sum(loss_history[-50:]) / min(50, len(loss_history)) |
| print(f"\n{'='*60}") |
| print(f"Training complete!") |
| print(f" {step} steps in {total_time/60:.1f} min ({step/total_time:.1f} steps/s)") |
| print(f" Loss: {f50:.4f} -> {l50:.4f} ({(1-l50/f50)*100:.1f}% reduction)") |
| print(f" Best: {best_loss:.4f}") |
| print(f" Saved: {final_path}") |
| print(f"{'='*60}") |
|
|
| |
| |
| |
| try: |
| import matplotlib.pyplot as plt |
|
|
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5)) |
|
|
| ax1.plot(loss_history, alpha=0.3, color="blue", linewidth=0.5) |
| window = 50 |
| if len(loss_history) > window: |
| smoothed = [sum(loss_history[max(0,i-window):i+1])/min(i+1, window) for i in range(len(loss_history))] |
| ax1.plot(smoothed, color="red", linewidth=2, label=f"Smoothed (w={window})") |
| ax1.set_xlabel("Step") |
| ax1.set_ylabel("Flow Matching Loss") |
| ax1.set_title("Training Loss") |
| ax1.legend() |
| ax1.grid(True, alpha=0.3) |
|
|
| chunks = [loss_history[i:i+100] for i in range(0, len(loss_history), 100)] |
| if len(chunks) > 1: |
| ax2.boxplot([c for c in chunks], positions=list(range(len(chunks)))) |
| ax2.set_xlabel("Step (x100)") |
| ax2.set_ylabel("Loss") |
| ax2.set_title("Loss Distribution Over Time") |
| ax2.grid(True, alpha=0.3) |
|
|
| plt.tight_layout() |
| plt.savefig("./iris_checkpoints/training_loss.png", dpi=100) |
| plt.show() |
| print("Loss plot saved.") |
| except ImportError: |
| print("matplotlib not available, skipping loss plot") |
|
|
| |
| |
| |
| print("\nGenerating sample images from trained model...") |
|
|
| |
| ae_decoder = AutoencoderDC.from_pretrained( |
| "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers", |
| torch_dtype=torch.float16, |
| ).to(device).eval() |
| ae_decoder.requires_grad_(False) |
|
|
| |
| text_enc = SentenceTransformer( |
| "sentence-transformers/all-MiniLM-L6-v2", |
| device=str(device), |
| ) |
|
|
| model.eval() |
|
|
| sample_prompts = [ |
| "a blue water pokemon with fins", |
| "a fire dragon pokemon with wings", |
| "a cute pink pokemon with big eyes", |
| "a green grass pokemon", |
| ] |
|
|
| for i, prompt in enumerate(sample_prompts): |
| with torch.no_grad(): |
| txt_emb = text_enc.encode( |
| [prompt], convert_to_tensor=True, normalize_embeddings=True |
| ).unsqueeze(1).to(device) |
|
|
| noise = torch.randn(1, 32, 16, 16, device=device) |
|
|
| with torch.no_grad(): |
| z_pred = euler_sample( |
| model, noise, txt_emb, |
| num_steps=20, |
| num_iterations=NUM_ITERS, |
| cfg_scale=1.0, |
| scale_factor=SCALE, |
| ) |
| img = ae_decoder.decode(z_pred.half()).sample |
| img = (img.float().clamp(-1, 1) * 0.5 + 0.5) |
|
|
| from torchvision.utils import save_image |
| fname = f"./iris_checkpoints/sample_{i}_{prompt[:20].replace(' ','_')}.png" |
| save_image(img, fname) |
| print(f" Sample {i}: '{prompt}' -> {fname}") |
|
|
| print("\nAll samples saved to ./iris_checkpoints/") |
| print("NOTE: Trained on 833 images for 3000 steps — quality improves with more data + steps.") |
|
|