| """ |
| TD3B Data Utilities |
| Handles loading and preprocessing of TD3B_data.csv for both oracle training and finetuning. |
| """ |
|
|
| import pandas as pd |
| import numpy as np |
| import torch |
| from torch.utils.data import Dataset |
| from typing import Dict, List, Optional, Tuple |
| import sys |
|
|
| try: |
| from rdkit import Chem |
| except ImportError: |
| Chem = None |
|
|
| sys.path.append('..') |
|
|
| AA_SET = set("ACDEFGHIKLMNPQRSTVWY") |
|
|
|
|
| def is_amino_acid_sequence(seq: str) -> bool: |
| if not isinstance(seq, str) or not seq: |
| return False |
| seq = seq.strip().upper() |
| return all(ch in AA_SET for ch in seq) |
|
|
|
|
| def aa_sequence_to_smiles(seq: str) -> Optional[str]: |
| if Chem is None or not is_amino_acid_sequence(seq): |
| return None |
| try: |
| mol = Chem.MolFromSequence(seq) |
| except Exception: |
| return None |
| if mol is None: |
| return None |
| return Chem.MolToSmiles(mol, isomericSmiles=True) |
|
|
|
|
| def peptide_seq_to_smiles(seq: str) -> str: |
| smiles = aa_sequence_to_smiles(seq) |
| return smiles if smiles is not None else seq |
|
|
|
|
| def smiles_token_length(smiles: str, tokenizer) -> int: |
| if tokenizer is None: |
| return len(smiles) |
| tokens = tokenizer(smiles, return_tensors="pt")["input_ids"][0] |
| return int(tokens.numel()) |
|
|
|
|
| class TD3BDataset(Dataset): |
| """ |
| Dataset for TD3B that loads peptide-protein pairs with directional labels. |
| |
| Supports both: |
| 1. Oracle training: uses all pairs for training f_φ |
| 2. Finetuning: provides target proteins for conditioning during RL |
| """ |
|
|
| def __init__( |
| self, |
| data_path: str, |
| mode: str = 'oracle', |
| peptide_tokenizer=None, |
| protein_tokenizer=None, |
| max_peptide_length: int = 200, |
| max_protein_length: int = 1000, |
| target_protein_id: Optional[str] = None, |
| convert_peptide_to_smiles: bool = True, |
| ): |
| """ |
| Args: |
| data_path: Path to TD3B_data.csv |
| mode: 'oracle' for training f_φ, 'finetune' for RL conditioning |
| peptide_tokenizer: Tokenizer for peptide sequences |
| protein_tokenizer: Tokenizer for protein sequences (ESM-2) |
| max_peptide_length: Maximum peptide sequence length |
| max_protein_length: Maximum protein sequence length |
| target_protein_id: UniProt ID for target protein (finetuning mode) |
| """ |
| self.mode = mode |
| self.data_path = data_path |
| self.peptide_tokenizer = peptide_tokenizer |
| self.protein_tokenizer = protein_tokenizer |
| self.max_peptide_length = max_peptide_length |
| self.max_protein_length = max_protein_length |
| self.convert_peptide_to_smiles = convert_peptide_to_smiles |
|
|
| |
| self.data = pd.read_csv(data_path) |
| print(f"Loaded {len(self.data)} peptide-protein pairs from {data_path}") |
|
|
| |
| if mode == 'finetune' and target_protein_id is not None: |
| self.data = self.data[self.data['Target_UniProt_ID'] == target_protein_id] |
| print(f"Filtered to {len(self.data)} pairs for target {target_protein_id}") |
|
|
| |
| self.label_map = { |
| 'agonist': 1.0, |
| 'antagonist': -1.0, |
| 'neutral': 0.0, |
| } |
|
|
| |
| self.data['numeric_label'] = self.data['label'].map(self.label_map) |
|
|
| |
| self.data['confidence'] = self.data['Action'].apply(self._action_to_confidence) |
|
|
| def _action_to_confidence(self, action: str) -> float: |
| """ |
| Convert action description to confidence score. |
| |
| Full agonist/antagonist: 1.0 |
| Partial/Weak: 0.7 |
| Others: 0.5 |
| """ |
| action_lower = action.lower() |
|
|
| if 'full' in action_lower: |
| return 1.0 |
| elif 'partial' in action_lower or 'weak' in action_lower: |
| return 0.7 |
| elif 'slows' in action_lower or 'modulator' in action_lower: |
| return 0.5 |
| else: |
| return 0.8 |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, idx): |
| row = self.data.iloc[idx] |
|
|
| |
| peptide_seq = row['Ligand_Sequence'] |
| protein_seq = row['Target_Sequence'] |
| peptide_smiles = self._peptide_to_smiles(peptide_seq) |
| peptide_smiles_length = smiles_token_length(peptide_smiles, self.peptide_tokenizer) |
|
|
| |
| if self.peptide_tokenizer is not None: |
| peptide_tokens = self._tokenize_peptide(peptide_smiles) |
| else: |
| peptide_tokens = torch.zeros(self.max_peptide_length, dtype=torch.long) |
|
|
| if self.protein_tokenizer is not None: |
| protein_tokens = self._tokenize_protein(protein_seq) |
| else: |
| protein_tokens = self._tokenize_protein_placeholder(protein_seq) |
|
|
| |
| label = torch.tensor(row['numeric_label'], dtype=torch.float32) |
| confidence = torch.tensor(row['confidence'], dtype=torch.float32) |
|
|
| return { |
| 'peptide_seq': peptide_seq, |
| 'peptide_smiles': peptide_smiles, |
| 'peptide_smiles_length': peptide_smiles_length, |
| 'protein_seq': protein_seq, |
| 'peptide_tokens': peptide_tokens, |
| 'protein_tokens': protein_tokens, |
| 'label': label, |
| 'confidence': confidence, |
| 'target_id': row['Target_UniProt_ID'], |
| 'ligand_id': row['Ligand_UniProt_ID'], |
| 'action': row['Action'] |
| } |
|
|
| def _peptide_to_smiles(self, peptide_seq: str) -> str: |
| if not self.convert_peptide_to_smiles: |
| return peptide_seq |
| return peptide_seq_to_smiles(peptide_seq) |
|
|
| def _tokenize_peptide(self, peptide_seq: str) -> torch.Tensor: |
| """Tokenize peptide sequence using provided tokenizer.""" |
| tokens = self.peptide_tokenizer( |
| peptide_seq, |
| return_tensors='pt', |
| padding='max_length', |
| max_length=self.max_peptide_length, |
| truncation=True |
| )['input_ids'].squeeze(0) |
| return tokens |
|
|
| def _tokenize_protein_placeholder(self, protein_seq: str) -> torch.Tensor: |
| """ |
| Placeholder protein tokenizer (character-level). |
| |
| NOTE: Replace with ESM-2 tokenizer in production: |
| from esm import pretrained |
| _, alphabet = pretrained.esm2_t33_650M_UR50D() |
| batch_converter = alphabet.get_batch_converter() |
| _, _, tokens = batch_converter([("protein", protein_seq)]) |
| """ |
| |
| aa_to_idx = {aa: i+1 for i, aa in enumerate('ACDEFGHIKLMNPQRSTVWY')} |
| aa_to_idx['<PAD>'] = 0 |
| aa_to_idx['<UNK>'] = 21 |
|
|
| |
| indices = [aa_to_idx.get(aa, aa_to_idx['<UNK>']) for aa in protein_seq] |
|
|
| |
| if len(indices) > self.max_protein_length: |
| indices = indices[:self.max_protein_length] |
| else: |
| indices += [0] * (self.max_protein_length - len(indices)) |
|
|
| return torch.tensor(indices, dtype=torch.long) |
|
|
| def _tokenize_protein(self, protein_seq: str) -> torch.Tensor: |
| """Tokenize protein using ESM-2 tokenizer if available.""" |
| if self.protein_tokenizer is None: |
| return self._tokenize_protein_placeholder(protein_seq) |
|
|
| |
| |
| return self._tokenize_protein_placeholder(protein_seq) |
|
|
| def get_target_proteins(self) -> Dict[str, str]: |
| """ |
| Get dictionary of unique target proteins. |
| |
| Returns: |
| dict: {UniProt_ID: Sequence} |
| """ |
| unique_targets = self.data.drop_duplicates(subset=['Target_UniProt_ID']) |
| return dict(zip(unique_targets['Target_UniProt_ID'], unique_targets['Target_Sequence'])) |
|
|
| def get_ligands_for_target(self, target_id: str) -> List[Dict]: |
| """ |
| Get all ligands (peptides) for a specific target protein. |
| |
| Args: |
| target_id: Target protein UniProt ID |
| |
| Returns: |
| List of dicts with ligand info |
| """ |
| target_data = self.data[self.data['Target_UniProt_ID'] == target_id] |
|
|
| ligands = [] |
| for _, row in target_data.iterrows(): |
| ligands.append({ |
| 'sequence': row['Ligand_Sequence'], |
| 'uniprot_id': row['Ligand_UniProt_ID'], |
| 'label': row['numeric_label'], |
| 'confidence': row['confidence'], |
| 'action': row['Action'] |
| }) |
|
|
| return ligands |
|
|
|
|
| def load_td3b_data( |
| data_path: str, |
| mode: str = 'oracle', |
| target_protein_id: Optional[str] = None |
| ) -> Tuple[pd.DataFrame, Dict]: |
| """ |
| Load and summarize TD3B data. |
| |
| Args: |
| data_path: Path to TD3B_data.csv |
| mode: 'oracle' or 'finetune' |
| target_protein_id: Filter by target protein (finetuning mode) |
| |
| Returns: |
| data: Filtered DataFrame |
| stats: Dictionary of statistics |
| """ |
| data = pd.read_csv(data_path) |
|
|
| |
| if mode == 'finetune' and target_protein_id is not None: |
| data = data[data['Target_UniProt_ID'] == target_protein_id] |
|
|
| |
| stats = { |
| 'total_pairs': len(data), |
| 'unique_targets': data['Target_UniProt_ID'].nunique(), |
| 'unique_ligands': data['Ligand_UniProt_ID'].nunique(), |
| 'agonist_count': (data['label'] == 'agonist').sum(), |
| 'antagonist_count': (data['label'] == 'antagonist').sum(), |
| 'action_distribution': data['Action'].value_counts().to_dict() |
| } |
|
|
| return data, stats |
|
|
|
|
| def create_target_dataset_for_finetuning( |
| data_path: str, |
| target_protein_id: str, |
| desired_direction: str = 'agonist' |
| ) -> Dict: |
| """ |
| Create a dataset for TD3B finetuning focused on a specific target. |
| |
| Args: |
| data_path: Path to TD3B_data.csv |
| target_protein_id: Target protein UniProt ID |
| desired_direction: 'agonist' or 'antagonist' |
| |
| Returns: |
| dict with target protein info and example ligands |
| """ |
| data = pd.read_csv(data_path) |
|
|
| |
| target_data = data[data['Target_UniProt_ID'] == target_protein_id] |
|
|
| if len(target_data) == 0: |
| raise ValueError(f"No data found for target {target_protein_id}") |
|
|
| |
| protein_seq = target_data.iloc[0]['Target_Sequence'] |
|
|
| |
| direction_map = {'agonist': 'agonist', 'antagonist': 'antagonist'} |
| direction_ligands = target_data[target_data['label'] == direction_map[desired_direction]] |
|
|
| |
| opposite_direction = 'antagonist' if desired_direction == 'agonist' else 'agonist' |
| opposite_ligands = target_data[target_data['label'] == opposite_direction] |
|
|
| return { |
| 'target_protein_id': target_protein_id, |
| 'target_protein_seq': protein_seq, |
| 'desired_direction': desired_direction, |
| 'n_desired_examples': len(direction_ligands), |
| 'n_opposite_examples': len(opposite_ligands), |
| 'desired_ligands': direction_ligands[['Ligand_Sequence', 'Action', 'Ligand_UniProt_ID']].to_dict('records'), |
| 'opposite_ligands': opposite_ligands[['Ligand_Sequence', 'Action', 'Ligand_UniProt_ID']].to_dict('records') |
| } |
|
|
|
|
| if __name__ == "__main__": |
| |
| data_path = "../TD3B_data.csv" |
|
|
| print("=" * 80) |
| print("TD3B Data Loading Example") |
| print("=" * 80) |
|
|
| |
| data, stats = load_td3b_data(data_path, mode='oracle') |
|
|
| print("\nDataset Statistics:") |
| for key, value in stats.items(): |
| print(f" {key}: {value}") |
|
|
| |
| print("\n" + "=" * 80) |
| print("Oracle Training Dataset") |
| print("=" * 80) |
|
|
| dataset = TD3BDataset(data_path, mode='oracle') |
| print(f"Dataset size: {len(dataset)}") |
|
|
| |
| sample = dataset[0] |
| print(f"\nSample item:") |
| print(f" Target: {sample['target_id']}") |
| print(f" Ligand: {sample['ligand_id']}") |
| print(f" Label: {sample['label'].item()}") |
| print(f" Confidence: {sample['confidence'].item()}") |
| print(f" Action: {sample['action']}") |
|
|
| |
| print("\n" + "=" * 80) |
| print("Finetuning Dataset Example") |
| print("=" * 80) |
|
|
| |
| targets = dataset.get_target_proteins() |
| first_target_id = list(targets.keys())[0] |
|
|
| finetune_info = create_target_dataset_for_finetuning( |
| data_path, |
| first_target_id, |
| desired_direction='agonist' |
| ) |
|
|
| print(f"\nTarget: {finetune_info['target_protein_id']}") |
| print(f"Desired direction: {finetune_info['desired_direction']}") |
| print(f"Number of agonist examples: {finetune_info['n_desired_examples']}") |
| print(f"Number of antagonist examples: {finetune_info['n_opposite_examples']}") |
|
|