""" Core OpenEnv environment: SQLOptimizerEnv Implements the three required methods: reset(task_id) → Observation step(action) → (Observation, Reward, done, info) state() → dict (current internal snapshot) """ from __future__ import annotations from typing import Any, Dict, Optional, Tuple from .models import Action, Observation, Reward, RewardBreakdown from .tasks import TASKS, TaskDef, get_task from .reward import compute_step_reward _MIN_SCORE_EPS = 0.001 _MAX_SCORE_EPS = 0.999 def _strict_score(value: float) -> float: return round(min(max(float(value), _MIN_SCORE_EPS), _MAX_SCORE_EPS), 4) class SQLOptimizerEnv: """SQL Query Optimizer OpenEnv environment.""" def __init__(self) -> None: self._task: Optional[TaskDef] = None self._step_number: int = 0 self._done: bool = False self._cumulative_score: float = 0.0 self._prev_grader_score: float = 0.0 self._history: list[Dict[str, Any]] = [] self._last_grader_score: float = 0.0 # ────────────────────────────────────────────────────────────────────────── # reset # ────────────────────────────────────────────────────────────────────────── def reset(self, task_id: int = 1) -> Observation: """Start a fresh episode for the given task.""" self._task = get_task(task_id) self._step_number = 0 self._done = False self._cumulative_score = 0.0 self._prev_grader_score = 0.0 self._last_grader_score = 0.0 self._history = [] return self._make_observation() # ────────────────────────────────────────────────────────────────────────── # step # ────────────────────────────────────────────────────────────────────────── def step(self, action: Action) -> Tuple[Observation, Reward, bool, Dict[str, Any]]: """ Advance the environment by one step. Returns: observation: next Observation reward: Reward for this step done: whether the episode has ended info: auxiliary dict """ if self._task is None: raise RuntimeError("Call reset() before step().") if self._done: raise RuntimeError("Episode is done. Call reset() to start a new episode.") # Validate action is_invalid = not action.rewritten_query or not action.rewritten_query.strip() # Run grader if is_invalid: grader_result_score = self._prev_grader_score breakdown = RewardBreakdown() feedback = "Empty or invalid query submitted." else: gr = self._task.grader(action.rewritten_query) grader_result_score = gr.score breakdown = RewardBreakdown( correctness=gr.correctness, performance=gr.performance, style=gr.style, step_penalty=0.0, ) feedback = gr.feedback grader_result_score = _strict_score(grader_result_score) # Compute shaped reward step_reward = compute_step_reward( grader_score=grader_result_score, prev_grader_score=self._prev_grader_score, step_number=self._step_number, max_steps=self._task.max_steps, is_done=action.is_done, is_invalid=is_invalid, ) # Apply step penalty to breakdown import math halfway = math.ceil(self._task.max_steps / 2) if self._step_number > halfway and not action.is_done: breakdown.step_penalty = -0.02 self._cumulative_score = _strict_score(self._cumulative_score + step_reward) self._prev_grader_score = grader_result_score self._last_grader_score = grader_result_score self._step_number += 1 # Episode ends if agent signals done OR max steps reached self._done = action.is_done or self._step_number >= self._task.max_steps # Record history self._history.append( { "step": self._step_number, "rewritten_query": action.rewritten_query, "grader_score": grader_result_score, "step_reward": step_reward, "is_done": action.is_done, } ) reward = Reward( score=_strict_score(step_reward), grader_score=grader_result_score, breakdown=breakdown, feedback=feedback, cumulative_score=self._cumulative_score, ) info = { "step_number": self._step_number, "grader_score": grader_result_score, "cumulative_score": self._cumulative_score, "is_invalid": is_invalid, } return self._make_observation(), reward, self._done, info # ────────────────────────────────────────────────────────────────────────── # state # ────────────────────────────────────────────────────────────────────────── def state(self) -> Dict[str, Any]: """Return the current internal state snapshot.""" if self._task is None: return {"status": "not_started"} return { "task_id": self._task.id, "task_name": self._task.name, "difficulty": self._task.difficulty, "step_number": self._step_number, "max_steps": self._task.max_steps, "done": self._done, "cumulative_score": self._cumulative_score, "last_grader_score": self._last_grader_score, "history": self._history, } # ────────────────────────────────────────────────────────────────────────── # Internal helpers # ────────────────────────────────────────────────────────────────────────── def _make_observation(self) -> Observation: assert self._task is not None return Observation( task_id=self._task.id, task_name=self._task.name, task_description=self._task.description, query=self._task.query, schema_context=self._task.schema_context, hint=self._task.hint, step_number=self._step_number, max_steps=self._task.max_steps, done=self._done, )