flow-matching / test /test.py
sabertoaster's picture
Upload folder using huggingface_hub
4edc9aa verified
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
# --- 1. The Neural Vector Field ---
# A simple MLP that takes (x, y, t) and outputs velocity (vx, vy)
class VectorField(nn.Module):
def __init__(self):
super().__init__()
self.net = nn.Sequential(
nn.Linear(3, 64),
nn.Tanh(),
nn.Linear(64, 64),
nn.Tanh(),
nn.Linear(64, 2) # Output: (vx, vy)
)
def forward(self, x, t):
# Concatenate x (Batch, 2) and t (Batch, 1)
if t.dim() == 0: t = t.expand(x.shape[0], 1)
elif t.dim() == 1: t = t.view(-1, 1)
xt = torch.cat([x, t], dim=1)
return self.net(xt)
# --- 2. Setup Data and Training ---
model = VectorField()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
# Target distribution: Two Gaussian blobs centered at (-2, -2) and (2, 2)
def sample_data(batch_size):
indices = torch.randint(0, 2, (batch_size,))
centers = torch.tensor([[-2., -2.], [2., 2.]])
noise = torch.randn(batch_size, 2) * 0.5
return centers[indices] + noise
# Source distribution: Standard Gaussian centered at (0, 0)
def sample_source(batch_size):
return torch.randn(batch_size, 2)
# --- 3. The Flow Matching Training Loop ---
print("Training Flow Matching Model...")
for step in range(2000):
batch_size = 256
# Sample endpoints
x0 = sample_source(batch_size) # Noise
x1 = sample_data(batch_size) # Data
# Sample random times t ~ U[0, 1]
t = torch.rand(batch_size, 1)
# Compute the interpolation (linear path)
# x_t = (1 - t) * x0 + t * x1
x_t = (1 - t) * x0 + t * x1
# Calculate the target velocity (conditional flow)
# u_t = x1 - x0
target_velocity = x1 - x0
# Predict velocity with neural network
pred_velocity = model(x_t, t)
# Loss: MSE between predicted and target velocity
loss = torch.mean((pred_velocity - target_velocity) ** 2)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % 500 == 0:
print(f"Step {step}: Loss = {loss.item():.4f}")
# --- 4. Inference: Solving the ODE ---
# We solve dx/dt = v(x, t) using a simple Euler solver
print("\nSampling (solving ODE)...")
with torch.no_grad():
x = sample_source(1000) # Start from noise
dt = 0.01
for t_step in np.arange(0, 1, dt):
t_tensor = torch.full((x.shape[0], 1), t_step)
velocity = model(x, t_tensor)
x = x + velocity * dt # Euler update
# Visualization
final_samples = x.numpy()
plt.figure(figsize=(6, 6))
plt.scatter(final_samples[:, 0], final_samples[:, 1], s=10, alpha=0.6, label="Generated")
plt.title("Flow Matching Output (Approx. Data Dist.)")
plt.grid(True)
plt.tight_layout()
plt.savefig("flow_matching_output.png")
plt.close()