NSGF++ — Neural Sinkhorn Gradient Flow
Reproduction of arXiv:2401.14069
Setup
git clone https://huggingface.co/rogermt/nsgf-plusplus
cd nsgf-plusplus
pip install torch torchvision numpy scipy scikit-learn matplotlib geomloss pot tqdm pyyaml
Quick start — 2D experiments
python main.py --experiment 2d --dataset 8gaussians --steps 10
python main.py --experiment 2d --dataset 8gaussians --steps 5 --pool-batches 10 --train-iters 1000
for ds in 8gaussians moons scurve checkerboard; do
python main.py --experiment 2d --dataset $ds --steps 10
python main.py --experiment 2d --dataset $ds --steps 100
done
Image experiments (NSGF++)
python main.py --experiment mnist
python main.py --experiment cifar10
Files
| File |
Description |
config.yaml |
All hyperparameters from the paper |
main.py |
CLI entry point |
dataset_loader.py |
2D synthetic + MNIST/CIFAR-10 loaders |
sinkhorn_flow.py |
Sinkhorn potentials (GeomLoss), gradient flow, trajectory pool |
model.py |
VelocityMLP (2D), VelocityUNet (images), PhaseTransitionPredictor |
trainer.py |
NSGF, NSF, phase predictor, and NSGF++ trainers |
inference.py |
NSGF and NSGF++ samplers |
evaluation.py |
W2 distance, FID, IS, visualization |
Paper targets
| Experiment |
Metric |
Target |
| 8gaussians / 10 steps |
W2 |
0.285 |
| MNIST |
FID / NFE |
3.8 / 60 |
| CIFAR-10 |
FID / IS / NFE |
5.55 / 8.86 / 59 |