GSD Bot
feat: apply master fix to HF environment
7a529e5
import os
import uuid
from collections import OrderedDict
from typing import Dict
from fastapi import FastAPI, Request
from pydantic import ValidationError
from aegis_env.environment import AEGISEnvironment
from aegis_env.models import AEGISAction
scenario_dir = os.getenv("SCENARIO_DIR", None)
worker_mode = os.getenv("WORKER_MODE", "scripted")
memory_enabled = os.getenv("MEMORY_ENABLED", "true").lower() == "true"
seed = int(os.getenv("SEED", "42"))
# Session registry — each client gets its own env instance
_sessions: OrderedDict[str, AEGISEnvironment] = OrderedDict()
MAX_SESSIONS = 100
def _get_or_create_env(session_id: str) -> AEGISEnvironment:
if session_id in _sessions:
_sessions.move_to_end(session_id)
return _sessions[session_id]
env = AEGISEnvironment(
scenario_dir=scenario_dir,
worker_mode=worker_mode,
memory_enabled=memory_enabled,
seed=seed,
)
_sessions[session_id] = env
if len(_sessions) > MAX_SESSIONS:
_sessions.popitem(last=False) # evict oldest
return env
app = FastAPI(title="AEGIS-Env", description="OpenEnv backend for RL model oversight.")
@app.get("/")
async def root():
return {
"name": "AEGIS-Env",
"description": "OpenEnv backend for RL model oversight",
"version": "1.0",
"endpoints": {
"POST /reset": "Start a new episode (returns session_id)",
"POST /step": "Execute an action (body: {session_id, decision, confidence, violation_type, explanation})",
},
"docs": "/docs",
"health": "/health",
}
@app.get("/health")
async def health():
return {"status": "ok"}
@app.post("/reset")
async def reset_env(request: Request):
body = {}
try:
body = await request.json()
except Exception:
pass
session_id = body.get("session_id") or str(uuid.uuid4())
env = _get_or_create_env(session_id)
obs, info = env.reset()
return {"session_id": session_id, "observation": obs, "info": info}
@app.post("/step")
async def step_env(request: Request):
try:
body = await request.json()
except Exception:
body = {}
session_id = body.get("session_id", "default")
env = _get_or_create_env(session_id)
try:
validated = AEGISAction(**{k: v for k, v in body.items() if k != "session_id"})
action_dict = validated.model_dump()
action_dict["__valid__"] = True
except (ValidationError, TypeError):
action_dict = {
"decision": "ALLOW",
"confidence": 0.5,
"violation_type": "none",
"explanation": "",
"__valid__": False,
}
obs, reward, done, info = env.step(action_dict)
return {
"session_id": session_id,
"observation": obs,
"reward": float(reward),
"done": done,
"info": info,
}
def main():
import uvicorn
port = int(os.getenv("PORT", "7860"))
uvicorn.run(app, host="0.0.0.0", port=port)
if __name__ == "__main__":
main()