nl2sql-bench / value_swapper.py
ritvik360's picture
Upload folder using huggingface_hub
a39d8ef verified
import json
import re
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent))
from data_factory.templates import ALL_TEMPLATES
# Define strict categorical swaps based on the exact schemas
SWAP_RULES = {
"ecommerce": [
(r"'gold'", r"gold", ["'silver'", "'bronze'"], ["silver", "bronze"]),
(r"'delivered'", r"delivered", ["'pending'", "'processing'", "'shipped'", "'cancelled'"], ["pending", "processing", "shipped", "cancelled"]),
(r"'India'", r"India", ["'USA'", "'Germany'", "'UK'", "'Canada'"], ["USA", "Germany", "UK", "Canada"])
],
"healthcare": [
(r"'severe'", r"severe", ["'mild'", "'moderate'"], ["mild", "moderate"]),
(r"'completed'", r"completed", ["'scheduled'", "'cancelled'", "'no_show'"], ["scheduled", "cancelled", "no-show"])
],
"finance": [
(r"'active'", r"active", ["'dormant'", "'closed'"], ["dormant", "closed"]),
(r"'credit'", r"credit", ["'debit'"], ["debit"]),
(r"'verified'", r"verified", ["'pending'", "'rejected'"], ["pending", "rejected"])
],
"hr": [
(r"'active'", r"active", ["'resigned'", "'terminated'"], ["resigned", "terminated"])
]
}
def generate_swaps():
expanded_templates = []
for template in ALL_TEMPLATES:
expanded_templates.append(template) # Keep the original
domain = template["domain"]
if domain not in SWAP_RULES:
continue
for sql_target, nl_target, sql_replacements, nl_replacements in SWAP_RULES[domain]:
if re.search(sql_target, template["sql"], re.IGNORECASE):
for sql_repl, nl_repl in zip(sql_replacements, nl_replacements):
new_template = template.copy()
# Swap in SQL
new_template["sql"] = re.sub(sql_target, sql_repl, template["sql"], flags=re.IGNORECASE)
# Swap in NL and Description
new_template["base_nl"] = re.sub(nl_target, nl_repl, template["base_nl"], flags=re.IGNORECASE)
new_template["description"] = re.sub(nl_target, nl_repl, template["description"], flags=re.IGNORECASE)
# Create a unique ID
new_template["id"] = f"{template.get('id', 'temp')}_swap_{nl_repl.replace(' ', '_')}"
expanded_templates.append(new_template)
return expanded_templates
if __name__ == "__main__":
swapped = generate_swaps()
print(f"Original Templates: {len(ALL_TEMPLATES)}")
print(f"After Value Swapping: {len(swapped)}")
with open("swapped_templates.json", "w") as f:
json.dump(swapped, f, indent=2)
print("Saved to swapped_templates.json")