File size: 886 Bytes
21c7db9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
"""Graph safety agent."""

from __future__ import annotations

from app.common.types import PolyGuardState
from app.knowledge.ddi_knowledge import top_risky_pairs
from app.models.graph.infer import infer_graph_risk


class GraphSafetyAgent:
    name = "GraphSafetyAgent"

    def run(self, state: PolyGuardState) -> dict:
        drugs = [m.drug for m in state.patient.medications]
        risk = infer_graph_risk(drugs)
        top_pairs = top_risky_pairs(drugs)
        triples = []
        if len(drugs) >= 3:
            triples = [
                [drugs[i], drugs[i + 1], drugs[i + 2]]
                for i in range(min(2, len(drugs) - 2))
            ]
        return {
            **risk,
            "top_dangerous_pairs": top_pairs[:5],
            "top_dangerous_triples": triples,
            "mechanism_tags": list(risk.get("side_effect_probs", {}).keys())[:5],
        }