nsgf-plusplus / LEARNING.md
rogermt's picture
Add LEARNING.md β€” all mistakes, war stories, examples, and principles from NSGF++ reproduction
c95af57 verified
# LEARNING.md β€” Lessons from NSGF++ Reproduction
All concrete mistakes, war stories, code examples, and derived principles from reproducing arXiv:2401.14069 on Kaggle T4 GPUs. Rules and procedures live in [SKILL.md](SKILL.md).
---
## Mistake Catalog
### #1 β€” geomloss tensor shape bug (CRITICAL)
**What**: `SamplesLoss` in geomloss requires `(N, D)` or `(B, N, D)` tensors. Image experiments passed `(N, C, H, W)`.
**Impact**: MNIST and CIFAR-10 crash immediately with `ValueError: Input samples 'x' and 'y' should be encoded as (N,D) or (B,N,D) (batch) tensors.` The 2D experiment works fine with shape `(256, 2)`, hiding the bug completely.
**Root cause**: Only tested the 2D code path. Never ran a 10-line test with image-shaped tensors.
**Prevention**: Before building any training loop, test the library with EXACT tensor shapes for every experiment:
```python
from geomloss import SamplesLoss
loss_fn = SamplesLoss(loss="sinkhorn", p=2, blur=0.5, potentials=True)
# 2D β€” OK
F, G = loss_fn(torch.randn(256, 2, requires_grad=True), torch.randn(256, 2))
# MNIST β€” CRASHES unless flattened
x = torch.randn(128, 1, 28, 28)
x_flat = x.view(128, -1).requires_grad_(True) # (128, 784)
F, G = loss_fn(x_flat, torch.randn(128, 784)) # βœ…
```
**Fix applied**: `compute_velocity()` now flattens `(N,C,H,W)β†’(N,D)` before geomloss, reshapes gradients back after. This pattern applies to all OT libraries (geomloss, POT, ott-jax).
---
### #2 β€” TrajectoryPool re-concatenation every step (MODERATE)
**What**: `torch.cat(self.x_pool, dim=0)` called on the entire pool (512K entries) every single training step.
**Impact**: ~0.5s per step wasted on concatenation vs ~0.05s for the actual forward/backward. Training 10Γ— slower than necessary.
**Root cause**: Didn't profile. The pool was built from a list of tensors and never consolidated.
**Prevention**: Pre-concatenate once after pool building:
```python
def finalize(self):
self._all_x = torch.cat(self.x_pool, dim=0)
self.x_pool = None # free list memory
def sample(self, batch_size, device):
idx = torch.randint(0, self._all_x.shape[0], (batch_size,))
return self._all_x[idx].to(device) # O(1)
```
---
### #3 β€” Incomplete experiment testing (CRITICAL)
**What**: Tested 2D experiments thoroughly on CPU. Shipped MNIST/CIFAR completely untested.
**Impact**: User's first GPU run of MNIST crashes immediately (geomloss bug). User's second run of CIFAR crashes immediately (also geomloss bug + OOM). Two Kaggle sessions wasted.
**Root cause**: 2D success gave false confidence. Different experiment types exercise different code paths β€” 2D uses `(N, 2)` tensors, images use `(N, C, H, W)`.
**Prevention**: Test EVERY experiment type with `--pool-batches 2 --train-iters 5` before declaring code ready. This takes <60 seconds on CPU.
---
### #4 β€” No checkpointing (MODERATE β†’ CRITICAL at scale)
**What**: No intermediate checkpoints during training. No phase-level saves.
**Impact**: MNIST full run is ~7 hours on T4. Kaggle gives 9 hours per session. Any interruption (timeout, accidental Ctrl+C, OOM partway through) loses everything. CIFAR-10 is impossible in one session without resume.
**Root cause**: Built for "run once and done" β€” didn't anticipate multi-session training.
**Prevention**: Checkpoint after every phase + every N steps within phases. Implement `--resume-phase`. Test that resume actually loads and skips correctly.
---
### #5 β€” UNet forward pass fragility (LOW-MODERATE)
**What**: `_get_num_res_blocks()` infers block count by dividing module list length by number of levels. Fragile if architecture varies.
**Prevention**: Store `self.num_res_blocks = num_res_blocks` at init. Use it directly in `forward()`.
---
### #6 β€” DataLoader batch size mismatch across phases (CRITICAL)
**What**: `DatasetLoader.sample_target()` lazily creates a DataLoader on first call, caching it. Phase 1 calls with `n=256` β†’ DataLoader has `batch_size=256, drop_last=True`. Phase 2 calls with `n=128` β†’ cached DataLoader still yields 256 β†’ crash:
```
RuntimeError: The size of tensor a (128) must match the size of tensor b (256) at non-singleton dimension 0
```
**Impact**: Phase 2 (NSF) crashes immediately even after Phase 1 completes successfully. The error message doesn't reveal the caching problem β€” it looks like a model shape issue.
**Root cause**: Lazy initialization without invalidation. Classic stale cache bug.
**Fix applied**: Track `self._image_batch_size` and recreate DataLoader when it changes:
```python
def sample_target(self, n, device="cpu"):
if not hasattr(self, "_loader") or self._batch_size != n:
self._batch_size = n
self._loader = get_image_dataloader(self.dataset_name, batch_size=n)
self._iter = iter(self._loader)
```
---
### #7 β€” CLI flag not overriding all phases (LOW)
**What**: `--train-iters` overrode `nsgf_training.num_iterations` and `nsf_training.num_iterations` but missed `time_predictor.num_iterations` (default 40,000).
**Impact**: Smoke test with `--train-iters 5` completes Phase 1 and 2 in seconds, then hangs for hours on Phase 3.
**Prevention**: Grep config for ALL instances of the parameter: `grep -n num_iterations config.yaml`. Found 3 β€” must override all 3.
---
### #8 β€” CIFAR-10 Sinkhorn OOM on T4 (CRITICAL)
**What**: Paper's `sinkhorn.batch_size=128` for CIFAR-10. Sinkhorn `tensorized` backend computes 128Γ—128 cost matrix with 3072-dim vectors (flattened 3Γ—32Γ—32). With `potentials=True` + `autograd.grad`, plus 2 calls per flow step Γ— 5 steps per batch = 10 Sinkhorn calls per pool batch. Total: 8+ GB just for Sinkhorn, plus the 38M-param UNet β†’ OOM on T4 16GB.
**Impact**: CIFAR-10 crashes during pool building, before training even starts.
**Root cause**: Used paper hyperparameters verbatim without estimating VRAM for target hardware. Paper authors likely used A100 80GB.
**VRAM math that should have been done upfront**:
| Component | Approx VRAM |
|-----------|------------|
| Sinkhorn call (128Γ—3072, tensorized) | ~600 MB |
| Γ— 10 calls per pool batch (with autograd) | ~6+ GB peak |
| UNet (38M params, fp32) | ~150 MB |
| Gradients + optimizer states | ~450 MB |
| **Total** | **~7+ GB** (doesn't account for fragmentation) |
**Fix applied**: Reduce `sinkhorn.batch_size` 128β†’32, increase `pool.num_batches` 2500β†’10000. Total pool entries unchanged (1.6M). Added `--sinkhorn-batch` CLI flag.
---
### #9 β€” No GPU memory freed between phases (MODERATE)
**What**: After pool building finishes, Sinkhorn computation graph CUDA allocations stay cached. Training starts with reduced available VRAM.
**Fix applied**: `torch.cuda.empty_cache()` after pool building + `del pool` after finalization.
---
### #10 β€” Multi-GPU assumption (LOW)
**What**: User has T4Γ—2 on Kaggle. Code is single-GPU. Second GPU sits idle.
**Prevention**: Document explicitly: "Single-GPU only. T4Γ—2 wastes the second GPU. Use single T4 instead, or add DDP."
---
## Principles
Derived from the mistakes above. Each traces back to one or more specific failures.
1. **Read the appendix first.** The appendix is the recipe; the main paper is the story. *(from: all mistakes β€” hyperparams were in appendix)*
2. **Test the boundaries, not just the happy path.** The bug is always in the path you didn't test. *(from: #1, #3)*
3. **Library APIs are opaque until tested.** Don't assume a function accepts your tensor shape just because it "makes sense." Write a 10-line test script. *(from: #1)*
4. **Pre-concatenate, don't re-concatenate.** Any data structure built once and sampled many times should be finalized into a single tensor. *(from: #2)*
5. **The user's time is more expensive than your time.** A crash on their GPU after 5 minutes of setup is worse than you spending 30 extra minutes testing. *(from: #1, #3, #8)*
6. **Flatten for OT libraries.** geomloss, POT, ott-jax all expect `(N, D)` point clouds. Images must be flattened. #1 gotcha in OT-based generative models. *(from: #1)*
7. **Store training state on CPU, compute on GPU.** Pools and replay buffers on CPU; only minibatch goes to GPU. *(from: #8, #9)*
8. **Multi-phase training = multiple separate trainers.** Each phase gets its own optimizer. Previous phase's model β†’ `eval()`. *(from: #6)*
9. **Shared objects across phases are landmines.** Cached DataLoaders, iterators, batch sizes β€” any phase-specific param can silently break later phases. *(from: #6)*
10. **CLI overrides must be exhaustive.** N copies of a parameter in config β†’ CLI must touch all N. *(from: #7)*
11. **Paper hyperparameters assume paper hardware.** Re-derive batch sizes from VRAM constraints. Keep total samples seen constant. *(from: #8)*
12. **Estimate VRAM before running, not after OOM.** Sinkhorn: O(NΒ² Γ— D). Model: params Γ— 4 bytes Γ— 3. Write it down before the first GPU run. *(from: #8)*
13. **Checkpoint at phase boundaries, not just step boundaries.** Phase-level checkpoints enable `--resume-phase`. Step-level checkpoints within long phases are a bonus. *(from: #4)*
14. **Free GPU memory between phases.** `torch.cuda.empty_cache()` + `del` large objects between phases with different memory patterns. *(from: #9)*
15. **Document what your code does NOT support.** Single-GPU? No mixed precision? Say so explicitly. *(from: #10)*
---
## Sinkhorn VRAM Reference Table
For quick VRAM estimation during paper reproduction planning. `tensorized` backend, fp32, single `SamplesLoss` call with `potentials=True`.
| N (batch) | D (dim) | Example | Approx VRAM/call |
|-----------|---------|---------|-----------------|
| 256 | 2 | 2D points | ~1 MB |
| 256 | 784 | MNIST 28Γ—28 | ~200 MB |
| 128 | 3072 | CIFAR 3Γ—32Γ—32 | ~600 MB |
| 32 | 3072 | CIFAR (T4-safe) | ~40 MB |
Pool building does ~10 calls per batch (2 potentials Γ— 5 flow steps). Multiply accordingly.
---
## Timeline of Bugs
| When | What broke | How discovered | Sessions wasted |
|------|-----------|----------------|-----------------|
| Initial ship | geomloss shape (MNIST/CIFAR) | User's first GPU run | 1 |
| After geomloss fix | DataLoader batch mismatch (Phase 2) | My CPU test | 0 (caught in sandbox) |
| After DataLoader fix | `--train-iters` missed Phase 3 | My CPU test | 0 (caught in sandbox) |
| User's full MNIST run | No checkpointing, interrupted | User's Kaggle timeout | 1 |
| User's CIFAR run | Sinkhorn OOM on T4 | User's Kaggle log | 1 |
| After OOM fix | No `empty_cache()` between phases | Analysis of VRAM patterns | 0 (preventive) |
**Total Kaggle sessions wasted by user: 3.** The bugs caught in sandbox (0 sessions wasted) vs bugs caught on user hardware (3 sessions) reinforce principle #5: test everything before shipping.