Spaces:
Sleeping
Sleeping
File size: 5,376 Bytes
d2d30e9 5ededc8 498deff d2d30e9 498deff d2d30e9 d4930ce d2d30e9 5ededc8 d2d30e9 69e5273 d2d30e9 498deff d2d30e9 498deff d2d30e9 498deff d2d30e9 498deff d2d30e9 498deff d2d30e9 d4930ce 498deff | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 | """
FastAPI application exposing the OpenEnv-compatible HTTP API.
Endpoints: GET /health, GET /metadata, GET /schema,
POST /reset, POST /step, GET /state, POST /state, GET /docs
"""
from typing import Any, Dict, Optional
from fastapi import Body, FastAPI, HTTPException
from pydantic import BaseModel
import uvicorn
from models import DataCleaningAction, DataCleaningObservation, DataCleaningState
from server.environment import DataCleaningEnvironment
app = FastAPI(
title="Data Cleaning OpenEnv",
description="A real-world data cleaning environment for AI agent training.",
version="0.1.0",
)
# Single shared environment instance (stateful server)
env = DataCleaningEnvironment()
class ResetRequest(BaseModel):
task_id: Optional[int] = None
class StepResponse(BaseModel):
observation: DataCleaningObservation
reward: float
done: bool
info: dict = {}
# ------------------------------------------------------------------
# Routes
# ------------------------------------------------------------------
@app.get("/health")
def health():
return {"status": "healthy"}
@app.get("/metadata")
def metadata():
return {
"name": "data-cleaning-env",
"description": (
"A real-world data cleaning environment where an AI agent fixes "
"missing values, duplicate rows, format inconsistencies, outliers, "
"and dtype errors across three progressively harder tasks."
),
"version": "0.1.0",
"tags": ["openenv", "data-cleaning", "rl", "real-world"],
"tasks": [
{"id": "task1", "name": "Fill Missing Values", "difficulty": "easy"},
{"id": "task2", "name": "Fix Formats and Remove Duplicates", "difficulty": "medium"},
{"id": "task3", "name": "Full Cleaning Pipeline", "difficulty": "hard"},
],
}
@app.get("/schema")
def schema():
return {
"action": {
"type": "object",
"properties": {
"operation": {
"type": "string",
"enum": [
"fill_missing",
"drop_duplicates",
"fix_format",
"replace_value",
"drop_outliers",
"fix_dtype",
],
},
"column": {"type": "string", "nullable": True},
"params": {"type": "object", "nullable": True},
},
"required": ["operation"],
},
"observation": {
"type": "object",
"properties": {
"done": {"type": "boolean"},
"reward": {"type": "number"},
"data_preview": {"type": "string"},
"data_shape": {"type": "array", "items": {"type": "integer"}},
"missing_counts": {"type": "object"},
"duplicate_count": {"type": "integer"},
"dtype_issues": {"type": "object"},
"task_description": {"type": "string"},
"message": {"type": "string"},
"step_count": {"type": "integer"},
"current_score": {"type": "number"},
},
},
"state": {
"type": "object",
"properties": {
"episode_id": {"type": "string"},
"task_id": {"type": "integer"},
"step_count": {"type": "integer"},
"max_steps": {"type": "integer"},
"total_errors": {"type": "integer"},
"errors_remaining": {"type": "integer"},
},
},
}
@app.post("/reset", response_model=StepResponse)
def reset(req: ResetRequest = ResetRequest()):
try:
obs = env.reset(task_id=req.task_id)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return StepResponse(observation=obs, reward=obs.reward, done=False)
@app.post("/step", response_model=StepResponse)
async def step(body: Dict[str, Any] = Body(...)):
"""
Accept both openenv-core wrapped format:
{"action": {"operation": "...", ...}, "timeout_s": 15}
and direct format (for backward compat with our own client/inference):
{"operation": "...", "column": "...", "params": {...}}
"""
action_data = body.get("action", body)
try:
action = DataCleaningAction(**action_data)
obs = env.step(action)
except (TypeError, KeyError, Exception) as e:
raise HTTPException(status_code=400, detail=str(e))
return StepResponse(observation=obs, reward=obs.reward, done=obs.done)
@app.get("/state", response_model=DataCleaningState)
def state_get():
"""GET /state — openenv-core spec."""
return env.state()
@app.post("/state", response_model=DataCleaningState)
def state_post():
"""POST /state — backward compatibility."""
return env.state()
# ------------------------------------------------------------------
# Entry point (required by openenv-core and [project.scripts])
# ------------------------------------------------------------------
def main():
uvicorn.run("server.app:app", host="0.0.0.0", port=8000)
if __name__ == "__main__":
main()
|