| """FastAPI wrapper for PolyGuardEnv (OpenEnv-style).""" |
|
|
| from __future__ import annotations |
|
|
| import json |
| import os |
| from typing import Any, Optional |
|
|
| import uvicorn |
| from fastapi import FastAPI, WebSocket, WebSocketDisconnect |
| from pydantic import BaseModel, ConfigDict |
|
|
| from app.common.enums import Difficulty, SubEnvironment |
| from app.common.types import PolyGuardAction, PolyGuardObservation, PolyGuardState |
| from app.env.env_core import PolyGuardEnv |
|
|
| app = FastAPI(title="POLYGUARD-RL Env Service", version="0.1.0") |
| _ENV = PolyGuardEnv() |
|
|
|
|
| class ResetRequest(BaseModel): |
| model_config = ConfigDict(extra="forbid") |
| seed: Optional[int] = None |
| difficulty: Optional[Difficulty] = None |
| sub_environment: Optional[SubEnvironment] = None |
| scenario_id: Optional[str] = None |
| patient_id: Optional[str] = None |
|
|
|
|
| def _step_payload(observation: dict[str, Any], reward: float, done: bool, info: dict[str, Any]) -> dict[str, Any]: |
| reason = str(info.get("termination_reason", "")) if isinstance(info, dict) else "" |
| truncated = reason in {"wall_clock_timeout", "step_timeout", "step_budget_exhausted"} |
| return { |
| "observation": observation, |
| "reward": reward, |
| "done": done, |
| "terminated": done, |
| "truncated": truncated, |
| "info": info, |
| } |
|
|
|
|
| @app.get("/health") |
| def health() -> dict[str, str]: |
| return {"status": "healthy"} |
|
|
|
|
| @app.post("/env/reset") |
| def env_reset(request: ResetRequest) -> dict[str, Any]: |
| obs = _ENV.reset( |
| seed=request.seed, |
| difficulty=request.difficulty, |
| sub_environment=request.sub_environment, |
| scenario_id=request.scenario_id, |
| patient_id=request.patient_id, |
| ) |
| return {"observation": obs.model_dump(mode="json")} |
|
|
|
|
| @app.post("/env/step") |
| def env_step(action: dict[str, Any]) -> dict[str, Any]: |
| obs, reward, done, info = _ENV.step(action) |
| return _step_payload(observation=obs.model_dump(mode="json"), reward=reward, done=done, info=info) |
|
|
|
|
| @app.get("/env/state") |
| def env_state() -> dict[str, Any]: |
| return _ENV.get_state() |
|
|
|
|
| @app.get("/env/trace") |
| def env_trace() -> list[dict[str, Any]]: |
| return _ENV.get_trace() |
|
|
|
|
| @app.get("/env/legal_actions") |
| def env_legal_actions() -> list[dict[str, Any]]: |
| return _ENV.get_legal_actions() |
|
|
|
|
| @app.get("/env/reward_breakdown") |
| def env_reward_breakdown() -> dict[str, Any]: |
| return _ENV.get_reward_breakdown() |
|
|
|
|
| @app.get("/env/uncertainty") |
| def env_uncertainty() -> dict[str, Any]: |
| return _ENV.get_uncertainty_report().model_dump(mode="json") |
|
|
|
|
| @app.get("/env/metadata") |
| def env_metadata() -> dict[str, Any]: |
| return _ENV.get_metadata() |
|
|
|
|
| @app.get("/schema") |
| def schema() -> dict[str, Any]: |
| return { |
| "action": PolyGuardAction.model_json_schema(), |
| "observation": PolyGuardObservation.model_json_schema(), |
| "state": PolyGuardState.model_json_schema(), |
| } |
|
|
|
|
| @app.post("/mcp") |
| def mcp(payload: dict[str, Any]) -> dict[str, Any]: |
| request_id = payload.get("id") |
| method = str(payload.get("method", "") or "") |
| params = payload.get("params", {}) if isinstance(payload.get("params", {}), dict) else {} |
|
|
| try: |
| if method == "tools/list": |
| result = { |
| "tools": [ |
| { |
| "name": "env.reset", |
| "description": "Reset environment and return initial observation payload.", |
| "inputSchema": { |
| "type": "object", |
| "properties": { |
| "seed": {"type": "integer"}, |
| "difficulty": {"type": "string"}, |
| "sub_environment": {"type": "string"}, |
| "scenario_id": {"type": "string"}, |
| "patient_id": {"type": "string"}, |
| }, |
| }, |
| }, |
| { |
| "name": "env.step", |
| "description": "Execute a policy action.", |
| "inputSchema": PolyGuardAction.model_json_schema(), |
| }, |
| { |
| "name": "env.state", |
| "description": "Get current environment state.", |
| "inputSchema": {"type": "object", "properties": {}}, |
| }, |
| { |
| "name": "env.metadata", |
| "description": "Get environment metadata.", |
| "inputSchema": {"type": "object", "properties": {}}, |
| }, |
| ] |
| } |
| elif method == "tools/call": |
| tool_name = str(params.get("name", "") or "") |
| arguments = params.get("arguments", {}) if isinstance(params.get("arguments"), dict) else {} |
| if tool_name == "env.reset": |
| request = ResetRequest.model_validate(arguments) |
| result = env_reset(request) |
| elif tool_name == "env.step": |
| result = env_step(arguments) |
| elif tool_name == "env.state": |
| result = env_state() |
| elif tool_name == "env.metadata": |
| result = env_metadata() |
| else: |
| raise ValueError(f"Unknown tool name: {tool_name}") |
| elif not method: |
| result = {"capabilities": {"tools": True, "ws": True}} |
| else: |
| raise ValueError(f"Unsupported method: {method}") |
| return {"jsonrpc": "2.0", "id": request_id, "result": result} |
| except Exception as exc: |
| return { |
| "jsonrpc": "2.0", |
| "id": request_id, |
| "error": {"code": -32000, "message": str(exc)}, |
| } |
|
|
|
|
| |
| @app.post("/reset") |
| def reset_alias(request: ResetRequest) -> dict[str, Any]: |
| payload = env_reset(request) |
| return _step_payload( |
| observation=payload["observation"], |
| reward=0.5, |
| done=False, |
| info={"reset": True}, |
| ) |
|
|
|
|
| @app.post("/step") |
| def step_alias(action: dict[str, Any]) -> dict[str, Any]: |
| return env_step(action) |
|
|
|
|
| @app.get("/state") |
| def state_alias() -> dict[str, Any]: |
| return env_state() |
|
|
|
|
| @app.get("/metadata") |
| def metadata_alias() -> dict[str, Any]: |
| return env_metadata() |
|
|
|
|
| @app.websocket("/ws") |
| async def websocket_endpoint(websocket: WebSocket) -> None: |
| await websocket.accept() |
| try: |
| while True: |
| raw = await websocket.receive_text() |
| message = json.loads(raw) |
| msg_type = message.get("type") |
| data = message.get("data", {}) or {} |
| try: |
| if msg_type == "reset": |
| request = ResetRequest.model_validate(data) |
| obs = _ENV.reset( |
| seed=request.seed, |
| difficulty=request.difficulty, |
| sub_environment=request.sub_environment, |
| scenario_id=request.scenario_id, |
| patient_id=request.patient_id, |
| ) |
| payload = _step_payload( |
| observation=obs.model_dump(mode="json"), |
| reward=0.5, |
| done=False, |
| info={"reset": True}, |
| ) |
| elif msg_type == "step": |
| obs, reward, done, info = _ENV.step(data) |
| payload = _step_payload( |
| observation=obs.model_dump(mode="json"), |
| reward=reward, |
| done=done, |
| info=info, |
| ) |
| elif msg_type == "state": |
| payload = _ENV.get_state() |
| elif msg_type == "metadata": |
| payload = _ENV.get_metadata() |
| else: |
| raise ValueError(f"Unsupported message type: {msg_type}") |
| await websocket.send_json({"type": "result", "data": payload}) |
| except Exception as exc: |
| await websocket.send_json( |
| { |
| "type": "error", |
| "data": {"code": "EXECUTION_ERROR", "message": str(exc)}, |
| } |
| ) |
| except WebSocketDisconnect: |
| return |
|
|
|
|
| def main() -> None: |
| host = os.getenv("POLYGUARD_ENV_HOST", "127.0.0.1") |
| port = int(os.getenv("POLYGUARD_ENV_PORT", "8100")) |
| uvicorn.run("app.env.fastapi_app:app", host=host, port=port, reload=False) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|