File size: 1,748 Bytes
877add7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
"""Graph model inference."""

from __future__ import annotations

import pickle
from pathlib import Path

from app.models.graph.regimen_embedder import regimen_embedding
from app.models.graph.hetero_encoder import encode_regimen
from app.models.graph.pairwise_ddi_head import score_pair
from app.models.graph.severe_alert_head import severe_alert_probability
from app.models.graph.side_effect_head import predict_side_effects


def _model_path() -> Path:
    return Path("outputs/models/graph_model.pkl")


def infer_graph_risk(drugs: list[str], model_path: Path | None = None) -> dict:
    path = model_path or _model_path()
    base = {
        "regimen_embedding": regimen_embedding(drugs),
        "severe_alert_probability": severe_alert_probability(drugs),
        "side_effect_probs": predict_side_effects(drugs),
        "pairwise_ddi_severity": {
            f"{a}__{b}": score_pair(a, b)
            for i, a in enumerate(drugs)
            for b in drugs[i + 1 :]
        },
    }
    if not path.exists():
        return base
    with path.open("rb") as f:
        artifact = pickle.load(f)
    encoded = encode_regimen(drugs).reshape(1, -1)
    severe_model = artifact.get("severe_model")
    side_model = artifact.get("side_model")
    mlb = artifact.get("mlb")
    if severe_model is not None and hasattr(severe_model, "predict_proba"):
        base["severe_alert_probability"] = float(severe_model.predict_proba(encoded)[0][1])
    if side_model is not None and mlb is not None:
        side_probs = side_model.predict_proba(encoded)[0]
        base["side_effect_probs"] = {
            str(label): float(prob)
            for label, prob in zip(mlb.classes_, side_probs)
            if float(prob) > 0.05
        }
    return base