| import argparse |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn.functional as F |
| from tqdm import tqdm |
|
|
| |
| from smiles_train import MDLMLightningModule, PeptideAnalyzer |
| from smiles_tokenizer.my_tokenizers import SMILES_SPE_Tokenizer |
|
|
| import pdb |
|
|
|
|
| def generate_smiles(model, tokenizer, args): |
| """ |
| Generates peptide SMILES strings using the trained MDLM model |
| with a forward (t=0 to t=1) flow matching process. |
| |
| Args: |
| model (MDLMLightningModule): The trained PyTorch Lightning model. |
| tokenizer (SMILES_SPE_Tokenizer): The tokenizer used for training. |
| args (argparse.Namespace): Command-line arguments containing sampling parameters. |
| |
| Returns: |
| list[str]: A list of generated SMILES strings. |
| float: The validity rate of the generated SMILES. |
| """ |
| print("Starting SMILES generation with forward flow matching (t=0 to t=1)...") |
| model.eval() |
| device = args.device |
|
|
| |
| x = torch.randint( |
| 0, |
| model.model.vocab_size, |
| (args.n_samples, args.seq_len), |
| device=device |
| ) |
|
|
| |
| time_steps = torch.linspace(0.0, 1.0, args.n_steps + 1, device=device) |
|
|
| |
| with torch.no_grad(): |
| for i in tqdm(range(args.n_steps), desc="Flow Matching Steps"): |
| t_curr = time_steps[i] |
| t_next = time_steps[i+1] |
|
|
| |
| t_tensor = torch.full((args.n_samples,), t_curr, device=device) |
|
|
| |
| logits = model(x, t_tensor) |
| logits = logits / args.temperature |
|
|
| pred_x1 = torch.argmax(logits, dim=-1) |
|
|
| |
| if i == args.n_steps - 1: |
| x = pred_x1 |
| break |
|
|
| |
| |
| noise_prob = 1.0 - t_next |
| mask = torch.rand(x.shape, device=device) < noise_prob |
| |
| |
| noise = torch.randint( |
| 0, |
| model.model.vocab_size, |
| x.shape, |
| device=device |
| ) |
|
|
| |
| x = torch.where(mask, noise, pred_x1) |
|
|
| |
| generated_sequences = tokenizer.batch_decode(x) |
| |
| |
| peptide_analyzer = PeptideAnalyzer() |
| valid_count = 0 |
| valid_smiles = [] |
| for seq in generated_sequences: |
| if peptide_analyzer.is_peptide(seq): |
| valid_count += 1 |
| valid_smiles.append(seq) |
| |
| validity_rate = valid_count / len(generated_sequences) |
| |
| print(f"\nGeneration complete. Validity rate: {validity_rate:.2%}") |
| return valid_smiles, validity_rate |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Sample from a trained ReDi model.") |
| |
| |
| parser.add_argument("--checkpoint_path", type=str, required=True, help="Path to the model checkpoint (.ckpt file).") |
| |
| |
| parser.add_argument("--n_samples", type=int, default=16, help="Number of SMILES strings to generate.") |
| parser.add_argument("--seq_len", type=int, default=256, help="Maximum sequence length for generated SMILES.") |
| parser.add_argument("--n_steps", type=int, default=100, help="Number of denoising steps for sampling.") |
| parser.add_argument("--temperature", type=float, default=1.0, help="Sampling temperature. Higher values increase diversity.") |
| |
| |
| parser.add_argument("--vocab_path", type=str, default='/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/new_vocab.txt', help="Path to tokenizer vocabulary file.") |
| parser.add_argument("--splits_path", type=str, default='/scratch/pranamlab/tong/ReDi_discrete/smiles/smiles_tokenizer/new_splits.txt', help="Path to tokenizer splits file.") |
| parser.add_argument("--output_file", type=str, default="generated_smiles.txt", help="File to save the valid generated SMILES.") |
| |
| args = parser.parse_args() |
| |
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| args.device = device |
| print(f"Using device: {device}") |
| |
| |
| print("Loading tokenizer...") |
| tokenizer = SMILES_SPE_Tokenizer(args.vocab_path, args.splits_path) |
|
|
| print(f"Loading model from checkpoint: {args.checkpoint_path}") |
| |
| checkpoint = torch.load(args.checkpoint_path, map_location=device, weights_only=False) |
| model_hparams = checkpoint["hyper_parameters"]["args"] |
| |
| |
| model = MDLMLightningModule.load_from_checkpoint( |
| args.checkpoint_path, |
| args=model_hparams, |
| tokenizer=tokenizer, |
| map_location=device, |
| strict=False |
| ) |
| model.to(device) |
|
|
| |
| valid_smiles, validity_rate = generate_smiles(model, tokenizer, args) |
|
|
| |
|
|
| with open('./v0_samples_200.csv', 'a') as f: |
| for smiles in valid_smiles: |
| |
| f.write(smiles + '\n') |
| print(validity_rate) |
|
|
| if __name__ == "__main__": |
| main() |