Dev Shah
fix: make task parsing robust to string IDs for openenv task validation
5107b13
"""
FastAPI server exposing the Email Triage environment via HTTP.
Endpoints mirror the OpenEnv spec.
"""
from fastapi import FastAPI, HTTPException, Body
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import Optional, Union
import uvicorn
import os
import sys
# Ensure the root directory is in sys.path so environment.py can be imported
sys.path.insert(0, os.path.dirname(os.path.dirname(__file__)))
from environment import EmailTriageEnv, Action
app = FastAPI(title="Email Triage Environment", version="1.0.0")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# One env per task (task is set at reset time)
_envs: dict[int, EmailTriageEnv] = {}
def _parse_task(task: Union[int, str]) -> int:
if isinstance(task, str):
if task.startswith("task"):
return int(task[4:])
return int(task)
return task
class ResetRequest(BaseModel):
task: Union[int, str] = 1
class StepRequest(BaseModel):
task: Union[int, str] = 1
action: Action
def _get_env(task: Union[int, str]) -> EmailTriageEnv:
task_int = _parse_task(task)
if task_int not in _envs:
raise HTTPException(status_code=400, detail=f"Task {task_int} not initialised. Call /reset first.")
return _envs[task_int]
@app.get("/health")
def health():
return {"status": "ok"}
@app.post("/reset")
def reset(req: Optional[ResetRequest] = Body(default=None)):
task = _parse_task(req.task if req else 1)
env = EmailTriageEnv(task=task)
obs = env.reset()
_envs[task] = env
return {"observation": obs.model_dump(), "state": env.state()}
@app.post("/step")
def step(req: StepRequest):
env = _get_env(req.task)
result = env.step(req.action)
return {
"observation": result.observation.model_dump(),
"reward": result.reward,
"done": result.done,
"info": result.info,
"score": env.score(),
}
@app.get("/state")
def state(task: Union[int, str] = 1):
env = _get_env(task)
return {"state": env.state(), "score": env.score()}
@app.get("/score")
def score(task: Union[int, str] = 1):
env = _get_env(task)
return {"score": env.score(), "task": _parse_task(task)}
def main():
uvicorn.run(app, host="0.0.0.0", port=7860)
if __name__ == "__main__":
main()