nl2sql-bench / generate_edge_cases.py
ritvik360's picture
Upload folder using huggingface_hub
a39d8ef verified
"""
generate_edge_cases.py
======================
Targeted edge-case data generator for the 4 failure patterns found in eval:
1. ROW_NUMBER vs RANK vs DENSE_RANK (tie-breaking semantics)
2. strftime month as INTEGER (not '%Y-%m' string)
3. SELECT column discipline (no unrequested extras)
4. LAG/LEAD period-over-period
5. HAVING vs WHERE placement
6. COUNT(DISTINCT) vs COUNT
Produces: edge_cases.jsonl (same chat format as nl2sql_cleaned_ready_to_train.jsonl)
Run: python generate_edge_cases.py
"""
import os, sys, json, re, hashlib
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, AwqConfig, BitsAndBytesConfig
import torch
import transformers.activations
# Yeh line AutoAWQ ko bewakoof banayegi taaki wo crash na ho
if not hasattr(transformers.activations, 'PytorchGELUTanh'):
transformers.activations.PytorchGELUTanh = transformers.activations.NewGELUActivation
os.environ["CUDA_VISIBLE_DEVICES"] = "3,1,6,7"
PROJECT_ROOT = os.path.abspath(os.path.dirname(__file__))
if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
from data_factory.schemas import SCHEMA_CONTEXT
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
MODEL_NAME = "Qwen/Qwen2.5-72B-Instruct"
OUTPUT_FILE = "edge_cases.jsonl"
BATCH_SIZE = 8 # smaller β€” edge prompts are long
SAMPLES_PER_PATTERN = 715 # ~6 batches per pattern β†’ 5005 total edge samples
SYSTEM_PROMPT = (
"You are a Senior SQL Architect. "
"Output ONLY the SQL query. Use SQLite syntax."
)
# ── Edge-case prompt templates ──────────────────────────────────────────────
# Each entry: (pattern_tag, user_prompt_template)
# {schema} is filled at runtime with a random domain schema.
EDGE_PATTERNS = [
# 1. ROW_NUMBER tie-breaking β€” the #1 failure
("row_number_tiebreak", """SCHEMA:
{schema}
Generate exactly 8 NL2SQL pairs that REQUIRE ROW_NUMBER() (not RANK or DENSE_RANK) \
because the question explicitly says "pick one winner when there is a tie" \
using a tiebreaker column (e.g. lower id, earlier date).
Output ONLY a valid JSON array:
[
{{"nl": "...", "sql": "SELECT ..."}},
...
]
Rules:
- Every SQL must use ROW_NUMBER() OVER (...) not RANK().
- The OVER clause ORDER BY must include the tiebreaker column.
- WHERE rn = 1 must appear in an outer query or CTE.
- No markdown. No explanation. Just the JSON array."""),
# 2. RANK / DENSE_RANK β€” when ties SHOULD persist
("rank_dense_rank", """SCHEMA:
{schema}
Generate exactly 8 NL2SQL pairs where RANK() or DENSE_RANK() is the CORRECT choice \
because the question says "show all tied records at the same rank".
Output ONLY a valid JSON array:
[
{{"nl": "...", "sql": "SELECT ..."}},
...
]
Rules:
- Use RANK() when question implies gaps after ties, DENSE_RANK() when no gaps.
- NL must make the tie-semantics explicit ("same rank", "tied positions").
- No markdown. No explanation. Just the JSON array."""),
# 3. strftime integer month output
("strftime_integer_month", """SCHEMA:
{schema}
Generate exactly 8 NL2SQL pairs where the question asks for a numeric month number \
(1–12), NOT a 'YYYY-MM' string.
Output ONLY a valid JSON array:
[
{{"nl": "...", "sql": "SELECT ..."}},
...
]
Rules:
- SQL must use CAST(strftime('%m', <col>) AS INTEGER) to produce integer month.
- Do NOT use strftime('%Y-%m', ...) when the question asks for month number.
- NL questions must say "month number", "which month (1–12)", or similar.
- No markdown. No explanation. Just the JSON array."""),
# 4. SELECT column discipline
("select_column_discipline", """SCHEMA:
{schema}
Generate exactly 8 NL2SQL pairs where the question explicitly names ONLY the columns \
to return. The SQL must select EXACTLY those columns β€” no extras like avg_salary, \
row counts, or intermediate aggregates.
Output ONLY a valid JSON array:
[
{{"nl": "...", "sql": "SELECT ..."}},
...
]
Rules:
- NL must say "return only X, Y, Z" or "show me only the name and total".
- SQL SELECT list must contain only those columns.
- If aggregation is needed internally (e.g. for HAVING), do NOT expose it in SELECT.
- No markdown. No explanation. Just the JSON array."""),
# 5. LAG / LEAD period-over-period
("lag_lead_period", """SCHEMA:
{schema}
Generate exactly 8 NL2SQL pairs that require LAG() or LEAD() window functions \
for period-over-period comparison (e.g. month-over-month revenue change, \
previous order amount, next appointment date).
Output ONLY a valid JSON array:
[
{{"nl": "...", "sql": "SELECT ..."}},
...
]
Rules:
- Use LAG(<col>, 1) OVER (ORDER BY ...) or LEAD(...) correctly.
- NL must imply comparison with previous or next row/period.
- No markdown. No explanation. Just the JSON array."""),
# 6. HAVING vs WHERE
("having_vs_where", """SCHEMA:
{schema}
Generate exactly 8 NL2SQL pairs that test correct placement of filter conditions:
- Conditions on raw columns β†’ WHERE
- Conditions on aggregates β†’ HAVING
Include 4 pairs where a wrong model might put an aggregate condition in WHERE (trap).
Output ONLY a valid JSON array:
[
{{"nl": "...", "sql": "SELECT ..."}},
...
]
Rules:
- SQL must never filter an aggregate (COUNT, SUM, AVG) inside WHERE.
- SQL must never put a raw column filter inside HAVING.
- No markdown. No explanation. Just the JSON array."""),
# 7. COUNT(DISTINCT) vs COUNT
("count_distinct", """SCHEMA:
{schema}
Generate exactly 8 NL2SQL pairs where the question specifically asks for \
"unique", "distinct", or "different" counts β€” requiring COUNT(DISTINCT col).
Also include 2 pairs where COUNT(*) is correct to reinforce the contrast.
Output ONLY a valid JSON array:
[
{{"nl": "...", "sql": "SELECT ..."}},
...
]
Rules:
- When NL says "unique/distinct", SQL must use COUNT(DISTINCT <col>).
- When NL says "total orders placed" (not distinct), use COUNT(*) or COUNT(id).
- No markdown. No explanation. Just the JSON array."""),
]
# ── Helpers ──────────────────────────────────────────────────────────────────
def extract_json_array(text: str) -> str:
text = text.strip()
# strip code fences if model leaks them
text = re.sub(r"```(?:json)?\n?(.*?)```", r"\1", text, flags=re.DOTALL).strip()
s, e = text.find("["), text.rfind("]")
return text[s:e+1] if s != -1 and e != -1 else "[]"
def get_hash(text: str) -> str:
return hashlib.md5(text.lower().strip().encode()).hexdigest()
def build_record(nl: str, sql: str, domain: str) -> dict:
return {
"prompt": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": f"SCHEMA: {SCHEMA_CONTEXT[domain]}\nQUESTION: {nl}"}
],
"sql": sql
}
# ── Main ─────────────────────────────────────────────────────────────────────
def main():
print(f"Loading {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side="left")
tokenizer.pad_token = tokenizer.eos_token
custom_memory = {0:"30GiB",1:"75GiB",2:"45GiB",3:"45GiB"}
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
device_map="auto",
max_memory=custom_memory,
quantization_config=quantization_config,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
attn_implementation = "sdpa"
)
domains = list(SCHEMA_CONTEXT.keys())
seen = set()
total = 0
out = open(OUTPUT_FILE, "a", encoding="utf-8")
for pattern_tag, prompt_tmpl in EDGE_PATTERNS:
print(f"\n[PATTERN] {pattern_tag}")
collected = 0
domain_idx = 0
pbar = tqdm(total=SAMPLES_PER_PATTERN, desc=pattern_tag)
while collected < SAMPLES_PER_PATTERN:
# Build a batch of prompts, cycling through domains
batch_domains = []
batch_prompts = []
for _ in range(BATCH_SIZE):
domain = domains[domain_idx % len(domains)]
domain_idx += 1
user_msg = prompt_tmpl.format(schema=SCHEMA_CONTEXT[domain])
msgs = [
{"role": "system", "content": "You output only valid JSON arrays. No markdown."},
{"role": "user", "content": user_msg}
]
batch_prompts.append(
tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
)
batch_domains.append(domain)
inputs = tokenizer(
batch_prompts, return_tensors="pt", padding=True, truncation=True
).to(model.device)
try:
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=2048,
do_sample=True,
temperature=0.5,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id
)
responses = tokenizer.batch_decode(
outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True
)
for resp, domain in zip(responses, batch_domains):
raw = extract_json_array(resp)
try:
pairs = json.loads(raw)
except Exception:
continue
for pair in pairs:
nl = pair.get("nl", "").strip()
sql = pair.get("sql", "").strip()
if not nl or not sql:
continue
# strip fences in sql just in case
sql = re.sub(r"```(?:sql)?\n?(.*?)```", r"\1", sql, flags=re.DOTALL).strip()
h = get_hash(nl + sql)
if h in seen:
continue
seen.add(h)
record = build_record(nl, sql, domain)
out.write(json.dumps(record, ensure_ascii=False) + "\n")
out.flush()
collected += 1
total += 1
pbar.update(1)
if collected >= SAMPLES_PER_PATTERN:
break
except Exception as e:
tqdm.write(f"[WARN] Batch failed: {e}")
continue
pbar.close()
out.close()
print(f"\nDone! {total} edge-case records saved to {OUTPUT_FILE}")
if __name__ == "__main__":
main()