| """ |
| Bundle format: |
| { |
| "quantile": q, |
| "sigma_model": xgb_booster, |
| "emb_tag": "wt"|"peptideclm"|"chemberta", |
| "alpha": 0.1, |
| "adaptive": True, |
| } |
| |
| Binding affinity bundles additionally store "target_emb_tag": "wt" since |
| both binder and target embeddings are concatenated for the sigma model. |
| """ |
|
|
| import argparse |
| import sys |
| import numpy as np |
| import pandas as pd |
| import joblib |
| import xgboost as xgb |
| import torch |
| from pathlib import Path |
| from typing import Optional |
|
|
| sys.path.insert(0, str(Path(__file__).parent)) |
|
|
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| WEIGHT_ROOT = Path(__file__).parent |
|
|
| |
|
|
| SKIP_PROPS = {"half_life", "halflife"} |
|
|
| def should_skip(model_dir: Path) -> bool: |
| return any(part in SKIP_PROPS for part in model_dir.parts) |
|
|
| |
|
|
| def infer_emb_tag(folder_name: str) -> Optional[str]: |
| n = folder_name.lower() |
| if "chemberta" in n: return "chemberta" |
| if "peptideclm" in n: return "peptideclm" |
| if "smiles" in n: return "peptideclm" |
| if "wt" in n: return "wt" |
| return None |
|
|
| def is_binding_affinity(model_dir: Path) -> bool: |
| return "binding_affinity" in model_dir.parts |
|
|
| def infer_binding_emb_tags(folder_name: str): |
| """ |
| Returns (binder_emb_tag, target_emb_tag) for binding affinity folders. |
| Folder convention: {target_emb}_{binder_emb}_{pooled|unpooled} |
| e.g. wt_wt_unpooled, chemberta_smiles_unpooled, peptideclm_smiles_unpooled |
| """ |
| n = folder_name.lower() |
| |
| target_emb = "wt" |
| |
| if "chemberta" in n: binder_emb = "chemberta" |
| elif "peptideclm" in n: binder_emb = "peptideclm" |
| else: binder_emb = "wt" |
| return binder_emb, target_emb |
|
|
| SEQ_CANDIDATES = ["sequence", "smiles", "seq", "peptide", "molecule"] |
| PRED_CANDIDATES = ["y_prob", "y_pred", "pred_prob", "pred_score", "score", "pred", "prediction"] |
| TRUE_CANDIDATES = ["y_true", "label", "true_label", "affinity", "y", "target"] |
|
|
| def resolve_col(df, candidates, label): |
| cl = {c.lower(): c for c in df.columns} |
| for c in candidates: |
| if c.lower() in cl: |
| return cl[c.lower()] |
| raise ValueError(f"Cannot find {label} column. Available: {list(df.columns)}") |
|
|
| _embedders = {} |
|
|
| def get_embedder(emb_tag: str): |
| if emb_tag in _embedders: |
| return _embedders[emb_tag] |
| if emb_tag == "wt": |
| from inference_new import WTEmbedder |
| emb = WTEmbedder(DEVICE) |
| elif emb_tag == "peptideclm": |
| from inference_new import SMILESEmbedder |
| emb = SMILESEmbedder( |
| DEVICE, |
| vocab_path=str(WEIGHT_ROOT / "tokenizer/new_vocab.txt"), |
| splits_path=str(WEIGHT_ROOT / "tokenizer/new_splits.txt"), |
| ) |
| elif emb_tag == "chemberta": |
| from inference_new import ChemBERTaEmbedder |
| emb = ChemBERTaEmbedder(DEVICE) |
| else: |
| raise ValueError(f"Unknown emb_tag: {emb_tag}") |
| _embedders[emb_tag] = emb |
| return emb |
|
|
| @torch.no_grad() |
| def embed_sequences(sequences: list, emb_tag: str) -> np.ndarray: |
| embedder = get_embedder(emb_tag) |
| vecs = [] |
| for seq in sequences: |
| v = embedder.pooled(seq.strip()) |
| vecs.append(v.cpu().float().numpy()) |
| return np.vstack(vecs).astype(np.float32) |
|
|
| |
|
|
| def fit_sigma_model(X: np.ndarray, residuals: np.ndarray) -> xgb.Booster: |
| dtrain = xgb.DMatrix(X, label=residuals) |
| params = { |
| "objective": "reg:squarederror", |
| "max_depth": 4, |
| "eta": 0.05, |
| "subsample": 0.8, |
| "colsample_bytree": 0.3, |
| "min_child_weight": 5, |
| "tree_method": "hist", |
| "device": "cuda" if torch.cuda.is_available() else "cpu", |
| "seed": 1986, |
| } |
| return xgb.train(params, dtrain, num_boost_round=200, verbose_eval=False) |
|
|
| |
|
|
| def fit_standard(model_dir: Path, alpha: float, dry_run: bool) -> str: |
| val_path = model_dir / "val_predictions.csv" |
| if not val_path.exists(): |
| val_path = model_dir / "oof_predictions.csv" |
| if not val_path.exists(): |
| return "skip (no val/oof CSV)" |
|
|
| emb_tag = infer_emb_tag(model_dir.name) |
| if emb_tag is None: |
| return "skip (cannot infer emb_tag)" |
|
|
| try: |
| df = pd.read_csv(val_path) |
| seq_col = resolve_col(df, SEQ_CANDIDATES, "sequence") |
| pred_col = resolve_col(df, PRED_CANDIDATES, "pred") |
| true_col = resolve_col(df, TRUE_CANDIDATES, "true") |
| except Exception as e: |
| return f"error: {e}" |
|
|
| sequences = df[seq_col].astype(str).tolist() |
| y_pred = df[pred_col].values.astype(np.float64) |
| y_true = df[true_col].values.astype(np.float64) |
|
|
| mask = np.isfinite(y_pred) & np.isfinite(y_true) |
| sequences = [s for s, m in zip(sequences, mask) if m] |
| y_pred, y_true = y_pred[mask], y_true[mask] |
| n = len(y_pred) |
|
|
| if n < 30: |
| return f"skip (only {n} samples)" |
|
|
| if dry_run: |
| return f"would fit (n={n}, emb={emb_tag})" |
|
|
| try: |
| X = embed_sequences(sequences, emb_tag) |
| except Exception as e: |
| return f"error embedding: {e}" |
|
|
| residuals = np.abs(y_true - y_pred).astype(np.float32) |
| sigma_model = fit_sigma_model(X, residuals) |
| sigma_cal = np.clip(sigma_model.predict(xgb.DMatrix(X)).astype(np.float64), 1e-6, None) |
| norm_scores = (residuals / sigma_cal) |
| level = min(1.0, np.ceil((n + 1) * (1 - alpha)) / n) |
| q = float(np.quantile(norm_scores, level)) |
| lo, hi = y_pred - q * sigma_cal, y_pred + q * sigma_cal |
| coverage = float(np.mean((y_true >= lo) & (y_true <= hi))) |
| avg_width = float(np.mean(hi - lo)) |
|
|
| bundle = {"quantile": q, "sigma_model": sigma_model, |
| "emb_tag": emb_tag, "alpha": alpha, "adaptive": True} |
| joblib.dump(bundle, model_dir / "mapie_calibration.joblib") |
| return f"ok (n={n}, emb={emb_tag}, q={q:.4f}, cov={coverage:.3f}, avg_width={avg_width:.3f})" |
|
|
| |
|
|
| def fit_binding_affinity(model_dir: Path, alpha: float, dry_run: bool) -> str: |
| val_path = model_dir / "val_predictions.csv" |
| if not val_path.exists(): |
| return "skip (no val_predictions.csv)" |
|
|
| binder_emb, target_emb = infer_binding_emb_tags(model_dir.name) |
|
|
| try: |
| df = pd.read_csv(val_path) |
| except Exception as e: |
| return f"error reading CSV: {e}" |
|
|
| |
| cl = {c.lower(): c for c in df.columns} |
| if "sequence" not in cl or "target_sequence" not in cl: |
| return f"skip (missing sequence/target_sequence columns, have: {list(df.columns)})" |
|
|
| binder_seqs = df[cl["sequence"]].astype(str).tolist() |
| target_seqs = df[cl["target_sequence"]].astype(str).tolist() |
|
|
| try: |
| pred_col = resolve_col(df, PRED_CANDIDATES, "pred") |
| true_col = resolve_col(df, TRUE_CANDIDATES, "true") |
| except Exception as e: |
| return f"error: {e}" |
|
|
| y_pred = df[pred_col].values.astype(np.float64) |
| y_true = df[true_col].values.astype(np.float64) |
|
|
| mask = np.isfinite(y_pred) & np.isfinite(y_true) |
| binder_seqs = [s for s, m in zip(binder_seqs, mask) if m] |
| target_seqs = [s for s, m in zip(target_seqs, mask) if m] |
| y_pred, y_true = y_pred[mask], y_true[mask] |
| n = len(y_pred) |
|
|
| if n < 30: |
| return f"skip (only {n} samples)" |
|
|
| if dry_run: |
| return f"would fit (n={n}, binder_emb={binder_emb}, target_emb={target_emb})" |
|
|
| try: |
| X_binder = embed_sequences(binder_seqs, binder_emb) |
| X_target = embed_sequences(target_seqs, target_emb) |
| X = np.concatenate([X_target, X_binder], axis=1) |
| except Exception as e: |
| return f"error embedding: {e}" |
|
|
| |
| |
| |
| |
| |
| residuals = np.abs(y_true - y_pred).astype(np.float32) |
|
|
| |
| |
| |
| |
| |
| sigma_model = fit_sigma_model(X, residuals) |
| sigma_cal = np.clip(sigma_model.predict(xgb.DMatrix(X)).astype(np.float64), 1e-6, None) |
|
|
| |
| |
| |
| norm_scores = residuals / sigma_cal |
|
|
| |
| |
| |
| level = min(1.0, np.ceil((n + 1) * (1 - alpha)) / n) |
| q = float(np.quantile(norm_scores, level)) |
| lo, hi = y_pred - q * sigma_cal, y_pred + q * sigma_cal |
| coverage = float(np.mean((y_true >= lo) & (y_true <= hi))) |
| avg_width = float(np.mean(hi - lo)) |
|
|
| bundle = { |
| "quantile": q, |
| "sigma_model": sigma_model, |
| "emb_tag": binder_emb, |
| "target_emb_tag": target_emb, |
| "alpha": alpha, |
| "adaptive": True, |
| } |
| joblib.dump(bundle, model_dir / "mapie_calibration.joblib") |
| return (f"ok (n={n}, binder={binder_emb}, target={target_emb}, " |
| f"q={q:.4f}, cov={coverage:.3f}, avg_width={avg_width:.3f})") |
|
|
| MODEL_PATTERNS = [ |
| "xgb_*", "enet_*", "svm_*", "svr_*", "mlp_*", "cnn_*", "transformer_*", |
| "wt_wt_*", "wt_smiles_*", "peptideclm_smiles_*", "chemberta_smiles_*", |
| ] |
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--root", type=Path, required=True) |
| parser.add_argument("--alpha", type=float, default=0.1) |
| parser.add_argument("--prop", type=str, default=None, |
| help="Only process a specific property subfolder") |
| parser.add_argument("--dry-run", action="store_true") |
| parser.add_argument("--overwrite", action="store_true") |
| args = parser.parse_args() |
|
|
| search_root = args.root / args.prop if args.prop else args.root |
|
|
| model_dirs = [] |
| for pat in MODEL_PATTERNS: |
| model_dirs.extend(sorted(search_root.rglob(pat))) |
| model_dirs = [d for d in model_dirs if d.is_dir()] |
|
|
| print(f"Found {len(model_dirs)} model dirs under {search_root}") |
| if args.dry_run: |
| print("DRY RUN\n") |
|
|
| counts = {"ok": 0, "skip": 0, "error": 0} |
|
|
| for model_dir in model_dirs: |
| rel = model_dir.relative_to(args.root) |
|
|
| if should_skip(model_dir): |
| print(f" SKIP {rel} (halflife — no sequence in OOF CSV)") |
| counts["skip"] += 1 |
| continue |
|
|
| out = model_dir / "mapie_calibration.joblib" |
| if out.exists() and not args.overwrite: |
| try: |
| b = joblib.load(out) |
| if b.get("adaptive"): |
| print(f" OK {rel} (already adaptive)") |
| counts["ok"] += 1 |
| continue |
| except Exception: |
| pass |
|
|
| print(f" FITTING {rel} ...", end=" ", flush=True) |
| if is_binding_affinity(model_dir): |
| status = fit_binding_affinity(model_dir, args.alpha, args.dry_run) |
| else: |
| status = fit_standard(model_dir, args.alpha, args.dry_run) |
|
|
| tag = "ok" if status.startswith("ok") else ("skip" if status.startswith("skip") else "error") |
| counts[tag] += 1 |
| print(status) |
|
|
| print(f"\nDone. {counts}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |