""" Core environment implementing reset / step / state. Each call to reset() picks a task (round-robin: 1 → 2 → 3 → 1 …) or a specific task_id can be forced via reset(task_id=N). """ import re import uuid import numpy as np import pandas as pd from typing import Any, Dict, Optional, Tuple from models import DataCleaningAction, DataCleaningObservation, DataCleaningState import server.tasks.task1_missing as t1 import server.tasks.task2_format as t2 import server.tasks.task3_pipeline as t3 TASK_MODULES = {1: t1, 2: t2, 3: t3} PHONE_RE = re.compile(r"^\d{3}-\d{3}-\d{4}$") DATE_RE = re.compile(r"^\d{4}-\d{2}-\d{2}$") VALID_COUNTRIES = {"USA", "UK", "Canada", "Australia", "Germany"} class DataCleaningEnvironment: def __init__(self): self._df: Optional[pd.DataFrame] = None self._clean_df: Optional[pd.DataFrame] = None self._meta: Any = None # task-specific metadata self._task_id: int = 1 self._episode_id: str = "" self._step_count: int = 0 self._max_steps: int = 20 self._total_errors: int = 0 self._last_score: float = 0.0 self._task_cycle: int = 0 # for round-robin default # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ def reset(self, task_id: Optional[int] = None) -> DataCleaningObservation: if task_id is None: self._task_cycle = (self._task_cycle % 3) + 1 task_id = self._task_cycle if task_id not in TASK_MODULES: raise ValueError(f"task_id must be 1, 2, or 3 — got {task_id}") mod = TASK_MODULES[task_id] self._task_id = task_id self._episode_id = str(uuid.uuid4()) self._step_count = 0 self._max_steps = mod.MAX_STEPS if task_id == 1: self._df, self._clean_df, self._meta = mod.load() else: self._df, self._clean_df, self._meta = mod.load() self._last_score = self._compute_score() self._total_errors = self._count_errors() return self._build_obs(self._last_score, False, "Episode started. Begin cleaning.") def step(self, action: DataCleaningAction) -> DataCleaningObservation: if self._df is None: raise RuntimeError("Call reset() before step().") self._step_count += 1 score_before = self._last_score message, applied = self._apply_action(action) score_after = self._compute_score() self._last_score = score_after delta = score_after - score_before if not applied: reward = -0.05 elif delta <= 0: reward = -0.01 else: reward = round(delta, 4) done = (score_after >= 0.95) or (self._step_count >= self._max_steps) if done and score_after >= 0.95: reward = round(reward + 0.2, 4) return self._build_obs(reward, done, message) def state(self) -> DataCleaningState: if self._df is None: return DataCleaningState( episode_id="", task_id=0, step_count=0, max_steps=0, total_errors=0, errors_remaining=0, ) return DataCleaningState( episode_id = self._episode_id, task_id = self._task_id, step_count = self._step_count, max_steps = self._max_steps, total_errors = self._total_errors, errors_remaining = self._count_errors(), ) # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ def _compute_score(self) -> float: if self._task_id == 1: return t1.score(self._df, self._meta) elif self._task_id == 2: return t2.score(self._df, self._meta) else: return t3.score(self._df, self._meta) def _count_errors(self) -> int: if self._task_id == 1: return t1.count_errors(self._df) elif self._task_id == 2: return t2.count_errors(self._df, self._meta) else: return t3.count_errors(self._df, self._meta) def _build_obs(self, reward: float, done: bool, message: str) -> DataCleaningObservation: mod = TASK_MODULES[self._task_id] missing = {col: int(n) for col, n in self._df.isnull().sum().items() if n > 0} dupes = len(self._df) - len(self._df.drop_duplicates()) dtype_issues = self._detect_dtype_issues() preview = self._df.head(10).to_csv(index=False) return DataCleaningObservation( done = done, reward = reward, data_preview = preview, data_shape = list(self._df.shape), missing_counts = missing, duplicate_count = dupes, dtype_issues = dtype_issues, task_description = mod.DESCRIPTION, message = message, step_count = self._step_count, current_score = self._last_score, ) def _detect_dtype_issues(self) -> Dict[str, str]: issues: Dict[str, str] = {} for col in self._df.columns: series = self._df[col].dropna() if series.empty: continue if self._df[col].dtype == object: numeric_count = pd.to_numeric(series, errors="coerce").notna().sum() if numeric_count / len(series) > 0.8: issues[col] = "stored as string but appears numeric" return issues # ------------------------------------------------------------------ # Action dispatcher # ------------------------------------------------------------------ def _apply_action(self, action: DataCleaningAction) -> Tuple[str, bool]: op = action.operation.strip().lower() col = action.column p = action.params or {} try: if op == "fill_missing": return self._fill_missing(col, p) elif op == "drop_duplicates": return self._drop_duplicates() elif op == "fix_format": return self._fix_format(col) elif op == "replace_value": return self._replace_value(col, p) elif op == "drop_outliers": return self._drop_outliers(col) elif op == "fix_dtype": return self._fix_dtype(col, p) else: return f"Unknown operation '{op}'. Choose from: fill_missing, drop_duplicates, fix_format, replace_value, drop_outliers, fix_dtype.", False except Exception as exc: return f"Operation failed: {exc}", False def _fill_missing(self, col, p) -> Tuple[str, bool]: if col is None or col not in self._df.columns: return f"Column '{col}' not found.", False n_before = int(self._df[col].isnull().sum()) if n_before == 0: return f"No missing values in '{col}'.", False strategy = str(p.get("strategy", "median")).lower() if strategy == "median": fill_val = self._df[col].median(skipna=True) elif strategy == "mean": fill_val = self._df[col].mean(skipna=True) elif strategy == "mode": mode = self._df[col].mode(dropna=True) fill_val = mode.iloc[0] if not mode.empty else None elif strategy == "constant": fill_val = p.get("value") else: return f"Unknown strategy '{strategy}'.", False if fill_val is None: return "Could not determine fill value.", False self._df[col] = self._df[col].fillna(fill_val) n_after = int(self._df[col].isnull().sum()) return f"Filled {n_before - n_after} missing values in '{col}' using {strategy}.", True def _drop_duplicates(self) -> Tuple[str, bool]: n_before = len(self._df) self._df = self._df.drop_duplicates().reset_index(drop=True) n_after = len(self._df) removed = n_before - n_after if removed == 0: return "No duplicate rows found.", False return f"Dropped {removed} duplicate rows.", True def _fix_format(self, col) -> Tuple[str, bool]: if col is None or col not in self._df.columns: return f"Column '{col}' not found.", False if col == "phone": return self._fix_phone(col) elif col in ("listed_date", "signup_date"): return self._fix_date(col) elif col == "country": return self._fix_country(col) else: return f"No format rule defined for column '{col}'.", False def _fix_phone(self, col) -> Tuple[str, bool]: def normalise(val): if pd.isna(val): return val digits = re.sub(r"\D", "", str(val)) if len(digits) == 10: return f"{digits[:3]}-{digits[3:6]}-{digits[6:]}" return val before = (~self._df[col].str.match(PHONE_RE, na=False)).sum() self._df[col] = self._df[col].apply(normalise) after = (~self._df[col].str.match(PHONE_RE, na=False)).sum() fixed = int(before - after) if fixed == 0: return f"No phone format issues found in '{col}'.", False return f"Fixed {fixed} phone numbers in '{col}' to NNN-NNN-NNNN format.", True def _fix_date(self, col) -> Tuple[str, bool]: _DATE_FORMATS = ["%Y-%m-%d", "%b %d %Y", "%d/%m/%Y", "%m/%d/%Y", "%Y/%m/%d"] def normalise(val): if pd.isna(val): return val s = str(val).strip() for fmt in _DATE_FORMATS: try: return pd.to_datetime(s, format=fmt).strftime("%Y-%m-%d") except Exception: pass # last-resort flexible parse try: return pd.to_datetime(s).strftime("%Y-%m-%d") except Exception: return val before = (~self._df[col].apply( lambda x: bool(DATE_RE.match(str(x))) if pd.notna(x) else False )).sum() self._df[col] = self._df[col].apply(normalise) after = (~self._df[col].apply( lambda x: bool(DATE_RE.match(str(x))) if pd.notna(x) else False )).sum() fixed = int(before - after) if fixed == 0: return f"No date format issues found in '{col}'.", False return f"Fixed {fixed} dates in '{col}' to YYYY-MM-DD format.", True def _fix_country(self, col) -> Tuple[str, bool]: def normalise(val): if pd.isna(val): return val mapping = { "usa": "USA", "uk": "UK", "canada": "Canada", "australia": "Australia", "germany": "Germany", } return mapping.get(str(val).strip().lower(), val) before = (~self._df[col].isin(VALID_COUNTRIES) & self._df[col].notna()).sum() self._df[col] = self._df[col].apply(normalise) after = (~self._df[col].isin(VALID_COUNTRIES) & self._df[col].notna()).sum() fixed = int(before - after) if fixed == 0: return f"No country capitalisation issues found.", False return f"Fixed {fixed} country values to correct capitalisation.", True def _replace_value(self, col, p) -> Tuple[str, bool]: if col is None or col not in self._df.columns: return f"Column '{col}' not found.", False old = p.get("old") new = p.get("new") if old is None: return "params.old is required for replace_value.", False count = int((self._df[col] == old).sum()) if count == 0: return f"Value '{old}' not found in '{col}'.", False self._df[col] = self._df[col].replace(old, new) return f"Replaced {count} occurrences of '{old}' with '{new}' in '{col}'.", True def _drop_outliers(self, col) -> Tuple[str, bool]: if col is None or col not in self._df.columns: return f"Column '{col}' not found.", False if not pd.api.types.is_numeric_dtype(self._df[col]): return f"'{col}' is not numeric.", False q1 = self._df[col].quantile(0.25) q3 = self._df[col].quantile(0.75) iqr = q3 - q1 mask = (self._df[col] >= q1 - 3 * iqr) & (self._df[col] <= q3 + 3 * iqr) n_before = len(self._df) self._df = self._df[mask | self._df[col].isna()].reset_index(drop=True) removed = n_before - len(self._df) if removed == 0: return f"No outliers found in '{col}'.", False return f"Removed {removed} outlier rows from '{col}' using IQR method.", True def _fix_dtype(self, col, p) -> Tuple[str, bool]: if col is None or col not in self._df.columns: return f"Column '{col}' not found.", False dtype = str(p.get("dtype", "float")).lower() try: if dtype == "float": self._df[col] = pd.to_numeric(self._df[col], errors="coerce").astype(float) elif dtype == "int": self._df[col] = pd.to_numeric(self._df[col], errors="coerce") elif dtype == "str": self._df[col] = self._df[col].astype(str) else: return f"Unknown dtype '{dtype}'.", False return f"Converted '{col}' to {dtype}.", True except Exception as exc: return f"dtype conversion failed: {exc}", False