adithya9903's picture
Flatten project to root for OpenEnv submission readiness.
fa51dd9
"""Deterministic graders for the three PolypharmacyEnv task difficulties."""
from __future__ import annotations
from itertools import combinations
from typing import Dict, List, Tuple
from .data_loader import DDIRule
from .config import CRITICAL_DRUG_IDS
from .models import InterventionRecord
_EPS = 1e-8
def _clip(x: float) -> float:
return max(0.0, min(x, 1.0))
# ── Easy: easy_screening ─────────────────────────────────────────────────────
def grade_easy_screening(
baseline_risk: float,
final_risk: float,
interventions: List[InterventionRecord],
severe_ddi_drug_ids: List[Tuple[str, str]],
) -> float:
"""Score ∈ [0, 1] for the easy task.
50 % risk reduction + 50 % targeted-intervention flag.
"""
risk_reduction = max(0.0, baseline_risk - final_risk) / max(baseline_risk, _EPS)
targeted = 0.0
severe_drugs = set()
for a, b in severe_ddi_drug_ids:
severe_drugs.add(a)
severe_drugs.add(b)
for iv in interventions:
if iv.target_drug_id in severe_drugs:
targeted = 1.0
break
return _clip(0.5 * risk_reduction + 0.5 * targeted)
# ── Medium: budgeted_screening ───────────────────────────────────────────────
def grade_budgeted_screening(
baseline_risk: float,
final_risk: float,
interventions: List[InterventionRecord],
risk_deltas: List[float],
num_queries: int,
severe_moderate_discovered: int,
) -> float:
"""Score ∈ [0, 1] for the medium task.
50 % risk reduction + 30 % intervention precision + 20 % query efficiency.
"""
risk_reduction = max(0.0, baseline_risk - final_risk) / max(baseline_risk, _EPS)
# Intervention precision: fraction of interventions that reduced risk
if interventions:
good = sum(1 for d in risk_deltas if d > 0)
precision = good / len(interventions)
else:
precision = 0.0
# Query efficiency
if num_queries > 0:
query_eff = min(severe_moderate_discovered / num_queries, 1.0)
else:
query_eff = 0.0
return _clip(0.5 * risk_reduction + 0.3 * precision + 0.2 * query_eff)
# ── Hard: complex_tradeoff ───────────────────────────────────────────────────
def grade_complex_tradeoff(
baseline_risk: float,
final_risk: float,
interventions: List[InterventionRecord],
total_drug_changes: int,
critical_drugs_stopped_without_sub: int,
) -> float:
"""Score ∈ [0, 1] for the hard task.
Base = risk reduction; penalty for regimen disruption and critical-drug stops.
"""
risk_reduction = max(0.0, baseline_risk - final_risk) / max(baseline_risk, _EPS)
# Regimen disruption: penalise excessive changes
disruption = 0.05 * total_drug_changes
critical_penalty = 0.20 * critical_drugs_stopped_without_sub
return _clip(risk_reduction - disruption - critical_penalty)