nsgf-plusplus / config.yaml
rogermt's picture
config.yaml: CIFAR sinkhorn batch 128→32 for T4, pool batches 2500→10000 to compensate
80b1d4b verified
## config.yaml
## Neural Sinkhorn Gradient Flow (NSGF++) Configuration
## Based on arXiv:2401.14069
# ============================================================
# 2D Synthetic Experiments (Section 5.1, Appendix E.1)
# ============================================================
experiment_2d:
# Datasets: 8gaussians, moons, scurve, checkerboard, 8gaussians_moons
dataset: "8gaussians"
source: "gaussian" # source distribution: standard Gaussian N(0, I)
# MLP Architecture (Appendix E.1: 3 hidden layers, 256 hidden units)
model:
input_dim: 2
hidden_dim: 256
num_hidden_layers: 3
time_emb_dim: 64
activation: "silu"
# Sinkhorn gradient flow parameters
sinkhorn:
epsilon: 0.1 # regularization coefficient ε
blur: 0.5 # GeomLoss blur parameter (blur^p ~ ε)
scaling: 0.80 # GeomLoss multiscale scaling
eta: 1.0 # gradient flow step size η
num_steps: 10 # T: number of gradient flow time steps
batch_size: 256 # n: minibatch size for Sinkhorn flow
# Trajectory pool
pool:
num_batches: 200 # number of batches to build pool
experience_replay: true
# Velocity field matching training
training:
num_iterations: 20000
batch_size: 256
learning_rate: 0.001
optimizer: "adam"
beta1: 0.9
beta2: 0.999
weight_decay: 0.0
# Inference / Sampling
inference:
num_euler_steps: 10 # 10 or 100 Euler steps (uniform schedule)
num_samples: 1024 # samples for evaluation
# Evaluation
evaluation:
num_test_samples: 1024 # W2 computed against 1024 test samples
metric: "w2" # 2-Wasserstein distance
# ============================================================
# Image Benchmark Experiments (Section 5.2, Appendix E.2)
# ============================================================
experiment_mnist:
dataset: "mnist"
image_size: 28
in_channels: 1
# UNet Architecture (Appendix E.2, Dhariwal & Nichol 2021)
unet:
model_channels: 32 # base channels
num_res_blocks: 1 # depth = 1
channel_mult: [1, 2, 2]
num_heads: 1
num_head_channels: -1 # use num_heads instead
attention_resolutions: [16]
dropout: 0.0
use_scale_shift_norm: true # AdaGN
# Sinkhorn gradient flow (Phase 1)
sinkhorn:
blur: 0.5
scaling: 0.80
eta: 1.0
num_steps: 5 # T <= 5 for NSGF phase
batch_size: 256
# Trajectory pool (Appendix E.2: 256 batch * 1500 batches * 5 steps < 20GB)
pool:
num_batches: 1500
storage_limit_gb: 20
# Velocity field matching training (NSGF model)
nsgf_training:
num_iterations: 100000
batch_size: 128
learning_rate: 0.0001
optimizer: "adam"
beta1: 0.9
beta2: 0.999
weight_decay: 0.0
# Neural Straight Flow (Phase 2)
nsf_training:
num_iterations: 100000
batch_size: 128
learning_rate: 0.0001
optimizer: "adam"
beta1: 0.9
beta2: 0.999
weight_decay: 0.0
# Phase-transition time predictor (CNN)
time_predictor:
conv_channels: [32, 64, 128, 256]
kernel_size: 3
stride: 1
padding: 1
pool_size: 2
num_iterations: 40000
learning_rate: 0.0001
batch_size: 128
# Inference
inference:
nsgf_steps: 5 # 5-step Euler in NSGF phase
nsf_steps: 55 # remaining steps for straight flow
total_nfe: 60 # total NFE = nsgf_steps + nsf_steps
# Evaluation (Appendix E.2: FID between 10K gen and test)
evaluation:
num_generated: 10000
metrics: ["fid"]
experiment_cifar10:
dataset: "cifar10"
image_size: 32
in_channels: 3
# UNet Architecture (Appendix E.2)
unet:
model_channels: 128 # base channels
num_res_blocks: 2 # depth = 2
channel_mult: [1, 2, 2, 2]
num_heads: 4
num_head_channels: 64
attention_resolutions: [16]
dropout: 0.0
use_scale_shift_norm: true
# Sinkhorn gradient flow (Phase 1)
# NOTE: batch_size reduced from paper's 128 to 32 for T4 16GB VRAM.
# Sinkhorn on 3072-dim flattened vectors (3x32x32) with tensorized backend
# uses O(N^2 * D) memory. 128 samples OOMs on T4; 32 fits comfortably.
# Compensate by increasing pool batches (32 * 10000 = 320K ≈ 128 * 2500).
sinkhorn:
blur: 1.0
scaling: 0.85
eta: 1.0
num_steps: 5
batch_size: 32
# Trajectory pool — adjusted for smaller Sinkhorn batch
# 32 batch * 10000 batches * 5 steps = 1.6M entries (same order as paper)
pool:
num_batches: 10000
storage_limit_gb: 45
# Velocity field matching training (NSGF model)
nsgf_training:
num_iterations: 200000
batch_size: 128
learning_rate: 0.0001
optimizer: "adam"
beta1: 0.9
beta2: 0.999
weight_decay: 0.0
# Neural Straight Flow (Phase 2)
nsf_training:
num_iterations: 200000
batch_size: 128
learning_rate: 0.0001
optimizer: "adam"
beta1: 0.9
beta2: 0.999
weight_decay: 0.0
# Phase-transition time predictor (same CNN architecture)
time_predictor:
conv_channels: [32, 64, 128, 256]
kernel_size: 3
stride: 1
padding: 1
pool_size: 2
num_iterations: 40000
learning_rate: 0.0001
batch_size: 128
# Inference
inference:
nsgf_steps: 5
nsf_steps: 54
total_nfe: 59 # paper reports NFE=59 for CIFAR-10
# Evaluation
evaluation:
num_generated: 10000
metrics: ["fid", "is"]
# Paper target: FID=5.55, IS=8.86