Upload 22 files
Browse filesNew space created, with the updated code
- .dockerignore +10 -0
- .env.example +11 -0
- Dockerfile +19 -0
- README.md +98 -6
- app.py +187 -0
- env/__init__.py +1 -0
- env/environment.py +271 -0
- env/graders.py +141 -0
- env/models.py +42 -0
- env/rewards.py +31 -0
- env/tasks.py +291 -0
- inference.py +149 -0
- openenv.yaml +117 -0
- pyproject.toml +28 -0
- requirements.txt +10 -0
- server/__init__.py +1 -0
- server/app.py +17 -0
- server/cli.py +11 -0
- tests/test_graders.py +70 -0
- tests/test_reset.py +50 -0
- tests/test_step.py +63 -0
- uv.lock +0 -0
.dockerignore
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.env
|
| 2 |
+
.env.local
|
| 3 |
+
.env.*.local
|
| 4 |
+
conda-env/
|
| 5 |
+
.venv/
|
| 6 |
+
venv/
|
| 7 |
+
__pycache__/
|
| 8 |
+
.pytest_cache/
|
| 9 |
+
*.pyc
|
| 10 |
+
.DS_Store
|
.env.example
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
API_BASE_URL=https://api.groq.com/openai/v1
|
| 2 |
+
MODEL_NAME=meta-llama/llama-4-scout-17b-16e-instruct
|
| 3 |
+
HF_TOKEN=
|
| 4 |
+
# Available TASK_ID values:
|
| 5 |
+
# task1_easy
|
| 6 |
+
# task2_medium
|
| 7 |
+
# task3_hard
|
| 8 |
+
# task4_medium_alt
|
| 9 |
+
# task5_hard_alt
|
| 10 |
+
TASK_ID=task1_easy
|
| 11 |
+
MAX_STEPS=15
|
Dockerfile
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
# Non-root user required by Hugging Face Spaces
|
| 4 |
+
RUN useradd -m -u 1000 appuser
|
| 5 |
+
|
| 6 |
+
WORKDIR /app
|
| 7 |
+
|
| 8 |
+
# Install dependencies first (better layer caching)
|
| 9 |
+
COPY requirements.txt .
|
| 10 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 11 |
+
|
| 12 |
+
# Copy application code
|
| 13 |
+
COPY --chown=appuser:appuser . .
|
| 14 |
+
|
| 15 |
+
USER appuser
|
| 16 |
+
|
| 17 |
+
EXPOSE 7860
|
| 18 |
+
|
| 19 |
+
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
|
README.md
CHANGED
|
@@ -1,12 +1,104 @@
|
|
| 1 |
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
colorTo: green
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
-
|
| 9 |
-
|
| 10 |
---
|
| 11 |
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Data Cleaning OpenEnv Benchmark
|
| 3 |
+
emoji: 🧹
|
| 4 |
+
colorFrom: blue
|
| 5 |
colorTo: green
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
+
tags:
|
| 9 |
+
- openenv
|
| 10 |
---
|
| 11 |
|
| 12 |
+
# Data Cleaning OpenEnv Benchmark
|
| 13 |
+
|
| 14 |
+
A practical benchmark where LLM agents clean messy tabular datasets through a structured action API.
|
| 15 |
+
|
| 16 |
+
## Why This Matters
|
| 17 |
+
|
| 18 |
+
Data cleaning still takes a large share of real analytics work. This environment tests whether an agent can detect and correct common data quality problems such as duplicates, missing values, inconsistent formats, and outliers.
|
| 19 |
+
|
| 20 |
+
## Tasks
|
| 21 |
+
|
| 22 |
+
| ID | Difficulty | Description |
|
| 23 |
+
|----|-----------|-------------|
|
| 24 |
+
| `task1_easy` | Easy | Remove exact duplicates, fill missing emails and ages, standardise country names |
|
| 25 |
+
| `task2_medium` | Medium | Normalise mixed date formats, convert price strings to float, fix category typos |
|
| 26 |
+
| `task3_hard` | Hard | Resolve duplicate user IDs, clip session outliers, fix invalid bounce rates |
|
| 27 |
+
| `task4_medium_alt` | Medium | Alternate order-cleaning scenario that uses the same grader contract as `task2_medium` |
|
| 28 |
+
| `task5_hard_alt` | Hard | Alternate analytics-cleaning scenario that uses the same grader contract as `task3_hard` |
|
| 29 |
+
|
| 30 |
+
Each task is graded independently, and scores are always strictly between 0 and 1.
|
| 31 |
+
|
| 32 |
+
## Action Space
|
| 33 |
+
|
| 34 |
+
| Action | Required Fields |
|
| 35 |
+
|--------|----------------|
|
| 36 |
+
| `fill_missing` | `column`, `strategy` (`mean`/`median`/`mode`/`constant`), `value` when needed |
|
| 37 |
+
| `standardize_values` | `column`, `mapping` |
|
| 38 |
+
| `remove_duplicates` | None |
|
| 39 |
+
| `remove_row` | `row_id` |
|
| 40 |
+
| `convert_type` | `column`, `target_type` |
|
| 41 |
+
| `clip_outliers` | `column`, `lower`, `upper` |
|
| 42 |
+
| `submit` | None |
|
| 43 |
+
|
| 44 |
+
## Observation Space
|
| 45 |
+
|
| 46 |
+
Each step the agent receives `table_preview`, `schema_info`, `issues_detected`, `cleaning_log`, `valid_actions`, `step`, and `max_steps`.
|
| 47 |
+
|
| 48 |
+
## Reward Design
|
| 49 |
+
|
| 50 |
+
Correct cleaning actions receive positive intermediate rewards, wasted actions receive small penalties, invalid actions receive larger penalties, and `submit` returns the final grader score.
|
| 51 |
+
|
| 52 |
+
## Setup & Local Run
|
| 53 |
+
|
| 54 |
+
```bash
|
| 55 |
+
git clone https://huggingface.co/spaces/AnkushRaheja/data-cleaning-benchmark
|
| 56 |
+
cd data-cleaning-benchmark
|
| 57 |
+
pip install -r requirements.txt
|
| 58 |
+
uvicorn app:app --port 7860
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
## Run Baseline
|
| 62 |
+
|
| 63 |
+
```bash
|
| 64 |
+
export API_BASE_URL="https://api.groq.com/openai/v1"
|
| 65 |
+
export MODEL_NAME="meta-llama/llama-4-scout-17b-16e-instruct"
|
| 66 |
+
export HF_TOKEN="$GROQ_API_KEY"
|
| 67 |
+
export TASK_ID="task1_easy"
|
| 68 |
+
python inference.py
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
## Docker
|
| 72 |
+
|
| 73 |
+
```bash
|
| 74 |
+
docker build -t data-cleaning-benchmark .
|
| 75 |
+
docker run -p 7860:7860 \
|
| 76 |
+
-e API_BASE_URL="https://api.groq.com/openai/v1" \
|
| 77 |
+
-e MODEL_NAME="meta-llama/llama-4-scout-17b-16e-instruct" \
|
| 78 |
+
-e HF_TOKEN="$GROQ_API_KEY" \
|
| 79 |
+
data-cleaning-benchmark
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
## Baseline Scores
|
| 83 |
+
|
| 84 |
+
| Task | Score |
|
| 85 |
+
|------|-------|
|
| 86 |
+
| task1_easy | 0.99 |
|
| 87 |
+
| task2_medium | 0.99 |
|
| 88 |
+
| task3_hard | 0.97 |
|
| 89 |
+
| task4_medium_alt | 0.99 |
|
| 90 |
+
| task5_hard_alt | 0.97 |
|
| 91 |
+
|
| 92 |
+
## API Reference
|
| 93 |
+
|
| 94 |
+
| Method | Endpoint | Description |
|
| 95 |
+
|--------|----------|-------------|
|
| 96 |
+
| GET | `/health` | Health check |
|
| 97 |
+
| POST | `/reset` | Start new episode `{"task_id": "task1_easy"}` |
|
| 98 |
+
| POST | `/step` | Submit action and receive reward (compat route with `session_id` in body/query) |
|
| 99 |
+
| POST | `/step/{session_id}` | Legacy route for direct session addressing |
|
| 100 |
+
| GET | `/state` | Retrieve state by query (`session_id`) |
|
| 101 |
+
| GET | `/state/{session_id}` | Legacy route for direct session addressing |
|
| 102 |
+
| GET | `/tasks` | List all tasks |
|
| 103 |
+
| GET | `/metadata` | Benchmark metadata including task and score-range contract |
|
| 104 |
+
| GET | `/schema` | JSON schemas for action/observation/step response |
|
app.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import uuid
|
| 4 |
+
from typing import Any, Dict, Optional
|
| 5 |
+
|
| 6 |
+
from fastapi import FastAPI, HTTPException, Query
|
| 7 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 8 |
+
from pydantic import BaseModel
|
| 9 |
+
|
| 10 |
+
from env.environment import DataCleaningEnv
|
| 11 |
+
from env.models import Action, Observation, StepResult
|
| 12 |
+
from env.tasks import list_tasks as list_task_specs
|
| 13 |
+
|
| 14 |
+
app = FastAPI(
|
| 15 |
+
title="Data Cleaning OpenEnv Benchmark",
|
| 16 |
+
version="1.0.0",
|
| 17 |
+
description="LLM agent benchmark for real-world data cleaning tasks.",
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
app.add_middleware(
|
| 21 |
+
CORSMiddleware,
|
| 22 |
+
allow_origins=["*"],
|
| 23 |
+
allow_methods=["*"],
|
| 24 |
+
allow_headers=["*"],
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
sessions: Dict[str, DataCleaningEnv] = {}
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@app.get("/")
|
| 31 |
+
def root():
|
| 32 |
+
tasks = list_task_specs()
|
| 33 |
+
return {
|
| 34 |
+
"name": "Data Cleaning OpenEnv Benchmark",
|
| 35 |
+
"version": "1.0.0",
|
| 36 |
+
"tasks": tasks,
|
| 37 |
+
"api": {
|
| 38 |
+
"reset": "POST /reset",
|
| 39 |
+
"step": "POST /step/{session_id}",
|
| 40 |
+
"step_compat": "POST /step",
|
| 41 |
+
"state": "GET /state/{session_id}",
|
| 42 |
+
"state_compat": "GET /state?session_id=...",
|
| 43 |
+
"metadata": "GET /metadata",
|
| 44 |
+
"schema": "GET /schema",
|
| 45 |
+
"mcp": "GET|POST /mcp",
|
| 46 |
+
"health": "GET /health",
|
| 47 |
+
},
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@app.get("/health")
|
| 52 |
+
def health():
|
| 53 |
+
return {"status": "ok", "sessions_active": len(sessions)}
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class ResetRequest(BaseModel):
|
| 57 |
+
task_id: Optional[str] = None
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@app.post("/reset")
|
| 61 |
+
def reset(body: ResetRequest = ResetRequest()):
|
| 62 |
+
session_id = str(uuid.uuid4())
|
| 63 |
+
env = DataCleaningEnv()
|
| 64 |
+
obs = env.reset(task_id=body.task_id)
|
| 65 |
+
sessions[session_id] = env
|
| 66 |
+
return {
|
| 67 |
+
"session_id": session_id,
|
| 68 |
+
"observation": obs.model_dump(),
|
| 69 |
+
"reward": 0.0,
|
| 70 |
+
"done": False,
|
| 71 |
+
"info": {
|
| 72 |
+
"error": None,
|
| 73 |
+
"cumulative_reward": env.cumulative_reward,
|
| 74 |
+
"raw_cumulative_reward": env.raw_cumulative_reward,
|
| 75 |
+
"final_score": env.final_score,
|
| 76 |
+
"step": env.step_count,
|
| 77 |
+
},
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@app.post("/step")
|
| 82 |
+
def step_compat(
|
| 83 |
+
payload: Dict[str, Any],
|
| 84 |
+
session_id: Optional[str] = Query(default=None),
|
| 85 |
+
):
|
| 86 |
+
payload_session_id = payload.get("session_id")
|
| 87 |
+
resolved_session_id = _resolve_session_id(payload_session_id or session_id)
|
| 88 |
+
action_payload = payload.get("action", payload)
|
| 89 |
+
|
| 90 |
+
if not isinstance(action_payload, dict):
|
| 91 |
+
raise HTTPException(status_code=400, detail="Action payload must be an object")
|
| 92 |
+
if "type" not in action_payload:
|
| 93 |
+
raise HTTPException(status_code=400, detail="Action payload requires 'type'")
|
| 94 |
+
|
| 95 |
+
action = Action(**action_payload)
|
| 96 |
+
env = _get_session(resolved_session_id)
|
| 97 |
+
result = env.step(action)
|
| 98 |
+
return result.model_dump()
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
@app.post("/step/{session_id}")
|
| 102 |
+
def step(session_id: str, action: Action):
|
| 103 |
+
env = _get_session(session_id)
|
| 104 |
+
result = env.step(action)
|
| 105 |
+
return result.model_dump()
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
@app.get("/state")
|
| 109 |
+
def state_compat(session_id: Optional[str] = Query(default=None)):
|
| 110 |
+
env = _get_session(_resolve_session_id(session_id))
|
| 111 |
+
return env.state()
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
@app.get("/state/{session_id}")
|
| 115 |
+
def state(session_id: str):
|
| 116 |
+
env = _get_session(session_id)
|
| 117 |
+
return env.state()
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
@app.get("/metadata")
|
| 121 |
+
def metadata():
|
| 122 |
+
return {
|
| 123 |
+
"name": "data-cleaning-benchmark",
|
| 124 |
+
"version": "1.0.0",
|
| 125 |
+
"description": "LLM agent benchmark for real-world data cleaning tasks.",
|
| 126 |
+
"tasks": list_task_specs(),
|
| 127 |
+
"score_range": {
|
| 128 |
+
"min": DataCleaningEnv.MIN_EPISODE_SCORE,
|
| 129 |
+
"max": DataCleaningEnv.MAX_EPISODE_SCORE,
|
| 130 |
+
},
|
| 131 |
+
"entrypoints": {
|
| 132 |
+
"reset": "/reset",
|
| 133 |
+
"step": "/step",
|
| 134 |
+
"state": "/state",
|
| 135 |
+
"health": "/health",
|
| 136 |
+
"tasks": "/tasks",
|
| 137 |
+
"schema": "/schema",
|
| 138 |
+
"mcp": "/mcp",
|
| 139 |
+
},
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
@app.get("/schema")
|
| 144 |
+
def schema():
|
| 145 |
+
return {
|
| 146 |
+
"action": Action.model_json_schema(),
|
| 147 |
+
"observation": Observation.model_json_schema(),
|
| 148 |
+
"step_result": StepResult.model_json_schema(),
|
| 149 |
+
"reset_request": ResetRequest.model_json_schema(),
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
@app.api_route("/mcp", methods=["GET", "POST"])
|
| 154 |
+
def mcp_metadata():
|
| 155 |
+
return {
|
| 156 |
+
"supported": False,
|
| 157 |
+
"message": "This benchmark exposes simulation HTTP endpoints (reset/step/state).",
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
@app.delete("/session/{session_id}")
|
| 162 |
+
def delete_session(session_id: str):
|
| 163 |
+
sessions.pop(session_id, None)
|
| 164 |
+
return {"deleted": session_id}
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
@app.get("/tasks")
|
| 168 |
+
def list_tasks():
|
| 169 |
+
return {"tasks": list_task_specs()}
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def _resolve_session_id(session_id: Optional[str]) -> str:
|
| 173 |
+
if session_id:
|
| 174 |
+
return session_id
|
| 175 |
+
if len(sessions) == 1:
|
| 176 |
+
return next(iter(sessions.keys()))
|
| 177 |
+
raise HTTPException(
|
| 178 |
+
status_code=400,
|
| 179 |
+
detail="session_id is required when there is not exactly one active session",
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def _get_session(session_id: str) -> DataCleaningEnv:
|
| 184 |
+
env = sessions.get(session_id)
|
| 185 |
+
if env is None:
|
| 186 |
+
raise HTTPException(status_code=404, detail=f"Session '{session_id}' not found")
|
| 187 |
+
return env
|
env/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
env/environment.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Optional, Tuple
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
|
| 8 |
+
from .graders import grade_task
|
| 9 |
+
from .models import Action, Observation, StepResult, TablePreview
|
| 10 |
+
from .rewards import compute_reward
|
| 11 |
+
from .tasks import TASK_IDS, get_task
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class DataCleaningEnv:
|
| 15 |
+
MAX_STEPS: int = 20
|
| 16 |
+
MIN_EPISODE_SCORE: float = 0.01
|
| 17 |
+
MAX_EPISODE_SCORE: float = 0.99
|
| 18 |
+
|
| 19 |
+
def __init__(self) -> None:
|
| 20 |
+
self.task_id: Optional[str] = None
|
| 21 |
+
self._task_config: Optional[dict] = None
|
| 22 |
+
self.original_df: Optional[pd.DataFrame] = None
|
| 23 |
+
self.current_df: Optional[pd.DataFrame] = None
|
| 24 |
+
self.step_count: int = 0
|
| 25 |
+
self.cleaning_log: list = []
|
| 26 |
+
self.action_history: list = []
|
| 27 |
+
self.raw_cumulative_reward: float = 0.0
|
| 28 |
+
self.cumulative_reward: float = 0.0
|
| 29 |
+
self.done: bool = False
|
| 30 |
+
self.final_score: float = 0.01
|
| 31 |
+
|
| 32 |
+
def reset(self, task_id: Optional[str] = None) -> Observation:
|
| 33 |
+
if task_id is None:
|
| 34 |
+
task_id = TASK_IDS[0]
|
| 35 |
+
self.task_id = task_id
|
| 36 |
+
self._task_config = get_task(task_id)
|
| 37 |
+
self.original_df = self._task_config["dirty_df"].copy()
|
| 38 |
+
self.current_df = self._task_config["dirty_df"].copy()
|
| 39 |
+
self.step_count = 0
|
| 40 |
+
self.cleaning_log = []
|
| 41 |
+
self.action_history = []
|
| 42 |
+
self.raw_cumulative_reward = 0.0
|
| 43 |
+
self.cumulative_reward = 0.0
|
| 44 |
+
self.done = False
|
| 45 |
+
self.final_score = 0.01
|
| 46 |
+
return self._build_observation()
|
| 47 |
+
|
| 48 |
+
def step(self, action: Action) -> StepResult:
|
| 49 |
+
if self.done:
|
| 50 |
+
return StepResult(
|
| 51 |
+
observation=self._build_observation(),
|
| 52 |
+
reward=self.final_score,
|
| 53 |
+
done=True,
|
| 54 |
+
info={
|
| 55 |
+
"error": "Episode already finished",
|
| 56 |
+
"cumulative_reward": self.cumulative_reward,
|
| 57 |
+
"raw_cumulative_reward": self.raw_cumulative_reward,
|
| 58 |
+
"final_score": self.final_score,
|
| 59 |
+
"step": self.step_count,
|
| 60 |
+
},
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
error: Optional[str] = None
|
| 64 |
+
reward: float = 0.0
|
| 65 |
+
|
| 66 |
+
if action.type == "submit":
|
| 67 |
+
self.final_score = grade_task(self.task_id, self.current_df)
|
| 68 |
+
reward = self.final_score
|
| 69 |
+
self.cleaning_log.append(f"[SUBMIT] Final grade: {self.final_score:.4f}")
|
| 70 |
+
self.done = True
|
| 71 |
+
else:
|
| 72 |
+
try:
|
| 73 |
+
reward, log_msg = self._apply_action(action)
|
| 74 |
+
self.cleaning_log.append(log_msg)
|
| 75 |
+
except Exception as exc:
|
| 76 |
+
error = str(exc)
|
| 77 |
+
reward = -0.10
|
| 78 |
+
self.cleaning_log.append(f"[ERROR] {error}")
|
| 79 |
+
|
| 80 |
+
self.step_count += 1
|
| 81 |
+
self.raw_cumulative_reward = round(self.raw_cumulative_reward + reward, 4)
|
| 82 |
+
self.cumulative_reward = self._clamp_episode_score(self.raw_cumulative_reward)
|
| 83 |
+
self.action_history.append(action.model_dump())
|
| 84 |
+
|
| 85 |
+
if not self.done and self.step_count >= self.MAX_STEPS:
|
| 86 |
+
self.final_score = grade_task(self.task_id, self.current_df)
|
| 87 |
+
self.done = True
|
| 88 |
+
|
| 89 |
+
return StepResult(
|
| 90 |
+
observation=self._build_observation(),
|
| 91 |
+
reward=round(reward, 4),
|
| 92 |
+
done=self.done,
|
| 93 |
+
info={
|
| 94 |
+
"error": error,
|
| 95 |
+
"cumulative_reward": self.cumulative_reward,
|
| 96 |
+
"raw_cumulative_reward": self.raw_cumulative_reward,
|
| 97 |
+
"final_score": self.final_score,
|
| 98 |
+
"step": self.step_count,
|
| 99 |
+
},
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
def state(self) -> dict:
|
| 103 |
+
return {
|
| 104 |
+
"task_id": self.task_id,
|
| 105 |
+
"step_count": self.step_count,
|
| 106 |
+
"cumulative_reward": self.cumulative_reward,
|
| 107 |
+
"raw_cumulative_reward": self.raw_cumulative_reward,
|
| 108 |
+
"final_score": self.final_score,
|
| 109 |
+
"done": self.done,
|
| 110 |
+
"cleaning_log": self.cleaning_log,
|
| 111 |
+
"action_history": self.action_history,
|
| 112 |
+
"current_data": self._df_records_with_none(self.current_df) if self.current_df is not None else [],
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
@classmethod
|
| 116 |
+
def _clamp_episode_score(cls, value: float) -> float:
|
| 117 |
+
return round(min(max(value, cls.MIN_EPISODE_SCORE), cls.MAX_EPISODE_SCORE), 4)
|
| 118 |
+
|
| 119 |
+
def _apply_action(self, action: Action) -> Tuple[float, str]:
|
| 120 |
+
df = self.current_df
|
| 121 |
+
|
| 122 |
+
if action.type == "fill_missing":
|
| 123 |
+
col = self._require_column(action.column, df)
|
| 124 |
+
missing_before = int(df[col].isna().sum())
|
| 125 |
+
if missing_before == 0:
|
| 126 |
+
return -0.05, f"[WARN] No missing values in '{col}' — wasted step"
|
| 127 |
+
|
| 128 |
+
if action.strategy == "mean":
|
| 129 |
+
df[col] = df[col].fillna(df[col].mean())
|
| 130 |
+
elif action.strategy == "median":
|
| 131 |
+
df[col] = df[col].fillna(df[col].median())
|
| 132 |
+
elif action.strategy == "mode":
|
| 133 |
+
df[col] = df[col].fillna(df[col].mode().iloc[0])
|
| 134 |
+
elif action.strategy == "constant":
|
| 135 |
+
df[col] = df[col].fillna(action.value)
|
| 136 |
+
else:
|
| 137 |
+
raise ValueError(f"Unknown fill strategy '{action.strategy}'")
|
| 138 |
+
|
| 139 |
+
reward = compute_reward("fill_missing", {"filled": missing_before})
|
| 140 |
+
return reward, f"Filled {missing_before} missing values in '{col}' via {action.strategy}"
|
| 141 |
+
|
| 142 |
+
if action.type == "standardize_values":
|
| 143 |
+
col = self._require_column(action.column, df)
|
| 144 |
+
if not action.mapping:
|
| 145 |
+
raise ValueError("'mapping' dict is required for standardize_values")
|
| 146 |
+
replaced = int(df[col].isin(action.mapping.keys()).sum())
|
| 147 |
+
df[col] = df[col].apply(lambda x: action.mapping.get(str(x), x) if pd.notna(x) else x)
|
| 148 |
+
reward = compute_reward("standardize_values", {"replaced": replaced})
|
| 149 |
+
return reward, f"Standardised {replaced} values in '{col}'"
|
| 150 |
+
|
| 151 |
+
if action.type == "remove_duplicates":
|
| 152 |
+
before = len(df)
|
| 153 |
+
self.current_df = df.drop_duplicates().reset_index(drop=True)
|
| 154 |
+
removed = before - len(self.current_df)
|
| 155 |
+
if removed == 0:
|
| 156 |
+
return -0.05, "[WARN] No exact duplicates found — wasted step"
|
| 157 |
+
reward = compute_reward("remove_duplicates", {"removed": removed})
|
| 158 |
+
return reward, f"Removed {removed} duplicate row(s)"
|
| 159 |
+
|
| 160 |
+
if action.type == "remove_row":
|
| 161 |
+
if action.row_id is None:
|
| 162 |
+
raise ValueError("'row_id' is required for remove_row")
|
| 163 |
+
if action.row_id not in df.index:
|
| 164 |
+
raise ValueError(f"Row index {action.row_id} not found (valid range 0–{len(df)-1})")
|
| 165 |
+
self.current_df = df.drop(index=action.row_id).reset_index(drop=True)
|
| 166 |
+
reward = compute_reward("remove_row", {})
|
| 167 |
+
return reward, f"Removed row at index {action.row_id}"
|
| 168 |
+
|
| 169 |
+
if action.type == "convert_type":
|
| 170 |
+
col = self._require_column(action.column, df)
|
| 171 |
+
tgt = action.target_type
|
| 172 |
+
|
| 173 |
+
if tgt == "float":
|
| 174 |
+
df[col] = (
|
| 175 |
+
df[col]
|
| 176 |
+
.astype(str)
|
| 177 |
+
.str.replace(r"[$,\s]", "", regex=True)
|
| 178 |
+
.replace("nan", np.nan)
|
| 179 |
+
.replace("None", np.nan)
|
| 180 |
+
)
|
| 181 |
+
df[col] = pd.to_numeric(df[col], errors="coerce")
|
| 182 |
+
elif tgt == "int":
|
| 183 |
+
df[col] = pd.to_numeric(df[col], errors="coerce").astype("Int64")
|
| 184 |
+
elif tgt == "str":
|
| 185 |
+
df[col] = df[col].astype(str)
|
| 186 |
+
elif tgt == "datetime":
|
| 187 |
+
parsed = pd.to_datetime(df[col], errors="coerce")
|
| 188 |
+
df[col] = parsed.dt.strftime("%Y-%m-%d")
|
| 189 |
+
else:
|
| 190 |
+
raise ValueError(f"Unknown target_type '{tgt}'")
|
| 191 |
+
|
| 192 |
+
reward = compute_reward("convert_type", {})
|
| 193 |
+
return reward, f"Converted column '{col}' → {tgt}"
|
| 194 |
+
|
| 195 |
+
if action.type == "clip_outliers":
|
| 196 |
+
col = self._require_column(action.column, df)
|
| 197 |
+
if action.lower is None and action.upper is None:
|
| 198 |
+
raise ValueError("At least one of 'lower' or 'upper' must be set")
|
| 199 |
+
|
| 200 |
+
series = pd.to_numeric(df[col], errors="coerce")
|
| 201 |
+
clipped = 0
|
| 202 |
+
if action.lower is not None:
|
| 203 |
+
clipped += int((series < action.lower).sum())
|
| 204 |
+
if action.upper is not None:
|
| 205 |
+
clipped += int((series > action.upper).sum())
|
| 206 |
+
|
| 207 |
+
df[col] = series.clip(lower=action.lower, upper=action.upper)
|
| 208 |
+
reward = compute_reward("clip_outliers", {"clipped": clipped})
|
| 209 |
+
return reward, f"Clipped '{col}' to [{action.lower}, {action.upper}] ({clipped} value(s) affected)"
|
| 210 |
+
|
| 211 |
+
raise ValueError(f"Unknown action type '{action.type}'")
|
| 212 |
+
|
| 213 |
+
@staticmethod
|
| 214 |
+
def _require_column(col: Optional[str], df: pd.DataFrame) -> str:
|
| 215 |
+
if not col:
|
| 216 |
+
raise ValueError("'column' field is required for this action")
|
| 217 |
+
if col not in df.columns:
|
| 218 |
+
raise ValueError(f"Column '{col}' not found. Available: {list(df.columns)}")
|
| 219 |
+
return col
|
| 220 |
+
|
| 221 |
+
@staticmethod
|
| 222 |
+
def _df_records_with_none(df: pd.DataFrame) -> list[dict]:
|
| 223 |
+
safe_df = df.astype(object).where(pd.notna(df), None)
|
| 224 |
+
return safe_df.to_dict(orient="records")
|
| 225 |
+
|
| 226 |
+
def _build_observation(self) -> Observation:
|
| 227 |
+
df = self.current_df
|
| 228 |
+
issues: list = []
|
| 229 |
+
|
| 230 |
+
if df is not None:
|
| 231 |
+
for col in df.columns:
|
| 232 |
+
miss = int(df[col].isna().sum())
|
| 233 |
+
if miss > 0:
|
| 234 |
+
issues.append(f"Column '{col}' has {miss} missing value(s)")
|
| 235 |
+
dup = int(df.duplicated().sum())
|
| 236 |
+
if dup > 0:
|
| 237 |
+
issues.append(f"{dup} exact duplicate row(s) detected")
|
| 238 |
+
|
| 239 |
+
head = df.head(10).copy()
|
| 240 |
+
head.insert(0, "_row_id", head.index.tolist())
|
| 241 |
+
preview_rows = self._df_records_with_none(head)
|
| 242 |
+
schema_info = {c: str(df[c].dtype) for c in df.columns}
|
| 243 |
+
shape = list(df.shape)
|
| 244 |
+
else:
|
| 245 |
+
preview_rows, schema_info, shape = [], {}, [0, 0]
|
| 246 |
+
|
| 247 |
+
preview = TablePreview(
|
| 248 |
+
columns=["_row_id"] + (list(df.columns) if df is not None else []),
|
| 249 |
+
rows=preview_rows,
|
| 250 |
+
shape=shape,
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
return Observation(
|
| 254 |
+
task_id=self.task_id or "",
|
| 255 |
+
task_description=(self._task_config["description"] if self._task_config else ""),
|
| 256 |
+
table_preview=preview,
|
| 257 |
+
schema_info=schema_info,
|
| 258 |
+
valid_actions=[
|
| 259 |
+
"fill_missing",
|
| 260 |
+
"standardize_values",
|
| 261 |
+
"remove_duplicates",
|
| 262 |
+
"remove_row",
|
| 263 |
+
"convert_type",
|
| 264 |
+
"clip_outliers",
|
| 265 |
+
"submit",
|
| 266 |
+
],
|
| 267 |
+
step=self.step_count,
|
| 268 |
+
max_steps=self.MAX_STEPS,
|
| 269 |
+
cleaning_log=self.cleaning_log[-6:],
|
| 270 |
+
issues_detected=issues,
|
| 271 |
+
)
|
env/graders.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
import pandas as pd
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def _strict_score(value: float) -> float:
|
| 9 |
+
try:
|
| 10 |
+
score = float(value)
|
| 11 |
+
except (TypeError, ValueError):
|
| 12 |
+
return 0.01
|
| 13 |
+
|
| 14 |
+
if not math.isfinite(score):
|
| 15 |
+
return 0.01
|
| 16 |
+
|
| 17 |
+
return round(min(max(score, 0.01), 0.99), 4)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def grade_task1(df: pd.DataFrame) -> float:
|
| 21 |
+
score = 0.0
|
| 22 |
+
|
| 23 |
+
if df.duplicated().sum() == 0:
|
| 24 |
+
score += 0.25
|
| 25 |
+
|
| 26 |
+
if "email" in df.columns and df["email"].isna().sum() == 0:
|
| 27 |
+
score += 0.25
|
| 28 |
+
|
| 29 |
+
if "age" in df.columns and df["age"].isna().sum() == 0:
|
| 30 |
+
score += 0.25
|
| 31 |
+
|
| 32 |
+
valid_countries = {"United States", "United Kingdom", "Canada", "Australia"}
|
| 33 |
+
if "country" in df.columns:
|
| 34 |
+
non_null = df["country"].dropna()
|
| 35 |
+
if len(non_null) == 0:
|
| 36 |
+
pass
|
| 37 |
+
elif set(non_null.unique()).issubset(valid_countries):
|
| 38 |
+
score += 0.25
|
| 39 |
+
else:
|
| 40 |
+
valid_n = non_null.isin(valid_countries).sum()
|
| 41 |
+
score += 0.25 * (valid_n / len(non_null))
|
| 42 |
+
|
| 43 |
+
return _strict_score(score)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def grade_task2(df: pd.DataFrame) -> float:
|
| 47 |
+
score = 0.0
|
| 48 |
+
n = len(df)
|
| 49 |
+
if n == 0:
|
| 50 |
+
return 0.01
|
| 51 |
+
|
| 52 |
+
if "date" in df.columns:
|
| 53 |
+
pattern = r"^\d{4}-\d{2}-\d{2}$"
|
| 54 |
+
valid = df["date"].astype(str).str.match(pattern).sum()
|
| 55 |
+
score += 0.25 * (valid / n)
|
| 56 |
+
|
| 57 |
+
if "price" in df.columns:
|
| 58 |
+
numeric = pd.to_numeric(df["price"], errors="coerce")
|
| 59 |
+
non_null = numeric.notna().sum()
|
| 60 |
+
score += 0.25 * (non_null / n)
|
| 61 |
+
|
| 62 |
+
valid_cats = {"Electronics", "Furniture"}
|
| 63 |
+
if "category" in df.columns:
|
| 64 |
+
non_null_cats = df["category"].dropna()
|
| 65 |
+
if len(non_null_cats) > 0:
|
| 66 |
+
valid_n = non_null_cats.isin(valid_cats).sum()
|
| 67 |
+
score += 0.25 * (valid_n / len(non_null_cats))
|
| 68 |
+
|
| 69 |
+
key_cols = [c for c in ["price", "category", "quantity"] if c in df.columns]
|
| 70 |
+
if key_cols:
|
| 71 |
+
total_cells = n * len(key_cols)
|
| 72 |
+
missing = sum(int(df[c].isna().sum()) for c in key_cols)
|
| 73 |
+
score += 0.25 * (1.0 - missing / total_cells)
|
| 74 |
+
|
| 75 |
+
return _strict_score(score)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def grade_task3(df: pd.DataFrame) -> float:
|
| 79 |
+
score = 0.0
|
| 80 |
+
n = len(df)
|
| 81 |
+
if n == 0:
|
| 82 |
+
return 0.01
|
| 83 |
+
|
| 84 |
+
if "user_id" in df.columns:
|
| 85 |
+
dup = df["user_id"].duplicated().sum()
|
| 86 |
+
if dup == 0:
|
| 87 |
+
score += 0.34
|
| 88 |
+
else:
|
| 89 |
+
score += 0.34 * (1.0 - dup / n)
|
| 90 |
+
|
| 91 |
+
if "session_duration" in df.columns:
|
| 92 |
+
max_dur = df["session_duration"].dropna().max() if n > 0 else 0
|
| 93 |
+
if max_dur <= 1000:
|
| 94 |
+
score += 0.33
|
| 95 |
+
elif max_dur <= 5000:
|
| 96 |
+
score += 0.15
|
| 97 |
+
|
| 98 |
+
if "bounce_rate" in df.columns:
|
| 99 |
+
valid_br = ((df["bounce_rate"] >= 0) & (df["bounce_rate"] <= 1)).sum()
|
| 100 |
+
score += 0.165 * (valid_br / n)
|
| 101 |
+
|
| 102 |
+
if "page_views" in df.columns and df["page_views"].isna().sum() == 0:
|
| 103 |
+
score += 0.165
|
| 104 |
+
|
| 105 |
+
return _strict_score(score)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def grade_task(task_id: str, df: pd.DataFrame) -> float:
|
| 109 |
+
fn = TASK_GRADERS.get(task_id)
|
| 110 |
+
if fn is None:
|
| 111 |
+
return 0.01
|
| 112 |
+
return fn(df)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def grade_task1_easy(df: pd.DataFrame) -> float:
|
| 116 |
+
return grade_task1(df)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def grade_task2_medium(df: pd.DataFrame) -> float:
|
| 120 |
+
return grade_task2(df)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def grade_task3_hard(df: pd.DataFrame) -> float:
|
| 124 |
+
return grade_task3(df)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def grade_task4_medium_alt(df: pd.DataFrame) -> float:
|
| 128 |
+
return grade_task2(df)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def grade_task5_hard_alt(df: pd.DataFrame) -> float:
|
| 132 |
+
return grade_task3(df)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
TASK_GRADERS = {
|
| 136 |
+
"task1_easy": grade_task1_easy,
|
| 137 |
+
"task2_medium": grade_task2_medium,
|
| 138 |
+
"task3_hard": grade_task3_hard,
|
| 139 |
+
"task4_medium_alt": grade_task4_medium_alt,
|
| 140 |
+
"task5_hard_alt": grade_task5_hard_alt,
|
| 141 |
+
}
|
env/models.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Any, Dict, List, Optional, Union
|
| 4 |
+
|
| 5 |
+
from pydantic import BaseModel
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Action(BaseModel):
|
| 9 |
+
type: str
|
| 10 |
+
column: Optional[str] = None
|
| 11 |
+
row_id: Optional[int] = None
|
| 12 |
+
strategy: Optional[str] = None
|
| 13 |
+
value: Optional[Union[str, float, int]] = None
|
| 14 |
+
mapping: Optional[Dict[str, str]] = None
|
| 15 |
+
target_type: Optional[str] = None
|
| 16 |
+
lower: Optional[float] = None
|
| 17 |
+
upper: Optional[float] = None
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class TablePreview(BaseModel):
|
| 21 |
+
columns: List[str]
|
| 22 |
+
rows: List[Dict[str, Any]]
|
| 23 |
+
shape: List[int]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class Observation(BaseModel):
|
| 27 |
+
task_id: str
|
| 28 |
+
task_description: str
|
| 29 |
+
table_preview: TablePreview
|
| 30 |
+
schema_info: Dict[str, str]
|
| 31 |
+
valid_actions: List[str]
|
| 32 |
+
step: int
|
| 33 |
+
max_steps: int
|
| 34 |
+
cleaning_log: List[str]
|
| 35 |
+
issues_detected: List[str]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class StepResult(BaseModel):
|
| 39 |
+
observation: Observation
|
| 40 |
+
reward: float
|
| 41 |
+
done: bool
|
| 42 |
+
info: Dict[str, Any]
|
env/rewards.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def compute_reward(action_type: str, context: dict) -> float:
|
| 5 |
+
"""
|
| 6 |
+
Intermediate reward shaping.
|
| 7 |
+
Final episode reward comes from the grader (called at submit).
|
| 8 |
+
"""
|
| 9 |
+
if action_type == "fill_missing":
|
| 10 |
+
filled = context.get("filled", 0)
|
| 11 |
+
return round(min(0.08 * filled, 0.30), 4)
|
| 12 |
+
|
| 13 |
+
if action_type == "standardize_values":
|
| 14 |
+
replaced = context.get("replaced", 0)
|
| 15 |
+
return round(min(0.06 * replaced, 0.25), 4)
|
| 16 |
+
|
| 17 |
+
if action_type == "remove_duplicates":
|
| 18 |
+
removed = context.get("removed", 0)
|
| 19 |
+
return round(min(0.15 * removed, 0.30), 4)
|
| 20 |
+
|
| 21 |
+
if action_type == "remove_row":
|
| 22 |
+
return 0.05
|
| 23 |
+
|
| 24 |
+
if action_type == "convert_type":
|
| 25 |
+
return 0.15
|
| 26 |
+
|
| 27 |
+
if action_type == "clip_outliers":
|
| 28 |
+
clipped = context.get("clipped", 0)
|
| 29 |
+
return round(min(0.10 * max(clipped, 1), 0.30), 4)
|
| 30 |
+
|
| 31 |
+
return 0.0
|
env/tasks.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Any, Dict
|
| 4 |
+
|
| 5 |
+
import pandas as pd
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
TASK1_DIRTY = [
|
| 9 |
+
{"name": "Alice Johnson", "email": "alice@email.com", "country": "USA", "age": 28.0},
|
| 10 |
+
{"name": "Bob Smith", "email": "bob@email.com", "country": "United States", "age": None},
|
| 11 |
+
{"name": "Carol White", "email": "carol@email.com", "country": "UK", "age": 35.0},
|
| 12 |
+
{"name": "Alice Johnson", "email": "alice@email.com", "country": "USA", "age": 28.0},
|
| 13 |
+
{"name": "Dave Brown", "email": None, "country": "US", "age": 42.0},
|
| 14 |
+
{"name": "Eve Davis", "email": "eve@email.com", "country": "United Kingdom", "age": 31.0},
|
| 15 |
+
{"name": "Frank Miller", "email": "frank@email.com", "country": "Canada", "age": None},
|
| 16 |
+
{"name": "Grace Wilson", "email": "grace@email.com", "country": "CAN", "age": 25.0},
|
| 17 |
+
{"name": "Henry Moore", "email": "henry@email.com", "country": "australia", "age": 38.0},
|
| 18 |
+
{"name": "Iris Taylor", "email": "iris@email.com", "country": "AUS", "age": 29.0},
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
TASK1_DESCRIPTION = (
|
| 22 |
+
"Clean a customer dataset. Issues to fix:\n"
|
| 23 |
+
"1) Remove exact duplicate rows\n"
|
| 24 |
+
"2) Fill missing emails using constant 'unknown@email.com'\n"
|
| 25 |
+
"3) Fill missing ages using median\n"
|
| 26 |
+
"4) Standardize country names to United States, United Kingdom, Canada, Australia"
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
TASK2_DIRTY = [
|
| 30 |
+
{
|
| 31 |
+
"order_id": 1,
|
| 32 |
+
"date": "2023-01-15",
|
| 33 |
+
"product": "Laptop",
|
| 34 |
+
"category": "Electronics",
|
| 35 |
+
"price": "$1200.00",
|
| 36 |
+
"quantity": 2,
|
| 37 |
+
},
|
| 38 |
+
{
|
| 39 |
+
"order_id": 2,
|
| 40 |
+
"date": "02/20/2023",
|
| 41 |
+
"product": "Chair",
|
| 42 |
+
"category": "Furniture",
|
| 43 |
+
"price": "$250.50",
|
| 44 |
+
"quantity": 1,
|
| 45 |
+
},
|
| 46 |
+
{
|
| 47 |
+
"order_id": 3,
|
| 48 |
+
"date": "Mar 10, 2023",
|
| 49 |
+
"product": "Headphones",
|
| 50 |
+
"category": "Electronics",
|
| 51 |
+
"price": "$89.99",
|
| 52 |
+
"quantity": 3,
|
| 53 |
+
},
|
| 54 |
+
{
|
| 55 |
+
"order_id": 4,
|
| 56 |
+
"date": "2023-04-05",
|
| 57 |
+
"product": "Desk",
|
| 58 |
+
"category": "Furnitre",
|
| 59 |
+
"price": "$450.00",
|
| 60 |
+
"quantity": 1,
|
| 61 |
+
},
|
| 62 |
+
{
|
| 63 |
+
"order_id": 5,
|
| 64 |
+
"date": "05/12/2023",
|
| 65 |
+
"product": "Monitor",
|
| 66 |
+
"category": "Electronics",
|
| 67 |
+
"price": "320.00",
|
| 68 |
+
"quantity": 2,
|
| 69 |
+
},
|
| 70 |
+
{
|
| 71 |
+
"order_id": 6,
|
| 72 |
+
"date": "2023-06-18",
|
| 73 |
+
"product": "Keyboard",
|
| 74 |
+
"category": None,
|
| 75 |
+
"price": "$75.00",
|
| 76 |
+
"quantity": 5,
|
| 77 |
+
},
|
| 78 |
+
{
|
| 79 |
+
"order_id": 7,
|
| 80 |
+
"date": "July 22 2023",
|
| 81 |
+
"product": "Mouse",
|
| 82 |
+
"category": "Electronics",
|
| 83 |
+
"price": "$35.00",
|
| 84 |
+
"quantity": 4,
|
| 85 |
+
},
|
| 86 |
+
{
|
| 87 |
+
"order_id": 8,
|
| 88 |
+
"date": "2023-08-30",
|
| 89 |
+
"product": "Bookshelf",
|
| 90 |
+
"category": "Furniture",
|
| 91 |
+
"price": None,
|
| 92 |
+
"quantity": 1,
|
| 93 |
+
},
|
| 94 |
+
{
|
| 95 |
+
"order_id": 9,
|
| 96 |
+
"date": "09-14-2023",
|
| 97 |
+
"product": "Webcam",
|
| 98 |
+
"category": "ELECTRONICS",
|
| 99 |
+
"price": "$65.00",
|
| 100 |
+
"quantity": 2,
|
| 101 |
+
},
|
| 102 |
+
{
|
| 103 |
+
"order_id": 10,
|
| 104 |
+
"date": "2023-10-01",
|
| 105 |
+
"product": "Lamp",
|
| 106 |
+
"category": "Furniture",
|
| 107 |
+
"price": "$45.00",
|
| 108 |
+
"quantity": 3,
|
| 109 |
+
},
|
| 110 |
+
{
|
| 111 |
+
"order_id": 11,
|
| 112 |
+
"date": "11/15/2023",
|
| 113 |
+
"product": "Tablet",
|
| 114 |
+
"category": "Electronix",
|
| 115 |
+
"price": "$599.00",
|
| 116 |
+
"quantity": 1,
|
| 117 |
+
},
|
| 118 |
+
{
|
| 119 |
+
"order_id": 12,
|
| 120 |
+
"date": "2023-12-20",
|
| 121 |
+
"product": "Sofa",
|
| 122 |
+
"category": "Furniture",
|
| 123 |
+
"price": "$1100.00",
|
| 124 |
+
"quantity": 1,
|
| 125 |
+
},
|
| 126 |
+
]
|
| 127 |
+
|
| 128 |
+
TASK2_DESCRIPTION = (
|
| 129 |
+
"Clean an e-commerce orders dataset. Issues to fix:\n"
|
| 130 |
+
"1) Normalise all dates to YYYY-MM-DD format using convert_type(date, datetime)\n"
|
| 131 |
+
"2) Convert price column to float (strips $ signs automatically)\n"
|
| 132 |
+
"3) Standardise category typos: Furnitre to Furniture, ELECTRONICS to Electronics, Electronix to Electronics\n"
|
| 133 |
+
"4) Fill missing price with median; fill or remove missing category rows"
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
TASK3_DIRTY = [
|
| 137 |
+
{"user_id": "U001", "name": "Alice Johnson", "page_views": 45, "session_duration": 320, "bounce_rate": 0.25},
|
| 138 |
+
{"user_id": "U001", "name": "Alice J.", "page_views": 45, "session_duration": 315, "bounce_rate": 0.25},
|
| 139 |
+
{"user_id": "U002", "name": "Bob Smith", "page_views": 12, "session_duration": 85000, "bounce_rate": 0.80},
|
| 140 |
+
{"user_id": "U003", "name": "Carol White", "page_views": 67, "session_duration": 450, "bounce_rate": 0.15},
|
| 141 |
+
{"user_id": "U004", "name": "Dave Brown", "page_views": 23, "session_duration": 190, "bounce_rate": 0.55},
|
| 142 |
+
{"user_id": "U005", "name": "Eve Davis", "page_views": 89, "session_duration": 95000, "bounce_rate": 0.10},
|
| 143 |
+
{"user_id": "U003", "name": "Carol White", "page_views": 67, "session_duration": 450, "bounce_rate": 0.15},
|
| 144 |
+
{"user_id": "U006", "name": "Frank Miller", "page_views": None, "session_duration": 280, "bounce_rate": 0.45},
|
| 145 |
+
{"user_id": "U007", "name": "Grace Wilson", "page_views": 34, "session_duration": 360, "bounce_rate": 1.50},
|
| 146 |
+
{"user_id": "U008", "name": "Henry Moore", "page_views": 56, "session_duration": 420, "bounce_rate": 0.35},
|
| 147 |
+
{"user_id": "U009", "name": "Iris Taylor", "page_views": 78, "session_duration": 78000, "bounce_rate": 0.20},
|
| 148 |
+
{"user_id": "U010", "name": "Jack Wilson", "page_views": 19, "session_duration": 150, "bounce_rate": 0.70},
|
| 149 |
+
]
|
| 150 |
+
|
| 151 |
+
TASK3_DESCRIPTION = (
|
| 152 |
+
"Clean a web analytics dataset. Issues to fix:\n"
|
| 153 |
+
"1) Remove duplicate user_ids (exact + near-duplicates, keep first occurrence)\n"
|
| 154 |
+
"2) Clip session_duration outliers to max 1000 seconds\n"
|
| 155 |
+
"3) Clip bounce_rate to valid range [0.0, 1.0]\n"
|
| 156 |
+
"4) Fill missing page_views with median"
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
TASK4_DESCRIPTION = (
|
| 160 |
+
"Alternative medium data-cleaning scenario based on e-commerce orders.\n"
|
| 161 |
+
"Use the same cleaning operations as task2_medium and submit a clean table."
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
TASK5_DESCRIPTION = (
|
| 165 |
+
"Alternative hard data-cleaning scenario based on analytics logs.\n"
|
| 166 |
+
"Use the same cleaning operations as task3_hard and submit a clean table."
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
TASK_GRADER_ENTRYPOINTS_COLON = {
|
| 170 |
+
"task1_easy": "env.graders:grade_task1_easy",
|
| 171 |
+
"task2_medium": "env.graders:grade_task2_medium",
|
| 172 |
+
"task3_hard": "env.graders:grade_task3_hard",
|
| 173 |
+
"task4_medium_alt": "env.graders:grade_task4_medium_alt",
|
| 174 |
+
"task5_hard_alt": "env.graders:grade_task5_hard_alt",
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
TASK_GRADER_ENTRYPOINTS_DOTTED = {
|
| 178 |
+
"task1_easy": "env.graders.grade_task1_easy",
|
| 179 |
+
"task2_medium": "env.graders.grade_task2_medium",
|
| 180 |
+
"task3_hard": "env.graders.grade_task3_hard",
|
| 181 |
+
"task4_medium_alt": "env.graders.grade_task4_medium_alt",
|
| 182 |
+
"task5_hard_alt": "env.graders.grade_task5_hard_alt",
|
| 183 |
+
}
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def get_task(task_id: str) -> Dict[str, Any]:
|
| 187 |
+
registry = {
|
| 188 |
+
"task1_easy": {
|
| 189 |
+
"description": TASK1_DESCRIPTION,
|
| 190 |
+
"dirty_df": pd.DataFrame(TASK1_DIRTY),
|
| 191 |
+
"task_id": "task1_easy",
|
| 192 |
+
"difficulty": "easy",
|
| 193 |
+
"grader": TASK_GRADER_ENTRYPOINTS_DOTTED["task1_easy"],
|
| 194 |
+
"grader_fn": TASK_GRADER_ENTRYPOINTS_COLON["task1_easy"],
|
| 195 |
+
"grader_path": TASK_GRADER_ENTRYPOINTS_COLON["task1_easy"],
|
| 196 |
+
},
|
| 197 |
+
"task2_medium": {
|
| 198 |
+
"description": TASK2_DESCRIPTION,
|
| 199 |
+
"dirty_df": pd.DataFrame(TASK2_DIRTY),
|
| 200 |
+
"task_id": "task2_medium",
|
| 201 |
+
"difficulty": "medium",
|
| 202 |
+
"grader": TASK_GRADER_ENTRYPOINTS_DOTTED["task2_medium"],
|
| 203 |
+
"grader_fn": TASK_GRADER_ENTRYPOINTS_COLON["task2_medium"],
|
| 204 |
+
"grader_path": TASK_GRADER_ENTRYPOINTS_COLON["task2_medium"],
|
| 205 |
+
},
|
| 206 |
+
"task3_hard": {
|
| 207 |
+
"description": TASK3_DESCRIPTION,
|
| 208 |
+
"dirty_df": pd.DataFrame(TASK3_DIRTY),
|
| 209 |
+
"task_id": "task3_hard",
|
| 210 |
+
"difficulty": "hard",
|
| 211 |
+
"grader": TASK_GRADER_ENTRYPOINTS_DOTTED["task3_hard"],
|
| 212 |
+
"grader_fn": TASK_GRADER_ENTRYPOINTS_COLON["task3_hard"],
|
| 213 |
+
"grader_path": TASK_GRADER_ENTRYPOINTS_COLON["task3_hard"],
|
| 214 |
+
},
|
| 215 |
+
"task4_medium_alt": {
|
| 216 |
+
"description": TASK4_DESCRIPTION,
|
| 217 |
+
"dirty_df": pd.DataFrame(TASK2_DIRTY),
|
| 218 |
+
"task_id": "task4_medium_alt",
|
| 219 |
+
"difficulty": "medium",
|
| 220 |
+
"grader": TASK_GRADER_ENTRYPOINTS_DOTTED["task4_medium_alt"],
|
| 221 |
+
"grader_fn": TASK_GRADER_ENTRYPOINTS_COLON["task4_medium_alt"],
|
| 222 |
+
"grader_path": TASK_GRADER_ENTRYPOINTS_COLON["task4_medium_alt"],
|
| 223 |
+
},
|
| 224 |
+
"task5_hard_alt": {
|
| 225 |
+
"description": TASK5_DESCRIPTION,
|
| 226 |
+
"dirty_df": pd.DataFrame(TASK3_DIRTY),
|
| 227 |
+
"task_id": "task5_hard_alt",
|
| 228 |
+
"difficulty": "hard",
|
| 229 |
+
"grader": TASK_GRADER_ENTRYPOINTS_DOTTED["task5_hard_alt"],
|
| 230 |
+
"grader_fn": TASK_GRADER_ENTRYPOINTS_COLON["task5_hard_alt"],
|
| 231 |
+
"grader_path": TASK_GRADER_ENTRYPOINTS_COLON["task5_hard_alt"],
|
| 232 |
+
},
|
| 233 |
+
}
|
| 234 |
+
if task_id not in registry:
|
| 235 |
+
raise ValueError(f"Unknown task_id '{task_id}'. Choose from: {list(registry)}")
|
| 236 |
+
cfg = registry[task_id]
|
| 237 |
+
cfg["dirty_df"] = cfg["dirty_df"].copy()
|
| 238 |
+
return cfg
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
TASK_IDS = ["task1_easy", "task2_medium", "task3_hard", "task4_medium_alt", "task5_hard_alt"]
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def list_tasks() -> list[dict[str, Any]]:
|
| 245 |
+
return [
|
| 246 |
+
{
|
| 247 |
+
"id": "task1_easy",
|
| 248 |
+
"task_id": "task1_easy",
|
| 249 |
+
"difficulty": "easy",
|
| 250 |
+
"max_steps": 20,
|
| 251 |
+
"grader": TASK_GRADER_ENTRYPOINTS_DOTTED["task1_easy"],
|
| 252 |
+
"grader_fn": TASK_GRADER_ENTRYPOINTS_COLON["task1_easy"],
|
| 253 |
+
"grader_path": TASK_GRADER_ENTRYPOINTS_COLON["task1_easy"],
|
| 254 |
+
},
|
| 255 |
+
{
|
| 256 |
+
"id": "task2_medium",
|
| 257 |
+
"task_id": "task2_medium",
|
| 258 |
+
"difficulty": "medium",
|
| 259 |
+
"max_steps": 20,
|
| 260 |
+
"grader": TASK_GRADER_ENTRYPOINTS_DOTTED["task2_medium"],
|
| 261 |
+
"grader_fn": TASK_GRADER_ENTRYPOINTS_COLON["task2_medium"],
|
| 262 |
+
"grader_path": TASK_GRADER_ENTRYPOINTS_COLON["task2_medium"],
|
| 263 |
+
},
|
| 264 |
+
{
|
| 265 |
+
"id": "task3_hard",
|
| 266 |
+
"task_id": "task3_hard",
|
| 267 |
+
"difficulty": "hard",
|
| 268 |
+
"max_steps": 20,
|
| 269 |
+
"grader": TASK_GRADER_ENTRYPOINTS_DOTTED["task3_hard"],
|
| 270 |
+
"grader_fn": TASK_GRADER_ENTRYPOINTS_COLON["task3_hard"],
|
| 271 |
+
"grader_path": TASK_GRADER_ENTRYPOINTS_COLON["task3_hard"],
|
| 272 |
+
},
|
| 273 |
+
{
|
| 274 |
+
"id": "task4_medium_alt",
|
| 275 |
+
"task_id": "task4_medium_alt",
|
| 276 |
+
"difficulty": "medium",
|
| 277 |
+
"max_steps": 20,
|
| 278 |
+
"grader": TASK_GRADER_ENTRYPOINTS_DOTTED["task4_medium_alt"],
|
| 279 |
+
"grader_fn": TASK_GRADER_ENTRYPOINTS_COLON["task4_medium_alt"],
|
| 280 |
+
"grader_path": TASK_GRADER_ENTRYPOINTS_COLON["task4_medium_alt"],
|
| 281 |
+
},
|
| 282 |
+
{
|
| 283 |
+
"id": "task5_hard_alt",
|
| 284 |
+
"task_id": "task5_hard_alt",
|
| 285 |
+
"difficulty": "hard",
|
| 286 |
+
"max_steps": 20,
|
| 287 |
+
"grader": TASK_GRADER_ENTRYPOINTS_DOTTED["task5_hard_alt"],
|
| 288 |
+
"grader_fn": TASK_GRADER_ENTRYPOINTS_COLON["task5_hard_alt"],
|
| 289 |
+
"grader_path": TASK_GRADER_ENTRYPOINTS_COLON["task5_hard_alt"],
|
| 290 |
+
},
|
| 291 |
+
]
|
inference.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import os
|
| 5 |
+
import re
|
| 6 |
+
|
| 7 |
+
from dotenv import load_dotenv
|
| 8 |
+
from openai import OpenAI
|
| 9 |
+
|
| 10 |
+
from env.environment import DataCleaningEnv
|
| 11 |
+
from env.models import Action
|
| 12 |
+
|
| 13 |
+
load_dotenv()
|
| 14 |
+
|
| 15 |
+
API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1")
|
| 16 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/llama-4-scout-17b-16e-instruct")
|
| 17 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 18 |
+
TASK_ID = os.getenv("TASK_ID", "task1_easy")
|
| 19 |
+
MAX_STEPS = int(os.getenv("MAX_STEPS", "15"))
|
| 20 |
+
ENV_NAME = "data-cleaning-benchmark"
|
| 21 |
+
|
| 22 |
+
if HF_TOKEN is None:
|
| 23 |
+
raise ValueError("HF_TOKEN environment variable is required")
|
| 24 |
+
|
| 25 |
+
client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
|
| 26 |
+
|
| 27 |
+
SYSTEM_PROMPT = """You are a data cleaning agent. Analyse the observation and choose ONE cleaning action.
|
| 28 |
+
|
| 29 |
+
Available action types and required fields:
|
| 30 |
+
fill_missing -> column (str), strategy (mean|median|mode|constant), value (if constant)
|
| 31 |
+
standardize_values -> column (str), mapping (dict old->new)
|
| 32 |
+
remove_duplicates -> (no extra fields)
|
| 33 |
+
remove_row -> row_id (int from _row_id column in preview)
|
| 34 |
+
convert_type -> column (str), target_type (float|int|str|datetime)
|
| 35 |
+
clip_outliers -> column (str), lower (float|null), upper (float|null)
|
| 36 |
+
submit -> (no extra fields; use when dataset is clean)
|
| 37 |
+
|
| 38 |
+
Rules:
|
| 39 |
+
- Respond with a SINGLE valid JSON object and NOTHING else.
|
| 40 |
+
- No markdown fences, no explanation.
|
| 41 |
+
- When no issues remain, always respond with: {"type": "submit"}
|
| 42 |
+
|
| 43 |
+
Examples:
|
| 44 |
+
{"type": "remove_duplicates"}
|
| 45 |
+
{"type": "fill_missing", "column": "age", "strategy": "median"}
|
| 46 |
+
{"type": "standardize_values", "column": "country", "mapping": {"USA": "United States", "US": "United States", "UK": "United Kingdom", "CAN": "Canada", "australia": "Australia", "AUS": "Australia"}}
|
| 47 |
+
{"type": "convert_type", "column": "date", "target_type": "datetime"}
|
| 48 |
+
{"type": "convert_type", "column": "price", "target_type": "float"}
|
| 49 |
+
{"type": "clip_outliers", "column": "session_duration", "lower": 0.0, "upper": 1000.0}
|
| 50 |
+
{"type": "submit"}
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def get_action(obs_dict: dict, history: list[dict]) -> dict:
|
| 55 |
+
user_msg = {
|
| 56 |
+
"role": "user",
|
| 57 |
+
"content": (
|
| 58 |
+
"Current observation:\n" + json.dumps(obs_dict, indent=2, default=str) + "\n\nNext action (JSON only):"
|
| 59 |
+
),
|
| 60 |
+
}
|
| 61 |
+
history.append(user_msg)
|
| 62 |
+
|
| 63 |
+
response = client.chat.completions.create(
|
| 64 |
+
model=MODEL_NAME,
|
| 65 |
+
messages=[{"role": "system", "content": SYSTEM_PROMPT}] + history,
|
| 66 |
+
max_tokens=256,
|
| 67 |
+
temperature=0,
|
| 68 |
+
)
|
| 69 |
+
raw = response.choices[0].message.content.strip()
|
| 70 |
+
history.append({"role": "assistant", "content": raw})
|
| 71 |
+
|
| 72 |
+
clean = re.sub(r"```[a-z]*\n?", "", raw).replace("```", "").strip()
|
| 73 |
+
try:
|
| 74 |
+
return json.loads(clean)
|
| 75 |
+
except json.JSONDecodeError:
|
| 76 |
+
match = re.search(r"\{.*\}", clean, re.DOTALL)
|
| 77 |
+
if match:
|
| 78 |
+
return json.loads(match.group())
|
| 79 |
+
return {"type": "submit"}
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def run_inference() -> None:
|
| 83 |
+
env = DataCleaningEnv()
|
| 84 |
+
rewards: list[float] = []
|
| 85 |
+
history: list[dict] = []
|
| 86 |
+
step = 0
|
| 87 |
+
done = False
|
| 88 |
+
success = False
|
| 89 |
+
|
| 90 |
+
print(f"[START] task={TASK_ID} env={ENV_NAME} model={MODEL_NAME}", flush=True)
|
| 91 |
+
|
| 92 |
+
try:
|
| 93 |
+
obs = env.reset(task_id=TASK_ID)
|
| 94 |
+
|
| 95 |
+
while not done and step < MAX_STEPS:
|
| 96 |
+
try:
|
| 97 |
+
action_dict = get_action(obs.model_dump(), history)
|
| 98 |
+
action = Action(**action_dict)
|
| 99 |
+
except Exception:
|
| 100 |
+
action_dict = {"type": "submit"}
|
| 101 |
+
action = Action(type="submit")
|
| 102 |
+
|
| 103 |
+
result = env.step(action)
|
| 104 |
+
obs = result.observation
|
| 105 |
+
done = result.done
|
| 106 |
+
reward = result.reward
|
| 107 |
+
error = result.info.get("error")
|
| 108 |
+
|
| 109 |
+
rewards.append(reward)
|
| 110 |
+
step += 1
|
| 111 |
+
|
| 112 |
+
action_str = json.dumps(action_dict, separators=(",", ":"), default=str)
|
| 113 |
+
print(
|
| 114 |
+
f"[STEP] step={step} action={action_str} "
|
| 115 |
+
f"reward={reward:.2f} done={'true' if done else 'false'} "
|
| 116 |
+
f"error={error if error else 'null'}",
|
| 117 |
+
flush=True,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
if not done:
|
| 121 |
+
result = env.step(Action(type="submit"))
|
| 122 |
+
rewards.append(result.reward)
|
| 123 |
+
step += 1
|
| 124 |
+
print(
|
| 125 |
+
f"[STEP] step={step} action={{\"type\":\"submit\"}} "
|
| 126 |
+
f"reward={result.reward:.2f} done=true error={result.info.get('error') or 'null'}",
|
| 127 |
+
flush=True,
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
success = bool(env.final_score >= 0.5)
|
| 131 |
+
except Exception:
|
| 132 |
+
success = False
|
| 133 |
+
finally:
|
| 134 |
+
try:
|
| 135 |
+
if hasattr(env, "close"):
|
| 136 |
+
env.close()
|
| 137 |
+
except Exception:
|
| 138 |
+
pass
|
| 139 |
+
|
| 140 |
+
rewards_str = ",".join(f"{reward:.2f}" for reward in rewards)
|
| 141 |
+
print(
|
| 142 |
+
f"[END] success={'true' if success else 'false'} "
|
| 143 |
+
f"steps={step} score={env.final_score:.2f} rewards={rewards_str}",
|
| 144 |
+
flush=True,
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
if __name__ == "__main__":
|
| 149 |
+
run_inference()
|
openenv.yaml
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: data-cleaning-benchmark
|
| 2 |
+
version: "1.0.0"
|
| 3 |
+
description: >
|
| 4 |
+
A multi-task LLM agent benchmark for real-world tabular data cleaning.
|
| 5 |
+
The agent receives a dirty dataset and must apply structured cleaning
|
| 6 |
+
actions to fix duplicates, missing values, format issues, and outliers.
|
| 7 |
+
|
| 8 |
+
author: "Jayesh"
|
| 9 |
+
license: MIT
|
| 10 |
+
|
| 11 |
+
tasks:
|
| 12 |
+
- id: task1_easy
|
| 13 |
+
task_id: task1_easy
|
| 14 |
+
name: "Basic Customer Data Cleanup"
|
| 15 |
+
difficulty: easy
|
| 16 |
+
max_steps: 20
|
| 17 |
+
description: "Remove duplicates, fill missing values, standardise country names."
|
| 18 |
+
grader: env.graders.grade_task1_easy
|
| 19 |
+
grader_fn: env.graders:grade_task1_easy
|
| 20 |
+
grader_path: env.graders:grade_task1_easy
|
| 21 |
+
- id: task2_medium
|
| 22 |
+
task_id: task2_medium
|
| 23 |
+
name: "E-commerce Orders Normalisation"
|
| 24 |
+
difficulty: medium
|
| 25 |
+
max_steps: 20
|
| 26 |
+
description: "Fix mixed date formats, convert price strings, correct category typos."
|
| 27 |
+
grader: env.graders.grade_task2_medium
|
| 28 |
+
grader_fn: env.graders:grade_task2_medium
|
| 29 |
+
grader_path: env.graders:grade_task2_medium
|
| 30 |
+
- id: task3_hard
|
| 31 |
+
task_id: task3_hard
|
| 32 |
+
name: "Analytics Data Deep Clean"
|
| 33 |
+
difficulty: hard
|
| 34 |
+
max_steps: 20
|
| 35 |
+
description: "Resolve duplicate user IDs, clip session outliers, fix invalid bounce rates."
|
| 36 |
+
grader: env.graders.grade_task3_hard
|
| 37 |
+
grader_fn: env.graders:grade_task3_hard
|
| 38 |
+
grader_path: env.graders:grade_task3_hard
|
| 39 |
+
- id: task4_medium_alt
|
| 40 |
+
task_id: task4_medium_alt
|
| 41 |
+
name: "E-commerce Orders Cleanup (Alt)"
|
| 42 |
+
difficulty: medium
|
| 43 |
+
max_steps: 20
|
| 44 |
+
description: "Alternative medium scenario sharing the same grading criteria as task2_medium."
|
| 45 |
+
grader: env.graders.grade_task4_medium_alt
|
| 46 |
+
grader_fn: env.graders:grade_task4_medium_alt
|
| 47 |
+
grader_path: env.graders:grade_task4_medium_alt
|
| 48 |
+
- id: task5_hard_alt
|
| 49 |
+
task_id: task5_hard_alt
|
| 50 |
+
name: "Analytics Deep Clean (Alt)"
|
| 51 |
+
difficulty: hard
|
| 52 |
+
max_steps: 20
|
| 53 |
+
description: "Alternative hard scenario sharing the same grading criteria as task3_hard."
|
| 54 |
+
grader: env.graders.grade_task5_hard_alt
|
| 55 |
+
grader_fn: env.graders:grade_task5_hard_alt
|
| 56 |
+
grader_path: env.graders:grade_task5_hard_alt
|
| 57 |
+
|
| 58 |
+
observation_space:
|
| 59 |
+
type: structured_json
|
| 60 |
+
fields:
|
| 61 |
+
- task_id
|
| 62 |
+
- task_description
|
| 63 |
+
- table_preview
|
| 64 |
+
- schema_info
|
| 65 |
+
- valid_actions
|
| 66 |
+
- step / max_steps
|
| 67 |
+
- cleaning_log
|
| 68 |
+
- issues_detected
|
| 69 |
+
|
| 70 |
+
action_space:
|
| 71 |
+
type: structured_json
|
| 72 |
+
actions:
|
| 73 |
+
- name: fill_missing
|
| 74 |
+
params: ["column", "strategy(mean|median|mode|constant)", "value?"]
|
| 75 |
+
- name: standardize_values
|
| 76 |
+
params: ["column", "mapping(dict)"]
|
| 77 |
+
- name: remove_duplicates
|
| 78 |
+
params: []
|
| 79 |
+
- name: remove_row
|
| 80 |
+
params: ["row_id(int)"]
|
| 81 |
+
- name: convert_type
|
| 82 |
+
params: ["column", "target_type(float|int|str|datetime)"]
|
| 83 |
+
- name: clip_outliers
|
| 84 |
+
params: ["column", "lower?", "upper?"]
|
| 85 |
+
- name: submit
|
| 86 |
+
params: []
|
| 87 |
+
|
| 88 |
+
reward:
|
| 89 |
+
type: shaped
|
| 90 |
+
intermediate: true
|
| 91 |
+
range: [0.01, 0.99]
|
| 92 |
+
description: >
|
| 93 |
+
Positive rewards for correct cleaning steps; small penalties for
|
| 94 |
+
invalid or wasted actions; final grader score awarded on submit().
|
| 95 |
+
|
| 96 |
+
api:
|
| 97 |
+
base_path: "/"
|
| 98 |
+
endpoints:
|
| 99 |
+
reset: "POST /reset"
|
| 100 |
+
step: "POST /step"
|
| 101 |
+
state: "GET /state"
|
| 102 |
+
step_legacy: "POST /step/{session_id}"
|
| 103 |
+
state_legacy: "GET /state/{session_id}"
|
| 104 |
+
health: "GET /health"
|
| 105 |
+
tasks: "GET /tasks"
|
| 106 |
+
|
| 107 |
+
runtime:
|
| 108 |
+
language: python
|
| 109 |
+
version: "3.11"
|
| 110 |
+
port: 7860
|
| 111 |
+
framework: fastapi
|
| 112 |
+
|
| 113 |
+
tags:
|
| 114 |
+
- openenv
|
| 115 |
+
- data-cleaning
|
| 116 |
+
- llm-benchmark
|
| 117 |
+
- tabular
|
pyproject.toml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=68", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "data-cleaning-benchmark"
|
| 7 |
+
version = "1.0.0"
|
| 8 |
+
description = "A multi-task OpenEnv benchmark for tabular data cleaning."
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
requires-python = ">=3.11"
|
| 11 |
+
dependencies = [
|
| 12 |
+
"fastapi==0.110.0",
|
| 13 |
+
"uvicorn==0.27.1",
|
| 14 |
+
"pydantic==2.6.3",
|
| 15 |
+
"pandas==2.2.1",
|
| 16 |
+
"numpy==1.26.4",
|
| 17 |
+
"openai>=2.7.2",
|
| 18 |
+
"openenv>=0.2.0",
|
| 19 |
+
"python-dotenv==1.0.1",
|
| 20 |
+
"httpx==0.27.0",
|
| 21 |
+
"pytest==8.1.1",
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
[project.scripts]
|
| 25 |
+
server = "server.cli:main"
|
| 26 |
+
|
| 27 |
+
[tool.pytest.ini_options]
|
| 28 |
+
testpaths = ["tests"]
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi==0.110.0
|
| 2 |
+
uvicorn==0.27.1
|
| 3 |
+
pydantic==2.6.3
|
| 4 |
+
pandas==2.2.1
|
| 5 |
+
numpy==1.26.4
|
| 6 |
+
openai>=2.7.2
|
| 7 |
+
openenv>=0.2.0
|
| 8 |
+
python-dotenv==1.0.1
|
| 9 |
+
httpx==0.27.0
|
| 10 |
+
pytest==8.1.1
|
server/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
server/app.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
import uvicorn
|
| 6 |
+
|
| 7 |
+
from app import app
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def main() -> None:
|
| 11 |
+
host = os.getenv("HOST", "0.0.0.0")
|
| 12 |
+
port = int(os.getenv("PORT", "7860"))
|
| 13 |
+
uvicorn.run("server.app:app", host=host, port=port)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
if __name__ == "__main__":
|
| 17 |
+
main()
|
server/cli.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
import uvicorn
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def main() -> None:
|
| 9 |
+
host = os.getenv("HOST", "0.0.0.0")
|
| 10 |
+
port = int(os.getenv("PORT", "7860"))
|
| 11 |
+
uvicorn.run("server.app:app", host=host, port=port)
|
tests/test_graders.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
|
| 3 |
+
from env.graders import grade_task, grade_task1, grade_task2, grade_task3
|
| 4 |
+
from env.tasks import get_task
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def test_grade_task1_dirty_is_low():
|
| 8 |
+
cfg = get_task("task1_easy")
|
| 9 |
+
score = grade_task1(cfg["dirty_df"])
|
| 10 |
+
assert 0.0 < score <= 0.5
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def test_grade_task1_perfect_is_bounded():
|
| 14 |
+
df = pd.DataFrame(
|
| 15 |
+
{
|
| 16 |
+
"name": ["Alice", "Bob", "Carol"],
|
| 17 |
+
"email": ["a@x.com", "b@x.com", "c@x.com"],
|
| 18 |
+
"country": ["United States", "United Kingdom", "Australia"],
|
| 19 |
+
"age": [28.0, 35.0, 42.0],
|
| 20 |
+
}
|
| 21 |
+
)
|
| 22 |
+
score = grade_task1(df)
|
| 23 |
+
assert 0.99 == score
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def test_grade_task1_partial():
|
| 27 |
+
df = pd.DataFrame(
|
| 28 |
+
{
|
| 29 |
+
"name": ["Alice", "Bob"],
|
| 30 |
+
"email": ["a@x.com", "b@x.com"],
|
| 31 |
+
"country": ["USA", "UK"],
|
| 32 |
+
"age": [28.0, 35.0],
|
| 33 |
+
}
|
| 34 |
+
)
|
| 35 |
+
score = grade_task1(df)
|
| 36 |
+
assert 0.4 < score < 0.99
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def test_grade_task2_score_range():
|
| 40 |
+
cfg = get_task("task2_medium")
|
| 41 |
+
score = grade_task2(cfg["dirty_df"])
|
| 42 |
+
assert 0.0 < score < 1.0
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def test_grade_task3_score_range():
|
| 46 |
+
cfg = get_task("task3_hard")
|
| 47 |
+
score = grade_task3(cfg["dirty_df"])
|
| 48 |
+
assert 0.0 < score < 1.0
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def test_grade_task_dispatcher():
|
| 52 |
+
for tid in ["task1_easy", "task2_medium", "task3_hard", "task4_medium_alt", "task5_hard_alt"]:
|
| 53 |
+
cfg = get_task(tid)
|
| 54 |
+
s = grade_task(tid, cfg["dirty_df"])
|
| 55 |
+
assert 0.0 < s < 1.0
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def test_grader_not_constant():
|
| 59 |
+
cfg = get_task("task1_easy")
|
| 60 |
+
dirty_score = grade_task1(cfg["dirty_df"])
|
| 61 |
+
clean_df = pd.DataFrame(
|
| 62 |
+
{
|
| 63 |
+
"name": ["Alice", "Bob"],
|
| 64 |
+
"email": ["a@x.com", "b@x.com"],
|
| 65 |
+
"country": ["United States", "Australia"],
|
| 66 |
+
"age": [28.0, 35.0],
|
| 67 |
+
}
|
| 68 |
+
)
|
| 69 |
+
clean_score = grade_task1(clean_df)
|
| 70 |
+
assert clean_score != dirty_score
|
tests/test_reset.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pytest
|
| 2 |
+
|
| 3 |
+
from env.environment import DataCleaningEnv
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def test_reset_default():
|
| 7 |
+
env = DataCleaningEnv()
|
| 8 |
+
obs = env.reset()
|
| 9 |
+
assert obs.task_id == "task1_easy"
|
| 10 |
+
assert obs.step == 0
|
| 11 |
+
assert obs.max_steps == 20
|
| 12 |
+
assert len(obs.table_preview.rows) > 0
|
| 13 |
+
assert "remove_duplicates" in obs.valid_actions
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def test_reset_task2():
|
| 17 |
+
env = DataCleaningEnv()
|
| 18 |
+
obs = env.reset(task_id="task2_medium")
|
| 19 |
+
assert obs.task_id == "task2_medium"
|
| 20 |
+
assert obs.step == 0
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def test_reset_task3():
|
| 24 |
+
env = DataCleaningEnv()
|
| 25 |
+
obs = env.reset(task_id="task3_hard")
|
| 26 |
+
assert obs.task_id == "task3_hard"
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def test_reset_task4_alt():
|
| 30 |
+
env = DataCleaningEnv()
|
| 31 |
+
obs = env.reset(task_id="task4_medium_alt")
|
| 32 |
+
assert obs.task_id == "task4_medium_alt"
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def test_reset_task5_alt():
|
| 36 |
+
env = DataCleaningEnv()
|
| 37 |
+
obs = env.reset(task_id="task5_hard_alt")
|
| 38 |
+
assert obs.task_id == "task5_hard_alt"
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def test_reset_unknown_task():
|
| 42 |
+
env = DataCleaningEnv()
|
| 43 |
+
with pytest.raises(ValueError):
|
| 44 |
+
env.reset(task_id="nonexistent_task")
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def test_issues_detected_on_reset():
|
| 48 |
+
env = DataCleaningEnv()
|
| 49 |
+
obs = env.reset(task_id="task1_easy")
|
| 50 |
+
assert len(obs.issues_detected) > 0
|
tests/test_step.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from env.environment import DataCleaningEnv
|
| 2 |
+
from env.models import Action
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def test_remove_duplicates_gives_positive_reward():
|
| 6 |
+
env = DataCleaningEnv()
|
| 7 |
+
env.reset(task_id="task1_easy")
|
| 8 |
+
result = env.step(Action(type="remove_duplicates"))
|
| 9 |
+
assert result.reward > 0
|
| 10 |
+
assert not result.done
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def test_fill_missing_median():
|
| 14 |
+
env = DataCleaningEnv()
|
| 15 |
+
env.reset(task_id="task1_easy")
|
| 16 |
+
result = env.step(Action(type="fill_missing", column="age", strategy="median"))
|
| 17 |
+
assert result.reward >= 0
|
| 18 |
+
assert env.current_df["age"].isna().sum() == 0
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def test_invalid_action_penalised():
|
| 22 |
+
env = DataCleaningEnv()
|
| 23 |
+
env.reset(task_id="task1_easy")
|
| 24 |
+
result = env.step(Action(type="fill_missing", column="nonexistent_col", strategy="mean"))
|
| 25 |
+
assert result.reward < 0
|
| 26 |
+
assert result.info["error"] is not None
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def test_submit_ends_episode():
|
| 30 |
+
env = DataCleaningEnv()
|
| 31 |
+
env.reset(task_id="task1_easy")
|
| 32 |
+
result = env.step(Action(type="submit"))
|
| 33 |
+
assert result.done
|
| 34 |
+
assert result.info["final_score"] >= 0.0
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def test_step_after_done_is_no_op():
|
| 38 |
+
env = DataCleaningEnv()
|
| 39 |
+
env.reset(task_id="task1_easy")
|
| 40 |
+
env.step(Action(type="submit"))
|
| 41 |
+
result = env.step(Action(type="remove_duplicates"))
|
| 42 |
+
assert result.done
|
| 43 |
+
assert 0.0 < result.reward < 1.0
|
| 44 |
+
assert result.reward == result.info["final_score"]
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def test_convert_type_datetime():
|
| 48 |
+
env = DataCleaningEnv()
|
| 49 |
+
env.reset(task_id="task2_medium")
|
| 50 |
+
result = env.step(Action(type="convert_type", column="date", target_type="datetime"))
|
| 51 |
+
assert result.reward > 0
|
| 52 |
+
sample = env.current_df["date"].dropna().iloc[0]
|
| 53 |
+
import re
|
| 54 |
+
|
| 55 |
+
assert re.match(r"\d{4}-\d{2}-\d{2}", str(sample))
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def test_clip_outliers():
|
| 59 |
+
env = DataCleaningEnv()
|
| 60 |
+
env.reset(task_id="task3_hard")
|
| 61 |
+
result = env.step(Action(type="clip_outliers", column="session_duration", lower=0.0, upper=1000.0))
|
| 62 |
+
assert result.reward > 0
|
| 63 |
+
assert env.current_df["session_duration"].max() <= 1000.0
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|