| """trainer.py — Training procedures for NSGF and NSGF++. |
| |
| Implements: |
| 1. Trajectory pool construction (Phase 1: Sinkhorn gradient flow) |
| 2. NSGF velocity field matching training |
| 3. NSF (Neural Straight Flow) training for NSGF++ |
| 4. Phase-transition time predictor training |
| 5. End-to-end NSGF++ training pipeline |
| |
| Reference: arXiv:2401.14069, Section 4.2–4.4, Appendix D, E |
| """ |
|
|
| import os |
| import logging |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from typing import Optional, Dict, Any, Tuple |
|
|
| from dataset_loader import DatasetLoader |
| from sinkhorn_flow import ( |
| SinkhornPotentialComputer, SinkhornGradientFlow, TrajectoryPool, |
| ) |
| from model import VelocityMLP, VelocityUNet, PhaseTransitionPredictor |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class NSGFTrainer: |
| """Trainer for the Neural Sinkhorn Gradient Flow model. |
| |
| Loss (Eq. 14): L(θ) = E_{(x,v,t) ~ pool} ||v_θ(x, t) - v̂(x)||² |
| """ |
| def __init__(self, model: nn.Module, data_loader: DatasetLoader, |
| config: dict, device: str = "cpu"): |
| self.model = model.to(device) |
| self.data_loader = data_loader |
| self.config = config |
| self.device = device |
|
|
| sink_cfg = config.get("sinkhorn", {}) |
| self.potential_computer = SinkhornPotentialComputer( |
| blur=sink_cfg.get("blur", 0.5), scaling=sink_cfg.get("scaling", 0.80), |
| ) |
| self.gradient_flow = SinkhornGradientFlow( |
| potential_computer=self.potential_computer, |
| eta=sink_cfg.get("eta", 1.0), num_steps=sink_cfg.get("num_steps", 5), |
| ) |
| self.pool = TrajectoryPool(max_size=5_000_000) |
|
|
| train_cfg = config.get("training", config.get("nsgf_training", {})) |
| self.num_iterations = train_cfg.get("num_iterations", 20000) |
| self.train_batch_size = train_cfg.get("batch_size", 256) |
| self.lr = train_cfg.get("learning_rate", 1e-3) |
| self.optimizer = optim.Adam( |
| self.model.parameters(), lr=self.lr, |
| betas=(train_cfg.get("beta1", 0.9), train_cfg.get("beta2", 0.999)), |
| weight_decay=train_cfg.get("weight_decay", 0.0), |
| ) |
|
|
| def build_trajectory_pool(self, num_batches: Optional[int] = None): |
| if num_batches is None: |
| num_batches = self.config.get("pool", {}).get("num_batches", 200) |
| sink_batch_size = self.config.get("sinkhorn", {}).get("batch_size", 256) |
| logger.info( |
| f"Building trajectory pool: {num_batches} batches × " |
| f"{sink_batch_size} samples × {self.gradient_flow.num_steps} steps" |
| ) |
| for batch_idx in range(num_batches): |
| X0 = self.data_loader.sample_source(sink_batch_size, self.device) |
| Y = self.data_loader.sample_target(sink_batch_size, self.device) |
| _, trajectory = self.gradient_flow.run_flow(X0, Y, store_trajectory=True) |
| self.pool.add_trajectory(trajectory) |
| if (batch_idx + 1) % max(1, num_batches // 10) == 0: |
| logger.info(f" Pool building: {batch_idx + 1}/{num_batches}, pool size: {len(self.pool)}") |
| logger.info(f"Trajectory pool built. Total entries: {len(self.pool)}") |
| |
| self.pool.finalize() |
| logger.info("Trajectory pool finalized (pre-concatenated for fast sampling).") |
|
|
| def train(self) -> Dict[str, list]: |
| self.model.train() |
| history = {"loss": [], "step": []} |
| logger.info(f"Starting NSGF velocity field matching: {self.num_iterations} iterations") |
| for step in range(self.num_iterations): |
| x_batch, v_batch, t_batch = self.pool.sample(self.train_batch_size, self.device) |
| t_normalized = t_batch / max(self.gradient_flow.num_steps, 1.0) |
| v_pred = self.model(x_batch, t_normalized) |
| loss = ((v_pred - v_batch) ** 2).mean() |
| self.optimizer.zero_grad() |
| loss.backward() |
| self.optimizer.step() |
| if (step + 1) % 500 == 0 or step == 0: |
| loss_val = loss.item() |
| history["loss"].append(loss_val) |
| history["step"].append(step + 1) |
| logger.info(f" Step {step + 1}/{self.num_iterations}, Loss: {loss_val:.6f}") |
| logger.info("NSGF training complete.") |
| return history |
|
|
|
|
| class NSFTrainer: |
| """Trainer for Neural Straight Flow (Phase 2 of NSGF++). |
| |
| Straight flow: X_t = (1-t)*P_0 + t*P_1, target velocity = P_1 - P_0 |
| """ |
| def __init__(self, model: nn.Module, nsgf_model: nn.Module, |
| data_loader: DatasetLoader, config: dict, |
| nsgf_num_steps: int = 5, device: str = "cpu"): |
| self.model = model.to(device) |
| self.nsgf_model = nsgf_model.to(device) |
| self.nsgf_model.eval() |
| self.data_loader = data_loader |
| self.config = config |
| self.device = device |
| self.nsgf_num_steps = nsgf_num_steps |
|
|
| train_cfg = config.get("nsf_training", config.get("training", {})) |
| self.num_iterations = train_cfg.get("num_iterations", 100000) |
| self.train_batch_size = train_cfg.get("batch_size", 128) |
| self.lr = train_cfg.get("learning_rate", 1e-4) |
| self.optimizer = optim.Adam( |
| self.model.parameters(), lr=self.lr, |
| betas=(train_cfg.get("beta1", 0.9), train_cfg.get("beta2", 0.999)), |
| weight_decay=train_cfg.get("weight_decay", 0.0), |
| ) |
|
|
| @torch.no_grad() |
| def _generate_nsgf_samples(self, n: int) -> torch.Tensor: |
| X = self.data_loader.sample_source(n, self.device) |
| dt = 1.0 / self.nsgf_num_steps |
| for step in range(self.nsgf_num_steps): |
| t = torch.full((n,), step * dt, device=self.device) |
| v = self.nsgf_model(X, t) |
| X = X + dt * v |
| return X |
|
|
| def train(self) -> Dict[str, list]: |
| self.model.train() |
| history = {"loss": [], "step": []} |
| logger.info(f"Starting NSF training: {self.num_iterations} iterations") |
| for step in range(self.num_iterations): |
| P0 = self._generate_nsgf_samples(self.train_batch_size) |
| P1 = self.data_loader.sample_target(self.train_batch_size, self.device) |
| t = torch.rand(self.train_batch_size, device=self.device) |
| if P0.dim() == 2: |
| t_expand = t.unsqueeze(-1) |
| else: |
| t_expand = t.view(-1, 1, 1, 1) |
| X_t = (1 - t_expand) * P0 + t_expand * P1 |
| v_target = P1 - P0 |
| v_pred = self.model(X_t, t) |
| loss = ((v_pred - v_target) ** 2).mean() |
| self.optimizer.zero_grad() |
| loss.backward() |
| self.optimizer.step() |
| if (step + 1) % 500 == 0 or step == 0: |
| loss_val = loss.item() |
| history["loss"].append(loss_val) |
| history["step"].append(step + 1) |
| logger.info(f" Step {step + 1}/{self.num_iterations}, Loss: {loss_val:.6f}") |
| logger.info("NSF training complete.") |
| return history |
|
|
|
|
| class PhaseTransitionTrainer: |
| """Trainer for the phase-transition time predictor. |
| Loss: L(ϕ) = E_{t~U(0,1)} ||t - t_ϕ(X_t)||² |
| """ |
| def __init__(self, predictor: PhaseTransitionPredictor, nsgf_model: nn.Module, |
| data_loader: DatasetLoader, config: dict, |
| nsgf_num_steps: int = 5, device: str = "cpu"): |
| self.predictor = predictor.to(device) |
| self.nsgf_model = nsgf_model.to(device) |
| self.nsgf_model.eval() |
| self.data_loader = data_loader |
| self.config = config |
| self.device = device |
| self.nsgf_num_steps = nsgf_num_steps |
| tp_cfg = config.get("time_predictor", {}) |
| self.num_iterations = tp_cfg.get("num_iterations", 40000) |
| self.batch_size = tp_cfg.get("batch_size", 128) |
| self.lr = tp_cfg.get("learning_rate", 1e-4) |
| self.optimizer = optim.Adam(self.predictor.parameters(), lr=self.lr, betas=(0.9, 0.999)) |
|
|
| @torch.no_grad() |
| def _generate_nsgf_samples(self, n: int) -> torch.Tensor: |
| X = self.data_loader.sample_source(n, self.device) |
| dt = 1.0 / self.nsgf_num_steps |
| for step in range(self.nsgf_num_steps): |
| t = torch.full((n,), step * dt, device=self.device) |
| v = self.nsgf_model(X, t) |
| X = X + dt * v |
| return X |
|
|
| def train(self) -> Dict[str, list]: |
| self.predictor.train() |
| history = {"loss": [], "step": []} |
| logger.info(f"Starting phase predictor training: {self.num_iterations} iterations") |
| for step in range(self.num_iterations): |
| P0 = self._generate_nsgf_samples(self.batch_size) |
| P1 = self.data_loader.sample_target(self.batch_size, self.device) |
| t = torch.rand(self.batch_size, device=self.device) |
| if P0.dim() == 4: |
| t_expand = t.view(-1, 1, 1, 1) |
| else: |
| t_expand = t.unsqueeze(-1) |
| X_t = (1 - t_expand) * P0 + t_expand * P1 |
| t_pred = self.predictor(X_t) |
| loss = ((t_pred - t) ** 2).mean() |
| self.optimizer.zero_grad() |
| loss.backward() |
| self.optimizer.step() |
| if (step + 1) % 1000 == 0 or step == 0: |
| loss_val = loss.item() |
| history["loss"].append(loss_val) |
| history["step"].append(step + 1) |
| logger.info(f" Step {step + 1}/{self.num_iterations}, Loss: {loss_val:.6f}") |
| logger.info("Phase predictor training complete.") |
| return history |
|
|
|
|
| class NSGFPlusPlusTrainer: |
| """End-to-end NSGF++ trainer (Algorithm 3 / Appendix D).""" |
| def __init__(self, nsgf_model: nn.Module, nsf_model: nn.Module, |
| phase_predictor: PhaseTransitionPredictor, |
| data_loader: DatasetLoader, config: dict, device: str = "cpu"): |
| self.nsgf_model = nsgf_model |
| self.nsf_model = nsf_model |
| self.phase_predictor = phase_predictor |
| self.data_loader = data_loader |
| self.config = config |
| self.device = device |
|
|
| def train_all(self) -> Dict[str, Any]: |
| results = {} |
| logger.info("=" * 60) |
| logger.info("Phase 1: Training NSGF model") |
| logger.info("=" * 60) |
| nsgf_trainer = NSGFTrainer( |
| model=self.nsgf_model, data_loader=self.data_loader, |
| config=self.config, device=self.device, |
| ) |
| nsgf_trainer.build_trajectory_pool() |
| results["nsgf"] = nsgf_trainer.train() |
|
|
| logger.info("=" * 60) |
| logger.info("Phase 2: Training NSF (Neural Straight Flow) model") |
| logger.info("=" * 60) |
| nsgf_steps = self.config.get("sinkhorn", {}).get("num_steps", 5) |
| nsf_trainer = NSFTrainer( |
| model=self.nsf_model, nsgf_model=self.nsgf_model, |
| data_loader=self.data_loader, config=self.config, |
| nsgf_num_steps=nsgf_steps, device=self.device, |
| ) |
| results["nsf"] = nsf_trainer.train() |
|
|
| logger.info("=" * 60) |
| logger.info("Phase 3: Training phase-transition time predictor") |
| logger.info("=" * 60) |
| pt_trainer = PhaseTransitionTrainer( |
| predictor=self.phase_predictor, nsgf_model=self.nsgf_model, |
| data_loader=self.data_loader, config=self.config, |
| nsgf_num_steps=nsgf_steps, device=self.device, |
| ) |
| results["phase_predictor"] = pt_trainer.train() |
|
|
| logger.info("=" * 60) |
| logger.info("NSGF++ training complete!") |
| logger.info("=" * 60) |
| return results |
|
|