{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Model Plotting\n", "\n", "Here we compare different interpolants together on the same dataset from saved models." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import math\n", "import os\n", "import time\n", "\n", "import imageio\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import ot as pot\n", "import torch\n", "import torchdyn\n", "from torchdyn.core import DEFunc, NeuralODE\n", "from torchdyn.datasets import generate_moons\n", "from torchdyn.nn import Augmenter" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# Implement some helper functions\n", "\n", "\n", "def sample_normal(n):\n", " m = torch.distributions.multivariate_normal.MultivariateNormal(torch.zeros(2), torch.eye(2))\n", " return m.sample((n,))\n", "\n", "\n", "def log_normal_density(x):\n", " m = torch.distributions.multivariate_normal.MultivariateNormal(\n", " torch.zeros(x.shape[-1]), torch.eye(x.shape[-1])\n", " )\n", " return m.log_prob(x)\n", "\n", "\n", "def eight_normal_sample(n, dim, scale=1, var=1):\n", " m = torch.distributions.multivariate_normal.MultivariateNormal(\n", " torch.zeros(dim), math.sqrt(var) * torch.eye(dim)\n", " )\n", " centers = [\n", " (1, 0),\n", " (-1, 0),\n", " (0, 1),\n", " (0, -1),\n", " (1.0 / np.sqrt(2), 1.0 / np.sqrt(2)),\n", " (1.0 / np.sqrt(2), -1.0 / np.sqrt(2)),\n", " (-1.0 / np.sqrt(2), 1.0 / np.sqrt(2)),\n", " (-1.0 / np.sqrt(2), -1.0 / np.sqrt(2)),\n", " ]\n", " centers = torch.tensor(centers) * scale\n", " noise = m.sample((n,))\n", " multi = torch.multinomial(torch.ones(8), n, replacement=True)\n", " data = []\n", " for i in range(n):\n", " data.append(centers[multi[i]] + noise[i])\n", " data = torch.stack(data)\n", " return data\n", "\n", "\n", "def log_8gaussian_density(x, scale=5, var=0.1):\n", " centers = [\n", " (1, 0),\n", " (-1, 0),\n", " (0, 1),\n", " (0, -1),\n", " (1.0 / np.sqrt(2), 1.0 / np.sqrt(2)),\n", " (1.0 / np.sqrt(2), -1.0 / np.sqrt(2)),\n", " (-1.0 / np.sqrt(2), 1.0 / np.sqrt(2)),\n", " (-1.0 / np.sqrt(2), -1.0 / np.sqrt(2)),\n", " ]\n", " centers = torch.tensor(centers) * scale\n", " centers = centers.T.reshape(1, 2, 8)\n", " # calculate shifted xs [batch, centers, dims]\n", " x = (x[:, :, None] - centers).mT\n", " m = torch.distributions.multivariate_normal.MultivariateNormal(\n", " torch.zeros(x.shape[-1]), math.sqrt(var) * torch.eye(x.shape[-1])\n", " )\n", " log_probs = m.log_prob(x)\n", " log_probs = torch.logsumexp(log_probs, -1)\n", " return log_probs\n", "\n", "\n", "def sample_moons(n):\n", " x0, _ = generate_moons(n, noise=0.2)\n", " return x0 * 3 - 1\n", "\n", "\n", "def sample_8gaussians(n):\n", " return eight_normal_sample(n, 2, scale=5, var=0.1).float()\n", "\n", "\n", "class MLP(torch.nn.Module):\n", " def __init__(self, dim, out_dim=None, w=64, time_varying=False):\n", " super().__init__()\n", " self.time_varying = time_varying\n", " if out_dim is None:\n", " out_dim = dim\n", " self.net = torch.nn.Sequential(\n", " torch.nn.Linear(dim + (1 if time_varying else 0), w),\n", " torch.nn.SELU(),\n", " torch.nn.Linear(w, w),\n", " torch.nn.SELU(),\n", " torch.nn.Linear(w, w),\n", " torch.nn.SELU(),\n", " torch.nn.Linear(w, out_dim),\n", " )\n", "\n", " def forward(self, x):\n", " return self.net(x)\n", "\n", "\n", "class MLP2(torch.nn.Module):\n", " \"\"\"Change activations for Action Matching\"\"\"\n", "\n", " def __init__(self, dim, out_dim=None, w=64, time_varying=False):\n", " super().__init__()\n", " self.time_varying = time_varying\n", " if out_dim is None:\n", " out_dim = dim\n", " self.net = torch.nn.Sequential(\n", " torch.nn.Linear(dim + (1 if time_varying else 0), w),\n", " torch.nn.ReLU(),\n", " torch.nn.Linear(w, w),\n", " torch.nn.SiLU(),\n", " torch.nn.Linear(w, w),\n", " torch.nn.SiLU(),\n", " torch.nn.Linear(w, out_dim),\n", " )\n", "\n", " def forward(self, x):\n", " return self.net(x)\n", "\n", "\n", "class GradModel(torch.nn.Module):\n", " def __init__(self, action):\n", " super().__init__()\n", " self.action = action\n", "\n", " def forward(self, x):\n", " x = x.requires_grad_(True)\n", " grad = torch.autograd.grad(torch.sum(self.action(x)), x, create_graph=True)[0]\n", " return grad[:, :-1]\n", "\n", "\n", "class torch_wrapper(torch.nn.Module):\n", " \"\"\"Wraps model to torchdyn compatible format.\"\"\"\n", "\n", " def __init__(self, model):\n", " super().__init__()\n", " self.model = model\n", "\n", " def forward(self, t, x, *args, **kwargs):\n", " return model(torch.cat([x, t.repeat(x.shape[0])[:, None]], 1))\n", "\n", "\n", "def autograd_trace(x_out, x_in, **kwargs):\n", " \"\"\"Standard brute-force means of obtaining trace of the Jacobian, O(d) calls to autograd\"\"\"\n", " trJ = 0.0\n", " for i in range(x_in.shape[1]):\n", " trJ += torch.autograd.grad(x_out[:, i].sum(), x_in, allow_unused=False, create_graph=True)[\n", " 0\n", " ][:, i]\n", " return trJ\n", "\n", "\n", "class CNF(torch.nn.Module):\n", " def __init__(self, net, trace_estimator=None, noise_dist=None):\n", " super().__init__()\n", " self.net = net\n", " self.trace_estimator = trace_estimator if trace_estimator is not None else autograd_trace\n", " self.noise_dist, self.noise = noise_dist, None\n", "\n", " def forward(self, t, x, *args, **kwargs):\n", " with torch.set_grad_enabled(True):\n", " x_in = x[:, 1:].requires_grad_(\n", " True\n", " ) # first dimension reserved to divergence propagation\n", " # the neural network will handle the data-dynamics here\n", " x_out = self.net(\n", " torch.cat([x_in, t * torch.ones(x.shape[0], 1).type_as(x_in)], dim=-1)\n", " )\n", " trJ = self.trace_estimator(x_out, x_in, noise=self.noise)\n", " return (\n", " torch.cat([-trJ[:, None], x_out], 1) + 0 * x\n", " ) # `+ 0*x` has the only purpose of connecting x[:, 0] to autograd graph" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "savedir = \"models/8gaussian-moons\"\n", "models = {\n", " \"CFM\": torch.load(f\"{savedir}/cfm_v1.pt\"),\n", " \"OT-CFM (ours)\": torch.load(f\"{savedir}/otcfm_v1.pt\"),\n", " \"SB-CFM (ours)\": torch.load(f\"{savedir}/sbcfm_v1.pt\"),\n", " \"VP-CFM\": torch.load(f\"{savedir}/stochastic_interpolant_v1.pt\"),\n", " \"Action-Matching\": torch.load(f\"{savedir}/action_matching_v1.pt\"),\n", " \"Action-Matching (Swish)\": torch.load(f\"{savedir}/action_matching_swish_v1.pt\"),\n", "}" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "tags": [] }, "outputs": [], "source": [ "w = 7\n", "points = 100j\n", "points_real = 100\n", "device = \"cpu\"\n", "Y, X = np.mgrid[-w:w:points, -w:w:points]\n", "gridpoints = torch.tensor(np.stack([X.flatten(), Y.flatten()], axis=1)).type(torch.float32)\n", "points_small = 20j\n", "points_real_small = 20\n", "Y_small, X_small = np.mgrid[-w:w:points_small, -w:w:points_small]\n", "gridpoints_small = torch.tensor(np.stack([X_small.flatten(), Y_small.flatten()], axis=1)).type(\n", " torch.float32\n", ")\n", "\n", "torch.manual_seed(42)\n", "sample = sample_8gaussians(1024)\n", "ts = torch.linspace(0, 1, 101)\n", "trajs = {}\n", "for name, model in models.items():\n", " nde = NeuralODE(DEFunc(torch_wrapper(model)), solver=\"euler\").to(device)\n", " # with torch.no_grad():\n", " traj = nde.trajectory(sample.to(device), t_span=ts.to(device)).detach().cpu().numpy()\n", " trajs[name] = traj\n", "for i, t in enumerate(ts):\n", " names = [\n", " \"CFM\",\n", " \"Action-Matching\",\n", " \"Action-Matching (Swish)\",\n", " \"VP-CFM\",\n", " \"SB-CFM (ours)\",\n", " \"OT-CFM (ours)\",\n", " ]\n", " fig, axes = plt.subplots(3, len(names), figsize=(6 * len(names), 6 * 3))\n", " for axis, name in zip(axes.T, names):\n", " model = models[name]\n", " cnf = DEFunc(CNF(model))\n", " nde = NeuralODE(cnf, solver=\"euler\", sensitivity=\"adjoint\")\n", " cnf_model = torch.nn.Sequential(Augmenter(augment_idx=1, augment_dims=1), nde)\n", " with torch.no_grad():\n", " if t > 0:\n", " aug_traj = (\n", " cnf_model[1]\n", " .to(device)\n", " .trajectory(\n", " Augmenter(1, 1)(gridpoints).to(device),\n", " t_span=torch.linspace(t, 0, 201).to(device),\n", " )\n", " )[-1].cpu()\n", " log_probs = log_8gaussian_density(aug_traj[:, 1:]) - aug_traj[:, 0]\n", " else:\n", " log_probs = log_8gaussian_density(gridpoints)\n", " log_probs = log_probs.reshape(Y.shape)\n", " ax = axis[0]\n", " ax.pcolormesh(X, Y, torch.exp(log_probs), vmax=1)\n", " ax.set_xticks([])\n", " ax.set_yticks([])\n", " ax.set_xlim(-w, w)\n", " ax.set_ylim(-w, w)\n", " ax.set_title(f\"{name}\", fontsize=30)\n", " # Quiver plot\n", " # with torch.no_grad():\n", " out = model(\n", " torch.cat(\n", " [gridpoints_small, torch.ones((gridpoints_small.shape[0], 1)) * t], dim=1\n", " ).to(device)\n", " )\n", " out = out.reshape([points_real_small, points_real_small, 2]).cpu().detach().numpy()\n", " ax = axis[1]\n", " ax.quiver(\n", " X_small,\n", " Y_small,\n", " out[:, :, 0],\n", " out[:, :, 1],\n", " np.sqrt(np.sum(out**2, axis=-1)),\n", " cmap=\"coolwarm\",\n", " scale=50.0,\n", " width=0.015,\n", " pivot=\"mid\",\n", " )\n", " ax.set_xticks([])\n", " ax.set_yticks([])\n", " ax.set_xlim(-w, w)\n", "\n", " ax = axis[2]\n", " sample_traj = trajs[name]\n", " ax.scatter(sample_traj[0, :, 0], sample_traj[0, :, 1], s=10, alpha=0.8, c=\"black\")\n", " ax.scatter(sample_traj[:i, :, 0], sample_traj[:i, :, 1], s=0.2, alpha=0.2, c=\"olive\")\n", " ax.scatter(sample_traj[i, :, 0], sample_traj[i, :, 1], s=4, alpha=1, c=\"blue\")\n", " ax.set_xticks([])\n", " ax.set_yticks([])\n", " ax.set_xlim(-w, w)\n", " ax.set_ylim(-w, w)\n", " plt.suptitle(f\"8gaussians to Moons T={t:0.2f}\", fontsize=40)\n", " os.makedirs(\"figures/trajectory/v3/\", exist_ok=True)\n", " plt.savefig(f\"figures/trajectory/v3/{t:0.2f}.png\", dpi=40)\n", " plt.close()" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_10858/4293379450.py:6: DeprecationWarning: Starting with ImageIO v3 the behavior of this function will switch to that of iio.v3.imread. To keep the current behavior (and make this warning disappear) use `import imageio.v2 as imageio` or call `imageio.v2.imread` directly.\n", " image = imageio.imread(filename)\n" ] } ], "source": [ "gif_name = \"8gaussians-to-moons\"\n", "with imageio.get_writer(f\"{gif_name}.gif\", mode=\"I\") as writer:\n", " for filename in [f\"figures/trajectory/v3/{t:0.2f}.png\" for t in ts] + [\n", " f\"figures/trajectory/v3/{ts[-1].item():0.2f}.png\"\n", " ] * 10:\n", " image = imageio.imread(filename)\n", " writer.append_data(image)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "models = {\n", " \"CFM\": torch.load(\"./models/8gaussian-moons/cfm_v1.pt\"),\n", " \"OT-CFM (ours)\": torch.load(\"models/8gaussian-moons/otcfm_v1.pt\"),\n", " \"SB-CFM (ours)\": torch.load(\"models/8gaussian-moons/sbcfm_v1.pt\"),\n", " \"VP-CFM\": torch.load(\"models/8gaussian-moons/stochastic_interpolant_v1.pt\"),\n", " # \"FM\": torch.load(\"models/8gaussian-moons/flow_matching_v1.pt\"),\n", " # \"VP-SDE\": torch.load(\"models/8gaussian-moons/vp_flow_v1.pt\"),\n", " \"Action-Matching\": torch.load(\"models/8gaussian-moons/action_matching_v1.pt\"),\n", " \"Action-Matching (Swish)\": torch.load(\"models/8gaussian-moons/action_matching_swish_v1.pt\"),\n", "}" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "w = 7\n", "points = 100j\n", "points_real = 100\n", "device = \"cpu\"\n", "Y, X = np.mgrid[-w:w:points, -w:w:points]\n", "gridpoints = torch.tensor(np.stack([X.flatten(), Y.flatten()], axis=1)).type(torch.float32)\n", "points_small = 20j\n", "points_real_small = 20\n", "Y_small, X_small = np.mgrid[-w:w:points_small, -w:w:points_small]\n", "gridpoints_small = torch.tensor(np.stack([X_small.flatten(), Y_small.flatten()], axis=1)).type(\n", " torch.float32\n", ")\n", "\n", "torch.manual_seed(42)\n", "sample = sample_normal(1024)\n", "ts = torch.linspace(0, 1, 101)\n", "trajs = {}\n", "for name, model in models.items():\n", " nde = NeuralODE(DEFunc(torch_wrapper(model)), solver=\"euler\").to(device)\n", " # with torch.no_grad():\n", " traj = nde.trajectory(sample.to(device), t_span=ts.to(device)).detach().cpu().numpy()\n", " trajs[name] = traj\n", "names = [\n", " # \"VP-SDE\",\n", " # \"FM\",\n", " \"CFM\",\n", " \"Action-Matching\",\n", " \"Action-Matching (Swish)\",\n", " \"VP-CFM\",\n", " \"SB-CFM (ours)\",\n", " \"OT-CFM (ours)\",\n", "]\n", "for i, t in enumerate(ts):\n", " fig, axes = plt.subplots(3, len(names), figsize=(len(names) * 6, len(names) * 3))\n", "\n", " for axis, name in zip(axes.T, names):\n", " model = models[name]\n", " cnf = DEFunc(CNF(model))\n", " nde = NeuralODE(cnf, solver=\"euler\")\n", " cnf_model = torch.nn.Sequential(Augmenter(augment_idx=1, augment_dims=1), nde)\n", " with torch.no_grad():\n", " if t > 0:\n", " aug_traj = (\n", " cnf_model[1]\n", " .to(device)\n", " .trajectory(\n", " Augmenter(1, 1)(gridpoints).to(device),\n", " t_span=torch.linspace(t, 0, 201).to(device),\n", " )\n", " )[-1].cpu()\n", " log_probs = log_normal_density(aug_traj[:, 1:]) - aug_traj[:, 0]\n", " else:\n", " log_probs = log_normal_density(gridpoints)\n", " log_probs = log_probs.reshape(Y.shape)\n", " ax = axis[0]\n", " ax.pcolormesh(X, Y, torch.exp(log_probs))\n", " ax.set_xticks([])\n", " ax.set_yticks([])\n", " ax.set_xlim(-w, w)\n", " ax.set_ylim(-w, w)\n", " ax.set_title(f\"{name}\", fontsize=30)\n", " # Quiver plot\n", " # with torch.no_grad():\n", " out = model(\n", " torch.cat(\n", " [gridpoints_small, torch.ones((gridpoints_small.shape[0], 1)) * t], dim=1\n", " ).to(device)\n", " )\n", " out = out.reshape([points_real_small, points_real_small, 2]).cpu().detach().numpy()\n", " ax = axis[1]\n", " ax.quiver(\n", " X_small,\n", " Y_small,\n", " out[:, :, 0],\n", " out[:, :, 1],\n", " np.sqrt(np.sum(out**2, axis=-1)),\n", " cmap=\"coolwarm\",\n", " scale=50.0,\n", " width=0.015,\n", " pivot=\"mid\",\n", " )\n", " ax.set_xticks([])\n", " ax.set_yticks([])\n", " ax.set_xlim(-w, w)\n", "\n", " ax = axis[2]\n", " sample_traj = trajs[name]\n", " ax.scatter(sample_traj[0, :, 0], sample_traj[0, :, 1], s=10, alpha=0.8, c=\"black\")\n", " ax.scatter(sample_traj[:i, :, 0], sample_traj[:i, :, 1], s=0.2, alpha=0.2, c=\"olive\")\n", " ax.scatter(sample_traj[i, :, 0], sample_traj[i, :, 1], s=4, alpha=1, c=\"blue\")\n", " ax.set_xticks([])\n", " ax.set_yticks([])\n", " ax.set_xlim(-w, w)\n", " ax.set_ylim(-w, w)\n", " plt.suptitle(f\"Gaussian to Moons T={t:0.2f}\", fontsize=40)\n", " os.makedirs(\"figures/trajectory2/v3/\", exist_ok=True)\n", " plt.savefig(f\"figures/trajectory2/v3/{t:0.2f}.png\", dpi=40)\n", " plt.close()" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/tmp/ipykernel_10858/1452409276.py:7: DeprecationWarning: Starting with ImageIO v3 the behavior of this function will switch to that of iio.v3.imread. To keep the current behavior (and make this warning disappear) use `import imageio.v2 as imageio` or call `imageio.v2.imread` directly.\n", " image = imageio.imread(filename)\n" ] } ], "source": [ "gif_name = \"gaussians-to-moons\"\n", "ts = torch.linspace(0, 1, 101)\n", "with imageio.get_writer(f\"{gif_name}.gif\", mode=\"I\") as writer:\n", " for filename in [f\"figures/trajectory2/v3/{t:0.2f}.png\" for t in ts] + [\n", " f\"figures/trajectory2/v3/{ts[-1].item():0.2f}.png\"\n", " ] * 10:\n", " image = imageio.imread(filename)\n", " writer.append_data(image)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "torchcfm2", "language": "python", "name": "torchcfm2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" } }, "nbformat": 4, "nbformat_minor": 5 }