Spaces:
Sleeping
Sleeping
Commit ·
cd11aba
1
Parent(s): 4c1a85d
fixes v1: add per step reward
Browse files- README.md +198 -41
- dataqa_env/__init__.py +16 -1
- dataqa_env/server/environment.py +77 -4
- dataqa_env/server/tasks.py +158 -36
- inference.py +116 -126
- prevalidation_script.sh +185 -0
- sample_inference_script.py +188 -0
- tests/__init__.py +0 -0
- tests/test_environment.py +249 -0
- tests/test_extensibility.py +185 -0
- tests/test_inference.py +148 -0
- tests/test_tasks.py +161 -0
README.md
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
---
|
| 2 |
title: DataQA Environment Server
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: blue
|
| 5 |
colorTo: gray
|
| 6 |
sdk: docker
|
|
@@ -15,49 +15,173 @@ tags:
|
|
| 15 |
|
| 16 |
An OpenEnv environment for **Data Quality Assurance** — an LLM agent inspects datasets with planted quality issues and must identify them all.
|
| 17 |
|
| 18 |
-
##
|
| 19 |
|
| 20 |
-
|
| 21 |
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
Every ML engineer and data scientist spends significant time debugging data quality issues — missing values, type mismatches, inconsistencies, and subtle statistical anomalies. This environment turns that task into a structured, gradable challenge.
|
| 25 |
|
| 26 |
## Environment API
|
| 27 |
|
| 28 |
-
| Endpoint | Description |
|
| 29 |
-
|----------|-------------|
|
| 30 |
-
| `reset
|
| 31 |
-
| `step
|
| 32 |
-
| `state
|
|
|
|
| 33 |
|
| 34 |
## Tasks
|
| 35 |
|
| 36 |
-
| Task | Issues | Difficulty | Description |
|
| 37 |
-
|------|--------|-----------|-------------|
|
| 38 |
-
| `easy` | 4 | Beginner | Employee
|
| 39 |
-
| `medium` | 6 | Intermediate | E-commerce orders
|
| 40 |
-
| `hard` | 8 | Advanced | ML experiment metadata
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
## Reward Function
|
| 43 |
|
| 44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
|
| 46 |
-
|
| 47 |
-
- **Recall**: What fraction of planted issues did the agent find?
|
| 48 |
-
- **F1**: `2 * precision * recall / (precision + recall)`
|
| 49 |
|
| 50 |
-
|
| 51 |
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
-
|
| 55 |
|
| 56 |
-
|
| 57 |
|
| 58 |
-
|
| 59 |
|
| 60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
## Quick Start
|
| 63 |
|
|
@@ -68,42 +192,75 @@ pip install -e .
|
|
| 68 |
# Run server locally
|
| 69 |
uvicorn dataqa_env.server.app:app --host 0.0.0.0 --port 8000
|
| 70 |
|
| 71 |
-
# Run inference
|
| 72 |
-
API_BASE_URL=https://
|
| 73 |
-
MODEL_NAME=
|
| 74 |
-
|
| 75 |
python inference.py
|
| 76 |
```
|
| 77 |
|
| 78 |
## Docker
|
| 79 |
|
| 80 |
```bash
|
| 81 |
-
docker build -t dataqa-env
|
| 82 |
docker run -p 8000:8000 dataqa-env
|
| 83 |
```
|
| 84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
## Environment Variables
|
| 86 |
|
| 87 |
| Variable | Description | Default |
|
| 88 |
|----------|-------------|---------|
|
| 89 |
-
| `API_BASE_URL` | LLM API endpoint | `https://
|
| 90 |
-
| `MODEL_NAME` | Model identifier | `
|
| 91 |
-
| `HF_TOKEN` | HuggingFace token | - |
|
| 92 |
| `ENV_URL` | Environment server URL | `http://localhost:8000` |
|
| 93 |
-
| `LLM_API_KEY` | API key for LLM provider | Falls back to HF_TOKEN |
|
| 94 |
|
| 95 |
## Architecture
|
| 96 |
|
| 97 |
```
|
| 98 |
dataqa_env/
|
|
|
|
| 99 |
├── models.py # Pydantic: DataQAAction, DataQAObservation, DataQAState
|
| 100 |
├── client.py # EnvClient for WebSocket connections
|
| 101 |
├── server/
|
| 102 |
-
│ ├── environment.py # Core DataQAEnvironment (reset/step/state)
|
| 103 |
-
│ ├── tasks.py # Task definitions +
|
| 104 |
-
│ ├── app.py # FastAPI server
|
| 105 |
│ └── Dockerfile
|
| 106 |
-
|
| 107 |
-
├──
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
```
|
|
|
|
| 1 |
---
|
| 2 |
title: DataQA Environment Server
|
| 3 |
+
emoji: "\U0001F50D"
|
| 4 |
colorFrom: blue
|
| 5 |
colorTo: gray
|
| 6 |
sdk: docker
|
|
|
|
| 15 |
|
| 16 |
An OpenEnv environment for **Data Quality Assurance** — an LLM agent inspects datasets with planted quality issues and must identify them all.
|
| 17 |
|
| 18 |
+
## Motivation
|
| 19 |
|
| 20 |
+
Every ML engineer and data scientist spends significant time debugging data quality issues — missing values, type mismatches, logical inconsistencies, and subtle statistical anomalies — before data enters ML pipelines or production databases. This is a genuine, high-frequency human task that directly impacts model quality and business outcomes.
|
| 21 |
|
| 22 |
+
DataQA turns this into a structured, gradable RL environment where agents must systematically inspect corrupted datasets, reason about schema constraints and validation rules, and pinpoint every planted issue — from obvious nulls to subtle data leakage signals that require domain expertise.
|
|
|
|
|
|
|
| 23 |
|
| 24 |
## Environment API
|
| 25 |
|
| 26 |
+
| Endpoint | Method | Description |
|
| 27 |
+
|----------|--------|-------------|
|
| 28 |
+
| `/reset` | POST | Start a new episode with a corrupted dataset |
|
| 29 |
+
| `/step` | POST | Submit identified issues, receive scored feedback |
|
| 30 |
+
| `/state` | GET | Get current episode state |
|
| 31 |
+
| `/health` | GET | Health check |
|
| 32 |
|
| 33 |
## Tasks
|
| 34 |
|
| 35 |
+
| Task | Issues | Difficulty | Domain | Description |
|
| 36 |
+
|------|--------|-----------|--------|-------------|
|
| 37 |
+
| `easy` | 4 | Beginner | HR/Employee data | Nulls, wrong types, duplicates, out-of-range values |
|
| 38 |
+
| `medium` | 6 | Intermediate | E-commerce orders | Format violations, inconsistent computed fields, duplicate keys |
|
| 39 |
+
| `hard` | 8 | Advanced | ML experiment metadata | Data leakage signals, unreasonable GPU memory, timestamp ordering, whitespace-only fields |
|
| 40 |
+
|
| 41 |
+
**Difficulty progression**: Easy issues are individually obvious (empty fields, text in numeric columns). Medium issues require cross-column reasoning (total != qty * price) and set membership checks. Hard issues require ML domain knowledge (val_loss < train_loss = data leakage) and multi-row temporal reasoning.
|
| 42 |
+
|
| 43 |
+
## Action Space
|
| 44 |
+
|
| 45 |
+
The agent submits a list of issue strings, each in the format:
|
| 46 |
+
```
|
| 47 |
+
row:<row_number>,col:<column_name>,issue:<issue_type>
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
- `row_number`: 1-indexed position in the CSV data (after header). Row 1 = first data row.
|
| 51 |
+
- `column_name`: Exact column header name, lowercase.
|
| 52 |
+
- `issue_type`: One of the supported types below.
|
| 53 |
+
|
| 54 |
+
**Supported Issue Types:**
|
| 55 |
+
|
| 56 |
+
| Type | Description | Example |
|
| 57 |
+
|------|-------------|---------|
|
| 58 |
+
| `missing_value` | Null, empty, or whitespace-only | Empty name field |
|
| 59 |
+
| `wrong_type` | Value doesn't match expected type | Salary as "seventy-five thousand" |
|
| 60 |
+
| `duplicate_row` | Exact duplicate or duplicate key | Two rows with same employee_id |
|
| 61 |
+
| `out_of_range` | Value outside valid range | Salary of 5000 when min is 50000 |
|
| 62 |
+
| `format_violation` | Wrong format or invalid enum | Date as DD/MM/YYYY instead of YYYY-MM-DD |
|
| 63 |
+
| `inconsistent_value` | Computed field mismatch, logical inconsistency | total != qty * price |
|
| 64 |
+
| `statistical_outlier` | Unreasonable value given context | resnet18 using 42.5GB GPU |
|
| 65 |
+
| `referential_integrity` | Foreign key violation | (available for custom tasks) |
|
| 66 |
+
|
| 67 |
+
## Observation Space
|
| 68 |
+
|
| 69 |
+
Each observation contains:
|
| 70 |
+
|
| 71 |
+
| Field | Type | Description |
|
| 72 |
+
|-------|------|-------------|
|
| 73 |
+
| `dataset_csv` | str | The corrupted dataset in CSV format |
|
| 74 |
+
| `schema_description` | str | Column types, ranges, and constraints |
|
| 75 |
+
| `validation_rules` | str | Business rules the data must satisfy |
|
| 76 |
+
| `task_description` | str | Task context and instructions |
|
| 77 |
+
| `feedback` | str | Results from previous step (TP/FP/FN counts, precision/recall) |
|
| 78 |
+
| `num_issues_hint` | int | Exact count of planted issues |
|
| 79 |
+
| `max_steps` | int | Maximum attempts allowed |
|
| 80 |
+
| `done` | bool | Whether episode has terminated |
|
| 81 |
+
| `reward` | float | Best weighted reward so far (0.0-1.0) |
|
| 82 |
+
|
| 83 |
+
**Observation Metadata** (available after each step):
|
| 84 |
+
- `f1`, `weighted_reward`, `precision`, `recall`
|
| 85 |
+
- `tp`, `fp`, `fn`
|
| 86 |
+
- `difficulty_found`, `difficulty_missed`
|
| 87 |
|
| 88 |
## Reward Function
|
| 89 |
|
| 90 |
+
### Difficulty-Weighted Reward (Primary)
|
| 91 |
+
|
| 92 |
+
Each planted issue has a **difficulty weight** (1.0-3.0) reflecting how hard it is to detect. The primary reward is a **weighted F1 score** that provides meaningful per-step partial progress signals:
|
| 93 |
+
|
| 94 |
+
| Weight | Category | Examples |
|
| 95 |
+
|--------|----------|----------|
|
| 96 |
+
| 1.0 | Easy | Missing values, obvious out-of-range, wrong type |
|
| 97 |
+
| 1.5-2.0 | Medium | Duplicate keys, format violations, cross-column checks |
|
| 98 |
+
| 2.5-3.0 | Hard | Data leakage, statistical outliers, whitespace-only |
|
| 99 |
+
|
| 100 |
+
**Formula:**
|
| 101 |
+
- **Weighted Recall** = (sum of difficulty weights for found issues) / (total difficulty weight)
|
| 102 |
+
- **Weighted Precision** = (found weight) / (found weight + FP count * avg difficulty)
|
| 103 |
+
- **Weighted F1** = harmonic mean of weighted precision and recall
|
| 104 |
+
|
| 105 |
+
This means:
|
| 106 |
+
- Finding a hard issue (difficulty 3.0) increases reward 3x more than finding an easy one (1.0)
|
| 107 |
+
- False positives are penalized proportionally to average issue difficulty
|
| 108 |
+
- The agent sees meaningful reward differences at every step, not just pass/fail
|
| 109 |
+
|
| 110 |
+
### Standard F1 (also computed)
|
| 111 |
+
|
| 112 |
+
Available in observation metadata for comparison. Uses unweighted set matching.
|
| 113 |
+
|
| 114 |
+
### Episode Boundaries
|
| 115 |
+
|
| 116 |
+
- Each task allows up to 3 steps (attempts)
|
| 117 |
+
- Episode ends when F1 >= 0.999 (perfect) or max steps reached
|
| 118 |
+
- Best score across all steps is the final reward (monotonically non-decreasing)
|
| 119 |
+
- Reward is always in [0.0, 1.0]
|
| 120 |
|
| 121 |
+
## Baseline Scores
|
|
|
|
|
|
|
| 122 |
|
| 123 |
+
Baseline scores using Qwen2.5-72B-Instruct via HuggingFace Router:
|
| 124 |
|
| 125 |
+
| Task | Expected Score Range | Description |
|
| 126 |
+
|------|---------------------|-------------|
|
| 127 |
+
| `easy` | 0.7 - 1.0 | Most LLMs find obvious issues reliably |
|
| 128 |
+
| `medium` | 0.5 - 0.8 | Cross-column reasoning is challenging |
|
| 129 |
+
| `hard` | 0.3 - 0.6 | ML domain knowledge and subtle patterns |
|
| 130 |
|
| 131 |
+
Scores vary by model capability. Frontier models (GPT-4, Claude) typically score higher on the hard task due to better domain reasoning.
|
| 132 |
|
| 133 |
+
## Extensibility
|
| 134 |
|
| 135 |
+
DataQA supports custom tasks, contamination rules, and difficulty levels via a programmatic API.
|
| 136 |
|
| 137 |
+
### Custom Contamination Rules
|
| 138 |
+
|
| 139 |
+
```python
|
| 140 |
+
from dataqa_env import register_contamination_rule
|
| 141 |
+
from dataqa_env.server.tasks import PlantedIssue
|
| 142 |
+
|
| 143 |
+
def swap_digits(rows, header, col_idx, row_idx, rng):
|
| 144 |
+
val = rows[row_idx][col_idx]
|
| 145 |
+
corrupted = val[::-1]
|
| 146 |
+
issue = PlantedIssue(
|
| 147 |
+
row=row_idx + 1, col=header[col_idx],
|
| 148 |
+
issue_type="format_violation",
|
| 149 |
+
description=f"Digits swapped in {header[col_idx]}",
|
| 150 |
+
difficulty=2.0,
|
| 151 |
+
)
|
| 152 |
+
return corrupted, issue
|
| 153 |
+
|
| 154 |
+
register_contamination_rule("swap_digits", swap_digits)
|
| 155 |
+
```
|
| 156 |
+
|
| 157 |
+
### Custom Tasks from Config
|
| 158 |
+
|
| 159 |
+
```python
|
| 160 |
+
from dataqa_env import create_task_from_config, register_task
|
| 161 |
+
|
| 162 |
+
task = create_task_from_config(
|
| 163 |
+
task_id="custom",
|
| 164 |
+
name="Custom Validation",
|
| 165 |
+
description="Find quality issues in this dataset.",
|
| 166 |
+
schema_description="id: int, name: str, score: int (0-100)",
|
| 167 |
+
validation_rules="No missing values. Scores must be 0-100.",
|
| 168 |
+
clean_csv="id,name,score\n1,Alice,95\n2,Bob,87\n3,Carol,92",
|
| 169 |
+
contaminations=[
|
| 170 |
+
{"rule": "missing_value", "row": 0, "col": 1, "difficulty": 1.0},
|
| 171 |
+
{"rule": "negative_value", "row": 2, "col": 2, "difficulty": 1.5},
|
| 172 |
+
],
|
| 173 |
+
)
|
| 174 |
+
register_task("custom", lambda seed: task)
|
| 175 |
+
```
|
| 176 |
+
|
| 177 |
+
### Built-in Contamination Rules
|
| 178 |
+
|
| 179 |
+
| Rule | Effect | Default Difficulty |
|
| 180 |
+
|------|--------|--------------------|
|
| 181 |
+
| `missing_value` | Sets field to empty string | 1.0 |
|
| 182 |
+
| `whitespace_value` | Sets field to single space | 2.5 |
|
| 183 |
+
| `wrong_type_text` | Replaces with random text ("N/A", "null", etc.) | 1.0 |
|
| 184 |
+
| `negative_value` | Negates numeric value | 1.0 |
|
| 185 |
|
| 186 |
## Quick Start
|
| 187 |
|
|
|
|
| 192 |
# Run server locally
|
| 193 |
uvicorn dataqa_env.server.app:app --host 0.0.0.0 --port 8000
|
| 194 |
|
| 195 |
+
# Run inference (set your API credentials)
|
| 196 |
+
API_BASE_URL=https://router.huggingface.co/v1 \
|
| 197 |
+
MODEL_NAME=Qwen/Qwen2.5-72B-Instruct \
|
| 198 |
+
HF_TOKEN=your-token \
|
| 199 |
python inference.py
|
| 200 |
```
|
| 201 |
|
| 202 |
## Docker
|
| 203 |
|
| 204 |
```bash
|
| 205 |
+
docker build -t dataqa-env .
|
| 206 |
docker run -p 8000:8000 dataqa-env
|
| 207 |
```
|
| 208 |
|
| 209 |
+
## Testing
|
| 210 |
+
|
| 211 |
+
```bash
|
| 212 |
+
pip install -e ".[dev]"
|
| 213 |
+
pytest tests/ -v
|
| 214 |
+
```
|
| 215 |
+
|
| 216 |
+
89 tests covering:
|
| 217 |
+
- Task creation, corruption, and issue planting (difficulty weights, seed determinism)
|
| 218 |
+
- Issue key parsing (standard, lenient, edge cases)
|
| 219 |
+
- F1 and difficulty-weighted reward computation
|
| 220 |
+
- Full environment reset/step lifecycle
|
| 221 |
+
- Inference script parsing and prompt building
|
| 222 |
+
- **Structured log format** ([START], [STEP], [END] — exact field names and ordering)
|
| 223 |
+
- Score bounds (0.0-1.0), best-score monotonicity
|
| 224 |
+
- Extensibility API (custom rules, custom tasks, environment integration)
|
| 225 |
+
|
| 226 |
+
## Validation
|
| 227 |
+
|
| 228 |
+
```bash
|
| 229 |
+
# OpenEnv spec validation
|
| 230 |
+
openenv validate .
|
| 231 |
+
|
| 232 |
+
# Pre-submission validation (requires HF Space URL)
|
| 233 |
+
./prevalidation_script.sh https://your-space.hf.space
|
| 234 |
+
```
|
| 235 |
+
|
| 236 |
## Environment Variables
|
| 237 |
|
| 238 |
| Variable | Description | Default |
|
| 239 |
|----------|-------------|---------|
|
| 240 |
+
| `API_BASE_URL` | LLM API endpoint | `https://router.huggingface.co/v1` |
|
| 241 |
+
| `MODEL_NAME` | Model identifier | `Qwen/Qwen2.5-72B-Instruct` |
|
| 242 |
+
| `HF_TOKEN` | HuggingFace token / API key | - |
|
| 243 |
| `ENV_URL` | Environment server URL | `http://localhost:8000` |
|
|
|
|
| 244 |
|
| 245 |
## Architecture
|
| 246 |
|
| 247 |
```
|
| 248 |
dataqa_env/
|
| 249 |
+
├── __init__.py # Public API + extensibility exports
|
| 250 |
├── models.py # Pydantic: DataQAAction, DataQAObservation, DataQAState
|
| 251 |
├── client.py # EnvClient for WebSocket connections
|
| 252 |
├── server/
|
| 253 |
+
│ ├── environment.py # Core DataQAEnvironment (reset/step/state + weighted rewards)
|
| 254 |
+
│ ├── tasks.py # Task definitions + contamination rules + extensibility API
|
| 255 |
+
│ ├── app.py # FastAPI server (via openenv-core create_app)
|
| 256 |
│ └── Dockerfile
|
| 257 |
+
tests/
|
| 258 |
+
├── test_tasks.py # Task creation, corruption, difficulty weights
|
| 259 |
+
├── test_environment.py # Environment lifecycle, scoring, metadata
|
| 260 |
+
├── test_inference.py # LLM response parsing, prompt building, log format
|
| 261 |
+
└── test_extensibility.py # Custom rules, custom tasks, registration API
|
| 262 |
+
inference.py # Baseline LLM agent (OpenAI client, structured logs)
|
| 263 |
+
openenv.yaml # OpenEnv/HF Spaces spec
|
| 264 |
+
pyproject.toml # Package metadata and dependencies
|
| 265 |
+
Dockerfile # Production container
|
| 266 |
```
|
dataqa_env/__init__.py
CHANGED
|
@@ -1,4 +1,19 @@
|
|
| 1 |
from .client import DataQAEnv
|
| 2 |
from .models import DataQAAction, DataQAObservation, DataQAState
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
-
__all__ = [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from .client import DataQAEnv
|
| 2 |
from .models import DataQAAction, DataQAObservation, DataQAState
|
| 3 |
+
from .server.tasks import (
|
| 4 |
+
create_task_from_config,
|
| 5 |
+
register_task,
|
| 6 |
+
register_contamination_rule,
|
| 7 |
+
CONTAMINATION_RULES,
|
| 8 |
+
)
|
| 9 |
|
| 10 |
+
__all__ = [
|
| 11 |
+
"DataQAEnv",
|
| 12 |
+
"DataQAAction",
|
| 13 |
+
"DataQAObservation",
|
| 14 |
+
"DataQAState",
|
| 15 |
+
"create_task_from_config",
|
| 16 |
+
"register_task",
|
| 17 |
+
"register_contamination_rule",
|
| 18 |
+
"CONTAMINATION_RULES",
|
| 19 |
+
]
|
dataqa_env/server/environment.py
CHANGED
|
@@ -58,6 +58,61 @@ def compute_f1(reported_keys: Set[str], planted_keys: Set[str]) -> dict:
|
|
| 58 |
return {"precision": precision, "recall": recall, "f1": f1, "tp": tp, "fp": fp, "fn": fn}
|
| 59 |
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
class DataQAEnvironment(Environment):
|
| 62 |
"""
|
| 63 |
Data Quality Assurance environment.
|
|
@@ -138,15 +193,21 @@ class DataQAEnvironment(Environment):
|
|
| 138 |
else:
|
| 139 |
parse_errors.append(f"Could not parse: '{raw_issue}'")
|
| 140 |
|
| 141 |
-
# Compute score
|
| 142 |
metrics = compute_f1(reported_keys, self._planted_keys)
|
| 143 |
score = metrics["f1"]
|
| 144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 145 |
self._state.best_score = self._best_score
|
| 146 |
|
| 147 |
# Check if done
|
| 148 |
is_done = (
|
| 149 |
-
score >= 0.999 # Perfect score
|
| 150 |
or self._state.current_step >= self._state.max_steps
|
| 151 |
)
|
| 152 |
|
|
@@ -156,6 +217,7 @@ class DataQAEnvironment(Environment):
|
|
| 156 |
f"Issues reported: {len(reported_keys)}",
|
| 157 |
f"True positives: {metrics['tp']}, False positives: {metrics['fp']}, Missed: {metrics['fn']}",
|
| 158 |
f"Precision: {metrics['precision']:.3f}, Recall: {metrics['recall']:.3f}, F1: {score:.3f}",
|
|
|
|
| 159 |
]
|
| 160 |
|
| 161 |
if parse_errors:
|
|
@@ -173,7 +235,7 @@ class DataQAEnvironment(Environment):
|
|
| 173 |
)
|
| 174 |
feedback_lines.append("You can submit again with an updated list of issues.")
|
| 175 |
else:
|
| 176 |
-
feedback_lines.append(f"Task complete! Final best
|
| 177 |
|
| 178 |
return DataQAObservation(
|
| 179 |
dataset_csv=self._current_task.corrupted_csv,
|
|
@@ -186,6 +248,17 @@ class DataQAEnvironment(Environment):
|
|
| 186 |
max_steps=self._state.max_steps,
|
| 187 |
done=is_done,
|
| 188 |
reward=self._best_score,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
)
|
| 190 |
|
| 191 |
@property
|
|
|
|
| 58 |
return {"precision": precision, "recall": recall, "f1": f1, "tp": tp, "fp": fp, "fn": fn}
|
| 59 |
|
| 60 |
|
| 61 |
+
def compute_weighted_reward(
|
| 62 |
+
reported_keys: Set[str],
|
| 63 |
+
planted_issues: list,
|
| 64 |
+
) -> dict:
|
| 65 |
+
"""
|
| 66 |
+
Compute difficulty-weighted reward for richer per-step signal.
|
| 67 |
+
|
| 68 |
+
Each planted issue has a difficulty weight (1.0-3.0). Finding harder issues
|
| 69 |
+
earns more reward. False positives incur a penalty scaled by average difficulty.
|
| 70 |
+
|
| 71 |
+
Returns dict with weighted_reward (0.0-1.0), plus per-issue breakdown.
|
| 72 |
+
"""
|
| 73 |
+
if not planted_issues and not reported_keys:
|
| 74 |
+
return {"weighted_reward": 1.0, "difficulty_found": 0.0, "difficulty_missed": 0.0}
|
| 75 |
+
|
| 76 |
+
planted_by_key = {issue.to_key(): issue for issue in planted_issues}
|
| 77 |
+
planted_keys = set(planted_by_key.keys())
|
| 78 |
+
|
| 79 |
+
if not reported_keys:
|
| 80 |
+
total_weight = sum(i.difficulty for i in planted_issues)
|
| 81 |
+
return {"weighted_reward": 0.0, "difficulty_found": 0.0, "difficulty_missed": total_weight}
|
| 82 |
+
|
| 83 |
+
if not planted_keys:
|
| 84 |
+
return {"weighted_reward": 0.0, "difficulty_found": 0.0, "difficulty_missed": 0.0}
|
| 85 |
+
|
| 86 |
+
# Sum difficulty weights for found vs missed issues
|
| 87 |
+
found_keys = reported_keys & planted_keys
|
| 88 |
+
missed_keys = planted_keys - reported_keys
|
| 89 |
+
false_positive_count = len(reported_keys - planted_keys)
|
| 90 |
+
|
| 91 |
+
difficulty_found = sum(planted_by_key[k].difficulty for k in found_keys)
|
| 92 |
+
difficulty_missed = sum(planted_by_key[k].difficulty for k in missed_keys)
|
| 93 |
+
total_weight = sum(i.difficulty for i in planted_issues)
|
| 94 |
+
|
| 95 |
+
# Weighted recall: proportion of difficulty captured
|
| 96 |
+
weighted_recall = difficulty_found / total_weight if total_weight > 0 else 0.0
|
| 97 |
+
|
| 98 |
+
# Penalty for false positives (scaled by avg difficulty so penalty is proportional)
|
| 99 |
+
avg_difficulty = total_weight / len(planted_issues)
|
| 100 |
+
fp_penalty_weight = false_positive_count * avg_difficulty
|
| 101 |
+
weighted_precision = difficulty_found / (difficulty_found + fp_penalty_weight) if (difficulty_found + fp_penalty_weight) > 0 else 0.0
|
| 102 |
+
|
| 103 |
+
# Weighted F1
|
| 104 |
+
if (weighted_precision + weighted_recall) > 0:
|
| 105 |
+
weighted_reward = 2 * weighted_precision * weighted_recall / (weighted_precision + weighted_recall)
|
| 106 |
+
else:
|
| 107 |
+
weighted_reward = 0.0
|
| 108 |
+
|
| 109 |
+
return {
|
| 110 |
+
"weighted_reward": round(weighted_reward, 4),
|
| 111 |
+
"difficulty_found": round(difficulty_found, 2),
|
| 112 |
+
"difficulty_missed": round(difficulty_missed, 2),
|
| 113 |
+
}
|
| 114 |
+
|
| 115 |
+
|
| 116 |
class DataQAEnvironment(Environment):
|
| 117 |
"""
|
| 118 |
Data Quality Assurance environment.
|
|
|
|
| 193 |
else:
|
| 194 |
parse_errors.append(f"Could not parse: '{raw_issue}'")
|
| 195 |
|
| 196 |
+
# Compute score (standard F1)
|
| 197 |
metrics = compute_f1(reported_keys, self._planted_keys)
|
| 198 |
score = metrics["f1"]
|
| 199 |
+
|
| 200 |
+
# Compute difficulty-weighted reward (richer per-step signal)
|
| 201 |
+
weighted = compute_weighted_reward(reported_keys, self._current_task.planted_issues)
|
| 202 |
+
weighted_reward = weighted["weighted_reward"]
|
| 203 |
+
|
| 204 |
+
# Use weighted reward as the primary reward signal
|
| 205 |
+
self._best_score = max(self._best_score, weighted_reward)
|
| 206 |
self._state.best_score = self._best_score
|
| 207 |
|
| 208 |
# Check if done
|
| 209 |
is_done = (
|
| 210 |
+
score >= 0.999 # Perfect score (all issues found exactly)
|
| 211 |
or self._state.current_step >= self._state.max_steps
|
| 212 |
)
|
| 213 |
|
|
|
|
| 217 |
f"Issues reported: {len(reported_keys)}",
|
| 218 |
f"True positives: {metrics['tp']}, False positives: {metrics['fp']}, Missed: {metrics['fn']}",
|
| 219 |
f"Precision: {metrics['precision']:.3f}, Recall: {metrics['recall']:.3f}, F1: {score:.3f}",
|
| 220 |
+
f"Weighted reward: {weighted_reward:.3f} (difficulty found: {weighted['difficulty_found']}, missed: {weighted['difficulty_missed']})",
|
| 221 |
]
|
| 222 |
|
| 223 |
if parse_errors:
|
|
|
|
| 235 |
)
|
| 236 |
feedback_lines.append("You can submit again with an updated list of issues.")
|
| 237 |
else:
|
| 238 |
+
feedback_lines.append(f"Task complete! Final best weighted reward: {self._best_score:.3f}")
|
| 239 |
|
| 240 |
return DataQAObservation(
|
| 241 |
dataset_csv=self._current_task.corrupted_csv,
|
|
|
|
| 248 |
max_steps=self._state.max_steps,
|
| 249 |
done=is_done,
|
| 250 |
reward=self._best_score,
|
| 251 |
+
metadata={
|
| 252 |
+
"f1": score,
|
| 253 |
+
"weighted_reward": weighted_reward,
|
| 254 |
+
"precision": metrics["precision"],
|
| 255 |
+
"recall": metrics["recall"],
|
| 256 |
+
"tp": metrics["tp"],
|
| 257 |
+
"fp": metrics["fp"],
|
| 258 |
+
"fn": metrics["fn"],
|
| 259 |
+
"difficulty_found": weighted["difficulty_found"],
|
| 260 |
+
"difficulty_missed": weighted["difficulty_missed"],
|
| 261 |
+
},
|
| 262 |
)
|
| 263 |
|
| 264 |
@property
|
dataqa_env/server/tasks.py
CHANGED
|
@@ -25,6 +25,7 @@ class PlantedIssue:
|
|
| 25 |
col: str
|
| 26 |
issue_type: str
|
| 27 |
description: str
|
|
|
|
| 28 |
|
| 29 |
def to_key(self) -> str:
|
| 30 |
return f"row:{self.row},col:{self.col},issue:{self.issue_type}"
|
|
@@ -93,29 +94,29 @@ def create_task_easy(seed: int = 42) -> Task:
|
|
| 93 |
data = rows[1:]
|
| 94 |
issues: List[PlantedIssue] = []
|
| 95 |
|
| 96 |
-
# Issue 1: Missing value - null out a name
|
| 97 |
r = 3 # row index in data (0-based), displayed as row 4 in CSV
|
| 98 |
data[r][1] = ""
|
| 99 |
issues.append(PlantedIssue(row=r + 1, col="name", issue_type="missing_value",
|
| 100 |
-
description="Empty name field"))
|
| 101 |
|
| 102 |
-
# Issue 2: Wrong type - salary as text
|
| 103 |
r = 6
|
| 104 |
data[r][4] = "seventy-five thousand"
|
| 105 |
issues.append(PlantedIssue(row=r + 1, col="salary", issue_type="wrong_type",
|
| 106 |
-
description="Salary is text instead of integer"))
|
| 107 |
|
| 108 |
-
# Issue 3: Duplicate row
|
| 109 |
dup_source = 1
|
| 110 |
data.append(list(data[dup_source]))
|
| 111 |
issues.append(PlantedIssue(row=len(data), col="employee_id", issue_type="duplicate_row",
|
| 112 |
-
description=f"Exact duplicate of row {dup_source + 1}"))
|
| 113 |
|
| 114 |
-
# Issue 4: Out of range salary
|
| 115 |
r = 8
|
| 116 |
data[r][4] = "5000"
|
| 117 |
issues.append(PlantedIssue(row=r + 1, col="salary", issue_type="out_of_range",
|
| 118 |
-
description="Salary 5000 is below minimum 50000"))
|
| 119 |
|
| 120 |
corrupted = _rows_to_csv([header] + data)
|
| 121 |
|
|
@@ -190,41 +191,41 @@ ORD-020,CUST-118,Fitness Tracker,Electronics,1,79.99,2024-02-03,AU,delivered,79.
|
|
| 190 |
data = rows[1:]
|
| 191 |
issues: List[PlantedIssue] = []
|
| 192 |
|
| 193 |
-
# Issue 1: total doesn't match quantity * unit_price
|
| 194 |
r = 4 # ORD-005
|
| 195 |
data[r][9] = "84.00" # should be 42.00 (qty=1, price=42.00)
|
| 196 |
issues.append(PlantedIssue(row=r + 1, col="total", issue_type="inconsistent_value",
|
| 197 |
-
description="total (84.00) != quantity (1) * unit_price (42.00)"))
|
| 198 |
|
| 199 |
-
# Issue 2: Invalid category
|
| 200 |
r = 9 # ORD-010
|
| 201 |
data[r][3] = "Fitness" # should be Sports
|
| 202 |
issues.append(PlantedIssue(row=r + 1, col="category", issue_type="format_violation",
|
| 203 |
-
description="'Fitness' is not in allowed categories"))
|
| 204 |
|
| 205 |
-
# Issue 3: Missing value in product_name
|
| 206 |
r = 13 # ORD-014
|
| 207 |
data[r][2] = ""
|
| 208 |
issues.append(PlantedIssue(row=r + 1, col="product_name", issue_type="missing_value",
|
| 209 |
-
description="Empty product_name"))
|
| 210 |
|
| 211 |
-
# Issue 4: Out of range quantity
|
| 212 |
r = 16 # ORD-017
|
| 213 |
data[r][4] = "-1"
|
| 214 |
issues.append(PlantedIssue(row=r + 1, col="quantity", issue_type="out_of_range",
|
| 215 |
-
description="Negative quantity"))
|
| 216 |
|
| 217 |
-
# Issue 5: Duplicate order_id
|
| 218 |
r = 18 # ORD-019
|
| 219 |
data[r][0] = "ORD-003"
|
| 220 |
issues.append(PlantedIssue(row=r + 1, col="order_id", issue_type="duplicate_row",
|
| 221 |
-
description="Duplicate order_id ORD-003"))
|
| 222 |
|
| 223 |
-
# Issue 6: Wrong date format
|
| 224 |
r = 11 # ORD-012
|
| 225 |
data[r][6] = "26/01/2024"
|
| 226 |
issues.append(PlantedIssue(row=r + 1, col="order_date", issue_type="format_violation",
|
| 227 |
-
description="Date format DD/MM/YYYY instead of YYYY-MM-DD"))
|
| 228 |
|
| 229 |
corrupted = _rows_to_csv([header] + data)
|
| 230 |
|
|
@@ -301,53 +302,57 @@ EXP-015,whisper-small,common-voice,520000,16000,16000,0.00005,16,5,0.55,0.68,0.0
|
|
| 301 |
data = rows[1:]
|
| 302 |
issues: List[PlantedIssue] = []
|
| 303 |
|
| 304 |
-
# Issue 1: Data leakage signal — val_loss much lower than train_loss
|
| 305 |
r = 4 # EXP-005
|
| 306 |
data[r][10] = "0.15" # val_loss=0.15 but train_loss=0.28 → suspicious
|
| 307 |
issues.append(PlantedIssue(row=r + 1, col="val_loss", issue_type="inconsistent_value",
|
| 308 |
-
description="val_loss (0.15) significantly less than train_loss (0.28), potential data leakage"
|
|
|
|
| 309 |
|
| 310 |
-
# Issue 2: Batch size not power of 2
|
| 311 |
r = 8 # EXP-009
|
| 312 |
data[r][7] = "250" # not a power of 2
|
| 313 |
issues.append(PlantedIssue(row=r + 1, col="batch_size", issue_type="format_violation",
|
| 314 |
-
description="batch_size 250 is not a power of 2"))
|
| 315 |
|
| 316 |
-
# Issue 3: GPU memory unreasonable for model
|
| 317 |
r = 6 # EXP-007 resnet18 on cifar10
|
| 318 |
data[r][12] = "42.5" # resnet18 shouldn't need 42.5 GB
|
| 319 |
issues.append(PlantedIssue(row=r + 1, col="gpu_memory_gb", issue_type="statistical_outlier",
|
| 320 |
-
description="resnet18 on cifar10 using 42.5 GB GPU memory is unreasonable"
|
|
|
|
| 321 |
|
| 322 |
-
# Issue 4: Timestamp out of order
|
| 323 |
r = 10 # EXP-011
|
| 324 |
data[r][14] = "2024-03-02T09:00:00" # should be after EXP-010's timestamp
|
| 325 |
issues.append(PlantedIssue(row=r + 1, col="timestamp", issue_type="inconsistent_value",
|
| 326 |
-
description="Timestamp 2024-03-02 is before EXP-010's timestamp 2024-03-11"
|
|
|
|
| 327 |
|
| 328 |
-
# Issue 5: Train size smaller than test size
|
| 329 |
r = 9 # EXP-010
|
| 330 |
data[r][3] = "500" # train_size=500 but test_size=1821
|
| 331 |
issues.append(PlantedIssue(row=r + 1, col="train_size", issue_type="inconsistent_value",
|
| 332 |
-
description="train_size (500) is smaller than test_size (1821)"
|
|
|
|
| 333 |
|
| 334 |
-
# Issue 6: Negative training time
|
| 335 |
r = 13 # EXP-014
|
| 336 |
data[r][13] = "-72.0"
|
| 337 |
issues.append(PlantedIssue(row=r + 1, col="training_time_hours", issue_type="out_of_range",
|
| 338 |
-
description="Negative training time"))
|
| 339 |
|
| 340 |
-
# Issue 7: Learning rate out of range
|
| 341 |
r = 12 # EXP-013
|
| 342 |
data[r][6] = "2.5" # way too high
|
| 343 |
issues.append(PlantedIssue(row=r + 1, col="learning_rate", issue_type="out_of_range",
|
| 344 |
-
description="Learning rate 2.5 exceeds maximum of 1.0"))
|
| 345 |
|
| 346 |
-
# Issue 8: Missing model name (
|
| 347 |
r = 14 # EXP-015
|
| 348 |
data[r][1] = " "
|
| 349 |
issues.append(PlantedIssue(row=r + 1, col="model_name", issue_type="missing_value",
|
| 350 |
-
description="model_name is whitespace-only"))
|
| 351 |
|
| 352 |
corrupted = _rows_to_csv([header] + data)
|
| 353 |
|
|
@@ -370,6 +375,123 @@ EXP-015,whisper-small,common-voice,520000,16000,16000,0.00005,16,5,0.55,0.68,0.0
|
|
| 370 |
)
|
| 371 |
|
| 372 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
# ---------------------------------------------------------------------------
|
| 374 |
# Task registry
|
| 375 |
# ---------------------------------------------------------------------------
|
|
|
|
| 25 |
col: str
|
| 26 |
issue_type: str
|
| 27 |
description: str
|
| 28 |
+
difficulty: float = 1.0 # 1.0=easy, 2.0=medium, 3.0=hard (for weighted reward)
|
| 29 |
|
| 30 |
def to_key(self) -> str:
|
| 31 |
return f"row:{self.row},col:{self.col},issue:{self.issue_type}"
|
|
|
|
| 94 |
data = rows[1:]
|
| 95 |
issues: List[PlantedIssue] = []
|
| 96 |
|
| 97 |
+
# Issue 1: Missing value - null out a name (easy to spot)
|
| 98 |
r = 3 # row index in data (0-based), displayed as row 4 in CSV
|
| 99 |
data[r][1] = ""
|
| 100 |
issues.append(PlantedIssue(row=r + 1, col="name", issue_type="missing_value",
|
| 101 |
+
description="Empty name field", difficulty=1.0))
|
| 102 |
|
| 103 |
+
# Issue 2: Wrong type - salary as text (easy to spot)
|
| 104 |
r = 6
|
| 105 |
data[r][4] = "seventy-five thousand"
|
| 106 |
issues.append(PlantedIssue(row=r + 1, col="salary", issue_type="wrong_type",
|
| 107 |
+
description="Salary is text instead of integer", difficulty=1.0))
|
| 108 |
|
| 109 |
+
# Issue 3: Duplicate row (moderate — requires cross-row comparison)
|
| 110 |
dup_source = 1
|
| 111 |
data.append(list(data[dup_source]))
|
| 112 |
issues.append(PlantedIssue(row=len(data), col="employee_id", issue_type="duplicate_row",
|
| 113 |
+
description=f"Exact duplicate of row {dup_source + 1}", difficulty=1.5))
|
| 114 |
|
| 115 |
+
# Issue 4: Out of range salary (easy to spot)
|
| 116 |
r = 8
|
| 117 |
data[r][4] = "5000"
|
| 118 |
issues.append(PlantedIssue(row=r + 1, col="salary", issue_type="out_of_range",
|
| 119 |
+
description="Salary 5000 is below minimum 50000", difficulty=1.0))
|
| 120 |
|
| 121 |
corrupted = _rows_to_csv([header] + data)
|
| 122 |
|
|
|
|
| 191 |
data = rows[1:]
|
| 192 |
issues: List[PlantedIssue] = []
|
| 193 |
|
| 194 |
+
# Issue 1: total doesn't match quantity * unit_price (requires cross-column check)
|
| 195 |
r = 4 # ORD-005
|
| 196 |
data[r][9] = "84.00" # should be 42.00 (qty=1, price=42.00)
|
| 197 |
issues.append(PlantedIssue(row=r + 1, col="total", issue_type="inconsistent_value",
|
| 198 |
+
description="total (84.00) != quantity (1) * unit_price (42.00)", difficulty=2.0))
|
| 199 |
|
| 200 |
+
# Issue 2: Invalid category (requires knowing the allowed set)
|
| 201 |
r = 9 # ORD-010
|
| 202 |
data[r][3] = "Fitness" # should be Sports
|
| 203 |
issues.append(PlantedIssue(row=r + 1, col="category", issue_type="format_violation",
|
| 204 |
+
description="'Fitness' is not in allowed categories", difficulty=1.5))
|
| 205 |
|
| 206 |
+
# Issue 3: Missing value in product_name (easy to spot)
|
| 207 |
r = 13 # ORD-014
|
| 208 |
data[r][2] = ""
|
| 209 |
issues.append(PlantedIssue(row=r + 1, col="product_name", issue_type="missing_value",
|
| 210 |
+
description="Empty product_name", difficulty=1.0))
|
| 211 |
|
| 212 |
+
# Issue 4: Out of range quantity (easy to spot)
|
| 213 |
r = 16 # ORD-017
|
| 214 |
data[r][4] = "-1"
|
| 215 |
issues.append(PlantedIssue(row=r + 1, col="quantity", issue_type="out_of_range",
|
| 216 |
+
description="Negative quantity", difficulty=1.0))
|
| 217 |
|
| 218 |
+
# Issue 5: Duplicate order_id (requires cross-row comparison)
|
| 219 |
r = 18 # ORD-019
|
| 220 |
data[r][0] = "ORD-003"
|
| 221 |
issues.append(PlantedIssue(row=r + 1, col="order_id", issue_type="duplicate_row",
|
| 222 |
+
description="Duplicate order_id ORD-003", difficulty=1.5))
|
| 223 |
|
| 224 |
+
# Issue 6: Wrong date format (moderate — format mismatch)
|
| 225 |
r = 11 # ORD-012
|
| 226 |
data[r][6] = "26/01/2024"
|
| 227 |
issues.append(PlantedIssue(row=r + 1, col="order_date", issue_type="format_violation",
|
| 228 |
+
description="Date format DD/MM/YYYY instead of YYYY-MM-DD", difficulty=1.5))
|
| 229 |
|
| 230 |
corrupted = _rows_to_csv([header] + data)
|
| 231 |
|
|
|
|
| 302 |
data = rows[1:]
|
| 303 |
issues: List[PlantedIssue] = []
|
| 304 |
|
| 305 |
+
# Issue 1: Data leakage signal — val_loss much lower than train_loss (hard — requires ML knowledge)
|
| 306 |
r = 4 # EXP-005
|
| 307 |
data[r][10] = "0.15" # val_loss=0.15 but train_loss=0.28 → suspicious
|
| 308 |
issues.append(PlantedIssue(row=r + 1, col="val_loss", issue_type="inconsistent_value",
|
| 309 |
+
description="val_loss (0.15) significantly less than train_loss (0.28), potential data leakage",
|
| 310 |
+
difficulty=3.0))
|
| 311 |
|
| 312 |
+
# Issue 2: Batch size not power of 2 (moderate — domain convention)
|
| 313 |
r = 8 # EXP-009
|
| 314 |
data[r][7] = "250" # not a power of 2
|
| 315 |
issues.append(PlantedIssue(row=r + 1, col="batch_size", issue_type="format_violation",
|
| 316 |
+
description="batch_size 250 is not a power of 2", difficulty=2.0))
|
| 317 |
|
| 318 |
+
# Issue 3: GPU memory unreasonable for model (hard — requires model size reasoning)
|
| 319 |
r = 6 # EXP-007 resnet18 on cifar10
|
| 320 |
data[r][12] = "42.5" # resnet18 shouldn't need 42.5 GB
|
| 321 |
issues.append(PlantedIssue(row=r + 1, col="gpu_memory_gb", issue_type="statistical_outlier",
|
| 322 |
+
description="resnet18 on cifar10 using 42.5 GB GPU memory is unreasonable",
|
| 323 |
+
difficulty=3.0))
|
| 324 |
|
| 325 |
+
# Issue 4: Timestamp out of order (moderate — requires sequential comparison)
|
| 326 |
r = 10 # EXP-011
|
| 327 |
data[r][14] = "2024-03-02T09:00:00" # should be after EXP-010's timestamp
|
| 328 |
issues.append(PlantedIssue(row=r + 1, col="timestamp", issue_type="inconsistent_value",
|
| 329 |
+
description="Timestamp 2024-03-02 is before EXP-010's timestamp 2024-03-11",
|
| 330 |
+
difficulty=2.0))
|
| 331 |
|
| 332 |
+
# Issue 5: Train size smaller than test size (moderate — cross-column logic)
|
| 333 |
r = 9 # EXP-010
|
| 334 |
data[r][3] = "500" # train_size=500 but test_size=1821
|
| 335 |
issues.append(PlantedIssue(row=r + 1, col="train_size", issue_type="inconsistent_value",
|
| 336 |
+
description="train_size (500) is smaller than test_size (1821)",
|
| 337 |
+
difficulty=2.0))
|
| 338 |
|
| 339 |
+
# Issue 6: Negative training time (easy to spot)
|
| 340 |
r = 13 # EXP-014
|
| 341 |
data[r][13] = "-72.0"
|
| 342 |
issues.append(PlantedIssue(row=r + 1, col="training_time_hours", issue_type="out_of_range",
|
| 343 |
+
description="Negative training time", difficulty=1.0))
|
| 344 |
|
| 345 |
+
# Issue 7: Learning rate out of range (easy to spot)
|
| 346 |
r = 12 # EXP-013
|
| 347 |
data[r][6] = "2.5" # way too high
|
| 348 |
issues.append(PlantedIssue(row=r + 1, col="learning_rate", issue_type="out_of_range",
|
| 349 |
+
description="Learning rate 2.5 exceeds maximum of 1.0", difficulty=1.5))
|
| 350 |
|
| 351 |
+
# Issue 8: Missing model name (hard — whitespace-only is subtle)
|
| 352 |
r = 14 # EXP-015
|
| 353 |
data[r][1] = " "
|
| 354 |
issues.append(PlantedIssue(row=r + 1, col="model_name", issue_type="missing_value",
|
| 355 |
+
description="model_name is whitespace-only", difficulty=2.5))
|
| 356 |
|
| 357 |
corrupted = _rows_to_csv([header] + data)
|
| 358 |
|
|
|
|
| 375 |
)
|
| 376 |
|
| 377 |
|
| 378 |
+
# ---------------------------------------------------------------------------
|
| 379 |
+
# Contamination rules for extensible task creation
|
| 380 |
+
# ---------------------------------------------------------------------------
|
| 381 |
+
|
| 382 |
+
# Each contamination rule is a callable: (rows, header, col_idx, row_idx, rng) -> (new_value, PlantedIssue)
|
| 383 |
+
# Users can define their own and register them.
|
| 384 |
+
|
| 385 |
+
CONTAMINATION_RULES = {
|
| 386 |
+
"missing_value": lambda rows, header, col_idx, row_idx, rng: (
|
| 387 |
+
"",
|
| 388 |
+
PlantedIssue(
|
| 389 |
+
row=row_idx + 1, col=header[col_idx], issue_type="missing_value",
|
| 390 |
+
description=f"Empty {header[col_idx]} field", difficulty=1.0,
|
| 391 |
+
),
|
| 392 |
+
),
|
| 393 |
+
"whitespace_value": lambda rows, header, col_idx, row_idx, rng: (
|
| 394 |
+
" ",
|
| 395 |
+
PlantedIssue(
|
| 396 |
+
row=row_idx + 1, col=header[col_idx], issue_type="missing_value",
|
| 397 |
+
description=f"Whitespace-only {header[col_idx]} field", difficulty=2.5,
|
| 398 |
+
),
|
| 399 |
+
),
|
| 400 |
+
"wrong_type_text": lambda rows, header, col_idx, row_idx, rng: (
|
| 401 |
+
rng.choice(["not-a-number", "N/A", "null", "undefined"]),
|
| 402 |
+
PlantedIssue(
|
| 403 |
+
row=row_idx + 1, col=header[col_idx], issue_type="wrong_type",
|
| 404 |
+
description=f"{header[col_idx]} is text instead of expected type", difficulty=1.0,
|
| 405 |
+
),
|
| 406 |
+
),
|
| 407 |
+
"negative_value": lambda rows, header, col_idx, row_idx, rng: (
|
| 408 |
+
str(-abs(float(rows[row_idx][col_idx]) if rows[row_idx][col_idx] else 1)),
|
| 409 |
+
PlantedIssue(
|
| 410 |
+
row=row_idx + 1, col=header[col_idx], issue_type="out_of_range",
|
| 411 |
+
description=f"Negative {header[col_idx]}", difficulty=1.0,
|
| 412 |
+
),
|
| 413 |
+
),
|
| 414 |
+
}
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
def create_task_from_config(
|
| 418 |
+
task_id: str,
|
| 419 |
+
name: str,
|
| 420 |
+
description: str,
|
| 421 |
+
schema_description: str,
|
| 422 |
+
validation_rules: str,
|
| 423 |
+
clean_csv: str,
|
| 424 |
+
contaminations: List[dict],
|
| 425 |
+
max_steps: int = 3,
|
| 426 |
+
seed: int = 42,
|
| 427 |
+
) -> Task:
|
| 428 |
+
"""
|
| 429 |
+
Create a custom task from a configuration dict.
|
| 430 |
+
|
| 431 |
+
Each contamination entry should have:
|
| 432 |
+
- rule: str (key in CONTAMINATION_RULES) or callable
|
| 433 |
+
- row: int (0-based row index in data)
|
| 434 |
+
- col: int (column index in header)
|
| 435 |
+
- difficulty: float (optional, overrides rule default)
|
| 436 |
+
|
| 437 |
+
Example:
|
| 438 |
+
contaminations = [
|
| 439 |
+
{"rule": "missing_value", "row": 2, "col": 1, "difficulty": 1.5},
|
| 440 |
+
{"rule": "negative_value", "row": 5, "col": 4},
|
| 441 |
+
]
|
| 442 |
+
"""
|
| 443 |
+
rng = random.Random(seed)
|
| 444 |
+
rows = _csv_to_rows(clean_csv)
|
| 445 |
+
header = rows[0]
|
| 446 |
+
data = rows[1:]
|
| 447 |
+
issues: List[PlantedIssue] = []
|
| 448 |
+
|
| 449 |
+
for spec in contaminations:
|
| 450 |
+
rule = spec["rule"]
|
| 451 |
+
row_idx = spec["row"]
|
| 452 |
+
col_idx = spec["col"]
|
| 453 |
+
|
| 454 |
+
if callable(rule):
|
| 455 |
+
new_val, issue = rule(data, header, col_idx, row_idx, rng)
|
| 456 |
+
elif rule in CONTAMINATION_RULES:
|
| 457 |
+
new_val, issue = CONTAMINATION_RULES[rule](data, header, col_idx, row_idx, rng)
|
| 458 |
+
else:
|
| 459 |
+
raise ValueError(f"Unknown contamination rule: {rule}. Available: {list(CONTAMINATION_RULES.keys())}")
|
| 460 |
+
|
| 461 |
+
data[row_idx][col_idx] = new_val
|
| 462 |
+
if "difficulty" in spec:
|
| 463 |
+
issue.difficulty = spec["difficulty"]
|
| 464 |
+
issues.append(issue)
|
| 465 |
+
|
| 466 |
+
corrupted = _rows_to_csv([header] + data)
|
| 467 |
+
|
| 468 |
+
return Task(
|
| 469 |
+
task_id=task_id,
|
| 470 |
+
name=name,
|
| 471 |
+
description=description,
|
| 472 |
+
schema_description=schema_description,
|
| 473 |
+
validation_rules=validation_rules,
|
| 474 |
+
clean_csv=clean_csv,
|
| 475 |
+
planted_issues=issues,
|
| 476 |
+
corrupted_csv=corrupted,
|
| 477 |
+
max_steps=max_steps,
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
def register_task(task_id: str, factory_fn):
|
| 482 |
+
"""Register a custom task factory. Factory must accept (seed: int) -> Task."""
|
| 483 |
+
TASK_REGISTRY[task_id] = factory_fn
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
def register_contamination_rule(name: str, rule_fn):
|
| 487 |
+
"""
|
| 488 |
+
Register a custom contamination rule.
|
| 489 |
+
|
| 490 |
+
rule_fn signature: (rows, header, col_idx, row_idx, rng) -> (new_value, PlantedIssue)
|
| 491 |
+
"""
|
| 492 |
+
CONTAMINATION_RULES[name] = rule_fn
|
| 493 |
+
|
| 494 |
+
|
| 495 |
# ---------------------------------------------------------------------------
|
| 496 |
# Task registry
|
| 497 |
# ---------------------------------------------------------------------------
|
inference.py
CHANGED
|
@@ -6,21 +6,23 @@ LLM agent that plays the DataQA environment.
|
|
| 6 |
Uses the OpenAI client to interact with any OpenAI-compatible LLM API.
|
| 7 |
|
| 8 |
Required environment variables:
|
| 9 |
-
API_BASE_URL - LLM API endpoint (e.g., https://
|
| 10 |
-
MODEL_NAME - Model identifier (e.g.,
|
| 11 |
-
HF_TOKEN - HuggingFace token
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
| 14 |
"""
|
| 15 |
|
| 16 |
from __future__ import annotations
|
| 17 |
|
| 18 |
-
import json
|
| 19 |
import os
|
| 20 |
import re
|
| 21 |
import sys
|
| 22 |
import time
|
| 23 |
-
from typing import Optional
|
| 24 |
|
| 25 |
import requests
|
| 26 |
from openai import OpenAI
|
|
@@ -28,52 +30,43 @@ from openai import OpenAI
|
|
| 28 |
# ---------------------------------------------------------------------------
|
| 29 |
# Configuration
|
| 30 |
# ---------------------------------------------------------------------------
|
| 31 |
-
API_BASE_URL = os.
|
| 32 |
-
MODEL_NAME = os.
|
| 33 |
-
|
| 34 |
-
ENV_URL = os.
|
| 35 |
|
|
|
|
| 36 |
TASKS = ["easy", "medium", "hard"]
|
| 37 |
MAX_STEPS_PER_TASK = 3
|
| 38 |
|
|
|
|
| 39 |
# ---------------------------------------------------------------------------
|
| 40 |
-
# Logging helpers (structured stdout
|
| 41 |
# ---------------------------------------------------------------------------
|
| 42 |
|
| 43 |
-
def log_start(
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
}
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
def log_end(task_id: str, final_score: float, metadata: Optional[dict] = None):
|
| 64 |
-
entry = {
|
| 65 |
-
"event": "END",
|
| 66 |
-
"task_id": task_id,
|
| 67 |
-
"final_score": final_score,
|
| 68 |
-
"timestamp": time.time(),
|
| 69 |
-
}
|
| 70 |
-
if metadata:
|
| 71 |
-
entry["metadata"] = metadata
|
| 72 |
-
print(f"[END] {json.dumps(entry)}", flush=True)
|
| 73 |
|
| 74 |
|
| 75 |
# ---------------------------------------------------------------------------
|
| 76 |
-
# Environment HTTP client
|
| 77 |
# ---------------------------------------------------------------------------
|
| 78 |
|
| 79 |
class EnvHTTPClient:
|
|
@@ -108,11 +101,6 @@ class EnvHTTPClient:
|
|
| 108 |
r.raise_for_status()
|
| 109 |
return r.json()
|
| 110 |
|
| 111 |
-
def state(self) -> dict:
|
| 112 |
-
r = self.session.get(f"{self.base_url}/state", timeout=10)
|
| 113 |
-
r.raise_for_status()
|
| 114 |
-
return r.json()
|
| 115 |
-
|
| 116 |
|
| 117 |
# ---------------------------------------------------------------------------
|
| 118 |
# LLM Agent
|
|
@@ -142,7 +130,6 @@ CRITICAL INSTRUCTIONS FOR ROW NUMBERING:
|
|
| 142 |
- Row numbers refer to the ROW POSITION in the CSV data, NOT the value of any ID column
|
| 143 |
- Row 1 = the FIRST data row after the header
|
| 144 |
- Row 2 = the SECOND data row after the header
|
| 145 |
-
- For example, if the CSV has header on line 1 and data starting on line 2, the data on line 2 is row 1, line 3 is row 2, etc.
|
| 146 |
- DO NOT use the employee_id, order_id, or experiment_id as the row number
|
| 147 |
- Column names must match exactly (use the CSV header names, lowercase)
|
| 148 |
- Check EVERY row and EVERY column systematically
|
|
@@ -188,14 +175,12 @@ def parse_llm_response(response: str) -> list[str]:
|
|
| 188 |
line = re.sub(r"^\s*[-*]\s*", "", line)
|
| 189 |
line = line.strip()
|
| 190 |
if "row" in line.lower() and "col" in line.lower():
|
| 191 |
-
# Lenient regex: accept : or = as delimiters, case-insensitive
|
| 192 |
match = re.search(
|
| 193 |
r"row\s*[:=]\s*(\d+)\s*[,;\s]+col(?:umn)?\s*[:=]\s*([\w_]+)\s*[,;\s]+issue\s*[:=]\s*([\w_]+)",
|
| 194 |
line,
|
| 195 |
re.IGNORECASE,
|
| 196 |
)
|
| 197 |
if match:
|
| 198 |
-
# Normalize to lowercase canonical format
|
| 199 |
normalized = f"row:{match.group(1)},col:{match.group(2).lower()},issue:{match.group(3).lower()}"
|
| 200 |
issues.append(normalized)
|
| 201 |
return issues
|
|
@@ -203,68 +188,84 @@ def parse_llm_response(response: str) -> list[str]:
|
|
| 203 |
|
| 204 |
def run_task(client: OpenAI, env: EnvHTTPClient, task_id: str) -> float:
|
| 205 |
"""Run a single task and return the best score."""
|
| 206 |
-
log_start(task_id)
|
| 207 |
-
|
| 208 |
-
# Reset environment for this task
|
| 209 |
-
reset_response = env.reset(task_id=task_id)
|
| 210 |
-
observation = reset_response.get("observation", reset_response)
|
| 211 |
|
|
|
|
|
|
|
| 212 |
best_score = 0.0
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
if not issues:
|
| 244 |
-
print(f"[WARN] No issues parsed from LLM response for {task_id} step {step_num}", file=sys.stderr, flush=True)
|
| 245 |
-
|
| 246 |
-
# Submit to environment
|
| 247 |
-
step_response = env.step(issues, task_id=task_id)
|
| 248 |
-
observation = step_response.get("observation", step_response)
|
| 249 |
-
|
| 250 |
-
# reward and done are at the top level of the response, not inside observation
|
| 251 |
-
reward = float(step_response.get("reward", 0.0) or 0.0)
|
| 252 |
-
done = bool(step_response.get("done", False))
|
| 253 |
-
best_score = max(best_score, reward)
|
| 254 |
|
| 255 |
-
|
| 256 |
-
"
|
| 257 |
-
"
|
| 258 |
-
})
|
| 259 |
|
| 260 |
-
|
| 261 |
-
break
|
| 262 |
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
messages.append({"role": "assistant", "content": llm_output})
|
| 266 |
|
| 267 |
-
log_end(task_id, best_score)
|
| 268 |
return best_score
|
| 269 |
|
| 270 |
|
|
@@ -273,49 +274,38 @@ def run_task(client: OpenAI, env: EnvHTTPClient, task_id: str) -> float:
|
|
| 273 |
# ---------------------------------------------------------------------------
|
| 274 |
|
| 275 |
def main():
|
| 276 |
-
print(f"[
|
| 277 |
-
print(f"[
|
| 278 |
-
print(f"[
|
| 279 |
-
print(f"[
|
| 280 |
|
| 281 |
# Initialize clients
|
| 282 |
env = EnvHTTPClient(ENV_URL)
|
| 283 |
llm_client = OpenAI(
|
| 284 |
base_url=API_BASE_URL,
|
| 285 |
-
api_key=
|
| 286 |
)
|
| 287 |
|
| 288 |
# Check environment health
|
| 289 |
if not env.health():
|
| 290 |
-
print("[
|
| 291 |
sys.exit(1)
|
| 292 |
|
| 293 |
-
print(f"[
|
| 294 |
|
| 295 |
# Run all tasks
|
| 296 |
scores = {}
|
| 297 |
for task_id in TASKS:
|
| 298 |
-
print(f"\n{'='*60}", flush=True)
|
| 299 |
-
print(f"[INFO] Starting task: {task_id}", flush=True)
|
| 300 |
-
print(f"{'='*60}", flush=True)
|
| 301 |
-
|
| 302 |
try:
|
| 303 |
score = run_task(llm_client, env, task_id)
|
| 304 |
scores[task_id] = score
|
| 305 |
-
print(f"[INFO] Task {task_id} completed with score: {score:.3f}", flush=True)
|
| 306 |
except Exception as e:
|
| 307 |
-
print(f"[
|
| 308 |
scores[task_id] = 0.0
|
| 309 |
|
| 310 |
-
# Summary
|
| 311 |
-
print(f"\n{'='*60}", flush=True)
|
| 312 |
-
print("[INFO] FINAL RESULTS", flush=True)
|
| 313 |
-
print(f"{'='*60}", flush=True)
|
| 314 |
-
for task_id, score in scores.items():
|
| 315 |
-
print(f"[INFO] {task_id}: {score:.3f}", flush=True)
|
| 316 |
-
|
| 317 |
avg_score = sum(scores.values()) / len(scores) if scores else 0.0
|
| 318 |
-
print(f"[
|
| 319 |
|
| 320 |
|
| 321 |
if __name__ == "__main__":
|
|
|
|
| 6 |
Uses the OpenAI client to interact with any OpenAI-compatible LLM API.
|
| 7 |
|
| 8 |
Required environment variables:
|
| 9 |
+
API_BASE_URL - LLM API endpoint (e.g., https://router.huggingface.co/v1)
|
| 10 |
+
MODEL_NAME - Model identifier (e.g., Qwen/Qwen2.5-72B-Instruct)
|
| 11 |
+
HF_TOKEN - HuggingFace token / API key
|
| 12 |
+
|
| 13 |
+
STDOUT FORMAT (mandatory for evaluation):
|
| 14 |
+
[START] task=<task_name> env=<benchmark> model=<model_name>
|
| 15 |
+
[STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
|
| 16 |
+
[END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn>
|
| 17 |
"""
|
| 18 |
|
| 19 |
from __future__ import annotations
|
| 20 |
|
|
|
|
| 21 |
import os
|
| 22 |
import re
|
| 23 |
import sys
|
| 24 |
import time
|
| 25 |
+
from typing import List, Optional
|
| 26 |
|
| 27 |
import requests
|
| 28 |
from openai import OpenAI
|
|
|
|
| 30 |
# ---------------------------------------------------------------------------
|
| 31 |
# Configuration
|
| 32 |
# ---------------------------------------------------------------------------
|
| 33 |
+
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
| 34 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
|
| 35 |
+
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
|
| 36 |
+
ENV_URL = os.getenv("ENV_URL", "http://localhost:8000")
|
| 37 |
|
| 38 |
+
BENCHMARK = "dataqa_env"
|
| 39 |
TASKS = ["easy", "medium", "hard"]
|
| 40 |
MAX_STEPS_PER_TASK = 3
|
| 41 |
|
| 42 |
+
|
| 43 |
# ---------------------------------------------------------------------------
|
| 44 |
+
# Logging helpers (structured stdout — exact format required by evaluation)
|
| 45 |
# ---------------------------------------------------------------------------
|
| 46 |
|
| 47 |
+
def log_start(task: str, env: str, model: str) -> None:
|
| 48 |
+
print(f"[START] task={task} env={env} model={model}", flush=True)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
|
| 52 |
+
error_val = error if error else "null"
|
| 53 |
+
done_val = str(done).lower()
|
| 54 |
+
print(
|
| 55 |
+
f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
|
| 56 |
+
flush=True,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
|
| 61 |
+
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
|
| 62 |
+
print(
|
| 63 |
+
f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}",
|
| 64 |
+
flush=True,
|
| 65 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
|
| 68 |
# ---------------------------------------------------------------------------
|
| 69 |
+
# Environment HTTP client
|
| 70 |
# ---------------------------------------------------------------------------
|
| 71 |
|
| 72 |
class EnvHTTPClient:
|
|
|
|
| 101 |
r.raise_for_status()
|
| 102 |
return r.json()
|
| 103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
# ---------------------------------------------------------------------------
|
| 106 |
# LLM Agent
|
|
|
|
| 130 |
- Row numbers refer to the ROW POSITION in the CSV data, NOT the value of any ID column
|
| 131 |
- Row 1 = the FIRST data row after the header
|
| 132 |
- Row 2 = the SECOND data row after the header
|
|
|
|
| 133 |
- DO NOT use the employee_id, order_id, or experiment_id as the row number
|
| 134 |
- Column names must match exactly (use the CSV header names, lowercase)
|
| 135 |
- Check EVERY row and EVERY column systematically
|
|
|
|
| 175 |
line = re.sub(r"^\s*[-*]\s*", "", line)
|
| 176 |
line = line.strip()
|
| 177 |
if "row" in line.lower() and "col" in line.lower():
|
|
|
|
| 178 |
match = re.search(
|
| 179 |
r"row\s*[:=]\s*(\d+)\s*[,;\s]+col(?:umn)?\s*[:=]\s*([\w_]+)\s*[,;\s]+issue\s*[:=]\s*([\w_]+)",
|
| 180 |
line,
|
| 181 |
re.IGNORECASE,
|
| 182 |
)
|
| 183 |
if match:
|
|
|
|
| 184 |
normalized = f"row:{match.group(1)},col:{match.group(2).lower()},issue:{match.group(3).lower()}"
|
| 185 |
issues.append(normalized)
|
| 186 |
return issues
|
|
|
|
| 188 |
|
| 189 |
def run_task(client: OpenAI, env: EnvHTTPClient, task_id: str) -> float:
|
| 190 |
"""Run a single task and return the best score."""
|
| 191 |
+
log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 192 |
|
| 193 |
+
rewards: List[float] = []
|
| 194 |
+
steps_taken = 0
|
| 195 |
best_score = 0.0
|
| 196 |
+
success = False
|
| 197 |
+
|
| 198 |
+
try:
|
| 199 |
+
# Reset environment for this task
|
| 200 |
+
reset_response = env.reset(task_id=task_id)
|
| 201 |
+
observation = reset_response.get("observation", reset_response)
|
| 202 |
+
|
| 203 |
+
messages = [{"role": "system", "content": SYSTEM_PROMPT}]
|
| 204 |
+
|
| 205 |
+
for step_num in range(1, MAX_STEPS_PER_TASK + 1):
|
| 206 |
+
user_prompt = build_user_prompt(observation)
|
| 207 |
+
messages_for_call = messages + [{"role": "user", "content": user_prompt}]
|
| 208 |
+
|
| 209 |
+
# Call LLM with retry on rate limit
|
| 210 |
+
llm_output = ""
|
| 211 |
+
error_msg = None
|
| 212 |
+
for attempt in range(3):
|
| 213 |
+
try:
|
| 214 |
+
response = client.chat.completions.create(
|
| 215 |
+
model=MODEL_NAME,
|
| 216 |
+
messages=messages_for_call,
|
| 217 |
+
temperature=0.1,
|
| 218 |
+
max_tokens=2048,
|
| 219 |
+
)
|
| 220 |
+
llm_output = response.choices[0].message.content or ""
|
| 221 |
break
|
| 222 |
+
except Exception as e:
|
| 223 |
+
if "rate_limit" in str(e).lower() or "429" in str(e):
|
| 224 |
+
wait = 10 * (attempt + 1)
|
| 225 |
+
print(f"[DEBUG] Rate limited, waiting {wait}s...", file=sys.stderr, flush=True)
|
| 226 |
+
time.sleep(wait)
|
| 227 |
+
else:
|
| 228 |
+
error_msg = str(e)
|
| 229 |
+
print(f"[DEBUG] LLM call failed: {e}", file=sys.stderr, flush=True)
|
| 230 |
+
break
|
| 231 |
+
|
| 232 |
+
# Parse issues from LLM response
|
| 233 |
+
issues = parse_llm_response(llm_output)
|
| 234 |
+
action_str = ";".join(issues) if issues else "none"
|
| 235 |
+
|
| 236 |
+
if not issues and not error_msg:
|
| 237 |
+
error_msg = "no issues parsed from LLM response"
|
| 238 |
+
|
| 239 |
+
# Submit to environment
|
| 240 |
+
step_response = env.step(issues, task_id=task_id)
|
| 241 |
+
observation = step_response.get("observation", step_response)
|
| 242 |
+
|
| 243 |
+
reward = float(step_response.get("reward", 0.0) or 0.0)
|
| 244 |
+
done = bool(step_response.get("done", False))
|
| 245 |
+
best_score = max(best_score, reward)
|
| 246 |
+
rewards.append(reward)
|
| 247 |
+
steps_taken = step_num
|
| 248 |
+
|
| 249 |
+
log_step(
|
| 250 |
+
step=step_num,
|
| 251 |
+
action=action_str,
|
| 252 |
+
reward=reward,
|
| 253 |
+
done=done,
|
| 254 |
+
error=error_msg,
|
| 255 |
+
)
|
| 256 |
|
| 257 |
+
if done:
|
| 258 |
+
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
|
| 260 |
+
# Add context for next attempt
|
| 261 |
+
messages.append({"role": "user", "content": user_prompt})
|
| 262 |
+
messages.append({"role": "assistant", "content": llm_output})
|
|
|
|
| 263 |
|
| 264 |
+
success = best_score >= 0.5
|
|
|
|
| 265 |
|
| 266 |
+
finally:
|
| 267 |
+
log_end(success=success, steps=steps_taken, score=best_score, rewards=rewards)
|
|
|
|
| 268 |
|
|
|
|
| 269 |
return best_score
|
| 270 |
|
| 271 |
|
|
|
|
| 274 |
# ---------------------------------------------------------------------------
|
| 275 |
|
| 276 |
def main():
|
| 277 |
+
print(f"[DEBUG] DataQA Inference starting", file=sys.stderr, flush=True)
|
| 278 |
+
print(f"[DEBUG] ENV_URL={ENV_URL}", file=sys.stderr, flush=True)
|
| 279 |
+
print(f"[DEBUG] API_BASE_URL={API_BASE_URL}", file=sys.stderr, flush=True)
|
| 280 |
+
print(f"[DEBUG] MODEL_NAME={MODEL_NAME}", file=sys.stderr, flush=True)
|
| 281 |
|
| 282 |
# Initialize clients
|
| 283 |
env = EnvHTTPClient(ENV_URL)
|
| 284 |
llm_client = OpenAI(
|
| 285 |
base_url=API_BASE_URL,
|
| 286 |
+
api_key=API_KEY or "no-key",
|
| 287 |
)
|
| 288 |
|
| 289 |
# Check environment health
|
| 290 |
if not env.health():
|
| 291 |
+
print("[DEBUG] Environment is not healthy. Exiting.", file=sys.stderr, flush=True)
|
| 292 |
sys.exit(1)
|
| 293 |
|
| 294 |
+
print(f"[DEBUG] Environment is healthy", file=sys.stderr, flush=True)
|
| 295 |
|
| 296 |
# Run all tasks
|
| 297 |
scores = {}
|
| 298 |
for task_id in TASKS:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
try:
|
| 300 |
score = run_task(llm_client, env, task_id)
|
| 301 |
scores[task_id] = score
|
|
|
|
| 302 |
except Exception as e:
|
| 303 |
+
print(f"[DEBUG] Task {task_id} failed: {e}", file=sys.stderr, flush=True)
|
| 304 |
scores[task_id] = 0.0
|
| 305 |
|
| 306 |
+
# Summary to stderr (stdout is reserved for structured logs only)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
avg_score = sum(scores.values()) / len(scores) if scores else 0.0
|
| 308 |
+
print(f"\n[DEBUG] FINAL RESULTS: {scores} avg={avg_score:.3f}", file=sys.stderr, flush=True)
|
| 309 |
|
| 310 |
|
| 311 |
if __name__ == "__main__":
|
prevalidation_script.sh
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
#
|
| 3 |
+
# validate-submission.sh — OpenEnv Submission Validator
|
| 4 |
+
#
|
| 5 |
+
# Checks that your HF Space is live, Docker image builds, and openenv validate passes.
|
| 6 |
+
#
|
| 7 |
+
# Prerequisites:
|
| 8 |
+
# - Docker: https://docs.docker.com/get-docker/
|
| 9 |
+
# - openenv-core: pip install openenv-core
|
| 10 |
+
# - curl (usually pre-installed)
|
| 11 |
+
#
|
| 12 |
+
# Run:
|
| 13 |
+
# curl -fsSL https://raw.githubusercontent.com/<owner>/<repo>/main/scripts/validate-submission.sh | bash -s -- <ping_url> [repo_dir]
|
| 14 |
+
#
|
| 15 |
+
# Or download and run locally:
|
| 16 |
+
# chmod +x validate-submission.sh
|
| 17 |
+
# ./validate-submission.sh <ping_url> [repo_dir]
|
| 18 |
+
#
|
| 19 |
+
# Arguments:
|
| 20 |
+
# ping_url Your HuggingFace Space URL (e.g. https://your-space.hf.space)
|
| 21 |
+
# repo_dir Path to your repo (default: current directory)
|
| 22 |
+
#
|
| 23 |
+
# Examples:
|
| 24 |
+
# ./validate-submission.sh https://my-team.hf.space
|
| 25 |
+
# ./validate-submission.sh https://my-team.hf.space ./my-repo
|
| 26 |
+
#
|
| 27 |
+
|
| 28 |
+
set -uo pipefail
|
| 29 |
+
|
| 30 |
+
DOCKER_BUILD_TIMEOUT=600
|
| 31 |
+
if [ -t 1 ]; then
|
| 32 |
+
RED='\033[0;31m'
|
| 33 |
+
GREEN='\033[0;32m'
|
| 34 |
+
YELLOW='\033[1;33m'
|
| 35 |
+
BOLD='\033[1m'
|
| 36 |
+
NC='\033[0m'
|
| 37 |
+
else
|
| 38 |
+
RED='' GREEN='' YELLOW='' BOLD='' NC=''
|
| 39 |
+
fi
|
| 40 |
+
|
| 41 |
+
run_with_timeout() {
|
| 42 |
+
local secs="$1"; shift
|
| 43 |
+
if command -v timeout &>/dev/null; then
|
| 44 |
+
timeout "$secs" "$@"
|
| 45 |
+
elif command -v gtimeout &>/dev/null; then
|
| 46 |
+
gtimeout "$secs" "$@"
|
| 47 |
+
else
|
| 48 |
+
"$@" &
|
| 49 |
+
local pid=$!
|
| 50 |
+
( sleep "$secs" && kill "$pid" 2>/dev/null ) &
|
| 51 |
+
local watcher=$!
|
| 52 |
+
wait "$pid" 2>/dev/null
|
| 53 |
+
local rc=$?
|
| 54 |
+
kill "$watcher" 2>/dev/null
|
| 55 |
+
wait "$watcher" 2>/dev/null
|
| 56 |
+
return $rc
|
| 57 |
+
fi
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
portable_mktemp() {
|
| 61 |
+
local prefix="${1:-validate}"
|
| 62 |
+
mktemp "${TMPDIR:-/tmp}/${prefix}-XXXXXX" 2>/dev/null || mktemp
|
| 63 |
+
}
|
| 64 |
+
|
| 65 |
+
CLEANUP_FILES=()
|
| 66 |
+
cleanup() { rm -f "${CLEANUP_FILES[@]+"${CLEANUP_FILES[@]}"}"; }
|
| 67 |
+
trap cleanup EXIT
|
| 68 |
+
|
| 69 |
+
PING_URL="${1:-}"
|
| 70 |
+
REPO_DIR="${2:-.}"
|
| 71 |
+
|
| 72 |
+
if [ -z "$PING_URL" ]; then
|
| 73 |
+
printf "Usage: %s <ping_url> [repo_dir]\n" "$0"
|
| 74 |
+
printf "\n"
|
| 75 |
+
printf " ping_url Your HuggingFace Space URL (e.g. https://your-space.hf.space)\n"
|
| 76 |
+
printf " repo_dir Path to your repo (default: current directory)\n"
|
| 77 |
+
exit 1
|
| 78 |
+
fi
|
| 79 |
+
|
| 80 |
+
if ! REPO_DIR="$(cd "$REPO_DIR" 2>/dev/null && pwd)"; then
|
| 81 |
+
printf "Error: directory '%s' not found\n" "${2:-.}"
|
| 82 |
+
exit 1
|
| 83 |
+
fi
|
| 84 |
+
PING_URL="${PING_URL%/}"
|
| 85 |
+
export PING_URL
|
| 86 |
+
PASS=0
|
| 87 |
+
|
| 88 |
+
log() { printf "[%s] %b\n" "$(date -u +%H:%M:%S)" "$*"; }
|
| 89 |
+
pass() { log "${GREEN}PASSED${NC} -- $1"; PASS=$((PASS + 1)); }
|
| 90 |
+
fail() { log "${RED}FAILED${NC} -- $1"; }
|
| 91 |
+
hint() { printf " ${YELLOW}Hint:${NC} %b\n" "$1"; }
|
| 92 |
+
stop_at() {
|
| 93 |
+
printf "\n"
|
| 94 |
+
printf "${RED}${BOLD}Validation stopped at %s.${NC} Fix the above before continuing.\n" "$1"
|
| 95 |
+
exit 1
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
printf "\n"
|
| 99 |
+
printf "${BOLD}========================================${NC}\n"
|
| 100 |
+
printf "${BOLD} OpenEnv Submission Validator${NC}\n"
|
| 101 |
+
printf "${BOLD}========================================${NC}\n"
|
| 102 |
+
log "Repo: $REPO_DIR"
|
| 103 |
+
log "Ping URL: $PING_URL"
|
| 104 |
+
printf "\n"
|
| 105 |
+
|
| 106 |
+
log "${BOLD}Step 1/3: Pinging HF Space${NC} ($PING_URL/reset) ..."
|
| 107 |
+
|
| 108 |
+
CURL_OUTPUT=$(portable_mktemp "validate-curl")
|
| 109 |
+
CLEANUP_FILES+=("$CURL_OUTPUT")
|
| 110 |
+
HTTP_CODE=$(curl -s -o "$CURL_OUTPUT" -w "%{http_code}" -X POST \
|
| 111 |
+
-H "Content-Type: application/json" -d '{}' \
|
| 112 |
+
"$PING_URL/reset" --max-time 30 2>"$CURL_OUTPUT" || printf "000")
|
| 113 |
+
|
| 114 |
+
if [ "$HTTP_CODE" = "200" ]; then
|
| 115 |
+
pass "HF Space is live and responds to /reset"
|
| 116 |
+
elif [ "$HTTP_CODE" = "000" ]; then
|
| 117 |
+
fail "HF Space not reachable (connection failed or timed out)"
|
| 118 |
+
hint "Check your network connection and that the Space is running."
|
| 119 |
+
hint "Try: curl -s -o /dev/null -w '%%{http_code}' -X POST $PING_URL/reset"
|
| 120 |
+
stop_at "Step 1"
|
| 121 |
+
else
|
| 122 |
+
fail "HF Space /reset returned HTTP $HTTP_CODE (expected 200)"
|
| 123 |
+
hint "Make sure your Space is running and the URL is correct."
|
| 124 |
+
hint "Try opening $PING_URL in your browser first."
|
| 125 |
+
stop_at "Step 1"
|
| 126 |
+
fi
|
| 127 |
+
|
| 128 |
+
log "${BOLD}Step 2/3: Running docker build${NC} ..."
|
| 129 |
+
|
| 130 |
+
if ! command -v docker &>/dev/null; then
|
| 131 |
+
fail "docker command not found"
|
| 132 |
+
hint "Install Docker: https://docs.docker.com/get-docker/"
|
| 133 |
+
stop_at "Step 2"
|
| 134 |
+
fi
|
| 135 |
+
|
| 136 |
+
if [ -f "$REPO_DIR/Dockerfile" ]; then
|
| 137 |
+
DOCKER_CONTEXT="$REPO_DIR"
|
| 138 |
+
elif [ -f "$REPO_DIR/server/Dockerfile" ]; then
|
| 139 |
+
DOCKER_CONTEXT="$REPO_DIR/server"
|
| 140 |
+
else
|
| 141 |
+
fail "No Dockerfile found in repo root or server/ directory"
|
| 142 |
+
stop_at "Step 2"
|
| 143 |
+
fi
|
| 144 |
+
|
| 145 |
+
log " Found Dockerfile in $DOCKER_CONTEXT"
|
| 146 |
+
|
| 147 |
+
BUILD_OK=false
|
| 148 |
+
BUILD_OUTPUT=$(run_with_timeout "$DOCKER_BUILD_TIMEOUT" docker build "$DOCKER_CONTEXT" 2>&1) && BUILD_OK=true
|
| 149 |
+
|
| 150 |
+
if [ "$BUILD_OK" = true ]; then
|
| 151 |
+
pass "Docker build succeeded"
|
| 152 |
+
else
|
| 153 |
+
fail "Docker build failed (timeout=${DOCKER_BUILD_TIMEOUT}s)"
|
| 154 |
+
printf "%s\n" "$BUILD_OUTPUT" | tail -20
|
| 155 |
+
stop_at "Step 2"
|
| 156 |
+
fi
|
| 157 |
+
|
| 158 |
+
log "${BOLD}Step 3/3: Running openenv validate${NC} ..."
|
| 159 |
+
|
| 160 |
+
if ! command -v openenv &>/dev/null; then
|
| 161 |
+
fail "openenv command not found"
|
| 162 |
+
hint "Install it: pip install openenv-core"
|
| 163 |
+
stop_at "Step 3"
|
| 164 |
+
fi
|
| 165 |
+
|
| 166 |
+
VALIDATE_OK=false
|
| 167 |
+
VALIDATE_OUTPUT=$(cd "$REPO_DIR" && openenv validate 2>&1) && VALIDATE_OK=true
|
| 168 |
+
|
| 169 |
+
if [ "$VALIDATE_OK" = true ]; then
|
| 170 |
+
pass "openenv validate passed"
|
| 171 |
+
[ -n "$VALIDATE_OUTPUT" ] && log " $VALIDATE_OUTPUT"
|
| 172 |
+
else
|
| 173 |
+
fail "openenv validate failed"
|
| 174 |
+
printf "%s\n" "$VALIDATE_OUTPUT"
|
| 175 |
+
stop_at "Step 3"
|
| 176 |
+
fi
|
| 177 |
+
|
| 178 |
+
printf "\n"
|
| 179 |
+
printf "${BOLD}========================================${NC}\n"
|
| 180 |
+
printf "${GREEN}${BOLD} All 3/3 checks passed!${NC}\n"
|
| 181 |
+
printf "${GREEN}${BOLD} Your submission is ready to submit.${NC}\n"
|
| 182 |
+
printf "${BOLD}========================================${NC}\n"
|
| 183 |
+
printf "\n"
|
| 184 |
+
|
| 185 |
+
exit 0
|
sample_inference_script.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Inference Script Example
|
| 3 |
+
===================================
|
| 4 |
+
MANDATORY
|
| 5 |
+
- Before submitting, ensure the following variables are defined in your environment configuration:
|
| 6 |
+
API_BASE_URL The API endpoint for the LLM.
|
| 7 |
+
MODEL_NAME The model identifier to use for inference.
|
| 8 |
+
HF_TOKEN Your Hugging Face / API key.
|
| 9 |
+
LOCAL_IMAGE_NAME The name of the local image to use for the environment if you are using from_docker_image()
|
| 10 |
+
method
|
| 11 |
+
|
| 12 |
+
- Defaults are set only for API_BASE_URL and MODEL_NAME
|
| 13 |
+
(and should reflect your active inference setup):
|
| 14 |
+
API_BASE_URL = os.getenv("API_BASE_URL", "<your-active-endpoint>")
|
| 15 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "<your-active-model>")
|
| 16 |
+
|
| 17 |
+
- The inference script must be named `inference.py` and placed in the root directory of the project
|
| 18 |
+
- Participants must use OpenAI Client for all LLM calls using above variables
|
| 19 |
+
|
| 20 |
+
STDOUT FORMAT
|
| 21 |
+
- The script must emit exactly three line types to stdout, in this order:
|
| 22 |
+
|
| 23 |
+
[START] task=<task_name> env=<benchmark> model=<model_name>
|
| 24 |
+
[STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
|
| 25 |
+
[END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn>
|
| 26 |
+
|
| 27 |
+
Rules:
|
| 28 |
+
- One [START] line at episode begin.
|
| 29 |
+
- One [STEP] line per step, immediately after env.step() returns.
|
| 30 |
+
- One [END] line after env.close(), always emitted (even on exception).
|
| 31 |
+
- reward and rewards are formatted to 2 decimal places.
|
| 32 |
+
- done and success are lowercase booleans: true or false.
|
| 33 |
+
- error is the raw last_action_error string, or null if none.
|
| 34 |
+
- All fields on a single line with no newlines within a line.
|
| 35 |
+
- Each tasks should return score in [0, 1]
|
| 36 |
+
|
| 37 |
+
Example:
|
| 38 |
+
[START] task=click-test env=miniwob model=Qwen3-VL-30B
|
| 39 |
+
[STEP] step=1 action=click('123') reward=0.00 done=false error=null
|
| 40 |
+
[STEP] step=2 action=fill('456','text') reward=0.00 done=false error=null
|
| 41 |
+
[STEP] step=3 action=click('789') reward=1.00 done=true error=null
|
| 42 |
+
[END] success=true steps=3 score=1.00 rewards=0.00,0.00,1.00
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
import asyncio
|
| 46 |
+
import os
|
| 47 |
+
import textwrap
|
| 48 |
+
from typing import List, Optional
|
| 49 |
+
|
| 50 |
+
from openai import OpenAI
|
| 51 |
+
|
| 52 |
+
from my_env_v4 import MyEnvV4Action, MyEnvV4Env
|
| 53 |
+
IMAGE_NAME = os.getenv("IMAGE_NAME") # If you are using docker image
|
| 54 |
+
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
|
| 55 |
+
|
| 56 |
+
API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
|
| 57 |
+
MODEL_NAME = os.getenv("MODEL_NAME") or "Qwen/Qwen2.5-72B-Instruct"
|
| 58 |
+
TASK_NAME = os.getenv("MY_ENV_V4_TASK", "echo")
|
| 59 |
+
BENCHMARK = os.getenv("MY_ENV_V4_BENCHMARK", "my_env_v4")
|
| 60 |
+
MAX_STEPS = 8
|
| 61 |
+
TEMPERATURE = 0.7
|
| 62 |
+
MAX_TOKENS = 150
|
| 63 |
+
SUCCESS_SCORE_THRESHOLD = 0.1 # normalized score in [0, 1]
|
| 64 |
+
|
| 65 |
+
# Max possible reward: each token contributes 0.1, across all steps
|
| 66 |
+
_MAX_REWARD_PER_STEP = MAX_TOKENS * 0.1
|
| 67 |
+
MAX_TOTAL_REWARD = MAX_STEPS * _MAX_REWARD_PER_STEP
|
| 68 |
+
|
| 69 |
+
SYSTEM_PROMPT = textwrap.dedent(
|
| 70 |
+
"""
|
| 71 |
+
You are interacting with a simple echo environment.
|
| 72 |
+
Each turn you must send a message. The environment will echo it back.
|
| 73 |
+
Reward is proportional to message length: reward = len(message) * 0.1
|
| 74 |
+
Your goal is to maximize total reward by sending meaningful, substantive messages.
|
| 75 |
+
Reply with exactly one message string — no quotes, no prefixes, just the message text.
|
| 76 |
+
"""
|
| 77 |
+
).strip()
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def log_start(task: str, env: str, model: str) -> None:
|
| 81 |
+
print(f"[START] task={task} env={env} model={model}", flush=True)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
|
| 85 |
+
error_val = error if error else "null"
|
| 86 |
+
done_val = str(done).lower()
|
| 87 |
+
print(
|
| 88 |
+
f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
|
| 89 |
+
flush=True,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
|
| 94 |
+
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
|
| 95 |
+
print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def build_user_prompt(step: int, last_echoed: str, last_reward: float, history: List[str]) -> str:
|
| 99 |
+
history_block = "\n".join(history[-4:]) if history else "None"
|
| 100 |
+
return textwrap.dedent(
|
| 101 |
+
f"""
|
| 102 |
+
Step: {step}
|
| 103 |
+
Last echoed message: {last_echoed!r}
|
| 104 |
+
Last reward: {last_reward:.2f}
|
| 105 |
+
Previous steps:
|
| 106 |
+
{history_block}
|
| 107 |
+
Send your next message.
|
| 108 |
+
"""
|
| 109 |
+
).strip()
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def get_model_message(client: OpenAI, step: int, last_echoed: str, last_reward: float, history: List[str]) -> str:
|
| 113 |
+
user_prompt = build_user_prompt(step, last_echoed, last_reward, history)
|
| 114 |
+
try:
|
| 115 |
+
completion = client.chat.completions.create(
|
| 116 |
+
model=MODEL_NAME,
|
| 117 |
+
messages=[
|
| 118 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 119 |
+
{"role": "user", "content": user_prompt},
|
| 120 |
+
],
|
| 121 |
+
temperature=TEMPERATURE,
|
| 122 |
+
max_tokens=MAX_TOKENS,
|
| 123 |
+
stream=False,
|
| 124 |
+
)
|
| 125 |
+
text = (completion.choices[0].message.content or "").strip()
|
| 126 |
+
return text if text else "hello"
|
| 127 |
+
except Exception as exc:
|
| 128 |
+
print(f"[DEBUG] Model request failed: {exc}", flush=True)
|
| 129 |
+
return "hello"
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
async def main() -> None:
|
| 133 |
+
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
| 134 |
+
|
| 135 |
+
env = await MyEnvV4Env.from_docker_image(IMAGE_NAME)
|
| 136 |
+
|
| 137 |
+
history: List[str] = []
|
| 138 |
+
rewards: List[float] = []
|
| 139 |
+
steps_taken = 0
|
| 140 |
+
score = 0.0
|
| 141 |
+
success = False
|
| 142 |
+
|
| 143 |
+
log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
|
| 144 |
+
|
| 145 |
+
try:
|
| 146 |
+
result = await env.reset() # OpenENV.reset()
|
| 147 |
+
last_echoed = result.observation.echoed_message
|
| 148 |
+
last_reward = 0.0
|
| 149 |
+
|
| 150 |
+
for step in range(1, MAX_STEPS + 1):
|
| 151 |
+
if result.done:
|
| 152 |
+
break
|
| 153 |
+
|
| 154 |
+
message = get_model_message(client, step, last_echoed, last_reward, history)
|
| 155 |
+
|
| 156 |
+
result = await env.step(MyEnvV4Action(message=message))
|
| 157 |
+
obs = result.observation
|
| 158 |
+
|
| 159 |
+
reward = result.reward or 0.0
|
| 160 |
+
done = result.done
|
| 161 |
+
error = None
|
| 162 |
+
|
| 163 |
+
rewards.append(reward)
|
| 164 |
+
steps_taken = step
|
| 165 |
+
last_echoed = obs.echoed_message
|
| 166 |
+
last_reward = reward
|
| 167 |
+
|
| 168 |
+
log_step(step=step, action=message, reward=reward, done=done, error=error)
|
| 169 |
+
|
| 170 |
+
history.append(f"Step {step}: {message!r} -> reward {reward:+.2f}")
|
| 171 |
+
|
| 172 |
+
if done:
|
| 173 |
+
break
|
| 174 |
+
|
| 175 |
+
score = sum(rewards) / MAX_TOTAL_REWARD if MAX_TOTAL_REWARD > 0 else 0.0
|
| 176 |
+
score = min(max(score, 0.0), 1.0) # clamp to [0, 1]
|
| 177 |
+
success = score >= SUCCESS_SCORE_THRESHOLD
|
| 178 |
+
|
| 179 |
+
finally:
|
| 180 |
+
try:
|
| 181 |
+
await env.close()
|
| 182 |
+
except Exception as e:
|
| 183 |
+
print(f"[DEBUG] env.close() error (container cleanup): {e}", flush=True)
|
| 184 |
+
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
if __name__ == "__main__":
|
| 188 |
+
asyncio.run(main())
|
tests/__init__.py
ADDED
|
File without changes
|
tests/test_environment.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for the DataQA environment (reset, step, scoring)."""
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
from dataqa_env.server.environment import (
|
| 5 |
+
DataQAEnvironment,
|
| 6 |
+
parse_issue_key,
|
| 7 |
+
compute_f1,
|
| 8 |
+
compute_weighted_reward,
|
| 9 |
+
)
|
| 10 |
+
from dataqa_env.models import DataQAAction
|
| 11 |
+
from dataqa_env.server.tasks import PlantedIssue
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class TestParseIssueKey:
|
| 15 |
+
def test_standard_format(self):
|
| 16 |
+
assert parse_issue_key("row:3,col:salary,issue:missing_value") == "row:3,col:salary,issue:missing_value"
|
| 17 |
+
|
| 18 |
+
def test_with_equals(self):
|
| 19 |
+
assert parse_issue_key("row=3,col=salary,issue=missing_value") == "row:3,col:salary,issue:missing_value"
|
| 20 |
+
|
| 21 |
+
def test_case_insensitive(self):
|
| 22 |
+
assert parse_issue_key("Row:3,Col:Salary,Issue:Missing_Value") == "row:3,col:salary,issue:missing_value"
|
| 23 |
+
|
| 24 |
+
def test_with_spaces(self):
|
| 25 |
+
assert parse_issue_key("row: 3, col: salary, issue: missing_value") == "row:3,col:salary,issue:missing_value"
|
| 26 |
+
|
| 27 |
+
def test_unparseable(self):
|
| 28 |
+
assert parse_issue_key("this is garbage") is None
|
| 29 |
+
|
| 30 |
+
def test_partial_match(self):
|
| 31 |
+
assert parse_issue_key("row:3,col:salary") is None # missing issue
|
| 32 |
+
|
| 33 |
+
def test_empty_string(self):
|
| 34 |
+
assert parse_issue_key("") is None
|
| 35 |
+
|
| 36 |
+
def test_semicolon_separator(self):
|
| 37 |
+
result = parse_issue_key("row:3;col:salary;issue:missing_value")
|
| 38 |
+
assert result == "row:3,col:salary,issue:missing_value"
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class TestComputeF1:
|
| 42 |
+
def test_perfect_match(self):
|
| 43 |
+
keys = {"row:1,col:a,issue:missing_value"}
|
| 44 |
+
result = compute_f1(keys, keys)
|
| 45 |
+
assert result["f1"] == 1.0
|
| 46 |
+
assert result["tp"] == 1
|
| 47 |
+
assert result["fp"] == 0
|
| 48 |
+
assert result["fn"] == 0
|
| 49 |
+
|
| 50 |
+
def test_no_reported_no_planted(self):
|
| 51 |
+
result = compute_f1(set(), set())
|
| 52 |
+
assert result["f1"] == 1.0
|
| 53 |
+
|
| 54 |
+
def test_no_reported_some_planted(self):
|
| 55 |
+
planted = {"row:1,col:a,issue:missing_value"}
|
| 56 |
+
result = compute_f1(set(), planted)
|
| 57 |
+
assert result["f1"] == 0.0
|
| 58 |
+
assert result["fn"] == 1
|
| 59 |
+
|
| 60 |
+
def test_all_false_positives(self):
|
| 61 |
+
reported = {"row:99,col:x,issue:wrong_type"}
|
| 62 |
+
planted = {"row:1,col:a,issue:missing_value"}
|
| 63 |
+
result = compute_f1(reported, planted)
|
| 64 |
+
assert result["tp"] == 0
|
| 65 |
+
assert result["fp"] == 1
|
| 66 |
+
assert result["fn"] == 1
|
| 67 |
+
assert result["f1"] == 0.0
|
| 68 |
+
|
| 69 |
+
def test_partial_match(self):
|
| 70 |
+
reported = {"row:1,col:a,issue:missing_value", "row:2,col:b,issue:wrong_type"}
|
| 71 |
+
planted = {"row:1,col:a,issue:missing_value", "row:3,col:c,issue:duplicate_row"}
|
| 72 |
+
result = compute_f1(reported, planted)
|
| 73 |
+
assert result["tp"] == 1
|
| 74 |
+
assert result["fp"] == 1
|
| 75 |
+
assert result["fn"] == 1
|
| 76 |
+
assert 0 < result["f1"] < 1
|
| 77 |
+
|
| 78 |
+
def test_precision_recall_calculation(self):
|
| 79 |
+
reported = {"a", "b", "c"}
|
| 80 |
+
planted = {"a", "b", "d"}
|
| 81 |
+
result = compute_f1(reported, planted)
|
| 82 |
+
assert result["precision"] == pytest.approx(2 / 3)
|
| 83 |
+
assert result["recall"] == pytest.approx(2 / 3)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class TestComputeWeightedReward:
|
| 87 |
+
def test_perfect_match(self):
|
| 88 |
+
issues = [
|
| 89 |
+
PlantedIssue(row=1, col="a", issue_type="missing_value", description="", difficulty=1.0),
|
| 90 |
+
PlantedIssue(row=2, col="b", issue_type="wrong_type", description="", difficulty=3.0),
|
| 91 |
+
]
|
| 92 |
+
reported = {i.to_key() for i in issues}
|
| 93 |
+
result = compute_weighted_reward(reported, issues)
|
| 94 |
+
assert result["weighted_reward"] == 1.0
|
| 95 |
+
|
| 96 |
+
def test_empty_both(self):
|
| 97 |
+
result = compute_weighted_reward(set(), [])
|
| 98 |
+
assert result["weighted_reward"] == 1.0
|
| 99 |
+
|
| 100 |
+
def test_no_reported(self):
|
| 101 |
+
issues = [PlantedIssue(row=1, col="a", issue_type="missing_value", description="", difficulty=2.0)]
|
| 102 |
+
result = compute_weighted_reward(set(), issues)
|
| 103 |
+
assert result["weighted_reward"] == 0.0
|
| 104 |
+
assert result["difficulty_missed"] == 2.0
|
| 105 |
+
|
| 106 |
+
def test_hard_issue_worth_more(self):
|
| 107 |
+
easy = PlantedIssue(row=1, col="a", issue_type="missing_value", description="", difficulty=1.0)
|
| 108 |
+
hard = PlantedIssue(row=2, col="b", issue_type="statistical_outlier", description="", difficulty=3.0)
|
| 109 |
+
issues = [easy, hard]
|
| 110 |
+
|
| 111 |
+
# Finding only the hard issue should score higher than only the easy issue
|
| 112 |
+
hard_found = compute_weighted_reward({hard.to_key()}, issues)
|
| 113 |
+
easy_found = compute_weighted_reward({easy.to_key()}, issues)
|
| 114 |
+
assert hard_found["weighted_reward"] > easy_found["weighted_reward"]
|
| 115 |
+
|
| 116 |
+
def test_false_positives_reduce_reward(self):
|
| 117 |
+
issues = [PlantedIssue(row=1, col="a", issue_type="missing_value", description="", difficulty=1.0)]
|
| 118 |
+
correct = {issues[0].to_key()}
|
| 119 |
+
with_fp = correct | {"row:99,col:x,issue:wrong_type"}
|
| 120 |
+
r_correct = compute_weighted_reward(correct, issues)
|
| 121 |
+
r_with_fp = compute_weighted_reward(with_fp, issues)
|
| 122 |
+
assert r_correct["weighted_reward"] > r_with_fp["weighted_reward"]
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class TestDataQAEnvironment:
|
| 126 |
+
@pytest.fixture
|
| 127 |
+
def env(self):
|
| 128 |
+
return DataQAEnvironment()
|
| 129 |
+
|
| 130 |
+
def test_reset_returns_observation(self, env):
|
| 131 |
+
obs = env.reset(task_id="easy")
|
| 132 |
+
assert obs.dataset_csv
|
| 133 |
+
assert obs.schema_description
|
| 134 |
+
assert obs.validation_rules
|
| 135 |
+
assert obs.task_description
|
| 136 |
+
assert obs.num_issues_hint == 4
|
| 137 |
+
assert obs.max_steps == 3
|
| 138 |
+
assert obs.done is False
|
| 139 |
+
assert obs.reward == 0.0
|
| 140 |
+
|
| 141 |
+
def test_reset_medium(self, env):
|
| 142 |
+
obs = env.reset(task_id="medium")
|
| 143 |
+
assert obs.num_issues_hint == 6
|
| 144 |
+
|
| 145 |
+
def test_reset_hard(self, env):
|
| 146 |
+
obs = env.reset(task_id="hard")
|
| 147 |
+
assert obs.num_issues_hint == 8
|
| 148 |
+
|
| 149 |
+
def test_step_with_correct_issues(self, env):
|
| 150 |
+
env.reset(task_id="easy")
|
| 151 |
+
# Submit all correct issues for easy task
|
| 152 |
+
action = DataQAAction(
|
| 153 |
+
issues=[
|
| 154 |
+
"row:4,col:name,issue:missing_value",
|
| 155 |
+
"row:7,col:salary,issue:wrong_type",
|
| 156 |
+
"row:11,col:employee_id,issue:duplicate_row",
|
| 157 |
+
"row:9,col:salary,issue:out_of_range",
|
| 158 |
+
],
|
| 159 |
+
task_id="easy",
|
| 160 |
+
)
|
| 161 |
+
obs = env.step(action)
|
| 162 |
+
assert obs.done is True
|
| 163 |
+
assert obs.reward >= 0.999
|
| 164 |
+
|
| 165 |
+
def test_step_with_partial_issues(self, env):
|
| 166 |
+
env.reset(task_id="easy")
|
| 167 |
+
action = DataQAAction(
|
| 168 |
+
issues=["row:4,col:name,issue:missing_value"],
|
| 169 |
+
task_id="easy",
|
| 170 |
+
)
|
| 171 |
+
obs = env.step(action)
|
| 172 |
+
assert 0 < obs.reward < 1.0
|
| 173 |
+
assert obs.done is False
|
| 174 |
+
|
| 175 |
+
def test_step_with_no_issues(self, env):
|
| 176 |
+
env.reset(task_id="easy")
|
| 177 |
+
action = DataQAAction(issues=[], task_id="easy")
|
| 178 |
+
obs = env.step(action)
|
| 179 |
+
assert obs.reward == 0.0
|
| 180 |
+
|
| 181 |
+
def test_step_exhausts_max_steps(self, env):
|
| 182 |
+
env.reset(task_id="easy")
|
| 183 |
+
for _ in range(3):
|
| 184 |
+
action = DataQAAction(issues=["row:99,col:x,issue:wrong_type"], task_id="easy")
|
| 185 |
+
obs = env.step(action)
|
| 186 |
+
assert obs.done is True
|
| 187 |
+
|
| 188 |
+
def test_auto_reset_on_step(self, env):
|
| 189 |
+
# step() without prior reset should auto-reset
|
| 190 |
+
action = DataQAAction(
|
| 191 |
+
issues=["row:4,col:name,issue:missing_value"],
|
| 192 |
+
task_id="easy",
|
| 193 |
+
)
|
| 194 |
+
obs = env.step(action)
|
| 195 |
+
assert obs.task_id == "easy"
|
| 196 |
+
|
| 197 |
+
def test_state_tracking(self, env):
|
| 198 |
+
env.reset(task_id="easy")
|
| 199 |
+
assert env.state.task_id == "easy"
|
| 200 |
+
assert env.state.current_step == 0
|
| 201 |
+
assert env.state.best_score == 0.0
|
| 202 |
+
|
| 203 |
+
action = DataQAAction(issues=["row:4,col:name,issue:missing_value"], task_id="easy")
|
| 204 |
+
env.step(action)
|
| 205 |
+
assert env.state.current_step == 1
|
| 206 |
+
assert env.state.best_score > 0.0
|
| 207 |
+
|
| 208 |
+
def test_best_score_monotonic(self, env):
|
| 209 |
+
env.reset(task_id="easy")
|
| 210 |
+
action1 = DataQAAction(
|
| 211 |
+
issues=["row:4,col:name,issue:missing_value", "row:7,col:salary,issue:wrong_type"],
|
| 212 |
+
task_id="easy",
|
| 213 |
+
)
|
| 214 |
+
env.step(action1)
|
| 215 |
+
score_after_1 = env.state.best_score
|
| 216 |
+
|
| 217 |
+
# Worse submission shouldn't decrease best_score
|
| 218 |
+
action2 = DataQAAction(issues=["row:99,col:x,issue:wrong_type"], task_id="easy")
|
| 219 |
+
env.step(action2)
|
| 220 |
+
assert env.state.best_score >= score_after_1
|
| 221 |
+
|
| 222 |
+
def test_metadata_included_in_observation(self, env):
|
| 223 |
+
env.reset(task_id="easy")
|
| 224 |
+
action = DataQAAction(issues=["row:4,col:name,issue:missing_value"], task_id="easy")
|
| 225 |
+
obs = env.step(action)
|
| 226 |
+
assert "f1" in obs.metadata
|
| 227 |
+
assert "weighted_reward" in obs.metadata
|
| 228 |
+
assert "tp" in obs.metadata
|
| 229 |
+
assert "difficulty_found" in obs.metadata
|
| 230 |
+
|
| 231 |
+
def test_parse_error_in_feedback(self, env):
|
| 232 |
+
env.reset(task_id="easy")
|
| 233 |
+
action = DataQAAction(issues=["garbage input"], task_id="easy")
|
| 234 |
+
obs = env.step(action)
|
| 235 |
+
assert "Parse error" in obs.feedback
|
| 236 |
+
|
| 237 |
+
def test_concurrent_sessions_flag(self):
|
| 238 |
+
assert DataQAEnvironment.SUPPORTS_CONCURRENT_SESSIONS is True
|
| 239 |
+
|
| 240 |
+
def test_reward_between_0_and_1(self, env):
|
| 241 |
+
"""Hackathon requirement: scores must be 0.0-1.0."""
|
| 242 |
+
env.reset(task_id="hard")
|
| 243 |
+
for _ in range(3):
|
| 244 |
+
action = DataQAAction(
|
| 245 |
+
issues=["row:1,col:x,issue:wrong_type", "row:99,col:y,issue:missing_value"],
|
| 246 |
+
task_id="hard",
|
| 247 |
+
)
|
| 248 |
+
obs = env.step(action)
|
| 249 |
+
assert 0.0 <= obs.reward <= 1.0
|
tests/test_extensibility.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for the extensibility API — custom tasks and contamination rules."""
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
from dataqa_env.server.tasks import (
|
| 5 |
+
PlantedIssue,
|
| 6 |
+
create_task_from_config,
|
| 7 |
+
register_task,
|
| 8 |
+
register_contamination_rule,
|
| 9 |
+
CONTAMINATION_RULES,
|
| 10 |
+
get_task,
|
| 11 |
+
list_tasks,
|
| 12 |
+
)
|
| 13 |
+
from dataqa_env.server.environment import DataQAEnvironment, compute_weighted_reward
|
| 14 |
+
from dataqa_env.models import DataQAAction
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
SIMPLE_CSV = "id,name,score\n1,Alice,95\n2,Bob,87\n3,Carol,92\n4,Dave,78"
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class TestCreateTaskFromConfig:
|
| 21 |
+
def test_basic_creation(self):
|
| 22 |
+
task = create_task_from_config(
|
| 23 |
+
task_id="test_custom",
|
| 24 |
+
name="Test Task",
|
| 25 |
+
description="Test",
|
| 26 |
+
schema_description="id: int, name: str, score: int",
|
| 27 |
+
validation_rules="No missing values",
|
| 28 |
+
clean_csv=SIMPLE_CSV,
|
| 29 |
+
contaminations=[
|
| 30 |
+
{"rule": "missing_value", "row": 0, "col": 1},
|
| 31 |
+
],
|
| 32 |
+
)
|
| 33 |
+
assert task.task_id == "test_custom"
|
| 34 |
+
assert len(task.planted_issues) == 1
|
| 35 |
+
assert task.planted_issues[0].issue_type == "missing_value"
|
| 36 |
+
assert task.planted_issues[0].col == "name"
|
| 37 |
+
|
| 38 |
+
def test_multiple_contaminations(self):
|
| 39 |
+
task = create_task_from_config(
|
| 40 |
+
task_id="multi",
|
| 41 |
+
name="Multi",
|
| 42 |
+
description="Test",
|
| 43 |
+
schema_description="",
|
| 44 |
+
validation_rules="",
|
| 45 |
+
clean_csv=SIMPLE_CSV,
|
| 46 |
+
contaminations=[
|
| 47 |
+
{"rule": "missing_value", "row": 0, "col": 1},
|
| 48 |
+
{"rule": "missing_value", "row": 2, "col": 1},
|
| 49 |
+
],
|
| 50 |
+
)
|
| 51 |
+
assert len(task.planted_issues) == 2
|
| 52 |
+
|
| 53 |
+
def test_custom_difficulty_override(self):
|
| 54 |
+
task = create_task_from_config(
|
| 55 |
+
task_id="custom_diff",
|
| 56 |
+
name="Custom Difficulty",
|
| 57 |
+
description="Test",
|
| 58 |
+
schema_description="",
|
| 59 |
+
validation_rules="",
|
| 60 |
+
clean_csv=SIMPLE_CSV,
|
| 61 |
+
contaminations=[
|
| 62 |
+
{"rule": "missing_value", "row": 0, "col": 1, "difficulty": 2.5},
|
| 63 |
+
],
|
| 64 |
+
)
|
| 65 |
+
assert task.planted_issues[0].difficulty == 2.5
|
| 66 |
+
|
| 67 |
+
def test_callable_rule(self):
|
| 68 |
+
def custom_rule(rows, header, col_idx, row_idx, rng):
|
| 69 |
+
return "CORRUPTED", PlantedIssue(
|
| 70 |
+
row=row_idx + 1, col=header[col_idx], issue_type="wrong_type",
|
| 71 |
+
description="Custom corruption", difficulty=1.5,
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
task = create_task_from_config(
|
| 75 |
+
task_id="callable",
|
| 76 |
+
name="Callable Rule",
|
| 77 |
+
description="Test",
|
| 78 |
+
schema_description="",
|
| 79 |
+
validation_rules="",
|
| 80 |
+
clean_csv=SIMPLE_CSV,
|
| 81 |
+
contaminations=[
|
| 82 |
+
{"rule": custom_rule, "row": 1, "col": 2},
|
| 83 |
+
],
|
| 84 |
+
)
|
| 85 |
+
assert task.planted_issues[0].issue_type == "wrong_type"
|
| 86 |
+
assert "CORRUPTED" in task.corrupted_csv
|
| 87 |
+
|
| 88 |
+
def test_unknown_rule_raises(self):
|
| 89 |
+
with pytest.raises(ValueError, match="Unknown contamination rule"):
|
| 90 |
+
create_task_from_config(
|
| 91 |
+
task_id="bad",
|
| 92 |
+
name="Bad",
|
| 93 |
+
description="",
|
| 94 |
+
schema_description="",
|
| 95 |
+
validation_rules="",
|
| 96 |
+
clean_csv=SIMPLE_CSV,
|
| 97 |
+
contaminations=[{"rule": "nonexistent_rule", "row": 0, "col": 0}],
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class TestRegisterContaminationRule:
|
| 102 |
+
def test_register_and_use(self):
|
| 103 |
+
def reverse_value(rows, header, col_idx, row_idx, rng):
|
| 104 |
+
val = rows[row_idx][col_idx]
|
| 105 |
+
return val[::-1], PlantedIssue(
|
| 106 |
+
row=row_idx + 1, col=header[col_idx], issue_type="format_violation",
|
| 107 |
+
description="Reversed value", difficulty=1.5,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
register_contamination_rule("reverse", reverse_value)
|
| 111 |
+
assert "reverse" in CONTAMINATION_RULES
|
| 112 |
+
|
| 113 |
+
task = create_task_from_config(
|
| 114 |
+
task_id="rev_test",
|
| 115 |
+
name="Reverse Test",
|
| 116 |
+
description="",
|
| 117 |
+
schema_description="",
|
| 118 |
+
validation_rules="",
|
| 119 |
+
clean_csv=SIMPLE_CSV,
|
| 120 |
+
contaminations=[{"rule": "reverse", "row": 0, "col": 1}],
|
| 121 |
+
)
|
| 122 |
+
assert task.planted_issues[0].issue_type == "format_violation"
|
| 123 |
+
# "Alice" reversed is "ecilA"
|
| 124 |
+
assert "ecilA" in task.corrupted_csv
|
| 125 |
+
|
| 126 |
+
# Cleanup
|
| 127 |
+
del CONTAMINATION_RULES["reverse"]
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class TestRegisterTask:
|
| 131 |
+
def test_register_and_get(self):
|
| 132 |
+
task = create_task_from_config(
|
| 133 |
+
task_id="registered",
|
| 134 |
+
name="Registered Task",
|
| 135 |
+
description="Test registered task",
|
| 136 |
+
schema_description="id: int, name: str",
|
| 137 |
+
validation_rules="No missing values",
|
| 138 |
+
clean_csv=SIMPLE_CSV,
|
| 139 |
+
contaminations=[{"rule": "missing_value", "row": 1, "col": 1}],
|
| 140 |
+
)
|
| 141 |
+
register_task("registered", lambda seed: task)
|
| 142 |
+
assert "registered" in list_tasks()
|
| 143 |
+
|
| 144 |
+
fetched = get_task("registered")
|
| 145 |
+
assert fetched.task_id == "registered"
|
| 146 |
+
assert len(fetched.planted_issues) == 1
|
| 147 |
+
|
| 148 |
+
# Cleanup
|
| 149 |
+
from dataqa_env.server.tasks import TASK_REGISTRY
|
| 150 |
+
del TASK_REGISTRY["registered"]
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class TestCustomTaskInEnvironment:
|
| 154 |
+
def test_full_lifecycle(self):
|
| 155 |
+
"""Custom task works end-to-end in the environment."""
|
| 156 |
+
task = create_task_from_config(
|
| 157 |
+
task_id="e2e_custom",
|
| 158 |
+
name="E2E Custom",
|
| 159 |
+
description="End-to-end test",
|
| 160 |
+
schema_description="id: int, name: str, score: int",
|
| 161 |
+
validation_rules="No missing values",
|
| 162 |
+
clean_csv=SIMPLE_CSV,
|
| 163 |
+
contaminations=[
|
| 164 |
+
{"rule": "missing_value", "row": 0, "col": 1, "difficulty": 1.0},
|
| 165 |
+
{"rule": "whitespace_value", "row": 2, "col": 1, "difficulty": 2.5},
|
| 166 |
+
],
|
| 167 |
+
)
|
| 168 |
+
register_task("e2e_custom", lambda seed: task)
|
| 169 |
+
|
| 170 |
+
env = DataQAEnvironment()
|
| 171 |
+
obs = env.reset(task_id="e2e_custom")
|
| 172 |
+
assert obs.num_issues_hint == 2
|
| 173 |
+
|
| 174 |
+
# Submit correct answers
|
| 175 |
+
action = DataQAAction(
|
| 176 |
+
issues=[i.to_key() for i in task.planted_issues],
|
| 177 |
+
task_id="e2e_custom",
|
| 178 |
+
)
|
| 179 |
+
obs = env.step(action)
|
| 180 |
+
assert obs.done is True
|
| 181 |
+
assert obs.reward >= 0.999
|
| 182 |
+
|
| 183 |
+
# Cleanup
|
| 184 |
+
from dataqa_env.server.tasks import TASK_REGISTRY
|
| 185 |
+
del TASK_REGISTRY["e2e_custom"]
|
tests/test_inference.py
ADDED
|
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for the inference script's parsing, prompt building, and log format."""
|
| 2 |
+
|
| 3 |
+
import re
|
| 4 |
+
import pytest
|
| 5 |
+
import sys
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
|
| 9 |
+
from inference import parse_llm_response, build_user_prompt, log_start, log_step, log_end
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class TestParseLLMResponse:
|
| 13 |
+
def test_standard_format(self):
|
| 14 |
+
response = "row:1,col:name,issue:missing_value\nrow:2,col:salary,issue:wrong_type"
|
| 15 |
+
issues = parse_llm_response(response)
|
| 16 |
+
assert len(issues) == 2
|
| 17 |
+
assert "row:1,col:name,issue:missing_value" in issues
|
| 18 |
+
|
| 19 |
+
def test_numbered_list(self):
|
| 20 |
+
response = "1. row:1,col:name,issue:missing_value\n2. row:2,col:salary,issue:wrong_type"
|
| 21 |
+
issues = parse_llm_response(response)
|
| 22 |
+
assert len(issues) == 2
|
| 23 |
+
|
| 24 |
+
def test_bullet_list(self):
|
| 25 |
+
response = "- row:1,col:name,issue:missing_value\n* row:2,col:salary,issue:wrong_type"
|
| 26 |
+
issues = parse_llm_response(response)
|
| 27 |
+
assert len(issues) == 2
|
| 28 |
+
|
| 29 |
+
def test_equals_delimiter(self):
|
| 30 |
+
response = "row=1,col=name,issue=missing_value"
|
| 31 |
+
issues = parse_llm_response(response)
|
| 32 |
+
assert len(issues) == 1
|
| 33 |
+
assert issues[0] == "row:1,col:name,issue:missing_value"
|
| 34 |
+
|
| 35 |
+
def test_mixed_case(self):
|
| 36 |
+
response = "Row:1,Col:Name,Issue:Missing_Value"
|
| 37 |
+
issues = parse_llm_response(response)
|
| 38 |
+
assert len(issues) == 1
|
| 39 |
+
assert issues[0] == "row:1,col:name,issue:missing_value"
|
| 40 |
+
|
| 41 |
+
def test_empty_response(self):
|
| 42 |
+
assert parse_llm_response("") == []
|
| 43 |
+
assert parse_llm_response(" ") == []
|
| 44 |
+
|
| 45 |
+
def test_garbage_lines_skipped(self):
|
| 46 |
+
response = "Here are the issues:\nrow:1,col:name,issue:missing_value\nNo more issues."
|
| 47 |
+
issues = parse_llm_response(response)
|
| 48 |
+
assert len(issues) == 1
|
| 49 |
+
|
| 50 |
+
def test_deduplication_not_applied(self):
|
| 51 |
+
response = "row:1,col:name,issue:missing_value\nrow:1,col:name,issue:missing_value"
|
| 52 |
+
issues = parse_llm_response(response)
|
| 53 |
+
assert len(issues) == 2 # duplicates kept, env handles dedup
|
| 54 |
+
|
| 55 |
+
def test_with_column_variant(self):
|
| 56 |
+
response = "row:1,column:name,issue:missing_value"
|
| 57 |
+
issues = parse_llm_response(response)
|
| 58 |
+
assert len(issues) == 1
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class TestBuildUserPrompt:
|
| 62 |
+
def test_includes_all_fields(self):
|
| 63 |
+
obs = {
|
| 64 |
+
"task_description": "Find issues",
|
| 65 |
+
"schema_description": "col: int",
|
| 66 |
+
"validation_rules": "no nulls",
|
| 67 |
+
"dataset_csv": "a,b\n1,2",
|
| 68 |
+
"num_issues_hint": 3,
|
| 69 |
+
"feedback": "",
|
| 70 |
+
}
|
| 71 |
+
prompt = build_user_prompt(obs)
|
| 72 |
+
assert "Find issues" in prompt
|
| 73 |
+
assert "col: int" in prompt
|
| 74 |
+
assert "no nulls" in prompt
|
| 75 |
+
assert "a,b" in prompt
|
| 76 |
+
assert "3 issues" in prompt
|
| 77 |
+
|
| 78 |
+
def test_includes_feedback_on_retry(self):
|
| 79 |
+
obs = {
|
| 80 |
+
"task_description": "Find issues",
|
| 81 |
+
"schema_description": "",
|
| 82 |
+
"validation_rules": "",
|
| 83 |
+
"dataset_csv": "a\n1",
|
| 84 |
+
"num_issues_hint": 0,
|
| 85 |
+
"feedback": "Step 1/3: You missed 2 issues",
|
| 86 |
+
}
|
| 87 |
+
prompt = build_user_prompt(obs)
|
| 88 |
+
assert "FEEDBACK" in prompt
|
| 89 |
+
assert "missed 2" in prompt
|
| 90 |
+
|
| 91 |
+
def test_excludes_reset_feedback(self):
|
| 92 |
+
obs = {
|
| 93 |
+
"task_description": "",
|
| 94 |
+
"schema_description": "",
|
| 95 |
+
"validation_rules": "",
|
| 96 |
+
"dataset_csv": "",
|
| 97 |
+
"num_issues_hint": 0,
|
| 98 |
+
"feedback": "Environment reset. Start inspecting.",
|
| 99 |
+
}
|
| 100 |
+
prompt = build_user_prompt(obs)
|
| 101 |
+
assert "FEEDBACK" not in prompt
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class TestLogFormat:
|
| 105 |
+
"""Verify stdout log format matches hackathon evaluation requirements."""
|
| 106 |
+
|
| 107 |
+
def test_log_start_format(self, capsys):
|
| 108 |
+
log_start(task="easy", env="dataqa_env", model="test-model")
|
| 109 |
+
out = capsys.readouterr().out.strip()
|
| 110 |
+
assert out == "[START] task=easy env=dataqa_env model=test-model"
|
| 111 |
+
|
| 112 |
+
def test_log_step_format(self, capsys):
|
| 113 |
+
log_step(step=1, action="row:1,col:name,issue:missing_value", reward=0.50, done=False, error=None)
|
| 114 |
+
out = capsys.readouterr().out.strip()
|
| 115 |
+
assert out == "[STEP] step=1 action=row:1,col:name,issue:missing_value reward=0.50 done=false error=null"
|
| 116 |
+
|
| 117 |
+
def test_log_step_with_error(self, capsys):
|
| 118 |
+
log_step(step=2, action="none", reward=0.00, done=True, error="timeout")
|
| 119 |
+
out = capsys.readouterr().out.strip()
|
| 120 |
+
assert "error=timeout" in out
|
| 121 |
+
assert "done=true" in out
|
| 122 |
+
|
| 123 |
+
def test_log_end_format(self, capsys):
|
| 124 |
+
log_end(success=True, steps=3, score=0.85, rewards=[0.25, 0.50, 0.85])
|
| 125 |
+
out = capsys.readouterr().out.strip()
|
| 126 |
+
assert out == "[END] success=true steps=3 score=0.850 rewards=0.25,0.50,0.85"
|
| 127 |
+
|
| 128 |
+
def test_log_end_failure(self, capsys):
|
| 129 |
+
log_end(success=False, steps=1, score=0.0, rewards=[0.0])
|
| 130 |
+
out = capsys.readouterr().out.strip()
|
| 131 |
+
assert "success=false" in out
|
| 132 |
+
assert "score=0.000" in out
|
| 133 |
+
|
| 134 |
+
def test_reward_format_2_decimal(self, capsys):
|
| 135 |
+
log_step(step=1, action="test", reward=0.123456, done=False, error=None)
|
| 136 |
+
out = capsys.readouterr().out.strip()
|
| 137 |
+
assert "reward=0.12" in out
|
| 138 |
+
|
| 139 |
+
def test_no_newlines_within_line(self, capsys):
|
| 140 |
+
log_start(task="easy", env="dataqa_env", model="model")
|
| 141 |
+
log_step(step=1, action="act", reward=0.0, done=False, error=None)
|
| 142 |
+
log_end(success=False, steps=1, score=0.0, rewards=[0.0])
|
| 143 |
+
out = capsys.readouterr().out
|
| 144 |
+
lines = [l for l in out.split("\n") if l.strip()]
|
| 145 |
+
assert len(lines) == 3
|
| 146 |
+
assert lines[0].startswith("[START]")
|
| 147 |
+
assert lines[1].startswith("[STEP]")
|
| 148 |
+
assert lines[2].startswith("[END]")
|
tests/test_tasks.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Tests for task definitions, data corruption, and issue planting."""
|
| 2 |
+
|
| 3 |
+
import pytest
|
| 4 |
+
from dataqa_env.server.tasks import (
|
| 5 |
+
PlantedIssue,
|
| 6 |
+
Task,
|
| 7 |
+
create_task_easy,
|
| 8 |
+
create_task_medium,
|
| 9 |
+
create_task_hard,
|
| 10 |
+
get_task,
|
| 11 |
+
list_tasks,
|
| 12 |
+
_csv_to_rows,
|
| 13 |
+
_rows_to_csv,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class TestPlantedIssue:
|
| 18 |
+
def test_to_key(self):
|
| 19 |
+
issue = PlantedIssue(row=3, col="salary", issue_type="missing_value", description="test")
|
| 20 |
+
assert issue.to_key() == "row:3,col:salary,issue:missing_value"
|
| 21 |
+
|
| 22 |
+
def test_difficulty_default(self):
|
| 23 |
+
issue = PlantedIssue(row=1, col="name", issue_type="missing_value", description="test")
|
| 24 |
+
assert issue.difficulty == 1.0
|
| 25 |
+
|
| 26 |
+
def test_difficulty_custom(self):
|
| 27 |
+
issue = PlantedIssue(row=1, col="name", issue_type="missing_value", description="test", difficulty=3.0)
|
| 28 |
+
assert issue.difficulty == 3.0
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class TestCSVHelpers:
|
| 32 |
+
def test_roundtrip(self):
|
| 33 |
+
csv_text = "a,b,c\n1,2,3\n4,5,6"
|
| 34 |
+
rows = _csv_to_rows(csv_text)
|
| 35 |
+
assert len(rows) == 3
|
| 36 |
+
result = _rows_to_csv(rows)
|
| 37 |
+
assert "1,2,3" in result
|
| 38 |
+
|
| 39 |
+
def test_empty_csv(self):
|
| 40 |
+
rows = _csv_to_rows("a,b\n")
|
| 41 |
+
assert len(rows) == 1 # header only
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class TestTaskEasy:
|
| 45 |
+
@pytest.fixture
|
| 46 |
+
def task(self):
|
| 47 |
+
return create_task_easy()
|
| 48 |
+
|
| 49 |
+
def test_task_id(self, task):
|
| 50 |
+
assert task.task_id == "easy"
|
| 51 |
+
|
| 52 |
+
def test_has_4_issues(self, task):
|
| 53 |
+
assert len(task.planted_issues) == 4
|
| 54 |
+
|
| 55 |
+
def test_issue_types(self, task):
|
| 56 |
+
types = {i.issue_type for i in task.planted_issues}
|
| 57 |
+
assert "missing_value" in types
|
| 58 |
+
assert "wrong_type" in types
|
| 59 |
+
assert "duplicate_row" in types
|
| 60 |
+
assert "out_of_range" in types
|
| 61 |
+
|
| 62 |
+
def test_corrupted_csv_differs_from_clean(self, task):
|
| 63 |
+
assert task.corrupted_csv != task.clean_csv
|
| 64 |
+
|
| 65 |
+
def test_issue_keys_unique(self, task):
|
| 66 |
+
keys = [i.to_key() for i in task.planted_issues]
|
| 67 |
+
assert len(keys) == len(set(keys))
|
| 68 |
+
|
| 69 |
+
def test_max_steps(self, task):
|
| 70 |
+
assert task.max_steps == 3
|
| 71 |
+
|
| 72 |
+
def test_corrupted_csv_has_more_rows(self, task):
|
| 73 |
+
clean_rows = _csv_to_rows(task.clean_csv)
|
| 74 |
+
corrupt_rows = _csv_to_rows(task.corrupted_csv)
|
| 75 |
+
assert len(corrupt_rows) > len(clean_rows) # duplicate row added
|
| 76 |
+
|
| 77 |
+
def test_difficulty_weights(self, task):
|
| 78 |
+
for issue in task.planted_issues:
|
| 79 |
+
assert 1.0 <= issue.difficulty <= 3.0
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
class TestTaskMedium:
|
| 83 |
+
@pytest.fixture
|
| 84 |
+
def task(self):
|
| 85 |
+
return create_task_medium()
|
| 86 |
+
|
| 87 |
+
def test_task_id(self, task):
|
| 88 |
+
assert task.task_id == "medium"
|
| 89 |
+
|
| 90 |
+
def test_has_6_issues(self, task):
|
| 91 |
+
assert len(task.planted_issues) == 6
|
| 92 |
+
|
| 93 |
+
def test_issue_types(self, task):
|
| 94 |
+
types = {i.issue_type for i in task.planted_issues}
|
| 95 |
+
assert "inconsistent_value" in types
|
| 96 |
+
assert "format_violation" in types
|
| 97 |
+
assert "missing_value" in types
|
| 98 |
+
|
| 99 |
+
def test_issue_keys_unique(self, task):
|
| 100 |
+
keys = [i.to_key() for i in task.planted_issues]
|
| 101 |
+
assert len(keys) == len(set(keys))
|
| 102 |
+
|
| 103 |
+
def test_difficulty_weights(self, task):
|
| 104 |
+
for issue in task.planted_issues:
|
| 105 |
+
assert 1.0 <= issue.difficulty <= 3.0
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
class TestTaskHard:
|
| 109 |
+
@pytest.fixture
|
| 110 |
+
def task(self):
|
| 111 |
+
return create_task_hard()
|
| 112 |
+
|
| 113 |
+
def test_task_id(self, task):
|
| 114 |
+
assert task.task_id == "hard"
|
| 115 |
+
|
| 116 |
+
def test_has_8_issues(self, task):
|
| 117 |
+
assert len(task.planted_issues) == 8
|
| 118 |
+
|
| 119 |
+
def test_issue_types(self, task):
|
| 120 |
+
types = {i.issue_type for i in task.planted_issues}
|
| 121 |
+
assert "inconsistent_value" in types
|
| 122 |
+
assert "format_violation" in types
|
| 123 |
+
assert "statistical_outlier" in types
|
| 124 |
+
assert "out_of_range" in types
|
| 125 |
+
assert "missing_value" in types
|
| 126 |
+
|
| 127 |
+
def test_has_high_difficulty_issues(self, task):
|
| 128 |
+
hard_issues = [i for i in task.planted_issues if i.difficulty >= 2.5]
|
| 129 |
+
assert len(hard_issues) >= 2 # data leakage, GPU outlier, whitespace
|
| 130 |
+
|
| 131 |
+
def test_issue_keys_unique(self, task):
|
| 132 |
+
keys = [i.to_key() for i in task.planted_issues]
|
| 133 |
+
assert len(keys) == len(set(keys))
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class TestTaskRegistry:
|
| 137 |
+
def test_list_tasks(self):
|
| 138 |
+
tasks = list_tasks()
|
| 139 |
+
assert set(tasks) == {"easy", "medium", "hard"}
|
| 140 |
+
|
| 141 |
+
def test_get_task_easy(self):
|
| 142 |
+
task = get_task("easy")
|
| 143 |
+
assert task.task_id == "easy"
|
| 144 |
+
|
| 145 |
+
def test_get_task_medium(self):
|
| 146 |
+
task = get_task("medium")
|
| 147 |
+
assert task.task_id == "medium"
|
| 148 |
+
|
| 149 |
+
def test_get_task_hard(self):
|
| 150 |
+
task = get_task("hard")
|
| 151 |
+
assert task.task_id == "hard"
|
| 152 |
+
|
| 153 |
+
def test_get_task_unknown_raises(self):
|
| 154 |
+
with pytest.raises(ValueError, match="Unknown task"):
|
| 155 |
+
get_task("nonexistent")
|
| 156 |
+
|
| 157 |
+
def test_seed_determinism(self):
|
| 158 |
+
t1 = get_task("easy", seed=42)
|
| 159 |
+
t2 = get_task("easy", seed=42)
|
| 160 |
+
assert t1.corrupted_csv == t2.corrupted_csv
|
| 161 |
+
assert [i.to_key() for i in t1.planted_issues] == [i.to_key() for i in t2.planted_issues]
|