flow-matching / experiments /inspect_model.py
sabertoaster's picture
Upload folder using huggingface_hub
4edc9aa verified
import sys
from pathlib import Path
sys.path.append(str(Path(__file__).resolve().parent.parent))
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from types import SimpleNamespace
import numpy as np
# Import CFM from src instead of defining locally
from src.stage2.CFM import CFM
# =============================================================================
# Experiment: Gaussian -> 8-Gaussians
# =============================================================================
np.random.seed(42)
torch.manual_seed(42)
# ---- GPU setup ------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
n_samples = 4000
scale = 4.0
centers = np.array(
[
(np.cos(t) * scale, np.sin(t) * scale)
for t in np.linspace(0, 2 * np.pi, 8, endpoint=False)
]
)
assignments = np.random.randint(0, 8, size=n_samples)
gaussians_x = centers[assignments] + np.random.randn(n_samples, 2) * 0.4
target_tensor = torch.tensor(gaussians_x, dtype=torch.float32, device=device)
goal_dist = (target_tensor - target_tensor.mean(0)) / target_tensor.std(0)
# ---- build model ------------------------------------------------------------
cfm_params = SimpleNamespace(sigma_min=1e-4, solver="euler")
decoder_params = dict(hidden_dim=256, time_emb_dim=128, cond_dim=0)
model = CFM(feat_dim=2, cfm_params=cfm_params, decoder_params=decoder_params).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# ---- training loop ----------------------------------------------------------
epochs, batch_size = 3000, 512
losses = []
model.train()
for epoch in range(epochs):
idx = torch.randint(0, n_samples, (batch_size,))
x1 = goal_dist[idx].unsqueeze(-1) # (B, 2, 1)
# Conditional -> cluster embedding conditioning
labels = torch.tensor(assignments[idx], dtype=torch.long, device=device)
mu = model.label_emb(labels).unsqueeze(-1)
loss, loss_dict = model.compute_loss(x1, mu)
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(loss.item())
if (epoch + 1) % 1000 == 0:
print(
f"Epoch {epoch+1:5d} loss={loss.item():.5f} | "
f"FM={loss_dict['fm']:.5f} | "
f"Var={loss_dict['var']:.5f} | "
f"Align={loss_dict['align']:.5f}"
)
# ---- inference -------------------------------------------------------------
model.eval()
n_eval = 1000
eval_labels = torch.arange(8, device=device).repeat_interleave(n_eval // 8 + 1)[
:n_eval
]
mu_eval = model.label_emb(eval_labels).unsqueeze(-1).detach()
steps = 100
t_span = torch.linspace(0, 1, steps + 1, device=device)
trajectories = []
with torch.no_grad():
x = torch.randn(mu_eval.size(), device=device)
trajectories.append(x.squeeze(-1).cpu().numpy().copy())
t = t_span[0]
dt = t_span[1] - t_span[0]
snap_at = {0, 20, 40, 60, 80, 100}
for step in range(1, len(t_span)):
t_batch = t.expand(n_eval)
dphi_dt = model.estimator(x, mu_eval, t_batch)
x = x + dt * dphi_dt
t = t + dt
if step < len(t_span) - 1:
dt = t_span[step + 1] - t
if step in snap_at:
trajectories.append(x.squeeze(-1).cpu().numpy().copy())
# ---- plot ------------------------------------------------------------------
fig, axes = plt.subplots(1, 7, figsize=(21, 3))
fig.suptitle(
"OT-CFM: Gaussian → 8 Gaussians (conditional on cluster label)",
fontsize=13,
y=1.04,
)
times = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, "target"]
colors = ["#636EFA", "#7A89FB", "#9BA4FC", "#BCBFFD", "#DDDAFE", "#EF553B", "#00CC96"]
for ax, traj, label, c in zip(axes, trajectories, times, colors):
ax.scatter(traj[:, 0], traj[:, 1], s=4, alpha=0.6, color=c, linewidths=0)
ax.set_xlim(-3.5, 3.5)
ax.set_ylim(-3.5, 3.5)
ax.set_xlabel("X", fontsize=9)
ax.set_ylabel("Y", fontsize=9)
ax.set_title(f"t = {label}" if isinstance(label, float) else label, fontsize=10)
ax.axis("off")
# last panel: overlay ground-truth
gt = goal_dist[:1000].cpu().numpy()
axes[-1].scatter(gt[:, 0], gt[:, 1], s=4, alpha=0.3, color="#00CC96", linewidths=0)
axes[-1].set_xlim(-3.5, 3.5)
axes[-1].set_ylim(-3.5, 3.5)
axes[-1].set_xlabel("X", fontsize=9)
axes[-1].set_ylabel("Y", fontsize=9)
axes[-1].set_title("target", fontsize=10)
axes[-1].axis("off")
# loss curve panel
fig2, ax2 = plt.subplots(figsize=(7, 3))
ax2.plot(
np.convolve(losses, np.ones(50) / 50, mode="valid"), linewidth=1.2, color="#636EFA"
)
ax2.set_xlabel("Epoch")
ax2.set_ylabel("MSE Loss")
ax2.set_title("CFM Training Loss (50-epoch moving avg)")
ax2.spines[["top", "right"]].set_visible(False)
plt.tight_layout()
fig.savefig("cfm_trajectories_imported.png", dpi=130, bbox_inches="tight")
fig2.savefig("cfm_loss_imported.png", dpi=130, bbox_inches="tight")
print("Saved cfm_trajectories_imported.png and cfm_loss_imported.png")