"""Tests for synthetic database schema introspection utilities.""" from __future__ import annotations import json import sqlite3 from pathlib import Path import pytest from sql_env.server.synthetic.generate import ( VariantResult, generate_variant, generate_variants_for_question, ) from sql_env.server.synthetic.__main__ import main as synthetic_cli_main from sql_env.server.synthetic.mutations import ( MutationResult, TableSchema, detect_bridge_tables, duplicate_bridge_rows, get_table_schemas, inject_irrelevant_rows, remap_ids, ) from sql_env.server.synthetic.validate import validate_gold_sql def _sqlite_table_definitions(db_path: Path) -> list[tuple[str, str]]: with sqlite3.connect(db_path) as connection: cursor = connection.cursor() cursor.execute( "SELECT name, sql FROM sqlite_master " "WHERE type = 'table' AND name NOT LIKE 'sqlite_%' ORDER BY name" ) return [(str(row[0]), str(row[1])) for row in cursor.fetchall()] def _find_real_spider_case() -> tuple[Path, str, str] | None: repo_root = Path(__file__).resolve().parents[1] question_files = [ repo_root / "data" / "questions" / "questions_train.json", repo_root / "data" / "questions" / "questions_eval.json", ] for question_file in question_files: if not question_file.exists(): continue questions = json.loads(question_file.read_text(encoding="utf-8")) for question in questions: db_name = question.get("database_name") gold_sql = question.get("gold_sql") if not isinstance(db_name, str) or not isinstance(gold_sql, str): continue db_path = repo_root / "data" / "databases" / db_name / f"{db_name}.sqlite" if not db_path.exists(): continue try: is_valid, _ = validate_gold_sql(str(db_path), gold_sql) except sqlite3.OperationalError: continue if is_valid: return db_path, gold_sql, db_name return None @pytest.fixture def sample_db_path(tmp_path): db_path = tmp_path / "sample.sqlite" with sqlite3.connect(db_path) as connection: cursor = connection.cursor() cursor.execute("PRAGMA foreign_keys = ON") cursor.execute( "CREATE TABLE departments (id INTEGER PRIMARY KEY, name TEXT NOT NULL)" ) cursor.execute( "CREATE TABLE employees (" "id INTEGER PRIMARY KEY," "name TEXT NOT NULL," "department_id INTEGER," "FOREIGN KEY(department_id) REFERENCES departments(id)" ")" ) cursor.execute( "CREATE TABLE students (id INTEGER PRIMARY KEY, name TEXT NOT NULL)" ) cursor.execute( "CREATE TABLE courses (id INTEGER PRIMARY KEY, title TEXT NOT NULL)" ) cursor.execute( "CREATE TABLE enrollments (" "student_id INTEGER," "course_id INTEGER," "PRIMARY KEY(student_id, course_id)," "FOREIGN KEY(student_id) REFERENCES students(id)," "FOREIGN KEY(course_id) REFERENCES courses(id)" ")" ) cursor.execute("CREATE TABLE audit_log (event TEXT, created_at TEXT)") cursor.execute( "INSERT INTO departments (id, name) VALUES (1, 'Engineering'), (2, 'Sales')" ) cursor.execute( "INSERT INTO employees (id, name, department_id) VALUES " "(1, 'Alice', 1), (2, 'Bob', 2)" ) cursor.execute( "INSERT INTO students (id, name) VALUES (1, 'Sam'), (2, 'Riley')" ) cursor.execute( "INSERT INTO courses (id, title) VALUES (1, 'Math'), (2, 'Science')" ) cursor.execute( "INSERT INTO enrollments (student_id, course_id) VALUES (1, 1), (2, 2)" ) cursor.execute( "INSERT INTO audit_log (event, created_at) " "VALUES ('seed', '2026-03-27T00:00:00Z')" ) connection.commit() return str(db_path) def test_table_schema_dataclass_fields(): schema = TableSchema( name="enrollments", columns=["student_id", "course_id"], pk_columns=["student_id", "course_id"], fk_columns=[("student_id", "students", "id")], ) assert schema.name == "enrollments" assert schema.columns == ["student_id", "course_id"] assert schema.pk_columns == ["student_id", "course_id"] assert schema.fk_columns == [("student_id", "students", "id")] def test_mutation_result_dataclass_fields(): result = MutationResult( mutation_name="inject_irrelevant_rows", tables_affected=["employees"], rows_added=5, success=True, ) assert result.mutation_name == "inject_irrelevant_rows" assert result.tables_affected == ["employees"] assert result.rows_added == 5 assert result.success is True def test_get_table_schemas_multi_table_with_fk_and_composite_pk(sample_db_path): schemas = get_table_schemas(sample_db_path) by_name = {schema.name: schema for schema in schemas} assert set(by_name) == { "audit_log", "courses", "departments", "employees", "enrollments", "students", } assert by_name["departments"].pk_columns == ["id"] assert by_name["employees"].fk_columns == [("department_id", "departments", "id")] assert by_name["enrollments"].pk_columns == ["student_id", "course_id"] assert set(by_name["enrollments"].fk_columns) == { ("student_id", "students", "id"), ("course_id", "courses", "id"), } def test_get_table_schemas_no_pk_table(sample_db_path): schemas = get_table_schemas(sample_db_path) by_name = {schema.name: schema for schema in schemas} assert by_name["audit_log"].pk_columns == [] def test_get_table_schemas_empty_db(tmp_path): db_path = tmp_path / "empty.sqlite" sqlite3.connect(db_path).close() assert get_table_schemas(str(db_path)) == [] def test_get_table_schemas_nonexistent_db_raises_operational_error(tmp_path): missing_path = tmp_path / "missing.sqlite" with pytest.raises(sqlite3.OperationalError): get_table_schemas(str(missing_path)) def test_detect_bridge_tables_identifies_tables_with_two_or_more_fks(sample_db_path): schemas = get_table_schemas(sample_db_path) assert detect_bridge_tables(schemas) == ["enrollments"] def test_detect_bridge_tables_empty_when_no_bridge_tables(): schemas = [ TableSchema( name="employees", columns=["id", "department_id"], pk_columns=["id"], fk_columns=[("department_id", "departments", "id")], ), TableSchema( name="departments", columns=["id", "name"], pk_columns=["id"], fk_columns=[], ), ] assert detect_bridge_tables(schemas) == [] def test_inject_irrelevant_rows_adds_rows_with_new_primary_keys(sample_db_path): schemas = get_table_schemas(sample_db_path) result = inject_irrelevant_rows(sample_db_path, schemas, n_rows=2) assert result.mutation_name == "inject_irrelevant_rows" assert result.success is True assert result.rows_added == 8 assert result.tables_affected == ["courses", "departments", "employees", "students"] with sqlite3.connect(sample_db_path) as connection: cursor = connection.cursor() cursor.execute("SELECT COUNT(*) FROM employees") assert cursor.fetchone()[0] == 4 cursor.execute("SELECT MIN(id), MAX(id) FROM employees") assert cursor.fetchone() == (1, 4) def test_inject_irrelevant_rows_preserves_existing_rows(sample_db_path): schemas = get_table_schemas(sample_db_path) inject_irrelevant_rows(sample_db_path, schemas, n_rows=1) with sqlite3.connect(sample_db_path) as connection: cursor = connection.cursor() cursor.execute("SELECT name FROM employees ORDER BY id") names = [row[0] for row in cursor.fetchall()] assert names[0:2] == ["Alice", "Bob"] def test_inject_irrelevant_rows_zero_rows_no_change(sample_db_path): schemas = get_table_schemas(sample_db_path) result = inject_irrelevant_rows(sample_db_path, schemas, n_rows=0) assert result.rows_added == 0 assert result.tables_affected == [] assert result.success is True with sqlite3.connect(sample_db_path) as connection: cursor = connection.cursor() cursor.execute("SELECT COUNT(*) FROM employees") assert cursor.fetchone()[0] == 2 def test_remap_ids_basic_changes_integer_primary_keys(sample_db_path): schemas = get_table_schemas(sample_db_path) with sqlite3.connect(sample_db_path) as connection: cursor = connection.cursor() cursor.execute("SELECT id FROM departments ORDER BY id") before_department_ids = [row[0] for row in cursor.fetchall()] result = remap_ids(sample_db_path, schemas) assert result.mutation_name == "remap_ids" assert result.success is True assert "departments" in result.tables_affected assert "employees" in result.tables_affected assert result.rows_added >= 2 with sqlite3.connect(sample_db_path) as connection: cursor = connection.cursor() cursor.execute("SELECT id FROM departments ORDER BY id") after_department_ids = [row[0] for row in cursor.fetchall()] assert after_department_ids != before_department_ids assert len(after_department_ids) == len(before_department_ids) def test_remap_ids_updates_foreign_keys_and_preserves_join(sample_db_path): schemas = get_table_schemas(sample_db_path) remap_ids(sample_db_path, schemas) with sqlite3.connect(sample_db_path) as connection: cursor = connection.cursor() cursor.execute("PRAGMA foreign_key_check") assert cursor.fetchall() == [] cursor.execute( "SELECT e.name, d.name FROM employees e " "JOIN departments d ON e.department_id = d.id " "ORDER BY e.name" ) joined = cursor.fetchall() assert joined == [("Alice", "Engineering"), ("Bob", "Sales")] def test_remap_ids_is_bijective_and_preserves_row_counts(sample_db_path): schemas = get_table_schemas(sample_db_path) with sqlite3.connect(sample_db_path) as connection: cursor = connection.cursor() cursor.execute("SELECT COUNT(*), COUNT(DISTINCT id) FROM departments") before_counts = cursor.fetchone() remap_ids(sample_db_path, schemas) with sqlite3.connect(sample_db_path) as connection: cursor = connection.cursor() cursor.execute("SELECT COUNT(*), COUNT(DISTINCT id) FROM departments") after_counts = cursor.fetchone() assert after_counts == before_counts def test_remap_ids_skips_tables_without_integer_primary_key(tmp_path): db_path = tmp_path / "text_pk.sqlite" with sqlite3.connect(db_path) as connection: cursor = connection.cursor() cursor.execute("CREATE TABLE labels (name TEXT PRIMARY KEY, value TEXT)") cursor.execute("INSERT INTO labels (name, value) VALUES ('alpha', '1')") connection.commit() schemas = get_table_schemas(str(db_path)) result = remap_ids(str(db_path), schemas) assert result.success is True assert result.rows_added == 0 assert result.tables_affected == [] def test_duplicate_bridge_rows_adds_rows_for_bridge_tables(tmp_path): db_path = tmp_path / "bridge.sqlite" with sqlite3.connect(db_path) as connection: cursor = connection.cursor() cursor.execute("PRAGMA foreign_keys = ON") cursor.execute("CREATE TABLE students (id INTEGER PRIMARY KEY, name TEXT)") cursor.execute("CREATE TABLE clubs (id INTEGER PRIMARY KEY, name TEXT)") cursor.execute( "CREATE TABLE club_memberships (" "student_id INTEGER," "club_id INTEGER," "FOREIGN KEY(student_id) REFERENCES students(id)," "FOREIGN KEY(club_id) REFERENCES clubs(id)" ")" ) cursor.execute( "INSERT INTO students (id, name) VALUES (1, 'Sam'), (2, 'Riley')" ) cursor.execute( "INSERT INTO clubs (id, name) VALUES (1, 'Chess'), (2, 'Robotics')" ) cursor.execute( "INSERT INTO club_memberships (student_id, club_id) VALUES (1, 1), (2, 2)" ) connection.commit() schemas = get_table_schemas(str(db_path)) bridge_tables = detect_bridge_tables(schemas) result = duplicate_bridge_rows(str(db_path), schemas, bridge_tables) assert result.mutation_name == "duplicate_bridge_rows" assert result.success is True assert result.rows_added == 2 assert result.tables_affected == ["club_memberships"] with sqlite3.connect(db_path) as connection: cursor = connection.cursor() cursor.execute("SELECT COUNT(*) FROM club_memberships") assert cursor.fetchone()[0] == 4 def test_duplicate_bridge_rows_empty_bridge_tables_returns_noop(sample_db_path): schemas = get_table_schemas(sample_db_path) result = duplicate_bridge_rows(sample_db_path, schemas, []) assert result.success is True assert result.rows_added == 0 assert result.tables_affected == [] def test_duplicate_bridge_rows_skips_rows_blocked_by_unique_constraints(sample_db_path): schemas = get_table_schemas(sample_db_path) result = duplicate_bridge_rows(sample_db_path, schemas, ["enrollments"]) assert result.success is True assert result.rows_added == 0 assert result.tables_affected == [] with sqlite3.connect(sample_db_path) as connection: cursor = connection.cursor() cursor.execute("SELECT COUNT(*) FROM enrollments") assert cursor.fetchone()[0] == 2 def test_duplicate_bridge_rows_ignores_nonexistent_tables(sample_db_path): schemas = get_table_schemas(sample_db_path) result = duplicate_bridge_rows(sample_db_path, schemas, ["missing_bridge"]) assert result.success is True assert result.rows_added == 0 assert result.tables_affected == [] def test_validate_gold_sql_valid_returns_serialized_result(sample_db_path): is_valid, result = validate_gold_sql( sample_db_path, "SELECT name FROM employees ORDER BY id", ) assert is_valid is True assert result == "[('Alice',), ('Bob',)]" def test_validate_gold_sql_empty_result_returns_false(sample_db_path): is_valid, result = validate_gold_sql( sample_db_path, "SELECT name FROM employees WHERE id = -1", ) assert is_valid is False assert result is None def test_validate_gold_sql_invalid_query_raises_operational_error(sample_db_path): with pytest.raises(sqlite3.OperationalError): validate_gold_sql(sample_db_path, "SELECT * FROM definitely_missing_table") def test_validate_gold_sql_honors_custom_timeout(sample_db_path): is_valid, result = validate_gold_sql( sample_db_path, "SELECT COUNT(*) FROM employees", timeout=0.001, ) assert is_valid is True assert result == "[(2,)]" def test_variant_result_dataclass_fields(sample_db_path): mutation = MutationResult( mutation_name="inject_irrelevant_rows", tables_affected=["employees"], rows_added=2, success=True, ) result = VariantResult( variant_path="/tmp/sample_variant.sqlite", original_path=sample_db_path, mutations_applied=[mutation], gold_sql_valid=True, gold_answer="[(1,)]", ) assert result.variant_path.endswith("sample_variant.sqlite") assert result.original_path == sample_db_path assert result.mutations_applied == [mutation] assert result.gold_sql_valid is True assert result.gold_answer == "[(1,)]" def test_generate_variant_default_applies_all_mutations_and_creates_file( sample_db_path, tmp_path ): output_dir = tmp_path / "variants" result = generate_variant( db_path=sample_db_path, gold_sql="SELECT name FROM employees", output_dir=str(output_dir), ) assert result.gold_sql_valid is True assert len(result.mutations_applied) == 3 assert [m.mutation_name for m in result.mutations_applied] == [ "inject_irrelevant_rows", "remap_ids", "duplicate_bridge_rows", ] assert (output_dir / "sample_variant_0.sqlite").exists() def test_generate_variant_single_mutation(sample_db_path, tmp_path): output_dir = tmp_path / "variants" result = generate_variant( db_path=sample_db_path, gold_sql="SELECT name FROM employees", output_dir=str(output_dir), mutations=["inject_irrelevant_rows"], ) assert result.gold_sql_valid is True assert [m.mutation_name for m in result.mutations_applied] == [ "inject_irrelevant_rows" ] def test_generate_variant_does_not_modify_original_db(sample_db_path, tmp_path): output_dir = tmp_path / "variants" with sqlite3.connect(sample_db_path) as connection: cursor = connection.cursor() cursor.execute("SELECT COUNT(*) FROM employees") before_count = cursor.fetchone()[0] generate_variant( db_path=sample_db_path, gold_sql="SELECT name FROM employees", output_dir=str(output_dir), ) with sqlite3.connect(sample_db_path) as connection: cursor = connection.cursor() cursor.execute("SELECT COUNT(*) FROM employees") after_count = cursor.fetchone()[0] assert before_count == after_count == 2 def test_generate_variant_raises_on_missing_db(tmp_path): with pytest.raises(FileNotFoundError): generate_variant( db_path=str(tmp_path / "missing.sqlite"), gold_sql="SELECT 1", output_dir=str(tmp_path / "variants"), ) def test_generate_variant_raises_on_unknown_mutation(sample_db_path, tmp_path): with pytest.raises(ValueError, match="Unknown mutation"): generate_variant( db_path=sample_db_path, gold_sql="SELECT name FROM employees", output_dir=str(tmp_path / "variants"), mutations=["unknown_mutation"], ) def test_generate_variant_invalid_gold_sql_discards_variant(sample_db_path, tmp_path): output_dir = tmp_path / "variants" result = generate_variant( db_path=sample_db_path, gold_sql="SELECT name FROM employees WHERE id = 1", output_dir=str(output_dir), ) assert result.gold_sql_valid is False assert result.gold_answer is None assert not (output_dir / "sample_variant_0.sqlite").exists() def test_generate_variant_uses_variant_id_in_filename(sample_db_path, tmp_path): output_dir = tmp_path / "variants" result = generate_variant( db_path=sample_db_path, gold_sql="SELECT name FROM employees", output_dir=str(output_dir), variant_id=7, ) assert result.variant_path.endswith("sample_variant_7.sqlite") def test_generate_variants_for_question_default_count(sample_db_path, tmp_path): output_dir = tmp_path / "variants" results = generate_variants_for_question( db_path=sample_db_path, gold_sql="SELECT name FROM employees", output_dir=str(output_dir), n_variants=2, ) assert len(results) == 2 assert all(result.gold_sql_valid for result in results) def test_generate_variants_for_question_zero_returns_empty_list( sample_db_path, tmp_path ): output_dir = tmp_path / "variants" results = generate_variants_for_question( db_path=sample_db_path, gold_sql="SELECT name FROM employees", output_dir=str(output_dir), n_variants=0, ) assert results == [] def test_generate_variants_for_question_returns_unique_paths(sample_db_path, tmp_path): output_dir = tmp_path / "variants" results = generate_variants_for_question( db_path=sample_db_path, gold_sql="SELECT name FROM employees", output_dir=str(output_dir), n_variants=3, ) paths = [result.variant_path for result in results] assert len(paths) == len(set(paths)) def test_synthetic_cli_smoke_generates_variants(sample_db_path, tmp_path, capsys): output_dir = tmp_path / "variants_cli" exit_code = synthetic_cli_main( [ "--db-path", sample_db_path, "--gold-sql", "SELECT name FROM employees", "--output-dir", str(output_dir), "--n-variants", "2", ] ) captured = capsys.readouterr() assert exit_code == 0 assert "Generated 2 valid variant(s)" in captured.out assert (output_dir / "sample_variant_0.sqlite").exists() assert (output_dir / "sample_variant_1.sqlite").exists() @pytest.mark.slow def test_generate_variants_integration_with_real_spider_database(tmp_path): case = _find_real_spider_case() if case is None: pytest.skip( "No local Spider DB + valid gold SQL pair found. " "Run: uv run python scripts/download_spider_databases.py" ) db_path, gold_sql, db_name = case output_dir = tmp_path / "spider_variants" / db_name results = generate_variants_for_question( db_path=str(db_path), gold_sql=gold_sql, output_dir=str(output_dir), n_variants=2, ) assert len(results) >= 1 variant = results[0] variant_path = Path(variant.variant_path) assert variant_path.exists() assert variant.gold_sql_valid is True assert any(mutation.rows_added > 0 for mutation in variant.mutations_applied) original_schema = _sqlite_table_definitions(db_path) variant_schema = _sqlite_table_definitions(variant_path) assert variant_schema == original_schema