| import torch |
| import joblib |
| import numpy as np |
| import mdtraj as md |
| import matplotlib.pyplot as plt |
| import pyemma.coordinates as coor |
|
|
| from .utils import compute_dihedral |
| from matplotlib.colors import LinearSegmentedColormap |
|
|
| class Plot: |
| def __init__(self, args, mds): |
| self.device = args.device |
| self.save_dir = args.save_dir |
| self.molecule = args.molecule |
| self.start_state = args.start_state |
| self.num_samples = args.num_samples |
| self.start_position = mds.start_position |
| self.target_position = mds.target_position |
| self.energy_function = mds.energy_function |
|
|
| def __call__(self): |
| positions, potentials = [], [] |
| for i in range(self.num_samples): |
| position = np.load(f"{self.save_dir}/positions/{i}.npy").astype(np.float32) |
| potential = self.energy_function(position)[1] |
| positions.append(torch.from_numpy(position).to(self.device)) |
| potentials.append(potential) |
| self.paths(positions) |
|
|
| def paths(self, positions): |
| zorder = 32 |
| circle_size = 500 |
| saddle_size = 2400 |
| custom_colors_1 = ["#05009E", "#6B67EE", "#50B2D7", "#B0ADF1"] |
| custom_colors_2 = ["#05009E", "#6B67EE", "#50B2D7", "#F7EFFF"] |
| custom_cmap_1 = LinearSegmentedColormap.from_list("my_cmap", custom_colors_1) |
| custom_cmap_2 = LinearSegmentedColormap.from_list("my_cmap", custom_colors_2) |
| |
| if self.molecule == "aldp": |
| angle_1 = [6, 8, 14, 16] |
| angle_2 = [1, 6, 8, 14] |
| plt.clf() |
| plt.close() |
| fig = plt.figure(figsize=(7, 7)) |
| ax = fig.add_subplot(111) |
| plt.xlim([-np.pi, np.pi]) |
| plt.ylim([-np.pi, np.pi]) |
| with open("./data/aldp/landscape.dat") as f: |
| lines = f.readlines() |
| dims = [90, 90] |
| locations = torch.zeros((int(dims[0]), int(dims[1]), 2)) |
| data = torch.zeros((int(dims[0]), int(dims[1]))) |
| i = 0 |
| for line in lines[1:]: |
| splits = line[0:-1].split(" ") |
| vals = [y for y in splits if y != ""] |
| x = float(vals[0]) |
| y = float(vals[1]) |
| val = float(vals[-1]) |
| locations[i // 90, i % 90, :] = torch.tensor([x, y]) |
| data[i // 90, i % 90] = val |
| i = i + 1 |
| xs = np.arange(-np.pi, np.pi + 0.1, 0.1) |
| ys = np.arange(-np.pi, np.pi + 0.1, 0.1) |
| x, y = np.meshgrid(xs, ys) |
| inp = torch.tensor(np.array([x, y])).view(2, -1).T |
| loc = locations.view(-1, 2) |
| distances = torch.cdist(inp, loc.double(), p=2) |
| index = distances.argmin(dim=1) |
| a = torch.div(index, locations.shape[0], rounding_mode="trunc") |
| b = index % locations.shape[0] |
| z = data[a, b] |
| z = z.view(y.shape[0], y.shape[1]) |
| plt.contourf(xs, ys, z, levels=100, zorder=0, cmap=custom_cmap_2) |
| |
| |
| cm = custom_cmap_2 |
| |
| """ax.set_prop_cycle( |
| color=[cm(1.0 * i / len(positions)) for i in range(len(positions))] |
| )""" |
| for position in positions: |
| psi = compute_dihedral(position[:, angle_1, :]).detach().cpu().numpy() |
| phi = compute_dihedral(position[:, angle_2, :]).detach().cpu().numpy() |
| |
| ax.plot( |
| phi, |
| psi, |
| marker="o", linestyle="None", markersize=2, alpha=1.0, |
| markerfacecolor="white", |
| markeredgecolor="none", |
| markeredgewidth=0, |
| ) |
| |
| |
| end_phi, end_psi = phi[-1], psi[-1] |
| ax.scatter( |
| [end_phi], [end_psi], |
| s=70, c="#D577FF", edgecolors="w", linewidths=0.8, |
| zorder=zorder + 1, marker="o" |
| ) |
| |
| start_psi = ( |
| compute_dihedral(self.start_position[:, angle_1, :]) |
| .detach() |
| .cpu() |
| .numpy() |
| ) |
| start_phi = ( |
| compute_dihedral(self.start_position[:, angle_2, :]) |
| .detach() |
| .cpu() |
| .numpy() |
| ) |
| target_psi = ( |
| compute_dihedral(self.target_position[:, angle_1, :]) |
| .detach() |
| .cpu() |
| .numpy() |
| ) |
| target_phi = ( |
| compute_dihedral(self.target_position[:, angle_2, :]) |
| .detach() |
| .cpu() |
| .numpy() |
| ) |
| phis_saddle = [-0.035, -0.017] |
| psis_saddle = [1.605, -0.535] |
| """ax.scatter( |
| phis_saddle, |
| psis_saddle, |
| edgecolors="black", |
| c="w", |
| zorder=zorder, |
| s=saddle_size, |
| marker="*", |
| )""" |
| ax.scatter( |
| start_phi, |
| start_psi, |
| edgecolors="w", |
| c="#9793F8", |
| zorder=zorder, |
| s=circle_size, |
| marker="*", |
| ) |
| ax.scatter( |
| target_phi, |
| target_psi, |
| edgecolors="w", |
| c="#9793F8", |
| zorder=zorder, |
| s=circle_size, |
| marker="*", |
| ) |
| plt.xlabel("\u03A6", fontsize=35, fontweight="medium") |
| plt.ylabel("\u03A8", fontsize=35, fontweight="medium") |
| else: |
| fig = plt.figure(figsize=(7, 7)) |
| ax = fig.add_subplot(111) |
| cm = plt.get_cmap("gist_rainbow") |
| """ax.set_prop_cycle( |
| color=[cm(1.0 * i / len(positions)) for i in range(len(positions))] |
| )""" |
| pmf = np.load(f"./data/{self.molecule}/pmf.npy") |
| xs = np.load(f"./data/{self.molecule}/xs.npy") |
| ys = np.load(f"./data/{self.molecule}/ys.npy") |
| plt.pcolormesh(xs, ys, pmf.T, cmap=custom_cmap_1) |
| tica_model = joblib.load(f"./data/{self.molecule}/tica_model.pkl") |
| feat = coor.featurizer(f"./data/{self.molecule}/{self.start_state}.pdb") |
| feat.add_backbone_torsions(cossin=True) |
| for position in positions: |
| traj = md.Trajectory( |
| position.cpu().numpy(), |
| md.load(f"./data/{self.molecule}/{self.start_state}.pdb").topology, |
| ) |
| feature = feat.transform(traj) |
| tica = tica_model.transform(feature) |
| ax.plot( |
| tica[:, 0], |
| tica[:, 1], |
| marker="o", |
| linestyle="None", |
| markersize=2, |
| alpha=1.0, |
| markerfacecolor="white", |
| markeredgecolor="none", |
| markeredgewidth=0, |
| ) |
| end_x, end_y = tica[-1, 0], tica[-1, 1] |
| ax.scatter( |
| [end_x], [end_y], |
| s=70, c="#D577FF", edgecolors="w", linewidths=0.8, |
| zorder=zorder + 1, marker="o" |
| ) |
| |
| start_position = md.Trajectory( |
| self.start_position.cpu().numpy(), |
| md.load(f"./data/{self.molecule}/{self.start_state}.pdb").topology, |
| ) |
| feature = feat.transform(start_position) |
| start_tica = tica_model.transform(feature) |
| ax.scatter( |
| start_tica[:, 0], |
| start_tica[:, 1], |
| edgecolors="w", |
| c="#9793F8", |
| zorder=zorder, |
| s=circle_size, |
| marker="*", |
| ) |
| target_position = md.Trajectory( |
| self.target_position.cpu().numpy(), |
| md.load(f"./data/{self.molecule}/{self.start_state}.pdb").topology, |
| ) |
| feature = feat.transform(target_position) |
| target_tica = tica_model.transform(feature) |
| ax.scatter( |
| target_tica[:, 0], |
| target_tica[:, 1], |
| edgecolors="w", |
| c="#9793F8", |
| zorder=zorder, |
| s=circle_size, |
| marker="*", |
| ) |
| plt.xlabel("TIC 1", fontsize=35, fontweight="medium") |
| plt.ylabel("TIC 2", fontsize=35, fontweight="medium") |
| plt.xlim(xs.min(), xs.max()) |
| plt.ylim(ys.min(), ys.max()) |
| plt.tick_params( |
| left=False, |
| right=False, |
| labelleft=False, |
| labelbottom=False, |
| bottom=False, |
| ) |
| plt.tight_layout() |
| plt.savefig(f"{self.save_dir}/paths.png", dpi=300, bbox_inches="tight") |
| plt.show() |
| plt.close() |
| return fig |
|
|