iris-image-gen / colab_train_iris.py
asdf98's picture
Fix conv2d bf16 crash on T4: colab_train_iris.py
e90110a verified
#!/usr/bin/env python3
"""
IRIS Colab Training — One-Click, Real Dataset, Real Learning
=============================================================
Copy-paste into Google Colab (free tier T4) and run all cells.
Trains IRIS on Pokemon BLIP Captions (833 images + text).
Colab free tier specs (2025):
- GPU: NVIDIA T4 (16 GB VRAM)
- System RAM: ~12.7 GB
- Disk: ~78 GB
- PyTorch: 2.5+ preinstalled
- Runtime: ~12 hours max session
What this script does:
1. Installs dependencies (~30s)
2. Downloads IRIS source from HF Hub
3. Downloads DC-AE encoder (1.2 GB) + text encoder (87 MB)
4. Encodes all 833 Pokemon images to latents (~2 min on T4)
5. Encodes all captions to text embeddings (~5s)
6. Frees encoder VRAM
7. Trains IRIS-Small (40M params) for 3000 steps (~15 min on T4)
8. Generates sample images from trained model
9. Saves checkpoint
Total wall time: ~20 minutes for a trained model.
"""
# ============================================================
# CELL 1: Install dependencies
# ============================================================
print("Installing dependencies...")
import subprocess, sys
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q",
"diffusers>=0.32.0",
"sentence-transformers",
"datasets",
"accelerate",
"huggingface_hub",
])
print("Done.")
# ============================================================
# CELL 2: Download IRIS source code
# ============================================================
print("Downloading IRIS architecture from HF Hub...")
from huggingface_hub import snapshot_download
import os, shutil
iris_path = snapshot_download(
repo_id="asdf98/iris-image-gen",
allow_patterns=["iris/*.py"],
local_dir="./iris_repo",
)
# Add to Python path
sys.path.insert(0, os.path.join(iris_path))
print(f"IRIS source at: {iris_path}")
# Verify import
from iris import IRIS, get_model_config, flow_matching_loss, euler_sample
from iris.flow_matching import DCAE_F32C32_SCALE
print("IRIS imported successfully.")
# ============================================================
# CELL 3: Detect hardware
# ============================================================
import torch
import gc
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device.type == "cuda":
gpu_name = torch.cuda.get_device_name(0)
gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
print(f"GPU: {gpu_name} ({gpu_mem:.1f} GB)")
else:
print("WARNING: No GPU detected. Training will be very slow.")
print("In Colab: Runtime -> Change runtime type -> T4 GPU")
use_amp = device.type == "cuda"
# T4 (compute capability 7.5) reports bf16 supported but cuDNN conv2d kernels
# lack bf16 engines → crashes at runtime. Force fp16 which T4 natively supports.
if use_amp:
cc = torch.cuda.get_device_capability(0)
if cc[0] < 8: # Ampere (8.0+) has native bf16; Turing (7.5) does not
amp_dtype = torch.float16
print(f"GPU compute capability {cc[0]}.{cc[1]} — using fp16 (bf16 conv kernels unavailable)")
else:
amp_dtype = torch.bfloat16
print(f"GPU compute capability {cc[0]}.{cc[1]} — using bf16")
else:
amp_dtype = torch.float32
print(f"AMP dtype: {amp_dtype}")
# ============================================================
# CELL 4: Load dataset
# ============================================================
print("\nLoading Pokemon BLIP Captions dataset...")
from datasets import load_dataset
ds = load_dataset("reach-vb/pokemon-blip-captions", split="train")
print(f"Loaded {len(ds)} images with captions.")
print(f"Example: '{ds[0]['text']}'")
# ============================================================
# CELL 5: Encode all images to DC-AE latents
# ============================================================
print("\nLoading DC-AE encoder (~1.2 GB)...")
from diffusers import AutoencoderDC
import torchvision.transforms as T
# Use float16 to save VRAM — stable for inference
ae = AutoencoderDC.from_pretrained(
"mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers",
torch_dtype=torch.float16,
).to(device).eval()
ae.requires_grad_(False)
SCALE = ae.config.scaling_factor # 0.41407
transform = T.Compose([
T.Resize(512, interpolation=T.InterpolationMode.BICUBIC, antialias=True),
T.CenterCrop(512),
T.ToTensor(),
T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
print("Encoding images to latents...")
all_latents = []
import time
t0 = time.time()
batch_imgs = []
for i, example in enumerate(ds):
img = example["image"].convert("RGB")
tensor = transform(img)
batch_imgs.append(tensor)
# Process in batches of 8
if len(batch_imgs) == 8 or i == len(ds) - 1:
batch = torch.stack(batch_imgs).to(device, dtype=torch.float16)
with torch.no_grad():
latent = ae.encode(batch).latent.float() # encode in fp16, store in fp32
all_latents.append(latent.cpu())
batch_imgs = []
if (i + 1) % 100 == 0 or i == len(ds) - 1:
print(f" Encoded {i+1}/{len(ds)} images ({time.time()-t0:.1f}s)")
all_latents = torch.cat(all_latents, dim=0) # (N, 32, 16, 16)
print(f"All latents: {all_latents.shape}, took {time.time()-t0:.1f}s")
print(f"Latent stats: mean={all_latents.mean():.3f}, std={all_latents.std():.3f}")
# Free DC-AE VRAM
del ae
torch.cuda.empty_cache()
gc.collect()
print("DC-AE encoder freed from VRAM.")
# ============================================================
# CELL 6: Encode all captions to text embeddings
# ============================================================
print("\nLoading text encoder (~87 MB)...")
from sentence_transformers import SentenceTransformer
text_encoder = SentenceTransformer(
"sentence-transformers/all-MiniLM-L6-v2",
device=str(device),
)
text_encoder.eval()
captions = [ex["text"] for ex in ds]
print(f"Encoding {len(captions)} captions...")
with torch.no_grad():
all_text_embs = text_encoder.encode(
captions,
convert_to_tensor=True,
normalize_embeddings=True,
batch_size=128,
show_progress_bar=True,
)
# Expand to sequence format: (N, 1, 384)
# The model projects 384 -> model_dim via registered context_proj
all_text_embs = all_text_embs.unsqueeze(1).cpu() # (N, 1, 384)
print(f"Text embeddings: {all_text_embs.shape}")
# Free text encoder VRAM
del text_encoder
torch.cuda.empty_cache()
gc.collect()
print("Text encoder freed from VRAM.")
# ============================================================
# CELL 7: Create dataset from precomputed features
# ============================================================
from torch.utils.data import Dataset, DataLoader
class PrecomputedLatentDataset(Dataset):
"""All latents and text embeddings precomputed — zero I/O during training."""
def __init__(self, latents, text_embs):
self.latents = latents
self.text_embs = text_embs
def __len__(self):
return len(self.latents)
def __getitem__(self, idx):
return {
"latent": self.latents[idx],
"text_embed": self.text_embs[idx],
}
train_ds = PrecomputedLatentDataset(all_latents, all_text_embs)
print(f"Training dataset: {len(train_ds)} samples")
print(f" Latent: {train_ds[0]['latent'].shape}")
print(f" Text: {train_ds[0]['text_embed'].shape}")
# ============================================================
# CELL 8: Create IRIS model
# ============================================================
print("\nCreating IRIS-Small model...")
model = IRIS(
**get_model_config("iris-small"),
gradient_checkpointing=True,
text_dim=384, # all-MiniLM-L6-v2 output dim — registered as proper nn.Module
).to(device)
counts = model.count_params()
print(f"Parameters: {counts['total']:,} ({counts['total']/1e6:.1f}M)")
print(f" Core: {counts['core']:,}")
print(f" Decoder: {counts['tiny_decoder']:,}")
if device.type == "cuda":
print(f"VRAM used: {torch.cuda.memory_allocated()/1e9:.2f} GB / {torch.cuda.get_device_properties(0).total_memory/1e9:.1f} GB")
# ============================================================
# CELL 9: Train!
# ============================================================
import math
from iris.train import CosineWarmupScheduler
from iris.flow_matching import flow_matching_loss
# Training config — tuned for Colab T4 with 833 Pokemon images
NUM_STEPS = 3000 # ~15 min on T4
BATCH_SIZE = 16 # fits T4 with IRIS-Small + grad checkpoint
LR = 3e-4 # slightly higher LR for small dataset
WARMUP_STEPS = 200
GRAD_CLIP = 1.0
NUM_ITERS = 3 # refinement iterations (3 is good for speed/quality)
LOG_EVERY = 50
SAVE_EVERY = 1000
loader = DataLoader(
train_ds,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=2,
pin_memory=True,
drop_last=True,
persistent_workers=True,
)
optimizer = torch.optim.AdamW(
model.parameters(),
lr=LR,
weight_decay=0.01,
betas=(0.9, 0.999),
)
scheduler = CosineWarmupScheduler(optimizer, WARMUP_STEPS, NUM_STEPS, min_lr_ratio=0.05)
scaler = torch.amp.GradScaler(enabled=(use_amp and amp_dtype == torch.float16))
model.train()
step = 0
epoch = 0
running_loss = 0.0
loss_history = []
best_loss = float("inf")
t_start = time.time()
print(f"\n{'='*60}")
print(f"Training IRIS-Small on Pokemon BLIP Captions")
print(f" {len(train_ds)} images, {NUM_STEPS} steps, BS={BATCH_SIZE}, R={NUM_ITERS}")
print(f" LR={LR}, warmup={WARMUP_STEPS}, AMP={amp_dtype}")
print(f"{'='*60}\n")
while step < NUM_STEPS:
epoch += 1
for batch in loader:
if step >= NUM_STEPS:
break
latent = batch["latent"].to(device, non_blocking=True)
text_embed = batch["text_embed"].to(device, non_blocking=True)
with torch.amp.autocast(device_type=device.type, dtype=amp_dtype, enabled=use_amp):
losses = flow_matching_loss(
model, latent, text_embed,
num_iterations=NUM_ITERS,
timestep_sampling="logit_normal",
scale_factor=SCALE,
)
loss = losses["loss"]
optimizer.zero_grad(set_to_none=True)
if scaler.is_enabled():
scaler.scale(loss).backward()
scaler.unscale_(optimizer)
gn = torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
scaler.step(optimizer)
scaler.update()
else:
loss.backward()
gn = torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
optimizer.step()
scheduler.step()
step += 1
lv = loss.item()
running_loss += lv
loss_history.append(lv)
if step % LOG_EVERY == 0:
avg = running_loss / LOG_EVERY
elapsed = time.time() - t_start
sps = step / elapsed
eta = (NUM_STEPS - step) / sps
lr = scheduler.get_lr()[0]
gn_val = gn.item() if isinstance(gn, torch.Tensor) else gn
tag = "OK" if not (math.isnan(avg) or math.isinf(avg)) else "!!"
print(
f"[{tag}] step {step:>5d}/{NUM_STEPS} | "
f"loss={avg:.4f} | "
f"grad={gn_val:.3f} | "
f"lr={lr:.1e} | "
f"{sps:.1f} steps/s | "
f"ETA {eta/60:.0f}min"
)
if avg < best_loss:
best_loss = avg
running_loss = 0.0
if step % SAVE_EVERY == 0:
os.makedirs("./iris_checkpoints", exist_ok=True)
p = f"./iris_checkpoints/iris_pokemon_step{step}.pt"
torch.save({
"step": step,
"model_state_dict": model.state_dict(),
"loss_history": loss_history,
"config": get_model_config("iris-small"),
}, p)
print(f" Saved: {p}")
# Final save
os.makedirs("./iris_checkpoints", exist_ok=True)
final_path = "./iris_checkpoints/iris_pokemon_final.pt"
torch.save({
"step": step,
"model_state_dict": model.state_dict(),
"loss_history": loss_history,
"config": get_model_config("iris-small"),
}, final_path)
total_time = time.time() - t_start
f50 = sum(loss_history[:50]) / min(50, len(loss_history))
l50 = sum(loss_history[-50:]) / min(50, len(loss_history))
print(f"\n{'='*60}")
print(f"Training complete!")
print(f" {step} steps in {total_time/60:.1f} min ({step/total_time:.1f} steps/s)")
print(f" Loss: {f50:.4f} -> {l50:.4f} ({(1-l50/f50)*100:.1f}% reduction)")
print(f" Best: {best_loss:.4f}")
print(f" Saved: {final_path}")
print(f"{'='*60}")
# ============================================================
# CELL 10: Plot training loss
# ============================================================
try:
import matplotlib.pyplot as plt
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
ax1.plot(loss_history, alpha=0.3, color="blue", linewidth=0.5)
window = 50
if len(loss_history) > window:
smoothed = [sum(loss_history[max(0,i-window):i+1])/min(i+1, window) for i in range(len(loss_history))]
ax1.plot(smoothed, color="red", linewidth=2, label=f"Smoothed (w={window})")
ax1.set_xlabel("Step")
ax1.set_ylabel("Flow Matching Loss")
ax1.set_title("Training Loss")
ax1.legend()
ax1.grid(True, alpha=0.3)
chunks = [loss_history[i:i+100] for i in range(0, len(loss_history), 100)]
if len(chunks) > 1:
ax2.boxplot([c for c in chunks], positions=list(range(len(chunks))))
ax2.set_xlabel("Step (x100)")
ax2.set_ylabel("Loss")
ax2.set_title("Loss Distribution Over Time")
ax2.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig("./iris_checkpoints/training_loss.png", dpi=100)
plt.show()
print("Loss plot saved.")
except ImportError:
print("matplotlib not available, skipping loss plot")
# ============================================================
# CELL 11: Generate sample images from trained model
# ============================================================
print("\nGenerating sample images from trained model...")
# Reload DC-AE decoder for visualization
ae_decoder = AutoencoderDC.from_pretrained(
"mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers",
torch_dtype=torch.float16,
).to(device).eval()
ae_decoder.requires_grad_(False)
# Reload text encoder for new prompts
text_enc = SentenceTransformer(
"sentence-transformers/all-MiniLM-L6-v2",
device=str(device),
)
model.eval()
sample_prompts = [
"a blue water pokemon with fins",
"a fire dragon pokemon with wings",
"a cute pink pokemon with big eyes",
"a green grass pokemon",
]
for i, prompt in enumerate(sample_prompts):
with torch.no_grad():
txt_emb = text_enc.encode(
[prompt], convert_to_tensor=True, normalize_embeddings=True
).unsqueeze(1).to(device) # (1, 1, 384)
noise = torch.randn(1, 32, 16, 16, device=device)
with torch.no_grad():
z_pred = euler_sample(
model, noise, txt_emb,
num_steps=20,
num_iterations=NUM_ITERS,
cfg_scale=1.0,
scale_factor=SCALE,
)
img = ae_decoder.decode(z_pred.half()).sample
img = (img.float().clamp(-1, 1) * 0.5 + 0.5)
from torchvision.utils import save_image
fname = f"./iris_checkpoints/sample_{i}_{prompt[:20].replace(' ','_')}.png"
save_image(img, fname)
print(f" Sample {i}: '{prompt}' -> {fname}")
print("\nAll samples saved to ./iris_checkpoints/")
print("NOTE: Trained on 833 images for 3000 steps — quality improves with more data + steps.")