xiangzai's picture
Add files using upload-large-folder tool
3e4f775 verified
import os
from typing import Union
import matplotlib.pyplot as plt
import numpy as np
import scprep
import torch
def plot_scatter(obs, model, title="fig", wandb_logger=None):
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
batch_size, ts, dim = obs.shape
obs = obs.reshape(-1, dim).detach().cpu().numpy()
ts = np.tile(np.arange(ts), batch_size)
scprep.plot.scatter2d(obs, c=ts, ax=ax)
os.makedirs("figs", exist_ok=True)
plt.savefig(f"figs/{title}.png")
if wandb_logger:
wandb_logger.log_image(key=title, images=[f"figs/{title}.png"])
plt.close()
def plot_scatter_and_flow(obs, model, title="stream", wandb_logger=None):
batch_size, ts, dim = obs.shape
device = obs.device
obs = obs.reshape(-1, dim).detach().cpu().numpy()
diff = obs.max() - obs.min()
wmin = obs.min() - diff * 0.1
wmax = obs.max() + diff * 0.1
points = 50j
points_real = 50
Y, X, T = np.mgrid[wmin:wmax:points, wmin:wmax:points, 0 : ts - 1 : 7j]
gridpoints = torch.tensor(
np.stack([X.flatten(), Y.flatten()], axis=1), requires_grad=True, device=device
).type(torch.float32)
times = torch.tensor(T.flatten(), requires_grad=True, device=device).type(torch.float32)[
:, None
]
out = model(times, gridpoints)
out = out.reshape([points_real, points_real, 7, dim])
out = out.cpu().detach().numpy()
# Stream over time
fig, axes = plt.subplots(1, 7, figsize=(20, 4), sharey=True)
axes = axes.flatten()
tts = np.tile(np.arange(ts), batch_size)
for i in range(7):
scprep.plot.scatter2d(obs, c=tts, ax=axes[i])
axes[i].streamplot(
X[:, :, 0],
Y[:, :, 0],
out[:, :, i, 0],
out[:, :, i, 1],
color=np.sum(out[:, :, i] ** 2, axis=-1),
)
axes[i].set_title(f"t = {np.linspace(0,ts-1,7)[i]:0.2f}")
os.makedirs("figs", exist_ok=True)
plt.savefig(f"figs/{title}.png")
plt.close()
if wandb_logger:
wandb_logger.log_image(key="flow", images=[f"figs/{title}.png"])
def store_trajectories(obs: Union[torch.Tensor, list], model, title="trajs", start_time=0):
n = 2000
if isinstance(obs, list):
data, labels = [], []
for t, xi in enumerate(obs):
xi = xi.detach().cpu().numpy()
data.append(xi)
labels.append(t * np.ones(xi.shape[0]))
data = np.concatenate(data, axis=0)
labels = np.concatenate(labels, axis=0)
scprep.plot.scatter2d(data, c=labels)
start = obs[0][:n]
ts = len(obs)
else:
batch_size, ts, dim = obs.shape
start = obs[:n, start_time, :]
obs = obs.reshape(-1, dim).detach().cpu().numpy()
from torchdyn.core import NeuralODE
with torch.no_grad():
node = NeuralODE(model)
# For consistency with DSB
traj = node.trajectory(start, t_span=torch.linspace(0, ts - 1, 20 * (ts - 1)))
traj = traj.cpu().detach().numpy()
os.makedirs("figs", exist_ok=True)
np.save(f"figs/{title}.npy", traj)
def plot_trajectory(
obs: Union[torch.Tensor, list],
traj: torch.Tensor,
title="traj",
key="traj",
start_time=0,
n=200,
wandb_logger=None,
):
plt.figure(figsize=(6, 6))
if isinstance(obs, list):
data, labels = [], []
for t, xi in enumerate(obs):
xi = xi.detach().cpu().numpy()
data.append(xi)
labels.append(t * np.ones(xi.shape[0]))
data = np.concatenate(data, axis=0)
labels = np.concatenate(labels, axis=0)
scprep.plot.scatter2d(data, c=labels)
ts = len(obs)
else:
batch_size, ts, dim = obs.shape
obs = obs.reshape(-1, dim).detach().cpu().numpy()
tts = np.tile(np.arange(ts), batch_size)
scprep.plot.scatter2d(obs, c=tts)
plt.scatter(traj[:, :n, 0], traj[:, :n, 1], s=0.3, alpha=0.2, c="black", label="Flow")
plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=6, alpha=1, c="purple", marker="x")
for i in range(20):
plt.plot(traj[:, i, 0], traj[:, i, 1], c="red", alpha=0.5)
# plt.legend(["Prior sample z(S)", "Flow", "z(0)"])
os.makedirs("figs", exist_ok=True)
plt.savefig(f"figs/{title}.png")
plt.close()
if wandb_logger:
wandb_logger.log_image(key=key, images=[f"figs/{title}.png"])
def plot_paths(
obs: Union[torch.Tensor, list],
model,
title="paths",
start_time=0,
n=200,
wandb_logger=None,
):
plt.figure(figsize=(6, 6))
if isinstance(obs, list):
data, labels = [], []
for t, xi in enumerate(obs):
xi = xi.detach().cpu().numpy()
data.append(xi)
labels.append(t * np.ones(xi.shape[0]))
data = np.concatenate(data, axis=0)
labels = np.concatenate(labels, axis=0)
scprep.plot.scatter2d(data, c=labels)
start = obs[0][:n]
ts = len(obs)
else:
batch_size, ts, dim = obs.shape
start = obs[:n, start_time, :]
obs = obs.reshape(-1, dim).detach().cpu().numpy()
tts = np.tile(np.arange(ts), batch_size)
scprep.plot.scatter2d(obs, c=tts)
from torchdyn.core import NeuralODE
with torch.no_grad():
node = NeuralODE(model)
traj = node.trajectory(start, t_span=torch.linspace(0, ts - 1, max(20 * ts, 100)))
traj = traj.cpu().detach().numpy()
# plt.scatter(traj[0, :n, 0], traj[0, :n, 1], s=10, alpha=0.8, c="black")
plt.scatter(traj[:, :n, 0], traj[:, :n, 1], s=0.3, alpha=0.2, c="black", label="Flow")
plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=6, alpha=1, c="purple", marker="x")
# plt.legend(["Prior sample z(S)", "Flow", "z(0)"])
os.makedirs("figs", exist_ok=True)
plt.savefig(f"figs/{title}.png")
plt.close()
if wandb_logger:
wandb_logger.log_image(key="paths", images=[f"figs/{title}.png"])
def plot_samples(trajs, title="samples", wandb_logger=None):
import PIL
from torchvision.utils import save_image
images = trajs[:100]
os.makedirs("figs", exist_ok=True)
save_image(images, fp=f"figs/{title}.jpg", nrow=10, normalize=True, padding=0)
if wandb_logger:
try:
wandb_logger.log_image(key="paths", images=[f"figs/{title}.jpg"])
except PIL.UnidentifiedImageError:
print(f"ERROR logging {title}")