ronitraj commited on
Commit
9b64226
·
1 Parent(s): 1d6826f

refactor: remove unused routes and clean up API endpoints

Browse files
Files changed (2) hide show
  1. server/app.py +2 -51
  2. tests/test_api.py +12 -30
server/app.py CHANGED
@@ -2,9 +2,8 @@ from __future__ import annotations
2
 
3
  import os
4
  from pathlib import Path
5
- from typing import Iterable
6
 
7
- from fastapi import FastAPI, HTTPException, Query
8
  from fastapi.responses import RedirectResponse
9
  from openenv.core import create_fastapi_app
10
  from dotenv import load_dotenv
@@ -14,7 +13,7 @@ from llmserve_env.task_catalog import get_action_schema, get_task_catalog
14
  from server.baseline_inference import create_local_runner, run_baseline_suite
15
  from server.grader import GraderEngine
16
  from server.llmserve_environment import LLMServeEnvironment
17
- from server.schemas import GraderRequest, ResetRequest, StepRequest
18
  from server.session_manager import SessionManager
19
  from server.web_ui import create_web_app
20
 
@@ -38,20 +37,7 @@ def get_env() -> LLMServeEnvironment:
38
  return shared_env
39
 
40
 
41
- def _remove_routes(app: FastAPI, paths: Iterable[str]) -> None:
42
- blocked = set(paths)
43
- app.router.routes[:] = [route for route in app.router.routes if getattr(route, "path", None) not in blocked]
44
-
45
-
46
  def _register_extra_routes(app: FastAPI) -> FastAPI:
47
- def _resolve_env(session_id: str | None) -> LLMServeEnvironment:
48
- if not session_id:
49
- return shared_env
50
- try:
51
- return session_manager.get(session_id)
52
- except KeyError as exc:
53
- raise HTTPException(status_code=404, detail=str(exc)) from exc
54
-
55
  @app.get("/")
56
  def root() -> RedirectResponse:
57
  return RedirectResponse(url="/web", status_code=307)
@@ -69,40 +55,6 @@ def _register_extra_routes(app: FastAPI) -> FastAPI:
69
  "active_sessions": session_manager.count(),
70
  }
71
 
72
- @app.post("/reset")
73
- def reset(payload: ResetRequest) -> dict[str, object]:
74
- session_id, env = session_manager.create(
75
- task_id=payload.task_id,
76
- seed=payload.seed,
77
- episode_id=payload.episode_id,
78
- )
79
- observation = env.observations[-1]
80
-
81
- return {
82
- "session_id": session_id,
83
- "observation": observation.model_dump(mode="json"),
84
- "reward": observation.reward,
85
- "done": observation.done,
86
- "metadata": observation.metadata,
87
- }
88
-
89
- @app.post("/step")
90
- def step(payload: StepRequest) -> dict[str, object]:
91
- env = _resolve_env(payload.session_id)
92
- observation = env.step(payload.action)
93
- return {
94
- "session_id": payload.session_id or env.state.episode_id,
95
- "observation": observation.model_dump(mode="json"),
96
- "reward": observation.reward,
97
- "done": observation.done,
98
- "metadata": observation.metadata,
99
- }
100
-
101
- @app.get("/state")
102
- def state(session_id: str | None = Query(default=None)) -> dict[str, object]:
103
- env = _resolve_env(session_id)
104
- return env.state.model_dump(mode="json")
105
-
106
  @app.post("/grader")
107
  def grade(payload: GraderRequest | None = None) -> dict[str, object]:
108
  if payload and payload.episode_log is not None:
@@ -154,7 +106,6 @@ def create_application(enable_web: bool = True) -> FastAPI:
154
  ServeAction,
155
  ServeObservation,
156
  )
157
- _remove_routes(app, {"/reset", "/step", "/state"})
158
  if enable_web:
159
  app = create_web_app(app, session_manager, shared_env)
160
  return _register_extra_routes(app)
 
2
 
3
  import os
4
  from pathlib import Path
 
5
 
6
+ from fastapi import FastAPI, HTTPException
7
  from fastapi.responses import RedirectResponse
8
  from openenv.core import create_fastapi_app
9
  from dotenv import load_dotenv
 
13
  from server.baseline_inference import create_local_runner, run_baseline_suite
14
  from server.grader import GraderEngine
15
  from server.llmserve_environment import LLMServeEnvironment
16
+ from server.schemas import GraderRequest
17
  from server.session_manager import SessionManager
18
  from server.web_ui import create_web_app
19
 
 
37
  return shared_env
38
 
39
 
 
 
 
 
 
40
  def _register_extra_routes(app: FastAPI) -> FastAPI:
 
 
 
 
 
 
 
 
41
  @app.get("/")
42
  def root() -> RedirectResponse:
43
  return RedirectResponse(url="/web", status_code=307)
 
55
  "active_sessions": session_manager.count(),
56
  }
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  @app.post("/grader")
59
  def grade(payload: GraderRequest | None = None) -> dict[str, object]:
60
  if payload and payload.episode_log is not None:
 
106
  ServeAction,
107
  ServeObservation,
108
  )
 
109
  if enable_web:
110
  app = create_web_app(app, session_manager, shared_env)
111
  return _register_extra_routes(app)
tests/test_api.py CHANGED
@@ -1,11 +1,12 @@
1
  from __future__ import annotations
2
 
3
  import asyncio
 
4
  import pytest
5
  from fastapi import HTTPException
 
6
 
7
  from server.app import create_application, shared_env
8
- from server.schemas import ResetRequest, StepRequest
9
 
10
 
11
  def _route_map():
@@ -34,6 +35,16 @@ def test_session_routes_are_not_duplicated() -> None:
34
  assert paths.count("/state") == 1
35
 
36
 
 
 
 
 
 
 
 
 
 
 
37
  def test_health_endpoint_direct() -> None:
38
  data = _call(_route_map()["/health"])
39
  status = data["status"] if isinstance(data, dict) else data.status
@@ -64,32 +75,3 @@ def test_baseline_endpoint_direct() -> None:
64
  def test_demo_redirects_to_web() -> None:
65
  response = _call(_route_map()["/demo"])
66
  assert response.headers["location"] == "/web"
67
-
68
-
69
- def test_http_session_advances_across_multiple_steps() -> None:
70
- routes = _route_map()
71
- reset_endpoint = routes["/reset"]
72
- step_endpoint = routes["/step"]
73
- state_endpoint = routes["/state"]
74
-
75
- reset_payload = _call(reset_endpoint, ResetRequest(task_id="bursty_workload", seed=42))
76
- session_id = reset_payload["session_id"]
77
- assert session_id
78
- assert reset_payload["observation"]["step_index"] == 0
79
-
80
- action = {
81
- "batch_cap": 32,
82
- "kv_budget_fraction": 1.0,
83
- "speculation_depth": 0,
84
- "quantization_tier": "FP16",
85
- "prefill_decode_split": False,
86
- "priority_routing": False,
87
- }
88
- first_payload = _call(step_endpoint, StepRequest(session_id=session_id, action=action))
89
- second_payload = _call(step_endpoint, StepRequest(session_id=session_id, action=action))
90
- assert first_payload["observation"]["step_index"] == 1
91
- assert second_payload["observation"]["step_index"] == 2
92
- assert first_payload["reward"] != second_payload["reward"]
93
-
94
- state_payload = _call(state_endpoint, session_id=session_id)
95
- assert state_payload["step_count"] == 2
 
1
  from __future__ import annotations
2
 
3
  import asyncio
4
+ import os
5
  import pytest
6
  from fastapi import HTTPException
7
+ from fastapi.testclient import TestClient
8
 
9
  from server.app import create_application, shared_env
 
10
 
11
 
12
  def _route_map():
 
35
  assert paths.count("/state") == 1
36
 
37
 
38
+ def test_reset_endpoint_accepts_empty_json_body() -> None:
39
+ os.environ.setdefault("MPLCONFIGDIR", "/tmp")
40
+ client = TestClient(create_application(enable_web=False))
41
+ response = client.post("/reset", json={})
42
+ assert response.status_code == 200
43
+ payload = response.json()
44
+ assert payload["observation"]["task_id"] == "static_workload"
45
+ assert payload["observation"]["step_index"] == 0
46
+
47
+
48
  def test_health_endpoint_direct() -> None:
49
  data = _call(_route_map()["/health"])
50
  status = data["status"] if isinstance(data, dict) else data.status
 
75
  def test_demo_redirects_to_web() -> None:
76
  response = _call(_route_map()["/demo"])
77
  assert response.headers["location"] == "/web"