sql_env / tests /test_synthetic.py
hjerpe's picture
Upload folder using huggingface_hub
5dd1bb4 verified
"""Tests for synthetic database schema introspection utilities."""
from __future__ import annotations
import json
import sqlite3
from pathlib import Path
import pytest
from sql_env.server.synthetic.generate import (
VariantResult,
generate_variant,
generate_variants_for_question,
)
from sql_env.server.synthetic.__main__ import main as synthetic_cli_main
from sql_env.server.synthetic.mutations import (
MutationResult,
TableSchema,
detect_bridge_tables,
duplicate_bridge_rows,
get_table_schemas,
inject_irrelevant_rows,
remap_ids,
)
from sql_env.server.synthetic.validate import validate_gold_sql
def _sqlite_table_definitions(db_path: Path) -> list[tuple[str, str]]:
with sqlite3.connect(db_path) as connection:
cursor = connection.cursor()
cursor.execute(
"SELECT name, sql FROM sqlite_master "
"WHERE type = 'table' AND name NOT LIKE 'sqlite_%' ORDER BY name"
)
return [(str(row[0]), str(row[1])) for row in cursor.fetchall()]
def _find_real_spider_case() -> tuple[Path, str, str] | None:
repo_root = Path(__file__).resolve().parents[1]
question_files = [
repo_root / "data" / "questions" / "questions_train.json",
repo_root / "data" / "questions" / "questions_eval.json",
]
for question_file in question_files:
if not question_file.exists():
continue
questions = json.loads(question_file.read_text(encoding="utf-8"))
for question in questions:
db_name = question.get("database_name")
gold_sql = question.get("gold_sql")
if not isinstance(db_name, str) or not isinstance(gold_sql, str):
continue
db_path = repo_root / "data" / "databases" / db_name / f"{db_name}.sqlite"
if not db_path.exists():
continue
try:
is_valid, _ = validate_gold_sql(str(db_path), gold_sql)
except sqlite3.OperationalError:
continue
if is_valid:
return db_path, gold_sql, db_name
return None
@pytest.fixture
def sample_db_path(tmp_path):
db_path = tmp_path / "sample.sqlite"
with sqlite3.connect(db_path) as connection:
cursor = connection.cursor()
cursor.execute("PRAGMA foreign_keys = ON")
cursor.execute(
"CREATE TABLE departments (id INTEGER PRIMARY KEY, name TEXT NOT NULL)"
)
cursor.execute(
"CREATE TABLE employees ("
"id INTEGER PRIMARY KEY,"
"name TEXT NOT NULL,"
"department_id INTEGER,"
"FOREIGN KEY(department_id) REFERENCES departments(id)"
")"
)
cursor.execute(
"CREATE TABLE students (id INTEGER PRIMARY KEY, name TEXT NOT NULL)"
)
cursor.execute(
"CREATE TABLE courses (id INTEGER PRIMARY KEY, title TEXT NOT NULL)"
)
cursor.execute(
"CREATE TABLE enrollments ("
"student_id INTEGER,"
"course_id INTEGER,"
"PRIMARY KEY(student_id, course_id),"
"FOREIGN KEY(student_id) REFERENCES students(id),"
"FOREIGN KEY(course_id) REFERENCES courses(id)"
")"
)
cursor.execute("CREATE TABLE audit_log (event TEXT, created_at TEXT)")
cursor.execute(
"INSERT INTO departments (id, name) VALUES (1, 'Engineering'), (2, 'Sales')"
)
cursor.execute(
"INSERT INTO employees (id, name, department_id) VALUES "
"(1, 'Alice', 1), (2, 'Bob', 2)"
)
cursor.execute(
"INSERT INTO students (id, name) VALUES (1, 'Sam'), (2, 'Riley')"
)
cursor.execute(
"INSERT INTO courses (id, title) VALUES (1, 'Math'), (2, 'Science')"
)
cursor.execute(
"INSERT INTO enrollments (student_id, course_id) VALUES (1, 1), (2, 2)"
)
cursor.execute(
"INSERT INTO audit_log (event, created_at) "
"VALUES ('seed', '2026-03-27T00:00:00Z')"
)
connection.commit()
return str(db_path)
def test_table_schema_dataclass_fields():
schema = TableSchema(
name="enrollments",
columns=["student_id", "course_id"],
pk_columns=["student_id", "course_id"],
fk_columns=[("student_id", "students", "id")],
)
assert schema.name == "enrollments"
assert schema.columns == ["student_id", "course_id"]
assert schema.pk_columns == ["student_id", "course_id"]
assert schema.fk_columns == [("student_id", "students", "id")]
def test_mutation_result_dataclass_fields():
result = MutationResult(
mutation_name="inject_irrelevant_rows",
tables_affected=["employees"],
rows_added=5,
success=True,
)
assert result.mutation_name == "inject_irrelevant_rows"
assert result.tables_affected == ["employees"]
assert result.rows_added == 5
assert result.success is True
def test_get_table_schemas_multi_table_with_fk_and_composite_pk(sample_db_path):
schemas = get_table_schemas(sample_db_path)
by_name = {schema.name: schema for schema in schemas}
assert set(by_name) == {
"audit_log",
"courses",
"departments",
"employees",
"enrollments",
"students",
}
assert by_name["departments"].pk_columns == ["id"]
assert by_name["employees"].fk_columns == [("department_id", "departments", "id")]
assert by_name["enrollments"].pk_columns == ["student_id", "course_id"]
assert set(by_name["enrollments"].fk_columns) == {
("student_id", "students", "id"),
("course_id", "courses", "id"),
}
def test_get_table_schemas_no_pk_table(sample_db_path):
schemas = get_table_schemas(sample_db_path)
by_name = {schema.name: schema for schema in schemas}
assert by_name["audit_log"].pk_columns == []
def test_get_table_schemas_empty_db(tmp_path):
db_path = tmp_path / "empty.sqlite"
sqlite3.connect(db_path).close()
assert get_table_schemas(str(db_path)) == []
def test_get_table_schemas_nonexistent_db_raises_operational_error(tmp_path):
missing_path = tmp_path / "missing.sqlite"
with pytest.raises(sqlite3.OperationalError):
get_table_schemas(str(missing_path))
def test_detect_bridge_tables_identifies_tables_with_two_or_more_fks(sample_db_path):
schemas = get_table_schemas(sample_db_path)
assert detect_bridge_tables(schemas) == ["enrollments"]
def test_detect_bridge_tables_empty_when_no_bridge_tables():
schemas = [
TableSchema(
name="employees",
columns=["id", "department_id"],
pk_columns=["id"],
fk_columns=[("department_id", "departments", "id")],
),
TableSchema(
name="departments",
columns=["id", "name"],
pk_columns=["id"],
fk_columns=[],
),
]
assert detect_bridge_tables(schemas) == []
def test_inject_irrelevant_rows_adds_rows_with_new_primary_keys(sample_db_path):
schemas = get_table_schemas(sample_db_path)
result = inject_irrelevant_rows(sample_db_path, schemas, n_rows=2)
assert result.mutation_name == "inject_irrelevant_rows"
assert result.success is True
assert result.rows_added == 8
assert result.tables_affected == ["courses", "departments", "employees", "students"]
with sqlite3.connect(sample_db_path) as connection:
cursor = connection.cursor()
cursor.execute("SELECT COUNT(*) FROM employees")
assert cursor.fetchone()[0] == 4
cursor.execute("SELECT MIN(id), MAX(id) FROM employees")
assert cursor.fetchone() == (1, 4)
def test_inject_irrelevant_rows_preserves_existing_rows(sample_db_path):
schemas = get_table_schemas(sample_db_path)
inject_irrelevant_rows(sample_db_path, schemas, n_rows=1)
with sqlite3.connect(sample_db_path) as connection:
cursor = connection.cursor()
cursor.execute("SELECT name FROM employees ORDER BY id")
names = [row[0] for row in cursor.fetchall()]
assert names[0:2] == ["Alice", "Bob"]
def test_inject_irrelevant_rows_zero_rows_no_change(sample_db_path):
schemas = get_table_schemas(sample_db_path)
result = inject_irrelevant_rows(sample_db_path, schemas, n_rows=0)
assert result.rows_added == 0
assert result.tables_affected == []
assert result.success is True
with sqlite3.connect(sample_db_path) as connection:
cursor = connection.cursor()
cursor.execute("SELECT COUNT(*) FROM employees")
assert cursor.fetchone()[0] == 2
def test_remap_ids_basic_changes_integer_primary_keys(sample_db_path):
schemas = get_table_schemas(sample_db_path)
with sqlite3.connect(sample_db_path) as connection:
cursor = connection.cursor()
cursor.execute("SELECT id FROM departments ORDER BY id")
before_department_ids = [row[0] for row in cursor.fetchall()]
result = remap_ids(sample_db_path, schemas)
assert result.mutation_name == "remap_ids"
assert result.success is True
assert "departments" in result.tables_affected
assert "employees" in result.tables_affected
assert result.rows_added >= 2
with sqlite3.connect(sample_db_path) as connection:
cursor = connection.cursor()
cursor.execute("SELECT id FROM departments ORDER BY id")
after_department_ids = [row[0] for row in cursor.fetchall()]
assert after_department_ids != before_department_ids
assert len(after_department_ids) == len(before_department_ids)
def test_remap_ids_updates_foreign_keys_and_preserves_join(sample_db_path):
schemas = get_table_schemas(sample_db_path)
remap_ids(sample_db_path, schemas)
with sqlite3.connect(sample_db_path) as connection:
cursor = connection.cursor()
cursor.execute("PRAGMA foreign_key_check")
assert cursor.fetchall() == []
cursor.execute(
"SELECT e.name, d.name FROM employees e "
"JOIN departments d ON e.department_id = d.id "
"ORDER BY e.name"
)
joined = cursor.fetchall()
assert joined == [("Alice", "Engineering"), ("Bob", "Sales")]
def test_remap_ids_is_bijective_and_preserves_row_counts(sample_db_path):
schemas = get_table_schemas(sample_db_path)
with sqlite3.connect(sample_db_path) as connection:
cursor = connection.cursor()
cursor.execute("SELECT COUNT(*), COUNT(DISTINCT id) FROM departments")
before_counts = cursor.fetchone()
remap_ids(sample_db_path, schemas)
with sqlite3.connect(sample_db_path) as connection:
cursor = connection.cursor()
cursor.execute("SELECT COUNT(*), COUNT(DISTINCT id) FROM departments")
after_counts = cursor.fetchone()
assert after_counts == before_counts
def test_remap_ids_skips_tables_without_integer_primary_key(tmp_path):
db_path = tmp_path / "text_pk.sqlite"
with sqlite3.connect(db_path) as connection:
cursor = connection.cursor()
cursor.execute("CREATE TABLE labels (name TEXT PRIMARY KEY, value TEXT)")
cursor.execute("INSERT INTO labels (name, value) VALUES ('alpha', '1')")
connection.commit()
schemas = get_table_schemas(str(db_path))
result = remap_ids(str(db_path), schemas)
assert result.success is True
assert result.rows_added == 0
assert result.tables_affected == []
def test_duplicate_bridge_rows_adds_rows_for_bridge_tables(tmp_path):
db_path = tmp_path / "bridge.sqlite"
with sqlite3.connect(db_path) as connection:
cursor = connection.cursor()
cursor.execute("PRAGMA foreign_keys = ON")
cursor.execute("CREATE TABLE students (id INTEGER PRIMARY KEY, name TEXT)")
cursor.execute("CREATE TABLE clubs (id INTEGER PRIMARY KEY, name TEXT)")
cursor.execute(
"CREATE TABLE club_memberships ("
"student_id INTEGER,"
"club_id INTEGER,"
"FOREIGN KEY(student_id) REFERENCES students(id),"
"FOREIGN KEY(club_id) REFERENCES clubs(id)"
")"
)
cursor.execute(
"INSERT INTO students (id, name) VALUES (1, 'Sam'), (2, 'Riley')"
)
cursor.execute(
"INSERT INTO clubs (id, name) VALUES (1, 'Chess'), (2, 'Robotics')"
)
cursor.execute(
"INSERT INTO club_memberships (student_id, club_id) VALUES (1, 1), (2, 2)"
)
connection.commit()
schemas = get_table_schemas(str(db_path))
bridge_tables = detect_bridge_tables(schemas)
result = duplicate_bridge_rows(str(db_path), schemas, bridge_tables)
assert result.mutation_name == "duplicate_bridge_rows"
assert result.success is True
assert result.rows_added == 2
assert result.tables_affected == ["club_memberships"]
with sqlite3.connect(db_path) as connection:
cursor = connection.cursor()
cursor.execute("SELECT COUNT(*) FROM club_memberships")
assert cursor.fetchone()[0] == 4
def test_duplicate_bridge_rows_empty_bridge_tables_returns_noop(sample_db_path):
schemas = get_table_schemas(sample_db_path)
result = duplicate_bridge_rows(sample_db_path, schemas, [])
assert result.success is True
assert result.rows_added == 0
assert result.tables_affected == []
def test_duplicate_bridge_rows_skips_rows_blocked_by_unique_constraints(sample_db_path):
schemas = get_table_schemas(sample_db_path)
result = duplicate_bridge_rows(sample_db_path, schemas, ["enrollments"])
assert result.success is True
assert result.rows_added == 0
assert result.tables_affected == []
with sqlite3.connect(sample_db_path) as connection:
cursor = connection.cursor()
cursor.execute("SELECT COUNT(*) FROM enrollments")
assert cursor.fetchone()[0] == 2
def test_duplicate_bridge_rows_ignores_nonexistent_tables(sample_db_path):
schemas = get_table_schemas(sample_db_path)
result = duplicate_bridge_rows(sample_db_path, schemas, ["missing_bridge"])
assert result.success is True
assert result.rows_added == 0
assert result.tables_affected == []
def test_validate_gold_sql_valid_returns_serialized_result(sample_db_path):
is_valid, result = validate_gold_sql(
sample_db_path,
"SELECT name FROM employees ORDER BY id",
)
assert is_valid is True
assert result == "[('Alice',), ('Bob',)]"
def test_validate_gold_sql_empty_result_returns_false(sample_db_path):
is_valid, result = validate_gold_sql(
sample_db_path,
"SELECT name FROM employees WHERE id = -1",
)
assert is_valid is False
assert result is None
def test_validate_gold_sql_invalid_query_raises_operational_error(sample_db_path):
with pytest.raises(sqlite3.OperationalError):
validate_gold_sql(sample_db_path, "SELECT * FROM definitely_missing_table")
def test_validate_gold_sql_honors_custom_timeout(sample_db_path):
is_valid, result = validate_gold_sql(
sample_db_path,
"SELECT COUNT(*) FROM employees",
timeout=0.001,
)
assert is_valid is True
assert result == "[(2,)]"
def test_variant_result_dataclass_fields(sample_db_path):
mutation = MutationResult(
mutation_name="inject_irrelevant_rows",
tables_affected=["employees"],
rows_added=2,
success=True,
)
result = VariantResult(
variant_path="/tmp/sample_variant.sqlite",
original_path=sample_db_path,
mutations_applied=[mutation],
gold_sql_valid=True,
gold_answer="[(1,)]",
)
assert result.variant_path.endswith("sample_variant.sqlite")
assert result.original_path == sample_db_path
assert result.mutations_applied == [mutation]
assert result.gold_sql_valid is True
assert result.gold_answer == "[(1,)]"
def test_generate_variant_default_applies_all_mutations_and_creates_file(
sample_db_path, tmp_path
):
output_dir = tmp_path / "variants"
result = generate_variant(
db_path=sample_db_path,
gold_sql="SELECT name FROM employees",
output_dir=str(output_dir),
)
assert result.gold_sql_valid is True
assert len(result.mutations_applied) == 3
assert [m.mutation_name for m in result.mutations_applied] == [
"inject_irrelevant_rows",
"remap_ids",
"duplicate_bridge_rows",
]
assert (output_dir / "sample_variant_0.sqlite").exists()
def test_generate_variant_single_mutation(sample_db_path, tmp_path):
output_dir = tmp_path / "variants"
result = generate_variant(
db_path=sample_db_path,
gold_sql="SELECT name FROM employees",
output_dir=str(output_dir),
mutations=["inject_irrelevant_rows"],
)
assert result.gold_sql_valid is True
assert [m.mutation_name for m in result.mutations_applied] == [
"inject_irrelevant_rows"
]
def test_generate_variant_does_not_modify_original_db(sample_db_path, tmp_path):
output_dir = tmp_path / "variants"
with sqlite3.connect(sample_db_path) as connection:
cursor = connection.cursor()
cursor.execute("SELECT COUNT(*) FROM employees")
before_count = cursor.fetchone()[0]
generate_variant(
db_path=sample_db_path,
gold_sql="SELECT name FROM employees",
output_dir=str(output_dir),
)
with sqlite3.connect(sample_db_path) as connection:
cursor = connection.cursor()
cursor.execute("SELECT COUNT(*) FROM employees")
after_count = cursor.fetchone()[0]
assert before_count == after_count == 2
def test_generate_variant_raises_on_missing_db(tmp_path):
with pytest.raises(FileNotFoundError):
generate_variant(
db_path=str(tmp_path / "missing.sqlite"),
gold_sql="SELECT 1",
output_dir=str(tmp_path / "variants"),
)
def test_generate_variant_raises_on_unknown_mutation(sample_db_path, tmp_path):
with pytest.raises(ValueError, match="Unknown mutation"):
generate_variant(
db_path=sample_db_path,
gold_sql="SELECT name FROM employees",
output_dir=str(tmp_path / "variants"),
mutations=["unknown_mutation"],
)
def test_generate_variant_invalid_gold_sql_discards_variant(sample_db_path, tmp_path):
output_dir = tmp_path / "variants"
result = generate_variant(
db_path=sample_db_path,
gold_sql="SELECT name FROM employees WHERE id = 1",
output_dir=str(output_dir),
)
assert result.gold_sql_valid is False
assert result.gold_answer is None
assert not (output_dir / "sample_variant_0.sqlite").exists()
def test_generate_variant_uses_variant_id_in_filename(sample_db_path, tmp_path):
output_dir = tmp_path / "variants"
result = generate_variant(
db_path=sample_db_path,
gold_sql="SELECT name FROM employees",
output_dir=str(output_dir),
variant_id=7,
)
assert result.variant_path.endswith("sample_variant_7.sqlite")
def test_generate_variants_for_question_default_count(sample_db_path, tmp_path):
output_dir = tmp_path / "variants"
results = generate_variants_for_question(
db_path=sample_db_path,
gold_sql="SELECT name FROM employees",
output_dir=str(output_dir),
n_variants=2,
)
assert len(results) == 2
assert all(result.gold_sql_valid for result in results)
def test_generate_variants_for_question_zero_returns_empty_list(
sample_db_path, tmp_path
):
output_dir = tmp_path / "variants"
results = generate_variants_for_question(
db_path=sample_db_path,
gold_sql="SELECT name FROM employees",
output_dir=str(output_dir),
n_variants=0,
)
assert results == []
def test_generate_variants_for_question_returns_unique_paths(sample_db_path, tmp_path):
output_dir = tmp_path / "variants"
results = generate_variants_for_question(
db_path=sample_db_path,
gold_sql="SELECT name FROM employees",
output_dir=str(output_dir),
n_variants=3,
)
paths = [result.variant_path for result in results]
assert len(paths) == len(set(paths))
def test_synthetic_cli_smoke_generates_variants(sample_db_path, tmp_path, capsys):
output_dir = tmp_path / "variants_cli"
exit_code = synthetic_cli_main(
[
"--db-path",
sample_db_path,
"--gold-sql",
"SELECT name FROM employees",
"--output-dir",
str(output_dir),
"--n-variants",
"2",
]
)
captured = capsys.readouterr()
assert exit_code == 0
assert "Generated 2 valid variant(s)" in captured.out
assert (output_dir / "sample_variant_0.sqlite").exists()
assert (output_dir / "sample_variant_1.sqlite").exists()
@pytest.mark.slow
def test_generate_variants_integration_with_real_spider_database(tmp_path):
case = _find_real_spider_case()
if case is None:
pytest.skip(
"No local Spider DB + valid gold SQL pair found. "
"Run: uv run python scripts/download_spider_databases.py"
)
db_path, gold_sql, db_name = case
output_dir = tmp_path / "spider_variants" / db_name
results = generate_variants_for_question(
db_path=str(db_path),
gold_sql=gold_sql,
output_dir=str(output_dir),
n_variants=2,
)
assert len(results) >= 1
variant = results[0]
variant_path = Path(variant.variant_path)
assert variant_path.exists()
assert variant.gold_sql_valid is True
assert any(mutation.rows_added > 0 for mutation in variant.mutations_applied)
original_schema = _sqlite_table_definitions(db_path)
variant_schema = _sqlite_table_definitions(variant_path)
assert variant_schema == original_schema