"""main.py — Entry point for NSGF/NSGF++ experiments. Orchestrates the full experiment pipeline: 1. Load configuration 2. Set up dataset, model, trainer 3. Train (build pool → velocity matching → [NSF → phase predictor for NSGF++]) 4. Generate samples 5. Evaluate (W2 for 2D, FID/IS for images) 6. Visualize results Supports --resume-phase to continue from a checkpoint after interruption. Usage: python main.py --experiment 2d --dataset 8gaussians --steps 10 python main.py --experiment mnist --device cuda python main.py --experiment mnist --resume-phase 2 # skip Phase 1, load checkpoint python main.py --experiment cifar10 --device cuda Reference: arXiv:2401.14069 (Neural Sinkhorn Gradient Flow) """ import os import sys import argparse import logging import yaml import torch import time from dataset_loader import DatasetLoader from model import ( VelocityMLP, VelocityUNet, PhaseTransitionPredictor, create_velocity_model_2d, create_velocity_unet, create_phase_predictor, ) from trainer import NSGFTrainer, NSFTrainer, PhaseTransitionTrainer, NSGFPlusPlusTrainer from inference import NSGFSampler, NSGFPlusPlusSampler from evaluation import ( Evaluation, compute_w2_distance, plot_2d_samples, plot_2d_trajectory, plot_image_grid, ) logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) logger = logging.getLogger(__name__) def load_config(config_path: str = "config.yaml") -> dict: """Load configuration from YAML file.""" with open(config_path, "r") as f: return yaml.safe_load(f) def get_device(args) -> str: """Resolve device from CLI args or auto-detect.""" if args.device: return args.device return "cuda" if torch.cuda.is_available() else "cpu" def run_2d_experiment(config: dict, args): """Run 2D synthetic experiment (NSGF). Reference: Section 5.1, Appendix E.1 """ device = get_device(args) logger.info(f"Running 2D experiment on {device}") logger.info(f"Dataset: {config['dataset']}, Steps: {config['sinkhorn']['num_steps']}") # Override from args if args.dataset: config["dataset"] = args.dataset if args.steps: config["sinkhorn"]["num_steps"] = args.steps config["inference"]["num_euler_steps"] = args.steps if args.pool_batches: config["pool"]["num_batches"] = args.pool_batches if args.train_iters: config["training"]["num_iterations"] = args.train_iters # Setup data_loader = DatasetLoader(config) model = create_velocity_model_2d(config) logger.info(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") # ---- Training ---- start_time = time.time() trainer = NSGFTrainer( model=model, data_loader=data_loader, config=config, device=device, checkpoint_dir=args.checkpoint_dir, ) trainer.build_trajectory_pool() history = trainer.train() train_time = time.time() - start_time logger.info(f"Training completed in {train_time:.1f}s") # ---- Inference ---- num_eval = config.get("evaluation", {}).get("num_test_samples", 1024) num_steps = config.get("inference", {}).get("num_euler_steps", 10) sampler = NSGFSampler( model=model, data_loader=data_loader, num_steps=num_steps, device=device, ) samples = sampler.sample(num_eval) trajectory = sampler.sample_trajectory(min(200, num_eval)) # ---- Evaluation ---- test_samples = data_loader.get_test_samples(num_eval, device) evaluator = Evaluation(config, device) metrics = evaluator.evaluate(samples, test_samples) logger.info(f"\n{'='*50}") logger.info(f"RESULTS — 2D {config['dataset']}, {num_steps} steps") logger.info(f"{'='*50}") for k, v in metrics.items(): logger.info(f" {k}: {v:.4f}") logger.info(f" Training time: {train_time:.1f}s") # ---- Visualization ---- os.makedirs("results", exist_ok=True) plot_2d_samples( samples, test_samples, title=f"NSGF — {config['dataset']} ({num_steps} steps), W2={metrics.get('w2', 0):.4f}", save_path=f"results/nsgf_2d_{config['dataset']}_{num_steps}steps.png", ) plot_2d_trajectory( trajectory, test_samples, title=f"NSGF Trajectory — {config['dataset']}", save_path=f"results/nsgf_trajectory_{config['dataset']}_{num_steps}steps.png", ) torch.save(model.state_dict(), f"results/nsgf_2d_{config['dataset']}.pt") logger.info("Model saved.") return metrics def run_image_experiment(config: dict, args, dataset_name: str): """Run image experiment (NSGF++). Reference: Section 5.2, Appendix E.2 """ device = get_device(args) checkpoint_dir = args.checkpoint_dir resume_phase = args.resume_phase logger.info(f"Running {dataset_name.upper()} experiment on {device}") # Override from args if args.pool_batches: config["pool"]["num_batches"] = args.pool_batches if args.train_iters: config["nsgf_training"]["num_iterations"] = args.train_iters config["nsf_training"]["num_iterations"] = args.train_iters config["time_predictor"]["num_iterations"] = args.train_iters if args.sinkhorn_batch: config["sinkhorn"]["batch_size"] = args.sinkhorn_batch # Inject checkpoint_every into config for trainers to read config["checkpoint_every"] = args.checkpoint_every # Setup data_loader = DatasetLoader(config) # Create models nsgf_model = create_velocity_unet(config) nsf_model = create_velocity_unet(config) phase_predictor = create_phase_predictor(config) # Load checkpoints if resuming if resume_phase > 1: ckpt_path = os.path.join(checkpoint_dir, f"phase{resume_phase - 1}_complete.pt") if os.path.exists(ckpt_path): ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True) if "nsgf_model_state" in ckpt: nsgf_model.load_state_dict(ckpt["nsgf_model_state"]) logger.info(f"Loaded NSGF model from {ckpt_path}") if "nsf_model_state" in ckpt: nsf_model.load_state_dict(ckpt["nsf_model_state"]) logger.info(f"Loaded NSF model from {ckpt_path}") if "predictor_state" in ckpt: phase_predictor.load_state_dict(ckpt["predictor_state"]) logger.info(f"Loaded phase predictor from {ckpt_path}") else: logger.error(f"Checkpoint not found: {ckpt_path}") logger.error(f"Cannot resume from phase {resume_phase} without phase {resume_phase - 1} checkpoint.") sys.exit(1) logger.info(f"NSGF UNet params: {sum(p.numel() for p in nsgf_model.parameters()):,}") logger.info(f"NSF UNet params: {sum(p.numel() for p in nsf_model.parameters()):,}") logger.info(f"Phase predictor params: {sum(p.numel() for p in phase_predictor.parameters()):,}") # ---- Training ---- start_time = time.time() pp_trainer = NSGFPlusPlusTrainer( nsgf_model=nsgf_model, nsf_model=nsf_model, phase_predictor=phase_predictor, data_loader=data_loader, config=config, device=device, checkpoint_dir=checkpoint_dir, ) results = pp_trainer.train_all(resume_phase=resume_phase) train_time = time.time() - start_time logger.info(f"Training completed in {train_time:.1f}s") # ---- Inference ---- inference_cfg = config.get("inference", {}) nsgf_steps = inference_cfg.get("nsgf_steps", 5) nsf_steps = inference_cfg.get("nsf_steps", 55) num_gen = config.get("evaluation", {}).get("num_generated", 10000) sampler = NSGFPlusPlusSampler( nsgf_model=nsgf_model, nsf_model=nsf_model, phase_predictor=phase_predictor, data_loader=data_loader, nsgf_steps=nsgf_steps, nsf_steps=nsf_steps, device=device, ) logger.info(f"Generating {num_gen} samples...") batch_size = 128 all_samples = [] for i in range(0, num_gen, batch_size): n = min(batch_size, num_gen - i) samples = sampler.sample_simple(n) all_samples.append(samples.cpu()) generated = torch.cat(all_samples, dim=0) # ---- Evaluation ---- test_samples = data_loader.get_test_samples(num_gen, device="cpu") evaluator = Evaluation(config, device) metrics = evaluator.evaluate(generated, test_samples) logger.info(f"\n{'='*50}") logger.info(f"RESULTS — NSGF++ on {dataset_name.upper()}") logger.info(f"{'='*50}") for k, v in metrics.items(): logger.info(f" {k}: {v:.4f}") logger.info(f" NFE: {nsgf_steps + nsf_steps}") logger.info(f" Training time: {train_time:.1f}s") # ---- Visualization ---- os.makedirs("results", exist_ok=True) plot_image_grid( generated[:64], title=f"NSGF++ — {dataset_name.upper()}", save_path=f"results/nsgf_pp_{dataset_name}_samples.png", ) # Save final models torch.save(nsgf_model.state_dict(), f"results/nsgf_{dataset_name}_nsgf.pt") torch.save(nsf_model.state_dict(), f"results/nsgf_{dataset_name}_nsf.pt") torch.save(phase_predictor.state_dict(), f"results/nsgf_{dataset_name}_predictor.pt") logger.info("Models saved.") return metrics def main(): parser = argparse.ArgumentParser(description="NSGF/NSGF++ Experiments") parser.add_argument( "--experiment", type=str, default="2d", choices=["2d", "mnist", "cifar10"], help="Experiment type" ) parser.add_argument("--dataset", type=str, default=None, help="2D dataset name") parser.add_argument("--steps", type=int, default=None, help="Number of flow steps") parser.add_argument("--pool-batches", type=int, default=None, help="Pool building batches") parser.add_argument("--train-iters", type=int, default=None, help="Training iterations") parser.add_argument("--sinkhorn-batch", type=int, default=None, help="Sinkhorn batch size for pool building (reduce for OOM)") parser.add_argument("--config", type=str, default="config.yaml", help="Config file path") parser.add_argument("--seed", type=int, default=42, help="Random seed") parser.add_argument("--device", type=str, default=None, choices=["cpu", "cuda"], help="Force device (default: auto-detect)") parser.add_argument("--checkpoint-dir", type=str, default="checkpoints", help="Directory for saving/loading checkpoints") parser.add_argument("--checkpoint-every", type=int, default=5000, help="Save checkpoint every N training steps") parser.add_argument("--resume-phase", type=int, default=1, choices=[1, 2, 3], help="Resume from phase N (loads phase N-1 checkpoint)") args = parser.parse_args() # Set seed torch.manual_seed(args.seed) import numpy as np np.random.seed(args.seed) # Load config full_config = load_config(args.config) if args.experiment == "2d": config = full_config["experiment_2d"] run_2d_experiment(config, args) elif args.experiment == "mnist": config = full_config["experiment_mnist"] run_image_experiment(config, args, "mnist") elif args.experiment == "cifar10": config = full_config["experiment_cifar10"] run_image_experiment(config, args, "cifar10") else: logger.error(f"Unknown experiment: {args.experiment}") sys.exit(1) if __name__ == "__main__": main()