Perf fix: cache Manhattan dist, optimize scan, pre-cache CLIP, fix deprecated AMP API
Browse files- IRIS_Training_Notebook.ipynb +59 -37
IRIS_Training_Notebook.ipynb
CHANGED
|
@@ -359,7 +359,7 @@
|
|
| 359 |
"source": [
|
| 360 |
"# \u2500\u2500\u2500 VAE Training Loop \u2500\u2500\u2500\n",
|
| 361 |
"import time\n",
|
| 362 |
-
"from
|
| 363 |
"\n",
|
| 364 |
"VAE_EPOCHS = 80 # Enough to get good reconstructions\n",
|
| 365 |
"VAE_LR = 1e-4\n",
|
|
@@ -368,25 +368,23 @@
|
|
| 368 |
"\n",
|
| 369 |
"optimizer_vae = torch.optim.AdamW(vae.parameters(), lr=VAE_LR, weight_decay=0.01)\n",
|
| 370 |
"scheduler_vae = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_vae, T_max=VAE_EPOCHS)\n",
|
| 371 |
-
"scaler = GradScaler()\n",
|
| 372 |
"dwt = HaarDWT2D()\n",
|
| 373 |
"\n",
|
| 374 |
"# Logging\n",
|
| 375 |
"vae_losses = {\"total\": [], \"recon\": [], \"kl\": [], \"freq\": []}\n",
|
| 376 |
"\n",
|
| 377 |
"print(f\"Training VAE for {VAE_EPOCHS} epochs on {len(train_loader)} batches...\")\n",
|
| 378 |
-
"print(f\"{'Epoch':>6} {'Loss':>10} {'Recon':>10} {'KL':>10} {'Freq':>10} {'LR':>10} {'Time':>8}\")\n",
|
| 379 |
-
"print(\"\u2500\" * 70)\n",
|
| 380 |
"\n",
|
| 381 |
"vae.train()\n",
|
| 382 |
-
"
|
|
|
|
| 383 |
" epoch_losses = {\"total\": 0, \"recon\": 0, \"kl\": 0, \"freq\": 0}\n",
|
| 384 |
-
" t0 = time.time()\n",
|
| 385 |
"\n",
|
| 386 |
" for images, _ in train_loader:\n",
|
| 387 |
-
" images = images.to(device)\n",
|
| 388 |
"\n",
|
| 389 |
-
" with autocast(dtype=torch.float16):\n",
|
| 390 |
" x_recon, mean, logvar = vae(images)\n",
|
| 391 |
"\n",
|
| 392 |
" # Reconstruction loss\n",
|
|
@@ -422,12 +420,7 @@
|
|
| 422 |
" vae_losses[k].append(epoch_losses[k])\n",
|
| 423 |
"\n",
|
| 424 |
" scheduler_vae.step()\n",
|
| 425 |
-
"
|
| 426 |
-
"\n",
|
| 427 |
-
" if (epoch + 1) % 10 == 0 or epoch == 0:\n",
|
| 428 |
-
" lr = optimizer_vae.param_groups[0][\"lr\"]\n",
|
| 429 |
-
" print(f\"{epoch+1:>6} {epoch_losses['total']:>10.4f} {epoch_losses['recon']:>10.4f} \"\n",
|
| 430 |
-
" f\"{epoch_losses['kl']:>10.4f} {epoch_losses['freq']:>10.4f} {lr:>10.2e} {dt:>7.1f}s\")\n",
|
| 431 |
"\n",
|
| 432 |
"print(\"\\n\u2705 VAE training complete!\")"
|
| 433 |
],
|
|
@@ -539,6 +532,40 @@
|
|
| 539 |
"print(f\" Effective at r=6: ~{gen_params + 5*core_params:,} effective params\")\n",
|
| 540 |
"print(f\" Memory fp16: {gen_params*2/1024/1024:.1f} MB\")\n",
|
| 541 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 542 |
"# Free standalone VAE to save memory\n",
|
| 543 |
"del vae, optimizer_vae, scheduler_vae\n",
|
| 544 |
"torch.cuda.empty_cache()"
|
|
@@ -551,6 +578,8 @@
|
|
| 551 |
"metadata": {},
|
| 552 |
"source": [
|
| 553 |
"# \u2500\u2500\u2500 Generator Training Loop \u2500\u2500\u2500\n",
|
|
|
|
|
|
|
| 554 |
"GEN_EPOCHS = 150 # More epochs for small dataset\n",
|
| 555 |
"GEN_LR = 2e-4 # Higher LR works well with AdamW + cosine\n",
|
| 556 |
"GRAD_ACCUM = 2 # Effective batch = BATCH_SIZE \u00d7 GRAD_ACCUM = 8\n",
|
|
@@ -563,7 +592,7 @@
|
|
| 563 |
" betas=(0.9, 0.95),\n",
|
| 564 |
")\n",
|
| 565 |
"\n",
|
| 566 |
-
"total_steps = GEN_EPOCHS * len(
|
| 567 |
"\n",
|
| 568 |
"def lr_lambda(step):\n",
|
| 569 |
" if step < WARMUP_STEPS:\n",
|
|
@@ -572,41 +601,35 @@
|
|
| 572 |
" return 0.5 * (1 + __import__('math').cos(__import__('math').pi * progress))\n",
|
| 573 |
"\n",
|
| 574 |
"scheduler_gen = torch.optim.lr_scheduler.LambdaLR(optimizer_gen, lr_lambda)\n",
|
| 575 |
-
"scaler_gen = GradScaler()\n",
|
| 576 |
"\n",
|
| 577 |
"# Logging\n",
|
| 578 |
"gen_losses = {\"total\": [], \"velocity\": [], \"kl\": []}\n",
|
| 579 |
"\n",
|
| 580 |
"print(f\"Training generator for {GEN_EPOCHS} epochs ({total_steps} optimizer steps)\")\n",
|
| 581 |
"print(f\"Effective batch size: {BATCH_SIZE} \u00d7 {GRAD_ACCUM} = {BATCH_SIZE * GRAD_ACCUM}\")\n",
|
| 582 |
-
"print(f\"
|
| 583 |
-
"print()\n",
|
| 584 |
-
"print(f\"{'Epoch':>6} {'Loss':>10} {'VelLoss':>10} {'MeanT':>8} {'LR':>10} {'Time':>8}\")\n",
|
| 585 |
-
"print(\"\u2500\" * 60)\n",
|
| 586 |
"\n",
|
| 587 |
"iris.generator.train()\n",
|
| 588 |
"global_step = 0\n",
|
| 589 |
"best_loss = float('inf')\n",
|
| 590 |
"\n",
|
| 591 |
-
"
|
|
|
|
| 592 |
" epoch_vel = 0\n",
|
| 593 |
" epoch_total = 0\n",
|
| 594 |
" n_batches = 0\n",
|
| 595 |
-
" t0 = time.time()\n",
|
| 596 |
"\n",
|
| 597 |
" optimizer_gen.zero_grad(set_to_none=True)\n",
|
| 598 |
"\n",
|
| 599 |
-
" for batch_idx, (images,
|
| 600 |
-
" images = images.to(device)\n",
|
| 601 |
-
"\n",
|
| 602 |
-
" # Encode text with CLIP\n",
|
| 603 |
-
" with torch.no_grad():\n",
|
| 604 |
-
" text_emb = encode_text(list(captions)) # [B, 77, 768]\n",
|
| 605 |
"\n",
|
| 606 |
" # Forward pass with mixed precision\n",
|
| 607 |
-
" with autocast(dtype=torch.float16):\n",
|
| 608 |
-
" # Randomly sample iteration count for robustness\n",
|
| 609 |
-
" r = [
|
| 610 |
" result = iris.train_step(images, text_emb, num_iterations=r)\n",
|
| 611 |
" loss = result[\"loss\"] / GRAD_ACCUM\n",
|
| 612 |
"\n",
|
|
@@ -630,17 +653,16 @@
|
|
| 630 |
" avg_total = epoch_total / n_batches\n",
|
| 631 |
" gen_losses[\"velocity\"].append(avg_vel)\n",
|
| 632 |
" gen_losses[\"total\"].append(avg_total)\n",
|
| 633 |
-
" dt = time.time() - t0\n",
|
| 634 |
"\n",
|
| 635 |
" if avg_vel < best_loss:\n",
|
| 636 |
" best_loss = avg_vel\n",
|
| 637 |
"\n",
|
| 638 |
-
"
|
| 639 |
-
"
|
| 640 |
-
"
|
| 641 |
-
" f\"{result['mean_t']:>8.3f} {lr:>10.2e} {dt:>7.1f}s\")\n",
|
| 642 |
"\n",
|
| 643 |
-
"
|
|
|
|
| 644 |
],
|
| 645 |
"outputs": [],
|
| 646 |
"execution_count": null
|
|
|
|
| 359 |
"source": [
|
| 360 |
"# \u2500\u2500\u2500 VAE Training Loop \u2500\u2500\u2500\n",
|
| 361 |
"import time\n",
|
| 362 |
+
"from tqdm.auto import tqdm\n",
|
| 363 |
"\n",
|
| 364 |
"VAE_EPOCHS = 80 # Enough to get good reconstructions\n",
|
| 365 |
"VAE_LR = 1e-4\n",
|
|
|
|
| 368 |
"\n",
|
| 369 |
"optimizer_vae = torch.optim.AdamW(vae.parameters(), lr=VAE_LR, weight_decay=0.01)\n",
|
| 370 |
"scheduler_vae = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_vae, T_max=VAE_EPOCHS)\n",
|
| 371 |
+
"scaler = torch.amp.GradScaler('cuda')\n",
|
| 372 |
"dwt = HaarDWT2D()\n",
|
| 373 |
"\n",
|
| 374 |
"# Logging\n",
|
| 375 |
"vae_losses = {\"total\": [], \"recon\": [], \"kl\": [], \"freq\": []}\n",
|
| 376 |
"\n",
|
| 377 |
"print(f\"Training VAE for {VAE_EPOCHS} epochs on {len(train_loader)} batches...\")\n",
|
|
|
|
|
|
|
| 378 |
"\n",
|
| 379 |
"vae.train()\n",
|
| 380 |
+
"pbar = tqdm(range(VAE_EPOCHS), desc=\"VAE Training\")\n",
|
| 381 |
+
"for epoch in pbar:\n",
|
| 382 |
" epoch_losses = {\"total\": 0, \"recon\": 0, \"kl\": 0, \"freq\": 0}\n",
|
|
|
|
| 383 |
"\n",
|
| 384 |
" for images, _ in train_loader:\n",
|
| 385 |
+
" images = images.to(device, non_blocking=True)\n",
|
| 386 |
"\n",
|
| 387 |
+
" with torch.amp.autocast('cuda', dtype=torch.float16):\n",
|
| 388 |
" x_recon, mean, logvar = vae(images)\n",
|
| 389 |
"\n",
|
| 390 |
" # Reconstruction loss\n",
|
|
|
|
| 420 |
" vae_losses[k].append(epoch_losses[k])\n",
|
| 421 |
"\n",
|
| 422 |
" scheduler_vae.step()\n",
|
| 423 |
+
" pbar.set_postfix(loss=f\"{epoch_losses['total']:.4f}\", recon=f\"{epoch_losses['recon']:.4f}\")\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 424 |
"\n",
|
| 425 |
"print(\"\\n\u2705 VAE training complete!\")"
|
| 426 |
],
|
|
|
|
| 532 |
"print(f\" Effective at r=6: ~{gen_params + 5*core_params:,} effective params\")\n",
|
| 533 |
"print(f\" Memory fp16: {gen_params*2/1024/1024:.1f} MB\")\n",
|
| 534 |
"\n",
|
| 535 |
+
"# \u2500\u2500\u2500 Pre-cache CLIP text embeddings (HUGE speedup) \u2500\u2500\u2500\n",
|
| 536 |
+
"# Instead of encoding text every batch, cache all embeddings upfront\n",
|
| 537 |
+
"print(\"\\nPre-caching CLIP text embeddings...\")\n",
|
| 538 |
+
"all_text_embeddings = []\n",
|
| 539 |
+
"cache_loader = DataLoader(train_dataset, batch_size=32, shuffle=False, num_workers=0)\n",
|
| 540 |
+
"with torch.no_grad():\n",
|
| 541 |
+
" for _, captions in tqdm(cache_loader, desc=\"Encoding text\"):\n",
|
| 542 |
+
" emb = encode_text(list(captions))\n",
|
| 543 |
+
" all_text_embeddings.append(emb.cpu())\n",
|
| 544 |
+
"all_text_embeddings = torch.cat(all_text_embeddings, dim=0) # [N, 77, 768]\n",
|
| 545 |
+
"print(f\"\u2705 Cached {all_text_embeddings.shape[0]} text embeddings: {all_text_embeddings.shape}\")\n",
|
| 546 |
+
"\n",
|
| 547 |
+
"# Free CLIP from GPU (we don't need it during training anymore!)\n",
|
| 548 |
+
"text_encoder.cpu()\n",
|
| 549 |
+
"torch.cuda.empty_cache()\n",
|
| 550 |
+
"print(\"\u2705 CLIP moved to CPU to free ~600MB VRAM\")\n",
|
| 551 |
+
"\n",
|
| 552 |
+
"# Create a new dataset that uses cached embeddings\n",
|
| 553 |
+
"class CachedDataset(Dataset):\n",
|
| 554 |
+
" def __init__(self, image_dataset, cached_text_emb):\n",
|
| 555 |
+
" self.image_dataset = image_dataset\n",
|
| 556 |
+
" self.text_emb = cached_text_emb\n",
|
| 557 |
+
" def __len__(self):\n",
|
| 558 |
+
" return len(self.image_dataset)\n",
|
| 559 |
+
" def __getitem__(self, idx):\n",
|
| 560 |
+
" image, _ = self.image_dataset[idx]\n",
|
| 561 |
+
" return image, self.text_emb[idx]\n",
|
| 562 |
+
"\n",
|
| 563 |
+
"cached_dataset = CachedDataset(train_dataset, all_text_embeddings)\n",
|
| 564 |
+
"cached_loader = DataLoader(\n",
|
| 565 |
+
" cached_dataset, batch_size=BATCH_SIZE, shuffle=True,\n",
|
| 566 |
+
" num_workers=NUM_WORKERS, pin_memory=True, drop_last=True,\n",
|
| 567 |
+
")\n",
|
| 568 |
+
"\n",
|
| 569 |
"# Free standalone VAE to save memory\n",
|
| 570 |
"del vae, optimizer_vae, scheduler_vae\n",
|
| 571 |
"torch.cuda.empty_cache()"
|
|
|
|
| 578 |
"metadata": {},
|
| 579 |
"source": [
|
| 580 |
"# \u2500\u2500\u2500 Generator Training Loop \u2500\u2500\u2500\n",
|
| 581 |
+
"import time\n",
|
| 582 |
+
"\n",
|
| 583 |
"GEN_EPOCHS = 150 # More epochs for small dataset\n",
|
| 584 |
"GEN_LR = 2e-4 # Higher LR works well with AdamW + cosine\n",
|
| 585 |
"GRAD_ACCUM = 2 # Effective batch = BATCH_SIZE \u00d7 GRAD_ACCUM = 8\n",
|
|
|
|
| 592 |
" betas=(0.9, 0.95),\n",
|
| 593 |
")\n",
|
| 594 |
"\n",
|
| 595 |
+
"total_steps = GEN_EPOCHS * len(cached_loader) // GRAD_ACCUM\n",
|
| 596 |
"\n",
|
| 597 |
"def lr_lambda(step):\n",
|
| 598 |
" if step < WARMUP_STEPS:\n",
|
|
|
|
| 601 |
" return 0.5 * (1 + __import__('math').cos(__import__('math').pi * progress))\n",
|
| 602 |
"\n",
|
| 603 |
"scheduler_gen = torch.optim.lr_scheduler.LambdaLR(optimizer_gen, lr_lambda)\n",
|
| 604 |
+
"scaler_gen = torch.amp.GradScaler('cuda')\n",
|
| 605 |
"\n",
|
| 606 |
"# Logging\n",
|
| 607 |
"gen_losses = {\"total\": [], \"velocity\": [], \"kl\": []}\n",
|
| 608 |
"\n",
|
| 609 |
"print(f\"Training generator for {GEN_EPOCHS} epochs ({total_steps} optimizer steps)\")\n",
|
| 610 |
"print(f\"Effective batch size: {BATCH_SIZE} \u00d7 {GRAD_ACCUM} = {BATCH_SIZE * GRAD_ACCUM}\")\n",
|
| 611 |
+
"print(f\"Using cached CLIP embeddings (no per-batch encoding overhead)\")\n",
|
|
|
|
|
|
|
|
|
|
| 612 |
"\n",
|
| 613 |
"iris.generator.train()\n",
|
| 614 |
"global_step = 0\n",
|
| 615 |
"best_loss = float('inf')\n",
|
| 616 |
"\n",
|
| 617 |
+
"pbar = tqdm(range(GEN_EPOCHS), desc=\"Gen Training\")\n",
|
| 618 |
+
"for epoch in pbar:\n",
|
| 619 |
" epoch_vel = 0\n",
|
| 620 |
" epoch_total = 0\n",
|
| 621 |
" n_batches = 0\n",
|
|
|
|
| 622 |
"\n",
|
| 623 |
" optimizer_gen.zero_grad(set_to_none=True)\n",
|
| 624 |
"\n",
|
| 625 |
+
" for batch_idx, (images, text_emb) in enumerate(cached_loader):\n",
|
| 626 |
+
" images = images.to(device, non_blocking=True)\n",
|
| 627 |
+
" text_emb = text_emb.to(device, non_blocking=True)\n",
|
|
|
|
|
|
|
|
|
|
| 628 |
"\n",
|
| 629 |
" # Forward pass with mixed precision\n",
|
| 630 |
+
" with torch.amp.autocast('cuda', dtype=torch.float16):\n",
|
| 631 |
+
" # Randomly sample iteration count for robustness (keep low for speed)\n",
|
| 632 |
+
" r = [3, 4, 5][torch.randint(0, 3, (1,)).item()]\n",
|
| 633 |
" result = iris.train_step(images, text_emb, num_iterations=r)\n",
|
| 634 |
" loss = result[\"loss\"] / GRAD_ACCUM\n",
|
| 635 |
"\n",
|
|
|
|
| 653 |
" avg_total = epoch_total / n_batches\n",
|
| 654 |
" gen_losses[\"velocity\"].append(avg_vel)\n",
|
| 655 |
" gen_losses[\"total\"].append(avg_total)\n",
|
|
|
|
| 656 |
"\n",
|
| 657 |
" if avg_vel < best_loss:\n",
|
| 658 |
" best_loss = avg_vel\n",
|
| 659 |
"\n",
|
| 660 |
+
" pbar.set_postfix(vel_loss=f\"{avg_vel:.4f}\", best=f\"{best_loss:.4f}\")\n",
|
| 661 |
+
"\n",
|
| 662 |
+
"print(f\"\\n\u2705 Generator training complete! Best velocity loss: {best_loss:.4f}\")\n",
|
|
|
|
| 663 |
"\n",
|
| 664 |
+
"# Reload CLIP for generation\n",
|
| 665 |
+
"text_encoder.to(device)"
|
| 666 |
],
|
| 667 |
"outputs": [],
|
| 668 |
"execution_count": null
|