sql_env / evaluation /oracle_policy.py
hjerpe's picture
Upload folder using huggingface_hub
9e64e71 verified
"""Deterministic oracle policy for upper-bound evaluation baselines."""
from __future__ import annotations
try:
from ..models import QuestionRecord, SQLAction, SQLObservation
except ImportError:
try:
from models import QuestionRecord, SQLAction, SQLObservation # type: ignore[no-redef]
except ImportError:
from sql_env.models import QuestionRecord, SQLAction, SQLObservation # type: ignore[no-redef]
class OraclePolicy:
"""Play deterministic optimal actions using question gold data."""
def __init__(self, questions: list[QuestionRecord]) -> None:
self._question_lookup: dict[str, QuestionRecord] = {
question.question_text: question for question in questions
}
self._current_question: QuestionRecord | None = None
self._tables_to_describe: list[str] = []
self._gold_sql_sent = False
def select_action(self, observation: SQLObservation) -> SQLAction:
"""Select the next deterministic oracle action."""
if self._needs_episode_reset(observation):
self._start_episode(observation.question)
if self._current_question is None:
return SQLAction(action_type="ANSWER", argument="")
answer_value = self._gold_answer()
if observation.budget_remaining <= 1:
return SQLAction(action_type="ANSWER", argument=answer_value)
if self._tables_to_describe:
table_name = self._tables_to_describe.pop(0)
return SQLAction(action_type="DESCRIBE", argument=table_name)
if not self._gold_sql_sent:
self._gold_sql_sent = True
return SQLAction(action_type="QUERY", argument=self._gold_sql())
return SQLAction(action_type="ANSWER", argument=answer_value)
def _needs_episode_reset(self, observation: SQLObservation) -> bool:
if self._current_question is None:
return True
if observation.step_count == 0:
return True
return observation.question != self._current_question.question_text
def _start_episode(self, question_text: str) -> None:
self._current_question = self._question_lookup.get(question_text)
self._tables_to_describe = []
self._gold_sql_sent = False
if self._current_question is not None:
self._tables_to_describe = list(self._current_question.tables_involved)
def _gold_sql(self) -> str:
if self._current_question is None:
return ""
return self._current_question.gold_sql
def _gold_answer(self) -> str:
if self._current_question is None:
return ""
return self._current_question.gold_answer