|
|
|
|
|
|
| import torch
|
|
|
| min_var_est = 1e-8
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def linear_mmd2(f_of_X, f_of_Y):
|
| loss = 0.0
|
| delta = f_of_X - f_of_Y
|
| loss = torch.mean((delta[:-1] * delta[1:]).sum(1))
|
| return loss
|
|
|
|
|
|
|
|
|
|
|
|
|
| def poly_mmd2(f_of_X, f_of_Y, d=2, alpha=1.0, c=2.0):
|
| K_XX = alpha * (f_of_X[:-1] * f_of_X[1:]).sum(1) + c
|
| K_XX_mean = torch.mean(K_XX.pow(d))
|
|
|
| K_YY = alpha * (f_of_Y[:-1] * f_of_Y[1:]).sum(1) + c
|
| K_YY_mean = torch.mean(K_YY.pow(d))
|
|
|
| K_XY = alpha * (f_of_X[:-1] * f_of_Y[1:]).sum(1) + c
|
| K_XY_mean = torch.mean(K_XY.pow(d))
|
|
|
| K_YX = alpha * (f_of_Y[:-1] * f_of_X[1:]).sum(1) + c
|
| K_YX_mean = torch.mean(K_YX.pow(d))
|
|
|
| return K_XX_mean + K_YY_mean - K_XY_mean - K_YX_mean
|
|
|
|
|
| def _mix_rbf_kernel(X, Y, sigma_list):
|
| assert X.size(0) == Y.size(0)
|
| m = X.size(0)
|
|
|
| Z = torch.cat((X, Y), 0)
|
| ZZT = torch.mm(Z, Z.t())
|
| diag_ZZT = torch.diag(ZZT).unsqueeze(1)
|
| Z_norm_sqr = diag_ZZT.expand_as(ZZT)
|
| exponent = Z_norm_sqr - 2 * ZZT + Z_norm_sqr.t()
|
|
|
| K = 0.0
|
| for sigma in sigma_list:
|
| gamma = 1.0 / (2 * sigma**2)
|
| K += torch.exp(-gamma * exponent)
|
|
|
| return K[:m, :m], K[:m, m:], K[m:, m:], len(sigma_list)
|
|
|
|
|
| def mix_rbf_mmd2(X, Y, sigma_list, biased=True):
|
| K_XX, K_XY, K_YY, d = _mix_rbf_kernel(X, Y, sigma_list)
|
|
|
| return _mmd2(K_XX, K_XY, K_YY, const_diagonal=False, biased=biased)
|
|
|
|
|
| def mix_rbf_mmd2_and_ratio(X, Y, sigma_list, biased=True):
|
| K_XX, K_XY, K_YY, d = _mix_rbf_kernel(X, Y, sigma_list)
|
|
|
| return _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=False, biased=biased)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def _mmd2(K_XX, K_XY, K_YY, const_diagonal=False, biased=False):
|
| m = K_XX.size(0)
|
|
|
|
|
|
|
| if const_diagonal is not False:
|
| diag_X = diag_Y = const_diagonal
|
| sum_diag_X = sum_diag_Y = m * const_diagonal
|
| else:
|
| diag_X = torch.diag(K_XX)
|
| diag_Y = torch.diag(K_YY)
|
| sum_diag_X = torch.sum(diag_X)
|
| sum_diag_Y = torch.sum(diag_Y)
|
|
|
| Kt_XX_sums = K_XX.sum(dim=1) - diag_X
|
| Kt_YY_sums = K_YY.sum(dim=1) - diag_Y
|
| K_XY_sums_0 = K_XY.sum(dim=0)
|
|
|
| Kt_XX_sum = Kt_XX_sums.sum()
|
| Kt_YY_sum = Kt_YY_sums.sum()
|
| K_XY_sum = K_XY_sums_0.sum()
|
|
|
| if biased:
|
| mmd2 = (
|
| (Kt_XX_sum + sum_diag_X) / (m * m)
|
| + (Kt_YY_sum + sum_diag_Y) / (m * m)
|
| - 2.0 * K_XY_sum / (m * m)
|
| )
|
| else:
|
| mmd2 = Kt_XX_sum / (m * (m - 1)) + Kt_YY_sum / (m * (m - 1)) - 2.0 * K_XY_sum / (m * m)
|
|
|
| return mmd2
|
|
|
|
|
| def _mmd2_and_ratio(K_XX, K_XY, K_YY, const_diagonal=False, biased=False):
|
| mmd2, var_est = _mmd2_and_variance(
|
| K_XX, K_XY, K_YY, const_diagonal=const_diagonal, biased=biased
|
| )
|
| loss = mmd2 / torch.sqrt(torch.clamp(var_est, min=min_var_est))
|
| return loss, mmd2, var_est
|
|
|
|
|
| def _mmd2_and_variance(K_XX, K_XY, K_YY, const_diagonal=False, biased=False):
|
| m = K_XX.size(0)
|
|
|
|
|
|
|
| if const_diagonal is not False:
|
| diag_X = diag_Y = const_diagonal
|
| sum_diag_X = sum_diag_Y = m * const_diagonal
|
| sum_diag2_X = sum_diag2_Y = m * const_diagonal**2
|
| else:
|
| diag_X = torch.diag(K_XX)
|
| diag_Y = torch.diag(K_YY)
|
| sum_diag_X = torch.sum(diag_X)
|
| sum_diag_Y = torch.sum(diag_Y)
|
| sum_diag2_X = diag_X.dot(diag_X)
|
| sum_diag2_Y = diag_Y.dot(diag_Y)
|
|
|
| Kt_XX_sums = K_XX.sum(dim=1) - diag_X
|
| Kt_YY_sums = K_YY.sum(dim=1) - diag_Y
|
| K_XY_sums_0 = K_XY.sum(dim=0)
|
| K_XY_sums_1 = K_XY.sum(dim=1)
|
|
|
| Kt_XX_sum = Kt_XX_sums.sum()
|
| Kt_YY_sum = Kt_YY_sums.sum()
|
| K_XY_sum = K_XY_sums_0.sum()
|
|
|
| Kt_XX_2_sum = (K_XX**2).sum() - sum_diag2_X
|
| Kt_YY_2_sum = (K_YY**2).sum() - sum_diag2_Y
|
| K_XY_2_sum = (K_XY**2).sum()
|
|
|
| if biased:
|
| mmd2 = (
|
| (Kt_XX_sum + sum_diag_X) / (m * m)
|
| + (Kt_YY_sum + sum_diag_Y) / (m * m)
|
| - 2.0 * K_XY_sum / (m * m)
|
| )
|
| else:
|
| mmd2 = Kt_XX_sum / (m * (m - 1)) + Kt_YY_sum / (m * (m - 1)) - 2.0 * K_XY_sum / (m * m)
|
|
|
| var_est = (
|
| 2.0
|
| / (m**2 * (m - 1.0) ** 2)
|
| * (
|
| 2 * Kt_XX_sums.dot(Kt_XX_sums)
|
| - Kt_XX_2_sum
|
| + 2 * Kt_YY_sums.dot(Kt_YY_sums)
|
| - Kt_YY_2_sum
|
| )
|
| - (4.0 * m - 6.0) / (m**3 * (m - 1.0) ** 3) * (Kt_XX_sum**2 + Kt_YY_sum**2)
|
| + 4.0
|
| * (m - 2.0)
|
| / (m**3 * (m - 1.0) ** 2)
|
| * (K_XY_sums_1.dot(K_XY_sums_1) + K_XY_sums_0.dot(K_XY_sums_0))
|
| - 4.0 * (m - 3.0) / (m**3 * (m - 1.0) ** 2) * (K_XY_2_sum)
|
| - (8 * m - 12) / (m**5 * (m - 1)) * K_XY_sum**2
|
| + 8.0
|
| / (m**3 * (m - 1.0))
|
| * (
|
| 1.0 / m * (Kt_XX_sum + Kt_YY_sum) * K_XY_sum
|
| - Kt_XX_sums.dot(K_XY_sums_1)
|
| - Kt_YY_sums.dot(K_XY_sums_0)
|
| )
|
| )
|
| return mmd2, var_est
|
|
|