| import argparse |
| from pathlib import Path |
| import os |
| import re |
| import torch |
| import torch.nn.functional as F |
| from tqdm import tqdm |
| from datasets import Dataset, concatenate_datasets |
| import pdb |
|
|
| |
| |
| from train import MDLMLightningModule |
| from peptide_analyzer import PeptideAnalyzer |
| from smiles_tokenizer.my_tokenizers import SMILES_SPE_Tokenizer |
|
|
| def peptide_bond_mask(smiles_list): |
| """ |
| Returns a mask with shape (batch_size, seq_length) that has 1 at the locations |
| of recognized bonds in the positions dictionary and 0 elsewhere. |
| |
| Args: |
| smiles_list: List of peptide SMILES strings (batch of SMILES strings). |
| |
| Returns: |
| np.ndarray: A mask of shape (batch_size, seq_length) with 1s at bond positions. |
| """ |
| |
| batch_size = len(smiles_list) |
| max_seq_length = 1035 |
| mask = torch.zeros((batch_size, max_seq_length), dtype=torch.int) |
|
|
| bond_patterns = [ |
| (r'OC\(=O\)', 'ester'), |
| (r'N\(C\)C\(=O\)', 'n_methyl'), |
| (r'N[12]C\(=O\)', 'peptide'), |
| (r'NC\(=O\)', 'peptide'), |
| (r'C\(=O\)N\(C\)', 'n_methyl'), |
| (r'C\(=O\)N[12]?', 'peptide') |
| ] |
|
|
| for batch_idx, smiles in enumerate(smiles_list): |
| positions = [] |
| used = set() |
|
|
| |
| for pattern, bond_type in bond_patterns: |
| for match in re.finditer(pattern, smiles): |
| if not any(p in range(match.start(), match.end()) for p in used): |
| positions.append({ |
| 'start': match.start(), |
| 'end': match.end(), |
| 'type': bond_type, |
| 'pattern': match.group() |
| }) |
| used.update(range(match.start(), match.end())) |
|
|
| |
| for pos in positions: |
| mask[batch_idx, pos['start']:pos['end']] = 1 |
|
|
| return mask |
|
|
| def peptide_token_mask(smiles_list, token_lists): |
| """ |
| Returns a mask with shape (batch_size, num_tokens) that has 1 for tokens |
| where any part of the token overlaps with a peptide bond, and 0 elsewhere. |
| |
| Args: |
| smiles_list: List of peptide SMILES strings (batch of SMILES strings). |
| token_lists: List of tokenized SMILES strings (split into tokens). |
| |
| Returns: |
| np.ndarray: A mask of shape (batch_size, num_tokens) with 1s for peptide bond tokens. |
| """ |
| |
| batch_size = len(smiles_list) |
| token_seq_length = max(len(tokens) for tokens in token_lists) |
| tokenized_masks = torch.zeros((batch_size, token_seq_length), dtype=torch.int) |
| atomwise_masks = peptide_bond_mask(smiles_list) |
|
|
|
|
| for batch_idx, atomwise_mask in enumerate(atomwise_masks): |
| token_seq = token_lists[batch_idx] |
| atom_idx = 0 |
| |
| for token_idx, token in enumerate(token_seq): |
| if token_idx != 0 and token_idx != len(token_seq) - 1: |
| if torch.sum(atomwise_mask[atom_idx:atom_idx+len(token)]) >= 1: |
| tokenized_masks[batch_idx][token_idx] = 1 |
| atom_idx += len(token) |
| |
| return tokenized_masks |
|
|
| def generate_and_filter_batch(model, tokenizer, peptide_analyzer, seq_len, batch_size, n_steps, temperature, device): |
| """ |
| Generates a single batch of SMILES, filters them for validity, and returns the valid ones |
| along with their original corresponding noise tensors (x0) and final token tensors (x1). |
| |
| Args: |
| model (MDLMLightningModule): The trained PyTorch Lightning model. |
| tokenizer (SMILES_SPE_Tokenizer): The tokenizer used for training. |
| peptide_analyzer (PeptideAnalyzer): The analyzer to validate peptides. |
| seq_len (int): The sequence length for this batch. |
| batch_size (int): The number of samples to generate in this batch. |
| n_steps (int): The number of steps for the flow matching process. |
| temperature (float): The sampling temperature. |
| device (str): The device to run generation on ('cuda' or 'cpu'). |
| |
| Returns: |
| tuple[list[str], list[torch.Tensor], list[torch.Tensor]]: A tuple containing: |
| - A list of valid, generated peptide SMILES strings. |
| - A list of the corresponding x0 tensors (noise). |
| - A list of the corresponding x1 tensors (final generated tokens). |
| """ |
| |
| x0 = torch.randint( |
| 0, |
| model.model.vocab_size, |
| (batch_size, seq_len), |
| device=device |
| ) |
| x = x0.clone() |
|
|
| |
| time_steps = torch.linspace(0.0, 1.0, n_steps + 1, device=device) |
|
|
| |
| with torch.no_grad(): |
| for i in range(n_steps): |
| t_curr = time_steps[i] |
| |
| t_tensor = torch.full((batch_size,), t_curr, device=device) |
|
|
| |
| logits = model(x, t_tensor) |
| if temperature > 0: |
| logits = logits / temperature |
| |
| pred_x1 = torch.argmax(logits, dim=-1) |
|
|
| if i == n_steps - 1: |
| x = pred_x1 |
| break |
|
|
| |
| t_next = time_steps[i+1] |
| noise_prob = 1.0 - t_next |
| mask = torch.rand(x.shape, device=device) < noise_prob |
| noise = torch.randint(0, model.model.vocab_size, x.shape, device=device) |
| x = torch.where(mask, noise, pred_x1) |
|
|
| generated_sequences = tokenizer.batch_decode(x) |
| |
| |
| valid_smiles = [] |
| valid_x0s = [] |
| valid_x1s = [] |
| for i, seq in enumerate(generated_sequences): |
| if peptide_analyzer.is_peptide(seq): |
| valid_smiles.append(seq) |
| valid_x0s.append(x0[i]) |
| valid_x1s.append(x[i]) |
| |
| return valid_smiles, valid_x0s, valid_x1s |
|
|
|
|
| def main(args): |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| tokenizer = SMILES_SPE_Tokenizer(args.vocab_path, args.splits_path) |
| checkpoint = torch.load(args.checkpoint_path, map_location=device, weights_only=False) |
| model = MDLMLightningModule.load_from_checkpoint( |
| args.checkpoint_path, args=checkpoint["hyper_parameters"]["args"], |
| tokenizer=tokenizer, strict=False |
| ).to(device).eval() |
| pa = PeptideAnalyzer() |
|
|
| all_sources = [] |
| all_targets = [] |
| all_bonds = [] |
|
|
| for length in range(args.max_length, args.min_length - 1, -1): |
| print(f"\n--- Generating for length {length} ---") |
| |
| collected_for_len = 0 |
| pbar = tqdm(total=args.num_sequences_per_length, desc=f"Length {length}") |
| |
| |
| chunk_source, chunk_target, chunk_bond = [], [], [] |
| max_batch_size = args.max_tokens_in_batch // length |
| |
| while collected_for_len < args.num_sequences_per_length: |
| num_needed = args.num_sequences_per_length - collected_for_len |
|
|
| gen_bsz = max_batch_size - len(chunk_target) if max_batch_size > len(chunk_target) else max_batch_size |
| if gen_bsz == 0: |
| print(f"Warning: Length {length} too long for token limit. Skipping.") |
| break |
| |
| actual_bsz = min(num_needed, gen_bsz) |
| |
| smiles, x0s, x1s = generate_and_filter_batch( |
| model, tokenizer, pa, length, actual_bsz, |
| args.n_steps, args.temperature, device |
| ) |
|
|
| if smiles: |
| tokens = tokenizer.get_token_split(x1s) |
| b_masks = peptide_token_mask(smiles, tokens) |
| |
| chunk_source.extend([x.tolist() for x in x0s]) |
| chunk_target.extend([x.tolist() for x in x1s]) |
| chunk_bond.extend(b_masks.tolist()) |
|
|
| collected_for_len += len(smiles) |
| pbar.update(len(smiles)) |
|
|
|
|
| |
| if len(chunk_target) == min(max_batch_size, args.num_sequences_per_length): |
| all_sources.append(chunk_source) |
| all_targets.append(chunk_target) |
| all_bonds.append(chunk_bond) |
| chunk_source, chunk_target, chunk_bond = [], [], [] |
|
|
| pbar.close() |
|
|
| all_data = Dataset.from_dict({ |
| 'source_ids': all_sources, |
| 'target_ids': all_targets, |
| 'bond_mask': all_bonds |
| }) |
| print("\n--- Combining all generated data chunks ---") |
| print(f"Total valid sequences collected: {len(all_data)}") |
|
|
| print(f"Saving new rectified dataset to {args.output_dir}...") |
| train_val = all_data.train_test_split(test_size=0.1, seed=42) |
| final_split = train_val['train'].train_test_split(test_size=(1/9), seed=42) |
| |
| train_val['train'].save_to_disk(os.path.join(args.output_dir, 'train')) |
| final_split['test'].save_to_disk(os.path.join(args.output_dir, 'validation')) |
| train_val['test'].save_to_disk(os.path.join(args.output_dir, 'test')) |
| |
| print("\nDataset combination and saving complete.") |
|
|
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Generate rectified data couplings using a trained ReDi model for a range of lengths.") |
| |
| |
| parser.add_argument("--checkpoint_path", type=str, required=True, help="Path to the model checkpoint (.ckpt file).") |
| parser.add_argument("--output_dir", type=str, required=True, help="Directory to save the new rectified dataset.") |
| |
| |
| parser.add_argument("--num_sequences_per_length", type=int, default=100, help="Number of valid sequences to generate for each length.") |
| parser.add_argument("--min_length", type=int, default=4, help="Minimum sequence length to generate.") |
| parser.add_argument("--max_length", type=int, default=1035, help="Maximum sequence length to generate (and padding length).") |
| parser.add_argument("--max_tokens_in_batch", type=int, default=5200, help="Maximum number of tokens in a single generation batch (batch_size * seq_len).") |
| parser.add_argument("--n_steps", type=int, default=100, help="Number of steps for the flow matching process.") |
| parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature. Higher values increase diversity. Set to 0 for pure argmax.") |
| |
| |
| parser.add_argument("--vocab_path", type=str, default='/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/new_vocab.txt', help="Path to tokenizer vocabulary file.") |
| parser.add_argument("--splits_path", type=str, default='/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/new_splits.txt', help="Path to tokenizer splits file.") |
| |
| args = parser.parse_args() |
|
|
| main(args) |
|
|
|
|