File size: 10,902 Bytes
f762b8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
# environment/env.py
# Main OpenEnv Environment for SQL Data Analyst
# Inherits from openenv.BaseEnv and implements reset(), step(), state()

from typing import Dict, Any, Tuple, Optional
from dataclasses import dataclass, field
from .models import Action, Observation, Reward
from .db_engine import DatabaseEngine
from .tasks import Task, get_random_task, TASKS
from .graders import grade_answer, calculate_final_score

# Try to import openenv.BaseEnv, fallback to a simple base class if not available
try:
    from openenv import BaseEnv
except ImportError:
    # Fallback base class for development/testing
    class BaseEnv:
        """Fallback base class when openenv-core is not installed."""
        pass


# ============================================
# REWARD CONSTANTS (per PRD specification)
# ============================================
REWARD_SUCCESSFUL_QUERY = 0.1      # Successful, error-free SQL query
REWARD_SYNTAX_ERROR = -0.1         # SQLite syntax error
REWARD_DESTRUCTIVE_ACTION = -1.0   # Destructive action detected
REWARD_INFINITE_LOOP = -0.5        # Step count >= 15

# Maximum steps before infinite loop shield activates
MAX_STEPS = 15


@dataclass
class EnvironmentState:
    """

    Internal state of the SQL Analyst environment.

    

    Attributes:

        task: The current task being solved

        step_count: Number of steps taken in current episode

        done: Whether the episode has ended

        last_query_result: Result from the most recent SQL query

        error_message: Error message from the last action

        rewards: List of all rewards received in this episode

        final_score: The final grading score (0.0 to 1.0)

        success: Whether the task was completed successfully

    """
    task: Optional[Task] = None
    step_count: int = 0
    done: bool = False
    last_query_result: str = ""
    error_message: str = ""
    rewards: list = field(default_factory=list)
    final_score: float = 0.0
    success: bool = False


class SQLAnalystEnv(BaseEnv):
    """

    SQL Data Analyst Reinforcement Learning Environment.

    

    This environment simulates a Data Analyst workspace where an AI agent

    queries a SQLite database to answer business questions.

    

    Implements the OpenEnv interface:

    - reset(): Initialize a clean episode

    - step(action): Execute an action and return (observation, reward, done, info)

    - state(): Return the current internal state

    

    Reward Shaping (per PRD):

    - +0.1: Successful, error-free SQL query

    - -0.1: SQLite syntax error

    - -1.0: Destructive action detected (done=True)

    - -0.5: Step count >= 15 (infinite loop shield, done=True)

    """
    
    def __init__(self):
        """Initialize the SQL Analyst environment."""
        super().__init__()
        self.db_engine = DatabaseEngine()
        self._state = EnvironmentState()
    
    def reset(self, task_id: Optional[str] = None) -> Observation:
        """

        Reset the environment to start a new episode.

        

        This method:

        1. Initializes a clean in-memory SQLite database

        2. Randomly selects 1 of the 3 tasks (or uses specified task)

        3. Resets step_count to 0

        4. Returns the initial observation

        

        Args:

            task_id: Optional specific task to use

            

        Returns:

            Observation: The initial observation for the episode

        """
        # Initialize clean database
        self.db_engine.initialize()
        
        # Select task
        if task_id:
            for task in TASKS:
                if task.task_id == task_id:
                    self._state.task = task
                    break
            else:
                self._state.task = get_random_task()
        else:
            self._state.task = get_random_task()
        
        # Reset state
        self._state.step_count = 0
        self._state.done = False
        self._state.last_query_result = ""
        self._state.error_message = ""
        self._state.rewards = []
        self._state.final_score = 0.0
        self._state.success = False
        
        # Build initial observation
        return Observation(
            schema_info=self.db_engine.get_schema(),
            current_question=self._state.task.question,
            last_query_result="No queries executed yet.",
            error_message=""
        )
    
    def step(self, action: Action) -> Tuple[Observation, Reward, bool, Dict[str, Any]]:
        """

        Execute an action in the environment.

        

        This method processes the agent's action and returns:

        - observation: The new state after the action

        - reward: The reward for this action

        - done: Whether the episode has ended

        - info: Additional information

        

        Reward Shaping:

        - +0.1: Successful, error-free SQL query

        - -0.1: SQLite syntax error

        - -1.0: Destructive action detected (done=True)

        - -0.5: Step count >= 15 (done=True)

        

        Args:

            action: The Action to execute

            

        Returns:

            Tuple containing (observation, reward, done, info)

        """
        if self._state.done:
            # Episode already ended
            return self._get_observation(), Reward(value=0.0), True, self._get_info()
        
        # Increment step count
        self._state.step_count += 1
        
        # Check for infinite loop shield FIRST
        if self._state.step_count >= MAX_STEPS:
            self._state.done = True
            self._state.error_message = f"Maximum steps ({MAX_STEPS}) reached. Episode terminated."
            reward = REWARD_INFINITE_LOOP
            self._state.rewards.append(reward)
            return self._get_observation(), Reward(value=reward), True, self._get_info()
        
        # Initialize reward for this step
        reward = 0.0
        self._state.error_message = ""
        
        # Process action
        if action.sql_query:
            reward = self._handle_sql_query(action.sql_query)
        elif action.submit_answer:
            reward = self._handle_submit_answer(action.submit_answer)
        
        # Record reward
        self._state.rewards.append(reward)
        
        return self._get_observation(), Reward(value=reward), self._state.done, self._get_info()
    
    def _handle_sql_query(self, query: str) -> float:
        """

        Handle a SQL query action.

        

        Args:

            query: The SQL query to execute

            

        Returns:

            float: The reward for this action

        """
        # Check for destructive action first
        mutation_error = self.db_engine.check_mutation(query)
        if mutation_error:
            self._state.done = True
            self._state.error_message = mutation_error
            self._state.last_query_result = ""
            return REWARD_DESTRUCTIVE_ACTION
        
        # Execute the query
        result, is_error = self.db_engine.execute_query(query)
        
        if is_error:
            self._state.error_message = result
            self._state.last_query_result = ""
            return REWARD_SYNTAX_ERROR
        
        # Successful query
        self._state.last_query_result = result
        self._state.error_message = ""
        return REWARD_SUCCESSFUL_QUERY
    
    def _handle_submit_answer(self, answer: str) -> float:
        """

        Handle a submit answer action.

        

        Args:

            answer: The answer to submit for grading

            

        Returns:

            float: The reward for this action

        """
        # Episode ends when answer is submitted
        self._state.done = True
        
        # Grade the answer
        is_correct, grading_score = grade_answer(
            answer,
            self._state.task.ground_truth,
            self.db_engine
        )
        
        # Calculate final score
        self._state.success = is_correct
        self._state.final_score = calculate_final_score(
            is_correct,
            self._state.step_count,
            MAX_STEPS
        )
        
        # Reward for submission is based on correctness
        # This is separate from the final_score which considers efficiency
        if is_correct:
            return 1.0  # Full reward for correct answer
        else:
            return 0.0  # No reward for incorrect answer
    
    def _get_observation(self) -> Observation:
        """

        Build the current observation.

        

        Returns:

            Observation: The current state visible to the agent

        """
        return Observation(
            schema_info=self.db_engine.get_schema(),
            current_question=self._state.task.question if self._state.task else "",
            last_query_result=self._state.last_query_result or "No results yet.",
            error_message=self._state.error_message
        )
    
    def _get_info(self) -> Dict[str, Any]:
        """

        Build the info dictionary.

        

        Returns:

            Dict: Additional information about the current state

        """
        return {
            "step_count": self._state.step_count,
            "task_id": self._state.task.task_id if self._state.task else None,
            "task_difficulty": self._state.task.difficulty if self._state.task else None,
            "success": self._state.success,
            "final_score": self._state.final_score,
            "total_reward": sum(self._state.rewards),
            "rewards_history": self._state.rewards.copy()
        }
    
    def state(self) -> Dict[str, Any]:
        """

        Return the current internal state of the environment.

        

        Returns:

            Dict: The full internal state

        """
        return {
            "task_id": self._state.task.task_id if self._state.task else None,
            "task_difficulty": self._state.task.difficulty if self._state.task else None,
            "task_question": self._state.task.question if self._state.task else None,
            "step_count": self._state.step_count,
            "done": self._state.done,
            "last_query_result": self._state.last_query_result,
            "error_message": self._state.error_message,
            "rewards": self._state.rewards.copy(),
            "total_reward": sum(self._state.rewards),
            "success": self._state.success,
            "final_score": self._state.final_score
        }
    
    def close(self):
        """Clean up resources."""
        if self.db_engine:
            self.db_engine.close()