nsgf-plusplus / main.py
rogermt's picture
main.py: --resume-phase, --checkpoint-dir, --sinkhorn-batch flags
a365009 verified
"""main.py β€” Entry point for NSGF/NSGF++ experiments.
Orchestrates the full experiment pipeline:
1. Load configuration
2. Set up dataset, model, trainer
3. Train (build pool β†’ velocity matching β†’ [NSF β†’ phase predictor for NSGF++])
4. Generate samples
5. Evaluate (W2 for 2D, FID/IS for images)
6. Visualize results
Supports --resume-phase to continue from a checkpoint after interruption.
Usage:
python main.py --experiment 2d --dataset 8gaussians --steps 10
python main.py --experiment mnist --device cuda
python main.py --experiment mnist --resume-phase 2 # skip Phase 1, load checkpoint
python main.py --experiment cifar10 --device cuda
Reference: arXiv:2401.14069 (Neural Sinkhorn Gradient Flow)
"""
import os
import sys
import argparse
import logging
import yaml
import torch
import time
from dataset_loader import DatasetLoader
from model import (
VelocityMLP, VelocityUNet, PhaseTransitionPredictor,
create_velocity_model_2d, create_velocity_unet, create_phase_predictor,
)
from trainer import NSGFTrainer, NSFTrainer, PhaseTransitionTrainer, NSGFPlusPlusTrainer
from inference import NSGFSampler, NSGFPlusPlusSampler
from evaluation import (
Evaluation, compute_w2_distance,
plot_2d_samples, plot_2d_trajectory, plot_image_grid,
)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)
def load_config(config_path: str = "config.yaml") -> dict:
"""Load configuration from YAML file."""
with open(config_path, "r") as f:
return yaml.safe_load(f)
def get_device(args) -> str:
"""Resolve device from CLI args or auto-detect."""
if args.device:
return args.device
return "cuda" if torch.cuda.is_available() else "cpu"
def run_2d_experiment(config: dict, args):
"""Run 2D synthetic experiment (NSGF).
Reference: Section 5.1, Appendix E.1
"""
device = get_device(args)
logger.info(f"Running 2D experiment on {device}")
logger.info(f"Dataset: {config['dataset']}, Steps: {config['sinkhorn']['num_steps']}")
# Override from args
if args.dataset:
config["dataset"] = args.dataset
if args.steps:
config["sinkhorn"]["num_steps"] = args.steps
config["inference"]["num_euler_steps"] = args.steps
if args.pool_batches:
config["pool"]["num_batches"] = args.pool_batches
if args.train_iters:
config["training"]["num_iterations"] = args.train_iters
# Setup
data_loader = DatasetLoader(config)
model = create_velocity_model_2d(config)
logger.info(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
# ---- Training ----
start_time = time.time()
trainer = NSGFTrainer(
model=model,
data_loader=data_loader,
config=config,
device=device,
checkpoint_dir=args.checkpoint_dir,
)
trainer.build_trajectory_pool()
history = trainer.train()
train_time = time.time() - start_time
logger.info(f"Training completed in {train_time:.1f}s")
# ---- Inference ----
num_eval = config.get("evaluation", {}).get("num_test_samples", 1024)
num_steps = config.get("inference", {}).get("num_euler_steps", 10)
sampler = NSGFSampler(
model=model,
data_loader=data_loader,
num_steps=num_steps,
device=device,
)
samples = sampler.sample(num_eval)
trajectory = sampler.sample_trajectory(min(200, num_eval))
# ---- Evaluation ----
test_samples = data_loader.get_test_samples(num_eval, device)
evaluator = Evaluation(config, device)
metrics = evaluator.evaluate(samples, test_samples)
logger.info(f"\n{'='*50}")
logger.info(f"RESULTS β€” 2D {config['dataset']}, {num_steps} steps")
logger.info(f"{'='*50}")
for k, v in metrics.items():
logger.info(f" {k}: {v:.4f}")
logger.info(f" Training time: {train_time:.1f}s")
# ---- Visualization ----
os.makedirs("results", exist_ok=True)
plot_2d_samples(
samples, test_samples,
title=f"NSGF β€” {config['dataset']} ({num_steps} steps), W2={metrics.get('w2', 0):.4f}",
save_path=f"results/nsgf_2d_{config['dataset']}_{num_steps}steps.png",
)
plot_2d_trajectory(
trajectory, test_samples,
title=f"NSGF Trajectory β€” {config['dataset']}",
save_path=f"results/nsgf_trajectory_{config['dataset']}_{num_steps}steps.png",
)
torch.save(model.state_dict(), f"results/nsgf_2d_{config['dataset']}.pt")
logger.info("Model saved.")
return metrics
def run_image_experiment(config: dict, args, dataset_name: str):
"""Run image experiment (NSGF++).
Reference: Section 5.2, Appendix E.2
"""
device = get_device(args)
checkpoint_dir = args.checkpoint_dir
resume_phase = args.resume_phase
logger.info(f"Running {dataset_name.upper()} experiment on {device}")
# Override from args
if args.pool_batches:
config["pool"]["num_batches"] = args.pool_batches
if args.train_iters:
config["nsgf_training"]["num_iterations"] = args.train_iters
config["nsf_training"]["num_iterations"] = args.train_iters
config["time_predictor"]["num_iterations"] = args.train_iters
if args.sinkhorn_batch:
config["sinkhorn"]["batch_size"] = args.sinkhorn_batch
# Inject checkpoint_every into config for trainers to read
config["checkpoint_every"] = args.checkpoint_every
# Setup
data_loader = DatasetLoader(config)
# Create models
nsgf_model = create_velocity_unet(config)
nsf_model = create_velocity_unet(config)
phase_predictor = create_phase_predictor(config)
# Load checkpoints if resuming
if resume_phase > 1:
ckpt_path = os.path.join(checkpoint_dir, f"phase{resume_phase - 1}_complete.pt")
if os.path.exists(ckpt_path):
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
if "nsgf_model_state" in ckpt:
nsgf_model.load_state_dict(ckpt["nsgf_model_state"])
logger.info(f"Loaded NSGF model from {ckpt_path}")
if "nsf_model_state" in ckpt:
nsf_model.load_state_dict(ckpt["nsf_model_state"])
logger.info(f"Loaded NSF model from {ckpt_path}")
if "predictor_state" in ckpt:
phase_predictor.load_state_dict(ckpt["predictor_state"])
logger.info(f"Loaded phase predictor from {ckpt_path}")
else:
logger.error(f"Checkpoint not found: {ckpt_path}")
logger.error(f"Cannot resume from phase {resume_phase} without phase {resume_phase - 1} checkpoint.")
sys.exit(1)
logger.info(f"NSGF UNet params: {sum(p.numel() for p in nsgf_model.parameters()):,}")
logger.info(f"NSF UNet params: {sum(p.numel() for p in nsf_model.parameters()):,}")
logger.info(f"Phase predictor params: {sum(p.numel() for p in phase_predictor.parameters()):,}")
# ---- Training ----
start_time = time.time()
pp_trainer = NSGFPlusPlusTrainer(
nsgf_model=nsgf_model,
nsf_model=nsf_model,
phase_predictor=phase_predictor,
data_loader=data_loader,
config=config,
device=device,
checkpoint_dir=checkpoint_dir,
)
results = pp_trainer.train_all(resume_phase=resume_phase)
train_time = time.time() - start_time
logger.info(f"Training completed in {train_time:.1f}s")
# ---- Inference ----
inference_cfg = config.get("inference", {})
nsgf_steps = inference_cfg.get("nsgf_steps", 5)
nsf_steps = inference_cfg.get("nsf_steps", 55)
num_gen = config.get("evaluation", {}).get("num_generated", 10000)
sampler = NSGFPlusPlusSampler(
nsgf_model=nsgf_model,
nsf_model=nsf_model,
phase_predictor=phase_predictor,
data_loader=data_loader,
nsgf_steps=nsgf_steps,
nsf_steps=nsf_steps,
device=device,
)
logger.info(f"Generating {num_gen} samples...")
batch_size = 128
all_samples = []
for i in range(0, num_gen, batch_size):
n = min(batch_size, num_gen - i)
samples = sampler.sample_simple(n)
all_samples.append(samples.cpu())
generated = torch.cat(all_samples, dim=0)
# ---- Evaluation ----
test_samples = data_loader.get_test_samples(num_gen, device="cpu")
evaluator = Evaluation(config, device)
metrics = evaluator.evaluate(generated, test_samples)
logger.info(f"\n{'='*50}")
logger.info(f"RESULTS β€” NSGF++ on {dataset_name.upper()}")
logger.info(f"{'='*50}")
for k, v in metrics.items():
logger.info(f" {k}: {v:.4f}")
logger.info(f" NFE: {nsgf_steps + nsf_steps}")
logger.info(f" Training time: {train_time:.1f}s")
# ---- Visualization ----
os.makedirs("results", exist_ok=True)
plot_image_grid(
generated[:64],
title=f"NSGF++ β€” {dataset_name.upper()}",
save_path=f"results/nsgf_pp_{dataset_name}_samples.png",
)
# Save final models
torch.save(nsgf_model.state_dict(), f"results/nsgf_{dataset_name}_nsgf.pt")
torch.save(nsf_model.state_dict(), f"results/nsgf_{dataset_name}_nsf.pt")
torch.save(phase_predictor.state_dict(), f"results/nsgf_{dataset_name}_predictor.pt")
logger.info("Models saved.")
return metrics
def main():
parser = argparse.ArgumentParser(description="NSGF/NSGF++ Experiments")
parser.add_argument(
"--experiment", type=str, default="2d",
choices=["2d", "mnist", "cifar10"],
help="Experiment type"
)
parser.add_argument("--dataset", type=str, default=None, help="2D dataset name")
parser.add_argument("--steps", type=int, default=None, help="Number of flow steps")
parser.add_argument("--pool-batches", type=int, default=None, help="Pool building batches")
parser.add_argument("--train-iters", type=int, default=None, help="Training iterations")
parser.add_argument("--sinkhorn-batch", type=int, default=None,
help="Sinkhorn batch size for pool building (reduce for OOM)")
parser.add_argument("--config", type=str, default="config.yaml", help="Config file path")
parser.add_argument("--seed", type=int, default=42, help="Random seed")
parser.add_argument("--device", type=str, default=None,
choices=["cpu", "cuda"],
help="Force device (default: auto-detect)")
parser.add_argument("--checkpoint-dir", type=str, default="checkpoints",
help="Directory for saving/loading checkpoints")
parser.add_argument("--checkpoint-every", type=int, default=5000,
help="Save checkpoint every N training steps")
parser.add_argument("--resume-phase", type=int, default=1,
choices=[1, 2, 3],
help="Resume from phase N (loads phase N-1 checkpoint)")
args = parser.parse_args()
# Set seed
torch.manual_seed(args.seed)
import numpy as np
np.random.seed(args.seed)
# Load config
full_config = load_config(args.config)
if args.experiment == "2d":
config = full_config["experiment_2d"]
run_2d_experiment(config, args)
elif args.experiment == "mnist":
config = full_config["experiment_mnist"]
run_image_experiment(config, args, "mnist")
elif args.experiment == "cifar10":
config = full_config["experiment_cifar10"]
run_image_experiment(config, args, "cifar10")
else:
logger.error(f"Unknown experiment: {args.experiment}")
sys.exit(1)
if __name__ == "__main__":
main()