# TODO.md — Next Steps for NSGF++ Reproduction ## Current Status | Experiment | Pool Building | Phase 1 (NSGF) | Phase 2 (NSF) | Phase 3 (Predictor) | Inference | Eval | |-----------|:---:|:---:|:---:|:---:|:---:|:---:| | **2D 8gaussians** | ✅ | ✅ | — | — | ✅ | ✅ W2=2.04 (small run) | | **MNIST** | ✅ | 🔶 runs, loss converging (~0.03), interrupted at 9.5K/100K | untested on GPU | untested on GPU | untested | untested | | **CIFAR-10** | 🔶 OOM fixed (batch 128→32), untested on GPU | untested | untested | untested | untested | untested | ✅ = verified working 🔶 = partially done ❌ = blocked --- ## Immediate — Run Full Experiments ### 1. MNIST full run on T4 The most important next step. All code bugs are fixed. Need a clean Kaggle run. ```bash cd /kaggle/working/ && rm -rf nsgf-plusplus git clone https://huggingface.co/rogermt/nsgf-plusplus cd nsgf-plusplus && pip install -r requirements.txt # Phase 1: pool (~7 min) + NSGF training (100K steps, ~2.5 hrs) python main.py --experiment mnist # If session runs out, next session: python main.py --experiment mnist --resume-phase 2 # If Phase 2 done: python main.py --experiment mnist --resume-phase 3 ``` **Expected runtimes on T4:** - Pool building (1500 batches): ~7 min - Phase 1 NSGF (100K steps): ~2.5 hours - Phase 2 NSF (100K steps): ~3-4 hours (each step does NSGF inference + NSF forward/backward) - Phase 3 Predictor (40K steps): ~1.5 hours - **Total: ~7-8 hours** — tight for one 9-hour Kaggle session **Alternative: use `--train-iters 50000` for Phase 1+2 to fit in one session, accept lower quality.** **Paper target: FID ≈ 3.8 at NFE=60** --- ### 2. CIFAR-10 first test on T4 After MNIST works, test CIFAR with reduced Sinkhorn batch. ```bash # Smoke test first (should run ~2 min) python main.py --experiment cifar10 --pool-batches 10 --train-iters 50 # If smoke test passes, real Phase 1: python main.py --experiment cifar10 --train-iters 50000 # Subsequent sessions: python main.py --experiment cifar10 --resume-phase 2 --train-iters 50000 python main.py --experiment cifar10 --resume-phase 3 ``` **If still OOMs**: try `--sinkhorn-batch 16 --pool-batches 20000` **Paper target: FID ≈ 5.55, IS ≈ 8.86 at NFE=59** --- ### 3. 2D full-scale run Quick win to validate against paper numbers. Should take ~20 min on T4. ```bash python main.py --experiment 2d --dataset 8gaussians --steps 10 ``` **Paper target: W2 ≈ 0.285 for 8gaussians** Current small-run W2=2.04 is expected — only used 10 pool batches + 1000 iters. Full run (200 batches, 20K iters) should drop dramatically. Also run other 2D datasets: ```bash python main.py --experiment 2d --dataset moons --steps 10 python main.py --experiment 2d --dataset scurve --steps 10 python main.py --experiment 2d --dataset checkerboard --steps 10 ``` --- ## Medium-term — Code Improvements ### 4. Step-level resume within phases Current `--resume-phase` skips completed phases but restarts the current phase from step 0. For 100K-step phases, mid-phase interruption still loses progress. Need: - Load `nsgf_checkpoint.pt` / `nsf_checkpoint.pt` / `predictor_checkpoint.pt` - Resume optimizer state + step counter - Continue from last checkpoint step ### 5. EMA (Exponential Moving Average) for image models Paper uses EMA for MNIST and CIFAR-10 (standard in diffusion/flow models). Current code doesn't implement EMA. This likely affects FID significantly. ### 6. Learning rate scheduler Paper may use cosine decay or warmup. Currently using constant lr. Check if this matters for convergence. ### 7. FID evaluation correctness Verify that `evaluation.py`'s FID computation matches the standard protocol: - InceptionV3 features from `pool3` layer (2048-dim) - 10K generated vs 10K test samples - Proper image preprocessing (resize to 299×299 for Inception) - Compare against `pytorch-fid` or `clean-fid` for sanity check ### 8. Inception Score evaluation Implement properly for CIFAR-10 if not already correct. Paper reports IS=8.86. --- ## Longer-term — Towards Paper Numbers ### 9. Full paper hyperparameters Once code is stable, run with exact paper configs (no iteration reduction): - MNIST: 100K + 100K + 40K iterations - CIFAR-10: 200K + 200K + 40K iterations - This requires A100 or multiple Kaggle sessions with checkpointing ### 10. Ablation: NSGF vs NSGF++ Run NSGF-only (Phase 1 only, no straight flow) and compare FID/W2 against NSGF++ to verify the two-phase approach actually helps. Paper shows clear improvement. ### 11. NFE sweep Paper reports results at various NFE (number of function evaluations). Test: - MNIST: NFE = 10, 20, 40, 60 - CIFAR: NFE = 10, 20, 40, 59 - Compare FID vs NFE curve against paper's Figure 3 ### 12. pykeops for faster Sinkhorn Install `pykeops` to enable geomloss `online` backend. This avoids materializing the full N×N cost matrix and should be much faster + lower VRAM for image experiments. Could enable using paper's original batch_size=128 on T4. ```bash pip install pykeops # Then in config or code: # backend: "online" instead of "tensorized" ``` --- ## Known Limitations - **Single-GPU only** — no DDP, T4×2 wastes one GPU - **No EMA** — standard in flow/diffusion, likely hurts FID - **No mixed precision** — fp32 only, could halve VRAM with fp16/bf16 - **No gradient accumulation** — batch size is hard-limited by VRAM - **Kaggle checkpoint persistence** — checkpoints lost between sessions unless manually saved