Spaces:
Running
Running
| """FastAPI wrapper for PolyGuardEnv (OpenEnv-style).""" | |
| from __future__ import annotations | |
| import json | |
| import os | |
| from typing import Any, Optional | |
| import uvicorn | |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect | |
| from pydantic import BaseModel, ConfigDict | |
| from app.common.enums import Difficulty, SubEnvironment | |
| from app.common.types import PolyGuardAction, PolyGuardObservation, PolyGuardState | |
| from app.env.env_core import PolyGuardEnv | |
| app = FastAPI(title="POLYGUARD-RL Env Service", version="0.1.0") | |
| _ENV = PolyGuardEnv() | |
| class ResetRequest(BaseModel): | |
| model_config = ConfigDict(extra="forbid") | |
| seed: Optional[int] = None | |
| difficulty: Optional[Difficulty] = None | |
| sub_environment: Optional[SubEnvironment] = None | |
| scenario_id: Optional[str] = None | |
| patient_id: Optional[str] = None | |
| def _step_payload(observation: dict[str, Any], reward: float, done: bool, info: dict[str, Any]) -> dict[str, Any]: | |
| reason = str(info.get("termination_reason", "")) if isinstance(info, dict) else "" | |
| truncated = reason in {"wall_clock_timeout", "step_timeout", "step_budget_exhausted"} | |
| return { | |
| "observation": observation, | |
| "reward": reward, | |
| "done": done, | |
| "terminated": done, | |
| "truncated": truncated, | |
| "info": info, | |
| } | |
| def health() -> dict[str, str]: | |
| return {"status": "healthy"} | |
| def env_reset(request: ResetRequest) -> dict[str, Any]: | |
| obs = _ENV.reset( | |
| seed=request.seed, | |
| difficulty=request.difficulty, | |
| sub_environment=request.sub_environment, | |
| scenario_id=request.scenario_id, | |
| patient_id=request.patient_id, | |
| ) | |
| return {"observation": obs.model_dump(mode="json")} | |
| def env_step(action: dict[str, Any]) -> dict[str, Any]: | |
| obs, reward, done, info = _ENV.step(action) | |
| return _step_payload(observation=obs.model_dump(mode="json"), reward=reward, done=done, info=info) | |
| def env_state() -> dict[str, Any]: | |
| return _ENV.get_state() | |
| def env_trace() -> list[dict[str, Any]]: | |
| return _ENV.get_trace() | |
| def env_legal_actions() -> list[dict[str, Any]]: | |
| return _ENV.get_legal_actions() | |
| def env_reward_breakdown() -> dict[str, Any]: | |
| return _ENV.get_reward_breakdown() | |
| def env_uncertainty() -> dict[str, Any]: | |
| return _ENV.get_uncertainty_report().model_dump(mode="json") | |
| def env_metadata() -> dict[str, Any]: | |
| return _ENV.get_metadata() | |
| def schema() -> dict[str, Any]: | |
| return { | |
| "action": PolyGuardAction.model_json_schema(), | |
| "observation": PolyGuardObservation.model_json_schema(), | |
| "state": PolyGuardState.model_json_schema(), | |
| } | |
| def mcp(payload: dict[str, Any]) -> dict[str, Any]: | |
| request_id = payload.get("id") | |
| method = str(payload.get("method", "") or "") | |
| params = payload.get("params", {}) if isinstance(payload.get("params", {}), dict) else {} | |
| try: | |
| if method == "tools/list": | |
| result = { | |
| "tools": [ | |
| { | |
| "name": "env.reset", | |
| "description": "Reset environment and return initial observation payload.", | |
| "inputSchema": { | |
| "type": "object", | |
| "properties": { | |
| "seed": {"type": "integer"}, | |
| "difficulty": {"type": "string"}, | |
| "sub_environment": {"type": "string"}, | |
| "scenario_id": {"type": "string"}, | |
| "patient_id": {"type": "string"}, | |
| }, | |
| }, | |
| }, | |
| { | |
| "name": "env.step", | |
| "description": "Execute a policy action.", | |
| "inputSchema": PolyGuardAction.model_json_schema(), | |
| }, | |
| { | |
| "name": "env.state", | |
| "description": "Get current environment state.", | |
| "inputSchema": {"type": "object", "properties": {}}, | |
| }, | |
| { | |
| "name": "env.metadata", | |
| "description": "Get environment metadata.", | |
| "inputSchema": {"type": "object", "properties": {}}, | |
| }, | |
| ] | |
| } | |
| elif method == "tools/call": | |
| tool_name = str(params.get("name", "") or "") | |
| arguments = params.get("arguments", {}) if isinstance(params.get("arguments"), dict) else {} | |
| if tool_name == "env.reset": | |
| request = ResetRequest.model_validate(arguments) | |
| result = env_reset(request) | |
| elif tool_name == "env.step": | |
| result = env_step(arguments) | |
| elif tool_name == "env.state": | |
| result = env_state() | |
| elif tool_name == "env.metadata": | |
| result = env_metadata() | |
| else: | |
| raise ValueError(f"Unknown tool name: {tool_name}") | |
| elif not method: | |
| result = {"capabilities": {"tools": True, "ws": True}} | |
| else: | |
| raise ValueError(f"Unsupported method: {method}") | |
| return {"jsonrpc": "2.0", "id": request_id, "result": result} | |
| except Exception as exc: # noqa: BLE001 | |
| return { | |
| "jsonrpc": "2.0", | |
| "id": request_id, | |
| "error": {"code": -32000, "message": str(exc)}, | |
| } | |
| # OpenEnv baseline compatibility aliases. | |
| def reset_alias(request: ResetRequest) -> dict[str, Any]: | |
| payload = env_reset(request) | |
| return _step_payload( | |
| observation=payload["observation"], | |
| reward=0.5, | |
| done=False, | |
| info={"reset": True}, | |
| ) | |
| def step_alias(action: dict[str, Any]) -> dict[str, Any]: | |
| return env_step(action) | |
| def state_alias() -> dict[str, Any]: | |
| return env_state() | |
| def metadata_alias() -> dict[str, Any]: | |
| return env_metadata() | |
| async def websocket_endpoint(websocket: WebSocket) -> None: | |
| await websocket.accept() | |
| try: | |
| while True: | |
| raw = await websocket.receive_text() | |
| message = json.loads(raw) | |
| msg_type = message.get("type") | |
| data = message.get("data", {}) or {} | |
| try: | |
| if msg_type == "reset": | |
| request = ResetRequest.model_validate(data) | |
| obs = _ENV.reset( | |
| seed=request.seed, | |
| difficulty=request.difficulty, | |
| sub_environment=request.sub_environment, | |
| scenario_id=request.scenario_id, | |
| patient_id=request.patient_id, | |
| ) | |
| payload = _step_payload( | |
| observation=obs.model_dump(mode="json"), | |
| reward=0.5, | |
| done=False, | |
| info={"reset": True}, | |
| ) | |
| elif msg_type == "step": | |
| obs, reward, done, info = _ENV.step(data) | |
| payload = _step_payload( | |
| observation=obs.model_dump(mode="json"), | |
| reward=reward, | |
| done=done, | |
| info=info, | |
| ) | |
| elif msg_type == "state": | |
| payload = _ENV.get_state() | |
| elif msg_type == "metadata": | |
| payload = _ENV.get_metadata() | |
| else: | |
| raise ValueError(f"Unsupported message type: {msg_type}") | |
| await websocket.send_json({"type": "result", "data": payload}) | |
| except Exception as exc: # noqa: BLE001 | |
| await websocket.send_json( | |
| { | |
| "type": "error", | |
| "data": {"code": "EXECUTION_ERROR", "message": str(exc)}, | |
| } | |
| ) | |
| except WebSocketDisconnect: | |
| return | |
| def main() -> None: | |
| host = os.getenv("POLYGUARD_ENV_HOST", "127.0.0.1") | |
| port = int(os.getenv("POLYGUARD_ENV_PORT", "8100")) | |
| uvicorn.run("app.env.fastapi_app:app", host=host, port=port, reload=False) | |
| if __name__ == "__main__": | |
| main() | |