Spaces:
Sleeping
Sleeping
| """ | |
| FastAPI server exposing the Email Triage environment via HTTP. | |
| Endpoints mirror the OpenEnv spec. | |
| """ | |
| from fastapi import FastAPI, HTTPException, Body | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| from typing import Optional, Union | |
| import uvicorn | |
| import os | |
| import sys | |
| # Ensure the root directory is in sys.path so environment.py can be imported | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(__file__))) | |
| from environment import EmailTriageEnv, Action | |
| app = FastAPI(title="Email Triage Environment", version="1.0.0") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # One env per task (task is set at reset time) | |
| _envs: dict[int, EmailTriageEnv] = {} | |
| def _parse_task(task: Union[int, str]) -> int: | |
| if isinstance(task, str): | |
| if task.startswith("task"): | |
| return int(task[4:]) | |
| return int(task) | |
| return task | |
| class ResetRequest(BaseModel): | |
| task: Union[int, str] = 1 | |
| class StepRequest(BaseModel): | |
| task: Union[int, str] = 1 | |
| action: Action | |
| def _get_env(task: Union[int, str]) -> EmailTriageEnv: | |
| task_int = _parse_task(task) | |
| if task_int not in _envs: | |
| raise HTTPException(status_code=400, detail=f"Task {task_int} not initialised. Call /reset first.") | |
| return _envs[task_int] | |
| def health(): | |
| return {"status": "ok"} | |
| def reset(req: Optional[ResetRequest] = Body(default=None)): | |
| task = _parse_task(req.task if req else 1) | |
| env = EmailTriageEnv(task=task) | |
| obs = env.reset() | |
| _envs[task] = env | |
| return {"observation": obs.model_dump(), "state": env.state()} | |
| def step(req: StepRequest): | |
| env = _get_env(req.task) | |
| result = env.step(req.action) | |
| return { | |
| "observation": result.observation.model_dump(), | |
| "reward": result.reward, | |
| "done": result.done, | |
| "info": result.info, | |
| "score": env.score(), | |
| } | |
| def state(task: Union[int, str] = 1): | |
| env = _get_env(task) | |
| return {"state": env.state(), "score": env.score()} | |
| def score(task: Union[int, str] = 1): | |
| env = _get_env(task) | |
| return {"score": env.score(), "task": _parse_task(task)} | |
| def main(): | |
| uvicorn.run(app, host="0.0.0.0", port=7860) | |
| if __name__ == "__main__": | |
| main() | |