sabertoaster's picture
Upload folder using huggingface_hub
4edc9aa verified
# %%
# =============================================================================
# CFM 2-D Toy Experiment β€” self-contained, no src path hacks
# =============================================================================
# Architecture contract (from your code):
# Decoder.forward(x, mu, t)
# x : (B, feat_dim, L)
# mu : (B, feat_dim, L)
# t : (B,) <-- scalar per sample, NOT (B,1,1)
# => out : (B, feat_dim, L)
#
# CFM.compute_loss(x1, mu)
# x1 : (B, feat_dim, L)
# mu : (B, feat_dim, L)
# Inside compute_loss, t is sampled as (B, 1, 1) and passed directly to
# estimator β€” BUT Decoder.time_emb expects (B,).
# FIX: squeeze t inside Decoder.forward, or patch compute_loss to pass t.squeeze().
# We patch the Decoder forward to handle both (B,) and (B,1,1).
# =============================================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from types import SimpleNamespace
from typing import List, Optional
from abc import ABC, abstractmethod
# ── helpers ──────────────────────────────────────────────────────────────────
def sinusoidal_pos_emb(t: torch.Tensor, dim: int) -> torch.Tensor:
"""t: (B,) -> (B, dim)"""
device = t.device
half = dim // 2
freqs = torch.exp(-torch.arange(half, device=device) * (np.log(10000) / (half - 1)))
args = t[:, None] * freqs[None]
return torch.cat([args.sin(), args.cos()], dim=-1)
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.dim = dim
def forward(self, t: torch.Tensor) -> torch.Tensor:
# accept (B,), (B,1), or (B,1,1) β€” always return (B, dim)
t = t.view(t.shape[0])
return sinusoidal_pos_emb(t, self.dim)
# ── MLP block ─────────────────────────────────────────────────────────────────
# %%
class MLP(nn.Module):
def __init__(self, in_c, hidden_c, out_c, time_emb_dim):
super().__init__()
self.time_net = nn.Sequential(nn.Linear(time_emb_dim, hidden_c), nn.Mish())
self.net1 = nn.Sequential(nn.Conv1d(in_c, hidden_c, 1), nn.ReLU())
self.net2 = nn.Sequential(nn.Conv1d(hidden_c, hidden_c, 1), nn.ReLU())
self.net3 = nn.Sequential(nn.Conv1d(hidden_c, hidden_c, 1), nn.ReLU())
self.out = nn.Conv1d(hidden_c, out_c, 1)
def forward(self, x, time_emb):
h = self.net1(x)
h = h + self.time_net(time_emb).unsqueeze(-1)
h = self.net2(h)
h = self.net3(h)
return self.out(h)
# class MLP(nn.Module):
# def __init__(self, in_c: int, hidden_c: int, out_c: int, time_emb_dim: int):
# super().__init__()
# self.time_net = nn.Sequential(nn.Linear(time_emb_dim, hidden_c), nn.Mish())
# self.net1 = nn.Sequential(nn.Linear(in_c, hidden_c), nn.ReLU())
# self.net2 = nn.Linear(hidden_c, out_c)
# def forward(self, x: torch.Tensor, time_emb: torch.Tensor) -> torch.Tensor:
# # x : (B, in_c, L)
# # time_emb : (B, time_emb_dim)
# x_t = x.transpose(1, 2) # (B, L, in_c) for Linear
# out = self.net1(x_t) # (B, L, hidden_c)
# out = out + self.time_net(time_emb).unsqueeze(1) # broadcast over L
# out = self.net2(out) # (B, L, out_c)
# return out.transpose(1, 2) # (B, out_c, L)
# %%
# ── Decoder ───────────────────────────────────────────────────────────────────
class Decoder(nn.Module):
"""
Lightweight MLP velocity estimator for toy 2-D flow-matching.
Tensor contract
---------------
forward(x, mu, t) -> vel
x : (B, feat_dim, L)
mu : (B, feat_dim, L)
t : (B,) | (B,1) | (B,1,1) # all accepted
vel : (B, feat_dim, L)
"""
def __init__(
self,
in_c: int = 2,
hidden_dim: int = 128,
out_c: int = 2,
time_emb_dim: int = 64,
cond_dim: int = 0,
):
super().__init__()
self.time_emb = SinusoidalPosEmb(time_emb_dim)
self.time_mlp = nn.Sequential(
nn.Linear(time_emb_dim, time_emb_dim),
)
# concat(x, mu) along channel dim -> 2*feat_dim channels
self.net = MLP(
in_c=in_c * 2, hidden_c=hidden_dim, out_c=out_c, time_emb_dim=time_emb_dim
)
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0.0, 0.02)
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(
self,
x: torch.Tensor,
mu: torch.Tensor,
t: torch.Tensor,
cond=None,
) -> torch.Tensor:
# normalise t to (B,) regardless of input shape
t_flat = t.reshape(x.shape[0]) # (B,)
t_emb = self.time_mlp(self.time_emb(t_flat)) # (B, time_emb_dim)
# concat along channel axis (B, 2*feat_dim, L)
xmu = torch.cat([x, mu], dim=1)
return self.net(xmu, t_emb) # (B, feat_dim, L)
# -- SourceGenerator
class SourceGenerator(nn.Module):
def __init__(self, feat_dim: int, hidden_dim: int = 64):
super().__init__()
# Outputs 2 * feat_dim to hold both mean and log_var
self.net = nn.Sequential(
nn.Conv1d(feat_dim, hidden_dim, 1),
nn.Mish(),
nn.Conv1d(hidden_dim, feat_dim * 2, 1),
)
def forward(self, mu: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
# mu: (B, feat_dim, L)
out = self.net(mu) # (B, 2*feat_dim, L)
mean_c, logvar_c = out.chunk(2, dim=1) # each (B, feat_dim, L)
return mean_c, logvar_c
# ── BASECFM ───────────────────────────────────────────────────────────────────
class BASECFM(nn.Module, ABC):
def __init__(self, feat_dim: int, cfm_params):
super().__init__()
self.feat_dim = feat_dim
self.sigma_min = cfm_params.sigma_min
self.estimator: Optional[nn.Module] = None
self.src_gen: Optional[nn.Module] = None
# ---- inference -----------------------------------------------------------
@torch.inference_mode()
def forward(
self,
mu: torch.Tensor, # (B, feat_dim, L)
n_timesteps: int,
temperature: float = 1.0,
) -> torch.Tensor:
z = self.src_gen(mu) * temperature
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device)
return self.solve_euler(z, t_span, mu)
def solve_euler(
self,
x: torch.Tensor, # (B, feat_dim, L)
t_span: torch.Tensor, # (n_steps+1,)
mu: torch.Tensor, # (B, feat_dim, L)
) -> torch.Tensor:
t = t_span[0]
dt = t_span[1] - t_span[0]
B = x.shape[0]
for step in range(1, len(t_span)):
t_batch = t.expand(B, device=device) # (B,)
dphi_dt = self.estimator(x, mu, t_batch)
x = x + dt * dphi_dt
t = t + dt
if step < len(t_span) - 1:
dt = t_span[step + 1] - t
return x
# ---- training ------------------------------------------------------------
def compute_loss(
self,
x1: torch.Tensor, # (B, feat_dim, L)
mu: torch.Tensor, # (B, feat_dim, L)
lambda_var: float = 1, # Hyperparameters from the paper
lambda_align: float = 0,
) -> tuple:
B = x1.shape[0]
# t sampled per sample, broadcast-ready for interpolation
t = torch.rand(B, 1, 1, device=mu.device, dtype=mu.dtype) # (B,1,1)
# z = torch.randn_like(mu) # (B, C, L)
mean_c, logvar_c = self.src_gen(mu) # (B, C, L)
eps = torch.randn_like(mean_c)
z = mean_c + torch.exp(0.5 * logvar_c) * eps
y = (1 - (1 - self.sigma_min) * t) * z + t * x1 # interpolant
u = x1 - (1 - self.sigma_min) * z # target velocity
# estimator expects t as (B,)
t_batch = t.reshape(B)
pred = self.estimator(y, mu, t_batch)
# 4. Standard Flow Matching Loss
loss_fm = F.mse_loss(pred, u)
# 5. Variance Regularization Loss [Eq. 9 in paper]
# D_KL( N(mu_c, sigma_c^2) || N(mu_c, I) ) = 0.5 * (sigma^2 - log(sigma^2) - 1)
loss_var = 0.5 * (torch.exp(logvar_c) - logvar_c - 1).mean()
# 6. Cosine Alignment Loss [Eq. 10 in paper]
sim = F.cosine_similarity(z.flatten(1), x1.flatten(1), dim=1)
loss_align = (1.0 - sim).mean()
# 7. Total Loss [Eq. 11 in paper]
loss_total = loss_fm + lambda_var * loss_var + lambda_align * loss_align
# Return total loss, and a dictionary for logging
loss_dict = {
"fm": loss_fm.item(),
"var": loss_var.item(),
"align": loss_align.item(),
}
return loss_total, loss_dict
class CFM(BASECFM):
def __init__(
self, feat_dim: int, cfm_params, decoder_params: dict, num_classes: int = 8
):
super().__init__(feat_dim=feat_dim, cfm_params=cfm_params)
self.estimator = Decoder(in_c=feat_dim, out_c=feat_dim, **decoder_params)
self.label_emb = nn.Embedding(num_classes, feat_dim)
self.src_gen = SourceGenerator(feat_dim=feat_dim)
# %%
# =============================================================================
# Experiment: Gaussian -> 8-Gaussians
# =============================================================================
np.random.seed(42)
torch.manual_seed(42)
# ---- GPU setup ------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
n_samples = 4000
scale = 4.0
centers = np.array(
[
(np.cos(t) * scale, np.sin(t) * scale)
for t in np.linspace(0, 2 * np.pi, 8, endpoint=False)
]
)
assignments = np.random.randint(0, 8, size=n_samples)
gaussians_x = centers[assignments] + np.random.randn(n_samples, 2) * 0.4
target_tensor = torch.tensor(gaussians_x, dtype=torch.float32, device=device)
goal_dist = (target_tensor - target_tensor.mean(0)) / target_tensor.std(0)
# ---- build model ------------------------------------------------------------
cfm_params = SimpleNamespace(sigma_min=1e-4, solver="euler")
decoder_params = dict(hidden_dim=256, time_emb_dim=128, cond_dim=0)
model = CFM(feat_dim=2, cfm_params=cfm_params, decoder_params=decoder_params).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# ---- training loop ----------------------------------------------------------
epochs, batch_size = 3000, 512
losses = []
model.train()
for epoch in range(epochs):
idx = torch.randint(0, n_samples, (batch_size,))
x1 = goal_dist[idx].unsqueeze(-1) # (B, 2, 1)
# Conditional -> cluster embedding conditioning
labels = torch.tensor(assignments[idx], dtype=torch.long, device=device)
mu = model.label_emb(labels).unsqueeze(-1)
loss, loss_dict = model.compute_loss(x1, mu)
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(loss.item())
if (epoch + 1) % 1000 == 0:
print(
f"Epoch {epoch+1:5d} loss={loss.item():.5f} | "
f"FM={loss_dict['fm']:.5f} | "
f"Var={loss_dict['var']:.5f} | "
f"Align={loss_dict['align']:.5f}"
)
# %%
# ---- inference -------------------------------------------------------------
model.eval()
n_eval = 1000
eval_labels = torch.arange(8, device=device).repeat_interleave(n_eval // 8 + 1)[
:n_eval
] # TODO: investigate
mu_eval = model.label_emb(eval_labels).unsqueeze(-1).detach()
steps = 100
t_span = torch.linspace(0, 1, steps + 1, device=device)
trajectories = []
with torch.no_grad():
x = torch.randn(mu_eval.size(), device=device)
trajectories.append(x.squeeze(-1).cpu().numpy().copy())
t = t_span[0]
dt = t_span[1] - t_span[0]
snap_at = {0, 20, 40, 60, 80, 100}
for step in range(1, len(t_span)):
t_batch = t.expand(n_eval)
dphi_dt = model.estimator(x, mu_eval, t_batch)
x = x + dt * dphi_dt
t = t + dt
if step < len(t_span) - 1:
dt = t_span[step + 1] - t
if step in snap_at:
trajectories.append(x.squeeze(-1).cpu().numpy().copy())
print(x.max(), " -- ", x.min())
# ---- plot ------------------------------------------------------------------
fig, axes = plt.subplots(1, 7, figsize=(21, 3))
fig.suptitle(
"OT-CFM: Gaussian β†’ 8 Gaussians (conditional on cluster label)",
fontsize=13,
y=1.04,
)
times = [0.0, 0.2, 0.4, 0.6, 0.8, 1.0, "target"]
colors = ["#636EFA", "#7A89FB", "#9BA4FC", "#BCBFFD", "#DDDAFE", "#EF553B", "#00CC96"]
for ax, traj, label, c in zip(axes, trajectories, times, colors):
ax.scatter(traj[:, 0], traj[:, 1], s=4, alpha=0.6, color=c, linewidths=0)
ax.set_xlim(-3.5, 3.5)
ax.set_ylim(-3.5, 3.5)
ax.set_xlabel("X", fontsize=9)
ax.set_ylabel("Y", fontsize=9)
ax.set_title(f"t = {label}" if isinstance(label, float) else label, fontsize=10)
ax.axis("off")
# last panel: overlay ground-truth
gt = goal_dist[:1000].cpu().numpy()
axes[-1].scatter(gt[:, 0], gt[:, 1], s=4, alpha=0.3, color="#00CC96", linewidths=0)
axes[-1].set_xlim(-3.5, 3.5)
axes[-1].set_ylim(-3.5, 3.5)
axes[-1].set_xlabel("X", fontsize=9)
axes[-1].set_ylabel("Y", fontsize=9)
axes[-1].set_title("target", fontsize=10)
axes[-1].axis("off")
# loss curve panel
fig2, ax2 = plt.subplots(figsize=(7, 3))
ax2.plot(
np.convolve(losses, np.ones(50) / 50, mode="valid"), linewidth=1.2, color="#636EFA"
)
ax2.set_xlabel("Epoch")
ax2.set_ylabel("MSE Loss")
ax2.set_title("CFM Training Loss (50-epoch moving avg)")
ax2.spines[["top", "right"]].set_visible(False)
plt.tight_layout()
fig.savefig("cfm_trajectories.png", dpi=130, bbox_inches="tight")
fig2.savefig("cfm_loss.png", dpi=130, bbox_inches="tight")
print("Saved cfm_trajectories.png and cfm_loss.png")
# %%
from torchinfo import summary
print(summary(model))
# %%
print(goal_dist.max(), goal_dist.min())
# %%