"""FastAPI wrapper for PolyGuardEnv (OpenEnv-style).""" from __future__ import annotations import json import os from typing import Any, Optional import uvicorn from fastapi import FastAPI, WebSocket, WebSocketDisconnect from pydantic import BaseModel, ConfigDict from app.common.config import load_project_env from app.common.enums import Difficulty, SubEnvironment from app.common.types import PolyGuardAction, PolyGuardObservation, PolyGuardState from app.env.env_core import PolyGuardEnv load_project_env() app = FastAPI(title="POLYGUARD-RL Env Service", version="0.1.0") _ENV = PolyGuardEnv() class ResetRequest(BaseModel): model_config = ConfigDict(extra="forbid") seed: Optional[int] = None difficulty: Optional[Difficulty] = None sub_environment: Optional[SubEnvironment] = None scenario_id: Optional[str] = None patient_id: Optional[str] = None def _step_payload(observation: dict[str, Any], reward: float, done: bool, info: dict[str, Any]) -> dict[str, Any]: reason = str(info.get("termination_reason", "")) if isinstance(info, dict) else "" truncated = reason in {"wall_clock_timeout", "step_timeout", "step_budget_exhausted"} return { "observation": observation, "reward": reward, "done": done, "terminated": done, "truncated": truncated, "info": info, } @app.get("/health") def health() -> dict[str, str]: return {"status": "healthy"} @app.post("/env/reset") def env_reset(request: ResetRequest) -> dict[str, Any]: obs = _ENV.reset( seed=request.seed, difficulty=request.difficulty, sub_environment=request.sub_environment, scenario_id=request.scenario_id, patient_id=request.patient_id, ) return {"observation": obs.model_dump(mode="json")} @app.post("/env/step") def env_step(action: dict[str, Any]) -> dict[str, Any]: obs, reward, done, info = _ENV.step(action) return _step_payload(observation=obs.model_dump(mode="json"), reward=reward, done=done, info=info) @app.get("/env/state") def env_state() -> dict[str, Any]: return _ENV.get_state() @app.get("/env/trace") def env_trace() -> list[dict[str, Any]]: return _ENV.get_trace() @app.get("/env/legal_actions") def env_legal_actions() -> list[dict[str, Any]]: return _ENV.get_legal_actions() @app.get("/env/reward_breakdown") def env_reward_breakdown() -> dict[str, Any]: return _ENV.get_reward_breakdown() @app.get("/env/uncertainty") def env_uncertainty() -> dict[str, Any]: return _ENV.get_uncertainty_report().model_dump(mode="json") @app.get("/env/metadata") def env_metadata() -> dict[str, Any]: return _ENV.get_metadata() @app.get("/schema") def schema() -> dict[str, Any]: return { "action": PolyGuardAction.model_json_schema(), "observation": PolyGuardObservation.model_json_schema(), "state": PolyGuardState.model_json_schema(), } @app.post("/mcp") def mcp(payload: dict[str, Any]) -> dict[str, Any]: request_id = payload.get("id") method = str(payload.get("method", "") or "") params = payload.get("params", {}) if isinstance(payload.get("params", {}), dict) else {} try: if method == "tools/list": result = { "tools": [ { "name": "env.reset", "description": "Reset environment and return initial observation payload.", "inputSchema": { "type": "object", "properties": { "seed": {"type": "integer"}, "difficulty": {"type": "string"}, "sub_environment": {"type": "string"}, "scenario_id": {"type": "string"}, "patient_id": {"type": "string"}, }, }, }, { "name": "env.step", "description": "Execute a policy action.", "inputSchema": PolyGuardAction.model_json_schema(), }, { "name": "env.state", "description": "Get current environment state.", "inputSchema": {"type": "object", "properties": {}}, }, { "name": "env.metadata", "description": "Get environment metadata.", "inputSchema": {"type": "object", "properties": {}}, }, ] } elif method == "tools/call": tool_name = str(params.get("name", "") or "") arguments = params.get("arguments", {}) if isinstance(params.get("arguments"), dict) else {} if tool_name == "env.reset": request = ResetRequest.model_validate(arguments) result = env_reset(request) elif tool_name == "env.step": result = env_step(arguments) elif tool_name == "env.state": result = env_state() elif tool_name == "env.metadata": result = env_metadata() else: raise ValueError(f"Unknown tool name: {tool_name}") elif not method: result = {"capabilities": {"tools": True, "ws": True}} else: raise ValueError(f"Unsupported method: {method}") return {"jsonrpc": "2.0", "id": request_id, "result": result} except Exception as exc: # noqa: BLE001 return { "jsonrpc": "2.0", "id": request_id, "error": {"code": -32000, "message": str(exc)}, } # OpenEnv baseline compatibility aliases. @app.post("/reset") def reset_alias(request: ResetRequest) -> dict[str, Any]: payload = env_reset(request) return _step_payload( observation=payload["observation"], reward=0.5, done=False, info={"reset": True}, ) @app.post("/step") def step_alias(action: dict[str, Any]) -> dict[str, Any]: return env_step(action) @app.get("/state") def state_alias() -> dict[str, Any]: return env_state() @app.get("/metadata") def metadata_alias() -> dict[str, Any]: return env_metadata() @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket) -> None: await websocket.accept() try: while True: raw = await websocket.receive_text() message = json.loads(raw) msg_type = message.get("type") data = message.get("data", {}) or {} try: if msg_type == "reset": request = ResetRequest.model_validate(data) obs = _ENV.reset( seed=request.seed, difficulty=request.difficulty, sub_environment=request.sub_environment, scenario_id=request.scenario_id, patient_id=request.patient_id, ) payload = _step_payload( observation=obs.model_dump(mode="json"), reward=0.5, done=False, info={"reset": True}, ) elif msg_type == "step": obs, reward, done, info = _ENV.step(data) payload = _step_payload( observation=obs.model_dump(mode="json"), reward=reward, done=done, info=info, ) elif msg_type == "state": payload = _ENV.get_state() elif msg_type == "metadata": payload = _ENV.get_metadata() else: raise ValueError(f"Unsupported message type: {msg_type}") await websocket.send_json({"type": "result", "data": payload}) except Exception as exc: # noqa: BLE001 await websocket.send_json( { "type": "error", "data": {"code": "EXECUTION_ERROR", "message": str(exc)}, } ) except WebSocketDisconnect: return def main() -> None: host = os.getenv("POLYGUARD_ENV_HOST", "127.0.0.1") port = int(os.getenv("POLYGUARD_ENV_PORT", "8100")) uvicorn.run("app.env.fastapi_app:app", host=host, port=port, reload=False) if __name__ == "__main__": main()