varb15 commited on
Commit
0c216ef
·
verified ·
1 Parent(s): 8ced877

Upload folder using huggingface_hub

Browse files
Dockerfile ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system deps
6
+ RUN apt-get update && apt-get install -y --no-install-recommends \
7
+ git curl \
8
+ && rm -rf /var/lib/apt/lists/*
9
+
10
+ # Install uv for fast dependency management
11
+ RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \
12
+ mv /root/.local/bin/uv /usr/local/bin/uv && \
13
+ mv /root/.local/bin/uvx /usr/local/bin/uvx
14
+
15
+ # Copy project files
16
+ COPY pyproject.toml /app/
17
+ COPY openenv.yaml /app/
18
+ COPY dataqa_env/ /app/dataqa_env/
19
+ COPY inference.py /app/
20
+ COPY README.md /app/
21
+
22
+ # Install dependencies
23
+ RUN uv sync --no-editable 2>/dev/null || pip install -e .
24
+
25
+ # Set environment
26
+ ENV PATH="/app/.venv/bin:$PATH"
27
+ ENV PYTHONPATH="/app:$PYTHONPATH"
28
+
29
+ # Health check
30
+ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
31
+ CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" || exit 1
32
+
33
+ EXPOSE 8000
34
+
35
+ ENV ENABLE_WEB_INTERFACE=true
36
+ CMD ["uvicorn", "dataqa_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"]
README.md CHANGED
@@ -1,10 +1,109 @@
1
  ---
2
- title: Dataqa Env
3
- emoji: 💻
4
- colorFrom: red
5
- colorTo: yellow
6
  sdk: docker
7
  pinned: false
 
 
 
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: DataQA Environment Server
3
+ emoji: 🔍
4
+ colorFrom: blue
5
+ colorTo: gray
6
  sdk: docker
7
  pinned: false
8
+ app_port: 8000
9
+ base_path: /web
10
+ tags:
11
+ - openenv
12
  ---
13
 
14
+ # DataQA Environment
15
+
16
+ An OpenEnv environment for **Data Quality Assurance** — an LLM agent inspects datasets with planted quality issues and must identify them all.
17
+
18
+ ## Overview
19
+
20
+ DataQA simulates the real-world task of validating datasets before they enter ML training pipelines or production databases. The agent receives a corrupted dataset along with its schema and validation rules, then must identify all planted data quality issues.
21
+
22
+ ### Why Data QA?
23
+
24
+ Every ML engineer and data scientist spends significant time debugging data quality issues — missing values, type mismatches, inconsistencies, and subtle statistical anomalies. This environment turns that task into a structured, gradable challenge.
25
+
26
+ ## Environment API
27
+
28
+ | Endpoint | Description |
29
+ |----------|-------------|
30
+ | `reset(task_id)` | Start a new episode with a corrupted dataset |
31
+ | `step(issues)` | Submit identified issues, receive F1-scored feedback |
32
+ | `state()` | Get current episode state |
33
+
34
+ ## Tasks
35
+
36
+ | Task | Issues | Difficulty | Description |
37
+ |------|--------|-----------|-------------|
38
+ | `easy` | 4 | Beginner | Employee directory — nulls, wrong types, duplicates, out-of-range |
39
+ | `medium` | 6 | Intermediate | E-commerce orders — format violations, inconsistent totals, duplicate keys |
40
+ | `hard` | 8 | Advanced | ML experiment metadata — data leakage signals, unreasonable GPU usage, timestamp ordering |
41
+
42
+ ## Reward Function
43
+
44
+ Scoring uses **F1 score** (harmonic mean of precision and recall):
45
+
46
+ - **Precision**: What fraction of reported issues are real?
47
+ - **Recall**: What fraction of planted issues did the agent find?
48
+ - **F1**: `2 * precision * recall / (precision + recall)`
49
+
50
+ Issues are matched by `row:<N>,col:<column>,issue:<type>` keys.
51
+
52
+ The agent gets up to 3 attempts per task with feedback on each attempt (true positives, false positives, missed count).
53
+
54
+ ## Action/Observation Space
55
+
56
+ **Action**: List of issue strings in format `row:<row_number>,col:<column_name>,issue:<issue_type>`
57
+
58
+ **Observation**: Dataset CSV + schema + validation rules + feedback from previous attempt
59
+
60
+ **Issue Types**: `missing_value`, `wrong_type`, `duplicate_row`, `out_of_range`, `format_violation`, `inconsistent_value`, `statistical_outlier`, `referential_integrity`
61
+
62
+ ## Quick Start
63
+
64
+ ```bash
65
+ # Install
66
+ pip install -e .
67
+
68
+ # Run server locally
69
+ uvicorn dataqa_env.server.app:app --host 0.0.0.0 --port 8000
70
+
71
+ # Run inference
72
+ API_BASE_URL=https://api.groq.com/openai/v1 \
73
+ MODEL_NAME=llama-3.3-70b-versatile \
74
+ LLM_API_KEY=your-key \
75
+ python inference.py
76
+ ```
77
+
78
+ ## Docker
79
+
80
+ ```bash
81
+ docker build -t dataqa-env -f dataqa_env/server/Dockerfile .
82
+ docker run -p 8000:8000 dataqa-env
83
+ ```
84
+
85
+ ## Environment Variables
86
+
87
+ | Variable | Description | Default |
88
+ |----------|-------------|---------|
89
+ | `API_BASE_URL` | LLM API endpoint | `https://api.groq.com/openai/v1` |
90
+ | `MODEL_NAME` | Model identifier | `llama-3.3-70b-versatile` |
91
+ | `HF_TOKEN` | HuggingFace token | - |
92
+ | `ENV_URL` | Environment server URL | `http://localhost:8000` |
93
+ | `LLM_API_KEY` | API key for LLM provider | Falls back to HF_TOKEN |
94
+
95
+ ## Architecture
96
+
97
+ ```
98
+ dataqa_env/
99
+ ├── models.py # Pydantic: DataQAAction, DataQAObservation, DataQAState
100
+ ├── client.py # EnvClient for WebSocket connections
101
+ ├── server/
102
+ │ ├── environment.py # Core DataQAEnvironment (reset/step/state)
103
+ │ ├── tasks.py # Task definitions + data corruption + grading
104
+ │ ├── app.py # FastAPI server
105
+ │ └── Dockerfile
106
+ ├── openenv.yaml
107
+ ├── pyproject.toml
108
+ └── inference.py # LLM agent using OpenAI client
109
+ ```
__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from dataqa_env import DataQAEnv, DataQAAction, DataQAObservation, DataQAState
2
+
3
+ __all__ = ["DataQAEnv", "DataQAAction", "DataQAObservation", "DataQAState"]
client.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Root-level client for OpenEnv compatibility."""
2
+ from dataqa_env.client import DataQAEnv
3
+ from dataqa_env.models import DataQAAction, DataQAObservation, DataQAState
4
+
5
+ __all__ = ["DataQAEnv", "DataQAAction", "DataQAObservation", "DataQAState"]
dataqa_env/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .client import DataQAEnv
2
+ from .models import DataQAAction, DataQAObservation, DataQAState
3
+
4
+ __all__ = ["DataQAEnv", "DataQAAction", "DataQAObservation", "DataQAState"]
dataqa_env/client.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DataQAEnv Client
3
+ ----------------
4
+ Client-side wrapper for the DataQA environment server.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ from openenv.core.client_types import StepResult
10
+ from openenv.core.env_client import EnvClient
11
+
12
+ from .models import DataQAAction, DataQAObservation, DataQAState
13
+
14
+
15
+ class DataQAEnv(EnvClient[DataQAAction, DataQAObservation, DataQAState]):
16
+
17
+ def _step_payload(self, action: DataQAAction) -> dict:
18
+ return {"issues": action.issues, "task_id": action.task_id}
19
+
20
+ def _parse_result(self, payload: dict) -> StepResult[DataQAObservation]:
21
+ obs = DataQAObservation(**payload["observation"])
22
+ return StepResult(
23
+ observation=obs,
24
+ reward=payload.get("reward"),
25
+ done=bool(payload.get("done", False)),
26
+ )
27
+
28
+ def _parse_state(self, payload: dict) -> DataQAState:
29
+ return DataQAState(
30
+ episode_id=payload.get("episode_id"),
31
+ step_count=payload.get("step_count", 0),
32
+ task_id=payload.get("task_id", ""),
33
+ current_step=payload.get("current_step", 0),
34
+ max_steps=payload.get("max_steps", 3),
35
+ best_score=payload.get("best_score", 0.0),
36
+ total_planted_issues=payload.get("total_planted_issues", 0),
37
+ )
dataqa_env/models.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DataQA Environment Models
3
+ -------------------------
4
+ Action/Observation/State types for the Data Quality Assurance environment.
5
+
6
+ The agent receives a dataset with planted quality issues and must identify them.
7
+ Grading is based on F1 score (precision × recall) of correctly identified issues.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from typing import List, Optional
13
+
14
+ from openenv.core.env_server.interfaces import Action, Observation, State
15
+
16
+
17
+ class DataQAAction(Action):
18
+ """
19
+ Agent submits a list of identified data quality issues.
20
+
21
+ Each issue is a string in the format: "row:<row_idx>,col:<col_name>,issue:<issue_type>"
22
+ Supported issue types:
23
+ - missing_value
24
+ - wrong_type
25
+ - duplicate_row
26
+ - out_of_range
27
+ - format_violation
28
+ - inconsistent_value
29
+ - statistical_outlier
30
+ - referential_integrity
31
+ """
32
+
33
+ issues: List[str]
34
+ # Include task_id so step() can reconstruct context in stateless HTTP mode
35
+ task_id: str = "easy"
36
+
37
+
38
+ class DataQAObservation(Observation):
39
+ """
40
+ What the agent sees: a dataset, its schema/rules, and feedback.
41
+ """
42
+
43
+ # The dataset as CSV text
44
+ dataset_csv: str = ""
45
+
46
+ # Schema description (column names, expected types, constraints)
47
+ schema_description: str = ""
48
+
49
+ # Validation rules in plain text
50
+ validation_rules: str = ""
51
+
52
+ # Task description
53
+ task_description: str = ""
54
+
55
+ # Feedback from previous step (empty on reset)
56
+ feedback: str = ""
57
+
58
+ # Current task ID
59
+ task_id: str = ""
60
+
61
+ # Number of planted issues (hint for the agent)
62
+ num_issues_hint: int = 0
63
+
64
+ # Max allowed steps for this task
65
+ max_steps: int = 3
66
+
67
+
68
+ class DataQAState(State):
69
+ """Tracks episode progress."""
70
+
71
+ task_id: str = ""
72
+ current_step: int = 0
73
+ max_steps: int = 3
74
+ best_score: float = 0.0
75
+ total_planted_issues: int = 0
dataqa_env/server/Dockerfile ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system deps
6
+ RUN apt-get update && apt-get install -y --no-install-recommends \
7
+ git curl \
8
+ && rm -rf /var/lib/apt/lists/*
9
+
10
+ # Install uv for fast dependency management
11
+ RUN curl -LsSf https://astral.sh/uv/install.sh | sh && \
12
+ mv /root/.local/bin/uv /usr/local/bin/uv && \
13
+ mv /root/.local/bin/uvx /usr/local/bin/uvx
14
+
15
+ # Copy project files
16
+ COPY . /app/env
17
+
18
+ WORKDIR /app/env
19
+
20
+ # Install dependencies
21
+ RUN uv sync --frozen --no-editable 2>/dev/null || uv sync --no-editable
22
+
23
+ # Set environment
24
+ ENV PATH="/app/env/.venv/bin:$PATH"
25
+ ENV PYTHONPATH="/app/env:$PYTHONPATH"
26
+
27
+ # Health check
28
+ HEALTHCHECK --interval=30s --timeout=3s --start-period=5s --retries=3 \
29
+ CMD python -c "import urllib.request; urllib.request.urlopen('http://localhost:8000/health')" || exit 1
30
+
31
+ EXPOSE 8000
32
+
33
+ CMD ["uvicorn", "dataqa_env.server.app:app", "--host", "0.0.0.0", "--port", "8000"]
dataqa_env/server/__init__.py ADDED
File without changes
dataqa_env/server/app.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI application for the DataQA Environment.
3
+
4
+ Usage:
5
+ uvicorn dataqa_env.server.app:app --reload --host 0.0.0.0 --port 8000
6
+ """
7
+
8
+ try:
9
+ from openenv.core.env_server.http_server import create_app
10
+ from .environment import DataQAEnvironment
11
+ from ..models import DataQAAction, DataQAObservation
12
+ except ImportError:
13
+ from openenv.core.env_server.http_server import create_app
14
+ from dataqa_env.server.environment import DataQAEnvironment
15
+ from dataqa_env.models import DataQAAction, DataQAObservation
16
+
17
+ app = create_app(
18
+ DataQAEnvironment, DataQAAction, DataQAObservation, env_name="dataqa_env"
19
+ )
20
+
21
+
22
+ def main():
23
+ import uvicorn
24
+ uvicorn.run(app, host="0.0.0.0", port=8000)
25
+
26
+
27
+ if __name__ == "__main__":
28
+ main()
dataqa_env/server/environment.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ DataQA Environment
3
+ ------------------
4
+ Server-side environment for data quality assurance tasks.
5
+
6
+ The agent receives corrupted datasets and must identify planted quality issues.
7
+ Scoring is based on F1 (precision-recall) of correctly matched issues.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import re
13
+ import uuid
14
+ from typing import Any, Optional, Set
15
+
16
+ from openenv.core.env_server.interfaces import Action, Environment, Observation
17
+
18
+ from ..models import DataQAAction, DataQAObservation, DataQAState
19
+ from .tasks import PlantedIssue, Task, get_task, list_tasks
20
+
21
+
22
+ def parse_issue_key(raw: str) -> Optional[str]:
23
+ """
24
+ Parse an agent-reported issue string into a normalized key.
25
+ Expected format: row:<N>,col:<name>,issue:<type>
26
+ Returns normalized key or None if unparseable.
27
+ """
28
+ raw = raw.strip().lower()
29
+ # Be lenient with formatting
30
+ row_match = re.search(r"row\s*[:=]\s*(\d+)", raw)
31
+ col_match = re.search(r"col\s*[:=]\s*([\w_]+)", raw)
32
+ issue_match = re.search(r"issue\s*[:=]\s*([\w_]+)", raw)
33
+
34
+ if row_match and col_match and issue_match:
35
+ return f"row:{row_match.group(1)},col:{col_match.group(1)},issue:{issue_match.group(1)}"
36
+ return None
37
+
38
+
39
+ def compute_f1(reported_keys: Set[str], planted_keys: Set[str]) -> dict:
40
+ """Compute precision, recall, and F1 score."""
41
+ if not reported_keys and not planted_keys:
42
+ return {"precision": 1.0, "recall": 1.0, "f1": 1.0, "tp": 0, "fp": 0, "fn": 0}
43
+
44
+ if not reported_keys:
45
+ return {"precision": 0.0, "recall": 0.0, "f1": 0.0, "tp": 0, "fp": 0, "fn": len(planted_keys)}
46
+
47
+ if not planted_keys:
48
+ return {"precision": 0.0, "recall": 0.0, "f1": 0.0, "tp": 0, "fp": len(reported_keys), "fn": 0}
49
+
50
+ tp = len(reported_keys & planted_keys)
51
+ fp = len(reported_keys - planted_keys)
52
+ fn = len(planted_keys - reported_keys)
53
+
54
+ precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
55
+ recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
56
+ f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0
57
+
58
+ return {"precision": precision, "recall": recall, "f1": f1, "tp": tp, "fp": fp, "fn": fn}
59
+
60
+
61
+ class DataQAEnvironment(Environment):
62
+ """
63
+ Data Quality Assurance environment.
64
+
65
+ The agent inspects corrupted datasets and reports quality issues.
66
+ Reward is F1 score of correctly identified issues vs planted ground truth.
67
+ """
68
+
69
+ SUPPORTS_CONCURRENT_SESSIONS = True
70
+
71
+ def __init__(self):
72
+ self._state = DataQAState()
73
+ self._current_task: Optional[Task] = None
74
+ self._planted_keys: Set[str] = set()
75
+ self._best_score: float = 0.0
76
+
77
+ def reset(
78
+ self,
79
+ seed: Optional[int] = None,
80
+ episode_id: Optional[str] = None,
81
+ **kwargs: Any,
82
+ ) -> Observation:
83
+ task_id = kwargs.get("task_id", "easy")
84
+ task_seed = seed if seed is not None else 42
85
+
86
+ self._current_task = get_task(task_id, seed=task_seed)
87
+ self._planted_keys = {issue.to_key() for issue in self._current_task.planted_issues}
88
+ self._best_score = 0.0
89
+
90
+ ep_id = episode_id or str(uuid.uuid4())
91
+ self._state = DataQAState(
92
+ episode_id=ep_id,
93
+ step_count=0,
94
+ task_id=task_id,
95
+ current_step=0,
96
+ max_steps=self._current_task.max_steps,
97
+ best_score=0.0,
98
+ total_planted_issues=len(self._current_task.planted_issues),
99
+ )
100
+
101
+ return DataQAObservation(
102
+ dataset_csv=self._current_task.corrupted_csv,
103
+ schema_description=self._current_task.schema_description,
104
+ validation_rules=self._current_task.validation_rules,
105
+ task_description=self._current_task.description,
106
+ feedback="Environment reset. Inspect the dataset and report all quality issues.",
107
+ task_id=task_id,
108
+ num_issues_hint=len(self._current_task.planted_issues),
109
+ max_steps=self._current_task.max_steps,
110
+ done=False,
111
+ reward=0.0,
112
+ )
113
+
114
+ def step(
115
+ self,
116
+ action: Action,
117
+ timeout_s: Optional[float] = None,
118
+ **kwargs: Any,
119
+ ) -> Observation:
120
+ if not isinstance(action, DataQAAction):
121
+ raise ValueError(f"Expected DataQAAction, got {type(action)}")
122
+
123
+ # In stateless HTTP mode, each request creates a fresh env instance.
124
+ # Auto-reset using the task_id from the action so step() works standalone.
125
+ if self._current_task is None:
126
+ self.reset(task_id=action.task_id)
127
+
128
+ self._state.step_count += 1
129
+ self._state.current_step += 1
130
+
131
+ # Parse reported issues
132
+ reported_keys: Set[str] = set()
133
+ parse_errors: list[str] = []
134
+ for raw_issue in action.issues:
135
+ key = parse_issue_key(raw_issue)
136
+ if key:
137
+ reported_keys.add(key)
138
+ else:
139
+ parse_errors.append(f"Could not parse: '{raw_issue}'")
140
+
141
+ # Compute score
142
+ metrics = compute_f1(reported_keys, self._planted_keys)
143
+ score = metrics["f1"]
144
+ self._best_score = max(self._best_score, score)
145
+ self._state.best_score = self._best_score
146
+
147
+ # Check if done
148
+ is_done = (
149
+ score >= 0.999 # Perfect score
150
+ or self._state.current_step >= self._state.max_steps
151
+ )
152
+
153
+ # Build feedback
154
+ feedback_lines = [
155
+ f"Step {self._state.current_step}/{self._state.max_steps}",
156
+ f"Issues reported: {len(reported_keys)}",
157
+ f"True positives: {metrics['tp']}, False positives: {metrics['fp']}, Missed: {metrics['fn']}",
158
+ f"Precision: {metrics['precision']:.3f}, Recall: {metrics['recall']:.3f}, F1: {score:.3f}",
159
+ ]
160
+
161
+ if parse_errors:
162
+ feedback_lines.append(f"Parse errors ({len(parse_errors)}): {'; '.join(parse_errors[:3])}")
163
+
164
+ if not is_done:
165
+ # Give hints about what was missed without revealing exact answers
166
+ if metrics["fn"] > 0:
167
+ feedback_lines.append(
168
+ f"You missed {metrics['fn']} issue(s). Review the dataset carefully."
169
+ )
170
+ if metrics["fp"] > 0:
171
+ feedback_lines.append(
172
+ f"{metrics['fp']} of your reported issues were incorrect."
173
+ )
174
+ feedback_lines.append("You can submit again with an updated list of issues.")
175
+ else:
176
+ feedback_lines.append(f"Task complete! Final best F1 score: {self._best_score:.3f}")
177
+
178
+ return DataQAObservation(
179
+ dataset_csv=self._current_task.corrupted_csv,
180
+ schema_description=self._current_task.schema_description,
181
+ validation_rules=self._current_task.validation_rules,
182
+ task_description=self._current_task.description,
183
+ feedback="\n".join(feedback_lines),
184
+ task_id=self._current_task.task_id,
185
+ num_issues_hint=len(self._current_task.planted_issues),
186
+ max_steps=self._state.max_steps,
187
+ done=is_done,
188
+ reward=self._best_score,
189
+ )
190
+
191
+ @property
192
+ def state(self) -> DataQAState:
193
+ return self._state
dataqa_env/server/tasks.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Task definitions for the DataQA environment.
3
+
4
+ Each task provides:
5
+ - A clean dataset (CSV)
6
+ - A schema + validation rules
7
+ - A set of planted issues (ground truth)
8
+ - A function to inject those issues into the clean data
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ import csv
14
+ import io
15
+ import random
16
+ from dataclasses import dataclass, field
17
+ from typing import List, Set
18
+
19
+
20
+ @dataclass
21
+ class PlantedIssue:
22
+ """A single planted data quality issue."""
23
+
24
+ row: int
25
+ col: str
26
+ issue_type: str
27
+ description: str
28
+
29
+ def to_key(self) -> str:
30
+ return f"row:{self.row},col:{self.col},issue:{self.issue_type}"
31
+
32
+
33
+ @dataclass
34
+ class Task:
35
+ task_id: str
36
+ name: str
37
+ description: str
38
+ schema_description: str
39
+ validation_rules: str
40
+ clean_csv: str
41
+ planted_issues: List[PlantedIssue] = field(default_factory=list)
42
+ corrupted_csv: str = ""
43
+ max_steps: int = 3
44
+
45
+
46
+ def _csv_to_rows(csv_text: str) -> List[List[str]]:
47
+ reader = csv.reader(io.StringIO(csv_text.strip()))
48
+ return [row for row in reader]
49
+
50
+
51
+ def _rows_to_csv(rows: List[List[str]]) -> str:
52
+ output = io.StringIO()
53
+ writer = csv.writer(output)
54
+ writer.writerows(rows)
55
+ return output.getvalue()
56
+
57
+
58
+ # ---------------------------------------------------------------------------
59
+ # TASK 1: Easy — Employee directory with obvious issues
60
+ # ---------------------------------------------------------------------------
61
+
62
+ def create_task_easy(seed: int = 42) -> Task:
63
+ rng = random.Random(seed)
64
+
65
+ clean_csv = """employee_id,name,email,department,salary,start_date
66
+ 101,Alice Chen,alice.chen@company.com,Engineering,95000,2022-03-15
67
+ 102,Bob Martinez,bob.martinez@company.com,Marketing,72000,2021-07-01
68
+ 103,Carol Davis,carol.davis@company.com,Engineering,98000,2020-11-20
69
+ 104,David Kim,david.kim@company.com,Sales,68000,2023-01-10
70
+ 105,Eve Johnson,eve.johnson@company.com,HR,71000,2022-06-05
71
+ 106,Frank Wilson,frank.wilson@company.com,Engineering,102000,2019-08-12
72
+ 107,Grace Lee,grace.lee@company.com,Marketing,75000,2021-12-01
73
+ 108,Hank Brown,hank.brown@company.com,Sales,65000,2023-04-18
74
+ 109,Iris Patel,iris.patel@company.com,HR,73000,2020-02-28
75
+ 110,Jack Taylor,jack.taylor@company.com,Engineering,97000,2022-09-14"""
76
+
77
+ schema_desc = """Columns:
78
+ - employee_id: integer, unique, range 100-999
79
+ - name: string, non-empty, format "FirstName LastName"
80
+ - email: string, valid email format, must match pattern firstname.lastname@company.com
81
+ - department: string, one of [Engineering, Marketing, Sales, HR]
82
+ - salary: integer, range 50000-150000
83
+ - start_date: string, format YYYY-MM-DD, must be between 2015-01-01 and 2025-12-31"""
84
+
85
+ rules = """1. No missing values in any column
86
+ 2. employee_id must be unique
87
+ 3. email must follow the pattern: lowercase(firstname).lowercase(lastname)@company.com
88
+ 4. salary must be within the valid range
89
+ 5. No duplicate rows"""
90
+
91
+ rows = _csv_to_rows(clean_csv)
92
+ header = rows[0]
93
+ data = rows[1:]
94
+ issues: List[PlantedIssue] = []
95
+
96
+ # Issue 1: Missing value - null out a name
97
+ r = 3 # row index in data (0-based), displayed as row 4 in CSV
98
+ data[r][1] = ""
99
+ issues.append(PlantedIssue(row=r + 1, col="name", issue_type="missing_value",
100
+ description="Empty name field"))
101
+
102
+ # Issue 2: Wrong type - salary as text
103
+ r = 6
104
+ data[r][4] = "seventy-five thousand"
105
+ issues.append(PlantedIssue(row=r + 1, col="salary", issue_type="wrong_type",
106
+ description="Salary is text instead of integer"))
107
+
108
+ # Issue 3: Duplicate row
109
+ dup_source = 1
110
+ data.append(list(data[dup_source]))
111
+ issues.append(PlantedIssue(row=len(data), col="employee_id", issue_type="duplicate_row",
112
+ description=f"Exact duplicate of row {dup_source + 1}"))
113
+
114
+ # Issue 4: Out of range salary
115
+ r = 8
116
+ data[r][4] = "5000"
117
+ issues.append(PlantedIssue(row=r + 1, col="salary", issue_type="out_of_range",
118
+ description="Salary 5000 is below minimum 50000"))
119
+
120
+ corrupted = _rows_to_csv([header] + data)
121
+
122
+ return Task(
123
+ task_id="easy",
124
+ name="Employee Directory Validation",
125
+ description=(
126
+ "You are given an employee directory dataset. "
127
+ "Find all data quality issues based on the schema and validation rules. "
128
+ "Report each issue in the format: row:<row_number>,col:<column_name>,issue:<issue_type>"
129
+ ),
130
+ schema_description=schema_desc,
131
+ validation_rules=rules,
132
+ clean_csv=clean_csv,
133
+ planted_issues=issues,
134
+ corrupted_csv=corrupted,
135
+ max_steps=3,
136
+ )
137
+
138
+
139
+ # ---------------------------------------------------------------------------
140
+ # TASK 2: Medium — E-commerce orders with moderate issues
141
+ # ---------------------------------------------------------------------------
142
+
143
+ def create_task_medium(seed: int = 42) -> Task:
144
+ rng = random.Random(seed)
145
+
146
+ clean_csv = """order_id,customer_id,product_name,category,quantity,unit_price,order_date,shipping_country,status,total
147
+ ORD-001,CUST-100,Wireless Mouse,Electronics,2,29.99,2024-01-15,US,delivered,59.98
148
+ ORD-002,CUST-101,Python Cookbook,Books,1,45.50,2024-01-16,UK,delivered,45.50
149
+ ORD-003,CUST-102,USB-C Hub,Electronics,1,35.00,2024-01-17,US,shipped,35.00
150
+ ORD-004,CUST-103,Yoga Mat,Sports,1,25.99,2024-01-18,CA,delivered,25.99
151
+ ORD-005,CUST-104,Desk Lamp,Home,1,42.00,2024-01-19,US,processing,42.00
152
+ ORD-006,CUST-105,Running Shoes,Sports,1,89.99,2024-01-20,DE,delivered,89.99
153
+ ORD-007,CUST-106,Mechanical Keyboard,Electronics,1,129.99,2024-01-21,US,shipped,129.99
154
+ ORD-008,CUST-100,Monitor Stand,Home,1,55.00,2024-01-22,US,delivered,55.00
155
+ ORD-009,CUST-107,Data Science Handbook,Books,2,39.99,2024-01-23,UK,delivered,79.98
156
+ ORD-010,CUST-108,Resistance Bands,Sports,3,12.99,2024-01-24,CA,shipped,38.97
157
+ ORD-011,CUST-109,Webcam HD,Electronics,1,65.00,2024-01-25,US,delivered,65.00
158
+ ORD-012,CUST-110,Standing Desk,Home,1,299.99,2024-01-26,US,processing,299.99
159
+ ORD-013,CUST-111,Tennis Racket,Sports,1,75.00,2024-01-27,AU,delivered,75.00
160
+ ORD-014,CUST-112,LED Strip Lights,Home,2,18.50,2024-01-28,US,shipped,37.00
161
+ ORD-015,CUST-113,AI Textbook,Books,1,59.99,2024-01-29,DE,delivered,59.99
162
+ ORD-016,CUST-114,Bluetooth Speaker,Electronics,1,49.99,2024-01-30,UK,delivered,49.99
163
+ ORD-017,CUST-115,Jump Rope,Sports,2,8.99,2024-01-31,US,shipped,17.98
164
+ ORD-018,CUST-116,Coffee Table Book,Books,1,32.00,2024-02-01,CA,delivered,32.00
165
+ ORD-019,CUST-117,Ergonomic Chair,Home,1,450.00,2024-02-02,US,processing,450.00
166
+ ORD-020,CUST-118,Fitness Tracker,Electronics,1,79.99,2024-02-03,AU,delivered,79.99"""
167
+
168
+ schema_desc = """Columns:
169
+ - order_id: string, unique, format ORD-NNN
170
+ - customer_id: string, format CUST-NNN
171
+ - product_name: string, non-empty
172
+ - category: string, one of [Electronics, Books, Sports, Home]
173
+ - quantity: integer, range 1-100
174
+ - unit_price: float, range 0.01-10000.00
175
+ - order_date: string, format YYYY-MM-DD
176
+ - shipping_country: string, ISO 2-letter country code
177
+ - status: string, one of [processing, shipped, delivered, cancelled, returned]
178
+ - total: float, must equal quantity * unit_price"""
179
+
180
+ rules = """1. No missing values in any column
181
+ 2. order_id must be unique
182
+ 3. total must equal quantity * unit_price (tolerance: 0.01)
183
+ 4. order_date must be in valid chronological order for sequential order_ids
184
+ 5. category must be from the allowed set
185
+ 6. All monetary values must have at most 2 decimal places
186
+ 7. shipping_country must be a valid ISO 2-letter code"""
187
+
188
+ rows = _csv_to_rows(clean_csv)
189
+ header = rows[0]
190
+ data = rows[1:]
191
+ issues: List[PlantedIssue] = []
192
+
193
+ # Issue 1: total doesn't match quantity * unit_price
194
+ r = 4 # ORD-005
195
+ data[r][9] = "84.00" # should be 42.00 (qty=1, price=42.00)
196
+ issues.append(PlantedIssue(row=r + 1, col="total", issue_type="inconsistent_value",
197
+ description="total (84.00) != quantity (1) * unit_price (42.00)"))
198
+
199
+ # Issue 2: Invalid category
200
+ r = 9 # ORD-010
201
+ data[r][3] = "Fitness" # should be Sports
202
+ issues.append(PlantedIssue(row=r + 1, col="category", issue_type="format_violation",
203
+ description="'Fitness' is not in allowed categories"))
204
+
205
+ # Issue 3: Missing value in product_name
206
+ r = 13 # ORD-014
207
+ data[r][2] = ""
208
+ issues.append(PlantedIssue(row=r + 1, col="product_name", issue_type="missing_value",
209
+ description="Empty product_name"))
210
+
211
+ # Issue 4: Out of range quantity
212
+ r = 16 # ORD-017
213
+ data[r][4] = "-1"
214
+ issues.append(PlantedIssue(row=r + 1, col="quantity", issue_type="out_of_range",
215
+ description="Negative quantity"))
216
+
217
+ # Issue 5: Duplicate order_id
218
+ r = 18 # ORD-019
219
+ data[r][0] = "ORD-003"
220
+ issues.append(PlantedIssue(row=r + 1, col="order_id", issue_type="duplicate_row",
221
+ description="Duplicate order_id ORD-003"))
222
+
223
+ # Issue 6: Wrong date format
224
+ r = 11 # ORD-012
225
+ data[r][6] = "26/01/2024"
226
+ issues.append(PlantedIssue(row=r + 1, col="order_date", issue_type="format_violation",
227
+ description="Date format DD/MM/YYYY instead of YYYY-MM-DD"))
228
+
229
+ corrupted = _rows_to_csv([header] + data)
230
+
231
+ return Task(
232
+ task_id="medium",
233
+ name="E-commerce Orders Validation",
234
+ description=(
235
+ "You are given an e-commerce orders dataset. "
236
+ "Find all data quality issues based on the schema and validation rules. "
237
+ "Report each issue in the format: row:<row_number>,col:<column_name>,issue:<issue_type>"
238
+ ),
239
+ schema_description=schema_desc,
240
+ validation_rules=rules,
241
+ clean_csv=clean_csv,
242
+ planted_issues=issues,
243
+ corrupted_csv=corrupted,
244
+ max_steps=3,
245
+ )
246
+
247
+
248
+ # ---------------------------------------------------------------------------
249
+ # TASK 3: Hard — ML training metadata with subtle issues
250
+ # ---------------------------------------------------------------------------
251
+
252
+ def create_task_hard(seed: int = 42) -> Task:
253
+ rng = random.Random(seed)
254
+
255
+ clean_csv = """experiment_id,model_name,dataset,train_size,val_size,test_size,learning_rate,batch_size,epochs,train_loss,val_loss,test_accuracy,gpu_memory_gb,training_time_hours,timestamp
256
+ EXP-001,resnet50,imagenet-1k,1281167,50000,100000,0.001,256,90,0.85,1.12,76.3,12.4,48.5,2024-03-01T10:00:00
257
+ EXP-002,bert-base,squad-v2,130319,11873,8862,0.00003,32,3,0.45,0.52,81.2,7.8,2.1,2024-03-02T14:30:00
258
+ EXP-003,gpt2-small,openwebtext,8013769,100000,100000,0.0003,64,1,3.12,3.28,0.0,14.2,72.0,2024-03-03T09:15:00
259
+ EXP-004,vit-base,imagenet-1k,1281167,50000,100000,0.001,512,300,0.72,0.98,79.8,15.6,96.0,2024-03-05T08:00:00
260
+ EXP-005,distilbert,mnli,392702,9815,9796,0.00005,16,5,0.28,0.35,84.6,5.2,1.5,2024-03-06T11:00:00
261
+ EXP-006,llama2-7b,alpaca-52k,51760,500,500,0.00002,4,3,1.05,1.18,0.0,38.5,8.2,2024-03-07T16:00:00
262
+ EXP-007,resnet18,cifar10,50000,5000,10000,0.01,128,200,0.15,0.28,93.5,3.2,1.8,2024-03-08T10:30:00
263
+ EXP-008,t5-small,cnn-dailymail,287113,13368,11490,0.0001,16,10,1.45,1.62,0.0,6.8,4.5,2024-03-09T13:00:00
264
+ EXP-009,efficientnet-b0,imagenet-1k,1281167,50000,100000,0.005,256,350,0.68,0.89,77.1,8.4,36.0,2024-03-10T07:45:00
265
+ EXP-010,roberta-large,sst2,67349,872,1821,0.00001,8,10,0.08,0.12,95.1,14.8,3.2,2024-03-11T15:00:00
266
+ EXP-011,yolov5-m,coco-2017,118287,5000,40670,0.01,32,300,0.032,0.045,0.0,10.2,24.0,2024-03-12T09:00:00
267
+ EXP-012,wav2vec2,librispeech,281241,5567,2620,0.0001,8,20,0.92,1.05,0.0,12.6,15.0,2024-03-13T11:30:00
268
+ EXP-013,clip-base,cc3m,2818102,15000,15000,0.00001,256,32,2.15,2.38,0.0,22.4,48.0,2024-03-14T08:00:00
269
+ EXP-014,detr,coco-2017,118287,5000,40670,0.0001,4,500,1.85,2.12,0.0,16.0,72.0,2024-03-15T10:00:00
270
+ EXP-015,whisper-small,common-voice,520000,16000,16000,0.00005,16,5,0.55,0.68,0.0,7.4,6.5,2024-03-16T14:00:00"""
271
+
272
+ schema_desc = """Columns:
273
+ - experiment_id: string, unique, format EXP-NNN
274
+ - model_name: string, non-empty
275
+ - dataset: string, non-empty
276
+ - train_size: integer, positive, must be > val_size and > test_size
277
+ - val_size: integer, positive
278
+ - test_size: integer, positive
279
+ - learning_rate: float, range 1e-7 to 1.0
280
+ - batch_size: integer, must be power of 2, range 1-1024
281
+ - epochs: integer, positive, range 1-1000
282
+ - train_loss: float, non-negative
283
+ - val_loss: float, non-negative, typically >= train_loss (if not, may indicate data leakage)
284
+ - test_accuracy: float, range 0-100 (percentage), 0.0 is valid for generative models
285
+ - gpu_memory_gb: float, positive
286
+ - training_time_hours: float, positive
287
+ - timestamp: string, ISO 8601 format, chronological order by experiment_id"""
288
+
289
+ rules = """1. No missing values
290
+ 2. experiment_id must be unique
291
+ 3. val_loss should be >= train_loss (if val_loss < train_loss significantly, flag as potential data leakage)
292
+ 4. batch_size must be a power of 2
293
+ 5. train_size must be larger than both val_size and test_size
294
+ 6. learning_rate must be within valid range
295
+ 7. gpu_memory_gb should be reasonable for the model size (e.g., resnet18 shouldn't need 40GB)
296
+ 8. training_time should be proportional to dataset size and epochs (flag major inconsistencies)
297
+ 9. timestamps must be in chronological order"""
298
+
299
+ rows = _csv_to_rows(clean_csv)
300
+ header = rows[0]
301
+ data = rows[1:]
302
+ issues: List[PlantedIssue] = []
303
+
304
+ # Issue 1: Data leakage signal — val_loss much lower than train_loss
305
+ r = 4 # EXP-005
306
+ data[r][10] = "0.15" # val_loss=0.15 but train_loss=0.28 → suspicious
307
+ issues.append(PlantedIssue(row=r + 1, col="val_loss", issue_type="inconsistent_value",
308
+ description="val_loss (0.15) significantly less than train_loss (0.28), potential data leakage"))
309
+
310
+ # Issue 2: Batch size not power of 2
311
+ r = 8 # EXP-009
312
+ data[r][7] = "250" # not a power of 2
313
+ issues.append(PlantedIssue(row=r + 1, col="batch_size", issue_type="format_violation",
314
+ description="batch_size 250 is not a power of 2"))
315
+
316
+ # Issue 3: GPU memory unreasonable for model
317
+ r = 6 # EXP-007 resnet18 on cifar10
318
+ data[r][12] = "42.5" # resnet18 shouldn't need 42.5 GB
319
+ issues.append(PlantedIssue(row=r + 1, col="gpu_memory_gb", issue_type="statistical_outlier",
320
+ description="resnet18 on cifar10 using 42.5 GB GPU memory is unreasonable"))
321
+
322
+ # Issue 4: Timestamp out of order
323
+ r = 10 # EXP-011
324
+ data[r][14] = "2024-03-02T09:00:00" # should be after EXP-010's timestamp
325
+ issues.append(PlantedIssue(row=r + 1, col="timestamp", issue_type="inconsistent_value",
326
+ description="Timestamp 2024-03-02 is before EXP-010's timestamp 2024-03-11"))
327
+
328
+ # Issue 5: Train size smaller than test size
329
+ r = 9 # EXP-010
330
+ data[r][3] = "500" # train_size=500 but test_size=1821
331
+ issues.append(PlantedIssue(row=r + 1, col="train_size", issue_type="inconsistent_value",
332
+ description="train_size (500) is smaller than test_size (1821)"))
333
+
334
+ # Issue 6: Negative training time
335
+ r = 13 # EXP-014
336
+ data[r][13] = "-72.0"
337
+ issues.append(PlantedIssue(row=r + 1, col="training_time_hours", issue_type="out_of_range",
338
+ description="Negative training time"))
339
+
340
+ # Issue 7: Learning rate out of range
341
+ r = 12 # EXP-013
342
+ data[r][6] = "2.5" # way too high
343
+ issues.append(PlantedIssue(row=r + 1, col="learning_rate", issue_type="out_of_range",
344
+ description="Learning rate 2.5 exceeds maximum of 1.0"))
345
+
346
+ # Issue 8: Missing model name (subtle — single space instead of empty)
347
+ r = 14 # EXP-015
348
+ data[r][1] = " "
349
+ issues.append(PlantedIssue(row=r + 1, col="model_name", issue_type="missing_value",
350
+ description="model_name is whitespace-only"))
351
+
352
+ corrupted = _rows_to_csv([header] + data)
353
+
354
+ return Task(
355
+ task_id="hard",
356
+ name="ML Experiment Metadata Validation",
357
+ description=(
358
+ "You are given an ML experiment tracking dataset. "
359
+ "Find all data quality issues based on the schema and validation rules. "
360
+ "This dataset contains subtle issues including potential data leakage signals, "
361
+ "unreasonable resource usage, and logical inconsistencies. "
362
+ "Report each issue in the format: row:<row_number>,col:<column_name>,issue:<issue_type>"
363
+ ),
364
+ schema_description=schema_desc,
365
+ validation_rules=rules,
366
+ clean_csv=clean_csv,
367
+ planted_issues=issues,
368
+ corrupted_csv=corrupted,
369
+ max_steps=3,
370
+ )
371
+
372
+
373
+ # ---------------------------------------------------------------------------
374
+ # Task registry
375
+ # ---------------------------------------------------------------------------
376
+
377
+ TASK_REGISTRY = {
378
+ "easy": create_task_easy,
379
+ "medium": create_task_medium,
380
+ "hard": create_task_hard,
381
+ }
382
+
383
+
384
+ def get_task(task_id: str, seed: int = 42) -> Task:
385
+ if task_id not in TASK_REGISTRY:
386
+ raise ValueError(f"Unknown task: {task_id}. Available: {list(TASK_REGISTRY.keys())}")
387
+ return TASK_REGISTRY[task_id](seed=seed)
388
+
389
+
390
+ def list_tasks() -> List[str]:
391
+ return list(TASK_REGISTRY.keys())
inference.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ DataQA Inference Script
4
+ -----------------------
5
+ LLM agent that plays the DataQA environment.
6
+ Uses the OpenAI client to interact with any OpenAI-compatible LLM API.
7
+
8
+ Required environment variables:
9
+ API_BASE_URL - LLM API endpoint (e.g., https://api.groq.com/openai/v1)
10
+ MODEL_NAME - Model identifier (e.g., llama-3.3-70b-versatile)
11
+ HF_TOKEN - HuggingFace token (for HF Spaces access)
12
+
13
+ Structured logging format: [START], [STEP], [END] tags for evaluation.
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import json
19
+ import os
20
+ import re
21
+ import sys
22
+ import time
23
+ from typing import Optional
24
+
25
+ import requests
26
+ from openai import OpenAI
27
+
28
+ # ---------------------------------------------------------------------------
29
+ # Configuration
30
+ # ---------------------------------------------------------------------------
31
+ API_BASE_URL = os.environ.get("API_BASE_URL", "https://api.groq.com/openai/v1")
32
+ MODEL_NAME = os.environ.get("MODEL_NAME", "llama-3.3-70b-versatile")
33
+ HF_TOKEN = os.environ.get("HF_TOKEN", "")
34
+ ENV_URL = os.environ.get("ENV_URL", "http://localhost:8000")
35
+
36
+ TASKS = ["easy", "medium", "hard"]
37
+ MAX_STEPS_PER_TASK = 3
38
+
39
+ # ---------------------------------------------------------------------------
40
+ # Logging helpers (structured stdout for evaluation)
41
+ # ---------------------------------------------------------------------------
42
+
43
+ def log_start(task_id: str, metadata: Optional[dict] = None):
44
+ entry = {"event": "START", "task_id": task_id, "timestamp": time.time()}
45
+ if metadata:
46
+ entry["metadata"] = metadata
47
+ print(f"[START] {json.dumps(entry)}", flush=True)
48
+
49
+
50
+ def log_step(task_id: str, step: int, reward: float, details: Optional[dict] = None):
51
+ entry = {
52
+ "event": "STEP",
53
+ "task_id": task_id,
54
+ "step": step,
55
+ "reward": reward,
56
+ "timestamp": time.time(),
57
+ }
58
+ if details:
59
+ entry["details"] = details
60
+ print(f"[STEP] {json.dumps(entry)}", flush=True)
61
+
62
+
63
+ def log_end(task_id: str, final_score: float, metadata: Optional[dict] = None):
64
+ entry = {
65
+ "event": "END",
66
+ "task_id": task_id,
67
+ "final_score": final_score,
68
+ "timestamp": time.time(),
69
+ }
70
+ if metadata:
71
+ entry["metadata"] = metadata
72
+ print(f"[END] {json.dumps(entry)}", flush=True)
73
+
74
+
75
+ # ---------------------------------------------------------------------------
76
+ # Environment HTTP client (simple, no WebSocket needed for inference)
77
+ # ---------------------------------------------------------------------------
78
+
79
+ class EnvHTTPClient:
80
+ """Minimal HTTP client for the DataQA environment."""
81
+
82
+ def __init__(self, base_url: str):
83
+ self.base_url = base_url.rstrip("/")
84
+ self.session = requests.Session()
85
+
86
+ def health(self) -> bool:
87
+ try:
88
+ r = self.session.get(f"{self.base_url}/health", timeout=10)
89
+ return r.status_code == 200
90
+ except Exception:
91
+ return False
92
+
93
+ def reset(self, task_id: str = "easy") -> dict:
94
+ r = self.session.post(
95
+ f"{self.base_url}/reset",
96
+ json={"task_id": task_id},
97
+ timeout=30,
98
+ )
99
+ r.raise_for_status()
100
+ return r.json()
101
+
102
+ def step(self, issues: list[str], task_id: str = "easy") -> dict:
103
+ r = self.session.post(
104
+ f"{self.base_url}/step",
105
+ json={"action": {"issues": issues, "task_id": task_id}},
106
+ timeout=30,
107
+ )
108
+ r.raise_for_status()
109
+ return r.json()
110
+
111
+ def state(self) -> dict:
112
+ r = self.session.get(f"{self.base_url}/state", timeout=10)
113
+ r.raise_for_status()
114
+ return r.json()
115
+
116
+
117
+ # ---------------------------------------------------------------------------
118
+ # LLM Agent
119
+ # ---------------------------------------------------------------------------
120
+
121
+ SYSTEM_PROMPT = """You are a data quality analyst. Your job is to inspect datasets and identify data quality issues.
122
+
123
+ You will be given:
124
+ 1. A dataset in CSV format
125
+ 2. A schema describing expected column types and constraints
126
+ 3. Validation rules that the data should satisfy
127
+
128
+ You must identify ALL data quality issues and report each one in EXACTLY this format:
129
+ row:<row_number>,col:<column_name>,issue:<issue_type>
130
+
131
+ Supported issue types:
132
+ - missing_value (null, empty, or whitespace-only)
133
+ - wrong_type (value doesn't match expected type)
134
+ - duplicate_row (exact duplicate or duplicate key)
135
+ - out_of_range (value outside valid range)
136
+ - format_violation (wrong format, invalid enum value)
137
+ - inconsistent_value (computed field doesn't match, logical inconsistency)
138
+ - statistical_outlier (value is unreasonable given context)
139
+ - referential_integrity (foreign key violation)
140
+
141
+ CRITICAL INSTRUCTIONS FOR ROW NUMBERING:
142
+ - Row numbers refer to the ROW POSITION in the CSV data, NOT the value of any ID column
143
+ - Row 1 = the FIRST data row after the header
144
+ - Row 2 = the SECOND data row after the header
145
+ - For example, if the CSV has header on line 1 and data starting on line 2, the data on line 2 is row 1, line 3 is row 2, etc.
146
+ - DO NOT use the employee_id, order_id, or experiment_id as the row number
147
+ - Column names must match exactly (use the CSV header names, lowercase)
148
+ - Check EVERY row and EVERY column systematically
149
+ - Consider cross-column consistency (e.g., total = quantity * price)
150
+ - Look for subtle issues like whitespace-only values, near-duplicates
151
+ - Report ALL issues you find, even if uncertain
152
+
153
+ Respond with ONLY the list of issues, one per line. No other text.
154
+ Example: row:3,col:salary,issue:missing_value"""
155
+
156
+
157
+ def build_user_prompt(observation: dict) -> str:
158
+ obs = observation if isinstance(observation, dict) else observation
159
+ parts = []
160
+
161
+ if obs.get("task_description"):
162
+ parts.append(f"TASK: {obs['task_description']}")
163
+
164
+ parts.append(f"SCHEMA:\n{obs.get('schema_description', '')}")
165
+ parts.append(f"VALIDATION RULES:\n{obs.get('validation_rules', '')}")
166
+ parts.append(f"DATASET:\n{obs.get('dataset_csv', '')}")
167
+
168
+ hint = obs.get("num_issues_hint", 0)
169
+ if hint:
170
+ parts.append(f"HINT: There are exactly {hint} issues to find.")
171
+
172
+ feedback = obs.get("feedback", "")
173
+ if feedback and "reset" not in feedback.lower():
174
+ parts.append(f"FEEDBACK FROM PREVIOUS ATTEMPT:\n{feedback}")
175
+
176
+ return "\n\n".join(parts)
177
+
178
+
179
+ def parse_llm_response(response: str) -> list[str]:
180
+ """Extract issue lines from LLM response."""
181
+ issues = []
182
+ for line in response.strip().split("\n"):
183
+ line = line.strip()
184
+ if not line:
185
+ continue
186
+ # Remove numbering like "1. " or "- " or "* "
187
+ line = re.sub(r"^\s*[\d]+[.\)]\s*", "", line)
188
+ line = re.sub(r"^\s*[-*]\s*", "", line)
189
+ line = line.strip()
190
+ if "row" in line.lower() and "col" in line.lower():
191
+ # Lenient regex: accept : or = as delimiters, case-insensitive
192
+ match = re.search(
193
+ r"row\s*[:=]\s*(\d+)\s*[,;\s]+col(?:umn)?\s*[:=]\s*([\w_]+)\s*[,;\s]+issue\s*[:=]\s*([\w_]+)",
194
+ line,
195
+ re.IGNORECASE,
196
+ )
197
+ if match:
198
+ # Normalize to lowercase canonical format
199
+ normalized = f"row:{match.group(1)},col:{match.group(2).lower()},issue:{match.group(3).lower()}"
200
+ issues.append(normalized)
201
+ return issues
202
+
203
+
204
+ def run_task(client: OpenAI, env: EnvHTTPClient, task_id: str) -> float:
205
+ """Run a single task and return the best score."""
206
+ log_start(task_id)
207
+
208
+ # Reset environment for this task
209
+ reset_response = env.reset(task_id=task_id)
210
+ observation = reset_response.get("observation", reset_response)
211
+
212
+ best_score = 0.0
213
+ messages = [{"role": "system", "content": SYSTEM_PROMPT}]
214
+
215
+ for step_num in range(1, MAX_STEPS_PER_TASK + 1):
216
+ user_prompt = build_user_prompt(observation)
217
+ messages_for_call = messages + [{"role": "user", "content": user_prompt}]
218
+
219
+ # Call LLM with retry on rate limit
220
+ llm_output = ""
221
+ for attempt in range(3):
222
+ try:
223
+ response = client.chat.completions.create(
224
+ model=MODEL_NAME,
225
+ messages=messages_for_call,
226
+ temperature=0.1,
227
+ max_tokens=2048,
228
+ )
229
+ llm_output = response.choices[0].message.content or ""
230
+ break
231
+ except Exception as e:
232
+ if "rate_limit" in str(e).lower() or "429" in str(e):
233
+ wait = 10 * (attempt + 1)
234
+ print(f"[WARN] Rate limited, waiting {wait}s...", flush=True)
235
+ time.sleep(wait)
236
+ else:
237
+ print(f"[ERROR] LLM call failed: {e}", file=sys.stderr, flush=True)
238
+ break
239
+
240
+ # Parse issues from LLM response
241
+ issues = parse_llm_response(llm_output)
242
+
243
+ if not issues:
244
+ print(f"[WARN] No issues parsed from LLM response for {task_id} step {step_num}", file=sys.stderr, flush=True)
245
+
246
+ # Submit to environment
247
+ step_response = env.step(issues, task_id=task_id)
248
+ observation = step_response.get("observation", step_response)
249
+
250
+ # reward and done are at the top level of the response, not inside observation
251
+ reward = float(step_response.get("reward", 0.0) or 0.0)
252
+ done = bool(step_response.get("done", False))
253
+ best_score = max(best_score, reward)
254
+
255
+ log_step(task_id, step_num, reward, {
256
+ "issues_reported": len(issues),
257
+ "feedback": observation.get("feedback", ""),
258
+ })
259
+
260
+ if done:
261
+ break
262
+
263
+ # Add context for next attempt
264
+ messages.append({"role": "user", "content": user_prompt})
265
+ messages.append({"role": "assistant", "content": llm_output})
266
+
267
+ log_end(task_id, best_score)
268
+ return best_score
269
+
270
+
271
+ # ---------------------------------------------------------------------------
272
+ # Main
273
+ # ---------------------------------------------------------------------------
274
+
275
+ def main():
276
+ print(f"[INFO] DataQA Inference starting", flush=True)
277
+ print(f"[INFO] ENV_URL={ENV_URL}", flush=True)
278
+ print(f"[INFO] API_BASE_URL={API_BASE_URL}", flush=True)
279
+ print(f"[INFO] MODEL_NAME={MODEL_NAME}", flush=True)
280
+
281
+ # Initialize clients
282
+ env = EnvHTTPClient(ENV_URL)
283
+ llm_client = OpenAI(
284
+ base_url=API_BASE_URL,
285
+ api_key=os.environ.get("LLM_API_KEY", HF_TOKEN or "no-key"),
286
+ )
287
+
288
+ # Check environment health
289
+ if not env.health():
290
+ print("[ERROR] Environment is not healthy. Exiting.", file=sys.stderr, flush=True)
291
+ sys.exit(1)
292
+
293
+ print(f"[INFO] Environment is healthy", flush=True)
294
+
295
+ # Run all tasks
296
+ scores = {}
297
+ for task_id in TASKS:
298
+ print(f"\n{'='*60}", flush=True)
299
+ print(f"[INFO] Starting task: {task_id}", flush=True)
300
+ print(f"{'='*60}", flush=True)
301
+
302
+ try:
303
+ score = run_task(llm_client, env, task_id)
304
+ scores[task_id] = score
305
+ print(f"[INFO] Task {task_id} completed with score: {score:.3f}", flush=True)
306
+ except Exception as e:
307
+ print(f"[ERROR] Task {task_id} failed: {e}", file=sys.stderr, flush=True)
308
+ scores[task_id] = 0.0
309
+
310
+ # Summary
311
+ print(f"\n{'='*60}", flush=True)
312
+ print("[INFO] FINAL RESULTS", flush=True)
313
+ print(f"{'='*60}", flush=True)
314
+ for task_id, score in scores.items():
315
+ print(f"[INFO] {task_id}: {score:.3f}", flush=True)
316
+
317
+ avg_score = sum(scores.values()) / len(scores) if scores else 0.0
318
+ print(f"[INFO] Average score: {avg_score:.3f}", flush=True)
319
+
320
+
321
+ if __name__ == "__main__":
322
+ main()
models.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ """Root-level models for OpenEnv compatibility."""
2
+ from dataqa_env.models import DataQAAction, DataQAObservation, DataQAState
3
+
4
+ __all__ = ["DataQAAction", "DataQAObservation", "DataQAState"]
openenv.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ spec_version: 1
2
+ name: dataqa_env
3
+ type: space
4
+ runtime: fastapi
5
+ app: dataqa_env.server.app:app
6
+ port: 8000
openenv_dataqa_env.egg-info/PKG-INFO ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.4
2
+ Name: openenv-dataqa-env
3
+ Version: 0.1.0
4
+ Summary: Data Quality Assurance Environment for OpenEnv - An LLM agent inspects datasets to find planted quality issues
5
+ Requires-Python: >=3.10
6
+ Requires-Dist: openenv-core[core]>=0.2.2
7
+ Requires-Dist: fastapi>=0.115.0
8
+ Requires-Dist: pydantic>=2.0.0
9
+ Requires-Dist: uvicorn[standard]>=0.24.0
10
+ Requires-Dist: requests>=2.31.0
11
+ Provides-Extra: dev
12
+ Requires-Dist: pytest>=8.0.0; extra == "dev"
13
+ Requires-Dist: pytest-cov>=4.0.0; extra == "dev"
openenv_dataqa_env.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ README.md
2
+ pyproject.toml
3
+ dataqa_env/__init__.py
4
+ dataqa_env/client.py
5
+ dataqa_env/models.py
6
+ dataqa_env/server/__init__.py
7
+ dataqa_env/server/app.py
8
+ dataqa_env/server/environment.py
9
+ dataqa_env/server/tasks.py
10
+ openenv_dataqa_env.egg-info/PKG-INFO
11
+ openenv_dataqa_env.egg-info/SOURCES.txt
12
+ openenv_dataqa_env.egg-info/dependency_links.txt
13
+ openenv_dataqa_env.egg-info/entry_points.txt
14
+ openenv_dataqa_env.egg-info/requires.txt
15
+ openenv_dataqa_env.egg-info/top_level.txt
openenv_dataqa_env.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
openenv_dataqa_env.egg-info/entry_points.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [console_scripts]
2
+ server = dataqa_env.server.app:main
openenv_dataqa_env.egg-info/requires.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ openenv-core[core]>=0.2.2
2
+ fastapi>=0.115.0
3
+ pydantic>=2.0.0
4
+ uvicorn[standard]>=0.24.0
5
+ requests>=2.31.0
6
+
7
+ [dev]
8
+ pytest>=8.0.0
9
+ pytest-cov>=4.0.0
openenv_dataqa_env.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ dataqa_env
pyproject.toml ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=45", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "openenv-dataqa-env"
7
+ version = "0.1.0"
8
+ description = "Data Quality Assurance Environment for OpenEnv - An LLM agent inspects datasets to find planted quality issues"
9
+ requires-python = ">=3.10"
10
+ dependencies = [
11
+ "openenv-core[core]>=0.2.2",
12
+ "fastapi>=0.115.0",
13
+ "pydantic>=2.0.0",
14
+ "uvicorn[standard]>=0.24.0",
15
+ "requests>=2.31.0",
16
+ ]
17
+
18
+ [project.optional-dependencies]
19
+ dev = [
20
+ "pytest>=8.0.0",
21
+ "pytest-cov>=4.0.0",
22
+ ]
23
+
24
+ [project.scripts]
25
+ server = "dataqa_env.server.app:main"
26
+
27
+ [tool.setuptools]
28
+ packages = ["dataqa_env", "dataqa_env.server"]
29
+ package-dir = { "dataqa_env" = "dataqa_env", "dataqa_env.server" = "dataqa_env/server" }
30
+
31
+ [tool.setuptools.package-data]
32
+ dataqa_env = ["**/*.yaml", "**/*.yml"]
server/__init__.py ADDED
File without changes
server/app.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Root-level server entry point for OpenEnv compatibility.
3
+ """
4
+
5
+ from dataqa_env.server.app import app # noqa: F401
6
+
7
+
8
+ def main():
9
+ import uvicorn
10
+ uvicorn.run(app, host="0.0.0.0", port=8000)
11
+
12
+
13
+ if __name__ == "__main__":
14
+ main()
uv.lock ADDED
The diff for this file is too large to render. See raw diff