| """main.py β Entry point for NSGF/NSGF++ experiments. |
| |
| Orchestrates the full experiment pipeline: |
| 1. Load configuration |
| 2. Set up dataset, model, trainer |
| 3. Train (build pool β velocity matching β [NSF β phase predictor for NSGF++]) |
| 4. Generate samples |
| 5. Evaluate (W2 for 2D, FID/IS for images) |
| 6. Visualize results |
| |
| Supports --resume-phase to continue from a checkpoint after interruption. |
| |
| Usage: |
| python main.py --experiment 2d --dataset 8gaussians --steps 10 |
| python main.py --experiment mnist --device cuda |
| python main.py --experiment mnist --resume-phase 2 # skip Phase 1, load checkpoint |
| python main.py --experiment cifar10 --device cuda |
| |
| Reference: arXiv:2401.14069 (Neural Sinkhorn Gradient Flow) |
| """ |
|
|
| import os |
| import sys |
| import argparse |
| import logging |
| import yaml |
| import torch |
| import time |
|
|
| from dataset_loader import DatasetLoader |
| from model import ( |
| VelocityMLP, VelocityUNet, PhaseTransitionPredictor, |
| create_velocity_model_2d, create_velocity_unet, create_phase_predictor, |
| ) |
| from trainer import NSGFTrainer, NSFTrainer, PhaseTransitionTrainer, NSGFPlusPlusTrainer |
| from inference import NSGFSampler, NSGFPlusPlusSampler |
| from evaluation import ( |
| Evaluation, compute_w2_distance, |
| plot_2d_samples, plot_2d_trajectory, plot_image_grid, |
| ) |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", |
| datefmt="%Y-%m-%d %H:%M:%S", |
| ) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| def load_config(config_path: str = "config.yaml") -> dict: |
| """Load configuration from YAML file.""" |
| with open(config_path, "r") as f: |
| return yaml.safe_load(f) |
|
|
|
|
| def get_device(args) -> str: |
| """Resolve device from CLI args or auto-detect.""" |
| if args.device: |
| return args.device |
| return "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
| def run_2d_experiment(config: dict, args): |
| """Run 2D synthetic experiment (NSGF). |
| |
| Reference: Section 5.1, Appendix E.1 |
| """ |
| device = get_device(args) |
| logger.info(f"Running 2D experiment on {device}") |
| logger.info(f"Dataset: {config['dataset']}, Steps: {config['sinkhorn']['num_steps']}") |
| |
| |
| if args.dataset: |
| config["dataset"] = args.dataset |
| if args.steps: |
| config["sinkhorn"]["num_steps"] = args.steps |
| config["inference"]["num_euler_steps"] = args.steps |
| if args.pool_batches: |
| config["pool"]["num_batches"] = args.pool_batches |
| if args.train_iters: |
| config["training"]["num_iterations"] = args.train_iters |
| |
| |
| data_loader = DatasetLoader(config) |
| model = create_velocity_model_2d(config) |
| |
| logger.info(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}") |
| |
| |
| start_time = time.time() |
| |
| trainer = NSGFTrainer( |
| model=model, |
| data_loader=data_loader, |
| config=config, |
| device=device, |
| checkpoint_dir=args.checkpoint_dir, |
| ) |
| |
| trainer.build_trajectory_pool() |
| history = trainer.train() |
| |
| train_time = time.time() - start_time |
| logger.info(f"Training completed in {train_time:.1f}s") |
| |
| |
| num_eval = config.get("evaluation", {}).get("num_test_samples", 1024) |
| num_steps = config.get("inference", {}).get("num_euler_steps", 10) |
| |
| sampler = NSGFSampler( |
| model=model, |
| data_loader=data_loader, |
| num_steps=num_steps, |
| device=device, |
| ) |
| |
| samples = sampler.sample(num_eval) |
| trajectory = sampler.sample_trajectory(min(200, num_eval)) |
| |
| |
| test_samples = data_loader.get_test_samples(num_eval, device) |
| evaluator = Evaluation(config, device) |
| metrics = evaluator.evaluate(samples, test_samples) |
| |
| logger.info(f"\n{'='*50}") |
| logger.info(f"RESULTS β 2D {config['dataset']}, {num_steps} steps") |
| logger.info(f"{'='*50}") |
| for k, v in metrics.items(): |
| logger.info(f" {k}: {v:.4f}") |
| logger.info(f" Training time: {train_time:.1f}s") |
| |
| |
| os.makedirs("results", exist_ok=True) |
| plot_2d_samples( |
| samples, test_samples, |
| title=f"NSGF β {config['dataset']} ({num_steps} steps), W2={metrics.get('w2', 0):.4f}", |
| save_path=f"results/nsgf_2d_{config['dataset']}_{num_steps}steps.png", |
| ) |
| plot_2d_trajectory( |
| trajectory, test_samples, |
| title=f"NSGF Trajectory β {config['dataset']}", |
| save_path=f"results/nsgf_trajectory_{config['dataset']}_{num_steps}steps.png", |
| ) |
| |
| torch.save(model.state_dict(), f"results/nsgf_2d_{config['dataset']}.pt") |
| logger.info("Model saved.") |
| return metrics |
|
|
|
|
| def run_image_experiment(config: dict, args, dataset_name: str): |
| """Run image experiment (NSGF++). |
| |
| Reference: Section 5.2, Appendix E.2 |
| """ |
| device = get_device(args) |
| checkpoint_dir = args.checkpoint_dir |
| resume_phase = args.resume_phase |
| logger.info(f"Running {dataset_name.upper()} experiment on {device}") |
| |
| |
| if args.pool_batches: |
| config["pool"]["num_batches"] = args.pool_batches |
| if args.train_iters: |
| config["nsgf_training"]["num_iterations"] = args.train_iters |
| config["nsf_training"]["num_iterations"] = args.train_iters |
| config["time_predictor"]["num_iterations"] = args.train_iters |
| if args.sinkhorn_batch: |
| config["sinkhorn"]["batch_size"] = args.sinkhorn_batch |
| |
| |
| config["checkpoint_every"] = args.checkpoint_every |
| |
| |
| data_loader = DatasetLoader(config) |
| |
| |
| nsgf_model = create_velocity_unet(config) |
| nsf_model = create_velocity_unet(config) |
| phase_predictor = create_phase_predictor(config) |
| |
| |
| if resume_phase > 1: |
| ckpt_path = os.path.join(checkpoint_dir, f"phase{resume_phase - 1}_complete.pt") |
| if os.path.exists(ckpt_path): |
| ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True) |
| if "nsgf_model_state" in ckpt: |
| nsgf_model.load_state_dict(ckpt["nsgf_model_state"]) |
| logger.info(f"Loaded NSGF model from {ckpt_path}") |
| if "nsf_model_state" in ckpt: |
| nsf_model.load_state_dict(ckpt["nsf_model_state"]) |
| logger.info(f"Loaded NSF model from {ckpt_path}") |
| if "predictor_state" in ckpt: |
| phase_predictor.load_state_dict(ckpt["predictor_state"]) |
| logger.info(f"Loaded phase predictor from {ckpt_path}") |
| else: |
| logger.error(f"Checkpoint not found: {ckpt_path}") |
| logger.error(f"Cannot resume from phase {resume_phase} without phase {resume_phase - 1} checkpoint.") |
| sys.exit(1) |
| |
| logger.info(f"NSGF UNet params: {sum(p.numel() for p in nsgf_model.parameters()):,}") |
| logger.info(f"NSF UNet params: {sum(p.numel() for p in nsf_model.parameters()):,}") |
| logger.info(f"Phase predictor params: {sum(p.numel() for p in phase_predictor.parameters()):,}") |
| |
| |
| start_time = time.time() |
| |
| pp_trainer = NSGFPlusPlusTrainer( |
| nsgf_model=nsgf_model, |
| nsf_model=nsf_model, |
| phase_predictor=phase_predictor, |
| data_loader=data_loader, |
| config=config, |
| device=device, |
| checkpoint_dir=checkpoint_dir, |
| ) |
| |
| results = pp_trainer.train_all(resume_phase=resume_phase) |
| |
| train_time = time.time() - start_time |
| logger.info(f"Training completed in {train_time:.1f}s") |
| |
| |
| inference_cfg = config.get("inference", {}) |
| nsgf_steps = inference_cfg.get("nsgf_steps", 5) |
| nsf_steps = inference_cfg.get("nsf_steps", 55) |
| num_gen = config.get("evaluation", {}).get("num_generated", 10000) |
| |
| sampler = NSGFPlusPlusSampler( |
| nsgf_model=nsgf_model, |
| nsf_model=nsf_model, |
| phase_predictor=phase_predictor, |
| data_loader=data_loader, |
| nsgf_steps=nsgf_steps, |
| nsf_steps=nsf_steps, |
| device=device, |
| ) |
| |
| logger.info(f"Generating {num_gen} samples...") |
| batch_size = 128 |
| all_samples = [] |
| for i in range(0, num_gen, batch_size): |
| n = min(batch_size, num_gen - i) |
| samples = sampler.sample_simple(n) |
| all_samples.append(samples.cpu()) |
| generated = torch.cat(all_samples, dim=0) |
| |
| |
| test_samples = data_loader.get_test_samples(num_gen, device="cpu") |
| evaluator = Evaluation(config, device) |
| metrics = evaluator.evaluate(generated, test_samples) |
| |
| logger.info(f"\n{'='*50}") |
| logger.info(f"RESULTS β NSGF++ on {dataset_name.upper()}") |
| logger.info(f"{'='*50}") |
| for k, v in metrics.items(): |
| logger.info(f" {k}: {v:.4f}") |
| logger.info(f" NFE: {nsgf_steps + nsf_steps}") |
| logger.info(f" Training time: {train_time:.1f}s") |
| |
| |
| os.makedirs("results", exist_ok=True) |
| plot_image_grid( |
| generated[:64], |
| title=f"NSGF++ β {dataset_name.upper()}", |
| save_path=f"results/nsgf_pp_{dataset_name}_samples.png", |
| ) |
| |
| |
| torch.save(nsgf_model.state_dict(), f"results/nsgf_{dataset_name}_nsgf.pt") |
| torch.save(nsf_model.state_dict(), f"results/nsgf_{dataset_name}_nsf.pt") |
| torch.save(phase_predictor.state_dict(), f"results/nsgf_{dataset_name}_predictor.pt") |
| logger.info("Models saved.") |
| |
| return metrics |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="NSGF/NSGF++ Experiments") |
| parser.add_argument( |
| "--experiment", type=str, default="2d", |
| choices=["2d", "mnist", "cifar10"], |
| help="Experiment type" |
| ) |
| parser.add_argument("--dataset", type=str, default=None, help="2D dataset name") |
| parser.add_argument("--steps", type=int, default=None, help="Number of flow steps") |
| parser.add_argument("--pool-batches", type=int, default=None, help="Pool building batches") |
| parser.add_argument("--train-iters", type=int, default=None, help="Training iterations") |
| parser.add_argument("--sinkhorn-batch", type=int, default=None, |
| help="Sinkhorn batch size for pool building (reduce for OOM)") |
| parser.add_argument("--config", type=str, default="config.yaml", help="Config file path") |
| parser.add_argument("--seed", type=int, default=42, help="Random seed") |
| parser.add_argument("--device", type=str, default=None, |
| choices=["cpu", "cuda"], |
| help="Force device (default: auto-detect)") |
| parser.add_argument("--checkpoint-dir", type=str, default="checkpoints", |
| help="Directory for saving/loading checkpoints") |
| parser.add_argument("--checkpoint-every", type=int, default=5000, |
| help="Save checkpoint every N training steps") |
| parser.add_argument("--resume-phase", type=int, default=1, |
| choices=[1, 2, 3], |
| help="Resume from phase N (loads phase N-1 checkpoint)") |
| |
| args = parser.parse_args() |
| |
| |
| torch.manual_seed(args.seed) |
| import numpy as np |
| np.random.seed(args.seed) |
| |
| |
| full_config = load_config(args.config) |
| |
| if args.experiment == "2d": |
| config = full_config["experiment_2d"] |
| run_2d_experiment(config, args) |
| elif args.experiment == "mnist": |
| config = full_config["experiment_mnist"] |
| run_image_experiment(config, args, "mnist") |
| elif args.experiment == "cifar10": |
| config = full_config["experiment_cifar10"] |
| run_image_experiment(config, args, "cifar10") |
| else: |
| logger.error(f"Unknown experiment: {args.experiment}") |
| sys.exit(1) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|