avanigupta commited on
Commit
cd11aba
·
1 Parent(s): 4c1a85d

fixes v1: add per step reward

Browse files
README.md CHANGED
@@ -1,6 +1,6 @@
1
  ---
2
  title: DataQA Environment Server
3
- emoji: 🔍
4
  colorFrom: blue
5
  colorTo: gray
6
  sdk: docker
@@ -15,49 +15,173 @@ tags:
15
 
16
  An OpenEnv environment for **Data Quality Assurance** — an LLM agent inspects datasets with planted quality issues and must identify them all.
17
 
18
- ## Overview
19
 
20
- DataQA simulates the real-world task of validating datasets before they enter ML training pipelines or production databases. The agent receives a corrupted dataset along with its schema and validation rules, then must identify all planted data quality issues.
21
 
22
- ### Why Data QA?
23
-
24
- Every ML engineer and data scientist spends significant time debugging data quality issues — missing values, type mismatches, inconsistencies, and subtle statistical anomalies. This environment turns that task into a structured, gradable challenge.
25
 
26
  ## Environment API
27
 
28
- | Endpoint | Description |
29
- |----------|-------------|
30
- | `reset(task_id)` | Start a new episode with a corrupted dataset |
31
- | `step(issues)` | Submit identified issues, receive F1-scored feedback |
32
- | `state()` | Get current episode state |
 
33
 
34
  ## Tasks
35
 
36
- | Task | Issues | Difficulty | Description |
37
- |------|--------|-----------|-------------|
38
- | `easy` | 4 | Beginner | Employee directory nulls, wrong types, duplicates, out-of-range |
39
- | `medium` | 6 | Intermediate | E-commerce orders format violations, inconsistent totals, duplicate keys |
40
- | `hard` | 8 | Advanced | ML experiment metadata data leakage signals, unreasonable GPU usage, timestamp ordering |
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  ## Reward Function
43
 
44
- Scoring uses **F1 score** (harmonic mean of precision and recall):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- - **Precision**: What fraction of reported issues are real?
47
- - **Recall**: What fraction of planted issues did the agent find?
48
- - **F1**: `2 * precision * recall / (precision + recall)`
49
 
50
- Issues are matched by `row:<N>,col:<column>,issue:<type>` keys.
51
 
52
- The agent gets up to 3 attempts per task with feedback on each attempt (true positives, false positives, missed count).
 
 
 
 
53
 
54
- ## Action/Observation Space
55
 
56
- **Action**: List of issue strings in format `row:<row_number>,col:<column_name>,issue:<issue_type>`
57
 
58
- **Observation**: Dataset CSV + schema + validation rules + feedback from previous attempt
59
 
60
- **Issue Types**: `missing_value`, `wrong_type`, `duplicate_row`, `out_of_range`, `format_violation`, `inconsistent_value`, `statistical_outlier`, `referential_integrity`
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  ## Quick Start
63
 
@@ -68,42 +192,75 @@ pip install -e .
68
  # Run server locally
69
  uvicorn dataqa_env.server.app:app --host 0.0.0.0 --port 8000
70
 
71
- # Run inference
72
- API_BASE_URL=https://api.groq.com/openai/v1 \
73
- MODEL_NAME=llama-3.3-70b-versatile \
74
- LLM_API_KEY=your-key \
75
  python inference.py
76
  ```
77
 
78
  ## Docker
79
 
80
  ```bash
81
- docker build -t dataqa-env -f dataqa_env/server/Dockerfile .
82
  docker run -p 8000:8000 dataqa-env
83
  ```
84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  ## Environment Variables
86
 
87
  | Variable | Description | Default |
88
  |----------|-------------|---------|
89
- | `API_BASE_URL` | LLM API endpoint | `https://api.groq.com/openai/v1` |
90
- | `MODEL_NAME` | Model identifier | `llama-3.3-70b-versatile` |
91
- | `HF_TOKEN` | HuggingFace token | - |
92
  | `ENV_URL` | Environment server URL | `http://localhost:8000` |
93
- | `LLM_API_KEY` | API key for LLM provider | Falls back to HF_TOKEN |
94
 
95
  ## Architecture
96
 
97
  ```
98
  dataqa_env/
 
99
  ├── models.py # Pydantic: DataQAAction, DataQAObservation, DataQAState
100
  ├── client.py # EnvClient for WebSocket connections
101
  ├── server/
102
- │ ├── environment.py # Core DataQAEnvironment (reset/step/state)
103
- │ ├── tasks.py # Task definitions + data corruption + grading
104
- │ ├── app.py # FastAPI server
105
  │ └── Dockerfile
106
- ├── openenv.yaml
107
- ├── pyproject.toml
108
- ── inference.py # LLM agent using OpenAI client
 
 
 
 
 
 
109
  ```
 
1
  ---
2
  title: DataQA Environment Server
3
+ emoji: "\U0001F50D"
4
  colorFrom: blue
5
  colorTo: gray
6
  sdk: docker
 
15
 
16
  An OpenEnv environment for **Data Quality Assurance** — an LLM agent inspects datasets with planted quality issues and must identify them all.
17
 
18
+ ## Motivation
19
 
20
+ Every ML engineer and data scientist spends significant time debugging data quality issues missing values, type mismatches, logical inconsistencies, and subtle statistical anomalies before data enters ML pipelines or production databases. This is a genuine, high-frequency human task that directly impacts model quality and business outcomes.
21
 
22
+ DataQA turns this into a structured, gradable RL environment where agents must systematically inspect corrupted datasets, reason about schema constraints and validation rules, and pinpoint every planted issue — from obvious nulls to subtle data leakage signals that require domain expertise.
 
 
23
 
24
  ## Environment API
25
 
26
+ | Endpoint | Method | Description |
27
+ |----------|--------|-------------|
28
+ | `/reset` | POST | Start a new episode with a corrupted dataset |
29
+ | `/step` | POST | Submit identified issues, receive scored feedback |
30
+ | `/state` | GET | Get current episode state |
31
+ | `/health` | GET | Health check |
32
 
33
  ## Tasks
34
 
35
+ | Task | Issues | Difficulty | Domain | Description |
36
+ |------|--------|-----------|--------|-------------|
37
+ | `easy` | 4 | Beginner | HR/Employee data | Nulls, wrong types, duplicates, out-of-range values |
38
+ | `medium` | 6 | Intermediate | E-commerce orders | Format violations, inconsistent computed fields, duplicate keys |
39
+ | `hard` | 8 | Advanced | ML experiment metadata | Data leakage signals, unreasonable GPU memory, timestamp ordering, whitespace-only fields |
40
+
41
+ **Difficulty progression**: Easy issues are individually obvious (empty fields, text in numeric columns). Medium issues require cross-column reasoning (total != qty * price) and set membership checks. Hard issues require ML domain knowledge (val_loss < train_loss = data leakage) and multi-row temporal reasoning.
42
+
43
+ ## Action Space
44
+
45
+ The agent submits a list of issue strings, each in the format:
46
+ ```
47
+ row:<row_number>,col:<column_name>,issue:<issue_type>
48
+ ```
49
+
50
+ - `row_number`: 1-indexed position in the CSV data (after header). Row 1 = first data row.
51
+ - `column_name`: Exact column header name, lowercase.
52
+ - `issue_type`: One of the supported types below.
53
+
54
+ **Supported Issue Types:**
55
+
56
+ | Type | Description | Example |
57
+ |------|-------------|---------|
58
+ | `missing_value` | Null, empty, or whitespace-only | Empty name field |
59
+ | `wrong_type` | Value doesn't match expected type | Salary as "seventy-five thousand" |
60
+ | `duplicate_row` | Exact duplicate or duplicate key | Two rows with same employee_id |
61
+ | `out_of_range` | Value outside valid range | Salary of 5000 when min is 50000 |
62
+ | `format_violation` | Wrong format or invalid enum | Date as DD/MM/YYYY instead of YYYY-MM-DD |
63
+ | `inconsistent_value` | Computed field mismatch, logical inconsistency | total != qty * price |
64
+ | `statistical_outlier` | Unreasonable value given context | resnet18 using 42.5GB GPU |
65
+ | `referential_integrity` | Foreign key violation | (available for custom tasks) |
66
+
67
+ ## Observation Space
68
+
69
+ Each observation contains:
70
+
71
+ | Field | Type | Description |
72
+ |-------|------|-------------|
73
+ | `dataset_csv` | str | The corrupted dataset in CSV format |
74
+ | `schema_description` | str | Column types, ranges, and constraints |
75
+ | `validation_rules` | str | Business rules the data must satisfy |
76
+ | `task_description` | str | Task context and instructions |
77
+ | `feedback` | str | Results from previous step (TP/FP/FN counts, precision/recall) |
78
+ | `num_issues_hint` | int | Exact count of planted issues |
79
+ | `max_steps` | int | Maximum attempts allowed |
80
+ | `done` | bool | Whether episode has terminated |
81
+ | `reward` | float | Best weighted reward so far (0.0-1.0) |
82
+
83
+ **Observation Metadata** (available after each step):
84
+ - `f1`, `weighted_reward`, `precision`, `recall`
85
+ - `tp`, `fp`, `fn`
86
+ - `difficulty_found`, `difficulty_missed`
87
 
88
  ## Reward Function
89
 
90
+ ### Difficulty-Weighted Reward (Primary)
91
+
92
+ Each planted issue has a **difficulty weight** (1.0-3.0) reflecting how hard it is to detect. The primary reward is a **weighted F1 score** that provides meaningful per-step partial progress signals:
93
+
94
+ | Weight | Category | Examples |
95
+ |--------|----------|----------|
96
+ | 1.0 | Easy | Missing values, obvious out-of-range, wrong type |
97
+ | 1.5-2.0 | Medium | Duplicate keys, format violations, cross-column checks |
98
+ | 2.5-3.0 | Hard | Data leakage, statistical outliers, whitespace-only |
99
+
100
+ **Formula:**
101
+ - **Weighted Recall** = (sum of difficulty weights for found issues) / (total difficulty weight)
102
+ - **Weighted Precision** = (found weight) / (found weight + FP count * avg difficulty)
103
+ - **Weighted F1** = harmonic mean of weighted precision and recall
104
+
105
+ This means:
106
+ - Finding a hard issue (difficulty 3.0) increases reward 3x more than finding an easy one (1.0)
107
+ - False positives are penalized proportionally to average issue difficulty
108
+ - The agent sees meaningful reward differences at every step, not just pass/fail
109
+
110
+ ### Standard F1 (also computed)
111
+
112
+ Available in observation metadata for comparison. Uses unweighted set matching.
113
+
114
+ ### Episode Boundaries
115
+
116
+ - Each task allows up to 3 steps (attempts)
117
+ - Episode ends when F1 >= 0.999 (perfect) or max steps reached
118
+ - Best score across all steps is the final reward (monotonically non-decreasing)
119
+ - Reward is always in [0.0, 1.0]
120
 
121
+ ## Baseline Scores
 
 
122
 
123
+ Baseline scores using Qwen2.5-72B-Instruct via HuggingFace Router:
124
 
125
+ | Task | Expected Score Range | Description |
126
+ |------|---------------------|-------------|
127
+ | `easy` | 0.7 - 1.0 | Most LLMs find obvious issues reliably |
128
+ | `medium` | 0.5 - 0.8 | Cross-column reasoning is challenging |
129
+ | `hard` | 0.3 - 0.6 | ML domain knowledge and subtle patterns |
130
 
131
+ Scores vary by model capability. Frontier models (GPT-4, Claude) typically score higher on the hard task due to better domain reasoning.
132
 
133
+ ## Extensibility
134
 
135
+ DataQA supports custom tasks, contamination rules, and difficulty levels via a programmatic API.
136
 
137
+ ### Custom Contamination Rules
138
+
139
+ ```python
140
+ from dataqa_env import register_contamination_rule
141
+ from dataqa_env.server.tasks import PlantedIssue
142
+
143
+ def swap_digits(rows, header, col_idx, row_idx, rng):
144
+ val = rows[row_idx][col_idx]
145
+ corrupted = val[::-1]
146
+ issue = PlantedIssue(
147
+ row=row_idx + 1, col=header[col_idx],
148
+ issue_type="format_violation",
149
+ description=f"Digits swapped in {header[col_idx]}",
150
+ difficulty=2.0,
151
+ )
152
+ return corrupted, issue
153
+
154
+ register_contamination_rule("swap_digits", swap_digits)
155
+ ```
156
+
157
+ ### Custom Tasks from Config
158
+
159
+ ```python
160
+ from dataqa_env import create_task_from_config, register_task
161
+
162
+ task = create_task_from_config(
163
+ task_id="custom",
164
+ name="Custom Validation",
165
+ description="Find quality issues in this dataset.",
166
+ schema_description="id: int, name: str, score: int (0-100)",
167
+ validation_rules="No missing values. Scores must be 0-100.",
168
+ clean_csv="id,name,score\n1,Alice,95\n2,Bob,87\n3,Carol,92",
169
+ contaminations=[
170
+ {"rule": "missing_value", "row": 0, "col": 1, "difficulty": 1.0},
171
+ {"rule": "negative_value", "row": 2, "col": 2, "difficulty": 1.5},
172
+ ],
173
+ )
174
+ register_task("custom", lambda seed: task)
175
+ ```
176
+
177
+ ### Built-in Contamination Rules
178
+
179
+ | Rule | Effect | Default Difficulty |
180
+ |------|--------|--------------------|
181
+ | `missing_value` | Sets field to empty string | 1.0 |
182
+ | `whitespace_value` | Sets field to single space | 2.5 |
183
+ | `wrong_type_text` | Replaces with random text ("N/A", "null", etc.) | 1.0 |
184
+ | `negative_value` | Negates numeric value | 1.0 |
185
 
186
  ## Quick Start
187
 
 
192
  # Run server locally
193
  uvicorn dataqa_env.server.app:app --host 0.0.0.0 --port 8000
194
 
195
+ # Run inference (set your API credentials)
196
+ API_BASE_URL=https://router.huggingface.co/v1 \
197
+ MODEL_NAME=Qwen/Qwen2.5-72B-Instruct \
198
+ HF_TOKEN=your-token \
199
  python inference.py
200
  ```
201
 
202
  ## Docker
203
 
204
  ```bash
205
+ docker build -t dataqa-env .
206
  docker run -p 8000:8000 dataqa-env
207
  ```
208
 
209
+ ## Testing
210
+
211
+ ```bash
212
+ pip install -e ".[dev]"
213
+ pytest tests/ -v
214
+ ```
215
+
216
+ 89 tests covering:
217
+ - Task creation, corruption, and issue planting (difficulty weights, seed determinism)
218
+ - Issue key parsing (standard, lenient, edge cases)
219
+ - F1 and difficulty-weighted reward computation
220
+ - Full environment reset/step lifecycle
221
+ - Inference script parsing and prompt building
222
+ - **Structured log format** ([START], [STEP], [END] — exact field names and ordering)
223
+ - Score bounds (0.0-1.0), best-score monotonicity
224
+ - Extensibility API (custom rules, custom tasks, environment integration)
225
+
226
+ ## Validation
227
+
228
+ ```bash
229
+ # OpenEnv spec validation
230
+ openenv validate .
231
+
232
+ # Pre-submission validation (requires HF Space URL)
233
+ ./prevalidation_script.sh https://your-space.hf.space
234
+ ```
235
+
236
  ## Environment Variables
237
 
238
  | Variable | Description | Default |
239
  |----------|-------------|---------|
240
+ | `API_BASE_URL` | LLM API endpoint | `https://router.huggingface.co/v1` |
241
+ | `MODEL_NAME` | Model identifier | `Qwen/Qwen2.5-72B-Instruct` |
242
+ | `HF_TOKEN` | HuggingFace token / API key | - |
243
  | `ENV_URL` | Environment server URL | `http://localhost:8000` |
 
244
 
245
  ## Architecture
246
 
247
  ```
248
  dataqa_env/
249
+ ├── __init__.py # Public API + extensibility exports
250
  ├── models.py # Pydantic: DataQAAction, DataQAObservation, DataQAState
251
  ├── client.py # EnvClient for WebSocket connections
252
  ├── server/
253
+ │ ├── environment.py # Core DataQAEnvironment (reset/step/state + weighted rewards)
254
+ │ ├── tasks.py # Task definitions + contamination rules + extensibility API
255
+ │ ├── app.py # FastAPI server (via openenv-core create_app)
256
  │ └── Dockerfile
257
+ tests/
258
+ ├── test_tasks.py # Task creation, corruption, difficulty weights
259
+ ── test_environment.py # Environment lifecycle, scoring, metadata
260
+ ├── test_inference.py # LLM response parsing, prompt building, log format
261
+ └── test_extensibility.py # Custom rules, custom tasks, registration API
262
+ inference.py # Baseline LLM agent (OpenAI client, structured logs)
263
+ openenv.yaml # OpenEnv/HF Spaces spec
264
+ pyproject.toml # Package metadata and dependencies
265
+ Dockerfile # Production container
266
  ```
dataqa_env/__init__.py CHANGED
@@ -1,4 +1,19 @@
1
  from .client import DataQAEnv
2
  from .models import DataQAAction, DataQAObservation, DataQAState
 
 
 
 
 
 
3
 
4
- __all__ = ["DataQAEnv", "DataQAAction", "DataQAObservation", "DataQAState"]
 
 
 
 
 
 
 
 
 
 
1
  from .client import DataQAEnv
2
  from .models import DataQAAction, DataQAObservation, DataQAState
3
+ from .server.tasks import (
4
+ create_task_from_config,
5
+ register_task,
6
+ register_contamination_rule,
7
+ CONTAMINATION_RULES,
8
+ )
9
 
10
+ __all__ = [
11
+ "DataQAEnv",
12
+ "DataQAAction",
13
+ "DataQAObservation",
14
+ "DataQAState",
15
+ "create_task_from_config",
16
+ "register_task",
17
+ "register_contamination_rule",
18
+ "CONTAMINATION_RULES",
19
+ ]
dataqa_env/server/environment.py CHANGED
@@ -58,6 +58,61 @@ def compute_f1(reported_keys: Set[str], planted_keys: Set[str]) -> dict:
58
  return {"precision": precision, "recall": recall, "f1": f1, "tp": tp, "fp": fp, "fn": fn}
59
 
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  class DataQAEnvironment(Environment):
62
  """
63
  Data Quality Assurance environment.
@@ -138,15 +193,21 @@ class DataQAEnvironment(Environment):
138
  else:
139
  parse_errors.append(f"Could not parse: '{raw_issue}'")
140
 
141
- # Compute score
142
  metrics = compute_f1(reported_keys, self._planted_keys)
143
  score = metrics["f1"]
144
- self._best_score = max(self._best_score, score)
 
 
 
 
 
 
145
  self._state.best_score = self._best_score
146
 
147
  # Check if done
148
  is_done = (
149
- score >= 0.999 # Perfect score
150
  or self._state.current_step >= self._state.max_steps
151
  )
152
 
@@ -156,6 +217,7 @@ class DataQAEnvironment(Environment):
156
  f"Issues reported: {len(reported_keys)}",
157
  f"True positives: {metrics['tp']}, False positives: {metrics['fp']}, Missed: {metrics['fn']}",
158
  f"Precision: {metrics['precision']:.3f}, Recall: {metrics['recall']:.3f}, F1: {score:.3f}",
 
159
  ]
160
 
161
  if parse_errors:
@@ -173,7 +235,7 @@ class DataQAEnvironment(Environment):
173
  )
174
  feedback_lines.append("You can submit again with an updated list of issues.")
175
  else:
176
- feedback_lines.append(f"Task complete! Final best F1 score: {self._best_score:.3f}")
177
 
178
  return DataQAObservation(
179
  dataset_csv=self._current_task.corrupted_csv,
@@ -186,6 +248,17 @@ class DataQAEnvironment(Environment):
186
  max_steps=self._state.max_steps,
187
  done=is_done,
188
  reward=self._best_score,
 
 
 
 
 
 
 
 
 
 
 
189
  )
190
 
191
  @property
 
58
  return {"precision": precision, "recall": recall, "f1": f1, "tp": tp, "fp": fp, "fn": fn}
59
 
60
 
61
+ def compute_weighted_reward(
62
+ reported_keys: Set[str],
63
+ planted_issues: list,
64
+ ) -> dict:
65
+ """
66
+ Compute difficulty-weighted reward for richer per-step signal.
67
+
68
+ Each planted issue has a difficulty weight (1.0-3.0). Finding harder issues
69
+ earns more reward. False positives incur a penalty scaled by average difficulty.
70
+
71
+ Returns dict with weighted_reward (0.0-1.0), plus per-issue breakdown.
72
+ """
73
+ if not planted_issues and not reported_keys:
74
+ return {"weighted_reward": 1.0, "difficulty_found": 0.0, "difficulty_missed": 0.0}
75
+
76
+ planted_by_key = {issue.to_key(): issue for issue in planted_issues}
77
+ planted_keys = set(planted_by_key.keys())
78
+
79
+ if not reported_keys:
80
+ total_weight = sum(i.difficulty for i in planted_issues)
81
+ return {"weighted_reward": 0.0, "difficulty_found": 0.0, "difficulty_missed": total_weight}
82
+
83
+ if not planted_keys:
84
+ return {"weighted_reward": 0.0, "difficulty_found": 0.0, "difficulty_missed": 0.0}
85
+
86
+ # Sum difficulty weights for found vs missed issues
87
+ found_keys = reported_keys & planted_keys
88
+ missed_keys = planted_keys - reported_keys
89
+ false_positive_count = len(reported_keys - planted_keys)
90
+
91
+ difficulty_found = sum(planted_by_key[k].difficulty for k in found_keys)
92
+ difficulty_missed = sum(planted_by_key[k].difficulty for k in missed_keys)
93
+ total_weight = sum(i.difficulty for i in planted_issues)
94
+
95
+ # Weighted recall: proportion of difficulty captured
96
+ weighted_recall = difficulty_found / total_weight if total_weight > 0 else 0.0
97
+
98
+ # Penalty for false positives (scaled by avg difficulty so penalty is proportional)
99
+ avg_difficulty = total_weight / len(planted_issues)
100
+ fp_penalty_weight = false_positive_count * avg_difficulty
101
+ weighted_precision = difficulty_found / (difficulty_found + fp_penalty_weight) if (difficulty_found + fp_penalty_weight) > 0 else 0.0
102
+
103
+ # Weighted F1
104
+ if (weighted_precision + weighted_recall) > 0:
105
+ weighted_reward = 2 * weighted_precision * weighted_recall / (weighted_precision + weighted_recall)
106
+ else:
107
+ weighted_reward = 0.0
108
+
109
+ return {
110
+ "weighted_reward": round(weighted_reward, 4),
111
+ "difficulty_found": round(difficulty_found, 2),
112
+ "difficulty_missed": round(difficulty_missed, 2),
113
+ }
114
+
115
+
116
  class DataQAEnvironment(Environment):
117
  """
118
  Data Quality Assurance environment.
 
193
  else:
194
  parse_errors.append(f"Could not parse: '{raw_issue}'")
195
 
196
+ # Compute score (standard F1)
197
  metrics = compute_f1(reported_keys, self._planted_keys)
198
  score = metrics["f1"]
199
+
200
+ # Compute difficulty-weighted reward (richer per-step signal)
201
+ weighted = compute_weighted_reward(reported_keys, self._current_task.planted_issues)
202
+ weighted_reward = weighted["weighted_reward"]
203
+
204
+ # Use weighted reward as the primary reward signal
205
+ self._best_score = max(self._best_score, weighted_reward)
206
  self._state.best_score = self._best_score
207
 
208
  # Check if done
209
  is_done = (
210
+ score >= 0.999 # Perfect score (all issues found exactly)
211
  or self._state.current_step >= self._state.max_steps
212
  )
213
 
 
217
  f"Issues reported: {len(reported_keys)}",
218
  f"True positives: {metrics['tp']}, False positives: {metrics['fp']}, Missed: {metrics['fn']}",
219
  f"Precision: {metrics['precision']:.3f}, Recall: {metrics['recall']:.3f}, F1: {score:.3f}",
220
+ f"Weighted reward: {weighted_reward:.3f} (difficulty found: {weighted['difficulty_found']}, missed: {weighted['difficulty_missed']})",
221
  ]
222
 
223
  if parse_errors:
 
235
  )
236
  feedback_lines.append("You can submit again with an updated list of issues.")
237
  else:
238
+ feedback_lines.append(f"Task complete! Final best weighted reward: {self._best_score:.3f}")
239
 
240
  return DataQAObservation(
241
  dataset_csv=self._current_task.corrupted_csv,
 
248
  max_steps=self._state.max_steps,
249
  done=is_done,
250
  reward=self._best_score,
251
+ metadata={
252
+ "f1": score,
253
+ "weighted_reward": weighted_reward,
254
+ "precision": metrics["precision"],
255
+ "recall": metrics["recall"],
256
+ "tp": metrics["tp"],
257
+ "fp": metrics["fp"],
258
+ "fn": metrics["fn"],
259
+ "difficulty_found": weighted["difficulty_found"],
260
+ "difficulty_missed": weighted["difficulty_missed"],
261
+ },
262
  )
263
 
264
  @property
dataqa_env/server/tasks.py CHANGED
@@ -25,6 +25,7 @@ class PlantedIssue:
25
  col: str
26
  issue_type: str
27
  description: str
 
28
 
29
  def to_key(self) -> str:
30
  return f"row:{self.row},col:{self.col},issue:{self.issue_type}"
@@ -93,29 +94,29 @@ def create_task_easy(seed: int = 42) -> Task:
93
  data = rows[1:]
94
  issues: List[PlantedIssue] = []
95
 
96
- # Issue 1: Missing value - null out a name
97
  r = 3 # row index in data (0-based), displayed as row 4 in CSV
98
  data[r][1] = ""
99
  issues.append(PlantedIssue(row=r + 1, col="name", issue_type="missing_value",
100
- description="Empty name field"))
101
 
102
- # Issue 2: Wrong type - salary as text
103
  r = 6
104
  data[r][4] = "seventy-five thousand"
105
  issues.append(PlantedIssue(row=r + 1, col="salary", issue_type="wrong_type",
106
- description="Salary is text instead of integer"))
107
 
108
- # Issue 3: Duplicate row
109
  dup_source = 1
110
  data.append(list(data[dup_source]))
111
  issues.append(PlantedIssue(row=len(data), col="employee_id", issue_type="duplicate_row",
112
- description=f"Exact duplicate of row {dup_source + 1}"))
113
 
114
- # Issue 4: Out of range salary
115
  r = 8
116
  data[r][4] = "5000"
117
  issues.append(PlantedIssue(row=r + 1, col="salary", issue_type="out_of_range",
118
- description="Salary 5000 is below minimum 50000"))
119
 
120
  corrupted = _rows_to_csv([header] + data)
121
 
@@ -190,41 +191,41 @@ ORD-020,CUST-118,Fitness Tracker,Electronics,1,79.99,2024-02-03,AU,delivered,79.
190
  data = rows[1:]
191
  issues: List[PlantedIssue] = []
192
 
193
- # Issue 1: total doesn't match quantity * unit_price
194
  r = 4 # ORD-005
195
  data[r][9] = "84.00" # should be 42.00 (qty=1, price=42.00)
196
  issues.append(PlantedIssue(row=r + 1, col="total", issue_type="inconsistent_value",
197
- description="total (84.00) != quantity (1) * unit_price (42.00)"))
198
 
199
- # Issue 2: Invalid category
200
  r = 9 # ORD-010
201
  data[r][3] = "Fitness" # should be Sports
202
  issues.append(PlantedIssue(row=r + 1, col="category", issue_type="format_violation",
203
- description="'Fitness' is not in allowed categories"))
204
 
205
- # Issue 3: Missing value in product_name
206
  r = 13 # ORD-014
207
  data[r][2] = ""
208
  issues.append(PlantedIssue(row=r + 1, col="product_name", issue_type="missing_value",
209
- description="Empty product_name"))
210
 
211
- # Issue 4: Out of range quantity
212
  r = 16 # ORD-017
213
  data[r][4] = "-1"
214
  issues.append(PlantedIssue(row=r + 1, col="quantity", issue_type="out_of_range",
215
- description="Negative quantity"))
216
 
217
- # Issue 5: Duplicate order_id
218
  r = 18 # ORD-019
219
  data[r][0] = "ORD-003"
220
  issues.append(PlantedIssue(row=r + 1, col="order_id", issue_type="duplicate_row",
221
- description="Duplicate order_id ORD-003"))
222
 
223
- # Issue 6: Wrong date format
224
  r = 11 # ORD-012
225
  data[r][6] = "26/01/2024"
226
  issues.append(PlantedIssue(row=r + 1, col="order_date", issue_type="format_violation",
227
- description="Date format DD/MM/YYYY instead of YYYY-MM-DD"))
228
 
229
  corrupted = _rows_to_csv([header] + data)
230
 
@@ -301,53 +302,57 @@ EXP-015,whisper-small,common-voice,520000,16000,16000,0.00005,16,5,0.55,0.68,0.0
301
  data = rows[1:]
302
  issues: List[PlantedIssue] = []
303
 
304
- # Issue 1: Data leakage signal — val_loss much lower than train_loss
305
  r = 4 # EXP-005
306
  data[r][10] = "0.15" # val_loss=0.15 but train_loss=0.28 → suspicious
307
  issues.append(PlantedIssue(row=r + 1, col="val_loss", issue_type="inconsistent_value",
308
- description="val_loss (0.15) significantly less than train_loss (0.28), potential data leakage"))
 
309
 
310
- # Issue 2: Batch size not power of 2
311
  r = 8 # EXP-009
312
  data[r][7] = "250" # not a power of 2
313
  issues.append(PlantedIssue(row=r + 1, col="batch_size", issue_type="format_violation",
314
- description="batch_size 250 is not a power of 2"))
315
 
316
- # Issue 3: GPU memory unreasonable for model
317
  r = 6 # EXP-007 resnet18 on cifar10
318
  data[r][12] = "42.5" # resnet18 shouldn't need 42.5 GB
319
  issues.append(PlantedIssue(row=r + 1, col="gpu_memory_gb", issue_type="statistical_outlier",
320
- description="resnet18 on cifar10 using 42.5 GB GPU memory is unreasonable"))
 
321
 
322
- # Issue 4: Timestamp out of order
323
  r = 10 # EXP-011
324
  data[r][14] = "2024-03-02T09:00:00" # should be after EXP-010's timestamp
325
  issues.append(PlantedIssue(row=r + 1, col="timestamp", issue_type="inconsistent_value",
326
- description="Timestamp 2024-03-02 is before EXP-010's timestamp 2024-03-11"))
 
327
 
328
- # Issue 5: Train size smaller than test size
329
  r = 9 # EXP-010
330
  data[r][3] = "500" # train_size=500 but test_size=1821
331
  issues.append(PlantedIssue(row=r + 1, col="train_size", issue_type="inconsistent_value",
332
- description="train_size (500) is smaller than test_size (1821)"))
 
333
 
334
- # Issue 6: Negative training time
335
  r = 13 # EXP-014
336
  data[r][13] = "-72.0"
337
  issues.append(PlantedIssue(row=r + 1, col="training_time_hours", issue_type="out_of_range",
338
- description="Negative training time"))
339
 
340
- # Issue 7: Learning rate out of range
341
  r = 12 # EXP-013
342
  data[r][6] = "2.5" # way too high
343
  issues.append(PlantedIssue(row=r + 1, col="learning_rate", issue_type="out_of_range",
344
- description="Learning rate 2.5 exceeds maximum of 1.0"))
345
 
346
- # Issue 8: Missing model name (subtlesingle space instead of empty)
347
  r = 14 # EXP-015
348
  data[r][1] = " "
349
  issues.append(PlantedIssue(row=r + 1, col="model_name", issue_type="missing_value",
350
- description="model_name is whitespace-only"))
351
 
352
  corrupted = _rows_to_csv([header] + data)
353
 
@@ -370,6 +375,123 @@ EXP-015,whisper-small,common-voice,520000,16000,16000,0.00005,16,5,0.55,0.68,0.0
370
  )
371
 
372
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
  # ---------------------------------------------------------------------------
374
  # Task registry
375
  # ---------------------------------------------------------------------------
 
25
  col: str
26
  issue_type: str
27
  description: str
28
+ difficulty: float = 1.0 # 1.0=easy, 2.0=medium, 3.0=hard (for weighted reward)
29
 
30
  def to_key(self) -> str:
31
  return f"row:{self.row},col:{self.col},issue:{self.issue_type}"
 
94
  data = rows[1:]
95
  issues: List[PlantedIssue] = []
96
 
97
+ # Issue 1: Missing value - null out a name (easy to spot)
98
  r = 3 # row index in data (0-based), displayed as row 4 in CSV
99
  data[r][1] = ""
100
  issues.append(PlantedIssue(row=r + 1, col="name", issue_type="missing_value",
101
+ description="Empty name field", difficulty=1.0))
102
 
103
+ # Issue 2: Wrong type - salary as text (easy to spot)
104
  r = 6
105
  data[r][4] = "seventy-five thousand"
106
  issues.append(PlantedIssue(row=r + 1, col="salary", issue_type="wrong_type",
107
+ description="Salary is text instead of integer", difficulty=1.0))
108
 
109
+ # Issue 3: Duplicate row (moderate — requires cross-row comparison)
110
  dup_source = 1
111
  data.append(list(data[dup_source]))
112
  issues.append(PlantedIssue(row=len(data), col="employee_id", issue_type="duplicate_row",
113
+ description=f"Exact duplicate of row {dup_source + 1}", difficulty=1.5))
114
 
115
+ # Issue 4: Out of range salary (easy to spot)
116
  r = 8
117
  data[r][4] = "5000"
118
  issues.append(PlantedIssue(row=r + 1, col="salary", issue_type="out_of_range",
119
+ description="Salary 5000 is below minimum 50000", difficulty=1.0))
120
 
121
  corrupted = _rows_to_csv([header] + data)
122
 
 
191
  data = rows[1:]
192
  issues: List[PlantedIssue] = []
193
 
194
+ # Issue 1: total doesn't match quantity * unit_price (requires cross-column check)
195
  r = 4 # ORD-005
196
  data[r][9] = "84.00" # should be 42.00 (qty=1, price=42.00)
197
  issues.append(PlantedIssue(row=r + 1, col="total", issue_type="inconsistent_value",
198
+ description="total (84.00) != quantity (1) * unit_price (42.00)", difficulty=2.0))
199
 
200
+ # Issue 2: Invalid category (requires knowing the allowed set)
201
  r = 9 # ORD-010
202
  data[r][3] = "Fitness" # should be Sports
203
  issues.append(PlantedIssue(row=r + 1, col="category", issue_type="format_violation",
204
+ description="'Fitness' is not in allowed categories", difficulty=1.5))
205
 
206
+ # Issue 3: Missing value in product_name (easy to spot)
207
  r = 13 # ORD-014
208
  data[r][2] = ""
209
  issues.append(PlantedIssue(row=r + 1, col="product_name", issue_type="missing_value",
210
+ description="Empty product_name", difficulty=1.0))
211
 
212
+ # Issue 4: Out of range quantity (easy to spot)
213
  r = 16 # ORD-017
214
  data[r][4] = "-1"
215
  issues.append(PlantedIssue(row=r + 1, col="quantity", issue_type="out_of_range",
216
+ description="Negative quantity", difficulty=1.0))
217
 
218
+ # Issue 5: Duplicate order_id (requires cross-row comparison)
219
  r = 18 # ORD-019
220
  data[r][0] = "ORD-003"
221
  issues.append(PlantedIssue(row=r + 1, col="order_id", issue_type="duplicate_row",
222
+ description="Duplicate order_id ORD-003", difficulty=1.5))
223
 
224
+ # Issue 6: Wrong date format (moderate — format mismatch)
225
  r = 11 # ORD-012
226
  data[r][6] = "26/01/2024"
227
  issues.append(PlantedIssue(row=r + 1, col="order_date", issue_type="format_violation",
228
+ description="Date format DD/MM/YYYY instead of YYYY-MM-DD", difficulty=1.5))
229
 
230
  corrupted = _rows_to_csv([header] + data)
231
 
 
302
  data = rows[1:]
303
  issues: List[PlantedIssue] = []
304
 
305
+ # Issue 1: Data leakage signal — val_loss much lower than train_loss (hard — requires ML knowledge)
306
  r = 4 # EXP-005
307
  data[r][10] = "0.15" # val_loss=0.15 but train_loss=0.28 → suspicious
308
  issues.append(PlantedIssue(row=r + 1, col="val_loss", issue_type="inconsistent_value",
309
+ description="val_loss (0.15) significantly less than train_loss (0.28), potential data leakage",
310
+ difficulty=3.0))
311
 
312
+ # Issue 2: Batch size not power of 2 (moderate — domain convention)
313
  r = 8 # EXP-009
314
  data[r][7] = "250" # not a power of 2
315
  issues.append(PlantedIssue(row=r + 1, col="batch_size", issue_type="format_violation",
316
+ description="batch_size 250 is not a power of 2", difficulty=2.0))
317
 
318
+ # Issue 3: GPU memory unreasonable for model (hard — requires model size reasoning)
319
  r = 6 # EXP-007 resnet18 on cifar10
320
  data[r][12] = "42.5" # resnet18 shouldn't need 42.5 GB
321
  issues.append(PlantedIssue(row=r + 1, col="gpu_memory_gb", issue_type="statistical_outlier",
322
+ description="resnet18 on cifar10 using 42.5 GB GPU memory is unreasonable",
323
+ difficulty=3.0))
324
 
325
+ # Issue 4: Timestamp out of order (moderate — requires sequential comparison)
326
  r = 10 # EXP-011
327
  data[r][14] = "2024-03-02T09:00:00" # should be after EXP-010's timestamp
328
  issues.append(PlantedIssue(row=r + 1, col="timestamp", issue_type="inconsistent_value",
329
+ description="Timestamp 2024-03-02 is before EXP-010's timestamp 2024-03-11",
330
+ difficulty=2.0))
331
 
332
+ # Issue 5: Train size smaller than test size (moderate — cross-column logic)
333
  r = 9 # EXP-010
334
  data[r][3] = "500" # train_size=500 but test_size=1821
335
  issues.append(PlantedIssue(row=r + 1, col="train_size", issue_type="inconsistent_value",
336
+ description="train_size (500) is smaller than test_size (1821)",
337
+ difficulty=2.0))
338
 
339
+ # Issue 6: Negative training time (easy to spot)
340
  r = 13 # EXP-014
341
  data[r][13] = "-72.0"
342
  issues.append(PlantedIssue(row=r + 1, col="training_time_hours", issue_type="out_of_range",
343
+ description="Negative training time", difficulty=1.0))
344
 
345
+ # Issue 7: Learning rate out of range (easy to spot)
346
  r = 12 # EXP-013
347
  data[r][6] = "2.5" # way too high
348
  issues.append(PlantedIssue(row=r + 1, col="learning_rate", issue_type="out_of_range",
349
+ description="Learning rate 2.5 exceeds maximum of 1.0", difficulty=1.5))
350
 
351
+ # Issue 8: Missing model name (hardwhitespace-only is subtle)
352
  r = 14 # EXP-015
353
  data[r][1] = " "
354
  issues.append(PlantedIssue(row=r + 1, col="model_name", issue_type="missing_value",
355
+ description="model_name is whitespace-only", difficulty=2.5))
356
 
357
  corrupted = _rows_to_csv([header] + data)
358
 
 
375
  )
376
 
377
 
378
+ # ---------------------------------------------------------------------------
379
+ # Contamination rules for extensible task creation
380
+ # ---------------------------------------------------------------------------
381
+
382
+ # Each contamination rule is a callable: (rows, header, col_idx, row_idx, rng) -> (new_value, PlantedIssue)
383
+ # Users can define their own and register them.
384
+
385
+ CONTAMINATION_RULES = {
386
+ "missing_value": lambda rows, header, col_idx, row_idx, rng: (
387
+ "",
388
+ PlantedIssue(
389
+ row=row_idx + 1, col=header[col_idx], issue_type="missing_value",
390
+ description=f"Empty {header[col_idx]} field", difficulty=1.0,
391
+ ),
392
+ ),
393
+ "whitespace_value": lambda rows, header, col_idx, row_idx, rng: (
394
+ " ",
395
+ PlantedIssue(
396
+ row=row_idx + 1, col=header[col_idx], issue_type="missing_value",
397
+ description=f"Whitespace-only {header[col_idx]} field", difficulty=2.5,
398
+ ),
399
+ ),
400
+ "wrong_type_text": lambda rows, header, col_idx, row_idx, rng: (
401
+ rng.choice(["not-a-number", "N/A", "null", "undefined"]),
402
+ PlantedIssue(
403
+ row=row_idx + 1, col=header[col_idx], issue_type="wrong_type",
404
+ description=f"{header[col_idx]} is text instead of expected type", difficulty=1.0,
405
+ ),
406
+ ),
407
+ "negative_value": lambda rows, header, col_idx, row_idx, rng: (
408
+ str(-abs(float(rows[row_idx][col_idx]) if rows[row_idx][col_idx] else 1)),
409
+ PlantedIssue(
410
+ row=row_idx + 1, col=header[col_idx], issue_type="out_of_range",
411
+ description=f"Negative {header[col_idx]}", difficulty=1.0,
412
+ ),
413
+ ),
414
+ }
415
+
416
+
417
+ def create_task_from_config(
418
+ task_id: str,
419
+ name: str,
420
+ description: str,
421
+ schema_description: str,
422
+ validation_rules: str,
423
+ clean_csv: str,
424
+ contaminations: List[dict],
425
+ max_steps: int = 3,
426
+ seed: int = 42,
427
+ ) -> Task:
428
+ """
429
+ Create a custom task from a configuration dict.
430
+
431
+ Each contamination entry should have:
432
+ - rule: str (key in CONTAMINATION_RULES) or callable
433
+ - row: int (0-based row index in data)
434
+ - col: int (column index in header)
435
+ - difficulty: float (optional, overrides rule default)
436
+
437
+ Example:
438
+ contaminations = [
439
+ {"rule": "missing_value", "row": 2, "col": 1, "difficulty": 1.5},
440
+ {"rule": "negative_value", "row": 5, "col": 4},
441
+ ]
442
+ """
443
+ rng = random.Random(seed)
444
+ rows = _csv_to_rows(clean_csv)
445
+ header = rows[0]
446
+ data = rows[1:]
447
+ issues: List[PlantedIssue] = []
448
+
449
+ for spec in contaminations:
450
+ rule = spec["rule"]
451
+ row_idx = spec["row"]
452
+ col_idx = spec["col"]
453
+
454
+ if callable(rule):
455
+ new_val, issue = rule(data, header, col_idx, row_idx, rng)
456
+ elif rule in CONTAMINATION_RULES:
457
+ new_val, issue = CONTAMINATION_RULES[rule](data, header, col_idx, row_idx, rng)
458
+ else:
459
+ raise ValueError(f"Unknown contamination rule: {rule}. Available: {list(CONTAMINATION_RULES.keys())}")
460
+
461
+ data[row_idx][col_idx] = new_val
462
+ if "difficulty" in spec:
463
+ issue.difficulty = spec["difficulty"]
464
+ issues.append(issue)
465
+
466
+ corrupted = _rows_to_csv([header] + data)
467
+
468
+ return Task(
469
+ task_id=task_id,
470
+ name=name,
471
+ description=description,
472
+ schema_description=schema_description,
473
+ validation_rules=validation_rules,
474
+ clean_csv=clean_csv,
475
+ planted_issues=issues,
476
+ corrupted_csv=corrupted,
477
+ max_steps=max_steps,
478
+ )
479
+
480
+
481
+ def register_task(task_id: str, factory_fn):
482
+ """Register a custom task factory. Factory must accept (seed: int) -> Task."""
483
+ TASK_REGISTRY[task_id] = factory_fn
484
+
485
+
486
+ def register_contamination_rule(name: str, rule_fn):
487
+ """
488
+ Register a custom contamination rule.
489
+
490
+ rule_fn signature: (rows, header, col_idx, row_idx, rng) -> (new_value, PlantedIssue)
491
+ """
492
+ CONTAMINATION_RULES[name] = rule_fn
493
+
494
+
495
  # ---------------------------------------------------------------------------
496
  # Task registry
497
  # ---------------------------------------------------------------------------
inference.py CHANGED
@@ -6,21 +6,23 @@ LLM agent that plays the DataQA environment.
6
  Uses the OpenAI client to interact with any OpenAI-compatible LLM API.
7
 
8
  Required environment variables:
9
- API_BASE_URL - LLM API endpoint (e.g., https://api.groq.com/openai/v1)
10
- MODEL_NAME - Model identifier (e.g., llama-3.3-70b-versatile)
11
- HF_TOKEN - HuggingFace token (for HF Spaces access)
12
-
13
- Structured logging format: [START], [STEP], [END] tags for evaluation.
 
 
 
14
  """
15
 
16
  from __future__ import annotations
17
 
18
- import json
19
  import os
20
  import re
21
  import sys
22
  import time
23
- from typing import Optional
24
 
25
  import requests
26
  from openai import OpenAI
@@ -28,52 +30,43 @@ from openai import OpenAI
28
  # ---------------------------------------------------------------------------
29
  # Configuration
30
  # ---------------------------------------------------------------------------
31
- API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.groq.com/openai/v1")
32
- MODEL_NAME = os.environ.get("MODEL_NAME", "llama-3.3-70b-versatile")
33
- HF_TOKEN = os.environ.get("HF_TOKEN", "")
34
- ENV_URL = os.environ.get("ENV_URL", "http://localhost:8000")
35
 
 
36
  TASKS = ["easy", "medium", "hard"]
37
  MAX_STEPS_PER_TASK = 3
38
 
 
39
  # ---------------------------------------------------------------------------
40
- # Logging helpers (structured stdout for evaluation)
41
  # ---------------------------------------------------------------------------
42
 
43
- def log_start(task_id: str, metadata: Optional[dict] = None):
44
- entry = {"event": "START", "task_id": task_id, "timestamp": time.time()}
45
- if metadata:
46
- entry["metadata"] = metadata
47
- print(f"[START] {json.dumps(entry)}", flush=True)
48
-
49
-
50
- def log_step(task_id: str, step: int, reward: float, details: Optional[dict] = None):
51
- entry = {
52
- "event": "STEP",
53
- "task_id": task_id,
54
- "step": step,
55
- "reward": reward,
56
- "timestamp": time.time(),
57
- }
58
- if details:
59
- entry["details"] = details
60
- print(f"[STEP] {json.dumps(entry)}", flush=True)
61
-
62
-
63
- def log_end(task_id: str, final_score: float, metadata: Optional[dict] = None):
64
- entry = {
65
- "event": "END",
66
- "task_id": task_id,
67
- "final_score": final_score,
68
- "timestamp": time.time(),
69
- }
70
- if metadata:
71
- entry["metadata"] = metadata
72
- print(f"[END] {json.dumps(entry)}", flush=True)
73
 
74
 
75
  # ---------------------------------------------------------------------------
76
- # Environment HTTP client (simple, no WebSocket needed for inference)
77
  # ---------------------------------------------------------------------------
78
 
79
  class EnvHTTPClient:
@@ -108,11 +101,6 @@ class EnvHTTPClient:
108
  r.raise_for_status()
109
  return r.json()
110
 
111
- def state(self) -> dict:
112
- r = self.session.get(f"{self.base_url}/state", timeout=10)
113
- r.raise_for_status()
114
- return r.json()
115
-
116
 
117
  # ---------------------------------------------------------------------------
118
  # LLM Agent
@@ -142,7 +130,6 @@ CRITICAL INSTRUCTIONS FOR ROW NUMBERING:
142
  - Row numbers refer to the ROW POSITION in the CSV data, NOT the value of any ID column
143
  - Row 1 = the FIRST data row after the header
144
  - Row 2 = the SECOND data row after the header
145
- - For example, if the CSV has header on line 1 and data starting on line 2, the data on line 2 is row 1, line 3 is row 2, etc.
146
  - DO NOT use the employee_id, order_id, or experiment_id as the row number
147
  - Column names must match exactly (use the CSV header names, lowercase)
148
  - Check EVERY row and EVERY column systematically
@@ -188,14 +175,12 @@ def parse_llm_response(response: str) -> list[str]:
188
  line = re.sub(r"^\s*[-*]\s*", "", line)
189
  line = line.strip()
190
  if "row" in line.lower() and "col" in line.lower():
191
- # Lenient regex: accept : or = as delimiters, case-insensitive
192
  match = re.search(
193
  r"row\s*[:=]\s*(\d+)\s*[,;\s]+col(?:umn)?\s*[:=]\s*([\w_]+)\s*[,;\s]+issue\s*[:=]\s*([\w_]+)",
194
  line,
195
  re.IGNORECASE,
196
  )
197
  if match:
198
- # Normalize to lowercase canonical format
199
  normalized = f"row:{match.group(1)},col:{match.group(2).lower()},issue:{match.group(3).lower()}"
200
  issues.append(normalized)
201
  return issues
@@ -203,68 +188,84 @@ def parse_llm_response(response: str) -> list[str]:
203
 
204
  def run_task(client: OpenAI, env: EnvHTTPClient, task_id: str) -> float:
205
  """Run a single task and return the best score."""
206
- log_start(task_id)
207
-
208
- # Reset environment for this task
209
- reset_response = env.reset(task_id=task_id)
210
- observation = reset_response.get("observation", reset_response)
211
 
 
 
212
  best_score = 0.0
213
- messages = [{"role": "system", "content": SYSTEM_PROMPT}]
214
-
215
- for step_num in range(1, MAX_STEPS_PER_TASK + 1):
216
- user_prompt = build_user_prompt(observation)
217
- messages_for_call = messages + [{"role": "user", "content": user_prompt}]
218
-
219
- # Call LLM with retry on rate limit
220
- llm_output = ""
221
- for attempt in range(3):
222
- try:
223
- response = client.chat.completions.create(
224
- model=MODEL_NAME,
225
- messages=messages_for_call,
226
- temperature=0.1,
227
- max_tokens=2048,
228
- )
229
- llm_output = response.choices[0].message.content or ""
230
- break
231
- except Exception as e:
232
- if "rate_limit" in str(e).lower() or "429" in str(e):
233
- wait = 10 * (attempt + 1)
234
- print(f"[WARN] Rate limited, waiting {wait}s...", flush=True)
235
- time.sleep(wait)
236
- else:
237
- print(f"[ERROR] LLM call failed: {e}", file=sys.stderr, flush=True)
238
  break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
 
240
- # Parse issues from LLM response
241
- issues = parse_llm_response(llm_output)
242
-
243
- if not issues:
244
- print(f"[WARN] No issues parsed from LLM response for {task_id} step {step_num}", file=sys.stderr, flush=True)
245
-
246
- # Submit to environment
247
- step_response = env.step(issues, task_id=task_id)
248
- observation = step_response.get("observation", step_response)
249
-
250
- # reward and done are at the top level of the response, not inside observation
251
- reward = float(step_response.get("reward", 0.0) or 0.0)
252
- done = bool(step_response.get("done", False))
253
- best_score = max(best_score, reward)
254
 
255
- log_step(task_id, step_num, reward, {
256
- "issues_reported": len(issues),
257
- "feedback": observation.get("feedback", ""),
258
- })
259
 
260
- if done:
261
- break
262
 
263
- # Add context for next attempt
264
- messages.append({"role": "user", "content": user_prompt})
265
- messages.append({"role": "assistant", "content": llm_output})
266
 
267
- log_end(task_id, best_score)
268
  return best_score
269
 
270
 
@@ -273,49 +274,38 @@ def run_task(client: OpenAI, env: EnvHTTPClient, task_id: str) -> float:
273
  # ---------------------------------------------------------------------------
274
 
275
  def main():
276
- print(f"[INFO] DataQA Inference starting", flush=True)
277
- print(f"[INFO] ENV_URL={ENV_URL}", flush=True)
278
- print(f"[INFO] API_BASE_URL={API_BASE_URL}", flush=True)
279
- print(f"[INFO] MODEL_NAME={MODEL_NAME}", flush=True)
280
 
281
  # Initialize clients
282
  env = EnvHTTPClient(ENV_URL)
283
  llm_client = OpenAI(
284
  base_url=API_BASE_URL,
285
- api_key=os.environ.get("LLM_API_KEY", HF_TOKEN or "no-key"),
286
  )
287
 
288
  # Check environment health
289
  if not env.health():
290
- print("[ERROR] Environment is not healthy. Exiting.", file=sys.stderr, flush=True)
291
  sys.exit(1)
292
 
293
- print(f"[INFO] Environment is healthy", flush=True)
294
 
295
  # Run all tasks
296
  scores = {}
297
  for task_id in TASKS:
298
- print(f"\n{'='*60}", flush=True)
299
- print(f"[INFO] Starting task: {task_id}", flush=True)
300
- print(f"{'='*60}", flush=True)
301
-
302
  try:
303
  score = run_task(llm_client, env, task_id)
304
  scores[task_id] = score
305
- print(f"[INFO] Task {task_id} completed with score: {score:.3f}", flush=True)
306
  except Exception as e:
307
- print(f"[ERROR] Task {task_id} failed: {e}", file=sys.stderr, flush=True)
308
  scores[task_id] = 0.0
309
 
310
- # Summary
311
- print(f"\n{'='*60}", flush=True)
312
- print("[INFO] FINAL RESULTS", flush=True)
313
- print(f"{'='*60}", flush=True)
314
- for task_id, score in scores.items():
315
- print(f"[INFO] {task_id}: {score:.3f}", flush=True)
316
-
317
  avg_score = sum(scores.values()) / len(scores) if scores else 0.0
318
- print(f"[INFO] Average score: {avg_score:.3f}", flush=True)
319
 
320
 
321
  if __name__ == "__main__":
 
6
  Uses the OpenAI client to interact with any OpenAI-compatible LLM API.
7
 
8
  Required environment variables:
9
+ API_BASE_URL - LLM API endpoint (e.g., https://router.huggingface.co/v1)
10
+ MODEL_NAME - Model identifier (e.g., Qwen/Qwen2.5-72B-Instruct)
11
+ HF_TOKEN - HuggingFace token / API key
12
+
13
+ STDOUT FORMAT (mandatory for evaluation):
14
+ [START] task=<task_name> env=<benchmark> model=<model_name>
15
+ [STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
16
+ [END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn>
17
  """
18
 
19
  from __future__ import annotations
20
 
 
21
  import os
22
  import re
23
  import sys
24
  import time
25
+ from typing import List, Optional
26
 
27
  import requests
28
  from openai import OpenAI
 
30
  # ---------------------------------------------------------------------------
31
  # Configuration
32
  # ---------------------------------------------------------------------------
33
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
34
+ MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
35
+ API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
36
+ ENV_URL = os.getenv("ENV_URL", "http://localhost:8000")
37
 
38
+ BENCHMARK = "dataqa_env"
39
  TASKS = ["easy", "medium", "hard"]
40
  MAX_STEPS_PER_TASK = 3
41
 
42
+
43
  # ---------------------------------------------------------------------------
44
+ # Logging helpers (structured stdout exact format required by evaluation)
45
  # ---------------------------------------------------------------------------
46
 
47
+ def log_start(task: str, env: str, model: str) -> None:
48
+ print(f"[START] task={task} env={env} model={model}", flush=True)
49
+
50
+
51
+ def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
52
+ error_val = error if error else "null"
53
+ done_val = str(done).lower()
54
+ print(
55
+ f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
56
+ flush=True,
57
+ )
58
+
59
+
60
+ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
61
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
62
+ print(
63
+ f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}",
64
+ flush=True,
65
+ )
 
 
 
 
 
 
 
 
 
 
 
66
 
67
 
68
  # ---------------------------------------------------------------------------
69
+ # Environment HTTP client
70
  # ---------------------------------------------------------------------------
71
 
72
  class EnvHTTPClient:
 
101
  r.raise_for_status()
102
  return r.json()
103
 
 
 
 
 
 
104
 
105
  # ---------------------------------------------------------------------------
106
  # LLM Agent
 
130
  - Row numbers refer to the ROW POSITION in the CSV data, NOT the value of any ID column
131
  - Row 1 = the FIRST data row after the header
132
  - Row 2 = the SECOND data row after the header
 
133
  - DO NOT use the employee_id, order_id, or experiment_id as the row number
134
  - Column names must match exactly (use the CSV header names, lowercase)
135
  - Check EVERY row and EVERY column systematically
 
175
  line = re.sub(r"^\s*[-*]\s*", "", line)
176
  line = line.strip()
177
  if "row" in line.lower() and "col" in line.lower():
 
178
  match = re.search(
179
  r"row\s*[:=]\s*(\d+)\s*[,;\s]+col(?:umn)?\s*[:=]\s*([\w_]+)\s*[,;\s]+issue\s*[:=]\s*([\w_]+)",
180
  line,
181
  re.IGNORECASE,
182
  )
183
  if match:
 
184
  normalized = f"row:{match.group(1)},col:{match.group(2).lower()},issue:{match.group(3).lower()}"
185
  issues.append(normalized)
186
  return issues
 
188
 
189
  def run_task(client: OpenAI, env: EnvHTTPClient, task_id: str) -> float:
190
  """Run a single task and return the best score."""
191
+ log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
 
 
 
 
192
 
193
+ rewards: List[float] = []
194
+ steps_taken = 0
195
  best_score = 0.0
196
+ success = False
197
+
198
+ try:
199
+ # Reset environment for this task
200
+ reset_response = env.reset(task_id=task_id)
201
+ observation = reset_response.get("observation", reset_response)
202
+
203
+ messages = [{"role": "system", "content": SYSTEM_PROMPT}]
204
+
205
+ for step_num in range(1, MAX_STEPS_PER_TASK + 1):
206
+ user_prompt = build_user_prompt(observation)
207
+ messages_for_call = messages + [{"role": "user", "content": user_prompt}]
208
+
209
+ # Call LLM with retry on rate limit
210
+ llm_output = ""
211
+ error_msg = None
212
+ for attempt in range(3):
213
+ try:
214
+ response = client.chat.completions.create(
215
+ model=MODEL_NAME,
216
+ messages=messages_for_call,
217
+ temperature=0.1,
218
+ max_tokens=2048,
219
+ )
220
+ llm_output = response.choices[0].message.content or ""
221
  break
222
+ except Exception as e:
223
+ if "rate_limit" in str(e).lower() or "429" in str(e):
224
+ wait = 10 * (attempt + 1)
225
+ print(f"[DEBUG] Rate limited, waiting {wait}s...", file=sys.stderr, flush=True)
226
+ time.sleep(wait)
227
+ else:
228
+ error_msg = str(e)
229
+ print(f"[DEBUG] LLM call failed: {e}", file=sys.stderr, flush=True)
230
+ break
231
+
232
+ # Parse issues from LLM response
233
+ issues = parse_llm_response(llm_output)
234
+ action_str = ";".join(issues) if issues else "none"
235
+
236
+ if not issues and not error_msg:
237
+ error_msg = "no issues parsed from LLM response"
238
+
239
+ # Submit to environment
240
+ step_response = env.step(issues, task_id=task_id)
241
+ observation = step_response.get("observation", step_response)
242
+
243
+ reward = float(step_response.get("reward", 0.0) or 0.0)
244
+ done = bool(step_response.get("done", False))
245
+ best_score = max(best_score, reward)
246
+ rewards.append(reward)
247
+ steps_taken = step_num
248
+
249
+ log_step(
250
+ step=step_num,
251
+ action=action_str,
252
+ reward=reward,
253
+ done=done,
254
+ error=error_msg,
255
+ )
256
 
257
+ if done:
258
+ break
 
 
 
 
 
 
 
 
 
 
 
 
259
 
260
+ # Add context for next attempt
261
+ messages.append({"role": "user", "content": user_prompt})
262
+ messages.append({"role": "assistant", "content": llm_output})
 
263
 
264
+ success = best_score >= 0.5
 
265
 
266
+ finally:
267
+ log_end(success=success, steps=steps_taken, score=best_score, rewards=rewards)
 
268
 
 
269
  return best_score
270
 
271
 
 
274
  # ---------------------------------------------------------------------------
275
 
276
  def main():
277
+ print(f"[DEBUG] DataQA Inference starting", file=sys.stderr, flush=True)
278
+ print(f"[DEBUG] ENV_URL={ENV_URL}", file=sys.stderr, flush=True)
279
+ print(f"[DEBUG] API_BASE_URL={API_BASE_URL}", file=sys.stderr, flush=True)
280
+ print(f"[DEBUG] MODEL_NAME={MODEL_NAME}", file=sys.stderr, flush=True)
281
 
282
  # Initialize clients
283
  env = EnvHTTPClient(ENV_URL)
284
  llm_client = OpenAI(
285
  base_url=API_BASE_URL,
286
+ api_key=API_KEY or "no-key",
287
  )
288
 
289
  # Check environment health
290
  if not env.health():
291
+ print("[DEBUG] Environment is not healthy. Exiting.", file=sys.stderr, flush=True)
292
  sys.exit(1)
293
 
294
+ print(f"[DEBUG] Environment is healthy", file=sys.stderr, flush=True)
295
 
296
  # Run all tasks
297
  scores = {}
298
  for task_id in TASKS:
 
 
 
 
299
  try:
300
  score = run_task(llm_client, env, task_id)
301
  scores[task_id] = score
 
302
  except Exception as e:
303
+ print(f"[DEBUG] Task {task_id} failed: {e}", file=sys.stderr, flush=True)
304
  scores[task_id] = 0.0
305
 
306
+ # Summary to stderr (stdout is reserved for structured logs only)
 
 
 
 
 
 
307
  avg_score = sum(scores.values()) / len(scores) if scores else 0.0
308
+ print(f"\n[DEBUG] FINAL RESULTS: {scores} avg={avg_score:.3f}", file=sys.stderr, flush=True)
309
 
310
 
311
  if __name__ == "__main__":
prevalidation_script.sh ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ #
3
+ # validate-submission.sh — OpenEnv Submission Validator
4
+ #
5
+ # Checks that your HF Space is live, Docker image builds, and openenv validate passes.
6
+ #
7
+ # Prerequisites:
8
+ # - Docker: https://docs.docker.com/get-docker/
9
+ # - openenv-core: pip install openenv-core
10
+ # - curl (usually pre-installed)
11
+ #
12
+ # Run:
13
+ # curl -fsSL https://raw.githubusercontent.com/<owner>/<repo>/main/scripts/validate-submission.sh | bash -s -- <ping_url> [repo_dir]
14
+ #
15
+ # Or download and run locally:
16
+ # chmod +x validate-submission.sh
17
+ # ./validate-submission.sh <ping_url> [repo_dir]
18
+ #
19
+ # Arguments:
20
+ # ping_url Your HuggingFace Space URL (e.g. https://your-space.hf.space)
21
+ # repo_dir Path to your repo (default: current directory)
22
+ #
23
+ # Examples:
24
+ # ./validate-submission.sh https://my-team.hf.space
25
+ # ./validate-submission.sh https://my-team.hf.space ./my-repo
26
+ #
27
+
28
+ set -uo pipefail
29
+
30
+ DOCKER_BUILD_TIMEOUT=600
31
+ if [ -t 1 ]; then
32
+ RED='\033[0;31m'
33
+ GREEN='\033[0;32m'
34
+ YELLOW='\033[1;33m'
35
+ BOLD='\033[1m'
36
+ NC='\033[0m'
37
+ else
38
+ RED='' GREEN='' YELLOW='' BOLD='' NC=''
39
+ fi
40
+
41
+ run_with_timeout() {
42
+ local secs="$1"; shift
43
+ if command -v timeout &>/dev/null; then
44
+ timeout "$secs" "$@"
45
+ elif command -v gtimeout &>/dev/null; then
46
+ gtimeout "$secs" "$@"
47
+ else
48
+ "$@" &
49
+ local pid=$!
50
+ ( sleep "$secs" && kill "$pid" 2>/dev/null ) &
51
+ local watcher=$!
52
+ wait "$pid" 2>/dev/null
53
+ local rc=$?
54
+ kill "$watcher" 2>/dev/null
55
+ wait "$watcher" 2>/dev/null
56
+ return $rc
57
+ fi
58
+ }
59
+
60
+ portable_mktemp() {
61
+ local prefix="${1:-validate}"
62
+ mktemp "${TMPDIR:-/tmp}/${prefix}-XXXXXX" 2>/dev/null || mktemp
63
+ }
64
+
65
+ CLEANUP_FILES=()
66
+ cleanup() { rm -f "${CLEANUP_FILES[@]+"${CLEANUP_FILES[@]}"}"; }
67
+ trap cleanup EXIT
68
+
69
+ PING_URL="${1:-}"
70
+ REPO_DIR="${2:-.}"
71
+
72
+ if [ -z "$PING_URL" ]; then
73
+ printf "Usage: %s <ping_url> [repo_dir]\n" "$0"
74
+ printf "\n"
75
+ printf " ping_url Your HuggingFace Space URL (e.g. https://your-space.hf.space)\n"
76
+ printf " repo_dir Path to your repo (default: current directory)\n"
77
+ exit 1
78
+ fi
79
+
80
+ if ! REPO_DIR="$(cd "$REPO_DIR" 2>/dev/null && pwd)"; then
81
+ printf "Error: directory '%s' not found\n" "${2:-.}"
82
+ exit 1
83
+ fi
84
+ PING_URL="${PING_URL%/}"
85
+ export PING_URL
86
+ PASS=0
87
+
88
+ log() { printf "[%s] %b\n" "$(date -u +%H:%M:%S)" "$*"; }
89
+ pass() { log "${GREEN}PASSED${NC} -- $1"; PASS=$((PASS + 1)); }
90
+ fail() { log "${RED}FAILED${NC} -- $1"; }
91
+ hint() { printf " ${YELLOW}Hint:${NC} %b\n" "$1"; }
92
+ stop_at() {
93
+ printf "\n"
94
+ printf "${RED}${BOLD}Validation stopped at %s.${NC} Fix the above before continuing.\n" "$1"
95
+ exit 1
96
+ }
97
+
98
+ printf "\n"
99
+ printf "${BOLD}========================================${NC}\n"
100
+ printf "${BOLD} OpenEnv Submission Validator${NC}\n"
101
+ printf "${BOLD}========================================${NC}\n"
102
+ log "Repo: $REPO_DIR"
103
+ log "Ping URL: $PING_URL"
104
+ printf "\n"
105
+
106
+ log "${BOLD}Step 1/3: Pinging HF Space${NC} ($PING_URL/reset) ..."
107
+
108
+ CURL_OUTPUT=$(portable_mktemp "validate-curl")
109
+ CLEANUP_FILES+=("$CURL_OUTPUT")
110
+ HTTP_CODE=$(curl -s -o "$CURL_OUTPUT" -w "%{http_code}" -X POST \
111
+ -H "Content-Type: application/json" -d '{}' \
112
+ "$PING_URL/reset" --max-time 30 2>"$CURL_OUTPUT" || printf "000")
113
+
114
+ if [ "$HTTP_CODE" = "200" ]; then
115
+ pass "HF Space is live and responds to /reset"
116
+ elif [ "$HTTP_CODE" = "000" ]; then
117
+ fail "HF Space not reachable (connection failed or timed out)"
118
+ hint "Check your network connection and that the Space is running."
119
+ hint "Try: curl -s -o /dev/null -w '%%{http_code}' -X POST $PING_URL/reset"
120
+ stop_at "Step 1"
121
+ else
122
+ fail "HF Space /reset returned HTTP $HTTP_CODE (expected 200)"
123
+ hint "Make sure your Space is running and the URL is correct."
124
+ hint "Try opening $PING_URL in your browser first."
125
+ stop_at "Step 1"
126
+ fi
127
+
128
+ log "${BOLD}Step 2/3: Running docker build${NC} ..."
129
+
130
+ if ! command -v docker &>/dev/null; then
131
+ fail "docker command not found"
132
+ hint "Install Docker: https://docs.docker.com/get-docker/"
133
+ stop_at "Step 2"
134
+ fi
135
+
136
+ if [ -f "$REPO_DIR/Dockerfile" ]; then
137
+ DOCKER_CONTEXT="$REPO_DIR"
138
+ elif [ -f "$REPO_DIR/server/Dockerfile" ]; then
139
+ DOCKER_CONTEXT="$REPO_DIR/server"
140
+ else
141
+ fail "No Dockerfile found in repo root or server/ directory"
142
+ stop_at "Step 2"
143
+ fi
144
+
145
+ log " Found Dockerfile in $DOCKER_CONTEXT"
146
+
147
+ BUILD_OK=false
148
+ BUILD_OUTPUT=$(run_with_timeout "$DOCKER_BUILD_TIMEOUT" docker build "$DOCKER_CONTEXT" 2>&1) && BUILD_OK=true
149
+
150
+ if [ "$BUILD_OK" = true ]; then
151
+ pass "Docker build succeeded"
152
+ else
153
+ fail "Docker build failed (timeout=${DOCKER_BUILD_TIMEOUT}s)"
154
+ printf "%s\n" "$BUILD_OUTPUT" | tail -20
155
+ stop_at "Step 2"
156
+ fi
157
+
158
+ log "${BOLD}Step 3/3: Running openenv validate${NC} ..."
159
+
160
+ if ! command -v openenv &>/dev/null; then
161
+ fail "openenv command not found"
162
+ hint "Install it: pip install openenv-core"
163
+ stop_at "Step 3"
164
+ fi
165
+
166
+ VALIDATE_OK=false
167
+ VALIDATE_OUTPUT=$(cd "$REPO_DIR" && openenv validate 2>&1) && VALIDATE_OK=true
168
+
169
+ if [ "$VALIDATE_OK" = true ]; then
170
+ pass "openenv validate passed"
171
+ [ -n "$VALIDATE_OUTPUT" ] && log " $VALIDATE_OUTPUT"
172
+ else
173
+ fail "openenv validate failed"
174
+ printf "%s\n" "$VALIDATE_OUTPUT"
175
+ stop_at "Step 3"
176
+ fi
177
+
178
+ printf "\n"
179
+ printf "${BOLD}========================================${NC}\n"
180
+ printf "${GREEN}${BOLD} All 3/3 checks passed!${NC}\n"
181
+ printf "${GREEN}${BOLD} Your submission is ready to submit.${NC}\n"
182
+ printf "${BOLD}========================================${NC}\n"
183
+ printf "\n"
184
+
185
+ exit 0
sample_inference_script.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Inference Script Example
3
+ ===================================
4
+ MANDATORY
5
+ - Before submitting, ensure the following variables are defined in your environment configuration:
6
+ API_BASE_URL The API endpoint for the LLM.
7
+ MODEL_NAME The model identifier to use for inference.
8
+ HF_TOKEN Your Hugging Face / API key.
9
+ LOCAL_IMAGE_NAME The name of the local image to use for the environment if you are using from_docker_image()
10
+ method
11
+
12
+ - Defaults are set only for API_BASE_URL and MODEL_NAME
13
+ (and should reflect your active inference setup):
14
+ API_BASE_URL = os.getenv("API_BASE_URL", "<your-active-endpoint>")
15
+ MODEL_NAME = os.getenv("MODEL_NAME", "<your-active-model>")
16
+
17
+ - The inference script must be named `inference.py` and placed in the root directory of the project
18
+ - Participants must use OpenAI Client for all LLM calls using above variables
19
+
20
+ STDOUT FORMAT
21
+ - The script must emit exactly three line types to stdout, in this order:
22
+
23
+ [START] task=<task_name> env=<benchmark> model=<model_name>
24
+ [STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
25
+ [END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn>
26
+
27
+ Rules:
28
+ - One [START] line at episode begin.
29
+ - One [STEP] line per step, immediately after env.step() returns.
30
+ - One [END] line after env.close(), always emitted (even on exception).
31
+ - reward and rewards are formatted to 2 decimal places.
32
+ - done and success are lowercase booleans: true or false.
33
+ - error is the raw last_action_error string, or null if none.
34
+ - All fields on a single line with no newlines within a line.
35
+ - Each tasks should return score in [0, 1]
36
+
37
+ Example:
38
+ [START] task=click-test env=miniwob model=Qwen3-VL-30B
39
+ [STEP] step=1 action=click('123') reward=0.00 done=false error=null
40
+ [STEP] step=2 action=fill('456','text') reward=0.00 done=false error=null
41
+ [STEP] step=3 action=click('789') reward=1.00 done=true error=null
42
+ [END] success=true steps=3 score=1.00 rewards=0.00,0.00,1.00
43
+ """
44
+
45
+ import asyncio
46
+ import os
47
+ import textwrap
48
+ from typing import List, Optional
49
+
50
+ from openai import OpenAI
51
+
52
+ from my_env_v4 import MyEnvV4Action, MyEnvV4Env
53
+ IMAGE_NAME = os.getenv("IMAGE_NAME") # If you are using docker image
54
+ API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
55
+
56
+ API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
57
+ MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
58
+ TASK_NAME = os.getenv("MY_ENV_V4_TASK", "echo")
59
+ BENCHMARK = os.getenv("MY_ENV_V4_BENCHMARK", "my_env_v4")
60
+ MAX_STEPS = 8
61
+ TEMPERATURE = 0.7
62
+ MAX_TOKENS = 150
63
+ SUCCESS_SCORE_THRESHOLD = 0.1 # normalized score in [0, 1]
64
+
65
+ # Max possible reward: each token contributes 0.1, across all steps
66
+ _MAX_REWARD_PER_STEP = MAX_TOKENS * 0.1
67
+ MAX_TOTAL_REWARD = MAX_STEPS * _MAX_REWARD_PER_STEP
68
+
69
+ SYSTEM_PROMPT = textwrap.dedent(
70
+ """
71
+ You are interacting with a simple echo environment.
72
+ Each turn you must send a message. The environment will echo it back.
73
+ Reward is proportional to message length: reward = len(message) * 0.1
74
+ Your goal is to maximize total reward by sending meaningful, substantive messages.
75
+ Reply with exactly one message string — no quotes, no prefixes, just the message text.
76
+ """
77
+ ).strip()
78
+
79
+
80
+ def log_start(task: str, env: str, model: str) -> None:
81
+ print(f"[START] task={task} env={env} model={model}", flush=True)
82
+
83
+
84
+ def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
85
+ error_val = error if error else "null"
86
+ done_val = str(done).lower()
87
+ print(
88
+ f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
89
+ flush=True,
90
+ )
91
+
92
+
93
+ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
94
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
95
+ print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
96
+
97
+
98
+ def build_user_prompt(step: int, last_echoed: str, last_reward: float, history: List[str]) -> str:
99
+ history_block = "\n".join(history[-4:]) if history else "None"
100
+ return textwrap.dedent(
101
+ f"""
102
+ Step: {step}
103
+ Last echoed message: {last_echoed!r}
104
+ Last reward: {last_reward:.2f}
105
+ Previous steps:
106
+ {history_block}
107
+ Send your next message.
108
+ """
109
+ ).strip()
110
+
111
+
112
+ def get_model_message(client: OpenAI, step: int, last_echoed: str, last_reward: float, history: List[str]) -> str:
113
+ user_prompt = build_user_prompt(step, last_echoed, last_reward, history)
114
+ try:
115
+ completion = client.chat.completions.create(
116
+ model=MODEL_NAME,
117
+ messages=[
118
+ {"role": "system", "content": SYSTEM_PROMPT},
119
+ {"role": "user", "content": user_prompt},
120
+ ],
121
+ temperature=TEMPERATURE,
122
+ max_tokens=MAX_TOKENS,
123
+ stream=False,
124
+ )
125
+ text = (completion.choices[0].message.content or "").strip()
126
+ return text if text else "hello"
127
+ except Exception as exc:
128
+ print(f"[DEBUG] Model request failed: {exc}", flush=True)
129
+ return "hello"
130
+
131
+
132
+ async def main() -> None:
133
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
134
+
135
+ env = await MyEnvV4Env.from_docker_image(IMAGE_NAME)
136
+
137
+ history: List[str] = []
138
+ rewards: List[float] = []
139
+ steps_taken = 0
140
+ score = 0.0
141
+ success = False
142
+
143
+ log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
144
+
145
+ try:
146
+ result = await env.reset() # OpenENV.reset()
147
+ last_echoed = result.observation.echoed_message
148
+ last_reward = 0.0
149
+
150
+ for step in range(1, MAX_STEPS + 1):
151
+ if result.done:
152
+ break
153
+
154
+ message = get_model_message(client, step, last_echoed, last_reward, history)
155
+
156
+ result = await env.step(MyEnvV4Action(message=message))
157
+ obs = result.observation
158
+
159
+ reward = result.reward or 0.0
160
+ done = result.done
161
+ error = None
162
+
163
+ rewards.append(reward)
164
+ steps_taken = step
165
+ last_echoed = obs.echoed_message
166
+ last_reward = reward
167
+
168
+ log_step(step=step, action=message, reward=reward, done=done, error=error)
169
+
170
+ history.append(f"Step {step}: {message!r} -> reward {reward:+.2f}")
171
+
172
+ if done:
173
+ break
174
+
175
+ score = sum(rewards) / MAX_TOTAL_REWARD if MAX_TOTAL_REWARD > 0 else 0.0
176
+ score = min(max(score, 0.0), 1.0) # clamp to [0, 1]
177
+ success = score >= SUCCESS_SCORE_THRESHOLD
178
+
179
+ finally:
180
+ try:
181
+ await env.close()
182
+ except Exception as e:
183
+ print(f"[DEBUG] env.close() error (container cleanup): {e}", flush=True)
184
+ log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
185
+
186
+
187
+ if __name__ == "__main__":
188
+ asyncio.run(main())
tests/__init__.py ADDED
File without changes
tests/test_environment.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for the DataQA environment (reset, step, scoring)."""
2
+
3
+ import pytest
4
+ from dataqa_env.server.environment import (
5
+ DataQAEnvironment,
6
+ parse_issue_key,
7
+ compute_f1,
8
+ compute_weighted_reward,
9
+ )
10
+ from dataqa_env.models import DataQAAction
11
+ from dataqa_env.server.tasks import PlantedIssue
12
+
13
+
14
+ class TestParseIssueKey:
15
+ def test_standard_format(self):
16
+ assert parse_issue_key("row:3,col:salary,issue:missing_value") == "row:3,col:salary,issue:missing_value"
17
+
18
+ def test_with_equals(self):
19
+ assert parse_issue_key("row=3,col=salary,issue=missing_value") == "row:3,col:salary,issue:missing_value"
20
+
21
+ def test_case_insensitive(self):
22
+ assert parse_issue_key("Row:3,Col:Salary,Issue:Missing_Value") == "row:3,col:salary,issue:missing_value"
23
+
24
+ def test_with_spaces(self):
25
+ assert parse_issue_key("row: 3, col: salary, issue: missing_value") == "row:3,col:salary,issue:missing_value"
26
+
27
+ def test_unparseable(self):
28
+ assert parse_issue_key("this is garbage") is None
29
+
30
+ def test_partial_match(self):
31
+ assert parse_issue_key("row:3,col:salary") is None # missing issue
32
+
33
+ def test_empty_string(self):
34
+ assert parse_issue_key("") is None
35
+
36
+ def test_semicolon_separator(self):
37
+ result = parse_issue_key("row:3;col:salary;issue:missing_value")
38
+ assert result == "row:3,col:salary,issue:missing_value"
39
+
40
+
41
+ class TestComputeF1:
42
+ def test_perfect_match(self):
43
+ keys = {"row:1,col:a,issue:missing_value"}
44
+ result = compute_f1(keys, keys)
45
+ assert result["f1"] == 1.0
46
+ assert result["tp"] == 1
47
+ assert result["fp"] == 0
48
+ assert result["fn"] == 0
49
+
50
+ def test_no_reported_no_planted(self):
51
+ result = compute_f1(set(), set())
52
+ assert result["f1"] == 1.0
53
+
54
+ def test_no_reported_some_planted(self):
55
+ planted = {"row:1,col:a,issue:missing_value"}
56
+ result = compute_f1(set(), planted)
57
+ assert result["f1"] == 0.0
58
+ assert result["fn"] == 1
59
+
60
+ def test_all_false_positives(self):
61
+ reported = {"row:99,col:x,issue:wrong_type"}
62
+ planted = {"row:1,col:a,issue:missing_value"}
63
+ result = compute_f1(reported, planted)
64
+ assert result["tp"] == 0
65
+ assert result["fp"] == 1
66
+ assert result["fn"] == 1
67
+ assert result["f1"] == 0.0
68
+
69
+ def test_partial_match(self):
70
+ reported = {"row:1,col:a,issue:missing_value", "row:2,col:b,issue:wrong_type"}
71
+ planted = {"row:1,col:a,issue:missing_value", "row:3,col:c,issue:duplicate_row"}
72
+ result = compute_f1(reported, planted)
73
+ assert result["tp"] == 1
74
+ assert result["fp"] == 1
75
+ assert result["fn"] == 1
76
+ assert 0 < result["f1"] < 1
77
+
78
+ def test_precision_recall_calculation(self):
79
+ reported = {"a", "b", "c"}
80
+ planted = {"a", "b", "d"}
81
+ result = compute_f1(reported, planted)
82
+ assert result["precision"] == pytest.approx(2 / 3)
83
+ assert result["recall"] == pytest.approx(2 / 3)
84
+
85
+
86
+ class TestComputeWeightedReward:
87
+ def test_perfect_match(self):
88
+ issues = [
89
+ PlantedIssue(row=1, col="a", issue_type="missing_value", description="", difficulty=1.0),
90
+ PlantedIssue(row=2, col="b", issue_type="wrong_type", description="", difficulty=3.0),
91
+ ]
92
+ reported = {i.to_key() for i in issues}
93
+ result = compute_weighted_reward(reported, issues)
94
+ assert result["weighted_reward"] == 1.0
95
+
96
+ def test_empty_both(self):
97
+ result = compute_weighted_reward(set(), [])
98
+ assert result["weighted_reward"] == 1.0
99
+
100
+ def test_no_reported(self):
101
+ issues = [PlantedIssue(row=1, col="a", issue_type="missing_value", description="", difficulty=2.0)]
102
+ result = compute_weighted_reward(set(), issues)
103
+ assert result["weighted_reward"] == 0.0
104
+ assert result["difficulty_missed"] == 2.0
105
+
106
+ def test_hard_issue_worth_more(self):
107
+ easy = PlantedIssue(row=1, col="a", issue_type="missing_value", description="", difficulty=1.0)
108
+ hard = PlantedIssue(row=2, col="b", issue_type="statistical_outlier", description="", difficulty=3.0)
109
+ issues = [easy, hard]
110
+
111
+ # Finding only the hard issue should score higher than only the easy issue
112
+ hard_found = compute_weighted_reward({hard.to_key()}, issues)
113
+ easy_found = compute_weighted_reward({easy.to_key()}, issues)
114
+ assert hard_found["weighted_reward"] > easy_found["weighted_reward"]
115
+
116
+ def test_false_positives_reduce_reward(self):
117
+ issues = [PlantedIssue(row=1, col="a", issue_type="missing_value", description="", difficulty=1.0)]
118
+ correct = {issues[0].to_key()}
119
+ with_fp = correct | {"row:99,col:x,issue:wrong_type"}
120
+ r_correct = compute_weighted_reward(correct, issues)
121
+ r_with_fp = compute_weighted_reward(with_fp, issues)
122
+ assert r_correct["weighted_reward"] > r_with_fp["weighted_reward"]
123
+
124
+
125
+ class TestDataQAEnvironment:
126
+ @pytest.fixture
127
+ def env(self):
128
+ return DataQAEnvironment()
129
+
130
+ def test_reset_returns_observation(self, env):
131
+ obs = env.reset(task_id="easy")
132
+ assert obs.dataset_csv
133
+ assert obs.schema_description
134
+ assert obs.validation_rules
135
+ assert obs.task_description
136
+ assert obs.num_issues_hint == 4
137
+ assert obs.max_steps == 3
138
+ assert obs.done is False
139
+ assert obs.reward == 0.0
140
+
141
+ def test_reset_medium(self, env):
142
+ obs = env.reset(task_id="medium")
143
+ assert obs.num_issues_hint == 6
144
+
145
+ def test_reset_hard(self, env):
146
+ obs = env.reset(task_id="hard")
147
+ assert obs.num_issues_hint == 8
148
+
149
+ def test_step_with_correct_issues(self, env):
150
+ env.reset(task_id="easy")
151
+ # Submit all correct issues for easy task
152
+ action = DataQAAction(
153
+ issues=[
154
+ "row:4,col:name,issue:missing_value",
155
+ "row:7,col:salary,issue:wrong_type",
156
+ "row:11,col:employee_id,issue:duplicate_row",
157
+ "row:9,col:salary,issue:out_of_range",
158
+ ],
159
+ task_id="easy",
160
+ )
161
+ obs = env.step(action)
162
+ assert obs.done is True
163
+ assert obs.reward >= 0.999
164
+
165
+ def test_step_with_partial_issues(self, env):
166
+ env.reset(task_id="easy")
167
+ action = DataQAAction(
168
+ issues=["row:4,col:name,issue:missing_value"],
169
+ task_id="easy",
170
+ )
171
+ obs = env.step(action)
172
+ assert 0 < obs.reward < 1.0
173
+ assert obs.done is False
174
+
175
+ def test_step_with_no_issues(self, env):
176
+ env.reset(task_id="easy")
177
+ action = DataQAAction(issues=[], task_id="easy")
178
+ obs = env.step(action)
179
+ assert obs.reward == 0.0
180
+
181
+ def test_step_exhausts_max_steps(self, env):
182
+ env.reset(task_id="easy")
183
+ for _ in range(3):
184
+ action = DataQAAction(issues=["row:99,col:x,issue:wrong_type"], task_id="easy")
185
+ obs = env.step(action)
186
+ assert obs.done is True
187
+
188
+ def test_auto_reset_on_step(self, env):
189
+ # step() without prior reset should auto-reset
190
+ action = DataQAAction(
191
+ issues=["row:4,col:name,issue:missing_value"],
192
+ task_id="easy",
193
+ )
194
+ obs = env.step(action)
195
+ assert obs.task_id == "easy"
196
+
197
+ def test_state_tracking(self, env):
198
+ env.reset(task_id="easy")
199
+ assert env.state.task_id == "easy"
200
+ assert env.state.current_step == 0
201
+ assert env.state.best_score == 0.0
202
+
203
+ action = DataQAAction(issues=["row:4,col:name,issue:missing_value"], task_id="easy")
204
+ env.step(action)
205
+ assert env.state.current_step == 1
206
+ assert env.state.best_score > 0.0
207
+
208
+ def test_best_score_monotonic(self, env):
209
+ env.reset(task_id="easy")
210
+ action1 = DataQAAction(
211
+ issues=["row:4,col:name,issue:missing_value", "row:7,col:salary,issue:wrong_type"],
212
+ task_id="easy",
213
+ )
214
+ env.step(action1)
215
+ score_after_1 = env.state.best_score
216
+
217
+ # Worse submission shouldn't decrease best_score
218
+ action2 = DataQAAction(issues=["row:99,col:x,issue:wrong_type"], task_id="easy")
219
+ env.step(action2)
220
+ assert env.state.best_score >= score_after_1
221
+
222
+ def test_metadata_included_in_observation(self, env):
223
+ env.reset(task_id="easy")
224
+ action = DataQAAction(issues=["row:4,col:name,issue:missing_value"], task_id="easy")
225
+ obs = env.step(action)
226
+ assert "f1" in obs.metadata
227
+ assert "weighted_reward" in obs.metadata
228
+ assert "tp" in obs.metadata
229
+ assert "difficulty_found" in obs.metadata
230
+
231
+ def test_parse_error_in_feedback(self, env):
232
+ env.reset(task_id="easy")
233
+ action = DataQAAction(issues=["garbage input"], task_id="easy")
234
+ obs = env.step(action)
235
+ assert "Parse error" in obs.feedback
236
+
237
+ def test_concurrent_sessions_flag(self):
238
+ assert DataQAEnvironment.SUPPORTS_CONCURRENT_SESSIONS is True
239
+
240
+ def test_reward_between_0_and_1(self, env):
241
+ """Hackathon requirement: scores must be 0.0-1.0."""
242
+ env.reset(task_id="hard")
243
+ for _ in range(3):
244
+ action = DataQAAction(
245
+ issues=["row:1,col:x,issue:wrong_type", "row:99,col:y,issue:missing_value"],
246
+ task_id="hard",
247
+ )
248
+ obs = env.step(action)
249
+ assert 0.0 <= obs.reward <= 1.0
tests/test_extensibility.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for the extensibility API — custom tasks and contamination rules."""
2
+
3
+ import pytest
4
+ from dataqa_env.server.tasks import (
5
+ PlantedIssue,
6
+ create_task_from_config,
7
+ register_task,
8
+ register_contamination_rule,
9
+ CONTAMINATION_RULES,
10
+ get_task,
11
+ list_tasks,
12
+ )
13
+ from dataqa_env.server.environment import DataQAEnvironment, compute_weighted_reward
14
+ from dataqa_env.models import DataQAAction
15
+
16
+
17
+ SIMPLE_CSV = "id,name,score\n1,Alice,95\n2,Bob,87\n3,Carol,92\n4,Dave,78"
18
+
19
+
20
+ class TestCreateTaskFromConfig:
21
+ def test_basic_creation(self):
22
+ task = create_task_from_config(
23
+ task_id="test_custom",
24
+ name="Test Task",
25
+ description="Test",
26
+ schema_description="id: int, name: str, score: int",
27
+ validation_rules="No missing values",
28
+ clean_csv=SIMPLE_CSV,
29
+ contaminations=[
30
+ {"rule": "missing_value", "row": 0, "col": 1},
31
+ ],
32
+ )
33
+ assert task.task_id == "test_custom"
34
+ assert len(task.planted_issues) == 1
35
+ assert task.planted_issues[0].issue_type == "missing_value"
36
+ assert task.planted_issues[0].col == "name"
37
+
38
+ def test_multiple_contaminations(self):
39
+ task = create_task_from_config(
40
+ task_id="multi",
41
+ name="Multi",
42
+ description="Test",
43
+ schema_description="",
44
+ validation_rules="",
45
+ clean_csv=SIMPLE_CSV,
46
+ contaminations=[
47
+ {"rule": "missing_value", "row": 0, "col": 1},
48
+ {"rule": "missing_value", "row": 2, "col": 1},
49
+ ],
50
+ )
51
+ assert len(task.planted_issues) == 2
52
+
53
+ def test_custom_difficulty_override(self):
54
+ task = create_task_from_config(
55
+ task_id="custom_diff",
56
+ name="Custom Difficulty",
57
+ description="Test",
58
+ schema_description="",
59
+ validation_rules="",
60
+ clean_csv=SIMPLE_CSV,
61
+ contaminations=[
62
+ {"rule": "missing_value", "row": 0, "col": 1, "difficulty": 2.5},
63
+ ],
64
+ )
65
+ assert task.planted_issues[0].difficulty == 2.5
66
+
67
+ def test_callable_rule(self):
68
+ def custom_rule(rows, header, col_idx, row_idx, rng):
69
+ return "CORRUPTED", PlantedIssue(
70
+ row=row_idx + 1, col=header[col_idx], issue_type="wrong_type",
71
+ description="Custom corruption", difficulty=1.5,
72
+ )
73
+
74
+ task = create_task_from_config(
75
+ task_id="callable",
76
+ name="Callable Rule",
77
+ description="Test",
78
+ schema_description="",
79
+ validation_rules="",
80
+ clean_csv=SIMPLE_CSV,
81
+ contaminations=[
82
+ {"rule": custom_rule, "row": 1, "col": 2},
83
+ ],
84
+ )
85
+ assert task.planted_issues[0].issue_type == "wrong_type"
86
+ assert "CORRUPTED" in task.corrupted_csv
87
+
88
+ def test_unknown_rule_raises(self):
89
+ with pytest.raises(ValueError, match="Unknown contamination rule"):
90
+ create_task_from_config(
91
+ task_id="bad",
92
+ name="Bad",
93
+ description="",
94
+ schema_description="",
95
+ validation_rules="",
96
+ clean_csv=SIMPLE_CSV,
97
+ contaminations=[{"rule": "nonexistent_rule", "row": 0, "col": 0}],
98
+ )
99
+
100
+
101
+ class TestRegisterContaminationRule:
102
+ def test_register_and_use(self):
103
+ def reverse_value(rows, header, col_idx, row_idx, rng):
104
+ val = rows[row_idx][col_idx]
105
+ return val[::-1], PlantedIssue(
106
+ row=row_idx + 1, col=header[col_idx], issue_type="format_violation",
107
+ description="Reversed value", difficulty=1.5,
108
+ )
109
+
110
+ register_contamination_rule("reverse", reverse_value)
111
+ assert "reverse" in CONTAMINATION_RULES
112
+
113
+ task = create_task_from_config(
114
+ task_id="rev_test",
115
+ name="Reverse Test",
116
+ description="",
117
+ schema_description="",
118
+ validation_rules="",
119
+ clean_csv=SIMPLE_CSV,
120
+ contaminations=[{"rule": "reverse", "row": 0, "col": 1}],
121
+ )
122
+ assert task.planted_issues[0].issue_type == "format_violation"
123
+ # "Alice" reversed is "ecilA"
124
+ assert "ecilA" in task.corrupted_csv
125
+
126
+ # Cleanup
127
+ del CONTAMINATION_RULES["reverse"]
128
+
129
+
130
+ class TestRegisterTask:
131
+ def test_register_and_get(self):
132
+ task = create_task_from_config(
133
+ task_id="registered",
134
+ name="Registered Task",
135
+ description="Test registered task",
136
+ schema_description="id: int, name: str",
137
+ validation_rules="No missing values",
138
+ clean_csv=SIMPLE_CSV,
139
+ contaminations=[{"rule": "missing_value", "row": 1, "col": 1}],
140
+ )
141
+ register_task("registered", lambda seed: task)
142
+ assert "registered" in list_tasks()
143
+
144
+ fetched = get_task("registered")
145
+ assert fetched.task_id == "registered"
146
+ assert len(fetched.planted_issues) == 1
147
+
148
+ # Cleanup
149
+ from dataqa_env.server.tasks import TASK_REGISTRY
150
+ del TASK_REGISTRY["registered"]
151
+
152
+
153
+ class TestCustomTaskInEnvironment:
154
+ def test_full_lifecycle(self):
155
+ """Custom task works end-to-end in the environment."""
156
+ task = create_task_from_config(
157
+ task_id="e2e_custom",
158
+ name="E2E Custom",
159
+ description="End-to-end test",
160
+ schema_description="id: int, name: str, score: int",
161
+ validation_rules="No missing values",
162
+ clean_csv=SIMPLE_CSV,
163
+ contaminations=[
164
+ {"rule": "missing_value", "row": 0, "col": 1, "difficulty": 1.0},
165
+ {"rule": "whitespace_value", "row": 2, "col": 1, "difficulty": 2.5},
166
+ ],
167
+ )
168
+ register_task("e2e_custom", lambda seed: task)
169
+
170
+ env = DataQAEnvironment()
171
+ obs = env.reset(task_id="e2e_custom")
172
+ assert obs.num_issues_hint == 2
173
+
174
+ # Submit correct answers
175
+ action = DataQAAction(
176
+ issues=[i.to_key() for i in task.planted_issues],
177
+ task_id="e2e_custom",
178
+ )
179
+ obs = env.step(action)
180
+ assert obs.done is True
181
+ assert obs.reward >= 0.999
182
+
183
+ # Cleanup
184
+ from dataqa_env.server.tasks import TASK_REGISTRY
185
+ del TASK_REGISTRY["e2e_custom"]
tests/test_inference.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for the inference script's parsing, prompt building, and log format."""
2
+
3
+ import re
4
+ import pytest
5
+ import sys
6
+ import os
7
+
8
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
9
+ from inference import parse_llm_response, build_user_prompt, log_start, log_step, log_end
10
+
11
+
12
+ class TestParseLLMResponse:
13
+ def test_standard_format(self):
14
+ response = "row:1,col:name,issue:missing_value\nrow:2,col:salary,issue:wrong_type"
15
+ issues = parse_llm_response(response)
16
+ assert len(issues) == 2
17
+ assert "row:1,col:name,issue:missing_value" in issues
18
+
19
+ def test_numbered_list(self):
20
+ response = "1. row:1,col:name,issue:missing_value\n2. row:2,col:salary,issue:wrong_type"
21
+ issues = parse_llm_response(response)
22
+ assert len(issues) == 2
23
+
24
+ def test_bullet_list(self):
25
+ response = "- row:1,col:name,issue:missing_value\n* row:2,col:salary,issue:wrong_type"
26
+ issues = parse_llm_response(response)
27
+ assert len(issues) == 2
28
+
29
+ def test_equals_delimiter(self):
30
+ response = "row=1,col=name,issue=missing_value"
31
+ issues = parse_llm_response(response)
32
+ assert len(issues) == 1
33
+ assert issues[0] == "row:1,col:name,issue:missing_value"
34
+
35
+ def test_mixed_case(self):
36
+ response = "Row:1,Col:Name,Issue:Missing_Value"
37
+ issues = parse_llm_response(response)
38
+ assert len(issues) == 1
39
+ assert issues[0] == "row:1,col:name,issue:missing_value"
40
+
41
+ def test_empty_response(self):
42
+ assert parse_llm_response("") == []
43
+ assert parse_llm_response(" ") == []
44
+
45
+ def test_garbage_lines_skipped(self):
46
+ response = "Here are the issues:\nrow:1,col:name,issue:missing_value\nNo more issues."
47
+ issues = parse_llm_response(response)
48
+ assert len(issues) == 1
49
+
50
+ def test_deduplication_not_applied(self):
51
+ response = "row:1,col:name,issue:missing_value\nrow:1,col:name,issue:missing_value"
52
+ issues = parse_llm_response(response)
53
+ assert len(issues) == 2 # duplicates kept, env handles dedup
54
+
55
+ def test_with_column_variant(self):
56
+ response = "row:1,column:name,issue:missing_value"
57
+ issues = parse_llm_response(response)
58
+ assert len(issues) == 1
59
+
60
+
61
+ class TestBuildUserPrompt:
62
+ def test_includes_all_fields(self):
63
+ obs = {
64
+ "task_description": "Find issues",
65
+ "schema_description": "col: int",
66
+ "validation_rules": "no nulls",
67
+ "dataset_csv": "a,b\n1,2",
68
+ "num_issues_hint": 3,
69
+ "feedback": "",
70
+ }
71
+ prompt = build_user_prompt(obs)
72
+ assert "Find issues" in prompt
73
+ assert "col: int" in prompt
74
+ assert "no nulls" in prompt
75
+ assert "a,b" in prompt
76
+ assert "3 issues" in prompt
77
+
78
+ def test_includes_feedback_on_retry(self):
79
+ obs = {
80
+ "task_description": "Find issues",
81
+ "schema_description": "",
82
+ "validation_rules": "",
83
+ "dataset_csv": "a\n1",
84
+ "num_issues_hint": 0,
85
+ "feedback": "Step 1/3: You missed 2 issues",
86
+ }
87
+ prompt = build_user_prompt(obs)
88
+ assert "FEEDBACK" in prompt
89
+ assert "missed 2" in prompt
90
+
91
+ def test_excludes_reset_feedback(self):
92
+ obs = {
93
+ "task_description": "",
94
+ "schema_description": "",
95
+ "validation_rules": "",
96
+ "dataset_csv": "",
97
+ "num_issues_hint": 0,
98
+ "feedback": "Environment reset. Start inspecting.",
99
+ }
100
+ prompt = build_user_prompt(obs)
101
+ assert "FEEDBACK" not in prompt
102
+
103
+
104
+ class TestLogFormat:
105
+ """Verify stdout log format matches hackathon evaluation requirements."""
106
+
107
+ def test_log_start_format(self, capsys):
108
+ log_start(task="easy", env="dataqa_env", model="test-model")
109
+ out = capsys.readouterr().out.strip()
110
+ assert out == "[START] task=easy env=dataqa_env model=test-model"
111
+
112
+ def test_log_step_format(self, capsys):
113
+ log_step(step=1, action="row:1,col:name,issue:missing_value", reward=0.50, done=False, error=None)
114
+ out = capsys.readouterr().out.strip()
115
+ assert out == "[STEP] step=1 action=row:1,col:name,issue:missing_value reward=0.50 done=false error=null"
116
+
117
+ def test_log_step_with_error(self, capsys):
118
+ log_step(step=2, action="none", reward=0.00, done=True, error="timeout")
119
+ out = capsys.readouterr().out.strip()
120
+ assert "error=timeout" in out
121
+ assert "done=true" in out
122
+
123
+ def test_log_end_format(self, capsys):
124
+ log_end(success=True, steps=3, score=0.85, rewards=[0.25, 0.50, 0.85])
125
+ out = capsys.readouterr().out.strip()
126
+ assert out == "[END] success=true steps=3 score=0.850 rewards=0.25,0.50,0.85"
127
+
128
+ def test_log_end_failure(self, capsys):
129
+ log_end(success=False, steps=1, score=0.0, rewards=[0.0])
130
+ out = capsys.readouterr().out.strip()
131
+ assert "success=false" in out
132
+ assert "score=0.000" in out
133
+
134
+ def test_reward_format_2_decimal(self, capsys):
135
+ log_step(step=1, action="test", reward=0.123456, done=False, error=None)
136
+ out = capsys.readouterr().out.strip()
137
+ assert "reward=0.12" in out
138
+
139
+ def test_no_newlines_within_line(self, capsys):
140
+ log_start(task="easy", env="dataqa_env", model="model")
141
+ log_step(step=1, action="act", reward=0.0, done=False, error=None)
142
+ log_end(success=False, steps=1, score=0.0, rewards=[0.0])
143
+ out = capsys.readouterr().out
144
+ lines = [l for l in out.split("\n") if l.strip()]
145
+ assert len(lines) == 3
146
+ assert lines[0].startswith("[START]")
147
+ assert lines[1].startswith("[STEP]")
148
+ assert lines[2].startswith("[END]")
tests/test_tasks.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tests for task definitions, data corruption, and issue planting."""
2
+
3
+ import pytest
4
+ from dataqa_env.server.tasks import (
5
+ PlantedIssue,
6
+ Task,
7
+ create_task_easy,
8
+ create_task_medium,
9
+ create_task_hard,
10
+ get_task,
11
+ list_tasks,
12
+ _csv_to_rows,
13
+ _rows_to_csv,
14
+ )
15
+
16
+
17
+ class TestPlantedIssue:
18
+ def test_to_key(self):
19
+ issue = PlantedIssue(row=3, col="salary", issue_type="missing_value", description="test")
20
+ assert issue.to_key() == "row:3,col:salary,issue:missing_value"
21
+
22
+ def test_difficulty_default(self):
23
+ issue = PlantedIssue(row=1, col="name", issue_type="missing_value", description="test")
24
+ assert issue.difficulty == 1.0
25
+
26
+ def test_difficulty_custom(self):
27
+ issue = PlantedIssue(row=1, col="name", issue_type="missing_value", description="test", difficulty=3.0)
28
+ assert issue.difficulty == 3.0
29
+
30
+
31
+ class TestCSVHelpers:
32
+ def test_roundtrip(self):
33
+ csv_text = "a,b,c\n1,2,3\n4,5,6"
34
+ rows = _csv_to_rows(csv_text)
35
+ assert len(rows) == 3
36
+ result = _rows_to_csv(rows)
37
+ assert "1,2,3" in result
38
+
39
+ def test_empty_csv(self):
40
+ rows = _csv_to_rows("a,b\n")
41
+ assert len(rows) == 1 # header only
42
+
43
+
44
+ class TestTaskEasy:
45
+ @pytest.fixture
46
+ def task(self):
47
+ return create_task_easy()
48
+
49
+ def test_task_id(self, task):
50
+ assert task.task_id == "easy"
51
+
52
+ def test_has_4_issues(self, task):
53
+ assert len(task.planted_issues) == 4
54
+
55
+ def test_issue_types(self, task):
56
+ types = {i.issue_type for i in task.planted_issues}
57
+ assert "missing_value" in types
58
+ assert "wrong_type" in types
59
+ assert "duplicate_row" in types
60
+ assert "out_of_range" in types
61
+
62
+ def test_corrupted_csv_differs_from_clean(self, task):
63
+ assert task.corrupted_csv != task.clean_csv
64
+
65
+ def test_issue_keys_unique(self, task):
66
+ keys = [i.to_key() for i in task.planted_issues]
67
+ assert len(keys) == len(set(keys))
68
+
69
+ def test_max_steps(self, task):
70
+ assert task.max_steps == 3
71
+
72
+ def test_corrupted_csv_has_more_rows(self, task):
73
+ clean_rows = _csv_to_rows(task.clean_csv)
74
+ corrupt_rows = _csv_to_rows(task.corrupted_csv)
75
+ assert len(corrupt_rows) > len(clean_rows) # duplicate row added
76
+
77
+ def test_difficulty_weights(self, task):
78
+ for issue in task.planted_issues:
79
+ assert 1.0 <= issue.difficulty <= 3.0
80
+
81
+
82
+ class TestTaskMedium:
83
+ @pytest.fixture
84
+ def task(self):
85
+ return create_task_medium()
86
+
87
+ def test_task_id(self, task):
88
+ assert task.task_id == "medium"
89
+
90
+ def test_has_6_issues(self, task):
91
+ assert len(task.planted_issues) == 6
92
+
93
+ def test_issue_types(self, task):
94
+ types = {i.issue_type for i in task.planted_issues}
95
+ assert "inconsistent_value" in types
96
+ assert "format_violation" in types
97
+ assert "missing_value" in types
98
+
99
+ def test_issue_keys_unique(self, task):
100
+ keys = [i.to_key() for i in task.planted_issues]
101
+ assert len(keys) == len(set(keys))
102
+
103
+ def test_difficulty_weights(self, task):
104
+ for issue in task.planted_issues:
105
+ assert 1.0 <= issue.difficulty <= 3.0
106
+
107
+
108
+ class TestTaskHard:
109
+ @pytest.fixture
110
+ def task(self):
111
+ return create_task_hard()
112
+
113
+ def test_task_id(self, task):
114
+ assert task.task_id == "hard"
115
+
116
+ def test_has_8_issues(self, task):
117
+ assert len(task.planted_issues) == 8
118
+
119
+ def test_issue_types(self, task):
120
+ types = {i.issue_type for i in task.planted_issues}
121
+ assert "inconsistent_value" in types
122
+ assert "format_violation" in types
123
+ assert "statistical_outlier" in types
124
+ assert "out_of_range" in types
125
+ assert "missing_value" in types
126
+
127
+ def test_has_high_difficulty_issues(self, task):
128
+ hard_issues = [i for i in task.planted_issues if i.difficulty >= 2.5]
129
+ assert len(hard_issues) >= 2 # data leakage, GPU outlier, whitespace
130
+
131
+ def test_issue_keys_unique(self, task):
132
+ keys = [i.to_key() for i in task.planted_issues]
133
+ assert len(keys) == len(set(keys))
134
+
135
+
136
+ class TestTaskRegistry:
137
+ def test_list_tasks(self):
138
+ tasks = list_tasks()
139
+ assert set(tasks) == {"easy", "medium", "hard"}
140
+
141
+ def test_get_task_easy(self):
142
+ task = get_task("easy")
143
+ assert task.task_id == "easy"
144
+
145
+ def test_get_task_medium(self):
146
+ task = get_task("medium")
147
+ assert task.task_id == "medium"
148
+
149
+ def test_get_task_hard(self):
150
+ task = get_task("hard")
151
+ assert task.task_id == "hard"
152
+
153
+ def test_get_task_unknown_raises(self):
154
+ with pytest.raises(ValueError, match="Unknown task"):
155
+ get_task("nonexistent")
156
+
157
+ def test_seed_determinism(self):
158
+ t1 = get_task("easy", seed=42)
159
+ t2 = get_task("easy", seed=42)
160
+ assert t1.corrupted_csv == t2.corrupted_csv
161
+ assert [i.to_key() for i in t1.planted_issues] == [i.to_key() for i in t2.planted_issues]