Spaces:
Sleeping
Sleeping
| """ | |
| run_data_factory.py | |
| ==================== | |
| Entry point and smoke-test runner for the NL2SQL Data Factory. | |
| Run this FIRST before running the full pipeline to verify: | |
| 1. All 66 SQL templates execute without errors | |
| 2. Rule augmentation produces diverse NL variants | |
| 3. Validators correctly accept/reject queries | |
| 4. Base pipeline generates well-formed JSONL records | |
| Usage: | |
| # Smoke test only (fast, ~10 seconds) | |
| python run_data_factory.py --smoke-test | |
| # Base mode (no GPU, generates all rule-augmented records) | |
| python run_data_factory.py --mode base | |
| # Full mode (H100 required) | |
| python run_data_factory.py --mode full --model meta-llama/Meta-Llama-3-70B-Instruct --tensor-parallel 4 | |
| # Preview what the dataset looks like | |
| python run_data_factory.py --smoke-test --show-samples 3 | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import sys | |
| import textwrap | |
| from pathlib import Path | |
| # Allow running from project root | |
| sys.path.insert(0, str(Path(__file__).parent)) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # SMOKE TEST | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def run_smoke_test(show_samples: int = 0) -> bool: | |
| print("\n" + "=" * 60) | |
| print(" NL2SQL DATA FACTORY β SMOKE TEST") | |
| print("=" * 60) | |
| all_passed = True | |
| # 1. Template validation | |
| print("\n[1/4] Validating all SQL templates against seeded data...") | |
| from data_factory.templates import ALL_TEMPLATES, template_stats | |
| from data_factory.validator import validate_all_templates | |
| stats = template_stats() | |
| result = validate_all_templates(ALL_TEMPLATES) | |
| print(f" Templates: {stats}") | |
| print(f" Validation: {result['passed']}/{result['total']} passed", end="") | |
| if result["failed"]: | |
| print(f" β {result['failed']} FAILURES:") | |
| for f in result["failures"]: | |
| print(f" [{f['domain']}] {f['sql']}... β {f['error']}") | |
| all_passed = False | |
| else: | |
| print(" β") | |
| # 2. Rule augmentation | |
| print("\n[2/4] Testing rule-based augmentation...") | |
| from data_factory.augmentor import augment_nl | |
| test_nls = [ | |
| "List all gold-tier customers ordered by name alphabetically. Return id, name, email, country.", | |
| "Which medications are prescribed most often? Return medication_name, category, times_prescribed.", | |
| "Rank active employees by salary within their department. Return salary_rank.", | |
| ] | |
| for nl in test_nls: | |
| variants = augment_nl(nl, n=3, seed=42) | |
| if not variants: | |
| print(f" FAIL: No variants generated for: {nl[:50]}") | |
| all_passed = False | |
| else: | |
| print(f" β {len(variants)} variants from: '{nl[:45]}...'") | |
| if show_samples > 0: | |
| for i, v in enumerate(variants[:show_samples]): | |
| print(f" [{i+1}] {v}") | |
| # 3. Validator accept/reject | |
| print("\n[3/4] Testing SQL validator accept/reject logic...") | |
| from data_factory.validator import SQLValidator | |
| v = SQLValidator("ecommerce") | |
| tests = [ | |
| ("SELECT id, name FROM customers WHERE tier = 'gold'", True, "valid SELECT"), | |
| ("INSERT INTO customers VALUES (1,'x','x@x.com','IN','gold','2024-01-01')", False, "rejected INSERT"), | |
| ("SELECT nonexistent_col FROM customers", False, "bad column name"), | |
| ("", False, "empty string"), | |
| ] | |
| for sql, expect_pass, label in tests: | |
| vr = v.validate(sql) | |
| status = "β" if vr.passed == expect_pass else "β" | |
| print(f" {status} {label}: passed={vr.passed}", end="") | |
| if not vr.passed: | |
| print(f" (error: {vr.error})", end="") | |
| print() | |
| if vr.passed != expect_pass: | |
| all_passed = False | |
| v.close() | |
| # 4. Mini base pipeline (first 5 templates only) | |
| print("\n[4/4] Running mini base pipeline (first 5 templates)...") | |
| from data_factory.pipeline import run_base_pipeline | |
| mini_templates = ALL_TEMPLATES[:5] | |
| records = run_base_pipeline(mini_templates, n_augmentations=2, seed=42) | |
| expected_min = 5 # at least canonical NLs | |
| if len(records) < expected_min: | |
| print(f" FAIL: Only {len(records)} records (expected β₯{expected_min})") | |
| all_passed = False | |
| else: | |
| print(f" β Generated {len(records)} records from 5 templates") | |
| # Validate structure | |
| required_keys = {"prompt", "sql", "metadata"} | |
| for rec in records[:3]: | |
| missing = required_keys - rec.keys() | |
| if missing: | |
| print(f" FAIL: Record missing keys: {missing}") | |
| all_passed = False | |
| break | |
| else: | |
| print(" β Record structure validated") | |
| if show_samples > 0 and records: | |
| print(f"\n --- Sample Record ---") | |
| sample = records[0] | |
| print(f" Domain: {sample['metadata']['domain']}") | |
| print(f" Difficulty: {sample['metadata']['difficulty']}") | |
| print(f" Persona: {sample['metadata']['persona']}") | |
| print(f" NL: {sample['prompt'][1]['content'].split('QUESTION: ')[-1][:100]}") | |
| print(f" SQL: {sample['sql'][:80]}...") | |
| # Summary | |
| print("\n" + "=" * 60) | |
| if all_passed: | |
| print(" ALL SMOKE TESTS PASSED β") | |
| print(" Safe to run: python run_data_factory.py --mode base") | |
| else: | |
| print(" SOME TESTS FAILED β β fix errors before running pipeline") | |
| print("=" * 60 + "\n") | |
| return all_passed | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # INSPECT DATASET | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def inspect_dataset(jsonl_path: str, n: int = 5) -> None: | |
| """Pretty-print N records from an output JSONL file.""" | |
| path = Path(jsonl_path) | |
| if not path.exists(): | |
| print(f"File not found: {path}") | |
| return | |
| records = [] | |
| with open(path, encoding="utf-8") as f: | |
| for i, line in enumerate(f): | |
| if i >= n: | |
| break | |
| records.append(json.loads(line)) | |
| print(f"\n{'='*65}") | |
| print(f" Showing {len(records)} records from {path.name}") | |
| print(f"{'='*65}") | |
| for i, rec in enumerate(records): | |
| nl = rec["prompt"][1]["content"].split("QUESTION:")[-1].strip() | |
| sql = rec["sql"] | |
| meta = rec["metadata"] | |
| print(f"\n[{i+1}] Domain={meta['domain']} | Difficulty={meta['difficulty']} | " | |
| f"Persona={meta['persona']} | Source={meta['source']}") | |
| print(f" NL: {textwrap.shorten(nl, 90)}") | |
| print(f" SQL: {textwrap.shorten(sql, 90)}") | |
| print() | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # MAIN | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main() -> None: | |
| parser = argparse.ArgumentParser( | |
| description="NL2SQL Data Factory β entry point.", | |
| formatter_class=argparse.RawTextHelpFormatter, | |
| ) | |
| parser.add_argument( | |
| "--smoke-test", action="store_true", | |
| help="Run smoke test only (validates all templates, no output written).", | |
| ) | |
| parser.add_argument( | |
| "--show-samples", type=int, default=0, | |
| help="During smoke test, show N sample NL variants and records.", | |
| ) | |
| parser.add_argument( | |
| "--inspect", type=str, default=None, | |
| help="Path to a JSONL output file to inspect.", | |
| ) | |
| parser.add_argument( | |
| "--inspect-n", type=int, default=5, | |
| help="Number of records to show when inspecting.", | |
| ) | |
| parser.add_argument( | |
| "--mode", choices=["base", "full"], default="base", | |
| help=( | |
| "base: rule augmentation only, ~450 records, no GPU needed.\n" | |
| "full: + vLLM persona variants, 500K+ records, H100 required." | |
| ), | |
| ) | |
| parser.add_argument("--model", default="meta-llama/Meta-Llama-3-70B-Instruct") | |
| parser.add_argument("--tensor-parallel", type=int, default=4) | |
| parser.add_argument("--n-rule-augments", type=int, default=5) | |
| parser.add_argument("--n-persona-variants", type=int, default=10) | |
| parser.add_argument("--batch-size", type=int, default=64) | |
| parser.add_argument("--temperature", type=float, default=0.85) | |
| parser.add_argument("--output-dir", default="generated_data/output") | |
| parser.add_argument("--checkpoint-dir", default="generated_data/checkpoints") | |
| parser.add_argument("--seed", type=int, default=42) | |
| parser.add_argument("--no-parquet", action="store_true") | |
| parser.add_argument("--resume", action="store_true") | |
| parser.add_argument( | |
| "--domains", nargs="+", | |
| choices=["ecommerce","healthcare","finance","hr"], | |
| default=["ecommerce","healthcare","finance","hr"], | |
| ) | |
| parser.add_argument( | |
| "--difficulties", nargs="+", | |
| choices=["easy","medium","hard"], | |
| default=["easy","medium","hard"], | |
| ) | |
| args = parser.parse_args() | |
| if args.smoke_test: | |
| ok = run_smoke_test(show_samples=args.show_samples) | |
| sys.exit(0 if ok else 1) | |
| if args.inspect: | |
| inspect_dataset(args.inspect, n=args.inspect_n) | |
| sys.exit(0) | |
| # Forward to pipeline | |
| from data_factory.pipeline import main as pipeline_main | |
| # Re-parse with pipeline's own parser by forwarding sys.argv | |
| pipeline_main() | |
| if __name__ == "__main__": | |
| main() | |