| """sinkhorn_flow.py — Sinkhorn gradient flow and W_ε potential computation. |
| |
| Core implementation of: |
| - Sinkhorn divergence computation via GeomLoss |
| - W_ε-potential gradients (∇f_{μ,μ} and ∇f_{μ,μ*}) |
| - Velocity field: v(x) = ∇f_{μ,μ}(x) - ∇f_{μ,μ*}(x) (Theorem 1, Eq. 10) |
| - Euler discretization of the Sinkhorn WGF (Algorithm 1) |
| - Trajectory pool construction for velocity field matching |
| |
| Reference: arXiv:2401.14069, Section 4.1, 4.3, Appendix A |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| from typing import List, Tuple, Optional |
| from geomloss import SamplesLoss |
|
|
|
|
| class SinkhornPotentialComputer: |
| """Computes W_ε-potentials and their gradients using GeomLoss. |
| |
| The velocity field of the Sinkhorn WGF is (Theorem 1): |
| v(x) = ∇f_{μ,μ}(x) - ∇f_{μ,μ*}(x) |
| |
| IMPORTANT: GeomLoss SamplesLoss requires inputs as (N, D) or (B, N, D) tensors. |
| For image data (N, C, H, W), we flatten to (N, C*H*W) before calling geomloss, |
| then reshape gradients back to (N, C, H, W). |
| |
| Args: |
| blur: GeomLoss blur parameter (related to ε: ε = blur^p). |
| scaling: Multiscale scaling parameter for Sinkhorn iterations. |
| p: Cost exponent (default 2 for squared Euclidean). |
| backend: GeomLoss backend ('auto', 'tensorized', 'online'). |
| """ |
|
|
| def __init__(self, blur: float = 0.5, scaling: float = 0.80, |
| p: int = 2, backend: str = "tensorized"): |
| self.blur = blur |
| self.scaling = scaling |
| self.p = p |
| self.backend = backend |
|
|
| self.loss_fn = SamplesLoss( |
| loss="sinkhorn", p=p, blur=blur, scaling=scaling, |
| backend=backend, potentials=True, |
| ) |
| self.loss_monitor = SamplesLoss( |
| loss="sinkhorn", p=p, blur=blur, scaling=scaling, |
| backend=backend, potentials=False, |
| ) |
|
|
| def _flatten_if_image(self, X: torch.Tensor) -> Tuple[torch.Tensor, bool, torch.Size]: |
| """Flatten (N,C,H,W) → (N,D) for geomloss. Returns (flat_tensor, was_image, original_shape).""" |
| original_shape = X.shape |
| if X.dim() == 4: |
| return X.view(X.shape[0], -1), True, original_shape |
| return X, False, original_shape |
|
|
| def compute_velocity(self, X: torch.Tensor, Y: torch.Tensor) -> torch.Tensor: |
| """Compute the Sinkhorn WGF velocity field at particles X. |
| |
| v(X_i) = ∇f_{μ,μ}(X_i) - ∇f_{μ,μ*}(X_i) |
| |
| Handles both 2D point clouds (N,D) and images (N,C,H,W) by |
| flattening images before geomloss calls. |
| """ |
| original_shape = X.shape |
|
|
| |
| X_flat, is_image, _ = self._flatten_if_image(X.detach().clone()) |
| Y_flat, _, _ = self._flatten_if_image(Y.detach()) |
|
|
| |
| X_grad = X_flat.requires_grad_(True) |
| X_self_detached = X_flat.detach().clone() |
| F_self, _ = self.loss_fn(X_grad, X_self_detached) |
| grad_self = torch.autograd.grad( |
| F_self.sum(), X_grad, create_graph=False, retain_graph=False |
| )[0] |
|
|
| |
| X_grad2 = X_flat.detach().clone().requires_grad_(True) |
| F_cross, _ = self.loss_fn(X_grad2, Y_flat) |
| grad_cross = torch.autograd.grad( |
| F_cross.sum(), X_grad2, create_graph=False, retain_graph=False |
| )[0] |
|
|
| |
| velocity = grad_self.detach() - grad_cross.detach() |
|
|
| |
| if is_image: |
| velocity = velocity.view(original_shape) |
|
|
| return velocity |
|
|
| def compute_sinkhorn_divergence(self, X: torch.Tensor, Y: torch.Tensor) -> float: |
| """Compute Sinkhorn divergence S_ε(μ, μ*). Handles image tensors.""" |
| with torch.no_grad(): |
| X_flat, _, _ = self._flatten_if_image(X) |
| Y_flat, _, _ = self._flatten_if_image(Y) |
| return self.loss_monitor(X_flat, Y_flat).item() |
|
|
|
|
| class SinkhornGradientFlow: |
| """Implements the discrete Sinkhorn Wasserstein Gradient Flow. |
| |
| Evolves particles via Euler steps: |
| X^{t+1} = X^t + η * v(X^t) |
| """ |
|
|
| def __init__(self, potential_computer: SinkhornPotentialComputer, |
| eta: float = 1.0, num_steps: int = 5): |
| self.potential_computer = potential_computer |
| self.eta = eta |
| self.num_steps = num_steps |
|
|
| def run_flow(self, X0: torch.Tensor, Y: torch.Tensor, |
| store_trajectory: bool = True |
| ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor, int]]]: |
| trajectory = [] |
| X_t = X0.clone() |
|
|
| for t in range(self.num_steps): |
| v_t = self.potential_computer.compute_velocity(X_t, Y) |
| if store_trajectory: |
| trajectory.append(( |
| X_t.detach().cpu().clone(), |
| v_t.detach().cpu().clone(), |
| t, |
| )) |
| X_t = X_t.detach() + self.eta * v_t.detach() |
|
|
| return X_t, trajectory |
|
|
| def run_flow_no_store(self, X0: torch.Tensor, Y: torch.Tensor) -> torch.Tensor: |
| X_T, _ = self.run_flow(X0, Y, store_trajectory=False) |
| return X_T |
|
|
|
|
| class TrajectoryPool: |
| """Stores (x, v, t) tuples from Sinkhorn gradient flow trajectories. |
| |
| After building, call finalize() to pre-concatenate tensors for O(1) sampling. |
| Without finalize(), sampling is O(pool_size) per call due to torch.cat. |
| """ |
|
|
| def __init__(self, max_size: int = 1_000_000): |
| self.max_size = max_size |
| self.x_pool: List[torch.Tensor] = [] |
| self.v_pool: List[torch.Tensor] = [] |
| self.t_pool: List[int] = [] |
| self._size = 0 |
| self._finalized = False |
| self._all_x = None |
| self._all_v = None |
| self._all_t = None |
|
|
| def add_trajectory(self, trajectory: List[Tuple[torch.Tensor, torch.Tensor, int]]): |
| """Add (x, v, t) entries from a flow trajectory. Call before finalize().""" |
| if self._finalized: |
| raise RuntimeError("Cannot add to a finalized pool. Create a new pool.") |
| for x, v, t in trajectory: |
| n = x.shape[0] |
| if self._size + n > self.max_size: |
| excess = (self._size + n) - self.max_size |
| self._drop_oldest(excess) |
| self.x_pool.append(x) |
| self.v_pool.append(v) |
| self.t_pool.extend([t] * n) |
| self._size += n |
|
|
| def _drop_oldest(self, n: int): |
| removed = 0 |
| while removed < n and len(self.x_pool) > 0: |
| batch_size = self.x_pool[0].shape[0] |
| if removed + batch_size <= n: |
| self.x_pool.pop(0) |
| self.v_pool.pop(0) |
| self.t_pool = self.t_pool[batch_size:] |
| removed += batch_size |
| self._size -= batch_size |
| else: |
| keep = batch_size - (n - removed) |
| self.x_pool[0] = self.x_pool[0][-keep:] |
| self.v_pool[0] = self.v_pool[0][-keep:] |
| self.t_pool = self.t_pool[(batch_size - keep):] |
| self._size -= (batch_size - keep) |
| removed = n |
|
|
| def finalize(self): |
| """Pre-concatenate all pool data for fast O(1) sampling. |
| |
| Call this once after all trajectories have been added. |
| After finalization, sample() is fast (just random indexing). |
| """ |
| if self._size == 0: |
| raise RuntimeError("Cannot finalize an empty pool.") |
| self._all_x = torch.cat(self.x_pool, dim=0) |
| self._all_v = torch.cat(self.v_pool, dim=0) |
| self._all_t = torch.tensor(self.t_pool, dtype=torch.float32) |
| |
| self.x_pool = None |
| self.v_pool = None |
| self.t_pool = None |
| self._finalized = True |
|
|
| def sample(self, batch_size: int, device: str = "cpu" |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """Sample a random batch from the pool. |
| |
| If finalize() was called, this is O(1). Otherwise falls back to O(pool_size). |
| """ |
| if self._finalized: |
| idx = torch.randint(0, self._all_x.shape[0], (batch_size,)) |
| return ( |
| self._all_x[idx].to(device), |
| self._all_v[idx].to(device), |
| self._all_t[idx].to(device), |
| ) |
| else: |
| |
| all_x = torch.cat(self.x_pool, dim=0) |
| all_v = torch.cat(self.v_pool, dim=0) |
| all_t = torch.tensor(self.t_pool, dtype=torch.float32) |
| idx = torch.randint(0, all_x.shape[0], (batch_size,)) |
| return all_x[idx].to(device), all_v[idx].to(device), all_t[idx].to(device) |
|
|
| @property |
| def size(self) -> int: |
| return self._size |
|
|
| def __len__(self) -> int: |
| return self._size |
|
|