File size: 1,786 Bytes
39fb7ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# NSGF++ — Neural Sinkhorn Gradient Flow

Reproduction of [arXiv:2401.14069](https://arxiv.org/abs/2401.14069)

## Setup

```bash
git clone https://huggingface.co/rogermt/nsgf-plusplus
cd nsgf-plusplus
pip install torch torchvision numpy scipy scikit-learn matplotlib geomloss pot tqdm pyyaml
# For GPU acceleration of Sinkhorn: pip install pykeops
```

## Quick start — 2D experiments

```bash
# Full-scale 8gaussians (paper Table 1, ~10 min on GPU)
python main.py --experiment 2d --dataset 8gaussians --steps 10

# Quick test (< 1 min)
python main.py --experiment 2d --dataset 8gaussians --steps 5 --pool-batches 10 --train-iters 1000

# All 2D datasets
for ds in 8gaussians moons scurve checkerboard; do
  python main.py --experiment 2d --dataset $ds --steps 10
  python main.py --experiment 2d --dataset $ds --steps 100
done
```

## Image experiments (NSGF++)

```bash
# MNIST (paper: FID=3.8, NFE=60)
python main.py --experiment mnist

# CIFAR-10 (paper: FID=5.55, IS=8.86, NFE=59)
python main.py --experiment cifar10
```

## Files

| File | Description |
|------|-------------|
| `config.yaml` | All hyperparameters from the paper |
| `main.py` | CLI entry point |
| `dataset_loader.py` | 2D synthetic + MNIST/CIFAR-10 loaders |
| `sinkhorn_flow.py` | Sinkhorn potentials (GeomLoss), gradient flow, trajectory pool |
| `model.py` | VelocityMLP (2D), VelocityUNet (images), PhaseTransitionPredictor |
| `trainer.py` | NSGF, NSF, phase predictor, and NSGF++ trainers |
| `inference.py` | NSGF and NSGF++ samplers |
| `evaluation.py` | W2 distance, FID, IS, visualization |

## Paper targets

| Experiment | Metric | Target |
|-----------|--------|--------|
| 8gaussians / 10 steps | W2 | 0.285 |
| MNIST | FID / NFE | 3.8 / 60 |
| CIFAR-10 | FID / IS / NFE | 5.55 / 8.86 / 59 |