Spaces:
Sleeping
Sleeping
File size: 4,226 Bytes
2043afa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 | """Load and cache CSV lookup data for the PolypharmacyEnv."""
from __future__ import annotations
import csv
from dataclasses import dataclass, field
from functools import lru_cache
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from .config import (
BEERS_CRITERIA_CSV,
DDI_RULES_CSV,
DRUG_METADATA_CSV,
PATIENTS_CSV,
)
# ββ Row-level data classes βββββββββββββββββββββββββββββββββββββββββββββββββββ
@dataclass(frozen=True)
class DrugMeta:
drug_id: str
generic_name: str
atc_class: str
is_high_risk_elderly: bool
default_dose_mg: float
min_dose_mg: float
max_dose_mg: float
@dataclass(frozen=True)
class DDIRule:
drug_id_1: str
drug_id_2: str
severity: str
mechanism: str
recommendation: str
base_risk_score: float
@dataclass(frozen=True)
class BeersCriterion:
drug_id: str
criterion_type: str # avoid | caution | dose_adjust | avoid_in_condition
condition: Optional[str]
rationale: str
@dataclass
class PatientEpisode:
episode_id: str
age: int
sex: str
conditions: List[str]
eGFR_category: str
liver_function_category: str
medication_ids: List[str]
baseline_risk_score: float
difficulty: str
# ββ Loaders (cached) ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def _read_csv(path: Path) -> List[Dict[str, str]]:
with open(path, newline="") as f:
return list(csv.DictReader(f))
@lru_cache(maxsize=1)
def load_drug_metadata(path: Path = DRUG_METADATA_CSV) -> Dict[str, DrugMeta]:
out: Dict[str, DrugMeta] = {}
for row in _read_csv(path):
dm = DrugMeta(
drug_id=row["drug_id"],
generic_name=row["generic_name"],
atc_class=row["atc_class"],
is_high_risk_elderly=row["is_high_risk_elderly"] == "1",
default_dose_mg=float(row["default_dose_mg"]),
min_dose_mg=float(row["min_dose_mg"]),
max_dose_mg=float(row["max_dose_mg"]),
)
out[dm.drug_id] = dm
return out
def _normalise_pair(a: str, b: str) -> Tuple[str, str]:
return (a, b) if a < b else (b, a)
@lru_cache(maxsize=1)
def load_ddi_rules(path: Path = DDI_RULES_CSV) -> Dict[Tuple[str, str], DDIRule]:
out: Dict[Tuple[str, str], DDIRule] = {}
for row in _read_csv(path):
key = _normalise_pair(row["drug_id_1"], row["drug_id_2"])
out[key] = DDIRule(
drug_id_1=key[0],
drug_id_2=key[1],
severity=row["severity"],
mechanism=row["mechanism"],
recommendation=row["recommendation"],
base_risk_score=float(row["base_risk_score"]),
)
return out
@lru_cache(maxsize=1)
def load_beers_criteria(path: Path = BEERS_CRITERIA_CSV) -> List[BeersCriterion]:
out: List[BeersCriterion] = []
for row in _read_csv(path):
cond = row["condition"].strip() or None
out.append(BeersCriterion(
drug_id=row["drug_id"],
criterion_type=row["criterion_type"],
condition=cond,
rationale=row["rationale"],
))
return out
def load_patients(
path: Path = PATIENTS_CSV,
difficulty: Optional[str] = None,
) -> List[PatientEpisode]:
rows = _read_csv(path)
eps: List[PatientEpisode] = []
for row in rows:
d = row.get("difficulty", "medium")
if difficulty and d != difficulty:
continue
eps.append(PatientEpisode(
episode_id=row["episode_id"],
age=int(row["age"]),
sex=row["sex"],
conditions=[c.strip() for c in row["conditions"].split(";") if c.strip()],
eGFR_category=row["eGFR_category"],
liver_function_category=row["liver_function_category"],
medication_ids=[m.strip() for m in row["medication_ids"].split(";") if m.strip()],
baseline_risk_score=float(row["baseline_risk_score"]),
difficulty=d,
))
return eps
|