TheJackBright's picture
Deploy GitHub root master to Space
c296d62
"""FastAPI wrapper for PolyGuardEnv (OpenEnv-style)."""
from __future__ import annotations
import json
import os
from typing import Any, Optional
import uvicorn
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from pydantic import BaseModel, ConfigDict
from app.common.config import load_project_env
from app.common.enums import Difficulty, SubEnvironment
from app.common.types import PolyGuardAction, PolyGuardObservation, PolyGuardState
from app.env.env_core import PolyGuardEnv
load_project_env()
app = FastAPI(title="POLYGUARD-RL Env Service", version="0.1.0")
_ENV = PolyGuardEnv()
class ResetRequest(BaseModel):
model_config = ConfigDict(extra="forbid")
seed: Optional[int] = None
difficulty: Optional[Difficulty] = None
sub_environment: Optional[SubEnvironment] = None
scenario_id: Optional[str] = None
patient_id: Optional[str] = None
def _step_payload(observation: dict[str, Any], reward: float, done: bool, info: dict[str, Any]) -> dict[str, Any]:
reason = str(info.get("termination_reason", "")) if isinstance(info, dict) else ""
truncated = reason in {"wall_clock_timeout", "step_timeout", "step_budget_exhausted"}
return {
"observation": observation,
"reward": reward,
"done": done,
"terminated": done,
"truncated": truncated,
"info": info,
}
@app.get("/health")
def health() -> dict[str, str]:
return {"status": "healthy"}
@app.post("/env/reset")
def env_reset(request: ResetRequest) -> dict[str, Any]:
obs = _ENV.reset(
seed=request.seed,
difficulty=request.difficulty,
sub_environment=request.sub_environment,
scenario_id=request.scenario_id,
patient_id=request.patient_id,
)
return {"observation": obs.model_dump(mode="json")}
@app.post("/env/step")
def env_step(action: dict[str, Any]) -> dict[str, Any]:
obs, reward, done, info = _ENV.step(action)
return _step_payload(observation=obs.model_dump(mode="json"), reward=reward, done=done, info=info)
@app.get("/env/state")
def env_state() -> dict[str, Any]:
return _ENV.get_state()
@app.get("/env/trace")
def env_trace() -> list[dict[str, Any]]:
return _ENV.get_trace()
@app.get("/env/legal_actions")
def env_legal_actions() -> list[dict[str, Any]]:
return _ENV.get_legal_actions()
@app.get("/env/reward_breakdown")
def env_reward_breakdown() -> dict[str, Any]:
return _ENV.get_reward_breakdown()
@app.get("/env/uncertainty")
def env_uncertainty() -> dict[str, Any]:
return _ENV.get_uncertainty_report().model_dump(mode="json")
@app.get("/env/metadata")
def env_metadata() -> dict[str, Any]:
return _ENV.get_metadata()
@app.get("/schema")
def schema() -> dict[str, Any]:
return {
"action": PolyGuardAction.model_json_schema(),
"observation": PolyGuardObservation.model_json_schema(),
"state": PolyGuardState.model_json_schema(),
}
@app.post("/mcp")
def mcp(payload: dict[str, Any]) -> dict[str, Any]:
request_id = payload.get("id")
method = str(payload.get("method", "") or "")
params = payload.get("params", {}) if isinstance(payload.get("params", {}), dict) else {}
try:
if method == "tools/list":
result = {
"tools": [
{
"name": "env.reset",
"description": "Reset environment and return initial observation payload.",
"inputSchema": {
"type": "object",
"properties": {
"seed": {"type": "integer"},
"difficulty": {"type": "string"},
"sub_environment": {"type": "string"},
"scenario_id": {"type": "string"},
"patient_id": {"type": "string"},
},
},
},
{
"name": "env.step",
"description": "Execute a policy action.",
"inputSchema": PolyGuardAction.model_json_schema(),
},
{
"name": "env.state",
"description": "Get current environment state.",
"inputSchema": {"type": "object", "properties": {}},
},
{
"name": "env.metadata",
"description": "Get environment metadata.",
"inputSchema": {"type": "object", "properties": {}},
},
]
}
elif method == "tools/call":
tool_name = str(params.get("name", "") or "")
arguments = params.get("arguments", {}) if isinstance(params.get("arguments"), dict) else {}
if tool_name == "env.reset":
request = ResetRequest.model_validate(arguments)
result = env_reset(request)
elif tool_name == "env.step":
result = env_step(arguments)
elif tool_name == "env.state":
result = env_state()
elif tool_name == "env.metadata":
result = env_metadata()
else:
raise ValueError(f"Unknown tool name: {tool_name}")
elif not method:
result = {"capabilities": {"tools": True, "ws": True}}
else:
raise ValueError(f"Unsupported method: {method}")
return {"jsonrpc": "2.0", "id": request_id, "result": result}
except Exception as exc: # noqa: BLE001
return {
"jsonrpc": "2.0",
"id": request_id,
"error": {"code": -32000, "message": str(exc)},
}
# OpenEnv baseline compatibility aliases.
@app.post("/reset")
def reset_alias(request: ResetRequest) -> dict[str, Any]:
payload = env_reset(request)
return _step_payload(
observation=payload["observation"],
reward=0.5,
done=False,
info={"reset": True},
)
@app.post("/step")
def step_alias(action: dict[str, Any]) -> dict[str, Any]:
return env_step(action)
@app.get("/state")
def state_alias() -> dict[str, Any]:
return env_state()
@app.get("/metadata")
def metadata_alias() -> dict[str, Any]:
return env_metadata()
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket) -> None:
await websocket.accept()
try:
while True:
raw = await websocket.receive_text()
message = json.loads(raw)
msg_type = message.get("type")
data = message.get("data", {}) or {}
try:
if msg_type == "reset":
request = ResetRequest.model_validate(data)
obs = _ENV.reset(
seed=request.seed,
difficulty=request.difficulty,
sub_environment=request.sub_environment,
scenario_id=request.scenario_id,
patient_id=request.patient_id,
)
payload = _step_payload(
observation=obs.model_dump(mode="json"),
reward=0.5,
done=False,
info={"reset": True},
)
elif msg_type == "step":
obs, reward, done, info = _ENV.step(data)
payload = _step_payload(
observation=obs.model_dump(mode="json"),
reward=reward,
done=done,
info=info,
)
elif msg_type == "state":
payload = _ENV.get_state()
elif msg_type == "metadata":
payload = _ENV.get_metadata()
else:
raise ValueError(f"Unsupported message type: {msg_type}")
await websocket.send_json({"type": "result", "data": payload})
except Exception as exc: # noqa: BLE001
await websocket.send_json(
{
"type": "error",
"data": {"code": "EXECUTION_ERROR", "message": str(exc)},
}
)
except WebSocketDisconnect:
return
def main() -> None:
host = os.getenv("POLYGUARD_ENV_HOST", "127.0.0.1")
port = int(os.getenv("POLYGUARD_ENV_PORT", "8100"))
uvicorn.run("app.env.fastapi_app:app", host=host, port=port, reload=False)
if __name__ == "__main__":
main()