Add SKILL.md — paper reproduction skill with lessons from NSGF++ implementation
Browse files
SKILL.md
ADDED
|
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
name: paper-reproduction
|
| 3 |
+
description: "Skill for reproducing ML research papers from scratch when no official code exists. Use this whenever a user asks to implement, reproduce, or replicate a paper — especially papers involving novel loss functions, custom training loops, or non-standard architectures that aren't covered by existing HF trainers. Also use when the user mentions 'paper reproduction', 'implement this paper', 'no official code', or describes a method from a specific arxiv paper. Covers: reading papers systematically, extracting hyperparameters, building custom training pipelines, handling library-specific gotchas (geomloss, POT, custom UNets), and iterating on GPU results."
|
| 4 |
+
---
|
| 5 |
+
|
| 6 |
+
# Paper Reproduction Skill
|
| 7 |
+
|
| 8 |
+
A skill for reproducing ML research papers from scratch, learned through the experience of reproducing NSGF++ (arXiv:2401.14069) — a Neural Sinkhorn Gradient Flow paper with no official implementation.
|
| 9 |
+
|
| 10 |
+
## When to use this skill
|
| 11 |
+
|
| 12 |
+
- User wants to reproduce/implement an ML paper
|
| 13 |
+
- No official code repository exists
|
| 14 |
+
- The paper uses custom training loops, novel losses, or non-standard architectures
|
| 15 |
+
- The method doesn't fit neatly into existing HF Trainer abstractions (SFT, DPO, GRPO)
|
| 16 |
+
|
| 17 |
+
---
|
| 18 |
+
|
| 19 |
+
## Phase 1: Read the Paper Properly
|
| 20 |
+
|
| 21 |
+
Most reproduction failures trace back to incomplete paper reading. Don't skim — read methodology sections (3, 4, 5) line by line, and read ALL appendices.
|
| 22 |
+
|
| 23 |
+
### What to extract (checklist)
|
| 24 |
+
|
| 25 |
+
```
|
| 26 |
+
□ Loss function — exact mathematical form, every symbol defined
|
| 27 |
+
□ Architecture — layer counts, hidden dims, activation functions, normalization
|
| 28 |
+
□ Optimizer — type, learning rate, betas, weight decay, scheduler
|
| 29 |
+
□ Batch size — for each phase/component separately
|
| 30 |
+
□ Training iterations — for each phase/component
|
| 31 |
+
□ Dataset preprocessing — normalization range, image size, augmentation
|
| 32 |
+
□ Evaluation protocol — metrics, number of samples, any special setup
|
| 33 |
+
□ Hyperparameters per experiment — papers often have different configs per dataset
|
| 34 |
+
□ Algorithm pseudocode — if provided, follow it exactly before improvising
|
| 35 |
+
```
|
| 36 |
+
|
| 37 |
+
### Mistake I made: Incomplete appendix reading
|
| 38 |
+
|
| 39 |
+
I extracted most hyperparameters correctly from the NSGF++ paper but missed a critical detail about how geomloss handles image tensors. The paper says "GeomLoss package" but doesn't spell out that images must be flattened to (N, D) format for the `SamplesLoss` API. This caused the MNIST and CIFAR-10 experiments to crash immediately on GPU.
|
| 40 |
+
|
| 41 |
+
**Lesson**: When a paper references a specific library, read that library's documentation and test its API with the exact tensor shapes you'll use BEFORE writing the full pipeline.
|
| 42 |
+
|
| 43 |
+
---
|
| 44 |
+
|
| 45 |
+
## Phase 2: Library API Verification
|
| 46 |
+
|
| 47 |
+
### CRITICAL: Test third-party library APIs with your actual tensor shapes
|
| 48 |
+
|
| 49 |
+
This is the single biggest mistake pattern in paper reproduction. You read the paper, understand the math, implement everything — then it crashes because a library function expects `(N, D)` but you passed `(N, C, H, W)`.
|
| 50 |
+
|
| 51 |
+
**The rule**: Before building ANY training loop that uses a third-party library (geomloss, POT, torchsde, torchdiffeq, etc.), write a 10-line test script:
|
| 52 |
+
|
| 53 |
+
```python
|
| 54 |
+
import torch
|
| 55 |
+
from geomloss import SamplesLoss
|
| 56 |
+
|
| 57 |
+
# Test with EXACT shapes you'll use in training
|
| 58 |
+
loss_fn = SamplesLoss(loss="sinkhorn", p=2, blur=0.5, potentials=True)
|
| 59 |
+
|
| 60 |
+
# 2D case — works fine
|
| 61 |
+
x_2d = torch.randn(256, 2, requires_grad=True)
|
| 62 |
+
y_2d = torch.randn(256, 2)
|
| 63 |
+
F, G = loss_fn(x_2d, y_2d) # ✅ OK
|
| 64 |
+
|
| 65 |
+
# Image case — THIS CRASHES
|
| 66 |
+
x_img = torch.randn(128, 1, 28, 28, requires_grad=True)
|
| 67 |
+
y_img = torch.randn(128, 1, 28, 28)
|
| 68 |
+
F, G = loss_fn(x_img, y_img) # ❌ ValueError: must be (N,D) or (B,N,D)
|
| 69 |
+
|
| 70 |
+
# Image case — FIXED by flattening
|
| 71 |
+
B = x_img.shape[0]
|
| 72 |
+
x_flat = x_img.view(B, -1).requires_grad_(True)
|
| 73 |
+
y_flat = y_img.view(B, -1)
|
| 74 |
+
F, G = loss_fn(x_flat, y_flat) # ✅ OK
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
### Mistake I made: geomloss tensor shape assumption
|
| 78 |
+
|
| 79 |
+
The `SamplesLoss` in geomloss requires inputs as `(N, D)` or `(B, N, D)` tensors. For 2D experiments with shape `(256, 2)` this works perfectly. For images with shape `(128, 1, 28, 28)` it crashes with:
|
| 80 |
+
|
| 81 |
+
```
|
| 82 |
+
ValueError: Input samples 'x' and 'y' should be encoded as (N,D) or (B,N,D) (batch) tensors.
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
**The fix**: Flatten images before passing to geomloss, reshape gradients back after:
|
| 86 |
+
|
| 87 |
+
```python
|
| 88 |
+
def compute_velocity(self, X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor:
|
| 89 |
+
original_shape = X.shape
|
| 90 |
+
is_image = X.dim() == 4
|
| 91 |
+
|
| 92 |
+
if is_image:
|
| 93 |
+
B = X.shape[0]
|
| 94 |
+
X_flat = X.detach().clone().view(B, -1).requires_grad_(True)
|
| 95 |
+
Y_flat = Y.detach().view(B, -1)
|
| 96 |
+
else:
|
| 97 |
+
X_flat = X.detach().clone().requires_grad_(True)
|
| 98 |
+
Y_flat = Y.detach()
|
| 99 |
+
|
| 100 |
+
# Self-potential
|
| 101 |
+
F_self, _ = self.loss_fn(X_flat, X_flat.detach().clone())
|
| 102 |
+
grad_self = torch.autograd.grad(F_self.sum(), X_flat)[0]
|
| 103 |
+
|
| 104 |
+
# Cross-potential
|
| 105 |
+
X_flat2 = X.detach().clone().view(B, -1).requires_grad_(True) if is_image else X.detach().clone().requires_grad_(True)
|
| 106 |
+
F_cross, _ = self.loss_fn(X_flat2, Y_flat)
|
| 107 |
+
grad_cross = torch.autograd.grad(F_cross.sum(), X_flat2)[0]
|
| 108 |
+
|
| 109 |
+
velocity = grad_self.detach() - grad_cross.detach()
|
| 110 |
+
|
| 111 |
+
if is_image:
|
| 112 |
+
velocity = velocity.view(original_shape)
|
| 113 |
+
|
| 114 |
+
return velocity
|
| 115 |
+
```
|
| 116 |
+
|
| 117 |
+
This pattern — flatten before library call, reshape after — applies to many optimal transport libraries (POT, geomloss, ott-jax).
|
| 118 |
+
|
| 119 |
+
---
|
| 120 |
+
|
| 121 |
+
## Phase 3: Architecture Gotchas
|
| 122 |
+
|
| 123 |
+
### UNet skip connections
|
| 124 |
+
|
| 125 |
+
When building a UNet from scratch (rather than importing from guided-diffusion), the skip connection bookkeeping is the #1 source of shape mismatch errors.
|
| 126 |
+
|
| 127 |
+
**The pattern that works**:
|
| 128 |
+
1. During the downward pass, push every intermediate activation onto a `skips` list
|
| 129 |
+
2. During the upward pass, pop from `skips` and concatenate
|
| 130 |
+
3. The number of pops must EXACTLY equal the number of pushes
|
| 131 |
+
|
| 132 |
+
**Mistake pattern**: Using a helper like `_get_num_res_blocks()` that infers block count from module list lengths. This is fragile — if the number of levels or blocks per level varies, the inference breaks.
|
| 133 |
+
|
| 134 |
+
**Better approach**: Store `num_res_blocks` as an instance variable at init time and use it directly:
|
| 135 |
+
|
| 136 |
+
```python
|
| 137 |
+
def __init__(self, ..., num_res_blocks=2, channel_mult=[1,2,2,2], ...):
|
| 138 |
+
self.num_res_blocks = num_res_blocks
|
| 139 |
+
self.num_levels = len(channel_mult)
|
| 140 |
+
# ... build layers ...
|
| 141 |
+
|
| 142 |
+
def forward(self, x, t):
|
| 143 |
+
# Use self.num_res_blocks directly, not a computed value
|
| 144 |
+
```
|
| 145 |
+
|
| 146 |
+
### GroupNorm channel requirements
|
| 147 |
+
|
| 148 |
+
`nn.GroupNorm(32, channels)` requires `channels` to be divisible by 32. For small models (e.g., MNIST with `model_channels=32`), this is fine at the first level but may break at deeper levels if `channel_mult` creates channels not divisible by 32.
|
| 149 |
+
|
| 150 |
+
**Safety check**: At init time, verify all channel counts are divisible by the group count:
|
| 151 |
+
```python
|
| 152 |
+
for level, mult in enumerate(channel_mult):
|
| 153 |
+
ch = model_channels * mult
|
| 154 |
+
assert ch % 32 == 0, f"Level {level}: channels={ch} not divisible by 32"
|
| 155 |
+
```
|
| 156 |
+
|
| 157 |
+
---
|
| 158 |
+
|
| 159 |
+
## Phase 4: Training Loop Patterns for Custom Pipelines
|
| 160 |
+
|
| 161 |
+
### Trajectory pool memory management
|
| 162 |
+
|
| 163 |
+
When building trajectory pools (storing (x, v, t) tuples from gradient flow), memory can explode:
|
| 164 |
+
- MNIST: 256 samples × 1500 batches × 5 steps × 784 dims × 4 bytes ≈ 2.4 GB (manageable)
|
| 165 |
+
- CIFAR-10: 128 samples × 2500 batches × 5 steps × 3072 dims × 4 bytes ≈ 19 GB (tight on T4)
|
| 166 |
+
|
| 167 |
+
**The fix**: Store pool tensors on CPU, transfer to GPU only during sampling:
|
| 168 |
+
|
| 169 |
+
```python
|
| 170 |
+
class TrajectoryPool:
|
| 171 |
+
def sample(self, batch_size, device="cpu"):
|
| 172 |
+
# Concatenate on CPU, index, THEN move to device
|
| 173 |
+
all_x = torch.cat(self.x_pool, dim=0) # stays on CPU
|
| 174 |
+
idx = torch.randint(0, all_x.shape[0], (batch_size,))
|
| 175 |
+
return all_x[idx].to(device), ... # only batch moves to GPU
|
| 176 |
+
```
|
| 177 |
+
|
| 178 |
+
**Mistake I made**: The pool sampling code calls `torch.cat` on the entire pool every training step, which is O(pool_size) per step. For 512K entries this is slow. Better: pre-concatenate once after pool building, then just index:
|
| 179 |
+
|
| 180 |
+
```python
|
| 181 |
+
def finalize(self):
|
| 182 |
+
"""Call once after pool is fully built."""
|
| 183 |
+
self._all_x = torch.cat(self.x_pool, dim=0)
|
| 184 |
+
self._all_v = torch.cat(self.v_pool, dim=0)
|
| 185 |
+
self._all_t = torch.tensor(self.t_pool, dtype=torch.float32)
|
| 186 |
+
# Free the lists
|
| 187 |
+
self.x_pool = self.v_pool = self.t_pool = None
|
| 188 |
+
|
| 189 |
+
def sample(self, batch_size, device="cpu"):
|
| 190 |
+
idx = torch.randint(0, self._all_x.shape[0], (batch_size,))
|
| 191 |
+
return self._all_x[idx].to(device), self._all_v[idx].to(device), self._all_t[idx].to(device)
|
| 192 |
+
```
|
| 193 |
+
|
| 194 |
+
### Multi-phase training (NSGF++)
|
| 195 |
+
|
| 196 |
+
NSGF++ has 3 sequential training phases:
|
| 197 |
+
1. **NSGF**: Build trajectory pool → train velocity field
|
| 198 |
+
2. **NSF**: Use trained NSGF to generate P0 samples → train straight flow
|
| 199 |
+
3. **Phase predictor**: Train CNN to predict transition time
|
| 200 |
+
|
| 201 |
+
**Key insight**: Each phase depends on the previous one being fully trained. Don't try to interleave them. The NSGF model must be in `eval()` mode when used as a sample generator in phases 2 and 3.
|
| 202 |
+
|
| 203 |
+
---
|
| 204 |
+
|
| 205 |
+
## Phase 5: Testing Strategy
|
| 206 |
+
|
| 207 |
+
### Always test on CPU first with tiny configs
|
| 208 |
+
|
| 209 |
+
Before any GPU run, verify the full pipeline works end-to-end:
|
| 210 |
+
|
| 211 |
+
```bash
|
| 212 |
+
# Tiny run — should complete in <30 seconds
|
| 213 |
+
python main.py --experiment 2d --dataset 8gaussians --steps 5 --pool-batches 5 --train-iters 100
|
| 214 |
+
|
| 215 |
+
# Slightly larger — should complete in <5 minutes
|
| 216 |
+
python main.py --experiment 2d --dataset 8gaussians --steps 5 --pool-batches 20 --train-iters 2000
|
| 217 |
+
```
|
| 218 |
+
|
| 219 |
+
### Test image experiments separately with minimal configs
|
| 220 |
+
|
| 221 |
+
```bash
|
| 222 |
+
# MNIST smoke test — 2 pool batches, 50 training iters
|
| 223 |
+
python main.py --experiment mnist --pool-batches 2 --train-iters 50
|
| 224 |
+
|
| 225 |
+
# If this crashes, fix before scaling up
|
| 226 |
+
```
|
| 227 |
+
|
| 228 |
+
**Mistake I made**: I tested 2D experiments thoroughly on CPU (both tiny and medium runs worked) but shipped the image experiments without testing them at all. The geomloss tensor shape bug affected ONLY the image path, so 2D success gave false confidence. The first GPU test of MNIST crashed immediately.
|
| 229 |
+
|
| 230 |
+
**Rule**: Test EVERY experiment type, not just the simplest one. If you have `{2d, mnist, cifar10}` experiments, test all three with minimal configs before declaring the code ready.
|
| 231 |
+
|
| 232 |
+
---
|
| 233 |
+
|
| 234 |
+
## Phase 6: Debugging GPU Runs
|
| 235 |
+
|
| 236 |
+
### Common error patterns
|
| 237 |
+
|
| 238 |
+
| Error | Cause | Fix |
|
| 239 |
+
|-------|-------|-----|
|
| 240 |
+
| `ValueError: (N,D) or (B,N,D)` | Library expects flat tensors, got images | Flatten before library call |
|
| 241 |
+
| `RuntimeError: shape mismatch` in UNet | Skip connection count wrong | Count pushes and pops manually |
|
| 242 |
+
| `CUDA OOM` during pool building | Pool too large for GPU | Build pool on CPU, sample to GPU |
|
| 243 |
+
| `CUDA OOM` during training | Batch too large or model too big | Reduce batch → increase grad accum |
|
| 244 |
+
| Training loss plateaus high | Pool too small or too few iterations | Increase pool batches, more iters |
|
| 245 |
+
| W2 distance too high | Undertrained model | Full paper config: 200 batches, 20k iters |
|
| 246 |
+
| `KeyboardInterrupt` during training | Training takes too long at scale | Expected — full 2D takes ~20min on T4 |
|
| 247 |
+
|
| 248 |
+
### When the user runs on their hardware
|
| 249 |
+
|
| 250 |
+
If you're developing code that the user will run on their own GPU (Kaggle, Colab, local):
|
| 251 |
+
|
| 252 |
+
1. **Provide exact commands** — don't make them figure out args
|
| 253 |
+
2. **Warn about expected runtimes** — "2D full run: ~20min on T4, MNIST: ~2-4 hours, CIFAR-10: ~8-12 hours"
|
| 254 |
+
3. **Include checkpoint saving** — so partial runs aren't wasted
|
| 255 |
+
4. **Test the exact commands yourself** — if you can't run on GPU, at least verify the command parses correctly on CPU
|
| 256 |
+
|
| 257 |
+
---
|
| 258 |
+
|
| 259 |
+
## Mistake Catalog
|
| 260 |
+
|
| 261 |
+
### Mistakes made during NSGF++ reproduction
|
| 262 |
+
|
| 263 |
+
1. **geomloss tensor shape bug** (CRITICAL)
|
| 264 |
+
- **What**: `SamplesLoss` requires `(N,D)` tensors. Image experiments passed `(N,C,H,W)`.
|
| 265 |
+
- **Impact**: MNIST and CIFAR-10 experiments crash immediately. 2D works fine, hiding the bug.
|
| 266 |
+
- **Root cause**: Only tested 2D path. Didn't verify library API with image tensor shapes.
|
| 267 |
+
- **Prevention**: Write a standalone API test script for every third-party library, testing with ALL tensor shapes you'll use.
|
| 268 |
+
|
| 269 |
+
2. **TrajectoryPool sampling performance** (MODERATE)
|
| 270 |
+
- **What**: `torch.cat` called on entire pool every training step.
|
| 271 |
+
- **Impact**: Training slower than necessary. At 512K pool entries, the cat+index is the bottleneck (~0.5s per step vs ~0.05s for the actual forward/backward).
|
| 272 |
+
- **Root cause**: Didn't profile the training loop.
|
| 273 |
+
- **Prevention**: Pre-concatenate the pool after building it. Profile before shipping.
|
| 274 |
+
|
| 275 |
+
3. **Incomplete experiment testing** (CRITICAL)
|
| 276 |
+
- **What**: Tested 2D experiments only. Shipped MNIST/CIFAR untested.
|
| 277 |
+
- **Impact**: User's first GPU run crashes. Wasted their Kaggle session time.
|
| 278 |
+
- **Root cause**: False confidence from 2D success. Assumed same code path.
|
| 279 |
+
- **Prevention**: Test EVERY experiment type with minimal configs. Different experiment types often exercise different code paths.
|
| 280 |
+
|
| 281 |
+
4. **No checkpoint saving** (MODERATE)
|
| 282 |
+
- **What**: No intermediate checkpoints during long training runs.
|
| 283 |
+
- **Impact**: If training is interrupted (Kaggle timeout, OOM), all progress is lost.
|
| 284 |
+
- **Prevention**: Save checkpoints every N iterations. Implement `--resume` flag.
|
| 285 |
+
|
| 286 |
+
5. **UNet forward pass fragility** (LOW-MODERATE)
|
| 287 |
+
- **What**: `_get_num_res_blocks()` infers block count from module list length division.
|
| 288 |
+
- **Impact**: Could break silently with non-standard configs.
|
| 289 |
+
- **Prevention**: Store config values as instance variables, don't infer from module counts.
|
| 290 |
+
|
| 291 |
+
---
|
| 292 |
+
|
| 293 |
+
## Pre-flight Checklist (before declaring code ready)
|
| 294 |
+
|
| 295 |
+
```
|
| 296 |
+
□ All experiment types tested with minimal configs (not just the easiest one)
|
| 297 |
+
□ Third-party library APIs tested with exact tensor shapes per experiment
|
| 298 |
+
□ Training loop profiled — no O(N) operations per step where O(1) suffices
|
| 299 |
+
□ Memory estimated per experiment (pool size × data dim × 4 bytes)
|
| 300 |
+
□ Checkpointing implemented for runs >10 minutes
|
| 301 |
+
□ Clear CLI with sensible defaults and override flags
|
| 302 |
+
□ Expected runtimes documented per hardware tier
|
| 303 |
+
□ Error messages are clear (not just stack traces)
|
| 304 |
+
□ Results directory created automatically
|
| 305 |
+
□ Requirements.txt includes ALL dependencies with minimum versions
|
| 306 |
+
```
|
| 307 |
+
|
| 308 |
+
---
|
| 309 |
+
|
| 310 |
+
## General Principles for Paper Reproduction
|
| 311 |
+
|
| 312 |
+
1. **Read the appendix first.** The appendix contains the actual implementation details. The main paper is the story; the appendix is the recipe.
|
| 313 |
+
|
| 314 |
+
2. **Test the boundaries, not just the happy path.** If your code handles 2D, MNIST, and CIFAR-10, test all three. The bug is always in the path you didn't test.
|
| 315 |
+
|
| 316 |
+
3. **Library APIs are opaque until tested.** Don't assume a function accepts your tensor shape just because it "makes sense." Write a 10-line test script.
|
| 317 |
+
|
| 318 |
+
4. **Pre-concatenate, don't re-concatenate.** Any data structure that's built once and sampled many times should be finalized into a single tensor after building.
|
| 319 |
+
|
| 320 |
+
5. **The user's time is more expensive than your time.** A crash on their GPU after 5 minutes of setup is worse than you spending 30 extra minutes testing. Ship code that works on first run.
|
| 321 |
+
|
| 322 |
+
6. **Flatten for OT libraries.** Optimal transport libraries (geomloss, POT, ott-jax) almost universally expect `(N, D)` point clouds. Images must be flattened. This is the #1 gotcha in OT-based generative models.
|
| 323 |
+
|
| 324 |
+
7. **Store training state on CPU, compute on GPU.** Trajectory pools, replay buffers, and other large data structures should live on CPU. Only the current minibatch goes to GPU.
|
| 325 |
+
|
| 326 |
+
8. **Multi-phase training = multiple separate trainers.** Don't try to be clever with a single training loop that switches phases. Each phase is a distinct trainer with its own optimizer. The previous phase's model goes to `eval()`.
|