import sys from pathlib import Path import pytest sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) from env import Action, PharmaVigilanceEnv from tasks import ( cluster_signal_medium_action_grader, cluster_signal_medium_grader, confounded_hard_action_grader, confounded_hard_grader, get_task, get_tasks, known_signal_easy_action_grader, known_signal_easy_grader, ) def test_reset_loads_easy_task(): env = PharmaVigilanceEnv() obs = env.reset("known_signal_easy") assert obs.task_id == "known_signal_easy" assert obs.step_number == 0 assert obs.max_steps == 2 assert len(obs.reports) == 1 def test_known_signal_grader_full_credit(): reward = known_signal_easy_action_grader( Action( classification="known_side_effect", suspect_drug="Lisinopril", severity_assessment="mild", recommended_action="log_and_monitor", reasoning="Known reaction pattern.", confidence=91, ) ) assert reward.total == 1.0 def test_medium_cluster_grader_partial_credit(): reward = cluster_signal_medium_action_grader( Action( classification="new_signal", suspect_drug="Cardiovexa", severity_assessment="moderate", recommended_action="escalate", reasoning="A cluster is forming.", ) ) assert reward.total == 0.75 def test_hard_grader_reasoning_bonus(): reward = confounded_hard_action_grader( Action( classification="new_signal", suspect_drug="Tacrolimus+Voriconazole", severity_assessment="critical", recommended_action="escalate", reasoning="This looks like a tacrolimus-voriconazole drug interaction with toxic levels.", ) ) assert reward.total == 1.0 assert reward.breakdown["reasoning_bonus"] == 0.05 def test_hard_grader_substring_suspect_match(): reward = confounded_hard_action_grader( Action( classification="new_signal", suspect_drug="Tacrolimus", severity_assessment="critical", recommended_action="escalate", reasoning="Voriconazole likely raised tacrolimus exposure.", ) ) assert reward.breakdown["suspect_drug"] == 0.25 def test_env_step_returns_done(): env = PharmaVigilanceEnv() env.reset("confounded_hard") obs, reward, done, info = env.step( Action( classification="new_signal", suspect_drug="Tacrolimus+Voriconazole", severity_assessment="critical", recommended_action="escalate", reasoning="Tacrolimus toxicity from an azole interaction.", ) ) assert done is False assert obs.step_number == 1 assert "reward_breakdown" in info assert reward.total >= 0.20 obs, reward, done, info = env.step( Action( classification="new_signal", suspect_drug="Tacrolimus+Voriconazole", severity_assessment="critical", recommended_action="escalate", reasoning="Tacrolimus toxicity from an azole interaction.", ) ) assert done is True assert obs.step_number == 2 assert reward.total >= 0.85 def test_first_step_returns_partial_reward_and_review_feedback(): env = PharmaVigilanceEnv() obs = env.reset("cluster_signal_medium") obs, reward, done, info = env.step( Action( classification="new_signal", suspect_drug="Cardiovexa", severity_assessment="severe", recommended_action="escalate", reasoning="Clustered bradycardia on a newer therapy.", confidence=88, ) ) assert done is False assert obs.step_number == 1 assert reward.total > 0.0 assert info["phase"] == "initial_triage" assert "Senior review note" in obs.feedback def test_final_step_awards_revision_bonus_when_agent_improves(): env = PharmaVigilanceEnv() env.reset("cluster_signal_medium") env.step( Action( classification="noise", suspect_drug="Unknown", severity_assessment="mild", recommended_action="dismiss", reasoning="Weak initial guess.", confidence=90, ) ) _, reward, done, info = env.step( Action( classification="new_signal", suspect_drug="Cardiovexa", severity_assessment="severe", recommended_action="escalate", reasoning="Follow-up reports confirm a coherent bradycardia cluster.", confidence=82, ) ) assert done is True assert reward.breakdown["revision_bonus"] == 0.05 assert info["phase"] == "final_review" def test_final_step_applies_stubborn_penalty_for_repeating_weak_answer(): env = PharmaVigilanceEnv() env.reset("confounded_hard") weak = Action( classification="noise", suspect_drug="Trimethoprim-sulfamethoxazole", severity_assessment="mild", recommended_action="dismiss", reasoning="Reporter probably overcalled it.", confidence=85, ) env.step(weak) _, reward, done, _ = env.step(weak) assert done is True assert reward.breakdown["stubborn_penalty"] == -0.05 def test_initial_step_can_return_negative_reward_for_unsafe_triage(): env = PharmaVigilanceEnv() env.reset("cluster_signal_medium") _, reward, done, info = env.step( Action( classification="noise", suspect_drug="Unknown", severity_assessment="mild", recommended_action="dismiss", reasoning="No obvious concern.", confidence=95, ) ) assert done is False assert info["phase"] == "initial_triage" assert reward.total < 0.0 def test_single_step_action_grader_can_return_negative_total(): reward = cluster_signal_medium_action_grader( Action( classification="noise", suspect_drug="Unknown", severity_assessment="mild", recommended_action="dismiss", reasoning="Probably unrelated.", confidence=95, ) ) assert reward.total < 0.0 def test_overconfidence_penalty_applies_on_weak_single_step_grading(): reward = cluster_signal_medium_action_grader( Action( classification="noise", suspect_drug="Unknown", severity_assessment="mild", recommended_action="dismiss", reasoning="This is probably nothing.", confidence=95, ) ) assert reward.breakdown["confidence_adjustment"] == -0.10 def test_low_confidence_penalty_applies_on_strong_answer(): reward = known_signal_easy_action_grader( Action( classification="known_side_effect", suspect_drug="Lisinopril", severity_assessment="mild", recommended_action="log_and_monitor", reasoning="Known labeled ACE-inhibitor cough.", confidence=20, ) ) assert reward.breakdown["confidence_adjustment"] == -0.03 def test_episode_rejects_third_step_after_completion(): env = PharmaVigilanceEnv() env.reset("known_signal_easy") good = Action( classification="known_side_effect", suspect_drug="Lisinopril", severity_assessment="mild", recommended_action="log_and_monitor", reasoning="Known ACE-inhibitor cough.", confidence=90, ) env.step(good) env.step(good) with pytest.raises(RuntimeError, match="Episode already complete"): env.step(good) def test_state_tracks_last_action(): env = PharmaVigilanceEnv() env.reset("known_signal_easy") env.step( Action( classification="known_side_effect", suspect_drug="Lisinopril", severity_assessment="mild", recommended_action="log_and_monitor", reasoning="Known adverse effect.", confidence=90, ) ) env.step( Action( classification="known_side_effect", suspect_drug="Lisinopril", severity_assessment="mild", recommended_action="log_and_monitor", reasoning="Known adverse effect.", confidence=90, ) ) state = env.state() assert state["step_number"] == 2 assert state["last_action"]["classification"] == "known_side_effect" def test_all_tasks_available(): tasks = get_tasks() assert set(tasks.keys()) == { "known_signal_easy", "cluster_signal_medium", "confounded_hard", } def test_grouped_tasks_expose_easy_medium_hard_pools(): grouped = get_tasks(grouped=True) assert set(grouped.keys()) == {"easy", "medium", "hard"} assert grouped["easy"][0].task_id == "known_signal_easy" assert grouped["medium"][0].task_id == "cluster_signal_medium" assert grouped["hard"][0].task_id == "confounded_hard" def test_get_task_returns_hard_truth(): task = get_task("confounded_hard") assert task.ground_truth.suspect_drug == "Tacrolimus+Voriconazole" def test_public_graders_are_strictly_bounded(): assert known_signal_easy_grader({"rewards": [1.0]}) == 0.99 assert cluster_signal_medium_grader({"rewards": [0.0]}) == 0.01 assert confounded_hard_grader({"score": 1.5}) == 0.99 def test_inference_final_score_uses_public_task_grader(): pytest.importorskip("openenv") from inference import final_score rewards = [0.4, 1.0] assert final_score("known_signal_easy", rewards) == known_signal_easy_grader({"rewards": rewards}) assert final_score("cluster_signal_medium", rewards) == cluster_signal_medium_grader({"rewards": rewards}) assert final_score("confounded_hard", rewards) == confounded_hard_grader({"rewards": rewards}) def test_http_reset_then_step_roundtrip(): pytest.importorskip("openenv") from fastapi.testclient import TestClient from server.app import app client = TestClient(app) reset_response = client.post("/reset", json={}) assert reset_response.status_code == 200 first_step = client.post( "/step", json={ "action": { "classification": "known_side_effect", "suspect_drug": "Lisinopril", "severity_assessment": "mild", "recommended_action": "log_and_monitor", "reasoning": "Known ACE inhibitor cough.", "confidence": 90, } }, ) assert first_step.status_code == 200 first_payload = first_step.json() assert first_payload["done"] is False step_response = client.post( "/step", json={ "action": { "classification": "known_side_effect", "suspect_drug": "Lisinopril", "severity_assessment": "mild", "recommended_action": "log_and_monitor", "reasoning": "Known ACE inhibitor cough.", "confidence": 90, } }, ) assert step_response.status_code == 200 payload = step_response.json() assert payload["done"] is True assert payload["reward"] == 1.0