Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import json | |
| import torch | |
| import hashlib | |
| from pathlib import Path | |
| from tqdm import tqdm | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import sys | |
| # --- PATCH FOR TRANSFORMERS VERSION MISMATCH --- | |
| try: | |
| import transformers.activations | |
| if not hasattr(transformers.activations, "PytorchGELUTanh"): | |
| # Mapping the old name to the new existing one | |
| transformers.activations.PytorchGELUTanh = transformers.activations.GELUActivation | |
| except ImportError: | |
| pass | |
| # ------------------------------------------------------ | |
| import os | |
| import json | |
| import torch | |
| # ... baaki ke saare purane imports | |
| # Force script to use only the 2 free GPUs (e.g., 0 and 7) | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "0,7" | |
| PROJECT_ROOT = os.path.abspath(os.path.dirname(__file__)) | |
| if PROJECT_ROOT not in sys.path: | |
| sys.path.insert(0, PROJECT_ROOT) | |
| from data_factory.schemas import SCHEMA_CONTEXT | |
| # AWQ model is 4x smaller and much faster | |
| MODEL_NAME = "Qwen/Qwen2.5-72B-Instruct-AWQ" | |
| INPUT_FILE = "llm_hybrid_templates.json" | |
| OUTPUT_FILE = "nl2sql_50k_elite_dataset.jsonl" | |
| VARIATIONS_PER_SQL = 20 | |
| BATCH_SIZE = 64 # AWQ allows much larger batches! | |
| SYSTEM_PROMPT = "You are an expert SQL analyst. Write a single SELECT query that answers the question. Output ONLY the SQL query — no markdown, no explanation, no backticks." | |
| EXPANSION_PROMPT = """ | |
| You are an expert linguist and NL2SQL data augmentor. I have a SQLite database schema and a complex SQL query. | |
| Generate exactly {count} completely different natural language questions that this exact SQL query answers. | |
| RULES: | |
| - Personas: Executive (direct), Non-tech (wordy), Analyst (technical), Curious (investigative). | |
| - Structure: Completely change sentence flow. | |
| - No direct column/table names. | |
| DATABASE SCHEMA: | |
| {schema} | |
| SQL QUERY: | |
| {sql} | |
| OUTPUT FORMAT: | |
| Return ONLY a valid JSON array of objects: [{{"persona": "...", "question": "..."}}] | |
| """ | |
| def extract_json_array(raw_text): | |
| text = raw_text.strip() | |
| start = text.find("[") | |
| end = text.rfind("]") | |
| if start != -1 and end != -1: | |
| return text[start:end+1] | |
| return "[]" | |
| def get_hash(text): | |
| return hashlib.md5(text.lower().strip().encode('utf-8')).hexdigest() | |
| def main(): | |
| if not os.path.exists(INPUT_FILE): | |
| print(f"Error: {INPUT_FILE} not found.") | |
| sys.exit(1) | |
| with open(INPUT_FILE, "r") as f: | |
| base_templates = json.load(f) | |
| print(f"🚀 Loading {MODEL_NAME} on 2 GPUs...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side="left") | |
| tokenizer.pad_token = tokenizer.eos_token | |
| # Model loading (AWQ version automatically handles quantization) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| device_map="auto", | |
| torch_dtype=torch.float16, # AWQ models use float16/bfloat16 for weights | |
| low_cpu_mem_usage=True | |
| ) | |
| seen_hashes = set() | |
| total_saved = 0 | |
| if os.path.exists(OUTPUT_FILE): | |
| with open(OUTPUT_FILE, "r") as f: | |
| for line in f: | |
| total_saved += 1 # Quick count | |
| pbar = tqdm(total=len(base_templates) * VARIATIONS_PER_SQL, initial=total_saved) | |
| # Batch processing | |
| for i in range(0, len(base_templates), BATCH_SIZE): | |
| batch = base_templates[i:i + BATCH_SIZE] | |
| prompts = [] | |
| for temp in batch: | |
| msg = [ | |
| {"role": "system", "content": "You output only JSON arrays."}, | |
| {"role": "user", "content": EXPANSION_PROMPT.format(count=VARIATIONS_PER_SQL, schema=SCHEMA_CONTEXT[temp['domain']], sql=temp['sql'])} | |
| ] | |
| prompts.append(tokenizer.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)) | |
| inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device) | |
| try: | |
| with torch.no_grad(): | |
| # Increased speed: AWQ handles large batches efficiently | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=2048, | |
| temperature=0.5, | |
| do_sample=True, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| responses = tokenizer.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True) | |
| with open(OUTPUT_FILE, "a", encoding="utf-8") as out_file: | |
| for idx, resp in enumerate(responses): | |
| questions_data = json.loads(extract_json_array(resp)) | |
| sql = batch[idx]["sql"] | |
| domain = batch[idx]["domain"] | |
| for item in questions_data: | |
| q = item.get("question", "") | |
| if len(q) > 10: | |
| q_hash = get_hash(q + sql) | |
| if q_hash not in seen_hashes: | |
| seen_hashes.add(q_hash) | |
| record = { | |
| "prompt": [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": f"SCHEMA: {SCHEMA_CONTEXT[domain]}\nQUESTION: {q}"} | |
| ], | |
| "sql": sql | |
| } | |
| out_file.write(json.dumps(record, ensure_ascii=False) + "\n") | |
| total_saved += 1 | |
| pbar.update(1) | |
| out_file.flush() | |
| except Exception as e: | |
| print(f"Batch failed: {e}") | |
| continue | |
| pbar.close() | |
| if __name__ == "__main__": | |
| main() |