"""Deterministic oracle policy for upper-bound evaluation baselines.""" from __future__ import annotations try: from ..models import QuestionRecord, SQLAction, SQLObservation except ImportError: try: from models import QuestionRecord, SQLAction, SQLObservation # type: ignore[no-redef] except ImportError: from sql_env.models import QuestionRecord, SQLAction, SQLObservation # type: ignore[no-redef] class OraclePolicy: """Play deterministic optimal actions using question gold data.""" def __init__(self, questions: list[QuestionRecord]) -> None: self._question_lookup: dict[str, QuestionRecord] = { question.question_text: question for question in questions } self._current_question: QuestionRecord | None = None self._tables_to_describe: list[str] = [] self._gold_sql_sent = False def select_action(self, observation: SQLObservation) -> SQLAction: """Select the next deterministic oracle action.""" if self._needs_episode_reset(observation): self._start_episode(observation.question) if self._current_question is None: return SQLAction(action_type="ANSWER", argument="") answer_value = self._gold_answer() if observation.budget_remaining <= 1: return SQLAction(action_type="ANSWER", argument=answer_value) if self._tables_to_describe: table_name = self._tables_to_describe.pop(0) return SQLAction(action_type="DESCRIBE", argument=table_name) if not self._gold_sql_sent: self._gold_sql_sent = True return SQLAction(action_type="QUERY", argument=self._gold_sql()) return SQLAction(action_type="ANSWER", argument=answer_value) def _needs_episode_reset(self, observation: SQLObservation) -> bool: if self._current_question is None: return True if observation.step_count == 0: return True return observation.question != self._current_question.question_text def _start_episode(self, question_text: str) -> None: self._current_question = self._question_lookup.get(question_text) self._tables_to_describe = [] self._gold_sql_sent = False if self._current_question is not None: self._tables_to_describe = list(self._current_question.tables_involved) def _gold_sql(self) -> str: if self._current_question is None: return "" return self._current_question.gold_sql def _gold_answer(self) -> str: if self._current_question is None: return "" return self._current_question.gold_answer