Spaces:
Sleeping
Sleeping
File size: 5,096 Bytes
ad43b05 f0ef01d ad43b05 f0ef01d ad43b05 f0ef01d ad43b05 f0ef01d 1bb11d9 f0ef01d 1bb11d9 f0ef01d ad43b05 f0ef01d ad43b05 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | """FastAPI app factory for PolypharmacyEnv."""
from __future__ import annotations
from pathlib import Path
from typing import Any, Dict, Optional
from dotenv import load_dotenv
from fastapi import HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from openenv.core.env_server.http_server import create_app
from starlette.responses import FileResponse
from ..env_core import PolypharmacyEnv
from ..models import PolypharmacyAction, PolypharmacyObservation
from .routes.agent import router as agent_router
from .routes.bandit import router as bandit_router
load_dotenv()
class SPAStaticFiles(StaticFiles):
"""Serve SPA index for unknown frontend routes."""
async def get_response(self, path: str, scope):
response = await super().get_response(path, scope)
if response.status_code != 404:
return response
index_path = Path(self.directory) / "index.html"
if index_path.exists():
return FileResponse(index_path)
raise HTTPException(status_code=404, detail="Not Found")
# ββ Stateful singleton for HTTP-based inference ββββββββββββββββββββββββββββββ
# OpenEnv's built-in HTTP /reset and /step handlers are stateless (they create
# a fresh env per call). The WebSocket /ws endpoint handles stateful sessions
# for the frontend. For the inference.py script (and the evaluator), we need
# HTTP endpoints that maintain state across reset β step β step β ... calls.
# We override OpenEnv's default routes with stateful versions.
_http_env: Optional[PolypharmacyEnv] = None
def _get_or_create_env() -> PolypharmacyEnv:
global _http_env
if _http_env is None:
_http_env = PolypharmacyEnv()
return _http_env
def _serialize_obs(obs: PolypharmacyObservation) -> Dict[str, Any]:
"""Convert observation to JSON-serializable dict."""
return obs.model_dump() if hasattr(obs, "model_dump") else obs.dict()
def create_polypharmacy_app():
app = create_app(
PolypharmacyEnv,
PolypharmacyAction,
PolypharmacyObservation,
env_name="polypharmacy_env",
)
# ββ Override stateless HTTP routes with stateful ones βββββββββββββββββ
# Remove OpenEnv's default /reset and /step routes so ours take priority
new_routes = []
for route in app.routes:
path = getattr(route, "path", "")
if path in ("/reset", "/step", "/state"):
continue
new_routes.append(route)
app.routes[:] = new_routes
@app.post("/reset")
async def stateful_reset(body: Dict[str, Any] = {}):
env = _get_or_create_env()
task_id = body.get("task_id", None)
kwargs = {}
if task_id:
kwargs["task_id"] = task_id
seed = body.get("seed", None)
episode_id = body.get("episode_id", None)
obs = env.reset(seed=seed, episode_id=episode_id, **kwargs)
obs_data = _serialize_obs(obs)
return {
"observation": obs_data,
"reward": 0.0,
"done": False,
}
@app.post("/step")
async def stateful_step(body: Dict[str, Any] = {}):
env = _get_or_create_env()
action_data = body.get("action", body)
try:
action = PolypharmacyAction(**action_data)
except Exception as e:
raise HTTPException(status_code=422, detail=str(e))
obs = env.step(action)
obs_data = _serialize_obs(obs)
# Extract metadata for top-level info
metadata = obs_data.get("metadata", {}) or {}
raw_reward = obs_data.get("shaped_reward", 0.001)
# Clamp reward to strict (0.001, 0.999) bounds
clamped_reward = max(0.001, min(0.999, float(raw_reward)))
return {
"observation": obs_data,
"reward": clamped_reward,
"done": obs_data.get("done", False),
"info": metadata,
}
@app.get("/state")
async def stateful_state():
env = _get_or_create_env()
state = env.state
return state.model_dump() if hasattr(state, "model_dump") else state.dict()
# ββ Middleware & extra routes βββββββββββββββββββββββββββββββββββββββββ
app.add_middleware(
CORSMiddleware,
allow_origins=[
"http://localhost:5173",
"http://127.0.0.1:5173",
],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.include_router(agent_router)
app.include_router(bandit_router)
# In Docker Space deployment, serve built frontend from same container.
project_root = Path(__file__).resolve().parents[4]
frontend_dist = project_root / "frontend" / "dist"
if frontend_dist.exists():
app.mount("/", SPAStaticFiles(directory=frontend_dist, html=True), name="frontend")
return app
app = create_polypharmacy_app()
|