| import os |
| import warnings |
| import logging |
|
|
| import os |
| os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" |
|
|
| warnings.filterwarnings("ignore") |
| warnings.filterwarnings("ignore", category=UserWarning) |
| warnings.filterwarnings("ignore", category=FutureWarning) |
|
|
| from sklearn.exceptions import InconsistentVersionWarning |
| warnings.filterwarnings("ignore", category=InconsistentVersionWarning) |
|
|
| logging.getLogger().setLevel(logging.ERROR) |
| logging.getLogger("lightning").setLevel(logging.ERROR) |
| logging.getLogger("pytorch_lightning").setLevel(logging.ERROR) |
| logging.getLogger("transformers").setLevel(logging.ERROR) |
| logging.getLogger("absl").setLevel(logging.ERROR) |
|
|
| from transformers import logging as hf_logging |
| hf_logging.set_verbosity_error() |
| hf_logging.disable_progress_bar() |
|
|
| logging.getLogger("lightning.fabric.utilities.seed").setLevel(logging.ERROR) |
| logging.getLogger("pytorch_lightning.utilities.seed").setLevel(logging.ERROR) |
|
|
| import torch |
| from transformers import AutoTokenizer |
| from pathlib import Path |
| import inspect |
|
|
| |
| from models.peptiverse_classifiers import * |
| from utils.parsing import parse_guidance_args |
| args = parse_guidance_args() |
|
|
|
|
| |
| step_size = 1 / 100 |
| n_samples = 1 |
| vocab_size = 24 |
| source_distribution = "uniform" |
| device = 'cuda:0' |
|
|
| length = args.length |
| target = args.target_protein |
|
|
| if args.motifs: |
| motifs = parse_motifs(args.motifs).to(device) |
|
|
| tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") |
| target_sequence = tokenizer(target, return_tensors='pt').to(device) |
|
|
| |
| solver = load_solver('./ckpt/peptide/cnn_epoch200_lr0.0001_embed512_hidden256_loss3.1051.ckpt', vocab_size, device) |
|
|
| score_models = [] |
| if 'Hemolysis' in args.objectives: |
| hemolysis_model = HemolysisWT() |
| score_models.append(hemolysis_model) |
| if 'Non-Fouling' in args.objectives: |
| nonfouling_model = NonfoulingWT() |
| score_models.append(nonfouling_model) |
| if 'Solubility' in args.objectives: |
| solubility_model = Solubility() |
| score_models.append(solubility_model) |
| if 'Permeability' in args.objectives: |
| permeability_model = PermeabilityWT() |
| score_models.append(permeability_model) |
| if 'Half-Life' in args.objectives: |
| halflife_model = HalfLifeWT() |
| score_models.append(halflife_model) |
| if 'Affinity' in args.objectives: |
| affinity_model = AffinityWT(target) |
| score_models.append(affinity_model) |
| if 'Motif' in args.objectives or 'Specificity' in args.objectives: |
| bindevaluator = load_bindevaluator('./classifier_ckpt/finetuned_BindEvaluator.ckpt', device) |
| if 'Specificity' in args.objectives: |
| args.specificity = True |
| else: |
| args.specificity = False |
| motif_model = MotifModelWT(bindevaluator, target_sequence['input_ids'], motifs, tokenizer, device, penalty=args.specificity) |
| score_models.append(motif_model) |
|
|
| objective_line = "Binder," + str(args.objectives)[1:-1].replace(' ', '').replace("'", "") + '\n' |
|
|
| if Path(args.output_file).exists(): |
| with open(args.output_file, 'r') as f: |
| lines = f.readlines() |
|
|
| if len(lines) == 0 or lines[0] != objective_line: |
| with open(args.output_file, 'w') as f: |
| f.write(objective_line) |
| else: |
| with open(args.output_file, 'w') as f: |
| f.write(objective_line) |
|
|
| for i in range(args.n_batches): |
| if args.starting_sequence: |
| x_init = tokenizer(args.starting_sequence, return_tensors='pt')['input_ids'].to(device) |
| else: |
| if source_distribution == "uniform": |
| x_init = torch.randint(low=4, high=vocab_size, size=(n_samples, length), device=device) |
| elif source_distribution == "mask": |
| x_init = (torch.zeros(size=(n_samples, length), device=device) + 3).long() |
| else: |
| raise NotImplementedError |
|
|
| zeros = torch.zeros((n_samples, 1), dtype=x_init.dtype, device=x_init.device) |
| twos = torch.full((n_samples, 1), 2, dtype=x_init.dtype, device=x_init.device) |
| x_init = torch.cat([zeros, x_init, twos], dim=1) |
|
|
| if args.fixed_positions is not None: |
| fixed_positions = parse_motifs(args.fixed_positions).tolist() |
| else: |
| fixed_positions = [] |
| |
| invalid_tokens = torch.tensor([0, 1, 2, 3], device=device) |
|
|
| x_1 = solver.multi_guidance_sample(args=args, x_init=x_init, |
| step_size=step_size, |
| verbose=True, |
| time_grid=torch.tensor([0.0, 1.0-1e-3]), |
| score_models=score_models, |
| num_objectives=len(score_models) + int(args.specificity), |
| weights=args.weights, |
| tokenizer=tokenizer, |
| fixed_positions=fixed_positions, |
| invalid_tokens=invalid_tokens) |
| |
| scores = [] |
| input_seqs = [tokenizer.batch_decode(x_1)[0].replace(' ', '')[5:-5]] |
| for i, s in enumerate(score_models): |
| sig = inspect.signature(s.forward) if hasattr(s, 'forward') else inspect.signature(s) |
| if 't' in sig.parameters: |
| candidate_scores = s(input_seqs, 1) |
| else: |
| candidate_scores = s(input_seqs) |
|
|
| if args.objectives[i] == 'Affinity': |
| candidate_scores = 10 * candidate_scores |
| elif args.objectives[i] == 'Hemolysis': |
| candidate_scores = 1 - candidate_scores |
| else: |
| candidate_scores = candidate_scores |
|
|
| if isinstance(candidate_scores, tuple): |
| for score in candidate_scores: |
| scores.append(score.item()) |
| else: |
| scores.append(candidate_scores.item()) |
|
|
| print(f"Sample: {input_seqs[0]}") |
| print(f"Scores: ") |
| for i, objective in enumerate(args.objectives): |
| print(f"{objective}: {scores[i]:.4f}") |
|
|
| with open(args.output_file, 'a') as f: |
| f.write(input_seqs[0]) |
| for score in scores: |
| f.write(f",{score}") |
| f.write('\n') |