Spaces:
Runtime error
Runtime error
File size: 3,072 Bytes
ab65ac6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 | import os
import uuid
from collections import OrderedDict
from typing import Dict
from fastapi import FastAPI, Request
from pydantic import ValidationError
from aegis_env.environment import AEGISEnvironment
from aegis_env.models import AEGISAction
scenario_dir = os.getenv("SCENARIO_DIR", None)
worker_mode = os.getenv("WORKER_MODE", "scripted")
memory_enabled = os.getenv("MEMORY_ENABLED", "true").lower() == "true"
seed = int(os.getenv("SEED", "42"))
# Session registry — each client gets its own env instance
_sessions: OrderedDict[str, AEGISEnvironment] = OrderedDict()
MAX_SESSIONS = 100
def _get_or_create_env(session_id: str) -> AEGISEnvironment:
if session_id in _sessions:
_sessions.move_to_end(session_id)
return _sessions[session_id]
env = AEGISEnvironment(
scenario_dir=scenario_dir,
worker_mode=worker_mode,
memory_enabled=memory_enabled,
seed=seed,
)
_sessions[session_id] = env
if len(_sessions) > MAX_SESSIONS:
_sessions.popitem(last=False) # evict oldest
return env
app = FastAPI(title="AEGIS-Env", description="OpenEnv backend for RL model oversight.")
@app.get("/")
async def root():
return {
"name": "AEGIS-Env",
"description": "OpenEnv backend for RL model oversight",
"version": "1.0",
"endpoints": {
"POST /reset": "Start a new episode (returns session_id)",
"POST /step": "Execute an action (body: {session_id, decision, confidence, violation_type, explanation})",
},
"docs": "/docs",
"health": "/health",
}
@app.get("/health")
async def health():
return {"status": "ok"}
@app.post("/reset")
async def reset_env(request: Request):
body = {}
try:
body = await request.json()
except Exception:
pass
session_id = body.get("session_id") or str(uuid.uuid4())
env = _get_or_create_env(session_id)
obs, info = env.reset()
return {"session_id": session_id, "observation": obs, "info": info}
@app.post("/step")
async def step_env(request: Request):
try:
body = await request.json()
except Exception:
body = {}
session_id = body.get("session_id", "default")
env = _get_or_create_env(session_id)
try:
validated = AEGISAction(**{k: v for k, v in body.items() if k != "session_id"})
action_dict = validated.model_dump()
action_dict["__valid__"] = True
except (ValidationError, TypeError):
action_dict = {
"decision": "ALLOW",
"confidence": 0.5,
"violation_type": "none",
"explanation": "",
"__valid__": False,
}
obs, reward, done, info = env.step(action_dict)
return {
"session_id": session_id,
"observation": obs,
"reward": float(reward),
"done": done,
"info": info,
}
def main():
import uvicorn
port = int(os.getenv("PORT", "7860"))
uvicorn.run(app, host="0.0.0.0", port=port)
if __name__ == "__main__":
main()
|