#!/usr/bin/env python3 from __future__ import annotations import os from pathlib import Path from typing import Dict, List, Optional, Tuple from huggingface_hub import snapshot_download from inference import ( PeptiVersePredictor, read_best_manifest_csv, canon_model, ) # ----------------------------- # Config # ----------------------------- root = Path(__file__).resolve().parent # current script folder MODEL_REPO = "ChatterjeeLab/PeptiVerse" DEFAULT_ASSETS_DIR = Path(root) # where downloaded models live DEFAULT_MANIFEST = Path("./basic_models.txt") BANNED_MODELS = {"svm", "enet", "svm_gpu", "enet_gpu"} def _norm_prop_disk(prop_key: str) -> str: return "half_life" if prop_key == "halflife" else prop_key def _resolve_expected_model_dir(prop_key: str, model_name: str, mode: str) -> str: disk_prop = _norm_prop_disk(prop_key) base = f"training_classifiers/{disk_prop}" # binding affinity is special: its label is pooled/unpooled and folder uses wt__ if prop_key == "binding_affinity": pooled_or_unpooled = model_name # "pooled" or "unpooled" return f"{base}/wt_{mode}_{pooled_or_unpooled}" # halflife special folders if prop_key == "halflife": if model_name in {"xgb_wt_log", "xgb_smiles"}: return f"{base}/{model_name}" if mode == "wt" and model_name == "transformer": return f"{base}/transformer_wt_log" if model_name == "xgb": return f"{base}/{'xgb_wt_log' if mode == 'wt' else 'xgb_smiles'}" return f"{base}/{model_name}_{mode}" def build_allow_patterns_from_manifest(manifest_path: Path) -> List[str]: best = read_best_manifest_csv(manifest_path) allow: List[str] = [] # For each property, fetch best artifacts for wt + smiles for prop_key, row in best.items(): for mode, label in [("wt", row.best_wt), ("smiles", row.best_smiles)]: m = canon_model(label) if m is None: continue if m in BANNED_MODELS: m = "xgb" model_dir = _resolve_expected_model_dir(prop_key, m, mode) # fetch only "basic" artifacts, not everything in the folder allow += [ f"{model_dir}/best_model.json", f"{model_dir}/best_model.pt", f"{model_dir}/best_model*.joblib", f"{model_dir}/best_model*.json", ] seen = set() out = [] for p in allow: if p not in seen: out.append(p) seen.add(p) return out def download_assets( repo_id: str, manifest_path: Path, out_dir: Path, ) -> Path: out_dir = out_dir.resolve() out_dir.mkdir(parents=True, exist_ok=True) allow_patterns = build_allow_patterns_from_manifest(manifest_path) snapshot_download( repo_id=repo_id, local_dir=str(out_dir), local_dir_use_symlinks=False, allow_patterns=allow_patterns, ) return out_dir # ----------------------------- # Main # ----------------------------- def main(): import argparse ap = argparse.ArgumentParser(description="Lightweight PeptiVerse inference with on-demand model download.") ap.add_argument("--repo", default=MODEL_REPO, help="HF repo id containing weights/assets.") ap.add_argument("--manifest", default=str(DEFAULT_MANIFEST), help="Path to best_models.txt") ap.add_argument("--assets", default=str(DEFAULT_ASSETS_DIR), help="Where to store downloaded assets") ap.add_argument("--device", default=None, help="cuda / cpu / cuda:0, etc") ap.add_argument("--property", default="hemolysis", help="Property key (e.g. hemolysis, solubility, ...)") ap.add_argument("--mode", default="wt", choices=["wt", "smiles"], help="Input type: wt=AA sequence, smiles=SMILES") ap.add_argument("--input", default="GIGAVLKVLTTGLPALISWIKRKRQQ", help="Sequence or SMILES string") ap.add_argument("--target_seq", default=None, help="Target WT sequence for binding_affinity") ap.add_argument("--binder", default=None, help="Binder string (AA or SMILES) for binding_affinity") args = ap.parse_args() manifest_path = Path(args.manifest) if not manifest_path.exists(): raise FileNotFoundError(f"Manifest not found: {manifest_path}") assets_dir = download_assets(args.repo, manifest_path=manifest_path, out_dir=Path(args.assets)) """ OPTIONAL TEST CODE predictor = PeptiVersePredictor( manifest_path="basic_models.txt", # use the downloaded copy to be consistent classifier_weight_root=str(assets_dir), device=args.device, ) if args.property == "binding_affinity": if not args.target_seq or not args.binder: raise ValueError("For binding_affinity, provide --target_seq and --binder.") out = predictor.predict_binding_affinity(args.mode, target_seq=args.target_seq, binder_str=args.binder) else: out = predictor.predict_property(args.property, args.mode, args.input) print(out) """ if __name__ == "__main__": main()