| from fastapi.testclient import TestClient |
|
|
| from app.api import app |
| from app.common.constants import PRIMARY_REWARD_KEYS, REQUIRED_REWARD_KEYS |
|
|
|
|
| def _assert_reward_precision(value: float) -> None: |
| assert 0.001 <= value <= 0.999 |
| assert value == round(value, 3) |
|
|
|
|
| def test_api_health() -> None: |
| client = TestClient(app) |
| res = client.get("/health") |
| assert res.status_code == 200 |
| assert res.json()["status"] == "ok" |
|
|
|
|
| def test_api_env_reset() -> None: |
| client = TestClient(app) |
| res = client.post("/env/reset", json={}) |
| assert res.status_code == 200 |
| assert "patient_summary" in res.json() |
|
|
|
|
| def test_api_uncertainty_endpoint() -> None: |
| client = TestClient(app) |
| client.post("/env/reset", json={}) |
| res = client.get("/env/uncertainty") |
| assert res.status_code == 200 |
| assert "overall_uncertainty" in res.json() |
|
|
|
|
| def test_api_env_catalog_contains_adapter_presets() -> None: |
| client = TestClient(app) |
| res = client.get("/env/catalog") |
| assert res.status_code == 200 |
| payload = res.json() |
|
|
| assert payload["reward_range"] == [0.001, 0.999] |
| assert payload["reward_precision"] == 3 |
| assert payload["sub_environments"] == [ |
| "DDI", |
| "BANDIT_MINING", |
| "REGIMEN_RISK", |
| "PRECISION_DOSING", |
| "LONGITUDINAL_DEPRESCRIBING", |
| "WEB_SEARCH_MISSING_DATA", |
| "ALTERNATIVE_SUGGESTION", |
| "NEW_DRUG_DECOMPOSITION", |
| ] |
| presets = {item["id"]: item for item in payload["task_presets"]} |
| assert presets["easy_screening"]["difficulty"] == "easy" |
| assert presets["easy_screening"]["sub_environment"] == "DDI" |
| assert presets["budgeted_screening"]["difficulty"] == "medium" |
| assert presets["budgeted_screening"]["sub_environment"] == "REGIMEN_RISK" |
| assert presets["complex_tradeoff"]["difficulty"] == "hard" |
| assert presets["complex_tradeoff"]["sub_environment"] == "REGIMEN_RISK" |
| assert presets["bandit_mining"]["difficulty"] == "hard" |
| assert presets["bandit_mining"]["sub_environment"] == "BANDIT_MINING" |
|
|
|
|
| def test_api_reset_accepts_task_presets() -> None: |
| client = TestClient(app) |
| expected = { |
| "easy_screening": ("easy", "DDI"), |
| "budgeted_screening": ("medium", "REGIMEN_RISK"), |
| "complex_tradeoff": ("hard", "REGIMEN_RISK"), |
| "bandit_mining": ("hard", "BANDIT_MINING"), |
| } |
|
|
| for task_id, (difficulty, sub_environment) in expected.items(): |
| res = client.post("/env/reset", json={"task_id": task_id, "seed": 91}) |
| assert res.status_code == 200 |
| contract = res.json()["deterministic_contract"] |
| assert contract["difficulty"] == difficulty |
| assert contract["sub_environment"] == sub_environment |
|
|
|
|
| def test_api_step_candidate_resolves_legal_candidate() -> None: |
| client = TestClient(app) |
| reset = client.post("/env/reset", json={"task_id": "easy_screening", "seed": 42}) |
| assert reset.status_code == 200 |
| candidate = reset.json()["candidate_action_set"][0] |
|
|
| res = client.post( |
| "/env/step_candidate", |
| json={ |
| "candidate_id": candidate["candidate_id"], |
| "confidence": 0.750, |
| "rationale_brief": "Selected from the candidate workbench.", |
| }, |
| ) |
| assert res.status_code == 200 |
| payload = res.json() |
| _assert_reward_precision(payload["reward"]) |
| assert payload["info"]["reward_breakdown"]["total_reward"] == payload["reward"] |
| for key in REQUIRED_REWARD_KEYS: |
| _assert_reward_precision(payload["info"]["reward_breakdown"][key]) |
| for key in PRIMARY_REWARD_KEYS: |
| _assert_reward_precision(payload["info"]["primary_reward_channels"][key]) |
|
|
|
|
| def test_api_step_candidate_rejects_unknown_candidate() -> None: |
| client = TestClient(app) |
| client.post("/env/reset", json={"task_id": "easy_screening", "seed": 42}) |
|
|
| res = client.post( |
| "/env/step_candidate", |
| json={ |
| "candidate_id": "cand_missing", |
| "confidence": 0.500, |
| "rationale_brief": "This should not resolve.", |
| }, |
| ) |
| assert res.status_code == 404 |
|
|
|
|
| def test_api_policy_model_status() -> None: |
| client = TestClient(app) |
| res = client.get("/policy/model_status") |
| assert res.status_code == 200 |
| payload = res.json() |
| assert payload["provider"] == "transformers" |
| assert "preferred_artifact" in payload |
| assert "availability" in payload |
|
|