"""Graph model inference.""" from __future__ import annotations import pickle from pathlib import Path from app.models.graph.regimen_embedder import regimen_embedding from app.models.graph.hetero_encoder import encode_regimen from app.models.graph.pairwise_ddi_head import score_pair from app.models.graph.severe_alert_head import severe_alert_probability from app.models.graph.side_effect_head import predict_side_effects def _model_path() -> Path: return Path("outputs/models/graph_model.pkl") def infer_graph_risk(drugs: list[str], model_path: Path | None = None) -> dict: path = model_path or _model_path() base = { "regimen_embedding": regimen_embedding(drugs), "severe_alert_probability": severe_alert_probability(drugs), "side_effect_probs": predict_side_effects(drugs), "pairwise_ddi_severity": { f"{a}__{b}": score_pair(a, b) for i, a in enumerate(drugs) for b in drugs[i + 1 :] }, } if not path.exists(): return base try: with path.open("rb") as f: artifact = pickle.load(f) except Exception: return base encoded = encode_regimen(drugs).reshape(1, -1) severe_model = artifact.get("severe_model") side_model = artifact.get("side_model") mlb = artifact.get("mlb") if severe_model is not None and hasattr(severe_model, "predict_proba"): try: base["severe_alert_probability"] = float(severe_model.predict_proba(encoded)[0][1]) except Exception: pass if side_model is not None and mlb is not None: try: side_probs = side_model.predict_proba(encoded)[0] base["side_effect_probs"] = { str(label): float(prob) for label, prob in zip(mlb.classes_, side_probs) if float(prob) > 0.05 } except Exception: pass return base