Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- .gitignore +12 -0
- Dockerfile +1 -9
- env/__init__.py +0 -1
- env/environment.py +19 -55
- env/models.py +21 -42
- env/tasks.py +43 -104
- inference.py +1 -40
- server.py +7 -22
.gitignore
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
*.pyo
|
| 4 |
+
.env
|
| 5 |
+
*.egg-info/
|
| 6 |
+
dist/
|
| 7 |
+
build/
|
| 8 |
+
.vscode/
|
| 9 |
+
.idea/
|
| 10 |
+
test_space.py
|
| 11 |
+
test_all_tasks.py
|
| 12 |
+
*.pdf
|
Dockerfile
CHANGED
|
@@ -2,24 +2,16 @@ FROM python:3.11-slim
|
|
| 2 |
|
| 3 |
WORKDIR /app
|
| 4 |
|
| 5 |
-
|
| 6 |
-
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 7 |
-
curl \
|
| 8 |
-
&& rm -rf /var/lib/apt/lists/*
|
| 9 |
|
| 10 |
-
# Copy requirements first for caching
|
| 11 |
COPY requirements.txt .
|
| 12 |
RUN pip install --no-cache-dir -r requirements.txt
|
| 13 |
|
| 14 |
-
# Copy application code
|
| 15 |
COPY . .
|
| 16 |
|
| 17 |
-
# Expose port
|
| 18 |
EXPOSE 8000
|
| 19 |
|
| 20 |
-
# Health check
|
| 21 |
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
| 22 |
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" || exit 1
|
| 23 |
|
| 24 |
-
# Run the server
|
| 25 |
CMD ["uvicorn", "server:app", "--host", "0.0.0.0", "--port", "8000"]
|
|
|
|
| 2 |
|
| 3 |
WORKDIR /app
|
| 4 |
|
| 5 |
+
RUN apt-get update && apt-get install -y --no-install-recommends curl && rm -rf /var/lib/apt/lists/*
|
|
|
|
|
|
|
|
|
|
| 6 |
|
|
|
|
| 7 |
COPY requirements.txt .
|
| 8 |
RUN pip install --no-cache-dir -r requirements.txt
|
| 9 |
|
|
|
|
| 10 |
COPY . .
|
| 11 |
|
|
|
|
| 12 |
EXPOSE 8000
|
| 13 |
|
|
|
|
| 14 |
HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
|
| 15 |
CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" || exit 1
|
| 16 |
|
|
|
|
| 17 |
CMD ["uvicorn", "server:app", "--host", "0.0.0.0", "--port", "8000"]
|
env/__init__.py
CHANGED
|
@@ -1,2 +1 @@
|
|
| 1 |
-
"""Data Validation Pipeline - OpenEnv Environment."""
|
| 2 |
from env.models import DataCleanAction, DataCleanObservation, DataCleanState
|
|
|
|
|
|
|
| 1 |
from env.models import DataCleanAction, DataCleanObservation, DataCleanState
|
env/environment.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
| 1 |
-
"""Core Environment implementation for the Data Validation Pipeline."""
|
| 2 |
-
|
| 3 |
import uuid
|
| 4 |
from typing import Any, Dict, List, Optional
|
| 5 |
|
|
@@ -8,44 +6,25 @@ from env.tasks import generate_task, get_task_names, grade_action
|
|
| 8 |
|
| 9 |
|
| 10 |
class DataValidationEnvironment:
|
| 11 |
-
|
| 12 |
-
Data Validation Pipeline Environment.
|
| 13 |
-
|
| 14 |
-
An RL environment where the agent must clean and validate structured datasets
|
| 15 |
-
by identifying and fixing errors (missing values, type mismatches, format violations,
|
| 16 |
-
range errors, and duplicates).
|
| 17 |
-
|
| 18 |
-
Follows OpenEnv Environment interface: reset(), step(), state().
|
| 19 |
-
"""
|
| 20 |
-
|
| 21 |
def __init__(self):
|
| 22 |
self._state = DataCleanState()
|
| 23 |
self._ground_truth: List[Dict[str, Any]] = []
|
| 24 |
self._errors: List[Dict[str, Any]] = []
|
| 25 |
self._task_info: Dict[str, Any] = {}
|
| 26 |
self._field_names: List[str] = []
|
| 27 |
-
|
| 28 |
def reset(self, task_name: Optional[str] = None, seed: int = 42, **kwargs) -> DataCleanObservation:
|
| 29 |
-
"""
|
| 30 |
-
Reset the environment with a new task.
|
| 31 |
-
|
| 32 |
-
Args:
|
| 33 |
-
task_name: Task to load ('easy_missing_values', 'medium_mixed_errors', 'hard_multi_constraint')
|
| 34 |
-
seed: Random seed for reproducibility
|
| 35 |
-
|
| 36 |
-
Returns:
|
| 37 |
-
Initial observation
|
| 38 |
-
"""
|
| 39 |
if task_name is None:
|
| 40 |
task_name = "easy_missing_values"
|
| 41 |
-
|
| 42 |
task = generate_task(task_name, seed)
|
| 43 |
-
|
| 44 |
self._ground_truth = task["ground_truth"]
|
| 45 |
self._errors = task["errors"]
|
| 46 |
self._task_info = task
|
| 47 |
self._field_names = task["field_names"]
|
| 48 |
-
|
| 49 |
self._state = DataCleanState(
|
| 50 |
episode_id=str(uuid.uuid4()),
|
| 51 |
task_name=task_name,
|
|
@@ -61,7 +40,7 @@ class DataValidationEnvironment:
|
|
| 61 |
total_errors=len(self._errors),
|
| 62 |
last_actions=[],
|
| 63 |
)
|
| 64 |
-
|
| 65 |
return DataCleanObservation(
|
| 66 |
task_name=task_name,
|
| 67 |
task_description=task["description"],
|
|
@@ -80,27 +59,17 @@ class DataValidationEnvironment:
|
|
| 80 |
progress_pct=0.0,
|
| 81 |
field_names=self._field_names,
|
| 82 |
)
|
| 83 |
-
|
| 84 |
def step(self, action: DataCleanAction) -> DataCleanObservation:
|
| 85 |
-
"""
|
| 86 |
-
Execute an action to fix a data error.
|
| 87 |
-
|
| 88 |
-
Args:
|
| 89 |
-
action: The action to take
|
| 90 |
-
|
| 91 |
-
Returns:
|
| 92 |
-
Updated observation with reward
|
| 93 |
-
"""
|
| 94 |
if self._state.done:
|
| 95 |
return self._make_observation(0.0, "Episode already done. Call reset().")
|
| 96 |
-
|
| 97 |
self._state.step_count += 1
|
| 98 |
-
|
| 99 |
-
# Check for repeated identical action
|
| 100 |
action_key = f"{action.action_type}:{action.target_field}:{action.target_row}:{action.new_value}"
|
| 101 |
is_repeat = action_key in self._state.last_actions
|
| 102 |
self._state.last_actions.append(action_key)
|
| 103 |
-
|
| 104 |
if is_repeat:
|
| 105 |
reward = -0.1
|
| 106 |
message = "Penalty: repeated identical action"
|
|
@@ -116,39 +85,34 @@ class DataValidationEnvironment:
|
|
| 116 |
)
|
| 117 |
if fixed:
|
| 118 |
self._state.errors_fixed += 1
|
| 119 |
-
|
| 120 |
self._state.cumulative_reward += reward
|
| 121 |
self._state.reward_history.append(reward)
|
| 122 |
-
|
| 123 |
-
# Check termination conditions
|
| 124 |
errors_remaining = sum(1 for e in self._errors if not e.get("fixed", False))
|
| 125 |
-
|
| 126 |
if errors_remaining == 0:
|
| 127 |
self._state.done = True
|
| 128 |
message += " | All errors fixed! Episode complete."
|
| 129 |
elif self._state.step_count >= self._state.max_steps:
|
| 130 |
self._state.done = True
|
| 131 |
message += f" | Max steps reached. {errors_remaining} errors remaining."
|
| 132 |
-
|
| 133 |
return self._make_observation(reward, message)
|
| 134 |
-
|
| 135 |
def state(self) -> DataCleanState:
|
| 136 |
-
"""Return the current environment state."""
|
| 137 |
return self._state
|
| 138 |
-
|
| 139 |
def get_task_names(self) -> List[str]:
|
| 140 |
-
"""Return available task names."""
|
| 141 |
return get_task_names()
|
| 142 |
-
|
| 143 |
def _make_observation(self, reward: float, message: str) -> DataCleanObservation:
|
| 144 |
-
"""Create an observation from current state."""
|
| 145 |
errors_remaining = sum(1 for e in self._errors if not e.get("fixed", False))
|
| 146 |
total = self._state.total_errors if self._state.total_errors > 0 else 1
|
| 147 |
progress = (self._state.errors_fixed / total) * 100
|
| 148 |
-
|
| 149 |
-
# Only show unfixed errors
|
| 150 |
unfixed_errors = [e for e in self._errors if not e.get("fixed", False)]
|
| 151 |
-
|
| 152 |
return DataCleanObservation(
|
| 153 |
task_name=self._state.task_name,
|
| 154 |
task_description=self._task_info.get("description", ""),
|
|
|
|
|
|
|
|
|
|
| 1 |
import uuid
|
| 2 |
from typing import Any, Dict, List, Optional
|
| 3 |
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
class DataValidationEnvironment:
|
| 9 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
def __init__(self):
|
| 11 |
self._state = DataCleanState()
|
| 12 |
self._ground_truth: List[Dict[str, Any]] = []
|
| 13 |
self._errors: List[Dict[str, Any]] = []
|
| 14 |
self._task_info: Dict[str, Any] = {}
|
| 15 |
self._field_names: List[str] = []
|
| 16 |
+
|
| 17 |
def reset(self, task_name: Optional[str] = None, seed: int = 42, **kwargs) -> DataCleanObservation:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
if task_name is None:
|
| 19 |
task_name = "easy_missing_values"
|
| 20 |
+
|
| 21 |
task = generate_task(task_name, seed)
|
| 22 |
+
|
| 23 |
self._ground_truth = task["ground_truth"]
|
| 24 |
self._errors = task["errors"]
|
| 25 |
self._task_info = task
|
| 26 |
self._field_names = task["field_names"]
|
| 27 |
+
|
| 28 |
self._state = DataCleanState(
|
| 29 |
episode_id=str(uuid.uuid4()),
|
| 30 |
task_name=task_name,
|
|
|
|
| 40 |
total_errors=len(self._errors),
|
| 41 |
last_actions=[],
|
| 42 |
)
|
| 43 |
+
|
| 44 |
return DataCleanObservation(
|
| 45 |
task_name=task_name,
|
| 46 |
task_description=task["description"],
|
|
|
|
| 59 |
progress_pct=0.0,
|
| 60 |
field_names=self._field_names,
|
| 61 |
)
|
| 62 |
+
|
| 63 |
def step(self, action: DataCleanAction) -> DataCleanObservation:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
if self._state.done:
|
| 65 |
return self._make_observation(0.0, "Episode already done. Call reset().")
|
| 66 |
+
|
| 67 |
self._state.step_count += 1
|
| 68 |
+
|
|
|
|
| 69 |
action_key = f"{action.action_type}:{action.target_field}:{action.target_row}:{action.new_value}"
|
| 70 |
is_repeat = action_key in self._state.last_actions
|
| 71 |
self._state.last_actions.append(action_key)
|
| 72 |
+
|
| 73 |
if is_repeat:
|
| 74 |
reward = -0.1
|
| 75 |
message = "Penalty: repeated identical action"
|
|
|
|
| 85 |
)
|
| 86 |
if fixed:
|
| 87 |
self._state.errors_fixed += 1
|
| 88 |
+
|
| 89 |
self._state.cumulative_reward += reward
|
| 90 |
self._state.reward_history.append(reward)
|
| 91 |
+
|
|
|
|
| 92 |
errors_remaining = sum(1 for e in self._errors if not e.get("fixed", False))
|
| 93 |
+
|
| 94 |
if errors_remaining == 0:
|
| 95 |
self._state.done = True
|
| 96 |
message += " | All errors fixed! Episode complete."
|
| 97 |
elif self._state.step_count >= self._state.max_steps:
|
| 98 |
self._state.done = True
|
| 99 |
message += f" | Max steps reached. {errors_remaining} errors remaining."
|
| 100 |
+
|
| 101 |
return self._make_observation(reward, message)
|
| 102 |
+
|
| 103 |
def state(self) -> DataCleanState:
|
|
|
|
| 104 |
return self._state
|
| 105 |
+
|
| 106 |
def get_task_names(self) -> List[str]:
|
|
|
|
| 107 |
return get_task_names()
|
| 108 |
+
|
| 109 |
def _make_observation(self, reward: float, message: str) -> DataCleanObservation:
|
|
|
|
| 110 |
errors_remaining = sum(1 for e in self._errors if not e.get("fixed", False))
|
| 111 |
total = self._state.total_errors if self._state.total_errors > 0 else 1
|
| 112 |
progress = (self._state.errors_fixed / total) * 100
|
| 113 |
+
|
|
|
|
| 114 |
unfixed_errors = [e for e in self._errors if not e.get("fixed", False)]
|
| 115 |
+
|
| 116 |
return DataCleanObservation(
|
| 117 |
task_name=self._state.task_name,
|
| 118 |
task_description=self._task_info.get("description", ""),
|
env/models.py
CHANGED
|
@@ -1,61 +1,40 @@
|
|
| 1 |
-
"""Pydantic models for the Data Validation Pipeline environment."""
|
| 2 |
-
|
| 3 |
from typing import Any, Dict, List, Optional
|
| 4 |
from pydantic import BaseModel, Field
|
| 5 |
|
| 6 |
|
| 7 |
class DataCleanAction(BaseModel):
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
)
|
| 13 |
-
target_field: str = Field(
|
| 14 |
-
default="",
|
| 15 |
-
description="The field/column name to apply the action to"
|
| 16 |
-
)
|
| 17 |
-
target_row: int = Field(
|
| 18 |
-
default=0,
|
| 19 |
-
description="The row index to apply the action to"
|
| 20 |
-
)
|
| 21 |
-
new_value: str = Field(
|
| 22 |
-
default="",
|
| 23 |
-
description="The new/corrected value to set"
|
| 24 |
-
)
|
| 25 |
|
| 26 |
|
| 27 |
class DataCleanObservation(BaseModel):
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
)
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
cumulative_reward: float = Field(default=0.0, description="Total reward accumulated")
|
| 43 |
-
done: bool = Field(default=False, description="Whether episode is finished")
|
| 44 |
-
last_action_result: str = Field(default="", description="Result of the last action")
|
| 45 |
-
task_hint: str = Field(default="", description="Hint for solving the task")
|
| 46 |
available_actions: List[str] = Field(
|
| 47 |
default_factory=lambda: [
|
| 48 |
"fix_missing", "fix_type", "fix_range", "fix_format",
|
| 49 |
"fix_duplicate", "validate", "skip"
|
| 50 |
-
]
|
| 51 |
-
description="Available action types"
|
| 52 |
)
|
| 53 |
-
progress_pct: float = Field(default=0.0
|
| 54 |
-
field_names: List[str] = Field(default_factory=list
|
| 55 |
|
| 56 |
|
| 57 |
class DataCleanState(BaseModel):
|
| 58 |
-
"""Full internal state of the environment."""
|
| 59 |
episode_id: str = Field(default="")
|
| 60 |
task_name: str = Field(default="")
|
| 61 |
step_count: int = Field(default=0)
|
|
|
|
|
|
|
|
|
|
| 1 |
from typing import Any, Dict, List, Optional
|
| 2 |
from pydantic import BaseModel, Field
|
| 3 |
|
| 4 |
|
| 5 |
class DataCleanAction(BaseModel):
|
| 6 |
+
action_type: str = Field(...)
|
| 7 |
+
target_field: str = Field(default="")
|
| 8 |
+
target_row: int = Field(default=0)
|
| 9 |
+
new_value: str = Field(default="")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
|
| 12 |
class DataCleanObservation(BaseModel):
|
| 13 |
+
task_name: str = Field(default="")
|
| 14 |
+
task_description: str = Field(default="")
|
| 15 |
+
dataset: List[Dict[str, Any]] = Field(default_factory=list)
|
| 16 |
+
errors_found: List[Dict[str, Any]] = Field(default_factory=list)
|
| 17 |
+
errors_remaining: int = Field(default=0)
|
| 18 |
+
errors_total: int = Field(default=0)
|
| 19 |
+
errors_fixed: int = Field(default=0)
|
| 20 |
+
step_count: int = Field(default=0)
|
| 21 |
+
max_steps: int = Field(default=20)
|
| 22 |
+
reward: float = Field(default=0.0)
|
| 23 |
+
cumulative_reward: float = Field(default=0.0)
|
| 24 |
+
done: bool = Field(default=False)
|
| 25 |
+
last_action_result: str = Field(default="")
|
| 26 |
+
task_hint: str = Field(default="")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
available_actions: List[str] = Field(
|
| 28 |
default_factory=lambda: [
|
| 29 |
"fix_missing", "fix_type", "fix_range", "fix_format",
|
| 30 |
"fix_duplicate", "validate", "skip"
|
| 31 |
+
]
|
|
|
|
| 32 |
)
|
| 33 |
+
progress_pct: float = Field(default=0.0)
|
| 34 |
+
field_names: List[str] = Field(default_factory=list)
|
| 35 |
|
| 36 |
|
| 37 |
class DataCleanState(BaseModel):
|
|
|
|
| 38 |
episode_id: str = Field(default="")
|
| 39 |
task_name: str = Field(default="")
|
| 40 |
step_count: int = Field(default=0)
|
env/tasks.py
CHANGED
|
@@ -1,24 +1,11 @@
|
|
| 1 |
-
"""Task registry and graders for the Data Validation Pipeline environment.
|
| 2 |
-
|
| 3 |
-
Each task provides:
|
| 4 |
-
- A dirty dataset with injected errors
|
| 5 |
-
- A ground truth clean dataset
|
| 6 |
-
- A grader that scores partial progress
|
| 7 |
-
"""
|
| 8 |
-
|
| 9 |
import copy
|
| 10 |
import random
|
| 11 |
from typing import Any, Dict, List, Tuple
|
| 12 |
|
| 13 |
|
| 14 |
-
# ──────────────────────────────────────────────────────────────────────
|
| 15 |
-
# TASK 1 (Easy): Fix Missing Values — solvable in ≤5 steps
|
| 16 |
-
# ──────────────────────────────────────────────────────────────────────
|
| 17 |
-
|
| 18 |
def _generate_task_easy(seed: int = 42) -> Tuple[List[Dict], List[Dict], List[Dict]]:
|
| 19 |
-
"""Generate a small employee dataset with missing values only."""
|
| 20 |
rng = random.Random(seed)
|
| 21 |
-
|
| 22 |
ground_truth = [
|
| 23 |
{"id": 1, "name": "Alice Johnson", "email": "alice@example.com", "age": 30, "department": "Engineering"},
|
| 24 |
{"id": 2, "name": "Bob Smith", "email": "bob@example.com", "age": 25, "department": "Marketing"},
|
|
@@ -26,17 +13,16 @@ def _generate_task_easy(seed: int = 42) -> Tuple[List[Dict], List[Dict], List[Di
|
|
| 26 |
{"id": 4, "name": "David Brown", "email": "david@example.com", "age": 28, "department": "Sales"},
|
| 27 |
{"id": 5, "name": "Eve Davis", "email": "eve@example.com", "age": 32, "department": "Marketing"},
|
| 28 |
]
|
| 29 |
-
|
| 30 |
dirty = copy.deepcopy(ground_truth)
|
| 31 |
errors = []
|
| 32 |
-
|
| 33 |
-
# Inject 3 missing value errors
|
| 34 |
missing_configs = [
|
| 35 |
(1, "email", ""),
|
| 36 |
(2, "department", ""),
|
| 37 |
(4, "name", ""),
|
| 38 |
]
|
| 39 |
-
|
| 40 |
for row_idx, field, replacement in missing_configs:
|
| 41 |
dirty[row_idx][field] = replacement
|
| 42 |
errors.append({
|
|
@@ -46,18 +32,13 @@ def _generate_task_easy(seed: int = 42) -> Tuple[List[Dict], List[Dict], List[Di
|
|
| 46 |
"current_value": replacement,
|
| 47 |
"description": f"Row {row_idx}: '{field}' is missing/empty"
|
| 48 |
})
|
| 49 |
-
|
| 50 |
-
return dirty, ground_truth, errors
|
| 51 |
|
|
|
|
| 52 |
|
| 53 |
-
# ──────────────────────────────────────────────────────────────────────
|
| 54 |
-
# TASK 2 (Medium): Fix Types & Formats — requires 2-3 stage reasoning
|
| 55 |
-
# ──────────────────────────────────────────────────────────────────────
|
| 56 |
|
| 57 |
def _generate_task_medium(seed: int = 42) -> Tuple[List[Dict], List[Dict], List[Dict]]:
|
| 58 |
-
"""Generate a product dataset with type, format, and missing errors."""
|
| 59 |
rng = random.Random(seed)
|
| 60 |
-
|
| 61 |
ground_truth = [
|
| 62 |
{"id": 1, "product": "Laptop Pro", "price": 999.99, "quantity": 50, "sku": "LP-001", "category": "Electronics"},
|
| 63 |
{"id": 2, "product": "Wireless Mouse", "price": 29.99, "quantity": 200, "sku": "WM-002", "category": "Accessories"},
|
|
@@ -67,11 +48,10 @@ def _generate_task_medium(seed: int = 42) -> Tuple[List[Dict], List[Dict], List[
|
|
| 67 |
{"id": 6, "product": "Headphones", "price": 149.99, "quantity": 80, "sku": "HP-006", "category": "Audio"},
|
| 68 |
{"id": 7, "product": "Webcam HD", "price": 59.99, "quantity": 120, "sku": "WC-007", "category": "Electronics"},
|
| 69 |
]
|
| 70 |
-
|
| 71 |
dirty = copy.deepcopy(ground_truth)
|
| 72 |
errors = []
|
| 73 |
-
|
| 74 |
-
# Error 1: price stored as string
|
| 75 |
dirty[0]["price"] = "999.99"
|
| 76 |
errors.append({
|
| 77 |
"error_type": "type",
|
|
@@ -81,8 +61,7 @@ def _generate_task_medium(seed: int = 42) -> Tuple[List[Dict], List[Dict], List[
|
|
| 81 |
"expected_type": "float",
|
| 82 |
"description": "Row 0: 'price' should be float, got string '999.99'"
|
| 83 |
})
|
| 84 |
-
|
| 85 |
-
# Error 2: quantity stored as string
|
| 86 |
dirty[2]["quantity"] = "five hundred"
|
| 87 |
errors.append({
|
| 88 |
"error_type": "type",
|
|
@@ -92,8 +71,7 @@ def _generate_task_medium(seed: int = 42) -> Tuple[List[Dict], List[Dict], List[
|
|
| 92 |
"expected_type": "int",
|
| 93 |
"description": "Row 2: 'quantity' should be int, got string 'five hundred'"
|
| 94 |
})
|
| 95 |
-
|
| 96 |
-
# Error 3: SKU wrong format
|
| 97 |
dirty[3]["sku"] = "mn004"
|
| 98 |
errors.append({
|
| 99 |
"error_type": "format",
|
|
@@ -103,8 +81,7 @@ def _generate_task_medium(seed: int = 42) -> Tuple[List[Dict], List[Dict], List[
|
|
| 103 |
"expected_format": "XX-NNN",
|
| 104 |
"description": "Row 3: 'sku' should match format 'XX-NNN', got 'mn004'"
|
| 105 |
})
|
| 106 |
-
|
| 107 |
-
# Error 4: missing category
|
| 108 |
dirty[5]["category"] = ""
|
| 109 |
errors.append({
|
| 110 |
"error_type": "missing",
|
|
@@ -113,8 +90,7 @@ def _generate_task_medium(seed: int = 42) -> Tuple[List[Dict], List[Dict], List[
|
|
| 113 |
"current_value": "",
|
| 114 |
"description": "Row 5: 'category' is missing/empty"
|
| 115 |
})
|
| 116 |
-
|
| 117 |
-
# Error 5: negative price (range error)
|
| 118 |
dirty[4]["price"] = -79.99
|
| 119 |
errors.append({
|
| 120 |
"error_type": "range",
|
|
@@ -123,8 +99,7 @@ def _generate_task_medium(seed: int = 42) -> Tuple[List[Dict], List[Dict], List[
|
|
| 123 |
"current_value": -79.99,
|
| 124 |
"description": "Row 4: 'price' is negative (-79.99), should be positive"
|
| 125 |
})
|
| 126 |
-
|
| 127 |
-
# Error 6: duplicate SKU
|
| 128 |
dirty[6]["sku"] = "WM-002"
|
| 129 |
errors.append({
|
| 130 |
"error_type": "duplicate",
|
|
@@ -133,18 +108,13 @@ def _generate_task_medium(seed: int = 42) -> Tuple[List[Dict], List[Dict], List[
|
|
| 133 |
"current_value": "WM-002",
|
| 134 |
"description": "Row 6: 'sku' value 'WM-002' duplicates row 1"
|
| 135 |
})
|
| 136 |
-
|
| 137 |
-
return dirty, ground_truth, errors
|
| 138 |
|
|
|
|
| 139 |
|
| 140 |
-
# ──────────────────────────────────────────────────────────────────────
|
| 141 |
-
# TASK 3 (Hard): Multi-constraint Optimization — requires planning
|
| 142 |
-
# ──────────────────────────────────────────────────────────────────────
|
| 143 |
|
| 144 |
def _generate_task_hard(seed: int = 42) -> Tuple[List[Dict], List[Dict], List[Dict]]:
|
| 145 |
-
"""Generate a complex customer orders dataset with multiple interrelated errors."""
|
| 146 |
rng = random.Random(seed)
|
| 147 |
-
|
| 148 |
ground_truth = [
|
| 149 |
{"id": 1, "customer": "Acme Corp", "email": "orders@acme.com", "amount": 1500.00, "currency": "USD", "status": "completed", "date": "2024-03-15", "region": "North America", "priority": "high"},
|
| 150 |
{"id": 2, "customer": "GlobalTech", "email": "sales@globaltech.io", "amount": 2300.50, "currency": "EUR", "status": "pending", "date": "2024-03-16", "region": "Europe", "priority": "medium"},
|
|
@@ -157,57 +127,43 @@ def _generate_task_hard(seed: int = 42) -> Tuple[List[Dict], List[Dict], List[Di
|
|
| 157 |
{"id": 9, "customer": "EcoSmart", "email": "green@ecosmart.co", "amount": 1200.00, "currency": "AUD", "status": "shipped", "date": "2024-03-23", "region": "Asia Pacific", "priority": "medium"},
|
| 158 |
{"id": 10, "customer": "BlueOcean", "email": "info@blueocean.net", "amount": 980.75, "currency": "USD", "status": "pending", "date": "2024-03-24", "region": "North America", "priority": "low"},
|
| 159 |
]
|
| 160 |
-
|
| 161 |
dirty = copy.deepcopy(ground_truth)
|
| 162 |
errors = []
|
| 163 |
-
|
| 164 |
-
# Error 1: missing email
|
| 165 |
dirty[0]["email"] = ""
|
| 166 |
errors.append({"error_type": "missing", "row": 0, "field": "email", "current_value": "", "description": "Row 0: 'email' is missing"})
|
| 167 |
-
|
| 168 |
-
# Error 2: negative amount
|
| 169 |
dirty[1]["amount"] = -2300.50
|
| 170 |
errors.append({"error_type": "range", "row": 1, "field": "amount", "current_value": -2300.50, "description": "Row 1: 'amount' is negative"})
|
| 171 |
-
|
| 172 |
-
# Error 3: invalid date format
|
| 173 |
dirty[2]["date"] = "03/17/2024"
|
| 174 |
errors.append({"error_type": "format", "row": 2, "field": "date", "current_value": "03/17/2024", "expected_format": "YYYY-MM-DD", "description": "Row 2: 'date' wrong format, expected YYYY-MM-DD"})
|
| 175 |
-
|
| 176 |
-
# Error 4: amount as string
|
| 177 |
dirty[3]["amount"] = "4200"
|
| 178 |
errors.append({"error_type": "type", "row": 3, "field": "amount", "current_value": "4200", "expected_type": "float", "description": "Row 3: 'amount' should be float, got string"})
|
| 179 |
-
|
| 180 |
-
# Error 5: invalid status
|
| 181 |
dirty[4]["status"] = "in-progress"
|
| 182 |
errors.append({"error_type": "format", "row": 4, "field": "status", "current_value": "in-progress", "expected_format": "one of: completed, pending, shipped, cancelled", "description": "Row 4: 'status' invalid value 'in-progress'"})
|
| 183 |
-
|
| 184 |
-
# Error 6: missing region
|
| 185 |
dirty[5]["region"] = ""
|
| 186 |
errors.append({"error_type": "missing", "row": 5, "field": "region", "current_value": "", "description": "Row 5: 'region' is missing"})
|
| 187 |
-
|
| 188 |
-
# Error 7: duplicate customer
|
| 189 |
dirty[6]["customer"] = "Acme Corp"
|
| 190 |
dirty[6]["email"] = "orders@acme.com"
|
| 191 |
errors.append({"error_type": "duplicate", "row": 6, "field": "customer", "current_value": "Acme Corp", "description": "Row 6: 'customer' duplicates row 0"})
|
| 192 |
-
|
| 193 |
-
# Error 8: amount out of range (too high)
|
| 194 |
dirty[7]["amount"] = 99999.99
|
| 195 |
errors.append({"error_type": "range", "row": 7, "field": "amount", "current_value": 99999.99, "description": "Row 7: 'amount' exceeds maximum threshold (should be 6750.00)"})
|
| 196 |
-
|
| 197 |
-
# Error 9: invalid currency
|
| 198 |
dirty[8]["currency"] = "AUSD"
|
| 199 |
errors.append({"error_type": "format", "row": 8, "field": "currency", "current_value": "AUSD", "expected_format": "3-letter ISO code", "description": "Row 8: 'currency' invalid code 'AUSD'"})
|
| 200 |
-
|
| 201 |
-
# Error 10: missing priority + wrong type
|
| 202 |
dirty[9]["priority"] = ""
|
| 203 |
errors.append({"error_type": "missing", "row": 9, "field": "priority", "current_value": "", "description": "Row 9: 'priority' is missing"})
|
| 204 |
-
|
| 205 |
-
return dirty, ground_truth, errors
|
| 206 |
|
|
|
|
| 207 |
|
| 208 |
-
# ──────────────────────────────────────────────────────────────────────
|
| 209 |
-
# Task Registry
|
| 210 |
-
# ──────────────────────────────────────────────────────────────────────
|
| 211 |
|
| 212 |
TASK_REGISTRY = {
|
| 213 |
"easy_missing_values": {
|
|
@@ -238,18 +194,16 @@ TASK_REGISTRY = {
|
|
| 238 |
|
| 239 |
|
| 240 |
def get_task_names() -> List[str]:
|
| 241 |
-
"""Return all registered task names."""
|
| 242 |
return list(TASK_REGISTRY.keys())
|
| 243 |
|
| 244 |
|
| 245 |
def generate_task(task_name: str, seed: int = 42) -> Dict[str, Any]:
|
| 246 |
-
"""Generate a task by name."""
|
| 247 |
if task_name not in TASK_REGISTRY:
|
| 248 |
raise ValueError(f"Unknown task: {task_name}. Available: {get_task_names()}")
|
| 249 |
-
|
| 250 |
task_info = TASK_REGISTRY[task_name]
|
| 251 |
dirty, ground_truth, errors = task_info["generator"](seed)
|
| 252 |
-
|
| 253 |
return {
|
| 254 |
"name": task_info["name"],
|
| 255 |
"description": task_info["description"],
|
|
@@ -263,29 +217,18 @@ def generate_task(task_name: str, seed: int = 42) -> Dict[str, Any]:
|
|
| 263 |
}
|
| 264 |
|
| 265 |
|
| 266 |
-
def grade_action(action_type: str, target_field: str, target_row: int,
|
| 267 |
-
new_value: str, dirty_dataset: List[Dict],
|
| 268 |
ground_truth: List[Dict], errors: List[Dict]) -> Tuple[float, str, bool]:
|
| 269 |
-
"""
|
| 270 |
-
Grade a single action. Returns (reward, message, error_fixed).
|
| 271 |
-
|
| 272 |
-
Reward strategy:
|
| 273 |
-
- Correct fix: +1.0 / total_errors (proportional)
|
| 274 |
-
- Wrong fix: -0.05
|
| 275 |
-
- Skip: 0.0
|
| 276 |
-
- Validate (check progress): 0.0
|
| 277 |
-
- Repeated identical action: -0.1
|
| 278 |
-
"""
|
| 279 |
total_errors = len(errors) if errors else 1
|
| 280 |
-
|
| 281 |
if action_type == "validate":
|
| 282 |
fixed = sum(1 for e in errors if e.get("fixed", False))
|
| 283 |
return 0.0, f"Validation: {fixed}/{total_errors} errors fixed ({fixed/total_errors*100:.0f}%)", False
|
| 284 |
-
|
| 285 |
if action_type == "skip":
|
| 286 |
return 0.0, "Skipped current action", False
|
| 287 |
-
|
| 288 |
-
# Find matching error
|
| 289 |
matching_error = None
|
| 290 |
for e in errors:
|
| 291 |
if e.get("fixed", False):
|
|
@@ -293,11 +236,10 @@ def grade_action(action_type: str, target_field: str, target_row: int,
|
|
| 293 |
if e["row"] == target_row and e["field"] == target_field:
|
| 294 |
matching_error = e
|
| 295 |
break
|
| 296 |
-
|
| 297 |
if matching_error is None:
|
| 298 |
return -0.05, f"No unfixed error at row {target_row}, field '{target_field}'", False
|
| 299 |
-
|
| 300 |
-
# Check if the action type matches the error type
|
| 301 |
action_to_error_map = {
|
| 302 |
"fix_missing": "missing",
|
| 303 |
"fix_type": "type",
|
|
@@ -305,15 +247,13 @@ def grade_action(action_type: str, target_field: str, target_row: int,
|
|
| 305 |
"fix_format": "format",
|
| 306 |
"fix_duplicate": "duplicate",
|
| 307 |
}
|
| 308 |
-
|
| 309 |
expected_error_type = action_to_error_map.get(action_type, "")
|
| 310 |
if expected_error_type != matching_error["error_type"]:
|
| 311 |
return -0.05, f"Wrong action type '{action_type}' for error type '{matching_error['error_type']}'", False
|
| 312 |
-
|
| 313 |
-
# Check the new value against ground truth
|
| 314 |
gt_value = ground_truth[target_row][target_field]
|
| 315 |
-
|
| 316 |
-
# Flexible value comparison
|
| 317 |
is_correct = False
|
| 318 |
try:
|
| 319 |
if isinstance(gt_value, float):
|
|
@@ -324,18 +264,17 @@ def grade_action(action_type: str, target_field: str, target_row: int,
|
|
| 324 |
is_correct = str(new_value).strip() == str(gt_value).strip()
|
| 325 |
except (ValueError, TypeError):
|
| 326 |
is_correct = str(new_value).strip() == str(gt_value).strip()
|
| 327 |
-
|
| 328 |
if is_correct:
|
| 329 |
matching_error["fixed"] = True
|
| 330 |
-
# Update the dirty dataset
|
| 331 |
if isinstance(gt_value, float):
|
| 332 |
dirty_dataset[target_row][target_field] = float(new_value)
|
| 333 |
elif isinstance(gt_value, int):
|
| 334 |
dirty_dataset[target_row][target_field] = int(float(new_value))
|
| 335 |
else:
|
| 336 |
dirty_dataset[target_row][target_field] = new_value
|
| 337 |
-
|
| 338 |
reward = 1.0 / total_errors
|
| 339 |
-
return reward, f"
|
| 340 |
else:
|
| 341 |
-
return -0.05, f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import copy
|
| 2 |
import random
|
| 3 |
from typing import Any, Dict, List, Tuple
|
| 4 |
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
def _generate_task_easy(seed: int = 42) -> Tuple[List[Dict], List[Dict], List[Dict]]:
|
|
|
|
| 7 |
rng = random.Random(seed)
|
| 8 |
+
|
| 9 |
ground_truth = [
|
| 10 |
{"id": 1, "name": "Alice Johnson", "email": "alice@example.com", "age": 30, "department": "Engineering"},
|
| 11 |
{"id": 2, "name": "Bob Smith", "email": "bob@example.com", "age": 25, "department": "Marketing"},
|
|
|
|
| 13 |
{"id": 4, "name": "David Brown", "email": "david@example.com", "age": 28, "department": "Sales"},
|
| 14 |
{"id": 5, "name": "Eve Davis", "email": "eve@example.com", "age": 32, "department": "Marketing"},
|
| 15 |
]
|
| 16 |
+
|
| 17 |
dirty = copy.deepcopy(ground_truth)
|
| 18 |
errors = []
|
| 19 |
+
|
|
|
|
| 20 |
missing_configs = [
|
| 21 |
(1, "email", ""),
|
| 22 |
(2, "department", ""),
|
| 23 |
(4, "name", ""),
|
| 24 |
]
|
| 25 |
+
|
| 26 |
for row_idx, field, replacement in missing_configs:
|
| 27 |
dirty[row_idx][field] = replacement
|
| 28 |
errors.append({
|
|
|
|
| 32 |
"current_value": replacement,
|
| 33 |
"description": f"Row {row_idx}: '{field}' is missing/empty"
|
| 34 |
})
|
|
|
|
|
|
|
| 35 |
|
| 36 |
+
return dirty, ground_truth, errors
|
| 37 |
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
def _generate_task_medium(seed: int = 42) -> Tuple[List[Dict], List[Dict], List[Dict]]:
|
|
|
|
| 40 |
rng = random.Random(seed)
|
| 41 |
+
|
| 42 |
ground_truth = [
|
| 43 |
{"id": 1, "product": "Laptop Pro", "price": 999.99, "quantity": 50, "sku": "LP-001", "category": "Electronics"},
|
| 44 |
{"id": 2, "product": "Wireless Mouse", "price": 29.99, "quantity": 200, "sku": "WM-002", "category": "Accessories"},
|
|
|
|
| 48 |
{"id": 6, "product": "Headphones", "price": 149.99, "quantity": 80, "sku": "HP-006", "category": "Audio"},
|
| 49 |
{"id": 7, "product": "Webcam HD", "price": 59.99, "quantity": 120, "sku": "WC-007", "category": "Electronics"},
|
| 50 |
]
|
| 51 |
+
|
| 52 |
dirty = copy.deepcopy(ground_truth)
|
| 53 |
errors = []
|
| 54 |
+
|
|
|
|
| 55 |
dirty[0]["price"] = "999.99"
|
| 56 |
errors.append({
|
| 57 |
"error_type": "type",
|
|
|
|
| 61 |
"expected_type": "float",
|
| 62 |
"description": "Row 0: 'price' should be float, got string '999.99'"
|
| 63 |
})
|
| 64 |
+
|
|
|
|
| 65 |
dirty[2]["quantity"] = "five hundred"
|
| 66 |
errors.append({
|
| 67 |
"error_type": "type",
|
|
|
|
| 71 |
"expected_type": "int",
|
| 72 |
"description": "Row 2: 'quantity' should be int, got string 'five hundred'"
|
| 73 |
})
|
| 74 |
+
|
|
|
|
| 75 |
dirty[3]["sku"] = "mn004"
|
| 76 |
errors.append({
|
| 77 |
"error_type": "format",
|
|
|
|
| 81 |
"expected_format": "XX-NNN",
|
| 82 |
"description": "Row 3: 'sku' should match format 'XX-NNN', got 'mn004'"
|
| 83 |
})
|
| 84 |
+
|
|
|
|
| 85 |
dirty[5]["category"] = ""
|
| 86 |
errors.append({
|
| 87 |
"error_type": "missing",
|
|
|
|
| 90 |
"current_value": "",
|
| 91 |
"description": "Row 5: 'category' is missing/empty"
|
| 92 |
})
|
| 93 |
+
|
|
|
|
| 94 |
dirty[4]["price"] = -79.99
|
| 95 |
errors.append({
|
| 96 |
"error_type": "range",
|
|
|
|
| 99 |
"current_value": -79.99,
|
| 100 |
"description": "Row 4: 'price' is negative (-79.99), should be positive"
|
| 101 |
})
|
| 102 |
+
|
|
|
|
| 103 |
dirty[6]["sku"] = "WM-002"
|
| 104 |
errors.append({
|
| 105 |
"error_type": "duplicate",
|
|
|
|
| 108 |
"current_value": "WM-002",
|
| 109 |
"description": "Row 6: 'sku' value 'WM-002' duplicates row 1"
|
| 110 |
})
|
|
|
|
|
|
|
| 111 |
|
| 112 |
+
return dirty, ground_truth, errors
|
| 113 |
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
def _generate_task_hard(seed: int = 42) -> Tuple[List[Dict], List[Dict], List[Dict]]:
|
|
|
|
| 116 |
rng = random.Random(seed)
|
| 117 |
+
|
| 118 |
ground_truth = [
|
| 119 |
{"id": 1, "customer": "Acme Corp", "email": "orders@acme.com", "amount": 1500.00, "currency": "USD", "status": "completed", "date": "2024-03-15", "region": "North America", "priority": "high"},
|
| 120 |
{"id": 2, "customer": "GlobalTech", "email": "sales@globaltech.io", "amount": 2300.50, "currency": "EUR", "status": "pending", "date": "2024-03-16", "region": "Europe", "priority": "medium"},
|
|
|
|
| 127 |
{"id": 9, "customer": "EcoSmart", "email": "green@ecosmart.co", "amount": 1200.00, "currency": "AUD", "status": "shipped", "date": "2024-03-23", "region": "Asia Pacific", "priority": "medium"},
|
| 128 |
{"id": 10, "customer": "BlueOcean", "email": "info@blueocean.net", "amount": 980.75, "currency": "USD", "status": "pending", "date": "2024-03-24", "region": "North America", "priority": "low"},
|
| 129 |
]
|
| 130 |
+
|
| 131 |
dirty = copy.deepcopy(ground_truth)
|
| 132 |
errors = []
|
| 133 |
+
|
|
|
|
| 134 |
dirty[0]["email"] = ""
|
| 135 |
errors.append({"error_type": "missing", "row": 0, "field": "email", "current_value": "", "description": "Row 0: 'email' is missing"})
|
| 136 |
+
|
|
|
|
| 137 |
dirty[1]["amount"] = -2300.50
|
| 138 |
errors.append({"error_type": "range", "row": 1, "field": "amount", "current_value": -2300.50, "description": "Row 1: 'amount' is negative"})
|
| 139 |
+
|
|
|
|
| 140 |
dirty[2]["date"] = "03/17/2024"
|
| 141 |
errors.append({"error_type": "format", "row": 2, "field": "date", "current_value": "03/17/2024", "expected_format": "YYYY-MM-DD", "description": "Row 2: 'date' wrong format, expected YYYY-MM-DD"})
|
| 142 |
+
|
|
|
|
| 143 |
dirty[3]["amount"] = "4200"
|
| 144 |
errors.append({"error_type": "type", "row": 3, "field": "amount", "current_value": "4200", "expected_type": "float", "description": "Row 3: 'amount' should be float, got string"})
|
| 145 |
+
|
|
|
|
| 146 |
dirty[4]["status"] = "in-progress"
|
| 147 |
errors.append({"error_type": "format", "row": 4, "field": "status", "current_value": "in-progress", "expected_format": "one of: completed, pending, shipped, cancelled", "description": "Row 4: 'status' invalid value 'in-progress'"})
|
| 148 |
+
|
|
|
|
| 149 |
dirty[5]["region"] = ""
|
| 150 |
errors.append({"error_type": "missing", "row": 5, "field": "region", "current_value": "", "description": "Row 5: 'region' is missing"})
|
| 151 |
+
|
|
|
|
| 152 |
dirty[6]["customer"] = "Acme Corp"
|
| 153 |
dirty[6]["email"] = "orders@acme.com"
|
| 154 |
errors.append({"error_type": "duplicate", "row": 6, "field": "customer", "current_value": "Acme Corp", "description": "Row 6: 'customer' duplicates row 0"})
|
| 155 |
+
|
|
|
|
| 156 |
dirty[7]["amount"] = 99999.99
|
| 157 |
errors.append({"error_type": "range", "row": 7, "field": "amount", "current_value": 99999.99, "description": "Row 7: 'amount' exceeds maximum threshold (should be 6750.00)"})
|
| 158 |
+
|
|
|
|
| 159 |
dirty[8]["currency"] = "AUSD"
|
| 160 |
errors.append({"error_type": "format", "row": 8, "field": "currency", "current_value": "AUSD", "expected_format": "3-letter ISO code", "description": "Row 8: 'currency' invalid code 'AUSD'"})
|
| 161 |
+
|
|
|
|
| 162 |
dirty[9]["priority"] = ""
|
| 163 |
errors.append({"error_type": "missing", "row": 9, "field": "priority", "current_value": "", "description": "Row 9: 'priority' is missing"})
|
|
|
|
|
|
|
| 164 |
|
| 165 |
+
return dirty, ground_truth, errors
|
| 166 |
|
|
|
|
|
|
|
|
|
|
| 167 |
|
| 168 |
TASK_REGISTRY = {
|
| 169 |
"easy_missing_values": {
|
|
|
|
| 194 |
|
| 195 |
|
| 196 |
def get_task_names() -> List[str]:
|
|
|
|
| 197 |
return list(TASK_REGISTRY.keys())
|
| 198 |
|
| 199 |
|
| 200 |
def generate_task(task_name: str, seed: int = 42) -> Dict[str, Any]:
|
|
|
|
| 201 |
if task_name not in TASK_REGISTRY:
|
| 202 |
raise ValueError(f"Unknown task: {task_name}. Available: {get_task_names()}")
|
| 203 |
+
|
| 204 |
task_info = TASK_REGISTRY[task_name]
|
| 205 |
dirty, ground_truth, errors = task_info["generator"](seed)
|
| 206 |
+
|
| 207 |
return {
|
| 208 |
"name": task_info["name"],
|
| 209 |
"description": task_info["description"],
|
|
|
|
| 217 |
}
|
| 218 |
|
| 219 |
|
| 220 |
+
def grade_action(action_type: str, target_field: str, target_row: int,
|
| 221 |
+
new_value: str, dirty_dataset: List[Dict],
|
| 222 |
ground_truth: List[Dict], errors: List[Dict]) -> Tuple[float, str, bool]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
total_errors = len(errors) if errors else 1
|
| 224 |
+
|
| 225 |
if action_type == "validate":
|
| 226 |
fixed = sum(1 for e in errors if e.get("fixed", False))
|
| 227 |
return 0.0, f"Validation: {fixed}/{total_errors} errors fixed ({fixed/total_errors*100:.0f}%)", False
|
| 228 |
+
|
| 229 |
if action_type == "skip":
|
| 230 |
return 0.0, "Skipped current action", False
|
| 231 |
+
|
|
|
|
| 232 |
matching_error = None
|
| 233 |
for e in errors:
|
| 234 |
if e.get("fixed", False):
|
|
|
|
| 236 |
if e["row"] == target_row and e["field"] == target_field:
|
| 237 |
matching_error = e
|
| 238 |
break
|
| 239 |
+
|
| 240 |
if matching_error is None:
|
| 241 |
return -0.05, f"No unfixed error at row {target_row}, field '{target_field}'", False
|
| 242 |
+
|
|
|
|
| 243 |
action_to_error_map = {
|
| 244 |
"fix_missing": "missing",
|
| 245 |
"fix_type": "type",
|
|
|
|
| 247 |
"fix_format": "format",
|
| 248 |
"fix_duplicate": "duplicate",
|
| 249 |
}
|
| 250 |
+
|
| 251 |
expected_error_type = action_to_error_map.get(action_type, "")
|
| 252 |
if expected_error_type != matching_error["error_type"]:
|
| 253 |
return -0.05, f"Wrong action type '{action_type}' for error type '{matching_error['error_type']}'", False
|
| 254 |
+
|
|
|
|
| 255 |
gt_value = ground_truth[target_row][target_field]
|
| 256 |
+
|
|
|
|
| 257 |
is_correct = False
|
| 258 |
try:
|
| 259 |
if isinstance(gt_value, float):
|
|
|
|
| 264 |
is_correct = str(new_value).strip() == str(gt_value).strip()
|
| 265 |
except (ValueError, TypeError):
|
| 266 |
is_correct = str(new_value).strip() == str(gt_value).strip()
|
| 267 |
+
|
| 268 |
if is_correct:
|
| 269 |
matching_error["fixed"] = True
|
|
|
|
| 270 |
if isinstance(gt_value, float):
|
| 271 |
dirty_dataset[target_row][target_field] = float(new_value)
|
| 272 |
elif isinstance(gt_value, int):
|
| 273 |
dirty_dataset[target_row][target_field] = int(float(new_value))
|
| 274 |
else:
|
| 275 |
dirty_dataset[target_row][target_field] = new_value
|
| 276 |
+
|
| 277 |
reward = 1.0 / total_errors
|
| 278 |
+
return reward, f"Fixed: row {target_row}, field '{target_field}' -> '{new_value}'", True
|
| 279 |
else:
|
| 280 |
+
return -0.05, f"Wrong value for row {target_row}, field '{target_field}'. Got '{new_value}', expected something else.", False
|
inference.py
CHANGED
|
@@ -1,18 +1,3 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Inference agent for the Data Validation Pipeline environment.
|
| 3 |
-
|
| 4 |
-
Uses OpenAI-compatible API to solve data cleaning tasks.
|
| 5 |
-
Reads environment variables:
|
| 6 |
-
- API_BASE_URL: Base URL for the OpenAI-compatible API (default: https://api.openai.com/v1)
|
| 7 |
-
- MODEL_NAME: Model to use (default: gpt-4.1-mini)
|
| 8 |
-
- HF_TOKEN: HuggingFace token (REQUIRED, no default)
|
| 9 |
-
|
| 10 |
-
Output format strictly follows OpenEnv spec:
|
| 11 |
-
[START] task=<name> env=<benchmark> model=<model_name>
|
| 12 |
-
[STEP] step=<n> action=<str> reward=<0.00> done=<true|false> error=<msg|null>
|
| 13 |
-
[END] success=<true|false> steps=<n> rewards=<r1,r2,...,rn>
|
| 14 |
-
"""
|
| 15 |
-
|
| 16 |
import json
|
| 17 |
import os
|
| 18 |
import re
|
|
@@ -21,7 +6,6 @@ import time
|
|
| 21 |
import requests
|
| 22 |
from openai import OpenAI
|
| 23 |
|
| 24 |
-
# Read environment variables with defaults where required
|
| 25 |
API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
|
| 26 |
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4.1-mini")
|
| 27 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
|
@@ -29,16 +13,13 @@ HF_TOKEN = os.getenv("HF_TOKEN")
|
|
| 29 |
if HF_TOKEN is None:
|
| 30 |
raise ValueError("HF_TOKEN environment variable is required")
|
| 31 |
|
| 32 |
-
# Initialize OpenAI client
|
| 33 |
client = OpenAI(
|
| 34 |
base_url=API_BASE_URL,
|
| 35 |
api_key=HF_TOKEN,
|
| 36 |
)
|
| 37 |
|
| 38 |
-
# The HF Space URL where the environment is running
|
| 39 |
ENV_BASE_URL = os.getenv("ENV_BASE_URL", "https://kush5699-data-validation-env.hf.space")
|
| 40 |
|
| 41 |
-
# All 3 tasks to run sequentially
|
| 42 |
TASKS = [
|
| 43 |
{"task_name": "easy_missing_values", "seed": 42},
|
| 44 |
{"task_name": "medium_mixed_errors", "seed": 42},
|
|
@@ -49,7 +30,6 @@ BENCHMARK_NAME = "data_validation_env"
|
|
| 49 |
|
| 50 |
|
| 51 |
def call_llm(messages: list) -> str:
|
| 52 |
-
"""Call the LLM via OpenAI-compatible API."""
|
| 53 |
try:
|
| 54 |
response = client.chat.completions.create(
|
| 55 |
model=MODEL_NAME,
|
|
@@ -68,7 +48,6 @@ def call_llm(messages: list) -> str:
|
|
| 68 |
|
| 69 |
|
| 70 |
def env_reset(task_name: str, seed: int = 42) -> dict:
|
| 71 |
-
"""Reset the environment."""
|
| 72 |
resp = requests.post(
|
| 73 |
f"{ENV_BASE_URL}/reset",
|
| 74 |
json={"task_name": task_name, "seed": seed},
|
|
@@ -79,7 +58,6 @@ def env_reset(task_name: str, seed: int = 42) -> dict:
|
|
| 79 |
|
| 80 |
|
| 81 |
def env_step(action: dict) -> dict:
|
| 82 |
-
"""Take a step in the environment."""
|
| 83 |
resp = requests.post(
|
| 84 |
f"{ENV_BASE_URL}/step",
|
| 85 |
json=action,
|
|
@@ -90,7 +68,6 @@ def env_step(action: dict) -> dict:
|
|
| 90 |
|
| 91 |
|
| 92 |
def build_system_prompt(obs: dict) -> str:
|
| 93 |
-
"""Build a system prompt for the LLM based on current observation."""
|
| 94 |
return f"""You are a data validation agent. Your task is to fix errors in a dataset.
|
| 95 |
|
| 96 |
TASK: {obs.get('task_description', '')}
|
|
@@ -116,13 +93,9 @@ RULES:
|
|
| 116 |
|
| 117 |
|
| 118 |
def build_user_prompt(obs: dict) -> str:
|
| 119 |
-
"""Build a user prompt showing current state."""
|
| 120 |
errors = obs.get("errors_found", [])
|
| 121 |
dataset = obs.get("dataset", [])
|
| 122 |
-
|
| 123 |
errors_text = json.dumps(errors, indent=2) if errors else "No errors remaining!"
|
| 124 |
-
|
| 125 |
-
# Show a compact view of dataset
|
| 126 |
dataset_compact = []
|
| 127 |
for i, row in enumerate(dataset):
|
| 128 |
dataset_compact.append(f"Row {i}: {json.dumps(row)}")
|
|
@@ -146,8 +119,6 @@ Respond with ONLY a JSON action object to fix the next error."""
|
|
| 146 |
|
| 147 |
|
| 148 |
def parse_llm_response(response: str) -> dict:
|
| 149 |
-
"""Parse the LLM response into a valid action."""
|
| 150 |
-
# Try to extract JSON from the response
|
| 151 |
try:
|
| 152 |
action = json.loads(response)
|
| 153 |
return {
|
|
@@ -159,7 +130,6 @@ def parse_llm_response(response: str) -> dict:
|
|
| 159 |
except json.JSONDecodeError:
|
| 160 |
pass
|
| 161 |
|
| 162 |
-
# Try to find JSON in the response
|
| 163 |
json_match = re.search(r'\{[^}]+\}', response)
|
| 164 |
if json_match:
|
| 165 |
try:
|
|
@@ -173,12 +143,10 @@ def parse_llm_response(response: str) -> dict:
|
|
| 173 |
except (json.JSONDecodeError, ValueError):
|
| 174 |
pass
|
| 175 |
|
| 176 |
-
# Fallback: skip
|
| 177 |
return {"action_type": "skip", "target_field": "", "target_row": 0, "new_value": ""}
|
| 178 |
|
| 179 |
|
| 180 |
def run_episode(task_config: dict) -> None:
|
| 181 |
-
"""Run a single episode for a task."""
|
| 182 |
task_name = task_config["task_name"]
|
| 183 |
seed = task_config.get("seed", 42)
|
| 184 |
rewards = []
|
|
@@ -188,7 +156,6 @@ def run_episode(task_config: dict) -> None:
|
|
| 188 |
print(f"[START] task={task_name} env={BENCHMARK_NAME} model={MODEL_NAME}")
|
| 189 |
|
| 190 |
try:
|
| 191 |
-
# Reset environment
|
| 192 |
obs = env_reset(task_name, seed)
|
| 193 |
max_steps = obs.get("max_steps", 20)
|
| 194 |
|
|
@@ -197,18 +164,14 @@ def run_episode(task_config: dict) -> None:
|
|
| 197 |
]
|
| 198 |
|
| 199 |
while not obs.get("done", False) and steps < max_steps:
|
| 200 |
-
# Build user prompt
|
| 201 |
user_msg = build_user_prompt(obs)
|
| 202 |
messages_for_call = messages + [{"role": "user", "content": user_msg}]
|
| 203 |
|
| 204 |
-
# Get LLM response
|
| 205 |
llm_response = call_llm(messages_for_call)
|
| 206 |
|
| 207 |
-
# Parse into action
|
| 208 |
action = parse_llm_response(llm_response)
|
| 209 |
action_str = json.dumps(action)
|
| 210 |
|
| 211 |
-
# Take step
|
| 212 |
error_msg = None
|
| 213 |
try:
|
| 214 |
obs = env_step(action)
|
|
@@ -227,7 +190,6 @@ def run_episode(task_config: dict) -> None:
|
|
| 227 |
if done:
|
| 228 |
break
|
| 229 |
|
| 230 |
-
# Calculate success based on cumulative reward
|
| 231 |
total_reward = sum(rewards)
|
| 232 |
success = total_reward > 0.5
|
| 233 |
|
|
@@ -243,10 +205,9 @@ def run_episode(task_config: dict) -> None:
|
|
| 243 |
|
| 244 |
|
| 245 |
def main():
|
| 246 |
-
"""Run all 3 tasks sequentially."""
|
| 247 |
for task_config in TASKS:
|
| 248 |
run_episode(task_config)
|
| 249 |
-
time.sleep(1)
|
| 250 |
|
| 251 |
|
| 252 |
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import json
|
| 2 |
import os
|
| 3 |
import re
|
|
|
|
| 6 |
import requests
|
| 7 |
from openai import OpenAI
|
| 8 |
|
|
|
|
| 9 |
API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
|
| 10 |
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4.1-mini")
|
| 11 |
HF_TOKEN = os.getenv("HF_TOKEN")
|
|
|
|
| 13 |
if HF_TOKEN is None:
|
| 14 |
raise ValueError("HF_TOKEN environment variable is required")
|
| 15 |
|
|
|
|
| 16 |
client = OpenAI(
|
| 17 |
base_url=API_BASE_URL,
|
| 18 |
api_key=HF_TOKEN,
|
| 19 |
)
|
| 20 |
|
|
|
|
| 21 |
ENV_BASE_URL = os.getenv("ENV_BASE_URL", "https://kush5699-data-validation-env.hf.space")
|
| 22 |
|
|
|
|
| 23 |
TASKS = [
|
| 24 |
{"task_name": "easy_missing_values", "seed": 42},
|
| 25 |
{"task_name": "medium_mixed_errors", "seed": 42},
|
|
|
|
| 30 |
|
| 31 |
|
| 32 |
def call_llm(messages: list) -> str:
|
|
|
|
| 33 |
try:
|
| 34 |
response = client.chat.completions.create(
|
| 35 |
model=MODEL_NAME,
|
|
|
|
| 48 |
|
| 49 |
|
| 50 |
def env_reset(task_name: str, seed: int = 42) -> dict:
|
|
|
|
| 51 |
resp = requests.post(
|
| 52 |
f"{ENV_BASE_URL}/reset",
|
| 53 |
json={"task_name": task_name, "seed": seed},
|
|
|
|
| 58 |
|
| 59 |
|
| 60 |
def env_step(action: dict) -> dict:
|
|
|
|
| 61 |
resp = requests.post(
|
| 62 |
f"{ENV_BASE_URL}/step",
|
| 63 |
json=action,
|
|
|
|
| 68 |
|
| 69 |
|
| 70 |
def build_system_prompt(obs: dict) -> str:
|
|
|
|
| 71 |
return f"""You are a data validation agent. Your task is to fix errors in a dataset.
|
| 72 |
|
| 73 |
TASK: {obs.get('task_description', '')}
|
|
|
|
| 93 |
|
| 94 |
|
| 95 |
def build_user_prompt(obs: dict) -> str:
|
|
|
|
| 96 |
errors = obs.get("errors_found", [])
|
| 97 |
dataset = obs.get("dataset", [])
|
|
|
|
| 98 |
errors_text = json.dumps(errors, indent=2) if errors else "No errors remaining!"
|
|
|
|
|
|
|
| 99 |
dataset_compact = []
|
| 100 |
for i, row in enumerate(dataset):
|
| 101 |
dataset_compact.append(f"Row {i}: {json.dumps(row)}")
|
|
|
|
| 119 |
|
| 120 |
|
| 121 |
def parse_llm_response(response: str) -> dict:
|
|
|
|
|
|
|
| 122 |
try:
|
| 123 |
action = json.loads(response)
|
| 124 |
return {
|
|
|
|
| 130 |
except json.JSONDecodeError:
|
| 131 |
pass
|
| 132 |
|
|
|
|
| 133 |
json_match = re.search(r'\{[^}]+\}', response)
|
| 134 |
if json_match:
|
| 135 |
try:
|
|
|
|
| 143 |
except (json.JSONDecodeError, ValueError):
|
| 144 |
pass
|
| 145 |
|
|
|
|
| 146 |
return {"action_type": "skip", "target_field": "", "target_row": 0, "new_value": ""}
|
| 147 |
|
| 148 |
|
| 149 |
def run_episode(task_config: dict) -> None:
|
|
|
|
| 150 |
task_name = task_config["task_name"]
|
| 151 |
seed = task_config.get("seed", 42)
|
| 152 |
rewards = []
|
|
|
|
| 156 |
print(f"[START] task={task_name} env={BENCHMARK_NAME} model={MODEL_NAME}")
|
| 157 |
|
| 158 |
try:
|
|
|
|
| 159 |
obs = env_reset(task_name, seed)
|
| 160 |
max_steps = obs.get("max_steps", 20)
|
| 161 |
|
|
|
|
| 164 |
]
|
| 165 |
|
| 166 |
while not obs.get("done", False) and steps < max_steps:
|
|
|
|
| 167 |
user_msg = build_user_prompt(obs)
|
| 168 |
messages_for_call = messages + [{"role": "user", "content": user_msg}]
|
| 169 |
|
|
|
|
| 170 |
llm_response = call_llm(messages_for_call)
|
| 171 |
|
|
|
|
| 172 |
action = parse_llm_response(llm_response)
|
| 173 |
action_str = json.dumps(action)
|
| 174 |
|
|
|
|
| 175 |
error_msg = None
|
| 176 |
try:
|
| 177 |
obs = env_step(action)
|
|
|
|
| 190 |
if done:
|
| 191 |
break
|
| 192 |
|
|
|
|
| 193 |
total_reward = sum(rewards)
|
| 194 |
success = total_reward > 0.5
|
| 195 |
|
|
|
|
| 205 |
|
| 206 |
|
| 207 |
def main():
|
|
|
|
| 208 |
for task_config in TASKS:
|
| 209 |
run_episode(task_config)
|
| 210 |
+
time.sleep(1)
|
| 211 |
|
| 212 |
|
| 213 |
if __name__ == "__main__":
|
server.py
CHANGED
|
@@ -1,27 +1,19 @@
|
|
| 1 |
-
"""FastAPI server for the Data Validation Pipeline environment.
|
| 2 |
-
|
| 3 |
-
Exposes OpenEnv-compatible HTTP endpoints: /reset, /step, /state, /health
|
| 4 |
-
"""
|
| 5 |
-
|
| 6 |
import json
|
| 7 |
import traceback
|
| 8 |
-
from typing import
|
| 9 |
|
| 10 |
-
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
|
| 11 |
-
from fastapi.responses import JSONResponse
|
| 12 |
from pydantic import BaseModel
|
| 13 |
|
| 14 |
from env.environment import DataValidationEnvironment
|
| 15 |
-
from env.models import DataCleanAction
|
| 16 |
from env.tasks import get_task_names
|
| 17 |
|
| 18 |
app = FastAPI(
|
| 19 |
title="Data Validation Pipeline - OpenEnv Environment",
|
| 20 |
-
description="An RL environment for training agents to clean and validate structured data",
|
| 21 |
version="1.0.0",
|
| 22 |
)
|
| 23 |
|
| 24 |
-
# Single shared environment instance
|
| 25 |
env = DataValidationEnvironment()
|
| 26 |
|
| 27 |
|
|
@@ -39,13 +31,11 @@ class StepRequest(BaseModel):
|
|
| 39 |
|
| 40 |
@app.get("/health")
|
| 41 |
async def health():
|
| 42 |
-
"""Health check endpoint."""
|
| 43 |
return {"status": "healthy", "service": "data-validation-env"}
|
| 44 |
|
| 45 |
|
| 46 |
@app.post("/reset")
|
| 47 |
async def reset(request: ResetRequest = None):
|
| 48 |
-
"""Reset the environment with a new task."""
|
| 49 |
if request is None:
|
| 50 |
request = ResetRequest()
|
| 51 |
try:
|
|
@@ -57,7 +47,6 @@ async def reset(request: ResetRequest = None):
|
|
| 57 |
|
| 58 |
@app.post("/step")
|
| 59 |
async def step(request: StepRequest):
|
| 60 |
-
"""Execute an action in the environment."""
|
| 61 |
try:
|
| 62 |
action = DataCleanAction(
|
| 63 |
action_type=request.action_type,
|
|
@@ -73,7 +62,6 @@ async def step(request: StepRequest):
|
|
| 73 |
|
| 74 |
@app.get("/state")
|
| 75 |
async def state():
|
| 76 |
-
"""Get the current environment state."""
|
| 77 |
try:
|
| 78 |
s = env.state()
|
| 79 |
return s.model_dump()
|
|
@@ -83,25 +71,22 @@ async def state():
|
|
| 83 |
|
| 84 |
@app.get("/tasks")
|
| 85 |
async def tasks():
|
| 86 |
-
"""List available tasks."""
|
| 87 |
return {"tasks": get_task_names()}
|
| 88 |
|
| 89 |
|
| 90 |
-
# WebSocket support for OpenEnv clients
|
| 91 |
@app.websocket("/ws")
|
| 92 |
async def websocket_endpoint(websocket: WebSocket):
|
| 93 |
-
"""WebSocket endpoint for persistent sessions."""
|
| 94 |
await websocket.accept()
|
| 95 |
ws_env = DataValidationEnvironment()
|
| 96 |
-
|
| 97 |
try:
|
| 98 |
while True:
|
| 99 |
data = await websocket.receive_text()
|
| 100 |
msg = json.loads(data)
|
| 101 |
-
|
| 102 |
method = msg.get("method", "")
|
| 103 |
params = msg.get("params", {})
|
| 104 |
-
|
| 105 |
try:
|
| 106 |
if method == "reset":
|
| 107 |
obs = ws_env.reset(
|
|
@@ -131,7 +116,7 @@ async def websocket_endpoint(websocket: WebSocket):
|
|
| 131 |
}
|
| 132 |
else:
|
| 133 |
response = {"error": f"Unknown method: {method}"}
|
| 134 |
-
|
| 135 |
await websocket.send_text(json.dumps(response))
|
| 136 |
except Exception as e:
|
| 137 |
await websocket.send_text(json.dumps({
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import json
|
| 2 |
import traceback
|
| 3 |
+
from typing import Optional
|
| 4 |
|
| 5 |
+
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
|
|
|
|
| 6 |
from pydantic import BaseModel
|
| 7 |
|
| 8 |
from env.environment import DataValidationEnvironment
|
| 9 |
+
from env.models import DataCleanAction
|
| 10 |
from env.tasks import get_task_names
|
| 11 |
|
| 12 |
app = FastAPI(
|
| 13 |
title="Data Validation Pipeline - OpenEnv Environment",
|
|
|
|
| 14 |
version="1.0.0",
|
| 15 |
)
|
| 16 |
|
|
|
|
| 17 |
env = DataValidationEnvironment()
|
| 18 |
|
| 19 |
|
|
|
|
| 31 |
|
| 32 |
@app.get("/health")
|
| 33 |
async def health():
|
|
|
|
| 34 |
return {"status": "healthy", "service": "data-validation-env"}
|
| 35 |
|
| 36 |
|
| 37 |
@app.post("/reset")
|
| 38 |
async def reset(request: ResetRequest = None):
|
|
|
|
| 39 |
if request is None:
|
| 40 |
request = ResetRequest()
|
| 41 |
try:
|
|
|
|
| 47 |
|
| 48 |
@app.post("/step")
|
| 49 |
async def step(request: StepRequest):
|
|
|
|
| 50 |
try:
|
| 51 |
action = DataCleanAction(
|
| 52 |
action_type=request.action_type,
|
|
|
|
| 62 |
|
| 63 |
@app.get("/state")
|
| 64 |
async def state():
|
|
|
|
| 65 |
try:
|
| 66 |
s = env.state()
|
| 67 |
return s.model_dump()
|
|
|
|
| 71 |
|
| 72 |
@app.get("/tasks")
|
| 73 |
async def tasks():
|
|
|
|
| 74 |
return {"tasks": get_task_names()}
|
| 75 |
|
| 76 |
|
|
|
|
| 77 |
@app.websocket("/ws")
|
| 78 |
async def websocket_endpoint(websocket: WebSocket):
|
|
|
|
| 79 |
await websocket.accept()
|
| 80 |
ws_env = DataValidationEnvironment()
|
| 81 |
+
|
| 82 |
try:
|
| 83 |
while True:
|
| 84 |
data = await websocket.receive_text()
|
| 85 |
msg = json.loads(data)
|
| 86 |
+
|
| 87 |
method = msg.get("method", "")
|
| 88 |
params = msg.get("params", {})
|
| 89 |
+
|
| 90 |
try:
|
| 91 |
if method == "reset":
|
| 92 |
obs = ws_env.reset(
|
|
|
|
| 116 |
}
|
| 117 |
else:
|
| 118 |
response = {"error": f"Unknown method: {method}"}
|
| 119 |
+
|
| 120 |
await websocket.send_text(json.dumps(response))
|
| 121 |
except Exception as e:
|
| 122 |
await websocket.send_text(json.dumps({
|