Spaces:
Sleeping
Sleeping
File size: 3,187 Bytes
2043afa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 | """Deterministic graders for the three PolypharmacyEnv task difficulties."""
from __future__ import annotations
from itertools import combinations
from typing import Dict, List, Tuple
from .data_loader import DDIRule
from .config import CRITICAL_DRUG_IDS
from .models import InterventionRecord
_EPS = 1e-8
def _clip(x: float) -> float:
return max(0.0, min(x, 1.0))
# ββ Easy: easy_screening βββββββββββββββββββββββββββββββββββββββββββββββββββββ
def grade_easy_screening(
baseline_risk: float,
final_risk: float,
interventions: List[InterventionRecord],
severe_ddi_drug_ids: List[Tuple[str, str]],
) -> float:
"""Score β [0, 1] for the easy task.
50 % risk reduction + 50 % targeted-intervention flag.
"""
risk_reduction = max(0.0, baseline_risk - final_risk) / max(baseline_risk, _EPS)
targeted = 0.0
severe_drugs = set()
for a, b in severe_ddi_drug_ids:
severe_drugs.add(a)
severe_drugs.add(b)
for iv in interventions:
if iv.target_drug_id in severe_drugs:
targeted = 1.0
break
return _clip(0.5 * risk_reduction + 0.5 * targeted)
# ββ Medium: budgeted_screening βββββββββββββββββββββββββββββββββββββββββββββββ
def grade_budgeted_screening(
baseline_risk: float,
final_risk: float,
interventions: List[InterventionRecord],
risk_deltas: List[float],
num_queries: int,
severe_moderate_discovered: int,
) -> float:
"""Score β [0, 1] for the medium task.
50 % risk reduction + 30 % intervention precision + 20 % query efficiency.
"""
risk_reduction = max(0.0, baseline_risk - final_risk) / max(baseline_risk, _EPS)
# Intervention precision: fraction of interventions that reduced risk
if interventions:
good = sum(1 for d in risk_deltas if d > 0)
precision = good / len(interventions)
else:
precision = 0.0
# Query efficiency
if num_queries > 0:
query_eff = min(severe_moderate_discovered / num_queries, 1.0)
else:
query_eff = 0.0
return _clip(0.5 * risk_reduction + 0.3 * precision + 0.2 * query_eff)
# ββ Hard: complex_tradeoff βββββββββββββββββββββββββββββββββββββββββββββββββββ
def grade_complex_tradeoff(
baseline_risk: float,
final_risk: float,
interventions: List[InterventionRecord],
total_drug_changes: int,
critical_drugs_stopped_without_sub: int,
) -> float:
"""Score β [0, 1] for the hard task.
Base = risk reduction; penalty for regimen disruption and critical-drug stops.
"""
risk_reduction = max(0.0, baseline_risk - final_risk) / max(baseline_risk, _EPS)
# Regimen disruption: penalise excessive changes
disruption = 0.05 * total_drug_changes
critical_penalty = 0.20 * critical_drugs_stopped_without_sub
return _clip(risk_reduction - disruption - critical_penalty)
|