| |
| import torch |
| from torch.utils.data import Dataset, DataLoader |
| from datasets import Dataset,load_from_disk |
| import sys |
| import pytorch_lightning as pl |
| from smiles_tokenizer.my_tokenizers import SMILES_SPE_Tokenizer |
| from functools import partial |
| import re |
| from tqdm import tqdm |
| import os |
| import pdb |
|
|
|
|
| class DynamicBatchingDataset(Dataset): |
| def __init__(self, dataset_dict, tokenizer): |
| print('Initializing dataset...') |
| self.dataset_dict = { |
| 'attention_mask': [torch.tensor(item) for item in tqdm(dataset_dict['attention_mask'])], |
| 'input_ids': [torch.tensor(item) for item in dataset_dict['input_ids']], |
| 'labels': dataset_dict['labels'] |
| } |
| self.tokenizer = tokenizer |
|
|
| def __len__(self): |
| return len(self.dataset_dict['attention_mask']) |
|
|
| def __getitem__(self, idx): |
| if isinstance(idx, int): |
| return { |
| 'input_ids': self.dataset_dict['input_ids'][idx], |
| 'attention_mask': self.dataset_dict['attention_mask'][idx], |
| 'labels': self.dataset_dict['labels'][idx] |
| } |
| elif isinstance(idx, list): |
| return { |
| 'input_ids': [self.dataset_dict['input_ids'][i] for i in idx], |
| 'attention_mask': [self.dataset_dict['attention_mask'][i] for i in idx], |
| 'labels': [self.dataset_dict['labels'][i] for i in idx] |
| } |
| else: |
| raise ValueError(f"Expected idx to be int or list, but got {type(idx)}") |
|
|
| class CustomDataModule(pl.LightningDataModule): |
| def __init__(self, dataset_path, tokenizer): |
| super().__init__() |
| self.dataset = load_from_disk(dataset_path) |
| self.tokenizer = tokenizer |
| self.dataset_path = dataset_path |
| |
| def peptide_bond_mask(self, smiles_list): |
| """ |
| Returns a mask with shape (batch_size, seq_length) that has 1 at the locations |
| of recognized bonds in the positions dictionary and 0 elsewhere. |
| |
| Args: |
| smiles_list: List of peptide SMILES strings (batch of SMILES strings). |
| |
| Returns: |
| np.ndarray: A mask of shape (batch_size, seq_length) with 1s at bond positions. |
| """ |
| |
| batch_size = len(smiles_list) |
| max_seq_length = 1035 |
| mask = torch.zeros((batch_size, max_seq_length), dtype=torch.int) |
|
|
| bond_patterns = [ |
| (r'OC\(=O\)', 'ester'), |
| (r'N\(C\)C\(=O\)', 'n_methyl'), |
| (r'N[12]C\(=O\)', 'peptide'), |
| (r'NC\(=O\)', 'peptide'), |
| (r'C\(=O\)N\(C\)', 'n_methyl'), |
| (r'C\(=O\)N[12]?', 'peptide') |
| ] |
|
|
| for batch_idx, smiles in enumerate(smiles_list): |
| positions = [] |
| used = set() |
|
|
| |
| for pattern, bond_type in bond_patterns: |
| for match in re.finditer(pattern, smiles): |
| if not any(p in range(match.start(), match.end()) for p in used): |
| positions.append({ |
| 'start': match.start(), |
| 'end': match.end(), |
| 'type': bond_type, |
| 'pattern': match.group() |
| }) |
| used.update(range(match.start(), match.end())) |
|
|
| |
| for pos in positions: |
| mask[batch_idx, pos['start']:pos['end']] = 1 |
|
|
| return mask |
|
|
| def peptide_token_mask(self, smiles_list, token_lists): |
| """ |
| Returns a mask with shape (batch_size, num_tokens) that has 1 for tokens |
| where any part of the token overlaps with a peptide bond, and 0 elsewhere. |
| |
| Args: |
| smiles_list: List of peptide SMILES strings (batch of SMILES strings). |
| token_lists: List of tokenized SMILES strings (split into tokens). |
| |
| Returns: |
| np.ndarray: A mask of shape (batch_size, num_tokens) with 1s for peptide bond tokens. |
| """ |
| |
| batch_size = len(smiles_list) |
| token_seq_length = max(len(tokens) for tokens in token_lists) |
| tokenized_masks = torch.zeros((batch_size, token_seq_length), dtype=torch.int) |
| atomwise_masks = self.peptide_bond_mask(smiles_list) |
|
|
| |
| for batch_idx, atomwise_mask in enumerate(atomwise_masks): |
| token_seq = token_lists[batch_idx] |
| atom_idx = 0 |
| |
| for token_idx, token in enumerate(token_seq): |
| if token_idx != 0 and token_idx != len(token_seq) - 1: |
| if torch.sum(atomwise_mask[atom_idx:atom_idx+len(token)]) >= 1: |
| tokenized_masks[batch_idx][token_idx] = 1 |
| atom_idx += len(token) |
| |
| return tokenized_masks |
| |
| def collate_fn(self, batch): |
| item = batch[0] |
| |
| |
| |
| token_array = self.tokenizer.get_token_split(item['input_ids']) |
| bond_mask = self.peptide_token_mask(item['labels'], token_array) |
|
|
| return { |
| 'input_ids': item['input_ids'], |
| 'attention_mask': item['attention_mask'], |
| 'bond_mask': bond_mask |
| } |
| |
| def _train_dataset(self): |
| train_dataset = DynamicBatchingDataset(self.dataset['train'], tokenizer=self.tokenizer) |
| return train_dataset |
| |
| def _val_dataset(self): |
| val_dataset = DynamicBatchingDataset(self.dataset['val'], tokenizer=self.tokenizer) |
| return val_dataset |
| |
| def train_dataloader(self): |
| train_dataset = self._train_dataset() |
| |
| |
| |
|
|
| return DataLoader( |
| train_dataset, |
| batch_size=1, |
| collate_fn=self.collate_fn, |
| shuffle=True, |
| num_workers=12, |
| pin_memory=True |
| ) |
|
|
| def val_dataloader(self): |
| val_dataset = self._val_dataset() |
| |
| |
| |
|
|
| return DataLoader( |
| val_dataset, |
| batch_size=1, |
| collate_fn=self.collate_fn, |
| num_workers=8, |
| pin_memory=True |
| ) |
|
|
| class RectifyDataModule(pl.LightningDataModule): |
| def __init__(self, dataset_path): |
| super().__init__() |
| self.dataset_path = dataset_path |
|
|
| def collate_fn(self, batch): |
| return { |
| 'source_ids': torch.tensor(batch[0]['source_ids']), |
| 'target_ids': torch.tensor(batch[0]['target_ids']), |
| 'bond_mask': torch.tensor(batch[0]['bond_mask']), |
| } |
|
|
| def train_dataloader(self): |
| train_dataset = load_from_disk(os.path.join(self.dataset_path, 'train')) |
| return DataLoader( |
| train_dataset, |
| batch_size=1, |
| collate_fn=self.collate_fn, |
| num_workers=12, |
| pin_memory=True |
| ) |
| |
| def val_dataloader(self): |
| val_dataset = load_from_disk(os.path.join(self.dataset_path, 'validation')) |
| return DataLoader( |
| val_dataset, |
| batch_size=1, |
| collate_fn=self.collate_fn, |
| num_workers=8, |
| pin_memory=True |
| ) |