YashashMathur's picture
Sync from GitHub - all files
f762b8d verified
# environment/env.py
# Main OpenEnv Environment for SQL Data Analyst
# Inherits from openenv.BaseEnv and implements reset(), step(), state()
from typing import Dict, Any, Tuple, Optional
from dataclasses import dataclass, field
from .models import Action, Observation, Reward
from .db_engine import DatabaseEngine
from .tasks import Task, get_random_task, TASKS
from .graders import grade_answer, calculate_final_score
# Try to import openenv.BaseEnv, fallback to a simple base class if not available
try:
from openenv import BaseEnv
except ImportError:
# Fallback base class for development/testing
class BaseEnv:
"""Fallback base class when openenv-core is not installed."""
pass
# ============================================
# REWARD CONSTANTS (per PRD specification)
# ============================================
REWARD_SUCCESSFUL_QUERY = 0.1 # Successful, error-free SQL query
REWARD_SYNTAX_ERROR = -0.1 # SQLite syntax error
REWARD_DESTRUCTIVE_ACTION = -1.0 # Destructive action detected
REWARD_INFINITE_LOOP = -0.5 # Step count >= 15
# Maximum steps before infinite loop shield activates
MAX_STEPS = 15
@dataclass
class EnvironmentState:
"""
Internal state of the SQL Analyst environment.
Attributes:
task: The current task being solved
step_count: Number of steps taken in current episode
done: Whether the episode has ended
last_query_result: Result from the most recent SQL query
error_message: Error message from the last action
rewards: List of all rewards received in this episode
final_score: The final grading score (0.0 to 1.0)
success: Whether the task was completed successfully
"""
task: Optional[Task] = None
step_count: int = 0
done: bool = False
last_query_result: str = ""
error_message: str = ""
rewards: list = field(default_factory=list)
final_score: float = 0.0
success: bool = False
class SQLAnalystEnv(BaseEnv):
"""
SQL Data Analyst Reinforcement Learning Environment.
This environment simulates a Data Analyst workspace where an AI agent
queries a SQLite database to answer business questions.
Implements the OpenEnv interface:
- reset(): Initialize a clean episode
- step(action): Execute an action and return (observation, reward, done, info)
- state(): Return the current internal state
Reward Shaping (per PRD):
- +0.1: Successful, error-free SQL query
- -0.1: SQLite syntax error
- -1.0: Destructive action detected (done=True)
- -0.5: Step count >= 15 (infinite loop shield, done=True)
"""
def __init__(self):
"""Initialize the SQL Analyst environment."""
super().__init__()
self.db_engine = DatabaseEngine()
self._state = EnvironmentState()
def reset(self, task_id: Optional[str] = None) -> Observation:
"""
Reset the environment to start a new episode.
This method:
1. Initializes a clean in-memory SQLite database
2. Randomly selects 1 of the 3 tasks (or uses specified task)
3. Resets step_count to 0
4. Returns the initial observation
Args:
task_id: Optional specific task to use
Returns:
Observation: The initial observation for the episode
"""
# Initialize clean database
self.db_engine.initialize()
# Select task
if task_id:
for task in TASKS:
if task.task_id == task_id:
self._state.task = task
break
else:
self._state.task = get_random_task()
else:
self._state.task = get_random_task()
# Reset state
self._state.step_count = 0
self._state.done = False
self._state.last_query_result = ""
self._state.error_message = ""
self._state.rewards = []
self._state.final_score = 0.0
self._state.success = False
# Build initial observation
return Observation(
schema_info=self.db_engine.get_schema(),
current_question=self._state.task.question,
last_query_result="No queries executed yet.",
error_message=""
)
def step(self, action: Action) -> Tuple[Observation, Reward, bool, Dict[str, Any]]:
"""
Execute an action in the environment.
This method processes the agent's action and returns:
- observation: The new state after the action
- reward: The reward for this action
- done: Whether the episode has ended
- info: Additional information
Reward Shaping:
- +0.1: Successful, error-free SQL query
- -0.1: SQLite syntax error
- -1.0: Destructive action detected (done=True)
- -0.5: Step count >= 15 (done=True)
Args:
action: The Action to execute
Returns:
Tuple containing (observation, reward, done, info)
"""
if self._state.done:
# Episode already ended
return self._get_observation(), Reward(value=0.0), True, self._get_info()
# Increment step count
self._state.step_count += 1
# Check for infinite loop shield FIRST
if self._state.step_count >= MAX_STEPS:
self._state.done = True
self._state.error_message = f"Maximum steps ({MAX_STEPS}) reached. Episode terminated."
reward = REWARD_INFINITE_LOOP
self._state.rewards.append(reward)
return self._get_observation(), Reward(value=reward), True, self._get_info()
# Initialize reward for this step
reward = 0.0
self._state.error_message = ""
# Process action
if action.sql_query:
reward = self._handle_sql_query(action.sql_query)
elif action.submit_answer:
reward = self._handle_submit_answer(action.submit_answer)
# Record reward
self._state.rewards.append(reward)
return self._get_observation(), Reward(value=reward), self._state.done, self._get_info()
def _handle_sql_query(self, query: str) -> float:
"""
Handle a SQL query action.
Args:
query: The SQL query to execute
Returns:
float: The reward for this action
"""
# Check for destructive action first
mutation_error = self.db_engine.check_mutation(query)
if mutation_error:
self._state.done = True
self._state.error_message = mutation_error
self._state.last_query_result = ""
return REWARD_DESTRUCTIVE_ACTION
# Execute the query
result, is_error = self.db_engine.execute_query(query)
if is_error:
self._state.error_message = result
self._state.last_query_result = ""
return REWARD_SYNTAX_ERROR
# Successful query
self._state.last_query_result = result
self._state.error_message = ""
return REWARD_SUCCESSFUL_QUERY
def _handle_submit_answer(self, answer: str) -> float:
"""
Handle a submit answer action.
Args:
answer: The answer to submit for grading
Returns:
float: The reward for this action
"""
# Episode ends when answer is submitted
self._state.done = True
# Grade the answer
is_correct, grading_score = grade_answer(
answer,
self._state.task.ground_truth,
self.db_engine
)
# Calculate final score
self._state.success = is_correct
self._state.final_score = calculate_final_score(
is_correct,
self._state.step_count,
MAX_STEPS
)
# Reward for submission is based on correctness
# This is separate from the final_score which considers efficiency
if is_correct:
return 1.0 # Full reward for correct answer
else:
return 0.0 # No reward for incorrect answer
def _get_observation(self) -> Observation:
"""
Build the current observation.
Returns:
Observation: The current state visible to the agent
"""
return Observation(
schema_info=self.db_engine.get_schema(),
current_question=self._state.task.question if self._state.task else "",
last_query_result=self._state.last_query_result or "No results yet.",
error_message=self._state.error_message
)
def _get_info(self) -> Dict[str, Any]:
"""
Build the info dictionary.
Returns:
Dict: Additional information about the current state
"""
return {
"step_count": self._state.step_count,
"task_id": self._state.task.task_id if self._state.task else None,
"task_difficulty": self._state.task.difficulty if self._state.task else None,
"success": self._state.success,
"final_score": self._state.final_score,
"total_reward": sum(self._state.rewards),
"rewards_history": self._state.rewards.copy()
}
def state(self) -> Dict[str, Any]:
"""
Return the current internal state of the environment.
Returns:
Dict: The full internal state
"""
return {
"task_id": self._state.task.task_id if self._state.task else None,
"task_difficulty": self._state.task.difficulty if self._state.task else None,
"task_question": self._state.task.question if self._state.task else None,
"step_count": self._state.step_count,
"done": self._state.done,
"last_query_result": self._state.last_query_result,
"error_message": self._state.error_message,
"rewards": self._state.rewards.copy(),
"total_reward": sum(self._state.rewards),
"success": self._state.success,
"final_score": self._state.final_score
}
def close(self):
"""Clean up resources."""
if self.db_engine:
self.db_engine.close()