"""Core environment logic for DataDetective.""" import random import uuid from typing import Any, Optional from openenv.core.env_server import Environment try: from ..models import DataDetectiveAction, DataDetectiveObservation, DataDetectiveState from .database import create_database, get_schema_info from .tasks import TASKS, grade_answer except (ImportError, ModuleNotFoundError): from models import DataDetectiveAction, DataDetectiveObservation, DataDetectiveState from server.database import create_database, get_schema_info from server.tasks import TASKS, grade_answer class DataDetectiveEnvironment( Environment[DataDetectiveAction, DataDetectiveObservation, DataDetectiveState] ): SUPPORTS_CONCURRENT_SESSIONS = True MAX_STEPS = 30 def __init__(self): super().__init__() self._db = None self._task_id: str = "" self._step_count: int = 0 self._episode_id: str = "" self._queries_executed: int = 0 self._state = DataDetectiveState() def reset( self, seed: Optional[int] = None, episode_id: Optional[str] = None, task_id: Optional[str] = None, **kwargs: Any, ) -> DataDetectiveObservation: if seed is not None: random.seed(seed) self._episode_id = episode_id or str(uuid.uuid4()) self._task_id = task_id if task_id in TASKS else random.choice(list(TASKS)) self._step_count = 0 self._queries_executed = 0 if self._db is not None: self._db.close() self._db = create_database() task = TASKS[self._task_id] schema = get_schema_info(self._db) self._state = DataDetectiveState( episode_id=self._episode_id, step_count=0, task_id=self._task_id, queries_executed=0, max_steps=self.MAX_STEPS, ) return DataDetectiveObservation( done=False, reward=None, output="Environment ready. Run SQL queries to investigate the issue, then submit your answer.", task_description=task["description"], schema_info=schema, step_number=0, max_steps=self.MAX_STEPS, message=f"Investigation: {task['title']} [{task['difficulty'].upper()}] -- {self.MAX_STEPS} steps available.", ) def step( self, action: DataDetectiveAction, timeout_s: Optional[float] = None, **kwargs: Any, ) -> DataDetectiveObservation: self._step_count += 1 self._state.step_count = self._step_count remaining = self.MAX_STEPS - self._step_count if self._step_count > self.MAX_STEPS: return self._obs( done=True, reward=0.0, output="Maximum steps reached -- investigation ended with no answer submitted.", message="Out of steps.", ) atype = (action.action_type or "").strip().lower() if atype == "query": return self._handle_query(action.content, remaining) elif atype == "answer": return self._handle_answer(action.content) else: return self._obs( done=False, reward=0.0, output="", message=f"Unknown action_type '{action.action_type}'. Use 'query' or 'answer'. ({remaining} steps left)", ) @property def state(self) -> DataDetectiveState: return self._state def close(self) -> None: if self._db is not None: self._db.close() self._db = None def _obs(self, *, done: bool, reward: float | None, output: str, message: str) -> DataDetectiveObservation: return DataDetectiveObservation( done=done, reward=reward, output=output, task_description=TASKS[self._task_id]["description"], schema_info="", step_number=self._step_count, max_steps=self.MAX_STEPS, message=message, ) def _handle_query(self, sql: str, remaining: int) -> DataDetectiveObservation: self._queries_executed += 1 self._state.queries_executed = self._queries_executed if not sql or not sql.strip(): return self._obs( done=False, reward=0.0, output="Empty query -- please provide a valid SQL statement.", message=f"{remaining} steps left.", ) try: cur = self._db.cursor() cur.execute(sql) columns = [d[0] for d in cur.description] if cur.description else [] rows = cur.fetchall() output = _format_table(columns, rows) if rows else "Query returned 0 rows." except Exception as exc: output = f"SQL Error: {exc}" return self._obs( done=False, reward=0.0, output=output, message=f"Query failed. Fix your SQL and retry. ({remaining} steps left)", ) return self._obs( done=False, reward=0.0, output=output, message=f"{len(rows)} row(s) returned. ({remaining} steps left)", ) def _handle_answer(self, answer_text: str) -> DataDetectiveObservation: reward = grade_answer(self._task_id, answer_text) if reward >= 0.8: verdict = "Excellent investigation!" elif reward >= 0.5: verdict = "Good findings, but some details missing." else: verdict = "Several key findings were missed." return self._obs( done=True, reward=reward, output=f"Score: {reward:.2f} / 1.00 -- {verdict}", message=f"Investigation complete. Final score: {reward:.2f}", ) def _format_table(columns: list[str], rows: list, max_rows: int = 100) -> str: truncated = len(rows) > max_rows display = rows[:max_rows] widths = [len(str(c)) for c in columns] for row in display: for i, v in enumerate(row): widths[i] = max(widths[i], min(len(str(v)), 60)) header = " | ".join(str(c).ljust(widths[i]) for i, c in enumerate(columns)) sep = "-+-".join("-" * w for w in widths) lines = [header, sep] for row in display: lines.append(" | ".join(str(v).ljust(widths[i])[:60] for i, v in enumerate(row))) if truncated: lines.append(f"... (showing {max_rows} of {len(rows)} rows)") return "\n".join(lines)