nl2sql-bench / data_expander.py
ritvik360's picture
Upload folder using huggingface_hub
a39d8ef verified
import os
import sys
import json
import torch
import hashlib
from pathlib import Path
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
import sys
# --- PATCH FOR TRANSFORMERS VERSION MISMATCH ---
try:
import transformers.activations
if not hasattr(transformers.activations, "PytorchGELUTanh"):
# Mapping the old name to the new existing one
transformers.activations.PytorchGELUTanh = transformers.activations.GELUActivation
except ImportError:
pass
# ------------------------------------------------------
import os
import json
import torch
# ... baaki ke saare purane imports
# Force script to use only the 2 free GPUs (e.g., 0 and 7)
os.environ["CUDA_VISIBLE_DEVICES"] = "0,7"
PROJECT_ROOT = os.path.abspath(os.path.dirname(__file__))
if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
from data_factory.schemas import SCHEMA_CONTEXT
# AWQ model is 4x smaller and much faster
MODEL_NAME = "Qwen/Qwen2.5-72B-Instruct-AWQ"
INPUT_FILE = "llm_hybrid_templates.json"
OUTPUT_FILE = "nl2sql_50k_elite_dataset.jsonl"
VARIATIONS_PER_SQL = 20
BATCH_SIZE = 64 # AWQ allows much larger batches!
SYSTEM_PROMPT = "You are an expert SQL analyst. Write a single SELECT query that answers the question. Output ONLY the SQL query — no markdown, no explanation, no backticks."
EXPANSION_PROMPT = """
You are an expert linguist and NL2SQL data augmentor. I have a SQLite database schema and a complex SQL query.
Generate exactly {count} completely different natural language questions that this exact SQL query answers.
RULES:
- Personas: Executive (direct), Non-tech (wordy), Analyst (technical), Curious (investigative).
- Structure: Completely change sentence flow.
- No direct column/table names.
DATABASE SCHEMA:
{schema}
SQL QUERY:
{sql}
OUTPUT FORMAT:
Return ONLY a valid JSON array of objects: [{{"persona": "...", "question": "..."}}]
"""
def extract_json_array(raw_text):
text = raw_text.strip()
start = text.find("[")
end = text.rfind("]")
if start != -1 and end != -1:
return text[start:end+1]
return "[]"
def get_hash(text):
return hashlib.md5(text.lower().strip().encode('utf-8')).hexdigest()
def main():
if not os.path.exists(INPUT_FILE):
print(f"Error: {INPUT_FILE} not found.")
sys.exit(1)
with open(INPUT_FILE, "r") as f:
base_templates = json.load(f)
print(f"🚀 Loading {MODEL_NAME} on 2 GPUs...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
# Model loading (AWQ version automatically handles quantization)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
device_map="auto",
torch_dtype=torch.float16, # AWQ models use float16/bfloat16 for weights
low_cpu_mem_usage=True
)
seen_hashes = set()
total_saved = 0
if os.path.exists(OUTPUT_FILE):
with open(OUTPUT_FILE, "r") as f:
for line in f:
total_saved += 1 # Quick count
pbar = tqdm(total=len(base_templates) * VARIATIONS_PER_SQL, initial=total_saved)
# Batch processing
for i in range(0, len(base_templates), BATCH_SIZE):
batch = base_templates[i:i + BATCH_SIZE]
prompts = []
for temp in batch:
msg = [
{"role": "system", "content": "You output only JSON arrays."},
{"role": "user", "content": EXPANSION_PROMPT.format(count=VARIATIONS_PER_SQL, schema=SCHEMA_CONTEXT[temp['domain']], sql=temp['sql'])}
]
prompts.append(tokenizer.apply_chat_template(msg, tokenize=False, add_generation_prompt=True))
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
try:
with torch.no_grad():
# Increased speed: AWQ handles large batches efficiently
outputs = model.generate(
**inputs,
max_new_tokens=2048,
temperature=0.5,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
responses = tokenizer.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)
with open(OUTPUT_FILE, "a", encoding="utf-8") as out_file:
for idx, resp in enumerate(responses):
questions_data = json.loads(extract_json_array(resp))
sql = batch[idx]["sql"]
domain = batch[idx]["domain"]
for item in questions_data:
q = item.get("question", "")
if len(q) > 10:
q_hash = get_hash(q + sql)
if q_hash not in seen_hashes:
seen_hashes.add(q_hash)
record = {
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": f"SCHEMA: {SCHEMA_CONTEXT[domain]}\nQUESTION: {q}"}
],
"sql": sql
}
out_file.write(json.dumps(record, ensure_ascii=False) + "\n")
total_saved += 1
pbar.update(1)
out_file.flush()
except Exception as e:
print(f"Batch failed: {e}")
continue
pbar.close()
if __name__ == "__main__":
main()