vegarl / server /app.py
ronitraj's picture
fix: accept empty-body reset requests
0d53691
from __future__ import annotations
import os
from pathlib import Path
from fastapi import FastAPI, HTTPException
from fastapi.responses import RedirectResponse
from openenv.core import create_fastapi_app
from dotenv import load_dotenv
from llmserve_env.models import ServeAction, ServeObservation
from llmserve_env.task_catalog import get_action_schema, get_task_catalog
from server.baseline_inference import create_local_runner, run_baseline_suite
from server.grader import GraderEngine
from server.llmserve_environment import LLMServeEnvironment
from server.schemas import GraderRequest, ResetRequest
from server.session_manager import SessionManager
from server.web_ui import create_web_app
ROOT_DIR = Path(__file__).resolve().parents[1]
load_dotenv(ROOT_DIR / ".env", override=False)
def _build_shared_env() -> LLMServeEnvironment:
seed = int(os.getenv("LLMSERVE_SEED", "42"))
mode = os.getenv("LLMSERVE_MODE")
return LLMServeEnvironment(seed=seed, mode=mode)
shared_env = _build_shared_env()
grader = GraderEngine()
session_manager = SessionManager()
def get_env() -> LLMServeEnvironment:
return shared_env
def _remove_routes(app: FastAPI, paths: set[str]) -> None:
app.router.routes[:] = [route for route in app.router.routes if getattr(route, "path", None) not in paths]
def _register_extra_routes(app: FastAPI) -> FastAPI:
@app.get("/")
def root() -> RedirectResponse:
return RedirectResponse(url="/web", status_code=307)
@app.post("/reset")
def reset(payload: ResetRequest | None = None) -> dict[str, object]:
payload = payload or ResetRequest()
session_id, env = session_manager.create(
task_id=payload.task_id,
seed=payload.seed,
episode_id=payload.episode_id,
)
observation = env.observations[-1]
return {
"session_id": session_id,
"observation": observation.model_dump(mode="json"),
"reward": observation.reward,
"done": observation.done,
"metadata": observation.metadata,
}
@app.get("/tasks")
def tasks() -> dict[str, object]:
return {"tasks": get_task_catalog(), "action_schema": get_action_schema()}
@app.get("/runtime")
def runtime() -> dict[str, object]:
return {
"mode": shared_env.backend.mode,
"backend": shared_env.backend.describe(),
"seed": shared_env.seed,
"active_sessions": session_manager.count(),
}
@app.post("/grader")
def grade(payload: GraderRequest | None = None) -> dict[str, object]:
if payload and payload.episode_log is not None:
if payload.task_id and payload.task_id != payload.episode_log.task_id:
raise HTTPException(status_code=400, detail="task_id does not match episode_log.task_id.")
return grader.grade(payload.episode_log, actions_taken=payload.actions_taken)
if not shared_env.observations:
raise HTTPException(status_code=400, detail="No active or completed episode is available to grade.")
current_log = shared_env.export_episode_log()
if payload and payload.task_id and payload.task_id != current_log.task_id:
raise HTTPException(status_code=400, detail="task_id does not match the active episode.")
return grader.grade(current_log, actions_taken=payload.actions_taken if payload else None)
@app.get("/baseline")
def baseline(
task_id: str | None = None,
use_openai: bool = False,
model: str = "gpt-4.1-mini",
seed: int = 42,
) -> dict[str, object]:
task_ids = [task_id] if task_id else [task["id"] for task in get_task_catalog()]
mode = "openai" if use_openai else "deterministic"
try:
runner_factory = (
(lambda: create_local_runner(seed=seed, mode=os.getenv("LLMSERVE_MODE", "sim")))
if use_openai
else (lambda: create_local_runner(seed=seed, mode="sim"))
)
return run_baseline_suite(
mode=mode,
task_ids=task_ids,
seed=seed,
model=model,
runner_factory=runner_factory,
)
except RuntimeError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
@app.get("/demo")
def demo() -> RedirectResponse:
return RedirectResponse(url="/web", status_code=307)
return app
def create_application(enable_web: bool = True) -> FastAPI:
app = create_fastapi_app(
get_env,
ServeAction,
ServeObservation,
)
_remove_routes(app, {"/reset"})
if enable_web:
app = create_web_app(app, session_manager, shared_env)
return _register_extra_routes(app)
def create_test_application() -> FastAPI:
return create_application(enable_web=False)
app = create_application(enable_web=True)
def main(host: str = "0.0.0.0", port: int = 7860) -> None:
import uvicorn
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
main()