Spaces:
Sleeping
Sleeping
Prasham.Jain
refactor(branch-a): A1 — adopt openenv-core MCPEnvironment + create_app for canonical OpenEnv compliance
4c1ad51 | """Phase A1 server tests against the OpenEnv-canonical wire format. | |
| OpenEnv's HTTP ``/reset`` and ``/step`` endpoints are stateless (a fresh env | |
| instance is constructed per request). State-preserving multi-step flows go | |
| over the WebSocket ``/ws`` endpoint, which is the canonical OpenEnv session | |
| path used by ``EnvClient``. Stateful behavior is therefore validated over | |
| ``/ws``; HTTP endpoints are validated for shape and error handling. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| from ci_triage_env.env.server import build_app | |
| from ci_triage_env.env.wire import CITriageAction | |
| from ci_triage_env.schemas.diagnosis import DiagnosisLabel | |
| def _assert_observation_envelope(payload: dict) -> None: | |
| assert "observation" in payload | |
| assert "reward" in payload | |
| assert "done" in payload | |
| # --------------------------------------------------------------------------- | |
| # A1.1 Server boot | |
| # --------------------------------------------------------------------------- | |
| def test_server_boots(env_factory): | |
| app = build_app(env_factory=env_factory) | |
| assert app.title # FastAPI app instantiated | |
| # --------------------------------------------------------------------------- | |
| # A1.2 / A1.3 HTTP /reset shape | |
| # --------------------------------------------------------------------------- | |
| def test_reset_returns_valid_observation(client): | |
| resp = client.post("/reset", json={}) | |
| assert resp.status_code == 200, resp.text | |
| body = resp.json() | |
| _assert_observation_envelope(body) | |
| assert body["done"] is False | |
| obs_payload = body["observation"]["payload"] | |
| assert obs_payload["failure_summary"] is not None | |
| def test_reset_with_specific_scenario_id(client, known_scenario_id): | |
| resp = client.post("/reset", json={"scenario_id": known_scenario_id}) | |
| assert resp.status_code == 200, resp.text | |
| body = resp.json() | |
| obs_payload = body["observation"]["payload"] | |
| assert obs_payload["episode_id"] | |
| # --------------------------------------------------------------------------- | |
| # A1.4 Unknown scenario surfaces as a server error | |
| # --------------------------------------------------------------------------- | |
| def test_reset_with_unknown_scenario_id_errors(app): | |
| from fastapi.testclient import TestClient | |
| nonraising = TestClient(app, raise_server_exceptions=False) | |
| resp = nonraising.post("/reset", json={"scenario_id": "does-not-exist"}) | |
| assert resp.status_code >= 400 | |
| # --------------------------------------------------------------------------- | |
| # A1.5–A1.8: Stateful flows over WebSocket /ws | |
| # --------------------------------------------------------------------------- | |
| def _ws_send_recv(ws, message: dict) -> dict: | |
| ws.send_text(json.dumps(message)) | |
| return json.loads(ws.receive_text()) | |
| def test_step_with_tool_call_returns_observation(client, known_scenario_id): | |
| with client.websocket_connect("/ws") as ws: | |
| _ws_send_recv(ws, {"type": "reset", "data": {"scenario_id": known_scenario_id}}) | |
| action = CITriageAction.from_tool_call("read_logs", {"scope": "test"}) | |
| resp = _ws_send_recv(ws, {"type": "step", "data": action.model_dump()}) | |
| assert resp["type"] == "observation" | |
| payload = resp["data"]["observation"]["payload"] | |
| assert payload["tool_response"] is not None | |
| assert payload["tool_response"]["tool_name"] == "read_logs" | |
| assert resp["data"]["done"] is False | |
| def test_step_with_terminal_action_marks_done(client, known_scenario_id): | |
| with client.websocket_connect("/ws") as ws: | |
| _ws_send_recv(ws, {"type": "reset", "data": {"scenario_id": known_scenario_id}}) | |
| terminal = CITriageAction.from_terminal(DiagnosisLabel.RACE_FLAKE, confidence=0.8) | |
| step = _ws_send_recv(ws, {"type": "step", "data": terminal.model_dump()}) | |
| assert step["data"]["done"] is True | |
| assert step["data"]["observation"]["payload"]["is_terminal"] is True | |
| state = _ws_send_recv(ws, {"type": "state"}) | |
| assert state["type"] == "state" | |
| assert state["data"]["payload"]["is_terminated"] is True | |
| assert state["data"]["payload"]["final_action"] is not None | |
| def test_step_after_terminal_returns_error(client, known_scenario_id): | |
| with client.websocket_connect("/ws") as ws: | |
| _ws_send_recv(ws, {"type": "reset", "data": {"scenario_id": known_scenario_id}}) | |
| terminal = CITriageAction.from_terminal(DiagnosisLabel.RACE_FLAKE, confidence=0.8) | |
| _ws_send_recv(ws, {"type": "step", "data": terminal.model_dump()}) | |
| again = _ws_send_recv(ws, {"type": "step", "data": terminal.model_dump()}) | |
| assert again["type"] == "error" | |
| def test_state_endpoint_returns_episode_state(client, known_scenario_id): | |
| with client.websocket_connect("/ws") as ws: | |
| _ws_send_recv(ws, {"type": "reset", "data": {"scenario_id": known_scenario_id}}) | |
| state = _ws_send_recv(ws, {"type": "state"}) | |
| assert state["type"] == "state" | |
| payload = state["data"]["payload"] | |
| assert payload is not None | |
| assert payload["scenario_id"] == known_scenario_id | |
| assert payload["is_terminated"] is False | |
| # --------------------------------------------------------------------------- | |
| # A1.9 Distinct episode_ids across concurrent sessions | |
| # --------------------------------------------------------------------------- | |
| def test_concurrent_ws_sessions_get_distinct_episode_ids(client, known_scenario_id): | |
| episode_ids: list[str] = [] | |
| with client.websocket_connect("/ws") as ws_a, client.websocket_connect("/ws") as ws_b: | |
| for ws in (ws_a, ws_b): | |
| obs = _ws_send_recv(ws, {"type": "reset", "data": {"scenario_id": known_scenario_id}}) | |
| episode_ids.append(obs["data"]["observation"]["payload"]["episode_id"]) | |
| assert episode_ids[0] != episode_ids[1] | |
| # --------------------------------------------------------------------------- | |
| # A1.10 MCP /mcp tools/list — canonical OpenEnv MCP discovery | |
| # --------------------------------------------------------------------------- | |
| EXPECTED_TOOL_NAMES = { | |
| "read_logs", | |
| "inspect_test_code", | |
| "run_diagnostic", | |
| "cluster_metrics", | |
| "query_flake_history", | |
| "recent_commits", | |
| "check_owner", | |
| "rerun_test", | |
| "quarantine_test", | |
| "file_bug", | |
| "ping_owner", | |
| } | |
| def test_mcp_endpoint_lists_all_11_tools(client): | |
| # Establish an MCP session first so subsequent tools/list shares it. | |
| create = client.post( | |
| "/mcp", | |
| json={"jsonrpc": "2.0", "id": "1", "method": "openenv/session/create", "params": {}}, | |
| ) | |
| assert create.status_code == 200, create.text | |
| session_id = create.json()["result"]["session_id"] | |
| resp = client.post( | |
| "/mcp", | |
| json={"jsonrpc": "2.0", "id": "2", "method": "tools/list", "params": {}}, | |
| headers={"X-OpenEnv-Session-Id": session_id}, | |
| ) | |
| assert resp.status_code == 200, resp.text | |
| body = resp.json() | |
| assert "result" in body, body | |
| tools = body["result"]["tools"] | |
| names = {t["name"] for t in tools} | |
| assert names == EXPECTED_TOOL_NAMES | |
| assert len(tools) == 11 | |
| # --------------------------------------------------------------------------- | |
| # A1.11 Deterministic seeding | |
| # --------------------------------------------------------------------------- | |
| def test_episode_seeding_deterministic(client, known_scenario_id): | |
| def run_one() -> dict: | |
| with client.websocket_connect("/ws") as ws: | |
| _ws_send_recv( | |
| ws, | |
| { | |
| "type": "reset", | |
| "data": {"scenario_id": known_scenario_id, "seed": 12345}, | |
| }, | |
| ) | |
| action = CITriageAction.from_tool_call("read_logs", {"scope": "test"}) | |
| _ws_send_recv(ws, {"type": "step", "data": action.model_dump()}) | |
| terminal = CITriageAction.from_terminal(DiagnosisLabel.RACE_FLAKE, confidence=0.6) | |
| _ws_send_recv(ws, {"type": "step", "data": terminal.model_dump()}) | |
| return _ws_send_recv(ws, {"type": "state"})["data"]["payload"] | |
| a = run_one() | |
| b = run_one() | |
| assert a["seed"] == b["seed"] == 12345 | |
| assert a["step"] == b["step"] | |
| assert a["is_terminated"] and b["is_terminated"] | |
| assert [r["action"] for r in a["history"]] == [r["action"] for r in b["history"]] | |
| assert [r["cost_charged"] for r in a["history"]] == [r["cost_charged"] for r in b["history"]] | |
| # --------------------------------------------------------------------------- | |
| # A1.12 /schema is exposed (OpenEnv standard surface) | |
| # --------------------------------------------------------------------------- | |
| def test_schema_endpoint_returns_action_observation_state(client): | |
| resp = client.get("/schema") | |
| assert resp.status_code == 200, resp.text | |
| body = resp.json() | |
| assert "action" in body | |
| assert "observation" in body | |