flow-matching / src /stage2 /transport.py
sabertoaster's picture
Upload folder using huggingface_hub
4edc9aa verified
import enum
import torch as th
class ModelType(enum.Enum):
NOISE = enum.auto()
SCORE = enum.auto()
VELOCITY = enum.auto()
class PathType(enum.Enum):
LINEAR = enum.auto()
GVP = enum.auto()
VP = enum.auto()
class WeightType(enum.Enum):
NONE = enum.auto()
VELOCITY = enum.auto()
LIKELIHOOD = enum.auto()
class LinearPath:
"""Linear OT path: x_t = (1 - t) * x_0 + t * x_1."""
def plan(self, t: th.Tensor, x0: th.Tensor, x1: th.Tensor):
if t.ndim == 0:
t = t.expand(x1.shape[0])
t_in = t
t_expanded = t
while t_expanded.ndim < x1.ndim:
t_expanded = t_expanded.unsqueeze(-1)
xt = (1 - t_expanded) * x0 + t_expanded * x1
ut = x1 - x0
return t_in, xt, ut
class Transport:
"""Flow matching transport with configurable path and time distribution."""
def __init__(
self,
*,
model_type,
path_type,
loss_type,
time_dist_type,
time_dist_shift,
train_eps,
sample_eps,
):
path_options = {
PathType.LINEAR: LinearPath,
PathType.GVP: LinearPath,
}
if path_type not in path_options:
raise NotImplementedError(
f"Path type '{path_type}' is not implemented in this Stage-2 transport."
)
self.loss_type = loss_type
self.model_type = model_type
self.time_dist_type = time_dist_type
self.time_dist_shift = max(float(time_dist_shift), 1.0)
self.path_sampler = path_options[path_type]()
self.train_eps = train_eps
self.sample_eps = sample_eps
def check_interval(
self,
train_eps,
sample_eps,
*,
sde=False,
eval=False,
reverse=False,
last_step_size=0.0,
):
t0 = 0.0
t1 = 1.0 - 1 / 1000
eps = train_eps if not eval else sample_eps
if isinstance(self.path_sampler, LinearPath) and (
self.model_type != ModelType.VELOCITY or sde
):
t0 = eps if sde else 0.0
t1 = 1.0 - eps if (not sde or last_step_size == 0) else 1.0 - last_step_size
if reverse:
t0, t1 = 1.0 - t0, 1.0 - t1
return t0, t1
def sample_timestep(self, x1: th.Tensor) -> th.Tensor:
"""Sample timesteps and apply CSFM shifted-uniform schedule."""
t0, t1 = self.check_interval(self.train_eps, self.sample_eps)
if self.time_dist_type == "uniform":
t = th.rand((x1.shape[0],), device=x1.device, dtype=x1.dtype)
t = t * (t1 - t0) + t0
else:
raise NotImplementedError(
f"Unknown time distribution: {self.time_dist_type}"
)
t = self.time_dist_shift * t / (1 + (self.time_dist_shift - 1) * t)
return t
def get_drift(self):
"""Get ODE drift function for velocity models."""
def velocity_ode(x, t, model, **model_kwargs):
return model(x, t, **model_kwargs)
def body_fn(x, t, model, **model_kwargs):
model_output = velocity_ode(x, t, model, **model_kwargs)
if model_output.shape != x.shape:
raise ValueError(
f"Output shape {model_output.shape} does not match input shape {x.shape}."
)
return model_output
return body_fn
def _build_time_grid(
t0: float,
t1: float,
num_steps: int,
time_dist_shift: float,
device,
dtype,
) -> th.Tensor:
tau = th.linspace(0, 1, num_steps + 1, device=device, dtype=dtype)
if time_dist_shift != 1.0:
tau = time_dist_shift * tau / (1 + (time_dist_shift - 1.0) * tau)
return t0 + (t1 - t0) * tau
class Sampler:
"""Fixed-step ODE sampler for CSFM inference."""
def __init__(self, transport: Transport):
self.transport = transport
self.drift = transport.get_drift()
def sample_ode(
self,
*,
sampling_method="euler",
num_steps=50,
reverse=False,
):
method = sampling_method.lower()
if method not in {"euler", "midpoint", "heun"}:
raise ValueError(
"sampling_method must be one of: 'euler', 'midpoint', 'heun'."
)
def sample_fn(x_init: th.Tensor, model, **model_kwargs):
t0, t1 = self.transport.check_interval(
self.transport.train_eps,
self.transport.sample_eps,
sde=False,
eval=True,
reverse=reverse,
last_step_size=0.0,
)
times = _build_time_grid(
t0=t0,
t1=t1,
num_steps=num_steps,
time_dist_shift=self.transport.time_dist_shift,
device=x_init.device,
dtype=x_init.dtype,
)
if reverse:
times = th.flip(times, dims=[0])
traj = [x_init]
x = x_init
for idx in range(num_steps):
t_cur = times[idx]
t_next = times[idx + 1]
dt = t_next - t_cur
t_batch = th.full(
(x.shape[0],),
t_cur,
device=x.device,
dtype=x.dtype,
)
if method == "euler":
k1 = self.drift(x, t_batch, model, **model_kwargs)
x = x + dt * k1
elif method == "midpoint":
k1 = self.drift(x, t_batch, model, **model_kwargs)
t_mid = t_batch + 0.5 * dt
x_mid = x + 0.5 * dt * k1
k2 = self.drift(x_mid, t_mid, model, **model_kwargs)
x = x + dt * k2
else: # heun
k1 = self.drift(x, t_batch, model, **model_kwargs)
x_euler = x + dt * k1
t_end = t_batch + dt
k2 = self.drift(x_euler, t_end, model, **model_kwargs)
x = x + 0.5 * dt * (k1 + k2)
traj.append(x)
return th.stack(traj, dim=0)
return sample_fn
def create_transport(
path_type="Linear",
prediction="velocity",
loss_weight=None,
time_dist_type="uniform",
time_dist_shift=1.0,
):
"""Create a Transport object for flow matching."""
model_type = {
"velocity": ModelType.VELOCITY,
"noise": ModelType.NOISE,
"score": ModelType.SCORE,
}[prediction]
loss_type = {
"velocity": WeightType.VELOCITY,
"likelihood": WeightType.LIKELIHOOD,
}.get(loss_weight, WeightType.NONE)
path_enum = {
"Linear": PathType.LINEAR,
"GVP": PathType.GVP,
"VP": PathType.VP,
}[path_type]
if path_enum == PathType.VP:
train_eps, sample_eps = 1e-5, 1e-3
elif model_type != ModelType.VELOCITY:
train_eps, sample_eps = 1e-3, 1e-3
else:
train_eps, sample_eps = 0, 0
return Transport(
model_type=model_type,
path_type=path_enum,
loss_type=loss_type,
time_dist_type=time_dist_type,
time_dist_shift=time_dist_shift,
train_eps=train_eps,
sample_eps=sample_eps,
)