Spaces:
Sleeping
Sleeping
File size: 4,878 Bytes
dfc0f77 f2beac3 dfc0f77 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 | import sys
from pathlib import Path
import pytest
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from env import Action, PharmaVigilanceEnv
from tasks import (
cluster_signal_medium_action_grader,
cluster_signal_medium_grader,
confounded_hard_action_grader,
confounded_hard_grader,
get_task,
get_tasks,
known_signal_easy_action_grader,
known_signal_easy_grader,
)
def test_reset_loads_easy_task():
env = PharmaVigilanceEnv()
obs = env.reset("known_signal_easy")
assert obs.task_id == "known_signal_easy"
assert obs.step_number == 0
assert len(obs.reports) == 1
def test_known_signal_grader_full_credit():
reward = known_signal_easy_action_grader(
Action(
classification="known_side_effect",
suspect_drug="Lisinopril",
severity_assessment="mild",
recommended_action="log_and_monitor",
reasoning="Known reaction pattern.",
)
)
assert reward.total == 1.0
def test_medium_cluster_grader_partial_credit():
reward = cluster_signal_medium_action_grader(
Action(
classification="new_signal",
suspect_drug="Cardiovexa",
severity_assessment="moderate",
recommended_action="escalate",
reasoning="A cluster is forming.",
)
)
assert reward.total == 0.75
def test_hard_grader_reasoning_bonus():
reward = confounded_hard_action_grader(
Action(
classification="new_signal",
suspect_drug="Tacrolimus+Voriconazole",
severity_assessment="critical",
recommended_action="escalate",
reasoning="This looks like a tacrolimus-voriconazole drug interaction with toxic levels.",
)
)
assert reward.total == 1.0
assert reward.breakdown["reasoning_bonus"] == 0.15
def test_hard_grader_substring_suspect_match():
reward = confounded_hard_action_grader(
Action(
classification="new_signal",
suspect_drug="Tacrolimus",
severity_assessment="critical",
recommended_action="escalate",
reasoning="Voriconazole likely raised tacrolimus exposure.",
)
)
assert reward.breakdown["suspect_drug"] == 0.25
def test_env_step_returns_done():
env = PharmaVigilanceEnv()
env.reset("confounded_hard")
obs, reward, done, info = env.step(
Action(
classification="new_signal",
suspect_drug="Tacrolimus+Voriconazole",
severity_assessment="critical",
recommended_action="escalate",
reasoning="Tacrolimus toxicity from an azole interaction.",
)
)
assert done is True
assert obs.step_number == 1
assert "reward_breakdown" in info
assert reward.total >= 0.85
def test_state_tracks_last_action():
env = PharmaVigilanceEnv()
env.reset("known_signal_easy")
env.step(
Action(
classification="known_side_effect",
suspect_drug="Lisinopril",
severity_assessment="mild",
recommended_action="log_and_monitor",
reasoning="Known adverse effect.",
)
)
state = env.state()
assert state["step_number"] == 1
assert state["last_action"]["classification"] == "known_side_effect"
def test_all_tasks_available():
tasks = get_tasks()
assert set(tasks.keys()) == {
"known_signal_easy",
"cluster_signal_medium",
"confounded_hard",
}
def test_get_task_returns_hard_truth():
task = get_task("confounded_hard")
assert task.ground_truth.suspect_drug == "Tacrolimus+Voriconazole"
def test_public_graders_are_strictly_bounded():
assert known_signal_easy_grader({"rewards": [1.0]}) == 0.99
assert cluster_signal_medium_grader({"rewards": [0.0]}) == 0.01
assert confounded_hard_grader({"score": 1.5}) == 0.99
def test_http_reset_then_step_roundtrip():
pytest.importorskip("openenv")
from fastapi.testclient import TestClient
from server.app import app
client = TestClient(app)
reset_response = client.post("/reset", json={})
assert reset_response.status_code == 200
step_response = client.post(
"/step",
json={
"action": {
"classification": "known_side_effect",
"suspect_drug": "Lisinopril",
"severity_assessment": "mild",
"recommended_action": "log_and_monitor",
"reasoning": "Known ACE inhibitor cough.",
}
},
)
assert step_response.status_code == 200
payload = step_response.json()
assert payload["done"] is True
assert payload["reward"] == 1.0
|