LEARNING.md β Lessons from NSGF++ Reproduction
All concrete mistakes, war stories, code examples, and derived principles from reproducing arXiv:2401.14069 on Kaggle T4 GPUs. Rules and procedures live in SKILL.md.
Mistake Catalog
#1 β geomloss tensor shape bug (CRITICAL)
What: SamplesLoss in geomloss requires (N, D) or (B, N, D) tensors. Image experiments passed (N, C, H, W).
Impact: MNIST and CIFAR-10 crash immediately with ValueError: Input samples 'x' and 'y' should be encoded as (N,D) or (B,N,D) (batch) tensors. The 2D experiment works fine with shape (256, 2), hiding the bug completely.
Root cause: Only tested the 2D code path. Never ran a 10-line test with image-shaped tensors.
Prevention: Before building any training loop, test the library with EXACT tensor shapes for every experiment:
from geomloss import SamplesLoss
loss_fn = SamplesLoss(loss="sinkhorn", p=2, blur=0.5, potentials=True)
# 2D β OK
F, G = loss_fn(torch.randn(256, 2, requires_grad=True), torch.randn(256, 2))
# MNIST β CRASHES unless flattened
x = torch.randn(128, 1, 28, 28)
x_flat = x.view(128, -1).requires_grad_(True) # (128, 784)
F, G = loss_fn(x_flat, torch.randn(128, 784)) # β
Fix applied: compute_velocity() now flattens (N,C,H,W)β(N,D) before geomloss, reshapes gradients back after. This pattern applies to all OT libraries (geomloss, POT, ott-jax).
#2 β TrajectoryPool re-concatenation every step (MODERATE)
What: torch.cat(self.x_pool, dim=0) called on the entire pool (512K entries) every single training step.
Impact: ~0.5s per step wasted on concatenation vs ~0.05s for the actual forward/backward. Training 10Γ slower than necessary.
Root cause: Didn't profile. The pool was built from a list of tensors and never consolidated.
Prevention: Pre-concatenate once after pool building:
def finalize(self):
self._all_x = torch.cat(self.x_pool, dim=0)
self.x_pool = None # free list memory
def sample(self, batch_size, device):
idx = torch.randint(0, self._all_x.shape[0], (batch_size,))
return self._all_x[idx].to(device) # O(1)
#3 β Incomplete experiment testing (CRITICAL)
What: Tested 2D experiments thoroughly on CPU. Shipped MNIST/CIFAR completely untested.
Impact: User's first GPU run of MNIST crashes immediately (geomloss bug). User's second run of CIFAR crashes immediately (also geomloss bug + OOM). Two Kaggle sessions wasted.
Root cause: 2D success gave false confidence. Different experiment types exercise different code paths β 2D uses (N, 2) tensors, images use (N, C, H, W).
Prevention: Test EVERY experiment type with --pool-batches 2 --train-iters 5 before declaring code ready. This takes <60 seconds on CPU.
#4 β No checkpointing (MODERATE β CRITICAL at scale)
What: No intermediate checkpoints during training. No phase-level saves.
Impact: MNIST full run is ~7 hours on T4. Kaggle gives 9 hours per session. Any interruption (timeout, accidental Ctrl+C, OOM partway through) loses everything. CIFAR-10 is impossible in one session without resume.
Root cause: Built for "run once and done" β didn't anticipate multi-session training.
Prevention: Checkpoint after every phase + every N steps within phases. Implement --resume-phase. Test that resume actually loads and skips correctly.
#5 β UNet forward pass fragility (LOW-MODERATE)
What: _get_num_res_blocks() infers block count by dividing module list length by number of levels. Fragile if architecture varies.
Prevention: Store self.num_res_blocks = num_res_blocks at init. Use it directly in forward().
#6 β DataLoader batch size mismatch across phases (CRITICAL)
What: DatasetLoader.sample_target() lazily creates a DataLoader on first call, caching it. Phase 1 calls with n=256 β DataLoader has batch_size=256, drop_last=True. Phase 2 calls with n=128 β cached DataLoader still yields 256 β crash:
RuntimeError: The size of tensor a (128) must match the size of tensor b (256) at non-singleton dimension 0
Impact: Phase 2 (NSF) crashes immediately even after Phase 1 completes successfully. The error message doesn't reveal the caching problem β it looks like a model shape issue.
Root cause: Lazy initialization without invalidation. Classic stale cache bug.
Fix applied: Track self._image_batch_size and recreate DataLoader when it changes:
def sample_target(self, n, device="cpu"):
if not hasattr(self, "_loader") or self._batch_size != n:
self._batch_size = n
self._loader = get_image_dataloader(self.dataset_name, batch_size=n)
self._iter = iter(self._loader)
#7 β CLI flag not overriding all phases (LOW)
What: --train-iters overrode nsgf_training.num_iterations and nsf_training.num_iterations but missed time_predictor.num_iterations (default 40,000).
Impact: Smoke test with --train-iters 5 completes Phase 1 and 2 in seconds, then hangs for hours on Phase 3.
Prevention: Grep config for ALL instances of the parameter: grep -n num_iterations config.yaml. Found 3 β must override all 3.
#8 β CIFAR-10 Sinkhorn OOM on T4 (CRITICAL)
What: Paper's sinkhorn.batch_size=128 for CIFAR-10. Sinkhorn tensorized backend computes 128Γ128 cost matrix with 3072-dim vectors (flattened 3Γ32Γ32). With potentials=True + autograd.grad, plus 2 calls per flow step Γ 5 steps per batch = 10 Sinkhorn calls per pool batch. Total: 8+ GB just for Sinkhorn, plus the 38M-param UNet β OOM on T4 16GB.
Impact: CIFAR-10 crashes during pool building, before training even starts.
Root cause: Used paper hyperparameters verbatim without estimating VRAM for target hardware. Paper authors likely used A100 80GB.
VRAM math that should have been done upfront:
| Component | Approx VRAM |
|---|---|
| Sinkhorn call (128Γ3072, tensorized) | ~600 MB |
| Γ 10 calls per pool batch (with autograd) | ~6+ GB peak |
| UNet (38M params, fp32) | ~150 MB |
| Gradients + optimizer states | ~450 MB |
| Total | ~7+ GB (doesn't account for fragmentation) |
Fix applied: Reduce sinkhorn.batch_size 128β32, increase pool.num_batches 2500β10000. Total pool entries unchanged (1.6M). Added --sinkhorn-batch CLI flag.
#9 β No GPU memory freed between phases (MODERATE)
What: After pool building finishes, Sinkhorn computation graph CUDA allocations stay cached. Training starts with reduced available VRAM.
Fix applied: torch.cuda.empty_cache() after pool building + del pool after finalization.
#10 β Multi-GPU assumption (LOW)
What: User has T4Γ2 on Kaggle. Code is single-GPU. Second GPU sits idle.
Prevention: Document explicitly: "Single-GPU only. T4Γ2 wastes the second GPU. Use single T4 instead, or add DDP."
Principles
Derived from the mistakes above. Each traces back to one or more specific failures.
Read the appendix first. The appendix is the recipe; the main paper is the story. (from: all mistakes β hyperparams were in appendix)
Test the boundaries, not just the happy path. The bug is always in the path you didn't test. (from: #1, #3)
Library APIs are opaque until tested. Don't assume a function accepts your tensor shape just because it "makes sense." Write a 10-line test script. (from: #1)
Pre-concatenate, don't re-concatenate. Any data structure built once and sampled many times should be finalized into a single tensor. (from: #2)
The user's time is more expensive than your time. A crash on their GPU after 5 minutes of setup is worse than you spending 30 extra minutes testing. (from: #1, #3, #8)
Flatten for OT libraries. geomloss, POT, ott-jax all expect
(N, D)point clouds. Images must be flattened. #1 gotcha in OT-based generative models. (from: #1)Store training state on CPU, compute on GPU. Pools and replay buffers on CPU; only minibatch goes to GPU. (from: #8, #9)
Multi-phase training = multiple separate trainers. Each phase gets its own optimizer. Previous phase's model β
eval(). (from: #6)Shared objects across phases are landmines. Cached DataLoaders, iterators, batch sizes β any phase-specific param can silently break later phases. (from: #6)
CLI overrides must be exhaustive. N copies of a parameter in config β CLI must touch all N. (from: #7)
Paper hyperparameters assume paper hardware. Re-derive batch sizes from VRAM constraints. Keep total samples seen constant. (from: #8)
Estimate VRAM before running, not after OOM. Sinkhorn: O(NΒ² Γ D). Model: params Γ 4 bytes Γ 3. Write it down before the first GPU run. (from: #8)
Checkpoint at phase boundaries, not just step boundaries. Phase-level checkpoints enable
--resume-phase. Step-level checkpoints within long phases are a bonus. (from: #4)Free GPU memory between phases.
torch.cuda.empty_cache()+dellarge objects between phases with different memory patterns. (from: #9)Document what your code does NOT support. Single-GPU? No mixed precision? Say so explicitly. (from: #10)
Sinkhorn VRAM Reference Table
For quick VRAM estimation during paper reproduction planning. tensorized backend, fp32, single SamplesLoss call with potentials=True.
| N (batch) | D (dim) | Example | Approx VRAM/call |
|---|---|---|---|
| 256 | 2 | 2D points | ~1 MB |
| 256 | 784 | MNIST 28Γ28 | ~200 MB |
| 128 | 3072 | CIFAR 3Γ32Γ32 | ~600 MB |
| 32 | 3072 | CIFAR (T4-safe) | ~40 MB |
Pool building does ~10 calls per batch (2 potentials Γ 5 flow steps). Multiply accordingly.
Timeline of Bugs
| When | What broke | How discovered | Sessions wasted |
|---|---|---|---|
| Initial ship | geomloss shape (MNIST/CIFAR) | User's first GPU run | 1 |
| After geomloss fix | DataLoader batch mismatch (Phase 2) | My CPU test | 0 (caught in sandbox) |
| After DataLoader fix | --train-iters missed Phase 3 |
My CPU test | 0 (caught in sandbox) |
| User's full MNIST run | No checkpointing, interrupted | User's Kaggle timeout | 1 |
| User's CIFAR run | Sinkhorn OOM on T4 | User's Kaggle log | 1 |
| After OOM fix | No empty_cache() between phases |
Analysis of VRAM patterns | 0 (preventive) |
Total Kaggle sessions wasted by user: 3. The bugs caught in sandbox (0 sessions wasted) vs bugs caught on user hardware (3 sessions) reinforce principle #5: test everything before shipping.