Spaces:
Running
Running
| from __future__ import annotations | |
| import csv, re, json | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Dict, Optional, Tuple, Any, List | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import joblib | |
| import xgboost as xgb | |
| from transformers import EsmModel, EsmTokenizer, AutoModelForMaskedLM | |
| from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer | |
| from lightning.pytorch import seed_everything | |
| seed_everything(1986) | |
| # ----------------------------- | |
| # Manifest | |
| # ----------------------------- | |
| EMB_TAG_TO_FOLDER_SUFFIX = { | |
| "wt": "wt", | |
| "peptideclm": "smiles", | |
| "chemberta": "chemberta", | |
| } | |
| EMB_TAG_TO_RUNTIME_MODE = { | |
| "wt": "wt", | |
| "peptideclm": "smiles", | |
| "chemberta": "chemberta", | |
| } | |
| MAPIE_REGRESSION_MODELS = {"svr", "enet_gpu"} | |
| DNN_ARCHS = {"mlp", "cnn", "transformer"} | |
| XGB_MODELS = {"xgb", "xgb_reg", "xgb_wt_log", "xgb_smiles"} | |
| class BestRow: | |
| property_key: str | |
| best_wt: Optional[Tuple[str, Optional[str]]] | |
| best_smiles: Optional[Tuple[str, Optional[str]]] | |
| task_type: str | |
| thr_wt: Optional[float] | |
| thr_smiles: Optional[float] | |
| def _clean(s: str) -> str: | |
| return (s or "").strip() | |
| def _none_if_dash(s: str) -> Optional[str]: | |
| s = _clean(s) | |
| return None if s in {"", "-", "-", "NA", "N/A"} else s | |
| def _float_or_none(s: str) -> Optional[float]: | |
| s = _clean(s) | |
| return None if s in {"", "-", "-", "NA", "N/A"} else float(s) | |
| def normalize_property_key(name: str) -> str: | |
| n = name.strip().lower() | |
| n = re.sub(r"\s*\(.*?\)\s*", "", n) | |
| n = n.replace("-", "_").replace(" ", "_") | |
| if "permeability" in n and "pampa" not in n and "caco" not in n: | |
| return "permeability_penetrance" | |
| if n == "binding_affinity": | |
| return "binding_affinity" | |
| if n in {"halflife", "half_life"}: | |
| return "halflife" | |
| if n == "non_fouling": | |
| return "nf" | |
| return n | |
| MODEL_ALIAS = { | |
| "SVM": "svm_gpu", | |
| "SVR": "svr", | |
| "ENET": "enet_gpu", | |
| "CNN": "cnn", | |
| "MLP": "mlp", | |
| "TRANSFORMER": "transformer", | |
| "XGB": "xgb", | |
| "XGB_REG": "xgb_reg", | |
| "POOLED": "pooled", | |
| "UNPOOLED": "unpooled", | |
| "TRANSFORMER_WT_LOG": "transformer_wt_log", | |
| } | |
| def _parse_model_and_emb(raw: Optional[str]) -> Optional[Tuple[str, Optional[str]]]: | |
| if raw is None: | |
| return None | |
| raw = _clean(raw) | |
| if not raw or raw in {"-", "-", "NA", "N/A"}: | |
| return None | |
| m = re.match(r"^(.+?)\s*\((.+?)\)\s*$", raw) | |
| if m: | |
| model_raw = m.group(1).strip() | |
| emb_tag = m.group(2).strip().lower() | |
| else: | |
| model_raw = raw | |
| emb_tag = None | |
| canon = MODEL_ALIAS.get(model_raw.upper(), model_raw.lower()) | |
| return canon, emb_tag | |
| def read_best_manifest_csv(path: str | Path) -> Dict[str, BestRow]: | |
| p = Path(path) | |
| out: Dict[str, BestRow] = {} | |
| with p.open("r", newline="") as f: | |
| reader = csv.reader(f) | |
| header = None | |
| for raw in reader: | |
| if not raw or all(_clean(x) == "" for x in raw): | |
| continue | |
| while raw and _clean(raw[-1]) == "": | |
| raw = raw[:-1] | |
| if header is None: | |
| header = [h.strip() for h in raw] | |
| continue | |
| if len(raw) < len(header): | |
| raw = raw + [""] * (len(header) - len(raw)) | |
| rec = dict(zip(header, raw)) | |
| prop_raw = _clean(rec.get("Properties", "")) | |
| if not prop_raw: | |
| continue | |
| prop_key = normalize_property_key(prop_raw) | |
| best_wt = _parse_model_and_emb(_none_if_dash(rec.get("Best_Model_WT", ""))) | |
| best_smiles = _parse_model_and_emb(_none_if_dash(rec.get("Best_Model_SMILES", ""))) | |
| row = BestRow( | |
| property_key=prop_key, | |
| best_wt=best_wt, | |
| best_smiles=best_smiles, | |
| task_type=_clean(rec.get("Type", "Classifier")), | |
| thr_wt=_float_or_none(rec.get("Threshold_WT", "")), | |
| thr_smiles=_float_or_none(rec.get("Threshold_SMILES", "")), | |
| ) | |
| out[prop_key] = row | |
| return out | |
| # ----------------------------- | |
| # Generic artifact loading | |
| # ----------------------------- | |
| def find_best_artifact(model_dir: Path) -> Path: | |
| for pat in ["best_model.json", "best_model.pt", "best_model*.joblib", | |
| "model.json", "model.ubj", "final_model.json"]: | |
| hits = sorted(model_dir.glob(pat)) | |
| if hits: | |
| return hits[0] | |
| seed_pt = model_dir / "seed_1986" / "model.pt" | |
| if seed_pt.exists(): | |
| return seed_pt | |
| raise FileNotFoundError(f"No best_model artifact found in {model_dir}") | |
| def load_artifact(model_dir: Path, device: torch.device) -> Tuple[str, Any, Path]: | |
| art = find_best_artifact(model_dir) | |
| if art.suffix == ".json": | |
| booster = xgb.Booster() | |
| booster.load_model(str(art)) | |
| return "xgb", booster, art | |
| if art.suffix == ".joblib": | |
| obj = joblib.load(art) | |
| return "joblib", obj, art | |
| if art.suffix == ".pt": | |
| ckpt = torch.load(art, map_location=device, weights_only=False) | |
| return "torch_ckpt", ckpt, art | |
| raise ValueError(f"Unknown artifact type: {art}") | |
| # ----------------------------- | |
| # NN architectures | |
| # ----------------------------- | |
| class MaskedMeanPool(nn.Module): | |
| def forward(self, X, M): | |
| Mf = M.unsqueeze(-1).float() | |
| denom = Mf.sum(dim=1).clamp(min=1.0) | |
| return (X * Mf).sum(dim=1) / denom | |
| class MLPHead(nn.Module): | |
| def __init__(self, in_dim, hidden=512, dropout=0.1): | |
| super().__init__() | |
| self.pool = MaskedMeanPool() | |
| self.net = nn.Sequential( | |
| nn.Linear(in_dim, hidden), nn.GELU(), nn.Dropout(dropout), | |
| nn.Linear(hidden, 1), | |
| ) | |
| def forward(self, X, M): | |
| return self.net(self.pool(X, M)).squeeze(-1) | |
| class CNNHead(nn.Module): | |
| def __init__(self, in_ch, c=256, k=5, layers=2, dropout=0.1): | |
| super().__init__() | |
| blocks, ch = [], in_ch | |
| for _ in range(layers): | |
| blocks += [nn.Conv1d(ch, c, kernel_size=k, padding=k//2), nn.GELU(), nn.Dropout(dropout)] | |
| ch = c | |
| self.conv = nn.Sequential(*blocks) | |
| self.head = nn.Linear(c, 1) | |
| def forward(self, X, M): | |
| Y = self.conv(X.transpose(1, 2)).transpose(1, 2) | |
| Mf = M.unsqueeze(-1).float() | |
| pooled = (Y * Mf).sum(dim=1) / Mf.sum(dim=1).clamp(min=1.0) | |
| return self.head(pooled).squeeze(-1) | |
| class TransformerHead(nn.Module): | |
| def __init__(self, in_dim, d_model=256, nhead=8, layers=2, ff=512, dropout=0.1): | |
| super().__init__() | |
| self.proj = nn.Linear(in_dim, d_model) | |
| enc_layer = nn.TransformerEncoderLayer( | |
| d_model=d_model, nhead=nhead, dim_feedforward=ff, | |
| dropout=dropout, batch_first=True, activation="gelu" | |
| ) | |
| self.enc = nn.TransformerEncoder(enc_layer, num_layers=layers) | |
| self.head = nn.Linear(d_model, 1) | |
| def forward(self, X, M): | |
| Z = self.enc(self.proj(X), src_key_padding_mask=~M) | |
| Mf = M.unsqueeze(-1).float() | |
| pooled = (Z * Mf).sum(dim=1) / Mf.sum(dim=1).clamp(min=1.0) | |
| return self.head(pooled).squeeze(-1) | |
| def _infer_in_dim_from_sd(sd: dict, model_name: str) -> int: | |
| if model_name == "mlp": return int(sd["net.0.weight"].shape[1]) | |
| if model_name == "cnn": return int(sd["conv.0.weight"].shape[1]) | |
| if model_name == "transformer": return int(sd["proj.weight"].shape[1]) | |
| raise ValueError(model_name) | |
| def _infer_num_layers_from_sd(sd: dict, prefix: str = "enc.layers.") -> int: | |
| idxs = set() | |
| for k in sd.keys(): | |
| if k.startswith(prefix): | |
| m = re.match(r"(\d+)\.", k[len(prefix):]) | |
| if m: | |
| idxs.add(int(m.group(1))) | |
| return (max(idxs) + 1) if idxs else 1 | |
| def _infer_transformer_arch_from_sd(sd: dict) -> Tuple[int, int, int]: | |
| if "proj.weight" not in sd: | |
| raise KeyError("Missing proj.weight in state_dict") | |
| d_model = int(sd["proj.weight"].shape[0]) | |
| layers = _infer_num_layers_from_sd(sd, prefix="enc.layers.") | |
| ff = int(sd["enc.layers.0.linear1.weight"].shape[0]) if "enc.layers.0.linear1.weight" in sd else 4 * d_model | |
| return d_model, layers, ff | |
| def _pick_nhead(d_model: int) -> int: | |
| for h in (8, 6, 4, 3, 2, 1): | |
| if d_model % h == 0: | |
| return h | |
| return 1 | |
| def build_torch_model_from_ckpt(model_name: str, ckpt: dict, device: torch.device) -> nn.Module: | |
| params = ckpt["best_params"] | |
| sd = ckpt["state_dict"] | |
| in_dim = int(ckpt.get("in_dim", _infer_in_dim_from_sd(sd, model_name))) | |
| dropout = float(params.get("dropout", 0.1)) | |
| if model_name == "mlp": | |
| model = MLPHead(in_dim=in_dim, hidden=int(params["hidden"]), dropout=dropout) | |
| elif model_name == "cnn": | |
| model = CNNHead(in_ch=in_dim, c=int(params["channels"]), k=int(params["kernel"]), | |
| layers=int(params["layers"]), dropout=dropout) | |
| elif model_name == "transformer": | |
| d_model = params.get("d_model") or params.get("hidden") or params.get("hidden_dim") | |
| if d_model is None: | |
| d_model_i, layers_i, ff_i = _infer_transformer_arch_from_sd(sd) | |
| nhead_i = _pick_nhead(d_model_i) | |
| model = TransformerHead( | |
| in_dim=in_dim, d_model=int(d_model_i), nhead=int(params.get("nhead", nhead_i)), | |
| layers=int(params.get("layers", layers_i)), ff=int(params.get("ff", ff_i)), | |
| dropout=float(params.get("dropout", dropout)), | |
| ) | |
| else: | |
| d_model = int(d_model) | |
| model = TransformerHead( | |
| in_dim=in_dim, d_model=d_model, | |
| nhead=int(params.get("nhead", _pick_nhead(d_model))), | |
| layers=int(params.get("layers", 2)), | |
| ff=int(params.get("ff", 4 * d_model)), | |
| dropout=dropout, | |
| ) | |
| else: | |
| raise ValueError(f"Unknown NN model_name={model_name}") | |
| model.load_state_dict(sd) | |
| model.to(device).eval() | |
| return model | |
| # ----------------------------- | |
| # Wrappers | |
| # ----------------------------- | |
| from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin | |
| class PassthroughRegressor(BaseEstimator, RegressorMixin): | |
| def __init__(self, preds: np.ndarray): | |
| self.preds = preds | |
| def fit(self, X, y): return self | |
| def predict(self, X): return self.preds[:len(X)] | |
| class PassthroughClassifier(BaseEstimator, ClassifierMixin): | |
| def __init__(self, preds: np.ndarray): | |
| self.preds = preds | |
| self.classes_ = np.array([0, 1]) | |
| def fit(self, X, y): return self | |
| def predict(self, X): return (self.preds[:len(X)] >= 0.5).astype(int) | |
| def predict_proba(self, X): | |
| p = self.preds[:len(X)] | |
| return np.stack([1 - p, p], axis=1) | |
| # ----------------------------- | |
| # Uncertainty helpers | |
| # ----------------------------- | |
| SEED_DIRS = ["seed_1986", "seed_42", "seed_0", "seed_123", "seed_12345"] | |
| def load_seed_ensemble(model_dir: Path, arch: str, device: torch.device) -> List[nn.Module]: | |
| ensemble = [] | |
| for sd_name in SEED_DIRS: | |
| pt = model_dir / sd_name / "model.pt" | |
| if not pt.exists(): | |
| continue | |
| ckpt = torch.load(pt, map_location=device, weights_only=False) | |
| ensemble.append(build_torch_model_from_ckpt(arch, ckpt, device)) | |
| return ensemble | |
| def _binary_entropy(p: float) -> float: | |
| p = float(np.clip(p, 1e-9, 1 - 1e-9)) | |
| return float(-p * np.log(p) - (1 - p) * np.log(1 - p)) | |
| def _ensemble_clf_uncertainty(ensemble: List[nn.Module], X: torch.Tensor, M: torch.Tensor) -> float: | |
| probs = [] | |
| with torch.no_grad(): | |
| for m in ensemble: | |
| logit = m(X, M).squeeze().float().cpu().item() | |
| probs.append(1.0 / (1.0 + np.exp(-logit))) | |
| return _binary_entropy(float(np.mean(probs))) | |
| def _ensemble_reg_uncertainty(ensemble: List[nn.Module], X: torch.Tensor, M: torch.Tensor) -> float: | |
| preds = [] | |
| with torch.no_grad(): | |
| for m in ensemble: | |
| preds.append(m(X, M).squeeze().float().cpu().item()) | |
| return float(np.std(preds)) | |
| def _mapie_uncertainty(mapie_bundle: dict, score: float, | |
| embedding: Optional[np.ndarray] = None) -> Tuple[float, float]: | |
| """ | |
| Returns (ci_low, ci_high) from a conformal bundle. | |
| - adaptive: {"quantile": q, "sigma_model": xgb, "emb_tag": ..., "adaptive": True} | |
| Input-dependent: interval = score +/- q * sigma(embedding) | |
| - plain_quantile: {"quantile": q, "alpha": ...} | |
| Fixed-width: interval = score +/- q | |
| """ | |
| # Adaptive format is input-dependent interval | |
| if mapie_bundle.get("adaptive") and "sigma_model" in mapie_bundle: | |
| q = float(mapie_bundle["quantile"]) | |
| if embedding is not None: | |
| # Adaptive interval: y_hat ± q * sigma_hat(x). | |
| # Equivalent to MAPIE's get_estimation_distribution(): | |
| # y_pred + conformity_scores * r_pred | |
| # where conformity_scores=q and r_pred=sigma_hat(x). | |
| # (ResidualNormalisedScore, Cordier et al. 2023) | |
| sigma_model = mapie_bundle["sigma_model"] | |
| sigma = float(sigma_model.predict(xgb.DMatrix(embedding.reshape(1, -1)))[0]) | |
| sigma = max(sigma, 1e-6) | |
| else: | |
| # No embedding available - fall back to fixed interval with sigma=1 | |
| sigma = 1.0 | |
| return float(score - q * sigma), float(score + q * sigma) | |
| # Plain quantile format | |
| if "quantile" in mapie_bundle: | |
| q = float(mapie_bundle["quantile"]) | |
| return float(score - q), float(score + q) | |
| X_dummy = np.zeros((1, 1)) | |
| result = mapie.predict(X_dummy) | |
| if isinstance(result, tuple): | |
| intervals = np.asarray(result[1]) | |
| if intervals.ndim == 3: | |
| return float(intervals[0, 0, 0]), float(intervals[0, 1, 0]) | |
| return float(intervals[0, 0]), float(intervals[0, 1]) | |
| raise RuntimeError( | |
| f"Cannot extract intervals: unknown MAPIE bundle format. " | |
| f"Bundle keys: {list(mapie_bundle.keys())}." | |
| ) | |
| def affinity_to_class(y: float) -> int: | |
| if y >= 9.0: return 0 | |
| if y < 7.0: return 2 | |
| return 1 | |
| class CrossAttnPooled(nn.Module): | |
| def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1): | |
| super().__init__() | |
| self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden)) | |
| self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden)) | |
| self.layers = nn.ModuleList([]) | |
| for _ in range(n_layers): | |
| self.layers.append(nn.ModuleDict({ | |
| "attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False), | |
| "attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False), | |
| "n1t": nn.LayerNorm(hidden), "n2t": nn.LayerNorm(hidden), | |
| "n1b": nn.LayerNorm(hidden), "n2b": nn.LayerNorm(hidden), | |
| "fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), | |
| "ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), | |
| })) | |
| self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout)) | |
| self.reg = nn.Linear(hidden, 1) | |
| self.cls = nn.Linear(hidden, 3) | |
| def forward(self, t_vec, b_vec): | |
| t = self.t_proj(t_vec).unsqueeze(0) | |
| b = self.b_proj(b_vec).unsqueeze(0) | |
| for L in self.layers: | |
| t_attn, _ = L["attn_tb"](t, b, b) | |
| t = L["n1t"]((t + t_attn).transpose(0,1)).transpose(0,1) | |
| t = L["n2t"]((t + L["fft"](t)).transpose(0,1)).transpose(0,1) | |
| b_attn, _ = L["attn_bt"](b, t, t) | |
| b = L["n1b"]((b + b_attn).transpose(0,1)).transpose(0,1) | |
| b = L["n2b"]((b + L["ffb"](b)).transpose(0,1)).transpose(0,1) | |
| h = self.shared(torch.cat([t[0], b[0]], dim=-1)) | |
| return self.reg(h).squeeze(-1), self.cls(h) | |
| class CrossAttnUnpooled(nn.Module): | |
| def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1): | |
| super().__init__() | |
| self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden)) | |
| self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden)) | |
| self.layers = nn.ModuleList([]) | |
| for _ in range(n_layers): | |
| self.layers.append(nn.ModuleDict({ | |
| "attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True), | |
| "attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True), | |
| "n1t": nn.LayerNorm(hidden), "n2t": nn.LayerNorm(hidden), | |
| "n1b": nn.LayerNorm(hidden), "n2b": nn.LayerNorm(hidden), | |
| "fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), | |
| "ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)), | |
| })) | |
| self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout)) | |
| self.reg = nn.Linear(hidden, 1) | |
| self.cls = nn.Linear(hidden, 3) | |
| def _masked_mean(self, X, M): | |
| Mf = M.unsqueeze(-1).float() | |
| return (X * Mf).sum(dim=1) / Mf.sum(dim=1).clamp(min=1.0) | |
| def forward(self, T, Mt, B, Mb): | |
| T = self.t_proj(T); Bx = self.b_proj(B) | |
| kp_t, kp_b = ~Mt, ~Mb | |
| for L in self.layers: | |
| T_attn, _ = L["attn_tb"](T, Bx, Bx, key_padding_mask=kp_b) | |
| T = L["n1t"](T + T_attn); T = L["n2t"](T + L["fft"](T)) | |
| B_attn, _ = L["attn_bt"](Bx, T, T, key_padding_mask=kp_t) | |
| Bx = L["n1b"](Bx + B_attn); Bx = L["n2b"](Bx + L["ffb"](Bx)) | |
| h = self.shared(torch.cat([self._masked_mean(T, Mt), self._masked_mean(Bx, Mb)], dim=-1)) | |
| return self.reg(h).squeeze(-1), self.cls(h) | |
| def load_binding_model(best_model_pt: Path, pooled_or_unpooled: str, device: torch.device) -> nn.Module: | |
| ckpt = torch.load(best_model_pt, map_location=device, weights_only=False) | |
| params = ckpt["best_params"] | |
| sd = ckpt["state_dict"] | |
| Ht = int(sd["t_proj.0.weight"].shape[1]) | |
| Hb = int(sd["b_proj.0.weight"].shape[1]) | |
| common = dict(Ht=Ht, Hb=Hb, hidden=int(params["hidden_dim"]), | |
| n_heads=int(params["n_heads"]), n_layers=int(params["n_layers"]), | |
| dropout=float(params["dropout"])) | |
| cls = CrossAttnPooled if pooled_or_unpooled == "pooled" else CrossAttnUnpooled | |
| model = cls(**common) | |
| model.load_state_dict(sd) | |
| return model.to(device).eval() | |
| # ----------------------------- | |
| # Embedding generation | |
| # ----------------------------- | |
| def _safe_isin(ids: torch.Tensor, test_ids: torch.Tensor) -> torch.Tensor: | |
| if hasattr(torch, "isin"): | |
| return torch.isin(ids, test_ids) | |
| return (ids.unsqueeze(-1) == test_ids.view(1, 1, -1)).any(dim=-1) | |
| class SMILESEmbedder: | |
| def __init__(self, device, vocab_path, splits_path, | |
| clm_name="aaronfeller/PeptideCLM-23M-all", max_len=512, use_cache=True): | |
| self.device = device | |
| self.max_len = max_len | |
| self.use_cache = use_cache | |
| self.tokenizer = SMILES_SPE_Tokenizer(vocab_path, splits_path) | |
| self.model = AutoModelForMaskedLM.from_pretrained(clm_name).roformer.to(device).eval() | |
| self.special_ids = self._get_special_ids(self.tokenizer) | |
| self.special_ids_t = (torch.tensor(self.special_ids, device=device, dtype=torch.long) | |
| if self.special_ids else None) | |
| self._cache_pooled: Dict[str, torch.Tensor] = {} | |
| self._cache_unpooled: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {} | |
| def _get_special_ids(tokenizer) -> List[int]: | |
| cand = [getattr(tokenizer, f"{x}_token_id", None) | |
| for x in ("pad", "cls", "sep", "bos", "eos", "mask")] | |
| return sorted({int(x) for x in cand if x is not None}) | |
| def _tokenize(self, smiles_list): | |
| tok = self.tokenizer(smiles_list, return_tensors="pt", padding=True, | |
| truncation=True, max_length=self.max_len) | |
| for k in tok: tok[k] = tok[k].to(self.device) | |
| if "attention_mask" not in tok: | |
| tok["attention_mask"] = torch.ones_like(tok["input_ids"], dtype=torch.long, device=self.device) | |
| return tok | |
| def _valid_mask(self, ids, attn): | |
| valid = attn.bool() | |
| if self.special_ids_t is not None and self.special_ids_t.numel() > 0: | |
| valid = valid & (~_safe_isin(ids, self.special_ids_t)) | |
| return valid | |
| def pooled(self, smiles: str) -> torch.Tensor: | |
| s = smiles.strip() | |
| if self.use_cache and s in self._cache_pooled: return self._cache_pooled[s] | |
| tok = self._tokenize([s]) | |
| h = self.model(input_ids=tok["input_ids"], attention_mask=tok["attention_mask"]).last_hidden_state | |
| valid = self._valid_mask(tok["input_ids"], tok["attention_mask"]) | |
| vf = valid.unsqueeze(-1).float() | |
| pooled = (h * vf).sum(dim=1) / vf.sum(dim=1).clamp(min=1e-9) | |
| if self.use_cache: self._cache_pooled[s] = pooled | |
| return pooled | |
| def unpooled(self, smiles: str) -> Tuple[torch.Tensor, torch.Tensor]: | |
| s = smiles.strip() | |
| if self.use_cache and s in self._cache_unpooled: return self._cache_unpooled[s] | |
| tok = self._tokenize([s]) | |
| h = self.model(input_ids=tok["input_ids"], attention_mask=tok["attention_mask"]).last_hidden_state | |
| valid = self._valid_mask(tok["input_ids"], tok["attention_mask"]) | |
| X = h[:, valid[0], :] | |
| M = torch.ones((1, X.shape[1]), dtype=torch.bool, device=self.device) | |
| if self.use_cache: self._cache_unpooled[s] = (X, M) | |
| return X, M | |
| class ChemBERTaEmbedder: | |
| def __init__(self, device, model_name="DeepChem/ChemBERTa-77M-MLM", | |
| max_len=512, use_cache=True): | |
| from transformers import AutoTokenizer, AutoModel | |
| self.device = device | |
| self.max_len = max_len | |
| self.use_cache = use_cache | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| self.model = AutoModel.from_pretrained(model_name).to(device).eval() | |
| self.special_ids = self._get_special_ids(self.tokenizer) | |
| self.special_ids_t = (torch.tensor(self.special_ids, device=device, dtype=torch.long) | |
| if self.special_ids else None) | |
| self._cache_pooled: Dict[str, torch.Tensor] = {} | |
| self._cache_unpooled: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {} | |
| def _get_special_ids(tokenizer) -> List[int]: | |
| cand = [getattr(tokenizer, f"{x}_token_id", None) | |
| for x in ("pad", "cls", "sep", "bos", "eos", "mask")] | |
| return sorted({int(x) for x in cand if x is not None}) | |
| def _tokenize(self, smiles_list): | |
| tok = self.tokenizer(smiles_list, return_tensors="pt", padding=True, | |
| truncation=True, max_length=self.max_len) | |
| for k in tok: tok[k] = tok[k].to(self.device) | |
| if "attention_mask" not in tok: | |
| tok["attention_mask"] = torch.ones_like(tok["input_ids"], dtype=torch.long, device=self.device) | |
| return tok | |
| def _valid_mask(self, ids, attn): | |
| valid = attn.bool() | |
| if self.special_ids_t is not None and self.special_ids_t.numel() > 0: | |
| valid = valid & (~_safe_isin(ids, self.special_ids_t)) | |
| return valid | |
| def pooled(self, smiles: str) -> torch.Tensor: | |
| s = smiles.strip() | |
| if self.use_cache and s in self._cache_pooled: return self._cache_pooled[s] | |
| tok = self._tokenize([s]) | |
| h = self.model(input_ids=tok["input_ids"], attention_mask=tok["attention_mask"]).last_hidden_state | |
| valid = self._valid_mask(tok["input_ids"], tok["attention_mask"]) | |
| vf = valid.unsqueeze(-1).float() | |
| pooled = (h * vf).sum(dim=1) / vf.sum(dim=1).clamp(min=1e-9) | |
| if self.use_cache: self._cache_pooled[s] = pooled | |
| return pooled | |
| def unpooled(self, smiles: str) -> Tuple[torch.Tensor, torch.Tensor]: | |
| s = smiles.strip() | |
| if self.use_cache and s in self._cache_unpooled: return self._cache_unpooled[s] | |
| tok = self._tokenize([s]) | |
| h = self.model(input_ids=tok["input_ids"], attention_mask=tok["attention_mask"]).last_hidden_state | |
| valid = self._valid_mask(tok["input_ids"], tok["attention_mask"]) | |
| X = h[:, valid[0], :] | |
| M = torch.ones((1, X.shape[1]), dtype=torch.bool, device=self.device) | |
| if self.use_cache: self._cache_unpooled[s] = (X, M) | |
| return X, M | |
| class WTEmbedder: | |
| def __init__(self, device, esm_name="facebook/esm2_t33_650M_UR50D", max_len=1022, use_cache=True): | |
| self.device = device | |
| self.max_len = max_len | |
| self.use_cache = use_cache | |
| self.tokenizer = EsmTokenizer.from_pretrained(esm_name) | |
| self.model = EsmModel.from_pretrained(esm_name, add_pooling_layer=False).to(device).eval() | |
| self.special_ids = self._get_special_ids(self.tokenizer) | |
| self.special_ids_t = (torch.tensor(self.special_ids, device=device, dtype=torch.long) | |
| if self.special_ids else None) | |
| self._cache_pooled: Dict[str, torch.Tensor] = {} | |
| self._cache_unpooled: Dict[str, Tuple[torch.Tensor, torch.Tensor]] = {} | |
| def _get_special_ids(tokenizer) -> List[int]: | |
| cand = [getattr(tokenizer, f"{x}_token_id", None) | |
| for x in ("pad", "cls", "sep", "bos", "eos", "mask")] | |
| return sorted({int(x) for x in cand if x is not None}) | |
| def _tokenize(self, seq_list): | |
| tok = self.tokenizer(seq_list, return_tensors="pt", padding=True, | |
| truncation=True, max_length=self.max_len) | |
| tok = {k: v.to(self.device) for k, v in tok.items()} | |
| if "attention_mask" not in tok: | |
| tok["attention_mask"] = torch.ones_like(tok["input_ids"], dtype=torch.long, device=self.device) | |
| return tok | |
| def _valid_mask(self, ids, attn): | |
| valid = attn.bool() | |
| if self.special_ids_t is not None and self.special_ids_t.numel() > 0: | |
| valid = valid & (~_safe_isin(ids, self.special_ids_t)) | |
| return valid | |
| def pooled(self, seq: str) -> torch.Tensor: | |
| s = seq.strip() | |
| if self.use_cache and s in self._cache_pooled: return self._cache_pooled[s] | |
| tok = self._tokenize([s]) | |
| h = self.model(**tok).last_hidden_state | |
| valid = self._valid_mask(tok["input_ids"], tok["attention_mask"]) | |
| vf = valid.unsqueeze(-1).float() | |
| pooled = (h * vf).sum(dim=1) / vf.sum(dim=1).clamp(min=1e-9) | |
| if self.use_cache: self._cache_pooled[s] = pooled | |
| return pooled | |
| def unpooled(self, seq: str) -> Tuple[torch.Tensor, torch.Tensor]: | |
| s = seq.strip() | |
| if self.use_cache and s in self._cache_unpooled: return self._cache_unpooled[s] | |
| tok = self._tokenize([s]) | |
| h = self.model(**tok).last_hidden_state | |
| valid = self._valid_mask(tok["input_ids"], tok["attention_mask"]) | |
| X = h[:, valid[0], :] | |
| M = torch.ones((1, X.shape[1]), dtype=torch.bool, device=self.device) | |
| if self.use_cache: self._cache_unpooled[s] = (X, M) | |
| return X, M | |
| # ----------------------------- | |
| # Predictor | |
| # ----------------------------- | |
| class PeptiVersePredictor: | |
| def __init__( | |
| self, | |
| manifest_path: str | Path, | |
| classifier_weight_root: str | Path, | |
| esm_name="facebook/esm2_t33_650M_UR50D", | |
| clm_name="aaronfeller/PeptideCLM-23M-all", | |
| chemberta_name="DeepChem/ChemBERTa-77M-MLM", | |
| smiles_vocab="tokenizer/new_vocab.txt", | |
| smiles_splits="tokenizer/new_splits.txt", | |
| device: Optional[str] = None, | |
| ): | |
| self.root = Path(classifier_weight_root) | |
| self.training_root = self.root / "training_classifiers" | |
| self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu")) | |
| self.manifest = read_best_manifest_csv(manifest_path) | |
| self.wt_embedder = WTEmbedder(self.device, esm_name=esm_name) | |
| self.smiles_embedder = SMILESEmbedder(self.device, clm_name=clm_name, | |
| vocab_path=str(self.root / smiles_vocab), | |
| splits_path=str(self.root / smiles_splits)) | |
| self.chemberta_embedder = ChemBERTaEmbedder(self.device, model_name=chemberta_name) | |
| self.models: Dict[Tuple[str, str], Any] = {} | |
| self.meta: Dict[Tuple[str, str], Dict[str, Any]] = {} | |
| self.mapie: Dict[Tuple[str, str], dict] = {} | |
| self.ensembles: Dict[Tuple[str, str], List] = {} | |
| self._load_all_best_models() | |
| def _get_embedder(self, emb_tag: str): | |
| if emb_tag == "wt": return self.wt_embedder | |
| if emb_tag == "peptideclm": return self.smiles_embedder | |
| if emb_tag == "chemberta": return self.chemberta_embedder | |
| raise ValueError(f"Unknown emb_tag={emb_tag!r}") | |
| def _embed_pooled(self, emb_tag: str, input_str: str) -> np.ndarray: | |
| v = self._get_embedder(emb_tag).pooled(input_str) | |
| feats = v.detach().cpu().numpy().astype(np.float32) | |
| feats = np.nan_to_num(feats, nan=0.0) | |
| return np.clip(feats, np.finfo(np.float32).min, np.finfo(np.float32).max) | |
| def _embed_unpooled(self, emb_tag: str, input_str: str) -> Tuple[torch.Tensor, torch.Tensor]: | |
| return self._get_embedder(emb_tag).unpooled(input_str) | |
| def _resolve_dir(self, prop_key: str, model_name: str, emb_tag: str) -> Path: | |
| disk_prop = "half_life" if prop_key == "halflife" else prop_key | |
| base = self.training_root / disk_prop | |
| folder_suffix = EMB_TAG_TO_FOLDER_SUFFIX.get(emb_tag, emb_tag) | |
| if prop_key == "halflife" and emb_tag == "wt": | |
| if model_name == "transformer": | |
| for d in [base / "transformer_wt_log", base / "transformer_wt"]: | |
| if d.exists(): return d | |
| if model_name in {"xgb", "xgb_reg"}: | |
| d = base / "xgb_wt_log" | |
| if d.exists(): return d | |
| candidates = [ | |
| base / f"{model_name}_{folder_suffix}", | |
| base / model_name, | |
| ] | |
| for d in candidates: | |
| if d.exists(): return d | |
| raise FileNotFoundError( | |
| f"Cannot find model dir for {prop_key}/{model_name}/{emb_tag}. Tried: {candidates}" | |
| ) | |
| def _load_all_best_models(self): | |
| for prop_key, row in self.manifest.items(): | |
| for col, parsed, thr in [ | |
| ("wt", row.best_wt, row.thr_wt), | |
| ("smiles", row.best_smiles, row.thr_smiles), | |
| ]: | |
| if parsed is None: | |
| continue | |
| model_name, emb_tag = parsed | |
| # binding affinity | |
| if prop_key == "binding_affinity": | |
| folder = model_name | |
| pooled_or_unpooled = "unpooled" if "unpooled" in folder else "pooled" | |
| model_dir = self.training_root / "binding_affinity" / folder | |
| art = find_best_artifact(model_dir) | |
| model = load_binding_model(art, pooled_or_unpooled, self.device) | |
| self.models[(prop_key, col)] = model | |
| self.meta[(prop_key, col)] = { | |
| "task_type": "Regression", | |
| "threshold": None, | |
| "artifact": str(art), | |
| "model_name": pooled_or_unpooled, | |
| "emb_tag": emb_tag, | |
| "folder": folder, | |
| "kind": "binding", | |
| } | |
| print(f" [LOAD] binding_affinity ({col}): folder={folder}, arch={pooled_or_unpooled}, emb_tag={emb_tag}, art={art.name}") | |
| mapie_path = model_dir / "mapie_calibration.joblib" | |
| if mapie_path.exists(): | |
| try: | |
| self.mapie[(prop_key, col)] = joblib.load(mapie_path) | |
| print(f" MAPIE loaded from {mapie_path.name}") | |
| except Exception as e: | |
| print(f" MAPIE load FAILED for ({prop_key}, {col}): {e}") | |
| else: | |
| print(f" No MAPIE bundle found (uncertainty will be unavailable)") | |
| continue | |
| # infer emb_tag | |
| if emb_tag is None: | |
| emb_tag = col | |
| model_dir = self._resolve_dir(prop_key, model_name, emb_tag) | |
| kind, obj, art = load_artifact(model_dir, self.device) | |
| if kind == "torch_ckpt": | |
| arch = self._base_arch(model_name) | |
| model = build_torch_model_from_ckpt(arch, obj, self.device) | |
| else: | |
| model = obj | |
| self.models[(prop_key, col)] = model | |
| self.meta[(prop_key, col)] = { | |
| "task_type": row.task_type, | |
| "threshold": thr, | |
| "artifact": str(art), | |
| "model_name": model_name, | |
| "emb_tag": emb_tag, | |
| "kind": kind, | |
| } | |
| print(f" [LOAD] ({prop_key}, {col}): kind={kind}, model={model_name}, emb={emb_tag}, task={row.task_type}, art={art.name}") | |
| # MAPIE: SVR/ElasticNet, XGBoost regression, AND all regression torch_ckpt | |
| is_regression = row.task_type.lower() == "regression" | |
| wants_mapie = ( | |
| (model_name in MAPIE_REGRESSION_MODELS and is_regression) | |
| or (kind == "xgb" and is_regression) | |
| or (kind == "torch_ckpt" and is_regression) | |
| ) | |
| if wants_mapie: | |
| mapie_path = model_dir / "mapie_calibration.joblib" | |
| if mapie_path.exists(): | |
| try: | |
| self.mapie[(prop_key, col)] = joblib.load(mapie_path) | |
| print(f" MAPIE loaded from {mapie_path.name}") | |
| except Exception as e: | |
| print(f" MAPIE load FAILED for ({prop_key}, {col}): {e}") | |
| else: | |
| print(f" No MAPIE bundle found at {mapie_path} (will fall back to ensemble if available)") | |
| # Seed ensembles: DNN only, used when MAPIE not available | |
| if kind == "torch_ckpt": | |
| arch = self._base_arch(model_name) | |
| ens = load_seed_ensemble(model_dir, arch, self.device) | |
| if ens: | |
| self.ensembles[(prop_key, col)] = ens | |
| if (prop_key, col) in self.mapie: | |
| print(f" Seed ensemble: {len(ens)} seeds loaded (MAPIE takes priority for regression)") | |
| else: | |
| unc_type = "ensemble_predictive_entropy" if row.task_type.lower() == "classifier" else "ensemble_std" | |
| print(f" Seed ensemble: {len(ens)} seeds loaded uncertainty method: {unc_type}") | |
| else: | |
| if (prop_key, col) in self.mapie: | |
| print(f" No seed ensemble (MAPIE covers uncertainty)") | |
| else: | |
| print(f" No seed ensemble found (checked: {SEED_DIRS}) - uncertainty unavailable") | |
| # XGBoost/SVM classifiers: binary entropy | |
| if kind in ("xgb", "joblib") and row.task_type.lower() == "classifier": | |
| print(f" Uncertainty method: binary_predictive_entropy (computed at inference)") | |
| def _base_arch(model_name: str) -> str: | |
| if model_name.startswith("transformer"): return "transformer" | |
| if model_name.startswith("mlp"): return "mlp" | |
| if model_name.startswith("cnn"): return "cnn" | |
| return model_name | |
| # Feature extraction | |
| def _get_features(self, prop_key: str, col: str, input_str: str): | |
| meta = self.meta[(prop_key, col)] | |
| emb_tag = meta["emb_tag"] | |
| kind = meta["kind"] | |
| if kind == "torch_ckpt": | |
| return self._embed_unpooled(emb_tag, input_str) | |
| return self._embed_pooled(emb_tag, input_str) | |
| # Uncertainty | |
| def _compute_uncertainty(self, prop_key: str, col: str, input_str: str, | |
| score: float) -> Tuple[Any, str]: | |
| meta = self.meta[(prop_key, col)] | |
| kind = meta["kind"] | |
| model_name = meta["model_name"] | |
| task_type = meta["task_type"].lower() | |
| emb_tag = meta["emb_tag"] | |
| # Pooled embedding for adaptive MAPIE sigma model | |
| def get_pooled_emb(): | |
| return self._embed_pooled(emb_tag, input_str) if emb_tag else None | |
| # DNN | |
| if kind == "torch_ckpt": | |
| # Regression: prefer MAPIE if available | |
| if task_type == "regression": | |
| mapie_bundle = self.mapie.get((prop_key, col)) | |
| if mapie_bundle: | |
| emb = get_pooled_emb() if mapie_bundle.get("adaptive") else None | |
| lo, hi = _mapie_uncertainty(mapie_bundle, score, emb) | |
| return (lo, hi), "conformal_prediction_interval" | |
| # Fall back to seed ensemble std | |
| ens = self.ensembles.get((prop_key, col)) | |
| if ens: | |
| X, M = self._embed_unpooled(emb_tag, input_str) | |
| return _ensemble_reg_uncertainty(ens, X, M), "ensemble_std" | |
| return None, "unavailable (no MAPIE bundle and no seed ensemble)" | |
| # Classifier: ensemble predictive entropy | |
| ens = self.ensembles.get((prop_key, col)) | |
| if not ens: | |
| return None, "unavailable (no seed ensemble found)" | |
| X, M = self._embed_unpooled(emb_tag, input_str) | |
| return _ensemble_clf_uncertainty(ens, X, M), "ensemble_predictive_entropy" | |
| # XGBoost | |
| if kind == "xgb": | |
| if task_type == "classifier": | |
| return _binary_entropy(score), "binary_predictive_entropy" | |
| mapie_bundle = self.mapie.get((prop_key, col)) | |
| if mapie_bundle: | |
| emb = get_pooled_emb() if mapie_bundle.get("adaptive") else None | |
| lo, hi = _mapie_uncertainty(mapie_bundle, score, emb) | |
| return (lo, hi), "conformal_prediction_interval" | |
| return None, "unavailable (no MAPIE bundle for XGBoost regression)" | |
| # SVR / ElasticNet regression: MAPIE | |
| if kind == "joblib" and model_name in MAPIE_REGRESSION_MODELS and task_type == "regression": | |
| mapie_bundle = self.mapie.get((prop_key, col)) | |
| if mapie_bundle: | |
| emb = get_pooled_emb() if mapie_bundle.get("adaptive") else None | |
| lo, hi = _mapie_uncertainty(mapie_bundle, score, emb) | |
| return (lo, hi), "conformal_prediction_interval" | |
| return None, "unavailable (MAPIE bundle not found)" | |
| # joblib classifiers (SVM, ElasticNet used as classifier) | |
| if kind == "joblib" and task_type == "classifier": | |
| return _binary_entropy(score), "binary_predictive_entropy_single_model" | |
| return None, "unavailable" | |
| def predict_property(self, prop_key: str, col: str, input_str: str, | |
| uncertainty: bool = False) -> Dict[str, Any]: | |
| if (prop_key, col) not in self.models: | |
| raise KeyError(f"No model loaded for ({prop_key}, {col}).") | |
| meta = self.meta[(prop_key, col)] | |
| model = self.models[(prop_key, col)] | |
| task_type = meta["task_type"].lower() | |
| thr = meta.get("threshold") | |
| kind = meta["kind"] | |
| model_name = meta["model_name"] | |
| if prop_key == "binding_affinity": | |
| raise RuntimeError("Use predict_binding_affinity().") | |
| # DNN | |
| if kind == "torch_ckpt": | |
| X, M = self._get_features(prop_key, col, input_str) | |
| with torch.no_grad(): | |
| raw = model(X, M).squeeze().float().cpu().item() | |
| if prop_key == "halflife" and col == "wt" and "log" in model_name: | |
| raw = float(np.expm1(raw)) | |
| if task_type == "classifier": | |
| score = float(1.0 / (1.0 + np.exp(-raw))) | |
| out = {"property": prop_key, "col": col, "score": score, | |
| "emb_tag": meta["emb_tag"]} | |
| if thr is not None: | |
| out["label"] = int(score >= float(thr)); out["threshold"] = float(thr) | |
| else: | |
| out = {"property": prop_key, "col": col, "score": float(raw), | |
| "emb_tag": meta["emb_tag"]} | |
| # XGBoost | |
| elif kind == "xgb": | |
| feats = self._get_features(prop_key, col, input_str) | |
| pred = float(model.predict(xgb.DMatrix(feats))[0]) | |
| if prop_key == "halflife" and col == "wt" and "log" in model_name: | |
| pred = float(np.expm1(pred)) | |
| out = {"property": prop_key, "col": col, "score": pred, | |
| "emb_tag": meta["emb_tag"]} | |
| if task_type == "classifier" and thr is not None: | |
| out["label"] = int(pred >= float(thr)); out["threshold"] = float(thr) | |
| # joblib (SVM / ElasticNet / SVR) | |
| elif kind == "joblib": | |
| feats = self._get_features(prop_key, col, input_str) | |
| if task_type == "classifier": | |
| if hasattr(model, "predict_proba"): | |
| pred = float(model.predict_proba(feats)[:, 1][0]) | |
| elif hasattr(model, "decision_function"): | |
| pred = float(1.0 / (1.0 + np.exp(-model.decision_function(feats)[0]))) | |
| else: | |
| pred = float(model.predict(feats)[0]) | |
| out = {"property": prop_key, "col": col, "score": pred, | |
| "emb_tag": meta["emb_tag"]} | |
| if thr is not None: | |
| out["label"] = int(pred >= float(thr)); out["threshold"] = float(thr) | |
| else: | |
| pred = float(model.predict(feats)[0]) | |
| out = {"property": prop_key, "col": col, "score": pred, | |
| "emb_tag": meta["emb_tag"]} | |
| else: | |
| raise RuntimeError(f"Unknown kind={kind}") | |
| if uncertainty: | |
| u_val, u_type = self._compute_uncertainty(prop_key, col, input_str, out["score"]) | |
| out["uncertainty"] = u_val | |
| out["uncertainty_type"] = u_type | |
| return out | |
| def predict_binding_affinity(self, col: str, target_seq: str, binder_str: str, | |
| uncertainty: bool = False) -> Dict[str, Any]: | |
| prop_key = "binding_affinity" | |
| if (prop_key, col) not in self.models: | |
| raise KeyError(f"No binding model loaded for ({prop_key}, {col}).") | |
| model = self.models[(prop_key, col)] | |
| meta = self.meta[(prop_key, col)] | |
| arch = meta["model_name"] | |
| emb_tag = meta.get("emb_tag") | |
| if arch == "pooled": | |
| t_vec = self.wt_embedder.pooled(target_seq) | |
| b_vec = self._get_embedder(emb_tag or col).pooled(binder_str) if emb_tag else \ | |
| (self.wt_embedder.pooled(binder_str) if col == "wt" else self.chemberta_embedder.pooled(binder_str)) | |
| with torch.no_grad(): | |
| reg, logits = model(t_vec, b_vec) | |
| else: | |
| T, Mt = self.wt_embedder.unpooled(target_seq) | |
| binder_emb = self._get_embedder(emb_tag or col) if emb_tag else \ | |
| (self.wt_embedder if col == "wt" else self.smiles_embedder) | |
| B, Mb = binder_emb.unpooled(binder_str) | |
| with torch.no_grad(): | |
| reg, logits = model(T, Mt, B, Mb) | |
| affinity = float(reg.squeeze().cpu().item()) | |
| cls_logit = int(torch.argmax(logits, dim=-1).cpu().item()) | |
| cls_thr = affinity_to_class(affinity) | |
| names = {0: "High (≥9)", 1: "Moderate (7-9)", 2: "Low (<7)"} | |
| out = { | |
| "property": "binding_affinity", | |
| "col": col, | |
| "affinity": affinity, | |
| "class_by_threshold": names[cls_thr], | |
| "class_by_logits": names[cls_logit], | |
| "binding_model": arch, | |
| } | |
| if uncertainty: | |
| mapie_bundle = self.mapie.get((prop_key, col)) | |
| if mapie_bundle: | |
| if mapie_bundle.get("adaptive") and "sigma_model" in mapie_bundle: | |
| # Concatenate target + binder pooled embeddings for sigma model | |
| binder_emb_tag = mapie_bundle.get("emb_tag") or col | |
| target_emb_tag = mapie_bundle.get("target_emb_tag", "wt") | |
| t_vec = self.wt_embedder.pooled(target_seq).cpu().float().numpy() | |
| b_vec = self._get_embedder(binder_emb_tag).pooled(binder_str).cpu().float().numpy() | |
| emb = np.concatenate([t_vec, b_vec], axis=1) | |
| else: | |
| emb = None | |
| lo, hi = _mapie_uncertainty(mapie_bundle, affinity, emb) | |
| out["uncertainty"] = (lo, hi) | |
| out["uncertainty_type"] = "conformal_prediction_interval" | |
| else: | |
| out["uncertainty"] = None | |
| out["uncertainty_type"] = "unavailable (no MAPIE bundle found)" | |
| return out | |
| if __name__ == "__main__": | |
| root = Path(__file__).resolve().parent # current script folder | |
| predictor = PeptiVersePredictor( | |
| manifest_path=root / "best_models.txt", | |
| classifier_weight_root=root | |
| ) | |
| print(predictor.training_root) | |
| print("MAPIE keys:", list(predictor.mapie.keys())) | |
| print("Ensemble keys:", list(predictor.ensembles.keys())) | |
| seq = "GIGAVLKVLTTGLPALISWIKRKRQQ" | |
| smiles = "C(C)C[C@@H]1NC(=O)[C@@H]2CCCN2C(=O)[C@@H](CC(C)C)NC(=O)[C@@H](CC(C)C)N(C)C(=O)[C@H](C)NC(=O)[C@H](Cc2ccccc2)NC1=O" | |
| print(predictor.predict_property("hemolysis", "wt", seq)) | |
| print(predictor.predict_property("hemolysis", "smiles", smiles, uncertainty=True)) | |
| print(predictor.predict_property("nf", "wt", seq, uncertainty=True)) | |
| print(predictor.predict_property("nf", "smiles", smiles, uncertainty=True)) | |
| print(predictor.predict_binding_affinity("wt", target_seq=seq, binder_str="GIGAVLKVLT")) | |
| print(predictor.predict_binding_affinity("wt", target_seq=seq, binder_str="GIGAVLKVLT", uncertainty=True)) | |
| seq1 = "GIGAVLKVLTTGLPALISWIKRKRQQ" | |
| seq2 = "ACDEFGHIKLMNPQRSTVWY" | |
| r1 = predictor.predict_binding_affinity("wt", target_seq=seq2, binder_str="GIGAVLKVLT", uncertainty=True) | |
| r2 = predictor.predict_property("nf", "wt", seq1, uncertainty=True) | |
| r3 = predictor.predict_property("nf", "wt", seq2, uncertainty=True) | |
| print(r1) | |
| print(r2) | |
| print(r3) | |