| import argparse |
| import dataclasses |
| import functools as fn |
| import pandas as pd |
| import os |
| import tree |
| import torch |
| import multiprocessing as mp |
| import time |
| import esm |
| from Bio import PDB |
| import numpy as np |
| from data import utils as du |
| from data import parsers |
| from data import errors |
| from data.repr import get_pre_repr |
| from openfold.data import data_transforms |
| from openfold.utils import rigid_utils |
| from data.cal_trans_rotmats import cal_trans_rotmats |
| from data.ESMfold_pred import ESMFold_Pred |
|
|
|
|
| def process_file(file_path: str, write_dir: str): |
| """Processes protein file into usable, smaller pickles. |
| |
| Args: |
| file_path: Path to file to read. |
| write_dir: Directory to write pickles to. |
| |
| Returns: |
| Saves processed protein to pickle and returns metadata. |
| |
| Raises: |
| DataError if a known filtering rule is hit. |
| All other errors are unexpected and are propogated. |
| """ |
| metadata = {} |
| pdb_name = os.path.basename(file_path).replace('.pdb', '') |
| metadata['pdb_name'] = pdb_name |
|
|
| processed_path = os.path.join(write_dir, f'{pdb_name}.pkl') |
| metadata['processed_path'] = os.path.abspath(processed_path) |
| metadata['raw_path'] = file_path |
| parser = PDB.PDBParser(QUIET=True) |
| structure = parser.get_structure(pdb_name, file_path) |
|
|
| |
| struct_chains = { |
| chain.id.upper(): chain |
| for chain in structure.get_chains()} |
| metadata['num_chains'] = len(struct_chains) |
|
|
| |
| struct_feats = [] |
| all_seqs = set() |
| for chain_id, chain in struct_chains.items(): |
| |
| chain_id = du.chain_str_to_int(chain_id) |
| chain_prot = parsers.process_chain(chain, chain_id) |
| chain_dict = dataclasses.asdict(chain_prot) |
| chain_dict = du.parse_chain_feats(chain_dict) |
| all_seqs.add(tuple(chain_dict['aatype'])) |
| struct_feats.append(chain_dict) |
| if len(all_seqs) == 1: |
| metadata['quaternary_category'] = 'homomer' |
| else: |
| metadata['quaternary_category'] = 'heteromer' |
| complex_feats = du.concat_np_features(struct_feats, False) |
|
|
| |
| complex_aatype = complex_feats['aatype'] |
| metadata['seq_len'] = len(complex_aatype) |
| modeled_idx = np.where(complex_aatype != 20)[0] |
| if np.sum(complex_aatype != 20) == 0: |
| raise errors.LengthError('No modeled residues') |
| min_modeled_idx = np.min(modeled_idx) |
| max_modeled_idx = np.max(modeled_idx) |
| metadata['modeled_seq_len'] = max_modeled_idx - min_modeled_idx + 1 |
| complex_feats['modeled_idx'] = modeled_idx |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| du.write_pkl(processed_path, complex_feats) |
|
|
| return metadata |
|
|
|
|
| def process_serially(all_paths, write_dir): |
| all_metadata = [] |
| for i, file_path in enumerate(all_paths): |
| try: |
| start_time = time.time() |
| metadata = process_file( |
| file_path, |
| write_dir) |
| elapsed_time = time.time() - start_time |
| print(f'Finished {file_path} in {elapsed_time:2.2f}s') |
| all_metadata.append(metadata) |
| except errors.DataError as e: |
| print(f'Failed {file_path}: {e}') |
| return all_metadata |
|
|
|
|
| def process_fn( |
| file_path, |
| verbose=None, |
| write_dir=None): |
| try: |
| start_time = time.time() |
| metadata = process_file( |
| file_path, |
| write_dir) |
| elapsed_time = time.time() - start_time |
| if verbose: |
| print(f'Finished {file_path} in {elapsed_time:2.2f}s') |
| return metadata |
| except errors.DataError as e: |
| if verbose: |
| print(f'Failed {file_path}: {e}') |
|
|
|
|
| def main(args): |
| pdb_dir = args.pdb_dir |
| all_file_paths = [ |
| os.path.join(pdb_dir, x) |
| for x in os.listdir(args.pdb_dir) if '.pdb' in x] |
| total_num_paths = len(all_file_paths) |
| write_dir = args.write_dir |
| if not os.path.exists(write_dir): |
| os.makedirs(write_dir) |
| if args.debug: |
| metadata_file_name = 'metadata_debug.csv' |
| else: |
| metadata_file_name = 'metadata.csv' |
| metadata_path = os.path.join(write_dir, metadata_file_name) |
| print(f'Files will be written to {write_dir}') |
|
|
| |
| if args.num_processes == 1 or args.debug: |
| all_metadata = process_serially( |
| all_file_paths, |
| write_dir) |
| else: |
| _process_fn = fn.partial( |
| process_fn, |
| verbose=args.verbose, |
| write_dir=write_dir) |
| with mp.Pool(processes=args.num_processes) as pool: |
| all_metadata = pool.map(_process_fn, all_file_paths) |
| all_metadata = [x for x in all_metadata if x is not None] |
| metadata_df = pd.DataFrame(all_metadata) |
| metadata_df.to_csv(metadata_path, index=False) |
| succeeded = len(all_metadata) |
| print( |
| f'Finished processing {succeeded}/{total_num_paths} files') |
|
|
|
|
| def cal_repr(processed_file_path, model_esm2, alphabet, batch_converter, esm_device): |
| print(f'cal_repr for {processed_file_path}') |
| processed_feats_org = du.read_pkl(processed_file_path) |
| processed_feats = du.parse_chain_feats(processed_feats_org) |
|
|
| |
| modeled_idx = processed_feats['modeled_idx'] |
| min_idx = np.min(modeled_idx) |
| max_idx = np.max(modeled_idx) |
| |
| processed_feats = tree.map_structure( |
| lambda x: x[min_idx:(max_idx+1)], processed_feats) |
|
|
| |
| chain_feats = { |
| 'aatype': torch.tensor(processed_feats['aatype']).long(), |
| 'all_atom_positions': torch.tensor(processed_feats['atom_positions']).double(), |
| 'all_atom_mask': torch.tensor(processed_feats['atom_mask']).double() |
| } |
| chain_feats = data_transforms.atom37_to_frames(chain_feats) |
| rigids_1 = rigid_utils.Rigid.from_tensor_4x4(chain_feats['rigidgroups_gt_frames'])[:, 0] |
| rotmats_1 = rigids_1.get_rots().get_rot_mats() |
| trans_1 = rigids_1.get_trans() |
| |
|
|
| node_repr_pre, pair_repr_pre = get_pre_repr(chain_feats['aatype'], model_esm2, alphabet, batch_converter, device = esm_device) |
| node_repr_pre = node_repr_pre[0].cpu() |
| pair_repr_pre = pair_repr_pre[0].cpu() |
|
|
| out = { |
| 'aatype': chain_feats['aatype'], |
| 'rotmats_1': rotmats_1, |
| 'trans_1': trans_1, |
| 'res_mask': torch.tensor(processed_feats['bb_mask']).int(), |
| 'bb_positions': processed_feats['bb_positions'], |
| 'all_atom_positions':chain_feats['all_atom_positions'], |
| 'node_repr_pre':node_repr_pre, |
| 'pair_repr_pre':pair_repr_pre, |
| } |
|
|
| du.write_pkl(processed_file_path, out) |
|
|
| def cal_static_structure(processed_file_path, raw_pdb_file, ESMFold): |
| output_total = du.read_pkl(processed_file_path) |
|
|
| save_dir = os.path.join(os.path.dirname(raw_pdb_file), 'ESMFold_Pred_results') |
| os.makedirs(save_dir, exist_ok=True) |
| save_path = os.path.join(save_dir, os.path.basename(processed_file_path)[:6]+'_esmfold.pdb') |
| if not os.path.exists(save_path): |
| print(f'cal_static_structure for {processed_file_path}') |
| ESMFold.predict_str(raw_pdb_file, save_path) |
| trans, rotmats = cal_trans_rotmats(save_path) |
| output_total['trans_esmfold'] = trans |
| output_total['rotmats_esmfold'] = rotmats |
|
|
| du.write_pkl(processed_file_path, output_total) |
|
|
|
|
| def merge_pdb(metadata_path, traj_info_file, valid_seq_file, merged_output_file): |
| df1 = pd.read_csv(metadata_path) |
| df2 = pd.read_csv(traj_info_file) |
| df3 = pd.read_csv(valid_seq_file) |
|
|
| |
| df1['traj_filename'] = [os.path.basename(i) for i in df1['raw_path']] |
|
|
| |
| merged = df1.merge(df2[['traj_filename', 'energy']], on='traj_filename', how='left') |
| merged['is_trainset'] = ~merged['traj_filename'].str[:6].isin(df3['file']) |
|
|
| |
| merged.to_csv(merged_output_file, index=False) |
| print('merge complete!') |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
|
|
| parser.add_argument("--pdb_dir", type=str, default="./dataset/ATLAS/select") |
| parser.add_argument("--write_dir", type=str, default="./dataset/ATLAS/select/pkl") |
| parser.add_argument("--csv_name", type=str, default="metadata.csv") |
| parser.add_argument("--debug", type=bool, default=False) |
| parser.add_argument("--num_processes", type=int, default=48) |
| parser.add_argument('--verbose', help='Whether to log everything.',action='store_true') |
|
|
| parser.add_argument("--esm_device", type=str, default='cuda') |
|
|
| parser.add_argument("--traj_info_file", type=str, default='./dataset/ATLAS/select/traj_info_select.csv') |
| parser.add_argument("--valid_seq_file", type=str, default='./inference/valid_seq.csv') |
| parser.add_argument("--merged_output_file", type=str, default='./dataset/ATLAS/select/pkl/metadata_merged.csv') |
|
|
| args = parser.parse_args() |
|
|
| |
| main(args) |
|
|
| |
| csv_path = os.path.join(args.write_dir, args.csv_name) |
| pdb_csv = pd.read_csv(csv_path) |
| pdb_csv = pdb_csv.sort_values('modeled_seq_len', ascending=False) |
| model_esm2, alphabet = esm.pretrained.esm2_t33_650M_UR50D() |
| batch_converter = alphabet.get_batch_converter() |
| model_esm2.eval() |
| model_esm2.requires_grad_(False) |
| model_esm2.to(args.esm_device) |
| for idx in range(len(pdb_csv)): |
| cal_repr(pdb_csv.iloc[idx]['processed_path'], model_esm2, alphabet, batch_converter, args.esm_device) |
|
|
| |
| csv_path = os.path.join(args.write_dir, args.csv_name) |
| pdb_csv = pd.read_csv(csv_path) |
| ESMFold = ESMFold_Pred(device = args.esm_device) |
| for idx in range(len(pdb_csv)): |
| cal_static_structure(pdb_csv.iloc[idx]['processed_path'], pdb_csv.iloc[idx]['raw_path'], ESMFold) |
|
|
| |
| csv_path = os.path.join(args.write_dir, args.csv_name) |
| merge_pdb(csv_path, args.traj_info_file, args.valid_seq_file, args.merged_output_file) |
|
|
|
|
| |
| |
|
|
|
|