dataops-env / env.py
Graheet
Refactor semantic cleaning evaluator and improve API docs UX.
00cf35f
"""Semantic data-cleaning evaluation environment."""
from __future__ import annotations
from copy import deepcopy
import random
from typing import Any, Dict, List, Mapping, Optional, Tuple
from grader import grade_step_details, grade_task_result
from models import Action, Observation
from task import easy_cleaning_task, hard_conflict_resolution_task, medium_normalization_task
class DataOpsEnv:
"""Step-based semantic evaluator with strict action protocol."""
def __init__(self, seed: int = 0, task_name: Optional[str] = None) -> None:
self._seed = seed
self._rng = random.Random(seed)
self._task_registry: List[Tuple[str, Any]] = [
("easy", easy_cleaning_task),
("medium", medium_normalization_task),
("hard", hard_conflict_resolution_task),
]
self._fixed_task_name = task_name
self._state_data: Dict[str, Any] = {}
def reset(self) -> Observation:
task_name, task_factory = self._select_task_factory()
variant_count = max(1, int(getattr(task_factory, "variant_count", 1)))
task_definition = deepcopy(task_factory(variant=self._rng.randrange(variant_count)))
initial_table = deepcopy(task_definition["initial_table"])
self._state_data = {
"seed": self._seed,
"task_name": task_name,
"task_variant": task_definition.get("variant_id", task_name),
"task": task_definition,
"dataset_original": initial_table,
"dataset_modified": deepcopy(initial_table),
"action_history": [],
"per_record_scores": {},
"current_iteration_score": 0.0,
"previous_iteration_score": 0.0,
"failure_logs": [],
"steps_taken": 0,
"steps_remaining": task_definition["max_steps"],
"done": False,
"totals": {
"total_fixes": 0,
"hallucinated_fixes": 0,
"total_cannot_determine": 0,
"correct_cannot_determine": 0,
"total_related_cases": 0,
"consistent_decisions": 0,
},
"related_decisions": {},
"detected_unresolved_issues": {},
"detected_issues": {},
"hallucination_rate": 0.0,
"uncertainty_accuracy": 0.0,
"consistency_score": 1.0,
}
return self._build_observation()
def step(self, action: Action | Mapping[str, Any]) -> Tuple[Observation, float, bool, Dict[str, Any]]:
if not self._state_data:
raise RuntimeError("Environment must be reset before calling step().")
if self._state_data["done"]:
raise RuntimeError("Episode is finished. Call reset() before stepping again.")
parsed_action = action if isinstance(action, Action) else Action(**dict(action))
result = self._evaluate_action(parsed_action)
self._state_data["action_history"].append(parsed_action.model_dump())
self._state_data["steps_taken"] += 1
self._state_data["steps_remaining"] = max(
0, self._state_data["task"]["max_steps"] - self._state_data["steps_taken"]
)
self._state_data["previous_iteration_score"] = float(
self._state_data["current_iteration_score"]
)
reward, reward_components = grade_step_details(
self._state_data, parsed_action.model_dump(), result
)
rid = parsed_action.record_id
self._state_data["per_record_scores"][rid] = float(
self._state_data["per_record_scores"].get(rid, 0.0)
) + reward
self._state_data["current_iteration_score"] = sum(
float(v) for v in self._state_data["per_record_scores"].values()
)
prev = self._state_data["previous_iteration_score"]
curr = self._state_data["current_iteration_score"]
if curr > prev:
reward += 0.1
reward_components["iteration_improvement"] = 0.1
elif curr < prev:
reward -= 0.1
reward_components["iteration_improvement"] = -0.1
self._update_metrics()
task_score = grade_task_result(
self._state_data["task"], self._state_data["dataset_modified"], self._state_data
)
done = self._state_data["steps_remaining"] <= 0
self._state_data["done"] = done
info = {
"actions_taken": deepcopy(self._state_data["action_history"]),
"updated_dataset": deepcopy(self._state_data["dataset_modified"]),
"per_record_scores": deepcopy(self._state_data["per_record_scores"]),
"final_task_score": task_score,
"metrics": {
"hallucination_rate": self._state_data["hallucination_rate"],
"uncertainty_accuracy": self._state_data["uncertainty_accuracy"],
"consistency_score": self._state_data["consistency_score"],
},
"failure_logs": deepcopy(self._state_data["failure_logs"]),
"reward_components": reward_components,
"result": result,
}
return self._build_observation(), reward, done, info
def state(self) -> Dict[str, Any]:
return deepcopy(self._state_data)
def close(self) -> None:
self._state_data = {}
def _select_task_factory(self) -> Tuple[str, Any]:
"""Pick the configured task factory deterministically."""
if self._fixed_task_name is None:
return self._rng.choice(self._task_registry)
for task_name, task_factory in self._task_registry:
if self._fixed_task_name in {task_name, task_factory.__name__}:
return task_name, task_factory
raise ValueError(f"Unknown task_name: {self._fixed_task_name}")
def _evaluate_action(self, action: Action) -> Dict[str, Any]:
table = self._state_data["dataset_modified"]
issue = self._matching_issue(action.record_id, action.field)
issue_key = self._issue_key(issue)
result: Dict[str, Any] = {"extra_fields_modified": 0}
self._apply_related_consistency(action, issue, result)
self._apply_follow_up_requirement(action, issue_key, result)
if action.action_type == "skip":
if issue is not None:
result["missed_issue"] = True
result["passive_penalty"] = True
if issue_key is not None:
self._state_data["detected_unresolved_issues"][issue_key] = True
self._append_failure(action, "missed_issue", "Issue exists but action was skip.")
return result
if action.action_type == "detect_issue":
if issue is not None:
result["classification_correct"] = True
result["correct_issue_detected"] = True
result["passive_penalty"] = True
if issue_key is not None:
if issue_key in self._state_data["detected_issues"]:
result["repeated_detection"] = True
self._state_data["detected_issues"][issue_key] = True
self._state_data["detected_unresolved_issues"][issue_key] = True
else:
result["classification_incorrect"] = True
result["false_issue"] = True
return result
if action.action_type == "cannot_determine":
self._state_data["totals"]["total_cannot_determine"] += 1
if issue is None:
result["wrong_cannot_determine"] = True
self._append_failure(
action, "wrong_fix", "cannot_determine used without any supporting issue."
)
elif issue.get("fixable", True) is False:
result["correct_cannot_determine"] = True
self._state_data["totals"]["correct_cannot_determine"] += 1
if issue_key is not None:
self._state_data["detected_unresolved_issues"].pop(issue_key, None)
if issue_key in self._state_data["detected_issues"]:
result["resolved_detected_issue"] = True
else:
result["wrong_cannot_determine"] = True
self._append_failure(
action, "wrong_fix", "cannot_determine used when evidence was sufficient."
)
return result
# fix_value
self._state_data["totals"]["total_fixes"] += 1
if issue is None:
related_issue_count = self._count_issues_for_record(action.record_id)
if related_issue_count > 0:
result["extra_fields_modified"] += 1
row = self._find_record(action.record_id, table)
if row is None or action.field not in row:
result["hallucinated_fix"] = True
self._state_data["totals"]["hallucinated_fixes"] += 1
self._append_failure(action, "hallucination", "Attempted fix with no evidence.")
return result
if issue is None:
result["hallucinated_fix"] = True
self._state_data["totals"]["hallucinated_fixes"] += 1
self._append_failure(action, "hallucination", "Field had no target issue.")
return result
if self._issue_resolved(issue, table):
result["hallucinated_fix"] = True
self._state_data["totals"]["hallucinated_fixes"] += 1
self._append_failure(action, "hallucination", "Field is already correct.")
return result
old_value = row.get(action.field)
before_row = deepcopy(row)
row[action.field] = action.value
if self._introduces_inconsistency(row, action.field, table):
result["hallucinated_fix"] = True
self._state_data["totals"]["hallucinated_fixes"] += 1
row[action.field] = old_value
self._append_failure(
action, "hallucination", "Fix introduces cross-record or temporal inconsistency."
)
return result
if self.validate_fix(issue, before_row, row, table):
result["correct_fix"] = True
result["classification_correct"] = True
if issue_key is not None:
if issue_key in self._state_data["detected_issues"]:
result["resolved_detected_issue"] = True
self._state_data["detected_unresolved_issues"].pop(issue_key, None)
else:
row[action.field] = old_value
result["wrong_fix"] = True
self._append_failure(action, "wrong_fix", "Fix does not resolve the identified issue.")
return result
def _apply_follow_up_requirement(
self, action: Action, issue_key: Optional[str], result: Dict[str, Any]
) -> None:
unresolved = self._state_data.get("detected_unresolved_issues", {})
if not unresolved:
return
# Follow-up action types are fix/cannot_determine against a detected issue.
is_follow_up = (
action.action_type in {"fix_value", "cannot_determine"}
and issue_key is not None
and issue_key in unresolved
)
if not is_follow_up:
result["passive_penalty"] = True
def _apply_related_consistency(
self, action: Action, issue: Optional[Dict[str, Any]], result: Dict[str, Any]
) -> None:
if issue is None:
return
issue_type = issue.get("type")
if issue_type not in {"duplicate", "conflict"}:
return
rows = issue.get("rows", [])
if not rows:
return
key = f"{issue_type}:{','.join(str(v) for v in sorted(rows))}"
self._state_data["totals"]["total_related_cases"] += 1
seen = self._state_data["related_decisions"]
decision = action.action_type
if key not in seen:
seen[key] = decision
result["consistent_handling"] = True
self._state_data["totals"]["consistent_decisions"] += 1
return
if seen[key] == decision:
result["consistent_handling"] = True
self._state_data["totals"]["consistent_decisions"] += 1
else:
result["inconsistent_handling"] = True
self._append_failure(
action, "inconsistency", "Related records were handled inconsistently."
)
def _matching_issue(self, record_id: str, field: str) -> Optional[Dict[str, Any]]:
rid = self._parse_record_id(record_id)
for issue in self._state_data["task"]["hidden_issues"]:
issue_type = issue.get("type")
if issue_type == "missing_value" and issue.get("row") == rid and issue.get("column") == field:
return issue
if issue_type == "invalid_format" and issue.get("row") == rid and issue.get("column") == field:
return issue
if issue_type == "inconsistent_casing" and field == issue.get("column") and rid in issue.get("rows", []):
return issue
if (
issue_type in {"duplicate", "conflict", "constraint_violation"}
and (field in {"row", "record"} or field == issue.get("field"))
and rid in issue.get("rows", [])
):
ambiguous = issue_type in {"conflict", "constraint_violation"}
c = dict(issue)
c["ambiguous"] = ambiguous
return c
return None
def _issue_resolved(self, issue: Mapping[str, Any], table: List[Dict[str, Any]]) -> bool:
if issue.get("type") in {"duplicate", "conflict", "constraint_violation"}:
return False
rid = int(issue.get("row", -1))
field = issue.get("column")
row = self._find_record(str(rid), table)
if row is None:
return True
if issue.get("type") == "missing_value":
return row.get(field) not in (None, "", "unknown", "9999")
if issue.get("type") == "invalid_format":
value = str(row.get(field, ""))
if field == "email":
return "@" in value and "." in value.split("@")[-1]
if field == "phone":
digits = "".join(ch for ch in value if ch.isdigit())
return len(digits) in {10, 11}
if field in {"start_date", "end_date"}:
start = row.get("start_date")
end = row.get("end_date")
return not (start and end and str(end) < str(start))
return row.get(field) not in (None, "", "unknown", "9999")
def validate_fix(
self,
issue: Mapping[str, Any],
before_row: Mapping[str, Any],
after_row: Mapping[str, Any],
table: List[Dict[str, Any]],
) -> bool:
"""Ground-truth validator for semantic fixes."""
issue_type = str(issue.get("type", ""))
field = str(issue.get("column") or issue.get("field") or "")
if field and before_row.get(field) == after_row.get(field):
return False
if field == "age":
try:
age = int(after_row.get("age"))
except Exception:
return False
if age < 0 or age > 120:
return False
if issue_type == "missing_value":
return after_row.get(field) not in (None, "", "unknown", "9999")
if issue_type == "invalid_format":
value = str(after_row.get(field, ""))
if field == "email":
return "@" in value and "." in value.split("@")[-1]
if field == "phone":
digits = "".join(ch for ch in value if ch.isdigit())
return len(digits) in {10, 11}
if field in {"start_date", "end_date"}:
start = after_row.get("start_date")
end = after_row.get("end_date")
return not (start and end and str(end) < str(start))
return value not in ("", "unknown", "9999")
if issue_type == "inconsistent_casing":
value = after_row.get(field)
return isinstance(value, str) and value == value.title()
if issue_type in {"duplicate", "conflict", "constraint_violation"}:
return False
return not self._introduces_inconsistency(dict(after_row), field, table) and self._issue_resolved(
issue, table
)
def _count_issues_for_record(self, record_id: str) -> int:
rid = self._parse_record_id(record_id)
count = 0
for issue in self._state_data["task"]["hidden_issues"]:
if issue.get("row") == rid:
count += 1
continue
if rid in issue.get("rows", []):
count += 1
return count
def _issue_key(self, issue: Optional[Dict[str, Any]]) -> Optional[str]:
if issue is None:
return None
issue_type = issue.get("type", "unknown")
if "row" in issue and "column" in issue:
return f"{issue_type}:row={issue.get('row')}:col={issue.get('column')}"
if "rows" in issue:
rows = ",".join(str(v) for v in sorted(issue.get("rows", [])))
field = issue.get("field", "record")
return f"{issue_type}:rows={rows}:field={field}"
return f"{issue_type}:generic"
def _introduces_inconsistency(
self, row: Dict[str, Any], field: str, table: List[Dict[str, Any]]
) -> bool:
# Unique email consistency check across records.
if field == "email":
email = row.get("email")
if email not in (None, ""):
duplicates = [
r for r in table
if r is not row and str(r.get("email", "")).strip() == str(email).strip()
]
if duplicates:
return True
# Temporal consistency check where both fields are present.
if field in {"start_date", "end_date"}:
start = row.get("start_date")
end = row.get("end_date")
if start and end and str(end) < str(start):
return True
return False
def _build_observation(self) -> Observation:
return Observation(
dataset={
"original": deepcopy(self._state_data["dataset_original"]),
"modified": deepcopy(self._state_data["dataset_modified"]),
},
action_history=deepcopy(self._state_data["action_history"]),
per_record_scores=deepcopy(self._state_data["per_record_scores"]),
current_iteration_score=float(self._state_data["current_iteration_score"]),
previous_iteration_score=float(self._state_data["previous_iteration_score"]),
steps_remaining=int(self._state_data["steps_remaining"]),
)
def _update_metrics(self) -> None:
totals = self._state_data["totals"]
total_fixes = int(totals["total_fixes"])
self._state_data["hallucination_rate"] = (
0.0 if total_fixes == 0 else float(totals["hallucinated_fixes"]) / total_fixes
)
total_cd = int(totals["total_cannot_determine"])
self._state_data["uncertainty_accuracy"] = (
0.0 if total_cd == 0 else float(totals["correct_cannot_determine"]) / total_cd
)
total_related = int(totals["total_related_cases"])
self._state_data["consistency_score"] = (
1.0 if total_related == 0 else float(totals["consistent_decisions"]) / total_related
)
def _parse_record_id(self, record_id: str) -> int:
digits = "".join(ch for ch in str(record_id) if ch.isdigit())
return int(digits) if digits else -1
def _find_record(self, record_id: str, table: List[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
rid = self._parse_record_id(record_id)
for row in table:
if int(row.get("row_id", -1)) == rid:
return row
return None
def _append_failure(self, action: Action, error_type: str, details: str) -> None:
mapped = error_type
if error_type == "wrong_fix":
mapped = "wrong_fix"
self._state_data["failure_logs"].append(
{
"record_id": action.record_id,
"error_type": mapped,
"details": details,
"confidence": float(action.confidence),
}
)
class DataOpsGymEnv(DataOpsEnv):
"""Compatibility wrapper matching the configured OpenEnv entrypoint."""
pass
__all__ = ["DataOpsEnv", "DataOpsGymEnv"]