Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import json | |
| import torch | |
| import hashlib | |
| from pathlib import Path | |
| from tqdm import tqdm | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| # GPU CONFIG - All 4 H100s engaged | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,7" | |
| PROJECT_ROOT = os.path.abspath(os.path.dirname(__file__)) | |
| if PROJECT_ROOT not in sys.path: | |
| sys.path.insert(0, PROJECT_ROOT) | |
| from data_factory.schemas import SCHEMA_CONTEXT | |
| from data_factory.validator import SQLValidator | |
| # CONFIG | |
| MODEL_NAME = "Qwen/Qwen2.5-72B-Instruct" | |
| TARGET_TEMPLATES = 10000 | |
| OUTPUT_FILE = "llm_10k_base_templates.json" | |
| BATCH_SIZE = 64 | |
| PROMPT_TEMPLATE = """ | |
| You are a senior expert in SQLite schema design and NL2SQL dataset generation. | |
| TASK | |
| Generate exactly 10 UNIQUE, COMPLEX, and FULLY VALID SQLite SQL SELECT queries for the given schema. | |
| For each query, also write a natural language question that a real user might ask. | |
| HARD RULES | |
| - Output ONLY a valid JSON array. | |
| - Do NOT wrap output in markdown, code fences, or explanations. | |
| - Every item must be a JSON object with exactly these keys: | |
| - "sql" | |
| - "base_nl" | |
| - "difficulty" | |
| - "has_order" | |
| - All SQL must be a single SELECT statement. | |
| - Do NOT use INSERT, UPDATE, DELETE, DROP, CREATE, ALTER, PRAGMA, ATTACH, DETACH, or any DDL/DML. | |
| - Every table and column used in SQL must exist in the provided schema. | |
| - Do NOT invent columns, tables, aliases, or constraints. | |
| - SQL must be valid for SQLite. | |
| - Prefer queries that are meaningfully different from each other. | |
| - Avoid repetitive templates. | |
| - Each SQL should test a different reasoning pattern. | |
| - Each base_nl should sound natural and distinct from the others. | |
| - Use advanced SQL patterns where appropriate: | |
| - multiple JOINs | |
| - CTEs | |
| - subqueries | |
| - window functions such as ROW_NUMBER, RANK, DENSE_RANK, LAG, LEAD | |
| - GROUP BY and HAVING | |
| - conditional aggregation | |
| - anti-joins / exclusion logic | |
| - top-N per group | |
| - time-based filtering | |
| - Exactly 3 of the 10 queries must be "easy" (basic filtering, simple lookups, 1-2 tables). | |
| - Exactly 3 of the 10 queries must be "medium" (moderate complexity, standard JOINs, basic aggregation). | |
| - Exactly 4 of the 10 queries must be genuinely "hard" (advanced patterns, CTEs, subqueries, window functions). | |
| - Ensure the "difficulty" key strictly contains one of these exact string values: "easy", "medium", or "hard". | |
| QUALITY TARGETS | |
| - The SQL should be executable as written. | |
| - The question should be answerable from the schema alone. | |
| - Prefer business-like, realistic analytics questions. | |
| - Prefer queries that require combining 2 to 4 tables. | |
| - If a query uses aggregation, ensure the NL clearly implies aggregation. | |
| - If a query uses ordering, include "has_order": true. | |
| - If a query does not require ordering, set "has_order": false. | |
| - Make the 10 queries cover diverse intent types: | |
| 1. ranking | |
| 2. comparison against average or median | |
| 3. top/bottom-N | |
| 4. grouped aggregation | |
| 5. time filtering | |
| 6. multi-join analysis | |
| 7. exclusion / NOT EXISTS | |
| 8. window-function based analysis | |
| 9. conditional counting | |
| 10. trend or interval-based logic | |
| SCHEMA | |
| {schema} | |
| OUTPUT FORMAT | |
| Return ONLY a valid JSON array of 10 objects. | |
| Example structure: | |
| [ | |
| {{ | |
| "sql": "SELECT ...", | |
| "base_nl": "Show ...", | |
| "difficulty": "hard", | |
| "has_order": true | |
| }} | |
| ] | |
| FINAL SELF-CHECK BEFORE RESPONDING | |
| - Confirm the output is valid JSON. | |
| - Confirm there are exactly 10 objects. | |
| - Confirm every SQL is a single SELECT. | |
| - Confirm no hallucinated schema elements exist. | |
| - Confirm the 10 questions are not paraphrases of each other. | |
| """ | |
| def extract_json(raw_text): | |
| text = raw_text.strip() | |
| if text.startswith("```json"): | |
| text = text[7:-3].strip() | |
| elif text.startswith("```"): | |
| text = text[3:-3].strip() | |
| start = text.find("[") | |
| end = text.rfind("]") | |
| if start != -1 and end != -1: | |
| return text[start:end+1] | |
| return None | |
| def main(): | |
| print("Loading Model Qwen-72B (SDPA) for 10K Mining...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| custom_max_memory = { | |
| 0: "60GiB", # System GPU 0 (Has 13GB used, ~67GB free) | |
| 1: "75GiB", # System GPU 1 (Fully free) | |
| 2: "75GiB", # System GPU 2 (Fully free) | |
| 3: "75GiB", # System GPU 3 (Fully free) | |
| 4: "75GiB", # System GPU 4 (Fully free) | |
| 5: "45GiB" # System GPU 7 (Has 25GB used, ~55GB free) | |
| } | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| device_map="auto", | |
| max_memory = custom_max_memory, | |
| torch_dtype=torch.bfloat16, | |
| attn_implementation="sdpa" | |
| ) | |
| domains = list(SCHEMA_CONTEXT.keys()) | |
| valid_templates = [] | |
| seen_sql_hashes = set() | |
| # Resume support: Load existing templates to prevent duplicates | |
| if os.path.exists(OUTPUT_FILE): | |
| with open(OUTPUT_FILE, "r") as f: | |
| valid_templates = json.load(f) | |
| for t in valid_templates: | |
| seen_sql_hashes.add(hashlib.md5(t["sql"].lower().encode()).hexdigest()) | |
| pbar = tqdm(total=TARGET_TEMPLATES, initial=len(valid_templates), desc="Mining 10K Base Templates") | |
| validators = {} | |
| domain_idx = 0 | |
| while len(valid_templates) < TARGET_TEMPLATES: | |
| batch_prompts = [] | |
| batch_domains = [] | |
| # Prepare Batch | |
| for _ in range(BATCH_SIZE): | |
| domain = domains[domain_idx % len(domains)] | |
| schema_string = SCHEMA_CONTEXT[domain] | |
| domain_idx += 1 | |
| messages = [ | |
| {"role": "system", "content": "You output only valid JSON arrays. Do not include markdown."}, | |
| {"role": "user", "content": PROMPT_TEMPLATE.format(schema=schema_string)} | |
| ] | |
| chat_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| batch_prompts.append(chat_text) | |
| batch_domains.append(domain) | |
| inputs = tokenizer(batch_prompts, return_tensors="pt", padding=True, truncation=True).to(model.device) | |
| try: | |
| tqdm.write(f"\n[DEBUG] Sending batch of {BATCH_SIZE} to model.generate(). Please wait...") | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=5000, | |
| do_sample=True, | |
| temperature=0.55, | |
| top_p=0.9, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| tqdm.write("[DEBUG] Model generation finished. Decoding responses...") | |
| # Output Slicing | |
| input_length = inputs.input_ids.shape[1] | |
| generated_tokens = outputs[:, input_length:] | |
| responses = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) | |
| batch_added = 0 | |
| for i, (response, domain) in enumerate(zip(responses, batch_domains)): | |
| tqdm.write(f"\n[DEBUG] Processing Response {i+1}/{BATCH_SIZE} for domain: {domain}") | |
| json_text = extract_json(response) | |
| if not json_text: | |
| tqdm.write(f"[DEBUG] extract_json failed. Raw text snippet: {response[:200]}...") | |
| continue | |
| try: | |
| generated_data = json.loads(json_text) | |
| tqdm.write(f"[DEBUG] JSON loaded successfully. Found {len(generated_data)} items.") | |
| except Exception as e: | |
| tqdm.write(f"[DEBUG] json.loads failed. Error: {e}") | |
| tqdm.write(f"[DEBUG] Bad JSON snippet: {json_text[:200]}...") | |
| continue | |
| if domain not in validators: | |
| validators[domain] = SQLValidator(domain, seed=42) | |
| validator = validators[domain] | |
| for item in generated_data: | |
| if not isinstance(item, dict): continue | |
| sql = item.get("sql", "").strip() | |
| if not sql: continue | |
| # Check for duplicates using hash | |
| sql_hash = hashlib.md5(sql.lower().encode()).hexdigest() | |
| if sql_hash in seen_sql_hashes: | |
| tqdm.write("[DEBUG] Duplicate query skipped.") | |
| continue | |
| val_result = validator.validate(sql) | |
| # Hard validation rule: SQL must execute AND return rows | |
| if val_result.passed and val_result.row_count > 0: | |
| tqdm.write(f"[DEBUG] SQL Passed (Rows: {val_result.row_count}): {sql[:50]}...") | |
| item["domain"] = domain | |
| item["id"] = f"base_{len(valid_templates)}" | |
| valid_templates.append(item) | |
| seen_sql_hashes.add(sql_hash) | |
| batch_added += 1 | |
| else: | |
| tqdm.write(f"[DEBUG] SQL Failed Validation or 0 Rows (Passed: {val_result.passed}, Rows: {val_result.row_count}): {sql[:50]}...") | |
| if batch_added > 0: | |
| pbar.update(batch_added) | |
| tqdm.write(f"[DEBUG] Auto-saving {batch_added} new templates to JSON...") | |
| # Auto-save after every successful batch | |
| with open(OUTPUT_FILE, "w") as f: | |
| json.dump(valid_templates, f, indent=2) | |
| if len(valid_templates) >= TARGET_TEMPLATES: | |
| break | |
| except Exception as e: | |
| tqdm.write(f"\n[DEBUG] CRITICAL EXCEPTION CAUGHT: {e}") | |
| continue | |
| # Close validators | |
| for v in validators.values(): | |
| v.close() | |
| pbar.close() | |
| print(f"\nBoom! Generated {len(valid_templates)} Elite Base Templates!") | |
| if __name__ == "__main__": | |
| main() |