File size: 3,187 Bytes
2043afa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""Deterministic graders for the three PolypharmacyEnv task difficulties."""

from __future__ import annotations

from itertools import combinations
from typing import Dict, List, Tuple

from .data_loader import DDIRule
from .config import CRITICAL_DRUG_IDS
from .models import InterventionRecord


_EPS = 1e-8


def _clip(x: float) -> float:
    return max(0.0, min(x, 1.0))


# ── Easy: easy_screening ─────────────────────────────────────────────────────

def grade_easy_screening(
    baseline_risk: float,
    final_risk: float,
    interventions: List[InterventionRecord],
    severe_ddi_drug_ids: List[Tuple[str, str]],
) -> float:
    """Score ∈ [0, 1] for the easy task.

    50 % risk reduction + 50 % targeted-intervention flag.
    """
    risk_reduction = max(0.0, baseline_risk - final_risk) / max(baseline_risk, _EPS)

    targeted = 0.0
    severe_drugs = set()
    for a, b in severe_ddi_drug_ids:
        severe_drugs.add(a)
        severe_drugs.add(b)
    for iv in interventions:
        if iv.target_drug_id in severe_drugs:
            targeted = 1.0
            break

    return _clip(0.5 * risk_reduction + 0.5 * targeted)


# ── Medium: budgeted_screening ───────────────────────────────────────────────

def grade_budgeted_screening(
    baseline_risk: float,
    final_risk: float,
    interventions: List[InterventionRecord],
    risk_deltas: List[float],
    num_queries: int,
    severe_moderate_discovered: int,
) -> float:
    """Score ∈ [0, 1] for the medium task.

    50 % risk reduction + 30 % intervention precision + 20 % query efficiency.
    """
    risk_reduction = max(0.0, baseline_risk - final_risk) / max(baseline_risk, _EPS)

    # Intervention precision: fraction of interventions that reduced risk
    if interventions:
        good = sum(1 for d in risk_deltas if d > 0)
        precision = good / len(interventions)
    else:
        precision = 0.0

    # Query efficiency
    if num_queries > 0:
        query_eff = min(severe_moderate_discovered / num_queries, 1.0)
    else:
        query_eff = 0.0

    return _clip(0.5 * risk_reduction + 0.3 * precision + 0.2 * query_eff)


# ── Hard: complex_tradeoff ───────────────────────────────────────────────────

def grade_complex_tradeoff(
    baseline_risk: float,
    final_risk: float,
    interventions: List[InterventionRecord],
    total_drug_changes: int,
    critical_drugs_stopped_without_sub: int,
) -> float:
    """Score ∈ [0, 1] for the hard task.

    Base = risk reduction; penalty for regimen disruption and critical-drug stops.
    """
    risk_reduction = max(0.0, baseline_risk - final_risk) / max(baseline_risk, _EPS)

    # Regimen disruption: penalise excessive changes
    disruption = 0.05 * total_drug_changes
    critical_penalty = 0.20 * critical_drugs_stopped_without_sub

    return _clip(risk_reduction - disruption - critical_penalty)