import math from functools import partial from typing import Optional import numpy as np import ot as pot import torch class OTPlanSampler: """OTPlanSampler implements sampling coordinates according to an squared L2 OT plan with different implementations of the plan calculation.""" def __init__( self, method: str, reg: float = 0.05, reg_m: float = 1.0, normalize_cost=False, **kwargs, ): # ot_fn should take (a, b, M) as arguments where a, b are marginals and # M is a cost matrix if method == "exact": self.ot_fn = pot.emd elif method == "sinkhorn": self.ot_fn = partial(pot.sinkhorn, reg=reg) elif method == "unbalanced": self.ot_fn = partial(pot.unbalanced.sinkhorn_knopp_unbalanced, reg=reg, reg_m=reg_m) elif method == "partial": self.ot_fn = partial(pot.partial.entropic_partial_wasserstein, reg=reg) else: raise ValueError(f"Unknown method: {method}") self.reg = reg self.reg_m = reg_m self.normalize_cost = normalize_cost self.kwargs = kwargs def get_map(self, x0, x1): a, b = pot.unif(x0.shape[0]), pot.unif(x1.shape[0]) if x0.dim() > 2: x0 = x0.reshape(x0.shape[0], -1) if x1.dim() > 2: x1 = x1.reshape(x1.shape[0], -1) x1 = x1.reshape(x1.shape[0], -1) M = torch.cdist(x0, x1) ** 2 if self.normalize_cost: M = M / M.max() p = self.ot_fn(a, b, M.detach().cpu().numpy()) if not np.all(np.isfinite(p)): print("ERROR: p is not finite") print(p) print("Cost mean, max", M.mean(), M.max()) print(x0, x1) return p def sample_map(self, pi, batch_size): p = pi.flatten() p = p / p.sum() choices = np.random.choice(pi.shape[0] * pi.shape[1], p=p, size=batch_size) return np.divmod(choices, pi.shape[1]) def sample_plan(self, x0, x1): pi = self.get_map(x0, x1) i, j = self.sample_map(pi, x0.shape[0]) return x0[i], x1[j] def sample_trajectory(self, X): # Assume X is [batch, times, dim] times = X.shape[1] pis = [] for t in range(times - 1): pis.append(self.get_map(X[:, t], X[:, t + 1])) indices = [np.arange(X.shape[0])] for pi in pis: j = [] for i in indices[-1]: j.append(np.random.choice(pi.shape[1], p=pi[i] / pi[i].sum())) indices.append(np.array(j)) to_return = [] for t in range(times): to_return.append(X[:, t][indices[t]]) to_return = np.stack(to_return, axis=1) return to_return def wasserstein( x0: torch.Tensor, x1: torch.Tensor, method: Optional[str] = None, reg: float = 0.05, power: int = 2, **kwargs, ) -> float: assert power == 1 or power == 2 # ot_fn should take (a, b, M) as arguments where a, b are marginals and # M is a cost matrix if method == "exact" or method is None: ot_fn = pot.emd2 elif method == "sinkhorn": ot_fn = partial(pot.sinkhorn2, reg=reg) else: raise ValueError(f"Unknown method: {method}") a, b = pot.unif(x0.shape[0]), pot.unif(x1.shape[0]) if x0.dim() > 2: x0 = x0.reshape(x0.shape[0], -1) if x1.dim() > 2: x1 = x1.reshape(x1.shape[0], -1) M = torch.cdist(x0, x1) if power == 2: M = M**2 ret = ot_fn(a, b, M.detach().cpu().numpy(), numItermax=1e7) if power == 2: ret = math.sqrt(ret) return ret