polypharmacy-env / scripts /preprocess_data.py
adithya9903's picture
fix: monorepo
e543908
"""Synthetic data generator for the PolypharmacyEnv.
Generates:
- data/lookups/drug_metadata.csv
- data/lookups/ddi_rules.csv
- data/lookups/beers_criteria.csv
- data/processed/patients_polypharmacy.csv
"""
from __future__ import annotations
import csv
import random
import sys
from itertools import combinations
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
LOOKUPS = ROOT / "data" / "lookups"
PROCESSED = ROOT / "data" / "processed"
# ── Drug catalogue ───────────────────────────────────────────────────────────
DRUGS = [
# drug_id, generic_name, atc_class, high_risk, default, min, max
("DRUG_WARFARIN", "warfarin", "B01AA", 1, 5.0, 1.0, 10.0),
("DRUG_APIXABAN", "apixaban", "B01AF", 1, 5.0, 2.5, 10.0),
("DRUG_METFORMIN", "metformin", "A10BA", 0, 1000, 500, 2000),
("DRUG_GLIPIZIDE", "glipizide", "A10BB", 1, 5.0, 2.5, 20.0),
("DRUG_LISINOPRIL", "lisinopril", "C09AA", 0, 10.0, 2.5, 40.0),
("DRUG_AMLODIPINE", "amlodipine", "C08CA", 0, 5.0, 2.5, 10.0),
("DRUG_METOPROLOL", "metoprolol", "C07AB", 0, 50.0, 25.0,200.0),
("DRUG_DIGOXIN", "digoxin", "C01AA", 1, 0.25, 0.0625,0.5),
("DRUG_FUROSEMIDE", "furosemide", "C03CA", 0, 40.0, 20.0,160.0),
("DRUG_SPIRONOLACTONE", "spironolactone", "C03DA", 0, 25.0, 12.5, 50.0),
("DRUG_ATORVASTATIN", "atorvastatin", "C10AA", 0, 20.0, 10.0, 80.0),
("DRUG_SIMVASTATIN", "simvastatin", "C10AA", 0, 20.0, 10.0, 40.0),
("DRUG_OMEPRAZOLE", "omeprazole", "A02BC", 0, 20.0, 10.0, 40.0),
("DRUG_DIAZEPAM", "diazepam", "N05BA", 1, 5.0, 2.0, 10.0),
("DRUG_ALPRAZOLAM", "alprazolam", "N05BA", 1, 0.5, 0.25, 2.0),
("DRUG_AMITRIPTYLINE", "amitriptyline", "N06AA", 1, 25.0, 10.0, 75.0),
("DRUG_INSULIN_GLARGINE","insulin glargine", "A10AE", 1, 20.0, 10.0, 60.0),
("DRUG_PREDNISONE", "prednisone", "H02AB", 0, 10.0, 5.0, 60.0),
("DRUG_NAPROXEN", "naproxen", "M01AE", 1, 500, 250, 1000),
("DRUG_IBUPROFEN", "ibuprofen", "M01AE", 1, 400, 200, 800),
("DRUG_CLOPIDOGREL", "clopidogrel", "B01AC", 0, 75.0, 75.0, 75.0),
("DRUG_ASPIRIN", "aspirin", "B01AC", 0, 81.0, 81.0, 325.0),
("DRUG_HYDROCHLOROTHIAZIDE","HCTZ", "C03AA", 0, 25.0, 12.5, 50.0),
("DRUG_DONEPEZIL", "donepezil", "N06DA", 0, 5.0, 5.0, 10.0),
("DRUG_GABAPENTIN", "gabapentin", "N03AX", 0, 300, 100, 1200),
("DRUG_TRAMADOL", "tramadol", "N02AX", 1, 50.0, 25.0, 200.0),
("DRUG_FLUOXETINE", "fluoxetine", "N06AB", 0, 20.0, 10.0, 60.0),
("DRUG_SERTRALINE", "sertraline", "N06AB", 0, 50.0, 25.0, 200.0),
("DRUG_CIPROFLOXACIN", "ciprofloxacin", "J01MA", 0, 500, 250, 750),
("DRUG_TAMSULOSIN", "tamsulosin", "G04CA", 0, 0.4, 0.4, 0.8),
("DRUG_CELECOXIB", "celecoxib", "M01AE", 0, 200, 100, 400),
("DRUG_NORTRIPTYLINE", "nortriptyline", "N06AA", 0, 25.0, 10.0, 75.0),
("DRUG_LOSARTAN", "losartan", "C09AA", 0, 50.0, 25.0, 100.0),
]
# ── DDI rules ────────────────────────────────────────────────────────────────
DDI_PAIRS: list[tuple[str, str, str, str, str, float]] = [
# id1, id2, severity, mechanism, recommendation, base_risk_score
("DRUG_WARFARIN", "DRUG_NAPROXEN", "severe", "Increased bleeding risk – NSAID inhibits platelet + anticoagulant", "avoid_combination", 0.90),
("DRUG_WARFARIN", "DRUG_IBUPROFEN", "severe", "Increased bleeding risk – NSAID + anticoagulant synergy", "avoid_combination", 0.88),
("DRUG_WARFARIN", "DRUG_ASPIRIN", "moderate", "Additive antiplatelet + anticoagulant bleeding risk", "monitor_closely", 0.55),
("DRUG_WARFARIN", "DRUG_FLUOXETINE", "moderate", "SSRI increases serotonin and may potentiate bleeding", "monitor_closely", 0.45),
("DRUG_WARFARIN", "DRUG_CIPROFLOXACIN","moderate","CYP1A2 inhibition raises warfarin levels", "dose_adjust", 0.50),
("DRUG_APIXABAN", "DRUG_NAPROXEN", "severe", "DOAC + NSAID – high bleeding risk", "avoid_combination", 0.85),
("DRUG_APIXABAN", "DRUG_ASPIRIN", "moderate", "Additive bleeding risk with antiplatelet", "monitor_closely", 0.50),
("DRUG_DIGOXIN", "DRUG_AMIODARONE", "severe", "Amiodarone increases digoxin levels – toxicity risk", "dose_adjust", 0.80),
("DRUG_DIGOXIN", "DRUG_SPIRONOLACTONE","moderate","Spironolactone may raise digoxin levels", "monitor_closely", 0.40),
("DRUG_METFORMIN", "DRUG_CIPROFLOXACIN","moderate","Fluoroquinolone may cause dysglycemia with metformin", "monitor_closely", 0.35),
("DRUG_DIAZEPAM", "DRUG_TRAMADOL", "severe", "CNS depression – benzodiazepine + opioid", "avoid_combination", 0.92),
("DRUG_ALPRAZOLAM", "DRUG_TRAMADOL", "severe", "CNS depression – benzodiazepine + opioid", "avoid_combination", 0.91),
("DRUG_LISINOPRIL", "DRUG_SPIRONOLACTONE","moderate","Hyperkalemia risk – ACE-I + K-sparing diuretic", "monitor_closely", 0.48),
("DRUG_LISINOPRIL", "DRUG_NAPROXEN", "moderate", "NSAID reduces ACE-I efficacy, renal risk", "monitor_closely", 0.42),
("DRUG_SIMVASTATIN","DRUG_AMLODIPINE", "moderate", "CYP3A4 interaction increases statin exposure", "dose_adjust", 0.38),
("DRUG_ATORVASTATIN","DRUG_CIPROFLOXACIN","mild", "Minor CYP interaction raising statin levels", "no_action", 0.15),
("DRUG_CLOPIDOGREL","DRUG_OMEPRAZOLE", "moderate", "PPI reduces clopidogrel activation via CYP2C19", "dose_adjust", 0.45),
("DRUG_INSULIN_GLARGINE","DRUG_GLIPIZIDE","moderate","Additive hypoglycemia risk", "monitor_closely", 0.50),
("DRUG_FLUOXETINE", "DRUG_TRAMADOL", "severe", "Serotonin syndrome risk – SSRI + serotonergic opioid", "avoid_combination", 0.82),
("DRUG_AMITRIPTYLINE","DRUG_TRAMADOL", "severe", "Serotonin syndrome + CNS depression", "avoid_combination", 0.85),
("DRUG_METOPROLOL", "DRUG_DIGOXIN", "moderate", "Additive bradycardia", "monitor_closely", 0.40),
("DRUG_FUROSEMIDE", "DRUG_DIGOXIN", "moderate", "Loop diuretic causes hypokalemia increasing digoxin toxicity risk", "monitor_closely", 0.45),
("DRUG_PREDNISONE", "DRUG_NAPROXEN", "moderate", "GI bleeding risk – corticosteroid + NSAID", "monitor_closely", 0.50),
("DRUG_PREDNISONE", "DRUG_WARFARIN", "mild", "Corticosteroid may alter INR", "monitor_closely", 0.25),
]
# ── Beers criteria ───────────────────────────────────────────────────────────
BEERS_ENTRIES: list[tuple[str, str, str | None, str]] = [
# drug_id, criterion_type, condition, rationale
("DRUG_DIAZEPAM", "avoid", None, "Long-acting benzodiazepine: falls, fractures, cognitive impairment in elderly"),
("DRUG_ALPRAZOLAM", "avoid", None, "Benzodiazepine: falls, fractures, cognitive impairment in elderly"),
("DRUG_AMITRIPTYLINE", "avoid", None, "Strongly anticholinergic TCA: sedation, confusion, urinary retention in elderly"),
("DRUG_GLIPIZIDE", "caution", None, "Sulfonylurea: hypoglycemia risk higher in elderly"),
("DRUG_NAPROXEN", "avoid", "CKD", "NSAID contraindicated in CKD – renal deterioration, fluid retention"),
("DRUG_IBUPROFEN", "avoid", "CKD", "NSAID contraindicated in CKD – renal deterioration, fluid retention"),
("DRUG_NAPROXEN", "caution", None, "NSAID: GI bleeding and renal risk in elderly"),
("DRUG_IBUPROFEN", "caution", None, "NSAID: GI bleeding and renal risk in elderly"),
("DRUG_DIGOXIN", "dose_adjust", None, "Avoid doses > 0.125 mg/day in elderly – toxicity risk"),
("DRUG_TRAMADOL", "avoid", None, "Opioid: CNS depression, falls, constipation in elderly"),
("DRUG_METFORMIN", "dose_adjust", "CKD", "Reduce dose or avoid if eGFR < 30 – lactic acidosis risk"),
("DRUG_INSULIN_GLARGINE","caution", None, "Tight glycemic control increases hypoglycemia risk in elderly"),
("DRUG_PREDNISONE", "avoid_in_condition", "DM", "Corticosteroid worsens glycemic control in diabetes"),
("DRUG_DONEPEZIL", "avoid_in_condition", "dementia", "Limited benefit, GI side effects; reassess regularly"),
("DRUG_CIPROFLOXACIN", "caution", None, "Fluoroquinolone: tendon rupture, QT prolongation risk in elderly"),
]
# ── Conditions pool & constraints ────────────────────────────────────────────
ALL_CONDITIONS = ["HTN", "DM", "HF", "CKD", "AF", "COPD", "OA", "depression", "dementia", "GERD", "BPH", "neuropathy"]
EGFR_CATS = ["normal", "mild", "moderate", "severe"]
LIVER_CATS = ["normal", "impaired"]
# Drugs that make clinical sense per condition
CONDITION_DRUG_MAP: dict[str, list[str]] = {
"HTN": ["DRUG_LISINOPRIL", "DRUG_AMLODIPINE", "DRUG_METOPROLOL", "DRUG_HYDROCHLOROTHIAZIDE", "DRUG_FUROSEMIDE"],
"DM": ["DRUG_METFORMIN", "DRUG_GLIPIZIDE", "DRUG_INSULIN_GLARGINE"],
"HF": ["DRUG_FUROSEMIDE", "DRUG_SPIRONOLACTONE", "DRUG_METOPROLOL", "DRUG_LISINOPRIL", "DRUG_DIGOXIN"],
"CKD": ["DRUG_FUROSEMIDE", "DRUG_AMLODIPINE"],
"AF": ["DRUG_WARFARIN", "DRUG_APIXABAN", "DRUG_METOPROLOL", "DRUG_DIGOXIN"],
"COPD": ["DRUG_PREDNISONE"],
"OA": ["DRUG_NAPROXEN", "DRUG_IBUPROFEN", "DRUG_TRAMADOL", "DRUG_GABAPENTIN"],
"depression": ["DRUG_FLUOXETINE", "DRUG_SERTRALINE", "DRUG_AMITRIPTYLINE"],
"dementia": ["DRUG_DONEPEZIL"],
"GERD": ["DRUG_OMEPRAZOLE"],
"BPH": ["DRUG_TAMSULOSIN"],
"neuropathy": ["DRUG_GABAPENTIN", "DRUG_AMITRIPTYLINE"],
}
def _normalise_pair(a: str, b: str) -> tuple[str, str]:
return (a, b) if a < b else (b, a)
def _gen_drug_metadata(out: Path) -> None:
out.parent.mkdir(parents=True, exist_ok=True)
with open(out, "w", newline="") as f:
w = csv.writer(f)
w.writerow(["drug_id", "generic_name", "atc_class", "is_high_risk_elderly",
"default_dose_mg", "min_dose_mg", "max_dose_mg"])
for row in DRUGS:
w.writerow(row)
def _gen_ddi_rules(out: Path) -> None:
out.parent.mkdir(parents=True, exist_ok=True)
with open(out, "w", newline="") as f:
w = csv.writer(f)
w.writerow(["drug_id_1", "drug_id_2", "severity", "mechanism",
"recommendation", "base_risk_score"])
for pair in DDI_PAIRS:
a, b = _normalise_pair(pair[0], pair[1])
w.writerow([a, b, pair[2], pair[3], pair[4], pair[5]])
def _gen_beers(out: Path) -> None:
out.parent.mkdir(parents=True, exist_ok=True)
with open(out, "w", newline="") as f:
w = csv.writer(f)
w.writerow(["drug_id", "criterion_type", "condition", "rationale"])
for row in BEERS_ENTRIES:
w.writerow([row[0], row[1], row[2] or "", row[3]])
def _gen_patients(out: Path, n_easy: int = 40, n_med: int = 40, n_hard: int = 40) -> None:
"""Generate synthetic patient episodes tagged by difficulty."""
out.parent.mkdir(parents=True, exist_ok=True)
rng = random.Random(42)
drug_ids = [d[0] for d in DRUGS]
# Build severity lookup for quick reference
severe_pairs: set[tuple[str, str]] = set()
for pair in DDI_PAIRS:
if pair[2] == "severe":
severe_pairs.add(_normalise_pair(pair[0], pair[1]))
rows: list[list[str]] = []
ep_counter = 0
def _pick_conditions(n: int) -> list[str]:
return rng.sample(ALL_CONDITIONS, min(n, len(ALL_CONDITIONS)))
def _drugs_for_conditions(conds: list[str], target_n: int) -> list[str]:
pool: list[str] = []
for c in conds:
pool.extend(CONDITION_DRUG_MAP.get(c, []))
pool = list(dict.fromkeys(pool)) # deduplicate preserving order
rng.shuffle(pool)
selected = pool[:target_n]
# Pad with random drugs if needed
remaining = [d for d in drug_ids if d not in selected]
while len(selected) < target_n and remaining:
pick = rng.choice(remaining)
remaining.remove(pick)
selected.append(pick)
return selected
def _count_severe(meds: list[str]) -> int:
count = 0
for a, b in combinations(meds, 2):
if _normalise_pair(a, b) in severe_pairs:
count += 1
return count
def _baseline_risk(meds: list[str]) -> float:
risk = 0.0
for pair in DDI_PAIRS:
a, b = _normalise_pair(pair[0], pair[1])
if a in meds and b in meds:
risk += pair[5]
return min(risk / max(len(meds), 1), 1.0)
# Easy episodes: 3-5 drugs, exactly 1 severe DDI
for _ in range(n_easy):
ep_counter += 1
n_drugs = rng.randint(3, 5)
conds = _pick_conditions(rng.randint(1, 3))
# Ensure at least one severe DDI pair is present
for attempt in range(50):
meds = _drugs_for_conditions(conds, n_drugs)
if _count_severe(meds) >= 1:
break
else:
# Force a known severe pair
sp = rng.choice(list(severe_pairs))
meds = list(set(meds[:n_drugs - 2]) | {sp[0], sp[1]})[:n_drugs]
age = rng.randint(65, 90)
sex = rng.choice(["M", "F"])
egfr = rng.choices(EGFR_CATS, weights=[4, 3, 2, 1])[0]
liver = rng.choices(LIVER_CATS, weights=[8, 2])[0]
br = round(_baseline_risk(meds), 4)
rows.append([
f"EP_{ep_counter:04d}", str(age), sex, ";".join(conds),
egfr, liver, ";".join(meds), str(br), "easy",
])
# Medium episodes: 6-10 drugs, multiple DDIs
for _ in range(n_med):
ep_counter += 1
n_drugs = rng.randint(6, 10)
conds = _pick_conditions(rng.randint(3, 5))
meds = _drugs_for_conditions(conds, n_drugs)
age = rng.randint(65, 92)
sex = rng.choice(["M", "F"])
egfr = rng.choices(EGFR_CATS, weights=[3, 3, 3, 1])[0]
liver = rng.choices(LIVER_CATS, weights=[7, 3])[0]
br = round(_baseline_risk(meds), 4)
rows.append([
f"EP_{ep_counter:04d}", str(age), sex, ";".join(conds),
egfr, liver, ";".join(meds), str(br), "medium",
])
# Hard episodes: 10-15 drugs, many issues, include critical drugs
for _ in range(n_hard):
ep_counter += 1
n_drugs = rng.randint(10, 15)
conds = _pick_conditions(rng.randint(4, 7))
meds = _drugs_for_conditions(conds, n_drugs)
# Ensure some critical drugs are present
critical = ["DRUG_WARFARIN", "DRUG_INSULIN_GLARGINE", "DRUG_DIGOXIN"]
for cd in rng.sample(critical, min(2, len(critical))):
if cd not in meds and len(meds) < 15:
meds.append(cd)
age = rng.randint(70, 95)
sex = rng.choice(["M", "F"])
egfr = rng.choices(EGFR_CATS, weights=[2, 2, 3, 3])[0]
liver = rng.choices(LIVER_CATS, weights=[6, 4])[0]
br = round(_baseline_risk(meds), 4)
rows.append([
f"EP_{ep_counter:04d}", str(age), sex, ";".join(conds),
egfr, liver, ";".join(meds), str(br), "hard",
])
with open(out, "w", newline="") as f:
w = csv.writer(f)
w.writerow(["episode_id", "age", "sex", "conditions", "eGFR_category",
"liver_function_category", "medication_ids",
"baseline_risk_score", "difficulty"])
for r in rows:
w.writerow(r)
def main() -> None:
print("Generating drug_metadata.csv …")
_gen_drug_metadata(LOOKUPS / "drug_metadata.csv")
print("Generating ddi_rules.csv …")
_gen_ddi_rules(LOOKUPS / "ddi_rules.csv")
print("Generating beers_criteria.csv …")
_gen_beers(LOOKUPS / "beers_criteria.csv")
print("Generating patients_polypharmacy.csv …")
_gen_patients(PROCESSED / "patients_polypharmacy.csv")
print("Done.")
if __name__ == "__main__":
main()