from __future__ import annotations import uuid from typing import Any, Dict, Optional from fastapi import FastAPI, HTTPException, Query from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from env.environment import DataCleaningEnv from env.models import Action, Observation, StepResult from env.tasks import list_tasks as list_task_specs app = FastAPI( title="Data Cleaning OpenEnv Benchmark", version="1.0.0", description="LLM agent benchmark for real-world data cleaning tasks.", ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) sessions: Dict[str, DataCleaningEnv] = {} @app.get("/") def root(): tasks = list_task_specs() return { "name": "Data Cleaning OpenEnv Benchmark", "version": "1.0.0", "tasks": tasks, "api": { "reset": "POST /reset", "step": "POST /step/{session_id}", "step_compat": "POST /step", "state": "GET /state/{session_id}", "state_compat": "GET /state?session_id=...", "metadata": "GET /metadata", "schema": "GET /schema", "mcp": "GET|POST /mcp", "health": "GET /health", }, } @app.get("/health") def health(): return {"status": "ok", "sessions_active": len(sessions)} class ResetRequest(BaseModel): task_id: Optional[str] = None @app.post("/reset") def reset(body: ResetRequest = ResetRequest()): session_id = str(uuid.uuid4()) env = DataCleaningEnv() obs = env.reset(task_id=body.task_id) sessions[session_id] = env return { "session_id": session_id, "observation": obs.model_dump(), "reward": 0.0, "done": False, "info": { "error": None, "cumulative_reward": env.cumulative_reward, "raw_cumulative_reward": env.raw_cumulative_reward, "final_score": env.final_score, "step": env.step_count, }, } @app.post("/step") def step_compat( payload: Dict[str, Any], session_id: Optional[str] = Query(default=None), ): payload_session_id = payload.get("session_id") resolved_session_id = _resolve_session_id(payload_session_id or session_id) action_payload = payload.get("action", payload) if not isinstance(action_payload, dict): raise HTTPException(status_code=400, detail="Action payload must be an object") if "type" not in action_payload: raise HTTPException(status_code=400, detail="Action payload requires 'type'") action = Action(**action_payload) env = _get_session(resolved_session_id) result = env.step(action) return result.model_dump() @app.post("/step/{session_id}") def step(session_id: str, action: Action): env = _get_session(session_id) result = env.step(action) return result.model_dump() @app.get("/state") def state_compat(session_id: Optional[str] = Query(default=None)): env = _get_session(_resolve_session_id(session_id)) return env.state() @app.get("/state/{session_id}") def state(session_id: str): env = _get_session(session_id) return env.state() @app.get("/metadata") def metadata(): return { "name": "data-cleaning-benchmark", "version": "1.0.0", "description": "LLM agent benchmark for real-world data cleaning tasks.", "tasks": list_task_specs(), "score_range": { "min": DataCleaningEnv.MIN_EPISODE_SCORE, "max": DataCleaningEnv.MAX_EPISODE_SCORE, }, "entrypoints": { "reset": "/reset", "step": "/step", "state": "/state", "health": "/health", "tasks": "/tasks", "schema": "/schema", "mcp": "/mcp", }, } @app.get("/schema") def schema(): return { "action": Action.model_json_schema(), "observation": Observation.model_json_schema(), "step_result": StepResult.model_json_schema(), "reset_request": ResetRequest.model_json_schema(), } @app.api_route("/mcp", methods=["GET", "POST"]) def mcp_metadata(): return { "supported": False, "message": "This benchmark exposes simulation HTTP endpoints (reset/step/state).", } @app.delete("/session/{session_id}") def delete_session(session_id: str): sessions.pop(session_id, None) return {"deleted": session_id} @app.get("/tasks") def list_tasks(): return {"tasks": list_task_specs()} def _resolve_session_id(session_id: Optional[str]) -> str: if session_id: return session_id if len(sessions) == 1: return next(iter(sessions.keys())) raise HTTPException( status_code=400, detail="session_id is required when there is not exactly one active session", ) def _get_session(session_id: str) -> DataCleaningEnv: env = sessions.get(session_id) if env is None: raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found") return env