Spaces:
Sleeping
Sleeping
| """Synthetic data generator for the PolypharmacyEnv. | |
| Generates: | |
| - data/lookups/drug_metadata.csv | |
| - data/lookups/ddi_rules.csv | |
| - data/lookups/beers_criteria.csv | |
| - data/processed/patients_polypharmacy.csv | |
| """ | |
| from __future__ import annotations | |
| import csv | |
| import random | |
| import sys | |
| from itertools import combinations | |
| from pathlib import Path | |
| ROOT = Path(__file__).resolve().parents[1] | |
| LOOKUPS = ROOT / "data" / "lookups" | |
| PROCESSED = ROOT / "data" / "processed" | |
| # ββ Drug catalogue βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| DRUGS = [ | |
| # drug_id, generic_name, atc_class, high_risk, default, min, max | |
| ("DRUG_WARFARIN", "warfarin", "B01AA", 1, 5.0, 1.0, 10.0), | |
| ("DRUG_APIXABAN", "apixaban", "B01AF", 1, 5.0, 2.5, 10.0), | |
| ("DRUG_METFORMIN", "metformin", "A10BA", 0, 1000, 500, 2000), | |
| ("DRUG_GLIPIZIDE", "glipizide", "A10BB", 1, 5.0, 2.5, 20.0), | |
| ("DRUG_LISINOPRIL", "lisinopril", "C09AA", 0, 10.0, 2.5, 40.0), | |
| ("DRUG_AMLODIPINE", "amlodipine", "C08CA", 0, 5.0, 2.5, 10.0), | |
| ("DRUG_METOPROLOL", "metoprolol", "C07AB", 0, 50.0, 25.0,200.0), | |
| ("DRUG_DIGOXIN", "digoxin", "C01AA", 1, 0.25, 0.0625,0.5), | |
| ("DRUG_FUROSEMIDE", "furosemide", "C03CA", 0, 40.0, 20.0,160.0), | |
| ("DRUG_SPIRONOLACTONE", "spironolactone", "C03DA", 0, 25.0, 12.5, 50.0), | |
| ("DRUG_ATORVASTATIN", "atorvastatin", "C10AA", 0, 20.0, 10.0, 80.0), | |
| ("DRUG_SIMVASTATIN", "simvastatin", "C10AA", 0, 20.0, 10.0, 40.0), | |
| ("DRUG_OMEPRAZOLE", "omeprazole", "A02BC", 0, 20.0, 10.0, 40.0), | |
| ("DRUG_DIAZEPAM", "diazepam", "N05BA", 1, 5.0, 2.0, 10.0), | |
| ("DRUG_ALPRAZOLAM", "alprazolam", "N05BA", 1, 0.5, 0.25, 2.0), | |
| ("DRUG_AMITRIPTYLINE", "amitriptyline", "N06AA", 1, 25.0, 10.0, 75.0), | |
| ("DRUG_INSULIN_GLARGINE","insulin glargine", "A10AE", 1, 20.0, 10.0, 60.0), | |
| ("DRUG_PREDNISONE", "prednisone", "H02AB", 0, 10.0, 5.0, 60.0), | |
| ("DRUG_NAPROXEN", "naproxen", "M01AE", 1, 500, 250, 1000), | |
| ("DRUG_IBUPROFEN", "ibuprofen", "M01AE", 1, 400, 200, 800), | |
| ("DRUG_CLOPIDOGREL", "clopidogrel", "B01AC", 0, 75.0, 75.0, 75.0), | |
| ("DRUG_ASPIRIN", "aspirin", "B01AC", 0, 81.0, 81.0, 325.0), | |
| ("DRUG_HYDROCHLOROTHIAZIDE","HCTZ", "C03AA", 0, 25.0, 12.5, 50.0), | |
| ("DRUG_DONEPEZIL", "donepezil", "N06DA", 0, 5.0, 5.0, 10.0), | |
| ("DRUG_GABAPENTIN", "gabapentin", "N03AX", 0, 300, 100, 1200), | |
| ("DRUG_TRAMADOL", "tramadol", "N02AX", 1, 50.0, 25.0, 200.0), | |
| ("DRUG_FLUOXETINE", "fluoxetine", "N06AB", 0, 20.0, 10.0, 60.0), | |
| ("DRUG_SERTRALINE", "sertraline", "N06AB", 0, 50.0, 25.0, 200.0), | |
| ("DRUG_CIPROFLOXACIN", "ciprofloxacin", "J01MA", 0, 500, 250, 750), | |
| ("DRUG_TAMSULOSIN", "tamsulosin", "G04CA", 0, 0.4, 0.4, 0.8), | |
| ("DRUG_CELECOXIB", "celecoxib", "M01AE", 0, 200, 100, 400), | |
| ("DRUG_NORTRIPTYLINE", "nortriptyline", "N06AA", 0, 25.0, 10.0, 75.0), | |
| ("DRUG_LOSARTAN", "losartan", "C09AA", 0, 50.0, 25.0, 100.0), | |
| ] | |
| # ββ DDI rules ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| DDI_PAIRS: list[tuple[str, str, str, str, str, float]] = [ | |
| # id1, id2, severity, mechanism, recommendation, base_risk_score | |
| ("DRUG_WARFARIN", "DRUG_NAPROXEN", "severe", "Increased bleeding risk β NSAID inhibits platelet + anticoagulant", "avoid_combination", 0.90), | |
| ("DRUG_WARFARIN", "DRUG_IBUPROFEN", "severe", "Increased bleeding risk β NSAID + anticoagulant synergy", "avoid_combination", 0.88), | |
| ("DRUG_WARFARIN", "DRUG_ASPIRIN", "moderate", "Additive antiplatelet + anticoagulant bleeding risk", "monitor_closely", 0.55), | |
| ("DRUG_WARFARIN", "DRUG_FLUOXETINE", "moderate", "SSRI increases serotonin and may potentiate bleeding", "monitor_closely", 0.45), | |
| ("DRUG_WARFARIN", "DRUG_CIPROFLOXACIN","moderate","CYP1A2 inhibition raises warfarin levels", "dose_adjust", 0.50), | |
| ("DRUG_APIXABAN", "DRUG_NAPROXEN", "severe", "DOAC + NSAID β high bleeding risk", "avoid_combination", 0.85), | |
| ("DRUG_APIXABAN", "DRUG_ASPIRIN", "moderate", "Additive bleeding risk with antiplatelet", "monitor_closely", 0.50), | |
| ("DRUG_DIGOXIN", "DRUG_AMIODARONE", "severe", "Amiodarone increases digoxin levels β toxicity risk", "dose_adjust", 0.80), | |
| ("DRUG_DIGOXIN", "DRUG_SPIRONOLACTONE","moderate","Spironolactone may raise digoxin levels", "monitor_closely", 0.40), | |
| ("DRUG_METFORMIN", "DRUG_CIPROFLOXACIN","moderate","Fluoroquinolone may cause dysglycemia with metformin", "monitor_closely", 0.35), | |
| ("DRUG_DIAZEPAM", "DRUG_TRAMADOL", "severe", "CNS depression β benzodiazepine + opioid", "avoid_combination", 0.92), | |
| ("DRUG_ALPRAZOLAM", "DRUG_TRAMADOL", "severe", "CNS depression β benzodiazepine + opioid", "avoid_combination", 0.91), | |
| ("DRUG_LISINOPRIL", "DRUG_SPIRONOLACTONE","moderate","Hyperkalemia risk β ACE-I + K-sparing diuretic", "monitor_closely", 0.48), | |
| ("DRUG_LISINOPRIL", "DRUG_NAPROXEN", "moderate", "NSAID reduces ACE-I efficacy, renal risk", "monitor_closely", 0.42), | |
| ("DRUG_SIMVASTATIN","DRUG_AMLODIPINE", "moderate", "CYP3A4 interaction increases statin exposure", "dose_adjust", 0.38), | |
| ("DRUG_ATORVASTATIN","DRUG_CIPROFLOXACIN","mild", "Minor CYP interaction raising statin levels", "no_action", 0.15), | |
| ("DRUG_CLOPIDOGREL","DRUG_OMEPRAZOLE", "moderate", "PPI reduces clopidogrel activation via CYP2C19", "dose_adjust", 0.45), | |
| ("DRUG_INSULIN_GLARGINE","DRUG_GLIPIZIDE","moderate","Additive hypoglycemia risk", "monitor_closely", 0.50), | |
| ("DRUG_FLUOXETINE", "DRUG_TRAMADOL", "severe", "Serotonin syndrome risk β SSRI + serotonergic opioid", "avoid_combination", 0.82), | |
| ("DRUG_AMITRIPTYLINE","DRUG_TRAMADOL", "severe", "Serotonin syndrome + CNS depression", "avoid_combination", 0.85), | |
| ("DRUG_METOPROLOL", "DRUG_DIGOXIN", "moderate", "Additive bradycardia", "monitor_closely", 0.40), | |
| ("DRUG_FUROSEMIDE", "DRUG_DIGOXIN", "moderate", "Loop diuretic causes hypokalemia increasing digoxin toxicity risk", "monitor_closely", 0.45), | |
| ("DRUG_PREDNISONE", "DRUG_NAPROXEN", "moderate", "GI bleeding risk β corticosteroid + NSAID", "monitor_closely", 0.50), | |
| ("DRUG_PREDNISONE", "DRUG_WARFARIN", "mild", "Corticosteroid may alter INR", "monitor_closely", 0.25), | |
| ] | |
| # ββ Beers criteria βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| BEERS_ENTRIES: list[tuple[str, str, str | None, str]] = [ | |
| # drug_id, criterion_type, condition, rationale | |
| ("DRUG_DIAZEPAM", "avoid", None, "Long-acting benzodiazepine: falls, fractures, cognitive impairment in elderly"), | |
| ("DRUG_ALPRAZOLAM", "avoid", None, "Benzodiazepine: falls, fractures, cognitive impairment in elderly"), | |
| ("DRUG_AMITRIPTYLINE", "avoid", None, "Strongly anticholinergic TCA: sedation, confusion, urinary retention in elderly"), | |
| ("DRUG_GLIPIZIDE", "caution", None, "Sulfonylurea: hypoglycemia risk higher in elderly"), | |
| ("DRUG_NAPROXEN", "avoid", "CKD", "NSAID contraindicated in CKD β renal deterioration, fluid retention"), | |
| ("DRUG_IBUPROFEN", "avoid", "CKD", "NSAID contraindicated in CKD β renal deterioration, fluid retention"), | |
| ("DRUG_NAPROXEN", "caution", None, "NSAID: GI bleeding and renal risk in elderly"), | |
| ("DRUG_IBUPROFEN", "caution", None, "NSAID: GI bleeding and renal risk in elderly"), | |
| ("DRUG_DIGOXIN", "dose_adjust", None, "Avoid doses > 0.125 mg/day in elderly β toxicity risk"), | |
| ("DRUG_TRAMADOL", "avoid", None, "Opioid: CNS depression, falls, constipation in elderly"), | |
| ("DRUG_METFORMIN", "dose_adjust", "CKD", "Reduce dose or avoid if eGFR < 30 β lactic acidosis risk"), | |
| ("DRUG_INSULIN_GLARGINE","caution", None, "Tight glycemic control increases hypoglycemia risk in elderly"), | |
| ("DRUG_PREDNISONE", "avoid_in_condition", "DM", "Corticosteroid worsens glycemic control in diabetes"), | |
| ("DRUG_DONEPEZIL", "avoid_in_condition", "dementia", "Limited benefit, GI side effects; reassess regularly"), | |
| ("DRUG_CIPROFLOXACIN", "caution", None, "Fluoroquinolone: tendon rupture, QT prolongation risk in elderly"), | |
| ] | |
| # ββ Conditions pool & constraints ββββββββββββββββββββββββββββββββββββββββββββ | |
| ALL_CONDITIONS = ["HTN", "DM", "HF", "CKD", "AF", "COPD", "OA", "depression", "dementia", "GERD", "BPH", "neuropathy"] | |
| EGFR_CATS = ["normal", "mild", "moderate", "severe"] | |
| LIVER_CATS = ["normal", "impaired"] | |
| # Drugs that make clinical sense per condition | |
| CONDITION_DRUG_MAP: dict[str, list[str]] = { | |
| "HTN": ["DRUG_LISINOPRIL", "DRUG_AMLODIPINE", "DRUG_METOPROLOL", "DRUG_HYDROCHLOROTHIAZIDE", "DRUG_FUROSEMIDE"], | |
| "DM": ["DRUG_METFORMIN", "DRUG_GLIPIZIDE", "DRUG_INSULIN_GLARGINE"], | |
| "HF": ["DRUG_FUROSEMIDE", "DRUG_SPIRONOLACTONE", "DRUG_METOPROLOL", "DRUG_LISINOPRIL", "DRUG_DIGOXIN"], | |
| "CKD": ["DRUG_FUROSEMIDE", "DRUG_AMLODIPINE"], | |
| "AF": ["DRUG_WARFARIN", "DRUG_APIXABAN", "DRUG_METOPROLOL", "DRUG_DIGOXIN"], | |
| "COPD": ["DRUG_PREDNISONE"], | |
| "OA": ["DRUG_NAPROXEN", "DRUG_IBUPROFEN", "DRUG_TRAMADOL", "DRUG_GABAPENTIN"], | |
| "depression": ["DRUG_FLUOXETINE", "DRUG_SERTRALINE", "DRUG_AMITRIPTYLINE"], | |
| "dementia": ["DRUG_DONEPEZIL"], | |
| "GERD": ["DRUG_OMEPRAZOLE"], | |
| "BPH": ["DRUG_TAMSULOSIN"], | |
| "neuropathy": ["DRUG_GABAPENTIN", "DRUG_AMITRIPTYLINE"], | |
| } | |
| def _normalise_pair(a: str, b: str) -> tuple[str, str]: | |
| return (a, b) if a < b else (b, a) | |
| def _gen_drug_metadata(out: Path) -> None: | |
| out.parent.mkdir(parents=True, exist_ok=True) | |
| with open(out, "w", newline="") as f: | |
| w = csv.writer(f) | |
| w.writerow(["drug_id", "generic_name", "atc_class", "is_high_risk_elderly", | |
| "default_dose_mg", "min_dose_mg", "max_dose_mg"]) | |
| for row in DRUGS: | |
| w.writerow(row) | |
| def _gen_ddi_rules(out: Path) -> None: | |
| out.parent.mkdir(parents=True, exist_ok=True) | |
| with open(out, "w", newline="") as f: | |
| w = csv.writer(f) | |
| w.writerow(["drug_id_1", "drug_id_2", "severity", "mechanism", | |
| "recommendation", "base_risk_score"]) | |
| for pair in DDI_PAIRS: | |
| a, b = _normalise_pair(pair[0], pair[1]) | |
| w.writerow([a, b, pair[2], pair[3], pair[4], pair[5]]) | |
| def _gen_beers(out: Path) -> None: | |
| out.parent.mkdir(parents=True, exist_ok=True) | |
| with open(out, "w", newline="") as f: | |
| w = csv.writer(f) | |
| w.writerow(["drug_id", "criterion_type", "condition", "rationale"]) | |
| for row in BEERS_ENTRIES: | |
| w.writerow([row[0], row[1], row[2] or "", row[3]]) | |
| def _gen_patients(out: Path, n_easy: int = 40, n_med: int = 40, n_hard: int = 40) -> None: | |
| """Generate synthetic patient episodes tagged by difficulty.""" | |
| out.parent.mkdir(parents=True, exist_ok=True) | |
| rng = random.Random(42) | |
| drug_ids = [d[0] for d in DRUGS] | |
| # Build severity lookup for quick reference | |
| severe_pairs: set[tuple[str, str]] = set() | |
| for pair in DDI_PAIRS: | |
| if pair[2] == "severe": | |
| severe_pairs.add(_normalise_pair(pair[0], pair[1])) | |
| rows: list[list[str]] = [] | |
| ep_counter = 0 | |
| def _pick_conditions(n: int) -> list[str]: | |
| return rng.sample(ALL_CONDITIONS, min(n, len(ALL_CONDITIONS))) | |
| def _drugs_for_conditions(conds: list[str], target_n: int) -> list[str]: | |
| pool: list[str] = [] | |
| for c in conds: | |
| pool.extend(CONDITION_DRUG_MAP.get(c, [])) | |
| pool = list(dict.fromkeys(pool)) # deduplicate preserving order | |
| rng.shuffle(pool) | |
| selected = pool[:target_n] | |
| # Pad with random drugs if needed | |
| remaining = [d for d in drug_ids if d not in selected] | |
| while len(selected) < target_n and remaining: | |
| pick = rng.choice(remaining) | |
| remaining.remove(pick) | |
| selected.append(pick) | |
| return selected | |
| def _count_severe(meds: list[str]) -> int: | |
| count = 0 | |
| for a, b in combinations(meds, 2): | |
| if _normalise_pair(a, b) in severe_pairs: | |
| count += 1 | |
| return count | |
| def _baseline_risk(meds: list[str]) -> float: | |
| risk = 0.0 | |
| for pair in DDI_PAIRS: | |
| a, b = _normalise_pair(pair[0], pair[1]) | |
| if a in meds and b in meds: | |
| risk += pair[5] | |
| return min(risk / max(len(meds), 1), 1.0) | |
| # Easy episodes: 3-5 drugs, exactly 1 severe DDI | |
| for _ in range(n_easy): | |
| ep_counter += 1 | |
| n_drugs = rng.randint(3, 5) | |
| conds = _pick_conditions(rng.randint(1, 3)) | |
| # Ensure at least one severe DDI pair is present | |
| for attempt in range(50): | |
| meds = _drugs_for_conditions(conds, n_drugs) | |
| if _count_severe(meds) >= 1: | |
| break | |
| else: | |
| # Force a known severe pair | |
| sp = rng.choice(list(severe_pairs)) | |
| meds = list(set(meds[:n_drugs - 2]) | {sp[0], sp[1]})[:n_drugs] | |
| age = rng.randint(65, 90) | |
| sex = rng.choice(["M", "F"]) | |
| egfr = rng.choices(EGFR_CATS, weights=[4, 3, 2, 1])[0] | |
| liver = rng.choices(LIVER_CATS, weights=[8, 2])[0] | |
| br = round(_baseline_risk(meds), 4) | |
| rows.append([ | |
| f"EP_{ep_counter:04d}", str(age), sex, ";".join(conds), | |
| egfr, liver, ";".join(meds), str(br), "easy", | |
| ]) | |
| # Medium episodes: 6-10 drugs, multiple DDIs | |
| for _ in range(n_med): | |
| ep_counter += 1 | |
| n_drugs = rng.randint(6, 10) | |
| conds = _pick_conditions(rng.randint(3, 5)) | |
| meds = _drugs_for_conditions(conds, n_drugs) | |
| age = rng.randint(65, 92) | |
| sex = rng.choice(["M", "F"]) | |
| egfr = rng.choices(EGFR_CATS, weights=[3, 3, 3, 1])[0] | |
| liver = rng.choices(LIVER_CATS, weights=[7, 3])[0] | |
| br = round(_baseline_risk(meds), 4) | |
| rows.append([ | |
| f"EP_{ep_counter:04d}", str(age), sex, ";".join(conds), | |
| egfr, liver, ";".join(meds), str(br), "medium", | |
| ]) | |
| # Hard episodes: 10-15 drugs, many issues, include critical drugs | |
| for _ in range(n_hard): | |
| ep_counter += 1 | |
| n_drugs = rng.randint(10, 15) | |
| conds = _pick_conditions(rng.randint(4, 7)) | |
| meds = _drugs_for_conditions(conds, n_drugs) | |
| # Ensure some critical drugs are present | |
| critical = ["DRUG_WARFARIN", "DRUG_INSULIN_GLARGINE", "DRUG_DIGOXIN"] | |
| for cd in rng.sample(critical, min(2, len(critical))): | |
| if cd not in meds and len(meds) < 15: | |
| meds.append(cd) | |
| age = rng.randint(70, 95) | |
| sex = rng.choice(["M", "F"]) | |
| egfr = rng.choices(EGFR_CATS, weights=[2, 2, 3, 3])[0] | |
| liver = rng.choices(LIVER_CATS, weights=[6, 4])[0] | |
| br = round(_baseline_risk(meds), 4) | |
| rows.append([ | |
| f"EP_{ep_counter:04d}", str(age), sex, ";".join(conds), | |
| egfr, liver, ";".join(meds), str(br), "hard", | |
| ]) | |
| with open(out, "w", newline="") as f: | |
| w = csv.writer(f) | |
| w.writerow(["episode_id", "age", "sex", "conditions", "eGFR_category", | |
| "liver_function_category", "medication_ids", | |
| "baseline_risk_score", "difficulty"]) | |
| for r in rows: | |
| w.writerow(r) | |
| def main() -> None: | |
| print("Generating drug_metadata.csv β¦") | |
| _gen_drug_metadata(LOOKUPS / "drug_metadata.csv") | |
| print("Generating ddi_rules.csv β¦") | |
| _gen_ddi_rules(LOOKUPS / "ddi_rules.csv") | |
| print("Generating beers_criteria.csv β¦") | |
| _gen_beers(LOOKUPS / "beers_criteria.csv") | |
| print("Generating patients_polypharmacy.csv β¦") | |
| _gen_patients(PROCESSED / "patients_polypharmacy.csv") | |
| print("Done.") | |
| if __name__ == "__main__": | |
| main() | |