| # NSGF++ — Neural Sinkhorn Gradient Flow |
|
|
| Reproduction of [arXiv:2401.14069](https://arxiv.org/abs/2401.14069) |
|
|
| ## Setup |
|
|
| ```bash |
| git clone https://huggingface.co/rogermt/nsgf-plusplus |
| cd nsgf-plusplus |
| pip install torch torchvision numpy scipy scikit-learn matplotlib geomloss pot tqdm pyyaml |
| # For GPU acceleration of Sinkhorn: pip install pykeops |
| ``` |
|
|
| ## Quick start — 2D experiments |
|
|
| ```bash |
| # Full-scale 8gaussians (paper Table 1, ~10 min on GPU) |
| python main.py --experiment 2d --dataset 8gaussians --steps 10 |
| |
| # Quick test (< 1 min) |
| python main.py --experiment 2d --dataset 8gaussians --steps 5 --pool-batches 10 --train-iters 1000 |
| |
| # All 2D datasets |
| for ds in 8gaussians moons scurve checkerboard; do |
| python main.py --experiment 2d --dataset $ds --steps 10 |
| python main.py --experiment 2d --dataset $ds --steps 100 |
| done |
| ``` |
|
|
| ## Image experiments (NSGF++) |
|
|
| ```bash |
| # MNIST (paper: FID=3.8, NFE=60) |
| python main.py --experiment mnist |
| |
| # CIFAR-10 (paper: FID=5.55, IS=8.86, NFE=59) |
| python main.py --experiment cifar10 |
| ``` |
|
|
| ## Files |
|
|
| | File | Description | |
| |------|-------------| |
| | `config.yaml` | All hyperparameters from the paper | |
| | `main.py` | CLI entry point | |
| | `dataset_loader.py` | 2D synthetic + MNIST/CIFAR-10 loaders | |
| | `sinkhorn_flow.py` | Sinkhorn potentials (GeomLoss), gradient flow, trajectory pool | |
| | `model.py` | VelocityMLP (2D), VelocityUNet (images), PhaseTransitionPredictor | |
| | `trainer.py` | NSGF, NSF, phase predictor, and NSGF++ trainers | |
| | `inference.py` | NSGF and NSGF++ samplers | |
| | `evaluation.py` | W2 distance, FID, IS, visualization | |
|
|
| ## Paper targets |
|
|
| | Experiment | Metric | Target | |
| |-----------|--------|--------| |
| | 8gaussians / 10 steps | W2 | 0.285 | |
| | MNIST | FID / NFE | 3.8 / 60 | |
| | CIFAR-10 | FID / IS / NFE | 5.55 / 8.86 / 59 | |
|
|