asdf98 commited on
Commit
3bbd6a9
·
verified ·
1 Parent(s): c98929a

Add one-click Colab training notebook with real Pokemon dataset

Browse files
Files changed (1) hide show
  1. colab_train_iris.py +455 -0
colab_train_iris.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ IRIS Colab Training — One-Click, Real Dataset, Real Learning
4
+ =============================================================
5
+
6
+ Copy-paste into Google Colab (free tier T4) and run all cells.
7
+ Trains IRIS on Pokemon BLIP Captions (833 images + text).
8
+
9
+ Colab free tier specs (2025):
10
+ - GPU: NVIDIA T4 (16 GB VRAM)
11
+ - System RAM: ~12.7 GB
12
+ - Disk: ~78 GB
13
+ - PyTorch: 2.5+ preinstalled
14
+ - Runtime: ~12 hours max session
15
+
16
+ What this script does:
17
+ 1. Installs dependencies (~30s)
18
+ 2. Downloads IRIS source from HF Hub
19
+ 3. Downloads DC-AE encoder (1.2 GB) + text encoder (87 MB)
20
+ 4. Encodes all 833 Pokemon images to latents (~2 min on T4)
21
+ 5. Encodes all captions to text embeddings (~5s)
22
+ 6. Frees encoder VRAM
23
+ 7. Trains IRIS-Small (40M params) for 3000 steps (~15 min on T4)
24
+ 8. Generates sample images from trained model
25
+ 9. Saves checkpoint
26
+
27
+ Total wall time: ~20 minutes for a trained model.
28
+ """
29
+
30
+ # ============================================================
31
+ # CELL 1: Install dependencies
32
+ # ============================================================
33
+ print("Installing dependencies...")
34
+ import subprocess, sys
35
+
36
+ subprocess.check_call([sys.executable, "-m", "pip", "install", "-q",
37
+ "diffusers>=0.32.0",
38
+ "sentence-transformers",
39
+ "datasets",
40
+ "accelerate",
41
+ "huggingface_hub",
42
+ ])
43
+ print("Done.")
44
+
45
+ # ============================================================
46
+ # CELL 2: Download IRIS source code
47
+ # ============================================================
48
+ print("Downloading IRIS architecture from HF Hub...")
49
+ from huggingface_hub import snapshot_download
50
+ import os, shutil
51
+
52
+ iris_path = snapshot_download(
53
+ repo_id="asdf98/iris-image-gen",
54
+ allow_patterns=["iris/*.py"],
55
+ local_dir="./iris_repo",
56
+ )
57
+ # Add to Python path
58
+ sys.path.insert(0, os.path.join(iris_path))
59
+ print(f"IRIS source at: {iris_path}")
60
+
61
+ # Verify import
62
+ from iris import IRIS, get_model_config, flow_matching_loss, euler_sample
63
+ from iris.flow_matching import DCAE_F32C32_SCALE
64
+ print("IRIS imported successfully.")
65
+
66
+ # ============================================================
67
+ # CELL 3: Detect hardware
68
+ # ============================================================
69
+ import torch
70
+ import gc
71
+
72
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
73
+ if device.type == "cuda":
74
+ gpu_name = torch.cuda.get_device_name(0)
75
+ gpu_mem = torch.cuda.get_device_properties(0).total_mem / 1e9
76
+ print(f"GPU: {gpu_name} ({gpu_mem:.1f} GB)")
77
+ else:
78
+ print("WARNING: No GPU detected. Training will be very slow.")
79
+ print("In Colab: Runtime -> Change runtime type -> T4 GPU")
80
+
81
+ use_amp = device.type == "cuda"
82
+ amp_dtype = torch.bfloat16 if (use_amp and torch.cuda.is_bf16_supported()) else torch.float16 if use_amp else torch.float32
83
+ print(f"AMP dtype: {amp_dtype}")
84
+
85
+ # ============================================================
86
+ # CELL 4: Load dataset
87
+ # ============================================================
88
+ print("\nLoading Pokemon BLIP Captions dataset...")
89
+ from datasets import load_dataset
90
+
91
+ ds = load_dataset("reach-vb/pokemon-blip-captions", split="train")
92
+ print(f"Loaded {len(ds)} images with captions.")
93
+ print(f"Example: '{ds[0]['text']}'")
94
+
95
+ # ============================================================
96
+ # CELL 5: Encode all images to DC-AE latents
97
+ # ============================================================
98
+ print("\nLoading DC-AE encoder (~1.2 GB)...")
99
+ from diffusers import AutoencoderDC
100
+ import torchvision.transforms as T
101
+
102
+ # Use float16 to save VRAM — stable for inference
103
+ ae = AutoencoderDC.from_pretrained(
104
+ "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers",
105
+ torch_dtype=torch.float16,
106
+ ).to(device).eval()
107
+ ae.requires_grad_(False)
108
+
109
+ SCALE = ae.config.scaling_factor # 0.41407
110
+
111
+ transform = T.Compose([
112
+ T.Resize(512, interpolation=T.InterpolationMode.BICUBIC, antialias=True),
113
+ T.CenterCrop(512),
114
+ T.ToTensor(),
115
+ T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
116
+ ])
117
+
118
+ print("Encoding images to latents...")
119
+ all_latents = []
120
+ import time
121
+ t0 = time.time()
122
+
123
+ batch_imgs = []
124
+ for i, example in enumerate(ds):
125
+ img = example["image"].convert("RGB")
126
+ tensor = transform(img)
127
+ batch_imgs.append(tensor)
128
+
129
+ # Process in batches of 8
130
+ if len(batch_imgs) == 8 or i == len(ds) - 1:
131
+ batch = torch.stack(batch_imgs).to(device, dtype=torch.float16)
132
+ with torch.no_grad():
133
+ latent = ae.encode(batch).latent.float() # encode in fp16, store in fp32
134
+ all_latents.append(latent.cpu())
135
+ batch_imgs = []
136
+
137
+ if (i + 1) % 100 == 0 or i == len(ds) - 1:
138
+ print(f" Encoded {i+1}/{len(ds)} images ({time.time()-t0:.1f}s)")
139
+
140
+ all_latents = torch.cat(all_latents, dim=0) # (N, 32, 16, 16)
141
+ print(f"All latents: {all_latents.shape}, took {time.time()-t0:.1f}s")
142
+ print(f"Latent stats: mean={all_latents.mean():.3f}, std={all_latents.std():.3f}")
143
+
144
+ # Free DC-AE VRAM
145
+ del ae
146
+ torch.cuda.empty_cache()
147
+ gc.collect()
148
+ print("DC-AE encoder freed from VRAM.")
149
+
150
+ # ============================================================
151
+ # CELL 6: Encode all captions to text embeddings
152
+ # ============================================================
153
+ print("\nLoading text encoder (~87 MB)...")
154
+ from sentence_transformers import SentenceTransformer
155
+
156
+ text_encoder = SentenceTransformer(
157
+ "sentence-transformers/all-MiniLM-L6-v2",
158
+ device=str(device),
159
+ )
160
+ text_encoder.eval()
161
+
162
+ captions = [ex["text"] for ex in ds]
163
+ print(f"Encoding {len(captions)} captions...")
164
+
165
+ with torch.no_grad():
166
+ all_text_embs = text_encoder.encode(
167
+ captions,
168
+ convert_to_tensor=True,
169
+ normalize_embeddings=True,
170
+ batch_size=128,
171
+ show_progress_bar=True,
172
+ )
173
+
174
+ # Expand to sequence format: (N, 1, 384)
175
+ # The model projects 384 -> model_dim via registered context_proj
176
+ all_text_embs = all_text_embs.unsqueeze(1).cpu() # (N, 1, 384)
177
+ print(f"Text embeddings: {all_text_embs.shape}")
178
+
179
+ # Free text encoder VRAM
180
+ del text_encoder
181
+ torch.cuda.empty_cache()
182
+ gc.collect()
183
+ print("Text encoder freed from VRAM.")
184
+
185
+ # ============================================================
186
+ # CELL 7: Create dataset from precomputed features
187
+ # ============================================================
188
+ from torch.utils.data import Dataset, DataLoader
189
+
190
+ class PrecomputedLatentDataset(Dataset):
191
+ """All latents and text embeddings precomputed — zero I/O during training."""
192
+ def __init__(self, latents, text_embs):
193
+ self.latents = latents
194
+ self.text_embs = text_embs
195
+
196
+ def __len__(self):
197
+ return len(self.latents)
198
+
199
+ def __getitem__(self, idx):
200
+ return {
201
+ "latent": self.latents[idx],
202
+ "text_embed": self.text_embs[idx],
203
+ }
204
+
205
+ train_ds = PrecomputedLatentDataset(all_latents, all_text_embs)
206
+ print(f"Training dataset: {len(train_ds)} samples")
207
+ print(f" Latent: {train_ds[0]['latent'].shape}")
208
+ print(f" Text: {train_ds[0]['text_embed'].shape}")
209
+
210
+ # ============================================================
211
+ # CELL 8: Create IRIS model
212
+ # ============================================================
213
+ print("\nCreating IRIS-Small model...")
214
+
215
+ model = IRIS(
216
+ **get_model_config("iris-small"),
217
+ gradient_checkpointing=True,
218
+ text_dim=384, # all-MiniLM-L6-v2 output dim — registered as proper nn.Module
219
+ ).to(device)
220
+
221
+ counts = model.count_params()
222
+ print(f"Parameters: {counts['total']:,} ({counts['total']/1e6:.1f}M)")
223
+ print(f" Core: {counts['core']:,}")
224
+ print(f" Decoder: {counts['tiny_decoder']:,}")
225
+
226
+ if device.type == "cuda":
227
+ print(f"VRAM used: {torch.cuda.memory_allocated()/1e9:.2f} GB / {torch.cuda.get_device_properties(0).total_mem/1e9:.1f} GB")
228
+
229
+ # ============================================================
230
+ # CELL 9: Train!
231
+ # ============================================================
232
+ import math
233
+ from iris.train import CosineWarmupScheduler
234
+ from iris.flow_matching import flow_matching_loss
235
+
236
+ # Training config — tuned for Colab T4 with 833 Pokemon images
237
+ NUM_STEPS = 3000 # ~15 min on T4
238
+ BATCH_SIZE = 16 # fits T4 with IRIS-Small + grad checkpoint
239
+ LR = 3e-4 # slightly higher LR for small dataset
240
+ WARMUP_STEPS = 200
241
+ GRAD_CLIP = 1.0
242
+ NUM_ITERS = 3 # refinement iterations (3 is good for speed/quality)
243
+ LOG_EVERY = 50
244
+ SAVE_EVERY = 1000
245
+
246
+ loader = DataLoader(
247
+ train_ds,
248
+ batch_size=BATCH_SIZE,
249
+ shuffle=True,
250
+ num_workers=2,
251
+ pin_memory=True,
252
+ drop_last=True,
253
+ persistent_workers=True,
254
+ )
255
+
256
+ optimizer = torch.optim.AdamW(
257
+ model.parameters(),
258
+ lr=LR,
259
+ weight_decay=0.01,
260
+ betas=(0.9, 0.999),
261
+ )
262
+ scheduler = CosineWarmupScheduler(optimizer, WARMUP_STEPS, NUM_STEPS, min_lr_ratio=0.05)
263
+ scaler = torch.amp.GradScaler(enabled=(use_amp and amp_dtype == torch.float16))
264
+
265
+ model.train()
266
+ step = 0
267
+ epoch = 0
268
+ running_loss = 0.0
269
+ loss_history = []
270
+ best_loss = float("inf")
271
+ t_start = time.time()
272
+
273
+ print(f"\n{'='*60}")
274
+ print(f"Training IRIS-Small on Pokemon BLIP Captions")
275
+ print(f" {len(train_ds)} images, {NUM_STEPS} steps, BS={BATCH_SIZE}, R={NUM_ITERS}")
276
+ print(f" LR={LR}, warmup={WARMUP_STEPS}, AMP={amp_dtype}")
277
+ print(f"{'='*60}\n")
278
+
279
+ while step < NUM_STEPS:
280
+ epoch += 1
281
+ for batch in loader:
282
+ if step >= NUM_STEPS:
283
+ break
284
+
285
+ latent = batch["latent"].to(device, non_blocking=True)
286
+ text_embed = batch["text_embed"].to(device, non_blocking=True)
287
+
288
+ with torch.amp.autocast(device_type=device.type, dtype=amp_dtype, enabled=use_amp):
289
+ losses = flow_matching_loss(
290
+ model, latent, text_embed,
291
+ num_iterations=NUM_ITERS,
292
+ timestep_sampling="logit_normal",
293
+ scale_factor=SCALE,
294
+ )
295
+ loss = losses["loss"]
296
+
297
+ optimizer.zero_grad(set_to_none=True)
298
+ if scaler.is_enabled():
299
+ scaler.scale(loss).backward()
300
+ scaler.unscale_(optimizer)
301
+ gn = torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
302
+ scaler.step(optimizer)
303
+ scaler.update()
304
+ else:
305
+ loss.backward()
306
+ gn = torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
307
+ optimizer.step()
308
+
309
+ scheduler.step()
310
+ step += 1
311
+ lv = loss.item()
312
+ running_loss += lv
313
+ loss_history.append(lv)
314
+
315
+ if step % LOG_EVERY == 0:
316
+ avg = running_loss / LOG_EVERY
317
+ elapsed = time.time() - t_start
318
+ sps = step / elapsed
319
+ eta = (NUM_STEPS - step) / sps
320
+ lr = scheduler.get_lr()[0]
321
+ gn_val = gn.item() if isinstance(gn, torch.Tensor) else gn
322
+ tag = "OK" if not (math.isnan(avg) or math.isinf(avg)) else "!!"
323
+
324
+ print(
325
+ f"[{tag}] step {step:>5d}/{NUM_STEPS} | "
326
+ f"loss={avg:.4f} | "
327
+ f"grad={gn_val:.3f} | "
328
+ f"lr={lr:.1e} | "
329
+ f"{sps:.1f} steps/s | "
330
+ f"ETA {eta/60:.0f}min"
331
+ )
332
+
333
+ if avg < best_loss:
334
+ best_loss = avg
335
+ running_loss = 0.0
336
+
337
+ if step % SAVE_EVERY == 0:
338
+ os.makedirs("./iris_checkpoints", exist_ok=True)
339
+ p = f"./iris_checkpoints/iris_pokemon_step{step}.pt"
340
+ torch.save({
341
+ "step": step,
342
+ "model_state_dict": model.state_dict(),
343
+ "loss_history": loss_history,
344
+ "config": get_model_config("iris-small"),
345
+ }, p)
346
+ print(f" Saved: {p}")
347
+
348
+ # Final save
349
+ os.makedirs("./iris_checkpoints", exist_ok=True)
350
+ final_path = "./iris_checkpoints/iris_pokemon_final.pt"
351
+ torch.save({
352
+ "step": step,
353
+ "model_state_dict": model.state_dict(),
354
+ "loss_history": loss_history,
355
+ "config": get_model_config("iris-small"),
356
+ }, final_path)
357
+
358
+ total_time = time.time() - t_start
359
+ f50 = sum(loss_history[:50]) / min(50, len(loss_history))
360
+ l50 = sum(loss_history[-50:]) / min(50, len(loss_history))
361
+ print(f"\n{'='*60}")
362
+ print(f"Training complete!")
363
+ print(f" {step} steps in {total_time/60:.1f} min ({step/total_time:.1f} steps/s)")
364
+ print(f" Loss: {f50:.4f} -> {l50:.4f} ({(1-l50/f50)*100:.1f}% reduction)")
365
+ print(f" Best: {best_loss:.4f}")
366
+ print(f" Saved: {final_path}")
367
+ print(f"{'='*60}")
368
+
369
+ # ============================================================
370
+ # CELL 10: Plot training loss
371
+ # ============================================================
372
+ try:
373
+ import matplotlib.pyplot as plt
374
+
375
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
376
+
377
+ ax1.plot(loss_history, alpha=0.3, color="blue", linewidth=0.5)
378
+ window = 50
379
+ if len(loss_history) > window:
380
+ smoothed = [sum(loss_history[max(0,i-window):i+1])/min(i+1, window) for i in range(len(loss_history))]
381
+ ax1.plot(smoothed, color="red", linewidth=2, label=f"Smoothed (w={window})")
382
+ ax1.set_xlabel("Step")
383
+ ax1.set_ylabel("Flow Matching Loss")
384
+ ax1.set_title("Training Loss")
385
+ ax1.legend()
386
+ ax1.grid(True, alpha=0.3)
387
+
388
+ chunks = [loss_history[i:i+100] for i in range(0, len(loss_history), 100)]
389
+ if len(chunks) > 1:
390
+ ax2.boxplot([c for c in chunks], positions=list(range(len(chunks))))
391
+ ax2.set_xlabel("Step (x100)")
392
+ ax2.set_ylabel("Loss")
393
+ ax2.set_title("Loss Distribution Over Time")
394
+ ax2.grid(True, alpha=0.3)
395
+
396
+ plt.tight_layout()
397
+ plt.savefig("./iris_checkpoints/training_loss.png", dpi=100)
398
+ plt.show()
399
+ print("Loss plot saved.")
400
+ except ImportError:
401
+ print("matplotlib not available, skipping loss plot")
402
+
403
+ # ============================================================
404
+ # CELL 11: Generate sample images from trained model
405
+ # ============================================================
406
+ print("\nGenerating sample images from trained model...")
407
+
408
+ # Reload DC-AE decoder for visualization
409
+ ae_decoder = AutoencoderDC.from_pretrained(
410
+ "mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers",
411
+ torch_dtype=torch.float16,
412
+ ).to(device).eval()
413
+ ae_decoder.requires_grad_(False)
414
+
415
+ # Reload text encoder for new prompts
416
+ text_enc = SentenceTransformer(
417
+ "sentence-transformers/all-MiniLM-L6-v2",
418
+ device=str(device),
419
+ )
420
+
421
+ model.eval()
422
+
423
+ sample_prompts = [
424
+ "a blue water pokemon with fins",
425
+ "a fire dragon pokemon with wings",
426
+ "a cute pink pokemon with big eyes",
427
+ "a green grass pokemon",
428
+ ]
429
+
430
+ for i, prompt in enumerate(sample_prompts):
431
+ with torch.no_grad():
432
+ txt_emb = text_enc.encode(
433
+ [prompt], convert_to_tensor=True, normalize_embeddings=True
434
+ ).unsqueeze(1).to(device) # (1, 1, 384)
435
+
436
+ noise = torch.randn(1, 32, 16, 16, device=device)
437
+
438
+ with torch.no_grad():
439
+ z_pred = euler_sample(
440
+ model, noise, txt_emb,
441
+ num_steps=20,
442
+ num_iterations=NUM_ITERS,
443
+ cfg_scale=1.0,
444
+ scale_factor=SCALE,
445
+ )
446
+ img = ae_decoder.decode(z_pred.half()).sample
447
+ img = (img.float().clamp(-1, 1) * 0.5 + 0.5)
448
+
449
+ from torchvision.utils import save_image
450
+ fname = f"./iris_checkpoints/sample_{i}_{prompt[:20].replace(' ','_')}.png"
451
+ save_image(img, fname)
452
+ print(f" Sample {i}: '{prompt}' -> {fname}")
453
+
454
+ print("\nAll samples saved to ./iris_checkpoints/")
455
+ print("NOTE: Trained on 833 images for 3000 steps — quality improves with more data + steps.")