| """Graph model inference.""" |
|
|
| from __future__ import annotations |
|
|
| import pickle |
| from pathlib import Path |
|
|
| from app.models.graph.regimen_embedder import regimen_embedding |
| from app.models.graph.hetero_encoder import encode_regimen |
| from app.models.graph.pairwise_ddi_head import score_pair |
| from app.models.graph.severe_alert_head import severe_alert_probability |
| from app.models.graph.side_effect_head import predict_side_effects |
|
|
|
|
| def _model_path() -> Path: |
| return Path("outputs/models/graph_model.pkl") |
|
|
|
|
| def infer_graph_risk(drugs: list[str], model_path: Path | None = None) -> dict: |
| path = model_path or _model_path() |
| base = { |
| "regimen_embedding": regimen_embedding(drugs), |
| "severe_alert_probability": severe_alert_probability(drugs), |
| "side_effect_probs": predict_side_effects(drugs), |
| "pairwise_ddi_severity": { |
| f"{a}__{b}": score_pair(a, b) |
| for i, a in enumerate(drugs) |
| for b in drugs[i + 1 :] |
| }, |
| } |
| if not path.exists(): |
| return base |
| try: |
| with path.open("rb") as f: |
| artifact = pickle.load(f) |
| except Exception: |
| return base |
| encoded = encode_regimen(drugs).reshape(1, -1) |
| severe_model = artifact.get("severe_model") |
| side_model = artifact.get("side_model") |
| mlb = artifact.get("mlb") |
| if severe_model is not None and hasattr(severe_model, "predict_proba"): |
| try: |
| base["severe_alert_probability"] = float(severe_model.predict_proba(encoded)[0][1]) |
| except Exception: |
| pass |
| if side_model is not None and mlb is not None: |
| try: |
| side_probs = side_model.predict_proba(encoded)[0] |
| base["side_effect_probs"] = { |
| str(label): float(prob) |
| for label, prob in zip(mlb.classes_, side_probs) |
| if float(prob) > 0.05 |
| } |
| except Exception: |
| pass |
| return base |
|
|