| import os |
| import argparse |
| import shutil |
|
|
| import pandas as pd |
| import MDAnalysis as mda |
| import numpy as np |
| import matplotlib.pyplot as plt |
| import multiprocessing as mp |
|
|
| from scipy.stats import gaussian_kde |
| from MDAnalysis.analysis import align |
|
|
|
|
|
|
| def cal_energy(para1): |
| file_md, dirpath = para1 |
| mdpath = os.path.join(dirpath, file_md) |
| filename = file_md |
|
|
| k=2.32*1e-4 |
| T=298.15 |
|
|
| pdb_filepath = os.path.join(mdpath, filename+".pdb") |
| topology_filepath = os.path.join(mdpath, filename+".pdb") |
|
|
| u_ref = mda.Universe(pdb_filepath) |
| protein_ref = u_ref.select_atoms('protein') |
| bb_atom_ref = protein_ref.select_atoms('name CA or name C or name N') |
|
|
| info = { |
| 'rad_gyr': [], |
| 'rmsd_ref':[], |
| 'traj_filename':[], |
| 'energy':[], |
| } |
|
|
| for xtc_idx in range(1,4): |
| trajectory_filepath = os.path.join(mdpath,filename+"_R"+str(xtc_idx)+".xtc") |
|
|
| u = mda.Universe(topology_filepath, trajectory_filepath) |
| |
| protein = u.select_atoms('protein') |
| bb_atom = protein.select_atoms('name CA or name C or name N') |
|
|
| |
| |
|
|
| count = 0 |
| |
| for _ in u.trajectory: |
| count += 1 |
|
|
| rad_gyr = bb_atom.radius_of_gyration() |
| rmsd_ref = align.alignto(bb_atom, bb_atom_ref, select='all', match_atoms=False)[-1] |
| info['rad_gyr'].append(rad_gyr) |
| info['rmsd_ref'].append(rmsd_ref) |
|
|
| traj_filename = filename + '_R' + str(xtc_idx) + '_'+str(count)+".pdb" |
| info['traj_filename'].append(traj_filename) |
| print(traj_filename) |
| protein.write(os.path.join(mdpath, traj_filename)) |
|
|
| |
| info_array = np.stack([info['rad_gyr'],info['rmsd_ref']],axis=0) |
| kde = gaussian_kde(info_array) |
| density = kde(info_array) |
| G = k*T*np.log(np.max(density)/density) |
| G = (G-np.min(G))/(np.max(G)-np.min(G)) |
| |
| info['energy'] += G.tolist() |
|
|
| out_total = pd.DataFrame(info) |
| x, y = np.meshgrid(np.linspace(min(out_total['rad_gyr'])-0.25, max(out_total['rad_gyr'])+0.25, 200), |
| np.linspace(min(out_total['rmsd_ref'])-0.25, max(out_total['rmsd_ref'])+0.25, 200)) |
| grid_coordinates = np.vstack([x.ravel(), y.ravel()]) |
| density_values = kde(grid_coordinates) |
| |
| density_map = density_values.reshape(x.shape) |
| |
| plt.contourf(x, y, density_map, levels= np.arange(np.max(density_map)/20, np.max(density_map)*1.1, np.max(density_map)/10)) |
| plt.colorbar() |
|
|
| plt.savefig(os.path.join(mdpath,"md.png")) |
| plt.close() |
|
|
| out_total.to_csv(os.path.join(mdpath,"traj_info.csv"),index=False) |
|
|
|
|
| def select_str(file, data_dir, output_dir, select_num=100): |
| info_total = { |
| 'rad_gyr': [], |
| 'rmsd_ref': [], |
| 'traj_filename': [], |
| 'energy': [], |
| } |
|
|
| print(f"Processing {file}") |
| md_dir = os.path.join(data_dir, file) |
| md_csv = pd.read_csv(os.path.join(md_dir, 'traj_info.csv')) |
| md_csv = md_csv.sort_values('energy', ascending=True) |
|
|
| idx_total = np.linspace(0, len(md_csv) - 1, select_num) |
| idx_total = (idx_total / idx_total[-1]) ** (1 / 3) * (len(md_csv) - 1) |
| idx_total = np.unique(np.round(idx_total).astype(int)) |
|
|
| for idx in idx_total: |
| info = md_csv.iloc[idx] |
| traj_filename = info['traj_filename'] |
| shutil.copy(os.path.join(md_dir, traj_filename), output_dir) |
|
|
| info_total['traj_filename'].append(traj_filename) |
| info_total['energy'].append(info['energy']) |
| info_total['rad_gyr'].append(info['rad_gyr']) |
| info_total['rmsd_ref'].append(info['rmsd_ref']) |
|
|
| return info_total |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
|
|
| parser.add_argument("--dir_path", type=str, default="./dataset/ATLAS") |
| parser.add_argument("--filename", type=str, default="ATLAS_filename.txt") |
|
|
| parser.add_argument("--select_num", type=int, default=100) |
| parser.add_argument("--select_dir", type=str, default="./dataset/ATLAS/select") |
|
|
| args = parser.parse_args() |
|
|
| num_processes = 48 |
|
|
| file_txt = os.path.join(args.dir_path, args.filename) |
| os.makedirs(args.select_dir, exist_ok=True) |
|
|
| with open(file_txt,'r+') as f: |
| file_cont = f.read() |
| file_list = file_cont.split("\n") |
|
|
| para1_list = [(file, args.dir_path) for file in file_list] |
| para2_list = [(file, args.dir_path, args.select_dir, args.select_num) for file in file_list] |
|
|
| info_total_all = { |
| 'rad_gyr': [], |
| 'rmsd_ref': [], |
| 'traj_filename': [], |
| 'energy': [], |
| } |
|
|
| with mp.Pool(num_processes) as pool: |
| _ = pool.map(cal_energy, para1_list) |
| results = pool.starmap(select_str, para2_list) |
| |
| for result in results: |
| info_total_all['traj_filename'].extend(result['traj_filename']) |
| info_total_all['energy'].extend(result['energy']) |
| info_total_all['rad_gyr'].extend(result['rad_gyr']) |
| info_total_all['rmsd_ref'].extend(result['rmsd_ref']) |
|
|
| df = pd.DataFrame(info_total_all) |
| df.to_csv(os.path.join(args.select_dir, 'traj_info_select.csv'), index=False) |