Joblib
PeptiVerse / inference.py
AlienChen's picture
Update inference.py
dfe9aa7 verified
from __future__ import annotations
import csv, re, json
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Optional, Tuple, Any, List
import numpy as np
import torch
import torch.nn as nn
import joblib
import xgboost as xgb
from transformers import EsmModel, EsmTokenizer, AutoModelForMaskedLM
from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
from lightning.pytorch import seed_everything
seed_everything(1986)
# -----------------------------
# Manifest
# -----------------------------
EMB_TAG_TO_FOLDER_SUFFIX = {
"wt": "wt",
"peptideclm": "smiles",
"chemberta": "chemberta",
}
EMB_TAG_TO_RUNTIME_MODE = {
"wt": "wt",
"peptideclm": "smiles",
"chemberta": "chemberta",
}
MAPIE_REGRESSION_MODELS = {"svr", "enet_gpu"}
DNN_ARCHS = {"mlp", "cnn", "transformer"}
XGB_MODELS = {"xgb", "xgb_reg", "xgb_wt_log", "xgb_smiles"}
@dataclass(frozen=True)
class BestRow:
property_key: str
best_wt: Optional[Tuple[str, Optional[str]]]
best_smiles: Optional[Tuple[str, Optional[str]]]
task_type: str
thr_wt: Optional[float]
thr_smiles: Optional[float]
def _clean(s: str) -> str:
return (s or "").strip()
def _none_if_dash(s: str) -> Optional[str]:
s = _clean(s)
return None if s in {"", "-", "-", "NA", "N/A"} else s
def _float_or_none(s: str) -> Optional[float]:
s = _clean(s)
return None if s in {"", "-", "-", "NA", "N/A"} else float(s)
def normalize_property_key(name: str) -> str:
n = name.strip().lower()
n = re.sub(r"\s*\(.*?\)\s*", "", n)
n = n.replace("-", "_").replace(" ", "_")
if "permeability" in n and "pampa" not in n and "caco" not in n:
return "permeability_penetrance"
if n == "binding_affinity":
return "binding_affinity"
if n in {"halflife", "half_life"}:
return "halflife"
if n == "non_fouling":
return "nf"
return n
MODEL_ALIAS = {
"SVM": "svm_gpu",
"SVR": "svr",
"ENET": "enet_gpu",
"CNN": "cnn",
"MLP": "mlp",
"TRANSFORMER": "transformer",
"XGB": "xgb",
"XGB_REG": "xgb_reg",
"POOLED": "pooled",
"UNPOOLED": "unpooled",
"TRANSFORMER_WT_LOG": "transformer_wt_log",
}
def _parse_model_and_emb(raw: Optional[str]) -> Optional[Tuple[str, Optional[str]]]:
if raw is None:
return None
raw = _clean(raw)
if not raw or raw in {"-", "-", "NA", "N/A"}:
return None
m = re.match(r"^(.+?)\s*\((.+?)\)\s*$", raw)
if m:
model_raw = m.group(1).strip()
emb_tag = m.group(2).strip().lower()
else:
model_raw = raw
emb_tag = None
canon = MODEL_ALIAS.get(model_raw.upper(), model_raw.lower())
return canon, emb_tag
def read_best_manifest_csv(path: str | Path) -> Dict[str, BestRow]:
p = Path(path)
out: Dict[str, BestRow] = {}
with p.open("r", newline="") as f:
reader = csv.reader(f)
header = None
for raw in reader:
if not raw or all(_clean(x) == "" for x in raw):
continue
while raw and _clean(raw[-1]) == "":
raw = raw[:-1]
if header is None:
header = [h.strip() for h in raw]
continue
if len(raw) < len(header):
raw = raw + [""] * (len(header) - len(raw))
rec = dict(zip(header, raw))
prop_raw = _clean(rec.get("Properties", ""))
if not prop_raw:
continue
prop_key = normalize_property_key(prop_raw)
best_wt = _parse_model_and_emb(_none_if_dash(rec.get("Best_Model_WT", "")))
best_smiles = _parse_model_and_emb(_none_if_dash(rec.get("Best_Model_SMILES", "")))
row = BestRow(
property_key=prop_key,
best_wt=best_wt,
best_smiles=best_smiles,
task_type=_clean(rec.get("Type", "Classifier")),
thr_wt=_float_or_none(rec.get("Threshold_WT", "")),
thr_smiles=_float_or_none(rec.get("Threshold_SMILES", "")),
)
out[prop_key] = row
return out
# -----------------------------
# Generic artifact loading
# -----------------------------
def find_best_artifact(model_dir: Path) -> Path:
for pat in ["best_model.json", "best_model.pt", "best_model*.joblib",
"model.json", "model.ubj", "final_model.json"]:
hits = sorted(model_dir.glob(pat))
if hits:
return hits[0]
seed_pt = model_dir / "seed_1986" / "model.pt"
if seed_pt.exists():
return seed_pt
raise FileNotFoundError(f"No best_model artifact found in {model_dir}")
def load_artifact(model_dir: Path, device: torch.device) -> Tuple[str, Any, Path]:
art = find_best_artifact(model_dir)
if art.suffix == ".json":
booster = xgb.Booster()
booster.load_model(str(art))
return "xgb", booster, art
if art.suffix == ".joblib":
obj = joblib.load(art)
return "joblib", obj, art
if art.suffix == ".pt":
ckpt = torch.load(art, map_location=device, weights_only=False)
return "torch_ckpt", ckpt, art
raise ValueError(f"Unknown artifact type: {art}")
# -----------------------------
# NN architectures
# -----------------------------
class MaskedMeanPool(nn.Module):
def forward(self, X, M):
Mf = M.unsqueeze(-1).float()
denom = Mf.sum(dim=1).clamp(min=1.0)
return (X * Mf).sum(dim=1) / denom
class MLPHead(nn.Module):
def __init__(self, in_dim, hidden=512, dropout=0.1):
super().__init__()
self.pool = MaskedMeanPool()
self.net = nn.Sequential(
nn.Linear(in_dim, hidden), nn.GELU(), nn.Dropout(dropout),
nn.Linear(hidden, 1),
)
def forward(self, X, M):
return self.net(self.pool(X, M)).squeeze(-1)
class CNNHead(nn.Module):
def __init__(self, in_ch, c=256, k=5, layers=2, dropout=0.1):
super().__init__()
blocks, ch = [], in_ch
for _ in range(layers):
blocks += [nn.Conv1d(ch, c, kernel_size=k, padding=k//2), nn.GELU(), nn.Dropout(dropout)]
ch = c
self.conv = nn.Sequential(*blocks)
self.head = nn.Linear(c, 1)
def forward(self, X, M):
Y = self.conv(X.transpose(1, 2)).transpose(1, 2)
Mf = M.unsqueeze(-1).float()
pooled = (Y * Mf).sum(dim=1) / Mf.sum(dim=1).clamp(min=1.0)
return self.head(pooled).squeeze(-1)
class TransformerHead(nn.Module):
def __init__(self, in_dim, d_model=256, nhead=8, layers=2, ff=512, dropout=0.1):
super().__init__()
self.proj = nn.Linear(in_dim, d_model)
enc_layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=nhead, dim_feedforward=ff,
dropout=dropout, batch_first=True, activation="gelu"
)
self.enc = nn.TransformerEncoder(enc_layer, num_layers=layers)
self.head = nn.Linear(d_model, 1)
def forward(self, X, M):
Z = self.enc(self.proj(X), src_key_padding_mask=~M)
Mf = M.unsqueeze(-1).float()
pooled = (Z * Mf).sum(dim=1) / Mf.sum(dim=1).clamp(min=1.0)
return self.head(pooled).squeeze(-1)
def _infer_in_dim_from_sd(sd: dict, model_name: str) -> int:
if model_name == "mlp": return int(sd["net.0.weight"].shape[1])
if model_name == "cnn": return int(sd["conv.0.weight"].shape[1])
if model_name == "transformer": return int(sd["proj.weight"].shape[1])
raise ValueError(model_name)
def _infer_num_layers_from_sd(sd: dict, prefix: str = "enc.layers.") -> int:
idxs = set()
for k in sd.keys():
if k.startswith(prefix):
m = re.match(r"(\d+)\.", k[len(prefix):])
if m:
idxs.add(int(m.group(1)))
return (max(idxs) + 1) if idxs else 1
def _infer_transformer_arch_from_sd(sd: dict) -> Tuple[int, int, int]:
if "proj.weight" not in sd:
raise KeyError("Missing proj.weight in state_dict")
d_model = int(sd["proj.weight"].shape[0])
layers = _infer_num_layers_from_sd(sd, prefix="enc.layers.")
ff = int(sd["enc.layers.0.linear1.weight"].shape[0]) if "enc.layers.0.linear1.weight" in sd else 4 * d_model
return d_model, layers, ff
def _pick_nhead(d_model: int) -> int:
for h in (8, 6, 4, 3, 2, 1):
if d_model % h == 0:
return h
return 1
def build_torch_model_from_ckpt(model_name: str, ckpt: dict, device: torch.device) -> nn.Module:
params = ckpt["best_params"]
sd = ckpt["state_dict"]
in_dim = int(ckpt.get("in_dim", _infer_in_dim_from_sd(sd, model_name)))
dropout = float(params.get("dropout", 0.1))
if model_name == "mlp":
model = MLPHead(in_dim=in_dim, hidden=int(params["hidden"]), dropout=dropout)
elif model_name == "cnn":
model = CNNHead(in_ch=in_dim, c=int(params["channels"]), k=int(params["kernel"]),
layers=int(params["layers"]), dropout=dropout)
elif model_name == "transformer":
d_model = params.get("d_model") or params.get("hidden") or params.get("hidden_dim")
if d_model is None:
d_model_i, layers_i, ff_i = _infer_transformer_arch_from_sd(sd)
nhead_i = _pick_nhead(d_model_i)
model = TransformerHead(
in_dim=in_dim, d_model=int(d_model_i), nhead=int(params.get("nhead", nhead_i)),
layers=int(params.get("layers", layers_i)), ff=int(params.get("ff", ff_i)),
dropout=float(params.get("dropout", dropout)),
)
else:
d_model = int(d_model)
model = TransformerHead(
in_dim=in_dim, d_model=d_model,
nhead=int(params.get("nhead", _pick_nhead(d_model))),
layers=int(params.get("layers", 2)),
ff=int(params.get("ff", 4 * d_model)),
dropout=dropout,
)
else:
raise ValueError(f"Unknown NN model_name={model_name}")
model.load_state_dict(sd)
model.to(device).eval()
return model
# -----------------------------
# Wrappers
# -----------------------------
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
class PassthroughRegressor(BaseEstimator, RegressorMixin):
def __init__(self, preds: np.ndarray):
self.preds = preds
def fit(self, X, y): return self
def predict(self, X): return self.preds[:len(X)]
class PassthroughClassifier(BaseEstimator, ClassifierMixin):
def __init__(self, preds: np.ndarray):
self.preds = preds
self.classes_ = np.array([0, 1])
def fit(self, X, y): return self
def predict(self, X): return (self.preds[:len(X)] >= 0.5).astype(int)
def predict_proba(self, X):
p = self.preds[:len(X)]
return np.stack([1 - p, p], axis=1)
# -----------------------------
# Uncertainty helpers
# -----------------------------
SEED_DIRS = ["seed_1986", "seed_42", "seed_0", "seed_123", "seed_12345"]
def load_seed_ensemble(model_dir: Path, arch: str, device: torch.device) -> List[nn.Module]:
ensemble = []
for sd_name in SEED_DIRS:
pt = model_dir / sd_name / "model.pt"
if not pt.exists():
continue
ckpt = torch.load(pt, map_location=device, weights_only=False)
ensemble.append(build_torch_model_from_ckpt(arch, ckpt, device))
return ensemble
def _binary_entropy(p: float) -> float:
p = float(np.clip(p, 1e-9, 1 - 1e-9))
return float(-p * np.log(p) - (1 - p) * np.log(1 - p))
def _ensemble_clf_uncertainty(ensemble: List[nn.Module], X: torch.Tensor, M: torch.Tensor) -> float:
probs = []
with torch.no_grad():
for m in ensemble:
logit = m(X, M).squeeze().float().cpu().item()
probs.append(1.0 / (1.0 + np.exp(-logit)))
return _binary_entropy(float(np.mean(probs)))
def _ensemble_reg_uncertainty(ensemble: List[nn.Module], X: torch.Tensor, M: torch.Tensor) -> float:
preds = []
with torch.no_grad():
for m in ensemble:
preds.append(m(X, M).squeeze().float().cpu().item())
return float(np.std(preds))
def _mapie_uncertainty(mapie_bundle: dict, score: float,
embedding: Optional[np.ndarray] = None) -> Tuple[float, float]:
"""
Returns (ci_low, ci_high) from a conformal bundle.
- adaptive: {"quantile": q, "sigma_model": xgb, "emb_tag": ..., "adaptive": True}
Input-dependent: interval = score +/- q * sigma(embedding)
- plain_quantile: {"quantile": q, "alpha": ...}
Fixed-width: interval = score +/- q
"""
# Adaptive format is input-dependent interval
if mapie_bundle.get("adaptive") and "sigma_model" in mapie_bundle:
q = float(mapie_bundle["quantile"])
if embedding is not None:
# Adaptive interval: y_hat ± q * sigma_hat(x).
# Equivalent to MAPIE's get_estimation_distribution():
# y_pred + conformity_scores * r_pred
# where conformity_scores=q and r_pred=sigma_hat(x).
# (ResidualNormalisedScore, Cordier et al. 2023)
sigma_model = mapie_bundle["sigma_model"]
sigma = float(sigma_model.predict(xgb.DMatrix(embedding.reshape(1, -1)))[0])
sigma = max(sigma, 1e-6)
else:
# No embedding available - fall back to fixed interval with sigma=1
sigma = 1.0
return float(score - q * sigma), float(score + q * sigma)
# Plain quantile format
if "quantile" in mapie_bundle:
q = float(mapie_bundle["quantile"])
return float(score - q), float(score + q)
X_dummy = np.zeros((1, 1))
result = mapie.predict(X_dummy)
if isinstance(result, tuple):
intervals = np.asarray(result[1])
if intervals.ndim == 3:
return float(intervals[0, 0, 0]), float(intervals[0, 1, 0])
return float(intervals[0, 0]), float(intervals[0, 1])
raise RuntimeError(
f"Cannot extract intervals: unknown MAPIE bundle format. "
f"Bundle keys: {list(mapie_bundle.keys())}."
)
def affinity_to_class(y: float) -> int:
if y >= 9.0: return 0
if y < 7.0: return 2
return 1
class CrossAttnPooled(nn.Module):
def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1):
super().__init__()
self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden))
self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden))
self.layers = nn.ModuleList([])
for _ in range(n_layers):
self.layers.append(nn.ModuleDict({
"attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False),
"attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False),
"n1t": nn.LayerNorm(hidden), "n2t": nn.LayerNorm(hidden),
"n1b": nn.LayerNorm(hidden), "n2b": nn.LayerNorm(hidden),
"fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
"ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
}))
self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout))
self.reg = nn.Linear(hidden, 1)
self.cls = nn.Linear(hidden, 3)
def forward(self, t_vec, b_vec):
t = self.t_proj(t_vec).unsqueeze(0)
b = self.b_proj(b_vec).unsqueeze(0)
for L in self.layers:
t_attn, _ = L["attn_tb"](t, b, b)
t = L["n1t"]((t + t_attn).transpose(0,1)).transpose(0,1)
t = L["n2t"]((t + L["fft"](t)).transpose(0,1)).transpose(0,1)
b_attn, _ = L["attn_bt"](b, t, t)
b = L["n1b"]((b + b_attn).transpose(0,1)).transpose(0,1)
b = L["n2b"]((b + L["ffb"](b)).transpose(0,1)).transpose(0,1)
h = self.shared(torch.cat([t[0], b[0]], dim=-1))
return self.reg(h).squeeze(-1), self.cls(h)
class CrossAttnUnpooled(nn.Module):
def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1):
super().__init__()
self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden))
self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden))
self.layers = nn.ModuleList([])
for _ in range(n_layers):
self.layers.append(nn.ModuleDict({
"attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
"attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
"n1t": nn.LayerNorm(hidden), "n2t": nn.LayerNorm(hidden),
"n1b": nn.LayerNorm(hidden), "n2b": nn.LayerNorm(hidden),
"fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
"ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
}))
self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout))
self.reg = nn.Linear(hidden, 1)
self.cls = nn.Linear(hidden, 3)
def _masked_mean(self, X, M):
Mf = M.unsqueeze(-1).float()
return (X * Mf).sum(dim=1) / Mf.sum(dim=1).clamp(min=1.0)
def forward(self, T, Mt, B, Mb):
T = self.t_proj(T); Bx = self.b_proj(B)
kp_t, kp_b = ~Mt, ~Mb
for L in self.layers:
T_attn, _ = L["attn_tb"](T, Bx, Bx, key_padding_mask=kp_b)
T = L["n1t"](T + T_attn); T = L["n2t"](T + L["fft"](T))
B_attn, _ = L["attn_bt"](Bx, T, T, key_padding_mask=kp_t)
Bx = L["n1b"](Bx + B_attn); Bx = L["n2b"](Bx + L["ffb"](Bx))
h = self.shared(torch.cat([self._masked_mean(T, Mt), self._masked_mean(Bx, Mb)], dim=-1))
return self.reg(h).squeeze(-1), self.cls(h)
def load_binding_model(best_model_pt: Path, pooled_or_unpooled: str, device: torch.device) -> nn.Module:
ckpt = torch.load(best_model_pt, map_location=device, weights_only=False)
params = ckpt["best_params"]
sd = ckpt["state_dict"]
Ht = int(sd["t_proj.0.weight"].shape[1])
Hb = int(sd["b_proj.0.weight"].shape[1])
common = dict(Ht=Ht, Hb=Hb, hidden=int(params["hidden_dim"]),
n_heads=int(params["n_heads"]), n_layers=int(params["n_layers"]),
dropout=float(params["dropout"]))
cls = CrossAttnPooled if pooled_or_unpooled == "pooled" else CrossAttnUnpooled
model = cls(**common)
model.load_state_dict(sd)
return model.to(device).eval()
# -----------------------------
# Embedding generation
# -----------------------------
def _safe_isin(ids: torch.Tensor, test_ids: torch.Tensor) -> torch.Tensor:
if hasattr(torch, "isin"):
return torch.isin(ids, test_ids)
return (ids.unsqueeze(-1) == test_ids.view(1, 1, -1)).any(dim=-1)
class SMILESEmbedder:
def __init__(self, device, vocab_path, splits_path,
clm_name="aaronfeller/PeptideCLM-23M-all", max_len=512, use_cache=True):
self.device = device
self.max_len = max_len
self.use_cache = use_cache
self.tokenizer = SMILES_SPE_Tokenizer(vocab_path, splits_path)
self.model = AutoModelForMaskedLM.from_pretrained(clm_name).roformer.to(device).eval()
self.special_ids = self._get_special_ids(self.tokenizer)
self.special_ids_t = (torch.tensor(self.special_ids, device=device, dtype=torch.long)
if self.special_ids else None)
self._cache_pooled: Dict[str, torch.Tensor] = {}
self._cache_unpooled: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
@staticmethod
def _get_special_ids(tokenizer) -> List[int]:
cand = [getattr(tokenizer, f"{x}_token_id", None)
for x in ("pad", "cls", "sep", "bos", "eos", "mask")]
return sorted({int(x) for x in cand if x is not None})
def _tokenize(self, smiles_list):
tok = self.tokenizer(smiles_list, return_tensors="pt", padding=True,
truncation=True, max_length=self.max_len)
for k in tok: tok[k] = tok[k].to(self.device)
if "attention_mask" not in tok:
tok["attention_mask"] = torch.ones_like(tok["input_ids"], dtype=torch.long, device=self.device)
return tok
def _valid_mask(self, ids, attn):
valid = attn.bool()
if self.special_ids_t is not None and self.special_ids_t.numel() > 0:
valid = valid & (~_safe_isin(ids, self.special_ids_t))
return valid
@torch.no_grad()
def pooled(self, smiles: str) -> torch.Tensor:
s = smiles.strip()
if self.use_cache and s in self._cache_pooled: return self._cache_pooled[s]
tok = self._tokenize([s])
h = self.model(input_ids=tok["input_ids"], attention_mask=tok["attention_mask"]).last_hidden_state
valid = self._valid_mask(tok["input_ids"], tok["attention_mask"])
vf = valid.unsqueeze(-1).float()
pooled = (h * vf).sum(dim=1) / vf.sum(dim=1).clamp(min=1e-9)
if self.use_cache: self._cache_pooled[s] = pooled
return pooled
@torch.no_grad()
def unpooled(self, smiles: str) -> Tuple[torch.Tensor, torch.Tensor]:
s = smiles.strip()
if self.use_cache and s in self._cache_unpooled: return self._cache_unpooled[s]
tok = self._tokenize([s])
h = self.model(input_ids=tok["input_ids"], attention_mask=tok["attention_mask"]).last_hidden_state
valid = self._valid_mask(tok["input_ids"], tok["attention_mask"])
X = h[:, valid[0], :]
M = torch.ones((1, X.shape[1]), dtype=torch.bool, device=self.device)
if self.use_cache: self._cache_unpooled[s] = (X, M)
return X, M
class ChemBERTaEmbedder:
def __init__(self, device, model_name="DeepChem/ChemBERTa-77M-MLM",
max_len=512, use_cache=True):
from transformers import AutoTokenizer, AutoModel
self.device = device
self.max_len = max_len
self.use_cache = use_cache
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name).to(device).eval()
self.special_ids = self._get_special_ids(self.tokenizer)
self.special_ids_t = (torch.tensor(self.special_ids, device=device, dtype=torch.long)
if self.special_ids else None)
self._cache_pooled: Dict[str, torch.Tensor] = {}
self._cache_unpooled: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
@staticmethod
def _get_special_ids(tokenizer) -> List[int]:
cand = [getattr(tokenizer, f"{x}_token_id", None)
for x in ("pad", "cls", "sep", "bos", "eos", "mask")]
return sorted({int(x) for x in cand if x is not None})
def _tokenize(self, smiles_list):
tok = self.tokenizer(smiles_list, return_tensors="pt", padding=True,
truncation=True, max_length=self.max_len)
for k in tok: tok[k] = tok[k].to(self.device)
if "attention_mask" not in tok:
tok["attention_mask"] = torch.ones_like(tok["input_ids"], dtype=torch.long, device=self.device)
return tok
def _valid_mask(self, ids, attn):
valid = attn.bool()
if self.special_ids_t is not None and self.special_ids_t.numel() > 0:
valid = valid & (~_safe_isin(ids, self.special_ids_t))
return valid
@torch.no_grad()
def pooled(self, smiles: str) -> torch.Tensor:
s = smiles.strip()
if self.use_cache and s in self._cache_pooled: return self._cache_pooled[s]
tok = self._tokenize([s])
h = self.model(input_ids=tok["input_ids"], attention_mask=tok["attention_mask"]).last_hidden_state
valid = self._valid_mask(tok["input_ids"], tok["attention_mask"])
vf = valid.unsqueeze(-1).float()
pooled = (h * vf).sum(dim=1) / vf.sum(dim=1).clamp(min=1e-9)
if self.use_cache: self._cache_pooled[s] = pooled
return pooled
@torch.no_grad()
def unpooled(self, smiles: str) -> Tuple[torch.Tensor, torch.Tensor]:
s = smiles.strip()
if self.use_cache and s in self._cache_unpooled: return self._cache_unpooled[s]
tok = self._tokenize([s])
h = self.model(input_ids=tok["input_ids"], attention_mask=tok["attention_mask"]).last_hidden_state
valid = self._valid_mask(tok["input_ids"], tok["attention_mask"])
X = h[:, valid[0], :]
M = torch.ones((1, X.shape[1]), dtype=torch.bool, device=self.device)
if self.use_cache: self._cache_unpooled[s] = (X, M)
return X, M
class WTEmbedder:
def __init__(self, device, esm_name="facebook/esm2_t33_650M_UR50D", max_len=1022, use_cache=True):
self.device = device
self.max_len = max_len
self.use_cache = use_cache
self.tokenizer = EsmTokenizer.from_pretrained(esm_name)
self.model = EsmModel.from_pretrained(esm_name, add_pooling_layer=False).to(device).eval()
self.special_ids = self._get_special_ids(self.tokenizer)
self.special_ids_t = (torch.tensor(self.special_ids, device=device, dtype=torch.long)
if self.special_ids else None)
self._cache_pooled: Dict[str, torch.Tensor] = {}
self._cache_unpooled: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {}
@staticmethod
def _get_special_ids(tokenizer) -> List[int]:
cand = [getattr(tokenizer, f"{x}_token_id", None)
for x in ("pad", "cls", "sep", "bos", "eos", "mask")]
return sorted({int(x) for x in cand if x is not None})
def _tokenize(self, seq_list):
tok = self.tokenizer(seq_list, return_tensors="pt", padding=True,
truncation=True, max_length=self.max_len)
tok = {k: v.to(self.device) for k, v in tok.items()}
if "attention_mask" not in tok:
tok["attention_mask"] = torch.ones_like(tok["input_ids"], dtype=torch.long, device=self.device)
return tok
def _valid_mask(self, ids, attn):
valid = attn.bool()
if self.special_ids_t is not None and self.special_ids_t.numel() > 0:
valid = valid & (~_safe_isin(ids, self.special_ids_t))
return valid
@torch.no_grad()
def pooled(self, seq: str) -> torch.Tensor:
s = seq.strip()
if self.use_cache and s in self._cache_pooled: return self._cache_pooled[s]
tok = self._tokenize([s])
h = self.model(**tok).last_hidden_state
valid = self._valid_mask(tok["input_ids"], tok["attention_mask"])
vf = valid.unsqueeze(-1).float()
pooled = (h * vf).sum(dim=1) / vf.sum(dim=1).clamp(min=1e-9)
if self.use_cache: self._cache_pooled[s] = pooled
return pooled
@torch.no_grad()
def unpooled(self, seq: str) -> Tuple[torch.Tensor, torch.Tensor]:
s = seq.strip()
if self.use_cache and s in self._cache_unpooled: return self._cache_unpooled[s]
tok = self._tokenize([s])
h = self.model(**tok).last_hidden_state
valid = self._valid_mask(tok["input_ids"], tok["attention_mask"])
X = h[:, valid[0], :]
M = torch.ones((1, X.shape[1]), dtype=torch.bool, device=self.device)
if self.use_cache: self._cache_unpooled[s] = (X, M)
return X, M
# -----------------------------
# Predictor
# -----------------------------
class PeptiVersePredictor:
def __init__(
self,
manifest_path: str | Path,
classifier_weight_root: str | Path,
esm_name="facebook/esm2_t33_650M_UR50D",
clm_name="aaronfeller/PeptideCLM-23M-all",
chemberta_name="DeepChem/ChemBERTa-77M-MLM",
smiles_vocab="tokenizer/new_vocab.txt",
smiles_splits="tokenizer/new_splits.txt",
device: Optional[str] = None,
):
self.root = Path(classifier_weight_root)
self.training_root = self.root / "training_classifiers"
self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
self.manifest = read_best_manifest_csv(manifest_path)
self.wt_embedder = WTEmbedder(self.device, esm_name=esm_name)
self.smiles_embedder = SMILESEmbedder(self.device, clm_name=clm_name,
vocab_path=str(self.root / smiles_vocab),
splits_path=str(self.root / smiles_splits))
self.chemberta_embedder = ChemBERTaEmbedder(self.device, model_name=chemberta_name)
self.models: Dict[Tuple[str, str], Any] = {}
self.meta: Dict[Tuple[str, str], Dict[str, Any]] = {}
self.mapie: Dict[Tuple[str, str], dict] = {}
self.ensembles: Dict[Tuple[str, str], List] = {}
self._load_all_best_models()
def _get_embedder(self, emb_tag: str):
if emb_tag == "wt": return self.wt_embedder
if emb_tag == "peptideclm": return self.smiles_embedder
if emb_tag == "chemberta": return self.chemberta_embedder
raise ValueError(f"Unknown emb_tag={emb_tag!r}")
def _embed_pooled(self, emb_tag: str, input_str: str) -> np.ndarray:
v = self._get_embedder(emb_tag).pooled(input_str)
feats = v.detach().cpu().numpy().astype(np.float32)
feats = np.nan_to_num(feats, nan=0.0)
return np.clip(feats, np.finfo(np.float32).min, np.finfo(np.float32).max)
def _embed_unpooled(self, emb_tag: str, input_str: str) -> Tuple[torch.Tensor, torch.Tensor]:
return self._get_embedder(emb_tag).unpooled(input_str)
def _resolve_dir(self, prop_key: str, model_name: str, emb_tag: str) -> Path:
disk_prop = "half_life" if prop_key == "halflife" else prop_key
base = self.training_root / disk_prop
folder_suffix = EMB_TAG_TO_FOLDER_SUFFIX.get(emb_tag, emb_tag)
if prop_key == "halflife" and emb_tag == "wt":
if model_name == "transformer":
for d in [base / "transformer_wt_log", base / "transformer_wt"]:
if d.exists(): return d
if model_name in {"xgb", "xgb_reg"}:
d = base / "xgb_wt_log"
if d.exists(): return d
candidates = [
base / f"{model_name}_{folder_suffix}",
base / model_name,
]
for d in candidates:
if d.exists(): return d
raise FileNotFoundError(
f"Cannot find model dir for {prop_key}/{model_name}/{emb_tag}. Tried: {candidates}"
)
def _load_all_best_models(self):
for prop_key, row in self.manifest.items():
for col, parsed, thr in [
("wt", row.best_wt, row.thr_wt),
("smiles", row.best_smiles, row.thr_smiles),
]:
if parsed is None:
continue
model_name, emb_tag = parsed
# binding affinity
if prop_key == "binding_affinity":
folder = model_name
pooled_or_unpooled = "unpooled" if "unpooled" in folder else "pooled"
model_dir = self.training_root / "binding_affinity" / folder
art = find_best_artifact(model_dir)
model = load_binding_model(art, pooled_or_unpooled, self.device)
self.models[(prop_key, col)] = model
self.meta[(prop_key, col)] = {
"task_type": "Regression",
"threshold": None,
"artifact": str(art),
"model_name": pooled_or_unpooled,
"emb_tag": emb_tag,
"folder": folder,
"kind": "binding",
}
print(f" [LOAD] binding_affinity ({col}): folder={folder}, arch={pooled_or_unpooled}, emb_tag={emb_tag}, art={art.name}")
mapie_path = model_dir / "mapie_calibration.joblib"
if mapie_path.exists():
try:
self.mapie[(prop_key, col)] = joblib.load(mapie_path)
print(f" MAPIE loaded from {mapie_path.name}")
except Exception as e:
print(f" MAPIE load FAILED for ({prop_key}, {col}): {e}")
else:
print(f" No MAPIE bundle found (uncertainty will be unavailable)")
continue
# infer emb_tag
if emb_tag is None:
emb_tag = col
model_dir = self._resolve_dir(prop_key, model_name, emb_tag)
kind, obj, art = load_artifact(model_dir, self.device)
if kind == "torch_ckpt":
arch = self._base_arch(model_name)
model = build_torch_model_from_ckpt(arch, obj, self.device)
else:
model = obj
self.models[(prop_key, col)] = model
self.meta[(prop_key, col)] = {
"task_type": row.task_type,
"threshold": thr,
"artifact": str(art),
"model_name": model_name,
"emb_tag": emb_tag,
"kind": kind,
}
print(f" [LOAD] ({prop_key}, {col}): kind={kind}, model={model_name}, emb={emb_tag}, task={row.task_type}, art={art.name}")
# MAPIE: SVR/ElasticNet, XGBoost regression, AND all regression torch_ckpt
is_regression = row.task_type.lower() == "regression"
wants_mapie = (
(model_name in MAPIE_REGRESSION_MODELS and is_regression)
or (kind == "xgb" and is_regression)
or (kind == "torch_ckpt" and is_regression)
)
if wants_mapie:
mapie_path = model_dir / "mapie_calibration.joblib"
if mapie_path.exists():
try:
self.mapie[(prop_key, col)] = joblib.load(mapie_path)
print(f" MAPIE loaded from {mapie_path.name}")
except Exception as e:
print(f" MAPIE load FAILED for ({prop_key}, {col}): {e}")
else:
print(f" No MAPIE bundle found at {mapie_path} (will fall back to ensemble if available)")
# Seed ensembles: DNN only, used when MAPIE not available
if kind == "torch_ckpt":
arch = self._base_arch(model_name)
ens = load_seed_ensemble(model_dir, arch, self.device)
if ens:
self.ensembles[(prop_key, col)] = ens
if (prop_key, col) in self.mapie:
print(f" Seed ensemble: {len(ens)} seeds loaded (MAPIE takes priority for regression)")
else:
unc_type = "ensemble_predictive_entropy" if row.task_type.lower() == "classifier" else "ensemble_std"
print(f" Seed ensemble: {len(ens)} seeds loaded uncertainty method: {unc_type}")
else:
if (prop_key, col) in self.mapie:
print(f" No seed ensemble (MAPIE covers uncertainty)")
else:
print(f" No seed ensemble found (checked: {SEED_DIRS}) - uncertainty unavailable")
# XGBoost/SVM classifiers: binary entropy
if kind in ("xgb", "joblib") and row.task_type.lower() == "classifier":
print(f" Uncertainty method: binary_predictive_entropy (computed at inference)")
@staticmethod
def _base_arch(model_name: str) -> str:
if model_name.startswith("transformer"): return "transformer"
if model_name.startswith("mlp"): return "mlp"
if model_name.startswith("cnn"): return "cnn"
return model_name
# Feature extraction
def _get_features(self, prop_key: str, col: str, input_str: str):
meta = self.meta[(prop_key, col)]
emb_tag = meta["emb_tag"]
kind = meta["kind"]
if kind == "torch_ckpt":
return self._embed_unpooled(emb_tag, input_str)
return self._embed_pooled(emb_tag, input_str)
# Uncertainty
def _compute_uncertainty(self, prop_key: str, col: str, input_str: str,
score: float) -> Tuple[Any, str]:
meta = self.meta[(prop_key, col)]
kind = meta["kind"]
model_name = meta["model_name"]
task_type = meta["task_type"].lower()
emb_tag = meta["emb_tag"]
# Pooled embedding for adaptive MAPIE sigma model
def get_pooled_emb():
return self._embed_pooled(emb_tag, input_str) if emb_tag else None
# DNN
if kind == "torch_ckpt":
# Regression: prefer MAPIE if available
if task_type == "regression":
mapie_bundle = self.mapie.get((prop_key, col))
if mapie_bundle:
emb = get_pooled_emb() if mapie_bundle.get("adaptive") else None
lo, hi = _mapie_uncertainty(mapie_bundle, score, emb)
return (lo, hi), "conformal_prediction_interval"
# Fall back to seed ensemble std
ens = self.ensembles.get((prop_key, col))
if ens:
X, M = self._embed_unpooled(emb_tag, input_str)
return _ensemble_reg_uncertainty(ens, X, M), "ensemble_std"
return None, "unavailable (no MAPIE bundle and no seed ensemble)"
# Classifier: ensemble predictive entropy
ens = self.ensembles.get((prop_key, col))
if not ens:
return None, "unavailable (no seed ensemble found)"
X, M = self._embed_unpooled(emb_tag, input_str)
return _ensemble_clf_uncertainty(ens, X, M), "ensemble_predictive_entropy"
# XGBoost
if kind == "xgb":
if task_type == "classifier":
return _binary_entropy(score), "binary_predictive_entropy"
mapie_bundle = self.mapie.get((prop_key, col))
if mapie_bundle:
emb = get_pooled_emb() if mapie_bundle.get("adaptive") else None
lo, hi = _mapie_uncertainty(mapie_bundle, score, emb)
return (lo, hi), "conformal_prediction_interval"
return None, "unavailable (no MAPIE bundle for XGBoost regression)"
# SVR / ElasticNet regression: MAPIE
if kind == "joblib" and model_name in MAPIE_REGRESSION_MODELS and task_type == "regression":
mapie_bundle = self.mapie.get((prop_key, col))
if mapie_bundle:
emb = get_pooled_emb() if mapie_bundle.get("adaptive") else None
lo, hi = _mapie_uncertainty(mapie_bundle, score, emb)
return (lo, hi), "conformal_prediction_interval"
return None, "unavailable (MAPIE bundle not found)"
# joblib classifiers (SVM, ElasticNet used as classifier)
if kind == "joblib" and task_type == "classifier":
return _binary_entropy(score), "binary_predictive_entropy_single_model"
return None, "unavailable"
def predict_property(self, prop_key: str, col: str, input_str: str,
uncertainty: bool = False) -> Dict[str, Any]:
if (prop_key, col) not in self.models:
raise KeyError(f"No model loaded for ({prop_key}, {col}).")
meta = self.meta[(prop_key, col)]
model = self.models[(prop_key, col)]
task_type = meta["task_type"].lower()
thr = meta.get("threshold")
kind = meta["kind"]
model_name = meta["model_name"]
if prop_key == "binding_affinity":
raise RuntimeError("Use predict_binding_affinity().")
# DNN
if kind == "torch_ckpt":
X, M = self._get_features(prop_key, col, input_str)
with torch.no_grad():
raw = model(X, M).squeeze().float().cpu().item()
if prop_key == "halflife" and col == "wt" and "log" in model_name:
raw = float(np.expm1(raw))
if task_type == "classifier":
score = float(1.0 / (1.0 + np.exp(-raw)))
out = {"property": prop_key, "col": col, "score": score,
"emb_tag": meta["emb_tag"]}
if thr is not None:
out["label"] = int(score >= float(thr)); out["threshold"] = float(thr)
else:
out = {"property": prop_key, "col": col, "score": float(raw),
"emb_tag": meta["emb_tag"]}
# XGBoost
elif kind == "xgb":
feats = self._get_features(prop_key, col, input_str)
pred = float(model.predict(xgb.DMatrix(feats))[0])
if prop_key == "halflife" and col == "wt" and "log" in model_name:
pred = float(np.expm1(pred))
out = {"property": prop_key, "col": col, "score": pred,
"emb_tag": meta["emb_tag"]}
if task_type == "classifier" and thr is not None:
out["label"] = int(pred >= float(thr)); out["threshold"] = float(thr)
# joblib (SVM / ElasticNet / SVR)
elif kind == "joblib":
feats = self._get_features(prop_key, col, input_str)
if task_type == "classifier":
if hasattr(model, "predict_proba"):
pred = float(model.predict_proba(feats)[:, 1][0])
elif hasattr(model, "decision_function"):
pred = float(1.0 / (1.0 + np.exp(-model.decision_function(feats)[0])))
else:
pred = float(model.predict(feats)[0])
out = {"property": prop_key, "col": col, "score": pred,
"emb_tag": meta["emb_tag"]}
if thr is not None:
out["label"] = int(pred >= float(thr)); out["threshold"] = float(thr)
else:
pred = float(model.predict(feats)[0])
out = {"property": prop_key, "col": col, "score": pred,
"emb_tag": meta["emb_tag"]}
else:
raise RuntimeError(f"Unknown kind={kind}")
if uncertainty:
u_val, u_type = self._compute_uncertainty(prop_key, col, input_str, out["score"])
out["uncertainty"] = u_val
out["uncertainty_type"] = u_type
return out
def predict_binding_affinity(self, col: str, target_seq: str, binder_str: str,
uncertainty: bool = False) -> Dict[str, Any]:
prop_key = "binding_affinity"
if (prop_key, col) not in self.models:
raise KeyError(f"No binding model loaded for ({prop_key}, {col}).")
model = self.models[(prop_key, col)]
meta = self.meta[(prop_key, col)]
arch = meta["model_name"]
emb_tag = meta.get("emb_tag")
if arch == "pooled":
t_vec = self.wt_embedder.pooled(target_seq)
b_vec = self._get_embedder(emb_tag or col).pooled(binder_str) if emb_tag else \
(self.wt_embedder.pooled(binder_str) if col == "wt" else self.chemberta_embedder.pooled(binder_str))
with torch.no_grad():
reg, logits = model(t_vec, b_vec)
else:
T, Mt = self.wt_embedder.unpooled(target_seq)
binder_emb = self._get_embedder(emb_tag or col) if emb_tag else \
(self.wt_embedder if col == "wt" else self.smiles_embedder)
B, Mb = binder_emb.unpooled(binder_str)
with torch.no_grad():
reg, logits = model(T, Mt, B, Mb)
affinity = float(reg.squeeze().cpu().item())
cls_logit = int(torch.argmax(logits, dim=-1).cpu().item())
cls_thr = affinity_to_class(affinity)
names = {0: "High (≥9)", 1: "Moderate (7-9)", 2: "Low (<7)"}
out = {
"property": "binding_affinity",
"col": col,
"affinity": affinity,
"class_by_threshold": names[cls_thr],
"class_by_logits": names[cls_logit],
"binding_model": arch,
}
if uncertainty:
mapie_bundle = self.mapie.get((prop_key, col))
if mapie_bundle:
if mapie_bundle.get("adaptive") and "sigma_model" in mapie_bundle:
# Concatenate target + binder pooled embeddings for sigma model
binder_emb_tag = mapie_bundle.get("emb_tag") or col
target_emb_tag = mapie_bundle.get("target_emb_tag", "wt")
t_vec = self.wt_embedder.pooled(target_seq).cpu().float().numpy()
b_vec = self._get_embedder(binder_emb_tag).pooled(binder_str).cpu().float().numpy()
emb = np.concatenate([t_vec, b_vec], axis=1)
else:
emb = None
lo, hi = _mapie_uncertainty(mapie_bundle, affinity, emb)
out["uncertainty"] = (lo, hi)
out["uncertainty_type"] = "conformal_prediction_interval"
else:
out["uncertainty"] = None
out["uncertainty_type"] = "unavailable (no MAPIE bundle found)"
return out
if __name__ == "__main__":
root = Path(__file__).resolve().parent # current script folder
predictor = PeptiVersePredictor(
manifest_path=root / "best_models.txt",
classifier_weight_root=root
)
print(predictor.training_root)
print("MAPIE keys:", list(predictor.mapie.keys()))
print("Ensemble keys:", list(predictor.ensembles.keys()))
seq = "GIGAVLKVLTTGLPALISWIKRKRQQ"
smiles = "C(C)C[C@@H]1NC(=O)[C@@H]2CCCN2C(=O)[C@@H](CC(C)C)NC(=O)[C@@H](CC(C)C)N(C)C(=O)[C@H](C)NC(=O)[C@H](Cc2ccccc2)NC1=O"
print(predictor.predict_property("hemolysis", "wt", seq))
print(predictor.predict_property("hemolysis", "smiles", smiles, uncertainty=True))
print(predictor.predict_property("nf", "wt", seq, uncertainty=True))
print(predictor.predict_property("nf", "smiles", smiles, uncertainty=True))
print(predictor.predict_binding_affinity("wt", target_seq=seq, binder_str="GIGAVLKVLT"))
print(predictor.predict_binding_affinity("wt", target_seq=seq, binder_str="GIGAVLKVLT", uncertainty=True))
seq1 = "GIGAVLKVLTTGLPALISWIKRKRQQ"
seq2 = "ACDEFGHIKLMNPQRSTVWY"
r1 = predictor.predict_binding_affinity("wt", target_seq=seq2, binder_str="GIGAVLKVLT", uncertainty=True)
r2 = predictor.predict_property("nf", "wt", seq1, uncertainty=True)
r3 = predictor.predict_property("nf", "wt", seq2, uncertainty=True)
print(r1)
print(r2)
print(r3)