File size: 11,436 Bytes
dfc0f77
 
 
 
 
 
 
 
9ab33d8
 
f2beac3
 
 
 
9ab33d8
 
 
f2beac3
 
60c0453
 
 
 
 
 
 
f2beac3
 
60c0453
 
 
 
 
 
 
 
 
 
 
 
f2beac3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60c0453
 
 
 
 
 
 
 
 
 
 
 
f2beac3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60c0453
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9ab33d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60c0453
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f2beac3
 
60c0453
 
f2beac3
 
 
60c0453
 
 
 
 
 
 
 
 
 
f2beac3
 
 
 
 
 
 
dfc0f77
 
 
 
 
 
9ab33d8
 
 
 
 
 
 
 
 
 
dfc0f77
 
 
 
 
 
 
 
 
 
60c0453
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dfc0f77
 
 
 
 
 
 
 
 
60c0453
dfc0f77
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
import sys
from pathlib import Path

import pytest
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))

from env import Action, PharmaVigilanceEnv
from tasks import (
    cluster_signal_medium_action_grader,
    cluster_signal_medium_grader,
    confounded_hard_action_grader,
    confounded_hard_grader,
    get_task,
    get_tasks,
    known_signal_easy_action_grader,
    known_signal_easy_grader,
)


def test_reset_loads_easy_task():
    env = PharmaVigilanceEnv()
    obs = env.reset("known_signal_easy")
    assert obs.task_id == "known_signal_easy"
    assert obs.step_number == 0
    assert obs.max_steps == 2
    assert len(obs.reports) == 1


def test_known_signal_grader_full_credit():
    reward = known_signal_easy_action_grader(
        Action(
            classification="known_side_effect",
            suspect_drug="Lisinopril",
            severity_assessment="mild",
            recommended_action="log_and_monitor",
            reasoning="Known reaction pattern.",
            confidence=91,
        )
    )
    assert reward.total == 1.0


def test_medium_cluster_grader_partial_credit():
    reward = cluster_signal_medium_action_grader(
        Action(
            classification="new_signal",
            suspect_drug="Cardiovexa",
            severity_assessment="moderate",
            recommended_action="escalate",
            reasoning="A cluster is forming.",
        )
    )
    assert reward.total == 0.75


def test_hard_grader_reasoning_bonus():
    reward = confounded_hard_action_grader(
        Action(
            classification="new_signal",
            suspect_drug="Tacrolimus+Voriconazole",
            severity_assessment="critical",
            recommended_action="escalate",
            reasoning="This looks like a tacrolimus-voriconazole drug interaction with toxic levels.",
        )
    )
    assert reward.total == 1.0
    assert reward.breakdown["reasoning_bonus"] == 0.05


def test_hard_grader_substring_suspect_match():
    reward = confounded_hard_action_grader(
        Action(
            classification="new_signal",
            suspect_drug="Tacrolimus",
            severity_assessment="critical",
            recommended_action="escalate",
            reasoning="Voriconazole likely raised tacrolimus exposure.",
        )
    )
    assert reward.breakdown["suspect_drug"] == 0.25


def test_env_step_returns_done():
    env = PharmaVigilanceEnv()
    env.reset("confounded_hard")
    obs, reward, done, info = env.step(
        Action(
            classification="new_signal",
            suspect_drug="Tacrolimus+Voriconazole",
            severity_assessment="critical",
            recommended_action="escalate",
            reasoning="Tacrolimus toxicity from an azole interaction.",
        )
    )
    assert done is False
    assert obs.step_number == 1
    assert "reward_breakdown" in info
    assert reward.total >= 0.20

    obs, reward, done, info = env.step(
        Action(
            classification="new_signal",
            suspect_drug="Tacrolimus+Voriconazole",
            severity_assessment="critical",
            recommended_action="escalate",
            reasoning="Tacrolimus toxicity from an azole interaction.",
        )
    )
    assert done is True
    assert obs.step_number == 2
    assert reward.total >= 0.85


def test_first_step_returns_partial_reward_and_review_feedback():
    env = PharmaVigilanceEnv()
    obs = env.reset("cluster_signal_medium")

    obs, reward, done, info = env.step(
        Action(
            classification="new_signal",
            suspect_drug="Cardiovexa",
            severity_assessment="severe",
            recommended_action="escalate",
            reasoning="Clustered bradycardia on a newer therapy.",
            confidence=88,
        )
    )
    assert done is False
    assert obs.step_number == 1
    assert reward.total > 0.0
    assert info["phase"] == "initial_triage"
    assert "Senior review note" in obs.feedback


def test_final_step_awards_revision_bonus_when_agent_improves():
    env = PharmaVigilanceEnv()
    env.reset("cluster_signal_medium")

    env.step(
        Action(
            classification="noise",
            suspect_drug="Unknown",
            severity_assessment="mild",
            recommended_action="dismiss",
            reasoning="Weak initial guess.",
            confidence=90,
        )
    )
    _, reward, done, info = env.step(
        Action(
            classification="new_signal",
            suspect_drug="Cardiovexa",
            severity_assessment="severe",
            recommended_action="escalate",
            reasoning="Follow-up reports confirm a coherent bradycardia cluster.",
            confidence=82,
        )
    )
    assert done is True
    assert reward.breakdown["revision_bonus"] == 0.05
    assert info["phase"] == "final_review"


def test_final_step_applies_stubborn_penalty_for_repeating_weak_answer():
    env = PharmaVigilanceEnv()
    env.reset("confounded_hard")

    weak = Action(
        classification="noise",
        suspect_drug="Trimethoprim-sulfamethoxazole",
        severity_assessment="mild",
        recommended_action="dismiss",
        reasoning="Reporter probably overcalled it.",
        confidence=85,
    )
    env.step(weak)
    _, reward, done, _ = env.step(weak)
    assert done is True
    assert reward.breakdown["stubborn_penalty"] == -0.05


def test_initial_step_can_return_negative_reward_for_unsafe_triage():
    env = PharmaVigilanceEnv()
    env.reset("cluster_signal_medium")

    _, reward, done, info = env.step(
        Action(
            classification="noise",
            suspect_drug="Unknown",
            severity_assessment="mild",
            recommended_action="dismiss",
            reasoning="No obvious concern.",
            confidence=95,
        )
    )
    assert done is False
    assert info["phase"] == "initial_triage"
    assert reward.total < 0.0


def test_single_step_action_grader_can_return_negative_total():
    reward = cluster_signal_medium_action_grader(
        Action(
            classification="noise",
            suspect_drug="Unknown",
            severity_assessment="mild",
            recommended_action="dismiss",
            reasoning="Probably unrelated.",
            confidence=95,
        )
    )
    assert reward.total < 0.0


def test_overconfidence_penalty_applies_on_weak_single_step_grading():
    reward = cluster_signal_medium_action_grader(
        Action(
            classification="noise",
            suspect_drug="Unknown",
            severity_assessment="mild",
            recommended_action="dismiss",
            reasoning="This is probably nothing.",
            confidence=95,
        )
    )
    assert reward.breakdown["confidence_adjustment"] == -0.10


def test_low_confidence_penalty_applies_on_strong_answer():
    reward = known_signal_easy_action_grader(
        Action(
            classification="known_side_effect",
            suspect_drug="Lisinopril",
            severity_assessment="mild",
            recommended_action="log_and_monitor",
            reasoning="Known labeled ACE-inhibitor cough.",
            confidence=20,
        )
    )
    assert reward.breakdown["confidence_adjustment"] == -0.03


def test_episode_rejects_third_step_after_completion():
    env = PharmaVigilanceEnv()
    env.reset("known_signal_easy")
    good = Action(
        classification="known_side_effect",
        suspect_drug="Lisinopril",
        severity_assessment="mild",
        recommended_action="log_and_monitor",
        reasoning="Known ACE-inhibitor cough.",
        confidence=90,
    )
    env.step(good)
    env.step(good)
    with pytest.raises(RuntimeError, match="Episode already complete"):
        env.step(good)


def test_state_tracks_last_action():
    env = PharmaVigilanceEnv()
    env.reset("known_signal_easy")
    env.step(
        Action(
            classification="known_side_effect",
            suspect_drug="Lisinopril",
            severity_assessment="mild",
            recommended_action="log_and_monitor",
            reasoning="Known adverse effect.",
            confidence=90,
        )
    )
    env.step(
        Action(
            classification="known_side_effect",
            suspect_drug="Lisinopril",
            severity_assessment="mild",
            recommended_action="log_and_monitor",
            reasoning="Known adverse effect.",
            confidence=90,
        )
    )
    state = env.state()
    assert state["step_number"] == 2
    assert state["last_action"]["classification"] == "known_side_effect"


def test_all_tasks_available():
    tasks = get_tasks()
    assert set(tasks.keys()) == {
        "known_signal_easy",
        "cluster_signal_medium",
        "confounded_hard",
    }


def test_grouped_tasks_expose_easy_medium_hard_pools():
    grouped = get_tasks(grouped=True)
    assert set(grouped.keys()) == {"easy", "medium", "hard"}
    assert grouped["easy"][0].task_id == "known_signal_easy"
    assert grouped["medium"][0].task_id == "cluster_signal_medium"
    assert grouped["hard"][0].task_id == "confounded_hard"


def test_get_task_returns_hard_truth():
    task = get_task("confounded_hard")
    assert task.ground_truth.suspect_drug == "Tacrolimus+Voriconazole"


def test_public_graders_are_strictly_bounded():
    assert known_signal_easy_grader({"rewards": [1.0]}) == 0.99
    assert cluster_signal_medium_grader({"rewards": [0.0]}) == 0.01
    assert confounded_hard_grader({"score": 1.5}) == 0.99


def test_inference_final_score_uses_public_task_grader():
    pytest.importorskip("openenv")
    from inference import final_score

    rewards = [0.4, 1.0]
    assert final_score("known_signal_easy", rewards) == known_signal_easy_grader({"rewards": rewards})
    assert final_score("cluster_signal_medium", rewards) == cluster_signal_medium_grader({"rewards": rewards})
    assert final_score("confounded_hard", rewards) == confounded_hard_grader({"rewards": rewards})


def test_http_reset_then_step_roundtrip():
    pytest.importorskip("openenv")
    from fastapi.testclient import TestClient
    from server.app import app

    client = TestClient(app)

    reset_response = client.post("/reset", json={})
    assert reset_response.status_code == 200

    first_step = client.post(
        "/step",
        json={
            "action": {
                "classification": "known_side_effect",
                "suspect_drug": "Lisinopril",
                "severity_assessment": "mild",
                "recommended_action": "log_and_monitor",
                "reasoning": "Known ACE inhibitor cough.",
                "confidence": 90,
            }
        },
    )
    assert first_step.status_code == 200
    first_payload = first_step.json()
    assert first_payload["done"] is False

    step_response = client.post(
        "/step",
        json={
            "action": {
                "classification": "known_side_effect",
                "suspect_drug": "Lisinopril",
                "severity_assessment": "mild",
                "recommended_action": "log_and_monitor",
                "reasoning": "Known ACE inhibitor cough.",
                "confidence": 90,
            }
        },
    )
    assert step_response.status_code == 200
    payload = step_response.json()
    assert payload["done"] is True
    assert payload["reward"] == 1.0