rogermt commited on
Commit
88f3058
·
verified ·
1 Parent(s): 74841f0

Add SKILL.md — paper reproduction skill with lessons from NSGF++ implementation

Browse files
Files changed (1) hide show
  1. SKILL.md +326 -0
SKILL.md ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: paper-reproduction
3
+ description: "Skill for reproducing ML research papers from scratch when no official code exists. Use this whenever a user asks to implement, reproduce, or replicate a paper — especially papers involving novel loss functions, custom training loops, or non-standard architectures that aren't covered by existing HF trainers. Also use when the user mentions 'paper reproduction', 'implement this paper', 'no official code', or describes a method from a specific arxiv paper. Covers: reading papers systematically, extracting hyperparameters, building custom training pipelines, handling library-specific gotchas (geomloss, POT, custom UNets), and iterating on GPU results."
4
+ ---
5
+
6
+ # Paper Reproduction Skill
7
+
8
+ A skill for reproducing ML research papers from scratch, learned through the experience of reproducing NSGF++ (arXiv:2401.14069) — a Neural Sinkhorn Gradient Flow paper with no official implementation.
9
+
10
+ ## When to use this skill
11
+
12
+ - User wants to reproduce/implement an ML paper
13
+ - No official code repository exists
14
+ - The paper uses custom training loops, novel losses, or non-standard architectures
15
+ - The method doesn't fit neatly into existing HF Trainer abstractions (SFT, DPO, GRPO)
16
+
17
+ ---
18
+
19
+ ## Phase 1: Read the Paper Properly
20
+
21
+ Most reproduction failures trace back to incomplete paper reading. Don't skim — read methodology sections (3, 4, 5) line by line, and read ALL appendices.
22
+
23
+ ### What to extract (checklist)
24
+
25
+ ```
26
+ □ Loss function — exact mathematical form, every symbol defined
27
+ □ Architecture — layer counts, hidden dims, activation functions, normalization
28
+ □ Optimizer — type, learning rate, betas, weight decay, scheduler
29
+ □ Batch size — for each phase/component separately
30
+ □ Training iterations — for each phase/component
31
+ □ Dataset preprocessing — normalization range, image size, augmentation
32
+ □ Evaluation protocol — metrics, number of samples, any special setup
33
+ □ Hyperparameters per experiment — papers often have different configs per dataset
34
+ □ Algorithm pseudocode — if provided, follow it exactly before improvising
35
+ ```
36
+
37
+ ### Mistake I made: Incomplete appendix reading
38
+
39
+ I extracted most hyperparameters correctly from the NSGF++ paper but missed a critical detail about how geomloss handles image tensors. The paper says "GeomLoss package" but doesn't spell out that images must be flattened to (N, D) format for the `SamplesLoss` API. This caused the MNIST and CIFAR-10 experiments to crash immediately on GPU.
40
+
41
+ **Lesson**: When a paper references a specific library, read that library's documentation and test its API with the exact tensor shapes you'll use BEFORE writing the full pipeline.
42
+
43
+ ---
44
+
45
+ ## Phase 2: Library API Verification
46
+
47
+ ### CRITICAL: Test third-party library APIs with your actual tensor shapes
48
+
49
+ This is the single biggest mistake pattern in paper reproduction. You read the paper, understand the math, implement everything — then it crashes because a library function expects `(N, D)` but you passed `(N, C, H, W)`.
50
+
51
+ **The rule**: Before building ANY training loop that uses a third-party library (geomloss, POT, torchsde, torchdiffeq, etc.), write a 10-line test script:
52
+
53
+ ```python
54
+ import torch
55
+ from geomloss import SamplesLoss
56
+
57
+ # Test with EXACT shapes you'll use in training
58
+ loss_fn = SamplesLoss(loss="sinkhorn", p=2, blur=0.5, potentials=True)
59
+
60
+ # 2D case — works fine
61
+ x_2d = torch.randn(256, 2, requires_grad=True)
62
+ y_2d = torch.randn(256, 2)
63
+ F, G = loss_fn(x_2d, y_2d) # ✅ OK
64
+
65
+ # Image case — THIS CRASHES
66
+ x_img = torch.randn(128, 1, 28, 28, requires_grad=True)
67
+ y_img = torch.randn(128, 1, 28, 28)
68
+ F, G = loss_fn(x_img, y_img) # ❌ ValueError: must be (N,D) or (B,N,D)
69
+
70
+ # Image case — FIXED by flattening
71
+ B = x_img.shape[0]
72
+ x_flat = x_img.view(B, -1).requires_grad_(True)
73
+ y_flat = y_img.view(B, -1)
74
+ F, G = loss_fn(x_flat, y_flat) # ✅ OK
75
+ ```
76
+
77
+ ### Mistake I made: geomloss tensor shape assumption
78
+
79
+ The `SamplesLoss` in geomloss requires inputs as `(N, D)` or `(B, N, D)` tensors. For 2D experiments with shape `(256, 2)` this works perfectly. For images with shape `(128, 1, 28, 28)` it crashes with:
80
+
81
+ ```
82
+ ValueError: Input samples 'x' and 'y' should be encoded as (N,D) or (B,N,D) (batch) tensors.
83
+ ```
84
+
85
+ **The fix**: Flatten images before passing to geomloss, reshape gradients back after:
86
+
87
+ ```python
88
+ def compute_velocity(self, X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
89
+ original_shape = X.shape
90
+ is_image = X.dim() == 4
91
+
92
+ if is_image:
93
+ B = X.shape[0]
94
+ X_flat = X.detach().clone().view(B, -1).requires_grad_(True)
95
+ Y_flat = Y.detach().view(B, -1)
96
+ else:
97
+ X_flat = X.detach().clone().requires_grad_(True)
98
+ Y_flat = Y.detach()
99
+
100
+ # Self-potential
101
+ F_self, _ = self.loss_fn(X_flat, X_flat.detach().clone())
102
+ grad_self = torch.autograd.grad(F_self.sum(), X_flat)[0]
103
+
104
+ # Cross-potential
105
+ X_flat2 = X.detach().clone().view(B, -1).requires_grad_(True) if is_image else X.detach().clone().requires_grad_(True)
106
+ F_cross, _ = self.loss_fn(X_flat2, Y_flat)
107
+ grad_cross = torch.autograd.grad(F_cross.sum(), X_flat2)[0]
108
+
109
+ velocity = grad_self.detach() - grad_cross.detach()
110
+
111
+ if is_image:
112
+ velocity = velocity.view(original_shape)
113
+
114
+ return velocity
115
+ ```
116
+
117
+ This pattern — flatten before library call, reshape after — applies to many optimal transport libraries (POT, geomloss, ott-jax).
118
+
119
+ ---
120
+
121
+ ## Phase 3: Architecture Gotchas
122
+
123
+ ### UNet skip connections
124
+
125
+ When building a UNet from scratch (rather than importing from guided-diffusion), the skip connection bookkeeping is the #1 source of shape mismatch errors.
126
+
127
+ **The pattern that works**:
128
+ 1. During the downward pass, push every intermediate activation onto a `skips` list
129
+ 2. During the upward pass, pop from `skips` and concatenate
130
+ 3. The number of pops must EXACTLY equal the number of pushes
131
+
132
+ **Mistake pattern**: Using a helper like `_get_num_res_blocks()` that infers block count from module list lengths. This is fragile — if the number of levels or blocks per level varies, the inference breaks.
133
+
134
+ **Better approach**: Store `num_res_blocks` as an instance variable at init time and use it directly:
135
+
136
+ ```python
137
+ def __init__(self, ..., num_res_blocks=2, channel_mult=[1,2,2,2], ...):
138
+ self.num_res_blocks = num_res_blocks
139
+ self.num_levels = len(channel_mult)
140
+ # ... build layers ...
141
+
142
+ def forward(self, x, t):
143
+ # Use self.num_res_blocks directly, not a computed value
144
+ ```
145
+
146
+ ### GroupNorm channel requirements
147
+
148
+ `nn.GroupNorm(32, channels)` requires `channels` to be divisible by 32. For small models (e.g., MNIST with `model_channels=32`), this is fine at the first level but may break at deeper levels if `channel_mult` creates channels not divisible by 32.
149
+
150
+ **Safety check**: At init time, verify all channel counts are divisible by the group count:
151
+ ```python
152
+ for level, mult in enumerate(channel_mult):
153
+ ch = model_channels * mult
154
+ assert ch % 32 == 0, f"Level {level}: channels={ch} not divisible by 32"
155
+ ```
156
+
157
+ ---
158
+
159
+ ## Phase 4: Training Loop Patterns for Custom Pipelines
160
+
161
+ ### Trajectory pool memory management
162
+
163
+ When building trajectory pools (storing (x, v, t) tuples from gradient flow), memory can explode:
164
+ - MNIST: 256 samples × 1500 batches × 5 steps × 784 dims × 4 bytes ≈ 2.4 GB (manageable)
165
+ - CIFAR-10: 128 samples × 2500 batches × 5 steps × 3072 dims × 4 bytes ≈ 19 GB (tight on T4)
166
+
167
+ **The fix**: Store pool tensors on CPU, transfer to GPU only during sampling:
168
+
169
+ ```python
170
+ class TrajectoryPool:
171
+ def sample(self, batch_size, device="cpu"):
172
+ # Concatenate on CPU, index, THEN move to device
173
+ all_x = torch.cat(self.x_pool, dim=0) # stays on CPU
174
+ idx = torch.randint(0, all_x.shape[0], (batch_size,))
175
+ return all_x[idx].to(device), ... # only batch moves to GPU
176
+ ```
177
+
178
+ **Mistake I made**: The pool sampling code calls `torch.cat` on the entire pool every training step, which is O(pool_size) per step. For 512K entries this is slow. Better: pre-concatenate once after pool building, then just index:
179
+
180
+ ```python
181
+ def finalize(self):
182
+ """Call once after pool is fully built."""
183
+ self._all_x = torch.cat(self.x_pool, dim=0)
184
+ self._all_v = torch.cat(self.v_pool, dim=0)
185
+ self._all_t = torch.tensor(self.t_pool, dtype=torch.float32)
186
+ # Free the lists
187
+ self.x_pool = self.v_pool = self.t_pool = None
188
+
189
+ def sample(self, batch_size, device="cpu"):
190
+ idx = torch.randint(0, self._all_x.shape[0], (batch_size,))
191
+ return self._all_x[idx].to(device), self._all_v[idx].to(device), self._all_t[idx].to(device)
192
+ ```
193
+
194
+ ### Multi-phase training (NSGF++)
195
+
196
+ NSGF++ has 3 sequential training phases:
197
+ 1. **NSGF**: Build trajectory pool → train velocity field
198
+ 2. **NSF**: Use trained NSGF to generate P0 samples → train straight flow
199
+ 3. **Phase predictor**: Train CNN to predict transition time
200
+
201
+ **Key insight**: Each phase depends on the previous one being fully trained. Don't try to interleave them. The NSGF model must be in `eval()` mode when used as a sample generator in phases 2 and 3.
202
+
203
+ ---
204
+
205
+ ## Phase 5: Testing Strategy
206
+
207
+ ### Always test on CPU first with tiny configs
208
+
209
+ Before any GPU run, verify the full pipeline works end-to-end:
210
+
211
+ ```bash
212
+ # Tiny run — should complete in <30 seconds
213
+ python main.py --experiment 2d --dataset 8gaussians --steps 5 --pool-batches 5 --train-iters 100
214
+
215
+ # Slightly larger — should complete in <5 minutes
216
+ python main.py --experiment 2d --dataset 8gaussians --steps 5 --pool-batches 20 --train-iters 2000
217
+ ```
218
+
219
+ ### Test image experiments separately with minimal configs
220
+
221
+ ```bash
222
+ # MNIST smoke test — 2 pool batches, 50 training iters
223
+ python main.py --experiment mnist --pool-batches 2 --train-iters 50
224
+
225
+ # If this crashes, fix before scaling up
226
+ ```
227
+
228
+ **Mistake I made**: I tested 2D experiments thoroughly on CPU (both tiny and medium runs worked) but shipped the image experiments without testing them at all. The geomloss tensor shape bug affected ONLY the image path, so 2D success gave false confidence. The first GPU test of MNIST crashed immediately.
229
+
230
+ **Rule**: Test EVERY experiment type, not just the simplest one. If you have `{2d, mnist, cifar10}` experiments, test all three with minimal configs before declaring the code ready.
231
+
232
+ ---
233
+
234
+ ## Phase 6: Debugging GPU Runs
235
+
236
+ ### Common error patterns
237
+
238
+ | Error | Cause | Fix |
239
+ |-------|-------|-----|
240
+ | `ValueError: (N,D) or (B,N,D)` | Library expects flat tensors, got images | Flatten before library call |
241
+ | `RuntimeError: shape mismatch` in UNet | Skip connection count wrong | Count pushes and pops manually |
242
+ | `CUDA OOM` during pool building | Pool too large for GPU | Build pool on CPU, sample to GPU |
243
+ | `CUDA OOM` during training | Batch too large or model too big | Reduce batch → increase grad accum |
244
+ | Training loss plateaus high | Pool too small or too few iterations | Increase pool batches, more iters |
245
+ | W2 distance too high | Undertrained model | Full paper config: 200 batches, 20k iters |
246
+ | `KeyboardInterrupt` during training | Training takes too long at scale | Expected — full 2D takes ~20min on T4 |
247
+
248
+ ### When the user runs on their hardware
249
+
250
+ If you're developing code that the user will run on their own GPU (Kaggle, Colab, local):
251
+
252
+ 1. **Provide exact commands** — don't make them figure out args
253
+ 2. **Warn about expected runtimes** — "2D full run: ~20min on T4, MNIST: ~2-4 hours, CIFAR-10: ~8-12 hours"
254
+ 3. **Include checkpoint saving** — so partial runs aren't wasted
255
+ 4. **Test the exact commands yourself** — if you can't run on GPU, at least verify the command parses correctly on CPU
256
+
257
+ ---
258
+
259
+ ## Mistake Catalog
260
+
261
+ ### Mistakes made during NSGF++ reproduction
262
+
263
+ 1. **geomloss tensor shape bug** (CRITICAL)
264
+ - **What**: `SamplesLoss` requires `(N,D)` tensors. Image experiments passed `(N,C,H,W)`.
265
+ - **Impact**: MNIST and CIFAR-10 experiments crash immediately. 2D works fine, hiding the bug.
266
+ - **Root cause**: Only tested 2D path. Didn't verify library API with image tensor shapes.
267
+ - **Prevention**: Write a standalone API test script for every third-party library, testing with ALL tensor shapes you'll use.
268
+
269
+ 2. **TrajectoryPool sampling performance** (MODERATE)
270
+ - **What**: `torch.cat` called on entire pool every training step.
271
+ - **Impact**: Training slower than necessary. At 512K pool entries, the cat+index is the bottleneck (~0.5s per step vs ~0.05s for the actual forward/backward).
272
+ - **Root cause**: Didn't profile the training loop.
273
+ - **Prevention**: Pre-concatenate the pool after building it. Profile before shipping.
274
+
275
+ 3. **Incomplete experiment testing** (CRITICAL)
276
+ - **What**: Tested 2D experiments only. Shipped MNIST/CIFAR untested.
277
+ - **Impact**: User's first GPU run crashes. Wasted their Kaggle session time.
278
+ - **Root cause**: False confidence from 2D success. Assumed same code path.
279
+ - **Prevention**: Test EVERY experiment type with minimal configs. Different experiment types often exercise different code paths.
280
+
281
+ 4. **No checkpoint saving** (MODERATE)
282
+ - **What**: No intermediate checkpoints during long training runs.
283
+ - **Impact**: If training is interrupted (Kaggle timeout, OOM), all progress is lost.
284
+ - **Prevention**: Save checkpoints every N iterations. Implement `--resume` flag.
285
+
286
+ 5. **UNet forward pass fragility** (LOW-MODERATE)
287
+ - **What**: `_get_num_res_blocks()` infers block count from module list length division.
288
+ - **Impact**: Could break silently with non-standard configs.
289
+ - **Prevention**: Store config values as instance variables, don't infer from module counts.
290
+
291
+ ---
292
+
293
+ ## Pre-flight Checklist (before declaring code ready)
294
+
295
+ ```
296
+ □ All experiment types tested with minimal configs (not just the easiest one)
297
+ □ Third-party library APIs tested with exact tensor shapes per experiment
298
+ □ Training loop profiled — no O(N) operations per step where O(1) suffices
299
+ □ Memory estimated per experiment (pool size × data dim × 4 bytes)
300
+ □ Checkpointing implemented for runs >10 minutes
301
+ □ Clear CLI with sensible defaults and override flags
302
+ □ Expected runtimes documented per hardware tier
303
+ □ Error messages are clear (not just stack traces)
304
+ □ Results directory created automatically
305
+ □ Requirements.txt includes ALL dependencies with minimum versions
306
+ ```
307
+
308
+ ---
309
+
310
+ ## General Principles for Paper Reproduction
311
+
312
+ 1. **Read the appendix first.** The appendix contains the actual implementation details. The main paper is the story; the appendix is the recipe.
313
+
314
+ 2. **Test the boundaries, not just the happy path.** If your code handles 2D, MNIST, and CIFAR-10, test all three. The bug is always in the path you didn't test.
315
+
316
+ 3. **Library APIs are opaque until tested.** Don't assume a function accepts your tensor shape just because it "makes sense." Write a 10-line test script.
317
+
318
+ 4. **Pre-concatenate, don't re-concatenate.** Any data structure that's built once and sampled many times should be finalized into a single tensor after building.
319
+
320
+ 5. **The user's time is more expensive than your time.** A crash on their GPU after 5 minutes of setup is worse than you spending 30 extra minutes testing. Ship code that works on first run.
321
+
322
+ 6. **Flatten for OT libraries.** Optimal transport libraries (geomloss, POT, ott-jax) almost universally expect `(N, D)` point clouds. Images must be flattened. This is the #1 gotcha in OT-based generative models.
323
+
324
+ 7. **Store training state on CPU, compute on GPU.** Trajectory pools, replay buffers, and other large data structures should live on CPU. Only the current minibatch goes to GPU.
325
+
326
+ 8. **Multi-phase training = multiple separate trainers.** Don't try to be clever with a single training loop that switches phases. Each phase is a distinct trainer with its own optimizer. The previous phase's model goes to `eval()`.