| """Deterministic oracle policy for upper-bound evaluation baselines.""" |
|
|
| from __future__ import annotations |
|
|
| try: |
| from ..models import QuestionRecord, SQLAction, SQLObservation |
| except ImportError: |
| try: |
| from models import QuestionRecord, SQLAction, SQLObservation |
| except ImportError: |
| from sql_env.models import QuestionRecord, SQLAction, SQLObservation |
|
|
|
|
| class OraclePolicy: |
| """Play deterministic optimal actions using question gold data.""" |
|
|
| def __init__(self, questions: list[QuestionRecord]) -> None: |
| self._question_lookup: dict[str, QuestionRecord] = { |
| question.question_text: question for question in questions |
| } |
| self._current_question: QuestionRecord | None = None |
| self._tables_to_describe: list[str] = [] |
| self._gold_sql_sent = False |
|
|
| def select_action(self, observation: SQLObservation) -> SQLAction: |
| """Select the next deterministic oracle action.""" |
| if self._needs_episode_reset(observation): |
| self._start_episode(observation.question) |
|
|
| if self._current_question is None: |
| return SQLAction(action_type="ANSWER", argument="") |
|
|
| answer_value = self._gold_answer() |
| if observation.budget_remaining <= 1: |
| return SQLAction(action_type="ANSWER", argument=answer_value) |
|
|
| if self._tables_to_describe: |
| table_name = self._tables_to_describe.pop(0) |
| return SQLAction(action_type="DESCRIBE", argument=table_name) |
|
|
| if not self._gold_sql_sent: |
| self._gold_sql_sent = True |
| return SQLAction(action_type="QUERY", argument=self._gold_sql()) |
|
|
| return SQLAction(action_type="ANSWER", argument=answer_value) |
|
|
| def _needs_episode_reset(self, observation: SQLObservation) -> bool: |
| if self._current_question is None: |
| return True |
| if observation.step_count == 0: |
| return True |
| return observation.question != self._current_question.question_text |
|
|
| def _start_episode(self, question_text: str) -> None: |
| self._current_question = self._question_lookup.get(question_text) |
| self._tables_to_describe = [] |
| self._gold_sql_sent = False |
| if self._current_question is not None: |
| self._tables_to_describe = list(self._current_question.tables_involved) |
|
|
| def _gold_sql(self) -> str: |
| if self._current_question is None: |
| return "" |
| return self._current_question.gold_sql |
|
|
| def _gold_answer(self) -> str: |
| if self._current_question is None: |
| return "" |
| return self._current_question.gold_answer |
|
|