Spaces:
Sleeping
Sleeping
| """ | |
| data_factory/generator.py | |
| ========================== | |
| vLLM-based Natural Language question generator for H100. | |
| This module uses a large LLM (Llama-3-70B or Qwen-72B) served via vLLM | |
| to generate diverse, persona-based natural language paraphrases of the | |
| canonical NL questions in our template library. | |
| KEY DESIGN: The LLM generates ONLY natural language questions. | |
| SQL is NEVER touched by the LLM. | |
| This guarantees zero SQL errors in the final dataset. | |
| Persona descriptions: | |
| ceo - Direct, short, active voice. Business executive style. | |
| chatty - Conversational, verbose, passive voice. | |
| lazy_typist - Short, abbreviations, possible informal grammar. | |
| non_techie - Plain English, avoids SQL/tech jargon, uses synonyms. | |
| analyst - Technical, precise, jargon-heavy. | |
| Usage (on H100 cluster): | |
| python -m data_factory.generator --templates-per-chunk 20 --n-variants 10 | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import logging | |
| import time | |
| from typing import Iterator, Optional | |
| logger = logging.getLogger(__name__) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # PERSONA SYSTEM PROMPTS | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| PERSONA_SYSTEM_PROMPTS: dict[str, str] = { | |
| "ceo": ( | |
| "You are a busy C-level executive who communicates in short, punchy, " | |
| "direct sentences. You use active voice, skip filler words, and get " | |
| "straight to the point. You are asking a data analyst for information." | |
| ), | |
| "chatty": ( | |
| "You are a friendly, conversational person who likes to be thorough " | |
| "and explain things fully. You use passive voice sometimes, add context, " | |
| "and ask questions in a relaxed, detailed way. You are not technical." | |
| ), | |
| "lazy_typist": ( | |
| "You type quickly and informally. You use abbreviations (e.g. 'pls', " | |
| "'lmk', 'asap'), lowercase, minimal punctuation, and sometimes omit " | |
| "words. You get your meaning across without perfect grammar." | |
| ), | |
| "non_techie": ( | |
| "You have no database or SQL knowledge. You use everyday English words " | |
| "instead of technical terms. For example, you say 'customers' not 'rows', " | |
| "'most expensive' not 'highest price', 'total money' not 'sum'. " | |
| "You describe what you want to see, not how to get it." | |
| ), | |
| "analyst": ( | |
| "You are a data scientist or BI analyst who is precise and technical. " | |
| "You use terms like 'aggregate', 'partition', 'granularity', 'distinct', " | |
| "'filter predicate', 'ranked by metric'. Your questions are precise and unambiguous." | |
| ), | |
| } | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # PROMPT BUILDER | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_generation_prompt( | |
| canonical_nl: str, | |
| description: str, | |
| persona: str, | |
| schema_context: str, | |
| n_variants: int = 5, | |
| ) -> list[dict[str, str]]: | |
| """ | |
| Build a chat-format prompt asking the LLM to rephrase the canonical NL | |
| question in the style of the given persona. | |
| Parameters | |
| ---------- | |
| canonical_nl : The base NL question from the template. | |
| description : One-line SQL description (gives the LLM additional context). | |
| persona : One of the 5 persona keys. | |
| schema_context : The compact schema string for the domain. | |
| n_variants : How many rephrased questions to generate. | |
| Returns | |
| ------- | |
| list[dict] Chat messages in [{"role": ..., "content": ...}] format. | |
| """ | |
| persona_desc = PERSONA_SYSTEM_PROMPTS[persona] | |
| system = ( | |
| "You are a data labelling specialist. Your task is to rephrase a database " | |
| "question in a specific communication style (persona). The rephrased questions " | |
| "must preserve the EXACT same intent and required information as the original β " | |
| "do not change what data is being asked for, only how it is expressed.\n\n" | |
| f"PERSONA: {persona_desc}\n\n" | |
| "OUTPUT FORMAT: Return ONLY a valid JSON array of strings. " | |
| "No preamble, no markdown, no extra keys. Example: " | |
| '["question 1", "question 2", "question 3"]' | |
| ) | |
| user = ( | |
| f"DATABASE CONTEXT:\n{schema_context}\n\n" | |
| f"WHAT THE QUERY DOES: {description}\n\n" | |
| f"CANONICAL QUESTION: {canonical_nl}\n\n" | |
| f"Generate {n_variants} different ways a person with the persona described " | |
| f"above would ask this same question. The meaning must stay identical." | |
| ) | |
| return [ | |
| {"role": "system", "content": system}, | |
| {"role": "user", "content": user}, | |
| ] | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # RESPONSE PARSER | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def parse_llm_response(raw_text: str) -> list[str]: | |
| """ | |
| Extract a list of strings from the LLM's JSON response. | |
| Handles common failures: markdown fences, trailing commas, extra text. | |
| Returns an empty list if parsing fails completely. | |
| """ | |
| text = raw_text.strip() | |
| # Strip markdown fences if present | |
| if text.startswith("```"): | |
| lines = text.split("\n") | |
| text = "\n".join(l for l in lines if not l.strip().startswith("```")).strip() | |
| # Find the JSON array boundaries | |
| start = text.find("[") | |
| end = text.rfind("]") | |
| if start == -1 or end == -1 or end <= start: | |
| logger.warning("LLM response missing JSON array brackets: %s", text[:100]) | |
| return [] | |
| json_str = text[start:end + 1] | |
| # Fix trailing commas before ] (common LLM mistake) | |
| json_str = json_str.rstrip() | |
| json_str = json_str.replace(",]", "]").replace(", ]", "]") | |
| try: | |
| parsed = json.loads(json_str) | |
| if not isinstance(parsed, list): | |
| return [] | |
| # Filter to only non-empty strings | |
| return [s.strip() for s in parsed if isinstance(s, str) and s.strip()] | |
| except json.JSONDecodeError as exc: | |
| logger.warning("JSON parse error: %s | text: %s", exc, json_str[:200]) | |
| return [] | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # VLLM INTERFACE | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class VLLMGenerator: | |
| """ | |
| Wrapper around a running vLLM server for high-throughput NL generation. | |
| Supports two modes: | |
| online : Calls a running vLLM OpenAI-compatible API server. | |
| offline : Uses vllm.LLM directly (loads model in-process, H100 recommended). | |
| For H100 cluster usage, prefer 'offline' mode with tensor_parallel_size=4 | |
| to saturate all 4 H100s for maximum throughput. | |
| """ | |
| def __init__( | |
| self, | |
| model_name: str, | |
| mode: str = "offline", | |
| tensor_parallel_size: int = 4, | |
| gpu_memory_utilization: float = 0.90, | |
| max_model_len: int = 4096, | |
| # Online mode only | |
| api_base: str = "http://localhost:8000/v1", | |
| api_key: str = "EMPTY", | |
| ) -> None: | |
| self.model_name = model_name | |
| self.mode = mode | |
| self._llm = None | |
| self._client = None | |
| if mode == "offline": | |
| self._init_offline(tensor_parallel_size, gpu_memory_utilization, max_model_len) | |
| elif mode == "online": | |
| self._init_online(api_base, api_key) | |
| else: | |
| raise ValueError(f"Unknown mode: {mode!r}. Use 'offline' or 'online'.") | |
| def _init_offline( | |
| self, | |
| tensor_parallel_size: int, | |
| gpu_memory_utilization: float, | |
| max_model_len: int, | |
| ) -> None: | |
| """Load vLLM engine in-process (best for H100 cluster).""" | |
| try: | |
| from vllm import LLM, SamplingParams | |
| self._LLM = LLM | |
| self._SamplingParams = SamplingParams | |
| except ImportError: | |
| raise ImportError( | |
| "vLLM not installed. Run: pip install vllm\n" | |
| "For H100: pip install vllm --extra-index-url https://download.pytorch.org/whl/cu124" | |
| ) | |
| logger.info("Loading model %s with %d GPUs (offline mode)...", self.model_name, tensor_parallel_size) | |
| t0 = time.time() | |
| self._llm = self._LLM( | |
| model=self.model_name, | |
| tensor_parallel_size=tensor_parallel_size, | |
| gpu_memory_utilization=gpu_memory_utilization, | |
| max_model_len=max_model_len, | |
| dtype="bfloat16", | |
| trust_remote_code=True, | |
| ) | |
| logger.info("Model loaded in %.1f seconds.", time.time() - t0) | |
| def _init_online(self, api_base: str, api_key: str) -> None: | |
| """Use OpenAI-compatible vLLM server (for distributed setups).""" | |
| try: | |
| from openai import OpenAI | |
| self._client = OpenAI(base_url=api_base, api_key=api_key) | |
| except ImportError: | |
| raise ImportError("pip install openai") | |
| logger.info("Connected to vLLM server at %s", api_base) | |
| def generate_batch( | |
| self, | |
| prompts: list[list[dict[str, str]]], | |
| temperature: float = 0.85, | |
| max_new_tokens: int = 300, | |
| ) -> list[str]: | |
| """ | |
| Generate responses for a batch of chat prompts. | |
| Parameters | |
| ---------- | |
| prompts : List of chat message lists (one per item in batch). | |
| temperature : Sampling temperature. Higher = more diverse. | |
| max_new_tokens : Max tokens per response. | |
| Returns | |
| ------- | |
| list[str] Raw text response per prompt (same length as input). | |
| """ | |
| if self.mode == "offline": | |
| return self._generate_offline(prompts, temperature, max_new_tokens) | |
| else: | |
| return self._generate_online(prompts, temperature, max_new_tokens) | |
| def _generate_offline( | |
| self, | |
| prompts: list[list[dict]], | |
| temperature: float, | |
| max_new_tokens: int, | |
| ) -> list[str]: | |
| """vLLM offline batched generation β maximises H100 throughput.""" | |
| from vllm import SamplingParams | |
| sampling = SamplingParams( | |
| temperature=temperature, | |
| max_tokens=max_new_tokens, | |
| stop=["</s>", "<|eot_id|>"], # Llama-3 stop tokens | |
| ) | |
| # Convert chat messages to tokenised prompt strings using the model's template | |
| tokenizer = self._llm.get_tokenizer() | |
| formatted_prompts: list[str] = [] | |
| for msgs in prompts: | |
| if hasattr(tokenizer, "apply_chat_template"): | |
| text = tokenizer.apply_chat_template( | |
| msgs, tokenize=False, add_generation_prompt=True | |
| ) | |
| else: | |
| # Fallback: simple concatenation | |
| text = "\n".join( | |
| f"<|{m['role']}|>\n{m['content']}" for m in msgs | |
| ) | |
| formatted_prompts.append(text) | |
| outputs = self._llm.generate(formatted_prompts, sampling) | |
| return [o.outputs[0].text for o in outputs] | |
| def _generate_online( | |
| self, | |
| prompts: list[list[dict]], | |
| temperature: float, | |
| max_new_tokens: int, | |
| ) -> list[str]: | |
| """Sequential generation via OpenAI-compatible API (fallback / debugging).""" | |
| results = [] | |
| for msgs in prompts: | |
| try: | |
| resp = self._client.chat.completions.create( | |
| model=self.model_name, | |
| messages=msgs, | |
| temperature=temperature, | |
| max_tokens=max_new_tokens, | |
| ) | |
| results.append(resp.choices[0].message.content or "") | |
| except Exception as exc: | |
| logger.warning("API call failed: %s", exc) | |
| results.append("") | |
| return results | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # HIGH-LEVEL GENERATION LOOP | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def generate_persona_variants_batch( | |
| templates_subset: list[dict], | |
| generator: VLLMGenerator, | |
| personas: list[str], | |
| n_variants_per_persona: int = 5, | |
| batch_size: int = 64, | |
| temperature: float = 0.85, | |
| max_new_tokens: int = 300, | |
| ) -> Iterator[dict]: | |
| """ | |
| For each template Γ persona combination, generate `n_variants_per_persona` | |
| NL question variants using the LLM. | |
| Yields dicts: | |
| { | |
| "template_idx": int, | |
| "persona": str, | |
| "nl_variants": list[str], # successfully parsed NL questions | |
| } | |
| Parameters | |
| ---------- | |
| templates_subset : List of template dicts (from templates.py). | |
| generator : VLLMGenerator instance. | |
| personas : List of persona keys to use. | |
| n_variants_per_persona : How many NL variants per (template, persona) pair. | |
| batch_size : How many LLM calls to batch together. | |
| temperature : Sampling temperature. | |
| max_new_tokens : Max tokens for LLM response (should be ~300 for JSON array). | |
| """ | |
| from data_factory.schemas import SCHEMA_CONTEXT | |
| # Build all (template_idx, persona) prompt pairs | |
| all_jobs: list[tuple[int, str, list[dict]]] = [] | |
| for t_idx, template in enumerate(templates_subset): | |
| schema_ctx = SCHEMA_CONTEXT[template["domain"]] | |
| for persona in personas: | |
| prompt = build_generation_prompt( | |
| canonical_nl=template["base_nl"], | |
| description=template["description"], | |
| persona=persona, | |
| schema_context=schema_ctx, | |
| n_variants=n_variants_per_persona, | |
| ) | |
| all_jobs.append((t_idx, persona, prompt)) | |
| total_jobs = len(all_jobs) | |
| logger.info("Starting LLM generation: %d jobs (templates Γ personas).", total_jobs) | |
| # Process in batches | |
| for batch_start in range(0, total_jobs, batch_size): | |
| batch = all_jobs[batch_start: batch_start + batch_size] | |
| prompts = [job[2] for job in batch] | |
| t0 = time.time() | |
| raw_responses = generator.generate_batch( | |
| prompts, temperature=temperature, max_new_tokens=max_new_tokens | |
| ) | |
| elapsed = time.time() - t0 | |
| logger.info( | |
| "Batch %d-%d completed in %.1fs (%.1f jobs/s).", | |
| batch_start, batch_start + len(batch), elapsed, len(batch) / max(elapsed, 0.001) | |
| ) | |
| for (t_idx, persona, _), raw in zip(batch, raw_responses): | |
| nl_variants = parse_llm_response(raw) | |
| if not nl_variants: | |
| logger.debug( | |
| "Empty parse for template_idx=%d persona=%s. raw=%s", | |
| t_idx, persona, raw[:100] | |
| ) | |
| # Fall back to the canonical NL rather than losing this entry | |
| nl_variants = [templates_subset[t_idx]["base_nl"]] | |
| yield { | |
| "template_idx": t_idx, | |
| "persona": persona, | |
| "nl_variants": nl_variants, | |
| } | |