"""Unit tests for OraclePolicy action selection.""" from __future__ import annotations from sql_env.evaluation.policies import Policy from sql_env.evaluation.oracle_policy import OraclePolicy from sql_env.models import QuestionRecord, SQLAction, SQLObservation def _question( *, text: str, tables: list[str], gold_sql: str, gold_answer: str, question_id: str = "q1", ) -> QuestionRecord: return QuestionRecord( question_id=question_id, question_text=text, database_name="db", gold_sql=gold_sql, gold_answer=gold_answer, answer_type="string", difficulty="easy", tables_involved=tables, ) def _obs( *, question: str, step_count: int = 0, budget_remaining: int = 10, ) -> SQLObservation: return SQLObservation( question=question, schema_info="Available tables:\n- employees\n- departments", result="", error="", step_count=step_count, budget_remaining=budget_remaining, action_history=[], done=False, reward=None, ) def test_init_builds_lookup_from_questions() -> None: first = _question(text="Q1", tables=["t1"], gold_sql="SELECT 1", gold_answer="1") second = _question( text="Q2", tables=["t2"], gold_sql="SELECT 2", gold_answer="2", question_id="q2", ) policy = OraclePolicy([first, second]) assert set(policy._question_lookup) == {"Q1", "Q2"} def test_init_empty_questions() -> None: policy = OraclePolicy([]) assert policy._question_lookup == {} def test_init_single_question() -> None: policy = OraclePolicy( [_question(text="Q1", tables=[], gold_sql="SELECT 1", gold_answer="1")] ) assert set(policy._question_lookup) == {"Q1"} def test_init_duplicate_question_text() -> None: first = _question( text="Q1", tables=["t1"], gold_sql="SELECT 1", gold_answer="first" ) second = _question( text="Q1", tables=["t2"], gold_sql="SELECT 2", gold_answer="second", question_id="q2", ) policy = OraclePolicy([first, second]) assert policy._question_lookup["Q1"].gold_answer == "second" def test_init_state_defaults() -> None: policy = OraclePolicy( [_question(text="Q1", tables=[], gold_sql="SELECT 1", gold_answer="1")] ) assert policy._current_question is None assert policy._tables_to_describe == [] assert policy._gold_sql_sent is False def test_select_action_describe_phase() -> None: policy = OraclePolicy( [ _question( text="Q1", tables=["t1", "t2"], gold_sql="SELECT * FROM t1", gold_answer="A", ) ] ) action = policy.select_action(_obs(question="Q1")) assert action == SQLAction(action_type="DESCRIBE", argument="t1") def test_select_action_describe_second_table() -> None: policy = OraclePolicy( [ _question( text="Q1", tables=["t1", "t2"], gold_sql="SELECT * FROM t1", gold_answer="A", ) ] ) policy.select_action(_obs(question="Q1")) action = policy.select_action(_obs(question="Q1", step_count=1)) assert action == SQLAction(action_type="DESCRIBE", argument="t2") def test_select_action_query_phase() -> None: policy = OraclePolicy( [ _question( text="Q1", tables=["t1", "t2"], gold_sql="SELECT * FROM t1", gold_answer="A", ) ] ) policy.select_action(_obs(question="Q1")) policy.select_action(_obs(question="Q1", step_count=1)) action = policy.select_action(_obs(question="Q1", step_count=2)) assert action == SQLAction(action_type="QUERY", argument="SELECT * FROM t1") def test_select_action_answer_phase() -> None: policy = OraclePolicy( [ _question( text="Q1", tables=["t1"], gold_sql="SELECT * FROM t1", gold_answer="A" ) ] ) policy.select_action(_obs(question="Q1")) policy.select_action(_obs(question="Q1", step_count=1)) action = policy.select_action(_obs(question="Q1", step_count=2)) assert action == SQLAction(action_type="ANSWER", argument="A") def test_full_episode_sequence() -> None: policy = OraclePolicy( [ _question( text="Q1", tables=["employees"], gold_sql="SELECT COUNT(*) FROM employees", gold_answer="3", ) ] ) action_1 = policy.select_action(_obs(question="Q1")) action_2 = policy.select_action(_obs(question="Q1", step_count=1)) action_3 = policy.select_action(_obs(question="Q1", step_count=2)) assert [action_1.action_type, action_2.action_type, action_3.action_type] == [ "DESCRIBE", "QUERY", "ANSWER", ] def test_new_episode_resets_state() -> None: q1 = _question( text="Q1", tables=["t1"], gold_sql="SELECT * FROM t1", gold_answer="A" ) q2 = _question( text="Q2", tables=["t2"], gold_sql="SELECT * FROM t2", gold_answer="B" ) policy = OraclePolicy([q1, q2]) policy.select_action(_obs(question="Q1")) policy.select_action(_obs(question="Q1", step_count=1)) action = policy.select_action(_obs(question="Q2", step_count=0)) assert action == SQLAction(action_type="DESCRIBE", argument="t2") assert policy._gold_sql_sent is False def test_new_episode_lookup() -> None: q1 = _question( text="Q1", tables=["t1"], gold_sql="SELECT * FROM t1", gold_answer="A" ) q2 = _question( text="Q2", tables=["t2"], gold_sql="SELECT * FROM t2", gold_answer="B" ) policy = OraclePolicy([q1, q2]) policy.select_action(_obs(question="Q1")) policy.select_action(_obs(question="Q2", step_count=0)) assert policy._current_question is q2 def test_zero_tables_skips_describe() -> None: policy = OraclePolicy( [_question(text="Q1", tables=[], gold_sql="SELECT 1", gold_answer="1")] ) action = policy.select_action(_obs(question="Q1")) assert action == SQLAction(action_type="QUERY", argument="SELECT 1") def test_zero_tables_then_answer() -> None: policy = OraclePolicy( [_question(text="Q1", tables=[], gold_sql="SELECT 1", gold_answer="1")] ) policy.select_action(_obs(question="Q1")) action = policy.select_action(_obs(question="Q1", step_count=1)) assert action == SQLAction(action_type="ANSWER", argument="1") def test_unknown_question_returns_empty_answer() -> None: policy = OraclePolicy( [_question(text="Q1", tables=["t1"], gold_sql="SELECT 1", gold_answer="1")] ) action = policy.select_action(_obs(question="UNKNOWN")) assert action == SQLAction(action_type="ANSWER", argument="") def test_unknown_question_no_crash() -> None: policy = OraclePolicy( [_question(text="Q1", tables=["t1"], gold_sql="SELECT 1", gold_answer="1")] ) result = policy.select_action(_obs(question="UNKNOWN")) assert isinstance(result, SQLAction) assert result.action_type == "ANSWER" def test_budget_one_forces_answer() -> None: policy = OraclePolicy( [ _question( text="Q1", tables=["t1", "t2"], gold_sql="SELECT * FROM t1", gold_answer="A", ) ] ) policy.select_action(_obs(question="Q1")) action = policy.select_action(_obs(question="Q1", step_count=1, budget_remaining=1)) assert action == SQLAction(action_type="ANSWER", argument="A") def test_budget_one_forces_answer_before_query() -> None: policy = OraclePolicy( [ _question( text="Q1", tables=["t1"], gold_sql="SELECT * FROM t1", gold_answer="A" ) ] ) policy.select_action(_obs(question="Q1")) action = policy.select_action(_obs(question="Q1", step_count=1, budget_remaining=1)) assert action == SQLAction(action_type="ANSWER", argument="A") def test_budget_one_unknown_question() -> None: policy = OraclePolicy( [_question(text="Q1", tables=["t1"], gold_sql="SELECT 1", gold_answer="1")] ) action = policy.select_action(_obs(question="UNKNOWN", budget_remaining=1)) assert action == SQLAction(action_type="ANSWER", argument="") def test_select_action_returns_sql_action() -> None: policy = OraclePolicy( [_question(text="Q1", tables=["t1"], gold_sql="SELECT 1", gold_answer="1")] ) result = policy.select_action(_obs(question="Q1")) assert isinstance(result, SQLAction) def test_select_action_valid_action_types() -> None: policy = OraclePolicy( [_question(text="Q1", tables=["t1"], gold_sql="SELECT 1", gold_answer="1")] ) actions = [ policy.select_action(_obs(question="Q1", step_count=0)), policy.select_action(_obs(question="Q1", step_count=1)), policy.select_action(_obs(question="Q1", step_count=2)), ] assert {action.action_type for action in actions}.issubset( {"DESCRIBE", "QUERY", "ANSWER"} ) def test_oracle_satisfies_policy_protocol() -> None: policy = OraclePolicy( [_question(text="Q1", tables=[], gold_sql="SELECT 1", gold_answer="1")] ) assert isinstance(policy, Policy) def test_oracle_has_select_action_method() -> None: assert callable(getattr(OraclePolicy, "select_action", None))