OpenEnv_hack / server /app.py
Taniieeee83's picture
Fix step endpoint to accept openenv-core wrapped action format and add GET /state
498deff
raw
history blame
5.38 kB
"""
FastAPI application exposing the OpenEnv-compatible HTTP API.
Endpoints: GET /health, GET /metadata, GET /schema,
POST /reset, POST /step, GET /state, POST /state, GET /docs
"""
from typing import Any, Dict, Optional
from fastapi import Body, FastAPI, HTTPException
from pydantic import BaseModel
import uvicorn
from models import DataCleaningAction, DataCleaningObservation, DataCleaningState
from server.environment import DataCleaningEnvironment
app = FastAPI(
title="Data Cleaning OpenEnv",
description="A real-world data cleaning environment for AI agent training.",
version="0.1.0",
)
# Single shared environment instance (stateful server)
env = DataCleaningEnvironment()
class ResetRequest(BaseModel):
task_id: Optional[int] = None
class StepResponse(BaseModel):
observation: DataCleaningObservation
reward: float
done: bool
info: dict = {}
# ------------------------------------------------------------------
# Routes
# ------------------------------------------------------------------
@app.get("/health")
def health():
return {"status": "healthy"}
@app.get("/metadata")
def metadata():
return {
"name": "data-cleaning-env",
"description": (
"A real-world data cleaning environment where an AI agent fixes "
"missing values, duplicate rows, format inconsistencies, outliers, "
"and dtype errors across three progressively harder tasks."
),
"version": "0.1.0",
"tags": ["openenv", "data-cleaning", "rl", "real-world"],
"tasks": [
{"id": "task1", "name": "Fill Missing Values", "difficulty": "easy"},
{"id": "task2", "name": "Fix Formats and Remove Duplicates", "difficulty": "medium"},
{"id": "task3", "name": "Full Cleaning Pipeline", "difficulty": "hard"},
],
}
@app.get("/schema")
def schema():
return {
"action": {
"type": "object",
"properties": {
"operation": {
"type": "string",
"enum": [
"fill_missing",
"drop_duplicates",
"fix_format",
"replace_value",
"drop_outliers",
"fix_dtype",
],
},
"column": {"type": "string", "nullable": True},
"params": {"type": "object", "nullable": True},
},
"required": ["operation"],
},
"observation": {
"type": "object",
"properties": {
"done": {"type": "boolean"},
"reward": {"type": "number"},
"data_preview": {"type": "string"},
"data_shape": {"type": "array", "items": {"type": "integer"}},
"missing_counts": {"type": "object"},
"duplicate_count": {"type": "integer"},
"dtype_issues": {"type": "object"},
"task_description": {"type": "string"},
"message": {"type": "string"},
"step_count": {"type": "integer"},
"current_score": {"type": "number"},
},
},
"state": {
"type": "object",
"properties": {
"episode_id": {"type": "string"},
"task_id": {"type": "integer"},
"step_count": {"type": "integer"},
"max_steps": {"type": "integer"},
"total_errors": {"type": "integer"},
"errors_remaining": {"type": "integer"},
},
},
}
@app.post("/reset", response_model=StepResponse)
def reset(req: ResetRequest = ResetRequest()):
try:
obs = env.reset(task_id=req.task_id)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return StepResponse(observation=obs, reward=obs.reward, done=False)
@app.post("/step", response_model=StepResponse)
async def step(body: Dict[str, Any] = Body(...)):
"""
Accept both openenv-core wrapped format:
{"action": {"operation": "...", ...}, "timeout_s": 15}
and direct format (for backward compat with our own client/inference):
{"operation": "...", "column": "...", "params": {...}}
"""
action_data = body.get("action", body)
try:
action = DataCleaningAction(**action_data)
obs = env.step(action)
except (TypeError, KeyError, Exception) as e:
raise HTTPException(status_code=400, detail=str(e))
return StepResponse(observation=obs, reward=obs.reward, done=obs.done)
@app.get("/state", response_model=DataCleaningState)
def state_get():
"""GET /state — openenv-core spec."""
return env.state()
@app.post("/state", response_model=DataCleaningState)
def state_post():
"""POST /state — backward compatibility."""
return env.state()
# ------------------------------------------------------------------
# Entry point (required by openenv-core and [project.scripts])
# ------------------------------------------------------------------
def main():
uvicorn.run("server.app:app", host="0.0.0.0", port=8000)
if __name__ == "__main__":
main()