Graheet commited on
Commit
00cf35f
·
1 Parent(s): 40128b8

Refactor semantic cleaning evaluator and improve API docs UX.

Browse files

Align environment, scoring, inference, and task metadata with strict step-based semantic actions, add robust uncertainty/hallucination handling, and polish Swagger docs with readable themed cards.

Made-with: Cursor

Files changed (6) hide show
  1. README.md +213 -245
  2. env.py +367 -632
  3. grader.py +128 -563
  4. inference.py +156 -156
  5. server/app.py +145 -4
  6. task.py +28 -5
README.md CHANGED
@@ -1,7 +1,7 @@
1
  ---
2
 
3
  title: Dataops Env
4
- emoji: 📊
5
  colorFrom: indigo
6
  colorTo: gray
7
  sdk: docker
@@ -10,335 +10,303 @@ pinned: false
10
 
11
  ---
12
 
13
- # `dataops-env`
14
 
15
- `dataops-env` is an OpenEnv benchmark for training and evaluating agents on
16
- multi-step data operations work. Instead of a single obvious cleanup action, an
17
- agent must inspect messy business tables, choose corrective actions in the right
18
- order, preserve valid-but-unusual records, and know when the table is truly
19
- ready for validation.
20
 
21
- It exposes the standard `reset()`, `step(action)`, and `state()` interface,
22
- ships with a production-ready FastAPI server and Docker image, and includes a
23
- reproducible OpenAI-compatible baseline runner.
24
 
25
- ## Benchmark Purpose
26
 
27
- Many toy data-cleaning tasks reward shallow pattern matching. Real operational
28
- data work is harder:
29
 
30
- - duplicates may be safe to remove, but conflicting rows require judgment
31
- - some malformed values should be normalized, while unusual valid values must be preserved
32
- - deletion is often the riskiest action, not the default fix
33
- - agents need partial credit for progress, but strong penalties for repeated mistakes
34
 
35
- `dataops-env` is designed to capture those decisions in a compact benchmark that
36
- is still easy to run, validate, and deploy in the OpenEnv ecosystem.
37
 
38
- ## Why It Feels Real
 
 
 
 
 
39
 
40
- The environment models common enterprise data quality problems:
41
 
42
- - exact duplicates in customer or vendor master data
43
- - missing required fields
44
- - inconsistent casing in names and locations
45
- - invalid email and phone formats
46
- - conflicting records for the same real-world entity
47
- - uniqueness constraints such as shared-email violations
48
- - trap rows that look suspicious but are actually valid
49
 
50
- Agents are rewarded for minimal corrective behavior and punished for destructive
51
- or repetitive actions. That makes the environment useful for both learning and
52
- evaluation.
53
 
54
- ## Task Families
55
 
56
- The benchmark keeps the hackathon-friendly `easy`, `medium`, and `hard` task
57
- structure, while each family now contains deterministic variants so policies
58
- cannot overfit a single table.
 
 
59
 
60
- 1. `easy`
61
- Remove duplicates and fill missing required fields.
62
- 2. `medium`
63
- Remove duplicates, normalize casing, and repair invalid emails.
64
- 3. `hard`
65
- Resolve conflicts, enforce unique-email constraints, fix invalid formats,
66
- and preserve valid trap rows.
67
 
68
- Each task definition includes:
69
 
70
- - `goal`
71
- - `difficulty`
72
- - `variant_id`
73
- - `required_columns`
74
- - `hidden_issues`
75
- - `constraints`
76
- - `expected_outcome`
77
- - `max_steps`
78
 
79
- ## Learning Signals
80
 
81
- The environment provides both dense rewards and a deterministic final score:
82
 
83
- - partial rewards for duplicate removal, normalization, and filling missing values
84
- - step costs and no-progress penalties to discourage random actions
85
- - escalating penalties for repeated mistakes
86
- - destructive-action penalties for harmful deletions
87
- - proactive hints after recurring failures
88
- - final task scoring on a strict `0.0` to `1.0` scale
89
 
90
- The final task score and the visible validation failures are produced from the
91
- same explicit rule set, reducing mismatch between what the agent sees and how it
92
- is ultimately judged.
 
 
 
 
 
 
93
 
94
- ## Action Space
95
 
96
- Agents interact with the environment through a typed `Action` object.
97
 
98
- Supported action types:
99
 
100
- - `remove_duplicate`
101
- Remove one row from an exact duplicate group. Can be called with an explicit
102
- `row_id`, or the environment can choose the default duplicate target.
103
- - `fill_missing`
104
- Fill a missing field on a target row. Requires `column` and `value`, and may
105
- also include `row_id`.
106
- - `normalize_column`
107
- Apply deterministic normalization to a supported column such as `name`,
108
- `city`, `email`, or `phone`.
109
- - `delete_row`
110
- Delete a row when doing so resolves a structural issue like a conflict or a
111
- uniqueness violation. Requires `row_id`.
112
- - `validate`
113
- Signal that the agent believes the table is ready for completion.
114
- - `noop`
115
- Explicitly take no action. This is allowed but penalized when unresolved
116
- issues remain.
117
-
118
- Typed action schema:
119
 
120
- - `action_id: Optional[str]`
121
- - `action_type: Literal["remove_duplicate", "fill_missing", "normalize_column", "delete_row", "validate", "noop"]`
122
- - `column: Optional[str]`
123
- - `row_id: Optional[int]`
124
- - `value: Optional[str]`
125
-
126
- Validation rules:
127
-
128
- - `delete_row` requires `row_id`
129
- - `normalize_column` requires `column`
130
- - `fill_missing` requires `column` and `value`
131
-
132
- Example actions:
133
 
134
- ```json
135
- {"action_id":"step-001","action_type":"remove_duplicate","row_id":33}
136
- {"action_id":"step-002","action_type":"fill_missing","row_id":35,"column":"email","value":"peak.systems@example.com"}
137
- {"action_id":"step-003","action_type":"normalize_column","column":"email"}
138
- {"action_id":"step-004","action_type":"validate"}
139
- ```
 
140
 
141
- ## Observation Space
142
 
143
- The environment returns a typed `Observation` object after `reset()` and each
144
- call to `step()`.
145
 
146
- Observation fields:
147
 
148
- - `goal: str`
149
- Natural-language description of what the agent should accomplish.
150
- - `table: List[Dict[str, Any]]`
151
- Current JSON-serializable table snapshot.
152
- - `issues: List[str]`
153
- Human-readable unresolved issues and validation failures.
154
- - `history: List[str]`
155
- Ordered record of previous actions/events in the current episode.
156
- - `mistakes: Dict[str, int]`
157
- Counts of repeated mistake categories tracked during the episode.
158
- - `hints: List[str]`
159
- Proactive or reactive guidance derived from issue state and prior failures.
160
- - `progress: float`
161
- Normalized progress estimate in `[0.0, 1.0]`.
162
- - `steps_remaining: int`
163
- Number of remaining actions before the episode terminates.
164
 
165
- Example observation shape:
 
 
 
166
 
167
- ```json
168
- {
169
- "goal": "Normalize the dataset by fixing casing, removing duplicates, and correcting invalid email formats.",
170
- "table": [
171
- {"row_id": 10, "customer_id": "C100", "name": "jane miller", "city": "new york", "email": "jane.miller@example.com"}
172
- ],
173
- "issues": [
174
- "Rows 11 and 13 are duplicates and only one should remain."
175
- ],
176
- "history": [],
177
- "mistakes": {},
178
- "hints": [],
179
- "progress": 0.0,
180
- "steps_remaining": 9
181
- }
182
- ```
183
 
184
- ## Expected Agent Behavior
 
 
 
185
 
186
- A strong agent should behave roughly like this:
 
 
187
 
188
- 1. inspect the visible table and unresolved issues
189
- 2. remove safe duplicates first
190
- 3. repair missing or malformed values without over-editing valid rows
191
- 4. resolve structural conflicts carefully, especially in hard tasks
192
- 5. validate only when the remaining issue list is empty
193
 
194
- Example successful baseline trace:
 
 
195
 
196
  ```text
197
- [START] task=medium env=dataops-env model=your-model
198
- [STEP] step=1 action=remove_duplicate(row_id=13) reward=0.37 done=false error=null
199
- [STEP] step=2 action=normalize_column(column='email') reward=0.27 done=false error=null
200
- [STEP] step=3 action=normalize_column(column='name') reward=0.24 done=false error=null
201
- [STEP] step=4 action=normalize_column(column='city') reward=0.44 done=true error=null
202
- [END] success=true steps=4 rewards=0.37,0.27,0.24,0.44
203
  ```
204
 
205
- ## Project Layout
206
 
207
- - `env.py`: core `DataOpsEnv` implementation
208
- - `task.py`: task families and deterministic variants
209
- - `models.py`: typed `Action`, `Observation`, and `Reward` contracts
210
- - `grader.py`: dense rewards, explicit validation checks, and final task scoring
211
- - `server/app.py`: FastAPI runtime API
212
- - `inference.py`: hybrid heuristic/model baseline runner
213
- - `openenv.yaml`: OpenEnv metadata and task registration
214
- - `pyproject.toml`: package metadata and server script entry point
215
- - `Dockerfile`: production container image
216
 
217
- ## Local Setup
 
 
 
 
218
 
219
- ```bash
220
- pip install -r requirements.txt
221
- openenv validate
222
- ```
223
 
224
- Run the FastAPI server:
 
225
 
226
- ```bash
227
- python -m server.app
228
- ```
229
 
230
- By default, the local server runs on port `8000`.
231
 
232
- Or use the packaged entry point:
 
 
233
 
234
- ```bash
235
- server
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  ```
237
 
238
- ## API
 
 
239
 
240
- Health check:
 
 
241
 
242
  ```bash
243
- curl http://localhost:8000/health
244
  ```
245
 
246
- Create a session with an optional seed and task selection:
 
 
247
 
248
  ```bash
249
- curl -X POST http://localhost:8000/reset \
250
- -H "Content-Type: application/json" \
251
- -d '{"seed": 0, "task_name": "easy"}'
252
  ```
253
 
254
- Step the environment:
 
 
255
 
256
  ```bash
257
- curl -X POST "http://localhost:8000/step" \
258
- -H "Content-Type: application/json" \
259
- -d '{"action_id":"step-001","action_type":"validate"}'
260
  ```
261
 
262
- Read internal state:
263
 
264
- ```bash
265
- curl "http://localhost:8000/state"
 
 
 
 
266
  ```
267
 
268
- ## Baseline Inference
 
 
 
 
 
 
 
 
 
 
 
269
 
270
- The baseline runner now combines deterministic local planning with optional
271
- model arbitration. The local planner proposes ranked candidate actions from the
272
- visible table state, and the model is constrained to choose only from those
273
- candidates. This avoids many common failure modes such as invalid actions,
274
- repeated no-op loops, and reckless deletion choices.
275
 
276
- Run it with an OpenAI-compatible endpoint:
277
 
278
  ```bash
279
- set HF_TOKEN=your_token
280
- set MODEL_NAME=your_model
281
- set API_BASE_URL=https://router.huggingface.co/v1
282
- python inference.py
283
  ```
284
 
285
- Key properties:
286
 
287
- - strict `[START]`, `[STEP]`, and `[END]` output formatting
288
- - fixed task ordering for reproducibility
289
- - retry logic for invalid or blocked model suggestions
290
- - strong heuristic fallback when the model is unavailable
291
- - action filtering based on prior no-progress or errorful behavior
292
 
293
- ## Docker
 
 
 
 
294
 
295
- Build:
296
 
297
- ```bash
298
- docker build -t dataops-env .
299
- ```
300
 
301
- Run locally:
 
 
 
 
302
 
303
- ```bash
304
- docker run -p 8000:8000 dataops-env
305
- ```
306
 
307
- ## Hugging Face Spaces Notes
308
 
309
- For Hugging Face `Docker` Spaces, the container should normally listen on port
310
- `7860`, or the Space must be explicitly configured to expect a different
311
- internal port.
 
 
312
 
313
- If you keep the current container on port `8000`, make sure your Space is
314
- configured with:
315
 
316
- ```yaml
317
- app_port: 8000
318
- ```
319
 
320
- If you want the simplest Hugging Face Spaces setup, change the container to use
321
- port `7860` instead:
322
 
323
- ```dockerfile
324
- EXPOSE 7860
325
- CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860"]
326
- ```
327
 
328
- Then local Docker testing would become:
 
 
 
 
329
 
330
- ```bash
331
- docker run -p 7860:7860 dataops-env
332
- curl http://localhost:7860/health
333
- ```
334
 
335
- ## Submission Notes
336
 
337
- - `openenv validate` passes
338
- - the server and Docker image run successfully
339
- - the packaged benchmark supports multi-mode deployment
340
- - the default baseline now completes the public task families deterministically
341
 
342
- Leaderboard performance will still depend on the quality of the external model,
343
- but the repository is now structured and documented like a serious benchmark
344
- submission rather than a starter scaffold.
 
1
  ---
2
 
3
  title: Dataops Env
4
+ emoji: 🧼
5
  colorFrom: indigo
6
  colorTo: gray
7
  sdk: docker
 
10
 
11
  ---
12
 
13
+ # ✨ DataOps Gym
14
 
15
+ ### The First Hallucination-Aware Data Cleaning Environment
 
 
 
 
16
 
17
+ > Most systems ask: *“Did you fix the data?”*
18
+ > We ask: *“Did you think before fixing?”*
 
19
 
20
+ ---
21
 
22
+ # 🚨 THE PROBLEM
 
23
 
24
+ **60–80% of a data scientist’s time is spent cleaning data.**
 
 
 
25
 
26
+ But current systems:
 
27
 
28
+ * blindly fix values
29
+ * hallucinate corrections
30
+ * ignore contradictions
31
+ * break real-world logic
32
+
33
+ ---
34
 
35
+ > 💡 **Wrong data is worse than missing data.**
36
 
37
+ ---
 
 
 
 
 
 
38
 
39
+ # 🧠 WHAT THIS PROJECT DOES
 
 
40
 
41
+ DataOps Gym is a **step-based OpenEnv environment** where an AI agent:
42
 
43
+ 1. Detects semantic inconsistencies
44
+ 2. Fixes data **only when confident**
45
+ 3. Outputs **"cannot determine"** when uncertain
46
+ 4. Maintains **cross-record consistency**
47
+ 5. Learns through **reward-based feedback**
48
 
49
+ ---
 
 
 
 
 
 
50
 
51
+ Each step teaches the agent:
52
 
53
+ * when to fix ✅
54
+ * when to abstain ⚠️
55
+ * when to say “I don’t know” 🧠
 
 
 
 
 
56
 
57
+ ---
58
 
59
+ # 🧩 ACTION SPACE
60
 
61
+ All actions must follow strict JSON format:
 
 
 
 
 
62
 
63
+ ```json
64
+ {
65
+ "action_type": "detect_issue | fix_value | cannot_determine | skip",
66
+ "record_id": "string",
67
+ "field": "string",
68
+ "value": "string",
69
+ "confidence": 0.0
70
+ }
71
+ ```
72
 
73
+ ---
74
 
75
+ ## 🔥 Key Innovation
76
 
77
+ 👉 `cannot_determine` is a **first-class action**
78
 
79
+ ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
+ # 🧠 WHY THIS IS DIFFERENT
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
+ | Traditional Systems | DataOps Gym |
84
+ | ------------------- | ---------------------- |
85
+ | Fix everything | Fix only when safe |
86
+ | Always answer | Can abstain |
87
+ | Ignore confidence | Confidence-aware |
88
+ | Single-row logic | Cross-record reasoning |
89
+ | Output-based | Behavior-based |
90
 
91
+ ---
92
 
93
+ # 💰 REWARD SYSTEM
 
94
 
95
+ ---
96
 
97
+ ## Rewards
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
+ * correct reasoning
100
+ * safe corrections
101
+ * correct uncertainty
102
+ * consistency across records
103
 
104
+ ---
105
+
106
+ ## Penalties
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
+ * hallucinated fixes 🚫
109
+ * overconfidence 🚫
110
+ * over-correction 🚫
111
+ * inconsistency 🚫
112
 
113
+ ---
114
+
115
+ ### 🔥 Core Principle
116
 
117
+ > **“Better to not fix than to fix incorrectly.”**
 
 
 
 
118
 
119
+ ---
120
+
121
+ # 📊 FINAL SCORING (0–1)
122
 
123
  ```text
124
+ task_score =
125
+ 0.5 * normalized_record_score
126
+ + 0.2 * (1 - hallucination_rate)
127
+ + 0.15 * uncertainty_accuracy
128
+ + 0.15 * consistency_score
 
129
  ```
130
 
131
+ ---
132
 
133
+ # 📉 METRICS
 
 
 
 
 
 
 
 
134
 
135
+ | Metric | Description |
136
+ | ----------------------- | ---------------------- |
137
+ | 🧠 Hallucination Rate | Wrong invented fixes |
138
+ | ⚖️ Uncertainty Accuracy | Correct abstentions |
139
+ | 🔗 Consistency Score | Cross-record reasoning |
140
 
141
+ ---
 
 
 
142
 
143
+ # 🧪 TASKS
144
+ > ⚡ Each task is carefully designed to evaluate **reasoning, restraint, and reliability** — not just accuracy.
145
 
146
+ ---
 
 
147
 
148
+ ## 🟢 EASY *Foundational Data Hygiene*
149
 
150
+ <p align="left">
151
+ <b>“Can the agent fix obvious issues without breaking anything?”</b>
152
+ </p>
153
 
154
+ * Basic inconsistencies
155
+ * Missing values
156
+ * Duplicate records
157
+
158
+ ---
159
+
160
+ ## 🟡 MEDIUM — *Contextual Reasoning & Ambiguity*
161
+
162
+ <p align="left">
163
+ <b>“Can the agent reason across records and handle uncertainty?”</b>
164
+ </p>
165
+
166
+ * Cross-table inconsistencies
167
+ * Identity ambiguity
168
+ * Data normalization
169
+
170
+ ---
171
+
172
+ ## 🔴 HARD — *Real-World Data Chaos*
173
+
174
+ <p align="left">
175
+ <b>“Can the agent survive contradictions, missing context, and unsolvable data?”</b>
176
+ </p>
177
+
178
+ * Multi-table conflicts
179
+ * Temporal inconsistencies
180
+ * Non-fixable contradictions
181
+
182
+ ---
183
+
184
+ > 🔥 **Difficulty is not about complexity — it's about uncertainty.**
185
+
186
+ | Level | Focus |
187
+ |--------|------|
188
+ | 🟢 Easy | Precision on clear signals |
189
+ | 🟡 Medium | Reasoning under ambiguity |
190
+ | 🔴 Hard | Decision-making under uncertainty |
191
+
192
+ ---
193
+
194
+ # 🧪 EXAMPLE FAILURE LOG
195
+
196
+ ```json
197
+ {
198
+ "record_id": "T3",
199
+ "error_type": "hallucination",
200
+ "details": "assigned value without evidence",
201
+ "confidence": 0.9
202
+ }
203
  ```
204
 
205
+ ---
206
+
207
+ # 🚀 QUICK START
208
 
209
+ ---
210
+
211
+ ## Install
212
 
213
  ```bash
214
+ pip install -r requirements.txt
215
  ```
216
 
217
+ ---
218
+
219
+ ## Run Server
220
 
221
  ```bash
222
+ python -m server.app
 
 
223
  ```
224
 
225
+ ---
226
+
227
+ ## Run Baseline
228
 
229
  ```bash
230
+ python inference.py
 
 
231
  ```
232
 
233
+ ---
234
 
235
+ ## Example Output
236
+
237
+ ```text
238
+ easy → 0.73
239
+ medium → 0.55
240
+ hard → 0.38
241
  ```
242
 
243
+ > ⚠️ Replace with your actual results
244
+
245
+ ---
246
+
247
+ # 🌐 API ENDPOINTS
248
+
249
+ | Endpoint | Description |
250
+ | --------- | ----------------- |
251
+ | `/reset` | Start new episode |
252
+ | `/step` | Take action |
253
+ | `/state` | Get current state |
254
+ | `/health` | Health check |
255
 
256
+ ---
 
 
 
 
257
 
258
+ # 🐳 DOCKER
259
 
260
  ```bash
261
+ docker build -t dataops-gym .
262
+ docker run -p 7860:7860 dataops-gym
 
 
263
  ```
264
 
265
+ ---
266
 
267
+ # 🧠 DESIGN PRINCIPLES
 
 
 
 
268
 
269
+ 1. Prefer uncertainty over hallucination
270
+ 2. Penalize confident mistakes
271
+ 3. Avoid over-correction
272
+ 4. Enforce cross-record consistency
273
+ 5. Reward safe reasoning
274
 
275
+ ---
276
 
277
+ # 🏆 BENCHMARK (EXPECTED)
 
 
278
 
279
+ | Task | Score |
280
+ | ------ | ----------- |
281
+ | Easy | 0.65 – 0.85 |
282
+ | Medium | 0.45 – 0.65 |
283
+ | Hard | 0.05 – 0.40 |
284
 
285
+ ---
 
 
286
 
287
+ # 📌 USE CASES
288
 
289
+ * AI data pipelines
290
+ * automated ETL validation
291
+ * financial data cleaning
292
+ * healthcare record validation
293
+ * LLM safety benchmarking
294
 
295
+ ---
 
296
 
297
+ # 🏁 FINAL TAKEAWAY
 
 
298
 
299
+ > 🧠 **The future of AI is not about answering everything.**
300
+ > **It’s about knowing when NOT to answer.**
301
 
302
+ ---
 
 
 
303
 
304
+ # 🔥 TAGLINE
305
+
306
+ > **“We built a system that teaches AI when NOT to change data.”**
307
+
308
+ ---
309
 
 
 
 
 
310
 
 
311
 
 
 
 
 
312
 
 
 
 
env.py CHANGED
@@ -1,36 +1,20 @@
1
- """OpenEnv environment entrypoint for ``dataops-gym``.
2
-
3
- This module is responsible for declaring top-level environment metadata,
4
- configuration wiring, and lifecycle integration points for the OpenEnv runtime.
5
- """
6
 
7
  from __future__ import annotations
8
 
9
  from copy import deepcopy
10
  import random
11
- import re
12
- from typing import Any, Dict, Iterable, List, Mapping, MutableMapping, Optional, Tuple
13
 
14
- from grader import grade_step_details, grade_task_result, task_failure_messages
15
  from models import Action, Observation
16
- from task import (
17
- HiddenIssue,
18
- TaskDefinition,
19
- easy_cleaning_task,
20
- hard_conflict_resolution_task,
21
- medium_normalization_task,
22
- )
23
-
24
-
25
- EMAIL_PATTERN = re.compile(r"^[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}$")
26
 
27
 
28
  class DataOpsEnv:
29
- """Deterministic multi-step data-cleaning environment for OpenEnv."""
30
 
31
  def __init__(self, seed: int = 0, task_name: Optional[str] = None) -> None:
32
- """Initialize the environment with deterministic task sampling."""
33
-
34
  self._seed = seed
35
  self._rng = random.Random(seed)
36
  self._task_registry: List[Tuple[str, Any]] = [
@@ -39,127 +23,109 @@ class DataOpsEnv:
39
  ("hard", hard_conflict_resolution_task),
40
  ]
41
  self._fixed_task_name = task_name
42
- self._global_mistake_memory: Dict[str, int] = {}
43
  self._state_data: Dict[str, Any] = {}
44
 
45
  def reset(self) -> Observation:
46
- """Load a random task, initialize episode state, and return an observation."""
47
-
48
  task_name, task_factory = self._select_task_factory()
49
  variant_count = max(1, int(getattr(task_factory, "variant_count", 1)))
50
- variant_index = self._rng.randrange(variant_count)
51
- task_definition = deepcopy(task_factory(variant=variant_index))
52
  initial_table = deepcopy(task_definition["initial_table"])
53
- initial_table_by_row_id = self._table_by_row_id(initial_table)
54
-
55
  self._state_data = {
56
  "seed": self._seed,
57
  "task_name": task_name,
58
- "task_variant": task_definition.get("variant_id", f"{task_name}_variant_{variant_index}"),
59
  "task": task_definition,
60
- "table": initial_table,
61
- "history": [],
62
- "mistakes": {},
63
- "mistake_memory": [],
64
- "hints": [],
 
 
65
  "steps_taken": 0,
66
  "steps_remaining": task_definition["max_steps"],
67
  "done": False,
68
- "last_reward_components": {},
69
- "last_info": {},
70
- "last_task_score": 0.0,
71
- "initial_issue_count": 1,
72
- "initial_table_by_row_id": initial_table_by_row_id,
 
 
 
 
 
 
 
 
 
73
  }
74
- initial_issue_count = len(self._current_issue_messages(initial_table, task_definition))
75
- self._state_data["initial_issue_count"] = max(1, initial_issue_count)
76
  return self._build_observation()
77
 
78
- def step(
79
- self, action: Action | Mapping[str, Any]
80
- ) -> Tuple[Observation, float, bool, Dict[str, Any]]:
81
- """Apply one action, score it, update state, and return a gym-style step tuple."""
82
-
83
  if not self._state_data:
84
  raise RuntimeError("Environment must be reset before calling step().")
85
- if self._state_data.get("done", False):
86
  raise RuntimeError("Episode is finished. Call reset() before stepping again.")
87
 
88
- parsed_action, action_error = self._coerce_action(action)
89
- task_definition: TaskDefinition = self._state_data["task"]
90
- table_before = deepcopy(self._state_data["table"])
91
- issues_before = self._current_issue_messages(table_before, task_definition)
92
-
93
- result: Dict[str, Any] = {
94
- "mistake_keys": [],
95
- "error_type": "general",
96
- }
97
-
98
- if action_error is not None:
99
- parsed_action = Action(action_type="noop")
100
- result["noop"] = True
101
- result["unnecessary_action"] = True
102
- result["error_type"] = "invalid_action"
103
- result["mistake_keys"].append("invalid_action:general")
104
- history_entry = f"invalid_action({action_error})"
105
- else:
106
- history_entry = self._apply_action(parsed_action, result)
107
 
108
- self._state_data["history"].append(history_entry)
109
  self._state_data["steps_taken"] += 1
110
  self._state_data["steps_remaining"] = max(
111
- 0, task_definition["max_steps"] - self._state_data["steps_taken"]
112
  )
113
 
114
- table_after = deepcopy(self._state_data["table"])
115
- issues_after = self._current_issue_messages(table_after, task_definition)
116
- self._populate_result_signals(
117
- parsed_action,
118
- table_before,
119
- table_after,
120
- issues_before,
121
- issues_after,
122
- result,
123
  )
124
-
125
- reward, components = grade_step_details(
126
  self._state_data, parsed_action.model_dump(), result
127
  )
128
- self._record_mistake_memory(parsed_action, result)
129
- self._update_hints(result, issues_after)
130
-
131
- done = not issues_after or self._state_data["steps_remaining"] <= 0
132
- self._state_data["done"] = done
 
 
 
 
 
 
 
 
 
 
 
 
133
  task_score = grade_task_result(
134
- task_definition, self._state_data["table"], self._state_data
135
  )
136
- self._state_data["last_task_score"] = task_score
137
 
138
- observation = self._build_observation()
 
139
  info = {
140
- "task_name": self._state_data["task_name"],
141
- "task_variant": self._state_data["task_variant"],
142
- "difficulty": task_definition["difficulty"],
143
- "reward_components": components,
144
- "mistakes": deepcopy(self._state_data["mistakes"]),
145
- "hints": list(self._state_data["hints"]),
146
- "issues_remaining": len(issues_after),
147
- "done_reason": "resolved" if not issues_after else "max_steps" if done else None,
148
- "task_score": task_score,
149
- "result": deepcopy(result),
 
 
150
  }
151
- self._state_data["last_reward_components"] = deepcopy(components)
152
- self._state_data["last_info"] = deepcopy(info)
153
- return observation, reward, done, info
154
 
155
  def state(self) -> Dict[str, Any]:
156
- """Return a deep copy of the internal environment state."""
157
-
158
  return deepcopy(self._state_data)
159
 
160
  def close(self) -> None:
161
- """Release environment state for callers using explicit lifecycle cleanup."""
162
-
163
  self._state_data = {}
164
 
165
  def _select_task_factory(self) -> Tuple[str, Any]:
@@ -174,578 +140,347 @@ class DataOpsEnv:
174
 
175
  raise ValueError(f"Unknown task_name: {self._fixed_task_name}")
176
 
177
- def _coerce_action(
178
- self, action: Action | Mapping[str, Any]
179
- ) -> Tuple[Optional[Action], Optional[str]]:
180
- """Convert user input into an ``Action`` model without raising outward."""
181
-
182
- if isinstance(action, Action):
183
- return action, None
184
-
185
- try:
186
- return Action(**dict(action)), None
187
- except Exception as exc: # pragma: no cover - defensive runtime boundary
188
- return None, str(exc)
189
-
190
- def _apply_action(self, action: Action, result: MutableMapping[str, Any]) -> str:
191
- """Apply a single action to the current table and capture side effects."""
192
-
193
- if action.action_type == "noop":
194
- result["noop"] = True
195
- result["mistake_keys"].append(f"{action.action_type}:noop")
196
- return self._format_history(action)
197
-
198
- if action.action_type == "remove_duplicate":
199
- self._remove_duplicate(action, result)
200
- return self._format_history(action)
201
-
202
- if action.action_type == "delete_row":
203
- self._delete_row(action, result)
204
- return self._format_history(action)
205
-
206
- if action.action_type == "fill_missing":
207
- self._fill_missing(action, result)
208
- return self._format_history(action)
209
-
210
- if action.action_type == "normalize_column":
211
- self._normalize_column(action, result)
212
- return self._format_history(action)
213
-
214
- if action.action_type == "validate":
215
- return self._format_history(action)
216
-
217
- result["unnecessary_action"] = True
218
- result["error_type"] = "unsupported_action"
219
- result["mistake_keys"].append(f"{action.action_type}:unsupported_action")
220
- return self._format_history(action)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
 
222
- def _remove_duplicate(
223
- self, action: Action, result: MutableMapping[str, Any]
224
  ) -> None:
225
- """Remove a duplicate row when the target belongs to a duplicate issue."""
226
-
227
- duplicate_groups = [
228
- issue
229
- for issue in self._state_data["task"]["hidden_issues"]
230
- if issue["type"] == "duplicate" and self._is_issue_unresolved(issue, self._state_data["table"])
231
- ]
232
- if not duplicate_groups:
233
- result["unnecessary_action"] = True
234
- result["error_type"] = "no_duplicate_available"
235
  return
236
 
237
- candidate_rows = set(duplicate_groups[0].get("rows", []))
238
- target_row_id = action.row_id or max(candidate_rows)
239
-
240
- if target_row_id not in candidate_rows:
241
- result["unnecessary_action"] = True
242
- result["error_type"] = "invalid_duplicate_target"
243
- return
244
-
245
- removed = self._remove_row_by_id(target_row_id)
246
- if not removed:
247
- result["unnecessary_action"] = True
248
- result["error_type"] = "missing_row"
249
-
250
- def _delete_row(self, action: Action, result: MutableMapping[str, Any]) -> None:
251
- """Delete a row and mark destructive behavior when the target is unsafe."""
252
 
253
- target_row = self._get_row_by_id(action.row_id)
254
- if target_row is None:
255
- result["unnecessary_action"] = True
256
- result["error_type"] = "missing_row"
257
  return
258
-
259
- if self._row_is_protected(action.row_id):
260
- result["wrong_deletion"] = True
261
- result["destructive_action"] = True
262
- result["error_type"] = "protected_row"
263
- result["mistake_keys"].append(f"{action.action_type}:protected_row")
264
- elif not self._row_belongs_to_removable_issue(action.row_id):
265
- result["wrong_deletion"] = True
266
- result["destructive_action"] = True
267
- result["error_type"] = "wrong_deletion"
268
- result["mistake_keys"].append(f"{action.action_type}:wrong_deletion")
269
-
270
- self._remove_row_by_id(action.row_id)
271
-
272
- def _fill_missing(self, action: Action, result: MutableMapping[str, Any]) -> None:
273
- """Fill a missing field on the target row or the first matching missing cell."""
274
-
275
- target_row = self._resolve_missing_target_row(action.row_id, action.column)
276
- if target_row is None or action.column is None:
277
- result["unnecessary_action"] = True
278
- result["error_type"] = "missing_target"
279
  return
280
 
281
- if not self._is_missing_value(target_row.get(action.column)):
282
- result["unnecessary_action"] = True
283
- result["error_type"] = "cell_not_missing"
284
  return
285
-
286
- target_row[action.column] = action.value
287
-
288
- def _normalize_column(self, action: Action, result: MutableMapping[str, Any]) -> None:
289
- """Normalize a supported column using deterministic, minimal edits."""
290
-
291
- if action.column is None:
292
- result["unnecessary_action"] = True
293
- result["error_type"] = "missing_column"
294
  return
 
 
 
 
 
 
 
 
295
 
296
- changed_rows = 0
297
- for row in self._state_data["table"]:
298
- original = row.get(action.column)
299
- normalized = self._normalized_value(action.column, original)
300
- if normalized is None or normalized == original:
301
- continue
302
-
303
- # Keep trap rows stable unless the value is actually invalid.
304
- if self._row_is_protected(row.get("row_id")) and self._value_is_valid(
305
- action.column, original
 
 
 
 
306
  ):
307
- continue
308
-
309
- row[action.column] = normalized
310
- changed_rows += 1
311
-
312
- if changed_rows == 0:
313
- result["unnecessary_action"] = True
314
- result["error_type"] = "no_normalization_needed"
315
 
316
- def _populate_result_signals(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
  self,
318
- action: Action,
319
- table_before: List[Dict[str, Any]],
320
- table_after: List[Dict[str, Any]],
321
- issues_before: List[str],
322
- issues_after: List[str],
323
- result: MutableMapping[str, Any],
324
- ) -> None:
325
- """Derive reward signals from before/after state transitions."""
326
-
327
- task_definition: TaskDefinition = self._state_data["task"]
328
- hidden_before = self._issue_type_counts(table_before, task_definition)
329
- hidden_after = self._issue_type_counts(table_after, task_definition)
330
-
331
- if hidden_after.get("duplicate", 0) < hidden_before.get("duplicate", 0):
332
- result["correct_duplicate_removal"] = True
333
-
334
- if hidden_after.get("missing_value", 0) < hidden_before.get("missing_value", 0):
335
- result["fixed_missing_value"] = True
336
-
337
- normalization_before = hidden_before.get("inconsistent_casing", 0) + hidden_before.get(
338
- "invalid_format", 0
339
- )
340
- normalization_after = hidden_after.get("inconsistent_casing", 0) + hidden_after.get(
341
- "invalid_format", 0
342
- )
343
- if (
344
- action.action_type == "normalize_column"
345
- and normalization_after < normalization_before
346
- ):
347
- result["correct_normalization"] = True
348
-
349
- if action.action_type == "validate" and not issues_after:
350
- result["validation_success"] = True
351
- result["task_completed"] = True
352
-
353
- if not issues_after:
354
- result["task_completed"] = True
355
-
356
- issue_delta = max(0, len(issues_before) - len(issues_after))
357
- result["progress_delta"] = round(
358
- issue_delta / float(self._state_data["initial_issue_count"]),
359
- 4,
360
- )
361
-
362
- if issue_delta > 0 and any(self._state_data["mistakes"].values()):
363
- result["corrected_previous_mistake"] = True
364
-
365
- if action.action_type == "noop" and issues_after:
366
- result["unnecessary_action"] = True
367
- result["error_type"] = result.get("error_type", "noop")
368
-
369
- def _build_observation(self) -> Observation:
370
- """Construct the typed observation returned to callers."""
371
-
372
- task_definition: TaskDefinition = self._state_data["task"]
373
- issue_messages = self._current_issue_messages(self._state_data["table"], task_definition)
374
- progress = self._compute_progress(issue_messages)
375
- return Observation(
376
- goal=task_definition["goal"],
377
- table=deepcopy(self._state_data["table"]),
378
- issues=issue_messages,
379
- history=list(self._state_data["history"]),
380
- mistakes=deepcopy(self._state_data["mistakes"]),
381
- hints=list(self._state_data["hints"]),
382
- progress=progress,
383
- steps_remaining=int(self._state_data["steps_remaining"]),
384
- )
385
-
386
- def _compute_progress(self, issue_messages: List[str]) -> float:
387
- """Estimate progress from the current unresolved issue count."""
388
-
389
- baseline = float(self._state_data["initial_issue_count"])
390
- remaining = min(len(issue_messages), self._state_data["initial_issue_count"])
391
- resolved_fraction = 1.0 - (remaining / baseline)
392
- return round(max(0.0, min(1.0, resolved_fraction)), 4)
393
-
394
- def _current_issue_messages(
395
- self, table: List[Dict[str, Any]], task_definition: TaskDefinition
396
- ) -> List[str]:
397
- """Return unresolved issue descriptions plus validation-rule failures."""
398
-
399
- messages: List[str] = []
400
- for issue in task_definition["hidden_issues"]:
401
- if self._is_issue_unresolved(issue, table):
402
- description = issue.get("description")
403
- if description:
404
- messages.append(description)
405
-
406
- messages.extend(self._validation_failures(table, task_definition))
407
- return messages
408
-
409
- def _validation_failures(
410
- self, table: List[Dict[str, Any]], task_definition: TaskDefinition
411
- ) -> List[str]:
412
- """Evaluate rule-based outcome constraints beyond the hidden issue list."""
413
-
414
- return task_failure_messages(task_definition, table, self._state_data)
415
-
416
- def _issue_type_counts(
417
- self, table: List[Dict[str, Any]], task_definition: TaskDefinition
418
- ) -> Dict[str, int]:
419
- """Count unresolved hidden issues by type."""
420
-
421
- counts: Dict[str, int] = {}
422
- for issue in task_definition["hidden_issues"]:
423
- if self._is_issue_unresolved(issue, table):
424
- issue_type = issue["type"]
425
- counts[issue_type] = counts.get(issue_type, 0) + 1
426
- return counts
427
-
428
- def _is_issue_unresolved(self, issue: HiddenIssue, table: List[Dict[str, Any]]) -> bool:
429
- """Determine whether a hidden issue is still unresolved."""
430
 
431
- issue_type = issue["type"]
432
- table_by_row_id = self._table_by_row_id(table)
433
 
434
- if issue_type == "valid_trap":
435
  return False
436
 
437
- if issue_type in {"duplicate", "conflict"}:
438
- rows = issue.get("rows", [])
439
- return all(row_id in table_by_row_id for row_id in rows)
 
 
 
 
440
 
441
  if issue_type == "missing_value":
442
- row = table_by_row_id.get(issue.get("row"))
443
- column = issue.get("column")
444
- return row is not None and column is not None and self._is_missing_value(row.get(column))
445
-
446
- if issue_type == "inconsistent_casing":
447
- column = issue.get("column")
448
- return any(
449
- row_id in table_by_row_id
450
- and self._needs_title_case(str(table_by_row_id[row_id].get(column, "")))
451
- for row_id in issue.get("rows", [])
452
- )
453
 
454
  if issue_type == "invalid_format":
455
- row = table_by_row_id.get(issue.get("row"))
456
- column = issue.get("column")
457
- return row is not None and column is not None and not self._value_is_valid(
458
- column, row.get(column)
459
- )
 
 
 
 
 
 
460
 
461
- if issue_type == "constraint_violation" and issue.get("constraint") == "unique_email":
462
- rows = issue.get("rows", [])
463
- emails = [
464
- table_by_row_id[row_id].get("email")
465
- for row_id in rows
466
- if row_id in table_by_row_id
467
- ]
468
- return len(emails) != len(set(emails))
469
-
470
- return False
471
-
472
- def _update_hints(self, result: Mapping[str, Any], issues_after: List[str]) -> None:
473
- """Add deterministic hints when the agent stalls or accumulates mistakes."""
474
 
475
- if not issues_after:
476
- return
477
 
478
- global_wrong_deletion_count = sum(
479
- count
480
- for key, count in self._global_mistake_memory.items()
481
- if key == "wrong_deletion" or key.endswith(":wrong_deletion")
482
  )
483
- if global_wrong_deletion_count >= 3:
484
- hint = (
485
- "You are repeatedly deleting valid rows. Try resolving issues "
486
- "instead of deleting."
487
- )
488
- if hint not in self._state_data["hints"]:
489
- self._state_data["hints"].append(hint)
490
-
491
- total_mistakes = sum(self._state_data["mistakes"].values())
492
- should_hint = bool(result.get("unnecessary_action")) or bool(
493
- result.get("wrong_deletion")
494
- ) or total_mistakes >= 2 or float(result.get("progress_delta", 0.0)) == 0.0
495
-
496
- if not should_hint:
497
- return
498
-
499
- next_hint = self._build_hint(issues_after[0])
500
- if next_hint not in self._state_data["hints"]:
501
- self._state_data["hints"].append(next_hint)
502
-
503
- def _build_hint(self, issue_message: str) -> str:
504
- """Map unresolved issue descriptions to small, actionable hints."""
505
-
506
- lowered = issue_message.lower()
507
- if "duplicate" in lowered:
508
- return "Look for rows that describe the same entity and keep only one representative record."
509
- if "missing" in lowered:
510
- return "A required field is still empty. Fill the missing value instead of deleting the row."
511
- if "email" in lowered and "format" in lowered:
512
- return "Normalize only the invalid email values; valid addresses should be preserved."
513
- if "phone" in lowered:
514
- return "Repair only phone values that are actually malformed."
515
- if "title-case" in lowered or "casing" in lowered:
516
- return "Normalize text columns to a consistent title-case style."
517
- if "unchanged" in lowered:
518
- return "Some unusual-looking rows are valid traps and should be preserved."
519
- return "Focus on the first unresolved issue and prefer minimal corrective actions."
520
-
521
- def _record_mistake_memory(
522
- self, action: Action, result: Mapping[str, Any]
523
- ) -> None:
524
- """Persist mistake events so hinting can look at prior failures."""
525
 
526
- for key, count in self._state_data["mistakes"].items():
527
- if count <= 0:
 
 
 
 
528
  continue
529
- if action.action_id:
530
- memory_entry = f"{action.action_id}:{key}:{count}"
531
- else:
532
- memory_entry = f"{action.action_type}:{key}:{count}"
533
- if memory_entry not in self._state_data["mistake_memory"]:
534
- self._state_data["mistake_memory"].append(memory_entry)
535
-
536
- self._global_mistake_memory[key] = (
537
- self._global_mistake_memory.get(key, 0) + 1
538
- )
539
- category_key = key.split(":")[-1]
540
- self._global_mistake_memory[category_key] = (
541
- self._global_mistake_memory.get(category_key, 0) + 1
542
- )
543
 
544
- if result.get("destructive_action"):
545
- entry = f"{action.action_type}:destructive_action"
546
- if entry not in self._state_data["mistake_memory"]:
547
- self._state_data["mistake_memory"].append(entry)
548
-
549
- def _resolve_missing_target_row(
550
- self, row_id: Optional[int], column: Optional[str]
551
- ) -> Optional[Dict[str, Any]]:
552
- """Choose the requested row or the first matching missing-value row."""
553
-
554
- if row_id is not None:
555
- return self._get_row_by_id(row_id)
556
-
557
- if column is None:
558
  return None
559
-
560
- for row in self._state_data["table"]:
561
- if self._is_missing_value(row.get(column)):
562
- return row
563
- return None
564
-
565
- def _normalized_value(self, column: str, value: Any) -> Any:
566
- """Return a normalized value for supported columns."""
567
-
568
- if not isinstance(value, str):
569
- return value
570
-
571
- if column in {"name", "city"}:
572
- return value.title()
573
-
574
- if column == "email" and not self._is_valid_email(value):
575
- normalized = value.strip().lower()
576
- normalized = normalized.replace("[at]", "@").replace(" at ", "@")
577
- if "@" not in normalized and normalized.endswith(".example.com"):
578
- normalized = normalized.replace(".example.com", "@example.com", 1)
579
- if "@" in normalized and "." not in normalized.split("@", 1)[1]:
580
- normalized = normalized + ".com"
581
- return normalized
582
-
583
- if column == "phone" and not self._is_valid_phone(value):
584
- digits = re.sub(r"\D", "", value)
585
- if len(digits) == 11 and digits.startswith("1"):
586
- digits = digits[1:]
587
- if len(digits) == 10:
588
- return f"{digits[0:3]}-{digits[3:6]}-{digits[6:10]}"
589
- return value
590
-
591
- def _value_is_valid(self, column: str, value: Any) -> bool:
592
- """Validate known column types used by the tasks."""
593
-
594
- if value is None:
595
- return False
596
- if column == "email":
597
- return self._is_valid_email(str(value))
598
- if column == "phone":
599
- return self._is_valid_phone(str(value))
600
- if column in {"name", "city"}:
601
- return not self._needs_title_case(str(value))
602
- return True
603
-
604
- def _is_valid_email(self, value: str) -> bool:
605
- """Return whether the supplied email string looks valid."""
606
-
607
- return bool(EMAIL_PATTERN.match(value.strip()))
608
-
609
- def _is_valid_phone(self, value: str) -> bool:
610
- """Return whether the supplied phone value is valid for this environment."""
611
-
612
- digits = re.sub(r"\D", "", value)
613
- return len(digits) == 10 or (len(digits) == 11 and digits.startswith("1"))
614
-
615
- def _needs_title_case(self, value: str) -> bool:
616
- """Detect whether a string still needs title-case normalization."""
617
-
618
- cleaned = value.strip()
619
- return bool(cleaned) and cleaned != cleaned.title()
620
-
621
- def _has_missing_required_values(
622
- self, table: Iterable[Dict[str, Any]], required_columns: Iterable[str]
623
  ) -> bool:
624
- """Check whether any required field remains missing."""
625
-
626
- for row in table:
627
- for column in required_columns:
628
- if self._is_missing_value(row.get(column)):
 
 
 
 
629
  return True
630
- return False
631
 
632
- def _has_duplicates(self, table: Iterable[Dict[str, Any]], column: str) -> bool:
633
- """Check whether a column contains duplicate non-empty values."""
634
-
635
- values = [row.get(column) for row in table if row.get(column) not in (None, "")]
636
- return len(values) != len(set(values))
 
637
 
638
- def _column_has_invalid_email(
639
- self, table: Iterable[Dict[str, Any]], column: str
640
- ) -> bool:
641
- """Check whether any remaining email value is invalid."""
642
 
643
- return any(
644
- row.get(column) not in (None, "") and not self._is_valid_email(str(row.get(column)))
645
- for row in table
 
 
 
 
 
 
 
 
646
  )
647
 
648
- def _column_has_invalid_phone(
649
- self, table: Iterable[Dict[str, Any]], column: str
650
- ) -> bool:
651
- """Check whether any remaining phone value is invalid."""
652
-
653
- return any(
654
- row.get(column) not in (None, "") and not self._is_valid_phone(str(row.get(column)))
655
- for row in table
656
  )
657
-
658
- def _column_needs_title_case(
659
- self, table: Iterable[Dict[str, Any]], column: str
660
- ) -> bool:
661
- """Check whether any remaining value still violates title-case normalization."""
662
-
663
- return any(
664
- isinstance(row.get(column), str) and self._needs_title_case(str(row.get(column)))
665
- for row in table
666
  )
667
 
668
- def _row_has_changed_from_initial(
669
- self, row_id: int, current_table: List[Dict[str, Any]]
670
- ) -> bool:
671
- """Check whether a protected row has changed relative to the task start."""
672
-
673
- current_row = self._table_by_row_id(current_table).get(row_id)
674
- initial_row = self._state_data["initial_table_by_row_id"].get(row_id)
675
- if current_row is None or initial_row is None:
676
- return True
677
- return current_row != initial_row
678
-
679
- def _row_is_protected(self, row_id: Optional[int]) -> bool:
680
- """Return whether a row is marked as a valid trap in the current task."""
681
-
682
- if row_id is None:
683
- return False
684
- for issue in self._state_data["task"]["hidden_issues"]:
685
- if issue["type"] == "valid_trap" and issue.get("row") == row_id:
686
- return True
687
- return False
688
-
689
- def _row_belongs_to_removable_issue(self, row_id: Optional[int]) -> bool:
690
- """Return whether deleting a row could plausibly resolve a structural issue."""
691
-
692
- if row_id is None:
693
- return False
694
- for issue in self._state_data["task"]["hidden_issues"]:
695
- if issue["type"] in {"duplicate", "conflict", "constraint_violation"} and row_id in issue.get(
696
- "rows", []
697
- ):
698
- return True
699
- return False
700
-
701
- def _remove_row_by_id(self, row_id: Optional[int]) -> bool:
702
- """Remove a row by id and report whether a row was deleted."""
703
-
704
- if row_id is None:
705
- return False
706
- table = self._state_data["table"]
707
- for index, row in enumerate(table):
708
- if row.get("row_id") == row_id:
709
- del table[index]
710
- return True
711
- return False
712
-
713
- def _get_row_by_id(self, row_id: Optional[int]) -> Optional[Dict[str, Any]]:
714
- """Return a mutable row reference by id."""
715
 
716
- if row_id is None:
717
- return None
718
- for row in self._state_data["table"]:
719
- if row.get("row_id") == row_id:
720
  return row
721
  return None
722
 
723
- def _table_by_row_id(self, table: List[Dict[str, Any]]) -> Dict[int, Dict[str, Any]]:
724
- """Index a table by row id."""
725
-
726
- return {
727
- int(row["row_id"]): deepcopy(row)
728
- for row in table
729
- if row.get("row_id") is not None
730
- }
731
-
732
- def _is_missing_value(self, value: Any) -> bool:
733
- """Return whether a cell should be treated as missing."""
734
-
735
- return value is None or value == ""
736
-
737
- def _format_history(self, action: Action) -> str:
738
- """Return a compact history entry for the applied action."""
739
-
740
- details = []
741
- if action.row_id is not None:
742
- details.append(f"row_id={action.row_id}")
743
- if action.column is not None:
744
- details.append(f"column={action.column}")
745
- if action.value is not None:
746
- details.append(f"value={action.value}")
747
- detail_text = ", ".join(details)
748
- return f"{action.action_type}({detail_text})" if detail_text else action.action_type
749
 
750
 
751
  class DataOpsGymEnv(DataOpsEnv):
 
1
+ """Semantic data-cleaning evaluation environment."""
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
5
  from copy import deepcopy
6
  import random
7
+ from typing import Any, Dict, List, Mapping, Optional, Tuple
 
8
 
9
+ from grader import grade_step_details, grade_task_result
10
  from models import Action, Observation
11
+ from task import easy_cleaning_task, hard_conflict_resolution_task, medium_normalization_task
 
 
 
 
 
 
 
 
 
12
 
13
 
14
  class DataOpsEnv:
15
+ """Step-based semantic evaluator with strict action protocol."""
16
 
17
  def __init__(self, seed: int = 0, task_name: Optional[str] = None) -> None:
 
 
18
  self._seed = seed
19
  self._rng = random.Random(seed)
20
  self._task_registry: List[Tuple[str, Any]] = [
 
23
  ("hard", hard_conflict_resolution_task),
24
  ]
25
  self._fixed_task_name = task_name
 
26
  self._state_data: Dict[str, Any] = {}
27
 
28
  def reset(self) -> Observation:
 
 
29
  task_name, task_factory = self._select_task_factory()
30
  variant_count = max(1, int(getattr(task_factory, "variant_count", 1)))
31
+ task_definition = deepcopy(task_factory(variant=self._rng.randrange(variant_count)))
 
32
  initial_table = deepcopy(task_definition["initial_table"])
 
 
33
  self._state_data = {
34
  "seed": self._seed,
35
  "task_name": task_name,
36
+ "task_variant": task_definition.get("variant_id", task_name),
37
  "task": task_definition,
38
+ "dataset_original": initial_table,
39
+ "dataset_modified": deepcopy(initial_table),
40
+ "action_history": [],
41
+ "per_record_scores": {},
42
+ "current_iteration_score": 0.0,
43
+ "previous_iteration_score": 0.0,
44
+ "failure_logs": [],
45
  "steps_taken": 0,
46
  "steps_remaining": task_definition["max_steps"],
47
  "done": False,
48
+ "totals": {
49
+ "total_fixes": 0,
50
+ "hallucinated_fixes": 0,
51
+ "total_cannot_determine": 0,
52
+ "correct_cannot_determine": 0,
53
+ "total_related_cases": 0,
54
+ "consistent_decisions": 0,
55
+ },
56
+ "related_decisions": {},
57
+ "detected_unresolved_issues": {},
58
+ "detected_issues": {},
59
+ "hallucination_rate": 0.0,
60
+ "uncertainty_accuracy": 0.0,
61
+ "consistency_score": 1.0,
62
  }
 
 
63
  return self._build_observation()
64
 
65
+ def step(self, action: Action | Mapping[str, Any]) -> Tuple[Observation, float, bool, Dict[str, Any]]:
 
 
 
 
66
  if not self._state_data:
67
  raise RuntimeError("Environment must be reset before calling step().")
68
+ if self._state_data["done"]:
69
  raise RuntimeError("Episode is finished. Call reset() before stepping again.")
70
 
71
+ parsed_action = action if isinstance(action, Action) else Action(**dict(action))
72
+ result = self._evaluate_action(parsed_action)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
+ self._state_data["action_history"].append(parsed_action.model_dump())
75
  self._state_data["steps_taken"] += 1
76
  self._state_data["steps_remaining"] = max(
77
+ 0, self._state_data["task"]["max_steps"] - self._state_data["steps_taken"]
78
  )
79
 
80
+ self._state_data["previous_iteration_score"] = float(
81
+ self._state_data["current_iteration_score"]
 
 
 
 
 
 
 
82
  )
83
+ reward, reward_components = grade_step_details(
 
84
  self._state_data, parsed_action.model_dump(), result
85
  )
86
+ rid = parsed_action.record_id
87
+ self._state_data["per_record_scores"][rid] = float(
88
+ self._state_data["per_record_scores"].get(rid, 0.0)
89
+ ) + reward
90
+ self._state_data["current_iteration_score"] = sum(
91
+ float(v) for v in self._state_data["per_record_scores"].values()
92
+ )
93
+ prev = self._state_data["previous_iteration_score"]
94
+ curr = self._state_data["current_iteration_score"]
95
+ if curr > prev:
96
+ reward += 0.1
97
+ reward_components["iteration_improvement"] = 0.1
98
+ elif curr < prev:
99
+ reward -= 0.1
100
+ reward_components["iteration_improvement"] = -0.1
101
+
102
+ self._update_metrics()
103
  task_score = grade_task_result(
104
+ self._state_data["task"], self._state_data["dataset_modified"], self._state_data
105
  )
 
106
 
107
+ done = self._state_data["steps_remaining"] <= 0
108
+ self._state_data["done"] = done
109
  info = {
110
+ "actions_taken": deepcopy(self._state_data["action_history"]),
111
+ "updated_dataset": deepcopy(self._state_data["dataset_modified"]),
112
+ "per_record_scores": deepcopy(self._state_data["per_record_scores"]),
113
+ "final_task_score": task_score,
114
+ "metrics": {
115
+ "hallucination_rate": self._state_data["hallucination_rate"],
116
+ "uncertainty_accuracy": self._state_data["uncertainty_accuracy"],
117
+ "consistency_score": self._state_data["consistency_score"],
118
+ },
119
+ "failure_logs": deepcopy(self._state_data["failure_logs"]),
120
+ "reward_components": reward_components,
121
+ "result": result,
122
  }
123
+ return self._build_observation(), reward, done, info
 
 
124
 
125
  def state(self) -> Dict[str, Any]:
 
 
126
  return deepcopy(self._state_data)
127
 
128
  def close(self) -> None:
 
 
129
  self._state_data = {}
130
 
131
  def _select_task_factory(self) -> Tuple[str, Any]:
 
140
 
141
  raise ValueError(f"Unknown task_name: {self._fixed_task_name}")
142
 
143
+ def _evaluate_action(self, action: Action) -> Dict[str, Any]:
144
+ table = self._state_data["dataset_modified"]
145
+ issue = self._matching_issue(action.record_id, action.field)
146
+ issue_key = self._issue_key(issue)
147
+ result: Dict[str, Any] = {"extra_fields_modified": 0}
148
+ self._apply_related_consistency(action, issue, result)
149
+ self._apply_follow_up_requirement(action, issue_key, result)
150
+
151
+ if action.action_type == "skip":
152
+ if issue is not None:
153
+ result["missed_issue"] = True
154
+ result["passive_penalty"] = True
155
+ if issue_key is not None:
156
+ self._state_data["detected_unresolved_issues"][issue_key] = True
157
+ self._append_failure(action, "missed_issue", "Issue exists but action was skip.")
158
+ return result
159
+
160
+ if action.action_type == "detect_issue":
161
+ if issue is not None:
162
+ result["classification_correct"] = True
163
+ result["correct_issue_detected"] = True
164
+ result["passive_penalty"] = True
165
+ if issue_key is not None:
166
+ if issue_key in self._state_data["detected_issues"]:
167
+ result["repeated_detection"] = True
168
+ self._state_data["detected_issues"][issue_key] = True
169
+ self._state_data["detected_unresolved_issues"][issue_key] = True
170
+ else:
171
+ result["classification_incorrect"] = True
172
+ result["false_issue"] = True
173
+ return result
174
+
175
+ if action.action_type == "cannot_determine":
176
+ self._state_data["totals"]["total_cannot_determine"] += 1
177
+ if issue is None:
178
+ result["wrong_cannot_determine"] = True
179
+ self._append_failure(
180
+ action, "wrong_fix", "cannot_determine used without any supporting issue."
181
+ )
182
+ elif issue.get("fixable", True) is False:
183
+ result["correct_cannot_determine"] = True
184
+ self._state_data["totals"]["correct_cannot_determine"] += 1
185
+ if issue_key is not None:
186
+ self._state_data["detected_unresolved_issues"].pop(issue_key, None)
187
+ if issue_key in self._state_data["detected_issues"]:
188
+ result["resolved_detected_issue"] = True
189
+ else:
190
+ result["wrong_cannot_determine"] = True
191
+ self._append_failure(
192
+ action, "wrong_fix", "cannot_determine used when evidence was sufficient."
193
+ )
194
+ return result
195
+
196
+ # fix_value
197
+ self._state_data["totals"]["total_fixes"] += 1
198
+ if issue is None:
199
+ related_issue_count = self._count_issues_for_record(action.record_id)
200
+ if related_issue_count > 0:
201
+ result["extra_fields_modified"] += 1
202
+
203
+ row = self._find_record(action.record_id, table)
204
+ if row is None or action.field not in row:
205
+ result["hallucinated_fix"] = True
206
+ self._state_data["totals"]["hallucinated_fixes"] += 1
207
+ self._append_failure(action, "hallucination", "Attempted fix with no evidence.")
208
+ return result
209
+
210
+ if issue is None:
211
+ result["hallucinated_fix"] = True
212
+ self._state_data["totals"]["hallucinated_fixes"] += 1
213
+ self._append_failure(action, "hallucination", "Field had no target issue.")
214
+ return result
215
+
216
+ if self._issue_resolved(issue, table):
217
+ result["hallucinated_fix"] = True
218
+ self._state_data["totals"]["hallucinated_fixes"] += 1
219
+ self._append_failure(action, "hallucination", "Field is already correct.")
220
+ return result
221
+
222
+ old_value = row.get(action.field)
223
+ before_row = deepcopy(row)
224
+ row[action.field] = action.value
225
+ if self._introduces_inconsistency(row, action.field, table):
226
+ result["hallucinated_fix"] = True
227
+ self._state_data["totals"]["hallucinated_fixes"] += 1
228
+ row[action.field] = old_value
229
+ self._append_failure(
230
+ action, "hallucination", "Fix introduces cross-record or temporal inconsistency."
231
+ )
232
+ return result
233
+
234
+ if self.validate_fix(issue, before_row, row, table):
235
+ result["correct_fix"] = True
236
+ result["classification_correct"] = True
237
+ if issue_key is not None:
238
+ if issue_key in self._state_data["detected_issues"]:
239
+ result["resolved_detected_issue"] = True
240
+ self._state_data["detected_unresolved_issues"].pop(issue_key, None)
241
+ else:
242
+ row[action.field] = old_value
243
+ result["wrong_fix"] = True
244
+ self._append_failure(action, "wrong_fix", "Fix does not resolve the identified issue.")
245
+ return result
246
 
247
+ def _apply_follow_up_requirement(
248
+ self, action: Action, issue_key: Optional[str], result: Dict[str, Any]
249
  ) -> None:
250
+ unresolved = self._state_data.get("detected_unresolved_issues", {})
251
+ if not unresolved:
 
 
 
 
 
 
 
 
252
  return
253
 
254
+ # Follow-up action types are fix/cannot_determine against a detected issue.
255
+ is_follow_up = (
256
+ action.action_type in {"fix_value", "cannot_determine"}
257
+ and issue_key is not None
258
+ and issue_key in unresolved
259
+ )
260
+ if not is_follow_up:
261
+ result["passive_penalty"] = True
 
 
 
 
 
 
 
262
 
263
+ def _apply_related_consistency(
264
+ self, action: Action, issue: Optional[Dict[str, Any]], result: Dict[str, Any]
265
+ ) -> None:
266
+ if issue is None:
267
  return
268
+ issue_type = issue.get("type")
269
+ if issue_type not in {"duplicate", "conflict"}:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
  return
271
 
272
+ rows = issue.get("rows", [])
273
+ if not rows:
 
274
  return
275
+ key = f"{issue_type}:{','.join(str(v) for v in sorted(rows))}"
276
+ self._state_data["totals"]["total_related_cases"] += 1
277
+ seen = self._state_data["related_decisions"]
278
+ decision = action.action_type
279
+ if key not in seen:
280
+ seen[key] = decision
281
+ result["consistent_handling"] = True
282
+ self._state_data["totals"]["consistent_decisions"] += 1
 
283
  return
284
+ if seen[key] == decision:
285
+ result["consistent_handling"] = True
286
+ self._state_data["totals"]["consistent_decisions"] += 1
287
+ else:
288
+ result["inconsistent_handling"] = True
289
+ self._append_failure(
290
+ action, "inconsistency", "Related records were handled inconsistently."
291
+ )
292
 
293
+ def _matching_issue(self, record_id: str, field: str) -> Optional[Dict[str, Any]]:
294
+ rid = self._parse_record_id(record_id)
295
+ for issue in self._state_data["task"]["hidden_issues"]:
296
+ issue_type = issue.get("type")
297
+ if issue_type == "missing_value" and issue.get("row") == rid and issue.get("column") == field:
298
+ return issue
299
+ if issue_type == "invalid_format" and issue.get("row") == rid and issue.get("column") == field:
300
+ return issue
301
+ if issue_type == "inconsistent_casing" and field == issue.get("column") and rid in issue.get("rows", []):
302
+ return issue
303
+ if (
304
+ issue_type in {"duplicate", "conflict", "constraint_violation"}
305
+ and (field in {"row", "record"} or field == issue.get("field"))
306
+ and rid in issue.get("rows", [])
307
  ):
308
+ ambiguous = issue_type in {"conflict", "constraint_violation"}
309
+ c = dict(issue)
310
+ c["ambiguous"] = ambiguous
311
+ return c
312
+ return None
 
 
 
313
 
314
+ def _issue_resolved(self, issue: Mapping[str, Any], table: List[Dict[str, Any]]) -> bool:
315
+ if issue.get("type") in {"duplicate", "conflict", "constraint_violation"}:
316
+ return False
317
+ rid = int(issue.get("row", -1))
318
+ field = issue.get("column")
319
+ row = self._find_record(str(rid), table)
320
+ if row is None:
321
+ return True
322
+ if issue.get("type") == "missing_value":
323
+ return row.get(field) not in (None, "", "unknown", "9999")
324
+ if issue.get("type") == "invalid_format":
325
+ value = str(row.get(field, ""))
326
+ if field == "email":
327
+ return "@" in value and "." in value.split("@")[-1]
328
+ if field == "phone":
329
+ digits = "".join(ch for ch in value if ch.isdigit())
330
+ return len(digits) in {10, 11}
331
+ if field in {"start_date", "end_date"}:
332
+ start = row.get("start_date")
333
+ end = row.get("end_date")
334
+ return not (start and end and str(end) < str(start))
335
+ return row.get(field) not in (None, "", "unknown", "9999")
336
+
337
+ def validate_fix(
338
  self,
339
+ issue: Mapping[str, Any],
340
+ before_row: Mapping[str, Any],
341
+ after_row: Mapping[str, Any],
342
+ table: List[Dict[str, Any]],
343
+ ) -> bool:
344
+ """Ground-truth validator for semantic fixes."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
 
346
+ issue_type = str(issue.get("type", ""))
347
+ field = str(issue.get("column") or issue.get("field") or "")
348
 
349
+ if field and before_row.get(field) == after_row.get(field):
350
  return False
351
 
352
+ if field == "age":
353
+ try:
354
+ age = int(after_row.get("age"))
355
+ except Exception:
356
+ return False
357
+ if age < 0 or age > 120:
358
+ return False
359
 
360
  if issue_type == "missing_value":
361
+ return after_row.get(field) not in (None, "", "unknown", "9999")
 
 
 
 
 
 
 
 
 
 
362
 
363
  if issue_type == "invalid_format":
364
+ value = str(after_row.get(field, ""))
365
+ if field == "email":
366
+ return "@" in value and "." in value.split("@")[-1]
367
+ if field == "phone":
368
+ digits = "".join(ch for ch in value if ch.isdigit())
369
+ return len(digits) in {10, 11}
370
+ if field in {"start_date", "end_date"}:
371
+ start = after_row.get("start_date")
372
+ end = after_row.get("end_date")
373
+ return not (start and end and str(end) < str(start))
374
+ return value not in ("", "unknown", "9999")
375
 
376
+ if issue_type == "inconsistent_casing":
377
+ value = after_row.get(field)
378
+ return isinstance(value, str) and value == value.title()
 
 
 
 
 
 
 
 
 
 
379
 
380
+ if issue_type in {"duplicate", "conflict", "constraint_violation"}:
381
+ return False
382
 
383
+ return not self._introduces_inconsistency(dict(after_row), field, table) and self._issue_resolved(
384
+ issue, table
 
 
385
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
386
 
387
+ def _count_issues_for_record(self, record_id: str) -> int:
388
+ rid = self._parse_record_id(record_id)
389
+ count = 0
390
+ for issue in self._state_data["task"]["hidden_issues"]:
391
+ if issue.get("row") == rid:
392
+ count += 1
393
  continue
394
+ if rid in issue.get("rows", []):
395
+ count += 1
396
+ return count
 
 
 
 
 
 
 
 
 
 
 
397
 
398
+ def _issue_key(self, issue: Optional[Dict[str, Any]]) -> Optional[str]:
399
+ if issue is None:
 
 
 
 
 
 
 
 
 
 
 
 
400
  return None
401
+ issue_type = issue.get("type", "unknown")
402
+ if "row" in issue and "column" in issue:
403
+ return f"{issue_type}:row={issue.get('row')}:col={issue.get('column')}"
404
+ if "rows" in issue:
405
+ rows = ",".join(str(v) for v in sorted(issue.get("rows", [])))
406
+ field = issue.get("field", "record")
407
+ return f"{issue_type}:rows={rows}:field={field}"
408
+ return f"{issue_type}:generic"
409
+
410
+ def _introduces_inconsistency(
411
+ self, row: Dict[str, Any], field: str, table: List[Dict[str, Any]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
412
  ) -> bool:
413
+ # Unique email consistency check across records.
414
+ if field == "email":
415
+ email = row.get("email")
416
+ if email not in (None, ""):
417
+ duplicates = [
418
+ r for r in table
419
+ if r is not row and str(r.get("email", "")).strip() == str(email).strip()
420
+ ]
421
+ if duplicates:
422
  return True
 
423
 
424
+ # Temporal consistency check where both fields are present.
425
+ if field in {"start_date", "end_date"}:
426
+ start = row.get("start_date")
427
+ end = row.get("end_date")
428
+ if start and end and str(end) < str(start):
429
+ return True
430
 
431
+ return False
 
 
 
432
 
433
+ def _build_observation(self) -> Observation:
434
+ return Observation(
435
+ dataset={
436
+ "original": deepcopy(self._state_data["dataset_original"]),
437
+ "modified": deepcopy(self._state_data["dataset_modified"]),
438
+ },
439
+ action_history=deepcopy(self._state_data["action_history"]),
440
+ per_record_scores=deepcopy(self._state_data["per_record_scores"]),
441
+ current_iteration_score=float(self._state_data["current_iteration_score"]),
442
+ previous_iteration_score=float(self._state_data["previous_iteration_score"]),
443
+ steps_remaining=int(self._state_data["steps_remaining"]),
444
  )
445
 
446
+ def _update_metrics(self) -> None:
447
+ totals = self._state_data["totals"]
448
+ total_fixes = int(totals["total_fixes"])
449
+ self._state_data["hallucination_rate"] = (
450
+ 0.0 if total_fixes == 0 else float(totals["hallucinated_fixes"]) / total_fixes
 
 
 
451
  )
452
+ total_cd = int(totals["total_cannot_determine"])
453
+ self._state_data["uncertainty_accuracy"] = (
454
+ 0.0 if total_cd == 0 else float(totals["correct_cannot_determine"]) / total_cd
455
+ )
456
+ total_related = int(totals["total_related_cases"])
457
+ self._state_data["consistency_score"] = (
458
+ 1.0 if total_related == 0 else float(totals["consistent_decisions"]) / total_related
 
 
459
  )
460
 
461
+ def _parse_record_id(self, record_id: str) -> int:
462
+ digits = "".join(ch for ch in str(record_id) if ch.isdigit())
463
+ return int(digits) if digits else -1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
464
 
465
+ def _find_record(self, record_id: str, table: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
466
+ rid = self._parse_record_id(record_id)
467
+ for row in table:
468
+ if int(row.get("row_id", -1)) == rid:
469
  return row
470
  return None
471
 
472
+ def _append_failure(self, action: Action, error_type: str, details: str) -> None:
473
+ mapped = error_type
474
+ if error_type == "wrong_fix":
475
+ mapped = "wrong_fix"
476
+ self._state_data["failure_logs"].append(
477
+ {
478
+ "record_id": action.record_id,
479
+ "error_type": mapped,
480
+ "details": details,
481
+ "confidence": float(action.confidence),
482
+ }
483
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
484
 
485
 
486
  class DataOpsGymEnv(DataOpsEnv):
grader.py CHANGED
@@ -1,438 +1,8 @@
1
- """Evaluation and grading interfaces for ``dataops-gym``.
2
-
3
- This module is responsible for validating outputs, scoring task results, and
4
- capturing assessment metadata independently from task execution logic.
5
- """
6
 
7
  from __future__ import annotations
8
 
9
- import re
10
- from typing import Any, Dict, Iterable, Mapping, MutableMapping, Optional, Tuple
11
-
12
-
13
- # Dense reward values are intentionally small and additive so the agent receives
14
- # feedback for intermediate progress without requiring full task completion.
15
- CORRECT_DUPLICATE_REMOVAL_REWARD = 0.3
16
- CORRECT_NORMALIZATION_REWARD = 0.2
17
- FIX_MISSING_VALUE_REWARD = 0.2
18
- VALIDATION_SUCCESS_REWARD = 0.2
19
- EFFICIENCY_BONUS = 0.2
20
- RECOVERY_BONUS = 0.25
21
- STEP_PENALTY = -0.02
22
- PROGRESS_REWARD_SCALE = 0.3
23
-
24
- # Penalties are split into:
25
- # 1. a direct penalty for the current bad action, and
26
- # 2. an escalating repetition penalty if the same mistake keeps happening.
27
- WRONG_DELETION_PENALTY = -0.3
28
- UNNECESSARY_ACTION_PENALTY = -0.1
29
- NOOP_PENALTY = -0.05
30
- DESTRUCTIVE_ACTION_PENALTY = -0.4
31
-
32
- FIRST_REPEAT_PENALTY = -0.1
33
- SECOND_REPEAT_PENALTY = -0.2
34
- THIRD_OR_MORE_REPEAT_PENALTY = -0.4
35
- EMAIL_PATTERN = re.compile(r"^[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}$")
36
-
37
-
38
- def detect_repeated_mistake(mistakes: Mapping[str, int], mistake_key: str) -> int:
39
- """Return how many times a mistake has already occurred before this step."""
40
-
41
- return int(mistakes.get(mistake_key, 0))
42
-
43
-
44
- def track_mistake(state: MutableMapping[str, Any], mistake_key: str) -> int:
45
- """Update the mistake counter in state and return the new occurrence count."""
46
-
47
- mistakes = state.setdefault("mistakes", {})
48
- if not isinstance(mistakes, dict):
49
- raise ValueError("state['mistakes'] must be a dictionary for mistake tracking")
50
-
51
- current_count = int(mistakes.get(mistake_key, 0))
52
- new_count = current_count + 1
53
- mistakes[mistake_key] = new_count
54
- return new_count
55
-
56
-
57
- def repeated_mistake_penalty(occurrence_count: int) -> float:
58
- """Return the escalating penalty for repeated mistakes."""
59
-
60
- if occurrence_count <= 1:
61
- return FIRST_REPEAT_PENALTY
62
- if occurrence_count == 2:
63
- return SECOND_REPEAT_PENALTY
64
- return THIRD_OR_MORE_REPEAT_PENALTY
65
-
66
-
67
- def _to_bool(mapping: Mapping[str, Any], key: str) -> bool:
68
- """Normalize truthy result flags into deterministic boolean checks."""
69
-
70
- return bool(mapping.get(key, False))
71
-
72
-
73
- def _mistake_key(
74
- action: Mapping[str, Any],
75
- result: Mapping[str, Any],
76
- fallback_key: str,
77
- ) -> str:
78
- """Build an action-specific mistake key with a safe fallback."""
79
-
80
- action_type = action.get("action_type")
81
- error_type = result.get("error_type", "general")
82
-
83
- if action_type:
84
- return f"{action_type}:{error_type}"
85
- return fallback_key
86
-
87
-
88
- def _clamp_reward(value: float) -> float:
89
- """Keep rewards in the required [-1.0, 1.0] range."""
90
-
91
- return max(-1.0, min(1.0, round(value, 4)))
92
-
93
-
94
- def _clamp_score(value: float) -> float:
95
- """Keep task-level scores in the required [0.0, 1.0] range."""
96
-
97
- return max(0.0, min(1.0, round(value, 4)))
98
-
99
-
100
- def _is_missing_value(value: Any) -> bool:
101
- """Return whether a cell should be considered missing."""
102
-
103
- return value is None or value == ""
104
-
105
-
106
- def _is_valid_email(value: str) -> bool:
107
- """Validate email formatting used by task graders."""
108
-
109
- return bool(EMAIL_PATTERN.match(value.strip()))
110
-
111
-
112
- def _is_valid_phone(value: str) -> bool:
113
- """Validate phone formatting used by task graders."""
114
-
115
- digits = re.sub(r"\D", "", value)
116
- return len(digits) == 10 or (len(digits) == 11 and digits.startswith("1"))
117
-
118
-
119
- def _needs_title_case(value: str) -> bool:
120
- """Return whether text still violates title-case normalization."""
121
-
122
- cleaned = value.strip()
123
- return bool(cleaned) and cleaned != cleaned.title()
124
-
125
-
126
- def _has_duplicates(table: Iterable[Dict[str, Any]], column: str) -> bool:
127
- """Check whether a column contains duplicate non-empty values."""
128
-
129
- values = [row.get(column) for row in table if row.get(column) not in (None, "")]
130
- return len(values) != len(set(values))
131
-
132
-
133
- def _table_by_row_id(table: Iterable[Dict[str, Any]]) -> Dict[int, Dict[str, Any]]:
134
- """Index a table by ``row_id`` for deterministic issue evaluation."""
135
-
136
- return {
137
- int(row["row_id"]): dict(row)
138
- for row in table
139
- if row.get("row_id") is not None
140
- }
141
-
142
-
143
- def _is_issue_resolved(issue: Mapping[str, Any], table_by_row_id: Dict[int, Dict[str, Any]]) -> bool:
144
- """Return whether a structured hidden issue has been resolved."""
145
-
146
- issue_type = issue.get("type")
147
-
148
- if issue_type == "valid_trap":
149
- return True
150
-
151
- if issue_type in {"duplicate", "conflict"}:
152
- rows = issue.get("rows", [])
153
- return not all(row_id in table_by_row_id for row_id in rows)
154
-
155
- if issue_type == "missing_value":
156
- row = table_by_row_id.get(issue.get("row"))
157
- column = issue.get("column")
158
- return row is None or column is None or not _is_missing_value(row.get(column))
159
-
160
- if issue_type == "inconsistent_casing":
161
- column = issue.get("column")
162
- rows = issue.get("rows", [])
163
- return not any(
164
- row_id in table_by_row_id
165
- and isinstance(table_by_row_id[row_id].get(column), str)
166
- and _needs_title_case(str(table_by_row_id[row_id].get(column)))
167
- for row_id in rows
168
- )
169
-
170
- if issue_type == "invalid_format":
171
- row = table_by_row_id.get(issue.get("row"))
172
- column = issue.get("column")
173
- if row is None or column is None:
174
- return True
175
- value = row.get(column)
176
- if column == "email":
177
- return _is_valid_email(str(value))
178
- if column == "phone":
179
- return _is_valid_phone(str(value))
180
- return True
181
-
182
- if issue_type == "constraint_violation" and issue.get("constraint") == "unique_email":
183
- rows = issue.get("rows", [])
184
- emails = [
185
- table_by_row_id[row_id].get("email")
186
- for row_id in rows
187
- if row_id in table_by_row_id
188
- ]
189
- return len(emails) == len(set(emails))
190
-
191
- return True
192
-
193
-
194
- def _task_check_results(
195
- task_definition: Mapping[str, Any],
196
- table: Iterable[Dict[str, Any]],
197
- state: Optional[Mapping[str, Any]] = None,
198
- ) -> list[Dict[str, Any]]:
199
- """Build explicit pass/fail checks for final grading and validation."""
200
-
201
- rows = [dict(row) for row in table]
202
- table_by_row_id = _table_by_row_id(rows)
203
- expected_outcome = dict(task_definition.get("expected_outcome", {}))
204
- checks: list[Dict[str, Any]] = []
205
-
206
- expected_row_count = expected_outcome.get("expected_row_count")
207
- if expected_row_count is not None:
208
- checks.append(
209
- {
210
- "name": "expected_row_count",
211
- "passed": len(rows) == expected_row_count,
212
- "message": f"Expected exactly {expected_row_count} rows in the cleaned table.",
213
- }
214
- )
215
-
216
- expected_row_range = expected_outcome.get("expected_row_count_range")
217
- if expected_row_range is not None:
218
- checks.append(
219
- {
220
- "name": "expected_row_count_range",
221
- "passed": expected_row_range["min"] <= len(rows) <= expected_row_range["max"],
222
- "message": (
223
- "Expected the cleaned table to contain between "
224
- f"{expected_row_range['min']} and {expected_row_range['max']} rows."
225
- ),
226
- }
227
- )
228
-
229
- required_columns = expected_outcome.get(
230
- "required_non_null_columns", task_definition.get("required_columns", [])
231
- )
232
- if required_columns:
233
- checks.append(
234
- {
235
- "name": "required_non_null_columns",
236
- "passed": not any(
237
- _is_missing_value(row.get(column))
238
- for row in rows
239
- for column in required_columns
240
- ),
241
- "message": "Required columns must be populated for all remaining rows.",
242
- }
243
- )
244
-
245
- for unique_column in expected_outcome.get("unique_by", []):
246
- checks.append(
247
- {
248
- "name": f"unique_by:{unique_column}",
249
- "passed": not _has_duplicates(rows, unique_column),
250
- "message": f"Values in '{unique_column}' must remain unique.",
251
- }
252
- )
253
-
254
- for column, rule in expected_outcome.get("normalized_columns", {}).items():
255
- if rule == "title_case":
256
- checks.append(
257
- {
258
- "name": f"normalized_column:{column}",
259
- "passed": not any(
260
- isinstance(row.get(column), str)
261
- and _needs_title_case(str(row.get(column)))
262
- for row in rows
263
- ),
264
- "message": f"Column '{column}' should use a consistent title-case style.",
265
- }
266
- )
267
-
268
- for column, rule in expected_outcome.get("format_rules", {}).items():
269
- if rule == "valid_email":
270
- checks.append(
271
- {
272
- "name": f"valid_email:{column}",
273
- "passed": not any(
274
- row.get(column) not in (None, "")
275
- and not _is_valid_email(str(row.get(column)))
276
- for row in rows
277
- ),
278
- "message": "All remaining email values must use a valid email format.",
279
- }
280
- )
281
- if rule == "normalized_phone":
282
- checks.append(
283
- {
284
- "name": f"normalized_phone:{column}",
285
- "passed": not any(
286
- row.get(column) not in (None, "")
287
- and not _is_valid_phone(str(row.get(column)))
288
- for row in rows
289
- ),
290
- "message": "All remaining phone values must use a consistent valid format.",
291
- }
292
- )
293
-
294
- initial_rows = {}
295
- if state is not None:
296
- initial_rows = dict(state.get("initial_table_by_row_id", {}))
297
-
298
- for row_id in expected_outcome.get("must_preserve_valid_rows", []):
299
- current_row = table_by_row_id.get(row_id)
300
- checks.append(
301
- {
302
- "name": f"preserve_valid_row:{row_id}",
303
- "passed": current_row is not None and current_row == initial_rows.get(row_id),
304
- "message": f"Valid row {row_id} should remain logically unchanged.",
305
- }
306
- )
307
-
308
- for row_group in expected_outcome.get("exactly_one_of_rows", []):
309
- surviving = [row_id for row_id in row_group if row_id in table_by_row_id]
310
- checks.append(
311
- {
312
- "name": f"exactly_one_of_rows:{','.join(str(row_id) for row_id in row_group)}",
313
- "passed": len(surviving) == 1,
314
- "message": f"Exactly one of rows {row_group} should remain in the cleaned table.",
315
- }
316
- )
317
-
318
- for row_id in expected_outcome.get("rows_must_survive", []):
319
- checks.append(
320
- {
321
- "name": f"rows_must_survive:{row_id}",
322
- "passed": row_id in table_by_row_id,
323
- "message": f"Row {row_id} must still be present in the cleaned table.",
324
- }
325
- )
326
-
327
- for row_id in expected_outcome.get("rows_must_be_removed", []):
328
- checks.append(
329
- {
330
- "name": f"rows_must_be_removed:{row_id}",
331
- "passed": row_id not in table_by_row_id,
332
- "message": f"Row {row_id} should not remain in the cleaned table.",
333
- }
334
- )
335
-
336
- for issue in task_definition.get("hidden_issues", []):
337
- if issue.get("type") == "valid_trap":
338
- continue
339
- message = issue.get("description") or f"Issue '{issue.get('type')}' must be resolved."
340
- checks.append(
341
- {
342
- "name": f"hidden_issue:{issue.get('type')}",
343
- "passed": _is_issue_resolved(issue, table_by_row_id),
344
- "message": message,
345
- }
346
- )
347
-
348
- return checks
349
-
350
-
351
- def _calculate_reward(
352
- state: MutableMapping[str, Any],
353
- action: Mapping[str, Any],
354
- result: MutableMapping[str, Any],
355
- ) -> float:
356
- """Compute the deterministic scalar reward for a single environment step."""
357
-
358
- reward = 0.0
359
-
360
- # Every step incurs a small cost so the agent is encouraged to solve the
361
- # task quickly instead of exploring indefinitely.
362
- reward += STEP_PENALTY
363
-
364
- # Intermediate rewards encourage the agent to make progress even when the
365
- # dataset is not fully clean yet.
366
- if _to_bool(result, "correct_duplicate_removal"):
367
- reward += CORRECT_DUPLICATE_REMOVAL_REWARD
368
-
369
- if _to_bool(result, "correct_normalization"):
370
- reward += CORRECT_NORMALIZATION_REWARD
371
-
372
- if _to_bool(result, "fixed_missing_value") or _to_bool(
373
- result, "fixing_missing_values"
374
- ):
375
- reward += FIX_MISSING_VALUE_REWARD
376
-
377
- if _to_bool(result, "validation_success"):
378
- reward += VALIDATION_SUCCESS_REWARD
379
-
380
- if _to_bool(result, "corrected_previous_mistake"):
381
- reward += RECOVERY_BONUS
382
-
383
- if _to_bool(result, "noop"):
384
- reward += NOOP_PENALTY
385
-
386
- if _to_bool(result, "destructive_action"):
387
- reward += DESTRUCTIVE_ACTION_PENALTY
388
-
389
- # Progress-based shaping provides a smoother learning signal for partial
390
- # improvement, even when a step does not fully resolve a visible issue.
391
- progress_delta = float(result.get("progress_delta", 0.0))
392
- progress_delta = max(0.0, min(1.0, progress_delta))
393
- reward += progress_delta * PROGRESS_REWARD_SCALE
394
-
395
- # Explicitly penalize steps that fail to improve task progress so agents do
396
- # not learn that random but harmless actions are equivalent to useful ones.
397
- if progress_delta == 0.0:
398
- reward -= 0.05
399
-
400
- # Direct penalties handle obviously harmful moves. Repetition is tracked
401
- # separately so the same bad behavior becomes more expensive over time.
402
- if _to_bool(result, "wrong_deletion"):
403
- reward += WRONG_DELETION_PENALTY
404
- mistake_key = _mistake_key(action, result, "wrong_deletion")
405
- occurrence_count = track_mistake(state, mistake_key)
406
- reward += repeated_mistake_penalty(occurrence_count)
407
-
408
- if _to_bool(result, "unnecessary_action"):
409
- reward += UNNECESSARY_ACTION_PENALTY
410
- mistake_key = _mistake_key(action, result, "unnecessary_action")
411
- occurrence_count = track_mistake(state, mistake_key)
412
- reward += repeated_mistake_penalty(occurrence_count)
413
-
414
- # Support arbitrary custom mistake keys in addition to the built-in ones.
415
- for mistake_key in result.get("mistake_keys", []):
416
- if mistake_key not in {"wrong_deletion", "unnecessary_action"}:
417
- occurrence_count = track_mistake(state, str(mistake_key))
418
- reward += repeated_mistake_penalty(occurrence_count)
419
-
420
- # Reward early completion only when the task finishes with steps still
421
- # available. This creates a simple deterministic efficiency incentive.
422
- if _to_bool(result, "task_completed") and int(state.get("steps_remaining", 0)) > 0:
423
- reward += EFFICIENCY_BONUS
424
-
425
- return _clamp_reward(reward)
426
-
427
-
428
- def grade_step(
429
- state: MutableMapping[str, Any],
430
- action: Mapping[str, Any],
431
- result: MutableMapping[str, Any],
432
- ) -> float:
433
- """Compute a deterministic dense reward for a single environment step."""
434
-
435
- return _calculate_reward(state, action, result)
436
 
437
 
438
  def grade_step_details(
@@ -440,153 +10,148 @@ def grade_step_details(
440
  action: Mapping[str, Any],
441
  result: MutableMapping[str, Any],
442
  ) -> Tuple[float, Dict[str, Any]]:
443
- """Compute reward plus a structured component breakdown for debugging."""
444
-
445
- previous_mistakes = {
446
- key: int(value)
447
- for key, value in state.get("mistakes", {}).items()
448
- }
449
- reward = grade_step(state, action, result)
450
-
451
- wrong_deletion_repeat_penalty = 0.0
452
- if result.get("wrong_deletion"):
453
- mistake_key = _mistake_key(action, result, "wrong_deletion")
454
- occurrence_count = int(state.get("mistakes", {}).get(mistake_key, 0))
455
- if occurrence_count > int(previous_mistakes.get(mistake_key, 0)):
456
- wrong_deletion_repeat_penalty = repeated_mistake_penalty(occurrence_count)
457
-
458
- unnecessary_repeat_penalty = 0.0
459
- if result.get("unnecessary_action"):
460
- mistake_key = _mistake_key(action, result, "unnecessary_action")
461
- occurrence_count = int(state.get("mistakes", {}).get(mistake_key, 0))
462
- if occurrence_count > int(previous_mistakes.get(mistake_key, 0)):
463
- unnecessary_repeat_penalty = repeated_mistake_penalty(occurrence_count)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
464
 
465
- components: Dict[str, Any] = {
466
- "step_penalty": STEP_PENALTY,
467
- "duplicate_reward": (
468
- CORRECT_DUPLICATE_REMOVAL_REWARD
469
- if result.get("correct_duplicate_removal")
470
- else 0.0
471
- ),
472
- "normalization_reward": (
473
- CORRECT_NORMALIZATION_REWARD
474
- if result.get("correct_normalization")
475
- else 0.0
476
- ),
477
- "missing_value_reward": (
478
- FIX_MISSING_VALUE_REWARD if result.get("fixed_missing_value") else 0.0
479
- ),
480
- "validation_reward": (
481
- VALIDATION_SUCCESS_REWARD if result.get("validation_success") else 0.0
482
- ),
483
- "penalties": {
484
- "wrong_deletion": (
485
- WRONG_DELETION_PENALTY if result.get("wrong_deletion") else 0.0
486
- ),
487
- "unnecessary_action": (
488
- UNNECESSARY_ACTION_PENALTY if result.get("unnecessary_action") else 0.0
489
- ),
490
- "wrong_deletion_repeat": wrong_deletion_repeat_penalty,
491
- "unnecessary_action_repeat": unnecessary_repeat_penalty,
492
- "noop": NOOP_PENALTY if result.get("noop") else 0.0,
493
- "destructive_action": (
494
- DESTRUCTIVE_ACTION_PENALTY
495
- if result.get("destructive_action")
496
- else 0.0
497
- ),
498
- },
499
- "progress_reward": round(
500
- max(0.0, min(1.0, float(result.get("progress_delta", 0.0))))
501
- * PROGRESS_REWARD_SCALE,
502
- 4,
503
- ),
504
- "recovery_bonus": (
505
- RECOVERY_BONUS if result.get("corrected_previous_mistake") else 0.0
506
- ),
507
- "efficiency_bonus": (
508
- EFFICIENCY_BONUS
509
- if result.get("task_completed") and int(state.get("steps_remaining", 0)) > 0
510
- else 0.0
511
- ),
512
- }
513
 
514
- if float(result.get("progress_delta", 0.0)) == 0.0:
515
- components["no_progress_penalty"] = -0.05
 
516
 
517
- result["reward_components"] = components
518
- result["reward_total"] = reward
519
- return reward, components
520
 
521
 
522
  def grade_task_result(
523
  task_definition: Mapping[str, Any],
524
- table: Iterable[Dict[str, Any]],
525
  state: Optional[Mapping[str, Any]] = None,
526
  ) -> float:
527
- """Compute a deterministic final task score between 0.0 and 1.0."""
528
-
529
- checks = _task_check_results(task_definition, table, state)
530
- if not checks:
531
- return 0.0
532
- return _clamp_score(
533
- sum(1.0 for check in checks if check["passed"]) / len(checks)
 
 
 
 
 
 
 
 
 
 
 
 
 
534
  )
 
535
 
536
 
537
  def task_failure_messages(
538
  task_definition: Mapping[str, Any],
539
- table: Iterable[Dict[str, Any]],
540
  state: Optional[Mapping[str, Any]] = None,
541
  ) -> list[str]:
542
- """Return explicit failure messages for unresolved outcome checks."""
543
-
544
- return [
545
- str(check["message"])
546
- for check in _task_check_results(task_definition, table, state)
547
- if not bool(check["passed"])
548
- ]
549
-
550
-
551
- def grade_easy_cleaning_task(
552
- task_definition: Mapping[str, Any],
553
- table: Iterable[Dict[str, Any]],
554
- state: Optional[Mapping[str, Any]] = None,
555
- ) -> float:
556
- """Grade the easy cleaning task on a 0.0–1.0 scale."""
557
-
558
- return grade_task_result(task_definition, table, state)
559
-
560
-
561
- def grade_medium_normalization_task(
562
- task_definition: Mapping[str, Any],
563
- table: Iterable[Dict[str, Any]],
564
- state: Optional[Mapping[str, Any]] = None,
565
- ) -> float:
566
- """Grade the medium normalization task on a 0.0–1.0 scale."""
567
-
568
- return grade_task_result(task_definition, table, state)
569
-
570
-
571
- def grade_hard_conflict_resolution_task(
572
- task_definition: Mapping[str, Any],
573
- table: Iterable[Dict[str, Any]],
574
- state: Optional[Mapping[str, Any]] = None,
575
- ) -> float:
576
- """Grade the hard conflict-resolution task on a 0.0–1.0 scale."""
577
 
578
- return grade_task_result(task_definition, table, state)
 
 
 
 
579
 
580
 
581
- __all__ = [
582
- "detect_repeated_mistake",
583
- "grade_step",
584
- "grade_step_details",
585
- "grade_task_result",
586
- "task_failure_messages",
587
- "grade_easy_cleaning_task",
588
- "grade_medium_normalization_task",
589
- "grade_hard_conflict_resolution_task",
590
- "repeated_mistake_penalty",
591
- "track_mistake",
592
- ]
 
1
+ """Strict semantic evaluation math for ``dataops-gym``."""
 
 
 
 
2
 
3
  from __future__ import annotations
4
 
5
+ from typing import Any, Dict, Mapping, MutableMapping, Optional, Tuple
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
 
8
  def grade_step_details(
 
10
  action: Mapping[str, Any],
11
  result: MutableMapping[str, Any],
12
  ) -> Tuple[float, Dict[str, Any]]:
13
+ """Apply the exact per-step reward rules with no score clamping."""
14
+
15
+ score = 0.0
16
+ components: Dict[str, float] = {}
17
+ confidence = float(action.get("confidence", 0.0))
18
+
19
+ action_type = str(action.get("action_type", ""))
20
+ if result.get("classification_correct"):
21
+ # Detect is intentionally lower value than fix/cannot_determine.
22
+ if action_type == "detect_issue":
23
+ score += 0.1
24
+ components["classification"] = 0.1
25
+ else:
26
+ score += 0.2
27
+ components["classification"] = 0.2
28
+ elif result.get("classification_incorrect"):
29
+ score -= 0.2
30
+ components["classification"] = -0.2
31
+
32
+ if result.get("correct_issue_detected"):
33
+ if action_type == "detect_issue":
34
+ score += 0.05
35
+ components["issue_detection"] = 0.05
36
+ else:
37
+ score += 0.15
38
+ components["issue_detection"] = 0.15
39
+ elif result.get("missed_issue"):
40
+ score -= 0.15
41
+ components["issue_detection"] = -0.15
42
+ elif result.get("false_issue"):
43
+ score -= 0.05
44
+ components["issue_detection"] = -0.05
45
+
46
+ if result.get("correct_fix"):
47
+ score += 0.25
48
+ components["decision"] = 0.25
49
+ elif result.get("correct_cannot_determine"):
50
+ score += 0.25
51
+ components["decision"] = 0.25
52
+ elif result.get("hallucinated_fix"):
53
+ score -= 0.5
54
+ components["decision"] = -0.5
55
+ elif result.get("wrong_fix"):
56
+ score -= 0.4
57
+ components["decision"] = -0.4
58
+ elif result.get("wrong_cannot_determine"):
59
+ score -= 0.2
60
+ components["decision"] = -0.2
61
+
62
+ if result.get("passive_penalty"):
63
+ score -= 0.05
64
+ components["passive_penalty"] = -0.05
65
+
66
+ if result.get("repeated_detection"):
67
+ score -= 0.1
68
+ components["repeated_detection_penalty"] = -0.1
69
+
70
+ extra_mods = int(result.get("extra_fields_modified", 0))
71
+ if extra_mods > 0:
72
+ over = -0.05 * extra_mods
73
+ score += over
74
+ components["overcorrection"] = over
75
+
76
+ if result.get("consistent_handling"):
77
+ score += 0.2
78
+ components["cross_record_consistency"] = 0.2
79
+ elif result.get("inconsistent_handling"):
80
+ score -= 0.3
81
+ components["cross_record_consistency"] = -0.3
82
+
83
+ is_correct = bool(
84
+ result.get("classification_correct")
85
+ or result.get("correct_fix")
86
+ or result.get("correct_cannot_determine")
87
+ or result.get("correct_issue_detected")
88
+ )
89
+ is_wrong = bool(
90
+ result.get("classification_incorrect")
91
+ or result.get("wrong_fix")
92
+ or result.get("hallucinated_fix")
93
+ or result.get("wrong_cannot_determine")
94
+ or result.get("false_issue")
95
+ )
96
+ if confidence > 0.7 and is_correct:
97
+ score += 0.05
98
+ components["confidence"] = 0.05
99
+ elif confidence > 0.7 and is_wrong:
100
+ score -= 0.1
101
+ components["confidence"] = -0.1
102
 
103
+ if result.get("hallucinated_fix") and confidence > 0.8:
104
+ score -= 0.2
105
+ components["confident_hallucination_amplification"] = -0.2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
+ if result.get("resolved_detected_issue"):
108
+ score += 0.15
109
+ components["resolution_reward"] = 0.15
110
 
111
+ return score, components
 
 
112
 
113
 
114
  def grade_task_result(
115
  task_definition: Mapping[str, Any],
116
+ table: Any,
117
  state: Optional[Mapping[str, Any]] = None,
118
  ) -> float:
119
+ """Compute final task score in [0, 1] using required formula."""
120
+
121
+ _ = task_definition
122
+ _ = table
123
+ state = state or {}
124
+ per_record_scores = dict(state.get("per_record_scores", {}))
125
+ n = max(1, len(per_record_scores))
126
+ avg_record_score = sum(float(v) for v in per_record_scores.values()) / n
127
+ normalized_record_score = (avg_record_score + 1.0) / 2.0
128
+ normalized_record_score = max(0.0, min(1.0, normalized_record_score))
129
+
130
+ hallucination_rate = float(state.get("hallucination_rate", 0.0))
131
+ uncertainty_accuracy = float(state.get("uncertainty_accuracy", 0.0))
132
+ consistency_score = float(state.get("consistency_score", 1.0))
133
+
134
+ task_score = (
135
+ 0.5 * normalized_record_score
136
+ + 0.2 * (1.0 - hallucination_rate)
137
+ + 0.15 * uncertainty_accuracy
138
+ + 0.15 * consistency_score
139
  )
140
+ return max(0.0, min(1.0, task_score))
141
 
142
 
143
  def task_failure_messages(
144
  task_definition: Mapping[str, Any],
145
+ table: Any,
146
  state: Optional[Mapping[str, Any]] = None,
147
  ) -> list[str]:
148
+ """Return lightweight failure reasons collected during stepping."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
 
150
+ _ = task_definition
151
+ _ = table
152
+ state = state or {}
153
+ failures = state.get("failure_logs", [])
154
+ return [str(f.get("details", "")) for f in failures if f.get("details")]
155
 
156
 
157
+ __all__ = ["grade_step_details", "grade_task_result", "task_failure_messages"]
 
 
 
 
 
 
 
 
 
 
 
inference.py CHANGED
@@ -23,15 +23,16 @@ from env import DataOpsEnv
23
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
24
  MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen3-VL-30B-A3B-Instruct:novita")
25
  HF_TOKEN = os.getenv("HF_TOKEN")
26
- BENCHMARK = os.getenv("OPENENV_BENCHMARK", "dataops-env")
 
 
27
  MAX_STEPS = 10
28
  TEMPERATURE = 0.0
29
  MAX_TOKENS = 160
30
  MODEL_RETRIES = 2
31
- FALLBACK_ACTION = "noop()"
32
  ACTION_PREFIX_RE = re.compile(r"^(action|next action)\s*[:\-]\s*", re.IGNORECASE)
33
  EMAIL_PATTERN = re.compile(r"^[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}$")
34
- TASK_ORDER = ["easy", "medium", "hard"]
35
  IDENTIFIER_COLUMNS = ("customer_id", "vendor_id", "partner_id")
36
  POLICY_CACHE_PATH = os.getenv("POLICY_CACHE_PATH", ".dataops_policy_cache.json")
37
  POLICY_CACHE_VERSION = 1
@@ -209,12 +210,12 @@ def log_step(
209
  )
210
 
211
 
212
- def log_end(success: bool, steps: int, rewards: List[float]) -> None:
213
  """Emit the required episode end line."""
214
 
215
  rewards_text = ",".join(f"{reward:.2f}" for reward in rewards)
216
  print(
217
- f"[END] success={str(success).lower()} steps={steps} rewards={rewards_text}",
218
  flush=True,
219
  )
220
 
@@ -293,8 +294,13 @@ def build_memory_keys(
293
  ) -> Tuple[str, str]:
294
  """Build exact-state and generalized problem-pattern keys."""
295
 
296
- table = list(observation.get("table", []))
297
- normalized_issues = sorted(_normalize_issue_text(str(issue)) for issue in observation.get("issues", []))
 
 
 
 
 
298
  state_key = _hash_key(
299
  {
300
  "task_name": task_name,
@@ -304,7 +310,7 @@ def build_memory_keys(
304
  {key: row.get(key) for key in sorted(row.keys())}
305
  for row in sorted(table, key=lambda row: int(row.get("row_id", 0)))
306
  ],
307
- "issues": normalized_issues,
308
  }
309
  )
310
  pattern_key = _hash_key(
@@ -312,7 +318,7 @@ def build_memory_keys(
312
  "task_name": task_name,
313
  "goal": goal,
314
  "summary": _table_summary(table),
315
- "issues": normalized_issues,
316
  }
317
  )
318
  return state_key, pattern_key
@@ -432,7 +438,7 @@ def _build_action_string(payload: Mapping[str, Any]) -> str:
432
 
433
  action_type = str(payload["action_type"])
434
  args: List[str] = []
435
- for key in ("row_id", "column", "value"):
436
  if key not in payload or payload[key] is None:
437
  continue
438
  value = payload[key]
@@ -474,25 +480,22 @@ def action_string_to_payload(action_str: str, step_number: int) -> Tuple[str, Di
474
  try:
475
  expression = ast.parse(action_str, mode="eval").body
476
  except SyntaxError:
477
- return FALLBACK_ACTION, {"action_id": f"step-{step_number:03d}", "action_type": "noop"}
478
 
479
  if not isinstance(expression, ast.Call) or not isinstance(expression.func, ast.Name):
480
- return FALLBACK_ACTION, {"action_id": f"step-{step_number:03d}", "action_type": "noop"}
481
 
482
  allowed_actions = {
483
- "remove_duplicate",
484
- "fill_missing",
485
- "normalize_column",
486
- "delete_row",
487
- "validate",
488
- "noop",
489
  }
490
  action_type = expression.func.id
491
  if action_type not in allowed_actions:
492
- return FALLBACK_ACTION, {"action_id": f"step-{step_number:03d}", "action_type": "noop"}
493
 
494
  payload: Dict[str, Any] = {
495
- "action_id": f"step-{step_number:03d}",
496
  "action_type": action_type,
497
  }
498
  try:
@@ -501,7 +504,11 @@ def action_string_to_payload(action_str: str, step_number: int) -> Tuple[str, Di
501
  continue
502
  payload[keyword.arg] = ast.literal_eval(keyword.value)
503
  except (SyntaxError, ValueError, TypeError):
504
- return FALLBACK_ACTION, {"action_id": f"step-{step_number:03d}", "action_type": "noop"}
 
 
 
 
505
 
506
  return _build_action_string(payload), payload
507
 
@@ -536,7 +543,7 @@ def _table_preview(table: Sequence[Mapping[str, Any]], limit: int = 6) -> str:
536
  summary = ", ".join(
537
  f"{key}={value}"
538
  for key, value in row.items()
539
- if key in {"row_id", "name", "city", "email", "phone", "status", "customer_id", "vendor_id", "partner_id"}
540
  )
541
  preview_lines.append(f"- {summary}")
542
  return "\n".join(preview_lines) if preview_lines else "- None"
@@ -553,10 +560,8 @@ def build_user_prompt(
553
  ) -> str:
554
  """Construct a compact prompt that constrains the model to useful actions."""
555
 
556
- issues = observation.get("issues", [])
557
- hints = observation.get("hints", [])
558
- issues_text = "\n".join(f"- {issue}" for issue in issues[:6]) if issues else "- None"
559
- hints_text = "\n".join(f"- {hint}" for hint in hints[:3]) if hints else "- None"
560
  candidates_text = "\n".join(f"- {action}" for action in candidate_actions)
561
  blocked_text = "\n".join(f"- {action}" for action in blocked_actions[:5]) if blocked_actions else "- None"
562
 
@@ -565,13 +570,11 @@ def build_user_prompt(
565
  Step: {step}
566
  Goal: {goal}
567
  Steps remaining: {observation.get("steps_remaining")}
568
- Progress: {observation.get("progress")}
569
- Current issues:
570
- {issues_text}
571
- Current hints:
572
- {hints_text}
573
  Table preview:
574
- {_table_preview(observation.get("table", []))}
575
  Recent history:
576
  {build_history_lines(history)}
577
  Last action error: {last_error or "null"}
@@ -594,124 +597,74 @@ def _prefer_action(
594
  action_text = _build_action_string(candidate)
595
  if action_text not in blocked_actions:
596
  return dict(candidate)
597
- return {"action_type": "validate"}
598
 
599
 
600
- def _exact_duplicate_candidates(table: Sequence[Mapping[str, Any]]) -> List[Dict[str, Any]]:
601
- """Generate explicit remove-duplicate actions for exact duplicate rows."""
 
602
 
603
- groups: Dict[Tuple[Tuple[str, Any], ...], List[int]] = defaultdict(list)
604
- for row in table:
605
- row_id = row.get("row_id")
606
- if row_id is None:
607
- continue
608
- groups[_row_signature(row)].append(int(row_id))
609
-
610
- actions: List[Dict[str, Any]] = []
611
- for row_ids in groups.values():
612
- if len(row_ids) > 1:
613
- actions.append({"action_type": "remove_duplicate", "row_id": max(row_ids)})
614
- return actions
615
 
616
-
617
- def _group_by_identifier(table: Sequence[Mapping[str, Any]]) -> Dict[Tuple[str, str], List[Dict[str, Any]]]:
618
- """Group rows by likely business identifiers."""
619
-
620
- groups: Dict[Tuple[str, str], List[Dict[str, Any]]] = defaultdict(list)
621
- for row in table:
622
- for key in IDENTIFIER_COLUMNS:
623
- value = row.get(key)
624
- if value not in (None, ""):
625
- groups[(key, str(value))].append(dict(row))
626
- return groups
627
-
628
-
629
- def _row_quality_score(row: Mapping[str, Any]) -> int:
630
- """Score a row so lower-quality conflict rows can be removed first."""
631
-
632
- score = 0
633
- if _is_valid_email(row.get("email")):
634
- score += 3
635
- if _is_valid_phone(row.get("phone")) or row.get("phone") in (None, ""):
636
- score += 2
637
- if isinstance(row.get("status"), str) and row.get("status") == "active":
638
- score += 1
639
- if isinstance(row.get("name"), str) and row.get("name").strip():
640
- score += 1
641
- return score
642
-
643
-
644
- def _structural_delete_candidates(table: Sequence[Mapping[str, Any]]) -> List[Dict[str, Any]]:
645
- """Generate delete actions for non-exact structural conflicts."""
646
 
647
  actions: List[Dict[str, Any]] = []
648
- for rows in _group_by_identifier(table).values():
649
- if len(rows) < 2:
650
- continue
651
- signatures = {_row_signature(row) for row in rows}
652
- if len(signatures) == 1:
653
- continue
654
- worst_row = sorted(
655
- rows,
656
- key=lambda row: (_row_quality_score(row), int(row.get("row_id", 0))),
657
- )[0]
658
- actions.append({"action_type": "delete_row", "row_id": int(worst_row["row_id"])})
659
-
660
- email_groups: Dict[str, List[Dict[str, Any]]] = defaultdict(list)
661
  for row in table:
662
- email = row.get("email")
663
- if email not in (None, ""):
664
- email_groups[str(email)].append(dict(row))
665
- for rows in email_groups.values():
666
- if len(rows) < 2:
667
- continue
668
- worst_row = sorted(
669
- rows,
670
- key=lambda row: (_row_quality_score(row), int(row.get("row_id", 0))),
671
- )[0]
672
- action = {"action_type": "delete_row", "row_id": int(worst_row["row_id"])}
673
- if action not in actions:
674
- actions.append(action)
675
- return actions
676
-
677
-
678
- def _missing_value_candidates(table: Sequence[Mapping[str, Any]]) -> List[Dict[str, Any]]:
679
- """Generate candidate fill actions for visible missing values."""
680
-
681
- present_columns = {key for row in table for key in row.keys()}
682
- priorities = [
683
- column
684
- for column in ["email", "city", "phone", "status", "name"]
685
- if column in present_columns
686
- ]
687
- actions: List[Dict[str, Any]] = []
688
- for column in priorities:
689
- for row in table:
690
- if _is_missing(row.get(column)):
691
  actions.append(
692
  {
693
- "action_type": "fill_missing",
694
- "row_id": int(row["row_id"]),
695
- "column": column,
696
- "value": _infer_fill_value(row, column, table),
 
697
  }
698
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
699
  return actions
700
 
701
 
702
- def _normalization_candidates(table: Sequence[Mapping[str, Any]]) -> List[Dict[str, Any]]:
703
- """Generate candidate column-normalization actions."""
704
 
705
- candidates: List[Dict[str, Any]] = []
706
- if any(row.get("email") not in (None, "") and not _is_valid_email(row.get("email")) for row in table):
707
- candidates.append({"action_type": "normalize_column", "column": "email"})
708
- if any(row.get("phone") not in (None, "") and not _is_valid_phone(row.get("phone")) for row in table):
709
- candidates.append({"action_type": "normalize_column", "column": "phone"})
710
- if any(_needs_title_case(row.get("name")) for row in table):
711
- candidates.append({"action_type": "normalize_column", "column": "name"})
712
- if any(_needs_title_case(row.get("city")) for row in table):
713
- candidates.append({"action_type": "normalize_column", "column": "city"})
714
- return candidates
715
 
716
 
717
  def propose_candidate_actions(
@@ -720,15 +673,21 @@ def propose_candidate_actions(
720
  ) -> List[Dict[str, Any]]:
721
  """Generate ranked candidate actions from visible table state."""
722
 
723
- table = list(observation.get("table", []))
724
- candidates = (
725
- _exact_duplicate_candidates(table)
726
- + _structural_delete_candidates(table)
727
- + _missing_value_candidates(table)
728
- + _normalization_candidates(table)
729
- + [{"action_type": "validate"}]
730
- + [{"action_type": "noop"}]
731
- )
 
 
 
 
 
 
732
 
733
  unique_candidates: List[Dict[str, Any]] = []
734
  seen: set[str] = set()
@@ -746,7 +705,7 @@ def propose_candidate_actions(
746
  for candidate in unique_candidates
747
  if _build_action_string(candidate) != preferred_text
748
  ]
749
- return ordered[:8]
750
 
751
 
752
  def _order_candidates_with_memory(
@@ -754,15 +713,26 @@ def _order_candidates_with_memory(
754
  memory: PolicyMemory,
755
  state_key: str,
756
  pattern_key: str,
 
757
  ) -> List[Dict[str, Any]]:
758
  """Re-rank candidates using persistent cross-episode memory."""
759
 
760
  scored = []
 
 
 
 
 
 
 
 
 
761
  for index, candidate in enumerate(candidates):
762
  action_text = _build_action_string(candidate)
 
763
  scored.append(
764
  (
765
- -memory.score_action(state_key, pattern_key, action_text),
766
  index,
767
  dict(candidate),
768
  )
@@ -847,7 +817,9 @@ def choose_action(
847
  memory_blocked = memory.blocked_actions(state_key, pattern_key)
848
  combined_blocked = set(blocked_actions) | set(memory_blocked)
849
  candidates = propose_candidate_actions(observation, combined_blocked)
850
- candidates = _order_candidates_with_memory(candidates, memory, state_key, pattern_key)
 
 
851
  heuristic_candidate = candidates[0]
852
  heuristic_text = _build_action_string(heuristic_candidate)
853
  candidate_texts = [_build_action_string(candidate) for candidate in candidates]
@@ -864,6 +836,8 @@ def choose_action(
864
  blocked_actions=sorted(combined_blocked),
865
  )
866
 
 
 
867
  chosen_text = model_text or heuristic_text
868
  normalized_text, payload = action_string_to_payload(chosen_text, step_number)
869
  if normalized_text in combined_blocked:
@@ -889,6 +863,8 @@ def run_episode(
889
  last_error: Optional[str] = None
890
  final_score = 0.0
891
  task_variant = "unknown"
 
 
892
 
893
  try:
894
  log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME)
@@ -903,7 +879,7 @@ def run_episode(
903
  task_name=task_name,
904
  task_variant=task_variant,
905
  observation=observation,
906
- goal=observation_model.goal,
907
  step_number=step_number,
908
  history=history,
909
  last_error=last_error,
@@ -911,12 +887,23 @@ def run_episode(
911
  )
912
 
913
  try:
 
 
914
  observation_model, reward, done, info = env.step(action_payload)
915
  observation = observation_model.model_dump()
 
 
916
  result = info.get("result", {})
917
- progress_delta = float(result.get("progress_delta", 0.0))
918
- error_value = result.get("error_type") or info.get("error") or None
919
- final_score = float(info.get("task_score", 0.0))
 
 
 
 
 
 
 
920
  if error_value == "general":
921
  error_value = None
922
  memory.update(
@@ -931,6 +918,15 @@ def run_episode(
931
  )
932
  if error_value or progress_delta == 0.0 or reward <= 0.0:
933
  blocked_actions.add(action_text)
 
 
 
 
 
 
 
 
 
934
  except Exception as exc: # noqa: BLE001
935
  reward = 0.0
936
  done = True
@@ -965,24 +961,28 @@ def run_episode(
965
  )
966
 
967
  if done:
968
- success = bool(final_score >= 0.95 and error_value is None)
969
  break
970
  finally:
971
  memory.save()
972
  close_method = getattr(env, "close", None)
973
  if callable(close_method):
974
  close_method()
975
- log_end(success=success, steps=steps_taken, rewards=rewards)
976
  return final_score
977
 
978
 
979
  def main() -> None:
980
- """Run all benchmark tasks with deterministic ordering and stdout formatting."""
981
 
982
  client = create_client()
983
  memory = PolicyMemory(POLICY_CACHE_PATH)
984
- for task_index, task_name in enumerate(TASK_ORDER):
985
- run_episode(client=client, memory=memory, task_name=task_name, seed=task_index)
 
 
 
 
986
 
987
 
988
  if __name__ == "__main__":
 
23
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
24
  MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen3-VL-30B-A3B-Instruct:novita")
25
  HF_TOKEN = os.getenv("HF_TOKEN")
26
+ BENCHMARK = os.getenv("BROWSERGYM_BENCHMARK", "dataops-env")
27
+ TASK_NAME = os.getenv("BROWSERGYM_TASK_NAME", "all")
28
+ TASK_ORDER = ["easy", "medium", "hard"]
29
  MAX_STEPS = 10
30
  TEMPERATURE = 0.0
31
  MAX_TOKENS = 160
32
  MODEL_RETRIES = 2
33
+ FALLBACK_ACTION = "skip(record_id='0', field='record', confidence=0.0)"
34
  ACTION_PREFIX_RE = re.compile(r"^(action|next action)\s*[:\-]\s*", re.IGNORECASE)
35
  EMAIL_PATTERN = re.compile(r"^[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}$")
 
36
  IDENTIFIER_COLUMNS = ("customer_id", "vendor_id", "partner_id")
37
  POLICY_CACHE_PATH = os.getenv("POLICY_CACHE_PATH", ".dataops_policy_cache.json")
38
  POLICY_CACHE_VERSION = 1
 
210
  )
211
 
212
 
213
+ def log_end(success: bool, steps: int, rewards: List[float], final_score: float) -> None:
214
  """Emit the required episode end line."""
215
 
216
  rewards_text = ",".join(f"{reward:.2f}" for reward in rewards)
217
  print(
218
+ f"[END] success={str(success).lower()} steps={steps} rewards={rewards_text} final_score={final_score:.4f}",
219
  flush=True,
220
  )
221
 
 
294
  ) -> Tuple[str, str]:
295
  """Build exact-state and generalized problem-pattern keys."""
296
 
297
+ dataset = observation.get("dataset", {}) if isinstance(observation, dict) else {}
298
+ table = list(dataset.get("modified", []))
299
+ normalized_issues = [
300
+ f"rows={len(table)}",
301
+ f"history={len(observation.get('action_history', []))}",
302
+ f"iter={observation.get('current_iteration_score', 0.0)}",
303
+ ]
304
  state_key = _hash_key(
305
  {
306
  "task_name": task_name,
 
310
  {key: row.get(key) for key in sorted(row.keys())}
311
  for row in sorted(table, key=lambda row: int(row.get("row_id", 0)))
312
  ],
313
+ "issues": sorted(normalized_issues),
314
  }
315
  )
316
  pattern_key = _hash_key(
 
318
  "task_name": task_name,
319
  "goal": goal,
320
  "summary": _table_summary(table),
321
+ "issues": sorted(normalized_issues),
322
  }
323
  )
324
  return state_key, pattern_key
 
438
 
439
  action_type = str(payload["action_type"])
440
  args: List[str] = []
441
+ for key in ("record_id", "field", "value", "confidence"):
442
  if key not in payload or payload[key] is None:
443
  continue
444
  value = payload[key]
 
480
  try:
481
  expression = ast.parse(action_str, mode="eval").body
482
  except SyntaxError:
483
+ return FALLBACK_ACTION, {"action_type": "skip", "record_id": "0", "field": "record", "confidence": 0.0}
484
 
485
  if not isinstance(expression, ast.Call) or not isinstance(expression.func, ast.Name):
486
+ return FALLBACK_ACTION, {"action_type": "skip", "record_id": "0", "field": "record", "confidence": 0.0}
487
 
488
  allowed_actions = {
489
+ "detect_issue",
490
+ "fix_value",
491
+ "cannot_determine",
492
+ "skip",
 
 
493
  }
494
  action_type = expression.func.id
495
  if action_type not in allowed_actions:
496
+ return FALLBACK_ACTION, {"action_type": "skip", "record_id": "0", "field": "record", "confidence": 0.0}
497
 
498
  payload: Dict[str, Any] = {
 
499
  "action_type": action_type,
500
  }
501
  try:
 
504
  continue
505
  payload[keyword.arg] = ast.literal_eval(keyword.value)
506
  except (SyntaxError, ValueError, TypeError):
507
+ return FALLBACK_ACTION, {"action_type": "skip", "record_id": "0", "field": "record", "confidence": 0.0}
508
+
509
+ payload.setdefault("record_id", "0")
510
+ payload.setdefault("field", "record")
511
+ payload.setdefault("confidence", 0.6 if action_type != "skip" else 0.0)
512
 
513
  return _build_action_string(payload), payload
514
 
 
543
  summary = ", ".join(
544
  f"{key}={value}"
545
  for key, value in row.items()
546
+ if key in {"row_id", "name", "city", "email", "phone", "status", "customer_id", "vendor_id", "partner_id", "age", "start_date", "end_date"}
547
  )
548
  preview_lines.append(f"- {summary}")
549
  return "\n".join(preview_lines) if preview_lines else "- None"
 
560
  ) -> str:
561
  """Construct a compact prompt that constrains the model to useful actions."""
562
 
563
+ dataset = observation.get("dataset", {})
564
+ modified = dataset.get("modified", [])
 
 
565
  candidates_text = "\n".join(f"- {action}" for action in candidate_actions)
566
  blocked_text = "\n".join(f"- {action}" for action in blocked_actions[:5]) if blocked_actions else "- None"
567
 
 
570
  Step: {step}
571
  Goal: {goal}
572
  Steps remaining: {observation.get("steps_remaining")}
573
+ Current iteration score: {observation.get("current_iteration_score")}
574
+ Previous iteration score: {observation.get("previous_iteration_score")}
575
+ Per-record scores: {observation.get("per_record_scores")}
 
 
576
  Table preview:
577
+ {_table_preview(modified)}
578
  Recent history:
579
  {build_history_lines(history)}
580
  Last action error: {last_error or "null"}
 
597
  action_text = _build_action_string(candidate)
598
  if action_text not in blocked_actions:
599
  return dict(candidate)
600
+ return {"action_type": "skip", "record_id": "0", "field": "record", "confidence": 0.0}
601
 
602
 
603
+ def _record_id(row: Mapping[str, Any]) -> str:
604
+ rid = row.get("row_id")
605
+ return str(rid) if rid is not None else "0"
606
 
 
 
 
 
 
 
 
 
 
 
 
 
607
 
608
+ def _issue_like_candidates(table: Sequence[Mapping[str, Any]]) -> List[Dict[str, Any]]:
609
+ """Generate issue detection/fix candidates for new semantic action schema."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
610
 
611
  actions: List[Dict[str, Any]] = []
 
 
 
 
 
 
 
 
 
 
 
 
 
612
  for row in table:
613
+ rid = _record_id(row)
614
+ for field, value in row.items():
615
+ if field == "row_id":
616
+ continue
617
+ if _is_missing(value) or str(value).strip().lower() in {"unknown", "9999"}:
618
+ actions.append(
619
+ {"action_type": "detect_issue", "record_id": rid, "field": field, "confidence": 0.85}
620
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
621
  actions.append(
622
  {
623
+ "action_type": "fix_value",
624
+ "record_id": rid,
625
+ "field": field,
626
+ "value": _infer_fill_value(row, field, table),
627
+ "confidence": 0.75,
628
  }
629
  )
630
+ elif field == "email" and not _is_valid_email(value):
631
+ fixed = str(value).replace("[at]", "@").replace(" at ", "@").replace(" ", "")
632
+ if "@" in fixed and "." not in fixed.split("@")[-1]:
633
+ fixed += ".com"
634
+ actions.append({"action_type": "detect_issue", "record_id": rid, "field": field, "confidence": 0.85})
635
+ actions.append({"action_type": "fix_value", "record_id": rid, "field": field, "value": fixed, "confidence": 0.8})
636
+ elif field == "phone" and not _is_valid_phone(value):
637
+ digits = re.sub(r"\D", "", str(value))
638
+ if len(digits) == 10:
639
+ fixed = f"{digits[0:3]}-{digits[3:6]}-{digits[6:10]}"
640
+ actions.append({"action_type": "detect_issue", "record_id": rid, "field": field, "confidence": 0.8})
641
+ actions.append({"action_type": "fix_value", "record_id": rid, "field": field, "value": fixed, "confidence": 0.75})
642
+ elif field in {"start_date", "end_date"}:
643
+ start = row.get("start_date")
644
+ end = row.get("end_date")
645
+ if start and end and str(end) < str(start):
646
+ actions.append({"action_type": "detect_issue", "record_id": rid, "field": field, "confidence": 0.8})
647
+ actions.append({"action_type": "cannot_determine", "record_id": rid, "field": field, "confidence": 0.7})
648
+ elif field == "age":
649
+ try:
650
+ age = int(value)
651
+ except Exception:
652
+ age = -1
653
+ if age < 0 or age > 120:
654
+ actions.append({"action_type": "detect_issue", "record_id": rid, "field": field, "confidence": 0.9})
655
+ actions.append({"action_type": "cannot_determine", "record_id": rid, "field": field, "confidence": 0.8})
656
  return actions
657
 
658
 
659
+ def _detected_keys_from_history(action_history: Sequence[Mapping[str, Any]]) -> set[str]:
660
+ """Extract previously detected issue keys from observation history."""
661
 
662
+ keys: set[str] = set()
663
+ for action in action_history:
664
+ if action.get("action_type") != "detect_issue":
665
+ continue
666
+ keys.add(f"{action.get('record_id')}::{action.get('field')}")
667
+ return keys
 
 
 
 
668
 
669
 
670
  def propose_candidate_actions(
 
673
  ) -> List[Dict[str, Any]]:
674
  """Generate ranked candidate actions from visible table state."""
675
 
676
+ dataset = observation.get("dataset", {}) if isinstance(observation, dict) else {}
677
+ table = list(dataset.get("modified", []))
678
+ detected_keys = _detected_keys_from_history(observation.get("action_history", []))
679
+ raw_candidates = _issue_like_candidates(table)
680
+ candidates: List[Dict[str, Any]] = []
681
+ for candidate in raw_candidates:
682
+ if candidate.get("action_type") == "detect_issue":
683
+ key = f"{candidate.get('record_id')}::{candidate.get('field')}"
684
+ # Detect once; then prefer follow-up actions.
685
+ if key in detected_keys:
686
+ continue
687
+ candidates.append(candidate)
688
+ candidates += [
689
+ {"action_type": "skip", "record_id": "0", "field": "record", "confidence": 0.0}
690
+ ]
691
 
692
  unique_candidates: List[Dict[str, Any]] = []
693
  seen: set[str] = set()
 
705
  for candidate in unique_candidates
706
  if _build_action_string(candidate) != preferred_text
707
  ]
708
+ return ordered[:12]
709
 
710
 
711
  def _order_candidates_with_memory(
 
713
  memory: PolicyMemory,
714
  state_key: str,
715
  pattern_key: str,
716
+ recent_history: Sequence[str],
717
  ) -> List[Dict[str, Any]]:
718
  """Re-rank candidates using persistent cross-episode memory."""
719
 
720
  scored = []
721
+ recent_action_counts = Counter()
722
+ for item in recent_history[-5:]:
723
+ try:
724
+ parsed = item.split(" action=", 1)[1].split(" reward=", 1)[0].strip()
725
+ if parsed:
726
+ recent_action_counts[parsed] += 1
727
+ except Exception:
728
+ continue
729
+
730
  for index, candidate in enumerate(candidates):
731
  action_text = _build_action_string(candidate)
732
+ repeat_penalty = recent_action_counts.get(action_text, 0) * 2.0
733
  scored.append(
734
  (
735
+ -memory.score_action(state_key, pattern_key, action_text) + repeat_penalty,
736
  index,
737
  dict(candidate),
738
  )
 
817
  memory_blocked = memory.blocked_actions(state_key, pattern_key)
818
  combined_blocked = set(blocked_actions) | set(memory_blocked)
819
  candidates = propose_candidate_actions(observation, combined_blocked)
820
+ candidates = _order_candidates_with_memory(
821
+ candidates, memory, state_key, pattern_key, history
822
+ )
823
  heuristic_candidate = candidates[0]
824
  heuristic_text = _build_action_string(heuristic_candidate)
825
  candidate_texts = [_build_action_string(candidate) for candidate in candidates]
 
836
  blocked_actions=sorted(combined_blocked),
837
  )
838
 
839
+ if model_text not in candidate_texts:
840
+ model_text = None
841
  chosen_text = model_text or heuristic_text
842
  normalized_text, payload = action_string_to_payload(chosen_text, step_number)
843
  if normalized_text in combined_blocked:
 
863
  last_error: Optional[str] = None
864
  final_score = 0.0
865
  task_variant = "unknown"
866
+ action_repeat_counts: Dict[str, int] = defaultdict(int)
867
+ no_change_counts: Dict[str, int] = defaultdict(int)
868
 
869
  try:
870
  log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME)
 
879
  task_name=task_name,
880
  task_variant=task_variant,
881
  observation=observation,
882
+ goal=str(env.state().get("task", {}).get("goal", "")),
883
  step_number=step_number,
884
  history=history,
885
  last_error=last_error,
 
887
  )
888
 
889
  try:
890
+ before_dataset = observation.get("dataset", {}) if isinstance(observation, dict) else {}
891
+ before_modified = before_dataset.get("modified", [])
892
  observation_model, reward, done, info = env.step(action_payload)
893
  observation = observation_model.model_dump()
894
+ after_dataset = observation.get("dataset", {}) if isinstance(observation, dict) else {}
895
+ after_modified = after_dataset.get("modified", [])
896
  result = info.get("result", {})
897
+ curr_iter = float(observation.get("current_iteration_score", 0.0))
898
+ prev_iter = float(observation.get("previous_iteration_score", 0.0))
899
+ progress_delta = max(0.0, curr_iter - prev_iter)
900
+ error_value = "step_error" if (
901
+ result.get("wrong_fix")
902
+ or result.get("hallucinated_fix")
903
+ or result.get("wrong_cannot_determine")
904
+ or result.get("classification_incorrect")
905
+ ) else None
906
+ final_score = float(info.get("final_task_score", 0.0))
907
  if error_value == "general":
908
  error_value = None
909
  memory.update(
 
918
  )
919
  if error_value or progress_delta == 0.0 or reward <= 0.0:
920
  blocked_actions.add(action_text)
921
+ action_repeat_counts[action_text] += 1
922
+ if action_repeat_counts[action_text] > 2:
923
+ blocked_actions.add(action_text)
924
+ if _stable_json(before_modified) == _stable_json(after_modified):
925
+ no_change_counts[action_text] += 1
926
+ if no_change_counts[action_text] >= 2:
927
+ blocked_actions.add(action_text)
928
+ else:
929
+ no_change_counts[action_text] = 0
930
  except Exception as exc: # noqa: BLE001
931
  reward = 0.0
932
  done = True
 
961
  )
962
 
963
  if done:
964
+ success = bool(final_score > 0.0)
965
  break
966
  finally:
967
  memory.save()
968
  close_method = getattr(env, "close", None)
969
  if callable(close_method):
970
  close_method()
971
+ log_end(success=success, steps=steps_taken, rewards=rewards, final_score=final_score)
972
  return final_score
973
 
974
 
975
  def main() -> None:
976
+ """Run one configured task or all tasks in deterministic order."""
977
 
978
  client = create_client()
979
  memory = PolicyMemory(POLICY_CACHE_PATH)
980
+ task_name = str(TASK_NAME).strip().lower()
981
+ if task_name in {"all", "*"}:
982
+ for task_index, ordered_task in enumerate(TASK_ORDER):
983
+ run_episode(client=client, memory=memory, task_name=ordered_task, seed=task_index)
984
+ return
985
+ run_episode(client=client, memory=memory, task_name=task_name, seed=0)
986
 
987
 
988
  if __name__ == "__main__":
server/app.py CHANGED
@@ -6,6 +6,7 @@ deployment-facing application setup for the environment.
6
 
7
  from __future__ import annotations
8
 
 
9
  import logging
10
  import os
11
  from pathlib import Path
@@ -14,7 +15,8 @@ from threading import RLock
14
  from typing import Any, Dict, Optional
15
 
16
  from fastapi import FastAPI, HTTPException, Request
17
- from fastapi.responses import JSONResponse, RedirectResponse
 
18
  from pydantic import BaseModel, Field
19
  import uvicorn
20
 
@@ -30,7 +32,28 @@ from models import Action
30
  logging.basicConfig(level=logging.INFO)
31
  logger = logging.getLogger(__name__)
32
 
33
- app = FastAPI(title="dataops-env", version="1.0.0")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  active_env: Optional[DataOpsEnv] = None
35
  active_env_lock = RLock()
36
 
@@ -39,12 +62,20 @@ class ResetRequest(BaseModel):
39
  """Optional reset controls for reproducible task selection."""
40
 
41
  seed: int = Field(default=0, description="Deterministic seed for task sampling.")
42
- task_name: str | None = Field(
43
  default=None,
44
  description="Optional fixed task name: easy, medium, or hard.",
45
  )
46
 
47
 
 
 
 
 
 
 
 
 
48
  @app.exception_handler(Exception)
49
  async def unhandled_exception_handler(
50
  request: Request, exc: Exception
@@ -65,6 +96,111 @@ def root() -> RedirectResponse:
65
  return RedirectResponse(url="/docs")
66
 
67
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  @app.get("/health")
69
  def health() -> Dict[str, str]:
70
  """Return a lightweight deployment health signal."""
@@ -78,7 +214,10 @@ def reset(payload: ResetRequest | None = None) -> Dict[str, Any]:
78
 
79
  try:
80
  request = payload or ResetRequest()
81
- env = DataOpsEnv(seed=request.seed, task_name=request.task_name)
 
 
 
82
  observation = env.reset()
83
 
84
  global active_env
@@ -90,6 +229,8 @@ def reset(payload: ResetRequest | None = None) -> Dict[str, Any]:
90
  "task_name": env.state().get("task_name"),
91
  "observation": observation.model_dump(),
92
  }
 
 
93
  except Exception as exc:
94
  logger.exception("Failed to reset environment", exc_info=exc)
95
  raise HTTPException(status_code=500, detail="Failed to reset environment") from exc
 
6
 
7
  from __future__ import annotations
8
 
9
+ from enum import Enum
10
  import logging
11
  import os
12
  from pathlib import Path
 
15
  from typing import Any, Dict, Optional
16
 
17
  from fastapi import FastAPI, HTTPException, Request
18
+ from fastapi.openapi.docs import get_swagger_ui_html
19
+ from fastapi.responses import HTMLResponse, JSONResponse, RedirectResponse
20
  from pydantic import BaseModel, Field
21
  import uvicorn
22
 
 
32
  logging.basicConfig(level=logging.INFO)
33
  logger = logging.getLogger(__name__)
34
 
35
+ app = FastAPI(
36
+ title="dataops-env",
37
+ version="1.0.0",
38
+ summary="Reasoning-first semantic data cleaning benchmark.",
39
+ description=(
40
+ "### DataOps Gym: Clean Data, Keep Truth\n"
41
+ "A step-based evaluation environment for testing whether agents can detect issues, "
42
+ "fix only with evidence, abstain with `cannot_determine` under ambiguity, and stay "
43
+ "consistent across related records.\n\n"
44
+ "**Tagline:** *Fix data without fabricating reality.*\n\n"
45
+ "#### Why this API matters\n"
46
+ "- Strict JSON action schema (no free-form outputs)\n"
47
+ "- Reward shaping that penalizes hallucinations and over-correction\n"
48
+ "- Cross-record consistency and uncertainty-aware scoring\n"
49
+
50
+ ),
51
+ contact={
52
+ "name": "DataOps Gym",
53
+ "url": "https://github.com/graheetphartyal23/Dataops--GYM",
54
+ },
55
+ docs_url=None,
56
+ )
57
  active_env: Optional[DataOpsEnv] = None
58
  active_env_lock = RLock()
59
 
 
62
  """Optional reset controls for reproducible task selection."""
63
 
64
  seed: int = Field(default=0, description="Deterministic seed for task sampling.")
65
+ task_name: "TaskName | None" = Field(
66
  default=None,
67
  description="Optional fixed task name: easy, medium, or hard.",
68
  )
69
 
70
 
71
+ class TaskName(str, Enum):
72
+ """Allowed benchmark task names."""
73
+
74
+ EASY = "easy"
75
+ MEDIUM = "medium"
76
+ HARD = "hard"
77
+
78
+
79
  @app.exception_handler(Exception)
80
  async def unhandled_exception_handler(
81
  request: Request, exc: Exception
 
96
  return RedirectResponse(url="/docs")
97
 
98
 
99
+ @app.get("/docs", include_in_schema=False)
100
+ def custom_docs() -> HTMLResponse:
101
+ """Serve Swagger UI with a dark theme override."""
102
+
103
+ swagger = get_swagger_ui_html(
104
+ openapi_url=app.openapi_url,
105
+ title=f"{app.title} - API Docs",
106
+ swagger_ui_parameters={
107
+ "syntaxHighlight.theme": "obsidian",
108
+ "displayRequestDuration": True,
109
+ },
110
+ )
111
+ dark_css = """
112
+ <style>
113
+ html, body { background: #0b1020 !important; color: #e5e7eb !important; }
114
+ .swagger-ui, .swagger-ui .topbar { background: #0b1020 !important; }
115
+ .swagger-ui .topbar { border-bottom: 1px solid #1f2937 !important; }
116
+ .swagger-ui .topbar a, .swagger-ui .topbar span { color: #e5e7eb !important; }
117
+
118
+ /* Keep top API details readable: white card + black text */
119
+ .swagger-ui .info {
120
+ background: #ffffff !important;
121
+ color: #111827 !important;
122
+ border: 1px solid #e5e7eb !important;
123
+ border-radius: 12px !important;
124
+ padding: 18px !important;
125
+ margin: 18px 0 24px 0 !important;
126
+ }
127
+ .swagger-ui .info .title, .swagger-ui .info h1, .swagger-ui .info h2,
128
+ .swagger-ui .info h3, .swagger-ui .info p, .swagger-ui .info li,
129
+ .swagger-ui .info a, .swagger-ui .info .base-url, .swagger-ui .info .version {
130
+ color: #111827 !important;
131
+ }
132
+ .swagger-ui .info ul { margin: 10px 0 0 18px !important; }
133
+
134
+ /* Default + Schemas sections as white cards with black text */
135
+ .swagger-ui .opblock-tag {
136
+ background: #ffffff !important;
137
+ color: #111827 !important;
138
+ border: 1px solid #e5e7eb !important;
139
+ border-radius: 10px !important;
140
+ padding: 10px 12px !important;
141
+ margin-bottom: 12px !important;
142
+ }
143
+ .swagger-ui .opblock {
144
+ background: #ffffff !important;
145
+ border: 1px solid #e5e7eb !important;
146
+ border-radius: 10px !important;
147
+ margin: 0 0 14px 0 !important;
148
+ box-shadow: 0 2px 10px rgba(0, 0, 0, 0.25) !important;
149
+ }
150
+ .swagger-ui .opblock .opblock-summary {
151
+ background: #ffffff !important;
152
+ border-bottom: 1px solid #e5e7eb !important;
153
+ }
154
+ .swagger-ui .opblock .opblock-summary-method,
155
+ .swagger-ui .opblock .opblock-summary-path,
156
+ .swagger-ui .opblock .opblock-summary-path__deprecated,
157
+ .swagger-ui .opblock .opblock-summary-description {
158
+ color: #111827 !important;
159
+ fill: #111827 !important;
160
+ }
161
+ .swagger-ui .opblock-section-header,
162
+ .swagger-ui .responses-inner h4,
163
+ .swagger-ui .responses-inner h5,
164
+ .swagger-ui .tab li,
165
+ .swagger-ui .parameter__type,
166
+ .swagger-ui .model-title,
167
+ .swagger-ui .models h4 {
168
+ color: #111827 !important;
169
+ }
170
+ .swagger-ui .models {
171
+ background: #ffffff !important;
172
+ border: 1px solid #e5e7eb !important;
173
+ border-radius: 10px !important;
174
+ padding: 8px !important;
175
+ }
176
+ .swagger-ui .model-container, .swagger-ui .model-box {
177
+ background: #ffffff !important;
178
+ color: #111827 !important;
179
+ border-color: #e5e7eb !important;
180
+ }
181
+ .swagger-ui .model, .swagger-ui .prop-name, .swagger-ui .prop-type, .swagger-ui .prop-format {
182
+ color: #111827 !important;
183
+ }
184
+ .swagger-ui .response-col_status, .swagger-ui .response-col_description,
185
+ .swagger-ui label, .swagger-ui .parameter__name,
186
+ .swagger-ui table tbody tr td, .swagger-ui .responses-table, .swagger-ui .parameters-col_description {
187
+ color: #111827 !important;
188
+ background: #ffffff !important;
189
+ border-color: #e5e7eb !important;
190
+ }
191
+ .swagger-ui input, .swagger-ui textarea, .swagger-ui select {
192
+ background: #0f172a !important;
193
+ color: #e5e7eb !important;
194
+ border-color: #374151 !important;
195
+ }
196
+ .swagger-ui .btn.execute { background: #2563eb !important; color: white !important; }
197
+ .swagger-ui .btn { border-color: #4b5563 !important; }
198
+ </style>
199
+ """
200
+ html = swagger.body.decode("utf-8").replace("</head>", f"{dark_css}</head>")
201
+ return HTMLResponse(content=html, status_code=200)
202
+
203
+
204
  @app.get("/health")
205
  def health() -> Dict[str, str]:
206
  """Return a lightweight deployment health signal."""
 
214
 
215
  try:
216
  request = payload or ResetRequest()
217
+ env = DataOpsEnv(
218
+ seed=request.seed,
219
+ task_name=request.task_name.value if request.task_name is not None else None,
220
+ )
221
  observation = env.reset()
222
 
223
  global active_env
 
229
  "task_name": env.state().get("task_name"),
230
  "observation": observation.model_dump(),
231
  }
232
+ except ValueError as exc:
233
+ raise HTTPException(status_code=422, detail=str(exc)) from exc
234
  except Exception as exc:
235
  logger.exception("Failed to reset environment", exc_info=exc)
236
  raise HTTPException(status_code=500, detail="Failed to reset environment") from exc
task.py CHANGED
@@ -8,6 +8,7 @@ broader and less gameable.
8
 
9
  from __future__ import annotations
10
 
 
11
  from typing import Any, Dict, List, TypedDict
12
 
13
 
@@ -43,7 +44,23 @@ def _pick_variant(variant: int | None, variants: List[TaskDefinition]) -> TaskDe
43
  """Select a deterministic task variant with a stable default."""
44
 
45
  index = 0 if variant is None else max(0, min(len(variants) - 1, int(variant)))
46
- return variants[index]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
 
49
  def easy_cleaning_task(variant: int | None = None) -> TaskDefinition:
@@ -320,8 +337,8 @@ def hard_conflict_resolution_task(variant: int | None = None) -> TaskDefinition:
320
  "initial_table": [
321
  {"row_id": 21, "customer_id": "C200", "name": "Nina Patel", "email": "nina.patel@example.com", "phone": "206-555-0101", "status": "active"},
322
  {"row_id": 22, "customer_id": "C200", "name": "Nina Patel", "email": "nina.patel@example.com", "phone": "206-555-0101", "status": "active"},
323
- {"row_id": 23, "customer_id": "C201", "name": "Evan Cole", "email": "evan.cole@example", "phone": "4155550102", "status": "active"},
324
- {"row_id": 24, "customer_id": "C201", "name": "Evan Cole", "email": "evan.cole@example.com", "phone": "(415) 555-0102", "status": "inactive"},
325
  {"row_id": 25, "customer_id": "C202", "name": "A. J. Brown", "email": "aj.brown@example.com", "phone": "+1-312-555-0103", "status": "active"},
326
  {"row_id": 26, "customer_id": "C203", "name": "Marta Silva", "email": "shared@example.com", "phone": "646-555-0104", "status": "active"},
327
  {"row_id": 27, "customer_id": "C204", "name": "Martin Silva", "email": "shared@example.com", "phone": "646-555-0105", "status": "active"},
@@ -336,6 +353,9 @@ def hard_conflict_resolution_task(variant: int | None = None) -> TaskDefinition:
336
  {
337
  "type": "conflict",
338
  "rows": [23, 24],
 
 
 
339
  "description": "Rows 23 and 24 conflict for the same customer and must be reconciled into one trustworthy record.",
340
  },
341
  {
@@ -399,8 +419,8 @@ def hard_conflict_resolution_task(variant: int | None = None) -> TaskDefinition:
399
  "initial_table": [
400
  {"row_id": 51, "customer_id": "A900", "name": "Lena Brooks", "email": "lena.brooks@example.com", "phone": "212-555-0111", "status": "active"},
401
  {"row_id": 52, "customer_id": "A900", "name": "Lena Brooks", "email": "lena.brooks@example.com", "phone": "212-555-0111", "status": "active"},
402
- {"row_id": 53, "customer_id": "A901", "name": "Ravi Shah", "email": "ravi.shah example.com", "phone": "6465550112", "status": "active"},
403
- {"row_id": 54, "customer_id": "A901", "name": "Ravi Shah", "email": "ravi.shah@example.com", "phone": "646-555-0112", "status": "inactive"},
404
  {"row_id": 55, "customer_id": "A902", "name": "M. E. Klein", "email": "mek@example.com", "phone": "+1-303-555-0113", "status": "active"},
405
  {"row_id": 56, "customer_id": "A903", "name": "Sana Noor", "email": "ops@example.com", "phone": "718-555-0114", "status": "active"},
406
  {"row_id": 57, "customer_id": "A904", "name": "Sana N.", "email": "ops@example.com", "phone": "718-555-0115", "status": "active"},
@@ -415,6 +435,9 @@ def hard_conflict_resolution_task(variant: int | None = None) -> TaskDefinition:
415
  {
416
  "type": "conflict",
417
  "rows": [53, 54],
 
 
 
418
  "description": "Rows 53 and 54 conflict for the same customer and must be reconciled into one trustworthy record.",
419
  },
420
  {
 
8
 
9
  from __future__ import annotations
10
 
11
+ from copy import deepcopy
12
  from typing import Any, Dict, List, TypedDict
13
 
14
 
 
44
  """Select a deterministic task variant with a stable default."""
45
 
46
  index = 0 if variant is None else max(0, min(len(variants) - 1, int(variant)))
47
+ selected = deepcopy(variants[index])
48
+ selected["hidden_issues"] = _with_fixable_flags(selected.get("hidden_issues", []))
49
+ return selected
50
+
51
+
52
+ def _with_fixable_flags(hidden_issues: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
53
+ """Ensure each hidden issue carries an explicit ``fixable`` flag."""
54
+
55
+ enriched: List[Dict[str, Any]] = []
56
+ for issue in hidden_issues:
57
+ issue_copy = dict(issue)
58
+ if "fixable" not in issue_copy:
59
+ issue_type = issue_copy.get("type")
60
+ # Structural conflicts usually require judgment across rows.
61
+ issue_copy["fixable"] = issue_type not in {"duplicate", "conflict", "constraint_violation"}
62
+ enriched.append(issue_copy)
63
+ return enriched
64
 
65
 
66
  def easy_cleaning_task(variant: int | None = None) -> TaskDefinition:
 
337
  "initial_table": [
338
  {"row_id": 21, "customer_id": "C200", "name": "Nina Patel", "email": "nina.patel@example.com", "phone": "206-555-0101", "status": "active"},
339
  {"row_id": 22, "customer_id": "C200", "name": "Nina Patel", "email": "nina.patel@example.com", "phone": "206-555-0101", "status": "active"},
340
+ {"row_id": 23, "customer_id": "C201", "name": "Evan Cole", "email": "evan.cole@example", "phone": "4155550102", "status": "active", "age": 250},
341
+ {"row_id": 24, "customer_id": "C201", "name": "Evan Cole", "email": "evan.cole@example.com", "phone": "(415) 555-0102", "status": "inactive", "age": 45},
342
  {"row_id": 25, "customer_id": "C202", "name": "A. J. Brown", "email": "aj.brown@example.com", "phone": "+1-312-555-0103", "status": "active"},
343
  {"row_id": 26, "customer_id": "C203", "name": "Marta Silva", "email": "shared@example.com", "phone": "646-555-0104", "status": "active"},
344
  {"row_id": 27, "customer_id": "C204", "name": "Martin Silva", "email": "shared@example.com", "phone": "646-555-0105", "status": "active"},
 
353
  {
354
  "type": "conflict",
355
  "rows": [23, 24],
356
+ "field": "age",
357
+ "values": [250, 45],
358
+ "fixable": False,
359
  "description": "Rows 23 and 24 conflict for the same customer and must be reconciled into one trustworthy record.",
360
  },
361
  {
 
419
  "initial_table": [
420
  {"row_id": 51, "customer_id": "A900", "name": "Lena Brooks", "email": "lena.brooks@example.com", "phone": "212-555-0111", "status": "active"},
421
  {"row_id": 52, "customer_id": "A900", "name": "Lena Brooks", "email": "lena.brooks@example.com", "phone": "212-555-0111", "status": "active"},
422
+ {"row_id": 53, "customer_id": "A901", "name": "Ravi Shah", "email": "ravi.shah example.com", "phone": "6465550112", "status": "active", "age": 250},
423
+ {"row_id": 54, "customer_id": "A901", "name": "Ravi Shah", "email": "ravi.shah@example.com", "phone": "646-555-0112", "status": "inactive", "age": 45},
424
  {"row_id": 55, "customer_id": "A902", "name": "M. E. Klein", "email": "mek@example.com", "phone": "+1-303-555-0113", "status": "active"},
425
  {"row_id": 56, "customer_id": "A903", "name": "Sana Noor", "email": "ops@example.com", "phone": "718-555-0114", "status": "active"},
426
  {"row_id": 57, "customer_id": "A904", "name": "Sana N.", "email": "ops@example.com", "phone": "718-555-0115", "status": "active"},
 
435
  {
436
  "type": "conflict",
437
  "rows": [53, 54],
438
+ "field": "age",
439
+ "values": [250, 45],
440
+ "fixable": False,
441
  "description": "Rows 53 and 54 conflict for the same customer and must be reconciled into one trustworthy record.",
442
  },
443
  {