| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| import matplotlib.pyplot as plt |
| from types import SimpleNamespace |
| from typing import List, Optional |
| from abc import ABC, abstractmethod |
|
|
| |
|
|
|
|
| def sinusoidal_pos_emb(t: torch.Tensor, dim: int) -> torch.Tensor: |
| """t: (B,) -> (B, dim)""" |
| device = t.device |
| half = dim // 2 |
| freqs = torch.exp(-torch.arange(half, device=device) * (np.log(10000) / (half - 1))) |
| args = t[:, None] * freqs[None] |
| return torch.cat([args.sin(), args.cos()], dim=-1) |
|
|
|
|
| class SinusoidalPosEmb(nn.Module): |
| def __init__(self, dim: int): |
| super().__init__() |
| self.dim = dim |
|
|
| def forward(self, t: torch.Tensor) -> torch.Tensor: |
| |
| t = t.view(t.shape[0]) |
| return sinusoidal_pos_emb(t, self.dim) |
|
|
|
|
| |
|
|
|
|
| |
| class MLP(nn.Module): |
| def __init__(self, in_c, hidden_c, out_c, time_emb_dim): |
| super().__init__() |
| self.time_net = nn.Sequential(nn.Linear(time_emb_dim, hidden_c), nn.Mish()) |
| self.net1 = nn.Sequential(nn.Conv1d(in_c, hidden_c, 1), nn.ReLU()) |
| self.net2 = nn.Sequential(nn.Conv1d(hidden_c, hidden_c, 1), nn.ReLU()) |
| self.net3 = nn.Sequential(nn.Conv1d(hidden_c, hidden_c, 1), nn.ReLU()) |
| self.out = nn.Conv1d(hidden_c, out_c, 1) |
|
|
| def forward(self, x, time_emb): |
| h = self.net1(x) |
| h = h + self.time_net(time_emb).unsqueeze(-1) |
| h = self.net2(h) |
| h = self.net3(h) |
| return self.out(h) |
|
|
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| |
|
|
|
|
| class Decoder(nn.Module): |
| """ |
| Lightweight MLP velocity estimator for toy 2-D flow-matching. |
| |
| Tensor contract |
| --------------- |
| forward(x, mu, t) -> vel |
| x : (B, feat_dim, L) |
| mu : (B, feat_dim, L) |
| t : (B,) | (B,1) | (B,1,1) # all accepted |
| vel : (B, feat_dim, L) |
| """ |
|
|
| def __init__( |
| self, |
| in_c: int = 2, |
| hidden_dim: int = 128, |
| out_c: int = 2, |
| time_emb_dim: int = 64, |
| cond_dim: int = 0, |
| ): |
| super().__init__() |
| self.time_emb = SinusoidalPosEmb(time_emb_dim) |
| self.time_mlp = nn.Sequential( |
| nn.Linear(time_emb_dim, time_emb_dim), |
| ) |
| |
| self.net = MLP( |
| in_c=in_c * 2, hidden_c=hidden_dim, out_c=out_c, time_emb_dim=time_emb_dim |
| ) |
| self._init_weights() |
|
|
| def _init_weights(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Linear): |
| nn.init.normal_(m.weight, 0.0, 0.02) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| mu: torch.Tensor, |
| t: torch.Tensor, |
| cond=None, |
| ) -> torch.Tensor: |
| |
| t_flat = t.reshape(x.shape[0]) |
| t_emb = self.time_mlp(self.time_emb(t_flat)) |
|
|
| |
| xmu = torch.cat([x, mu], dim=1) |
|
|
| return self.net(xmu, t_emb) |
|
|
|
|
| |
|
|
|
|
| class SourceGenerator(nn.Module): |
| def __init__(self, feat_dim: int, hidden_dim: int = 64): |
| super().__init__() |
| |
| self.net = nn.Sequential( |
| nn.Conv1d(feat_dim, hidden_dim, 1), |
| nn.Mish(), |
| nn.Conv1d(hidden_dim, feat_dim * 2, 1), |
| ) |
|
|
| def forward(self, mu: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: |
| |
| out = self.net(mu) |
| mean_c, logvar_c = out.chunk(2, dim=1) |
| return mean_c, logvar_c |
|
|
|
|
| |
|
|
|
|
| class BASECFM(nn.Module, ABC): |
| def __init__(self, feat_dim: int, cfm_params): |
| super().__init__() |
| self.feat_dim = feat_dim |
| self.sigma_min = cfm_params.sigma_min |
| self.estimator: Optional[nn.Module] = None |
| self.src_gen: Optional[nn.Module] = None |
|
|
| |
|
|
| @torch.inference_mode() |
| def forward( |
| self, |
| mu: torch.Tensor, |
| n_timesteps: int, |
| temperature: float = 1.0, |
| ) -> torch.Tensor: |
| z = self.src_gen(mu) * temperature |
| t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device) |
| return self.solve_euler(z, t_span, mu) |
|
|
| def solve_euler( |
| self, |
| x: torch.Tensor, |
| t_span: torch.Tensor, |
| mu: torch.Tensor, |
| ) -> torch.Tensor: |
| t = t_span[0] |
| dt = t_span[1] - t_span[0] |
| B = x.shape[0] |
|
|
| for step in range(1, len(t_span)): |
| t_batch = t.expand(B, device=device) |
| dphi_dt = self.estimator(x, mu, t_batch) |
| x = x + dt * dphi_dt |
| t = t + dt |
| if step < len(t_span) - 1: |
| dt = t_span[step + 1] - t |
|
|
| return x |
|
|
| |
|
|
| def compute_loss( |
| self, |
| x1: torch.Tensor, |
| mu: torch.Tensor, |
| lambda_var: float = 1, |
| lambda_align: float = 0, |
| ) -> tuple: |
| B = x1.shape[0] |
|
|
| |
| t = torch.rand(B, 1, 1, device=mu.device, dtype=mu.dtype) |
| |
| mean_c, logvar_c = self.src_gen(mu) |
| eps = torch.randn_like(mean_c) |
| z = mean_c + torch.exp(0.5 * logvar_c) * eps |
|
|
| y = (1 - (1 - self.sigma_min) * t) * z + t * x1 |
| u = x1 - (1 - self.sigma_min) * z |
|
|
| |
| t_batch = t.reshape(B) |
| pred = self.estimator(y, mu, t_batch) |
|
|
| |
| loss_fm = F.mse_loss(pred, u) |
|
|
| |
| |
| loss_var = 0.5 * (torch.exp(logvar_c) - logvar_c - 1).mean() |
|
|
| |
| sim = F.cosine_similarity(z.flatten(1), x1.flatten(1), dim=1) |
| loss_align = (1.0 - sim).mean() |
|
|
| |
| loss_total = loss_fm + lambda_var * loss_var + lambda_align * loss_align |
|
|
| |
| loss_dict = { |
| "fm": loss_fm.item(), |
| "var": loss_var.item(), |
| "align": loss_align.item(), |
| } |
|
|
| return loss_total, loss_dict |
|
|
|
|
| class CFM(BASECFM): |
| def __init__( |
| self, feat_dim: int, cfm_params, decoder_params: dict, num_classes: int = 8 |
| ): |
| super().__init__(feat_dim=feat_dim, cfm_params=cfm_params) |
| self.estimator = Decoder(in_c=feat_dim, out_c=feat_dim, **decoder_params) |
| self.label_emb = nn.Embedding(num_classes, feat_dim) |
| self.src_gen = SourceGenerator(feat_dim=feat_dim) |
|
|
|
|
| |
| |
| |
| |
|
|
| np.random.seed(42) |
| torch.manual_seed(42) |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Using device: {device}") |
|
|
| n_samples = 4000 |
| scale = 4.0 |
| centers = np.array( |
| [ |
| (np.cos(t) * scale, np.sin(t) * scale) |
| for t in np.linspace(0, 2 * np.pi, 8, endpoint=False) |
| ] |
| ) |
| assignments = np.random.randint(0, 8, size=n_samples) |
| gaussians_x = centers[assignments] + np.random.randn(n_samples, 2) * 0.4 |
|
|
| target_tensor = torch.tensor(gaussians_x, dtype=torch.float32, device=device) |
| goal_dist = (target_tensor - target_tensor.mean(0)) / target_tensor.std(0) |
|
|
| |
| cfm_params = SimpleNamespace(sigma_min=1e-4, solver="euler") |
| decoder_params = dict(hidden_dim=256, time_emb_dim=128, cond_dim=0) |
| model = CFM(feat_dim=2, cfm_params=cfm_params, decoder_params=decoder_params).to(device) |
| optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) |
|
|
| |
| epochs, batch_size = 3000, 512 |
| losses = [] |
|
|
| model.train() |
| for epoch in range(epochs): |
| idx = torch.randint(0, n_samples, (batch_size,)) |
| x1 = goal_dist[idx].unsqueeze(-1) |
|
|
| |
| labels = torch.tensor(assignments[idx], dtype=torch.long, device=device) |
| mu = model.label_emb(labels).unsqueeze(-1) |
|
|
| loss, loss_dict = model.compute_loss(x1, mu) |
|
|
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
| losses.append(loss.item()) |
|
|
| if (epoch + 1) % 1000 == 0: |
| print( |
| f"Epoch {epoch+1:5d} loss={loss.item():.5f} | " |
| f"FM={loss_dict['fm']:.5f} | " |
| f"Var={loss_dict['var']:.5f} | " |
| f"Align={loss_dict['align']:.5f}" |
| ) |
|
|
| |
| |
| model.eval() |
| n_eval = 1000 |
| eval_labels = torch.arange(8, device=device).repeat_interleave(n_eval // 8 + 1)[ |
| :n_eval |
| ] |
| mu_eval = model.label_emb(eval_labels).unsqueeze(-1).detach() |
| steps = 100 |
| t_span = torch.linspace(0, 1, steps + 1, device=device) |
|
|
| trajectories = [] |
| with torch.no_grad(): |
| x = torch.randn(mu_eval.size(), device=device) |
| trajectories.append(x.squeeze(-1).cpu().numpy().copy()) |
|
|
| t = t_span[0] |
| dt = t_span[1] - t_span[0] |
|
|
| snap_at = {0, 20, 40, 60, 80, 100} |
| for step in range(1, len(t_span)): |
| t_batch = t.expand(n_eval) |
| dphi_dt = model.estimator(x, mu_eval, t_batch) |
| x = x + dt * dphi_dt |
| t = t + dt |
| if step < len(t_span) - 1: |
| dt = t_span[step + 1] - t |
| if step in snap_at: |
| trajectories.append(x.squeeze(-1).cpu().numpy().copy()) |
| print(x.max(), " -- ", x.min()) |
|
|
| |
| fig, axes = plt.subplots(1, 7, figsize=(21, 3)) |
| fig.suptitle( |
| "OT-CFM: Gaussian β 8 Gaussians (conditional on cluster label)", |
| fontsize=13, |
| y=1.04, |
| ) |
|
|
| times = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, "target"] |
| colors = ["#636EFA", "#7A89FB", "#9BA4FC", "#BCBFFD", "#DDDAFE", "#EF553B", "#00CC96"] |
|
|
| for ax, traj, label, c in zip(axes, trajectories, times, colors): |
| ax.scatter(traj[:, 0], traj[:, 1], s=4, alpha=0.6, color=c, linewidths=0) |
| ax.set_xlim(-3.5, 3.5) |
| ax.set_ylim(-3.5, 3.5) |
| ax.set_xlabel("X", fontsize=9) |
| ax.set_ylabel("Y", fontsize=9) |
| ax.set_title(f"t = {label}" if isinstance(label, float) else label, fontsize=10) |
| ax.axis("off") |
|
|
| |
| gt = goal_dist[:1000].cpu().numpy() |
| axes[-1].scatter(gt[:, 0], gt[:, 1], s=4, alpha=0.3, color="#00CC96", linewidths=0) |
| axes[-1].set_xlim(-3.5, 3.5) |
| axes[-1].set_ylim(-3.5, 3.5) |
| axes[-1].set_xlabel("X", fontsize=9) |
| axes[-1].set_ylabel("Y", fontsize=9) |
| axes[-1].set_title("target", fontsize=10) |
| axes[-1].axis("off") |
|
|
| |
| fig2, ax2 = plt.subplots(figsize=(7, 3)) |
| ax2.plot( |
| np.convolve(losses, np.ones(50) / 50, mode="valid"), linewidth=1.2, color="#636EFA" |
| ) |
| ax2.set_xlabel("Epoch") |
| ax2.set_ylabel("MSE Loss") |
| ax2.set_title("CFM Training Loss (50-epoch moving avg)") |
| ax2.spines[["top", "right"]].set_visible(False) |
|
|
| plt.tight_layout() |
| fig.savefig("cfm_trajectories.png", dpi=130, bbox_inches="tight") |
| fig2.savefig("cfm_loss.png", dpi=130, bbox_inches="tight") |
| print("Saved cfm_trajectories.png and cfm_loss.png") |
| |
| from torchinfo import summary |
|
|
| print(summary(model)) |
| |
| print(goal_dist.max(), goal_dist.min()) |
| |
|
|