from __future__ import annotations import csv, re, json from dataclasses import dataclass from pathlib import Path from typing import Dict, Optional, Tuple, Any, List import numpy as np import torch import torch.nn as nn import joblib import xgboost as xgb from transformers import EsmModel, EsmTokenizer, AutoModelForMaskedLM from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer from lightning.pytorch import seed_everything seed_everything(1986) # ----------------------------- # Manifest # ----------------------------- EMB_TAG_TO_FOLDER_SUFFIX = { "wt": "wt", "peptideclm": "smiles", "chemberta": "chemberta", } EMB_TAG_TO_RUNTIME_MODE = { "wt": "wt", "peptideclm": "smiles", "chemberta": "chemberta", } MAPIE_REGRESSION_MODELS = {"svr", "enet_gpu"} DNN_ARCHS = {"mlp", "cnn", "transformer"} XGB_MODELS = {"xgb", "xgb_reg", "xgb_wt_log", "xgb_smiles"} @dataclass(frozen=True) class BestRow: property_key: str best_wt: Optional[Tuple[str, Optional[str]]] best_smiles: Optional[Tuple[str, Optional[str]]] task_type: str thr_wt: Optional[float] thr_smiles: Optional[float] def _clean(s: str) -> str: return (s or "").strip() def _none_if_dash(s: str) -> Optional[str]: s = _clean(s) return None if s in {"", "-", "-", "NA", "N/A"} else s def _float_or_none(s: str) -> Optional[float]: s = _clean(s) return None if s in {"", "-", "-", "NA", "N/A"} else float(s) def normalize_property_key(name: str) -> str: n = name.strip().lower() n = re.sub(r"\s*\(.*?\)\s*", "", n) n = n.replace("-", "_").replace(" ", "_") if "permeability" in n and "pampa" not in n and "caco" not in n: return "permeability_penetrance" if n == "binding_affinity": return "binding_affinity" if n in {"halflife", "half_life"}: return "halflife" if n == "non_fouling": return "nf" return n MODEL_ALIAS = { "SVM": "svm_gpu", "SVR": "svr", "ENET": "enet_gpu", "CNN": "cnn", "MLP": "mlp", "TRANSFORMER": "transformer", "XGB": "xgb", "XGB_REG": "xgb_reg", "POOLED": "pooled", "UNPOOLED": "unpooled", "TRANSFORMER_WT_LOG": "transformer_wt_log", } def _parse_model_and_emb(raw: Optional[str]) -> Optional[Tuple[str, Optional[str]]]: if raw is None: return None raw = _clean(raw) if not raw or raw in {"-", "-", "NA", "N/A"}: return None m = re.match(r"^(.+?)\s*\((.+?)\)\s*$", raw) if m: model_raw = m.group(1).strip() emb_tag = m.group(2).strip().lower() else: model_raw = raw emb_tag = None canon = MODEL_ALIAS.get(model_raw.upper(), model_raw.lower()) return canon, emb_tag def read_best_manifest_csv(path: str | Path) -> Dict[str, BestRow]: p = Path(path) out: Dict[str, BestRow] = {} with p.open("r", newline="") as f: reader = csv.reader(f) header = None for raw in reader: if not raw or all(_clean(x) == "" for x in raw): continue while raw and _clean(raw[-1]) == "": raw = raw[:-1] if header is None: header = [h.strip() for h in raw] continue if len(raw) < len(header): raw = raw + [""] * (len(header) - len(raw)) rec = dict(zip(header, raw)) prop_raw = _clean(rec.get("Properties", "")) if not prop_raw: continue prop_key = normalize_property_key(prop_raw) best_wt = _parse_model_and_emb(_none_if_dash(rec.get("Best_Model_WT", ""))) best_smiles = _parse_model_and_emb(_none_if_dash(rec.get("Best_Model_SMILES", ""))) row = BestRow( property_key=prop_key, best_wt=best_wt, best_smiles=best_smiles, task_type=_clean(rec.get("Type", "Classifier")), thr_wt=_float_or_none(rec.get("Threshold_WT", "")), thr_smiles=_float_or_none(rec.get("Threshold_SMILES", "")), ) out[prop_key] = row return out # ----------------------------- # Generic artifact loading # ----------------------------- def find_best_artifact(model_dir: Path) -> Path: for pat in ["best_model.json", "best_model.pt", "best_model*.joblib", "model.json", "model.ubj", "final_model.json"]: hits = sorted(model_dir.glob(pat)) if hits: return hits[0] seed_pt = model_dir / "seed_1986" / "model.pt" if seed_pt.exists(): return seed_pt raise FileNotFoundError(f"No best_model artifact found in {model_dir}") def load_artifact(model_dir: Path, device: torch.device) -> Tuple[str, Any, Path]: art = find_best_artifact(model_dir) if art.suffix == ".json": booster = xgb.Booster() booster.load_model(str(art)) return "xgb", booster, art if art.suffix == ".joblib": obj = joblib.load(art) return "joblib", obj, art if art.suffix == ".pt": ckpt = torch.load(art, map_location=device, weights_only=False) return "torch_ckpt", ckpt, art raise ValueError(f"Unknown artifact type: {art}") # ----------------------------- # NN architectures # ----------------------------- class MaskedMeanPool(nn.Module): def forward(self, X, M): Mf = M.unsqueeze(-1).float() denom = Mf.sum(dim=1).clamp(min=1.0) return (X * Mf).sum(dim=1) / denom class MLPHead(nn.Module): def __init__(self, in_dim, hidden=512, dropout=0.1): super().__init__() self.pool = MaskedMeanPool() self.net = nn.Sequential( nn.Linear(in_dim, hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden, 1), ) def forward(self, X, M): return self.net(self.pool(X, M)).squeeze(-1) class CNNHead(nn.Module): def __init__(self, in_ch, c=256, k=5, layers=2, dropout=0.1): super().__init__() blocks, ch = [], in_ch for _ in range(layers): blocks += [nn.Conv1d(ch, c, kernel_size=k, padding=k//2), nn.GELU(), nn.Dropout(dropout)] ch = c self.conv = nn.Sequential(*blocks) self.head = nn.Linear(c, 1) def forward(self, X, M): Y = self.conv(X.transpose(1, 2)).transpose(1, 2) Mf = M.unsqueeze(-1).float() pooled = (Y * Mf).sum(dim=1) / Mf.sum(dim=1).clamp(min=1.0) return self.head(pooled).squeeze(-1) class TransformerHead(nn.Module): def __init__(self, in_dim, d_model=256, nhead=8, layers=2, ff=512, dropout=0.1): super().__init__() self.proj = nn.Linear(in_dim, d_model) enc_layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=nhead, dim_feedforward=ff, dropout=dropout, batch_first=True, activation="gelu" ) self.enc = nn.TransformerEncoder(enc_layer, num_layers=layers) self.head = nn.Linear(d_model, 1) def forward(self, X, M): Z = self.enc(self.proj(X), src_key_padding_mask=~M) Mf = M.unsqueeze(-1).float() pooled = (Z * Mf).sum(dim=1) / Mf.sum(dim=1).clamp(min=1.0) return self.head(pooled).squeeze(-1) def _infer_in_dim_from_sd(sd: dict, model_name: str) -> int: if model_name == "mlp": return int(sd["net.0.weight"].shape[1]) if model_name == "cnn": return int(sd["conv.0.weight"].shape[1]) if model_name == "transformer": return int(sd["proj.weight"].shape[1]) raise ValueError(model_name) def _infer_num_layers_from_sd(sd: dict, prefix: str = "enc.layers.") -> int: idxs = set() for k in sd.keys(): if k.startswith(prefix): m = re.match(r"(\d+)\.", k[len(prefix):]) if m: idxs.add(int(m.group(1))) return (max(idxs) + 1) if idxs else 1 def _infer_transformer_arch_from_sd(sd: dict) -> Tuple[int, int, int]: if "proj.weight" not in sd: raise KeyError("Missing proj.weight in state_dict") d_model = int(sd["proj.weight"].shape[0]) layers = _infer_num_layers_from_sd(sd, prefix="enc.layers.") ff = int(sd["enc.layers.0.linear1.weight"].shape[0]) if "enc.layers.0.linear1.weight" in sd else 4 * d_model return d_model, layers, ff def _pick_nhead(d_model: int) -> int: for h in (8, 6, 4, 3, 2, 1): if d_model % h == 0: return h return 1 def build_torch_model_from_ckpt(model_name: str, ckpt: dict, device: torch.device) -> nn.Module: params = ckpt["best_params"] sd = ckpt["state_dict"] in_dim = int(ckpt.get("in_dim", _infer_in_dim_from_sd(sd, model_name))) dropout = float(params.get("dropout", 0.1)) if model_name == "mlp": model = MLPHead(in_dim=in_dim, hidden=int(params["hidden"]), dropout=dropout) elif model_name == "cnn": model = CNNHead(in_ch=in_dim, c=int(params["channels"]), k=int(params["kernel"]), layers=int(params["layers"]), dropout=dropout) elif model_name == "transformer": d_model = params.get("d_model") or params.get("hidden") or params.get("hidden_dim") if d_model is None: d_model_i, layers_i, ff_i = _infer_transformer_arch_from_sd(sd) nhead_i = _pick_nhead(d_model_i) model = TransformerHead( in_dim=in_dim, d_model=int(d_model_i), nhead=int(params.get("nhead", nhead_i)), layers=int(params.get("layers", layers_i)), ff=int(params.get("ff", ff_i)), dropout=float(params.get("dropout", dropout)), ) else: d_model = int(d_model) model = TransformerHead( in_dim=in_dim, d_model=d_model, nhead=int(params.get("nhead", _pick_nhead(d_model))), layers=int(params.get("layers", 2)), ff=int(params.get("ff", 4 * d_model)), dropout=dropout, ) else: raise ValueError(f"Unknown NN model_name={model_name}") model.load_state_dict(sd) model.to(device).eval() return model # ----------------------------- # Wrappers # ----------------------------- from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin class PassthroughRegressor(BaseEstimator, RegressorMixin): def __init__(self, preds: np.ndarray): self.preds = preds def fit(self, X, y): return self def predict(self, X): return self.preds[:len(X)] class PassthroughClassifier(BaseEstimator, ClassifierMixin): def __init__(self, preds: np.ndarray): self.preds = preds self.classes_ = np.array([0, 1]) def fit(self, X, y): return self def predict(self, X): return (self.preds[:len(X)] >= 0.5).astype(int) def predict_proba(self, X): p = self.preds[:len(X)] return np.stack([1 - p, p], axis=1) # ----------------------------- # Uncertainty helpers # ----------------------------- SEED_DIRS = ["seed_1986", "seed_42", "seed_0", "seed_123", "seed_12345"] def load_seed_ensemble(model_dir: Path, arch: str, device: torch.device) -> List[nn.Module]: ensemble = [] for sd_name in SEED_DIRS: pt = model_dir / sd_name / "model.pt" if not pt.exists(): continue ckpt = torch.load(pt, map_location=device, weights_only=False) ensemble.append(build_torch_model_from_ckpt(arch, ckpt, device)) return ensemble def _binary_entropy(p: float) -> float: p = float(np.clip(p, 1e-9, 1 - 1e-9)) return float(-p * np.log(p) - (1 - p) * np.log(1 - p)) def _ensemble_clf_uncertainty(ensemble: List[nn.Module], X: torch.Tensor, M: torch.Tensor) -> float: probs = [] with torch.no_grad(): for m in ensemble: logit = m(X, M).squeeze().float().cpu().item() probs.append(1.0 / (1.0 + np.exp(-logit))) return _binary_entropy(float(np.mean(probs))) def _ensemble_reg_uncertainty(ensemble: List[nn.Module], X: torch.Tensor, M: torch.Tensor) -> float: preds = [] with torch.no_grad(): for m in ensemble: preds.append(m(X, M).squeeze().float().cpu().item()) return float(np.std(preds)) def _mapie_uncertainty(mapie_bundle: dict, score: float, embedding: Optional[np.ndarray] = None) -> Tuple[float, float]: """ Returns (ci_low, ci_high) from a conformal bundle. - adaptive: {"quantile": q, "sigma_model": xgb, "emb_tag": ..., "adaptive": True} Input-dependent: interval = score +/- q * sigma(embedding) - plain_quantile: {"quantile": q, "alpha": ...} Fixed-width: interval = score +/- q """ # Adaptive format is input-dependent interval if mapie_bundle.get("adaptive") and "sigma_model" in mapie_bundle: q = float(mapie_bundle["quantile"]) if embedding is not None: # Adaptive interval: y_hat ± q * sigma_hat(x). # Equivalent to MAPIE's get_estimation_distribution(): # y_pred + conformity_scores * r_pred # where conformity_scores=q and r_pred=sigma_hat(x). # (ResidualNormalisedScore, Cordier et al. 2023) sigma_model = mapie_bundle["sigma_model"] sigma = float(sigma_model.predict(xgb.DMatrix(embedding.reshape(1, -1)))[0]) sigma = max(sigma, 1e-6) else: # No embedding available - fall back to fixed interval with sigma=1 sigma = 1.0 return float(score - q * sigma), float(score + q * sigma) # Plain quantile format if "quantile" in mapie_bundle: q = float(mapie_bundle["quantile"]) return float(score - q), float(score + q) X_dummy = np.zeros((1, 1)) result = mapie.predict(X_dummy) if isinstance(result, tuple): intervals = np.asarray(result[1]) if intervals.ndim == 3: return float(intervals[0, 0, 0]), float(intervals[0, 1, 0]) return float(intervals[0, 0]), float(intervals[0, 1]) raise RuntimeError( f"Cannot extract intervals: unknown MAPIE bundle format. " f"Bundle keys: {list(mapie_bundle.keys())}." ) def affinity_to_class(y: float) -> int: if y >= 9.0: return 0 if y < 7.0: return 2 return 1 class CrossAttnPooled(nn.Module): def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1): super().__init__() self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden)) self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden)) self.layers = nn.ModuleList([]) for _ in range(n_layers): self.layers.append(nn.ModuleDict({ "attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False), "attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False), "n1t": nn.LayerNorm(hidden), "n2t": nn.LayerNorm(hidden), "n1b": nn.LayerNorm(hidden), "n2b": nn.LayerNorm(hidden), "fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), "ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), })) self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout)) self.reg = nn.Linear(hidden, 1) self.cls = nn.Linear(hidden, 3) def forward(self, t_vec, b_vec): t = self.t_proj(t_vec).unsqueeze(0) b = self.b_proj(b_vec).unsqueeze(0) for L in self.layers: t_attn, _ = L["attn_tb"](t, b, b) t = L["n1t"]((t + t_attn).transpose(0,1)).transpose(0,1) t = L["n2t"]((t + L["fft"](t)).transpose(0,1)).transpose(0,1) b_attn, _ = L["attn_bt"](b, t, t) b = L["n1b"]((b + b_attn).transpose(0,1)).transpose(0,1) b = L["n2b"]((b + L["ffb"](b)).transpose(0,1)).transpose(0,1) h = self.shared(torch.cat([t[0], b[0]], dim=-1)) return self.reg(h).squeeze(-1), self.cls(h) class CrossAttnUnpooled(nn.Module): def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1): super().__init__() self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden)) self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden)) self.layers = nn.ModuleList([]) for _ in range(n_layers): self.layers.append(nn.ModuleDict({ "attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True), "attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True), "n1t": nn.LayerNorm(hidden), "n2t": nn.LayerNorm(hidden), "n1b": nn.LayerNorm(hidden), "n2b": nn.LayerNorm(hidden), "fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), "ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), })) self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout)) self.reg = nn.Linear(hidden, 1) self.cls = nn.Linear(hidden, 3) def _masked_mean(self, X, M): Mf = M.unsqueeze(-1).float() return (X * Mf).sum(dim=1) / Mf.sum(dim=1).clamp(min=1.0) def forward(self, T, Mt, B, Mb): T = self.t_proj(T); Bx = self.b_proj(B) kp_t, kp_b = ~Mt, ~Mb for L in self.layers: T_attn, _ = L["attn_tb"](T, Bx, Bx, key_padding_mask=kp_b) T = L["n1t"](T + T_attn); T = L["n2t"](T + L["fft"](T)) B_attn, _ = L["attn_bt"](Bx, T, T, key_padding_mask=kp_t) Bx = L["n1b"](Bx + B_attn); Bx = L["n2b"](Bx + L["ffb"](Bx)) h = self.shared(torch.cat([self._masked_mean(T, Mt), self._masked_mean(Bx, Mb)], dim=-1)) return self.reg(h).squeeze(-1), self.cls(h) def load_binding_model(best_model_pt: Path, pooled_or_unpooled: str, device: torch.device) -> nn.Module: ckpt = torch.load(best_model_pt, map_location=device, weights_only=False) params = ckpt["best_params"] sd = ckpt["state_dict"] Ht = int(sd["t_proj.0.weight"].shape[1]) Hb = int(sd["b_proj.0.weight"].shape[1]) common = dict(Ht=Ht, Hb=Hb, hidden=int(params["hidden_dim"]), n_heads=int(params["n_heads"]), n_layers=int(params["n_layers"]), dropout=float(params["dropout"])) cls = CrossAttnPooled if pooled_or_unpooled == "pooled" else CrossAttnUnpooled model = cls(**common) model.load_state_dict(sd) return model.to(device).eval() # ----------------------------- # Embedding generation # ----------------------------- def _safe_isin(ids: torch.Tensor, test_ids: torch.Tensor) -> torch.Tensor: if hasattr(torch, "isin"): return torch.isin(ids, test_ids) return (ids.unsqueeze(-1) == test_ids.view(1, 1, -1)).any(dim=-1) class SMILESEmbedder: def __init__(self, device, vocab_path, splits_path, clm_name="aaronfeller/PeptideCLM-23M-all", max_len=512, use_cache=True): self.device = device self.max_len = max_len self.use_cache = use_cache self.tokenizer = SMILES_SPE_Tokenizer(vocab_path, splits_path) self.model = AutoModelForMaskedLM.from_pretrained(clm_name).roformer.to(device).eval() self.special_ids = self._get_special_ids(self.tokenizer) self.special_ids_t = (torch.tensor(self.special_ids, device=device, dtype=torch.long) if self.special_ids else None) self._cache_pooled: Dict[str, torch.Tensor] = {} self._cache_unpooled: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {} @staticmethod def _get_special_ids(tokenizer) -> List[int]: cand = [getattr(tokenizer, f"{x}_token_id", None) for x in ("pad", "cls", "sep", "bos", "eos", "mask")] return sorted({int(x) for x in cand if x is not None}) def _tokenize(self, smiles_list): tok = self.tokenizer(smiles_list, return_tensors="pt", padding=True, truncation=True, max_length=self.max_len) for k in tok: tok[k] = tok[k].to(self.device) if "attention_mask" not in tok: tok["attention_mask"] = torch.ones_like(tok["input_ids"], dtype=torch.long, device=self.device) return tok def _valid_mask(self, ids, attn): valid = attn.bool() if self.special_ids_t is not None and self.special_ids_t.numel() > 0: valid = valid & (~_safe_isin(ids, self.special_ids_t)) return valid @torch.no_grad() def pooled(self, smiles: str) -> torch.Tensor: s = smiles.strip() if self.use_cache and s in self._cache_pooled: return self._cache_pooled[s] tok = self._tokenize([s]) h = self.model(input_ids=tok["input_ids"], attention_mask=tok["attention_mask"]).last_hidden_state valid = self._valid_mask(tok["input_ids"], tok["attention_mask"]) vf = valid.unsqueeze(-1).float() pooled = (h * vf).sum(dim=1) / vf.sum(dim=1).clamp(min=1e-9) if self.use_cache: self._cache_pooled[s] = pooled return pooled @torch.no_grad() def unpooled(self, smiles: str) -> Tuple[torch.Tensor, torch.Tensor]: s = smiles.strip() if self.use_cache and s in self._cache_unpooled: return self._cache_unpooled[s] tok = self._tokenize([s]) h = self.model(input_ids=tok["input_ids"], attention_mask=tok["attention_mask"]).last_hidden_state valid = self._valid_mask(tok["input_ids"], tok["attention_mask"]) X = h[:, valid[0], :] M = torch.ones((1, X.shape[1]), dtype=torch.bool, device=self.device) if self.use_cache: self._cache_unpooled[s] = (X, M) return X, M class ChemBERTaEmbedder: def __init__(self, device, model_name="DeepChem/ChemBERTa-77M-MLM", max_len=512, use_cache=True): from transformers import AutoTokenizer, AutoModel self.device = device self.max_len = max_len self.use_cache = use_cache self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModel.from_pretrained(model_name).to(device).eval() self.special_ids = self._get_special_ids(self.tokenizer) self.special_ids_t = (torch.tensor(self.special_ids, device=device, dtype=torch.long) if self.special_ids else None) self._cache_pooled: Dict[str, torch.Tensor] = {} self._cache_unpooled: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {} @staticmethod def _get_special_ids(tokenizer) -> List[int]: cand = [getattr(tokenizer, f"{x}_token_id", None) for x in ("pad", "cls", "sep", "bos", "eos", "mask")] return sorted({int(x) for x in cand if x is not None}) def _tokenize(self, smiles_list): tok = self.tokenizer(smiles_list, return_tensors="pt", padding=True, truncation=True, max_length=self.max_len) for k in tok: tok[k] = tok[k].to(self.device) if "attention_mask" not in tok: tok["attention_mask"] = torch.ones_like(tok["input_ids"], dtype=torch.long, device=self.device) return tok def _valid_mask(self, ids, attn): valid = attn.bool() if self.special_ids_t is not None and self.special_ids_t.numel() > 0: valid = valid & (~_safe_isin(ids, self.special_ids_t)) return valid @torch.no_grad() def pooled(self, smiles: str) -> torch.Tensor: s = smiles.strip() if self.use_cache and s in self._cache_pooled: return self._cache_pooled[s] tok = self._tokenize([s]) h = self.model(input_ids=tok["input_ids"], attention_mask=tok["attention_mask"]).last_hidden_state valid = self._valid_mask(tok["input_ids"], tok["attention_mask"]) vf = valid.unsqueeze(-1).float() pooled = (h * vf).sum(dim=1) / vf.sum(dim=1).clamp(min=1e-9) if self.use_cache: self._cache_pooled[s] = pooled return pooled @torch.no_grad() def unpooled(self, smiles: str) -> Tuple[torch.Tensor, torch.Tensor]: s = smiles.strip() if self.use_cache and s in self._cache_unpooled: return self._cache_unpooled[s] tok = self._tokenize([s]) h = self.model(input_ids=tok["input_ids"], attention_mask=tok["attention_mask"]).last_hidden_state valid = self._valid_mask(tok["input_ids"], tok["attention_mask"]) X = h[:, valid[0], :] M = torch.ones((1, X.shape[1]), dtype=torch.bool, device=self.device) if self.use_cache: self._cache_unpooled[s] = (X, M) return X, M class WTEmbedder: def __init__(self, device, esm_name="facebook/esm2_t33_650M_UR50D", max_len=1022, use_cache=True): self.device = device self.max_len = max_len self.use_cache = use_cache self.tokenizer = EsmTokenizer.from_pretrained(esm_name) self.model = EsmModel.from_pretrained(esm_name, add_pooling_layer=False).to(device).eval() self.special_ids = self._get_special_ids(self.tokenizer) self.special_ids_t = (torch.tensor(self.special_ids, device=device, dtype=torch.long) if self.special_ids else None) self._cache_pooled: Dict[str, torch.Tensor] = {} self._cache_unpooled: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {} @staticmethod def _get_special_ids(tokenizer) -> List[int]: cand = [getattr(tokenizer, f"{x}_token_id", None) for x in ("pad", "cls", "sep", "bos", "eos", "mask")] return sorted({int(x) for x in cand if x is not None}) def _tokenize(self, seq_list): tok = self.tokenizer(seq_list, return_tensors="pt", padding=True, truncation=True, max_length=self.max_len) tok = {k: v.to(self.device) for k, v in tok.items()} if "attention_mask" not in tok: tok["attention_mask"] = torch.ones_like(tok["input_ids"], dtype=torch.long, device=self.device) return tok def _valid_mask(self, ids, attn): valid = attn.bool() if self.special_ids_t is not None and self.special_ids_t.numel() > 0: valid = valid & (~_safe_isin(ids, self.special_ids_t)) return valid @torch.no_grad() def pooled(self, seq: str) -> torch.Tensor: s = seq.strip() if self.use_cache and s in self._cache_pooled: return self._cache_pooled[s] tok = self._tokenize([s]) h = self.model(**tok).last_hidden_state valid = self._valid_mask(tok["input_ids"], tok["attention_mask"]) vf = valid.unsqueeze(-1).float() pooled = (h * vf).sum(dim=1) / vf.sum(dim=1).clamp(min=1e-9) if self.use_cache: self._cache_pooled[s] = pooled return pooled @torch.no_grad() def unpooled(self, seq: str) -> Tuple[torch.Tensor, torch.Tensor]: s = seq.strip() if self.use_cache and s in self._cache_unpooled: return self._cache_unpooled[s] tok = self._tokenize([s]) h = self.model(**tok).last_hidden_state valid = self._valid_mask(tok["input_ids"], tok["attention_mask"]) X = h[:, valid[0], :] M = torch.ones((1, X.shape[1]), dtype=torch.bool, device=self.device) if self.use_cache: self._cache_unpooled[s] = (X, M) return X, M # ----------------------------- # Predictor # ----------------------------- class PeptiVersePredictor: def __init__( self, manifest_path: str | Path, classifier_weight_root: str | Path, esm_name="facebook/esm2_t33_650M_UR50D", clm_name="aaronfeller/PeptideCLM-23M-all", chemberta_name="DeepChem/ChemBERTa-77M-MLM", smiles_vocab="tokenizer/new_vocab.txt", smiles_splits="tokenizer/new_splits.txt", device: Optional[str] = None, ): self.root = Path(classifier_weight_root) self.training_root = self.root / "training_classifiers" self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu")) self.manifest = read_best_manifest_csv(manifest_path) self.wt_embedder = WTEmbedder(self.device, esm_name=esm_name) self.smiles_embedder = SMILESEmbedder(self.device, clm_name=clm_name, vocab_path=str(self.root / smiles_vocab), splits_path=str(self.root / smiles_splits)) self.chemberta_embedder = ChemBERTaEmbedder(self.device, model_name=chemberta_name) self.models: Dict[Tuple[str, str], Any] = {} self.meta: Dict[Tuple[str, str], Dict[str, Any]] = {} self.mapie: Dict[Tuple[str, str], dict] = {} self.ensembles: Dict[Tuple[str, str], List] = {} self._load_all_best_models() def _get_embedder(self, emb_tag: str): if emb_tag == "wt": return self.wt_embedder if emb_tag == "peptideclm": return self.smiles_embedder if emb_tag == "chemberta": return self.chemberta_embedder raise ValueError(f"Unknown emb_tag={emb_tag!r}") def _embed_pooled(self, emb_tag: str, input_str: str) -> np.ndarray: v = self._get_embedder(emb_tag).pooled(input_str) feats = v.detach().cpu().numpy().astype(np.float32) feats = np.nan_to_num(feats, nan=0.0) return np.clip(feats, np.finfo(np.float32).min, np.finfo(np.float32).max) def _embed_unpooled(self, emb_tag: str, input_str: str) -> Tuple[torch.Tensor, torch.Tensor]: return self._get_embedder(emb_tag).unpooled(input_str) def _resolve_dir(self, prop_key: str, model_name: str, emb_tag: str) -> Path: disk_prop = "half_life" if prop_key == "halflife" else prop_key base = self.training_root / disk_prop folder_suffix = EMB_TAG_TO_FOLDER_SUFFIX.get(emb_tag, emb_tag) if prop_key == "halflife" and emb_tag == "wt": if model_name == "transformer": for d in [base / "transformer_wt_log", base / "transformer_wt"]: if d.exists(): return d if model_name in {"xgb", "xgb_reg"}: d = base / "xgb_wt_log" if d.exists(): return d candidates = [ base / f"{model_name}_{folder_suffix}", base / model_name, ] for d in candidates: if d.exists(): return d raise FileNotFoundError( f"Cannot find model dir for {prop_key}/{model_name}/{emb_tag}. Tried: {candidates}" ) def _load_all_best_models(self): for prop_key, row in self.manifest.items(): for col, parsed, thr in [ ("wt", row.best_wt, row.thr_wt), ("smiles", row.best_smiles, row.thr_smiles), ]: if parsed is None: continue model_name, emb_tag = parsed # binding affinity if prop_key == "binding_affinity": folder = model_name pooled_or_unpooled = "unpooled" if "unpooled" in folder else "pooled" model_dir = self.training_root / "binding_affinity" / folder art = find_best_artifact(model_dir) model = load_binding_model(art, pooled_or_unpooled, self.device) self.models[(prop_key, col)] = model self.meta[(prop_key, col)] = { "task_type": "Regression", "threshold": None, "artifact": str(art), "model_name": pooled_or_unpooled, "emb_tag": emb_tag, "folder": folder, "kind": "binding", } print(f" [LOAD] binding_affinity ({col}): folder={folder}, arch={pooled_or_unpooled}, emb_tag={emb_tag}, art={art.name}") mapie_path = model_dir / "mapie_calibration.joblib" if mapie_path.exists(): try: self.mapie[(prop_key, col)] = joblib.load(mapie_path) print(f" MAPIE loaded from {mapie_path.name}") except Exception as e: print(f" MAPIE load FAILED for ({prop_key}, {col}): {e}") else: print(f" No MAPIE bundle found (uncertainty will be unavailable)") continue # infer emb_tag if emb_tag is None: emb_tag = col model_dir = self._resolve_dir(prop_key, model_name, emb_tag) kind, obj, art = load_artifact(model_dir, self.device) if kind == "torch_ckpt": arch = self._base_arch(model_name) model = build_torch_model_from_ckpt(arch, obj, self.device) else: model = obj self.models[(prop_key, col)] = model self.meta[(prop_key, col)] = { "task_type": row.task_type, "threshold": thr, "artifact": str(art), "model_name": model_name, "emb_tag": emb_tag, "kind": kind, } print(f" [LOAD] ({prop_key}, {col}): kind={kind}, model={model_name}, emb={emb_tag}, task={row.task_type}, art={art.name}") # MAPIE: SVR/ElasticNet, XGBoost regression, AND all regression torch_ckpt is_regression = row.task_type.lower() == "regression" wants_mapie = ( (model_name in MAPIE_REGRESSION_MODELS and is_regression) or (kind == "xgb" and is_regression) or (kind == "torch_ckpt" and is_regression) ) if wants_mapie: mapie_path = model_dir / "mapie_calibration.joblib" if mapie_path.exists(): try: self.mapie[(prop_key, col)] = joblib.load(mapie_path) print(f" MAPIE loaded from {mapie_path.name}") except Exception as e: print(f" MAPIE load FAILED for ({prop_key}, {col}): {e}") else: print(f" No MAPIE bundle found at {mapie_path} (will fall back to ensemble if available)") # Seed ensembles: DNN only, used when MAPIE not available if kind == "torch_ckpt": arch = self._base_arch(model_name) ens = load_seed_ensemble(model_dir, arch, self.device) if ens: self.ensembles[(prop_key, col)] = ens if (prop_key, col) in self.mapie: print(f" Seed ensemble: {len(ens)} seeds loaded (MAPIE takes priority for regression)") else: unc_type = "ensemble_predictive_entropy" if row.task_type.lower() == "classifier" else "ensemble_std" print(f" Seed ensemble: {len(ens)} seeds loaded uncertainty method: {unc_type}") else: if (prop_key, col) in self.mapie: print(f" No seed ensemble (MAPIE covers uncertainty)") else: print(f" No seed ensemble found (checked: {SEED_DIRS}) - uncertainty unavailable") # XGBoost/SVM classifiers: binary entropy if kind in ("xgb", "joblib") and row.task_type.lower() == "classifier": print(f" Uncertainty method: binary_predictive_entropy (computed at inference)") @staticmethod def _base_arch(model_name: str) -> str: if model_name.startswith("transformer"): return "transformer" if model_name.startswith("mlp"): return "mlp" if model_name.startswith("cnn"): return "cnn" return model_name # Feature extraction def _get_features(self, prop_key: str, col: str, input_str: str): meta = self.meta[(prop_key, col)] emb_tag = meta["emb_tag"] kind = meta["kind"] if kind == "torch_ckpt": return self._embed_unpooled(emb_tag, input_str) return self._embed_pooled(emb_tag, input_str) # Uncertainty def _compute_uncertainty(self, prop_key: str, col: str, input_str: str, score: float) -> Tuple[Any, str]: meta = self.meta[(prop_key, col)] kind = meta["kind"] model_name = meta["model_name"] task_type = meta["task_type"].lower() emb_tag = meta["emb_tag"] # Pooled embedding for adaptive MAPIE sigma model def get_pooled_emb(): return self._embed_pooled(emb_tag, input_str) if emb_tag else None # DNN if kind == "torch_ckpt": # Regression: prefer MAPIE if available if task_type == "regression": mapie_bundle = self.mapie.get((prop_key, col)) if mapie_bundle: emb = get_pooled_emb() if mapie_bundle.get("adaptive") else None lo, hi = _mapie_uncertainty(mapie_bundle, score, emb) return (lo, hi), "conformal_prediction_interval" # Fall back to seed ensemble std ens = self.ensembles.get((prop_key, col)) if ens: X, M = self._embed_unpooled(emb_tag, input_str) return _ensemble_reg_uncertainty(ens, X, M), "ensemble_std" return None, "unavailable (no MAPIE bundle and no seed ensemble)" # Classifier: ensemble predictive entropy ens = self.ensembles.get((prop_key, col)) if not ens: return None, "unavailable (no seed ensemble found)" X, M = self._embed_unpooled(emb_tag, input_str) return _ensemble_clf_uncertainty(ens, X, M), "ensemble_predictive_entropy" # XGBoost if kind == "xgb": if task_type == "classifier": return _binary_entropy(score), "binary_predictive_entropy" mapie_bundle = self.mapie.get((prop_key, col)) if mapie_bundle: emb = get_pooled_emb() if mapie_bundle.get("adaptive") else None lo, hi = _mapie_uncertainty(mapie_bundle, score, emb) return (lo, hi), "conformal_prediction_interval" return None, "unavailable (no MAPIE bundle for XGBoost regression)" # SVR / ElasticNet regression: MAPIE if kind == "joblib" and model_name in MAPIE_REGRESSION_MODELS and task_type == "regression": mapie_bundle = self.mapie.get((prop_key, col)) if mapie_bundle: emb = get_pooled_emb() if mapie_bundle.get("adaptive") else None lo, hi = _mapie_uncertainty(mapie_bundle, score, emb) return (lo, hi), "conformal_prediction_interval" return None, "unavailable (MAPIE bundle not found)" # joblib classifiers (SVM, ElasticNet used as classifier) if kind == "joblib" and task_type == "classifier": return _binary_entropy(score), "binary_predictive_entropy_single_model" return None, "unavailable" def predict_property(self, prop_key: str, col: str, input_str: str, uncertainty: bool = False) -> Dict[str, Any]: if (prop_key, col) not in self.models: raise KeyError(f"No model loaded for ({prop_key}, {col}).") meta = self.meta[(prop_key, col)] model = self.models[(prop_key, col)] task_type = meta["task_type"].lower() thr = meta.get("threshold") kind = meta["kind"] model_name = meta["model_name"] if prop_key == "binding_affinity": raise RuntimeError("Use predict_binding_affinity().") # DNN if kind == "torch_ckpt": X, M = self._get_features(prop_key, col, input_str) with torch.no_grad(): raw = model(X, M).squeeze().float().cpu().item() if prop_key == "halflife" and col == "wt" and "log" in model_name: raw = float(np.expm1(raw)) if task_type == "classifier": score = float(1.0 / (1.0 + np.exp(-raw))) out = {"property": prop_key, "col": col, "score": score, "emb_tag": meta["emb_tag"]} if thr is not None: out["label"] = int(score >= float(thr)); out["threshold"] = float(thr) else: out = {"property": prop_key, "col": col, "score": float(raw), "emb_tag": meta["emb_tag"]} # XGBoost elif kind == "xgb": feats = self._get_features(prop_key, col, input_str) pred = float(model.predict(xgb.DMatrix(feats))[0]) if prop_key == "halflife" and col == "wt" and "log" in model_name: pred = float(np.expm1(pred)) out = {"property": prop_key, "col": col, "score": pred, "emb_tag": meta["emb_tag"]} if task_type == "classifier" and thr is not None: out["label"] = int(pred >= float(thr)); out["threshold"] = float(thr) # joblib (SVM / ElasticNet / SVR) elif kind == "joblib": feats = self._get_features(prop_key, col, input_str) if task_type == "classifier": if hasattr(model, "predict_proba"): pred = float(model.predict_proba(feats)[:, 1][0]) elif hasattr(model, "decision_function"): pred = float(1.0 / (1.0 + np.exp(-model.decision_function(feats)[0]))) else: pred = float(model.predict(feats)[0]) out = {"property": prop_key, "col": col, "score": pred, "emb_tag": meta["emb_tag"]} if thr is not None: out["label"] = int(pred >= float(thr)); out["threshold"] = float(thr) else: pred = float(model.predict(feats)[0]) out = {"property": prop_key, "col": col, "score": pred, "emb_tag": meta["emb_tag"]} else: raise RuntimeError(f"Unknown kind={kind}") if uncertainty: u_val, u_type = self._compute_uncertainty(prop_key, col, input_str, out["score"]) out["uncertainty"] = u_val out["uncertainty_type"] = u_type return out def predict_binding_affinity(self, col: str, target_seq: str, binder_str: str, uncertainty: bool = False) -> Dict[str, Any]: prop_key = "binding_affinity" if (prop_key, col) not in self.models: raise KeyError(f"No binding model loaded for ({prop_key}, {col}).") model = self.models[(prop_key, col)] meta = self.meta[(prop_key, col)] arch = meta["model_name"] emb_tag = meta.get("emb_tag") if arch == "pooled": t_vec = self.wt_embedder.pooled(target_seq) b_vec = self._get_embedder(emb_tag or col).pooled(binder_str) if emb_tag else \ (self.wt_embedder.pooled(binder_str) if col == "wt" else self.chemberta_embedder.pooled(binder_str)) with torch.no_grad(): reg, logits = model(t_vec, b_vec) else: T, Mt = self.wt_embedder.unpooled(target_seq) binder_emb = self._get_embedder(emb_tag or col) if emb_tag else \ (self.wt_embedder if col == "wt" else self.smiles_embedder) B, Mb = binder_emb.unpooled(binder_str) with torch.no_grad(): reg, logits = model(T, Mt, B, Mb) affinity = float(reg.squeeze().cpu().item()) cls_logit = int(torch.argmax(logits, dim=-1).cpu().item()) cls_thr = affinity_to_class(affinity) names = {0: "High (≥9)", 1: "Moderate (7-9)", 2: "Low (<7)"} out = { "property": "binding_affinity", "col": col, "affinity": affinity, "class_by_threshold": names[cls_thr], "class_by_logits": names[cls_logit], "binding_model": arch, } if uncertainty: mapie_bundle = self.mapie.get((prop_key, col)) if mapie_bundle: if mapie_bundle.get("adaptive") and "sigma_model" in mapie_bundle: # Concatenate target + binder pooled embeddings for sigma model binder_emb_tag = mapie_bundle.get("emb_tag") or col target_emb_tag = mapie_bundle.get("target_emb_tag", "wt") t_vec = self.wt_embedder.pooled(target_seq).cpu().float().numpy() b_vec = self._get_embedder(binder_emb_tag).pooled(binder_str).cpu().float().numpy() emb = np.concatenate([t_vec, b_vec], axis=1) else: emb = None lo, hi = _mapie_uncertainty(mapie_bundle, affinity, emb) out["uncertainty"] = (lo, hi) out["uncertainty_type"] = "conformal_prediction_interval" else: out["uncertainty"] = None out["uncertainty_type"] = "unavailable (no MAPIE bundle found)" return out if __name__ == "__main__": root = Path(__file__).resolve().parent # current script folder predictor = PeptiVersePredictor( manifest_path=root / "best_models.txt", classifier_weight_root=root ) print(predictor.training_root) print("MAPIE keys:", list(predictor.mapie.keys())) print("Ensemble keys:", list(predictor.ensembles.keys())) seq = "GIGAVLKVLTTGLPALISWIKRKRQQ" smiles = "C(C)C[C@@H]1NC(=O)[C@@H]2CCCN2C(=O)[C@@H](CC(C)C)NC(=O)[C@@H](CC(C)C)N(C)C(=O)[C@H](C)NC(=O)[C@H](Cc2ccccc2)NC1=O" print(predictor.predict_property("hemolysis", "wt", seq)) print(predictor.predict_property("hemolysis", "smiles", smiles, uncertainty=True)) print(predictor.predict_property("nf", "wt", seq, uncertainty=True)) print(predictor.predict_property("nf", "smiles", smiles, uncertainty=True)) print(predictor.predict_binding_affinity("wt", target_seq=seq, binder_str="GIGAVLKVLT")) print(predictor.predict_binding_affinity("wt", target_seq=seq, binder_str="GIGAVLKVLT", uncertainty=True)) seq1 = "GIGAVLKVLTTGLPALISWIKRKRQQ" seq2 = "ACDEFGHIKLMNPQRSTVWY" r1 = predictor.predict_binding_affinity("wt", target_seq=seq2, binder_str="GIGAVLKVLT", uncertainty=True) r2 = predictor.predict_property("nf", "wt", seq1, uncertainty=True) r3 = predictor.predict_property("nf", "wt", seq2, uncertainty=True) print(r1) print(r2) print(r3)