| """PDB data loader.""" |
| import math |
| import torch |
| import tree |
| import numpy as np |
| import torch |
| import pandas as pd |
| import logging |
| import os |
| import random |
| import esm |
| import copy |
|
|
| from data import utils as du |
| from data.repr import get_pre_repr |
| from openfold.data import data_transforms |
| from openfold.utils import rigid_utils |
| from data.residue_constants import restype_atom37_mask, order2restype_with_mask |
|
|
| from pytorch_lightning import LightningDataModule |
| from torch.utils.data import DataLoader, Dataset |
| from torch.utils.data.distributed import DistributedSampler, dist |
| from scipy.spatial.transform import Rotation as scipy_R |
|
|
|
|
|
|
|
|
| class PdbDataModule(LightningDataModule): |
| def __init__(self, data_cfg): |
| super().__init__() |
| self.data_cfg = data_cfg |
| self.loader_cfg = data_cfg.loader |
| self.dataset_cfg = data_cfg.dataset |
| self.sampler_cfg = data_cfg.sampler |
|
|
| def setup(self, stage: str): |
| self._train_dataset = PdbDataset( |
| dataset_cfg=self.dataset_cfg, |
| is_training=True, |
| ) |
| self._valid_dataset = PdbDataset( |
| dataset_cfg=self.dataset_cfg, |
| is_training=False, |
| ) |
|
|
| def train_dataloader(self, rank=None, num_replicas=None): |
| num_workers = self.loader_cfg.num_workers |
| return DataLoader( |
| self._train_dataset, |
|
|
| |
| |
| |
| |
| |
| |
| sampler=DistributedSampler(self._train_dataset, shuffle=True), |
|
|
| num_workers=self.loader_cfg.num_workers, |
| prefetch_factor=None if num_workers == 0 else self.loader_cfg.prefetch_factor, |
| persistent_workers=True if num_workers > 0 else False, |
| |
| ) |
|
|
| def val_dataloader(self): |
| num_workers = self.loader_cfg.num_workers |
| return DataLoader( |
| self._valid_dataset, |
| sampler=DistributedSampler(self._valid_dataset, shuffle=False), |
| num_workers=self.loader_cfg.num_workers, |
| prefetch_factor=None if num_workers == 0 else self.loader_cfg.prefetch_factor, |
| persistent_workers=True, |
| |
| ) |
|
|
|
|
| class PdbDataset(Dataset): |
| def __init__( |
| self, |
| *, |
| dataset_cfg, |
| is_training, |
| ): |
| self._log = logging.getLogger(__name__) |
| self._is_training = is_training |
| self._dataset_cfg = dataset_cfg |
| self.split_frac = self._dataset_cfg.split_frac |
| self.random_seed = self._dataset_cfg.seed |
| |
|
|
| self._init_metadata() |
|
|
| @property |
| def is_training(self): |
| return self._is_training |
|
|
| @property |
| def dataset_cfg(self): |
| return self._dataset_cfg |
|
|
| def _init_metadata(self): |
| """Initialize metadata.""" |
|
|
| |
| pdb_csv = pd.read_csv(self.dataset_cfg.csv_path) |
| self.raw_csv = pdb_csv |
| pdb_csv = pdb_csv[pdb_csv.modeled_seq_len <= self.dataset_cfg.max_num_res] |
| pdb_csv = pdb_csv[pdb_csv.modeled_seq_len >= self.dataset_cfg.min_num_res] |
|
|
| if self.dataset_cfg.subset is not None: |
| pdb_csv = pdb_csv.iloc[:self.dataset_cfg.subset] |
| pdb_csv = pdb_csv.sort_values('modeled_seq_len', ascending=False) |
|
|
| |
| |
|
|
| |
| if self.is_training: |
| self.csv = pdb_csv[pdb_csv['is_trainset']] |
| self.csv = pdb_csv.sample(frac=self.split_frac, random_state=self.random_seed).reset_index() |
| self.csv.to_csv(os.path.join(os.path.dirname(self.dataset_cfg.csv_path),"train.csv"), index=False) |
|
|
| |
|
|
| |
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
|
|
| |
|
|
|
|
| |
| |
| self._log.info( |
| f"Training: {len(self.csv)} examples, len_range is {self.csv['modeled_seq_len'].min()}-{self.csv['modeled_seq_len'].max()}") |
| else: |
| self.csv = pdb_csv[~pdb_csv['is_trainset']] |
| |
| |
| |
| self.csv = pdb_csv[pdb_csv.modeled_seq_len <= self.dataset_cfg.max_eval_length] |
| self.csv.to_csv(os.path.join(os.path.dirname(self.dataset_cfg.csv_path),"valid.csv"), index=False) |
|
|
| self.csv = self.csv.sample(n=min(self.dataset_cfg.max_valid_num, len(self.csv)), random_state=self.random_seed).reset_index() |
|
|
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
|
|
| |
| |
| self._log.info( |
| f"Valid: {len(self.csv)} examples, len_range is {self.csv['modeled_seq_len'].min()}-{self.csv['modeled_seq_len'].max()}") |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
|
|
| |
| |
| |
|
|
| |
|
|
| def __len__(self): |
| return len(self.csv) |
|
|
| def __getitem__(self, idx): |
| |
|
|
| processed_path = self.csv.iloc[idx]['processed_path'] |
| chain_feats = du.read_pkl(processed_path) |
| chain_feats['energy'] = torch.tensor(self.csv.iloc[idx]['energy'], dtype=torch.float32) |
|
|
| energy = chain_feats['energy'] |
|
|
|
|
| if self.is_training and self._dataset_cfg.use_split: |
| |
|
|
| split_len = random.randint(self.dataset_cfg.min_num_res, min(self._dataset_cfg.split_len, chain_feats['aatype'].shape[0])) |
|
|
| idx = random.randint(0,chain_feats['aatype'].shape[0]-split_len) |
| output_total = copy.deepcopy(chain_feats) |
|
|
| output_total['energy'] = torch.ones(chain_feats['aatype'].shape) |
|
|
| output_temp = tree.map_structure(lambda x: x[idx:idx+split_len], output_total) |
|
|
| bb_center = np.sum(output_temp['bb_positions'], axis=0) / (np.sum(output_temp['res_mask'].numpy()) + 1e-5) |
| output_temp['trans_1']=(output_temp['trans_1'] - torch.from_numpy(bb_center[None, :])).float() |
| output_temp['bb_positions']=output_temp['bb_positions']- bb_center[None, :] |
| output_temp['all_atom_positions']=output_temp['all_atom_positions'] - torch.from_numpy(bb_center[None, None, :]) |
| output_temp['pair_repr_pre'] = output_temp['pair_repr_pre'][:,idx:idx+split_len] |
|
|
| bb_center_esmfold = torch.sum(output_temp['trans_esmfold'], dim=0) / (np.sum(output_temp['res_mask'].numpy()) + 1e-5) |
| output_temp['trans_esmfold']=(output_temp['trans_esmfold'] - bb_center_esmfold[None, :]).float() |
|
|
| chain_feats = output_temp |
| chain_feats['energy'] = energy |
|
|
|
|
| if self._dataset_cfg.use_rotate_enhance: |
| rot_vet = [random.random() for _ in range(3)] |
| rot_mat = torch.tensor(scipy_R.from_rotvec(rot_vet).as_matrix()) |
| chain_feats['all_atom_positions']=torch.einsum('lij,kj->lik',chain_feats['all_atom_positions'], |
| rot_mat.type(chain_feats['all_atom_positions'].dtype)) |
| |
| all_atom_mask = np.array([restype_atom37_mask[i] for i in chain_feats['aatype']]) |
|
|
| chain_feats_temp = { |
| 'aatype': chain_feats['aatype'], |
| 'all_atom_positions': chain_feats['all_atom_positions'], |
| 'all_atom_mask': torch.tensor(all_atom_mask).double(), |
| } |
| chain_feats_temp = data_transforms.atom37_to_frames(chain_feats_temp) |
| curr_rigid = rigid_utils.Rigid.from_tensor_4x4(chain_feats_temp['rigidgroups_gt_frames'])[:, 0] |
| chain_feats['trans_1'] = curr_rigid.get_trans() |
| chain_feats['rotmats_1'] = curr_rigid.get_rots().get_rot_mats() |
| chain_feats['bb_positions']=(chain_feats['trans_1']).numpy().astype(chain_feats['bb_positions'].dtype) |
|
|
| return chain_feats |
|
|