""" data_factory/validator.py ========================== SQL execution validation layer. GUARANTEE: Every record that passes this validator has a SQL that: 1. Runs without error against the actual seeded SQLite schema 2. Returns at least one row (non-empty result) 3. Returns the expected column names No LLM-generated SQL ever reaches this validator — SQL always comes from the human-verified template library. This validator is an extra safety net to catch any copy-paste or formatting regressions. """ from __future__ import annotations import sqlite3 from dataclasses import dataclass, field from typing import Any, Optional from data_factory.schemas import build_connection, SCHEMA_CONTEXT from data_factory.templates import Template # ───────────────────────────────────────────────────────────────────────────── # DATA CLASSES # ───────────────────────────────────────────────────────────────────────────── @dataclass class ValidationResult: passed: bool sql: str error: Optional[str] = None row_count: int = 0 columns: list[str] = field(default_factory=list) @dataclass class DataRecord: """One training example ready to be written to JSONL/Parquet.""" domain: str difficulty: str sql: str nl_question: str # The NL paraphrase used as prompt persona: str # ceo | chatty | lazy_typist | non_techie | analyst | augmented has_order: bool schema_context: str row_count: int # From validation run columns: list[str] # From validation run source: str # "template_base" | "vllm_persona" | "rule_augmented" template_id: int # Index into ALL_TEMPLATES def to_training_dict(self) -> dict[str, Any]: """ Returns the dictionary that will be written to the output dataset. Format is compatible with TRL / HuggingFace `datasets`: prompt : chat-format messages list (system + user) sql : ground-truth SQL (label / reward reference) metadata: auxiliary fields for curriculum or filtering """ system_msg = ( "You are an expert SQL analyst. " "Write a single SELECT query that answers the question. " "Output ONLY the SQL query — no markdown, no explanation, no backticks." ) user_msg = ( f"DATABASE SCHEMA\n" f"---------------\n" f"{self.schema_context}\n\n" f"QUESTION: {self.nl_question}" ) return { "prompt": [ {"role": "system", "content": system_msg}, {"role": "user", "content": user_msg}, ], "sql": self.sql, "metadata": { "domain": self.domain, "difficulty": self.difficulty, "persona": self.persona, "has_order": self.has_order, "row_count": self.row_count, "columns": self.columns, "source": self.source, "template_id": self.template_id, }, } # ───────────────────────────────────────────────────────────────────────────── # VALIDATOR # ───────────────────────────────────────────────────────────────────────────── class SQLValidator: """ Validates SQL against a seeded in-memory SQLite connection. One validator per domain to reuse the same connection for all templates in that domain (performance optimization). """ def __init__(self, domain: str, seed: int = 42) -> None: self.domain = domain self._conn = build_connection(domain, seed=seed) def validate(self, sql: str) -> ValidationResult: """ Execute SQL and return a ValidationResult. Never raises — always returns a result object. """ sql = sql.strip().rstrip(";") if not sql: return ValidationResult(passed=False, sql=sql, error="Empty SQL string.") # Block any write operations first_word = sql.split()[0].lower() if sql.split() else "" forbidden = {"insert","update","delete","drop","alter","create","replace","truncate","pragma"} if first_word in forbidden: return ValidationResult( passed=False, sql=sql, error=f"Write operation '{first_word.upper()}' is not permitted." ) try: cur = self._conn.execute(sql) cols = [d[0] for d in cur.description] if cur.description else [] rows = cur.fetchall() return ValidationResult( passed=True, sql=sql, row_count=len(rows), columns=cols, ) except sqlite3.Error as exc: return ValidationResult(passed=False, sql=sql, error=str(exc)) def close(self) -> None: self._conn.close() def validate_template(template: Template, seed: int = 42) -> ValidationResult: """Convenience function: validate a single template.""" v = SQLValidator(template["domain"], seed=seed) result = v.validate(template["sql"]) v.close() return result def validate_all_templates(templates: list[Template], seed: int = 42) -> dict[str, Any]: """ Run validation across all templates. Returns a summary dict. Used during CI / smoke testing. """ from data_factory.schemas import SCHEMA_MAP validators = {domain: SQLValidator(domain, seed) for domain in SCHEMA_MAP} passed = [] failed = [] for i, t in enumerate(templates): v = validators[t["domain"]] result = v.validate(t["sql"]) if result.passed: passed.append(i) else: failed.append({"index": i, "domain": t["domain"], "sql": t["sql"][:80], "error": result.error}) for v in validators.values(): v.close() return { "total": len(templates), "passed": len(passed), "failed": len(failed), "failures": failed, } def build_record( template: Template, template_idx: int, nl_question: str, persona: str, source: str, validator: SQLValidator, ) -> Optional[DataRecord]: """ Validate the template SQL and, if it passes, build a DataRecord. Parameters ---------- template : The source template (contains SQL, domain, difficulty). template_idx : Index of template in ALL_TEMPLATES (for deduplication). nl_question : The NL paraphrase to use as the prompt. persona : Which persona/strategy generated this NL. source : 'template_base' | 'vllm_persona' | 'rule_augmented' validator : Pre-built SQLValidator for this domain. Returns None if validation fails. """ vr = validator.validate(template["sql"]) if not vr.passed: return None return DataRecord( domain=template["domain"], difficulty=template["difficulty"], sql=template["sql"], nl_question=nl_question, persona=persona, has_order=template["has_order"], schema_context=SCHEMA_CONTEXT[template["domain"]], row_count=vr.row_count, columns=vr.columns, source=source, template_id=template_idx, )