| import argparse |
| import math |
| import random |
| from collections import Counter |
| import csv |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from tqdm import tqdm |
| from transformers import AutoTokenizer |
|
|
| from peptide_classifiers import * |
|
|
|
|
| |
| def modulate(x, shift, scale): |
| return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) |
|
|
| class TimestepEmbedder(nn.Module): |
| def __init__(self, hidden_size): |
| super().__init__() |
| self.mlp = nn.Sequential( |
| nn.Linear(1, hidden_size, bias=True), nn.SiLU(), |
| nn.Linear(hidden_size, hidden_size, bias=True), |
| ) |
| def forward(self, t): |
| return self.mlp(t.unsqueeze(-1)) |
|
|
| class DiTBlock(nn.Module): |
| def __init__(self, hidden_size, n_heads): |
| super().__init__() |
| self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
| self.attn = nn.MultiheadAttention(hidden_size, n_heads, batch_first=True) |
| self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
| self.mlp = nn.Sequential( |
| nn.Linear(hidden_size, 4 * hidden_size), nn.GELU(), |
| nn.Linear(4 * hidden_size, hidden_size) |
| ) |
| self.adaLN_modulation = nn.Sequential( |
| nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True) |
| ) |
| def forward(self, x, c): |
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) |
| x_norm1 = modulate(self.norm1(x), shift_msa, scale_msa) |
| attn_output, _ = self.attn(x_norm1, x_norm1, x_norm1) |
| x = x + gate_msa.unsqueeze(1) * attn_output |
| x_norm2 = modulate(self.norm2(x), shift_mlp, scale_mlp) |
| mlp_output = self.mlp(x_norm2) |
| x = x + gate_mlp.unsqueeze(1) * mlp_output |
| return x |
|
|
| class MDLM(nn.Module): |
| def __init__(self, vocab_size, seq_len, model_dim, n_heads, n_layers): |
| super().__init__() |
| self.vocab_size = vocab_size |
| self.seq_len = seq_len |
| self.model_dim = model_dim |
| self.mask_token_id = vocab_size |
| self.token_embedder = nn.Embedding(vocab_size + 1, model_dim) |
| self.pos_embedder = nn.Parameter(torch.randn(1, seq_len, model_dim)) |
| self.time_embedder = TimestepEmbedder(model_dim) |
| self.transformer_blocks = nn.ModuleList([DiTBlock(model_dim, n_heads) for _ in range(n_layers)]) |
| self.final_norm = nn.LayerNorm(model_dim) |
| self.lm_head = nn.Linear(model_dim, vocab_size) |
| def forward(self, x, t): |
| seq_len = x.shape[1] |
| x_embed = self.token_embedder(x) + self.pos_embedder[:, :seq_len, :] |
| t_embed = self.time_embedder(t) |
| for block in self.transformer_blocks: |
| x_embed = block(x_embed, t_embed) |
| x_embed = self.final_norm(x_embed) |
| logits = self.lm_head(x_embed) |
| return logits |
|
|
| class MOGGenerator: |
| def __init__(self, model, device, objectives, args): |
| self.model = model |
| self.device = device |
| self.objectives = objectives |
| self.args = args |
| self.num_objectives = len(objectives) |
|
|
| def _get_scores(self, x_batch): |
| """Calculates the normalized scores for a batch of sequences.""" |
| scores = [] |
| for obj_func in self.objectives: |
| scores.append(obj_func(x_batch.to(self.device))) |
| return torch.stack(scores, dim=0) |
|
|
| def _barker_g(self, u): |
| """Barker balancing function.""" |
| return u / (1 + u) |
|
|
| def generate(self): |
| """Main generation loop.""" |
| shape = (self.args.num_samples, self.args.gen_len + 2) |
| x = torch.randint(4, self.model.vocab_size, shape, dtype=torch.long, device=self.device) |
| x[:, 0] = 0 |
| x[:, -1] = 2 |
| |
| if args.weights is None: |
| weights = torch.full((self.num_objectives,), 1/self.num_objectives, device=self.device).view(-1,1) |
| else: |
| weights = torch.tensor(self.args.weights, device=self.device).view(-1, 1) |
| if len(weights) != self.num_objectives: |
| raise ValueError("Number of weights must match number of objectives.") |
| print(f"Weights: {weights}") |
| |
| if self.args.min_threshold is not None: |
| min_threshold = torch.tensor(self.args.min_threshold, device=self.device) |
| else: |
| min_threshold = None |
|
|
| total_optimization_steps = self.args.optimization_steps * self.args.gen_len |
|
|
| with torch.no_grad(): |
| for t in tqdm(range(total_optimization_steps), desc="MOG Generation"): |
| |
| eta_t = self.args.eta_min + (self.args.eta_max - self.args.eta_min) * (t / (total_optimization_steps - 1)) |
| |
| |
| mut_idx = random.randint(1, self.args.gen_len) |
|
|
| |
| |
| generation_step = t % self.args.optimization_steps |
| time_t = torch.full((self.args.num_samples,), (generation_step / self.args.optimization_steps), device=self.device) |
| |
| |
| logits = self.model(x, time_t) |
| probs = F.softmax(logits, dim=-1) |
| pos_probs = probs[:, mut_idx, :] |
| pos_probs[:, x[:, mut_idx]] = 0 |
|
|
| |
| sorted_probs, sorted_indices = torch.sort(pos_probs, descending=True) |
| cumulative_probs = torch.cumsum(sorted_probs, dim=-1) |
| remove_mask = cumulative_probs > self.args.top_p |
| remove_mask[..., 1:] = remove_mask[..., :-1].clone() |
| remove_mask[..., 0] = 0 |
| |
| |
| candidate_tokens_list = [] |
| for i in range(self.args.num_samples): |
| sample_mask = remove_mask[i] |
| candidates = sorted_indices[i, ~sample_mask] |
| candidate_tokens_list.append(candidates) |
| |
| |
| current_scores = self._get_scores(x) |
| w_current = torch.exp(eta_t * torch.min(weights * current_scores, dim=0).values) |
|
|
| |
| final_proposal_tokens = [] |
| for i in range(self.args.num_samples): |
| candidates = candidate_tokens_list[i] |
| candidates = torch.tensor([token for token in candidates if token not in [0,1,2,3]], device=candidates.device) |
| num_candidates = len(candidates) |
| |
| |
| x_prop_batch = x[i].repeat(num_candidates, 1) |
| x_prop_batch[:, mut_idx] = candidates |
| |
| |
| proposal_scores = self._get_scores(x_prop_batch) |
| proposal_s_omega = torch.min(weights * proposal_scores, dim=0).values |
| w_proposal = torch.exp(eta_t * proposal_s_omega) |
| |
| |
| redi_probs = pos_probs[i, candidates] |
|
|
| |
| tilde_q = redi_probs * self._barker_g(w_proposal / w_current[i]) |
| |
| |
| final_probs = tilde_q / (torch.sum(tilde_q) + 1e-9) |
|
|
| index = torch.multinomial(final_probs, 1).item() |
| if torch.sum(weights.squeeze(1) * proposal_scores[:, index]) >= torch.sum(weights.squeeze(1) * current_scores[:,i]): |
| final_token = candidates[index] |
| print(f"Previous Weighted Sum: {torch.sum(weights.squeeze(1) * current_scores[:,i])}") |
| print(f"Previous Scores: {current_scores[:,i]}") |
|
|
| print(f"New Weighted Sum: {torch.sum(weights.squeeze(1) * proposal_scores[:, index])}") |
| print(f"New Scores: {proposal_scores[:,index]}") |
| else: |
| final_token = x[i][mut_idx] |
| |
| |
| final_proposal_tokens.append(final_token) |
|
|
| |
| x[torch.arange(self.args.num_samples), mut_idx] = torch.stack(final_proposal_tokens) |
| |
| scores = self._get_scores(x) |
| |
| return x |
|
|
| |
| def main(args): |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Using device: {device}") |
|
|
| target = args.target |
| tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D") |
| target_sequence = tokenizer(target, return_tensors='pt')['input_ids'].to(device) |
|
|
| affinity_predictor = load_affinity_predictor('/scratch/pranamlab/tong/ReDi_discrete/peptides/classifier_ckpt/binding_affinity_unpooled.pt', device) |
| affinity_model = AffinityModel(affinity_predictor, target_sequence) |
| hemolysis_model = HemolysisModel(device=device) |
| nonfouling_model = NonfoulingModel(device=device) |
| solubility_model = SolubilityModel(device=device) |
| halflife_model = HalfLifeModel(device=device) |
|
|
| print(f"Loading checkpoint from {args.checkpoint}...") |
| try: |
| checkpoint = torch.load(args.checkpoint, map_location=device, weights_only=False) |
| model_args = checkpoint['args'] |
| except Exception as e: |
| print(f"Error loading checkpoint: {e}") |
| return |
|
|
| print("Initializing model...") |
| model = MDLM( |
| vocab_size=model_args.vocab_size, |
| seq_len=model_args.seq_len, |
| model_dim=model_args.model_dim, |
| n_heads=model_args.n_heads, |
| n_layers=model_args.n_layers |
| ).to(device) |
| model.load_state_dict(checkpoint['model_state_dict']) |
| print("Model loaded successfully.") |
|
|
| |
| OBJECTIVE_FUNCTIONS = [hemolysis_model, nonfouling_model, solubility_model, halflife_model, affinity_model] |
|
|
| mog_generator = MOGGenerator(model, device, OBJECTIVE_FUNCTIONS, args) |
|
|
| hemolysis = [] |
| nonfouling = [] |
| solubility = [] |
| halflife = [] |
| affinity = [] |
|
|
| for _ in range(args.num_batches): |
| generated_tokens = mog_generator.generate() |
| final_scores = mog_generator._get_scores(generated_tokens).detach().cpu().numpy() |
| |
| with open(args.output_file, 'a', newline='') as f: |
| writer = csv.writer(f) |
|
|
| for i in range(args.num_samples): |
| sample_tokens = generated_tokens[i] |
| print(sample_tokens) |
| sequence_str = tokenizer.decode(sample_tokens.tolist(), skip_special_tokens=False).replace(" ", "")[5:-5] |
| |
| scores = final_scores[:, i] |
|
|
| writer.writerow([sequence_str] + scores.tolist()) |
|
|
| print([sequence_str] + scores.tolist()) |
| |
| print("Generation complete.") |
|
|
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Multi-Objective Generation with LBP-MOG-ReDi (Single Mutation).") |
|
|
| parser.add_argument("--checkpoint", type=str, required=True, help="Path to the trained ReDi model checkpoint.") |
| parser.add_argument("--num_samples", type=int, default=10, help="Number of samples to generate.") |
| parser.add_argument("--num_batches", type=int, default=10, help="Number of samples to generate.") |
| parser.add_argument("--output_file", type=str, default="./mog_peptides.txt", help="File to save the generated sequences.") |
| parser.add_argument("--gen_len", type=int, default=50, help="Length of the sequences to generate.") |
| parser.add_argument("--optimization_steps", type=int, default=16, help="Number of passes over the sequence.") |
| parser.add_argument("--weights", type=float, nargs='+', required=False, help="Weights for the objectives (e.g., 0.5 0.5).") |
| parser.add_argument("--min_threshold", type=float, nargs='+', required=False, help="minimum threshold for the objectives (e.g., 0.2 0.2).") |
| parser.add_argument("--eta_min", type=float, default=1.0, help="Minimum guidance strength for annealing.") |
| parser.add_argument("--eta_max", type=float, default=20.0, help="Maximum guidance strength for annealing.") |
| parser.add_argument("--top_p", type=float, default=0.9, help="Top-p for pruning candidate tokens.") |
|
|
| parser.add_argument("--target", type=str, required=True) |
| args = parser.parse_args() |
| main(args) |
|
|