OpenEnv_hack / server /environment.py
Taniieeee83's picture
made changes to reset()
69e5273
raw
history blame
13.9 kB
"""
Core environment implementing reset / step / state.
Each call to reset() picks a task (round-robin: 1 → 2 → 3 → 1 …)
or a specific task_id can be forced via reset(task_id=N).
"""
import re
import uuid
import numpy as np
import pandas as pd
from typing import Any, Dict, Optional, Tuple
from models import DataCleaningAction, DataCleaningObservation, DataCleaningState
import server.tasks.task1_missing as t1
import server.tasks.task2_format as t2
import server.tasks.task3_pipeline as t3
TASK_MODULES = {1: t1, 2: t2, 3: t3}
PHONE_RE = re.compile(r"^\d{3}-\d{3}-\d{4}$")
DATE_RE = re.compile(r"^\d{4}-\d{2}-\d{2}$")
VALID_COUNTRIES = {"USA", "UK", "Canada", "Australia", "Germany"}
class DataCleaningEnvironment:
def __init__(self):
self._df: Optional[pd.DataFrame] = None
self._clean_df: Optional[pd.DataFrame] = None
self._meta: Any = None # task-specific metadata
self._task_id: int = 1
self._episode_id: str = ""
self._step_count: int = 0
self._max_steps: int = 20
self._total_errors: int = 0
self._last_score: float = 0.0
self._task_cycle: int = 0 # for round-robin default
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def reset(self, task_id: Optional[int] = None) -> DataCleaningObservation:
if task_id is None:
self._task_cycle = (self._task_cycle % 3) + 1
task_id = self._task_cycle
if task_id not in TASK_MODULES:
raise ValueError(f"task_id must be 1, 2, or 3 — got {task_id}")
mod = TASK_MODULES[task_id]
self._task_id = task_id
self._episode_id = str(uuid.uuid4())
self._step_count = 0
self._max_steps = mod.MAX_STEPS
if task_id == 1:
self._df, self._clean_df, self._meta = mod.load()
else:
self._df, self._clean_df, self._meta = mod.load()
self._last_score = self._compute_score()
self._total_errors = self._count_errors()
return self._build_obs(self._last_score, False, "Episode started. Begin cleaning.")
def step(self, action: DataCleaningAction) -> DataCleaningObservation:
if self._df is None:
raise RuntimeError("Call reset() before step().")
self._step_count += 1
score_before = self._last_score
message, applied = self._apply_action(action)
score_after = self._compute_score()
self._last_score = score_after
delta = score_after - score_before
if not applied:
reward = -0.05
elif delta <= 0:
reward = -0.01
else:
reward = round(delta, 4)
done = (score_after >= 0.95) or (self._step_count >= self._max_steps)
if done and score_after >= 0.95:
reward = round(reward + 0.2, 4)
return self._build_obs(reward, done, message)
def state(self) -> DataCleaningState:
if self._df is None:
return DataCleaningState(
episode_id="", task_id=0, step_count=0,
max_steps=0, total_errors=0, errors_remaining=0,
)
return DataCleaningState(
episode_id = self._episode_id,
task_id = self._task_id,
step_count = self._step_count,
max_steps = self._max_steps,
total_errors = self._total_errors,
errors_remaining = self._count_errors(),
)
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _compute_score(self) -> float:
if self._task_id == 1:
return t1.score(self._df, self._meta)
elif self._task_id == 2:
return t2.score(self._df, self._meta)
else:
return t3.score(self._df, self._meta)
def _count_errors(self) -> int:
if self._task_id == 1:
return t1.count_errors(self._df)
elif self._task_id == 2:
return t2.count_errors(self._df, self._meta)
else:
return t3.count_errors(self._df, self._meta)
def _build_obs(self, reward: float, done: bool, message: str) -> DataCleaningObservation:
mod = TASK_MODULES[self._task_id]
missing = {col: int(n) for col, n in self._df.isnull().sum().items() if n > 0}
dupes = len(self._df) - len(self._df.drop_duplicates())
dtype_issues = self._detect_dtype_issues()
preview = self._df.head(10).to_csv(index=False)
return DataCleaningObservation(
done = done,
reward = reward,
data_preview = preview,
data_shape = list(self._df.shape),
missing_counts = missing,
duplicate_count = dupes,
dtype_issues = dtype_issues,
task_description = mod.DESCRIPTION,
message = message,
step_count = self._step_count,
current_score = self._last_score,
)
def _detect_dtype_issues(self) -> Dict[str, str]:
issues: Dict[str, str] = {}
for col in self._df.columns:
series = self._df[col].dropna()
if series.empty:
continue
if self._df[col].dtype == object:
numeric_count = pd.to_numeric(series, errors="coerce").notna().sum()
if numeric_count / len(series) > 0.8:
issues[col] = "stored as string but appears numeric"
return issues
# ------------------------------------------------------------------
# Action dispatcher
# ------------------------------------------------------------------
def _apply_action(self, action: DataCleaningAction) -> Tuple[str, bool]:
op = action.operation.strip().lower()
col = action.column
p = action.params or {}
try:
if op == "fill_missing":
return self._fill_missing(col, p)
elif op == "drop_duplicates":
return self._drop_duplicates()
elif op == "fix_format":
return self._fix_format(col)
elif op == "replace_value":
return self._replace_value(col, p)
elif op == "drop_outliers":
return self._drop_outliers(col)
elif op == "fix_dtype":
return self._fix_dtype(col, p)
else:
return f"Unknown operation '{op}'. Choose from: fill_missing, drop_duplicates, fix_format, replace_value, drop_outliers, fix_dtype.", False
except Exception as exc:
return f"Operation failed: {exc}", False
def _fill_missing(self, col, p) -> Tuple[str, bool]:
if col is None or col not in self._df.columns:
return f"Column '{col}' not found.", False
n_before = int(self._df[col].isnull().sum())
if n_before == 0:
return f"No missing values in '{col}'.", False
strategy = str(p.get("strategy", "median")).lower()
if strategy == "median":
fill_val = self._df[col].median(skipna=True)
elif strategy == "mean":
fill_val = self._df[col].mean(skipna=True)
elif strategy == "mode":
mode = self._df[col].mode(dropna=True)
fill_val = mode.iloc[0] if not mode.empty else None
elif strategy == "constant":
fill_val = p.get("value")
else:
return f"Unknown strategy '{strategy}'.", False
if fill_val is None:
return "Could not determine fill value.", False
self._df[col] = self._df[col].fillna(fill_val)
n_after = int(self._df[col].isnull().sum())
return f"Filled {n_before - n_after} missing values in '{col}' using {strategy}.", True
def _drop_duplicates(self) -> Tuple[str, bool]:
n_before = len(self._df)
self._df = self._df.drop_duplicates().reset_index(drop=True)
n_after = len(self._df)
removed = n_before - n_after
if removed == 0:
return "No duplicate rows found.", False
return f"Dropped {removed} duplicate rows.", True
def _fix_format(self, col) -> Tuple[str, bool]:
if col is None or col not in self._df.columns:
return f"Column '{col}' not found.", False
if col == "phone":
return self._fix_phone(col)
elif col in ("listed_date", "signup_date"):
return self._fix_date(col)
elif col == "country":
return self._fix_country(col)
else:
return f"No format rule defined for column '{col}'.", False
def _fix_phone(self, col) -> Tuple[str, bool]:
def normalise(val):
if pd.isna(val):
return val
digits = re.sub(r"\D", "", str(val))
if len(digits) == 10:
return f"{digits[:3]}-{digits[3:6]}-{digits[6:]}"
return val
before = (~self._df[col].str.match(PHONE_RE, na=False)).sum()
self._df[col] = self._df[col].apply(normalise)
after = (~self._df[col].str.match(PHONE_RE, na=False)).sum()
fixed = int(before - after)
if fixed == 0:
return f"No phone format issues found in '{col}'.", False
return f"Fixed {fixed} phone numbers in '{col}' to NNN-NNN-NNNN format.", True
def _fix_date(self, col) -> Tuple[str, bool]:
_DATE_FORMATS = ["%Y-%m-%d", "%b %d %Y", "%d/%m/%Y", "%m/%d/%Y", "%Y/%m/%d"]
def normalise(val):
if pd.isna(val):
return val
s = str(val).strip()
for fmt in _DATE_FORMATS:
try:
return pd.to_datetime(s, format=fmt).strftime("%Y-%m-%d")
except Exception:
pass
# last-resort flexible parse
try:
return pd.to_datetime(s).strftime("%Y-%m-%d")
except Exception:
return val
before = (~self._df[col].apply(
lambda x: bool(DATE_RE.match(str(x))) if pd.notna(x) else False
)).sum()
self._df[col] = self._df[col].apply(normalise)
after = (~self._df[col].apply(
lambda x: bool(DATE_RE.match(str(x))) if pd.notna(x) else False
)).sum()
fixed = int(before - after)
if fixed == 0:
return f"No date format issues found in '{col}'.", False
return f"Fixed {fixed} dates in '{col}' to YYYY-MM-DD format.", True
def _fix_country(self, col) -> Tuple[str, bool]:
def normalise(val):
if pd.isna(val):
return val
mapping = {
"usa": "USA", "uk": "UK", "canada": "Canada",
"australia": "Australia", "germany": "Germany",
}
return mapping.get(str(val).strip().lower(), val)
before = (~self._df[col].isin(VALID_COUNTRIES) & self._df[col].notna()).sum()
self._df[col] = self._df[col].apply(normalise)
after = (~self._df[col].isin(VALID_COUNTRIES) & self._df[col].notna()).sum()
fixed = int(before - after)
if fixed == 0:
return f"No country capitalisation issues found.", False
return f"Fixed {fixed} country values to correct capitalisation.", True
def _replace_value(self, col, p) -> Tuple[str, bool]:
if col is None or col not in self._df.columns:
return f"Column '{col}' not found.", False
old = p.get("old")
new = p.get("new")
if old is None:
return "params.old is required for replace_value.", False
count = int((self._df[col] == old).sum())
if count == 0:
return f"Value '{old}' not found in '{col}'.", False
self._df[col] = self._df[col].replace(old, new)
return f"Replaced {count} occurrences of '{old}' with '{new}' in '{col}'.", True
def _drop_outliers(self, col) -> Tuple[str, bool]:
if col is None or col not in self._df.columns:
return f"Column '{col}' not found.", False
if not pd.api.types.is_numeric_dtype(self._df[col]):
return f"'{col}' is not numeric.", False
q1 = self._df[col].quantile(0.25)
q3 = self._df[col].quantile(0.75)
iqr = q3 - q1
mask = (self._df[col] >= q1 - 3 * iqr) & (self._df[col] <= q3 + 3 * iqr)
n_before = len(self._df)
self._df = self._df[mask | self._df[col].isna()].reset_index(drop=True)
removed = n_before - len(self._df)
if removed == 0:
return f"No outliers found in '{col}'.", False
return f"Removed {removed} outlier rows from '{col}' using IQR method.", True
def _fix_dtype(self, col, p) -> Tuple[str, bool]:
if col is None or col not in self._df.columns:
return f"Column '{col}' not found.", False
dtype = str(p.get("dtype", "float")).lower()
try:
if dtype == "float":
self._df[col] = pd.to_numeric(self._df[col], errors="coerce").astype(float)
elif dtype == "int":
self._df[col] = pd.to_numeric(self._df[col], errors="coerce")
elif dtype == "str":
self._df[col] = self._df[col].astype(str)
else:
return f"Unknown dtype '{dtype}'.", False
return f"Converted '{col}' to {dtype}.", True
except Exception as exc:
return f"dtype conversion failed: {exc}", False