| """Implements unbalanced sinkhorn knopp optimization for unbalanced ot.
|
|
|
| This is from the package python optimal transport but modified to take three regularization
|
| parameters instead of two. This is necessary to find growth rates of the source distribution that
|
| best match the target distribution or vis versa. by setting reg_m_1 to something low and reg_m_2 to
|
| something large we can compute an unbalanced optimal transport where all the scaling is done on the
|
| source distribution and none is done on the target distribution.
|
| """
|
|
|
| import warnings
|
|
|
| import numpy as np
|
|
|
|
|
| def sinkhorn_knopp_unbalanced(
|
| a,
|
| b,
|
| M,
|
| reg,
|
| reg_m_1,
|
| reg_m_2,
|
| numItermax=1000,
|
| stopThr=1e-6,
|
| verbose=False,
|
| log=False,
|
| **kwargs,
|
| ):
|
| """Solve the entropic regularization unbalanced optimal transport problem.
|
|
|
| The function solves the following optimization problem:
|
|
|
| .. math::
|
| W = \\min_\\gamma <\\gamma,M>_F + reg\\cdot\\Omega(\\gamma) + \
|
| \reg_m_1 KL(\\gamma 1, a) + \reg_m_2 KL(\\gamma^T 1, b)
|
|
|
| s.t.
|
| \\gamma\\geq 0
|
| where :
|
|
|
| - M is the (dim_a, dim_b) metric cost matrix
|
| - :math:`\\Omega` is the entropic regularization term
|
| :math:`\\Omega(\\gamma)=\\sum_{i,j} \\gamma_{i,j}\\log(\\gamma_{i,j})`
|
| - a and b are source and target unbalanced distributions
|
| - KL is the Kullback-Leibler divergence
|
|
|
| The algorithm used for solving the problem is the generalized
|
| Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
|
|
|
|
|
| Parameters
|
| ----------
|
| a : np.ndarray (dim_a,)
|
| Unnormalized histogram of dimension dim_a
|
| b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
|
| One or multiple unnormalized histograms of dimension dim_b
|
| If many, compute all the OT distances (a, b_i)
|
| M : np.ndarray (dim_a, dim_b)
|
| loss matrix
|
| reg : float
|
| Entropy regularization term > 0
|
| reg_m: float
|
| Marginal relaxation term > 0
|
| numItermax : int, optional
|
| Max number of iterations
|
| stopThr : float, optional
|
| Stop threshold on error (> 0)
|
| verbose : bool, optional
|
| Print information along iterations
|
| log : bool, optional
|
| record log if True
|
|
|
|
|
| Returns
|
| -------
|
| if n_hists == 1:
|
| gamma : (dim_a x dim_b) ndarray
|
| Optimal transportation matrix for the given parameters
|
| log : dict
|
| log dictionary returned only if `log` is `True`
|
| else:
|
| ot_distance : (n_hists,) ndarray
|
| the OT distance between `a` and each of the histograms `b_i`
|
| log : dict
|
| log dictionary returned only if `log` is `True`
|
| Examples
|
| --------
|
|
|
| >>> import ot
|
| >>> a=[.5, .5]
|
| >>> b=[.5, .5]
|
| >>> M=[[0., 1.],[1., 0.]]
|
| >>> ot.unbalanced.sinkhorn_knopp_unbalanced(a, b, M, 1., 1.)
|
| array([[0.51122814, 0.18807032],
|
| [0.18807032, 0.51122814]])
|
|
|
| References
|
| ----------
|
|
|
| .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
|
| Scaling algorithms for unbalanced transport problems. arXiv preprint
|
| arXiv:1607.05816.
|
|
|
| .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. :
|
| Learning with a Wasserstein Loss, Advances in Neural Information
|
| Processing Systems (NIPS) 2015
|
|
|
| See Also
|
| --------
|
| ot.lp.emd : Unregularized OT
|
| ot.optim.cg : General regularized OT
|
| """
|
|
|
| a = np.asarray(a, dtype=np.float64)
|
| b = np.asarray(b, dtype=np.float64)
|
| M = np.asarray(M, dtype=np.float64)
|
|
|
| dim_a, dim_b = M.shape
|
|
|
| if len(a) == 0:
|
| a = np.ones(dim_a, dtype=np.float64) / dim_a
|
| if len(b) == 0:
|
| b = np.ones(dim_b, dtype=np.float64) / dim_b
|
|
|
| if len(b.shape) > 1:
|
| n_hists = b.shape[1]
|
| else:
|
| n_hists = 0
|
|
|
| if log:
|
| log = {"err": []}
|
|
|
|
|
|
|
| if n_hists:
|
| u = np.ones((dim_a, 1)) / dim_a
|
| v = np.ones((dim_b, n_hists)) / dim_b
|
| a = a.reshape(dim_a, 1)
|
| else:
|
| u = np.ones(dim_a) / dim_a
|
| v = np.ones(dim_b) / dim_b
|
|
|
|
|
| K = np.empty(M.shape, dtype=M.dtype)
|
| np.divide(M, -reg, out=K)
|
| np.exp(K, out=K)
|
|
|
| cpt = 0
|
| err = 1.0
|
|
|
| while err > stopThr and cpt < numItermax:
|
| uprev = u
|
| vprev = v
|
|
|
| Kv = K.dot(v)
|
| u = (a / Kv) ** (reg_m_1 / (reg_m_1 + reg))
|
| Ktu = K.T.dot(u)
|
| v = (b / Ktu) ** (reg_m_2 / (reg_m_2 + reg))
|
|
|
| if (
|
| np.any(Ktu == 0.0)
|
| or np.any(np.isnan(u))
|
| or np.any(np.isnan(v))
|
| or np.any(np.isinf(u))
|
| or np.any(np.isinf(v))
|
| ):
|
|
|
|
|
| warnings.warn("Numerical errors at iteration %s" % cpt)
|
| u = uprev
|
| v = vprev
|
| break
|
| if cpt % 10 == 0:
|
|
|
|
|
| err_u = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), 1.0)
|
| err_v = abs(v - vprev).max() / max(abs(v).max(), abs(vprev).max(), 1.0)
|
| err = 0.5 * (err_u + err_v)
|
| if log:
|
| log["err"].append(err)
|
| if verbose:
|
| if cpt % 200 == 0:
|
| print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19)
|
| print(f"{cpt:5d}|{err:8e}|")
|
| cpt += 1
|
|
|
| if log:
|
| log["logu"] = np.log(u + 1e-16)
|
| log["logv"] = np.log(v + 1e-16)
|
|
|
| if n_hists:
|
| res = np.einsum("ik,ij,jk,ij->k", u, K, v, M)
|
| if log:
|
| return res, log
|
| else:
|
| return res
|
|
|
| else:
|
| if log:
|
| return u[:, None] * K * v[None, :], log
|
| else:
|
| return u[:, None] * K * v[None, :]
|
|
|