| import os
|
| from typing import Union
|
|
|
| import matplotlib.pyplot as plt
|
| import numpy as np
|
| import scprep
|
| import torch
|
|
|
|
|
| def plot_scatter(obs, model, title="fig", wandb_logger=None):
|
| fig, ax = plt.subplots(1, 1, figsize=(5, 5))
|
| batch_size, ts, dim = obs.shape
|
| obs = obs.reshape(-1, dim).detach().cpu().numpy()
|
| ts = np.tile(np.arange(ts), batch_size)
|
| scprep.plot.scatter2d(obs, c=ts, ax=ax)
|
| os.makedirs("figs", exist_ok=True)
|
| plt.savefig(f"figs/{title}.png")
|
| if wandb_logger:
|
| wandb_logger.log_image(key=title, images=[f"figs/{title}.png"])
|
| plt.close()
|
|
|
|
|
| def plot_scatter_and_flow(obs, model, title="stream", wandb_logger=None):
|
| batch_size, ts, dim = obs.shape
|
| device = obs.device
|
| obs = obs.reshape(-1, dim).detach().cpu().numpy()
|
| diff = obs.max() - obs.min()
|
| wmin = obs.min() - diff * 0.1
|
| wmax = obs.max() + diff * 0.1
|
| points = 50j
|
| points_real = 50
|
| Y, X, T = np.mgrid[wmin:wmax:points, wmin:wmax:points, 0 : ts - 1 : 7j]
|
| gridpoints = torch.tensor(
|
| np.stack([X.flatten(), Y.flatten()], axis=1), requires_grad=True, device=device
|
| ).type(torch.float32)
|
| times = torch.tensor(T.flatten(), requires_grad=True, device=device).type(torch.float32)[
|
| :, None
|
| ]
|
| out = model(times, gridpoints)
|
| out = out.reshape([points_real, points_real, 7, dim])
|
| out = out.cpu().detach().numpy()
|
|
|
| fig, axes = plt.subplots(1, 7, figsize=(20, 4), sharey=True)
|
| axes = axes.flatten()
|
| tts = np.tile(np.arange(ts), batch_size)
|
| for i in range(7):
|
| scprep.plot.scatter2d(obs, c=tts, ax=axes[i])
|
| axes[i].streamplot(
|
| X[:, :, 0],
|
| Y[:, :, 0],
|
| out[:, :, i, 0],
|
| out[:, :, i, 1],
|
| color=np.sum(out[:, :, i] ** 2, axis=-1),
|
| )
|
| axes[i].set_title(f"t = {np.linspace(0,ts-1,7)[i]:0.2f}")
|
| os.makedirs("figs", exist_ok=True)
|
| plt.savefig(f"figs/{title}.png")
|
| plt.close()
|
| if wandb_logger:
|
| wandb_logger.log_image(key="flow", images=[f"figs/{title}.png"])
|
|
|
|
|
| def store_trajectories(obs: Union[torch.Tensor, list], model, title="trajs", start_time=0):
|
| n = 2000
|
| if isinstance(obs, list):
|
| data, labels = [], []
|
| for t, xi in enumerate(obs):
|
| xi = xi.detach().cpu().numpy()
|
| data.append(xi)
|
| labels.append(t * np.ones(xi.shape[0]))
|
| data = np.concatenate(data, axis=0)
|
| labels = np.concatenate(labels, axis=0)
|
| scprep.plot.scatter2d(data, c=labels)
|
| start = obs[0][:n]
|
| ts = len(obs)
|
| else:
|
| batch_size, ts, dim = obs.shape
|
| start = obs[:n, start_time, :]
|
| obs = obs.reshape(-1, dim).detach().cpu().numpy()
|
| from torchdyn.core import NeuralODE
|
|
|
| with torch.no_grad():
|
| node = NeuralODE(model)
|
|
|
| traj = node.trajectory(start, t_span=torch.linspace(0, ts - 1, 20 * (ts - 1)))
|
| traj = traj.cpu().detach().numpy()
|
| os.makedirs("figs", exist_ok=True)
|
| np.save(f"figs/{title}.npy", traj)
|
|
|
|
|
| def plot_trajectory(
|
| obs: Union[torch.Tensor, list],
|
| traj: torch.Tensor,
|
| title="traj",
|
| key="traj",
|
| start_time=0,
|
| n=200,
|
| wandb_logger=None,
|
| ):
|
| plt.figure(figsize=(6, 6))
|
| if isinstance(obs, list):
|
| data, labels = [], []
|
| for t, xi in enumerate(obs):
|
| xi = xi.detach().cpu().numpy()
|
| data.append(xi)
|
| labels.append(t * np.ones(xi.shape[0]))
|
| data = np.concatenate(data, axis=0)
|
| labels = np.concatenate(labels, axis=0)
|
| scprep.plot.scatter2d(data, c=labels)
|
| ts = len(obs)
|
| else:
|
| batch_size, ts, dim = obs.shape
|
| obs = obs.reshape(-1, dim).detach().cpu().numpy()
|
| tts = np.tile(np.arange(ts), batch_size)
|
| scprep.plot.scatter2d(obs, c=tts)
|
| plt.scatter(traj[:, :n, 0], traj[:, :n, 1], s=0.3, alpha=0.2, c="black", label="Flow")
|
| plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=6, alpha=1, c="purple", marker="x")
|
| for i in range(20):
|
| plt.plot(traj[:, i, 0], traj[:, i, 1], c="red", alpha=0.5)
|
|
|
| os.makedirs("figs", exist_ok=True)
|
| plt.savefig(f"figs/{title}.png")
|
| plt.close()
|
| if wandb_logger:
|
| wandb_logger.log_image(key=key, images=[f"figs/{title}.png"])
|
|
|
|
|
| def plot_paths(
|
| obs: Union[torch.Tensor, list],
|
| model,
|
| title="paths",
|
| start_time=0,
|
| n=200,
|
| wandb_logger=None,
|
| ):
|
| plt.figure(figsize=(6, 6))
|
| if isinstance(obs, list):
|
| data, labels = [], []
|
| for t, xi in enumerate(obs):
|
| xi = xi.detach().cpu().numpy()
|
| data.append(xi)
|
| labels.append(t * np.ones(xi.shape[0]))
|
| data = np.concatenate(data, axis=0)
|
| labels = np.concatenate(labels, axis=0)
|
| scprep.plot.scatter2d(data, c=labels)
|
| start = obs[0][:n]
|
| ts = len(obs)
|
| else:
|
| batch_size, ts, dim = obs.shape
|
| start = obs[:n, start_time, :]
|
| obs = obs.reshape(-1, dim).detach().cpu().numpy()
|
| tts = np.tile(np.arange(ts), batch_size)
|
| scprep.plot.scatter2d(obs, c=tts)
|
| from torchdyn.core import NeuralODE
|
|
|
| with torch.no_grad():
|
| node = NeuralODE(model)
|
| traj = node.trajectory(start, t_span=torch.linspace(0, ts - 1, max(20 * ts, 100)))
|
| traj = traj.cpu().detach().numpy()
|
|
|
| plt.scatter(traj[:, :n, 0], traj[:, :n, 1], s=0.3, alpha=0.2, c="black", label="Flow")
|
| plt.scatter(traj[-1, :n, 0], traj[-1, :n, 1], s=6, alpha=1, c="purple", marker="x")
|
|
|
| os.makedirs("figs", exist_ok=True)
|
| plt.savefig(f"figs/{title}.png")
|
| plt.close()
|
| if wandb_logger:
|
| wandb_logger.log_image(key="paths", images=[f"figs/{title}.png"])
|
|
|
|
|
| def plot_samples(trajs, title="samples", wandb_logger=None):
|
| import PIL
|
| from torchvision.utils import save_image
|
|
|
| images = trajs[:100]
|
| os.makedirs("figs", exist_ok=True)
|
| save_image(images, fp=f"figs/{title}.jpg", nrow=10, normalize=True, padding=0)
|
| if wandb_logger:
|
| try:
|
| wandb_logger.log_image(key="paths", images=[f"figs/{title}.jpg"])
|
| except PIL.UnidentifiedImageError:
|
| print(f"ERROR logging {title}")
|
|
|