# environment/env.py # Main OpenEnv Environment for SQL Data Analyst # Inherits from openenv.BaseEnv and implements reset(), step(), state() from typing import Dict, Any, Tuple, Optional from dataclasses import dataclass, field from .models import Action, Observation, Reward from .db_engine import DatabaseEngine from .tasks import Task, get_random_task, TASKS from .graders import grade_answer, calculate_final_score # Try to import openenv.BaseEnv, fallback to a simple base class if not available try: from openenv import BaseEnv except ImportError: # Fallback base class for development/testing class BaseEnv: """Fallback base class when openenv-core is not installed.""" pass # ============================================ # REWARD CONSTANTS (per PRD specification) # ============================================ REWARD_SUCCESSFUL_QUERY = 0.1 # Successful, error-free SQL query REWARD_SYNTAX_ERROR = -0.1 # SQLite syntax error REWARD_DESTRUCTIVE_ACTION = -1.0 # Destructive action detected REWARD_INFINITE_LOOP = -0.5 # Step count >= 15 # Maximum steps before infinite loop shield activates MAX_STEPS = 15 @dataclass class EnvironmentState: """ Internal state of the SQL Analyst environment. Attributes: task: The current task being solved step_count: Number of steps taken in current episode done: Whether the episode has ended last_query_result: Result from the most recent SQL query error_message: Error message from the last action rewards: List of all rewards received in this episode final_score: The final grading score (0.0 to 1.0) success: Whether the task was completed successfully """ task: Optional[Task] = None step_count: int = 0 done: bool = False last_query_result: str = "" error_message: str = "" rewards: list = field(default_factory=list) final_score: float = 0.0 success: bool = False class SQLAnalystEnv(BaseEnv): """ SQL Data Analyst Reinforcement Learning Environment. This environment simulates a Data Analyst workspace where an AI agent queries a SQLite database to answer business questions. Implements the OpenEnv interface: - reset(): Initialize a clean episode - step(action): Execute an action and return (observation, reward, done, info) - state(): Return the current internal state Reward Shaping (per PRD): - +0.1: Successful, error-free SQL query - -0.1: SQLite syntax error - -1.0: Destructive action detected (done=True) - -0.5: Step count >= 15 (infinite loop shield, done=True) """ def __init__(self): """Initialize the SQL Analyst environment.""" super().__init__() self.db_engine = DatabaseEngine() self._state = EnvironmentState() def reset(self, task_id: Optional[str] = None) -> Observation: """ Reset the environment to start a new episode. This method: 1. Initializes a clean in-memory SQLite database 2. Randomly selects 1 of the 3 tasks (or uses specified task) 3. Resets step_count to 0 4. Returns the initial observation Args: task_id: Optional specific task to use Returns: Observation: The initial observation for the episode """ # Initialize clean database self.db_engine.initialize() # Select task if task_id: for task in TASKS: if task.task_id == task_id: self._state.task = task break else: self._state.task = get_random_task() else: self._state.task = get_random_task() # Reset state self._state.step_count = 0 self._state.done = False self._state.last_query_result = "" self._state.error_message = "" self._state.rewards = [] self._state.final_score = 0.0 self._state.success = False # Build initial observation return Observation( schema_info=self.db_engine.get_schema(), current_question=self._state.task.question, last_query_result="No queries executed yet.", error_message="" ) def step(self, action: Action) -> Tuple[Observation, Reward, bool, Dict[str, Any]]: """ Execute an action in the environment. This method processes the agent's action and returns: - observation: The new state after the action - reward: The reward for this action - done: Whether the episode has ended - info: Additional information Reward Shaping: - +0.1: Successful, error-free SQL query - -0.1: SQLite syntax error - -1.0: Destructive action detected (done=True) - -0.5: Step count >= 15 (done=True) Args: action: The Action to execute Returns: Tuple containing (observation, reward, done, info) """ if self._state.done: # Episode already ended return self._get_observation(), Reward(value=0.0), True, self._get_info() # Increment step count self._state.step_count += 1 # Check for infinite loop shield FIRST if self._state.step_count >= MAX_STEPS: self._state.done = True self._state.error_message = f"Maximum steps ({MAX_STEPS}) reached. Episode terminated." reward = REWARD_INFINITE_LOOP self._state.rewards.append(reward) return self._get_observation(), Reward(value=reward), True, self._get_info() # Initialize reward for this step reward = 0.0 self._state.error_message = "" # Process action if action.sql_query: reward = self._handle_sql_query(action.sql_query) elif action.submit_answer: reward = self._handle_submit_answer(action.submit_answer) # Record reward self._state.rewards.append(reward) return self._get_observation(), Reward(value=reward), self._state.done, self._get_info() def _handle_sql_query(self, query: str) -> float: """ Handle a SQL query action. Args: query: The SQL query to execute Returns: float: The reward for this action """ # Check for destructive action first mutation_error = self.db_engine.check_mutation(query) if mutation_error: self._state.done = True self._state.error_message = mutation_error self._state.last_query_result = "" return REWARD_DESTRUCTIVE_ACTION # Execute the query result, is_error = self.db_engine.execute_query(query) if is_error: self._state.error_message = result self._state.last_query_result = "" return REWARD_SYNTAX_ERROR # Successful query self._state.last_query_result = result self._state.error_message = "" return REWARD_SUCCESSFUL_QUERY def _handle_submit_answer(self, answer: str) -> float: """ Handle a submit answer action. Args: answer: The answer to submit for grading Returns: float: The reward for this action """ # Episode ends when answer is submitted self._state.done = True # Grade the answer is_correct, grading_score = grade_answer( answer, self._state.task.ground_truth, self.db_engine ) # Calculate final score self._state.success = is_correct self._state.final_score = calculate_final_score( is_correct, self._state.step_count, MAX_STEPS ) # Reward for submission is based on correctness # This is separate from the final_score which considers efficiency if is_correct: return 1.0 # Full reward for correct answer else: return 0.0 # No reward for incorrect answer def _get_observation(self) -> Observation: """ Build the current observation. Returns: Observation: The current state visible to the agent """ return Observation( schema_info=self.db_engine.get_schema(), current_question=self._state.task.question if self._state.task else "", last_query_result=self._state.last_query_result or "No results yet.", error_message=self._state.error_message ) def _get_info(self) -> Dict[str, Any]: """ Build the info dictionary. Returns: Dict: Additional information about the current state """ return { "step_count": self._state.step_count, "task_id": self._state.task.task_id if self._state.task else None, "task_difficulty": self._state.task.difficulty if self._state.task else None, "success": self._state.success, "final_score": self._state.final_score, "total_reward": sum(self._state.rewards), "rewards_history": self._state.rewards.copy() } def state(self) -> Dict[str, Any]: """ Return the current internal state of the environment. Returns: Dict: The full internal state """ return { "task_id": self._state.task.task_id if self._state.task else None, "task_difficulty": self._state.task.difficulty if self._state.task else None, "task_question": self._state.task.question if self._state.task else None, "step_count": self._state.step_count, "done": self._state.done, "last_query_result": self._state.last_query_result, "error_message": self._state.error_message, "rewards": self._state.rewards.copy(), "total_reward": sum(self._state.rewards), "success": self._state.success, "final_score": self._state.final_score } def close(self): """Clean up resources.""" if self.db_engine: self.db_engine.close()