kush5699 commited on
Commit
42757ca
·
verified ·
1 Parent(s): 0e5c7a8

Upload folder using huggingface_hub

Browse files
Files changed (8) hide show
  1. .gitignore +12 -0
  2. Dockerfile +1 -9
  3. env/__init__.py +0 -1
  4. env/environment.py +19 -55
  5. env/models.py +21 -42
  6. env/tasks.py +43 -104
  7. inference.py +1 -40
  8. server.py +7 -22
.gitignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ *.pyo
4
+ .env
5
+ *.egg-info/
6
+ dist/
7
+ build/
8
+ .vscode/
9
+ .idea/
10
+ test_space.py
11
+ test_all_tasks.py
12
+ *.pdf
Dockerfile CHANGED
@@ -2,24 +2,16 @@ FROM python:3.11-slim
2
 
3
  WORKDIR /app
4
 
5
- # Install system dependencies
6
- RUN apt-get update && apt-get install -y --no-install-recommends \
7
- curl \
8
- && rm -rf /var/lib/apt/lists/*
9
 
10
- # Copy requirements first for caching
11
  COPY requirements.txt .
12
  RUN pip install --no-cache-dir -r requirements.txt
13
 
14
- # Copy application code
15
  COPY . .
16
 
17
- # Expose port
18
  EXPOSE 8000
19
 
20
- # Health check
21
  HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
22
  CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" || exit 1
23
 
24
- # Run the server
25
  CMD ["uvicorn", "server:app", "--host", "0.0.0.0", "--port", "8000"]
 
2
 
3
  WORKDIR /app
4
 
5
+ RUN apt-get update && apt-get install -y --no-install-recommends curl && rm -rf /var/lib/apt/lists/*
 
 
 
6
 
 
7
  COPY requirements.txt .
8
  RUN pip install --no-cache-dir -r requirements.txt
9
 
 
10
  COPY . .
11
 
 
12
  EXPOSE 8000
13
 
 
14
  HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
15
  CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" || exit 1
16
 
 
17
  CMD ["uvicorn", "server:app", "--host", "0.0.0.0", "--port", "8000"]
env/__init__.py CHANGED
@@ -1,2 +1 @@
1
- """Data Validation Pipeline - OpenEnv Environment."""
2
  from env.models import DataCleanAction, DataCleanObservation, DataCleanState
 
 
1
  from env.models import DataCleanAction, DataCleanObservation, DataCleanState
env/environment.py CHANGED
@@ -1,5 +1,3 @@
1
- """Core Environment implementation for the Data Validation Pipeline."""
2
-
3
  import uuid
4
  from typing import Any, Dict, List, Optional
5
 
@@ -8,44 +6,25 @@ from env.tasks import generate_task, get_task_names, grade_action
8
 
9
 
10
  class DataValidationEnvironment:
11
- """
12
- Data Validation Pipeline Environment.
13
-
14
- An RL environment where the agent must clean and validate structured datasets
15
- by identifying and fixing errors (missing values, type mismatches, format violations,
16
- range errors, and duplicates).
17
-
18
- Follows OpenEnv Environment interface: reset(), step(), state().
19
- """
20
-
21
  def __init__(self):
22
  self._state = DataCleanState()
23
  self._ground_truth: List[Dict[str, Any]] = []
24
  self._errors: List[Dict[str, Any]] = []
25
  self._task_info: Dict[str, Any] = {}
26
  self._field_names: List[str] = []
27
-
28
  def reset(self, task_name: Optional[str] = None, seed: int = 42, **kwargs) -> DataCleanObservation:
29
- """
30
- Reset the environment with a new task.
31
-
32
- Args:
33
- task_name: Task to load ('easy_missing_values', 'medium_mixed_errors', 'hard_multi_constraint')
34
- seed: Random seed for reproducibility
35
-
36
- Returns:
37
- Initial observation
38
- """
39
  if task_name is None:
40
  task_name = "easy_missing_values"
41
-
42
  task = generate_task(task_name, seed)
43
-
44
  self._ground_truth = task["ground_truth"]
45
  self._errors = task["errors"]
46
  self._task_info = task
47
  self._field_names = task["field_names"]
48
-
49
  self._state = DataCleanState(
50
  episode_id=str(uuid.uuid4()),
51
  task_name=task_name,
@@ -61,7 +40,7 @@ class DataValidationEnvironment:
61
  total_errors=len(self._errors),
62
  last_actions=[],
63
  )
64
-
65
  return DataCleanObservation(
66
  task_name=task_name,
67
  task_description=task["description"],
@@ -80,27 +59,17 @@ class DataValidationEnvironment:
80
  progress_pct=0.0,
81
  field_names=self._field_names,
82
  )
83
-
84
  def step(self, action: DataCleanAction) -> DataCleanObservation:
85
- """
86
- Execute an action to fix a data error.
87
-
88
- Args:
89
- action: The action to take
90
-
91
- Returns:
92
- Updated observation with reward
93
- """
94
  if self._state.done:
95
  return self._make_observation(0.0, "Episode already done. Call reset().")
96
-
97
  self._state.step_count += 1
98
-
99
- # Check for repeated identical action
100
  action_key = f"{action.action_type}:{action.target_field}:{action.target_row}:{action.new_value}"
101
  is_repeat = action_key in self._state.last_actions
102
  self._state.last_actions.append(action_key)
103
-
104
  if is_repeat:
105
  reward = -0.1
106
  message = "Penalty: repeated identical action"
@@ -116,39 +85,34 @@ class DataValidationEnvironment:
116
  )
117
  if fixed:
118
  self._state.errors_fixed += 1
119
-
120
  self._state.cumulative_reward += reward
121
  self._state.reward_history.append(reward)
122
-
123
- # Check termination conditions
124
  errors_remaining = sum(1 for e in self._errors if not e.get("fixed", False))
125
-
126
  if errors_remaining == 0:
127
  self._state.done = True
128
  message += " | All errors fixed! Episode complete."
129
  elif self._state.step_count >= self._state.max_steps:
130
  self._state.done = True
131
  message += f" | Max steps reached. {errors_remaining} errors remaining."
132
-
133
  return self._make_observation(reward, message)
134
-
135
  def state(self) -> DataCleanState:
136
- """Return the current environment state."""
137
  return self._state
138
-
139
  def get_task_names(self) -> List[str]:
140
- """Return available task names."""
141
  return get_task_names()
142
-
143
  def _make_observation(self, reward: float, message: str) -> DataCleanObservation:
144
- """Create an observation from current state."""
145
  errors_remaining = sum(1 for e in self._errors if not e.get("fixed", False))
146
  total = self._state.total_errors if self._state.total_errors > 0 else 1
147
  progress = (self._state.errors_fixed / total) * 100
148
-
149
- # Only show unfixed errors
150
  unfixed_errors = [e for e in self._errors if not e.get("fixed", False)]
151
-
152
  return DataCleanObservation(
153
  task_name=self._state.task_name,
154
  task_description=self._task_info.get("description", ""),
 
 
 
1
  import uuid
2
  from typing import Any, Dict, List, Optional
3
 
 
6
 
7
 
8
  class DataValidationEnvironment:
9
+
 
 
 
 
 
 
 
 
 
10
  def __init__(self):
11
  self._state = DataCleanState()
12
  self._ground_truth: List[Dict[str, Any]] = []
13
  self._errors: List[Dict[str, Any]] = []
14
  self._task_info: Dict[str, Any] = {}
15
  self._field_names: List[str] = []
16
+
17
  def reset(self, task_name: Optional[str] = None, seed: int = 42, **kwargs) -> DataCleanObservation:
 
 
 
 
 
 
 
 
 
 
18
  if task_name is None:
19
  task_name = "easy_missing_values"
20
+
21
  task = generate_task(task_name, seed)
22
+
23
  self._ground_truth = task["ground_truth"]
24
  self._errors = task["errors"]
25
  self._task_info = task
26
  self._field_names = task["field_names"]
27
+
28
  self._state = DataCleanState(
29
  episode_id=str(uuid.uuid4()),
30
  task_name=task_name,
 
40
  total_errors=len(self._errors),
41
  last_actions=[],
42
  )
43
+
44
  return DataCleanObservation(
45
  task_name=task_name,
46
  task_description=task["description"],
 
59
  progress_pct=0.0,
60
  field_names=self._field_names,
61
  )
62
+
63
  def step(self, action: DataCleanAction) -> DataCleanObservation:
 
 
 
 
 
 
 
 
 
64
  if self._state.done:
65
  return self._make_observation(0.0, "Episode already done. Call reset().")
66
+
67
  self._state.step_count += 1
68
+
 
69
  action_key = f"{action.action_type}:{action.target_field}:{action.target_row}:{action.new_value}"
70
  is_repeat = action_key in self._state.last_actions
71
  self._state.last_actions.append(action_key)
72
+
73
  if is_repeat:
74
  reward = -0.1
75
  message = "Penalty: repeated identical action"
 
85
  )
86
  if fixed:
87
  self._state.errors_fixed += 1
88
+
89
  self._state.cumulative_reward += reward
90
  self._state.reward_history.append(reward)
91
+
 
92
  errors_remaining = sum(1 for e in self._errors if not e.get("fixed", False))
93
+
94
  if errors_remaining == 0:
95
  self._state.done = True
96
  message += " | All errors fixed! Episode complete."
97
  elif self._state.step_count >= self._state.max_steps:
98
  self._state.done = True
99
  message += f" | Max steps reached. {errors_remaining} errors remaining."
100
+
101
  return self._make_observation(reward, message)
102
+
103
  def state(self) -> DataCleanState:
 
104
  return self._state
105
+
106
  def get_task_names(self) -> List[str]:
 
107
  return get_task_names()
108
+
109
  def _make_observation(self, reward: float, message: str) -> DataCleanObservation:
 
110
  errors_remaining = sum(1 for e in self._errors if not e.get("fixed", False))
111
  total = self._state.total_errors if self._state.total_errors > 0 else 1
112
  progress = (self._state.errors_fixed / total) * 100
113
+
 
114
  unfixed_errors = [e for e in self._errors if not e.get("fixed", False)]
115
+
116
  return DataCleanObservation(
117
  task_name=self._state.task_name,
118
  task_description=self._task_info.get("description", ""),
env/models.py CHANGED
@@ -1,61 +1,40 @@
1
- """Pydantic models for the Data Validation Pipeline environment."""
2
-
3
  from typing import Any, Dict, List, Optional
4
  from pydantic import BaseModel, Field
5
 
6
 
7
  class DataCleanAction(BaseModel):
8
- """Action that the agent can take to clean/fix data."""
9
- action_type: str = Field(
10
- ...,
11
- description="Type of action: fix_missing, fix_type, fix_range, fix_format, fix_duplicate, validate, skip"
12
- )
13
- target_field: str = Field(
14
- default="",
15
- description="The field/column name to apply the action to"
16
- )
17
- target_row: int = Field(
18
- default=0,
19
- description="The row index to apply the action to"
20
- )
21
- new_value: str = Field(
22
- default="",
23
- description="The new/corrected value to set"
24
- )
25
 
26
 
27
  class DataCleanObservation(BaseModel):
28
- """Observation returned by the environment."""
29
- task_name: str = Field(default="", description="Name of the current task")
30
- task_description: str = Field(default="", description="Description of what to do")
31
- dataset: List[Dict[str, Any]] = Field(default_factory=list, description="Current state of the dataset")
32
- errors_found: List[Dict[str, Any]] = Field(
33
- default_factory=list,
34
- description="List of errors detected in the dataset"
35
- )
36
- errors_remaining: int = Field(default=0, description="Number of errors left to fix")
37
- errors_total: int = Field(default=0, description="Total errors at start")
38
- errors_fixed: int = Field(default=0, description="Number of errors successfully fixed")
39
- step_count: int = Field(default=0, description="Current step number")
40
- max_steps: int = Field(default=20, description="Max steps allowed")
41
- reward: float = Field(default=0.0, description="Reward from last action")
42
- cumulative_reward: float = Field(default=0.0, description="Total reward accumulated")
43
- done: bool = Field(default=False, description="Whether episode is finished")
44
- last_action_result: str = Field(default="", description="Result of the last action")
45
- task_hint: str = Field(default="", description="Hint for solving the task")
46
  available_actions: List[str] = Field(
47
  default_factory=lambda: [
48
  "fix_missing", "fix_type", "fix_range", "fix_format",
49
  "fix_duplicate", "validate", "skip"
50
- ],
51
- description="Available action types"
52
  )
53
- progress_pct: float = Field(default=0.0, description="Progress percentage (0-100)")
54
- field_names: List[str] = Field(default_factory=list, description="Column names in the dataset")
55
 
56
 
57
  class DataCleanState(BaseModel):
58
- """Full internal state of the environment."""
59
  episode_id: str = Field(default="")
60
  task_name: str = Field(default="")
61
  step_count: int = Field(default=0)
 
 
 
1
  from typing import Any, Dict, List, Optional
2
  from pydantic import BaseModel, Field
3
 
4
 
5
  class DataCleanAction(BaseModel):
6
+ action_type: str = Field(...)
7
+ target_field: str = Field(default="")
8
+ target_row: int = Field(default=0)
9
+ new_value: str = Field(default="")
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
 
12
  class DataCleanObservation(BaseModel):
13
+ task_name: str = Field(default="")
14
+ task_description: str = Field(default="")
15
+ dataset: List[Dict[str, Any]] = Field(default_factory=list)
16
+ errors_found: List[Dict[str, Any]] = Field(default_factory=list)
17
+ errors_remaining: int = Field(default=0)
18
+ errors_total: int = Field(default=0)
19
+ errors_fixed: int = Field(default=0)
20
+ step_count: int = Field(default=0)
21
+ max_steps: int = Field(default=20)
22
+ reward: float = Field(default=0.0)
23
+ cumulative_reward: float = Field(default=0.0)
24
+ done: bool = Field(default=False)
25
+ last_action_result: str = Field(default="")
26
+ task_hint: str = Field(default="")
 
 
 
 
27
  available_actions: List[str] = Field(
28
  default_factory=lambda: [
29
  "fix_missing", "fix_type", "fix_range", "fix_format",
30
  "fix_duplicate", "validate", "skip"
31
+ ]
 
32
  )
33
+ progress_pct: float = Field(default=0.0)
34
+ field_names: List[str] = Field(default_factory=list)
35
 
36
 
37
  class DataCleanState(BaseModel):
 
38
  episode_id: str = Field(default="")
39
  task_name: str = Field(default="")
40
  step_count: int = Field(default=0)
env/tasks.py CHANGED
@@ -1,24 +1,11 @@
1
- """Task registry and graders for the Data Validation Pipeline environment.
2
-
3
- Each task provides:
4
- - A dirty dataset with injected errors
5
- - A ground truth clean dataset
6
- - A grader that scores partial progress
7
- """
8
-
9
  import copy
10
  import random
11
  from typing import Any, Dict, List, Tuple
12
 
13
 
14
- # ──────────────────────────────────────────────────────────────────────
15
- # TASK 1 (Easy): Fix Missing Values — solvable in ≤5 steps
16
- # ──────────────────────────────────────────────────────────────────────
17
-
18
  def _generate_task_easy(seed: int = 42) -> Tuple[List[Dict], List[Dict], List[Dict]]:
19
- """Generate a small employee dataset with missing values only."""
20
  rng = random.Random(seed)
21
-
22
  ground_truth = [
23
  {"id": 1, "name": "Alice Johnson", "email": "alice@example.com", "age": 30, "department": "Engineering"},
24
  {"id": 2, "name": "Bob Smith", "email": "bob@example.com", "age": 25, "department": "Marketing"},
@@ -26,17 +13,16 @@ def _generate_task_easy(seed: int = 42) -> Tuple[List[Dict], List[Dict], List[Di
26
  {"id": 4, "name": "David Brown", "email": "david@example.com", "age": 28, "department": "Sales"},
27
  {"id": 5, "name": "Eve Davis", "email": "eve@example.com", "age": 32, "department": "Marketing"},
28
  ]
29
-
30
  dirty = copy.deepcopy(ground_truth)
31
  errors = []
32
-
33
- # Inject 3 missing value errors
34
  missing_configs = [
35
  (1, "email", ""),
36
  (2, "department", ""),
37
  (4, "name", ""),
38
  ]
39
-
40
  for row_idx, field, replacement in missing_configs:
41
  dirty[row_idx][field] = replacement
42
  errors.append({
@@ -46,18 +32,13 @@ def _generate_task_easy(seed: int = 42) -> Tuple[List[Dict], List[Dict], List[Di
46
  "current_value": replacement,
47
  "description": f"Row {row_idx}: '{field}' is missing/empty"
48
  })
49
-
50
- return dirty, ground_truth, errors
51
 
 
52
 
53
- # ──────────────────────────────────────────────────────────────────────
54
- # TASK 2 (Medium): Fix Types & Formats — requires 2-3 stage reasoning
55
- # ──────────────────────────────────────────────────────────────────────
56
 
57
  def _generate_task_medium(seed: int = 42) -> Tuple[List[Dict], List[Dict], List[Dict]]:
58
- """Generate a product dataset with type, format, and missing errors."""
59
  rng = random.Random(seed)
60
-
61
  ground_truth = [
62
  {"id": 1, "product": "Laptop Pro", "price": 999.99, "quantity": 50, "sku": "LP-001", "category": "Electronics"},
63
  {"id": 2, "product": "Wireless Mouse", "price": 29.99, "quantity": 200, "sku": "WM-002", "category": "Accessories"},
@@ -67,11 +48,10 @@ def _generate_task_medium(seed: int = 42) -> Tuple[List[Dict], List[Dict], List[
67
  {"id": 6, "product": "Headphones", "price": 149.99, "quantity": 80, "sku": "HP-006", "category": "Audio"},
68
  {"id": 7, "product": "Webcam HD", "price": 59.99, "quantity": 120, "sku": "WC-007", "category": "Electronics"},
69
  ]
70
-
71
  dirty = copy.deepcopy(ground_truth)
72
  errors = []
73
-
74
- # Error 1: price stored as string
75
  dirty[0]["price"] = "999.99"
76
  errors.append({
77
  "error_type": "type",
@@ -81,8 +61,7 @@ def _generate_task_medium(seed: int = 42) -> Tuple[List[Dict], List[Dict], List[
81
  "expected_type": "float",
82
  "description": "Row 0: 'price' should be float, got string '999.99'"
83
  })
84
-
85
- # Error 2: quantity stored as string
86
  dirty[2]["quantity"] = "five hundred"
87
  errors.append({
88
  "error_type": "type",
@@ -92,8 +71,7 @@ def _generate_task_medium(seed: int = 42) -> Tuple[List[Dict], List[Dict], List[
92
  "expected_type": "int",
93
  "description": "Row 2: 'quantity' should be int, got string 'five hundred'"
94
  })
95
-
96
- # Error 3: SKU wrong format
97
  dirty[3]["sku"] = "mn004"
98
  errors.append({
99
  "error_type": "format",
@@ -103,8 +81,7 @@ def _generate_task_medium(seed: int = 42) -> Tuple[List[Dict], List[Dict], List[
103
  "expected_format": "XX-NNN",
104
  "description": "Row 3: 'sku' should match format 'XX-NNN', got 'mn004'"
105
  })
106
-
107
- # Error 4: missing category
108
  dirty[5]["category"] = ""
109
  errors.append({
110
  "error_type": "missing",
@@ -113,8 +90,7 @@ def _generate_task_medium(seed: int = 42) -> Tuple[List[Dict], List[Dict], List[
113
  "current_value": "",
114
  "description": "Row 5: 'category' is missing/empty"
115
  })
116
-
117
- # Error 5: negative price (range error)
118
  dirty[4]["price"] = -79.99
119
  errors.append({
120
  "error_type": "range",
@@ -123,8 +99,7 @@ def _generate_task_medium(seed: int = 42) -> Tuple[List[Dict], List[Dict], List[
123
  "current_value": -79.99,
124
  "description": "Row 4: 'price' is negative (-79.99), should be positive"
125
  })
126
-
127
- # Error 6: duplicate SKU
128
  dirty[6]["sku"] = "WM-002"
129
  errors.append({
130
  "error_type": "duplicate",
@@ -133,18 +108,13 @@ def _generate_task_medium(seed: int = 42) -> Tuple[List[Dict], List[Dict], List[
133
  "current_value": "WM-002",
134
  "description": "Row 6: 'sku' value 'WM-002' duplicates row 1"
135
  })
136
-
137
- return dirty, ground_truth, errors
138
 
 
139
 
140
- # ──────────────────────────────────────────────────────────────────────
141
- # TASK 3 (Hard): Multi-constraint Optimization — requires planning
142
- # ──────────────────────────────────────────────────────────────────────
143
 
144
  def _generate_task_hard(seed: int = 42) -> Tuple[List[Dict], List[Dict], List[Dict]]:
145
- """Generate a complex customer orders dataset with multiple interrelated errors."""
146
  rng = random.Random(seed)
147
-
148
  ground_truth = [
149
  {"id": 1, "customer": "Acme Corp", "email": "orders@acme.com", "amount": 1500.00, "currency": "USD", "status": "completed", "date": "2024-03-15", "region": "North America", "priority": "high"},
150
  {"id": 2, "customer": "GlobalTech", "email": "sales@globaltech.io", "amount": 2300.50, "currency": "EUR", "status": "pending", "date": "2024-03-16", "region": "Europe", "priority": "medium"},
@@ -157,57 +127,43 @@ def _generate_task_hard(seed: int = 42) -> Tuple[List[Dict], List[Dict], List[Di
157
  {"id": 9, "customer": "EcoSmart", "email": "green@ecosmart.co", "amount": 1200.00, "currency": "AUD", "status": "shipped", "date": "2024-03-23", "region": "Asia Pacific", "priority": "medium"},
158
  {"id": 10, "customer": "BlueOcean", "email": "info@blueocean.net", "amount": 980.75, "currency": "USD", "status": "pending", "date": "2024-03-24", "region": "North America", "priority": "low"},
159
  ]
160
-
161
  dirty = copy.deepcopy(ground_truth)
162
  errors = []
163
-
164
- # Error 1: missing email
165
  dirty[0]["email"] = ""
166
  errors.append({"error_type": "missing", "row": 0, "field": "email", "current_value": "", "description": "Row 0: 'email' is missing"})
167
-
168
- # Error 2: negative amount
169
  dirty[1]["amount"] = -2300.50
170
  errors.append({"error_type": "range", "row": 1, "field": "amount", "current_value": -2300.50, "description": "Row 1: 'amount' is negative"})
171
-
172
- # Error 3: invalid date format
173
  dirty[2]["date"] = "03/17/2024"
174
  errors.append({"error_type": "format", "row": 2, "field": "date", "current_value": "03/17/2024", "expected_format": "YYYY-MM-DD", "description": "Row 2: 'date' wrong format, expected YYYY-MM-DD"})
175
-
176
- # Error 4: amount as string
177
  dirty[3]["amount"] = "4200"
178
  errors.append({"error_type": "type", "row": 3, "field": "amount", "current_value": "4200", "expected_type": "float", "description": "Row 3: 'amount' should be float, got string"})
179
-
180
- # Error 5: invalid status
181
  dirty[4]["status"] = "in-progress"
182
  errors.append({"error_type": "format", "row": 4, "field": "status", "current_value": "in-progress", "expected_format": "one of: completed, pending, shipped, cancelled", "description": "Row 4: 'status' invalid value 'in-progress'"})
183
-
184
- # Error 6: missing region
185
  dirty[5]["region"] = ""
186
  errors.append({"error_type": "missing", "row": 5, "field": "region", "current_value": "", "description": "Row 5: 'region' is missing"})
187
-
188
- # Error 7: duplicate customer
189
  dirty[6]["customer"] = "Acme Corp"
190
  dirty[6]["email"] = "orders@acme.com"
191
  errors.append({"error_type": "duplicate", "row": 6, "field": "customer", "current_value": "Acme Corp", "description": "Row 6: 'customer' duplicates row 0"})
192
-
193
- # Error 8: amount out of range (too high)
194
  dirty[7]["amount"] = 99999.99
195
  errors.append({"error_type": "range", "row": 7, "field": "amount", "current_value": 99999.99, "description": "Row 7: 'amount' exceeds maximum threshold (should be 6750.00)"})
196
-
197
- # Error 9: invalid currency
198
  dirty[8]["currency"] = "AUSD"
199
  errors.append({"error_type": "format", "row": 8, "field": "currency", "current_value": "AUSD", "expected_format": "3-letter ISO code", "description": "Row 8: 'currency' invalid code 'AUSD'"})
200
-
201
- # Error 10: missing priority + wrong type
202
  dirty[9]["priority"] = ""
203
  errors.append({"error_type": "missing", "row": 9, "field": "priority", "current_value": "", "description": "Row 9: 'priority' is missing"})
204
-
205
- return dirty, ground_truth, errors
206
 
 
207
 
208
- # ──────────────────────────────────────────────────────────────────────
209
- # Task Registry
210
- # ──────────────────────────────────────────────────────────────────────
211
 
212
  TASK_REGISTRY = {
213
  "easy_missing_values": {
@@ -238,18 +194,16 @@ TASK_REGISTRY = {
238
 
239
 
240
  def get_task_names() -> List[str]:
241
- """Return all registered task names."""
242
  return list(TASK_REGISTRY.keys())
243
 
244
 
245
  def generate_task(task_name: str, seed: int = 42) -> Dict[str, Any]:
246
- """Generate a task by name."""
247
  if task_name not in TASK_REGISTRY:
248
  raise ValueError(f"Unknown task: {task_name}. Available: {get_task_names()}")
249
-
250
  task_info = TASK_REGISTRY[task_name]
251
  dirty, ground_truth, errors = task_info["generator"](seed)
252
-
253
  return {
254
  "name": task_info["name"],
255
  "description": task_info["description"],
@@ -263,29 +217,18 @@ def generate_task(task_name: str, seed: int = 42) -> Dict[str, Any]:
263
  }
264
 
265
 
266
- def grade_action(action_type: str, target_field: str, target_row: int,
267
- new_value: str, dirty_dataset: List[Dict],
268
  ground_truth: List[Dict], errors: List[Dict]) -> Tuple[float, str, bool]:
269
- """
270
- Grade a single action. Returns (reward, message, error_fixed).
271
-
272
- Reward strategy:
273
- - Correct fix: +1.0 / total_errors (proportional)
274
- - Wrong fix: -0.05
275
- - Skip: 0.0
276
- - Validate (check progress): 0.0
277
- - Repeated identical action: -0.1
278
- """
279
  total_errors = len(errors) if errors else 1
280
-
281
  if action_type == "validate":
282
  fixed = sum(1 for e in errors if e.get("fixed", False))
283
  return 0.0, f"Validation: {fixed}/{total_errors} errors fixed ({fixed/total_errors*100:.0f}%)", False
284
-
285
  if action_type == "skip":
286
  return 0.0, "Skipped current action", False
287
-
288
- # Find matching error
289
  matching_error = None
290
  for e in errors:
291
  if e.get("fixed", False):
@@ -293,11 +236,10 @@ def grade_action(action_type: str, target_field: str, target_row: int,
293
  if e["row"] == target_row and e["field"] == target_field:
294
  matching_error = e
295
  break
296
-
297
  if matching_error is None:
298
  return -0.05, f"No unfixed error at row {target_row}, field '{target_field}'", False
299
-
300
- # Check if the action type matches the error type
301
  action_to_error_map = {
302
  "fix_missing": "missing",
303
  "fix_type": "type",
@@ -305,15 +247,13 @@ def grade_action(action_type: str, target_field: str, target_row: int,
305
  "fix_format": "format",
306
  "fix_duplicate": "duplicate",
307
  }
308
-
309
  expected_error_type = action_to_error_map.get(action_type, "")
310
  if expected_error_type != matching_error["error_type"]:
311
  return -0.05, f"Wrong action type '{action_type}' for error type '{matching_error['error_type']}'", False
312
-
313
- # Check the new value against ground truth
314
  gt_value = ground_truth[target_row][target_field]
315
-
316
- # Flexible value comparison
317
  is_correct = False
318
  try:
319
  if isinstance(gt_value, float):
@@ -324,18 +264,17 @@ def grade_action(action_type: str, target_field: str, target_row: int,
324
  is_correct = str(new_value).strip() == str(gt_value).strip()
325
  except (ValueError, TypeError):
326
  is_correct = str(new_value).strip() == str(gt_value).strip()
327
-
328
  if is_correct:
329
  matching_error["fixed"] = True
330
- # Update the dirty dataset
331
  if isinstance(gt_value, float):
332
  dirty_dataset[target_row][target_field] = float(new_value)
333
  elif isinstance(gt_value, int):
334
  dirty_dataset[target_row][target_field] = int(float(new_value))
335
  else:
336
  dirty_dataset[target_row][target_field] = new_value
337
-
338
  reward = 1.0 / total_errors
339
- return reward, f"Fixed: row {target_row}, field '{target_field}' '{new_value}'", True
340
  else:
341
- return -0.05, f"Wrong value for row {target_row}, field '{target_field}'. Got '{new_value}', expected something else.", False
 
 
 
 
 
 
 
 
 
1
  import copy
2
  import random
3
  from typing import Any, Dict, List, Tuple
4
 
5
 
 
 
 
 
6
  def _generate_task_easy(seed: int = 42) -> Tuple[List[Dict], List[Dict], List[Dict]]:
 
7
  rng = random.Random(seed)
8
+
9
  ground_truth = [
10
  {"id": 1, "name": "Alice Johnson", "email": "alice@example.com", "age": 30, "department": "Engineering"},
11
  {"id": 2, "name": "Bob Smith", "email": "bob@example.com", "age": 25, "department": "Marketing"},
 
13
  {"id": 4, "name": "David Brown", "email": "david@example.com", "age": 28, "department": "Sales"},
14
  {"id": 5, "name": "Eve Davis", "email": "eve@example.com", "age": 32, "department": "Marketing"},
15
  ]
16
+
17
  dirty = copy.deepcopy(ground_truth)
18
  errors = []
19
+
 
20
  missing_configs = [
21
  (1, "email", ""),
22
  (2, "department", ""),
23
  (4, "name", ""),
24
  ]
25
+
26
  for row_idx, field, replacement in missing_configs:
27
  dirty[row_idx][field] = replacement
28
  errors.append({
 
32
  "current_value": replacement,
33
  "description": f"Row {row_idx}: '{field}' is missing/empty"
34
  })
 
 
35
 
36
+ return dirty, ground_truth, errors
37
 
 
 
 
38
 
39
  def _generate_task_medium(seed: int = 42) -> Tuple[List[Dict], List[Dict], List[Dict]]:
 
40
  rng = random.Random(seed)
41
+
42
  ground_truth = [
43
  {"id": 1, "product": "Laptop Pro", "price": 999.99, "quantity": 50, "sku": "LP-001", "category": "Electronics"},
44
  {"id": 2, "product": "Wireless Mouse", "price": 29.99, "quantity": 200, "sku": "WM-002", "category": "Accessories"},
 
48
  {"id": 6, "product": "Headphones", "price": 149.99, "quantity": 80, "sku": "HP-006", "category": "Audio"},
49
  {"id": 7, "product": "Webcam HD", "price": 59.99, "quantity": 120, "sku": "WC-007", "category": "Electronics"},
50
  ]
51
+
52
  dirty = copy.deepcopy(ground_truth)
53
  errors = []
54
+
 
55
  dirty[0]["price"] = "999.99"
56
  errors.append({
57
  "error_type": "type",
 
61
  "expected_type": "float",
62
  "description": "Row 0: 'price' should be float, got string '999.99'"
63
  })
64
+
 
65
  dirty[2]["quantity"] = "five hundred"
66
  errors.append({
67
  "error_type": "type",
 
71
  "expected_type": "int",
72
  "description": "Row 2: 'quantity' should be int, got string 'five hundred'"
73
  })
74
+
 
75
  dirty[3]["sku"] = "mn004"
76
  errors.append({
77
  "error_type": "format",
 
81
  "expected_format": "XX-NNN",
82
  "description": "Row 3: 'sku' should match format 'XX-NNN', got 'mn004'"
83
  })
84
+
 
85
  dirty[5]["category"] = ""
86
  errors.append({
87
  "error_type": "missing",
 
90
  "current_value": "",
91
  "description": "Row 5: 'category' is missing/empty"
92
  })
93
+
 
94
  dirty[4]["price"] = -79.99
95
  errors.append({
96
  "error_type": "range",
 
99
  "current_value": -79.99,
100
  "description": "Row 4: 'price' is negative (-79.99), should be positive"
101
  })
102
+
 
103
  dirty[6]["sku"] = "WM-002"
104
  errors.append({
105
  "error_type": "duplicate",
 
108
  "current_value": "WM-002",
109
  "description": "Row 6: 'sku' value 'WM-002' duplicates row 1"
110
  })
 
 
111
 
112
+ return dirty, ground_truth, errors
113
 
 
 
 
114
 
115
  def _generate_task_hard(seed: int = 42) -> Tuple[List[Dict], List[Dict], List[Dict]]:
 
116
  rng = random.Random(seed)
117
+
118
  ground_truth = [
119
  {"id": 1, "customer": "Acme Corp", "email": "orders@acme.com", "amount": 1500.00, "currency": "USD", "status": "completed", "date": "2024-03-15", "region": "North America", "priority": "high"},
120
  {"id": 2, "customer": "GlobalTech", "email": "sales@globaltech.io", "amount": 2300.50, "currency": "EUR", "status": "pending", "date": "2024-03-16", "region": "Europe", "priority": "medium"},
 
127
  {"id": 9, "customer": "EcoSmart", "email": "green@ecosmart.co", "amount": 1200.00, "currency": "AUD", "status": "shipped", "date": "2024-03-23", "region": "Asia Pacific", "priority": "medium"},
128
  {"id": 10, "customer": "BlueOcean", "email": "info@blueocean.net", "amount": 980.75, "currency": "USD", "status": "pending", "date": "2024-03-24", "region": "North America", "priority": "low"},
129
  ]
130
+
131
  dirty = copy.deepcopy(ground_truth)
132
  errors = []
133
+
 
134
  dirty[0]["email"] = ""
135
  errors.append({"error_type": "missing", "row": 0, "field": "email", "current_value": "", "description": "Row 0: 'email' is missing"})
136
+
 
137
  dirty[1]["amount"] = -2300.50
138
  errors.append({"error_type": "range", "row": 1, "field": "amount", "current_value": -2300.50, "description": "Row 1: 'amount' is negative"})
139
+
 
140
  dirty[2]["date"] = "03/17/2024"
141
  errors.append({"error_type": "format", "row": 2, "field": "date", "current_value": "03/17/2024", "expected_format": "YYYY-MM-DD", "description": "Row 2: 'date' wrong format, expected YYYY-MM-DD"})
142
+
 
143
  dirty[3]["amount"] = "4200"
144
  errors.append({"error_type": "type", "row": 3, "field": "amount", "current_value": "4200", "expected_type": "float", "description": "Row 3: 'amount' should be float, got string"})
145
+
 
146
  dirty[4]["status"] = "in-progress"
147
  errors.append({"error_type": "format", "row": 4, "field": "status", "current_value": "in-progress", "expected_format": "one of: completed, pending, shipped, cancelled", "description": "Row 4: 'status' invalid value 'in-progress'"})
148
+
 
149
  dirty[5]["region"] = ""
150
  errors.append({"error_type": "missing", "row": 5, "field": "region", "current_value": "", "description": "Row 5: 'region' is missing"})
151
+
 
152
  dirty[6]["customer"] = "Acme Corp"
153
  dirty[6]["email"] = "orders@acme.com"
154
  errors.append({"error_type": "duplicate", "row": 6, "field": "customer", "current_value": "Acme Corp", "description": "Row 6: 'customer' duplicates row 0"})
155
+
 
156
  dirty[7]["amount"] = 99999.99
157
  errors.append({"error_type": "range", "row": 7, "field": "amount", "current_value": 99999.99, "description": "Row 7: 'amount' exceeds maximum threshold (should be 6750.00)"})
158
+
 
159
  dirty[8]["currency"] = "AUSD"
160
  errors.append({"error_type": "format", "row": 8, "field": "currency", "current_value": "AUSD", "expected_format": "3-letter ISO code", "description": "Row 8: 'currency' invalid code 'AUSD'"})
161
+
 
162
  dirty[9]["priority"] = ""
163
  errors.append({"error_type": "missing", "row": 9, "field": "priority", "current_value": "", "description": "Row 9: 'priority' is missing"})
 
 
164
 
165
+ return dirty, ground_truth, errors
166
 
 
 
 
167
 
168
  TASK_REGISTRY = {
169
  "easy_missing_values": {
 
194
 
195
 
196
  def get_task_names() -> List[str]:
 
197
  return list(TASK_REGISTRY.keys())
198
 
199
 
200
  def generate_task(task_name: str, seed: int = 42) -> Dict[str, Any]:
 
201
  if task_name not in TASK_REGISTRY:
202
  raise ValueError(f"Unknown task: {task_name}. Available: {get_task_names()}")
203
+
204
  task_info = TASK_REGISTRY[task_name]
205
  dirty, ground_truth, errors = task_info["generator"](seed)
206
+
207
  return {
208
  "name": task_info["name"],
209
  "description": task_info["description"],
 
217
  }
218
 
219
 
220
+ def grade_action(action_type: str, target_field: str, target_row: int,
221
+ new_value: str, dirty_dataset: List[Dict],
222
  ground_truth: List[Dict], errors: List[Dict]) -> Tuple[float, str, bool]:
 
 
 
 
 
 
 
 
 
 
223
  total_errors = len(errors) if errors else 1
224
+
225
  if action_type == "validate":
226
  fixed = sum(1 for e in errors if e.get("fixed", False))
227
  return 0.0, f"Validation: {fixed}/{total_errors} errors fixed ({fixed/total_errors*100:.0f}%)", False
228
+
229
  if action_type == "skip":
230
  return 0.0, "Skipped current action", False
231
+
 
232
  matching_error = None
233
  for e in errors:
234
  if e.get("fixed", False):
 
236
  if e["row"] == target_row and e["field"] == target_field:
237
  matching_error = e
238
  break
239
+
240
  if matching_error is None:
241
  return -0.05, f"No unfixed error at row {target_row}, field '{target_field}'", False
242
+
 
243
  action_to_error_map = {
244
  "fix_missing": "missing",
245
  "fix_type": "type",
 
247
  "fix_format": "format",
248
  "fix_duplicate": "duplicate",
249
  }
250
+
251
  expected_error_type = action_to_error_map.get(action_type, "")
252
  if expected_error_type != matching_error["error_type"]:
253
  return -0.05, f"Wrong action type '{action_type}' for error type '{matching_error['error_type']}'", False
254
+
 
255
  gt_value = ground_truth[target_row][target_field]
256
+
 
257
  is_correct = False
258
  try:
259
  if isinstance(gt_value, float):
 
264
  is_correct = str(new_value).strip() == str(gt_value).strip()
265
  except (ValueError, TypeError):
266
  is_correct = str(new_value).strip() == str(gt_value).strip()
267
+
268
  if is_correct:
269
  matching_error["fixed"] = True
 
270
  if isinstance(gt_value, float):
271
  dirty_dataset[target_row][target_field] = float(new_value)
272
  elif isinstance(gt_value, int):
273
  dirty_dataset[target_row][target_field] = int(float(new_value))
274
  else:
275
  dirty_dataset[target_row][target_field] = new_value
276
+
277
  reward = 1.0 / total_errors
278
+ return reward, f"Fixed: row {target_row}, field '{target_field}' -> '{new_value}'", True
279
  else:
280
+ return -0.05, f"Wrong value for row {target_row}, field '{target_field}'. Got '{new_value}', expected something else.", False
inference.py CHANGED
@@ -1,18 +1,3 @@
1
- """
2
- Inference agent for the Data Validation Pipeline environment.
3
-
4
- Uses OpenAI-compatible API to solve data cleaning tasks.
5
- Reads environment variables:
6
- - API_BASE_URL: Base URL for the OpenAI-compatible API (default: https://api.openai.com/v1)
7
- - MODEL_NAME: Model to use (default: gpt-4.1-mini)
8
- - HF_TOKEN: HuggingFace token (REQUIRED, no default)
9
-
10
- Output format strictly follows OpenEnv spec:
11
- [START] task=<name> env=<benchmark> model=<model_name>
12
- [STEP] step=<n> action=<str> reward=<0.00> done=<true|false> error=<msg|null>
13
- [END] success=<true|false> steps=<n> rewards=<r1,r2,...,rn>
14
- """
15
-
16
  import json
17
  import os
18
  import re
@@ -21,7 +6,6 @@ import time
21
  import requests
22
  from openai import OpenAI
23
 
24
- # Read environment variables with defaults where required
25
  API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
26
  MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4.1-mini")
27
  HF_TOKEN = os.getenv("HF_TOKEN")
@@ -29,16 +13,13 @@ HF_TOKEN = os.getenv("HF_TOKEN")
29
  if HF_TOKEN is None:
30
  raise ValueError("HF_TOKEN environment variable is required")
31
 
32
- # Initialize OpenAI client
33
  client = OpenAI(
34
  base_url=API_BASE_URL,
35
  api_key=HF_TOKEN,
36
  )
37
 
38
- # The HF Space URL where the environment is running
39
  ENV_BASE_URL = os.getenv("ENV_BASE_URL", "https://kush5699-data-validation-env.hf.space")
40
 
41
- # All 3 tasks to run sequentially
42
  TASKS = [
43
  {"task_name": "easy_missing_values", "seed": 42},
44
  {"task_name": "medium_mixed_errors", "seed": 42},
@@ -49,7 +30,6 @@ BENCHMARK_NAME = "data_validation_env"
49
 
50
 
51
  def call_llm(messages: list) -> str:
52
- """Call the LLM via OpenAI-compatible API."""
53
  try:
54
  response = client.chat.completions.create(
55
  model=MODEL_NAME,
@@ -68,7 +48,6 @@ def call_llm(messages: list) -> str:
68
 
69
 
70
  def env_reset(task_name: str, seed: int = 42) -> dict:
71
- """Reset the environment."""
72
  resp = requests.post(
73
  f"{ENV_BASE_URL}/reset",
74
  json={"task_name": task_name, "seed": seed},
@@ -79,7 +58,6 @@ def env_reset(task_name: str, seed: int = 42) -> dict:
79
 
80
 
81
  def env_step(action: dict) -> dict:
82
- """Take a step in the environment."""
83
  resp = requests.post(
84
  f"{ENV_BASE_URL}/step",
85
  json=action,
@@ -90,7 +68,6 @@ def env_step(action: dict) -> dict:
90
 
91
 
92
  def build_system_prompt(obs: dict) -> str:
93
- """Build a system prompt for the LLM based on current observation."""
94
  return f"""You are a data validation agent. Your task is to fix errors in a dataset.
95
 
96
  TASK: {obs.get('task_description', '')}
@@ -116,13 +93,9 @@ RULES:
116
 
117
 
118
  def build_user_prompt(obs: dict) -> str:
119
- """Build a user prompt showing current state."""
120
  errors = obs.get("errors_found", [])
121
  dataset = obs.get("dataset", [])
122
-
123
  errors_text = json.dumps(errors, indent=2) if errors else "No errors remaining!"
124
-
125
- # Show a compact view of dataset
126
  dataset_compact = []
127
  for i, row in enumerate(dataset):
128
  dataset_compact.append(f"Row {i}: {json.dumps(row)}")
@@ -146,8 +119,6 @@ Respond with ONLY a JSON action object to fix the next error."""
146
 
147
 
148
  def parse_llm_response(response: str) -> dict:
149
- """Parse the LLM response into a valid action."""
150
- # Try to extract JSON from the response
151
  try:
152
  action = json.loads(response)
153
  return {
@@ -159,7 +130,6 @@ def parse_llm_response(response: str) -> dict:
159
  except json.JSONDecodeError:
160
  pass
161
 
162
- # Try to find JSON in the response
163
  json_match = re.search(r'\{[^}]+\}', response)
164
  if json_match:
165
  try:
@@ -173,12 +143,10 @@ def parse_llm_response(response: str) -> dict:
173
  except (json.JSONDecodeError, ValueError):
174
  pass
175
 
176
- # Fallback: skip
177
  return {"action_type": "skip", "target_field": "", "target_row": 0, "new_value": ""}
178
 
179
 
180
  def run_episode(task_config: dict) -> None:
181
- """Run a single episode for a task."""
182
  task_name = task_config["task_name"]
183
  seed = task_config.get("seed", 42)
184
  rewards = []
@@ -188,7 +156,6 @@ def run_episode(task_config: dict) -> None:
188
  print(f"[START] task={task_name} env={BENCHMARK_NAME} model={MODEL_NAME}")
189
 
190
  try:
191
- # Reset environment
192
  obs = env_reset(task_name, seed)
193
  max_steps = obs.get("max_steps", 20)
194
 
@@ -197,18 +164,14 @@ def run_episode(task_config: dict) -> None:
197
  ]
198
 
199
  while not obs.get("done", False) and steps < max_steps:
200
- # Build user prompt
201
  user_msg = build_user_prompt(obs)
202
  messages_for_call = messages + [{"role": "user", "content": user_msg}]
203
 
204
- # Get LLM response
205
  llm_response = call_llm(messages_for_call)
206
 
207
- # Parse into action
208
  action = parse_llm_response(llm_response)
209
  action_str = json.dumps(action)
210
 
211
- # Take step
212
  error_msg = None
213
  try:
214
  obs = env_step(action)
@@ -227,7 +190,6 @@ def run_episode(task_config: dict) -> None:
227
  if done:
228
  break
229
 
230
- # Calculate success based on cumulative reward
231
  total_reward = sum(rewards)
232
  success = total_reward > 0.5
233
 
@@ -243,10 +205,9 @@ def run_episode(task_config: dict) -> None:
243
 
244
 
245
  def main():
246
- """Run all 3 tasks sequentially."""
247
  for task_config in TASKS:
248
  run_episode(task_config)
249
- time.sleep(1) # Small delay between tasks
250
 
251
 
252
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import json
2
  import os
3
  import re
 
6
  import requests
7
  from openai import OpenAI
8
 
 
9
  API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
10
  MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4.1-mini")
11
  HF_TOKEN = os.getenv("HF_TOKEN")
 
13
  if HF_TOKEN is None:
14
  raise ValueError("HF_TOKEN environment variable is required")
15
 
 
16
  client = OpenAI(
17
  base_url=API_BASE_URL,
18
  api_key=HF_TOKEN,
19
  )
20
 
 
21
  ENV_BASE_URL = os.getenv("ENV_BASE_URL", "https://kush5699-data-validation-env.hf.space")
22
 
 
23
  TASKS = [
24
  {"task_name": "easy_missing_values", "seed": 42},
25
  {"task_name": "medium_mixed_errors", "seed": 42},
 
30
 
31
 
32
  def call_llm(messages: list) -> str:
 
33
  try:
34
  response = client.chat.completions.create(
35
  model=MODEL_NAME,
 
48
 
49
 
50
  def env_reset(task_name: str, seed: int = 42) -> dict:
 
51
  resp = requests.post(
52
  f"{ENV_BASE_URL}/reset",
53
  json={"task_name": task_name, "seed": seed},
 
58
 
59
 
60
  def env_step(action: dict) -> dict:
 
61
  resp = requests.post(
62
  f"{ENV_BASE_URL}/step",
63
  json=action,
 
68
 
69
 
70
  def build_system_prompt(obs: dict) -> str:
 
71
  return f"""You are a data validation agent. Your task is to fix errors in a dataset.
72
 
73
  TASK: {obs.get('task_description', '')}
 
93
 
94
 
95
  def build_user_prompt(obs: dict) -> str:
 
96
  errors = obs.get("errors_found", [])
97
  dataset = obs.get("dataset", [])
 
98
  errors_text = json.dumps(errors, indent=2) if errors else "No errors remaining!"
 
 
99
  dataset_compact = []
100
  for i, row in enumerate(dataset):
101
  dataset_compact.append(f"Row {i}: {json.dumps(row)}")
 
119
 
120
 
121
  def parse_llm_response(response: str) -> dict:
 
 
122
  try:
123
  action = json.loads(response)
124
  return {
 
130
  except json.JSONDecodeError:
131
  pass
132
 
 
133
  json_match = re.search(r'\{[^}]+\}', response)
134
  if json_match:
135
  try:
 
143
  except (json.JSONDecodeError, ValueError):
144
  pass
145
 
 
146
  return {"action_type": "skip", "target_field": "", "target_row": 0, "new_value": ""}
147
 
148
 
149
  def run_episode(task_config: dict) -> None:
 
150
  task_name = task_config["task_name"]
151
  seed = task_config.get("seed", 42)
152
  rewards = []
 
156
  print(f"[START] task={task_name} env={BENCHMARK_NAME} model={MODEL_NAME}")
157
 
158
  try:
 
159
  obs = env_reset(task_name, seed)
160
  max_steps = obs.get("max_steps", 20)
161
 
 
164
  ]
165
 
166
  while not obs.get("done", False) and steps < max_steps:
 
167
  user_msg = build_user_prompt(obs)
168
  messages_for_call = messages + [{"role": "user", "content": user_msg}]
169
 
 
170
  llm_response = call_llm(messages_for_call)
171
 
 
172
  action = parse_llm_response(llm_response)
173
  action_str = json.dumps(action)
174
 
 
175
  error_msg = None
176
  try:
177
  obs = env_step(action)
 
190
  if done:
191
  break
192
 
 
193
  total_reward = sum(rewards)
194
  success = total_reward > 0.5
195
 
 
205
 
206
 
207
  def main():
 
208
  for task_config in TASKS:
209
  run_episode(task_config)
210
+ time.sleep(1)
211
 
212
 
213
  if __name__ == "__main__":
server.py CHANGED
@@ -1,27 +1,19 @@
1
- """FastAPI server for the Data Validation Pipeline environment.
2
-
3
- Exposes OpenEnv-compatible HTTP endpoints: /reset, /step, /state, /health
4
- """
5
-
6
  import json
7
  import traceback
8
- from typing import Any, Dict, Optional
9
 
10
- from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect, Request
11
- from fastapi.responses import JSONResponse
12
  from pydantic import BaseModel
13
 
14
  from env.environment import DataValidationEnvironment
15
- from env.models import DataCleanAction, DataCleanObservation, DataCleanState
16
  from env.tasks import get_task_names
17
 
18
  app = FastAPI(
19
  title="Data Validation Pipeline - OpenEnv Environment",
20
- description="An RL environment for training agents to clean and validate structured data",
21
  version="1.0.0",
22
  )
23
 
24
- # Single shared environment instance
25
  env = DataValidationEnvironment()
26
 
27
 
@@ -39,13 +31,11 @@ class StepRequest(BaseModel):
39
 
40
  @app.get("/health")
41
  async def health():
42
- """Health check endpoint."""
43
  return {"status": "healthy", "service": "data-validation-env"}
44
 
45
 
46
  @app.post("/reset")
47
  async def reset(request: ResetRequest = None):
48
- """Reset the environment with a new task."""
49
  if request is None:
50
  request = ResetRequest()
51
  try:
@@ -57,7 +47,6 @@ async def reset(request: ResetRequest = None):
57
 
58
  @app.post("/step")
59
  async def step(request: StepRequest):
60
- """Execute an action in the environment."""
61
  try:
62
  action = DataCleanAction(
63
  action_type=request.action_type,
@@ -73,7 +62,6 @@ async def step(request: StepRequest):
73
 
74
  @app.get("/state")
75
  async def state():
76
- """Get the current environment state."""
77
  try:
78
  s = env.state()
79
  return s.model_dump()
@@ -83,25 +71,22 @@ async def state():
83
 
84
  @app.get("/tasks")
85
  async def tasks():
86
- """List available tasks."""
87
  return {"tasks": get_task_names()}
88
 
89
 
90
- # WebSocket support for OpenEnv clients
91
  @app.websocket("/ws")
92
  async def websocket_endpoint(websocket: WebSocket):
93
- """WebSocket endpoint for persistent sessions."""
94
  await websocket.accept()
95
  ws_env = DataValidationEnvironment()
96
-
97
  try:
98
  while True:
99
  data = await websocket.receive_text()
100
  msg = json.loads(data)
101
-
102
  method = msg.get("method", "")
103
  params = msg.get("params", {})
104
-
105
  try:
106
  if method == "reset":
107
  obs = ws_env.reset(
@@ -131,7 +116,7 @@ async def websocket_endpoint(websocket: WebSocket):
131
  }
132
  else:
133
  response = {"error": f"Unknown method: {method}"}
134
-
135
  await websocket.send_text(json.dumps(response))
136
  except Exception as e:
137
  await websocket.send_text(json.dumps({
 
 
 
 
 
 
1
  import json
2
  import traceback
3
+ from typing import Optional
4
 
5
+ from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
 
6
  from pydantic import BaseModel
7
 
8
  from env.environment import DataValidationEnvironment
9
+ from env.models import DataCleanAction
10
  from env.tasks import get_task_names
11
 
12
  app = FastAPI(
13
  title="Data Validation Pipeline - OpenEnv Environment",
 
14
  version="1.0.0",
15
  )
16
 
 
17
  env = DataValidationEnvironment()
18
 
19
 
 
31
 
32
  @app.get("/health")
33
  async def health():
 
34
  return {"status": "healthy", "service": "data-validation-env"}
35
 
36
 
37
  @app.post("/reset")
38
  async def reset(request: ResetRequest = None):
 
39
  if request is None:
40
  request = ResetRequest()
41
  try:
 
47
 
48
  @app.post("/step")
49
  async def step(request: StepRequest):
 
50
  try:
51
  action = DataCleanAction(
52
  action_type=request.action_type,
 
62
 
63
  @app.get("/state")
64
  async def state():
 
65
  try:
66
  s = env.state()
67
  return s.model_dump()
 
71
 
72
  @app.get("/tasks")
73
  async def tasks():
 
74
  return {"tasks": get_task_names()}
75
 
76
 
 
77
  @app.websocket("/ws")
78
  async def websocket_endpoint(websocket: WebSocket):
 
79
  await websocket.accept()
80
  ws_env = DataValidationEnvironment()
81
+
82
  try:
83
  while True:
84
  data = await websocket.receive_text()
85
  msg = json.loads(data)
86
+
87
  method = msg.get("method", "")
88
  params = msg.get("params", {})
89
+
90
  try:
91
  if method == "reset":
92
  obs = ws_env.reset(
 
116
  }
117
  else:
118
  response = {"error": f"Unknown method: {method}"}
119
+
120
  await websocket.send_text(json.dumps(response))
121
  except Exception as e:
122
  await websocket.send_text(json.dumps({