moPPIt / moppit.py
AlienChen's picture
Update moppit.py
29de994 verified
import os
import warnings
import logging
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
warnings.filterwarnings("ignore")
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
from sklearn.exceptions import InconsistentVersionWarning
warnings.filterwarnings("ignore", category=InconsistentVersionWarning)
logging.getLogger().setLevel(logging.ERROR)
logging.getLogger("lightning").setLevel(logging.ERROR)
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
logging.getLogger("transformers").setLevel(logging.ERROR)
logging.getLogger("absl").setLevel(logging.ERROR)
from transformers import logging as hf_logging
hf_logging.set_verbosity_error()
hf_logging.disable_progress_bar()
logging.getLogger("lightning.fabric.utilities.seed").setLevel(logging.ERROR)
logging.getLogger("pytorch_lightning.utilities.seed").setLevel(logging.ERROR)
import torch
from transformers import AutoTokenizer
from pathlib import Path
import inspect
# from models.peptide_classifiers import *
from models.peptiverse_classifiers import *
from utils.parsing import parse_guidance_args
args = parse_guidance_args()
# MOO hyper-parameters
step_size = 1 / 100
n_samples = 1
vocab_size = 24
source_distribution = "uniform"
device = 'cuda:0'
length = args.length
target = args.target_protein
if args.motifs:
motifs = parse_motifs(args.motifs).to(device)
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
target_sequence = tokenizer(target, return_tensors='pt').to(device)
# Load Models
solver = load_solver('./ckpt/peptide/cnn_epoch200_lr0.0001_embed512_hidden256_loss3.1051.ckpt', vocab_size, device)
score_models = []
if 'Hemolysis' in args.objectives:
hemolysis_model = HemolysisWT()
score_models.append(hemolysis_model)
if 'Non-Fouling' in args.objectives:
nonfouling_model = NonfoulingWT()
score_models.append(nonfouling_model)
if 'Solubility' in args.objectives:
solubility_model = Solubility()
score_models.append(solubility_model)
if 'Permeability' in args.objectives:
permeability_model = PermeabilityWT()
score_models.append(permeability_model)
if 'Half-Life' in args.objectives:
halflife_model = HalfLifeWT()
score_models.append(halflife_model)
if 'Affinity' in args.objectives:
affinity_model = AffinityWT(target)
score_models.append(affinity_model)
if 'Motif' in args.objectives or 'Specificity' in args.objectives:
bindevaluator = load_bindevaluator('./classifier_ckpt/finetuned_BindEvaluator.ckpt', device)
if 'Specificity' in args.objectives:
args.specificity = True
else:
args.specificity = False
motif_model = MotifModelWT(bindevaluator, target_sequence['input_ids'], motifs, tokenizer, device, penalty=args.specificity)
score_models.append(motif_model)
objective_line = "Binder," + str(args.objectives)[1:-1].replace(' ', '').replace("'", "") + '\n'
if Path(args.output_file).exists():
with open(args.output_file, 'r') as f:
lines = f.readlines()
if len(lines) == 0 or lines[0] != objective_line:
with open(args.output_file, 'w') as f:
f.write(objective_line)
else:
with open(args.output_file, 'w') as f:
f.write(objective_line)
for i in range(args.n_batches):
if args.starting_sequence:
x_init = tokenizer(args.starting_sequence, return_tensors='pt')['input_ids'].to(device)
else:
if source_distribution == "uniform":
x_init = torch.randint(low=4, high=vocab_size, size=(n_samples, length), device=device) # CHANGE!
elif source_distribution == "mask":
x_init = (torch.zeros(size=(n_samples, length), device=device) + 3).long()
else:
raise NotImplementedError
zeros = torch.zeros((n_samples, 1), dtype=x_init.dtype, device=x_init.device)
twos = torch.full((n_samples, 1), 2, dtype=x_init.dtype, device=x_init.device)
x_init = torch.cat([zeros, x_init, twos], dim=1)
if args.fixed_positions is not None:
fixed_positions = parse_motifs(args.fixed_positions).tolist()
else:
fixed_positions = []
invalid_tokens = torch.tensor([0, 1, 2, 3], device=device)
x_1 = solver.multi_guidance_sample(args=args, x_init=x_init,
step_size=step_size,
verbose=True,
time_grid=torch.tensor([0.0, 1.0-1e-3]),
score_models=score_models,
num_objectives=len(score_models) + int(args.specificity),
weights=args.weights,
tokenizer=tokenizer,
fixed_positions=fixed_positions,
invalid_tokens=invalid_tokens)
scores = []
input_seqs = [tokenizer.batch_decode(x_1)[0].replace(' ', '')[5:-5]]
for i, s in enumerate(score_models):
sig = inspect.signature(s.forward) if hasattr(s, 'forward') else inspect.signature(s)
if 't' in sig.parameters:
candidate_scores = s(input_seqs, 1)
else:
candidate_scores = s(input_seqs)
if args.objectives[i] == 'Affinity':
candidate_scores = 10 * candidate_scores
elif args.objectives[i] == 'Hemolysis':
candidate_scores = 1 - candidate_scores
else:
candidate_scores = candidate_scores
if isinstance(candidate_scores, tuple):
for score in candidate_scores:
scores.append(score.item())
else:
scores.append(candidate_scores.item())
print(f"Sample: {input_seqs[0]}")
print(f"Scores: ")
for i, objective in enumerate(args.objectives):
print(f"{objective}: {scores[i]:.4f}")
with open(args.output_file, 'a') as f:
f.write(input_seqs[0])
for score in scores:
f.write(f",{score}")
f.write('\n')