rogermt commited on
Commit
66d3632
Β·
verified Β·
1 Parent(s): 80b1d4b

SKILL.md: Add VRAM estimation, checkpointing, multi-session, multi-GPU lessons from CIFAR OOM

Browse files

New mistakes #8-10, new principles #11-15, expanded Phase 4 (VRAM), new Phase 7 (checkpointing),
updated pre-flight checklist, updated error table."

Files changed (1) hide show
  1. SKILL.md +193 -113
SKILL.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  name: paper-reproduction
3
- description: "Skill for reproducing ML research papers from scratch when no official code exists. Use this whenever a user asks to implement, reproduce, or replicate a paper β€” especially papers involving novel loss functions, custom training loops, or non-standard architectures that aren't covered by existing HF trainers. Also use when the user mentions 'paper reproduction', 'implement this paper', 'no official code', or describes a method from a specific arxiv paper. Covers: reading papers systematically, extracting hyperparameters, building custom training pipelines, handling library-specific gotchas (geomloss, POT, custom UNets), and iterating on GPU results."
4
  ---
5
 
6
  # Paper Reproduction Skill
@@ -32,6 +32,8 @@ Most reproduction failures trace back to incomplete paper reading. Don't skim
32
  β–‘ Evaluation protocol β€” metrics, number of samples, any special setup
33
  β–‘ Hyperparameters per experiment β€” papers often have different configs per dataset
34
  β–‘ Algorithm pseudocode β€” if provided, follow it exactly before improvising
 
 
35
  ```
36
 
37
  ### Mistake I made: Incomplete appendix reading
@@ -82,39 +84,7 @@ The `SamplesLoss` in geomloss requires inputs as `(N, D)` or `(B, N, D)` tensors
82
  ValueError: Input samples 'x' and 'y' should be encoded as (N,D) or (B,N,D) (batch) tensors.
83
  ```
84
 
85
- **The fix**: Flatten images before passing to geomloss, reshape gradients back after:
86
-
87
- ```python
88
- def compute_velocity(self, X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
89
- original_shape = X.shape
90
- is_image = X.dim() == 4
91
-
92
- if is_image:
93
- B = X.shape[0]
94
- X_flat = X.detach().clone().view(B, -1).requires_grad_(True)
95
- Y_flat = Y.detach().view(B, -1)
96
- else:
97
- X_flat = X.detach().clone().requires_grad_(True)
98
- Y_flat = Y.detach()
99
-
100
- # Self-potential
101
- F_self, _ = self.loss_fn(X_flat, X_flat.detach().clone())
102
- grad_self = torch.autograd.grad(F_self.sum(), X_flat)[0]
103
-
104
- # Cross-potential
105
- X_flat2 = X.detach().clone().view(B, -1).requires_grad_(True) if is_image else X.detach().clone().requires_grad_(True)
106
- F_cross, _ = self.loss_fn(X_flat2, Y_flat)
107
- grad_cross = torch.autograd.grad(F_cross.sum(), X_flat2)[0]
108
-
109
- velocity = grad_self.detach() - grad_cross.detach()
110
-
111
- if is_image:
112
- velocity = velocity.view(original_shape)
113
-
114
- return velocity
115
- ```
116
-
117
- This pattern β€” flatten before library call, reshape after β€” applies to many optimal transport libraries (POT, geomloss, ott-jax).
118
 
119
  ---
120
 
@@ -131,76 +101,120 @@ When building a UNet from scratch (rather than importing from guided-diffusion),
131
 
132
  **Mistake pattern**: Using a helper like `_get_num_res_blocks()` that infers block count from module list lengths. This is fragile β€” if the number of levels or blocks per level varies, the inference breaks.
133
 
134
- **Better approach**: Store `num_res_blocks` as an instance variable at init time and use it directly:
135
 
136
- ```python
137
- def __init__(self, ..., num_res_blocks=2, channel_mult=[1,2,2,2], ...):
138
- self.num_res_blocks = num_res_blocks
139
- self.num_levels = len(channel_mult)
140
- # ... build layers ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
- def forward(self, x, t):
143
- # Use self.num_res_blocks directly, not a computed value
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  ```
145
 
146
- ### GroupNorm channel requirements
147
 
148
- `nn.GroupNorm(32, channels)` requires `channels` to be divisible by 32. For small models (e.g., MNIST with `model_channels=32`), this is fine at the first level but may break at deeper levels if `channel_mult` creates channels not divisible by 32.
 
 
149
 
150
- **Safety check**: At init time, verify all channel counts are divisible by the group count:
151
  ```python
152
- for level, mult in enumerate(channel_mult):
153
- ch = model_channels * mult
154
- assert ch % 32 == 0, f"Level {level}: channels={ch} not divisible by 32"
 
 
 
 
 
 
 
 
 
 
 
 
155
  ```
156
 
 
 
 
 
 
 
 
 
157
  ---
158
 
159
- ## Phase 4: Training Loop Patterns for Custom Pipelines
160
 
161
- ### Trajectory pool memory management
162
 
163
- When building trajectory pools (storing (x, v, t) tuples from gradient flow), memory can explode:
164
- - MNIST: 256 samples Γ— 1500 batches Γ— 5 steps Γ— 784 dims Γ— 4 bytes β‰ˆ 2.4 GB (manageable)
165
- - CIFAR-10: 128 samples Γ— 2500 batches Γ— 5 steps Γ— 3072 dims Γ— 4 bytes β‰ˆ 19 GB (tight on T4)
166
 
167
- **The fix**: Store pool tensors on CPU, transfer to GPU only during sampling:
 
 
168
 
169
- ```python
170
- class TrajectoryPool:
171
- def sample(self, batch_size, device="cpu"):
172
- # Concatenate on CPU, index, THEN move to device
173
- all_x = torch.cat(self.x_pool, dim=0) # stays on CPU
174
- idx = torch.randint(0, all_x.shape[0], (batch_size,))
175
- return all_x[idx].to(device), ... # only batch moves to GPU
176
  ```
177
 
178
- **Mistake I made**: The pool sampling code calls `torch.cat` on the entire pool every training step, which is O(pool_size) per step. For 512K entries this is slow. Better: pre-concatenate once after pool building, then just index:
179
 
180
- ```python
181
- def finalize(self):
182
- """Call once after pool is fully built."""
183
- self._all_x = torch.cat(self.x_pool, dim=0)
184
- self._all_v = torch.cat(self.v_pool, dim=0)
185
- self._all_t = torch.tensor(self.t_pool, dtype=torch.float32)
186
- # Free the lists
187
- self.x_pool = self.v_pool = self.t_pool = None
188
-
189
- def sample(self, batch_size, device="cpu"):
190
- idx = torch.randint(0, self._all_x.shape[0], (batch_size,))
191
- return self._all_x[idx].to(device), self._all_v[idx].to(device), self._all_t[idx].to(device)
192
  ```
193
 
194
- ### Multi-phase training (NSGF++)
 
 
 
 
195
 
196
- NSGF++ has 3 sequential training phases:
197
- 1. **NSGF**: Build trajectory pool β†’ train velocity field
198
- 2. **NSF**: Use trained NSGF to generate P0 samples β†’ train straight flow
199
- 3. **Phase predictor**: Train CNN to predict transition time
200
 
201
- **Key insight**: Each phase depends on the previous one being fully trained. Don't try to interleave them. The NSGF model must be in `eval()` mode when used as a sample generator in phases 2 and 3.
 
 
202
 
203
- ### Shared state across training phases β€” the DataLoader trap
204
 
205
  When a single `DatasetLoader` object is shared across multiple training phases, **lazy-initialized internal state** (like a cached DataLoader) will silently break subsequent phases.
206
 
@@ -225,40 +239,69 @@ def sample_target(self, n, device="cpu"):
225
 
226
  ---
227
 
228
- ## Phase 5: Testing Strategy
229
 
230
- ### Always test on CPU first with tiny configs
231
 
232
- Before any GPU run, verify the full pipeline works end-to-end:
233
 
234
- ```bash
235
- # Tiny run β€” should complete in <30 seconds
236
- python main.py --experiment 2d --dataset 8gaussians --steps 5 --pool-batches 5 --train-iters 100
237
 
238
- # Slightly larger β€” should complete in <5 minutes
239
- python main.py --experiment 2d --dataset 8gaussians --steps 5 --pool-batches 20 --train-iters 2000
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  ```
241
 
242
- ### Test image experiments separately with minimal configs
243
 
244
  ```bash
245
- # MNIST smoke test β€” 2 pool batches, 5 training iters per phase
246
- python main.py --experiment mnist --pool-batches 2 --train-iters 5
247
 
248
- # If this crashes, fix before scaling up
 
 
 
 
249
  ```
250
 
251
- **Mistake I made**: I tested 2D experiments thoroughly on CPU (both tiny and medium runs worked) but shipped the image experiments without testing them at all. The geomloss tensor shape bug affected ONLY the image path, so 2D success gave false confidence. The first GPU test of MNIST crashed immediately.
252
 
253
- **Rule**: Test EVERY experiment type, not just the simplest one. If you have `{2d, mnist, cifar10}` experiments, test all three with minimal configs before declaring the code ready.
254
 
255
- ### Test all training phases, not just the first one
 
 
 
 
 
 
 
 
 
256
 
257
- Even after fixing Phase 1, Phase 2 can still crash due to shared state (see DataLoader trap above). Run with `--train-iters 5 --pool-batches 2` to verify all 3 phases complete without errors. This takes <60 seconds on CPU for MNIST.
 
 
 
258
 
259
  ---
260
 
261
- ## Phase 6: Debugging GPU Runs
262
 
263
  ### Common error patterns
264
 
@@ -267,20 +310,24 @@ Even after fixing Phase 1, Phase 2 can still crash due to shared state (see Data
267
  | `ValueError: (N,D) or (B,N,D)` | Library expects flat tensors, got images | Flatten before library call |
268
  | `RuntimeError: size of tensor a (X) must match size of tensor b (Y)` | Shared DataLoader with wrong batch size | Recreate DataLoader when batch size changes |
269
  | `RuntimeError: shape mismatch` in UNet | Skip connection count wrong | Count pushes and pops manually |
270
- | `CUDA OOM` during pool building | Pool too large for GPU | Build pool on CPU, sample to GPU |
271
- | `CUDA OOM` during training | Batch too large or model too big | Reduce batch β†’ increase grad accum |
 
272
  | Training loss plateaus high | Pool too small or too few iterations | Increase pool batches, more iters |
273
  | W2 distance too high | Undertrained model | Full paper config: 200 batches, 20k iters |
274
- | `KeyboardInterrupt` during training | Training takes too long at scale | Expected β€” full 2D takes ~20min on T4 |
 
275
 
276
  ### When the user runs on their hardware
277
 
278
  If you're developing code that the user will run on their own GPU (Kaggle, Colab, local):
279
 
280
  1. **Provide exact commands** β€” don't make them figure out args
281
- 2. **Warn about expected runtimes** β€” "2D full run: ~20min on T4, MNIST: ~2-4 hours, CIFAR-10: ~8-12 hours"
282
  3. **Include checkpoint saving** β€” so partial runs aren't wasted
283
- 4. **Test the exact commands yourself** β€” if you can't run on GPU, at least verify the command parses correctly on CPU
 
 
284
 
285
  ---
286
 
@@ -306,10 +353,10 @@ If you're developing code that the user will run on their own GPU (Kaggle, Colab
306
  - **Root cause**: False confidence from 2D success. Assumed same code path.
307
  - **Prevention**: Test EVERY experiment type with minimal configs. Different experiment types often exercise different code paths.
308
 
309
- 4. **No checkpoint saving** (MODERATE)
310
  - **What**: No intermediate checkpoints during long training runs.
311
- - **Impact**: If training is interrupted (Kaggle timeout, OOM), all progress is lost.
312
- - **Prevention**: Save checkpoints every N iterations. Implement `--resume` flag.
313
 
314
  5. **UNet forward pass fragility** (LOW-MODERATE)
315
  - **What**: `_get_num_res_blocks()` infers block count from module list length division.
@@ -318,15 +365,33 @@ If you're developing code that the user will run on their own GPU (Kaggle, Colab
318
 
319
  6. **DataLoader batch size mismatch across phases** (CRITICAL)
320
  - **What**: Shared `DatasetLoader` caches a DataLoader with batch_size=256 from Phase 1. Phase 2 requests batch_size=128 but gets 256 back β†’ tensor dimension mismatch crash.
321
- - **Impact**: Phase 2 (NSF) crashes immediately even after Phase 1 completes successfully. The error message (`size of tensor a (128) must match size of tensor b (256)`) doesn't make the DataLoader caching obvious.
322
- - **Root cause**: Lazy initialization pattern without invalidation. The `_image_loader` was created once and never checked for batch size changes.
323
- - **Prevention**: When sharing stateful objects across consumers with different configs, either (a) track all cached parameters and invalidate on change, or (b) don't cache at all. For DataLoaders: recreate when batch_size changes.
324
 
325
  7. **CLI flag not overriding all training phases** (LOW)
326
  - **What**: `--train-iters` flag overrode NSGF and NSF iterations but NOT the phase predictor iterations (40,000 default). Smoke tests would hang on Phase 3 even with `--train-iters 5`.
327
- - **Impact**: Tests take much longer than expected. User thinks something is broken.
328
  - **Root cause**: Forgot that 3-phase training means 3 iteration counts to override.
329
- - **Prevention**: When adding a CLI override, grep the config for ALL fields it should affect. If a config has `nsgf_training.num_iterations`, `nsf_training.num_iterations`, AND `time_predictor.num_iterations`, the override must touch all three.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
 
331
  ---
332
 
@@ -338,11 +403,16 @@ If you're developing code that the user will run on their own GPU (Kaggle, Colab
338
  β–‘ Third-party library APIs tested with exact tensor shapes per experiment
339
  β–‘ Shared state across phases verified (DataLoaders, iterators, caches)
340
  β–‘ CLI flags override ALL relevant config values (not just some)
 
 
 
341
  β–‘ Training loop profiled β€” no O(N) operations per step where O(1) suffices
342
  β–‘ Memory estimated per experiment (pool size Γ— data dim Γ— 4 bytes)
343
- β–‘ Checkpointing implemented for runs >10 minutes
344
- β–‘ Clear CLI with sensible defaults and override flags
 
345
  β–‘ Expected runtimes documented per hardware tier
 
346
  β–‘ Error messages are clear (not just stack traces)
347
  β–‘ Results directory created automatically
348
  β–‘ Requirements.txt includes ALL dependencies with minimum versions
@@ -371,3 +441,13 @@ If you're developing code that the user will run on their own GPU (Kaggle, Colab
371
  9. **Shared objects across phases are landmines.** When a DataLoader, iterator, or cache is shared across training phases, any phase-specific parameter (batch size, number of workers, shuffle mode) can silently break later phases. Either don't share, or implement proper invalidation. Test by running all phases sequentially with different configs per phase.
372
 
373
  10. **CLI overrides must be exhaustive.** If your config has N copies of a parameter (one per training phase), your CLI override must touch all N. Grep the config file for the parameter name to find all instances.
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  name: paper-reproduction
3
+ description: "Skill for reproducing ML research papers from scratch when no official code exists. Use this whenever a user asks to implement, reproduce, or replicate a paper β€” especially papers involving novel loss functions, custom training loops, or non-standard architectures that aren't covered by existing HF trainers. Also use when the user mentions 'paper reproduction', 'implement this paper', 'no official code', or describes a method from a specific arxiv paper. Covers: reading papers systematically, extracting hyperparameters, building custom training pipelines, handling library-specific gotchas (geomloss, POT, custom UNets), VRAM estimation, checkpointing for multi-session training, and iterating on GPU results."
4
  ---
5
 
6
  # Paper Reproduction Skill
 
32
  β–‘ Evaluation protocol β€” metrics, number of samples, any special setup
33
  β–‘ Hyperparameters per experiment β€” papers often have different configs per dataset
34
  β–‘ Algorithm pseudocode β€” if provided, follow it exactly before improvising
35
+ β–‘ GPU hardware used β€” what the authors trained on (often buried in appendix)
36
+ β–‘ Training time β€” how long did the authors' runs take?
37
  ```
38
 
39
  ### Mistake I made: Incomplete appendix reading
 
84
  ValueError: Input samples 'x' and 'y' should be encoded as (N,D) or (B,N,D) (batch) tensors.
85
  ```
86
 
87
+ **The fix**: Flatten images before passing to geomloss, reshape gradients back after. This pattern β€” flatten before library call, reshape after β€” applies to many optimal transport libraries (POT, geomloss, ott-jax).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  ---
90
 
 
101
 
102
  **Mistake pattern**: Using a helper like `_get_num_res_blocks()` that infers block count from module list lengths. This is fragile β€” if the number of levels or blocks per level varies, the inference breaks.
103
 
104
+ **Better approach**: Store `num_res_blocks` as an instance variable at init time and use it directly.
105
 
106
+ ### GroupNorm channel requirements
107
+
108
+ `nn.GroupNorm(32, channels)` requires `channels` to be divisible by 32. For small models (e.g., MNIST with `model_channels=32`), this is fine at the first level but may break at deeper levels if `channel_mult` creates channels not divisible by 32.
109
+
110
+ ---
111
+
112
+ ## Phase 4: VRAM Estimation and Memory Management
113
+
114
+ ### Estimate VRAM BEFORE running β€” not after OOM
115
+
116
+ Papers report batch sizes that worked on their hardware (often A100 80GB or 8Γ—V100). If your user has a T4 (16GB) or even a T4Γ—2 (16GB per GPU, but single-GPU code only uses one), you must recalculate whether the paper's configs will fit.
117
+
118
+ ### The Sinkhorn VRAM trap
119
+
120
+ The `tensorized` backend in geomloss computes a full NΓ—N cost matrix. For N samples of dimension D:
121
+ - Memory β‰ˆ O(NΒ² Γ— D) for the cost matrix + intermediate Sinkhorn iterations
122
+ - With `potentials=True` and `autograd.grad`, add another O(N Γ— D) for gradient storage
123
+
124
+ **Concrete examples (fp32, single Sinkhorn call)**:
125
+ | N (batch) | D (flattened dim) | Approx VRAM per call |
126
+ |-----------|-------------------|---------------------|
127
+ | 256 | 2 (2D points) | ~1 MB |
128
+ | 256 | 784 (MNIST 28Γ—28) | ~200 MB |
129
+ | 128 | 3072 (CIFAR 3Γ—32Γ—32) | ~600 MB |
130
 
131
+ But pool building calls Sinkhorn **twice per step** (self-potential + cross-potential) Γ— **5 flow steps per batch** = 10 Sinkhorn calls per pool batch. With autograd overhead, 128Γ—3072 easily eats 8+ GB β€” leaving no room for the 38M-param UNet on a 16GB T4.
132
+
133
+ **Mistake I made**: Used the paper's `sinkhorn.batch_size=128` for CIFAR-10. This OOMed immediately on T4. The paper's authors likely used A100s.
134
+
135
+ **The fix**: Reduce Sinkhorn batch size for smaller GPUs and increase pool batches to compensate:
136
+ ```yaml
137
+ # Paper config (A100 80GB):
138
+ sinkhorn.batch_size: 128
139
+ pool.num_batches: 2500
140
+ # Total pool entries: 128 Γ— 2500 Γ— 5 = 1.6M
141
+
142
+ # T4 16GB config:
143
+ sinkhorn.batch_size: 32
144
+ pool.num_batches: 10000
145
+ # Total pool entries: 32 Γ— 10000 Γ— 5 = 1.6M (same!)
146
  ```
147
 
148
+ Add a CLI override (`--sinkhorn-batch`) so users can tune without editing config files.
149
 
150
+ ### Always call `torch.cuda.empty_cache()` between phases
151
+
152
+ Pool building uses GPU for Sinkhorn computation. Training uses GPU for the neural network. These are different memory patterns. After pool building, the Sinkhorn computation graph is no longer needed β€” but PyTorch's CUDA allocator may still hold that memory. Explicitly free it:
153
 
 
154
  ```python
155
+ def build_trajectory_pool(self, ...):
156
+ # ... build pool ...
157
+ if self.device != "cpu":
158
+ torch.cuda.empty_cache() # Free Sinkhorn memory before training
159
+ self.pool.finalize()
160
+ ```
161
+
162
+ ### Multi-GPU β‰  automatic parallelism
163
+
164
+ If the user has a T4Γ—2 on Kaggle, your single-GPU code will only use ONE of the two GPUs. The second sits idle. Using both requires PyTorch DDP or model parallelism β€” which is a significant code change.
165
+
166
+ **Don't silently assume multi-GPU works.** Document this:
167
+ ```
168
+ NOTE: This code uses a single GPU. If you have T4Γ—2, only one GPU is used.
169
+ A single T4 (16GB) is sufficient β€” the second GPU is wasted without DDP.
170
  ```
171
 
172
+ ### Trajectory pool memory on CPU vs GPU
173
+
174
+ The trajectory pool stores ALL flow trajectories for the entire training. For image experiments this is gigabytes:
175
+ - MNIST: 1.92M entries Γ— 784 dims Γ— 4 bytes = **6 GB** on CPU
176
+ - CIFAR: 1.6M entries Γ— 3072 dims Γ— 4 bytes = **19.6 GB** on CPU
177
+
178
+ The pool MUST live on CPU. Only the sampled minibatch (128-256 samples) goes to GPU per training step. This is already how the code works (trajectories stored as CPU tensors, `.to(device)` in `sample()`), but it's worth being explicit about why.
179
+
180
  ---
181
 
182
+ ## Phase 5: Testing Strategy
183
 
184
+ ### Always test on CPU first with tiny configs
185
 
186
+ Before any GPU run, verify the full pipeline works end-to-end:
 
 
187
 
188
+ ```bash
189
+ # Tiny run β€” should complete in <30 seconds
190
+ python main.py --experiment 2d --dataset 8gaussians --steps 5 --pool-batches 5 --train-iters 100
191
 
192
+ # Slightly larger β€” should complete in <5 minutes
193
+ python main.py --experiment 2d --dataset 8gaussians --steps 5 --pool-batches 20 --train-iters 2000
 
 
 
 
 
194
  ```
195
 
196
+ ### Test image experiments separately with minimal configs
197
 
198
+ ```bash
199
+ # MNIST smoke test β€” 2 pool batches, 5 training iters per phase
200
+ python main.py --experiment mnist --pool-batches 2 --train-iters 5
201
+
202
+ # If this crashes, fix before scaling up
 
 
 
 
 
 
 
203
  ```
204
 
205
+ **Mistake I made**: I tested 2D experiments thoroughly on CPU (both tiny and medium runs worked) but shipped the image experiments without testing them at all. The geomloss tensor shape bug affected ONLY the image path, so 2D success gave false confidence. The first GPU test of MNIST crashed immediately.
206
+
207
+ **Rule**: Test EVERY experiment type, not just the simplest one. If you have `{2d, mnist, cifar10}` experiments, test all three with minimal configs before declaring the code ready.
208
+
209
+ ### Test all training phases, not just the first one
210
 
211
+ Even after fixing Phase 1, Phase 2 can still crash due to shared state (see DataLoader trap in Phase 6). Run with `--train-iters 5 --pool-batches 2` to verify all 3 phases complete without errors. This takes <60 seconds on CPU for MNIST.
 
 
 
212
 
213
+ ---
214
+
215
+ ## Phase 6: Shared State Across Training Phases
216
 
217
+ ### The DataLoader trap
218
 
219
  When a single `DatasetLoader` object is shared across multiple training phases, **lazy-initialized internal state** (like a cached DataLoader) will silently break subsequent phases.
220
 
 
239
 
240
  ---
241
 
242
+ ## Phase 7: Checkpointing and Multi-Session Training
243
 
244
+ ### Why this matters
245
 
246
+ Paper reproduction often requires training runs that exceed a single GPU session. Kaggle gives 9 hours per T4 session. MNIST NSGF++ with full paper config (100K+100K+40K iters) needs ~7-8 hours on T4 β€” tight. CIFAR-10 (200K+200K+40K) is impossible in one session.
247
 
248
+ Without checkpointing, a Kaggle timeout = all progress lost.
 
 
249
 
250
+ ### Phase-level checkpointing
251
+
252
+ For multi-phase training, save a checkpoint after EACH phase completes:
253
+
254
+ ```python
255
+ # After Phase 1 completes:
256
+ torch.save({
257
+ "nsgf_model_state": nsgf_model.state_dict(),
258
+ "phase": 1,
259
+ }, "checkpoints/phase1_complete.pt")
260
+
261
+ # After Phase 2 completes:
262
+ torch.save({
263
+ "nsgf_model_state": nsgf_model.state_dict(),
264
+ "nsf_model_state": nsf_model.state_dict(),
265
+ "phase": 2,
266
+ }, "checkpoints/phase2_complete.pt")
267
  ```
268
 
269
+ Then implement `--resume-phase N` that loads the phase N-1 checkpoint and skips completed phases:
270
 
271
  ```bash
272
+ # Session 1: Run Phase 1 (gets interrupted or completes)
273
+ python main.py --experiment mnist
274
 
275
+ # Session 2: Skip Phase 1, start Phase 2
276
+ python main.py --experiment mnist --resume-phase 2
277
+
278
+ # Session 3: Skip Phases 1+2, run Phase 3 + inference
279
+ python main.py --experiment mnist --resume-phase 3
280
  ```
281
 
282
+ ### Step-level checkpointing within phases
283
 
284
+ For long phases (100K+ steps), also save within the phase every N steps:
285
 
286
+ ```python
287
+ if (step + 1) % checkpoint_every == 0:
288
+ torch.save({
289
+ "model_state": model.state_dict(),
290
+ "optimizer_state": optimizer.state_dict(),
291
+ "step": step + 1,
292
+ }, "checkpoints/nsgf_checkpoint.pt")
293
+ ```
294
+
295
+ ### Important: checkpoint persistence on Kaggle
296
 
297
+ Kaggle notebooks persist `/kaggle/working/` across cells within the same session, but NOT across sessions. To carry checkpoints between sessions:
298
+ 1. Save checkpoints to `/kaggle/working/nsgf-plusplus/checkpoints/`
299
+ 2. Before session ends, commit the notebook output or copy checkpoints to a dataset
300
+ 3. In the new session, restore checkpoints before running `--resume-phase`
301
 
302
  ---
303
 
304
+ ## Phase 8: Debugging GPU Runs
305
 
306
  ### Common error patterns
307
 
 
310
  | `ValueError: (N,D) or (B,N,D)` | Library expects flat tensors, got images | Flatten before library call |
311
  | `RuntimeError: size of tensor a (X) must match size of tensor b (Y)` | Shared DataLoader with wrong batch size | Recreate DataLoader when batch size changes |
312
  | `RuntimeError: shape mismatch` in UNet | Skip connection count wrong | Count pushes and pops manually |
313
+ | `CUDA OOM` during pool building (Sinkhorn) | Sinkhorn batch too large for GPU | Reduce `--sinkhorn-batch` (e.g. 128β†’32) |
314
+ | `CUDA OOM` during training | Training batch too large or model too big | Reduce training batch, increase grad accum |
315
+ | `CUDA OOM` at phase transition | Memory not freed between phases | Add `torch.cuda.empty_cache()` + `del pool` |
316
  | Training loss plateaus high | Pool too small or too few iterations | Increase pool batches, more iters |
317
  | W2 distance too high | Undertrained model | Full paper config: 200 batches, 20k iters |
318
+ | Only 1 of 2 GPUs used | Code is single-GPU, no DDP | Expected β€” use single GPU or add DDP |
319
+ | `KeyboardInterrupt` mid-training | Training too long at scale | Check `checkpoints/` for latest save |
320
 
321
  ### When the user runs on their hardware
322
 
323
  If you're developing code that the user will run on their own GPU (Kaggle, Colab, local):
324
 
325
  1. **Provide exact commands** β€” don't make them figure out args
326
+ 2. **Warn about expected runtimes** β€” "2D full run: ~20min on T4, MNIST: ~2-4 hours per phase, CIFAR-10: ~4+ hours per phase"
327
  3. **Include checkpoint saving** β€” so partial runs aren't wasted
328
+ 4. **Document GPU requirements** β€” "MNIST fits on T4 16GB, CIFAR-10 needs `--sinkhorn-batch 32`"
329
+ 5. **Document multi-GPU limitations** β€” "Single-GPU only. T4Γ—2 wastes the second GPU."
330
+ 6. **Test the exact commands yourself** β€” if you can't run on GPU, at least verify the command parses correctly on CPU
331
 
332
  ---
333
 
 
353
  - **Root cause**: False confidence from 2D success. Assumed same code path.
354
  - **Prevention**: Test EVERY experiment type with minimal configs. Different experiment types often exercise different code paths.
355
 
356
+ 4. **No checkpoint saving** (MODERATE β†’ became CRITICAL at scale)
357
  - **What**: No intermediate checkpoints during long training runs.
358
+ - **Impact**: If training is interrupted (Kaggle timeout, OOM, accidental Ctrl+C), all progress is lost. MNIST full run is ~7 hours β€” losing that is devastating.
359
+ - **Prevention**: Save checkpoints every N iterations. Save after each phase. Implement `--resume-phase` flag. Test resume actually works.
360
 
361
  5. **UNet forward pass fragility** (LOW-MODERATE)
362
  - **What**: `_get_num_res_blocks()` infers block count from module list length division.
 
365
 
366
  6. **DataLoader batch size mismatch across phases** (CRITICAL)
367
  - **What**: Shared `DatasetLoader` caches a DataLoader with batch_size=256 from Phase 1. Phase 2 requests batch_size=128 but gets 256 back β†’ tensor dimension mismatch crash.
368
+ - **Impact**: Phase 2 (NSF) crashes immediately even after Phase 1 completes successfully.
369
+ - **Root cause**: Lazy initialization pattern without invalidation.
370
+ - **Prevention**: When sharing stateful objects across consumers with different configs, track all cached parameters and invalidate on change.
371
 
372
  7. **CLI flag not overriding all training phases** (LOW)
373
  - **What**: `--train-iters` flag overrode NSGF and NSF iterations but NOT the phase predictor iterations (40,000 default). Smoke tests would hang on Phase 3 even with `--train-iters 5`.
374
+ - **Impact**: Tests take much longer than expected.
375
  - **Root cause**: Forgot that 3-phase training means 3 iteration counts to override.
376
+ - **Prevention**: When adding a CLI override, grep the config for ALL fields it should affect.
377
+
378
+ 8. **CIFAR-10 Sinkhorn OOM on T4** (CRITICAL)
379
+ - **What**: Paper uses `sinkhorn.batch_size=128` for CIFAR. Sinkhorn on 128 Γ— 3072-dim (flattened 3Γ—32Γ—32) with `tensorized` backend computes a 128Γ—128 cost matrix with 3072-dim vectors, plus autograd for potentials. This OOMs on T4 16GB during pool building.
380
+ - **Impact**: CIFAR-10 experiment crashes before even starting training. User loses their Kaggle session.
381
+ - **Root cause**: Used paper's hyperparameters without estimating VRAM for target hardware. Paper authors likely used A100 80GB.
382
+ - **Prevention**: ALWAYS estimate VRAM before running. Sinkhorn with `tensorized` backend is O(NΒ² Γ— D). For CIFAR: 128Β² Γ— 3072 Γ— 4 bytes Γ— ~10 (overhead) β‰ˆ 2+ GB per call, Γ—10 calls per pool batch = too much. Reduce N: 32Β² Γ— 3072 is 4Γ— cheaper. Add `--sinkhorn-batch` CLI flag so users can tune without editing config.
383
+
384
+ 9. **No GPU memory freed between phases** (MODERATE)
385
+ - **What**: After pool building, the Sinkhorn computation graph's CUDA allocations remain cached even though they're no longer needed. Training then starts with less available VRAM.
386
+ - **Impact**: Training phase might OOM even though pool building finished.
387
+ - **Root cause**: PyTorch's CUDA allocator doesn't automatically return memory to the OS.
388
+ - **Prevention**: `torch.cuda.empty_cache()` after pool building completes. Also `del pool` if the pool data was already finalized to separate tensors.
389
+
390
+ 10. **Multi-GPU assumption** (LOW)
391
+ - **What**: User has T4Γ—2 on Kaggle. Code is single-GPU. Second GPU sits idle.
392
+ - **Impact**: User pays for 2 GPUs but only uses 1. They might think the code is broken.
393
+ - **Root cause**: Didn't document single-GPU limitation.
394
+ - **Prevention**: Document GPU requirements explicitly. If multi-GPU is needed, implement DDP β€” but that's a significant scope change, so discuss with user first.
395
 
396
  ---
397
 
 
403
  β–‘ Third-party library APIs tested with exact tensor shapes per experiment
404
  β–‘ Shared state across phases verified (DataLoaders, iterators, caches)
405
  β–‘ CLI flags override ALL relevant config values (not just some)
406
+ β–‘ VRAM estimated for target hardware β€” will Sinkhorn/model/pool fit?
407
+ β–‘ Sinkhorn batch size appropriate for target GPU (not just paper's GPU)
408
+ β–‘ torch.cuda.empty_cache() called between memory-intensive phases
409
  β–‘ Training loop profiled β€” no O(N) operations per step where O(1) suffices
410
  β–‘ Memory estimated per experiment (pool size Γ— data dim Γ— 4 bytes)
411
+ β–‘ Checkpointing implemented: every N steps + after each phase
412
+ β–‘ --resume-phase tested and working (load checkpoint β†’ skip phases β†’ continue)
413
+ β–‘ Clear CLI with sensible defaults and override flags for GPU-sensitive params
414
  β–‘ Expected runtimes documented per hardware tier
415
+ β–‘ Multi-GPU limitations documented
416
  β–‘ Error messages are clear (not just stack traces)
417
  β–‘ Results directory created automatically
418
  β–‘ Requirements.txt includes ALL dependencies with minimum versions
 
441
  9. **Shared objects across phases are landmines.** When a DataLoader, iterator, or cache is shared across training phases, any phase-specific parameter (batch size, number of workers, shuffle mode) can silently break later phases. Either don't share, or implement proper invalidation. Test by running all phases sequentially with different configs per phase.
442
 
443
  10. **CLI overrides must be exhaustive.** If your config has N copies of a parameter (one per training phase), your CLI override must touch all N. Grep the config file for the parameter name to find all instances.
444
+
445
+ 11. **Paper hyperparameters assume paper hardware.** If a paper reports batch_size=128 and trained on A100 80GB, that batch size may OOM on your T4 16GB. Always re-derive batch sizes from VRAM constraints, keeping the total samples seen (batch Γ— iterations) the same.
446
+
447
+ 12. **Estimate VRAM before running, not after OOM.** For Sinkhorn: O(NΒ² Γ— D). For model: count parameters Γ— 4 bytes (fp32) Γ— 3 (params + gradients + optimizer). For pool: stored on CPU but sampled minibatch goes to GPU. Write this down before your first GPU run.
448
+
449
+ 13. **Checkpoint at phase boundaries, not just step boundaries.** Phase-level checkpoints enable `--resume-phase` which is the minimum viable recovery. Step-level checkpoints within long phases are a bonus. Both together make multi-session training actually work.
450
+
451
+ 14. **Free GPU memory between phases.** `torch.cuda.empty_cache()` after pool building or any phase that uses different GPU memory patterns than the next phase. Also `del` large objects (pools, computation graphs) that won't be needed again.
452
+
453
+ 15. **Document what your code does NOT support.** Single-GPU only? No mixed precision? No gradient accumulation? Say so. Users with multi-GPU setups will waste time wondering why only one GPU is active if you don't tell them.