File size: 4,878 Bytes
dfc0f77
 
 
 
 
 
 
 
f2beac3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dfc0f77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import sys
from pathlib import Path

import pytest
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

from env import Action, PharmaVigilanceEnv
from tasks import (
    cluster_signal_medium_action_grader,
    cluster_signal_medium_grader,
    confounded_hard_action_grader,
    confounded_hard_grader,
    get_task,
    get_tasks,
    known_signal_easy_action_grader,
    known_signal_easy_grader,
)


def test_reset_loads_easy_task():
    env = PharmaVigilanceEnv()
    obs = env.reset("known_signal_easy")
    assert obs.task_id == "known_signal_easy"
    assert obs.step_number == 0
    assert len(obs.reports) == 1


def test_known_signal_grader_full_credit():
    reward = known_signal_easy_action_grader(
        Action(
            classification="known_side_effect",
            suspect_drug="Lisinopril",
            severity_assessment="mild",
            recommended_action="log_and_monitor",
            reasoning="Known reaction pattern.",
        )
    )
    assert reward.total == 1.0


def test_medium_cluster_grader_partial_credit():
    reward = cluster_signal_medium_action_grader(
        Action(
            classification="new_signal",
            suspect_drug="Cardiovexa",
            severity_assessment="moderate",
            recommended_action="escalate",
            reasoning="A cluster is forming.",
        )
    )
    assert reward.total == 0.75


def test_hard_grader_reasoning_bonus():
    reward = confounded_hard_action_grader(
        Action(
            classification="new_signal",
            suspect_drug="Tacrolimus+Voriconazole",
            severity_assessment="critical",
            recommended_action="escalate",
            reasoning="This looks like a tacrolimus-voriconazole drug interaction with toxic levels.",
        )
    )
    assert reward.total == 1.0
    assert reward.breakdown["reasoning_bonus"] == 0.15


def test_hard_grader_substring_suspect_match():
    reward = confounded_hard_action_grader(
        Action(
            classification="new_signal",
            suspect_drug="Tacrolimus",
            severity_assessment="critical",
            recommended_action="escalate",
            reasoning="Voriconazole likely raised tacrolimus exposure.",
        )
    )
    assert reward.breakdown["suspect_drug"] == 0.25


def test_env_step_returns_done():
    env = PharmaVigilanceEnv()
    env.reset("confounded_hard")
    obs, reward, done, info = env.step(
        Action(
            classification="new_signal",
            suspect_drug="Tacrolimus+Voriconazole",
            severity_assessment="critical",
            recommended_action="escalate",
            reasoning="Tacrolimus toxicity from an azole interaction.",
        )
    )
    assert done is True
    assert obs.step_number == 1
    assert "reward_breakdown" in info
    assert reward.total >= 0.85


def test_state_tracks_last_action():
    env = PharmaVigilanceEnv()
    env.reset("known_signal_easy")
    env.step(
        Action(
            classification="known_side_effect",
            suspect_drug="Lisinopril",
            severity_assessment="mild",
            recommended_action="log_and_monitor",
            reasoning="Known adverse effect.",
        )
    )
    state = env.state()
    assert state["step_number"] == 1
    assert state["last_action"]["classification"] == "known_side_effect"


def test_all_tasks_available():
    tasks = get_tasks()
    assert set(tasks.keys()) == {
        "known_signal_easy",
        "cluster_signal_medium",
        "confounded_hard",
    }


def test_get_task_returns_hard_truth():
    task = get_task("confounded_hard")
    assert task.ground_truth.suspect_drug == "Tacrolimus+Voriconazole"


def test_public_graders_are_strictly_bounded():
    assert known_signal_easy_grader({"rewards": [1.0]}) == 0.99
    assert cluster_signal_medium_grader({"rewards": [0.0]}) == 0.01
    assert confounded_hard_grader({"score": 1.5}) == 0.99


def test_http_reset_then_step_roundtrip():
    pytest.importorskip("openenv")
    from fastapi.testclient import TestClient
    from server.app import app

    client = TestClient(app)

    reset_response = client.post("/reset", json={})
    assert reset_response.status_code == 200

    step_response = client.post(
        "/step",
        json={
            "action": {
                "classification": "known_side_effect",
                "suspect_drug": "Lisinopril",
                "severity_assessment": "mild",
                "recommended_action": "log_and_monitor",
                "reasoning": "Known ACE inhibitor cough.",
            }
        },
    )
    assert step_response.status_code == 200
    payload = step_response.json()
    assert payload["done"] is True
    assert payload["reward"] == 1.0