yinuozhang commited on
Commit
8d63dc0
·
1 Parent(s): 64b4595

update model

Browse files
app.py CHANGED
@@ -20,7 +20,6 @@ from inference import (
20
  PeptiVersePredictor,
21
  read_best_manifest_csv,
22
  BestRow,
23
- canon_model,
24
  )
25
 
26
  try:
@@ -75,6 +74,74 @@ ASSETS_DATA = ASSETS / "training_data_cleaned"; ASSETS_DATA.mkdir(parents=True
75
  MODEL_REPO = "ChatterjeeLab/PeptiVerse" # model repo
76
  DATASET_REPO = "ChatterjeeLab/PeptiVerse" # dataset repo
77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  def fetch_models_and_data():
79
  snapshot_download(
80
  repo_id=MODEL_REPO,
@@ -94,8 +161,8 @@ def fetch_models_and_data():
94
  )
95
 
96
  fetch_models_and_data()
97
-
98
- BEST_TXT = Path("best_models.txt")
99
  TRAINING_ROOT = ASSETS_MODELS / "training_classifiers"
100
  TOKENIZER_DIR = ASSETS_MODELS / "tokenizer"
101
 
 
20
  PeptiVersePredictor,
21
  read_best_manifest_csv,
22
  BestRow,
 
23
  )
24
 
25
  try:
 
74
  MODEL_REPO = "ChatterjeeLab/PeptiVerse" # model repo
75
  DATASET_REPO = "ChatterjeeLab/PeptiVerse" # dataset repo
76
 
77
+ def canon_model(parsed) -> Optional[str]:
78
+ """Return the bare lowercase model name from a parsed (model, emb_tag) tuple or raw string."""
79
+ if parsed is None:
80
+ return None
81
+ if isinstance(parsed, tuple):
82
+ return parsed[0].lower() if parsed[0] else None
83
+ return str(parsed).lower()
84
+
85
+ def get_required_patterns(manifest_path: Path) -> List[str]:
86
+ """Build allow_patterns from the manifest so we only download what we need."""
87
+ from inference import read_best_manifest_csv, EMB_TAG_TO_FOLDER_SUFFIX
88
+
89
+ manifest = read_best_manifest_csv(manifest_path)
90
+ patterns = set()
91
+
92
+ patterns.add("tokenizer/new_vocab.txt")
93
+ patterns.add("tokenizer/new_splits.txt")
94
+ patterns.add("training_data_cleaned/**/*.csv")
95
+
96
+ for prop_key, row in manifest.items():
97
+ disk_prop = "half_life" if prop_key == "halflife" else prop_key
98
+
99
+ for parsed in [row.best_wt, row.best_smiles]:
100
+ if parsed is None:
101
+ continue
102
+ model_name, emb_tag = parsed
103
+
104
+ if prop_key == "binding_affinity":
105
+ folder = model_name # e.g. "wt_wt_pooled", "chemberta_smiles_pooled"
106
+ patterns.add(f"training_classifiers/binding_affinity/{folder}/best_model*")
107
+ continue
108
+
109
+ # infer emb_tag fallback
110
+ if emb_tag is None:
111
+ emb_tag = "wt" if parsed == row.best_wt else "smiles"
112
+
113
+ suffix = EMB_TAG_TO_FOLDER_SUFFIX.get(emb_tag, emb_tag)
114
+
115
+ # halflife special cases
116
+ if prop_key == "halflife" and emb_tag == "wt":
117
+ if model_name in {"transformer"}:
118
+ for variant in ["transformer_wt_log", "transformer_wt"]:
119
+ patterns.add(f"training_classifiers/{disk_prop}/{variant}/best_model*")
120
+ continue
121
+ if model_name in {"xgb", "xgb_reg"}:
122
+ patterns.add(f"training_classifiers/{disk_prop}/xgb_wt_log/best_model*")
123
+ continue
124
+
125
+ patterns.add(f"training_classifiers/{disk_prop}/{model_name}_{suffix}/best_model*")
126
+ patterns.add(f"training_classifiers/{disk_prop}/{model_name}/best_model*")
127
+
128
+ return sorted(patterns)
129
+
130
+
131
+ def fetch_models_and_data():
132
+ patterns = get_required_patterns(BEST_TXT)
133
+ print(f"Downloading {len(patterns)} targeted pattern(s):")
134
+ for p in patterns:
135
+ print(f" {p}")
136
+
137
+ snapshot_download(
138
+ repo_id=MODEL_REPO,
139
+ local_dir=str(ASSETS_MODELS),
140
+ local_dir_use_symlinks=False,
141
+ allow_patterns=patterns,
142
+ )
143
+
144
+ """
145
  def fetch_models_and_data():
146
  snapshot_download(
147
  repo_id=MODEL_REPO,
 
161
  )
162
 
163
  fetch_models_and_data()
164
+ """
165
+ BEST_TXT = Path("basic_models.txt")
166
  TRAINING_ROOT = ASSETS_MODELS / "training_classifiers"
167
  TOKENIZER_DIR = ASSETS_MODELS / "tokenizer"
168
 
basic_models.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ Properties, Best_Model_WT, Best_Model_SMILES, Type, Threshold_WT, Threshold_SMILES,
2
+ Hemolysis, XGB, CNN (chemberta), Classifier, 0.2801, 0.564,
3
+ Non-Fouling, Transformer, XGB (peptideclm), Classifier, 0.57, 0.3892,
4
+ Solubility, CNN, Transformer (peptideclm), Classifier, 0.377, 0.329,
5
+ Permeability (Penetrance), XGB, XGB (chemberta), Classifier, 0.4301, 0.5028,
6
+ Toxicity, -, CNN (chemberta), Classifier, -, 0.49,
7
+ Binding_affinity, wt_wt_pooled, chemberta_smiles_pooled, Regression, -, -,
8
+ Permeability_PAMPA, -, CNN (chemberta), Regression, -, -,
9
+ Permeability_CACO2, -, SVR (chemberta), Regression, -, -,
10
+ Halflife, Transformer, XGB (peptideclm), Regression, -, -,
best_models.txt DELETED
@@ -1,10 +0,0 @@
1
- Properties, Best_Model_WT, Best_Model_SMILES, Type, Threshold_WT, Threshold_SMILES,
2
- Hemolysis, XGB, Transformer, Classifier, 0.2801, 0.4343,
3
- Non-Fouling, MLP, XGB, Classifier, 0.57, 0.3982,
4
- Solubility, CNN, -, Classifier, 0.377, -,
5
- Permeability (Penetrance), XGB, -, Classifier, 0.4301, -,
6
- Toxicity, -, Transformer, Classifier, -, 0.3401,
7
- Binding_affinity, unpooled, unpooled, Regression, -, -,
8
- Permeability_PAMPA, -, CNN, Regression, -, -,
9
- Permeability_CACO2, -, SVR, Regression, -, -,
10
- Halflife, xgb_wt_log, xgb_smiles, Regression, -, -,
 
 
 
 
 
 
 
 
 
 
 
description.md CHANGED
@@ -16,8 +16,8 @@
16
  |---|---:|---:|---:|---:|
17
  | Hemolysis | 4765 | 1311 | 4765 | 1311 |
18
  | Non-Fouling | 13580 | 3600 | 13580 | 3600 |
19
- | Solubility | 9668 | 8785 | | |
20
- | Permeability (Penetrance) | 1162 | 1162 | | |
21
  | Toxicity | – | – | 5518 | 5518 |
22
 
23
  #### Regression (total N)
@@ -27,7 +27,7 @@
27
  | Permeability (PAMPA) | – | 6869 |
28
  | Permeability (CACO2) | – | 606 |
29
  | Half-Life | 130 | 245 |
30
- | Binding Affinity | 1436 | 1597 |
31
 
32
 
33
  Our models are trained on curated datasets from multiple sources. For detailed cleaning up procedures please refer to our [paper](https://www.biorxiv.org/content/10.64898/2025.12.31.697180v1).
@@ -90,6 +90,7 @@ Higher scores indicate stronger non-fouling behavior, desirable for circulation
90
 
91
  ### Model Training and Weight Hosting
92
  - More instructions can be found here at [PeptiVersse](https://huggingface.co/ChatterjeeLab/PeptiVerse)
 
93
 
94
  ### 🧪 Physicochemical Properties
95
 
 
16
  |---|---:|---:|---:|---:|
17
  | Hemolysis | 4765 | 1311 | 4765 | 1311 |
18
  | Non-Fouling | 13580 | 3600 | 13580 | 3600 |
19
+ | Solubility | 9668 | 8785 | 9668 | 8785 |
20
+ | Permeability (Penetrance) | 1162 | 1162 | 1162 | 1162 |
21
  | Toxicity | – | – | 5518 | 5518 |
22
 
23
  #### Regression (total N)
 
27
  | Permeability (PAMPA) | – | 6869 |
28
  | Permeability (CACO2) | – | 606 |
29
  | Half-Life | 130 | 245 |
30
+ | Binding Affinity | 1433 | 1702 |
31
 
32
 
33
  Our models are trained on curated datasets from multiple sources. For detailed cleaning up procedures please refer to our [paper](https://www.biorxiv.org/content/10.64898/2025.12.31.697180v1).
 
90
 
91
  ### Model Training and Weight Hosting
92
  - More instructions can be found here at [PeptiVersse](https://huggingface.co/ChatterjeeLab/PeptiVerse)
93
+ - Model uncertainty predictions is not supported for the app version, but the code is available at [PeptiVersse](https://huggingface.co/ChatterjeeLab/PeptiVerse) for local deployment.
94
 
95
  ### 🧪 Physicochemical Properties
96
 
inference.py CHANGED
@@ -1,31 +1,46 @@
1
- # peptiverse_infer.py
2
  from __future__ import annotations
3
-
4
  import csv, re, json
5
  from dataclasses import dataclass
6
  from pathlib import Path
7
  from typing import Dict, Optional, Tuple, Any, List
8
-
9
  import numpy as np
10
  import torch
11
  import torch.nn as nn
12
  import joblib
13
  import xgboost as xgb
14
-
15
  from transformers import EsmModel, EsmTokenizer, AutoModelForMaskedLM
16
  from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
17
-
 
18
 
19
  # -----------------------------
20
  # Manifest
21
  # -----------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  @dataclass(frozen=True)
23
  class BestRow:
24
  property_key: str
25
- best_wt: Optional[str]
26
- best_smiles: Optional[str]
27
- task_type: str # "Classifier" or "Regression"
28
- thr_wt: Optional[float]
29
  thr_smiles: Optional[float]
30
 
31
 
@@ -34,21 +49,16 @@ def _clean(s: str) -> str:
34
 
35
  def _none_if_dash(s: str) -> Optional[str]:
36
  s = _clean(s)
37
- if s in {"", "-", "", "NA", "N/A"}:
38
- return None
39
- return s
40
 
41
  def _float_or_none(s: str) -> Optional[float]:
42
  s = _clean(s)
43
- if s in {"", "-", "", "NA", "N/A"}:
44
- return None
45
- return float(s)
46
 
47
  def normalize_property_key(name: str) -> str:
48
  n = name.strip().lower()
49
  n = re.sub(r"\s*\(.*?\)\s*", "", n)
50
  n = n.replace("-", "_").replace(" ", "_")
51
-
52
  if "permeability" in n and "pampa" not in n and "caco" not in n:
53
  return "permeability_penetrance"
54
  if n == "binding_affinity":
@@ -60,11 +70,40 @@ def normalize_property_key(name: str) -> str:
60
  return n
61
 
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  def read_best_manifest_csv(path: str | Path) -> Dict[str, BestRow]:
64
- """
65
- Properties, Best_Model_WT, Best_Model_SMILES, Type, Threshold_WT, Threshold_SMILES,
66
- Hemolysis, SVM, SGB, Classifier, 0.2801, 0.2223,
67
- """
68
  p = Path(path)
69
  out: Dict[str, BestRow] = {}
70
 
@@ -90,10 +129,13 @@ def read_best_manifest_csv(path: str | Path) -> Dict[str, BestRow]:
90
  continue
91
  prop_key = normalize_property_key(prop_raw)
92
 
 
 
 
93
  row = BestRow(
94
  property_key=prop_key,
95
- best_wt=_none_if_dash(rec.get("Best_Model_WT", "")),
96
- best_smiles=_none_if_dash(rec.get("Best_Model_SMILES", "")),
97
  task_type=_clean(rec.get("Type", "Classifier")),
98
  thr_wt=_float_or_none(rec.get("Threshold_WT", "")),
99
  thr_smiles=_float_or_none(rec.get("Threshold_SMILES", "")),
@@ -103,53 +145,32 @@ def read_best_manifest_csv(path: str | Path) -> Dict[str, BestRow]:
103
  return out
104
 
105
 
106
- MODEL_ALIAS = {
107
- "SVM": "svm_gpu",
108
- "SVR": "svr",
109
- "ENET": "enet_gpu",
110
- "CNN": "cnn",
111
- "MLP": "mlp",
112
- "TRANSFORMER": "transformer",
113
- "XGB": "xgb",
114
- "XGB_REG": "xgb_reg",
115
- "POOLED": "pooled",
116
- "UNPOOLED": "unpooled",
117
- "TRANSFORMER_WT_LOG": "transformer_wt_log",
118
- }
119
- def canon_model(label: Optional[str]) -> Optional[str]:
120
- if label is None:
121
- return None
122
- k = label.strip().upper()
123
- return MODEL_ALIAS.get(k, label.strip().lower())
124
-
125
-
126
  # -----------------------------
127
  # Generic artifact loading
128
  # -----------------------------
129
  def find_best_artifact(model_dir: Path) -> Path:
130
- for pat in ["best_model.json", "best_model.pt", "best_model*.joblib"]:
 
131
  hits = sorted(model_dir.glob(pat))
132
  if hits:
133
  return hits[0]
 
 
 
134
  raise FileNotFoundError(f"No best_model artifact found in {model_dir}")
135
 
136
  def load_artifact(model_dir: Path, device: torch.device) -> Tuple[str, Any, Path]:
137
  art = find_best_artifact(model_dir)
138
-
139
  if art.suffix == ".json":
140
  booster = xgb.Booster()
141
- print(str(art))
142
  booster.load_model(str(art))
143
  return "xgb", booster, art
144
-
145
  if art.suffix == ".joblib":
146
  obj = joblib.load(art)
147
  return "joblib", obj, art
148
-
149
  if art.suffix == ".pt":
150
  ckpt = torch.load(art, map_location=device, weights_only=False)
151
  return "torch_ckpt", ckpt, art
152
-
153
  raise ValueError(f"Unknown artifact type: {art}")
154
 
155
 
@@ -157,7 +178,7 @@ def load_artifact(model_dir: Path, device: torch.device) -> Tuple[str, Any, Path
157
  # NN architectures
158
  # -----------------------------
159
  class MaskedMeanPool(nn.Module):
160
- def forward(self, X, M): # X:(B,L,H), M:(B,L)
161
  Mf = M.unsqueeze(-1).float()
162
  denom = Mf.sum(dim=1).clamp(min=1.0)
163
  return (X * Mf).sum(dim=1) / denom
@@ -167,34 +188,25 @@ class MLPHead(nn.Module):
167
  super().__init__()
168
  self.pool = MaskedMeanPool()
169
  self.net = nn.Sequential(
170
- nn.Linear(in_dim, hidden),
171
- nn.GELU(),
172
- nn.Dropout(dropout),
173
  nn.Linear(hidden, 1),
174
  )
175
  def forward(self, X, M):
176
- z = self.pool(X, M)
177
- return self.net(z).squeeze(-1)
178
 
179
  class CNNHead(nn.Module):
180
  def __init__(self, in_ch, c=256, k=5, layers=2, dropout=0.1):
181
  super().__init__()
182
- blocks = []
183
- ch = in_ch
184
  for _ in range(layers):
185
- blocks += [nn.Conv1d(ch, c, kernel_size=k, padding=k//2),
186
- nn.GELU(),
187
- nn.Dropout(dropout)]
188
  ch = c
189
  self.conv = nn.Sequential(*blocks)
190
  self.head = nn.Linear(c, 1)
191
-
192
  def forward(self, X, M):
193
- Xc = X.transpose(1, 2) # (B,H,L)
194
- Y = self.conv(Xc).transpose(1, 2) # (B,L,C)
195
  Mf = M.unsqueeze(-1).float()
196
- denom = Mf.sum(dim=1).clamp(min=1.0)
197
- pooled = (Y * Mf).sum(dim=1) / denom
198
  return self.head(pooled).squeeze(-1)
199
 
200
  class TransformerHead(nn.Module):
@@ -207,28 +219,44 @@ class TransformerHead(nn.Module):
207
  )
208
  self.enc = nn.TransformerEncoder(enc_layer, num_layers=layers)
209
  self.head = nn.Linear(d_model, 1)
210
-
211
  def forward(self, X, M):
212
- pad_mask = ~M
213
- Z = self.proj(X)
214
- Z = self.enc(Z, src_key_padding_mask=pad_mask)
215
  Mf = M.unsqueeze(-1).float()
216
- denom = Mf.sum(dim=1).clamp(min=1.0)
217
- pooled = (Z * Mf).sum(dim=1) / denom
218
  return self.head(pooled).squeeze(-1)
219
 
220
  def _infer_in_dim_from_sd(sd: dict, model_name: str) -> int:
221
- if model_name == "mlp":
222
- return int(sd["net.0.weight"].shape[1])
223
- if model_name == "cnn":
224
- return int(sd["conv.0.weight"].shape[1])
225
- if model_name == "transformer":
226
- return int(sd["proj.weight"].shape[1])
227
  raise ValueError(model_name)
228
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  def build_torch_model_from_ckpt(model_name: str, ckpt: dict, device: torch.device) -> nn.Module:
230
  params = ckpt["best_params"]
231
- sd = ckpt["state_dict"]
232
  in_dim = int(ckpt.get("in_dim", _infer_in_dim_from_sd(sd, model_name)))
233
  dropout = float(params.get("dropout", 0.1))
234
 
@@ -238,39 +266,132 @@ def build_torch_model_from_ckpt(model_name: str, ckpt: dict, device: torch.devic
238
  model = CNNHead(in_ch=in_dim, c=int(params["channels"]), k=int(params["kernel"]),
239
  layers=int(params["layers"]), dropout=dropout)
240
  elif model_name == "transformer":
241
- d_model = (
242
- params.get("d_model")
243
- or params.get("hidden")
244
- or params.get("hidden_dim")
245
- )
246
  if d_model is None:
247
- raise KeyError(
248
- f"Transformer checkpoint missing d_model/hidden. "
249
- f"Available keys: {list(params.keys())}"
 
 
 
 
 
 
 
 
 
 
 
 
250
  )
251
-
252
- model = TransformerHead(
253
- in_dim=in_dim,
254
- d_model=int(d_model),
255
- nhead=int(params["nhead"]),
256
- layers=int(params["layers"]),
257
- ff=int(params.get("ff", 4 * int(d_model))),
258
- dropout=dropout
259
- )
260
  else:
261
  raise ValueError(f"Unknown NN model_name={model_name}")
262
 
263
  model.load_state_dict(sd)
264
- model.to(device)
265
- model.eval()
266
  return model
267
 
268
 
269
  # -----------------------------
270
- # Binding affinity models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  # -----------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  def affinity_to_class(y: float) -> int:
273
- # 0=High(>=9), 1=Moderate(7-9), 2=Low(<7)
274
  if y >= 9.0: return 0
275
  if y < 7.0: return 2
276
  return 1
@@ -280,38 +401,31 @@ class CrossAttnPooled(nn.Module):
280
  super().__init__()
281
  self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden))
282
  self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden))
283
-
284
  self.layers = nn.ModuleList([])
285
  for _ in range(n_layers):
286
  self.layers.append(nn.ModuleDict({
287
  "attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False),
288
  "attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False),
289
- "n1t": nn.LayerNorm(hidden),
290
- "n2t": nn.LayerNorm(hidden),
291
- "n1b": nn.LayerNorm(hidden),
292
- "n2b": nn.LayerNorm(hidden),
293
  "fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
294
  "ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
295
  }))
296
-
297
  self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout))
298
  self.reg = nn.Linear(hidden, 1)
299
  self.cls = nn.Linear(hidden, 3)
300
 
301
  def forward(self, t_vec, b_vec):
302
- t = self.t_proj(t_vec).unsqueeze(0) # (1,B,H)
303
- b = self.b_proj(b_vec).unsqueeze(0) # (1,B,H)
304
  for L in self.layers:
305
  t_attn, _ = L["attn_tb"](t, b, b)
306
  t = L["n1t"]((t + t_attn).transpose(0,1)).transpose(0,1)
307
  t = L["n2t"]((t + L["fft"](t)).transpose(0,1)).transpose(0,1)
308
-
309
  b_attn, _ = L["attn_bt"](b, t, t)
310
  b = L["n1b"]((b + b_attn).transpose(0,1)).transpose(0,1)
311
  b = L["n2b"]((b + L["ffb"](b)).transpose(0,1)).transpose(0,1)
312
-
313
- z = torch.cat([t[0], b[0]], dim=-1)
314
- h = self.shared(z)
315
  return self.reg(h).squeeze(-1), self.cls(h)
316
 
317
  class CrossAttnUnpooled(nn.Module):
@@ -319,344 +433,247 @@ class CrossAttnUnpooled(nn.Module):
319
  super().__init__()
320
  self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden))
321
  self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden))
322
-
323
  self.layers = nn.ModuleList([])
324
  for _ in range(n_layers):
325
  self.layers.append(nn.ModuleDict({
326
  "attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
327
  "attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
328
- "n1t": nn.LayerNorm(hidden),
329
- "n2t": nn.LayerNorm(hidden),
330
- "n1b": nn.LayerNorm(hidden),
331
- "n2b": nn.LayerNorm(hidden),
332
  "fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
333
  "ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
334
  }))
335
-
336
  self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout))
337
  self.reg = nn.Linear(hidden, 1)
338
  self.cls = nn.Linear(hidden, 3)
339
 
340
  def _masked_mean(self, X, M):
341
  Mf = M.unsqueeze(-1).float()
342
- denom = Mf.sum(dim=1).clamp(min=1.0)
343
- return (X * Mf).sum(dim=1) / denom
344
 
345
  def forward(self, T, Mt, B, Mb):
346
- T = self.t_proj(T)
347
- Bx = self.b_proj(B)
348
- kp_t = ~Mt
349
- kp_b = ~Mb
350
-
351
  for L in self.layers:
352
  T_attn, _ = L["attn_tb"](T, Bx, Bx, key_padding_mask=kp_b)
353
- T = L["n1t"](T + T_attn)
354
- T = L["n2t"](T + L["fft"](T))
355
-
356
  B_attn, _ = L["attn_bt"](Bx, T, T, key_padding_mask=kp_t)
357
- Bx = L["n1b"](Bx + B_attn)
358
- Bx = L["n2b"](Bx + L["ffb"](Bx))
359
-
360
- t_pool = self._masked_mean(T, Mt)
361
- b_pool = self._masked_mean(Bx, Mb)
362
- z = torch.cat([t_pool, b_pool], dim=-1)
363
- h = self.shared(z)
364
  return self.reg(h).squeeze(-1), self.cls(h)
365
 
366
  def load_binding_model(best_model_pt: Path, pooled_or_unpooled: str, device: torch.device) -> nn.Module:
367
  ckpt = torch.load(best_model_pt, map_location=device, weights_only=False)
368
  params = ckpt["best_params"]
369
- sd = ckpt["state_dict"]
370
-
371
- # infer Ht/Hb from projection weights
372
  Ht = int(sd["t_proj.0.weight"].shape[1])
373
  Hb = int(sd["b_proj.0.weight"].shape[1])
374
-
375
- common = dict(
376
- Ht=Ht, Hb=Hb,
377
- hidden=int(params["hidden_dim"]),
378
- n_heads=int(params["n_heads"]),
379
- n_layers=int(params["n_layers"]),
380
- dropout=float(params["dropout"]),
381
- )
382
-
383
- if pooled_or_unpooled == "pooled":
384
- model = CrossAttnPooled(**common)
385
- elif pooled_or_unpooled == "unpooled":
386
- model = CrossAttnUnpooled(**common)
387
- else:
388
- raise ValueError(pooled_or_unpooled)
389
-
390
  model.load_state_dict(sd)
391
- model.to(device).eval()
392
- return model
393
 
394
 
395
  # -----------------------------
396
  # Embedding generation
397
  # -----------------------------
398
  def _safe_isin(ids: torch.Tensor, test_ids: torch.Tensor) -> torch.Tensor:
399
- """
400
- Pytorch patch
401
- """
402
  if hasattr(torch, "isin"):
403
  return torch.isin(ids, test_ids)
404
- # Fallback: compare against each special id
405
- # (B,L,1) == (1,1,K) -> (B,L,K)
406
  return (ids.unsqueeze(-1) == test_ids.view(1, 1, -1)).any(dim=-1)
407
-
408
  class SMILESEmbedder:
409
- """
410
- PeptideCLM RoFormer embeddings for SMILES.
411
- - pooled(): mean over tokens where attention_mask==1 AND token_id not in SPECIAL_IDS
412
- - unpooled(): returns token embeddings filtered to valid tokens (specials removed),
413
- plus a 1-mask of length Li (since already filtered).
414
- """
415
- def __init__(
416
- self,
417
- device: torch.device,
418
- vocab_path: str,
419
- splits_path: str,
420
- clm_name: str = "aaronfeller/PeptideCLM-23M-all",
421
- max_len: int = 512,
422
- use_cache: bool = True,
423
- ):
424
  self.device = device
425
  self.max_len = max_len
426
  self.use_cache = use_cache
427
-
428
  self.tokenizer = SMILES_SPE_Tokenizer(vocab_path, splits_path)
429
  self.model = AutoModelForMaskedLM.from_pretrained(clm_name).roformer.to(device).eval()
430
-
431
  self.special_ids = self._get_special_ids(self.tokenizer)
432
  self.special_ids_t = (torch.tensor(self.special_ids, device=device, dtype=torch.long)
433
- if len(self.special_ids) else None)
434
-
435
  self._cache_pooled: Dict[str, torch.Tensor] = {}
436
  self._cache_unpooled: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
437
 
438
  @staticmethod
439
  def _get_special_ids(tokenizer) -> List[int]:
440
- cand = [
441
- getattr(tokenizer, "pad_token_id", None),
442
- getattr(tokenizer, "cls_token_id", None),
443
- getattr(tokenizer, "sep_token_id", None),
444
- getattr(tokenizer, "bos_token_id", None),
445
- getattr(tokenizer, "eos_token_id", None),
446
- getattr(tokenizer, "mask_token_id", None),
447
- ]
448
  return sorted({int(x) for x in cand if x is not None})
449
 
450
- def _tokenize(self, smiles_list: List[str]) -> Dict[str, torch.Tensor]:
451
- tok = self.tokenizer(
452
- smiles_list,
453
- return_tensors="pt",
454
- padding=True,
455
- truncation=True,
456
- max_length=self.max_len,
457
- )
458
- for k in tok:
459
- tok[k] = tok[k].to(self.device)
460
  if "attention_mask" not in tok:
461
  tok["attention_mask"] = torch.ones_like(tok["input_ids"], dtype=torch.long, device=self.device)
462
  return tok
463
 
 
 
 
 
 
 
464
  @torch.no_grad()
465
  def pooled(self, smiles: str) -> torch.Tensor:
466
  s = smiles.strip()
467
- if self.use_cache and s in self._cache_pooled:
468
- return self._cache_pooled[s]
 
 
 
 
 
 
469
 
 
 
 
 
470
  tok = self._tokenize([s])
471
- ids = tok["input_ids"] # (1,L)
472
- attn = tok["attention_mask"].bool() # (1,L)
 
 
 
 
473
 
474
- out = self.model(input_ids=ids, attention_mask=tok["attention_mask"])
475
- h = out.last_hidden_state # (1,L,H)
476
 
477
- valid = attn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
478
  if self.special_ids_t is not None and self.special_ids_t.numel() > 0:
479
  valid = valid & (~_safe_isin(ids, self.special_ids_t))
 
480
 
 
 
 
 
 
 
 
481
  vf = valid.unsqueeze(-1).float()
482
- summed = (h * vf).sum(dim=1) # (1,H)
483
- denom = vf.sum(dim=1).clamp(min=1e-9) # (1,1)
484
- pooled = summed / denom # (1,H)
485
-
486
- if self.use_cache:
487
- self._cache_pooled[s] = pooled
488
  return pooled
489
 
490
  @torch.no_grad()
491
  def unpooled(self, smiles: str) -> Tuple[torch.Tensor, torch.Tensor]:
492
- """
493
- Returns:
494
- X: (1, Li, H) float32 on device
495
- M: (1, Li) bool on device
496
- where Li excludes padding + special tokens.
497
- """
498
  s = smiles.strip()
499
- if self.use_cache and s in self._cache_unpooled:
500
- return self._cache_unpooled[s]
501
-
502
  tok = self._tokenize([s])
503
- ids = tok["input_ids"] # (1,L)
504
- attn = tok["attention_mask"].bool() # (1,L)
505
-
506
- out = self.model(input_ids=ids, attention_mask=tok["attention_mask"])
507
- h = out.last_hidden_state # (1,L,H)
508
-
509
- valid = attn
510
- if self.special_ids_t is not None and self.special_ids_t.numel() > 0:
511
- valid = valid & (~_safe_isin(ids, self.special_ids_t))
512
-
513
- # filter valid tokens
514
- keep = valid[0] # (L,)
515
- X = h[:, keep, :] # (1,Li,H)
516
  M = torch.ones((1, X.shape[1]), dtype=torch.bool, device=self.device)
517
-
518
- if self.use_cache:
519
- self._cache_unpooled[s] = (X, M)
520
  return X, M
521
 
522
 
523
  class WTEmbedder:
524
- """
525
- ESM2 embeddings for AA sequences.
526
- - pooled(): mean over tokens where attention_mask==1 AND token_id not in {CLS, EOS, PAD,...}
527
- - unpooled(): returns token embeddings filtered to valid tokens (specials removed),
528
- plus a 1-mask of length Li (since already filtered).
529
- """
530
- def __init__(
531
- self,
532
- device: torch.device,
533
- esm_name: str = "facebook/esm2_t33_650M_UR50D",
534
- max_len: int = 1022,
535
- use_cache: bool = True,
536
- ):
537
  self.device = device
538
  self.max_len = max_len
539
  self.use_cache = use_cache
540
-
541
  self.tokenizer = EsmTokenizer.from_pretrained(esm_name)
542
  self.model = EsmModel.from_pretrained(esm_name, add_pooling_layer=False).to(device).eval()
543
-
544
  self.special_ids = self._get_special_ids(self.tokenizer)
545
  self.special_ids_t = (torch.tensor(self.special_ids, device=device, dtype=torch.long)
546
- if len(self.special_ids) else None)
547
-
548
  self._cache_pooled: Dict[str, torch.Tensor] = {}
549
  self._cache_unpooled: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
550
 
551
  @staticmethod
552
  def _get_special_ids(tokenizer) -> List[int]:
553
- cand = [
554
- getattr(tokenizer, "pad_token_id", None),
555
- getattr(tokenizer, "cls_token_id", None),
556
- getattr(tokenizer, "sep_token_id", None),
557
- getattr(tokenizer, "bos_token_id", None),
558
- getattr(tokenizer, "eos_token_id", None),
559
- getattr(tokenizer, "mask_token_id", None),
560
- ]
561
  return sorted({int(x) for x in cand if x is not None})
562
 
563
- def _tokenize(self, seq_list: List[str]) -> Dict[str, torch.Tensor]:
564
- tok = self.tokenizer(
565
- seq_list,
566
- return_tensors="pt",
567
- padding=True,
568
- truncation=True,
569
- max_length=self.max_len,
570
- )
571
  tok = {k: v.to(self.device) for k, v in tok.items()}
572
  if "attention_mask" not in tok:
573
  tok["attention_mask"] = torch.ones_like(tok["input_ids"], dtype=torch.long, device=self.device)
574
  return tok
575
 
 
 
 
 
 
 
576
  @torch.no_grad()
577
  def pooled(self, seq: str) -> torch.Tensor:
578
  s = seq.strip()
579
- if self.use_cache and s in self._cache_pooled:
580
- return self._cache_pooled[s]
581
-
582
  tok = self._tokenize([s])
583
- ids = tok["input_ids"] # (1,L)
584
- attn = tok["attention_mask"].bool() # (1,L)
585
-
586
- out = self.model(**tok)
587
- h = out.last_hidden_state # (1,L,H)
588
-
589
- valid = attn
590
- if self.special_ids_t is not None and self.special_ids_t.numel() > 0:
591
- valid = valid & (~_safe_isin(ids, self.special_ids_t))
592
-
593
  vf = valid.unsqueeze(-1).float()
594
- summed = (h * vf).sum(dim=1) # (1,H)
595
- denom = vf.sum(dim=1).clamp(min=1e-9) # (1,1)
596
- pooled = summed / denom # (1,H)
597
-
598
- if self.use_cache:
599
- self._cache_pooled[s] = pooled
600
  return pooled
601
 
602
  @torch.no_grad()
603
  def unpooled(self, seq: str) -> Tuple[torch.Tensor, torch.Tensor]:
604
- """
605
- Returns:
606
- X: (1, Li, H) float32 on device
607
- M: (1, Li) bool on device
608
- where Li excludes padding + special tokens.
609
- """
610
  s = seq.strip()
611
- if self.use_cache and s in self._cache_unpooled:
612
- return self._cache_unpooled[s]
613
-
614
  tok = self._tokenize([s])
615
- ids = tok["input_ids"] # (1,L)
616
- attn = tok["attention_mask"].bool() # (1,L)
617
-
618
- out = self.model(**tok)
619
- h = out.last_hidden_state # (1,L,H)
620
-
621
- valid = attn
622
- if self.special_ids_t is not None and self.special_ids_t.numel() > 0:
623
- valid = valid & (~_safe_isin(ids, self.special_ids_t))
624
-
625
- keep = valid[0] # (L,)
626
- X = h[:, keep, :] # (1,Li,H)
627
  M = torch.ones((1, X.shape[1]), dtype=torch.bool, device=self.device)
628
-
629
- if self.use_cache:
630
- self._cache_unpooled[s] = (X, M)
631
  return X, M
632
 
633
- def _clean_state_dict(sd: dict) -> dict:
634
- # just for wt halflife transformer predictor
635
- out = {}
636
- for k, v in sd.items():
637
- if k.startswith("module."):
638
- k = k[len("module."):]
639
- if k.startswith("model."):
640
- k = k[len("model."):]
641
- out[k] = v
642
- return out
643
-
644
 
645
  # -----------------------------
646
  # Predictor
647
  # -----------------------------
 
648
  class PeptiVersePredictor:
649
- """
650
- - loads best models from training_classifiers/
651
- - computes embeddings as needed (pooled/unpooled)
652
- - supports: xgb, joblib(ENET/SVM/SVR), NN(mlp/cnn/transformer), binding pooled/unpooled.
653
- """
654
  def __init__(
655
  self,
656
  manifest_path: str | Path,
657
  classifier_weight_root: str | Path,
658
  esm_name="facebook/esm2_t33_650M_UR50D",
659
  clm_name="aaronfeller/PeptideCLM-23M-all",
 
660
  smiles_vocab="tokenizer/new_vocab.txt",
661
  smiles_splits="tokenizer/new_splits.txt",
662
  device: Optional[str] = None,
@@ -667,291 +684,398 @@ class PeptiVersePredictor:
667
 
668
  self.manifest = read_best_manifest_csv(manifest_path)
669
 
670
- self.wt_embedder = WTEmbedder(self.device)
671
- self.smiles_embedder = SMILESEmbedder(self.device, clm_name=clm_name,
672
- vocab_path=str(self.root / smiles_vocab),
673
- splits_path=str(self.root / smiles_splits))
 
674
 
675
- self.models: Dict[Tuple[str, str], Any] = {}
676
- self.meta: Dict[Tuple[str, str], Dict[str, Any]] = {}
 
 
677
 
678
  self._load_all_best_models()
679
 
680
- def _resolve_dir(self, prop_key: str, model_name: str, mode: str) -> Path:
681
- # map halflife -> half_life folder on disk (common layout)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
682
  disk_prop = "half_life" if prop_key == "halflife" else prop_key
683
  base = self.training_root / disk_prop
684
 
685
- # special handling for halflife xgb_wt_log / xgb_smiles
686
- if prop_key == "halflife" and model_name in {"xgb_wt_log", "xgb_smiles"}:
687
- d = base / model_name
688
- if d.exists():
689
- return d
690
 
691
- if prop_key == "halflife" and model_name == "xgb":
692
- d = base / ("xgb_wt_log" if mode == "wt" else "xgb_smiles")
693
- if d.exists():
694
- return d
 
 
 
695
 
696
  candidates = [
697
- base / f"{model_name}_{mode}",
698
  base / model_name,
699
  ]
700
- if mode == "wt":
701
- candidates += [base / f"{model_name}_wt"]
702
- if mode == "smiles":
703
- candidates += [base / f"{model_name}_smiles"]
704
-
705
  for d in candidates:
706
- if d.exists():
707
- return d
708
 
709
  raise FileNotFoundError(
710
- f"Cannot find model directory for {prop_key} {model_name} {mode}. Tried: {candidates}"
711
  )
712
 
713
-
714
  def _load_all_best_models(self):
715
  for prop_key, row in self.manifest.items():
716
- for mode, label, thr in [
717
- ("wt", row.best_wt, row.thr_wt),
718
- ("smiles", row.best_smiles, row.thr_smiles),
719
  ]:
720
- m = canon_model(label)
721
- if m is None:
722
  continue
 
723
 
724
- # ---- binding affinity special ----
725
  if prop_key == "binding_affinity":
726
- # label is pooled/unpooled; mode chooses folder wt_wt_* vs wt_smiles_*
727
- pooled_or_unpooled = m # "pooled" or "unpooled"
728
- folder = f"wt_{mode}_{pooled_or_unpooled}" # wt_wt_pooled / wt_smiles_unpooled etc.
729
  model_dir = self.training_root / "binding_affinity" / folder
730
  art = find_best_artifact(model_dir)
731
- if art.suffix != ".pt":
732
- raise RuntimeError(f"Binding model expected best_model.pt, got {art}")
733
- model = load_binding_model(art, pooled_or_unpooled=pooled_or_unpooled, device=self.device)
734
- self.models[(prop_key, mode)] = model
735
- self.meta[(prop_key, mode)] = {
736
- "task_type": "Regression",
737
- "threshold": None,
738
- "artifact": str(art),
739
- "model_name": pooled_or_unpooled,
 
740
  }
 
 
 
 
 
 
 
 
 
 
741
  continue
742
 
743
- model_dir = self._resolve_dir(prop_key, m, mode)
 
 
 
 
744
  kind, obj, art = load_artifact(model_dir, self.device)
745
 
746
- if kind in {"xgb", "joblib"}:
747
- self.models[(prop_key, mode)] = obj
 
748
  else:
749
- # rebuild NN architecture
750
- arch = m
751
- if arch.startswith("transformer"):
752
- arch = "transformer"
753
- elif arch.startswith("mlp"):
754
- arch = "mlp"
755
- elif arch.startswith("cnn"):
756
- arch = "cnn"
757
- if prop_key == "halflife" and mode == "wt" and m == "transformer_wt_log":
758
- if isinstance(obj, dict) and "state_dict" in obj:
759
- obj = dict(obj)
760
- obj["state_dict"] = _clean_state_dict(obj["state_dict"])
761
-
762
- self.models[(prop_key, mode)] = build_torch_model_from_ckpt(arch, obj, self.device)
763
-
764
- self.meta[(prop_key, mode)] = {
765
- "task_type": row.task_type,
766
- "threshold": thr,
767
- "artifact": str(art),
768
- "model_name": m,
769
- "kind": kind,
770
- }
771
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
772
 
773
- def _get_features_for_model(self, prop_key: str, mode: str, input_str: str):
774
- """
775
- Returns either:
776
- - pooled np array shape (1,H) for xgb/joblib
777
- - unpooled torch tensors (X,M) for NN
778
- """
779
- model = self.models[(prop_key, mode)]
780
- meta = self.meta[(prop_key, mode)]
781
- kind = meta.get("kind", None)
782
- model_name = meta.get("model_name", "")
783
 
784
- if prop_key == "binding_affinity":
785
- raise RuntimeError("Use predict_binding_affinity().")
786
-
787
- # If torch NN: needs unpooled
 
 
 
 
 
 
 
 
788
  if kind == "torch_ckpt":
789
- if mode == "wt":
790
- X, M = self.wt_embedder.unpooled(input_str)
791
- else:
792
- X, M = self.smiles_embedder.unpooled(input_str)
793
- return X, M
794
-
795
- # Otherwise pooled vectors for xgb/joblib
796
- if mode == "wt":
797
- v = self.wt_embedder.pooled(input_str) # (1,H)
798
- else:
799
- v = self.smiles_embedder.pooled(input_str) # (1,H)
800
- feats = v.detach().cpu().numpy().astype(np.float32)
801
- feats = np.nan_to_num(feats, nan=0.0)
802
- feats = np.clip(feats, np.finfo(np.float32).min, np.finfo(np.float32).max)
803
- return feats
804
-
805
- def predict_property(self, prop_key: str, mode: str, input_str: str) -> Dict[str, Any]:
806
- """
807
- mode: "wt" for AA sequence input, "smiles" for SMILES input
808
- Returns dict with score + label if classifier threshold exists.
809
- """
810
- if (prop_key, mode) not in self.models:
811
- raise KeyError(f"No model loaded for ({prop_key}, {mode}). Check manifest and folders.")
812
-
813
- meta = self.meta[(prop_key, mode)]
814
- model = self.models[(prop_key, mode)]
815
- task_type = meta["task_type"].lower()
816
- thr = meta.get("threshold", None)
817
- kind = meta.get("kind", None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
818
 
819
  if prop_key == "binding_affinity":
820
  raise RuntimeError("Use predict_binding_affinity().")
821
 
822
- # NN path (logits / regression)
823
  if kind == "torch_ckpt":
824
- X, M = self._get_features_for_model(prop_key, mode, input_str)
825
  with torch.no_grad():
826
- y = model(X, M).squeeze().float().cpu().item()
827
- # invert log1p(hours) ONLY for WT half-life log models
828
- model_name = meta.get("model_name", "")
829
- if (
830
- prop_key == "halflife"
831
- and mode == "wt"
832
- and model_name in {"xgb_wt_log", "transformer_wt_log"}
833
- ):
834
- y = float(np.expm1(y))
835
  if task_type == "classifier":
836
- prob = float(1.0 / (1.0 + np.exp(-y))) # sigmoid(logit)
837
- out = {"property": prop_key, "mode": mode, "score": prob}
 
838
  if thr is not None:
839
- out["label"] = int(prob >= float(thr))
840
- out["threshold"] = float(thr)
841
- return out
842
  else:
843
- return {"property": prop_key, "mode": mode, "score": float(y)}
844
-
845
- if kind == "xgb":
846
- feats = self._get_features_for_model(prop_key, mode, input_str)
847
- dmat = xgb.DMatrix(feats)
848
- pred = float(model.predict(dmat)[0])
849
-
850
- # invert log1p(hours) ONLY for WT half-life log models
851
- model_name = meta.get("model_name", "")
852
- if (
853
- prop_key == "halflife"
854
- and mode == "wt"
855
- and model_name in {"xgb_wt_log", "transformer_wt_log"}
856
- ):
857
  pred = float(np.expm1(pred))
858
-
859
- out = {"property": prop_key, "mode": mode, "score": pred}
860
-
861
- return out
862
-
863
- # joblib path (svm/enet/svr)
864
- if kind == "joblib":
865
- feats = self._get_features_for_model(prop_key, mode, input_str) # (1,H)
866
- # classifier vs regressor behavior differs by estimator
867
  if task_type == "classifier":
868
  if hasattr(model, "predict_proba"):
869
  pred = float(model.predict_proba(feats)[:, 1][0])
 
 
870
  else:
871
- if hasattr(model, "decision_function"):
872
- logit = float(model.decision_function(feats)[0])
873
- pred = float(1.0 / (1.0 + np.exp(-logit)))
874
- else:
875
- pred = float(model.predict(feats)[0])
876
- out = {"property": prop_key, "mode": mode, "score": pred}
877
  if thr is not None:
878
- out["label"] = int(pred >= float(thr))
879
- out["threshold"] = float(thr)
880
- return out
881
  else:
882
  pred = float(model.predict(feats)[0])
883
- return {"property": prop_key, "mode": mode, "score": pred}
884
-
885
- raise RuntimeError(f"Unknown model kind={kind}")
 
886
 
887
- def predict_binding_affinity(self, mode: str, target_seq: str, binder_str: str) -> Dict[str, Any]:
888
- """
889
- mode: "wt" (binder is AA sequence) -> wt_wt_(pooled|unpooled)
890
- "smiles" (binder is SMILES) -> wt_smiles_(pooled|unpooled)
891
- """
892
- prop_key = "binding_affinity"
893
- if (prop_key, mode) not in self.models:
894
- raise KeyError(f"No binding model loaded for ({prop_key}, {mode}).")
895
 
896
- model = self.models[(prop_key, mode)]
897
- pooled_or_unpooled = self.meta[(prop_key, mode)]["model_name"] # pooled/unpooled
898
 
899
- # target is always WT sequence (ESM)
900
- if pooled_or_unpooled == "pooled":
901
- t_vec = self.wt_embedder.pooled(target_seq) # (1,Ht)
902
- if mode == "wt":
903
- b_vec = self.wt_embedder.pooled(binder_str) # (1,Hb)
904
- else:
905
- b_vec = self.smiles_embedder.pooled(binder_str) # (1,Hb)
 
 
 
 
 
 
 
 
906
  with torch.no_grad():
907
  reg, logits = model(t_vec, b_vec)
908
- affinity = float(reg.squeeze().cpu().item())
909
- cls_logit = int(torch.argmax(logits, dim=-1).cpu().item())
910
- cls_thr = affinity_to_class(affinity)
911
  else:
912
  T, Mt = self.wt_embedder.unpooled(target_seq)
913
- if mode == "wt":
914
- B, Mb = self.wt_embedder.unpooled(binder_str)
915
- else:
916
- B, Mb = self.smiles_embedder.unpooled(binder_str)
917
  with torch.no_grad():
918
  reg, logits = model(T, Mt, B, Mb)
919
- affinity = float(reg.squeeze().cpu().item())
920
- cls_logit = int(torch.argmax(logits, dim=-1).cpu().item())
921
- cls_thr = affinity_to_class(affinity)
922
-
923
- names = {0: "High (≥9)", 1: "Moderate (7-9)", 2: "Low (<7)"}
924
- return {
925
- "property": "binding_affinity",
926
- "mode": mode,
927
- "affinity": affinity,
 
928
  "class_by_threshold": names[cls_thr],
929
- "class_by_logits": names[cls_logit],
930
- "binding_model": pooled_or_unpooled,
931
  }
932
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
933
 
934
  if __name__ == "__main__":
935
- predictor = PeptiVersePredictor(
936
- manifest_path="best_models.txt",
937
- classifier_weight_root="./Classifier_Weight"
938
- )
939
- print(predictor.predict_property("hemolysis", "wt", "GIGAVLKVLTTGLPALISWIKRKRQQ"))
940
- print(predictor.predict_binding_affinity("wt", target_seq="...", binder_str="..."))
941
 
942
- # Test Embedding #
943
- """
944
- device = torch.device("cuda:0")
945
-
946
- wt = WTEmbedder(device)
947
- sm = SMILESEmbedder(device,
948
- vocab_path="./tokeizner/new_vocab.txt",
949
- splits_path="./tokenizer/new_splits.txt"
950
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
951
 
952
- p = wt.pooled("GIGAVLKVLTTGLPALISWIKRKRQQ") # (1,1280)
953
- X, M = wt.unpooled("GIGAVLKVLTTGLPALISWIKRKRQQ") # (1,Li,1280), (1,Li)
954
-
955
- p2 = sm.pooled("NCC(=O)N[C@H](CS)C(=O)O") # (1,H_smiles)
956
- X2, M2 = sm.unpooled("NCC(=O)N[C@H](CS)C(=O)O") # (1,Li,H_smiles), (1,Li)
957
- """
 
 
1
  from __future__ import annotations
 
2
  import csv, re, json
3
  from dataclasses import dataclass
4
  from pathlib import Path
5
  from typing import Dict, Optional, Tuple, Any, List
 
6
  import numpy as np
7
  import torch
8
  import torch.nn as nn
9
  import joblib
10
  import xgboost as xgb
 
11
  from transformers import EsmModel, EsmTokenizer, AutoModelForMaskedLM
12
  from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
13
+ from lightning.pytorch import seed_everything
14
+ seed_everything(1986)
15
 
16
  # -----------------------------
17
  # Manifest
18
  # -----------------------------
19
+
20
+ EMB_TAG_TO_FOLDER_SUFFIX = {
21
+ "wt": "wt",
22
+ "peptideclm": "smiles",
23
+ "chemberta": "chemberta",
24
+ }
25
+
26
+ EMB_TAG_TO_RUNTIME_MODE = {
27
+ "wt": "wt",
28
+ "peptideclm": "smiles",
29
+ "chemberta": "chemberta",
30
+ }
31
+
32
+ MAPIE_REGRESSION_MODELS = {"svr", "enet_gpu"}
33
+ DNN_ARCHS = {"mlp", "cnn", "transformer"}
34
+ XGB_MODELS = {"xgb", "xgb_reg", "xgb_wt_log", "xgb_smiles"}
35
+
36
+
37
  @dataclass(frozen=True)
38
  class BestRow:
39
  property_key: str
40
+ best_wt: Optional[Tuple[str, Optional[str]]]
41
+ best_smiles: Optional[Tuple[str, Optional[str]]]
42
+ task_type: str
43
+ thr_wt: Optional[float]
44
  thr_smiles: Optional[float]
45
 
46
 
 
49
 
50
  def _none_if_dash(s: str) -> Optional[str]:
51
  s = _clean(s)
52
+ return None if s in {"", "-", "-", "NA", "N/A"} else s
 
 
53
 
54
  def _float_or_none(s: str) -> Optional[float]:
55
  s = _clean(s)
56
+ return None if s in {"", "-", "-", "NA", "N/A"} else float(s)
 
 
57
 
58
  def normalize_property_key(name: str) -> str:
59
  n = name.strip().lower()
60
  n = re.sub(r"\s*\(.*?\)\s*", "", n)
61
  n = n.replace("-", "_").replace(" ", "_")
 
62
  if "permeability" in n and "pampa" not in n and "caco" not in n:
63
  return "permeability_penetrance"
64
  if n == "binding_affinity":
 
70
  return n
71
 
72
 
73
+ MODEL_ALIAS = {
74
+ "SVM": "svm_gpu",
75
+ "SVR": "svr",
76
+ "ENET": "enet_gpu",
77
+ "CNN": "cnn",
78
+ "MLP": "mlp",
79
+ "TRANSFORMER": "transformer",
80
+ "XGB": "xgb",
81
+ "XGB_REG": "xgb_reg",
82
+ "POOLED": "pooled",
83
+ "UNPOOLED": "unpooled",
84
+ "TRANSFORMER_WT_LOG": "transformer_wt_log",
85
+ }
86
+
87
+ def _parse_model_and_emb(raw: Optional[str]) -> Optional[Tuple[str, Optional[str]]]:
88
+ if raw is None:
89
+ return None
90
+ raw = _clean(raw)
91
+ if not raw or raw in {"-", "-", "NA", "N/A"}:
92
+ return None
93
+
94
+ m = re.match(r"^(.+?)\s*\((.+?)\)\s*$", raw)
95
+ if m:
96
+ model_raw = m.group(1).strip()
97
+ emb_tag = m.group(2).strip().lower()
98
+ else:
99
+ model_raw = raw
100
+ emb_tag = None
101
+
102
+ canon = MODEL_ALIAS.get(model_raw.upper(), model_raw.lower())
103
+ return canon, emb_tag
104
+
105
+
106
  def read_best_manifest_csv(path: str | Path) -> Dict[str, BestRow]:
 
 
 
 
107
  p = Path(path)
108
  out: Dict[str, BestRow] = {}
109
 
 
129
  continue
130
  prop_key = normalize_property_key(prop_raw)
131
 
132
+ best_wt = _parse_model_and_emb(_none_if_dash(rec.get("Best_Model_WT", "")))
133
+ best_smiles = _parse_model_and_emb(_none_if_dash(rec.get("Best_Model_SMILES", "")))
134
+
135
  row = BestRow(
136
  property_key=prop_key,
137
+ best_wt=best_wt,
138
+ best_smiles=best_smiles,
139
  task_type=_clean(rec.get("Type", "Classifier")),
140
  thr_wt=_float_or_none(rec.get("Threshold_WT", "")),
141
  thr_smiles=_float_or_none(rec.get("Threshold_SMILES", "")),
 
145
  return out
146
 
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  # -----------------------------
149
  # Generic artifact loading
150
  # -----------------------------
151
  def find_best_artifact(model_dir: Path) -> Path:
152
+ for pat in ["best_model.json", "best_model.pt", "best_model*.joblib",
153
+ "model.json", "model.ubj", "final_model.json"]:
154
  hits = sorted(model_dir.glob(pat))
155
  if hits:
156
  return hits[0]
157
+ seed_pt = model_dir / "seed_1986" / "model.pt"
158
+ if seed_pt.exists():
159
+ return seed_pt
160
  raise FileNotFoundError(f"No best_model artifact found in {model_dir}")
161
 
162
  def load_artifact(model_dir: Path, device: torch.device) -> Tuple[str, Any, Path]:
163
  art = find_best_artifact(model_dir)
 
164
  if art.suffix == ".json":
165
  booster = xgb.Booster()
 
166
  booster.load_model(str(art))
167
  return "xgb", booster, art
 
168
  if art.suffix == ".joblib":
169
  obj = joblib.load(art)
170
  return "joblib", obj, art
 
171
  if art.suffix == ".pt":
172
  ckpt = torch.load(art, map_location=device, weights_only=False)
173
  return "torch_ckpt", ckpt, art
 
174
  raise ValueError(f"Unknown artifact type: {art}")
175
 
176
 
 
178
  # NN architectures
179
  # -----------------------------
180
  class MaskedMeanPool(nn.Module):
181
+ def forward(self, X, M):
182
  Mf = M.unsqueeze(-1).float()
183
  denom = Mf.sum(dim=1).clamp(min=1.0)
184
  return (X * Mf).sum(dim=1) / denom
 
188
  super().__init__()
189
  self.pool = MaskedMeanPool()
190
  self.net = nn.Sequential(
191
+ nn.Linear(in_dim, hidden), nn.GELU(), nn.Dropout(dropout),
 
 
192
  nn.Linear(hidden, 1),
193
  )
194
  def forward(self, X, M):
195
+ return self.net(self.pool(X, M)).squeeze(-1)
 
196
 
197
  class CNNHead(nn.Module):
198
  def __init__(self, in_ch, c=256, k=5, layers=2, dropout=0.1):
199
  super().__init__()
200
+ blocks, ch = [], in_ch
 
201
  for _ in range(layers):
202
+ blocks += [nn.Conv1d(ch, c, kernel_size=k, padding=k//2), nn.GELU(), nn.Dropout(dropout)]
 
 
203
  ch = c
204
  self.conv = nn.Sequential(*blocks)
205
  self.head = nn.Linear(c, 1)
 
206
  def forward(self, X, M):
207
+ Y = self.conv(X.transpose(1, 2)).transpose(1, 2)
 
208
  Mf = M.unsqueeze(-1).float()
209
+ pooled = (Y * Mf).sum(dim=1) / Mf.sum(dim=1).clamp(min=1.0)
 
210
  return self.head(pooled).squeeze(-1)
211
 
212
  class TransformerHead(nn.Module):
 
219
  )
220
  self.enc = nn.TransformerEncoder(enc_layer, num_layers=layers)
221
  self.head = nn.Linear(d_model, 1)
 
222
  def forward(self, X, M):
223
+ Z = self.enc(self.proj(X), src_key_padding_mask=~M)
 
 
224
  Mf = M.unsqueeze(-1).float()
225
+ pooled = (Z * Mf).sum(dim=1) / Mf.sum(dim=1).clamp(min=1.0)
 
226
  return self.head(pooled).squeeze(-1)
227
 
228
  def _infer_in_dim_from_sd(sd: dict, model_name: str) -> int:
229
+ if model_name == "mlp": return int(sd["net.0.weight"].shape[1])
230
+ if model_name == "cnn": return int(sd["conv.0.weight"].shape[1])
231
+ if model_name == "transformer": return int(sd["proj.weight"].shape[1])
 
 
 
232
  raise ValueError(model_name)
233
 
234
+ def _infer_num_layers_from_sd(sd: dict, prefix: str = "enc.layers.") -> int:
235
+ idxs = set()
236
+ for k in sd.keys():
237
+ if k.startswith(prefix):
238
+ m = re.match(r"(\d+)\.", k[len(prefix):])
239
+ if m:
240
+ idxs.add(int(m.group(1)))
241
+ return (max(idxs) + 1) if idxs else 1
242
+
243
+ def _infer_transformer_arch_from_sd(sd: dict) -> Tuple[int, int, int]:
244
+ if "proj.weight" not in sd:
245
+ raise KeyError("Missing proj.weight in state_dict")
246
+ d_model = int(sd["proj.weight"].shape[0])
247
+ layers = _infer_num_layers_from_sd(sd, prefix="enc.layers.")
248
+ ff = int(sd["enc.layers.0.linear1.weight"].shape[0]) if "enc.layers.0.linear1.weight" in sd else 4 * d_model
249
+ return d_model, layers, ff
250
+
251
+ def _pick_nhead(d_model: int) -> int:
252
+ for h in (8, 6, 4, 3, 2, 1):
253
+ if d_model % h == 0:
254
+ return h
255
+ return 1
256
+
257
  def build_torch_model_from_ckpt(model_name: str, ckpt: dict, device: torch.device) -> nn.Module:
258
  params = ckpt["best_params"]
259
+ sd = ckpt["state_dict"]
260
  in_dim = int(ckpt.get("in_dim", _infer_in_dim_from_sd(sd, model_name)))
261
  dropout = float(params.get("dropout", 0.1))
262
 
 
266
  model = CNNHead(in_ch=in_dim, c=int(params["channels"]), k=int(params["kernel"]),
267
  layers=int(params["layers"]), dropout=dropout)
268
  elif model_name == "transformer":
269
+ d_model = params.get("d_model") or params.get("hidden") or params.get("hidden_dim")
 
 
 
 
270
  if d_model is None:
271
+ d_model_i, layers_i, ff_i = _infer_transformer_arch_from_sd(sd)
272
+ nhead_i = _pick_nhead(d_model_i)
273
+ model = TransformerHead(
274
+ in_dim=in_dim, d_model=int(d_model_i), nhead=int(params.get("nhead", nhead_i)),
275
+ layers=int(params.get("layers", layers_i)), ff=int(params.get("ff", ff_i)),
276
+ dropout=float(params.get("dropout", dropout)),
277
+ )
278
+ else:
279
+ d_model = int(d_model)
280
+ model = TransformerHead(
281
+ in_dim=in_dim, d_model=d_model,
282
+ nhead=int(params.get("nhead", _pick_nhead(d_model))),
283
+ layers=int(params.get("layers", 2)),
284
+ ff=int(params.get("ff", 4 * d_model)),
285
+ dropout=dropout,
286
  )
 
 
 
 
 
 
 
 
 
287
  else:
288
  raise ValueError(f"Unknown NN model_name={model_name}")
289
 
290
  model.load_state_dict(sd)
291
+ model.to(device).eval()
 
292
  return model
293
 
294
 
295
  # -----------------------------
296
+ # Wrappers
297
+ # -----------------------------
298
+ from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
299
+
300
+ class PassthroughRegressor(BaseEstimator, RegressorMixin):
301
+ def __init__(self, preds: np.ndarray):
302
+ self.preds = preds
303
+ def fit(self, X, y): return self
304
+ def predict(self, X): return self.preds[:len(X)]
305
+
306
+ class PassthroughClassifier(BaseEstimator, ClassifierMixin):
307
+ def __init__(self, preds: np.ndarray):
308
+ self.preds = preds
309
+ self.classes_ = np.array([0, 1])
310
+ def fit(self, X, y): return self
311
+ def predict(self, X): return (self.preds[:len(X)] >= 0.5).astype(int)
312
+ def predict_proba(self, X):
313
+ p = self.preds[:len(X)]
314
+ return np.stack([1 - p, p], axis=1)
315
+
316
+
317
+ # -----------------------------
318
+ # Uncertainty helpers
319
  # -----------------------------
320
+ SEED_DIRS = ["seed_1986", "seed_42", "seed_0", "seed_123", "seed_12345"]
321
+
322
+ def load_seed_ensemble(model_dir: Path, arch: str, device: torch.device) -> List[nn.Module]:
323
+ ensemble = []
324
+ for sd_name in SEED_DIRS:
325
+ pt = model_dir / sd_name / "model.pt"
326
+ if not pt.exists():
327
+ continue
328
+ ckpt = torch.load(pt, map_location=device, weights_only=False)
329
+ ensemble.append(build_torch_model_from_ckpt(arch, ckpt, device))
330
+ return ensemble
331
+
332
+ def _binary_entropy(p: float) -> float:
333
+ p = float(np.clip(p, 1e-9, 1 - 1e-9))
334
+ return float(-p * np.log(p) - (1 - p) * np.log(1 - p))
335
+
336
+ def _ensemble_clf_uncertainty(ensemble: List[nn.Module], X: torch.Tensor, M: torch.Tensor) -> float:
337
+ probs = []
338
+ with torch.no_grad():
339
+ for m in ensemble:
340
+ logit = m(X, M).squeeze().float().cpu().item()
341
+ probs.append(1.0 / (1.0 + np.exp(-logit)))
342
+ return _binary_entropy(float(np.mean(probs)))
343
+
344
+ def _ensemble_reg_uncertainty(ensemble: List[nn.Module], X: torch.Tensor, M: torch.Tensor) -> float:
345
+ preds = []
346
+ with torch.no_grad():
347
+ for m in ensemble:
348
+ preds.append(m(X, M).squeeze().float().cpu().item())
349
+ return float(np.std(preds))
350
+
351
+ def _mapie_uncertainty(mapie_bundle: dict, score: float,
352
+ embedding: Optional[np.ndarray] = None) -> Tuple[float, float]:
353
+ """
354
+ Returns (ci_low, ci_high) from a conformal bundle.
355
+ - adaptive: {"quantile": q, "sigma_model": xgb, "emb_tag": ..., "adaptive": True}
356
+ Input-dependent: interval = score +/- q * sigma(embedding)
357
+ - plain_quantile: {"quantile": q, "alpha": ...}
358
+ Fixed-width: interval = score +/- q
359
+ """
360
+ # Adaptive format is input-dependent interval
361
+ if mapie_bundle.get("adaptive") and "sigma_model" in mapie_bundle:
362
+ q = float(mapie_bundle["quantile"])
363
+ if embedding is not None:
364
+ # Adaptive interval: y_hat ± q * sigma_hat(x).
365
+ # Equivalent to MAPIE's get_estimation_distribution():
366
+ # y_pred + conformity_scores * r_pred
367
+ # where conformity_scores=q and r_pred=sigma_hat(x).
368
+ # (ResidualNormalisedScore, Cordier et al. 2023)
369
+ sigma_model = mapie_bundle["sigma_model"]
370
+ sigma = float(sigma_model.predict(xgb.DMatrix(embedding.reshape(1, -1)))[0])
371
+ sigma = max(sigma, 1e-6)
372
+ else:
373
+ # No embedding available - fall back to fixed interval with sigma=1
374
+ sigma = 1.0
375
+ return float(score - q * sigma), float(score + q * sigma)
376
+
377
+ # Plain quantile format
378
+ if "quantile" in mapie_bundle:
379
+ q = float(mapie_bundle["quantile"])
380
+ return float(score - q), float(score + q)
381
+
382
+ X_dummy = np.zeros((1, 1))
383
+ result = mapie.predict(X_dummy)
384
+ if isinstance(result, tuple):
385
+ intervals = np.asarray(result[1])
386
+ if intervals.ndim == 3:
387
+ return float(intervals[0, 0, 0]), float(intervals[0, 1, 0])
388
+ return float(intervals[0, 0]), float(intervals[0, 1])
389
+ raise RuntimeError(
390
+ f"Cannot extract intervals: unknown MAPIE bundle format. "
391
+ f"Bundle keys: {list(mapie_bundle.keys())}."
392
+ )
393
+
394
  def affinity_to_class(y: float) -> int:
 
395
  if y >= 9.0: return 0
396
  if y < 7.0: return 2
397
  return 1
 
401
  super().__init__()
402
  self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden))
403
  self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden))
 
404
  self.layers = nn.ModuleList([])
405
  for _ in range(n_layers):
406
  self.layers.append(nn.ModuleDict({
407
  "attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False),
408
  "attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False),
409
+ "n1t": nn.LayerNorm(hidden), "n2t": nn.LayerNorm(hidden),
410
+ "n1b": nn.LayerNorm(hidden), "n2b": nn.LayerNorm(hidden),
 
 
411
  "fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
412
  "ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
413
  }))
 
414
  self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout))
415
  self.reg = nn.Linear(hidden, 1)
416
  self.cls = nn.Linear(hidden, 3)
417
 
418
  def forward(self, t_vec, b_vec):
419
+ t = self.t_proj(t_vec).unsqueeze(0)
420
+ b = self.b_proj(b_vec).unsqueeze(0)
421
  for L in self.layers:
422
  t_attn, _ = L["attn_tb"](t, b, b)
423
  t = L["n1t"]((t + t_attn).transpose(0,1)).transpose(0,1)
424
  t = L["n2t"]((t + L["fft"](t)).transpose(0,1)).transpose(0,1)
 
425
  b_attn, _ = L["attn_bt"](b, t, t)
426
  b = L["n1b"]((b + b_attn).transpose(0,1)).transpose(0,1)
427
  b = L["n2b"]((b + L["ffb"](b)).transpose(0,1)).transpose(0,1)
428
+ h = self.shared(torch.cat([t[0], b[0]], dim=-1))
 
 
429
  return self.reg(h).squeeze(-1), self.cls(h)
430
 
431
  class CrossAttnUnpooled(nn.Module):
 
433
  super().__init__()
434
  self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden))
435
  self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden))
 
436
  self.layers = nn.ModuleList([])
437
  for _ in range(n_layers):
438
  self.layers.append(nn.ModuleDict({
439
  "attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
440
  "attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
441
+ "n1t": nn.LayerNorm(hidden), "n2t": nn.LayerNorm(hidden),
442
+ "n1b": nn.LayerNorm(hidden), "n2b": nn.LayerNorm(hidden),
 
 
443
  "fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
444
  "ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
445
  }))
 
446
  self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout))
447
  self.reg = nn.Linear(hidden, 1)
448
  self.cls = nn.Linear(hidden, 3)
449
 
450
  def _masked_mean(self, X, M):
451
  Mf = M.unsqueeze(-1).float()
452
+ return (X * Mf).sum(dim=1) / Mf.sum(dim=1).clamp(min=1.0)
 
453
 
454
  def forward(self, T, Mt, B, Mb):
455
+ T = self.t_proj(T); Bx = self.b_proj(B)
456
+ kp_t, kp_b = ~Mt, ~Mb
 
 
 
457
  for L in self.layers:
458
  T_attn, _ = L["attn_tb"](T, Bx, Bx, key_padding_mask=kp_b)
459
+ T = L["n1t"](T + T_attn); T = L["n2t"](T + L["fft"](T))
 
 
460
  B_attn, _ = L["attn_bt"](Bx, T, T, key_padding_mask=kp_t)
461
+ Bx = L["n1b"](Bx + B_attn); Bx = L["n2b"](Bx + L["ffb"](Bx))
462
+ h = self.shared(torch.cat([self._masked_mean(T, Mt), self._masked_mean(Bx, Mb)], dim=-1))
 
 
 
 
 
463
  return self.reg(h).squeeze(-1), self.cls(h)
464
 
465
  def load_binding_model(best_model_pt: Path, pooled_or_unpooled: str, device: torch.device) -> nn.Module:
466
  ckpt = torch.load(best_model_pt, map_location=device, weights_only=False)
467
  params = ckpt["best_params"]
468
+ sd = ckpt["state_dict"]
 
 
469
  Ht = int(sd["t_proj.0.weight"].shape[1])
470
  Hb = int(sd["b_proj.0.weight"].shape[1])
471
+ common = dict(Ht=Ht, Hb=Hb, hidden=int(params["hidden_dim"]),
472
+ n_heads=int(params["n_heads"]), n_layers=int(params["n_layers"]),
473
+ dropout=float(params["dropout"]))
474
+ cls = CrossAttnPooled if pooled_or_unpooled == "pooled" else CrossAttnUnpooled
475
+ model = cls(**common)
 
 
 
 
 
 
 
 
 
 
 
476
  model.load_state_dict(sd)
477
+ return model.to(device).eval()
 
478
 
479
 
480
  # -----------------------------
481
  # Embedding generation
482
  # -----------------------------
483
  def _safe_isin(ids: torch.Tensor, test_ids: torch.Tensor) -> torch.Tensor:
 
 
 
484
  if hasattr(torch, "isin"):
485
  return torch.isin(ids, test_ids)
 
 
486
  return (ids.unsqueeze(-1) == test_ids.view(1, 1, -1)).any(dim=-1)
487
+
488
  class SMILESEmbedder:
489
+ def __init__(self, device, vocab_path, splits_path,
490
+ clm_name="aaronfeller/PeptideCLM-23M-all", max_len=512, use_cache=True):
 
 
 
 
 
 
 
 
 
 
 
 
 
491
  self.device = device
492
  self.max_len = max_len
493
  self.use_cache = use_cache
 
494
  self.tokenizer = SMILES_SPE_Tokenizer(vocab_path, splits_path)
495
  self.model = AutoModelForMaskedLM.from_pretrained(clm_name).roformer.to(device).eval()
 
496
  self.special_ids = self._get_special_ids(self.tokenizer)
497
  self.special_ids_t = (torch.tensor(self.special_ids, device=device, dtype=torch.long)
498
+ if self.special_ids else None)
 
499
  self._cache_pooled: Dict[str, torch.Tensor] = {}
500
  self._cache_unpooled: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
501
 
502
  @staticmethod
503
  def _get_special_ids(tokenizer) -> List[int]:
504
+ cand = [getattr(tokenizer, f"{x}_token_id", None)
505
+ for x in ("pad", "cls", "sep", "bos", "eos", "mask")]
 
 
 
 
 
 
506
  return sorted({int(x) for x in cand if x is not None})
507
 
508
+ def _tokenize(self, smiles_list):
509
+ tok = self.tokenizer(smiles_list, return_tensors="pt", padding=True,
510
+ truncation=True, max_length=self.max_len)
511
+ for k in tok: tok[k] = tok[k].to(self.device)
 
 
 
 
 
 
512
  if "attention_mask" not in tok:
513
  tok["attention_mask"] = torch.ones_like(tok["input_ids"], dtype=torch.long, device=self.device)
514
  return tok
515
 
516
+ def _valid_mask(self, ids, attn):
517
+ valid = attn.bool()
518
+ if self.special_ids_t is not None and self.special_ids_t.numel() > 0:
519
+ valid = valid & (~_safe_isin(ids, self.special_ids_t))
520
+ return valid
521
+
522
  @torch.no_grad()
523
  def pooled(self, smiles: str) -> torch.Tensor:
524
  s = smiles.strip()
525
+ if self.use_cache and s in self._cache_pooled: return self._cache_pooled[s]
526
+ tok = self._tokenize([s])
527
+ h = self.model(input_ids=tok["input_ids"], attention_mask=tok["attention_mask"]).last_hidden_state
528
+ valid = self._valid_mask(tok["input_ids"], tok["attention_mask"])
529
+ vf = valid.unsqueeze(-1).float()
530
+ pooled = (h * vf).sum(dim=1) / vf.sum(dim=1).clamp(min=1e-9)
531
+ if self.use_cache: self._cache_pooled[s] = pooled
532
+ return pooled
533
 
534
+ @torch.no_grad()
535
+ def unpooled(self, smiles: str) -> Tuple[torch.Tensor, torch.Tensor]:
536
+ s = smiles.strip()
537
+ if self.use_cache and s in self._cache_unpooled: return self._cache_unpooled[s]
538
  tok = self._tokenize([s])
539
+ h = self.model(input_ids=tok["input_ids"], attention_mask=tok["attention_mask"]).last_hidden_state
540
+ valid = self._valid_mask(tok["input_ids"], tok["attention_mask"])
541
+ X = h[:, valid[0], :]
542
+ M = torch.ones((1, X.shape[1]), dtype=torch.bool, device=self.device)
543
+ if self.use_cache: self._cache_unpooled[s] = (X, M)
544
+ return X, M
545
 
 
 
546
 
547
+ class ChemBERTaEmbedder:
548
+ def __init__(self, device, model_name="DeepChem/ChemBERTa-77M-MLM",
549
+ max_len=512, use_cache=True):
550
+ from transformers import AutoTokenizer, AutoModel
551
+ self.device = device
552
+ self.max_len = max_len
553
+ self.use_cache = use_cache
554
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
555
+ self.model = AutoModel.from_pretrained(model_name).to(device).eval()
556
+ self.special_ids = self._get_special_ids(self.tokenizer)
557
+ self.special_ids_t = (torch.tensor(self.special_ids, device=device, dtype=torch.long)
558
+ if self.special_ids else None)
559
+ self._cache_pooled: Dict[str, torch.Tensor] = {}
560
+ self._cache_unpooled: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
561
+
562
+ @staticmethod
563
+ def _get_special_ids(tokenizer) -> List[int]:
564
+ cand = [getattr(tokenizer, f"{x}_token_id", None)
565
+ for x in ("pad", "cls", "sep", "bos", "eos", "mask")]
566
+ return sorted({int(x) for x in cand if x is not None})
567
+
568
+ def _tokenize(self, smiles_list):
569
+ tok = self.tokenizer(smiles_list, return_tensors="pt", padding=True,
570
+ truncation=True, max_length=self.max_len)
571
+ for k in tok: tok[k] = tok[k].to(self.device)
572
+ if "attention_mask" not in tok:
573
+ tok["attention_mask"] = torch.ones_like(tok["input_ids"], dtype=torch.long, device=self.device)
574
+ return tok
575
+
576
+ def _valid_mask(self, ids, attn):
577
+ valid = attn.bool()
578
  if self.special_ids_t is not None and self.special_ids_t.numel() > 0:
579
  valid = valid & (~_safe_isin(ids, self.special_ids_t))
580
+ return valid
581
 
582
+ @torch.no_grad()
583
+ def pooled(self, smiles: str) -> torch.Tensor:
584
+ s = smiles.strip()
585
+ if self.use_cache and s in self._cache_pooled: return self._cache_pooled[s]
586
+ tok = self._tokenize([s])
587
+ h = self.model(input_ids=tok["input_ids"], attention_mask=tok["attention_mask"]).last_hidden_state
588
+ valid = self._valid_mask(tok["input_ids"], tok["attention_mask"])
589
  vf = valid.unsqueeze(-1).float()
590
+ pooled = (h * vf).sum(dim=1) / vf.sum(dim=1).clamp(min=1e-9)
591
+ if self.use_cache: self._cache_pooled[s] = pooled
 
 
 
 
592
  return pooled
593
 
594
  @torch.no_grad()
595
  def unpooled(self, smiles: str) -> Tuple[torch.Tensor, torch.Tensor]:
 
 
 
 
 
 
596
  s = smiles.strip()
597
+ if self.use_cache and s in self._cache_unpooled: return self._cache_unpooled[s]
 
 
598
  tok = self._tokenize([s])
599
+ h = self.model(input_ids=tok["input_ids"], attention_mask=tok["attention_mask"]).last_hidden_state
600
+ valid = self._valid_mask(tok["input_ids"], tok["attention_mask"])
601
+ X = h[:, valid[0], :]
 
 
 
 
 
 
 
 
 
 
602
  M = torch.ones((1, X.shape[1]), dtype=torch.bool, device=self.device)
603
+ if self.use_cache: self._cache_unpooled[s] = (X, M)
 
 
604
  return X, M
605
 
606
 
607
  class WTEmbedder:
608
+ def __init__(self, device, esm_name="facebook/esm2_t33_650M_UR50D", max_len=1022, use_cache=True):
 
 
 
 
 
 
 
 
 
 
 
 
609
  self.device = device
610
  self.max_len = max_len
611
  self.use_cache = use_cache
 
612
  self.tokenizer = EsmTokenizer.from_pretrained(esm_name)
613
  self.model = EsmModel.from_pretrained(esm_name, add_pooling_layer=False).to(device).eval()
 
614
  self.special_ids = self._get_special_ids(self.tokenizer)
615
  self.special_ids_t = (torch.tensor(self.special_ids, device=device, dtype=torch.long)
616
+ if self.special_ids else None)
 
617
  self._cache_pooled: Dict[str, torch.Tensor] = {}
618
  self._cache_unpooled: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
619
 
620
  @staticmethod
621
  def _get_special_ids(tokenizer) -> List[int]:
622
+ cand = [getattr(tokenizer, f"{x}_token_id", None)
623
+ for x in ("pad", "cls", "sep", "bos", "eos", "mask")]
 
 
 
 
 
 
624
  return sorted({int(x) for x in cand if x is not None})
625
 
626
+ def _tokenize(self, seq_list):
627
+ tok = self.tokenizer(seq_list, return_tensors="pt", padding=True,
628
+ truncation=True, max_length=self.max_len)
 
 
 
 
 
629
  tok = {k: v.to(self.device) for k, v in tok.items()}
630
  if "attention_mask" not in tok:
631
  tok["attention_mask"] = torch.ones_like(tok["input_ids"], dtype=torch.long, device=self.device)
632
  return tok
633
 
634
+ def _valid_mask(self, ids, attn):
635
+ valid = attn.bool()
636
+ if self.special_ids_t is not None and self.special_ids_t.numel() > 0:
637
+ valid = valid & (~_safe_isin(ids, self.special_ids_t))
638
+ return valid
639
+
640
  @torch.no_grad()
641
  def pooled(self, seq: str) -> torch.Tensor:
642
  s = seq.strip()
643
+ if self.use_cache and s in self._cache_pooled: return self._cache_pooled[s]
 
 
644
  tok = self._tokenize([s])
645
+ h = self.model(**tok).last_hidden_state
646
+ valid = self._valid_mask(tok["input_ids"], tok["attention_mask"])
 
 
 
 
 
 
 
 
647
  vf = valid.unsqueeze(-1).float()
648
+ pooled = (h * vf).sum(dim=1) / vf.sum(dim=1).clamp(min=1e-9)
649
+ if self.use_cache: self._cache_pooled[s] = pooled
 
 
 
 
650
  return pooled
651
 
652
  @torch.no_grad()
653
  def unpooled(self, seq: str) -> Tuple[torch.Tensor, torch.Tensor]:
 
 
 
 
 
 
654
  s = seq.strip()
655
+ if self.use_cache and s in self._cache_unpooled: return self._cache_unpooled[s]
 
 
656
  tok = self._tokenize([s])
657
+ h = self.model(**tok).last_hidden_state
658
+ valid = self._valid_mask(tok["input_ids"], tok["attention_mask"])
659
+ X = h[:, valid[0], :]
 
 
 
 
 
 
 
 
 
660
  M = torch.ones((1, X.shape[1]), dtype=torch.bool, device=self.device)
661
+ if self.use_cache: self._cache_unpooled[s] = (X, M)
 
 
662
  return X, M
663
 
 
 
 
 
 
 
 
 
 
 
 
664
 
665
  # -----------------------------
666
  # Predictor
667
  # -----------------------------
668
+
669
  class PeptiVersePredictor:
 
 
 
 
 
670
  def __init__(
671
  self,
672
  manifest_path: str | Path,
673
  classifier_weight_root: str | Path,
674
  esm_name="facebook/esm2_t33_650M_UR50D",
675
  clm_name="aaronfeller/PeptideCLM-23M-all",
676
+ chemberta_name="DeepChem/ChemBERTa-77M-MLM",
677
  smiles_vocab="tokenizer/new_vocab.txt",
678
  smiles_splits="tokenizer/new_splits.txt",
679
  device: Optional[str] = None,
 
684
 
685
  self.manifest = read_best_manifest_csv(manifest_path)
686
 
687
+ self.wt_embedder = WTEmbedder(self.device, esm_name=esm_name)
688
+ self.smiles_embedder = SMILESEmbedder(self.device, clm_name=clm_name,
689
+ vocab_path=str(self.root / smiles_vocab),
690
+ splits_path=str(self.root / smiles_splits))
691
+ self.chemberta_embedder = ChemBERTaEmbedder(self.device, model_name=chemberta_name)
692
 
693
+ self.models: Dict[Tuple[str, str], Any] = {}
694
+ self.meta: Dict[Tuple[str, str], Dict[str, Any]] = {}
695
+ self.mapie: Dict[Tuple[str, str], dict] = {}
696
+ self.ensembles: Dict[Tuple[str, str], List] = {}
697
 
698
  self._load_all_best_models()
699
 
700
+ def _get_embedder(self, emb_tag: str):
701
+ if emb_tag == "wt": return self.wt_embedder
702
+ if emb_tag == "peptideclm": return self.smiles_embedder
703
+ if emb_tag == "chemberta": return self.chemberta_embedder
704
+ raise ValueError(f"Unknown emb_tag={emb_tag!r}")
705
+
706
+ def _embed_pooled(self, emb_tag: str, input_str: str) -> np.ndarray:
707
+ v = self._get_embedder(emb_tag).pooled(input_str)
708
+ feats = v.detach().cpu().numpy().astype(np.float32)
709
+ feats = np.nan_to_num(feats, nan=0.0)
710
+ return np.clip(feats, np.finfo(np.float32).min, np.finfo(np.float32).max)
711
+
712
+ def _embed_unpooled(self, emb_tag: str, input_str: str) -> Tuple[torch.Tensor, torch.Tensor]:
713
+ return self._get_embedder(emb_tag).unpooled(input_str)
714
+
715
+ def _resolve_dir(self, prop_key: str, model_name: str, emb_tag: str) -> Path:
716
  disk_prop = "half_life" if prop_key == "halflife" else prop_key
717
  base = self.training_root / disk_prop
718
 
719
+ folder_suffix = EMB_TAG_TO_FOLDER_SUFFIX.get(emb_tag, emb_tag)
 
 
 
 
720
 
721
+ if prop_key == "halflife" and emb_tag == "wt":
722
+ if model_name == "transformer":
723
+ for d in [base / "transformer_wt_log", base / "transformer_wt"]:
724
+ if d.exists(): return d
725
+ if model_name in {"xgb", "xgb_reg"}:
726
+ d = base / "xgb_wt_log"
727
+ if d.exists(): return d
728
 
729
  candidates = [
730
+ base / f"{model_name}_{folder_suffix}",
731
  base / model_name,
732
  ]
 
 
 
 
 
733
  for d in candidates:
734
+ if d.exists(): return d
 
735
 
736
  raise FileNotFoundError(
737
+ f"Cannot find model dir for {prop_key}/{model_name}/{emb_tag}. Tried: {candidates}"
738
  )
739
 
 
740
  def _load_all_best_models(self):
741
  for prop_key, row in self.manifest.items():
742
+ for col, parsed, thr in [
743
+ ("wt", row.best_wt, row.thr_wt),
744
+ ("smiles", row.best_smiles, row.thr_smiles),
745
  ]:
746
+ if parsed is None:
 
747
  continue
748
+ model_name, emb_tag = parsed
749
 
750
+ # binding affinity
751
  if prop_key == "binding_affinity":
752
+ folder = model_name
753
+ pooled_or_unpooled = "unpooled" if "unpooled" in folder else "pooled"
 
754
  model_dir = self.training_root / "binding_affinity" / folder
755
  art = find_best_artifact(model_dir)
756
+ model = load_binding_model(art, pooled_or_unpooled, self.device)
757
+ self.models[(prop_key, col)] = model
758
+ self.meta[(prop_key, col)] = {
759
+ "task_type": "Regression",
760
+ "threshold": None,
761
+ "artifact": str(art),
762
+ "model_name": pooled_or_unpooled,
763
+ "emb_tag": emb_tag,
764
+ "folder": folder,
765
+ "kind": "binding",
766
  }
767
+ print(f" [LOAD] binding_affinity ({col}): folder={folder}, arch={pooled_or_unpooled}, emb_tag={emb_tag}, art={art.name}")
768
+ mapie_path = model_dir / "mapie_calibration.joblib"
769
+ if mapie_path.exists():
770
+ try:
771
+ self.mapie[(prop_key, col)] = joblib.load(mapie_path)
772
+ print(f" MAPIE loaded from {mapie_path.name}")
773
+ except Exception as e:
774
+ print(f" MAPIE load FAILED for ({prop_key}, {col}): {e}")
775
+ else:
776
+ print(f" No MAPIE bundle found (uncertainty will be unavailable)")
777
  continue
778
 
779
+ # infer emb_tag
780
+ if emb_tag is None:
781
+ emb_tag = col
782
+
783
+ model_dir = self._resolve_dir(prop_key, model_name, emb_tag)
784
  kind, obj, art = load_artifact(model_dir, self.device)
785
 
786
+ if kind == "torch_ckpt":
787
+ arch = self._base_arch(model_name)
788
+ model = build_torch_model_from_ckpt(arch, obj, self.device)
789
  else:
790
+ model = obj
791
+
792
+ self.models[(prop_key, col)] = model
793
+ self.meta[(prop_key, col)] = {
794
+ "task_type": row.task_type,
795
+ "threshold": thr,
796
+ "artifact": str(art),
797
+ "model_name": model_name,
798
+ "emb_tag": emb_tag,
799
+ "kind": kind,
800
+ }
801
+
802
+ print(f" [LOAD] ({prop_key}, {col}): kind={kind}, model={model_name}, emb={emb_tag}, task={row.task_type}, art={art.name}")
803
+
804
+ # MAPIE: SVR/ElasticNet, XGBoost regression, AND all regression torch_ckpt
805
+ is_regression = row.task_type.lower() == "regression"
806
+ wants_mapie = (
807
+ (model_name in MAPIE_REGRESSION_MODELS and is_regression)
808
+ or (kind == "xgb" and is_regression)
809
+ or (kind == "torch_ckpt" and is_regression)
810
+ )
811
+ if wants_mapie:
812
+ mapie_path = model_dir / "mapie_calibration.joblib"
813
+ if mapie_path.exists():
814
+ try:
815
+ self.mapie[(prop_key, col)] = joblib.load(mapie_path)
816
+ print(f" MAPIE loaded from {mapie_path.name}")
817
+ except Exception as e:
818
+ print(f" MAPIE load FAILED for ({prop_key}, {col}): {e}")
819
+ else:
820
+ print(f" No MAPIE bundle found at {mapie_path} (will fall back to ensemble if available)")
821
+
822
+ # Seed ensembles: DNN only, used when MAPIE not available
823
+ if kind == "torch_ckpt":
824
+ arch = self._base_arch(model_name)
825
+ ens = load_seed_ensemble(model_dir, arch, self.device)
826
+ if ens:
827
+ self.ensembles[(prop_key, col)] = ens
828
+ if (prop_key, col) in self.mapie:
829
+ print(f" Seed ensemble: {len(ens)} seeds loaded (MAPIE takes priority for regression)")
830
+ else:
831
+ unc_type = "ensemble_predictive_entropy" if row.task_type.lower() == "classifier" else "ensemble_std"
832
+ print(f" Seed ensemble: {len(ens)} seeds loaded uncertainty method: {unc_type}")
833
+ else:
834
+ if (prop_key, col) in self.mapie:
835
+ print(f" No seed ensemble (MAPIE covers uncertainty)")
836
+ else:
837
+ print(f" No seed ensemble found (checked: {SEED_DIRS}) - uncertainty unavailable")
838
 
839
+ # XGBoost/SVM classifiers: binary entropy
840
+ if kind in ("xgb", "joblib") and row.task_type.lower() == "classifier":
841
+ print(f" Uncertainty method: binary_predictive_entropy (computed at inference)")
 
 
 
 
 
 
 
842
 
843
+ @staticmethod
844
+ def _base_arch(model_name: str) -> str:
845
+ if model_name.startswith("transformer"): return "transformer"
846
+ if model_name.startswith("mlp"): return "mlp"
847
+ if model_name.startswith("cnn"): return "cnn"
848
+ return model_name
849
+
850
+ # Feature extraction
851
+ def _get_features(self, prop_key: str, col: str, input_str: str):
852
+ meta = self.meta[(prop_key, col)]
853
+ emb_tag = meta["emb_tag"]
854
+ kind = meta["kind"]
855
  if kind == "torch_ckpt":
856
+ return self._embed_unpooled(emb_tag, input_str)
857
+ return self._embed_pooled(emb_tag, input_str)
858
+
859
+ # Uncertainty
860
+ def _compute_uncertainty(self, prop_key: str, col: str, input_str: str,
861
+ score: float) -> Tuple[Any, str]:
862
+ meta = self.meta[(prop_key, col)]
863
+ kind = meta["kind"]
864
+ model_name = meta["model_name"]
865
+ task_type = meta["task_type"].lower()
866
+ emb_tag = meta["emb_tag"]
867
+
868
+ # Pooled embedding for adaptive MAPIE sigma model
869
+ def get_pooled_emb():
870
+ return self._embed_pooled(emb_tag, input_str) if emb_tag else None
871
+
872
+ # DNN
873
+ if kind == "torch_ckpt":
874
+ # Regression: prefer MAPIE if available
875
+ if task_type == "regression":
876
+ mapie_bundle = self.mapie.get((prop_key, col))
877
+ if mapie_bundle:
878
+ emb = get_pooled_emb() if mapie_bundle.get("adaptive") else None
879
+ lo, hi = _mapie_uncertainty(mapie_bundle, score, emb)
880
+ return (lo, hi), "conformal_prediction_interval"
881
+ # Fall back to seed ensemble std
882
+ ens = self.ensembles.get((prop_key, col))
883
+ if ens:
884
+ X, M = self._embed_unpooled(emb_tag, input_str)
885
+ return _ensemble_reg_uncertainty(ens, X, M), "ensemble_std"
886
+ return None, "unavailable (no MAPIE bundle and no seed ensemble)"
887
+ # Classifier: ensemble predictive entropy
888
+ ens = self.ensembles.get((prop_key, col))
889
+ if not ens:
890
+ return None, "unavailable (no seed ensemble found)"
891
+ X, M = self._embed_unpooled(emb_tag, input_str)
892
+ return _ensemble_clf_uncertainty(ens, X, M), "ensemble_predictive_entropy"
893
+
894
+ # XGBoost
895
+ if kind == "xgb":
896
+ if task_type == "classifier":
897
+ return _binary_entropy(score), "binary_predictive_entropy"
898
+ mapie_bundle = self.mapie.get((prop_key, col))
899
+ if mapie_bundle:
900
+ emb = get_pooled_emb() if mapie_bundle.get("adaptive") else None
901
+ lo, hi = _mapie_uncertainty(mapie_bundle, score, emb)
902
+ return (lo, hi), "conformal_prediction_interval"
903
+ return None, "unavailable (no MAPIE bundle for XGBoost regression)"
904
+
905
+ # SVR / ElasticNet regression: MAPIE
906
+ if kind == "joblib" and model_name in MAPIE_REGRESSION_MODELS and task_type == "regression":
907
+ mapie_bundle = self.mapie.get((prop_key, col))
908
+ if mapie_bundle:
909
+ emb = get_pooled_emb() if mapie_bundle.get("adaptive") else None
910
+ lo, hi = _mapie_uncertainty(mapie_bundle, score, emb)
911
+ return (lo, hi), "conformal_prediction_interval"
912
+ return None, "unavailable (MAPIE bundle not found)"
913
+
914
+ # joblib classifiers (SVM, ElasticNet used as classifier)
915
+ if kind == "joblib" and task_type == "classifier":
916
+ return _binary_entropy(score), "binary_predictive_entropy_single_model"
917
+
918
+ return None, "unavailable"
919
+
920
+ def predict_property(self, prop_key: str, col: str, input_str: str,
921
+ uncertainty: bool = False) -> Dict[str, Any]:
922
+ if (prop_key, col) not in self.models:
923
+ raise KeyError(f"No model loaded for ({prop_key}, {col}).")
924
+
925
+ meta = self.meta[(prop_key, col)]
926
+ model = self.models[(prop_key, col)]
927
+ task_type = meta["task_type"].lower()
928
+ thr = meta.get("threshold")
929
+ kind = meta["kind"]
930
+ model_name = meta["model_name"]
931
 
932
  if prop_key == "binding_affinity":
933
  raise RuntimeError("Use predict_binding_affinity().")
934
 
935
+ # DNN
936
  if kind == "torch_ckpt":
937
+ X, M = self._get_features(prop_key, col, input_str)
938
  with torch.no_grad():
939
+ raw = model(X, M).squeeze().float().cpu().item()
940
+
941
+ if prop_key == "halflife" and col == "wt" and "log" in model_name:
942
+ raw = float(np.expm1(raw))
943
+
 
 
 
 
944
  if task_type == "classifier":
945
+ score = float(1.0 / (1.0 + np.exp(-raw)))
946
+ out = {"property": prop_key, "col": col, "score": score,
947
+ "emb_tag": meta["emb_tag"]}
948
  if thr is not None:
949
+ out["label"] = int(score >= float(thr)); out["threshold"] = float(thr)
 
 
950
  else:
951
+ out = {"property": prop_key, "col": col, "score": float(raw),
952
+ "emb_tag": meta["emb_tag"]}
953
+
954
+ # XGBoost
955
+ elif kind == "xgb":
956
+ feats = self._get_features(prop_key, col, input_str)
957
+ pred = float(model.predict(xgb.DMatrix(feats))[0])
958
+ if prop_key == "halflife" and col == "wt" and "log" in model_name:
 
 
 
 
 
 
959
  pred = float(np.expm1(pred))
960
+ out = {"property": prop_key, "col": col, "score": pred,
961
+ "emb_tag": meta["emb_tag"]}
962
+ if task_type == "classifier" and thr is not None:
963
+ out["label"] = int(pred >= float(thr)); out["threshold"] = float(thr)
964
+
965
+ # joblib (SVM / ElasticNet / SVR)
966
+ elif kind == "joblib":
967
+ feats = self._get_features(prop_key, col, input_str)
 
968
  if task_type == "classifier":
969
  if hasattr(model, "predict_proba"):
970
  pred = float(model.predict_proba(feats)[:, 1][0])
971
+ elif hasattr(model, "decision_function"):
972
+ pred = float(1.0 / (1.0 + np.exp(-model.decision_function(feats)[0])))
973
  else:
974
+ pred = float(model.predict(feats)[0])
975
+ out = {"property": prop_key, "col": col, "score": pred,
976
+ "emb_tag": meta["emb_tag"]}
 
 
 
977
  if thr is not None:
978
+ out["label"] = int(pred >= float(thr)); out["threshold"] = float(thr)
 
 
979
  else:
980
  pred = float(model.predict(feats)[0])
981
+ out = {"property": prop_key, "col": col, "score": pred,
982
+ "emb_tag": meta["emb_tag"]}
983
+ else:
984
+ raise RuntimeError(f"Unknown kind={kind}")
985
 
986
+ if uncertainty:
987
+ u_val, u_type = self._compute_uncertainty(prop_key, col, input_str, out["score"])
988
+ out["uncertainty"] = u_val
989
+ out["uncertainty_type"] = u_type
 
 
 
 
990
 
991
+ return out
 
992
 
993
+ def predict_binding_affinity(self, col: str, target_seq: str, binder_str: str,
994
+ uncertainty: bool = False) -> Dict[str, Any]:
995
+ prop_key = "binding_affinity"
996
+ if (prop_key, col) not in self.models:
997
+ raise KeyError(f"No binding model loaded for ({prop_key}, {col}).")
998
+
999
+ model = self.models[(prop_key, col)]
1000
+ meta = self.meta[(prop_key, col)]
1001
+ arch = meta["model_name"]
1002
+ emb_tag = meta.get("emb_tag")
1003
+
1004
+ if arch == "pooled":
1005
+ t_vec = self.wt_embedder.pooled(target_seq)
1006
+ b_vec = self._get_embedder(emb_tag or col).pooled(binder_str) if emb_tag else \
1007
+ (self.wt_embedder.pooled(binder_str) if col == "wt" else self.smiles_embedder.pooled(binder_str))
1008
  with torch.no_grad():
1009
  reg, logits = model(t_vec, b_vec)
 
 
 
1010
  else:
1011
  T, Mt = self.wt_embedder.unpooled(target_seq)
1012
+ binder_emb = self._get_embedder(emb_tag or col) if emb_tag else \
1013
+ (self.wt_embedder if col == "wt" else self.smiles_embedder)
1014
+ B, Mb = binder_emb.unpooled(binder_str)
 
1015
  with torch.no_grad():
1016
  reg, logits = model(T, Mt, B, Mb)
1017
+
1018
+ affinity = float(reg.squeeze().cpu().item())
1019
+ cls_logit = int(torch.argmax(logits, dim=-1).cpu().item())
1020
+ cls_thr = affinity_to_class(affinity)
1021
+ names = {0: "High (≥9)", 1: "Moderate (7-9)", 2: "Low (<7)"}
1022
+
1023
+ out = {
1024
+ "property": "binding_affinity",
1025
+ "col": col,
1026
+ "affinity": affinity,
1027
  "class_by_threshold": names[cls_thr],
1028
+ "class_by_logits": names[cls_logit],
1029
+ "binding_model": arch,
1030
  }
1031
 
1032
+ if uncertainty:
1033
+ mapie_bundle = self.mapie.get((prop_key, col))
1034
+ if mapie_bundle:
1035
+ if mapie_bundle.get("adaptive") and "sigma_model" in mapie_bundle:
1036
+ # Concatenate target + binder pooled embeddings for sigma model
1037
+ binder_emb_tag = mapie_bundle.get("emb_tag") or col
1038
+ target_emb_tag = mapie_bundle.get("target_emb_tag", "wt")
1039
+ t_vec = self.wt_embedder.pooled(target_seq).cpu().float().numpy()
1040
+ b_vec = self._get_embedder(binder_emb_tag).pooled(binder_str).cpu().float().numpy()
1041
+ emb = np.concatenate([t_vec, b_vec], axis=1)
1042
+ else:
1043
+ emb = None
1044
+ lo, hi = _mapie_uncertainty(mapie_bundle, affinity, emb)
1045
+ out["uncertainty"] = (lo, hi)
1046
+ out["uncertainty_type"] = "conformal_prediction_interval"
1047
+ else:
1048
+ out["uncertainty"] = None
1049
+ out["uncertainty_type"] = "unavailable (no MAPIE bundle found)"
1050
+
1051
+ return out
1052
 
1053
  if __name__ == "__main__":
1054
+ root = Path(__file__).resolve().parent # current script folder
 
 
 
 
 
1055
 
1056
+ predictor = PeptiVersePredictor(
1057
+ manifest_path=root / "best_models.txt",
1058
+ classifier_weight_root=root
 
 
 
 
 
1059
  )
1060
+ print(predictor.training_root)
1061
+ print("MAPIE keys:", list(predictor.mapie.keys()))
1062
+ print("Ensemble keys:", list(predictor.ensembles.keys()))
1063
+
1064
+ seq = "GIGAVLKVLTTGLPALISWIKRKRQQ"
1065
+ smiles = "C(C)C[C@@H]1NC(=O)[C@@H]2CCCN2C(=O)[C@@H](CC(C)C)NC(=O)[C@@H](CC(C)C)N(C)C(=O)[C@H](C)NC(=O)[C@H](Cc2ccccc2)NC1=O"
1066
+
1067
+ print(predictor.predict_property("hemolysis", "wt", seq))
1068
+ print(predictor.predict_property("hemolysis", "smiles", smiles, uncertainty=True))
1069
+ print(predictor.predict_property("nf", "wt", seq, uncertainty=True))
1070
+ print(predictor.predict_property("nf", "smiles", smiles, uncertainty=True))
1071
+ print(predictor.predict_binding_affinity("wt", target_seq=seq, binder_str="GIGAVLKVLT"))
1072
+ print(predictor.predict_binding_affinity("wt", target_seq=seq, binder_str="GIGAVLKVLT", uncertainty=True))
1073
+ seq1 = "GIGAVLKVLTTGLPALISWIKRKRQQ"
1074
+ seq2 = "ACDEFGHIKLMNPQRSTVWY"
1075
 
1076
+ r1 = predictor.predict_binding_affinity("wt", target_seq=seq2, binder_str="GIGAVLKVLT", uncertainty=True)
1077
+ r2 = predictor.predict_property("nf", "wt", seq1, uncertainty=True)
1078
+ r3 = predictor.predict_property("nf", "wt", seq2, uncertainty=True)
1079
+ print(r1)
1080
+ print(r2)
1081
+ print(r3)
tokenizer/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (136 Bytes)
 
tokenizer/__pycache__/my_tokenizers.cpython-310.pyc DELETED
Binary file (16.2 kB)