File size: 4,082 Bytes
877add7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from fastapi.testclient import TestClient

from app.api import app
from app.common.constants import PRIMARY_REWARD_KEYS, REQUIRED_REWARD_KEYS


def _assert_reward_precision(value: float) -> None:
    assert 0.001 <= value <= 0.999
    assert value == round(value, 3)


def test_api_health() -> None:
    client = TestClient(app)
    res = client.get("/health")
    assert res.status_code == 200
    assert res.json()["status"] == "ok"


def test_api_env_reset() -> None:
    client = TestClient(app)
    res = client.post("/env/reset", json={})
    assert res.status_code == 200
    assert "patient_summary" in res.json()


def test_api_uncertainty_endpoint() -> None:
    client = TestClient(app)
    client.post("/env/reset", json={})
    res = client.get("/env/uncertainty")
    assert res.status_code == 200
    assert "overall_uncertainty" in res.json()


def test_api_env_catalog_contains_adapter_presets() -> None:
    client = TestClient(app)
    res = client.get("/env/catalog")
    assert res.status_code == 200
    payload = res.json()

    assert payload["reward_range"] == [0.001, 0.999]
    assert payload["reward_precision"] == 3
    assert payload["sub_environments"] == [
        "DDI",
        "BANDIT_MINING",
        "REGIMEN_RISK",
        "PRECISION_DOSING",
        "LONGITUDINAL_DEPRESCRIBING",
        "WEB_SEARCH_MISSING_DATA",
        "ALTERNATIVE_SUGGESTION",
        "NEW_DRUG_DECOMPOSITION",
    ]
    presets = {item["id"]: item for item in payload["task_presets"]}
    assert presets["easy_screening"]["difficulty"] == "easy"
    assert presets["easy_screening"]["sub_environment"] == "DDI"
    assert presets["budgeted_screening"]["difficulty"] == "medium"
    assert presets["budgeted_screening"]["sub_environment"] == "REGIMEN_RISK"
    assert presets["complex_tradeoff"]["difficulty"] == "hard"
    assert presets["complex_tradeoff"]["sub_environment"] == "REGIMEN_RISK"
    assert presets["bandit_mining"]["difficulty"] == "hard"
    assert presets["bandit_mining"]["sub_environment"] == "BANDIT_MINING"


def test_api_reset_accepts_task_presets() -> None:
    client = TestClient(app)
    expected = {
        "easy_screening": ("easy", "DDI"),
        "budgeted_screening": ("medium", "REGIMEN_RISK"),
        "complex_tradeoff": ("hard", "REGIMEN_RISK"),
        "bandit_mining": ("hard", "BANDIT_MINING"),
    }

    for task_id, (difficulty, sub_environment) in expected.items():
        res = client.post("/env/reset", json={"task_id": task_id, "seed": 91})
        assert res.status_code == 200
        contract = res.json()["deterministic_contract"]
        assert contract["difficulty"] == difficulty
        assert contract["sub_environment"] == sub_environment


def test_api_step_candidate_resolves_legal_candidate() -> None:
    client = TestClient(app)
    reset = client.post("/env/reset", json={"task_id": "easy_screening", "seed": 42})
    assert reset.status_code == 200
    candidate = reset.json()["candidate_action_set"][0]

    res = client.post(
        "/env/step_candidate",
        json={
            "candidate_id": candidate["candidate_id"],
            "confidence": 0.750,
            "rationale_brief": "Selected from the candidate workbench.",
        },
    )
    assert res.status_code == 200
    payload = res.json()
    _assert_reward_precision(payload["reward"])
    assert payload["info"]["reward_breakdown"]["total_reward"] == payload["reward"]
    for key in REQUIRED_REWARD_KEYS:
        _assert_reward_precision(payload["info"]["reward_breakdown"][key])
    for key in PRIMARY_REWARD_KEYS:
        _assert_reward_precision(payload["info"]["primary_reward_channels"][key])


def test_api_step_candidate_rejects_unknown_candidate() -> None:
    client = TestClient(app)
    client.post("/env/reset", json={"task_id": "easy_screening", "seed": 42})

    res = client.post(
        "/env/step_candidate",
        json={
            "candidate_id": "cand_missing",
            "confidence": 0.500,
            "rationale_brief": "This should not resolve.",
        },
    )
    assert res.status_code == 404