Spaces:
Sleeping
Sleeping
| """Semantic data-cleaning evaluation environment.""" | |
| from __future__ import annotations | |
| from copy import deepcopy | |
| import random | |
| from typing import Any, Dict, List, Mapping, Optional, Tuple | |
| from grader import grade_step_details, grade_task_result | |
| from models import Action, Observation | |
| from task import easy_cleaning_task, hard_conflict_resolution_task, medium_normalization_task | |
| class DataOpsEnv: | |
| """Step-based semantic evaluator with strict action protocol.""" | |
| def __init__(self, seed: int = 0, task_name: Optional[str] = None) -> None: | |
| self._seed = seed | |
| self._rng = random.Random(seed) | |
| self._task_registry: List[Tuple[str, Any]] = [ | |
| ("easy", easy_cleaning_task), | |
| ("medium", medium_normalization_task), | |
| ("hard", hard_conflict_resolution_task), | |
| ] | |
| self._fixed_task_name = task_name | |
| self._state_data: Dict[str, Any] = {} | |
| def reset(self) -> Observation: | |
| task_name, task_factory = self._select_task_factory() | |
| variant_count = max(1, int(getattr(task_factory, "variant_count", 1))) | |
| task_definition = deepcopy(task_factory(variant=self._rng.randrange(variant_count))) | |
| initial_table = deepcopy(task_definition["initial_table"]) | |
| self._state_data = { | |
| "seed": self._seed, | |
| "task_name": task_name, | |
| "task_variant": task_definition.get("variant_id", task_name), | |
| "task": task_definition, | |
| "dataset_original": initial_table, | |
| "dataset_modified": deepcopy(initial_table), | |
| "action_history": [], | |
| "per_record_scores": {}, | |
| "current_iteration_score": 0.0, | |
| "previous_iteration_score": 0.0, | |
| "failure_logs": [], | |
| "steps_taken": 0, | |
| "steps_remaining": task_definition["max_steps"], | |
| "done": False, | |
| "totals": { | |
| "total_fixes": 0, | |
| "hallucinated_fixes": 0, | |
| "total_cannot_determine": 0, | |
| "correct_cannot_determine": 0, | |
| "total_related_cases": 0, | |
| "consistent_decisions": 0, | |
| }, | |
| "related_decisions": {}, | |
| "detected_unresolved_issues": {}, | |
| "detected_issues": {}, | |
| "hallucination_rate": 0.0, | |
| "uncertainty_accuracy": 0.0, | |
| "consistency_score": 1.0, | |
| } | |
| return self._build_observation() | |
| def step(self, action: Action | Mapping[str, Any]) -> Tuple[Observation, float, bool, Dict[str, Any]]: | |
| if not self._state_data: | |
| raise RuntimeError("Environment must be reset before calling step().") | |
| if self._state_data["done"]: | |
| raise RuntimeError("Episode is finished. Call reset() before stepping again.") | |
| parsed_action = action if isinstance(action, Action) else Action(**dict(action)) | |
| result = self._evaluate_action(parsed_action) | |
| self._state_data["action_history"].append(parsed_action.model_dump()) | |
| self._state_data["steps_taken"] += 1 | |
| self._state_data["steps_remaining"] = max( | |
| 0, self._state_data["task"]["max_steps"] - self._state_data["steps_taken"] | |
| ) | |
| self._state_data["previous_iteration_score"] = float( | |
| self._state_data["current_iteration_score"] | |
| ) | |
| reward, reward_components = grade_step_details( | |
| self._state_data, parsed_action.model_dump(), result | |
| ) | |
| rid = parsed_action.record_id | |
| self._state_data["per_record_scores"][rid] = float( | |
| self._state_data["per_record_scores"].get(rid, 0.0) | |
| ) + reward | |
| self._state_data["current_iteration_score"] = sum( | |
| float(v) for v in self._state_data["per_record_scores"].values() | |
| ) | |
| prev = self._state_data["previous_iteration_score"] | |
| curr = self._state_data["current_iteration_score"] | |
| if curr > prev: | |
| reward += 0.1 | |
| reward_components["iteration_improvement"] = 0.1 | |
| elif curr < prev: | |
| reward -= 0.1 | |
| reward_components["iteration_improvement"] = -0.1 | |
| self._update_metrics() | |
| task_score = grade_task_result( | |
| self._state_data["task"], self._state_data["dataset_modified"], self._state_data | |
| ) | |
| done = self._state_data["steps_remaining"] <= 0 | |
| self._state_data["done"] = done | |
| info = { | |
| "actions_taken": deepcopy(self._state_data["action_history"]), | |
| "updated_dataset": deepcopy(self._state_data["dataset_modified"]), | |
| "per_record_scores": deepcopy(self._state_data["per_record_scores"]), | |
| "final_task_score": task_score, | |
| "metrics": { | |
| "hallucination_rate": self._state_data["hallucination_rate"], | |
| "uncertainty_accuracy": self._state_data["uncertainty_accuracy"], | |
| "consistency_score": self._state_data["consistency_score"], | |
| }, | |
| "failure_logs": deepcopy(self._state_data["failure_logs"]), | |
| "reward_components": reward_components, | |
| "result": result, | |
| } | |
| return self._build_observation(), reward, done, info | |
| def state(self) -> Dict[str, Any]: | |
| return deepcopy(self._state_data) | |
| def close(self) -> None: | |
| self._state_data = {} | |
| def _select_task_factory(self) -> Tuple[str, Any]: | |
| """Pick the configured task factory deterministically.""" | |
| if self._fixed_task_name is None: | |
| return self._rng.choice(self._task_registry) | |
| for task_name, task_factory in self._task_registry: | |
| if self._fixed_task_name in {task_name, task_factory.__name__}: | |
| return task_name, task_factory | |
| raise ValueError(f"Unknown task_name: {self._fixed_task_name}") | |
| def _evaluate_action(self, action: Action) -> Dict[str, Any]: | |
| table = self._state_data["dataset_modified"] | |
| issue = self._matching_issue(action.record_id, action.field) | |
| issue_key = self._issue_key(issue) | |
| result: Dict[str, Any] = {"extra_fields_modified": 0} | |
| self._apply_related_consistency(action, issue, result) | |
| self._apply_follow_up_requirement(action, issue_key, result) | |
| if action.action_type == "skip": | |
| if issue is not None: | |
| result["missed_issue"] = True | |
| result["passive_penalty"] = True | |
| if issue_key is not None: | |
| self._state_data["detected_unresolved_issues"][issue_key] = True | |
| self._append_failure(action, "missed_issue", "Issue exists but action was skip.") | |
| return result | |
| if action.action_type == "detect_issue": | |
| if issue is not None: | |
| result["classification_correct"] = True | |
| result["correct_issue_detected"] = True | |
| result["passive_penalty"] = True | |
| if issue_key is not None: | |
| if issue_key in self._state_data["detected_issues"]: | |
| result["repeated_detection"] = True | |
| self._state_data["detected_issues"][issue_key] = True | |
| self._state_data["detected_unresolved_issues"][issue_key] = True | |
| else: | |
| result["classification_incorrect"] = True | |
| result["false_issue"] = True | |
| return result | |
| if action.action_type == "cannot_determine": | |
| self._state_data["totals"]["total_cannot_determine"] += 1 | |
| if issue is None: | |
| result["wrong_cannot_determine"] = True | |
| self._append_failure( | |
| action, "wrong_fix", "cannot_determine used without any supporting issue." | |
| ) | |
| elif issue.get("fixable", True) is False: | |
| result["correct_cannot_determine"] = True | |
| self._state_data["totals"]["correct_cannot_determine"] += 1 | |
| if issue_key is not None: | |
| self._state_data["detected_unresolved_issues"].pop(issue_key, None) | |
| if issue_key in self._state_data["detected_issues"]: | |
| result["resolved_detected_issue"] = True | |
| else: | |
| result["wrong_cannot_determine"] = True | |
| self._append_failure( | |
| action, "wrong_fix", "cannot_determine used when evidence was sufficient." | |
| ) | |
| return result | |
| # fix_value | |
| self._state_data["totals"]["total_fixes"] += 1 | |
| if issue is None: | |
| related_issue_count = self._count_issues_for_record(action.record_id) | |
| if related_issue_count > 0: | |
| result["extra_fields_modified"] += 1 | |
| row = self._find_record(action.record_id, table) | |
| if row is None or action.field not in row: | |
| result["hallucinated_fix"] = True | |
| self._state_data["totals"]["hallucinated_fixes"] += 1 | |
| self._append_failure(action, "hallucination", "Attempted fix with no evidence.") | |
| return result | |
| if issue is None: | |
| result["hallucinated_fix"] = True | |
| self._state_data["totals"]["hallucinated_fixes"] += 1 | |
| self._append_failure(action, "hallucination", "Field had no target issue.") | |
| return result | |
| if self._issue_resolved(issue, table): | |
| result["hallucinated_fix"] = True | |
| self._state_data["totals"]["hallucinated_fixes"] += 1 | |
| self._append_failure(action, "hallucination", "Field is already correct.") | |
| return result | |
| old_value = row.get(action.field) | |
| before_row = deepcopy(row) | |
| row[action.field] = action.value | |
| if self._introduces_inconsistency(row, action.field, table): | |
| result["hallucinated_fix"] = True | |
| self._state_data["totals"]["hallucinated_fixes"] += 1 | |
| row[action.field] = old_value | |
| self._append_failure( | |
| action, "hallucination", "Fix introduces cross-record or temporal inconsistency." | |
| ) | |
| return result | |
| if self.validate_fix(issue, before_row, row, table): | |
| result["correct_fix"] = True | |
| result["classification_correct"] = True | |
| if issue_key is not None: | |
| if issue_key in self._state_data["detected_issues"]: | |
| result["resolved_detected_issue"] = True | |
| self._state_data["detected_unresolved_issues"].pop(issue_key, None) | |
| else: | |
| row[action.field] = old_value | |
| result["wrong_fix"] = True | |
| self._append_failure(action, "wrong_fix", "Fix does not resolve the identified issue.") | |
| return result | |
| def _apply_follow_up_requirement( | |
| self, action: Action, issue_key: Optional[str], result: Dict[str, Any] | |
| ) -> None: | |
| unresolved = self._state_data.get("detected_unresolved_issues", {}) | |
| if not unresolved: | |
| return | |
| # Follow-up action types are fix/cannot_determine against a detected issue. | |
| is_follow_up = ( | |
| action.action_type in {"fix_value", "cannot_determine"} | |
| and issue_key is not None | |
| and issue_key in unresolved | |
| ) | |
| if not is_follow_up: | |
| result["passive_penalty"] = True | |
| def _apply_related_consistency( | |
| self, action: Action, issue: Optional[Dict[str, Any]], result: Dict[str, Any] | |
| ) -> None: | |
| if issue is None: | |
| return | |
| issue_type = issue.get("type") | |
| if issue_type not in {"duplicate", "conflict"}: | |
| return | |
| rows = issue.get("rows", []) | |
| if not rows: | |
| return | |
| key = f"{issue_type}:{','.join(str(v) for v in sorted(rows))}" | |
| self._state_data["totals"]["total_related_cases"] += 1 | |
| seen = self._state_data["related_decisions"] | |
| decision = action.action_type | |
| if key not in seen: | |
| seen[key] = decision | |
| result["consistent_handling"] = True | |
| self._state_data["totals"]["consistent_decisions"] += 1 | |
| return | |
| if seen[key] == decision: | |
| result["consistent_handling"] = True | |
| self._state_data["totals"]["consistent_decisions"] += 1 | |
| else: | |
| result["inconsistent_handling"] = True | |
| self._append_failure( | |
| action, "inconsistency", "Related records were handled inconsistently." | |
| ) | |
| def _matching_issue(self, record_id: str, field: str) -> Optional[Dict[str, Any]]: | |
| rid = self._parse_record_id(record_id) | |
| for issue in self._state_data["task"]["hidden_issues"]: | |
| issue_type = issue.get("type") | |
| if issue_type == "missing_value" and issue.get("row") == rid and issue.get("column") == field: | |
| return issue | |
| if issue_type == "invalid_format" and issue.get("row") == rid and issue.get("column") == field: | |
| return issue | |
| if issue_type == "inconsistent_casing" and field == issue.get("column") and rid in issue.get("rows", []): | |
| return issue | |
| if ( | |
| issue_type in {"duplicate", "conflict", "constraint_violation"} | |
| and (field in {"row", "record"} or field == issue.get("field")) | |
| and rid in issue.get("rows", []) | |
| ): | |
| ambiguous = issue_type in {"conflict", "constraint_violation"} | |
| c = dict(issue) | |
| c["ambiguous"] = ambiguous | |
| return c | |
| return None | |
| def _issue_resolved(self, issue: Mapping[str, Any], table: List[Dict[str, Any]]) -> bool: | |
| if issue.get("type") in {"duplicate", "conflict", "constraint_violation"}: | |
| return False | |
| rid = int(issue.get("row", -1)) | |
| field = issue.get("column") | |
| row = self._find_record(str(rid), table) | |
| if row is None: | |
| return True | |
| if issue.get("type") == "missing_value": | |
| return row.get(field) not in (None, "", "unknown", "9999") | |
| if issue.get("type") == "invalid_format": | |
| value = str(row.get(field, "")) | |
| if field == "email": | |
| return "@" in value and "." in value.split("@")[-1] | |
| if field == "phone": | |
| digits = "".join(ch for ch in value if ch.isdigit()) | |
| return len(digits) in {10, 11} | |
| if field in {"start_date", "end_date"}: | |
| start = row.get("start_date") | |
| end = row.get("end_date") | |
| return not (start and end and str(end) < str(start)) | |
| return row.get(field) not in (None, "", "unknown", "9999") | |
| def validate_fix( | |
| self, | |
| issue: Mapping[str, Any], | |
| before_row: Mapping[str, Any], | |
| after_row: Mapping[str, Any], | |
| table: List[Dict[str, Any]], | |
| ) -> bool: | |
| """Ground-truth validator for semantic fixes.""" | |
| issue_type = str(issue.get("type", "")) | |
| field = str(issue.get("column") or issue.get("field") or "") | |
| if field and before_row.get(field) == after_row.get(field): | |
| return False | |
| if field == "age": | |
| try: | |
| age = int(after_row.get("age")) | |
| except Exception: | |
| return False | |
| if age < 0 or age > 120: | |
| return False | |
| if issue_type == "missing_value": | |
| return after_row.get(field) not in (None, "", "unknown", "9999") | |
| if issue_type == "invalid_format": | |
| value = str(after_row.get(field, "")) | |
| if field == "email": | |
| return "@" in value and "." in value.split("@")[-1] | |
| if field == "phone": | |
| digits = "".join(ch for ch in value if ch.isdigit()) | |
| return len(digits) in {10, 11} | |
| if field in {"start_date", "end_date"}: | |
| start = after_row.get("start_date") | |
| end = after_row.get("end_date") | |
| return not (start and end and str(end) < str(start)) | |
| return value not in ("", "unknown", "9999") | |
| if issue_type == "inconsistent_casing": | |
| value = after_row.get(field) | |
| return isinstance(value, str) and value == value.title() | |
| if issue_type in {"duplicate", "conflict", "constraint_violation"}: | |
| return False | |
| return not self._introduces_inconsistency(dict(after_row), field, table) and self._issue_resolved( | |
| issue, table | |
| ) | |
| def _count_issues_for_record(self, record_id: str) -> int: | |
| rid = self._parse_record_id(record_id) | |
| count = 0 | |
| for issue in self._state_data["task"]["hidden_issues"]: | |
| if issue.get("row") == rid: | |
| count += 1 | |
| continue | |
| if rid in issue.get("rows", []): | |
| count += 1 | |
| return count | |
| def _issue_key(self, issue: Optional[Dict[str, Any]]) -> Optional[str]: | |
| if issue is None: | |
| return None | |
| issue_type = issue.get("type", "unknown") | |
| if "row" in issue and "column" in issue: | |
| return f"{issue_type}:row={issue.get('row')}:col={issue.get('column')}" | |
| if "rows" in issue: | |
| rows = ",".join(str(v) for v in sorted(issue.get("rows", []))) | |
| field = issue.get("field", "record") | |
| return f"{issue_type}:rows={rows}:field={field}" | |
| return f"{issue_type}:generic" | |
| def _introduces_inconsistency( | |
| self, row: Dict[str, Any], field: str, table: List[Dict[str, Any]] | |
| ) -> bool: | |
| # Unique email consistency check across records. | |
| if field == "email": | |
| email = row.get("email") | |
| if email not in (None, ""): | |
| duplicates = [ | |
| r for r in table | |
| if r is not row and str(r.get("email", "")).strip() == str(email).strip() | |
| ] | |
| if duplicates: | |
| return True | |
| # Temporal consistency check where both fields are present. | |
| if field in {"start_date", "end_date"}: | |
| start = row.get("start_date") | |
| end = row.get("end_date") | |
| if start and end and str(end) < str(start): | |
| return True | |
| return False | |
| def _build_observation(self) -> Observation: | |
| return Observation( | |
| dataset={ | |
| "original": deepcopy(self._state_data["dataset_original"]), | |
| "modified": deepcopy(self._state_data["dataset_modified"]), | |
| }, | |
| action_history=deepcopy(self._state_data["action_history"]), | |
| per_record_scores=deepcopy(self._state_data["per_record_scores"]), | |
| current_iteration_score=float(self._state_data["current_iteration_score"]), | |
| previous_iteration_score=float(self._state_data["previous_iteration_score"]), | |
| steps_remaining=int(self._state_data["steps_remaining"]), | |
| ) | |
| def _update_metrics(self) -> None: | |
| totals = self._state_data["totals"] | |
| total_fixes = int(totals["total_fixes"]) | |
| self._state_data["hallucination_rate"] = ( | |
| 0.0 if total_fixes == 0 else float(totals["hallucinated_fixes"]) / total_fixes | |
| ) | |
| total_cd = int(totals["total_cannot_determine"]) | |
| self._state_data["uncertainty_accuracy"] = ( | |
| 0.0 if total_cd == 0 else float(totals["correct_cannot_determine"]) / total_cd | |
| ) | |
| total_related = int(totals["total_related_cases"]) | |
| self._state_data["consistency_score"] = ( | |
| 1.0 if total_related == 0 else float(totals["consistent_decisions"]) / total_related | |
| ) | |
| def _parse_record_id(self, record_id: str) -> int: | |
| digits = "".join(ch for ch in str(record_id) if ch.isdigit()) | |
| return int(digits) if digits else -1 | |
| def _find_record(self, record_id: str, table: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]: | |
| rid = self._parse_record_id(record_id) | |
| for row in table: | |
| if int(row.get("row_id", -1)) == rid: | |
| return row | |
| return None | |
| def _append_failure(self, action: Action, error_type: str, details: str) -> None: | |
| mapped = error_type | |
| if error_type == "wrong_fix": | |
| mapped = "wrong_fix" | |
| self._state_data["failure_logs"].append( | |
| { | |
| "record_id": action.record_id, | |
| "error_type": mapped, | |
| "details": details, | |
| "confidence": float(action.confidence), | |
| } | |
| ) | |
| class DataOpsGymEnv(DataOpsEnv): | |
| """Compatibility wrapper matching the configured OpenEnv entrypoint.""" | |
| pass | |
| __all__ = ["DataOpsEnv", "DataOpsGymEnv"] | |