| """pep-rec dataset""" |
| import os |
| import logging |
| import joblib |
| import pickle |
| import lmdb |
| from Bio import PDB |
| from Bio.PDB import PDBExceptions |
| from torch.utils.data import Dataset |
| from tqdm.auto import tqdm |
|
|
| from pepflow.modules.protein.parsers import parse_pdb |
| from pepflow.modules.common.geometry import * |
| from pepflow.modules.protein.constants import * |
| from pepflow.utils.data import mask_select_data, find_longest_true_segment, PaddingCollate |
| from torch.utils.data import DataLoader |
|
|
| from omegaconf import OmegaConf |
| from easydict import EasyDict |
|
|
| from torch.utils.data import DataLoader, Dataset |
| from torch.utils.data.distributed import DistributedSampler, dist |
|
|
| from pepflow.utils.misc import load_config |
| from pepflow.utils.train import recursive_to |
|
|
| from models_con.torsion import get_torsion_angle |
|
|
| import torch |
|
|
| from pepflow.modules.protein.writers import save_pdb |
|
|
| |
|
|
| |
| names = [] |
| with open('/datapool/data2/home/ruihan/data/jiahan/ResProj/PepDiff/pepflowww/Data/names.txt','r') as f: |
| for line in f: |
| names.append(line.strip()) |
| |
| def preprocess_structure(task): |
|
|
| try: |
| if task['id'] in names: |
| raise ValueError(f'{task["id"]} not in names') |
| pdb_path = task['pdb_path'] |
| |
| |
| pep = parse_pdb(os.path.join(pdb_path,'peptide.pdb'))[0] |
| center = torch.sum(pep['pos_heavyatom'][pep['mask_heavyatom'][:, BBHeavyAtom.CA], BBHeavyAtom.CA], dim=0) / (torch.sum(pep['mask_heavyatom'][:, BBHeavyAtom.CA]) + 1e-8) |
| pep['pos_heavyatom'] = pep['pos_heavyatom'] - center[None, None, :] |
| pep['torsion_angle'],pep['torsion_angle_mask'] = get_torsion_angle(pep['pos_heavyatom'],pep['aa']) |
| if len(pep['aa'])<3 or len(pep['aa'])>25: |
| raise ValueError('peptide length not in [3,25]') |
| |
| rec = parse_pdb(os.path.join(pdb_path,'pocket.pdb'))[0] |
| rec['pos_heavyatom'] = rec['pos_heavyatom'] - center[None, None, :] |
| rec['torsion_angle'],rec['torsion_angle_mask'] = get_torsion_angle(rec['pos_heavyatom'],rec['aa']) |
| rec['chain_nb'] += 1 |
| |
| data = {} |
| data['id'] = task['id'] |
| data['generate_mask'] = torch.cat([torch.zeros_like(rec['aa']), torch.ones_like(pep['aa'])], dim=0).bool() |
| for k in rec.keys(): |
| if isinstance(rec[k], torch.Tensor): |
| data[k] = torch.cat([rec[k], pep[k]], dim=0) |
| elif isinstance(rec[k], list): |
| data[k] = rec[k] + pep[k] |
| else: |
| raise ValueError(f'Unknown type of {rec[k]}') |
| return data |
| |
| except ( |
| PDBExceptions.PDBConstructionException, |
| KeyError, |
| ValueError, |
| TypeError |
| ) as e: |
| logging.warning('[{}] {}: {}'.format( |
| task['id'], |
| e.__class__.__name__, |
| str(e) |
| )) |
| return None |
|
|
|
|
| class PepDataset(Dataset): |
|
|
| MAP_SIZE = 32*(1024*1024*1024) |
|
|
| def __init__(self, structure_dir = "./Data/PepMerge_new/", dataset_dir = "./Data/", |
| name = 'pep', transform=None, reset=False): |
|
|
| super().__init__() |
| self.structure_dir = structure_dir |
| self.dataset_dir = dataset_dir |
| self.transform = transform |
| self.name = name |
|
|
| self.db_conn = None |
| self.db_ids = None |
| self._load_structures(reset) |
|
|
| @property |
| def _cache_db_path(self): |
| return os.path.join(self.dataset_dir, f'{self.name}_structure_cache.lmdb') |
|
|
| def _connect_db(self): |
| self._close_db() |
| self.db_conn = lmdb.open( |
| self._cache_db_path, |
| map_size=self.MAP_SIZE, |
| create=False, |
| subdir=False, |
| readonly=True, |
| lock=False, |
| readahead=False, |
| meminit=False, |
| ) |
| with self.db_conn.begin() as txn: |
| keys = [k.decode() for k in txn.cursor().iternext(values=False)] |
| self.db_ids = keys |
|
|
| def _close_db(self): |
| if self.db_conn is not None: |
| self.db_conn.close() |
| self.db_conn = None |
| self.db_ids = None |
|
|
| def _load_structures(self, reset): |
| all_pdbs = os.listdir(self.structure_dir) |
|
|
| if reset: |
| if os.path.exists(self._cache_db_path): |
| os.remove(self._cache_db_path) |
| lock_file = self._cache_db_path + "-lock" |
| if os.path.exists(lock_file): |
| os.remove(lock_file) |
| self._close_db() |
| todo_pdbs = all_pdbs |
| else: |
| if not os.path.exists(self._cache_db_path): |
| todo_pdbs = all_pdbs |
| else: |
| todo_pdbs = [] |
| |
| |
| |
| |
|
|
| if len(todo_pdbs) > 0: |
| self._preprocess_structures(todo_pdbs) |
| |
| def _preprocess_structures(self, pdb_list): |
| tasks = [] |
| for pdb_fname in pdb_list: |
| pdb_path = os.path.join(self.structure_dir, pdb_fname) |
| tasks.append({ |
| 'id': pdb_fname, |
| 'pdb_path': pdb_path, |
| }) |
|
|
| data_list = joblib.Parallel( |
| n_jobs = max(joblib.cpu_count() // 2, 1), |
| )( |
| joblib.delayed(preprocess_structure)(task) |
| for task in tqdm(tasks, dynamic_ncols=True, desc='Preprocess') |
| ) |
|
|
| db_conn = lmdb.open( |
| self._cache_db_path, |
| map_size = self.MAP_SIZE, |
| create=True, |
| subdir=False, |
| readonly=False, |
| ) |
| ids = [] |
| with db_conn.begin(write=True, buffers=True) as txn: |
| for data in tqdm(data_list, dynamic_ncols=True, desc='Write to LMDB'): |
| if data is None: |
| continue |
| ids.append(data['id']) |
| txn.put(data['id'].encode('utf-8'), pickle.dumps(data)) |
|
|
| def __len__(self): |
| self._connect_db() |
| return len(self.db_ids) |
|
|
| def __getitem__(self, index): |
| self._connect_db() |
| id = self.db_ids[index] |
| with self.db_conn.begin() as txn: |
| data = pickle.loads(txn.get(id.encode())) |
| if self.transform is not None: |
| data = self.transform(data) |
| return data |
|
|
|
|
|
|
| if __name__ == '__main__': |
| device = 'cuda:1' |
| config,cfg_name = load_config("./configs/learn/learn_all.yaml") |
| dataset = PepDataset(structure_dir = "./Data/PepMerge_new/", dataset_dir = "/Data/Fixed Data", |
| name = 'pep_pocket_test', transform=None, reset=True) |
| print(len(dataset)) |
| print(dataset[0]) |
| |
| dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4, collate_fn=PaddingCollate(eight=False)) |
|
|
| batch = next(iter(dataloader)) |
| print(batch['torsion_angle'].shape) |
| print(batch['torsion_angle_mask'].shape) |