SKILL.md: Add VRAM estimation, checkpointing, multi-session, multi-GPU lessons from CIFAR OOM
Browse filesNew mistakes #8-10, new principles #11-15, expanded Phase 4 (VRAM), new Phase 7 (checkpointing),
updated pre-flight checklist, updated error table."
SKILL.md
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
---
|
| 2 |
name: paper-reproduction
|
| 3 |
-
description: "Skill for reproducing ML research papers from scratch when no official code exists. Use this whenever a user asks to implement, reproduce, or replicate a paper β especially papers involving novel loss functions, custom training loops, or non-standard architectures that aren't covered by existing HF trainers. Also use when the user mentions 'paper reproduction', 'implement this paper', 'no official code', or describes a method from a specific arxiv paper. Covers: reading papers systematically, extracting hyperparameters, building custom training pipelines, handling library-specific gotchas (geomloss, POT, custom UNets), and iterating on GPU results."
|
| 4 |
---
|
| 5 |
|
| 6 |
# Paper Reproduction Skill
|
|
@@ -32,6 +32,8 @@ Most reproduction failures trace back to incomplete paper reading. Don't skim
|
|
| 32 |
β‘ Evaluation protocol β metrics, number of samples, any special setup
|
| 33 |
β‘ Hyperparameters per experiment β papers often have different configs per dataset
|
| 34 |
β‘ Algorithm pseudocode β if provided, follow it exactly before improvising
|
|
|
|
|
|
|
| 35 |
```
|
| 36 |
|
| 37 |
### Mistake I made: Incomplete appendix reading
|
|
@@ -82,39 +84,7 @@ The `SamplesLoss` in geomloss requires inputs as `(N, D)` or `(B, N, D)` tensors
|
|
| 82 |
ValueError: Input samples 'x' and 'y' should be encoded as (N,D) or (B,N,D) (batch) tensors.
|
| 83 |
```
|
| 84 |
|
| 85 |
-
**The fix**: Flatten images before passing to geomloss, reshape gradients back after
|
| 86 |
-
|
| 87 |
-
```python
|
| 88 |
-
def compute_velocity(self, X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
|
| 89 |
-
original_shape = X.shape
|
| 90 |
-
is_image = X.dim() == 4
|
| 91 |
-
|
| 92 |
-
if is_image:
|
| 93 |
-
B = X.shape[0]
|
| 94 |
-
X_flat = X.detach().clone().view(B, -1).requires_grad_(True)
|
| 95 |
-
Y_flat = Y.detach().view(B, -1)
|
| 96 |
-
else:
|
| 97 |
-
X_flat = X.detach().clone().requires_grad_(True)
|
| 98 |
-
Y_flat = Y.detach()
|
| 99 |
-
|
| 100 |
-
# Self-potential
|
| 101 |
-
F_self, _ = self.loss_fn(X_flat, X_flat.detach().clone())
|
| 102 |
-
grad_self = torch.autograd.grad(F_self.sum(), X_flat)[0]
|
| 103 |
-
|
| 104 |
-
# Cross-potential
|
| 105 |
-
X_flat2 = X.detach().clone().view(B, -1).requires_grad_(True) if is_image else X.detach().clone().requires_grad_(True)
|
| 106 |
-
F_cross, _ = self.loss_fn(X_flat2, Y_flat)
|
| 107 |
-
grad_cross = torch.autograd.grad(F_cross.sum(), X_flat2)[0]
|
| 108 |
-
|
| 109 |
-
velocity = grad_self.detach() - grad_cross.detach()
|
| 110 |
-
|
| 111 |
-
if is_image:
|
| 112 |
-
velocity = velocity.view(original_shape)
|
| 113 |
-
|
| 114 |
-
return velocity
|
| 115 |
-
```
|
| 116 |
-
|
| 117 |
-
This pattern β flatten before library call, reshape after β applies to many optimal transport libraries (POT, geomloss, ott-jax).
|
| 118 |
|
| 119 |
---
|
| 120 |
|
|
@@ -131,76 +101,120 @@ When building a UNet from scratch (rather than importing from guided-diffusion),
|
|
| 131 |
|
| 132 |
**Mistake pattern**: Using a helper like `_get_num_res_blocks()` that infers block count from module list lengths. This is fragile β if the number of levels or blocks per level varies, the inference breaks.
|
| 133 |
|
| 134 |
-
**Better approach**: Store `num_res_blocks` as an instance variable at init time and use it directly
|
| 135 |
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
-
|
| 143 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
```
|
| 145 |
|
| 146 |
-
|
| 147 |
|
| 148 |
-
|
|
|
|
|
|
|
| 149 |
|
| 150 |
-
**Safety check**: At init time, verify all channel counts are divisible by the group count:
|
| 151 |
```python
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
```
|
| 156 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
---
|
| 158 |
|
| 159 |
-
## Phase
|
| 160 |
|
| 161 |
-
###
|
| 162 |
|
| 163 |
-
|
| 164 |
-
- MNIST: 256 samples Γ 1500 batches Γ 5 steps Γ 784 dims Γ 4 bytes β 2.4 GB (manageable)
|
| 165 |
-
- CIFAR-10: 128 samples Γ 2500 batches Γ 5 steps Γ 3072 dims Γ 4 bytes β 19 GB (tight on T4)
|
| 166 |
|
| 167 |
-
|
|
|
|
|
|
|
| 168 |
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
def sample(self, batch_size, device="cpu"):
|
| 172 |
-
# Concatenate on CPU, index, THEN move to device
|
| 173 |
-
all_x = torch.cat(self.x_pool, dim=0) # stays on CPU
|
| 174 |
-
idx = torch.randint(0, all_x.shape[0], (batch_size,))
|
| 175 |
-
return all_x[idx].to(device), ... # only batch moves to GPU
|
| 176 |
```
|
| 177 |
|
| 178 |
-
|
| 179 |
|
| 180 |
-
```
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
self._all_t = torch.tensor(self.t_pool, dtype=torch.float32)
|
| 186 |
-
# Free the lists
|
| 187 |
-
self.x_pool = self.v_pool = self.t_pool = None
|
| 188 |
-
|
| 189 |
-
def sample(self, batch_size, device="cpu"):
|
| 190 |
-
idx = torch.randint(0, self._all_x.shape[0], (batch_size,))
|
| 191 |
-
return self._all_x[idx].to(device), self._all_v[idx].to(device), self._all_t[idx].to(device)
|
| 192 |
```
|
| 193 |
|
| 194 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
|
| 196 |
-
|
| 197 |
-
1. **NSGF**: Build trajectory pool β train velocity field
|
| 198 |
-
2. **NSF**: Use trained NSGF to generate P0 samples β train straight flow
|
| 199 |
-
3. **Phase predictor**: Train CNN to predict transition time
|
| 200 |
|
| 201 |
-
|
|
|
|
|
|
|
| 202 |
|
| 203 |
-
###
|
| 204 |
|
| 205 |
When a single `DatasetLoader` object is shared across multiple training phases, **lazy-initialized internal state** (like a cached DataLoader) will silently break subsequent phases.
|
| 206 |
|
|
@@ -225,40 +239,69 @@ def sample_target(self, n, device="cpu"):
|
|
| 225 |
|
| 226 |
---
|
| 227 |
|
| 228 |
-
## Phase
|
| 229 |
|
| 230 |
-
###
|
| 231 |
|
| 232 |
-
|
| 233 |
|
| 234 |
-
|
| 235 |
-
# Tiny run β should complete in <30 seconds
|
| 236 |
-
python main.py --experiment 2d --dataset 8gaussians --steps 5 --pool-batches 5 --train-iters 100
|
| 237 |
|
| 238 |
-
#
|
| 239 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
```
|
| 241 |
|
| 242 |
-
|
| 243 |
|
| 244 |
```bash
|
| 245 |
-
#
|
| 246 |
-
python main.py --experiment mnist
|
| 247 |
|
| 248 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
| 249 |
```
|
| 250 |
|
| 251 |
-
|
| 252 |
|
| 253 |
-
|
| 254 |
|
| 255 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
|
| 257 |
-
|
|
|
|
|
|
|
|
|
|
| 258 |
|
| 259 |
---
|
| 260 |
|
| 261 |
-
## Phase
|
| 262 |
|
| 263 |
### Common error patterns
|
| 264 |
|
|
@@ -267,20 +310,24 @@ Even after fixing Phase 1, Phase 2 can still crash due to shared state (see Data
|
|
| 267 |
| `ValueError: (N,D) or (B,N,D)` | Library expects flat tensors, got images | Flatten before library call |
|
| 268 |
| `RuntimeError: size of tensor a (X) must match size of tensor b (Y)` | Shared DataLoader with wrong batch size | Recreate DataLoader when batch size changes |
|
| 269 |
| `RuntimeError: shape mismatch` in UNet | Skip connection count wrong | Count pushes and pops manually |
|
| 270 |
-
| `CUDA OOM` during pool building |
|
| 271 |
-
| `CUDA OOM` during training |
|
|
|
|
| 272 |
| Training loss plateaus high | Pool too small or too few iterations | Increase pool batches, more iters |
|
| 273 |
| W2 distance too high | Undertrained model | Full paper config: 200 batches, 20k iters |
|
| 274 |
-
|
|
|
|
|
| 275 |
|
| 276 |
### When the user runs on their hardware
|
| 277 |
|
| 278 |
If you're developing code that the user will run on their own GPU (Kaggle, Colab, local):
|
| 279 |
|
| 280 |
1. **Provide exact commands** β don't make them figure out args
|
| 281 |
-
2. **Warn about expected runtimes** β "2D full run: ~20min on T4, MNIST: ~2-4 hours, CIFAR-10: ~
|
| 282 |
3. **Include checkpoint saving** β so partial runs aren't wasted
|
| 283 |
-
4. **
|
|
|
|
|
|
|
| 284 |
|
| 285 |
---
|
| 286 |
|
|
@@ -306,10 +353,10 @@ If you're developing code that the user will run on their own GPU (Kaggle, Colab
|
|
| 306 |
- **Root cause**: False confidence from 2D success. Assumed same code path.
|
| 307 |
- **Prevention**: Test EVERY experiment type with minimal configs. Different experiment types often exercise different code paths.
|
| 308 |
|
| 309 |
-
4. **No checkpoint saving** (MODERATE)
|
| 310 |
- **What**: No intermediate checkpoints during long training runs.
|
| 311 |
-
- **Impact**: If training is interrupted (Kaggle timeout, OOM), all progress is lost.
|
| 312 |
-
- **Prevention**: Save checkpoints every N iterations. Implement `--resume` flag.
|
| 313 |
|
| 314 |
5. **UNet forward pass fragility** (LOW-MODERATE)
|
| 315 |
- **What**: `_get_num_res_blocks()` infers block count from module list length division.
|
|
@@ -318,15 +365,33 @@ If you're developing code that the user will run on their own GPU (Kaggle, Colab
|
|
| 318 |
|
| 319 |
6. **DataLoader batch size mismatch across phases** (CRITICAL)
|
| 320 |
- **What**: Shared `DatasetLoader` caches a DataLoader with batch_size=256 from Phase 1. Phase 2 requests batch_size=128 but gets 256 back β tensor dimension mismatch crash.
|
| 321 |
-
- **Impact**: Phase 2 (NSF) crashes immediately even after Phase 1 completes successfully.
|
| 322 |
-
- **Root cause**: Lazy initialization pattern without invalidation.
|
| 323 |
-
- **Prevention**: When sharing stateful objects across consumers with different configs,
|
| 324 |
|
| 325 |
7. **CLI flag not overriding all training phases** (LOW)
|
| 326 |
- **What**: `--train-iters` flag overrode NSGF and NSF iterations but NOT the phase predictor iterations (40,000 default). Smoke tests would hang on Phase 3 even with `--train-iters 5`.
|
| 327 |
-
- **Impact**: Tests take much longer than expected.
|
| 328 |
- **Root cause**: Forgot that 3-phase training means 3 iteration counts to override.
|
| 329 |
-
- **Prevention**: When adding a CLI override, grep the config for ALL fields it should affect.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
|
| 331 |
---
|
| 332 |
|
|
@@ -338,11 +403,16 @@ If you're developing code that the user will run on their own GPU (Kaggle, Colab
|
|
| 338 |
β‘ Third-party library APIs tested with exact tensor shapes per experiment
|
| 339 |
β‘ Shared state across phases verified (DataLoaders, iterators, caches)
|
| 340 |
β‘ CLI flags override ALL relevant config values (not just some)
|
|
|
|
|
|
|
|
|
|
| 341 |
β‘ Training loop profiled β no O(N) operations per step where O(1) suffices
|
| 342 |
β‘ Memory estimated per experiment (pool size Γ data dim Γ 4 bytes)
|
| 343 |
-
β‘ Checkpointing implemented
|
| 344 |
-
β‘
|
|
|
|
| 345 |
β‘ Expected runtimes documented per hardware tier
|
|
|
|
| 346 |
β‘ Error messages are clear (not just stack traces)
|
| 347 |
β‘ Results directory created automatically
|
| 348 |
β‘ Requirements.txt includes ALL dependencies with minimum versions
|
|
@@ -371,3 +441,13 @@ If you're developing code that the user will run on their own GPU (Kaggle, Colab
|
|
| 371 |
9. **Shared objects across phases are landmines.** When a DataLoader, iterator, or cache is shared across training phases, any phase-specific parameter (batch size, number of workers, shuffle mode) can silently break later phases. Either don't share, or implement proper invalidation. Test by running all phases sequentially with different configs per phase.
|
| 372 |
|
| 373 |
10. **CLI overrides must be exhaustive.** If your config has N copies of a parameter (one per training phase), your CLI override must touch all N. Grep the config file for the parameter name to find all instances.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
name: paper-reproduction
|
| 3 |
+
description: "Skill for reproducing ML research papers from scratch when no official code exists. Use this whenever a user asks to implement, reproduce, or replicate a paper β especially papers involving novel loss functions, custom training loops, or non-standard architectures that aren't covered by existing HF trainers. Also use when the user mentions 'paper reproduction', 'implement this paper', 'no official code', or describes a method from a specific arxiv paper. Covers: reading papers systematically, extracting hyperparameters, building custom training pipelines, handling library-specific gotchas (geomloss, POT, custom UNets), VRAM estimation, checkpointing for multi-session training, and iterating on GPU results."
|
| 4 |
---
|
| 5 |
|
| 6 |
# Paper Reproduction Skill
|
|
|
|
| 32 |
β‘ Evaluation protocol β metrics, number of samples, any special setup
|
| 33 |
β‘ Hyperparameters per experiment β papers often have different configs per dataset
|
| 34 |
β‘ Algorithm pseudocode β if provided, follow it exactly before improvising
|
| 35 |
+
β‘ GPU hardware used β what the authors trained on (often buried in appendix)
|
| 36 |
+
β‘ Training time β how long did the authors' runs take?
|
| 37 |
```
|
| 38 |
|
| 39 |
### Mistake I made: Incomplete appendix reading
|
|
|
|
| 84 |
ValueError: Input samples 'x' and 'y' should be encoded as (N,D) or (B,N,D) (batch) tensors.
|
| 85 |
```
|
| 86 |
|
| 87 |
+
**The fix**: Flatten images before passing to geomloss, reshape gradients back after. This pattern β flatten before library call, reshape after β applies to many optimal transport libraries (POT, geomloss, ott-jax).
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
---
|
| 90 |
|
|
|
|
| 101 |
|
| 102 |
**Mistake pattern**: Using a helper like `_get_num_res_blocks()` that infers block count from module list lengths. This is fragile β if the number of levels or blocks per level varies, the inference breaks.
|
| 103 |
|
| 104 |
+
**Better approach**: Store `num_res_blocks` as an instance variable at init time and use it directly.
|
| 105 |
|
| 106 |
+
### GroupNorm channel requirements
|
| 107 |
+
|
| 108 |
+
`nn.GroupNorm(32, channels)` requires `channels` to be divisible by 32. For small models (e.g., MNIST with `model_channels=32`), this is fine at the first level but may break at deeper levels if `channel_mult` creates channels not divisible by 32.
|
| 109 |
+
|
| 110 |
+
---
|
| 111 |
+
|
| 112 |
+
## Phase 4: VRAM Estimation and Memory Management
|
| 113 |
+
|
| 114 |
+
### Estimate VRAM BEFORE running β not after OOM
|
| 115 |
+
|
| 116 |
+
Papers report batch sizes that worked on their hardware (often A100 80GB or 8ΓV100). If your user has a T4 (16GB) or even a T4Γ2 (16GB per GPU, but single-GPU code only uses one), you must recalculate whether the paper's configs will fit.
|
| 117 |
+
|
| 118 |
+
### The Sinkhorn VRAM trap
|
| 119 |
+
|
| 120 |
+
The `tensorized` backend in geomloss computes a full NΓN cost matrix. For N samples of dimension D:
|
| 121 |
+
- Memory β O(NΒ² Γ D) for the cost matrix + intermediate Sinkhorn iterations
|
| 122 |
+
- With `potentials=True` and `autograd.grad`, add another O(N Γ D) for gradient storage
|
| 123 |
+
|
| 124 |
+
**Concrete examples (fp32, single Sinkhorn call)**:
|
| 125 |
+
| N (batch) | D (flattened dim) | Approx VRAM per call |
|
| 126 |
+
|-----------|-------------------|---------------------|
|
| 127 |
+
| 256 | 2 (2D points) | ~1 MB |
|
| 128 |
+
| 256 | 784 (MNIST 28Γ28) | ~200 MB |
|
| 129 |
+
| 128 | 3072 (CIFAR 3Γ32Γ32) | ~600 MB |
|
| 130 |
|
| 131 |
+
But pool building calls Sinkhorn **twice per step** (self-potential + cross-potential) Γ **5 flow steps per batch** = 10 Sinkhorn calls per pool batch. With autograd overhead, 128Γ3072 easily eats 8+ GB β leaving no room for the 38M-param UNet on a 16GB T4.
|
| 132 |
+
|
| 133 |
+
**Mistake I made**: Used the paper's `sinkhorn.batch_size=128` for CIFAR-10. This OOMed immediately on T4. The paper's authors likely used A100s.
|
| 134 |
+
|
| 135 |
+
**The fix**: Reduce Sinkhorn batch size for smaller GPUs and increase pool batches to compensate:
|
| 136 |
+
```yaml
|
| 137 |
+
# Paper config (A100 80GB):
|
| 138 |
+
sinkhorn.batch_size: 128
|
| 139 |
+
pool.num_batches: 2500
|
| 140 |
+
# Total pool entries: 128 Γ 2500 Γ 5 = 1.6M
|
| 141 |
+
|
| 142 |
+
# T4 16GB config:
|
| 143 |
+
sinkhorn.batch_size: 32
|
| 144 |
+
pool.num_batches: 10000
|
| 145 |
+
# Total pool entries: 32 Γ 10000 Γ 5 = 1.6M (same!)
|
| 146 |
```
|
| 147 |
|
| 148 |
+
Add a CLI override (`--sinkhorn-batch`) so users can tune without editing config files.
|
| 149 |
|
| 150 |
+
### Always call `torch.cuda.empty_cache()` between phases
|
| 151 |
+
|
| 152 |
+
Pool building uses GPU for Sinkhorn computation. Training uses GPU for the neural network. These are different memory patterns. After pool building, the Sinkhorn computation graph is no longer needed β but PyTorch's CUDA allocator may still hold that memory. Explicitly free it:
|
| 153 |
|
|
|
|
| 154 |
```python
|
| 155 |
+
def build_trajectory_pool(self, ...):
|
| 156 |
+
# ... build pool ...
|
| 157 |
+
if self.device != "cpu":
|
| 158 |
+
torch.cuda.empty_cache() # Free Sinkhorn memory before training
|
| 159 |
+
self.pool.finalize()
|
| 160 |
+
```
|
| 161 |
+
|
| 162 |
+
### Multi-GPU β automatic parallelism
|
| 163 |
+
|
| 164 |
+
If the user has a T4Γ2 on Kaggle, your single-GPU code will only use ONE of the two GPUs. The second sits idle. Using both requires PyTorch DDP or model parallelism β which is a significant code change.
|
| 165 |
+
|
| 166 |
+
**Don't silently assume multi-GPU works.** Document this:
|
| 167 |
+
```
|
| 168 |
+
NOTE: This code uses a single GPU. If you have T4Γ2, only one GPU is used.
|
| 169 |
+
A single T4 (16GB) is sufficient β the second GPU is wasted without DDP.
|
| 170 |
```
|
| 171 |
|
| 172 |
+
### Trajectory pool memory on CPU vs GPU
|
| 173 |
+
|
| 174 |
+
The trajectory pool stores ALL flow trajectories for the entire training. For image experiments this is gigabytes:
|
| 175 |
+
- MNIST: 1.92M entries Γ 784 dims Γ 4 bytes = **6 GB** on CPU
|
| 176 |
+
- CIFAR: 1.6M entries Γ 3072 dims Γ 4 bytes = **19.6 GB** on CPU
|
| 177 |
+
|
| 178 |
+
The pool MUST live on CPU. Only the sampled minibatch (128-256 samples) goes to GPU per training step. This is already how the code works (trajectories stored as CPU tensors, `.to(device)` in `sample()`), but it's worth being explicit about why.
|
| 179 |
+
|
| 180 |
---
|
| 181 |
|
| 182 |
+
## Phase 5: Testing Strategy
|
| 183 |
|
| 184 |
+
### Always test on CPU first with tiny configs
|
| 185 |
|
| 186 |
+
Before any GPU run, verify the full pipeline works end-to-end:
|
|
|
|
|
|
|
| 187 |
|
| 188 |
+
```bash
|
| 189 |
+
# Tiny run β should complete in <30 seconds
|
| 190 |
+
python main.py --experiment 2d --dataset 8gaussians --steps 5 --pool-batches 5 --train-iters 100
|
| 191 |
|
| 192 |
+
# Slightly larger β should complete in <5 minutes
|
| 193 |
+
python main.py --experiment 2d --dataset 8gaussians --steps 5 --pool-batches 20 --train-iters 2000
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
```
|
| 195 |
|
| 196 |
+
### Test image experiments separately with minimal configs
|
| 197 |
|
| 198 |
+
```bash
|
| 199 |
+
# MNIST smoke test β 2 pool batches, 5 training iters per phase
|
| 200 |
+
python main.py --experiment mnist --pool-batches 2 --train-iters 5
|
| 201 |
+
|
| 202 |
+
# If this crashes, fix before scaling up
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
```
|
| 204 |
|
| 205 |
+
**Mistake I made**: I tested 2D experiments thoroughly on CPU (both tiny and medium runs worked) but shipped the image experiments without testing them at all. The geomloss tensor shape bug affected ONLY the image path, so 2D success gave false confidence. The first GPU test of MNIST crashed immediately.
|
| 206 |
+
|
| 207 |
+
**Rule**: Test EVERY experiment type, not just the simplest one. If you have `{2d, mnist, cifar10}` experiments, test all three with minimal configs before declaring the code ready.
|
| 208 |
+
|
| 209 |
+
### Test all training phases, not just the first one
|
| 210 |
|
| 211 |
+
Even after fixing Phase 1, Phase 2 can still crash due to shared state (see DataLoader trap in Phase 6). Run with `--train-iters 5 --pool-batches 2` to verify all 3 phases complete without errors. This takes <60 seconds on CPU for MNIST.
|
|
|
|
|
|
|
|
|
|
| 212 |
|
| 213 |
+
---
|
| 214 |
+
|
| 215 |
+
## Phase 6: Shared State Across Training Phases
|
| 216 |
|
| 217 |
+
### The DataLoader trap
|
| 218 |
|
| 219 |
When a single `DatasetLoader` object is shared across multiple training phases, **lazy-initialized internal state** (like a cached DataLoader) will silently break subsequent phases.
|
| 220 |
|
|
|
|
| 239 |
|
| 240 |
---
|
| 241 |
|
| 242 |
+
## Phase 7: Checkpointing and Multi-Session Training
|
| 243 |
|
| 244 |
+
### Why this matters
|
| 245 |
|
| 246 |
+
Paper reproduction often requires training runs that exceed a single GPU session. Kaggle gives 9 hours per T4 session. MNIST NSGF++ with full paper config (100K+100K+40K iters) needs ~7-8 hours on T4 β tight. CIFAR-10 (200K+200K+40K) is impossible in one session.
|
| 247 |
|
| 248 |
+
Without checkpointing, a Kaggle timeout = all progress lost.
|
|
|
|
|
|
|
| 249 |
|
| 250 |
+
### Phase-level checkpointing
|
| 251 |
+
|
| 252 |
+
For multi-phase training, save a checkpoint after EACH phase completes:
|
| 253 |
+
|
| 254 |
+
```python
|
| 255 |
+
# After Phase 1 completes:
|
| 256 |
+
torch.save({
|
| 257 |
+
"nsgf_model_state": nsgf_model.state_dict(),
|
| 258 |
+
"phase": 1,
|
| 259 |
+
}, "checkpoints/phase1_complete.pt")
|
| 260 |
+
|
| 261 |
+
# After Phase 2 completes:
|
| 262 |
+
torch.save({
|
| 263 |
+
"nsgf_model_state": nsgf_model.state_dict(),
|
| 264 |
+
"nsf_model_state": nsf_model.state_dict(),
|
| 265 |
+
"phase": 2,
|
| 266 |
+
}, "checkpoints/phase2_complete.pt")
|
| 267 |
```
|
| 268 |
|
| 269 |
+
Then implement `--resume-phase N` that loads the phase N-1 checkpoint and skips completed phases:
|
| 270 |
|
| 271 |
```bash
|
| 272 |
+
# Session 1: Run Phase 1 (gets interrupted or completes)
|
| 273 |
+
python main.py --experiment mnist
|
| 274 |
|
| 275 |
+
# Session 2: Skip Phase 1, start Phase 2
|
| 276 |
+
python main.py --experiment mnist --resume-phase 2
|
| 277 |
+
|
| 278 |
+
# Session 3: Skip Phases 1+2, run Phase 3 + inference
|
| 279 |
+
python main.py --experiment mnist --resume-phase 3
|
| 280 |
```
|
| 281 |
|
| 282 |
+
### Step-level checkpointing within phases
|
| 283 |
|
| 284 |
+
For long phases (100K+ steps), also save within the phase every N steps:
|
| 285 |
|
| 286 |
+
```python
|
| 287 |
+
if (step + 1) % checkpoint_every == 0:
|
| 288 |
+
torch.save({
|
| 289 |
+
"model_state": model.state_dict(),
|
| 290 |
+
"optimizer_state": optimizer.state_dict(),
|
| 291 |
+
"step": step + 1,
|
| 292 |
+
}, "checkpoints/nsgf_checkpoint.pt")
|
| 293 |
+
```
|
| 294 |
+
|
| 295 |
+
### Important: checkpoint persistence on Kaggle
|
| 296 |
|
| 297 |
+
Kaggle notebooks persist `/kaggle/working/` across cells within the same session, but NOT across sessions. To carry checkpoints between sessions:
|
| 298 |
+
1. Save checkpoints to `/kaggle/working/nsgf-plusplus/checkpoints/`
|
| 299 |
+
2. Before session ends, commit the notebook output or copy checkpoints to a dataset
|
| 300 |
+
3. In the new session, restore checkpoints before running `--resume-phase`
|
| 301 |
|
| 302 |
---
|
| 303 |
|
| 304 |
+
## Phase 8: Debugging GPU Runs
|
| 305 |
|
| 306 |
### Common error patterns
|
| 307 |
|
|
|
|
| 310 |
| `ValueError: (N,D) or (B,N,D)` | Library expects flat tensors, got images | Flatten before library call |
|
| 311 |
| `RuntimeError: size of tensor a (X) must match size of tensor b (Y)` | Shared DataLoader with wrong batch size | Recreate DataLoader when batch size changes |
|
| 312 |
| `RuntimeError: shape mismatch` in UNet | Skip connection count wrong | Count pushes and pops manually |
|
| 313 |
+
| `CUDA OOM` during pool building (Sinkhorn) | Sinkhorn batch too large for GPU | Reduce `--sinkhorn-batch` (e.g. 128β32) |
|
| 314 |
+
| `CUDA OOM` during training | Training batch too large or model too big | Reduce training batch, increase grad accum |
|
| 315 |
+
| `CUDA OOM` at phase transition | Memory not freed between phases | Add `torch.cuda.empty_cache()` + `del pool` |
|
| 316 |
| Training loss plateaus high | Pool too small or too few iterations | Increase pool batches, more iters |
|
| 317 |
| W2 distance too high | Undertrained model | Full paper config: 200 batches, 20k iters |
|
| 318 |
+
| Only 1 of 2 GPUs used | Code is single-GPU, no DDP | Expected β use single GPU or add DDP |
|
| 319 |
+
| `KeyboardInterrupt` mid-training | Training too long at scale | Check `checkpoints/` for latest save |
|
| 320 |
|
| 321 |
### When the user runs on their hardware
|
| 322 |
|
| 323 |
If you're developing code that the user will run on their own GPU (Kaggle, Colab, local):
|
| 324 |
|
| 325 |
1. **Provide exact commands** β don't make them figure out args
|
| 326 |
+
2. **Warn about expected runtimes** β "2D full run: ~20min on T4, MNIST: ~2-4 hours per phase, CIFAR-10: ~4+ hours per phase"
|
| 327 |
3. **Include checkpoint saving** β so partial runs aren't wasted
|
| 328 |
+
4. **Document GPU requirements** β "MNIST fits on T4 16GB, CIFAR-10 needs `--sinkhorn-batch 32`"
|
| 329 |
+
5. **Document multi-GPU limitations** β "Single-GPU only. T4Γ2 wastes the second GPU."
|
| 330 |
+
6. **Test the exact commands yourself** β if you can't run on GPU, at least verify the command parses correctly on CPU
|
| 331 |
|
| 332 |
---
|
| 333 |
|
|
|
|
| 353 |
- **Root cause**: False confidence from 2D success. Assumed same code path.
|
| 354 |
- **Prevention**: Test EVERY experiment type with minimal configs. Different experiment types often exercise different code paths.
|
| 355 |
|
| 356 |
+
4. **No checkpoint saving** (MODERATE β became CRITICAL at scale)
|
| 357 |
- **What**: No intermediate checkpoints during long training runs.
|
| 358 |
+
- **Impact**: If training is interrupted (Kaggle timeout, OOM, accidental Ctrl+C), all progress is lost. MNIST full run is ~7 hours β losing that is devastating.
|
| 359 |
+
- **Prevention**: Save checkpoints every N iterations. Save after each phase. Implement `--resume-phase` flag. Test resume actually works.
|
| 360 |
|
| 361 |
5. **UNet forward pass fragility** (LOW-MODERATE)
|
| 362 |
- **What**: `_get_num_res_blocks()` infers block count from module list length division.
|
|
|
|
| 365 |
|
| 366 |
6. **DataLoader batch size mismatch across phases** (CRITICAL)
|
| 367 |
- **What**: Shared `DatasetLoader` caches a DataLoader with batch_size=256 from Phase 1. Phase 2 requests batch_size=128 but gets 256 back β tensor dimension mismatch crash.
|
| 368 |
+
- **Impact**: Phase 2 (NSF) crashes immediately even after Phase 1 completes successfully.
|
| 369 |
+
- **Root cause**: Lazy initialization pattern without invalidation.
|
| 370 |
+
- **Prevention**: When sharing stateful objects across consumers with different configs, track all cached parameters and invalidate on change.
|
| 371 |
|
| 372 |
7. **CLI flag not overriding all training phases** (LOW)
|
| 373 |
- **What**: `--train-iters` flag overrode NSGF and NSF iterations but NOT the phase predictor iterations (40,000 default). Smoke tests would hang on Phase 3 even with `--train-iters 5`.
|
| 374 |
+
- **Impact**: Tests take much longer than expected.
|
| 375 |
- **Root cause**: Forgot that 3-phase training means 3 iteration counts to override.
|
| 376 |
+
- **Prevention**: When adding a CLI override, grep the config for ALL fields it should affect.
|
| 377 |
+
|
| 378 |
+
8. **CIFAR-10 Sinkhorn OOM on T4** (CRITICAL)
|
| 379 |
+
- **What**: Paper uses `sinkhorn.batch_size=128` for CIFAR. Sinkhorn on 128 Γ 3072-dim (flattened 3Γ32Γ32) with `tensorized` backend computes a 128Γ128 cost matrix with 3072-dim vectors, plus autograd for potentials. This OOMs on T4 16GB during pool building.
|
| 380 |
+
- **Impact**: CIFAR-10 experiment crashes before even starting training. User loses their Kaggle session.
|
| 381 |
+
- **Root cause**: Used paper's hyperparameters without estimating VRAM for target hardware. Paper authors likely used A100 80GB.
|
| 382 |
+
- **Prevention**: ALWAYS estimate VRAM before running. Sinkhorn with `tensorized` backend is O(NΒ² Γ D). For CIFAR: 128Β² Γ 3072 Γ 4 bytes Γ ~10 (overhead) β 2+ GB per call, Γ10 calls per pool batch = too much. Reduce N: 32Β² Γ 3072 is 4Γ cheaper. Add `--sinkhorn-batch` CLI flag so users can tune without editing config.
|
| 383 |
+
|
| 384 |
+
9. **No GPU memory freed between phases** (MODERATE)
|
| 385 |
+
- **What**: After pool building, the Sinkhorn computation graph's CUDA allocations remain cached even though they're no longer needed. Training then starts with less available VRAM.
|
| 386 |
+
- **Impact**: Training phase might OOM even though pool building finished.
|
| 387 |
+
- **Root cause**: PyTorch's CUDA allocator doesn't automatically return memory to the OS.
|
| 388 |
+
- **Prevention**: `torch.cuda.empty_cache()` after pool building completes. Also `del pool` if the pool data was already finalized to separate tensors.
|
| 389 |
+
|
| 390 |
+
10. **Multi-GPU assumption** (LOW)
|
| 391 |
+
- **What**: User has T4Γ2 on Kaggle. Code is single-GPU. Second GPU sits idle.
|
| 392 |
+
- **Impact**: User pays for 2 GPUs but only uses 1. They might think the code is broken.
|
| 393 |
+
- **Root cause**: Didn't document single-GPU limitation.
|
| 394 |
+
- **Prevention**: Document GPU requirements explicitly. If multi-GPU is needed, implement DDP β but that's a significant scope change, so discuss with user first.
|
| 395 |
|
| 396 |
---
|
| 397 |
|
|
|
|
| 403 |
β‘ Third-party library APIs tested with exact tensor shapes per experiment
|
| 404 |
β‘ Shared state across phases verified (DataLoaders, iterators, caches)
|
| 405 |
β‘ CLI flags override ALL relevant config values (not just some)
|
| 406 |
+
β‘ VRAM estimated for target hardware β will Sinkhorn/model/pool fit?
|
| 407 |
+
β‘ Sinkhorn batch size appropriate for target GPU (not just paper's GPU)
|
| 408 |
+
β‘ torch.cuda.empty_cache() called between memory-intensive phases
|
| 409 |
β‘ Training loop profiled β no O(N) operations per step where O(1) suffices
|
| 410 |
β‘ Memory estimated per experiment (pool size Γ data dim Γ 4 bytes)
|
| 411 |
+
β‘ Checkpointing implemented: every N steps + after each phase
|
| 412 |
+
β‘ --resume-phase tested and working (load checkpoint β skip phases β continue)
|
| 413 |
+
β‘ Clear CLI with sensible defaults and override flags for GPU-sensitive params
|
| 414 |
β‘ Expected runtimes documented per hardware tier
|
| 415 |
+
β‘ Multi-GPU limitations documented
|
| 416 |
β‘ Error messages are clear (not just stack traces)
|
| 417 |
β‘ Results directory created automatically
|
| 418 |
β‘ Requirements.txt includes ALL dependencies with minimum versions
|
|
|
|
| 441 |
9. **Shared objects across phases are landmines.** When a DataLoader, iterator, or cache is shared across training phases, any phase-specific parameter (batch size, number of workers, shuffle mode) can silently break later phases. Either don't share, or implement proper invalidation. Test by running all phases sequentially with different configs per phase.
|
| 442 |
|
| 443 |
10. **CLI overrides must be exhaustive.** If your config has N copies of a parameter (one per training phase), your CLI override must touch all N. Grep the config file for the parameter name to find all instances.
|
| 444 |
+
|
| 445 |
+
11. **Paper hyperparameters assume paper hardware.** If a paper reports batch_size=128 and trained on A100 80GB, that batch size may OOM on your T4 16GB. Always re-derive batch sizes from VRAM constraints, keeping the total samples seen (batch Γ iterations) the same.
|
| 446 |
+
|
| 447 |
+
12. **Estimate VRAM before running, not after OOM.** For Sinkhorn: O(NΒ² Γ D). For model: count parameters Γ 4 bytes (fp32) Γ 3 (params + gradients + optimizer). For pool: stored on CPU but sampled minibatch goes to GPU. Write this down before your first GPU run.
|
| 448 |
+
|
| 449 |
+
13. **Checkpoint at phase boundaries, not just step boundaries.** Phase-level checkpoints enable `--resume-phase` which is the minimum viable recovery. Step-level checkpoints within long phases are a bonus. Both together make multi-session training actually work.
|
| 450 |
+
|
| 451 |
+
14. **Free GPU memory between phases.** `torch.cuda.empty_cache()` after pool building or any phase that uses different GPU memory patterns than the next phase. Also `del` large objects (pools, computation graphs) that won't be needed again.
|
| 452 |
+
|
| 453 |
+
15. **Document what your code does NOT support.** Single-GPU only? No mixed precision? No gradient accumulation? Say so. Users with multi-GPU setups will waste time wondering why only one GPU is active if you don't tell them.
|