nsgf-plusplus / LEARNING.md
rogermt's picture
Add LEARNING.md β€” all mistakes, war stories, examples, and principles from NSGF++ reproduction
c95af57 verified

LEARNING.md β€” Lessons from NSGF++ Reproduction

All concrete mistakes, war stories, code examples, and derived principles from reproducing arXiv:2401.14069 on Kaggle T4 GPUs. Rules and procedures live in SKILL.md.


Mistake Catalog

#1 β€” geomloss tensor shape bug (CRITICAL)

What: SamplesLoss in geomloss requires (N, D) or (B, N, D) tensors. Image experiments passed (N, C, H, W).

Impact: MNIST and CIFAR-10 crash immediately with ValueError: Input samples 'x' and 'y' should be encoded as (N,D) or (B,N,D) (batch) tensors. The 2D experiment works fine with shape (256, 2), hiding the bug completely.

Root cause: Only tested the 2D code path. Never ran a 10-line test with image-shaped tensors.

Prevention: Before building any training loop, test the library with EXACT tensor shapes for every experiment:

from geomloss import SamplesLoss
loss_fn = SamplesLoss(loss="sinkhorn", p=2, blur=0.5, potentials=True)

# 2D β€” OK
F, G = loss_fn(torch.randn(256, 2, requires_grad=True), torch.randn(256, 2))

# MNIST β€” CRASHES unless flattened
x = torch.randn(128, 1, 28, 28)
x_flat = x.view(128, -1).requires_grad_(True)  # (128, 784)
F, G = loss_fn(x_flat, torch.randn(128, 784))  # βœ…

Fix applied: compute_velocity() now flattens (N,C,H,W)β†’(N,D) before geomloss, reshapes gradients back after. This pattern applies to all OT libraries (geomloss, POT, ott-jax).


#2 β€” TrajectoryPool re-concatenation every step (MODERATE)

What: torch.cat(self.x_pool, dim=0) called on the entire pool (512K entries) every single training step.

Impact: ~0.5s per step wasted on concatenation vs ~0.05s for the actual forward/backward. Training 10Γ— slower than necessary.

Root cause: Didn't profile. The pool was built from a list of tensors and never consolidated.

Prevention: Pre-concatenate once after pool building:

def finalize(self):
    self._all_x = torch.cat(self.x_pool, dim=0)
    self.x_pool = None  # free list memory

def sample(self, batch_size, device):
    idx = torch.randint(0, self._all_x.shape[0], (batch_size,))
    return self._all_x[idx].to(device)  # O(1)

#3 β€” Incomplete experiment testing (CRITICAL)

What: Tested 2D experiments thoroughly on CPU. Shipped MNIST/CIFAR completely untested.

Impact: User's first GPU run of MNIST crashes immediately (geomloss bug). User's second run of CIFAR crashes immediately (also geomloss bug + OOM). Two Kaggle sessions wasted.

Root cause: 2D success gave false confidence. Different experiment types exercise different code paths β€” 2D uses (N, 2) tensors, images use (N, C, H, W).

Prevention: Test EVERY experiment type with --pool-batches 2 --train-iters 5 before declaring code ready. This takes <60 seconds on CPU.


#4 β€” No checkpointing (MODERATE β†’ CRITICAL at scale)

What: No intermediate checkpoints during training. No phase-level saves.

Impact: MNIST full run is ~7 hours on T4. Kaggle gives 9 hours per session. Any interruption (timeout, accidental Ctrl+C, OOM partway through) loses everything. CIFAR-10 is impossible in one session without resume.

Root cause: Built for "run once and done" β€” didn't anticipate multi-session training.

Prevention: Checkpoint after every phase + every N steps within phases. Implement --resume-phase. Test that resume actually loads and skips correctly.


#5 β€” UNet forward pass fragility (LOW-MODERATE)

What: _get_num_res_blocks() infers block count by dividing module list length by number of levels. Fragile if architecture varies.

Prevention: Store self.num_res_blocks = num_res_blocks at init. Use it directly in forward().


#6 β€” DataLoader batch size mismatch across phases (CRITICAL)

What: DatasetLoader.sample_target() lazily creates a DataLoader on first call, caching it. Phase 1 calls with n=256 β†’ DataLoader has batch_size=256, drop_last=True. Phase 2 calls with n=128 β†’ cached DataLoader still yields 256 β†’ crash:

RuntimeError: The size of tensor a (128) must match the size of tensor b (256) at non-singleton dimension 0

Impact: Phase 2 (NSF) crashes immediately even after Phase 1 completes successfully. The error message doesn't reveal the caching problem β€” it looks like a model shape issue.

Root cause: Lazy initialization without invalidation. Classic stale cache bug.

Fix applied: Track self._image_batch_size and recreate DataLoader when it changes:

def sample_target(self, n, device="cpu"):
    if not hasattr(self, "_loader") or self._batch_size != n:
        self._batch_size = n
        self._loader = get_image_dataloader(self.dataset_name, batch_size=n)
        self._iter = iter(self._loader)

#7 β€” CLI flag not overriding all phases (LOW)

What: --train-iters overrode nsgf_training.num_iterations and nsf_training.num_iterations but missed time_predictor.num_iterations (default 40,000).

Impact: Smoke test with --train-iters 5 completes Phase 1 and 2 in seconds, then hangs for hours on Phase 3.

Prevention: Grep config for ALL instances of the parameter: grep -n num_iterations config.yaml. Found 3 β€” must override all 3.


#8 β€” CIFAR-10 Sinkhorn OOM on T4 (CRITICAL)

What: Paper's sinkhorn.batch_size=128 for CIFAR-10. Sinkhorn tensorized backend computes 128Γ—128 cost matrix with 3072-dim vectors (flattened 3Γ—32Γ—32). With potentials=True + autograd.grad, plus 2 calls per flow step Γ— 5 steps per batch = 10 Sinkhorn calls per pool batch. Total: 8+ GB just for Sinkhorn, plus the 38M-param UNet β†’ OOM on T4 16GB.

Impact: CIFAR-10 crashes during pool building, before training even starts.

Root cause: Used paper hyperparameters verbatim without estimating VRAM for target hardware. Paper authors likely used A100 80GB.

VRAM math that should have been done upfront:

Component Approx VRAM
Sinkhorn call (128Γ—3072, tensorized) ~600 MB
Γ— 10 calls per pool batch (with autograd) ~6+ GB peak
UNet (38M params, fp32) ~150 MB
Gradients + optimizer states ~450 MB
Total ~7+ GB (doesn't account for fragmentation)

Fix applied: Reduce sinkhorn.batch_size 128β†’32, increase pool.num_batches 2500β†’10000. Total pool entries unchanged (1.6M). Added --sinkhorn-batch CLI flag.


#9 β€” No GPU memory freed between phases (MODERATE)

What: After pool building finishes, Sinkhorn computation graph CUDA allocations stay cached. Training starts with reduced available VRAM.

Fix applied: torch.cuda.empty_cache() after pool building + del pool after finalization.


#10 β€” Multi-GPU assumption (LOW)

What: User has T4Γ—2 on Kaggle. Code is single-GPU. Second GPU sits idle.

Prevention: Document explicitly: "Single-GPU only. T4Γ—2 wastes the second GPU. Use single T4 instead, or add DDP."


Principles

Derived from the mistakes above. Each traces back to one or more specific failures.

  1. Read the appendix first. The appendix is the recipe; the main paper is the story. (from: all mistakes β€” hyperparams were in appendix)

  2. Test the boundaries, not just the happy path. The bug is always in the path you didn't test. (from: #1, #3)

  3. Library APIs are opaque until tested. Don't assume a function accepts your tensor shape just because it "makes sense." Write a 10-line test script. (from: #1)

  4. Pre-concatenate, don't re-concatenate. Any data structure built once and sampled many times should be finalized into a single tensor. (from: #2)

  5. The user's time is more expensive than your time. A crash on their GPU after 5 minutes of setup is worse than you spending 30 extra minutes testing. (from: #1, #3, #8)

  6. Flatten for OT libraries. geomloss, POT, ott-jax all expect (N, D) point clouds. Images must be flattened. #1 gotcha in OT-based generative models. (from: #1)

  7. Store training state on CPU, compute on GPU. Pools and replay buffers on CPU; only minibatch goes to GPU. (from: #8, #9)

  8. Multi-phase training = multiple separate trainers. Each phase gets its own optimizer. Previous phase's model β†’ eval(). (from: #6)

  9. Shared objects across phases are landmines. Cached DataLoaders, iterators, batch sizes β€” any phase-specific param can silently break later phases. (from: #6)

  10. CLI overrides must be exhaustive. N copies of a parameter in config β†’ CLI must touch all N. (from: #7)

  11. Paper hyperparameters assume paper hardware. Re-derive batch sizes from VRAM constraints. Keep total samples seen constant. (from: #8)

  12. Estimate VRAM before running, not after OOM. Sinkhorn: O(NΒ² Γ— D). Model: params Γ— 4 bytes Γ— 3. Write it down before the first GPU run. (from: #8)

  13. Checkpoint at phase boundaries, not just step boundaries. Phase-level checkpoints enable --resume-phase. Step-level checkpoints within long phases are a bonus. (from: #4)

  14. Free GPU memory between phases. torch.cuda.empty_cache() + del large objects between phases with different memory patterns. (from: #9)

  15. Document what your code does NOT support. Single-GPU? No mixed precision? Say so explicitly. (from: #10)


Sinkhorn VRAM Reference Table

For quick VRAM estimation during paper reproduction planning. tensorized backend, fp32, single SamplesLoss call with potentials=True.

N (batch) D (dim) Example Approx VRAM/call
256 2 2D points ~1 MB
256 784 MNIST 28Γ—28 ~200 MB
128 3072 CIFAR 3Γ—32Γ—32 ~600 MB
32 3072 CIFAR (T4-safe) ~40 MB

Pool building does ~10 calls per batch (2 potentials Γ— 5 flow steps). Multiply accordingly.


Timeline of Bugs

When What broke How discovered Sessions wasted
Initial ship geomloss shape (MNIST/CIFAR) User's first GPU run 1
After geomloss fix DataLoader batch mismatch (Phase 2) My CPU test 0 (caught in sandbox)
After DataLoader fix --train-iters missed Phase 3 My CPU test 0 (caught in sandbox)
User's full MNIST run No checkpointing, interrupted User's Kaggle timeout 1
User's CIFAR run Sinkhorn OOM on T4 User's Kaggle log 1
After OOM fix No empty_cache() between phases Analysis of VRAM patterns 0 (preventive)

Total Kaggle sessions wasted by user: 3. The bugs caught in sandbox (0 sessions wasted) vs bugs caught on user hardware (3 sessions) reinforce principle #5: test everything before shipping.