import os import sys import json import torch import hashlib from pathlib import Path from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer import sys # --- PATCH FOR TRANSFORMERS VERSION MISMATCH --- try: import transformers.activations if not hasattr(transformers.activations, "PytorchGELUTanh"): # Mapping the old name to the new existing one transformers.activations.PytorchGELUTanh = transformers.activations.GELUActivation except ImportError: pass # ------------------------------------------------------ import os import json import torch # ... baaki ke saare purane imports # Force script to use only the 2 free GPUs (e.g., 0 and 7) os.environ["CUDA_VISIBLE_DEVICES"] = "0,7" PROJECT_ROOT = os.path.abspath(os.path.dirname(__file__)) if PROJECT_ROOT not in sys.path: sys.path.insert(0, PROJECT_ROOT) from data_factory.schemas import SCHEMA_CONTEXT # AWQ model is 4x smaller and much faster MODEL_NAME = "Qwen/Qwen2.5-72B-Instruct-AWQ" INPUT_FILE = "llm_hybrid_templates.json" OUTPUT_FILE = "nl2sql_50k_elite_dataset.jsonl" VARIATIONS_PER_SQL = 20 BATCH_SIZE = 64 # AWQ allows much larger batches! SYSTEM_PROMPT = "You are an expert SQL analyst. Write a single SELECT query that answers the question. Output ONLY the SQL query — no markdown, no explanation, no backticks." EXPANSION_PROMPT = """ You are an expert linguist and NL2SQL data augmentor. I have a SQLite database schema and a complex SQL query. Generate exactly {count} completely different natural language questions that this exact SQL query answers. RULES: - Personas: Executive (direct), Non-tech (wordy), Analyst (technical), Curious (investigative). - Structure: Completely change sentence flow. - No direct column/table names. DATABASE SCHEMA: {schema} SQL QUERY: {sql} OUTPUT FORMAT: Return ONLY a valid JSON array of objects: [{{"persona": "...", "question": "..."}}] """ def extract_json_array(raw_text): text = raw_text.strip() start = text.find("[") end = text.rfind("]") if start != -1 and end != -1: return text[start:end+1] return "[]" def get_hash(text): return hashlib.md5(text.lower().strip().encode('utf-8')).hexdigest() def main(): if not os.path.exists(INPUT_FILE): print(f"Error: {INPUT_FILE} not found.") sys.exit(1) with open(INPUT_FILE, "r") as f: base_templates = json.load(f) print(f"🚀 Loading {MODEL_NAME} on 2 GPUs...") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side="left") tokenizer.pad_token = tokenizer.eos_token # Model loading (AWQ version automatically handles quantization) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, device_map="auto", torch_dtype=torch.float16, # AWQ models use float16/bfloat16 for weights low_cpu_mem_usage=True ) seen_hashes = set() total_saved = 0 if os.path.exists(OUTPUT_FILE): with open(OUTPUT_FILE, "r") as f: for line in f: total_saved += 1 # Quick count pbar = tqdm(total=len(base_templates) * VARIATIONS_PER_SQL, initial=total_saved) # Batch processing for i in range(0, len(base_templates), BATCH_SIZE): batch = base_templates[i:i + BATCH_SIZE] prompts = [] for temp in batch: msg = [ {"role": "system", "content": "You output only JSON arrays."}, {"role": "user", "content": EXPANSION_PROMPT.format(count=VARIATIONS_PER_SQL, schema=SCHEMA_CONTEXT[temp['domain']], sql=temp['sql'])} ] prompts.append(tokenizer.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)) inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) try: with torch.no_grad(): # Increased speed: AWQ handles large batches efficiently outputs = model.generate( **inputs, max_new_tokens=2048, temperature=0.5, do_sample=True, pad_token_id=tokenizer.eos_token_id ) responses = tokenizer.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True) with open(OUTPUT_FILE, "a", encoding="utf-8") as out_file: for idx, resp in enumerate(responses): questions_data = json.loads(extract_json_array(resp)) sql = batch[idx]["sql"] domain = batch[idx]["domain"] for item in questions_data: q = item.get("question", "") if len(q) > 10: q_hash = get_hash(q + sql) if q_hash not in seen_hashes: seen_hashes.add(q_hash) record = { "prompt": [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": f"SCHEMA: {SCHEMA_CONTEXT[domain]}\nQUESTION: {q}"} ], "sql": sql } out_file.write(json.dumps(record, ensure_ascii=False) + "\n") total_saved += 1 pbar.update(1) out_file.flush() except Exception as e: print(f"Batch failed: {e}") continue pbar.close() if __name__ == "__main__": main()