File size: 6,112 Bytes
40897d7 3527383 6409d51 3527383 40897d7 3527383 40897d7 3527383 16339c9 3527383 a33cd10 3527383 a620d8f 40897d7 a620d8f 40897d7 a620d8f 40897d7 a620d8f 40897d7 a620d8f 40897d7 a620d8f 40897d7 a620d8f 99cc5ad 40897d7 a620d8f 3527383 9fdebfd 6409d51 9fdebfd 6409d51 3527383 40897d7 3527383 40897d7 3527383 40897d7 3527383 40897d7 3527383 40897d7 3527383 40897d7 3527383 40897d7 3527383 6409d51 b801090 6409d51 3527383 90d840d af66ec2 90d840d 29de994 90d840d 3527383 40897d7 3527383 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | import os
import warnings
import logging
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.filterwarnings("ignore")
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
from sklearn.exceptions import InconsistentVersionWarning
warnings.filterwarnings("ignore", category=InconsistentVersionWarning)
logging.getLogger().setLevel(logging.ERROR)
logging.getLogger("lightning").setLevel(logging.ERROR)
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
logging.getLogger("transformers").setLevel(logging.ERROR)
logging.getLogger("absl").setLevel(logging.ERROR)
from transformers import logging as hf_logging
hf_logging.set_verbosity_error()
hf_logging.disable_progress_bar()
logging.getLogger("lightning.fabric.utilities.seed").setLevel(logging.ERROR)
logging.getLogger("pytorch_lightning.utilities.seed").setLevel(logging.ERROR)
import torch
from transformers import AutoTokenizer
from pathlib import Path
import inspect
# from models.peptide_classifiers import *
from models.peptiverse_classifiers import *
from utils.parsing import parse_guidance_args
args = parse_guidance_args()
# MOO hyper-parameters
step_size = 1 / 100
n_samples = 1
vocab_size = 24
source_distribution = "uniform"
device = 'cuda:0'
length = args.length
target = args.target_protein
if args.motifs:
motifs = parse_motifs(args.motifs).to(device)
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
target_sequence = tokenizer(target, return_tensors='pt').to(device)
# Load Models
solver = load_solver('./ckpt/peptide/cnn_epoch200_lr0.0001_embed512_hidden256_loss3.1051.ckpt', vocab_size, device)
score_models = []
if 'Hemolysis' in args.objectives:
hemolysis_model = HemolysisWT()
score_models.append(hemolysis_model)
if 'Non-Fouling' in args.objectives:
nonfouling_model = NonfoulingWT()
score_models.append(nonfouling_model)
if 'Solubility' in args.objectives:
solubility_model = Solubility()
score_models.append(solubility_model)
if 'Permeability' in args.objectives:
permeability_model = PermeabilityWT()
score_models.append(permeability_model)
if 'Half-Life' in args.objectives:
halflife_model = HalfLifeWT()
score_models.append(halflife_model)
if 'Affinity' in args.objectives:
affinity_model = AffinityWT(target)
score_models.append(affinity_model)
if 'Motif' in args.objectives or 'Specificity' in args.objectives:
bindevaluator = load_bindevaluator('./classifier_ckpt/finetuned_BindEvaluator.ckpt', device)
if 'Specificity' in args.objectives:
args.specificity = True
else:
args.specificity = False
motif_model = MotifModelWT(bindevaluator, target_sequence['input_ids'], motifs, tokenizer, device, penalty=args.specificity)
score_models.append(motif_model)
objective_line = "Binder," + str(args.objectives)[1:-1].replace(' ', '').replace("'", "") + '\n'
if Path(args.output_file).exists():
with open(args.output_file, 'r') as f:
lines = f.readlines()
if len(lines) == 0 or lines[0] != objective_line:
with open(args.output_file, 'w') as f:
f.write(objective_line)
else:
with open(args.output_file, 'w') as f:
f.write(objective_line)
for i in range(args.n_batches):
if args.starting_sequence:
x_init = tokenizer(args.starting_sequence, return_tensors='pt')['input_ids'].to(device)
else:
if source_distribution == "uniform":
x_init = torch.randint(low=4, high=vocab_size, size=(n_samples, length), device=device) # CHANGE!
elif source_distribution == "mask":
x_init = (torch.zeros(size=(n_samples, length), device=device) + 3).long()
else:
raise NotImplementedError
zeros = torch.zeros((n_samples, 1), dtype=x_init.dtype, device=x_init.device)
twos = torch.full((n_samples, 1), 2, dtype=x_init.dtype, device=x_init.device)
x_init = torch.cat([zeros, x_init, twos], dim=1)
if args.fixed_positions is not None:
fixed_positions = parse_motifs(args.fixed_positions).tolist()
else:
fixed_positions = []
invalid_tokens = torch.tensor([0, 1, 2, 3], device=device)
x_1 = solver.multi_guidance_sample(args=args, x_init=x_init,
step_size=step_size,
verbose=True,
time_grid=torch.tensor([0.0, 1.0-1e-3]),
score_models=score_models,
num_objectives=len(score_models) + int(args.specificity),
weights=args.weights,
tokenizer=tokenizer,
fixed_positions=fixed_positions,
invalid_tokens=invalid_tokens)
scores = []
input_seqs = [tokenizer.batch_decode(x_1)[0].replace(' ', '')[5:-5]]
for i, s in enumerate(score_models):
sig = inspect.signature(s.forward) if hasattr(s, 'forward') else inspect.signature(s)
if 't' in sig.parameters:
candidate_scores = s(input_seqs, 1)
else:
candidate_scores = s(input_seqs)
if args.objectives[i] == 'Affinity':
candidate_scores = 10 * candidate_scores
elif args.objectives[i] == 'Hemolysis':
candidate_scores = 1 - candidate_scores
else:
candidate_scores = candidate_scores
if isinstance(candidate_scores, tuple):
for score in candidate_scores:
scores.append(score.item())
else:
scores.append(candidate_scores.item())
print(f"Sample: {input_seqs[0]}")
print(f"Scores: ")
for i, objective in enumerate(args.objectives):
print(f"{objective}: {scores[i]:.4f}")
with open(args.output_file, 'a') as f:
f.write(input_seqs[0])
for score in scores:
f.write(f",{score}")
f.write('\n') |