File size: 5,144 Bytes
4fbc241 9b64226 4fbc241 0d53691 8ec915d 4fbc241 8ec915d 4fbc241 0d53691 4fbc241 0d53691 4fbc241 8ec915d 4fbc241 8ec915d 0d53691 4fbc241 8ec915d 4fbc241 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | from __future__ import annotations
import os
from pathlib import Path
from fastapi import FastAPI, HTTPException
from fastapi.responses import RedirectResponse
from openenv.core import create_fastapi_app
from dotenv import load_dotenv
from llmserve_env.models import ServeAction, ServeObservation
from llmserve_env.task_catalog import get_action_schema, get_task_catalog
from server.baseline_inference import create_local_runner, run_baseline_suite
from server.grader import GraderEngine
from server.llmserve_environment import LLMServeEnvironment
from server.schemas import GraderRequest, ResetRequest
from server.session_manager import SessionManager
from server.web_ui import create_web_app
ROOT_DIR = Path(__file__).resolve().parents[1]
load_dotenv(ROOT_DIR / ".env", override=False)
def _build_shared_env() -> LLMServeEnvironment:
seed = int(os.getenv("LLMSERVE_SEED", "42"))
mode = os.getenv("LLMSERVE_MODE")
return LLMServeEnvironment(seed=seed, mode=mode)
shared_env = _build_shared_env()
grader = GraderEngine()
session_manager = SessionManager()
def get_env() -> LLMServeEnvironment:
return shared_env
def _remove_routes(app: FastAPI, paths: set[str]) -> None:
app.router.routes[:] = [route for route in app.router.routes if getattr(route, "path", None) not in paths]
def _register_extra_routes(app: FastAPI) -> FastAPI:
@app.get("/")
def root() -> RedirectResponse:
return RedirectResponse(url="/web", status_code=307)
@app.post("/reset")
def reset(payload: ResetRequest | None = None) -> dict[str, object]:
payload = payload or ResetRequest()
session_id, env = session_manager.create(
task_id=payload.task_id,
seed=payload.seed,
episode_id=payload.episode_id,
)
observation = env.observations[-1]
return {
"session_id": session_id,
"observation": observation.model_dump(mode="json"),
"reward": observation.reward,
"done": observation.done,
"metadata": observation.metadata,
}
@app.get("/tasks")
def tasks() -> dict[str, object]:
return {"tasks": get_task_catalog(), "action_schema": get_action_schema()}
@app.get("/runtime")
def runtime() -> dict[str, object]:
return {
"mode": shared_env.backend.mode,
"backend": shared_env.backend.describe(),
"seed": shared_env.seed,
"active_sessions": session_manager.count(),
}
@app.post("/grader")
def grade(payload: GraderRequest | None = None) -> dict[str, object]:
if payload and payload.episode_log is not None:
if payload.task_id and payload.task_id != payload.episode_log.task_id:
raise HTTPException(status_code=400, detail="task_id does not match episode_log.task_id.")
return grader.grade(payload.episode_log, actions_taken=payload.actions_taken)
if not shared_env.observations:
raise HTTPException(status_code=400, detail="No active or completed episode is available to grade.")
current_log = shared_env.export_episode_log()
if payload and payload.task_id and payload.task_id != current_log.task_id:
raise HTTPException(status_code=400, detail="task_id does not match the active episode.")
return grader.grade(current_log, actions_taken=payload.actions_taken if payload else None)
@app.get("/baseline")
def baseline(
task_id: str | None = None,
use_openai: bool = False,
model: str = "gpt-4.1-mini",
seed: int = 42,
) -> dict[str, object]:
task_ids = [task_id] if task_id else [task["id"] for task in get_task_catalog()]
mode = "openai" if use_openai else "deterministic"
try:
runner_factory = (
(lambda: create_local_runner(seed=seed, mode=os.getenv("LLMSERVE_MODE", "sim")))
if use_openai
else (lambda: create_local_runner(seed=seed, mode="sim"))
)
return run_baseline_suite(
mode=mode,
task_ids=task_ids,
seed=seed,
model=model,
runner_factory=runner_factory,
)
except RuntimeError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
@app.get("/demo")
def demo() -> RedirectResponse:
return RedirectResponse(url="/web", status_code=307)
return app
def create_application(enable_web: bool = True) -> FastAPI:
app = create_fastapi_app(
get_env,
ServeAction,
ServeObservation,
)
_remove_routes(app, {"/reset"})
if enable_web:
app = create_web_app(app, session_manager, shared_env)
return _register_extra_routes(app)
def create_test_application() -> FastAPI:
return create_application(enable_web=False)
app = create_application(enable_web=True)
def main(host: str = "0.0.0.0", port: int = 7860) -> None:
import uvicorn
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
main()
|