""" data_factory/generator.py ========================== vLLM-based Natural Language question generator for H100. This module uses a large LLM (Llama-3-70B or Qwen-72B) served via vLLM to generate diverse, persona-based natural language paraphrases of the canonical NL questions in our template library. KEY DESIGN: The LLM generates ONLY natural language questions. SQL is NEVER touched by the LLM. This guarantees zero SQL errors in the final dataset. Persona descriptions: ceo - Direct, short, active voice. Business executive style. chatty - Conversational, verbose, passive voice. lazy_typist - Short, abbreviations, possible informal grammar. non_techie - Plain English, avoids SQL/tech jargon, uses synonyms. analyst - Technical, precise, jargon-heavy. Usage (on H100 cluster): python -m data_factory.generator --templates-per-chunk 20 --n-variants 10 """ from __future__ import annotations import json import logging import time from typing import Iterator, Optional logger = logging.getLogger(__name__) # ───────────────────────────────────────────────────────────────────────────── # PERSONA SYSTEM PROMPTS # ───────────────────────────────────────────────────────────────────────────── PERSONA_SYSTEM_PROMPTS: dict[str, str] = { "ceo": ( "You are a busy C-level executive who communicates in short, punchy, " "direct sentences. You use active voice, skip filler words, and get " "straight to the point. You are asking a data analyst for information." ), "chatty": ( "You are a friendly, conversational person who likes to be thorough " "and explain things fully. You use passive voice sometimes, add context, " "and ask questions in a relaxed, detailed way. You are not technical." ), "lazy_typist": ( "You type quickly and informally. You use abbreviations (e.g. 'pls', " "'lmk', 'asap'), lowercase, minimal punctuation, and sometimes omit " "words. You get your meaning across without perfect grammar." ), "non_techie": ( "You have no database or SQL knowledge. You use everyday English words " "instead of technical terms. For example, you say 'customers' not 'rows', " "'most expensive' not 'highest price', 'total money' not 'sum'. " "You describe what you want to see, not how to get it." ), "analyst": ( "You are a data scientist or BI analyst who is precise and technical. " "You use terms like 'aggregate', 'partition', 'granularity', 'distinct', " "'filter predicate', 'ranked by metric'. Your questions are precise and unambiguous." ), } # ───────────────────────────────────────────────────────────────────────────── # PROMPT BUILDER # ───────────────────────────────────────────────────────────────────────────── def build_generation_prompt( canonical_nl: str, description: str, persona: str, schema_context: str, n_variants: int = 5, ) -> list[dict[str, str]]: """ Build a chat-format prompt asking the LLM to rephrase the canonical NL question in the style of the given persona. Parameters ---------- canonical_nl : The base NL question from the template. description : One-line SQL description (gives the LLM additional context). persona : One of the 5 persona keys. schema_context : The compact schema string for the domain. n_variants : How many rephrased questions to generate. Returns ------- list[dict] Chat messages in [{"role": ..., "content": ...}] format. """ persona_desc = PERSONA_SYSTEM_PROMPTS[persona] system = ( "You are a data labelling specialist. Your task is to rephrase a database " "question in a specific communication style (persona). The rephrased questions " "must preserve the EXACT same intent and required information as the original — " "do not change what data is being asked for, only how it is expressed.\n\n" f"PERSONA: {persona_desc}\n\n" "OUTPUT FORMAT: Return ONLY a valid JSON array of strings. " "No preamble, no markdown, no extra keys. Example: " '["question 1", "question 2", "question 3"]' ) user = ( f"DATABASE CONTEXT:\n{schema_context}\n\n" f"WHAT THE QUERY DOES: {description}\n\n" f"CANONICAL QUESTION: {canonical_nl}\n\n" f"Generate {n_variants} different ways a person with the persona described " f"above would ask this same question. The meaning must stay identical." ) return [ {"role": "system", "content": system}, {"role": "user", "content": user}, ] # ───────────────────────────────────────────────────────────────────────────── # RESPONSE PARSER # ───────────────────────────────────────────────────────────────────────────── def parse_llm_response(raw_text: str) -> list[str]: """ Extract a list of strings from the LLM's JSON response. Handles common failures: markdown fences, trailing commas, extra text. Returns an empty list if parsing fails completely. """ text = raw_text.strip() # Strip markdown fences if present if text.startswith("```"): lines = text.split("\n") text = "\n".join(l for l in lines if not l.strip().startswith("```")).strip() # Find the JSON array boundaries start = text.find("[") end = text.rfind("]") if start == -1 or end == -1 or end <= start: logger.warning("LLM response missing JSON array brackets: %s", text[:100]) return [] json_str = text[start:end + 1] # Fix trailing commas before ] (common LLM mistake) json_str = json_str.rstrip() json_str = json_str.replace(",]", "]").replace(", ]", "]") try: parsed = json.loads(json_str) if not isinstance(parsed, list): return [] # Filter to only non-empty strings return [s.strip() for s in parsed if isinstance(s, str) and s.strip()] except json.JSONDecodeError as exc: logger.warning("JSON parse error: %s | text: %s", exc, json_str[:200]) return [] # ───────────────────────────────────────────────────────────────────────────── # VLLM INTERFACE # ───────────────────────────────────────────────────────────────────────────── class VLLMGenerator: """ Wrapper around a running vLLM server for high-throughput NL generation. Supports two modes: online : Calls a running vLLM OpenAI-compatible API server. offline : Uses vllm.LLM directly (loads model in-process, H100 recommended). For H100 cluster usage, prefer 'offline' mode with tensor_parallel_size=4 to saturate all 4 H100s for maximum throughput. """ def __init__( self, model_name: str, mode: str = "offline", tensor_parallel_size: int = 4, gpu_memory_utilization: float = 0.90, max_model_len: int = 4096, # Online mode only api_base: str = "http://localhost:8000/v1", api_key: str = "EMPTY", ) -> None: self.model_name = model_name self.mode = mode self._llm = None self._client = None if mode == "offline": self._init_offline(tensor_parallel_size, gpu_memory_utilization, max_model_len) elif mode == "online": self._init_online(api_base, api_key) else: raise ValueError(f"Unknown mode: {mode!r}. Use 'offline' or 'online'.") def _init_offline( self, tensor_parallel_size: int, gpu_memory_utilization: float, max_model_len: int, ) -> None: """Load vLLM engine in-process (best for H100 cluster).""" try: from vllm import LLM, SamplingParams self._LLM = LLM self._SamplingParams = SamplingParams except ImportError: raise ImportError( "vLLM not installed. Run: pip install vllm\n" "For H100: pip install vllm --extra-index-url https://download.pytorch.org/whl/cu124" ) logger.info("Loading model %s with %d GPUs (offline mode)...", self.model_name, tensor_parallel_size) t0 = time.time() self._llm = self._LLM( model=self.model_name, tensor_parallel_size=tensor_parallel_size, gpu_memory_utilization=gpu_memory_utilization, max_model_len=max_model_len, dtype="bfloat16", trust_remote_code=True, ) logger.info("Model loaded in %.1f seconds.", time.time() - t0) def _init_online(self, api_base: str, api_key: str) -> None: """Use OpenAI-compatible vLLM server (for distributed setups).""" try: from openai import OpenAI self._client = OpenAI(base_url=api_base, api_key=api_key) except ImportError: raise ImportError("pip install openai") logger.info("Connected to vLLM server at %s", api_base) def generate_batch( self, prompts: list[list[dict[str, str]]], temperature: float = 0.85, max_new_tokens: int = 300, ) -> list[str]: """ Generate responses for a batch of chat prompts. Parameters ---------- prompts : List of chat message lists (one per item in batch). temperature : Sampling temperature. Higher = more diverse. max_new_tokens : Max tokens per response. Returns ------- list[str] Raw text response per prompt (same length as input). """ if self.mode == "offline": return self._generate_offline(prompts, temperature, max_new_tokens) else: return self._generate_online(prompts, temperature, max_new_tokens) def _generate_offline( self, prompts: list[list[dict]], temperature: float, max_new_tokens: int, ) -> list[str]: """vLLM offline batched generation — maximises H100 throughput.""" from vllm import SamplingParams sampling = SamplingParams( temperature=temperature, max_tokens=max_new_tokens, stop=["", "<|eot_id|>"], # Llama-3 stop tokens ) # Convert chat messages to tokenised prompt strings using the model's template tokenizer = self._llm.get_tokenizer() formatted_prompts: list[str] = [] for msgs in prompts: if hasattr(tokenizer, "apply_chat_template"): text = tokenizer.apply_chat_template( msgs, tokenize=False, add_generation_prompt=True ) else: # Fallback: simple concatenation text = "\n".join( f"<|{m['role']}|>\n{m['content']}" for m in msgs ) formatted_prompts.append(text) outputs = self._llm.generate(formatted_prompts, sampling) return [o.outputs[0].text for o in outputs] def _generate_online( self, prompts: list[list[dict]], temperature: float, max_new_tokens: int, ) -> list[str]: """Sequential generation via OpenAI-compatible API (fallback / debugging).""" results = [] for msgs in prompts: try: resp = self._client.chat.completions.create( model=self.model_name, messages=msgs, temperature=temperature, max_tokens=max_new_tokens, ) results.append(resp.choices[0].message.content or "") except Exception as exc: logger.warning("API call failed: %s", exc) results.append("") return results # ───────────────────────────────────────────────────────────────────────────── # HIGH-LEVEL GENERATION LOOP # ───────────────────────────────────────────────────────────────────────────── def generate_persona_variants_batch( templates_subset: list[dict], generator: VLLMGenerator, personas: list[str], n_variants_per_persona: int = 5, batch_size: int = 64, temperature: float = 0.85, max_new_tokens: int = 300, ) -> Iterator[dict]: """ For each template × persona combination, generate `n_variants_per_persona` NL question variants using the LLM. Yields dicts: { "template_idx": int, "persona": str, "nl_variants": list[str], # successfully parsed NL questions } Parameters ---------- templates_subset : List of template dicts (from templates.py). generator : VLLMGenerator instance. personas : List of persona keys to use. n_variants_per_persona : How many NL variants per (template, persona) pair. batch_size : How many LLM calls to batch together. temperature : Sampling temperature. max_new_tokens : Max tokens for LLM response (should be ~300 for JSON array). """ from data_factory.schemas import SCHEMA_CONTEXT # Build all (template_idx, persona) prompt pairs all_jobs: list[tuple[int, str, list[dict]]] = [] for t_idx, template in enumerate(templates_subset): schema_ctx = SCHEMA_CONTEXT[template["domain"]] for persona in personas: prompt = build_generation_prompt( canonical_nl=template["base_nl"], description=template["description"], persona=persona, schema_context=schema_ctx, n_variants=n_variants_per_persona, ) all_jobs.append((t_idx, persona, prompt)) total_jobs = len(all_jobs) logger.info("Starting LLM generation: %d jobs (templates × personas).", total_jobs) # Process in batches for batch_start in range(0, total_jobs, batch_size): batch = all_jobs[batch_start: batch_start + batch_size] prompts = [job[2] for job in batch] t0 = time.time() raw_responses = generator.generate_batch( prompts, temperature=temperature, max_new_tokens=max_new_tokens ) elapsed = time.time() - t0 logger.info( "Batch %d-%d completed in %.1fs (%.1f jobs/s).", batch_start, batch_start + len(batch), elapsed, len(batch) / max(elapsed, 0.001) ) for (t_idx, persona, _), raw in zip(batch, raw_responses): nl_variants = parse_llm_response(raw) if not nl_variants: logger.debug( "Empty parse for template_idx=%d persona=%s. raw=%s", t_idx, persona, raw[:100] ) # Fall back to the canonical NL rather than losing this entry nl_variants = [templates_subset[t_idx]["base_nl"]] yield { "template_idx": t_idx, "persona": persona, "nl_variants": nl_variants, }