File size: 5,502 Bytes
91fd7ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
# TODO.md β€” Next Steps for NSGF++ Reproduction

## Current Status

| Experiment | Pool Building | Phase 1 (NSGF) | Phase 2 (NSF) | Phase 3 (Predictor) | Inference | Eval |
|-----------|:---:|:---:|:---:|:---:|:---:|:---:|
| **2D 8gaussians** | βœ… | βœ… | β€” | β€” | βœ… | βœ… W2=2.04 (small run) |
| **MNIST** | βœ… | πŸ”Ά runs, loss converging (~0.03), interrupted at 9.5K/100K | untested on GPU | untested on GPU | untested | untested |
| **CIFAR-10** | πŸ”Ά OOM fixed (batch 128β†’32), untested on GPU | untested | untested | untested | untested | untested |

βœ… = verified working  πŸ”Ά = partially done  ❌ = blocked

---

## Immediate β€” Run Full Experiments

### 1. MNIST full run on T4

The most important next step. All code bugs are fixed. Need a clean Kaggle run.

```bash
cd /kaggle/working/ && rm -rf nsgf-plusplus
git clone https://huggingface.co/rogermt/nsgf-plusplus
cd nsgf-plusplus && pip install -r requirements.txt

# Phase 1: pool (~7 min) + NSGF training (100K steps, ~2.5 hrs)
python main.py --experiment mnist

# If session runs out, next session:
python main.py --experiment mnist --resume-phase 2

# If Phase 2 done:
python main.py --experiment mnist --resume-phase 3
```

**Expected runtimes on T4:**
- Pool building (1500 batches): ~7 min
- Phase 1 NSGF (100K steps): ~2.5 hours
- Phase 2 NSF (100K steps): ~3-4 hours (each step does NSGF inference + NSF forward/backward)
- Phase 3 Predictor (40K steps): ~1.5 hours
- **Total: ~7-8 hours** β€” tight for one 9-hour Kaggle session

**Alternative: use `--train-iters 50000` for Phase 1+2 to fit in one session, accept lower quality.**

**Paper target: FID β‰ˆ 3.8 at NFE=60**

---

### 2. CIFAR-10 first test on T4

After MNIST works, test CIFAR with reduced Sinkhorn batch.

```bash
# Smoke test first (should run ~2 min)
python main.py --experiment cifar10 --pool-batches 10 --train-iters 50

# If smoke test passes, real Phase 1:
python main.py --experiment cifar10 --train-iters 50000

# Subsequent sessions:
python main.py --experiment cifar10 --resume-phase 2 --train-iters 50000
python main.py --experiment cifar10 --resume-phase 3
```

**If still OOMs**: try `--sinkhorn-batch 16 --pool-batches 20000`

**Paper target: FID β‰ˆ 5.55, IS β‰ˆ 8.86 at NFE=59**

---

### 3. 2D full-scale run

Quick win to validate against paper numbers. Should take ~20 min on T4.

```bash
python main.py --experiment 2d --dataset 8gaussians --steps 10
```

**Paper target: W2 β‰ˆ 0.285 for 8gaussians**

Current small-run W2=2.04 is expected β€” only used 10 pool batches + 1000 iters. Full run (200 batches, 20K iters) should drop dramatically.

Also run other 2D datasets:
```bash
python main.py --experiment 2d --dataset moons --steps 10
python main.py --experiment 2d --dataset scurve --steps 10
python main.py --experiment 2d --dataset checkerboard --steps 10
```

---

## Medium-term β€” Code Improvements

### 4. Step-level resume within phases

Current `--resume-phase` skips completed phases but restarts the current phase from step 0. For 100K-step phases, mid-phase interruption still loses progress. Need:
- Load `nsgf_checkpoint.pt` / `nsf_checkpoint.pt` / `predictor_checkpoint.pt`
- Resume optimizer state + step counter
- Continue from last checkpoint step

### 5. EMA (Exponential Moving Average) for image models

Paper uses EMA for MNIST and CIFAR-10 (standard in diffusion/flow models). Current code doesn't implement EMA. This likely affects FID significantly.

### 6. Learning rate scheduler

Paper may use cosine decay or warmup. Currently using constant lr. Check if this matters for convergence.

### 7. FID evaluation correctness

Verify that `evaluation.py`'s FID computation matches the standard protocol:
- InceptionV3 features from `pool3` layer (2048-dim)
- 10K generated vs 10K test samples
- Proper image preprocessing (resize to 299Γ—299 for Inception)
- Compare against `pytorch-fid` or `clean-fid` for sanity check

### 8. Inception Score evaluation

Implement properly for CIFAR-10 if not already correct. Paper reports IS=8.86.

---

## Longer-term β€” Towards Paper Numbers

### 9. Full paper hyperparameters

Once code is stable, run with exact paper configs (no iteration reduction):
- MNIST: 100K + 100K + 40K iterations
- CIFAR-10: 200K + 200K + 40K iterations
- This requires A100 or multiple Kaggle sessions with checkpointing

### 10. Ablation: NSGF vs NSGF++

Run NSGF-only (Phase 1 only, no straight flow) and compare FID/W2 against NSGF++ to verify the two-phase approach actually helps. Paper shows clear improvement.

### 11. NFE sweep

Paper reports results at various NFE (number of function evaluations). Test:
- MNIST: NFE = 10, 20, 40, 60
- CIFAR: NFE = 10, 20, 40, 59
- Compare FID vs NFE curve against paper's Figure 3

### 12. pykeops for faster Sinkhorn

Install `pykeops` to enable geomloss `online` backend. This avoids materializing the full NΓ—N cost matrix and should be much faster + lower VRAM for image experiments. Could enable using paper's original batch_size=128 on T4.

```bash
pip install pykeops
# Then in config or code:
# backend: "online"  instead of "tensorized"
```

---

## Known Limitations

- **Single-GPU only** β€” no DDP, T4Γ—2 wastes one GPU
- **No EMA** β€” standard in flow/diffusion, likely hurts FID
- **No mixed precision** β€” fp32 only, could halve VRAM with fp16/bf16
- **No gradient accumulation** β€” batch size is hard-limited by VRAM
- **Kaggle checkpoint persistence** β€” checkpoints lost between sessions unless manually saved