Spaces:
Running
Running
| """Graph model inference.""" | |
| from __future__ import annotations | |
| import pickle | |
| from pathlib import Path | |
| from app.models.graph.regimen_embedder import regimen_embedding | |
| from app.models.graph.hetero_encoder import encode_regimen | |
| from app.models.graph.pairwise_ddi_head import score_pair | |
| from app.models.graph.severe_alert_head import severe_alert_probability | |
| from app.models.graph.side_effect_head import predict_side_effects | |
| def _model_path() -> Path: | |
| return Path("outputs/models/graph_model.pkl") | |
| def infer_graph_risk(drugs: list[str], model_path: Path | None = None) -> dict: | |
| path = model_path or _model_path() | |
| base = { | |
| "regimen_embedding": regimen_embedding(drugs), | |
| "severe_alert_probability": severe_alert_probability(drugs), | |
| "side_effect_probs": predict_side_effects(drugs), | |
| "pairwise_ddi_severity": { | |
| f"{a}__{b}": score_pair(a, b) | |
| for i, a in enumerate(drugs) | |
| for b in drugs[i + 1 :] | |
| }, | |
| } | |
| if not path.exists(): | |
| return base | |
| with path.open("rb") as f: | |
| artifact = pickle.load(f) | |
| encoded = encode_regimen(drugs).reshape(1, -1) | |
| severe_model = artifact.get("severe_model") | |
| side_model = artifact.get("side_model") | |
| mlb = artifact.get("mlb") | |
| if severe_model is not None and hasattr(severe_model, "predict_proba"): | |
| base["severe_alert_probability"] = float(severe_model.predict_proba(encoded)[0][1]) | |
| if side_model is not None and mlb is not None: | |
| side_probs = side_model.predict_proba(encoded)[0] | |
| base["side_effect_probs"] = { | |
| str(label): float(prob) | |
| for label, prob in zip(mlb.classes_, side_probs) | |
| if float(prob) > 0.05 | |
| } | |
| return base | |