|
|
| import os |
| from typing import Optional, Any |
|
|
| import numpy as np |
| import torch |
| from torch.nn.utils.rnn import pad_sequence |
|
|
| from utils import register as R |
| from utils.const import sidechain_atoms |
|
|
| from data.converter.list_blocks_to_pdb import list_blocks_to_pdb |
|
|
| from .format import VOCAB, Block, Atom |
| from .mmap_dataset import MMAPDataset |
| from .resample import ClusterResampler |
|
|
|
|
|
|
| def calculate_covariance_matrix(point_cloud): |
| |
| covariance_matrix = np.cov(point_cloud, rowvar=False) |
| return covariance_matrix |
|
|
|
|
| @R.register('CoDesignDataset') |
| class CoDesignDataset(MMAPDataset): |
|
|
| MAX_N_ATOM = 14 |
|
|
| def __init__( |
| self, |
| mmap_dir: str, |
| backbone_only: bool, |
| specify_data: Optional[str] = None, |
| specify_index: Optional[str] = None, |
| padding_collate: bool = False, |
| cluster: Optional[str] = None, |
| use_covariance_matrix: bool = False |
| ) -> None: |
| super().__init__(mmap_dir, specify_data, specify_index) |
| self.mmap_dir = mmap_dir |
| self.backbone_only = backbone_only |
| self._lengths = [len(prop[-1].split(',')) + int(prop[1]) for prop in self._properties] |
| self.padding_collate = padding_collate |
| self.resampler = ClusterResampler(cluster) if cluster else None |
| self.use_covariance_matrix = use_covariance_matrix |
|
|
| self.dynamic_idxs = [i for i in range(len(self))] |
| self.update_epoch() |
|
|
| def update_epoch(self): |
| if self.resampler is not None: |
| self.dynamic_idxs = self.resampler(len(self)) |
|
|
| def get_len(self, idx): |
| return self._lengths[self.dynamic_idxs[idx]] |
|
|
| def get_summary(self, idx: int): |
| props = self._properties[idx] |
| _id = self._indexes[idx][0].split('.')[0] |
| ref_pdb = os.path.join(self.mmap_dir, '..', 'pdbs', _id + '.pdb') |
| rec_chain, lig_chain = props[4], props[5] |
| return _id, ref_pdb, rec_chain, lig_chain |
|
|
| def __getitem__(self, idx: int): |
| idx = self.dynamic_idxs[idx] |
| rec_blocks, lig_blocks = super().__getitem__(idx) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| pocket_idx = [int(i) for i in self._properties[idx][-1].split(',')] |
| rec_position_ids = [i + 1 for i, _ in enumerate(rec_blocks)] |
| rec_blocks = [rec_blocks[i] for i in pocket_idx] |
| rec_position_ids = [rec_position_ids[i] for i in pocket_idx] |
| rec_blocks = [Block.from_tuple(tup) for tup in rec_blocks] |
| lig_blocks = [Block.from_tuple(tup) for tup in lig_blocks] |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| mask = [0 for _ in rec_blocks] + [1 for _ in lig_blocks] |
| position_ids = rec_position_ids + [i + 1 for i, _ in enumerate(lig_blocks)] |
| X, S, atom_mask = [], [], [] |
| for block in rec_blocks + lig_blocks: |
| symbol = VOCAB.abrv_to_symbol(block.abrv) |
| atom2coord = { unit.name: unit.get_coord() for unit in block.units } |
| bb_pos = np.mean(list(atom2coord.values()), axis=0).tolist() |
| coords, coord_mask = [], [] |
| for atom_name in VOCAB.backbone_atoms + sidechain_atoms.get(symbol, []): |
| if atom_name in atom2coord: |
| coords.append(atom2coord[atom_name]) |
| coord_mask.append(1) |
| else: |
| coords.append(bb_pos) |
| coord_mask.append(0) |
| n_pad = self.MAX_N_ATOM - len(coords) |
| for _ in range(n_pad): |
| coords.append(bb_pos) |
| coord_mask.append(0) |
|
|
| X.append(coords) |
| S.append(VOCAB.symbol_to_idx(symbol)) |
| atom_mask.append(coord_mask) |
| |
| X, atom_mask = torch.tensor(X, dtype=torch.float), torch.tensor(atom_mask, dtype=torch.bool) |
| mask = torch.tensor(mask, dtype=torch.bool) |
| if self.backbone_only: |
| X, atom_mask = X[:, :4], atom_mask[:, :4] |
|
|
| if self.use_covariance_matrix: |
| cov = calculate_covariance_matrix(X[~mask][:, 1][atom_mask[~mask][:, 1]].numpy()) |
| eps = 1e-4 |
| cov = cov + eps * np.identity(cov.shape[0]) |
| L = torch.from_numpy(np.linalg.cholesky(cov)).float().unsqueeze(0) |
| else: |
| L = None |
|
|
| item = { |
| 'X': X, |
| 'S': torch.tensor(S, dtype=torch.long), |
| 'position_ids': torch.tensor(position_ids, dtype=torch.long), |
| 'mask': mask, |
| 'atom_mask': atom_mask, |
| 'lengths': len(S), |
| } |
| if L is not None: |
| item['L'] = L |
| return item |
|
|
| def collate_fn(self, batch): |
| if self.padding_collate: |
| results = {} |
| pad_idx = VOCAB.symbol_to_idx(VOCAB.PAD) |
| for key in batch[0]: |
| values = [item[key] for item in batch] |
| if values[0] is None: |
| results[key] = None |
| continue |
| if key == 'lengths': |
| results[key] = torch.tensor(values, dtype=torch.long) |
| elif key == 'S': |
| results[key] = pad_sequence(values, batch_first=True, padding_value=pad_idx) |
| else: |
| results[key] = pad_sequence(values, batch_first=True, padding_value=0) |
| return results |
| else: |
| results = {} |
| for key in batch[0]: |
| values = [item[key] for item in batch] |
| if values[0] is None: |
| results[key] = None |
| continue |
| if key == 'lengths': |
| results[key] = torch.tensor(values, dtype=torch.long) |
| else: |
| results[key] = torch.cat(values, dim=0) |
| return results |
|
|
|
|
| @R.register('ShapeDataset') |
| class ShapeDataset(CoDesignDataset): |
| def __init__( |
| self, |
| mmap_dir: str, |
| specify_data: Optional[str] = None, |
| specify_index: Optional[str] = None, |
| padding_collate: bool = False, |
| cluster: Optional[str] = None |
| ) -> None: |
| super().__init__(mmap_dir, False, specify_data, specify_index, padding_collate, cluster) |
| self.ca_idx = VOCAB.backbone_atoms.index('CA') |
| |
| def __getitem__(self, idx: int): |
| item = super().__getitem__(idx) |
|
|
| |
| X = item['X'] |
| atom_mask = item['atom_mask'] |
| ca_x = X[:, self.ca_idx].unsqueeze(1) |
| sc_x = X[:, 4:] |
| dist = torch.norm(sc_x - ca_x, dim=-1) |
| dist = dist.masked_fill(~atom_mask[:, 4:], 1e10) |
| furthest_atom_x = sc_x[torch.arange(sc_x.shape[0]), torch.argmax(dist, dim=-1)] |
| X = torch.cat([ca_x, furthest_atom_x.unsqueeze(1)], dim=1) |
| |
| item['X'] = X |
| return item |
|
|
|
|
| if __name__ == '__main__': |
| import sys |
| dataset = CoDesignDataset(sys.argv[1], backbone_only=True) |
| print(dataset[0]) |
|
|