| """Tests for Conditional Flow Matcher classers."""
|
|
|
|
|
|
|
| import math
|
|
|
| import numpy as np
|
| import pytest
|
| import torch
|
|
|
| from torchcfm.conditional_flow_matching import (
|
| ConditionalFlowMatcher,
|
| ExactOptimalTransportConditionalFlowMatcher,
|
| SchrodingerBridgeConditionalFlowMatcher,
|
| TargetConditionalFlowMatcher,
|
| VariancePreservingConditionalFlowMatcher,
|
| pad_t_like_x,
|
| )
|
| from torchcfm.optimal_transport import OTPlanSampler
|
|
|
| TEST_SEED = 1994
|
| TEST_BATCH_SIZE = 128
|
| SIGMA_CONDITION = {
|
| "sb_cfm": lambda x: x <= 0,
|
| }
|
|
|
|
|
| def random_samples(shape, batch_size=TEST_BATCH_SIZE):
|
| """Generate random samples of different dimensions."""
|
| if isinstance(shape, int):
|
| shape = [shape]
|
| return [torch.randn(batch_size, *shape), torch.randn(batch_size, *shape)]
|
|
|
|
|
| def compute_xt_ut(method, x0, x1, t_given, sigma, epsilon):
|
| if method == "vp_cfm":
|
| sigma_t = sigma
|
| mu_t = torch.cos(math.pi / 2 * t_given) * x0 + torch.sin(math.pi / 2 * t_given) * x1
|
| computed_xt = mu_t + sigma_t * epsilon
|
| computed_ut = (
|
| math.pi
|
| / 2
|
| * (torch.cos(math.pi / 2 * t_given) * x1 - torch.sin(math.pi / 2 * t_given) * x0)
|
| )
|
| elif method == "t_cfm":
|
| sigma_t = 1 - (1 - sigma) * t_given
|
| mu_t = t_given * x1
|
| computed_xt = mu_t + sigma_t * epsilon
|
| computed_ut = (x1 - (1 - sigma) * computed_xt) / sigma_t
|
|
|
| elif method == "sb_cfm":
|
| sigma_t = sigma * torch.sqrt(t_given * (1 - t_given))
|
| mu_t = t_given * x1 + (1 - t_given) * x0
|
| computed_xt = mu_t + sigma_t * epsilon
|
| computed_ut = (
|
| (1 - 2 * t_given)
|
| / (2 * t_given * (1 - t_given) + 1e-8)
|
| * (computed_xt - (t_given * x1 + (1 - t_given) * x0))
|
| + x1
|
| - x0
|
| )
|
| elif method in ["exact_ot_cfm", "i_cfm"]:
|
| sigma_t = sigma
|
| mu_t = t_given * x1 + (1 - t_given) * x0
|
| computed_xt = mu_t + sigma_t * epsilon
|
| computed_ut = x1 - x0
|
|
|
| return computed_xt, computed_ut
|
|
|
|
|
| def get_flow_matcher(method, sigma):
|
| if method == "vp_cfm":
|
| fm = VariancePreservingConditionalFlowMatcher(sigma=sigma)
|
| elif method == "t_cfm":
|
| fm = TargetConditionalFlowMatcher(sigma=sigma)
|
| elif method == "sb_cfm":
|
| fm = SchrodingerBridgeConditionalFlowMatcher(sigma=sigma, ot_method="sinkhorn")
|
| elif method == "exact_ot_cfm":
|
| fm = ExactOptimalTransportConditionalFlowMatcher(sigma=sigma)
|
| elif method == "i_cfm":
|
| fm = ConditionalFlowMatcher(sigma=sigma)
|
| return fm
|
|
|
|
|
| def sample_plan(method, x0, x1, sigma):
|
| if method == "sb_cfm":
|
| x0, x1 = OTPlanSampler(method="sinkhorn", reg=2 * (sigma**2)).sample_plan(x0, x1)
|
| elif method == "exact_ot_cfm":
|
| x0, x1 = OTPlanSampler(method="exact").sample_plan(x0, x1)
|
| return x0, x1
|
|
|
|
|
| @pytest.mark.parametrize("method", ["vp_cfm", "t_cfm", "sb_cfm", "exact_ot_cfm", "i_cfm"])
|
|
|
| @pytest.mark.parametrize("sigma", [0.0, 5e-4, 0.5, 1.5, 0, 1])
|
| @pytest.mark.parametrize("shape", [[1], [2], [1, 2], [3, 4, 5]])
|
| def test_fm(method, sigma, shape):
|
| batch_size = TEST_BATCH_SIZE
|
|
|
| if method in SIGMA_CONDITION.keys() and SIGMA_CONDITION[method](sigma):
|
| with pytest.raises(ValueError):
|
| get_flow_matcher(method, sigma)
|
| return
|
|
|
| FM = get_flow_matcher(method, sigma)
|
| x0, x1 = random_samples(shape, batch_size=batch_size)
|
| torch.manual_seed(TEST_SEED)
|
| np.random.seed(TEST_SEED)
|
| t, xt, ut, eps = FM.sample_location_and_conditional_flow(x0, x1, return_noise=True)
|
| _ = FM.compute_lambda(t)
|
|
|
| if method in ["sb_cfm", "exact_ot_cfm"]:
|
| torch.manual_seed(TEST_SEED)
|
| np.random.seed(TEST_SEED)
|
| x0, x1 = sample_plan(method, x0, x1, sigma)
|
|
|
| torch.manual_seed(TEST_SEED)
|
| t_given_init = torch.rand(batch_size)
|
| t_given = t_given_init.reshape(-1, *([1] * (x0.dim() - 1)))
|
| sigma_pad = pad_t_like_x(sigma, x0)
|
| epsilon = torch.randn_like(x0)
|
| computed_xt, computed_ut = compute_xt_ut(method, x0, x1, t_given, sigma_pad, epsilon)
|
|
|
| assert torch.all(ut.eq(computed_ut))
|
| assert torch.all(xt.eq(computed_xt))
|
| assert torch.all(eps.eq(epsilon))
|
| assert any(t_given_init == t)
|
|
|