| import enum |
|
|
| import torch as th |
|
|
|
|
| class ModelType(enum.Enum): |
| NOISE = enum.auto() |
| SCORE = enum.auto() |
| VELOCITY = enum.auto() |
|
|
|
|
| class PathType(enum.Enum): |
| LINEAR = enum.auto() |
| GVP = enum.auto() |
| VP = enum.auto() |
|
|
|
|
| class WeightType(enum.Enum): |
| NONE = enum.auto() |
| VELOCITY = enum.auto() |
| LIKELIHOOD = enum.auto() |
|
|
|
|
| class LinearPath: |
| """Linear OT path: x_t = (1 - t) * x_0 + t * x_1.""" |
|
|
| def plan(self, t: th.Tensor, x0: th.Tensor, x1: th.Tensor): |
| if t.ndim == 0: |
| t = t.expand(x1.shape[0]) |
|
|
| t_in = t |
| t_expanded = t |
| while t_expanded.ndim < x1.ndim: |
| t_expanded = t_expanded.unsqueeze(-1) |
|
|
| xt = (1 - t_expanded) * x0 + t_expanded * x1 |
| ut = x1 - x0 |
| return t_in, xt, ut |
|
|
|
|
| class Transport: |
| """Flow matching transport with configurable path and time distribution.""" |
|
|
| def __init__( |
| self, |
| *, |
| model_type, |
| path_type, |
| loss_type, |
| time_dist_type, |
| time_dist_shift, |
| train_eps, |
| sample_eps, |
| ): |
| path_options = { |
| PathType.LINEAR: LinearPath, |
| PathType.GVP: LinearPath, |
| } |
|
|
| if path_type not in path_options: |
| raise NotImplementedError( |
| f"Path type '{path_type}' is not implemented in this Stage-2 transport." |
| ) |
|
|
| self.loss_type = loss_type |
| self.model_type = model_type |
| self.time_dist_type = time_dist_type |
| self.time_dist_shift = max(float(time_dist_shift), 1.0) |
| self.path_sampler = path_options[path_type]() |
| self.train_eps = train_eps |
| self.sample_eps = sample_eps |
|
|
| def check_interval( |
| self, |
| train_eps, |
| sample_eps, |
| *, |
| sde=False, |
| eval=False, |
| reverse=False, |
| last_step_size=0.0, |
| ): |
| t0 = 0.0 |
| t1 = 1.0 - 1 / 1000 |
| eps = train_eps if not eval else sample_eps |
|
|
| if isinstance(self.path_sampler, LinearPath) and ( |
| self.model_type != ModelType.VELOCITY or sde |
| ): |
| t0 = eps if sde else 0.0 |
| t1 = 1.0 - eps if (not sde or last_step_size == 0) else 1.0 - last_step_size |
|
|
| if reverse: |
| t0, t1 = 1.0 - t0, 1.0 - t1 |
|
|
| return t0, t1 |
|
|
| def sample_timestep(self, x1: th.Tensor) -> th.Tensor: |
| """Sample timesteps and apply CSFM shifted-uniform schedule.""" |
| t0, t1 = self.check_interval(self.train_eps, self.sample_eps) |
|
|
| if self.time_dist_type == "uniform": |
| t = th.rand((x1.shape[0],), device=x1.device, dtype=x1.dtype) |
| t = t * (t1 - t0) + t0 |
| else: |
| raise NotImplementedError( |
| f"Unknown time distribution: {self.time_dist_type}" |
| ) |
|
|
| t = self.time_dist_shift * t / (1 + (self.time_dist_shift - 1) * t) |
| return t |
|
|
| def get_drift(self): |
| """Get ODE drift function for velocity models.""" |
|
|
| def velocity_ode(x, t, model, **model_kwargs): |
| return model(x, t, **model_kwargs) |
|
|
| def body_fn(x, t, model, **model_kwargs): |
| model_output = velocity_ode(x, t, model, **model_kwargs) |
| if model_output.shape != x.shape: |
| raise ValueError( |
| f"Output shape {model_output.shape} does not match input shape {x.shape}." |
| ) |
| return model_output |
|
|
| return body_fn |
|
|
|
|
| def _build_time_grid( |
| t0: float, |
| t1: float, |
| num_steps: int, |
| time_dist_shift: float, |
| device, |
| dtype, |
| ) -> th.Tensor: |
| tau = th.linspace(0, 1, num_steps + 1, device=device, dtype=dtype) |
| if time_dist_shift != 1.0: |
| tau = time_dist_shift * tau / (1 + (time_dist_shift - 1.0) * tau) |
| return t0 + (t1 - t0) * tau |
|
|
|
|
| class Sampler: |
| """Fixed-step ODE sampler for CSFM inference.""" |
|
|
| def __init__(self, transport: Transport): |
| self.transport = transport |
| self.drift = transport.get_drift() |
|
|
| def sample_ode( |
| self, |
| *, |
| sampling_method="euler", |
| num_steps=50, |
| reverse=False, |
| ): |
| method = sampling_method.lower() |
| if method not in {"euler", "midpoint", "heun"}: |
| raise ValueError( |
| "sampling_method must be one of: 'euler', 'midpoint', 'heun'." |
| ) |
|
|
| def sample_fn(x_init: th.Tensor, model, **model_kwargs): |
| t0, t1 = self.transport.check_interval( |
| self.transport.train_eps, |
| self.transport.sample_eps, |
| sde=False, |
| eval=True, |
| reverse=reverse, |
| last_step_size=0.0, |
| ) |
|
|
| times = _build_time_grid( |
| t0=t0, |
| t1=t1, |
| num_steps=num_steps, |
| time_dist_shift=self.transport.time_dist_shift, |
| device=x_init.device, |
| dtype=x_init.dtype, |
| ) |
|
|
| if reverse: |
| times = th.flip(times, dims=[0]) |
|
|
| traj = [x_init] |
| x = x_init |
|
|
| for idx in range(num_steps): |
| t_cur = times[idx] |
| t_next = times[idx + 1] |
| dt = t_next - t_cur |
|
|
| t_batch = th.full( |
| (x.shape[0],), |
| t_cur, |
| device=x.device, |
| dtype=x.dtype, |
| ) |
|
|
| if method == "euler": |
| k1 = self.drift(x, t_batch, model, **model_kwargs) |
| x = x + dt * k1 |
|
|
| elif method == "midpoint": |
| k1 = self.drift(x, t_batch, model, **model_kwargs) |
| t_mid = t_batch + 0.5 * dt |
| x_mid = x + 0.5 * dt * k1 |
| k2 = self.drift(x_mid, t_mid, model, **model_kwargs) |
| x = x + dt * k2 |
|
|
| else: |
| k1 = self.drift(x, t_batch, model, **model_kwargs) |
| x_euler = x + dt * k1 |
| t_end = t_batch + dt |
| k2 = self.drift(x_euler, t_end, model, **model_kwargs) |
| x = x + 0.5 * dt * (k1 + k2) |
|
|
| traj.append(x) |
|
|
| return th.stack(traj, dim=0) |
|
|
| return sample_fn |
|
|
|
|
| def create_transport( |
| path_type="Linear", |
| prediction="velocity", |
| loss_weight=None, |
| time_dist_type="uniform", |
| time_dist_shift=1.0, |
| ): |
| """Create a Transport object for flow matching.""" |
| model_type = { |
| "velocity": ModelType.VELOCITY, |
| "noise": ModelType.NOISE, |
| "score": ModelType.SCORE, |
| }[prediction] |
|
|
| loss_type = { |
| "velocity": WeightType.VELOCITY, |
| "likelihood": WeightType.LIKELIHOOD, |
| }.get(loss_weight, WeightType.NONE) |
|
|
| path_enum = { |
| "Linear": PathType.LINEAR, |
| "GVP": PathType.GVP, |
| "VP": PathType.VP, |
| }[path_type] |
|
|
| if path_enum == PathType.VP: |
| train_eps, sample_eps = 1e-5, 1e-3 |
| elif model_type != ModelType.VELOCITY: |
| train_eps, sample_eps = 1e-3, 1e-3 |
| else: |
| train_eps, sample_eps = 0, 0 |
|
|
| return Transport( |
| model_type=model_type, |
| path_type=path_enum, |
| loss_type=loss_type, |
| time_dist_type=time_dist_type, |
| time_dist_shift=time_dist_shift, |
| train_eps=train_eps, |
| sample_eps=sample_eps, |
| ) |
|
|