| """Graph dataset builder.""" | |
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from app.knowledge.ddi_knowledge import top_risky_pairs | |
| from app.knowledge.side_effect_ontology import SIDE_EFFECT_TAGS | |
| class GraphSample: | |
| drugs: list[str] | |
| side_effects: list[str] | |
| severe_alert: int | |
| def build_graph_samples(regimens: list[list[str]]) -> list[GraphSample]: | |
| samples: list[GraphSample] = [] | |
| for regimen in regimens: | |
| tags: list[str] = [] | |
| for drug in regimen: | |
| tags.extend(SIDE_EFFECT_TAGS.get(drug, [])) | |
| samples.append( | |
| GraphSample( | |
| drugs=regimen, | |
| side_effects=sorted(set(tags)), | |
| severe_alert=1 if top_risky_pairs(regimen) else 0, | |
| ) | |
| ) | |
| return samples | |