| # LEARNING.md β Lessons from NSGF++ Reproduction |
|
|
| All concrete mistakes, war stories, code examples, and derived principles from reproducing arXiv:2401.14069 on Kaggle T4 GPUs. Rules and procedures live in [SKILL.md](SKILL.md). |
|
|
| --- |
|
|
| ## Mistake Catalog |
|
|
| ### #1 β geomloss tensor shape bug (CRITICAL) |
|
|
| **What**: `SamplesLoss` in geomloss requires `(N, D)` or `(B, N, D)` tensors. Image experiments passed `(N, C, H, W)`. |
|
|
| **Impact**: MNIST and CIFAR-10 crash immediately with `ValueError: Input samples 'x' and 'y' should be encoded as (N,D) or (B,N,D) (batch) tensors.` The 2D experiment works fine with shape `(256, 2)`, hiding the bug completely. |
|
|
| **Root cause**: Only tested the 2D code path. Never ran a 10-line test with image-shaped tensors. |
|
|
| **Prevention**: Before building any training loop, test the library with EXACT tensor shapes for every experiment: |
|
|
| ```python |
| from geomloss import SamplesLoss |
| loss_fn = SamplesLoss(loss="sinkhorn", p=2, blur=0.5, potentials=True) |
| |
| # 2D β OK |
| F, G = loss_fn(torch.randn(256, 2, requires_grad=True), torch.randn(256, 2)) |
| |
| # MNIST β CRASHES unless flattened |
| x = torch.randn(128, 1, 28, 28) |
| x_flat = x.view(128, -1).requires_grad_(True) # (128, 784) |
| F, G = loss_fn(x_flat, torch.randn(128, 784)) # β
|
| ``` |
|
|
| **Fix applied**: `compute_velocity()` now flattens `(N,C,H,W)β(N,D)` before geomloss, reshapes gradients back after. This pattern applies to all OT libraries (geomloss, POT, ott-jax). |
|
|
| --- |
|
|
| ### #2 β TrajectoryPool re-concatenation every step (MODERATE) |
|
|
| **What**: `torch.cat(self.x_pool, dim=0)` called on the entire pool (512K entries) every single training step. |
|
|
| **Impact**: ~0.5s per step wasted on concatenation vs ~0.05s for the actual forward/backward. Training 10Γ slower than necessary. |
|
|
| **Root cause**: Didn't profile. The pool was built from a list of tensors and never consolidated. |
|
|
| **Prevention**: Pre-concatenate once after pool building: |
|
|
| ```python |
| def finalize(self): |
| self._all_x = torch.cat(self.x_pool, dim=0) |
| self.x_pool = None # free list memory |
| |
| def sample(self, batch_size, device): |
| idx = torch.randint(0, self._all_x.shape[0], (batch_size,)) |
| return self._all_x[idx].to(device) # O(1) |
| ``` |
|
|
| --- |
|
|
| ### #3 β Incomplete experiment testing (CRITICAL) |
|
|
| **What**: Tested 2D experiments thoroughly on CPU. Shipped MNIST/CIFAR completely untested. |
|
|
| **Impact**: User's first GPU run of MNIST crashes immediately (geomloss bug). User's second run of CIFAR crashes immediately (also geomloss bug + OOM). Two Kaggle sessions wasted. |
|
|
| **Root cause**: 2D success gave false confidence. Different experiment types exercise different code paths β 2D uses `(N, 2)` tensors, images use `(N, C, H, W)`. |
|
|
| **Prevention**: Test EVERY experiment type with `--pool-batches 2 --train-iters 5` before declaring code ready. This takes <60 seconds on CPU. |
|
|
| --- |
|
|
| ### #4 β No checkpointing (MODERATE β CRITICAL at scale) |
|
|
| **What**: No intermediate checkpoints during training. No phase-level saves. |
|
|
| **Impact**: MNIST full run is ~7 hours on T4. Kaggle gives 9 hours per session. Any interruption (timeout, accidental Ctrl+C, OOM partway through) loses everything. CIFAR-10 is impossible in one session without resume. |
|
|
| **Root cause**: Built for "run once and done" β didn't anticipate multi-session training. |
|
|
| **Prevention**: Checkpoint after every phase + every N steps within phases. Implement `--resume-phase`. Test that resume actually loads and skips correctly. |
|
|
| --- |
|
|
| ### #5 β UNet forward pass fragility (LOW-MODERATE) |
|
|
| **What**: `_get_num_res_blocks()` infers block count by dividing module list length by number of levels. Fragile if architecture varies. |
|
|
| **Prevention**: Store `self.num_res_blocks = num_res_blocks` at init. Use it directly in `forward()`. |
|
|
| --- |
|
|
| ### #6 β DataLoader batch size mismatch across phases (CRITICAL) |
|
|
| **What**: `DatasetLoader.sample_target()` lazily creates a DataLoader on first call, caching it. Phase 1 calls with `n=256` β DataLoader has `batch_size=256, drop_last=True`. Phase 2 calls with `n=128` β cached DataLoader still yields 256 β crash: |
|
|
| ``` |
| RuntimeError: The size of tensor a (128) must match the size of tensor b (256) at non-singleton dimension 0 |
| ``` |
|
|
| **Impact**: Phase 2 (NSF) crashes immediately even after Phase 1 completes successfully. The error message doesn't reveal the caching problem β it looks like a model shape issue. |
|
|
| **Root cause**: Lazy initialization without invalidation. Classic stale cache bug. |
|
|
| **Fix applied**: Track `self._image_batch_size` and recreate DataLoader when it changes: |
|
|
| ```python |
| def sample_target(self, n, device="cpu"): |
| if not hasattr(self, "_loader") or self._batch_size != n: |
| self._batch_size = n |
| self._loader = get_image_dataloader(self.dataset_name, batch_size=n) |
| self._iter = iter(self._loader) |
| ``` |
|
|
| --- |
|
|
| ### #7 β CLI flag not overriding all phases (LOW) |
|
|
| **What**: `--train-iters` overrode `nsgf_training.num_iterations` and `nsf_training.num_iterations` but missed `time_predictor.num_iterations` (default 40,000). |
|
|
| **Impact**: Smoke test with `--train-iters 5` completes Phase 1 and 2 in seconds, then hangs for hours on Phase 3. |
|
|
| **Prevention**: Grep config for ALL instances of the parameter: `grep -n num_iterations config.yaml`. Found 3 β must override all 3. |
|
|
| --- |
|
|
| ### #8 β CIFAR-10 Sinkhorn OOM on T4 (CRITICAL) |
|
|
| **What**: Paper's `sinkhorn.batch_size=128` for CIFAR-10. Sinkhorn `tensorized` backend computes 128Γ128 cost matrix with 3072-dim vectors (flattened 3Γ32Γ32). With `potentials=True` + `autograd.grad`, plus 2 calls per flow step Γ 5 steps per batch = 10 Sinkhorn calls per pool batch. Total: 8+ GB just for Sinkhorn, plus the 38M-param UNet β OOM on T4 16GB. |
|
|
| **Impact**: CIFAR-10 crashes during pool building, before training even starts. |
|
|
| **Root cause**: Used paper hyperparameters verbatim without estimating VRAM for target hardware. Paper authors likely used A100 80GB. |
|
|
| **VRAM math that should have been done upfront**: |
| | Component | Approx VRAM | |
| |-----------|------------| |
| | Sinkhorn call (128Γ3072, tensorized) | ~600 MB | |
| | Γ 10 calls per pool batch (with autograd) | ~6+ GB peak | |
| | UNet (38M params, fp32) | ~150 MB | |
| | Gradients + optimizer states | ~450 MB | |
| | **Total** | **~7+ GB** (doesn't account for fragmentation) | |
|
|
| **Fix applied**: Reduce `sinkhorn.batch_size` 128β32, increase `pool.num_batches` 2500β10000. Total pool entries unchanged (1.6M). Added `--sinkhorn-batch` CLI flag. |
|
|
| --- |
|
|
| ### #9 β No GPU memory freed between phases (MODERATE) |
|
|
| **What**: After pool building finishes, Sinkhorn computation graph CUDA allocations stay cached. Training starts with reduced available VRAM. |
|
|
| **Fix applied**: `torch.cuda.empty_cache()` after pool building + `del pool` after finalization. |
|
|
| --- |
|
|
| ### #10 β Multi-GPU assumption (LOW) |
|
|
| **What**: User has T4Γ2 on Kaggle. Code is single-GPU. Second GPU sits idle. |
|
|
| **Prevention**: Document explicitly: "Single-GPU only. T4Γ2 wastes the second GPU. Use single T4 instead, or add DDP." |
|
|
| --- |
|
|
| ## Principles |
|
|
| Derived from the mistakes above. Each traces back to one or more specific failures. |
|
|
| 1. **Read the appendix first.** The appendix is the recipe; the main paper is the story. *(from: all mistakes β hyperparams were in appendix)* |
|
|
| 2. **Test the boundaries, not just the happy path.** The bug is always in the path you didn't test. *(from: #1, #3)* |
|
|
| 3. **Library APIs are opaque until tested.** Don't assume a function accepts your tensor shape just because it "makes sense." Write a 10-line test script. *(from: #1)* |
|
|
| 4. **Pre-concatenate, don't re-concatenate.** Any data structure built once and sampled many times should be finalized into a single tensor. *(from: #2)* |
|
|
| 5. **The user's time is more expensive than your time.** A crash on their GPU after 5 minutes of setup is worse than you spending 30 extra minutes testing. *(from: #1, #3, #8)* |
|
|
| 6. **Flatten for OT libraries.** geomloss, POT, ott-jax all expect `(N, D)` point clouds. Images must be flattened. #1 gotcha in OT-based generative models. *(from: #1)* |
|
|
| 7. **Store training state on CPU, compute on GPU.** Pools and replay buffers on CPU; only minibatch goes to GPU. *(from: #8, #9)* |
|
|
| 8. **Multi-phase training = multiple separate trainers.** Each phase gets its own optimizer. Previous phase's model β `eval()`. *(from: #6)* |
|
|
| 9. **Shared objects across phases are landmines.** Cached DataLoaders, iterators, batch sizes β any phase-specific param can silently break later phases. *(from: #6)* |
|
|
| 10. **CLI overrides must be exhaustive.** N copies of a parameter in config β CLI must touch all N. *(from: #7)* |
|
|
| 11. **Paper hyperparameters assume paper hardware.** Re-derive batch sizes from VRAM constraints. Keep total samples seen constant. *(from: #8)* |
|
|
| 12. **Estimate VRAM before running, not after OOM.** Sinkhorn: O(NΒ² Γ D). Model: params Γ 4 bytes Γ 3. Write it down before the first GPU run. *(from: #8)* |
|
|
| 13. **Checkpoint at phase boundaries, not just step boundaries.** Phase-level checkpoints enable `--resume-phase`. Step-level checkpoints within long phases are a bonus. *(from: #4)* |
|
|
| 14. **Free GPU memory between phases.** `torch.cuda.empty_cache()` + `del` large objects between phases with different memory patterns. *(from: #9)* |
|
|
| 15. **Document what your code does NOT support.** Single-GPU? No mixed precision? Say so explicitly. *(from: #10)* |
|
|
| --- |
|
|
| ## Sinkhorn VRAM Reference Table |
|
|
| For quick VRAM estimation during paper reproduction planning. `tensorized` backend, fp32, single `SamplesLoss` call with `potentials=True`. |
|
|
| | N (batch) | D (dim) | Example | Approx VRAM/call | |
| |-----------|---------|---------|-----------------| |
| | 256 | 2 | 2D points | ~1 MB | |
| | 256 | 784 | MNIST 28Γ28 | ~200 MB | |
| | 128 | 3072 | CIFAR 3Γ32Γ32 | ~600 MB | |
| | 32 | 3072 | CIFAR (T4-safe) | ~40 MB | |
|
|
| Pool building does ~10 calls per batch (2 potentials Γ 5 flow steps). Multiply accordingly. |
|
|
| --- |
|
|
| ## Timeline of Bugs |
|
|
| | When | What broke | How discovered | Sessions wasted | |
| |------|-----------|----------------|-----------------| |
| | Initial ship | geomloss shape (MNIST/CIFAR) | User's first GPU run | 1 | |
| | After geomloss fix | DataLoader batch mismatch (Phase 2) | My CPU test | 0 (caught in sandbox) | |
| | After DataLoader fix | `--train-iters` missed Phase 3 | My CPU test | 0 (caught in sandbox) | |
| | User's full MNIST run | No checkpointing, interrupted | User's Kaggle timeout | 1 | |
| | User's CIFAR run | Sinkhorn OOM on T4 | User's Kaggle log | 1 | |
| | After OOM fix | No `empty_cache()` between phases | Analysis of VRAM patterns | 0 (preventive) | |
|
|
| **Total Kaggle sessions wasted by user: 3.** The bugs caught in sandbox (0 sessions wasted) vs bugs caught on user hardware (3 sessions) reinforce principle #5: test everything before shipping. |
|
|