from __future__ import annotations import os from pathlib import Path from fastapi import FastAPI, HTTPException from fastapi.responses import RedirectResponse from openenv.core import create_fastapi_app from dotenv import load_dotenv from llmserve_env.models import ServeAction, ServeObservation from llmserve_env.task_catalog import get_action_schema, get_task_catalog from server.baseline_inference import create_local_runner, run_baseline_suite from server.grader import GraderEngine from server.llmserve_environment import LLMServeEnvironment from server.schemas import GraderRequest, ResetRequest from server.session_manager import SessionManager from server.web_ui import create_web_app ROOT_DIR = Path(__file__).resolve().parents[1] load_dotenv(ROOT_DIR / ".env", override=False) def _build_shared_env() -> LLMServeEnvironment: seed = int(os.getenv("LLMSERVE_SEED", "42")) mode = os.getenv("LLMSERVE_MODE") return LLMServeEnvironment(seed=seed, mode=mode) shared_env = _build_shared_env() grader = GraderEngine() session_manager = SessionManager() def get_env() -> LLMServeEnvironment: return shared_env def _remove_routes(app: FastAPI, paths: set[str]) -> None: app.router.routes[:] = [route for route in app.router.routes if getattr(route, "path", None) not in paths] def _register_extra_routes(app: FastAPI) -> FastAPI: @app.get("/") def root() -> RedirectResponse: return RedirectResponse(url="/web", status_code=307) @app.post("/reset") def reset(payload: ResetRequest | None = None) -> dict[str, object]: payload = payload or ResetRequest() session_id, env = session_manager.create( task_id=payload.task_id, seed=payload.seed, episode_id=payload.episode_id, ) observation = env.observations[-1] return { "session_id": session_id, "observation": observation.model_dump(mode="json"), "reward": observation.reward, "done": observation.done, "metadata": observation.metadata, } @app.get("/tasks") def tasks() -> dict[str, object]: return {"tasks": get_task_catalog(), "action_schema": get_action_schema()} @app.get("/runtime") def runtime() -> dict[str, object]: return { "mode": shared_env.backend.mode, "backend": shared_env.backend.describe(), "seed": shared_env.seed, "active_sessions": session_manager.count(), } @app.post("/grader") def grade(payload: GraderRequest | None = None) -> dict[str, object]: if payload and payload.episode_log is not None: if payload.task_id and payload.task_id != payload.episode_log.task_id: raise HTTPException(status_code=400, detail="task_id does not match episode_log.task_id.") return grader.grade(payload.episode_log, actions_taken=payload.actions_taken) if not shared_env.observations: raise HTTPException(status_code=400, detail="No active or completed episode is available to grade.") current_log = shared_env.export_episode_log() if payload and payload.task_id and payload.task_id != current_log.task_id: raise HTTPException(status_code=400, detail="task_id does not match the active episode.") return grader.grade(current_log, actions_taken=payload.actions_taken if payload else None) @app.get("/baseline") def baseline( task_id: str | None = None, use_openai: bool = False, model: str = "gpt-4.1-mini", seed: int = 42, ) -> dict[str, object]: task_ids = [task_id] if task_id else [task["id"] for task in get_task_catalog()] mode = "openai" if use_openai else "deterministic" try: runner_factory = ( (lambda: create_local_runner(seed=seed, mode=os.getenv("LLMSERVE_MODE", "sim"))) if use_openai else (lambda: create_local_runner(seed=seed, mode="sim")) ) return run_baseline_suite( mode=mode, task_ids=task_ids, seed=seed, model=model, runner_factory=runner_factory, ) except RuntimeError as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc @app.get("/demo") def demo() -> RedirectResponse: return RedirectResponse(url="/web", status_code=307) return app def create_application(enable_web: bool = True) -> FastAPI: app = create_fastapi_app( get_env, ServeAction, ServeObservation, ) _remove_routes(app, {"/reset"}) if enable_web: app = create_web_app(app, session_manager, shared_env) return _register_extra_routes(app) def create_test_application() -> FastAPI: return create_application(enable_web=False) app = create_application(enable_web=True) def main(host: str = "0.0.0.0", port: int = 7860) -> None: import uvicorn uvicorn.run(app, host=host, port=port) if __name__ == "__main__": main()