Spaces:
Sleeping
Sleeping
| """Deterministic graders for the three PolypharmacyEnv task difficulties.""" | |
| from __future__ import annotations | |
| from itertools import combinations | |
| from typing import Dict, List, Tuple | |
| from .data_loader import DDIRule | |
| from .config import CRITICAL_DRUG_IDS | |
| from .models import InterventionRecord | |
| _EPS = 1e-8 | |
| def _clip(x: float) -> float: | |
| return max(0.0, min(x, 1.0)) | |
| # ββ Easy: easy_screening βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def grade_easy_screening( | |
| baseline_risk: float, | |
| final_risk: float, | |
| interventions: List[InterventionRecord], | |
| severe_ddi_drug_ids: List[Tuple[str, str]], | |
| ) -> float: | |
| """Score β [0, 1] for the easy task. | |
| 50 % risk reduction + 50 % targeted-intervention flag. | |
| """ | |
| risk_reduction = max(0.0, baseline_risk - final_risk) / max(baseline_risk, _EPS) | |
| targeted = 0.0 | |
| severe_drugs = set() | |
| for a, b in severe_ddi_drug_ids: | |
| severe_drugs.add(a) | |
| severe_drugs.add(b) | |
| for iv in interventions: | |
| if iv.target_drug_id in severe_drugs: | |
| targeted = 1.0 | |
| break | |
| return _clip(0.5 * risk_reduction + 0.5 * targeted) | |
| # ββ Medium: budgeted_screening βββββββββββββββββββββββββββββββββββββββββββββββ | |
| def grade_budgeted_screening( | |
| baseline_risk: float, | |
| final_risk: float, | |
| interventions: List[InterventionRecord], | |
| risk_deltas: List[float], | |
| num_queries: int, | |
| severe_moderate_discovered: int, | |
| ) -> float: | |
| """Score β [0, 1] for the medium task. | |
| 50 % risk reduction + 30 % intervention precision + 20 % query efficiency. | |
| """ | |
| risk_reduction = max(0.0, baseline_risk - final_risk) / max(baseline_risk, _EPS) | |
| # Intervention precision: fraction of interventions that reduced risk | |
| if interventions: | |
| good = sum(1 for d in risk_deltas if d > 0) | |
| precision = good / len(interventions) | |
| else: | |
| precision = 0.0 | |
| # Query efficiency | |
| if num_queries > 0: | |
| query_eff = min(severe_moderate_discovered / num_queries, 1.0) | |
| else: | |
| query_eff = 0.0 | |
| return _clip(0.5 * risk_reduction + 0.3 * precision + 0.2 * query_eff) | |
| # ββ Hard: complex_tradeoff βββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def grade_complex_tradeoff( | |
| baseline_risk: float, | |
| final_risk: float, | |
| interventions: List[InterventionRecord], | |
| total_drug_changes: int, | |
| critical_drugs_stopped_without_sub: int, | |
| ) -> float: | |
| """Score β [0, 1] for the hard task. | |
| Base = risk reduction; penalty for regimen disruption and critical-drug stops. | |
| """ | |
| risk_reduction = max(0.0, baseline_risk - final_risk) / max(baseline_risk, _EPS) | |
| # Regimen disruption: penalise excessive changes | |
| disruption = 0.05 * total_drug_changes | |
| critical_penalty = 0.20 * critical_drugs_stopped_without_sub | |
| return _clip(risk_reduction - disruption - critical_penalty) | |