""" nl2sql-bench/tests/test_all.py ================================ Comprehensive test suite covering: - Database seeder (determinism + row counts) - Grader (all reward components, step penalty, edge cases) - Task registry (all 3 tasks load and produce valid examples) - Environment (reset, step, episode boundary, done logic) - Inference log format (regex checks on START / STEP / END) Run with: pytest tests/ -v or from project root: PYTHONPATH=.:server pytest tests/ -v """ from __future__ import annotations import re import sqlite3 import sys import os from pathlib import Path import pytest # ── Path setup so tests can import from both project root and server/ ────── ROOT = Path(__file__).parent.parent SERVER = ROOT / "server" sys.path.insert(0, str(ROOT)) sys.path.insert(0, str(SERVER)) # ── Fixtures ─────────────────────────────────────────────────────────────── @pytest.fixture(scope="session") def db_conn(): """Shared in-memory SQLite connection with full schema + seed data.""" from db.seed import seed_database schema = (SERVER / "db" / "schema.sql").read_text() conn = sqlite3.connect(":memory:", check_same_thread=False) conn.row_factory = sqlite3.Row conn.executescript(schema) seed_database(conn) yield conn conn.close() @pytest.fixture def fresh_env(): """A fresh NL2SQLEnvironment instance per test.""" from environment import NL2SQLEnvironment return NL2SQLEnvironment() # ══════════════════════════════════════════════════════════════════════════════ # 1. DATABASE SEEDER # ══════════════════════════════════════════════════════════════════════════════ class TestSeeder: def test_categories_count(self, db_conn): row = db_conn.execute("SELECT COUNT(*) FROM categories").fetchone() assert row[0] == 8, "Should have exactly 8 categories" def test_products_count(self, db_conn): row = db_conn.execute("SELECT COUNT(*) FROM products").fetchone() assert row[0] == 64, "Should have 8 products × 8 categories = 64" def test_customers_count(self, db_conn): row = db_conn.execute("SELECT COUNT(*) FROM customers").fetchone() assert row[0] == 150 def test_orders_exist(self, db_conn): row = db_conn.execute("SELECT COUNT(*) FROM orders").fetchone() assert row[0] > 100, "Should have a meaningful number of orders" def test_order_items_exist(self, db_conn): row = db_conn.execute("SELECT COUNT(*) FROM order_items").fetchone() assert row[0] > 200 def test_reviews_exist(self, db_conn): row = db_conn.execute("SELECT COUNT(*) FROM reviews").fetchone() assert row[0] > 50 def test_determinism(self, db_conn): """Seeding a second connection with the same seed gives identical counts.""" from db.seed import seed_database schema = (SERVER / "db" / "schema.sql").read_text() conn2 = sqlite3.connect(":memory:") conn2.executescript(schema) seed_database(conn2) for tbl in ["categories", "products", "customers", "orders", "order_items", "reviews"]: c1 = db_conn.execute(f"SELECT COUNT(*) FROM {tbl}").fetchone()[0] c2 = conn2.execute(f"SELECT COUNT(*) FROM {tbl}").fetchone()[0] assert c1 == c2, f"Table {tbl} count mismatch: {c1} vs {c2}" conn2.close() def test_tiers_valid(self, db_conn): bad = db_conn.execute( "SELECT COUNT(*) FROM customers WHERE tier NOT IN ('bronze','silver','gold')" ).fetchone()[0] assert bad == 0 def test_statuses_valid(self, db_conn): bad = db_conn.execute( "SELECT COUNT(*) FROM orders " "WHERE status NOT IN ('pending','processing','shipped','delivered','cancelled')" ).fetchone()[0] assert bad == 0 def test_ratings_valid(self, db_conn): bad = db_conn.execute( "SELECT COUNT(*) FROM reviews WHERE rating < 1 OR rating > 5" ).fetchone()[0] assert bad == 0 def test_referential_integrity(self, db_conn): """Order items should reference valid orders and products.""" orphan_orders = db_conn.execute( "SELECT COUNT(*) FROM order_items oi " "LEFT JOIN orders o ON o.id = oi.order_id WHERE o.id IS NULL" ).fetchone()[0] assert orphan_orders == 0 orphan_products = db_conn.execute( "SELECT COUNT(*) FROM order_items oi " "LEFT JOIN products p ON p.id = oi.product_id WHERE p.id IS NULL" ).fetchone()[0] assert orphan_products == 0 # ══════════════════════════════════════════════════════════════════════════════ # 2. GRADER # ══════════════════════════════════════════════════════════════════════════════ class TestGrader: def test_exact_match_first_step(self): from grader import grade gt = [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}] result = grade( actual_rows=gt.copy(), ground_truth_rows=gt, error=None, step=1, order_sensitive=False, ) assert result.reward == pytest.approx(1.0) assert result.exact_match is True assert result.syntax_ok is True assert result.columns_match is True assert result.row_count_match is True assert result.step_penalty == 0.0 def test_syntax_error_gives_zero(self): from grader import grade result = grade( actual_rows=None, ground_truth_rows=[{"x": 1}], error="near 'SELCT': syntax error", step=1, ) assert result.reward == 0.0 assert result.syntax_ok is False def test_step_penalty_applied(self): from grader import grade gt = [{"n": 1}] result = grade( actual_rows=gt.copy(), ground_truth_rows=gt, error=None, step=3, # penalty = (3-1)*0.05 = 0.10 ) assert result.reward == pytest.approx(1.0 - 0.10) assert result.step_penalty == pytest.approx(0.10) def test_columns_wrong_zero_higher_components(self): from grader import grade gt = [{"name": "Alice", "score": 10}] actual = [{"user": "Alice", "points": 10}] # wrong column names result = grade(actual_rows=actual, ground_truth_rows=gt, error=None, step=1) assert result.columns_match is False assert result.exact_match is False # Only syntax score: 0.10 assert result.reward == pytest.approx(0.10) def test_correct_columns_wrong_rows(self): from grader import grade gt = [{"name": "Alice"}, {"name": "Bob"}] actual = [{"name": "Charlie"}, {"name": "Dave"}] result = grade(actual_rows=actual, ground_truth_rows=gt, error=None, step=1) assert result.columns_match is True assert result.row_count_match is True assert result.exact_match is False # syntax(0.10) + columns(0.20) + row_count(0.20) = 0.50 assert result.reward == pytest.approx(0.50) def test_order_sensitive_wrong_order_is_not_exact(self): from grader import grade gt = [{"id": 1}, {"id": 2}] actual = [{"id": 2}, {"id": 1}] # reversed result = grade( actual_rows=actual, ground_truth_rows=gt, error=None, step=1, order_sensitive=True, ) assert result.exact_match is False def test_order_insensitive_accepts_different_row_order(self): from grader import grade gt = [{"id": 1}, {"id": 2}] actual = [{"id": 2}, {"id": 1}] # different order but same content result = grade( actual_rows=actual, ground_truth_rows=gt, error=None, step=1, order_sensitive=False, ) assert result.exact_match is True def test_penalty_never_makes_reward_negative(self): from grader import grade # Step 99 with syntax error → reward must be >= 0 result = grade( actual_rows=None, ground_truth_rows=[{"x": 1}], error="some error", step=99, ) assert result.reward >= 0.0 def test_execute_query_blocks_writes(self, db_conn): from grader import execute_query rows, err = execute_query(db_conn, "INSERT INTO categories(name) VALUES ('x')") assert rows is None assert "not allowed" in err.lower() or "INSERT" in err def test_execute_query_returns_rows(self, db_conn): from grader import execute_query rows, err = execute_query(db_conn, "SELECT id, name FROM categories ORDER BY id") assert err is None assert len(rows) == 8 assert "id" in rows[0] assert "name" in rows[0] def test_compute_ground_truth(self, db_conn): from grader import compute_ground_truth rows = compute_ground_truth(db_conn, "SELECT COUNT(*) AS n FROM customers") assert len(rows) == 1 assert rows[0]["n"] == 150 # ══════════════════════════════════════════════════════════════════════════════ # 3. TASK REGISTRY # ══════════════════════════════════════════════════════════════════════════════ class TestTasks: def test_all_tasks_registered(self): from tasks import all_task_names names = all_task_names() assert "simple-filter" in names assert "join-aggregation" in names assert "analytics-window" in names @pytest.mark.parametrize("task_name", [ "simple-filter", "join-aggregation", "analytics-window" ]) def test_task_has_examples(self, task_name): from tasks import get_task task = get_task(task_name) assert len(task.examples) >= 3, f"{task_name} needs at least 3 examples" @pytest.mark.parametrize("task_name", [ "simple-filter", "join-aggregation", "analytics-window" ]) def test_task_sql_runs_on_real_db(self, task_name, db_conn): """Every ground-truth SQL must execute cleanly against the seeded DB.""" from tasks import get_task from grader import execute_query task = get_task(task_name) for ex in task.examples: rows, error = execute_query(db_conn, ex.sql) assert error is None, ( f"Task {task_name!r} SQL failed:\n{ex.sql}\nError: {error}" ) assert rows is not None @pytest.mark.parametrize("task_name", [ "simple-filter", "join-aggregation", "analytics-window" ]) def test_task_roundrobin(self, task_name): from tasks import get_task task = get_task(task_name) n = len(task.examples) seen = [task.next_example() for _ in range(n * 2)] # After n calls, second half should repeat first half assert seen[:n] == seen[n:] def test_schema_context_non_empty(self): from tasks import get_task task = get_task("simple-filter") ctx = task.schema_context() assert "customers" in ctx assert "orders" in ctx assert "products" in ctx # ══════════════════════════════════════════════════════════════════════════════ # 4. ENVIRONMENT # ══════════════════════════════════════════════════════════════════════════════ class TestEnvironment: def test_reset_returns_observation(self, fresh_env): obs = fresh_env.reset(task_name="simple-filter") assert obs.question != "" assert obs.schema_context != "" assert obs.task_name == "simple-filter" assert obs.done is False assert obs.step == 0 assert obs.reward is None def test_reset_state(self, fresh_env): fresh_env.reset(task_name="join-aggregation") state = fresh_env.state assert state.task_name == "join-aggregation" assert state.task_difficulty == "medium" assert state.step_count == 0 assert state.solved is False def test_step_increments_step_count(self, fresh_env): from models import NL2SQLAction fresh_env.reset(task_name="simple-filter") fresh_env.step(NL2SQLAction(query="SELECT 1")) assert fresh_env.state.step_count == 1 def test_step_syntax_error_gives_nonzero_error(self, fresh_env): from models import NL2SQLAction fresh_env.reset(task_name="simple-filter") obs = fresh_env.step(NL2SQLAction(query="SELCT * FORM broken")) assert obs.last_error is not None assert obs.reward == 0.0 def test_step_valid_query_returns_result(self, fresh_env): from models import NL2SQLAction fresh_env.reset(task_name="simple-filter") obs = fresh_env.step(NL2SQLAction( query="SELECT id, name FROM customers ORDER BY name LIMIT 5" )) assert obs.last_error is None assert len(obs.last_result) <= 5 assert obs.reward >= 0.0 def test_exact_match_ends_episode(self, fresh_env): """Submitting the exact ground-truth SQL should solve the episode.""" from models import NL2SQLAction fresh_env.reset(task_name="simple-filter") # Get the ground truth SQL from the internal example gt_sql = fresh_env._example.sql obs = fresh_env.step(NL2SQLAction(query=gt_sql)) assert obs.done is True assert fresh_env.state.solved is True assert obs.reward == pytest.approx(1.0) # step 1, full score def test_max_steps_ends_episode(self, fresh_env): """Exhausting all steps should end the episode even without solving.""" from models import NL2SQLAction from environment import MAX_STEPS fresh_env.reset(task_name="analytics-window") obs = None for _ in range(MAX_STEPS): obs = fresh_env.step(NL2SQLAction(query="SELECT 1")) assert obs is not None assert obs.done is True def test_reset_clears_previous_episode(self, fresh_env): from models import NL2SQLAction fresh_env.reset(task_name="simple-filter") fresh_env.step(NL2SQLAction(query="SELECT 1")) # Second reset should clear state obs = fresh_env.reset(task_name="join-aggregation") assert fresh_env.state.step_count == 0 assert obs.step == 0 assert obs.task_name == "join-aggregation" @pytest.mark.parametrize("task_name", [ "simple-filter", "join-aggregation", "analytics-window" ]) def test_all_tasks_solvable(self, task_name): """Ground-truth SQL should always produce reward == 1.0 on step 1.""" from environment import NL2SQLEnvironment from models import NL2SQLAction env = NL2SQLEnvironment() env.reset(task_name=task_name) gt_sql = env._example.sql obs = env.step(NL2SQLAction(query=gt_sql)) assert obs.done is True assert obs.reward == pytest.approx(1.0), ( f"Task {task_name!r}: ground-truth SQL did not score 1.0.\n" f"SQL: {gt_sql}\nError: {obs.last_error}\nReward: {obs.reward}" ) def test_score_normalised_to_0_1(self, fresh_env): from models import NL2SQLAction fresh_env.reset(task_name="simple-filter") for _ in range(3): obs = fresh_env.step(NL2SQLAction(query="SELECT 1 AS x")) assert 0.0 <= obs.score <= 1.0 def test_write_query_blocked(self, fresh_env): from models import NL2SQLAction fresh_env.reset(task_name="simple-filter") obs = fresh_env.step(NL2SQLAction( query="INSERT INTO categories(name) VALUES ('hack')" )) assert obs.last_error is not None assert "not allowed" in obs.last_error.lower() or "INSERT" in obs.last_error # ══════════════════════════════════════════════════════════════════════════════ # 5. LOG FORMAT COMPLIANCE # ══════════════════════════════════════════════════════════════════════════════ class TestLogFormat: """Validate that the inference.py log helpers emit correct format.""" START_RE = re.compile( r"^\[START\] task=\S+ env=\S+ model=\S+$" ) STEP_RE = re.compile( r"^\[STEP\] step=\d+ action=.+ reward=\d+\.\d{2} " r"done=(true|false) error=.+$" ) END_RE = re.compile( r"^\[END\] success=(true|false) steps=\d+ score=\d+\.\d{3} " r"rewards=[\d.,]+$" ) def _capture(self, func, *args, **kwargs) -> str: import io from contextlib import redirect_stdout buf = io.StringIO() with redirect_stdout(buf): func(*args, **kwargs) return buf.getvalue().strip() def test_log_start_format(self): sys.path.insert(0, str(ROOT)) from inference import log_start out = self._capture(log_start, "simple-filter", "Qwen/Qwen2.5-72B") assert self.START_RE.match(out), f"Bad [START] format: {out!r}" def test_log_step_format_null_error(self): from inference import log_step out = self._capture(log_step, 1, "SELECT 1", 0.10, False, None) assert self.STEP_RE.match(out), f"Bad [STEP] format: {out!r}" def test_log_step_format_with_error(self): from inference import log_step out = self._capture(log_step, 2, "SELCT 1", 0.0, False, "syntax error") assert self.STEP_RE.match(out), f"Bad [STEP] format: {out!r}" def test_log_end_format_success(self): from inference import log_end out = self._capture(log_end, True, 3, 0.850, [0.50, 1.0, 1.0]) assert self.END_RE.match(out), f"Bad [END] format: {out!r}" def test_log_end_format_failure(self): from inference import log_end out = self._capture(log_end, False, 5, 0.100, [0.1, 0.0, 0.0, 0.0, 0.0]) assert self.END_RE.match(out), f"Bad [END] format: {out!r}" def test_reward_two_decimal_places(self): from inference import log_step out = self._capture(log_step, 1, "SELECT 1", 0.5, False, None) # reward= field must have exactly 2 decimal places match = re.search(r"reward=(\d+\.\d+)", out) assert match, "No reward= field found" assert len(match.group(1).split(".")[1]) == 2 def test_score_three_decimal_places(self): from inference import log_end out = self._capture(log_end, True, 1, 1.0, [1.0]) match = re.search(r"score=(\d+\.\d+)", out) assert match assert len(match.group(1).split(".")[1]) == 3