Spaces:
Sleeping
Sleeping
File size: 5,115 Bytes
2043afa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | """Tests for the FastAPI HTTP server (OpenEnv create_app endpoints).
OpenEnv HTTP endpoints are *stateless*: each /reset and /step creates a
fresh environment instance. Multi-step sessions only work via WebSocket.
These tests validate single-call behaviour and schema contracts.
"""
from __future__ import annotations
import json
import pytest
from fastapi.testclient import TestClient
from polypharmacy_env.api.server import app
@pytest.fixture
def client() -> TestClient:
return TestClient(app)
class TestHealth:
def test_health(self, client: TestClient) -> None:
resp = client.get("/health")
assert resp.status_code == 200
data = resp.json()
assert data["status"] == "healthy"
class TestReset:
def test_reset_default(self, client: TestClient) -> None:
resp = client.post("/reset", json={})
assert resp.status_code == 200
data = resp.json()
assert "observation" in data
assert data["done"] is False
def test_reset_with_task(self, client: TestClient) -> None:
resp = client.post("/reset", json={"task_id": "easy_screening"})
assert resp.status_code == 200
obs = resp.json()["observation"]
assert obs["task_id"] == "easy_screening"
def test_reset_observation_has_medications(self, client: TestClient) -> None:
resp = client.post("/reset", json={"task_id": "easy_screening", "seed": 42})
assert resp.status_code == 200
obs = resp.json()["observation"]
assert len(obs["current_medications"]) >= 3
class TestStep:
"""Test /step endpoint – each call is independent (stateless)."""
def test_step_finish(self, client: TestClient) -> None:
resp = client.post(
"/step",
json={"action": {"action_type": "finish_review"}},
)
assert resp.status_code == 200
data = resp.json()
assert "observation" in data
def test_invalid_action_422(self, client: TestClient) -> None:
resp = client.post(
"/step",
json={"action": {"action_type": "invalid_type"}},
)
assert resp.status_code == 422
class TestSchema:
def test_schema(self, client: TestClient) -> None:
resp = client.get("/schema")
assert resp.status_code == 200
data = resp.json()
# OpenEnv schema endpoint returns keys: action, observation, state
assert "action" in data
assert "observation" in data
class TestWebSocketSession:
"""Test multi-step sessions through the /ws WebSocket endpoint.
OpenEnv WS protocol:
Send: {"type": "reset", "data": {"task_id": "...", "seed": ...}}
Recv: {"type": "observation", "data": {"observation": {...}, "reward": ..., "done": ...}}
Send: {"type": "step", "data": {"action_type": "...", ...}}
Recv: {"type": "observation", "data": {"observation": {...}, ...}}
Send: {"type": "state"}
Recv: {"type": "state", "data": {...state fields...}}
"""
def test_ws_reset_step_finish(self, client: TestClient) -> None:
with client.websocket_connect("/ws") as ws:
# Reset
ws.send_json({
"type": "reset",
"data": {"task_id": "easy_screening", "seed": 42},
})
reset_resp = ws.receive_json()
assert reset_resp["type"] == "observation"
reset_data = reset_resp["data"]
assert reset_data["done"] is False
obs = reset_data["observation"]
assert obs["task_id"] == "easy_screening"
meds = obs["current_medications"]
assert len(meds) >= 3
# Step – query DDI
if len(meds) >= 2:
ws.send_json({
"type": "step",
"data": {
"action_type": "query_ddi",
"drug_id_1": meds[0]["drug_id"],
"drug_id_2": meds[1]["drug_id"],
},
})
step_resp = ws.receive_json()
assert step_resp["type"] == "observation"
assert step_resp["data"]["done"] is False
# Finish
ws.send_json({
"type": "step",
"data": {"action_type": "finish_review"},
})
finish_resp = ws.receive_json()
assert finish_resp["type"] == "observation"
assert finish_resp["data"]["done"] is True
def test_ws_state(self, client: TestClient) -> None:
with client.websocket_connect("/ws") as ws:
ws.send_json({
"type": "reset",
"data": {"task_id": "easy_screening", "seed": 0},
})
ws.receive_json() # consume reset response
ws.send_json({"type": "state"})
state_resp = ws.receive_json()
assert state_resp["type"] == "state"
state_data = state_resp["data"]
assert state_data["step_count"] == 0
assert state_data["task_id"] == "easy_screening"
|