Spaces:
Sleeping
Sleeping
| """ | |
| nl2sql-bench/tests/test_all.py | |
| ================================ | |
| Comprehensive test suite covering: | |
| - Database seeder (determinism + row counts) | |
| - Grader (all reward components, step penalty, edge cases) | |
| - Task registry (all 3 tasks load and produce valid examples) | |
| - Environment (reset, step, episode boundary, done logic) | |
| - Inference log format (regex checks on START / STEP / END) | |
| Run with: | |
| pytest tests/ -v | |
| or from project root: | |
| PYTHONPATH=.:server pytest tests/ -v | |
| """ | |
| from __future__ import annotations | |
| import re | |
| import sqlite3 | |
| import sys | |
| import os | |
| from pathlib import Path | |
| import pytest | |
| # ββ Path setup so tests can import from both project root and server/ ββββββ | |
| ROOT = Path(__file__).parent.parent | |
| SERVER = ROOT / "server" | |
| sys.path.insert(0, str(ROOT)) | |
| sys.path.insert(0, str(SERVER)) | |
| # ββ Fixtures βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def db_conn(): | |
| """Shared in-memory SQLite connection with full schema + seed data.""" | |
| from db.seed import seed_database | |
| schema = (SERVER / "db" / "schema.sql").read_text() | |
| conn = sqlite3.connect(":memory:", check_same_thread=False) | |
| conn.row_factory = sqlite3.Row | |
| conn.executescript(schema) | |
| seed_database(conn) | |
| yield conn | |
| conn.close() | |
| def fresh_env(): | |
| """A fresh NL2SQLEnvironment instance per test.""" | |
| from environment import NL2SQLEnvironment | |
| return NL2SQLEnvironment() | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 1. DATABASE SEEDER | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class TestSeeder: | |
| def test_categories_count(self, db_conn): | |
| row = db_conn.execute("SELECT COUNT(*) FROM categories").fetchone() | |
| assert row[0] == 8, "Should have exactly 8 categories" | |
| def test_products_count(self, db_conn): | |
| row = db_conn.execute("SELECT COUNT(*) FROM products").fetchone() | |
| assert row[0] == 64, "Should have 8 products Γ 8 categories = 64" | |
| def test_customers_count(self, db_conn): | |
| row = db_conn.execute("SELECT COUNT(*) FROM customers").fetchone() | |
| assert row[0] == 150 | |
| def test_orders_exist(self, db_conn): | |
| row = db_conn.execute("SELECT COUNT(*) FROM orders").fetchone() | |
| assert row[0] > 100, "Should have a meaningful number of orders" | |
| def test_order_items_exist(self, db_conn): | |
| row = db_conn.execute("SELECT COUNT(*) FROM order_items").fetchone() | |
| assert row[0] > 200 | |
| def test_reviews_exist(self, db_conn): | |
| row = db_conn.execute("SELECT COUNT(*) FROM reviews").fetchone() | |
| assert row[0] > 50 | |
| def test_determinism(self, db_conn): | |
| """Seeding a second connection with the same seed gives identical counts.""" | |
| from db.seed import seed_database | |
| schema = (SERVER / "db" / "schema.sql").read_text() | |
| conn2 = sqlite3.connect(":memory:") | |
| conn2.executescript(schema) | |
| seed_database(conn2) | |
| for tbl in ["categories", "products", "customers", "orders", | |
| "order_items", "reviews"]: | |
| c1 = db_conn.execute(f"SELECT COUNT(*) FROM {tbl}").fetchone()[0] | |
| c2 = conn2.execute(f"SELECT COUNT(*) FROM {tbl}").fetchone()[0] | |
| assert c1 == c2, f"Table {tbl} count mismatch: {c1} vs {c2}" | |
| conn2.close() | |
| def test_tiers_valid(self, db_conn): | |
| bad = db_conn.execute( | |
| "SELECT COUNT(*) FROM customers WHERE tier NOT IN ('bronze','silver','gold')" | |
| ).fetchone()[0] | |
| assert bad == 0 | |
| def test_statuses_valid(self, db_conn): | |
| bad = db_conn.execute( | |
| "SELECT COUNT(*) FROM orders " | |
| "WHERE status NOT IN ('pending','processing','shipped','delivered','cancelled')" | |
| ).fetchone()[0] | |
| assert bad == 0 | |
| def test_ratings_valid(self, db_conn): | |
| bad = db_conn.execute( | |
| "SELECT COUNT(*) FROM reviews WHERE rating < 1 OR rating > 5" | |
| ).fetchone()[0] | |
| assert bad == 0 | |
| def test_referential_integrity(self, db_conn): | |
| """Order items should reference valid orders and products.""" | |
| orphan_orders = db_conn.execute( | |
| "SELECT COUNT(*) FROM order_items oi " | |
| "LEFT JOIN orders o ON o.id = oi.order_id WHERE o.id IS NULL" | |
| ).fetchone()[0] | |
| assert orphan_orders == 0 | |
| orphan_products = db_conn.execute( | |
| "SELECT COUNT(*) FROM order_items oi " | |
| "LEFT JOIN products p ON p.id = oi.product_id WHERE p.id IS NULL" | |
| ).fetchone()[0] | |
| assert orphan_products == 0 | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 2. GRADER | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class TestGrader: | |
| def test_exact_match_first_step(self): | |
| from grader import grade | |
| gt = [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}] | |
| result = grade( | |
| actual_rows=gt.copy(), | |
| ground_truth_rows=gt, | |
| error=None, | |
| step=1, | |
| order_sensitive=False, | |
| ) | |
| assert result.reward == pytest.approx(1.0) | |
| assert result.exact_match is True | |
| assert result.syntax_ok is True | |
| assert result.columns_match is True | |
| assert result.row_count_match is True | |
| assert result.step_penalty == 0.0 | |
| def test_syntax_error_gives_zero(self): | |
| from grader import grade | |
| result = grade( | |
| actual_rows=None, | |
| ground_truth_rows=[{"x": 1}], | |
| error="near 'SELCT': syntax error", | |
| step=1, | |
| ) | |
| assert result.reward == 0.0 | |
| assert result.syntax_ok is False | |
| def test_step_penalty_applied(self): | |
| from grader import grade | |
| gt = [{"n": 1}] | |
| result = grade( | |
| actual_rows=gt.copy(), | |
| ground_truth_rows=gt, | |
| error=None, | |
| step=3, # penalty = (3-1)*0.05 = 0.10 | |
| ) | |
| assert result.reward == pytest.approx(1.0 - 0.10) | |
| assert result.step_penalty == pytest.approx(0.10) | |
| def test_columns_wrong_zero_higher_components(self): | |
| from grader import grade | |
| gt = [{"name": "Alice", "score": 10}] | |
| actual = [{"user": "Alice", "points": 10}] # wrong column names | |
| result = grade(actual_rows=actual, ground_truth_rows=gt, error=None, step=1) | |
| assert result.columns_match is False | |
| assert result.exact_match is False | |
| # Only syntax score: 0.10 | |
| assert result.reward == pytest.approx(0.10) | |
| def test_correct_columns_wrong_rows(self): | |
| from grader import grade | |
| gt = [{"name": "Alice"}, {"name": "Bob"}] | |
| actual = [{"name": "Charlie"}, {"name": "Dave"}] | |
| result = grade(actual_rows=actual, ground_truth_rows=gt, error=None, step=1) | |
| assert result.columns_match is True | |
| assert result.row_count_match is True | |
| assert result.exact_match is False | |
| # syntax(0.10) + columns(0.20) + row_count(0.20) = 0.50 | |
| assert result.reward == pytest.approx(0.50) | |
| def test_order_sensitive_wrong_order_is_not_exact(self): | |
| from grader import grade | |
| gt = [{"id": 1}, {"id": 2}] | |
| actual = [{"id": 2}, {"id": 1}] # reversed | |
| result = grade( | |
| actual_rows=actual, | |
| ground_truth_rows=gt, | |
| error=None, | |
| step=1, | |
| order_sensitive=True, | |
| ) | |
| assert result.exact_match is False | |
| def test_order_insensitive_accepts_different_row_order(self): | |
| from grader import grade | |
| gt = [{"id": 1}, {"id": 2}] | |
| actual = [{"id": 2}, {"id": 1}] # different order but same content | |
| result = grade( | |
| actual_rows=actual, | |
| ground_truth_rows=gt, | |
| error=None, | |
| step=1, | |
| order_sensitive=False, | |
| ) | |
| assert result.exact_match is True | |
| def test_penalty_never_makes_reward_negative(self): | |
| from grader import grade | |
| # Step 99 with syntax error β reward must be >= 0 | |
| result = grade( | |
| actual_rows=None, | |
| ground_truth_rows=[{"x": 1}], | |
| error="some error", | |
| step=99, | |
| ) | |
| assert result.reward >= 0.0 | |
| def test_execute_query_blocks_writes(self, db_conn): | |
| from grader import execute_query | |
| rows, err = execute_query(db_conn, "INSERT INTO categories(name) VALUES ('x')") | |
| assert rows is None | |
| assert "not allowed" in err.lower() or "INSERT" in err | |
| def test_execute_query_returns_rows(self, db_conn): | |
| from grader import execute_query | |
| rows, err = execute_query(db_conn, "SELECT id, name FROM categories ORDER BY id") | |
| assert err is None | |
| assert len(rows) == 8 | |
| assert "id" in rows[0] | |
| assert "name" in rows[0] | |
| def test_compute_ground_truth(self, db_conn): | |
| from grader import compute_ground_truth | |
| rows = compute_ground_truth(db_conn, "SELECT COUNT(*) AS n FROM customers") | |
| assert len(rows) == 1 | |
| assert rows[0]["n"] == 150 | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 3. TASK REGISTRY | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class TestTasks: | |
| def test_all_tasks_registered(self): | |
| from tasks import all_task_names | |
| names = all_task_names() | |
| assert "simple-filter" in names | |
| assert "join-aggregation" in names | |
| assert "analytics-window" in names | |
| def test_task_has_examples(self, task_name): | |
| from tasks import get_task | |
| task = get_task(task_name) | |
| assert len(task.examples) >= 3, f"{task_name} needs at least 3 examples" | |
| def test_task_sql_runs_on_real_db(self, task_name, db_conn): | |
| """Every ground-truth SQL must execute cleanly against the seeded DB.""" | |
| from tasks import get_task | |
| from grader import execute_query | |
| task = get_task(task_name) | |
| for ex in task.examples: | |
| rows, error = execute_query(db_conn, ex.sql) | |
| assert error is None, ( | |
| f"Task {task_name!r} SQL failed:\n{ex.sql}\nError: {error}" | |
| ) | |
| assert rows is not None | |
| def test_task_roundrobin(self, task_name): | |
| from tasks import get_task | |
| task = get_task(task_name) | |
| n = len(task.examples) | |
| seen = [task.next_example() for _ in range(n * 2)] | |
| # After n calls, second half should repeat first half | |
| assert seen[:n] == seen[n:] | |
| def test_schema_context_non_empty(self): | |
| from tasks import get_task | |
| task = get_task("simple-filter") | |
| ctx = task.schema_context() | |
| assert "customers" in ctx | |
| assert "orders" in ctx | |
| assert "products" in ctx | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 4. ENVIRONMENT | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class TestEnvironment: | |
| def test_reset_returns_observation(self, fresh_env): | |
| obs = fresh_env.reset(task_name="simple-filter") | |
| assert obs.question != "" | |
| assert obs.schema_context != "" | |
| assert obs.task_name == "simple-filter" | |
| assert obs.done is False | |
| assert obs.step == 0 | |
| assert obs.reward is None | |
| def test_reset_state(self, fresh_env): | |
| fresh_env.reset(task_name="join-aggregation") | |
| state = fresh_env.state | |
| assert state.task_name == "join-aggregation" | |
| assert state.task_difficulty == "medium" | |
| assert state.step_count == 0 | |
| assert state.solved is False | |
| def test_step_increments_step_count(self, fresh_env): | |
| from models import NL2SQLAction | |
| fresh_env.reset(task_name="simple-filter") | |
| fresh_env.step(NL2SQLAction(query="SELECT 1")) | |
| assert fresh_env.state.step_count == 1 | |
| def test_step_syntax_error_gives_nonzero_error(self, fresh_env): | |
| from models import NL2SQLAction | |
| fresh_env.reset(task_name="simple-filter") | |
| obs = fresh_env.step(NL2SQLAction(query="SELCT * FORM broken")) | |
| assert obs.last_error is not None | |
| assert obs.reward == 0.0 | |
| def test_step_valid_query_returns_result(self, fresh_env): | |
| from models import NL2SQLAction | |
| fresh_env.reset(task_name="simple-filter") | |
| obs = fresh_env.step(NL2SQLAction( | |
| query="SELECT id, name FROM customers ORDER BY name LIMIT 5" | |
| )) | |
| assert obs.last_error is None | |
| assert len(obs.last_result) <= 5 | |
| assert obs.reward >= 0.0 | |
| def test_exact_match_ends_episode(self, fresh_env): | |
| """Submitting the exact ground-truth SQL should solve the episode.""" | |
| from models import NL2SQLAction | |
| fresh_env.reset(task_name="simple-filter") | |
| # Get the ground truth SQL from the internal example | |
| gt_sql = fresh_env._example.sql | |
| obs = fresh_env.step(NL2SQLAction(query=gt_sql)) | |
| assert obs.done is True | |
| assert fresh_env.state.solved is True | |
| assert obs.reward == pytest.approx(1.0) # step 1, full score | |
| def test_max_steps_ends_episode(self, fresh_env): | |
| """Exhausting all steps should end the episode even without solving.""" | |
| from models import NL2SQLAction | |
| from environment import MAX_STEPS | |
| fresh_env.reset(task_name="analytics-window") | |
| obs = None | |
| for _ in range(MAX_STEPS): | |
| obs = fresh_env.step(NL2SQLAction(query="SELECT 1")) | |
| assert obs is not None | |
| assert obs.done is True | |
| def test_reset_clears_previous_episode(self, fresh_env): | |
| from models import NL2SQLAction | |
| fresh_env.reset(task_name="simple-filter") | |
| fresh_env.step(NL2SQLAction(query="SELECT 1")) | |
| # Second reset should clear state | |
| obs = fresh_env.reset(task_name="join-aggregation") | |
| assert fresh_env.state.step_count == 0 | |
| assert obs.step == 0 | |
| assert obs.task_name == "join-aggregation" | |
| def test_all_tasks_solvable(self, task_name): | |
| """Ground-truth SQL should always produce reward == 1.0 on step 1.""" | |
| from environment import NL2SQLEnvironment | |
| from models import NL2SQLAction | |
| env = NL2SQLEnvironment() | |
| env.reset(task_name=task_name) | |
| gt_sql = env._example.sql | |
| obs = env.step(NL2SQLAction(query=gt_sql)) | |
| assert obs.done is True | |
| assert obs.reward == pytest.approx(1.0), ( | |
| f"Task {task_name!r}: ground-truth SQL did not score 1.0.\n" | |
| f"SQL: {gt_sql}\nError: {obs.last_error}\nReward: {obs.reward}" | |
| ) | |
| def test_score_normalised_to_0_1(self, fresh_env): | |
| from models import NL2SQLAction | |
| fresh_env.reset(task_name="simple-filter") | |
| for _ in range(3): | |
| obs = fresh_env.step(NL2SQLAction(query="SELECT 1 AS x")) | |
| assert 0.0 <= obs.score <= 1.0 | |
| def test_write_query_blocked(self, fresh_env): | |
| from models import NL2SQLAction | |
| fresh_env.reset(task_name="simple-filter") | |
| obs = fresh_env.step(NL2SQLAction( | |
| query="INSERT INTO categories(name) VALUES ('hack')" | |
| )) | |
| assert obs.last_error is not None | |
| assert "not allowed" in obs.last_error.lower() or "INSERT" in obs.last_error | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # 5. LOG FORMAT COMPLIANCE | |
| # ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class TestLogFormat: | |
| """Validate that the inference.py log helpers emit correct format.""" | |
| START_RE = re.compile( | |
| r"^\[START\] task=\S+ env=\S+ model=\S+$" | |
| ) | |
| STEP_RE = re.compile( | |
| r"^\[STEP\] step=\d+ action=.+ reward=\d+\.\d{2} " | |
| r"done=(true|false) error=.+$" | |
| ) | |
| END_RE = re.compile( | |
| r"^\[END\] success=(true|false) steps=\d+ score=\d+\.\d{3} " | |
| r"rewards=[\d.,]+$" | |
| ) | |
| def _capture(self, func, *args, **kwargs) -> str: | |
| import io | |
| from contextlib import redirect_stdout | |
| buf = io.StringIO() | |
| with redirect_stdout(buf): | |
| func(*args, **kwargs) | |
| return buf.getvalue().strip() | |
| def test_log_start_format(self): | |
| sys.path.insert(0, str(ROOT)) | |
| from inference import log_start | |
| out = self._capture(log_start, "simple-filter", "Qwen/Qwen2.5-72B") | |
| assert self.START_RE.match(out), f"Bad [START] format: {out!r}" | |
| def test_log_step_format_null_error(self): | |
| from inference import log_step | |
| out = self._capture(log_step, 1, "SELECT 1", 0.10, False, None) | |
| assert self.STEP_RE.match(out), f"Bad [STEP] format: {out!r}" | |
| def test_log_step_format_with_error(self): | |
| from inference import log_step | |
| out = self._capture(log_step, 2, "SELCT 1", 0.0, False, "syntax error") | |
| assert self.STEP_RE.match(out), f"Bad [STEP] format: {out!r}" | |
| def test_log_end_format_success(self): | |
| from inference import log_end | |
| out = self._capture(log_end, True, 3, 0.850, [0.50, 1.0, 1.0]) | |
| assert self.END_RE.match(out), f"Bad [END] format: {out!r}" | |
| def test_log_end_format_failure(self): | |
| from inference import log_end | |
| out = self._capture(log_end, False, 5, 0.100, [0.1, 0.0, 0.0, 0.0, 0.0]) | |
| assert self.END_RE.match(out), f"Bad [END] format: {out!r}" | |
| def test_reward_two_decimal_places(self): | |
| from inference import log_step | |
| out = self._capture(log_step, 1, "SELECT 1", 0.5, False, None) | |
| # reward= field must have exactly 2 decimal places | |
| match = re.search(r"reward=(\d+\.\d+)", out) | |
| assert match, "No reward= field found" | |
| assert len(match.group(1).split(".")[1]) == 2 | |
| def test_score_three_decimal_places(self): | |
| from inference import log_end | |
| out = self._capture(log_end, True, 1, 1.0, [1.0]) | |
| match = re.search(r"score=(\d+\.\d+)", out) | |
| assert match | |
| assert len(match.group(1).split(".")[1]) == 3 | |