nl2sql-bench / tests /test_all.py
ritvik360's picture
Upload folder using huggingface_hub
a39d8ef verified
"""
nl2sql-bench/tests/test_all.py
================================
Comprehensive test suite covering:
- Database seeder (determinism + row counts)
- Grader (all reward components, step penalty, edge cases)
- Task registry (all 3 tasks load and produce valid examples)
- Environment (reset, step, episode boundary, done logic)
- Inference log format (regex checks on START / STEP / END)
Run with:
pytest tests/ -v
or from project root:
PYTHONPATH=.:server pytest tests/ -v
"""
from __future__ import annotations
import re
import sqlite3
import sys
import os
from pathlib import Path
import pytest
# ── Path setup so tests can import from both project root and server/ ──────
ROOT = Path(__file__).parent.parent
SERVER = ROOT / "server"
sys.path.insert(0, str(ROOT))
sys.path.insert(0, str(SERVER))
# ── Fixtures ───────────────────────────────────────────────────────────────
@pytest.fixture(scope="session")
def db_conn():
"""Shared in-memory SQLite connection with full schema + seed data."""
from db.seed import seed_database
schema = (SERVER / "db" / "schema.sql").read_text()
conn = sqlite3.connect(":memory:", check_same_thread=False)
conn.row_factory = sqlite3.Row
conn.executescript(schema)
seed_database(conn)
yield conn
conn.close()
@pytest.fixture
def fresh_env():
"""A fresh NL2SQLEnvironment instance per test."""
from environment import NL2SQLEnvironment
return NL2SQLEnvironment()
# ══════════════════════════════════════════════════════════════════════════════
# 1. DATABASE SEEDER
# ══════════════════════════════════════════════════════════════════════════════
class TestSeeder:
def test_categories_count(self, db_conn):
row = db_conn.execute("SELECT COUNT(*) FROM categories").fetchone()
assert row[0] == 8, "Should have exactly 8 categories"
def test_products_count(self, db_conn):
row = db_conn.execute("SELECT COUNT(*) FROM products").fetchone()
assert row[0] == 64, "Should have 8 products Γ— 8 categories = 64"
def test_customers_count(self, db_conn):
row = db_conn.execute("SELECT COUNT(*) FROM customers").fetchone()
assert row[0] == 150
def test_orders_exist(self, db_conn):
row = db_conn.execute("SELECT COUNT(*) FROM orders").fetchone()
assert row[0] > 100, "Should have a meaningful number of orders"
def test_order_items_exist(self, db_conn):
row = db_conn.execute("SELECT COUNT(*) FROM order_items").fetchone()
assert row[0] > 200
def test_reviews_exist(self, db_conn):
row = db_conn.execute("SELECT COUNT(*) FROM reviews").fetchone()
assert row[0] > 50
def test_determinism(self, db_conn):
"""Seeding a second connection with the same seed gives identical counts."""
from db.seed import seed_database
schema = (SERVER / "db" / "schema.sql").read_text()
conn2 = sqlite3.connect(":memory:")
conn2.executescript(schema)
seed_database(conn2)
for tbl in ["categories", "products", "customers", "orders",
"order_items", "reviews"]:
c1 = db_conn.execute(f"SELECT COUNT(*) FROM {tbl}").fetchone()[0]
c2 = conn2.execute(f"SELECT COUNT(*) FROM {tbl}").fetchone()[0]
assert c1 == c2, f"Table {tbl} count mismatch: {c1} vs {c2}"
conn2.close()
def test_tiers_valid(self, db_conn):
bad = db_conn.execute(
"SELECT COUNT(*) FROM customers WHERE tier NOT IN ('bronze','silver','gold')"
).fetchone()[0]
assert bad == 0
def test_statuses_valid(self, db_conn):
bad = db_conn.execute(
"SELECT COUNT(*) FROM orders "
"WHERE status NOT IN ('pending','processing','shipped','delivered','cancelled')"
).fetchone()[0]
assert bad == 0
def test_ratings_valid(self, db_conn):
bad = db_conn.execute(
"SELECT COUNT(*) FROM reviews WHERE rating < 1 OR rating > 5"
).fetchone()[0]
assert bad == 0
def test_referential_integrity(self, db_conn):
"""Order items should reference valid orders and products."""
orphan_orders = db_conn.execute(
"SELECT COUNT(*) FROM order_items oi "
"LEFT JOIN orders o ON o.id = oi.order_id WHERE o.id IS NULL"
).fetchone()[0]
assert orphan_orders == 0
orphan_products = db_conn.execute(
"SELECT COUNT(*) FROM order_items oi "
"LEFT JOIN products p ON p.id = oi.product_id WHERE p.id IS NULL"
).fetchone()[0]
assert orphan_products == 0
# ══════════════════════════════════════════════════════════════════════════════
# 2. GRADER
# ══════════════════════════════════════════════════════════════════════════════
class TestGrader:
def test_exact_match_first_step(self):
from grader import grade
gt = [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}]
result = grade(
actual_rows=gt.copy(),
ground_truth_rows=gt,
error=None,
step=1,
order_sensitive=False,
)
assert result.reward == pytest.approx(1.0)
assert result.exact_match is True
assert result.syntax_ok is True
assert result.columns_match is True
assert result.row_count_match is True
assert result.step_penalty == 0.0
def test_syntax_error_gives_zero(self):
from grader import grade
result = grade(
actual_rows=None,
ground_truth_rows=[{"x": 1}],
error="near 'SELCT': syntax error",
step=1,
)
assert result.reward == 0.0
assert result.syntax_ok is False
def test_step_penalty_applied(self):
from grader import grade
gt = [{"n": 1}]
result = grade(
actual_rows=gt.copy(),
ground_truth_rows=gt,
error=None,
step=3, # penalty = (3-1)*0.05 = 0.10
)
assert result.reward == pytest.approx(1.0 - 0.10)
assert result.step_penalty == pytest.approx(0.10)
def test_columns_wrong_zero_higher_components(self):
from grader import grade
gt = [{"name": "Alice", "score": 10}]
actual = [{"user": "Alice", "points": 10}] # wrong column names
result = grade(actual_rows=actual, ground_truth_rows=gt, error=None, step=1)
assert result.columns_match is False
assert result.exact_match is False
# Only syntax score: 0.10
assert result.reward == pytest.approx(0.10)
def test_correct_columns_wrong_rows(self):
from grader import grade
gt = [{"name": "Alice"}, {"name": "Bob"}]
actual = [{"name": "Charlie"}, {"name": "Dave"}]
result = grade(actual_rows=actual, ground_truth_rows=gt, error=None, step=1)
assert result.columns_match is True
assert result.row_count_match is True
assert result.exact_match is False
# syntax(0.10) + columns(0.20) + row_count(0.20) = 0.50
assert result.reward == pytest.approx(0.50)
def test_order_sensitive_wrong_order_is_not_exact(self):
from grader import grade
gt = [{"id": 1}, {"id": 2}]
actual = [{"id": 2}, {"id": 1}] # reversed
result = grade(
actual_rows=actual,
ground_truth_rows=gt,
error=None,
step=1,
order_sensitive=True,
)
assert result.exact_match is False
def test_order_insensitive_accepts_different_row_order(self):
from grader import grade
gt = [{"id": 1}, {"id": 2}]
actual = [{"id": 2}, {"id": 1}] # different order but same content
result = grade(
actual_rows=actual,
ground_truth_rows=gt,
error=None,
step=1,
order_sensitive=False,
)
assert result.exact_match is True
def test_penalty_never_makes_reward_negative(self):
from grader import grade
# Step 99 with syntax error β†’ reward must be >= 0
result = grade(
actual_rows=None,
ground_truth_rows=[{"x": 1}],
error="some error",
step=99,
)
assert result.reward >= 0.0
def test_execute_query_blocks_writes(self, db_conn):
from grader import execute_query
rows, err = execute_query(db_conn, "INSERT INTO categories(name) VALUES ('x')")
assert rows is None
assert "not allowed" in err.lower() or "INSERT" in err
def test_execute_query_returns_rows(self, db_conn):
from grader import execute_query
rows, err = execute_query(db_conn, "SELECT id, name FROM categories ORDER BY id")
assert err is None
assert len(rows) == 8
assert "id" in rows[0]
assert "name" in rows[0]
def test_compute_ground_truth(self, db_conn):
from grader import compute_ground_truth
rows = compute_ground_truth(db_conn, "SELECT COUNT(*) AS n FROM customers")
assert len(rows) == 1
assert rows[0]["n"] == 150
# ══════════════════════════════════════════════════════════════════════════════
# 3. TASK REGISTRY
# ══════════════════════════════════════════════════════════════════════════════
class TestTasks:
def test_all_tasks_registered(self):
from tasks import all_task_names
names = all_task_names()
assert "simple-filter" in names
assert "join-aggregation" in names
assert "analytics-window" in names
@pytest.mark.parametrize("task_name", [
"simple-filter", "join-aggregation", "analytics-window"
])
def test_task_has_examples(self, task_name):
from tasks import get_task
task = get_task(task_name)
assert len(task.examples) >= 3, f"{task_name} needs at least 3 examples"
@pytest.mark.parametrize("task_name", [
"simple-filter", "join-aggregation", "analytics-window"
])
def test_task_sql_runs_on_real_db(self, task_name, db_conn):
"""Every ground-truth SQL must execute cleanly against the seeded DB."""
from tasks import get_task
from grader import execute_query
task = get_task(task_name)
for ex in task.examples:
rows, error = execute_query(db_conn, ex.sql)
assert error is None, (
f"Task {task_name!r} SQL failed:\n{ex.sql}\nError: {error}"
)
assert rows is not None
@pytest.mark.parametrize("task_name", [
"simple-filter", "join-aggregation", "analytics-window"
])
def test_task_roundrobin(self, task_name):
from tasks import get_task
task = get_task(task_name)
n = len(task.examples)
seen = [task.next_example() for _ in range(n * 2)]
# After n calls, second half should repeat first half
assert seen[:n] == seen[n:]
def test_schema_context_non_empty(self):
from tasks import get_task
task = get_task("simple-filter")
ctx = task.schema_context()
assert "customers" in ctx
assert "orders" in ctx
assert "products" in ctx
# ══════════════════════════════════════════════════════════════════════════════
# 4. ENVIRONMENT
# ══════════════════════════════════════════════════════════════════════════════
class TestEnvironment:
def test_reset_returns_observation(self, fresh_env):
obs = fresh_env.reset(task_name="simple-filter")
assert obs.question != ""
assert obs.schema_context != ""
assert obs.task_name == "simple-filter"
assert obs.done is False
assert obs.step == 0
assert obs.reward is None
def test_reset_state(self, fresh_env):
fresh_env.reset(task_name="join-aggregation")
state = fresh_env.state
assert state.task_name == "join-aggregation"
assert state.task_difficulty == "medium"
assert state.step_count == 0
assert state.solved is False
def test_step_increments_step_count(self, fresh_env):
from models import NL2SQLAction
fresh_env.reset(task_name="simple-filter")
fresh_env.step(NL2SQLAction(query="SELECT 1"))
assert fresh_env.state.step_count == 1
def test_step_syntax_error_gives_nonzero_error(self, fresh_env):
from models import NL2SQLAction
fresh_env.reset(task_name="simple-filter")
obs = fresh_env.step(NL2SQLAction(query="SELCT * FORM broken"))
assert obs.last_error is not None
assert obs.reward == 0.0
def test_step_valid_query_returns_result(self, fresh_env):
from models import NL2SQLAction
fresh_env.reset(task_name="simple-filter")
obs = fresh_env.step(NL2SQLAction(
query="SELECT id, name FROM customers ORDER BY name LIMIT 5"
))
assert obs.last_error is None
assert len(obs.last_result) <= 5
assert obs.reward >= 0.0
def test_exact_match_ends_episode(self, fresh_env):
"""Submitting the exact ground-truth SQL should solve the episode."""
from models import NL2SQLAction
fresh_env.reset(task_name="simple-filter")
# Get the ground truth SQL from the internal example
gt_sql = fresh_env._example.sql
obs = fresh_env.step(NL2SQLAction(query=gt_sql))
assert obs.done is True
assert fresh_env.state.solved is True
assert obs.reward == pytest.approx(1.0) # step 1, full score
def test_max_steps_ends_episode(self, fresh_env):
"""Exhausting all steps should end the episode even without solving."""
from models import NL2SQLAction
from environment import MAX_STEPS
fresh_env.reset(task_name="analytics-window")
obs = None
for _ in range(MAX_STEPS):
obs = fresh_env.step(NL2SQLAction(query="SELECT 1"))
assert obs is not None
assert obs.done is True
def test_reset_clears_previous_episode(self, fresh_env):
from models import NL2SQLAction
fresh_env.reset(task_name="simple-filter")
fresh_env.step(NL2SQLAction(query="SELECT 1"))
# Second reset should clear state
obs = fresh_env.reset(task_name="join-aggregation")
assert fresh_env.state.step_count == 0
assert obs.step == 0
assert obs.task_name == "join-aggregation"
@pytest.mark.parametrize("task_name", [
"simple-filter", "join-aggregation", "analytics-window"
])
def test_all_tasks_solvable(self, task_name):
"""Ground-truth SQL should always produce reward == 1.0 on step 1."""
from environment import NL2SQLEnvironment
from models import NL2SQLAction
env = NL2SQLEnvironment()
env.reset(task_name=task_name)
gt_sql = env._example.sql
obs = env.step(NL2SQLAction(query=gt_sql))
assert obs.done is True
assert obs.reward == pytest.approx(1.0), (
f"Task {task_name!r}: ground-truth SQL did not score 1.0.\n"
f"SQL: {gt_sql}\nError: {obs.last_error}\nReward: {obs.reward}"
)
def test_score_normalised_to_0_1(self, fresh_env):
from models import NL2SQLAction
fresh_env.reset(task_name="simple-filter")
for _ in range(3):
obs = fresh_env.step(NL2SQLAction(query="SELECT 1 AS x"))
assert 0.0 <= obs.score <= 1.0
def test_write_query_blocked(self, fresh_env):
from models import NL2SQLAction
fresh_env.reset(task_name="simple-filter")
obs = fresh_env.step(NL2SQLAction(
query="INSERT INTO categories(name) VALUES ('hack')"
))
assert obs.last_error is not None
assert "not allowed" in obs.last_error.lower() or "INSERT" in obs.last_error
# ══════════════════════════════════════════════════════════════════════════════
# 5. LOG FORMAT COMPLIANCE
# ══════════════════════════════════════════════════════════════════════════════
class TestLogFormat:
"""Validate that the inference.py log helpers emit correct format."""
START_RE = re.compile(
r"^\[START\] task=\S+ env=\S+ model=\S+$"
)
STEP_RE = re.compile(
r"^\[STEP\] step=\d+ action=.+ reward=\d+\.\d{2} "
r"done=(true|false) error=.+$"
)
END_RE = re.compile(
r"^\[END\] success=(true|false) steps=\d+ score=\d+\.\d{3} "
r"rewards=[\d.,]+$"
)
def _capture(self, func, *args, **kwargs) -> str:
import io
from contextlib import redirect_stdout
buf = io.StringIO()
with redirect_stdout(buf):
func(*args, **kwargs)
return buf.getvalue().strip()
def test_log_start_format(self):
sys.path.insert(0, str(ROOT))
from inference import log_start
out = self._capture(log_start, "simple-filter", "Qwen/Qwen2.5-72B")
assert self.START_RE.match(out), f"Bad [START] format: {out!r}"
def test_log_step_format_null_error(self):
from inference import log_step
out = self._capture(log_step, 1, "SELECT 1", 0.10, False, None)
assert self.STEP_RE.match(out), f"Bad [STEP] format: {out!r}"
def test_log_step_format_with_error(self):
from inference import log_step
out = self._capture(log_step, 2, "SELCT 1", 0.0, False, "syntax error")
assert self.STEP_RE.match(out), f"Bad [STEP] format: {out!r}"
def test_log_end_format_success(self):
from inference import log_end
out = self._capture(log_end, True, 3, 0.850, [0.50, 1.0, 1.0])
assert self.END_RE.match(out), f"Bad [END] format: {out!r}"
def test_log_end_format_failure(self):
from inference import log_end
out = self._capture(log_end, False, 5, 0.100, [0.1, 0.0, 0.0, 0.0, 0.0])
assert self.END_RE.match(out), f"Bad [END] format: {out!r}"
def test_reward_two_decimal_places(self):
from inference import log_step
out = self._capture(log_step, 1, "SELECT 1", 0.5, False, None)
# reward= field must have exactly 2 decimal places
match = re.search(r"reward=(\d+\.\d+)", out)
assert match, "No reward= field found"
assert len(match.group(1).split(".")[1]) == 2
def test_score_three_decimal_places(self):
from inference import log_end
out = self._capture(log_end, True, 1, 1.0, [1.0])
match = re.search(r"score=(\d+\.\d+)", out)
assert match
assert len(match.group(1).split(".")[1]) == 3