it-support-triage / server.py
kevanthonyP's picture
Update server.py
ab6a3f1 verified
"""
server.py β€” FastAPI server exposing the OpenEnv HTTP API.
"""
import os
import traceback
from fastapi import FastAPI, HTTPException, Body
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel, ValidationError
from typing import Optional
from env_models import TriageAction, StepResult, EnvState, TicketObservation
from env_core import ITSupportEnv
app = FastAPI(
title="IT Support Triage β€” OpenEnv",
description=(
"An OpenEnv-compliant RL environment for training and evaluating agents "
"on IT helpdesk ticket triage tasks. Includes 3 tasks (easy β†’ medium β†’ hard) "
"with deterministic graders and safety-aware reward functions."
),
version="1.0.0",
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
env = ITSupportEnv()
class ResetRequest(BaseModel):
task_id: Optional[str] = "task_easy"
class StepRequest(BaseModel):
action: TriageAction
# ─── Endpoints ────────────────────────────────────────────────────────────────
@app.get("/")
def root():
return {
"environment": "it-support-triage",
"version": "1.0.0",
"status": "ok",
"endpoints": {
"POST /reset": "Start episode. Body: {task_id: task_easy|task_medium|task_hard}",
"POST /step": "Submit action. Body: {action: {...}}",
"GET /state": "Current environment state",
"GET /tasks": "List all tasks",
"GET /health": "Health check",
"GET /docs": "Interactive API docs (Swagger UI)",
}
}
@app.get("/health")
def health():
return {"status": "ok", "environment": "it-support-triage", "version": "1.0.0"}
@app.post("/reset", response_model=TicketObservation)
async def reset(request: Optional[ResetRequest] = Body(default=None)):
task_id = (request.task_id if request else None) or "task_easy"
try:
obs = env.reset(task_id)
return obs
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=traceback.format_exc())
@app.post("/step", response_model=StepResult)
def step(request: StepRequest):
try:
result = env.step(request.action)
return result
except RuntimeError as e:
raise HTTPException(status_code=400, detail=str(e))
except ValidationError as e:
raise HTTPException(status_code=422, detail=str(e))
except Exception as e:
raise HTTPException(status_code=500, detail=traceback.format_exc())
@app.get("/state", response_model=EnvState)
def state():
return env.state()
@app.get("/tasks")
def list_tasks():
return {"tasks": env.list_tasks()}
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", 7860))
uvicorn.run("server:app", host="0.0.0.0", port=port, reload=False)