rogermt commited on
Commit
39fb7ad
·
verified ·
1 Parent(s): c881e93

Upload README.md

Browse files
Files changed (1) hide show
  1. README.md +59 -0
README.md ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NSGF++ — Neural Sinkhorn Gradient Flow
2
+
3
+ Reproduction of [arXiv:2401.14069](https://arxiv.org/abs/2401.14069)
4
+
5
+ ## Setup
6
+
7
+ ```bash
8
+ git clone https://huggingface.co/rogermt/nsgf-plusplus
9
+ cd nsgf-plusplus
10
+ pip install torch torchvision numpy scipy scikit-learn matplotlib geomloss pot tqdm pyyaml
11
+ # For GPU acceleration of Sinkhorn: pip install pykeops
12
+ ```
13
+
14
+ ## Quick start — 2D experiments
15
+
16
+ ```bash
17
+ # Full-scale 8gaussians (paper Table 1, ~10 min on GPU)
18
+ python main.py --experiment 2d --dataset 8gaussians --steps 10
19
+
20
+ # Quick test (< 1 min)
21
+ python main.py --experiment 2d --dataset 8gaussians --steps 5 --pool-batches 10 --train-iters 1000
22
+
23
+ # All 2D datasets
24
+ for ds in 8gaussians moons scurve checkerboard; do
25
+ python main.py --experiment 2d --dataset $ds --steps 10
26
+ python main.py --experiment 2d --dataset $ds --steps 100
27
+ done
28
+ ```
29
+
30
+ ## Image experiments (NSGF++)
31
+
32
+ ```bash
33
+ # MNIST (paper: FID=3.8, NFE=60)
34
+ python main.py --experiment mnist
35
+
36
+ # CIFAR-10 (paper: FID=5.55, IS=8.86, NFE=59)
37
+ python main.py --experiment cifar10
38
+ ```
39
+
40
+ ## Files
41
+
42
+ | File | Description |
43
+ |------|-------------|
44
+ | `config.yaml` | All hyperparameters from the paper |
45
+ | `main.py` | CLI entry point |
46
+ | `dataset_loader.py` | 2D synthetic + MNIST/CIFAR-10 loaders |
47
+ | `sinkhorn_flow.py` | Sinkhorn potentials (GeomLoss), gradient flow, trajectory pool |
48
+ | `model.py` | VelocityMLP (2D), VelocityUNet (images), PhaseTransitionPredictor |
49
+ | `trainer.py` | NSGF, NSF, phase predictor, and NSGF++ trainers |
50
+ | `inference.py` | NSGF and NSGF++ samplers |
51
+ | `evaluation.py` | W2 distance, FID, IS, visualization |
52
+
53
+ ## Paper targets
54
+
55
+ | Experiment | Metric | Target |
56
+ |-----------|--------|--------|
57
+ | 8gaussians / 10 steps | W2 | 0.285 |
58
+ | MNIST | FID / NFE | 3.8 / 60 |
59
+ | CIFAR-10 | FID / IS / NFE | 5.55 / 8.86 / 59 |