rogermt commited on
Commit
c95af57
·
verified ·
1 Parent(s): d6ef77d

Add LEARNING.md — all mistakes, war stories, examples, and principles from NSGF++ reproduction

Browse files
Files changed (1) hide show
  1. LEARNING.md +223 -0
LEARNING.md ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LEARNING.md — Lessons from NSGF++ Reproduction
2
+
3
+ All concrete mistakes, war stories, code examples, and derived principles from reproducing arXiv:2401.14069 on Kaggle T4 GPUs. Rules and procedures live in [SKILL.md](SKILL.md).
4
+
5
+ ---
6
+
7
+ ## Mistake Catalog
8
+
9
+ ### #1 — geomloss tensor shape bug (CRITICAL)
10
+
11
+ **What**: `SamplesLoss` in geomloss requires `(N, D)` or `(B, N, D)` tensors. Image experiments passed `(N, C, H, W)`.
12
+
13
+ **Impact**: MNIST and CIFAR-10 crash immediately with `ValueError: Input samples 'x' and 'y' should be encoded as (N,D) or (B,N,D) (batch) tensors.` The 2D experiment works fine with shape `(256, 2)`, hiding the bug completely.
14
+
15
+ **Root cause**: Only tested the 2D code path. Never ran a 10-line test with image-shaped tensors.
16
+
17
+ **Prevention**: Before building any training loop, test the library with EXACT tensor shapes for every experiment:
18
+
19
+ ```python
20
+ from geomloss import SamplesLoss
21
+ loss_fn = SamplesLoss(loss="sinkhorn", p=2, blur=0.5, potentials=True)
22
+
23
+ # 2D — OK
24
+ F, G = loss_fn(torch.randn(256, 2, requires_grad=True), torch.randn(256, 2))
25
+
26
+ # MNIST — CRASHES unless flattened
27
+ x = torch.randn(128, 1, 28, 28)
28
+ x_flat = x.view(128, -1).requires_grad_(True) # (128, 784)
29
+ F, G = loss_fn(x_flat, torch.randn(128, 784)) # ✅
30
+ ```
31
+
32
+ **Fix applied**: `compute_velocity()` now flattens `(N,C,H,W)→(N,D)` before geomloss, reshapes gradients back after. This pattern applies to all OT libraries (geomloss, POT, ott-jax).
33
+
34
+ ---
35
+
36
+ ### #2 — TrajectoryPool re-concatenation every step (MODERATE)
37
+
38
+ **What**: `torch.cat(self.x_pool, dim=0)` called on the entire pool (512K entries) every single training step.
39
+
40
+ **Impact**: ~0.5s per step wasted on concatenation vs ~0.05s for the actual forward/backward. Training 10× slower than necessary.
41
+
42
+ **Root cause**: Didn't profile. The pool was built from a list of tensors and never consolidated.
43
+
44
+ **Prevention**: Pre-concatenate once after pool building:
45
+
46
+ ```python
47
+ def finalize(self):
48
+ self._all_x = torch.cat(self.x_pool, dim=0)
49
+ self.x_pool = None # free list memory
50
+
51
+ def sample(self, batch_size, device):
52
+ idx = torch.randint(0, self._all_x.shape[0], (batch_size,))
53
+ return self._all_x[idx].to(device) # O(1)
54
+ ```
55
+
56
+ ---
57
+
58
+ ### #3 — Incomplete experiment testing (CRITICAL)
59
+
60
+ **What**: Tested 2D experiments thoroughly on CPU. Shipped MNIST/CIFAR completely untested.
61
+
62
+ **Impact**: User's first GPU run of MNIST crashes immediately (geomloss bug). User's second run of CIFAR crashes immediately (also geomloss bug + OOM). Two Kaggle sessions wasted.
63
+
64
+ **Root cause**: 2D success gave false confidence. Different experiment types exercise different code paths — 2D uses `(N, 2)` tensors, images use `(N, C, H, W)`.
65
+
66
+ **Prevention**: Test EVERY experiment type with `--pool-batches 2 --train-iters 5` before declaring code ready. This takes <60 seconds on CPU.
67
+
68
+ ---
69
+
70
+ ### #4 — No checkpointing (MODERATE → CRITICAL at scale)
71
+
72
+ **What**: No intermediate checkpoints during training. No phase-level saves.
73
+
74
+ **Impact**: MNIST full run is ~7 hours on T4. Kaggle gives 9 hours per session. Any interruption (timeout, accidental Ctrl+C, OOM partway through) loses everything. CIFAR-10 is impossible in one session without resume.
75
+
76
+ **Root cause**: Built for "run once and done" — didn't anticipate multi-session training.
77
+
78
+ **Prevention**: Checkpoint after every phase + every N steps within phases. Implement `--resume-phase`. Test that resume actually loads and skips correctly.
79
+
80
+ ---
81
+
82
+ ### #5 — UNet forward pass fragility (LOW-MODERATE)
83
+
84
+ **What**: `_get_num_res_blocks()` infers block count by dividing module list length by number of levels. Fragile if architecture varies.
85
+
86
+ **Prevention**: Store `self.num_res_blocks = num_res_blocks` at init. Use it directly in `forward()`.
87
+
88
+ ---
89
+
90
+ ### #6 — DataLoader batch size mismatch across phases (CRITICAL)
91
+
92
+ **What**: `DatasetLoader.sample_target()` lazily creates a DataLoader on first call, caching it. Phase 1 calls with `n=256` → DataLoader has `batch_size=256, drop_last=True`. Phase 2 calls with `n=128` → cached DataLoader still yields 256 → crash:
93
+
94
+ ```
95
+ RuntimeError: The size of tensor a (128) must match the size of tensor b (256) at non-singleton dimension 0
96
+ ```
97
+
98
+ **Impact**: Phase 2 (NSF) crashes immediately even after Phase 1 completes successfully. The error message doesn't reveal the caching problem — it looks like a model shape issue.
99
+
100
+ **Root cause**: Lazy initialization without invalidation. Classic stale cache bug.
101
+
102
+ **Fix applied**: Track `self._image_batch_size` and recreate DataLoader when it changes:
103
+
104
+ ```python
105
+ def sample_target(self, n, device="cpu"):
106
+ if not hasattr(self, "_loader") or self._batch_size != n:
107
+ self._batch_size = n
108
+ self._loader = get_image_dataloader(self.dataset_name, batch_size=n)
109
+ self._iter = iter(self._loader)
110
+ ```
111
+
112
+ ---
113
+
114
+ ### #7 — CLI flag not overriding all phases (LOW)
115
+
116
+ **What**: `--train-iters` overrode `nsgf_training.num_iterations` and `nsf_training.num_iterations` but missed `time_predictor.num_iterations` (default 40,000).
117
+
118
+ **Impact**: Smoke test with `--train-iters 5` completes Phase 1 and 2 in seconds, then hangs for hours on Phase 3.
119
+
120
+ **Prevention**: Grep config for ALL instances of the parameter: `grep -n num_iterations config.yaml`. Found 3 — must override all 3.
121
+
122
+ ---
123
+
124
+ ### #8 — CIFAR-10 Sinkhorn OOM on T4 (CRITICAL)
125
+
126
+ **What**: Paper's `sinkhorn.batch_size=128` for CIFAR-10. Sinkhorn `tensorized` backend computes 128×128 cost matrix with 3072-dim vectors (flattened 3×32×32). With `potentials=True` + `autograd.grad`, plus 2 calls per flow step × 5 steps per batch = 10 Sinkhorn calls per pool batch. Total: 8+ GB just for Sinkhorn, plus the 38M-param UNet → OOM on T4 16GB.
127
+
128
+ **Impact**: CIFAR-10 crashes during pool building, before training even starts.
129
+
130
+ **Root cause**: Used paper hyperparameters verbatim without estimating VRAM for target hardware. Paper authors likely used A100 80GB.
131
+
132
+ **VRAM math that should have been done upfront**:
133
+ | Component | Approx VRAM |
134
+ |-----------|------------|
135
+ | Sinkhorn call (128×3072, tensorized) | ~600 MB |
136
+ | × 10 calls per pool batch (with autograd) | ~6+ GB peak |
137
+ | UNet (38M params, fp32) | ~150 MB |
138
+ | Gradients + optimizer states | ~450 MB |
139
+ | **Total** | **~7+ GB** (doesn't account for fragmentation) |
140
+
141
+ **Fix applied**: Reduce `sinkhorn.batch_size` 128→32, increase `pool.num_batches` 2500→10000. Total pool entries unchanged (1.6M). Added `--sinkhorn-batch` CLI flag.
142
+
143
+ ---
144
+
145
+ ### #9 — No GPU memory freed between phases (MODERATE)
146
+
147
+ **What**: After pool building finishes, Sinkhorn computation graph CUDA allocations stay cached. Training starts with reduced available VRAM.
148
+
149
+ **Fix applied**: `torch.cuda.empty_cache()` after pool building + `del pool` after finalization.
150
+
151
+ ---
152
+
153
+ ### #10 — Multi-GPU assumption (LOW)
154
+
155
+ **What**: User has T4×2 on Kaggle. Code is single-GPU. Second GPU sits idle.
156
+
157
+ **Prevention**: Document explicitly: "Single-GPU only. T4×2 wastes the second GPU. Use single T4 instead, or add DDP."
158
+
159
+ ---
160
+
161
+ ## Principles
162
+
163
+ Derived from the mistakes above. Each traces back to one or more specific failures.
164
+
165
+ 1. **Read the appendix first.** The appendix is the recipe; the main paper is the story. *(from: all mistakes — hyperparams were in appendix)*
166
+
167
+ 2. **Test the boundaries, not just the happy path.** The bug is always in the path you didn't test. *(from: #1, #3)*
168
+
169
+ 3. **Library APIs are opaque until tested.** Don't assume a function accepts your tensor shape just because it "makes sense." Write a 10-line test script. *(from: #1)*
170
+
171
+ 4. **Pre-concatenate, don't re-concatenate.** Any data structure built once and sampled many times should be finalized into a single tensor. *(from: #2)*
172
+
173
+ 5. **The user's time is more expensive than your time.** A crash on their GPU after 5 minutes of setup is worse than you spending 30 extra minutes testing. *(from: #1, #3, #8)*
174
+
175
+ 6. **Flatten for OT libraries.** geomloss, POT, ott-jax all expect `(N, D)` point clouds. Images must be flattened. #1 gotcha in OT-based generative models. *(from: #1)*
176
+
177
+ 7. **Store training state on CPU, compute on GPU.** Pools and replay buffers on CPU; only minibatch goes to GPU. *(from: #8, #9)*
178
+
179
+ 8. **Multi-phase training = multiple separate trainers.** Each phase gets its own optimizer. Previous phase's model → `eval()`. *(from: #6)*
180
+
181
+ 9. **Shared objects across phases are landmines.** Cached DataLoaders, iterators, batch sizes — any phase-specific param can silently break later phases. *(from: #6)*
182
+
183
+ 10. **CLI overrides must be exhaustive.** N copies of a parameter in config → CLI must touch all N. *(from: #7)*
184
+
185
+ 11. **Paper hyperparameters assume paper hardware.** Re-derive batch sizes from VRAM constraints. Keep total samples seen constant. *(from: #8)*
186
+
187
+ 12. **Estimate VRAM before running, not after OOM.** Sinkhorn: O(N² × D). Model: params × 4 bytes × 3. Write it down before the first GPU run. *(from: #8)*
188
+
189
+ 13. **Checkpoint at phase boundaries, not just step boundaries.** Phase-level checkpoints enable `--resume-phase`. Step-level checkpoints within long phases are a bonus. *(from: #4)*
190
+
191
+ 14. **Free GPU memory between phases.** `torch.cuda.empty_cache()` + `del` large objects between phases with different memory patterns. *(from: #9)*
192
+
193
+ 15. **Document what your code does NOT support.** Single-GPU? No mixed precision? Say so explicitly. *(from: #10)*
194
+
195
+ ---
196
+
197
+ ## Sinkhorn VRAM Reference Table
198
+
199
+ For quick VRAM estimation during paper reproduction planning. `tensorized` backend, fp32, single `SamplesLoss` call with `potentials=True`.
200
+
201
+ | N (batch) | D (dim) | Example | Approx VRAM/call |
202
+ |-----------|---------|---------|-----------------|
203
+ | 256 | 2 | 2D points | ~1 MB |
204
+ | 256 | 784 | MNIST 28×28 | ~200 MB |
205
+ | 128 | 3072 | CIFAR 3×32×32 | ~600 MB |
206
+ | 32 | 3072 | CIFAR (T4-safe) | ~40 MB |
207
+
208
+ Pool building does ~10 calls per batch (2 potentials × 5 flow steps). Multiply accordingly.
209
+
210
+ ---
211
+
212
+ ## Timeline of Bugs
213
+
214
+ | When | What broke | How discovered | Sessions wasted |
215
+ |------|-----------|----------------|-----------------|
216
+ | Initial ship | geomloss shape (MNIST/CIFAR) | User's first GPU run | 1 |
217
+ | After geomloss fix | DataLoader batch mismatch (Phase 2) | My CPU test | 0 (caught in sandbox) |
218
+ | After DataLoader fix | `--train-iters` missed Phase 3 | My CPU test | 0 (caught in sandbox) |
219
+ | User's full MNIST run | No checkpointing, interrupted | User's Kaggle timeout | 1 |
220
+ | User's CIFAR run | Sinkhorn OOM on T4 | User's Kaggle log | 1 |
221
+ | After OOM fix | No `empty_cache()` between phases | Analysis of VRAM patterns | 0 (preventive) |
222
+
223
+ **Total Kaggle sessions wasted by user: 3.** The bugs caught in sandbox (0 sessions wasted) vs bugs caught on user hardware (3 sessions) reinforce principle #5: test everything before shipping.