Spaces:
Sleeping
Sleeping
Commit ·
c3002ad
1
Parent(s): cd11aba
add fix stage+demo
Browse files- README.md +97 -57
- dataqa_env/models.py +12 -10
- dataqa_env/server/environment.py +283 -34
- dataqa_env/server/gradio_ui.py +487 -0
- dataqa_env/server/tasks.py +48 -0
- inference.py +113 -49
- tests/test_environment.py +245 -24
- tests/test_extensibility.py +34 -4
- tests/test_inference.py +46 -3
- tests/test_tasks.py +2 -2
README.md
CHANGED
|
@@ -13,20 +13,38 @@ tags:
|
|
| 13 |
|
| 14 |
# DataQA Environment
|
| 15 |
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
## Motivation
|
| 19 |
|
| 20 |
Every ML engineer and data scientist spends significant time debugging data quality issues — missing values, type mismatches, logical inconsistencies, and subtle statistical anomalies — before data enters ML pipelines or production databases. This is a genuine, high-frequency human task that directly impacts model quality and business outcomes.
|
| 21 |
|
| 22 |
-
DataQA turns this into a
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
|
| 24 |
## Environment API
|
| 25 |
|
| 26 |
| Endpoint | Method | Description |
|
| 27 |
|----------|--------|-------------|
|
| 28 |
| `/reset` | POST | Start a new episode with a corrupted dataset |
|
| 29 |
-
| `/step` | POST | Submit identified issues
|
| 30 |
| `/state` | GET | Get current episode state |
|
| 31 |
| `/health` | GET | Health check |
|
| 32 |
|
|
@@ -36,20 +54,27 @@ DataQA turns this into a structured, gradable RL environment where agents must s
|
|
| 36 |
|------|--------|-----------|--------|-------------|
|
| 37 |
| `easy` | 4 | Beginner | HR/Employee data | Nulls, wrong types, duplicates, out-of-range values |
|
| 38 |
| `medium` | 6 | Intermediate | E-commerce orders | Format violations, inconsistent computed fields, duplicate keys |
|
| 39 |
-
| `hard` |
|
| 40 |
|
| 41 |
**Difficulty progression**: Easy issues are individually obvious (empty fields, text in numeric columns). Medium issues require cross-column reasoning (total != qty * price) and set membership checks. Hard issues require ML domain knowledge (val_loss < train_loss = data leakage) and multi-row temporal reasoning.
|
| 42 |
|
| 43 |
-
## Action Space
|
| 44 |
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
row:<row_number>,col:<column_name>,issue:<issue_type>
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
| 53 |
|
| 54 |
**Supported Issue Types:**
|
| 55 |
|
|
@@ -66,30 +91,36 @@ row:<row_number>,col:<column_name>,issue:<issue_type>
|
|
| 66 |
|
| 67 |
## Observation Space
|
| 68 |
|
| 69 |
-
Each observation contains:
|
| 70 |
-
|
| 71 |
| Field | Type | Description |
|
| 72 |
|-------|------|-------------|
|
| 73 |
| `dataset_csv` | str | The corrupted dataset in CSV format |
|
| 74 |
| `schema_description` | str | Column types, ranges, and constraints |
|
| 75 |
| `validation_rules` | str | Business rules the data must satisfy |
|
| 76 |
| `task_description` | str | Task context and instructions |
|
| 77 |
-
| `feedback` | str |
|
| 78 |
| `num_issues_hint` | int | Exact count of planted issues |
|
| 79 |
| `max_steps` | int | Maximum attempts allowed |
|
| 80 |
| `done` | bool | Whether episode has terminated |
|
| 81 |
-
| `reward` | float | Best
|
| 82 |
|
| 83 |
-
**Observation Metadata** (
|
| 84 |
-
- `
|
| 85 |
-
- `
|
| 86 |
-
- `difficulty_found`, `difficulty_missed`
|
| 87 |
|
| 88 |
## Reward Function
|
| 89 |
|
| 90 |
-
###
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
-
Each planted issue has a **difficulty weight** (1.0-3.0)
|
| 93 |
|
| 94 |
| Weight | Category | Examples |
|
| 95 |
|--------|----------|----------|
|
|
@@ -97,43 +128,51 @@ Each planted issue has a **difficulty weight** (1.0-3.0) reflecting how hard it
|
|
| 97 |
| 1.5-2.0 | Medium | Duplicate keys, format violations, cross-column checks |
|
| 98 |
| 2.5-3.0 | Hard | Data leakage, statistical outliers, whitespace-only |
|
| 99 |
|
| 100 |
-
**
|
| 101 |
-
- **Weighted
|
| 102 |
-
- **Weighted
|
| 103 |
-
|
|
|
|
| 104 |
|
| 105 |
-
|
| 106 |
-
- Finding a hard issue (difficulty 3.0) increases reward 3x more than finding an easy one (1.0)
|
| 107 |
-
- False positives are penalized proportionally to average issue difficulty
|
| 108 |
-
- The agent sees meaningful reward differences at every step, not just pass/fail
|
| 109 |
|
| 110 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
|
| 112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
### Episode Boundaries
|
| 115 |
|
| 116 |
- Each task allows up to 3 steps (attempts)
|
| 117 |
-
- Episode ends when F1 >= 0.999 (perfect) or max steps reached
|
| 118 |
-
-
|
| 119 |
-
- Reward is always in [0.0, 1.0]
|
| 120 |
|
| 121 |
## Baseline Scores
|
| 122 |
|
| 123 |
-
Baseline
|
| 124 |
|
| 125 |
-
| Task |
|
| 126 |
-
|------|---------------------|-------------|
|
| 127 |
-
| `easy` | 0.7 - 1.0 | Most LLMs find obvious issues reliably |
|
| 128 |
-
| `medium` | 0.5 - 0.
|
| 129 |
-
| `hard` | 0.3 - 0.
|
| 130 |
|
| 131 |
-
Scores vary by model
|
| 132 |
|
| 133 |
## Extensibility
|
| 134 |
|
| 135 |
-
DataQA supports custom tasks, contamination rules, and difficulty levels via a programmatic API.
|
| 136 |
-
|
| 137 |
### Custom Contamination Rules
|
| 138 |
|
| 139 |
```python
|
|
@@ -180,7 +219,7 @@ register_task("custom", lambda seed: task)
|
|
| 180 |
|------|--------|--------------------|
|
| 181 |
| `missing_value` | Sets field to empty string | 1.0 |
|
| 182 |
| `whitespace_value` | Sets field to single space | 2.5 |
|
| 183 |
-
| `wrong_type_text` | Replaces with random text
|
| 184 |
| `negative_value` | Negates numeric value | 1.0 |
|
| 185 |
|
| 186 |
## Quick Start
|
|
@@ -213,15 +252,16 @@ pip install -e ".[dev]"
|
|
| 213 |
pytest tests/ -v
|
| 214 |
```
|
| 215 |
|
| 216 |
-
|
| 217 |
-
- Task creation, corruption, and
|
| 218 |
-
- Issue key parsing (standard, lenient, edge cases)
|
| 219 |
-
- F1
|
| 220 |
-
- Full environment
|
|
|
|
| 221 |
- Inference script parsing and prompt building
|
| 222 |
-
-
|
| 223 |
- Score bounds (0.0-1.0), best-score monotonicity
|
| 224 |
-
- Extensibility API (custom rules, custom tasks
|
| 225 |
|
| 226 |
## Validation
|
| 227 |
|
|
@@ -247,19 +287,19 @@ openenv validate .
|
|
| 247 |
```
|
| 248 |
dataqa_env/
|
| 249 |
├── __init__.py # Public API + extensibility exports
|
| 250 |
-
├── models.py # Pydantic: DataQAAction, DataQAObservation, DataQAState
|
| 251 |
├── client.py # EnvClient for WebSocket connections
|
| 252 |
├── server/
|
| 253 |
-
│ ├── environment.py #
|
| 254 |
│ ├── tasks.py # Task definitions + contamination rules + extensibility API
|
| 255 |
│ ├── app.py # FastAPI server (via openenv-core create_app)
|
| 256 |
│ └── Dockerfile
|
| 257 |
tests/
|
| 258 |
├── test_tasks.py # Task creation, corruption, difficulty weights
|
| 259 |
-
├── test_environment.py #
|
| 260 |
-
├── test_inference.py # LLM response parsing, prompt building, log format
|
| 261 |
└── test_extensibility.py # Custom rules, custom tasks, registration API
|
| 262 |
-
inference.py #
|
| 263 |
openenv.yaml # OpenEnv/HF Spaces spec
|
| 264 |
pyproject.toml # Package metadata and dependencies
|
| 265 |
Dockerfile # Production container
|
|
|
|
| 13 |
|
| 14 |
# DataQA Environment
|
| 15 |
|
| 16 |
+
A two-phase OpenEnv RL environment for **Data Quality Assurance** — an LLM agent inspects corrupted datasets, identifies all planted quality issues, and proposes data repairs.
|
| 17 |
+
|
| 18 |
+
### Demo: Agent Trajectory Replay
|
| 19 |
+
|
| 20 |
+
**Easy task** — Agent finds all 4 issues and proposes fixes (step 2):
|
| 21 |
+
|
| 22 |
+

|
| 23 |
+
|
| 24 |
+
**Hard task** — Agent identifies 8 subtle ML issues including data leakage and GPU memory outlier, proposes fixes (step 2):
|
| 25 |
+
|
| 26 |
+

|
| 27 |
+
|
| 28 |
+
Green cells = correctly found issues. Yellow = missed. Green outlines = correct fixes with proposed values shown inline (e.g. `empty → David Kim`, `seventy-five thousand → 75000`).
|
| 29 |
+
|
| 30 |
+
> The interactive replay UI is available at the `/web` endpoint on the HF Space.
|
| 31 |
|
| 32 |
## Motivation
|
| 33 |
|
| 34 |
Every ML engineer and data scientist spends significant time debugging data quality issues — missing values, type mismatches, logical inconsistencies, and subtle statistical anomalies — before data enters ML pipelines or production databases. This is a genuine, high-frequency human task that directly impacts model quality and business outcomes.
|
| 35 |
|
| 36 |
+
DataQA turns this into a **two-phase RL challenge**:
|
| 37 |
+
1. **Identify** — systematically inspect corrupted data and pinpoint every planted issue
|
| 38 |
+
2. **Fix** — propose corrected values by reasoning about schema, constraints, and context
|
| 39 |
+
|
| 40 |
+
This creates a rich multi-step decision problem where agents must explore datasets strategically, distinguish subtle anomalies from noise, and reason about what the correct data should be.
|
| 41 |
|
| 42 |
## Environment API
|
| 43 |
|
| 44 |
| Endpoint | Method | Description |
|
| 45 |
|----------|--------|-------------|
|
| 46 |
| `/reset` | POST | Start a new episode with a corrupted dataset |
|
| 47 |
+
| `/step` | POST | Submit identified issues + proposed fixes |
|
| 48 |
| `/state` | GET | Get current episode state |
|
| 49 |
| `/health` | GET | Health check |
|
| 50 |
|
|
|
|
| 54 |
|------|--------|-----------|--------|-------------|
|
| 55 |
| `easy` | 4 | Beginner | HR/Employee data | Nulls, wrong types, duplicates, out-of-range values |
|
| 56 |
| `medium` | 6 | Intermediate | E-commerce orders | Format violations, inconsistent computed fields, duplicate keys |
|
| 57 |
+
| `hard` | 10 | Advanced | ML experiment metadata | Data leakage signals, unreasonable GPU memory, impossibly fast training, SOTA-exceeding accuracy, timestamp ordering, whitespace-only fields |
|
| 58 |
|
| 59 |
**Difficulty progression**: Easy issues are individually obvious (empty fields, text in numeric columns). Medium issues require cross-column reasoning (total != qty * price) and set membership checks. Hard issues require ML domain knowledge (val_loss < train_loss = data leakage) and multi-row temporal reasoning.
|
| 60 |
|
| 61 |
+
## Two-Phase Action Space
|
| 62 |
|
| 63 |
+
### Phase 1: Identify Issues
|
| 64 |
+
|
| 65 |
+
Submit issues in format: `row:<row_number>,col:<column_name>,issue:<issue_type>`
|
| 66 |
+
|
| 67 |
+
- `row_number`: 1-indexed data row position (after header)
|
| 68 |
+
- `column_name`: Exact column header name, lowercase
|
| 69 |
+
- `issue_type`: One of the supported types below
|
| 70 |
+
|
| 71 |
+
### Phase 2: Propose Fixes
|
| 72 |
|
| 73 |
+
Submit fixes in format: `row:<row_number>,col:<column_name>,fix:<corrected_value>`
|
| 74 |
+
|
| 75 |
+
The agent proposes the **correct value** that should replace the corrupted data. Fixes are graded against the original clean dataset.
|
| 76 |
+
|
| 77 |
+
Both phases can be submitted in the same step or across multiple steps.
|
| 78 |
|
| 79 |
**Supported Issue Types:**
|
| 80 |
|
|
|
|
| 91 |
|
| 92 |
## Observation Space
|
| 93 |
|
|
|
|
|
|
|
| 94 |
| Field | Type | Description |
|
| 95 |
|-------|------|-------------|
|
| 96 |
| `dataset_csv` | str | The corrupted dataset in CSV format |
|
| 97 |
| `schema_description` | str | Column types, ranges, and constraints |
|
| 98 |
| `validation_rules` | str | Business rules the data must satisfy |
|
| 99 |
| `task_description` | str | Task context and instructions |
|
| 100 |
+
| `feedback` | str | Per-step results: TP/FP/FN, precision/recall, fix scores |
|
| 101 |
| `num_issues_hint` | int | Exact count of planted issues |
|
| 102 |
| `max_steps` | int | Maximum attempts allowed |
|
| 103 |
| `done` | bool | Whether episode has terminated |
|
| 104 |
+
| `reward` | float | Best combined reward so far (0.0-1.0) |
|
| 105 |
|
| 106 |
+
**Observation Metadata** (per step):
|
| 107 |
+
- Identify: `identify_f1`, `identify_score`, `precision`, `recall`, `tp`, `fp`, `fn`
|
| 108 |
+
- Fix: `fix_score`, `fixes_correct`, `fixes_partial`, `fixes_wrong`, `fixes_attempted`
|
| 109 |
+
- Combined: `combined_reward`, `difficulty_found`, `difficulty_missed`
|
| 110 |
|
| 111 |
## Reward Function
|
| 112 |
|
| 113 |
+
### Combined Reward
|
| 114 |
+
|
| 115 |
+
```
|
| 116 |
+
combined_reward = 0.6 * identify_score + 0.4 * fix_score
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
If no fixes are submitted, `combined_reward = identify_score` (no penalty — backward compatible).
|
| 120 |
+
|
| 121 |
+
### Identify Score (Difficulty-Weighted F1)
|
| 122 |
|
| 123 |
+
Each planted issue has a **difficulty weight** (1.0-3.0):
|
| 124 |
|
| 125 |
| Weight | Category | Examples |
|
| 126 |
|--------|----------|----------|
|
|
|
|
| 128 |
| 1.5-2.0 | Medium | Duplicate keys, format violations, cross-column checks |
|
| 129 |
| 2.5-3.0 | Hard | Data leakage, statistical outliers, whitespace-only |
|
| 130 |
|
| 131 |
+
- **Weighted Recall** = (difficulty of found issues) / (total difficulty)
|
| 132 |
+
- **Weighted Precision** = penalizes false positives proportional to average difficulty
|
| 133 |
+
- **Weighted F1** = harmonic mean
|
| 134 |
+
|
| 135 |
+
### Fix Score (Difficulty-Weighted Quality)
|
| 136 |
|
| 137 |
+
Each proposed fix is compared against the original clean value:
|
|
|
|
|
|
|
|
|
|
| 138 |
|
| 139 |
+
| Fix Quality | Score | Description |
|
| 140 |
+
|-------------|-------|-------------|
|
| 141 |
+
| Exact match | 1.0 | Case-insensitive, whitespace-stripped match |
|
| 142 |
+
| Numeric close | 0.8 | Within 1% of correct numeric value |
|
| 143 |
+
| Correct cell | 0.1 | Right location, wrong value |
|
| 144 |
+
| Non-issue cell | 0.0 | Fix targets a cell with no issue |
|
| 145 |
|
| 146 |
+
Fix score = (sum of best fix score per issue × difficulty weight) / (total difficulty weight)
|
| 147 |
+
|
| 148 |
+
### Reward Properties
|
| 149 |
+
|
| 150 |
+
- **Per-step partial progress**: reward increases as more issues are found/fixed
|
| 151 |
+
- **Difficulty-aware**: finding subtle issues earns more than obvious ones
|
| 152 |
+
- **Penalizes bad behavior**: false positives reduce score, fixing non-issues earns nothing
|
| 153 |
+
- **Monotonically non-decreasing**: best score across all steps is the final reward
|
| 154 |
+
- **Always in [0.0, 1.0]**: meets hackathon requirement
|
| 155 |
|
| 156 |
### Episode Boundaries
|
| 157 |
|
| 158 |
- Each task allows up to 3 steps (attempts)
|
| 159 |
+
- Episode ends when F1 >= 0.999 (perfect identification) or max steps reached
|
| 160 |
+
- Agent receives detailed feedback after each step to improve on next attempt
|
|
|
|
| 161 |
|
| 162 |
## Baseline Scores
|
| 163 |
|
| 164 |
+
Baseline agent uses Qwen2.5-72B-Instruct via HuggingFace Router:
|
| 165 |
|
| 166 |
+
| Task | Identify Score | Fix Score | Combined | Notes |
|
| 167 |
+
|------|---------------|-----------|----------|-------|
|
| 168 |
+
| `easy` | 0.7-1.0 | 0.5-0.9 | 0.6-1.0 | Most LLMs find obvious issues reliably |
|
| 169 |
+
| `medium` | 0.5-0.8 | 0.3-0.6 | 0.4-0.7 | Cross-column reasoning challenges models |
|
| 170 |
+
| `hard` | 0.3-0.6 | 0.2-0.4 | 0.3-0.5 | ML domain knowledge and subtle patterns |
|
| 171 |
|
| 172 |
+
Scores vary by model. The hard task is designed to challenge frontier models.
|
| 173 |
|
| 174 |
## Extensibility
|
| 175 |
|
|
|
|
|
|
|
| 176 |
### Custom Contamination Rules
|
| 177 |
|
| 178 |
```python
|
|
|
|
| 219 |
|------|--------|--------------------|
|
| 220 |
| `missing_value` | Sets field to empty string | 1.0 |
|
| 221 |
| `whitespace_value` | Sets field to single space | 2.5 |
|
| 222 |
+
| `wrong_type_text` | Replaces with random text | 1.0 |
|
| 223 |
| `negative_value` | Negates numeric value | 1.0 |
|
| 224 |
|
| 225 |
## Quick Start
|
|
|
|
| 252 |
pytest tests/ -v
|
| 253 |
```
|
| 254 |
|
| 255 |
+
118 tests covering:
|
| 256 |
+
- Task creation, corruption, and difficulty weights
|
| 257 |
+
- Issue key and fix parsing (standard, lenient, edge cases)
|
| 258 |
+
- F1, weighted reward, and fix quality computation
|
| 259 |
+
- Full environment lifecycle (identify-only and identify+fix)
|
| 260 |
+
- Combined reward calculation and weight verification
|
| 261 |
- Inference script parsing and prompt building
|
| 262 |
+
- Structured log format ([START], [STEP], [END])
|
| 263 |
- Score bounds (0.0-1.0), best-score monotonicity
|
| 264 |
+
- Extensibility API (custom rules, custom tasks)
|
| 265 |
|
| 266 |
## Validation
|
| 267 |
|
|
|
|
| 287 |
```
|
| 288 |
dataqa_env/
|
| 289 |
├── __init__.py # Public API + extensibility exports
|
| 290 |
+
├── models.py # Pydantic: DataQAAction (issues + fixes), DataQAObservation, DataQAState
|
| 291 |
├── client.py # EnvClient for WebSocket connections
|
| 292 |
├── server/
|
| 293 |
+
│ ├── environment.py # Two-phase DataQAEnvironment (identify + fix + combined reward)
|
| 294 |
│ ├── tasks.py # Task definitions + contamination rules + extensibility API
|
| 295 |
│ ├── app.py # FastAPI server (via openenv-core create_app)
|
| 296 |
│ └── Dockerfile
|
| 297 |
tests/
|
| 298 |
├── test_tasks.py # Task creation, corruption, difficulty weights
|
| 299 |
+
├── test_environment.py # Identify scoring, fix grading, combined reward, lifecycle
|
| 300 |
+
├── test_inference.py # LLM response parsing, fix parsing, prompt building, log format
|
| 301 |
└── test_extensibility.py # Custom rules, custom tasks, registration API
|
| 302 |
+
inference.py # Two-phase baseline agent (identify → fix)
|
| 303 |
openenv.yaml # OpenEnv/HF Spaces spec
|
| 304 |
pyproject.toml # Package metadata and dependencies
|
| 305 |
Dockerfile # Production container
|
dataqa_env/models.py
CHANGED
|
@@ -16,21 +16,23 @@ from openenv.core.env_server.interfaces import Action, Observation, State
|
|
| 16 |
|
| 17 |
class DataQAAction(Action):
|
| 18 |
"""
|
| 19 |
-
Agent submits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
-
Each issue is a string in the format: "row:<row_idx>,col:<col_name>,issue:<issue_type>"
|
| 22 |
Supported issue types:
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
- out_of_range
|
| 27 |
-
- format_violation
|
| 28 |
-
- inconsistent_value
|
| 29 |
-
- statistical_outlier
|
| 30 |
-
- referential_integrity
|
| 31 |
"""
|
| 32 |
|
| 33 |
issues: List[str]
|
|
|
|
| 34 |
# Include task_id so step() can reconstruct context in stateless HTTP mode
|
| 35 |
task_id: str = "easy"
|
| 36 |
|
|
|
|
| 16 |
|
| 17 |
class DataQAAction(Action):
|
| 18 |
"""
|
| 19 |
+
Agent submits identified issues AND optional proposed fixes.
|
| 20 |
+
|
| 21 |
+
Two-phase action space:
|
| 22 |
+
Phase 1 (Identify): List issues in format "row:<N>,col:<name>,issue:<type>"
|
| 23 |
+
Phase 2 (Fix): List fixes in format "row:<N>,col:<name>,fix:<proposed_value>"
|
| 24 |
+
|
| 25 |
+
The agent can submit both in the same step or across multiple steps.
|
| 26 |
+
Combined reward = 0.6 * identify_score + 0.4 * fix_score
|
| 27 |
|
|
|
|
| 28 |
Supported issue types:
|
| 29 |
+
missing_value, wrong_type, duplicate_row, out_of_range,
|
| 30 |
+
format_violation, inconsistent_value, statistical_outlier,
|
| 31 |
+
referential_integrity
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
"""
|
| 33 |
|
| 34 |
issues: List[str]
|
| 35 |
+
fixes: List[str] = []
|
| 36 |
# Include task_id so step() can reconstruct context in stateless HTTP mode
|
| 37 |
task_id: str = "easy"
|
| 38 |
|
dataqa_env/server/environment.py
CHANGED
|
@@ -3,8 +3,12 @@ DataQA Environment
|
|
| 3 |
------------------
|
| 4 |
Server-side environment for data quality assurance tasks.
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
"""
|
| 9 |
|
| 10 |
from __future__ import annotations
|
|
@@ -18,6 +22,10 @@ from openenv.core.env_server.interfaces import Action, Environment, Observation
|
|
| 18 |
from ..models import DataQAAction, DataQAObservation, DataQAState
|
| 19 |
from .tasks import PlantedIssue, Task, get_task, list_tasks
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
def parse_issue_key(raw: str) -> Optional[str]:
|
| 23 |
"""
|
|
@@ -26,7 +34,6 @@ def parse_issue_key(raw: str) -> Optional[str]:
|
|
| 26 |
Returns normalized key or None if unparseable.
|
| 27 |
"""
|
| 28 |
raw = raw.strip().lower()
|
| 29 |
-
# Be lenient with formatting
|
| 30 |
row_match = re.search(r"row\s*[:=]\s*(\d+)", raw)
|
| 31 |
col_match = re.search(r"col\s*[:=]\s*([\w_]+)", raw)
|
| 32 |
issue_match = re.search(r"issue\s*[:=]\s*([\w_]+)", raw)
|
|
@@ -36,6 +43,22 @@ def parse_issue_key(raw: str) -> Optional[str]:
|
|
| 36 |
return None
|
| 37 |
|
| 38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
def compute_f1(reported_keys: Set[str], planted_keys: Set[str]) -> dict:
|
| 40 |
"""Compute precision, recall, and F1 score."""
|
| 41 |
if not reported_keys and not planted_keys:
|
|
@@ -83,7 +106,6 @@ def compute_weighted_reward(
|
|
| 83 |
if not planted_keys:
|
| 84 |
return {"weighted_reward": 0.0, "difficulty_found": 0.0, "difficulty_missed": 0.0}
|
| 85 |
|
| 86 |
-
# Sum difficulty weights for found vs missed issues
|
| 87 |
found_keys = reported_keys & planted_keys
|
| 88 |
missed_keys = planted_keys - reported_keys
|
| 89 |
false_positive_count = len(reported_keys - planted_keys)
|
|
@@ -92,15 +114,12 @@ def compute_weighted_reward(
|
|
| 92 |
difficulty_missed = sum(planted_by_key[k].difficulty for k in missed_keys)
|
| 93 |
total_weight = sum(i.difficulty for i in planted_issues)
|
| 94 |
|
| 95 |
-
# Weighted recall: proportion of difficulty captured
|
| 96 |
weighted_recall = difficulty_found / total_weight if total_weight > 0 else 0.0
|
| 97 |
|
| 98 |
-
# Penalty for false positives (scaled by avg difficulty so penalty is proportional)
|
| 99 |
avg_difficulty = total_weight / len(planted_issues)
|
| 100 |
fp_penalty_weight = false_positive_count * avg_difficulty
|
| 101 |
weighted_precision = difficulty_found / (difficulty_found + fp_penalty_weight) if (difficulty_found + fp_penalty_weight) > 0 else 0.0
|
| 102 |
|
| 103 |
-
# Weighted F1
|
| 104 |
if (weighted_precision + weighted_recall) > 0:
|
| 105 |
weighted_reward = 2 * weighted_precision * weighted_recall / (weighted_precision + weighted_recall)
|
| 106 |
else:
|
|
@@ -113,12 +132,134 @@ def compute_weighted_reward(
|
|
| 113 |
}
|
| 114 |
|
| 115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
class DataQAEnvironment(Environment):
|
| 117 |
"""
|
| 118 |
-
Data Quality Assurance environment.
|
| 119 |
|
| 120 |
-
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
| 122 |
"""
|
| 123 |
|
| 124 |
SUPPORTS_CONCURRENT_SESSIONS = True
|
|
@@ -158,7 +299,11 @@ class DataQAEnvironment(Environment):
|
|
| 158 |
schema_description=self._current_task.schema_description,
|
| 159 |
validation_rules=self._current_task.validation_rules,
|
| 160 |
task_description=self._current_task.description,
|
| 161 |
-
feedback=
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
task_id=task_id,
|
| 163 |
num_issues_hint=len(self._current_task.planted_issues),
|
| 164 |
max_steps=self._current_task.max_steps,
|
|
@@ -175,15 +320,14 @@ class DataQAEnvironment(Environment):
|
|
| 175 |
if not isinstance(action, DataQAAction):
|
| 176 |
raise ValueError(f"Expected DataQAAction, got {type(action)}")
|
| 177 |
|
| 178 |
-
#
|
| 179 |
-
# Auto-reset using the task_id from the action so step() works standalone.
|
| 180 |
if self._current_task is None:
|
| 181 |
self.reset(task_id=action.task_id)
|
| 182 |
|
| 183 |
self._state.step_count += 1
|
| 184 |
self._state.current_step += 1
|
| 185 |
|
| 186 |
-
# Parse
|
| 187 |
reported_keys: Set[str] = set()
|
| 188 |
parse_errors: list[str] = []
|
| 189 |
for raw_issue in action.issues:
|
|
@@ -191,51 +335,148 @@ class DataQAEnvironment(Environment):
|
|
| 191 |
if key:
|
| 192 |
reported_keys.add(key)
|
| 193 |
else:
|
| 194 |
-
parse_errors.append(f"Could not parse: '{raw_issue}'")
|
| 195 |
|
| 196 |
-
# Compute score (standard F1)
|
| 197 |
metrics = compute_f1(reported_keys, self._planted_keys)
|
| 198 |
-
|
| 199 |
|
| 200 |
-
# Compute difficulty-weighted reward (richer per-step signal)
|
| 201 |
weighted = compute_weighted_reward(reported_keys, self._current_task.planted_issues)
|
| 202 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
|
| 204 |
-
#
|
| 205 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
self._state.best_score = self._best_score
|
| 207 |
|
| 208 |
-
# Check if done
|
| 209 |
is_done = (
|
| 210 |
-
|
| 211 |
or self._state.current_step >= self._state.max_steps
|
| 212 |
)
|
| 213 |
|
| 214 |
-
# Build feedback
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
feedback_lines = [
|
| 216 |
f"Step {self._state.current_step}/{self._state.max_steps}",
|
|
|
|
|
|
|
| 217 |
f"Issues reported: {len(reported_keys)}",
|
| 218 |
f"True positives: {metrics['tp']}, False positives: {metrics['fp']}, Missed: {metrics['fn']}",
|
| 219 |
-
f"Precision: {metrics['precision']:.3f}, Recall: {metrics['recall']:.3f}, F1: {
|
| 220 |
-
f"
|
| 221 |
]
|
| 222 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
if parse_errors:
|
| 224 |
-
feedback_lines.append(f"
|
| 225 |
|
| 226 |
if not is_done:
|
| 227 |
-
# Give hints about what was missed without revealing exact answers
|
| 228 |
if metrics["fn"] > 0:
|
| 229 |
feedback_lines.append(
|
| 230 |
-
f"
|
| 231 |
)
|
| 232 |
if metrics["fp"] > 0:
|
| 233 |
feedback_lines.append(
|
| 234 |
-
f"{metrics['fp']}
|
| 235 |
)
|
| 236 |
-
feedback_lines.append("You can submit again with
|
| 237 |
else:
|
| 238 |
-
feedback_lines.append(f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
|
| 240 |
return DataQAObservation(
|
| 241 |
dataset_csv=self._current_task.corrupted_csv,
|
|
@@ -249,8 +490,10 @@ class DataQAEnvironment(Environment):
|
|
| 249 |
done=is_done,
|
| 250 |
reward=self._best_score,
|
| 251 |
metadata={
|
| 252 |
-
"
|
| 253 |
-
"
|
|
|
|
|
|
|
| 254 |
"precision": metrics["precision"],
|
| 255 |
"recall": metrics["recall"],
|
| 256 |
"tp": metrics["tp"],
|
|
@@ -258,6 +501,12 @@ class DataQAEnvironment(Environment):
|
|
| 258 |
"fn": metrics["fn"],
|
| 259 |
"difficulty_found": weighted["difficulty_found"],
|
| 260 |
"difficulty_missed": weighted["difficulty_missed"],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
},
|
| 262 |
)
|
| 263 |
|
|
|
|
| 3 |
------------------
|
| 4 |
Server-side environment for data quality assurance tasks.
|
| 5 |
|
| 6 |
+
Two-phase RL environment:
|
| 7 |
+
Phase 1 (Identify): Agent inspects corrupted datasets and reports quality issues.
|
| 8 |
+
Phase 2 (Fix): Agent proposes corrections for identified issues.
|
| 9 |
+
|
| 10 |
+
Combined reward = 0.6 * identify_score + 0.4 * fix_score
|
| 11 |
+
Both phases scored with difficulty-weighted metrics for rich per-step signal.
|
| 12 |
"""
|
| 13 |
|
| 14 |
from __future__ import annotations
|
|
|
|
| 22 |
from ..models import DataQAAction, DataQAObservation, DataQAState
|
| 23 |
from .tasks import PlantedIssue, Task, get_task, list_tasks
|
| 24 |
|
| 25 |
+
# Reward weights for the two phases
|
| 26 |
+
IDENTIFY_WEIGHT = 0.6
|
| 27 |
+
FIX_WEIGHT = 0.4
|
| 28 |
+
|
| 29 |
|
| 30 |
def parse_issue_key(raw: str) -> Optional[str]:
|
| 31 |
"""
|
|
|
|
| 34 |
Returns normalized key or None if unparseable.
|
| 35 |
"""
|
| 36 |
raw = raw.strip().lower()
|
|
|
|
| 37 |
row_match = re.search(r"row\s*[:=]\s*(\d+)", raw)
|
| 38 |
col_match = re.search(r"col\s*[:=]\s*([\w_]+)", raw)
|
| 39 |
issue_match = re.search(r"issue\s*[:=]\s*([\w_]+)", raw)
|
|
|
|
| 43 |
return None
|
| 44 |
|
| 45 |
|
| 46 |
+
def parse_fix(raw: str) -> Optional[tuple[int, str, str]]:
|
| 47 |
+
"""
|
| 48 |
+
Parse an agent-proposed fix into (row, col, proposed_value).
|
| 49 |
+
Expected format: row:<N>,col:<name>,fix:<value>
|
| 50 |
+
Returns (row, col, value) or None if unparseable.
|
| 51 |
+
"""
|
| 52 |
+
raw = raw.strip()
|
| 53 |
+
row_match = re.search(r"row\s*[:=]\s*(\d+)", raw, re.IGNORECASE)
|
| 54 |
+
col_match = re.search(r"col(?:umn)?\s*[:=]\s*([\w_]+)", raw, re.IGNORECASE)
|
| 55 |
+
fix_match = re.search(r"fix\s*[:=]\s*(.+?)$", raw, re.IGNORECASE)
|
| 56 |
+
|
| 57 |
+
if row_match and col_match and fix_match:
|
| 58 |
+
return (int(row_match.group(1)), col_match.group(1).lower(), fix_match.group(1).strip())
|
| 59 |
+
return None
|
| 60 |
+
|
| 61 |
+
|
| 62 |
def compute_f1(reported_keys: Set[str], planted_keys: Set[str]) -> dict:
|
| 63 |
"""Compute precision, recall, and F1 score."""
|
| 64 |
if not reported_keys and not planted_keys:
|
|
|
|
| 106 |
if not planted_keys:
|
| 107 |
return {"weighted_reward": 0.0, "difficulty_found": 0.0, "difficulty_missed": 0.0}
|
| 108 |
|
|
|
|
| 109 |
found_keys = reported_keys & planted_keys
|
| 110 |
missed_keys = planted_keys - reported_keys
|
| 111 |
false_positive_count = len(reported_keys - planted_keys)
|
|
|
|
| 114 |
difficulty_missed = sum(planted_by_key[k].difficulty for k in missed_keys)
|
| 115 |
total_weight = sum(i.difficulty for i in planted_issues)
|
| 116 |
|
|
|
|
| 117 |
weighted_recall = difficulty_found / total_weight if total_weight > 0 else 0.0
|
| 118 |
|
|
|
|
| 119 |
avg_difficulty = total_weight / len(planted_issues)
|
| 120 |
fp_penalty_weight = false_positive_count * avg_difficulty
|
| 121 |
weighted_precision = difficulty_found / (difficulty_found + fp_penalty_weight) if (difficulty_found + fp_penalty_weight) > 0 else 0.0
|
| 122 |
|
|
|
|
| 123 |
if (weighted_precision + weighted_recall) > 0:
|
| 124 |
weighted_reward = 2 * weighted_precision * weighted_recall / (weighted_precision + weighted_recall)
|
| 125 |
else:
|
|
|
|
| 132 |
}
|
| 133 |
|
| 134 |
|
| 135 |
+
def grade_fixes(
|
| 136 |
+
fixes: list[tuple[int, str, str]],
|
| 137 |
+
task: Task,
|
| 138 |
+
) -> dict:
|
| 139 |
+
"""
|
| 140 |
+
Grade proposed fixes against the clean dataset.
|
| 141 |
+
|
| 142 |
+
For each fix (row, col, proposed_value), compare to the original clean value.
|
| 143 |
+
Scoring per fix:
|
| 144 |
+
- Exact match (case-insensitive, whitespace-stripped): 1.0
|
| 145 |
+
- Numeric close match (within 1%): 0.8
|
| 146 |
+
- Correct column but wrong value: 0.1
|
| 147 |
+
- Targets a non-issue cell: 0.0 (penalty)
|
| 148 |
+
|
| 149 |
+
Returns dict with fix_score (0.0-1.0), details per fix, and counts.
|
| 150 |
+
"""
|
| 151 |
+
if not fixes and not task.planted_issues:
|
| 152 |
+
return {"fix_score": 1.0, "fixes_correct": 0, "fixes_partial": 0,
|
| 153 |
+
"fixes_wrong": 0, "fixes_attempted": 0, "fix_details": []}
|
| 154 |
+
|
| 155 |
+
if not fixes:
|
| 156 |
+
return {"fix_score": 0.0, "fixes_correct": 0, "fixes_partial": 0,
|
| 157 |
+
"fixes_wrong": 0, "fixes_attempted": 0, "fix_details": []}
|
| 158 |
+
|
| 159 |
+
issue_map = task.get_planted_issue_map()
|
| 160 |
+
# Build set of (row, col) that are actual issues
|
| 161 |
+
issue_cells = {(issue.row, issue.col) for issue in task.planted_issues}
|
| 162 |
+
|
| 163 |
+
total_weight = sum(i.difficulty for i in task.planted_issues) if task.planted_issues else 1.0
|
| 164 |
+
earned_weight = 0.0
|
| 165 |
+
fixes_correct = 0
|
| 166 |
+
fixes_partial = 0
|
| 167 |
+
fixes_wrong = 0
|
| 168 |
+
fix_details = []
|
| 169 |
+
|
| 170 |
+
# Track which issues have been fixed (best fix wins)
|
| 171 |
+
fixed_issues: dict[tuple[int, str], float] = {}
|
| 172 |
+
|
| 173 |
+
for row, col, proposed in fixes:
|
| 174 |
+
clean_value = task.get_clean_value(row, col)
|
| 175 |
+
cell_key = (row, col)
|
| 176 |
+
|
| 177 |
+
if cell_key not in issue_cells:
|
| 178 |
+
# Fix targets a non-issue cell — no credit
|
| 179 |
+
fix_details.append({"row": row, "col": col, "score": 0.0, "reason": "not an issue cell"})
|
| 180 |
+
fixes_wrong += 1
|
| 181 |
+
continue
|
| 182 |
+
|
| 183 |
+
if clean_value is None:
|
| 184 |
+
fix_details.append({"row": row, "col": col, "score": 0.0, "reason": "cell not found"})
|
| 185 |
+
fixes_wrong += 1
|
| 186 |
+
continue
|
| 187 |
+
|
| 188 |
+
# Find the planted issue for this cell to get its difficulty weight
|
| 189 |
+
matching_issue = None
|
| 190 |
+
for issue in task.planted_issues:
|
| 191 |
+
if issue.row == row and issue.col == col:
|
| 192 |
+
matching_issue = issue
|
| 193 |
+
break
|
| 194 |
+
|
| 195 |
+
difficulty = matching_issue.difficulty if matching_issue else 1.0
|
| 196 |
+
|
| 197 |
+
# Score the fix
|
| 198 |
+
score = 0.0
|
| 199 |
+
reason = "wrong value"
|
| 200 |
+
|
| 201 |
+
# Exact match (case-insensitive, whitespace-stripped)
|
| 202 |
+
if proposed.strip().lower() == clean_value.lower():
|
| 203 |
+
score = 1.0
|
| 204 |
+
reason = "exact match"
|
| 205 |
+
fixes_correct += 1
|
| 206 |
+
else:
|
| 207 |
+
# Try numeric close match
|
| 208 |
+
try:
|
| 209 |
+
proposed_num = float(proposed.strip())
|
| 210 |
+
clean_num = float(clean_value)
|
| 211 |
+
if clean_num != 0 and abs(proposed_num - clean_num) / abs(clean_num) <= 0.01:
|
| 212 |
+
score = 0.8
|
| 213 |
+
reason = "numeric close match"
|
| 214 |
+
fixes_partial += 1
|
| 215 |
+
elif proposed_num == clean_num:
|
| 216 |
+
score = 1.0
|
| 217 |
+
reason = "exact numeric match"
|
| 218 |
+
fixes_correct += 1
|
| 219 |
+
else:
|
| 220 |
+
score = 0.1
|
| 221 |
+
reason = "correct cell, wrong value"
|
| 222 |
+
fixes_partial += 1
|
| 223 |
+
except (ValueError, ZeroDivisionError):
|
| 224 |
+
# Not numeric — just a wrong value but at least right cell
|
| 225 |
+
score = 0.1
|
| 226 |
+
reason = "correct cell, wrong value"
|
| 227 |
+
fixes_partial += 1
|
| 228 |
+
|
| 229 |
+
# Keep best fix per cell
|
| 230 |
+
if cell_key not in fixed_issues or score > fixed_issues[cell_key]:
|
| 231 |
+
fixed_issues[cell_key] = score
|
| 232 |
+
|
| 233 |
+
fix_details.append({"row": row, "col": col, "score": score, "reason": reason})
|
| 234 |
+
|
| 235 |
+
# Compute fix score: weighted sum of best fix per issue / total weight
|
| 236 |
+
for issue in task.planted_issues:
|
| 237 |
+
cell_key = (issue.row, issue.col)
|
| 238 |
+
if cell_key in fixed_issues:
|
| 239 |
+
earned_weight += issue.difficulty * fixed_issues[cell_key]
|
| 240 |
+
|
| 241 |
+
fix_score = earned_weight / total_weight if total_weight > 0 else 0.0
|
| 242 |
+
fix_score = min(max(fix_score, 0.0), 1.0)
|
| 243 |
+
|
| 244 |
+
return {
|
| 245 |
+
"fix_score": round(fix_score, 4),
|
| 246 |
+
"fixes_correct": fixes_correct,
|
| 247 |
+
"fixes_partial": fixes_partial,
|
| 248 |
+
"fixes_wrong": fixes_wrong,
|
| 249 |
+
"fixes_attempted": len(fixes),
|
| 250 |
+
"fix_details": fix_details,
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
|
| 254 |
class DataQAEnvironment(Environment):
|
| 255 |
"""
|
| 256 |
+
Data Quality Assurance environment — two-phase identify + fix.
|
| 257 |
|
| 258 |
+
Phase 1 (Identify): Agent inspects corrupted datasets and reports quality issues.
|
| 259 |
+
Phase 2 (Fix): Agent proposes corrections for identified issues.
|
| 260 |
+
|
| 261 |
+
Combined reward = 0.6 * identify_score + 0.4 * fix_score
|
| 262 |
+
Both phases use difficulty-weighted scoring for rich per-step reward signals.
|
| 263 |
"""
|
| 264 |
|
| 265 |
SUPPORTS_CONCURRENT_SESSIONS = True
|
|
|
|
| 299 |
schema_description=self._current_task.schema_description,
|
| 300 |
validation_rules=self._current_task.validation_rules,
|
| 301 |
task_description=self._current_task.description,
|
| 302 |
+
feedback=(
|
| 303 |
+
"Environment reset. Inspect the dataset and report all quality issues.\n"
|
| 304 |
+
"You can also propose fixes in format: row:<N>,col:<name>,fix:<corrected_value>\n"
|
| 305 |
+
"Combined reward = 0.6 * identify_score + 0.4 * fix_score"
|
| 306 |
+
),
|
| 307 |
task_id=task_id,
|
| 308 |
num_issues_hint=len(self._current_task.planted_issues),
|
| 309 |
max_steps=self._current_task.max_steps,
|
|
|
|
| 320 |
if not isinstance(action, DataQAAction):
|
| 321 |
raise ValueError(f"Expected DataQAAction, got {type(action)}")
|
| 322 |
|
| 323 |
+
# Auto-reset in stateless HTTP mode
|
|
|
|
| 324 |
if self._current_task is None:
|
| 325 |
self.reset(task_id=action.task_id)
|
| 326 |
|
| 327 |
self._state.step_count += 1
|
| 328 |
self._state.current_step += 1
|
| 329 |
|
| 330 |
+
# ─�� Phase 1: Parse and score issue identification ──
|
| 331 |
reported_keys: Set[str] = set()
|
| 332 |
parse_errors: list[str] = []
|
| 333 |
for raw_issue in action.issues:
|
|
|
|
| 335 |
if key:
|
| 336 |
reported_keys.add(key)
|
| 337 |
else:
|
| 338 |
+
parse_errors.append(f"Could not parse issue: '{raw_issue}'")
|
| 339 |
|
|
|
|
| 340 |
metrics = compute_f1(reported_keys, self._planted_keys)
|
| 341 |
+
identify_f1 = metrics["f1"]
|
| 342 |
|
|
|
|
| 343 |
weighted = compute_weighted_reward(reported_keys, self._current_task.planted_issues)
|
| 344 |
+
identify_score = weighted["weighted_reward"]
|
| 345 |
+
|
| 346 |
+
# ── Phase 2: Parse and score proposed fixes ──
|
| 347 |
+
parsed_fixes: list[tuple[int, str, str]] = []
|
| 348 |
+
for raw_fix in action.fixes:
|
| 349 |
+
fix = parse_fix(raw_fix)
|
| 350 |
+
if fix:
|
| 351 |
+
parsed_fixes.append(fix)
|
| 352 |
+
else:
|
| 353 |
+
parse_errors.append(f"Could not parse fix: '{raw_fix}'")
|
| 354 |
+
|
| 355 |
+
fix_result = grade_fixes(parsed_fixes, self._current_task)
|
| 356 |
+
fix_score = fix_result["fix_score"]
|
| 357 |
|
| 358 |
+
# ── Combined reward ──
|
| 359 |
+
# If no fixes submitted, score is identify-only (no penalty for not fixing)
|
| 360 |
+
if action.fixes:
|
| 361 |
+
combined_reward = IDENTIFY_WEIGHT * identify_score + FIX_WEIGHT * fix_score
|
| 362 |
+
else:
|
| 363 |
+
combined_reward = identify_score # backward compatible
|
| 364 |
+
|
| 365 |
+
self._best_score = max(self._best_score, combined_reward)
|
| 366 |
self._state.best_score = self._best_score
|
| 367 |
|
| 368 |
+
# ── Check if done ──
|
| 369 |
is_done = (
|
| 370 |
+
identify_f1 >= 0.999 # Perfect identification
|
| 371 |
or self._state.current_step >= self._state.max_steps
|
| 372 |
)
|
| 373 |
|
| 374 |
+
# ── Build feedback with actionable diagnostics ──
|
| 375 |
+
# Show the agent exactly which reported issues were correct (TP) and which were wrong (FP)
|
| 376 |
+
tp_keys = reported_keys & self._planted_keys
|
| 377 |
+
fp_keys = reported_keys - self._planted_keys
|
| 378 |
+
|
| 379 |
feedback_lines = [
|
| 380 |
f"Step {self._state.current_step}/{self._state.max_steps}",
|
| 381 |
+
"",
|
| 382 |
+
"--- Identification ---",
|
| 383 |
f"Issues reported: {len(reported_keys)}",
|
| 384 |
f"True positives: {metrics['tp']}, False positives: {metrics['fp']}, Missed: {metrics['fn']}",
|
| 385 |
+
f"Precision: {metrics['precision']:.3f}, Recall: {metrics['recall']:.3f}, F1: {identify_f1:.3f}",
|
| 386 |
+
f"Identify score (weighted): {identify_score:.3f}",
|
| 387 |
]
|
| 388 |
|
| 389 |
+
# Show which reported issues were correct vs wrong (helps agent self-correct)
|
| 390 |
+
if tp_keys:
|
| 391 |
+
feedback_lines.append(f"Correct issues: {', '.join(sorted(tp_keys))}")
|
| 392 |
+
if fp_keys:
|
| 393 |
+
feedback_lines.append(f"Incorrect issues (false positives): {', '.join(sorted(fp_keys))}")
|
| 394 |
+
|
| 395 |
+
if action.fixes:
|
| 396 |
+
feedback_lines += [
|
| 397 |
+
"",
|
| 398 |
+
"--- Fix Proposals ---",
|
| 399 |
+
f"Fixes attempted: {fix_result['fixes_attempted']}",
|
| 400 |
+
f"Correct: {fix_result['fixes_correct']}, Partial: {fix_result['fixes_partial']}, Wrong: {fix_result['fixes_wrong']}",
|
| 401 |
+
f"Fix score: {fix_score:.3f}",
|
| 402 |
+
]
|
| 403 |
+
# Show per-fix feedback so agent knows which fixes worked
|
| 404 |
+
for detail in fix_result["fix_details"]:
|
| 405 |
+
status = "correct" if detail["score"] >= 0.99 else ("partial" if detail["score"] > 0 else "wrong")
|
| 406 |
+
feedback_lines.append(
|
| 407 |
+
f" row:{detail['row']},col:{detail['col']} -> {status} ({detail['reason']})"
|
| 408 |
+
)
|
| 409 |
+
feedback_lines.append(
|
| 410 |
+
f"\n--- Combined Reward: {combined_reward:.3f} (identify={identify_score:.3f} x {IDENTIFY_WEIGHT} + fix={fix_score:.3f} x {FIX_WEIGHT}) ---"
|
| 411 |
+
)
|
| 412 |
+
else:
|
| 413 |
+
feedback_lines += [
|
| 414 |
+
"",
|
| 415 |
+
"Tip: Submit fixes with format row:<N>,col:<name>,fix:<value> for bonus reward.",
|
| 416 |
+
]
|
| 417 |
+
|
| 418 |
if parse_errors:
|
| 419 |
+
feedback_lines.append(f"\nParse errors ({len(parse_errors)}): {'; '.join(parse_errors[:5])}")
|
| 420 |
|
| 421 |
if not is_done:
|
|
|
|
| 422 |
if metrics["fn"] > 0:
|
| 423 |
feedback_lines.append(
|
| 424 |
+
f"\nYou missed {metrics['fn']} issue(s). Review the dataset carefully."
|
| 425 |
)
|
| 426 |
if metrics["fp"] > 0:
|
| 427 |
feedback_lines.append(
|
| 428 |
+
f"Remove the {metrics['fp']} false positive(s) listed above and look for real issues."
|
| 429 |
)
|
| 430 |
+
feedback_lines.append("You can submit again with updated issues and/or fixes.")
|
| 431 |
else:
|
| 432 |
+
feedback_lines.append(f"\nTask complete! Final best reward: {self._best_score:.3f}")
|
| 433 |
+
|
| 434 |
+
# ── Flag items for human review ──
|
| 435 |
+
# In a production data QA pipeline, these would go to a human reviewer.
|
| 436 |
+
# The grader flags cases where automated scoring has low confidence.
|
| 437 |
+
human_review_flags: list[dict] = []
|
| 438 |
+
|
| 439 |
+
# 1. False positives that target real columns — could be legitimate issues
|
| 440 |
+
# the task designer didn't plant (agent may be smarter than the grader)
|
| 441 |
+
issue_map = self._current_task.get_planted_issue_map()
|
| 442 |
+
valid_issue_types = {"missing_value", "wrong_type", "duplicate_row", "out_of_range",
|
| 443 |
+
"format_violation", "inconsistent_value", "statistical_outlier",
|
| 444 |
+
"referential_integrity"}
|
| 445 |
+
for fp_key in fp_keys:
|
| 446 |
+
parts = fp_key.split(",")
|
| 447 |
+
itype = parts[2].split(":")[1] if len(parts) >= 3 else ""
|
| 448 |
+
if itype in valid_issue_types:
|
| 449 |
+
human_review_flags.append({
|
| 450 |
+
"item": fp_key,
|
| 451 |
+
"reason": "Agent reported this issue but it's not in ground truth — may be a real issue the grader missed",
|
| 452 |
+
"type": "possible_unplanted_issue",
|
| 453 |
+
})
|
| 454 |
+
|
| 455 |
+
# 2. Partial fix matches — fix was close but not exact, human should verify
|
| 456 |
+
for detail in fix_result["fix_details"]:
|
| 457 |
+
if 0 < detail["score"] < 0.99:
|
| 458 |
+
human_review_flags.append({
|
| 459 |
+
"item": f"row:{detail['row']},col:{detail['col']}",
|
| 460 |
+
"reason": f"Fix scored {detail['score']:.2f} ({detail['reason']}) — human should verify if acceptable",
|
| 461 |
+
"type": "partial_fix",
|
| 462 |
+
})
|
| 463 |
+
|
| 464 |
+
# 3. High-difficulty issues that were missed — flag for training data review
|
| 465 |
+
planted_by_key = {i.to_key(): i for i in self._current_task.planted_issues}
|
| 466 |
+
fn_keys = self._planted_keys - reported_keys
|
| 467 |
+
for fn_key in fn_keys:
|
| 468 |
+
issue = planted_by_key.get(fn_key)
|
| 469 |
+
if issue and issue.difficulty >= 2.5:
|
| 470 |
+
human_review_flags.append({
|
| 471 |
+
"item": fn_key,
|
| 472 |
+
"reason": f"High-difficulty issue (difficulty={issue.difficulty}) missed — {issue.description}",
|
| 473 |
+
"type": "missed_hard_issue",
|
| 474 |
+
})
|
| 475 |
+
|
| 476 |
+
if human_review_flags:
|
| 477 |
+
feedback_lines.append(f"\n--- Flagged for Human Review ({len(human_review_flags)}) ---")
|
| 478 |
+
for flag in human_review_flags:
|
| 479 |
+
feedback_lines.append(f" [{flag['type']}] {flag['item']}: {flag['reason']}")
|
| 480 |
|
| 481 |
return DataQAObservation(
|
| 482 |
dataset_csv=self._current_task.corrupted_csv,
|
|
|
|
| 490 |
done=is_done,
|
| 491 |
reward=self._best_score,
|
| 492 |
metadata={
|
| 493 |
+
"identify_f1": identify_f1,
|
| 494 |
+
"identify_score": identify_score,
|
| 495 |
+
"fix_score": fix_score,
|
| 496 |
+
"combined_reward": combined_reward,
|
| 497 |
"precision": metrics["precision"],
|
| 498 |
"recall": metrics["recall"],
|
| 499 |
"tp": metrics["tp"],
|
|
|
|
| 501 |
"fn": metrics["fn"],
|
| 502 |
"difficulty_found": weighted["difficulty_found"],
|
| 503 |
"difficulty_missed": weighted["difficulty_missed"],
|
| 504 |
+
"fixes_correct": fix_result["fixes_correct"],
|
| 505 |
+
"fixes_partial": fix_result["fixes_partial"],
|
| 506 |
+
"fixes_wrong": fix_result["fixes_wrong"],
|
| 507 |
+
"fixes_attempted": fix_result["fixes_attempted"],
|
| 508 |
+
"fix_details": fix_result["fix_details"],
|
| 509 |
+
"human_review_flags": human_review_flags,
|
| 510 |
},
|
| 511 |
)
|
| 512 |
|
dataqa_env/server/gradio_ui.py
ADDED
|
@@ -0,0 +1,487 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio UI — Agent Trajectory Replay Viewer for DataQA.
|
| 3 |
+
|
| 4 |
+
Designed for judges: zero clicks needed, auto-plays on load.
|
| 5 |
+
Tab per task, step slider, prominent metric cards, color-coded dataset.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
import csv
|
| 11 |
+
import io
|
| 12 |
+
|
| 13 |
+
import gradio as gr
|
| 14 |
+
|
| 15 |
+
from .environment import DataQAEnvironment, parse_issue_key
|
| 16 |
+
from .tasks import list_tasks, PlantedIssue
|
| 17 |
+
from ..models import DataQAAction
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# ── Pre-built agent trajectories (simulates baseline agent) ──
|
| 21 |
+
|
| 22 |
+
AGENT_TRAJECTORIES = {
|
| 23 |
+
"easy": [
|
| 24 |
+
{
|
| 25 |
+
"issues": [
|
| 26 |
+
"row:4,col:name,issue:missing_value",
|
| 27 |
+
"row:7,col:salary,issue:wrong_type",
|
| 28 |
+
"row:9,col:salary,issue:out_of_range",
|
| 29 |
+
"row:3,col:email,issue:format_violation", # FP
|
| 30 |
+
],
|
| 31 |
+
"fixes": [],
|
| 32 |
+
},
|
| 33 |
+
{
|
| 34 |
+
"issues": [
|
| 35 |
+
"row:4,col:name,issue:missing_value",
|
| 36 |
+
"row:7,col:salary,issue:wrong_type",
|
| 37 |
+
"row:9,col:salary,issue:out_of_range",
|
| 38 |
+
"row:11,col:employee_id,issue:duplicate_row",
|
| 39 |
+
],
|
| 40 |
+
"fixes": [
|
| 41 |
+
"row:4,col:name,fix:David Kim",
|
| 42 |
+
"row:7,col:salary,fix:75000",
|
| 43 |
+
"row:9,col:salary,fix:73000",
|
| 44 |
+
],
|
| 45 |
+
},
|
| 46 |
+
],
|
| 47 |
+
"medium": [
|
| 48 |
+
{
|
| 49 |
+
"issues": [
|
| 50 |
+
"row:5,col:total,issue:inconsistent_value",
|
| 51 |
+
"row:10,col:category,issue:format_violation",
|
| 52 |
+
"row:14,col:product_name,issue:missing_value",
|
| 53 |
+
"row:17,col:quantity,issue:out_of_range",
|
| 54 |
+
"row:19,col:order_id,issue:duplicate_row",
|
| 55 |
+
"row:12,col:order_date,issue:format_violation",
|
| 56 |
+
],
|
| 57 |
+
"fixes": [
|
| 58 |
+
"row:5,col:total,fix:42.00",
|
| 59 |
+
"row:10,col:category,fix:Sports",
|
| 60 |
+
"row:12,col:order_date,fix:2024-01-26",
|
| 61 |
+
"row:14,col:product_name,fix:LED Strip Lights",
|
| 62 |
+
],
|
| 63 |
+
},
|
| 64 |
+
],
|
| 65 |
+
"hard": [
|
| 66 |
+
{
|
| 67 |
+
"issues": [
|
| 68 |
+
"row:14,col:training_time_hours,issue:out_of_range",
|
| 69 |
+
"row:13,col:learning_rate,issue:out_of_range",
|
| 70 |
+
"row:15,col:model_name,issue:missing_value",
|
| 71 |
+
"row:9,col:batch_size,issue:format_violation",
|
| 72 |
+
"row:10,col:train_size,issue:inconsistent_value",
|
| 73 |
+
],
|
| 74 |
+
"fixes": [],
|
| 75 |
+
},
|
| 76 |
+
{
|
| 77 |
+
"issues": [
|
| 78 |
+
"row:14,col:training_time_hours,issue:out_of_range",
|
| 79 |
+
"row:13,col:learning_rate,issue:out_of_range",
|
| 80 |
+
"row:15,col:model_name,issue:missing_value",
|
| 81 |
+
"row:9,col:batch_size,issue:format_violation",
|
| 82 |
+
"row:10,col:train_size,issue:inconsistent_value",
|
| 83 |
+
"row:5,col:val_loss,issue:inconsistent_value",
|
| 84 |
+
"row:7,col:gpu_memory_gb,issue:statistical_outlier",
|
| 85 |
+
"row:11,col:timestamp,issue:inconsistent_value",
|
| 86 |
+
"row:9,col:training_time_hours,issue:statistical_outlier",
|
| 87 |
+
"row:12,col:test_accuracy,issue:statistical_outlier",
|
| 88 |
+
],
|
| 89 |
+
"fixes": [
|
| 90 |
+
"row:14,col:training_time_hours,fix:72.0",
|
| 91 |
+
"row:13,col:learning_rate,fix:0.00001",
|
| 92 |
+
"row:15,col:model_name,fix:whisper-small",
|
| 93 |
+
"row:9,col:batch_size,fix:256",
|
| 94 |
+
"row:9,col:training_time_hours,fix:36.0",
|
| 95 |
+
],
|
| 96 |
+
},
|
| 97 |
+
],
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# ── HTML rendering ──
|
| 102 |
+
|
| 103 |
+
def _metric_card(label: str, value: str, color: str = "#333") -> str:
|
| 104 |
+
return (
|
| 105 |
+
f'<div style="text-align:center;padding:12px 16px;background:#f8f9fa;'
|
| 106 |
+
f'border-radius:8px;min-width:100px;">'
|
| 107 |
+
f'<div style="font-size:11px;color:#666;text-transform:uppercase;letter-spacing:1px;">{label}</div>'
|
| 108 |
+
f'<div style="font-size:28px;font-weight:700;color:{color};margin-top:2px;">{value}</div>'
|
| 109 |
+
f'</div>'
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def _csv_to_html(
|
| 114 |
+
csv_text: str,
|
| 115 |
+
planted: list[PlantedIssue],
|
| 116 |
+
correct: set[tuple[int, str]],
|
| 117 |
+
fp: set[tuple[int, str]],
|
| 118 |
+
missed: set[tuple[int, str]],
|
| 119 |
+
fixed: dict[tuple[int, str], str],
|
| 120 |
+
fix_values: dict[tuple[int, str], str] | None = None,
|
| 121 |
+
) -> str:
|
| 122 |
+
"""Render CSV as HTML with color-coded cells and inline fix proposals."""
|
| 123 |
+
fix_values = fix_values or {}
|
| 124 |
+
desc_map = {(i.row, i.col): i for i in planted}
|
| 125 |
+
reader = csv.reader(io.StringIO(csv_text.strip()))
|
| 126 |
+
rows = list(reader)
|
| 127 |
+
if not rows:
|
| 128 |
+
return ""
|
| 129 |
+
|
| 130 |
+
header = rows[0]
|
| 131 |
+
header_lower = [h.strip().lower() for h in header]
|
| 132 |
+
data = rows[1:]
|
| 133 |
+
|
| 134 |
+
t = ['<table style="border-collapse:collapse;width:100%;font-size:12px;font-family:\'SF Mono\',monospace;">']
|
| 135 |
+
t.append('<tr>')
|
| 136 |
+
t.append('<th style="border:1px solid #dee2e6;padding:6px 8px;background:#343a40;color:#fff;font-size:11px;">Row</th>')
|
| 137 |
+
for h in header:
|
| 138 |
+
t.append(f'<th style="border:1px solid #dee2e6;padding:6px 8px;background:#343a40;color:#fff;font-size:11px;">{h}</th>')
|
| 139 |
+
t.append('</tr>')
|
| 140 |
+
|
| 141 |
+
for i, row in enumerate(data):
|
| 142 |
+
rn = i + 1
|
| 143 |
+
bg = "#fff" if i % 2 == 0 else "#f8f9fa"
|
| 144 |
+
t.append(f'<tr style="background:{bg};">')
|
| 145 |
+
t.append(f'<td style="border:1px solid #dee2e6;padding:4px 8px;color:#adb5bd;text-align:center;font-size:11px;">{rn}</td>')
|
| 146 |
+
for j, val in enumerate(row):
|
| 147 |
+
col = header_lower[j] if j < len(header_lower) else ""
|
| 148 |
+
ck = (rn, col)
|
| 149 |
+
s = "border:1px solid #dee2e6;padding:4px 8px;"
|
| 150 |
+
tip = ""
|
| 151 |
+
badge = ""
|
| 152 |
+
|
| 153 |
+
issue = desc_map.get(ck)
|
| 154 |
+
|
| 155 |
+
if ck in correct:
|
| 156 |
+
s += "background:#d4edda;"
|
| 157 |
+
tip = f"FOUND: {issue.description}" if issue else ""
|
| 158 |
+
badge = '<span style="font-size:9px;background:#28a745;color:#fff;padding:1px 4px;border-radius:3px;margin-left:4px;">TP</span>'
|
| 159 |
+
elif ck in fp:
|
| 160 |
+
s += "background:#f8d7da;"
|
| 161 |
+
badge = '<span style="font-size:9px;background:#dc3545;color:#fff;padding:1px 4px;border-radius:3px;margin-left:4px;">FP</span>'
|
| 162 |
+
elif ck in missed:
|
| 163 |
+
s += "background:#fff3cd;"
|
| 164 |
+
tip = f"MISSED: {issue.description}" if issue else ""
|
| 165 |
+
badge = '<span style="font-size:9px;background:#856404;color:#fff;padding:1px 4px;border-radius:3px;margin-left:4px;">MISS</span>'
|
| 166 |
+
|
| 167 |
+
fx = fixed.get(ck)
|
| 168 |
+
proposed = fix_values.get(ck)
|
| 169 |
+
if fx == "correct":
|
| 170 |
+
s += "box-shadow:inset 0 0 0 2px #28a745;"
|
| 171 |
+
badge += '<span style="font-size:9px;background:#28a745;color:#fff;padding:1px 4px;border-radius:3px;margin-left:2px;">FIX</span>'
|
| 172 |
+
elif fx == "partial":
|
| 173 |
+
s += "box-shadow:inset 0 0 0 2px #ffc107;"
|
| 174 |
+
badge += '<span style="font-size:9px;background:#ffc107;color:#333;padding:1px 4px;border-radius:3px;margin-left:2px;">~FIX</span>'
|
| 175 |
+
|
| 176 |
+
dv = val if val.strip() else '<em style="color:#dc3545;font-style:italic;">empty</em>'
|
| 177 |
+
|
| 178 |
+
# Show proposed fix value below the corrupted value
|
| 179 |
+
fix_line = ""
|
| 180 |
+
if proposed is not None:
|
| 181 |
+
fix_color = "#28a745" if fx == "correct" else ("#b8860b" if fx == "partial" else "#dc3545")
|
| 182 |
+
fix_line = (
|
| 183 |
+
f'<div style="font-size:10px;color:{fix_color};margin-top:2px;'
|
| 184 |
+
f'border-top:1px dashed {fix_color};padding-top:2px;">'
|
| 185 |
+
f'\u2192 {proposed}</div>'
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
t.append(f'<td style="{s}" title="{tip}">{dv}{badge}{fix_line}</td>')
|
| 189 |
+
t.append('</tr>')
|
| 190 |
+
t.append('</table>')
|
| 191 |
+
return "".join(t)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
LEGEND_HTML = (
|
| 195 |
+
'<div style="display:flex;gap:12px;flex-wrap:wrap;margin-top:10px;font-size:11px;">'
|
| 196 |
+
'<span style="background:#d4edda;padding:2px 8px;border-radius:4px;">Found (TP)</span>'
|
| 197 |
+
'<span style="background:#f8d7da;padding:2px 8px;border-radius:4px;">False Positive</span>'
|
| 198 |
+
'<span style="background:#fff3cd;padding:2px 8px;border-radius:4px;">Missed</span>'
|
| 199 |
+
'<span style="box-shadow:inset 0 0 0 2px #28a745;padding:2px 8px;border-radius:4px;">Fix Correct</span>'
|
| 200 |
+
'<span style="box-shadow:inset 0 0 0 2px #ffc107;padding:2px 8px;border-radius:4px;">Fix Partial</span>'
|
| 201 |
+
'</div>'
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
# ── Core replay logic ──
|
| 206 |
+
|
| 207 |
+
def _replay_task(task_id: str) -> list[dict]:
|
| 208 |
+
"""Run the agent trajectory and collect per-step data."""
|
| 209 |
+
env = DataQAEnvironment()
|
| 210 |
+
obs = env.reset(task_id=task_id)
|
| 211 |
+
task = env._current_task
|
| 212 |
+
planted_keys = {i.to_key() for i in task.planted_issues}
|
| 213 |
+
steps_data = []
|
| 214 |
+
|
| 215 |
+
# Step 0: initial state
|
| 216 |
+
steps_data.append({
|
| 217 |
+
"label": "Initial — corrupted dataset",
|
| 218 |
+
"html": _csv_to_html(obs.dataset_csv, task.planted_issues, set(), set(), set(), {}),
|
| 219 |
+
"metrics": {"reward": 0.0, "tp": 0, "fp": 0, "fn": len(task.planted_issues),
|
| 220 |
+
"identify": 0.0, "fix": 0.0, "fixes_correct": 0},
|
| 221 |
+
"feedback": f"Task: {task.name}\nIssues to find: {obs.num_issues_hint}\n\n{task.description}",
|
| 222 |
+
})
|
| 223 |
+
|
| 224 |
+
trajectory = AGENT_TRAJECTORIES.get(task_id, [])
|
| 225 |
+
for i, step_data in enumerate(trajectory):
|
| 226 |
+
action = DataQAAction(
|
| 227 |
+
issues=step_data["issues"],
|
| 228 |
+
fixes=step_data.get("fixes", []),
|
| 229 |
+
task_id=task_id,
|
| 230 |
+
)
|
| 231 |
+
obs = env.step(action)
|
| 232 |
+
|
| 233 |
+
reported_keys = set()
|
| 234 |
+
for iss in step_data["issues"]:
|
| 235 |
+
key = parse_issue_key(iss)
|
| 236 |
+
if key:
|
| 237 |
+
reported_keys.add(key)
|
| 238 |
+
|
| 239 |
+
tp_keys = reported_keys & planted_keys
|
| 240 |
+
fp_keys = reported_keys - planted_keys
|
| 241 |
+
fn_keys = planted_keys - reported_keys
|
| 242 |
+
|
| 243 |
+
correct = {_kc(k) for k in tp_keys}
|
| 244 |
+
fp = {_kc(k) for k in fp_keys}
|
| 245 |
+
missed = {_kc(k) for k in fn_keys} if obs.done else set()
|
| 246 |
+
|
| 247 |
+
fixed: dict[tuple[int, str], str] = {}
|
| 248 |
+
for d in obs.metadata.get("fix_details", []):
|
| 249 |
+
c = (d["row"], d["col"])
|
| 250 |
+
fixed[c] = "correct" if d["score"] >= 0.99 else ("partial" if d["score"] > 0 else "wrong")
|
| 251 |
+
|
| 252 |
+
# Extract proposed fix values from the raw fix strings
|
| 253 |
+
fix_values: dict[tuple[int, str], str] = {}
|
| 254 |
+
from .environment import parse_fix
|
| 255 |
+
for raw_fix in step_data.get("fixes", []):
|
| 256 |
+
parsed = parse_fix(raw_fix)
|
| 257 |
+
if parsed:
|
| 258 |
+
row, col, val = parsed
|
| 259 |
+
fix_values[(row, col)] = val
|
| 260 |
+
|
| 261 |
+
html = _csv_to_html(obs.dataset_csv, task.planted_issues, correct, fp, missed, fixed, fix_values)
|
| 262 |
+
|
| 263 |
+
has_fixes = bool(step_data.get("fixes"))
|
| 264 |
+
if has_fixes:
|
| 265 |
+
label = f"Step {i+1} — identify + fix"
|
| 266 |
+
else:
|
| 267 |
+
label = f"Step {i+1} — identify only"
|
| 268 |
+
|
| 269 |
+
steps_data.append({
|
| 270 |
+
"label": label,
|
| 271 |
+
"html": html,
|
| 272 |
+
"metrics": {
|
| 273 |
+
"reward": obs.reward,
|
| 274 |
+
"tp": obs.metadata["tp"],
|
| 275 |
+
"fp": obs.metadata["fp"],
|
| 276 |
+
"fn": obs.metadata["fn"],
|
| 277 |
+
"identify": obs.metadata["identify_score"],
|
| 278 |
+
"fix": obs.metadata["fix_score"],
|
| 279 |
+
"fixes_correct": obs.metadata["fixes_correct"],
|
| 280 |
+
},
|
| 281 |
+
"feedback": obs.feedback,
|
| 282 |
+
})
|
| 283 |
+
|
| 284 |
+
return steps_data
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def _kc(key: str) -> tuple[int, str]:
|
| 288 |
+
parts = key.split(",")
|
| 289 |
+
return (int(parts[0].split(":")[1]), parts[1].split(":")[1])
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
# ── Gradio app ──
|
| 293 |
+
|
| 294 |
+
def build_gradio_ui():
|
| 295 |
+
# Pre-compute all replays at startup
|
| 296 |
+
all_replays: dict[str, list[dict]] = {}
|
| 297 |
+
for tid in list_tasks():
|
| 298 |
+
all_replays[tid] = _replay_task(tid)
|
| 299 |
+
|
| 300 |
+
def show_step(task_id: str, step_idx: int):
|
| 301 |
+
replay = all_replays.get(task_id, [])
|
| 302 |
+
step_idx = int(step_idx)
|
| 303 |
+
if step_idx >= len(replay):
|
| 304 |
+
step_idx = len(replay) - 1
|
| 305 |
+
sd = replay[step_idx]
|
| 306 |
+
m = sd["metrics"]
|
| 307 |
+
|
| 308 |
+
# Reward color
|
| 309 |
+
r = m["reward"]
|
| 310 |
+
rc = "#28a745" if r >= 0.8 else ("#ffc107" if r >= 0.4 else "#dc3545")
|
| 311 |
+
|
| 312 |
+
cards = (
|
| 313 |
+
'<div style="display:flex;gap:10px;flex-wrap:wrap;margin-bottom:12px;">'
|
| 314 |
+
+ _metric_card("Reward", f"{r:.2f}", rc)
|
| 315 |
+
+ _metric_card("Found", str(m["tp"]), "#28a745")
|
| 316 |
+
+ _metric_card("False Pos", str(m["fp"]), "#dc3545" if m["fp"] > 0 else "#28a745")
|
| 317 |
+
+ _metric_card("Missed", str(m["fn"]), "#dc3545" if m["fn"] > 0 else "#28a745")
|
| 318 |
+
+ _metric_card("Identify", f"{m['identify']:.2f}", "#333")
|
| 319 |
+
+ _metric_card("Fix", f"{m['fix']:.2f}", "#333")
|
| 320 |
+
+ '</div>'
|
| 321 |
+
)
|
| 322 |
+
|
| 323 |
+
full_html = (
|
| 324 |
+
f'<div style="font-size:14px;font-weight:600;margin-bottom:8px;color:#495057;">'
|
| 325 |
+
f'{sd["label"]}</div>'
|
| 326 |
+
+ cards + sd["html"] + LEGEND_HTML
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
return full_html, sd["feedback"]
|
| 330 |
+
|
| 331 |
+
def on_task_change(task_id):
|
| 332 |
+
replay = all_replays.get(task_id, [])
|
| 333 |
+
max_step = len(replay) - 1
|
| 334 |
+
html, fb = show_step(task_id, 0)
|
| 335 |
+
return (
|
| 336 |
+
gr.update(maximum=max_step, value=0),
|
| 337 |
+
html,
|
| 338 |
+
fb,
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
def on_step_change(task_id, step_idx):
|
| 342 |
+
html, fb = show_step(task_id, step_idx)
|
| 343 |
+
return html, fb
|
| 344 |
+
|
| 345 |
+
# ── Live agent runner (connects to the env server) ──
|
| 346 |
+
|
| 347 |
+
live_env = DataQAEnvironment()
|
| 348 |
+
live_state: dict = {"obs": None, "task_id": "easy", "steps": []}
|
| 349 |
+
|
| 350 |
+
def live_reset(task_id):
|
| 351 |
+
obs = live_env.reset(task_id=task_id)
|
| 352 |
+
task = live_env._current_task
|
| 353 |
+
live_state["obs"] = obs
|
| 354 |
+
live_state["task_id"] = task_id
|
| 355 |
+
live_state["steps"] = []
|
| 356 |
+
html = _csv_to_html(obs.dataset_csv, task.planted_issues, set(), set(), set(), {})
|
| 357 |
+
info = f"**{task.name}** — {obs.num_issues_hint} issues to find, {obs.max_steps} steps max"
|
| 358 |
+
return html, info, "", "0.000"
|
| 359 |
+
|
| 360 |
+
def live_step(issues_text, fixes_text):
|
| 361 |
+
if live_state["obs"] is None:
|
| 362 |
+
return "Reset first.", "", "", ""
|
| 363 |
+
obs = live_state["obs"]
|
| 364 |
+
task = live_env._current_task
|
| 365 |
+
planted_keys = {i.to_key() for i in task.planted_issues}
|
| 366 |
+
|
| 367 |
+
issues = [l.strip() for l in issues_text.strip().split("\n") if l.strip()]
|
| 368 |
+
fixes = [l.strip() for l in fixes_text.strip().split("\n") if l.strip()] if fixes_text.strip() else []
|
| 369 |
+
|
| 370 |
+
action = DataQAAction(issues=issues, fixes=fixes, task_id=live_state["task_id"])
|
| 371 |
+
obs = live_env.step(action)
|
| 372 |
+
live_state["obs"] = obs
|
| 373 |
+
|
| 374 |
+
reported_keys = set()
|
| 375 |
+
for iss in issues:
|
| 376 |
+
key = parse_issue_key(iss)
|
| 377 |
+
if key:
|
| 378 |
+
reported_keys.add(key)
|
| 379 |
+
|
| 380 |
+
tp_keys = reported_keys & planted_keys
|
| 381 |
+
fp_keys = reported_keys - planted_keys
|
| 382 |
+
fn_keys = planted_keys - reported_keys
|
| 383 |
+
|
| 384 |
+
correct = {_kc(k) for k in tp_keys}
|
| 385 |
+
fp_set = {_kc(k) for k in fp_keys}
|
| 386 |
+
missed = {_kc(k) for k in fn_keys} if obs.done else set()
|
| 387 |
+
|
| 388 |
+
fixed: dict[tuple[int, str], str] = {}
|
| 389 |
+
for d in obs.metadata.get("fix_details", []):
|
| 390 |
+
c = (d["row"], d["col"])
|
| 391 |
+
fixed[c] = "correct" if d["score"] >= 0.99 else ("partial" if d["score"] > 0 else "wrong")
|
| 392 |
+
|
| 393 |
+
from .environment import parse_fix
|
| 394 |
+
fix_values: dict[tuple[int, str], str] = {}
|
| 395 |
+
for raw in fixes:
|
| 396 |
+
parsed = parse_fix(raw)
|
| 397 |
+
if parsed:
|
| 398 |
+
fix_values[(parsed[0], parsed[1])] = parsed[2]
|
| 399 |
+
|
| 400 |
+
html = _csv_to_html(obs.dataset_csv, task.planted_issues, correct, fp_set, missed, fixed, fix_values)
|
| 401 |
+
|
| 402 |
+
m = obs.metadata
|
| 403 |
+
r = obs.reward
|
| 404 |
+
rc = "#28a745" if r >= 0.8 else ("#ffc107" if r >= 0.4 else "#dc3545")
|
| 405 |
+
cards = (
|
| 406 |
+
'<div style="display:flex;gap:10px;flex-wrap:wrap;margin-bottom:12px;">'
|
| 407 |
+
+ _metric_card("Reward", f"{r:.2f}", rc)
|
| 408 |
+
+ _metric_card("Found", str(m["tp"]), "#28a745")
|
| 409 |
+
+ _metric_card("False Pos", str(m["fp"]), "#dc3545" if m["fp"] > 0 else "#28a745")
|
| 410 |
+
+ _metric_card("Missed", str(m["fn"]), "#dc3545" if m["fn"] > 0 else "#28a745")
|
| 411 |
+
+ '</div>'
|
| 412 |
+
)
|
| 413 |
+
full_html = cards + html + LEGEND_HTML
|
| 414 |
+
return full_html, obs.feedback, f"{r:.3f}", ""
|
| 415 |
+
|
| 416 |
+
# ── Build the UI ──
|
| 417 |
+
|
| 418 |
+
with gr.Blocks(title="DataQA Environment") as demo:
|
| 419 |
+
gr.Markdown(
|
| 420 |
+
"# DataQA — Data Quality Assurance Environment\n"
|
| 421 |
+
"Two-phase RL environment: **Identify** data quality issues, then **Fix** them."
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
with gr.Tabs():
|
| 425 |
+
# ── Tab 1: Demo replay ──
|
| 426 |
+
with gr.Tab("Demo (Baseline Agent)"):
|
| 427 |
+
gr.Markdown(
|
| 428 |
+
"*Replay of the baseline Qwen-72B agent. "
|
| 429 |
+
"Use the slider to step through the agent's trajectory.*"
|
| 430 |
+
)
|
| 431 |
+
with gr.Row():
|
| 432 |
+
task_dd = gr.Dropdown(choices=list_tasks(), value="easy", label="Task", scale=1)
|
| 433 |
+
step_slider = gr.Slider(minimum=0, maximum=2, step=1, value=0, label="Step", scale=3)
|
| 434 |
+
|
| 435 |
+
viz_html = gr.HTML()
|
| 436 |
+
feedback_box = gr.Textbox(label="Agent Feedback", lines=10, interactive=False)
|
| 437 |
+
|
| 438 |
+
task_dd.change(on_task_change, inputs=[task_dd], outputs=[step_slider, viz_html, feedback_box])
|
| 439 |
+
step_slider.change(on_step_change, inputs=[task_dd, step_slider], outputs=[viz_html, feedback_box])
|
| 440 |
+
demo.load(on_task_change, inputs=[task_dd], outputs=[step_slider, viz_html, feedback_box])
|
| 441 |
+
|
| 442 |
+
# ── Tab 2: Try your own agent ──
|
| 443 |
+
with gr.Tab("Try Your Own Agent"):
|
| 444 |
+
gr.Markdown(
|
| 445 |
+
"*Submit your own issues and fixes to see how the environment scores them. "
|
| 446 |
+
"This is the same environment the baseline agent talks to.*"
|
| 447 |
+
)
|
| 448 |
+
with gr.Row():
|
| 449 |
+
live_task_dd = gr.Dropdown(choices=list_tasks(), value="easy", label="Task", scale=1)
|
| 450 |
+
live_reset_btn = gr.Button("Reset", variant="primary", scale=1)
|
| 451 |
+
|
| 452 |
+
with gr.Row():
|
| 453 |
+
live_info = gr.Markdown()
|
| 454 |
+
live_reward = gr.Textbox(label="Reward", interactive=False, scale=1)
|
| 455 |
+
|
| 456 |
+
live_viz = gr.HTML()
|
| 457 |
+
|
| 458 |
+
with gr.Row():
|
| 459 |
+
live_issues = gr.Textbox(
|
| 460 |
+
label="Issues (one per line)",
|
| 461 |
+
placeholder="row:4,col:name,issue:missing_value\nrow:7,col:salary,issue:wrong_type",
|
| 462 |
+
lines=5,
|
| 463 |
+
)
|
| 464 |
+
live_fixes = gr.Textbox(
|
| 465 |
+
label="Fixes (one per line, optional)",
|
| 466 |
+
placeholder="row:4,col:name,fix:David Kim\nrow:7,col:salary,fix:75000",
|
| 467 |
+
lines=5,
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
live_step_btn = gr.Button("Submit Step", variant="primary")
|
| 471 |
+
live_feedback = gr.Textbox(label="Feedback", lines=10, interactive=False)
|
| 472 |
+
|
| 473 |
+
live_reset_btn.click(
|
| 474 |
+
live_reset, inputs=[live_task_dd],
|
| 475 |
+
outputs=[live_viz, live_info, live_feedback, live_reward],
|
| 476 |
+
)
|
| 477 |
+
live_step_btn.click(
|
| 478 |
+
live_step, inputs=[live_issues, live_fixes],
|
| 479 |
+
outputs=[live_viz, live_feedback, live_reward, live_issues],
|
| 480 |
+
)
|
| 481 |
+
|
| 482 |
+
return demo
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
if __name__ == "__main__":
|
| 486 |
+
demo = build_gradio_ui()
|
| 487 |
+
demo.launch()
|
dataqa_env/server/tasks.py
CHANGED
|
@@ -43,6 +43,28 @@ class Task:
|
|
| 43 |
corrupted_csv: str = ""
|
| 44 |
max_steps: int = 3
|
| 45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
def _csv_to_rows(csv_text: str) -> List[List[str]]:
|
| 48 |
reader = csv.reader(io.StringIO(csv_text.strip()))
|
|
@@ -354,6 +376,32 @@ EXP-015,whisper-small,common-voice,520000,16000,16000,0.00005,16,5,0.55,0.68,0.0
|
|
| 354 |
issues.append(PlantedIssue(row=r + 1, col="model_name", issue_type="missing_value",
|
| 355 |
description="model_name is whitespace-only", difficulty=2.5))
|
| 356 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 357 |
corrupted = _rows_to_csv([header] + data)
|
| 358 |
|
| 359 |
return Task(
|
|
|
|
| 43 |
corrupted_csv: str = ""
|
| 44 |
max_steps: int = 3
|
| 45 |
|
| 46 |
+
def get_clean_value(self, row: int, col: str) -> str | None:
|
| 47 |
+
"""
|
| 48 |
+
Look up the original clean value for a given (row, col).
|
| 49 |
+
Row is 1-indexed (data row after header).
|
| 50 |
+
Returns None if row/col is out of bounds or column not found.
|
| 51 |
+
"""
|
| 52 |
+
rows = _csv_to_rows(self.clean_csv)
|
| 53 |
+
if len(rows) < 2:
|
| 54 |
+
return None
|
| 55 |
+
header = [h.strip().lower() for h in rows[0]]
|
| 56 |
+
if col.lower() not in header:
|
| 57 |
+
return None
|
| 58 |
+
col_idx = header.index(col.lower())
|
| 59 |
+
data_row_idx = row # row is 1-indexed, rows[0] is header, so rows[row] is the data row
|
| 60 |
+
if data_row_idx < 1 or data_row_idx >= len(rows):
|
| 61 |
+
return None
|
| 62 |
+
return rows[data_row_idx][col_idx].strip()
|
| 63 |
+
|
| 64 |
+
def get_planted_issue_map(self) -> dict:
|
| 65 |
+
"""Return dict mapping issue key -> PlantedIssue for quick lookups."""
|
| 66 |
+
return {issue.to_key(): issue for issue in self.planted_issues}
|
| 67 |
+
|
| 68 |
|
| 69 |
def _csv_to_rows(csv_text: str) -> List[List[str]]:
|
| 70 |
reader = csv.reader(io.StringIO(csv_text.strip()))
|
|
|
|
| 376 |
issues.append(PlantedIssue(row=r + 1, col="model_name", issue_type="missing_value",
|
| 377 |
description="model_name is whitespace-only", difficulty=2.5))
|
| 378 |
|
| 379 |
+
# Issue 9: Training time impossibly fast for dataset size and epochs
|
| 380 |
+
# EXP-004: vit-base on imagenet-1k, 300 epochs, but only 96 hours is plausible.
|
| 381 |
+
# Let's make EXP-009: efficientnet-b0 on imagenet-1k, 350 epochs = should take ~40+ hours
|
| 382 |
+
# but we set it to 0.5 hours — impossible for 1.2M images * 350 epochs
|
| 383 |
+
r = 8 # EXP-009 (same row as batch_size issue, different column)
|
| 384 |
+
data[r][13] = "0.5" # 30 minutes for 350 epochs on imagenet? impossible
|
| 385 |
+
issues.append(PlantedIssue(row=r + 1, col="training_time_hours", issue_type="statistical_outlier",
|
| 386 |
+
description="0.5 hours for 350 epochs on imagenet-1k (1.2M images) is impossibly fast",
|
| 387 |
+
difficulty=3.0))
|
| 388 |
+
|
| 389 |
+
# Issue 10: test_accuracy of 95.1% for roberta-large on SST-2 with train_size=500
|
| 390 |
+
# is suspiciously high — SOTA is ~96% with full dataset (67k). With only 500 training
|
| 391 |
+
# samples, 95.1% accuracy suggests data contamination or evaluation bug
|
| 392 |
+
r = 9 # EXP-010 (same row as train_size issue, different column)
|
| 393 |
+
# train_size is already corrupted to 500, but the test_accuracy 95.1 is from the
|
| 394 |
+
# original full-dataset run — this cross-column inconsistency is the real issue
|
| 395 |
+
# We don't modify the value — the inconsistency emerges from the train_size corruption
|
| 396 |
+
# So let's use a different row. EXP-001: resnet50 on imagenet, accuracy 76.3 is fine.
|
| 397 |
+
# Instead: EXP-012 wav2vec2 on librispeech — set test_accuracy to 98.5 (way too high
|
| 398 |
+
# for a speech model with only 20 epochs, SOTA is ~96% with much more training)
|
| 399 |
+
r = 11 # EXP-012
|
| 400 |
+
data[r][11] = "98.5" # wav2vec2 with 20 epochs shouldn't hit 98.5% — SOTA is ~96%
|
| 401 |
+
issues.append(PlantedIssue(row=r + 1, col="test_accuracy", issue_type="statistical_outlier",
|
| 402 |
+
description="test_accuracy 98.5% for wav2vec2 with only 20 epochs exceeds known SOTA (~96%), likely evaluation error",
|
| 403 |
+
difficulty=3.0))
|
| 404 |
+
|
| 405 |
corrupted = _rows_to_csv([header] + data)
|
| 406 |
|
| 407 |
return Task(
|
inference.py
CHANGED
|
@@ -1,8 +1,11 @@
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
-
DataQA Inference Script
|
| 4 |
-
-----------------------
|
| 5 |
-
LLM agent that plays the DataQA environment
|
|
|
|
|
|
|
|
|
|
| 6 |
Uses the OpenAI client to interact with any OpenAI-compatible LLM API.
|
| 7 |
|
| 8 |
Required environment variables:
|
|
@@ -92,10 +95,10 @@ class EnvHTTPClient:
|
|
| 92 |
r.raise_for_status()
|
| 93 |
return r.json()
|
| 94 |
|
| 95 |
-
def step(self, issues: list[str], task_id: str = "easy") -> dict:
|
| 96 |
r = self.session.post(
|
| 97 |
f"{self.base_url}/step",
|
| 98 |
-
json={"action": {"issues": issues, "task_id": task_id}},
|
| 99 |
timeout=30,
|
| 100 |
)
|
| 101 |
r.raise_for_status()
|
|
@@ -103,10 +106,10 @@ class EnvHTTPClient:
|
|
| 103 |
|
| 104 |
|
| 105 |
# ---------------------------------------------------------------------------
|
| 106 |
-
# LLM
|
| 107 |
# ---------------------------------------------------------------------------
|
| 108 |
|
| 109 |
-
|
| 110 |
|
| 111 |
You will be given:
|
| 112 |
1. A dataset in CSV format
|
|
@@ -141,7 +144,26 @@ Respond with ONLY the list of issues, one per line. No other text.
|
|
| 141 |
Example: row:3,col:salary,issue:missing_value"""
|
| 142 |
|
| 143 |
|
| 144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
obs = observation if isinstance(observation, dict) else observation
|
| 146 |
parts = []
|
| 147 |
|
|
@@ -160,6 +182,12 @@ def build_user_prompt(observation: dict) -> str:
|
|
| 160 |
if feedback and "reset" not in feedback.lower():
|
| 161 |
parts.append(f"FEEDBACK FROM PREVIOUS ATTEMPT:\n{feedback}")
|
| 162 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
return "\n\n".join(parts)
|
| 164 |
|
| 165 |
|
|
@@ -170,7 +198,6 @@ def parse_llm_response(response: str) -> list[str]:
|
|
| 170 |
line = line.strip()
|
| 171 |
if not line:
|
| 172 |
continue
|
| 173 |
-
# Remove numbering like "1. " or "- " or "* "
|
| 174 |
line = re.sub(r"^\s*[\d]+[.\)]\s*", "", line)
|
| 175 |
line = re.sub(r"^\s*[-*]\s*", "", line)
|
| 176 |
line = line.strip()
|
|
@@ -186,8 +213,60 @@ def parse_llm_response(response: str) -> list[str]:
|
|
| 186 |
return issues
|
| 187 |
|
| 188 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
def run_task(client: OpenAI, env: EnvHTTPClient, task_id: str) -> float:
|
| 190 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
|
| 192 |
|
| 193 |
rewards: List[float] = []
|
|
@@ -196,48 +275,38 @@ def run_task(client: OpenAI, env: EnvHTTPClient, task_id: str) -> float:
|
|
| 196 |
success = False
|
| 197 |
|
| 198 |
try:
|
| 199 |
-
# Reset environment for this task
|
| 200 |
reset_response = env.reset(task_id=task_id)
|
| 201 |
observation = reset_response.get("observation", reset_response)
|
| 202 |
|
| 203 |
-
|
|
|
|
| 204 |
|
| 205 |
for step_num in range(1, MAX_STEPS_PER_TASK + 1):
|
| 206 |
-
user_prompt = build_user_prompt(observation)
|
| 207 |
-
messages_for_call = messages + [{"role": "user", "content": user_prompt}]
|
| 208 |
-
|
| 209 |
-
# Call LLM with retry on rate limit
|
| 210 |
-
llm_output = ""
|
| 211 |
error_msg = None
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
temperature=0.1,
|
| 218 |
-
max_tokens=2048,
|
| 219 |
-
)
|
| 220 |
-
llm_output = response.choices[0].message.content or ""
|
| 221 |
-
break
|
| 222 |
-
except Exception as e:
|
| 223 |
-
if "rate_limit" in str(e).lower() or "429" in str(e):
|
| 224 |
-
wait = 10 * (attempt + 1)
|
| 225 |
-
print(f"[DEBUG] Rate limited, waiting {wait}s...", file=sys.stderr, flush=True)
|
| 226 |
-
time.sleep(wait)
|
| 227 |
-
else:
|
| 228 |
-
error_msg = str(e)
|
| 229 |
-
print(f"[DEBUG] LLM call failed: {e}", file=sys.stderr, flush=True)
|
| 230 |
-
break
|
| 231 |
-
|
| 232 |
-
# Parse issues from LLM response
|
| 233 |
-
issues = parse_llm_response(llm_output)
|
| 234 |
-
action_str = ";".join(issues) if issues else "none"
|
| 235 |
|
| 236 |
if not issues and not error_msg:
|
| 237 |
error_msg = "no issues parsed from LLM response"
|
| 238 |
|
| 239 |
-
#
|
| 240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
observation = step_response.get("observation", step_response)
|
| 242 |
|
| 243 |
reward = float(step_response.get("reward", 0.0) or 0.0)
|
|
@@ -257,9 +326,8 @@ def run_task(client: OpenAI, env: EnvHTTPClient, task_id: str) -> float:
|
|
| 257 |
if done:
|
| 258 |
break
|
| 259 |
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
messages.append({"role": "assistant", "content": llm_output})
|
| 263 |
|
| 264 |
success = best_score >= 0.5
|
| 265 |
|
|
@@ -279,21 +347,18 @@ def main():
|
|
| 279 |
print(f"[DEBUG] API_BASE_URL={API_BASE_URL}", file=sys.stderr, flush=True)
|
| 280 |
print(f"[DEBUG] MODEL_NAME={MODEL_NAME}", file=sys.stderr, flush=True)
|
| 281 |
|
| 282 |
-
# Initialize clients
|
| 283 |
env = EnvHTTPClient(ENV_URL)
|
| 284 |
llm_client = OpenAI(
|
| 285 |
base_url=API_BASE_URL,
|
| 286 |
api_key=API_KEY or "no-key",
|
| 287 |
)
|
| 288 |
|
| 289 |
-
# Check environment health
|
| 290 |
if not env.health():
|
| 291 |
print("[DEBUG] Environment is not healthy. Exiting.", file=sys.stderr, flush=True)
|
| 292 |
sys.exit(1)
|
| 293 |
|
| 294 |
print(f"[DEBUG] Environment is healthy", file=sys.stderr, flush=True)
|
| 295 |
|
| 296 |
-
# Run all tasks
|
| 297 |
scores = {}
|
| 298 |
for task_id in TASKS:
|
| 299 |
try:
|
|
@@ -303,7 +368,6 @@ def main():
|
|
| 303 |
print(f"[DEBUG] Task {task_id} failed: {e}", file=sys.stderr, flush=True)
|
| 304 |
scores[task_id] = 0.0
|
| 305 |
|
| 306 |
-
# Summary to stderr (stdout is reserved for structured logs only)
|
| 307 |
avg_score = sum(scores.values()) / len(scores) if scores else 0.0
|
| 308 |
print(f"\n[DEBUG] FINAL RESULTS: {scores} avg={avg_score:.3f}", file=sys.stderr, flush=True)
|
| 309 |
|
|
|
|
| 1 |
#!/usr/bin/env python3
|
| 2 |
"""
|
| 3 |
+
DataQA Inference Script — Two-Phase Agent
|
| 4 |
+
------------------------------------------
|
| 5 |
+
LLM agent that plays the DataQA environment in two phases:
|
| 6 |
+
Phase 1: Identify all data quality issues
|
| 7 |
+
Phase 2: Propose fixes for identified issues
|
| 8 |
+
|
| 9 |
Uses the OpenAI client to interact with any OpenAI-compatible LLM API.
|
| 10 |
|
| 11 |
Required environment variables:
|
|
|
|
| 95 |
r.raise_for_status()
|
| 96 |
return r.json()
|
| 97 |
|
| 98 |
+
def step(self, issues: list[str], fixes: list[str], task_id: str = "easy") -> dict:
|
| 99 |
r = self.session.post(
|
| 100 |
f"{self.base_url}/step",
|
| 101 |
+
json={"action": {"issues": issues, "fixes": fixes, "task_id": task_id}},
|
| 102 |
timeout=30,
|
| 103 |
)
|
| 104 |
r.raise_for_status()
|
|
|
|
| 106 |
|
| 107 |
|
| 108 |
# ---------------------------------------------------------------------------
|
| 109 |
+
# LLM Prompts
|
| 110 |
# ---------------------------------------------------------------------------
|
| 111 |
|
| 112 |
+
IDENTIFY_SYSTEM_PROMPT = """You are a data quality analyst. Your job is to inspect datasets and identify data quality issues.
|
| 113 |
|
| 114 |
You will be given:
|
| 115 |
1. A dataset in CSV format
|
|
|
|
| 144 |
Example: row:3,col:salary,issue:missing_value"""
|
| 145 |
|
| 146 |
|
| 147 |
+
FIX_SYSTEM_PROMPT = """You are a data repair specialist. You have already identified data quality issues in a dataset. Now you must propose the correct values to fix each issue.
|
| 148 |
+
|
| 149 |
+
For each issue you identified, propose a fix in EXACTLY this format:
|
| 150 |
+
row:<row_number>,col:<column_name>,fix:<corrected_value>
|
| 151 |
+
|
| 152 |
+
Guidelines for proposing fixes:
|
| 153 |
+
- For missing_value: infer the correct value from context, schema, and other rows
|
| 154 |
+
- For wrong_type: convert to the correct type (e.g., "seventy-five thousand" → "75000")
|
| 155 |
+
- For out_of_range: propose a value within the valid range that makes sense in context
|
| 156 |
+
- For format_violation: correct the format (e.g., "26/01/2024" → "2024-01-26")
|
| 157 |
+
- For inconsistent_value: compute the correct value from related fields
|
| 158 |
+
- For duplicate_row: propose a corrected unique key or indicate removal
|
| 159 |
+
- For statistical_outlier: propose a reasonable value given the model/context
|
| 160 |
+
|
| 161 |
+
Use the schema, validation rules, and surrounding data to determine the correct fix.
|
| 162 |
+
Respond with ONLY the list of fixes, one per line. No other text.
|
| 163 |
+
Example: row:3,col:salary,fix:75000"""
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def build_user_prompt(observation: dict, include_fixes: bool = False) -> str:
|
| 167 |
obs = observation if isinstance(observation, dict) else observation
|
| 168 |
parts = []
|
| 169 |
|
|
|
|
| 182 |
if feedback and "reset" not in feedback.lower():
|
| 183 |
parts.append(f"FEEDBACK FROM PREVIOUS ATTEMPT:\n{feedback}")
|
| 184 |
|
| 185 |
+
if include_fixes:
|
| 186 |
+
parts.append(
|
| 187 |
+
"Now propose fixes for ALL issues. "
|
| 188 |
+
"Use format: row:<N>,col:<name>,fix:<corrected_value>"
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
return "\n\n".join(parts)
|
| 192 |
|
| 193 |
|
|
|
|
| 198 |
line = line.strip()
|
| 199 |
if not line:
|
| 200 |
continue
|
|
|
|
| 201 |
line = re.sub(r"^\s*[\d]+[.\)]\s*", "", line)
|
| 202 |
line = re.sub(r"^\s*[-*]\s*", "", line)
|
| 203 |
line = line.strip()
|
|
|
|
| 213 |
return issues
|
| 214 |
|
| 215 |
|
| 216 |
+
def parse_fix_response(response: str) -> list[str]:
|
| 217 |
+
"""Extract fix lines from LLM response."""
|
| 218 |
+
fixes = []
|
| 219 |
+
for line in response.strip().split("\n"):
|
| 220 |
+
line = line.strip()
|
| 221 |
+
if not line:
|
| 222 |
+
continue
|
| 223 |
+
line = re.sub(r"^\s*[\d]+[.\)]\s*", "", line)
|
| 224 |
+
line = re.sub(r"^\s*[-*]\s*", "", line)
|
| 225 |
+
line = line.strip()
|
| 226 |
+
if "row" in line.lower() and "fix" in line.lower():
|
| 227 |
+
match = re.search(
|
| 228 |
+
r"row\s*[:=]\s*(\d+)\s*[,;\s]+col(?:umn)?\s*[:=]\s*([\w_]+)\s*[,;\s]+fix\s*[:=]\s*(.+?)$",
|
| 229 |
+
line,
|
| 230 |
+
re.IGNORECASE,
|
| 231 |
+
)
|
| 232 |
+
if match:
|
| 233 |
+
normalized = f"row:{match.group(1)},col:{match.group(2).lower()},fix:{match.group(3).strip()}"
|
| 234 |
+
fixes.append(normalized)
|
| 235 |
+
return fixes
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def call_llm(client: OpenAI, system_prompt: str, user_prompt: str) -> str:
|
| 239 |
+
"""Call the LLM with retry on rate limit."""
|
| 240 |
+
for attempt in range(3):
|
| 241 |
+
try:
|
| 242 |
+
response = client.chat.completions.create(
|
| 243 |
+
model=MODEL_NAME,
|
| 244 |
+
messages=[
|
| 245 |
+
{"role": "system", "content": system_prompt},
|
| 246 |
+
{"role": "user", "content": user_prompt},
|
| 247 |
+
],
|
| 248 |
+
temperature=0.1,
|
| 249 |
+
max_tokens=2048,
|
| 250 |
+
)
|
| 251 |
+
return response.choices[0].message.content or ""
|
| 252 |
+
except Exception as e:
|
| 253 |
+
if "rate_limit" in str(e).lower() or "429" in str(e):
|
| 254 |
+
wait = 10 * (attempt + 1)
|
| 255 |
+
print(f"[DEBUG] Rate limited, waiting {wait}s...", file=sys.stderr, flush=True)
|
| 256 |
+
time.sleep(wait)
|
| 257 |
+
else:
|
| 258 |
+
print(f"[DEBUG] LLM call failed: {e}", file=sys.stderr, flush=True)
|
| 259 |
+
return ""
|
| 260 |
+
return ""
|
| 261 |
+
|
| 262 |
+
|
| 263 |
def run_task(client: OpenAI, env: EnvHTTPClient, task_id: str) -> float:
|
| 264 |
+
"""
|
| 265 |
+
Run a single task with two-phase strategy:
|
| 266 |
+
Step 1: Identify issues only
|
| 267 |
+
Step 2: Identify + Fix (using feedback from step 1)
|
| 268 |
+
Step 3: Refined identify + fix (if needed)
|
| 269 |
+
"""
|
| 270 |
log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
|
| 271 |
|
| 272 |
rewards: List[float] = []
|
|
|
|
| 275 |
success = False
|
| 276 |
|
| 277 |
try:
|
|
|
|
| 278 |
reset_response = env.reset(task_id=task_id)
|
| 279 |
observation = reset_response.get("observation", reset_response)
|
| 280 |
|
| 281 |
+
last_issues: list[str] = []
|
| 282 |
+
last_llm_output = ""
|
| 283 |
|
| 284 |
for step_num in range(1, MAX_STEPS_PER_TASK + 1):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 285 |
error_msg = None
|
| 286 |
+
|
| 287 |
+
# ── Phase 1: Identify issues ──
|
| 288 |
+
user_prompt = build_user_prompt(observation)
|
| 289 |
+
identify_output = call_llm(client, IDENTIFY_SYSTEM_PROMPT, user_prompt)
|
| 290 |
+
issues = parse_llm_response(identify_output)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
|
| 292 |
if not issues and not error_msg:
|
| 293 |
error_msg = "no issues parsed from LLM response"
|
| 294 |
|
| 295 |
+
# ── Phase 2: Propose fixes (from step 2 onward, or always if we have issues) ──
|
| 296 |
+
fixes: list[str] = []
|
| 297 |
+
if issues and step_num >= 2:
|
| 298 |
+
# Build a fix prompt that includes the identified issues
|
| 299 |
+
fix_prompt = build_user_prompt(observation, include_fixes=True)
|
| 300 |
+
fix_prompt += f"\n\nISSUES FOUND:\n" + "\n".join(issues)
|
| 301 |
+
fix_output = call_llm(client, FIX_SYSTEM_PROMPT, fix_prompt)
|
| 302 |
+
fixes = parse_fix_response(fix_output)
|
| 303 |
+
|
| 304 |
+
# ── Submit to environment ──
|
| 305 |
+
action_str = ";".join(issues[:5]) if issues else "none"
|
| 306 |
+
if fixes:
|
| 307 |
+
action_str += "|fixes:" + ";".join(fixes[:3])
|
| 308 |
+
|
| 309 |
+
step_response = env.step(issues, fixes, task_id=task_id)
|
| 310 |
observation = step_response.get("observation", step_response)
|
| 311 |
|
| 312 |
reward = float(step_response.get("reward", 0.0) or 0.0)
|
|
|
|
| 326 |
if done:
|
| 327 |
break
|
| 328 |
|
| 329 |
+
last_issues = issues
|
| 330 |
+
last_llm_output = identify_output
|
|
|
|
| 331 |
|
| 332 |
success = best_score >= 0.5
|
| 333 |
|
|
|
|
| 347 |
print(f"[DEBUG] API_BASE_URL={API_BASE_URL}", file=sys.stderr, flush=True)
|
| 348 |
print(f"[DEBUG] MODEL_NAME={MODEL_NAME}", file=sys.stderr, flush=True)
|
| 349 |
|
|
|
|
| 350 |
env = EnvHTTPClient(ENV_URL)
|
| 351 |
llm_client = OpenAI(
|
| 352 |
base_url=API_BASE_URL,
|
| 353 |
api_key=API_KEY or "no-key",
|
| 354 |
)
|
| 355 |
|
|
|
|
| 356 |
if not env.health():
|
| 357 |
print("[DEBUG] Environment is not healthy. Exiting.", file=sys.stderr, flush=True)
|
| 358 |
sys.exit(1)
|
| 359 |
|
| 360 |
print(f"[DEBUG] Environment is healthy", file=sys.stderr, flush=True)
|
| 361 |
|
|
|
|
| 362 |
scores = {}
|
| 363 |
for task_id in TASKS:
|
| 364 |
try:
|
|
|
|
| 368 |
print(f"[DEBUG] Task {task_id} failed: {e}", file=sys.stderr, flush=True)
|
| 369 |
scores[task_id] = 0.0
|
| 370 |
|
|
|
|
| 371 |
avg_score = sum(scores.values()) / len(scores) if scores else 0.0
|
| 372 |
print(f"\n[DEBUG] FINAL RESULTS: {scores} avg={avg_score:.3f}", file=sys.stderr, flush=True)
|
| 373 |
|
tests/test_environment.py
CHANGED
|
@@ -1,16 +1,24 @@
|
|
| 1 |
-
"""Tests for the DataQA environment (reset, step, scoring)."""
|
| 2 |
|
| 3 |
import pytest
|
| 4 |
from dataqa_env.server.environment import (
|
| 5 |
DataQAEnvironment,
|
| 6 |
parse_issue_key,
|
|
|
|
| 7 |
compute_f1,
|
| 8 |
compute_weighted_reward,
|
|
|
|
|
|
|
|
|
|
| 9 |
)
|
| 10 |
from dataqa_env.models import DataQAAction
|
| 11 |
-
from dataqa_env.server.tasks import PlantedIssue
|
| 12 |
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
class TestParseIssueKey:
|
| 15 |
def test_standard_format(self):
|
| 16 |
assert parse_issue_key("row:3,col:salary,issue:missing_value") == "row:3,col:salary,issue:missing_value"
|
|
@@ -28,7 +36,7 @@ class TestParseIssueKey:
|
|
| 28 |
assert parse_issue_key("this is garbage") is None
|
| 29 |
|
| 30 |
def test_partial_match(self):
|
| 31 |
-
assert parse_issue_key("row:3,col:salary") is None
|
| 32 |
|
| 33 |
def test_empty_string(self):
|
| 34 |
assert parse_issue_key("") is None
|
|
@@ -38,14 +46,49 @@ class TestParseIssueKey:
|
|
| 38 |
assert result == "row:3,col:salary,issue:missing_value"
|
| 39 |
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
class TestComputeF1:
|
| 42 |
def test_perfect_match(self):
|
| 43 |
keys = {"row:1,col:a,issue:missing_value"}
|
| 44 |
result = compute_f1(keys, keys)
|
| 45 |
assert result["f1"] == 1.0
|
| 46 |
-
assert result["tp"] == 1
|
| 47 |
-
assert result["fp"] == 0
|
| 48 |
-
assert result["fn"] == 0
|
| 49 |
|
| 50 |
def test_no_reported_no_planted(self):
|
| 51 |
result = compute_f1(set(), set())
|
|
@@ -61,9 +104,6 @@ class TestComputeF1:
|
|
| 61 |
reported = {"row:99,col:x,issue:wrong_type"}
|
| 62 |
planted = {"row:1,col:a,issue:missing_value"}
|
| 63 |
result = compute_f1(reported, planted)
|
| 64 |
-
assert result["tp"] == 0
|
| 65 |
-
assert result["fp"] == 1
|
| 66 |
-
assert result["fn"] == 1
|
| 67 |
assert result["f1"] == 0.0
|
| 68 |
|
| 69 |
def test_partial_match(self):
|
|
@@ -83,6 +123,10 @@ class TestComputeF1:
|
|
| 83 |
assert result["recall"] == pytest.approx(2 / 3)
|
| 84 |
|
| 85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
class TestComputeWeightedReward:
|
| 87 |
def test_perfect_match(self):
|
| 88 |
issues = [
|
|
@@ -101,14 +145,11 @@ class TestComputeWeightedReward:
|
|
| 101 |
issues = [PlantedIssue(row=1, col="a", issue_type="missing_value", description="", difficulty=2.0)]
|
| 102 |
result = compute_weighted_reward(set(), issues)
|
| 103 |
assert result["weighted_reward"] == 0.0
|
| 104 |
-
assert result["difficulty_missed"] == 2.0
|
| 105 |
|
| 106 |
def test_hard_issue_worth_more(self):
|
| 107 |
easy = PlantedIssue(row=1, col="a", issue_type="missing_value", description="", difficulty=1.0)
|
| 108 |
hard = PlantedIssue(row=2, col="b", issue_type="statistical_outlier", description="", difficulty=3.0)
|
| 109 |
issues = [easy, hard]
|
| 110 |
-
|
| 111 |
-
# Finding only the hard issue should score higher than only the easy issue
|
| 112 |
hard_found = compute_weighted_reward({hard.to_key()}, issues)
|
| 113 |
easy_found = compute_weighted_reward({easy.to_key()}, issues)
|
| 114 |
assert hard_found["weighted_reward"] > easy_found["weighted_reward"]
|
|
@@ -122,6 +163,92 @@ class TestComputeWeightedReward:
|
|
| 122 |
assert r_correct["weighted_reward"] > r_with_fp["weighted_reward"]
|
| 123 |
|
| 124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
class TestDataQAEnvironment:
|
| 126 |
@pytest.fixture
|
| 127 |
def env(self):
|
|
@@ -137,6 +264,7 @@ class TestDataQAEnvironment:
|
|
| 137 |
assert obs.max_steps == 3
|
| 138 |
assert obs.done is False
|
| 139 |
assert obs.reward == 0.0
|
|
|
|
| 140 |
|
| 141 |
def test_reset_medium(self, env):
|
| 142 |
obs = env.reset(task_id="medium")
|
|
@@ -144,11 +272,11 @@ class TestDataQAEnvironment:
|
|
| 144 |
|
| 145 |
def test_reset_hard(self, env):
|
| 146 |
obs = env.reset(task_id="hard")
|
| 147 |
-
assert obs.num_issues_hint ==
|
| 148 |
|
| 149 |
-
def
|
|
|
|
| 150 |
env.reset(task_id="easy")
|
| 151 |
-
# Submit all correct issues for easy task
|
| 152 |
action = DataQAAction(
|
| 153 |
issues=[
|
| 154 |
"row:4,col:name,issue:missing_value",
|
|
@@ -160,7 +288,46 @@ class TestDataQAEnvironment:
|
|
| 160 |
)
|
| 161 |
obs = env.step(action)
|
| 162 |
assert obs.done is True
|
| 163 |
-
assert obs.reward >= 0.999
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
|
| 165 |
def test_step_with_partial_issues(self, env):
|
| 166 |
env.reset(task_id="easy")
|
|
@@ -186,7 +353,6 @@ class TestDataQAEnvironment:
|
|
| 186 |
assert obs.done is True
|
| 187 |
|
| 188 |
def test_auto_reset_on_step(self, env):
|
| 189 |
-
# step() without prior reset should auto-reset
|
| 190 |
action = DataQAAction(
|
| 191 |
issues=["row:4,col:name,issue:missing_value"],
|
| 192 |
task_id="easy",
|
|
@@ -214,19 +380,26 @@ class TestDataQAEnvironment:
|
|
| 214 |
env.step(action1)
|
| 215 |
score_after_1 = env.state.best_score
|
| 216 |
|
| 217 |
-
# Worse submission shouldn't decrease best_score
|
| 218 |
action2 = DataQAAction(issues=["row:99,col:x,issue:wrong_type"], task_id="easy")
|
| 219 |
env.step(action2)
|
| 220 |
assert env.state.best_score >= score_after_1
|
| 221 |
|
| 222 |
-
def
|
| 223 |
env.reset(task_id="easy")
|
| 224 |
-
action = DataQAAction(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
obs = env.step(action)
|
| 226 |
-
|
| 227 |
-
assert "
|
| 228 |
-
assert "
|
| 229 |
-
assert "
|
|
|
|
|
|
|
|
|
|
|
|
|
| 230 |
|
| 231 |
def test_parse_error_in_feedback(self, env):
|
| 232 |
env.reset(task_id="easy")
|
|
@@ -243,7 +416,55 @@ class TestDataQAEnvironment:
|
|
| 243 |
for _ in range(3):
|
| 244 |
action = DataQAAction(
|
| 245 |
issues=["row:1,col:x,issue:wrong_type", "row:99,col:y,issue:missing_value"],
|
|
|
|
| 246 |
task_id="hard",
|
| 247 |
)
|
| 248 |
obs = env.step(action)
|
| 249 |
assert 0.0 <= obs.reward <= 1.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for the DataQA environment (reset, step, scoring, two-phase identify+fix)."""
|
| 2 |
|
| 3 |
import pytest
|
| 4 |
from dataqa_env.server.environment import (
|
| 5 |
DataQAEnvironment,
|
| 6 |
parse_issue_key,
|
| 7 |
+
parse_fix,
|
| 8 |
compute_f1,
|
| 9 |
compute_weighted_reward,
|
| 10 |
+
grade_fixes,
|
| 11 |
+
IDENTIFY_WEIGHT,
|
| 12 |
+
FIX_WEIGHT,
|
| 13 |
)
|
| 14 |
from dataqa_env.models import DataQAAction
|
| 15 |
+
from dataqa_env.server.tasks import PlantedIssue, create_task_easy, create_task_medium
|
| 16 |
|
| 17 |
|
| 18 |
+
# ──────────────────────────────────────────────────────
|
| 19 |
+
# Issue parsing
|
| 20 |
+
# ──────────────────────────────────────────────────────
|
| 21 |
+
|
| 22 |
class TestParseIssueKey:
|
| 23 |
def test_standard_format(self):
|
| 24 |
assert parse_issue_key("row:3,col:salary,issue:missing_value") == "row:3,col:salary,issue:missing_value"
|
|
|
|
| 36 |
assert parse_issue_key("this is garbage") is None
|
| 37 |
|
| 38 |
def test_partial_match(self):
|
| 39 |
+
assert parse_issue_key("row:3,col:salary") is None
|
| 40 |
|
| 41 |
def test_empty_string(self):
|
| 42 |
assert parse_issue_key("") is None
|
|
|
|
| 46 |
assert result == "row:3,col:salary,issue:missing_value"
|
| 47 |
|
| 48 |
|
| 49 |
+
# ──────────────────────────────────────────────────────
|
| 50 |
+
# Fix parsing
|
| 51 |
+
# ──────────────────────────────────────────────────────
|
| 52 |
+
|
| 53 |
+
class TestParseFix:
|
| 54 |
+
def test_standard_format(self):
|
| 55 |
+
result = parse_fix("row:4,col:name,fix:Alice Chen")
|
| 56 |
+
assert result == (4, "name", "Alice Chen")
|
| 57 |
+
|
| 58 |
+
def test_with_equals(self):
|
| 59 |
+
result = parse_fix("row=4,col=name,fix=Alice Chen")
|
| 60 |
+
assert result == (4, "name", "Alice Chen")
|
| 61 |
+
|
| 62 |
+
def test_numeric_fix(self):
|
| 63 |
+
result = parse_fix("row:7,col:salary,fix:75000")
|
| 64 |
+
assert result == (7, "salary", "75000")
|
| 65 |
+
|
| 66 |
+
def test_date_fix(self):
|
| 67 |
+
result = parse_fix("row:12,col:order_date,fix:2024-01-26")
|
| 68 |
+
assert result == (12, "order_date", "2024-01-26")
|
| 69 |
+
|
| 70 |
+
def test_case_insensitive(self):
|
| 71 |
+
result = parse_fix("Row:4,Col:Name,Fix:Alice Chen")
|
| 72 |
+
assert result == (4, "name", "Alice Chen")
|
| 73 |
+
|
| 74 |
+
def test_unparseable(self):
|
| 75 |
+
assert parse_fix("garbage") is None
|
| 76 |
+
assert parse_fix("row:4,col:name") is None
|
| 77 |
+
|
| 78 |
+
def test_fix_with_special_chars(self):
|
| 79 |
+
result = parse_fix("row:1,col:email,fix:alice.chen@company.com")
|
| 80 |
+
assert result == (1, "email", "alice.chen@company.com")
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
# ──────────────────────────────────────────────────────
|
| 84 |
+
# F1 scoring
|
| 85 |
+
# ──────────────────────────────────────────────────────
|
| 86 |
+
|
| 87 |
class TestComputeF1:
|
| 88 |
def test_perfect_match(self):
|
| 89 |
keys = {"row:1,col:a,issue:missing_value"}
|
| 90 |
result = compute_f1(keys, keys)
|
| 91 |
assert result["f1"] == 1.0
|
|
|
|
|
|
|
|
|
|
| 92 |
|
| 93 |
def test_no_reported_no_planted(self):
|
| 94 |
result = compute_f1(set(), set())
|
|
|
|
| 104 |
reported = {"row:99,col:x,issue:wrong_type"}
|
| 105 |
planted = {"row:1,col:a,issue:missing_value"}
|
| 106 |
result = compute_f1(reported, planted)
|
|
|
|
|
|
|
|
|
|
| 107 |
assert result["f1"] == 0.0
|
| 108 |
|
| 109 |
def test_partial_match(self):
|
|
|
|
| 123 |
assert result["recall"] == pytest.approx(2 / 3)
|
| 124 |
|
| 125 |
|
| 126 |
+
# ──────────────────────────────────────────────────────
|
| 127 |
+
# Weighted reward
|
| 128 |
+
# ──────────────────────────────────────────────────────
|
| 129 |
+
|
| 130 |
class TestComputeWeightedReward:
|
| 131 |
def test_perfect_match(self):
|
| 132 |
issues = [
|
|
|
|
| 145 |
issues = [PlantedIssue(row=1, col="a", issue_type="missing_value", description="", difficulty=2.0)]
|
| 146 |
result = compute_weighted_reward(set(), issues)
|
| 147 |
assert result["weighted_reward"] == 0.0
|
|
|
|
| 148 |
|
| 149 |
def test_hard_issue_worth_more(self):
|
| 150 |
easy = PlantedIssue(row=1, col="a", issue_type="missing_value", description="", difficulty=1.0)
|
| 151 |
hard = PlantedIssue(row=2, col="b", issue_type="statistical_outlier", description="", difficulty=3.0)
|
| 152 |
issues = [easy, hard]
|
|
|
|
|
|
|
| 153 |
hard_found = compute_weighted_reward({hard.to_key()}, issues)
|
| 154 |
easy_found = compute_weighted_reward({easy.to_key()}, issues)
|
| 155 |
assert hard_found["weighted_reward"] > easy_found["weighted_reward"]
|
|
|
|
| 163 |
assert r_correct["weighted_reward"] > r_with_fp["weighted_reward"]
|
| 164 |
|
| 165 |
|
| 166 |
+
# ──────────────────────────────────────────────────────
|
| 167 |
+
# Fix grading
|
| 168 |
+
# ──────────────────────────────────────────────────────
|
| 169 |
+
|
| 170 |
+
class TestGradeFixes:
|
| 171 |
+
@pytest.fixture
|
| 172 |
+
def easy_task(self):
|
| 173 |
+
return create_task_easy()
|
| 174 |
+
|
| 175 |
+
def test_no_fixes_no_issues(self):
|
| 176 |
+
from dataqa_env.server.tasks import Task
|
| 177 |
+
task = Task(task_id="empty", name="", description="", schema_description="",
|
| 178 |
+
validation_rules="", clean_csv="a\n1")
|
| 179 |
+
result = grade_fixes([], task)
|
| 180 |
+
assert result["fix_score"] == 1.0
|
| 181 |
+
|
| 182 |
+
def test_no_fixes_submitted(self, easy_task):
|
| 183 |
+
result = grade_fixes([], easy_task)
|
| 184 |
+
assert result["fix_score"] == 0.0
|
| 185 |
+
assert result["fixes_attempted"] == 0
|
| 186 |
+
|
| 187 |
+
def test_exact_fix_for_missing_name(self, easy_task):
|
| 188 |
+
# Row 4 has empty name — clean value is "David Kim"
|
| 189 |
+
fixes = [(4, "name", "David Kim")]
|
| 190 |
+
result = grade_fixes(fixes, easy_task)
|
| 191 |
+
assert result["fix_score"] > 0.0
|
| 192 |
+
assert result["fixes_correct"] == 1
|
| 193 |
+
|
| 194 |
+
def test_exact_fix_for_wrong_type_salary(self, easy_task):
|
| 195 |
+
# Row 7 has "seventy-five thousand" — clean value is "75000"
|
| 196 |
+
fixes = [(7, "salary", "75000")]
|
| 197 |
+
result = grade_fixes(fixes, easy_task)
|
| 198 |
+
assert result["fixes_correct"] == 1
|
| 199 |
+
|
| 200 |
+
def test_numeric_close_match(self, easy_task):
|
| 201 |
+
# Row 9 has salary "5000" — clean value is "73000"
|
| 202 |
+
# Propose 73100 (within 1% of 73000)
|
| 203 |
+
fixes = [(9, "salary", "73100")]
|
| 204 |
+
result = grade_fixes(fixes, easy_task)
|
| 205 |
+
assert result["fixes_partial"] == 1
|
| 206 |
+
|
| 207 |
+
def test_wrong_value_for_issue_cell(self, easy_task):
|
| 208 |
+
# Row 4 name is empty — propose wrong name
|
| 209 |
+
fixes = [(4, "name", "Wrong Person")]
|
| 210 |
+
result = grade_fixes(fixes, easy_task)
|
| 211 |
+
assert result["fixes_partial"] == 1 # correct cell, wrong value
|
| 212 |
+
assert result["fix_score"] > 0.0 # gets partial credit
|
| 213 |
+
|
| 214 |
+
def test_fix_for_non_issue_cell(self, easy_task):
|
| 215 |
+
# Row 1 col name is fine — no issue there
|
| 216 |
+
fixes = [(1, "name", "Some Name")]
|
| 217 |
+
result = grade_fixes(fixes, easy_task)
|
| 218 |
+
assert result["fixes_wrong"] == 1
|
| 219 |
+
assert result["fix_score"] == 0.0
|
| 220 |
+
|
| 221 |
+
def test_multiple_fixes_best_wins(self, easy_task):
|
| 222 |
+
# Submit two fixes for same cell — best one should count
|
| 223 |
+
fixes = [
|
| 224 |
+
(4, "name", "Wrong Person"), # partial credit
|
| 225 |
+
(4, "name", "David Kim"), # exact match
|
| 226 |
+
]
|
| 227 |
+
result = grade_fixes(fixes, easy_task)
|
| 228 |
+
assert result["fixes_correct"] >= 1
|
| 229 |
+
|
| 230 |
+
def test_all_fixes_correct(self, easy_task):
|
| 231 |
+
# Fix all 4 issues with exact values
|
| 232 |
+
fixes = [
|
| 233 |
+
(4, "name", "David Kim"),
|
| 234 |
+
(7, "salary", "75000"),
|
| 235 |
+
(9, "salary", "73000"),
|
| 236 |
+
# Row 11 is duplicate — clean value for employee_id is "Bob Martinez" row
|
| 237 |
+
# The duplicate is of row 2 (Bob Martinez), so the clean row 11 doesn't exist
|
| 238 |
+
]
|
| 239 |
+
result = grade_fixes(fixes, easy_task)
|
| 240 |
+
assert result["fix_score"] > 0.5 # at least 3/4 issues fixed
|
| 241 |
+
|
| 242 |
+
def test_fix_score_bounded(self, easy_task):
|
| 243 |
+
fixes = [(4, "name", "David Kim"), (99, "x", "bad")]
|
| 244 |
+
result = grade_fixes(fixes, easy_task)
|
| 245 |
+
assert 0.0 <= result["fix_score"] <= 1.0
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
# ──────────────────────────────────────────────────────
|
| 249 |
+
# Full environment lifecycle
|
| 250 |
+
# ──────────────────────────────────────────────────────
|
| 251 |
+
|
| 252 |
class TestDataQAEnvironment:
|
| 253 |
@pytest.fixture
|
| 254 |
def env(self):
|
|
|
|
| 264 |
assert obs.max_steps == 3
|
| 265 |
assert obs.done is False
|
| 266 |
assert obs.reward == 0.0
|
| 267 |
+
assert "fix" in obs.feedback.lower() # mentions fix phase
|
| 268 |
|
| 269 |
def test_reset_medium(self, env):
|
| 270 |
obs = env.reset(task_id="medium")
|
|
|
|
| 272 |
|
| 273 |
def test_reset_hard(self, env):
|
| 274 |
obs = env.reset(task_id="hard")
|
| 275 |
+
assert obs.num_issues_hint == 10
|
| 276 |
|
| 277 |
+
def test_step_identify_only(self, env):
|
| 278 |
+
"""Backward compatible: only issues, no fixes."""
|
| 279 |
env.reset(task_id="easy")
|
|
|
|
| 280 |
action = DataQAAction(
|
| 281 |
issues=[
|
| 282 |
"row:4,col:name,issue:missing_value",
|
|
|
|
| 288 |
)
|
| 289 |
obs = env.step(action)
|
| 290 |
assert obs.done is True
|
| 291 |
+
assert obs.reward >= 0.999 # identify-only uses identify_score directly
|
| 292 |
+
|
| 293 |
+
def test_step_with_fixes_increases_reward(self, env):
|
| 294 |
+
"""Submitting correct fixes should increase reward beyond identify-only."""
|
| 295 |
+
env.reset(task_id="easy")
|
| 296 |
+
# Step 1: identify only
|
| 297 |
+
action1 = DataQAAction(
|
| 298 |
+
issues=[
|
| 299 |
+
"row:4,col:name,issue:missing_value",
|
| 300 |
+
"row:7,col:salary,issue:wrong_type",
|
| 301 |
+
"row:11,col:employee_id,issue:duplicate_row",
|
| 302 |
+
"row:9,col:salary,issue:out_of_range",
|
| 303 |
+
],
|
| 304 |
+
task_id="easy",
|
| 305 |
+
)
|
| 306 |
+
obs1 = env.step(action1)
|
| 307 |
+
score_identify = obs1.reward
|
| 308 |
+
|
| 309 |
+
# Reset for fair comparison
|
| 310 |
+
env.reset(task_id="easy")
|
| 311 |
+
# Step with identify + fixes
|
| 312 |
+
action2 = DataQAAction(
|
| 313 |
+
issues=[
|
| 314 |
+
"row:4,col:name,issue:missing_value",
|
| 315 |
+
"row:7,col:salary,issue:wrong_type",
|
| 316 |
+
"row:11,col:employee_id,issue:duplicate_row",
|
| 317 |
+
"row:9,col:salary,issue:out_of_range",
|
| 318 |
+
],
|
| 319 |
+
fixes=[
|
| 320 |
+
"row:4,col:name,fix:David Kim",
|
| 321 |
+
"row:7,col:salary,fix:75000",
|
| 322 |
+
"row:9,col:salary,fix:73000",
|
| 323 |
+
],
|
| 324 |
+
task_id="easy",
|
| 325 |
+
)
|
| 326 |
+
obs2 = env.step(action2)
|
| 327 |
+
score_with_fixes = obs2.metadata["combined_reward"]
|
| 328 |
+
|
| 329 |
+
# With correct fixes, combined should be close to 1.0
|
| 330 |
+
assert score_with_fixes > 0.8
|
| 331 |
|
| 332 |
def test_step_with_partial_issues(self, env):
|
| 333 |
env.reset(task_id="easy")
|
|
|
|
| 353 |
assert obs.done is True
|
| 354 |
|
| 355 |
def test_auto_reset_on_step(self, env):
|
|
|
|
| 356 |
action = DataQAAction(
|
| 357 |
issues=["row:4,col:name,issue:missing_value"],
|
| 358 |
task_id="easy",
|
|
|
|
| 380 |
env.step(action1)
|
| 381 |
score_after_1 = env.state.best_score
|
| 382 |
|
|
|
|
| 383 |
action2 = DataQAAction(issues=["row:99,col:x,issue:wrong_type"], task_id="easy")
|
| 384 |
env.step(action2)
|
| 385 |
assert env.state.best_score >= score_after_1
|
| 386 |
|
| 387 |
+
def test_metadata_includes_both_phases(self, env):
|
| 388 |
env.reset(task_id="easy")
|
| 389 |
+
action = DataQAAction(
|
| 390 |
+
issues=["row:4,col:name,issue:missing_value"],
|
| 391 |
+
fixes=["row:4,col:name,fix:David Kim"],
|
| 392 |
+
task_id="easy",
|
| 393 |
+
)
|
| 394 |
obs = env.step(action)
|
| 395 |
+
m = obs.metadata
|
| 396 |
+
assert "identify_f1" in m
|
| 397 |
+
assert "identify_score" in m
|
| 398 |
+
assert "fix_score" in m
|
| 399 |
+
assert "combined_reward" in m
|
| 400 |
+
assert "tp" in m
|
| 401 |
+
assert "fixes_correct" in m
|
| 402 |
+
assert "fixes_attempted" in m
|
| 403 |
|
| 404 |
def test_parse_error_in_feedback(self, env):
|
| 405 |
env.reset(task_id="easy")
|
|
|
|
| 416 |
for _ in range(3):
|
| 417 |
action = DataQAAction(
|
| 418 |
issues=["row:1,col:x,issue:wrong_type", "row:99,col:y,issue:missing_value"],
|
| 419 |
+
fixes=["row:1,col:x,fix:wrong"],
|
| 420 |
task_id="hard",
|
| 421 |
)
|
| 422 |
obs = env.step(action)
|
| 423 |
assert 0.0 <= obs.reward <= 1.0
|
| 424 |
+
|
| 425 |
+
def test_combined_reward_weights(self, env):
|
| 426 |
+
"""Verify combined = IDENTIFY_WEIGHT * identify + FIX_WEIGHT * fix."""
|
| 427 |
+
env.reset(task_id="easy")
|
| 428 |
+
action = DataQAAction(
|
| 429 |
+
issues=[
|
| 430 |
+
"row:4,col:name,issue:missing_value",
|
| 431 |
+
"row:7,col:salary,issue:wrong_type",
|
| 432 |
+
"row:11,col:employee_id,issue:duplicate_row",
|
| 433 |
+
"row:9,col:salary,issue:out_of_range",
|
| 434 |
+
],
|
| 435 |
+
fixes=["row:4,col:name,fix:David Kim"],
|
| 436 |
+
task_id="easy",
|
| 437 |
+
)
|
| 438 |
+
obs = env.step(action)
|
| 439 |
+
m = obs.metadata
|
| 440 |
+
expected = IDENTIFY_WEIGHT * m["identify_score"] + FIX_WEIGHT * m["fix_score"]
|
| 441 |
+
assert abs(m["combined_reward"] - expected) < 0.01
|
| 442 |
+
|
| 443 |
+
def test_fix_feedback_shown_when_fixes_submitted(self, env):
|
| 444 |
+
env.reset(task_id="easy")
|
| 445 |
+
action = DataQAAction(
|
| 446 |
+
issues=["row:4,col:name,issue:missing_value"],
|
| 447 |
+
fixes=["row:4,col:name,fix:David Kim"],
|
| 448 |
+
task_id="easy",
|
| 449 |
+
)
|
| 450 |
+
obs = env.step(action)
|
| 451 |
+
assert "Fix Proposals" in obs.feedback
|
| 452 |
+
assert "Combined Reward" in obs.feedback
|
| 453 |
+
|
| 454 |
+
def test_no_fix_penalty_when_no_fixes_submitted(self, env):
|
| 455 |
+
"""If agent submits no fixes, reward = identify_score (no penalty)."""
|
| 456 |
+
env.reset(task_id="easy")
|
| 457 |
+
action = DataQAAction(
|
| 458 |
+
issues=[
|
| 459 |
+
"row:4,col:name,issue:missing_value",
|
| 460 |
+
"row:7,col:salary,issue:wrong_type",
|
| 461 |
+
"row:11,col:employee_id,issue:duplicate_row",
|
| 462 |
+
"row:9,col:salary,issue:out_of_range",
|
| 463 |
+
],
|
| 464 |
+
task_id="easy",
|
| 465 |
+
)
|
| 466 |
+
obs = env.step(action)
|
| 467 |
+
# identify_score should be ~1.0 since all issues found
|
| 468 |
+
assert obs.reward >= 0.99
|
| 469 |
+
# combined_reward equals identify_score when no fixes
|
| 470 |
+
assert obs.metadata["combined_reward"] == obs.metadata["identify_score"]
|
tests/test_extensibility.py
CHANGED
|
@@ -151,8 +151,8 @@ class TestRegisterTask:
|
|
| 151 |
|
| 152 |
|
| 153 |
class TestCustomTaskInEnvironment:
|
| 154 |
-
def
|
| 155 |
-
"""Custom task works end-to-end
|
| 156 |
task = create_task_from_config(
|
| 157 |
task_id="e2e_custom",
|
| 158 |
name="E2E Custom",
|
|
@@ -171,7 +171,6 @@ class TestCustomTaskInEnvironment:
|
|
| 171 |
obs = env.reset(task_id="e2e_custom")
|
| 172 |
assert obs.num_issues_hint == 2
|
| 173 |
|
| 174 |
-
# Submit correct answers
|
| 175 |
action = DataQAAction(
|
| 176 |
issues=[i.to_key() for i in task.planted_issues],
|
| 177 |
task_id="e2e_custom",
|
|
@@ -180,6 +179,37 @@ class TestCustomTaskInEnvironment:
|
|
| 180 |
assert obs.done is True
|
| 181 |
assert obs.reward >= 0.999
|
| 182 |
|
| 183 |
-
# Cleanup
|
| 184 |
from dataqa_env.server.tasks import TASK_REGISTRY
|
| 185 |
del TASK_REGISTRY["e2e_custom"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
|
| 153 |
class TestCustomTaskInEnvironment:
|
| 154 |
+
def test_full_lifecycle_identify_only(self):
|
| 155 |
+
"""Custom task works end-to-end with identify-only."""
|
| 156 |
task = create_task_from_config(
|
| 157 |
task_id="e2e_custom",
|
| 158 |
name="E2E Custom",
|
|
|
|
| 171 |
obs = env.reset(task_id="e2e_custom")
|
| 172 |
assert obs.num_issues_hint == 2
|
| 173 |
|
|
|
|
| 174 |
action = DataQAAction(
|
| 175 |
issues=[i.to_key() for i in task.planted_issues],
|
| 176 |
task_id="e2e_custom",
|
|
|
|
| 179 |
assert obs.done is True
|
| 180 |
assert obs.reward >= 0.999
|
| 181 |
|
|
|
|
| 182 |
from dataqa_env.server.tasks import TASK_REGISTRY
|
| 183 |
del TASK_REGISTRY["e2e_custom"]
|
| 184 |
+
|
| 185 |
+
def test_full_lifecycle_identify_and_fix(self):
|
| 186 |
+
"""Custom task works end-to-end with both identify and fix."""
|
| 187 |
+
task = create_task_from_config(
|
| 188 |
+
task_id="e2e_fix",
|
| 189 |
+
name="E2E Fix",
|
| 190 |
+
description="End-to-end test with fixes",
|
| 191 |
+
schema_description="id: int, name: str, score: int",
|
| 192 |
+
validation_rules="No missing values",
|
| 193 |
+
clean_csv=SIMPLE_CSV,
|
| 194 |
+
contaminations=[
|
| 195 |
+
{"rule": "missing_value", "row": 0, "col": 1, "difficulty": 1.0},
|
| 196 |
+
],
|
| 197 |
+
)
|
| 198 |
+
register_task("e2e_fix", lambda seed: task)
|
| 199 |
+
|
| 200 |
+
env = DataQAEnvironment()
|
| 201 |
+
env.reset(task_id="e2e_fix")
|
| 202 |
+
|
| 203 |
+
# Submit issues + fix
|
| 204 |
+
action = DataQAAction(
|
| 205 |
+
issues=[task.planted_issues[0].to_key()],
|
| 206 |
+
fixes=["row:1,col:name,fix:Alice"], # clean value is "Alice"
|
| 207 |
+
task_id="e2e_fix",
|
| 208 |
+
)
|
| 209 |
+
obs = env.step(action)
|
| 210 |
+
assert obs.done is True
|
| 211 |
+
assert obs.metadata["fix_score"] > 0.0
|
| 212 |
+
assert obs.metadata["combined_reward"] > 0.0
|
| 213 |
+
|
| 214 |
+
from dataqa_env.server.tasks import TASK_REGISTRY
|
| 215 |
+
del TASK_REGISTRY["e2e_fix"]
|
tests/test_inference.py
CHANGED
|
@@ -1,12 +1,11 @@
|
|
| 1 |
"""Tests for the inference script's parsing, prompt building, and log format."""
|
| 2 |
|
| 3 |
-
import re
|
| 4 |
import pytest
|
| 5 |
import sys
|
| 6 |
import os
|
| 7 |
|
| 8 |
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
| 9 |
-
from inference import parse_llm_response, build_user_prompt, log_start, log_step, log_end
|
| 10 |
|
| 11 |
|
| 12 |
class TestParseLLMResponse:
|
|
@@ -50,7 +49,7 @@ class TestParseLLMResponse:
|
|
| 50 |
def test_deduplication_not_applied(self):
|
| 51 |
response = "row:1,col:name,issue:missing_value\nrow:1,col:name,issue:missing_value"
|
| 52 |
issues = parse_llm_response(response)
|
| 53 |
-
assert len(issues) == 2
|
| 54 |
|
| 55 |
def test_with_column_variant(self):
|
| 56 |
response = "row:1,column:name,issue:missing_value"
|
|
@@ -58,6 +57,38 @@ class TestParseLLMResponse:
|
|
| 58 |
assert len(issues) == 1
|
| 59 |
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
class TestBuildUserPrompt:
|
| 62 |
def test_includes_all_fields(self):
|
| 63 |
obs = {
|
|
@@ -100,6 +131,18 @@ class TestBuildUserPrompt:
|
|
| 100 |
prompt = build_user_prompt(obs)
|
| 101 |
assert "FEEDBACK" not in prompt
|
| 102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
|
| 104 |
class TestLogFormat:
|
| 105 |
"""Verify stdout log format matches hackathon evaluation requirements."""
|
|
|
|
| 1 |
"""Tests for the inference script's parsing, prompt building, and log format."""
|
| 2 |
|
|
|
|
| 3 |
import pytest
|
| 4 |
import sys
|
| 5 |
import os
|
| 6 |
|
| 7 |
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
| 8 |
+
from inference import parse_llm_response, parse_fix_response, build_user_prompt, log_start, log_step, log_end
|
| 9 |
|
| 10 |
|
| 11 |
class TestParseLLMResponse:
|
|
|
|
| 49 |
def test_deduplication_not_applied(self):
|
| 50 |
response = "row:1,col:name,issue:missing_value\nrow:1,col:name,issue:missing_value"
|
| 51 |
issues = parse_llm_response(response)
|
| 52 |
+
assert len(issues) == 2
|
| 53 |
|
| 54 |
def test_with_column_variant(self):
|
| 55 |
response = "row:1,column:name,issue:missing_value"
|
|
|
|
| 57 |
assert len(issues) == 1
|
| 58 |
|
| 59 |
|
| 60 |
+
class TestParseFixResponse:
|
| 61 |
+
def test_standard_format(self):
|
| 62 |
+
response = "row:4,col:name,fix:David Kim\nrow:7,col:salary,fix:75000"
|
| 63 |
+
fixes = parse_fix_response(response)
|
| 64 |
+
assert len(fixes) == 2
|
| 65 |
+
assert "row:4,col:name,fix:David Kim" in fixes
|
| 66 |
+
|
| 67 |
+
def test_numbered_list(self):
|
| 68 |
+
response = "1. row:4,col:name,fix:David Kim\n2. row:7,col:salary,fix:75000"
|
| 69 |
+
fixes = parse_fix_response(response)
|
| 70 |
+
assert len(fixes) == 2
|
| 71 |
+
|
| 72 |
+
def test_with_special_chars(self):
|
| 73 |
+
response = "row:1,col:email,fix:alice.chen@company.com"
|
| 74 |
+
fixes = parse_fix_response(response)
|
| 75 |
+
assert len(fixes) == 1
|
| 76 |
+
assert "alice.chen@company.com" in fixes[0]
|
| 77 |
+
|
| 78 |
+
def test_empty_response(self):
|
| 79 |
+
assert parse_fix_response("") == []
|
| 80 |
+
|
| 81 |
+
def test_date_fix(self):
|
| 82 |
+
response = "row:12,col:order_date,fix:2024-01-26"
|
| 83 |
+
fixes = parse_fix_response(response)
|
| 84 |
+
assert len(fixes) == 1
|
| 85 |
+
|
| 86 |
+
def test_ignores_issue_lines(self):
|
| 87 |
+
response = "row:4,col:name,issue:missing_value\nrow:4,col:name,fix:David Kim"
|
| 88 |
+
fixes = parse_fix_response(response)
|
| 89 |
+
assert len(fixes) == 1 # only the fix line
|
| 90 |
+
|
| 91 |
+
|
| 92 |
class TestBuildUserPrompt:
|
| 93 |
def test_includes_all_fields(self):
|
| 94 |
obs = {
|
|
|
|
| 131 |
prompt = build_user_prompt(obs)
|
| 132 |
assert "FEEDBACK" not in prompt
|
| 133 |
|
| 134 |
+
def test_include_fixes_flag(self):
|
| 135 |
+
obs = {
|
| 136 |
+
"task_description": "Find issues",
|
| 137 |
+
"schema_description": "",
|
| 138 |
+
"validation_rules": "",
|
| 139 |
+
"dataset_csv": "a\n1",
|
| 140 |
+
"num_issues_hint": 0,
|
| 141 |
+
"feedback": "",
|
| 142 |
+
}
|
| 143 |
+
prompt = build_user_prompt(obs, include_fixes=True)
|
| 144 |
+
assert "fix" in prompt.lower()
|
| 145 |
+
|
| 146 |
|
| 147 |
class TestLogFormat:
|
| 148 |
"""Verify stdout log format matches hackathon evaluation requirements."""
|
tests/test_tasks.py
CHANGED
|
@@ -113,8 +113,8 @@ class TestTaskHard:
|
|
| 113 |
def test_task_id(self, task):
|
| 114 |
assert task.task_id == "hard"
|
| 115 |
|
| 116 |
-
def
|
| 117 |
-
assert len(task.planted_issues) ==
|
| 118 |
|
| 119 |
def test_issue_types(self, task):
|
| 120 |
types = {i.issue_type for i in task.planted_issues}
|
|
|
|
| 113 |
def test_task_id(self, task):
|
| 114 |
assert task.task_id == "hard"
|
| 115 |
|
| 116 |
+
def test_has_10_issues(self, task):
|
| 117 |
+
assert len(task.planted_issues) == 10
|
| 118 |
|
| 119 |
def test_issue_types(self, task):
|
| 120 |
types = {i.issue_type for i in task.planted_issues}
|