refactor: remove unused routes and clean up API endpoints
Browse files- server/app.py +2 -51
- tests/test_api.py +12 -30
server/app.py
CHANGED
|
@@ -2,9 +2,8 @@ from __future__ import annotations
|
|
| 2 |
|
| 3 |
import os
|
| 4 |
from pathlib import Path
|
| 5 |
-
from typing import Iterable
|
| 6 |
|
| 7 |
-
from fastapi import FastAPI, HTTPException
|
| 8 |
from fastapi.responses import RedirectResponse
|
| 9 |
from openenv.core import create_fastapi_app
|
| 10 |
from dotenv import load_dotenv
|
|
@@ -14,7 +13,7 @@ from llmserve_env.task_catalog import get_action_schema, get_task_catalog
|
|
| 14 |
from server.baseline_inference import create_local_runner, run_baseline_suite
|
| 15 |
from server.grader import GraderEngine
|
| 16 |
from server.llmserve_environment import LLMServeEnvironment
|
| 17 |
-
from server.schemas import GraderRequest
|
| 18 |
from server.session_manager import SessionManager
|
| 19 |
from server.web_ui import create_web_app
|
| 20 |
|
|
@@ -38,20 +37,7 @@ def get_env() -> LLMServeEnvironment:
|
|
| 38 |
return shared_env
|
| 39 |
|
| 40 |
|
| 41 |
-
def _remove_routes(app: FastAPI, paths: Iterable[str]) -> None:
|
| 42 |
-
blocked = set(paths)
|
| 43 |
-
app.router.routes[:] = [route for route in app.router.routes if getattr(route, "path", None) not in blocked]
|
| 44 |
-
|
| 45 |
-
|
| 46 |
def _register_extra_routes(app: FastAPI) -> FastAPI:
|
| 47 |
-
def _resolve_env(session_id: str | None) -> LLMServeEnvironment:
|
| 48 |
-
if not session_id:
|
| 49 |
-
return shared_env
|
| 50 |
-
try:
|
| 51 |
-
return session_manager.get(session_id)
|
| 52 |
-
except KeyError as exc:
|
| 53 |
-
raise HTTPException(status_code=404, detail=str(exc)) from exc
|
| 54 |
-
|
| 55 |
@app.get("/")
|
| 56 |
def root() -> RedirectResponse:
|
| 57 |
return RedirectResponse(url="/web", status_code=307)
|
|
@@ -69,40 +55,6 @@ def _register_extra_routes(app: FastAPI) -> FastAPI:
|
|
| 69 |
"active_sessions": session_manager.count(),
|
| 70 |
}
|
| 71 |
|
| 72 |
-
@app.post("/reset")
|
| 73 |
-
def reset(payload: ResetRequest) -> dict[str, object]:
|
| 74 |
-
session_id, env = session_manager.create(
|
| 75 |
-
task_id=payload.task_id,
|
| 76 |
-
seed=payload.seed,
|
| 77 |
-
episode_id=payload.episode_id,
|
| 78 |
-
)
|
| 79 |
-
observation = env.observations[-1]
|
| 80 |
-
|
| 81 |
-
return {
|
| 82 |
-
"session_id": session_id,
|
| 83 |
-
"observation": observation.model_dump(mode="json"),
|
| 84 |
-
"reward": observation.reward,
|
| 85 |
-
"done": observation.done,
|
| 86 |
-
"metadata": observation.metadata,
|
| 87 |
-
}
|
| 88 |
-
|
| 89 |
-
@app.post("/step")
|
| 90 |
-
def step(payload: StepRequest) -> dict[str, object]:
|
| 91 |
-
env = _resolve_env(payload.session_id)
|
| 92 |
-
observation = env.step(payload.action)
|
| 93 |
-
return {
|
| 94 |
-
"session_id": payload.session_id or env.state.episode_id,
|
| 95 |
-
"observation": observation.model_dump(mode="json"),
|
| 96 |
-
"reward": observation.reward,
|
| 97 |
-
"done": observation.done,
|
| 98 |
-
"metadata": observation.metadata,
|
| 99 |
-
}
|
| 100 |
-
|
| 101 |
-
@app.get("/state")
|
| 102 |
-
def state(session_id: str | None = Query(default=None)) -> dict[str, object]:
|
| 103 |
-
env = _resolve_env(session_id)
|
| 104 |
-
return env.state.model_dump(mode="json")
|
| 105 |
-
|
| 106 |
@app.post("/grader")
|
| 107 |
def grade(payload: GraderRequest | None = None) -> dict[str, object]:
|
| 108 |
if payload and payload.episode_log is not None:
|
|
@@ -154,7 +106,6 @@ def create_application(enable_web: bool = True) -> FastAPI:
|
|
| 154 |
ServeAction,
|
| 155 |
ServeObservation,
|
| 156 |
)
|
| 157 |
-
_remove_routes(app, {"/reset", "/step", "/state"})
|
| 158 |
if enable_web:
|
| 159 |
app = create_web_app(app, session_manager, shared_env)
|
| 160 |
return _register_extra_routes(app)
|
|
|
|
| 2 |
|
| 3 |
import os
|
| 4 |
from pathlib import Path
|
|
|
|
| 5 |
|
| 6 |
+
from fastapi import FastAPI, HTTPException
|
| 7 |
from fastapi.responses import RedirectResponse
|
| 8 |
from openenv.core import create_fastapi_app
|
| 9 |
from dotenv import load_dotenv
|
|
|
|
| 13 |
from server.baseline_inference import create_local_runner, run_baseline_suite
|
| 14 |
from server.grader import GraderEngine
|
| 15 |
from server.llmserve_environment import LLMServeEnvironment
|
| 16 |
+
from server.schemas import GraderRequest
|
| 17 |
from server.session_manager import SessionManager
|
| 18 |
from server.web_ui import create_web_app
|
| 19 |
|
|
|
|
| 37 |
return shared_env
|
| 38 |
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
def _register_extra_routes(app: FastAPI) -> FastAPI:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
@app.get("/")
|
| 42 |
def root() -> RedirectResponse:
|
| 43 |
return RedirectResponse(url="/web", status_code=307)
|
|
|
|
| 55 |
"active_sessions": session_manager.count(),
|
| 56 |
}
|
| 57 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
@app.post("/grader")
|
| 59 |
def grade(payload: GraderRequest | None = None) -> dict[str, object]:
|
| 60 |
if payload and payload.episode_log is not None:
|
|
|
|
| 106 |
ServeAction,
|
| 107 |
ServeObservation,
|
| 108 |
)
|
|
|
|
| 109 |
if enable_web:
|
| 110 |
app = create_web_app(app, session_manager, shared_env)
|
| 111 |
return _register_extra_routes(app)
|
tests/test_api.py
CHANGED
|
@@ -1,11 +1,12 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import asyncio
|
|
|
|
| 4 |
import pytest
|
| 5 |
from fastapi import HTTPException
|
|
|
|
| 6 |
|
| 7 |
from server.app import create_application, shared_env
|
| 8 |
-
from server.schemas import ResetRequest, StepRequest
|
| 9 |
|
| 10 |
|
| 11 |
def _route_map():
|
|
@@ -34,6 +35,16 @@ def test_session_routes_are_not_duplicated() -> None:
|
|
| 34 |
assert paths.count("/state") == 1
|
| 35 |
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
def test_health_endpoint_direct() -> None:
|
| 38 |
data = _call(_route_map()["/health"])
|
| 39 |
status = data["status"] if isinstance(data, dict) else data.status
|
|
@@ -64,32 +75,3 @@ def test_baseline_endpoint_direct() -> None:
|
|
| 64 |
def test_demo_redirects_to_web() -> None:
|
| 65 |
response = _call(_route_map()["/demo"])
|
| 66 |
assert response.headers["location"] == "/web"
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
def test_http_session_advances_across_multiple_steps() -> None:
|
| 70 |
-
routes = _route_map()
|
| 71 |
-
reset_endpoint = routes["/reset"]
|
| 72 |
-
step_endpoint = routes["/step"]
|
| 73 |
-
state_endpoint = routes["/state"]
|
| 74 |
-
|
| 75 |
-
reset_payload = _call(reset_endpoint, ResetRequest(task_id="bursty_workload", seed=42))
|
| 76 |
-
session_id = reset_payload["session_id"]
|
| 77 |
-
assert session_id
|
| 78 |
-
assert reset_payload["observation"]["step_index"] == 0
|
| 79 |
-
|
| 80 |
-
action = {
|
| 81 |
-
"batch_cap": 32,
|
| 82 |
-
"kv_budget_fraction": 1.0,
|
| 83 |
-
"speculation_depth": 0,
|
| 84 |
-
"quantization_tier": "FP16",
|
| 85 |
-
"prefill_decode_split": False,
|
| 86 |
-
"priority_routing": False,
|
| 87 |
-
}
|
| 88 |
-
first_payload = _call(step_endpoint, StepRequest(session_id=session_id, action=action))
|
| 89 |
-
second_payload = _call(step_endpoint, StepRequest(session_id=session_id, action=action))
|
| 90 |
-
assert first_payload["observation"]["step_index"] == 1
|
| 91 |
-
assert second_payload["observation"]["step_index"] == 2
|
| 92 |
-
assert first_payload["reward"] != second_payload["reward"]
|
| 93 |
-
|
| 94 |
-
state_payload = _call(state_endpoint, session_id=session_id)
|
| 95 |
-
assert state_payload["step_count"] == 2
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
import asyncio
|
| 4 |
+
import os
|
| 5 |
import pytest
|
| 6 |
from fastapi import HTTPException
|
| 7 |
+
from fastapi.testclient import TestClient
|
| 8 |
|
| 9 |
from server.app import create_application, shared_env
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
def _route_map():
|
|
|
|
| 35 |
assert paths.count("/state") == 1
|
| 36 |
|
| 37 |
|
| 38 |
+
def test_reset_endpoint_accepts_empty_json_body() -> None:
|
| 39 |
+
os.environ.setdefault("MPLCONFIGDIR", "/tmp")
|
| 40 |
+
client = TestClient(create_application(enable_web=False))
|
| 41 |
+
response = client.post("/reset", json={})
|
| 42 |
+
assert response.status_code == 200
|
| 43 |
+
payload = response.json()
|
| 44 |
+
assert payload["observation"]["task_id"] == "static_workload"
|
| 45 |
+
assert payload["observation"]["step_index"] == 0
|
| 46 |
+
|
| 47 |
+
|
| 48 |
def test_health_endpoint_direct() -> None:
|
| 49 |
data = _call(_route_map()["/health"])
|
| 50 |
status = data["status"] if isinstance(data, dict) else data.status
|
|
|
|
| 75 |
def test_demo_redirects_to_web() -> None:
|
| 76 |
response = _call(_route_map()["/demo"])
|
| 77 |
assert response.headers["location"] == "/web"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|