""" generate_data.py — NL2SQL Synthetic Data Factory ================================================= Designed for H100 + vLLM. Produces a clean JSONL file ready for SFT or GRPO training with the nl2sql-bench codebase (schema: e-commerce SQLite). Architecture ------------ 1. SQL_TEMPLATES — 120+ ground-truth SQLs, hand-written and verified, NEVER LLM-generated. 2. SQLiteValidator — executes every SQL against the actual seeded DB; discards any failure. 3. VLLMGenerator — async batched calls to a local vLLM server for NL paraphrasing. 4. RuleAugmentor — pure-Python synonym / date-format / condition-order augmentation. 5. DataFactory — orchestrates the full pipeline; writes JSONL with checkpointing. Output schema (one JSON object per line) ----------------------------------------- { "id": "easy_001_persona_ceo", "difficulty": "easy" | "medium" | "hard", "persona": "ceo" | "chatty" | "lazy" | "confused" | "analyst", "question": "", "sql": "", "db_result_ok": true, # always true — failures are discarded "augmented": false # true when rule-augmentor modified the NL } Usage ----- # 1. Start vLLM server (H100): # vllm serve meta-llama/Meta-Llama-3-70B-Instruct \ # --tensor-parallel-size 4 --port 8001 \ # --max-model-len 4096 --gpu-memory-utilization 0.92 # 2. Run this script (place it next to the nl2sql-bench folder): # python generate_data.py \ # --vllm-url http://localhost:8001/v1 \ # --model meta-llama/Meta-Llama-3-70B-Instruct \ # --output nl2sql_train.jsonl \ # --personas-per-template 5 \ # --aug-rounds 2 \ # --batch-size 64 Requirements ------------ pip install openai tqdm (vLLM + your model already running separately) IMPORTANT: Copy server/db/schema.sql and server/db/seed.py from nl2sql-bench into the same directory as this script, OR set --bench-root to the repo root. """ from __future__ import annotations import argparse import asyncio import hashlib import json import logging import os import random import re import sqlite3 import sys import time from copy import deepcopy from dataclasses import dataclass, asdict from pathlib import Path from typing import Any, Dict, List, Optional, Tuple from openai import AsyncOpenAI from tqdm import tqdm # ───────────────────────────────────────────────────────────────────────────── # Logging # ───────────────────────────────────────────────────────────────────────────── logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)-8s %(message)s", datefmt="%H:%M:%S", ) log = logging.getLogger("data-factory") # ───────────────────────────────────────────────────────────────────────────── # Database: build & validate # ───────────────────────────────────────────────────────────────────────────── SCHEMA_SQL = """ CREATE TABLE IF NOT EXISTS categories ( id INTEGER PRIMARY KEY, name TEXT NOT NULL UNIQUE ); CREATE TABLE IF NOT EXISTS products ( id INTEGER PRIMARY KEY, name TEXT NOT NULL, category_id INTEGER NOT NULL REFERENCES categories(id), price REAL NOT NULL CHECK(price >= 0), stock_quantity INTEGER NOT NULL DEFAULT 0 ); CREATE TABLE IF NOT EXISTS customers ( id INTEGER PRIMARY KEY, name TEXT NOT NULL, email TEXT NOT NULL UNIQUE, country TEXT NOT NULL, tier TEXT NOT NULL DEFAULT 'bronze' CHECK(tier IN ('bronze', 'silver', 'gold')), created_at TEXT NOT NULL ); CREATE TABLE IF NOT EXISTS orders ( id INTEGER PRIMARY KEY, customer_id INTEGER NOT NULL REFERENCES customers(id), status TEXT NOT NULL DEFAULT 'pending' CHECK(status IN ('pending','processing','shipped','delivered','cancelled')), created_at TEXT NOT NULL, total_amount REAL NOT NULL CHECK(total_amount >= 0) ); CREATE TABLE IF NOT EXISTS order_items ( id INTEGER PRIMARY KEY, order_id INTEGER NOT NULL REFERENCES orders(id), product_id INTEGER NOT NULL REFERENCES products(id), quantity INTEGER NOT NULL CHECK(quantity > 0), unit_price REAL NOT NULL CHECK(unit_price >= 0) ); CREATE TABLE IF NOT EXISTS reviews ( id INTEGER PRIMARY KEY, product_id INTEGER NOT NULL REFERENCES products(id), customer_id INTEGER NOT NULL REFERENCES customers(id), rating INTEGER NOT NULL CHECK(rating BETWEEN 1 AND 5), created_at TEXT NOT NULL ); """ # Minimal seeder so the validator can run the SQL against real data. # Mirrors the logic in nl2sql-bench/server/db/seed.py (fixed seed = 42). SEED_SCRIPT = """ import random, sqlite3 from datetime import date, timedelta RNG = random.Random(42) CATEGORIES = ["Electronics","Clothing","Books","Home & Garden", "Sports & Outdoors","Toys & Games","Beauty","Automotive"] PRODUCTS = { "Electronics": ["Wireless Headphones","USB-C Hub","Mechanical Keyboard", "Webcam 4K","Portable Charger","Smart Speaker", "Monitor Stand","HDMI Cable 2.1"], "Clothing": ["Cotton T-Shirt","Slim Fit Jeans","Hoodie", "Running Shorts","Winter Jacket","Polo Shirt", "Casual Sneakers","Wool Socks"], "Books": ["Clean Code","Designing Data-Intensive Applications", "The Pragmatic Programmer","System Design Interview", "Deep Learning Book","Python Cookbook", "Domain-Driven Design","Refactoring"], "Home & Garden": ["Coffee Maker","Air Purifier","LED Desk Lamp", "Plant Pot Set","Storage Organiser","Cutting Board", "Vacuum Cleaner","Electric Kettle"], "Sports & Outdoors": ["Yoga Mat","Resistance Bands","Cycling Gloves", "Trekking Poles","Water Bottle 1L","Jump Rope", "Foam Roller","Compression Socks"], "Toys & Games": ["Lego City Set","Card Game Pack","Puzzle 1000pc", "Remote Control Car","Building Blocks", "Board Game Strategy","Art Set","Toy Drone"], "Beauty": ["Face Serum","SPF 50 Sunscreen","Lip Balm", "Shampoo Pro","Hair Mask","Eye Cream", "Vitamin C Cream","Toner Mist"], "Automotive": ["Car Phone Mount","Dash Cam","Tyre Inflator", "Car Vacuum","Seat Cushion","Steering Wheel Cover", "OBD Scanner","Jump Starter"], } COUNTRIES = ["India","USA","Germany","UK","Canada", "Australia","France","Brazil","Japan","Singapore"] TIERS = ["bronze","silver","gold"] STATUSES = ["pending","processing","shipped","delivered","cancelled"] FIRST = ["Aarav","Priya","Rahul","Neha","Arjun","Sneha","Vikram","Pooja", "Karthik","Divya","James","Sarah","Michael","Emily","David","Jessica", "Hans","Lena","Oliver","Sofia","Pierre","Amelie","Carlos","Laura", "Yuki","Hana","Wei","Mei","Aiden","Zara"] LAST = ["Sharma","Singh","Patel","Kumar","Gupta","Verma","Nair","Reddy", "Smith","Johnson","Brown","Williams","Jones","Davis","Wilson", "Müller","Schmidt","Schneider","Fischer","Weber", "Martin","Bernard","Thomas","Richard","Petit", "Garcia","Martinez","Lopez","Sanchez","Gonzalez"] def _date(start=2022, end=2025): s = date(start, 1, 1) e = date(end, 12, 31) return str(s + timedelta(days=RNG.randint(0, (e - s).days))) def seed(conn): c = conn.cursor() for cat in CATEGORIES: c.execute("INSERT OR IGNORE INTO categories(name) VALUES (?)", (cat,)) conn.commit() cat_ids = {r[1]: r[0] for r in conn.execute("SELECT id, name FROM categories")} for cat, prods in PRODUCTS.items(): for pname in prods: c.execute( "INSERT OR IGNORE INTO products(name,category_id,price,stock_quantity) VALUES (?,?,?,?)", (pname, cat_ids[cat], round(RNG.uniform(5, 500), 2), RNG.randint(0, 200)), ) conn.commit() for i in range(200): name = f"{RNG.choice(FIRST)} {RNG.choice(LAST)}" email = f"user{i}@example.com" c.execute( "INSERT OR IGNORE INTO customers(name,email,country,tier,created_at) VALUES (?,?,?,?,?)", (name, email, RNG.choice(COUNTRIES), RNG.choice(TIERS), _date()), ) conn.commit() cust_ids = [r[0] for r in conn.execute("SELECT id FROM customers")] prod_ids = [r[0] for r in conn.execute("SELECT id FROM products")] for _ in range(600): cid = RNG.choice(cust_ids) amt = round(RNG.uniform(10, 1000), 2) status = RNG.choice(STATUSES) d = _date() c.execute( "INSERT INTO orders(customer_id,status,created_at,total_amount) VALUES (?,?,?,?)", (cid, status, d, amt), ) conn.commit() ord_ids = [r[0] for r in conn.execute("SELECT id FROM orders")] for oid in ord_ids: for _ in range(RNG.randint(1, 4)): pid = RNG.choice(prod_ids) qty = RNG.randint(1, 5) price = round(RNG.uniform(5, 500), 2) c.execute( "INSERT INTO order_items(order_id,product_id,quantity,unit_price) VALUES (?,?,?,?)", (oid, pid, qty, price), ) conn.commit() for _ in range(400): pid = RNG.choice(prod_ids) cid = RNG.choice(cust_ids) rating = RNG.randint(1, 5) c.execute( "INSERT INTO reviews(product_id,customer_id,rating,created_at) VALUES (?,?,?,?)", (pid, cid, rating, _date()), ) conn.commit() """ def build_db() -> sqlite3.Connection: """Build an in-memory SQLite DB with schema + seed data.""" conn = sqlite3.connect(":memory:") conn.executescript(SCHEMA_SQL) exec(SEED_SCRIPT, {"conn": conn}) # run the seeder inline conn.row_factory = sqlite3.Row log.info("In-memory DB built and seeded.") return conn class SQLiteValidator: """Execute SQL against the seeded DB; return (rows, error).""" def __init__(self, conn: sqlite3.Connection): self.conn = conn def validate(self, sql: str) -> Tuple[bool, Optional[str]]: sql = sql.strip().rstrip(";") if not sql: return False, "Empty SQL" first = sql.split()[0].lower() if first != "select": return False, f"Non-SELECT statement: {first}" try: cur = self.conn.execute(sql) cur.fetchmany(500) return True, None except sqlite3.Error as exc: return False, str(exc) # ───────────────────────────────────────────────────────────────────────────── # SQL Template Library (ground-truth, hand-written, execution-validated) # ───────────────────────────────────────────────────────────────────────────── @dataclass class SQLTemplate: id: str difficulty: str # easy | medium | hard description: str # plain-English description fed to the LLM sql: str order_sensitive: bool = False # NOTE: Every SQL here uses only the 6 tables in the schema and valid SQLite syntax. # They are intentionally grouped by the SQL pattern they teach, not just by difficulty. EASY_TEMPLATES: List[SQLTemplate] = [ # ── Equality filter ────────────────────────────────────────────────────── SQLTemplate( id="easy_001", difficulty="easy", description=( "List all gold-tier customers, ordered alphabetically by name. " "Return id, name, email, country." ), sql=( "SELECT id, name, email, country " "FROM customers " "WHERE tier = 'gold' " "ORDER BY name ASC" ), order_sensitive=True, ), SQLTemplate( id="easy_002", difficulty="easy", description=( "Show all products priced above $100, sorted by price descending. " "Return id, name, price." ), sql=( "SELECT id, name, price " "FROM products " "WHERE price > 100 " "ORDER BY price DESC" ), order_sensitive=True, ), SQLTemplate( id="easy_003", difficulty="easy", description=( "Find all delivered orders with a total_amount greater than $200, " "sorted by total_amount descending. " "Return id, customer_id, total_amount, created_at." ), sql=( "SELECT id, customer_id, total_amount, created_at " "FROM orders " "WHERE status = 'delivered' AND total_amount > 200 " "ORDER BY total_amount DESC" ), order_sensitive=True, ), SQLTemplate( id="easy_004", difficulty="easy", description=( "Return the top 5 most expensive products. Return id, name, price." ), sql=( "SELECT id, name, price " "FROM products " "ORDER BY price DESC " "LIMIT 5" ), order_sensitive=True, ), SQLTemplate( id="easy_005", difficulty="easy", description=( "List all distinct countries where customers come from, sorted alphabetically. " "Return a single column: country." ), sql=( "SELECT DISTINCT country " "FROM customers " "ORDER BY country ASC" ), order_sensitive=True, ), SQLTemplate( id="easy_006", difficulty="easy", description=( "Show all pending orders, ordered by created_at descending. " "Return id, customer_id, total_amount, created_at." ), sql=( "SELECT id, customer_id, total_amount, created_at " "FROM orders " "WHERE status = 'pending' " "ORDER BY created_at DESC" ), order_sensitive=True, ), SQLTemplate( id="easy_007", difficulty="easy", description=( "Find all products with zero stock (stock_quantity = 0). " "Return id, name, price, category_id." ), sql=( "SELECT id, name, price, category_id " "FROM products " "WHERE stock_quantity = 0" ), ), SQLTemplate( id="easy_008", difficulty="easy", description=( "How many customers are there in total? Return a single value: total_customers." ), sql="SELECT COUNT(*) AS total_customers FROM customers", ), SQLTemplate( id="easy_009", difficulty="easy", description=( "What is the most expensive product price in the store? " "Return a single value: max_price." ), sql="SELECT MAX(price) AS max_price FROM products", ), SQLTemplate( id="easy_010", difficulty="easy", description=( "What is the cheapest product price in the store? " "Return a single value: min_price." ), sql="SELECT MIN(price) AS min_price FROM products", ), SQLTemplate( id="easy_011", difficulty="easy", description=( "What is the average price of all products? " "Round to 2 decimal places. Return: avg_price." ), sql="SELECT ROUND(AVG(price), 2) AS avg_price FROM products", ), SQLTemplate( id="easy_012", difficulty="easy", description=( "Show all customers from India, sorted by name ascending. " "Return id, name, email, tier." ), sql=( "SELECT id, name, email, tier " "FROM customers " "WHERE country = 'India' " "ORDER BY name ASC" ), order_sensitive=True, ), SQLTemplate( id="easy_013", difficulty="easy", description=( "List the 10 most recently placed orders. " "Return id, customer_id, status, created_at, total_amount." ), sql=( "SELECT id, customer_id, status, created_at, total_amount " "FROM orders " "ORDER BY created_at DESC " "LIMIT 10" ), order_sensitive=True, ), SQLTemplate( id="easy_014", difficulty="easy", description=( "Find all reviews with a rating of 5 stars. " "Return id, product_id, customer_id, created_at." ), sql=( "SELECT id, product_id, customer_id, created_at " "FROM reviews " "WHERE rating = 5" ), ), SQLTemplate( id="easy_015", difficulty="easy", description=( "Find all reviews with a rating of 1 star (lowest possible). " "Return id, product_id, customer_id, created_at." ), sql=( "SELECT id, product_id, customer_id, created_at " "FROM reviews " "WHERE rating = 1" ), ), SQLTemplate( id="easy_016", difficulty="easy", description=( "Count the number of cancelled orders. Return: cancelled_count." ), sql=( "SELECT COUNT(*) AS cancelled_count " "FROM orders " "WHERE status = 'cancelled'" ), ), SQLTemplate( id="easy_017", difficulty="easy", description=( "List all products with stock_quantity greater than 100, " "sorted by stock_quantity descending. Return id, name, stock_quantity." ), sql=( "SELECT id, name, stock_quantity " "FROM products " "WHERE stock_quantity > 100 " "ORDER BY stock_quantity DESC" ), order_sensitive=True, ), SQLTemplate( id="easy_018", difficulty="easy", description=( "Find all silver-tier customers from the USA. " "Return id, name, email." ), sql=( "SELECT id, name, email " "FROM customers " "WHERE tier = 'silver' AND country = 'USA'" ), ), SQLTemplate( id="easy_019", difficulty="easy", description=( "What is the total revenue from all delivered orders? " "Round to 2 decimal places. Return: total_revenue." ), sql=( "SELECT ROUND(SUM(total_amount), 2) AS total_revenue " "FROM orders " "WHERE status = 'delivered'" ), ), SQLTemplate( id="easy_020", difficulty="easy", description=( "List all orders placed in 2024, sorted by created_at ascending. " "Return id, customer_id, status, total_amount, created_at." ), sql=( "SELECT id, customer_id, status, total_amount, created_at " "FROM orders " "WHERE created_at >= '2024-01-01' AND created_at < '2025-01-01' " "ORDER BY created_at ASC" ), order_sensitive=True, ), SQLTemplate( id="easy_021", difficulty="easy", description=( "Show the bottom 5 cheapest products. Return id, name, price." ), sql=( "SELECT id, name, price " "FROM products " "ORDER BY price ASC " "LIMIT 5" ), order_sensitive=True, ), SQLTemplate( id="easy_022", difficulty="easy", description=( "Count how many products exist in the catalogue. Return: product_count." ), sql="SELECT COUNT(*) AS product_count FROM products", ), SQLTemplate( id="easy_023", difficulty="easy", description=( "List all distinct order statuses that exist in the orders table. " "Return a single column: status." ), sql="SELECT DISTINCT status FROM orders ORDER BY status ASC", order_sensitive=True, ), SQLTemplate( id="easy_024", difficulty="easy", description=( "Find customers who joined (created_at) in 2023. " "Return id, name, country, tier, created_at, sorted by created_at ascending." ), sql=( "SELECT id, name, country, tier, created_at " "FROM customers " "WHERE created_at >= '2023-01-01' AND created_at < '2024-01-01' " "ORDER BY created_at ASC" ), order_sensitive=True, ), SQLTemplate( id="easy_025", difficulty="easy", description=( "Show all orders with total_amount between $50 and $150 inclusive. " "Return id, customer_id, total_amount, status." ), sql=( "SELECT id, customer_id, total_amount, status " "FROM orders " "WHERE total_amount BETWEEN 50 AND 150" ), ), SQLTemplate( id="easy_026", difficulty="easy", description=( "How many distinct customers have placed at least one order? " "Return a single value: customers_with_orders." ), sql=( "SELECT COUNT(DISTINCT customer_id) AS customers_with_orders " "FROM orders" ), ), SQLTemplate( id="easy_027", difficulty="easy", description=( "What is the total number of order line items across all orders? " "Return: total_line_items." ), sql="SELECT COUNT(*) AS total_line_items FROM order_items", ), SQLTemplate( id="easy_028", difficulty="easy", description=( "List all products priced between $20 and $80 inclusive, sorted by price ascending. " "Return id, name, price." ), sql=( "SELECT id, name, price " "FROM products " "WHERE price BETWEEN 20 AND 80 " "ORDER BY price ASC" ), order_sensitive=True, ), SQLTemplate( id="easy_029", difficulty="easy", description=( "Show all gold-tier customers from Germany. " "Return id, name, email, created_at." ), sql=( "SELECT id, name, email, created_at " "FROM customers " "WHERE tier = 'gold' AND country = 'Germany'" ), ), SQLTemplate( id="easy_030", difficulty="easy", description=( "What is the average rating across all reviews in the system? " "Round to 2 decimal places. Return: avg_rating." ), sql="SELECT ROUND(AVG(rating), 2) AS avg_rating FROM reviews", ), ] MEDIUM_TEMPLATES: List[SQLTemplate] = [ # ── JOIN + COUNT ───────────────────────────────────────────────────────── SQLTemplate( id="med_001", difficulty="medium", description=( "How many orders has each customer placed? Include customers with zero orders. " "Return customer_name and order_count. Sort by order_count descending, " "then customer_name ascending." ), sql=( "SELECT c.name AS customer_name, COUNT(o.id) AS order_count " "FROM customers c " "LEFT JOIN orders o ON c.id = o.customer_id " "GROUP BY c.id, c.name " "ORDER BY order_count DESC, customer_name ASC" ), order_sensitive=True, ), SQLTemplate( id="med_002", difficulty="medium", description=( "Average product rating per category, only for categories that have at least one review. " "Return category_name and avg_rating (rounded to 2 dp). Sort by avg_rating descending." ), sql=( "SELECT c.name AS category_name, ROUND(AVG(r.rating), 2) AS avg_rating " "FROM categories c " "JOIN products p ON p.category_id = c.id " "JOIN reviews r ON r.product_id = p.id " "GROUP BY c.id, c.name " "ORDER BY avg_rating DESC" ), order_sensitive=True, ), SQLTemplate( id="med_003", difficulty="medium", description=( "Which categories have more than 5 in-stock products (stock_quantity > 0)? " "Return category_name and in_stock_count. Sort by in_stock_count descending." ), sql=( "SELECT c.name AS category_name, COUNT(p.id) AS in_stock_count " "FROM categories c " "JOIN products p ON p.category_id = c.id " "WHERE p.stock_quantity > 0 " "GROUP BY c.id, c.name " "HAVING COUNT(p.id) > 5 " "ORDER BY in_stock_count DESC" ), order_sensitive=True, ), SQLTemplate( id="med_004", difficulty="medium", description=( "Which customers have spent more than $500 on delivered orders? " "Return customer_name and total_spent (rounded to 2 dp). Sort by total_spent descending." ), sql=( "SELECT c.name AS customer_name, ROUND(SUM(o.total_amount), 2) AS total_spent " "FROM customers c " "JOIN orders o ON o.customer_id = c.id " "WHERE o.status = 'delivered' " "GROUP BY c.id, c.name " "HAVING SUM(o.total_amount) > 500 " "ORDER BY total_spent DESC" ), order_sensitive=True, ), SQLTemplate( id="med_005", difficulty="medium", description=( "Total quantity sold for each product that appears in at least one order. " "Return product_name and total_quantity_sold. Sort by total_quantity_sold descending." ), sql=( "SELECT p.name AS product_name, SUM(oi.quantity) AS total_quantity_sold " "FROM products p " "JOIN order_items oi ON oi.product_id = p.id " "GROUP BY p.id, p.name " "ORDER BY total_quantity_sold DESC" ), order_sensitive=True, ), SQLTemplate( id="med_006", difficulty="medium", description=( "Number of reviews per product, only for products with at least 3 reviews. " "Return product_name and review_count. Sort by review_count descending." ), sql=( "SELECT p.name AS product_name, COUNT(r.id) AS review_count " "FROM products p " "JOIN reviews r ON r.product_id = p.id " "GROUP BY p.id, p.name " "HAVING COUNT(r.id) >= 3 " "ORDER BY review_count DESC" ), order_sensitive=True, ), SQLTemplate( id="med_007", difficulty="medium", description=( "Show the total revenue (sum of total_amount) per country from all orders, " "regardless of status. Return country and total_revenue (rounded to 2 dp). " "Sort by total_revenue descending." ), sql=( "SELECT c.country, ROUND(SUM(o.total_amount), 2) AS total_revenue " "FROM customers c " "JOIN orders o ON o.customer_id = c.id " "GROUP BY c.country " "ORDER BY total_revenue DESC" ), order_sensitive=True, ), SQLTemplate( id="med_008", difficulty="medium", description=( "For each customer tier (bronze, silver, gold) show the average order value " "from delivered orders. Return tier and avg_order_value (rounded to 2 dp). " "Sort by avg_order_value descending." ), sql=( "SELECT c.tier, ROUND(AVG(o.total_amount), 2) AS avg_order_value " "FROM customers c " "JOIN orders o ON o.customer_id = c.id " "WHERE o.status = 'delivered' " "GROUP BY c.tier " "ORDER BY avg_order_value DESC" ), order_sensitive=True, ), SQLTemplate( id="med_009", difficulty="medium", description=( "Which products have never been ordered? " "Return id and name, sorted by name ascending." ), sql=( "SELECT p.id, p.name " "FROM products p " "LEFT JOIN order_items oi ON oi.product_id = p.id " "WHERE oi.id IS NULL " "ORDER BY p.name ASC" ), order_sensitive=True, ), SQLTemplate( id="med_010", difficulty="medium", description=( "Number of orders per status. " "Return status and order_count. Sort by order_count descending." ), sql=( "SELECT status, COUNT(*) AS order_count " "FROM orders " "GROUP BY status " "ORDER BY order_count DESC" ), order_sensitive=True, ), SQLTemplate( id="med_011", difficulty="medium", description=( "Show the total number of products per category. " "Return category_name and product_count. Sort by product_count descending." ), sql=( "SELECT c.name AS category_name, COUNT(p.id) AS product_count " "FROM categories c " "LEFT JOIN products p ON p.category_id = c.id " "GROUP BY c.id, c.name " "ORDER BY product_count DESC" ), order_sensitive=True, ), SQLTemplate( id="med_012", difficulty="medium", description=( "Average rating per product for products with at least one review. " "Return product_name and avg_rating (rounded to 2 dp). Sort by avg_rating descending." ), sql=( "SELECT p.name AS product_name, ROUND(AVG(r.rating), 2) AS avg_rating " "FROM products p " "JOIN reviews r ON r.product_id = p.id " "GROUP BY p.id, p.name " "ORDER BY avg_rating DESC" ), order_sensitive=True, ), SQLTemplate( id="med_013", difficulty="medium", description=( "Which gold-tier customers have placed more than 3 orders? " "Return customer_name and order_count. Sort by order_count descending." ), sql=( "SELECT c.name AS customer_name, COUNT(o.id) AS order_count " "FROM customers c " "JOIN orders o ON o.customer_id = c.id " "WHERE c.tier = 'gold' " "GROUP BY c.id, c.name " "HAVING COUNT(o.id) > 3 " "ORDER BY order_count DESC" ), order_sensitive=True, ), SQLTemplate( id="med_014", difficulty="medium", description=( "Total quantity of each product ordered via order_items. " "Return product_name and total_units. Sort by total_units descending." ), sql=( "SELECT p.name AS product_name, SUM(oi.quantity) AS total_units " "FROM products p " "JOIN order_items oi ON oi.product_id = p.id " "GROUP BY p.id, p.name " "ORDER BY total_units DESC" ), order_sensitive=True, ), SQLTemplate( id="med_015", difficulty="medium", description=( "For each country, count the number of gold-tier customers. " "Only show countries with at least one gold-tier customer. " "Return country and gold_count. Sort by gold_count descending." ), sql=( "SELECT country, COUNT(*) AS gold_count " "FROM customers " "WHERE tier = 'gold' " "GROUP BY country " "HAVING COUNT(*) >= 1 " "ORDER BY gold_count DESC" ), order_sensitive=True, ), SQLTemplate( id="med_016", difficulty="medium", description=( "Show how many reviews each customer has submitted. Only include customers " "who have submitted at least one review. Return customer_name and review_count. " "Sort by review_count descending." ), sql=( "SELECT c.name AS customer_name, COUNT(r.id) AS review_count " "FROM customers c " "JOIN reviews r ON r.customer_id = c.id " "GROUP BY c.id, c.name " "ORDER BY review_count DESC" ), order_sensitive=True, ), SQLTemplate( id="med_017", difficulty="medium", description=( "Total revenue generated from order_items (quantity * unit_price) per category. " "Return category_name and category_revenue (rounded to 2 dp). " "Sort by category_revenue descending." ), sql=( "SELECT c.name AS category_name, " " ROUND(SUM(oi.quantity * oi.unit_price), 2) AS category_revenue " "FROM categories c " "JOIN products p ON p.category_id = c.id " "JOIN order_items oi ON oi.product_id = p.id " "GROUP BY c.id, c.name " "ORDER BY category_revenue DESC" ), order_sensitive=True, ), SQLTemplate( id="med_018", difficulty="medium", description=( "Which products have an average rating strictly below 3? " "Return product_name and avg_rating (rounded to 2 dp). Sort by avg_rating ascending." ), sql=( "SELECT p.name AS product_name, ROUND(AVG(r.rating), 2) AS avg_rating " "FROM products p " "JOIN reviews r ON r.product_id = p.id " "GROUP BY p.id, p.name " "HAVING AVG(r.rating) < 3 " "ORDER BY avg_rating ASC" ), order_sensitive=True, ), SQLTemplate( id="med_019", difficulty="medium", description=( "Find the maximum order value for each customer tier. " "Return tier and max_order_value (rounded to 2 dp). Sort by max_order_value descending." ), sql=( "SELECT c.tier, ROUND(MAX(o.total_amount), 2) AS max_order_value " "FROM customers c " "JOIN orders o ON o.customer_id = c.id " "GROUP BY c.tier " "ORDER BY max_order_value DESC" ), order_sensitive=True, ), SQLTemplate( id="med_020", difficulty="medium", description=( "How many customers per country have placed at least one delivered order? " "Return country and customer_count. Sort by customer_count descending." ), sql=( "SELECT c.country, COUNT(DISTINCT c.id) AS customer_count " "FROM customers c " "JOIN orders o ON o.customer_id = c.id " "WHERE o.status = 'delivered' " "GROUP BY c.country " "ORDER BY customer_count DESC" ), order_sensitive=True, ), SQLTemplate( id="med_021", difficulty="medium", description=( "List all products together with their category name. " "Return product_name, category_name, price. Sort by category_name, then price ascending." ), sql=( "SELECT p.name AS product_name, c.name AS category_name, p.price " "FROM products p " "JOIN categories c ON c.id = p.category_id " "ORDER BY category_name ASC, p.price ASC" ), order_sensitive=True, ), SQLTemplate( id="med_022", difficulty="medium", description=( "For each order, show the total number of line items it contains. " "Return order_id and line_item_count. Sort by line_item_count descending." ), sql=( "SELECT order_id, COUNT(*) AS line_item_count " "FROM order_items " "GROUP BY order_id " "ORDER BY line_item_count DESC" ), order_sensitive=True, ), SQLTemplate( id="med_023", difficulty="medium", description=( "Show the minimum and maximum product price per category. " "Return category_name, min_price, max_price. Sort by category_name ascending." ), sql=( "SELECT c.name AS category_name, " " ROUND(MIN(p.price), 2) AS min_price, " " ROUND(MAX(p.price), 2) AS max_price " "FROM categories c " "JOIN products p ON p.category_id = c.id " "GROUP BY c.id, c.name " "ORDER BY category_name ASC" ), order_sensitive=True, ), SQLTemplate( id="med_024", difficulty="medium", description=( "Find customers who have given a rating of 5 to at least one product. " "Return customer_name and five_star_count. Sort by five_star_count descending." ), sql=( "SELECT c.name AS customer_name, COUNT(r.id) AS five_star_count " "FROM customers c " "JOIN reviews r ON r.customer_id = c.id " "WHERE r.rating = 5 " "GROUP BY c.id, c.name " "ORDER BY five_star_count DESC" ), order_sensitive=True, ), SQLTemplate( id="med_025", difficulty="medium", description=( "Show the average number of items per order across all orders. " "Round to 2 decimal places. Return: avg_items_per_order." ), sql=( "SELECT ROUND(AVG(item_count), 2) AS avg_items_per_order " "FROM ( " " SELECT order_id, COUNT(*) AS item_count " " FROM order_items " " GROUP BY order_id " ")" ), ), ] HARD_TEMPLATES: List[SQLTemplate] = [ # ── Window functions ───────────────────────────────────────────────────── SQLTemplate( id="hard_001", difficulty="hard", description=( "Rank customers by total spending on delivered orders using DENSE_RANK " "(rank 1 = highest spender). " "Return customer_name, total_spent (rounded to 2 dp), spending_rank. " "Sort by spending_rank ascending." ), sql=( "SELECT customer_name, total_spent, spending_rank " "FROM ( " " SELECT c.name AS customer_name, " " ROUND(SUM(o.total_amount), 2) AS total_spent, " " DENSE_RANK() OVER (ORDER BY SUM(o.total_amount) DESC) AS spending_rank " " FROM customers c " " JOIN orders o ON o.customer_id = c.id " " WHERE o.status = 'delivered' " " GROUP BY c.id, c.name " ") sub " "ORDER BY spending_rank ASC" ), order_sensitive=True, ), SQLTemplate( id="hard_002", difficulty="hard", description=( "For each reviewed product, show its own average rating and the average rating " "of all products in its category (partition window). " "Return product_name, product_avg_rating, category_avg_rating (both rounded to 2 dp). " "Sort by product_avg_rating descending." ), sql=( "SELECT p.name AS product_name, " " ROUND(AVG(r.rating), 2) AS product_avg_rating, " " ROUND(AVG(AVG(r.rating)) OVER (PARTITION BY p.category_id), 2) AS category_avg_rating " "FROM products p " "JOIN reviews r ON r.product_id = p.id " "GROUP BY p.id, p.name, p.category_id " "ORDER BY product_avg_rating DESC" ), order_sensitive=True, ), SQLTemplate( id="hard_003", difficulty="hard", description=( "Find all customers whose most recent order has status 'cancelled'. " "Use a CTE with ROW_NUMBER partitioned by customer_id ordered by created_at DESC. " "Return customer_name, last_order_status, last_order_date. Sort by customer_name ascending." ), sql=( "WITH ranked_orders AS ( " " SELECT customer_id, status, created_at, " " ROW_NUMBER() OVER (PARTITION BY customer_id ORDER BY created_at DESC) AS rn " " FROM orders " ") " "SELECT c.name AS customer_name, " " ro.status AS last_order_status, " " ro.created_at AS last_order_date " "FROM customers c " "JOIN ranked_orders ro ON ro.customer_id = c.id " "WHERE ro.rn = 1 AND ro.status = 'cancelled' " "ORDER BY customer_name ASC" ), order_sensitive=True, ), SQLTemplate( id="hard_004", difficulty="hard", description=( "Monthly revenue from delivered orders and its running total for all months in 2024. " "Return month (YYYY-MM format), monthly_revenue, running_total (both rounded to 2 dp). " "Sort by month ascending." ), sql=( "WITH monthly AS ( " " SELECT strftime('%Y-%m', created_at) AS month, " " ROUND(SUM(total_amount), 2) AS monthly_revenue " " FROM orders " " WHERE status = 'delivered' " " AND created_at >= '2024-01-01' AND created_at < '2025-01-01' " " GROUP BY strftime('%Y-%m', created_at) " ") " "SELECT month, monthly_revenue, " " ROUND(SUM(monthly_revenue) OVER (ORDER BY month), 2) AS running_total " "FROM monthly " "ORDER BY month ASC" ), order_sensitive=True, ), SQLTemplate( id="hard_005", difficulty="hard", description=( "Find products whose average rating is strictly above the average rating of all products " "in their category. Use two CTEs: one for product-level averages and one for category-level. " "Return product_name, category_name, product_avg_rating, category_avg_rating (both rounded to 2 dp). " "Sort by product_avg_rating descending, then product_name ascending." ), sql=( "WITH product_ratings AS ( " " SELECT p.id AS product_id, p.name AS product_name, " " p.category_id, c.name AS category_name, " " ROUND(AVG(r.rating), 2) AS product_avg_rating " " FROM products p " " JOIN reviews r ON r.product_id = p.id " " JOIN categories c ON c.id = p.category_id " " GROUP BY p.id, p.name, p.category_id, c.name " "), " "category_ratings AS ( " " SELECT category_id, ROUND(AVG(product_avg_rating), 2) AS category_avg_rating " " FROM product_ratings " " GROUP BY category_id " ") " "SELECT pr.product_name, pr.category_name, " " pr.product_avg_rating, cr.category_avg_rating " "FROM product_ratings pr " "JOIN category_ratings cr ON cr.category_id = pr.category_id " "WHERE pr.product_avg_rating > cr.category_avg_rating " "ORDER BY pr.product_avg_rating DESC, pr.product_name ASC" ), order_sensitive=True, ), SQLTemplate( id="hard_006", difficulty="hard", description=( "For each customer, find their very first order date using ROW_NUMBER in a CTE. " "Return customer_name and first_order_date. Sort by first_order_date ascending." ), sql=( "WITH first_orders AS ( " " SELECT customer_id, created_at, " " ROW_NUMBER() OVER (PARTITION BY customer_id ORDER BY created_at ASC) AS rn " " FROM orders " ") " "SELECT c.name AS customer_name, fo.created_at AS first_order_date " "FROM customers c " "JOIN first_orders fo ON fo.customer_id = c.id " "WHERE fo.rn = 1 " "ORDER BY first_order_date ASC" ), order_sensitive=True, ), SQLTemplate( id="hard_007", difficulty="hard", description=( "Rank products by total revenue generated (quantity * unit_price from order_items) " "using RANK() window function. " "Return product_name, total_revenue (rounded to 2 dp), revenue_rank. " "Sort by revenue_rank ascending." ), sql=( "SELECT product_name, total_revenue, revenue_rank " "FROM ( " " SELECT p.name AS product_name, " " ROUND(SUM(oi.quantity * oi.unit_price), 2) AS total_revenue, " " RANK() OVER (ORDER BY SUM(oi.quantity * oi.unit_price) DESC) AS revenue_rank " " FROM products p " " JOIN order_items oi ON oi.product_id = p.id " " GROUP BY p.id, p.name " ") sub " "ORDER BY revenue_rank ASC" ), order_sensitive=True, ), SQLTemplate( id="hard_008", difficulty="hard", description=( "For each customer, compute the running total of their order amounts ordered by " "created_at. Return customer_name, order_date (created_at), order_amount (total_amount), " "running_total (rounded to 2 dp). Sort by customer_name, order_date ascending." ), sql=( "SELECT c.name AS customer_name, " " o.created_at AS order_date, " " o.total_amount AS order_amount, " " ROUND(SUM(o.total_amount) OVER " " (PARTITION BY c.id ORDER BY o.created_at " " ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), 2) AS running_total " "FROM customers c " "JOIN orders o ON o.customer_id = c.id " "ORDER BY customer_name ASC, order_date ASC" ), order_sensitive=True, ), SQLTemplate( id="hard_009", difficulty="hard", description=( "Find customers who have placed orders in every status " "(pending, processing, shipped, delivered, cancelled) at least once. " "Return customer_name and status_count. Sort by customer_name ascending." ), sql=( "SELECT c.name AS customer_name, COUNT(DISTINCT o.status) AS status_count " "FROM customers c " "JOIN orders o ON o.customer_id = c.id " "GROUP BY c.id, c.name " "HAVING COUNT(DISTINCT o.status) = 5 " "ORDER BY customer_name ASC" ), order_sensitive=True, ), SQLTemplate( id="hard_010", difficulty="hard", description=( "Using a CTE, compute the total revenue per product, then rank the top 3 products " "in each category by revenue using DENSE_RANK. Only return rows with rank <= 3. " "Return category_name, product_name, total_revenue (rounded to 2 dp), rank_in_category. " "Sort by category_name, rank_in_category ascending." ), sql=( "WITH product_rev AS ( " " SELECT p.id, p.name AS product_name, p.category_id, " " c.name AS category_name, " " ROUND(SUM(oi.quantity * oi.unit_price), 2) AS total_revenue " " FROM products p " " JOIN categories c ON c.id = p.category_id " " JOIN order_items oi ON oi.product_id = p.id " " GROUP BY p.id, p.name, p.category_id, c.name " "), " "ranked AS ( " " SELECT product_name, category_name, total_revenue, " " DENSE_RANK() OVER (PARTITION BY category_id ORDER BY total_revenue DESC) AS rank_in_category " " FROM product_rev " ") " "SELECT category_name, product_name, total_revenue, rank_in_category " "FROM ranked " "WHERE rank_in_category <= 3 " "ORDER BY category_name ASC, rank_in_category ASC" ), order_sensitive=True, ), SQLTemplate( id="hard_011", difficulty="hard", description=( "Compute the percentage of total revenue each category contributes. " "Use a CTE for category revenues and a window SUM for the grand total. " "Return category_name, category_revenue, pct_of_total (rounded to 2 dp). " "Sort by pct_of_total descending." ), sql=( "WITH cat_rev AS ( " " SELECT c.name AS category_name, " " ROUND(SUM(oi.quantity * oi.unit_price), 2) AS category_revenue " " FROM categories c " " JOIN products p ON p.category_id = c.id " " JOIN order_items oi ON oi.product_id = p.id " " GROUP BY c.id, c.name " ") " "SELECT category_name, category_revenue, " " ROUND(100.0 * category_revenue / SUM(category_revenue) OVER (), 2) AS pct_of_total " "FROM cat_rev " "ORDER BY pct_of_total DESC" ), order_sensitive=True, ), SQLTemplate( id="hard_012", difficulty="hard", description=( "Find the customers who placed the highest number of orders in 2023. " "Use a CTE to count per-customer orders in 2023, then apply DENSE_RANK. " "Return customer_name, order_count_2023, rank. Sort by rank, then customer_name." ), sql=( "WITH counts_2023 AS ( " " SELECT c.name AS customer_name, COUNT(o.id) AS order_count_2023 " " FROM customers c " " JOIN orders o ON o.customer_id = c.id " " WHERE o.created_at >= '2023-01-01' AND o.created_at < '2024-01-01' " " GROUP BY c.id, c.name " ") " "SELECT customer_name, order_count_2023, " " DENSE_RANK() OVER (ORDER BY order_count_2023 DESC) AS rank " "FROM counts_2023 " "ORDER BY rank ASC, customer_name ASC" ), order_sensitive=True, ), SQLTemplate( id="hard_013", difficulty="hard", description=( "Show a quarterly revenue breakdown for delivered orders across all years. " "Use strftime to derive year and quarter. " "Return year, quarter, quarterly_revenue (rounded to 2 dp), " "and running_total_in_year (running SUM within the same year, rounded to 2 dp). " "Sort by year, quarter ascending." ), sql=( "WITH quarterly AS ( " " SELECT strftime('%Y', created_at) AS year, " " ((CAST(strftime('%m', created_at) AS INTEGER) - 1) / 3 + 1) AS quarter, " " ROUND(SUM(total_amount), 2) AS quarterly_revenue " " FROM orders " " WHERE status = 'delivered' " " GROUP BY year, quarter " ") " "SELECT year, quarter, quarterly_revenue, " " ROUND(SUM(quarterly_revenue) OVER (PARTITION BY year ORDER BY quarter), 2) AS running_total_in_year " "FROM quarterly " "ORDER BY year ASC, quarter ASC" ), order_sensitive=True, ), SQLTemplate( id="hard_014", difficulty="hard", description=( "Find the top-spending customer in each country using ROW_NUMBER. " "Return country, customer_name, total_spent (rounded to 2 dp). " "Sort by country, total_spent descending." ), sql=( "WITH customer_spend AS ( " " SELECT c.id, c.name AS customer_name, c.country, " " ROUND(SUM(o.total_amount), 2) AS total_spent " " FROM customers c " " JOIN orders o ON o.customer_id = c.id " " GROUP BY c.id, c.name, c.country " "), " "ranked AS ( " " SELECT country, customer_name, total_spent, " " ROW_NUMBER() OVER (PARTITION BY country ORDER BY total_spent DESC) AS rn " " FROM customer_spend " ") " "SELECT country, customer_name, total_spent " "FROM ranked " "WHERE rn = 1 " "ORDER BY country ASC" ), order_sensitive=True, ), SQLTemplate( id="hard_015", difficulty="hard", description=( "Find products that have received both 1-star and 5-star reviews. " "Use two CTEs: one for 1-star products, one for 5-star products, then intersect. " "Return product_name. Sort by product_name ascending." ), sql=( "WITH one_star AS ( " " SELECT DISTINCT product_id FROM reviews WHERE rating = 1 " "), " "five_star AS ( " " SELECT DISTINCT product_id FROM reviews WHERE rating = 5 " ") " "SELECT p.name AS product_name " "FROM products p " "JOIN one_star os ON os.product_id = p.id " "JOIN five_star fs ON fs.product_id = p.id " "ORDER BY product_name ASC" ), order_sensitive=True, ), ] ALL_TEMPLATES: List[SQLTemplate] = EASY_TEMPLATES + MEDIUM_TEMPLATES + HARD_TEMPLATES # ───────────────────────────────────────────────────────────────────────────── # Personas # ───────────────────────────────────────────────────────────────────────────── SCHEMA_CONTEXT = """ DATABASE SCHEMA (SQLite e-commerce): categories(id, name) products(id, name, category_id, price, stock_quantity) customers(id, name, email, country, tier∈{bronze|silver|gold}, created_at) orders(id, customer_id, status∈{pending|processing|shipped|delivered|cancelled}, created_at, total_amount) order_items(id, order_id, product_id, quantity, unit_price) reviews(id, product_id, customer_id, rating∈1-5, created_at) """ PERSONA_SPECS = { "ceo": ( "You are a senior business executive. Write one SHORT, direct question in active voice, " "as if you are asking an analyst to pull a number fast. Be terse, no fluff. " "Use business language: 'revenue', 'customers', 'performance', not technical SQL terms." ), "chatty": ( "You are a friendly but verbose non-technical employee. Write one long, conversational " "question with filler phrases like 'Could you please tell me...', 'I was wondering if...', " "passive voice is fine. Use everyday words like 'money' instead of 'revenue', " "'people' instead of 'customers'." ), "lazy": ( "You are typing quickly on a phone. Write an extremely short question with abbreviations, " "lowercase letters, and minor spelling mistakes. Skip articles and punctuation where possible. " "Example style: 'top 5 prods by sales?', 'hw many cust in usa'." ), "confused": ( "You are a non-technical user who is unsure of the exact terminology. Write one question " "using synonyms and vague language. Replace 'revenue' with 'money made', 'customers' with " "'people' or 'users' or 'accounts', 'orders' with 'purchases' or 'transactions', " "'tier' with 'membership level'. Include a bit of ambiguity." ), "analyst": ( "You are a data analyst with technical knowledge. Write one precise, jargon-heavy question " "using terms like 'aggregate', 'partition', 'metric', 'fiscal period', 'segmented by', " "'cohort', 'granularity'. Be specific about column names and filters." ), } # ───────────────────────────────────────────────────────────────────────────── # Rule-based Augmentor # ───────────────────────────────────────────────────────────────────────────── class RuleAugmentor: """ Applies deterministic, non-LLM transformations to a generated NL question. Returns a list of augmented variants (may be empty if no rule applied). """ SYNONYMS: Dict[str, List[str]] = { "customers": ["clients", "users", "accounts", "shoppers", "buyers"], "orders": ["purchases", "transactions", "sales", "bookings"], "products": ["items", "goods", "listings", "SKUs"], "revenue": ["sales", "income", "earnings", "money made"], "spending": ["expenditure", "purchases", "money spent"], "delivered": ["completed", "fulfilled", "received"], "cancelled": ["canceled", "voided", "aborted"], "pending": ["waiting", "unprocessed", "queued"], "gold": ["premium", "top-tier", "VIP", "platinum"], "silver": ["mid-tier", "standard-plus"], "bronze": ["basic", "standard", "entry-level"], "rating": ["score", "star rating", "review score"], "country": ["region", "location", "geography", "nation"], "category": ["department", "section", "type", "group"], "price": ["cost", "value", "amount", "fee"], "total": ["sum", "aggregate", "combined", "overall"], "average": ["mean", "typical", "avg"], "show": ["list", "display", "give me", "get", "fetch"], "find": ["identify", "locate", "get", "pull", "retrieve"], "return": ["give me", "show", "list", "provide"], } def augment(self, question: str, rng: random.Random) -> Optional[str]: words = question.split() changed = False result = [] for w in words: clean = w.lower().strip(".,?!;:") if clean in self.SYNONYMS and rng.random() < 0.4: replacement = rng.choice(self.SYNONYMS[clean]) # Preserve trailing punctuation punct = w[len(clean):] if w.lower().startswith(clean) else "" result.append(replacement + punct) changed = True else: result.append(w) if not changed: return None new_q = " ".join(result) # Capitalise first letter return new_q[0].upper() + new_q[1:] if new_q else new_q # ───────────────────────────────────────────────────────────────────────────── # vLLM Generator # ───────────────────────────────────────────────────────────────────────────── class VLLMGenerator: """ Async batched inference using the OpenAI-compatible vLLM endpoint. vLLM exposes exactly the same API as OpenAI, so we reuse AsyncOpenAI. """ def __init__(self, base_url: str, model: str, temperature: float = 0.8, max_tokens: int = 256, semaphore: int = 64): self.client = AsyncOpenAI(base_url=base_url, api_key="NONE") self.model = model self.temperature = temperature self.max_tokens = max_tokens self._sem = asyncio.Semaphore(semaphore) async def generate_one( self, system: str, user: str, retries: int = 3, ) -> Optional[str]: for attempt in range(retries): try: async with self._sem: resp = await self.client.chat.completions.create( model=self.model, messages=[ {"role": "system", "content": system}, {"role": "user", "content": user}, ], temperature=self.temperature, max_tokens=self.max_tokens, ) text = resp.choices[0].message.content.strip() return text if text else None except Exception as exc: wait = 2 ** attempt log.warning(f"vLLM call failed (attempt {attempt+1}): {exc}. Retrying in {wait}s.") await asyncio.sleep(wait) return None async def generate_batch( self, requests: List[Tuple[str, str, str]], # (request_id, system, user) ) -> Dict[str, Optional[str]]: """ Fire all requests concurrently (bounded by semaphore) and return a dict. """ async def _one(rid, sys, usr): return rid, await self.generate_one(sys, usr) tasks = [_one(rid, sys, usr) for rid, sys, usr in requests] results = await asyncio.gather(*tasks) return {rid: text for rid, text in results} # ───────────────────────────────────────────────────────────────────────────── # Data Factory # ───────────────────────────────────────────────────────────────────────────── @dataclass class DataPoint: id: str difficulty: str persona: str question: str sql: str db_result_ok: bool augmented: bool def to_training_prompt(self, system_prompt: str) -> Dict[str, Any]: """ Return the dict structure expected by train.py / SFT pipelines. Includes both the raw fields and a formatted 'messages' list. """ user_content = ( f"SCHEMA:\n{SCHEMA_CONTEXT}\n\nQUESTION: {self.question}" ) return { **asdict(self), "messages": [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_content}, {"role": "assistant", "content": self.sql}, ], } SYSTEM_PROMPT = ( "You are an expert SQL analyst working with a SQLite e-commerce database. " "Write a single SELECT query. Output ONLY the SQL query, nothing else. No markdown." ) class DataFactory: def __init__( self, generator: VLLMGenerator, validator: SQLiteValidator, augmentor: RuleAugmentor, personas_per_template: int = 5, aug_rounds: int = 2, seed: int = 42, ): self.generator = generator self.validator = validator self.augmentor = augmentor self.personas_per_template = personas_per_template self.aug_rounds = aug_rounds self.rng = random.Random(seed) # ── Step 1: Validate all template SQLs ─────────────────────────────────── def validate_templates(self) -> List[SQLTemplate]: log.info("Validating all SQL templates against seeded DB...") valid = [] failed = [] for t in ALL_TEMPLATES: ok, err = self.validator.validate(t.sql) if ok: valid.append(t) else: failed.append((t.id, err)) if failed: log.error(f"FAILED templates (will be skipped): {failed}") log.info(f"Templates validated: {len(valid)} ok, {len(failed)} failed.") return valid # ── Step 2: Build generation requests ──────────────────────────────────── def _build_requests( self, templates: List[SQLTemplate], persona_names: List[str], ) -> List[Tuple[str, str, str]]: """ Returns a flat list of (request_id, system_prompt, user_prompt) tuples. """ requests = [] for t in templates: chosen_personas = ( persona_names if self.personas_per_template >= len(PERSONA_SPECS) else self.rng.sample(persona_names, self.personas_per_template) ) for persona in chosen_personas: rid = f"{t.id}__{persona}" system = ( f"{PERSONA_SPECS[persona]}\n\n" "Output ONLY the natural language question. " "No explanation, no SQL, no preamble, no quotes around the question." ) user = ( f"{SCHEMA_CONTEXT}\n" f"The SQL query that answers this question is:\n{t.sql}\n\n" f"Write ONE natural-language question that a {persona.upper()} user " f"would ask to get this exact result." ) requests.append((rid, system, user)) return requests # ── Step 3: Post-process a generated question ───────────────────────────── @staticmethod def _clean(text: str) -> str: """Strip quotes, markdown, leading numbers, trailing newlines.""" text = text.strip() # Remove leading numbering like "1. " or "Q: " text = re.sub(r'^[\d]+[\.\)]\s+', '', text) text = re.sub(r'^[Qq]:\s*', '', text) # Strip surrounding quotes if (text.startswith('"') and text.endswith('"')) or \ (text.startswith("'") and text.endswith("'")): text = text[1:-1].strip() # Collapse multiple whitespace text = re.sub(r'\s+', ' ', text) return text # ── Main pipeline ───────────────────────────────────────────────────────── async def run( self, output_path: str, checkpoint_path: str, batch_size: int = 64, ) -> None: # -- Validate templates templates = self.validate_templates() # -- Load checkpoint done_ids: set = set() if os.path.exists(checkpoint_path): with open(checkpoint_path) as f: done_ids = set(json.loads(line)["id"] for line in f if line.strip()) log.info(f"Resuming: {len(done_ids)} examples already generated.") persona_names = list(PERSONA_SPECS.keys())[: self.personas_per_template] all_requests = self._build_requests(templates, persona_names) # Filter already done pending = [r for r in all_requests if r[0] not in done_ids] log.info(f"Total requests to generate: {len(pending)}") # -- Build template lookup tmpl_lookup: Dict[str, SQLTemplate] = {t.id: t for t in templates} stats = {"generated": 0, "invalid_llm": 0, "augmented": 0} out_f = open(output_path, "a") ckpt_f = open(checkpoint_path, "a") try: for i in tqdm(range(0, len(pending), batch_size), desc="Batches"): batch = pending[i: i + batch_size] results = await self.generator.generate_batch(batch) for rid, raw_text in results.items(): tmpl_id, persona = rid.split("__", 1) tmpl = tmpl_lookup[tmpl_id] if not raw_text: stats["invalid_llm"] += 1 continue question = self._clean(raw_text) if len(question) < 8: stats["invalid_llm"] += 1 continue # SQL already validated; no need to re-run for NL variants dp = DataPoint( id=rid, difficulty=tmpl.difficulty, persona=persona, question=question, sql=tmpl.sql, db_result_ok=True, augmented=False, ) record = dp.to_training_prompt(SYSTEM_PROMPT) line = json.dumps(record, ensure_ascii=False) out_f.write(line + "\n") ckpt_f.write(line + "\n") stats["generated"] += 1 # -- Rule augmentation rounds for aug_i in range(self.aug_rounds): aug_q = self.augmentor.augment(question, self.rng) if aug_q and aug_q != question: aug_dp = DataPoint( id=f"{rid}__aug{aug_i}", difficulty=tmpl.difficulty, persona=persona, question=aug_q, sql=tmpl.sql, db_result_ok=True, augmented=True, ) aug_record = aug_dp.to_training_prompt(SYSTEM_PROMPT) aug_line = json.dumps(aug_record, ensure_ascii=False) out_f.write(aug_line + "\n") ckpt_f.write(aug_line + "\n") stats["augmented"] += 1 out_f.flush() ckpt_f.flush() finally: out_f.close() ckpt_f.close() log.info( f"Done. Generated={stats['generated']} " f"Augmented={stats['augmented']} " f"LLM failures={stats['invalid_llm']}" ) log.info(f"Output: {output_path}") # ───────────────────────────────────────────────────────────────────────────── # CLI # ───────────────────────────────────────────────────────────────────────────── def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser( description="NL2SQL Synthetic Data Factory — H100 + vLLM", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) p.add_argument("--vllm-url", default="http://localhost:8001/v1", help="Base URL of the running vLLM server.") p.add_argument("--model", default="meta-llama/Meta-Llama-3-70B-Instruct", help="Model name as registered in the vLLM server.") p.add_argument("--output", default="nl2sql_train.jsonl", help="Path to write the final JSONL dataset.") p.add_argument("--checkpoint",default="nl2sql_checkpoint.jsonl", help="Path for the checkpoint file (enables resume on crash).") p.add_argument("--personas-per-template", type=int, default=5, help="Number of persona variants to generate per SQL template (max 5).") p.add_argument("--aug-rounds", type=int, default=2, help="Number of rule-based augmentation rounds per generated question.") p.add_argument("--batch-size", type=int, default=64, help="Concurrent vLLM requests per batch (tune based on GPU memory).") p.add_argument("--temperature", type=float, default=0.85, help="Sampling temperature for vLLM (higher = more diverse).") p.add_argument("--max-tokens", type=int, default=200, help="Max tokens for each generated question.") p.add_argument("--seed", type=int, default=42) p.add_argument("--validate-only", action="store_true", help="Only validate SQL templates, do not generate data.") return p.parse_args() async def main() -> None: args = parse_args() # Build DB + validator conn = build_db() validator = SQLiteValidator(conn) if args.validate_only: valid = [t for t in ALL_TEMPLATES if validator.validate(t.sql)[0]] invalid = [t for t in ALL_TEMPLATES if not validator.validate(t.sql)[0]] print(f"\n✅ Valid: {len(valid)}") print(f"❌ Invalid: {len(invalid)}") for t in invalid: _, err = validator.validate(t.sql) print(f" {t.id}: {err}") return # Build pipeline components generator = VLLMGenerator( base_url=args.vllm_url, model=args.model, temperature=args.temperature, max_tokens=args.max_tokens, semaphore=args.batch_size, ) augmentor = RuleAugmentor() factory = DataFactory( generator=generator, validator=validator, augmentor=augmentor, personas_per_template=min(args.personas_per_template, len(PERSONA_SPECS)), aug_rounds=args.aug_rounds, seed=args.seed, ) await factory.run( output_path=args.output, checkpoint_path=args.checkpoint, batch_size=args.batch_size, ) if __name__ == "__main__": asyncio.run(main())