Spaces:
Sleeping
Sleeping
| """ | |
| DataQA Environment | |
| ------------------ | |
| Server-side environment for data quality assurance tasks. | |
| Two-phase RL environment: | |
| Phase 1 (Identify): Agent inspects corrupted datasets and reports quality issues. | |
| Phase 2 (Fix): Agent proposes corrections for identified issues. | |
| Combined reward = 0.6 * identify_score + 0.4 * fix_score | |
| Both phases scored with difficulty-weighted metrics for rich per-step signal. | |
| """ | |
| from __future__ import annotations | |
| import re | |
| import uuid | |
| from typing import Any, Optional, Set | |
| from openenv.core.env_server.interfaces import Action, Environment, Observation | |
| from ..models import DataQAAction, DataQAObservation, DataQAState | |
| from .tasks import PlantedIssue, Task, get_task, list_tasks | |
| # Reward weights for the two phases | |
| IDENTIFY_WEIGHT = 0.6 | |
| FIX_WEIGHT = 0.4 | |
| # Clamp reward to strict (0, 1) β validators reject exactly 0.0 and 1.0 | |
| REWARD_MIN = 0.001 | |
| REWARD_MAX = 0.999 | |
| def _clamp_reward(r: float) -> float: | |
| return max(REWARD_MIN, min(REWARD_MAX, r)) | |
| def parse_issue_key(raw: str) -> Optional[str]: | |
| """ | |
| Parse an agent-reported issue string into a normalized key. | |
| Expected format: row:<N>,col:<name>,issue:<type> | |
| Returns normalized key or None if unparseable. | |
| """ | |
| raw = raw.strip().lower() | |
| row_match = re.search(r"row\s*[:=]\s*(\d+)", raw) | |
| col_match = re.search(r"col\s*[:=]\s*([\w_]+)", raw) | |
| issue_match = re.search(r"issue\s*[:=]\s*([\w_]+)", raw) | |
| if row_match and col_match and issue_match: | |
| return f"row:{row_match.group(1)},col:{col_match.group(1)},issue:{issue_match.group(1)}" | |
| return None | |
| def parse_fix(raw: str) -> Optional[tuple[int, str, str]]: | |
| """ | |
| Parse an agent-proposed fix into (row, col, proposed_value). | |
| Expected format: row:<N>,col:<name>,fix:<value> | |
| Returns (row, col, value) or None if unparseable. | |
| """ | |
| raw = raw.strip() | |
| row_match = re.search(r"row\s*[:=]\s*(\d+)", raw, re.IGNORECASE) | |
| col_match = re.search(r"col(?:umn)?\s*[:=]\s*([\w_]+)", raw, re.IGNORECASE) | |
| fix_match = re.search(r"fix\s*[:=]\s*(.+?)$", raw, re.IGNORECASE) | |
| if row_match and col_match and fix_match: | |
| return (int(row_match.group(1)), col_match.group(1).lower(), fix_match.group(1).strip()) | |
| return None | |
| def compute_f1(reported_keys: Set[str], planted_keys: Set[str]) -> dict: | |
| """Compute precision, recall, and F1 score.""" | |
| if not reported_keys and not planted_keys: | |
| return {"precision": 1.0, "recall": 1.0, "f1": 1.0, "tp": 0, "fp": 0, "fn": 0} | |
| if not reported_keys: | |
| return {"precision": 0.0, "recall": 0.0, "f1": 0.0, "tp": 0, "fp": 0, "fn": len(planted_keys)} | |
| if not planted_keys: | |
| return {"precision": 0.0, "recall": 0.0, "f1": 0.0, "tp": 0, "fp": len(reported_keys), "fn": 0} | |
| tp = len(reported_keys & planted_keys) | |
| fp = len(reported_keys - planted_keys) | |
| fn = len(planted_keys - reported_keys) | |
| precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0 | |
| recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0 | |
| f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0 | |
| return {"precision": precision, "recall": recall, "f1": f1, "tp": tp, "fp": fp, "fn": fn} | |
| def compute_weighted_reward( | |
| reported_keys: Set[str], | |
| planted_issues: list, | |
| ) -> dict: | |
| """ | |
| Compute difficulty-weighted reward for richer per-step signal. | |
| Each planted issue has a difficulty weight (1.0-3.0). Finding harder issues | |
| earns more reward. False positives incur a penalty scaled by average difficulty. | |
| Returns dict with weighted_reward (0.0-1.0), plus per-issue breakdown. | |
| """ | |
| if not planted_issues and not reported_keys: | |
| return {"weighted_reward": 1.0, "difficulty_found": 0.0, "difficulty_missed": 0.0} | |
| planted_by_key = {issue.to_key(): issue for issue in planted_issues} | |
| planted_keys = set(planted_by_key.keys()) | |
| if not reported_keys: | |
| total_weight = sum(i.difficulty for i in planted_issues) | |
| return {"weighted_reward": 0.0, "difficulty_found": 0.0, "difficulty_missed": total_weight} | |
| if not planted_keys: | |
| return {"weighted_reward": 0.0, "difficulty_found": 0.0, "difficulty_missed": 0.0} | |
| found_keys = reported_keys & planted_keys | |
| missed_keys = planted_keys - reported_keys | |
| false_positive_count = len(reported_keys - planted_keys) | |
| difficulty_found = sum(planted_by_key[k].difficulty for k in found_keys) | |
| difficulty_missed = sum(planted_by_key[k].difficulty for k in missed_keys) | |
| total_weight = sum(i.difficulty for i in planted_issues) | |
| weighted_recall = difficulty_found / total_weight if total_weight > 0 else 0.0 | |
| avg_difficulty = total_weight / len(planted_issues) | |
| fp_penalty_weight = false_positive_count * avg_difficulty | |
| weighted_precision = difficulty_found / (difficulty_found + fp_penalty_weight) if (difficulty_found + fp_penalty_weight) > 0 else 0.0 | |
| if (weighted_precision + weighted_recall) > 0: | |
| weighted_reward = 2 * weighted_precision * weighted_recall / (weighted_precision + weighted_recall) | |
| else: | |
| weighted_reward = 0.0 | |
| return { | |
| "weighted_reward": round(weighted_reward, 4), | |
| "difficulty_found": round(difficulty_found, 2), | |
| "difficulty_missed": round(difficulty_missed, 2), | |
| } | |
| def grade_fixes( | |
| fixes: list[tuple[int, str, str]], | |
| task: Task, | |
| ) -> dict: | |
| """ | |
| Grade proposed fixes against the clean dataset. | |
| For each fix (row, col, proposed_value), compare to the original clean value. | |
| Scoring per fix: | |
| - Exact match (case-insensitive, whitespace-stripped): 1.0 | |
| - Numeric close match (within 1%): 0.8 | |
| - Correct column but wrong value: 0.1 | |
| - Targets a non-issue cell: 0.0 (penalty) | |
| Returns dict with fix_score (0.0-1.0), details per fix, and counts. | |
| """ | |
| if not fixes and not task.planted_issues: | |
| return {"fix_score": 1.0, "fixes_correct": 0, "fixes_partial": 0, | |
| "fixes_wrong": 0, "fixes_attempted": 0, "fix_details": []} | |
| if not fixes: | |
| return {"fix_score": 0.0, "fixes_correct": 0, "fixes_partial": 0, | |
| "fixes_wrong": 0, "fixes_attempted": 0, "fix_details": []} | |
| issue_map = task.get_planted_issue_map() | |
| # Build set of (row, col) that are actual issues | |
| issue_cells = {(issue.row, issue.col) for issue in task.planted_issues} | |
| total_weight = sum(i.difficulty for i in task.planted_issues) if task.planted_issues else 1.0 | |
| earned_weight = 0.0 | |
| fixes_correct = 0 | |
| fixes_partial = 0 | |
| fixes_wrong = 0 | |
| fix_details = [] | |
| # Track which issues have been fixed (best fix wins) | |
| fixed_issues: dict[tuple[int, str], float] = {} | |
| for row, col, proposed in fixes: | |
| clean_value = task.get_clean_value(row, col) | |
| cell_key = (row, col) | |
| if cell_key not in issue_cells: | |
| # Fix targets a non-issue cell β no credit | |
| fix_details.append({"row": row, "col": col, "score": 0.0, "reason": "not an issue cell"}) | |
| fixes_wrong += 1 | |
| continue | |
| if clean_value is None: | |
| fix_details.append({"row": row, "col": col, "score": 0.0, "reason": "cell not found"}) | |
| fixes_wrong += 1 | |
| continue | |
| # Find the planted issue for this cell to get its difficulty weight | |
| matching_issue = None | |
| for issue in task.planted_issues: | |
| if issue.row == row and issue.col == col: | |
| matching_issue = issue | |
| break | |
| difficulty = matching_issue.difficulty if matching_issue else 1.0 | |
| # Score the fix using tiered grading: | |
| # 1.0 = exact match with clean value | |
| # 0.8 = valid fix (right type, in range, addresses the issue) but not exact | |
| # 0.4 = partially valid (reasonable attempt, right direction) | |
| # 0.1 = targets correct cell but fix doesn't address the issue | |
| # 0.0 = makes things worse or targets non-issue cell | |
| score = 0.0 | |
| reason = "wrong value" | |
| issue_type = matching_issue.issue_type if matching_issue else "" | |
| # Exact match (case-insensitive, whitespace-stripped) | |
| if proposed.strip().lower() == clean_value.lower(): | |
| score = 1.0 | |
| reason = "exact match" | |
| fixes_correct += 1 | |
| else: | |
| # Grade by issue type β check if the fix is VALID even if not exact | |
| proposed_stripped = proposed.strip() | |
| if issue_type == "missing_value": | |
| # Any non-empty value is a reasonable fix for a missing value | |
| if proposed_stripped and proposed_stripped != " ": | |
| score = 0.8 | |
| reason = "valid fix (non-empty value for missing field)" | |
| fixes_partial += 1 | |
| else: | |
| score = 0.0 | |
| reason = "fix is still empty" | |
| fixes_wrong += 1 | |
| elif issue_type == "wrong_type": | |
| # Check if the proposed value is the correct type | |
| try: | |
| float(proposed_stripped) | |
| # Original was text, proposed is numeric β correct type fix | |
| score = 0.8 | |
| reason = "valid fix (correct type)" | |
| fixes_partial += 1 | |
| except ValueError: | |
| score = 0.1 | |
| reason = "fix is still wrong type" | |
| fixes_partial += 1 | |
| elif issue_type == "out_of_range": | |
| # Check if proposed value is within a reasonable range | |
| try: | |
| proposed_num = float(proposed_stripped) | |
| clean_num = float(clean_value) | |
| # Within 50% of clean value = good estimate | |
| if clean_num != 0 and abs(proposed_num - clean_num) / abs(clean_num) <= 0.5: | |
| score = 0.8 | |
| reason = "valid fix (in reasonable range)" | |
| fixes_partial += 1 | |
| elif proposed_num > 0 and (clean_num > 0) == (proposed_num > 0): | |
| # At least right sign/direction | |
| score = 0.4 | |
| reason = "partially valid (right direction)" | |
| fixes_partial += 1 | |
| else: | |
| score = 0.1 | |
| reason = "fix still out of reasonable range" | |
| fixes_partial += 1 | |
| except ValueError: | |
| score = 0.1 | |
| reason = "correct cell, wrong value" | |
| fixes_partial += 1 | |
| elif issue_type == "format_violation": | |
| # Check if proposed value matches expected format | |
| # For dates: YYYY-MM-DD pattern | |
| if re.match(r"\d{4}-\d{2}-\d{2}", proposed_stripped): | |
| score = 0.8 | |
| reason = "valid fix (correct format)" | |
| fixes_partial += 1 | |
| elif proposed_stripped and proposed_stripped != clean_value: | |
| score = 0.4 | |
| reason = "fix attempted but format unclear" | |
| fixes_partial += 1 | |
| else: | |
| score = 0.1 | |
| reason = "correct cell, wrong value" | |
| fixes_partial += 1 | |
| elif issue_type in ("inconsistent_value", "statistical_outlier"): | |
| # These require domain knowledge β any reasonable attempt gets partial credit | |
| try: | |
| proposed_num = float(proposed_stripped) | |
| clean_num = float(clean_value) | |
| # Within 20% = strong fix, within 50% = reasonable | |
| if clean_num != 0: | |
| pct_diff = abs(proposed_num - clean_num) / abs(clean_num) | |
| if pct_diff <= 0.01: | |
| score = 1.0 | |
| reason = "exact numeric match" | |
| fixes_correct += 1 | |
| elif pct_diff <= 0.2: | |
| score = 0.8 | |
| reason = "valid fix (within 20% of correct value)" | |
| fixes_partial += 1 | |
| elif pct_diff <= 0.5: | |
| score = 0.4 | |
| reason = "partially valid (right ballpark)" | |
| fixes_partial += 1 | |
| else: | |
| score = 0.1 | |
| reason = "correct cell, value not close" | |
| fixes_partial += 1 | |
| else: | |
| score = 0.4 | |
| reason = "numeric fix attempted" | |
| fixes_partial += 1 | |
| except ValueError: | |
| # Non-numeric fix for text fields β check similarity | |
| if len(proposed_stripped) > 10 and proposed_stripped != clean_value: | |
| score = 0.4 | |
| reason = "text fix attempted (cannot verify automatically)" | |
| fixes_partial += 1 | |
| else: | |
| score = 0.1 | |
| reason = "correct cell, wrong value" | |
| fixes_partial += 1 | |
| else: | |
| # Fallback: numeric close match or partial credit | |
| try: | |
| proposed_num = float(proposed_stripped) | |
| clean_num = float(clean_value) | |
| if clean_num != 0 and abs(proposed_num - clean_num) / abs(clean_num) <= 0.01: | |
| score = 0.8 | |
| reason = "numeric close match" | |
| fixes_partial += 1 | |
| else: | |
| score = 0.1 | |
| reason = "correct cell, wrong value" | |
| fixes_partial += 1 | |
| except (ValueError, ZeroDivisionError): | |
| score = 0.1 | |
| reason = "correct cell, wrong value" | |
| fixes_partial += 1 | |
| # Keep best fix per cell | |
| if cell_key not in fixed_issues or score > fixed_issues[cell_key]: | |
| fixed_issues[cell_key] = score | |
| fix_details.append({"row": row, "col": col, "score": score, "reason": reason}) | |
| # Compute fix score: weighted sum of best fix per issue / total weight | |
| for issue in task.planted_issues: | |
| cell_key = (issue.row, issue.col) | |
| if cell_key in fixed_issues: | |
| earned_weight += issue.difficulty * fixed_issues[cell_key] | |
| fix_score = earned_weight / total_weight if total_weight > 0 else 0.0 | |
| fix_score = min(max(fix_score, 0.0), 1.0) | |
| return { | |
| "fix_score": round(fix_score, 4), | |
| "fixes_correct": fixes_correct, | |
| "fixes_partial": fixes_partial, | |
| "fixes_wrong": fixes_wrong, | |
| "fixes_attempted": len(fixes), | |
| "fix_details": fix_details, | |
| } | |
| class DataQAEnvironment(Environment): | |
| """ | |
| Data Quality Assurance environment β two-phase identify + fix. | |
| Phase 1 (Identify): Agent inspects corrupted datasets and reports quality issues. | |
| Phase 2 (Fix): Agent proposes corrections for identified issues. | |
| Combined reward = 0.6 * identify_score + 0.4 * fix_score | |
| Both phases use difficulty-weighted scoring for rich per-step reward signals. | |
| """ | |
| SUPPORTS_CONCURRENT_SESSIONS = True | |
| def __init__(self): | |
| self._state = DataQAState() | |
| self._current_task: Optional[Task] = None | |
| self._planted_keys: Set[str] = set() | |
| self._best_score: float = 0.0 | |
| def reset( | |
| self, | |
| seed: Optional[int] = None, | |
| episode_id: Optional[str] = None, | |
| **kwargs: Any, | |
| ) -> Observation: | |
| task_id = kwargs.get("task_id", "easy") | |
| task_seed = seed if seed is not None else 42 | |
| self._current_task = get_task(task_id, seed=task_seed) | |
| self._planted_keys = {issue.to_key() for issue in self._current_task.planted_issues} | |
| self._best_score = 0.0 | |
| ep_id = episode_id or str(uuid.uuid4()) | |
| self._state = DataQAState( | |
| episode_id=ep_id, | |
| step_count=0, | |
| task_id=task_id, | |
| current_step=0, | |
| max_steps=self._current_task.max_steps, | |
| best_score=0.0, | |
| total_planted_issues=len(self._current_task.planted_issues), | |
| ) | |
| return DataQAObservation( | |
| dataset_csv=self._current_task.corrupted_csv, | |
| schema_description=self._current_task.schema_description, | |
| validation_rules=self._current_task.validation_rules, | |
| task_description=self._current_task.description, | |
| feedback=( | |
| "Environment reset. Inspect the dataset and report all quality issues.\n" | |
| "You can also propose fixes in format: row:<N>,col:<name>,fix:<corrected_value>\n" | |
| "Combined reward = 0.6 * identify_score + 0.4 * fix_score" | |
| ), | |
| task_id=task_id, | |
| num_issues_hint=len(self._current_task.planted_issues), | |
| max_steps=self._current_task.max_steps, | |
| done=False, | |
| reward=_clamp_reward(0.0), | |
| ) | |
| def step( | |
| self, | |
| action: Action, | |
| timeout_s: Optional[float] = None, | |
| **kwargs: Any, | |
| ) -> Observation: | |
| if not isinstance(action, DataQAAction): | |
| raise ValueError(f"Expected DataQAAction, got {type(action)}") | |
| # Auto-reset in stateless HTTP mode | |
| if self._current_task is None: | |
| self.reset(task_id=action.task_id) | |
| self._state.step_count += 1 | |
| self._state.current_step += 1 | |
| # ββ Phase 1: Parse and score issue identification ββ | |
| reported_keys: Set[str] = set() | |
| parse_errors: list[str] = [] | |
| for raw_issue in action.issues: | |
| key = parse_issue_key(raw_issue) | |
| if key: | |
| reported_keys.add(key) | |
| else: | |
| parse_errors.append(f"Could not parse issue: '{raw_issue}'") | |
| metrics = compute_f1(reported_keys, self._planted_keys) | |
| identify_f1 = metrics["f1"] | |
| weighted = compute_weighted_reward(reported_keys, self._current_task.planted_issues) | |
| identify_score = weighted["weighted_reward"] | |
| # ββ Phase 2: Parse and score proposed fixes ββ | |
| parsed_fixes: list[tuple[int, str, str]] = [] | |
| for raw_fix in action.fixes: | |
| fix = parse_fix(raw_fix) | |
| if fix: | |
| parsed_fixes.append(fix) | |
| else: | |
| parse_errors.append(f"Could not parse fix: '{raw_fix}'") | |
| fix_result = grade_fixes(parsed_fixes, self._current_task) | |
| fix_score = fix_result["fix_score"] | |
| # ββ Combined reward ββ | |
| # If no fixes submitted, score is identify-only (no penalty for not fixing) | |
| if action.fixes: | |
| combined_reward = IDENTIFY_WEIGHT * identify_score + FIX_WEIGHT * fix_score | |
| else: | |
| combined_reward = identify_score # backward compatible | |
| self._best_score = max(self._best_score, combined_reward) | |
| self._state.best_score = self._best_score | |
| # ββ Check if done ββ | |
| is_done = ( | |
| identify_f1 >= 0.999 # Perfect identification | |
| or self._state.current_step >= self._state.max_steps | |
| ) | |
| # ββ Build feedback with actionable diagnostics ββ | |
| # Show the agent exactly which reported issues were correct (TP) and which were wrong (FP) | |
| tp_keys = reported_keys & self._planted_keys | |
| fp_keys = reported_keys - self._planted_keys | |
| feedback_lines = [ | |
| f"Step {self._state.current_step}/{self._state.max_steps}", | |
| "", | |
| "--- Identification ---", | |
| f"Issues reported: {len(reported_keys)}", | |
| f"True positives: {metrics['tp']}, False positives: {metrics['fp']}, Missed: {metrics['fn']}", | |
| f"Precision: {metrics['precision']:.3f}, Recall: {metrics['recall']:.3f}, F1: {identify_f1:.3f}", | |
| f"Identify score (weighted): {identify_score:.3f}", | |
| ] | |
| # Show which reported issues were correct vs wrong (helps agent self-correct) | |
| if tp_keys: | |
| feedback_lines.append(f"Correct issues: {', '.join(sorted(tp_keys))}") | |
| if fp_keys: | |
| feedback_lines.append(f"Incorrect issues (false positives): {', '.join(sorted(fp_keys))}") | |
| if action.fixes: | |
| feedback_lines += [ | |
| "", | |
| "--- Fix Proposals ---", | |
| f"Fixes attempted: {fix_result['fixes_attempted']}", | |
| f"Correct: {fix_result['fixes_correct']}, Partial: {fix_result['fixes_partial']}, Wrong: {fix_result['fixes_wrong']}", | |
| f"Fix score: {fix_score:.3f}", | |
| ] | |
| # Show per-fix feedback so agent knows which fixes worked | |
| for detail in fix_result["fix_details"]: | |
| status = "correct" if detail["score"] >= 0.99 else ("partial" if detail["score"] > 0 else "wrong") | |
| feedback_lines.append( | |
| f" row:{detail['row']},col:{detail['col']} -> {status} ({detail['reason']})" | |
| ) | |
| feedback_lines.append( | |
| f"\n--- Combined Reward: {combined_reward:.3f} (identify={identify_score:.3f} x {IDENTIFY_WEIGHT} + fix={fix_score:.3f} x {FIX_WEIGHT}) ---" | |
| ) | |
| else: | |
| feedback_lines += [ | |
| "", | |
| "Tip: Submit fixes with format row:<N>,col:<name>,fix:<value> for bonus reward.", | |
| ] | |
| if parse_errors: | |
| feedback_lines.append(f"\nParse errors ({len(parse_errors)}): {'; '.join(parse_errors[:5])}") | |
| if not is_done: | |
| if metrics["fn"] > 0: | |
| feedback_lines.append( | |
| f"\nYou missed {metrics['fn']} issue(s). Review the dataset carefully." | |
| ) | |
| if metrics["fp"] > 0: | |
| feedback_lines.append( | |
| f"Remove the {metrics['fp']} false positive(s) listed above and look for real issues." | |
| ) | |
| feedback_lines.append("You can submit again with updated issues and/or fixes.") | |
| else: | |
| feedback_lines.append(f"\nTask complete! Final best reward: {self._best_score:.3f}") | |
| # ββ Flag items for human review ββ | |
| # In a production data QA pipeline, these would go to a human reviewer. | |
| # The grader flags cases where automated scoring has low confidence. | |
| human_review_flags: list[dict] = [] | |
| # 1. False positives that target real columns β could be legitimate issues | |
| # the task designer didn't plant (agent may be smarter than the grader) | |
| issue_map = self._current_task.get_planted_issue_map() | |
| valid_issue_types = {"missing_value", "wrong_type", "duplicate_row", "out_of_range", | |
| "format_violation", "inconsistent_value", "statistical_outlier", | |
| "referential_integrity"} | |
| for fp_key in fp_keys: | |
| parts = fp_key.split(",") | |
| itype = parts[2].split(":")[1] if len(parts) >= 3 else "" | |
| if itype in valid_issue_types: | |
| human_review_flags.append({ | |
| "item": fp_key, | |
| "reason": "Agent reported this issue but it's not in ground truth β may be a real issue the grader missed", | |
| "type": "possible_unplanted_issue", | |
| }) | |
| # 2. Partial fix matches β fix was close but not exact, human should verify | |
| for detail in fix_result["fix_details"]: | |
| if 0 < detail["score"] < 0.99: | |
| human_review_flags.append({ | |
| "item": f"row:{detail['row']},col:{detail['col']}", | |
| "reason": f"Fix scored {detail['score']:.2f} ({detail['reason']}) β human should verify if acceptable", | |
| "type": "partial_fix", | |
| }) | |
| # 3. High-difficulty issues that were missed β flag for training data review | |
| planted_by_key = {i.to_key(): i for i in self._current_task.planted_issues} | |
| fn_keys = self._planted_keys - reported_keys | |
| for fn_key in fn_keys: | |
| issue = planted_by_key.get(fn_key) | |
| if issue and issue.difficulty >= 2.5: | |
| human_review_flags.append({ | |
| "item": fn_key, | |
| "reason": f"High-difficulty issue (difficulty={issue.difficulty}) missed β {issue.description}", | |
| "type": "missed_hard_issue", | |
| }) | |
| if human_review_flags: | |
| feedback_lines.append(f"\n--- Flagged for Human Review ({len(human_review_flags)}) ---") | |
| for flag in human_review_flags: | |
| feedback_lines.append(f" [{flag['type']}] {flag['item']}: {flag['reason']}") | |
| return DataQAObservation( | |
| dataset_csv=self._current_task.corrupted_csv, | |
| schema_description=self._current_task.schema_description, | |
| validation_rules=self._current_task.validation_rules, | |
| task_description=self._current_task.description, | |
| feedback="\n".join(feedback_lines), | |
| task_id=self._current_task.task_id, | |
| num_issues_hint=len(self._current_task.planted_issues), | |
| max_steps=self._state.max_steps, | |
| done=is_done, | |
| reward=_clamp_reward(self._best_score), | |
| metadata={ | |
| "identify_f1": identify_f1, | |
| "identify_score": identify_score, | |
| "fix_score": fix_score, | |
| "combined_reward": combined_reward, | |
| "precision": metrics["precision"], | |
| "recall": metrics["recall"], | |
| "tp": metrics["tp"], | |
| "fp": metrics["fp"], | |
| "fn": metrics["fn"], | |
| "difficulty_found": weighted["difficulty_found"], | |
| "difficulty_missed": weighted["difficulty_missed"], | |
| "fixes_correct": fix_result["fixes_correct"], | |
| "fixes_partial": fix_result["fixes_partial"], | |
| "fixes_wrong": fix_result["fixes_wrong"], | |
| "fixes_attempted": fix_result["fixes_attempted"], | |
| "fix_details": fix_result["fix_details"], | |
| "human_review_flags": human_review_flags, | |
| }, | |
| ) | |
| def state(self) -> DataQAState: | |
| return self._state | |