nl2sql-bench / data_factory /generator.py
ritvik360's picture
Upload folder using huggingface_hub
a39d8ef verified
"""
data_factory/generator.py
==========================
vLLM-based Natural Language question generator for H100.
This module uses a large LLM (Llama-3-70B or Qwen-72B) served via vLLM
to generate diverse, persona-based natural language paraphrases of the
canonical NL questions in our template library.
KEY DESIGN: The LLM generates ONLY natural language questions.
SQL is NEVER touched by the LLM.
This guarantees zero SQL errors in the final dataset.
Persona descriptions:
ceo - Direct, short, active voice. Business executive style.
chatty - Conversational, verbose, passive voice.
lazy_typist - Short, abbreviations, possible informal grammar.
non_techie - Plain English, avoids SQL/tech jargon, uses synonyms.
analyst - Technical, precise, jargon-heavy.
Usage (on H100 cluster):
python -m data_factory.generator --templates-per-chunk 20 --n-variants 10
"""
from __future__ import annotations
import json
import logging
import time
from typing import Iterator, Optional
logger = logging.getLogger(__name__)
# ─────────────────────────────────────────────────────────────────────────────
# PERSONA SYSTEM PROMPTS
# ─────────────────────────────────────────────────────────────────────────────
PERSONA_SYSTEM_PROMPTS: dict[str, str] = {
"ceo": (
"You are a busy C-level executive who communicates in short, punchy, "
"direct sentences. You use active voice, skip filler words, and get "
"straight to the point. You are asking a data analyst for information."
),
"chatty": (
"You are a friendly, conversational person who likes to be thorough "
"and explain things fully. You use passive voice sometimes, add context, "
"and ask questions in a relaxed, detailed way. You are not technical."
),
"lazy_typist": (
"You type quickly and informally. You use abbreviations (e.g. 'pls', "
"'lmk', 'asap'), lowercase, minimal punctuation, and sometimes omit "
"words. You get your meaning across without perfect grammar."
),
"non_techie": (
"You have no database or SQL knowledge. You use everyday English words "
"instead of technical terms. For example, you say 'customers' not 'rows', "
"'most expensive' not 'highest price', 'total money' not 'sum'. "
"You describe what you want to see, not how to get it."
),
"analyst": (
"You are a data scientist or BI analyst who is precise and technical. "
"You use terms like 'aggregate', 'partition', 'granularity', 'distinct', "
"'filter predicate', 'ranked by metric'. Your questions are precise and unambiguous."
),
}
# ─────────────────────────────────────────────────────────────────────────────
# PROMPT BUILDER
# ─────────────────────────────────────────────────────────────────────────────
def build_generation_prompt(
canonical_nl: str,
description: str,
persona: str,
schema_context: str,
n_variants: int = 5,
) -> list[dict[str, str]]:
"""
Build a chat-format prompt asking the LLM to rephrase the canonical NL
question in the style of the given persona.
Parameters
----------
canonical_nl : The base NL question from the template.
description : One-line SQL description (gives the LLM additional context).
persona : One of the 5 persona keys.
schema_context : The compact schema string for the domain.
n_variants : How many rephrased questions to generate.
Returns
-------
list[dict] Chat messages in [{"role": ..., "content": ...}] format.
"""
persona_desc = PERSONA_SYSTEM_PROMPTS[persona]
system = (
"You are a data labelling specialist. Your task is to rephrase a database "
"question in a specific communication style (persona). The rephrased questions "
"must preserve the EXACT same intent and required information as the original β€” "
"do not change what data is being asked for, only how it is expressed.\n\n"
f"PERSONA: {persona_desc}\n\n"
"OUTPUT FORMAT: Return ONLY a valid JSON array of strings. "
"No preamble, no markdown, no extra keys. Example: "
'["question 1", "question 2", "question 3"]'
)
user = (
f"DATABASE CONTEXT:\n{schema_context}\n\n"
f"WHAT THE QUERY DOES: {description}\n\n"
f"CANONICAL QUESTION: {canonical_nl}\n\n"
f"Generate {n_variants} different ways a person with the persona described "
f"above would ask this same question. The meaning must stay identical."
)
return [
{"role": "system", "content": system},
{"role": "user", "content": user},
]
# ─────────────────────────────────────────────────────────────────────────────
# RESPONSE PARSER
# ─────────────────────────────────────────────────────────────────────────────
def parse_llm_response(raw_text: str) -> list[str]:
"""
Extract a list of strings from the LLM's JSON response.
Handles common failures: markdown fences, trailing commas, extra text.
Returns an empty list if parsing fails completely.
"""
text = raw_text.strip()
# Strip markdown fences if present
if text.startswith("```"):
lines = text.split("\n")
text = "\n".join(l for l in lines if not l.strip().startswith("```")).strip()
# Find the JSON array boundaries
start = text.find("[")
end = text.rfind("]")
if start == -1 or end == -1 or end <= start:
logger.warning("LLM response missing JSON array brackets: %s", text[:100])
return []
json_str = text[start:end + 1]
# Fix trailing commas before ] (common LLM mistake)
json_str = json_str.rstrip()
json_str = json_str.replace(",]", "]").replace(", ]", "]")
try:
parsed = json.loads(json_str)
if not isinstance(parsed, list):
return []
# Filter to only non-empty strings
return [s.strip() for s in parsed if isinstance(s, str) and s.strip()]
except json.JSONDecodeError as exc:
logger.warning("JSON parse error: %s | text: %s", exc, json_str[:200])
return []
# ─────────────────────────────────────────────────────────────────────────────
# VLLM INTERFACE
# ─────────────────────────────────────────────────────────────────────────────
class VLLMGenerator:
"""
Wrapper around a running vLLM server for high-throughput NL generation.
Supports two modes:
online : Calls a running vLLM OpenAI-compatible API server.
offline : Uses vllm.LLM directly (loads model in-process, H100 recommended).
For H100 cluster usage, prefer 'offline' mode with tensor_parallel_size=4
to saturate all 4 H100s for maximum throughput.
"""
def __init__(
self,
model_name: str,
mode: str = "offline",
tensor_parallel_size: int = 4,
gpu_memory_utilization: float = 0.90,
max_model_len: int = 4096,
# Online mode only
api_base: str = "http://localhost:8000/v1",
api_key: str = "EMPTY",
) -> None:
self.model_name = model_name
self.mode = mode
self._llm = None
self._client = None
if mode == "offline":
self._init_offline(tensor_parallel_size, gpu_memory_utilization, max_model_len)
elif mode == "online":
self._init_online(api_base, api_key)
else:
raise ValueError(f"Unknown mode: {mode!r}. Use 'offline' or 'online'.")
def _init_offline(
self,
tensor_parallel_size: int,
gpu_memory_utilization: float,
max_model_len: int,
) -> None:
"""Load vLLM engine in-process (best for H100 cluster)."""
try:
from vllm import LLM, SamplingParams
self._LLM = LLM
self._SamplingParams = SamplingParams
except ImportError:
raise ImportError(
"vLLM not installed. Run: pip install vllm\n"
"For H100: pip install vllm --extra-index-url https://download.pytorch.org/whl/cu124"
)
logger.info("Loading model %s with %d GPUs (offline mode)...", self.model_name, tensor_parallel_size)
t0 = time.time()
self._llm = self._LLM(
model=self.model_name,
tensor_parallel_size=tensor_parallel_size,
gpu_memory_utilization=gpu_memory_utilization,
max_model_len=max_model_len,
dtype="bfloat16",
trust_remote_code=True,
)
logger.info("Model loaded in %.1f seconds.", time.time() - t0)
def _init_online(self, api_base: str, api_key: str) -> None:
"""Use OpenAI-compatible vLLM server (for distributed setups)."""
try:
from openai import OpenAI
self._client = OpenAI(base_url=api_base, api_key=api_key)
except ImportError:
raise ImportError("pip install openai")
logger.info("Connected to vLLM server at %s", api_base)
def generate_batch(
self,
prompts: list[list[dict[str, str]]],
temperature: float = 0.85,
max_new_tokens: int = 300,
) -> list[str]:
"""
Generate responses for a batch of chat prompts.
Parameters
----------
prompts : List of chat message lists (one per item in batch).
temperature : Sampling temperature. Higher = more diverse.
max_new_tokens : Max tokens per response.
Returns
-------
list[str] Raw text response per prompt (same length as input).
"""
if self.mode == "offline":
return self._generate_offline(prompts, temperature, max_new_tokens)
else:
return self._generate_online(prompts, temperature, max_new_tokens)
def _generate_offline(
self,
prompts: list[list[dict]],
temperature: float,
max_new_tokens: int,
) -> list[str]:
"""vLLM offline batched generation β€” maximises H100 throughput."""
from vllm import SamplingParams
sampling = SamplingParams(
temperature=temperature,
max_tokens=max_new_tokens,
stop=["</s>", "<|eot_id|>"], # Llama-3 stop tokens
)
# Convert chat messages to tokenised prompt strings using the model's template
tokenizer = self._llm.get_tokenizer()
formatted_prompts: list[str] = []
for msgs in prompts:
if hasattr(tokenizer, "apply_chat_template"):
text = tokenizer.apply_chat_template(
msgs, tokenize=False, add_generation_prompt=True
)
else:
# Fallback: simple concatenation
text = "\n".join(
f"<|{m['role']}|>\n{m['content']}" for m in msgs
)
formatted_prompts.append(text)
outputs = self._llm.generate(formatted_prompts, sampling)
return [o.outputs[0].text for o in outputs]
def _generate_online(
self,
prompts: list[list[dict]],
temperature: float,
max_new_tokens: int,
) -> list[str]:
"""Sequential generation via OpenAI-compatible API (fallback / debugging)."""
results = []
for msgs in prompts:
try:
resp = self._client.chat.completions.create(
model=self.model_name,
messages=msgs,
temperature=temperature,
max_tokens=max_new_tokens,
)
results.append(resp.choices[0].message.content or "")
except Exception as exc:
logger.warning("API call failed: %s", exc)
results.append("")
return results
# ─────────────────────────────────────────────────────────────────────────────
# HIGH-LEVEL GENERATION LOOP
# ─────────────────────────────────────────────────────────────────────────────
def generate_persona_variants_batch(
templates_subset: list[dict],
generator: VLLMGenerator,
personas: list[str],
n_variants_per_persona: int = 5,
batch_size: int = 64,
temperature: float = 0.85,
max_new_tokens: int = 300,
) -> Iterator[dict]:
"""
For each template Γ— persona combination, generate `n_variants_per_persona`
NL question variants using the LLM.
Yields dicts:
{
"template_idx": int,
"persona": str,
"nl_variants": list[str], # successfully parsed NL questions
}
Parameters
----------
templates_subset : List of template dicts (from templates.py).
generator : VLLMGenerator instance.
personas : List of persona keys to use.
n_variants_per_persona : How many NL variants per (template, persona) pair.
batch_size : How many LLM calls to batch together.
temperature : Sampling temperature.
max_new_tokens : Max tokens for LLM response (should be ~300 for JSON array).
"""
from data_factory.schemas import SCHEMA_CONTEXT
# Build all (template_idx, persona) prompt pairs
all_jobs: list[tuple[int, str, list[dict]]] = []
for t_idx, template in enumerate(templates_subset):
schema_ctx = SCHEMA_CONTEXT[template["domain"]]
for persona in personas:
prompt = build_generation_prompt(
canonical_nl=template["base_nl"],
description=template["description"],
persona=persona,
schema_context=schema_ctx,
n_variants=n_variants_per_persona,
)
all_jobs.append((t_idx, persona, prompt))
total_jobs = len(all_jobs)
logger.info("Starting LLM generation: %d jobs (templates Γ— personas).", total_jobs)
# Process in batches
for batch_start in range(0, total_jobs, batch_size):
batch = all_jobs[batch_start: batch_start + batch_size]
prompts = [job[2] for job in batch]
t0 = time.time()
raw_responses = generator.generate_batch(
prompts, temperature=temperature, max_new_tokens=max_new_tokens
)
elapsed = time.time() - t0
logger.info(
"Batch %d-%d completed in %.1fs (%.1f jobs/s).",
batch_start, batch_start + len(batch), elapsed, len(batch) / max(elapsed, 0.001)
)
for (t_idx, persona, _), raw in zip(batch, raw_responses):
nl_variants = parse_llm_response(raw)
if not nl_variants:
logger.debug(
"Empty parse for template_idx=%d persona=%s. raw=%s",
t_idx, persona, raw[:100]
)
# Fall back to the canonical NL rather than losing this entry
nl_variants = [templates_subset[t_idx]["base_nl"]]
yield {
"template_idx": t_idx,
"persona": persona,
"nl_variants": nl_variants,
}