| """Graph safety agent.""" |
|
|
| from __future__ import annotations |
|
|
| from app.common.types import PolyGuardState |
| from app.knowledge.ddi_knowledge import top_risky_pairs |
| from app.models.graph.infer import infer_graph_risk |
|
|
|
|
| class GraphSafetyAgent: |
| name = "GraphSafetyAgent" |
|
|
| def run(self, state: PolyGuardState) -> dict: |
| drugs = [m.drug for m in state.patient.medications] |
| risk = infer_graph_risk(drugs) |
| top_pairs = top_risky_pairs(drugs) |
| triples = [] |
| if len(drugs) >= 3: |
| triples = [ |
| [drugs[i], drugs[i + 1], drugs[i + 2]] |
| for i in range(min(2, len(drugs) - 2)) |
| ] |
| return { |
| **risk, |
| "top_dangerous_pairs": top_pairs[:5], |
| "top_dangerous_triples": triples, |
| "mechanism_tags": list(risk.get("side_effect_probs", {}).keys())[:5], |
| } |
|
|