File size: 4,390 Bytes
21c7db9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from fastapi.testclient import TestClient

from app.api import app
from app.common.constants import PRIMARY_REWARD_KEYS, REQUIRED_REWARD_KEYS


def _assert_reward_precision(value: float) -> None:
    assert 0.001 <= value <= 0.999
    assert value == round(value, 3)


def test_api_health() -> None:
    client = TestClient(app)
    res = client.get("/health")
    assert res.status_code == 200
    assert res.json()["status"] == "ok"


def test_api_env_reset() -> None:
    client = TestClient(app)
    res = client.post("/env/reset", json={})
    assert res.status_code == 200
    assert "patient_summary" in res.json()


def test_api_uncertainty_endpoint() -> None:
    client = TestClient(app)
    client.post("/env/reset", json={})
    res = client.get("/env/uncertainty")
    assert res.status_code == 200
    assert "overall_uncertainty" in res.json()


def test_api_env_catalog_contains_adapter_presets() -> None:
    client = TestClient(app)
    res = client.get("/env/catalog")
    assert res.status_code == 200
    payload = res.json()

    assert payload["reward_range"] == [0.001, 0.999]
    assert payload["reward_precision"] == 3
    assert payload["sub_environments"] == [
        "DDI",
        "BANDIT_MINING",
        "REGIMEN_RISK",
        "PRECISION_DOSING",
        "LONGITUDINAL_DEPRESCRIBING",
        "WEB_SEARCH_MISSING_DATA",
        "ALTERNATIVE_SUGGESTION",
        "NEW_DRUG_DECOMPOSITION",
    ]
    presets = {item["id"]: item for item in payload["task_presets"]}
    assert presets["easy_screening"]["difficulty"] == "easy"
    assert presets["easy_screening"]["sub_environment"] == "DDI"
    assert presets["budgeted_screening"]["difficulty"] == "medium"
    assert presets["budgeted_screening"]["sub_environment"] == "REGIMEN_RISK"
    assert presets["complex_tradeoff"]["difficulty"] == "hard"
    assert presets["complex_tradeoff"]["sub_environment"] == "REGIMEN_RISK"
    assert presets["bandit_mining"]["difficulty"] == "hard"
    assert presets["bandit_mining"]["sub_environment"] == "BANDIT_MINING"


def test_api_reset_accepts_task_presets() -> None:
    client = TestClient(app)
    expected = {
        "easy_screening": ("easy", "DDI"),
        "budgeted_screening": ("medium", "REGIMEN_RISK"),
        "complex_tradeoff": ("hard", "REGIMEN_RISK"),
        "bandit_mining": ("hard", "BANDIT_MINING"),
    }

    for task_id, (difficulty, sub_environment) in expected.items():
        res = client.post("/env/reset", json={"task_id": task_id, "seed": 91})
        assert res.status_code == 200
        contract = res.json()["deterministic_contract"]
        assert contract["difficulty"] == difficulty
        assert contract["sub_environment"] == sub_environment


def test_api_step_candidate_resolves_legal_candidate() -> None:
    client = TestClient(app)
    reset = client.post("/env/reset", json={"task_id": "easy_screening", "seed": 42})
    assert reset.status_code == 200
    candidate = reset.json()["candidate_action_set"][0]

    res = client.post(
        "/env/step_candidate",
        json={
            "candidate_id": candidate["candidate_id"],
            "confidence": 0.750,
            "rationale_brief": "Selected from the candidate workbench.",
        },
    )
    assert res.status_code == 200
    payload = res.json()
    _assert_reward_precision(payload["reward"])
    assert payload["info"]["reward_breakdown"]["total_reward"] == payload["reward"]
    for key in REQUIRED_REWARD_KEYS:
        _assert_reward_precision(payload["info"]["reward_breakdown"][key])
    for key in PRIMARY_REWARD_KEYS:
        _assert_reward_precision(payload["info"]["primary_reward_channels"][key])


def test_api_step_candidate_rejects_unknown_candidate() -> None:
    client = TestClient(app)
    client.post("/env/reset", json={"task_id": "easy_screening", "seed": 42})

    res = client.post(
        "/env/step_candidate",
        json={
            "candidate_id": "cand_missing",
            "confidence": 0.500,
            "rationale_brief": "This should not resolve.",
        },
    )
    assert res.status_code == 404


def test_api_policy_model_status() -> None:
    client = TestClient(app)
    res = client.get("/policy/model_status")
    assert res.status_code == 200
    payload = res.json()
    assert payload["provider"] == "transformers"
    assert "preferred_artifact" in payload
    assert "availability" in payload