jsflow / conditional-flow-matching /runner /src /models /components /sinkhorn_knopp_unbalanced.py
xiangzai's picture
Add files using upload-large-folder tool
3e4f775 verified
"""Implements unbalanced sinkhorn knopp optimization for unbalanced ot.
This is from the package python optimal transport but modified to take three regularization
parameters instead of two. This is necessary to find growth rates of the source distribution that
best match the target distribution or vis versa. by setting reg_m_1 to something low and reg_m_2 to
something large we can compute an unbalanced optimal transport where all the scaling is done on the
source distribution and none is done on the target distribution.
"""
import warnings
import numpy as np
def sinkhorn_knopp_unbalanced(
a,
b,
M,
reg,
reg_m_1,
reg_m_2,
numItermax=1000,
stopThr=1e-6,
verbose=False,
log=False,
**kwargs,
):
"""Solve the entropic regularization unbalanced optimal transport problem.
The function solves the following optimization problem:
.. math::
W = \\min_\\gamma <\\gamma,M>_F + reg\\cdot\\Omega(\\gamma) + \
\reg_m_1 KL(\\gamma 1, a) + \reg_m_2 KL(\\gamma^T 1, b)
s.t.
\\gamma\\geq 0
where :
- M is the (dim_a, dim_b) metric cost matrix
- :math:`\\Omega` is the entropic regularization term
:math:`\\Omega(\\gamma)=\\sum_{i,j} \\gamma_{i,j}\\log(\\gamma_{i,j})`
- a and b are source and target unbalanced distributions
- KL is the Kullback-Leibler divergence
The algorithm used for solving the problem is the generalized
Sinkhorn-Knopp matrix scaling algorithm as proposed in [10, 23]_
Parameters
----------
a : np.ndarray (dim_a,)
Unnormalized histogram of dimension dim_a
b : np.ndarray (dim_b,) or np.ndarray (dim_b, n_hists)
One or multiple unnormalized histograms of dimension dim_b
If many, compute all the OT distances (a, b_i)
M : np.ndarray (dim_a, dim_b)
loss matrix
reg : float
Entropy regularization term > 0
reg_m: float
Marginal relaxation term > 0
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (> 0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
Returns
-------
if n_hists == 1:
gamma : (dim_a x dim_b) ndarray
Optimal transportation matrix for the given parameters
log : dict
log dictionary returned only if `log` is `True`
else:
ot_distance : (n_hists,) ndarray
the OT distance between `a` and each of the histograms `b_i`
log : dict
log dictionary returned only if `log` is `True`
Examples
--------
>>> import ot
>>> a=[.5, .5]
>>> b=[.5, .5]
>>> M=[[0., 1.],[1., 0.]]
>>> ot.unbalanced.sinkhorn_knopp_unbalanced(a, b, M, 1., 1.)
array([[0.51122814, 0.18807032],
[0.18807032, 0.51122814]])
References
----------
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
Scaling algorithms for unbalanced transport problems. arXiv preprint
arXiv:1607.05816.
.. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. :
Learning with a Wasserstein Loss, Advances in Neural Information
Processing Systems (NIPS) 2015
See Also
--------
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT
"""
a = np.asarray(a, dtype=np.float64)
b = np.asarray(b, dtype=np.float64)
M = np.asarray(M, dtype=np.float64)
dim_a, dim_b = M.shape
if len(a) == 0:
a = np.ones(dim_a, dtype=np.float64) / dim_a
if len(b) == 0:
b = np.ones(dim_b, dtype=np.float64) / dim_b
if len(b.shape) > 1:
n_hists = b.shape[1]
else:
n_hists = 0
if log:
log = {"err": []}
# we assume that no distances are null except those of the diagonal of
# distances
if n_hists:
u = np.ones((dim_a, 1)) / dim_a
v = np.ones((dim_b, n_hists)) / dim_b
a = a.reshape(dim_a, 1)
else:
u = np.ones(dim_a) / dim_a
v = np.ones(dim_b) / dim_b
# Next 3 lines equivalent to K= np.exp(-M/reg), but faster to compute
K = np.empty(M.shape, dtype=M.dtype)
np.divide(M, -reg, out=K)
np.exp(K, out=K)
cpt = 0
err = 1.0
while err > stopThr and cpt < numItermax:
uprev = u
vprev = v
Kv = K.dot(v)
u = (a / Kv) ** (reg_m_1 / (reg_m_1 + reg))
Ktu = K.T.dot(u)
v = (b / Ktu) ** (reg_m_2 / (reg_m_2 + reg))
if (
np.any(Ktu == 0.0)
or np.any(np.isnan(u))
or np.any(np.isnan(v))
or np.any(np.isinf(u))
or np.any(np.isinf(v))
):
# we have reached the machine precision
# come back to previous solution and quit loop
warnings.warn("Numerical errors at iteration %s" % cpt)
u = uprev
v = vprev
break
if cpt % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
err_u = abs(u - uprev).max() / max(abs(u).max(), abs(uprev).max(), 1.0)
err_v = abs(v - vprev).max() / max(abs(v).max(), abs(vprev).max(), 1.0)
err = 0.5 * (err_u + err_v)
if log:
log["err"].append(err)
if verbose:
if cpt % 200 == 0:
print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19)
print(f"{cpt:5d}|{err:8e}|")
cpt += 1
if log:
log["logu"] = np.log(u + 1e-16)
log["logv"] = np.log(v + 1e-16)
if n_hists: # return only loss
res = np.einsum("ik,ij,jk,ij->k", u, K, v, M)
if log:
return res, log
else:
return res
else: # return OT matrix
if log:
return u[:, None] * K * v[None, :], log
else:
return u[:, None] * K * v[None, :]