nl2sql-bench / data_factory /augmentor.py
ritvik360's picture
Upload folder using huggingface_hub
a39d8ef verified
"""
data_factory/augmentor.py
==========================
Rule-based Natural Language augmentation.
These transformations operate ONLY on NL question strings.
SQL is NEVER modified β€” it always comes from the verified template library.
Three augmentation strategies:
1. Synonym replacement β€” swaps domain words with semantically equivalent ones
2. Condition reordering β€” shuffles conjunctive phrases (preserves meaning)
3. Date normalisation β€” expresses dates in different formats when applicable
"""
from __future__ import annotations
import random
import re
from copy import deepcopy
from typing import Iterator
# ─────────────────────────────────────────────────────────────────────────────
# SYNONYM DICTIONARIES
# ─────────────────────────────────────────────────────────────────────────────
# Format: "canonical_term": ["synonym1", "synonym2", ...]
# All synonyms are semantically equivalent in a business context.
_SYNONYMS: dict[str, list[str]] = {
# Verbs / action starters
"list": ["show", "display", "return", "give me", "find", "retrieve"],
"show": ["list", "display", "return", "get", "retrieve"],
"find": ["identify", "locate", "get", "show", "retrieve", "look up"],
"return": ["show", "give", "list", "retrieve", "output"],
"retrieve": ["fetch", "get", "return", "pull"],
"get": ["retrieve", "fetch", "return", "give me"],
# Aggregation words
"total": ["sum", "aggregate", "overall", "cumulative", "combined"],
"average": ["mean", "avg", "typical"],
"count": ["number of", "quantity of", "how many"],
"highest": ["largest", "maximum", "top", "greatest"],
"lowest": ["smallest", "minimum", "least"],
# Business / domain
"customer": ["client", "buyer", "user", "account holder", "shopper"],
"customers": ["clients", "buyers", "users", "account holders", "shoppers"],
"product": ["item", "SKU", "article", "goods"],
"products": ["items", "SKUs", "articles", "goods"],
"order": ["purchase", "transaction", "sale"],
"orders": ["purchases", "transactions", "sales"],
"revenue": ["income", "earnings", "sales amount", "money earned"],
"spending": ["expenditure", "spend", "purchases"],
"amount": ["value", "sum", "total", "figure"],
"price": ["cost", "rate", "charge", "fee"],
# Healthcare
"patient": ["person", "individual", "case"],
"patients": ["persons", "individuals", "cases"],
"doctor": ["physician", "clinician", "practitioner", "specialist"],
"doctors": ["physicians", "clinicians", "practitioners"],
"appointment": ["visit", "consultation", "session"],
"appointments": ["visits", "consultations", "sessions"],
"medication": ["drug", "medicine", "pharmaceutical", "prescription drug"],
"medications": ["drugs", "medicines", "pharmaceuticals"],
"diagnosis": ["condition", "finding", "medical finding"],
# Finance
"account": ["bank account", "profile", "portfolio entry"],
"accounts": ["bank accounts", "profiles"],
"loan": ["credit", "borrowing", "debt instrument"],
"loans": ["credits", "borrowings", "debt instruments"],
"transaction": ["transfer", "payment", "operation", "activity"],
"transactions": ["transfers", "payments", "operations"],
"balance": ["funds", "available amount", "account balance"],
# HR
"employee": ["staff member", "worker", "team member", "headcount"],
"employees": ["staff", "workers", "team members", "workforce"],
"department": ["team", "division", "unit", "group"],
"departments": ["teams", "divisions", "units"],
"salary": ["pay", "compensation", "remuneration", "earnings"],
"project": ["initiative", "program", "assignment", "engagement"],
"projects": ["initiatives", "programs", "assignments"],
# Adjectives / Qualifiers
"active": ["current", "ongoing", "live", "existing"],
"delivered": ["completed", "fulfilled", "received"],
"cancelled": ["voided", "aborted", "terminated"],
"alphabetically": ["by name", "in alphabetical order", "A to Z"],
"descending": ["from highest to lowest", "in decreasing order", "largest first"],
"ascending": ["from lowest to highest", "in increasing order", "smallest first"],
"distinct": ["unique", "different"],
"in stock": ["available", "with available inventory", "not out of stock"],
}
# ─────────────────────────────────────────────────────────────────────────────
# DATE PHRASE PATTERNS
# These will be replaced with alternative date expressions.
# ─────────────────────────────────────────────────────────────────────────────
_DATE_ALTERNATES: list[tuple[str, list[str]]] = [
# ISO partial
("2024-01-01", ["January 1st 2024", "Jan 1, 2024", "the start of 2024", "2024 start"]),
("2023-01-01", ["January 1st 2023", "Jan 1, 2023", "the start of 2023"]),
("2025-01-01", ["January 1st 2025", "the start of 2025"]),
# Quarter references
("Q1", ["the first quarter", "January through March", "Jan-Mar"]),
("Q2", ["the second quarter", "April through June", "Apr-Jun"]),
("Q3", ["the third quarter", "July through September", "Jul-Sep"]),
("Q4", ["the fourth quarter", "October through December", "Oct-Dec"]),
# Year references
("in 2024", ["during 2024", "throughout 2024", "for the year 2024"]),
("in 2023", ["during 2023", "throughout 2023", "for the year 2023"]),
]
# ─────────────────────────────────────────────────────────────────────────────
# CONDITION REORDERING
# Splits on "and" between two conditions and reverses them.
# ─────────────────────────────────────────────────────────────────────────────
def _reorder_conditions(text: str, rng: random.Random) -> str:
"""
If the text contains ' and ' connecting two distinct clauses,
randomly swap their order 50% of the time.
Example:
"active employees earning above $100,000"
β†’ "employees earning above $100,000 that are active"
"""
# Only attempt if "and" is present as a clause connector
matches = list(re.finditer(r'\b(?:and|who are|that are|with)\b', text, re.IGNORECASE))
if not matches or rng.random() > 0.5:
return text
# Take the first match and swap text around it
m = matches[0]
before = text[:m.start()].strip()
after = text[m.end():].strip()
connector = m.group(0).lower()
# Build swapped version
if connector in ("and",):
swapped = f"{after} and {before}"
else:
swapped = f"{after} {connector} {before}"
# Return swapped only if it doesn't break grammar badly
# (heuristic: swapped should not start with a verb)
if swapped and not swapped[0].isupper():
swapped = swapped[0].upper() + swapped[1:]
return swapped
# ─────────────────────────────────────────────────────────────────────────────
# SYNONYM REPLACEMENT
# ─────────────────────────────────────────────────────────────────────────────
def _apply_synonyms(text: str, rng: random.Random, max_replacements: int = 3) -> str:
"""
Replace up to `max_replacements` words/phrases with synonyms.
Replacement is probabilistic (50% chance per match) to maintain diversity.
"""
result = text
replacements_done = 0
# Shuffle the synonym keys to get different replacement targets each call
keys = list(_SYNONYMS.keys())
rng.shuffle(keys)
for canonical in keys:
if replacements_done >= max_replacements:
break
synonyms = _SYNONYMS[canonical]
# Case-insensitive match on word boundary
pattern = re.compile(r'\b' + re.escape(canonical) + r'\b', re.IGNORECASE)
if pattern.search(result) and rng.random() < 0.5:
replacement = rng.choice(synonyms)
# Preserve original casing for first character
def _replace(m: re.Match) -> str:
original = m.group(0)
if original[0].isupper():
return replacement[0].upper() + replacement[1:]
return replacement
result = pattern.sub(_replace, result, count=1)
replacements_done += 1
return result
# ─────────────────────────────────────────────────────────────────────────────
# DATE FORMAT VARIATION
# ─────────────────────────────────────────────────────────────────────────────
def _vary_dates(text: str, rng: random.Random) -> str:
"""Replace date phrases with alternate representations."""
result = text
for phrase, alternates in _DATE_ALTERNATES:
if phrase.lower() in result.lower() and rng.random() < 0.6:
alt = rng.choice(alternates)
result = re.sub(re.escape(phrase), alt, result, count=1, flags=re.IGNORECASE)
return result
# ─────────────────────────────────────────────────────────────────────────────
# PUBLIC API
# ─────────────────────────────────────────────────────────────────────────────
def augment_nl(
nl_question: str,
n: int = 3,
seed: int = 42,
) -> list[str]:
"""
Generate `n` rule-based augmented variants of a natural language question.
Each variant applies a different combination of:
- synonym replacement
- condition reordering
- date format variation
The original question is NOT included in the output.
Parameters
----------
nl_question : str
The base NL question to augment.
n : int
Number of variants to generate.
seed : int
Random seed for reproducibility.
Returns
-------
list[str]
Up to `n` distinct augmented strings. May be fewer if the question
is too short to vary meaningfully.
"""
rng = random.Random(seed)
variants: list[str] = []
seen: set[str] = {nl_question}
strategies = [
# Strategy 1: synonym only
lambda t, r: _apply_synonyms(t, r, max_replacements=2),
# Strategy 2: synonym + date
lambda t, r: _vary_dates(_apply_synonyms(t, r, max_replacements=2), r),
# Strategy 3: condition reorder + synonym
lambda t, r: _apply_synonyms(_reorder_conditions(t, r), r, max_replacements=1),
# Strategy 4: heavy synonym
lambda t, r: _apply_synonyms(t, r, max_replacements=4),
# Strategy 5: date only
lambda t, r: _vary_dates(t, r),
]
for i in range(n * 3): # Over-generate, then deduplicate
strategy = strategies[i % len(strategies)]
# Use a different seed offset per variant attempt
local_rng = random.Random(seed + i * 31)
candidate = strategy(nl_question, local_rng).strip()
# Normalise whitespace
candidate = " ".join(candidate.split())
if candidate and candidate not in seen:
seen.add(candidate)
variants.append(candidate)
if len(variants) >= n:
break
return variants
def generate_all_augmentations(
nl_question: str,
seed: int = 42,
n_per_template: int = 3,
) -> Iterator[str]:
"""
Yield augmented NL variants one at a time (generator).
Suitable for streaming into a large dataset without memory pressure.
"""
yield from augment_nl(nl_question, n=n_per_template, seed=seed)