avanigupta commited on
Commit
c3002ad
·
1 Parent(s): cd11aba

add fix stage+demo

Browse files
README.md CHANGED
@@ -13,20 +13,38 @@ tags:
13
 
14
  # DataQA Environment
15
 
16
- An OpenEnv environment for **Data Quality Assurance** — an LLM agent inspects datasets with planted quality issues and must identify them all.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  ## Motivation
19
 
20
  Every ML engineer and data scientist spends significant time debugging data quality issues — missing values, type mismatches, logical inconsistencies, and subtle statistical anomalies — before data enters ML pipelines or production databases. This is a genuine, high-frequency human task that directly impacts model quality and business outcomes.
21
 
22
- DataQA turns this into a structured, gradable RL environment where agents must systematically inspect corrupted datasets, reason about schema constraints and validation rules, and pinpoint every planted issue — from obvious nulls to subtle data leakage signals that require domain expertise.
 
 
 
 
23
 
24
  ## Environment API
25
 
26
  | Endpoint | Method | Description |
27
  |----------|--------|-------------|
28
  | `/reset` | POST | Start a new episode with a corrupted dataset |
29
- | `/step` | POST | Submit identified issues, receive scored feedback |
30
  | `/state` | GET | Get current episode state |
31
  | `/health` | GET | Health check |
32
 
@@ -36,20 +54,27 @@ DataQA turns this into a structured, gradable RL environment where agents must s
36
  |------|--------|-----------|--------|-------------|
37
  | `easy` | 4 | Beginner | HR/Employee data | Nulls, wrong types, duplicates, out-of-range values |
38
  | `medium` | 6 | Intermediate | E-commerce orders | Format violations, inconsistent computed fields, duplicate keys |
39
- | `hard` | 8 | Advanced | ML experiment metadata | Data leakage signals, unreasonable GPU memory, timestamp ordering, whitespace-only fields |
40
 
41
  **Difficulty progression**: Easy issues are individually obvious (empty fields, text in numeric columns). Medium issues require cross-column reasoning (total != qty * price) and set membership checks. Hard issues require ML domain knowledge (val_loss < train_loss = data leakage) and multi-row temporal reasoning.
42
 
43
- ## Action Space
44
 
45
- The agent submits a list of issue strings, each in the format:
46
- ```
47
- row:<row_number>,col:<column_name>,issue:<issue_type>
48
- ```
 
 
 
 
 
49
 
50
- - `row_number`: 1-indexed position in the CSV data (after header). Row 1 = first data row.
51
- - `column_name`: Exact column header name, lowercase.
52
- - `issue_type`: One of the supported types below.
 
 
53
 
54
  **Supported Issue Types:**
55
 
@@ -66,30 +91,36 @@ row:<row_number>,col:<column_name>,issue:<issue_type>
66
 
67
  ## Observation Space
68
 
69
- Each observation contains:
70
-
71
  | Field | Type | Description |
72
  |-------|------|-------------|
73
  | `dataset_csv` | str | The corrupted dataset in CSV format |
74
  | `schema_description` | str | Column types, ranges, and constraints |
75
  | `validation_rules` | str | Business rules the data must satisfy |
76
  | `task_description` | str | Task context and instructions |
77
- | `feedback` | str | Results from previous step (TP/FP/FN counts, precision/recall) |
78
  | `num_issues_hint` | int | Exact count of planted issues |
79
  | `max_steps` | int | Maximum attempts allowed |
80
  | `done` | bool | Whether episode has terminated |
81
- | `reward` | float | Best weighted reward so far (0.0-1.0) |
82
 
83
- **Observation Metadata** (available after each step):
84
- - `f1`, `weighted_reward`, `precision`, `recall`
85
- - `tp`, `fp`, `fn`
86
- - `difficulty_found`, `difficulty_missed`
87
 
88
  ## Reward Function
89
 
90
- ### Difficulty-Weighted Reward (Primary)
 
 
 
 
 
 
 
 
91
 
92
- Each planted issue has a **difficulty weight** (1.0-3.0) reflecting how hard it is to detect. The primary reward is a **weighted F1 score** that provides meaningful per-step partial progress signals:
93
 
94
  | Weight | Category | Examples |
95
  |--------|----------|----------|
@@ -97,43 +128,51 @@ Each planted issue has a **difficulty weight** (1.0-3.0) reflecting how hard it
97
  | 1.5-2.0 | Medium | Duplicate keys, format violations, cross-column checks |
98
  | 2.5-3.0 | Hard | Data leakage, statistical outliers, whitespace-only |
99
 
100
- **Formula:**
101
- - **Weighted Recall** = (sum of difficulty weights for found issues) / (total difficulty weight)
102
- - **Weighted Precision** = (found weight) / (found weight + FP count * avg difficulty)
103
- - **Weighted F1** = harmonic mean of weighted precision and recall
 
104
 
105
- This means:
106
- - Finding a hard issue (difficulty 3.0) increases reward 3x more than finding an easy one (1.0)
107
- - False positives are penalized proportionally to average issue difficulty
108
- - The agent sees meaningful reward differences at every step, not just pass/fail
109
 
110
- ### Standard F1 (also computed)
 
 
 
 
 
111
 
112
- Available in observation metadata for comparison. Uses unweighted set matching.
 
 
 
 
 
 
 
 
113
 
114
  ### Episode Boundaries
115
 
116
  - Each task allows up to 3 steps (attempts)
117
- - Episode ends when F1 >= 0.999 (perfect) or max steps reached
118
- - Best score across all steps is the final reward (monotonically non-decreasing)
119
- - Reward is always in [0.0, 1.0]
120
 
121
  ## Baseline Scores
122
 
123
- Baseline scores using Qwen2.5-72B-Instruct via HuggingFace Router:
124
 
125
- | Task | Expected Score Range | Description |
126
- |------|---------------------|-------------|
127
- | `easy` | 0.7 - 1.0 | Most LLMs find obvious issues reliably |
128
- | `medium` | 0.5 - 0.8 | Cross-column reasoning is challenging |
129
- | `hard` | 0.3 - 0.6 | ML domain knowledge and subtle patterns |
130
 
131
- Scores vary by model capability. Frontier models (GPT-4, Claude) typically score higher on the hard task due to better domain reasoning.
132
 
133
  ## Extensibility
134
 
135
- DataQA supports custom tasks, contamination rules, and difficulty levels via a programmatic API.
136
-
137
  ### Custom Contamination Rules
138
 
139
  ```python
@@ -180,7 +219,7 @@ register_task("custom", lambda seed: task)
180
  |------|--------|--------------------|
181
  | `missing_value` | Sets field to empty string | 1.0 |
182
  | `whitespace_value` | Sets field to single space | 2.5 |
183
- | `wrong_type_text` | Replaces with random text ("N/A", "null", etc.) | 1.0 |
184
  | `negative_value` | Negates numeric value | 1.0 |
185
 
186
  ## Quick Start
@@ -213,15 +252,16 @@ pip install -e ".[dev]"
213
  pytest tests/ -v
214
  ```
215
 
216
- 89 tests covering:
217
- - Task creation, corruption, and issue planting (difficulty weights, seed determinism)
218
- - Issue key parsing (standard, lenient, edge cases)
219
- - F1 and difficulty-weighted reward computation
220
- - Full environment reset/step lifecycle
 
221
  - Inference script parsing and prompt building
222
- - **Structured log format** ([START], [STEP], [END] — exact field names and ordering)
223
  - Score bounds (0.0-1.0), best-score monotonicity
224
- - Extensibility API (custom rules, custom tasks, environment integration)
225
 
226
  ## Validation
227
 
@@ -247,19 +287,19 @@ openenv validate .
247
  ```
248
  dataqa_env/
249
  ├── __init__.py # Public API + extensibility exports
250
- ├── models.py # Pydantic: DataQAAction, DataQAObservation, DataQAState
251
  ├── client.py # EnvClient for WebSocket connections
252
  ├── server/
253
- │ ├── environment.py # Core DataQAEnvironment (reset/step/state + weighted rewards)
254
  │ ├── tasks.py # Task definitions + contamination rules + extensibility API
255
  │ ├── app.py # FastAPI server (via openenv-core create_app)
256
  │ └── Dockerfile
257
  tests/
258
  ├── test_tasks.py # Task creation, corruption, difficulty weights
259
- ├── test_environment.py # Environment lifecycle, scoring, metadata
260
- ├── test_inference.py # LLM response parsing, prompt building, log format
261
  └── test_extensibility.py # Custom rules, custom tasks, registration API
262
- inference.py # Baseline LLM agent (OpenAI client, structured logs)
263
  openenv.yaml # OpenEnv/HF Spaces spec
264
  pyproject.toml # Package metadata and dependencies
265
  Dockerfile # Production container
 
13
 
14
  # DataQA Environment
15
 
16
+ A two-phase OpenEnv RL environment for **Data Quality Assurance** — an LLM agent inspects corrupted datasets, identifies all planted quality issues, and proposes data repairs.
17
+
18
+ ### Demo: Agent Trajectory Replay
19
+
20
+ **Easy task** — Agent finds all 4 issues and proposes fixes (step 2):
21
+
22
+ ![Easy task: all issues found + fixes proposed](docs/demo_easy.png)
23
+
24
+ **Hard task** — Agent identifies 8 subtle ML issues including data leakage and GPU memory outlier, proposes fixes (step 2):
25
+
26
+ ![Hard task: ML experiment metadata with 8 issues](docs/demo_hard.png)
27
+
28
+ Green cells = correctly found issues. Yellow = missed. Green outlines = correct fixes with proposed values shown inline (e.g. `empty → David Kim`, `seventy-five thousand → 75000`).
29
+
30
+ > The interactive replay UI is available at the `/web` endpoint on the HF Space.
31
 
32
  ## Motivation
33
 
34
  Every ML engineer and data scientist spends significant time debugging data quality issues — missing values, type mismatches, logical inconsistencies, and subtle statistical anomalies — before data enters ML pipelines or production databases. This is a genuine, high-frequency human task that directly impacts model quality and business outcomes.
35
 
36
+ DataQA turns this into a **two-phase RL challenge**:
37
+ 1. **Identify** — systematically inspect corrupted data and pinpoint every planted issue
38
+ 2. **Fix** — propose corrected values by reasoning about schema, constraints, and context
39
+
40
+ This creates a rich multi-step decision problem where agents must explore datasets strategically, distinguish subtle anomalies from noise, and reason about what the correct data should be.
41
 
42
  ## Environment API
43
 
44
  | Endpoint | Method | Description |
45
  |----------|--------|-------------|
46
  | `/reset` | POST | Start a new episode with a corrupted dataset |
47
+ | `/step` | POST | Submit identified issues + proposed fixes |
48
  | `/state` | GET | Get current episode state |
49
  | `/health` | GET | Health check |
50
 
 
54
  |------|--------|-----------|--------|-------------|
55
  | `easy` | 4 | Beginner | HR/Employee data | Nulls, wrong types, duplicates, out-of-range values |
56
  | `medium` | 6 | Intermediate | E-commerce orders | Format violations, inconsistent computed fields, duplicate keys |
57
+ | `hard` | 10 | Advanced | ML experiment metadata | Data leakage signals, unreasonable GPU memory, impossibly fast training, SOTA-exceeding accuracy, timestamp ordering, whitespace-only fields |
58
 
59
  **Difficulty progression**: Easy issues are individually obvious (empty fields, text in numeric columns). Medium issues require cross-column reasoning (total != qty * price) and set membership checks. Hard issues require ML domain knowledge (val_loss < train_loss = data leakage) and multi-row temporal reasoning.
60
 
61
+ ## Two-Phase Action Space
62
 
63
+ ### Phase 1: Identify Issues
64
+
65
+ Submit issues in format: `row:<row_number>,col:<column_name>,issue:<issue_type>`
66
+
67
+ - `row_number`: 1-indexed data row position (after header)
68
+ - `column_name`: Exact column header name, lowercase
69
+ - `issue_type`: One of the supported types below
70
+
71
+ ### Phase 2: Propose Fixes
72
 
73
+ Submit fixes in format: `row:<row_number>,col:<column_name>,fix:<corrected_value>`
74
+
75
+ The agent proposes the **correct value** that should replace the corrupted data. Fixes are graded against the original clean dataset.
76
+
77
+ Both phases can be submitted in the same step or across multiple steps.
78
 
79
  **Supported Issue Types:**
80
 
 
91
 
92
  ## Observation Space
93
 
 
 
94
  | Field | Type | Description |
95
  |-------|------|-------------|
96
  | `dataset_csv` | str | The corrupted dataset in CSV format |
97
  | `schema_description` | str | Column types, ranges, and constraints |
98
  | `validation_rules` | str | Business rules the data must satisfy |
99
  | `task_description` | str | Task context and instructions |
100
+ | `feedback` | str | Per-step results: TP/FP/FN, precision/recall, fix scores |
101
  | `num_issues_hint` | int | Exact count of planted issues |
102
  | `max_steps` | int | Maximum attempts allowed |
103
  | `done` | bool | Whether episode has terminated |
104
+ | `reward` | float | Best combined reward so far (0.0-1.0) |
105
 
106
+ **Observation Metadata** (per step):
107
+ - Identify: `identify_f1`, `identify_score`, `precision`, `recall`, `tp`, `fp`, `fn`
108
+ - Fix: `fix_score`, `fixes_correct`, `fixes_partial`, `fixes_wrong`, `fixes_attempted`
109
+ - Combined: `combined_reward`, `difficulty_found`, `difficulty_missed`
110
 
111
  ## Reward Function
112
 
113
+ ### Combined Reward
114
+
115
+ ```
116
+ combined_reward = 0.6 * identify_score + 0.4 * fix_score
117
+ ```
118
+
119
+ If no fixes are submitted, `combined_reward = identify_score` (no penalty — backward compatible).
120
+
121
+ ### Identify Score (Difficulty-Weighted F1)
122
 
123
+ Each planted issue has a **difficulty weight** (1.0-3.0):
124
 
125
  | Weight | Category | Examples |
126
  |--------|----------|----------|
 
128
  | 1.5-2.0 | Medium | Duplicate keys, format violations, cross-column checks |
129
  | 2.5-3.0 | Hard | Data leakage, statistical outliers, whitespace-only |
130
 
131
+ - **Weighted Recall** = (difficulty of found issues) / (total difficulty)
132
+ - **Weighted Precision** = penalizes false positives proportional to average difficulty
133
+ - **Weighted F1** = harmonic mean
134
+
135
+ ### Fix Score (Difficulty-Weighted Quality)
136
 
137
+ Each proposed fix is compared against the original clean value:
 
 
 
138
 
139
+ | Fix Quality | Score | Description |
140
+ |-------------|-------|-------------|
141
+ | Exact match | 1.0 | Case-insensitive, whitespace-stripped match |
142
+ | Numeric close | 0.8 | Within 1% of correct numeric value |
143
+ | Correct cell | 0.1 | Right location, wrong value |
144
+ | Non-issue cell | 0.0 | Fix targets a cell with no issue |
145
 
146
+ Fix score = (sum of best fix score per issue × difficulty weight) / (total difficulty weight)
147
+
148
+ ### Reward Properties
149
+
150
+ - **Per-step partial progress**: reward increases as more issues are found/fixed
151
+ - **Difficulty-aware**: finding subtle issues earns more than obvious ones
152
+ - **Penalizes bad behavior**: false positives reduce score, fixing non-issues earns nothing
153
+ - **Monotonically non-decreasing**: best score across all steps is the final reward
154
+ - **Always in [0.0, 1.0]**: meets hackathon requirement
155
 
156
  ### Episode Boundaries
157
 
158
  - Each task allows up to 3 steps (attempts)
159
+ - Episode ends when F1 >= 0.999 (perfect identification) or max steps reached
160
+ - Agent receives detailed feedback after each step to improve on next attempt
 
161
 
162
  ## Baseline Scores
163
 
164
+ Baseline agent uses Qwen2.5-72B-Instruct via HuggingFace Router:
165
 
166
+ | Task | Identify Score | Fix Score | Combined | Notes |
167
+ |------|---------------|-----------|----------|-------|
168
+ | `easy` | 0.7-1.0 | 0.5-0.9 | 0.6-1.0 | Most LLMs find obvious issues reliably |
169
+ | `medium` | 0.5-0.8 | 0.3-0.6 | 0.4-0.7 | Cross-column reasoning challenges models |
170
+ | `hard` | 0.3-0.6 | 0.2-0.4 | 0.3-0.5 | ML domain knowledge and subtle patterns |
171
 
172
+ Scores vary by model. The hard task is designed to challenge frontier models.
173
 
174
  ## Extensibility
175
 
 
 
176
  ### Custom Contamination Rules
177
 
178
  ```python
 
219
  |------|--------|--------------------|
220
  | `missing_value` | Sets field to empty string | 1.0 |
221
  | `whitespace_value` | Sets field to single space | 2.5 |
222
+ | `wrong_type_text` | Replaces with random text | 1.0 |
223
  | `negative_value` | Negates numeric value | 1.0 |
224
 
225
  ## Quick Start
 
252
  pytest tests/ -v
253
  ```
254
 
255
+ 118 tests covering:
256
+ - Task creation, corruption, and difficulty weights
257
+ - Issue key and fix parsing (standard, lenient, edge cases)
258
+ - F1, weighted reward, and fix quality computation
259
+ - Full environment lifecycle (identify-only and identify+fix)
260
+ - Combined reward calculation and weight verification
261
  - Inference script parsing and prompt building
262
+ - Structured log format ([START], [STEP], [END])
263
  - Score bounds (0.0-1.0), best-score monotonicity
264
+ - Extensibility API (custom rules, custom tasks)
265
 
266
  ## Validation
267
 
 
287
  ```
288
  dataqa_env/
289
  ├── __init__.py # Public API + extensibility exports
290
+ ├── models.py # Pydantic: DataQAAction (issues + fixes), DataQAObservation, DataQAState
291
  ├── client.py # EnvClient for WebSocket connections
292
  ├── server/
293
+ │ ├── environment.py # Two-phase DataQAEnvironment (identify + fix + combined reward)
294
  │ ├── tasks.py # Task definitions + contamination rules + extensibility API
295
  │ ├── app.py # FastAPI server (via openenv-core create_app)
296
  │ └── Dockerfile
297
  tests/
298
  ├── test_tasks.py # Task creation, corruption, difficulty weights
299
+ ├── test_environment.py # Identify scoring, fix grading, combined reward, lifecycle
300
+ ├── test_inference.py # LLM response parsing, fix parsing, prompt building, log format
301
  └── test_extensibility.py # Custom rules, custom tasks, registration API
302
+ inference.py # Two-phase baseline agent (identify fix)
303
  openenv.yaml # OpenEnv/HF Spaces spec
304
  pyproject.toml # Package metadata and dependencies
305
  Dockerfile # Production container
dataqa_env/models.py CHANGED
@@ -16,21 +16,23 @@ from openenv.core.env_server.interfaces import Action, Observation, State
16
 
17
  class DataQAAction(Action):
18
  """
19
- Agent submits a list of identified data quality issues.
 
 
 
 
 
 
 
20
 
21
- Each issue is a string in the format: "row:<row_idx>,col:<col_name>,issue:<issue_type>"
22
  Supported issue types:
23
- - missing_value
24
- - wrong_type
25
- - duplicate_row
26
- - out_of_range
27
- - format_violation
28
- - inconsistent_value
29
- - statistical_outlier
30
- - referential_integrity
31
  """
32
 
33
  issues: List[str]
 
34
  # Include task_id so step() can reconstruct context in stateless HTTP mode
35
  task_id: str = "easy"
36
 
 
16
 
17
  class DataQAAction(Action):
18
  """
19
+ Agent submits identified issues AND optional proposed fixes.
20
+
21
+ Two-phase action space:
22
+ Phase 1 (Identify): List issues in format "row:<N>,col:<name>,issue:<type>"
23
+ Phase 2 (Fix): List fixes in format "row:<N>,col:<name>,fix:<proposed_value>"
24
+
25
+ The agent can submit both in the same step or across multiple steps.
26
+ Combined reward = 0.6 * identify_score + 0.4 * fix_score
27
 
 
28
  Supported issue types:
29
+ missing_value, wrong_type, duplicate_row, out_of_range,
30
+ format_violation, inconsistent_value, statistical_outlier,
31
+ referential_integrity
 
 
 
 
 
32
  """
33
 
34
  issues: List[str]
35
+ fixes: List[str] = []
36
  # Include task_id so step() can reconstruct context in stateless HTTP mode
37
  task_id: str = "easy"
38
 
dataqa_env/server/environment.py CHANGED
@@ -3,8 +3,12 @@ DataQA Environment
3
  ------------------
4
  Server-side environment for data quality assurance tasks.
5
 
6
- The agent receives corrupted datasets and must identify planted quality issues.
7
- Scoring is based on F1 (precision-recall) of correctly matched issues.
 
 
 
 
8
  """
9
 
10
  from __future__ import annotations
@@ -18,6 +22,10 @@ from openenv.core.env_server.interfaces import Action, Environment, Observation
18
  from ..models import DataQAAction, DataQAObservation, DataQAState
19
  from .tasks import PlantedIssue, Task, get_task, list_tasks
20
 
 
 
 
 
21
 
22
  def parse_issue_key(raw: str) -> Optional[str]:
23
  """
@@ -26,7 +34,6 @@ def parse_issue_key(raw: str) -> Optional[str]:
26
  Returns normalized key or None if unparseable.
27
  """
28
  raw = raw.strip().lower()
29
- # Be lenient with formatting
30
  row_match = re.search(r"row\s*[:=]\s*(\d+)", raw)
31
  col_match = re.search(r"col\s*[:=]\s*([\w_]+)", raw)
32
  issue_match = re.search(r"issue\s*[:=]\s*([\w_]+)", raw)
@@ -36,6 +43,22 @@ def parse_issue_key(raw: str) -> Optional[str]:
36
  return None
37
 
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  def compute_f1(reported_keys: Set[str], planted_keys: Set[str]) -> dict:
40
  """Compute precision, recall, and F1 score."""
41
  if not reported_keys and not planted_keys:
@@ -83,7 +106,6 @@ def compute_weighted_reward(
83
  if not planted_keys:
84
  return {"weighted_reward": 0.0, "difficulty_found": 0.0, "difficulty_missed": 0.0}
85
 
86
- # Sum difficulty weights for found vs missed issues
87
  found_keys = reported_keys & planted_keys
88
  missed_keys = planted_keys - reported_keys
89
  false_positive_count = len(reported_keys - planted_keys)
@@ -92,15 +114,12 @@ def compute_weighted_reward(
92
  difficulty_missed = sum(planted_by_key[k].difficulty for k in missed_keys)
93
  total_weight = sum(i.difficulty for i in planted_issues)
94
 
95
- # Weighted recall: proportion of difficulty captured
96
  weighted_recall = difficulty_found / total_weight if total_weight > 0 else 0.0
97
 
98
- # Penalty for false positives (scaled by avg difficulty so penalty is proportional)
99
  avg_difficulty = total_weight / len(planted_issues)
100
  fp_penalty_weight = false_positive_count * avg_difficulty
101
  weighted_precision = difficulty_found / (difficulty_found + fp_penalty_weight) if (difficulty_found + fp_penalty_weight) > 0 else 0.0
102
 
103
- # Weighted F1
104
  if (weighted_precision + weighted_recall) > 0:
105
  weighted_reward = 2 * weighted_precision * weighted_recall / (weighted_precision + weighted_recall)
106
  else:
@@ -113,12 +132,134 @@ def compute_weighted_reward(
113
  }
114
 
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  class DataQAEnvironment(Environment):
117
  """
118
- Data Quality Assurance environment.
119
 
120
- The agent inspects corrupted datasets and reports quality issues.
121
- Reward is F1 score of correctly identified issues vs planted ground truth.
 
 
 
122
  """
123
 
124
  SUPPORTS_CONCURRENT_SESSIONS = True
@@ -158,7 +299,11 @@ class DataQAEnvironment(Environment):
158
  schema_description=self._current_task.schema_description,
159
  validation_rules=self._current_task.validation_rules,
160
  task_description=self._current_task.description,
161
- feedback="Environment reset. Inspect the dataset and report all quality issues.",
 
 
 
 
162
  task_id=task_id,
163
  num_issues_hint=len(self._current_task.planted_issues),
164
  max_steps=self._current_task.max_steps,
@@ -175,15 +320,14 @@ class DataQAEnvironment(Environment):
175
  if not isinstance(action, DataQAAction):
176
  raise ValueError(f"Expected DataQAAction, got {type(action)}")
177
 
178
- # In stateless HTTP mode, each request creates a fresh env instance.
179
- # Auto-reset using the task_id from the action so step() works standalone.
180
  if self._current_task is None:
181
  self.reset(task_id=action.task_id)
182
 
183
  self._state.step_count += 1
184
  self._state.current_step += 1
185
 
186
- # Parse reported issues
187
  reported_keys: Set[str] = set()
188
  parse_errors: list[str] = []
189
  for raw_issue in action.issues:
@@ -191,51 +335,148 @@ class DataQAEnvironment(Environment):
191
  if key:
192
  reported_keys.add(key)
193
  else:
194
- parse_errors.append(f"Could not parse: '{raw_issue}'")
195
 
196
- # Compute score (standard F1)
197
  metrics = compute_f1(reported_keys, self._planted_keys)
198
- score = metrics["f1"]
199
 
200
- # Compute difficulty-weighted reward (richer per-step signal)
201
  weighted = compute_weighted_reward(reported_keys, self._current_task.planted_issues)
202
- weighted_reward = weighted["weighted_reward"]
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
- # Use weighted reward as the primary reward signal
205
- self._best_score = max(self._best_score, weighted_reward)
 
 
 
 
 
 
206
  self._state.best_score = self._best_score
207
 
208
- # Check if done
209
  is_done = (
210
- score >= 0.999 # Perfect score (all issues found exactly)
211
  or self._state.current_step >= self._state.max_steps
212
  )
213
 
214
- # Build feedback
 
 
 
 
215
  feedback_lines = [
216
  f"Step {self._state.current_step}/{self._state.max_steps}",
 
 
217
  f"Issues reported: {len(reported_keys)}",
218
  f"True positives: {metrics['tp']}, False positives: {metrics['fp']}, Missed: {metrics['fn']}",
219
- f"Precision: {metrics['precision']:.3f}, Recall: {metrics['recall']:.3f}, F1: {score:.3f}",
220
- f"Weighted reward: {weighted_reward:.3f} (difficulty found: {weighted['difficulty_found']}, missed: {weighted['difficulty_missed']})",
221
  ]
222
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  if parse_errors:
224
- feedback_lines.append(f"Parse errors ({len(parse_errors)}): {'; '.join(parse_errors[:3])}")
225
 
226
  if not is_done:
227
- # Give hints about what was missed without revealing exact answers
228
  if metrics["fn"] > 0:
229
  feedback_lines.append(
230
- f"You missed {metrics['fn']} issue(s). Review the dataset carefully."
231
  )
232
  if metrics["fp"] > 0:
233
  feedback_lines.append(
234
- f"{metrics['fp']} of your reported issues were incorrect."
235
  )
236
- feedback_lines.append("You can submit again with an updated list of issues.")
237
  else:
238
- feedback_lines.append(f"Task complete! Final best weighted reward: {self._best_score:.3f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
 
240
  return DataQAObservation(
241
  dataset_csv=self._current_task.corrupted_csv,
@@ -249,8 +490,10 @@ class DataQAEnvironment(Environment):
249
  done=is_done,
250
  reward=self._best_score,
251
  metadata={
252
- "f1": score,
253
- "weighted_reward": weighted_reward,
 
 
254
  "precision": metrics["precision"],
255
  "recall": metrics["recall"],
256
  "tp": metrics["tp"],
@@ -258,6 +501,12 @@ class DataQAEnvironment(Environment):
258
  "fn": metrics["fn"],
259
  "difficulty_found": weighted["difficulty_found"],
260
  "difficulty_missed": weighted["difficulty_missed"],
 
 
 
 
 
 
261
  },
262
  )
263
 
 
3
  ------------------
4
  Server-side environment for data quality assurance tasks.
5
 
6
+ Two-phase RL environment:
7
+ Phase 1 (Identify): Agent inspects corrupted datasets and reports quality issues.
8
+ Phase 2 (Fix): Agent proposes corrections for identified issues.
9
+
10
+ Combined reward = 0.6 * identify_score + 0.4 * fix_score
11
+ Both phases scored with difficulty-weighted metrics for rich per-step signal.
12
  """
13
 
14
  from __future__ import annotations
 
22
  from ..models import DataQAAction, DataQAObservation, DataQAState
23
  from .tasks import PlantedIssue, Task, get_task, list_tasks
24
 
25
+ # Reward weights for the two phases
26
+ IDENTIFY_WEIGHT = 0.6
27
+ FIX_WEIGHT = 0.4
28
+
29
 
30
  def parse_issue_key(raw: str) -> Optional[str]:
31
  """
 
34
  Returns normalized key or None if unparseable.
35
  """
36
  raw = raw.strip().lower()
 
37
  row_match = re.search(r"row\s*[:=]\s*(\d+)", raw)
38
  col_match = re.search(r"col\s*[:=]\s*([\w_]+)", raw)
39
  issue_match = re.search(r"issue\s*[:=]\s*([\w_]+)", raw)
 
43
  return None
44
 
45
 
46
+ def parse_fix(raw: str) -> Optional[tuple[int, str, str]]:
47
+ """
48
+ Parse an agent-proposed fix into (row, col, proposed_value).
49
+ Expected format: row:<N>,col:<name>,fix:<value>
50
+ Returns (row, col, value) or None if unparseable.
51
+ """
52
+ raw = raw.strip()
53
+ row_match = re.search(r"row\s*[:=]\s*(\d+)", raw, re.IGNORECASE)
54
+ col_match = re.search(r"col(?:umn)?\s*[:=]\s*([\w_]+)", raw, re.IGNORECASE)
55
+ fix_match = re.search(r"fix\s*[:=]\s*(.+?)$", raw, re.IGNORECASE)
56
+
57
+ if row_match and col_match and fix_match:
58
+ return (int(row_match.group(1)), col_match.group(1).lower(), fix_match.group(1).strip())
59
+ return None
60
+
61
+
62
  def compute_f1(reported_keys: Set[str], planted_keys: Set[str]) -> dict:
63
  """Compute precision, recall, and F1 score."""
64
  if not reported_keys and not planted_keys:
 
106
  if not planted_keys:
107
  return {"weighted_reward": 0.0, "difficulty_found": 0.0, "difficulty_missed": 0.0}
108
 
 
109
  found_keys = reported_keys & planted_keys
110
  missed_keys = planted_keys - reported_keys
111
  false_positive_count = len(reported_keys - planted_keys)
 
114
  difficulty_missed = sum(planted_by_key[k].difficulty for k in missed_keys)
115
  total_weight = sum(i.difficulty for i in planted_issues)
116
 
 
117
  weighted_recall = difficulty_found / total_weight if total_weight > 0 else 0.0
118
 
 
119
  avg_difficulty = total_weight / len(planted_issues)
120
  fp_penalty_weight = false_positive_count * avg_difficulty
121
  weighted_precision = difficulty_found / (difficulty_found + fp_penalty_weight) if (difficulty_found + fp_penalty_weight) > 0 else 0.0
122
 
 
123
  if (weighted_precision + weighted_recall) > 0:
124
  weighted_reward = 2 * weighted_precision * weighted_recall / (weighted_precision + weighted_recall)
125
  else:
 
132
  }
133
 
134
 
135
+ def grade_fixes(
136
+ fixes: list[tuple[int, str, str]],
137
+ task: Task,
138
+ ) -> dict:
139
+ """
140
+ Grade proposed fixes against the clean dataset.
141
+
142
+ For each fix (row, col, proposed_value), compare to the original clean value.
143
+ Scoring per fix:
144
+ - Exact match (case-insensitive, whitespace-stripped): 1.0
145
+ - Numeric close match (within 1%): 0.8
146
+ - Correct column but wrong value: 0.1
147
+ - Targets a non-issue cell: 0.0 (penalty)
148
+
149
+ Returns dict with fix_score (0.0-1.0), details per fix, and counts.
150
+ """
151
+ if not fixes and not task.planted_issues:
152
+ return {"fix_score": 1.0, "fixes_correct": 0, "fixes_partial": 0,
153
+ "fixes_wrong": 0, "fixes_attempted": 0, "fix_details": []}
154
+
155
+ if not fixes:
156
+ return {"fix_score": 0.0, "fixes_correct": 0, "fixes_partial": 0,
157
+ "fixes_wrong": 0, "fixes_attempted": 0, "fix_details": []}
158
+
159
+ issue_map = task.get_planted_issue_map()
160
+ # Build set of (row, col) that are actual issues
161
+ issue_cells = {(issue.row, issue.col) for issue in task.planted_issues}
162
+
163
+ total_weight = sum(i.difficulty for i in task.planted_issues) if task.planted_issues else 1.0
164
+ earned_weight = 0.0
165
+ fixes_correct = 0
166
+ fixes_partial = 0
167
+ fixes_wrong = 0
168
+ fix_details = []
169
+
170
+ # Track which issues have been fixed (best fix wins)
171
+ fixed_issues: dict[tuple[int, str], float] = {}
172
+
173
+ for row, col, proposed in fixes:
174
+ clean_value = task.get_clean_value(row, col)
175
+ cell_key = (row, col)
176
+
177
+ if cell_key not in issue_cells:
178
+ # Fix targets a non-issue cell — no credit
179
+ fix_details.append({"row": row, "col": col, "score": 0.0, "reason": "not an issue cell"})
180
+ fixes_wrong += 1
181
+ continue
182
+
183
+ if clean_value is None:
184
+ fix_details.append({"row": row, "col": col, "score": 0.0, "reason": "cell not found"})
185
+ fixes_wrong += 1
186
+ continue
187
+
188
+ # Find the planted issue for this cell to get its difficulty weight
189
+ matching_issue = None
190
+ for issue in task.planted_issues:
191
+ if issue.row == row and issue.col == col:
192
+ matching_issue = issue
193
+ break
194
+
195
+ difficulty = matching_issue.difficulty if matching_issue else 1.0
196
+
197
+ # Score the fix
198
+ score = 0.0
199
+ reason = "wrong value"
200
+
201
+ # Exact match (case-insensitive, whitespace-stripped)
202
+ if proposed.strip().lower() == clean_value.lower():
203
+ score = 1.0
204
+ reason = "exact match"
205
+ fixes_correct += 1
206
+ else:
207
+ # Try numeric close match
208
+ try:
209
+ proposed_num = float(proposed.strip())
210
+ clean_num = float(clean_value)
211
+ if clean_num != 0 and abs(proposed_num - clean_num) / abs(clean_num) <= 0.01:
212
+ score = 0.8
213
+ reason = "numeric close match"
214
+ fixes_partial += 1
215
+ elif proposed_num == clean_num:
216
+ score = 1.0
217
+ reason = "exact numeric match"
218
+ fixes_correct += 1
219
+ else:
220
+ score = 0.1
221
+ reason = "correct cell, wrong value"
222
+ fixes_partial += 1
223
+ except (ValueError, ZeroDivisionError):
224
+ # Not numeric — just a wrong value but at least right cell
225
+ score = 0.1
226
+ reason = "correct cell, wrong value"
227
+ fixes_partial += 1
228
+
229
+ # Keep best fix per cell
230
+ if cell_key not in fixed_issues or score > fixed_issues[cell_key]:
231
+ fixed_issues[cell_key] = score
232
+
233
+ fix_details.append({"row": row, "col": col, "score": score, "reason": reason})
234
+
235
+ # Compute fix score: weighted sum of best fix per issue / total weight
236
+ for issue in task.planted_issues:
237
+ cell_key = (issue.row, issue.col)
238
+ if cell_key in fixed_issues:
239
+ earned_weight += issue.difficulty * fixed_issues[cell_key]
240
+
241
+ fix_score = earned_weight / total_weight if total_weight > 0 else 0.0
242
+ fix_score = min(max(fix_score, 0.0), 1.0)
243
+
244
+ return {
245
+ "fix_score": round(fix_score, 4),
246
+ "fixes_correct": fixes_correct,
247
+ "fixes_partial": fixes_partial,
248
+ "fixes_wrong": fixes_wrong,
249
+ "fixes_attempted": len(fixes),
250
+ "fix_details": fix_details,
251
+ }
252
+
253
+
254
  class DataQAEnvironment(Environment):
255
  """
256
+ Data Quality Assurance environment — two-phase identify + fix.
257
 
258
+ Phase 1 (Identify): Agent inspects corrupted datasets and reports quality issues.
259
+ Phase 2 (Fix): Agent proposes corrections for identified issues.
260
+
261
+ Combined reward = 0.6 * identify_score + 0.4 * fix_score
262
+ Both phases use difficulty-weighted scoring for rich per-step reward signals.
263
  """
264
 
265
  SUPPORTS_CONCURRENT_SESSIONS = True
 
299
  schema_description=self._current_task.schema_description,
300
  validation_rules=self._current_task.validation_rules,
301
  task_description=self._current_task.description,
302
+ feedback=(
303
+ "Environment reset. Inspect the dataset and report all quality issues.\n"
304
+ "You can also propose fixes in format: row:<N>,col:<name>,fix:<corrected_value>\n"
305
+ "Combined reward = 0.6 * identify_score + 0.4 * fix_score"
306
+ ),
307
  task_id=task_id,
308
  num_issues_hint=len(self._current_task.planted_issues),
309
  max_steps=self._current_task.max_steps,
 
320
  if not isinstance(action, DataQAAction):
321
  raise ValueError(f"Expected DataQAAction, got {type(action)}")
322
 
323
+ # Auto-reset in stateless HTTP mode
 
324
  if self._current_task is None:
325
  self.reset(task_id=action.task_id)
326
 
327
  self._state.step_count += 1
328
  self._state.current_step += 1
329
 
330
+ # ─�� Phase 1: Parse and score issue identification ──
331
  reported_keys: Set[str] = set()
332
  parse_errors: list[str] = []
333
  for raw_issue in action.issues:
 
335
  if key:
336
  reported_keys.add(key)
337
  else:
338
+ parse_errors.append(f"Could not parse issue: '{raw_issue}'")
339
 
 
340
  metrics = compute_f1(reported_keys, self._planted_keys)
341
+ identify_f1 = metrics["f1"]
342
 
 
343
  weighted = compute_weighted_reward(reported_keys, self._current_task.planted_issues)
344
+ identify_score = weighted["weighted_reward"]
345
+
346
+ # ── Phase 2: Parse and score proposed fixes ──
347
+ parsed_fixes: list[tuple[int, str, str]] = []
348
+ for raw_fix in action.fixes:
349
+ fix = parse_fix(raw_fix)
350
+ if fix:
351
+ parsed_fixes.append(fix)
352
+ else:
353
+ parse_errors.append(f"Could not parse fix: '{raw_fix}'")
354
+
355
+ fix_result = grade_fixes(parsed_fixes, self._current_task)
356
+ fix_score = fix_result["fix_score"]
357
 
358
+ # ── Combined reward ──
359
+ # If no fixes submitted, score is identify-only (no penalty for not fixing)
360
+ if action.fixes:
361
+ combined_reward = IDENTIFY_WEIGHT * identify_score + FIX_WEIGHT * fix_score
362
+ else:
363
+ combined_reward = identify_score # backward compatible
364
+
365
+ self._best_score = max(self._best_score, combined_reward)
366
  self._state.best_score = self._best_score
367
 
368
+ # ── Check if done ──
369
  is_done = (
370
+ identify_f1 >= 0.999 # Perfect identification
371
  or self._state.current_step >= self._state.max_steps
372
  )
373
 
374
+ # ── Build feedback with actionable diagnostics ──
375
+ # Show the agent exactly which reported issues were correct (TP) and which were wrong (FP)
376
+ tp_keys = reported_keys & self._planted_keys
377
+ fp_keys = reported_keys - self._planted_keys
378
+
379
  feedback_lines = [
380
  f"Step {self._state.current_step}/{self._state.max_steps}",
381
+ "",
382
+ "--- Identification ---",
383
  f"Issues reported: {len(reported_keys)}",
384
  f"True positives: {metrics['tp']}, False positives: {metrics['fp']}, Missed: {metrics['fn']}",
385
+ f"Precision: {metrics['precision']:.3f}, Recall: {metrics['recall']:.3f}, F1: {identify_f1:.3f}",
386
+ f"Identify score (weighted): {identify_score:.3f}",
387
  ]
388
 
389
+ # Show which reported issues were correct vs wrong (helps agent self-correct)
390
+ if tp_keys:
391
+ feedback_lines.append(f"Correct issues: {', '.join(sorted(tp_keys))}")
392
+ if fp_keys:
393
+ feedback_lines.append(f"Incorrect issues (false positives): {', '.join(sorted(fp_keys))}")
394
+
395
+ if action.fixes:
396
+ feedback_lines += [
397
+ "",
398
+ "--- Fix Proposals ---",
399
+ f"Fixes attempted: {fix_result['fixes_attempted']}",
400
+ f"Correct: {fix_result['fixes_correct']}, Partial: {fix_result['fixes_partial']}, Wrong: {fix_result['fixes_wrong']}",
401
+ f"Fix score: {fix_score:.3f}",
402
+ ]
403
+ # Show per-fix feedback so agent knows which fixes worked
404
+ for detail in fix_result["fix_details"]:
405
+ status = "correct" if detail["score"] >= 0.99 else ("partial" if detail["score"] > 0 else "wrong")
406
+ feedback_lines.append(
407
+ f" row:{detail['row']},col:{detail['col']} -> {status} ({detail['reason']})"
408
+ )
409
+ feedback_lines.append(
410
+ f"\n--- Combined Reward: {combined_reward:.3f} (identify={identify_score:.3f} x {IDENTIFY_WEIGHT} + fix={fix_score:.3f} x {FIX_WEIGHT}) ---"
411
+ )
412
+ else:
413
+ feedback_lines += [
414
+ "",
415
+ "Tip: Submit fixes with format row:<N>,col:<name>,fix:<value> for bonus reward.",
416
+ ]
417
+
418
  if parse_errors:
419
+ feedback_lines.append(f"\nParse errors ({len(parse_errors)}): {'; '.join(parse_errors[:5])}")
420
 
421
  if not is_done:
 
422
  if metrics["fn"] > 0:
423
  feedback_lines.append(
424
+ f"\nYou missed {metrics['fn']} issue(s). Review the dataset carefully."
425
  )
426
  if metrics["fp"] > 0:
427
  feedback_lines.append(
428
+ f"Remove the {metrics['fp']} false positive(s) listed above and look for real issues."
429
  )
430
+ feedback_lines.append("You can submit again with updated issues and/or fixes.")
431
  else:
432
+ feedback_lines.append(f"\nTask complete! Final best reward: {self._best_score:.3f}")
433
+
434
+ # ── Flag items for human review ──
435
+ # In a production data QA pipeline, these would go to a human reviewer.
436
+ # The grader flags cases where automated scoring has low confidence.
437
+ human_review_flags: list[dict] = []
438
+
439
+ # 1. False positives that target real columns — could be legitimate issues
440
+ # the task designer didn't plant (agent may be smarter than the grader)
441
+ issue_map = self._current_task.get_planted_issue_map()
442
+ valid_issue_types = {"missing_value", "wrong_type", "duplicate_row", "out_of_range",
443
+ "format_violation", "inconsistent_value", "statistical_outlier",
444
+ "referential_integrity"}
445
+ for fp_key in fp_keys:
446
+ parts = fp_key.split(",")
447
+ itype = parts[2].split(":")[1] if len(parts) >= 3 else ""
448
+ if itype in valid_issue_types:
449
+ human_review_flags.append({
450
+ "item": fp_key,
451
+ "reason": "Agent reported this issue but it's not in ground truth — may be a real issue the grader missed",
452
+ "type": "possible_unplanted_issue",
453
+ })
454
+
455
+ # 2. Partial fix matches — fix was close but not exact, human should verify
456
+ for detail in fix_result["fix_details"]:
457
+ if 0 < detail["score"] < 0.99:
458
+ human_review_flags.append({
459
+ "item": f"row:{detail['row']},col:{detail['col']}",
460
+ "reason": f"Fix scored {detail['score']:.2f} ({detail['reason']}) — human should verify if acceptable",
461
+ "type": "partial_fix",
462
+ })
463
+
464
+ # 3. High-difficulty issues that were missed — flag for training data review
465
+ planted_by_key = {i.to_key(): i for i in self._current_task.planted_issues}
466
+ fn_keys = self._planted_keys - reported_keys
467
+ for fn_key in fn_keys:
468
+ issue = planted_by_key.get(fn_key)
469
+ if issue and issue.difficulty >= 2.5:
470
+ human_review_flags.append({
471
+ "item": fn_key,
472
+ "reason": f"High-difficulty issue (difficulty={issue.difficulty}) missed — {issue.description}",
473
+ "type": "missed_hard_issue",
474
+ })
475
+
476
+ if human_review_flags:
477
+ feedback_lines.append(f"\n--- Flagged for Human Review ({len(human_review_flags)}) ---")
478
+ for flag in human_review_flags:
479
+ feedback_lines.append(f" [{flag['type']}] {flag['item']}: {flag['reason']}")
480
 
481
  return DataQAObservation(
482
  dataset_csv=self._current_task.corrupted_csv,
 
490
  done=is_done,
491
  reward=self._best_score,
492
  metadata={
493
+ "identify_f1": identify_f1,
494
+ "identify_score": identify_score,
495
+ "fix_score": fix_score,
496
+ "combined_reward": combined_reward,
497
  "precision": metrics["precision"],
498
  "recall": metrics["recall"],
499
  "tp": metrics["tp"],
 
501
  "fn": metrics["fn"],
502
  "difficulty_found": weighted["difficulty_found"],
503
  "difficulty_missed": weighted["difficulty_missed"],
504
+ "fixes_correct": fix_result["fixes_correct"],
505
+ "fixes_partial": fix_result["fixes_partial"],
506
+ "fixes_wrong": fix_result["fixes_wrong"],
507
+ "fixes_attempted": fix_result["fixes_attempted"],
508
+ "fix_details": fix_result["fix_details"],
509
+ "human_review_flags": human_review_flags,
510
  },
511
  )
512
 
dataqa_env/server/gradio_ui.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio UI — Agent Trajectory Replay Viewer for DataQA.
3
+
4
+ Designed for judges: zero clicks needed, auto-plays on load.
5
+ Tab per task, step slider, prominent metric cards, color-coded dataset.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import csv
11
+ import io
12
+
13
+ import gradio as gr
14
+
15
+ from .environment import DataQAEnvironment, parse_issue_key
16
+ from .tasks import list_tasks, PlantedIssue
17
+ from ..models import DataQAAction
18
+
19
+
20
+ # ── Pre-built agent trajectories (simulates baseline agent) ──
21
+
22
+ AGENT_TRAJECTORIES = {
23
+ "easy": [
24
+ {
25
+ "issues": [
26
+ "row:4,col:name,issue:missing_value",
27
+ "row:7,col:salary,issue:wrong_type",
28
+ "row:9,col:salary,issue:out_of_range",
29
+ "row:3,col:email,issue:format_violation", # FP
30
+ ],
31
+ "fixes": [],
32
+ },
33
+ {
34
+ "issues": [
35
+ "row:4,col:name,issue:missing_value",
36
+ "row:7,col:salary,issue:wrong_type",
37
+ "row:9,col:salary,issue:out_of_range",
38
+ "row:11,col:employee_id,issue:duplicate_row",
39
+ ],
40
+ "fixes": [
41
+ "row:4,col:name,fix:David Kim",
42
+ "row:7,col:salary,fix:75000",
43
+ "row:9,col:salary,fix:73000",
44
+ ],
45
+ },
46
+ ],
47
+ "medium": [
48
+ {
49
+ "issues": [
50
+ "row:5,col:total,issue:inconsistent_value",
51
+ "row:10,col:category,issue:format_violation",
52
+ "row:14,col:product_name,issue:missing_value",
53
+ "row:17,col:quantity,issue:out_of_range",
54
+ "row:19,col:order_id,issue:duplicate_row",
55
+ "row:12,col:order_date,issue:format_violation",
56
+ ],
57
+ "fixes": [
58
+ "row:5,col:total,fix:42.00",
59
+ "row:10,col:category,fix:Sports",
60
+ "row:12,col:order_date,fix:2024-01-26",
61
+ "row:14,col:product_name,fix:LED Strip Lights",
62
+ ],
63
+ },
64
+ ],
65
+ "hard": [
66
+ {
67
+ "issues": [
68
+ "row:14,col:training_time_hours,issue:out_of_range",
69
+ "row:13,col:learning_rate,issue:out_of_range",
70
+ "row:15,col:model_name,issue:missing_value",
71
+ "row:9,col:batch_size,issue:format_violation",
72
+ "row:10,col:train_size,issue:inconsistent_value",
73
+ ],
74
+ "fixes": [],
75
+ },
76
+ {
77
+ "issues": [
78
+ "row:14,col:training_time_hours,issue:out_of_range",
79
+ "row:13,col:learning_rate,issue:out_of_range",
80
+ "row:15,col:model_name,issue:missing_value",
81
+ "row:9,col:batch_size,issue:format_violation",
82
+ "row:10,col:train_size,issue:inconsistent_value",
83
+ "row:5,col:val_loss,issue:inconsistent_value",
84
+ "row:7,col:gpu_memory_gb,issue:statistical_outlier",
85
+ "row:11,col:timestamp,issue:inconsistent_value",
86
+ "row:9,col:training_time_hours,issue:statistical_outlier",
87
+ "row:12,col:test_accuracy,issue:statistical_outlier",
88
+ ],
89
+ "fixes": [
90
+ "row:14,col:training_time_hours,fix:72.0",
91
+ "row:13,col:learning_rate,fix:0.00001",
92
+ "row:15,col:model_name,fix:whisper-small",
93
+ "row:9,col:batch_size,fix:256",
94
+ "row:9,col:training_time_hours,fix:36.0",
95
+ ],
96
+ },
97
+ ],
98
+ }
99
+
100
+
101
+ # ── HTML rendering ──
102
+
103
+ def _metric_card(label: str, value: str, color: str = "#333") -> str:
104
+ return (
105
+ f'<div style="text-align:center;padding:12px 16px;background:#f8f9fa;'
106
+ f'border-radius:8px;min-width:100px;">'
107
+ f'<div style="font-size:11px;color:#666;text-transform:uppercase;letter-spacing:1px;">{label}</div>'
108
+ f'<div style="font-size:28px;font-weight:700;color:{color};margin-top:2px;">{value}</div>'
109
+ f'</div>'
110
+ )
111
+
112
+
113
+ def _csv_to_html(
114
+ csv_text: str,
115
+ planted: list[PlantedIssue],
116
+ correct: set[tuple[int, str]],
117
+ fp: set[tuple[int, str]],
118
+ missed: set[tuple[int, str]],
119
+ fixed: dict[tuple[int, str], str],
120
+ fix_values: dict[tuple[int, str], str] | None = None,
121
+ ) -> str:
122
+ """Render CSV as HTML with color-coded cells and inline fix proposals."""
123
+ fix_values = fix_values or {}
124
+ desc_map = {(i.row, i.col): i for i in planted}
125
+ reader = csv.reader(io.StringIO(csv_text.strip()))
126
+ rows = list(reader)
127
+ if not rows:
128
+ return ""
129
+
130
+ header = rows[0]
131
+ header_lower = [h.strip().lower() for h in header]
132
+ data = rows[1:]
133
+
134
+ t = ['<table style="border-collapse:collapse;width:100%;font-size:12px;font-family:\'SF Mono\',monospace;">']
135
+ t.append('<tr>')
136
+ t.append('<th style="border:1px solid #dee2e6;padding:6px 8px;background:#343a40;color:#fff;font-size:11px;">Row</th>')
137
+ for h in header:
138
+ t.append(f'<th style="border:1px solid #dee2e6;padding:6px 8px;background:#343a40;color:#fff;font-size:11px;">{h}</th>')
139
+ t.append('</tr>')
140
+
141
+ for i, row in enumerate(data):
142
+ rn = i + 1
143
+ bg = "#fff" if i % 2 == 0 else "#f8f9fa"
144
+ t.append(f'<tr style="background:{bg};">')
145
+ t.append(f'<td style="border:1px solid #dee2e6;padding:4px 8px;color:#adb5bd;text-align:center;font-size:11px;">{rn}</td>')
146
+ for j, val in enumerate(row):
147
+ col = header_lower[j] if j < len(header_lower) else ""
148
+ ck = (rn, col)
149
+ s = "border:1px solid #dee2e6;padding:4px 8px;"
150
+ tip = ""
151
+ badge = ""
152
+
153
+ issue = desc_map.get(ck)
154
+
155
+ if ck in correct:
156
+ s += "background:#d4edda;"
157
+ tip = f"FOUND: {issue.description}" if issue else ""
158
+ badge = '<span style="font-size:9px;background:#28a745;color:#fff;padding:1px 4px;border-radius:3px;margin-left:4px;">TP</span>'
159
+ elif ck in fp:
160
+ s += "background:#f8d7da;"
161
+ badge = '<span style="font-size:9px;background:#dc3545;color:#fff;padding:1px 4px;border-radius:3px;margin-left:4px;">FP</span>'
162
+ elif ck in missed:
163
+ s += "background:#fff3cd;"
164
+ tip = f"MISSED: {issue.description}" if issue else ""
165
+ badge = '<span style="font-size:9px;background:#856404;color:#fff;padding:1px 4px;border-radius:3px;margin-left:4px;">MISS</span>'
166
+
167
+ fx = fixed.get(ck)
168
+ proposed = fix_values.get(ck)
169
+ if fx == "correct":
170
+ s += "box-shadow:inset 0 0 0 2px #28a745;"
171
+ badge += '<span style="font-size:9px;background:#28a745;color:#fff;padding:1px 4px;border-radius:3px;margin-left:2px;">FIX</span>'
172
+ elif fx == "partial":
173
+ s += "box-shadow:inset 0 0 0 2px #ffc107;"
174
+ badge += '<span style="font-size:9px;background:#ffc107;color:#333;padding:1px 4px;border-radius:3px;margin-left:2px;">~FIX</span>'
175
+
176
+ dv = val if val.strip() else '<em style="color:#dc3545;font-style:italic;">empty</em>'
177
+
178
+ # Show proposed fix value below the corrupted value
179
+ fix_line = ""
180
+ if proposed is not None:
181
+ fix_color = "#28a745" if fx == "correct" else ("#b8860b" if fx == "partial" else "#dc3545")
182
+ fix_line = (
183
+ f'<div style="font-size:10px;color:{fix_color};margin-top:2px;'
184
+ f'border-top:1px dashed {fix_color};padding-top:2px;">'
185
+ f'\u2192 {proposed}</div>'
186
+ )
187
+
188
+ t.append(f'<td style="{s}" title="{tip}">{dv}{badge}{fix_line}</td>')
189
+ t.append('</tr>')
190
+ t.append('</table>')
191
+ return "".join(t)
192
+
193
+
194
+ LEGEND_HTML = (
195
+ '<div style="display:flex;gap:12px;flex-wrap:wrap;margin-top:10px;font-size:11px;">'
196
+ '<span style="background:#d4edda;padding:2px 8px;border-radius:4px;">Found (TP)</span>'
197
+ '<span style="background:#f8d7da;padding:2px 8px;border-radius:4px;">False Positive</span>'
198
+ '<span style="background:#fff3cd;padding:2px 8px;border-radius:4px;">Missed</span>'
199
+ '<span style="box-shadow:inset 0 0 0 2px #28a745;padding:2px 8px;border-radius:4px;">Fix Correct</span>'
200
+ '<span style="box-shadow:inset 0 0 0 2px #ffc107;padding:2px 8px;border-radius:4px;">Fix Partial</span>'
201
+ '</div>'
202
+ )
203
+
204
+
205
+ # ── Core replay logic ──
206
+
207
+ def _replay_task(task_id: str) -> list[dict]:
208
+ """Run the agent trajectory and collect per-step data."""
209
+ env = DataQAEnvironment()
210
+ obs = env.reset(task_id=task_id)
211
+ task = env._current_task
212
+ planted_keys = {i.to_key() for i in task.planted_issues}
213
+ steps_data = []
214
+
215
+ # Step 0: initial state
216
+ steps_data.append({
217
+ "label": "Initial — corrupted dataset",
218
+ "html": _csv_to_html(obs.dataset_csv, task.planted_issues, set(), set(), set(), {}),
219
+ "metrics": {"reward": 0.0, "tp": 0, "fp": 0, "fn": len(task.planted_issues),
220
+ "identify": 0.0, "fix": 0.0, "fixes_correct": 0},
221
+ "feedback": f"Task: {task.name}\nIssues to find: {obs.num_issues_hint}\n\n{task.description}",
222
+ })
223
+
224
+ trajectory = AGENT_TRAJECTORIES.get(task_id, [])
225
+ for i, step_data in enumerate(trajectory):
226
+ action = DataQAAction(
227
+ issues=step_data["issues"],
228
+ fixes=step_data.get("fixes", []),
229
+ task_id=task_id,
230
+ )
231
+ obs = env.step(action)
232
+
233
+ reported_keys = set()
234
+ for iss in step_data["issues"]:
235
+ key = parse_issue_key(iss)
236
+ if key:
237
+ reported_keys.add(key)
238
+
239
+ tp_keys = reported_keys & planted_keys
240
+ fp_keys = reported_keys - planted_keys
241
+ fn_keys = planted_keys - reported_keys
242
+
243
+ correct = {_kc(k) for k in tp_keys}
244
+ fp = {_kc(k) for k in fp_keys}
245
+ missed = {_kc(k) for k in fn_keys} if obs.done else set()
246
+
247
+ fixed: dict[tuple[int, str], str] = {}
248
+ for d in obs.metadata.get("fix_details", []):
249
+ c = (d["row"], d["col"])
250
+ fixed[c] = "correct" if d["score"] >= 0.99 else ("partial" if d["score"] > 0 else "wrong")
251
+
252
+ # Extract proposed fix values from the raw fix strings
253
+ fix_values: dict[tuple[int, str], str] = {}
254
+ from .environment import parse_fix
255
+ for raw_fix in step_data.get("fixes", []):
256
+ parsed = parse_fix(raw_fix)
257
+ if parsed:
258
+ row, col, val = parsed
259
+ fix_values[(row, col)] = val
260
+
261
+ html = _csv_to_html(obs.dataset_csv, task.planted_issues, correct, fp, missed, fixed, fix_values)
262
+
263
+ has_fixes = bool(step_data.get("fixes"))
264
+ if has_fixes:
265
+ label = f"Step {i+1} — identify + fix"
266
+ else:
267
+ label = f"Step {i+1} — identify only"
268
+
269
+ steps_data.append({
270
+ "label": label,
271
+ "html": html,
272
+ "metrics": {
273
+ "reward": obs.reward,
274
+ "tp": obs.metadata["tp"],
275
+ "fp": obs.metadata["fp"],
276
+ "fn": obs.metadata["fn"],
277
+ "identify": obs.metadata["identify_score"],
278
+ "fix": obs.metadata["fix_score"],
279
+ "fixes_correct": obs.metadata["fixes_correct"],
280
+ },
281
+ "feedback": obs.feedback,
282
+ })
283
+
284
+ return steps_data
285
+
286
+
287
+ def _kc(key: str) -> tuple[int, str]:
288
+ parts = key.split(",")
289
+ return (int(parts[0].split(":")[1]), parts[1].split(":")[1])
290
+
291
+
292
+ # ── Gradio app ──
293
+
294
+ def build_gradio_ui():
295
+ # Pre-compute all replays at startup
296
+ all_replays: dict[str, list[dict]] = {}
297
+ for tid in list_tasks():
298
+ all_replays[tid] = _replay_task(tid)
299
+
300
+ def show_step(task_id: str, step_idx: int):
301
+ replay = all_replays.get(task_id, [])
302
+ step_idx = int(step_idx)
303
+ if step_idx >= len(replay):
304
+ step_idx = len(replay) - 1
305
+ sd = replay[step_idx]
306
+ m = sd["metrics"]
307
+
308
+ # Reward color
309
+ r = m["reward"]
310
+ rc = "#28a745" if r >= 0.8 else ("#ffc107" if r >= 0.4 else "#dc3545")
311
+
312
+ cards = (
313
+ '<div style="display:flex;gap:10px;flex-wrap:wrap;margin-bottom:12px;">'
314
+ + _metric_card("Reward", f"{r:.2f}", rc)
315
+ + _metric_card("Found", str(m["tp"]), "#28a745")
316
+ + _metric_card("False Pos", str(m["fp"]), "#dc3545" if m["fp"] > 0 else "#28a745")
317
+ + _metric_card("Missed", str(m["fn"]), "#dc3545" if m["fn"] > 0 else "#28a745")
318
+ + _metric_card("Identify", f"{m['identify']:.2f}", "#333")
319
+ + _metric_card("Fix", f"{m['fix']:.2f}", "#333")
320
+ + '</div>'
321
+ )
322
+
323
+ full_html = (
324
+ f'<div style="font-size:14px;font-weight:600;margin-bottom:8px;color:#495057;">'
325
+ f'{sd["label"]}</div>'
326
+ + cards + sd["html"] + LEGEND_HTML
327
+ )
328
+
329
+ return full_html, sd["feedback"]
330
+
331
+ def on_task_change(task_id):
332
+ replay = all_replays.get(task_id, [])
333
+ max_step = len(replay) - 1
334
+ html, fb = show_step(task_id, 0)
335
+ return (
336
+ gr.update(maximum=max_step, value=0),
337
+ html,
338
+ fb,
339
+ )
340
+
341
+ def on_step_change(task_id, step_idx):
342
+ html, fb = show_step(task_id, step_idx)
343
+ return html, fb
344
+
345
+ # ── Live agent runner (connects to the env server) ──
346
+
347
+ live_env = DataQAEnvironment()
348
+ live_state: dict = {"obs": None, "task_id": "easy", "steps": []}
349
+
350
+ def live_reset(task_id):
351
+ obs = live_env.reset(task_id=task_id)
352
+ task = live_env._current_task
353
+ live_state["obs"] = obs
354
+ live_state["task_id"] = task_id
355
+ live_state["steps"] = []
356
+ html = _csv_to_html(obs.dataset_csv, task.planted_issues, set(), set(), set(), {})
357
+ info = f"**{task.name}** — {obs.num_issues_hint} issues to find, {obs.max_steps} steps max"
358
+ return html, info, "", "0.000"
359
+
360
+ def live_step(issues_text, fixes_text):
361
+ if live_state["obs"] is None:
362
+ return "Reset first.", "", "", ""
363
+ obs = live_state["obs"]
364
+ task = live_env._current_task
365
+ planted_keys = {i.to_key() for i in task.planted_issues}
366
+
367
+ issues = [l.strip() for l in issues_text.strip().split("\n") if l.strip()]
368
+ fixes = [l.strip() for l in fixes_text.strip().split("\n") if l.strip()] if fixes_text.strip() else []
369
+
370
+ action = DataQAAction(issues=issues, fixes=fixes, task_id=live_state["task_id"])
371
+ obs = live_env.step(action)
372
+ live_state["obs"] = obs
373
+
374
+ reported_keys = set()
375
+ for iss in issues:
376
+ key = parse_issue_key(iss)
377
+ if key:
378
+ reported_keys.add(key)
379
+
380
+ tp_keys = reported_keys & planted_keys
381
+ fp_keys = reported_keys - planted_keys
382
+ fn_keys = planted_keys - reported_keys
383
+
384
+ correct = {_kc(k) for k in tp_keys}
385
+ fp_set = {_kc(k) for k in fp_keys}
386
+ missed = {_kc(k) for k in fn_keys} if obs.done else set()
387
+
388
+ fixed: dict[tuple[int, str], str] = {}
389
+ for d in obs.metadata.get("fix_details", []):
390
+ c = (d["row"], d["col"])
391
+ fixed[c] = "correct" if d["score"] >= 0.99 else ("partial" if d["score"] > 0 else "wrong")
392
+
393
+ from .environment import parse_fix
394
+ fix_values: dict[tuple[int, str], str] = {}
395
+ for raw in fixes:
396
+ parsed = parse_fix(raw)
397
+ if parsed:
398
+ fix_values[(parsed[0], parsed[1])] = parsed[2]
399
+
400
+ html = _csv_to_html(obs.dataset_csv, task.planted_issues, correct, fp_set, missed, fixed, fix_values)
401
+
402
+ m = obs.metadata
403
+ r = obs.reward
404
+ rc = "#28a745" if r >= 0.8 else ("#ffc107" if r >= 0.4 else "#dc3545")
405
+ cards = (
406
+ '<div style="display:flex;gap:10px;flex-wrap:wrap;margin-bottom:12px;">'
407
+ + _metric_card("Reward", f"{r:.2f}", rc)
408
+ + _metric_card("Found", str(m["tp"]), "#28a745")
409
+ + _metric_card("False Pos", str(m["fp"]), "#dc3545" if m["fp"] > 0 else "#28a745")
410
+ + _metric_card("Missed", str(m["fn"]), "#dc3545" if m["fn"] > 0 else "#28a745")
411
+ + '</div>'
412
+ )
413
+ full_html = cards + html + LEGEND_HTML
414
+ return full_html, obs.feedback, f"{r:.3f}", ""
415
+
416
+ # ── Build the UI ──
417
+
418
+ with gr.Blocks(title="DataQA Environment") as demo:
419
+ gr.Markdown(
420
+ "# DataQA — Data Quality Assurance Environment\n"
421
+ "Two-phase RL environment: **Identify** data quality issues, then **Fix** them."
422
+ )
423
+
424
+ with gr.Tabs():
425
+ # ── Tab 1: Demo replay ──
426
+ with gr.Tab("Demo (Baseline Agent)"):
427
+ gr.Markdown(
428
+ "*Replay of the baseline Qwen-72B agent. "
429
+ "Use the slider to step through the agent's trajectory.*"
430
+ )
431
+ with gr.Row():
432
+ task_dd = gr.Dropdown(choices=list_tasks(), value="easy", label="Task", scale=1)
433
+ step_slider = gr.Slider(minimum=0, maximum=2, step=1, value=0, label="Step", scale=3)
434
+
435
+ viz_html = gr.HTML()
436
+ feedback_box = gr.Textbox(label="Agent Feedback", lines=10, interactive=False)
437
+
438
+ task_dd.change(on_task_change, inputs=[task_dd], outputs=[step_slider, viz_html, feedback_box])
439
+ step_slider.change(on_step_change, inputs=[task_dd, step_slider], outputs=[viz_html, feedback_box])
440
+ demo.load(on_task_change, inputs=[task_dd], outputs=[step_slider, viz_html, feedback_box])
441
+
442
+ # ── Tab 2: Try your own agent ──
443
+ with gr.Tab("Try Your Own Agent"):
444
+ gr.Markdown(
445
+ "*Submit your own issues and fixes to see how the environment scores them. "
446
+ "This is the same environment the baseline agent talks to.*"
447
+ )
448
+ with gr.Row():
449
+ live_task_dd = gr.Dropdown(choices=list_tasks(), value="easy", label="Task", scale=1)
450
+ live_reset_btn = gr.Button("Reset", variant="primary", scale=1)
451
+
452
+ with gr.Row():
453
+ live_info = gr.Markdown()
454
+ live_reward = gr.Textbox(label="Reward", interactive=False, scale=1)
455
+
456
+ live_viz = gr.HTML()
457
+
458
+ with gr.Row():
459
+ live_issues = gr.Textbox(
460
+ label="Issues (one per line)",
461
+ placeholder="row:4,col:name,issue:missing_value\nrow:7,col:salary,issue:wrong_type",
462
+ lines=5,
463
+ )
464
+ live_fixes = gr.Textbox(
465
+ label="Fixes (one per line, optional)",
466
+ placeholder="row:4,col:name,fix:David Kim\nrow:7,col:salary,fix:75000",
467
+ lines=5,
468
+ )
469
+
470
+ live_step_btn = gr.Button("Submit Step", variant="primary")
471
+ live_feedback = gr.Textbox(label="Feedback", lines=10, interactive=False)
472
+
473
+ live_reset_btn.click(
474
+ live_reset, inputs=[live_task_dd],
475
+ outputs=[live_viz, live_info, live_feedback, live_reward],
476
+ )
477
+ live_step_btn.click(
478
+ live_step, inputs=[live_issues, live_fixes],
479
+ outputs=[live_viz, live_feedback, live_reward, live_issues],
480
+ )
481
+
482
+ return demo
483
+
484
+
485
+ if __name__ == "__main__":
486
+ demo = build_gradio_ui()
487
+ demo.launch()
dataqa_env/server/tasks.py CHANGED
@@ -43,6 +43,28 @@ class Task:
43
  corrupted_csv: str = ""
44
  max_steps: int = 3
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  def _csv_to_rows(csv_text: str) -> List[List[str]]:
48
  reader = csv.reader(io.StringIO(csv_text.strip()))
@@ -354,6 +376,32 @@ EXP-015,whisper-small,common-voice,520000,16000,16000,0.00005,16,5,0.55,0.68,0.0
354
  issues.append(PlantedIssue(row=r + 1, col="model_name", issue_type="missing_value",
355
  description="model_name is whitespace-only", difficulty=2.5))
356
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
  corrupted = _rows_to_csv([header] + data)
358
 
359
  return Task(
 
43
  corrupted_csv: str = ""
44
  max_steps: int = 3
45
 
46
+ def get_clean_value(self, row: int, col: str) -> str | None:
47
+ """
48
+ Look up the original clean value for a given (row, col).
49
+ Row is 1-indexed (data row after header).
50
+ Returns None if row/col is out of bounds or column not found.
51
+ """
52
+ rows = _csv_to_rows(self.clean_csv)
53
+ if len(rows) < 2:
54
+ return None
55
+ header = [h.strip().lower() for h in rows[0]]
56
+ if col.lower() not in header:
57
+ return None
58
+ col_idx = header.index(col.lower())
59
+ data_row_idx = row # row is 1-indexed, rows[0] is header, so rows[row] is the data row
60
+ if data_row_idx < 1 or data_row_idx >= len(rows):
61
+ return None
62
+ return rows[data_row_idx][col_idx].strip()
63
+
64
+ def get_planted_issue_map(self) -> dict:
65
+ """Return dict mapping issue key -> PlantedIssue for quick lookups."""
66
+ return {issue.to_key(): issue for issue in self.planted_issues}
67
+
68
 
69
  def _csv_to_rows(csv_text: str) -> List[List[str]]:
70
  reader = csv.reader(io.StringIO(csv_text.strip()))
 
376
  issues.append(PlantedIssue(row=r + 1, col="model_name", issue_type="missing_value",
377
  description="model_name is whitespace-only", difficulty=2.5))
378
 
379
+ # Issue 9: Training time impossibly fast for dataset size and epochs
380
+ # EXP-004: vit-base on imagenet-1k, 300 epochs, but only 96 hours is plausible.
381
+ # Let's make EXP-009: efficientnet-b0 on imagenet-1k, 350 epochs = should take ~40+ hours
382
+ # but we set it to 0.5 hours — impossible for 1.2M images * 350 epochs
383
+ r = 8 # EXP-009 (same row as batch_size issue, different column)
384
+ data[r][13] = "0.5" # 30 minutes for 350 epochs on imagenet? impossible
385
+ issues.append(PlantedIssue(row=r + 1, col="training_time_hours", issue_type="statistical_outlier",
386
+ description="0.5 hours for 350 epochs on imagenet-1k (1.2M images) is impossibly fast",
387
+ difficulty=3.0))
388
+
389
+ # Issue 10: test_accuracy of 95.1% for roberta-large on SST-2 with train_size=500
390
+ # is suspiciously high — SOTA is ~96% with full dataset (67k). With only 500 training
391
+ # samples, 95.1% accuracy suggests data contamination or evaluation bug
392
+ r = 9 # EXP-010 (same row as train_size issue, different column)
393
+ # train_size is already corrupted to 500, but the test_accuracy 95.1 is from the
394
+ # original full-dataset run — this cross-column inconsistency is the real issue
395
+ # We don't modify the value — the inconsistency emerges from the train_size corruption
396
+ # So let's use a different row. EXP-001: resnet50 on imagenet, accuracy 76.3 is fine.
397
+ # Instead: EXP-012 wav2vec2 on librispeech — set test_accuracy to 98.5 (way too high
398
+ # for a speech model with only 20 epochs, SOTA is ~96% with much more training)
399
+ r = 11 # EXP-012
400
+ data[r][11] = "98.5" # wav2vec2 with 20 epochs shouldn't hit 98.5% — SOTA is ~96%
401
+ issues.append(PlantedIssue(row=r + 1, col="test_accuracy", issue_type="statistical_outlier",
402
+ description="test_accuracy 98.5% for wav2vec2 with only 20 epochs exceeds known SOTA (~96%), likely evaluation error",
403
+ difficulty=3.0))
404
+
405
  corrupted = _rows_to_csv([header] + data)
406
 
407
  return Task(
inference.py CHANGED
@@ -1,8 +1,11 @@
1
  #!/usr/bin/env python3
2
  """
3
- DataQA Inference Script
4
- -----------------------
5
- LLM agent that plays the DataQA environment.
 
 
 
6
  Uses the OpenAI client to interact with any OpenAI-compatible LLM API.
7
 
8
  Required environment variables:
@@ -92,10 +95,10 @@ class EnvHTTPClient:
92
  r.raise_for_status()
93
  return r.json()
94
 
95
- def step(self, issues: list[str], task_id: str = "easy") -> dict:
96
  r = self.session.post(
97
  f"{self.base_url}/step",
98
- json={"action": {"issues": issues, "task_id": task_id}},
99
  timeout=30,
100
  )
101
  r.raise_for_status()
@@ -103,10 +106,10 @@ class EnvHTTPClient:
103
 
104
 
105
  # ---------------------------------------------------------------------------
106
- # LLM Agent
107
  # ---------------------------------------------------------------------------
108
 
109
- SYSTEM_PROMPT = """You are a data quality analyst. Your job is to inspect datasets and identify data quality issues.
110
 
111
  You will be given:
112
  1. A dataset in CSV format
@@ -141,7 +144,26 @@ Respond with ONLY the list of issues, one per line. No other text.
141
  Example: row:3,col:salary,issue:missing_value"""
142
 
143
 
144
- def build_user_prompt(observation: dict) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  obs = observation if isinstance(observation, dict) else observation
146
  parts = []
147
 
@@ -160,6 +182,12 @@ def build_user_prompt(observation: dict) -> str:
160
  if feedback and "reset" not in feedback.lower():
161
  parts.append(f"FEEDBACK FROM PREVIOUS ATTEMPT:\n{feedback}")
162
 
 
 
 
 
 
 
163
  return "\n\n".join(parts)
164
 
165
 
@@ -170,7 +198,6 @@ def parse_llm_response(response: str) -> list[str]:
170
  line = line.strip()
171
  if not line:
172
  continue
173
- # Remove numbering like "1. " or "- " or "* "
174
  line = re.sub(r"^\s*[\d]+[.\)]\s*", "", line)
175
  line = re.sub(r"^\s*[-*]\s*", "", line)
176
  line = line.strip()
@@ -186,8 +213,60 @@ def parse_llm_response(response: str) -> list[str]:
186
  return issues
187
 
188
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  def run_task(client: OpenAI, env: EnvHTTPClient, task_id: str) -> float:
190
- """Run a single task and return the best score."""
 
 
 
 
 
191
  log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
192
 
193
  rewards: List[float] = []
@@ -196,48 +275,38 @@ def run_task(client: OpenAI, env: EnvHTTPClient, task_id: str) -> float:
196
  success = False
197
 
198
  try:
199
- # Reset environment for this task
200
  reset_response = env.reset(task_id=task_id)
201
  observation = reset_response.get("observation", reset_response)
202
 
203
- messages = [{"role": "system", "content": SYSTEM_PROMPT}]
 
204
 
205
  for step_num in range(1, MAX_STEPS_PER_TASK + 1):
206
- user_prompt = build_user_prompt(observation)
207
- messages_for_call = messages + [{"role": "user", "content": user_prompt}]
208
-
209
- # Call LLM with retry on rate limit
210
- llm_output = ""
211
  error_msg = None
212
- for attempt in range(3):
213
- try:
214
- response = client.chat.completions.create(
215
- model=MODEL_NAME,
216
- messages=messages_for_call,
217
- temperature=0.1,
218
- max_tokens=2048,
219
- )
220
- llm_output = response.choices[0].message.content or ""
221
- break
222
- except Exception as e:
223
- if "rate_limit" in str(e).lower() or "429" in str(e):
224
- wait = 10 * (attempt + 1)
225
- print(f"[DEBUG] Rate limited, waiting {wait}s...", file=sys.stderr, flush=True)
226
- time.sleep(wait)
227
- else:
228
- error_msg = str(e)
229
- print(f"[DEBUG] LLM call failed: {e}", file=sys.stderr, flush=True)
230
- break
231
-
232
- # Parse issues from LLM response
233
- issues = parse_llm_response(llm_output)
234
- action_str = ";".join(issues) if issues else "none"
235
 
236
  if not issues and not error_msg:
237
  error_msg = "no issues parsed from LLM response"
238
 
239
- # Submit to environment
240
- step_response = env.step(issues, task_id=task_id)
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  observation = step_response.get("observation", step_response)
242
 
243
  reward = float(step_response.get("reward", 0.0) or 0.0)
@@ -257,9 +326,8 @@ def run_task(client: OpenAI, env: EnvHTTPClient, task_id: str) -> float:
257
  if done:
258
  break
259
 
260
- # Add context for next attempt
261
- messages.append({"role": "user", "content": user_prompt})
262
- messages.append({"role": "assistant", "content": llm_output})
263
 
264
  success = best_score >= 0.5
265
 
@@ -279,21 +347,18 @@ def main():
279
  print(f"[DEBUG] API_BASE_URL={API_BASE_URL}", file=sys.stderr, flush=True)
280
  print(f"[DEBUG] MODEL_NAME={MODEL_NAME}", file=sys.stderr, flush=True)
281
 
282
- # Initialize clients
283
  env = EnvHTTPClient(ENV_URL)
284
  llm_client = OpenAI(
285
  base_url=API_BASE_URL,
286
  api_key=API_KEY or "no-key",
287
  )
288
 
289
- # Check environment health
290
  if not env.health():
291
  print("[DEBUG] Environment is not healthy. Exiting.", file=sys.stderr, flush=True)
292
  sys.exit(1)
293
 
294
  print(f"[DEBUG] Environment is healthy", file=sys.stderr, flush=True)
295
 
296
- # Run all tasks
297
  scores = {}
298
  for task_id in TASKS:
299
  try:
@@ -303,7 +368,6 @@ def main():
303
  print(f"[DEBUG] Task {task_id} failed: {e}", file=sys.stderr, flush=True)
304
  scores[task_id] = 0.0
305
 
306
- # Summary to stderr (stdout is reserved for structured logs only)
307
  avg_score = sum(scores.values()) / len(scores) if scores else 0.0
308
  print(f"\n[DEBUG] FINAL RESULTS: {scores} avg={avg_score:.3f}", file=sys.stderr, flush=True)
309
 
 
1
  #!/usr/bin/env python3
2
  """
3
+ DataQA Inference Script — Two-Phase Agent
4
+ ------------------------------------------
5
+ LLM agent that plays the DataQA environment in two phases:
6
+ Phase 1: Identify all data quality issues
7
+ Phase 2: Propose fixes for identified issues
8
+
9
  Uses the OpenAI client to interact with any OpenAI-compatible LLM API.
10
 
11
  Required environment variables:
 
95
  r.raise_for_status()
96
  return r.json()
97
 
98
+ def step(self, issues: list[str], fixes: list[str], task_id: str = "easy") -> dict:
99
  r = self.session.post(
100
  f"{self.base_url}/step",
101
+ json={"action": {"issues": issues, "fixes": fixes, "task_id": task_id}},
102
  timeout=30,
103
  )
104
  r.raise_for_status()
 
106
 
107
 
108
  # ---------------------------------------------------------------------------
109
+ # LLM Prompts
110
  # ---------------------------------------------------------------------------
111
 
112
+ IDENTIFY_SYSTEM_PROMPT = """You are a data quality analyst. Your job is to inspect datasets and identify data quality issues.
113
 
114
  You will be given:
115
  1. A dataset in CSV format
 
144
  Example: row:3,col:salary,issue:missing_value"""
145
 
146
 
147
+ FIX_SYSTEM_PROMPT = """You are a data repair specialist. You have already identified data quality issues in a dataset. Now you must propose the correct values to fix each issue.
148
+
149
+ For each issue you identified, propose a fix in EXACTLY this format:
150
+ row:<row_number>,col:<column_name>,fix:<corrected_value>
151
+
152
+ Guidelines for proposing fixes:
153
+ - For missing_value: infer the correct value from context, schema, and other rows
154
+ - For wrong_type: convert to the correct type (e.g., "seventy-five thousand" → "75000")
155
+ - For out_of_range: propose a value within the valid range that makes sense in context
156
+ - For format_violation: correct the format (e.g., "26/01/2024" → "2024-01-26")
157
+ - For inconsistent_value: compute the correct value from related fields
158
+ - For duplicate_row: propose a corrected unique key or indicate removal
159
+ - For statistical_outlier: propose a reasonable value given the model/context
160
+
161
+ Use the schema, validation rules, and surrounding data to determine the correct fix.
162
+ Respond with ONLY the list of fixes, one per line. No other text.
163
+ Example: row:3,col:salary,fix:75000"""
164
+
165
+
166
+ def build_user_prompt(observation: dict, include_fixes: bool = False) -> str:
167
  obs = observation if isinstance(observation, dict) else observation
168
  parts = []
169
 
 
182
  if feedback and "reset" not in feedback.lower():
183
  parts.append(f"FEEDBACK FROM PREVIOUS ATTEMPT:\n{feedback}")
184
 
185
+ if include_fixes:
186
+ parts.append(
187
+ "Now propose fixes for ALL issues. "
188
+ "Use format: row:<N>,col:<name>,fix:<corrected_value>"
189
+ )
190
+
191
  return "\n\n".join(parts)
192
 
193
 
 
198
  line = line.strip()
199
  if not line:
200
  continue
 
201
  line = re.sub(r"^\s*[\d]+[.\)]\s*", "", line)
202
  line = re.sub(r"^\s*[-*]\s*", "", line)
203
  line = line.strip()
 
213
  return issues
214
 
215
 
216
+ def parse_fix_response(response: str) -> list[str]:
217
+ """Extract fix lines from LLM response."""
218
+ fixes = []
219
+ for line in response.strip().split("\n"):
220
+ line = line.strip()
221
+ if not line:
222
+ continue
223
+ line = re.sub(r"^\s*[\d]+[.\)]\s*", "", line)
224
+ line = re.sub(r"^\s*[-*]\s*", "", line)
225
+ line = line.strip()
226
+ if "row" in line.lower() and "fix" in line.lower():
227
+ match = re.search(
228
+ r"row\s*[:=]\s*(\d+)\s*[,;\s]+col(?:umn)?\s*[:=]\s*([\w_]+)\s*[,;\s]+fix\s*[:=]\s*(.+?)$",
229
+ line,
230
+ re.IGNORECASE,
231
+ )
232
+ if match:
233
+ normalized = f"row:{match.group(1)},col:{match.group(2).lower()},fix:{match.group(3).strip()}"
234
+ fixes.append(normalized)
235
+ return fixes
236
+
237
+
238
+ def call_llm(client: OpenAI, system_prompt: str, user_prompt: str) -> str:
239
+ """Call the LLM with retry on rate limit."""
240
+ for attempt in range(3):
241
+ try:
242
+ response = client.chat.completions.create(
243
+ model=MODEL_NAME,
244
+ messages=[
245
+ {"role": "system", "content": system_prompt},
246
+ {"role": "user", "content": user_prompt},
247
+ ],
248
+ temperature=0.1,
249
+ max_tokens=2048,
250
+ )
251
+ return response.choices[0].message.content or ""
252
+ except Exception as e:
253
+ if "rate_limit" in str(e).lower() or "429" in str(e):
254
+ wait = 10 * (attempt + 1)
255
+ print(f"[DEBUG] Rate limited, waiting {wait}s...", file=sys.stderr, flush=True)
256
+ time.sleep(wait)
257
+ else:
258
+ print(f"[DEBUG] LLM call failed: {e}", file=sys.stderr, flush=True)
259
+ return ""
260
+ return ""
261
+
262
+
263
  def run_task(client: OpenAI, env: EnvHTTPClient, task_id: str) -> float:
264
+ """
265
+ Run a single task with two-phase strategy:
266
+ Step 1: Identify issues only
267
+ Step 2: Identify + Fix (using feedback from step 1)
268
+ Step 3: Refined identify + fix (if needed)
269
+ """
270
  log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
271
 
272
  rewards: List[float] = []
 
275
  success = False
276
 
277
  try:
 
278
  reset_response = env.reset(task_id=task_id)
279
  observation = reset_response.get("observation", reset_response)
280
 
281
+ last_issues: list[str] = []
282
+ last_llm_output = ""
283
 
284
  for step_num in range(1, MAX_STEPS_PER_TASK + 1):
 
 
 
 
 
285
  error_msg = None
286
+
287
+ # ── Phase 1: Identify issues ──
288
+ user_prompt = build_user_prompt(observation)
289
+ identify_output = call_llm(client, IDENTIFY_SYSTEM_PROMPT, user_prompt)
290
+ issues = parse_llm_response(identify_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
 
292
  if not issues and not error_msg:
293
  error_msg = "no issues parsed from LLM response"
294
 
295
+ # ── Phase 2: Propose fixes (from step 2 onward, or always if we have issues) ──
296
+ fixes: list[str] = []
297
+ if issues and step_num >= 2:
298
+ # Build a fix prompt that includes the identified issues
299
+ fix_prompt = build_user_prompt(observation, include_fixes=True)
300
+ fix_prompt += f"\n\nISSUES FOUND:\n" + "\n".join(issues)
301
+ fix_output = call_llm(client, FIX_SYSTEM_PROMPT, fix_prompt)
302
+ fixes = parse_fix_response(fix_output)
303
+
304
+ # ── Submit to environment ──
305
+ action_str = ";".join(issues[:5]) if issues else "none"
306
+ if fixes:
307
+ action_str += "|fixes:" + ";".join(fixes[:3])
308
+
309
+ step_response = env.step(issues, fixes, task_id=task_id)
310
  observation = step_response.get("observation", step_response)
311
 
312
  reward = float(step_response.get("reward", 0.0) or 0.0)
 
326
  if done:
327
  break
328
 
329
+ last_issues = issues
330
+ last_llm_output = identify_output
 
331
 
332
  success = best_score >= 0.5
333
 
 
347
  print(f"[DEBUG] API_BASE_URL={API_BASE_URL}", file=sys.stderr, flush=True)
348
  print(f"[DEBUG] MODEL_NAME={MODEL_NAME}", file=sys.stderr, flush=True)
349
 
 
350
  env = EnvHTTPClient(ENV_URL)
351
  llm_client = OpenAI(
352
  base_url=API_BASE_URL,
353
  api_key=API_KEY or "no-key",
354
  )
355
 
 
356
  if not env.health():
357
  print("[DEBUG] Environment is not healthy. Exiting.", file=sys.stderr, flush=True)
358
  sys.exit(1)
359
 
360
  print(f"[DEBUG] Environment is healthy", file=sys.stderr, flush=True)
361
 
 
362
  scores = {}
363
  for task_id in TASKS:
364
  try:
 
368
  print(f"[DEBUG] Task {task_id} failed: {e}", file=sys.stderr, flush=True)
369
  scores[task_id] = 0.0
370
 
 
371
  avg_score = sum(scores.values()) / len(scores) if scores else 0.0
372
  print(f"\n[DEBUG] FINAL RESULTS: {scores} avg={avg_score:.3f}", file=sys.stderr, flush=True)
373
 
tests/test_environment.py CHANGED
@@ -1,16 +1,24 @@
1
- """Tests for the DataQA environment (reset, step, scoring)."""
2
 
3
  import pytest
4
  from dataqa_env.server.environment import (
5
  DataQAEnvironment,
6
  parse_issue_key,
 
7
  compute_f1,
8
  compute_weighted_reward,
 
 
 
9
  )
10
  from dataqa_env.models import DataQAAction
11
- from dataqa_env.server.tasks import PlantedIssue
12
 
13
 
 
 
 
 
14
  class TestParseIssueKey:
15
  def test_standard_format(self):
16
  assert parse_issue_key("row:3,col:salary,issue:missing_value") == "row:3,col:salary,issue:missing_value"
@@ -28,7 +36,7 @@ class TestParseIssueKey:
28
  assert parse_issue_key("this is garbage") is None
29
 
30
  def test_partial_match(self):
31
- assert parse_issue_key("row:3,col:salary") is None # missing issue
32
 
33
  def test_empty_string(self):
34
  assert parse_issue_key("") is None
@@ -38,14 +46,49 @@ class TestParseIssueKey:
38
  assert result == "row:3,col:salary,issue:missing_value"
39
 
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  class TestComputeF1:
42
  def test_perfect_match(self):
43
  keys = {"row:1,col:a,issue:missing_value"}
44
  result = compute_f1(keys, keys)
45
  assert result["f1"] == 1.0
46
- assert result["tp"] == 1
47
- assert result["fp"] == 0
48
- assert result["fn"] == 0
49
 
50
  def test_no_reported_no_planted(self):
51
  result = compute_f1(set(), set())
@@ -61,9 +104,6 @@ class TestComputeF1:
61
  reported = {"row:99,col:x,issue:wrong_type"}
62
  planted = {"row:1,col:a,issue:missing_value"}
63
  result = compute_f1(reported, planted)
64
- assert result["tp"] == 0
65
- assert result["fp"] == 1
66
- assert result["fn"] == 1
67
  assert result["f1"] == 0.0
68
 
69
  def test_partial_match(self):
@@ -83,6 +123,10 @@ class TestComputeF1:
83
  assert result["recall"] == pytest.approx(2 / 3)
84
 
85
 
 
 
 
 
86
  class TestComputeWeightedReward:
87
  def test_perfect_match(self):
88
  issues = [
@@ -101,14 +145,11 @@ class TestComputeWeightedReward:
101
  issues = [PlantedIssue(row=1, col="a", issue_type="missing_value", description="", difficulty=2.0)]
102
  result = compute_weighted_reward(set(), issues)
103
  assert result["weighted_reward"] == 0.0
104
- assert result["difficulty_missed"] == 2.0
105
 
106
  def test_hard_issue_worth_more(self):
107
  easy = PlantedIssue(row=1, col="a", issue_type="missing_value", description="", difficulty=1.0)
108
  hard = PlantedIssue(row=2, col="b", issue_type="statistical_outlier", description="", difficulty=3.0)
109
  issues = [easy, hard]
110
-
111
- # Finding only the hard issue should score higher than only the easy issue
112
  hard_found = compute_weighted_reward({hard.to_key()}, issues)
113
  easy_found = compute_weighted_reward({easy.to_key()}, issues)
114
  assert hard_found["weighted_reward"] > easy_found["weighted_reward"]
@@ -122,6 +163,92 @@ class TestComputeWeightedReward:
122
  assert r_correct["weighted_reward"] > r_with_fp["weighted_reward"]
123
 
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  class TestDataQAEnvironment:
126
  @pytest.fixture
127
  def env(self):
@@ -137,6 +264,7 @@ class TestDataQAEnvironment:
137
  assert obs.max_steps == 3
138
  assert obs.done is False
139
  assert obs.reward == 0.0
 
140
 
141
  def test_reset_medium(self, env):
142
  obs = env.reset(task_id="medium")
@@ -144,11 +272,11 @@ class TestDataQAEnvironment:
144
 
145
  def test_reset_hard(self, env):
146
  obs = env.reset(task_id="hard")
147
- assert obs.num_issues_hint == 8
148
 
149
- def test_step_with_correct_issues(self, env):
 
150
  env.reset(task_id="easy")
151
- # Submit all correct issues for easy task
152
  action = DataQAAction(
153
  issues=[
154
  "row:4,col:name,issue:missing_value",
@@ -160,7 +288,46 @@ class TestDataQAEnvironment:
160
  )
161
  obs = env.step(action)
162
  assert obs.done is True
163
- assert obs.reward >= 0.999
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
  def test_step_with_partial_issues(self, env):
166
  env.reset(task_id="easy")
@@ -186,7 +353,6 @@ class TestDataQAEnvironment:
186
  assert obs.done is True
187
 
188
  def test_auto_reset_on_step(self, env):
189
- # step() without prior reset should auto-reset
190
  action = DataQAAction(
191
  issues=["row:4,col:name,issue:missing_value"],
192
  task_id="easy",
@@ -214,19 +380,26 @@ class TestDataQAEnvironment:
214
  env.step(action1)
215
  score_after_1 = env.state.best_score
216
 
217
- # Worse submission shouldn't decrease best_score
218
  action2 = DataQAAction(issues=["row:99,col:x,issue:wrong_type"], task_id="easy")
219
  env.step(action2)
220
  assert env.state.best_score >= score_after_1
221
 
222
- def test_metadata_included_in_observation(self, env):
223
  env.reset(task_id="easy")
224
- action = DataQAAction(issues=["row:4,col:name,issue:missing_value"], task_id="easy")
 
 
 
 
225
  obs = env.step(action)
226
- assert "f1" in obs.metadata
227
- assert "weighted_reward" in obs.metadata
228
- assert "tp" in obs.metadata
229
- assert "difficulty_found" in obs.metadata
 
 
 
 
230
 
231
  def test_parse_error_in_feedback(self, env):
232
  env.reset(task_id="easy")
@@ -243,7 +416,55 @@ class TestDataQAEnvironment:
243
  for _ in range(3):
244
  action = DataQAAction(
245
  issues=["row:1,col:x,issue:wrong_type", "row:99,col:y,issue:missing_value"],
 
246
  task_id="hard",
247
  )
248
  obs = env.step(action)
249
  assert 0.0 <= obs.reward <= 1.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for the DataQA environment (reset, step, scoring, two-phase identify+fix)."""
2
 
3
  import pytest
4
  from dataqa_env.server.environment import (
5
  DataQAEnvironment,
6
  parse_issue_key,
7
+ parse_fix,
8
  compute_f1,
9
  compute_weighted_reward,
10
+ grade_fixes,
11
+ IDENTIFY_WEIGHT,
12
+ FIX_WEIGHT,
13
  )
14
  from dataqa_env.models import DataQAAction
15
+ from dataqa_env.server.tasks import PlantedIssue, create_task_easy, create_task_medium
16
 
17
 
18
+ # ──────────────────────────────────────────────────────
19
+ # Issue parsing
20
+ # ──────────────────────────────────────────────────────
21
+
22
  class TestParseIssueKey:
23
  def test_standard_format(self):
24
  assert parse_issue_key("row:3,col:salary,issue:missing_value") == "row:3,col:salary,issue:missing_value"
 
36
  assert parse_issue_key("this is garbage") is None
37
 
38
  def test_partial_match(self):
39
+ assert parse_issue_key("row:3,col:salary") is None
40
 
41
  def test_empty_string(self):
42
  assert parse_issue_key("") is None
 
46
  assert result == "row:3,col:salary,issue:missing_value"
47
 
48
 
49
+ # ──────────────────────────────────────────────────────
50
+ # Fix parsing
51
+ # ──────────────────────────────────────────────────────
52
+
53
+ class TestParseFix:
54
+ def test_standard_format(self):
55
+ result = parse_fix("row:4,col:name,fix:Alice Chen")
56
+ assert result == (4, "name", "Alice Chen")
57
+
58
+ def test_with_equals(self):
59
+ result = parse_fix("row=4,col=name,fix=Alice Chen")
60
+ assert result == (4, "name", "Alice Chen")
61
+
62
+ def test_numeric_fix(self):
63
+ result = parse_fix("row:7,col:salary,fix:75000")
64
+ assert result == (7, "salary", "75000")
65
+
66
+ def test_date_fix(self):
67
+ result = parse_fix("row:12,col:order_date,fix:2024-01-26")
68
+ assert result == (12, "order_date", "2024-01-26")
69
+
70
+ def test_case_insensitive(self):
71
+ result = parse_fix("Row:4,Col:Name,Fix:Alice Chen")
72
+ assert result == (4, "name", "Alice Chen")
73
+
74
+ def test_unparseable(self):
75
+ assert parse_fix("garbage") is None
76
+ assert parse_fix("row:4,col:name") is None
77
+
78
+ def test_fix_with_special_chars(self):
79
+ result = parse_fix("row:1,col:email,fix:alice.chen@company.com")
80
+ assert result == (1, "email", "alice.chen@company.com")
81
+
82
+
83
+ # ──────────────────────────────────────────────────────
84
+ # F1 scoring
85
+ # ──────────────────────────────────────────────────────
86
+
87
  class TestComputeF1:
88
  def test_perfect_match(self):
89
  keys = {"row:1,col:a,issue:missing_value"}
90
  result = compute_f1(keys, keys)
91
  assert result["f1"] == 1.0
 
 
 
92
 
93
  def test_no_reported_no_planted(self):
94
  result = compute_f1(set(), set())
 
104
  reported = {"row:99,col:x,issue:wrong_type"}
105
  planted = {"row:1,col:a,issue:missing_value"}
106
  result = compute_f1(reported, planted)
 
 
 
107
  assert result["f1"] == 0.0
108
 
109
  def test_partial_match(self):
 
123
  assert result["recall"] == pytest.approx(2 / 3)
124
 
125
 
126
+ # ──────────────────────────────────────────────────────
127
+ # Weighted reward
128
+ # ──────────────────────────────────────────────────────
129
+
130
  class TestComputeWeightedReward:
131
  def test_perfect_match(self):
132
  issues = [
 
145
  issues = [PlantedIssue(row=1, col="a", issue_type="missing_value", description="", difficulty=2.0)]
146
  result = compute_weighted_reward(set(), issues)
147
  assert result["weighted_reward"] == 0.0
 
148
 
149
  def test_hard_issue_worth_more(self):
150
  easy = PlantedIssue(row=1, col="a", issue_type="missing_value", description="", difficulty=1.0)
151
  hard = PlantedIssue(row=2, col="b", issue_type="statistical_outlier", description="", difficulty=3.0)
152
  issues = [easy, hard]
 
 
153
  hard_found = compute_weighted_reward({hard.to_key()}, issues)
154
  easy_found = compute_weighted_reward({easy.to_key()}, issues)
155
  assert hard_found["weighted_reward"] > easy_found["weighted_reward"]
 
163
  assert r_correct["weighted_reward"] > r_with_fp["weighted_reward"]
164
 
165
 
166
+ # ──────────────────────────────────────────────────────
167
+ # Fix grading
168
+ # ──────────────────────────────────────────────────────
169
+
170
+ class TestGradeFixes:
171
+ @pytest.fixture
172
+ def easy_task(self):
173
+ return create_task_easy()
174
+
175
+ def test_no_fixes_no_issues(self):
176
+ from dataqa_env.server.tasks import Task
177
+ task = Task(task_id="empty", name="", description="", schema_description="",
178
+ validation_rules="", clean_csv="a\n1")
179
+ result = grade_fixes([], task)
180
+ assert result["fix_score"] == 1.0
181
+
182
+ def test_no_fixes_submitted(self, easy_task):
183
+ result = grade_fixes([], easy_task)
184
+ assert result["fix_score"] == 0.0
185
+ assert result["fixes_attempted"] == 0
186
+
187
+ def test_exact_fix_for_missing_name(self, easy_task):
188
+ # Row 4 has empty name — clean value is "David Kim"
189
+ fixes = [(4, "name", "David Kim")]
190
+ result = grade_fixes(fixes, easy_task)
191
+ assert result["fix_score"] > 0.0
192
+ assert result["fixes_correct"] == 1
193
+
194
+ def test_exact_fix_for_wrong_type_salary(self, easy_task):
195
+ # Row 7 has "seventy-five thousand" — clean value is "75000"
196
+ fixes = [(7, "salary", "75000")]
197
+ result = grade_fixes(fixes, easy_task)
198
+ assert result["fixes_correct"] == 1
199
+
200
+ def test_numeric_close_match(self, easy_task):
201
+ # Row 9 has salary "5000" — clean value is "73000"
202
+ # Propose 73100 (within 1% of 73000)
203
+ fixes = [(9, "salary", "73100")]
204
+ result = grade_fixes(fixes, easy_task)
205
+ assert result["fixes_partial"] == 1
206
+
207
+ def test_wrong_value_for_issue_cell(self, easy_task):
208
+ # Row 4 name is empty — propose wrong name
209
+ fixes = [(4, "name", "Wrong Person")]
210
+ result = grade_fixes(fixes, easy_task)
211
+ assert result["fixes_partial"] == 1 # correct cell, wrong value
212
+ assert result["fix_score"] > 0.0 # gets partial credit
213
+
214
+ def test_fix_for_non_issue_cell(self, easy_task):
215
+ # Row 1 col name is fine — no issue there
216
+ fixes = [(1, "name", "Some Name")]
217
+ result = grade_fixes(fixes, easy_task)
218
+ assert result["fixes_wrong"] == 1
219
+ assert result["fix_score"] == 0.0
220
+
221
+ def test_multiple_fixes_best_wins(self, easy_task):
222
+ # Submit two fixes for same cell — best one should count
223
+ fixes = [
224
+ (4, "name", "Wrong Person"), # partial credit
225
+ (4, "name", "David Kim"), # exact match
226
+ ]
227
+ result = grade_fixes(fixes, easy_task)
228
+ assert result["fixes_correct"] >= 1
229
+
230
+ def test_all_fixes_correct(self, easy_task):
231
+ # Fix all 4 issues with exact values
232
+ fixes = [
233
+ (4, "name", "David Kim"),
234
+ (7, "salary", "75000"),
235
+ (9, "salary", "73000"),
236
+ # Row 11 is duplicate — clean value for employee_id is "Bob Martinez" row
237
+ # The duplicate is of row 2 (Bob Martinez), so the clean row 11 doesn't exist
238
+ ]
239
+ result = grade_fixes(fixes, easy_task)
240
+ assert result["fix_score"] > 0.5 # at least 3/4 issues fixed
241
+
242
+ def test_fix_score_bounded(self, easy_task):
243
+ fixes = [(4, "name", "David Kim"), (99, "x", "bad")]
244
+ result = grade_fixes(fixes, easy_task)
245
+ assert 0.0 <= result["fix_score"] <= 1.0
246
+
247
+
248
+ # ──────────────────────────────────────────────────────
249
+ # Full environment lifecycle
250
+ # ──────────────────────────────────────────────────────
251
+
252
  class TestDataQAEnvironment:
253
  @pytest.fixture
254
  def env(self):
 
264
  assert obs.max_steps == 3
265
  assert obs.done is False
266
  assert obs.reward == 0.0
267
+ assert "fix" in obs.feedback.lower() # mentions fix phase
268
 
269
  def test_reset_medium(self, env):
270
  obs = env.reset(task_id="medium")
 
272
 
273
  def test_reset_hard(self, env):
274
  obs = env.reset(task_id="hard")
275
+ assert obs.num_issues_hint == 10
276
 
277
+ def test_step_identify_only(self, env):
278
+ """Backward compatible: only issues, no fixes."""
279
  env.reset(task_id="easy")
 
280
  action = DataQAAction(
281
  issues=[
282
  "row:4,col:name,issue:missing_value",
 
288
  )
289
  obs = env.step(action)
290
  assert obs.done is True
291
+ assert obs.reward >= 0.999 # identify-only uses identify_score directly
292
+
293
+ def test_step_with_fixes_increases_reward(self, env):
294
+ """Submitting correct fixes should increase reward beyond identify-only."""
295
+ env.reset(task_id="easy")
296
+ # Step 1: identify only
297
+ action1 = DataQAAction(
298
+ issues=[
299
+ "row:4,col:name,issue:missing_value",
300
+ "row:7,col:salary,issue:wrong_type",
301
+ "row:11,col:employee_id,issue:duplicate_row",
302
+ "row:9,col:salary,issue:out_of_range",
303
+ ],
304
+ task_id="easy",
305
+ )
306
+ obs1 = env.step(action1)
307
+ score_identify = obs1.reward
308
+
309
+ # Reset for fair comparison
310
+ env.reset(task_id="easy")
311
+ # Step with identify + fixes
312
+ action2 = DataQAAction(
313
+ issues=[
314
+ "row:4,col:name,issue:missing_value",
315
+ "row:7,col:salary,issue:wrong_type",
316
+ "row:11,col:employee_id,issue:duplicate_row",
317
+ "row:9,col:salary,issue:out_of_range",
318
+ ],
319
+ fixes=[
320
+ "row:4,col:name,fix:David Kim",
321
+ "row:7,col:salary,fix:75000",
322
+ "row:9,col:salary,fix:73000",
323
+ ],
324
+ task_id="easy",
325
+ )
326
+ obs2 = env.step(action2)
327
+ score_with_fixes = obs2.metadata["combined_reward"]
328
+
329
+ # With correct fixes, combined should be close to 1.0
330
+ assert score_with_fixes > 0.8
331
 
332
  def test_step_with_partial_issues(self, env):
333
  env.reset(task_id="easy")
 
353
  assert obs.done is True
354
 
355
  def test_auto_reset_on_step(self, env):
 
356
  action = DataQAAction(
357
  issues=["row:4,col:name,issue:missing_value"],
358
  task_id="easy",
 
380
  env.step(action1)
381
  score_after_1 = env.state.best_score
382
 
 
383
  action2 = DataQAAction(issues=["row:99,col:x,issue:wrong_type"], task_id="easy")
384
  env.step(action2)
385
  assert env.state.best_score >= score_after_1
386
 
387
+ def test_metadata_includes_both_phases(self, env):
388
  env.reset(task_id="easy")
389
+ action = DataQAAction(
390
+ issues=["row:4,col:name,issue:missing_value"],
391
+ fixes=["row:4,col:name,fix:David Kim"],
392
+ task_id="easy",
393
+ )
394
  obs = env.step(action)
395
+ m = obs.metadata
396
+ assert "identify_f1" in m
397
+ assert "identify_score" in m
398
+ assert "fix_score" in m
399
+ assert "combined_reward" in m
400
+ assert "tp" in m
401
+ assert "fixes_correct" in m
402
+ assert "fixes_attempted" in m
403
 
404
  def test_parse_error_in_feedback(self, env):
405
  env.reset(task_id="easy")
 
416
  for _ in range(3):
417
  action = DataQAAction(
418
  issues=["row:1,col:x,issue:wrong_type", "row:99,col:y,issue:missing_value"],
419
+ fixes=["row:1,col:x,fix:wrong"],
420
  task_id="hard",
421
  )
422
  obs = env.step(action)
423
  assert 0.0 <= obs.reward <= 1.0
424
+
425
+ def test_combined_reward_weights(self, env):
426
+ """Verify combined = IDENTIFY_WEIGHT * identify + FIX_WEIGHT * fix."""
427
+ env.reset(task_id="easy")
428
+ action = DataQAAction(
429
+ issues=[
430
+ "row:4,col:name,issue:missing_value",
431
+ "row:7,col:salary,issue:wrong_type",
432
+ "row:11,col:employee_id,issue:duplicate_row",
433
+ "row:9,col:salary,issue:out_of_range",
434
+ ],
435
+ fixes=["row:4,col:name,fix:David Kim"],
436
+ task_id="easy",
437
+ )
438
+ obs = env.step(action)
439
+ m = obs.metadata
440
+ expected = IDENTIFY_WEIGHT * m["identify_score"] + FIX_WEIGHT * m["fix_score"]
441
+ assert abs(m["combined_reward"] - expected) < 0.01
442
+
443
+ def test_fix_feedback_shown_when_fixes_submitted(self, env):
444
+ env.reset(task_id="easy")
445
+ action = DataQAAction(
446
+ issues=["row:4,col:name,issue:missing_value"],
447
+ fixes=["row:4,col:name,fix:David Kim"],
448
+ task_id="easy",
449
+ )
450
+ obs = env.step(action)
451
+ assert "Fix Proposals" in obs.feedback
452
+ assert "Combined Reward" in obs.feedback
453
+
454
+ def test_no_fix_penalty_when_no_fixes_submitted(self, env):
455
+ """If agent submits no fixes, reward = identify_score (no penalty)."""
456
+ env.reset(task_id="easy")
457
+ action = DataQAAction(
458
+ issues=[
459
+ "row:4,col:name,issue:missing_value",
460
+ "row:7,col:salary,issue:wrong_type",
461
+ "row:11,col:employee_id,issue:duplicate_row",
462
+ "row:9,col:salary,issue:out_of_range",
463
+ ],
464
+ task_id="easy",
465
+ )
466
+ obs = env.step(action)
467
+ # identify_score should be ~1.0 since all issues found
468
+ assert obs.reward >= 0.99
469
+ # combined_reward equals identify_score when no fixes
470
+ assert obs.metadata["combined_reward"] == obs.metadata["identify_score"]
tests/test_extensibility.py CHANGED
@@ -151,8 +151,8 @@ class TestRegisterTask:
151
 
152
 
153
  class TestCustomTaskInEnvironment:
154
- def test_full_lifecycle(self):
155
- """Custom task works end-to-end in the environment."""
156
  task = create_task_from_config(
157
  task_id="e2e_custom",
158
  name="E2E Custom",
@@ -171,7 +171,6 @@ class TestCustomTaskInEnvironment:
171
  obs = env.reset(task_id="e2e_custom")
172
  assert obs.num_issues_hint == 2
173
 
174
- # Submit correct answers
175
  action = DataQAAction(
176
  issues=[i.to_key() for i in task.planted_issues],
177
  task_id="e2e_custom",
@@ -180,6 +179,37 @@ class TestCustomTaskInEnvironment:
180
  assert obs.done is True
181
  assert obs.reward >= 0.999
182
 
183
- # Cleanup
184
  from dataqa_env.server.tasks import TASK_REGISTRY
185
  del TASK_REGISTRY["e2e_custom"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
 
153
  class TestCustomTaskInEnvironment:
154
+ def test_full_lifecycle_identify_only(self):
155
+ """Custom task works end-to-end with identify-only."""
156
  task = create_task_from_config(
157
  task_id="e2e_custom",
158
  name="E2E Custom",
 
171
  obs = env.reset(task_id="e2e_custom")
172
  assert obs.num_issues_hint == 2
173
 
 
174
  action = DataQAAction(
175
  issues=[i.to_key() for i in task.planted_issues],
176
  task_id="e2e_custom",
 
179
  assert obs.done is True
180
  assert obs.reward >= 0.999
181
 
 
182
  from dataqa_env.server.tasks import TASK_REGISTRY
183
  del TASK_REGISTRY["e2e_custom"]
184
+
185
+ def test_full_lifecycle_identify_and_fix(self):
186
+ """Custom task works end-to-end with both identify and fix."""
187
+ task = create_task_from_config(
188
+ task_id="e2e_fix",
189
+ name="E2E Fix",
190
+ description="End-to-end test with fixes",
191
+ schema_description="id: int, name: str, score: int",
192
+ validation_rules="No missing values",
193
+ clean_csv=SIMPLE_CSV,
194
+ contaminations=[
195
+ {"rule": "missing_value", "row": 0, "col": 1, "difficulty": 1.0},
196
+ ],
197
+ )
198
+ register_task("e2e_fix", lambda seed: task)
199
+
200
+ env = DataQAEnvironment()
201
+ env.reset(task_id="e2e_fix")
202
+
203
+ # Submit issues + fix
204
+ action = DataQAAction(
205
+ issues=[task.planted_issues[0].to_key()],
206
+ fixes=["row:1,col:name,fix:Alice"], # clean value is "Alice"
207
+ task_id="e2e_fix",
208
+ )
209
+ obs = env.step(action)
210
+ assert obs.done is True
211
+ assert obs.metadata["fix_score"] > 0.0
212
+ assert obs.metadata["combined_reward"] > 0.0
213
+
214
+ from dataqa_env.server.tasks import TASK_REGISTRY
215
+ del TASK_REGISTRY["e2e_fix"]
tests/test_inference.py CHANGED
@@ -1,12 +1,11 @@
1
  """Tests for the inference script's parsing, prompt building, and log format."""
2
 
3
- import re
4
  import pytest
5
  import sys
6
  import os
7
 
8
  sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
9
- from inference import parse_llm_response, build_user_prompt, log_start, log_step, log_end
10
 
11
 
12
  class TestParseLLMResponse:
@@ -50,7 +49,7 @@ class TestParseLLMResponse:
50
  def test_deduplication_not_applied(self):
51
  response = "row:1,col:name,issue:missing_value\nrow:1,col:name,issue:missing_value"
52
  issues = parse_llm_response(response)
53
- assert len(issues) == 2 # duplicates kept, env handles dedup
54
 
55
  def test_with_column_variant(self):
56
  response = "row:1,column:name,issue:missing_value"
@@ -58,6 +57,38 @@ class TestParseLLMResponse:
58
  assert len(issues) == 1
59
 
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  class TestBuildUserPrompt:
62
  def test_includes_all_fields(self):
63
  obs = {
@@ -100,6 +131,18 @@ class TestBuildUserPrompt:
100
  prompt = build_user_prompt(obs)
101
  assert "FEEDBACK" not in prompt
102
 
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  class TestLogFormat:
105
  """Verify stdout log format matches hackathon evaluation requirements."""
 
1
  """Tests for the inference script's parsing, prompt building, and log format."""
2
 
 
3
  import pytest
4
  import sys
5
  import os
6
 
7
  sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
8
+ from inference import parse_llm_response, parse_fix_response, build_user_prompt, log_start, log_step, log_end
9
 
10
 
11
  class TestParseLLMResponse:
 
49
  def test_deduplication_not_applied(self):
50
  response = "row:1,col:name,issue:missing_value\nrow:1,col:name,issue:missing_value"
51
  issues = parse_llm_response(response)
52
+ assert len(issues) == 2
53
 
54
  def test_with_column_variant(self):
55
  response = "row:1,column:name,issue:missing_value"
 
57
  assert len(issues) == 1
58
 
59
 
60
+ class TestParseFixResponse:
61
+ def test_standard_format(self):
62
+ response = "row:4,col:name,fix:David Kim\nrow:7,col:salary,fix:75000"
63
+ fixes = parse_fix_response(response)
64
+ assert len(fixes) == 2
65
+ assert "row:4,col:name,fix:David Kim" in fixes
66
+
67
+ def test_numbered_list(self):
68
+ response = "1. row:4,col:name,fix:David Kim\n2. row:7,col:salary,fix:75000"
69
+ fixes = parse_fix_response(response)
70
+ assert len(fixes) == 2
71
+
72
+ def test_with_special_chars(self):
73
+ response = "row:1,col:email,fix:alice.chen@company.com"
74
+ fixes = parse_fix_response(response)
75
+ assert len(fixes) == 1
76
+ assert "alice.chen@company.com" in fixes[0]
77
+
78
+ def test_empty_response(self):
79
+ assert parse_fix_response("") == []
80
+
81
+ def test_date_fix(self):
82
+ response = "row:12,col:order_date,fix:2024-01-26"
83
+ fixes = parse_fix_response(response)
84
+ assert len(fixes) == 1
85
+
86
+ def test_ignores_issue_lines(self):
87
+ response = "row:4,col:name,issue:missing_value\nrow:4,col:name,fix:David Kim"
88
+ fixes = parse_fix_response(response)
89
+ assert len(fixes) == 1 # only the fix line
90
+
91
+
92
  class TestBuildUserPrompt:
93
  def test_includes_all_fields(self):
94
  obs = {
 
131
  prompt = build_user_prompt(obs)
132
  assert "FEEDBACK" not in prompt
133
 
134
+ def test_include_fixes_flag(self):
135
+ obs = {
136
+ "task_description": "Find issues",
137
+ "schema_description": "",
138
+ "validation_rules": "",
139
+ "dataset_csv": "a\n1",
140
+ "num_issues_hint": 0,
141
+ "feedback": "",
142
+ }
143
+ prompt = build_user_prompt(obs, include_fixes=True)
144
+ assert "fix" in prompt.lower()
145
+
146
 
147
  class TestLogFormat:
148
  """Verify stdout log format matches hackathon evaluation requirements."""
tests/test_tasks.py CHANGED
@@ -113,8 +113,8 @@ class TestTaskHard:
113
  def test_task_id(self, task):
114
  assert task.task_id == "hard"
115
 
116
- def test_has_8_issues(self, task):
117
- assert len(task.planted_issues) == 8
118
 
119
  def test_issue_types(self, task):
120
  types = {i.issue_type for i in task.planted_issues}
 
113
  def test_task_id(self, task):
114
  assert task.task_id == "hard"
115
 
116
+ def test_has_10_issues(self, task):
117
+ assert len(task.planted_issues) == 10
118
 
119
  def test_issue_types(self, task):
120
  types = {i.issue_type for i in task.planted_issues}