Spaces:
Sleeping
Sleeping
| """ | |
| generate_edge_cases.py | |
| ====================== | |
| Targeted edge-case data generator for the 4 failure patterns found in eval: | |
| 1. ROW_NUMBER vs RANK vs DENSE_RANK (tie-breaking semantics) | |
| 2. strftime month as INTEGER (not '%Y-%m' string) | |
| 3. SELECT column discipline (no unrequested extras) | |
| 4. LAG/LEAD period-over-period | |
| 5. HAVING vs WHERE placement | |
| 6. COUNT(DISTINCT) vs COUNT | |
| Produces: edge_cases.jsonl (same chat format as nl2sql_cleaned_ready_to_train.jsonl) | |
| Run: python generate_edge_cases.py | |
| """ | |
| import os, sys, json, re, hashlib | |
| from tqdm import tqdm | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, AwqConfig, BitsAndBytesConfig | |
| import torch | |
| import transformers.activations | |
| # Yeh line AutoAWQ ko bewakoof banayegi taaki wo crash na ho | |
| if not hasattr(transformers.activations, 'PytorchGELUTanh'): | |
| transformers.activations.PytorchGELUTanh = transformers.activations.NewGELUActivation | |
| os.environ["CUDA_VISIBLE_DEVICES"] = "3,1,6,7" | |
| PROJECT_ROOT = os.path.abspath(os.path.dirname(__file__)) | |
| if PROJECT_ROOT not in sys.path: | |
| sys.path.insert(0, PROJECT_ROOT) | |
| from data_factory.schemas import SCHEMA_CONTEXT | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4" | |
| ) | |
| MODEL_NAME = "Qwen/Qwen2.5-72B-Instruct" | |
| OUTPUT_FILE = "edge_cases.jsonl" | |
| BATCH_SIZE = 8 # smaller β edge prompts are long | |
| SAMPLES_PER_PATTERN = 715 # ~6 batches per pattern β 5005 total edge samples | |
| SYSTEM_PROMPT = ( | |
| "You are a Senior SQL Architect. " | |
| "Output ONLY the SQL query. Use SQLite syntax." | |
| ) | |
| # ββ Edge-case prompt templates ββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Each entry: (pattern_tag, user_prompt_template) | |
| # {schema} is filled at runtime with a random domain schema. | |
| EDGE_PATTERNS = [ | |
| # 1. ROW_NUMBER tie-breaking β the #1 failure | |
| ("row_number_tiebreak", """SCHEMA: | |
| {schema} | |
| Generate exactly 8 NL2SQL pairs that REQUIRE ROW_NUMBER() (not RANK or DENSE_RANK) \ | |
| because the question explicitly says "pick one winner when there is a tie" \ | |
| using a tiebreaker column (e.g. lower id, earlier date). | |
| Output ONLY a valid JSON array: | |
| [ | |
| {{"nl": "...", "sql": "SELECT ..."}}, | |
| ... | |
| ] | |
| Rules: | |
| - Every SQL must use ROW_NUMBER() OVER (...) not RANK(). | |
| - The OVER clause ORDER BY must include the tiebreaker column. | |
| - WHERE rn = 1 must appear in an outer query or CTE. | |
| - No markdown. No explanation. Just the JSON array."""), | |
| # 2. RANK / DENSE_RANK β when ties SHOULD persist | |
| ("rank_dense_rank", """SCHEMA: | |
| {schema} | |
| Generate exactly 8 NL2SQL pairs where RANK() or DENSE_RANK() is the CORRECT choice \ | |
| because the question says "show all tied records at the same rank". | |
| Output ONLY a valid JSON array: | |
| [ | |
| {{"nl": "...", "sql": "SELECT ..."}}, | |
| ... | |
| ] | |
| Rules: | |
| - Use RANK() when question implies gaps after ties, DENSE_RANK() when no gaps. | |
| - NL must make the tie-semantics explicit ("same rank", "tied positions"). | |
| - No markdown. No explanation. Just the JSON array."""), | |
| # 3. strftime integer month output | |
| ("strftime_integer_month", """SCHEMA: | |
| {schema} | |
| Generate exactly 8 NL2SQL pairs where the question asks for a numeric month number \ | |
| (1β12), NOT a 'YYYY-MM' string. | |
| Output ONLY a valid JSON array: | |
| [ | |
| {{"nl": "...", "sql": "SELECT ..."}}, | |
| ... | |
| ] | |
| Rules: | |
| - SQL must use CAST(strftime('%m', <col>) AS INTEGER) to produce integer month. | |
| - Do NOT use strftime('%Y-%m', ...) when the question asks for month number. | |
| - NL questions must say "month number", "which month (1β12)", or similar. | |
| - No markdown. No explanation. Just the JSON array."""), | |
| # 4. SELECT column discipline | |
| ("select_column_discipline", """SCHEMA: | |
| {schema} | |
| Generate exactly 8 NL2SQL pairs where the question explicitly names ONLY the columns \ | |
| to return. The SQL must select EXACTLY those columns β no extras like avg_salary, \ | |
| row counts, or intermediate aggregates. | |
| Output ONLY a valid JSON array: | |
| [ | |
| {{"nl": "...", "sql": "SELECT ..."}}, | |
| ... | |
| ] | |
| Rules: | |
| - NL must say "return only X, Y, Z" or "show me only the name and total". | |
| - SQL SELECT list must contain only those columns. | |
| - If aggregation is needed internally (e.g. for HAVING), do NOT expose it in SELECT. | |
| - No markdown. No explanation. Just the JSON array."""), | |
| # 5. LAG / LEAD period-over-period | |
| ("lag_lead_period", """SCHEMA: | |
| {schema} | |
| Generate exactly 8 NL2SQL pairs that require LAG() or LEAD() window functions \ | |
| for period-over-period comparison (e.g. month-over-month revenue change, \ | |
| previous order amount, next appointment date). | |
| Output ONLY a valid JSON array: | |
| [ | |
| {{"nl": "...", "sql": "SELECT ..."}}, | |
| ... | |
| ] | |
| Rules: | |
| - Use LAG(<col>, 1) OVER (ORDER BY ...) or LEAD(...) correctly. | |
| - NL must imply comparison with previous or next row/period. | |
| - No markdown. No explanation. Just the JSON array."""), | |
| # 6. HAVING vs WHERE | |
| ("having_vs_where", """SCHEMA: | |
| {schema} | |
| Generate exactly 8 NL2SQL pairs that test correct placement of filter conditions: | |
| - Conditions on raw columns β WHERE | |
| - Conditions on aggregates β HAVING | |
| Include 4 pairs where a wrong model might put an aggregate condition in WHERE (trap). | |
| Output ONLY a valid JSON array: | |
| [ | |
| {{"nl": "...", "sql": "SELECT ..."}}, | |
| ... | |
| ] | |
| Rules: | |
| - SQL must never filter an aggregate (COUNT, SUM, AVG) inside WHERE. | |
| - SQL must never put a raw column filter inside HAVING. | |
| - No markdown. No explanation. Just the JSON array."""), | |
| # 7. COUNT(DISTINCT) vs COUNT | |
| ("count_distinct", """SCHEMA: | |
| {schema} | |
| Generate exactly 8 NL2SQL pairs where the question specifically asks for \ | |
| "unique", "distinct", or "different" counts β requiring COUNT(DISTINCT col). | |
| Also include 2 pairs where COUNT(*) is correct to reinforce the contrast. | |
| Output ONLY a valid JSON array: | |
| [ | |
| {{"nl": "...", "sql": "SELECT ..."}}, | |
| ... | |
| ] | |
| Rules: | |
| - When NL says "unique/distinct", SQL must use COUNT(DISTINCT <col>). | |
| - When NL says "total orders placed" (not distinct), use COUNT(*) or COUNT(id). | |
| - No markdown. No explanation. Just the JSON array."""), | |
| ] | |
| # ββ Helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def extract_json_array(text: str) -> str: | |
| text = text.strip() | |
| # strip code fences if model leaks them | |
| text = re.sub(r"```(?:json)?\n?(.*?)```", r"\1", text, flags=re.DOTALL).strip() | |
| s, e = text.find("["), text.rfind("]") | |
| return text[s:e+1] if s != -1 and e != -1 else "[]" | |
| def get_hash(text: str) -> str: | |
| return hashlib.md5(text.lower().strip().encode()).hexdigest() | |
| def build_record(nl: str, sql: str, domain: str) -> dict: | |
| return { | |
| "prompt": [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": f"SCHEMA: {SCHEMA_CONTEXT[domain]}\nQUESTION: {nl}"} | |
| ], | |
| "sql": sql | |
| } | |
| # ββ Main βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main(): | |
| print(f"Loading {MODEL_NAME}...") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side="left") | |
| tokenizer.pad_token = tokenizer.eos_token | |
| custom_memory = {0:"30GiB",1:"75GiB",2:"45GiB",3:"45GiB"} | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| device_map="auto", | |
| max_memory=custom_memory, | |
| quantization_config=quantization_config, | |
| torch_dtype=torch.bfloat16, | |
| low_cpu_mem_usage=True, | |
| attn_implementation = "sdpa" | |
| ) | |
| domains = list(SCHEMA_CONTEXT.keys()) | |
| seen = set() | |
| total = 0 | |
| out = open(OUTPUT_FILE, "a", encoding="utf-8") | |
| for pattern_tag, prompt_tmpl in EDGE_PATTERNS: | |
| print(f"\n[PATTERN] {pattern_tag}") | |
| collected = 0 | |
| domain_idx = 0 | |
| pbar = tqdm(total=SAMPLES_PER_PATTERN, desc=pattern_tag) | |
| while collected < SAMPLES_PER_PATTERN: | |
| # Build a batch of prompts, cycling through domains | |
| batch_domains = [] | |
| batch_prompts = [] | |
| for _ in range(BATCH_SIZE): | |
| domain = domains[domain_idx % len(domains)] | |
| domain_idx += 1 | |
| user_msg = prompt_tmpl.format(schema=SCHEMA_CONTEXT[domain]) | |
| msgs = [ | |
| {"role": "system", "content": "You output only valid JSON arrays. No markdown."}, | |
| {"role": "user", "content": user_msg} | |
| ] | |
| batch_prompts.append( | |
| tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True) | |
| ) | |
| batch_domains.append(domain) | |
| inputs = tokenizer( | |
| batch_prompts, return_tensors="pt", padding=True, truncation=True | |
| ).to(model.device) | |
| try: | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=2048, | |
| do_sample=True, | |
| temperature=0.5, | |
| top_p=0.9, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| responses = tokenizer.batch_decode( | |
| outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True | |
| ) | |
| for resp, domain in zip(responses, batch_domains): | |
| raw = extract_json_array(resp) | |
| try: | |
| pairs = json.loads(raw) | |
| except Exception: | |
| continue | |
| for pair in pairs: | |
| nl = pair.get("nl", "").strip() | |
| sql = pair.get("sql", "").strip() | |
| if not nl or not sql: | |
| continue | |
| # strip fences in sql just in case | |
| sql = re.sub(r"```(?:sql)?\n?(.*?)```", r"\1", sql, flags=re.DOTALL).strip() | |
| h = get_hash(nl + sql) | |
| if h in seen: | |
| continue | |
| seen.add(h) | |
| record = build_record(nl, sql, domain) | |
| out.write(json.dumps(record, ensure_ascii=False) + "\n") | |
| out.flush() | |
| collected += 1 | |
| total += 1 | |
| pbar.update(1) | |
| if collected >= SAMPLES_PER_PATTERN: | |
| break | |
| except Exception as e: | |
| tqdm.write(f"[WARN] Batch failed: {e}") | |
| continue | |
| pbar.close() | |
| out.close() | |
| print(f"\nDone! {total} edge-case records saved to {OUTPUT_FILE}") | |
| if __name__ == "__main__": | |
| main() |