Spaces:
Sleeping
Sleeping
| """ | |
| FastAPI application exposing the OpenEnv-compatible HTTP API. | |
| Endpoints: GET /health, GET /metadata, GET /schema, | |
| POST /reset, POST /step, GET /state, POST /state, GET /docs | |
| """ | |
| from typing import Any, Dict, Optional | |
| from fastapi import Body, FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| import uvicorn | |
| from models import DataCleaningAction, DataCleaningObservation, DataCleaningState | |
| from server.environment import DataCleaningEnvironment | |
| app = FastAPI( | |
| title="Data Cleaning OpenEnv", | |
| description="A real-world data cleaning environment for AI agent training.", | |
| version="0.1.0", | |
| ) | |
| # Single shared environment instance (stateful server) | |
| env = DataCleaningEnvironment() | |
| class ResetRequest(BaseModel): | |
| task_id: Optional[int] = None | |
| class StepResponse(BaseModel): | |
| observation: DataCleaningObservation | |
| reward: float | |
| done: bool | |
| info: dict = {} | |
| # ------------------------------------------------------------------ | |
| # Routes | |
| # ------------------------------------------------------------------ | |
| def health(): | |
| return {"status": "healthy"} | |
| def metadata(): | |
| return { | |
| "name": "data-cleaning-env", | |
| "description": ( | |
| "A real-world data cleaning environment where an AI agent fixes " | |
| "missing values, duplicate rows, format inconsistencies, outliers, " | |
| "and dtype errors across three progressively harder tasks." | |
| ), | |
| "version": "0.1.0", | |
| "tags": ["openenv", "data-cleaning", "rl", "real-world"], | |
| "tasks": [ | |
| {"id": "task1", "name": "Fill Missing Values", "difficulty": "easy"}, | |
| {"id": "task2", "name": "Fix Formats and Remove Duplicates", "difficulty": "medium"}, | |
| {"id": "task3", "name": "Full Cleaning Pipeline", "difficulty": "hard"}, | |
| ], | |
| } | |
| def schema(): | |
| return { | |
| "action": { | |
| "type": "object", | |
| "properties": { | |
| "operation": { | |
| "type": "string", | |
| "enum": [ | |
| "fill_missing", | |
| "drop_duplicates", | |
| "fix_format", | |
| "replace_value", | |
| "drop_outliers", | |
| "fix_dtype", | |
| ], | |
| }, | |
| "column": {"type": "string", "nullable": True}, | |
| "params": {"type": "object", "nullable": True}, | |
| }, | |
| "required": ["operation"], | |
| }, | |
| "observation": { | |
| "type": "object", | |
| "properties": { | |
| "done": {"type": "boolean"}, | |
| "reward": {"type": "number"}, | |
| "data_preview": {"type": "string"}, | |
| "data_shape": {"type": "array", "items": {"type": "integer"}}, | |
| "missing_counts": {"type": "object"}, | |
| "duplicate_count": {"type": "integer"}, | |
| "dtype_issues": {"type": "object"}, | |
| "task_description": {"type": "string"}, | |
| "message": {"type": "string"}, | |
| "step_count": {"type": "integer"}, | |
| "current_score": {"type": "number"}, | |
| }, | |
| }, | |
| "state": { | |
| "type": "object", | |
| "properties": { | |
| "episode_id": {"type": "string"}, | |
| "task_id": {"type": "integer"}, | |
| "step_count": {"type": "integer"}, | |
| "max_steps": {"type": "integer"}, | |
| "total_errors": {"type": "integer"}, | |
| "errors_remaining": {"type": "integer"}, | |
| }, | |
| }, | |
| } | |
| def reset(req: ResetRequest = ResetRequest()): | |
| try: | |
| obs = env.reset(task_id=req.task_id) | |
| except ValueError as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| return StepResponse(observation=obs, reward=obs.reward, done=False) | |
| async def step(body: Dict[str, Any] = Body(...)): | |
| """ | |
| Accept both openenv-core wrapped format: | |
| {"action": {"operation": "...", ...}, "timeout_s": 15} | |
| and direct format (for backward compat with our own client/inference): | |
| {"operation": "...", "column": "...", "params": {...}} | |
| """ | |
| action_data = body.get("action", body) | |
| try: | |
| action = DataCleaningAction(**action_data) | |
| obs = env.step(action) | |
| except (TypeError, KeyError, Exception) as e: | |
| raise HTTPException(status_code=400, detail=str(e)) | |
| return StepResponse(observation=obs, reward=obs.reward, done=obs.done) | |
| def state_get(): | |
| """GET /state — openenv-core spec.""" | |
| return env.state() | |
| def state_post(): | |
| """POST /state — backward compatibility.""" | |
| return env.state() | |
| # ------------------------------------------------------------------ | |
| # Entry point (required by openenv-core and [project.scripts]) | |
| # ------------------------------------------------------------------ | |
| def main(): | |
| uvicorn.run("server.app:app", host="0.0.0.0", port=8000) | |
| if __name__ == "__main__": | |
| main() | |