adithya9903's picture
Deploy PolyGuard HF training Space
fd0c71a verified
from fastapi.testclient import TestClient
from app.api import app
from app.common.constants import PRIMARY_REWARD_KEYS, REQUIRED_REWARD_KEYS
def _assert_reward_precision(value: float) -> None:
assert 0.001 <= value <= 0.999
assert value == round(value, 3)
def test_api_health() -> None:
client = TestClient(app)
res = client.get("/health")
assert res.status_code == 200
assert res.json()["status"] == "ok"
def test_api_env_reset() -> None:
client = TestClient(app)
res = client.post("/env/reset", json={})
assert res.status_code == 200
assert "patient_summary" in res.json()
def test_api_uncertainty_endpoint() -> None:
client = TestClient(app)
client.post("/env/reset", json={})
res = client.get("/env/uncertainty")
assert res.status_code == 200
assert "overall_uncertainty" in res.json()
def test_api_env_catalog_contains_adapter_presets() -> None:
client = TestClient(app)
res = client.get("/env/catalog")
assert res.status_code == 200
payload = res.json()
assert payload["reward_range"] == [0.001, 0.999]
assert payload["reward_precision"] == 3
assert payload["sub_environments"] == [
"DDI",
"BANDIT_MINING",
"REGIMEN_RISK",
"PRECISION_DOSING",
"LONGITUDINAL_DEPRESCRIBING",
"WEB_SEARCH_MISSING_DATA",
"ALTERNATIVE_SUGGESTION",
"NEW_DRUG_DECOMPOSITION",
]
presets = {item["id"]: item for item in payload["task_presets"]}
assert presets["easy_screening"]["difficulty"] == "easy"
assert presets["easy_screening"]["sub_environment"] == "DDI"
assert presets["budgeted_screening"]["difficulty"] == "medium"
assert presets["budgeted_screening"]["sub_environment"] == "REGIMEN_RISK"
assert presets["complex_tradeoff"]["difficulty"] == "hard"
assert presets["complex_tradeoff"]["sub_environment"] == "REGIMEN_RISK"
assert presets["bandit_mining"]["difficulty"] == "hard"
assert presets["bandit_mining"]["sub_environment"] == "BANDIT_MINING"
def test_api_reset_accepts_task_presets() -> None:
client = TestClient(app)
expected = {
"easy_screening": ("easy", "DDI"),
"budgeted_screening": ("medium", "REGIMEN_RISK"),
"complex_tradeoff": ("hard", "REGIMEN_RISK"),
"bandit_mining": ("hard", "BANDIT_MINING"),
}
for task_id, (difficulty, sub_environment) in expected.items():
res = client.post("/env/reset", json={"task_id": task_id, "seed": 91})
assert res.status_code == 200
contract = res.json()["deterministic_contract"]
assert contract["difficulty"] == difficulty
assert contract["sub_environment"] == sub_environment
def test_api_step_candidate_resolves_legal_candidate() -> None:
client = TestClient(app)
reset = client.post("/env/reset", json={"task_id": "easy_screening", "seed": 42})
assert reset.status_code == 200
candidate = reset.json()["candidate_action_set"][0]
res = client.post(
"/env/step_candidate",
json={
"candidate_id": candidate["candidate_id"],
"confidence": 0.750,
"rationale_brief": "Selected from the candidate workbench.",
},
)
assert res.status_code == 200
payload = res.json()
_assert_reward_precision(payload["reward"])
assert payload["info"]["reward_breakdown"]["total_reward"] == payload["reward"]
for key in REQUIRED_REWARD_KEYS:
_assert_reward_precision(payload["info"]["reward_breakdown"][key])
for key in PRIMARY_REWARD_KEYS:
_assert_reward_precision(payload["info"]["primary_reward_channels"][key])
def test_api_step_candidate_rejects_unknown_candidate() -> None:
client = TestClient(app)
client.post("/env/reset", json={"task_id": "easy_screening", "seed": 42})
res = client.post(
"/env/step_candidate",
json={
"candidate_id": "cand_missing",
"confidence": 0.500,
"rationale_brief": "This should not resolve.",
},
)
assert res.status_code == 404
def test_api_policy_model_status() -> None:
client = TestClient(app)
res = client.get("/policy/model_status")
assert res.status_code == 200
payload = res.json()
assert payload["provider"] == "transformers"
assert "preferred_artifact" in payload
assert "availability" in payload