""" generate_edge_cases.py ====================== Targeted edge-case data generator for the 4 failure patterns found in eval: 1. ROW_NUMBER vs RANK vs DENSE_RANK (tie-breaking semantics) 2. strftime month as INTEGER (not '%Y-%m' string) 3. SELECT column discipline (no unrequested extras) 4. LAG/LEAD period-over-period 5. HAVING vs WHERE placement 6. COUNT(DISTINCT) vs COUNT Produces: edge_cases.jsonl (same chat format as nl2sql_cleaned_ready_to_train.jsonl) Run: python generate_edge_cases.py """ import os, sys, json, re, hashlib from tqdm import tqdm from transformers import AutoModelForCausalLM, AutoTokenizer, AwqConfig, BitsAndBytesConfig import torch import transformers.activations # Yeh line AutoAWQ ko bewakoof banayegi taaki wo crash na ho if not hasattr(transformers.activations, 'PytorchGELUTanh'): transformers.activations.PytorchGELUTanh = transformers.activations.NewGELUActivation os.environ["CUDA_VISIBLE_DEVICES"] = "3,1,6,7" PROJECT_ROOT = os.path.abspath(os.path.dirname(__file__)) if PROJECT_ROOT not in sys.path: sys.path.insert(0, PROJECT_ROOT) from data_factory.schemas import SCHEMA_CONTEXT quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, bnb_4bit_quant_type="nf4" ) MODEL_NAME = "Qwen/Qwen2.5-72B-Instruct" OUTPUT_FILE = "edge_cases.jsonl" BATCH_SIZE = 8 # smaller — edge prompts are long SAMPLES_PER_PATTERN = 715 # ~6 batches per pattern → 5005 total edge samples SYSTEM_PROMPT = ( "You are a Senior SQL Architect. " "Output ONLY the SQL query. Use SQLite syntax." ) # ── Edge-case prompt templates ────────────────────────────────────────────── # Each entry: (pattern_tag, user_prompt_template) # {schema} is filled at runtime with a random domain schema. EDGE_PATTERNS = [ # 1. ROW_NUMBER tie-breaking — the #1 failure ("row_number_tiebreak", """SCHEMA: {schema} Generate exactly 8 NL2SQL pairs that REQUIRE ROW_NUMBER() (not RANK or DENSE_RANK) \ because the question explicitly says "pick one winner when there is a tie" \ using a tiebreaker column (e.g. lower id, earlier date). Output ONLY a valid JSON array: [ {{"nl": "...", "sql": "SELECT ..."}}, ... ] Rules: - Every SQL must use ROW_NUMBER() OVER (...) not RANK(). - The OVER clause ORDER BY must include the tiebreaker column. - WHERE rn = 1 must appear in an outer query or CTE. - No markdown. No explanation. Just the JSON array."""), # 2. RANK / DENSE_RANK — when ties SHOULD persist ("rank_dense_rank", """SCHEMA: {schema} Generate exactly 8 NL2SQL pairs where RANK() or DENSE_RANK() is the CORRECT choice \ because the question says "show all tied records at the same rank". Output ONLY a valid JSON array: [ {{"nl": "...", "sql": "SELECT ..."}}, ... ] Rules: - Use RANK() when question implies gaps after ties, DENSE_RANK() when no gaps. - NL must make the tie-semantics explicit ("same rank", "tied positions"). - No markdown. No explanation. Just the JSON array."""), # 3. strftime integer month output ("strftime_integer_month", """SCHEMA: {schema} Generate exactly 8 NL2SQL pairs where the question asks for a numeric month number \ (1–12), NOT a 'YYYY-MM' string. Output ONLY a valid JSON array: [ {{"nl": "...", "sql": "SELECT ..."}}, ... ] Rules: - SQL must use CAST(strftime('%m', ) AS INTEGER) to produce integer month. - Do NOT use strftime('%Y-%m', ...) when the question asks for month number. - NL questions must say "month number", "which month (1–12)", or similar. - No markdown. No explanation. Just the JSON array."""), # 4. SELECT column discipline ("select_column_discipline", """SCHEMA: {schema} Generate exactly 8 NL2SQL pairs where the question explicitly names ONLY the columns \ to return. The SQL must select EXACTLY those columns — no extras like avg_salary, \ row counts, or intermediate aggregates. Output ONLY a valid JSON array: [ {{"nl": "...", "sql": "SELECT ..."}}, ... ] Rules: - NL must say "return only X, Y, Z" or "show me only the name and total". - SQL SELECT list must contain only those columns. - If aggregation is needed internally (e.g. for HAVING), do NOT expose it in SELECT. - No markdown. No explanation. Just the JSON array."""), # 5. LAG / LEAD period-over-period ("lag_lead_period", """SCHEMA: {schema} Generate exactly 8 NL2SQL pairs that require LAG() or LEAD() window functions \ for period-over-period comparison (e.g. month-over-month revenue change, \ previous order amount, next appointment date). Output ONLY a valid JSON array: [ {{"nl": "...", "sql": "SELECT ..."}}, ... ] Rules: - Use LAG(, 1) OVER (ORDER BY ...) or LEAD(...) correctly. - NL must imply comparison with previous or next row/period. - No markdown. No explanation. Just the JSON array."""), # 6. HAVING vs WHERE ("having_vs_where", """SCHEMA: {schema} Generate exactly 8 NL2SQL pairs that test correct placement of filter conditions: - Conditions on raw columns → WHERE - Conditions on aggregates → HAVING Include 4 pairs where a wrong model might put an aggregate condition in WHERE (trap). Output ONLY a valid JSON array: [ {{"nl": "...", "sql": "SELECT ..."}}, ... ] Rules: - SQL must never filter an aggregate (COUNT, SUM, AVG) inside WHERE. - SQL must never put a raw column filter inside HAVING. - No markdown. No explanation. Just the JSON array."""), # 7. COUNT(DISTINCT) vs COUNT ("count_distinct", """SCHEMA: {schema} Generate exactly 8 NL2SQL pairs where the question specifically asks for \ "unique", "distinct", or "different" counts — requiring COUNT(DISTINCT col). Also include 2 pairs where COUNT(*) is correct to reinforce the contrast. Output ONLY a valid JSON array: [ {{"nl": "...", "sql": "SELECT ..."}}, ... ] Rules: - When NL says "unique/distinct", SQL must use COUNT(DISTINCT ). - When NL says "total orders placed" (not distinct), use COUNT(*) or COUNT(id). - No markdown. No explanation. Just the JSON array."""), ] # ── Helpers ────────────────────────────────────────────────────────────────── def extract_json_array(text: str) -> str: text = text.strip() # strip code fences if model leaks them text = re.sub(r"```(?:json)?\n?(.*?)```", r"\1", text, flags=re.DOTALL).strip() s, e = text.find("["), text.rfind("]") return text[s:e+1] if s != -1 and e != -1 else "[]" def get_hash(text: str) -> str: return hashlib.md5(text.lower().strip().encode()).hexdigest() def build_record(nl: str, sql: str, domain: str) -> dict: return { "prompt": [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": f"SCHEMA: {SCHEMA_CONTEXT[domain]}\nQUESTION: {nl}"} ], "sql": sql } # ── Main ───────────────────────────────────────────────────────────────────── def main(): print(f"Loading {MODEL_NAME}...") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side="left") tokenizer.pad_token = tokenizer.eos_token custom_memory = {0:"30GiB",1:"75GiB",2:"45GiB",3:"45GiB"} model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, device_map="auto", max_memory=custom_memory, quantization_config=quantization_config, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, attn_implementation = "sdpa" ) domains = list(SCHEMA_CONTEXT.keys()) seen = set() total = 0 out = open(OUTPUT_FILE, "a", encoding="utf-8") for pattern_tag, prompt_tmpl in EDGE_PATTERNS: print(f"\n[PATTERN] {pattern_tag}") collected = 0 domain_idx = 0 pbar = tqdm(total=SAMPLES_PER_PATTERN, desc=pattern_tag) while collected < SAMPLES_PER_PATTERN: # Build a batch of prompts, cycling through domains batch_domains = [] batch_prompts = [] for _ in range(BATCH_SIZE): domain = domains[domain_idx % len(domains)] domain_idx += 1 user_msg = prompt_tmpl.format(schema=SCHEMA_CONTEXT[domain]) msgs = [ {"role": "system", "content": "You output only valid JSON arrays. No markdown."}, {"role": "user", "content": user_msg} ] batch_prompts.append( tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) ) batch_domains.append(domain) inputs = tokenizer( batch_prompts, return_tensors="pt", padding=True, truncation=True ).to(model.device) try: with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=2048, do_sample=True, temperature=0.5, top_p=0.9, pad_token_id=tokenizer.eos_token_id ) responses = tokenizer.batch_decode( outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True ) for resp, domain in zip(responses, batch_domains): raw = extract_json_array(resp) try: pairs = json.loads(raw) except Exception: continue for pair in pairs: nl = pair.get("nl", "").strip() sql = pair.get("sql", "").strip() if not nl or not sql: continue # strip fences in sql just in case sql = re.sub(r"```(?:sql)?\n?(.*?)```", r"\1", sql, flags=re.DOTALL).strip() h = get_hash(nl + sql) if h in seen: continue seen.add(h) record = build_record(nl, sql, domain) out.write(json.dumps(record, ensure_ascii=False) + "\n") out.flush() collected += 1 total += 1 pbar.update(1) if collected >= SAMPLES_PER_PATTERN: break except Exception as e: tqdm.write(f"[WARN] Batch failed: {e}") continue pbar.close() out.close() print(f"\nDone! {total} edge-case records saved to {OUTPUT_FILE}") if __name__ == "__main__": main()