| import os
|
|
|
| import matplotlib.pyplot as plt
|
| import numpy as np
|
| import torch
|
|
|
|
|
| def plot_trajectories(data, pred, graph, dataset, title=[1, 2.1]):
|
| fig, axs = plt.subplots(1, 3, figsize=(10, 2.3))
|
| fig.tight_layout(pad=0.2, w_pad=2, h_pad=3)
|
| assert data.shape[-1] == pred.shape[-1]
|
| for i in range(data.shape[-1]):
|
| axs[0].plot(data[0, :, i].squeeze())
|
| axs[1].plot(pred[0, :, i].squeeze())
|
| title = f"{dataset}: Epoch = {title[0]}, Loss = {title[1]:1.3f}"
|
| axs[1].set_title(title)
|
| cax = axs[2].matshow(graph)
|
| fig.colorbar(cax)
|
| if not os.path.exists("figs"):
|
| os.mkdir("figs")
|
| plt.savefig(f"figs/{title}.png")
|
| plt.close()
|
|
|
|
|
| def plot_graph_dist(graph_mu, graph_thresh, graph_std, ground_truth, path):
|
| fig, axs = plt.subplots(1, 4, figsize=(13, 4.5))
|
|
|
| axs[0].set_title("Ground Truth")
|
| axs[1].set_title("Graph means")
|
| axs[2].set_title("Graph post-threshold")
|
| axs[3].set_title("Graph std")
|
|
|
| print(graph_mu.shape, ground_truth.shape)
|
|
|
| g = [ground_truth, graph_mu, graph_thresh, graph_std]
|
| for col in range(4):
|
| ax = axs[col]
|
| pcm = ax.matshow(g[col], cmap="viridis")
|
| fig.colorbar(pcm, ax=ax)
|
|
|
| if not os.path.exists(path + "/figs"):
|
| os.mkdir(path + "/figs")
|
| plt.savefig(f"{path}/figs/graph_dist_plot.png")
|
| plt.close()
|
|
|
|
|
| def plot_traj_dist(data, pred, dataset, title=[1, 2.1]):
|
| fig, axs = plt.subplots(1, 2, figsize=(10, 2.3))
|
| fig.tight_layout(pad=0.2, w_pad=2, h_pad=3)
|
| assert data.shape[-1] == pred.shape[-1]
|
| for i in range(data.shape[-1]):
|
| axs[0].plot(data[0, :, i].squeeze())
|
| axs[1].plot(pred[0, :, i].squeeze())
|
| title = f"{dataset}: Epoch = {title[0]}, Loss = {title[1]:1.3f}"
|
| axs[1].set_title(title)
|
| if not os.path.exists("figs"):
|
| os.mkdir("figs")
|
| plt.savefig(f"figs/{title}.png")
|
| plt.close()
|
|
|
|
|
| def plot_cnf(data, traj, graph, dataset, title):
|
| n = 1000
|
| fig, axes = plt.subplots(1, 2, figsize=(10, 5))
|
| ax = axes[0]
|
| data = data.reshape([-1, *data.shape[2:]])
|
| ax.scatter(data[:, 0], data[:, 1], alpha=0.5)
|
| ax.scatter(traj[:n, -1, 0], traj[:n, -1, 1], s=10, alpha=0.8, c="black")
|
|
|
|
|
| ax.scatter(traj[:n, :, 0], traj[:n, :, 1], s=0.2, alpha=0.2, c="olive")
|
|
|
| ax.scatter(traj[:n, 0, 0], traj[:n, 0, 1], s=4, alpha=1, c="blue")
|
| ax.legend(["data", "Last Timepoint", "Flow", "Posterior"])
|
|
|
| ax = axes[1]
|
| cax = ax.matshow(graph)
|
| fig.colorbar(cax)
|
| title = f"{dataset}: Epoch = {title[0]}, Loss = {title[1]:1.3f}"
|
| ax.set_title(title)
|
| if not os.path.exists("figs"):
|
| os.mkdir("figs")
|
| plt.savefig(f"figs/{title}.png")
|
| plt.close()
|
|
|
|
|
| def plot_pca_traj(data, traj, graph, adata, dataset, title):
|
| """
|
| Args:
|
| data: np.array [N, T, D]
|
| traj: np.array [N, T, D]
|
| graph: np.array [D, D]
|
| """
|
| n = 1000
|
| fig, axes = plt.subplots(1, 2, figsize=(10, 5))
|
| ax = axes[0]
|
|
|
|
|
| def pca_transform(x, d=2):
|
| return (x - adata.var["means"].values) @ adata.varm["PCs"][:, :d]
|
|
|
| traj = pca_transform(traj)
|
|
|
| for t in range(data.shape[1]):
|
| pcd = pca_transform(data[:, t])
|
| ax.scatter(pcd[:, 0], pcd[:, 1], alpha=0.5)
|
| ax.scatter(traj[:n, -1, 0], traj[:n, -1, 1], s=10, alpha=0.8, c="black")
|
| ax.scatter(traj[:n, :, 0], traj[:n, :, 1], s=0.2, alpha=0.2, c="olive")
|
| ax.scatter(traj[:n, 0, 0], traj[:n, 0, 1], s=4, alpha=1, c="blue")
|
| ax.legend(
|
| [
|
| *[f"T={i}" for i in range(data.shape[1])],
|
| "Last Timepoint",
|
| "Flow",
|
| "Posterior",
|
| ]
|
| )
|
|
|
| ax = axes[1]
|
| cax = ax.matshow(graph)
|
| fig.colorbar(cax)
|
| title = f"{dataset}: Epoch = {title[0]}, Loss = {title[1]:1.3f}"
|
| ax.set_title(title)
|
| if not os.path.exists("figs_pca"):
|
| os.mkdir("figs_pca")
|
| plt.savefig(f"figs_pca/{title}.png")
|
| np.save(f"figs_pca/{title}.npy", graph)
|
| plt.close()
|
|
|
|
|
| def to_torch(arr):
|
| if isinstance(arr, list):
|
| return torch.tensor(np.array(arr)).float()
|
| elif isinstance(arr, (np.ndarray, np.generic)):
|
| return torch.tensor(arr).float()
|
| else:
|
| raise NotImplementedError(f"to_torch not implemented for type: {type(arr)}")
|
|
|