asdf98 commited on
Commit
99a5f54
·
verified ·
1 Parent(s): 3e17403

Perf fix: cache Manhattan dist, optimize scan, pre-cache CLIP, fix deprecated AMP API

Browse files
Files changed (1) hide show
  1. IRIS_Training_Notebook.ipynb +59 -37
IRIS_Training_Notebook.ipynb CHANGED
@@ -359,7 +359,7 @@
359
  "source": [
360
  "# \u2500\u2500\u2500 VAE Training Loop \u2500\u2500\u2500\n",
361
  "import time\n",
362
- "from torch.cuda.amp import autocast, GradScaler\n",
363
  "\n",
364
  "VAE_EPOCHS = 80 # Enough to get good reconstructions\n",
365
  "VAE_LR = 1e-4\n",
@@ -368,25 +368,23 @@
368
  "\n",
369
  "optimizer_vae = torch.optim.AdamW(vae.parameters(), lr=VAE_LR, weight_decay=0.01)\n",
370
  "scheduler_vae = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_vae, T_max=VAE_EPOCHS)\n",
371
- "scaler = GradScaler()\n",
372
  "dwt = HaarDWT2D()\n",
373
  "\n",
374
  "# Logging\n",
375
  "vae_losses = {\"total\": [], \"recon\": [], \"kl\": [], \"freq\": []}\n",
376
  "\n",
377
  "print(f\"Training VAE for {VAE_EPOCHS} epochs on {len(train_loader)} batches...\")\n",
378
- "print(f\"{'Epoch':>6} {'Loss':>10} {'Recon':>10} {'KL':>10} {'Freq':>10} {'LR':>10} {'Time':>8}\")\n",
379
- "print(\"\u2500\" * 70)\n",
380
  "\n",
381
  "vae.train()\n",
382
- "for epoch in range(VAE_EPOCHS):\n",
 
383
  " epoch_losses = {\"total\": 0, \"recon\": 0, \"kl\": 0, \"freq\": 0}\n",
384
- " t0 = time.time()\n",
385
  "\n",
386
  " for images, _ in train_loader:\n",
387
- " images = images.to(device)\n",
388
  "\n",
389
- " with autocast(dtype=torch.float16):\n",
390
  " x_recon, mean, logvar = vae(images)\n",
391
  "\n",
392
  " # Reconstruction loss\n",
@@ -422,12 +420,7 @@
422
  " vae_losses[k].append(epoch_losses[k])\n",
423
  "\n",
424
  " scheduler_vae.step()\n",
425
- " dt = time.time() - t0\n",
426
- "\n",
427
- " if (epoch + 1) % 10 == 0 or epoch == 0:\n",
428
- " lr = optimizer_vae.param_groups[0][\"lr\"]\n",
429
- " print(f\"{epoch+1:>6} {epoch_losses['total']:>10.4f} {epoch_losses['recon']:>10.4f} \"\n",
430
- " f\"{epoch_losses['kl']:>10.4f} {epoch_losses['freq']:>10.4f} {lr:>10.2e} {dt:>7.1f}s\")\n",
431
  "\n",
432
  "print(\"\\n\u2705 VAE training complete!\")"
433
  ],
@@ -539,6 +532,40 @@
539
  "print(f\" Effective at r=6: ~{gen_params + 5*core_params:,} effective params\")\n",
540
  "print(f\" Memory fp16: {gen_params*2/1024/1024:.1f} MB\")\n",
541
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
542
  "# Free standalone VAE to save memory\n",
543
  "del vae, optimizer_vae, scheduler_vae\n",
544
  "torch.cuda.empty_cache()"
@@ -551,6 +578,8 @@
551
  "metadata": {},
552
  "source": [
553
  "# \u2500\u2500\u2500 Generator Training Loop \u2500\u2500\u2500\n",
 
 
554
  "GEN_EPOCHS = 150 # More epochs for small dataset\n",
555
  "GEN_LR = 2e-4 # Higher LR works well with AdamW + cosine\n",
556
  "GRAD_ACCUM = 2 # Effective batch = BATCH_SIZE \u00d7 GRAD_ACCUM = 8\n",
@@ -563,7 +592,7 @@
563
  " betas=(0.9, 0.95),\n",
564
  ")\n",
565
  "\n",
566
- "total_steps = GEN_EPOCHS * len(train_loader) // GRAD_ACCUM\n",
567
  "\n",
568
  "def lr_lambda(step):\n",
569
  " if step < WARMUP_STEPS:\n",
@@ -572,41 +601,35 @@
572
  " return 0.5 * (1 + __import__('math').cos(__import__('math').pi * progress))\n",
573
  "\n",
574
  "scheduler_gen = torch.optim.lr_scheduler.LambdaLR(optimizer_gen, lr_lambda)\n",
575
- "scaler_gen = GradScaler()\n",
576
  "\n",
577
  "# Logging\n",
578
  "gen_losses = {\"total\": [], \"velocity\": [], \"kl\": []}\n",
579
  "\n",
580
  "print(f\"Training generator for {GEN_EPOCHS} epochs ({total_steps} optimizer steps)\")\n",
581
  "print(f\"Effective batch size: {BATCH_SIZE} \u00d7 {GRAD_ACCUM} = {BATCH_SIZE * GRAD_ACCUM}\")\n",
582
- "print(f\"Warmup: {WARMUP_STEPS} steps, then cosine decay to 0\")\n",
583
- "print()\n",
584
- "print(f\"{'Epoch':>6} {'Loss':>10} {'VelLoss':>10} {'MeanT':>8} {'LR':>10} {'Time':>8}\")\n",
585
- "print(\"\u2500\" * 60)\n",
586
  "\n",
587
  "iris.generator.train()\n",
588
  "global_step = 0\n",
589
  "best_loss = float('inf')\n",
590
  "\n",
591
- "for epoch in range(GEN_EPOCHS):\n",
 
592
  " epoch_vel = 0\n",
593
  " epoch_total = 0\n",
594
  " n_batches = 0\n",
595
- " t0 = time.time()\n",
596
  "\n",
597
  " optimizer_gen.zero_grad(set_to_none=True)\n",
598
  "\n",
599
- " for batch_idx, (images, captions) in enumerate(train_loader):\n",
600
- " images = images.to(device)\n",
601
- "\n",
602
- " # Encode text with CLIP\n",
603
- " with torch.no_grad():\n",
604
- " text_emb = encode_text(list(captions)) # [B, 77, 768]\n",
605
  "\n",
606
  " # Forward pass with mixed precision\n",
607
- " with autocast(dtype=torch.float16):\n",
608
- " # Randomly sample iteration count for robustness\n",
609
- " r = [4, 5, 6, 7, 8][torch.randint(0, 5, (1,)).item()]\n",
610
  " result = iris.train_step(images, text_emb, num_iterations=r)\n",
611
  " loss = result[\"loss\"] / GRAD_ACCUM\n",
612
  "\n",
@@ -630,17 +653,16 @@
630
  " avg_total = epoch_total / n_batches\n",
631
  " gen_losses[\"velocity\"].append(avg_vel)\n",
632
  " gen_losses[\"total\"].append(avg_total)\n",
633
- " dt = time.time() - t0\n",
634
  "\n",
635
  " if avg_vel < best_loss:\n",
636
  " best_loss = avg_vel\n",
637
  "\n",
638
- " if (epoch + 1) % 10 == 0 or epoch == 0:\n",
639
- " lr = optimizer_gen.param_groups[0][\"lr\"]\n",
640
- " print(f\"{epoch+1:>6} {avg_total:>10.4f} {avg_vel:>10.4f} \"\n",
641
- " f\"{result['mean_t']:>8.3f} {lr:>10.2e} {dt:>7.1f}s\")\n",
642
  "\n",
643
- "print(f\"\\n\u2705 Generator training complete! Best velocity loss: {best_loss:.4f}\")"
 
644
  ],
645
  "outputs": [],
646
  "execution_count": null
 
359
  "source": [
360
  "# \u2500\u2500\u2500 VAE Training Loop \u2500\u2500\u2500\n",
361
  "import time\n",
362
+ "from tqdm.auto import tqdm\n",
363
  "\n",
364
  "VAE_EPOCHS = 80 # Enough to get good reconstructions\n",
365
  "VAE_LR = 1e-4\n",
 
368
  "\n",
369
  "optimizer_vae = torch.optim.AdamW(vae.parameters(), lr=VAE_LR, weight_decay=0.01)\n",
370
  "scheduler_vae = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_vae, T_max=VAE_EPOCHS)\n",
371
+ "scaler = torch.amp.GradScaler('cuda')\n",
372
  "dwt = HaarDWT2D()\n",
373
  "\n",
374
  "# Logging\n",
375
  "vae_losses = {\"total\": [], \"recon\": [], \"kl\": [], \"freq\": []}\n",
376
  "\n",
377
  "print(f\"Training VAE for {VAE_EPOCHS} epochs on {len(train_loader)} batches...\")\n",
 
 
378
  "\n",
379
  "vae.train()\n",
380
+ "pbar = tqdm(range(VAE_EPOCHS), desc=\"VAE Training\")\n",
381
+ "for epoch in pbar:\n",
382
  " epoch_losses = {\"total\": 0, \"recon\": 0, \"kl\": 0, \"freq\": 0}\n",
 
383
  "\n",
384
  " for images, _ in train_loader:\n",
385
+ " images = images.to(device, non_blocking=True)\n",
386
  "\n",
387
+ " with torch.amp.autocast('cuda', dtype=torch.float16):\n",
388
  " x_recon, mean, logvar = vae(images)\n",
389
  "\n",
390
  " # Reconstruction loss\n",
 
420
  " vae_losses[k].append(epoch_losses[k])\n",
421
  "\n",
422
  " scheduler_vae.step()\n",
423
+ " pbar.set_postfix(loss=f\"{epoch_losses['total']:.4f}\", recon=f\"{epoch_losses['recon']:.4f}\")\n",
 
 
 
 
 
424
  "\n",
425
  "print(\"\\n\u2705 VAE training complete!\")"
426
  ],
 
532
  "print(f\" Effective at r=6: ~{gen_params + 5*core_params:,} effective params\")\n",
533
  "print(f\" Memory fp16: {gen_params*2/1024/1024:.1f} MB\")\n",
534
  "\n",
535
+ "# \u2500\u2500\u2500 Pre-cache CLIP text embeddings (HUGE speedup) \u2500\u2500\u2500\n",
536
+ "# Instead of encoding text every batch, cache all embeddings upfront\n",
537
+ "print(\"\\nPre-caching CLIP text embeddings...\")\n",
538
+ "all_text_embeddings = []\n",
539
+ "cache_loader = DataLoader(train_dataset, batch_size=32, shuffle=False, num_workers=0)\n",
540
+ "with torch.no_grad():\n",
541
+ " for _, captions in tqdm(cache_loader, desc=\"Encoding text\"):\n",
542
+ " emb = encode_text(list(captions))\n",
543
+ " all_text_embeddings.append(emb.cpu())\n",
544
+ "all_text_embeddings = torch.cat(all_text_embeddings, dim=0) # [N, 77, 768]\n",
545
+ "print(f\"\u2705 Cached {all_text_embeddings.shape[0]} text embeddings: {all_text_embeddings.shape}\")\n",
546
+ "\n",
547
+ "# Free CLIP from GPU (we don't need it during training anymore!)\n",
548
+ "text_encoder.cpu()\n",
549
+ "torch.cuda.empty_cache()\n",
550
+ "print(\"\u2705 CLIP moved to CPU to free ~600MB VRAM\")\n",
551
+ "\n",
552
+ "# Create a new dataset that uses cached embeddings\n",
553
+ "class CachedDataset(Dataset):\n",
554
+ " def __init__(self, image_dataset, cached_text_emb):\n",
555
+ " self.image_dataset = image_dataset\n",
556
+ " self.text_emb = cached_text_emb\n",
557
+ " def __len__(self):\n",
558
+ " return len(self.image_dataset)\n",
559
+ " def __getitem__(self, idx):\n",
560
+ " image, _ = self.image_dataset[idx]\n",
561
+ " return image, self.text_emb[idx]\n",
562
+ "\n",
563
+ "cached_dataset = CachedDataset(train_dataset, all_text_embeddings)\n",
564
+ "cached_loader = DataLoader(\n",
565
+ " cached_dataset, batch_size=BATCH_SIZE, shuffle=True,\n",
566
+ " num_workers=NUM_WORKERS, pin_memory=True, drop_last=True,\n",
567
+ ")\n",
568
+ "\n",
569
  "# Free standalone VAE to save memory\n",
570
  "del vae, optimizer_vae, scheduler_vae\n",
571
  "torch.cuda.empty_cache()"
 
578
  "metadata": {},
579
  "source": [
580
  "# \u2500\u2500\u2500 Generator Training Loop \u2500\u2500\u2500\n",
581
+ "import time\n",
582
+ "\n",
583
  "GEN_EPOCHS = 150 # More epochs for small dataset\n",
584
  "GEN_LR = 2e-4 # Higher LR works well with AdamW + cosine\n",
585
  "GRAD_ACCUM = 2 # Effective batch = BATCH_SIZE \u00d7 GRAD_ACCUM = 8\n",
 
592
  " betas=(0.9, 0.95),\n",
593
  ")\n",
594
  "\n",
595
+ "total_steps = GEN_EPOCHS * len(cached_loader) // GRAD_ACCUM\n",
596
  "\n",
597
  "def lr_lambda(step):\n",
598
  " if step < WARMUP_STEPS:\n",
 
601
  " return 0.5 * (1 + __import__('math').cos(__import__('math').pi * progress))\n",
602
  "\n",
603
  "scheduler_gen = torch.optim.lr_scheduler.LambdaLR(optimizer_gen, lr_lambda)\n",
604
+ "scaler_gen = torch.amp.GradScaler('cuda')\n",
605
  "\n",
606
  "# Logging\n",
607
  "gen_losses = {\"total\": [], \"velocity\": [], \"kl\": []}\n",
608
  "\n",
609
  "print(f\"Training generator for {GEN_EPOCHS} epochs ({total_steps} optimizer steps)\")\n",
610
  "print(f\"Effective batch size: {BATCH_SIZE} \u00d7 {GRAD_ACCUM} = {BATCH_SIZE * GRAD_ACCUM}\")\n",
611
+ "print(f\"Using cached CLIP embeddings (no per-batch encoding overhead)\")\n",
 
 
 
612
  "\n",
613
  "iris.generator.train()\n",
614
  "global_step = 0\n",
615
  "best_loss = float('inf')\n",
616
  "\n",
617
+ "pbar = tqdm(range(GEN_EPOCHS), desc=\"Gen Training\")\n",
618
+ "for epoch in pbar:\n",
619
  " epoch_vel = 0\n",
620
  " epoch_total = 0\n",
621
  " n_batches = 0\n",
 
622
  "\n",
623
  " optimizer_gen.zero_grad(set_to_none=True)\n",
624
  "\n",
625
+ " for batch_idx, (images, text_emb) in enumerate(cached_loader):\n",
626
+ " images = images.to(device, non_blocking=True)\n",
627
+ " text_emb = text_emb.to(device, non_blocking=True)\n",
 
 
 
628
  "\n",
629
  " # Forward pass with mixed precision\n",
630
+ " with torch.amp.autocast('cuda', dtype=torch.float16):\n",
631
+ " # Randomly sample iteration count for robustness (keep low for speed)\n",
632
+ " r = [3, 4, 5][torch.randint(0, 3, (1,)).item()]\n",
633
  " result = iris.train_step(images, text_emb, num_iterations=r)\n",
634
  " loss = result[\"loss\"] / GRAD_ACCUM\n",
635
  "\n",
 
653
  " avg_total = epoch_total / n_batches\n",
654
  " gen_losses[\"velocity\"].append(avg_vel)\n",
655
  " gen_losses[\"total\"].append(avg_total)\n",
 
656
  "\n",
657
  " if avg_vel < best_loss:\n",
658
  " best_loss = avg_vel\n",
659
  "\n",
660
+ " pbar.set_postfix(vel_loss=f\"{avg_vel:.4f}\", best=f\"{best_loss:.4f}\")\n",
661
+ "\n",
662
+ "print(f\"\\n\u2705 Generator training complete! Best velocity loss: {best_loss:.4f}\")\n",
 
663
  "\n",
664
+ "# Reload CLIP for generation\n",
665
+ "text_encoder.to(device)"
666
  ],
667
  "outputs": [],
668
  "execution_count": null