PepFlow / models_con /pep_dataloader.py
Irwiny123's picture
添加PepFlow模型初始代码
ef423c5
"""pep-rec dataset"""
import os
import logging
import joblib
import pickle
import lmdb
from Bio import PDB
from Bio.PDB import PDBExceptions
from torch.utils.data import Dataset
from tqdm.auto import tqdm
from pepflow.modules.protein.parsers import parse_pdb
from pepflow.modules.common.geometry import *
from pepflow.modules.protein.constants import *
from pepflow.utils.data import mask_select_data, find_longest_true_segment, PaddingCollate
from torch.utils.data import DataLoader
from omegaconf import OmegaConf
from easydict import EasyDict
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler, dist
from pepflow.utils.misc import load_config
from pepflow.utils.train import recursive_to
from models_con.torsion import get_torsion_angle
import torch
from pepflow.modules.protein.writers import save_pdb
# bind_dic = torch.load("/datapool/data2/home/jiahan/ResProj/PepDiff/frame-flow/misc/affinity_dict.pt")
# testset
names = []
with open('/datapool/data2/home/ruihan/data/jiahan/ResProj/PepDiff/pepflowww/Data/names.txt','r') as f:
for line in f:
names.append(line.strip())
def preprocess_structure(task):
try:
if task['id'] in names:
raise ValueError(f'{task["id"]} not in names')
pdb_path = task['pdb_path']
# pep
# process peptide and find center of mass
pep = parse_pdb(os.path.join(pdb_path,'peptide.pdb'))[0]
center = torch.sum(pep['pos_heavyatom'][pep['mask_heavyatom'][:, BBHeavyAtom.CA], BBHeavyAtom.CA], dim=0) / (torch.sum(pep['mask_heavyatom'][:, BBHeavyAtom.CA]) + 1e-8)
pep['pos_heavyatom'] = pep['pos_heavyatom'] - center[None, None, :]
pep['torsion_angle'],pep['torsion_angle_mask'] = get_torsion_angle(pep['pos_heavyatom'],pep['aa']) # calc angles after translation
if len(pep['aa'])<3 or len(pep['aa'])>25:
raise ValueError('peptide length not in [3,25]')
# rec
rec = parse_pdb(os.path.join(pdb_path,'pocket.pdb'))[0]
rec['pos_heavyatom'] = rec['pos_heavyatom'] - center[None, None, :]
rec['torsion_angle'],rec['torsion_angle_mask'] = get_torsion_angle(rec['pos_heavyatom'],rec['aa']) # calc angles after translation
rec['chain_nb'] += 1
# meta data
data = {}
data['id'] = task['id']
data['generate_mask'] = torch.cat([torch.zeros_like(rec['aa']), torch.ones_like(pep['aa'])], dim=0).bool()
for k in rec.keys():
if isinstance(rec[k], torch.Tensor):
data[k] = torch.cat([rec[k], pep[k]], dim=0)
elif isinstance(rec[k], list):
data[k] = rec[k] + pep[k]
else:
raise ValueError(f'Unknown type of {rec[k]}')
return data
except (
PDBExceptions.PDBConstructionException,
KeyError,
ValueError,
TypeError
) as e:
logging.warning('[{}] {}: {}'.format(
task['id'],
e.__class__.__name__,
str(e)
))
return None
class PepDataset(Dataset):
MAP_SIZE = 32*(1024*1024*1024) # 32GB
def __init__(self, structure_dir = "./Data/PepMerge_new/", dataset_dir = "./Data/",
name = 'pep', transform=None, reset=False):
super().__init__()
self.structure_dir = structure_dir
self.dataset_dir = dataset_dir
self.transform = transform
self.name = name
self.db_conn = None
self.db_ids = None
self._load_structures(reset)
@property
def _cache_db_path(self):
return os.path.join(self.dataset_dir, f'{self.name}_structure_cache.lmdb')
def _connect_db(self):
self._close_db()
self.db_conn = lmdb.open(
self._cache_db_path,
map_size=self.MAP_SIZE,
create=False,
subdir=False,
readonly=True,
lock=False,
readahead=False,
meminit=False,
)
with self.db_conn.begin() as txn:
keys = [k.decode() for k in txn.cursor().iternext(values=False)]
self.db_ids = keys
def _close_db(self):
if self.db_conn is not None:
self.db_conn.close()
self.db_conn = None
self.db_ids = None
def _load_structures(self, reset):
all_pdbs = os.listdir(self.structure_dir)
if reset:
if os.path.exists(self._cache_db_path):
os.remove(self._cache_db_path)
lock_file = self._cache_db_path + "-lock"
if os.path.exists(lock_file):
os.remove(lock_file)
self._close_db()
todo_pdbs = all_pdbs
else:
if not os.path.exists(self._cache_db_path):
todo_pdbs = all_pdbs
else:
todo_pdbs = []
# self._connect_db()
# processed_pdbs = self.db_ids
# self._close_db()
# todo_pdbs = list(set(all_pdbs) - set(processed_pdbs))
if len(todo_pdbs) > 0:
self._preprocess_structures(todo_pdbs)
def _preprocess_structures(self, pdb_list):
tasks = []
for pdb_fname in pdb_list:
pdb_path = os.path.join(self.structure_dir, pdb_fname)
tasks.append({
'id': pdb_fname,
'pdb_path': pdb_path,
})
data_list = joblib.Parallel(
n_jobs = max(joblib.cpu_count() // 2, 1),
)(
joblib.delayed(preprocess_structure)(task)
for task in tqdm(tasks, dynamic_ncols=True, desc='Preprocess')
)
db_conn = lmdb.open(
self._cache_db_path,
map_size = self.MAP_SIZE,
create=True,
subdir=False,
readonly=False,
)
ids = []
with db_conn.begin(write=True, buffers=True) as txn:
for data in tqdm(data_list, dynamic_ncols=True, desc='Write to LMDB'):
if data is None:
continue
ids.append(data['id'])
txn.put(data['id'].encode('utf-8'), pickle.dumps(data))
def __len__(self):
self._connect_db() # make sure db_ids is not None
return len(self.db_ids)
def __getitem__(self, index):
self._connect_db()
id = self.db_ids[index]
with self.db_conn.begin() as txn:
data = pickle.loads(txn.get(id.encode()))
if self.transform is not None:
data = self.transform(data)
return data
if __name__ == '__main__':
device = 'cuda:1'
config,cfg_name = load_config("./configs/learn/learn_all.yaml")
dataset = PepDataset(structure_dir = "./Data/PepMerge_new/", dataset_dir = "/Data/Fixed Data",
name = 'pep_pocket_test', transform=None, reset=True)
print(len(dataset))
print(dataset[0])
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, num_workers=4, collate_fn=PaddingCollate(eight=False))
batch = next(iter(dataloader))
print(batch['torsion_angle'].shape)
print(batch['torsion_angle_mask'].shape)