""" data_factory/pipeline.py ========================= Master orchestration pipeline for the NL2SQL Synthetic Data Factory. This module ties together: 1. Template library (66 verified SQL templates across 4 domains) 2. Rule-based NL augmentation (augmentor.py) 3. vLLM persona-based NL generation (generator.py) 4. SQL execution validation (validator.py) 5. Output serialisation (JSONL + Parquet) Run modes: --mode base : Only uses template base_nl + rule augmentation (no GPU required) --mode full : base + vLLM persona generation (requires H100) Output dataset format (JSONL, one record per line): { "prompt": [{"role": "system", ...}, {"role": "user", ...}], "sql": "SELECT ...", "metadata": { "domain", "difficulty", "persona", ... } } This format is directly loadable by: datasets.load_dataset("json", data_files="output/train.jsonl") """ from __future__ import annotations import argparse import json import logging import os import random import time from pathlib import Path from typing import Any, Iterator, Optional logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", datefmt="%H:%M:%S", ) logger = logging.getLogger("pipeline") # ───────────────────────────────────────────────────────────────────────────── # HELPERS # ───────────────────────────────────────────────────────────────────────────── def _ensure_dirs(*dirs: Path) -> None: for d in dirs: d.mkdir(parents=True, exist_ok=True) def _write_jsonl(records: list[dict], path: Path) -> None: with open(path, "w", encoding="utf-8") as f: for rec in records: f.write(json.dumps(rec, ensure_ascii=False) + "\n") logger.info("Wrote %d records to %s", len(records), path) def _write_parquet(records: list[dict], path: Path) -> None: try: import pandas as pd df = pd.DataFrame(records) df.to_parquet(path, index=False, engine="pyarrow", compression="snappy") logger.info("Wrote %d records to %s (Parquet)", len(records), path) except ImportError: logger.warning("pandas/pyarrow not installed — skipping Parquet output.") def _train_val_test_split( records: list[dict], train_frac: float = 0.90, val_frac: float = 0.05, seed: int = 42, ) -> tuple[list[dict], list[dict], list[dict]]: """ Stratified split by (domain, difficulty) to ensure all combinations are represented in every split. """ rng = random.Random(seed) from collections import defaultdict buckets: dict[str, list[dict]] = defaultdict(list) for rec in records: key = f"{rec['metadata']['domain']}_{rec['metadata']['difficulty']}" buckets[key].append(rec) train, val, test = [], [], [] for key, bucket in buckets.items(): rng.shuffle(bucket) n = len(bucket) n_train = max(1, int(n * train_frac)) n_val = max(1, int(n * val_frac)) train.extend(bucket[:n_train]) val.extend(bucket[n_train:n_train + n_val]) test.extend(bucket[n_train + n_val:]) rng.shuffle(train) rng.shuffle(val) rng.shuffle(test) return train, val, test # ───────────────────────────────────────────────────────────────────────────── # PHASE 1: BASE + RULE AUGMENTATION (no GPU required) # ───────────────────────────────────────────────────────────────────────────── def run_base_pipeline( templates: list, n_augmentations: int = 5, seed: int = 42, ) -> list[dict]: """ Generate training records from: (a) the canonical base_nl of each template (b) rule-based augmented NL variants Returns a list of training dicts (ready to write to JSONL). """ from data_factory.augmentor import augment_nl from data_factory.validator import SQLValidator, build_record from data_factory.schemas import SCHEMA_MAP # Build one validator per domain (reuse connection across templates) validators = {domain: SQLValidator(domain, seed=seed) for domain in SCHEMA_MAP} records: list[dict] = [] for t_idx, template in enumerate(templates): v = validators[template["domain"]] # (a) Canonical base_nl rec = build_record( template=template, template_idx=t_idx, nl_question=template["base_nl"], persona="canonical", source="template_base", validator=v, ) if rec: records.append(rec.to_training_dict()) # (b) Rule-augmented variants augmented = augment_nl( nl_question=template["base_nl"], n=n_augmentations, seed=seed + t_idx, ) for nl_variant in augmented: rec = build_record( template=template, template_idx=t_idx, nl_question=nl_variant, persona="rule_augmented", source="rule_augmented", validator=v, ) if rec: records.append(rec.to_training_dict()) for v in validators.values(): v.close() logger.info("Base pipeline: %d records generated from %d templates.", len(records), len(templates)) return records # ───────────────────────────────────────────────────────────────────────────── # PHASE 2: vLLM PERSONA GENERATION (H100 required) # ───────────────────────────────────────────────────────────────────────────── def run_vllm_pipeline( templates: list, generator, # VLLMGenerator instance personas: list[str], n_variants_per_persona: int = 10, batch_size: int = 64, temperature: float = 0.85, max_new_tokens: int = 350, seed: int = 42, ) -> list[dict]: """ Generate additional NL variants using the LLM, then validate SQL. Returns a list of training dicts. """ from data_factory.generator import generate_persona_variants_batch from data_factory.validator import SQLValidator, build_record from data_factory.schemas import SCHEMA_MAP validators = {domain: SQLValidator(domain, seed=seed) for domain in SCHEMA_MAP} records: list[dict] = [] gen_iter = generate_persona_variants_batch( templates_subset=templates, generator=generator, personas=personas, n_variants_per_persona=n_variants_per_persona, batch_size=batch_size, temperature=temperature, max_new_tokens=max_new_tokens, ) for job_result in gen_iter: t_idx = job_result["template_idx"] persona = job_result["persona"] template = templates[t_idx] v = validators[template["domain"]] for nl_variant in job_result["nl_variants"]: rec = build_record( template=template, template_idx=t_idx, nl_question=nl_variant, persona=persona, source="vllm_persona", validator=v, ) if rec: records.append(rec.to_training_dict()) for v in validators.values(): v.close() logger.info("vLLM pipeline: %d records generated.", len(records)) return records # ───────────────────────────────────────────────────────────────────────────── # CHECKPOINT UTILITIES # ───────────────────────────────────────────────────────────────────────────── def save_checkpoint(records: list[dict], checkpoint_dir: Path, name: str) -> Path: path = checkpoint_dir / f"{name}.jsonl" _write_jsonl(records, path) return path def load_checkpoint(checkpoint_dir: Path, name: str) -> Optional[list[dict]]: path = checkpoint_dir / f"{name}.jsonl" if not path.exists(): return None records = [] with open(path, encoding="utf-8") as f: for line in f: line = line.strip() if line: records.append(json.loads(line)) logger.info("Loaded %d records from checkpoint %s", len(records), path) return records # ───────────────────────────────────────────────────────────────────────────── # DATASET STATISTICS # ───────────────────────────────────────────────────────────────────────────── def print_dataset_stats(records: list[dict]) -> None: from collections import Counter domains = Counter(r["metadata"]["domain"] for r in records) diffs = Counter(r["metadata"]["difficulty"] for r in records) personas = Counter(r["metadata"]["persona"] for r in records) sources = Counter(r["metadata"]["source"] for r in records) print("\n" + "=" * 55) print(f" DATASET STATISTICS ({len(records):,} total records)") print("=" * 55) print("\nBy Domain:") for k, v in sorted(domains.items()): print(f" {k:20s}: {v:6,} ({v/len(records)*100:.1f}%)") print("\nBy Difficulty:") for k, v in sorted(diffs.items()): print(f" {k:20s}: {v:6,} ({v/len(records)*100:.1f}%)") print("\nBy Persona/Source:") for k, v in sorted(personas.items()): print(f" {k:20s}: {v:6,}") print("\nBy Source:") for k, v in sorted(sources.items()): print(f" {k:20s}: {v:6,}") print("=" * 55 + "\n") # ───────────────────────────────────────────────────────────────────────────── # MAIN ENTRY POINT # ───────────────────────────────────────────────────────────────────────────── def main() -> None: parser = argparse.ArgumentParser( description="NL2SQL Synthetic Data Factory — generates verified training data." ) parser.add_argument( "--mode", choices=["base", "full"], default="base", help="base = rule augmentation only (no GPU). full = + vLLM on H100.", ) parser.add_argument("--model", default="meta-llama/Meta-Llama-3-70B-Instruct", help="HuggingFace model name for vLLM (full mode only).") parser.add_argument("--tensor-parallel", type=int, default=4, help="Tensor parallel size for vLLM (number of H100s).") parser.add_argument("--n-rule-augments", type=int, default=5, help="Number of rule-based NL augmentations per template.") parser.add_argument("--n-persona-variants", type=int, default=10, help="Number of vLLM NL variants per (template, persona) pair.") parser.add_argument("--batch-size", type=int, default=64, help="vLLM batch size (larger = faster on H100).") parser.add_argument("--temperature", type=float, default=0.85, help="Sampling temperature for vLLM generation.") parser.add_argument("--output-dir", type=str, default="generated_data/output", help="Directory to write final dataset files.") parser.add_argument("--checkpoint-dir", type=str, default="generated_data/checkpoints", help="Directory for intermediate checkpoints.") parser.add_argument("--seed", type=int, default=42, help="Global random seed.") parser.add_argument("--no-parquet", action="store_true", help="Skip Parquet output (write only JSONL).") parser.add_argument("--resume", action="store_true", help="Resume from latest checkpoint if available.") parser.add_argument("--domains", nargs="+", choices=["ecommerce","healthcare","finance","hr"], default=["ecommerce","healthcare","finance","hr"], help="Domains to include (default: all 4).") parser.add_argument("--difficulties", nargs="+", choices=["easy","medium","hard"], default=["easy","medium","hard"], help="Difficulty levels to include (default: all 3).") args = parser.parse_args() output_dir = Path(args.output_dir) checkpoint_dir = Path(args.checkpoint_dir) _ensure_dirs(output_dir, checkpoint_dir) # ── Load templates ───────────────────────────────────────────────────── from data_factory.templates import ALL_TEMPLATES templates = [ t for t in ALL_TEMPLATES if t["domain"] in args.domains and t["difficulty"] in args.difficulties ] logger.info("Loaded %d templates (domains=%s, difficulties=%s).", len(templates), args.domains, args.difficulties) # ── Phase 1: Base + rule augmentation ───────────────────────────────── all_records: list[dict] = [] ckpt_base = load_checkpoint(checkpoint_dir, "phase1_base") if args.resume else None if ckpt_base is not None: all_records.extend(ckpt_base) logger.info("Resumed Phase 1 from checkpoint (%d records).", len(ckpt_base)) else: logger.info("=== Phase 1: Base + Rule Augmentation ===") base_records = run_base_pipeline( templates=templates, n_augmentations=args.n_rule_augments, seed=args.seed, ) all_records.extend(base_records) save_checkpoint(base_records, checkpoint_dir, "phase1_base") # ── Phase 2: vLLM persona generation (full mode only) ───────────────── if args.mode == "full": ckpt_vllm = load_checkpoint(checkpoint_dir, "phase2_vllm") if args.resume else None if ckpt_vllm is not None: all_records.extend(ckpt_vllm) logger.info("Resumed Phase 2 from checkpoint (%d records).", len(ckpt_vllm)) else: logger.info("=== Phase 2: vLLM Persona Generation ===") from data_factory.generator import VLLMGenerator from data_factory.config import PERSONAS generator = VLLMGenerator( model_name=args.model, mode="offline", tensor_parallel_size=args.tensor_parallel, gpu_memory_utilization=0.90, ) vllm_records = run_vllm_pipeline( templates=templates, generator=generator, personas=PERSONAS, n_variants_per_persona=args.n_persona_variants, batch_size=args.batch_size, temperature=args.temperature, max_new_tokens=350, seed=args.seed, ) all_records.extend(vllm_records) save_checkpoint(vllm_records, checkpoint_dir, "phase2_vllm") # ── Deduplication ────────────────────────────────────────────────────── logger.info("Deduplicating %d records...", len(all_records)) seen_nl: set[str] = set() deduped: list[dict] = [] for rec in all_records: nl = rec["prompt"][1]["content"] # user message contains the NL question if nl not in seen_nl: seen_nl.add(nl) deduped.append(rec) logger.info("After dedup: %d unique records (removed %d duplicates).", len(deduped), len(all_records) - len(deduped)) # ── Statistics ───────────────────────────────────────────────────────── print_dataset_stats(deduped) # ── Train / Val / Test split ─────────────────────────────────────────── train, val, test = _train_val_test_split(deduped, seed=args.seed) logger.info("Split: train=%d | val=%d | test=%d", len(train), len(val), len(test)) # ── Write outputs ───────────────────────────────────────────────────── _write_jsonl(train, output_dir / "train.jsonl") _write_jsonl(val, output_dir / "val.jsonl") _write_jsonl(test, output_dir / "test.jsonl") if not args.no_parquet: _write_parquet(train, output_dir / "train.parquet") _write_parquet(val, output_dir / "val.parquet") _write_parquet(test, output_dir / "test.parquet") # ── Write dataset card ───────────────────────────────────────────────── card = { "name": "NL2SQL-Bench Synthetic Training Dataset", "version": "1.0", "generated_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), "total_records": len(deduped), "splits": {"train": len(train), "val": len(val), "test": len(test)}, "domains": args.domains, "difficulties": args.difficulties, "mode": args.mode, "seed": args.seed, "sql_guarantee": ( "Every SQL in this dataset was human-authored and execution-validated " "against a seeded SQLite database. Zero LLM-generated SQL." ), } with open(output_dir / "dataset_card.json", "w") as f: json.dump(card, f, indent=2) logger.info("=== Done! Dataset written to %s ===", output_dir) if __name__ == "__main__": main()