Shrey Goel commited on
Commit ·
94c2704
0
Parent(s):
initial commit
Browse files- README.md +26 -0
- __init__.py +0 -0
- configs/wt_pep.yaml +84 -0
- setup.py +10 -0
- src/__init__.py +0 -0
- src/madsbm/__init__.py +0 -0
- src/madsbm/wt_peptide/__init__.py +0 -0
- src/madsbm/wt_peptide/control_field.py +199 -0
- src/madsbm/wt_peptide/dataloader.py +93 -0
- src/madsbm/wt_peptide/main.py +96 -0
- src/madsbm/wt_peptide/sbm_module.py +320 -0
- src/sampling/diffusion_sampler.py +106 -0
- src/sampling/guided_sample.py +121 -0
- src/sampling/madsbm_sampler.py +198 -0
- src/sampling/path_tracer.py +46 -0
- src/utils/__init__.py +0 -0
- src/utils/eval_utils.py +27 -0
- src/utils/fbd_score_model.py +104 -0
- src/utils/generate_utils.py +118 -0
- src/utils/model_utils.py +42 -0
- src/utils/time_utils.py +41 -0
README.md
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Token-Level Guided Discrete Diffusion for Membrane Protein Design
|
| 2 |
+
|
| 3 |
+

|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
arXiv preprint: ...
|
| 7 |
+
|
| 8 |
+
Reparameterized diffusion models (RDMs) have recently matched autoregressive methods in protein generation, motivating their use for challenging tasks such as designing membrane proteins, which possess interleaved soluble and transmembrane (TM) regions.
|
| 9 |
+
|
| 10 |
+
We introduce ***Membrane Diffusion Language Model (MemDLM)***, a fine-tuned RDM-based protein language model that enables controllable membrane protein sequence design. MemDLM-generated sequences recapitulate the TM residue density and structural features of natural membrane proteins, achieving comparable biological plausibility and outperforming state-of-the-art diffusion baselines in motif scaffolding tasks by producing:
|
| 11 |
+
|
| 12 |
+
- Lower perplexity
|
| 13 |
+
- Higher BLOSUM-62 scores
|
| 14 |
+
- Improved pLDDT confidence
|
| 15 |
+
|
| 16 |
+
To enhance controllability, we develop ***Per-Token Guidance (PET)***, a novel classifier-guided sampling strategy that selectively solubilizes residues while preserving conserved TM domains. This yields sequences with reduced TM density but intact functional cores.
|
| 17 |
+
|
| 18 |
+
Importantly, MemDLM designs validated in TOXCAT β-lactamase growth assays demonstrate successful TM insertion, distinguishing high-quality generated sequences from poor ones.
|
| 19 |
+
|
| 20 |
+
Together, our framework establishes the first experimentally validated diffusion-based model for rational membrane protein generation, integrating *de novo* design, motif scaffolding, and targeted property optimization.
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
## **Repository Authors**
|
| 25 |
+
- <u>[Shrey Goel](https://shreygoel09.github.io/)</u> – undergraduate student at Duke University
|
| 26 |
+
- <u>[Pranam Chatterjee](mailto:pranam@seas.upenn.edu)</u> – Assistant Professor at University of Pennsylvania
|
__init__.py
ADDED
|
File without changes
|
configs/wt_pep.yaml
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
seed: 42
|
| 3 |
+
base_dir: /scratch/pranamlab/sgoel/MadSBM
|
| 4 |
+
|
| 5 |
+
training:
|
| 6 |
+
mode: test # train / test / resume_from_ckpt
|
| 7 |
+
n_unfrozen: 3
|
| 8 |
+
n_epochs: 50
|
| 9 |
+
log_every_n_steps: 50
|
| 10 |
+
num_sanity_val_steps: 2
|
| 11 |
+
val_check_interval:
|
| 12 |
+
enable_progress_bar: true
|
| 13 |
+
grad_clip_val: 10.0
|
| 14 |
+
accumulate_grad_batches: 16 # to workaround dynamic batching
|
| 15 |
+
devices: 1 # number of GPUs
|
| 16 |
+
|
| 17 |
+
model:
|
| 18 |
+
ablate: false
|
| 19 |
+
evoflow_model: fredzzp/EvoFlow-150M-afdbseq
|
| 20 |
+
esm_model: facebook/esm2_t33_650M_UR50D
|
| 21 |
+
n_layers: 2 #8
|
| 22 |
+
n_heads: 16 #8
|
| 23 |
+
hidden_dim: 1280
|
| 24 |
+
attn_drop: 0.0
|
| 25 |
+
resid_drop: 0.0
|
| 26 |
+
mlp_ratio: 4.0
|
| 27 |
+
beta1: 1e-6
|
| 28 |
+
beta2: 1e-6
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
time_embed:
|
| 32 |
+
time_dim: 512
|
| 33 |
+
fourier_dim: 64
|
| 34 |
+
fourier_scale: 30.0
|
| 35 |
+
time_schedule: uniform # linear / exponential / uniform
|
| 36 |
+
anneal_frac: 0.75
|
| 37 |
+
min_time: 1e-6
|
| 38 |
+
n_timesteps: 500
|
| 39 |
+
|
| 40 |
+
data:
|
| 41 |
+
batch_size: 1
|
| 42 |
+
#max_seq_len: 500
|
| 43 |
+
train: /scratch/pranamlab/tong/data/peptide/tokenized_peptide_batched/train
|
| 44 |
+
test: /scratch/pranamlab/tong/data/peptide/tokenized_peptide_batched/test
|
| 45 |
+
val: /scratch/pranamlab/tong/data/peptide/tokenized_peptide_batched/val
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
optim:
|
| 49 |
+
type: adamw
|
| 50 |
+
scheduler: cosine
|
| 51 |
+
lr: 1e-4
|
| 52 |
+
lr_end: 1e-5
|
| 53 |
+
warmup_init_lr: 1e-6
|
| 54 |
+
warmup_epochs: 2
|
| 55 |
+
weight_decay: 0.01
|
| 56 |
+
beta1: 0.9
|
| 57 |
+
beta2: 0.999
|
| 58 |
+
power: 1
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
wandb:
|
| 62 |
+
project: MadSBM_PEPTIDE
|
| 63 |
+
group: programmablebio
|
| 64 |
+
name: peptide_og-madsbm_esm_no-gclip_lr=1e-4_n-layers=2_n-heads=16_trainable-lm-head_logits-sum-SM_gate-esm
|
| 65 |
+
#name: peptide_og-madsbm_esm_no-gclip_lr=1e-4_n-layers=2_n-heads=16_trainable-lm-head_logits-sum-SM_ABLATE-gate-esm
|
| 66 |
+
id: ${.name}_${seed}
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
checkpointing:
|
| 70 |
+
save_every_n_epochs: 1
|
| 71 |
+
save_dir: ${base_dir}/checkpoints/wt_pep/${wandb.name}
|
| 72 |
+
resume_ckpt_path: ${checkpointing.save_dir}/last.ckpt
|
| 73 |
+
best_ckpt_path: ${checkpointing.save_dir}/best-model_epoch=41_step=106890.ckpt
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
sampling:
|
| 77 |
+
model_type: madsbm # madsbm / diffusion / dfm
|
| 78 |
+
n_steps: 32
|
| 79 |
+
top_p: 0.9
|
| 80 |
+
rate_scale: 0.01
|
| 81 |
+
jump_scale: 0.05
|
| 82 |
+
tau: 0.5
|
| 83 |
+
M: 16
|
| 84 |
+
beta: 2.0
|
setup.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from setuptools import setup, find_packages
|
| 2 |
+
|
| 3 |
+
setup(
|
| 4 |
+
name='MadSBM',
|
| 5 |
+
version='1.0',
|
| 6 |
+
packages=find_packages(),
|
| 7 |
+
install_requires=[],
|
| 8 |
+
author='Shrey Goel',
|
| 9 |
+
author_email='shrey.goel@duke.edu'
|
| 10 |
+
)
|
src/__init__.py
ADDED
|
File without changes
|
src/madsbm/__init__.py
ADDED
|
File without changes
|
src/madsbm/wt_peptide/__init__.py
ADDED
|
File without changes
|
src/madsbm/wt_peptide/control_field.py
ADDED
|
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer
|
| 7 |
+
|
| 8 |
+
from src.utils.time_utils import TimeEmbedding
|
| 9 |
+
from src.utils.model_utils import _print
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# -------------------------
|
| 13 |
+
# DiT building blocks
|
| 14 |
+
# -------------------------
|
| 15 |
+
|
| 16 |
+
class MLP(nn.Module):
|
| 17 |
+
def __init__(self, dim, mlp_ratio, dropout):
|
| 18 |
+
super().__init__()
|
| 19 |
+
hidden_dim = int(dim * mlp_ratio)
|
| 20 |
+
self.fc1 = nn.Linear(dim, hidden_dim)
|
| 21 |
+
self.act = nn.GELU()
|
| 22 |
+
self.fc2 = nn.Linear(hidden_dim, dim)
|
| 23 |
+
self.dropout = nn.Dropout(dropout)
|
| 24 |
+
|
| 25 |
+
def forward(self, x):
|
| 26 |
+
x = self.fc1(x)
|
| 27 |
+
x = self.act(x)
|
| 28 |
+
x = self.dropout(x)
|
| 29 |
+
x = self.fc2(x)
|
| 30 |
+
x = self.dropout(x)
|
| 31 |
+
return x
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class DiTBlock1D(nn.Module):
|
| 35 |
+
def __init__(self, cfg):
|
| 36 |
+
super().__init__()
|
| 37 |
+
self.cfg = cfg
|
| 38 |
+
self.hidden_dim = cfg.model.hidden_dim
|
| 39 |
+
self.time_dim = cfg.time_embed.time_dim
|
| 40 |
+
|
| 41 |
+
self.norm1 = nn.LayerNorm(self.hidden_dim, eps=1e-6)
|
| 42 |
+
self.norm2 = nn.LayerNorm(self.hidden_dim, eps=1e-6)
|
| 43 |
+
|
| 44 |
+
# time-conditioned scale & shift for both norms
|
| 45 |
+
self.time_proj1 = nn.Linear(self.time_dim, 2 * self.hidden_dim) # scale1, shift1
|
| 46 |
+
self.time_proj2 = nn.Linear(self.time_dim, 2 * self.hidden_dim) # scale2, shift2
|
| 47 |
+
|
| 48 |
+
self.attn = nn.MultiheadAttention(
|
| 49 |
+
embed_dim=self.hidden_dim,
|
| 50 |
+
num_heads=cfg.model.n_heads,
|
| 51 |
+
dropout=cfg.model.attn_drop,
|
| 52 |
+
batch_first=True
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
self.mlp = MLP(
|
| 56 |
+
self.hidden_dim,
|
| 57 |
+
mlp_ratio=cfg.model.mlp_ratio,
|
| 58 |
+
dropout=cfg.model.resid_drop
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
def forward(self, x, t_emb, key_padding_mask=None):
|
| 62 |
+
# ----- Self-attention branch -----
|
| 63 |
+
# Adaptive LayerNorm (AdaLN) + FiLM from time embedding
|
| 64 |
+
scale1, shift1 = self.time_proj1(t_emb).chunk(2, dim=-1) # [B, D] and [B, D]
|
| 65 |
+
h = self.norm1(x)
|
| 66 |
+
h = h * (1 + scale1.unsqueeze(1)) + shift1.unsqueeze(1) # [B, L, D]
|
| 67 |
+
|
| 68 |
+
attn_out, _ = self.attn(
|
| 69 |
+
h,
|
| 70 |
+
h,
|
| 71 |
+
h,
|
| 72 |
+
key_padding_mask=key_padding_mask, # True for pads
|
| 73 |
+
need_weights=False,
|
| 74 |
+
)
|
| 75 |
+
x = x + attn_out
|
| 76 |
+
|
| 77 |
+
# ----- MLP branch -----
|
| 78 |
+
scale2, shift2 = self.time_proj2(t_emb).chunk(2, dim=-1)
|
| 79 |
+
h2 = self.norm2(x)
|
| 80 |
+
h2 = h2 * (1 + scale2.unsqueeze(1)) + shift2.unsqueeze(1)
|
| 81 |
+
|
| 82 |
+
mlp_out = self.mlp(h2)
|
| 83 |
+
x = x + mlp_out
|
| 84 |
+
|
| 85 |
+
return x
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class PeptideControlField(nn.Module):
|
| 89 |
+
def __init__(self, cfg):
|
| 90 |
+
super().__init__()
|
| 91 |
+
self.cfg = cfg
|
| 92 |
+
|
| 93 |
+
pth = cfg.model.esm_model
|
| 94 |
+
self.embed_model = AutoModelForMaskedLM.from_pretrained(pth, trust_remote_code=True)
|
| 95 |
+
self.tokenizer = AutoTokenizer.from_pretrained(pth, trust_remote_code=True)
|
| 96 |
+
|
| 97 |
+
# Freeze params
|
| 98 |
+
self.embed_model.eval()
|
| 99 |
+
for param in self.embed_model.parameters():
|
| 100 |
+
param.requires_grad = False
|
| 101 |
+
|
| 102 |
+
# # Unfreeze QKV in last few encoder layers
|
| 103 |
+
# encoder_layers = self.embed_model.esm.encoder.layer
|
| 104 |
+
# for layer in encoder_layers[-cfg.training.n_unfrozen:]:
|
| 105 |
+
# for param in layer.parameters():
|
| 106 |
+
# param.requires_grad = True
|
| 107 |
+
|
| 108 |
+
self.time_embed = TimeEmbedding(
|
| 109 |
+
hidden_dim=cfg.time_embed.time_dim,
|
| 110 |
+
fourier_dim=cfg.time_embed.fourier_dim,
|
| 111 |
+
scale=cfg.time_embed.fourier_scale
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
self.blocks = nn.ModuleList([
|
| 115 |
+
DiTBlock1D(self.cfg)
|
| 116 |
+
for _ in range(cfg.model.n_layers)
|
| 117 |
+
])
|
| 118 |
+
|
| 119 |
+
self.final_norm = nn.LayerNorm(cfg.model.hidden_dim, eps=1e-6)
|
| 120 |
+
|
| 121 |
+
# self.output_proj = self.embed_model.lm_head
|
| 122 |
+
# for param in self.output_proj.parameters():
|
| 123 |
+
# param.requires_grad = False
|
| 124 |
+
|
| 125 |
+
self.output_proj = nn.Linear(cfg.model.hidden_dim, self.tokenizer.vocab_size)
|
| 126 |
+
nn.init.zeros_(self.output_proj.weight)
|
| 127 |
+
nn.init.zeros_(self.output_proj.bias)
|
| 128 |
+
|
| 129 |
+
def forward(self, t, xt, attention_mask):
|
| 130 |
+
with torch.no_grad():
|
| 131 |
+
outs = self.embed_model(input_ids=xt, attention_mask=attention_mask, output_hidden_states=True)
|
| 132 |
+
|
| 133 |
+
gate = (1.0 - t).view(-1, 1, 1)
|
| 134 |
+
u_base = gate * outs.logits
|
| 135 |
+
|
| 136 |
+
h = outs.hidden_states[-1]
|
| 137 |
+
t_emb = self.time_embed(t) # [B, time_dim]
|
| 138 |
+
|
| 139 |
+
# Transformer head (key_padding_mask=True for pads)
|
| 140 |
+
key_padding_mask = (attention_mask == 0) # (B, L) bool
|
| 141 |
+
for dit_block in self.blocks:
|
| 142 |
+
h = dit_block(h, t_emb, key_padding_mask=key_padding_mask)
|
| 143 |
+
|
| 144 |
+
# Final norm + projection to vocab logits
|
| 145 |
+
h = self.final_norm(h) # [B, L, hidden_dim]
|
| 146 |
+
logits = self.output_proj(h) # [B, L, V]
|
| 147 |
+
|
| 148 |
+
return {
|
| 149 |
+
"esm": u_base,
|
| 150 |
+
"dit": logits,
|
| 151 |
+
"madsbm": u_base + logits
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
# def forward(self, t, xt, attention_mask):
|
| 158 |
+
# outs = self.embed_model(input_ids=xt, attention_mask=attention_mask, output_hidden_states=True)
|
| 159 |
+
# h = outs.hidden_states[-1]
|
| 160 |
+
# t_emb = self.time_embed(t) # [B, time_dim]
|
| 161 |
+
|
| 162 |
+
# # Transformer head (key_padding_mask=True for pads)
|
| 163 |
+
# key_padding_mask = (attention_mask == 0) # (B, L) bool
|
| 164 |
+
# for dit_block in self.blocks:
|
| 165 |
+
# h = dit_block(h, t_emb, key_padding_mask=key_padding_mask)
|
| 166 |
+
|
| 167 |
+
# # Final norm + projection to vocab logits
|
| 168 |
+
# h = self.final_norm(h) # [B, L, hidden_dim]
|
| 169 |
+
# logits = self.output_proj(h) # [B, L, V]
|
| 170 |
+
# return logits
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
# def forward(self, xt, attention_mask, t):
|
| 174 |
+
# with torch.no_grad():
|
| 175 |
+
# base_out = self.embed_model(
|
| 176 |
+
# input_ids=xt,
|
| 177 |
+
# attention_mask=attention_mask,
|
| 178 |
+
# output_hidden_states=True
|
| 179 |
+
# )
|
| 180 |
+
|
| 181 |
+
# logits_base = base_out.logits
|
| 182 |
+
# h_base = base_out.hidden_states[-1]
|
| 183 |
+
|
| 184 |
+
# norm = self.token_norm_sqrd.view(1,1,-1) # 1, 1, V
|
| 185 |
+
|
| 186 |
+
# log_R0 = (self.beta1 * logits_base) - (self.beta2 * norm)
|
| 187 |
+
|
| 188 |
+
# t_emb = self.time_embed(t) # [B, time_dim]
|
| 189 |
+
# key_padding_mask = (attention_mask == 0) # (B, L) bool
|
| 190 |
+
|
| 191 |
+
# h_ctrl = h_base
|
| 192 |
+
# for dit_block in self.blocks:
|
| 193 |
+
# h_ctrl = dit_block(h_ctrl, t_emb, key_padding_mask=key_padding_mask)
|
| 194 |
+
|
| 195 |
+
# h_ctrl = self.final_norm(h_ctrl)
|
| 196 |
+
# u_theta = self.output_proj(h_ctrl)
|
| 197 |
+
# tot_logits = log_R0 + u_theta
|
| 198 |
+
|
| 199 |
+
# return tot_logits, u_theta
|
src/madsbm/wt_peptide/dataloader.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import lightning.pytorch as pl
|
| 4 |
+
|
| 5 |
+
from omegaconf import OmegaConf
|
| 6 |
+
from datasets import load_from_disk
|
| 7 |
+
from torch.utils.data import DataLoader
|
| 8 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 9 |
+
from functools import partial
|
| 10 |
+
from src.utils.model_utils import _print
|
| 11 |
+
|
| 12 |
+
config = OmegaConf.load('/scratch/pranamlab/sgoel/MadSBM/configs/wt_pep.yaml')
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# class DNADataset(Dataset):
|
| 16 |
+
# def __init__(self, config, data_path):
|
| 17 |
+
# self.config = config
|
| 18 |
+
# self.data = pd.read_csv(data_path)
|
| 19 |
+
# self.custom_tokenizer = CustomDNATokenizer(config.model.dna_model_path)
|
| 20 |
+
|
| 21 |
+
# def __len__(self):
|
| 22 |
+
# return len(self.data)
|
| 23 |
+
|
| 24 |
+
# def __getitem__(self, idx):
|
| 25 |
+
# sequence = self.data.iloc[idx]["Sequence"]
|
| 26 |
+
# seq = sequence.upper()
|
| 27 |
+
|
| 28 |
+
# tokenized = self.custom_tokenizer(seq, max_length=self.config.data.max_seq_len)
|
| 29 |
+
|
| 30 |
+
# return {
|
| 31 |
+
# "input_ids": tokenized["input_ids"].squeeze(0),
|
| 32 |
+
# "attention_mask": tokenized["attention_mask"].squeeze(0)
|
| 33 |
+
# }
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def collate_fn(batch, pad_id=None):
|
| 38 |
+
input_ids = torch.tensor(batch[0]['input_ids'])
|
| 39 |
+
attention_mask = torch.tensor(batch[0]['attention_mask'])
|
| 40 |
+
return {
|
| 41 |
+
'input_ids': input_ids,
|
| 42 |
+
'attention_mask': attention_mask
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class PeptideDataModule(pl.LightningDataModule):
|
| 47 |
+
def __init__(self, config, train_dataset, val_dataset, test_dataset, tokenizer, collate_fn=collate_fn):
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.train_dataset = train_dataset
|
| 50 |
+
self.val_dataset = val_dataset
|
| 51 |
+
self.test_dataset = test_dataset
|
| 52 |
+
self.tokenizer = tokenizer
|
| 53 |
+
self.collate_fn = collate_fn
|
| 54 |
+
self.batch_size = config.data.batch_size
|
| 55 |
+
assert self.batch_size == 1, f'Batch size = {self.batch_size}. Needs to be 1 for dynamic batching'
|
| 56 |
+
|
| 57 |
+
def train_dataloader(self):
|
| 58 |
+
return DataLoader(self.train_dataset,
|
| 59 |
+
batch_size=self.batch_size,
|
| 60 |
+
collate_fn=partial(self.collate_fn),
|
| 61 |
+
num_workers=8,
|
| 62 |
+
shuffle=False,
|
| 63 |
+
pin_memory=True)
|
| 64 |
+
|
| 65 |
+
def val_dataloader(self):
|
| 66 |
+
return DataLoader(self.val_dataset,
|
| 67 |
+
batch_size=self.batch_size,
|
| 68 |
+
collate_fn=partial(self.collate_fn),
|
| 69 |
+
num_workers=8,
|
| 70 |
+
shuffle=False,
|
| 71 |
+
pin_memory=True)
|
| 72 |
+
|
| 73 |
+
def test_dataloader(self):
|
| 74 |
+
return DataLoader(self.test_dataset,
|
| 75 |
+
batch_size=self.batch_size,
|
| 76 |
+
collate_fn=partial(self.collate_fn),
|
| 77 |
+
num_workers=8,
|
| 78 |
+
shuffle=False,
|
| 79 |
+
pin_memory=True)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def get_datasets(config):
|
| 83 |
+
"""Helper method to grab datasets to quickly init data module in main.py"""
|
| 84 |
+
train_dataset = load_from_disk(config.data.train)
|
| 85 |
+
test_dataset = load_from_disk(config.data.test)
|
| 86 |
+
val_dataset = load_from_disk(config.data.val)
|
| 87 |
+
|
| 88 |
+
return {
|
| 89 |
+
"train": train_dataset,
|
| 90 |
+
"val": val_dataset,
|
| 91 |
+
"test": test_dataset
|
| 92 |
+
}
|
| 93 |
+
|
src/madsbm/wt_peptide/main.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python3
|
| 2 |
+
|
| 3 |
+
import sys
|
| 4 |
+
import os
|
| 5 |
+
import torch
|
| 6 |
+
import wandb
|
| 7 |
+
import lightning.pytorch as pl
|
| 8 |
+
|
| 9 |
+
from omegaconf import OmegaConf
|
| 10 |
+
from lightning.pytorch.strategies import DDPStrategy
|
| 11 |
+
from lightning.pytorch.loggers import WandbLogger
|
| 12 |
+
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
from src.madsbm.wt_peptide.sbm_module import MadSBM
|
| 16 |
+
from src.madsbm.wt_peptide.dataloader import PeptideDataModule, get_datasets
|
| 17 |
+
|
| 18 |
+
wandb.login(key='2b76a2fa2c1cdfddc5f443602c17b011fefb0a8f')
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# Load yaml config
|
| 22 |
+
config = OmegaConf.load("/scratch/pranamlab/sgoel/MadSBM/configs/wt_pep.yaml")
|
| 23 |
+
|
| 24 |
+
# Initialize WandB for logging
|
| 25 |
+
wandb.init(project=config.wandb.project, name=config.wandb.name)
|
| 26 |
+
wandb_logger = WandbLogger(**config.wandb)
|
| 27 |
+
|
| 28 |
+
# PL checkpoints
|
| 29 |
+
lr_monitor = LearningRateMonitor(logging_interval="step")
|
| 30 |
+
|
| 31 |
+
every_epoch_cb = ModelCheckpoint(
|
| 32 |
+
dirpath=config.checkpointing.save_dir,
|
| 33 |
+
filename="{epoch:02d}_{step}",
|
| 34 |
+
save_top_k=-1,
|
| 35 |
+
every_n_epochs=1,
|
| 36 |
+
save_on_train_epoch_end=True,
|
| 37 |
+
verbose=True,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
best_ckpt_cb = ModelCheckpoint(
|
| 41 |
+
monitor="val/loss",
|
| 42 |
+
dirpath=config.checkpointing.save_dir,
|
| 43 |
+
filename="best-model_{epoch:02d}_{step}",
|
| 44 |
+
save_top_k=1,
|
| 45 |
+
mode="min",
|
| 46 |
+
verbose=True,
|
| 47 |
+
save_last=False,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# PL trainer
|
| 51 |
+
trainer = pl.Trainer(
|
| 52 |
+
#max_steps=None, # Ensure training is based on epochs so we can compare with MOG-DFM and DirichletFM
|
| 53 |
+
max_epochs=config.training.n_epochs,
|
| 54 |
+
accelerator="cuda" if torch.cuda.is_available() else "cpu",
|
| 55 |
+
devices=config.training.devices if config.training.mode=='train' else [0],
|
| 56 |
+
strategy=DDPStrategy(find_unused_parameters=True),
|
| 57 |
+
callbacks=[every_epoch_cb, best_ckpt_cb, lr_monitor],
|
| 58 |
+
logger=wandb_logger
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# Folder to save checkpoints
|
| 63 |
+
ckpt_path = config.checkpointing.save_dir
|
| 64 |
+
try: os.makedirs(ckpt_path, exist_ok=False)
|
| 65 |
+
except FileExistsError: pass
|
| 66 |
+
|
| 67 |
+
# PL Model for training
|
| 68 |
+
sbm_model = MadSBM(config)
|
| 69 |
+
sbm_model.validate_config()
|
| 70 |
+
|
| 71 |
+
# Get datasets
|
| 72 |
+
datasets = get_datasets(config)
|
| 73 |
+
data_module = PeptideDataModule(
|
| 74 |
+
config=config,
|
| 75 |
+
train_dataset=datasets['train'],
|
| 76 |
+
val_dataset=datasets['val'],
|
| 77 |
+
test_dataset=datasets['test'],
|
| 78 |
+
tokenizer=sbm_model.tokenizer,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# Start/resume training or evaluate the model
|
| 82 |
+
if config.training.mode == "train":
|
| 83 |
+
trainer.fit(sbm_model, datamodule=data_module)
|
| 84 |
+
|
| 85 |
+
elif config.training.mode == "test":
|
| 86 |
+
state_dict = sbm_model.get_state_dict(config.checkpointing.best_ckpt_path)
|
| 87 |
+
sbm_model.load_state_dict(state_dict)
|
| 88 |
+
trainer.test(sbm_model, datamodule=data_module, ckpt_path=config.checkpointing.best_ckpt_path)
|
| 89 |
+
|
| 90 |
+
elif config.training.mode == "resume_from_checkpoint":
|
| 91 |
+
state_dict = sbm_model.get_state_dict(config.training.resume_ckpt_path)
|
| 92 |
+
sbm_model.load_state_dict(state_dict)
|
| 93 |
+
trainer.fit(sbm_model, datamodule=data_module, ckpt_path=ckpt_path)
|
| 94 |
+
|
| 95 |
+
wandb.finish()
|
| 96 |
+
|
src/madsbm/wt_peptide/sbm_module.py
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gc
|
| 2 |
+
import os
|
| 3 |
+
import math
|
| 4 |
+
from re import L
|
| 5 |
+
import torch
|
| 6 |
+
|
| 7 |
+
import lightning as pl
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
|
| 10 |
+
from transformers import AutoModel
|
| 11 |
+
|
| 12 |
+
from src.madsbm.wt_peptide.control_field import PeptideControlField
|
| 13 |
+
from src.PeptiVerse.inference import PeptiVersePredictor
|
| 14 |
+
from src.utils.model_utils import CosineWarmup, _print, compute_grad_norms
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class MadSBM(pl.LightningModule):
|
| 18 |
+
def __init__(self, config, guidance=None):
|
| 19 |
+
super().__init__()
|
| 20 |
+
|
| 21 |
+
self.config = config
|
| 22 |
+
self.model = PeptideControlField(config)
|
| 23 |
+
self.tokenizer = self.model.tokenizer
|
| 24 |
+
self.vocab_size = self.tokenizer.vocab_size
|
| 25 |
+
|
| 26 |
+
self.mask_id = self.tokenizer.mask_token_id
|
| 27 |
+
self.pad_id = self.tokenizer.pad_token_id
|
| 28 |
+
|
| 29 |
+
self.embed_model = AutoModel.from_pretrained(config.model.esm_model)
|
| 30 |
+
self.embed_model.eval()
|
| 31 |
+
for param in self.embed_model.parameters():
|
| 32 |
+
param.requires_grad = False
|
| 33 |
+
|
| 34 |
+
self.beta = 1.0 / self.config.model.hidden_dim
|
| 35 |
+
|
| 36 |
+
# self.L = config.data.max_seq_len
|
| 37 |
+
# self.V = self.vocab_size
|
| 38 |
+
# self.log_R0 = - math.log(self.L * self.V) # uninformed generator is constant
|
| 39 |
+
|
| 40 |
+
self.time_schedule = config.time_embed.time_schedule
|
| 41 |
+
self.anneal_frac = config.time_embed.anneal_frac
|
| 42 |
+
self.eps = float(config.time_embed.min_time)
|
| 43 |
+
self.t_max = 1.0 - self.eps
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# -------# Forward Pass #-------- #
|
| 47 |
+
def forward(self, input_ids, attention_mask, t):
|
| 48 |
+
return self.model(xt=input_ids, attention_mask=attention_mask, t=t)
|
| 49 |
+
|
| 50 |
+
def step(self, batch):
|
| 51 |
+
x1 = batch['input_ids']
|
| 52 |
+
attn_mask = batch['attention_mask']
|
| 53 |
+
maskable = self.is_maskable(x1)
|
| 54 |
+
|
| 55 |
+
t = self.sample_t(x1)
|
| 56 |
+
xt = self.noise_seq(x1, t, maskable_mask=maskable)
|
| 57 |
+
|
| 58 |
+
outs = self.forward(xt, attn_mask, t)
|
| 59 |
+
if self.config.model.ablate:
|
| 60 |
+
logits = outs['dit']
|
| 61 |
+
else:
|
| 62 |
+
logits = outs['madsbm']
|
| 63 |
+
max_u_logit = outs['dit'].max().item()
|
| 64 |
+
max_esm_logit = outs['esm'].max().item()
|
| 65 |
+
|
| 66 |
+
loss_token = F.cross_entropy(
|
| 67 |
+
logits.view(-1, logits.size(-1)),
|
| 68 |
+
x1.view(-1),
|
| 69 |
+
reduction = 'none',
|
| 70 |
+
ignore_index=self.pad_id
|
| 71 |
+
)
|
| 72 |
+
loss_token = loss_token.view(x1.size(0), x1.size(1))
|
| 73 |
+
|
| 74 |
+
sample_loss = (loss_token * maskable.float()).sum(dim=1) / maskable.float().sum(dim=1).clamp(min=1.0)
|
| 75 |
+
|
| 76 |
+
loss = sample_loss.mean()
|
| 77 |
+
ppl = torch.exp(loss)
|
| 78 |
+
|
| 79 |
+
_print(f'loss: {loss}')
|
| 80 |
+
_print(f'ppl: {ppl}')
|
| 81 |
+
|
| 82 |
+
return loss, ppl, max_u_logit, max_esm_logit
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
# def step(self, batch):
|
| 87 |
+
# x1 = batch['input_ids']
|
| 88 |
+
# attn_mask = batch['attention_mask']
|
| 89 |
+
# maskable = self.is_maskable(x1)
|
| 90 |
+
|
| 91 |
+
# t = self.sample_t(x1)
|
| 92 |
+
# xt = self.noise_seq(x1, t, maskable_mask=maskable)
|
| 93 |
+
|
| 94 |
+
# u_theta = self.forward(xt, attn_mask, t)
|
| 95 |
+
# b, l, v_target = self.compute_target(x1, xt, t, maskable_mask=maskable)
|
| 96 |
+
# loss, ppl = self.compute_loss(u_theta, v_target, x1, b, l)
|
| 97 |
+
|
| 98 |
+
# _print(f'loss: {loss}')
|
| 99 |
+
# _print(f'ppl: {ppl}')
|
| 100 |
+
|
| 101 |
+
# return loss, ppl
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# -------# Main Training Logic #-------- #
|
| 105 |
+
def noise_seq(self, x1, t, maskable_mask):
|
| 106 |
+
B, L = x1.shape
|
| 107 |
+
t = t.unsqueeze(1) # B, 1
|
| 108 |
+
|
| 109 |
+
# reveal if u < t, mask if u >= t
|
| 110 |
+
u = torch.rand((B, L), device=x1.device)
|
| 111 |
+
masked = (u < t) & maskable_mask
|
| 112 |
+
|
| 113 |
+
xt = x1.clone()
|
| 114 |
+
xt = xt.masked_fill(masked, self.mask_id)
|
| 115 |
+
|
| 116 |
+
return xt
|
| 117 |
+
|
| 118 |
+
# def compute_target(self, x1, xt, t, maskable_mask):
|
| 119 |
+
# L = x1.size(1)
|
| 120 |
+
# V = self.vocab_size
|
| 121 |
+
# device = x1.device
|
| 122 |
+
|
| 123 |
+
# mask = (xt == self.mask_id) & maskable_mask
|
| 124 |
+
# b, l = torch.nonzero(mask, as_tuple=True)
|
| 125 |
+
|
| 126 |
+
# if b.numel() == 0:
|
| 127 |
+
# return b, l, torch.empty(0, device=device, dtype=torch.long)
|
| 128 |
+
|
| 129 |
+
# log_R0 = - math.log(L * V) # uniform generator with rates (1 / L*V)
|
| 130 |
+
# time = - torch.log(1 - t[b])
|
| 131 |
+
|
| 132 |
+
# v_target = time - log_R0 # log(1/1-t) - log(1/L*V)
|
| 133 |
+
# v_target = v_target.clamp(min=-100.0, max=100.0)
|
| 134 |
+
|
| 135 |
+
# return b, l, v_target
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
# def compute_loss(self, u_theta, v_target, x1, b, l):
|
| 139 |
+
# if b.numel() == 0:
|
| 140 |
+
# dummy_loss = 0.0 * u_theta.sum()
|
| 141 |
+
# return dummy_loss, torch.tensor(0.0, device=u_theta.device)
|
| 142 |
+
|
| 143 |
+
# true_toks = x1[b, l]
|
| 144 |
+
# u_pred = u_theta[b, l, :] # N_masks, V
|
| 145 |
+
|
| 146 |
+
# tgt = torch.zeros_like(u_pred)
|
| 147 |
+
# tgt.scatter_(1, true_toks.unsqueeze(1), v_target.unsqueeze(1))
|
| 148 |
+
|
| 149 |
+
# sse = F.mse_loss(u_pred, tgt, reduction='sum')
|
| 150 |
+
# loss = sse / b.numel() if b.numel != 0 else sse # normalize by number of masks
|
| 151 |
+
|
| 152 |
+
# with torch.no_grad():
|
| 153 |
+
# ppl = torch.exp(F.cross_entropy(u_pred, true_toks))
|
| 154 |
+
|
| 155 |
+
# return loss, ppl
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
# -------# Time Schedules #-------- #
|
| 159 |
+
def sample_t(self, x1):
|
| 160 |
+
ts = self.time_schedule
|
| 161 |
+
if ts == 'linear':
|
| 162 |
+
return self.sample_linear_t(x1)
|
| 163 |
+
elif ts == 'exponential':
|
| 164 |
+
return self.sample_exp_t(x1)
|
| 165 |
+
elif ts == 'uniform':
|
| 166 |
+
return self.sample_uni_t(x1)
|
| 167 |
+
else:
|
| 168 |
+
raise ValueError(f"Unrecognized time scheduler type: {ts}")
|
| 169 |
+
|
| 170 |
+
def sample_uni_t(self, x1):
|
| 171 |
+
B = x1.size(0)
|
| 172 |
+
T = self.config.time_embed.n_timesteps
|
| 173 |
+
|
| 174 |
+
discrete_ts = torch.randint(1, T+1, (B,), device=x1.device)
|
| 175 |
+
timesteps = discrete_ts.float() / float(T)
|
| 176 |
+
_print(f'timesteps: {timesteps}')
|
| 177 |
+
return timesteps.clamp(min=self.eps, max=self.t_max)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def sample_linear_t(self, x1):
|
| 181 |
+
B = x1.size(0)
|
| 182 |
+
eps = self.eps
|
| 183 |
+
|
| 184 |
+
# fraction of total training steps completed
|
| 185 |
+
frac = float(self.global_step) / float(self.tot_steps)
|
| 186 |
+
t_max = 1.0 - eps
|
| 187 |
+
|
| 188 |
+
if frac < self.anneal_frac:
|
| 189 |
+
# normalize progress within the anneal window
|
| 190 |
+
prog = frac / max(1e-12, self.anneal_frac) # maps [0, anneal_frac) to [0,1)
|
| 191 |
+
t_min = eps + prog * (t_max - eps) # linear increase from eps to 1.0-eps
|
| 192 |
+
t = t_min + (t_max - t_min) * torch.rand(B, device=x1.device)
|
| 193 |
+
else:
|
| 194 |
+
# after anneal_frac of training steps completed, then uniform sample over entire range [eps, 1.0-eps]
|
| 195 |
+
t = eps + (t_max - eps) * torch.rand(B, device=x1.device)
|
| 196 |
+
|
| 197 |
+
return t.clamp(min=eps, max=t_max)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def sample_t_exponential(self, x1, t_min=1e-6, t_max=1.0-1e-6):
|
| 201 |
+
# TODO - FIX THIS METHOD IF NEEDED !!
|
| 202 |
+
"""
|
| 203 |
+
Exponentially anneal center of t from t_min to t_max over training.
|
| 204 |
+
|
| 205 |
+
Implement if linear schedule isn't expressive enough
|
| 206 |
+
But for annealing over training steps, which can be a very large quantity,
|
| 207 |
+
exponential approximates linear schedule
|
| 208 |
+
"""
|
| 209 |
+
# k controls how fast the curve rises.
|
| 210 |
+
k = self.config.training.exp_time_k
|
| 211 |
+
progress = self.trainer.step / self.tot_steps
|
| 212 |
+
frac = 1.0 - torch.exp(-k * torch.tensor(progress))
|
| 213 |
+
center = t_min + frac * (t_max - t_min)
|
| 214 |
+
|
| 215 |
+
# add small jitter so we don't collapse onto a distribution
|
| 216 |
+
t = torch.randn(x1.size(0)) * self.config.training.time_sigma + center
|
| 217 |
+
return t.clamp(min=t_min, max=t_max)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
# -------# Model Training / Evaluation #-------- #
|
| 222 |
+
def training_step(self, batch):
|
| 223 |
+
loss, ppl = self.step(batch)
|
| 224 |
+
self.log("train/loss", loss, on_step=True, on_epoch=False, prog_bar=True)
|
| 225 |
+
self.log("train/ppl", ppl, on_step=True, on_epoch=False, prog_bar=False)
|
| 226 |
+
return loss
|
| 227 |
+
|
| 228 |
+
def validation_step(self, batch):
|
| 229 |
+
loss, ppl = self.step(batch)
|
| 230 |
+
self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
|
| 231 |
+
self.log("val/ppl", ppl, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
|
| 232 |
+
return loss
|
| 233 |
+
|
| 234 |
+
def test_step(self, batch):
|
| 235 |
+
loss, ppl, max_u, max_esm = self.step(batch)
|
| 236 |
+
self.log('test/loss', loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
|
| 237 |
+
self.log("test/ppl", ppl, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
|
| 238 |
+
self.log("test/max_madsbm_logit", max_u, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
|
| 239 |
+
self.log("test/max_esm_logit", max_esm, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
|
| 240 |
+
return loss
|
| 241 |
+
|
| 242 |
+
def on_after_backward(self):
|
| 243 |
+
pre_norm = compute_grad_norms(self.parameters())
|
| 244 |
+
self.log('train/grad_norm_PRE_clip', pre_norm, on_step=True, on_epoch=False, prog_bar=False, sync_dist=True)
|
| 245 |
+
|
| 246 |
+
# torch.nn.utils.clip_grad_norm_(self.parameters(), float(self.config.training.grad_clip_val))
|
| 247 |
+
# post_norm = compute_grad_norms(self.parameters())
|
| 248 |
+
# self.log('train/grad_norm_POST_clip', post_norm, on_step=True, on_epoch=False, prog_bar=False, sync_dist=True)
|
| 249 |
+
|
| 250 |
+
def configure_optimizers(self):
|
| 251 |
+
optimizer = torch.optim.AdamW(
|
| 252 |
+
params = self.model.parameters(),
|
| 253 |
+
lr = self.config.optim.lr,
|
| 254 |
+
weight_decay = self.config.optim.weight_decay,
|
| 255 |
+
betas = (self.config.optim.beta1, self.config.optim.beta2)
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
self.tot_steps = self.trainer.estimated_stepping_batches
|
| 259 |
+
warmup_steps = int(self.config.optim.warmup_epochs * self.tot_steps / self.config.training.n_epochs)
|
| 260 |
+
|
| 261 |
+
lr_scheduler = CosineWarmup(
|
| 262 |
+
optimizer = optimizer,
|
| 263 |
+
warmup_steps = warmup_steps,
|
| 264 |
+
total_steps = self.tot_steps
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
return {
|
| 268 |
+
"optimizer": optimizer,
|
| 269 |
+
"lr_scheduler": {
|
| 270 |
+
"scheduler": lr_scheduler,
|
| 271 |
+
"interval": "step",
|
| 272 |
+
"frequency": 1
|
| 273 |
+
}
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
def on_save_checkpoint(self, checkpoint: dict):
|
| 277 |
+
"""
|
| 278 |
+
Don't save the classifier model used for FBD calculation in the ckpt
|
| 279 |
+
"""
|
| 280 |
+
sd = checkpoint.get('state_dict', None)
|
| 281 |
+
if sd is None:
|
| 282 |
+
return
|
| 283 |
+
keys_to_remove = [k for k in sd.keys() if k.startswith("score_model.")]
|
| 284 |
+
for k in keys_to_remove:
|
| 285 |
+
sd.pop(k, None)
|
| 286 |
+
checkpoint['state_dict'] = sd
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
# -------# Helper methods #-------- #
|
| 290 |
+
def is_maskable(self, input_ids: torch.Tensor):
|
| 291 |
+
return (
|
| 292 |
+
(input_ids != self.tokenizer.pad_token_id)
|
| 293 |
+
& (input_ids != self.tokenizer.cls_token_id)
|
| 294 |
+
& (input_ids != self.tokenizer.eos_token_id)
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
def validate_config(self):
|
| 298 |
+
assert os.path.isdir(self.config.checkpointing.save_dir), "invalid checkpointing path"
|
| 299 |
+
assert self.config.model.hidden_dim % 2 == 0, 'odd value for embedding dim'
|
| 300 |
+
assert self.config.time_embed.time_dim % 2 == 0, 'odd value for time dim'
|
| 301 |
+
assert self.config.time_embed.fourier_dim % 2 == 0, 'odd value for fourier dim'
|
| 302 |
+
|
| 303 |
+
def get_state_dict(self, ckpt_path):
|
| 304 |
+
def remove_model_prefix(state_dict):
|
| 305 |
+
for k, v in state_dict.items():
|
| 306 |
+
if "model." in k:
|
| 307 |
+
k.replace('model.', '')
|
| 308 |
+
return state_dict
|
| 309 |
+
|
| 310 |
+
checkpoint = torch.load(ckpt_path, map_location='cuda:3' if torch.cuda.is_available() else 'cpu')
|
| 311 |
+
state_dict = checkpoint.get("state_dict", checkpoint)
|
| 312 |
+
|
| 313 |
+
if any(k.startswith("model.") for k in state_dict.keys()):
|
| 314 |
+
state_dict = remove_model_prefix(state_dict)
|
| 315 |
+
|
| 316 |
+
return state_dict
|
| 317 |
+
|
| 318 |
+
def cleanup(self):
|
| 319 |
+
torch.cuda.empty_cache()
|
| 320 |
+
gc.collect()
|
src/sampling/diffusion_sampler.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import torch
|
| 3 |
+
import random
|
| 4 |
+
import numpy as np
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from src.utils.model_utils import _print
|
| 7 |
+
|
| 8 |
+
class DiffusionSampler:
|
| 9 |
+
def __init__(self, model, tokenizer):
|
| 10 |
+
self.model = model
|
| 11 |
+
self.tokenizer = tokenizer
|
| 12 |
+
|
| 13 |
+
self.device = self.model.device
|
| 14 |
+
self.mask_id = self.tokenizer.mask_token_id
|
| 15 |
+
self.seed_everything(seed=42)
|
| 16 |
+
|
| 17 |
+
@torch.inference_mode()
|
| 18 |
+
def sample_unconditional(self, xt, num_steps, tracer, tau=1.0, kappa_fn=lambda t: t, eta=1, alpha=1.):
|
| 19 |
+
"""
|
| 20 |
+
Stochastic remasking sampling method for iterative refinement of sequences.
|
| 21 |
+
Args:
|
| 22 |
+
xt (Tensor): Initial token tensor.
|
| 23 |
+
num_steps (int): Number of refinement steps.
|
| 24 |
+
tau (float): Temperature parameter for softmax sampling.
|
| 25 |
+
kappa_fn (callable): Function controlling the unmasking schedule.
|
| 26 |
+
eta (float): Scaling factor for score adjustments.
|
| 27 |
+
alpha (float): Weighting for confidence-based scoring.
|
| 28 |
+
Returns:
|
| 29 |
+
Tensor: Final sampled sequence tensor.
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
dt = 1 / num_steps
|
| 33 |
+
fix_mask = xt != self.mask_id # tokens to retain
|
| 34 |
+
attention_mask = torch.ones_like(xt).to(self.device)
|
| 35 |
+
|
| 36 |
+
if tracer:
|
| 37 |
+
tracer.log_step(xt=xt, step_idx = 0)
|
| 38 |
+
|
| 39 |
+
for i in range(1, num_steps + 1):
|
| 40 |
+
kappa_t = kappa_fn(i * dt)
|
| 41 |
+
logits = self.model(input_ids=xt, attention_mask=attention_mask).logits
|
| 42 |
+
last_mask = xt == self.mask_id # tokens currently masked
|
| 43 |
+
unmask_t = ~last_mask & ~fix_mask # unmasked and not fixed tokens - candidates for masking
|
| 44 |
+
|
| 45 |
+
x0, logp = self.stochastic_sample_from_categorical(logits, tau) # tokens, logprobs
|
| 46 |
+
|
| 47 |
+
# Confidence-based scoring
|
| 48 |
+
entropy = torch.distributions.Categorical(logits=logits).entropy()
|
| 49 |
+
score = alpha * logp + (1 - alpha) * -entropy # alpha = 1 --> score = logp
|
| 50 |
+
score = score.masked_fill(fix_mask, float('inf'))
|
| 51 |
+
|
| 52 |
+
score[unmask_t] = score[unmask_t] * eta
|
| 53 |
+
|
| 54 |
+
num_to_mask = ((~fix_mask).sum(1, keepdim=True).float() * (1 - kappa_t)).long()
|
| 55 |
+
lowest_k_mask = self.topk_lowest_masking(score, num_to_mask)
|
| 56 |
+
|
| 57 |
+
xt[lowest_k_mask] = self.mask_id
|
| 58 |
+
mask_2_x0 = last_mask & ~lowest_k_mask
|
| 59 |
+
xt[mask_2_x0] = x0[mask_2_x0]
|
| 60 |
+
|
| 61 |
+
tracer.log_step(xt=xt, step_idx = i)
|
| 62 |
+
|
| 63 |
+
xt[xt == self.mask_id] = x0[xt == self.mask_id]
|
| 64 |
+
|
| 65 |
+
tracer.log_step(xt, num_steps + 1)
|
| 66 |
+
|
| 67 |
+
return xt
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def stochastic_sample_from_categorical(self, logits, temperature, noise_scale=1.0):
|
| 71 |
+
"""
|
| 72 |
+
Sample from a categorical distribution with optional temperature scaling and Gumbel noise.
|
| 73 |
+
"""
|
| 74 |
+
logits = logits.double()
|
| 75 |
+
if temperature != 0:
|
| 76 |
+
gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) + 1e-8) + 1e-8)
|
| 77 |
+
logits = logits / temperature + noise_scale * gumbel_noise
|
| 78 |
+
scores, tokens = logits.log_softmax(dim=-1).max(dim=-1)
|
| 79 |
+
|
| 80 |
+
return tokens, scores
|
| 81 |
+
|
| 82 |
+
def topk_lowest_masking(self, scores, cutoff_len):
|
| 83 |
+
"""
|
| 84 |
+
scores: [b, n]
|
| 85 |
+
cutoff_len: [b, 1]
|
| 86 |
+
returns:
|
| 87 |
+
mask: [b, n], with 1 if the token is in top-k lowest scores, 0 otherwise
|
| 88 |
+
"""
|
| 89 |
+
sorted_index = scores.sort(-1)[0]
|
| 90 |
+
cutoff = sorted_index.gather(dim=-1, index=cutoff_len)
|
| 91 |
+
return scores < cutoff
|
| 92 |
+
|
| 93 |
+
def seed_everything(self, seed):
|
| 94 |
+
"""
|
| 95 |
+
Set the seed for reproducibility across various libraries.
|
| 96 |
+
"""
|
| 97 |
+
if seed is None:
|
| 98 |
+
return
|
| 99 |
+
random.seed(seed)
|
| 100 |
+
np.random.seed(seed)
|
| 101 |
+
torch.manual_seed(seed)
|
| 102 |
+
if torch.cuda.is_available():
|
| 103 |
+
torch.cuda.manual_seed(seed)
|
| 104 |
+
torch.cuda.manual_seed_all(seed) # if using multi-GPU
|
| 105 |
+
torch.backends.cudnn.deterministic = True
|
| 106 |
+
torch.backends.cudnn.benchmark = False
|
src/sampling/guided_sample.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
import torch
|
| 6 |
+
import pandas as pd
|
| 7 |
+
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from omegaconf import OmegaConf
|
| 10 |
+
from datetime import datetime
|
| 11 |
+
|
| 12 |
+
from transformers import AutoTokenizer, AutoModelForMaskedLM
|
| 13 |
+
|
| 14 |
+
from src.madsbm.wt_peptide.sbm_module import MadSBM
|
| 15 |
+
from src.sampling.madsbm_sampler import MadSBMSampler
|
| 16 |
+
|
| 17 |
+
from src.utils.generate_utils import calc_entropy, mask_for_de_novo, calc_ppl
|
| 18 |
+
from src.utils.model_utils import _print
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
|
| 22 |
+
os.chdir('/scratch/pranamlab/sgoel/MadSBM')
|
| 23 |
+
config = OmegaConf.load("./configs/wt_pep.yaml")
|
| 24 |
+
|
| 25 |
+
date = datetime.now().strftime("%Y-%m-%d")
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def generate_sequence(masked_seq, target_toks, tokenizer, generator, device):
|
| 29 |
+
input_ids = tokenizer(masked_seq, return_tensors="pt").to(device)['input_ids']
|
| 30 |
+
|
| 31 |
+
uncond_ids, uncond_bind = generator.sample(xt=input_ids, num_steps=config.sampling.n_steps, target_toks=target_toks, guidance=False)
|
| 32 |
+
guided_ids, guided_bind = generator.sample(xt=input_ids, num_steps=config.sampling.n_steps, target_toks=target_toks, guidance=True)
|
| 33 |
+
|
| 34 |
+
uncond_seq = tokenizer.decode(uncond_ids[0].squeeze())[5:-5].replace(" ", "") # bos/eos tokens & spaces between residues
|
| 35 |
+
guided_seq = tokenizer.decode(guided_ids[0].squeeze())[5:-5].replace(" ", "")
|
| 36 |
+
|
| 37 |
+
return uncond_seq, guided_seq, uncond_bind, guided_bind
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def main():
|
| 41 |
+
csv_save_path = f'./results/guided/'
|
| 42 |
+
|
| 43 |
+
try: os.makedirs(csv_save_path, exist_ok=False)
|
| 44 |
+
except FileExistsError: pass
|
| 45 |
+
|
| 46 |
+
# Load ESM model for eval
|
| 47 |
+
esm_pth = config.model.esm_model
|
| 48 |
+
esm_model = AutoModelForMaskedLM.from_pretrained(esm_pth).to(device)
|
| 49 |
+
esm_model.eval()
|
| 50 |
+
|
| 51 |
+
# Load SBM model
|
| 52 |
+
gen_model = MadSBM(config)
|
| 53 |
+
state_dict = gen_model.get_state_dict(config.checkpointing.best_ckpt_path)
|
| 54 |
+
gen_model.load_state_dict(state_dict)
|
| 55 |
+
gen_model.to(device)
|
| 56 |
+
gen_model.eval()
|
| 57 |
+
tokenizer = gen_model.tokenizer
|
| 58 |
+
generator = MadSBMSampler(gen_model, config, device, guidance=True)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
tgt_name = "3HVE"
|
| 62 |
+
df = pd.read_csv("./data/wt_pep/targets.csv")
|
| 63 |
+
tgt_seq = df.loc[df['Target'] == tgt_name, 'Sequence'].iloc[0]
|
| 64 |
+
target_toks = tokenizer(tgt_seq, return_tensors='pt')['input_ids'].to(device)
|
| 65 |
+
|
| 66 |
+
existing_binder = df.loc[df['Target'] == tgt_name, 'Existing Binder'].iloc[0]
|
| 67 |
+
existing_binder_pred = generator.peptiverse.predict_binding_affinity(
|
| 68 |
+
mode = 'wt',
|
| 69 |
+
target_ids = target_toks,
|
| 70 |
+
binder_ids = tokenizer(existing_binder, return_tensors='pt')['input_ids'].to(device).detach()
|
| 71 |
+
)['affinity']
|
| 72 |
+
|
| 73 |
+
_print(f'EXISTING BINDER AFFINITY: {existing_binder_pred}')
|
| 74 |
+
|
| 75 |
+
seq_lengths = [length for length in [10, 15, 20] for _ in range(20)]
|
| 76 |
+
generation_results = []
|
| 77 |
+
|
| 78 |
+
for seq_len in tqdm(seq_lengths, desc=f"Generating sequences: "):
|
| 79 |
+
|
| 80 |
+
masked_seq = mask_for_de_novo(seq_len) # Sequence of all <MASK> tokens
|
| 81 |
+
uncond_seq, guided_seq, uncond_bind, guided_bind = generate_sequence(masked_seq, target_toks, tokenizer, generator, device)
|
| 82 |
+
|
| 83 |
+
uncond_ppl = calc_ppl(esm_model, tokenizer, uncond_seq, [i for i in range(len(uncond_seq))], model_type='esm')
|
| 84 |
+
guided_ppl = calc_ppl(esm_model, tokenizer, uncond_seq, [i for i in range(len(uncond_seq))], model_type='esm')
|
| 85 |
+
|
| 86 |
+
_print(f'uncond seq: {uncond_seq}')
|
| 87 |
+
_print(f'uncond ppl: {uncond_ppl}')
|
| 88 |
+
_print(f'uncond bind: {uncond_bind}')
|
| 89 |
+
|
| 90 |
+
_print(f'guided seq: {guided_seq}')
|
| 91 |
+
_print(f'guided ppl: {guided_ppl}')
|
| 92 |
+
_print(f'guided bind: {guided_bind}')
|
| 93 |
+
|
| 94 |
+
res_row = {
|
| 95 |
+
"Uncond Generated Sequence": uncond_seq,
|
| 96 |
+
"Guided Generated Sequence": guided_seq,
|
| 97 |
+
"Uncond PPL": uncond_ppl,
|
| 98 |
+
"Guided PPL": guided_ppl,
|
| 99 |
+
"Uncond Affinity": uncond_bind,
|
| 100 |
+
"Guided Affinity": guided_bind
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
generation_results.append(res_row)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
df = pd.DataFrame(generation_results)
|
| 107 |
+
|
| 108 |
+
_print(f"Uncond PPL Res: {df['Uncond PPL'].mean()}, {df['Uncond PPL'].std()}")
|
| 109 |
+
_print(f"Guided PPL Res: {df['Guided PPL'].mean()}, {df['Guided PPL'].std()}")
|
| 110 |
+
|
| 111 |
+
_print(f"Uncond Affinity Res: {df['Uncond Affinity'].mean()}, {df['Uncond Affinity'].std()}")
|
| 112 |
+
_print(f"Guided Affinity Res: {df['Guided Affinity'].mean()}, {df['Guided Affinity'].std()}")
|
| 113 |
+
|
| 114 |
+
df.to_csv(
|
| 115 |
+
csv_save_path + f"/{tgt_name}/tau=0.5_topp=0.9_no-gumbel_rate=0.01_jump=0.05_ablate={config.model.ablate}_nsteps={config.sampling.n_steps}_seqs_with_ppl_{date}.csv",
|
| 116 |
+
index=False
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
if __name__ == "__main__":
|
| 121 |
+
main()
|
src/sampling/madsbm_sampler.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
from src.PeptiVerse.inference import PeptiVersePredictor
|
| 9 |
+
from src.utils.model_utils import _print
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class MadSBMSampler:
|
| 13 |
+
def __init__(self, model, config, device, guidance=None):
|
| 14 |
+
self.config = config
|
| 15 |
+
self.device = device
|
| 16 |
+
self.model = model
|
| 17 |
+
self.tokenizer = model.tokenizer
|
| 18 |
+
self.mask_id = self.tokenizer.mask_token_id
|
| 19 |
+
self.eps = config.time_embed.min_time
|
| 20 |
+
self.seed_everything(seed=42)
|
| 21 |
+
|
| 22 |
+
if guidance:
|
| 23 |
+
self.guidance = guidance
|
| 24 |
+
self.peptiverse = PeptiVersePredictor(
|
| 25 |
+
manifest_path="/scratch/pranamlab/sgoel/MadSBM/src/PeptiVerse/best_models.txt",
|
| 26 |
+
classifier_weight_root="/scratch/pranamlab/sgoel/MadSBM/src/PeptiVerse",
|
| 27 |
+
device=self.device
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@torch.inference_mode()
|
| 32 |
+
def sample(self, xt, num_steps, tracer, target_toks=None, guidance=None):
|
| 33 |
+
xt = xt.clone()
|
| 34 |
+
B, L = xt.shape
|
| 35 |
+
assert B == 1, "Do only 1 sequence at a time"
|
| 36 |
+
|
| 37 |
+
t_max = 1.0 - self.eps
|
| 38 |
+
dt = 1.0 / num_steps
|
| 39 |
+
attn_mask = torch.ones_like(xt, device=self.device)
|
| 40 |
+
|
| 41 |
+
action_traj = {}
|
| 42 |
+
tot_action = 0.0
|
| 43 |
+
|
| 44 |
+
tracer.log_step(xt=xt, step_idx=0)
|
| 45 |
+
|
| 46 |
+
converge_idx = num_steps
|
| 47 |
+
converged = False
|
| 48 |
+
|
| 49 |
+
for k in range(num_steps):
|
| 50 |
+
# t decreases from 1 --> 0 as our model was trained that t=1 --> noise and t=0 --> clean
|
| 51 |
+
prog = (k + 1) / float(num_steps)
|
| 52 |
+
t_val = t_max - (t_max - self.eps) * prog
|
| 53 |
+
t = torch.full((B,), fill_value=float(t_val), device=self.device) # B = 1 during sampling
|
| 54 |
+
|
| 55 |
+
# predicted control field --> B, L, V
|
| 56 |
+
outs = self.model(input_ids=xt, attention_mask=attn_mask, t=t)
|
| 57 |
+
|
| 58 |
+
u_tilt = outs['dit']
|
| 59 |
+
total_logits = outs['madsbm']
|
| 60 |
+
esm_logits = outs['esm']
|
| 61 |
+
|
| 62 |
+
if self.config.model.ablate:
|
| 63 |
+
actional = self.compute_action(u_tilt, esm_logits=None)
|
| 64 |
+
else:
|
| 65 |
+
actional = self.compute_action(u_tilt, esm_logits=esm_logits)
|
| 66 |
+
|
| 67 |
+
action_traj[f"action_step_{k+1}"] = actional
|
| 68 |
+
tot_action += (actional * dt)
|
| 69 |
+
|
| 70 |
+
# Compute jump rates and jump probs
|
| 71 |
+
# P(jump) = 1 - exp(-rate * dt)
|
| 72 |
+
r_theta = torch.exp(u_tilt * self.config.sampling.rate_scale)
|
| 73 |
+
R_tot = r_theta.sum(dim=-1) # 1, L
|
| 74 |
+
rate = (- R_tot * self.config.sampling.jump_scale * dt).clamp(min=-40.0, max=0.0)
|
| 75 |
+
jump_prob = 1.0 - torch.exp(rate)
|
| 76 |
+
|
| 77 |
+
# Scale and filter logits with nucleus sampling
|
| 78 |
+
logits = total_logits.clone()
|
| 79 |
+
logits /= self.config.sampling.tau
|
| 80 |
+
logits = self.top_p_filter(logits, self.config.sampling.top_p)
|
| 81 |
+
|
| 82 |
+
# Sample new tokens
|
| 83 |
+
probs = F.softmax(logits, dim=-1)
|
| 84 |
+
probs = probs.view(-1, probs.size(-1))
|
| 85 |
+
sample = torch.multinomial(probs, 1)
|
| 86 |
+
candidate_toks = sample.view(B, L)
|
| 87 |
+
|
| 88 |
+
# determine tokens we can change
|
| 89 |
+
rand = torch.rand(B, L, device=self.device)
|
| 90 |
+
can_jump = (rand < jump_prob)
|
| 91 |
+
updatable = can_jump & self.is_masked(xt)
|
| 92 |
+
|
| 93 |
+
# Update the sequence
|
| 94 |
+
if guidance:
|
| 95 |
+
chosen_candidate = self.binding_guidance(probs, target_toks, B, L)
|
| 96 |
+
xt[updatable] = chosen_candidate[updatable]
|
| 97 |
+
else:
|
| 98 |
+
xt[updatable] = candidate_toks[updatable]
|
| 99 |
+
|
| 100 |
+
tracer.log_step(xt=xt, step_idx = k+1)
|
| 101 |
+
|
| 102 |
+
if k == num_steps-1:
|
| 103 |
+
final_logits = total_logits
|
| 104 |
+
still_masked = self.is_masked(xt)
|
| 105 |
+
|
| 106 |
+
if not converged and not self.is_masked(xt).any():
|
| 107 |
+
converge_idx = k + 1
|
| 108 |
+
converged = True
|
| 109 |
+
|
| 110 |
+
# Copy over remaining tokens
|
| 111 |
+
if still_masked.any():
|
| 112 |
+
final_toks = final_logits.argmax(dim=-1)
|
| 113 |
+
xt[still_masked] = final_toks[still_masked]
|
| 114 |
+
|
| 115 |
+
tracer.log_step(xt, num_steps + 1)
|
| 116 |
+
|
| 117 |
+
binding_affin = self.peptiverse.predict_binding_affinity(
|
| 118 |
+
mode = 'wt',
|
| 119 |
+
target_ids = target_toks,
|
| 120 |
+
binder_ids = xt
|
| 121 |
+
)['affinity']
|
| 122 |
+
|
| 123 |
+
return xt, binding_affin
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def binding_guidance(self, probs, target_toks, B, L):
|
| 127 |
+
M = self.config.sampling.M
|
| 128 |
+
candidate_toks = []
|
| 129 |
+
affinities = []
|
| 130 |
+
|
| 131 |
+
for _ in range(M):
|
| 132 |
+
ith_sample = torch.multinomial(probs, 1).view(B, L)
|
| 133 |
+
candidate_toks.append(ith_sample)
|
| 134 |
+
|
| 135 |
+
for toks in candidate_toks:
|
| 136 |
+
pred = self.peptiverse.predict_binding_affinity(
|
| 137 |
+
mode = 'wt',
|
| 138 |
+
target_ids = target_toks,
|
| 139 |
+
binder_ids = toks.detach()
|
| 140 |
+
)['affinity']
|
| 141 |
+
affinities.append(pred)
|
| 142 |
+
|
| 143 |
+
affinities = torch.tensor(affinities, dtype=torch.float32)
|
| 144 |
+
weights = F.softmax(affinities / self.config.sampling.tau, dim=0)
|
| 145 |
+
chosen_idx = torch.multinomial(weights, 1).item()
|
| 146 |
+
|
| 147 |
+
return candidate_toks[chosen_idx]
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def compute_action(self, u_tilt, esm_logits=None):
|
| 151 |
+
""" Computes the action functional for evals """
|
| 152 |
+
if esm_logits is not None:
|
| 153 |
+
R0 = torch.softmax(esm_logits, dim=-1)
|
| 154 |
+
else:
|
| 155 |
+
R0 = 1.0 / self.tokenizer.vocab_size
|
| 156 |
+
|
| 157 |
+
psi_u = torch.exp(u_tilt) - u_tilt - 1.0
|
| 158 |
+
action_per_tok = (R0 * psi_u).sum(dim=-1) # R0 goes to 1 in both cases
|
| 159 |
+
|
| 160 |
+
return action_per_tok.mean().item()
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def top_p_filter(self, logits, p_val):
|
| 164 |
+
"""
|
| 165 |
+
Implementation of nucleus / top-p sampling
|
| 166 |
+
Masks out tokens that contribute to the bottom (1 - p) cumulative probability
|
| 167 |
+
"""
|
| 168 |
+
# Sort logits and get cumulative probabilities
|
| 169 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
| 170 |
+
cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
| 171 |
+
|
| 172 |
+
# Remove tokens with cum prob > p-val thresh
|
| 173 |
+
sorted_idx_to_remove = cum_probs > p_val
|
| 174 |
+
|
| 175 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
| 176 |
+
sorted_idx_to_remove[..., 1:] = sorted_idx_to_remove[..., :-1].clone()
|
| 177 |
+
sorted_idx_to_remove[..., 0] = 0
|
| 178 |
+
|
| 179 |
+
idx_to_remove = sorted_idx_to_remove.scatter(-1, sorted_indices, sorted_idx_to_remove)
|
| 180 |
+
logits[idx_to_remove] = float('-inf')
|
| 181 |
+
return logits
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def is_masked(self, xt):
|
| 185 |
+
return (xt == self.mask_id)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def seed_everything(self, seed):
|
| 189 |
+
if seed is None:
|
| 190 |
+
return
|
| 191 |
+
random.seed(seed)
|
| 192 |
+
np.random.seed(seed)
|
| 193 |
+
torch.manual_seed(seed)
|
| 194 |
+
if torch.cuda.is_available():
|
| 195 |
+
torch.cuda.manual_seed(seed)
|
| 196 |
+
torch.cuda.manual_seed_all(seed) # if using multi-GPU
|
| 197 |
+
torch.backends.cudnn.deterministic = True
|
| 198 |
+
torch.backends.cudnn.benchmark = False
|
src/sampling/path_tracer.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ProbabilityPathTracer:
|
| 7 |
+
def __init__(self, oracle_model, tokenizer, device):
|
| 8 |
+
self.oracle = oracle_model
|
| 9 |
+
self.tokenizer = tokenizer
|
| 10 |
+
self.device = device
|
| 11 |
+
self.mask_id = tokenizer.mask_token_id
|
| 12 |
+
self.history = {} # {nth_step: prob_score}
|
| 13 |
+
|
| 14 |
+
@torch.inference_mode()
|
| 15 |
+
def compute_loglikeli(self, xt):
|
| 16 |
+
is_revealed = (xt != self.mask_id)
|
| 17 |
+
|
| 18 |
+
if not is_revealed.any():
|
| 19 |
+
return 0.0
|
| 20 |
+
|
| 21 |
+
# esm forward pass
|
| 22 |
+
logits = self.oracle(
|
| 23 |
+
input_ids=xt,
|
| 24 |
+
attention_mask=torch.ones_like(xt, device=xt.device)
|
| 25 |
+
).logits
|
| 26 |
+
|
| 27 |
+
# Calculate CE loss only on unmasked tokens
|
| 28 |
+
nll = F.cross_entropy(
|
| 29 |
+
logits.view(-1, logits.size(-1)),
|
| 30 |
+
xt.view(-1),
|
| 31 |
+
reduction='none'
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
nll = nll.view(xt.shape)
|
| 35 |
+
|
| 36 |
+
# Lower NLL = better --> higher LL = better
|
| 37 |
+
avg_ll = -(nll * is_revealed.float()).sum(dim=1) / is_revealed.float().sum(dim=1).clamp(min=1)
|
| 38 |
+
|
| 39 |
+
return avg_ll.item()
|
| 40 |
+
|
| 41 |
+
def log_step(self, xt, step_idx):
|
| 42 |
+
score = self.compute_loglikeli(xt)
|
| 43 |
+
self.history[f"trace_step_{step_idx}"] = score
|
| 44 |
+
|
| 45 |
+
def get_trace(self):
|
| 46 |
+
return self.history
|
src/utils/__init__.py
ADDED
|
File without changes
|
src/utils/eval_utils.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from scipy.linalg import sqrtm
|
| 4 |
+
|
| 5 |
+
def dna_to_tensor(seq):
|
| 6 |
+
mapping = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
|
| 7 |
+
indices = [mapping[base] for base in seq]
|
| 8 |
+
return torch.tensor(indices, dtype=torch.long)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def compute_fbd(true_seqs, gen_seqs, score_model):
|
| 12 |
+
"""
|
| 13 |
+
The Frechet Biological Distance (FBD) is defined as the Wasserstein distance between Gaussian / true embeddings
|
| 14 |
+
"""
|
| 15 |
+
embeds1 = score_model()
|
| 16 |
+
embeds2 = score_model()
|
| 17 |
+
|
| 18 |
+
if np.isnan(embeds2).any() or np.isnan(embeds1).any() or len(embeds1) == 0 or len(embeds2) == 0:
|
| 19 |
+
return float('nan')
|
| 20 |
+
mu1, sigma1 = embeds1.mean(axis=0), np.cov(embeds1, rowvar=False)
|
| 21 |
+
mu2, sigma2 = embeds2.mean(axis=0), np.cov(embeds2, rowvar=False)
|
| 22 |
+
ssdiff = np.sum((mu1 - mu2) ** 2.0)
|
| 23 |
+
covmean = sqrtm(sigma1.dot(sigma2))
|
| 24 |
+
if np.iscomplexobj(covmean):
|
| 25 |
+
covmean = covmean.real
|
| 26 |
+
dist = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
|
| 27 |
+
return dist
|
src/utils/fbd_score_model.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import copy
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from torch import nn
|
| 6 |
+
from src.utils.time_utils import GaussianFourierProjection
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Dense(nn.Module):
|
| 10 |
+
"""
|
| 11 |
+
A fully connected layer that reshapes outputs to feature maps.
|
| 12 |
+
"""
|
| 13 |
+
def __init__(self, input_dim, output_dim):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.dense = nn.Linear(input_dim, output_dim)
|
| 16 |
+
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
return self.dense(x)[...]
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class Swish(nn.Module):
|
| 22 |
+
def __init__(self):
|
| 23 |
+
super().__init__()
|
| 24 |
+
|
| 25 |
+
def forward(self, x):
|
| 26 |
+
return torch.sigmoid(x) * x
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class CNNClassifier(nn.Module):
|
| 30 |
+
def __init__(self, args, alphabet_size, num_cls, classifier=False):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.alphabet_size = alphabet_size
|
| 33 |
+
self.args = args
|
| 34 |
+
self.classifier = classifier
|
| 35 |
+
self.num_cls = num_cls
|
| 36 |
+
|
| 37 |
+
self.linear = nn.Embedding(self.alphabet_size, embedding_dim=args.hidden_dim)
|
| 38 |
+
|
| 39 |
+
self.num_layers = 5 * args.num_cnn_stacks
|
| 40 |
+
self.convs = [
|
| 41 |
+
nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, padding=4),
|
| 42 |
+
nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, padding=4),
|
| 43 |
+
nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, dilation=4, padding=16),
|
| 44 |
+
nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, dilation=16, padding=64),
|
| 45 |
+
nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, dilation=64, padding=256)
|
| 46 |
+
]
|
| 47 |
+
self.convs = nn.ModuleList([copy.deepcopy(layer) for layer in self.convs for i in range(args.num_cnn_stacks)])
|
| 48 |
+
self.time_layers = nn.ModuleList([Dense(args.hidden_dim, args.hidden_dim) for _ in range(self.num_layers)])
|
| 49 |
+
self.norms = nn.ModuleList([nn.LayerNorm(args.hidden_dim) for _ in range(self.num_layers)])
|
| 50 |
+
self.final_conv = nn.Sequential(
|
| 51 |
+
nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=1),
|
| 52 |
+
nn.ReLU(),
|
| 53 |
+
nn.Conv1d(args.hidden_dim, args.hidden_dim if classifier else self.alphabet_size, kernel_size=1)
|
| 54 |
+
)
|
| 55 |
+
self.dropout = nn.Dropout(args.dropout)
|
| 56 |
+
|
| 57 |
+
if classifier:
|
| 58 |
+
self.cls_head = nn.Sequential(
|
| 59 |
+
nn.Linear(args.hidden_dim, args.hidden_dim),
|
| 60 |
+
nn.ReLU(),
|
| 61 |
+
nn.Linear(args.hidden_dim, self.num_cls)
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
if self.args.cls_free_guidance and not self.classifier:
|
| 65 |
+
self.cls_embedder = nn.Embedding(num_embeddings=self.num_cls + 1, embedding_dim=args.hidden_dim)
|
| 66 |
+
self.cls_layers = nn.ModuleList([Dense(args.hidden_dim, args.hidden_dim) for _ in range(self.num_layers)])
|
| 67 |
+
|
| 68 |
+
def forward(self, seq, t, cls = None, return_embedding=False):
|
| 69 |
+
if self.args.clean_data:
|
| 70 |
+
feat = self.linear(seq)
|
| 71 |
+
feat = feat.permute(0, 2, 1)
|
| 72 |
+
else:
|
| 73 |
+
time_emb = F.relu(self.time_embedder(t))
|
| 74 |
+
feat = seq.permute(0, 2, 1)
|
| 75 |
+
feat = F.relu(self.linear(feat))
|
| 76 |
+
|
| 77 |
+
if self.args.cls_free_guidance and not self.classifier and cls is not None:
|
| 78 |
+
cls_emb = self.cls_embedder(cls)
|
| 79 |
+
|
| 80 |
+
for i in range(self.num_layers):
|
| 81 |
+
h = self.dropout(feat.clone())
|
| 82 |
+
if not self.args.clean_data:
|
| 83 |
+
h = h + self.time_layers[i](time_emb)[:, :, None]
|
| 84 |
+
if self.args.cls_free_guidance and not self.classifier and cls is not None:
|
| 85 |
+
h = h + self.cls_layers[i](cls_emb)[:, :, None]
|
| 86 |
+
h = self.norms[i]((h).permute(0, 2, 1))
|
| 87 |
+
h = F.relu(self.convs[i](h.permute(0, 2, 1)))
|
| 88 |
+
if h.shape == feat.shape:
|
| 89 |
+
feat = h + feat
|
| 90 |
+
else:
|
| 91 |
+
feat = h
|
| 92 |
+
|
| 93 |
+
feat = self.final_conv(feat)
|
| 94 |
+
feat = feat.permute(0, 2, 1)
|
| 95 |
+
|
| 96 |
+
if self.classifier:
|
| 97 |
+
feat = feat.mean(dim=1)
|
| 98 |
+
if return_embedding:
|
| 99 |
+
embedding = self.cls_head[:1](feat)
|
| 100 |
+
return self.cls_head[1:](embedding), embedding
|
| 101 |
+
else:
|
| 102 |
+
return self.cls_head(feat)
|
| 103 |
+
|
| 104 |
+
return feat
|
src/utils/generate_utils.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import sys
|
| 3 |
+
import torch
|
| 4 |
+
import math
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
|
| 9 |
+
from collections import Counter
|
| 10 |
+
from omegaconf import OmegaConf
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
config = OmegaConf.load("/scratch/pranamlab/sgoel/MeMDLM_v2/src/configs/lm.yaml")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# -------# Masking #-------- #
|
| 17 |
+
def mask_for_de_novo(sequence_length):
|
| 18 |
+
return "<mask>" * sequence_length
|
| 19 |
+
|
| 20 |
+
def mask_for_scaffold(sequence, generate_type, mask_token):
|
| 21 |
+
if generate_type == "uppercase":
|
| 22 |
+
sequence = ''.join([mask_token if residue.isupper() else residue.upper() for residue in sequence])
|
| 23 |
+
elif generate_type == "lowercase":
|
| 24 |
+
sequence = ''.join([mask_token if residue.islower() else residue for residue in sequence])
|
| 25 |
+
return sequence
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# -------# Generation #-------- #
|
| 29 |
+
def evodiff_infill(motif_seq, tokenizer, model, device, batch_size=1):
|
| 30 |
+
"""
|
| 31 |
+
Following the given evodiff example
|
| 32 |
+
https://github.com/microsoft/evodiff/blob/main/examples/evodiff.ipynb
|
| 33 |
+
"""
|
| 34 |
+
# Manual masking of infilling sequence
|
| 35 |
+
motif_seq = ''.join(["#" if aa.islower() else aa for aa in motif_seq]) # Mask token is "#" in evodiff tokenizer
|
| 36 |
+
tkns = tokenizer.tokenize([motif_seq])
|
| 37 |
+
sample = torch.as_tensor(tkns).to(device)
|
| 38 |
+
|
| 39 |
+
# Create input motif + scaffold
|
| 40 |
+
loc = torch.arange(0, len(motif_seq)).to(device)[sample==tokenizer.mask_id].cpu().numpy()
|
| 41 |
+
np.random.shuffle(loc)
|
| 42 |
+
|
| 43 |
+
sample = sample.to(device).unsqueeze(0)
|
| 44 |
+
# og_sample = sample.clone()
|
| 45 |
+
|
| 46 |
+
with torch.no_grad():
|
| 47 |
+
for i in loc:
|
| 48 |
+
timestep = torch.tensor([0] * batch_size).to(device) # placeholder but not called in model
|
| 49 |
+
timestep = timestep.to(device)
|
| 50 |
+
prediction = model(sample, timestep)
|
| 51 |
+
p = prediction[:, i, :len(tokenizer.all_aas) - 6] # only canonical
|
| 52 |
+
p = F.softmax(p, dim=1) # softmax over logits
|
| 53 |
+
p_sample = torch.multinomial(p, num_samples=1) # sample from categorical distribution
|
| 54 |
+
sample[:, i] = p_sample.squeeze()
|
| 55 |
+
output = [tokenizer.untokenize(s) for s in sample]
|
| 56 |
+
return output[0] #if batch_size==1 else output, og_sample, loc
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def dplm_infill(masked_seq, tokenizer, model, device):
|
| 60 |
+
from src.lm.dplm.diffusion_module import DPLM
|
| 61 |
+
from src.lm.dplm.unconditional_sampler import UnconditionalSampler as DPLMUnconditionalSampler
|
| 62 |
+
|
| 63 |
+
generator = DPLMUnconditionalSampler(tokenizer, model)
|
| 64 |
+
xt = tokenizer(masked_seq, return_tensors='pt')['input_ids'].to(model.device)
|
| 65 |
+
denoised_tokens = generator.sample_unconditional(xt, config.sampling.n_steps)[0].squeeze()
|
| 66 |
+
generated_sequence = tokenizer.decode(denoised_tokens).replace(" ", "")[5:-5]
|
| 67 |
+
return generated_sequence
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# -------# Metrics #-------- #
|
| 71 |
+
def calc_progen_ppl(model, tokenizer, target, device, fp16=True):
|
| 72 |
+
"""Compute causal LM cross-entropy loss for a given sequence."""
|
| 73 |
+
with torch.no_grad():
|
| 74 |
+
with torch.cuda.amp.autocast(enabled=fp16):
|
| 75 |
+
logits = model(
|
| 76 |
+
input_ids = target,
|
| 77 |
+
attention_mask = torch.ones_like(target)
|
| 78 |
+
).logits
|
| 79 |
+
# Shift
|
| 80 |
+
logits = logits[:-1, ...]
|
| 81 |
+
target = target[1:]
|
| 82 |
+
loss = torch.nn.functional.cross_entropy(
|
| 83 |
+
input=logits,
|
| 84 |
+
target=target,
|
| 85 |
+
reduction='mean'
|
| 86 |
+
)
|
| 87 |
+
return torch.exp(loss).item()
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def calc_ppl(model, tokenizer, generated_sequence, mask_token_indices, model_type):
|
| 91 |
+
total_loss = 0.0
|
| 92 |
+
tensor_input = tokenizer.encode(generated_sequence, return_tensors='pt').to(model.device)
|
| 93 |
+
|
| 94 |
+
for i in mask_token_indices:
|
| 95 |
+
masked_input = tensor_input.clone()
|
| 96 |
+
masked_input[0, i] = tokenizer.mask_token_id
|
| 97 |
+
|
| 98 |
+
labels = torch.full(tensor_input.shape, -100).to(model.device)
|
| 99 |
+
labels[0, i] = tensor_input[0, i]
|
| 100 |
+
|
| 101 |
+
with torch.no_grad():
|
| 102 |
+
loss = model(masked_input, labels=labels).loss.item()
|
| 103 |
+
total_loss += loss
|
| 104 |
+
|
| 105 |
+
avg_loss = total_loss / len(generated_sequence)
|
| 106 |
+
perplexity = math.exp(avg_loss)
|
| 107 |
+
|
| 108 |
+
return perplexity
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def calc_entropy(seq):
|
| 112 |
+
counts = Counter(seq)
|
| 113 |
+
total_len = len(seq)
|
| 114 |
+
entropy = 0.0
|
| 115 |
+
for count in counts.values():
|
| 116 |
+
prob = count / total_len
|
| 117 |
+
entropy -= prob * math.log2(prob)
|
| 118 |
+
return entropy
|
src/utils/model_utils.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import math
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def _print(s):
|
| 9 |
+
print(s)
|
| 10 |
+
sys.stdout.flush()
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def compute_grad_norms(params):
|
| 14 |
+
""" Compute the norms of a matrix of gradients """
|
| 15 |
+
sqrd_sum = 0.0
|
| 16 |
+
for p in params:
|
| 17 |
+
if p.grad != None:
|
| 18 |
+
sqrd_sum += p.grad.norm(2).item() ** 2
|
| 19 |
+
norm = sqrd_sum ** 0.5
|
| 20 |
+
return norm
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class CosineWarmup(torch.optim.lr_scheduler._LRScheduler):
|
| 24 |
+
def __init__(self, optimizer, warmup_steps, total_steps, eta_ratio=0.1, last_epoch=-1):
|
| 25 |
+
self.warmup_steps = warmup_steps
|
| 26 |
+
self.total_steps = total_steps
|
| 27 |
+
self.eta_ratio = eta_ratio # The ratio of minimum to maximum learning rate
|
| 28 |
+
super(CosineWarmup, self).__init__(optimizer, last_epoch)
|
| 29 |
+
|
| 30 |
+
def get_lr(self):
|
| 31 |
+
step = self.last_epoch
|
| 32 |
+
if step < self.warmup_steps:
|
| 33 |
+
return [
|
| 34 |
+
base_lr * self.last_epoch / self.warmup_steps
|
| 35 |
+
for base_lr in self.base_lrs
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
progress = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
|
| 39 |
+
cosine_decay = 0.5 * (1 + np.cos(np.pi * progress))
|
| 40 |
+
lr_mult = (1 - self.eta_ratio) * cosine_decay + self.eta_ratio
|
| 41 |
+
|
| 42 |
+
return [base_lr * lr_mult for base_lr in self.base_lrs]
|
src/utils/time_utils.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
# -------------------------
|
| 6 |
+
# Timestep embeddings
|
| 7 |
+
# -------------------------
|
| 8 |
+
|
| 9 |
+
class GaussianFourierProjection(nn.Module):
|
| 10 |
+
"""
|
| 11 |
+
Gaussian Fourier features for continuous time t in [0, 1].
|
| 12 |
+
Produces 2 * embed_dim features: [sin(W t), cos(W t)].
|
| 13 |
+
"""
|
| 14 |
+
def __init__(self, embed_dim, scale):
|
| 15 |
+
super().__init__()
|
| 16 |
+
assert embed_dim % 2 == 0, "embed_dim must be even."
|
| 17 |
+
self.embed_dim = embed_dim
|
| 18 |
+
self.register_buffer("W", torch.randn(embed_dim // 2) * scale, persistent=False) # Fixed random frequencies
|
| 19 |
+
|
| 20 |
+
def forward(self, t):
|
| 21 |
+
# Ensure float
|
| 22 |
+
t = t.float().unsqueeze(-1) # Broadcoast to [B, 1]
|
| 23 |
+
angles = t * self.W # B, embed_dim // 2
|
| 24 |
+
return torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class TimeEmbedding(nn.Module):
|
| 28 |
+
def __init__(self, hidden_dim, fourier_dim, scale):
|
| 29 |
+
super().__init__()
|
| 30 |
+
assert fourier_dim % 2 == 0, "fourier_dim must be even for sine/cosine pairs."
|
| 31 |
+
|
| 32 |
+
self.fourier = GaussianFourierProjection(fourier_dim, scale)
|
| 33 |
+
self.mlp = nn.Sequential(
|
| 34 |
+
nn.Linear(fourier_dim, hidden_dim),
|
| 35 |
+
nn.SiLU(),
|
| 36 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
def forward(self, t):
|
| 40 |
+
ft = self.fourier(t) # (B, fourier_dim)
|
| 41 |
+
return self.mlp(ft) # (B, hidden_dim)
|