nl2sql-bench / data_factory /run_data_factory.py
ritvik360's picture
Upload folder using huggingface_hub
a39d8ef verified
"""
run_data_factory.py
====================
Entry point and smoke-test runner for the NL2SQL Data Factory.
Run this FIRST before running the full pipeline to verify:
1. All 66 SQL templates execute without errors
2. Rule augmentation produces diverse NL variants
3. Validators correctly accept/reject queries
4. Base pipeline generates well-formed JSONL records
Usage:
# Smoke test only (fast, ~10 seconds)
python run_data_factory.py --smoke-test
# Base mode (no GPU, generates all rule-augmented records)
python run_data_factory.py --mode base
# Full mode (H100 required)
python run_data_factory.py --mode full --model meta-llama/Meta-Llama-3-70B-Instruct --tensor-parallel 4
# Preview what the dataset looks like
python run_data_factory.py --smoke-test --show-samples 3
"""
from __future__ import annotations
import argparse
import json
import sys
import textwrap
from pathlib import Path
# Allow running from project root
sys.path.insert(0, str(Path(__file__).parent))
# ─────────────────────────────────────────────────────────────────────────────
# SMOKE TEST
# ─────────────────────────────────────────────────────────────────────────────
def run_smoke_test(show_samples: int = 0) -> bool:
print("\n" + "=" * 60)
print(" NL2SQL DATA FACTORY β€” SMOKE TEST")
print("=" * 60)
all_passed = True
# 1. Template validation
print("\n[1/4] Validating all SQL templates against seeded data...")
from data_factory.templates import ALL_TEMPLATES, template_stats
from data_factory.validator import validate_all_templates
stats = template_stats()
result = validate_all_templates(ALL_TEMPLATES)
print(f" Templates: {stats}")
print(f" Validation: {result['passed']}/{result['total']} passed", end="")
if result["failed"]:
print(f" ← {result['failed']} FAILURES:")
for f in result["failures"]:
print(f" [{f['domain']}] {f['sql']}... β†’ {f['error']}")
all_passed = False
else:
print(" βœ“")
# 2. Rule augmentation
print("\n[2/4] Testing rule-based augmentation...")
from data_factory.augmentor import augment_nl
test_nls = [
"List all gold-tier customers ordered by name alphabetically. Return id, name, email, country.",
"Which medications are prescribed most often? Return medication_name, category, times_prescribed.",
"Rank active employees by salary within their department. Return salary_rank.",
]
for nl in test_nls:
variants = augment_nl(nl, n=3, seed=42)
if not variants:
print(f" FAIL: No variants generated for: {nl[:50]}")
all_passed = False
else:
print(f" βœ“ {len(variants)} variants from: '{nl[:45]}...'")
if show_samples > 0:
for i, v in enumerate(variants[:show_samples]):
print(f" [{i+1}] {v}")
# 3. Validator accept/reject
print("\n[3/4] Testing SQL validator accept/reject logic...")
from data_factory.validator import SQLValidator
v = SQLValidator("ecommerce")
tests = [
("SELECT id, name FROM customers WHERE tier = 'gold'", True, "valid SELECT"),
("INSERT INTO customers VALUES (1,'x','x@x.com','IN','gold','2024-01-01')", False, "rejected INSERT"),
("SELECT nonexistent_col FROM customers", False, "bad column name"),
("", False, "empty string"),
]
for sql, expect_pass, label in tests:
vr = v.validate(sql)
status = "βœ“" if vr.passed == expect_pass else "βœ—"
print(f" {status} {label}: passed={vr.passed}", end="")
if not vr.passed:
print(f" (error: {vr.error})", end="")
print()
if vr.passed != expect_pass:
all_passed = False
v.close()
# 4. Mini base pipeline (first 5 templates only)
print("\n[4/4] Running mini base pipeline (first 5 templates)...")
from data_factory.pipeline import run_base_pipeline
mini_templates = ALL_TEMPLATES[:5]
records = run_base_pipeline(mini_templates, n_augmentations=2, seed=42)
expected_min = 5 # at least canonical NLs
if len(records) < expected_min:
print(f" FAIL: Only {len(records)} records (expected β‰₯{expected_min})")
all_passed = False
else:
print(f" βœ“ Generated {len(records)} records from 5 templates")
# Validate structure
required_keys = {"prompt", "sql", "metadata"}
for rec in records[:3]:
missing = required_keys - rec.keys()
if missing:
print(f" FAIL: Record missing keys: {missing}")
all_passed = False
break
else:
print(" βœ“ Record structure validated")
if show_samples > 0 and records:
print(f"\n --- Sample Record ---")
sample = records[0]
print(f" Domain: {sample['metadata']['domain']}")
print(f" Difficulty: {sample['metadata']['difficulty']}")
print(f" Persona: {sample['metadata']['persona']}")
print(f" NL: {sample['prompt'][1]['content'].split('QUESTION: ')[-1][:100]}")
print(f" SQL: {sample['sql'][:80]}...")
# Summary
print("\n" + "=" * 60)
if all_passed:
print(" ALL SMOKE TESTS PASSED βœ“")
print(" Safe to run: python run_data_factory.py --mode base")
else:
print(" SOME TESTS FAILED βœ— β€” fix errors before running pipeline")
print("=" * 60 + "\n")
return all_passed
# ─────────────────────────────────────────────────────────────────────────────
# INSPECT DATASET
# ─────────────────────────────────────────────────────────────────────────────
def inspect_dataset(jsonl_path: str, n: int = 5) -> None:
"""Pretty-print N records from an output JSONL file."""
path = Path(jsonl_path)
if not path.exists():
print(f"File not found: {path}")
return
records = []
with open(path, encoding="utf-8") as f:
for i, line in enumerate(f):
if i >= n:
break
records.append(json.loads(line))
print(f"\n{'='*65}")
print(f" Showing {len(records)} records from {path.name}")
print(f"{'='*65}")
for i, rec in enumerate(records):
nl = rec["prompt"][1]["content"].split("QUESTION:")[-1].strip()
sql = rec["sql"]
meta = rec["metadata"]
print(f"\n[{i+1}] Domain={meta['domain']} | Difficulty={meta['difficulty']} | "
f"Persona={meta['persona']} | Source={meta['source']}")
print(f" NL: {textwrap.shorten(nl, 90)}")
print(f" SQL: {textwrap.shorten(sql, 90)}")
print()
# ─────────────────────────────────────────────────────────────────────────────
# MAIN
# ─────────────────────────────────────────────────────────────────────────────
def main() -> None:
parser = argparse.ArgumentParser(
description="NL2SQL Data Factory β€” entry point.",
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument(
"--smoke-test", action="store_true",
help="Run smoke test only (validates all templates, no output written).",
)
parser.add_argument(
"--show-samples", type=int, default=0,
help="During smoke test, show N sample NL variants and records.",
)
parser.add_argument(
"--inspect", type=str, default=None,
help="Path to a JSONL output file to inspect.",
)
parser.add_argument(
"--inspect-n", type=int, default=5,
help="Number of records to show when inspecting.",
)
parser.add_argument(
"--mode", choices=["base", "full"], default="base",
help=(
"base: rule augmentation only, ~450 records, no GPU needed.\n"
"full: + vLLM persona variants, 500K+ records, H100 required."
),
)
parser.add_argument("--model", default="meta-llama/Meta-Llama-3-70B-Instruct")
parser.add_argument("--tensor-parallel", type=int, default=4)
parser.add_argument("--n-rule-augments", type=int, default=5)
parser.add_argument("--n-persona-variants", type=int, default=10)
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument("--temperature", type=float, default=0.85)
parser.add_argument("--output-dir", default="generated_data/output")
parser.add_argument("--checkpoint-dir", default="generated_data/checkpoints")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--no-parquet", action="store_true")
parser.add_argument("--resume", action="store_true")
parser.add_argument(
"--domains", nargs="+",
choices=["ecommerce","healthcare","finance","hr"],
default=["ecommerce","healthcare","finance","hr"],
)
parser.add_argument(
"--difficulties", nargs="+",
choices=["easy","medium","hard"],
default=["easy","medium","hard"],
)
args = parser.parse_args()
if args.smoke_test:
ok = run_smoke_test(show_samples=args.show_samples)
sys.exit(0 if ok else 1)
if args.inspect:
inspect_dataset(args.inspect, n=args.inspect_n)
sys.exit(0)
# Forward to pipeline
from data_factory.pipeline import main as pipeline_main
# Re-parse with pipeline's own parser by forwarding sys.argv
pipeline_main()
if __name__ == "__main__":
main()