| import sys |
| from pathlib import Path |
| sys.path.append(str(Path(__file__).resolve().parent.parent)) |
|
|
| import torch |
| import torch.nn as nn |
| import numpy as np |
| import matplotlib.pyplot as plt |
| from types import SimpleNamespace |
| import numpy as np |
| |
| from src.stage2.CFM import CFM |
|
|
| |
| |
| |
|
|
| np.random.seed(42) |
| torch.manual_seed(42) |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Using device: {device}") |
|
|
| n_samples = 4000 |
| scale = 4.0 |
| centers = np.array( |
| [ |
| (np.cos(t) * scale, np.sin(t) * scale) |
| for t in np.linspace(0, 2 * np.pi, 8, endpoint=False) |
| ] |
| ) |
| assignments = np.random.randint(0, 8, size=n_samples) |
| gaussians_x = centers[assignments] + np.random.randn(n_samples, 2) * 0.4 |
|
|
| target_tensor = torch.tensor(gaussians_x, dtype=torch.float32, device=device) |
| goal_dist = (target_tensor - target_tensor.mean(0)) / target_tensor.std(0) |
|
|
| |
| cfm_params = SimpleNamespace(sigma_min=1e-4, solver="euler") |
| decoder_params = dict(hidden_dim=256, time_emb_dim=128, cond_dim=0) |
| model = CFM(feat_dim=2, cfm_params=cfm_params, decoder_params=decoder_params).to(device) |
| optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) |
|
|
| |
| epochs, batch_size = 3000, 512 |
| losses = [] |
|
|
| model.train() |
| for epoch in range(epochs): |
| idx = torch.randint(0, n_samples, (batch_size,)) |
| x1 = goal_dist[idx].unsqueeze(-1) |
|
|
| |
| labels = torch.tensor(assignments[idx], dtype=torch.long, device=device) |
| mu = model.label_emb(labels).unsqueeze(-1) |
|
|
| loss, loss_dict = model.compute_loss(x1, mu) |
|
|
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
| losses.append(loss.item()) |
|
|
| if (epoch + 1) % 1000 == 0: |
| print( |
| f"Epoch {epoch+1:5d} loss={loss.item():.5f} | " |
| f"FM={loss_dict['fm']:.5f} | " |
| f"Var={loss_dict['var']:.5f} | " |
| f"Align={loss_dict['align']:.5f}" |
| ) |
|
|
| |
| model.eval() |
| n_eval = 1000 |
| eval_labels = torch.arange(8, device=device).repeat_interleave(n_eval // 8 + 1)[ |
| :n_eval |
| ] |
| mu_eval = model.label_emb(eval_labels).unsqueeze(-1).detach() |
| steps = 100 |
| t_span = torch.linspace(0, 1, steps + 1, device=device) |
|
|
| trajectories = [] |
| with torch.no_grad(): |
| x = torch.randn(mu_eval.size(), device=device) |
| trajectories.append(x.squeeze(-1).cpu().numpy().copy()) |
|
|
| t = t_span[0] |
| dt = t_span[1] - t_span[0] |
|
|
| snap_at = {0, 20, 40, 60, 80, 100} |
| for step in range(1, len(t_span)): |
| t_batch = t.expand(n_eval) |
| dphi_dt = model.estimator(x, mu_eval, t_batch) |
| x = x + dt * dphi_dt |
| t = t + dt |
| if step < len(t_span) - 1: |
| dt = t_span[step + 1] - t |
| if step in snap_at: |
| trajectories.append(x.squeeze(-1).cpu().numpy().copy()) |
|
|
| |
| fig, axes = plt.subplots(1, 7, figsize=(21, 3)) |
| fig.suptitle( |
| "OT-CFM: Gaussian → 8 Gaussians (conditional on cluster label)", |
| fontsize=13, |
| y=1.04, |
| ) |
|
|
| times = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, "target"] |
| colors = ["#636EFA", "#7A89FB", "#9BA4FC", "#BCBFFD", "#DDDAFE", "#EF553B", "#00CC96"] |
|
|
| for ax, traj, label, c in zip(axes, trajectories, times, colors): |
| ax.scatter(traj[:, 0], traj[:, 1], s=4, alpha=0.6, color=c, linewidths=0) |
| ax.set_xlim(-3.5, 3.5) |
| ax.set_ylim(-3.5, 3.5) |
| ax.set_xlabel("X", fontsize=9) |
| ax.set_ylabel("Y", fontsize=9) |
| ax.set_title(f"t = {label}" if isinstance(label, float) else label, fontsize=10) |
| ax.axis("off") |
|
|
| |
| gt = goal_dist[:1000].cpu().numpy() |
| axes[-1].scatter(gt[:, 0], gt[:, 1], s=4, alpha=0.3, color="#00CC96", linewidths=0) |
| axes[-1].set_xlim(-3.5, 3.5) |
| axes[-1].set_ylim(-3.5, 3.5) |
| axes[-1].set_xlabel("X", fontsize=9) |
| axes[-1].set_ylabel("Y", fontsize=9) |
| axes[-1].set_title("target", fontsize=10) |
| axes[-1].axis("off") |
|
|
| |
| fig2, ax2 = plt.subplots(figsize=(7, 3)) |
| ax2.plot( |
| np.convolve(losses, np.ones(50) / 50, mode="valid"), linewidth=1.2, color="#636EFA" |
| ) |
| ax2.set_xlabel("Epoch") |
| ax2.set_ylabel("MSE Loss") |
| ax2.set_title("CFM Training Loss (50-epoch moving avg)") |
| ax2.spines[["top", "right"]].set_visible(False) |
|
|
| plt.tight_layout() |
| fig.savefig("cfm_trajectories_imported.png", dpi=130, bbox_inches="tight") |
| fig2.savefig("cfm_loss_imported.png", dpi=130, bbox_inches="tight") |
| print("Saved cfm_trajectories_imported.png and cfm_loss_imported.png") |
|
|