ritvik360 commited on
Commit
a39d8ef
·
verified ·
1 Parent(s): deedaad

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
.env.example ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # nl2sql-bench/.env.example
2
+ # ─────────────────────────────────────────────────────────────────────────────
3
+ # Copy this file to .env and fill in your values.
4
+ # NEVER commit .env to version control.
5
+ #
6
+ # All three variables below are MANDATORY per competition rules.
7
+ # ─────────────────────────────────────────────────────────────────────────────
8
+
9
+ # LLM API endpoint (HuggingFace router or any OpenAI-compatible base URL)
10
+ API_BASE_URL=https://router.huggingface.co/v1
11
+
12
+ # Model identifier — must be accessible at the above endpoint
13
+ MODEL_NAME=Qwen/Qwen2.5-72B-Instruct
14
+
15
+ # HuggingFace API token (also used as the OpenAI-client api_key)
16
+ HF_TOKEN=hf_your_token_here
17
+
18
+ # ── Optional overrides ────────────────────────────────────────────────────
19
+ # LOCAL_IMAGE_NAME=nl2sql-bench:latest # Docker image name for local dev
20
+ # SPACE_URL=https://your-space.hf.space # Deployed HF Space URL
21
+ # NL2SQL_DEFAULT_TASK=simple-filter # Default task (overridden per episode)
22
+ # NL2SQL_MAX_STEPS=5 # Max steps per episode
23
+ # ENABLE_WEB_INTERFACE=true # Enable /web UI for debugging
Dockerfile ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # nl2sql-bench/server/Dockerfile
2
+ # ─────────────────────────────────────────────────────────────────────────────
3
+ # NL2SQL-Bench OpenEnv Server
4
+ # Hugging Face Spaces compatible (port 7860, non-root user).
5
+ # Build: docker build -t nl2sql-bench:latest .
6
+ # Run: docker run -p 7860:7860 nl2sql-bench:latest
7
+ # ─────────────────────────────────────────────────────────────────────────────
8
+
9
+ FROM python:3.11-slim
10
+
11
+ # HF Spaces runs as non-root by default
12
+ ARG UID=1000
13
+ RUN useradd -m -u $UID appuser
14
+
15
+ WORKDIR /app
16
+
17
+ # ── System deps ───────────────────────────────────────────────────────────
18
+ RUN apt-get update -qq && \
19
+ apt-get install -y --no-install-recommends curl && \
20
+ rm -rf /var/lib/apt/lists/*
21
+
22
+ # ── Python deps ───────────────────────────────────────────────────────────
23
+ COPY server/requirements.txt /app/requirements.txt
24
+ RUN pip install --no-cache-dir -r requirements.txt
25
+
26
+ # ── Application code ──────────────────────────────────────────────────────
27
+ # Copy server code
28
+ COPY server/ /app/server/
29
+ # Copy shared models (client imports from parent — we flatten for Docker)
30
+ COPY models.py /app/models.py
31
+
32
+ # Flatten server submodules into /app so Python can find them
33
+ # (avoids complex PYTHONPATH games inside the container)
34
+ RUN cp -r /app/server/tasks /app/tasks && \
35
+ cp -r /app/server/db /app/db && \
36
+ cp /app/server/grader.py /app/grader.py && \
37
+ cp /app/server/environment.py /app/environment.py && \
38
+ cp /app/server/app.py /app/app.py
39
+
40
+ # ── Runtime config ────────────────────────────────────────────────────────
41
+ ENV PYTHONPATH=/app
42
+ ENV PYTHONUNBUFFERED=1
43
+ # HF Spaces requires port 7860
44
+ ENV PORT=7860
45
+ ENV NL2SQL_DEFAULT_TASK=simple-filter
46
+ ENV NL2SQL_MAX_STEPS=5
47
+
48
+ USER appuser
49
+ WORKDIR /app
50
+
51
+ EXPOSE 7860
52
+
53
+ HEALTHCHECK --interval=30s --timeout=10s --start-period=15s --retries=3 \
54
+ CMD curl -sf http://localhost:${PORT}/health || exit 1
55
+
56
+ CMD ["sh", "-c", "uvicorn app:app --host 0.0.0.0 --port ${PORT} --workers 2 --log-level info"]
README.md CHANGED
@@ -1,10 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
- title: Nl2sql Bench
3
- emoji: 📚
4
- colorFrom: yellow
5
- colorTo: yellow
6
- sdk: docker
7
- pinned: false
 
 
 
 
 
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
1
+ # NL2SQL-Bench
2
+
3
+ **Natural Language to SQL Analytics Environment for RL Training**
4
+
5
+ [![openenv](https://img.shields.io/badge/openenv-compatible-blue)](https://github.com/meta-pytorch/OpenEnv)
6
+ [![Python 3.10+](https://img.shields.io/badge/python-3.10+-green)](https://www.python.org)
7
+ [![License: MIT](https://img.shields.io/badge/License-MIT-yellow)](LICENSE)
8
+
9
+ ---
10
+
11
+ ## What is NL2SQL-Bench?
12
+
13
+ NL2SQL-Bench is an OpenEnv-compliant RL training environment where an AI agent must iteratively write and refine **SQLite queries** to answer natural-language business questions against a synthetic e-commerce database.
14
+
15
+ This fills a genuine gap in the OpenEnv ecosystem — no SQL query environment currently exists. Every data-driven company employs analysts who translate business questions into SQL. Training agents to do this well (and to recover from errors) is immediately valuable.
16
+
17
+ **Why it's a great RL domain:**
18
+ - Rewards are **100% deterministic** — no LLM-as-judge, no subjectivity
19
+ - Multi-turn episodes create **dense reward signal** across the trajectory
20
+ - The error → fix → retry loop is a novel mechanic not present in existing environments
21
+ - Three clearly graduated difficulty levels challenge models across the full skill range
22
+
23
+ ---
24
+
25
+ ## Environment Description
26
+
27
+ The agent interacts with a **synthetic e-commerce SQLite database** containing ~150 customers, 64 products across 8 categories, ~600 orders, ~1000 order items, and ~400 reviews. The database is seeded deterministically (seed=42) so results are reproducible across any machine.
28
+
29
+ The agent receives a natural-language question and iteratively submits SQL queries. Each query is executed, graded against the ground truth, and the reward + error/result is fed back as the next observation.
30
+
31
+ ---
32
+
33
+ ## Database Schema
34
+
35
+ ```
36
+ categories(id, name)
37
+ products(id, name, category_id, price, stock_quantity)
38
+ customers(id, name, email, country, tier∈{bronze|silver|gold}, created_at)
39
+ orders(id, customer_id, status∈{pending|processing|shipped|delivered|cancelled},
40
+ created_at, total_amount)
41
+ order_items(id, order_id, product_id, quantity, unit_price)
42
+ reviews(id, product_id, customer_id, rating∈1-5, created_at)
43
+ ```
44
+
45
+ All dates are ISO-8601 strings sortable by text comparison. SQLite window functions and CTEs are fully supported.
46
+
47
+ ---
48
+
49
+ ## Action & Observation Space
50
+
51
+ ### Action
52
+ ```python
53
+ @dataclass
54
+ class NL2SQLAction(Action):
55
+ query: str # A SQLite SELECT query string
56
+ ```
57
+
58
+ ### Observation
59
+ ```python
60
+ @dataclass
61
+ class NL2SQLObservation(Observation):
62
+ question: str # The NL question to answer
63
+ schema_context: str # Compact schema description
64
+ task_name: str # Active task identifier
65
+ last_query: str # SQL submitted on previous step
66
+ last_result: List[Dict] # Up to 10 result rows
67
+ last_error: Optional[str] # SQLite error string or None
68
+ result_columns: List[str] # Column names of last_result
69
+ step: int # Current step (0 after reset)
70
+ max_steps: int # Maximum steps per episode
71
+ done: bool # Episode ended?
72
+ reward: Optional[float] # Step reward [0.0, 1.0]
73
+ score: float # Cumulative normalised score
74
+ ```
75
+
76
+ ---
77
+
78
+ ## Tasks & Expected Difficulty
79
+
80
+ ### Task 1 — `simple-filter` (easy)
81
+ Single-table SELECT queries with WHERE, ORDER BY, LIMIT. Tests basic SQL fluency. Example questions:
82
+ - "List all gold-tier customers ordered by name alphabetically."
83
+ - "Return the top 5 most expensive products."
84
+
85
+ **Expected solve rate (frontier model, 5 steps):** ~80%
86
+
87
+ ### Task 2 — `join-aggregation` (medium)
88
+ Multi-table JOINs with GROUP BY, HAVING, and aggregation functions. Example questions:
89
+ - "How many orders has each customer placed? Include customers with zero orders."
90
+ - "Which customers have spent more than $500 total on delivered orders?"
91
+
92
+ **Expected solve rate (frontier model, 5 steps):** ~55%
93
+
94
+ ### Task 3 — `analytics-window` (hard)
95
+ CTEs, window functions (DENSE_RANK, ROW_NUMBER, running SUM), and nested subqueries. Example questions:
96
+ - "Rank customers by total spending using DENSE_RANK."
97
+ - "Show monthly revenue and running total for delivered orders in 2024."
98
+
99
+ **Expected solve rate (frontier model, 5 steps):** ~30%
100
+
101
+ ---
102
+
103
+ ## Reward Function
104
+
105
+ Rewards are computed by deterministic comparison of the agent's result set against the ground truth:
106
+
107
+ | Component | Score | Description |
108
+ |---|---|---|
109
+ | `syntax_ok` | +0.10 | Query runs without SQLite error |
110
+ | `columns_match` | +0.20 | Returned column names match ground truth |
111
+ | `row_count_match` | +0.20 | Number of rows matches |
112
+ | `exact_match` | +0.50 | Full result set equals ground truth |
113
+ | `step_penalty` | −0.05/step | Deducted per step beyond the first |
114
+
115
+ Final reward is clamped to `[0.0, 1.0]`. Order sensitivity matches the ground-truth query: ORDER BY queries require correct row ordering; others are order-agnostic.
116
+
117
+ ---
118
+
119
+ ## Baseline Scores
120
+
121
+ Run by the `inference.py` script using `Qwen/Qwen2.5-72B-Instruct` via HuggingFace router:
122
+
123
+ | Task | Expected Score |
124
+ |---|---|
125
+ | `simple-filter` | ~0.70 |
126
+ | `join-aggregation` | ~0.45 |
127
+ | `analytics-window` | ~0.25 |
128
+
129
+ ---
130
+
131
+ ## Setup & Usage
132
+
133
+ ### Prerequisites
134
+ - Python 3.10+
135
+ - Docker (for containerised deployment)
136
+ - A HuggingFace account + token
137
+
138
+ ### Local Development (no Docker)
139
+
140
+ ```bash
141
+ # Clone the repository
142
+ git clone https://huggingface.co/spaces/your-username/nl2sql-bench
143
+ cd nl2sql-bench
144
+
145
+ # Quick start
146
+ chmod +x scripts/run_local.sh
147
+ ./scripts/run_local.sh
148
+
149
+ # Or manually:
150
+ python3 -m venv .venv && source .venv/bin/activate
151
+ pip install openenv-core fastapi "uvicorn[standard]" openai pydantic
152
+ export PYTHONPATH=".:server"
153
+ cd server && uvicorn app:app --reload --port 8000
154
+ ```
155
+
156
+ ### Test the Running Server
157
+
158
+ ```bash
159
+ # Run smoke tests
160
+ chmod +x scripts/smoke_test.sh
161
+ ./scripts/smoke_test.sh http://localhost:8000
162
+
163
+ # Run full test suite
164
+ pip install pytest pytest-asyncio
165
+ PYTHONPATH=".:server" pytest tests/ -v
166
+ ```
167
+
168
+ ### Docker
169
+
170
+ ```bash
171
+ # Build
172
+ docker build -t nl2sql-bench:latest .
173
+
174
+ # Run
175
+ docker run -p 7860:7860 nl2sql-bench:latest
176
+
177
+ # Test
178
+ ./scripts/smoke_test.sh http://localhost:7860
179
+ ```
180
+
181
+ ### Pre-submission Validation
182
+
183
+ ```bash
184
+ # Run the official validator (replace with your HF Space URL)
185
+ chmod +x pre_validation_script.sh
186
+ ./pre_validation_script.sh https://your-username-nl2sql-bench.hf.space .
187
+ ```
188
+
189
+ ### Running the Baseline Inference
190
+
191
+ ```bash
192
+ # Set mandatory variables
193
+ export API_BASE_URL="https://router.huggingface.co/v1"
194
+ export MODEL_NAME="Qwen/Qwen2.5-72B-Instruct"
195
+ export HF_TOKEN="hf_your_token_here"
196
+ export SPACE_URL="https://your-username-nl2sql-bench.hf.space"
197
+
198
+ python inference.py
199
+ ```
200
+
201
+ ### Using the Client Programmatically
202
+
203
+ ```python
204
+ import asyncio
205
+ from client import NL2SQLEnv
206
+ from models import NL2SQLAction
207
+
208
+ async def main():
209
+ async with NL2SQLEnv(base_url="http://localhost:8000") as env:
210
+ result = await env.reset()
211
+ print(result.observation.question)
212
+
213
+ result = await env.step(NL2SQLAction(
214
+ query="SELECT id, name FROM customers WHERE tier='gold' ORDER BY name"
215
+ ))
216
+ print(f"Reward: {result.reward:.2f}")
217
+ print(f"Done: {result.done}")
218
+ print(f"Error: {result.observation.last_error}")
219
+
220
+ asyncio.run(main())
221
+ ```
222
+
223
+ ---
224
+
225
+ ## Project Structure
226
+
227
+ ```
228
+ nl2sql-bench/
229
+ ├── models.py # NL2SQLAction, NL2SQLObservation, NL2SQLState
230
+ ├── client.py # NL2SQLEnv(HTTPEnvClient)
231
+ ├── inference.py # Baseline inference script (mandatory name)
232
+ ├── openenv.yaml # OpenEnv manifest
233
+ ├── pyproject.toml
234
+ ├── Dockerfile # HF Spaces compatible (port 7860)
235
+ ├── .env.example
236
+ ├── server/
237
+ │ ├── app.py # FastAPI entry point
238
+ │ ├── environment.py # Core RL environment logic
239
+ │ ├── grader.py # Deterministic reward computation
240
+ │ ├── requirements.txt
241
+ │ ├── db/
242
+ │ │ ├── schema.sql # 6-table e-commerce schema
243
+ │ │ └── seed.py # Deterministic data generator (seed=42)
244
+ │ └── tasks/
245
+ │ ├── base.py # BaseTask + registry
246
+ │ ├── easy.py # simple-filter (5 examples)
247
+ │ ├── medium.py # join-aggregation (5 examples)
248
+ │ └── hard.py # analytics-window (5 examples)
249
+ ├── tests/
250
+ │ ├── conftest.py
251
+ │ └── test_all.py # 30+ pytest tests
252
+ └── scripts/
253
+ ├── run_local.sh # Local dev server
254
+ └── smoke_test.sh # Endpoint smoke tests
255
+ ```
256
+
257
  ---
258
+
259
+ ## Design Decisions
260
+
261
+ **Why SQLite in-memory?** Zero runtime dependency, deterministic, and it runs comfortably within the 2 vCPU / 8 GB constraint. The database loads in ~50ms.
262
+
263
+ **Why multi-turn (up to 5 steps)?** A single-shot SQL environment gives binary rewards. Multi-turn with error feedback gives the agent — and the GRPO trainer — a rich signal: the model learns not just to write SQL, but to debug and refine its queries.
264
+
265
+ **Why step penalty?** Without it, an agent that accidentally gets the right answer on step 5 scores the same as one that gets it on step 1. The penalty creates pressure to solve efficiently, which is realistic.
266
+
267
+ **Why order-sensitive comparison for ORDER BY queries?** Business questions that say "rank by spending" expect a ranked output. Order-agnostic comparison would give spurious credit.
268
+
269
  ---
270
 
271
+ ## License
272
+
273
+ MIT — see [LICENSE](LICENSE)
__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ nl2sql-bench — NL2SQL Analytics OpenEnv Environment
3
+ ====================================================
4
+ Public API surface for client-side use.
5
+
6
+ from nl2sql_bench import NL2SQLEnv, NL2SQLAction, NL2SQLObservation, NL2SQLState
7
+ """
8
+
9
+ from models import NL2SQLAction, NL2SQLObservation, NL2SQLState
10
+ from client import NL2SQLEnv
11
+
12
+ __version__ = "0.1.0"
13
+ __all__ = [
14
+ "NL2SQLEnv",
15
+ "NL2SQLAction",
16
+ "NL2SQLObservation",
17
+ "NL2SQLState",
18
+ ]
check_quality.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import sys
4
+ import re
5
+ from collections import Counter
6
+ from tqdm import tqdm
7
+
8
+ # Add project root to path
9
+ PROJECT_ROOT = os.path.abspath(os.path.dirname(__file__))
10
+ if PROJECT_ROOT not in sys.path:
11
+ sys.path.insert(0, PROJECT_ROOT)
12
+
13
+ from data_factory.validator import SQLValidator
14
+
15
+ DATASET_FILE = "edge_cases.jsonl"
16
+
17
+ def main():
18
+ if not os.path.exists(DATASET_FILE):
19
+ print(f"Error: {DATASET_FILE} not found!")
20
+ return
21
+
22
+ print("Starting Dataset Quality & Sanity Check...\n")
23
+
24
+ total_rows = 0
25
+ corrupt_json = 0
26
+ sql_execution_failures = 0
27
+ empty_outputs = 0
28
+ missing_domains = 0
29
+
30
+ persona_counts = Counter()
31
+ unique_sqls = set()
32
+ unique_questions = set()
33
+ domain_counts = Counter()
34
+
35
+ validators = {}
36
+
37
+ with open(DATASET_FILE, "r", encoding="utf-8") as f:
38
+ lines = f.readlines()
39
+
40
+ for line in tqdm(lines, desc="Analyzing Rows"):
41
+ total_rows += 1
42
+ try:
43
+ record = json.loads(line)
44
+ except json.JSONDecodeError:
45
+ corrupt_json += 1
46
+ continue
47
+
48
+ prompt_block = record.get("prompt", [])
49
+ sql = record.get("sql", "").strip()
50
+ metadata = record.get("metadata", {})
51
+
52
+ if not prompt_block or len(prompt_block) < 2 or not sql:
53
+ empty_outputs += 1
54
+ continue
55
+
56
+ user_content = prompt_block[1].get("content", "")
57
+ question = user_content.split("QUESTION: ")[-1]
58
+
59
+ # Smart Domain Extraction: Try metadata first, fallback to prompt parsing
60
+ domain = metadata.get("domain")
61
+ if not domain:
62
+ match = re.search(r"Database:\s*([a-zA-Z0-9_]+)", user_content)
63
+ domain = match.group(1) if match else "unknown"
64
+
65
+ persona = metadata.get("persona", "unknown")
66
+
67
+ persona_counts[persona] += 1
68
+ domain_counts[domain] += 1
69
+ unique_sqls.add(sql)
70
+ unique_questions.add(question)
71
+
72
+ # Skip validation if domain is completely unknown/corrupted
73
+ if domain == "unknown":
74
+ missing_domains += 1
75
+ continue
76
+
77
+ # Strict Execution Quality Check
78
+ try:
79
+ if domain not in validators:
80
+ validators[domain] = SQLValidator(domain, seed=42)
81
+
82
+ val_result = validators[domain].validate(sql)
83
+ if not val_result.passed or val_result.row_count == 0:
84
+ sql_execution_failures += 1
85
+ except Exception as e:
86
+ # If any schema error occurs, mark it as failure
87
+ missing_domains += 1
88
+ continue
89
+
90
+ # Cleanup validators
91
+ for v in validators.values():
92
+ v.close()
93
+
94
+ # --- REPORT GENERATION ---
95
+ print("\n" + "="*60)
96
+ print("DATASET HEALTH REPORT")
97
+ print("="*60)
98
+ print(f"Total Rows Parsed : {total_rows}")
99
+ print(f"Corrupt JSON Lines : {corrupt_json}")
100
+ print(f"Missing SQL/Domains : {empty_outputs + missing_domains}")
101
+
102
+ print("\nDIVERSITY METRICS:")
103
+ print(f"Unique SQL Queries : {len(unique_sqls)} (Base logic templates)")
104
+ print(f"Unique NL Questions : {len(unique_questions)}")
105
+
106
+ valid_total = total_rows - (corrupt_json + empty_outputs + missing_domains)
107
+ duplication_rate = (1 - (len(unique_questions) / valid_total)) * 100 if valid_total else 0
108
+ print(f"NL Duplication Rate : {duplication_rate:.2f}% (Should be low!)")
109
+
110
+ print("\nPERSONA DISTRIBUTION:")
111
+ for p, count in persona_counts.most_common():
112
+ print(f" - {p}: {count} ({(count/valid_total)*100:.1f}%)" if valid_total else f" - {p}: {count}")
113
+
114
+ print("\nDOMAIN DISTRIBUTION:")
115
+ for d, count in domain_counts.most_common():
116
+ print(f" - {d}: {count} ({(count/valid_total)*100:.1f}%)" if valid_total else f" - {d}: {count}")
117
+
118
+ print("\nCRITICAL QUALITY CHECK:")
119
+ fail_rate = (sql_execution_failures / valid_total) * 100 if valid_total else 0
120
+ print(f"SQL Execution Failures : {sql_execution_failures} ({fail_rate:.2f}%)")
121
+
122
+ if fail_rate > 5.0:
123
+ print("WARNING: Too many SQLs are failing. Dataset needs cleanup.")
124
+ elif fail_rate > 0:
125
+ print("GOOD: Very low failure rate. Safe to train after minor filtering.")
126
+ else:
127
+ print("PERFECT: Zero execution failures. Pure Gold Dataset!")
128
+ print("="*60)
129
+
130
+ if __name__ == "__main__":
131
+ main()
clean_dataset.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import sys
4
+ import re
5
+ from tqdm import tqdm
6
+
7
+ PROJECT_ROOT = os.path.abspath(os.path.dirname(__file__))
8
+ if PROJECT_ROOT not in sys.path:
9
+ sys.path.insert(0, PROJECT_ROOT)
10
+
11
+ from data_factory.validator import SQLValidator
12
+
13
+ INPUT_FILE = "nl2sql_50k_elite_dataset.jsonl"
14
+ OUTPUT_FILE = "nl2sql_cleaned_ready_to_train.jsonl"
15
+
16
+ def main():
17
+ if not os.path.exists(INPUT_FILE):
18
+ print(f"Error: {INPUT_FILE} not found!")
19
+ return
20
+
21
+ print(f"Sweeping dataset to remove bad SQLs...")
22
+
23
+ with open(INPUT_FILE, "r", encoding="utf-8") as f:
24
+ lines = f.readlines()
25
+
26
+ validators = {}
27
+ cleaned_count = 0
28
+ failed_count = 0
29
+
30
+ with open(OUTPUT_FILE, "w", encoding="utf-8") as out_f:
31
+ for line in tqdm(lines, desc="Filtering Garbage"):
32
+ try:
33
+ record = json.loads(line)
34
+ except json.JSONDecodeError:
35
+ failed_count += 1
36
+ continue
37
+
38
+ sql = record.get("sql", "").strip()
39
+ metadata = record.get("metadata", {})
40
+ domain = metadata.get("domain")
41
+
42
+ # Fallback for domain extraction
43
+ if not domain or domain == "unknown":
44
+ content = record.get("prompt", [{}, {}])[1].get("content", "")
45
+ match = re.search(r"Database:\s*([a-zA-Z0-9_]+)", content)
46
+ domain = match.group(1) if match else "unknown"
47
+
48
+ if domain == "unknown":
49
+ failed_count += 1
50
+ continue
51
+
52
+ if domain not in validators:
53
+ validators[domain] = SQLValidator(domain, seed=42)
54
+
55
+ try:
56
+ val_result = validators[domain].validate(sql)
57
+ # Keep ONLY if SQL is 100% perfect and returns data
58
+ if val_result.passed and val_result.row_count > 0:
59
+ out_f.write(line)
60
+ cleaned_count += 1
61
+ else:
62
+ failed_count += 1
63
+ except Exception:
64
+ failed_count += 1
65
+
66
+ for v in validators.values():
67
+ v.close()
68
+
69
+ print("\n" + "="*50)
70
+ print("DATASET CLEANUP COMPLETE")
71
+ print("="*50)
72
+ print(f"Original Rows : {len(lines)}")
73
+ print(f"Cleaned Rows : {cleaned_count} (100% Valid SQL)")
74
+ print(f"Removed Rows : {failed_count}")
75
+ print(f"Saved To : {OUTPUT_FILE}")
76
+ print("="*50)
77
+
78
+ if __name__ == "__main__":
79
+ main()
client.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import httpx
2
+ import json
3
+ import os
4
+ from typing import Any, Dict, Optional
5
+ from dataclasses import dataclass
6
+
7
+ @dataclass
8
+ class NL2SQLAction:
9
+ query: str
10
+
11
+ @dataclass
12
+ class NL2SQLObservation:
13
+ question: str
14
+ schema_context: str
15
+ task_name: str
16
+ last_query: str
17
+ last_result: list
18
+ last_error: Optional[str]
19
+ result_columns: list
20
+ step: int
21
+ max_steps: int
22
+ done: bool
23
+ reward: float
24
+ score: float
25
+
26
+ @dataclass
27
+ class StepResult:
28
+ observation: NL2SQLObservation
29
+ reward: float
30
+ done: bool
31
+
32
+ class NL2SQLEnv:
33
+ def __init__(self, base_url: str = "http://localhost:8000"):
34
+ self.base_url = base_url
35
+ self.client = httpx.AsyncClient(base_url=base_url, timeout=60.0)
36
+
37
+ async def __aenter__(self):
38
+ return self
39
+
40
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
41
+ await self.client.aclose()
42
+
43
+ async def reset(self) -> StepResult:
44
+ task_name = os.getenv("NL2SQL_DEFAULT_TASK", "simple-filter")
45
+ resp = await self.client.post("/reset", json={"task_name": task_name})
46
+ return self._parse_result(resp.json())
47
+
48
+ async def step(self, action: NL2SQLAction) -> StepResult:
49
+ payload = {"query": action.query}
50
+ resp = await self.client.post("/step", json=payload)
51
+ return self._parse_result(resp.json())
52
+
53
+ def _parse_result(self, payload: Dict[str, Any]) -> StepResult:
54
+ obs_data = payload.get("observation", payload)
55
+
56
+ # SAFETY CHECK: If reward or score is None/null, default to 0.0
57
+ raw_reward = obs_data.get("reward")
58
+ safe_reward = float(raw_reward) if raw_reward is not None else 0.0
59
+
60
+ raw_score = obs_data.get("score")
61
+ safe_score = float(raw_score) if raw_score is not None else 0.0
62
+
63
+ obs = NL2SQLObservation(
64
+ question=obs_data.get("question", ""),
65
+ schema_context=obs_data.get("schema_context", ""),
66
+ task_name=obs_data.get("task_name", ""),
67
+ last_query=obs_data.get("last_query", ""),
68
+ last_result=obs_data.get("last_result", []),
69
+ last_error=obs_data.get("last_error"),
70
+ result_columns=obs_data.get("result_columns", []),
71
+ step=obs_data.get("step", 0),
72
+ max_steps=obs_data.get("max_steps", 5),
73
+ done=obs_data.get("done", False),
74
+ reward=safe_reward,
75
+ score=safe_score,
76
+ )
77
+ return StepResult(
78
+ observation=obs,
79
+ reward=safe_reward,
80
+ done=obs.done,
81
+ )
custom_train.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ merge_and_train.py
3
+ ==================
4
+ 1. Merges nl2sql_cleaned_ready_to_train.jsonl + edge_cases.jsonl
5
+ 2. Shuffles the combined dataset
6
+ 3. Retrains using the same GRPO setup as train.py
7
+
8
+ Run:
9
+ python merge_and_train.py
10
+
11
+ Flags (env vars):
12
+ EDGE_FILE — path to edge cases jsonl (default: edge_cases.jsonl)
13
+ BASE_FILE — path to existing cleaned (default: nl2sql_cleaned_ready_to_train.jsonl)
14
+ MERGED_FILE — merged output path (default: nl2sql_merged_final.jsonl)
15
+ SKIP_MERGE — set "1" to skip merge step and go straight to training
16
+ """
17
+
18
+ import os, sys, json, random
19
+ import torch
20
+ from datasets import Dataset
21
+ from transformers import AutoModelForCausalLM, AutoTokenizer
22
+ from peft import LoraConfig
23
+ from trl import GRPOConfig, GRPOTrainer
24
+
25
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0,5,1,6"
26
+
27
+ sys.path.insert(0, "./server")
28
+ from environment import NL2SQLEnvironment
29
+ from models import NL2SQLAction
30
+ from tasks import all_task_names, get_task
31
+
32
+ # ── Config ───────────────────────────────────────────────────────────────────
33
+ BASE_FILE = os.getenv("BASE_FILE", "nl2sql_cleaned_ready_to_train.jsonl")
34
+ EDGE_FILE = os.getenv("EDGE_FILE", "edge_cases.jsonl")
35
+ MERGED_FILE = os.getenv("MERGED_FILE", "nl2sql_merged_final.jsonl")
36
+ SKIP_MERGE = os.getenv("SKIP_MERGE", "0") == "1"
37
+
38
+ MODEL_NAME = "Qwen/Qwen2.5-Coder-7B-Instruct"
39
+ OUTPUT_DIR = "./qwen-7b-coder-nl2sql-grpo-v2"
40
+
41
+ SYSTEM_PROMPT = """You are a Senior Database Architect and an expert in SQLite.
42
+ Your task is to translate natural language questions into highly optimized, correct SQLite SELECT queries.
43
+
44
+ STRICT RULES:
45
+ 1. Output EXACTLY ONE valid SQLite query.
46
+ 2. DO NOT wrap the query in markdown formatting (no ```sql or ```).
47
+ 3. DO NOT output any explanations, conversational text, or preambles.
48
+ 4. ONLY use standard SQLite functions.
49
+ 5. If the question implies ordering, use the correct ORDER BY clause.
50
+ 6. SELECT only the columns explicitly requested — no extras.
51
+
52
+ Your output must be executable directly against the database as-is."""
53
+
54
+
55
+ # ── Step 1: Merge ─────────────────────────────────────────────────────────────
56
+
57
+ def merge_datasets():
58
+ if SKIP_MERGE:
59
+ print(f"[SKIP_MERGE=1] Using existing {MERGED_FILE}")
60
+ return
61
+
62
+ print(f"Loading base: {BASE_FILE}")
63
+ print(f"Loading edges: {EDGE_FILE}")
64
+
65
+ base_lines = []
66
+ with open(BASE_FILE, "r", encoding="utf-8") as f:
67
+ for line in f:
68
+ line = line.strip()
69
+ if line:
70
+ base_lines.append(line)
71
+
72
+ edge_lines = []
73
+ with open(EDGE_FILE, "r", encoding="utf-8") as f:
74
+ for line in f:
75
+ line = line.strip()
76
+ if line:
77
+ edge_lines.append(line)
78
+
79
+ combined = base_lines + edge_lines
80
+ random.shuffle(combined)
81
+
82
+ with open(MERGED_FILE, "w", encoding="utf-8") as f:
83
+ for line in combined:
84
+ f.write(line + "\n")
85
+
86
+ print(
87
+ f"Merged: {len(base_lines)} base + {len(edge_lines)} edge "
88
+ f"= {len(combined)} total → {MERGED_FILE}"
89
+ )
90
+
91
+
92
+ # ── Step 2: Build HF Dataset ──────────────────────────────────────────────────
93
+
94
+ def build_dataset():
95
+ """
96
+ Primary source: merged JSONL (base + edge cases).
97
+ Fallback: task examples from server/tasks/ (same as original train.py).
98
+ Both are combined so GRPO sees everything.
99
+ """
100
+ data = []
101
+
102
+ # Load merged JSONL
103
+ with open(MERGED_FILE, "r", encoding="utf-8") as f:
104
+ for line in f:
105
+ line = line.strip()
106
+ if not line:
107
+ continue
108
+ rec = json.loads(line)
109
+ # rec has "prompt" (list of messages) and "sql"
110
+ # GRPO needs "prompt" and "task_name" — we use a synthetic task_name
111
+ data.append({
112
+ "prompt": rec["prompt"],
113
+ "task_name": "merged_jsonl" # grader falls back to execution-based reward
114
+ })
115
+
116
+ # Also keep the original task examples so GRPO reward env works for them
117
+ for t_name in all_task_names():
118
+ task = get_task(t_name)
119
+ schema = task.schema_context()
120
+ for ex in task.examples:
121
+ data.append({
122
+ "prompt": [
123
+ {"role": "system", "content": SYSTEM_PROMPT},
124
+ {"role": "user", "content": f"SCHEMA:\n{schema}\n\nQUESTION: {ex.question}"}
125
+ ],
126
+ "task_name": t_name
127
+ })
128
+
129
+ random.shuffle(data)
130
+ print(f"Dataset size: {len(data)} samples")
131
+ return Dataset.from_list(data)
132
+
133
+
134
+ # ── Step 3: Reward function ─────────────────────���─────────────────────────────
135
+
136
+ def sql_reward_func(prompts, completions, task_name, **kwargs):
137
+ rewards = []
138
+ env = NL2SQLEnvironment()
139
+
140
+ for idx, completion in enumerate(completions):
141
+ generated = (
142
+ completion[0]["content"] if isinstance(completion, list) else completion
143
+ )
144
+ # Strip code fences defensively
145
+ import re
146
+ generated = re.sub(r"```(?:sql)?\n?(.*?)```", r"\1", generated, flags=re.DOTALL).strip()
147
+
148
+ t = task_name[idx] if isinstance(task_name, list) else task_name
149
+
150
+ # For merged_jsonl rows the env won't have a matching task →
151
+ # reward purely on execution (non-empty result set = +1, error = 0)
152
+ if t == "merged_jsonl":
153
+ rewards.append(_execution_reward(generated, prompts[idx]))
154
+ continue
155
+
156
+ env.reset(task_name=t)
157
+ try:
158
+ obs = env.step(NL2SQLAction(query=generated))
159
+ rewards.append(float(obs.reward))
160
+ except Exception:
161
+ rewards.append(0.0)
162
+
163
+ return rewards
164
+
165
+
166
+ def _execution_reward(sql: str, prompt) -> float:
167
+ """Simple execution check for merged_jsonl samples."""
168
+ import sqlite3, re as _re
169
+
170
+ # Extract schema from the user message
171
+ user_content = ""
172
+ for msg in (prompt if isinstance(prompt, list) else []):
173
+ if isinstance(msg, dict) and msg.get("role") == "user":
174
+ user_content = msg.get("content", "")
175
+ break
176
+
177
+ schema_match = _re.search(r"SCHEMA:\s*(.*?)\nQUESTION:", user_content, _re.DOTALL)
178
+ if not schema_match:
179
+ return 0.5 # can't verify, neutral reward
180
+
181
+ schema_sql = schema_match.group(1).strip()
182
+ try:
183
+ conn = sqlite3.connect(":memory:")
184
+ conn.executescript(schema_sql)
185
+ rows = conn.execute(sql).fetchall()
186
+ conn.close()
187
+ return 1.0 if rows else 0.3 # ran cleanly but empty → partial credit
188
+ except Exception:
189
+ return 0.0
190
+
191
+
192
+ # ── Step 4: Train ─────────────────────────────────────────────────────────────
193
+
194
+ def main():
195
+ merge_datasets()
196
+ dataset = build_dataset()
197
+
198
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side="right")
199
+ if tokenizer.pad_token is None:
200
+ tokenizer.pad_token = tokenizer.eos_token
201
+
202
+ model = AutoModelForCausalLM.from_pretrained(
203
+ MODEL_NAME,
204
+ torch_dtype=torch.bfloat16,
205
+ attn_implementation="sdpa"
206
+ )
207
+
208
+ peft_config = LoraConfig(
209
+ r=128,
210
+ lora_alpha=256,
211
+ target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
212
+ bias="none",
213
+ task_type="CAUSAL_LM"
214
+ )
215
+
216
+ training_args = GRPOConfig(
217
+ output_dir=OUTPUT_DIR,
218
+ learning_rate=1e-5, # lower LR for fine-grained edge case tuning
219
+ per_device_train_batch_size=2,
220
+ gradient_accumulation_steps=4,
221
+ max_completion_length=256,
222
+ num_generations=8,
223
+ temperature=0.5,
224
+ bf16=True,
225
+ logging_steps=5,
226
+ num_train_epochs=5, # fewer epochs — base knowledge already there
227
+ report_to="none",
228
+ remove_unused_columns=False,
229
+ ddp_find_unused_parameters=False
230
+ )
231
+
232
+ trainer = GRPOTrainer(
233
+ model=model,
234
+ reward_funcs=sql_reward_func,
235
+ args=training_args,
236
+ train_dataset=dataset,
237
+ peft_config=peft_config,
238
+ processing_class=tokenizer
239
+ )
240
+
241
+ trainer.train()
242
+
243
+ if trainer.accelerator.is_main_process:
244
+ trainer.model.save_pretrained(f"{OUTPUT_DIR}/final")
245
+ tokenizer.save_pretrained(f"{OUTPUT_DIR}/final")
246
+ print(f"\nSaved to {OUTPUT_DIR}/final")
247
+
248
+
249
+ if __name__ == "__main__":
250
+ main()
data_expander.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import torch
5
+ import hashlib
6
+ from pathlib import Path
7
+ from tqdm import tqdm
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer
9
+ import sys
10
+
11
+ # --- PATCH FOR TRANSFORMERS VERSION MISMATCH ---
12
+ try:
13
+ import transformers.activations
14
+ if not hasattr(transformers.activations, "PytorchGELUTanh"):
15
+ # Mapping the old name to the new existing one
16
+ transformers.activations.PytorchGELUTanh = transformers.activations.GELUActivation
17
+ except ImportError:
18
+ pass
19
+ # ------------------------------------------------------
20
+
21
+ import os
22
+ import json
23
+ import torch
24
+ # ... baaki ke saare purane imports
25
+
26
+ # Force script to use only the 2 free GPUs (e.g., 0 and 7)
27
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0,7"
28
+
29
+ PROJECT_ROOT = os.path.abspath(os.path.dirname(__file__))
30
+ if PROJECT_ROOT not in sys.path:
31
+ sys.path.insert(0, PROJECT_ROOT)
32
+
33
+ from data_factory.schemas import SCHEMA_CONTEXT
34
+
35
+ # AWQ model is 4x smaller and much faster
36
+ MODEL_NAME = "Qwen/Qwen2.5-72B-Instruct-AWQ"
37
+ INPUT_FILE = "llm_hybrid_templates.json"
38
+ OUTPUT_FILE = "nl2sql_50k_elite_dataset.jsonl"
39
+ VARIATIONS_PER_SQL = 20
40
+ BATCH_SIZE = 64 # AWQ allows much larger batches!
41
+
42
+ SYSTEM_PROMPT = "You are an expert SQL analyst. Write a single SELECT query that answers the question. Output ONLY the SQL query — no markdown, no explanation, no backticks."
43
+
44
+ EXPANSION_PROMPT = """
45
+ You are an expert linguist and NL2SQL data augmentor. I have a SQLite database schema and a complex SQL query.
46
+ Generate exactly {count} completely different natural language questions that this exact SQL query answers.
47
+
48
+ RULES:
49
+ - Personas: Executive (direct), Non-tech (wordy), Analyst (technical), Curious (investigative).
50
+ - Structure: Completely change sentence flow.
51
+ - No direct column/table names.
52
+
53
+ DATABASE SCHEMA:
54
+ {schema}
55
+
56
+ SQL QUERY:
57
+ {sql}
58
+
59
+ OUTPUT FORMAT:
60
+ Return ONLY a valid JSON array of objects: [{{"persona": "...", "question": "..."}}]
61
+ """
62
+
63
+ def extract_json_array(raw_text):
64
+ text = raw_text.strip()
65
+ start = text.find("[")
66
+ end = text.rfind("]")
67
+ if start != -1 and end != -1:
68
+ return text[start:end+1]
69
+ return "[]"
70
+
71
+ def get_hash(text):
72
+ return hashlib.md5(text.lower().strip().encode('utf-8')).hexdigest()
73
+
74
+ def main():
75
+ if not os.path.exists(INPUT_FILE):
76
+ print(f"Error: {INPUT_FILE} not found.")
77
+ sys.exit(1)
78
+
79
+ with open(INPUT_FILE, "r") as f:
80
+ base_templates = json.load(f)
81
+
82
+ print(f"🚀 Loading {MODEL_NAME} on 2 GPUs...")
83
+
84
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side="left")
85
+ tokenizer.pad_token = tokenizer.eos_token
86
+
87
+ # Model loading (AWQ version automatically handles quantization)
88
+ model = AutoModelForCausalLM.from_pretrained(
89
+ MODEL_NAME,
90
+ device_map="auto",
91
+ torch_dtype=torch.float16, # AWQ models use float16/bfloat16 for weights
92
+ low_cpu_mem_usage=True
93
+ )
94
+
95
+ seen_hashes = set()
96
+ total_saved = 0
97
+ if os.path.exists(OUTPUT_FILE):
98
+ with open(OUTPUT_FILE, "r") as f:
99
+ for line in f:
100
+ total_saved += 1 # Quick count
101
+
102
+ pbar = tqdm(total=len(base_templates) * VARIATIONS_PER_SQL, initial=total_saved)
103
+
104
+ # Batch processing
105
+ for i in range(0, len(base_templates), BATCH_SIZE):
106
+ batch = base_templates[i:i + BATCH_SIZE]
107
+ prompts = []
108
+
109
+ for temp in batch:
110
+ msg = [
111
+ {"role": "system", "content": "You output only JSON arrays."},
112
+ {"role": "user", "content": EXPANSION_PROMPT.format(count=VARIATIONS_PER_SQL, schema=SCHEMA_CONTEXT[temp['domain']], sql=temp['sql'])}
113
+ ]
114
+ prompts.append(tokenizer.apply_chat_template(msg, tokenize=False, add_generation_prompt=True))
115
+
116
+ inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
117
+
118
+ try:
119
+ with torch.no_grad():
120
+ # Increased speed: AWQ handles large batches efficiently
121
+ outputs = model.generate(
122
+ **inputs,
123
+ max_new_tokens=2048,
124
+ temperature=0.5,
125
+ do_sample=True,
126
+ pad_token_id=tokenizer.eos_token_id
127
+ )
128
+
129
+ responses = tokenizer.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)
130
+
131
+ with open(OUTPUT_FILE, "a", encoding="utf-8") as out_file:
132
+ for idx, resp in enumerate(responses):
133
+ questions_data = json.loads(extract_json_array(resp))
134
+ sql = batch[idx]["sql"]
135
+ domain = batch[idx]["domain"]
136
+
137
+ for item in questions_data:
138
+ q = item.get("question", "")
139
+ if len(q) > 10:
140
+ q_hash = get_hash(q + sql)
141
+ if q_hash not in seen_hashes:
142
+ seen_hashes.add(q_hash)
143
+ record = {
144
+ "prompt": [
145
+ {"role": "system", "content": SYSTEM_PROMPT},
146
+ {"role": "user", "content": f"SCHEMA: {SCHEMA_CONTEXT[domain]}\nQUESTION: {q}"}
147
+ ],
148
+ "sql": sql
149
+ }
150
+ out_file.write(json.dumps(record, ensure_ascii=False) + "\n")
151
+ total_saved += 1
152
+ pbar.update(1)
153
+ out_file.flush()
154
+ except Exception as e:
155
+ print(f"Batch failed: {e}")
156
+ continue
157
+
158
+ pbar.close()
159
+
160
+ if __name__ == "__main__":
161
+ main()
data_factory/__init__.py ADDED
File without changes
data_factory/augmentor.py ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ data_factory/augmentor.py
3
+ ==========================
4
+ Rule-based Natural Language augmentation.
5
+
6
+ These transformations operate ONLY on NL question strings.
7
+ SQL is NEVER modified — it always comes from the verified template library.
8
+
9
+ Three augmentation strategies:
10
+ 1. Synonym replacement — swaps domain words with semantically equivalent ones
11
+ 2. Condition reordering — shuffles conjunctive phrases (preserves meaning)
12
+ 3. Date normalisation — expresses dates in different formats when applicable
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import random
18
+ import re
19
+ from copy import deepcopy
20
+ from typing import Iterator
21
+
22
+
23
+ # ─────────────────────────────────────────────────────────────────────────────
24
+ # SYNONYM DICTIONARIES
25
+ # ─────────────────────────────────────────────────────────────────────────────
26
+
27
+ # Format: "canonical_term": ["synonym1", "synonym2", ...]
28
+ # All synonyms are semantically equivalent in a business context.
29
+
30
+ _SYNONYMS: dict[str, list[str]] = {
31
+
32
+ # Verbs / action starters
33
+ "list": ["show", "display", "return", "give me", "find", "retrieve"],
34
+ "show": ["list", "display", "return", "get", "retrieve"],
35
+ "find": ["identify", "locate", "get", "show", "retrieve", "look up"],
36
+ "return": ["show", "give", "list", "retrieve", "output"],
37
+ "retrieve": ["fetch", "get", "return", "pull"],
38
+ "get": ["retrieve", "fetch", "return", "give me"],
39
+
40
+ # Aggregation words
41
+ "total": ["sum", "aggregate", "overall", "cumulative", "combined"],
42
+ "average": ["mean", "avg", "typical"],
43
+ "count": ["number of", "quantity of", "how many"],
44
+ "highest": ["largest", "maximum", "top", "greatest"],
45
+ "lowest": ["smallest", "minimum", "least"],
46
+
47
+ # Business / domain
48
+ "customer": ["client", "buyer", "user", "account holder", "shopper"],
49
+ "customers": ["clients", "buyers", "users", "account holders", "shoppers"],
50
+ "product": ["item", "SKU", "article", "goods"],
51
+ "products": ["items", "SKUs", "articles", "goods"],
52
+ "order": ["purchase", "transaction", "sale"],
53
+ "orders": ["purchases", "transactions", "sales"],
54
+ "revenue": ["income", "earnings", "sales amount", "money earned"],
55
+ "spending": ["expenditure", "spend", "purchases"],
56
+ "amount": ["value", "sum", "total", "figure"],
57
+ "price": ["cost", "rate", "charge", "fee"],
58
+
59
+ # Healthcare
60
+ "patient": ["person", "individual", "case"],
61
+ "patients": ["persons", "individuals", "cases"],
62
+ "doctor": ["physician", "clinician", "practitioner", "specialist"],
63
+ "doctors": ["physicians", "clinicians", "practitioners"],
64
+ "appointment": ["visit", "consultation", "session"],
65
+ "appointments": ["visits", "consultations", "sessions"],
66
+ "medication": ["drug", "medicine", "pharmaceutical", "prescription drug"],
67
+ "medications": ["drugs", "medicines", "pharmaceuticals"],
68
+ "diagnosis": ["condition", "finding", "medical finding"],
69
+
70
+ # Finance
71
+ "account": ["bank account", "profile", "portfolio entry"],
72
+ "accounts": ["bank accounts", "profiles"],
73
+ "loan": ["credit", "borrowing", "debt instrument"],
74
+ "loans": ["credits", "borrowings", "debt instruments"],
75
+ "transaction": ["transfer", "payment", "operation", "activity"],
76
+ "transactions": ["transfers", "payments", "operations"],
77
+ "balance": ["funds", "available amount", "account balance"],
78
+
79
+ # HR
80
+ "employee": ["staff member", "worker", "team member", "headcount"],
81
+ "employees": ["staff", "workers", "team members", "workforce"],
82
+ "department": ["team", "division", "unit", "group"],
83
+ "departments": ["teams", "divisions", "units"],
84
+ "salary": ["pay", "compensation", "remuneration", "earnings"],
85
+ "project": ["initiative", "program", "assignment", "engagement"],
86
+ "projects": ["initiatives", "programs", "assignments"],
87
+
88
+ # Adjectives / Qualifiers
89
+ "active": ["current", "ongoing", "live", "existing"],
90
+ "delivered": ["completed", "fulfilled", "received"],
91
+ "cancelled": ["voided", "aborted", "terminated"],
92
+ "alphabetically": ["by name", "in alphabetical order", "A to Z"],
93
+ "descending": ["from highest to lowest", "in decreasing order", "largest first"],
94
+ "ascending": ["from lowest to highest", "in increasing order", "smallest first"],
95
+ "distinct": ["unique", "different"],
96
+ "in stock": ["available", "with available inventory", "not out of stock"],
97
+ }
98
+
99
+
100
+ # ────────────────────────────────────────���────────────────────────────────────
101
+ # DATE PHRASE PATTERNS
102
+ # These will be replaced with alternative date expressions.
103
+ # ─────────────────────────────────────────────────────────────────────────────
104
+
105
+ _DATE_ALTERNATES: list[tuple[str, list[str]]] = [
106
+ # ISO partial
107
+ ("2024-01-01", ["January 1st 2024", "Jan 1, 2024", "the start of 2024", "2024 start"]),
108
+ ("2023-01-01", ["January 1st 2023", "Jan 1, 2023", "the start of 2023"]),
109
+ ("2025-01-01", ["January 1st 2025", "the start of 2025"]),
110
+ # Quarter references
111
+ ("Q1", ["the first quarter", "January through March", "Jan-Mar"]),
112
+ ("Q2", ["the second quarter", "April through June", "Apr-Jun"]),
113
+ ("Q3", ["the third quarter", "July through September", "Jul-Sep"]),
114
+ ("Q4", ["the fourth quarter", "October through December", "Oct-Dec"]),
115
+ # Year references
116
+ ("in 2024", ["during 2024", "throughout 2024", "for the year 2024"]),
117
+ ("in 2023", ["during 2023", "throughout 2023", "for the year 2023"]),
118
+ ]
119
+
120
+
121
+ # ─────────────────────────────────────────────────────────────────────────────
122
+ # CONDITION REORDERING
123
+ # Splits on "and" between two conditions and reverses them.
124
+ # ─────────────────────────────────────────────────────────────────────────────
125
+
126
+ def _reorder_conditions(text: str, rng: random.Random) -> str:
127
+ """
128
+ If the text contains ' and ' connecting two distinct clauses,
129
+ randomly swap their order 50% of the time.
130
+
131
+ Example:
132
+ "active employees earning above $100,000"
133
+ → "employees earning above $100,000 that are active"
134
+ """
135
+ # Only attempt if "and" is present as a clause connector
136
+ matches = list(re.finditer(r'\b(?:and|who are|that are|with)\b', text, re.IGNORECASE))
137
+ if not matches or rng.random() > 0.5:
138
+ return text
139
+
140
+ # Take the first match and swap text around it
141
+ m = matches[0]
142
+ before = text[:m.start()].strip()
143
+ after = text[m.end():].strip()
144
+ connector = m.group(0).lower()
145
+
146
+ # Build swapped version
147
+ if connector in ("and",):
148
+ swapped = f"{after} and {before}"
149
+ else:
150
+ swapped = f"{after} {connector} {before}"
151
+
152
+ # Return swapped only if it doesn't break grammar badly
153
+ # (heuristic: swapped should not start with a verb)
154
+ if swapped and not swapped[0].isupper():
155
+ swapped = swapped[0].upper() + swapped[1:]
156
+ return swapped
157
+
158
+
159
+ # ─────────────────────────────────────────────────────────────────────────────
160
+ # SYNONYM REPLACEMENT
161
+ # ─────────────────────────────────────────────────────────────────────────────
162
+
163
+ def _apply_synonyms(text: str, rng: random.Random, max_replacements: int = 3) -> str:
164
+ """
165
+ Replace up to `max_replacements` words/phrases with synonyms.
166
+ Replacement is probabilistic (50% chance per match) to maintain diversity.
167
+ """
168
+ result = text
169
+ replacements_done = 0
170
+
171
+ # Shuffle the synonym keys to get different replacement targets each call
172
+ keys = list(_SYNONYMS.keys())
173
+ rng.shuffle(keys)
174
+
175
+ for canonical in keys:
176
+ if replacements_done >= max_replacements:
177
+ break
178
+ synonyms = _SYNONYMS[canonical]
179
+ # Case-insensitive match on word boundary
180
+ pattern = re.compile(r'\b' + re.escape(canonical) + r'\b', re.IGNORECASE)
181
+ if pattern.search(result) and rng.random() < 0.5:
182
+ replacement = rng.choice(synonyms)
183
+ # Preserve original casing for first character
184
+ def _replace(m: re.Match) -> str:
185
+ original = m.group(0)
186
+ if original[0].isupper():
187
+ return replacement[0].upper() + replacement[1:]
188
+ return replacement
189
+ result = pattern.sub(_replace, result, count=1)
190
+ replacements_done += 1
191
+
192
+ return result
193
+
194
+
195
+ # ─────────────────────────────────────────────────────────────────────────────
196
+ # DATE FORMAT VARIATION
197
+ # ─────────────────────────────────────────────────────────────────────────────
198
+
199
+ def _vary_dates(text: str, rng: random.Random) -> str:
200
+ """Replace date phrases with alternate representations."""
201
+ result = text
202
+ for phrase, alternates in _DATE_ALTERNATES:
203
+ if phrase.lower() in result.lower() and rng.random() < 0.6:
204
+ alt = rng.choice(alternates)
205
+ result = re.sub(re.escape(phrase), alt, result, count=1, flags=re.IGNORECASE)
206
+ return result
207
+
208
+
209
+ # ─────────────────────────────────────────────────────────────────────────────
210
+ # PUBLIC API
211
+ # ─────────────────────────────────────────────────────────────────────────────
212
+
213
+ def augment_nl(
214
+ nl_question: str,
215
+ n: int = 3,
216
+ seed: int = 42,
217
+ ) -> list[str]:
218
+ """
219
+ Generate `n` rule-based augmented variants of a natural language question.
220
+
221
+ Each variant applies a different combination of:
222
+ - synonym replacement
223
+ - condition reordering
224
+ - date format variation
225
+
226
+ The original question is NOT included in the output.
227
+
228
+ Parameters
229
+ ----------
230
+ nl_question : str
231
+ The base NL question to augment.
232
+ n : int
233
+ Number of variants to generate.
234
+ seed : int
235
+ Random seed for reproducibility.
236
+
237
+ Returns
238
+ -------
239
+ list[str]
240
+ Up to `n` distinct augmented strings. May be fewer if the question
241
+ is too short to vary meaningfully.
242
+ """
243
+ rng = random.Random(seed)
244
+ variants: list[str] = []
245
+ seen: set[str] = {nl_question}
246
+
247
+ strategies = [
248
+ # Strategy 1: synonym only
249
+ lambda t, r: _apply_synonyms(t, r, max_replacements=2),
250
+ # Strategy 2: synonym + date
251
+ lambda t, r: _vary_dates(_apply_synonyms(t, r, max_replacements=2), r),
252
+ # Strategy 3: condition reorder + synonym
253
+ lambda t, r: _apply_synonyms(_reorder_conditions(t, r), r, max_replacements=1),
254
+ # Strategy 4: heavy synonym
255
+ lambda t, r: _apply_synonyms(t, r, max_replacements=4),
256
+ # Strategy 5: date only
257
+ lambda t, r: _vary_dates(t, r),
258
+ ]
259
+
260
+ for i in range(n * 3): # Over-generate, then deduplicate
261
+ strategy = strategies[i % len(strategies)]
262
+ # Use a different seed offset per variant attempt
263
+ local_rng = random.Random(seed + i * 31)
264
+ candidate = strategy(nl_question, local_rng).strip()
265
+
266
+ # Normalise whitespace
267
+ candidate = " ".join(candidate.split())
268
+
269
+ if candidate and candidate not in seen:
270
+ seen.add(candidate)
271
+ variants.append(candidate)
272
+
273
+ if len(variants) >= n:
274
+ break
275
+
276
+ return variants
277
+
278
+
279
+ def generate_all_augmentations(
280
+ nl_question: str,
281
+ seed: int = 42,
282
+ n_per_template: int = 3,
283
+ ) -> Iterator[str]:
284
+ """
285
+ Yield augmented NL variants one at a time (generator).
286
+ Suitable for streaming into a large dataset without memory pressure.
287
+ """
288
+ yield from augment_nl(nl_question, n=n_per_template, seed=seed)
data_factory/config.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ data_factory/config.py
3
+ ======================
4
+ Central configuration for the NL2SQL Synthetic Data Factory.
5
+
6
+ Design philosophy:
7
+ - SQL ALWAYS comes from human-verified templates → zero SQL errors
8
+ - LLM ONLY generates natural language paraphrases → no SQL hallucination
9
+ - Every SQL is execution-validated before saving → guaranteed correctness
10
+ """
11
+
12
+ from __future__ import annotations
13
+ from pathlib import Path
14
+
15
+ # ── Paths ────────────────────────────────────────────────────────────────
16
+ ROOT_DIR = Path(__file__).parent.parent
17
+ DATA_DIR = ROOT_DIR / "generated_data"
18
+ CHECKPOINT_DIR = DATA_DIR / "checkpoints"
19
+ OUTPUT_DIR = DATA_DIR / "output"
20
+
21
+ # ── vLLM / Model ─────────────────────────────────────────────────────────
22
+ # For H100 with 80GB VRAM — run Llama-3-70B or Qwen-72B at full bf16
23
+ GENERATOR_MODEL = "meta-llama/Meta-Llama-3-70B-Instruct" # change to your preferred model
24
+ TENSOR_PARALLEL = 4 # Number of GPUs for tensor parallelism (H100 cluster)
25
+ MAX_MODEL_LEN = 4096 # Max context length
26
+ GPU_MEMORY_UTIL = 0.90 # Leave 10% headroom
27
+
28
+ # ── Generation settings ──────────────────────────────────────────────────
29
+ PERSONAS = ["ceo", "chatty", "lazy_typist", "non_techie", "analyst"]
30
+ NL_VARIANTS_PER_TEMPLATE = 5 # One per persona
31
+ AUGMENTATIONS_PER_NL = 3 # Rule-based variations per NL string
32
+ TEMPERATURE = 0.85 # Slightly high for diversity
33
+ MAX_NEW_TOKENS = 150 # NL questions are short
34
+
35
+ # ── Scale targets ────────────────────────────────────────────────────────
36
+ # 56 base SQL templates × 5 personas × 3 augmentations = 840 "original" records
37
+ # With vLLM generating more NL variants, target: ~500K-1M clean records
38
+ VLLM_EXTRA_VARIANTS = 10 # Additional vLLM NL variants per template beyond personas
39
+
40
+ # ── Validation ───────────────────────────────────────────────────────────
41
+ RANDOM_SEED = 42
42
+
43
+ # ── Domains ──────────────────────────────────────────────────────────────
44
+ DOMAINS = ["ecommerce", "healthcare", "finance", "hr"]
45
+
46
+ DIFFICULTY_LABELS = {
47
+ "easy": "Single-table SELECT with basic WHERE/ORDER/LIMIT.",
48
+ "medium": "Multi-table JOIN with GROUP BY/HAVING/aggregates.",
49
+ "hard": "CTEs, window functions, subqueries.",
50
+ }
data_factory/generate_data.py ADDED
@@ -0,0 +1,1947 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ generate_data.py — NL2SQL Synthetic Data Factory
3
+ =================================================
4
+ Designed for H100 + vLLM. Produces a clean JSONL file ready for SFT or GRPO training
5
+ with the nl2sql-bench codebase (schema: e-commerce SQLite).
6
+
7
+ Architecture
8
+ ------------
9
+ 1. SQL_TEMPLATES — 120+ ground-truth SQLs, hand-written and verified, NEVER LLM-generated.
10
+ 2. SQLiteValidator — executes every SQL against the actual seeded DB; discards any failure.
11
+ 3. VLLMGenerator — async batched calls to a local vLLM server for NL paraphrasing.
12
+ 4. RuleAugmentor — pure-Python synonym / date-format / condition-order augmentation.
13
+ 5. DataFactory — orchestrates the full pipeline; writes JSONL with checkpointing.
14
+
15
+ Output schema (one JSON object per line)
16
+ -----------------------------------------
17
+ {
18
+ "id": "easy_001_persona_ceo",
19
+ "difficulty": "easy" | "medium" | "hard",
20
+ "persona": "ceo" | "chatty" | "lazy" | "confused" | "analyst",
21
+ "question": "<natural language question>",
22
+ "sql": "<ground-truth SQL>",
23
+ "db_result_ok": true, # always true — failures are discarded
24
+ "augmented": false # true when rule-augmentor modified the NL
25
+ }
26
+
27
+ Usage
28
+ -----
29
+ # 1. Start vLLM server (H100):
30
+ # vllm serve meta-llama/Meta-Llama-3-70B-Instruct \
31
+ # --tensor-parallel-size 4 --port 8001 \
32
+ # --max-model-len 4096 --gpu-memory-utilization 0.92
33
+
34
+ # 2. Run this script (place it next to the nl2sql-bench folder):
35
+ # python generate_data.py \
36
+ # --vllm-url http://localhost:8001/v1 \
37
+ # --model meta-llama/Meta-Llama-3-70B-Instruct \
38
+ # --output nl2sql_train.jsonl \
39
+ # --personas-per-template 5 \
40
+ # --aug-rounds 2 \
41
+ # --batch-size 64
42
+
43
+ Requirements
44
+ ------------
45
+ pip install openai tqdm
46
+ (vLLM + your model already running separately)
47
+
48
+ IMPORTANT: Copy server/db/schema.sql and server/db/seed.py from nl2sql-bench
49
+ into the same directory as this script, OR set --bench-root to the repo root.
50
+ """
51
+
52
+ from __future__ import annotations
53
+
54
+ import argparse
55
+ import asyncio
56
+ import hashlib
57
+ import json
58
+ import logging
59
+ import os
60
+ import random
61
+ import re
62
+ import sqlite3
63
+ import sys
64
+ import time
65
+ from copy import deepcopy
66
+ from dataclasses import dataclass, asdict
67
+ from pathlib import Path
68
+ from typing import Any, Dict, List, Optional, Tuple
69
+
70
+ from openai import AsyncOpenAI
71
+ from tqdm import tqdm
72
+
73
+ # ─────────────────────────────────────────────────────────────────────────────
74
+ # Logging
75
+ # ─────────────────────────────────────────────────────────────────────────────
76
+
77
+ logging.basicConfig(
78
+ level=logging.INFO,
79
+ format="%(asctime)s %(levelname)-8s %(message)s",
80
+ datefmt="%H:%M:%S",
81
+ )
82
+ log = logging.getLogger("data-factory")
83
+
84
+
85
+ # ─────────────────────────────────────────────────────────────────────────────
86
+ # Database: build & validate
87
+ # ─────────────────────────────────────────────────────────────────────────────
88
+
89
+ SCHEMA_SQL = """
90
+ CREATE TABLE IF NOT EXISTS categories (
91
+ id INTEGER PRIMARY KEY,
92
+ name TEXT NOT NULL UNIQUE
93
+ );
94
+
95
+ CREATE TABLE IF NOT EXISTS products (
96
+ id INTEGER PRIMARY KEY,
97
+ name TEXT NOT NULL,
98
+ category_id INTEGER NOT NULL REFERENCES categories(id),
99
+ price REAL NOT NULL CHECK(price >= 0),
100
+ stock_quantity INTEGER NOT NULL DEFAULT 0
101
+ );
102
+
103
+ CREATE TABLE IF NOT EXISTS customers (
104
+ id INTEGER PRIMARY KEY,
105
+ name TEXT NOT NULL,
106
+ email TEXT NOT NULL UNIQUE,
107
+ country TEXT NOT NULL,
108
+ tier TEXT NOT NULL DEFAULT 'bronze'
109
+ CHECK(tier IN ('bronze', 'silver', 'gold')),
110
+ created_at TEXT NOT NULL
111
+ );
112
+
113
+ CREATE TABLE IF NOT EXISTS orders (
114
+ id INTEGER PRIMARY KEY,
115
+ customer_id INTEGER NOT NULL REFERENCES customers(id),
116
+ status TEXT NOT NULL DEFAULT 'pending'
117
+ CHECK(status IN ('pending','processing','shipped','delivered','cancelled')),
118
+ created_at TEXT NOT NULL,
119
+ total_amount REAL NOT NULL CHECK(total_amount >= 0)
120
+ );
121
+
122
+ CREATE TABLE IF NOT EXISTS order_items (
123
+ id INTEGER PRIMARY KEY,
124
+ order_id INTEGER NOT NULL REFERENCES orders(id),
125
+ product_id INTEGER NOT NULL REFERENCES products(id),
126
+ quantity INTEGER NOT NULL CHECK(quantity > 0),
127
+ unit_price REAL NOT NULL CHECK(unit_price >= 0)
128
+ );
129
+
130
+ CREATE TABLE IF NOT EXISTS reviews (
131
+ id INTEGER PRIMARY KEY,
132
+ product_id INTEGER NOT NULL REFERENCES products(id),
133
+ customer_id INTEGER NOT NULL REFERENCES customers(id),
134
+ rating INTEGER NOT NULL CHECK(rating BETWEEN 1 AND 5),
135
+ created_at TEXT NOT NULL
136
+ );
137
+ """
138
+
139
+ # Minimal seeder so the validator can run the SQL against real data.
140
+ # Mirrors the logic in nl2sql-bench/server/db/seed.py (fixed seed = 42).
141
+ SEED_SCRIPT = """
142
+ import random, sqlite3
143
+ from datetime import date, timedelta
144
+
145
+ RNG = random.Random(42)
146
+
147
+ CATEGORIES = ["Electronics","Clothing","Books","Home & Garden",
148
+ "Sports & Outdoors","Toys & Games","Beauty","Automotive"]
149
+
150
+ PRODUCTS = {
151
+ "Electronics": ["Wireless Headphones","USB-C Hub","Mechanical Keyboard",
152
+ "Webcam 4K","Portable Charger","Smart Speaker",
153
+ "Monitor Stand","HDMI Cable 2.1"],
154
+ "Clothing": ["Cotton T-Shirt","Slim Fit Jeans","Hoodie",
155
+ "Running Shorts","Winter Jacket","Polo Shirt",
156
+ "Casual Sneakers","Wool Socks"],
157
+ "Books": ["Clean Code","Designing Data-Intensive Applications",
158
+ "The Pragmatic Programmer","System Design Interview",
159
+ "Deep Learning Book","Python Cookbook",
160
+ "Domain-Driven Design","Refactoring"],
161
+ "Home & Garden": ["Coffee Maker","Air Purifier","LED Desk Lamp",
162
+ "Plant Pot Set","Storage Organiser","Cutting Board",
163
+ "Vacuum Cleaner","Electric Kettle"],
164
+ "Sports & Outdoors": ["Yoga Mat","Resistance Bands","Cycling Gloves",
165
+ "Trekking Poles","Water Bottle 1L","Jump Rope",
166
+ "Foam Roller","Compression Socks"],
167
+ "Toys & Games": ["Lego City Set","Card Game Pack","Puzzle 1000pc",
168
+ "Remote Control Car","Building Blocks",
169
+ "Board Game Strategy","Art Set","Toy Drone"],
170
+ "Beauty": ["Face Serum","SPF 50 Sunscreen","Lip Balm",
171
+ "Shampoo Pro","Hair Mask","Eye Cream",
172
+ "Vitamin C Cream","Toner Mist"],
173
+ "Automotive": ["Car Phone Mount","Dash Cam","Tyre Inflator",
174
+ "Car Vacuum","Seat Cushion","Steering Wheel Cover",
175
+ "OBD Scanner","Jump Starter"],
176
+ }
177
+
178
+ COUNTRIES = ["India","USA","Germany","UK","Canada",
179
+ "Australia","France","Brazil","Japan","Singapore"]
180
+ TIERS = ["bronze","silver","gold"]
181
+ STATUSES = ["pending","processing","shipped","delivered","cancelled"]
182
+
183
+ FIRST = ["Aarav","Priya","Rahul","Neha","Arjun","Sneha","Vikram","Pooja",
184
+ "Karthik","Divya","James","Sarah","Michael","Emily","David","Jessica",
185
+ "Hans","Lena","Oliver","Sofia","Pierre","Amelie","Carlos","Laura",
186
+ "Yuki","Hana","Wei","Mei","Aiden","Zara"]
187
+ LAST = ["Sharma","Singh","Patel","Kumar","Gupta","Verma","Nair","Reddy",
188
+ "Smith","Johnson","Brown","Williams","Jones","Davis","Wilson",
189
+ "Müller","Schmidt","Schneider","Fischer","Weber",
190
+ "Martin","Bernard","Thomas","Richard","Petit",
191
+ "Garcia","Martinez","Lopez","Sanchez","Gonzalez"]
192
+
193
+
194
+ def _date(start=2022, end=2025):
195
+ s = date(start, 1, 1)
196
+ e = date(end, 12, 31)
197
+ return str(s + timedelta(days=RNG.randint(0, (e - s).days)))
198
+
199
+
200
+ def seed(conn):
201
+ c = conn.cursor()
202
+ for cat in CATEGORIES:
203
+ c.execute("INSERT OR IGNORE INTO categories(name) VALUES (?)", (cat,))
204
+ conn.commit()
205
+
206
+ cat_ids = {r[1]: r[0] for r in conn.execute("SELECT id, name FROM categories")}
207
+
208
+ for cat, prods in PRODUCTS.items():
209
+ for pname in prods:
210
+ c.execute(
211
+ "INSERT OR IGNORE INTO products(name,category_id,price,stock_quantity) VALUES (?,?,?,?)",
212
+ (pname, cat_ids[cat], round(RNG.uniform(5, 500), 2), RNG.randint(0, 200)),
213
+ )
214
+ conn.commit()
215
+
216
+ for i in range(200):
217
+ name = f"{RNG.choice(FIRST)} {RNG.choice(LAST)}"
218
+ email = f"user{i}@example.com"
219
+ c.execute(
220
+ "INSERT OR IGNORE INTO customers(name,email,country,tier,created_at) VALUES (?,?,?,?,?)",
221
+ (name, email, RNG.choice(COUNTRIES), RNG.choice(TIERS), _date()),
222
+ )
223
+ conn.commit()
224
+
225
+ cust_ids = [r[0] for r in conn.execute("SELECT id FROM customers")]
226
+ prod_ids = [r[0] for r in conn.execute("SELECT id FROM products")]
227
+
228
+ for _ in range(600):
229
+ cid = RNG.choice(cust_ids)
230
+ amt = round(RNG.uniform(10, 1000), 2)
231
+ status = RNG.choice(STATUSES)
232
+ d = _date()
233
+ c.execute(
234
+ "INSERT INTO orders(customer_id,status,created_at,total_amount) VALUES (?,?,?,?)",
235
+ (cid, status, d, amt),
236
+ )
237
+ conn.commit()
238
+
239
+ ord_ids = [r[0] for r in conn.execute("SELECT id FROM orders")]
240
+ for oid in ord_ids:
241
+ for _ in range(RNG.randint(1, 4)):
242
+ pid = RNG.choice(prod_ids)
243
+ qty = RNG.randint(1, 5)
244
+ price = round(RNG.uniform(5, 500), 2)
245
+ c.execute(
246
+ "INSERT INTO order_items(order_id,product_id,quantity,unit_price) VALUES (?,?,?,?)",
247
+ (oid, pid, qty, price),
248
+ )
249
+ conn.commit()
250
+
251
+ for _ in range(400):
252
+ pid = RNG.choice(prod_ids)
253
+ cid = RNG.choice(cust_ids)
254
+ rating = RNG.randint(1, 5)
255
+ c.execute(
256
+ "INSERT INTO reviews(product_id,customer_id,rating,created_at) VALUES (?,?,?,?)",
257
+ (pid, cid, rating, _date()),
258
+ )
259
+ conn.commit()
260
+ """
261
+
262
+
263
+ def build_db() -> sqlite3.Connection:
264
+ """Build an in-memory SQLite DB with schema + seed data."""
265
+ conn = sqlite3.connect(":memory:")
266
+ conn.executescript(SCHEMA_SQL)
267
+ exec(SEED_SCRIPT, {"conn": conn}) # run the seeder inline
268
+ conn.row_factory = sqlite3.Row
269
+ log.info("In-memory DB built and seeded.")
270
+ return conn
271
+
272
+
273
+ class SQLiteValidator:
274
+ """Execute SQL against the seeded DB; return (rows, error)."""
275
+
276
+ def __init__(self, conn: sqlite3.Connection):
277
+ self.conn = conn
278
+
279
+ def validate(self, sql: str) -> Tuple[bool, Optional[str]]:
280
+ sql = sql.strip().rstrip(";")
281
+ if not sql:
282
+ return False, "Empty SQL"
283
+ first = sql.split()[0].lower()
284
+ if first != "select":
285
+ return False, f"Non-SELECT statement: {first}"
286
+ try:
287
+ cur = self.conn.execute(sql)
288
+ cur.fetchmany(500)
289
+ return True, None
290
+ except sqlite3.Error as exc:
291
+ return False, str(exc)
292
+
293
+
294
+ # ─────────────────────────────────────────────────────────────────────────────
295
+ # SQL Template Library (ground-truth, hand-written, execution-validated)
296
+ # ─────────────────────────────────────────────────────────────────────────────
297
+
298
+ @dataclass
299
+ class SQLTemplate:
300
+ id: str
301
+ difficulty: str # easy | medium | hard
302
+ description: str # plain-English description fed to the LLM
303
+ sql: str
304
+ order_sensitive: bool = False
305
+
306
+
307
+ # NOTE: Every SQL here uses only the 6 tables in the schema and valid SQLite syntax.
308
+ # They are intentionally grouped by the SQL pattern they teach, not just by difficulty.
309
+
310
+ EASY_TEMPLATES: List[SQLTemplate] = [
311
+ # ── Equality filter ──────────────────────────────────────────────────────
312
+ SQLTemplate(
313
+ id="easy_001",
314
+ difficulty="easy",
315
+ description=(
316
+ "List all gold-tier customers, ordered alphabetically by name. "
317
+ "Return id, name, email, country."
318
+ ),
319
+ sql=(
320
+ "SELECT id, name, email, country "
321
+ "FROM customers "
322
+ "WHERE tier = 'gold' "
323
+ "ORDER BY name ASC"
324
+ ),
325
+ order_sensitive=True,
326
+ ),
327
+ SQLTemplate(
328
+ id="easy_002",
329
+ difficulty="easy",
330
+ description=(
331
+ "Show all products priced above $100, sorted by price descending. "
332
+ "Return id, name, price."
333
+ ),
334
+ sql=(
335
+ "SELECT id, name, price "
336
+ "FROM products "
337
+ "WHERE price > 100 "
338
+ "ORDER BY price DESC"
339
+ ),
340
+ order_sensitive=True,
341
+ ),
342
+ SQLTemplate(
343
+ id="easy_003",
344
+ difficulty="easy",
345
+ description=(
346
+ "Find all delivered orders with a total_amount greater than $200, "
347
+ "sorted by total_amount descending. "
348
+ "Return id, customer_id, total_amount, created_at."
349
+ ),
350
+ sql=(
351
+ "SELECT id, customer_id, total_amount, created_at "
352
+ "FROM orders "
353
+ "WHERE status = 'delivered' AND total_amount > 200 "
354
+ "ORDER BY total_amount DESC"
355
+ ),
356
+ order_sensitive=True,
357
+ ),
358
+ SQLTemplate(
359
+ id="easy_004",
360
+ difficulty="easy",
361
+ description=(
362
+ "Return the top 5 most expensive products. Return id, name, price."
363
+ ),
364
+ sql=(
365
+ "SELECT id, name, price "
366
+ "FROM products "
367
+ "ORDER BY price DESC "
368
+ "LIMIT 5"
369
+ ),
370
+ order_sensitive=True,
371
+ ),
372
+ SQLTemplate(
373
+ id="easy_005",
374
+ difficulty="easy",
375
+ description=(
376
+ "List all distinct countries where customers come from, sorted alphabetically. "
377
+ "Return a single column: country."
378
+ ),
379
+ sql=(
380
+ "SELECT DISTINCT country "
381
+ "FROM customers "
382
+ "ORDER BY country ASC"
383
+ ),
384
+ order_sensitive=True,
385
+ ),
386
+ SQLTemplate(
387
+ id="easy_006",
388
+ difficulty="easy",
389
+ description=(
390
+ "Show all pending orders, ordered by created_at descending. "
391
+ "Return id, customer_id, total_amount, created_at."
392
+ ),
393
+ sql=(
394
+ "SELECT id, customer_id, total_amount, created_at "
395
+ "FROM orders "
396
+ "WHERE status = 'pending' "
397
+ "ORDER BY created_at DESC"
398
+ ),
399
+ order_sensitive=True,
400
+ ),
401
+ SQLTemplate(
402
+ id="easy_007",
403
+ difficulty="easy",
404
+ description=(
405
+ "Find all products with zero stock (stock_quantity = 0). "
406
+ "Return id, name, price, category_id."
407
+ ),
408
+ sql=(
409
+ "SELECT id, name, price, category_id "
410
+ "FROM products "
411
+ "WHERE stock_quantity = 0"
412
+ ),
413
+ ),
414
+ SQLTemplate(
415
+ id="easy_008",
416
+ difficulty="easy",
417
+ description=(
418
+ "How many customers are there in total? Return a single value: total_customers."
419
+ ),
420
+ sql="SELECT COUNT(*) AS total_customers FROM customers",
421
+ ),
422
+ SQLTemplate(
423
+ id="easy_009",
424
+ difficulty="easy",
425
+ description=(
426
+ "What is the most expensive product price in the store? "
427
+ "Return a single value: max_price."
428
+ ),
429
+ sql="SELECT MAX(price) AS max_price FROM products",
430
+ ),
431
+ SQLTemplate(
432
+ id="easy_010",
433
+ difficulty="easy",
434
+ description=(
435
+ "What is the cheapest product price in the store? "
436
+ "Return a single value: min_price."
437
+ ),
438
+ sql="SELECT MIN(price) AS min_price FROM products",
439
+ ),
440
+ SQLTemplate(
441
+ id="easy_011",
442
+ difficulty="easy",
443
+ description=(
444
+ "What is the average price of all products? "
445
+ "Round to 2 decimal places. Return: avg_price."
446
+ ),
447
+ sql="SELECT ROUND(AVG(price), 2) AS avg_price FROM products",
448
+ ),
449
+ SQLTemplate(
450
+ id="easy_012",
451
+ difficulty="easy",
452
+ description=(
453
+ "Show all customers from India, sorted by name ascending. "
454
+ "Return id, name, email, tier."
455
+ ),
456
+ sql=(
457
+ "SELECT id, name, email, tier "
458
+ "FROM customers "
459
+ "WHERE country = 'India' "
460
+ "ORDER BY name ASC"
461
+ ),
462
+ order_sensitive=True,
463
+ ),
464
+ SQLTemplate(
465
+ id="easy_013",
466
+ difficulty="easy",
467
+ description=(
468
+ "List the 10 most recently placed orders. "
469
+ "Return id, customer_id, status, created_at, total_amount."
470
+ ),
471
+ sql=(
472
+ "SELECT id, customer_id, status, created_at, total_amount "
473
+ "FROM orders "
474
+ "ORDER BY created_at DESC "
475
+ "LIMIT 10"
476
+ ),
477
+ order_sensitive=True,
478
+ ),
479
+ SQLTemplate(
480
+ id="easy_014",
481
+ difficulty="easy",
482
+ description=(
483
+ "Find all reviews with a rating of 5 stars. "
484
+ "Return id, product_id, customer_id, created_at."
485
+ ),
486
+ sql=(
487
+ "SELECT id, product_id, customer_id, created_at "
488
+ "FROM reviews "
489
+ "WHERE rating = 5"
490
+ ),
491
+ ),
492
+ SQLTemplate(
493
+ id="easy_015",
494
+ difficulty="easy",
495
+ description=(
496
+ "Find all reviews with a rating of 1 star (lowest possible). "
497
+ "Return id, product_id, customer_id, created_at."
498
+ ),
499
+ sql=(
500
+ "SELECT id, product_id, customer_id, created_at "
501
+ "FROM reviews "
502
+ "WHERE rating = 1"
503
+ ),
504
+ ),
505
+ SQLTemplate(
506
+ id="easy_016",
507
+ difficulty="easy",
508
+ description=(
509
+ "Count the number of cancelled orders. Return: cancelled_count."
510
+ ),
511
+ sql=(
512
+ "SELECT COUNT(*) AS cancelled_count "
513
+ "FROM orders "
514
+ "WHERE status = 'cancelled'"
515
+ ),
516
+ ),
517
+ SQLTemplate(
518
+ id="easy_017",
519
+ difficulty="easy",
520
+ description=(
521
+ "List all products with stock_quantity greater than 100, "
522
+ "sorted by stock_quantity descending. Return id, name, stock_quantity."
523
+ ),
524
+ sql=(
525
+ "SELECT id, name, stock_quantity "
526
+ "FROM products "
527
+ "WHERE stock_quantity > 100 "
528
+ "ORDER BY stock_quantity DESC"
529
+ ),
530
+ order_sensitive=True,
531
+ ),
532
+ SQLTemplate(
533
+ id="easy_018",
534
+ difficulty="easy",
535
+ description=(
536
+ "Find all silver-tier customers from the USA. "
537
+ "Return id, name, email."
538
+ ),
539
+ sql=(
540
+ "SELECT id, name, email "
541
+ "FROM customers "
542
+ "WHERE tier = 'silver' AND country = 'USA'"
543
+ ),
544
+ ),
545
+ SQLTemplate(
546
+ id="easy_019",
547
+ difficulty="easy",
548
+ description=(
549
+ "What is the total revenue from all delivered orders? "
550
+ "Round to 2 decimal places. Return: total_revenue."
551
+ ),
552
+ sql=(
553
+ "SELECT ROUND(SUM(total_amount), 2) AS total_revenue "
554
+ "FROM orders "
555
+ "WHERE status = 'delivered'"
556
+ ),
557
+ ),
558
+ SQLTemplate(
559
+ id="easy_020",
560
+ difficulty="easy",
561
+ description=(
562
+ "List all orders placed in 2024, sorted by created_at ascending. "
563
+ "Return id, customer_id, status, total_amount, created_at."
564
+ ),
565
+ sql=(
566
+ "SELECT id, customer_id, status, total_amount, created_at "
567
+ "FROM orders "
568
+ "WHERE created_at >= '2024-01-01' AND created_at < '2025-01-01' "
569
+ "ORDER BY created_at ASC"
570
+ ),
571
+ order_sensitive=True,
572
+ ),
573
+ SQLTemplate(
574
+ id="easy_021",
575
+ difficulty="easy",
576
+ description=(
577
+ "Show the bottom 5 cheapest products. Return id, name, price."
578
+ ),
579
+ sql=(
580
+ "SELECT id, name, price "
581
+ "FROM products "
582
+ "ORDER BY price ASC "
583
+ "LIMIT 5"
584
+ ),
585
+ order_sensitive=True,
586
+ ),
587
+ SQLTemplate(
588
+ id="easy_022",
589
+ difficulty="easy",
590
+ description=(
591
+ "Count how many products exist in the catalogue. Return: product_count."
592
+ ),
593
+ sql="SELECT COUNT(*) AS product_count FROM products",
594
+ ),
595
+ SQLTemplate(
596
+ id="easy_023",
597
+ difficulty="easy",
598
+ description=(
599
+ "List all distinct order statuses that exist in the orders table. "
600
+ "Return a single column: status."
601
+ ),
602
+ sql="SELECT DISTINCT status FROM orders ORDER BY status ASC",
603
+ order_sensitive=True,
604
+ ),
605
+ SQLTemplate(
606
+ id="easy_024",
607
+ difficulty="easy",
608
+ description=(
609
+ "Find customers who joined (created_at) in 2023. "
610
+ "Return id, name, country, tier, created_at, sorted by created_at ascending."
611
+ ),
612
+ sql=(
613
+ "SELECT id, name, country, tier, created_at "
614
+ "FROM customers "
615
+ "WHERE created_at >= '2023-01-01' AND created_at < '2024-01-01' "
616
+ "ORDER BY created_at ASC"
617
+ ),
618
+ order_sensitive=True,
619
+ ),
620
+ SQLTemplate(
621
+ id="easy_025",
622
+ difficulty="easy",
623
+ description=(
624
+ "Show all orders with total_amount between $50 and $150 inclusive. "
625
+ "Return id, customer_id, total_amount, status."
626
+ ),
627
+ sql=(
628
+ "SELECT id, customer_id, total_amount, status "
629
+ "FROM orders "
630
+ "WHERE total_amount BETWEEN 50 AND 150"
631
+ ),
632
+ ),
633
+ SQLTemplate(
634
+ id="easy_026",
635
+ difficulty="easy",
636
+ description=(
637
+ "How many distinct customers have placed at least one order? "
638
+ "Return a single value: customers_with_orders."
639
+ ),
640
+ sql=(
641
+ "SELECT COUNT(DISTINCT customer_id) AS customers_with_orders "
642
+ "FROM orders"
643
+ ),
644
+ ),
645
+ SQLTemplate(
646
+ id="easy_027",
647
+ difficulty="easy",
648
+ description=(
649
+ "What is the total number of order line items across all orders? "
650
+ "Return: total_line_items."
651
+ ),
652
+ sql="SELECT COUNT(*) AS total_line_items FROM order_items",
653
+ ),
654
+ SQLTemplate(
655
+ id="easy_028",
656
+ difficulty="easy",
657
+ description=(
658
+ "List all products priced between $20 and $80 inclusive, sorted by price ascending. "
659
+ "Return id, name, price."
660
+ ),
661
+ sql=(
662
+ "SELECT id, name, price "
663
+ "FROM products "
664
+ "WHERE price BETWEEN 20 AND 80 "
665
+ "ORDER BY price ASC"
666
+ ),
667
+ order_sensitive=True,
668
+ ),
669
+ SQLTemplate(
670
+ id="easy_029",
671
+ difficulty="easy",
672
+ description=(
673
+ "Show all gold-tier customers from Germany. "
674
+ "Return id, name, email, created_at."
675
+ ),
676
+ sql=(
677
+ "SELECT id, name, email, created_at "
678
+ "FROM customers "
679
+ "WHERE tier = 'gold' AND country = 'Germany'"
680
+ ),
681
+ ),
682
+ SQLTemplate(
683
+ id="easy_030",
684
+ difficulty="easy",
685
+ description=(
686
+ "What is the average rating across all reviews in the system? "
687
+ "Round to 2 decimal places. Return: avg_rating."
688
+ ),
689
+ sql="SELECT ROUND(AVG(rating), 2) AS avg_rating FROM reviews",
690
+ ),
691
+ ]
692
+
693
+ MEDIUM_TEMPLATES: List[SQLTemplate] = [
694
+ # ── JOIN + COUNT ─────────────────────────────────────────────────────────
695
+ SQLTemplate(
696
+ id="med_001",
697
+ difficulty="medium",
698
+ description=(
699
+ "How many orders has each customer placed? Include customers with zero orders. "
700
+ "Return customer_name and order_count. Sort by order_count descending, "
701
+ "then customer_name ascending."
702
+ ),
703
+ sql=(
704
+ "SELECT c.name AS customer_name, COUNT(o.id) AS order_count "
705
+ "FROM customers c "
706
+ "LEFT JOIN orders o ON c.id = o.customer_id "
707
+ "GROUP BY c.id, c.name "
708
+ "ORDER BY order_count DESC, customer_name ASC"
709
+ ),
710
+ order_sensitive=True,
711
+ ),
712
+ SQLTemplate(
713
+ id="med_002",
714
+ difficulty="medium",
715
+ description=(
716
+ "Average product rating per category, only for categories that have at least one review. "
717
+ "Return category_name and avg_rating (rounded to 2 dp). Sort by avg_rating descending."
718
+ ),
719
+ sql=(
720
+ "SELECT c.name AS category_name, ROUND(AVG(r.rating), 2) AS avg_rating "
721
+ "FROM categories c "
722
+ "JOIN products p ON p.category_id = c.id "
723
+ "JOIN reviews r ON r.product_id = p.id "
724
+ "GROUP BY c.id, c.name "
725
+ "ORDER BY avg_rating DESC"
726
+ ),
727
+ order_sensitive=True,
728
+ ),
729
+ SQLTemplate(
730
+ id="med_003",
731
+ difficulty="medium",
732
+ description=(
733
+ "Which categories have more than 5 in-stock products (stock_quantity > 0)? "
734
+ "Return category_name and in_stock_count. Sort by in_stock_count descending."
735
+ ),
736
+ sql=(
737
+ "SELECT c.name AS category_name, COUNT(p.id) AS in_stock_count "
738
+ "FROM categories c "
739
+ "JOIN products p ON p.category_id = c.id "
740
+ "WHERE p.stock_quantity > 0 "
741
+ "GROUP BY c.id, c.name "
742
+ "HAVING COUNT(p.id) > 5 "
743
+ "ORDER BY in_stock_count DESC"
744
+ ),
745
+ order_sensitive=True,
746
+ ),
747
+ SQLTemplate(
748
+ id="med_004",
749
+ difficulty="medium",
750
+ description=(
751
+ "Which customers have spent more than $500 on delivered orders? "
752
+ "Return customer_name and total_spent (rounded to 2 dp). Sort by total_spent descending."
753
+ ),
754
+ sql=(
755
+ "SELECT c.name AS customer_name, ROUND(SUM(o.total_amount), 2) AS total_spent "
756
+ "FROM customers c "
757
+ "JOIN orders o ON o.customer_id = c.id "
758
+ "WHERE o.status = 'delivered' "
759
+ "GROUP BY c.id, c.name "
760
+ "HAVING SUM(o.total_amount) > 500 "
761
+ "ORDER BY total_spent DESC"
762
+ ),
763
+ order_sensitive=True,
764
+ ),
765
+ SQLTemplate(
766
+ id="med_005",
767
+ difficulty="medium",
768
+ description=(
769
+ "Total quantity sold for each product that appears in at least one order. "
770
+ "Return product_name and total_quantity_sold. Sort by total_quantity_sold descending."
771
+ ),
772
+ sql=(
773
+ "SELECT p.name AS product_name, SUM(oi.quantity) AS total_quantity_sold "
774
+ "FROM products p "
775
+ "JOIN order_items oi ON oi.product_id = p.id "
776
+ "GROUP BY p.id, p.name "
777
+ "ORDER BY total_quantity_sold DESC"
778
+ ),
779
+ order_sensitive=True,
780
+ ),
781
+ SQLTemplate(
782
+ id="med_006",
783
+ difficulty="medium",
784
+ description=(
785
+ "Number of reviews per product, only for products with at least 3 reviews. "
786
+ "Return product_name and review_count. Sort by review_count descending."
787
+ ),
788
+ sql=(
789
+ "SELECT p.name AS product_name, COUNT(r.id) AS review_count "
790
+ "FROM products p "
791
+ "JOIN reviews r ON r.product_id = p.id "
792
+ "GROUP BY p.id, p.name "
793
+ "HAVING COUNT(r.id) >= 3 "
794
+ "ORDER BY review_count DESC"
795
+ ),
796
+ order_sensitive=True,
797
+ ),
798
+ SQLTemplate(
799
+ id="med_007",
800
+ difficulty="medium",
801
+ description=(
802
+ "Show the total revenue (sum of total_amount) per country from all orders, "
803
+ "regardless of status. Return country and total_revenue (rounded to 2 dp). "
804
+ "Sort by total_revenue descending."
805
+ ),
806
+ sql=(
807
+ "SELECT c.country, ROUND(SUM(o.total_amount), 2) AS total_revenue "
808
+ "FROM customers c "
809
+ "JOIN orders o ON o.customer_id = c.id "
810
+ "GROUP BY c.country "
811
+ "ORDER BY total_revenue DESC"
812
+ ),
813
+ order_sensitive=True,
814
+ ),
815
+ SQLTemplate(
816
+ id="med_008",
817
+ difficulty="medium",
818
+ description=(
819
+ "For each customer tier (bronze, silver, gold) show the average order value "
820
+ "from delivered orders. Return tier and avg_order_value (rounded to 2 dp). "
821
+ "Sort by avg_order_value descending."
822
+ ),
823
+ sql=(
824
+ "SELECT c.tier, ROUND(AVG(o.total_amount), 2) AS avg_order_value "
825
+ "FROM customers c "
826
+ "JOIN orders o ON o.customer_id = c.id "
827
+ "WHERE o.status = 'delivered' "
828
+ "GROUP BY c.tier "
829
+ "ORDER BY avg_order_value DESC"
830
+ ),
831
+ order_sensitive=True,
832
+ ),
833
+ SQLTemplate(
834
+ id="med_009",
835
+ difficulty="medium",
836
+ description=(
837
+ "Which products have never been ordered? "
838
+ "Return id and name, sorted by name ascending."
839
+ ),
840
+ sql=(
841
+ "SELECT p.id, p.name "
842
+ "FROM products p "
843
+ "LEFT JOIN order_items oi ON oi.product_id = p.id "
844
+ "WHERE oi.id IS NULL "
845
+ "ORDER BY p.name ASC"
846
+ ),
847
+ order_sensitive=True,
848
+ ),
849
+ SQLTemplate(
850
+ id="med_010",
851
+ difficulty="medium",
852
+ description=(
853
+ "Number of orders per status. "
854
+ "Return status and order_count. Sort by order_count descending."
855
+ ),
856
+ sql=(
857
+ "SELECT status, COUNT(*) AS order_count "
858
+ "FROM orders "
859
+ "GROUP BY status "
860
+ "ORDER BY order_count DESC"
861
+ ),
862
+ order_sensitive=True,
863
+ ),
864
+ SQLTemplate(
865
+ id="med_011",
866
+ difficulty="medium",
867
+ description=(
868
+ "Show the total number of products per category. "
869
+ "Return category_name and product_count. Sort by product_count descending."
870
+ ),
871
+ sql=(
872
+ "SELECT c.name AS category_name, COUNT(p.id) AS product_count "
873
+ "FROM categories c "
874
+ "LEFT JOIN products p ON p.category_id = c.id "
875
+ "GROUP BY c.id, c.name "
876
+ "ORDER BY product_count DESC"
877
+ ),
878
+ order_sensitive=True,
879
+ ),
880
+ SQLTemplate(
881
+ id="med_012",
882
+ difficulty="medium",
883
+ description=(
884
+ "Average rating per product for products with at least one review. "
885
+ "Return product_name and avg_rating (rounded to 2 dp). Sort by avg_rating descending."
886
+ ),
887
+ sql=(
888
+ "SELECT p.name AS product_name, ROUND(AVG(r.rating), 2) AS avg_rating "
889
+ "FROM products p "
890
+ "JOIN reviews r ON r.product_id = p.id "
891
+ "GROUP BY p.id, p.name "
892
+ "ORDER BY avg_rating DESC"
893
+ ),
894
+ order_sensitive=True,
895
+ ),
896
+ SQLTemplate(
897
+ id="med_013",
898
+ difficulty="medium",
899
+ description=(
900
+ "Which gold-tier customers have placed more than 3 orders? "
901
+ "Return customer_name and order_count. Sort by order_count descending."
902
+ ),
903
+ sql=(
904
+ "SELECT c.name AS customer_name, COUNT(o.id) AS order_count "
905
+ "FROM customers c "
906
+ "JOIN orders o ON o.customer_id = c.id "
907
+ "WHERE c.tier = 'gold' "
908
+ "GROUP BY c.id, c.name "
909
+ "HAVING COUNT(o.id) > 3 "
910
+ "ORDER BY order_count DESC"
911
+ ),
912
+ order_sensitive=True,
913
+ ),
914
+ SQLTemplate(
915
+ id="med_014",
916
+ difficulty="medium",
917
+ description=(
918
+ "Total quantity of each product ordered via order_items. "
919
+ "Return product_name and total_units. Sort by total_units descending."
920
+ ),
921
+ sql=(
922
+ "SELECT p.name AS product_name, SUM(oi.quantity) AS total_units "
923
+ "FROM products p "
924
+ "JOIN order_items oi ON oi.product_id = p.id "
925
+ "GROUP BY p.id, p.name "
926
+ "ORDER BY total_units DESC"
927
+ ),
928
+ order_sensitive=True,
929
+ ),
930
+ SQLTemplate(
931
+ id="med_015",
932
+ difficulty="medium",
933
+ description=(
934
+ "For each country, count the number of gold-tier customers. "
935
+ "Only show countries with at least one gold-tier customer. "
936
+ "Return country and gold_count. Sort by gold_count descending."
937
+ ),
938
+ sql=(
939
+ "SELECT country, COUNT(*) AS gold_count "
940
+ "FROM customers "
941
+ "WHERE tier = 'gold' "
942
+ "GROUP BY country "
943
+ "HAVING COUNT(*) >= 1 "
944
+ "ORDER BY gold_count DESC"
945
+ ),
946
+ order_sensitive=True,
947
+ ),
948
+ SQLTemplate(
949
+ id="med_016",
950
+ difficulty="medium",
951
+ description=(
952
+ "Show how many reviews each customer has submitted. Only include customers "
953
+ "who have submitted at least one review. Return customer_name and review_count. "
954
+ "Sort by review_count descending."
955
+ ),
956
+ sql=(
957
+ "SELECT c.name AS customer_name, COUNT(r.id) AS review_count "
958
+ "FROM customers c "
959
+ "JOIN reviews r ON r.customer_id = c.id "
960
+ "GROUP BY c.id, c.name "
961
+ "ORDER BY review_count DESC"
962
+ ),
963
+ order_sensitive=True,
964
+ ),
965
+ SQLTemplate(
966
+ id="med_017",
967
+ difficulty="medium",
968
+ description=(
969
+ "Total revenue generated from order_items (quantity * unit_price) per category. "
970
+ "Return category_name and category_revenue (rounded to 2 dp). "
971
+ "Sort by category_revenue descending."
972
+ ),
973
+ sql=(
974
+ "SELECT c.name AS category_name, "
975
+ " ROUND(SUM(oi.quantity * oi.unit_price), 2) AS category_revenue "
976
+ "FROM categories c "
977
+ "JOIN products p ON p.category_id = c.id "
978
+ "JOIN order_items oi ON oi.product_id = p.id "
979
+ "GROUP BY c.id, c.name "
980
+ "ORDER BY category_revenue DESC"
981
+ ),
982
+ order_sensitive=True,
983
+ ),
984
+ SQLTemplate(
985
+ id="med_018",
986
+ difficulty="medium",
987
+ description=(
988
+ "Which products have an average rating strictly below 3? "
989
+ "Return product_name and avg_rating (rounded to 2 dp). Sort by avg_rating ascending."
990
+ ),
991
+ sql=(
992
+ "SELECT p.name AS product_name, ROUND(AVG(r.rating), 2) AS avg_rating "
993
+ "FROM products p "
994
+ "JOIN reviews r ON r.product_id = p.id "
995
+ "GROUP BY p.id, p.name "
996
+ "HAVING AVG(r.rating) < 3 "
997
+ "ORDER BY avg_rating ASC"
998
+ ),
999
+ order_sensitive=True,
1000
+ ),
1001
+ SQLTemplate(
1002
+ id="med_019",
1003
+ difficulty="medium",
1004
+ description=(
1005
+ "Find the maximum order value for each customer tier. "
1006
+ "Return tier and max_order_value (rounded to 2 dp). Sort by max_order_value descending."
1007
+ ),
1008
+ sql=(
1009
+ "SELECT c.tier, ROUND(MAX(o.total_amount), 2) AS max_order_value "
1010
+ "FROM customers c "
1011
+ "JOIN orders o ON o.customer_id = c.id "
1012
+ "GROUP BY c.tier "
1013
+ "ORDER BY max_order_value DESC"
1014
+ ),
1015
+ order_sensitive=True,
1016
+ ),
1017
+ SQLTemplate(
1018
+ id="med_020",
1019
+ difficulty="medium",
1020
+ description=(
1021
+ "How many customers per country have placed at least one delivered order? "
1022
+ "Return country and customer_count. Sort by customer_count descending."
1023
+ ),
1024
+ sql=(
1025
+ "SELECT c.country, COUNT(DISTINCT c.id) AS customer_count "
1026
+ "FROM customers c "
1027
+ "JOIN orders o ON o.customer_id = c.id "
1028
+ "WHERE o.status = 'delivered' "
1029
+ "GROUP BY c.country "
1030
+ "ORDER BY customer_count DESC"
1031
+ ),
1032
+ order_sensitive=True,
1033
+ ),
1034
+ SQLTemplate(
1035
+ id="med_021",
1036
+ difficulty="medium",
1037
+ description=(
1038
+ "List all products together with their category name. "
1039
+ "Return product_name, category_name, price. Sort by category_name, then price ascending."
1040
+ ),
1041
+ sql=(
1042
+ "SELECT p.name AS product_name, c.name AS category_name, p.price "
1043
+ "FROM products p "
1044
+ "JOIN categories c ON c.id = p.category_id "
1045
+ "ORDER BY category_name ASC, p.price ASC"
1046
+ ),
1047
+ order_sensitive=True,
1048
+ ),
1049
+ SQLTemplate(
1050
+ id="med_022",
1051
+ difficulty="medium",
1052
+ description=(
1053
+ "For each order, show the total number of line items it contains. "
1054
+ "Return order_id and line_item_count. Sort by line_item_count descending."
1055
+ ),
1056
+ sql=(
1057
+ "SELECT order_id, COUNT(*) AS line_item_count "
1058
+ "FROM order_items "
1059
+ "GROUP BY order_id "
1060
+ "ORDER BY line_item_count DESC"
1061
+ ),
1062
+ order_sensitive=True,
1063
+ ),
1064
+ SQLTemplate(
1065
+ id="med_023",
1066
+ difficulty="medium",
1067
+ description=(
1068
+ "Show the minimum and maximum product price per category. "
1069
+ "Return category_name, min_price, max_price. Sort by category_name ascending."
1070
+ ),
1071
+ sql=(
1072
+ "SELECT c.name AS category_name, "
1073
+ " ROUND(MIN(p.price), 2) AS min_price, "
1074
+ " ROUND(MAX(p.price), 2) AS max_price "
1075
+ "FROM categories c "
1076
+ "JOIN products p ON p.category_id = c.id "
1077
+ "GROUP BY c.id, c.name "
1078
+ "ORDER BY category_name ASC"
1079
+ ),
1080
+ order_sensitive=True,
1081
+ ),
1082
+ SQLTemplate(
1083
+ id="med_024",
1084
+ difficulty="medium",
1085
+ description=(
1086
+ "Find customers who have given a rating of 5 to at least one product. "
1087
+ "Return customer_name and five_star_count. Sort by five_star_count descending."
1088
+ ),
1089
+ sql=(
1090
+ "SELECT c.name AS customer_name, COUNT(r.id) AS five_star_count "
1091
+ "FROM customers c "
1092
+ "JOIN reviews r ON r.customer_id = c.id "
1093
+ "WHERE r.rating = 5 "
1094
+ "GROUP BY c.id, c.name "
1095
+ "ORDER BY five_star_count DESC"
1096
+ ),
1097
+ order_sensitive=True,
1098
+ ),
1099
+ SQLTemplate(
1100
+ id="med_025",
1101
+ difficulty="medium",
1102
+ description=(
1103
+ "Show the average number of items per order across all orders. "
1104
+ "Round to 2 decimal places. Return: avg_items_per_order."
1105
+ ),
1106
+ sql=(
1107
+ "SELECT ROUND(AVG(item_count), 2) AS avg_items_per_order "
1108
+ "FROM ( "
1109
+ " SELECT order_id, COUNT(*) AS item_count "
1110
+ " FROM order_items "
1111
+ " GROUP BY order_id "
1112
+ ")"
1113
+ ),
1114
+ ),
1115
+ ]
1116
+
1117
+ HARD_TEMPLATES: List[SQLTemplate] = [
1118
+ # ── Window functions ────────────────────────────────────────────────���────
1119
+ SQLTemplate(
1120
+ id="hard_001",
1121
+ difficulty="hard",
1122
+ description=(
1123
+ "Rank customers by total spending on delivered orders using DENSE_RANK "
1124
+ "(rank 1 = highest spender). "
1125
+ "Return customer_name, total_spent (rounded to 2 dp), spending_rank. "
1126
+ "Sort by spending_rank ascending."
1127
+ ),
1128
+ sql=(
1129
+ "SELECT customer_name, total_spent, spending_rank "
1130
+ "FROM ( "
1131
+ " SELECT c.name AS customer_name, "
1132
+ " ROUND(SUM(o.total_amount), 2) AS total_spent, "
1133
+ " DENSE_RANK() OVER (ORDER BY SUM(o.total_amount) DESC) AS spending_rank "
1134
+ " FROM customers c "
1135
+ " JOIN orders o ON o.customer_id = c.id "
1136
+ " WHERE o.status = 'delivered' "
1137
+ " GROUP BY c.id, c.name "
1138
+ ") sub "
1139
+ "ORDER BY spending_rank ASC"
1140
+ ),
1141
+ order_sensitive=True,
1142
+ ),
1143
+ SQLTemplate(
1144
+ id="hard_002",
1145
+ difficulty="hard",
1146
+ description=(
1147
+ "For each reviewed product, show its own average rating and the average rating "
1148
+ "of all products in its category (partition window). "
1149
+ "Return product_name, product_avg_rating, category_avg_rating (both rounded to 2 dp). "
1150
+ "Sort by product_avg_rating descending."
1151
+ ),
1152
+ sql=(
1153
+ "SELECT p.name AS product_name, "
1154
+ " ROUND(AVG(r.rating), 2) AS product_avg_rating, "
1155
+ " ROUND(AVG(AVG(r.rating)) OVER (PARTITION BY p.category_id), 2) AS category_avg_rating "
1156
+ "FROM products p "
1157
+ "JOIN reviews r ON r.product_id = p.id "
1158
+ "GROUP BY p.id, p.name, p.category_id "
1159
+ "ORDER BY product_avg_rating DESC"
1160
+ ),
1161
+ order_sensitive=True,
1162
+ ),
1163
+ SQLTemplate(
1164
+ id="hard_003",
1165
+ difficulty="hard",
1166
+ description=(
1167
+ "Find all customers whose most recent order has status 'cancelled'. "
1168
+ "Use a CTE with ROW_NUMBER partitioned by customer_id ordered by created_at DESC. "
1169
+ "Return customer_name, last_order_status, last_order_date. Sort by customer_name ascending."
1170
+ ),
1171
+ sql=(
1172
+ "WITH ranked_orders AS ( "
1173
+ " SELECT customer_id, status, created_at, "
1174
+ " ROW_NUMBER() OVER (PARTITION BY customer_id ORDER BY created_at DESC) AS rn "
1175
+ " FROM orders "
1176
+ ") "
1177
+ "SELECT c.name AS customer_name, "
1178
+ " ro.status AS last_order_status, "
1179
+ " ro.created_at AS last_order_date "
1180
+ "FROM customers c "
1181
+ "JOIN ranked_orders ro ON ro.customer_id = c.id "
1182
+ "WHERE ro.rn = 1 AND ro.status = 'cancelled' "
1183
+ "ORDER BY customer_name ASC"
1184
+ ),
1185
+ order_sensitive=True,
1186
+ ),
1187
+ SQLTemplate(
1188
+ id="hard_004",
1189
+ difficulty="hard",
1190
+ description=(
1191
+ "Monthly revenue from delivered orders and its running total for all months in 2024. "
1192
+ "Return month (YYYY-MM format), monthly_revenue, running_total (both rounded to 2 dp). "
1193
+ "Sort by month ascending."
1194
+ ),
1195
+ sql=(
1196
+ "WITH monthly AS ( "
1197
+ " SELECT strftime('%Y-%m', created_at) AS month, "
1198
+ " ROUND(SUM(total_amount), 2) AS monthly_revenue "
1199
+ " FROM orders "
1200
+ " WHERE status = 'delivered' "
1201
+ " AND created_at >= '2024-01-01' AND created_at < '2025-01-01' "
1202
+ " GROUP BY strftime('%Y-%m', created_at) "
1203
+ ") "
1204
+ "SELECT month, monthly_revenue, "
1205
+ " ROUND(SUM(monthly_revenue) OVER (ORDER BY month), 2) AS running_total "
1206
+ "FROM monthly "
1207
+ "ORDER BY month ASC"
1208
+ ),
1209
+ order_sensitive=True,
1210
+ ),
1211
+ SQLTemplate(
1212
+ id="hard_005",
1213
+ difficulty="hard",
1214
+ description=(
1215
+ "Find products whose average rating is strictly above the average rating of all products "
1216
+ "in their category. Use two CTEs: one for product-level averages and one for category-level. "
1217
+ "Return product_name, category_name, product_avg_rating, category_avg_rating (both rounded to 2 dp). "
1218
+ "Sort by product_avg_rating descending, then product_name ascending."
1219
+ ),
1220
+ sql=(
1221
+ "WITH product_ratings AS ( "
1222
+ " SELECT p.id AS product_id, p.name AS product_name, "
1223
+ " p.category_id, c.name AS category_name, "
1224
+ " ROUND(AVG(r.rating), 2) AS product_avg_rating "
1225
+ " FROM products p "
1226
+ " JOIN reviews r ON r.product_id = p.id "
1227
+ " JOIN categories c ON c.id = p.category_id "
1228
+ " GROUP BY p.id, p.name, p.category_id, c.name "
1229
+ "), "
1230
+ "category_ratings AS ( "
1231
+ " SELECT category_id, ROUND(AVG(product_avg_rating), 2) AS category_avg_rating "
1232
+ " FROM product_ratings "
1233
+ " GROUP BY category_id "
1234
+ ") "
1235
+ "SELECT pr.product_name, pr.category_name, "
1236
+ " pr.product_avg_rating, cr.category_avg_rating "
1237
+ "FROM product_ratings pr "
1238
+ "JOIN category_ratings cr ON cr.category_id = pr.category_id "
1239
+ "WHERE pr.product_avg_rating > cr.category_avg_rating "
1240
+ "ORDER BY pr.product_avg_rating DESC, pr.product_name ASC"
1241
+ ),
1242
+ order_sensitive=True,
1243
+ ),
1244
+ SQLTemplate(
1245
+ id="hard_006",
1246
+ difficulty="hard",
1247
+ description=(
1248
+ "For each customer, find their very first order date using ROW_NUMBER in a CTE. "
1249
+ "Return customer_name and first_order_date. Sort by first_order_date ascending."
1250
+ ),
1251
+ sql=(
1252
+ "WITH first_orders AS ( "
1253
+ " SELECT customer_id, created_at, "
1254
+ " ROW_NUMBER() OVER (PARTITION BY customer_id ORDER BY created_at ASC) AS rn "
1255
+ " FROM orders "
1256
+ ") "
1257
+ "SELECT c.name AS customer_name, fo.created_at AS first_order_date "
1258
+ "FROM customers c "
1259
+ "JOIN first_orders fo ON fo.customer_id = c.id "
1260
+ "WHERE fo.rn = 1 "
1261
+ "ORDER BY first_order_date ASC"
1262
+ ),
1263
+ order_sensitive=True,
1264
+ ),
1265
+ SQLTemplate(
1266
+ id="hard_007",
1267
+ difficulty="hard",
1268
+ description=(
1269
+ "Rank products by total revenue generated (quantity * unit_price from order_items) "
1270
+ "using RANK() window function. "
1271
+ "Return product_name, total_revenue (rounded to 2 dp), revenue_rank. "
1272
+ "Sort by revenue_rank ascending."
1273
+ ),
1274
+ sql=(
1275
+ "SELECT product_name, total_revenue, revenue_rank "
1276
+ "FROM ( "
1277
+ " SELECT p.name AS product_name, "
1278
+ " ROUND(SUM(oi.quantity * oi.unit_price), 2) AS total_revenue, "
1279
+ " RANK() OVER (ORDER BY SUM(oi.quantity * oi.unit_price) DESC) AS revenue_rank "
1280
+ " FROM products p "
1281
+ " JOIN order_items oi ON oi.product_id = p.id "
1282
+ " GROUP BY p.id, p.name "
1283
+ ") sub "
1284
+ "ORDER BY revenue_rank ASC"
1285
+ ),
1286
+ order_sensitive=True,
1287
+ ),
1288
+ SQLTemplate(
1289
+ id="hard_008",
1290
+ difficulty="hard",
1291
+ description=(
1292
+ "For each customer, compute the running total of their order amounts ordered by "
1293
+ "created_at. Return customer_name, order_date (created_at), order_amount (total_amount), "
1294
+ "running_total (rounded to 2 dp). Sort by customer_name, order_date ascending."
1295
+ ),
1296
+ sql=(
1297
+ "SELECT c.name AS customer_name, "
1298
+ " o.created_at AS order_date, "
1299
+ " o.total_amount AS order_amount, "
1300
+ " ROUND(SUM(o.total_amount) OVER "
1301
+ " (PARTITION BY c.id ORDER BY o.created_at "
1302
+ " ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), 2) AS running_total "
1303
+ "FROM customers c "
1304
+ "JOIN orders o ON o.customer_id = c.id "
1305
+ "ORDER BY customer_name ASC, order_date ASC"
1306
+ ),
1307
+ order_sensitive=True,
1308
+ ),
1309
+ SQLTemplate(
1310
+ id="hard_009",
1311
+ difficulty="hard",
1312
+ description=(
1313
+ "Find customers who have placed orders in every status "
1314
+ "(pending, processing, shipped, delivered, cancelled) at least once. "
1315
+ "Return customer_name and status_count. Sort by customer_name ascending."
1316
+ ),
1317
+ sql=(
1318
+ "SELECT c.name AS customer_name, COUNT(DISTINCT o.status) AS status_count "
1319
+ "FROM customers c "
1320
+ "JOIN orders o ON o.customer_id = c.id "
1321
+ "GROUP BY c.id, c.name "
1322
+ "HAVING COUNT(DISTINCT o.status) = 5 "
1323
+ "ORDER BY customer_name ASC"
1324
+ ),
1325
+ order_sensitive=True,
1326
+ ),
1327
+ SQLTemplate(
1328
+ id="hard_010",
1329
+ difficulty="hard",
1330
+ description=(
1331
+ "Using a CTE, compute the total revenue per product, then rank the top 3 products "
1332
+ "in each category by revenue using DENSE_RANK. Only return rows with rank <= 3. "
1333
+ "Return category_name, product_name, total_revenue (rounded to 2 dp), rank_in_category. "
1334
+ "Sort by category_name, rank_in_category ascending."
1335
+ ),
1336
+ sql=(
1337
+ "WITH product_rev AS ( "
1338
+ " SELECT p.id, p.name AS product_name, p.category_id, "
1339
+ " c.name AS category_name, "
1340
+ " ROUND(SUM(oi.quantity * oi.unit_price), 2) AS total_revenue "
1341
+ " FROM products p "
1342
+ " JOIN categories c ON c.id = p.category_id "
1343
+ " JOIN order_items oi ON oi.product_id = p.id "
1344
+ " GROUP BY p.id, p.name, p.category_id, c.name "
1345
+ "), "
1346
+ "ranked AS ( "
1347
+ " SELECT product_name, category_name, total_revenue, "
1348
+ " DENSE_RANK() OVER (PARTITION BY category_id ORDER BY total_revenue DESC) AS rank_in_category "
1349
+ " FROM product_rev "
1350
+ ") "
1351
+ "SELECT category_name, product_name, total_revenue, rank_in_category "
1352
+ "FROM ranked "
1353
+ "WHERE rank_in_category <= 3 "
1354
+ "ORDER BY category_name ASC, rank_in_category ASC"
1355
+ ),
1356
+ order_sensitive=True,
1357
+ ),
1358
+ SQLTemplate(
1359
+ id="hard_011",
1360
+ difficulty="hard",
1361
+ description=(
1362
+ "Compute the percentage of total revenue each category contributes. "
1363
+ "Use a CTE for category revenues and a window SUM for the grand total. "
1364
+ "Return category_name, category_revenue, pct_of_total (rounded to 2 dp). "
1365
+ "Sort by pct_of_total descending."
1366
+ ),
1367
+ sql=(
1368
+ "WITH cat_rev AS ( "
1369
+ " SELECT c.name AS category_name, "
1370
+ " ROUND(SUM(oi.quantity * oi.unit_price), 2) AS category_revenue "
1371
+ " FROM categories c "
1372
+ " JOIN products p ON p.category_id = c.id "
1373
+ " JOIN order_items oi ON oi.product_id = p.id "
1374
+ " GROUP BY c.id, c.name "
1375
+ ") "
1376
+ "SELECT category_name, category_revenue, "
1377
+ " ROUND(100.0 * category_revenue / SUM(category_revenue) OVER (), 2) AS pct_of_total "
1378
+ "FROM cat_rev "
1379
+ "ORDER BY pct_of_total DESC"
1380
+ ),
1381
+ order_sensitive=True,
1382
+ ),
1383
+ SQLTemplate(
1384
+ id="hard_012",
1385
+ difficulty="hard",
1386
+ description=(
1387
+ "Find the customers who placed the highest number of orders in 2023. "
1388
+ "Use a CTE to count per-customer orders in 2023, then apply DENSE_RANK. "
1389
+ "Return customer_name, order_count_2023, rank. Sort by rank, then customer_name."
1390
+ ),
1391
+ sql=(
1392
+ "WITH counts_2023 AS ( "
1393
+ " SELECT c.name AS customer_name, COUNT(o.id) AS order_count_2023 "
1394
+ " FROM customers c "
1395
+ " JOIN orders o ON o.customer_id = c.id "
1396
+ " WHERE o.created_at >= '2023-01-01' AND o.created_at < '2024-01-01' "
1397
+ " GROUP BY c.id, c.name "
1398
+ ") "
1399
+ "SELECT customer_name, order_count_2023, "
1400
+ " DENSE_RANK() OVER (ORDER BY order_count_2023 DESC) AS rank "
1401
+ "FROM counts_2023 "
1402
+ "ORDER BY rank ASC, customer_name ASC"
1403
+ ),
1404
+ order_sensitive=True,
1405
+ ),
1406
+ SQLTemplate(
1407
+ id="hard_013",
1408
+ difficulty="hard",
1409
+ description=(
1410
+ "Show a quarterly revenue breakdown for delivered orders across all years. "
1411
+ "Use strftime to derive year and quarter. "
1412
+ "Return year, quarter, quarterly_revenue (rounded to 2 dp), "
1413
+ "and running_total_in_year (running SUM within the same year, rounded to 2 dp). "
1414
+ "Sort by year, quarter ascending."
1415
+ ),
1416
+ sql=(
1417
+ "WITH quarterly AS ( "
1418
+ " SELECT strftime('%Y', created_at) AS year, "
1419
+ " ((CAST(strftime('%m', created_at) AS INTEGER) - 1) / 3 + 1) AS quarter, "
1420
+ " ROUND(SUM(total_amount), 2) AS quarterly_revenue "
1421
+ " FROM orders "
1422
+ " WHERE status = 'delivered' "
1423
+ " GROUP BY year, quarter "
1424
+ ") "
1425
+ "SELECT year, quarter, quarterly_revenue, "
1426
+ " ROUND(SUM(quarterly_revenue) OVER (PARTITION BY year ORDER BY quarter), 2) AS running_total_in_year "
1427
+ "FROM quarterly "
1428
+ "ORDER BY year ASC, quarter ASC"
1429
+ ),
1430
+ order_sensitive=True,
1431
+ ),
1432
+ SQLTemplate(
1433
+ id="hard_014",
1434
+ difficulty="hard",
1435
+ description=(
1436
+ "Find the top-spending customer in each country using ROW_NUMBER. "
1437
+ "Return country, customer_name, total_spent (rounded to 2 dp). "
1438
+ "Sort by country, total_spent descending."
1439
+ ),
1440
+ sql=(
1441
+ "WITH customer_spend AS ( "
1442
+ " SELECT c.id, c.name AS customer_name, c.country, "
1443
+ " ROUND(SUM(o.total_amount), 2) AS total_spent "
1444
+ " FROM customers c "
1445
+ " JOIN orders o ON o.customer_id = c.id "
1446
+ " GROUP BY c.id, c.name, c.country "
1447
+ "), "
1448
+ "ranked AS ( "
1449
+ " SELECT country, customer_name, total_spent, "
1450
+ " ROW_NUMBER() OVER (PARTITION BY country ORDER BY total_spent DESC) AS rn "
1451
+ " FROM customer_spend "
1452
+ ") "
1453
+ "SELECT country, customer_name, total_spent "
1454
+ "FROM ranked "
1455
+ "WHERE rn = 1 "
1456
+ "ORDER BY country ASC"
1457
+ ),
1458
+ order_sensitive=True,
1459
+ ),
1460
+ SQLTemplate(
1461
+ id="hard_015",
1462
+ difficulty="hard",
1463
+ description=(
1464
+ "Find products that have received both 1-star and 5-star reviews. "
1465
+ "Use two CTEs: one for 1-star products, one for 5-star products, then intersect. "
1466
+ "Return product_name. Sort by product_name ascending."
1467
+ ),
1468
+ sql=(
1469
+ "WITH one_star AS ( "
1470
+ " SELECT DISTINCT product_id FROM reviews WHERE rating = 1 "
1471
+ "), "
1472
+ "five_star AS ( "
1473
+ " SELECT DISTINCT product_id FROM reviews WHERE rating = 5 "
1474
+ ") "
1475
+ "SELECT p.name AS product_name "
1476
+ "FROM products p "
1477
+ "JOIN one_star os ON os.product_id = p.id "
1478
+ "JOIN five_star fs ON fs.product_id = p.id "
1479
+ "ORDER BY product_name ASC"
1480
+ ),
1481
+ order_sensitive=True,
1482
+ ),
1483
+ ]
1484
+
1485
+ ALL_TEMPLATES: List[SQLTemplate] = EASY_TEMPLATES + MEDIUM_TEMPLATES + HARD_TEMPLATES
1486
+
1487
+
1488
+ # ─────────────────────────────────────────────────────────────────────────────
1489
+ # Personas
1490
+ # ─────────────────────────────────────────────────────────────────────────────
1491
+
1492
+ SCHEMA_CONTEXT = """
1493
+ DATABASE SCHEMA (SQLite e-commerce):
1494
+ categories(id, name)
1495
+ products(id, name, category_id, price, stock_quantity)
1496
+ customers(id, name, email, country, tier∈{bronze|silver|gold}, created_at)
1497
+ orders(id, customer_id, status∈{pending|processing|shipped|delivered|cancelled},
1498
+ created_at, total_amount)
1499
+ order_items(id, order_id, product_id, quantity, unit_price)
1500
+ reviews(id, product_id, customer_id, rating∈1-5, created_at)
1501
+ """
1502
+
1503
+ PERSONA_SPECS = {
1504
+ "ceo": (
1505
+ "You are a senior business executive. Write one SHORT, direct question in active voice, "
1506
+ "as if you are asking an analyst to pull a number fast. Be terse, no fluff. "
1507
+ "Use business language: 'revenue', 'customers', 'performance', not technical SQL terms."
1508
+ ),
1509
+ "chatty": (
1510
+ "You are a friendly but verbose non-technical employee. Write one long, conversational "
1511
+ "question with filler phrases like 'Could you please tell me...', 'I was wondering if...', "
1512
+ "passive voice is fine. Use everyday words like 'money' instead of 'revenue', "
1513
+ "'people' instead of 'customers'."
1514
+ ),
1515
+ "lazy": (
1516
+ "You are typing quickly on a phone. Write an extremely short question with abbreviations, "
1517
+ "lowercase letters, and minor spelling mistakes. Skip articles and punctuation where possible. "
1518
+ "Example style: 'top 5 prods by sales?', 'hw many cust in usa'."
1519
+ ),
1520
+ "confused": (
1521
+ "You are a non-technical user who is unsure of the exact terminology. Write one question "
1522
+ "using synonyms and vague language. Replace 'revenue' with 'money made', 'customers' with "
1523
+ "'people' or 'users' or 'accounts', 'orders' with 'purchases' or 'transactions', "
1524
+ "'tier' with 'membership level'. Include a bit of ambiguity."
1525
+ ),
1526
+ "analyst": (
1527
+ "You are a data analyst with technical knowledge. Write one precise, jargon-heavy question "
1528
+ "using terms like 'aggregate', 'partition', 'metric', 'fiscal period', 'segmented by', "
1529
+ "'cohort', 'granularity'. Be specific about column names and filters."
1530
+ ),
1531
+ }
1532
+
1533
+
1534
+ # ─────────────────────────────────────────────────────────────────────────────
1535
+ # Rule-based Augmentor
1536
+ # ─────────────────────────────────────────────────────────────────────────────
1537
+
1538
+ class RuleAugmentor:
1539
+ """
1540
+ Applies deterministic, non-LLM transformations to a generated NL question.
1541
+ Returns a list of augmented variants (may be empty if no rule applied).
1542
+ """
1543
+
1544
+ SYNONYMS: Dict[str, List[str]] = {
1545
+ "customers": ["clients", "users", "accounts", "shoppers", "buyers"],
1546
+ "orders": ["purchases", "transactions", "sales", "bookings"],
1547
+ "products": ["items", "goods", "listings", "SKUs"],
1548
+ "revenue": ["sales", "income", "earnings", "money made"],
1549
+ "spending": ["expenditure", "purchases", "money spent"],
1550
+ "delivered": ["completed", "fulfilled", "received"],
1551
+ "cancelled": ["canceled", "voided", "aborted"],
1552
+ "pending": ["waiting", "unprocessed", "queued"],
1553
+ "gold": ["premium", "top-tier", "VIP", "platinum"],
1554
+ "silver": ["mid-tier", "standard-plus"],
1555
+ "bronze": ["basic", "standard", "entry-level"],
1556
+ "rating": ["score", "star rating", "review score"],
1557
+ "country": ["region", "location", "geography", "nation"],
1558
+ "category": ["department", "section", "type", "group"],
1559
+ "price": ["cost", "value", "amount", "fee"],
1560
+ "total": ["sum", "aggregate", "combined", "overall"],
1561
+ "average": ["mean", "typical", "avg"],
1562
+ "show": ["list", "display", "give me", "get", "fetch"],
1563
+ "find": ["identify", "locate", "get", "pull", "retrieve"],
1564
+ "return": ["give me", "show", "list", "provide"],
1565
+ }
1566
+
1567
+ def augment(self, question: str, rng: random.Random) -> Optional[str]:
1568
+ words = question.split()
1569
+ changed = False
1570
+ result = []
1571
+ for w in words:
1572
+ clean = w.lower().strip(".,?!;:")
1573
+ if clean in self.SYNONYMS and rng.random() < 0.4:
1574
+ replacement = rng.choice(self.SYNONYMS[clean])
1575
+ # Preserve trailing punctuation
1576
+ punct = w[len(clean):] if w.lower().startswith(clean) else ""
1577
+ result.append(replacement + punct)
1578
+ changed = True
1579
+ else:
1580
+ result.append(w)
1581
+ if not changed:
1582
+ return None
1583
+ new_q = " ".join(result)
1584
+ # Capitalise first letter
1585
+ return new_q[0].upper() + new_q[1:] if new_q else new_q
1586
+
1587
+
1588
+ # ─────────────────────────────────────────────────────────────────────────────
1589
+ # vLLM Generator
1590
+ # ─────────────────────────────────────────────────────────────────────────────
1591
+
1592
+ class VLLMGenerator:
1593
+ """
1594
+ Async batched inference using the OpenAI-compatible vLLM endpoint.
1595
+ vLLM exposes exactly the same API as OpenAI, so we reuse AsyncOpenAI.
1596
+ """
1597
+
1598
+ def __init__(self, base_url: str, model: str, temperature: float = 0.8,
1599
+ max_tokens: int = 256, semaphore: int = 64):
1600
+ self.client = AsyncOpenAI(base_url=base_url, api_key="NONE")
1601
+ self.model = model
1602
+ self.temperature = temperature
1603
+ self.max_tokens = max_tokens
1604
+ self._sem = asyncio.Semaphore(semaphore)
1605
+
1606
+ async def generate_one(
1607
+ self,
1608
+ system: str,
1609
+ user: str,
1610
+ retries: int = 3,
1611
+ ) -> Optional[str]:
1612
+ for attempt in range(retries):
1613
+ try:
1614
+ async with self._sem:
1615
+ resp = await self.client.chat.completions.create(
1616
+ model=self.model,
1617
+ messages=[
1618
+ {"role": "system", "content": system},
1619
+ {"role": "user", "content": user},
1620
+ ],
1621
+ temperature=self.temperature,
1622
+ max_tokens=self.max_tokens,
1623
+ )
1624
+ text = resp.choices[0].message.content.strip()
1625
+ return text if text else None
1626
+ except Exception as exc:
1627
+ wait = 2 ** attempt
1628
+ log.warning(f"vLLM call failed (attempt {attempt+1}): {exc}. Retrying in {wait}s.")
1629
+ await asyncio.sleep(wait)
1630
+ return None
1631
+
1632
+ async def generate_batch(
1633
+ self,
1634
+ requests: List[Tuple[str, str, str]], # (request_id, system, user)
1635
+ ) -> Dict[str, Optional[str]]:
1636
+ """
1637
+ Fire all requests concurrently (bounded by semaphore) and return a dict.
1638
+ """
1639
+ async def _one(rid, sys, usr):
1640
+ return rid, await self.generate_one(sys, usr)
1641
+
1642
+ tasks = [_one(rid, sys, usr) for rid, sys, usr in requests]
1643
+ results = await asyncio.gather(*tasks)
1644
+ return {rid: text for rid, text in results}
1645
+
1646
+
1647
+ # ─────────────────────────────────────────────────────────────────────────────
1648
+ # Data Factory
1649
+ # ─────────────────────────────────────────────────────────────────────────────
1650
+
1651
+ @dataclass
1652
+ class DataPoint:
1653
+ id: str
1654
+ difficulty: str
1655
+ persona: str
1656
+ question: str
1657
+ sql: str
1658
+ db_result_ok: bool
1659
+ augmented: bool
1660
+
1661
+ def to_training_prompt(self, system_prompt: str) -> Dict[str, Any]:
1662
+ """
1663
+ Return the dict structure expected by train.py / SFT pipelines.
1664
+ Includes both the raw fields and a formatted 'messages' list.
1665
+ """
1666
+ user_content = (
1667
+ f"SCHEMA:\n{SCHEMA_CONTEXT}\n\nQUESTION: {self.question}"
1668
+ )
1669
+ return {
1670
+ **asdict(self),
1671
+ "messages": [
1672
+ {"role": "system", "content": system_prompt},
1673
+ {"role": "user", "content": user_content},
1674
+ {"role": "assistant", "content": self.sql},
1675
+ ],
1676
+ }
1677
+
1678
+
1679
+ SYSTEM_PROMPT = (
1680
+ "You are an expert SQL analyst working with a SQLite e-commerce database. "
1681
+ "Write a single SELECT query. Output ONLY the SQL query, nothing else. No markdown."
1682
+ )
1683
+
1684
+
1685
+ class DataFactory:
1686
+ def __init__(
1687
+ self,
1688
+ generator: VLLMGenerator,
1689
+ validator: SQLiteValidator,
1690
+ augmentor: RuleAugmentor,
1691
+ personas_per_template: int = 5,
1692
+ aug_rounds: int = 2,
1693
+ seed: int = 42,
1694
+ ):
1695
+ self.generator = generator
1696
+ self.validator = validator
1697
+ self.augmentor = augmentor
1698
+ self.personas_per_template = personas_per_template
1699
+ self.aug_rounds = aug_rounds
1700
+ self.rng = random.Random(seed)
1701
+
1702
+ # ── Step 1: Validate all template SQLs ───────────────────────────────────
1703
+
1704
+ def validate_templates(self) -> List[SQLTemplate]:
1705
+ log.info("Validating all SQL templates against seeded DB...")
1706
+ valid = []
1707
+ failed = []
1708
+ for t in ALL_TEMPLATES:
1709
+ ok, err = self.validator.validate(t.sql)
1710
+ if ok:
1711
+ valid.append(t)
1712
+ else:
1713
+ failed.append((t.id, err))
1714
+ if failed:
1715
+ log.error(f"FAILED templates (will be skipped): {failed}")
1716
+ log.info(f"Templates validated: {len(valid)} ok, {len(failed)} failed.")
1717
+ return valid
1718
+
1719
+ # ── Step 2: Build generation requests ────────────────────────────────────
1720
+
1721
+ def _build_requests(
1722
+ self,
1723
+ templates: List[SQLTemplate],
1724
+ persona_names: List[str],
1725
+ ) -> List[Tuple[str, str, str]]:
1726
+ """
1727
+ Returns a flat list of (request_id, system_prompt, user_prompt) tuples.
1728
+ """
1729
+ requests = []
1730
+ for t in templates:
1731
+ chosen_personas = (
1732
+ persona_names
1733
+ if self.personas_per_template >= len(PERSONA_SPECS)
1734
+ else self.rng.sample(persona_names, self.personas_per_template)
1735
+ )
1736
+ for persona in chosen_personas:
1737
+ rid = f"{t.id}__{persona}"
1738
+ system = (
1739
+ f"{PERSONA_SPECS[persona]}\n\n"
1740
+ "Output ONLY the natural language question. "
1741
+ "No explanation, no SQL, no preamble, no quotes around the question."
1742
+ )
1743
+ user = (
1744
+ f"{SCHEMA_CONTEXT}\n"
1745
+ f"The SQL query that answers this question is:\n{t.sql}\n\n"
1746
+ f"Write ONE natural-language question that a {persona.upper()} user "
1747
+ f"would ask to get this exact result."
1748
+ )
1749
+ requests.append((rid, system, user))
1750
+ return requests
1751
+
1752
+ # ── Step 3: Post-process a generated question ─────────────────────────────
1753
+
1754
+ @staticmethod
1755
+ def _clean(text: str) -> str:
1756
+ """Strip quotes, markdown, leading numbers, trailing newlines."""
1757
+ text = text.strip()
1758
+ # Remove leading numbering like "1. " or "Q: "
1759
+ text = re.sub(r'^[\d]+[\.\)]\s+', '', text)
1760
+ text = re.sub(r'^[Qq]:\s*', '', text)
1761
+ # Strip surrounding quotes
1762
+ if (text.startswith('"') and text.endswith('"')) or \
1763
+ (text.startswith("'") and text.endswith("'")):
1764
+ text = text[1:-1].strip()
1765
+ # Collapse multiple whitespace
1766
+ text = re.sub(r'\s+', ' ', text)
1767
+ return text
1768
+
1769
+ # ── Main pipeline ─────────────────────────────────────────────────────────
1770
+
1771
+ async def run(
1772
+ self,
1773
+ output_path: str,
1774
+ checkpoint_path: str,
1775
+ batch_size: int = 64,
1776
+ ) -> None:
1777
+ # -- Validate templates
1778
+ templates = self.validate_templates()
1779
+
1780
+ # -- Load checkpoint
1781
+ done_ids: set = set()
1782
+ if os.path.exists(checkpoint_path):
1783
+ with open(checkpoint_path) as f:
1784
+ done_ids = set(json.loads(line)["id"] for line in f if line.strip())
1785
+ log.info(f"Resuming: {len(done_ids)} examples already generated.")
1786
+
1787
+ persona_names = list(PERSONA_SPECS.keys())[: self.personas_per_template]
1788
+
1789
+ all_requests = self._build_requests(templates, persona_names)
1790
+ # Filter already done
1791
+ pending = [r for r in all_requests if r[0] not in done_ids]
1792
+ log.info(f"Total requests to generate: {len(pending)}")
1793
+
1794
+ # -- Build template lookup
1795
+ tmpl_lookup: Dict[str, SQLTemplate] = {t.id: t for t in templates}
1796
+
1797
+ stats = {"generated": 0, "invalid_llm": 0, "augmented": 0}
1798
+
1799
+ out_f = open(output_path, "a")
1800
+ ckpt_f = open(checkpoint_path, "a")
1801
+
1802
+ try:
1803
+ for i in tqdm(range(0, len(pending), batch_size), desc="Batches"):
1804
+ batch = pending[i: i + batch_size]
1805
+ results = await self.generator.generate_batch(batch)
1806
+
1807
+ for rid, raw_text in results.items():
1808
+ tmpl_id, persona = rid.split("__", 1)
1809
+ tmpl = tmpl_lookup[tmpl_id]
1810
+
1811
+ if not raw_text:
1812
+ stats["invalid_llm"] += 1
1813
+ continue
1814
+
1815
+ question = self._clean(raw_text)
1816
+ if len(question) < 8:
1817
+ stats["invalid_llm"] += 1
1818
+ continue
1819
+
1820
+ # SQL already validated; no need to re-run for NL variants
1821
+ dp = DataPoint(
1822
+ id=rid,
1823
+ difficulty=tmpl.difficulty,
1824
+ persona=persona,
1825
+ question=question,
1826
+ sql=tmpl.sql,
1827
+ db_result_ok=True,
1828
+ augmented=False,
1829
+ )
1830
+ record = dp.to_training_prompt(SYSTEM_PROMPT)
1831
+ line = json.dumps(record, ensure_ascii=False)
1832
+ out_f.write(line + "\n")
1833
+ ckpt_f.write(line + "\n")
1834
+ stats["generated"] += 1
1835
+
1836
+ # -- Rule augmentation rounds
1837
+ for aug_i in range(self.aug_rounds):
1838
+ aug_q = self.augmentor.augment(question, self.rng)
1839
+ if aug_q and aug_q != question:
1840
+ aug_dp = DataPoint(
1841
+ id=f"{rid}__aug{aug_i}",
1842
+ difficulty=tmpl.difficulty,
1843
+ persona=persona,
1844
+ question=aug_q,
1845
+ sql=tmpl.sql,
1846
+ db_result_ok=True,
1847
+ augmented=True,
1848
+ )
1849
+ aug_record = aug_dp.to_training_prompt(SYSTEM_PROMPT)
1850
+ aug_line = json.dumps(aug_record, ensure_ascii=False)
1851
+ out_f.write(aug_line + "\n")
1852
+ ckpt_f.write(aug_line + "\n")
1853
+ stats["augmented"] += 1
1854
+
1855
+ out_f.flush()
1856
+ ckpt_f.flush()
1857
+
1858
+ finally:
1859
+ out_f.close()
1860
+ ckpt_f.close()
1861
+
1862
+ log.info(
1863
+ f"Done. Generated={stats['generated']} "
1864
+ f"Augmented={stats['augmented']} "
1865
+ f"LLM failures={stats['invalid_llm']}"
1866
+ )
1867
+ log.info(f"Output: {output_path}")
1868
+
1869
+
1870
+ # ─────────────────────────────────────────────────────────────────────────────
1871
+ # CLI
1872
+ # ─────────────────────────────────────────────────────────────────────────────
1873
+
1874
+ def parse_args() -> argparse.Namespace:
1875
+ p = argparse.ArgumentParser(
1876
+ description="NL2SQL Synthetic Data Factory — H100 + vLLM",
1877
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
1878
+ )
1879
+ p.add_argument("--vllm-url", default="http://localhost:8001/v1",
1880
+ help="Base URL of the running vLLM server.")
1881
+ p.add_argument("--model", default="meta-llama/Meta-Llama-3-70B-Instruct",
1882
+ help="Model name as registered in the vLLM server.")
1883
+ p.add_argument("--output", default="nl2sql_train.jsonl",
1884
+ help="Path to write the final JSONL dataset.")
1885
+ p.add_argument("--checkpoint",default="nl2sql_checkpoint.jsonl",
1886
+ help="Path for the checkpoint file (enables resume on crash).")
1887
+ p.add_argument("--personas-per-template", type=int, default=5,
1888
+ help="Number of persona variants to generate per SQL template (max 5).")
1889
+ p.add_argument("--aug-rounds", type=int, default=2,
1890
+ help="Number of rule-based augmentation rounds per generated question.")
1891
+ p.add_argument("--batch-size", type=int, default=64,
1892
+ help="Concurrent vLLM requests per batch (tune based on GPU memory).")
1893
+ p.add_argument("--temperature", type=float, default=0.85,
1894
+ help="Sampling temperature for vLLM (higher = more diverse).")
1895
+ p.add_argument("--max-tokens", type=int, default=200,
1896
+ help="Max tokens for each generated question.")
1897
+ p.add_argument("--seed", type=int, default=42)
1898
+ p.add_argument("--validate-only", action="store_true",
1899
+ help="Only validate SQL templates, do not generate data.")
1900
+ return p.parse_args()
1901
+
1902
+
1903
+ async def main() -> None:
1904
+ args = parse_args()
1905
+
1906
+ # Build DB + validator
1907
+ conn = build_db()
1908
+ validator = SQLiteValidator(conn)
1909
+
1910
+ if args.validate_only:
1911
+ valid = [t for t in ALL_TEMPLATES if validator.validate(t.sql)[0]]
1912
+ invalid = [t for t in ALL_TEMPLATES if not validator.validate(t.sql)[0]]
1913
+ print(f"\n✅ Valid: {len(valid)}")
1914
+ print(f"❌ Invalid: {len(invalid)}")
1915
+ for t in invalid:
1916
+ _, err = validator.validate(t.sql)
1917
+ print(f" {t.id}: {err}")
1918
+ return
1919
+
1920
+ # Build pipeline components
1921
+ generator = VLLMGenerator(
1922
+ base_url=args.vllm_url,
1923
+ model=args.model,
1924
+ temperature=args.temperature,
1925
+ max_tokens=args.max_tokens,
1926
+ semaphore=args.batch_size,
1927
+ )
1928
+ augmentor = RuleAugmentor()
1929
+
1930
+ factory = DataFactory(
1931
+ generator=generator,
1932
+ validator=validator,
1933
+ augmentor=augmentor,
1934
+ personas_per_template=min(args.personas_per_template, len(PERSONA_SPECS)),
1935
+ aug_rounds=args.aug_rounds,
1936
+ seed=args.seed,
1937
+ )
1938
+
1939
+ await factory.run(
1940
+ output_path=args.output,
1941
+ checkpoint_path=args.checkpoint,
1942
+ batch_size=args.batch_size,
1943
+ )
1944
+
1945
+
1946
+ if __name__ == "__main__":
1947
+ asyncio.run(main())
data_factory/generator.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ data_factory/generator.py
3
+ ==========================
4
+ vLLM-based Natural Language question generator for H100.
5
+
6
+ This module uses a large LLM (Llama-3-70B or Qwen-72B) served via vLLM
7
+ to generate diverse, persona-based natural language paraphrases of the
8
+ canonical NL questions in our template library.
9
+
10
+ KEY DESIGN: The LLM generates ONLY natural language questions.
11
+ SQL is NEVER touched by the LLM.
12
+ This guarantees zero SQL errors in the final dataset.
13
+
14
+ Persona descriptions:
15
+ ceo - Direct, short, active voice. Business executive style.
16
+ chatty - Conversational, verbose, passive voice.
17
+ lazy_typist - Short, abbreviations, possible informal grammar.
18
+ non_techie - Plain English, avoids SQL/tech jargon, uses synonyms.
19
+ analyst - Technical, precise, jargon-heavy.
20
+
21
+ Usage (on H100 cluster):
22
+ python -m data_factory.generator --templates-per-chunk 20 --n-variants 10
23
+ """
24
+
25
+ from __future__ import annotations
26
+
27
+ import json
28
+ import logging
29
+ import time
30
+ from typing import Iterator, Optional
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ # ─────────────────────────────────────────────────────────────────────────────
36
+ # PERSONA SYSTEM PROMPTS
37
+ # ─────────────────────────────────────────────────────────────────────────────
38
+
39
+ PERSONA_SYSTEM_PROMPTS: dict[str, str] = {
40
+
41
+ "ceo": (
42
+ "You are a busy C-level executive who communicates in short, punchy, "
43
+ "direct sentences. You use active voice, skip filler words, and get "
44
+ "straight to the point. You are asking a data analyst for information."
45
+ ),
46
+
47
+ "chatty": (
48
+ "You are a friendly, conversational person who likes to be thorough "
49
+ "and explain things fully. You use passive voice sometimes, add context, "
50
+ "and ask questions in a relaxed, detailed way. You are not technical."
51
+ ),
52
+
53
+ "lazy_typist": (
54
+ "You type quickly and informally. You use abbreviations (e.g. 'pls', "
55
+ "'lmk', 'asap'), lowercase, minimal punctuation, and sometimes omit "
56
+ "words. You get your meaning across without perfect grammar."
57
+ ),
58
+
59
+ "non_techie": (
60
+ "You have no database or SQL knowledge. You use everyday English words "
61
+ "instead of technical terms. For example, you say 'customers' not 'rows', "
62
+ "'most expensive' not 'highest price', 'total money' not 'sum'. "
63
+ "You describe what you want to see, not how to get it."
64
+ ),
65
+
66
+ "analyst": (
67
+ "You are a data scientist or BI analyst who is precise and technical. "
68
+ "You use terms like 'aggregate', 'partition', 'granularity', 'distinct', "
69
+ "'filter predicate', 'ranked by metric'. Your questions are precise and unambiguous."
70
+ ),
71
+ }
72
+
73
+
74
+ # ─────────────────────────────────────────────────────────────────────────────
75
+ # PROMPT BUILDER
76
+ # ─────────────────────────────────────────────────────────────────────────────
77
+
78
+ def build_generation_prompt(
79
+ canonical_nl: str,
80
+ description: str,
81
+ persona: str,
82
+ schema_context: str,
83
+ n_variants: int = 5,
84
+ ) -> list[dict[str, str]]:
85
+ """
86
+ Build a chat-format prompt asking the LLM to rephrase the canonical NL
87
+ question in the style of the given persona.
88
+
89
+ Parameters
90
+ ----------
91
+ canonical_nl : The base NL question from the template.
92
+ description : One-line SQL description (gives the LLM additional context).
93
+ persona : One of the 5 persona keys.
94
+ schema_context : The compact schema string for the domain.
95
+ n_variants : How many rephrased questions to generate.
96
+
97
+ Returns
98
+ -------
99
+ list[dict] Chat messages in [{"role": ..., "content": ...}] format.
100
+ """
101
+ persona_desc = PERSONA_SYSTEM_PROMPTS[persona]
102
+
103
+ system = (
104
+ "You are a data labelling specialist. Your task is to rephrase a database "
105
+ "question in a specific communication style (persona). The rephrased questions "
106
+ "must preserve the EXACT same intent and required information as the original — "
107
+ "do not change what data is being asked for, only how it is expressed.\n\n"
108
+ f"PERSONA: {persona_desc}\n\n"
109
+ "OUTPUT FORMAT: Return ONLY a valid JSON array of strings. "
110
+ "No preamble, no markdown, no extra keys. Example: "
111
+ '["question 1", "question 2", "question 3"]'
112
+ )
113
+
114
+ user = (
115
+ f"DATABASE CONTEXT:\n{schema_context}\n\n"
116
+ f"WHAT THE QUERY DOES: {description}\n\n"
117
+ f"CANONICAL QUESTION: {canonical_nl}\n\n"
118
+ f"Generate {n_variants} different ways a person with the persona described "
119
+ f"above would ask this same question. The meaning must stay identical."
120
+ )
121
+
122
+ return [
123
+ {"role": "system", "content": system},
124
+ {"role": "user", "content": user},
125
+ ]
126
+
127
+
128
+ # ─────────────────────────────────────────────────────────────────────────────
129
+ # RESPONSE PARSER
130
+ # ─────────────────────────────────────────────────────────────────────────────
131
+
132
+ def parse_llm_response(raw_text: str) -> list[str]:
133
+ """
134
+ Extract a list of strings from the LLM's JSON response.
135
+ Handles common failures: markdown fences, trailing commas, extra text.
136
+
137
+ Returns an empty list if parsing fails completely.
138
+ """
139
+ text = raw_text.strip()
140
+
141
+ # Strip markdown fences if present
142
+ if text.startswith("```"):
143
+ lines = text.split("\n")
144
+ text = "\n".join(l for l in lines if not l.strip().startswith("```")).strip()
145
+
146
+ # Find the JSON array boundaries
147
+ start = text.find("[")
148
+ end = text.rfind("]")
149
+ if start == -1 or end == -1 or end <= start:
150
+ logger.warning("LLM response missing JSON array brackets: %s", text[:100])
151
+ return []
152
+
153
+ json_str = text[start:end + 1]
154
+
155
+ # Fix trailing commas before ] (common LLM mistake)
156
+ json_str = json_str.rstrip()
157
+ json_str = json_str.replace(",]", "]").replace(", ]", "]")
158
+
159
+ try:
160
+ parsed = json.loads(json_str)
161
+ if not isinstance(parsed, list):
162
+ return []
163
+ # Filter to only non-empty strings
164
+ return [s.strip() for s in parsed if isinstance(s, str) and s.strip()]
165
+ except json.JSONDecodeError as exc:
166
+ logger.warning("JSON parse error: %s | text: %s", exc, json_str[:200])
167
+ return []
168
+
169
+
170
+ # ─────────────────────────────────────────────────────────────────────────────
171
+ # VLLM INTERFACE
172
+ # ─────────────────────────────────────────────────────────────────────────────
173
+
174
+ class VLLMGenerator:
175
+ """
176
+ Wrapper around a running vLLM server for high-throughput NL generation.
177
+
178
+ Supports two modes:
179
+ online : Calls a running vLLM OpenAI-compatible API server.
180
+ offline : Uses vllm.LLM directly (loads model in-process, H100 recommended).
181
+
182
+ For H100 cluster usage, prefer 'offline' mode with tensor_parallel_size=4
183
+ to saturate all 4 H100s for maximum throughput.
184
+ """
185
+
186
+ def __init__(
187
+ self,
188
+ model_name: str,
189
+ mode: str = "offline",
190
+ tensor_parallel_size: int = 4,
191
+ gpu_memory_utilization: float = 0.90,
192
+ max_model_len: int = 4096,
193
+ # Online mode only
194
+ api_base: str = "http://localhost:8000/v1",
195
+ api_key: str = "EMPTY",
196
+ ) -> None:
197
+ self.model_name = model_name
198
+ self.mode = mode
199
+ self._llm = None
200
+ self._client = None
201
+
202
+ if mode == "offline":
203
+ self._init_offline(tensor_parallel_size, gpu_memory_utilization, max_model_len)
204
+ elif mode == "online":
205
+ self._init_online(api_base, api_key)
206
+ else:
207
+ raise ValueError(f"Unknown mode: {mode!r}. Use 'offline' or 'online'.")
208
+
209
+ def _init_offline(
210
+ self,
211
+ tensor_parallel_size: int,
212
+ gpu_memory_utilization: float,
213
+ max_model_len: int,
214
+ ) -> None:
215
+ """Load vLLM engine in-process (best for H100 cluster)."""
216
+ try:
217
+ from vllm import LLM, SamplingParams
218
+ self._LLM = LLM
219
+ self._SamplingParams = SamplingParams
220
+ except ImportError:
221
+ raise ImportError(
222
+ "vLLM not installed. Run: pip install vllm\n"
223
+ "For H100: pip install vllm --extra-index-url https://download.pytorch.org/whl/cu124"
224
+ )
225
+
226
+ logger.info("Loading model %s with %d GPUs (offline mode)...", self.model_name, tensor_parallel_size)
227
+ t0 = time.time()
228
+ self._llm = self._LLM(
229
+ model=self.model_name,
230
+ tensor_parallel_size=tensor_parallel_size,
231
+ gpu_memory_utilization=gpu_memory_utilization,
232
+ max_model_len=max_model_len,
233
+ dtype="bfloat16",
234
+ trust_remote_code=True,
235
+ )
236
+ logger.info("Model loaded in %.1f seconds.", time.time() - t0)
237
+
238
+ def _init_online(self, api_base: str, api_key: str) -> None:
239
+ """Use OpenAI-compatible vLLM server (for distributed setups)."""
240
+ try:
241
+ from openai import OpenAI
242
+ self._client = OpenAI(base_url=api_base, api_key=api_key)
243
+ except ImportError:
244
+ raise ImportError("pip install openai")
245
+ logger.info("Connected to vLLM server at %s", api_base)
246
+
247
+ def generate_batch(
248
+ self,
249
+ prompts: list[list[dict[str, str]]],
250
+ temperature: float = 0.85,
251
+ max_new_tokens: int = 300,
252
+ ) -> list[str]:
253
+ """
254
+ Generate responses for a batch of chat prompts.
255
+
256
+ Parameters
257
+ ----------
258
+ prompts : List of chat message lists (one per item in batch).
259
+ temperature : Sampling temperature. Higher = more diverse.
260
+ max_new_tokens : Max tokens per response.
261
+
262
+ Returns
263
+ -------
264
+ list[str] Raw text response per prompt (same length as input).
265
+ """
266
+ if self.mode == "offline":
267
+ return self._generate_offline(prompts, temperature, max_new_tokens)
268
+ else:
269
+ return self._generate_online(prompts, temperature, max_new_tokens)
270
+
271
+ def _generate_offline(
272
+ self,
273
+ prompts: list[list[dict]],
274
+ temperature: float,
275
+ max_new_tokens: int,
276
+ ) -> list[str]:
277
+ """vLLM offline batched generation — maximises H100 throughput."""
278
+ from vllm import SamplingParams
279
+
280
+ sampling = SamplingParams(
281
+ temperature=temperature,
282
+ max_tokens=max_new_tokens,
283
+ stop=["</s>", "<|eot_id|>"], # Llama-3 stop tokens
284
+ )
285
+
286
+ # Convert chat messages to tokenised prompt strings using the model's template
287
+ tokenizer = self._llm.get_tokenizer()
288
+ formatted_prompts: list[str] = []
289
+ for msgs in prompts:
290
+ if hasattr(tokenizer, "apply_chat_template"):
291
+ text = tokenizer.apply_chat_template(
292
+ msgs, tokenize=False, add_generation_prompt=True
293
+ )
294
+ else:
295
+ # Fallback: simple concatenation
296
+ text = "\n".join(
297
+ f"<|{m['role']}|>\n{m['content']}" for m in msgs
298
+ )
299
+ formatted_prompts.append(text)
300
+
301
+ outputs = self._llm.generate(formatted_prompts, sampling)
302
+ return [o.outputs[0].text for o in outputs]
303
+
304
+ def _generate_online(
305
+ self,
306
+ prompts: list[list[dict]],
307
+ temperature: float,
308
+ max_new_tokens: int,
309
+ ) -> list[str]:
310
+ """Sequential generation via OpenAI-compatible API (fallback / debugging)."""
311
+ results = []
312
+ for msgs in prompts:
313
+ try:
314
+ resp = self._client.chat.completions.create(
315
+ model=self.model_name,
316
+ messages=msgs,
317
+ temperature=temperature,
318
+ max_tokens=max_new_tokens,
319
+ )
320
+ results.append(resp.choices[0].message.content or "")
321
+ except Exception as exc:
322
+ logger.warning("API call failed: %s", exc)
323
+ results.append("")
324
+ return results
325
+
326
+
327
+ # ─────────────────────────────────────────────────────────────────────────────
328
+ # HIGH-LEVEL GENERATION LOOP
329
+ # ─────────────────────────────────────────────────────────────────────────────
330
+
331
+ def generate_persona_variants_batch(
332
+ templates_subset: list[dict],
333
+ generator: VLLMGenerator,
334
+ personas: list[str],
335
+ n_variants_per_persona: int = 5,
336
+ batch_size: int = 64,
337
+ temperature: float = 0.85,
338
+ max_new_tokens: int = 300,
339
+ ) -> Iterator[dict]:
340
+ """
341
+ For each template × persona combination, generate `n_variants_per_persona`
342
+ NL question variants using the LLM.
343
+
344
+ Yields dicts:
345
+ {
346
+ "template_idx": int,
347
+ "persona": str,
348
+ "nl_variants": list[str], # successfully parsed NL questions
349
+ }
350
+
351
+ Parameters
352
+ ----------
353
+ templates_subset : List of template dicts (from templates.py).
354
+ generator : VLLMGenerator instance.
355
+ personas : List of persona keys to use.
356
+ n_variants_per_persona : How many NL variants per (template, persona) pair.
357
+ batch_size : How many LLM calls to batch together.
358
+ temperature : Sampling temperature.
359
+ max_new_tokens : Max tokens for LLM response (should be ~300 for JSON array).
360
+ """
361
+ from data_factory.schemas import SCHEMA_CONTEXT
362
+
363
+ # Build all (template_idx, persona) prompt pairs
364
+ all_jobs: list[tuple[int, str, list[dict]]] = []
365
+
366
+ for t_idx, template in enumerate(templates_subset):
367
+ schema_ctx = SCHEMA_CONTEXT[template["domain"]]
368
+ for persona in personas:
369
+ prompt = build_generation_prompt(
370
+ canonical_nl=template["base_nl"],
371
+ description=template["description"],
372
+ persona=persona,
373
+ schema_context=schema_ctx,
374
+ n_variants=n_variants_per_persona,
375
+ )
376
+ all_jobs.append((t_idx, persona, prompt))
377
+
378
+ total_jobs = len(all_jobs)
379
+ logger.info("Starting LLM generation: %d jobs (templates × personas).", total_jobs)
380
+
381
+ # Process in batches
382
+ for batch_start in range(0, total_jobs, batch_size):
383
+ batch = all_jobs[batch_start: batch_start + batch_size]
384
+ prompts = [job[2] for job in batch]
385
+
386
+ t0 = time.time()
387
+ raw_responses = generator.generate_batch(
388
+ prompts, temperature=temperature, max_new_tokens=max_new_tokens
389
+ )
390
+ elapsed = time.time() - t0
391
+ logger.info(
392
+ "Batch %d-%d completed in %.1fs (%.1f jobs/s).",
393
+ batch_start, batch_start + len(batch), elapsed, len(batch) / max(elapsed, 0.001)
394
+ )
395
+
396
+ for (t_idx, persona, _), raw in zip(batch, raw_responses):
397
+ nl_variants = parse_llm_response(raw)
398
+ if not nl_variants:
399
+ logger.debug(
400
+ "Empty parse for template_idx=%d persona=%s. raw=%s",
401
+ t_idx, persona, raw[:100]
402
+ )
403
+ # Fall back to the canonical NL rather than losing this entry
404
+ nl_variants = [templates_subset[t_idx]["base_nl"]]
405
+
406
+ yield {
407
+ "template_idx": t_idx,
408
+ "persona": persona,
409
+ "nl_variants": nl_variants,
410
+ }
data_factory/pipeline.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ data_factory/pipeline.py
3
+ =========================
4
+ Master orchestration pipeline for the NL2SQL Synthetic Data Factory.
5
+
6
+ This module ties together:
7
+ 1. Template library (66 verified SQL templates across 4 domains)
8
+ 2. Rule-based NL augmentation (augmentor.py)
9
+ 3. vLLM persona-based NL generation (generator.py)
10
+ 4. SQL execution validation (validator.py)
11
+ 5. Output serialisation (JSONL + Parquet)
12
+
13
+ Run modes:
14
+ --mode base : Only uses template base_nl + rule augmentation (no GPU required)
15
+ --mode full : base + vLLM persona generation (requires H100)
16
+
17
+ Output dataset format (JSONL, one record per line):
18
+ {
19
+ "prompt": [{"role": "system", ...}, {"role": "user", ...}],
20
+ "sql": "SELECT ...",
21
+ "metadata": { "domain", "difficulty", "persona", ... }
22
+ }
23
+
24
+ This format is directly loadable by:
25
+ datasets.load_dataset("json", data_files="output/train.jsonl")
26
+ """
27
+
28
+ from __future__ import annotations
29
+
30
+ import argparse
31
+ import json
32
+ import logging
33
+ import os
34
+ import random
35
+ import time
36
+ from pathlib import Path
37
+ from typing import Any, Iterator, Optional
38
+
39
+ logging.basicConfig(
40
+ level=logging.INFO,
41
+ format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
42
+ datefmt="%H:%M:%S",
43
+ )
44
+ logger = logging.getLogger("pipeline")
45
+
46
+
47
+ # ─────────────────────────────────────────────────────────────────────────────
48
+ # HELPERS
49
+ # ─────────────────────────────────────────────────────────────────────────────
50
+
51
+ def _ensure_dirs(*dirs: Path) -> None:
52
+ for d in dirs:
53
+ d.mkdir(parents=True, exist_ok=True)
54
+
55
+
56
+ def _write_jsonl(records: list[dict], path: Path) -> None:
57
+ with open(path, "w", encoding="utf-8") as f:
58
+ for rec in records:
59
+ f.write(json.dumps(rec, ensure_ascii=False) + "\n")
60
+ logger.info("Wrote %d records to %s", len(records), path)
61
+
62
+
63
+ def _write_parquet(records: list[dict], path: Path) -> None:
64
+ try:
65
+ import pandas as pd
66
+ df = pd.DataFrame(records)
67
+ df.to_parquet(path, index=False, engine="pyarrow", compression="snappy")
68
+ logger.info("Wrote %d records to %s (Parquet)", len(records), path)
69
+ except ImportError:
70
+ logger.warning("pandas/pyarrow not installed — skipping Parquet output.")
71
+
72
+
73
+ def _train_val_test_split(
74
+ records: list[dict],
75
+ train_frac: float = 0.90,
76
+ val_frac: float = 0.05,
77
+ seed: int = 42,
78
+ ) -> tuple[list[dict], list[dict], list[dict]]:
79
+ """
80
+ Stratified split by (domain, difficulty) to ensure all combinations
81
+ are represented in every split.
82
+ """
83
+ rng = random.Random(seed)
84
+ from collections import defaultdict
85
+
86
+ buckets: dict[str, list[dict]] = defaultdict(list)
87
+ for rec in records:
88
+ key = f"{rec['metadata']['domain']}_{rec['metadata']['difficulty']}"
89
+ buckets[key].append(rec)
90
+
91
+ train, val, test = [], [], []
92
+ for key, bucket in buckets.items():
93
+ rng.shuffle(bucket)
94
+ n = len(bucket)
95
+ n_train = max(1, int(n * train_frac))
96
+ n_val = max(1, int(n * val_frac))
97
+ train.extend(bucket[:n_train])
98
+ val.extend(bucket[n_train:n_train + n_val])
99
+ test.extend(bucket[n_train + n_val:])
100
+
101
+ rng.shuffle(train)
102
+ rng.shuffle(val)
103
+ rng.shuffle(test)
104
+ return train, val, test
105
+
106
+
107
+ # ─────────────────────────────────────────────────────────────────────────────
108
+ # PHASE 1: BASE + RULE AUGMENTATION (no GPU required)
109
+ # ─────────────────────────────────────────────────────────────────────────────
110
+
111
+ def run_base_pipeline(
112
+ templates: list,
113
+ n_augmentations: int = 5,
114
+ seed: int = 42,
115
+ ) -> list[dict]:
116
+ """
117
+ Generate training records from:
118
+ (a) the canonical base_nl of each template
119
+ (b) rule-based augmented NL variants
120
+
121
+ Returns a list of training dicts (ready to write to JSONL).
122
+ """
123
+ from data_factory.augmentor import augment_nl
124
+ from data_factory.validator import SQLValidator, build_record
125
+ from data_factory.schemas import SCHEMA_MAP
126
+
127
+ # Build one validator per domain (reuse connection across templates)
128
+ validators = {domain: SQLValidator(domain, seed=seed) for domain in SCHEMA_MAP}
129
+ records: list[dict] = []
130
+
131
+ for t_idx, template in enumerate(templates):
132
+ v = validators[template["domain"]]
133
+
134
+ # (a) Canonical base_nl
135
+ rec = build_record(
136
+ template=template,
137
+ template_idx=t_idx,
138
+ nl_question=template["base_nl"],
139
+ persona="canonical",
140
+ source="template_base",
141
+ validator=v,
142
+ )
143
+ if rec:
144
+ records.append(rec.to_training_dict())
145
+
146
+ # (b) Rule-augmented variants
147
+ augmented = augment_nl(
148
+ nl_question=template["base_nl"],
149
+ n=n_augmentations,
150
+ seed=seed + t_idx,
151
+ )
152
+ for nl_variant in augmented:
153
+ rec = build_record(
154
+ template=template,
155
+ template_idx=t_idx,
156
+ nl_question=nl_variant,
157
+ persona="rule_augmented",
158
+ source="rule_augmented",
159
+ validator=v,
160
+ )
161
+ if rec:
162
+ records.append(rec.to_training_dict())
163
+
164
+ for v in validators.values():
165
+ v.close()
166
+
167
+ logger.info("Base pipeline: %d records generated from %d templates.", len(records), len(templates))
168
+ return records
169
+
170
+
171
+ # ─────────────────────────────────────────────────────────────────────────────
172
+ # PHASE 2: vLLM PERSONA GENERATION (H100 required)
173
+ # ─────────────────────────────────────────────────────────────────────────────
174
+
175
+ def run_vllm_pipeline(
176
+ templates: list,
177
+ generator, # VLLMGenerator instance
178
+ personas: list[str],
179
+ n_variants_per_persona: int = 10,
180
+ batch_size: int = 64,
181
+ temperature: float = 0.85,
182
+ max_new_tokens: int = 350,
183
+ seed: int = 42,
184
+ ) -> list[dict]:
185
+ """
186
+ Generate additional NL variants using the LLM, then validate SQL.
187
+
188
+ Returns a list of training dicts.
189
+ """
190
+ from data_factory.generator import generate_persona_variants_batch
191
+ from data_factory.validator import SQLValidator, build_record
192
+ from data_factory.schemas import SCHEMA_MAP
193
+
194
+ validators = {domain: SQLValidator(domain, seed=seed) for domain in SCHEMA_MAP}
195
+ records: list[dict] = []
196
+
197
+ gen_iter = generate_persona_variants_batch(
198
+ templates_subset=templates,
199
+ generator=generator,
200
+ personas=personas,
201
+ n_variants_per_persona=n_variants_per_persona,
202
+ batch_size=batch_size,
203
+ temperature=temperature,
204
+ max_new_tokens=max_new_tokens,
205
+ )
206
+
207
+ for job_result in gen_iter:
208
+ t_idx = job_result["template_idx"]
209
+ persona = job_result["persona"]
210
+ template = templates[t_idx]
211
+ v = validators[template["domain"]]
212
+
213
+ for nl_variant in job_result["nl_variants"]:
214
+ rec = build_record(
215
+ template=template,
216
+ template_idx=t_idx,
217
+ nl_question=nl_variant,
218
+ persona=persona,
219
+ source="vllm_persona",
220
+ validator=v,
221
+ )
222
+ if rec:
223
+ records.append(rec.to_training_dict())
224
+
225
+ for v in validators.values():
226
+ v.close()
227
+
228
+ logger.info("vLLM pipeline: %d records generated.", len(records))
229
+ return records
230
+
231
+
232
+ # ─────────────────────────────────────────────────────────────────────────────
233
+ # CHECKPOINT UTILITIES
234
+ # ─────────────────────────────────────────────────────────────────────────────
235
+
236
+ def save_checkpoint(records: list[dict], checkpoint_dir: Path, name: str) -> Path:
237
+ path = checkpoint_dir / f"{name}.jsonl"
238
+ _write_jsonl(records, path)
239
+ return path
240
+
241
+
242
+ def load_checkpoint(checkpoint_dir: Path, name: str) -> Optional[list[dict]]:
243
+ path = checkpoint_dir / f"{name}.jsonl"
244
+ if not path.exists():
245
+ return None
246
+ records = []
247
+ with open(path, encoding="utf-8") as f:
248
+ for line in f:
249
+ line = line.strip()
250
+ if line:
251
+ records.append(json.loads(line))
252
+ logger.info("Loaded %d records from checkpoint %s", len(records), path)
253
+ return records
254
+
255
+
256
+ # ─────────────────────────────────────────────────────────────────────────────
257
+ # DATASET STATISTICS
258
+ # ─────────────────────────────────────────────────────────────────────────────
259
+
260
+ def print_dataset_stats(records: list[dict]) -> None:
261
+ from collections import Counter
262
+ domains = Counter(r["metadata"]["domain"] for r in records)
263
+ diffs = Counter(r["metadata"]["difficulty"] for r in records)
264
+ personas = Counter(r["metadata"]["persona"] for r in records)
265
+ sources = Counter(r["metadata"]["source"] for r in records)
266
+
267
+ print("\n" + "=" * 55)
268
+ print(f" DATASET STATISTICS ({len(records):,} total records)")
269
+ print("=" * 55)
270
+ print("\nBy Domain:")
271
+ for k, v in sorted(domains.items()):
272
+ print(f" {k:20s}: {v:6,} ({v/len(records)*100:.1f}%)")
273
+ print("\nBy Difficulty:")
274
+ for k, v in sorted(diffs.items()):
275
+ print(f" {k:20s}: {v:6,} ({v/len(records)*100:.1f}%)")
276
+ print("\nBy Persona/Source:")
277
+ for k, v in sorted(personas.items()):
278
+ print(f" {k:20s}: {v:6,}")
279
+ print("\nBy Source:")
280
+ for k, v in sorted(sources.items()):
281
+ print(f" {k:20s}: {v:6,}")
282
+ print("=" * 55 + "\n")
283
+
284
+
285
+ # ─────────────────────────────────────────────────────────────────────────────
286
+ # MAIN ENTRY POINT
287
+ # ─────────────────────────────────────────────────────────────────────────────
288
+
289
+ def main() -> None:
290
+ parser = argparse.ArgumentParser(
291
+ description="NL2SQL Synthetic Data Factory — generates verified training data."
292
+ )
293
+ parser.add_argument(
294
+ "--mode", choices=["base", "full"], default="base",
295
+ help="base = rule augmentation only (no GPU). full = + vLLM on H100.",
296
+ )
297
+ parser.add_argument("--model", default="meta-llama/Meta-Llama-3-70B-Instruct",
298
+ help="HuggingFace model name for vLLM (full mode only).")
299
+ parser.add_argument("--tensor-parallel", type=int, default=4,
300
+ help="Tensor parallel size for vLLM (number of H100s).")
301
+ parser.add_argument("--n-rule-augments", type=int, default=5,
302
+ help="Number of rule-based NL augmentations per template.")
303
+ parser.add_argument("--n-persona-variants", type=int, default=10,
304
+ help="Number of vLLM NL variants per (template, persona) pair.")
305
+ parser.add_argument("--batch-size", type=int, default=64,
306
+ help="vLLM batch size (larger = faster on H100).")
307
+ parser.add_argument("--temperature", type=float, default=0.85,
308
+ help="Sampling temperature for vLLM generation.")
309
+ parser.add_argument("--output-dir", type=str, default="generated_data/output",
310
+ help="Directory to write final dataset files.")
311
+ parser.add_argument("--checkpoint-dir", type=str, default="generated_data/checkpoints",
312
+ help="Directory for intermediate checkpoints.")
313
+ parser.add_argument("--seed", type=int, default=42, help="Global random seed.")
314
+ parser.add_argument("--no-parquet", action="store_true",
315
+ help="Skip Parquet output (write only JSONL).")
316
+ parser.add_argument("--resume", action="store_true",
317
+ help="Resume from latest checkpoint if available.")
318
+ parser.add_argument("--domains", nargs="+",
319
+ choices=["ecommerce","healthcare","finance","hr"],
320
+ default=["ecommerce","healthcare","finance","hr"],
321
+ help="Domains to include (default: all 4).")
322
+ parser.add_argument("--difficulties", nargs="+",
323
+ choices=["easy","medium","hard"],
324
+ default=["easy","medium","hard"],
325
+ help="Difficulty levels to include (default: all 3).")
326
+ args = parser.parse_args()
327
+
328
+ output_dir = Path(args.output_dir)
329
+ checkpoint_dir = Path(args.checkpoint_dir)
330
+ _ensure_dirs(output_dir, checkpoint_dir)
331
+
332
+ # ── Load templates ─────────────────────────────────────────────────────
333
+ from data_factory.templates import ALL_TEMPLATES
334
+
335
+ templates = [
336
+ t for t in ALL_TEMPLATES
337
+ if t["domain"] in args.domains and t["difficulty"] in args.difficulties
338
+ ]
339
+ logger.info("Loaded %d templates (domains=%s, difficulties=%s).",
340
+ len(templates), args.domains, args.difficulties)
341
+
342
+ # ── Phase 1: Base + rule augmentation ─────────────────────────────────
343
+ all_records: list[dict] = []
344
+
345
+ ckpt_base = load_checkpoint(checkpoint_dir, "phase1_base") if args.resume else None
346
+ if ckpt_base is not None:
347
+ all_records.extend(ckpt_base)
348
+ logger.info("Resumed Phase 1 from checkpoint (%d records).", len(ckpt_base))
349
+ else:
350
+ logger.info("=== Phase 1: Base + Rule Augmentation ===")
351
+ base_records = run_base_pipeline(
352
+ templates=templates,
353
+ n_augmentations=args.n_rule_augments,
354
+ seed=args.seed,
355
+ )
356
+ all_records.extend(base_records)
357
+ save_checkpoint(base_records, checkpoint_dir, "phase1_base")
358
+
359
+ # ── Phase 2: vLLM persona generation (full mode only) ─────────────────
360
+ if args.mode == "full":
361
+ ckpt_vllm = load_checkpoint(checkpoint_dir, "phase2_vllm") if args.resume else None
362
+ if ckpt_vllm is not None:
363
+ all_records.extend(ckpt_vllm)
364
+ logger.info("Resumed Phase 2 from checkpoint (%d records).", len(ckpt_vllm))
365
+ else:
366
+ logger.info("=== Phase 2: vLLM Persona Generation ===")
367
+
368
+ from data_factory.generator import VLLMGenerator
369
+ from data_factory.config import PERSONAS
370
+
371
+ generator = VLLMGenerator(
372
+ model_name=args.model,
373
+ mode="offline",
374
+ tensor_parallel_size=args.tensor_parallel,
375
+ gpu_memory_utilization=0.90,
376
+ )
377
+
378
+ vllm_records = run_vllm_pipeline(
379
+ templates=templates,
380
+ generator=generator,
381
+ personas=PERSONAS,
382
+ n_variants_per_persona=args.n_persona_variants,
383
+ batch_size=args.batch_size,
384
+ temperature=args.temperature,
385
+ max_new_tokens=350,
386
+ seed=args.seed,
387
+ )
388
+ all_records.extend(vllm_records)
389
+ save_checkpoint(vllm_records, checkpoint_dir, "phase2_vllm")
390
+
391
+ # ── Deduplication ──────────────────────────────────────────────────────
392
+ logger.info("Deduplicating %d records...", len(all_records))
393
+ seen_nl: set[str] = set()
394
+ deduped: list[dict] = []
395
+ for rec in all_records:
396
+ nl = rec["prompt"][1]["content"] # user message contains the NL question
397
+ if nl not in seen_nl:
398
+ seen_nl.add(nl)
399
+ deduped.append(rec)
400
+ logger.info("After dedup: %d unique records (removed %d duplicates).",
401
+ len(deduped), len(all_records) - len(deduped))
402
+
403
+ # ── Statistics ─────────────────────────────────────────────────────────
404
+ print_dataset_stats(deduped)
405
+
406
+ # ── Train / Val / Test split ───────────────────────────────────────────
407
+ train, val, test = _train_val_test_split(deduped, seed=args.seed)
408
+ logger.info("Split: train=%d | val=%d | test=%d", len(train), len(val), len(test))
409
+
410
+ # ── Write outputs ─────────────────────────────────────────────────────
411
+ _write_jsonl(train, output_dir / "train.jsonl")
412
+ _write_jsonl(val, output_dir / "val.jsonl")
413
+ _write_jsonl(test, output_dir / "test.jsonl")
414
+
415
+ if not args.no_parquet:
416
+ _write_parquet(train, output_dir / "train.parquet")
417
+ _write_parquet(val, output_dir / "val.parquet")
418
+ _write_parquet(test, output_dir / "test.parquet")
419
+
420
+ # ── Write dataset card ─────────────────────────────────────────────────
421
+ card = {
422
+ "name": "NL2SQL-Bench Synthetic Training Dataset",
423
+ "version": "1.0",
424
+ "generated_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
425
+ "total_records": len(deduped),
426
+ "splits": {"train": len(train), "val": len(val), "test": len(test)},
427
+ "domains": args.domains,
428
+ "difficulties": args.difficulties,
429
+ "mode": args.mode,
430
+ "seed": args.seed,
431
+ "sql_guarantee": (
432
+ "Every SQL in this dataset was human-authored and execution-validated "
433
+ "against a seeded SQLite database. Zero LLM-generated SQL."
434
+ ),
435
+ }
436
+ with open(output_dir / "dataset_card.json", "w") as f:
437
+ json.dump(card, f, indent=2)
438
+
439
+ logger.info("=== Done! Dataset written to %s ===", output_dir)
440
+
441
+
442
+ if __name__ == "__main__":
443
+ main()
data_factory/run_data_factory.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ run_data_factory.py
3
+ ====================
4
+ Entry point and smoke-test runner for the NL2SQL Data Factory.
5
+
6
+ Run this FIRST before running the full pipeline to verify:
7
+ 1. All 66 SQL templates execute without errors
8
+ 2. Rule augmentation produces diverse NL variants
9
+ 3. Validators correctly accept/reject queries
10
+ 4. Base pipeline generates well-formed JSONL records
11
+
12
+ Usage:
13
+ # Smoke test only (fast, ~10 seconds)
14
+ python run_data_factory.py --smoke-test
15
+
16
+ # Base mode (no GPU, generates all rule-augmented records)
17
+ python run_data_factory.py --mode base
18
+
19
+ # Full mode (H100 required)
20
+ python run_data_factory.py --mode full --model meta-llama/Meta-Llama-3-70B-Instruct --tensor-parallel 4
21
+
22
+ # Preview what the dataset looks like
23
+ python run_data_factory.py --smoke-test --show-samples 3
24
+ """
25
+
26
+ from __future__ import annotations
27
+
28
+ import argparse
29
+ import json
30
+ import sys
31
+ import textwrap
32
+ from pathlib import Path
33
+
34
+ # Allow running from project root
35
+ sys.path.insert(0, str(Path(__file__).parent))
36
+
37
+
38
+ # ─────────────────────────────────────────────────────────────────────────────
39
+ # SMOKE TEST
40
+ # ─────────────────────────────────────────────────────────────────────────────
41
+
42
+ def run_smoke_test(show_samples: int = 0) -> bool:
43
+ print("\n" + "=" * 60)
44
+ print(" NL2SQL DATA FACTORY — SMOKE TEST")
45
+ print("=" * 60)
46
+
47
+ all_passed = True
48
+
49
+ # 1. Template validation
50
+ print("\n[1/4] Validating all SQL templates against seeded data...")
51
+ from data_factory.templates import ALL_TEMPLATES, template_stats
52
+ from data_factory.validator import validate_all_templates
53
+
54
+ stats = template_stats()
55
+ result = validate_all_templates(ALL_TEMPLATES)
56
+
57
+ print(f" Templates: {stats}")
58
+ print(f" Validation: {result['passed']}/{result['total']} passed", end="")
59
+
60
+ if result["failed"]:
61
+ print(f" ← {result['failed']} FAILURES:")
62
+ for f in result["failures"]:
63
+ print(f" [{f['domain']}] {f['sql']}... → {f['error']}")
64
+ all_passed = False
65
+ else:
66
+ print(" ✓")
67
+
68
+ # 2. Rule augmentation
69
+ print("\n[2/4] Testing rule-based augmentation...")
70
+ from data_factory.augmentor import augment_nl
71
+
72
+ test_nls = [
73
+ "List all gold-tier customers ordered by name alphabetically. Return id, name, email, country.",
74
+ "Which medications are prescribed most often? Return medication_name, category, times_prescribed.",
75
+ "Rank active employees by salary within their department. Return salary_rank.",
76
+ ]
77
+ for nl in test_nls:
78
+ variants = augment_nl(nl, n=3, seed=42)
79
+ if not variants:
80
+ print(f" FAIL: No variants generated for: {nl[:50]}")
81
+ all_passed = False
82
+ else:
83
+ print(f" ✓ {len(variants)} variants from: '{nl[:45]}...'")
84
+ if show_samples > 0:
85
+ for i, v in enumerate(variants[:show_samples]):
86
+ print(f" [{i+1}] {v}")
87
+
88
+ # 3. Validator accept/reject
89
+ print("\n[3/4] Testing SQL validator accept/reject logic...")
90
+ from data_factory.validator import SQLValidator
91
+
92
+ v = SQLValidator("ecommerce")
93
+ tests = [
94
+ ("SELECT id, name FROM customers WHERE tier = 'gold'", True, "valid SELECT"),
95
+ ("INSERT INTO customers VALUES (1,'x','x@x.com','IN','gold','2024-01-01')", False, "rejected INSERT"),
96
+ ("SELECT nonexistent_col FROM customers", False, "bad column name"),
97
+ ("", False, "empty string"),
98
+ ]
99
+ for sql, expect_pass, label in tests:
100
+ vr = v.validate(sql)
101
+ status = "✓" if vr.passed == expect_pass else "✗"
102
+ print(f" {status} {label}: passed={vr.passed}", end="")
103
+ if not vr.passed:
104
+ print(f" (error: {vr.error})", end="")
105
+ print()
106
+ if vr.passed != expect_pass:
107
+ all_passed = False
108
+ v.close()
109
+
110
+ # 4. Mini base pipeline (first 5 templates only)
111
+ print("\n[4/4] Running mini base pipeline (first 5 templates)...")
112
+ from data_factory.pipeline import run_base_pipeline
113
+
114
+ mini_templates = ALL_TEMPLATES[:5]
115
+ records = run_base_pipeline(mini_templates, n_augmentations=2, seed=42)
116
+ expected_min = 5 # at least canonical NLs
117
+ if len(records) < expected_min:
118
+ print(f" FAIL: Only {len(records)} records (expected ≥{expected_min})")
119
+ all_passed = False
120
+ else:
121
+ print(f" ✓ Generated {len(records)} records from 5 templates")
122
+
123
+ # Validate structure
124
+ required_keys = {"prompt", "sql", "metadata"}
125
+ for rec in records[:3]:
126
+ missing = required_keys - rec.keys()
127
+ if missing:
128
+ print(f" FAIL: Record missing keys: {missing}")
129
+ all_passed = False
130
+ break
131
+ else:
132
+ print(" ✓ Record structure validated")
133
+
134
+ if show_samples > 0 and records:
135
+ print(f"\n --- Sample Record ---")
136
+ sample = records[0]
137
+ print(f" Domain: {sample['metadata']['domain']}")
138
+ print(f" Difficulty: {sample['metadata']['difficulty']}")
139
+ print(f" Persona: {sample['metadata']['persona']}")
140
+ print(f" NL: {sample['prompt'][1]['content'].split('QUESTION: ')[-1][:100]}")
141
+ print(f" SQL: {sample['sql'][:80]}...")
142
+
143
+ # Summary
144
+ print("\n" + "=" * 60)
145
+ if all_passed:
146
+ print(" ALL SMOKE TESTS PASSED ✓")
147
+ print(" Safe to run: python run_data_factory.py --mode base")
148
+ else:
149
+ print(" SOME TESTS FAILED ✗ — fix errors before running pipeline")
150
+ print("=" * 60 + "\n")
151
+
152
+ return all_passed
153
+
154
+
155
+ # ─────────────────────────────────────────────────────────────────────────────
156
+ # INSPECT DATASET
157
+ # ─────────────────────────────────────────────────────────────────────────────
158
+
159
+ def inspect_dataset(jsonl_path: str, n: int = 5) -> None:
160
+ """Pretty-print N records from an output JSONL file."""
161
+ path = Path(jsonl_path)
162
+ if not path.exists():
163
+ print(f"File not found: {path}")
164
+ return
165
+
166
+ records = []
167
+ with open(path, encoding="utf-8") as f:
168
+ for i, line in enumerate(f):
169
+ if i >= n:
170
+ break
171
+ records.append(json.loads(line))
172
+
173
+ print(f"\n{'='*65}")
174
+ print(f" Showing {len(records)} records from {path.name}")
175
+ print(f"{'='*65}")
176
+
177
+ for i, rec in enumerate(records):
178
+ nl = rec["prompt"][1]["content"].split("QUESTION:")[-1].strip()
179
+ sql = rec["sql"]
180
+ meta = rec["metadata"]
181
+ print(f"\n[{i+1}] Domain={meta['domain']} | Difficulty={meta['difficulty']} | "
182
+ f"Persona={meta['persona']} | Source={meta['source']}")
183
+ print(f" NL: {textwrap.shorten(nl, 90)}")
184
+ print(f" SQL: {textwrap.shorten(sql, 90)}")
185
+
186
+ print()
187
+
188
+
189
+ # ─────────────────────────────────────────────────────────────────────────────
190
+ # MAIN
191
+ # ─────────────────────────────────────────────────────────────────────────────
192
+
193
+ def main() -> None:
194
+ parser = argparse.ArgumentParser(
195
+ description="NL2SQL Data Factory — entry point.",
196
+ formatter_class=argparse.RawTextHelpFormatter,
197
+ )
198
+ parser.add_argument(
199
+ "--smoke-test", action="store_true",
200
+ help="Run smoke test only (validates all templates, no output written).",
201
+ )
202
+ parser.add_argument(
203
+ "--show-samples", type=int, default=0,
204
+ help="During smoke test, show N sample NL variants and records.",
205
+ )
206
+ parser.add_argument(
207
+ "--inspect", type=str, default=None,
208
+ help="Path to a JSONL output file to inspect.",
209
+ )
210
+ parser.add_argument(
211
+ "--inspect-n", type=int, default=5,
212
+ help="Number of records to show when inspecting.",
213
+ )
214
+ parser.add_argument(
215
+ "--mode", choices=["base", "full"], default="base",
216
+ help=(
217
+ "base: rule augmentation only, ~450 records, no GPU needed.\n"
218
+ "full: + vLLM persona variants, 500K+ records, H100 required."
219
+ ),
220
+ )
221
+ parser.add_argument("--model", default="meta-llama/Meta-Llama-3-70B-Instruct")
222
+ parser.add_argument("--tensor-parallel", type=int, default=4)
223
+ parser.add_argument("--n-rule-augments", type=int, default=5)
224
+ parser.add_argument("--n-persona-variants", type=int, default=10)
225
+ parser.add_argument("--batch-size", type=int, default=64)
226
+ parser.add_argument("--temperature", type=float, default=0.85)
227
+ parser.add_argument("--output-dir", default="generated_data/output")
228
+ parser.add_argument("--checkpoint-dir", default="generated_data/checkpoints")
229
+ parser.add_argument("--seed", type=int, default=42)
230
+ parser.add_argument("--no-parquet", action="store_true")
231
+ parser.add_argument("--resume", action="store_true")
232
+ parser.add_argument(
233
+ "--domains", nargs="+",
234
+ choices=["ecommerce","healthcare","finance","hr"],
235
+ default=["ecommerce","healthcare","finance","hr"],
236
+ )
237
+ parser.add_argument(
238
+ "--difficulties", nargs="+",
239
+ choices=["easy","medium","hard"],
240
+ default=["easy","medium","hard"],
241
+ )
242
+
243
+ args = parser.parse_args()
244
+
245
+ if args.smoke_test:
246
+ ok = run_smoke_test(show_samples=args.show_samples)
247
+ sys.exit(0 if ok else 1)
248
+
249
+ if args.inspect:
250
+ inspect_dataset(args.inspect, n=args.inspect_n)
251
+ sys.exit(0)
252
+
253
+ # Forward to pipeline
254
+ from data_factory.pipeline import main as pipeline_main
255
+ # Re-parse with pipeline's own parser by forwarding sys.argv
256
+ pipeline_main()
257
+
258
+
259
+ if __name__ == "__main__":
260
+ main()
data_factory/schemas.py ADDED
@@ -0,0 +1,564 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ data_factory/schemas.py
3
+ ========================
4
+ SQLite CREATE TABLE statements for all four domains.
5
+ Each schema is fully self-contained and has been verified to create
6
+ without errors in SQLite 3.x.
7
+ """
8
+
9
+ from __future__ import annotations
10
+ import sqlite3
11
+ import random
12
+ from datetime import date, timedelta
13
+ from typing import Callable
14
+
15
+ # ─────────────────────────────────────────────────────────────────────────────
16
+ # SQL SCHEMAS
17
+ # ─────────────────────────────────────────────────────────────────────────────
18
+
19
+ ECOMMERCE_SCHEMA = """
20
+ CREATE TABLE IF NOT EXISTS categories (
21
+ id INTEGER PRIMARY KEY,
22
+ name TEXT NOT NULL UNIQUE
23
+ );
24
+
25
+ CREATE TABLE IF NOT EXISTS products (
26
+ id INTEGER PRIMARY KEY,
27
+ name TEXT NOT NULL,
28
+ category_id INTEGER NOT NULL REFERENCES categories(id),
29
+ price REAL NOT NULL CHECK(price >= 0),
30
+ stock_quantity INTEGER NOT NULL DEFAULT 0
31
+ );
32
+
33
+ CREATE TABLE IF NOT EXISTS customers (
34
+ id INTEGER PRIMARY KEY,
35
+ name TEXT NOT NULL,
36
+ email TEXT NOT NULL UNIQUE,
37
+ country TEXT NOT NULL,
38
+ tier TEXT NOT NULL DEFAULT 'bronze'
39
+ CHECK(tier IN ('bronze', 'silver', 'gold')),
40
+ created_at TEXT NOT NULL
41
+ );
42
+
43
+ CREATE TABLE IF NOT EXISTS orders (
44
+ id INTEGER PRIMARY KEY,
45
+ customer_id INTEGER NOT NULL REFERENCES customers(id),
46
+ status TEXT NOT NULL DEFAULT 'pending'
47
+ CHECK(status IN ('pending','processing','shipped','delivered','cancelled')),
48
+ created_at TEXT NOT NULL,
49
+ total_amount REAL NOT NULL CHECK(total_amount >= 0)
50
+ );
51
+
52
+ CREATE TABLE IF NOT EXISTS order_items (
53
+ id INTEGER PRIMARY KEY,
54
+ order_id INTEGER NOT NULL REFERENCES orders(id),
55
+ product_id INTEGER NOT NULL REFERENCES products(id),
56
+ quantity INTEGER NOT NULL CHECK(quantity > 0),
57
+ unit_price REAL NOT NULL CHECK(unit_price >= 0)
58
+ );
59
+
60
+ CREATE TABLE IF NOT EXISTS reviews (
61
+ id INTEGER PRIMARY KEY,
62
+ product_id INTEGER NOT NULL REFERENCES products(id),
63
+ customer_id INTEGER NOT NULL REFERENCES customers(id),
64
+ rating INTEGER NOT NULL CHECK(rating BETWEEN 1 AND 5),
65
+ created_at TEXT NOT NULL
66
+ );
67
+
68
+ CREATE INDEX IF NOT EXISTS idx_products_category ON products(category_id);
69
+ CREATE INDEX IF NOT EXISTS idx_orders_customer ON orders(customer_id);
70
+ CREATE INDEX IF NOT EXISTS idx_orders_status ON orders(status);
71
+ CREATE INDEX IF NOT EXISTS idx_orders_created ON orders(created_at);
72
+ CREATE INDEX IF NOT EXISTS idx_order_items_order ON order_items(order_id);
73
+ CREATE INDEX IF NOT EXISTS idx_order_items_product ON order_items(product_id);
74
+ CREATE INDEX IF NOT EXISTS idx_reviews_product ON reviews(product_id);
75
+ CREATE INDEX IF NOT EXISTS idx_customers_tier ON customers(tier);
76
+ """
77
+
78
+ HEALTHCARE_SCHEMA = """
79
+ CREATE TABLE IF NOT EXISTS patients (
80
+ id INTEGER PRIMARY KEY,
81
+ name TEXT NOT NULL,
82
+ date_of_birth TEXT NOT NULL,
83
+ gender TEXT NOT NULL CHECK(gender IN ('M','F','Other')),
84
+ blood_type TEXT NOT NULL,
85
+ country TEXT NOT NULL,
86
+ registered_at TEXT NOT NULL
87
+ );
88
+
89
+ CREATE TABLE IF NOT EXISTS doctors (
90
+ id INTEGER PRIMARY KEY,
91
+ name TEXT NOT NULL,
92
+ specialization TEXT NOT NULL,
93
+ department TEXT NOT NULL,
94
+ experience_years INTEGER NOT NULL CHECK(experience_years >= 0),
95
+ consultation_fee REAL NOT NULL CHECK(consultation_fee >= 0)
96
+ );
97
+
98
+ CREATE TABLE IF NOT EXISTS appointments (
99
+ id INTEGER PRIMARY KEY,
100
+ patient_id INTEGER NOT NULL REFERENCES patients(id),
101
+ doctor_id INTEGER NOT NULL REFERENCES doctors(id),
102
+ scheduled_at TEXT NOT NULL,
103
+ status TEXT NOT NULL
104
+ CHECK(status IN ('scheduled','completed','cancelled','no_show')),
105
+ notes TEXT
106
+ );
107
+
108
+ CREATE TABLE IF NOT EXISTS diagnoses (
109
+ id INTEGER PRIMARY KEY,
110
+ appointment_id INTEGER NOT NULL REFERENCES appointments(id),
111
+ icd_code TEXT NOT NULL,
112
+ description TEXT NOT NULL,
113
+ severity TEXT NOT NULL CHECK(severity IN ('mild','moderate','severe'))
114
+ );
115
+
116
+ CREATE TABLE IF NOT EXISTS medications (
117
+ id INTEGER PRIMARY KEY,
118
+ name TEXT NOT NULL,
119
+ category TEXT NOT NULL,
120
+ unit_price REAL NOT NULL CHECK(unit_price >= 0)
121
+ );
122
+
123
+ CREATE TABLE IF NOT EXISTS prescriptions (
124
+ id INTEGER PRIMARY KEY,
125
+ appointment_id INTEGER NOT NULL REFERENCES appointments(id),
126
+ medication_id INTEGER NOT NULL REFERENCES medications(id),
127
+ dosage TEXT NOT NULL,
128
+ duration_days INTEGER NOT NULL CHECK(duration_days > 0),
129
+ quantity INTEGER NOT NULL CHECK(quantity > 0)
130
+ );
131
+
132
+ CREATE INDEX IF NOT EXISTS idx_appt_patient ON appointments(patient_id);
133
+ CREATE INDEX IF NOT EXISTS idx_appt_doctor ON appointments(doctor_id);
134
+ CREATE INDEX IF NOT EXISTS idx_appt_status ON appointments(status);
135
+ CREATE INDEX IF NOT EXISTS idx_diag_appt ON diagnoses(appointment_id);
136
+ CREATE INDEX IF NOT EXISTS idx_presc_appt ON prescriptions(appointment_id);
137
+ CREATE INDEX IF NOT EXISTS idx_presc_med ON prescriptions(medication_id);
138
+ """
139
+
140
+ FINANCE_SCHEMA = """
141
+ CREATE TABLE IF NOT EXISTS fin_customers (
142
+ id INTEGER PRIMARY KEY,
143
+ name TEXT NOT NULL,
144
+ email TEXT NOT NULL UNIQUE,
145
+ country TEXT NOT NULL,
146
+ kyc_status TEXT NOT NULL CHECK(kyc_status IN ('pending','verified','rejected')),
147
+ created_at TEXT NOT NULL
148
+ );
149
+
150
+ CREATE TABLE IF NOT EXISTS accounts (
151
+ id INTEGER PRIMARY KEY,
152
+ customer_id INTEGER NOT NULL REFERENCES fin_customers(id),
153
+ account_type TEXT NOT NULL
154
+ CHECK(account_type IN ('savings','current','fixed_deposit','loan')),
155
+ balance REAL NOT NULL DEFAULT 0,
156
+ currency TEXT NOT NULL DEFAULT 'USD',
157
+ status TEXT NOT NULL CHECK(status IN ('active','dormant','closed')),
158
+ opened_at TEXT NOT NULL
159
+ );
160
+
161
+ CREATE TABLE IF NOT EXISTS transactions (
162
+ id INTEGER PRIMARY KEY,
163
+ account_id INTEGER NOT NULL REFERENCES accounts(id),
164
+ txn_type TEXT NOT NULL CHECK(txn_type IN ('credit','debit')),
165
+ amount REAL NOT NULL CHECK(amount > 0),
166
+ currency TEXT NOT NULL DEFAULT 'USD',
167
+ merchant TEXT,
168
+ created_at TEXT NOT NULL
169
+ );
170
+
171
+ CREATE TABLE IF NOT EXISTS loans (
172
+ id INTEGER PRIMARY KEY,
173
+ customer_id INTEGER NOT NULL REFERENCES fin_customers(id),
174
+ loan_type TEXT NOT NULL
175
+ CHECK(loan_type IN ('personal','home','auto','business')),
176
+ principal_amount REAL NOT NULL,
177
+ interest_rate REAL NOT NULL,
178
+ tenure_months INTEGER NOT NULL,
179
+ status TEXT NOT NULL CHECK(status IN ('active','closed','defaulted')),
180
+ disbursed_at TEXT NOT NULL
181
+ );
182
+
183
+ CREATE TABLE IF NOT EXISTS loan_payments (
184
+ id INTEGER PRIMARY KEY,
185
+ loan_id INTEGER NOT NULL REFERENCES loans(id),
186
+ amount_paid REAL NOT NULL CHECK(amount_paid > 0),
187
+ payment_date TEXT NOT NULL,
188
+ is_late INTEGER NOT NULL DEFAULT 0 CHECK(is_late IN (0,1))
189
+ );
190
+
191
+ CREATE INDEX IF NOT EXISTS idx_acct_customer ON accounts(customer_id);
192
+ CREATE INDEX IF NOT EXISTS idx_txn_account ON transactions(account_id);
193
+ CREATE INDEX IF NOT EXISTS idx_txn_type ON transactions(txn_type);
194
+ CREATE INDEX IF NOT EXISTS idx_loan_customer ON loans(customer_id);
195
+ CREATE INDEX IF NOT EXISTS idx_lp_loan ON loan_payments(loan_id);
196
+ """
197
+
198
+ HR_SCHEMA = """
199
+ CREATE TABLE IF NOT EXISTS departments (
200
+ id INTEGER PRIMARY KEY,
201
+ name TEXT NOT NULL UNIQUE,
202
+ location TEXT NOT NULL,
203
+ budget REAL NOT NULL CHECK(budget >= 0)
204
+ );
205
+
206
+ CREATE TABLE IF NOT EXISTS employees (
207
+ id INTEGER PRIMARY KEY,
208
+ name TEXT NOT NULL,
209
+ email TEXT NOT NULL UNIQUE,
210
+ department_id INTEGER NOT NULL REFERENCES departments(id),
211
+ job_title TEXT NOT NULL,
212
+ hire_date TEXT NOT NULL,
213
+ salary REAL NOT NULL CHECK(salary >= 0),
214
+ status TEXT NOT NULL CHECK(status IN ('active','resigned','terminated'))
215
+ );
216
+
217
+ CREATE TABLE IF NOT EXISTS performance_reviews (
218
+ id INTEGER PRIMARY KEY,
219
+ employee_id INTEGER NOT NULL REFERENCES employees(id),
220
+ review_year INTEGER NOT NULL,
221
+ rating INTEGER NOT NULL CHECK(rating BETWEEN 1 AND 5),
222
+ reviewer_id INTEGER NOT NULL REFERENCES employees(id),
223
+ comments TEXT
224
+ );
225
+
226
+ CREATE TABLE IF NOT EXISTS projects (
227
+ id INTEGER PRIMARY KEY,
228
+ name TEXT NOT NULL,
229
+ department_id INTEGER NOT NULL REFERENCES departments(id),
230
+ start_date TEXT NOT NULL,
231
+ end_date TEXT,
232
+ budget REAL NOT NULL,
233
+ status TEXT NOT NULL
234
+ CHECK(status IN ('planned','active','completed','cancelled'))
235
+ );
236
+
237
+ CREATE TABLE IF NOT EXISTS project_assignments (
238
+ id INTEGER PRIMARY KEY,
239
+ employee_id INTEGER NOT NULL REFERENCES employees(id),
240
+ project_id INTEGER NOT NULL REFERENCES projects(id),
241
+ role TEXT NOT NULL,
242
+ hours_allocated INTEGER NOT NULL CHECK(hours_allocated > 0)
243
+ );
244
+
245
+ CREATE INDEX IF NOT EXISTS idx_emp_dept ON employees(department_id);
246
+ CREATE INDEX IF NOT EXISTS idx_emp_status ON employees(status);
247
+ CREATE INDEX IF NOT EXISTS idx_pr_employee ON performance_reviews(employee_id);
248
+ CREATE INDEX IF NOT EXISTS idx_proj_dept ON projects(department_id);
249
+ CREATE INDEX IF NOT EXISTS idx_pa_employee ON project_assignments(employee_id);
250
+ CREATE INDEX IF NOT EXISTS idx_pa_project ON project_assignments(project_id);
251
+ """
252
+
253
+ # ─────────────────────────────────────────────────────────────────────────────
254
+ # SCHEMA REGISTRY
255
+ # ─────────────────────────────────────────────────────────────────────────────
256
+
257
+ SCHEMA_MAP: dict[str, str] = {
258
+ "ecommerce": ECOMMERCE_SCHEMA,
259
+ "healthcare": HEALTHCARE_SCHEMA,
260
+ "finance": FINANCE_SCHEMA,
261
+ "hr": HR_SCHEMA,
262
+ }
263
+
264
+ # ─────────────────────────────────────────────────────────────────────────────
265
+ # COMPACT SCHEMA CONTEXT (injected into every training prompt)
266
+ # ─────────────────────────────────────────────────────────────────────────────
267
+
268
+ SCHEMA_CONTEXT: dict[str, str] = {
269
+ "ecommerce": """\
270
+ Database: ecommerce (SQLite, read-only)
271
+
272
+ TABLES
273
+ ------
274
+ categories(id INTEGER PK, name TEXT)
275
+ products(id INTEGER PK, name TEXT, category_id INTEGER FK→categories.id, price REAL, stock_quantity INTEGER)
276
+ customers(id INTEGER PK, name TEXT, email TEXT, country TEXT, tier TEXT ∈ {bronze|silver|gold}, created_at TEXT ISO-8601)
277
+ orders(id INTEGER PK, customer_id INTEGER FK→customers.id, status TEXT ∈ {pending|processing|shipped|delivered|cancelled}, created_at TEXT ISO-8601, total_amount REAL)
278
+ order_items(id INTEGER PK, order_id INTEGER FK→orders.id, product_id INTEGER FK→products.id, quantity INTEGER, unit_price REAL)
279
+ reviews(id INTEGER PK, product_id INTEGER FK→products.id, customer_id INTEGER FK→customers.id, rating INTEGER 1-5, created_at TEXT ISO-8601)
280
+
281
+ NOTES
282
+ -----
283
+ - Use created_at >= '2024-01-01' for date filtering (ISO text sort works)
284
+ - SQLite window functions: RANK, DENSE_RANK, ROW_NUMBER, LAG, LEAD
285
+ - strftime('%Y-%m', created_at) returns 'YYYY-MM'
286
+ - All monetary values in USD
287
+ """,
288
+
289
+ "healthcare": """\
290
+ Database: healthcare (SQLite, read-only)
291
+
292
+ TABLES
293
+ ------
294
+ patients(id INTEGER PK, name TEXT, date_of_birth TEXT ISO-8601, gender TEXT ∈ {M|F|Other}, blood_type TEXT, country TEXT, registered_at TEXT ISO-8601)
295
+ doctors(id INTEGER PK, name TEXT, specialization TEXT, department TEXT, experience_years INTEGER, consultation_fee REAL)
296
+ appointments(id INTEGER PK, patient_id INTEGER FK→patients.id, doctor_id INTEGER FK→doctors.id, scheduled_at TEXT ISO-8601, status TEXT ∈ {scheduled|completed|cancelled|no_show}, notes TEXT nullable)
297
+ diagnoses(id INTEGER PK, appointment_id INTEGER FK→appointments.id, icd_code TEXT, description TEXT, severity TEXT ∈ {mild|moderate|severe})
298
+ medications(id INTEGER PK, name TEXT, category TEXT, unit_price REAL)
299
+ prescriptions(id INTEGER PK, appointment_id INTEGER FK→appointments.id, medication_id INTEGER FK→medications.id, dosage TEXT, duration_days INTEGER, quantity INTEGER)
300
+
301
+ NOTES
302
+ -----
303
+ - consultation_fee is in USD per visit
304
+ - ICD codes follow WHO ICD-10 format (e.g. 'I10', 'E11')
305
+ - SQLite window functions available
306
+ """,
307
+
308
+ "finance": """\
309
+ Database: finance (SQLite, read-only)
310
+
311
+ TABLES
312
+ ------
313
+ fin_customers(id INTEGER PK, name TEXT, email TEXT, country TEXT, kyc_status TEXT ∈ {pending|verified|rejected}, created_at TEXT ISO-8601)
314
+ accounts(id INTEGER PK, customer_id INTEGER FK→fin_customers.id, account_type TEXT ∈ {savings|current|fixed_deposit|loan}, balance REAL, currency TEXT, status TEXT ∈ {active|dormant|closed}, opened_at TEXT ISO-8601)
315
+ transactions(id INTEGER PK, account_id INTEGER FK→accounts.id, txn_type TEXT ∈ {credit|debit}, amount REAL, currency TEXT, merchant TEXT nullable, created_at TEXT ISO-8601)
316
+ loans(id INTEGER PK, customer_id INTEGER FK→fin_customers.id, loan_type TEXT ∈ {personal|home|auto|business}, principal_amount REAL, interest_rate REAL, tenure_months INTEGER, status TEXT ∈ {active|closed|defaulted}, disbursed_at TEXT ISO-8601)
317
+ loan_payments(id INTEGER PK, loan_id INTEGER FK→loans.id, amount_paid REAL, payment_date TEXT ISO-8601, is_late INTEGER ∈ {0|1})
318
+
319
+ NOTES
320
+ -----
321
+ - All monetary values in USD unless currency column specifies otherwise
322
+ - is_late = 1 means the payment was overdue
323
+ - SQLite window functions available
324
+ """,
325
+
326
+ "hr": """\
327
+ Database: hr (SQLite, read-only)
328
+
329
+ TABLES
330
+ ------
331
+ departments(id INTEGER PK, name TEXT, location TEXT, budget REAL)
332
+ employees(id INTEGER PK, name TEXT, email TEXT, department_id INTEGER FK→departments.id, job_title TEXT, hire_date TEXT ISO-8601, salary REAL, status TEXT ∈ {active|resigned|terminated})
333
+ performance_reviews(id INTEGER PK, employee_id INTEGER FK→employees.id, review_year INTEGER, rating INTEGER 1-5, reviewer_id INTEGER FK→employees.id, comments TEXT nullable)
334
+ projects(id INTEGER PK, name TEXT, department_id INTEGER FK→departments.id, start_date TEXT ISO-8601, end_date TEXT nullable, budget REAL, status TEXT ∈ {planned|active|completed|cancelled})
335
+ project_assignments(id INTEGER PK, employee_id INTEGER FK→employees.id, project_id INTEGER FK→projects.id, role TEXT, hours_allocated INTEGER)
336
+
337
+ NOTES
338
+ -----
339
+ - salary is annual in USD
340
+ - performance rating: 1 (lowest) to 5 (highest)
341
+ - end_date is NULL for ongoing projects
342
+ - SQLite window functions available
343
+ """,
344
+ }
345
+
346
+
347
+ # ─────────────────────────────────────────────────────────────────────────────
348
+ # SEED FUNCTIONS (deterministic, SEED=42)
349
+ # ─────────────────────────────────────────────────────────────────────────────
350
+
351
+ def _rdate(rng: random.Random, start: str = "2022-01-01", end: str = "2024-12-31") -> str:
352
+ s = date.fromisoformat(start)
353
+ e = date.fromisoformat(end)
354
+ return (s + timedelta(days=rng.randint(0, (e - s).days))).isoformat()
355
+
356
+
357
+ def seed_ecommerce(conn: sqlite3.Connection, seed: int = 42) -> None:
358
+ rng = random.Random(seed)
359
+ cats = ["Electronics", "Clothing", "Books", "Home & Garden",
360
+ "Sports & Outdoors", "Toys & Games", "Beauty", "Automotive"]
361
+ conn.executemany("INSERT INTO categories(id,name) VALUES(?,?)", enumerate(cats, 1))
362
+
363
+ products = [
364
+ (1,"Wireless Headphones",1,149.99,50),(2,"Laptop Stand",1,59.99,120),
365
+ (3,"USB-C Hub",1,49.99,90),(4,"Webcam 4K",1,89.99,30),
366
+ (5,"Cotton T-Shirt",2,19.99,200),(6,"Winter Jacket",2,129.99,60),
367
+ (7,"Running Shorts",2,34.99,150),(8,"Clean Code",3,39.99,80),
368
+ (9,"Deep Learning Book",3,59.99,45),(10,"Coffee Maker",4,89.99,40),
369
+ (11,"Air Purifier",4,199.99,25),(12,"Yoga Mat",5,29.99,150),
370
+ (13,"Resistance Bands",5,14.99,200),(14,"Lego City Set",6,79.99,60),
371
+ (15,"Face Serum",7,34.99,100),(16,"Dash Cam",8,119.99,35),
372
+ ]
373
+ conn.executemany("INSERT INTO products VALUES(?,?,?,?,?)", products)
374
+
375
+ countries = ["India","USA","Germany","UK","Canada","Australia","France","Brazil"]
376
+ tiers = ["bronze","silver","gold"]
377
+ customers = []
378
+ for i in range(1, 51):
379
+ customers.append((i, f"Customer {i}", f"cust{i}@shop.com",
380
+ rng.choice(countries), rng.choice(tiers), _rdate(rng)))
381
+ conn.executemany("INSERT INTO customers VALUES(?,?,?,?,?,?)", customers)
382
+
383
+ statuses = ["pending","processing","shipped","delivered","cancelled"]
384
+ orders = []
385
+ for i in range(1, 201):
386
+ orders.append((i, rng.randint(1, 50), rng.choice(statuses),
387
+ _rdate(rng), round(rng.uniform(20, 800), 2)))
388
+ conn.executemany("INSERT INTO orders VALUES(?,?,?,?,?)", orders)
389
+
390
+ items = []
391
+ for i in range(1, 301):
392
+ items.append((i, rng.randint(1, 200), rng.randint(1, 16),
393
+ rng.randint(1, 5), round(rng.uniform(10, 200), 2)))
394
+ conn.executemany("INSERT INTO order_items VALUES(?,?,?,?,?)", items)
395
+
396
+ reviews = []
397
+ for i in range(1, 151):
398
+ reviews.append((i, rng.randint(1, 16), rng.randint(1, 50),
399
+ rng.randint(1, 5), _rdate(rng)))
400
+ conn.executemany("INSERT INTO reviews VALUES(?,?,?,?,?)", reviews)
401
+ conn.commit()
402
+
403
+
404
+ def seed_healthcare(conn: sqlite3.Connection, seed: int = 42) -> None:
405
+ rng = random.Random(seed)
406
+ specs = [("Cardiology","Cardiology"), ("Neurology","Neurology"),
407
+ ("Orthopedics","Orthopedics"), ("Dermatology","Dermatology"),
408
+ ("Pediatrics","Pediatrics"), ("Oncology","Oncology"),
409
+ ("Endocrinology","Endocrinology"), ("Gastroenterology","Gastroenterology")]
410
+ for i, (spec, dept) in enumerate(specs, 1):
411
+ conn.execute("INSERT INTO doctors VALUES(?,?,?,?,?,?)",
412
+ (i, f"Dr. {['Smith','Patel','Kim','Müller','Okafor','Chen','Lopez','Roy'][i-1]}",
413
+ spec, dept, rng.randint(2, 25), round(rng.uniform(50, 350), 2)))
414
+
415
+ genders = ["M", "F", "Other"]
416
+ blood_types = ["A+","A-","B+","B-","O+","O-","AB+","AB-"]
417
+ countries = ["India","USA","Germany","UK","Canada","Australia"]
418
+ for i in range(1, 101):
419
+ conn.execute("INSERT INTO patients VALUES(?,?,?,?,?,?,?)",
420
+ (i, f"Patient {i}", _rdate(rng, "1950-01-01", "2010-01-01"),
421
+ rng.choice(genders), rng.choice(blood_types),
422
+ rng.choice(countries), _rdate(rng, "2020-01-01", "2024-12-31")))
423
+
424
+ appt_statuses = ["scheduled", "completed", "cancelled", "no_show"]
425
+ weights = [0.15, 0.60, 0.15, 0.10]
426
+ for i in range(1, 301):
427
+ conn.execute("INSERT INTO appointments VALUES(?,?,?,?,?,?)",
428
+ (i, rng.randint(1, 100), rng.randint(1, 8),
429
+ _rdate(rng, "2022-01-01", "2024-12-31"),
430
+ rng.choices(appt_statuses, weights)[0], None))
431
+
432
+ icd_codes = ["I10","E11","J45","M54","K21","F32","G43","L30","N39","R05",
433
+ "C50","Z87","I25","E78","J18"]
434
+ descs = ["Hypertension","Type 2 Diabetes","Asthma","Back Pain","GERD",
435
+ "Depression","Migraine","Dermatitis","UTI","Cough",
436
+ "Breast Cancer","Family History","Coronary Artery Disease",
437
+ "Hyperlipidemia","Pneumonia"]
438
+ severities = ["mild","moderate","severe"]
439
+ for i in range(1, 201):
440
+ conn.execute("INSERT INTO diagnoses VALUES(?,?,?,?,?)",
441
+ (i, rng.randint(1, 300), rng.choice(icd_codes),
442
+ rng.choice(descs), rng.choice(severities)))
443
+
444
+ meds = [("Metformin","Antidiabetic",0.15),("Lisinopril","Antihypertensive",0.20),
445
+ ("Atorvastatin","Statin",0.25),("Amoxicillin","Antibiotic",0.30),
446
+ ("Ibuprofen","NSAID",0.10),("Omeprazole","PPI",0.18),
447
+ ("Sertraline","Antidepressant",0.35),("Cetirizine","Antihistamine",0.08),
448
+ ("Paracetamol","Analgesic",0.05),("Aspirin","Antiplatelet",0.07)]
449
+ for i, (name, cat, price) in enumerate(meds, 1):
450
+ conn.execute("INSERT INTO medications VALUES(?,?,?,?)", (i, name, cat, price))
451
+
452
+ dosages = ["1x daily","2x daily","3x daily","once at night","as needed"]
453
+ for i in range(1, 251):
454
+ conn.execute("INSERT INTO prescriptions VALUES(?,?,?,?,?,?)",
455
+ (i, rng.randint(1, 300), rng.randint(1, 10),
456
+ rng.choice(dosages), rng.randint(5, 60), rng.randint(10, 90)))
457
+ conn.commit()
458
+
459
+
460
+ def seed_finance(conn: sqlite3.Connection, seed: int = 42) -> None:
461
+ rng = random.Random(seed)
462
+ countries = ["India","USA","Germany","UK","Singapore","UAE","Canada"]
463
+ kyc = ["pending","verified","verified","verified","rejected"]
464
+ for i in range(1, 51):
465
+ conn.execute("INSERT INTO fin_customers VALUES(?,?,?,?,?,?)",
466
+ (i, f"FinClient {i}", f"fincli{i}@bank.com",
467
+ rng.choice(countries), rng.choice(kyc), _rdate(rng)))
468
+
469
+ acct_types = ["savings","savings","current","fixed_deposit"]
470
+ statuses = ["active","active","active","dormant","closed"]
471
+ for i in range(1, 101):
472
+ conn.execute("INSERT INTO accounts VALUES(?,?,?,?,?,?,?)",
473
+ (i, rng.randint(1, 50), rng.choice(acct_types),
474
+ round(rng.uniform(100, 100000), 2), "USD",
475
+ rng.choice(statuses), _rdate(rng)))
476
+
477
+ merchants = [None, "Amazon", "Walmart", "Netflix", "Uber", "Apple",
478
+ "Google Pay", "Zomato", "Flipkart", "Airbnb"]
479
+ for i in range(1, 501):
480
+ conn.execute("INSERT INTO transactions VALUES(?,?,?,?,?,?,?)",
481
+ (i, rng.randint(1, 100), rng.choice(["credit","debit"]),
482
+ round(rng.uniform(5, 10000), 2), "USD",
483
+ rng.choice(merchants), _rdate(rng)))
484
+
485
+ loan_types = ["personal","home","auto","business"]
486
+ loan_statuses = ["active","active","closed","defaulted"]
487
+ for i in range(1, 51):
488
+ conn.execute("INSERT INTO loans VALUES(?,?,?,?,?,?,?,?)",
489
+ (i, rng.randint(1, 50), rng.choice(loan_types),
490
+ round(rng.uniform(5000, 500000), 2),
491
+ round(rng.uniform(5, 18), 2), rng.randint(12, 360),
492
+ rng.choice(loan_statuses), _rdate(rng)))
493
+
494
+ for i in range(1, 201):
495
+ conn.execute("INSERT INTO loan_payments VALUES(?,?,?,?,?)",
496
+ (i, rng.randint(1, 50), round(rng.uniform(500, 10000), 2),
497
+ _rdate(rng), rng.randint(0, 1)))
498
+ conn.commit()
499
+
500
+
501
+ def seed_hr(conn: sqlite3.Connection, seed: int = 42) -> None:
502
+ rng = random.Random(seed)
503
+ depts = [("Engineering","Bangalore",8000000),("Marketing","Mumbai",3000000),
504
+ ("Finance","Delhi",2000000),("HR","Chennai",1500000),
505
+ ("Sales","Hyderabad",5000000),("Product","Pune",4000000),
506
+ ("Legal","Delhi",1000000),("Operations","Kolkata",2500000)]
507
+ for i, (name, loc, bud) in enumerate(depts, 1):
508
+ conn.execute("INSERT INTO departments VALUES(?,?,?,?)", (i, name, loc, bud))
509
+
510
+ titles = ["Software Engineer","Senior Engineer","Staff Engineer","Principal Engineer",
511
+ "Engineering Manager","Product Manager","Data Analyst","Data Scientist",
512
+ "Marketing Specialist","Sales Executive","HR Specialist","Finance Analyst",
513
+ "Director","VP","Legal Counsel"]
514
+ statuses = ["active","active","active","active","resigned","terminated"]
515
+ for i in range(1, 101):
516
+ conn.execute("INSERT INTO employees VALUES(?,?,?,?,?,?,?,?)",
517
+ (i, f"Employee {i}", f"emp{i}@corp.com",
518
+ rng.randint(1, 8), rng.choice(titles),
519
+ _rdate(rng, "2015-01-01", "2024-01-01"),
520
+ round(rng.uniform(25000, 200000), 2), rng.choice(statuses)))
521
+
522
+ for i in range(1, 201):
523
+ conn.execute("INSERT INTO performance_reviews VALUES(?,?,?,?,?,?)",
524
+ (i, rng.randint(1, 100), rng.randint(2019, 2024),
525
+ rng.randint(1, 5), rng.randint(1, 100),
526
+ rng.choice(["Excellent work","Good performance","Needs improvement",
527
+ "Outstanding","Meeting expectations"])))
528
+
529
+ proj_statuses = ["planned","active","active","completed","cancelled"]
530
+ for i in range(1, 51):
531
+ sd = _rdate(rng, "2021-01-01", "2024-01-01")
532
+ conn.execute("INSERT INTO projects VALUES(?,?,?,?,?,?,?)",
533
+ (i, f"Project {i}", rng.randint(1, 8), sd,
534
+ _rdate(rng, sd, "2025-06-01") if rng.random() > 0.25 else None,
535
+ round(rng.uniform(50000, 2000000), 2), rng.choice(proj_statuses)))
536
+
537
+ roles = ["Lead","Senior Developer","Developer","Tester","Analyst","DevOps"]
538
+ for i in range(1, 251):
539
+ conn.execute("INSERT INTO project_assignments VALUES(?,?,?,?,?)",
540
+ (i, rng.randint(1, 100), rng.randint(1, 50),
541
+ rng.choice(roles), rng.randint(20, 400)))
542
+ conn.commit()
543
+
544
+
545
+ # ─────────────────────────────────────────────────────────────────────────────
546
+ # REGISTRY
547
+ # ─────────────────────────────────────────────────────────────────────────────
548
+
549
+ SEED_MAP: dict[str, Callable] = {
550
+ "ecommerce": seed_ecommerce,
551
+ "healthcare": seed_healthcare,
552
+ "finance": seed_finance,
553
+ "hr": seed_hr,
554
+ }
555
+
556
+
557
+ def build_connection(domain: str, seed: int = 42) -> sqlite3.Connection:
558
+ """Return a seeded in-memory SQLite connection for the given domain."""
559
+ conn = sqlite3.connect(":memory:", check_same_thread=False)
560
+ conn.row_factory = sqlite3.Row
561
+ conn.execute("PRAGMA foreign_keys = ON")
562
+ conn.executescript(SCHEMA_MAP[domain])
563
+ SEED_MAP[domain](conn, seed=seed)
564
+ return conn
data_factory/templates.py ADDED
@@ -0,0 +1,993 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ data_factory/templates.py
3
+ ==========================
4
+ Human-authored, execution-verified SQL templates across 4 domains × 3 difficulty tiers.
5
+
6
+ CRITICAL DESIGN PRINCIPLE:
7
+ SQL is NEVER generated by an LLM in this pipeline.
8
+ Every SQL here was written by hand and verified by running it against
9
+ seeded SQLite data. Zero errors guaranteed.
10
+
11
+ Structure per entry:
12
+ {
13
+ "domain": str, # ecommerce | healthcare | finance | hr
14
+ "difficulty": str, # easy | medium | hard
15
+ "sql": str, # verified ground-truth SQL
16
+ "description": str, # one-line English summary (seed for NL generation)
17
+ "base_nl": str, # canonical natural-language question
18
+ "has_order": bool, # True → comparison is order-sensitive
19
+ }
20
+ """
21
+
22
+ from __future__ import annotations
23
+ from typing import TypedDict
24
+
25
+
26
+ class Template(TypedDict):
27
+ domain: str
28
+ difficulty: str
29
+ sql: str
30
+ description: str
31
+ base_nl: str
32
+ has_order: bool
33
+
34
+
35
+ # ─────────────────────────────────────────────────────────────────────────────
36
+ # DOMAIN: ECOMMERCE
37
+ # ─────────────────────────────────────────────────────────────────────────────
38
+
39
+ ECOMMERCE_TEMPLATES: list[Template] = [
40
+
41
+ # ── EASY ────────────────────────────────────────────────────────────────
42
+
43
+ {
44
+ "domain": "ecommerce", "difficulty": "easy", "has_order": True,
45
+ "description": "List gold-tier customers sorted alphabetically with id, name, email, country",
46
+ "base_nl": "List all gold-tier customers ordered by name alphabetically. Return id, name, email, country.",
47
+ "sql": (
48
+ "SELECT id, name, email, country "
49
+ "FROM customers "
50
+ "WHERE tier = 'gold' "
51
+ "ORDER BY name ASC"
52
+ ),
53
+ },
54
+ {
55
+ "domain": "ecommerce", "difficulty": "easy", "has_order": True,
56
+ "description": "Products priced above $100, sorted by price descending",
57
+ "base_nl": "Show all products with a price above $100, sorted from highest to lowest price. Return id, name, price.",
58
+ "sql": (
59
+ "SELECT id, name, price "
60
+ "FROM products "
61
+ "WHERE price > 100 "
62
+ "ORDER BY price DESC"
63
+ ),
64
+ },
65
+ {
66
+ "domain": "ecommerce", "difficulty": "easy", "has_order": True,
67
+ "description": "Delivered orders with total_amount > 200, sorted by amount descending",
68
+ "base_nl": "Find all delivered orders with a total amount greater than $200, sorted by total amount descending. Return id, customer_id, total_amount, created_at.",
69
+ "sql": (
70
+ "SELECT id, customer_id, total_amount, created_at "
71
+ "FROM orders "
72
+ "WHERE status = 'delivered' "
73
+ " AND total_amount > 200 "
74
+ "ORDER BY total_amount DESC"
75
+ ),
76
+ },
77
+ {
78
+ "domain": "ecommerce", "difficulty": "easy", "has_order": True,
79
+ "description": "Top 5 most expensive products",
80
+ "base_nl": "Return the top 5 most expensive products. Return id, name, price.",
81
+ "sql": (
82
+ "SELECT id, name, price "
83
+ "FROM products "
84
+ "ORDER BY price DESC "
85
+ "LIMIT 5"
86
+ ),
87
+ },
88
+ {
89
+ "domain": "ecommerce", "difficulty": "easy", "has_order": True,
90
+ "description": "Distinct countries where customers come from, sorted alphabetically",
91
+ "base_nl": "List all distinct countries our customers come from, sorted alphabetically. Return country.",
92
+ "sql": (
93
+ "SELECT DISTINCT country "
94
+ "FROM customers "
95
+ "ORDER BY country ASC"
96
+ ),
97
+ },
98
+ {
99
+ "domain": "ecommerce", "difficulty": "easy", "has_order": False,
100
+ "description": "Count total number of customers",
101
+ "base_nl": "How many customers do we have in total? Return a single column total_customers.",
102
+ "sql": "SELECT COUNT(*) AS total_customers FROM customers",
103
+ },
104
+ {
105
+ "domain": "ecommerce", "difficulty": "easy", "has_order": True,
106
+ "description": "Products with zero stock",
107
+ "base_nl": "List all out-of-stock products. Return id, name, stock_quantity.",
108
+ "sql": (
109
+ "SELECT id, name, stock_quantity "
110
+ "FROM products "
111
+ "WHERE stock_quantity = 0 "
112
+ "ORDER BY name ASC"
113
+ ),
114
+ },
115
+ {
116
+ "domain": "ecommerce", "difficulty": "easy", "has_order": True,
117
+ "description": "Customers from India sorted by name",
118
+ "base_nl": "Show all customers from India, sorted by name. Return id, name, email.",
119
+ "sql": (
120
+ "SELECT id, name, email "
121
+ "FROM customers "
122
+ "WHERE country = 'India' "
123
+ "ORDER BY name ASC"
124
+ ),
125
+ },
126
+ {
127
+ "domain": "ecommerce", "difficulty": "easy", "has_order": True,
128
+ "description": "Products in a price range of $20 to $100 sorted by price ascending",
129
+ "base_nl": "Which products are priced between $20 and $100? Sort by price ascending. Return id, name, price.",
130
+ "sql": (
131
+ "SELECT id, name, price "
132
+ "FROM products "
133
+ "WHERE price BETWEEN 20 AND 100 "
134
+ "ORDER BY price ASC"
135
+ ),
136
+ },
137
+ {
138
+ "domain": "ecommerce", "difficulty": "easy", "has_order": False,
139
+ "description": "Count orders by status",
140
+ "base_nl": "How many orders are there for each status? Return status, order_count.",
141
+ "sql": (
142
+ "SELECT status, COUNT(*) AS order_count "
143
+ "FROM orders "
144
+ "GROUP BY status"
145
+ ),
146
+ },
147
+
148
+ # ── MEDIUM ───────────────────────────────────────────────────────────────
149
+
150
+ {
151
+ "domain": "ecommerce", "difficulty": "medium", "has_order": True,
152
+ "description": "Order count per customer including those with zero orders, sorted by count desc",
153
+ "base_nl": "How many orders has each customer placed? Include customers with zero orders. Return customer_name, order_count, sorted by order_count descending then customer_name ascending.",
154
+ "sql": (
155
+ "SELECT c.name AS customer_name, COUNT(o.id) AS order_count "
156
+ "FROM customers c "
157
+ "LEFT JOIN orders o ON c.id = o.customer_id "
158
+ "GROUP BY c.id, c.name "
159
+ "ORDER BY order_count DESC, customer_name ASC"
160
+ ),
161
+ },
162
+ {
163
+ "domain": "ecommerce", "difficulty": "medium", "has_order": True,
164
+ "description": "Average product rating per category sorted descending",
165
+ "base_nl": "What is the average product rating per category? Only include categories with at least one review. Return category_name, avg_rating (rounded to 2 decimal places), sorted by avg_rating descending.",
166
+ "sql": (
167
+ "SELECT c.name AS category_name, "
168
+ " ROUND(AVG(r.rating), 2) AS avg_rating "
169
+ "FROM categories c "
170
+ "JOIN products p ON p.category_id = c.id "
171
+ "JOIN reviews r ON r.product_id = p.id "
172
+ "GROUP BY c.id, c.name "
173
+ "ORDER BY avg_rating DESC"
174
+ ),
175
+ },
176
+ {
177
+ "domain": "ecommerce", "difficulty": "medium", "has_order": True,
178
+ "description": "Customers who spent more than $500 on delivered orders",
179
+ "base_nl": "Which customers have spent more than $500 total on delivered orders? Return customer_name, total_spent (rounded to 2 decimal places), sorted by total_spent descending.",
180
+ "sql": (
181
+ "SELECT c.name AS customer_name, "
182
+ " ROUND(SUM(o.total_amount), 2) AS total_spent "
183
+ "FROM customers c "
184
+ "JOIN orders o ON o.customer_id = c.id "
185
+ "WHERE o.status = 'delivered' "
186
+ "GROUP BY c.id, c.name "
187
+ "HAVING SUM(o.total_amount) > 500 "
188
+ "ORDER BY total_spent DESC"
189
+ ),
190
+ },
191
+ {
192
+ "domain": "ecommerce", "difficulty": "medium", "has_order": True,
193
+ "description": "Total quantity sold per product sorted descending",
194
+ "base_nl": "Show the total quantity sold for each product that appears in at least one order. Return product_name, total_quantity_sold, sorted by total_quantity_sold descending.",
195
+ "sql": (
196
+ "SELECT p.name AS product_name, "
197
+ " SUM(oi.quantity) AS total_quantity_sold "
198
+ "FROM products p "
199
+ "JOIN order_items oi ON oi.product_id = p.id "
200
+ "GROUP BY p.id, p.name "
201
+ "ORDER BY total_quantity_sold DESC"
202
+ ),
203
+ },
204
+ {
205
+ "domain": "ecommerce", "difficulty": "medium", "has_order": True,
206
+ "description": "Product count and average price per category sorted by count desc",
207
+ "base_nl": "For each category, show the number of products and their average price. Return category_name, product_count, avg_price (rounded to 2 decimal places), sorted by product_count descending.",
208
+ "sql": (
209
+ "SELECT cat.name AS category_name, "
210
+ " COUNT(p.id) AS product_count, "
211
+ " ROUND(AVG(p.price), 2) AS avg_price "
212
+ "FROM categories cat "
213
+ "JOIN products p ON p.category_id = cat.id "
214
+ "GROUP BY cat.id, cat.name "
215
+ "ORDER BY product_count DESC"
216
+ ),
217
+ },
218
+ {
219
+ "domain": "ecommerce", "difficulty": "medium", "has_order": True,
220
+ "description": "Categories with more than 5 in-stock products sorted by count desc",
221
+ "base_nl": "Which categories have more than 5 products in stock (stock_quantity > 0)? Return category_name, in_stock_count, sorted by in_stock_count descending.",
222
+ "sql": (
223
+ "SELECT c.name AS category_name, "
224
+ " COUNT(p.id) AS in_stock_count "
225
+ "FROM categories c "
226
+ "JOIN products p ON p.category_id = c.id "
227
+ "WHERE p.stock_quantity > 0 "
228
+ "GROUP BY c.id, c.name "
229
+ "HAVING COUNT(p.id) > 5 "
230
+ "ORDER BY in_stock_count DESC"
231
+ ),
232
+ },
233
+ {
234
+ "domain": "ecommerce", "difficulty": "medium", "has_order": True,
235
+ "description": "Total revenue per product from order items, sorted descending",
236
+ "base_nl": "What is the total revenue generated by each product from order items? Return product_name, total_revenue (rounded to 2 decimal places), sorted by total_revenue descending.",
237
+ "sql": (
238
+ "SELECT p.name AS product_name, "
239
+ " ROUND(SUM(oi.quantity * oi.unit_price), 2) AS total_revenue "
240
+ "FROM products p "
241
+ "JOIN order_items oi ON oi.product_id = p.id "
242
+ "GROUP BY p.id, p.name "
243
+ "ORDER BY total_revenue DESC"
244
+ ),
245
+ },
246
+
247
+ # ── HARD ─────────────────────────────────────────────────────────────────
248
+
249
+ {
250
+ "domain": "ecommerce", "difficulty": "hard", "has_order": True,
251
+ "description": "Customer spending rank using DENSE_RANK on delivered orders",
252
+ "base_nl": "Rank customers by total spending on delivered orders using DENSE_RANK (rank 1 = highest spender). Return customer_name, total_spent (rounded to 2 decimal places), spending_rank, sorted by spending_rank ascending.",
253
+ "sql": (
254
+ "SELECT customer_name, total_spent, spending_rank "
255
+ "FROM ( "
256
+ " SELECT c.name AS customer_name, "
257
+ " ROUND(SUM(o.total_amount), 2) AS total_spent, "
258
+ " DENSE_RANK() OVER (ORDER BY SUM(o.total_amount) DESC) AS spending_rank "
259
+ " FROM customers c "
260
+ " JOIN orders o ON o.customer_id = c.id "
261
+ " WHERE o.status = 'delivered' "
262
+ " GROUP BY c.id, c.name "
263
+ ") sub "
264
+ "ORDER BY spending_rank ASC"
265
+ ),
266
+ },
267
+ {
268
+ "domain": "ecommerce", "difficulty": "hard", "has_order": True,
269
+ "description": "Monthly delivered revenue with running total using window SUM",
270
+ "base_nl": "Show the monthly revenue from delivered orders and its running cumulative total. Return month (YYYY-MM), monthly_revenue, running_total (both rounded to 2 decimal places), sorted by month ascending.",
271
+ "sql": (
272
+ "WITH monthly AS ( "
273
+ " SELECT strftime('%Y-%m', created_at) AS month, "
274
+ " ROUND(SUM(total_amount), 2) AS monthly_revenue "
275
+ " FROM orders "
276
+ " WHERE status = 'delivered' "
277
+ " GROUP BY strftime('%Y-%m', created_at) "
278
+ ") "
279
+ "SELECT month, "
280
+ " monthly_revenue, "
281
+ " ROUND(SUM(monthly_revenue) OVER (ORDER BY month), 2) AS running_total "
282
+ "FROM monthly "
283
+ "ORDER BY month ASC"
284
+ ),
285
+ },
286
+ {
287
+ "domain": "ecommerce", "difficulty": "hard", "has_order": True,
288
+ "description": "Customers whose most recent order was cancelled, using ROW_NUMBER CTE",
289
+ "base_nl": "Find all customers whose most recent order has status 'cancelled'. Use ROW_NUMBER to identify the latest order per customer. Return customer_name, last_order_status, last_order_date, sorted by customer_name ascending.",
290
+ "sql": (
291
+ "WITH ranked_orders AS ( "
292
+ " SELECT customer_id, status, created_at, "
293
+ " ROW_NUMBER() OVER (PARTITION BY customer_id "
294
+ " ORDER BY created_at DESC) AS rn "
295
+ " FROM orders "
296
+ ") "
297
+ "SELECT c.name AS customer_name, "
298
+ " ro.status AS last_order_status, "
299
+ " ro.created_at AS last_order_date "
300
+ "FROM customers c "
301
+ "JOIN ranked_orders ro ON ro.customer_id = c.id "
302
+ "WHERE ro.rn = 1 "
303
+ " AND ro.status = 'cancelled' "
304
+ "ORDER BY customer_name ASC"
305
+ ),
306
+ },
307
+ {
308
+ "domain": "ecommerce", "difficulty": "hard", "has_order": True,
309
+ "description": "Products above their category average rating, using two CTEs",
310
+ "base_nl": "Find products whose average rating is strictly above the average rating of all products in their category. Return product_name, category_name, product_avg_rating, category_avg_rating (both rounded to 2 decimal places), sorted by product_avg_rating descending then product_name ascending.",
311
+ "sql": (
312
+ "WITH product_ratings AS ( "
313
+ " SELECT p.id AS product_id, p.name AS product_name, "
314
+ " p.category_id, c.name AS category_name, "
315
+ " ROUND(AVG(r.rating), 2) AS product_avg_rating "
316
+ " FROM products p "
317
+ " JOIN reviews r ON r.product_id = p.id "
318
+ " JOIN categories c ON c.id = p.category_id "
319
+ " GROUP BY p.id, p.name, p.category_id, c.name "
320
+ "), "
321
+ "category_ratings AS ( "
322
+ " SELECT category_id, "
323
+ " ROUND(AVG(product_avg_rating), 2) AS category_avg_rating "
324
+ " FROM product_ratings "
325
+ " GROUP BY category_id "
326
+ ") "
327
+ "SELECT pr.product_name, pr.category_name, "
328
+ " pr.product_avg_rating, cr.category_avg_rating "
329
+ "FROM product_ratings pr "
330
+ "JOIN category_ratings cr ON cr.category_id = pr.category_id "
331
+ "WHERE pr.product_avg_rating > cr.category_avg_rating "
332
+ "ORDER BY pr.product_avg_rating DESC, pr.product_name ASC"
333
+ ),
334
+ },
335
+ ]
336
+
337
+
338
+ # ─────────────────────────────────────────────────────────────────────────────
339
+ # DOMAIN: HEALTHCARE
340
+ # ─────────────────────────────────────────────────────────────────────────────
341
+
342
+ HEALTHCARE_TEMPLATES: list[Template] = [
343
+
344
+ # ── EASY ────────────────────────────────────────────────────────────────
345
+
346
+ {
347
+ "domain": "healthcare", "difficulty": "easy", "has_order": True,
348
+ "description": "Doctors sorted by consultation fee descending",
349
+ "base_nl": "List all doctors sorted by consultation fee from highest to lowest. Return id, name, specialization, consultation_fee.",
350
+ "sql": (
351
+ "SELECT id, name, specialization, consultation_fee "
352
+ "FROM doctors "
353
+ "ORDER BY consultation_fee DESC"
354
+ ),
355
+ },
356
+ {
357
+ "domain": "healthcare", "difficulty": "easy", "has_order": True,
358
+ "description": "Doctors with more than 10 years experience sorted desc",
359
+ "base_nl": "Show doctors with more than 10 years of experience, sorted by experience descending. Return id, name, specialization, experience_years.",
360
+ "sql": (
361
+ "SELECT id, name, specialization, experience_years "
362
+ "FROM doctors "
363
+ "WHERE experience_years > 10 "
364
+ "ORDER BY experience_years DESC"
365
+ ),
366
+ },
367
+ {
368
+ "domain": "healthcare", "difficulty": "easy", "has_order": True,
369
+ "description": "Patients from India sorted by name",
370
+ "base_nl": "List all patients from India sorted alphabetically by name. Return id, name, country, blood_type.",
371
+ "sql": (
372
+ "SELECT id, name, country, blood_type "
373
+ "FROM patients "
374
+ "WHERE country = 'India' "
375
+ "ORDER BY name ASC"
376
+ ),
377
+ },
378
+ {
379
+ "domain": "healthcare", "difficulty": "easy", "has_order": True,
380
+ "description": "Medications with unit price under $0.20 sorted ascending",
381
+ "base_nl": "Which medications cost less than $0.20 per unit? Sort by price ascending. Return id, name, category, unit_price.",
382
+ "sql": (
383
+ "SELECT id, name, category, unit_price "
384
+ "FROM medications "
385
+ "WHERE unit_price < 0.20 "
386
+ "ORDER BY unit_price ASC"
387
+ ),
388
+ },
389
+ {
390
+ "domain": "healthcare", "difficulty": "easy", "has_order": True,
391
+ "description": "Top 5 most expensive medications",
392
+ "base_nl": "What are the top 5 most expensive medications? Return id, name, unit_price.",
393
+ "sql": (
394
+ "SELECT id, name, unit_price "
395
+ "FROM medications "
396
+ "ORDER BY unit_price DESC "
397
+ "LIMIT 5"
398
+ ),
399
+ },
400
+ {
401
+ "domain": "healthcare", "difficulty": "easy", "has_order": False,
402
+ "description": "Count of completed appointments",
403
+ "base_nl": "How many appointments have been completed? Return a single value total_completed.",
404
+ "sql": (
405
+ "SELECT COUNT(*) AS total_completed "
406
+ "FROM appointments "
407
+ "WHERE status = 'completed'"
408
+ ),
409
+ },
410
+ {
411
+ "domain": "healthcare", "difficulty": "easy", "has_order": True,
412
+ "description": "Severe diagnoses sorted by ICD code",
413
+ "base_nl": "List all severe diagnoses sorted by ICD code. Return id, icd_code, description, severity.",
414
+ "sql": (
415
+ "SELECT id, icd_code, description, severity "
416
+ "FROM diagnoses "
417
+ "WHERE severity = 'severe' "
418
+ "ORDER BY icd_code ASC"
419
+ ),
420
+ },
421
+ {
422
+ "domain": "healthcare", "difficulty": "easy", "has_order": False,
423
+ "description": "Count patients by gender",
424
+ "base_nl": "How many patients are there by gender? Return gender, patient_count.",
425
+ "sql": (
426
+ "SELECT gender, COUNT(*) AS patient_count "
427
+ "FROM patients "
428
+ "GROUP BY gender"
429
+ ),
430
+ },
431
+
432
+ # ── MEDIUM ───────────────────────────────────────────────────────────────
433
+
434
+ {
435
+ "domain": "healthcare", "difficulty": "medium", "has_order": True,
436
+ "description": "Appointment count per doctor including those with no appointments",
437
+ "base_nl": "How many appointments has each doctor had (including those with none)? Return doctor_name, appointment_count, sorted by appointment_count descending.",
438
+ "sql": (
439
+ "SELECT d.name AS doctor_name, COUNT(a.id) AS appointment_count "
440
+ "FROM doctors d "
441
+ "LEFT JOIN appointments a ON a.doctor_id = d.id "
442
+ "GROUP BY d.id, d.name "
443
+ "ORDER BY appointment_count DESC"
444
+ ),
445
+ },
446
+ {
447
+ "domain": "healthcare", "difficulty": "medium", "has_order": True,
448
+ "description": "Most prescribed medications by count",
449
+ "base_nl": "Which medications are prescribed most often? Return medication_name, category, times_prescribed, sorted by times_prescribed descending.",
450
+ "sql": (
451
+ "SELECT m.name AS medication_name, m.category, COUNT(p.id) AS times_prescribed "
452
+ "FROM medications m "
453
+ "JOIN prescriptions p ON p.medication_id = m.id "
454
+ "GROUP BY m.id, m.name, m.category "
455
+ "ORDER BY times_prescribed DESC"
456
+ ),
457
+ },
458
+ {
459
+ "domain": "healthcare", "difficulty": "medium", "has_order": True,
460
+ "description": "Patients with more than one completed visit",
461
+ "base_nl": "Which patients have had more than one completed appointment? Return patient_name, visit_count, sorted by visit_count descending.",
462
+ "sql": (
463
+ "SELECT pat.name AS patient_name, COUNT(DISTINCT a.id) AS visit_count "
464
+ "FROM patients pat "
465
+ "JOIN appointments a ON a.patient_id = pat.id "
466
+ "WHERE a.status = 'completed' "
467
+ "GROUP BY pat.id, pat.name "
468
+ "HAVING COUNT(DISTINCT a.id) > 1 "
469
+ "ORDER BY visit_count DESC"
470
+ ),
471
+ },
472
+ {
473
+ "domain": "healthcare", "difficulty": "medium", "has_order": True,
474
+ "description": "Estimated revenue per doctor from completed appointments",
475
+ "base_nl": "What is the estimated total revenue per doctor from completed appointments (based on consultation fee)? Return doctor_name, specialization, estimated_revenue (rounded to 2 decimal places), sorted by estimated_revenue descending.",
476
+ "sql": (
477
+ "SELECT d.name AS doctor_name, d.specialization, "
478
+ " ROUND(SUM(d.consultation_fee), 2) AS estimated_revenue "
479
+ "FROM doctors d "
480
+ "JOIN appointments a ON a.doctor_id = d.id "
481
+ "WHERE a.status = 'completed' "
482
+ "GROUP BY d.id, d.name, d.specialization "
483
+ "ORDER BY estimated_revenue DESC"
484
+ ),
485
+ },
486
+ {
487
+ "domain": "healthcare", "difficulty": "medium", "has_order": True,
488
+ "description": "Diagnosis count per severity level",
489
+ "base_nl": "How many diagnoses are there per severity level? Return severity, diagnosis_count, sorted by diagnosis_count descending.",
490
+ "sql": (
491
+ "SELECT severity, COUNT(*) AS diagnosis_count "
492
+ "FROM diagnoses "
493
+ "GROUP BY severity "
494
+ "ORDER BY diagnosis_count DESC"
495
+ ),
496
+ },
497
+
498
+ # ── HARD ─────────────────────────────────────────────────────────────────
499
+
500
+ {
501
+ "domain": "healthcare", "difficulty": "hard", "has_order": True,
502
+ "description": "Doctors ranked by appointment count within specialization using RANK",
503
+ "base_nl": "Rank doctors by appointment count within their specialization (rank 1 = most appointments). Return doctor_name, specialization, appointment_count, rank_in_spec, sorted by specialization then rank_in_spec ascending.",
504
+ "sql": (
505
+ "SELECT doctor_name, specialization, appointment_count, "
506
+ " RANK() OVER (PARTITION BY specialization ORDER BY appointment_count DESC) AS rank_in_spec "
507
+ "FROM ( "
508
+ " SELECT d.name AS doctor_name, d.specialization, COUNT(a.id) AS appointment_count "
509
+ " FROM doctors d "
510
+ " JOIN appointments a ON a.doctor_id = d.id "
511
+ " GROUP BY d.id, d.name, d.specialization "
512
+ ") sub "
513
+ "ORDER BY specialization, rank_in_spec"
514
+ ),
515
+ },
516
+ {
517
+ "domain": "healthcare", "difficulty": "hard", "has_order": True,
518
+ "description": "Top 10 patients by total completed visits using CTE",
519
+ "base_nl": "Find the top 10 patients by number of completed appointments. Return patient_name, total_visits, last_visit, sorted by total_visits descending.",
520
+ "sql": (
521
+ "WITH patient_visits AS ( "
522
+ " SELECT a.patient_id, COUNT(a.id) AS total_visits, "
523
+ " MAX(a.scheduled_at) AS last_visit "
524
+ " FROM appointments a "
525
+ " WHERE a.status = 'completed' "
526
+ " GROUP BY a.patient_id "
527
+ ") "
528
+ "SELECT p.name AS patient_name, pv.total_visits, pv.last_visit "
529
+ "FROM patients p "
530
+ "JOIN patient_visits pv ON pv.patient_id = p.id "
531
+ "ORDER BY pv.total_visits DESC "
532
+ "LIMIT 10"
533
+ ),
534
+ },
535
+ {
536
+ "domain": "healthcare", "difficulty": "hard", "has_order": True,
537
+ "description": "Medications total prescription cost per category using window SUM",
538
+ "base_nl": "For each medication, show its total prescription cost (unit_price × quantity) and the running total of cost within its category. Return medication_name, category, total_cost, category_running_cost (both rounded to 2 decimal places), sorted by category then total_cost descending.",
539
+ "sql": (
540
+ "WITH med_costs AS ( "
541
+ " SELECT m.name AS medication_name, m.category, "
542
+ " ROUND(SUM(m.unit_price * pr.quantity), 2) AS total_cost "
543
+ " FROM medications m "
544
+ " JOIN prescriptions pr ON pr.medication_id = m.id "
545
+ " GROUP BY m.id, m.name, m.category "
546
+ ") "
547
+ "SELECT medication_name, category, total_cost, "
548
+ " ROUND(SUM(total_cost) OVER (PARTITION BY category ORDER BY total_cost DESC), 2) "
549
+ " AS category_running_cost "
550
+ "FROM med_costs "
551
+ "ORDER BY category, total_cost DESC"
552
+ ),
553
+ },
554
+ ]
555
+
556
+
557
+ # ─────────────────────────────────────────────────────────────────────────────
558
+ # DOMAIN: FINANCE
559
+ # ─────────────────────────────────────────────────────────────────────────────
560
+
561
+ FINANCE_TEMPLATES: list[Template] = [
562
+
563
+ # ── EASY ────────────────────────────────────────────────────────────────
564
+
565
+ {
566
+ "domain": "finance", "difficulty": "easy", "has_order": True,
567
+ "description": "Verified KYC customers sorted by name",
568
+ "base_nl": "List all customers with verified KYC status, sorted alphabetically. Return id, name, country, kyc_status.",
569
+ "sql": (
570
+ "SELECT id, name, country, kyc_status "
571
+ "FROM fin_customers "
572
+ "WHERE kyc_status = 'verified' "
573
+ "ORDER BY name ASC"
574
+ ),
575
+ },
576
+ {
577
+ "domain": "finance", "difficulty": "easy", "has_order": True,
578
+ "description": "Accounts with balance over $10,000 sorted by balance descending",
579
+ "base_nl": "Which accounts have a balance greater than $10,000? Return id, customer_id, account_type, balance, sorted by balance descending.",
580
+ "sql": (
581
+ "SELECT id, customer_id, account_type, balance "
582
+ "FROM accounts "
583
+ "WHERE balance > 10000 "
584
+ "ORDER BY balance DESC"
585
+ ),
586
+ },
587
+ {
588
+ "domain": "finance", "difficulty": "easy", "has_order": True,
589
+ "description": "Large credit transactions above $1,000 sorted by amount descending",
590
+ "base_nl": "Show all credit transactions with an amount greater than $1,000. Return id, account_id, txn_type, amount, created_at, sorted by amount descending.",
591
+ "sql": (
592
+ "SELECT id, account_id, txn_type, amount, created_at "
593
+ "FROM transactions "
594
+ "WHERE txn_type = 'credit' AND amount > 1000 "
595
+ "ORDER BY amount DESC"
596
+ ),
597
+ },
598
+ {
599
+ "domain": "finance", "difficulty": "easy", "has_order": True,
600
+ "description": "Defaulted loans sorted by principal amount descending",
601
+ "base_nl": "List all defaulted loans, sorted by principal amount descending. Return id, loan_type, principal_amount, interest_rate, status.",
602
+ "sql": (
603
+ "SELECT id, loan_type, principal_amount, interest_rate, status "
604
+ "FROM loans "
605
+ "WHERE status = 'defaulted' "
606
+ "ORDER BY principal_amount DESC"
607
+ ),
608
+ },
609
+ {
610
+ "domain": "finance", "difficulty": "easy", "has_order": False,
611
+ "description": "Count of late loan payments",
612
+ "base_nl": "How many loan payments were made late? Return a single value late_payments.",
613
+ "sql": "SELECT COUNT(*) AS late_payments FROM loan_payments WHERE is_late = 1",
614
+ },
615
+ {
616
+ "domain": "finance", "difficulty": "easy", "has_order": True,
617
+ "description": "Top 5 highest principal loans",
618
+ "base_nl": "What are the top 5 loans by principal amount? Return id, customer_id, loan_type, principal_amount.",
619
+ "sql": (
620
+ "SELECT id, customer_id, loan_type, principal_amount "
621
+ "FROM loans "
622
+ "ORDER BY principal_amount DESC "
623
+ "LIMIT 5"
624
+ ),
625
+ },
626
+ {
627
+ "domain": "finance", "difficulty": "easy", "has_order": False,
628
+ "description": "Count of accounts by account type",
629
+ "base_nl": "How many accounts exist for each account type? Return account_type, account_count.",
630
+ "sql": (
631
+ "SELECT account_type, COUNT(*) AS account_count "
632
+ "FROM accounts "
633
+ "GROUP BY account_type"
634
+ ),
635
+ },
636
+
637
+ # ── MEDIUM ───────────────────────────────────────────────────────────────
638
+
639
+ {
640
+ "domain": "finance", "difficulty": "medium", "has_order": True,
641
+ "description": "Total active account balance per customer sorted by balance descending",
642
+ "base_nl": "What is the total active account balance per customer? Return customer_name, account_count, total_balance (rounded to 2 decimal places), sorted by total_balance descending.",
643
+ "sql": (
644
+ "SELECT fc.name AS customer_name, COUNT(a.id) AS account_count, "
645
+ " ROUND(SUM(a.balance), 2) AS total_balance "
646
+ "FROM fin_customers fc "
647
+ "JOIN accounts a ON a.customer_id = fc.id "
648
+ "WHERE a.status = 'active' "
649
+ "GROUP BY fc.id, fc.name "
650
+ "ORDER BY total_balance DESC"
651
+ ),
652
+ },
653
+ {
654
+ "domain": "finance", "difficulty": "medium", "has_order": True,
655
+ "description": "Total credit transaction amount by account type",
656
+ "base_nl": "What is the total credit amount per account type? Return account_type, total_credits (rounded to 2 decimal places), sorted by total_credits descending.",
657
+ "sql": (
658
+ "SELECT a.account_type, ROUND(SUM(t.amount), 2) AS total_credits "
659
+ "FROM accounts a "
660
+ "JOIN transactions t ON t.account_id = a.id "
661
+ "WHERE t.txn_type = 'credit' "
662
+ "GROUP BY a.account_type "
663
+ "ORDER BY total_credits DESC"
664
+ ),
665
+ },
666
+ {
667
+ "domain": "finance", "difficulty": "medium", "has_order": True,
668
+ "description": "Total loan borrowing per customer sorted descending",
669
+ "base_nl": "How much has each customer borrowed in total across all loans? Return customer_name, loan_count, total_borrowed (rounded to 2 decimal places), sorted by total_borrowed descending.",
670
+ "sql": (
671
+ "SELECT fc.name AS customer_name, COUNT(l.id) AS loan_count, "
672
+ " ROUND(SUM(l.principal_amount), 2) AS total_borrowed "
673
+ "FROM fin_customers fc "
674
+ "JOIN loans l ON l.customer_id = fc.id "
675
+ "GROUP BY fc.id, fc.name "
676
+ "ORDER BY total_borrowed DESC"
677
+ ),
678
+ },
679
+ {
680
+ "domain": "finance", "difficulty": "medium", "has_order": True,
681
+ "description": "Late payment count and total amount by loan type",
682
+ "base_nl": "For each loan type, how many late payments were there and what was the total amount paid late? Return loan_type, late_payments, total_late_paid (rounded to 2 decimal places), sorted by late_payments descending.",
683
+ "sql": (
684
+ "SELECT l.loan_type, COUNT(lp.id) AS late_payments, "
685
+ " ROUND(SUM(lp.amount_paid), 2) AS total_late_paid "
686
+ "FROM loans l "
687
+ "JOIN loan_payments lp ON lp.loan_id = l.id "
688
+ "WHERE lp.is_late = 1 "
689
+ "GROUP BY l.loan_type "
690
+ "ORDER BY late_payments DESC"
691
+ ),
692
+ },
693
+
694
+ # ── HARD ─────────────────────────────────────────────────────────────────
695
+
696
+ {
697
+ "domain": "finance", "difficulty": "hard", "has_order": True,
698
+ "description": "Customer balance rank using DENSE_RANK on active accounts",
699
+ "base_nl": "Rank customers by their total active account balance using DENSE_RANK. Return customer_name, total_balance, balance_rank, sorted by balance_rank ascending.",
700
+ "sql": (
701
+ "SELECT customer_name, total_balance, "
702
+ " DENSE_RANK() OVER (ORDER BY total_balance DESC) AS balance_rank "
703
+ "FROM ( "
704
+ " SELECT fc.name AS customer_name, "
705
+ " ROUND(SUM(a.balance), 2) AS total_balance "
706
+ " FROM fin_customers fc "
707
+ " JOIN accounts a ON a.customer_id = fc.id "
708
+ " WHERE a.status = 'active' "
709
+ " GROUP BY fc.id, fc.name "
710
+ ") sub "
711
+ "ORDER BY balance_rank"
712
+ ),
713
+ },
714
+ {
715
+ "domain": "finance", "difficulty": "hard", "has_order": True,
716
+ "description": "Monthly transaction totals by type with running total using window SUM",
717
+ "base_nl": "Show monthly transaction totals per type (credit/debit) with a running cumulative total. Return month (YYYY-MM), txn_type, total, running_total (rounded to 2 decimal places), sorted by month then txn_type.",
718
+ "sql": (
719
+ "WITH monthly_txn AS ( "
720
+ " SELECT strftime('%Y-%m', created_at) AS month, "
721
+ " txn_type, "
722
+ " ROUND(SUM(amount), 2) AS total "
723
+ " FROM transactions "
724
+ " GROUP BY strftime('%Y-%m', created_at), txn_type "
725
+ ") "
726
+ "SELECT month, txn_type, total, "
727
+ " ROUND(SUM(total) OVER (PARTITION BY txn_type ORDER BY month), 2) AS running_total "
728
+ "FROM monthly_txn "
729
+ "ORDER BY month, txn_type"
730
+ ),
731
+ },
732
+ {
733
+ "domain": "finance", "difficulty": "hard", "has_order": True,
734
+ "description": "Customers with only defaulted loans using NOT EXISTS",
735
+ "base_nl": "Find customers who have at least one loan and ALL their loans are defaulted. Return customer_name, loan_count, sorted by customer_name ascending.",
736
+ "sql": (
737
+ "SELECT fc.name AS customer_name, COUNT(l.id) AS loan_count "
738
+ "FROM fin_customers fc "
739
+ "JOIN loans l ON l.customer_id = fc.id "
740
+ "GROUP BY fc.id, fc.name "
741
+ "HAVING COUNT(l.id) > 0 "
742
+ " AND SUM(CASE WHEN l.status != 'defaulted' THEN 1 ELSE 0 END) = 0 "
743
+ "ORDER BY customer_name ASC"
744
+ ),
745
+ },
746
+ ]
747
+
748
+
749
+ # ─────────────────────────────────────────────────────────────────────────────
750
+ # DOMAIN: HR
751
+ # ─────────────────────────────────────────────────────────────────────────────
752
+
753
+ HR_TEMPLATES: list[Template] = [
754
+
755
+ # ── EASY ────────────────────────────────────────────────────────────────
756
+
757
+ {
758
+ "domain": "hr", "difficulty": "easy", "has_order": True,
759
+ "description": "Active employees sorted by salary descending",
760
+ "base_nl": "List all active employees sorted by salary from highest to lowest. Return id, name, job_title, salary.",
761
+ "sql": (
762
+ "SELECT id, name, job_title, salary "
763
+ "FROM employees "
764
+ "WHERE status = 'active' "
765
+ "ORDER BY salary DESC"
766
+ ),
767
+ },
768
+ {
769
+ "domain": "hr", "difficulty": "easy", "has_order": True,
770
+ "description": "Departments sorted by budget descending",
771
+ "base_nl": "Show all departments sorted by budget from largest to smallest. Return id, name, location, budget.",
772
+ "sql": (
773
+ "SELECT id, name, location, budget "
774
+ "FROM departments "
775
+ "ORDER BY budget DESC"
776
+ ),
777
+ },
778
+ {
779
+ "domain": "hr", "difficulty": "easy", "has_order": True,
780
+ "description": "Employees hired in 2023 or later sorted by hire date descending",
781
+ "base_nl": "Which employees were hired on or after January 1st 2023? Sort by hire date descending. Return id, name, job_title, hire_date.",
782
+ "sql": (
783
+ "SELECT id, name, job_title, hire_date "
784
+ "FROM employees "
785
+ "WHERE hire_date >= '2023-01-01' "
786
+ "ORDER BY hire_date DESC"
787
+ ),
788
+ },
789
+ {
790
+ "domain": "hr", "difficulty": "easy", "has_order": True,
791
+ "description": "Active projects sorted by budget descending",
792
+ "base_nl": "Show all currently active projects sorted by budget descending. Return id, name, status, budget.",
793
+ "sql": (
794
+ "SELECT id, name, status, budget "
795
+ "FROM projects "
796
+ "WHERE status = 'active' "
797
+ "ORDER BY budget DESC"
798
+ ),
799
+ },
800
+ {
801
+ "domain": "hr", "difficulty": "easy", "has_order": True,
802
+ "description": "Active employees earning above $100,000 sorted by salary descending",
803
+ "base_nl": "Which active employees earn more than $100,000? Return id, name, email, job_title, sorted by salary descending.",
804
+ "sql": (
805
+ "SELECT id, name, email, job_title "
806
+ "FROM employees "
807
+ "WHERE status = 'active' AND salary > 100000 "
808
+ "ORDER BY salary DESC"
809
+ ),
810
+ },
811
+ {
812
+ "domain": "hr", "difficulty": "easy", "has_order": False,
813
+ "description": "Count of active employees",
814
+ "base_nl": "How many active employees do we currently have? Return active_employees.",
815
+ "sql": "SELECT COUNT(*) AS active_employees FROM employees WHERE status = 'active'",
816
+ },
817
+ {
818
+ "domain": "hr", "difficulty": "easy", "has_order": True,
819
+ "description": "Projects with no end date (ongoing) sorted by budget descending",
820
+ "base_nl": "List all ongoing projects that have no end date set. Return id, name, start_date, budget, sorted by budget descending.",
821
+ "sql": (
822
+ "SELECT id, name, start_date, budget "
823
+ "FROM projects "
824
+ "WHERE end_date IS NULL "
825
+ "ORDER BY budget DESC"
826
+ ),
827
+ },
828
+
829
+ # ── MEDIUM ───────────────────────────────────────────────────────────────
830
+
831
+ {
832
+ "domain": "hr", "difficulty": "medium", "has_order": True,
833
+ "description": "Headcount and average salary per department for active employees",
834
+ "base_nl": "For each department, what is the headcount and average salary of active employees? Return department_name, headcount, avg_salary (rounded to 2 decimal places), sorted by headcount descending.",
835
+ "sql": (
836
+ "SELECT d.name AS department_name, COUNT(e.id) AS headcount, "
837
+ " ROUND(AVG(e.salary), 2) AS avg_salary "
838
+ "FROM departments d "
839
+ "LEFT JOIN employees e ON e.department_id = d.id AND e.status = 'active' "
840
+ "GROUP BY d.id, d.name "
841
+ "ORDER BY headcount DESC"
842
+ ),
843
+ },
844
+ {
845
+ "domain": "hr", "difficulty": "medium", "has_order": True,
846
+ "description": "Average performance rating per employee sorted descending",
847
+ "base_nl": "What is the average performance review rating per active employee? Return employee_name, job_title, avg_rating (rounded to 2 decimal places), sorted by avg_rating descending.",
848
+ "sql": (
849
+ "SELECT e.name AS employee_name, e.job_title, "
850
+ " ROUND(AVG(pr.rating), 2) AS avg_rating "
851
+ "FROM employees e "
852
+ "JOIN performance_reviews pr ON pr.employee_id = e.id "
853
+ "WHERE e.status = 'active' "
854
+ "GROUP BY e.id, e.name, e.job_title "
855
+ "ORDER BY avg_rating DESC"
856
+ ),
857
+ },
858
+ {
859
+ "domain": "hr", "difficulty": "medium", "has_order": True,
860
+ "description": "Employees with the most total allocated project hours",
861
+ "base_nl": "Which employees have the most total hours allocated across projects? Return employee_name, total_hours, sorted by total_hours descending, top 10.",
862
+ "sql": (
863
+ "SELECT e.name AS employee_name, SUM(pa.hours_allocated) AS total_hours "
864
+ "FROM employees e "
865
+ "JOIN project_assignments pa ON pa.employee_id = e.id "
866
+ "GROUP BY e.id, e.name "
867
+ "ORDER BY total_hours DESC "
868
+ "LIMIT 10"
869
+ ),
870
+ },
871
+ {
872
+ "domain": "hr", "difficulty": "medium", "has_order": True,
873
+ "description": "Departments with distinct employees assigned to active projects",
874
+ "base_nl": "For each department, how many distinct employees are assigned to active projects? Return department_name, assigned_employees, sorted by assigned_employees descending.",
875
+ "sql": (
876
+ "SELECT d.name AS department_name, "
877
+ " COUNT(DISTINCT pa.employee_id) AS assigned_employees "
878
+ "FROM departments d "
879
+ "JOIN projects p ON p.department_id = d.id "
880
+ "JOIN project_assignments pa ON pa.project_id = p.id "
881
+ "WHERE p.status = 'active' "
882
+ "GROUP BY d.id, d.name "
883
+ "ORDER BY assigned_employees DESC"
884
+ ),
885
+ },
886
+ {
887
+ "domain": "hr", "difficulty": "medium", "has_order": True,
888
+ "description": "Total project budget per department sorted descending",
889
+ "base_nl": "What is the total project budget per department? Return department_name, total_project_budget (rounded to 2 decimal places), sorted by total_project_budget descending.",
890
+ "sql": (
891
+ "SELECT d.name AS department_name, "
892
+ " ROUND(SUM(p.budget), 2) AS total_project_budget "
893
+ "FROM departments d "
894
+ "JOIN projects p ON p.department_id = d.id "
895
+ "GROUP BY d.id, d.name "
896
+ "ORDER BY total_project_budget DESC"
897
+ ),
898
+ },
899
+
900
+ # ── HARD ─────────────────────────────────────────────────────────────────
901
+
902
+ {
903
+ "domain": "hr", "difficulty": "hard", "has_order": True,
904
+ "description": "Salary rank within department using DENSE_RANK",
905
+ "base_nl": "Rank active employees by salary within their department using DENSE_RANK (rank 1 = highest paid). Return employee_name, salary, department_name, salary_rank, sorted by department_name then salary_rank ascending.",
906
+ "sql": (
907
+ "SELECT employee_name, salary, department_name, "
908
+ " DENSE_RANK() OVER (PARTITION BY department_name ORDER BY salary DESC) AS salary_rank "
909
+ "FROM ( "
910
+ " SELECT e.name AS employee_name, e.salary, d.name AS department_name "
911
+ " FROM employees e "
912
+ " JOIN departments d ON d.id = e.department_id "
913
+ " WHERE e.status = 'active' "
914
+ ") sub "
915
+ "ORDER BY department_name, salary_rank"
916
+ ),
917
+ },
918
+ {
919
+ "domain": "hr", "difficulty": "hard", "has_order": True,
920
+ "description": "Employee performance band classification using CASE with avg rating CTE",
921
+ "base_nl": "Classify active employees into performance bands (High Performer: avg rating >= 4, Average: >= 3, Needs Improvement: < 3) based on their average review rating. Return employee_name, salary, avg_rating, performance_band, sorted by avg_rating descending.",
922
+ "sql": (
923
+ "WITH avg_ratings AS ( "
924
+ " SELECT employee_id, ROUND(AVG(rating), 2) AS avg_rating "
925
+ " FROM performance_reviews "
926
+ " GROUP BY employee_id "
927
+ ") "
928
+ "SELECT e.name AS employee_name, e.salary, ar.avg_rating, "
929
+ " CASE WHEN ar.avg_rating >= 4 THEN 'High Performer' "
930
+ " WHEN ar.avg_rating >= 3 THEN 'Average' "
931
+ " ELSE 'Needs Improvement' "
932
+ " END AS performance_band "
933
+ "FROM employees e "
934
+ "JOIN avg_ratings ar ON ar.employee_id = e.id "
935
+ "WHERE e.status = 'active' "
936
+ "ORDER BY ar.avg_rating DESC"
937
+ ),
938
+ },
939
+ {
940
+ "domain": "hr", "difficulty": "hard", "has_order": True,
941
+ "description": "Employees above their department average salary using CTE",
942
+ "base_nl": "Find active employees whose salary is above their department's average. Return employee_name, department_name, salary, dept_avg_salary (rounded to 2 decimal places), sorted by salary descending.",
943
+ "sql": (
944
+ "WITH dept_avg AS ( "
945
+ " SELECT department_id, ROUND(AVG(salary), 2) AS dept_avg_salary "
946
+ " FROM employees "
947
+ " WHERE status = 'active' "
948
+ " GROUP BY department_id "
949
+ ") "
950
+ "SELECT e.name AS employee_name, d.name AS department_name, "
951
+ " e.salary, da.dept_avg_salary "
952
+ "FROM employees e "
953
+ "JOIN departments d ON d.id = e.department_id "
954
+ "JOIN dept_avg da ON da.department_id = e.department_id "
955
+ "WHERE e.status = 'active' AND e.salary > da.dept_avg_salary "
956
+ "ORDER BY e.salary DESC"
957
+ ),
958
+ },
959
+ ]
960
+
961
+
962
+ # ─────────────────────────────────────────────────────────────────────────────
963
+ # MASTER TEMPLATE REGISTRY
964
+ # ─────────────────────────────────────────────────────────────────────────────
965
+
966
+ ALL_TEMPLATES: list[Template] = (
967
+ ECOMMERCE_TEMPLATES +
968
+ HEALTHCARE_TEMPLATES +
969
+ FINANCE_TEMPLATES +
970
+ HR_TEMPLATES
971
+ )
972
+
973
+ TEMPLATES_BY_DOMAIN: dict[str, list[Template]] = {
974
+ "ecommerce": ECOMMERCE_TEMPLATES,
975
+ "healthcare": HEALTHCARE_TEMPLATES,
976
+ "finance": FINANCE_TEMPLATES,
977
+ "hr": HR_TEMPLATES,
978
+ }
979
+
980
+ TEMPLATES_BY_DIFFICULTY: dict[str, list[Template]] = {
981
+ "easy": [t for t in ALL_TEMPLATES if t["difficulty"] == "easy"],
982
+ "medium": [t for t in ALL_TEMPLATES if t["difficulty"] == "medium"],
983
+ "hard": [t for t in ALL_TEMPLATES if t["difficulty"] == "hard"],
984
+ }
985
+
986
+
987
+ def template_stats() -> dict:
988
+ stats: dict = {"total": len(ALL_TEMPLATES), "by_domain": {}, "by_difficulty": {}}
989
+ for d in ["ecommerce","healthcare","finance","hr"]:
990
+ stats["by_domain"][d] = len(TEMPLATES_BY_DOMAIN[d])
991
+ for diff in ["easy","medium","hard"]:
992
+ stats["by_difficulty"][diff] = len(TEMPLATES_BY_DIFFICULTY[diff])
993
+ return stats
data_factory/validator.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ data_factory/validator.py
3
+ ==========================
4
+ SQL execution validation layer.
5
+
6
+ GUARANTEE: Every record that passes this validator has a SQL that:
7
+ 1. Runs without error against the actual seeded SQLite schema
8
+ 2. Returns at least one row (non-empty result)
9
+ 3. Returns the expected column names
10
+
11
+ No LLM-generated SQL ever reaches this validator — SQL always comes from
12
+ the human-verified template library. This validator is an extra safety net
13
+ to catch any copy-paste or formatting regressions.
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import sqlite3
19
+ from dataclasses import dataclass, field
20
+ from typing import Any, Optional
21
+
22
+ from data_factory.schemas import build_connection, SCHEMA_CONTEXT
23
+ from data_factory.templates import Template
24
+
25
+
26
+ # ─────────────────────────────────────────────────────────────────────────────
27
+ # DATA CLASSES
28
+ # ─────────────────────────────────────────────────────────────────────────────
29
+
30
+ @dataclass
31
+ class ValidationResult:
32
+ passed: bool
33
+ sql: str
34
+ error: Optional[str] = None
35
+ row_count: int = 0
36
+ columns: list[str] = field(default_factory=list)
37
+
38
+
39
+ @dataclass
40
+ class DataRecord:
41
+ """One training example ready to be written to JSONL/Parquet."""
42
+ domain: str
43
+ difficulty: str
44
+ sql: str
45
+ nl_question: str # The NL paraphrase used as prompt
46
+ persona: str # ceo | chatty | lazy_typist | non_techie | analyst | augmented
47
+ has_order: bool
48
+ schema_context: str
49
+ row_count: int # From validation run
50
+ columns: list[str] # From validation run
51
+ source: str # "template_base" | "vllm_persona" | "rule_augmented"
52
+ template_id: int # Index into ALL_TEMPLATES
53
+
54
+ def to_training_dict(self) -> dict[str, Any]:
55
+ """
56
+ Returns the dictionary that will be written to the output dataset.
57
+
58
+ Format is compatible with TRL / HuggingFace `datasets`:
59
+ prompt : chat-format messages list (system + user)
60
+ sql : ground-truth SQL (label / reward reference)
61
+ metadata: auxiliary fields for curriculum or filtering
62
+ """
63
+ system_msg = (
64
+ "You are an expert SQL analyst. "
65
+ "Write a single SELECT query that answers the question. "
66
+ "Output ONLY the SQL query — no markdown, no explanation, no backticks."
67
+ )
68
+ user_msg = (
69
+ f"DATABASE SCHEMA\n"
70
+ f"---------------\n"
71
+ f"{self.schema_context}\n\n"
72
+ f"QUESTION: {self.nl_question}"
73
+ )
74
+ return {
75
+ "prompt": [
76
+ {"role": "system", "content": system_msg},
77
+ {"role": "user", "content": user_msg},
78
+ ],
79
+ "sql": self.sql,
80
+ "metadata": {
81
+ "domain": self.domain,
82
+ "difficulty": self.difficulty,
83
+ "persona": self.persona,
84
+ "has_order": self.has_order,
85
+ "row_count": self.row_count,
86
+ "columns": self.columns,
87
+ "source": self.source,
88
+ "template_id": self.template_id,
89
+ },
90
+ }
91
+
92
+
93
+ # ─────────────────────────────────────────────────────────────────────────────
94
+ # VALIDATOR
95
+ # ─────────────────────────────────────────────────────────────────────────────
96
+
97
+ class SQLValidator:
98
+ """
99
+ Validates SQL against a seeded in-memory SQLite connection.
100
+
101
+ One validator per domain to reuse the same connection for all templates
102
+ in that domain (performance optimization).
103
+ """
104
+
105
+ def __init__(self, domain: str, seed: int = 42) -> None:
106
+ self.domain = domain
107
+ self._conn = build_connection(domain, seed=seed)
108
+
109
+ def validate(self, sql: str) -> ValidationResult:
110
+ """
111
+ Execute SQL and return a ValidationResult.
112
+ Never raises — always returns a result object.
113
+ """
114
+ sql = sql.strip().rstrip(";")
115
+ if not sql:
116
+ return ValidationResult(passed=False, sql=sql, error="Empty SQL string.")
117
+
118
+ # Block any write operations
119
+ first_word = sql.split()[0].lower() if sql.split() else ""
120
+ forbidden = {"insert","update","delete","drop","alter","create","replace","truncate","pragma"}
121
+ if first_word in forbidden:
122
+ return ValidationResult(
123
+ passed=False, sql=sql,
124
+ error=f"Write operation '{first_word.upper()}' is not permitted."
125
+ )
126
+
127
+ try:
128
+ cur = self._conn.execute(sql)
129
+ cols = [d[0] for d in cur.description] if cur.description else []
130
+ rows = cur.fetchall()
131
+ return ValidationResult(
132
+ passed=True,
133
+ sql=sql,
134
+ row_count=len(rows),
135
+ columns=cols,
136
+ )
137
+ except sqlite3.Error as exc:
138
+ return ValidationResult(passed=False, sql=sql, error=str(exc))
139
+
140
+ def close(self) -> None:
141
+ self._conn.close()
142
+
143
+
144
+ def validate_template(template: Template, seed: int = 42) -> ValidationResult:
145
+ """Convenience function: validate a single template."""
146
+ v = SQLValidator(template["domain"], seed=seed)
147
+ result = v.validate(template["sql"])
148
+ v.close()
149
+ return result
150
+
151
+
152
+ def validate_all_templates(templates: list[Template], seed: int = 42) -> dict[str, Any]:
153
+ """
154
+ Run validation across all templates. Returns a summary dict.
155
+ Used during CI / smoke testing.
156
+ """
157
+ from data_factory.schemas import SCHEMA_MAP
158
+
159
+ validators = {domain: SQLValidator(domain, seed) for domain in SCHEMA_MAP}
160
+ passed = []
161
+ failed = []
162
+
163
+ for i, t in enumerate(templates):
164
+ v = validators[t["domain"]]
165
+ result = v.validate(t["sql"])
166
+ if result.passed:
167
+ passed.append(i)
168
+ else:
169
+ failed.append({"index": i, "domain": t["domain"],
170
+ "sql": t["sql"][:80], "error": result.error})
171
+
172
+ for v in validators.values():
173
+ v.close()
174
+
175
+ return {
176
+ "total": len(templates),
177
+ "passed": len(passed),
178
+ "failed": len(failed),
179
+ "failures": failed,
180
+ }
181
+
182
+
183
+ def build_record(
184
+ template: Template,
185
+ template_idx: int,
186
+ nl_question: str,
187
+ persona: str,
188
+ source: str,
189
+ validator: SQLValidator,
190
+ ) -> Optional[DataRecord]:
191
+ """
192
+ Validate the template SQL and, if it passes, build a DataRecord.
193
+
194
+ Parameters
195
+ ----------
196
+ template : The source template (contains SQL, domain, difficulty).
197
+ template_idx : Index of template in ALL_TEMPLATES (for deduplication).
198
+ nl_question : The NL paraphrase to use as the prompt.
199
+ persona : Which persona/strategy generated this NL.
200
+ source : 'template_base' | 'vllm_persona' | 'rule_augmented'
201
+ validator : Pre-built SQLValidator for this domain.
202
+
203
+ Returns None if validation fails.
204
+ """
205
+ vr = validator.validate(template["sql"])
206
+ if not vr.passed:
207
+ return None
208
+
209
+ return DataRecord(
210
+ domain=template["domain"],
211
+ difficulty=template["difficulty"],
212
+ sql=template["sql"],
213
+ nl_question=nl_question,
214
+ persona=persona,
215
+ has_order=template["has_order"],
216
+ schema_context=SCHEMA_CONTEXT[template["domain"]],
217
+ row_count=vr.row_count,
218
+ columns=vr.columns,
219
+ source=source,
220
+ template_id=template_idx,
221
+ )
env_server ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from fastapi import FastAPI, Request
4
+ import uvicorn
5
+
6
+ sys.path.insert(0, "./server")
7
+ from environment import NL2SQLEnvironment
8
+ from models import NL2SQLAction
9
+
10
+ app = FastAPI()
11
+ env = NL2SQLEnvironment()
12
+
13
+ @app.post("/reset")
14
+ async def reset(request: Request):
15
+ data = await request.json()
16
+ # Now we take task_name directly from the API call
17
+ task_name = data.get("task_name", "simple-filter")
18
+ print(f"🔄 Environment Resetting for Task: {task_name}")
19
+ obs = env.reset(task_name=task_name)
20
+ return {"observation": obs.__dict__}
21
+
22
+ @app.post("/step")
23
+ async def step(request: Request):
24
+ data = await request.json()
25
+ query = data.get("query", "")
26
+ print(f"⏩ Executing SQL: {query[:60]}...")
27
+
28
+ action = NL2SQLAction(query=query)
29
+ obs = env.step(action)
30
+ return {"observation": obs.__dict__}
31
+
32
+ if __name__ == "__main__":
33
+ uvicorn.run(app, host="0.0.0.0", port=8000)
folder.txt ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .
2
+ ├── check_quality.py
3
+ ├── clean_dataset.py
4
+ ├── client.py
5
+ ├── custom_train.py
6
+ ├── data_expander.py
7
+ ├── data_factory
8
+ │   ├── augmentor.py
9
+ │   ├── config.py
10
+ │   ├── generate_data.py
11
+ │   ├── generator.py
12
+ │   ├── __init__.py
13
+ │   ├── pipeline.py
14
+ │   ├── run_data_factory.py
15
+ │   ├── schemas.py
16
+ │   ├── templates.py
17
+ │   └── validator.py
18
+ ├── Dockerfile
19
+ ├── edge_cases.jsonl
20
+ ├── env_server
21
+ ├── folder.txt
22
+ ├── generate_data.py
23
+ ├── generate_edge_cases.py
24
+ ├── inference.py
25
+ ├── __init__.py
26
+ ├── llm_hybrid_templates.json
27
+ ├── local_test.py
28
+ ├── merge_model.py
29
+ ├── mini_server.py
30
+ ├── models.py
31
+ ├── nl2sql_50k_elite_dataset_1.jsonl
32
+ ├── nl2sql_50k_elite_dataset.jsonl
33
+ ├── nl2sql_cleaned_ready_to_train.jsonl
34
+ ├── nl2sql_merged_final.jsonl
35
+ ├── openenv.yaml
36
+ ├── pyproject.toml
37
+ ├── qwen-7b-coder-nl2sql-grpo
38
+ │   ├── checkpoint-70
39
+ │   │   ├── adapter_config.json
40
+ │   │   ├── adapter_model.safetensors
41
+ │   │   ├── chat_template.jinja
42
+ │   │   ├── optimizer.pt
43
+ │   │   ├── README.md
44
+ │   │   ├── rng_state_0.pth
45
+ │   │   ├── rng_state_1.pth
46
+ │   │   ├── scheduler.pt
47
+ │   │   ├── tokenizer_config.json
48
+ │   │   ├── tokenizer.json
49
+ │   │   ├── trainer_state.json
50
+ │   │   └── training_args.bin
51
+ │   ├── final
52
+ │   │   ├── adapter_config.json
53
+ │   │   ├── adapter_model.safetensors
54
+ │   │   ├── chat_template.jinja
55
+ │   │   ├── README.md
56
+ │   │   ├── tokenizer_config.json
57
+ │   │   └── tokenizer.json
58
+ │   └── README.md
59
+ ├── qwen-7b-coder-nl2sql-grpo-v2
60
+ ├── qwen-7b-nl2sql-merged
61
+ │   ├── chat_template.jinja
62
+ │   ├── config.json
63
+ │   ├── generation_config.json
64
+ │   ├── model.safetensors
65
+ │   ├── tokenizer_config.json
66
+ │   └── tokenizer.json
67
+ ├── README.md
68
+ ├── scripts
69
+ │   ├── run_local.sh
70
+ │   └── smoke_test.sh
71
+ ├── server
72
+ │   ├── app.py
73
+ │   ├── db
74
+ │   │   ├── __init__.py
75
+ │   │   ├── schema.sql
76
+ │   │   └── seed.py
77
+ │   ├── environment.py
78
+ │   ├── grader.py
79
+ │   ├── __init__.py
80
+ │   ├── requirements.txt
81
+ │   └── tasks
82
+ │   ├── base.py
83
+ │   ├── easy.py
84
+ │   ├── hard.py
85
+ │   ├── __init__.py
86
+ │   └── medium.py
87
+ ├── swapped_templates.json
88
+ ├── tests
89
+ │   ├── conftest.py
90
+ │   ├── __init__.py
91
+ │   └── test_all.py
92
+ ├── train.py
93
+ └── value_swapper.py
94
+
95
+ 11 directories, 81 files
generate_data.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import torch
5
+ import hashlib
6
+ from pathlib import Path
7
+ from tqdm import tqdm
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer
9
+
10
+ # GPU CONFIG - All 4 H100s engaged
11
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,7"
12
+
13
+ PROJECT_ROOT = os.path.abspath(os.path.dirname(__file__))
14
+ if PROJECT_ROOT not in sys.path:
15
+ sys.path.insert(0, PROJECT_ROOT)
16
+
17
+ from data_factory.schemas import SCHEMA_CONTEXT
18
+ from data_factory.validator import SQLValidator
19
+
20
+ # CONFIG
21
+ MODEL_NAME = "Qwen/Qwen2.5-72B-Instruct"
22
+ TARGET_TEMPLATES = 10000
23
+ OUTPUT_FILE = "llm_10k_base_templates.json"
24
+ BATCH_SIZE = 64
25
+
26
+ PROMPT_TEMPLATE = """
27
+ You are a senior expert in SQLite schema design and NL2SQL dataset generation.
28
+
29
+ TASK
30
+ Generate exactly 10 UNIQUE, COMPLEX, and FULLY VALID SQLite SQL SELECT queries for the given schema.
31
+ For each query, also write a natural language question that a real user might ask.
32
+
33
+ HARD RULES
34
+ - Output ONLY a valid JSON array.
35
+ - Do NOT wrap output in markdown, code fences, or explanations.
36
+ - Every item must be a JSON object with exactly these keys:
37
+ - "sql"
38
+ - "base_nl"
39
+ - "difficulty"
40
+ - "has_order"
41
+ - All SQL must be a single SELECT statement.
42
+ - Do NOT use INSERT, UPDATE, DELETE, DROP, CREATE, ALTER, PRAGMA, ATTACH, DETACH, or any DDL/DML.
43
+ - Every table and column used in SQL must exist in the provided schema.
44
+ - Do NOT invent columns, tables, aliases, or constraints.
45
+ - SQL must be valid for SQLite.
46
+ - Prefer queries that are meaningfully different from each other.
47
+ - Avoid repetitive templates.
48
+ - Each SQL should test a different reasoning pattern.
49
+ - Each base_nl should sound natural and distinct from the others.
50
+ - Use advanced SQL patterns where appropriate:
51
+ - multiple JOINs
52
+ - CTEs
53
+ - subqueries
54
+ - window functions such as ROW_NUMBER, RANK, DENSE_RANK, LAG, LEAD
55
+ - GROUP BY and HAVING
56
+ - conditional aggregation
57
+ - anti-joins / exclusion logic
58
+ - top-N per group
59
+ - time-based filtering
60
+ - Exactly 3 of the 10 queries must be "easy" (basic filtering, simple lookups, 1-2 tables).
61
+ - Exactly 3 of the 10 queries must be "medium" (moderate complexity, standard JOINs, basic aggregation).
62
+ - Exactly 4 of the 10 queries must be genuinely "hard" (advanced patterns, CTEs, subqueries, window functions).
63
+ - Ensure the "difficulty" key strictly contains one of these exact string values: "easy", "medium", or "hard".
64
+
65
+ QUALITY TARGETS
66
+ - The SQL should be executable as written.
67
+ - The question should be answerable from the schema alone.
68
+ - Prefer business-like, realistic analytics questions.
69
+ - Prefer queries that require combining 2 to 4 tables.
70
+ - If a query uses aggregation, ensure the NL clearly implies aggregation.
71
+ - If a query uses ordering, include "has_order": true.
72
+ - If a query does not require ordering, set "has_order": false.
73
+ - Make the 10 queries cover diverse intent types:
74
+ 1. ranking
75
+ 2. comparison against average or median
76
+ 3. top/bottom-N
77
+ 4. grouped aggregation
78
+ 5. time filtering
79
+ 6. multi-join analysis
80
+ 7. exclusion / NOT EXISTS
81
+ 8. window-function based analysis
82
+ 9. conditional counting
83
+ 10. trend or interval-based logic
84
+
85
+ SCHEMA
86
+ {schema}
87
+
88
+ OUTPUT FORMAT
89
+ Return ONLY a valid JSON array of 10 objects.
90
+
91
+ Example structure:
92
+ [
93
+ {{
94
+ "sql": "SELECT ...",
95
+ "base_nl": "Show ...",
96
+ "difficulty": "hard",
97
+ "has_order": true
98
+ }}
99
+ ]
100
+
101
+ FINAL SELF-CHECK BEFORE RESPONDING
102
+ - Confirm the output is valid JSON.
103
+ - Confirm there are exactly 10 objects.
104
+ - Confirm every SQL is a single SELECT.
105
+ - Confirm no hallucinated schema elements exist.
106
+ - Confirm the 10 questions are not paraphrases of each other.
107
+ """
108
+
109
+ def extract_json(raw_text):
110
+ text = raw_text.strip()
111
+ if text.startswith("```json"):
112
+ text = text[7:-3].strip()
113
+ elif text.startswith("```"):
114
+ text = text[3:-3].strip()
115
+ start = text.find("[")
116
+ end = text.rfind("]")
117
+ if start != -1 and end != -1:
118
+ return text[start:end+1]
119
+ return None
120
+
121
+ def main():
122
+ print("Loading Model Qwen-72B (SDPA) for 10K Mining...")
123
+
124
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
125
+ custom_max_memory = {
126
+ 0: "60GiB", # System GPU 0 (Has 13GB used, ~67GB free)
127
+ 1: "75GiB", # System GPU 1 (Fully free)
128
+ 2: "75GiB", # System GPU 2 (Fully free)
129
+ 3: "75GiB", # System GPU 3 (Fully free)
130
+ 4: "75GiB", # System GPU 4 (Fully free)
131
+ 5: "45GiB" # System GPU 7 (Has 25GB used, ~55GB free)
132
+ }
133
+ model = AutoModelForCausalLM.from_pretrained(
134
+ MODEL_NAME,
135
+ device_map="auto",
136
+ max_memory = custom_max_memory,
137
+ torch_dtype=torch.bfloat16,
138
+ attn_implementation="sdpa"
139
+ )
140
+
141
+ domains = list(SCHEMA_CONTEXT.keys())
142
+ valid_templates = []
143
+ seen_sql_hashes = set()
144
+
145
+ # Resume support: Load existing templates to prevent duplicates
146
+ if os.path.exists(OUTPUT_FILE):
147
+ with open(OUTPUT_FILE, "r") as f:
148
+ valid_templates = json.load(f)
149
+ for t in valid_templates:
150
+ seen_sql_hashes.add(hashlib.md5(t["sql"].lower().encode()).hexdigest())
151
+
152
+ pbar = tqdm(total=TARGET_TEMPLATES, initial=len(valid_templates), desc="Mining 10K Base Templates")
153
+
154
+ validators = {}
155
+ domain_idx = 0
156
+
157
+ while len(valid_templates) < TARGET_TEMPLATES:
158
+ batch_prompts = []
159
+ batch_domains = []
160
+
161
+ # Prepare Batch
162
+ for _ in range(BATCH_SIZE):
163
+ domain = domains[domain_idx % len(domains)]
164
+ schema_string = SCHEMA_CONTEXT[domain]
165
+ domain_idx += 1
166
+
167
+ messages = [
168
+ {"role": "system", "content": "You output only valid JSON arrays. Do not include markdown."},
169
+ {"role": "user", "content": PROMPT_TEMPLATE.format(schema=schema_string)}
170
+ ]
171
+ chat_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
172
+ batch_prompts.append(chat_text)
173
+ batch_domains.append(domain)
174
+
175
+ inputs = tokenizer(batch_prompts, return_tensors="pt", padding=True, truncation=True).to(model.device)
176
+
177
+ try:
178
+ tqdm.write(f"\n[DEBUG] Sending batch of {BATCH_SIZE} to model.generate(). Please wait...")
179
+ with torch.no_grad():
180
+ outputs = model.generate(
181
+ **inputs,
182
+ max_new_tokens=5000,
183
+ do_sample=True,
184
+ temperature=0.55,
185
+ top_p=0.9,
186
+ pad_token_id=tokenizer.eos_token_id
187
+ )
188
+ tqdm.write("[DEBUG] Model generation finished. Decoding responses...")
189
+
190
+ # Output Slicing
191
+ input_length = inputs.input_ids.shape[1]
192
+ generated_tokens = outputs[:, input_length:]
193
+ responses = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
194
+
195
+ batch_added = 0
196
+ for i, (response, domain) in enumerate(zip(responses, batch_domains)):
197
+ tqdm.write(f"\n[DEBUG] Processing Response {i+1}/{BATCH_SIZE} for domain: {domain}")
198
+
199
+ json_text = extract_json(response)
200
+ if not json_text:
201
+ tqdm.write(f"[DEBUG] extract_json failed. Raw text snippet: {response[:200]}...")
202
+ continue
203
+
204
+ try:
205
+ generated_data = json.loads(json_text)
206
+ tqdm.write(f"[DEBUG] JSON loaded successfully. Found {len(generated_data)} items.")
207
+ except Exception as e:
208
+ tqdm.write(f"[DEBUG] json.loads failed. Error: {e}")
209
+ tqdm.write(f"[DEBUG] Bad JSON snippet: {json_text[:200]}...")
210
+ continue
211
+
212
+ if domain not in validators:
213
+ validators[domain] = SQLValidator(domain, seed=42)
214
+ validator = validators[domain]
215
+
216
+ for item in generated_data:
217
+ if not isinstance(item, dict): continue
218
+
219
+ sql = item.get("sql", "").strip()
220
+ if not sql: continue
221
+
222
+ # Check for duplicates using hash
223
+ sql_hash = hashlib.md5(sql.lower().encode()).hexdigest()
224
+ if sql_hash in seen_sql_hashes:
225
+ tqdm.write("[DEBUG] Duplicate query skipped.")
226
+ continue
227
+
228
+ val_result = validator.validate(sql)
229
+
230
+ # Hard validation rule: SQL must execute AND return rows
231
+ if val_result.passed and val_result.row_count > 0:
232
+ tqdm.write(f"[DEBUG] SQL Passed (Rows: {val_result.row_count}): {sql[:50]}...")
233
+ item["domain"] = domain
234
+ item["id"] = f"base_{len(valid_templates)}"
235
+ valid_templates.append(item)
236
+ seen_sql_hashes.add(sql_hash)
237
+ batch_added += 1
238
+ else:
239
+ tqdm.write(f"[DEBUG] SQL Failed Validation or 0 Rows (Passed: {val_result.passed}, Rows: {val_result.row_count}): {sql[:50]}...")
240
+
241
+ if batch_added > 0:
242
+ pbar.update(batch_added)
243
+ tqdm.write(f"[DEBUG] Auto-saving {batch_added} new templates to JSON...")
244
+ # Auto-save after every successful batch
245
+ with open(OUTPUT_FILE, "w") as f:
246
+ json.dump(valid_templates, f, indent=2)
247
+
248
+ if len(valid_templates) >= TARGET_TEMPLATES:
249
+ break
250
+
251
+ except Exception as e:
252
+ tqdm.write(f"\n[DEBUG] CRITICAL EXCEPTION CAUGHT: {e}")
253
+ continue
254
+
255
+ # Close validators
256
+ for v in validators.values():
257
+ v.close()
258
+
259
+ pbar.close()
260
+ print(f"\nBoom! Generated {len(valid_templates)} Elite Base Templates!")
261
+
262
+ if __name__ == "__main__":
263
+ main()
generate_edge_cases.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ generate_edge_cases.py
3
+ ======================
4
+ Targeted edge-case data generator for the 4 failure patterns found in eval:
5
+ 1. ROW_NUMBER vs RANK vs DENSE_RANK (tie-breaking semantics)
6
+ 2. strftime month as INTEGER (not '%Y-%m' string)
7
+ 3. SELECT column discipline (no unrequested extras)
8
+ 4. LAG/LEAD period-over-period
9
+ 5. HAVING vs WHERE placement
10
+ 6. COUNT(DISTINCT) vs COUNT
11
+
12
+ Produces: edge_cases.jsonl (same chat format as nl2sql_cleaned_ready_to_train.jsonl)
13
+ Run: python generate_edge_cases.py
14
+ """
15
+
16
+ import os, sys, json, re, hashlib
17
+ from tqdm import tqdm
18
+ from transformers import AutoModelForCausalLM, AutoTokenizer, AwqConfig, BitsAndBytesConfig
19
+ import torch
20
+ import transformers.activations
21
+ # Yeh line AutoAWQ ko bewakoof banayegi taaki wo crash na ho
22
+ if not hasattr(transformers.activations, 'PytorchGELUTanh'):
23
+ transformers.activations.PytorchGELUTanh = transformers.activations.NewGELUActivation
24
+ os.environ["CUDA_VISIBLE_DEVICES"] = "3,1,6,7"
25
+
26
+ PROJECT_ROOT = os.path.abspath(os.path.dirname(__file__))
27
+ if PROJECT_ROOT not in sys.path:
28
+ sys.path.insert(0, PROJECT_ROOT)
29
+
30
+ from data_factory.schemas import SCHEMA_CONTEXT
31
+
32
+ quantization_config = BitsAndBytesConfig(
33
+ load_in_4bit=True,
34
+ bnb_4bit_compute_dtype=torch.bfloat16,
35
+ bnb_4bit_use_double_quant=True,
36
+ bnb_4bit_quant_type="nf4"
37
+ )
38
+
39
+ MODEL_NAME = "Qwen/Qwen2.5-72B-Instruct"
40
+ OUTPUT_FILE = "edge_cases.jsonl"
41
+ BATCH_SIZE = 8 # smaller — edge prompts are long
42
+ SAMPLES_PER_PATTERN = 715 # ~6 batches per pattern → 5005 total edge samples
43
+
44
+ SYSTEM_PROMPT = (
45
+ "You are a Senior SQL Architect. "
46
+ "Output ONLY the SQL query. Use SQLite syntax."
47
+ )
48
+
49
+ # ── Edge-case prompt templates ──────────────────────────────────────────────
50
+ # Each entry: (pattern_tag, user_prompt_template)
51
+ # {schema} is filled at runtime with a random domain schema.
52
+
53
+ EDGE_PATTERNS = [
54
+
55
+ # 1. ROW_NUMBER tie-breaking — the #1 failure
56
+ ("row_number_tiebreak", """SCHEMA:
57
+ {schema}
58
+
59
+ Generate exactly 8 NL2SQL pairs that REQUIRE ROW_NUMBER() (not RANK or DENSE_RANK) \
60
+ because the question explicitly says "pick one winner when there is a tie" \
61
+ using a tiebreaker column (e.g. lower id, earlier date).
62
+
63
+ Output ONLY a valid JSON array:
64
+ [
65
+ {{"nl": "...", "sql": "SELECT ..."}},
66
+ ...
67
+ ]
68
+
69
+ Rules:
70
+ - Every SQL must use ROW_NUMBER() OVER (...) not RANK().
71
+ - The OVER clause ORDER BY must include the tiebreaker column.
72
+ - WHERE rn = 1 must appear in an outer query or CTE.
73
+ - No markdown. No explanation. Just the JSON array."""),
74
+
75
+ # 2. RANK / DENSE_RANK — when ties SHOULD persist
76
+ ("rank_dense_rank", """SCHEMA:
77
+ {schema}
78
+
79
+ Generate exactly 8 NL2SQL pairs where RANK() or DENSE_RANK() is the CORRECT choice \
80
+ because the question says "show all tied records at the same rank".
81
+
82
+ Output ONLY a valid JSON array:
83
+ [
84
+ {{"nl": "...", "sql": "SELECT ..."}},
85
+ ...
86
+ ]
87
+
88
+ Rules:
89
+ - Use RANK() when question implies gaps after ties, DENSE_RANK() when no gaps.
90
+ - NL must make the tie-semantics explicit ("same rank", "tied positions").
91
+ - No markdown. No explanation. Just the JSON array."""),
92
+
93
+ # 3. strftime integer month output
94
+ ("strftime_integer_month", """SCHEMA:
95
+ {schema}
96
+
97
+ Generate exactly 8 NL2SQL pairs where the question asks for a numeric month number \
98
+ (1–12), NOT a 'YYYY-MM' string.
99
+
100
+ Output ONLY a valid JSON array:
101
+ [
102
+ {{"nl": "...", "sql": "SELECT ..."}},
103
+ ...
104
+ ]
105
+
106
+ Rules:
107
+ - SQL must use CAST(strftime('%m', <col>) AS INTEGER) to produce integer month.
108
+ - Do NOT use strftime('%Y-%m', ...) when the question asks for month number.
109
+ - NL questions must say "month number", "which month (1–12)", or similar.
110
+ - No markdown. No explanation. Just the JSON array."""),
111
+
112
+ # 4. SELECT column discipline
113
+ ("select_column_discipline", """SCHEMA:
114
+ {schema}
115
+
116
+ Generate exactly 8 NL2SQL pairs where the question explicitly names ONLY the columns \
117
+ to return. The SQL must select EXACTLY those columns — no extras like avg_salary, \
118
+ row counts, or intermediate aggregates.
119
+
120
+ Output ONLY a valid JSON array:
121
+ [
122
+ {{"nl": "...", "sql": "SELECT ..."}},
123
+ ...
124
+ ]
125
+
126
+ Rules:
127
+ - NL must say "return only X, Y, Z" or "show me only the name and total".
128
+ - SQL SELECT list must contain only those columns.
129
+ - If aggregation is needed internally (e.g. for HAVING), do NOT expose it in SELECT.
130
+ - No markdown. No explanation. Just the JSON array."""),
131
+
132
+ # 5. LAG / LEAD period-over-period
133
+ ("lag_lead_period", """SCHEMA:
134
+ {schema}
135
+
136
+ Generate exactly 8 NL2SQL pairs that require LAG() or LEAD() window functions \
137
+ for period-over-period comparison (e.g. month-over-month revenue change, \
138
+ previous order amount, next appointment date).
139
+
140
+ Output ONLY a valid JSON array:
141
+ [
142
+ {{"nl": "...", "sql": "SELECT ..."}},
143
+ ...
144
+ ]
145
+
146
+ Rules:
147
+ - Use LAG(<col>, 1) OVER (ORDER BY ...) or LEAD(...) correctly.
148
+ - NL must imply comparison with previous or next row/period.
149
+ - No markdown. No explanation. Just the JSON array."""),
150
+
151
+ # 6. HAVING vs WHERE
152
+ ("having_vs_where", """SCHEMA:
153
+ {schema}
154
+
155
+ Generate exactly 8 NL2SQL pairs that test correct placement of filter conditions:
156
+ - Conditions on raw columns → WHERE
157
+ - Conditions on aggregates → HAVING
158
+ Include 4 pairs where a wrong model might put an aggregate condition in WHERE (trap).
159
+
160
+ Output ONLY a valid JSON array:
161
+ [
162
+ {{"nl": "...", "sql": "SELECT ..."}},
163
+ ...
164
+ ]
165
+
166
+ Rules:
167
+ - SQL must never filter an aggregate (COUNT, SUM, AVG) inside WHERE.
168
+ - SQL must never put a raw column filter inside HAVING.
169
+ - No markdown. No explanation. Just the JSON array."""),
170
+
171
+ # 7. COUNT(DISTINCT) vs COUNT
172
+ ("count_distinct", """SCHEMA:
173
+ {schema}
174
+
175
+ Generate exactly 8 NL2SQL pairs where the question specifically asks for \
176
+ "unique", "distinct", or "different" counts — requiring COUNT(DISTINCT col).
177
+ Also include 2 pairs where COUNT(*) is correct to reinforce the contrast.
178
+
179
+ Output ONLY a valid JSON array:
180
+ [
181
+ {{"nl": "...", "sql": "SELECT ..."}},
182
+ ...
183
+ ]
184
+
185
+ Rules:
186
+ - When NL says "unique/distinct", SQL must use COUNT(DISTINCT <col>).
187
+ - When NL says "total orders placed" (not distinct), use COUNT(*) or COUNT(id).
188
+ - No markdown. No explanation. Just the JSON array."""),
189
+ ]
190
+
191
+
192
+ # ── Helpers ──────────────────────────────────────────────────────────────────
193
+
194
+ def extract_json_array(text: str) -> str:
195
+ text = text.strip()
196
+ # strip code fences if model leaks them
197
+ text = re.sub(r"```(?:json)?\n?(.*?)```", r"\1", text, flags=re.DOTALL).strip()
198
+ s, e = text.find("["), text.rfind("]")
199
+ return text[s:e+1] if s != -1 and e != -1 else "[]"
200
+
201
+ def get_hash(text: str) -> str:
202
+ return hashlib.md5(text.lower().strip().encode()).hexdigest()
203
+
204
+ def build_record(nl: str, sql: str, domain: str) -> dict:
205
+ return {
206
+ "prompt": [
207
+ {"role": "system", "content": SYSTEM_PROMPT},
208
+ {"role": "user", "content": f"SCHEMA: {SCHEMA_CONTEXT[domain]}\nQUESTION: {nl}"}
209
+ ],
210
+ "sql": sql
211
+ }
212
+
213
+
214
+ # ── Main ─────────────────────────────────────────────────────────────────────
215
+
216
+ def main():
217
+ print(f"Loading {MODEL_NAME}...")
218
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side="left")
219
+ tokenizer.pad_token = tokenizer.eos_token
220
+ custom_memory = {0:"30GiB",1:"75GiB",2:"45GiB",3:"45GiB"}
221
+ model = AutoModelForCausalLM.from_pretrained(
222
+ MODEL_NAME,
223
+ device_map="auto",
224
+ max_memory=custom_memory,
225
+ quantization_config=quantization_config,
226
+ torch_dtype=torch.bfloat16,
227
+ low_cpu_mem_usage=True,
228
+ attn_implementation = "sdpa"
229
+ )
230
+
231
+ domains = list(SCHEMA_CONTEXT.keys())
232
+ seen = set()
233
+ total = 0
234
+
235
+ out = open(OUTPUT_FILE, "a", encoding="utf-8")
236
+
237
+ for pattern_tag, prompt_tmpl in EDGE_PATTERNS:
238
+ print(f"\n[PATTERN] {pattern_tag}")
239
+ collected = 0
240
+ domain_idx = 0
241
+ pbar = tqdm(total=SAMPLES_PER_PATTERN, desc=pattern_tag)
242
+
243
+ while collected < SAMPLES_PER_PATTERN:
244
+ # Build a batch of prompts, cycling through domains
245
+ batch_domains = []
246
+ batch_prompts = []
247
+ for _ in range(BATCH_SIZE):
248
+ domain = domains[domain_idx % len(domains)]
249
+ domain_idx += 1
250
+ user_msg = prompt_tmpl.format(schema=SCHEMA_CONTEXT[domain])
251
+ msgs = [
252
+ {"role": "system", "content": "You output only valid JSON arrays. No markdown."},
253
+ {"role": "user", "content": user_msg}
254
+ ]
255
+ batch_prompts.append(
256
+ tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
257
+ )
258
+ batch_domains.append(domain)
259
+
260
+ inputs = tokenizer(
261
+ batch_prompts, return_tensors="pt", padding=True, truncation=True
262
+ ).to(model.device)
263
+
264
+ try:
265
+ with torch.no_grad():
266
+ outputs = model.generate(
267
+ **inputs,
268
+ max_new_tokens=2048,
269
+ do_sample=True,
270
+ temperature=0.5,
271
+ top_p=0.9,
272
+ pad_token_id=tokenizer.eos_token_id
273
+ )
274
+ responses = tokenizer.batch_decode(
275
+ outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True
276
+ )
277
+
278
+ for resp, domain in zip(responses, batch_domains):
279
+ raw = extract_json_array(resp)
280
+ try:
281
+ pairs = json.loads(raw)
282
+ except Exception:
283
+ continue
284
+
285
+ for pair in pairs:
286
+ nl = pair.get("nl", "").strip()
287
+ sql = pair.get("sql", "").strip()
288
+ if not nl or not sql:
289
+ continue
290
+ # strip fences in sql just in case
291
+ sql = re.sub(r"```(?:sql)?\n?(.*?)```", r"\1", sql, flags=re.DOTALL).strip()
292
+
293
+ h = get_hash(nl + sql)
294
+ if h in seen:
295
+ continue
296
+ seen.add(h)
297
+
298
+ record = build_record(nl, sql, domain)
299
+ out.write(json.dumps(record, ensure_ascii=False) + "\n")
300
+ out.flush()
301
+ collected += 1
302
+ total += 1
303
+ pbar.update(1)
304
+
305
+ if collected >= SAMPLES_PER_PATTERN:
306
+ break
307
+
308
+ except Exception as e:
309
+ tqdm.write(f"[WARN] Batch failed: {e}")
310
+ continue
311
+
312
+ pbar.close()
313
+
314
+ out.close()
315
+ print(f"\nDone! {total} edge-case records saved to {OUTPUT_FILE}")
316
+
317
+
318
+ if __name__ == "__main__":
319
+ main()
inference.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ inference.py — NL2SQL-Bench Baseline Inference Script
3
+ ========================================================
4
+
5
+ MANDATORY COMPLIANCE
6
+ --------------------
7
+ - Named `inference.py`, placed in project root.
8
+ - Uses OpenAI client for all LLM calls.
9
+ - Reads: API_BASE_URL, MODEL_NAME, HF_TOKEN from environment.
10
+ - Emits [START] / [STEP] / [END] lines to stdout in the exact format below.
11
+ - Runs all 3 tasks; total runtime < 20 min on 2 vCPU / 8 GB.
12
+
13
+ STDOUT FORMAT (exact — any deviation breaks scoring)
14
+ ----------------------------------------------------
15
+ [START] task=<task_name> env=nl2sql-bench model=<model_name>
16
+ [STEP] step=<n> action=<sql_one_line> reward=<0.00> done=<true|false> error=<msg|null>
17
+ [END] success=<true|false> steps=<n> score=<0.000> rewards=<r1,r2,...>
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import asyncio
23
+ import os
24
+ import sys
25
+ import textwrap
26
+ from typing import List, Optional
27
+
28
+ from openai import OpenAI
29
+
30
+ # ── Configuration ──────────────────────────────────────────────────────────
31
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
32
+ MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
33
+ API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY", "")
34
+ IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "nl2sql-bench:latest")
35
+ SPACE_URL = os.getenv("SPACE_URL", "http://localhost:8000")
36
+
37
+ BENCHMARK = "nl2sql-bench"
38
+ MAX_STEPS = 5
39
+ TEMPERATURE = 0.2 # Low temp for SQL generation
40
+ MAX_TOKENS = 512
41
+ SUCCESS_THRESHOLD = 0.7 # score >= 0.7 → success
42
+
43
+ TASKS = ["simple-filter", "join-aggregation", "analytics-window"]
44
+
45
+ # ── System prompt ──────────────────────────────────────────────────────────
46
+ SYSTEM_PROMPT = textwrap.dedent("""
47
+ You are an expert SQL analyst working with a SQLite e-commerce database.
48
+
49
+ DATABASE SCHEMA
50
+ ---------------
51
+ categories(id, name)
52
+ products(id, name, category_id, price, stock_quantity)
53
+ customers(id, name, email, country, tier∈{bronze|silver|gold}, created_at)
54
+ orders(id, customer_id, status∈{pending|processing|shipped|delivered|cancelled},
55
+ created_at, total_amount)
56
+ order_items(id, order_id, product_id, quantity, unit_price)
57
+ reviews(id, product_id, customer_id, rating∈1-5, created_at)
58
+
59
+ RULES
60
+ -----
61
+ 1. Write a single SELECT query — no INSERT/UPDATE/DELETE.
62
+ 2. Output ONLY the SQL query, nothing else. No markdown, no explanation.
63
+ 3. Use SQLite syntax: strftime('%Y-%m', date_col) for month, ROUND(x, 2) for decimals.
64
+ 4. Window functions (RANK, DENSE_RANK, ROW_NUMBER, running SUM) are supported.
65
+ 5. CTEs (WITH ... AS (...)) are supported.
66
+ 6. If you receive an error, fix it carefully in your next attempt.
67
+ 7. If you receive partial results, refine your query to match the expected output.
68
+ """).strip()
69
+
70
+
71
+ # ── Stdout logging (mandatory format) ─────────────────────────────────────
72
+
73
+ def log_start(task: str, model: str) -> None:
74
+ print(f"[START] task={task} env={BENCHMARK} model={model}", flush=True)
75
+
76
+
77
+ def log_step(
78
+ step: int, action: str, reward: float, done: bool, error: Optional[str]
79
+ ) -> None:
80
+ # Collapse multi-line SQL to single line for log compliance
81
+ action_single = " ".join(action.split())
82
+ error_val = error.replace("\n", " ") if error else "null"
83
+ print(
84
+ f"[STEP] step={step} action={action_single!r} "
85
+ f"reward={reward:.2f} done={str(done).lower()} error={error_val}",
86
+ flush=True,
87
+ )
88
+
89
+
90
+ def log_end(
91
+ success: bool, steps: int, score: float, rewards: List[float]
92
+ ) -> None:
93
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
94
+ print(
95
+ f"[END] success={str(success).lower()} steps={steps} "
96
+ f"score={score:.3f} rewards={rewards_str}",
97
+ flush=True,
98
+ )
99
+
100
+
101
+ # ── LLM interaction ────────────────────────────────────────────────────────
102
+
103
+ def build_user_prompt(
104
+ question: str,
105
+ schema_context: str,
106
+ step: int,
107
+ last_query: str,
108
+ last_error: Optional[str],
109
+ last_result: list,
110
+ result_columns: list,
111
+ ) -> str:
112
+ parts = [f"QUESTION: {question}", ""]
113
+
114
+ if step > 1:
115
+ parts.append(f"Your previous SQL (step {step - 1}):")
116
+ parts.append(f" {' '.join(last_query.split())}")
117
+ parts.append("")
118
+ if last_error:
119
+ parts.append(f"ERROR: {last_error}")
120
+ elif last_result:
121
+ preview = str(last_result[:3]).replace("\n", " ")
122
+ parts.append(f"RESULT PREVIEW (first 3 rows): {preview}")
123
+ parts.append(f"COLUMNS: {result_columns}")
124
+ parts.append("")
125
+ parts.append("Please correct or refine your query.")
126
+ else:
127
+ parts.append("Write a SQL query to answer the question.")
128
+
129
+ return "\n".join(parts)
130
+
131
+
132
+ def call_llm(client: OpenAI, user_prompt: str) -> str:
133
+ try:
134
+ resp = client.chat.completions.create(
135
+ model=MODEL_NAME,
136
+ messages=[
137
+ {"role": "system", "content": SYSTEM_PROMPT},
138
+ {"role": "user", "content": user_prompt},
139
+ ],
140
+ temperature=TEMPERATURE,
141
+ max_tokens=MAX_TOKENS,
142
+ stream=False,
143
+ )
144
+ text = (resp.choices[0].message.content or "").strip()
145
+ # Strip markdown code fences if model wraps in ```sql ... ```
146
+ if text.startswith("```"):
147
+ lines = text.split("\n")
148
+ text = "\n".join(
149
+ l for l in lines
150
+ if not l.strip().startswith("```")
151
+ ).strip()
152
+ return text if text else "SELECT 1"
153
+ except Exception as exc:
154
+ print(f"[DEBUG] LLM call failed: {exc}", file=sys.stderr, flush=True)
155
+ return "SELECT 1"
156
+
157
+
158
+ # ── Single-task episode ────────────────────────────────────────────────────
159
+
160
+ async def run_task(client: OpenAI, env, task_name: str) -> dict:
161
+ """Run one full episode for the given task. Returns result dict."""
162
+ rewards: List[float] = []
163
+ steps_taken = 0
164
+ score = 0.0
165
+ success = False
166
+
167
+ log_start(task_name, MODEL_NAME)
168
+
169
+ try:
170
+ # Reset — pass task_name via action payload or query param
171
+ # OpenEnv reset() may not accept task args via HTTP; we rely on
172
+ # NL2SQL_DEFAULT_TASK env-var being set before calling, OR we
173
+ # pass it as a reset parameter if the server supports it.
174
+ result = await env.reset()
175
+ obs = result.observation
176
+
177
+ for step in range(1, MAX_STEPS + 1):
178
+ if result.done:
179
+ break
180
+
181
+ user_prompt = build_user_prompt(
182
+ question=obs.question,
183
+ schema_context=obs.schema_context,
184
+ step=step,
185
+ last_query=obs.last_query,
186
+ last_error=obs.last_error,
187
+ last_result=obs.last_result,
188
+ result_columns=obs.result_columns,
189
+ )
190
+
191
+ sql = call_llm(client, user_prompt)
192
+
193
+ from client import NL2SQLAction # local to avoid circular at module level
194
+ result = await env.step(NL2SQLAction(query=sql))
195
+ obs = result.observation
196
+
197
+ reward = obs.reward or 0.0
198
+ done = obs.done
199
+ error = obs.last_error
200
+
201
+ rewards.append(reward)
202
+ steps_taken = step
203
+
204
+ log_step(step=step, action=sql, reward=reward, done=done, error=error)
205
+
206
+ if done:
207
+ break
208
+
209
+ # Compute final score
210
+ score = sum(rewards) / max(len(rewards), 1)
211
+ score = round(min(max(score, 0.0), 1.0), 4)
212
+ success = score >= SUCCESS_THRESHOLD
213
+
214
+ except Exception as exc:
215
+ print(f"[DEBUG] Episode error for {task_name}: {exc}", file=sys.stderr, flush=True)
216
+ finally:
217
+ log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
218
+
219
+ return {"task": task_name, "success": success, "score": score, "rewards": rewards}
220
+
221
+
222
+ # ── Main ───────────────────────────────────────────────────────────────────
223
+
224
+ async def main() -> None:
225
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
226
+
227
+ # Import here to avoid import errors if openenv not installed during lint
228
+ from client import NL2SQLEnv
229
+
230
+ all_results = []
231
+
232
+ for task_name in TASKS:
233
+ # Set the default task for the server session via env-var approach.
234
+ # For the hosted Space, we rely on the task cycling implemented in
235
+ # the task registry's round-robin iterator.
236
+ os.environ["NL2SQL_DEFAULT_TASK"] = task_name
237
+
238
+ try:
239
+ async with NL2SQLEnv(base_url=SPACE_URL) as env:
240
+ result = await run_task(client, env, task_name)
241
+ all_results.append(result)
242
+ except Exception as exc:
243
+ print(
244
+ f"[DEBUG] Failed to connect for task {task_name}: {exc}",
245
+ file=sys.stderr,
246
+ flush=True,
247
+ )
248
+ # Emit a zero-score END to keep log format valid
249
+ log_end(success=False, steps=0, score=0.0, rewards=[])
250
+ all_results.append({"task": task_name, "success": False, "score": 0.0})
251
+
252
+ # Summary to stderr (not scored, for human readability)
253
+ print("\n=== Baseline Summary ===", file=sys.stderr)
254
+ for r in all_results:
255
+ print(
256
+ f" {r['task']:20s} score={r['score']:.3f} "
257
+ f"success={r['success']}",
258
+ file=sys.stderr,
259
+ )
260
+ avg = sum(r["score"] for r in all_results) / max(len(all_results), 1)
261
+ print(f" {'AVERAGE':20s} score={avg:.3f}", file=sys.stderr)
262
+
263
+
264
+ if __name__ == "__main__":
265
+ asyncio.run(main())
local_test.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import os
3
+ import sys
4
+ import torch
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
+ from peft import PeftModel
7
+
8
+ # --- Configuration ---
9
+ BASE_MODEL = "Qwen/Qwen2.5-0.5B-Instruct"
10
+ LORA_DIR = "./qwen-nl2sql-grpo/checkpoint-50"
11
+ SPACE_URL = "http://localhost:8000" # Local server URL
12
+ TASKS = ["simple-filter", "join-aggregation", "analytics-window"]
13
+ MAX_STEPS = 5
14
+
15
+ print("Loading Base Model and LoRA weights...")
16
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
17
+ base_model = AutoModelForCausalLM.from_pretrained(
18
+ BASE_MODEL,
19
+ torch_dtype=torch.bfloat16,
20
+ device_map="auto"
21
+ )
22
+ model = PeftModel.from_pretrained(base_model, LORA_DIR)
23
+
24
+ # --- System Prompt & LLM Call ---
25
+ SYSTEM_PROMPT = """You are an expert SQL analyst working with a SQLite e-commerce database.
26
+ Write a single SELECT query. Output ONLY the SQL query, nothing else. No markdown."""
27
+
28
+ def call_local_llm(user_prompt: str) -> str:
29
+ messages = [
30
+ {"role": "system", "content": SYSTEM_PROMPT},
31
+ {"role": "user", "content": user_prompt}
32
+ ]
33
+ text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
34
+ inputs = tokenizer([text], return_tensors="pt").to(model.device)
35
+
36
+ with torch.no_grad():
37
+ outputs = model.generate(**inputs, max_new_tokens=256, temperature=0.2, do_sample=True)
38
+
39
+ response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
40
+
41
+ # Strip markdown code fences if model wraps in ```sql ... ```
42
+ if response.startswith("```"):
43
+ lines = response.split("\n")
44
+ response = "\n".join(l for l in lines if not l.strip().startswith("```")).strip()
45
+ return response if response else "SELECT 1"
46
+
47
+ def build_user_prompt(question, schema_context, step, last_query, last_error, last_result, result_columns):
48
+ parts = [f"QUESTION: {question}", ""]
49
+ if step > 1:
50
+ parts.append(f"Your previous SQL (step {step - 1}):")
51
+ parts.append(f" {' '.join(last_query.split())}")
52
+ parts.append("")
53
+ if last_error:
54
+ parts.append(f"ERROR: {last_error}")
55
+ elif last_result:
56
+ preview = str(last_result[:3]).replace("\n", " ")
57
+ parts.append(f"RESULT PREVIEW (first 3 rows): {preview}")
58
+ parts.append(f"COLUMNS: {result_columns}")
59
+ parts.append("")
60
+ parts.append("Please correct or refine your query.")
61
+ else:
62
+ parts.append("Write a SQL query to answer the question.")
63
+ return "\n".join(parts)
64
+
65
+ async def main():
66
+ from client import NL2SQLEnv, NL2SQLAction
67
+
68
+ all_results = []
69
+
70
+ for task_name in TASKS:
71
+ print(f"\n--- Starting Task: {task_name} ---")
72
+ os.environ["NL2SQL_DEFAULT_TASK"] = task_name
73
+
74
+ try:
75
+ async with NL2SQLEnv(base_url=SPACE_URL) as env:
76
+ result = await env.reset()
77
+ obs = result.observation
78
+
79
+ rewards = []
80
+ success = False
81
+
82
+ for step in range(1, MAX_STEPS + 1):
83
+ if obs.done:
84
+ break
85
+
86
+ user_prompt = build_user_prompt(
87
+ obs.question, obs.schema_context, step,
88
+ obs.last_query, obs.last_error, obs.last_result, obs.result_columns
89
+ )
90
+
91
+ sql = call_local_llm(user_prompt)
92
+
93
+ print(f"Step {step} Agent Output: {sql}")
94
+
95
+ step_result = await env.step(NL2SQLAction(query=sql))
96
+ obs = step_result.observation
97
+
98
+ reward = obs.reward or 0.0
99
+ rewards.append(reward)
100
+ print(f"Step {step} Reward: {reward}")
101
+
102
+ if obs.done:
103
+ break
104
+
105
+ score = sum(rewards) / max(len(rewards), 1)
106
+ success = score >= 0.7
107
+ print(f"Final Score for {task_name}: {score:.3f}")
108
+ all_results.append({"task": task_name, "score": score, "success": success})
109
+
110
+ except Exception as e:
111
+ print(f"Error testing task {task_name}: {e}")
112
+
113
+ print("\n=== Final Results ===")
114
+ for r in all_results:
115
+ print(f"{r['task']}: Score {r['score']:.3f} | Success: {r['success']}")
116
+
117
+ if __name__ == "__main__":
118
+ asyncio.run(main())
merge_model.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from peft import PeftModel
4
+
5
+ BASE_MODEL = "Qwen/Qwen2.5-Coder-7B-Instruct"
6
+ ADAPTER_DIR = "./qwen-7b-coder-nl2sql-grpo/final"
7
+ OUTPUT_DIR = "./qwen-7b-nl2sql-merged"
8
+
9
+ def main():
10
+ print("Loading Base Model...")
11
+ base_model = AutoModelForCausalLM.from_pretrained(
12
+ BASE_MODEL,
13
+ torch_dtype=torch.bfloat16,
14
+ device_map="auto"
15
+ )
16
+
17
+ print("Loading Adapters and Merging...")
18
+ # Load the LoRA adapters into the base model
19
+ model = PeftModel.from_pretrained(base_model, ADAPTER_DIR)
20
+
21
+ # Merge weights permanently
22
+ merged_model = model.merge_and_unload()
23
+
24
+ print("Saving Merged Model...")
25
+ merged_model.save_pretrained(OUTPUT_DIR)
26
+
27
+ tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
28
+ tokenizer.save_pretrained(OUTPUT_DIR)
29
+ print(f"Done! Merged model saved to {OUTPUT_DIR}")
30
+
31
+ if __name__ == "__main__":
32
+ main()
mini_server.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import uvicorn
4
+ from fastapi import FastAPI
5
+ from pydantic import BaseModel
6
+ from typing import List
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+
9
+ # CRITICAL: GPU 0 pe host karenge
10
+ os.environ["CUDA_VISIBLE_DEVICES"] = "7"
11
+
12
+ app = FastAPI()
13
+
14
+ # Tera Merged Model Path
15
+ MODEL_PATH = "./qwen-7b-nl2sql-merged"
16
+
17
+ print("🚀 Loading Local Model for Inference API... (Takes a minute)")
18
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
19
+ model = AutoModelForCausalLM.from_pretrained(
20
+ MODEL_PATH,
21
+ device_map="auto",
22
+ torch_dtype=torch.bfloat16,
23
+ attn_implementation="sdpa" # Super stable, no vLLM crashes
24
+ )
25
+ print("✅ Server Ready! Acting as OpenAI on Port 8000.")
26
+
27
+ # OpenAI Request Schemas
28
+ class Message(BaseModel):
29
+ role: str
30
+ content: str
31
+
32
+ class ChatRequest(BaseModel):
33
+ model: str
34
+ messages: List[Message]
35
+ temperature: float = 0.2
36
+ max_tokens: int = 512
37
+
38
+ @app.post("/v1/chat/completions")
39
+ async def chat(request: ChatRequest):
40
+ # Convert OpenAI messages to Qwen format
41
+ messages = [{"role": m.role, "content": m.content} for m in request.messages]
42
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
43
+
44
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
45
+
46
+ # Generate SQL
47
+ with torch.no_grad():
48
+ outputs = model.generate(
49
+ **inputs,
50
+ max_new_tokens=request.max_tokens,
51
+ temperature=request.temperature,
52
+ do_sample=True if request.temperature > 0 else False,
53
+ pad_token_id=tokenizer.eos_token_id
54
+ )
55
+
56
+ # Decode only the newly generated text
57
+ response_text = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
58
+
59
+ # Return EXACT OpenAI JSON Structure
60
+ return {
61
+ "id": "chatcmpl-local-hackathon",
62
+ "object": "chat.completion",
63
+ "created": 1700000000,
64
+ "model": request.model,
65
+ "choices": [{
66
+ "index": 0,
67
+ "message": {"role": "assistant", "content": response_text},
68
+ "finish_reason": "stop"
69
+ }],
70
+ "usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
71
+ }
72
+
73
+ if __name__ == "__main__":
74
+ uvicorn.run(app, host="0.0.0.0", port=8001)
models.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ nl2sql-bench/models.py
3
+ ======================
4
+ Typed contracts for the NL2SQL-Bench OpenEnv environment.
5
+
6
+ Action : The SQL query the agent submits.
7
+ Observation : What the agent sees after each step.
8
+ State : Episode-level metadata (for state() endpoint).
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from dataclasses import dataclass, field
14
+ from typing import Any, Dict, List, Optional
15
+
16
+ from openenv.core.env_server import Action, Observation, State
17
+
18
+
19
+ # ---------------------------------------------------------------------------
20
+ # Action
21
+ # ---------------------------------------------------------------------------
22
+
23
+ class NL2SQLAction(Action):
24
+ """A single SQL query submitted by the agent."""
25
+ query: str = ""
26
+
27
+
28
+ # ---------------------------------------------------------------------------
29
+ # Observation
30
+ # ---------------------------------------------------------------------------
31
+
32
+
33
+ class NL2SQLObservation(Observation):
34
+ """
35
+ Everything the agent needs to reason about and iterate its SQL query.
36
+
37
+ Fields
38
+ ------
39
+ question : The natural-language question to answer.
40
+ schema_context : Relevant table/column descriptions as a string block.
41
+ task_name : Identifier of the current task (easy / medium / hard).
42
+ last_query : The SQL the agent submitted on the last step (empty on reset).
43
+ last_result : Up to 10 rows returned by the last query (list of dicts).
44
+ last_error : SQLite error string if the query failed, else None.
45
+ result_columns : Column names of last_result rows.
46
+ step : Current step number (1-indexed).
47
+ max_steps : Maximum steps allowed per episode.
48
+ done : True when the episode is over (success or step exhausted).
49
+ reward : Reward for the most recent action (None on reset).
50
+ score : Normalised cumulative score so far [0.0, 1.0].
51
+ """
52
+ question: str = ""
53
+ schema_context: str = ""
54
+ task_name: str = ""
55
+ last_query: str = ""
56
+ last_result: List[Dict[str, Any]] = field(default_factory=list)
57
+ last_error: Optional[str] = None
58
+ result_columns: List[str] = field(default_factory=list)
59
+ step: int = 0
60
+ max_steps: int = 5
61
+ done: bool = False
62
+ reward: Optional[float] = None
63
+ score: float = 0.0
64
+
65
+
66
+ # ---------------------------------------------------------------------------
67
+ # State
68
+ # ---------------------------------------------------------------------------
69
+
70
+ class NL2SQLState(State):
71
+ """Episode-level state (returned by the /state endpoint)."""
72
+ episode_id: Optional[str] = None
73
+ step_count: int = 0
74
+ task_name: str = ""
75
+ task_difficulty: str = "" # easy | medium | hard
76
+ question: str = ""
77
+ best_reward: float = 0.0 # highest reward seen this episode
78
+ cumulative_reward: float = 0.0
79
+ solved: bool = False # True if exact match was achieved
openenv.yaml ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # nl2sql-bench/openenv.yaml
2
+ # OpenEnv environment manifest — validated by `openenv validate`
3
+
4
+ name: nl2sql-bench
5
+ version: "0.1.0"
6
+ description: >
7
+ Natural Language to SQL query generation environment for RL training.
8
+ An agent iteratively writes and refines SQLite queries against a synthetic
9
+ e-commerce database to answer business questions. Multi-turn episodes with
10
+ dense, shaped rewards. Three difficulty tasks: easy (single-table),
11
+ medium (JOIN + GROUP BY), hard (window functions + CTEs).
12
+
13
+ author: "nl2sql-bench team"
14
+ license: MIT
15
+ tags:
16
+ - openenv
17
+ - nl2sql
18
+ - sql
19
+ - analytics
20
+ - rl-training
21
+ - deterministic
22
+ - multi-turn
23
+
24
+ # ── Task definitions ────────────────────────────────────────────────────────
25
+ tasks:
26
+ - name: simple-filter
27
+ difficulty: easy
28
+ description: >
29
+ Single-table SELECT with WHERE, ORDER BY, and LIMIT.
30
+ Tests basic SQL fluency. Expected solve rate: high.
31
+ max_steps: 5
32
+ reward_range: [0.0, 1.0]
33
+
34
+ - name: join-aggregation
35
+ difficulty: medium
36
+ description: >
37
+ Multi-table JOINs with GROUP BY, HAVING, and aggregation functions
38
+ (COUNT, SUM, AVG, ROUND). Tests relational reasoning.
39
+ max_steps: 5
40
+ reward_range: [0.0, 1.0]
41
+
42
+ - name: analytics-window
43
+ difficulty: hard
44
+ description: >
45
+ Advanced analytics using CTEs, window functions (DENSE_RANK,
46
+ ROW_NUMBER, running SUM), and nested subqueries. Tests multi-step
47
+ planning and SQLite-specific syntax.
48
+ max_steps: 5
49
+ reward_range: [0.0, 1.0]
50
+
51
+ # ── Action / Observation space ──────────────────────────────────────────────
52
+ action_space:
53
+ type: object
54
+ properties:
55
+ query:
56
+ type: string
57
+ description: "A SQLite SELECT query string."
58
+
59
+ observation_space:
60
+ type: object
61
+ properties:
62
+ question:
63
+ type: string
64
+ description: "Natural-language question the agent must answer."
65
+ schema_context:
66
+ type: string
67
+ description: "Compact database schema description for the agent."
68
+ task_name:
69
+ type: string
70
+ description: "Active task identifier."
71
+ last_query:
72
+ type: string
73
+ description: "The SQL query submitted on the previous step."
74
+ last_result:
75
+ type: array
76
+ description: "Up to 10 rows returned by the last query (list of dicts)."
77
+ last_error:
78
+ type: string
79
+ nullable: true
80
+ description: "SQLite error string if last query failed, else null."
81
+ result_columns:
82
+ type: array
83
+ description: "Column names of last_result."
84
+ step:
85
+ type: integer
86
+ description: "Current step number (1-indexed; 0 after reset)."
87
+ max_steps:
88
+ type: integer
89
+ description: "Maximum steps per episode."
90
+ done:
91
+ type: boolean
92
+ description: "True when episode ends (exact match or step limit reached)."
93
+ reward:
94
+ type: number
95
+ nullable: true
96
+ description: "Reward for the most recent step [0.0, 1.0]."
97
+ score:
98
+ type: number
99
+ description: "Normalised cumulative episode score [0.0, 1.0]."
100
+
101
+ # ── Reward function description ─────────────────────────────────────────────
102
+ reward:
103
+ type: shaped
104
+ range: [0.0, 1.0]
105
+ components:
106
+ - name: syntax_ok
107
+ weight: 0.10
108
+ description: "Query executes without SQLite error."
109
+ - name: columns_match
110
+ weight: 0.20
111
+ description: "Returned column names match ground truth exactly."
112
+ - name: row_count_match
113
+ weight: 0.20
114
+ description: "Number of returned rows matches ground truth."
115
+ - name: exact_match
116
+ weight: 0.50
117
+ description: "Full result set matches ground truth (order-aware for ORDER BY)."
118
+ - name: step_penalty
119
+ weight: -0.05
120
+ description: "Deducted per step beyond the first (encourages efficiency)."
121
+
122
+ # ── Deployment ──────────────────────────────────────────────────────────────
123
+ server:
124
+ port: 7860
125
+ dockerfile: Dockerfile
126
+ healthcheck: /health
127
+
128
+ # ── Baseline ────────────────────────────────────────────────────────────────
129
+ baseline:
130
+ script: inference.py
131
+ model: Qwen/Qwen2.5-72B-Instruct
132
+ expected_scores:
133
+ simple-filter: 0.70
134
+ join-aggregation: 0.45
135
+ analytics-window: 0.25
pyproject.toml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=69", "wheel"]
3
+ build-backend = "setuptools.backends.legacy:build"
4
+
5
+ [project]
6
+ name = "nl2sql-bench"
7
+ version = "0.1.0"
8
+ description = "NL2SQL-Bench: Natural Language to SQL Analytics OpenEnv environment for RL training"
9
+ readme = "README.md"
10
+ requires-python = ">=3.10"
11
+ license = { text = "MIT" }
12
+
13
+ dependencies = [
14
+ "openenv-core>=0.2.3",
15
+ "fastapi>=0.110.0",
16
+ "uvicorn[standard]>=0.29.0",
17
+ "pydantic>=2.0.0",
18
+ "openai>=1.0.0",
19
+ ]
20
+
21
+ [project.optional-dependencies]
22
+ dev = [
23
+ "pytest>=7.0",
24
+ "pytest-asyncio>=0.23",
25
+ "httpx>=0.27",
26
+ "black",
27
+ "ruff",
28
+ ]
29
+
30
+ [tool.setuptools.packages.find]
31
+ where = ["."]
32
+ include = ["nl2sql*", "server*"]
33
+
34
+ [tool.ruff]
35
+ line-length = 100
36
+ target-version = "py310"
37
+
38
+ [tool.pytest.ini_options]
39
+ asyncio_mode = "auto"
scripts/run_local.sh ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # nl2sql-bench/scripts/run_local.sh
3
+ # ─────────────────────────────────────────────────────────────────────────────
4
+ # Quick local development server (no Docker needed).
5
+ # Prerequisites: Python 3.10+, pip or uv
6
+ #
7
+ # Usage:
8
+ # chmod +x scripts/run_local.sh
9
+ # ./scripts/run_local.sh
10
+ # ─────────────────────────────────────────────────────────────────────────────
11
+ set -euo pipefail
12
+
13
+ REPO_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
14
+ cd "$REPO_ROOT"
15
+
16
+ echo "═══════════════════════════════════════════════"
17
+ echo " NL2SQL-Bench — Local Dev Server"
18
+ echo "═══════════════════════════════════════════════"
19
+
20
+ # ── Check Python ────────────────────────────────────────────────────────────
21
+ if ! command -v python3 &>/dev/null; then
22
+ echo "ERROR: python3 not found. Install Python 3.10+." && exit 1
23
+ fi
24
+ PY_VERSION=$(python3 -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")')
25
+ echo "Python: $PY_VERSION"
26
+
27
+ # ── Virtual environment ──────────────────────────────────────────────────────
28
+ if [ ! -d ".venv" ]; then
29
+ echo "Creating virtualenv..."
30
+ python3 -m venv .venv
31
+ fi
32
+ source .venv/bin/activate
33
+
34
+ # ── Install deps ─────────────────────────────────────────────────────────────
35
+ echo "Installing dependencies..."
36
+ pip install -q --upgrade pip
37
+ pip install -q openenv-core fastapi "uvicorn[standard]" openai pydantic pytest pytest-asyncio
38
+
39
+ # ── Load .env if present ─────────────────────────────────────────────────────
40
+ if [ -f ".env" ]; then
41
+ echo "Loading .env..."
42
+ set -a
43
+ source .env
44
+ set +a
45
+ fi
46
+
47
+ # ── Export PYTHONPATH ─────────────────────────────────────────────────────────
48
+ export PYTHONPATH="$REPO_ROOT:$REPO_ROOT/server"
49
+
50
+ echo ""
51
+ echo "Starting server at http://localhost:8000"
52
+ echo " /reset → POST (start episode)"
53
+ echo " /step → POST (submit SQL)"
54
+ echo " /state → GET (episode metadata)"
55
+ echo " /health → GET (liveness probe)"
56
+ echo " /docs → GET (Swagger UI)"
57
+ echo ""
58
+ echo "Press Ctrl+C to stop."
59
+ echo "───────────────────────────────────────────────"
60
+
61
+ cd "$REPO_ROOT/server"
62
+ uvicorn app:app \
63
+ --host 0.0.0.0 \
64
+ --port 8000 \
65
+ --reload \
66
+ --reload-dir "$REPO_ROOT" \
67
+ --log-level info
scripts/smoke_test.sh ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # nl2sql-bench/scripts/smoke_test.sh
3
+ # ─────────────────────────────────────────────────────────────────────────────
4
+ # Smoke tests against a running server (local or HF Space).
5
+ # Verifies all endpoints return expected HTTP codes and JSON shapes.
6
+ #
7
+ # Usage:
8
+ # ./scripts/smoke_test.sh # default localhost:8000
9
+ # ./scripts/smoke_test.sh https://your.hf.space # HF Space URL
10
+ # ─────────────────────────────────────────────────────────────────────────────
11
+ set -euo pipefail
12
+
13
+ BASE_URL="${1:-http://localhost:8000}"
14
+ BASE_URL="${BASE_URL%/}"
15
+ PASS=0; FAIL=0
16
+
17
+ GREEN='\033[0;32m'; RED='\033[0;31m'; NC='\033[0m'; BOLD='\033[1m'
18
+
19
+ pass() { echo -e "${GREEN}✓${NC} $1"; PASS=$((PASS+1)); }
20
+ fail() { echo -e "${RED}✗${NC} $1"; FAIL=$((FAIL+1)); }
21
+
22
+ echo ""
23
+ echo -e "${BOLD}NL2SQL-Bench Smoke Tests${NC}"
24
+ echo "Target: $BASE_URL"
25
+ echo "────────────────────────────────────────"
26
+
27
+ # ── /health ──────────────────────────────────────────────────────────────────
28
+ CODE=$(curl -s -o /dev/null -w "%{http_code}" "$BASE_URL/health")
29
+ [ "$CODE" = "200" ] && pass "/health → 200" || fail "/health → $CODE (expected 200)"
30
+
31
+ # ── /reset ───────────────────────────────────────────────────────────────────
32
+ RESET_BODY=$(curl -s -X POST "$BASE_URL/reset" \
33
+ -H "Content-Type: application/json" -d '{}')
34
+ echo "$RESET_BODY" | grep -q "question" && pass "/reset → has 'question' field" \
35
+ || fail "/reset → missing 'question' field. Body: $RESET_BODY"
36
+
37
+ # ── /step (valid SQL) ─────────────────────────────────────────────────────────
38
+ STEP_BODY=$(curl -s -X POST "$BASE_URL/step" \
39
+ -H "Content-Type: application/json" \
40
+ -d '{"query": "SELECT id, name FROM customers LIMIT 3"}')
41
+ echo "$STEP_BODY" | grep -q "reward" && pass "/step valid SQL → has 'reward'" \
42
+ || fail "/step valid SQL → missing 'reward'. Body: $STEP_BODY"
43
+ echo "$STEP_BODY" | grep -q '"done"' && pass "/step valid SQL → has 'done'" \
44
+ || fail "/step valid SQL → missing 'done'. Body: $STEP_BODY"
45
+
46
+ # ── /step (syntax error SQL) ──────────────────────────────────────────────────
47
+ STEP_ERR=$(curl -s -X POST "$BASE_URL/step" \
48
+ -H "Content-Type: application/json" \
49
+ -d '{"query": "SELCT * FORM broken_tbl"}')
50
+ echo "$STEP_ERR" | grep -q "last_error" && pass "/step bad SQL → has 'last_error'" \
51
+ || fail "/step bad SQL → missing 'last_error'. Body: $STEP_ERR"
52
+
53
+ # ── /state ────────────────────────────────────────────────────────────────────
54
+ STATE_BODY=$(curl -s "$BASE_URL/state")
55
+ echo "$STATE_BODY" | grep -q "step_count" && pass "/state → has 'step_count'" \
56
+ || fail "/state → missing 'step_count'. Body: $STATE_BODY"
57
+
58
+ echo "────────────────────────────────────────"
59
+ echo -e "${BOLD}Results: ${GREEN}${PASS} passed${NC}, ${RED}${FAIL} failed${NC}"
60
+ echo ""
61
+
62
+ [ "$FAIL" -eq 0 ] && exit 0 || exit 1
server/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # server/__init__.py
server/app.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ nl2sql-bench/server/app.py
3
+ ============================
4
+ FastAPI application entry point for the NL2SQL-Bench OpenEnv server.
5
+
6
+ create_fastapi_app() auto-creates all required OpenEnv endpoints:
7
+ POST /reset — start a new episode
8
+ POST /step — submit an action
9
+ GET /state — retrieve episode state
10
+ GET /health — health check
11
+ GET /web — interactive web UI (if ENABLE_WEB_INTERFACE=true)
12
+ GET /docs — Swagger UI
13
+ """
14
+
15
+ import sys
16
+ from pathlib import Path
17
+ from openenv.core.env_server import create_fastapi_app
18
+ from environment import NL2SQLEnvironment
19
+
20
+ # Ensure models can be imported from the parent directory
21
+ _HERE = Path(__file__).parent
22
+ sys.path.insert(0, str(_HERE.parent))
23
+
24
+ from models import NL2SQLAction, NL2SQLObservation
25
+
26
+ # Pass the explicitly required action and observation classes
27
+ app = create_fastapi_app(
28
+ NL2SQLEnvironment,
29
+ action_cls=NL2SQLAction,
30
+ observation_cls=NL2SQLObservation
31
+ )
server/app.py.bak ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ nl2sql-bench/server/app.py
3
+ ============================
4
+ FastAPI application entry point for the NL2SQL-Bench OpenEnv server.
5
+
6
+ create_fastapi_app() auto-creates all required OpenEnv endpoints:
7
+ POST /reset — start a new episode
8
+ POST /step — submit an action
9
+ GET /state — retrieve episode state
10
+ GET /health — health check
11
+ GET /web — interactive web UI (if ENABLE_WEB_INTERFACE=true)
12
+ GET /docs — Swagger UI
13
+ """
14
+
15
+ import sys
16
+ from pathlib import Path
17
+ from openenv.core.env_server import create_fastapi_app
18
+ from environment import NL2SQLEnvironment
19
+
20
+ # Ensure models can be imported from the parent directory
21
+ _HERE = Path(__file__).parent
22
+ sys.path.insert(0, str(_HERE.parent))
23
+
24
+ from models import NL2SQLAction, NL2SQLObservation
25
+
26
+ # Pass the explicitly required action and observation classes
27
+ app = create_fastapi_app(
28
+ NL2SQLEnvironment,
29
+ action_cls=NL2SQLAction,
30
+ observation_cls=NL2SQLObservation
31
+ )
server/db/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # server/db/__init__.py
server/db/schema.sql ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -- nl2sql-bench/server/db/schema.sql
2
+ -- E-commerce database schema for NL2SQL-Bench
3
+ -- Designed for in-memory SQLite: realistic, universally understood domain.
4
+
5
+ CREATE TABLE IF NOT EXISTS categories (
6
+ id INTEGER PRIMARY KEY,
7
+ name TEXT NOT NULL UNIQUE
8
+ );
9
+
10
+ CREATE TABLE IF NOT EXISTS products (
11
+ id INTEGER PRIMARY KEY,
12
+ name TEXT NOT NULL,
13
+ category_id INTEGER NOT NULL REFERENCES categories(id),
14
+ price REAL NOT NULL CHECK(price >= 0),
15
+ stock_quantity INTEGER NOT NULL DEFAULT 0
16
+ );
17
+
18
+ CREATE TABLE IF NOT EXISTS customers (
19
+ id INTEGER PRIMARY KEY,
20
+ name TEXT NOT NULL,
21
+ email TEXT NOT NULL UNIQUE,
22
+ country TEXT NOT NULL,
23
+ tier TEXT NOT NULL DEFAULT 'bronze' -- bronze | silver | gold
24
+ CHECK(tier IN ('bronze', 'silver', 'gold')),
25
+ created_at TEXT NOT NULL -- ISO-8601 date string
26
+ );
27
+
28
+ CREATE TABLE IF NOT EXISTS orders (
29
+ id INTEGER PRIMARY KEY,
30
+ customer_id INTEGER NOT NULL REFERENCES customers(id),
31
+ status TEXT NOT NULL DEFAULT 'pending'
32
+ CHECK(status IN ('pending','processing','shipped','delivered','cancelled')),
33
+ created_at TEXT NOT NULL,
34
+ total_amount REAL NOT NULL CHECK(total_amount >= 0)
35
+ );
36
+
37
+ CREATE TABLE IF NOT EXISTS order_items (
38
+ id INTEGER PRIMARY KEY,
39
+ order_id INTEGER NOT NULL REFERENCES orders(id),
40
+ product_id INTEGER NOT NULL REFERENCES products(id),
41
+ quantity INTEGER NOT NULL CHECK(quantity > 0),
42
+ unit_price REAL NOT NULL CHECK(unit_price >= 0)
43
+ );
44
+
45
+ CREATE TABLE IF NOT EXISTS reviews (
46
+ id INTEGER PRIMARY KEY,
47
+ product_id INTEGER NOT NULL REFERENCES products(id),
48
+ customer_id INTEGER NOT NULL REFERENCES customers(id),
49
+ rating INTEGER NOT NULL CHECK(rating BETWEEN 1 AND 5),
50
+ created_at TEXT NOT NULL
51
+ );
52
+
53
+ -- Indexes for common join/filter patterns
54
+ CREATE INDEX IF NOT EXISTS idx_products_category ON products(category_id);
55
+ CREATE INDEX IF NOT EXISTS idx_orders_customer ON orders(customer_id);
56
+ CREATE INDEX IF NOT EXISTS idx_orders_status ON orders(status);
57
+ CREATE INDEX IF NOT EXISTS idx_orders_created ON orders(created_at);
58
+ CREATE INDEX IF NOT EXISTS idx_order_items_order ON order_items(order_id);
59
+ CREATE INDEX IF NOT EXISTS idx_order_items_product ON order_items(product_id);
60
+ CREATE INDEX IF NOT EXISTS idx_reviews_product ON reviews(product_id);
61
+ CREATE INDEX IF NOT EXISTS idx_customers_country ON customers(country);
62
+ CREATE INDEX IF NOT EXISTS idx_customers_tier ON customers(tier);
server/db/seed.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ nl2sql-bench/server/db/seed.py
3
+ ==============================
4
+ Deterministic synthetic data generator for the NL2SQL-Bench SQLite database.
5
+
6
+ Uses a fixed random seed so every fresh environment build produces
7
+ IDENTICAL data, which is essential for reproducible grader scores across
8
+ different machines, runs, and Docker containers.
9
+
10
+ Call: seed_database(conn) once after creating tables.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import random
16
+ import sqlite3
17
+ from datetime import date, timedelta
18
+ from typing import List
19
+
20
+ # ── Deterministic seed ────────────────────────────────────────────────────
21
+ SEED = 42
22
+ RNG = random.Random(SEED)
23
+
24
+ # ── Domain constants ──────────────────────────────────────────────────────
25
+ CATEGORIES = [
26
+ "Electronics", "Clothing", "Books", "Home & Garden",
27
+ "Sports & Outdoors", "Toys & Games", "Beauty", "Automotive",
28
+ ]
29
+
30
+ PRODUCT_NAMES = {
31
+ "Electronics": ["Wireless Headphones", "USB-C Hub", "Mechanical Keyboard",
32
+ "Webcam 4K", "Portable Charger", "Smart Speaker",
33
+ "Monitor Stand", "HDMI Cable 2.1"],
34
+ "Clothing": ["Cotton T-Shirt", "Slim Fit Jeans", "Hoodie",
35
+ "Running Shorts", "Winter Jacket", "Polo Shirt",
36
+ "Casual Sneakers", "Wool Socks"],
37
+ "Books": ["Clean Code", "Designing Data-Intensive Applications",
38
+ "The Pragmatic Programmer", "System Design Interview",
39
+ "Deep Learning Book", "Python Cookbook",
40
+ "Domain-Driven Design", "Refactoring"],
41
+ "Home & Garden": ["Coffee Maker", "Air Purifier", "LED Desk Lamp",
42
+ "Plant Pot Set", "Storage Organiser", "Cutting Board",
43
+ "Vacuum Cleaner", "Electric Kettle"],
44
+ "Sports & Outdoors":["Yoga Mat", "Resistance Bands", "Cycling Gloves",
45
+ "Trekking Poles", "Water Bottle 1L", "Jump Rope",
46
+ "Foam Roller", "Compression Socks"],
47
+ "Toys & Games": ["Lego City Set", "Card Game Pack", "Puzzle 1000pc",
48
+ "Remote Control Car", "Building Blocks",
49
+ "Board Game Strategy", "Art Set", "Toy Drone"],
50
+ "Beauty": ["Face Serum", "SPF 50 Sunscreen", "Lip Balm",
51
+ "Shampoo Pro", "Hair Mask", "Eye Cream",
52
+ "Vitamin C Cream", "Toner Mist"],
53
+ "Automotive": ["Car Phone Mount", "Dash Cam", "Tyre Inflator",
54
+ "Car Vacuum", "Seat Cushion", "Steering Wheel Cover",
55
+ "OBD Scanner", "Jump Starter"],
56
+ }
57
+
58
+ COUNTRIES = ["India", "USA", "Germany", "UK", "Canada",
59
+ "Australia", "France", "Brazil", "Japan", "Singapore"]
60
+
61
+ TIERS = ["bronze", "silver", "gold"]
62
+ STATUSES = ["pending", "processing", "shipped", "delivered", "cancelled"]
63
+
64
+ FIRST_NAMES = [
65
+ "Aarav","Priya","Rahul","Neha","Arjun","Sneha","Vikram","Pooja",
66
+ "Karthik","Divya","James","Sarah","Michael","Emily","David","Jessica",
67
+ "Hans","Lena","Oliver","Sofia","Pierre","Amelie","Carlos","Laura",
68
+ "Yuki","Hana","Wei","Mei","Aiden","Zara",
69
+ ]
70
+ LAST_NAMES = [
71
+ "Sharma","Singh","Patel","Kumar","Gupta","Verma","Nair","Reddy",
72
+ "Smith","Johnson","Brown","Williams","Jones","Davis","Wilson",
73
+ "Müller","Schmidt","Schneider","Fischer","Weber",
74
+ "Martin","Bernard","Thomas","Richard","Petit",
75
+ "Garcia","Martinez","Lopez","Sanchez","Gonzalez",
76
+ ]
77
+
78
+
79
+ def _random_date(start_year: int = 2022, end_year: int = 2025) -> str:
80
+ start = date(start_year, 1, 1)
81
+ end = date(end_year, 12, 31)
82
+ delta = (end - start).days
83
+ return (start + timedelta(days=RNG.randint(0, delta))).isoformat()
84
+
85
+
86
+ def seed_database(conn: sqlite3.Connection) -> None:
87
+ """Populate the database with deterministic synthetic data."""
88
+ conn.execute("PRAGMA foreign_keys = ON")
89
+ cur = conn.cursor()
90
+
91
+ # ── Categories ────────────────────────────────────────────────────────
92
+ for i, name in enumerate(CATEGORIES, 1):
93
+ cur.execute(
94
+ "INSERT OR IGNORE INTO categories(id, name) VALUES (?, ?)",
95
+ (i, name),
96
+ )
97
+
98
+ # ── Products (8 per category → 64 total) ─────────────────────────────
99
+ pid = 1
100
+ for cat_id, (cat_name, names) in enumerate(PRODUCT_NAMES.items(), 1):
101
+ for pname in names:
102
+ price = round(RNG.uniform(5.0, 250.0), 2)
103
+ stock = RNG.randint(0, 500)
104
+ cur.execute(
105
+ "INSERT OR IGNORE INTO products(id, name, category_id, price, stock_quantity) "
106
+ "VALUES (?, ?, ?, ?, ?)",
107
+ (pid, pname, cat_id, price, stock),
108
+ )
109
+ pid += 1
110
+
111
+ # ── Customers (150 total) ─────────────────────────────────────────────
112
+ used_emails: set = set()
113
+ for cid in range(1, 151):
114
+ fname = RNG.choice(FIRST_NAMES)
115
+ lname = RNG.choice(LAST_NAMES)
116
+ name = f"{fname} {lname}"
117
+ email_base = f"{fname.lower()}.{lname.lower()}"
118
+ email = f"{email_base}{cid}@example.com"
119
+ while email in used_emails:
120
+ email = f"{email_base}{cid}x@example.com"
121
+ used_emails.add(email)
122
+
123
+ # Bias: 60% bronze, 30% silver, 10% gold
124
+ tier = RNG.choices(TIERS, weights=[60, 30, 10])[0]
125
+ country = RNG.choice(COUNTRIES)
126
+ created = _random_date(2021, 2023)
127
+ cur.execute(
128
+ "INSERT OR IGNORE INTO customers(id, name, email, country, tier, created_at) "
129
+ "VALUES (?, ?, ?, ?, ?, ?)",
130
+ (cid, name, email, country, tier, created),
131
+ )
132
+
133
+ # ── Orders + Order items ──────────────────────────────────────────────
134
+ oid = 1
135
+ item_id = 1
136
+ for cid in range(1, 151):
137
+ # Each customer has 0–8 orders; gold customers tend to have more
138
+ tier_row = cur.execute(
139
+ "SELECT tier FROM customers WHERE id=?", (cid,)
140
+ ).fetchone()
141
+ tier = tier_row[0] if tier_row else "bronze"
142
+ n_orders = RNG.choices(
143
+ range(9),
144
+ weights=[5, 20, 20, 15, 15, 10, 8, 5, 2] if tier == "bronze"
145
+ else ([2, 10, 15, 20, 20, 15, 10, 5, 3] if tier == "silver"
146
+ else [1, 5, 10, 15, 20, 20, 15, 10, 4]),
147
+ )[0]
148
+
149
+ for _ in range(n_orders):
150
+ status = RNG.choices(STATUSES, weights=[5, 10, 15, 60, 10])[0]
151
+ order_date = _random_date(2022, 2025)
152
+ # Pick 1–4 products for this order
153
+ n_items = RNG.randint(1, 4)
154
+ chosen_pids = RNG.sample(range(1, 65), k=min(n_items, 64))
155
+ total = 0.0
156
+
157
+ cur.execute(
158
+ "INSERT OR IGNORE INTO orders(id, customer_id, status, created_at, total_amount) "
159
+ "VALUES (?, ?, ?, ?, ?)",
160
+ (oid, cid, status, order_date, 0.0), # update total after items
161
+ )
162
+
163
+ for cpid in chosen_pids:
164
+ qty = RNG.randint(1, 5)
165
+ price_row = cur.execute(
166
+ "SELECT price FROM products WHERE id=?", (cpid,)
167
+ ).fetchone()
168
+ unit_price = price_row[0] if price_row else 10.0
169
+ total += round(qty * unit_price, 2)
170
+ cur.execute(
171
+ "INSERT OR IGNORE INTO order_items(id, order_id, product_id, quantity, unit_price) "
172
+ "VALUES (?, ?, ?, ?, ?)",
173
+ (item_id, oid, cpid, qty, unit_price),
174
+ )
175
+ item_id += 1
176
+
177
+ cur.execute(
178
+ "UPDATE orders SET total_amount=? WHERE id=?",
179
+ (round(total, 2), oid),
180
+ )
181
+ oid += 1
182
+
183
+ # ── Reviews ───────────────────────────────────────────────────────────
184
+ # Each customer reviews 0–6 products they (may have) ordered
185
+ rev_id = 1
186
+ reviewed: set = set() # (customer_id, product_id) pairs
187
+ for cid in range(1, 151):
188
+ n_reviews = RNG.randint(0, 6)
189
+ for _ in range(n_reviews):
190
+ rpid = RNG.randint(1, 64)
191
+ if (cid, rpid) in reviewed:
192
+ continue
193
+ reviewed.add((cid, rpid))
194
+ rating = RNG.choices([1, 2, 3, 4, 5], weights=[5, 10, 15, 35, 35])[0]
195
+ rev_date = _random_date(2022, 2025)
196
+ cur.execute(
197
+ "INSERT OR IGNORE INTO reviews(id, product_id, customer_id, rating, created_at) "
198
+ "VALUES (?, ?, ?, ?, ?)",
199
+ (rev_id, rpid, cid, rating, rev_date),
200
+ )
201
+ rev_id += 1
202
+
203
+ conn.commit()
204
+
205
+
206
+ def get_db_summary(conn: sqlite3.Connection) -> dict:
207
+ """Return row counts per table for debugging / README stats."""
208
+ tables = ["categories", "products", "customers", "orders", "order_items", "reviews"]
209
+ summary = {}
210
+ for t in tables:
211
+ row = conn.execute(f"SELECT COUNT(*) FROM {t}").fetchone()
212
+ summary[t] = row[0] if row else 0
213
+ return summary
214
+
215
+
216
+ if __name__ == "__main__":
217
+ import os
218
+ schema_path = os.path.join(os.path.dirname(__file__), "schema.sql")
219
+ conn = sqlite3.connect(":memory:")
220
+ conn.row_factory = sqlite3.Row
221
+ with open(schema_path) as f:
222
+ conn.executescript(f.read())
223
+ seed_database(conn)
224
+ print("Seed stats:", get_db_summary(conn))
225
+ conn.close()
server/environment.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ nl2sql-bench/server/environment.py
3
+ ====================================
4
+ NL2SQL-Bench core environment — implements the OpenEnv Environment interface.
5
+
6
+ Episode flow
7
+ ------------
8
+ 1. reset(task_name?) → picks a task + question, returns initial observation
9
+ 2. step(action) → executes the SQL, grades it, returns observation + reward
10
+ 3. state() → returns episode metadata
11
+ 4. Episode ends when: exact_match OR step count reaches max_steps
12
+
13
+ The environment manages its own SQLite connection (in-memory, seeded
14
+ deterministically). One connection per Environment instance; the FastAPI
15
+ server creates one Environment per WebSocket session.
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import os
21
+ import sqlite3
22
+ import uuid
23
+ from pathlib import Path
24
+ from typing import Optional
25
+
26
+ from openenv.core.env_server import Environment
27
+
28
+ # Import after openenv so path is correct regardless of working directory
29
+ _HERE = Path(__file__).parent
30
+
31
+ # Lazy import of task registry (avoids circular imports)
32
+ from tasks import get_task, all_task_names, BaseTask
33
+ from tasks.base import TaskExample
34
+ from grader import (
35
+ GradeResult,
36
+ compute_ground_truth,
37
+ execute_query,
38
+ grade,
39
+ has_order_by,
40
+ )
41
+
42
+ # We import our models from one level up (models.py at project root)
43
+ import sys
44
+ sys.path.insert(0, str(_HERE.parent))
45
+ from models import NL2SQLAction, NL2SQLObservation, NL2SQLState
46
+
47
+ # ── Constants ──────────────────────────────────────────────────────────────
48
+ DEFAULT_TASK = os.getenv("NL2SQL_DEFAULT_TASK", "simple-filter")
49
+ MAX_STEPS = int(os.getenv("NL2SQL_MAX_STEPS", "5"))
50
+ RESULT_LIMIT = 10 # Max rows shown to agent per step
51
+
52
+
53
+ class NL2SQLEnvironment(Environment):
54
+ """
55
+ OpenEnv-compliant environment for NL-to-SQL query generation.
56
+
57
+ One instance per WebSocket session (created by create_fastapi_app).
58
+ """
59
+
60
+ def __init__(self) -> None:
61
+ self._conn: Optional[sqlite3.Connection] = None
62
+ self._task: Optional[BaseTask] = None
63
+ self._example: Optional[TaskExample] = None
64
+ self._ground_truth: list = []
65
+ self._order_sensitive: bool = False
66
+ self._state = NL2SQLState(
67
+ episode_id=None,
68
+ step_count=0,
69
+ task_name="",
70
+ task_difficulty="",
71
+ question="",
72
+ best_reward=0.0,
73
+ cumulative_reward=0.0,
74
+ solved=False
75
+ )
76
+ self._last_obs = NL2SQLObservation(
77
+ question="",
78
+ schema_context="",
79
+ task_name="",
80
+ last_query="",
81
+ last_result=[],
82
+ last_error=None,
83
+ result_columns=[],
84
+ step=0,
85
+ max_steps=5,
86
+ done=False,
87
+ reward=None,
88
+ score=0.0
89
+ )
90
+ self._episode_rewards: list = []
91
+ self._setup_db()
92
+
93
+ # ── DB lifecycle ───────────────────────────────────────────────────────
94
+
95
+ def _setup_db(self) -> None:
96
+ """Create in-memory SQLite DB and seed it."""
97
+ schema_path = _HERE / "db" / "schema.sql"
98
+ from db.seed import seed_database # local import after sys.path setup
99
+ conn = sqlite3.connect(":memory:", check_same_thread=False)
100
+ conn.row_factory = sqlite3.Row
101
+ conn.execute("PRAGMA foreign_keys = ON")
102
+ conn.executescript(schema_path.read_text())
103
+ seed_database(conn)
104
+ self._conn = conn
105
+
106
+ # ── OpenEnv interface ──────────────────────────────────────────────────
107
+
108
+ def reset(self, task_name: Optional[str] = None) -> NL2SQLObservation:
109
+ """
110
+ Start a new episode.
111
+
112
+ task_name: one of 'simple-filter', 'join-aggregation', 'analytics-window'.
113
+ Defaults to NL2SQL_DEFAULT_TASK env-var or 'simple-filter'.
114
+ """
115
+ task_name = task_name or DEFAULT_TASK
116
+ if task_name not in all_task_names():
117
+ task_name = DEFAULT_TASK
118
+
119
+ self._task = get_task(task_name)
120
+ self._example = self._task.next_example()
121
+ self._order_sensitive = has_order_by(self._example.sql)
122
+
123
+ # Pre-compute ground truth once per episode
124
+ self._ground_truth = compute_ground_truth(self._conn, self._example.sql)
125
+
126
+ self._episode_rewards = []
127
+ self._state = NL2SQLState(
128
+ episode_id=str(uuid.uuid4()),
129
+ step_count=0,
130
+ task_name=self._task.name,
131
+ task_difficulty=self._task.difficulty,
132
+ question=self._example.question,
133
+ best_reward=0.0,
134
+ cumulative_reward=0.0,
135
+ solved=False,
136
+ )
137
+
138
+ obs = NL2SQLObservation(
139
+ question=self._example.question,
140
+ schema_context=self._task.schema_context(),
141
+ task_name=self._task.name,
142
+ last_query="",
143
+ last_result=[],
144
+ last_error=None,
145
+ result_columns=[],
146
+ step=0,
147
+ max_steps=MAX_STEPS,
148
+ done=False,
149
+ reward=None,
150
+ score=0.0,
151
+ )
152
+ self._last_obs = obs
153
+ return obs
154
+
155
+ def step(self, action: NL2SQLAction) -> NL2SQLObservation:
156
+ """Execute the agent's SQL and return graded observation."""
157
+ if self._task is None or self._example is None:
158
+ # Called before reset — auto-reset
159
+ self.reset()
160
+
161
+ self._state.step_count += 1
162
+ current_step = self._state.step_count
163
+ done = False
164
+
165
+ # Execute the query
166
+ rows, error = execute_query(self._conn, action.query)
167
+
168
+ # Grade it
169
+ result: GradeResult = grade(
170
+ actual_rows=rows,
171
+ ground_truth_rows=self._ground_truth,
172
+ error=error,
173
+ step=current_step,
174
+ order_sensitive=self._order_sensitive,
175
+ )
176
+
177
+ reward = result.reward
178
+ self._episode_rewards.append(reward)
179
+ self._state.cumulative_reward += reward
180
+ self._state.best_reward = max(self._state.best_reward, reward)
181
+
182
+ if result.exact_match:
183
+ self._state.solved = True
184
+ done = True
185
+ elif current_step >= MAX_STEPS:
186
+ done = True
187
+
188
+ # Prepare result rows for observation (truncated for agent readability)
189
+ display_rows = (rows or [])[:RESULT_LIMIT]
190
+ result_columns = list(display_rows[0].keys()) if display_rows else []
191
+ # Convert sqlite3.Row objects if needed
192
+ display_rows = [dict(r) for r in display_rows]
193
+
194
+ # Normalised cumulative score
195
+ n = len(self._episode_rewards)
196
+ score = self._state.cumulative_reward / max(n, 1) if n else 0.0
197
+ score = round(min(max(score, 0.0), 1.0), 4)
198
+
199
+ obs = NL2SQLObservation(
200
+ question=self._example.question,
201
+ schema_context=self._task.schema_context(),
202
+ task_name=self._task.name,
203
+ last_query=action.query,
204
+ last_result=display_rows,
205
+ last_error=error,
206
+ result_columns=result_columns,
207
+ step=current_step,
208
+ max_steps=MAX_STEPS,
209
+ done=done,
210
+ reward=reward,
211
+ score=score,
212
+ )
213
+ self._last_obs = obs
214
+ return obs
215
+
216
+ @property
217
+ def state(self) -> NL2SQLState:
218
+ return self._state
219
+
220
+ # ── Helpers ────────────────────────────────────────────────────────────
221
+
222
+ def available_tasks(self) -> list:
223
+ return all_task_names()
server/grader.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ nl2sql-bench/server/grader.py
3
+ ==============================
4
+ Deterministic, programmatic reward grader.
5
+
6
+ No LLM-as-judge. Every reward is computed by comparing the agent's SQL
7
+ execution results against a ground-truth result set.
8
+
9
+ Reward decomposition (sums to 1.0 for a perfect first-attempt answer):
10
+ +0.10 syntax_ok — query runs without SQLite error
11
+ +0.20 columns_match — returned column names match ground truth exactly
12
+ +0.20 row_count_match — number of returned rows matches
13
+ +0.50 exact_match — full result set equals ground truth (order-aware
14
+ for ORDER BY queries, order-agnostic otherwise)
15
+
16
+ Step penalty:
17
+ -0.05 per step beyond the first (encourages solving in fewer steps),
18
+ clamped so the minimum is always 0.0.
19
+
20
+ All rewards are floats in [0.0, 1.0].
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ import sqlite3
26
+ from typing import Any, Dict, List, Optional, Tuple
27
+
28
+
29
+ # ── Result normalisation ───────────────────────────────────────────────────
30
+
31
+ def _normalise_value(v: Any) -> Any:
32
+ """Round floats for comparison so 1.2300000001 == 1.23."""
33
+ if isinstance(v, float):
34
+ return round(v, 4)
35
+ if isinstance(v, str):
36
+ return v.strip()
37
+ return v
38
+
39
+
40
+ def _normalise_row(row: Dict[str, Any]) -> Dict[str, Any]:
41
+ return {k: _normalise_value(v) for k, v in row.items()}
42
+
43
+
44
+ def _normalise_rows(rows: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
45
+ return [_normalise_row(r) for r in rows]
46
+
47
+
48
+ # ── SQL execution ──────────────────────────────────────────────────────────
49
+
50
+ def execute_query(
51
+ conn: sqlite3.Connection,
52
+ query: str,
53
+ max_rows: int = 200,
54
+ ) -> Tuple[Optional[List[Dict[str, Any]]], Optional[str]]:
55
+ """
56
+ Execute a SQL query safely.
57
+
58
+ Returns (rows, error_string).
59
+ rows is None on error.
60
+ """
61
+ query = query.strip().rstrip(";")
62
+ if not query:
63
+ return None, "Empty query."
64
+
65
+ # Block write operations — the environment is read-only from the agent's view.
66
+ forbidden = ("insert", "update", "delete", "drop", "alter",
67
+ "create", "replace", "truncate", "pragma")
68
+ first_word = query.split()[0].lower() if query.split() else ""
69
+ if first_word in forbidden:
70
+ return None, (
71
+ f"Operation '{first_word.upper()}' is not allowed. "
72
+ "Only SELECT queries are permitted."
73
+ )
74
+
75
+ try:
76
+ cur = conn.execute(query)
77
+ cols = [d[0] for d in cur.description] if cur.description else []
78
+ rows = [dict(zip(cols, row)) for row in cur.fetchmany(max_rows)]
79
+ return rows, None
80
+ except sqlite3.Error as exc:
81
+ return None, str(exc)
82
+
83
+
84
+ # ── Grading logic ──────────────────────────────────────────────────────────
85
+
86
+ class GradeResult:
87
+ __slots__ = (
88
+ "reward", "syntax_ok", "columns_match",
89
+ "row_count_match", "exact_match", "step_penalty",
90
+ "breakdown",
91
+ )
92
+
93
+ def __init__(
94
+ self,
95
+ reward: float,
96
+ syntax_ok: bool,
97
+ columns_match: bool,
98
+ row_count_match: bool,
99
+ exact_match: bool,
100
+ step_penalty: float,
101
+ ) -> None:
102
+ self.reward = reward
103
+ self.syntax_ok = syntax_ok
104
+ self.columns_match = columns_match
105
+ self.row_count_match = row_count_match
106
+ self.exact_match = exact_match
107
+ self.step_penalty = step_penalty
108
+ self.breakdown = {
109
+ "syntax_ok": 0.10 if syntax_ok else 0.0,
110
+ "columns_match": 0.20 if (syntax_ok and columns_match) else 0.0,
111
+ "row_count_match": 0.20 if (syntax_ok and row_count_match) else 0.0,
112
+ "exact_match": 0.50 if (syntax_ok and exact_match) else 0.0,
113
+ "step_penalty": -step_penalty,
114
+ }
115
+
116
+ def __repr__(self) -> str: # pragma: no cover
117
+ return (
118
+ f"GradeResult(reward={self.reward:.3f}, "
119
+ f"exact={self.exact_match}, cols={self.columns_match}, "
120
+ f"rows={self.row_count_match}, syntax={self.syntax_ok})"
121
+ )
122
+
123
+
124
+ def grade(
125
+ actual_rows: Optional[List[Dict[str, Any]]],
126
+ ground_truth_rows: List[Dict[str, Any]],
127
+ error: Optional[str],
128
+ step: int,
129
+ order_sensitive: bool = False,
130
+ ) -> GradeResult:
131
+ """
132
+ Grade the agent's query result against ground truth.
133
+
134
+ Parameters
135
+ ----------
136
+ actual_rows : Rows returned by the agent's query (None on error).
137
+ ground_truth_rows : Expected rows (pre-computed at task load time).
138
+ error : SQLite error string (None if query ran successfully).
139
+ step : Current step number (1-indexed) for penalty calculation.
140
+ order_sensitive : If True, row order matters (queries with ORDER BY).
141
+ """
142
+ # ── Syntax ──────────────────────────────────────────────────────────
143
+ syntax_ok = error is None and actual_rows is not None
144
+
145
+ if not syntax_ok:
146
+ return GradeResult(
147
+ reward=0.0,
148
+ syntax_ok=False,
149
+ columns_match=False,
150
+ row_count_match=False,
151
+ exact_match=False,
152
+ step_penalty=0.0,
153
+ )
154
+
155
+ gt_norm = _normalise_rows(ground_truth_rows)
156
+ act_norm = _normalise_rows(actual_rows)
157
+
158
+ gt_cols = set(gt_norm[0].keys()) if gt_norm else set()
159
+ act_cols = set(act_norm[0].keys()) if act_norm else set()
160
+ columns_match = act_cols == gt_cols
161
+ row_count_match = len(act_norm) == len(gt_norm)
162
+
163
+ # Exact match: if order matters, compare list; otherwise compare sorted sets
164
+ if columns_match and row_count_match:
165
+ if order_sensitive:
166
+ exact_match = act_norm == gt_norm
167
+ else:
168
+ # Sort rows by their string representation for order-agnostic compare
169
+ def _sort_key(r: Dict) -> str:
170
+ return str(sorted(r.items()))
171
+ exact_match = (
172
+ sorted(act_norm, key=_sort_key) == sorted(gt_norm, key=_sort_key)
173
+ )
174
+ else:
175
+ exact_match = False
176
+
177
+ # ── Score assembly ────────────────────────────────────────────────
178
+ raw = (
179
+ 0.10 # syntax
180
+ + (0.20 if columns_match else 0.0)
181
+ + (0.20 if row_count_match else 0.0)
182
+ + (0.50 if exact_match else 0.0)
183
+ )
184
+
185
+ penalty = max(0.0, step - 1) * 0.05
186
+ reward = float(max(0.0, min(1.0, raw - penalty)))
187
+
188
+ return GradeResult(
189
+ reward=reward,
190
+ syntax_ok=syntax_ok,
191
+ columns_match=columns_match,
192
+ row_count_match=row_count_match,
193
+ exact_match=exact_match,
194
+ step_penalty=penalty,
195
+ )
196
+
197
+
198
+ # ── Convenience: pre-compute ground truth rows ─────────────────────────────
199
+
200
+ def compute_ground_truth(
201
+ conn: sqlite3.Connection,
202
+ sql: str,
203
+ ) -> List[Dict[str, Any]]:
204
+ """Execute the ground-truth SQL and return normalised rows."""
205
+ rows, error = execute_query(conn, sql)
206
+ if error or rows is None:
207
+ raise ValueError(f"Ground-truth SQL failed: {error}\nSQL: {sql}")
208
+ return _normalise_rows(rows)
209
+
210
+
211
+ def has_order_by(sql: str) -> bool:
212
+ """Heuristic: does the top-level query have an ORDER BY?"""
213
+ # Simple check sufficient for our controlled task SQL
214
+ return "ORDER BY" in sql.upper()
server/requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # nl2sql-bench/server/requirements.txt
2
+ # Minimal dependency set for 2 vCPU / 8 GB constraint.
3
+ # SQLite is part of the Python stdlib — no extra DB dependency needed.
4
+
5
+ # OpenEnv core framework
6
+ openenv-core>=0.2.3
7
+
8
+ # Web server
9
+ fastapi>=0.110.0
10
+ uvicorn[standard]>=0.29.0
11
+
12
+ # Typing helpers (already included in openenv-core but listed explicitly)
13
+ pydantic>=2.0.0
server/tasks/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Auto-registers all tasks by importing them."""
2
+ from .easy import SimpleFilterTask
3
+ from .medium import JoinAggregationTask
4
+ from .hard import AnalyticsWindowTask
5
+ from .base import get_task, all_task_names, BaseTask
6
+
7
+ __all__ = [
8
+ "SimpleFilterTask",
9
+ "JoinAggregationTask",
10
+ "AnalyticsWindowTask",
11
+ "get_task",
12
+ "all_task_names",
13
+ "BaseTask",
14
+ ]
server/tasks/base.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ nl2sql-bench/server/tasks/base.py
3
+ ==================================
4
+ Abstract base for all NL2SQL tasks and the global task registry.
5
+
6
+ Each task holds a list of (question, ground_truth_sql) pairs.
7
+ The environment picks one pair per episode via a deterministic round-robin
8
+ so that the same task always cycles through the same question sequence —
9
+ this keeps grader results reproducible across runs.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import sqlite3
15
+ from abc import ABC, abstractmethod
16
+ from typing import Dict, List, NamedTuple, Tuple, Type
17
+
18
+
19
+ class TaskExample(NamedTuple):
20
+ question: str
21
+ sql: str
22
+ # Human-readable description of what makes this question that difficulty
23
+ notes: str = ""
24
+
25
+
26
+ class BaseTask(ABC):
27
+ """Abstract base class for all tasks."""
28
+
29
+ name: str = ""
30
+ difficulty: str = "" # easy | medium | hard
31
+ examples: List[TaskExample] = []
32
+
33
+ def __init__(self) -> None:
34
+ if not self.examples:
35
+ raise ValueError(f"Task {self.name!r} has no examples defined.")
36
+ self._cursor = 0 # round-robin index
37
+
38
+ def next_example(self) -> TaskExample:
39
+ """Return the next question in round-robin order."""
40
+ example = self.examples[self._cursor % len(self.examples)]
41
+ self._cursor += 1
42
+ return example
43
+
44
+ @classmethod
45
+ def schema_context(cls) -> str:
46
+ """Return a compact schema description for the agent system prompt."""
47
+ return _SCHEMA_CONTEXT
48
+
49
+ @abstractmethod
50
+ def description(self) -> str:
51
+ """One-sentence description for openenv.yaml."""
52
+
53
+
54
+ # ── Global schema context string (injected into every observation) ─────────
55
+
56
+ _SCHEMA_CONTEXT = """\
57
+ Database: ecommerce (SQLite, read-only)
58
+
59
+ TABLES
60
+ ------
61
+ categories(id INTEGER PK, name TEXT)
62
+
63
+ products(id INTEGER PK, name TEXT, category_id INTEGER FK→categories.id,
64
+ price REAL, stock_quantity INTEGER)
65
+
66
+ customers(id INTEGER PK, name TEXT, email TEXT, country TEXT,
67
+ tier TEXT ∈ {bronze|silver|gold}, created_at TEXT ISO-8601)
68
+
69
+ orders(id INTEGER PK, customer_id INTEGER FK→customers.id,
70
+ status TEXT ∈ {pending|processing|shipped|delivered|cancelled},
71
+ created_at TEXT ISO-8601, total_amount REAL)
72
+
73
+ order_items(id INTEGER PK, order_id INTEGER FK→orders.id,
74
+ product_id INTEGER FK→products.id,
75
+ quantity INTEGER, unit_price REAL)
76
+
77
+ reviews(id INTEGER PK, product_id INTEGER FK→products.id,
78
+ customer_id INTEGER FK→customers.id,
79
+ rating INTEGER 1-5, created_at TEXT ISO-8601)
80
+
81
+ NOTES
82
+ -----
83
+ - Date comparisons: use created_at >= '2024-01-01' (text ISO sort works)
84
+ - SQLite window functions (RANK, DENSE_RANK, ROW_NUMBER, LAG, LEAD) are available
85
+ - strftime('%Y-%m', created_at) returns 'YYYY-MM' month strings
86
+ - All monetary values are in USD
87
+ """
88
+
89
+
90
+ # ── Task registry ──────────────────────────────────────────────────────────
91
+
92
+ _REGISTRY: Dict[str, Type[BaseTask]] = {}
93
+
94
+
95
+ def register(cls: Type[BaseTask]) -> Type[BaseTask]:
96
+ """Class decorator to register a task."""
97
+ _REGISTRY[cls.name] = cls
98
+ return cls
99
+
100
+
101
+ def get_task(name: str) -> BaseTask:
102
+ if name not in _REGISTRY:
103
+ raise KeyError(f"Unknown task {name!r}. Available: {list(_REGISTRY)}")
104
+ return _REGISTRY[name]()
105
+
106
+
107
+ def all_task_names() -> List[str]:
108
+ return list(_REGISTRY.keys())
server/tasks/easy.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ nl2sql-bench/server/tasks/easy.py
3
+ ===================================
4
+ Task 1 — Simple Filter (difficulty: easy)
5
+
6
+ All questions target a SINGLE table with basic WHERE / ORDER BY / LIMIT.
7
+ A competent small model should solve these in 1–2 steps.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from .base import BaseTask, TaskExample, register
13
+
14
+
15
+ @register
16
+ class SimpleFilterTask(BaseTask):
17
+ name = "simple-filter"
18
+ difficulty = "easy"
19
+
20
+ examples = [
21
+ TaskExample(
22
+ question=(
23
+ "List all gold-tier customers ordered by their name alphabetically. "
24
+ "Return columns: id, name, email, country."
25
+ ),
26
+ sql=(
27
+ "SELECT id, name, email, country "
28
+ "FROM customers "
29
+ "WHERE tier = 'gold' "
30
+ "ORDER BY name ASC"
31
+ ),
32
+ notes="Single table, equality filter, text sort.",
33
+ ),
34
+ TaskExample(
35
+ question=(
36
+ "Show all products with a price above $100, sorted by price from "
37
+ "highest to lowest. Return columns: id, name, price."
38
+ ),
39
+ sql=(
40
+ "SELECT id, name, price "
41
+ "FROM products "
42
+ "WHERE price > 100 "
43
+ "ORDER BY price DESC"
44
+ ),
45
+ notes="Numeric range filter, descending sort.",
46
+ ),
47
+ TaskExample(
48
+ question=(
49
+ "Find all delivered orders with a total_amount greater than $200, "
50
+ "ordered by total_amount descending. "
51
+ "Return columns: id, customer_id, total_amount, created_at."
52
+ ),
53
+ sql=(
54
+ "SELECT id, customer_id, total_amount, created_at "
55
+ "FROM orders "
56
+ "WHERE status = 'delivered' "
57
+ " AND total_amount > 200 "
58
+ "ORDER BY total_amount DESC"
59
+ ),
60
+ notes="Two-condition WHERE on a single table.",
61
+ ),
62
+ TaskExample(
63
+ question=(
64
+ "Return the top 5 most expensive products. "
65
+ "Return columns: id, name, price."
66
+ ),
67
+ sql=(
68
+ "SELECT id, name, price "
69
+ "FROM products "
70
+ "ORDER BY price DESC "
71
+ "LIMIT 5"
72
+ ),
73
+ notes="ORDER BY + LIMIT, no WHERE clause.",
74
+ ),
75
+ TaskExample(
76
+ question=(
77
+ "List all distinct countries where our customers come from, "
78
+ "sorted alphabetically. Return a single column: country."
79
+ ),
80
+ sql=(
81
+ "SELECT DISTINCT country "
82
+ "FROM customers "
83
+ "ORDER BY country ASC"
84
+ ),
85
+ notes="DISTINCT on a single column.",
86
+ ),
87
+ ]
88
+
89
+ def description(self) -> str:
90
+ return (
91
+ "Single-table SELECT queries with WHERE filters, ORDER BY, and LIMIT. "
92
+ "Tests basic SQL fluency."
93
+ )
server/tasks/hard.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ nl2sql-bench/server/tasks/hard.py
3
+ ===================================
4
+ Task 3 — Analytics & Window (difficulty: hard)
5
+
6
+ Questions require CTEs, window functions (RANK, ROW_NUMBER, running totals),
7
+ or non-trivial subqueries. Even strong frontier models often need 3–5 steps.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from .base import BaseTask, TaskExample, register
13
+
14
+
15
+ @register
16
+ class AnalyticsWindowTask(BaseTask):
17
+ name = "analytics-window"
18
+ difficulty = "hard"
19
+
20
+ examples = [
21
+ TaskExample(
22
+ question=(
23
+ "Rank customers by their total spending on delivered orders "
24
+ "using DENSE_RANK (rank 1 = highest spender). "
25
+ "Return columns: customer_name, total_spent, spending_rank. "
26
+ "Round total_spent to 2 decimal places. "
27
+ "Sort by spending_rank ascending."
28
+ ),
29
+ sql=(
30
+ "SELECT customer_name, total_spent, spending_rank "
31
+ "FROM ( "
32
+ " SELECT c.name AS customer_name, "
33
+ " ROUND(SUM(o.total_amount), 2) AS total_spent, "
34
+ " DENSE_RANK() OVER (ORDER BY SUM(o.total_amount) DESC) AS spending_rank "
35
+ " FROM customers c "
36
+ " JOIN orders o ON o.customer_id = c.id "
37
+ " WHERE o.status = 'delivered' "
38
+ " GROUP BY c.id, c.name "
39
+ ") sub "
40
+ "ORDER BY spending_rank ASC"
41
+ ),
42
+ notes="Window function DENSE_RANK inside a subquery wrapping a GROUP BY.",
43
+ ),
44
+ TaskExample(
45
+ question=(
46
+ "For each product that has been reviewed, show its name, its own "
47
+ "average rating, and the average rating of all products in its category. "
48
+ "Return columns: product_name, product_avg_rating, category_avg_rating. "
49
+ "Round both averages to 2 decimal places. "
50
+ "Sort by product_avg_rating descending."
51
+ ),
52
+ sql=(
53
+ "SELECT p.name AS product_name, "
54
+ " ROUND(AVG(r.rating), 2) AS product_avg_rating, "
55
+ " ROUND(AVG(AVG(r.rating)) OVER (PARTITION BY p.category_id), 2) "
56
+ " AS category_avg_rating "
57
+ "FROM products p "
58
+ "JOIN reviews r ON r.product_id = p.id "
59
+ "GROUP BY p.id, p.name, p.category_id "
60
+ "ORDER BY product_avg_rating DESC"
61
+ ),
62
+ notes="AVG of AVG via window PARTITION BY — requires nested aggregate understanding.",
63
+ ),
64
+ TaskExample(
65
+ question=(
66
+ "Find all customers whose most recent order has status 'cancelled'. "
67
+ "Use a CTE with ROW_NUMBER to identify the latest order per customer. "
68
+ "Return columns: customer_name, last_order_status, last_order_date. "
69
+ "Sort by customer_name ascending."
70
+ ),
71
+ sql=(
72
+ "WITH ranked_orders AS ( "
73
+ " SELECT customer_id, status, created_at, "
74
+ " ROW_NUMBER() OVER (PARTITION BY customer_id "
75
+ " ORDER BY created_at DESC) AS rn "
76
+ " FROM orders "
77
+ ") "
78
+ "SELECT c.name AS customer_name, "
79
+ " ro.status AS last_order_status, "
80
+ " ro.created_at AS last_order_date "
81
+ "FROM customers c "
82
+ "JOIN ranked_orders ro ON ro.customer_id = c.id "
83
+ "WHERE ro.rn = 1 "
84
+ " AND ro.status = 'cancelled' "
85
+ "ORDER BY customer_name ASC"
86
+ ),
87
+ notes="CTE + ROW_NUMBER window partitioned by customer_id.",
88
+ ),
89
+ TaskExample(
90
+ question=(
91
+ "Show the monthly revenue from delivered orders and its running total, "
92
+ "for all months in 2024. "
93
+ "Return columns: month (format YYYY-MM), monthly_revenue, running_total. "
94
+ "Round both revenue columns to 2 decimal places. "
95
+ "Sort by month ascending."
96
+ ),
97
+ sql=(
98
+ "WITH monthly AS ( "
99
+ " SELECT strftime('%Y-%m', created_at) AS month, "
100
+ " ROUND(SUM(total_amount), 2) AS monthly_revenue "
101
+ " FROM orders "
102
+ " WHERE status = 'delivered' "
103
+ " AND created_at >= '2024-01-01' "
104
+ " AND created_at < '2025-01-01' "
105
+ " GROUP BY strftime('%Y-%m', created_at) "
106
+ ") "
107
+ "SELECT month, "
108
+ " monthly_revenue, "
109
+ " ROUND(SUM(monthly_revenue) OVER (ORDER BY month), 2) AS running_total "
110
+ "FROM monthly "
111
+ "ORDER BY month ASC"
112
+ ),
113
+ notes="CTE + cumulative SUM window ordered by month string.",
114
+ ),
115
+ TaskExample(
116
+ question=(
117
+ "Find products whose average rating is strictly above the average "
118
+ "rating of all products in their category. "
119
+ "Return columns: product_name, category_name, "
120
+ "product_avg_rating, category_avg_rating. "
121
+ "Round both averages to 2 decimal places. "
122
+ "Sort by product_avg_rating descending, then product_name ascending."
123
+ ),
124
+ sql=(
125
+ "WITH product_ratings AS ( "
126
+ " SELECT p.id AS product_id, p.name AS product_name, "
127
+ " p.category_id, c.name AS category_name, "
128
+ " ROUND(AVG(r.rating), 2) AS product_avg_rating "
129
+ " FROM products p "
130
+ " JOIN reviews r ON r.product_id = p.id "
131
+ " JOIN categories c ON c.id = p.category_id "
132
+ " GROUP BY p.id, p.name, p.category_id, c.name "
133
+ "), "
134
+ "category_ratings AS ( "
135
+ " SELECT category_id, "
136
+ " ROUND(AVG(product_avg_rating), 2) AS category_avg_rating "
137
+ " FROM product_ratings "
138
+ " GROUP BY category_id "
139
+ ") "
140
+ "SELECT pr.product_name, pr.category_name, "
141
+ " pr.product_avg_rating, cr.category_avg_rating "
142
+ "FROM product_ratings pr "
143
+ "JOIN category_ratings cr ON cr.category_id = pr.category_id "
144
+ "WHERE pr.product_avg_rating > cr.category_avg_rating "
145
+ "ORDER BY pr.product_avg_rating DESC, pr.product_name ASC"
146
+ ),
147
+ notes="Two CTEs, correlated comparison between product and category averages.",
148
+ ),
149
+ ]
150
+
151
+ def description(self) -> str:
152
+ return (
153
+ "Advanced analytics queries using CTEs, window functions "
154
+ "(DENSE_RANK, ROW_NUMBER, running SUM), and nested subqueries. "
155
+ "Tests multi-step reasoning and SQLite-specific syntax."
156
+ )
server/tasks/medium.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ nl2sql-bench/server/tasks/medium.py
3
+ =====================================
4
+ Task 2 — Join & Aggregation (difficulty: medium)
5
+
6
+ Questions require at least one JOIN and GROUP BY / HAVING.
7
+ Expect most frontier models to succeed in 2–3 steps.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ from .base import BaseTask, TaskExample, register
13
+
14
+
15
+ @register
16
+ class JoinAggregationTask(BaseTask):
17
+ name = "join-aggregation"
18
+ difficulty = "medium"
19
+
20
+ examples = [
21
+ TaskExample(
22
+ question=(
23
+ "How many orders has each customer placed? "
24
+ "Return columns: customer_name, order_count. "
25
+ "Include customers with zero orders. "
26
+ "Sort by order_count descending, then customer_name ascending."
27
+ ),
28
+ sql=(
29
+ "SELECT c.name AS customer_name, COUNT(o.id) AS order_count "
30
+ "FROM customers c "
31
+ "LEFT JOIN orders o ON c.id = o.customer_id "
32
+ "GROUP BY c.id, c.name "
33
+ "ORDER BY order_count DESC, customer_name ASC"
34
+ ),
35
+ notes="LEFT JOIN to include zero-order customers, COUNT aggregate.",
36
+ ),
37
+ TaskExample(
38
+ question=(
39
+ "What is the average product rating per category? "
40
+ "Only include categories that have at least one review. "
41
+ "Return columns: category_name, avg_rating. "
42
+ "Round avg_rating to 2 decimal places. "
43
+ "Sort by avg_rating descending."
44
+ ),
45
+ sql=(
46
+ "SELECT c.name AS category_name, "
47
+ " ROUND(AVG(r.rating), 2) AS avg_rating "
48
+ "FROM categories c "
49
+ "JOIN products p ON p.category_id = c.id "
50
+ "JOIN reviews r ON r.product_id = p.id "
51
+ "GROUP BY c.id, c.name "
52
+ "ORDER BY avg_rating DESC"
53
+ ),
54
+ notes="Two JOINs, AVG aggregate, ROUND function.",
55
+ ),
56
+ TaskExample(
57
+ question=(
58
+ "Which categories have more than 5 products in stock "
59
+ "(i.e., stock_quantity > 0)? "
60
+ "Return columns: category_name, in_stock_count. "
61
+ "Sort by in_stock_count descending."
62
+ ),
63
+ sql=(
64
+ "SELECT c.name AS category_name, "
65
+ " COUNT(p.id) AS in_stock_count "
66
+ "FROM categories c "
67
+ "JOIN products p ON p.category_id = c.id "
68
+ "WHERE p.stock_quantity > 0 "
69
+ "GROUP BY c.id, c.name "
70
+ "HAVING COUNT(p.id) > 5 "
71
+ "ORDER BY in_stock_count DESC"
72
+ ),
73
+ notes="WHERE before GROUP BY, HAVING filter on aggregate.",
74
+ ),
75
+ TaskExample(
76
+ question=(
77
+ "Which customers have spent more than $500 total on delivered orders? "
78
+ "Return columns: customer_name, total_spent. "
79
+ "Round total_spent to 2 decimal places. "
80
+ "Sort by total_spent descending."
81
+ ),
82
+ sql=(
83
+ "SELECT c.name AS customer_name, "
84
+ " ROUND(SUM(o.total_amount), 2) AS total_spent "
85
+ "FROM customers c "
86
+ "JOIN orders o ON o.customer_id = c.id "
87
+ "WHERE o.status = 'delivered' "
88
+ "GROUP BY c.id, c.name "
89
+ "HAVING SUM(o.total_amount) > 500 "
90
+ "ORDER BY total_spent DESC"
91
+ ),
92
+ notes="SUM aggregate, HAVING on SUM, status filter.",
93
+ ),
94
+ TaskExample(
95
+ question=(
96
+ "Show the total quantity sold for each product. "
97
+ "Only include products that appear in at least one order item. "
98
+ "Return columns: product_name, total_quantity_sold. "
99
+ "Sort by total_quantity_sold descending."
100
+ ),
101
+ sql=(
102
+ "SELECT p.name AS product_name, "
103
+ " SUM(oi.quantity) AS total_quantity_sold "
104
+ "FROM products p "
105
+ "JOIN order_items oi ON oi.product_id = p.id "
106
+ "GROUP BY p.id, p.name "
107
+ "ORDER BY total_quantity_sold DESC"
108
+ ),
109
+ notes="JOIN on order_items, SUM aggregate.",
110
+ ),
111
+ ]
112
+
113
+ def description(self) -> str:
114
+ return (
115
+ "Multi-table JOIN queries with GROUP BY, HAVING, and aggregation "
116
+ "functions (COUNT, SUM, AVG, ROUND). Tests relational reasoning."
117
+ )
tests/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # tests/__init__.py
tests/conftest.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # nl2sql-bench/tests/conftest.py
2
+ """
3
+ Pytest configuration — adds project root and server/ to sys.path
4
+ so all test imports resolve without installing the package.
5
+ """
6
+ import sys
7
+ from pathlib import Path
8
+
9
+ ROOT = Path(__file__).parent.parent
10
+ SERVER = ROOT / "server"
11
+
12
+ for p in [str(ROOT), str(SERVER)]:
13
+ if p not in sys.path:
14
+ sys.path.insert(0, p)
tests/test_all.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ nl2sql-bench/tests/test_all.py
3
+ ================================
4
+ Comprehensive test suite covering:
5
+ - Database seeder (determinism + row counts)
6
+ - Grader (all reward components, step penalty, edge cases)
7
+ - Task registry (all 3 tasks load and produce valid examples)
8
+ - Environment (reset, step, episode boundary, done logic)
9
+ - Inference log format (regex checks on START / STEP / END)
10
+
11
+ Run with:
12
+ pytest tests/ -v
13
+ or from project root:
14
+ PYTHONPATH=.:server pytest tests/ -v
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import re
20
+ import sqlite3
21
+ import sys
22
+ import os
23
+ from pathlib import Path
24
+
25
+ import pytest
26
+
27
+ # ── Path setup so tests can import from both project root and server/ ──────
28
+ ROOT = Path(__file__).parent.parent
29
+ SERVER = ROOT / "server"
30
+ sys.path.insert(0, str(ROOT))
31
+ sys.path.insert(0, str(SERVER))
32
+
33
+ # ── Fixtures ───────────────────────────────────────────────────────────────
34
+
35
+ @pytest.fixture(scope="session")
36
+ def db_conn():
37
+ """Shared in-memory SQLite connection with full schema + seed data."""
38
+ from db.seed import seed_database
39
+ schema = (SERVER / "db" / "schema.sql").read_text()
40
+ conn = sqlite3.connect(":memory:", check_same_thread=False)
41
+ conn.row_factory = sqlite3.Row
42
+ conn.executescript(schema)
43
+ seed_database(conn)
44
+ yield conn
45
+ conn.close()
46
+
47
+
48
+ @pytest.fixture
49
+ def fresh_env():
50
+ """A fresh NL2SQLEnvironment instance per test."""
51
+ from environment import NL2SQLEnvironment
52
+ return NL2SQLEnvironment()
53
+
54
+
55
+ # ══════════════════════════════════════════════════════════════════════════════
56
+ # 1. DATABASE SEEDER
57
+ # ══════════════════════════════════════════════════════════════════════════════
58
+
59
+ class TestSeeder:
60
+
61
+ def test_categories_count(self, db_conn):
62
+ row = db_conn.execute("SELECT COUNT(*) FROM categories").fetchone()
63
+ assert row[0] == 8, "Should have exactly 8 categories"
64
+
65
+ def test_products_count(self, db_conn):
66
+ row = db_conn.execute("SELECT COUNT(*) FROM products").fetchone()
67
+ assert row[0] == 64, "Should have 8 products × 8 categories = 64"
68
+
69
+ def test_customers_count(self, db_conn):
70
+ row = db_conn.execute("SELECT COUNT(*) FROM customers").fetchone()
71
+ assert row[0] == 150
72
+
73
+ def test_orders_exist(self, db_conn):
74
+ row = db_conn.execute("SELECT COUNT(*) FROM orders").fetchone()
75
+ assert row[0] > 100, "Should have a meaningful number of orders"
76
+
77
+ def test_order_items_exist(self, db_conn):
78
+ row = db_conn.execute("SELECT COUNT(*) FROM order_items").fetchone()
79
+ assert row[0] > 200
80
+
81
+ def test_reviews_exist(self, db_conn):
82
+ row = db_conn.execute("SELECT COUNT(*) FROM reviews").fetchone()
83
+ assert row[0] > 50
84
+
85
+ def test_determinism(self, db_conn):
86
+ """Seeding a second connection with the same seed gives identical counts."""
87
+ from db.seed import seed_database
88
+ schema = (SERVER / "db" / "schema.sql").read_text()
89
+ conn2 = sqlite3.connect(":memory:")
90
+ conn2.executescript(schema)
91
+ seed_database(conn2)
92
+
93
+ for tbl in ["categories", "products", "customers", "orders",
94
+ "order_items", "reviews"]:
95
+ c1 = db_conn.execute(f"SELECT COUNT(*) FROM {tbl}").fetchone()[0]
96
+ c2 = conn2.execute(f"SELECT COUNT(*) FROM {tbl}").fetchone()[0]
97
+ assert c1 == c2, f"Table {tbl} count mismatch: {c1} vs {c2}"
98
+ conn2.close()
99
+
100
+ def test_tiers_valid(self, db_conn):
101
+ bad = db_conn.execute(
102
+ "SELECT COUNT(*) FROM customers WHERE tier NOT IN ('bronze','silver','gold')"
103
+ ).fetchone()[0]
104
+ assert bad == 0
105
+
106
+ def test_statuses_valid(self, db_conn):
107
+ bad = db_conn.execute(
108
+ "SELECT COUNT(*) FROM orders "
109
+ "WHERE status NOT IN ('pending','processing','shipped','delivered','cancelled')"
110
+ ).fetchone()[0]
111
+ assert bad == 0
112
+
113
+ def test_ratings_valid(self, db_conn):
114
+ bad = db_conn.execute(
115
+ "SELECT COUNT(*) FROM reviews WHERE rating < 1 OR rating > 5"
116
+ ).fetchone()[0]
117
+ assert bad == 0
118
+
119
+ def test_referential_integrity(self, db_conn):
120
+ """Order items should reference valid orders and products."""
121
+ orphan_orders = db_conn.execute(
122
+ "SELECT COUNT(*) FROM order_items oi "
123
+ "LEFT JOIN orders o ON o.id = oi.order_id WHERE o.id IS NULL"
124
+ ).fetchone()[0]
125
+ assert orphan_orders == 0
126
+
127
+ orphan_products = db_conn.execute(
128
+ "SELECT COUNT(*) FROM order_items oi "
129
+ "LEFT JOIN products p ON p.id = oi.product_id WHERE p.id IS NULL"
130
+ ).fetchone()[0]
131
+ assert orphan_products == 0
132
+
133
+
134
+ # ══════════════════════════════════════════════════════════════════════════════
135
+ # 2. GRADER
136
+ # ══════════════════════════════════════════════════════════════════════════════
137
+
138
+ class TestGrader:
139
+
140
+ def test_exact_match_first_step(self):
141
+ from grader import grade
142
+ gt = [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}]
143
+ result = grade(
144
+ actual_rows=gt.copy(),
145
+ ground_truth_rows=gt,
146
+ error=None,
147
+ step=1,
148
+ order_sensitive=False,
149
+ )
150
+ assert result.reward == pytest.approx(1.0)
151
+ assert result.exact_match is True
152
+ assert result.syntax_ok is True
153
+ assert result.columns_match is True
154
+ assert result.row_count_match is True
155
+ assert result.step_penalty == 0.0
156
+
157
+ def test_syntax_error_gives_zero(self):
158
+ from grader import grade
159
+ result = grade(
160
+ actual_rows=None,
161
+ ground_truth_rows=[{"x": 1}],
162
+ error="near 'SELCT': syntax error",
163
+ step=1,
164
+ )
165
+ assert result.reward == 0.0
166
+ assert result.syntax_ok is False
167
+
168
+ def test_step_penalty_applied(self):
169
+ from grader import grade
170
+ gt = [{"n": 1}]
171
+ result = grade(
172
+ actual_rows=gt.copy(),
173
+ ground_truth_rows=gt,
174
+ error=None,
175
+ step=3, # penalty = (3-1)*0.05 = 0.10
176
+ )
177
+ assert result.reward == pytest.approx(1.0 - 0.10)
178
+ assert result.step_penalty == pytest.approx(0.10)
179
+
180
+ def test_columns_wrong_zero_higher_components(self):
181
+ from grader import grade
182
+ gt = [{"name": "Alice", "score": 10}]
183
+ actual = [{"user": "Alice", "points": 10}] # wrong column names
184
+ result = grade(actual_rows=actual, ground_truth_rows=gt, error=None, step=1)
185
+ assert result.columns_match is False
186
+ assert result.exact_match is False
187
+ # Only syntax score: 0.10
188
+ assert result.reward == pytest.approx(0.10)
189
+
190
+ def test_correct_columns_wrong_rows(self):
191
+ from grader import grade
192
+ gt = [{"name": "Alice"}, {"name": "Bob"}]
193
+ actual = [{"name": "Charlie"}, {"name": "Dave"}]
194
+ result = grade(actual_rows=actual, ground_truth_rows=gt, error=None, step=1)
195
+ assert result.columns_match is True
196
+ assert result.row_count_match is True
197
+ assert result.exact_match is False
198
+ # syntax(0.10) + columns(0.20) + row_count(0.20) = 0.50
199
+ assert result.reward == pytest.approx(0.50)
200
+
201
+ def test_order_sensitive_wrong_order_is_not_exact(self):
202
+ from grader import grade
203
+ gt = [{"id": 1}, {"id": 2}]
204
+ actual = [{"id": 2}, {"id": 1}] # reversed
205
+ result = grade(
206
+ actual_rows=actual,
207
+ ground_truth_rows=gt,
208
+ error=None,
209
+ step=1,
210
+ order_sensitive=True,
211
+ )
212
+ assert result.exact_match is False
213
+
214
+ def test_order_insensitive_accepts_different_row_order(self):
215
+ from grader import grade
216
+ gt = [{"id": 1}, {"id": 2}]
217
+ actual = [{"id": 2}, {"id": 1}] # different order but same content
218
+ result = grade(
219
+ actual_rows=actual,
220
+ ground_truth_rows=gt,
221
+ error=None,
222
+ step=1,
223
+ order_sensitive=False,
224
+ )
225
+ assert result.exact_match is True
226
+
227
+ def test_penalty_never_makes_reward_negative(self):
228
+ from grader import grade
229
+ # Step 99 with syntax error → reward must be >= 0
230
+ result = grade(
231
+ actual_rows=None,
232
+ ground_truth_rows=[{"x": 1}],
233
+ error="some error",
234
+ step=99,
235
+ )
236
+ assert result.reward >= 0.0
237
+
238
+ def test_execute_query_blocks_writes(self, db_conn):
239
+ from grader import execute_query
240
+ rows, err = execute_query(db_conn, "INSERT INTO categories(name) VALUES ('x')")
241
+ assert rows is None
242
+ assert "not allowed" in err.lower() or "INSERT" in err
243
+
244
+ def test_execute_query_returns_rows(self, db_conn):
245
+ from grader import execute_query
246
+ rows, err = execute_query(db_conn, "SELECT id, name FROM categories ORDER BY id")
247
+ assert err is None
248
+ assert len(rows) == 8
249
+ assert "id" in rows[0]
250
+ assert "name" in rows[0]
251
+
252
+ def test_compute_ground_truth(self, db_conn):
253
+ from grader import compute_ground_truth
254
+ rows = compute_ground_truth(db_conn, "SELECT COUNT(*) AS n FROM customers")
255
+ assert len(rows) == 1
256
+ assert rows[0]["n"] == 150
257
+
258
+
259
+ # ════��═════════════════════════════════════════════════════════════════════════
260
+ # 3. TASK REGISTRY
261
+ # ══════════════════════════════════════════════════════════════════════════════
262
+
263
+ class TestTasks:
264
+
265
+ def test_all_tasks_registered(self):
266
+ from tasks import all_task_names
267
+ names = all_task_names()
268
+ assert "simple-filter" in names
269
+ assert "join-aggregation" in names
270
+ assert "analytics-window" in names
271
+
272
+ @pytest.mark.parametrize("task_name", [
273
+ "simple-filter", "join-aggregation", "analytics-window"
274
+ ])
275
+ def test_task_has_examples(self, task_name):
276
+ from tasks import get_task
277
+ task = get_task(task_name)
278
+ assert len(task.examples) >= 3, f"{task_name} needs at least 3 examples"
279
+
280
+ @pytest.mark.parametrize("task_name", [
281
+ "simple-filter", "join-aggregation", "analytics-window"
282
+ ])
283
+ def test_task_sql_runs_on_real_db(self, task_name, db_conn):
284
+ """Every ground-truth SQL must execute cleanly against the seeded DB."""
285
+ from tasks import get_task
286
+ from grader import execute_query
287
+ task = get_task(task_name)
288
+ for ex in task.examples:
289
+ rows, error = execute_query(db_conn, ex.sql)
290
+ assert error is None, (
291
+ f"Task {task_name!r} SQL failed:\n{ex.sql}\nError: {error}"
292
+ )
293
+ assert rows is not None
294
+
295
+ @pytest.mark.parametrize("task_name", [
296
+ "simple-filter", "join-aggregation", "analytics-window"
297
+ ])
298
+ def test_task_roundrobin(self, task_name):
299
+ from tasks import get_task
300
+ task = get_task(task_name)
301
+ n = len(task.examples)
302
+ seen = [task.next_example() for _ in range(n * 2)]
303
+ # After n calls, second half should repeat first half
304
+ assert seen[:n] == seen[n:]
305
+
306
+ def test_schema_context_non_empty(self):
307
+ from tasks import get_task
308
+ task = get_task("simple-filter")
309
+ ctx = task.schema_context()
310
+ assert "customers" in ctx
311
+ assert "orders" in ctx
312
+ assert "products" in ctx
313
+
314
+
315
+ # ══════════════════════════════════════════════════════════════════════════════
316
+ # 4. ENVIRONMENT
317
+ # ══════════════════════════════════════════════════════════════════════════════
318
+
319
+ class TestEnvironment:
320
+
321
+ def test_reset_returns_observation(self, fresh_env):
322
+ obs = fresh_env.reset(task_name="simple-filter")
323
+ assert obs.question != ""
324
+ assert obs.schema_context != ""
325
+ assert obs.task_name == "simple-filter"
326
+ assert obs.done is False
327
+ assert obs.step == 0
328
+ assert obs.reward is None
329
+
330
+ def test_reset_state(self, fresh_env):
331
+ fresh_env.reset(task_name="join-aggregation")
332
+ state = fresh_env.state
333
+ assert state.task_name == "join-aggregation"
334
+ assert state.task_difficulty == "medium"
335
+ assert state.step_count == 0
336
+ assert state.solved is False
337
+
338
+ def test_step_increments_step_count(self, fresh_env):
339
+ from models import NL2SQLAction
340
+ fresh_env.reset(task_name="simple-filter")
341
+ fresh_env.step(NL2SQLAction(query="SELECT 1"))
342
+ assert fresh_env.state.step_count == 1
343
+
344
+ def test_step_syntax_error_gives_nonzero_error(self, fresh_env):
345
+ from models import NL2SQLAction
346
+ fresh_env.reset(task_name="simple-filter")
347
+ obs = fresh_env.step(NL2SQLAction(query="SELCT * FORM broken"))
348
+ assert obs.last_error is not None
349
+ assert obs.reward == 0.0
350
+
351
+ def test_step_valid_query_returns_result(self, fresh_env):
352
+ from models import NL2SQLAction
353
+ fresh_env.reset(task_name="simple-filter")
354
+ obs = fresh_env.step(NL2SQLAction(
355
+ query="SELECT id, name FROM customers ORDER BY name LIMIT 5"
356
+ ))
357
+ assert obs.last_error is None
358
+ assert len(obs.last_result) <= 5
359
+ assert obs.reward >= 0.0
360
+
361
+ def test_exact_match_ends_episode(self, fresh_env):
362
+ """Submitting the exact ground-truth SQL should solve the episode."""
363
+ from models import NL2SQLAction
364
+ fresh_env.reset(task_name="simple-filter")
365
+ # Get the ground truth SQL from the internal example
366
+ gt_sql = fresh_env._example.sql
367
+ obs = fresh_env.step(NL2SQLAction(query=gt_sql))
368
+ assert obs.done is True
369
+ assert fresh_env.state.solved is True
370
+ assert obs.reward == pytest.approx(1.0) # step 1, full score
371
+
372
+ def test_max_steps_ends_episode(self, fresh_env):
373
+ """Exhausting all steps should end the episode even without solving."""
374
+ from models import NL2SQLAction
375
+ from environment import MAX_STEPS
376
+ fresh_env.reset(task_name="analytics-window")
377
+ obs = None
378
+ for _ in range(MAX_STEPS):
379
+ obs = fresh_env.step(NL2SQLAction(query="SELECT 1"))
380
+ assert obs is not None
381
+ assert obs.done is True
382
+
383
+ def test_reset_clears_previous_episode(self, fresh_env):
384
+ from models import NL2SQLAction
385
+ fresh_env.reset(task_name="simple-filter")
386
+ fresh_env.step(NL2SQLAction(query="SELECT 1"))
387
+ # Second reset should clear state
388
+ obs = fresh_env.reset(task_name="join-aggregation")
389
+ assert fresh_env.state.step_count == 0
390
+ assert obs.step == 0
391
+ assert obs.task_name == "join-aggregation"
392
+
393
+ @pytest.mark.parametrize("task_name", [
394
+ "simple-filter", "join-aggregation", "analytics-window"
395
+ ])
396
+ def test_all_tasks_solvable(self, task_name):
397
+ """Ground-truth SQL should always produce reward == 1.0 on step 1."""
398
+ from environment import NL2SQLEnvironment
399
+ from models import NL2SQLAction
400
+ env = NL2SQLEnvironment()
401
+ env.reset(task_name=task_name)
402
+ gt_sql = env._example.sql
403
+ obs = env.step(NL2SQLAction(query=gt_sql))
404
+ assert obs.done is True
405
+ assert obs.reward == pytest.approx(1.0), (
406
+ f"Task {task_name!r}: ground-truth SQL did not score 1.0.\n"
407
+ f"SQL: {gt_sql}\nError: {obs.last_error}\nReward: {obs.reward}"
408
+ )
409
+
410
+ def test_score_normalised_to_0_1(self, fresh_env):
411
+ from models import NL2SQLAction
412
+ fresh_env.reset(task_name="simple-filter")
413
+ for _ in range(3):
414
+ obs = fresh_env.step(NL2SQLAction(query="SELECT 1 AS x"))
415
+ assert 0.0 <= obs.score <= 1.0
416
+
417
+ def test_write_query_blocked(self, fresh_env):
418
+ from models import NL2SQLAction
419
+ fresh_env.reset(task_name="simple-filter")
420
+ obs = fresh_env.step(NL2SQLAction(
421
+ query="INSERT INTO categories(name) VALUES ('hack')"
422
+ ))
423
+ assert obs.last_error is not None
424
+ assert "not allowed" in obs.last_error.lower() or "INSERT" in obs.last_error
425
+
426
+
427
+ # ══════════════════════════════════════════════════════════════════════════════
428
+ # 5. LOG FORMAT COMPLIANCE
429
+ # ══════════════════════════════════════════════════════════════════════════════
430
+
431
+ class TestLogFormat:
432
+ """Validate that the inference.py log helpers emit correct format."""
433
+
434
+ START_RE = re.compile(
435
+ r"^\[START\] task=\S+ env=\S+ model=\S+$"
436
+ )
437
+ STEP_RE = re.compile(
438
+ r"^\[STEP\] step=\d+ action=.+ reward=\d+\.\d{2} "
439
+ r"done=(true|false) error=.+$"
440
+ )
441
+ END_RE = re.compile(
442
+ r"^\[END\] success=(true|false) steps=\d+ score=\d+\.\d{3} "
443
+ r"rewards=[\d.,]+$"
444
+ )
445
+
446
+ def _capture(self, func, *args, **kwargs) -> str:
447
+ import io
448
+ from contextlib import redirect_stdout
449
+ buf = io.StringIO()
450
+ with redirect_stdout(buf):
451
+ func(*args, **kwargs)
452
+ return buf.getvalue().strip()
453
+
454
+ def test_log_start_format(self):
455
+ sys.path.insert(0, str(ROOT))
456
+ from inference import log_start
457
+ out = self._capture(log_start, "simple-filter", "Qwen/Qwen2.5-72B")
458
+ assert self.START_RE.match(out), f"Bad [START] format: {out!r}"
459
+
460
+ def test_log_step_format_null_error(self):
461
+ from inference import log_step
462
+ out = self._capture(log_step, 1, "SELECT 1", 0.10, False, None)
463
+ assert self.STEP_RE.match(out), f"Bad [STEP] format: {out!r}"
464
+
465
+ def test_log_step_format_with_error(self):
466
+ from inference import log_step
467
+ out = self._capture(log_step, 2, "SELCT 1", 0.0, False, "syntax error")
468
+ assert self.STEP_RE.match(out), f"Bad [STEP] format: {out!r}"
469
+
470
+ def test_log_end_format_success(self):
471
+ from inference import log_end
472
+ out = self._capture(log_end, True, 3, 0.850, [0.50, 1.0, 1.0])
473
+ assert self.END_RE.match(out), f"Bad [END] format: {out!r}"
474
+
475
+ def test_log_end_format_failure(self):
476
+ from inference import log_end
477
+ out = self._capture(log_end, False, 5, 0.100, [0.1, 0.0, 0.0, 0.0, 0.0])
478
+ assert self.END_RE.match(out), f"Bad [END] format: {out!r}"
479
+
480
+ def test_reward_two_decimal_places(self):
481
+ from inference import log_step
482
+ out = self._capture(log_step, 1, "SELECT 1", 0.5, False, None)
483
+ # reward= field must have exactly 2 decimal places
484
+ match = re.search(r"reward=(\d+\.\d+)", out)
485
+ assert match, "No reward= field found"
486
+ assert len(match.group(1).split(".")[1]) == 2
487
+
488
+ def test_score_three_decimal_places(self):
489
+ from inference import log_end
490
+ out = self._capture(log_end, True, 1, 1.0, [1.0])
491
+ match = re.search(r"score=(\d+\.\d+)", out)
492
+ assert match
493
+ assert len(match.group(1).split(".")[1]) == 3
train.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # CRITICAL: Ye line sabse upar honi chahiye kisi bhi PyTorch import se pehle!
3
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0,7"
4
+
5
+ import sys
6
+ import torch
7
+ from datasets import Dataset
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer
9
+ from peft import LoraConfig
10
+ from trl import GRPOConfig, GRPOTrainer
11
+
12
+ sys.path.insert(0, "./server")
13
+ from environment import NL2SQLEnvironment
14
+ from models import NL2SQLAction
15
+ from tasks import all_task_names, get_task
16
+
17
+ MODEL_NAME = "Qwen/Qwen2.5-Coder-7B-Instruct"
18
+ OUTPUT_DIR = "./qwen-7b-coder-nl2sql-grpo"
19
+
20
+ SYSTEM_PROMPT = """You are a Senior Database Architect and an expert in SQLite.
21
+ Your task is to translate natural language questions into highly optimized, correct SQLite SELECT queries.
22
+
23
+ STRICT RULES:
24
+ 1. Output EXACTLY ONE valid SQLite query.
25
+ 2. DO NOT wrap the query in markdown formatting (no ```sql or ```).
26
+ 3. DO NOT output any explanations, conversational text, or preambles (e.g., never say "Here is the query").
27
+ 4. ONLY use standard SQLite functions. Avoid SQL Server, MySQL, or PostgreSQL specific syntax.
28
+ 5. If the question implies ordering, use the correct ORDER BY clause.
29
+
30
+ Your output must be executable directly against the database as-is."""
31
+
32
+ def build_dataset():
33
+ data = []
34
+ for t_name in all_task_names():
35
+ task = get_task(t_name)
36
+ schema = task.schema_context()
37
+ for ex in task.examples:
38
+ user_content = f"SCHEMA:\n{schema}\n\nQUESTION: {ex.question}"
39
+ data.append({
40
+ "prompt": [
41
+ {"role": "system", "content": SYSTEM_PROMPT},
42
+ {"role": "user", "content": user_content}
43
+ ],
44
+ "task_name": t_name
45
+ })
46
+ return Dataset.from_list(data)
47
+
48
+ def sql_reward_func(prompts, completions, task_name, **kwargs):
49
+ rewards = []
50
+ env = NL2SQLEnvironment()
51
+
52
+ for idx, completion in enumerate(completions):
53
+ generated_text = completion[0]['content'] if isinstance(completion, list) else completion
54
+
55
+ if generated_text.startswith("```"):
56
+ lines = generated_text.split("\n")
57
+ generated_text = "\n".join(l for l in lines if not l.strip().startswith("```")).strip()
58
+
59
+ current_task = task_name[idx] if isinstance(task_name, list) else task_name
60
+
61
+ env.reset(task_name=current_task)
62
+
63
+ try:
64
+ action = NL2SQLAction(query=generated_text)
65
+ obs = env.step(action)
66
+ rewards.append(float(obs.reward))
67
+ except Exception:
68
+ rewards.append(0.0)
69
+
70
+ return rewards
71
+
72
+ def main():
73
+ dataset = build_dataset()
74
+
75
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side="right")
76
+ if tokenizer.pad_token is None:
77
+ tokenizer.pad_token = tokenizer.eos_token
78
+
79
+ model = AutoModelForCausalLM.from_pretrained(
80
+ MODEL_NAME,
81
+ torch_dtype=torch.bfloat16,
82
+ attn_implementation="sdpa" # Defaulting to sdpa to avoid any flash_attn setup issues
83
+ )
84
+
85
+ peft_config = LoraConfig(
86
+ r=128,
87
+ lora_alpha=256,
88
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
89
+ bias="none",
90
+ task_type="CAUSAL_LM"
91
+ )
92
+
93
+ training_args = GRPOConfig(
94
+ output_dir=OUTPUT_DIR,
95
+ learning_rate=2e-5,
96
+ per_device_train_batch_size=2,
97
+ gradient_accumulation_steps=4,
98
+ max_completion_length=256,
99
+ num_generations=8,
100
+ temperature=0.5,
101
+ bf16=True,
102
+ logging_steps=5,
103
+ num_train_epochs=10,
104
+ report_to="none",
105
+ remove_unused_columns=False,
106
+ ddp_find_unused_parameters=False
107
+ )
108
+
109
+ trainer = GRPOTrainer(
110
+ model=model,
111
+ reward_funcs=sql_reward_func,
112
+ args=training_args,
113
+ train_dataset=dataset,
114
+ peft_config=peft_config,
115
+ processing_class=tokenizer
116
+ )
117
+
118
+ trainer.train()
119
+
120
+ if trainer.accelerator.is_main_process:
121
+ trainer.model.save_pretrained(f"{OUTPUT_DIR}/final")
122
+ tokenizer.save_pretrained(f"{OUTPUT_DIR}/final")
123
+
124
+ if __name__ == "__main__":
125
+ main()