Shrey Goel commited on
Commit
94c2704
·
0 Parent(s):

initial commit

Browse files
README.md ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Token-Level Guided Discrete Diffusion for Membrane Protein Design
2
+
3
+ ![MemDLM diagram](./memdlm_schematic.png)
4
+
5
+
6
+ arXiv preprint: ...
7
+
8
+ Reparameterized diffusion models (RDMs) have recently matched autoregressive methods in protein generation, motivating their use for challenging tasks such as designing membrane proteins, which possess interleaved soluble and transmembrane (TM) regions.
9
+
10
+ We introduce ***Membrane Diffusion Language Model (MemDLM)***, a fine-tuned RDM-based protein language model that enables controllable membrane protein sequence design. MemDLM-generated sequences recapitulate the TM residue density and structural features of natural membrane proteins, achieving comparable biological plausibility and outperforming state-of-the-art diffusion baselines in motif scaffolding tasks by producing:
11
+
12
+ - Lower perplexity
13
+ - Higher BLOSUM-62 scores
14
+ - Improved pLDDT confidence
15
+
16
+ To enhance controllability, we develop ***Per-Token Guidance (PET)***, a novel classifier-guided sampling strategy that selectively solubilizes residues while preserving conserved TM domains. This yields sequences with reduced TM density but intact functional cores.
17
+
18
+ Importantly, MemDLM designs validated in TOXCAT β-lactamase growth assays demonstrate successful TM insertion, distinguishing high-quality generated sequences from poor ones.
19
+
20
+ Together, our framework establishes the first experimentally validated diffusion-based model for rational membrane protein generation, integrating *de novo* design, motif scaffolding, and targeted property optimization.
21
+
22
+
23
+
24
+ ## **Repository Authors**
25
+ - <u>[Shrey Goel](https://shreygoel09.github.io/)</u> – undergraduate student at Duke University
26
+ - <u>[Pranam Chatterjee](mailto:pranam@seas.upenn.edu)</u> – Assistant Professor at University of Pennsylvania
__init__.py ADDED
File without changes
configs/wt_pep.yaml ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ seed: 42
3
+ base_dir: /scratch/pranamlab/sgoel/MadSBM
4
+
5
+ training:
6
+ mode: test # train / test / resume_from_ckpt
7
+ n_unfrozen: 3
8
+ n_epochs: 50
9
+ log_every_n_steps: 50
10
+ num_sanity_val_steps: 2
11
+ val_check_interval:
12
+ enable_progress_bar: true
13
+ grad_clip_val: 10.0
14
+ accumulate_grad_batches: 16 # to workaround dynamic batching
15
+ devices: 1 # number of GPUs
16
+
17
+ model:
18
+ ablate: false
19
+ evoflow_model: fredzzp/EvoFlow-150M-afdbseq
20
+ esm_model: facebook/esm2_t33_650M_UR50D
21
+ n_layers: 2 #8
22
+ n_heads: 16 #8
23
+ hidden_dim: 1280
24
+ attn_drop: 0.0
25
+ resid_drop: 0.0
26
+ mlp_ratio: 4.0
27
+ beta1: 1e-6
28
+ beta2: 1e-6
29
+
30
+
31
+ time_embed:
32
+ time_dim: 512
33
+ fourier_dim: 64
34
+ fourier_scale: 30.0
35
+ time_schedule: uniform # linear / exponential / uniform
36
+ anneal_frac: 0.75
37
+ min_time: 1e-6
38
+ n_timesteps: 500
39
+
40
+ data:
41
+ batch_size: 1
42
+ #max_seq_len: 500
43
+ train: /scratch/pranamlab/tong/data/peptide/tokenized_peptide_batched/train
44
+ test: /scratch/pranamlab/tong/data/peptide/tokenized_peptide_batched/test
45
+ val: /scratch/pranamlab/tong/data/peptide/tokenized_peptide_batched/val
46
+
47
+
48
+ optim:
49
+ type: adamw
50
+ scheduler: cosine
51
+ lr: 1e-4
52
+ lr_end: 1e-5
53
+ warmup_init_lr: 1e-6
54
+ warmup_epochs: 2
55
+ weight_decay: 0.01
56
+ beta1: 0.9
57
+ beta2: 0.999
58
+ power: 1
59
+
60
+
61
+ wandb:
62
+ project: MadSBM_PEPTIDE
63
+ group: programmablebio
64
+ name: peptide_og-madsbm_esm_no-gclip_lr=1e-4_n-layers=2_n-heads=16_trainable-lm-head_logits-sum-SM_gate-esm
65
+ #name: peptide_og-madsbm_esm_no-gclip_lr=1e-4_n-layers=2_n-heads=16_trainable-lm-head_logits-sum-SM_ABLATE-gate-esm
66
+ id: ${.name}_${seed}
67
+
68
+
69
+ checkpointing:
70
+ save_every_n_epochs: 1
71
+ save_dir: ${base_dir}/checkpoints/wt_pep/${wandb.name}
72
+ resume_ckpt_path: ${checkpointing.save_dir}/last.ckpt
73
+ best_ckpt_path: ${checkpointing.save_dir}/best-model_epoch=41_step=106890.ckpt
74
+
75
+
76
+ sampling:
77
+ model_type: madsbm # madsbm / diffusion / dfm
78
+ n_steps: 32
79
+ top_p: 0.9
80
+ rate_scale: 0.01
81
+ jump_scale: 0.05
82
+ tau: 0.5
83
+ M: 16
84
+ beta: 2.0
setup.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name='MadSBM',
5
+ version='1.0',
6
+ packages=find_packages(),
7
+ install_requires=[],
8
+ author='Shrey Goel',
9
+ author_email='shrey.goel@duke.edu'
10
+ )
src/__init__.py ADDED
File without changes
src/madsbm/__init__.py ADDED
File without changes
src/madsbm/wt_peptide/__init__.py ADDED
File without changes
src/madsbm/wt_peptide/control_field.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer
7
+
8
+ from src.utils.time_utils import TimeEmbedding
9
+ from src.utils.model_utils import _print
10
+
11
+
12
+ # -------------------------
13
+ # DiT building blocks
14
+ # -------------------------
15
+
16
+ class MLP(nn.Module):
17
+ def __init__(self, dim, mlp_ratio, dropout):
18
+ super().__init__()
19
+ hidden_dim = int(dim * mlp_ratio)
20
+ self.fc1 = nn.Linear(dim, hidden_dim)
21
+ self.act = nn.GELU()
22
+ self.fc2 = nn.Linear(hidden_dim, dim)
23
+ self.dropout = nn.Dropout(dropout)
24
+
25
+ def forward(self, x):
26
+ x = self.fc1(x)
27
+ x = self.act(x)
28
+ x = self.dropout(x)
29
+ x = self.fc2(x)
30
+ x = self.dropout(x)
31
+ return x
32
+
33
+
34
+ class DiTBlock1D(nn.Module):
35
+ def __init__(self, cfg):
36
+ super().__init__()
37
+ self.cfg = cfg
38
+ self.hidden_dim = cfg.model.hidden_dim
39
+ self.time_dim = cfg.time_embed.time_dim
40
+
41
+ self.norm1 = nn.LayerNorm(self.hidden_dim, eps=1e-6)
42
+ self.norm2 = nn.LayerNorm(self.hidden_dim, eps=1e-6)
43
+
44
+ # time-conditioned scale & shift for both norms
45
+ self.time_proj1 = nn.Linear(self.time_dim, 2 * self.hidden_dim) # scale1, shift1
46
+ self.time_proj2 = nn.Linear(self.time_dim, 2 * self.hidden_dim) # scale2, shift2
47
+
48
+ self.attn = nn.MultiheadAttention(
49
+ embed_dim=self.hidden_dim,
50
+ num_heads=cfg.model.n_heads,
51
+ dropout=cfg.model.attn_drop,
52
+ batch_first=True
53
+ )
54
+
55
+ self.mlp = MLP(
56
+ self.hidden_dim,
57
+ mlp_ratio=cfg.model.mlp_ratio,
58
+ dropout=cfg.model.resid_drop
59
+ )
60
+
61
+ def forward(self, x, t_emb, key_padding_mask=None):
62
+ # ----- Self-attention branch -----
63
+ # Adaptive LayerNorm (AdaLN) + FiLM from time embedding
64
+ scale1, shift1 = self.time_proj1(t_emb).chunk(2, dim=-1) # [B, D] and [B, D]
65
+ h = self.norm1(x)
66
+ h = h * (1 + scale1.unsqueeze(1)) + shift1.unsqueeze(1) # [B, L, D]
67
+
68
+ attn_out, _ = self.attn(
69
+ h,
70
+ h,
71
+ h,
72
+ key_padding_mask=key_padding_mask, # True for pads
73
+ need_weights=False,
74
+ )
75
+ x = x + attn_out
76
+
77
+ # ----- MLP branch -----
78
+ scale2, shift2 = self.time_proj2(t_emb).chunk(2, dim=-1)
79
+ h2 = self.norm2(x)
80
+ h2 = h2 * (1 + scale2.unsqueeze(1)) + shift2.unsqueeze(1)
81
+
82
+ mlp_out = self.mlp(h2)
83
+ x = x + mlp_out
84
+
85
+ return x
86
+
87
+
88
+ class PeptideControlField(nn.Module):
89
+ def __init__(self, cfg):
90
+ super().__init__()
91
+ self.cfg = cfg
92
+
93
+ pth = cfg.model.esm_model
94
+ self.embed_model = AutoModelForMaskedLM.from_pretrained(pth, trust_remote_code=True)
95
+ self.tokenizer = AutoTokenizer.from_pretrained(pth, trust_remote_code=True)
96
+
97
+ # Freeze params
98
+ self.embed_model.eval()
99
+ for param in self.embed_model.parameters():
100
+ param.requires_grad = False
101
+
102
+ # # Unfreeze QKV in last few encoder layers
103
+ # encoder_layers = self.embed_model.esm.encoder.layer
104
+ # for layer in encoder_layers[-cfg.training.n_unfrozen:]:
105
+ # for param in layer.parameters():
106
+ # param.requires_grad = True
107
+
108
+ self.time_embed = TimeEmbedding(
109
+ hidden_dim=cfg.time_embed.time_dim,
110
+ fourier_dim=cfg.time_embed.fourier_dim,
111
+ scale=cfg.time_embed.fourier_scale
112
+ )
113
+
114
+ self.blocks = nn.ModuleList([
115
+ DiTBlock1D(self.cfg)
116
+ for _ in range(cfg.model.n_layers)
117
+ ])
118
+
119
+ self.final_norm = nn.LayerNorm(cfg.model.hidden_dim, eps=1e-6)
120
+
121
+ # self.output_proj = self.embed_model.lm_head
122
+ # for param in self.output_proj.parameters():
123
+ # param.requires_grad = False
124
+
125
+ self.output_proj = nn.Linear(cfg.model.hidden_dim, self.tokenizer.vocab_size)
126
+ nn.init.zeros_(self.output_proj.weight)
127
+ nn.init.zeros_(self.output_proj.bias)
128
+
129
+ def forward(self, t, xt, attention_mask):
130
+ with torch.no_grad():
131
+ outs = self.embed_model(input_ids=xt, attention_mask=attention_mask, output_hidden_states=True)
132
+
133
+ gate = (1.0 - t).view(-1, 1, 1)
134
+ u_base = gate * outs.logits
135
+
136
+ h = outs.hidden_states[-1]
137
+ t_emb = self.time_embed(t) # [B, time_dim]
138
+
139
+ # Transformer head (key_padding_mask=True for pads)
140
+ key_padding_mask = (attention_mask == 0) # (B, L) bool
141
+ for dit_block in self.blocks:
142
+ h = dit_block(h, t_emb, key_padding_mask=key_padding_mask)
143
+
144
+ # Final norm + projection to vocab logits
145
+ h = self.final_norm(h) # [B, L, hidden_dim]
146
+ logits = self.output_proj(h) # [B, L, V]
147
+
148
+ return {
149
+ "esm": u_base,
150
+ "dit": logits,
151
+ "madsbm": u_base + logits
152
+ }
153
+
154
+
155
+
156
+
157
+ # def forward(self, t, xt, attention_mask):
158
+ # outs = self.embed_model(input_ids=xt, attention_mask=attention_mask, output_hidden_states=True)
159
+ # h = outs.hidden_states[-1]
160
+ # t_emb = self.time_embed(t) # [B, time_dim]
161
+
162
+ # # Transformer head (key_padding_mask=True for pads)
163
+ # key_padding_mask = (attention_mask == 0) # (B, L) bool
164
+ # for dit_block in self.blocks:
165
+ # h = dit_block(h, t_emb, key_padding_mask=key_padding_mask)
166
+
167
+ # # Final norm + projection to vocab logits
168
+ # h = self.final_norm(h) # [B, L, hidden_dim]
169
+ # logits = self.output_proj(h) # [B, L, V]
170
+ # return logits
171
+
172
+
173
+ # def forward(self, xt, attention_mask, t):
174
+ # with torch.no_grad():
175
+ # base_out = self.embed_model(
176
+ # input_ids=xt,
177
+ # attention_mask=attention_mask,
178
+ # output_hidden_states=True
179
+ # )
180
+
181
+ # logits_base = base_out.logits
182
+ # h_base = base_out.hidden_states[-1]
183
+
184
+ # norm = self.token_norm_sqrd.view(1,1,-1) # 1, 1, V
185
+
186
+ # log_R0 = (self.beta1 * logits_base) - (self.beta2 * norm)
187
+
188
+ # t_emb = self.time_embed(t) # [B, time_dim]
189
+ # key_padding_mask = (attention_mask == 0) # (B, L) bool
190
+
191
+ # h_ctrl = h_base
192
+ # for dit_block in self.blocks:
193
+ # h_ctrl = dit_block(h_ctrl, t_emb, key_padding_mask=key_padding_mask)
194
+
195
+ # h_ctrl = self.final_norm(h_ctrl)
196
+ # u_theta = self.output_proj(h_ctrl)
197
+ # tot_logits = log_R0 + u_theta
198
+
199
+ # return tot_logits, u_theta
src/madsbm/wt_peptide/dataloader.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pandas as pd
3
+ import lightning.pytorch as pl
4
+
5
+ from omegaconf import OmegaConf
6
+ from datasets import load_from_disk
7
+ from torch.utils.data import DataLoader
8
+ from torch.nn.utils.rnn import pad_sequence
9
+ from functools import partial
10
+ from src.utils.model_utils import _print
11
+
12
+ config = OmegaConf.load('/scratch/pranamlab/sgoel/MadSBM/configs/wt_pep.yaml')
13
+
14
+
15
+ # class DNADataset(Dataset):
16
+ # def __init__(self, config, data_path):
17
+ # self.config = config
18
+ # self.data = pd.read_csv(data_path)
19
+ # self.custom_tokenizer = CustomDNATokenizer(config.model.dna_model_path)
20
+
21
+ # def __len__(self):
22
+ # return len(self.data)
23
+
24
+ # def __getitem__(self, idx):
25
+ # sequence = self.data.iloc[idx]["Sequence"]
26
+ # seq = sequence.upper()
27
+
28
+ # tokenized = self.custom_tokenizer(seq, max_length=self.config.data.max_seq_len)
29
+
30
+ # return {
31
+ # "input_ids": tokenized["input_ids"].squeeze(0),
32
+ # "attention_mask": tokenized["attention_mask"].squeeze(0)
33
+ # }
34
+
35
+
36
+
37
+ def collate_fn(batch, pad_id=None):
38
+ input_ids = torch.tensor(batch[0]['input_ids'])
39
+ attention_mask = torch.tensor(batch[0]['attention_mask'])
40
+ return {
41
+ 'input_ids': input_ids,
42
+ 'attention_mask': attention_mask
43
+ }
44
+
45
+
46
+ class PeptideDataModule(pl.LightningDataModule):
47
+ def __init__(self, config, train_dataset, val_dataset, test_dataset, tokenizer, collate_fn=collate_fn):
48
+ super().__init__()
49
+ self.train_dataset = train_dataset
50
+ self.val_dataset = val_dataset
51
+ self.test_dataset = test_dataset
52
+ self.tokenizer = tokenizer
53
+ self.collate_fn = collate_fn
54
+ self.batch_size = config.data.batch_size
55
+ assert self.batch_size == 1, f'Batch size = {self.batch_size}. Needs to be 1 for dynamic batching'
56
+
57
+ def train_dataloader(self):
58
+ return DataLoader(self.train_dataset,
59
+ batch_size=self.batch_size,
60
+ collate_fn=partial(self.collate_fn),
61
+ num_workers=8,
62
+ shuffle=False,
63
+ pin_memory=True)
64
+
65
+ def val_dataloader(self):
66
+ return DataLoader(self.val_dataset,
67
+ batch_size=self.batch_size,
68
+ collate_fn=partial(self.collate_fn),
69
+ num_workers=8,
70
+ shuffle=False,
71
+ pin_memory=True)
72
+
73
+ def test_dataloader(self):
74
+ return DataLoader(self.test_dataset,
75
+ batch_size=self.batch_size,
76
+ collate_fn=partial(self.collate_fn),
77
+ num_workers=8,
78
+ shuffle=False,
79
+ pin_memory=True)
80
+
81
+
82
+ def get_datasets(config):
83
+ """Helper method to grab datasets to quickly init data module in main.py"""
84
+ train_dataset = load_from_disk(config.data.train)
85
+ test_dataset = load_from_disk(config.data.test)
86
+ val_dataset = load_from_disk(config.data.val)
87
+
88
+ return {
89
+ "train": train_dataset,
90
+ "val": val_dataset,
91
+ "test": test_dataset
92
+ }
93
+
src/madsbm/wt_peptide/main.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+
3
+ import sys
4
+ import os
5
+ import torch
6
+ import wandb
7
+ import lightning.pytorch as pl
8
+
9
+ from omegaconf import OmegaConf
10
+ from lightning.pytorch.strategies import DDPStrategy
11
+ from lightning.pytorch.loggers import WandbLogger
12
+ from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
13
+
14
+
15
+ from src.madsbm.wt_peptide.sbm_module import MadSBM
16
+ from src.madsbm.wt_peptide.dataloader import PeptideDataModule, get_datasets
17
+
18
+ wandb.login(key='2b76a2fa2c1cdfddc5f443602c17b011fefb0a8f')
19
+
20
+
21
+ # Load yaml config
22
+ config = OmegaConf.load("/scratch/pranamlab/sgoel/MadSBM/configs/wt_pep.yaml")
23
+
24
+ # Initialize WandB for logging
25
+ wandb.init(project=config.wandb.project, name=config.wandb.name)
26
+ wandb_logger = WandbLogger(**config.wandb)
27
+
28
+ # PL checkpoints
29
+ lr_monitor = LearningRateMonitor(logging_interval="step")
30
+
31
+ every_epoch_cb = ModelCheckpoint(
32
+ dirpath=config.checkpointing.save_dir,
33
+ filename="{epoch:02d}_{step}",
34
+ save_top_k=-1,
35
+ every_n_epochs=1,
36
+ save_on_train_epoch_end=True,
37
+ verbose=True,
38
+ )
39
+
40
+ best_ckpt_cb = ModelCheckpoint(
41
+ monitor="val/loss",
42
+ dirpath=config.checkpointing.save_dir,
43
+ filename="best-model_{epoch:02d}_{step}",
44
+ save_top_k=1,
45
+ mode="min",
46
+ verbose=True,
47
+ save_last=False,
48
+ )
49
+
50
+ # PL trainer
51
+ trainer = pl.Trainer(
52
+ #max_steps=None, # Ensure training is based on epochs so we can compare with MOG-DFM and DirichletFM
53
+ max_epochs=config.training.n_epochs,
54
+ accelerator="cuda" if torch.cuda.is_available() else "cpu",
55
+ devices=config.training.devices if config.training.mode=='train' else [0],
56
+ strategy=DDPStrategy(find_unused_parameters=True),
57
+ callbacks=[every_epoch_cb, best_ckpt_cb, lr_monitor],
58
+ logger=wandb_logger
59
+ )
60
+
61
+
62
+ # Folder to save checkpoints
63
+ ckpt_path = config.checkpointing.save_dir
64
+ try: os.makedirs(ckpt_path, exist_ok=False)
65
+ except FileExistsError: pass
66
+
67
+ # PL Model for training
68
+ sbm_model = MadSBM(config)
69
+ sbm_model.validate_config()
70
+
71
+ # Get datasets
72
+ datasets = get_datasets(config)
73
+ data_module = PeptideDataModule(
74
+ config=config,
75
+ train_dataset=datasets['train'],
76
+ val_dataset=datasets['val'],
77
+ test_dataset=datasets['test'],
78
+ tokenizer=sbm_model.tokenizer,
79
+ )
80
+
81
+ # Start/resume training or evaluate the model
82
+ if config.training.mode == "train":
83
+ trainer.fit(sbm_model, datamodule=data_module)
84
+
85
+ elif config.training.mode == "test":
86
+ state_dict = sbm_model.get_state_dict(config.checkpointing.best_ckpt_path)
87
+ sbm_model.load_state_dict(state_dict)
88
+ trainer.test(sbm_model, datamodule=data_module, ckpt_path=config.checkpointing.best_ckpt_path)
89
+
90
+ elif config.training.mode == "resume_from_checkpoint":
91
+ state_dict = sbm_model.get_state_dict(config.training.resume_ckpt_path)
92
+ sbm_model.load_state_dict(state_dict)
93
+ trainer.fit(sbm_model, datamodule=data_module, ckpt_path=ckpt_path)
94
+
95
+ wandb.finish()
96
+
src/madsbm/wt_peptide/sbm_module.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ import math
4
+ from re import L
5
+ import torch
6
+
7
+ import lightning as pl
8
+ import torch.nn.functional as F
9
+
10
+ from transformers import AutoModel
11
+
12
+ from src.madsbm.wt_peptide.control_field import PeptideControlField
13
+ from src.PeptiVerse.inference import PeptiVersePredictor
14
+ from src.utils.model_utils import CosineWarmup, _print, compute_grad_norms
15
+
16
+
17
+ class MadSBM(pl.LightningModule):
18
+ def __init__(self, config, guidance=None):
19
+ super().__init__()
20
+
21
+ self.config = config
22
+ self.model = PeptideControlField(config)
23
+ self.tokenizer = self.model.tokenizer
24
+ self.vocab_size = self.tokenizer.vocab_size
25
+
26
+ self.mask_id = self.tokenizer.mask_token_id
27
+ self.pad_id = self.tokenizer.pad_token_id
28
+
29
+ self.embed_model = AutoModel.from_pretrained(config.model.esm_model)
30
+ self.embed_model.eval()
31
+ for param in self.embed_model.parameters():
32
+ param.requires_grad = False
33
+
34
+ self.beta = 1.0 / self.config.model.hidden_dim
35
+
36
+ # self.L = config.data.max_seq_len
37
+ # self.V = self.vocab_size
38
+ # self.log_R0 = - math.log(self.L * self.V) # uninformed generator is constant
39
+
40
+ self.time_schedule = config.time_embed.time_schedule
41
+ self.anneal_frac = config.time_embed.anneal_frac
42
+ self.eps = float(config.time_embed.min_time)
43
+ self.t_max = 1.0 - self.eps
44
+
45
+
46
+ # -------# Forward Pass #-------- #
47
+ def forward(self, input_ids, attention_mask, t):
48
+ return self.model(xt=input_ids, attention_mask=attention_mask, t=t)
49
+
50
+ def step(self, batch):
51
+ x1 = batch['input_ids']
52
+ attn_mask = batch['attention_mask']
53
+ maskable = self.is_maskable(x1)
54
+
55
+ t = self.sample_t(x1)
56
+ xt = self.noise_seq(x1, t, maskable_mask=maskable)
57
+
58
+ outs = self.forward(xt, attn_mask, t)
59
+ if self.config.model.ablate:
60
+ logits = outs['dit']
61
+ else:
62
+ logits = outs['madsbm']
63
+ max_u_logit = outs['dit'].max().item()
64
+ max_esm_logit = outs['esm'].max().item()
65
+
66
+ loss_token = F.cross_entropy(
67
+ logits.view(-1, logits.size(-1)),
68
+ x1.view(-1),
69
+ reduction = 'none',
70
+ ignore_index=self.pad_id
71
+ )
72
+ loss_token = loss_token.view(x1.size(0), x1.size(1))
73
+
74
+ sample_loss = (loss_token * maskable.float()).sum(dim=1) / maskable.float().sum(dim=1).clamp(min=1.0)
75
+
76
+ loss = sample_loss.mean()
77
+ ppl = torch.exp(loss)
78
+
79
+ _print(f'loss: {loss}')
80
+ _print(f'ppl: {ppl}')
81
+
82
+ return loss, ppl, max_u_logit, max_esm_logit
83
+
84
+
85
+
86
+ # def step(self, batch):
87
+ # x1 = batch['input_ids']
88
+ # attn_mask = batch['attention_mask']
89
+ # maskable = self.is_maskable(x1)
90
+
91
+ # t = self.sample_t(x1)
92
+ # xt = self.noise_seq(x1, t, maskable_mask=maskable)
93
+
94
+ # u_theta = self.forward(xt, attn_mask, t)
95
+ # b, l, v_target = self.compute_target(x1, xt, t, maskable_mask=maskable)
96
+ # loss, ppl = self.compute_loss(u_theta, v_target, x1, b, l)
97
+
98
+ # _print(f'loss: {loss}')
99
+ # _print(f'ppl: {ppl}')
100
+
101
+ # return loss, ppl
102
+
103
+
104
+ # -------# Main Training Logic #-------- #
105
+ def noise_seq(self, x1, t, maskable_mask):
106
+ B, L = x1.shape
107
+ t = t.unsqueeze(1) # B, 1
108
+
109
+ # reveal if u < t, mask if u >= t
110
+ u = torch.rand((B, L), device=x1.device)
111
+ masked = (u < t) & maskable_mask
112
+
113
+ xt = x1.clone()
114
+ xt = xt.masked_fill(masked, self.mask_id)
115
+
116
+ return xt
117
+
118
+ # def compute_target(self, x1, xt, t, maskable_mask):
119
+ # L = x1.size(1)
120
+ # V = self.vocab_size
121
+ # device = x1.device
122
+
123
+ # mask = (xt == self.mask_id) & maskable_mask
124
+ # b, l = torch.nonzero(mask, as_tuple=True)
125
+
126
+ # if b.numel() == 0:
127
+ # return b, l, torch.empty(0, device=device, dtype=torch.long)
128
+
129
+ # log_R0 = - math.log(L * V) # uniform generator with rates (1 / L*V)
130
+ # time = - torch.log(1 - t[b])
131
+
132
+ # v_target = time - log_R0 # log(1/1-t) - log(1/L*V)
133
+ # v_target = v_target.clamp(min=-100.0, max=100.0)
134
+
135
+ # return b, l, v_target
136
+
137
+
138
+ # def compute_loss(self, u_theta, v_target, x1, b, l):
139
+ # if b.numel() == 0:
140
+ # dummy_loss = 0.0 * u_theta.sum()
141
+ # return dummy_loss, torch.tensor(0.0, device=u_theta.device)
142
+
143
+ # true_toks = x1[b, l]
144
+ # u_pred = u_theta[b, l, :] # N_masks, V
145
+
146
+ # tgt = torch.zeros_like(u_pred)
147
+ # tgt.scatter_(1, true_toks.unsqueeze(1), v_target.unsqueeze(1))
148
+
149
+ # sse = F.mse_loss(u_pred, tgt, reduction='sum')
150
+ # loss = sse / b.numel() if b.numel != 0 else sse # normalize by number of masks
151
+
152
+ # with torch.no_grad():
153
+ # ppl = torch.exp(F.cross_entropy(u_pred, true_toks))
154
+
155
+ # return loss, ppl
156
+
157
+
158
+ # -------# Time Schedules #-------- #
159
+ def sample_t(self, x1):
160
+ ts = self.time_schedule
161
+ if ts == 'linear':
162
+ return self.sample_linear_t(x1)
163
+ elif ts == 'exponential':
164
+ return self.sample_exp_t(x1)
165
+ elif ts == 'uniform':
166
+ return self.sample_uni_t(x1)
167
+ else:
168
+ raise ValueError(f"Unrecognized time scheduler type: {ts}")
169
+
170
+ def sample_uni_t(self, x1):
171
+ B = x1.size(0)
172
+ T = self.config.time_embed.n_timesteps
173
+
174
+ discrete_ts = torch.randint(1, T+1, (B,), device=x1.device)
175
+ timesteps = discrete_ts.float() / float(T)
176
+ _print(f'timesteps: {timesteps}')
177
+ return timesteps.clamp(min=self.eps, max=self.t_max)
178
+
179
+
180
+ def sample_linear_t(self, x1):
181
+ B = x1.size(0)
182
+ eps = self.eps
183
+
184
+ # fraction of total training steps completed
185
+ frac = float(self.global_step) / float(self.tot_steps)
186
+ t_max = 1.0 - eps
187
+
188
+ if frac < self.anneal_frac:
189
+ # normalize progress within the anneal window
190
+ prog = frac / max(1e-12, self.anneal_frac) # maps [0, anneal_frac) to [0,1)
191
+ t_min = eps + prog * (t_max - eps) # linear increase from eps to 1.0-eps
192
+ t = t_min + (t_max - t_min) * torch.rand(B, device=x1.device)
193
+ else:
194
+ # after anneal_frac of training steps completed, then uniform sample over entire range [eps, 1.0-eps]
195
+ t = eps + (t_max - eps) * torch.rand(B, device=x1.device)
196
+
197
+ return t.clamp(min=eps, max=t_max)
198
+
199
+
200
+ def sample_t_exponential(self, x1, t_min=1e-6, t_max=1.0-1e-6):
201
+ # TODO - FIX THIS METHOD IF NEEDED !!
202
+ """
203
+ Exponentially anneal center of t from t_min to t_max over training.
204
+
205
+ Implement if linear schedule isn't expressive enough
206
+ But for annealing over training steps, which can be a very large quantity,
207
+ exponential approximates linear schedule
208
+ """
209
+ # k controls how fast the curve rises.
210
+ k = self.config.training.exp_time_k
211
+ progress = self.trainer.step / self.tot_steps
212
+ frac = 1.0 - torch.exp(-k * torch.tensor(progress))
213
+ center = t_min + frac * (t_max - t_min)
214
+
215
+ # add small jitter so we don't collapse onto a distribution
216
+ t = torch.randn(x1.size(0)) * self.config.training.time_sigma + center
217
+ return t.clamp(min=t_min, max=t_max)
218
+
219
+
220
+
221
+ # -------# Model Training / Evaluation #-------- #
222
+ def training_step(self, batch):
223
+ loss, ppl = self.step(batch)
224
+ self.log("train/loss", loss, on_step=True, on_epoch=False, prog_bar=True)
225
+ self.log("train/ppl", ppl, on_step=True, on_epoch=False, prog_bar=False)
226
+ return loss
227
+
228
+ def validation_step(self, batch):
229
+ loss, ppl = self.step(batch)
230
+ self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
231
+ self.log("val/ppl", ppl, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
232
+ return loss
233
+
234
+ def test_step(self, batch):
235
+ loss, ppl, max_u, max_esm = self.step(batch)
236
+ self.log('test/loss', loss, on_step=False, on_epoch=True, prog_bar=True, sync_dist=True)
237
+ self.log("test/ppl", ppl, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
238
+ self.log("test/max_madsbm_logit", max_u, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
239
+ self.log("test/max_esm_logit", max_esm, on_step=False, on_epoch=True, prog_bar=False, sync_dist=True)
240
+ return loss
241
+
242
+ def on_after_backward(self):
243
+ pre_norm = compute_grad_norms(self.parameters())
244
+ self.log('train/grad_norm_PRE_clip', pre_norm, on_step=True, on_epoch=False, prog_bar=False, sync_dist=True)
245
+
246
+ # torch.nn.utils.clip_grad_norm_(self.parameters(), float(self.config.training.grad_clip_val))
247
+ # post_norm = compute_grad_norms(self.parameters())
248
+ # self.log('train/grad_norm_POST_clip', post_norm, on_step=True, on_epoch=False, prog_bar=False, sync_dist=True)
249
+
250
+ def configure_optimizers(self):
251
+ optimizer = torch.optim.AdamW(
252
+ params = self.model.parameters(),
253
+ lr = self.config.optim.lr,
254
+ weight_decay = self.config.optim.weight_decay,
255
+ betas = (self.config.optim.beta1, self.config.optim.beta2)
256
+ )
257
+
258
+ self.tot_steps = self.trainer.estimated_stepping_batches
259
+ warmup_steps = int(self.config.optim.warmup_epochs * self.tot_steps / self.config.training.n_epochs)
260
+
261
+ lr_scheduler = CosineWarmup(
262
+ optimizer = optimizer,
263
+ warmup_steps = warmup_steps,
264
+ total_steps = self.tot_steps
265
+ )
266
+
267
+ return {
268
+ "optimizer": optimizer,
269
+ "lr_scheduler": {
270
+ "scheduler": lr_scheduler,
271
+ "interval": "step",
272
+ "frequency": 1
273
+ }
274
+ }
275
+
276
+ def on_save_checkpoint(self, checkpoint: dict):
277
+ """
278
+ Don't save the classifier model used for FBD calculation in the ckpt
279
+ """
280
+ sd = checkpoint.get('state_dict', None)
281
+ if sd is None:
282
+ return
283
+ keys_to_remove = [k for k in sd.keys() if k.startswith("score_model.")]
284
+ for k in keys_to_remove:
285
+ sd.pop(k, None)
286
+ checkpoint['state_dict'] = sd
287
+
288
+
289
+ # -------# Helper methods #-------- #
290
+ def is_maskable(self, input_ids: torch.Tensor):
291
+ return (
292
+ (input_ids != self.tokenizer.pad_token_id)
293
+ & (input_ids != self.tokenizer.cls_token_id)
294
+ & (input_ids != self.tokenizer.eos_token_id)
295
+ )
296
+
297
+ def validate_config(self):
298
+ assert os.path.isdir(self.config.checkpointing.save_dir), "invalid checkpointing path"
299
+ assert self.config.model.hidden_dim % 2 == 0, 'odd value for embedding dim'
300
+ assert self.config.time_embed.time_dim % 2 == 0, 'odd value for time dim'
301
+ assert self.config.time_embed.fourier_dim % 2 == 0, 'odd value for fourier dim'
302
+
303
+ def get_state_dict(self, ckpt_path):
304
+ def remove_model_prefix(state_dict):
305
+ for k, v in state_dict.items():
306
+ if "model." in k:
307
+ k.replace('model.', '')
308
+ return state_dict
309
+
310
+ checkpoint = torch.load(ckpt_path, map_location='cuda:3' if torch.cuda.is_available() else 'cpu')
311
+ state_dict = checkpoint.get("state_dict", checkpoint)
312
+
313
+ if any(k.startswith("model.") for k in state_dict.keys()):
314
+ state_dict = remove_model_prefix(state_dict)
315
+
316
+ return state_dict
317
+
318
+ def cleanup(self):
319
+ torch.cuda.empty_cache()
320
+ gc.collect()
src/sampling/diffusion_sampler.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import torch
3
+ import random
4
+ import numpy as np
5
+ from tqdm import tqdm
6
+ from src.utils.model_utils import _print
7
+
8
+ class DiffusionSampler:
9
+ def __init__(self, model, tokenizer):
10
+ self.model = model
11
+ self.tokenizer = tokenizer
12
+
13
+ self.device = self.model.device
14
+ self.mask_id = self.tokenizer.mask_token_id
15
+ self.seed_everything(seed=42)
16
+
17
+ @torch.inference_mode()
18
+ def sample_unconditional(self, xt, num_steps, tracer, tau=1.0, kappa_fn=lambda t: t, eta=1, alpha=1.):
19
+ """
20
+ Stochastic remasking sampling method for iterative refinement of sequences.
21
+ Args:
22
+ xt (Tensor): Initial token tensor.
23
+ num_steps (int): Number of refinement steps.
24
+ tau (float): Temperature parameter for softmax sampling.
25
+ kappa_fn (callable): Function controlling the unmasking schedule.
26
+ eta (float): Scaling factor for score adjustments.
27
+ alpha (float): Weighting for confidence-based scoring.
28
+ Returns:
29
+ Tensor: Final sampled sequence tensor.
30
+ """
31
+
32
+ dt = 1 / num_steps
33
+ fix_mask = xt != self.mask_id # tokens to retain
34
+ attention_mask = torch.ones_like(xt).to(self.device)
35
+
36
+ if tracer:
37
+ tracer.log_step(xt=xt, step_idx = 0)
38
+
39
+ for i in range(1, num_steps + 1):
40
+ kappa_t = kappa_fn(i * dt)
41
+ logits = self.model(input_ids=xt, attention_mask=attention_mask).logits
42
+ last_mask = xt == self.mask_id # tokens currently masked
43
+ unmask_t = ~last_mask & ~fix_mask # unmasked and not fixed tokens - candidates for masking
44
+
45
+ x0, logp = self.stochastic_sample_from_categorical(logits, tau) # tokens, logprobs
46
+
47
+ # Confidence-based scoring
48
+ entropy = torch.distributions.Categorical(logits=logits).entropy()
49
+ score = alpha * logp + (1 - alpha) * -entropy # alpha = 1 --> score = logp
50
+ score = score.masked_fill(fix_mask, float('inf'))
51
+
52
+ score[unmask_t] = score[unmask_t] * eta
53
+
54
+ num_to_mask = ((~fix_mask).sum(1, keepdim=True).float() * (1 - kappa_t)).long()
55
+ lowest_k_mask = self.topk_lowest_masking(score, num_to_mask)
56
+
57
+ xt[lowest_k_mask] = self.mask_id
58
+ mask_2_x0 = last_mask & ~lowest_k_mask
59
+ xt[mask_2_x0] = x0[mask_2_x0]
60
+
61
+ tracer.log_step(xt=xt, step_idx = i)
62
+
63
+ xt[xt == self.mask_id] = x0[xt == self.mask_id]
64
+
65
+ tracer.log_step(xt, num_steps + 1)
66
+
67
+ return xt
68
+
69
+
70
+ def stochastic_sample_from_categorical(self, logits, temperature, noise_scale=1.0):
71
+ """
72
+ Sample from a categorical distribution with optional temperature scaling and Gumbel noise.
73
+ """
74
+ logits = logits.double()
75
+ if temperature != 0:
76
+ gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits) + 1e-8) + 1e-8)
77
+ logits = logits / temperature + noise_scale * gumbel_noise
78
+ scores, tokens = logits.log_softmax(dim=-1).max(dim=-1)
79
+
80
+ return tokens, scores
81
+
82
+ def topk_lowest_masking(self, scores, cutoff_len):
83
+ """
84
+ scores: [b, n]
85
+ cutoff_len: [b, 1]
86
+ returns:
87
+ mask: [b, n], with 1 if the token is in top-k lowest scores, 0 otherwise
88
+ """
89
+ sorted_index = scores.sort(-1)[0]
90
+ cutoff = sorted_index.gather(dim=-1, index=cutoff_len)
91
+ return scores < cutoff
92
+
93
+ def seed_everything(self, seed):
94
+ """
95
+ Set the seed for reproducibility across various libraries.
96
+ """
97
+ if seed is None:
98
+ return
99
+ random.seed(seed)
100
+ np.random.seed(seed)
101
+ torch.manual_seed(seed)
102
+ if torch.cuda.is_available():
103
+ torch.cuda.manual_seed(seed)
104
+ torch.cuda.manual_seed_all(seed) # if using multi-GPU
105
+ torch.backends.cudnn.deterministic = True
106
+ torch.backends.cudnn.benchmark = False
src/sampling/guided_sample.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import os
4
+ import random
5
+ import torch
6
+ import pandas as pd
7
+
8
+ from tqdm import tqdm
9
+ from omegaconf import OmegaConf
10
+ from datetime import datetime
11
+
12
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
13
+
14
+ from src.madsbm.wt_peptide.sbm_module import MadSBM
15
+ from src.sampling.madsbm_sampler import MadSBMSampler
16
+
17
+ from src.utils.generate_utils import calc_entropy, mask_for_de_novo, calc_ppl
18
+ from src.utils.model_utils import _print
19
+
20
+
21
+ device = torch.device('cuda:3' if torch.cuda.is_available() else 'cpu')
22
+ os.chdir('/scratch/pranamlab/sgoel/MadSBM')
23
+ config = OmegaConf.load("./configs/wt_pep.yaml")
24
+
25
+ date = datetime.now().strftime("%Y-%m-%d")
26
+
27
+
28
+ def generate_sequence(masked_seq, target_toks, tokenizer, generator, device):
29
+ input_ids = tokenizer(masked_seq, return_tensors="pt").to(device)['input_ids']
30
+
31
+ uncond_ids, uncond_bind = generator.sample(xt=input_ids, num_steps=config.sampling.n_steps, target_toks=target_toks, guidance=False)
32
+ guided_ids, guided_bind = generator.sample(xt=input_ids, num_steps=config.sampling.n_steps, target_toks=target_toks, guidance=True)
33
+
34
+ uncond_seq = tokenizer.decode(uncond_ids[0].squeeze())[5:-5].replace(" ", "") # bos/eos tokens & spaces between residues
35
+ guided_seq = tokenizer.decode(guided_ids[0].squeeze())[5:-5].replace(" ", "")
36
+
37
+ return uncond_seq, guided_seq, uncond_bind, guided_bind
38
+
39
+
40
+ def main():
41
+ csv_save_path = f'./results/guided/'
42
+
43
+ try: os.makedirs(csv_save_path, exist_ok=False)
44
+ except FileExistsError: pass
45
+
46
+ # Load ESM model for eval
47
+ esm_pth = config.model.esm_model
48
+ esm_model = AutoModelForMaskedLM.from_pretrained(esm_pth).to(device)
49
+ esm_model.eval()
50
+
51
+ # Load SBM model
52
+ gen_model = MadSBM(config)
53
+ state_dict = gen_model.get_state_dict(config.checkpointing.best_ckpt_path)
54
+ gen_model.load_state_dict(state_dict)
55
+ gen_model.to(device)
56
+ gen_model.eval()
57
+ tokenizer = gen_model.tokenizer
58
+ generator = MadSBMSampler(gen_model, config, device, guidance=True)
59
+
60
+
61
+ tgt_name = "3HVE"
62
+ df = pd.read_csv("./data/wt_pep/targets.csv")
63
+ tgt_seq = df.loc[df['Target'] == tgt_name, 'Sequence'].iloc[0]
64
+ target_toks = tokenizer(tgt_seq, return_tensors='pt')['input_ids'].to(device)
65
+
66
+ existing_binder = df.loc[df['Target'] == tgt_name, 'Existing Binder'].iloc[0]
67
+ existing_binder_pred = generator.peptiverse.predict_binding_affinity(
68
+ mode = 'wt',
69
+ target_ids = target_toks,
70
+ binder_ids = tokenizer(existing_binder, return_tensors='pt')['input_ids'].to(device).detach()
71
+ )['affinity']
72
+
73
+ _print(f'EXISTING BINDER AFFINITY: {existing_binder_pred}')
74
+
75
+ seq_lengths = [length for length in [10, 15, 20] for _ in range(20)]
76
+ generation_results = []
77
+
78
+ for seq_len in tqdm(seq_lengths, desc=f"Generating sequences: "):
79
+
80
+ masked_seq = mask_for_de_novo(seq_len) # Sequence of all <MASK> tokens
81
+ uncond_seq, guided_seq, uncond_bind, guided_bind = generate_sequence(masked_seq, target_toks, tokenizer, generator, device)
82
+
83
+ uncond_ppl = calc_ppl(esm_model, tokenizer, uncond_seq, [i for i in range(len(uncond_seq))], model_type='esm')
84
+ guided_ppl = calc_ppl(esm_model, tokenizer, uncond_seq, [i for i in range(len(uncond_seq))], model_type='esm')
85
+
86
+ _print(f'uncond seq: {uncond_seq}')
87
+ _print(f'uncond ppl: {uncond_ppl}')
88
+ _print(f'uncond bind: {uncond_bind}')
89
+
90
+ _print(f'guided seq: {guided_seq}')
91
+ _print(f'guided ppl: {guided_ppl}')
92
+ _print(f'guided bind: {guided_bind}')
93
+
94
+ res_row = {
95
+ "Uncond Generated Sequence": uncond_seq,
96
+ "Guided Generated Sequence": guided_seq,
97
+ "Uncond PPL": uncond_ppl,
98
+ "Guided PPL": guided_ppl,
99
+ "Uncond Affinity": uncond_bind,
100
+ "Guided Affinity": guided_bind
101
+ }
102
+
103
+ generation_results.append(res_row)
104
+
105
+
106
+ df = pd.DataFrame(generation_results)
107
+
108
+ _print(f"Uncond PPL Res: {df['Uncond PPL'].mean()}, {df['Uncond PPL'].std()}")
109
+ _print(f"Guided PPL Res: {df['Guided PPL'].mean()}, {df['Guided PPL'].std()}")
110
+
111
+ _print(f"Uncond Affinity Res: {df['Uncond Affinity'].mean()}, {df['Uncond Affinity'].std()}")
112
+ _print(f"Guided Affinity Res: {df['Guided Affinity'].mean()}, {df['Guided Affinity'].std()}")
113
+
114
+ df.to_csv(
115
+ csv_save_path + f"/{tgt_name}/tau=0.5_topp=0.9_no-gumbel_rate=0.01_jump=0.05_ablate={config.model.ablate}_nsteps={config.sampling.n_steps}_seqs_with_ppl_{date}.csv",
116
+ index=False
117
+ )
118
+
119
+
120
+ if __name__ == "__main__":
121
+ main()
src/sampling/madsbm_sampler.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import torch
4
+ import numpy as np
5
+
6
+ import torch.nn.functional as F
7
+
8
+ from src.PeptiVerse.inference import PeptiVersePredictor
9
+ from src.utils.model_utils import _print
10
+
11
+
12
+ class MadSBMSampler:
13
+ def __init__(self, model, config, device, guidance=None):
14
+ self.config = config
15
+ self.device = device
16
+ self.model = model
17
+ self.tokenizer = model.tokenizer
18
+ self.mask_id = self.tokenizer.mask_token_id
19
+ self.eps = config.time_embed.min_time
20
+ self.seed_everything(seed=42)
21
+
22
+ if guidance:
23
+ self.guidance = guidance
24
+ self.peptiverse = PeptiVersePredictor(
25
+ manifest_path="/scratch/pranamlab/sgoel/MadSBM/src/PeptiVerse/best_models.txt",
26
+ classifier_weight_root="/scratch/pranamlab/sgoel/MadSBM/src/PeptiVerse",
27
+ device=self.device
28
+ )
29
+
30
+
31
+ @torch.inference_mode()
32
+ def sample(self, xt, num_steps, tracer, target_toks=None, guidance=None):
33
+ xt = xt.clone()
34
+ B, L = xt.shape
35
+ assert B == 1, "Do only 1 sequence at a time"
36
+
37
+ t_max = 1.0 - self.eps
38
+ dt = 1.0 / num_steps
39
+ attn_mask = torch.ones_like(xt, device=self.device)
40
+
41
+ action_traj = {}
42
+ tot_action = 0.0
43
+
44
+ tracer.log_step(xt=xt, step_idx=0)
45
+
46
+ converge_idx = num_steps
47
+ converged = False
48
+
49
+ for k in range(num_steps):
50
+ # t decreases from 1 --> 0 as our model was trained that t=1 --> noise and t=0 --> clean
51
+ prog = (k + 1) / float(num_steps)
52
+ t_val = t_max - (t_max - self.eps) * prog
53
+ t = torch.full((B,), fill_value=float(t_val), device=self.device) # B = 1 during sampling
54
+
55
+ # predicted control field --> B, L, V
56
+ outs = self.model(input_ids=xt, attention_mask=attn_mask, t=t)
57
+
58
+ u_tilt = outs['dit']
59
+ total_logits = outs['madsbm']
60
+ esm_logits = outs['esm']
61
+
62
+ if self.config.model.ablate:
63
+ actional = self.compute_action(u_tilt, esm_logits=None)
64
+ else:
65
+ actional = self.compute_action(u_tilt, esm_logits=esm_logits)
66
+
67
+ action_traj[f"action_step_{k+1}"] = actional
68
+ tot_action += (actional * dt)
69
+
70
+ # Compute jump rates and jump probs
71
+ # P(jump) = 1 - exp(-rate * dt)
72
+ r_theta = torch.exp(u_tilt * self.config.sampling.rate_scale)
73
+ R_tot = r_theta.sum(dim=-1) # 1, L
74
+ rate = (- R_tot * self.config.sampling.jump_scale * dt).clamp(min=-40.0, max=0.0)
75
+ jump_prob = 1.0 - torch.exp(rate)
76
+
77
+ # Scale and filter logits with nucleus sampling
78
+ logits = total_logits.clone()
79
+ logits /= self.config.sampling.tau
80
+ logits = self.top_p_filter(logits, self.config.sampling.top_p)
81
+
82
+ # Sample new tokens
83
+ probs = F.softmax(logits, dim=-1)
84
+ probs = probs.view(-1, probs.size(-1))
85
+ sample = torch.multinomial(probs, 1)
86
+ candidate_toks = sample.view(B, L)
87
+
88
+ # determine tokens we can change
89
+ rand = torch.rand(B, L, device=self.device)
90
+ can_jump = (rand < jump_prob)
91
+ updatable = can_jump & self.is_masked(xt)
92
+
93
+ # Update the sequence
94
+ if guidance:
95
+ chosen_candidate = self.binding_guidance(probs, target_toks, B, L)
96
+ xt[updatable] = chosen_candidate[updatable]
97
+ else:
98
+ xt[updatable] = candidate_toks[updatable]
99
+
100
+ tracer.log_step(xt=xt, step_idx = k+1)
101
+
102
+ if k == num_steps-1:
103
+ final_logits = total_logits
104
+ still_masked = self.is_masked(xt)
105
+
106
+ if not converged and not self.is_masked(xt).any():
107
+ converge_idx = k + 1
108
+ converged = True
109
+
110
+ # Copy over remaining tokens
111
+ if still_masked.any():
112
+ final_toks = final_logits.argmax(dim=-1)
113
+ xt[still_masked] = final_toks[still_masked]
114
+
115
+ tracer.log_step(xt, num_steps + 1)
116
+
117
+ binding_affin = self.peptiverse.predict_binding_affinity(
118
+ mode = 'wt',
119
+ target_ids = target_toks,
120
+ binder_ids = xt
121
+ )['affinity']
122
+
123
+ return xt, binding_affin
124
+
125
+
126
+ def binding_guidance(self, probs, target_toks, B, L):
127
+ M = self.config.sampling.M
128
+ candidate_toks = []
129
+ affinities = []
130
+
131
+ for _ in range(M):
132
+ ith_sample = torch.multinomial(probs, 1).view(B, L)
133
+ candidate_toks.append(ith_sample)
134
+
135
+ for toks in candidate_toks:
136
+ pred = self.peptiverse.predict_binding_affinity(
137
+ mode = 'wt',
138
+ target_ids = target_toks,
139
+ binder_ids = toks.detach()
140
+ )['affinity']
141
+ affinities.append(pred)
142
+
143
+ affinities = torch.tensor(affinities, dtype=torch.float32)
144
+ weights = F.softmax(affinities / self.config.sampling.tau, dim=0)
145
+ chosen_idx = torch.multinomial(weights, 1).item()
146
+
147
+ return candidate_toks[chosen_idx]
148
+
149
+
150
+ def compute_action(self, u_tilt, esm_logits=None):
151
+ """ Computes the action functional for evals """
152
+ if esm_logits is not None:
153
+ R0 = torch.softmax(esm_logits, dim=-1)
154
+ else:
155
+ R0 = 1.0 / self.tokenizer.vocab_size
156
+
157
+ psi_u = torch.exp(u_tilt) - u_tilt - 1.0
158
+ action_per_tok = (R0 * psi_u).sum(dim=-1) # R0 goes to 1 in both cases
159
+
160
+ return action_per_tok.mean().item()
161
+
162
+
163
+ def top_p_filter(self, logits, p_val):
164
+ """
165
+ Implementation of nucleus / top-p sampling
166
+ Masks out tokens that contribute to the bottom (1 - p) cumulative probability
167
+ """
168
+ # Sort logits and get cumulative probabilities
169
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
170
+ cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
171
+
172
+ # Remove tokens with cum prob > p-val thresh
173
+ sorted_idx_to_remove = cum_probs > p_val
174
+
175
+ # Shift the indices to the right to keep also the first token above the threshold
176
+ sorted_idx_to_remove[..., 1:] = sorted_idx_to_remove[..., :-1].clone()
177
+ sorted_idx_to_remove[..., 0] = 0
178
+
179
+ idx_to_remove = sorted_idx_to_remove.scatter(-1, sorted_indices, sorted_idx_to_remove)
180
+ logits[idx_to_remove] = float('-inf')
181
+ return logits
182
+
183
+
184
+ def is_masked(self, xt):
185
+ return (xt == self.mask_id)
186
+
187
+
188
+ def seed_everything(self, seed):
189
+ if seed is None:
190
+ return
191
+ random.seed(seed)
192
+ np.random.seed(seed)
193
+ torch.manual_seed(seed)
194
+ if torch.cuda.is_available():
195
+ torch.cuda.manual_seed(seed)
196
+ torch.cuda.manual_seed_all(seed) # if using multi-GPU
197
+ torch.backends.cudnn.deterministic = True
198
+ torch.backends.cudnn.benchmark = False
src/sampling/path_tracer.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class ProbabilityPathTracer:
7
+ def __init__(self, oracle_model, tokenizer, device):
8
+ self.oracle = oracle_model
9
+ self.tokenizer = tokenizer
10
+ self.device = device
11
+ self.mask_id = tokenizer.mask_token_id
12
+ self.history = {} # {nth_step: prob_score}
13
+
14
+ @torch.inference_mode()
15
+ def compute_loglikeli(self, xt):
16
+ is_revealed = (xt != self.mask_id)
17
+
18
+ if not is_revealed.any():
19
+ return 0.0
20
+
21
+ # esm forward pass
22
+ logits = self.oracle(
23
+ input_ids=xt,
24
+ attention_mask=torch.ones_like(xt, device=xt.device)
25
+ ).logits
26
+
27
+ # Calculate CE loss only on unmasked tokens
28
+ nll = F.cross_entropy(
29
+ logits.view(-1, logits.size(-1)),
30
+ xt.view(-1),
31
+ reduction='none'
32
+ )
33
+
34
+ nll = nll.view(xt.shape)
35
+
36
+ # Lower NLL = better --> higher LL = better
37
+ avg_ll = -(nll * is_revealed.float()).sum(dim=1) / is_revealed.float().sum(dim=1).clamp(min=1)
38
+
39
+ return avg_ll.item()
40
+
41
+ def log_step(self, xt, step_idx):
42
+ score = self.compute_loglikeli(xt)
43
+ self.history[f"trace_step_{step_idx}"] = score
44
+
45
+ def get_trace(self):
46
+ return self.history
src/utils/__init__.py ADDED
File without changes
src/utils/eval_utils.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from scipy.linalg import sqrtm
4
+
5
+ def dna_to_tensor(seq):
6
+ mapping = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
7
+ indices = [mapping[base] for base in seq]
8
+ return torch.tensor(indices, dtype=torch.long)
9
+
10
+
11
+ def compute_fbd(true_seqs, gen_seqs, score_model):
12
+ """
13
+ The Frechet Biological Distance (FBD) is defined as the Wasserstein distance between Gaussian / true embeddings
14
+ """
15
+ embeds1 = score_model()
16
+ embeds2 = score_model()
17
+
18
+ if np.isnan(embeds2).any() or np.isnan(embeds1).any() or len(embeds1) == 0 or len(embeds2) == 0:
19
+ return float('nan')
20
+ mu1, sigma1 = embeds1.mean(axis=0), np.cov(embeds1, rowvar=False)
21
+ mu2, sigma2 = embeds2.mean(axis=0), np.cov(embeds2, rowvar=False)
22
+ ssdiff = np.sum((mu1 - mu2) ** 2.0)
23
+ covmean = sqrtm(sigma1.dot(sigma2))
24
+ if np.iscomplexobj(covmean):
25
+ covmean = covmean.real
26
+ dist = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
27
+ return dist
src/utils/fbd_score_model.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import copy
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from torch import nn
6
+ from src.utils.time_utils import GaussianFourierProjection
7
+
8
+
9
+ class Dense(nn.Module):
10
+ """
11
+ A fully connected layer that reshapes outputs to feature maps.
12
+ """
13
+ def __init__(self, input_dim, output_dim):
14
+ super().__init__()
15
+ self.dense = nn.Linear(input_dim, output_dim)
16
+
17
+ def forward(self, x):
18
+ return self.dense(x)[...]
19
+
20
+
21
+ class Swish(nn.Module):
22
+ def __init__(self):
23
+ super().__init__()
24
+
25
+ def forward(self, x):
26
+ return torch.sigmoid(x) * x
27
+
28
+
29
+ class CNNClassifier(nn.Module):
30
+ def __init__(self, args, alphabet_size, num_cls, classifier=False):
31
+ super().__init__()
32
+ self.alphabet_size = alphabet_size
33
+ self.args = args
34
+ self.classifier = classifier
35
+ self.num_cls = num_cls
36
+
37
+ self.linear = nn.Embedding(self.alphabet_size, embedding_dim=args.hidden_dim)
38
+
39
+ self.num_layers = 5 * args.num_cnn_stacks
40
+ self.convs = [
41
+ nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, padding=4),
42
+ nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, padding=4),
43
+ nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, dilation=4, padding=16),
44
+ nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, dilation=16, padding=64),
45
+ nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=9, dilation=64, padding=256)
46
+ ]
47
+ self.convs = nn.ModuleList([copy.deepcopy(layer) for layer in self.convs for i in range(args.num_cnn_stacks)])
48
+ self.time_layers = nn.ModuleList([Dense(args.hidden_dim, args.hidden_dim) for _ in range(self.num_layers)])
49
+ self.norms = nn.ModuleList([nn.LayerNorm(args.hidden_dim) for _ in range(self.num_layers)])
50
+ self.final_conv = nn.Sequential(
51
+ nn.Conv1d(args.hidden_dim, args.hidden_dim, kernel_size=1),
52
+ nn.ReLU(),
53
+ nn.Conv1d(args.hidden_dim, args.hidden_dim if classifier else self.alphabet_size, kernel_size=1)
54
+ )
55
+ self.dropout = nn.Dropout(args.dropout)
56
+
57
+ if classifier:
58
+ self.cls_head = nn.Sequential(
59
+ nn.Linear(args.hidden_dim, args.hidden_dim),
60
+ nn.ReLU(),
61
+ nn.Linear(args.hidden_dim, self.num_cls)
62
+ )
63
+
64
+ if self.args.cls_free_guidance and not self.classifier:
65
+ self.cls_embedder = nn.Embedding(num_embeddings=self.num_cls + 1, embedding_dim=args.hidden_dim)
66
+ self.cls_layers = nn.ModuleList([Dense(args.hidden_dim, args.hidden_dim) for _ in range(self.num_layers)])
67
+
68
+ def forward(self, seq, t, cls = None, return_embedding=False):
69
+ if self.args.clean_data:
70
+ feat = self.linear(seq)
71
+ feat = feat.permute(0, 2, 1)
72
+ else:
73
+ time_emb = F.relu(self.time_embedder(t))
74
+ feat = seq.permute(0, 2, 1)
75
+ feat = F.relu(self.linear(feat))
76
+
77
+ if self.args.cls_free_guidance and not self.classifier and cls is not None:
78
+ cls_emb = self.cls_embedder(cls)
79
+
80
+ for i in range(self.num_layers):
81
+ h = self.dropout(feat.clone())
82
+ if not self.args.clean_data:
83
+ h = h + self.time_layers[i](time_emb)[:, :, None]
84
+ if self.args.cls_free_guidance and not self.classifier and cls is not None:
85
+ h = h + self.cls_layers[i](cls_emb)[:, :, None]
86
+ h = self.norms[i]((h).permute(0, 2, 1))
87
+ h = F.relu(self.convs[i](h.permute(0, 2, 1)))
88
+ if h.shape == feat.shape:
89
+ feat = h + feat
90
+ else:
91
+ feat = h
92
+
93
+ feat = self.final_conv(feat)
94
+ feat = feat.permute(0, 2, 1)
95
+
96
+ if self.classifier:
97
+ feat = feat.mean(dim=1)
98
+ if return_embedding:
99
+ embedding = self.cls_head[:1](feat)
100
+ return self.cls_head[1:](embedding), embedding
101
+ else:
102
+ return self.cls_head(feat)
103
+
104
+ return feat
src/utils/generate_utils.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import sys
3
+ import torch
4
+ import math
5
+
6
+ import numpy as np
7
+ import torch.nn.functional as F
8
+
9
+ from collections import Counter
10
+ from omegaconf import OmegaConf
11
+
12
+
13
+ config = OmegaConf.load("/scratch/pranamlab/sgoel/MeMDLM_v2/src/configs/lm.yaml")
14
+
15
+
16
+ # -------# Masking #-------- #
17
+ def mask_for_de_novo(sequence_length):
18
+ return "<mask>" * sequence_length
19
+
20
+ def mask_for_scaffold(sequence, generate_type, mask_token):
21
+ if generate_type == "uppercase":
22
+ sequence = ''.join([mask_token if residue.isupper() else residue.upper() for residue in sequence])
23
+ elif generate_type == "lowercase":
24
+ sequence = ''.join([mask_token if residue.islower() else residue for residue in sequence])
25
+ return sequence
26
+
27
+
28
+ # -------# Generation #-------- #
29
+ def evodiff_infill(motif_seq, tokenizer, model, device, batch_size=1):
30
+ """
31
+ Following the given evodiff example
32
+ https://github.com/microsoft/evodiff/blob/main/examples/evodiff.ipynb
33
+ """
34
+ # Manual masking of infilling sequence
35
+ motif_seq = ''.join(["#" if aa.islower() else aa for aa in motif_seq]) # Mask token is "#" in evodiff tokenizer
36
+ tkns = tokenizer.tokenize([motif_seq])
37
+ sample = torch.as_tensor(tkns).to(device)
38
+
39
+ # Create input motif + scaffold
40
+ loc = torch.arange(0, len(motif_seq)).to(device)[sample==tokenizer.mask_id].cpu().numpy()
41
+ np.random.shuffle(loc)
42
+
43
+ sample = sample.to(device).unsqueeze(0)
44
+ # og_sample = sample.clone()
45
+
46
+ with torch.no_grad():
47
+ for i in loc:
48
+ timestep = torch.tensor([0] * batch_size).to(device) # placeholder but not called in model
49
+ timestep = timestep.to(device)
50
+ prediction = model(sample, timestep)
51
+ p = prediction[:, i, :len(tokenizer.all_aas) - 6] # only canonical
52
+ p = F.softmax(p, dim=1) # softmax over logits
53
+ p_sample = torch.multinomial(p, num_samples=1) # sample from categorical distribution
54
+ sample[:, i] = p_sample.squeeze()
55
+ output = [tokenizer.untokenize(s) for s in sample]
56
+ return output[0] #if batch_size==1 else output, og_sample, loc
57
+
58
+
59
+ def dplm_infill(masked_seq, tokenizer, model, device):
60
+ from src.lm.dplm.diffusion_module import DPLM
61
+ from src.lm.dplm.unconditional_sampler import UnconditionalSampler as DPLMUnconditionalSampler
62
+
63
+ generator = DPLMUnconditionalSampler(tokenizer, model)
64
+ xt = tokenizer(masked_seq, return_tensors='pt')['input_ids'].to(model.device)
65
+ denoised_tokens = generator.sample_unconditional(xt, config.sampling.n_steps)[0].squeeze()
66
+ generated_sequence = tokenizer.decode(denoised_tokens).replace(" ", "")[5:-5]
67
+ return generated_sequence
68
+
69
+
70
+ # -------# Metrics #-------- #
71
+ def calc_progen_ppl(model, tokenizer, target, device, fp16=True):
72
+ """Compute causal LM cross-entropy loss for a given sequence."""
73
+ with torch.no_grad():
74
+ with torch.cuda.amp.autocast(enabled=fp16):
75
+ logits = model(
76
+ input_ids = target,
77
+ attention_mask = torch.ones_like(target)
78
+ ).logits
79
+ # Shift
80
+ logits = logits[:-1, ...]
81
+ target = target[1:]
82
+ loss = torch.nn.functional.cross_entropy(
83
+ input=logits,
84
+ target=target,
85
+ reduction='mean'
86
+ )
87
+ return torch.exp(loss).item()
88
+
89
+
90
+ def calc_ppl(model, tokenizer, generated_sequence, mask_token_indices, model_type):
91
+ total_loss = 0.0
92
+ tensor_input = tokenizer.encode(generated_sequence, return_tensors='pt').to(model.device)
93
+
94
+ for i in mask_token_indices:
95
+ masked_input = tensor_input.clone()
96
+ masked_input[0, i] = tokenizer.mask_token_id
97
+
98
+ labels = torch.full(tensor_input.shape, -100).to(model.device)
99
+ labels[0, i] = tensor_input[0, i]
100
+
101
+ with torch.no_grad():
102
+ loss = model(masked_input, labels=labels).loss.item()
103
+ total_loss += loss
104
+
105
+ avg_loss = total_loss / len(generated_sequence)
106
+ perplexity = math.exp(avg_loss)
107
+
108
+ return perplexity
109
+
110
+
111
+ def calc_entropy(seq):
112
+ counts = Counter(seq)
113
+ total_len = len(seq)
114
+ entropy = 0.0
115
+ for count in counts.values():
116
+ prob = count / total_len
117
+ entropy -= prob * math.log2(prob)
118
+ return entropy
src/utils/model_utils.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import math
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+
7
+
8
+ def _print(s):
9
+ print(s)
10
+ sys.stdout.flush()
11
+
12
+
13
+ def compute_grad_norms(params):
14
+ """ Compute the norms of a matrix of gradients """
15
+ sqrd_sum = 0.0
16
+ for p in params:
17
+ if p.grad != None:
18
+ sqrd_sum += p.grad.norm(2).item() ** 2
19
+ norm = sqrd_sum ** 0.5
20
+ return norm
21
+
22
+
23
+ class CosineWarmup(torch.optim.lr_scheduler._LRScheduler):
24
+ def __init__(self, optimizer, warmup_steps, total_steps, eta_ratio=0.1, last_epoch=-1):
25
+ self.warmup_steps = warmup_steps
26
+ self.total_steps = total_steps
27
+ self.eta_ratio = eta_ratio # The ratio of minimum to maximum learning rate
28
+ super(CosineWarmup, self).__init__(optimizer, last_epoch)
29
+
30
+ def get_lr(self):
31
+ step = self.last_epoch
32
+ if step < self.warmup_steps:
33
+ return [
34
+ base_lr * self.last_epoch / self.warmup_steps
35
+ for base_lr in self.base_lrs
36
+ ]
37
+
38
+ progress = (step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
39
+ cosine_decay = 0.5 * (1 + np.cos(np.pi * progress))
40
+ lr_mult = (1 - self.eta_ratio) * cosine_decay + self.eta_ratio
41
+
42
+ return [base_lr * lr_mult for base_lr in self.base_lrs]
src/utils/time_utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ # -------------------------
6
+ # Timestep embeddings
7
+ # -------------------------
8
+
9
+ class GaussianFourierProjection(nn.Module):
10
+ """
11
+ Gaussian Fourier features for continuous time t in [0, 1].
12
+ Produces 2 * embed_dim features: [sin(W t), cos(W t)].
13
+ """
14
+ def __init__(self, embed_dim, scale):
15
+ super().__init__()
16
+ assert embed_dim % 2 == 0, "embed_dim must be even."
17
+ self.embed_dim = embed_dim
18
+ self.register_buffer("W", torch.randn(embed_dim // 2) * scale, persistent=False) # Fixed random frequencies
19
+
20
+ def forward(self, t):
21
+ # Ensure float
22
+ t = t.float().unsqueeze(-1) # Broadcoast to [B, 1]
23
+ angles = t * self.W # B, embed_dim // 2
24
+ return torch.cat([torch.sin(angles), torch.cos(angles)], dim=-1)
25
+
26
+
27
+ class TimeEmbedding(nn.Module):
28
+ def __init__(self, hidden_dim, fourier_dim, scale):
29
+ super().__init__()
30
+ assert fourier_dim % 2 == 0, "fourier_dim must be even for sine/cosine pairs."
31
+
32
+ self.fourier = GaussianFourierProjection(fourier_dim, scale)
33
+ self.mlp = nn.Sequential(
34
+ nn.Linear(fourier_dim, hidden_dim),
35
+ nn.SiLU(),
36
+ nn.Linear(hidden_dim, hidden_dim),
37
+ )
38
+
39
+ def forward(self, t):
40
+ ft = self.fourier(t) # (B, fourier_dim)
41
+ return self.mlp(ft) # (B, hidden_dim)