xiangzai's picture
Add files using upload-large-folder tool
5910540 verified
import copy
import math
import os
from typing import Any, List, Optional, Union
import numpy as np
import torch
from pytorch_lightning import LightningDataModule, LightningModule
from torch.distributions import MultivariateNormal
from torchdyn.core import NeuralODE
from torchvision import transforms
from .components.augmentation import (
AugmentationModule,
AugmentedVectorField,
Sequential,
)
from .components.distribution_distances import compute_distribution_distances
from .components.optimal_transport import OTPlanSampler
from .components.plotting import (
plot_samples,
plot_trajectory,
store_trajectories,
)
from .components.schedule import ConstantNoiseScheduler, NoiseScheduler
from .components.solver import FlowSolver
from .utils import get_wandb_logger
class CFMLitModule(LightningModule):
"""Conditional Flow Matching Module for training generative models and models over time."""
def __init__(
self,
net: Any,
optimizer: Any,
datamodule: LightningDataModule,
augmentations: AugmentationModule,
partial_solver: FlowSolver,
scheduler: Optional[Any] = None,
neural_ode: Optional[Any] = None,
ot_sampler: Optional[Union[str, Any]] = None,
sigma_min: float = 0.1,
avg_size: int = -1,
leaveout_timepoint: int = -1,
test_nfe: int = 100,
plot: bool = False,
nice_name: str = "CFM",
) -> None:
"""Initialize a conditional flow matching network either as a generative model or for a
sequence of timepoints.
Note: DDP does not currently work with NeuralODE objects from torchdyn
in the init so we initialize them every time we need to do a sampling
step.
Args:
net: torch module representing dx/dt = f(t, x) for t in [1, T] missing dimension.
optimizer: partial torch.optimizer missing parameters.
datamodule: datamodule object needs to have "dim", "IS_TRAJECTORY" properties.
ot_sampler: ot_sampler specified as an object or string. If none then no OT is used in minibatch.
sigma_min: sigma_min determines the width of the Gaussian smoothing of the data and interpolations.
leaveout_timepoint: which (if any) timepoint to leave out during the training phase
plot: if true, log intermediate plots during validation
"""
super().__init__()
self.save_hyperparameters(
ignore=[
"net",
"optimizer",
"scheduler",
"datamodule",
"augmentations",
"partial_solver",
],
logger=False,
)
self.datamodule = datamodule
self.is_trajectory = False
if hasattr(datamodule, "IS_TRAJECTORY"):
self.is_trajectory = datamodule.IS_TRAJECTORY
# dims is either an integer or a tuple. This helps us to decide whether to process things as
# a vector or as an image.
if hasattr(datamodule, "dim"):
self.dim = datamodule.dim
self.is_image = False
elif hasattr(datamodule, "dims"):
self.dim = datamodule.dims
self.is_image = True
else:
raise NotImplementedError("Datamodule must have either dim or dims")
self.net = net(dim=self.dim)
self.augmentations = augmentations
self.aug_net = AugmentedVectorField(self.net, self.augmentations.regs, self.dim)
self.val_augmentations = AugmentationModule(
# cnf_estimator=None,
l1_reg=1,
l2_reg=1,
squared_l2_reg=1,
)
self.val_aug_net = AugmentedVectorField(self.net, self.val_augmentations.regs, self.dim)
if neural_ode is not None:
self.aug_node = Sequential(
self.augmentations.augmenter,
neural_ode(self.aug_net),
)
self.partial_solver = partial_solver
self.optimizer = optimizer
self.scheduler = scheduler
self.ot_sampler = ot_sampler
if ot_sampler == "None":
self.ot_sampler = None
if isinstance(self.ot_sampler, str):
# regularization taken for optimal Schrodinger bridge relationship
self.ot_sampler = OTPlanSampler(method=ot_sampler, reg=2 * sigma_min**2)
self.criterion = torch.nn.MSELoss()
def forward_integrate(self, batch: Any, t_span: torch.Tensor):
"""Forward pass with integration over t_span intervals.
(t, x, t_span) -> [x_t_span].
"""
X = self.unpack_batch(batch)
X_start = X[:, t_span[0], :]
traj = self.node.trajectory(X_start, t_span=t_span)
return traj
def forward(self, t: torch.Tensor, x: torch.Tensor):
"""Forward pass (t, x) -> dx/dt."""
return self.net(t, x)
def unpack_batch(self, batch):
"""Unpacks a batch of data to a single tensor."""
if self.is_trajectory:
return torch.stack(batch, dim=1)
if not isinstance(self.dim, int):
# Assume this is an image classification dataset where we need to strip the targets
return batch[0]
return batch
def preprocess_batch(self, X, training=False):
"""Converts a batch of data into matched a random pair of (x0, x1)"""
t_select = torch.zeros(1, device=X.device)
if self.is_trajectory:
batch_size, times, dim = X.shape
if not hasattr(self.datamodule, "HAS_JOINT_PLANS"):
# resample the OT plan
# list of length t of tuples of length 2 of tensors of shape
tmp_ot_list = []
for t in range(times - 1):
if training and t + 1 == self.hparams.leaveout_timepoint:
tmp_ot = torch.stack((X[:, t], X[:, t + 2]))
else:
tmp_ot = torch.stack((X[:, t], X[:, t + 1]))
if (
training
and self.ot_sampler is not None
and t != self.hparams.leaveout_timepoint
):
tmp_ot = torch.stack(self.ot_sampler.sample_plan(tmp_ot[0], tmp_ot[1]))
tmp_ot_list.append(tmp_ot)
tmp_ot_list = torch.stack(tmp_ot_list)
# randomly sample a batch
if training and self.hparams.leaveout_timepoint > 0:
# Select random except for the leftout timepoint
t_select = torch.randint(times - 2, size=(batch_size,), device=X.device)
t_select[t_select >= self.hparams.leaveout_timepoint] += 1
else:
t_select = torch.randint(times - 1, size=(batch_size,))
x0 = []
x1 = []
for i in range(batch_size):
ti = t_select[i]
ti_next = ti + 1
if training and ti_next == self.hparams.leaveout_timepoint:
ti_next += 1
if hasattr(self.datamodule, "HAS_JOINT_PLANS"):
x0.append(torch.tensor(self.datamodule.timepoint_data[ti][X[i, ti]]))
pi = self.datamodule.pi[ti]
if training and ti + 1 == self.hparams.leaveout_timepoint:
pi = self.datamodule.pi_leaveout[ti]
index_batch = X[i][ti]
i_next = np.random.choice(
pi.shape[1], p=pi[index_batch] / pi[index_batch].sum()
)
x1.append(torch.tensor(self.datamodule.timepoint_data[ti_next][i_next]))
else:
x0.append(tmp_ot_list[ti][0][i])
x1.append(tmp_ot_list[ti][1][i])
x0, x1 = torch.stack(x0), torch.stack(x1)
else:
batch_size = X.shape[0]
# If no trajectory assume generate from standard normal
x0 = torch.randn_like(X)
x1 = X
return x0, x1, t_select
def average_ut(self, x, t, mu_t, sigma_t, ut):
pt = torch.exp(-0.5 * (torch.cdist(x, mu_t) ** 2) / (sigma_t**2))
batch_size = x.shape[0]
ind = torch.randint(
batch_size, size=(batch_size, self.hparams.avg_size - 1)
) # randomly (non-repreat) sample m-many index
# always include self
ind = torch.cat([ind, torch.arange(batch_size)[:, None]], dim=1)
pt_sub = torch.stack([pt[i, ind[i]] for i in range(batch_size)])
ut_sub = torch.stack([ut[ind[i]] for i in range(batch_size)])
p_sum = torch.sum(pt_sub, dim=1, keepdim=True)
ut = torch.sum(pt_sub[:, :, None] * ut_sub, dim=1) / p_sum
# Reduce batch size because they are all the same
return x[:1], ut[:1], t[:1]
def calc_mu_sigma(self, x0, x1, t):
mu_t = t * x1 + (1 - t) * x0
sigma_t = self.hparams.sigma_min
return mu_t, sigma_t
def calc_u(self, x0, x1, x, t, mu_t, sigma_t):
del x, t, mu_t, sigma_t
return x1 - x0
def calc_loc_and_target(self, x0, x1, t, t_select, training):
"""Computes the loss on a batch of data."""
t_xshape = t.reshape(-1, *([1] * (x0.dim() - 1)))
mu_t, sigma_t = self.calc_mu_sigma(x0, x1, t_xshape)
eps_t = torch.randn_like(mu_t)
x = mu_t + sigma_t * eps_t
ut = self.calc_u(x0, x1, x, t_xshape, mu_t, sigma_t)
# if we are starting from right before the leaveout_timepoint then we
# divide the target by 2
if training and self.hparams.leaveout_timepoint > 0:
ut[t_select + 1 == self.hparams.leaveout_timepoint] /= 2
t[t_select + 1 == self.hparams.leaveout_timepoint] *= 2
# p is the pair-wise conditional probability matrix. Note that this has to be torch.cdist(x, mu) in that order
# t that network sees is incremented by first timepoint
t = t + t_select.reshape(-1, *t.shape[1:])
return x, ut, t, mu_t, sigma_t, eps_t
def step(self, batch: Any, training: bool = False):
"""Computes the loss on a batch of data."""
X = self.unpack_batch(batch)
x0, x1, t_select = self.preprocess_batch(X, training)
# Either randomly sample a single T or sample a batch of T's
if self.hparams.avg_size > 0:
t = torch.rand(1).repeat(X.shape[0]).type_as(X)
else:
t = torch.rand(X.shape[0]).type_as(X)
# Resample the plan if we are using optimal transport
if self.ot_sampler is not None and not self.is_trajectory:
x0, x1 = self.ot_sampler.sample_plan(x0, x1)
x, ut, t, mu_t, sigma_t, eps_t = self.calc_loc_and_target(x0, x1, t, t_select, training)
if self.hparams.avg_size > 0:
x, ut, t = self.average_ut(x, t, mu_t, sigma_t, ut)
aug_x = self.aug_net(t, x, augmented_input=False)
reg, vt = self.augmentations(aug_x)
return torch.mean(reg), self.criterion(vt, ut)
def training_step(self, batch: Any, batch_idx: int):
reg, mse = self.step(batch, training=True)
loss = mse + reg
prefix = "train"
self.log_dict(
{f"{prefix}/loss": loss, f"{prefix}/mse": mse, f"{prefix}/reg": reg},
on_step=True,
on_epoch=False,
prog_bar=True,
)
return loss
def image_eval_step(self, batch: Any, batch_idx: int, prefix: str):
import os
from torchvision.utils import save_image
# val_augmentations = AugmentationModule(
# cnf_estimator="hutch",
# squared_l2_reg=1,
# )
# aug_dims = val_augmentations.aug_dims
# val_aug_net = AugmentedVectorField(self.net, val_augmentations.regs, self.dim)
# val_aug_node = Sequential(
# val_augmentations.augmenter,
# NeuralODE(val_aug_net, solver="euler", sensitivity="adjoint"),
# )
# t_span = torch.linspace(1, 0, 101)
# x = batch[0]
# os.makedirs("regularizations", exist_ok=True)
# for k in range(0):
# x_norm = cifar10_normalization()(x + (torch.rand_like(x) / 255))
# _, aug_traj = val_aug_node(x_norm, t_span)
# aug, traj = aug_traj[-1, :, :aug_dims], aug_traj[-1, :, aug_dims:]
# mn = MultivariateNormal(
# torch.zeros(prod(self.dim)).type_as(traj),
# torch.eye(prod(self.dim)).type_as(traj),
# )
# aug[:, 0] += mn.log_prob(traj.reshape(traj.shape[0], -1))
# np.save(
# f"regularizations/regs_{k}_{batch_idx}.npy",
# aug.detach().cpu().numpy(),
# )
solver = self.partial_solver(self.net, self.dim)
if isinstance(self.hparams.test_nfe, int):
t_span = torch.linspace(0, 1, int(self.hparams.test_nfe) + 1)
elif isinstance(self.hparams.test_nfe, str):
solver.ode_solver = "tsit5"
t_span = torch.linspace(0, 1, 2)
else:
raise NotImplementedError(f"Unknown test procedure {self.hparams.test_nfe}")
traj = solver.odeint(torch.randn(batch[0].shape[0], *self.dim).type_as(batch[0]), t_span)[
-1
]
os.makedirs("images", exist_ok=True)
mean = [-x / 255.0 for x in [125.3, 123.0, 113.9]]
std = [255.0 / x for x in [63.0, 62.1, 66.7]]
inv_normalize = transforms.Compose(
[
transforms.Normalize(mean=[0.0, 0.0, 0.0], std=std),
transforms.Normalize(mean=mean, std=[1.0, 1.0, 1.0]),
]
)
traj = inv_normalize(traj)
traj = torch.clip(traj, min=0, max=1.0)
for i, image in enumerate(traj):
save_image(image, fp=f"images/{batch_idx}_{i}.png")
return {"x": batch[0]}
def eval_step(self, batch: Any, batch_idx: int, prefix: str):
if prefix == "test" and self.is_image:
self.image_eval_step(batch, batch_idx, prefix)
shapes = [b.shape[0] for b in batch]
if not self.is_image and prefix == "val" and shapes.count(shapes[0]) == len(shapes):
reg, mse = self.step(batch, training=False)
loss = mse + reg
self.log_dict(
{f"{prefix}/loss": loss, f"{prefix}/mse": mse, f"{prefix}/reg": reg},
on_step=False,
on_epoch=True,
sync_dist=True,
)
return {"loss": loss, "mse": mse, "reg": reg, "x": self.unpack_batch(batch)}
return {"x": batch}
def preprocess_epoch_end(self, outputs: List[Any], prefix: str):
"""Preprocess the outputs of the epoch end function."""
if self.is_trajectory and prefix == "test" and isinstance(outputs[0]["x"], list):
# x is jagged if doing a trajectory
x = outputs[0]["x"]
ts = len(x)
x0 = x[0]
x_rest = x[1:]
elif self.is_trajectory:
if hasattr(self.datamodule, "HAS_JOINT_PLANS"):
x = [torch.tensor(dd) for dd in self.datamodule.timepoint_data]
x0 = x[0]
x_rest = x[1:]
ts = len(x)
else:
v = {k: torch.cat([d[k] for d in outputs]) for k in ["x"]}
x = v["x"]
ts = x.shape[1]
x0 = x[:, 0, :]
x_rest = x[:, 1:]
else:
if isinstance(self.dim, int):
v = {k: torch.cat([d[k] for d in outputs]) for k in ["x"]}
x = v["x"]
else:
x = [d["x"] for d in outputs][0][0][:100]
# Sample some random points for the plotting function
rand = torch.randn_like(x)
# rand = torch.randn_like(x, generator=torch.Generator(device=x.device).manual_seed(42))
x = torch.stack([rand, x], dim=1)
ts = x.shape[1]
x0 = x[:, 0]
x_rest = x[:, 1:]
return ts, x, x0, x_rest
def forward_eval_integrate(self, ts, x0, x_rest, outputs, prefix):
# Build a trajectory
t_span = torch.linspace(0, 1, 101)
regs = []
trajs = []
full_trajs = []
solver = self.partial_solver(self.net, self.dim)
nfe = 0
x0_tmp = x0.clone()
if self.is_image:
traj = solver.odeint(x0, t_span)
full_trajs.append(traj)
trajs.append(traj[0])
trajs.append(traj[-1])
nfe += solver.nfe
if not self.is_image:
solver.augmentations = self.val_augmentations
for i in range(ts - 1):
traj, aug = solver.odeint(x0_tmp, t_span + i)
full_trajs.append(traj)
traj, aug = traj[-1], aug[-1]
x0_tmp = traj
regs.append(torch.mean(aug, dim=0).detach().cpu().numpy())
trajs.append(traj)
nfe += solver.nfe
full_trajs = torch.cat(full_trajs)
if not self.is_image:
regs = np.stack(regs).mean(axis=0)
names = [f"{prefix}/{name}" for name in self.val_augmentations.names]
self.log_dict(dict(zip(names, regs)), sync_dist=True)
# Evaluate the fit
if (
self.is_trajectory
and prefix == "test"
and isinstance(outputs[0]["x"], list)
and not hasattr(self.datamodule, "GAUSSIAN_CLOSED_FORM")
):
# Redo the solver for each timepoint
trajs = []
full_trajs = []
nfe = 0
x0_tmp = x0
for i in range(ts - 1):
traj, _ = solver.odeint(x0_tmp, t_span + i)
traj = traj[-1]
x0_tmp = x_rest[i]
trajs.append(traj)
nfe += solver.nfe
names, dists = compute_distribution_distances(trajs[:-1], x_rest[:-1])
else:
names, dists = compute_distribution_distances(trajs, x_rest)
names = [f"{prefix}/{name}" for name in names]
d = dict(zip(names, dists))
if self.hparams.leaveout_timepoint >= 0:
to_add = {
f"{prefix}/t_out/{key.split('/')[-1]}": val
for key, val in d.items()
if key.startswith(f"{prefix}/t{self.hparams.leaveout_timepoint}")
}
d.update(to_add)
d[f"{prefix}/nfe"] = nfe
self.log_dict(d, sync_dist=True)
if hasattr(self.datamodule, "GAUSSIAN_CLOSED_FORM"):
solver.augmentations = None
# t_span = torch.linspace(0, 1, 101)
# traj = solver.odeint(x0, t_span)
# t_span = t_span[::5]
# traj = traj[::5]
t_span = torch.linspace(0, 1, 21)
traj = solver.odeint(x0, t_span)
assert traj.shape[0] == t_span.shape[0]
kls = [
self.datamodule.KL(xt, self.hparams.sigma_min, t) for t, xt in zip(t_span, traj)
]
self.log_dict({f"{prefix}/kl/mean": torch.stack(kls).mean().item()}, sync_dist=True)
self.log_dict({f"{prefix}/kl/tp_{i}": kls[i] for i in range(21)}, sync_dist=True)
return trajs, full_trajs
def eval_epoch_end(self, outputs: List[Any], prefix: str):
wandb_logger = get_wandb_logger(self.loggers)
if prefix == "test" and self.is_image:
os.makedirs("images", exist_ok=True)
if len(os.listdir("images")) > 0:
path = "/home/mila/a/alexander.tong/scratch/trajectory-inference/data/fid_stats_cifar10_train.npz"
from pytorch_fid import fid_score
fid = fid_score.calculate_fid_given_paths(["images", path], 256, "cuda", 2048, 0)
self.log(f"{prefix}/fid", fid)
ts, x, x0, x_rest = self.preprocess_epoch_end(outputs, prefix)
trajs, full_trajs = self.forward_eval_integrate(ts, x0, x_rest, outputs, prefix)
if self.hparams.plot:
if isinstance(self.dim, int):
plot_trajectory(
x,
full_trajs,
title=f"{self.current_epoch}_ode",
key="ode_path",
wandb_logger=wandb_logger,
)
else:
plot_samples(
trajs[-1],
title=f"{self.current_epoch}_samples",
wandb_logger=wandb_logger,
)
if prefix == "test" and not self.is_image:
store_trajectories(x, self.net)
def validation_step(self, batch: Any, batch_idx: int):
return self.eval_step(batch, batch_idx, "val")
def validation_epoch_end(self, outputs: List[Any]):
self.eval_epoch_end(outputs, "val")
def test_step(self, batch: Any, batch_idx: int):
return self.eval_step(batch, batch_idx, "test")
def test_epoch_end(self, outputs: List[Any]):
self.eval_epoch_end(outputs, "test")
def configure_optimizers(self):
"""Pass model parameters to optimizer."""
optimizer = self.optimizer(params=self.parameters())
if self.scheduler is None:
return optimizer
scheduler = self.scheduler(optimizer)
return [optimizer], [{"scheduler": scheduler, "interval": "epoch"}]
def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
scheduler.step(epoch=self.current_epoch)
class RectifiedFlowLitModule(CFMLitModule):
def __init__(
self,
net: Any,
optimizer: Any,
datamodule: LightningDataModule,
augmentations: AugmentationModule,
partial_solver: FlowSolver,
val_augmentations: Optional[AugmentationModule] = None,
scheduler: Optional[Any] = None,
neural_ode: Optional[Any] = None,
ot_sampler: Optional[Union[str, Any]] = None,
sigma_min: float = 0.1,
rectify_epochs: Optional[List[int]] = None,
test_nfe: int = 100,
avg_size: int = -1,
leaveout_timepoint: int = -1,
plot: bool = False,
nice_name: str = "Rect",
) -> None:
"""Initialize a conditional flow matching network either as a generative model or for a
sequence of timepoints.
Args:
net: torch module representing dx/dt = f(t, x) for t in [1, T] missing dimension.
optimizer: partial torch.optimizer missing parameters.
datamodule: datamodule object needs to have "dim", "IS_TRAJECTORY" properties.
ot_sampler: ot_sampler specified as an object or string. If none then no OT is used in minibatch.
sigma_min: sigma_min determines the width of the Gaussian smoothing of the data and interpolations.
leaveout_timepoint: which (if any) timepoint to leave out during the training phase
plot: if true, log intermediate plots during validation
"""
super(CFMLitModule, self).__init__()
self.save_hyperparameters(
ignore=[
"net",
"optimizer",
"scheduler",
"datamodule",
"augmentations",
"val_augmentations",
"partial_solver",
],
logger=False,
)
self.datamodule = datamodule
self.is_trajectory = False
if hasattr(datamodule, "IS_TRAJECTORY"):
self.is_trajectory = datamodule.IS_TRAJECTORY
if hasattr(datamodule, "dim"):
self.dim = datamodule.dim
self.is_image = False
elif hasattr(datamodule, "dims"):
self.dim = datamodule.dims
self.is_image = True
else:
raise NotImplementedError("Datamodule must have either dim or dims")
self.net = net(dim=self.dim)
self.frozen_net = None
self.augmentations = augmentations
self.aug_net = AugmentedVectorField(self.net, self.augmentations.regs, self.dim)
self.val_augmentations = val_augmentations
if val_augmentations is None:
self.val_augmentations = AugmentationModule(
l1_reg=1,
l2_reg=1,
squared_l2_reg=1,
)
self.val_aug_net = AugmentedVectorField(self.net, self.val_augmentations.regs, self.dim)
if neural_ode is not None:
self.aug_node = Sequential(
self.augmentations.augmenter,
neural_ode(self.aug_net),
)
self.partial_solver = partial_solver
self.optimizer = optimizer
self.scheduler = scheduler
self.ot_sampler = ot_sampler
if ot_sampler == "None":
self.ot_sampler = None
if isinstance(self.ot_sampler, str):
# regularization taken for optimal Schrodinger bridge relationship
self.ot_sampler = OTPlanSampler(method=ot_sampler, reg=2 * sigma_min**2)
self.criterion = torch.nn.MSELoss()
def preprocess_batch(self, X, training=False):
"""Converts a batch of data into matched a random pair of (x0, x1)"""
t_select = torch.zeros(1, device=X.device)
if self.is_trajectory:
batch_size, times, dim = X.shape
if training and self.hparams.leaveout_timepoint > 0:
# Select random except for the leftout timepoint
t_select = torch.randint(times - 2, size=(batch_size,), device=X.device)
t_select[t_select >= self.hparams.leaveout_timepoint] += 1
else:
t_select = torch.randint(times - 1, size=(batch_size,))
x0 = []
x1 = []
for i in range(batch_size):
ti = t_select[i]
ti_next = ti + 1
if training and ti_next == self.hparams.leaveout_timepoint:
ti_next += 1
x0.append(X[i, ti])
x1.append(X[i, ti_next])
x0, x1 = torch.stack(x0), torch.stack(x1)
else:
batch_size = X.shape[0]
# If no trajectory assume generate from standard normal
x0 = torch.randn_like(X)
x1 = X
if self.frozen_net is not None:
# Currently only works for 2 distributions
assert t_select[0] == 0
t_span = torch.linspace(0, 1, 100)
val_node = NeuralODE(self.frozen_net, solver="euler")
with torch.no_grad():
_, traj = val_node(x0, t_span)
x1 = traj[-1]
return x0, x1, t_select
def training_epoch_end(self, training_step_outputs):
if (
self.hparams.rectify_epochs is not None
and self.current_epoch in self.hparams.rectify_epochs
):
self.frozen_net = copy.deepcopy(self.net)
class ActionMatchingLitModule(CFMLitModule):
"""Implements Action Matching: Learning Stochastic Dynamics from Samples (Neklyudov et al.
2022)
Requires net to have a .energy function where net.energy(t, x): \\mathbb{R}^{d+1} \to
\\mathbb{R} and net.forward is equal to \nabla_x(net.energy).
"""
def step(self, batch: Any, training: bool = False):
"""Computes the loss on a batch of data."""
assert not self.is_trajectory
energy = self.net.energy
X = self.unpack_batch(batch)
x0, x1, t_select = self.preprocess_batch(X, training)
if self.ot_sampler is not None:
x0, x1 = self.ot_sampler.sample_plan(x0, x1)
t = torch.rand(X.shape[0]).type_as(X)
t_xshape = t.reshape(-1, *([1] * (x0.dim() - 1)))
xt = t_xshape * x1 + (1 - t_xshape) * x0
# t that network sees is incremented by first timepoint
t = t + t_select.reshape(-1, *t.shape[1:])
xt.requires_grad, t_xshape.requires_grad = True, True
with torch.set_grad_enabled(True):
st = torch.sum(energy(torch.cat([xt, t_xshape], dim=-1)))
dsdx, dsdt = torch.autograd.grad(st, (xt, t_xshape), create_graph=True)
xt.requires_grad, t_xshape.requires_grad = False, False
a0 = energy(torch.cat([x0, torch.zeros(x0.shape[0], 1)], dim=-1))
a1 = energy(torch.cat([x1, torch.ones(x1.shape[0], 1)], dim=-1))
loss = a0 - a1 + 0.5 * (dsdx**2).sum(1, keepdims=True) + dsdt
loss = loss.mean()
aug_x = self.aug_net(t, xt, augmented_input=False)
reg, vt = self.augmentations(aug_x)
return torch.mean(reg), loss
class VariancePreservingCFM(CFMLitModule):
"""Implements a variance preserving time schedule as suggested in (Albergo et al.
2023) here we have an interpolation cos(t pi/2) x_0 + sin(t pi/2) x_1.
"""
def calc_mu_sigma(self, x0, x1, t):
assert not self.is_trajectory
mu_t = torch.cos(math.pi / 2 * t) * x0 + torch.sin(math.pi / 2 * t) * x1
sigma_t = self.hparams.sigma_min
return mu_t, sigma_t
def calc_u(self, x0, x1, x, t, mu_t, sigma_t):
del x, mu_t, sigma_t
return math.pi / 2 * (torch.cos(math.pi / 2 * t) * x1 - torch.sin(math.pi / 2 * t) * x0)
class SBCFMLitModule(CFMLitModule):
"""Implements a Schrodinger Bridge based conditional flow matching model.
This is similar to the OTCFM loss, however with the variance varying with t*(1-t). This has
provably equal probability flow to the Schrodinger bridge solution when the transport is
computed with the squared Euclidean distance on R^d.
"""
def calc_mu_sigma(self, x0, x1, t):
assert not self.is_trajectory
mu_t = t * x1 + (1 - t) * x0
sigma_t = self.hparams.sigma_min * torch.sqrt(t - t**2)
return mu_t, sigma_t
def calc_u(self, x0, x1, x, t, mu_t, sigma_t):
del sigma_t
sigma_t_prime_over_sigma_t = (1 - 2 * t) / (2 * t * (1 - t))
ut = sigma_t_prime_over_sigma_t * (x - mu_t) + x1 - x0
return ut
class SF2MLitModule(CFMLitModule):
def __init__(
self,
net: Any,
optimizer: Any,
datamodule: LightningDataModule,
augmentations: AugmentationModule,
partial_solver: FlowSolver,
score_net: Optional[Any] = None,
scheduler: Optional[Any] = None,
ot_sampler: Optional[Union[str, Any]] = None,
sigma: Optional[NoiseScheduler] = None,
sigma_min: float = 0.1,
outer_loop_epochs: Optional[int] = None,
score_weight: float = 1.0,
avg_size: int = -1,
leaveout_timepoint: int = -1,
test_nfe: int = 100,
test_sde: bool = False,
plot: bool = False,
nice_name: Optional[str] = "SF2M",
) -> None:
"""Initialize a conditional flow matching network either as a generative model or for a
sequence of timepoints.
Args:
net: torch module representing dx/dt = f(t, x) for t in [1, T] missing dimension.
score_net: torch module representing the score function of the flow.
If not supplied it is assumed that the net contains both flow and
score.
optimizer: partial torch.optimizer missing parameters.
datamodule: datamodule object needs to have "dim", "IS_TRAJECTORY" properties.
ot_sampler: ot_sampler specified as an object or string. If none then no OT is used in minibatch.
sigma: sigma determines the width of the Gaussian smoothing of the data and interpolations.
leaveout_timepoint: which (if any) timepoint to leave out during the training phase
plot: if true, log intermediate plots during validation
"""
super(CFMLitModule, self).__init__()
self.save_hyperparameters(
ignore=[
"net",
"optimizer",
"scheduler",
"datamodule",
"augmentations",
"sigma_scheduler",
"partial_solver",
],
logger=False,
)
self.datamodule = datamodule
self.is_trajectory = False
if hasattr(datamodule, "IS_TRAJECTORY"):
self.is_trajectory = datamodule.IS_TRAJECTORY
# dims is either an integer or a tuple. This helps us to decide whether to process things as
# a vector or as an image.
if hasattr(datamodule, "dim"):
self.dim = datamodule.dim
self.is_image = False
elif hasattr(datamodule, "dims"):
self.dim = datamodule.dims
self.is_image = True
else:
raise NotImplementedError("Datamodule must have either dim or dims")
self.net = net(dim=self.dim)
self.separate_score = score_net is not None
self.score_net = score_net
if self.separate_score:
self.score_net = score_net(dim=self.dim)
self.partial_solver = partial_solver
self.augmentations = augmentations
self.aug_net = AugmentedVectorField(self.net, self.augmentations.regs, self.dim)
self.val_augmentations = AugmentationModule(
# cnf_estimator=None,
l1_reg=1,
l2_reg=1,
squared_l2_reg=1,
)
self.val_aug_net = AugmentedVectorField(self.net, self.val_augmentations.regs, self.dim)
self.optimizer = optimizer
self.scheduler = scheduler
self.sigma = sigma
if sigma is None:
self.sigma = ConstantNoiseScheduler(sigma_min)
self.ot_sampler = ot_sampler
if ot_sampler == "None":
self.ot_sampler = None
if isinstance(self.ot_sampler, str):
# regularization taken for optimal Schrodinger bridge relationship
self.ot_sampler = OTPlanSampler(method=ot_sampler, reg=2 * self.sigma.F(1))
self.criterion = torch.nn.MSELoss()
# If we are doing outer loops holds the current dataset
self.stored_data = None
self.tmp_stored_data = None
def calc_mu_sigma(self, x0, x1, t):
# assert not self.is_trajectory
ft = self.sigma.F(t)
fone = self.sigma.F(1)
mu_t = x0 + (x1 - x0) * ft / fone
# Note this is slightly different than the notebook. Which is correct?
sigma_t = torch.sqrt(ft - ft**2 / fone)
return mu_t, sigma_t
def calc_u(self, x0, x1, x, t, mu_t, sigma_t):
ft = self.sigma.F(t)
fone = self.sigma.F(1)
sigma_t_prime = self.sigma(t) ** 2 - 2 * ft * self.sigma(t) ** 2 / fone
sigma_t_prime_over_sigma_t = sigma_t_prime / (sigma_t + 1e-8)
mu_t_prime = (x1 - x0) * self.sigma(t) ** 2 / fone
ut = sigma_t_prime_over_sigma_t * (x - mu_t) + mu_t_prime
return ut
def calc_loc_and_target(self, x0, x1, t, t_select, training):
t_xshape = t.reshape(-1, *([1] * (x0.dim() - 1)))
mu_t, sigma_t = self.calc_mu_sigma(x0, x1, t_xshape)
eps_t = torch.randn_like(mu_t)
x = mu_t + sigma_t * eps_t
ut = self.calc_u(x0, x1, x, t_xshape, mu_t, sigma_t)
# if we are starting from right before the leaveout_timepoint then we
# divide the target by 2
if training and self.hparams.leaveout_timepoint > 0:
ut[t_select + 1 == self.hparams.leaveout_timepoint] /= 2
t[t_select + 1 == self.hparams.leaveout_timepoint] *= 2
# p is the pair-wise conditional probability matrix. Note that this has to be torch.cdist(x, mu) in that order
# t that network sees is incremented by first timepoint
score_target = eps_t
# score_target = -eps_t * self.sigma(t_xshape) ** 2 / 2
t = t + t_select.reshape(-1, *t.shape[1:])
return x, ut, t, mu_t, sigma_t, score_target
def forward_flow_and_score(self, t, x):
if self.separate_score:
reg, vt = self.augmentations(self.aug_net(t, x, augmented_input=False))
st = self.score_net(t, x)
return reg, vt, st
reg, vtst = self.augmentations(self.aug_net(t, x, augmented_input=False))
split_idx = vtst.shape[1] // 2
vt, st = vtst[:, :split_idx], vtst[:, split_idx:]
return reg, vt, st
def step(self, batch: Any, training: bool = False):
"""Computes the loss on a batch of data."""
X = self.unpack_batch(batch)
x0, x1, t_select = self.preprocess_batch(X, training)
# Either randomly sample a single T or sample a batch of T's
if self.hparams.avg_size > 0:
t = torch.rand(1).repeat(X.shape[0]).type_as(X)
else:
t = torch.rand(X.shape[0]).type_as(X)
# Resample the plan if we are using optimal transport
if self.ot_sampler is not None and self.stored_data is None:
x0, x1 = self.ot_sampler.sample_plan(x0, x1)
t_orig = t.clone()
x, ut, t, mu_t, sigma_t, score_target = self.calc_loc_and_target(
x0, x1, t, t_select, training
)
if self.hparams.avg_size > 0:
x, ut, t = self.average_ut(x, t, mu_t, sigma_t, ut)
reg, vt, st = self.forward_flow_and_score(t, x)
flow_loss = self.criterion(vt, ut)
score_loss = self.criterion(
-sigma_t * st / (self.sigma(t_orig.reshape(sigma_t.shape)) ** 2) * 2,
score_target,
)
return torch.mean(reg) + self.hparams.score_weight * score_loss, flow_loss
def forward_sde_eval(self, ts, x0, x_rest, outputs, prefix):
# Build a trajectory
t_span = torch.linspace(0, 1, 2)
solver = self.partial_solver(
self.net, self.dim, score_field=self.score_net, sigma=self.sigma
)
if False and self.is_image:
traj = solver.sdeint(x0, t_span, logqp=False)
trajs = []
full_trajs = []
nfe = 0
kldiv_total = 0
x0_tmp = x0.clone()
for i in range(ts - 1):
traj, kldiv = solver.sdeint(x0_tmp, t_span + i, logqp=True)
kldiv_total += torch.mean(kldiv[-1])
x0_tmp = traj[-1]
trajs.append(traj[-1])
full_trajs.append(traj)
nfe += solver.nfe
full_trajs = torch.cat(full_trajs)
if not self.is_image:
# Evaluate the fit
if (
self.is_trajectory
and prefix == "test"
and isinstance(outputs[0]["x"], list)
and not hasattr(self.datamodule, "GAUSSIAN_CLOSED_FORM")
):
trajs = []
full_trajs = []
nfe = 0
kldiv_total = 0
x0_tmp = x0.clone()
for i in range(ts - 1):
traj, kldiv = solver.sdeint(x0_tmp, t_span + i, logqp=True)
x0_tmp = x_rest[i]
kldiv_total += torch.mean(kldiv[-1])
trajs.append(traj[-1])
full_trajs.append(traj)
nfe += solver.nfe
names, dists = compute_distribution_distances(trajs[:-1], x_rest[:-1])
else:
names, dists = compute_distribution_distances(trajs, x_rest)
names = [f"{prefix}/sde/{name}" for name in names]
d = dict(zip(names, dists))
if self.hparams.leaveout_timepoint >= 0:
to_add = {
f"{prefix}/sde/t_out/{key.split('/')[-1]}": val
for key, val in d.items()
if key.startswith(f"{prefix}/sde/t{self.hparams.leaveout_timepoint}")
}
d.update(to_add)
d[f"{prefix}/sde/nfe"] = nfe
d[f"{prefix}/sde/kldiv"] = kldiv_total
self.log_dict(d, sync_dist=True)
if hasattr(self.datamodule, "GAUSSIAN_CLOSED_FORM"):
solver.augmentations = None
t_span = torch.linspace(0, 1, 21)
solver.dt = 0.05
# solver.dt = 0.01
traj = solver.sdeint(x0, t_span)
assert traj.shape[0] == t_span.shape[0]
kls = [
self.datamodule.KL(xt, self.hparams.sigma_min, t) for t, xt in zip(t_span, traj)
]
self.log_dict(
{f"{prefix}/sde/kl/mean": torch.stack(kls).mean().item()},
sync_dist=True,
)
self.log_dict({f"{prefix}/sde/kl/tp_{i}": kls[i] for i in range(21)}, sync_dist=True)
return trajs, full_trajs
def eval_epoch_end(self, outputs: List[Any], prefix: str):
super().eval_epoch_end(outputs, prefix)
wandb_logger = get_wandb_logger(self.loggers)
ts, x, x0, x_rest = self.preprocess_epoch_end(outputs, prefix)
if isinstance(self.dim, int):
traj, sde_traj = self.forward_sde_eval(ts, x0, x_rest, outputs, prefix)
if self.hparams.plot:
if isinstance(self.dim, int):
plot_trajectory(
x,
sde_traj,
title=f"{self.current_epoch}_sde_traj",
key="sde",
wandb_logger=wandb_logger,
)
def preprocess_batch(self, X, training=False):
"""Converts a batch of data into matched a random pair of (x0, x1)"""
if self.stored_data is not None and training:
# Randomly sample a batch from the stored data.
idx = torch.randint(self.stored_data.shape[0], size=(X.shape[0],))
X = self.stored_data[idx]
t_select = torch.zeros(1, device=X.device)
return X[:, 0], X[:, 1], t_select
return super().preprocess_batch(X, training)
def training_step(self, batch: Any, batch_idx: int):
# If we are doing outerloops we need to resample and store forward and backwards batches.
if (
self.hparams.outer_loop_epochs is not None
and (self.current_epoch + 1) % self.hparams.outer_loop_epochs == 0
):
X = self.unpack_batch(batch)
x0, x1, t_select = self.preprocess_batch(X, training=True)
assert not torch.any(t_select) # resampling outerloop can only handle 2 timepoints
solver = self.partial_solver
t_span = torch.linspace(0, 1, 2)
solver = self.partial_solver(
self.net, self.dim, score_field=self.score_net, sigma=self.sigma
)
batch_size = x0.shape[0]
with torch.no_grad():
forward_traj = solver.sdeint(x0[: batch_size // 2], t_span)
backward_traj = torch.flip(
solver.sdeint(x1[batch_size // 2 :], t_span, reverse=True), (0,)
)
stored_traj = torch.cat([forward_traj, backward_traj], dim=1)
stored_traj = stored_traj.transpose(0, 1)
if batch_idx == 0:
self.tmp_stored_data = []
self.tmp_stored_data.append(stored_traj)
return super().training_step(batch, batch_idx)
def training_epoch_end(self, training_step_outputs):
if (
self.hparams.outer_loop_epochs is not None
and (self.current_epoch + 1) % self.hparams.outer_loop_epochs == 0
):
self.stored_data = torch.cat(self.tmp_stored_data, dim=0).detach().clone()
def image_eval_step(self, batch: Any, batch_idx: int, prefix: str):
import os
from torchvision.utils import save_image
solver = self.partial_solver(self.net, self.dim)
if isinstance(self.hparams.test_nfe, int):
t_span = torch.linspace(0, 1, int(self.hparams.test_nfe) + 1)
elif isinstance(self.hparams.test_nfe, str):
solver.ode_solver = "tsit5"
t_span = torch.linspace(0, 1, 2).type_as(batch[0])
else:
raise NotImplementedError(f"Unknown test procedure {self.hparams.test_nfe}")
if self.hparams.test_sde:
solver = self.partial_solver(
self.net, self.dim, score_field=self.score_net, sigma=self.sigma
)
solver.dt = 1 / int(self.hparams.test_nfe)
t_span = torch.linspace(0, 1, 2).type_as(batch[0])
integrator = solver.sdeint
else:
integrator = solver.odeint
x0 = torch.randn(5 * batch[0].shape[0], *self.dim).type_as(batch[0])
traj = integrator(x0, t_span)[-1]
os.makedirs("images", exist_ok=True)
mean = [-x / 255.0 for x in [125.3, 123.0, 113.9]]
std = [255.0 / x for x in [63.0, 62.1, 66.7]]
inv_normalize = transforms.Compose(
[
transforms.Normalize(mean=[0.0, 0.0, 0.0], std=std),
transforms.Normalize(mean=mean, std=[1.0, 1.0, 1.0]),
]
)
traj = inv_normalize(traj)
traj = torch.clip(traj, min=0, max=1.0)
for i, image in enumerate(traj):
save_image(image, fp=f"images/{batch_idx}_{i}.png")
os.makedirs("compressed_images", exist_ok=True)
torch.save(traj.cpu(), f"compressed_images/{batch_idx}.pt")
return {"x": batch[0]}
class OneWaySF2MLitModule(SF2MLitModule):
def calc_loc_and_target(self, x0, x1, t, t_select, training):
x, ut, t, mu_t, sigma_t, score_target = super().calc_loc_and_target(
x0, x1, t, t_select, training
)
t_xshape = t.reshape(-1, *([1] * (x0.dim() - 1)))
eps_t = -score_target * 2 / (self.sigma(t_xshape) ** 2)
forward_target = (
x1 - x0 - (self.sigma(t_xshape) * torch.sqrt(t_xshape / (1 - t_xshape + 1e-6))) * eps_t
)
return x, forward_target, t, mu_t, sigma_t, None
def step(self, batch: Any, training: bool = False):
"""Computes the loss on a batch of data."""
X = self.unpack_batch(batch)
x0, x1, t_select = self.preprocess_batch(X, training)
# Either randomly sample a single T or sample a batch of T's
if self.hparams.avg_size > 0:
t = torch.rand(1).repeat(X.shape[0]).type_as(X)
else:
t = torch.rand(X.shape[0]).type_as(X)
# Resample the plan if we are using optimal transport
if self.ot_sampler is not None and self.stored_data is None:
x0, x1 = self.ot_sampler.sample_plan(x0, x1)
x, forward_target, t, _, _, _ = self.calc_loc_and_target(x0, x1, t, t_select, training)
t_xshape = t.reshape(-1, *([1] * (x0.dim() - 1)))
forward_scaling = (1 + self.sigma(t_xshape) ** 2 * t_xshape / (1 - t_xshape + 1e-6)) ** -1
reg, vt, st = self.forward_flow_and_score(t, x)
forward_flow_loss = torch.mean(forward_scaling * (vt - forward_target) ** 2)
return torch.mean(reg), forward_flow_loss
def forward_eval_integrate(self, ts, x0, x_rest, outputs, prefix):
# Build a trajectory
t_span = torch.linspace(0, 1, 101).type_as(x0)
regs = []
trajs = []
full_trajs = []
solver = self.partial_solver(
self.net, self.dim, score_field=self.score_net, sigma=self.sigma
)
nfe = 0
x0_tmp = x0.clone()
for i in range(ts - 1):
if not self.is_image:
solver.augmentations = self.val_augmentations
traj, aug = solver.sdeint(x0_tmp, t_span + i)
aug = aug[-1]
regs.append(torch.mean(aug, dim=0).detach().cpu().numpy())
else:
traj = solver.sdeint(x0_tmp, t_span + i)
full_trajs.append(traj)
traj = traj[-1]
x0_tmp = traj
trajs.append(traj)
nfe += solver.nfe
if not self.is_image:
regs = np.stack(regs).mean(axis=0)
names = [f"{prefix}/{name}" for name in self.val_augmentations.names]
self.log_dict(dict(zip(names, regs)), sync_dist=True)
# Evaluate the fit
names, dists = compute_distribution_distances(trajs, x_rest)
names = [f"{prefix}/{name}" for name in names]
d = dict(zip(names, dists))
if self.hparams.leaveout_timepoint >= 0:
to_add = {
f"{prefix}/t_out/{key.split('/')[-1]}": val
for key, val in d.items()
if key.startswith(f"{prefix}/t{self.hparams.leaveout_timepoint}")
}
d.update(to_add)
d[f"{prefix}/nfe"] = nfe
self.log_dict(d, sync_dist=True)
if hasattr(self.datamodule, "GAUSSIAN_CLOSED_FORM"):
solver.augmentations = None
t_span = torch.linspace(0, 1, 21) # 101
traj = solver.odeint(x0, t_span)
# t_span = t_span[::5]
# traj = traj[::5]
assert traj.shape[0] == t_span.shape[0]
kls = [
self.datamodule.KL(xt, self.hparams.sigma_min, t) for t, xt in zip(t_span, traj)
]
# others = torch.stack([self.datamodule.detailed_evaluation(xt, self.hparams.sigma_min, t) for t, xt in zip(t_span, traj)])
self.log_dict({f"{prefix}/kl/mean": torch.stack(kls).mean().item()}, sync_dist=True)
self.log_dict({f"{prefix}/kl/tp_{i}": kls[i] for i in range(21)}, sync_dist=True)
full_trajs = torch.cat(full_trajs)
return trajs, full_trajs
class DSBMLitModule(SF2MLitModule):
"""Based on SF2M module except directly regresses against the target SDE drift rather than
separating the ODE and Score components."""
def calc_loc_and_target(self, x0, x1, t, t_select, training):
t_xshape = t.reshape(-1, *([1] * (x0.dim() - 1))).clone()
x, ut, t_plus_t_select, mu_t, sigma_t, eps_t = super().calc_loc_and_target(
x0, x1, t, t_select, training
)
forward_target = (
x1 - x0 - (self.sigma(t_xshape) * torch.sqrt(t_xshape / (1 - t_xshape + 1e-6))) * eps_t
)
backward_target = (
x0
- x1
- (self.sigma(t_xshape) * torch.sqrt((1 - t_xshape) / (t_xshape + 1e-6))) * eps_t
)
return x, forward_target, t_plus_t_select, mu_t, sigma_t, backward_target
def step(self, batch: Any, training: bool = False):
"""Computes the loss on a batch of data."""
X = self.unpack_batch(batch)
x0, x1, t_select = self.preprocess_batch(X, training)
# Either randomly sample a single T or sample a batch of T's
if self.hparams.avg_size > 0:
t = torch.rand(1).repeat(X.shape[0]).type_as(X)
else:
t = torch.rand(X.shape[0]).type_as(X)
# Resample the plan if we are using optimal transport
if self.ot_sampler is not None and self.stored_data is None:
x0, x1 = self.ot_sampler.sample_plan(x0, x1)
forward_scaling = (1 + self.sigma(t) ** 2 * t / (1 - t + 1e-6)) ** -1
backward_scaling = (1 + self.sigma(t) ** 2 * (1 - t) / (t + 1e-6)) ** -1
x, forward_target, t, _, _, backward_target = self.calc_loc_and_target(
x0, x1, t, t_select, training
)
# print(forward_target, backward_target, x0, x1, t, t_select)
reg, vt, st = self.forward_flow_and_score(t, x)
forward_flow_loss = torch.mean(forward_scaling[:, None] * (vt - forward_target) ** 2)
backward_flow_loss = torch.mean(backward_scaling[:, None] * (st - backward_target) ** 2)
if not torch.isfinite(forward_flow_loss) or not torch.isfinite(backward_flow_loss):
raise ValueError("Loss Not Finite")
return torch.mean(reg) + backward_flow_loss, forward_flow_loss
def forward_eval_integrate(self, ts, x0, x_rest, outputs, prefix):
# Build a trajectory
t_span = torch.linspace(0, 1, 101)
regs = []
trajs = []
full_trajs = []
solver = self.partial_solver(
self.net, self.dim, score_field=self.score_net, sigma=self.sigma
)
nfe = 0
x0_tmp = x0.clone()
for i in range(ts - 1):
if not self.is_image:
solver.augmentations = self.val_augmentations
traj, aug = solver.odeint(x0_tmp, t_span + i)
else:
traj = solver.odeint(x0_tmp, t_span + i)
full_trajs.append(traj)
if not self.is_image:
traj, aug = traj[-1], aug[-1]
else:
traj = traj[-1]
aug = torch.tensor(0.0)
x0_tmp = traj
regs.append(torch.mean(aug, dim=0).detach().cpu().numpy())
trajs.append(traj)
nfe += solver.nfe
if not self.is_image:
regs = np.stack(regs).mean(axis=0)
names = [f"{prefix}/{name}" for name in self.val_augmentations.names]
self.log_dict(dict(zip(names, regs)), sync_dist=True)
# Evaluate the fit
names, dists = compute_distribution_distances(trajs, x_rest)
names = [f"{prefix}/{name}" for name in names]
d = dict(zip(names, dists))
if self.hparams.leaveout_timepoint >= 0:
to_add = {
f"{prefix}/t_out/{key.split('/')[-1]}": val
for key, val in d.items()
if key.startswith(f"{prefix}/t{self.hparams.leaveout_timepoint}")
}
d.update(to_add)
d[f"{prefix}/nfe"] = nfe
self.log_dict(d, sync_dist=True)
if hasattr(self.datamodule, "GAUSSIAN_CLOSED_FORM"):
solver.augmentations = None
t_span = torch.linspace(0, 1, 21) # 101
traj = solver.odeint(x0, t_span)
# t_span = t_span[::5]
# traj = traj[::5]
assert traj.shape[0] == t_span.shape[0]
kls = [
self.datamodule.KL(xt, self.hparams.sigma_min, t) for t, xt in zip(t_span, traj)
]
# others = torch.stack([self.datamodule.detailed_evaluation(xt, self.hparams.sigma_min, t) for t, xt in zip(t_span, traj)])
self.log_dict({f"{prefix}/kl/mean": torch.stack(kls).mean().item()}, sync_dist=True)
self.log_dict({f"{prefix}/kl/tp_{i}": kls[i] for i in range(21)}, sync_dist=True)
full_trajs = torch.cat(full_trajs)
return trajs, full_trajs
class DSBMSharedLitModule(SF2MLitModule):
"""Based on SF2M module except directly regresses against the target SDE drift rather than
separating the ODE and Score components."""
def step(self, batch: Any, training: bool = False):
"""Computes the loss on a batch of data."""
X = self.unpack_batch(batch)
x0, x1, t_select = self.preprocess_batch(X, training)
# Either randomly sample a single T or sample a batch of T's
if self.hparams.avg_size > 0:
t = torch.rand(1).repeat(X.shape[0]).type_as(X)
else:
t = torch.rand(X.shape[0]).type_as(X)
# Resample the plan if we are using optimal transport
if self.ot_sampler is not None:
x0, x1 = self.ot_sampler.sample_plan(x0, x1)
x, ut, t, mu_t, sigma_t, score_target = self.calc_loc_and_target(
x0, x1, t, t_select, training
)
if self.hparams.avg_size > 0:
x, ut, t = self.average_ut(x, t, mu_t, sigma_t, ut)
aug_x = self.aug_net(t, x, augmented_input=False)
reg, vt = self.augmentations(aug_x)
forward_flow_loss = self.criterion(vt + sigma_t * self.score_net(t, x), ut + score_target)
backward_flow_loss = self.criterion(
-vt + sigma_t * self.score_net(t, x), -ut + score_target
)
# flow_loss = self.criterion(vt + sigma_t * self.score_net, ut + score_target)
# score_loss = self.criterion(sigma_t * self.score_net(t, x), score_target)
return torch.mean(reg) + backward_flow_loss, forward_flow_loss
class FMLitModule(CFMLitModule):
"""Implements a Lipman et al.
2023 style flow matching loss. This maps the standard normal distribution to the data
distribution by using conditional flows that are the optimal transport flow from a narrow
Gaussian around a datapoint to a standard N(x | 0, 1).
"""
def calc_mu_sigma(self, x0, x1, t):
assert not self.is_trajectory
del x0
sigma_min = self.hparams.sigma_min
mu_t = t * x1
sigma_t = 1 - (1 - sigma_min) * t
return mu_t, sigma_t
def calc_u(self, x0, x1, x, t, mu_t, sigma_t):
del x0, mu_t, sigma_t
sigma_min = self.hparams.sigma_min
ut = (x1 - (1 - sigma_min) * x) / (1 - (1 - sigma_min) * t)
return ut
class SplineCFMLitModule(CFMLitModule):
"""Implements cubic spline version of OT-CFM."""
def preprocess_batch(self, X, training=False):
from torchcubicspline import NaturalCubicSpline, natural_cubic_spline_coeffs
"""Converts a batch of data into matched a random pair of (x0, x1)"""
lotp = self.hparams.leaveout_timepoint
valid_times = torch.arange(X.shape[1]).type_as(X)
t_select = torch.zeros(1)
batch_size, times, dim = X.shape
# TODO handle leaveout case
if training and self.hparams.leaveout_timepoint > 0:
# Select random except for the leftout timepoint
t_select = torch.randint(times - 2, size=(batch_size,))
X = torch.cat([X[:, :lotp], X[:, lotp + 1 :]], dim=1)
valid_times = valid_times[valid_times != lotp]
else:
t_select = torch.randint(times - 1, size=(batch_size,))
traj = torch.from_numpy(self.ot_sampler.sample_trajectory(X)).type_as(X)
x0 = []
x1 = []
for i in range(batch_size):
x0.append(traj[i, t_select[i]])
x1.append(traj[i, t_select[i] + 1])
x0, x1 = torch.stack(x0), torch.stack(x1)
if training and self.hparams.leaveout_timepoint > 0:
t_select[t_select >= self.hparams.leaveout_timepoint] += 1
coeffs = natural_cubic_spline_coeffs(valid_times, traj)
spline = NaturalCubicSpline(coeffs)
return x0, x1, t_select, spline
def step(self, batch: Any, training: bool = False):
"""Computes the loss on a batch of data."""
assert self.is_trajectory
X = self.unpack_batch(batch)
x0, x1, t_select, spline = self.preprocess_batch(X, training)
t = torch.rand(X.shape[0], 1)
# t [batch, 1]
# coeffs [batch, times, dims]
# t that network sees is incremented by first timepoint
t = t + t_select[:, None]
ut = torch.stack([spline.derivative(b[0])[i] for i, b in enumerate(t)], dim=0)
mu_t = torch.stack([spline.evaluate(b[0])[i] for i, b in enumerate(t)], dim=0)
sigma_t = self.hparams.sigma_min
# if we are starting from right before the leaveout_timepoint then we
# divide the target by 2
if training and self.hparams.leaveout_timepoint > 0:
ut[t_select + 1 == self.hparams.leaveout_timepoint] /= 2
t[t_select + 1 == self.hparams.leaveout_timepoint] *= 2
x = mu_t + sigma_t * torch.randn_like(x0)
aug_x = self.aug_net(t, x, augmented_input=False)
reg, vt = self.augmentations(aug_x)
return torch.mean(reg), self.criterion(vt, ut)
class CNFLitModule(CFMLitModule):
def forward_integrate(self, batch: Any, t_span: torch.Tensor):
"""Forward pass with integration over t_span intervals.
(t, x, t_span) -> [x_t_span].
"""
return super().forward_integrate(batch, t_span + 1)
def step(self, batch: Any, training: bool = False):
obs = self.unpack_batch(batch)
if not self.is_trajectory:
obs = obs[:, None, :]
even_ts = torch.arange(obs.shape[1]).to(obs) + 1
self.prior = MultivariateNormal(
torch.zeros(self.dim).type_as(obs), torch.eye(self.dim).type_as(obs)
)
# Minimize the log likelihood by integrating all back to the initial timepoint
reversed_ts = torch.cat([torch.flip(even_ts, [0]), torch.tensor([0]).type_as(even_ts)])
# If only one timepoint then Gaussian is at t0, data t1
# If multiple timepoints then Gaussian is at t_{-1} data is at times 0 to T
if self.is_trajectory:
reversed_ts -= 1
losses = []
regs = []
for t in range(len(reversed_ts) - 1):
# When leaving out a timepoint simply skip it in the backwards integration
if self.hparams.leaveout_timepoint == t:
continue
ts, x = reversed_ts[t:], obs[:, len(even_ts) - t - 1, :]
# ts, x = self.aug(reversed_ts[t:], obs[:, len(even_ts) - t - 1, :])
_, x = self.aug_node(x, ts)
x = x[-1]
# Assume log prob is in zero spot
delta_logprob, reg, x = self.augmentations(x)
logprob = self.prior.log_prob(x).to(x) - delta_logprob
losses.append(-torch.mean(logprob))
# negative because we are integrating backwards
regs.append(-reg)
# Predicted locations
reg = torch.mean(torch.stack(regs))
loss = torch.mean(torch.stack(losses))
return reg, loss