File size: 5,115 Bytes
d110f58
 
 
 
 
 
b42dbeb
 
 
d110f58
 
b42dbeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d110f58
 
b42dbeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d110f58
 
 
 
 
 
b42dbeb
 
d110f58
 
b42dbeb
d110f58
 
 
 
b42dbeb
 
d110f58
 
 
 
 
 
 
b42dbeb
 
 
d110f58
 
 
b42dbeb
 
d110f58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""Tests for the FastAPI HTTP server (OpenEnv create_app endpoints).

OpenEnv HTTP endpoints are *stateless*: each /reset and /step creates a
fresh environment instance.  Multi-step sessions only work via WebSocket.
These tests validate single-call behaviour and schema contracts.
"""

from __future__ import annotations

import json

import pytest
from fastapi.testclient import TestClient

from polypharmacy_env.api.server import app


@pytest.fixture
def client() -> TestClient:
    return TestClient(app)


class TestHealth:
    def test_health(self, client: TestClient) -> None:
        resp = client.get("/health")
        assert resp.status_code == 200
        data = resp.json()
        assert data["status"] == "healthy"


class TestReset:
    def test_reset_default(self, client: TestClient) -> None:
        resp = client.post("/reset", json={})
        assert resp.status_code == 200
        data = resp.json()
        assert "observation" in data
        assert data["done"] is False

    def test_reset_with_task(self, client: TestClient) -> None:
        resp = client.post("/reset", json={"task_id": "easy_screening"})
        assert resp.status_code == 200
        obs = resp.json()["observation"]
        assert obs["task_id"] == "easy_screening"

    def test_reset_observation_has_medications(self, client: TestClient) -> None:
        resp = client.post("/reset", json={"task_id": "easy_screening", "seed": 42})
        assert resp.status_code == 200
        obs = resp.json()["observation"]
        assert len(obs["current_medications"]) >= 3


class TestStep:
    """Test /step endpoint – each call is independent (stateless)."""

    def test_step_finish(self, client: TestClient) -> None:
        resp = client.post(
            "/step",
            json={"action": {"action_type": "finish_review"}},
        )
        assert resp.status_code == 200
        data = resp.json()
        assert "observation" in data

    def test_invalid_action_422(self, client: TestClient) -> None:
        resp = client.post(
            "/step",
            json={"action": {"action_type": "invalid_type"}},
        )
        assert resp.status_code == 422


class TestSchema:
    def test_schema(self, client: TestClient) -> None:
        resp = client.get("/schema")
        assert resp.status_code == 200
        data = resp.json()
        # OpenEnv schema endpoint returns keys: action, observation, state
        assert "action" in data
        assert "observation" in data


class TestWebSocketSession:
    """Test multi-step sessions through the /ws WebSocket endpoint.

    OpenEnv WS protocol:
      Send:  {"type": "reset", "data": {"task_id": "...", "seed": ...}}
      Recv:  {"type": "observation", "data": {"observation": {...}, "reward": ..., "done": ...}}
      Send:  {"type": "step", "data": {"action_type": "...", ...}}
      Recv:  {"type": "observation", "data": {"observation": {...}, ...}}
      Send:  {"type": "state"}
      Recv:  {"type": "state", "data": {...state fields...}}
    """

    def test_ws_reset_step_finish(self, client: TestClient) -> None:
        with client.websocket_connect("/ws") as ws:
            # Reset
            ws.send_json({
                "type": "reset",
                "data": {"task_id": "easy_screening", "seed": 42},
            })
            reset_resp = ws.receive_json()
            assert reset_resp["type"] == "observation"
            reset_data = reset_resp["data"]
            assert reset_data["done"] is False
            obs = reset_data["observation"]
            assert obs["task_id"] == "easy_screening"
            meds = obs["current_medications"]
            assert len(meds) >= 3

            # Step – query DDI
            if len(meds) >= 2:
                ws.send_json({
                    "type": "step",
                    "data": {
                        "action_type": "query_ddi",
                        "drug_id_1": meds[0]["drug_id"],
                        "drug_id_2": meds[1]["drug_id"],
                    },
                })
                step_resp = ws.receive_json()
                assert step_resp["type"] == "observation"
                assert step_resp["data"]["done"] is False

            # Finish
            ws.send_json({
                "type": "step",
                "data": {"action_type": "finish_review"},
            })
            finish_resp = ws.receive_json()
            assert finish_resp["type"] == "observation"
            assert finish_resp["data"]["done"] is True

    def test_ws_state(self, client: TestClient) -> None:
        with client.websocket_connect("/ws") as ws:
            ws.send_json({
                "type": "reset",
                "data": {"task_id": "easy_screening", "seed": 0},
            })
            ws.receive_json()  # consume reset response

            ws.send_json({"type": "state"})
            state_resp = ws.receive_json()
            assert state_resp["type"] == "state"
            state_data = state_resp["data"]
            assert state_data["step_count"] == 0
            assert state_data["task_id"] == "easy_screening"