import sys from pathlib import Path sys.path.append(str(Path(__file__).resolve().parent.parent)) import torch import torch.nn as nn import numpy as np import matplotlib.pyplot as plt from types import SimpleNamespace import numpy as np # Import CFM from src instead of defining locally from src.stage2.CFM import CFM # ============================================================================= # Experiment: Gaussian -> 8-Gaussians # ============================================================================= np.random.seed(42) torch.manual_seed(42) # ---- GPU setup ------ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") n_samples = 4000 scale = 4.0 centers = np.array( [ (np.cos(t) * scale, np.sin(t) * scale) for t in np.linspace(0, 2 * np.pi, 8, endpoint=False) ] ) assignments = np.random.randint(0, 8, size=n_samples) gaussians_x = centers[assignments] + np.random.randn(n_samples, 2) * 0.4 target_tensor = torch.tensor(gaussians_x, dtype=torch.float32, device=device) goal_dist = (target_tensor - target_tensor.mean(0)) / target_tensor.std(0) # ---- build model ------------------------------------------------------------ cfm_params = SimpleNamespace(sigma_min=1e-4, solver="euler") decoder_params = dict(hidden_dim=256, time_emb_dim=128, cond_dim=0) model = CFM(feat_dim=2, cfm_params=cfm_params, decoder_params=decoder_params).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) # ---- training loop ---------------------------------------------------------- epochs, batch_size = 3000, 512 losses = [] model.train() for epoch in range(epochs): idx = torch.randint(0, n_samples, (batch_size,)) x1 = goal_dist[idx].unsqueeze(-1) # (B, 2, 1) # Conditional -> cluster embedding conditioning labels = torch.tensor(assignments[idx], dtype=torch.long, device=device) mu = model.label_emb(labels).unsqueeze(-1) loss, loss_dict = model.compute_loss(x1, mu) optimizer.zero_grad() loss.backward() optimizer.step() losses.append(loss.item()) if (epoch + 1) % 1000 == 0: print( f"Epoch {epoch+1:5d} loss={loss.item():.5f} | " f"FM={loss_dict['fm']:.5f} | " f"Var={loss_dict['var']:.5f} | " f"Align={loss_dict['align']:.5f}" ) # ---- inference ------------------------------------------------------------- model.eval() n_eval = 1000 eval_labels = torch.arange(8, device=device).repeat_interleave(n_eval // 8 + 1)[ :n_eval ] mu_eval = model.label_emb(eval_labels).unsqueeze(-1).detach() steps = 100 t_span = torch.linspace(0, 1, steps + 1, device=device) trajectories = [] with torch.no_grad(): x = torch.randn(mu_eval.size(), device=device) trajectories.append(x.squeeze(-1).cpu().numpy().copy()) t = t_span[0] dt = t_span[1] - t_span[0] snap_at = {0, 20, 40, 60, 80, 100} for step in range(1, len(t_span)): t_batch = t.expand(n_eval) dphi_dt = model.estimator(x, mu_eval, t_batch) x = x + dt * dphi_dt t = t + dt if step < len(t_span) - 1: dt = t_span[step + 1] - t if step in snap_at: trajectories.append(x.squeeze(-1).cpu().numpy().copy()) # ---- plot ------------------------------------------------------------------ fig, axes = plt.subplots(1, 7, figsize=(21, 3)) fig.suptitle( "OT-CFM: Gaussian → 8 Gaussians (conditional on cluster label)", fontsize=13, y=1.04, ) times = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, "target"] colors = ["#636EFA", "#7A89FB", "#9BA4FC", "#BCBFFD", "#DDDAFE", "#EF553B", "#00CC96"] for ax, traj, label, c in zip(axes, trajectories, times, colors): ax.scatter(traj[:, 0], traj[:, 1], s=4, alpha=0.6, color=c, linewidths=0) ax.set_xlim(-3.5, 3.5) ax.set_ylim(-3.5, 3.5) ax.set_xlabel("X", fontsize=9) ax.set_ylabel("Y", fontsize=9) ax.set_title(f"t = {label}" if isinstance(label, float) else label, fontsize=10) ax.axis("off") # last panel: overlay ground-truth gt = goal_dist[:1000].cpu().numpy() axes[-1].scatter(gt[:, 0], gt[:, 1], s=4, alpha=0.3, color="#00CC96", linewidths=0) axes[-1].set_xlim(-3.5, 3.5) axes[-1].set_ylim(-3.5, 3.5) axes[-1].set_xlabel("X", fontsize=9) axes[-1].set_ylabel("Y", fontsize=9) axes[-1].set_title("target", fontsize=10) axes[-1].axis("off") # loss curve panel fig2, ax2 = plt.subplots(figsize=(7, 3)) ax2.plot( np.convolve(losses, np.ones(50) / 50, mode="valid"), linewidth=1.2, color="#636EFA" ) ax2.set_xlabel("Epoch") ax2.set_ylabel("MSE Loss") ax2.set_title("CFM Training Loss (50-epoch moving avg)") ax2.spines[["top", "right"]].set_visible(False) plt.tight_layout() fig.savefig("cfm_trajectories_imported.png", dpi=130, bbox_inches="tight") fig2.savefig("cfm_loss_imported.png", dpi=130, bbox_inches="tight") print("Saved cfm_trajectories_imported.png and cfm_loss_imported.png")