File size: 5,144 Bytes
4fbc241
 
 
 
 
9b64226
4fbc241
 
 
 
 
 
 
 
 
0d53691
8ec915d
4fbc241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ec915d
4fbc241
 
 
 
 
 
0d53691
 
 
 
4fbc241
 
 
 
 
0d53691
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4fbc241
 
 
 
 
 
 
 
 
 
8ec915d
 
 
4fbc241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ec915d
 
 
 
 
0d53691
4fbc241
8ec915d
4fbc241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
from __future__ import annotations

import os
from pathlib import Path

from fastapi import FastAPI, HTTPException
from fastapi.responses import RedirectResponse
from openenv.core import create_fastapi_app
from dotenv import load_dotenv

from llmserve_env.models import ServeAction, ServeObservation
from llmserve_env.task_catalog import get_action_schema, get_task_catalog
from server.baseline_inference import create_local_runner, run_baseline_suite
from server.grader import GraderEngine
from server.llmserve_environment import LLMServeEnvironment
from server.schemas import GraderRequest, ResetRequest
from server.session_manager import SessionManager
from server.web_ui import create_web_app


ROOT_DIR = Path(__file__).resolve().parents[1]
load_dotenv(ROOT_DIR / ".env", override=False)


def _build_shared_env() -> LLMServeEnvironment:
    seed = int(os.getenv("LLMSERVE_SEED", "42"))
    mode = os.getenv("LLMSERVE_MODE")
    return LLMServeEnvironment(seed=seed, mode=mode)


shared_env = _build_shared_env()
grader = GraderEngine()
session_manager = SessionManager()


def get_env() -> LLMServeEnvironment:
    return shared_env


def _remove_routes(app: FastAPI, paths: set[str]) -> None:
    app.router.routes[:] = [route for route in app.router.routes if getattr(route, "path", None) not in paths]


def _register_extra_routes(app: FastAPI) -> FastAPI:
    @app.get("/")
    def root() -> RedirectResponse:
        return RedirectResponse(url="/web", status_code=307)

    @app.post("/reset")
    def reset(payload: ResetRequest | None = None) -> dict[str, object]:
        payload = payload or ResetRequest()
        session_id, env = session_manager.create(
            task_id=payload.task_id,
            seed=payload.seed,
            episode_id=payload.episode_id,
        )
        observation = env.observations[-1]
        return {
            "session_id": session_id,
            "observation": observation.model_dump(mode="json"),
            "reward": observation.reward,
            "done": observation.done,
            "metadata": observation.metadata,
        }

    @app.get("/tasks")
    def tasks() -> dict[str, object]:
        return {"tasks": get_task_catalog(), "action_schema": get_action_schema()}

    @app.get("/runtime")
    def runtime() -> dict[str, object]:
        return {
            "mode": shared_env.backend.mode,
            "backend": shared_env.backend.describe(),
            "seed": shared_env.seed,
            "active_sessions": session_manager.count(),
        }

    @app.post("/grader")
    def grade(payload: GraderRequest | None = None) -> dict[str, object]:
        if payload and payload.episode_log is not None:
            if payload.task_id and payload.task_id != payload.episode_log.task_id:
                raise HTTPException(status_code=400, detail="task_id does not match episode_log.task_id.")
            return grader.grade(payload.episode_log, actions_taken=payload.actions_taken)
        if not shared_env.observations:
            raise HTTPException(status_code=400, detail="No active or completed episode is available to grade.")
        current_log = shared_env.export_episode_log()
        if payload and payload.task_id and payload.task_id != current_log.task_id:
            raise HTTPException(status_code=400, detail="task_id does not match the active episode.")
        return grader.grade(current_log, actions_taken=payload.actions_taken if payload else None)

    @app.get("/baseline")
    def baseline(
        task_id: str | None = None,
        use_openai: bool = False,
        model: str = "gpt-4.1-mini",
        seed: int = 42,
    ) -> dict[str, object]:
        task_ids = [task_id] if task_id else [task["id"] for task in get_task_catalog()]
        mode = "openai" if use_openai else "deterministic"
        try:
            runner_factory = (
                (lambda: create_local_runner(seed=seed, mode=os.getenv("LLMSERVE_MODE", "sim")))
                if use_openai
                else (lambda: create_local_runner(seed=seed, mode="sim"))
            )
            return run_baseline_suite(
                mode=mode,
                task_ids=task_ids,
                seed=seed,
                model=model,
                runner_factory=runner_factory,
            )
        except RuntimeError as exc:
            raise HTTPException(status_code=400, detail=str(exc)) from exc

    @app.get("/demo")
    def demo() -> RedirectResponse:
        return RedirectResponse(url="/web", status_code=307)

    return app


def create_application(enable_web: bool = True) -> FastAPI:
    app = create_fastapi_app(
        get_env,
        ServeAction,
        ServeObservation,
    )
    _remove_routes(app, {"/reset"})
    if enable_web:
        app = create_web_app(app, session_manager, shared_env)
    return _register_extra_routes(app)


def create_test_application() -> FastAPI:
    return create_application(enable_web=False)


app = create_application(enable_web=True)


def main(host: str = "0.0.0.0", port: int = 7860) -> None:
    import uvicorn

    uvicorn.run(app, host=host, port=port)


if __name__ == "__main__":
    main()