TheJackBright's picture
Deploy GitHub root master to Space
c296d62
"""Graph model inference."""
from __future__ import annotations
import pickle
from pathlib import Path
from app.models.graph.regimen_embedder import regimen_embedding
from app.models.graph.hetero_encoder import encode_regimen
from app.models.graph.pairwise_ddi_head import score_pair
from app.models.graph.severe_alert_head import severe_alert_probability
from app.models.graph.side_effect_head import predict_side_effects
def _model_path() -> Path:
return Path("outputs/models/graph_model.pkl")
def infer_graph_risk(drugs: list[str], model_path: Path | None = None) -> dict:
path = model_path or _model_path()
base = {
"regimen_embedding": regimen_embedding(drugs),
"severe_alert_probability": severe_alert_probability(drugs),
"side_effect_probs": predict_side_effects(drugs),
"pairwise_ddi_severity": {
f"{a}__{b}": score_pair(a, b)
for i, a in enumerate(drugs)
for b in drugs[i + 1 :]
},
}
if not path.exists():
return base
try:
with path.open("rb") as f:
artifact = pickle.load(f)
except Exception:
return base
encoded = encode_regimen(drugs).reshape(1, -1)
severe_model = artifact.get("severe_model")
side_model = artifact.get("side_model")
mlb = artifact.get("mlb")
if severe_model is not None and hasattr(severe_model, "predict_proba"):
try:
base["severe_alert_probability"] = float(severe_model.predict_proba(encoded)[0][1])
except Exception:
pass
if side_model is not None and mlb is not None:
try:
side_probs = side_model.predict_proba(encoded)[0]
base["side_effect_probs"] = {
str(label): float(prob)
for label, prob in zip(mlb.classes_, side_probs)
if float(prob) > 0.05
}
except Exception:
pass
return base