Spaces:
Sleeping
Sleeping
| import sys | |
| from pathlib import Path | |
| import pytest | |
| sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) | |
| from env import Action, PharmaVigilanceEnv | |
| from tasks import ( | |
| cluster_signal_medium_action_grader, | |
| cluster_signal_medium_grader, | |
| confounded_hard_action_grader, | |
| confounded_hard_grader, | |
| get_task, | |
| get_tasks, | |
| known_signal_easy_action_grader, | |
| known_signal_easy_grader, | |
| ) | |
| def test_reset_loads_easy_task(): | |
| env = PharmaVigilanceEnv() | |
| obs = env.reset("known_signal_easy") | |
| assert obs.task_id == "known_signal_easy" | |
| assert obs.step_number == 0 | |
| assert len(obs.reports) == 1 | |
| def test_known_signal_grader_full_credit(): | |
| reward = known_signal_easy_action_grader( | |
| Action( | |
| classification="known_side_effect", | |
| suspect_drug="Lisinopril", | |
| severity_assessment="mild", | |
| recommended_action="log_and_monitor", | |
| reasoning="Known reaction pattern.", | |
| ) | |
| ) | |
| assert reward.total == 1.0 | |
| def test_medium_cluster_grader_partial_credit(): | |
| reward = cluster_signal_medium_action_grader( | |
| Action( | |
| classification="new_signal", | |
| suspect_drug="Cardiovexa", | |
| severity_assessment="moderate", | |
| recommended_action="escalate", | |
| reasoning="A cluster is forming.", | |
| ) | |
| ) | |
| assert reward.total == 0.75 | |
| def test_hard_grader_reasoning_bonus(): | |
| reward = confounded_hard_action_grader( | |
| Action( | |
| classification="new_signal", | |
| suspect_drug="Tacrolimus+Voriconazole", | |
| severity_assessment="critical", | |
| recommended_action="escalate", | |
| reasoning="This looks like a tacrolimus-voriconazole drug interaction with toxic levels.", | |
| ) | |
| ) | |
| assert reward.total == 1.0 | |
| assert reward.breakdown["reasoning_bonus"] == 0.15 | |
| def test_hard_grader_substring_suspect_match(): | |
| reward = confounded_hard_action_grader( | |
| Action( | |
| classification="new_signal", | |
| suspect_drug="Tacrolimus", | |
| severity_assessment="critical", | |
| recommended_action="escalate", | |
| reasoning="Voriconazole likely raised tacrolimus exposure.", | |
| ) | |
| ) | |
| assert reward.breakdown["suspect_drug"] == 0.25 | |
| def test_env_step_returns_done(): | |
| env = PharmaVigilanceEnv() | |
| env.reset("confounded_hard") | |
| obs, reward, done, info = env.step( | |
| Action( | |
| classification="new_signal", | |
| suspect_drug="Tacrolimus+Voriconazole", | |
| severity_assessment="critical", | |
| recommended_action="escalate", | |
| reasoning="Tacrolimus toxicity from an azole interaction.", | |
| ) | |
| ) | |
| assert done is True | |
| assert obs.step_number == 1 | |
| assert "reward_breakdown" in info | |
| assert reward.total >= 0.85 | |
| def test_state_tracks_last_action(): | |
| env = PharmaVigilanceEnv() | |
| env.reset("known_signal_easy") | |
| env.step( | |
| Action( | |
| classification="known_side_effect", | |
| suspect_drug="Lisinopril", | |
| severity_assessment="mild", | |
| recommended_action="log_and_monitor", | |
| reasoning="Known adverse effect.", | |
| ) | |
| ) | |
| state = env.state() | |
| assert state["step_number"] == 1 | |
| assert state["last_action"]["classification"] == "known_side_effect" | |
| def test_all_tasks_available(): | |
| tasks = get_tasks() | |
| assert set(tasks.keys()) == { | |
| "known_signal_easy", | |
| "cluster_signal_medium", | |
| "confounded_hard", | |
| } | |
| def test_get_task_returns_hard_truth(): | |
| task = get_task("confounded_hard") | |
| assert task.ground_truth.suspect_drug == "Tacrolimus+Voriconazole" | |
| def test_public_graders_are_strictly_bounded(): | |
| assert known_signal_easy_grader({"rewards": [1.0]}) == 0.99 | |
| assert cluster_signal_medium_grader({"rewards": [0.0]}) == 0.01 | |
| assert confounded_hard_grader({"score": 1.5}) == 0.99 | |
| def test_http_reset_then_step_roundtrip(): | |
| pytest.importorskip("openenv") | |
| from fastapi.testclient import TestClient | |
| from server.app import app | |
| client = TestClient(app) | |
| reset_response = client.post("/reset", json={}) | |
| assert reset_response.status_code == 200 | |
| step_response = client.post( | |
| "/step", | |
| json={ | |
| "action": { | |
| "classification": "known_side_effect", | |
| "suspect_drug": "Lisinopril", | |
| "severity_assessment": "mild", | |
| "recommended_action": "log_and_monitor", | |
| "reasoning": "Known ACE inhibitor cough.", | |
| } | |
| }, | |
| ) | |
| assert step_response.status_code == 200 | |
| payload = step_response.json() | |
| assert payload["done"] is True | |
| assert payload["reward"] == 1.0 | |