# LEARNING.md — Lessons from NSGF++ Reproduction All concrete mistakes, war stories, code examples, and derived principles from reproducing arXiv:2401.14069 on Kaggle T4 GPUs. Rules and procedures live in [SKILL.md](SKILL.md). --- ## Mistake Catalog ### #1 — geomloss tensor shape bug (CRITICAL) **What**: `SamplesLoss` in geomloss requires `(N, D)` or `(B, N, D)` tensors. Image experiments passed `(N, C, H, W)`. **Impact**: MNIST and CIFAR-10 crash immediately with `ValueError: Input samples 'x' and 'y' should be encoded as (N,D) or (B,N,D) (batch) tensors.` The 2D experiment works fine with shape `(256, 2)`, hiding the bug completely. **Root cause**: Only tested the 2D code path. Never ran a 10-line test with image-shaped tensors. **Prevention**: Before building any training loop, test the library with EXACT tensor shapes for every experiment: ```python from geomloss import SamplesLoss loss_fn = SamplesLoss(loss="sinkhorn", p=2, blur=0.5, potentials=True) # 2D — OK F, G = loss_fn(torch.randn(256, 2, requires_grad=True), torch.randn(256, 2)) # MNIST — CRASHES unless flattened x = torch.randn(128, 1, 28, 28) x_flat = x.view(128, -1).requires_grad_(True) # (128, 784) F, G = loss_fn(x_flat, torch.randn(128, 784)) # ✅ ``` **Fix applied**: `compute_velocity()` now flattens `(N,C,H,W)→(N,D)` before geomloss, reshapes gradients back after. This pattern applies to all OT libraries (geomloss, POT, ott-jax). --- ### #2 — TrajectoryPool re-concatenation every step (MODERATE) **What**: `torch.cat(self.x_pool, dim=0)` called on the entire pool (512K entries) every single training step. **Impact**: ~0.5s per step wasted on concatenation vs ~0.05s for the actual forward/backward. Training 10× slower than necessary. **Root cause**: Didn't profile. The pool was built from a list of tensors and never consolidated. **Prevention**: Pre-concatenate once after pool building: ```python def finalize(self): self._all_x = torch.cat(self.x_pool, dim=0) self.x_pool = None # free list memory def sample(self, batch_size, device): idx = torch.randint(0, self._all_x.shape[0], (batch_size,)) return self._all_x[idx].to(device) # O(1) ``` --- ### #3 — Incomplete experiment testing (CRITICAL) **What**: Tested 2D experiments thoroughly on CPU. Shipped MNIST/CIFAR completely untested. **Impact**: User's first GPU run of MNIST crashes immediately (geomloss bug). User's second run of CIFAR crashes immediately (also geomloss bug + OOM). Two Kaggle sessions wasted. **Root cause**: 2D success gave false confidence. Different experiment types exercise different code paths — 2D uses `(N, 2)` tensors, images use `(N, C, H, W)`. **Prevention**: Test EVERY experiment type with `--pool-batches 2 --train-iters 5` before declaring code ready. This takes <60 seconds on CPU. --- ### #4 — No checkpointing (MODERATE → CRITICAL at scale) **What**: No intermediate checkpoints during training. No phase-level saves. **Impact**: MNIST full run is ~7 hours on T4. Kaggle gives 9 hours per session. Any interruption (timeout, accidental Ctrl+C, OOM partway through) loses everything. CIFAR-10 is impossible in one session without resume. **Root cause**: Built for "run once and done" — didn't anticipate multi-session training. **Prevention**: Checkpoint after every phase + every N steps within phases. Implement `--resume-phase`. Test that resume actually loads and skips correctly. --- ### #5 — UNet forward pass fragility (LOW-MODERATE) **What**: `_get_num_res_blocks()` infers block count by dividing module list length by number of levels. Fragile if architecture varies. **Prevention**: Store `self.num_res_blocks = num_res_blocks` at init. Use it directly in `forward()`. --- ### #6 — DataLoader batch size mismatch across phases (CRITICAL) **What**: `DatasetLoader.sample_target()` lazily creates a DataLoader on first call, caching it. Phase 1 calls with `n=256` → DataLoader has `batch_size=256, drop_last=True`. Phase 2 calls with `n=128` → cached DataLoader still yields 256 → crash: ``` RuntimeError: The size of tensor a (128) must match the size of tensor b (256) at non-singleton dimension 0 ``` **Impact**: Phase 2 (NSF) crashes immediately even after Phase 1 completes successfully. The error message doesn't reveal the caching problem — it looks like a model shape issue. **Root cause**: Lazy initialization without invalidation. Classic stale cache bug. **Fix applied**: Track `self._image_batch_size` and recreate DataLoader when it changes: ```python def sample_target(self, n, device="cpu"): if not hasattr(self, "_loader") or self._batch_size != n: self._batch_size = n self._loader = get_image_dataloader(self.dataset_name, batch_size=n) self._iter = iter(self._loader) ``` --- ### #7 — CLI flag not overriding all phases (LOW) **What**: `--train-iters` overrode `nsgf_training.num_iterations` and `nsf_training.num_iterations` but missed `time_predictor.num_iterations` (default 40,000). **Impact**: Smoke test with `--train-iters 5` completes Phase 1 and 2 in seconds, then hangs for hours on Phase 3. **Prevention**: Grep config for ALL instances of the parameter: `grep -n num_iterations config.yaml`. Found 3 — must override all 3. --- ### #8 — CIFAR-10 Sinkhorn OOM on T4 (CRITICAL) **What**: Paper's `sinkhorn.batch_size=128` for CIFAR-10. Sinkhorn `tensorized` backend computes 128×128 cost matrix with 3072-dim vectors (flattened 3×32×32). With `potentials=True` + `autograd.grad`, plus 2 calls per flow step × 5 steps per batch = 10 Sinkhorn calls per pool batch. Total: 8+ GB just for Sinkhorn, plus the 38M-param UNet → OOM on T4 16GB. **Impact**: CIFAR-10 crashes during pool building, before training even starts. **Root cause**: Used paper hyperparameters verbatim without estimating VRAM for target hardware. Paper authors likely used A100 80GB. **VRAM math that should have been done upfront**: | Component | Approx VRAM | |-----------|------------| | Sinkhorn call (128×3072, tensorized) | ~600 MB | | × 10 calls per pool batch (with autograd) | ~6+ GB peak | | UNet (38M params, fp32) | ~150 MB | | Gradients + optimizer states | ~450 MB | | **Total** | **~7+ GB** (doesn't account for fragmentation) | **Fix applied**: Reduce `sinkhorn.batch_size` 128→32, increase `pool.num_batches` 2500→10000. Total pool entries unchanged (1.6M). Added `--sinkhorn-batch` CLI flag. --- ### #9 — No GPU memory freed between phases (MODERATE) **What**: After pool building finishes, Sinkhorn computation graph CUDA allocations stay cached. Training starts with reduced available VRAM. **Fix applied**: `torch.cuda.empty_cache()` after pool building + `del pool` after finalization. --- ### #10 — Multi-GPU assumption (LOW) **What**: User has T4×2 on Kaggle. Code is single-GPU. Second GPU sits idle. **Prevention**: Document explicitly: "Single-GPU only. T4×2 wastes the second GPU. Use single T4 instead, or add DDP." --- ## Principles Derived from the mistakes above. Each traces back to one or more specific failures. 1. **Read the appendix first.** The appendix is the recipe; the main paper is the story. *(from: all mistakes — hyperparams were in appendix)* 2. **Test the boundaries, not just the happy path.** The bug is always in the path you didn't test. *(from: #1, #3)* 3. **Library APIs are opaque until tested.** Don't assume a function accepts your tensor shape just because it "makes sense." Write a 10-line test script. *(from: #1)* 4. **Pre-concatenate, don't re-concatenate.** Any data structure built once and sampled many times should be finalized into a single tensor. *(from: #2)* 5. **The user's time is more expensive than your time.** A crash on their GPU after 5 minutes of setup is worse than you spending 30 extra minutes testing. *(from: #1, #3, #8)* 6. **Flatten for OT libraries.** geomloss, POT, ott-jax all expect `(N, D)` point clouds. Images must be flattened. #1 gotcha in OT-based generative models. *(from: #1)* 7. **Store training state on CPU, compute on GPU.** Pools and replay buffers on CPU; only minibatch goes to GPU. *(from: #8, #9)* 8. **Multi-phase training = multiple separate trainers.** Each phase gets its own optimizer. Previous phase's model → `eval()`. *(from: #6)* 9. **Shared objects across phases are landmines.** Cached DataLoaders, iterators, batch sizes — any phase-specific param can silently break later phases. *(from: #6)* 10. **CLI overrides must be exhaustive.** N copies of a parameter in config → CLI must touch all N. *(from: #7)* 11. **Paper hyperparameters assume paper hardware.** Re-derive batch sizes from VRAM constraints. Keep total samples seen constant. *(from: #8)* 12. **Estimate VRAM before running, not after OOM.** Sinkhorn: O(N² × D). Model: params × 4 bytes × 3. Write it down before the first GPU run. *(from: #8)* 13. **Checkpoint at phase boundaries, not just step boundaries.** Phase-level checkpoints enable `--resume-phase`. Step-level checkpoints within long phases are a bonus. *(from: #4)* 14. **Free GPU memory between phases.** `torch.cuda.empty_cache()` + `del` large objects between phases with different memory patterns. *(from: #9)* 15. **Document what your code does NOT support.** Single-GPU? No mixed precision? Say so explicitly. *(from: #10)* --- ## Sinkhorn VRAM Reference Table For quick VRAM estimation during paper reproduction planning. `tensorized` backend, fp32, single `SamplesLoss` call with `potentials=True`. | N (batch) | D (dim) | Example | Approx VRAM/call | |-----------|---------|---------|-----------------| | 256 | 2 | 2D points | ~1 MB | | 256 | 784 | MNIST 28×28 | ~200 MB | | 128 | 3072 | CIFAR 3×32×32 | ~600 MB | | 32 | 3072 | CIFAR (T4-safe) | ~40 MB | Pool building does ~10 calls per batch (2 potentials × 5 flow steps). Multiply accordingly. --- ## Timeline of Bugs | When | What broke | How discovered | Sessions wasted | |------|-----------|----------------|-----------------| | Initial ship | geomloss shape (MNIST/CIFAR) | User's first GPU run | 1 | | After geomloss fix | DataLoader batch mismatch (Phase 2) | My CPU test | 0 (caught in sandbox) | | After DataLoader fix | `--train-iters` missed Phase 3 | My CPU test | 0 (caught in sandbox) | | User's full MNIST run | No checkpointing, interrupted | User's Kaggle timeout | 1 | | User's CIFAR run | Sinkhorn OOM on T4 | User's Kaggle log | 1 | | After OOM fix | No `empty_cache()` between phases | Analysis of VRAM patterns | 0 (preventive) | **Total Kaggle sessions wasted by user: 3.** The bugs caught in sandbox (0 sessions wasted) vs bugs caught on user hardware (3 sessions) reinforce principle #5: test everything before shipping.