Spaces:
Running
Running
File size: 1,748 Bytes
877add7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 | """Graph model inference."""
from __future__ import annotations
import pickle
from pathlib import Path
from app.models.graph.regimen_embedder import regimen_embedding
from app.models.graph.hetero_encoder import encode_regimen
from app.models.graph.pairwise_ddi_head import score_pair
from app.models.graph.severe_alert_head import severe_alert_probability
from app.models.graph.side_effect_head import predict_side_effects
def _model_path() -> Path:
return Path("outputs/models/graph_model.pkl")
def infer_graph_risk(drugs: list[str], model_path: Path | None = None) -> dict:
path = model_path or _model_path()
base = {
"regimen_embedding": regimen_embedding(drugs),
"severe_alert_probability": severe_alert_probability(drugs),
"side_effect_probs": predict_side_effects(drugs),
"pairwise_ddi_severity": {
f"{a}__{b}": score_pair(a, b)
for i, a in enumerate(drugs)
for b in drugs[i + 1 :]
},
}
if not path.exists():
return base
with path.open("rb") as f:
artifact = pickle.load(f)
encoded = encode_regimen(drugs).reshape(1, -1)
severe_model = artifact.get("severe_model")
side_model = artifact.get("side_model")
mlb = artifact.get("mlb")
if severe_model is not None and hasattr(severe_model, "predict_proba"):
base["severe_alert_probability"] = float(severe_model.predict_proba(encoded)[0][1])
if side_model is not None and mlb is not None:
side_probs = side_model.predict_proba(encoded)[0]
base["side_effect_probs"] = {
str(label): float(prob)
for label, prob in zip(mlb.classes_, side_probs)
if float(prob) > 0.05
}
return base
|