Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .env.example +23 -0
- Dockerfile +56 -0
- README.md +270 -7
- __init__.py +18 -0
- check_quality.py +131 -0
- clean_dataset.py +79 -0
- client.py +81 -0
- custom_train.py +250 -0
- data_expander.py +161 -0
- data_factory/__init__.py +0 -0
- data_factory/augmentor.py +288 -0
- data_factory/config.py +50 -0
- data_factory/generate_data.py +1947 -0
- data_factory/generator.py +410 -0
- data_factory/pipeline.py +443 -0
- data_factory/run_data_factory.py +260 -0
- data_factory/schemas.py +564 -0
- data_factory/templates.py +993 -0
- data_factory/validator.py +221 -0
- env_server +33 -0
- folder.txt +95 -0
- generate_data.py +263 -0
- generate_edge_cases.py +319 -0
- inference.py +265 -0
- local_test.py +118 -0
- merge_model.py +32 -0
- mini_server.py +74 -0
- models.py +79 -0
- openenv.yaml +135 -0
- pyproject.toml +39 -0
- scripts/run_local.sh +67 -0
- scripts/smoke_test.sh +62 -0
- server/__init__.py +1 -0
- server/app.py +31 -0
- server/app.py.bak +31 -0
- server/db/__init__.py +1 -0
- server/db/schema.sql +62 -0
- server/db/seed.py +225 -0
- server/environment.py +223 -0
- server/grader.py +214 -0
- server/requirements.txt +13 -0
- server/tasks/__init__.py +14 -0
- server/tasks/base.py +108 -0
- server/tasks/easy.py +93 -0
- server/tasks/hard.py +156 -0
- server/tasks/medium.py +117 -0
- tests/__init__.py +1 -0
- tests/conftest.py +14 -0
- tests/test_all.py +493 -0
- train.py +125 -0
.env.example
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# nl2sql-bench/.env.example
|
| 2 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 3 |
+
# Copy this file to .env and fill in your values.
|
| 4 |
+
# NEVER commit .env to version control.
|
| 5 |
+
#
|
| 6 |
+
# All three variables below are MANDATORY per competition rules.
|
| 7 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 8 |
+
|
| 9 |
+
# LLM API endpoint (HuggingFace router or any OpenAI-compatible base URL)
|
| 10 |
+
API_BASE_URL=https://router.huggingface.co/v1
|
| 11 |
+
|
| 12 |
+
# Model identifier — must be accessible at the above endpoint
|
| 13 |
+
MODEL_NAME=Qwen/Qwen2.5-72B-Instruct
|
| 14 |
+
|
| 15 |
+
# HuggingFace API token (also used as the OpenAI-client api_key)
|
| 16 |
+
HF_TOKEN=hf_your_token_here
|
| 17 |
+
|
| 18 |
+
# ── Optional overrides ────────────────────────────────────────────────────
|
| 19 |
+
# LOCAL_IMAGE_NAME=nl2sql-bench:latest # Docker image name for local dev
|
| 20 |
+
# SPACE_URL=https://your-space.hf.space # Deployed HF Space URL
|
| 21 |
+
# NL2SQL_DEFAULT_TASK=simple-filter # Default task (overridden per episode)
|
| 22 |
+
# NL2SQL_MAX_STEPS=5 # Max steps per episode
|
| 23 |
+
# ENABLE_WEB_INTERFACE=true # Enable /web UI for debugging
|
Dockerfile
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# nl2sql-bench/server/Dockerfile
|
| 2 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 3 |
+
# NL2SQL-Bench OpenEnv Server
|
| 4 |
+
# Hugging Face Spaces compatible (port 7860, non-root user).
|
| 5 |
+
# Build: docker build -t nl2sql-bench:latest .
|
| 6 |
+
# Run: docker run -p 7860:7860 nl2sql-bench:latest
|
| 7 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 8 |
+
|
| 9 |
+
FROM python:3.11-slim
|
| 10 |
+
|
| 11 |
+
# HF Spaces runs as non-root by default
|
| 12 |
+
ARG UID=1000
|
| 13 |
+
RUN useradd -m -u $UID appuser
|
| 14 |
+
|
| 15 |
+
WORKDIR /app
|
| 16 |
+
|
| 17 |
+
# ── System deps ───────────────────────────────────────────────────────────
|
| 18 |
+
RUN apt-get update -qq && \
|
| 19 |
+
apt-get install -y --no-install-recommends curl && \
|
| 20 |
+
rm -rf /var/lib/apt/lists/*
|
| 21 |
+
|
| 22 |
+
# ── Python deps ───────────────────────────────────────────────────────────
|
| 23 |
+
COPY server/requirements.txt /app/requirements.txt
|
| 24 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 25 |
+
|
| 26 |
+
# ── Application code ──────────────────────────────────────────────────────
|
| 27 |
+
# Copy server code
|
| 28 |
+
COPY server/ /app/server/
|
| 29 |
+
# Copy shared models (client imports from parent — we flatten for Docker)
|
| 30 |
+
COPY models.py /app/models.py
|
| 31 |
+
|
| 32 |
+
# Flatten server submodules into /app so Python can find them
|
| 33 |
+
# (avoids complex PYTHONPATH games inside the container)
|
| 34 |
+
RUN cp -r /app/server/tasks /app/tasks && \
|
| 35 |
+
cp -r /app/server/db /app/db && \
|
| 36 |
+
cp /app/server/grader.py /app/grader.py && \
|
| 37 |
+
cp /app/server/environment.py /app/environment.py && \
|
| 38 |
+
cp /app/server/app.py /app/app.py
|
| 39 |
+
|
| 40 |
+
# ── Runtime config ────────────────────────────────────────────────────────
|
| 41 |
+
ENV PYTHONPATH=/app
|
| 42 |
+
ENV PYTHONUNBUFFERED=1
|
| 43 |
+
# HF Spaces requires port 7860
|
| 44 |
+
ENV PORT=7860
|
| 45 |
+
ENV NL2SQL_DEFAULT_TASK=simple-filter
|
| 46 |
+
ENV NL2SQL_MAX_STEPS=5
|
| 47 |
+
|
| 48 |
+
USER appuser
|
| 49 |
+
WORKDIR /app
|
| 50 |
+
|
| 51 |
+
EXPOSE 7860
|
| 52 |
+
|
| 53 |
+
HEALTHCHECK --interval=30s --timeout=10s --start-period=15s --retries=3 \
|
| 54 |
+
CMD curl -sf http://localhost:${PORT}/health || exit 1
|
| 55 |
+
|
| 56 |
+
CMD ["sh", "-c", "uvicorn app:app --host 0.0.0.0 --port ${PORT} --workers 2 --log-level info"]
|
README.md
CHANGED
|
@@ -1,10 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
| 1 |
+
# NL2SQL-Bench
|
| 2 |
+
|
| 3 |
+
**Natural Language to SQL Analytics Environment for RL Training**
|
| 4 |
+
|
| 5 |
+
[](https://github.com/meta-pytorch/OpenEnv)
|
| 6 |
+
[](https://www.python.org)
|
| 7 |
+
[](LICENSE)
|
| 8 |
+
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
## What is NL2SQL-Bench?
|
| 12 |
+
|
| 13 |
+
NL2SQL-Bench is an OpenEnv-compliant RL training environment where an AI agent must iteratively write and refine **SQLite queries** to answer natural-language business questions against a synthetic e-commerce database.
|
| 14 |
+
|
| 15 |
+
This fills a genuine gap in the OpenEnv ecosystem — no SQL query environment currently exists. Every data-driven company employs analysts who translate business questions into SQL. Training agents to do this well (and to recover from errors) is immediately valuable.
|
| 16 |
+
|
| 17 |
+
**Why it's a great RL domain:**
|
| 18 |
+
- Rewards are **100% deterministic** — no LLM-as-judge, no subjectivity
|
| 19 |
+
- Multi-turn episodes create **dense reward signal** across the trajectory
|
| 20 |
+
- The error → fix → retry loop is a novel mechanic not present in existing environments
|
| 21 |
+
- Three clearly graduated difficulty levels challenge models across the full skill range
|
| 22 |
+
|
| 23 |
+
---
|
| 24 |
+
|
| 25 |
+
## Environment Description
|
| 26 |
+
|
| 27 |
+
The agent interacts with a **synthetic e-commerce SQLite database** containing ~150 customers, 64 products across 8 categories, ~600 orders, ~1000 order items, and ~400 reviews. The database is seeded deterministically (seed=42) so results are reproducible across any machine.
|
| 28 |
+
|
| 29 |
+
The agent receives a natural-language question and iteratively submits SQL queries. Each query is executed, graded against the ground truth, and the reward + error/result is fed back as the next observation.
|
| 30 |
+
|
| 31 |
+
---
|
| 32 |
+
|
| 33 |
+
## Database Schema
|
| 34 |
+
|
| 35 |
+
```
|
| 36 |
+
categories(id, name)
|
| 37 |
+
products(id, name, category_id, price, stock_quantity)
|
| 38 |
+
customers(id, name, email, country, tier∈{bronze|silver|gold}, created_at)
|
| 39 |
+
orders(id, customer_id, status∈{pending|processing|shipped|delivered|cancelled},
|
| 40 |
+
created_at, total_amount)
|
| 41 |
+
order_items(id, order_id, product_id, quantity, unit_price)
|
| 42 |
+
reviews(id, product_id, customer_id, rating∈1-5, created_at)
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
All dates are ISO-8601 strings sortable by text comparison. SQLite window functions and CTEs are fully supported.
|
| 46 |
+
|
| 47 |
+
---
|
| 48 |
+
|
| 49 |
+
## Action & Observation Space
|
| 50 |
+
|
| 51 |
+
### Action
|
| 52 |
+
```python
|
| 53 |
+
@dataclass
|
| 54 |
+
class NL2SQLAction(Action):
|
| 55 |
+
query: str # A SQLite SELECT query string
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
### Observation
|
| 59 |
+
```python
|
| 60 |
+
@dataclass
|
| 61 |
+
class NL2SQLObservation(Observation):
|
| 62 |
+
question: str # The NL question to answer
|
| 63 |
+
schema_context: str # Compact schema description
|
| 64 |
+
task_name: str # Active task identifier
|
| 65 |
+
last_query: str # SQL submitted on previous step
|
| 66 |
+
last_result: List[Dict] # Up to 10 result rows
|
| 67 |
+
last_error: Optional[str] # SQLite error string or None
|
| 68 |
+
result_columns: List[str] # Column names of last_result
|
| 69 |
+
step: int # Current step (0 after reset)
|
| 70 |
+
max_steps: int # Maximum steps per episode
|
| 71 |
+
done: bool # Episode ended?
|
| 72 |
+
reward: Optional[float] # Step reward [0.0, 1.0]
|
| 73 |
+
score: float # Cumulative normalised score
|
| 74 |
+
```
|
| 75 |
+
|
| 76 |
+
---
|
| 77 |
+
|
| 78 |
+
## Tasks & Expected Difficulty
|
| 79 |
+
|
| 80 |
+
### Task 1 — `simple-filter` (easy)
|
| 81 |
+
Single-table SELECT queries with WHERE, ORDER BY, LIMIT. Tests basic SQL fluency. Example questions:
|
| 82 |
+
- "List all gold-tier customers ordered by name alphabetically."
|
| 83 |
+
- "Return the top 5 most expensive products."
|
| 84 |
+
|
| 85 |
+
**Expected solve rate (frontier model, 5 steps):** ~80%
|
| 86 |
+
|
| 87 |
+
### Task 2 — `join-aggregation` (medium)
|
| 88 |
+
Multi-table JOINs with GROUP BY, HAVING, and aggregation functions. Example questions:
|
| 89 |
+
- "How many orders has each customer placed? Include customers with zero orders."
|
| 90 |
+
- "Which customers have spent more than $500 total on delivered orders?"
|
| 91 |
+
|
| 92 |
+
**Expected solve rate (frontier model, 5 steps):** ~55%
|
| 93 |
+
|
| 94 |
+
### Task 3 — `analytics-window` (hard)
|
| 95 |
+
CTEs, window functions (DENSE_RANK, ROW_NUMBER, running SUM), and nested subqueries. Example questions:
|
| 96 |
+
- "Rank customers by total spending using DENSE_RANK."
|
| 97 |
+
- "Show monthly revenue and running total for delivered orders in 2024."
|
| 98 |
+
|
| 99 |
+
**Expected solve rate (frontier model, 5 steps):** ~30%
|
| 100 |
+
|
| 101 |
+
---
|
| 102 |
+
|
| 103 |
+
## Reward Function
|
| 104 |
+
|
| 105 |
+
Rewards are computed by deterministic comparison of the agent's result set against the ground truth:
|
| 106 |
+
|
| 107 |
+
| Component | Score | Description |
|
| 108 |
+
|---|---|---|
|
| 109 |
+
| `syntax_ok` | +0.10 | Query runs without SQLite error |
|
| 110 |
+
| `columns_match` | +0.20 | Returned column names match ground truth |
|
| 111 |
+
| `row_count_match` | +0.20 | Number of rows matches |
|
| 112 |
+
| `exact_match` | +0.50 | Full result set equals ground truth |
|
| 113 |
+
| `step_penalty` | −0.05/step | Deducted per step beyond the first |
|
| 114 |
+
|
| 115 |
+
Final reward is clamped to `[0.0, 1.0]`. Order sensitivity matches the ground-truth query: ORDER BY queries require correct row ordering; others are order-agnostic.
|
| 116 |
+
|
| 117 |
+
---
|
| 118 |
+
|
| 119 |
+
## Baseline Scores
|
| 120 |
+
|
| 121 |
+
Run by the `inference.py` script using `Qwen/Qwen2.5-72B-Instruct` via HuggingFace router:
|
| 122 |
+
|
| 123 |
+
| Task | Expected Score |
|
| 124 |
+
|---|---|
|
| 125 |
+
| `simple-filter` | ~0.70 |
|
| 126 |
+
| `join-aggregation` | ~0.45 |
|
| 127 |
+
| `analytics-window` | ~0.25 |
|
| 128 |
+
|
| 129 |
+
---
|
| 130 |
+
|
| 131 |
+
## Setup & Usage
|
| 132 |
+
|
| 133 |
+
### Prerequisites
|
| 134 |
+
- Python 3.10+
|
| 135 |
+
- Docker (for containerised deployment)
|
| 136 |
+
- A HuggingFace account + token
|
| 137 |
+
|
| 138 |
+
### Local Development (no Docker)
|
| 139 |
+
|
| 140 |
+
```bash
|
| 141 |
+
# Clone the repository
|
| 142 |
+
git clone https://huggingface.co/spaces/your-username/nl2sql-bench
|
| 143 |
+
cd nl2sql-bench
|
| 144 |
+
|
| 145 |
+
# Quick start
|
| 146 |
+
chmod +x scripts/run_local.sh
|
| 147 |
+
./scripts/run_local.sh
|
| 148 |
+
|
| 149 |
+
# Or manually:
|
| 150 |
+
python3 -m venv .venv && source .venv/bin/activate
|
| 151 |
+
pip install openenv-core fastapi "uvicorn[standard]" openai pydantic
|
| 152 |
+
export PYTHONPATH=".:server"
|
| 153 |
+
cd server && uvicorn app:app --reload --port 8000
|
| 154 |
+
```
|
| 155 |
+
|
| 156 |
+
### Test the Running Server
|
| 157 |
+
|
| 158 |
+
```bash
|
| 159 |
+
# Run smoke tests
|
| 160 |
+
chmod +x scripts/smoke_test.sh
|
| 161 |
+
./scripts/smoke_test.sh http://localhost:8000
|
| 162 |
+
|
| 163 |
+
# Run full test suite
|
| 164 |
+
pip install pytest pytest-asyncio
|
| 165 |
+
PYTHONPATH=".:server" pytest tests/ -v
|
| 166 |
+
```
|
| 167 |
+
|
| 168 |
+
### Docker
|
| 169 |
+
|
| 170 |
+
```bash
|
| 171 |
+
# Build
|
| 172 |
+
docker build -t nl2sql-bench:latest .
|
| 173 |
+
|
| 174 |
+
# Run
|
| 175 |
+
docker run -p 7860:7860 nl2sql-bench:latest
|
| 176 |
+
|
| 177 |
+
# Test
|
| 178 |
+
./scripts/smoke_test.sh http://localhost:7860
|
| 179 |
+
```
|
| 180 |
+
|
| 181 |
+
### Pre-submission Validation
|
| 182 |
+
|
| 183 |
+
```bash
|
| 184 |
+
# Run the official validator (replace with your HF Space URL)
|
| 185 |
+
chmod +x pre_validation_script.sh
|
| 186 |
+
./pre_validation_script.sh https://your-username-nl2sql-bench.hf.space .
|
| 187 |
+
```
|
| 188 |
+
|
| 189 |
+
### Running the Baseline Inference
|
| 190 |
+
|
| 191 |
+
```bash
|
| 192 |
+
# Set mandatory variables
|
| 193 |
+
export API_BASE_URL="https://router.huggingface.co/v1"
|
| 194 |
+
export MODEL_NAME="Qwen/Qwen2.5-72B-Instruct"
|
| 195 |
+
export HF_TOKEN="hf_your_token_here"
|
| 196 |
+
export SPACE_URL="https://your-username-nl2sql-bench.hf.space"
|
| 197 |
+
|
| 198 |
+
python inference.py
|
| 199 |
+
```
|
| 200 |
+
|
| 201 |
+
### Using the Client Programmatically
|
| 202 |
+
|
| 203 |
+
```python
|
| 204 |
+
import asyncio
|
| 205 |
+
from client import NL2SQLEnv
|
| 206 |
+
from models import NL2SQLAction
|
| 207 |
+
|
| 208 |
+
async def main():
|
| 209 |
+
async with NL2SQLEnv(base_url="http://localhost:8000") as env:
|
| 210 |
+
result = await env.reset()
|
| 211 |
+
print(result.observation.question)
|
| 212 |
+
|
| 213 |
+
result = await env.step(NL2SQLAction(
|
| 214 |
+
query="SELECT id, name FROM customers WHERE tier='gold' ORDER BY name"
|
| 215 |
+
))
|
| 216 |
+
print(f"Reward: {result.reward:.2f}")
|
| 217 |
+
print(f"Done: {result.done}")
|
| 218 |
+
print(f"Error: {result.observation.last_error}")
|
| 219 |
+
|
| 220 |
+
asyncio.run(main())
|
| 221 |
+
```
|
| 222 |
+
|
| 223 |
+
---
|
| 224 |
+
|
| 225 |
+
## Project Structure
|
| 226 |
+
|
| 227 |
+
```
|
| 228 |
+
nl2sql-bench/
|
| 229 |
+
├── models.py # NL2SQLAction, NL2SQLObservation, NL2SQLState
|
| 230 |
+
├── client.py # NL2SQLEnv(HTTPEnvClient)
|
| 231 |
+
├── inference.py # Baseline inference script (mandatory name)
|
| 232 |
+
├── openenv.yaml # OpenEnv manifest
|
| 233 |
+
├── pyproject.toml
|
| 234 |
+
├── Dockerfile # HF Spaces compatible (port 7860)
|
| 235 |
+
├── .env.example
|
| 236 |
+
├── server/
|
| 237 |
+
│ ├── app.py # FastAPI entry point
|
| 238 |
+
│ ├── environment.py # Core RL environment logic
|
| 239 |
+
│ ├── grader.py # Deterministic reward computation
|
| 240 |
+
│ ├── requirements.txt
|
| 241 |
+
│ ├── db/
|
| 242 |
+
│ │ ├── schema.sql # 6-table e-commerce schema
|
| 243 |
+
│ │ └── seed.py # Deterministic data generator (seed=42)
|
| 244 |
+
│ └── tasks/
|
| 245 |
+
│ ├── base.py # BaseTask + registry
|
| 246 |
+
│ ├── easy.py # simple-filter (5 examples)
|
| 247 |
+
│ ├── medium.py # join-aggregation (5 examples)
|
| 248 |
+
│ └── hard.py # analytics-window (5 examples)
|
| 249 |
+
├── tests/
|
| 250 |
+
│ ├── conftest.py
|
| 251 |
+
│ └── test_all.py # 30+ pytest tests
|
| 252 |
+
└── scripts/
|
| 253 |
+
├── run_local.sh # Local dev server
|
| 254 |
+
└── smoke_test.sh # Endpoint smoke tests
|
| 255 |
+
```
|
| 256 |
+
|
| 257 |
---
|
| 258 |
+
|
| 259 |
+
## Design Decisions
|
| 260 |
+
|
| 261 |
+
**Why SQLite in-memory?** Zero runtime dependency, deterministic, and it runs comfortably within the 2 vCPU / 8 GB constraint. The database loads in ~50ms.
|
| 262 |
+
|
| 263 |
+
**Why multi-turn (up to 5 steps)?** A single-shot SQL environment gives binary rewards. Multi-turn with error feedback gives the agent — and the GRPO trainer — a rich signal: the model learns not just to write SQL, but to debug and refine its queries.
|
| 264 |
+
|
| 265 |
+
**Why step penalty?** Without it, an agent that accidentally gets the right answer on step 5 scores the same as one that gets it on step 1. The penalty creates pressure to solve efficiently, which is realistic.
|
| 266 |
+
|
| 267 |
+
**Why order-sensitive comparison for ORDER BY queries?** Business questions that say "rank by spending" expect a ranked output. Order-agnostic comparison would give spurious credit.
|
| 268 |
+
|
| 269 |
---
|
| 270 |
|
| 271 |
+
## License
|
| 272 |
+
|
| 273 |
+
MIT — see [LICENSE](LICENSE)
|
__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
nl2sql-bench — NL2SQL Analytics OpenEnv Environment
|
| 3 |
+
====================================================
|
| 4 |
+
Public API surface for client-side use.
|
| 5 |
+
|
| 6 |
+
from nl2sql_bench import NL2SQLEnv, NL2SQLAction, NL2SQLObservation, NL2SQLState
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from models import NL2SQLAction, NL2SQLObservation, NL2SQLState
|
| 10 |
+
from client import NL2SQLEnv
|
| 11 |
+
|
| 12 |
+
__version__ = "0.1.0"
|
| 13 |
+
__all__ = [
|
| 14 |
+
"NL2SQLEnv",
|
| 15 |
+
"NL2SQLAction",
|
| 16 |
+
"NL2SQLObservation",
|
| 17 |
+
"NL2SQLState",
|
| 18 |
+
]
|
check_quality.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import re
|
| 5 |
+
from collections import Counter
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
|
| 8 |
+
# Add project root to path
|
| 9 |
+
PROJECT_ROOT = os.path.abspath(os.path.dirname(__file__))
|
| 10 |
+
if PROJECT_ROOT not in sys.path:
|
| 11 |
+
sys.path.insert(0, PROJECT_ROOT)
|
| 12 |
+
|
| 13 |
+
from data_factory.validator import SQLValidator
|
| 14 |
+
|
| 15 |
+
DATASET_FILE = "edge_cases.jsonl"
|
| 16 |
+
|
| 17 |
+
def main():
|
| 18 |
+
if not os.path.exists(DATASET_FILE):
|
| 19 |
+
print(f"Error: {DATASET_FILE} not found!")
|
| 20 |
+
return
|
| 21 |
+
|
| 22 |
+
print("Starting Dataset Quality & Sanity Check...\n")
|
| 23 |
+
|
| 24 |
+
total_rows = 0
|
| 25 |
+
corrupt_json = 0
|
| 26 |
+
sql_execution_failures = 0
|
| 27 |
+
empty_outputs = 0
|
| 28 |
+
missing_domains = 0
|
| 29 |
+
|
| 30 |
+
persona_counts = Counter()
|
| 31 |
+
unique_sqls = set()
|
| 32 |
+
unique_questions = set()
|
| 33 |
+
domain_counts = Counter()
|
| 34 |
+
|
| 35 |
+
validators = {}
|
| 36 |
+
|
| 37 |
+
with open(DATASET_FILE, "r", encoding="utf-8") as f:
|
| 38 |
+
lines = f.readlines()
|
| 39 |
+
|
| 40 |
+
for line in tqdm(lines, desc="Analyzing Rows"):
|
| 41 |
+
total_rows += 1
|
| 42 |
+
try:
|
| 43 |
+
record = json.loads(line)
|
| 44 |
+
except json.JSONDecodeError:
|
| 45 |
+
corrupt_json += 1
|
| 46 |
+
continue
|
| 47 |
+
|
| 48 |
+
prompt_block = record.get("prompt", [])
|
| 49 |
+
sql = record.get("sql", "").strip()
|
| 50 |
+
metadata = record.get("metadata", {})
|
| 51 |
+
|
| 52 |
+
if not prompt_block or len(prompt_block) < 2 or not sql:
|
| 53 |
+
empty_outputs += 1
|
| 54 |
+
continue
|
| 55 |
+
|
| 56 |
+
user_content = prompt_block[1].get("content", "")
|
| 57 |
+
question = user_content.split("QUESTION: ")[-1]
|
| 58 |
+
|
| 59 |
+
# Smart Domain Extraction: Try metadata first, fallback to prompt parsing
|
| 60 |
+
domain = metadata.get("domain")
|
| 61 |
+
if not domain:
|
| 62 |
+
match = re.search(r"Database:\s*([a-zA-Z0-9_]+)", user_content)
|
| 63 |
+
domain = match.group(1) if match else "unknown"
|
| 64 |
+
|
| 65 |
+
persona = metadata.get("persona", "unknown")
|
| 66 |
+
|
| 67 |
+
persona_counts[persona] += 1
|
| 68 |
+
domain_counts[domain] += 1
|
| 69 |
+
unique_sqls.add(sql)
|
| 70 |
+
unique_questions.add(question)
|
| 71 |
+
|
| 72 |
+
# Skip validation if domain is completely unknown/corrupted
|
| 73 |
+
if domain == "unknown":
|
| 74 |
+
missing_domains += 1
|
| 75 |
+
continue
|
| 76 |
+
|
| 77 |
+
# Strict Execution Quality Check
|
| 78 |
+
try:
|
| 79 |
+
if domain not in validators:
|
| 80 |
+
validators[domain] = SQLValidator(domain, seed=42)
|
| 81 |
+
|
| 82 |
+
val_result = validators[domain].validate(sql)
|
| 83 |
+
if not val_result.passed or val_result.row_count == 0:
|
| 84 |
+
sql_execution_failures += 1
|
| 85 |
+
except Exception as e:
|
| 86 |
+
# If any schema error occurs, mark it as failure
|
| 87 |
+
missing_domains += 1
|
| 88 |
+
continue
|
| 89 |
+
|
| 90 |
+
# Cleanup validators
|
| 91 |
+
for v in validators.values():
|
| 92 |
+
v.close()
|
| 93 |
+
|
| 94 |
+
# --- REPORT GENERATION ---
|
| 95 |
+
print("\n" + "="*60)
|
| 96 |
+
print("DATASET HEALTH REPORT")
|
| 97 |
+
print("="*60)
|
| 98 |
+
print(f"Total Rows Parsed : {total_rows}")
|
| 99 |
+
print(f"Corrupt JSON Lines : {corrupt_json}")
|
| 100 |
+
print(f"Missing SQL/Domains : {empty_outputs + missing_domains}")
|
| 101 |
+
|
| 102 |
+
print("\nDIVERSITY METRICS:")
|
| 103 |
+
print(f"Unique SQL Queries : {len(unique_sqls)} (Base logic templates)")
|
| 104 |
+
print(f"Unique NL Questions : {len(unique_questions)}")
|
| 105 |
+
|
| 106 |
+
valid_total = total_rows - (corrupt_json + empty_outputs + missing_domains)
|
| 107 |
+
duplication_rate = (1 - (len(unique_questions) / valid_total)) * 100 if valid_total else 0
|
| 108 |
+
print(f"NL Duplication Rate : {duplication_rate:.2f}% (Should be low!)")
|
| 109 |
+
|
| 110 |
+
print("\nPERSONA DISTRIBUTION:")
|
| 111 |
+
for p, count in persona_counts.most_common():
|
| 112 |
+
print(f" - {p}: {count} ({(count/valid_total)*100:.1f}%)" if valid_total else f" - {p}: {count}")
|
| 113 |
+
|
| 114 |
+
print("\nDOMAIN DISTRIBUTION:")
|
| 115 |
+
for d, count in domain_counts.most_common():
|
| 116 |
+
print(f" - {d}: {count} ({(count/valid_total)*100:.1f}%)" if valid_total else f" - {d}: {count}")
|
| 117 |
+
|
| 118 |
+
print("\nCRITICAL QUALITY CHECK:")
|
| 119 |
+
fail_rate = (sql_execution_failures / valid_total) * 100 if valid_total else 0
|
| 120 |
+
print(f"SQL Execution Failures : {sql_execution_failures} ({fail_rate:.2f}%)")
|
| 121 |
+
|
| 122 |
+
if fail_rate > 5.0:
|
| 123 |
+
print("WARNING: Too many SQLs are failing. Dataset needs cleanup.")
|
| 124 |
+
elif fail_rate > 0:
|
| 125 |
+
print("GOOD: Very low failure rate. Safe to train after minor filtering.")
|
| 126 |
+
else:
|
| 127 |
+
print("PERFECT: Zero execution failures. Pure Gold Dataset!")
|
| 128 |
+
print("="*60)
|
| 129 |
+
|
| 130 |
+
if __name__ == "__main__":
|
| 131 |
+
main()
|
clean_dataset.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import re
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
|
| 7 |
+
PROJECT_ROOT = os.path.abspath(os.path.dirname(__file__))
|
| 8 |
+
if PROJECT_ROOT not in sys.path:
|
| 9 |
+
sys.path.insert(0, PROJECT_ROOT)
|
| 10 |
+
|
| 11 |
+
from data_factory.validator import SQLValidator
|
| 12 |
+
|
| 13 |
+
INPUT_FILE = "nl2sql_50k_elite_dataset.jsonl"
|
| 14 |
+
OUTPUT_FILE = "nl2sql_cleaned_ready_to_train.jsonl"
|
| 15 |
+
|
| 16 |
+
def main():
|
| 17 |
+
if not os.path.exists(INPUT_FILE):
|
| 18 |
+
print(f"Error: {INPUT_FILE} not found!")
|
| 19 |
+
return
|
| 20 |
+
|
| 21 |
+
print(f"Sweeping dataset to remove bad SQLs...")
|
| 22 |
+
|
| 23 |
+
with open(INPUT_FILE, "r", encoding="utf-8") as f:
|
| 24 |
+
lines = f.readlines()
|
| 25 |
+
|
| 26 |
+
validators = {}
|
| 27 |
+
cleaned_count = 0
|
| 28 |
+
failed_count = 0
|
| 29 |
+
|
| 30 |
+
with open(OUTPUT_FILE, "w", encoding="utf-8") as out_f:
|
| 31 |
+
for line in tqdm(lines, desc="Filtering Garbage"):
|
| 32 |
+
try:
|
| 33 |
+
record = json.loads(line)
|
| 34 |
+
except json.JSONDecodeError:
|
| 35 |
+
failed_count += 1
|
| 36 |
+
continue
|
| 37 |
+
|
| 38 |
+
sql = record.get("sql", "").strip()
|
| 39 |
+
metadata = record.get("metadata", {})
|
| 40 |
+
domain = metadata.get("domain")
|
| 41 |
+
|
| 42 |
+
# Fallback for domain extraction
|
| 43 |
+
if not domain or domain == "unknown":
|
| 44 |
+
content = record.get("prompt", [{}, {}])[1].get("content", "")
|
| 45 |
+
match = re.search(r"Database:\s*([a-zA-Z0-9_]+)", content)
|
| 46 |
+
domain = match.group(1) if match else "unknown"
|
| 47 |
+
|
| 48 |
+
if domain == "unknown":
|
| 49 |
+
failed_count += 1
|
| 50 |
+
continue
|
| 51 |
+
|
| 52 |
+
if domain not in validators:
|
| 53 |
+
validators[domain] = SQLValidator(domain, seed=42)
|
| 54 |
+
|
| 55 |
+
try:
|
| 56 |
+
val_result = validators[domain].validate(sql)
|
| 57 |
+
# Keep ONLY if SQL is 100% perfect and returns data
|
| 58 |
+
if val_result.passed and val_result.row_count > 0:
|
| 59 |
+
out_f.write(line)
|
| 60 |
+
cleaned_count += 1
|
| 61 |
+
else:
|
| 62 |
+
failed_count += 1
|
| 63 |
+
except Exception:
|
| 64 |
+
failed_count += 1
|
| 65 |
+
|
| 66 |
+
for v in validators.values():
|
| 67 |
+
v.close()
|
| 68 |
+
|
| 69 |
+
print("\n" + "="*50)
|
| 70 |
+
print("DATASET CLEANUP COMPLETE")
|
| 71 |
+
print("="*50)
|
| 72 |
+
print(f"Original Rows : {len(lines)}")
|
| 73 |
+
print(f"Cleaned Rows : {cleaned_count} (100% Valid SQL)")
|
| 74 |
+
print(f"Removed Rows : {failed_count}")
|
| 75 |
+
print(f"Saved To : {OUTPUT_FILE}")
|
| 76 |
+
print("="*50)
|
| 77 |
+
|
| 78 |
+
if __name__ == "__main__":
|
| 79 |
+
main()
|
client.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import httpx
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
from typing import Any, Dict, Optional
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
|
| 7 |
+
@dataclass
|
| 8 |
+
class NL2SQLAction:
|
| 9 |
+
query: str
|
| 10 |
+
|
| 11 |
+
@dataclass
|
| 12 |
+
class NL2SQLObservation:
|
| 13 |
+
question: str
|
| 14 |
+
schema_context: str
|
| 15 |
+
task_name: str
|
| 16 |
+
last_query: str
|
| 17 |
+
last_result: list
|
| 18 |
+
last_error: Optional[str]
|
| 19 |
+
result_columns: list
|
| 20 |
+
step: int
|
| 21 |
+
max_steps: int
|
| 22 |
+
done: bool
|
| 23 |
+
reward: float
|
| 24 |
+
score: float
|
| 25 |
+
|
| 26 |
+
@dataclass
|
| 27 |
+
class StepResult:
|
| 28 |
+
observation: NL2SQLObservation
|
| 29 |
+
reward: float
|
| 30 |
+
done: bool
|
| 31 |
+
|
| 32 |
+
class NL2SQLEnv:
|
| 33 |
+
def __init__(self, base_url: str = "http://localhost:8000"):
|
| 34 |
+
self.base_url = base_url
|
| 35 |
+
self.client = httpx.AsyncClient(base_url=base_url, timeout=60.0)
|
| 36 |
+
|
| 37 |
+
async def __aenter__(self):
|
| 38 |
+
return self
|
| 39 |
+
|
| 40 |
+
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
| 41 |
+
await self.client.aclose()
|
| 42 |
+
|
| 43 |
+
async def reset(self) -> StepResult:
|
| 44 |
+
task_name = os.getenv("NL2SQL_DEFAULT_TASK", "simple-filter")
|
| 45 |
+
resp = await self.client.post("/reset", json={"task_name": task_name})
|
| 46 |
+
return self._parse_result(resp.json())
|
| 47 |
+
|
| 48 |
+
async def step(self, action: NL2SQLAction) -> StepResult:
|
| 49 |
+
payload = {"query": action.query}
|
| 50 |
+
resp = await self.client.post("/step", json=payload)
|
| 51 |
+
return self._parse_result(resp.json())
|
| 52 |
+
|
| 53 |
+
def _parse_result(self, payload: Dict[str, Any]) -> StepResult:
|
| 54 |
+
obs_data = payload.get("observation", payload)
|
| 55 |
+
|
| 56 |
+
# SAFETY CHECK: If reward or score is None/null, default to 0.0
|
| 57 |
+
raw_reward = obs_data.get("reward")
|
| 58 |
+
safe_reward = float(raw_reward) if raw_reward is not None else 0.0
|
| 59 |
+
|
| 60 |
+
raw_score = obs_data.get("score")
|
| 61 |
+
safe_score = float(raw_score) if raw_score is not None else 0.0
|
| 62 |
+
|
| 63 |
+
obs = NL2SQLObservation(
|
| 64 |
+
question=obs_data.get("question", ""),
|
| 65 |
+
schema_context=obs_data.get("schema_context", ""),
|
| 66 |
+
task_name=obs_data.get("task_name", ""),
|
| 67 |
+
last_query=obs_data.get("last_query", ""),
|
| 68 |
+
last_result=obs_data.get("last_result", []),
|
| 69 |
+
last_error=obs_data.get("last_error"),
|
| 70 |
+
result_columns=obs_data.get("result_columns", []),
|
| 71 |
+
step=obs_data.get("step", 0),
|
| 72 |
+
max_steps=obs_data.get("max_steps", 5),
|
| 73 |
+
done=obs_data.get("done", False),
|
| 74 |
+
reward=safe_reward,
|
| 75 |
+
score=safe_score,
|
| 76 |
+
)
|
| 77 |
+
return StepResult(
|
| 78 |
+
observation=obs,
|
| 79 |
+
reward=safe_reward,
|
| 80 |
+
done=obs.done,
|
| 81 |
+
)
|
custom_train.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
merge_and_train.py
|
| 3 |
+
==================
|
| 4 |
+
1. Merges nl2sql_cleaned_ready_to_train.jsonl + edge_cases.jsonl
|
| 5 |
+
2. Shuffles the combined dataset
|
| 6 |
+
3. Retrains using the same GRPO setup as train.py
|
| 7 |
+
|
| 8 |
+
Run:
|
| 9 |
+
python merge_and_train.py
|
| 10 |
+
|
| 11 |
+
Flags (env vars):
|
| 12 |
+
EDGE_FILE — path to edge cases jsonl (default: edge_cases.jsonl)
|
| 13 |
+
BASE_FILE — path to existing cleaned (default: nl2sql_cleaned_ready_to_train.jsonl)
|
| 14 |
+
MERGED_FILE — merged output path (default: nl2sql_merged_final.jsonl)
|
| 15 |
+
SKIP_MERGE — set "1" to skip merge step and go straight to training
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import os, sys, json, random
|
| 19 |
+
import torch
|
| 20 |
+
from datasets import Dataset
|
| 21 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 22 |
+
from peft import LoraConfig
|
| 23 |
+
from trl import GRPOConfig, GRPOTrainer
|
| 24 |
+
|
| 25 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0,5,1,6"
|
| 26 |
+
|
| 27 |
+
sys.path.insert(0, "./server")
|
| 28 |
+
from environment import NL2SQLEnvironment
|
| 29 |
+
from models import NL2SQLAction
|
| 30 |
+
from tasks import all_task_names, get_task
|
| 31 |
+
|
| 32 |
+
# ── Config ───────────────────────────────────────────────────────────────────
|
| 33 |
+
BASE_FILE = os.getenv("BASE_FILE", "nl2sql_cleaned_ready_to_train.jsonl")
|
| 34 |
+
EDGE_FILE = os.getenv("EDGE_FILE", "edge_cases.jsonl")
|
| 35 |
+
MERGED_FILE = os.getenv("MERGED_FILE", "nl2sql_merged_final.jsonl")
|
| 36 |
+
SKIP_MERGE = os.getenv("SKIP_MERGE", "0") == "1"
|
| 37 |
+
|
| 38 |
+
MODEL_NAME = "Qwen/Qwen2.5-Coder-7B-Instruct"
|
| 39 |
+
OUTPUT_DIR = "./qwen-7b-coder-nl2sql-grpo-v2"
|
| 40 |
+
|
| 41 |
+
SYSTEM_PROMPT = """You are a Senior Database Architect and an expert in SQLite.
|
| 42 |
+
Your task is to translate natural language questions into highly optimized, correct SQLite SELECT queries.
|
| 43 |
+
|
| 44 |
+
STRICT RULES:
|
| 45 |
+
1. Output EXACTLY ONE valid SQLite query.
|
| 46 |
+
2. DO NOT wrap the query in markdown formatting (no ```sql or ```).
|
| 47 |
+
3. DO NOT output any explanations, conversational text, or preambles.
|
| 48 |
+
4. ONLY use standard SQLite functions.
|
| 49 |
+
5. If the question implies ordering, use the correct ORDER BY clause.
|
| 50 |
+
6. SELECT only the columns explicitly requested — no extras.
|
| 51 |
+
|
| 52 |
+
Your output must be executable directly against the database as-is."""
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# ── Step 1: Merge ─────────────────────────────────────────────────────────────
|
| 56 |
+
|
| 57 |
+
def merge_datasets():
|
| 58 |
+
if SKIP_MERGE:
|
| 59 |
+
print(f"[SKIP_MERGE=1] Using existing {MERGED_FILE}")
|
| 60 |
+
return
|
| 61 |
+
|
| 62 |
+
print(f"Loading base: {BASE_FILE}")
|
| 63 |
+
print(f"Loading edges: {EDGE_FILE}")
|
| 64 |
+
|
| 65 |
+
base_lines = []
|
| 66 |
+
with open(BASE_FILE, "r", encoding="utf-8") as f:
|
| 67 |
+
for line in f:
|
| 68 |
+
line = line.strip()
|
| 69 |
+
if line:
|
| 70 |
+
base_lines.append(line)
|
| 71 |
+
|
| 72 |
+
edge_lines = []
|
| 73 |
+
with open(EDGE_FILE, "r", encoding="utf-8") as f:
|
| 74 |
+
for line in f:
|
| 75 |
+
line = line.strip()
|
| 76 |
+
if line:
|
| 77 |
+
edge_lines.append(line)
|
| 78 |
+
|
| 79 |
+
combined = base_lines + edge_lines
|
| 80 |
+
random.shuffle(combined)
|
| 81 |
+
|
| 82 |
+
with open(MERGED_FILE, "w", encoding="utf-8") as f:
|
| 83 |
+
for line in combined:
|
| 84 |
+
f.write(line + "\n")
|
| 85 |
+
|
| 86 |
+
print(
|
| 87 |
+
f"Merged: {len(base_lines)} base + {len(edge_lines)} edge "
|
| 88 |
+
f"= {len(combined)} total → {MERGED_FILE}"
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# ── Step 2: Build HF Dataset ──────────────────────────────────────────────────
|
| 93 |
+
|
| 94 |
+
def build_dataset():
|
| 95 |
+
"""
|
| 96 |
+
Primary source: merged JSONL (base + edge cases).
|
| 97 |
+
Fallback: task examples from server/tasks/ (same as original train.py).
|
| 98 |
+
Both are combined so GRPO sees everything.
|
| 99 |
+
"""
|
| 100 |
+
data = []
|
| 101 |
+
|
| 102 |
+
# Load merged JSONL
|
| 103 |
+
with open(MERGED_FILE, "r", encoding="utf-8") as f:
|
| 104 |
+
for line in f:
|
| 105 |
+
line = line.strip()
|
| 106 |
+
if not line:
|
| 107 |
+
continue
|
| 108 |
+
rec = json.loads(line)
|
| 109 |
+
# rec has "prompt" (list of messages) and "sql"
|
| 110 |
+
# GRPO needs "prompt" and "task_name" — we use a synthetic task_name
|
| 111 |
+
data.append({
|
| 112 |
+
"prompt": rec["prompt"],
|
| 113 |
+
"task_name": "merged_jsonl" # grader falls back to execution-based reward
|
| 114 |
+
})
|
| 115 |
+
|
| 116 |
+
# Also keep the original task examples so GRPO reward env works for them
|
| 117 |
+
for t_name in all_task_names():
|
| 118 |
+
task = get_task(t_name)
|
| 119 |
+
schema = task.schema_context()
|
| 120 |
+
for ex in task.examples:
|
| 121 |
+
data.append({
|
| 122 |
+
"prompt": [
|
| 123 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 124 |
+
{"role": "user", "content": f"SCHEMA:\n{schema}\n\nQUESTION: {ex.question}"}
|
| 125 |
+
],
|
| 126 |
+
"task_name": t_name
|
| 127 |
+
})
|
| 128 |
+
|
| 129 |
+
random.shuffle(data)
|
| 130 |
+
print(f"Dataset size: {len(data)} samples")
|
| 131 |
+
return Dataset.from_list(data)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
# ── Step 3: Reward function ─────────────────────���─────────────────────────────
|
| 135 |
+
|
| 136 |
+
def sql_reward_func(prompts, completions, task_name, **kwargs):
|
| 137 |
+
rewards = []
|
| 138 |
+
env = NL2SQLEnvironment()
|
| 139 |
+
|
| 140 |
+
for idx, completion in enumerate(completions):
|
| 141 |
+
generated = (
|
| 142 |
+
completion[0]["content"] if isinstance(completion, list) else completion
|
| 143 |
+
)
|
| 144 |
+
# Strip code fences defensively
|
| 145 |
+
import re
|
| 146 |
+
generated = re.sub(r"```(?:sql)?\n?(.*?)```", r"\1", generated, flags=re.DOTALL).strip()
|
| 147 |
+
|
| 148 |
+
t = task_name[idx] if isinstance(task_name, list) else task_name
|
| 149 |
+
|
| 150 |
+
# For merged_jsonl rows the env won't have a matching task →
|
| 151 |
+
# reward purely on execution (non-empty result set = +1, error = 0)
|
| 152 |
+
if t == "merged_jsonl":
|
| 153 |
+
rewards.append(_execution_reward(generated, prompts[idx]))
|
| 154 |
+
continue
|
| 155 |
+
|
| 156 |
+
env.reset(task_name=t)
|
| 157 |
+
try:
|
| 158 |
+
obs = env.step(NL2SQLAction(query=generated))
|
| 159 |
+
rewards.append(float(obs.reward))
|
| 160 |
+
except Exception:
|
| 161 |
+
rewards.append(0.0)
|
| 162 |
+
|
| 163 |
+
return rewards
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def _execution_reward(sql: str, prompt) -> float:
|
| 167 |
+
"""Simple execution check for merged_jsonl samples."""
|
| 168 |
+
import sqlite3, re as _re
|
| 169 |
+
|
| 170 |
+
# Extract schema from the user message
|
| 171 |
+
user_content = ""
|
| 172 |
+
for msg in (prompt if isinstance(prompt, list) else []):
|
| 173 |
+
if isinstance(msg, dict) and msg.get("role") == "user":
|
| 174 |
+
user_content = msg.get("content", "")
|
| 175 |
+
break
|
| 176 |
+
|
| 177 |
+
schema_match = _re.search(r"SCHEMA:\s*(.*?)\nQUESTION:", user_content, _re.DOTALL)
|
| 178 |
+
if not schema_match:
|
| 179 |
+
return 0.5 # can't verify, neutral reward
|
| 180 |
+
|
| 181 |
+
schema_sql = schema_match.group(1).strip()
|
| 182 |
+
try:
|
| 183 |
+
conn = sqlite3.connect(":memory:")
|
| 184 |
+
conn.executescript(schema_sql)
|
| 185 |
+
rows = conn.execute(sql).fetchall()
|
| 186 |
+
conn.close()
|
| 187 |
+
return 1.0 if rows else 0.3 # ran cleanly but empty → partial credit
|
| 188 |
+
except Exception:
|
| 189 |
+
return 0.0
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
# ── Step 4: Train ─────────────────────────────────────────────────────────────
|
| 193 |
+
|
| 194 |
+
def main():
|
| 195 |
+
merge_datasets()
|
| 196 |
+
dataset = build_dataset()
|
| 197 |
+
|
| 198 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side="right")
|
| 199 |
+
if tokenizer.pad_token is None:
|
| 200 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 201 |
+
|
| 202 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 203 |
+
MODEL_NAME,
|
| 204 |
+
torch_dtype=torch.bfloat16,
|
| 205 |
+
attn_implementation="sdpa"
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
peft_config = LoraConfig(
|
| 209 |
+
r=128,
|
| 210 |
+
lora_alpha=256,
|
| 211 |
+
target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
|
| 212 |
+
bias="none",
|
| 213 |
+
task_type="CAUSAL_LM"
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
training_args = GRPOConfig(
|
| 217 |
+
output_dir=OUTPUT_DIR,
|
| 218 |
+
learning_rate=1e-5, # lower LR for fine-grained edge case tuning
|
| 219 |
+
per_device_train_batch_size=2,
|
| 220 |
+
gradient_accumulation_steps=4,
|
| 221 |
+
max_completion_length=256,
|
| 222 |
+
num_generations=8,
|
| 223 |
+
temperature=0.5,
|
| 224 |
+
bf16=True,
|
| 225 |
+
logging_steps=5,
|
| 226 |
+
num_train_epochs=5, # fewer epochs — base knowledge already there
|
| 227 |
+
report_to="none",
|
| 228 |
+
remove_unused_columns=False,
|
| 229 |
+
ddp_find_unused_parameters=False
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
trainer = GRPOTrainer(
|
| 233 |
+
model=model,
|
| 234 |
+
reward_funcs=sql_reward_func,
|
| 235 |
+
args=training_args,
|
| 236 |
+
train_dataset=dataset,
|
| 237 |
+
peft_config=peft_config,
|
| 238 |
+
processing_class=tokenizer
|
| 239 |
+
)
|
| 240 |
+
|
| 241 |
+
trainer.train()
|
| 242 |
+
|
| 243 |
+
if trainer.accelerator.is_main_process:
|
| 244 |
+
trainer.model.save_pretrained(f"{OUTPUT_DIR}/final")
|
| 245 |
+
tokenizer.save_pretrained(f"{OUTPUT_DIR}/final")
|
| 246 |
+
print(f"\nSaved to {OUTPUT_DIR}/final")
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
if __name__ == "__main__":
|
| 250 |
+
main()
|
data_expander.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import json
|
| 4 |
+
import torch
|
| 5 |
+
import hashlib
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 9 |
+
import sys
|
| 10 |
+
|
| 11 |
+
# --- PATCH FOR TRANSFORMERS VERSION MISMATCH ---
|
| 12 |
+
try:
|
| 13 |
+
import transformers.activations
|
| 14 |
+
if not hasattr(transformers.activations, "PytorchGELUTanh"):
|
| 15 |
+
# Mapping the old name to the new existing one
|
| 16 |
+
transformers.activations.PytorchGELUTanh = transformers.activations.GELUActivation
|
| 17 |
+
except ImportError:
|
| 18 |
+
pass
|
| 19 |
+
# ------------------------------------------------------
|
| 20 |
+
|
| 21 |
+
import os
|
| 22 |
+
import json
|
| 23 |
+
import torch
|
| 24 |
+
# ... baaki ke saare purane imports
|
| 25 |
+
|
| 26 |
+
# Force script to use only the 2 free GPUs (e.g., 0 and 7)
|
| 27 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0,7"
|
| 28 |
+
|
| 29 |
+
PROJECT_ROOT = os.path.abspath(os.path.dirname(__file__))
|
| 30 |
+
if PROJECT_ROOT not in sys.path:
|
| 31 |
+
sys.path.insert(0, PROJECT_ROOT)
|
| 32 |
+
|
| 33 |
+
from data_factory.schemas import SCHEMA_CONTEXT
|
| 34 |
+
|
| 35 |
+
# AWQ model is 4x smaller and much faster
|
| 36 |
+
MODEL_NAME = "Qwen/Qwen2.5-72B-Instruct-AWQ"
|
| 37 |
+
INPUT_FILE = "llm_hybrid_templates.json"
|
| 38 |
+
OUTPUT_FILE = "nl2sql_50k_elite_dataset.jsonl"
|
| 39 |
+
VARIATIONS_PER_SQL = 20
|
| 40 |
+
BATCH_SIZE = 64 # AWQ allows much larger batches!
|
| 41 |
+
|
| 42 |
+
SYSTEM_PROMPT = "You are an expert SQL analyst. Write a single SELECT query that answers the question. Output ONLY the SQL query — no markdown, no explanation, no backticks."
|
| 43 |
+
|
| 44 |
+
EXPANSION_PROMPT = """
|
| 45 |
+
You are an expert linguist and NL2SQL data augmentor. I have a SQLite database schema and a complex SQL query.
|
| 46 |
+
Generate exactly {count} completely different natural language questions that this exact SQL query answers.
|
| 47 |
+
|
| 48 |
+
RULES:
|
| 49 |
+
- Personas: Executive (direct), Non-tech (wordy), Analyst (technical), Curious (investigative).
|
| 50 |
+
- Structure: Completely change sentence flow.
|
| 51 |
+
- No direct column/table names.
|
| 52 |
+
|
| 53 |
+
DATABASE SCHEMA:
|
| 54 |
+
{schema}
|
| 55 |
+
|
| 56 |
+
SQL QUERY:
|
| 57 |
+
{sql}
|
| 58 |
+
|
| 59 |
+
OUTPUT FORMAT:
|
| 60 |
+
Return ONLY a valid JSON array of objects: [{{"persona": "...", "question": "..."}}]
|
| 61 |
+
"""
|
| 62 |
+
|
| 63 |
+
def extract_json_array(raw_text):
|
| 64 |
+
text = raw_text.strip()
|
| 65 |
+
start = text.find("[")
|
| 66 |
+
end = text.rfind("]")
|
| 67 |
+
if start != -1 and end != -1:
|
| 68 |
+
return text[start:end+1]
|
| 69 |
+
return "[]"
|
| 70 |
+
|
| 71 |
+
def get_hash(text):
|
| 72 |
+
return hashlib.md5(text.lower().strip().encode('utf-8')).hexdigest()
|
| 73 |
+
|
| 74 |
+
def main():
|
| 75 |
+
if not os.path.exists(INPUT_FILE):
|
| 76 |
+
print(f"Error: {INPUT_FILE} not found.")
|
| 77 |
+
sys.exit(1)
|
| 78 |
+
|
| 79 |
+
with open(INPUT_FILE, "r") as f:
|
| 80 |
+
base_templates = json.load(f)
|
| 81 |
+
|
| 82 |
+
print(f"🚀 Loading {MODEL_NAME} on 2 GPUs...")
|
| 83 |
+
|
| 84 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side="left")
|
| 85 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 86 |
+
|
| 87 |
+
# Model loading (AWQ version automatically handles quantization)
|
| 88 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 89 |
+
MODEL_NAME,
|
| 90 |
+
device_map="auto",
|
| 91 |
+
torch_dtype=torch.float16, # AWQ models use float16/bfloat16 for weights
|
| 92 |
+
low_cpu_mem_usage=True
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
seen_hashes = set()
|
| 96 |
+
total_saved = 0
|
| 97 |
+
if os.path.exists(OUTPUT_FILE):
|
| 98 |
+
with open(OUTPUT_FILE, "r") as f:
|
| 99 |
+
for line in f:
|
| 100 |
+
total_saved += 1 # Quick count
|
| 101 |
+
|
| 102 |
+
pbar = tqdm(total=len(base_templates) * VARIATIONS_PER_SQL, initial=total_saved)
|
| 103 |
+
|
| 104 |
+
# Batch processing
|
| 105 |
+
for i in range(0, len(base_templates), BATCH_SIZE):
|
| 106 |
+
batch = base_templates[i:i + BATCH_SIZE]
|
| 107 |
+
prompts = []
|
| 108 |
+
|
| 109 |
+
for temp in batch:
|
| 110 |
+
msg = [
|
| 111 |
+
{"role": "system", "content": "You output only JSON arrays."},
|
| 112 |
+
{"role": "user", "content": EXPANSION_PROMPT.format(count=VARIATIONS_PER_SQL, schema=SCHEMA_CONTEXT[temp['domain']], sql=temp['sql'])}
|
| 113 |
+
]
|
| 114 |
+
prompts.append(tokenizer.apply_chat_template(msg, tokenize=False, add_generation_prompt=True))
|
| 115 |
+
|
| 116 |
+
inputs = tokenizer(prompts, return_tensors="pt", padding=True).to(model.device)
|
| 117 |
+
|
| 118 |
+
try:
|
| 119 |
+
with torch.no_grad():
|
| 120 |
+
# Increased speed: AWQ handles large batches efficiently
|
| 121 |
+
outputs = model.generate(
|
| 122 |
+
**inputs,
|
| 123 |
+
max_new_tokens=2048,
|
| 124 |
+
temperature=0.5,
|
| 125 |
+
do_sample=True,
|
| 126 |
+
pad_token_id=tokenizer.eos_token_id
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
responses = tokenizer.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
| 130 |
+
|
| 131 |
+
with open(OUTPUT_FILE, "a", encoding="utf-8") as out_file:
|
| 132 |
+
for idx, resp in enumerate(responses):
|
| 133 |
+
questions_data = json.loads(extract_json_array(resp))
|
| 134 |
+
sql = batch[idx]["sql"]
|
| 135 |
+
domain = batch[idx]["domain"]
|
| 136 |
+
|
| 137 |
+
for item in questions_data:
|
| 138 |
+
q = item.get("question", "")
|
| 139 |
+
if len(q) > 10:
|
| 140 |
+
q_hash = get_hash(q + sql)
|
| 141 |
+
if q_hash not in seen_hashes:
|
| 142 |
+
seen_hashes.add(q_hash)
|
| 143 |
+
record = {
|
| 144 |
+
"prompt": [
|
| 145 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 146 |
+
{"role": "user", "content": f"SCHEMA: {SCHEMA_CONTEXT[domain]}\nQUESTION: {q}"}
|
| 147 |
+
],
|
| 148 |
+
"sql": sql
|
| 149 |
+
}
|
| 150 |
+
out_file.write(json.dumps(record, ensure_ascii=False) + "\n")
|
| 151 |
+
total_saved += 1
|
| 152 |
+
pbar.update(1)
|
| 153 |
+
out_file.flush()
|
| 154 |
+
except Exception as e:
|
| 155 |
+
print(f"Batch failed: {e}")
|
| 156 |
+
continue
|
| 157 |
+
|
| 158 |
+
pbar.close()
|
| 159 |
+
|
| 160 |
+
if __name__ == "__main__":
|
| 161 |
+
main()
|
data_factory/__init__.py
ADDED
|
File without changes
|
data_factory/augmentor.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
data_factory/augmentor.py
|
| 3 |
+
==========================
|
| 4 |
+
Rule-based Natural Language augmentation.
|
| 5 |
+
|
| 6 |
+
These transformations operate ONLY on NL question strings.
|
| 7 |
+
SQL is NEVER modified — it always comes from the verified template library.
|
| 8 |
+
|
| 9 |
+
Three augmentation strategies:
|
| 10 |
+
1. Synonym replacement — swaps domain words with semantically equivalent ones
|
| 11 |
+
2. Condition reordering — shuffles conjunctive phrases (preserves meaning)
|
| 12 |
+
3. Date normalisation — expresses dates in different formats when applicable
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import random
|
| 18 |
+
import re
|
| 19 |
+
from copy import deepcopy
|
| 20 |
+
from typing import Iterator
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 24 |
+
# SYNONYM DICTIONARIES
|
| 25 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 26 |
+
|
| 27 |
+
# Format: "canonical_term": ["synonym1", "synonym2", ...]
|
| 28 |
+
# All synonyms are semantically equivalent in a business context.
|
| 29 |
+
|
| 30 |
+
_SYNONYMS: dict[str, list[str]] = {
|
| 31 |
+
|
| 32 |
+
# Verbs / action starters
|
| 33 |
+
"list": ["show", "display", "return", "give me", "find", "retrieve"],
|
| 34 |
+
"show": ["list", "display", "return", "get", "retrieve"],
|
| 35 |
+
"find": ["identify", "locate", "get", "show", "retrieve", "look up"],
|
| 36 |
+
"return": ["show", "give", "list", "retrieve", "output"],
|
| 37 |
+
"retrieve": ["fetch", "get", "return", "pull"],
|
| 38 |
+
"get": ["retrieve", "fetch", "return", "give me"],
|
| 39 |
+
|
| 40 |
+
# Aggregation words
|
| 41 |
+
"total": ["sum", "aggregate", "overall", "cumulative", "combined"],
|
| 42 |
+
"average": ["mean", "avg", "typical"],
|
| 43 |
+
"count": ["number of", "quantity of", "how many"],
|
| 44 |
+
"highest": ["largest", "maximum", "top", "greatest"],
|
| 45 |
+
"lowest": ["smallest", "minimum", "least"],
|
| 46 |
+
|
| 47 |
+
# Business / domain
|
| 48 |
+
"customer": ["client", "buyer", "user", "account holder", "shopper"],
|
| 49 |
+
"customers": ["clients", "buyers", "users", "account holders", "shoppers"],
|
| 50 |
+
"product": ["item", "SKU", "article", "goods"],
|
| 51 |
+
"products": ["items", "SKUs", "articles", "goods"],
|
| 52 |
+
"order": ["purchase", "transaction", "sale"],
|
| 53 |
+
"orders": ["purchases", "transactions", "sales"],
|
| 54 |
+
"revenue": ["income", "earnings", "sales amount", "money earned"],
|
| 55 |
+
"spending": ["expenditure", "spend", "purchases"],
|
| 56 |
+
"amount": ["value", "sum", "total", "figure"],
|
| 57 |
+
"price": ["cost", "rate", "charge", "fee"],
|
| 58 |
+
|
| 59 |
+
# Healthcare
|
| 60 |
+
"patient": ["person", "individual", "case"],
|
| 61 |
+
"patients": ["persons", "individuals", "cases"],
|
| 62 |
+
"doctor": ["physician", "clinician", "practitioner", "specialist"],
|
| 63 |
+
"doctors": ["physicians", "clinicians", "practitioners"],
|
| 64 |
+
"appointment": ["visit", "consultation", "session"],
|
| 65 |
+
"appointments": ["visits", "consultations", "sessions"],
|
| 66 |
+
"medication": ["drug", "medicine", "pharmaceutical", "prescription drug"],
|
| 67 |
+
"medications": ["drugs", "medicines", "pharmaceuticals"],
|
| 68 |
+
"diagnosis": ["condition", "finding", "medical finding"],
|
| 69 |
+
|
| 70 |
+
# Finance
|
| 71 |
+
"account": ["bank account", "profile", "portfolio entry"],
|
| 72 |
+
"accounts": ["bank accounts", "profiles"],
|
| 73 |
+
"loan": ["credit", "borrowing", "debt instrument"],
|
| 74 |
+
"loans": ["credits", "borrowings", "debt instruments"],
|
| 75 |
+
"transaction": ["transfer", "payment", "operation", "activity"],
|
| 76 |
+
"transactions": ["transfers", "payments", "operations"],
|
| 77 |
+
"balance": ["funds", "available amount", "account balance"],
|
| 78 |
+
|
| 79 |
+
# HR
|
| 80 |
+
"employee": ["staff member", "worker", "team member", "headcount"],
|
| 81 |
+
"employees": ["staff", "workers", "team members", "workforce"],
|
| 82 |
+
"department": ["team", "division", "unit", "group"],
|
| 83 |
+
"departments": ["teams", "divisions", "units"],
|
| 84 |
+
"salary": ["pay", "compensation", "remuneration", "earnings"],
|
| 85 |
+
"project": ["initiative", "program", "assignment", "engagement"],
|
| 86 |
+
"projects": ["initiatives", "programs", "assignments"],
|
| 87 |
+
|
| 88 |
+
# Adjectives / Qualifiers
|
| 89 |
+
"active": ["current", "ongoing", "live", "existing"],
|
| 90 |
+
"delivered": ["completed", "fulfilled", "received"],
|
| 91 |
+
"cancelled": ["voided", "aborted", "terminated"],
|
| 92 |
+
"alphabetically": ["by name", "in alphabetical order", "A to Z"],
|
| 93 |
+
"descending": ["from highest to lowest", "in decreasing order", "largest first"],
|
| 94 |
+
"ascending": ["from lowest to highest", "in increasing order", "smallest first"],
|
| 95 |
+
"distinct": ["unique", "different"],
|
| 96 |
+
"in stock": ["available", "with available inventory", "not out of stock"],
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# ────────────────────────────────────────���────────────────────────────────────
|
| 101 |
+
# DATE PHRASE PATTERNS
|
| 102 |
+
# These will be replaced with alternative date expressions.
|
| 103 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 104 |
+
|
| 105 |
+
_DATE_ALTERNATES: list[tuple[str, list[str]]] = [
|
| 106 |
+
# ISO partial
|
| 107 |
+
("2024-01-01", ["January 1st 2024", "Jan 1, 2024", "the start of 2024", "2024 start"]),
|
| 108 |
+
("2023-01-01", ["January 1st 2023", "Jan 1, 2023", "the start of 2023"]),
|
| 109 |
+
("2025-01-01", ["January 1st 2025", "the start of 2025"]),
|
| 110 |
+
# Quarter references
|
| 111 |
+
("Q1", ["the first quarter", "January through March", "Jan-Mar"]),
|
| 112 |
+
("Q2", ["the second quarter", "April through June", "Apr-Jun"]),
|
| 113 |
+
("Q3", ["the third quarter", "July through September", "Jul-Sep"]),
|
| 114 |
+
("Q4", ["the fourth quarter", "October through December", "Oct-Dec"]),
|
| 115 |
+
# Year references
|
| 116 |
+
("in 2024", ["during 2024", "throughout 2024", "for the year 2024"]),
|
| 117 |
+
("in 2023", ["during 2023", "throughout 2023", "for the year 2023"]),
|
| 118 |
+
]
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 122 |
+
# CONDITION REORDERING
|
| 123 |
+
# Splits on "and" between two conditions and reverses them.
|
| 124 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 125 |
+
|
| 126 |
+
def _reorder_conditions(text: str, rng: random.Random) -> str:
|
| 127 |
+
"""
|
| 128 |
+
If the text contains ' and ' connecting two distinct clauses,
|
| 129 |
+
randomly swap their order 50% of the time.
|
| 130 |
+
|
| 131 |
+
Example:
|
| 132 |
+
"active employees earning above $100,000"
|
| 133 |
+
→ "employees earning above $100,000 that are active"
|
| 134 |
+
"""
|
| 135 |
+
# Only attempt if "and" is present as a clause connector
|
| 136 |
+
matches = list(re.finditer(r'\b(?:and|who are|that are|with)\b', text, re.IGNORECASE))
|
| 137 |
+
if not matches or rng.random() > 0.5:
|
| 138 |
+
return text
|
| 139 |
+
|
| 140 |
+
# Take the first match and swap text around it
|
| 141 |
+
m = matches[0]
|
| 142 |
+
before = text[:m.start()].strip()
|
| 143 |
+
after = text[m.end():].strip()
|
| 144 |
+
connector = m.group(0).lower()
|
| 145 |
+
|
| 146 |
+
# Build swapped version
|
| 147 |
+
if connector in ("and",):
|
| 148 |
+
swapped = f"{after} and {before}"
|
| 149 |
+
else:
|
| 150 |
+
swapped = f"{after} {connector} {before}"
|
| 151 |
+
|
| 152 |
+
# Return swapped only if it doesn't break grammar badly
|
| 153 |
+
# (heuristic: swapped should not start with a verb)
|
| 154 |
+
if swapped and not swapped[0].isupper():
|
| 155 |
+
swapped = swapped[0].upper() + swapped[1:]
|
| 156 |
+
return swapped
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 160 |
+
# SYNONYM REPLACEMENT
|
| 161 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 162 |
+
|
| 163 |
+
def _apply_synonyms(text: str, rng: random.Random, max_replacements: int = 3) -> str:
|
| 164 |
+
"""
|
| 165 |
+
Replace up to `max_replacements` words/phrases with synonyms.
|
| 166 |
+
Replacement is probabilistic (50% chance per match) to maintain diversity.
|
| 167 |
+
"""
|
| 168 |
+
result = text
|
| 169 |
+
replacements_done = 0
|
| 170 |
+
|
| 171 |
+
# Shuffle the synonym keys to get different replacement targets each call
|
| 172 |
+
keys = list(_SYNONYMS.keys())
|
| 173 |
+
rng.shuffle(keys)
|
| 174 |
+
|
| 175 |
+
for canonical in keys:
|
| 176 |
+
if replacements_done >= max_replacements:
|
| 177 |
+
break
|
| 178 |
+
synonyms = _SYNONYMS[canonical]
|
| 179 |
+
# Case-insensitive match on word boundary
|
| 180 |
+
pattern = re.compile(r'\b' + re.escape(canonical) + r'\b', re.IGNORECASE)
|
| 181 |
+
if pattern.search(result) and rng.random() < 0.5:
|
| 182 |
+
replacement = rng.choice(synonyms)
|
| 183 |
+
# Preserve original casing for first character
|
| 184 |
+
def _replace(m: re.Match) -> str:
|
| 185 |
+
original = m.group(0)
|
| 186 |
+
if original[0].isupper():
|
| 187 |
+
return replacement[0].upper() + replacement[1:]
|
| 188 |
+
return replacement
|
| 189 |
+
result = pattern.sub(_replace, result, count=1)
|
| 190 |
+
replacements_done += 1
|
| 191 |
+
|
| 192 |
+
return result
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 196 |
+
# DATE FORMAT VARIATION
|
| 197 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 198 |
+
|
| 199 |
+
def _vary_dates(text: str, rng: random.Random) -> str:
|
| 200 |
+
"""Replace date phrases with alternate representations."""
|
| 201 |
+
result = text
|
| 202 |
+
for phrase, alternates in _DATE_ALTERNATES:
|
| 203 |
+
if phrase.lower() in result.lower() and rng.random() < 0.6:
|
| 204 |
+
alt = rng.choice(alternates)
|
| 205 |
+
result = re.sub(re.escape(phrase), alt, result, count=1, flags=re.IGNORECASE)
|
| 206 |
+
return result
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 210 |
+
# PUBLIC API
|
| 211 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 212 |
+
|
| 213 |
+
def augment_nl(
|
| 214 |
+
nl_question: str,
|
| 215 |
+
n: int = 3,
|
| 216 |
+
seed: int = 42,
|
| 217 |
+
) -> list[str]:
|
| 218 |
+
"""
|
| 219 |
+
Generate `n` rule-based augmented variants of a natural language question.
|
| 220 |
+
|
| 221 |
+
Each variant applies a different combination of:
|
| 222 |
+
- synonym replacement
|
| 223 |
+
- condition reordering
|
| 224 |
+
- date format variation
|
| 225 |
+
|
| 226 |
+
The original question is NOT included in the output.
|
| 227 |
+
|
| 228 |
+
Parameters
|
| 229 |
+
----------
|
| 230 |
+
nl_question : str
|
| 231 |
+
The base NL question to augment.
|
| 232 |
+
n : int
|
| 233 |
+
Number of variants to generate.
|
| 234 |
+
seed : int
|
| 235 |
+
Random seed for reproducibility.
|
| 236 |
+
|
| 237 |
+
Returns
|
| 238 |
+
-------
|
| 239 |
+
list[str]
|
| 240 |
+
Up to `n` distinct augmented strings. May be fewer if the question
|
| 241 |
+
is too short to vary meaningfully.
|
| 242 |
+
"""
|
| 243 |
+
rng = random.Random(seed)
|
| 244 |
+
variants: list[str] = []
|
| 245 |
+
seen: set[str] = {nl_question}
|
| 246 |
+
|
| 247 |
+
strategies = [
|
| 248 |
+
# Strategy 1: synonym only
|
| 249 |
+
lambda t, r: _apply_synonyms(t, r, max_replacements=2),
|
| 250 |
+
# Strategy 2: synonym + date
|
| 251 |
+
lambda t, r: _vary_dates(_apply_synonyms(t, r, max_replacements=2), r),
|
| 252 |
+
# Strategy 3: condition reorder + synonym
|
| 253 |
+
lambda t, r: _apply_synonyms(_reorder_conditions(t, r), r, max_replacements=1),
|
| 254 |
+
# Strategy 4: heavy synonym
|
| 255 |
+
lambda t, r: _apply_synonyms(t, r, max_replacements=4),
|
| 256 |
+
# Strategy 5: date only
|
| 257 |
+
lambda t, r: _vary_dates(t, r),
|
| 258 |
+
]
|
| 259 |
+
|
| 260 |
+
for i in range(n * 3): # Over-generate, then deduplicate
|
| 261 |
+
strategy = strategies[i % len(strategies)]
|
| 262 |
+
# Use a different seed offset per variant attempt
|
| 263 |
+
local_rng = random.Random(seed + i * 31)
|
| 264 |
+
candidate = strategy(nl_question, local_rng).strip()
|
| 265 |
+
|
| 266 |
+
# Normalise whitespace
|
| 267 |
+
candidate = " ".join(candidate.split())
|
| 268 |
+
|
| 269 |
+
if candidate and candidate not in seen:
|
| 270 |
+
seen.add(candidate)
|
| 271 |
+
variants.append(candidate)
|
| 272 |
+
|
| 273 |
+
if len(variants) >= n:
|
| 274 |
+
break
|
| 275 |
+
|
| 276 |
+
return variants
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def generate_all_augmentations(
|
| 280 |
+
nl_question: str,
|
| 281 |
+
seed: int = 42,
|
| 282 |
+
n_per_template: int = 3,
|
| 283 |
+
) -> Iterator[str]:
|
| 284 |
+
"""
|
| 285 |
+
Yield augmented NL variants one at a time (generator).
|
| 286 |
+
Suitable for streaming into a large dataset without memory pressure.
|
| 287 |
+
"""
|
| 288 |
+
yield from augment_nl(nl_question, n=n_per_template, seed=seed)
|
data_factory/config.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
data_factory/config.py
|
| 3 |
+
======================
|
| 4 |
+
Central configuration for the NL2SQL Synthetic Data Factory.
|
| 5 |
+
|
| 6 |
+
Design philosophy:
|
| 7 |
+
- SQL ALWAYS comes from human-verified templates → zero SQL errors
|
| 8 |
+
- LLM ONLY generates natural language paraphrases → no SQL hallucination
|
| 9 |
+
- Every SQL is execution-validated before saving → guaranteed correctness
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
# ── Paths ────────────────────────────────────────────────────────────────
|
| 16 |
+
ROOT_DIR = Path(__file__).parent.parent
|
| 17 |
+
DATA_DIR = ROOT_DIR / "generated_data"
|
| 18 |
+
CHECKPOINT_DIR = DATA_DIR / "checkpoints"
|
| 19 |
+
OUTPUT_DIR = DATA_DIR / "output"
|
| 20 |
+
|
| 21 |
+
# ── vLLM / Model ─────────────────────────────────────────────────────────
|
| 22 |
+
# For H100 with 80GB VRAM — run Llama-3-70B or Qwen-72B at full bf16
|
| 23 |
+
GENERATOR_MODEL = "meta-llama/Meta-Llama-3-70B-Instruct" # change to your preferred model
|
| 24 |
+
TENSOR_PARALLEL = 4 # Number of GPUs for tensor parallelism (H100 cluster)
|
| 25 |
+
MAX_MODEL_LEN = 4096 # Max context length
|
| 26 |
+
GPU_MEMORY_UTIL = 0.90 # Leave 10% headroom
|
| 27 |
+
|
| 28 |
+
# ── Generation settings ──────────────────────────────────────────────────
|
| 29 |
+
PERSONAS = ["ceo", "chatty", "lazy_typist", "non_techie", "analyst"]
|
| 30 |
+
NL_VARIANTS_PER_TEMPLATE = 5 # One per persona
|
| 31 |
+
AUGMENTATIONS_PER_NL = 3 # Rule-based variations per NL string
|
| 32 |
+
TEMPERATURE = 0.85 # Slightly high for diversity
|
| 33 |
+
MAX_NEW_TOKENS = 150 # NL questions are short
|
| 34 |
+
|
| 35 |
+
# ── Scale targets ────────────────────────────────────────────────────────
|
| 36 |
+
# 56 base SQL templates × 5 personas × 3 augmentations = 840 "original" records
|
| 37 |
+
# With vLLM generating more NL variants, target: ~500K-1M clean records
|
| 38 |
+
VLLM_EXTRA_VARIANTS = 10 # Additional vLLM NL variants per template beyond personas
|
| 39 |
+
|
| 40 |
+
# ── Validation ───────────────────────────────────────────────────────────
|
| 41 |
+
RANDOM_SEED = 42
|
| 42 |
+
|
| 43 |
+
# ── Domains ──────────────────────────────────────────────────────────────
|
| 44 |
+
DOMAINS = ["ecommerce", "healthcare", "finance", "hr"]
|
| 45 |
+
|
| 46 |
+
DIFFICULTY_LABELS = {
|
| 47 |
+
"easy": "Single-table SELECT with basic WHERE/ORDER/LIMIT.",
|
| 48 |
+
"medium": "Multi-table JOIN with GROUP BY/HAVING/aggregates.",
|
| 49 |
+
"hard": "CTEs, window functions, subqueries.",
|
| 50 |
+
}
|
data_factory/generate_data.py
ADDED
|
@@ -0,0 +1,1947 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
generate_data.py — NL2SQL Synthetic Data Factory
|
| 3 |
+
=================================================
|
| 4 |
+
Designed for H100 + vLLM. Produces a clean JSONL file ready for SFT or GRPO training
|
| 5 |
+
with the nl2sql-bench codebase (schema: e-commerce SQLite).
|
| 6 |
+
|
| 7 |
+
Architecture
|
| 8 |
+
------------
|
| 9 |
+
1. SQL_TEMPLATES — 120+ ground-truth SQLs, hand-written and verified, NEVER LLM-generated.
|
| 10 |
+
2. SQLiteValidator — executes every SQL against the actual seeded DB; discards any failure.
|
| 11 |
+
3. VLLMGenerator — async batched calls to a local vLLM server for NL paraphrasing.
|
| 12 |
+
4. RuleAugmentor — pure-Python synonym / date-format / condition-order augmentation.
|
| 13 |
+
5. DataFactory — orchestrates the full pipeline; writes JSONL with checkpointing.
|
| 14 |
+
|
| 15 |
+
Output schema (one JSON object per line)
|
| 16 |
+
-----------------------------------------
|
| 17 |
+
{
|
| 18 |
+
"id": "easy_001_persona_ceo",
|
| 19 |
+
"difficulty": "easy" | "medium" | "hard",
|
| 20 |
+
"persona": "ceo" | "chatty" | "lazy" | "confused" | "analyst",
|
| 21 |
+
"question": "<natural language question>",
|
| 22 |
+
"sql": "<ground-truth SQL>",
|
| 23 |
+
"db_result_ok": true, # always true — failures are discarded
|
| 24 |
+
"augmented": false # true when rule-augmentor modified the NL
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
Usage
|
| 28 |
+
-----
|
| 29 |
+
# 1. Start vLLM server (H100):
|
| 30 |
+
# vllm serve meta-llama/Meta-Llama-3-70B-Instruct \
|
| 31 |
+
# --tensor-parallel-size 4 --port 8001 \
|
| 32 |
+
# --max-model-len 4096 --gpu-memory-utilization 0.92
|
| 33 |
+
|
| 34 |
+
# 2. Run this script (place it next to the nl2sql-bench folder):
|
| 35 |
+
# python generate_data.py \
|
| 36 |
+
# --vllm-url http://localhost:8001/v1 \
|
| 37 |
+
# --model meta-llama/Meta-Llama-3-70B-Instruct \
|
| 38 |
+
# --output nl2sql_train.jsonl \
|
| 39 |
+
# --personas-per-template 5 \
|
| 40 |
+
# --aug-rounds 2 \
|
| 41 |
+
# --batch-size 64
|
| 42 |
+
|
| 43 |
+
Requirements
|
| 44 |
+
------------
|
| 45 |
+
pip install openai tqdm
|
| 46 |
+
(vLLM + your model already running separately)
|
| 47 |
+
|
| 48 |
+
IMPORTANT: Copy server/db/schema.sql and server/db/seed.py from nl2sql-bench
|
| 49 |
+
into the same directory as this script, OR set --bench-root to the repo root.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
from __future__ import annotations
|
| 53 |
+
|
| 54 |
+
import argparse
|
| 55 |
+
import asyncio
|
| 56 |
+
import hashlib
|
| 57 |
+
import json
|
| 58 |
+
import logging
|
| 59 |
+
import os
|
| 60 |
+
import random
|
| 61 |
+
import re
|
| 62 |
+
import sqlite3
|
| 63 |
+
import sys
|
| 64 |
+
import time
|
| 65 |
+
from copy import deepcopy
|
| 66 |
+
from dataclasses import dataclass, asdict
|
| 67 |
+
from pathlib import Path
|
| 68 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 69 |
+
|
| 70 |
+
from openai import AsyncOpenAI
|
| 71 |
+
from tqdm import tqdm
|
| 72 |
+
|
| 73 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 74 |
+
# Logging
|
| 75 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 76 |
+
|
| 77 |
+
logging.basicConfig(
|
| 78 |
+
level=logging.INFO,
|
| 79 |
+
format="%(asctime)s %(levelname)-8s %(message)s",
|
| 80 |
+
datefmt="%H:%M:%S",
|
| 81 |
+
)
|
| 82 |
+
log = logging.getLogger("data-factory")
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 86 |
+
# Database: build & validate
|
| 87 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 88 |
+
|
| 89 |
+
SCHEMA_SQL = """
|
| 90 |
+
CREATE TABLE IF NOT EXISTS categories (
|
| 91 |
+
id INTEGER PRIMARY KEY,
|
| 92 |
+
name TEXT NOT NULL UNIQUE
|
| 93 |
+
);
|
| 94 |
+
|
| 95 |
+
CREATE TABLE IF NOT EXISTS products (
|
| 96 |
+
id INTEGER PRIMARY KEY,
|
| 97 |
+
name TEXT NOT NULL,
|
| 98 |
+
category_id INTEGER NOT NULL REFERENCES categories(id),
|
| 99 |
+
price REAL NOT NULL CHECK(price >= 0),
|
| 100 |
+
stock_quantity INTEGER NOT NULL DEFAULT 0
|
| 101 |
+
);
|
| 102 |
+
|
| 103 |
+
CREATE TABLE IF NOT EXISTS customers (
|
| 104 |
+
id INTEGER PRIMARY KEY,
|
| 105 |
+
name TEXT NOT NULL,
|
| 106 |
+
email TEXT NOT NULL UNIQUE,
|
| 107 |
+
country TEXT NOT NULL,
|
| 108 |
+
tier TEXT NOT NULL DEFAULT 'bronze'
|
| 109 |
+
CHECK(tier IN ('bronze', 'silver', 'gold')),
|
| 110 |
+
created_at TEXT NOT NULL
|
| 111 |
+
);
|
| 112 |
+
|
| 113 |
+
CREATE TABLE IF NOT EXISTS orders (
|
| 114 |
+
id INTEGER PRIMARY KEY,
|
| 115 |
+
customer_id INTEGER NOT NULL REFERENCES customers(id),
|
| 116 |
+
status TEXT NOT NULL DEFAULT 'pending'
|
| 117 |
+
CHECK(status IN ('pending','processing','shipped','delivered','cancelled')),
|
| 118 |
+
created_at TEXT NOT NULL,
|
| 119 |
+
total_amount REAL NOT NULL CHECK(total_amount >= 0)
|
| 120 |
+
);
|
| 121 |
+
|
| 122 |
+
CREATE TABLE IF NOT EXISTS order_items (
|
| 123 |
+
id INTEGER PRIMARY KEY,
|
| 124 |
+
order_id INTEGER NOT NULL REFERENCES orders(id),
|
| 125 |
+
product_id INTEGER NOT NULL REFERENCES products(id),
|
| 126 |
+
quantity INTEGER NOT NULL CHECK(quantity > 0),
|
| 127 |
+
unit_price REAL NOT NULL CHECK(unit_price >= 0)
|
| 128 |
+
);
|
| 129 |
+
|
| 130 |
+
CREATE TABLE IF NOT EXISTS reviews (
|
| 131 |
+
id INTEGER PRIMARY KEY,
|
| 132 |
+
product_id INTEGER NOT NULL REFERENCES products(id),
|
| 133 |
+
customer_id INTEGER NOT NULL REFERENCES customers(id),
|
| 134 |
+
rating INTEGER NOT NULL CHECK(rating BETWEEN 1 AND 5),
|
| 135 |
+
created_at TEXT NOT NULL
|
| 136 |
+
);
|
| 137 |
+
"""
|
| 138 |
+
|
| 139 |
+
# Minimal seeder so the validator can run the SQL against real data.
|
| 140 |
+
# Mirrors the logic in nl2sql-bench/server/db/seed.py (fixed seed = 42).
|
| 141 |
+
SEED_SCRIPT = """
|
| 142 |
+
import random, sqlite3
|
| 143 |
+
from datetime import date, timedelta
|
| 144 |
+
|
| 145 |
+
RNG = random.Random(42)
|
| 146 |
+
|
| 147 |
+
CATEGORIES = ["Electronics","Clothing","Books","Home & Garden",
|
| 148 |
+
"Sports & Outdoors","Toys & Games","Beauty","Automotive"]
|
| 149 |
+
|
| 150 |
+
PRODUCTS = {
|
| 151 |
+
"Electronics": ["Wireless Headphones","USB-C Hub","Mechanical Keyboard",
|
| 152 |
+
"Webcam 4K","Portable Charger","Smart Speaker",
|
| 153 |
+
"Monitor Stand","HDMI Cable 2.1"],
|
| 154 |
+
"Clothing": ["Cotton T-Shirt","Slim Fit Jeans","Hoodie",
|
| 155 |
+
"Running Shorts","Winter Jacket","Polo Shirt",
|
| 156 |
+
"Casual Sneakers","Wool Socks"],
|
| 157 |
+
"Books": ["Clean Code","Designing Data-Intensive Applications",
|
| 158 |
+
"The Pragmatic Programmer","System Design Interview",
|
| 159 |
+
"Deep Learning Book","Python Cookbook",
|
| 160 |
+
"Domain-Driven Design","Refactoring"],
|
| 161 |
+
"Home & Garden": ["Coffee Maker","Air Purifier","LED Desk Lamp",
|
| 162 |
+
"Plant Pot Set","Storage Organiser","Cutting Board",
|
| 163 |
+
"Vacuum Cleaner","Electric Kettle"],
|
| 164 |
+
"Sports & Outdoors": ["Yoga Mat","Resistance Bands","Cycling Gloves",
|
| 165 |
+
"Trekking Poles","Water Bottle 1L","Jump Rope",
|
| 166 |
+
"Foam Roller","Compression Socks"],
|
| 167 |
+
"Toys & Games": ["Lego City Set","Card Game Pack","Puzzle 1000pc",
|
| 168 |
+
"Remote Control Car","Building Blocks",
|
| 169 |
+
"Board Game Strategy","Art Set","Toy Drone"],
|
| 170 |
+
"Beauty": ["Face Serum","SPF 50 Sunscreen","Lip Balm",
|
| 171 |
+
"Shampoo Pro","Hair Mask","Eye Cream",
|
| 172 |
+
"Vitamin C Cream","Toner Mist"],
|
| 173 |
+
"Automotive": ["Car Phone Mount","Dash Cam","Tyre Inflator",
|
| 174 |
+
"Car Vacuum","Seat Cushion","Steering Wheel Cover",
|
| 175 |
+
"OBD Scanner","Jump Starter"],
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
COUNTRIES = ["India","USA","Germany","UK","Canada",
|
| 179 |
+
"Australia","France","Brazil","Japan","Singapore"]
|
| 180 |
+
TIERS = ["bronze","silver","gold"]
|
| 181 |
+
STATUSES = ["pending","processing","shipped","delivered","cancelled"]
|
| 182 |
+
|
| 183 |
+
FIRST = ["Aarav","Priya","Rahul","Neha","Arjun","Sneha","Vikram","Pooja",
|
| 184 |
+
"Karthik","Divya","James","Sarah","Michael","Emily","David","Jessica",
|
| 185 |
+
"Hans","Lena","Oliver","Sofia","Pierre","Amelie","Carlos","Laura",
|
| 186 |
+
"Yuki","Hana","Wei","Mei","Aiden","Zara"]
|
| 187 |
+
LAST = ["Sharma","Singh","Patel","Kumar","Gupta","Verma","Nair","Reddy",
|
| 188 |
+
"Smith","Johnson","Brown","Williams","Jones","Davis","Wilson",
|
| 189 |
+
"Müller","Schmidt","Schneider","Fischer","Weber",
|
| 190 |
+
"Martin","Bernard","Thomas","Richard","Petit",
|
| 191 |
+
"Garcia","Martinez","Lopez","Sanchez","Gonzalez"]
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def _date(start=2022, end=2025):
|
| 195 |
+
s = date(start, 1, 1)
|
| 196 |
+
e = date(end, 12, 31)
|
| 197 |
+
return str(s + timedelta(days=RNG.randint(0, (e - s).days)))
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def seed(conn):
|
| 201 |
+
c = conn.cursor()
|
| 202 |
+
for cat in CATEGORIES:
|
| 203 |
+
c.execute("INSERT OR IGNORE INTO categories(name) VALUES (?)", (cat,))
|
| 204 |
+
conn.commit()
|
| 205 |
+
|
| 206 |
+
cat_ids = {r[1]: r[0] for r in conn.execute("SELECT id, name FROM categories")}
|
| 207 |
+
|
| 208 |
+
for cat, prods in PRODUCTS.items():
|
| 209 |
+
for pname in prods:
|
| 210 |
+
c.execute(
|
| 211 |
+
"INSERT OR IGNORE INTO products(name,category_id,price,stock_quantity) VALUES (?,?,?,?)",
|
| 212 |
+
(pname, cat_ids[cat], round(RNG.uniform(5, 500), 2), RNG.randint(0, 200)),
|
| 213 |
+
)
|
| 214 |
+
conn.commit()
|
| 215 |
+
|
| 216 |
+
for i in range(200):
|
| 217 |
+
name = f"{RNG.choice(FIRST)} {RNG.choice(LAST)}"
|
| 218 |
+
email = f"user{i}@example.com"
|
| 219 |
+
c.execute(
|
| 220 |
+
"INSERT OR IGNORE INTO customers(name,email,country,tier,created_at) VALUES (?,?,?,?,?)",
|
| 221 |
+
(name, email, RNG.choice(COUNTRIES), RNG.choice(TIERS), _date()),
|
| 222 |
+
)
|
| 223 |
+
conn.commit()
|
| 224 |
+
|
| 225 |
+
cust_ids = [r[0] for r in conn.execute("SELECT id FROM customers")]
|
| 226 |
+
prod_ids = [r[0] for r in conn.execute("SELECT id FROM products")]
|
| 227 |
+
|
| 228 |
+
for _ in range(600):
|
| 229 |
+
cid = RNG.choice(cust_ids)
|
| 230 |
+
amt = round(RNG.uniform(10, 1000), 2)
|
| 231 |
+
status = RNG.choice(STATUSES)
|
| 232 |
+
d = _date()
|
| 233 |
+
c.execute(
|
| 234 |
+
"INSERT INTO orders(customer_id,status,created_at,total_amount) VALUES (?,?,?,?)",
|
| 235 |
+
(cid, status, d, amt),
|
| 236 |
+
)
|
| 237 |
+
conn.commit()
|
| 238 |
+
|
| 239 |
+
ord_ids = [r[0] for r in conn.execute("SELECT id FROM orders")]
|
| 240 |
+
for oid in ord_ids:
|
| 241 |
+
for _ in range(RNG.randint(1, 4)):
|
| 242 |
+
pid = RNG.choice(prod_ids)
|
| 243 |
+
qty = RNG.randint(1, 5)
|
| 244 |
+
price = round(RNG.uniform(5, 500), 2)
|
| 245 |
+
c.execute(
|
| 246 |
+
"INSERT INTO order_items(order_id,product_id,quantity,unit_price) VALUES (?,?,?,?)",
|
| 247 |
+
(oid, pid, qty, price),
|
| 248 |
+
)
|
| 249 |
+
conn.commit()
|
| 250 |
+
|
| 251 |
+
for _ in range(400):
|
| 252 |
+
pid = RNG.choice(prod_ids)
|
| 253 |
+
cid = RNG.choice(cust_ids)
|
| 254 |
+
rating = RNG.randint(1, 5)
|
| 255 |
+
c.execute(
|
| 256 |
+
"INSERT INTO reviews(product_id,customer_id,rating,created_at) VALUES (?,?,?,?)",
|
| 257 |
+
(pid, cid, rating, _date()),
|
| 258 |
+
)
|
| 259 |
+
conn.commit()
|
| 260 |
+
"""
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def build_db() -> sqlite3.Connection:
|
| 264 |
+
"""Build an in-memory SQLite DB with schema + seed data."""
|
| 265 |
+
conn = sqlite3.connect(":memory:")
|
| 266 |
+
conn.executescript(SCHEMA_SQL)
|
| 267 |
+
exec(SEED_SCRIPT, {"conn": conn}) # run the seeder inline
|
| 268 |
+
conn.row_factory = sqlite3.Row
|
| 269 |
+
log.info("In-memory DB built and seeded.")
|
| 270 |
+
return conn
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
class SQLiteValidator:
|
| 274 |
+
"""Execute SQL against the seeded DB; return (rows, error)."""
|
| 275 |
+
|
| 276 |
+
def __init__(self, conn: sqlite3.Connection):
|
| 277 |
+
self.conn = conn
|
| 278 |
+
|
| 279 |
+
def validate(self, sql: str) -> Tuple[bool, Optional[str]]:
|
| 280 |
+
sql = sql.strip().rstrip(";")
|
| 281 |
+
if not sql:
|
| 282 |
+
return False, "Empty SQL"
|
| 283 |
+
first = sql.split()[0].lower()
|
| 284 |
+
if first != "select":
|
| 285 |
+
return False, f"Non-SELECT statement: {first}"
|
| 286 |
+
try:
|
| 287 |
+
cur = self.conn.execute(sql)
|
| 288 |
+
cur.fetchmany(500)
|
| 289 |
+
return True, None
|
| 290 |
+
except sqlite3.Error as exc:
|
| 291 |
+
return False, str(exc)
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 295 |
+
# SQL Template Library (ground-truth, hand-written, execution-validated)
|
| 296 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 297 |
+
|
| 298 |
+
@dataclass
|
| 299 |
+
class SQLTemplate:
|
| 300 |
+
id: str
|
| 301 |
+
difficulty: str # easy | medium | hard
|
| 302 |
+
description: str # plain-English description fed to the LLM
|
| 303 |
+
sql: str
|
| 304 |
+
order_sensitive: bool = False
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
# NOTE: Every SQL here uses only the 6 tables in the schema and valid SQLite syntax.
|
| 308 |
+
# They are intentionally grouped by the SQL pattern they teach, not just by difficulty.
|
| 309 |
+
|
| 310 |
+
EASY_TEMPLATES: List[SQLTemplate] = [
|
| 311 |
+
# ── Equality filter ──────────────────────────────────────────────────────
|
| 312 |
+
SQLTemplate(
|
| 313 |
+
id="easy_001",
|
| 314 |
+
difficulty="easy",
|
| 315 |
+
description=(
|
| 316 |
+
"List all gold-tier customers, ordered alphabetically by name. "
|
| 317 |
+
"Return id, name, email, country."
|
| 318 |
+
),
|
| 319 |
+
sql=(
|
| 320 |
+
"SELECT id, name, email, country "
|
| 321 |
+
"FROM customers "
|
| 322 |
+
"WHERE tier = 'gold' "
|
| 323 |
+
"ORDER BY name ASC"
|
| 324 |
+
),
|
| 325 |
+
order_sensitive=True,
|
| 326 |
+
),
|
| 327 |
+
SQLTemplate(
|
| 328 |
+
id="easy_002",
|
| 329 |
+
difficulty="easy",
|
| 330 |
+
description=(
|
| 331 |
+
"Show all products priced above $100, sorted by price descending. "
|
| 332 |
+
"Return id, name, price."
|
| 333 |
+
),
|
| 334 |
+
sql=(
|
| 335 |
+
"SELECT id, name, price "
|
| 336 |
+
"FROM products "
|
| 337 |
+
"WHERE price > 100 "
|
| 338 |
+
"ORDER BY price DESC"
|
| 339 |
+
),
|
| 340 |
+
order_sensitive=True,
|
| 341 |
+
),
|
| 342 |
+
SQLTemplate(
|
| 343 |
+
id="easy_003",
|
| 344 |
+
difficulty="easy",
|
| 345 |
+
description=(
|
| 346 |
+
"Find all delivered orders with a total_amount greater than $200, "
|
| 347 |
+
"sorted by total_amount descending. "
|
| 348 |
+
"Return id, customer_id, total_amount, created_at."
|
| 349 |
+
),
|
| 350 |
+
sql=(
|
| 351 |
+
"SELECT id, customer_id, total_amount, created_at "
|
| 352 |
+
"FROM orders "
|
| 353 |
+
"WHERE status = 'delivered' AND total_amount > 200 "
|
| 354 |
+
"ORDER BY total_amount DESC"
|
| 355 |
+
),
|
| 356 |
+
order_sensitive=True,
|
| 357 |
+
),
|
| 358 |
+
SQLTemplate(
|
| 359 |
+
id="easy_004",
|
| 360 |
+
difficulty="easy",
|
| 361 |
+
description=(
|
| 362 |
+
"Return the top 5 most expensive products. Return id, name, price."
|
| 363 |
+
),
|
| 364 |
+
sql=(
|
| 365 |
+
"SELECT id, name, price "
|
| 366 |
+
"FROM products "
|
| 367 |
+
"ORDER BY price DESC "
|
| 368 |
+
"LIMIT 5"
|
| 369 |
+
),
|
| 370 |
+
order_sensitive=True,
|
| 371 |
+
),
|
| 372 |
+
SQLTemplate(
|
| 373 |
+
id="easy_005",
|
| 374 |
+
difficulty="easy",
|
| 375 |
+
description=(
|
| 376 |
+
"List all distinct countries where customers come from, sorted alphabetically. "
|
| 377 |
+
"Return a single column: country."
|
| 378 |
+
),
|
| 379 |
+
sql=(
|
| 380 |
+
"SELECT DISTINCT country "
|
| 381 |
+
"FROM customers "
|
| 382 |
+
"ORDER BY country ASC"
|
| 383 |
+
),
|
| 384 |
+
order_sensitive=True,
|
| 385 |
+
),
|
| 386 |
+
SQLTemplate(
|
| 387 |
+
id="easy_006",
|
| 388 |
+
difficulty="easy",
|
| 389 |
+
description=(
|
| 390 |
+
"Show all pending orders, ordered by created_at descending. "
|
| 391 |
+
"Return id, customer_id, total_amount, created_at."
|
| 392 |
+
),
|
| 393 |
+
sql=(
|
| 394 |
+
"SELECT id, customer_id, total_amount, created_at "
|
| 395 |
+
"FROM orders "
|
| 396 |
+
"WHERE status = 'pending' "
|
| 397 |
+
"ORDER BY created_at DESC"
|
| 398 |
+
),
|
| 399 |
+
order_sensitive=True,
|
| 400 |
+
),
|
| 401 |
+
SQLTemplate(
|
| 402 |
+
id="easy_007",
|
| 403 |
+
difficulty="easy",
|
| 404 |
+
description=(
|
| 405 |
+
"Find all products with zero stock (stock_quantity = 0). "
|
| 406 |
+
"Return id, name, price, category_id."
|
| 407 |
+
),
|
| 408 |
+
sql=(
|
| 409 |
+
"SELECT id, name, price, category_id "
|
| 410 |
+
"FROM products "
|
| 411 |
+
"WHERE stock_quantity = 0"
|
| 412 |
+
),
|
| 413 |
+
),
|
| 414 |
+
SQLTemplate(
|
| 415 |
+
id="easy_008",
|
| 416 |
+
difficulty="easy",
|
| 417 |
+
description=(
|
| 418 |
+
"How many customers are there in total? Return a single value: total_customers."
|
| 419 |
+
),
|
| 420 |
+
sql="SELECT COUNT(*) AS total_customers FROM customers",
|
| 421 |
+
),
|
| 422 |
+
SQLTemplate(
|
| 423 |
+
id="easy_009",
|
| 424 |
+
difficulty="easy",
|
| 425 |
+
description=(
|
| 426 |
+
"What is the most expensive product price in the store? "
|
| 427 |
+
"Return a single value: max_price."
|
| 428 |
+
),
|
| 429 |
+
sql="SELECT MAX(price) AS max_price FROM products",
|
| 430 |
+
),
|
| 431 |
+
SQLTemplate(
|
| 432 |
+
id="easy_010",
|
| 433 |
+
difficulty="easy",
|
| 434 |
+
description=(
|
| 435 |
+
"What is the cheapest product price in the store? "
|
| 436 |
+
"Return a single value: min_price."
|
| 437 |
+
),
|
| 438 |
+
sql="SELECT MIN(price) AS min_price FROM products",
|
| 439 |
+
),
|
| 440 |
+
SQLTemplate(
|
| 441 |
+
id="easy_011",
|
| 442 |
+
difficulty="easy",
|
| 443 |
+
description=(
|
| 444 |
+
"What is the average price of all products? "
|
| 445 |
+
"Round to 2 decimal places. Return: avg_price."
|
| 446 |
+
),
|
| 447 |
+
sql="SELECT ROUND(AVG(price), 2) AS avg_price FROM products",
|
| 448 |
+
),
|
| 449 |
+
SQLTemplate(
|
| 450 |
+
id="easy_012",
|
| 451 |
+
difficulty="easy",
|
| 452 |
+
description=(
|
| 453 |
+
"Show all customers from India, sorted by name ascending. "
|
| 454 |
+
"Return id, name, email, tier."
|
| 455 |
+
),
|
| 456 |
+
sql=(
|
| 457 |
+
"SELECT id, name, email, tier "
|
| 458 |
+
"FROM customers "
|
| 459 |
+
"WHERE country = 'India' "
|
| 460 |
+
"ORDER BY name ASC"
|
| 461 |
+
),
|
| 462 |
+
order_sensitive=True,
|
| 463 |
+
),
|
| 464 |
+
SQLTemplate(
|
| 465 |
+
id="easy_013",
|
| 466 |
+
difficulty="easy",
|
| 467 |
+
description=(
|
| 468 |
+
"List the 10 most recently placed orders. "
|
| 469 |
+
"Return id, customer_id, status, created_at, total_amount."
|
| 470 |
+
),
|
| 471 |
+
sql=(
|
| 472 |
+
"SELECT id, customer_id, status, created_at, total_amount "
|
| 473 |
+
"FROM orders "
|
| 474 |
+
"ORDER BY created_at DESC "
|
| 475 |
+
"LIMIT 10"
|
| 476 |
+
),
|
| 477 |
+
order_sensitive=True,
|
| 478 |
+
),
|
| 479 |
+
SQLTemplate(
|
| 480 |
+
id="easy_014",
|
| 481 |
+
difficulty="easy",
|
| 482 |
+
description=(
|
| 483 |
+
"Find all reviews with a rating of 5 stars. "
|
| 484 |
+
"Return id, product_id, customer_id, created_at."
|
| 485 |
+
),
|
| 486 |
+
sql=(
|
| 487 |
+
"SELECT id, product_id, customer_id, created_at "
|
| 488 |
+
"FROM reviews "
|
| 489 |
+
"WHERE rating = 5"
|
| 490 |
+
),
|
| 491 |
+
),
|
| 492 |
+
SQLTemplate(
|
| 493 |
+
id="easy_015",
|
| 494 |
+
difficulty="easy",
|
| 495 |
+
description=(
|
| 496 |
+
"Find all reviews with a rating of 1 star (lowest possible). "
|
| 497 |
+
"Return id, product_id, customer_id, created_at."
|
| 498 |
+
),
|
| 499 |
+
sql=(
|
| 500 |
+
"SELECT id, product_id, customer_id, created_at "
|
| 501 |
+
"FROM reviews "
|
| 502 |
+
"WHERE rating = 1"
|
| 503 |
+
),
|
| 504 |
+
),
|
| 505 |
+
SQLTemplate(
|
| 506 |
+
id="easy_016",
|
| 507 |
+
difficulty="easy",
|
| 508 |
+
description=(
|
| 509 |
+
"Count the number of cancelled orders. Return: cancelled_count."
|
| 510 |
+
),
|
| 511 |
+
sql=(
|
| 512 |
+
"SELECT COUNT(*) AS cancelled_count "
|
| 513 |
+
"FROM orders "
|
| 514 |
+
"WHERE status = 'cancelled'"
|
| 515 |
+
),
|
| 516 |
+
),
|
| 517 |
+
SQLTemplate(
|
| 518 |
+
id="easy_017",
|
| 519 |
+
difficulty="easy",
|
| 520 |
+
description=(
|
| 521 |
+
"List all products with stock_quantity greater than 100, "
|
| 522 |
+
"sorted by stock_quantity descending. Return id, name, stock_quantity."
|
| 523 |
+
),
|
| 524 |
+
sql=(
|
| 525 |
+
"SELECT id, name, stock_quantity "
|
| 526 |
+
"FROM products "
|
| 527 |
+
"WHERE stock_quantity > 100 "
|
| 528 |
+
"ORDER BY stock_quantity DESC"
|
| 529 |
+
),
|
| 530 |
+
order_sensitive=True,
|
| 531 |
+
),
|
| 532 |
+
SQLTemplate(
|
| 533 |
+
id="easy_018",
|
| 534 |
+
difficulty="easy",
|
| 535 |
+
description=(
|
| 536 |
+
"Find all silver-tier customers from the USA. "
|
| 537 |
+
"Return id, name, email."
|
| 538 |
+
),
|
| 539 |
+
sql=(
|
| 540 |
+
"SELECT id, name, email "
|
| 541 |
+
"FROM customers "
|
| 542 |
+
"WHERE tier = 'silver' AND country = 'USA'"
|
| 543 |
+
),
|
| 544 |
+
),
|
| 545 |
+
SQLTemplate(
|
| 546 |
+
id="easy_019",
|
| 547 |
+
difficulty="easy",
|
| 548 |
+
description=(
|
| 549 |
+
"What is the total revenue from all delivered orders? "
|
| 550 |
+
"Round to 2 decimal places. Return: total_revenue."
|
| 551 |
+
),
|
| 552 |
+
sql=(
|
| 553 |
+
"SELECT ROUND(SUM(total_amount), 2) AS total_revenue "
|
| 554 |
+
"FROM orders "
|
| 555 |
+
"WHERE status = 'delivered'"
|
| 556 |
+
),
|
| 557 |
+
),
|
| 558 |
+
SQLTemplate(
|
| 559 |
+
id="easy_020",
|
| 560 |
+
difficulty="easy",
|
| 561 |
+
description=(
|
| 562 |
+
"List all orders placed in 2024, sorted by created_at ascending. "
|
| 563 |
+
"Return id, customer_id, status, total_amount, created_at."
|
| 564 |
+
),
|
| 565 |
+
sql=(
|
| 566 |
+
"SELECT id, customer_id, status, total_amount, created_at "
|
| 567 |
+
"FROM orders "
|
| 568 |
+
"WHERE created_at >= '2024-01-01' AND created_at < '2025-01-01' "
|
| 569 |
+
"ORDER BY created_at ASC"
|
| 570 |
+
),
|
| 571 |
+
order_sensitive=True,
|
| 572 |
+
),
|
| 573 |
+
SQLTemplate(
|
| 574 |
+
id="easy_021",
|
| 575 |
+
difficulty="easy",
|
| 576 |
+
description=(
|
| 577 |
+
"Show the bottom 5 cheapest products. Return id, name, price."
|
| 578 |
+
),
|
| 579 |
+
sql=(
|
| 580 |
+
"SELECT id, name, price "
|
| 581 |
+
"FROM products "
|
| 582 |
+
"ORDER BY price ASC "
|
| 583 |
+
"LIMIT 5"
|
| 584 |
+
),
|
| 585 |
+
order_sensitive=True,
|
| 586 |
+
),
|
| 587 |
+
SQLTemplate(
|
| 588 |
+
id="easy_022",
|
| 589 |
+
difficulty="easy",
|
| 590 |
+
description=(
|
| 591 |
+
"Count how many products exist in the catalogue. Return: product_count."
|
| 592 |
+
),
|
| 593 |
+
sql="SELECT COUNT(*) AS product_count FROM products",
|
| 594 |
+
),
|
| 595 |
+
SQLTemplate(
|
| 596 |
+
id="easy_023",
|
| 597 |
+
difficulty="easy",
|
| 598 |
+
description=(
|
| 599 |
+
"List all distinct order statuses that exist in the orders table. "
|
| 600 |
+
"Return a single column: status."
|
| 601 |
+
),
|
| 602 |
+
sql="SELECT DISTINCT status FROM orders ORDER BY status ASC",
|
| 603 |
+
order_sensitive=True,
|
| 604 |
+
),
|
| 605 |
+
SQLTemplate(
|
| 606 |
+
id="easy_024",
|
| 607 |
+
difficulty="easy",
|
| 608 |
+
description=(
|
| 609 |
+
"Find customers who joined (created_at) in 2023. "
|
| 610 |
+
"Return id, name, country, tier, created_at, sorted by created_at ascending."
|
| 611 |
+
),
|
| 612 |
+
sql=(
|
| 613 |
+
"SELECT id, name, country, tier, created_at "
|
| 614 |
+
"FROM customers "
|
| 615 |
+
"WHERE created_at >= '2023-01-01' AND created_at < '2024-01-01' "
|
| 616 |
+
"ORDER BY created_at ASC"
|
| 617 |
+
),
|
| 618 |
+
order_sensitive=True,
|
| 619 |
+
),
|
| 620 |
+
SQLTemplate(
|
| 621 |
+
id="easy_025",
|
| 622 |
+
difficulty="easy",
|
| 623 |
+
description=(
|
| 624 |
+
"Show all orders with total_amount between $50 and $150 inclusive. "
|
| 625 |
+
"Return id, customer_id, total_amount, status."
|
| 626 |
+
),
|
| 627 |
+
sql=(
|
| 628 |
+
"SELECT id, customer_id, total_amount, status "
|
| 629 |
+
"FROM orders "
|
| 630 |
+
"WHERE total_amount BETWEEN 50 AND 150"
|
| 631 |
+
),
|
| 632 |
+
),
|
| 633 |
+
SQLTemplate(
|
| 634 |
+
id="easy_026",
|
| 635 |
+
difficulty="easy",
|
| 636 |
+
description=(
|
| 637 |
+
"How many distinct customers have placed at least one order? "
|
| 638 |
+
"Return a single value: customers_with_orders."
|
| 639 |
+
),
|
| 640 |
+
sql=(
|
| 641 |
+
"SELECT COUNT(DISTINCT customer_id) AS customers_with_orders "
|
| 642 |
+
"FROM orders"
|
| 643 |
+
),
|
| 644 |
+
),
|
| 645 |
+
SQLTemplate(
|
| 646 |
+
id="easy_027",
|
| 647 |
+
difficulty="easy",
|
| 648 |
+
description=(
|
| 649 |
+
"What is the total number of order line items across all orders? "
|
| 650 |
+
"Return: total_line_items."
|
| 651 |
+
),
|
| 652 |
+
sql="SELECT COUNT(*) AS total_line_items FROM order_items",
|
| 653 |
+
),
|
| 654 |
+
SQLTemplate(
|
| 655 |
+
id="easy_028",
|
| 656 |
+
difficulty="easy",
|
| 657 |
+
description=(
|
| 658 |
+
"List all products priced between $20 and $80 inclusive, sorted by price ascending. "
|
| 659 |
+
"Return id, name, price."
|
| 660 |
+
),
|
| 661 |
+
sql=(
|
| 662 |
+
"SELECT id, name, price "
|
| 663 |
+
"FROM products "
|
| 664 |
+
"WHERE price BETWEEN 20 AND 80 "
|
| 665 |
+
"ORDER BY price ASC"
|
| 666 |
+
),
|
| 667 |
+
order_sensitive=True,
|
| 668 |
+
),
|
| 669 |
+
SQLTemplate(
|
| 670 |
+
id="easy_029",
|
| 671 |
+
difficulty="easy",
|
| 672 |
+
description=(
|
| 673 |
+
"Show all gold-tier customers from Germany. "
|
| 674 |
+
"Return id, name, email, created_at."
|
| 675 |
+
),
|
| 676 |
+
sql=(
|
| 677 |
+
"SELECT id, name, email, created_at "
|
| 678 |
+
"FROM customers "
|
| 679 |
+
"WHERE tier = 'gold' AND country = 'Germany'"
|
| 680 |
+
),
|
| 681 |
+
),
|
| 682 |
+
SQLTemplate(
|
| 683 |
+
id="easy_030",
|
| 684 |
+
difficulty="easy",
|
| 685 |
+
description=(
|
| 686 |
+
"What is the average rating across all reviews in the system? "
|
| 687 |
+
"Round to 2 decimal places. Return: avg_rating."
|
| 688 |
+
),
|
| 689 |
+
sql="SELECT ROUND(AVG(rating), 2) AS avg_rating FROM reviews",
|
| 690 |
+
),
|
| 691 |
+
]
|
| 692 |
+
|
| 693 |
+
MEDIUM_TEMPLATES: List[SQLTemplate] = [
|
| 694 |
+
# ── JOIN + COUNT ─────────────────────────────────────────────────────────
|
| 695 |
+
SQLTemplate(
|
| 696 |
+
id="med_001",
|
| 697 |
+
difficulty="medium",
|
| 698 |
+
description=(
|
| 699 |
+
"How many orders has each customer placed? Include customers with zero orders. "
|
| 700 |
+
"Return customer_name and order_count. Sort by order_count descending, "
|
| 701 |
+
"then customer_name ascending."
|
| 702 |
+
),
|
| 703 |
+
sql=(
|
| 704 |
+
"SELECT c.name AS customer_name, COUNT(o.id) AS order_count "
|
| 705 |
+
"FROM customers c "
|
| 706 |
+
"LEFT JOIN orders o ON c.id = o.customer_id "
|
| 707 |
+
"GROUP BY c.id, c.name "
|
| 708 |
+
"ORDER BY order_count DESC, customer_name ASC"
|
| 709 |
+
),
|
| 710 |
+
order_sensitive=True,
|
| 711 |
+
),
|
| 712 |
+
SQLTemplate(
|
| 713 |
+
id="med_002",
|
| 714 |
+
difficulty="medium",
|
| 715 |
+
description=(
|
| 716 |
+
"Average product rating per category, only for categories that have at least one review. "
|
| 717 |
+
"Return category_name and avg_rating (rounded to 2 dp). Sort by avg_rating descending."
|
| 718 |
+
),
|
| 719 |
+
sql=(
|
| 720 |
+
"SELECT c.name AS category_name, ROUND(AVG(r.rating), 2) AS avg_rating "
|
| 721 |
+
"FROM categories c "
|
| 722 |
+
"JOIN products p ON p.category_id = c.id "
|
| 723 |
+
"JOIN reviews r ON r.product_id = p.id "
|
| 724 |
+
"GROUP BY c.id, c.name "
|
| 725 |
+
"ORDER BY avg_rating DESC"
|
| 726 |
+
),
|
| 727 |
+
order_sensitive=True,
|
| 728 |
+
),
|
| 729 |
+
SQLTemplate(
|
| 730 |
+
id="med_003",
|
| 731 |
+
difficulty="medium",
|
| 732 |
+
description=(
|
| 733 |
+
"Which categories have more than 5 in-stock products (stock_quantity > 0)? "
|
| 734 |
+
"Return category_name and in_stock_count. Sort by in_stock_count descending."
|
| 735 |
+
),
|
| 736 |
+
sql=(
|
| 737 |
+
"SELECT c.name AS category_name, COUNT(p.id) AS in_stock_count "
|
| 738 |
+
"FROM categories c "
|
| 739 |
+
"JOIN products p ON p.category_id = c.id "
|
| 740 |
+
"WHERE p.stock_quantity > 0 "
|
| 741 |
+
"GROUP BY c.id, c.name "
|
| 742 |
+
"HAVING COUNT(p.id) > 5 "
|
| 743 |
+
"ORDER BY in_stock_count DESC"
|
| 744 |
+
),
|
| 745 |
+
order_sensitive=True,
|
| 746 |
+
),
|
| 747 |
+
SQLTemplate(
|
| 748 |
+
id="med_004",
|
| 749 |
+
difficulty="medium",
|
| 750 |
+
description=(
|
| 751 |
+
"Which customers have spent more than $500 on delivered orders? "
|
| 752 |
+
"Return customer_name and total_spent (rounded to 2 dp). Sort by total_spent descending."
|
| 753 |
+
),
|
| 754 |
+
sql=(
|
| 755 |
+
"SELECT c.name AS customer_name, ROUND(SUM(o.total_amount), 2) AS total_spent "
|
| 756 |
+
"FROM customers c "
|
| 757 |
+
"JOIN orders o ON o.customer_id = c.id "
|
| 758 |
+
"WHERE o.status = 'delivered' "
|
| 759 |
+
"GROUP BY c.id, c.name "
|
| 760 |
+
"HAVING SUM(o.total_amount) > 500 "
|
| 761 |
+
"ORDER BY total_spent DESC"
|
| 762 |
+
),
|
| 763 |
+
order_sensitive=True,
|
| 764 |
+
),
|
| 765 |
+
SQLTemplate(
|
| 766 |
+
id="med_005",
|
| 767 |
+
difficulty="medium",
|
| 768 |
+
description=(
|
| 769 |
+
"Total quantity sold for each product that appears in at least one order. "
|
| 770 |
+
"Return product_name and total_quantity_sold. Sort by total_quantity_sold descending."
|
| 771 |
+
),
|
| 772 |
+
sql=(
|
| 773 |
+
"SELECT p.name AS product_name, SUM(oi.quantity) AS total_quantity_sold "
|
| 774 |
+
"FROM products p "
|
| 775 |
+
"JOIN order_items oi ON oi.product_id = p.id "
|
| 776 |
+
"GROUP BY p.id, p.name "
|
| 777 |
+
"ORDER BY total_quantity_sold DESC"
|
| 778 |
+
),
|
| 779 |
+
order_sensitive=True,
|
| 780 |
+
),
|
| 781 |
+
SQLTemplate(
|
| 782 |
+
id="med_006",
|
| 783 |
+
difficulty="medium",
|
| 784 |
+
description=(
|
| 785 |
+
"Number of reviews per product, only for products with at least 3 reviews. "
|
| 786 |
+
"Return product_name and review_count. Sort by review_count descending."
|
| 787 |
+
),
|
| 788 |
+
sql=(
|
| 789 |
+
"SELECT p.name AS product_name, COUNT(r.id) AS review_count "
|
| 790 |
+
"FROM products p "
|
| 791 |
+
"JOIN reviews r ON r.product_id = p.id "
|
| 792 |
+
"GROUP BY p.id, p.name "
|
| 793 |
+
"HAVING COUNT(r.id) >= 3 "
|
| 794 |
+
"ORDER BY review_count DESC"
|
| 795 |
+
),
|
| 796 |
+
order_sensitive=True,
|
| 797 |
+
),
|
| 798 |
+
SQLTemplate(
|
| 799 |
+
id="med_007",
|
| 800 |
+
difficulty="medium",
|
| 801 |
+
description=(
|
| 802 |
+
"Show the total revenue (sum of total_amount) per country from all orders, "
|
| 803 |
+
"regardless of status. Return country and total_revenue (rounded to 2 dp). "
|
| 804 |
+
"Sort by total_revenue descending."
|
| 805 |
+
),
|
| 806 |
+
sql=(
|
| 807 |
+
"SELECT c.country, ROUND(SUM(o.total_amount), 2) AS total_revenue "
|
| 808 |
+
"FROM customers c "
|
| 809 |
+
"JOIN orders o ON o.customer_id = c.id "
|
| 810 |
+
"GROUP BY c.country "
|
| 811 |
+
"ORDER BY total_revenue DESC"
|
| 812 |
+
),
|
| 813 |
+
order_sensitive=True,
|
| 814 |
+
),
|
| 815 |
+
SQLTemplate(
|
| 816 |
+
id="med_008",
|
| 817 |
+
difficulty="medium",
|
| 818 |
+
description=(
|
| 819 |
+
"For each customer tier (bronze, silver, gold) show the average order value "
|
| 820 |
+
"from delivered orders. Return tier and avg_order_value (rounded to 2 dp). "
|
| 821 |
+
"Sort by avg_order_value descending."
|
| 822 |
+
),
|
| 823 |
+
sql=(
|
| 824 |
+
"SELECT c.tier, ROUND(AVG(o.total_amount), 2) AS avg_order_value "
|
| 825 |
+
"FROM customers c "
|
| 826 |
+
"JOIN orders o ON o.customer_id = c.id "
|
| 827 |
+
"WHERE o.status = 'delivered' "
|
| 828 |
+
"GROUP BY c.tier "
|
| 829 |
+
"ORDER BY avg_order_value DESC"
|
| 830 |
+
),
|
| 831 |
+
order_sensitive=True,
|
| 832 |
+
),
|
| 833 |
+
SQLTemplate(
|
| 834 |
+
id="med_009",
|
| 835 |
+
difficulty="medium",
|
| 836 |
+
description=(
|
| 837 |
+
"Which products have never been ordered? "
|
| 838 |
+
"Return id and name, sorted by name ascending."
|
| 839 |
+
),
|
| 840 |
+
sql=(
|
| 841 |
+
"SELECT p.id, p.name "
|
| 842 |
+
"FROM products p "
|
| 843 |
+
"LEFT JOIN order_items oi ON oi.product_id = p.id "
|
| 844 |
+
"WHERE oi.id IS NULL "
|
| 845 |
+
"ORDER BY p.name ASC"
|
| 846 |
+
),
|
| 847 |
+
order_sensitive=True,
|
| 848 |
+
),
|
| 849 |
+
SQLTemplate(
|
| 850 |
+
id="med_010",
|
| 851 |
+
difficulty="medium",
|
| 852 |
+
description=(
|
| 853 |
+
"Number of orders per status. "
|
| 854 |
+
"Return status and order_count. Sort by order_count descending."
|
| 855 |
+
),
|
| 856 |
+
sql=(
|
| 857 |
+
"SELECT status, COUNT(*) AS order_count "
|
| 858 |
+
"FROM orders "
|
| 859 |
+
"GROUP BY status "
|
| 860 |
+
"ORDER BY order_count DESC"
|
| 861 |
+
),
|
| 862 |
+
order_sensitive=True,
|
| 863 |
+
),
|
| 864 |
+
SQLTemplate(
|
| 865 |
+
id="med_011",
|
| 866 |
+
difficulty="medium",
|
| 867 |
+
description=(
|
| 868 |
+
"Show the total number of products per category. "
|
| 869 |
+
"Return category_name and product_count. Sort by product_count descending."
|
| 870 |
+
),
|
| 871 |
+
sql=(
|
| 872 |
+
"SELECT c.name AS category_name, COUNT(p.id) AS product_count "
|
| 873 |
+
"FROM categories c "
|
| 874 |
+
"LEFT JOIN products p ON p.category_id = c.id "
|
| 875 |
+
"GROUP BY c.id, c.name "
|
| 876 |
+
"ORDER BY product_count DESC"
|
| 877 |
+
),
|
| 878 |
+
order_sensitive=True,
|
| 879 |
+
),
|
| 880 |
+
SQLTemplate(
|
| 881 |
+
id="med_012",
|
| 882 |
+
difficulty="medium",
|
| 883 |
+
description=(
|
| 884 |
+
"Average rating per product for products with at least one review. "
|
| 885 |
+
"Return product_name and avg_rating (rounded to 2 dp). Sort by avg_rating descending."
|
| 886 |
+
),
|
| 887 |
+
sql=(
|
| 888 |
+
"SELECT p.name AS product_name, ROUND(AVG(r.rating), 2) AS avg_rating "
|
| 889 |
+
"FROM products p "
|
| 890 |
+
"JOIN reviews r ON r.product_id = p.id "
|
| 891 |
+
"GROUP BY p.id, p.name "
|
| 892 |
+
"ORDER BY avg_rating DESC"
|
| 893 |
+
),
|
| 894 |
+
order_sensitive=True,
|
| 895 |
+
),
|
| 896 |
+
SQLTemplate(
|
| 897 |
+
id="med_013",
|
| 898 |
+
difficulty="medium",
|
| 899 |
+
description=(
|
| 900 |
+
"Which gold-tier customers have placed more than 3 orders? "
|
| 901 |
+
"Return customer_name and order_count. Sort by order_count descending."
|
| 902 |
+
),
|
| 903 |
+
sql=(
|
| 904 |
+
"SELECT c.name AS customer_name, COUNT(o.id) AS order_count "
|
| 905 |
+
"FROM customers c "
|
| 906 |
+
"JOIN orders o ON o.customer_id = c.id "
|
| 907 |
+
"WHERE c.tier = 'gold' "
|
| 908 |
+
"GROUP BY c.id, c.name "
|
| 909 |
+
"HAVING COUNT(o.id) > 3 "
|
| 910 |
+
"ORDER BY order_count DESC"
|
| 911 |
+
),
|
| 912 |
+
order_sensitive=True,
|
| 913 |
+
),
|
| 914 |
+
SQLTemplate(
|
| 915 |
+
id="med_014",
|
| 916 |
+
difficulty="medium",
|
| 917 |
+
description=(
|
| 918 |
+
"Total quantity of each product ordered via order_items. "
|
| 919 |
+
"Return product_name and total_units. Sort by total_units descending."
|
| 920 |
+
),
|
| 921 |
+
sql=(
|
| 922 |
+
"SELECT p.name AS product_name, SUM(oi.quantity) AS total_units "
|
| 923 |
+
"FROM products p "
|
| 924 |
+
"JOIN order_items oi ON oi.product_id = p.id "
|
| 925 |
+
"GROUP BY p.id, p.name "
|
| 926 |
+
"ORDER BY total_units DESC"
|
| 927 |
+
),
|
| 928 |
+
order_sensitive=True,
|
| 929 |
+
),
|
| 930 |
+
SQLTemplate(
|
| 931 |
+
id="med_015",
|
| 932 |
+
difficulty="medium",
|
| 933 |
+
description=(
|
| 934 |
+
"For each country, count the number of gold-tier customers. "
|
| 935 |
+
"Only show countries with at least one gold-tier customer. "
|
| 936 |
+
"Return country and gold_count. Sort by gold_count descending."
|
| 937 |
+
),
|
| 938 |
+
sql=(
|
| 939 |
+
"SELECT country, COUNT(*) AS gold_count "
|
| 940 |
+
"FROM customers "
|
| 941 |
+
"WHERE tier = 'gold' "
|
| 942 |
+
"GROUP BY country "
|
| 943 |
+
"HAVING COUNT(*) >= 1 "
|
| 944 |
+
"ORDER BY gold_count DESC"
|
| 945 |
+
),
|
| 946 |
+
order_sensitive=True,
|
| 947 |
+
),
|
| 948 |
+
SQLTemplate(
|
| 949 |
+
id="med_016",
|
| 950 |
+
difficulty="medium",
|
| 951 |
+
description=(
|
| 952 |
+
"Show how many reviews each customer has submitted. Only include customers "
|
| 953 |
+
"who have submitted at least one review. Return customer_name and review_count. "
|
| 954 |
+
"Sort by review_count descending."
|
| 955 |
+
),
|
| 956 |
+
sql=(
|
| 957 |
+
"SELECT c.name AS customer_name, COUNT(r.id) AS review_count "
|
| 958 |
+
"FROM customers c "
|
| 959 |
+
"JOIN reviews r ON r.customer_id = c.id "
|
| 960 |
+
"GROUP BY c.id, c.name "
|
| 961 |
+
"ORDER BY review_count DESC"
|
| 962 |
+
),
|
| 963 |
+
order_sensitive=True,
|
| 964 |
+
),
|
| 965 |
+
SQLTemplate(
|
| 966 |
+
id="med_017",
|
| 967 |
+
difficulty="medium",
|
| 968 |
+
description=(
|
| 969 |
+
"Total revenue generated from order_items (quantity * unit_price) per category. "
|
| 970 |
+
"Return category_name and category_revenue (rounded to 2 dp). "
|
| 971 |
+
"Sort by category_revenue descending."
|
| 972 |
+
),
|
| 973 |
+
sql=(
|
| 974 |
+
"SELECT c.name AS category_name, "
|
| 975 |
+
" ROUND(SUM(oi.quantity * oi.unit_price), 2) AS category_revenue "
|
| 976 |
+
"FROM categories c "
|
| 977 |
+
"JOIN products p ON p.category_id = c.id "
|
| 978 |
+
"JOIN order_items oi ON oi.product_id = p.id "
|
| 979 |
+
"GROUP BY c.id, c.name "
|
| 980 |
+
"ORDER BY category_revenue DESC"
|
| 981 |
+
),
|
| 982 |
+
order_sensitive=True,
|
| 983 |
+
),
|
| 984 |
+
SQLTemplate(
|
| 985 |
+
id="med_018",
|
| 986 |
+
difficulty="medium",
|
| 987 |
+
description=(
|
| 988 |
+
"Which products have an average rating strictly below 3? "
|
| 989 |
+
"Return product_name and avg_rating (rounded to 2 dp). Sort by avg_rating ascending."
|
| 990 |
+
),
|
| 991 |
+
sql=(
|
| 992 |
+
"SELECT p.name AS product_name, ROUND(AVG(r.rating), 2) AS avg_rating "
|
| 993 |
+
"FROM products p "
|
| 994 |
+
"JOIN reviews r ON r.product_id = p.id "
|
| 995 |
+
"GROUP BY p.id, p.name "
|
| 996 |
+
"HAVING AVG(r.rating) < 3 "
|
| 997 |
+
"ORDER BY avg_rating ASC"
|
| 998 |
+
),
|
| 999 |
+
order_sensitive=True,
|
| 1000 |
+
),
|
| 1001 |
+
SQLTemplate(
|
| 1002 |
+
id="med_019",
|
| 1003 |
+
difficulty="medium",
|
| 1004 |
+
description=(
|
| 1005 |
+
"Find the maximum order value for each customer tier. "
|
| 1006 |
+
"Return tier and max_order_value (rounded to 2 dp). Sort by max_order_value descending."
|
| 1007 |
+
),
|
| 1008 |
+
sql=(
|
| 1009 |
+
"SELECT c.tier, ROUND(MAX(o.total_amount), 2) AS max_order_value "
|
| 1010 |
+
"FROM customers c "
|
| 1011 |
+
"JOIN orders o ON o.customer_id = c.id "
|
| 1012 |
+
"GROUP BY c.tier "
|
| 1013 |
+
"ORDER BY max_order_value DESC"
|
| 1014 |
+
),
|
| 1015 |
+
order_sensitive=True,
|
| 1016 |
+
),
|
| 1017 |
+
SQLTemplate(
|
| 1018 |
+
id="med_020",
|
| 1019 |
+
difficulty="medium",
|
| 1020 |
+
description=(
|
| 1021 |
+
"How many customers per country have placed at least one delivered order? "
|
| 1022 |
+
"Return country and customer_count. Sort by customer_count descending."
|
| 1023 |
+
),
|
| 1024 |
+
sql=(
|
| 1025 |
+
"SELECT c.country, COUNT(DISTINCT c.id) AS customer_count "
|
| 1026 |
+
"FROM customers c "
|
| 1027 |
+
"JOIN orders o ON o.customer_id = c.id "
|
| 1028 |
+
"WHERE o.status = 'delivered' "
|
| 1029 |
+
"GROUP BY c.country "
|
| 1030 |
+
"ORDER BY customer_count DESC"
|
| 1031 |
+
),
|
| 1032 |
+
order_sensitive=True,
|
| 1033 |
+
),
|
| 1034 |
+
SQLTemplate(
|
| 1035 |
+
id="med_021",
|
| 1036 |
+
difficulty="medium",
|
| 1037 |
+
description=(
|
| 1038 |
+
"List all products together with their category name. "
|
| 1039 |
+
"Return product_name, category_name, price. Sort by category_name, then price ascending."
|
| 1040 |
+
),
|
| 1041 |
+
sql=(
|
| 1042 |
+
"SELECT p.name AS product_name, c.name AS category_name, p.price "
|
| 1043 |
+
"FROM products p "
|
| 1044 |
+
"JOIN categories c ON c.id = p.category_id "
|
| 1045 |
+
"ORDER BY category_name ASC, p.price ASC"
|
| 1046 |
+
),
|
| 1047 |
+
order_sensitive=True,
|
| 1048 |
+
),
|
| 1049 |
+
SQLTemplate(
|
| 1050 |
+
id="med_022",
|
| 1051 |
+
difficulty="medium",
|
| 1052 |
+
description=(
|
| 1053 |
+
"For each order, show the total number of line items it contains. "
|
| 1054 |
+
"Return order_id and line_item_count. Sort by line_item_count descending."
|
| 1055 |
+
),
|
| 1056 |
+
sql=(
|
| 1057 |
+
"SELECT order_id, COUNT(*) AS line_item_count "
|
| 1058 |
+
"FROM order_items "
|
| 1059 |
+
"GROUP BY order_id "
|
| 1060 |
+
"ORDER BY line_item_count DESC"
|
| 1061 |
+
),
|
| 1062 |
+
order_sensitive=True,
|
| 1063 |
+
),
|
| 1064 |
+
SQLTemplate(
|
| 1065 |
+
id="med_023",
|
| 1066 |
+
difficulty="medium",
|
| 1067 |
+
description=(
|
| 1068 |
+
"Show the minimum and maximum product price per category. "
|
| 1069 |
+
"Return category_name, min_price, max_price. Sort by category_name ascending."
|
| 1070 |
+
),
|
| 1071 |
+
sql=(
|
| 1072 |
+
"SELECT c.name AS category_name, "
|
| 1073 |
+
" ROUND(MIN(p.price), 2) AS min_price, "
|
| 1074 |
+
" ROUND(MAX(p.price), 2) AS max_price "
|
| 1075 |
+
"FROM categories c "
|
| 1076 |
+
"JOIN products p ON p.category_id = c.id "
|
| 1077 |
+
"GROUP BY c.id, c.name "
|
| 1078 |
+
"ORDER BY category_name ASC"
|
| 1079 |
+
),
|
| 1080 |
+
order_sensitive=True,
|
| 1081 |
+
),
|
| 1082 |
+
SQLTemplate(
|
| 1083 |
+
id="med_024",
|
| 1084 |
+
difficulty="medium",
|
| 1085 |
+
description=(
|
| 1086 |
+
"Find customers who have given a rating of 5 to at least one product. "
|
| 1087 |
+
"Return customer_name and five_star_count. Sort by five_star_count descending."
|
| 1088 |
+
),
|
| 1089 |
+
sql=(
|
| 1090 |
+
"SELECT c.name AS customer_name, COUNT(r.id) AS five_star_count "
|
| 1091 |
+
"FROM customers c "
|
| 1092 |
+
"JOIN reviews r ON r.customer_id = c.id "
|
| 1093 |
+
"WHERE r.rating = 5 "
|
| 1094 |
+
"GROUP BY c.id, c.name "
|
| 1095 |
+
"ORDER BY five_star_count DESC"
|
| 1096 |
+
),
|
| 1097 |
+
order_sensitive=True,
|
| 1098 |
+
),
|
| 1099 |
+
SQLTemplate(
|
| 1100 |
+
id="med_025",
|
| 1101 |
+
difficulty="medium",
|
| 1102 |
+
description=(
|
| 1103 |
+
"Show the average number of items per order across all orders. "
|
| 1104 |
+
"Round to 2 decimal places. Return: avg_items_per_order."
|
| 1105 |
+
),
|
| 1106 |
+
sql=(
|
| 1107 |
+
"SELECT ROUND(AVG(item_count), 2) AS avg_items_per_order "
|
| 1108 |
+
"FROM ( "
|
| 1109 |
+
" SELECT order_id, COUNT(*) AS item_count "
|
| 1110 |
+
" FROM order_items "
|
| 1111 |
+
" GROUP BY order_id "
|
| 1112 |
+
")"
|
| 1113 |
+
),
|
| 1114 |
+
),
|
| 1115 |
+
]
|
| 1116 |
+
|
| 1117 |
+
HARD_TEMPLATES: List[SQLTemplate] = [
|
| 1118 |
+
# ── Window functions ────────────────────────────────────────────────���────
|
| 1119 |
+
SQLTemplate(
|
| 1120 |
+
id="hard_001",
|
| 1121 |
+
difficulty="hard",
|
| 1122 |
+
description=(
|
| 1123 |
+
"Rank customers by total spending on delivered orders using DENSE_RANK "
|
| 1124 |
+
"(rank 1 = highest spender). "
|
| 1125 |
+
"Return customer_name, total_spent (rounded to 2 dp), spending_rank. "
|
| 1126 |
+
"Sort by spending_rank ascending."
|
| 1127 |
+
),
|
| 1128 |
+
sql=(
|
| 1129 |
+
"SELECT customer_name, total_spent, spending_rank "
|
| 1130 |
+
"FROM ( "
|
| 1131 |
+
" SELECT c.name AS customer_name, "
|
| 1132 |
+
" ROUND(SUM(o.total_amount), 2) AS total_spent, "
|
| 1133 |
+
" DENSE_RANK() OVER (ORDER BY SUM(o.total_amount) DESC) AS spending_rank "
|
| 1134 |
+
" FROM customers c "
|
| 1135 |
+
" JOIN orders o ON o.customer_id = c.id "
|
| 1136 |
+
" WHERE o.status = 'delivered' "
|
| 1137 |
+
" GROUP BY c.id, c.name "
|
| 1138 |
+
") sub "
|
| 1139 |
+
"ORDER BY spending_rank ASC"
|
| 1140 |
+
),
|
| 1141 |
+
order_sensitive=True,
|
| 1142 |
+
),
|
| 1143 |
+
SQLTemplate(
|
| 1144 |
+
id="hard_002",
|
| 1145 |
+
difficulty="hard",
|
| 1146 |
+
description=(
|
| 1147 |
+
"For each reviewed product, show its own average rating and the average rating "
|
| 1148 |
+
"of all products in its category (partition window). "
|
| 1149 |
+
"Return product_name, product_avg_rating, category_avg_rating (both rounded to 2 dp). "
|
| 1150 |
+
"Sort by product_avg_rating descending."
|
| 1151 |
+
),
|
| 1152 |
+
sql=(
|
| 1153 |
+
"SELECT p.name AS product_name, "
|
| 1154 |
+
" ROUND(AVG(r.rating), 2) AS product_avg_rating, "
|
| 1155 |
+
" ROUND(AVG(AVG(r.rating)) OVER (PARTITION BY p.category_id), 2) AS category_avg_rating "
|
| 1156 |
+
"FROM products p "
|
| 1157 |
+
"JOIN reviews r ON r.product_id = p.id "
|
| 1158 |
+
"GROUP BY p.id, p.name, p.category_id "
|
| 1159 |
+
"ORDER BY product_avg_rating DESC"
|
| 1160 |
+
),
|
| 1161 |
+
order_sensitive=True,
|
| 1162 |
+
),
|
| 1163 |
+
SQLTemplate(
|
| 1164 |
+
id="hard_003",
|
| 1165 |
+
difficulty="hard",
|
| 1166 |
+
description=(
|
| 1167 |
+
"Find all customers whose most recent order has status 'cancelled'. "
|
| 1168 |
+
"Use a CTE with ROW_NUMBER partitioned by customer_id ordered by created_at DESC. "
|
| 1169 |
+
"Return customer_name, last_order_status, last_order_date. Sort by customer_name ascending."
|
| 1170 |
+
),
|
| 1171 |
+
sql=(
|
| 1172 |
+
"WITH ranked_orders AS ( "
|
| 1173 |
+
" SELECT customer_id, status, created_at, "
|
| 1174 |
+
" ROW_NUMBER() OVER (PARTITION BY customer_id ORDER BY created_at DESC) AS rn "
|
| 1175 |
+
" FROM orders "
|
| 1176 |
+
") "
|
| 1177 |
+
"SELECT c.name AS customer_name, "
|
| 1178 |
+
" ro.status AS last_order_status, "
|
| 1179 |
+
" ro.created_at AS last_order_date "
|
| 1180 |
+
"FROM customers c "
|
| 1181 |
+
"JOIN ranked_orders ro ON ro.customer_id = c.id "
|
| 1182 |
+
"WHERE ro.rn = 1 AND ro.status = 'cancelled' "
|
| 1183 |
+
"ORDER BY customer_name ASC"
|
| 1184 |
+
),
|
| 1185 |
+
order_sensitive=True,
|
| 1186 |
+
),
|
| 1187 |
+
SQLTemplate(
|
| 1188 |
+
id="hard_004",
|
| 1189 |
+
difficulty="hard",
|
| 1190 |
+
description=(
|
| 1191 |
+
"Monthly revenue from delivered orders and its running total for all months in 2024. "
|
| 1192 |
+
"Return month (YYYY-MM format), monthly_revenue, running_total (both rounded to 2 dp). "
|
| 1193 |
+
"Sort by month ascending."
|
| 1194 |
+
),
|
| 1195 |
+
sql=(
|
| 1196 |
+
"WITH monthly AS ( "
|
| 1197 |
+
" SELECT strftime('%Y-%m', created_at) AS month, "
|
| 1198 |
+
" ROUND(SUM(total_amount), 2) AS monthly_revenue "
|
| 1199 |
+
" FROM orders "
|
| 1200 |
+
" WHERE status = 'delivered' "
|
| 1201 |
+
" AND created_at >= '2024-01-01' AND created_at < '2025-01-01' "
|
| 1202 |
+
" GROUP BY strftime('%Y-%m', created_at) "
|
| 1203 |
+
") "
|
| 1204 |
+
"SELECT month, monthly_revenue, "
|
| 1205 |
+
" ROUND(SUM(monthly_revenue) OVER (ORDER BY month), 2) AS running_total "
|
| 1206 |
+
"FROM monthly "
|
| 1207 |
+
"ORDER BY month ASC"
|
| 1208 |
+
),
|
| 1209 |
+
order_sensitive=True,
|
| 1210 |
+
),
|
| 1211 |
+
SQLTemplate(
|
| 1212 |
+
id="hard_005",
|
| 1213 |
+
difficulty="hard",
|
| 1214 |
+
description=(
|
| 1215 |
+
"Find products whose average rating is strictly above the average rating of all products "
|
| 1216 |
+
"in their category. Use two CTEs: one for product-level averages and one for category-level. "
|
| 1217 |
+
"Return product_name, category_name, product_avg_rating, category_avg_rating (both rounded to 2 dp). "
|
| 1218 |
+
"Sort by product_avg_rating descending, then product_name ascending."
|
| 1219 |
+
),
|
| 1220 |
+
sql=(
|
| 1221 |
+
"WITH product_ratings AS ( "
|
| 1222 |
+
" SELECT p.id AS product_id, p.name AS product_name, "
|
| 1223 |
+
" p.category_id, c.name AS category_name, "
|
| 1224 |
+
" ROUND(AVG(r.rating), 2) AS product_avg_rating "
|
| 1225 |
+
" FROM products p "
|
| 1226 |
+
" JOIN reviews r ON r.product_id = p.id "
|
| 1227 |
+
" JOIN categories c ON c.id = p.category_id "
|
| 1228 |
+
" GROUP BY p.id, p.name, p.category_id, c.name "
|
| 1229 |
+
"), "
|
| 1230 |
+
"category_ratings AS ( "
|
| 1231 |
+
" SELECT category_id, ROUND(AVG(product_avg_rating), 2) AS category_avg_rating "
|
| 1232 |
+
" FROM product_ratings "
|
| 1233 |
+
" GROUP BY category_id "
|
| 1234 |
+
") "
|
| 1235 |
+
"SELECT pr.product_name, pr.category_name, "
|
| 1236 |
+
" pr.product_avg_rating, cr.category_avg_rating "
|
| 1237 |
+
"FROM product_ratings pr "
|
| 1238 |
+
"JOIN category_ratings cr ON cr.category_id = pr.category_id "
|
| 1239 |
+
"WHERE pr.product_avg_rating > cr.category_avg_rating "
|
| 1240 |
+
"ORDER BY pr.product_avg_rating DESC, pr.product_name ASC"
|
| 1241 |
+
),
|
| 1242 |
+
order_sensitive=True,
|
| 1243 |
+
),
|
| 1244 |
+
SQLTemplate(
|
| 1245 |
+
id="hard_006",
|
| 1246 |
+
difficulty="hard",
|
| 1247 |
+
description=(
|
| 1248 |
+
"For each customer, find their very first order date using ROW_NUMBER in a CTE. "
|
| 1249 |
+
"Return customer_name and first_order_date. Sort by first_order_date ascending."
|
| 1250 |
+
),
|
| 1251 |
+
sql=(
|
| 1252 |
+
"WITH first_orders AS ( "
|
| 1253 |
+
" SELECT customer_id, created_at, "
|
| 1254 |
+
" ROW_NUMBER() OVER (PARTITION BY customer_id ORDER BY created_at ASC) AS rn "
|
| 1255 |
+
" FROM orders "
|
| 1256 |
+
") "
|
| 1257 |
+
"SELECT c.name AS customer_name, fo.created_at AS first_order_date "
|
| 1258 |
+
"FROM customers c "
|
| 1259 |
+
"JOIN first_orders fo ON fo.customer_id = c.id "
|
| 1260 |
+
"WHERE fo.rn = 1 "
|
| 1261 |
+
"ORDER BY first_order_date ASC"
|
| 1262 |
+
),
|
| 1263 |
+
order_sensitive=True,
|
| 1264 |
+
),
|
| 1265 |
+
SQLTemplate(
|
| 1266 |
+
id="hard_007",
|
| 1267 |
+
difficulty="hard",
|
| 1268 |
+
description=(
|
| 1269 |
+
"Rank products by total revenue generated (quantity * unit_price from order_items) "
|
| 1270 |
+
"using RANK() window function. "
|
| 1271 |
+
"Return product_name, total_revenue (rounded to 2 dp), revenue_rank. "
|
| 1272 |
+
"Sort by revenue_rank ascending."
|
| 1273 |
+
),
|
| 1274 |
+
sql=(
|
| 1275 |
+
"SELECT product_name, total_revenue, revenue_rank "
|
| 1276 |
+
"FROM ( "
|
| 1277 |
+
" SELECT p.name AS product_name, "
|
| 1278 |
+
" ROUND(SUM(oi.quantity * oi.unit_price), 2) AS total_revenue, "
|
| 1279 |
+
" RANK() OVER (ORDER BY SUM(oi.quantity * oi.unit_price) DESC) AS revenue_rank "
|
| 1280 |
+
" FROM products p "
|
| 1281 |
+
" JOIN order_items oi ON oi.product_id = p.id "
|
| 1282 |
+
" GROUP BY p.id, p.name "
|
| 1283 |
+
") sub "
|
| 1284 |
+
"ORDER BY revenue_rank ASC"
|
| 1285 |
+
),
|
| 1286 |
+
order_sensitive=True,
|
| 1287 |
+
),
|
| 1288 |
+
SQLTemplate(
|
| 1289 |
+
id="hard_008",
|
| 1290 |
+
difficulty="hard",
|
| 1291 |
+
description=(
|
| 1292 |
+
"For each customer, compute the running total of their order amounts ordered by "
|
| 1293 |
+
"created_at. Return customer_name, order_date (created_at), order_amount (total_amount), "
|
| 1294 |
+
"running_total (rounded to 2 dp). Sort by customer_name, order_date ascending."
|
| 1295 |
+
),
|
| 1296 |
+
sql=(
|
| 1297 |
+
"SELECT c.name AS customer_name, "
|
| 1298 |
+
" o.created_at AS order_date, "
|
| 1299 |
+
" o.total_amount AS order_amount, "
|
| 1300 |
+
" ROUND(SUM(o.total_amount) OVER "
|
| 1301 |
+
" (PARTITION BY c.id ORDER BY o.created_at "
|
| 1302 |
+
" ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW), 2) AS running_total "
|
| 1303 |
+
"FROM customers c "
|
| 1304 |
+
"JOIN orders o ON o.customer_id = c.id "
|
| 1305 |
+
"ORDER BY customer_name ASC, order_date ASC"
|
| 1306 |
+
),
|
| 1307 |
+
order_sensitive=True,
|
| 1308 |
+
),
|
| 1309 |
+
SQLTemplate(
|
| 1310 |
+
id="hard_009",
|
| 1311 |
+
difficulty="hard",
|
| 1312 |
+
description=(
|
| 1313 |
+
"Find customers who have placed orders in every status "
|
| 1314 |
+
"(pending, processing, shipped, delivered, cancelled) at least once. "
|
| 1315 |
+
"Return customer_name and status_count. Sort by customer_name ascending."
|
| 1316 |
+
),
|
| 1317 |
+
sql=(
|
| 1318 |
+
"SELECT c.name AS customer_name, COUNT(DISTINCT o.status) AS status_count "
|
| 1319 |
+
"FROM customers c "
|
| 1320 |
+
"JOIN orders o ON o.customer_id = c.id "
|
| 1321 |
+
"GROUP BY c.id, c.name "
|
| 1322 |
+
"HAVING COUNT(DISTINCT o.status) = 5 "
|
| 1323 |
+
"ORDER BY customer_name ASC"
|
| 1324 |
+
),
|
| 1325 |
+
order_sensitive=True,
|
| 1326 |
+
),
|
| 1327 |
+
SQLTemplate(
|
| 1328 |
+
id="hard_010",
|
| 1329 |
+
difficulty="hard",
|
| 1330 |
+
description=(
|
| 1331 |
+
"Using a CTE, compute the total revenue per product, then rank the top 3 products "
|
| 1332 |
+
"in each category by revenue using DENSE_RANK. Only return rows with rank <= 3. "
|
| 1333 |
+
"Return category_name, product_name, total_revenue (rounded to 2 dp), rank_in_category. "
|
| 1334 |
+
"Sort by category_name, rank_in_category ascending."
|
| 1335 |
+
),
|
| 1336 |
+
sql=(
|
| 1337 |
+
"WITH product_rev AS ( "
|
| 1338 |
+
" SELECT p.id, p.name AS product_name, p.category_id, "
|
| 1339 |
+
" c.name AS category_name, "
|
| 1340 |
+
" ROUND(SUM(oi.quantity * oi.unit_price), 2) AS total_revenue "
|
| 1341 |
+
" FROM products p "
|
| 1342 |
+
" JOIN categories c ON c.id = p.category_id "
|
| 1343 |
+
" JOIN order_items oi ON oi.product_id = p.id "
|
| 1344 |
+
" GROUP BY p.id, p.name, p.category_id, c.name "
|
| 1345 |
+
"), "
|
| 1346 |
+
"ranked AS ( "
|
| 1347 |
+
" SELECT product_name, category_name, total_revenue, "
|
| 1348 |
+
" DENSE_RANK() OVER (PARTITION BY category_id ORDER BY total_revenue DESC) AS rank_in_category "
|
| 1349 |
+
" FROM product_rev "
|
| 1350 |
+
") "
|
| 1351 |
+
"SELECT category_name, product_name, total_revenue, rank_in_category "
|
| 1352 |
+
"FROM ranked "
|
| 1353 |
+
"WHERE rank_in_category <= 3 "
|
| 1354 |
+
"ORDER BY category_name ASC, rank_in_category ASC"
|
| 1355 |
+
),
|
| 1356 |
+
order_sensitive=True,
|
| 1357 |
+
),
|
| 1358 |
+
SQLTemplate(
|
| 1359 |
+
id="hard_011",
|
| 1360 |
+
difficulty="hard",
|
| 1361 |
+
description=(
|
| 1362 |
+
"Compute the percentage of total revenue each category contributes. "
|
| 1363 |
+
"Use a CTE for category revenues and a window SUM for the grand total. "
|
| 1364 |
+
"Return category_name, category_revenue, pct_of_total (rounded to 2 dp). "
|
| 1365 |
+
"Sort by pct_of_total descending."
|
| 1366 |
+
),
|
| 1367 |
+
sql=(
|
| 1368 |
+
"WITH cat_rev AS ( "
|
| 1369 |
+
" SELECT c.name AS category_name, "
|
| 1370 |
+
" ROUND(SUM(oi.quantity * oi.unit_price), 2) AS category_revenue "
|
| 1371 |
+
" FROM categories c "
|
| 1372 |
+
" JOIN products p ON p.category_id = c.id "
|
| 1373 |
+
" JOIN order_items oi ON oi.product_id = p.id "
|
| 1374 |
+
" GROUP BY c.id, c.name "
|
| 1375 |
+
") "
|
| 1376 |
+
"SELECT category_name, category_revenue, "
|
| 1377 |
+
" ROUND(100.0 * category_revenue / SUM(category_revenue) OVER (), 2) AS pct_of_total "
|
| 1378 |
+
"FROM cat_rev "
|
| 1379 |
+
"ORDER BY pct_of_total DESC"
|
| 1380 |
+
),
|
| 1381 |
+
order_sensitive=True,
|
| 1382 |
+
),
|
| 1383 |
+
SQLTemplate(
|
| 1384 |
+
id="hard_012",
|
| 1385 |
+
difficulty="hard",
|
| 1386 |
+
description=(
|
| 1387 |
+
"Find the customers who placed the highest number of orders in 2023. "
|
| 1388 |
+
"Use a CTE to count per-customer orders in 2023, then apply DENSE_RANK. "
|
| 1389 |
+
"Return customer_name, order_count_2023, rank. Sort by rank, then customer_name."
|
| 1390 |
+
),
|
| 1391 |
+
sql=(
|
| 1392 |
+
"WITH counts_2023 AS ( "
|
| 1393 |
+
" SELECT c.name AS customer_name, COUNT(o.id) AS order_count_2023 "
|
| 1394 |
+
" FROM customers c "
|
| 1395 |
+
" JOIN orders o ON o.customer_id = c.id "
|
| 1396 |
+
" WHERE o.created_at >= '2023-01-01' AND o.created_at < '2024-01-01' "
|
| 1397 |
+
" GROUP BY c.id, c.name "
|
| 1398 |
+
") "
|
| 1399 |
+
"SELECT customer_name, order_count_2023, "
|
| 1400 |
+
" DENSE_RANK() OVER (ORDER BY order_count_2023 DESC) AS rank "
|
| 1401 |
+
"FROM counts_2023 "
|
| 1402 |
+
"ORDER BY rank ASC, customer_name ASC"
|
| 1403 |
+
),
|
| 1404 |
+
order_sensitive=True,
|
| 1405 |
+
),
|
| 1406 |
+
SQLTemplate(
|
| 1407 |
+
id="hard_013",
|
| 1408 |
+
difficulty="hard",
|
| 1409 |
+
description=(
|
| 1410 |
+
"Show a quarterly revenue breakdown for delivered orders across all years. "
|
| 1411 |
+
"Use strftime to derive year and quarter. "
|
| 1412 |
+
"Return year, quarter, quarterly_revenue (rounded to 2 dp), "
|
| 1413 |
+
"and running_total_in_year (running SUM within the same year, rounded to 2 dp). "
|
| 1414 |
+
"Sort by year, quarter ascending."
|
| 1415 |
+
),
|
| 1416 |
+
sql=(
|
| 1417 |
+
"WITH quarterly AS ( "
|
| 1418 |
+
" SELECT strftime('%Y', created_at) AS year, "
|
| 1419 |
+
" ((CAST(strftime('%m', created_at) AS INTEGER) - 1) / 3 + 1) AS quarter, "
|
| 1420 |
+
" ROUND(SUM(total_amount), 2) AS quarterly_revenue "
|
| 1421 |
+
" FROM orders "
|
| 1422 |
+
" WHERE status = 'delivered' "
|
| 1423 |
+
" GROUP BY year, quarter "
|
| 1424 |
+
") "
|
| 1425 |
+
"SELECT year, quarter, quarterly_revenue, "
|
| 1426 |
+
" ROUND(SUM(quarterly_revenue) OVER (PARTITION BY year ORDER BY quarter), 2) AS running_total_in_year "
|
| 1427 |
+
"FROM quarterly "
|
| 1428 |
+
"ORDER BY year ASC, quarter ASC"
|
| 1429 |
+
),
|
| 1430 |
+
order_sensitive=True,
|
| 1431 |
+
),
|
| 1432 |
+
SQLTemplate(
|
| 1433 |
+
id="hard_014",
|
| 1434 |
+
difficulty="hard",
|
| 1435 |
+
description=(
|
| 1436 |
+
"Find the top-spending customer in each country using ROW_NUMBER. "
|
| 1437 |
+
"Return country, customer_name, total_spent (rounded to 2 dp). "
|
| 1438 |
+
"Sort by country, total_spent descending."
|
| 1439 |
+
),
|
| 1440 |
+
sql=(
|
| 1441 |
+
"WITH customer_spend AS ( "
|
| 1442 |
+
" SELECT c.id, c.name AS customer_name, c.country, "
|
| 1443 |
+
" ROUND(SUM(o.total_amount), 2) AS total_spent "
|
| 1444 |
+
" FROM customers c "
|
| 1445 |
+
" JOIN orders o ON o.customer_id = c.id "
|
| 1446 |
+
" GROUP BY c.id, c.name, c.country "
|
| 1447 |
+
"), "
|
| 1448 |
+
"ranked AS ( "
|
| 1449 |
+
" SELECT country, customer_name, total_spent, "
|
| 1450 |
+
" ROW_NUMBER() OVER (PARTITION BY country ORDER BY total_spent DESC) AS rn "
|
| 1451 |
+
" FROM customer_spend "
|
| 1452 |
+
") "
|
| 1453 |
+
"SELECT country, customer_name, total_spent "
|
| 1454 |
+
"FROM ranked "
|
| 1455 |
+
"WHERE rn = 1 "
|
| 1456 |
+
"ORDER BY country ASC"
|
| 1457 |
+
),
|
| 1458 |
+
order_sensitive=True,
|
| 1459 |
+
),
|
| 1460 |
+
SQLTemplate(
|
| 1461 |
+
id="hard_015",
|
| 1462 |
+
difficulty="hard",
|
| 1463 |
+
description=(
|
| 1464 |
+
"Find products that have received both 1-star and 5-star reviews. "
|
| 1465 |
+
"Use two CTEs: one for 1-star products, one for 5-star products, then intersect. "
|
| 1466 |
+
"Return product_name. Sort by product_name ascending."
|
| 1467 |
+
),
|
| 1468 |
+
sql=(
|
| 1469 |
+
"WITH one_star AS ( "
|
| 1470 |
+
" SELECT DISTINCT product_id FROM reviews WHERE rating = 1 "
|
| 1471 |
+
"), "
|
| 1472 |
+
"five_star AS ( "
|
| 1473 |
+
" SELECT DISTINCT product_id FROM reviews WHERE rating = 5 "
|
| 1474 |
+
") "
|
| 1475 |
+
"SELECT p.name AS product_name "
|
| 1476 |
+
"FROM products p "
|
| 1477 |
+
"JOIN one_star os ON os.product_id = p.id "
|
| 1478 |
+
"JOIN five_star fs ON fs.product_id = p.id "
|
| 1479 |
+
"ORDER BY product_name ASC"
|
| 1480 |
+
),
|
| 1481 |
+
order_sensitive=True,
|
| 1482 |
+
),
|
| 1483 |
+
]
|
| 1484 |
+
|
| 1485 |
+
ALL_TEMPLATES: List[SQLTemplate] = EASY_TEMPLATES + MEDIUM_TEMPLATES + HARD_TEMPLATES
|
| 1486 |
+
|
| 1487 |
+
|
| 1488 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 1489 |
+
# Personas
|
| 1490 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 1491 |
+
|
| 1492 |
+
SCHEMA_CONTEXT = """
|
| 1493 |
+
DATABASE SCHEMA (SQLite e-commerce):
|
| 1494 |
+
categories(id, name)
|
| 1495 |
+
products(id, name, category_id, price, stock_quantity)
|
| 1496 |
+
customers(id, name, email, country, tier∈{bronze|silver|gold}, created_at)
|
| 1497 |
+
orders(id, customer_id, status∈{pending|processing|shipped|delivered|cancelled},
|
| 1498 |
+
created_at, total_amount)
|
| 1499 |
+
order_items(id, order_id, product_id, quantity, unit_price)
|
| 1500 |
+
reviews(id, product_id, customer_id, rating∈1-5, created_at)
|
| 1501 |
+
"""
|
| 1502 |
+
|
| 1503 |
+
PERSONA_SPECS = {
|
| 1504 |
+
"ceo": (
|
| 1505 |
+
"You are a senior business executive. Write one SHORT, direct question in active voice, "
|
| 1506 |
+
"as if you are asking an analyst to pull a number fast. Be terse, no fluff. "
|
| 1507 |
+
"Use business language: 'revenue', 'customers', 'performance', not technical SQL terms."
|
| 1508 |
+
),
|
| 1509 |
+
"chatty": (
|
| 1510 |
+
"You are a friendly but verbose non-technical employee. Write one long, conversational "
|
| 1511 |
+
"question with filler phrases like 'Could you please tell me...', 'I was wondering if...', "
|
| 1512 |
+
"passive voice is fine. Use everyday words like 'money' instead of 'revenue', "
|
| 1513 |
+
"'people' instead of 'customers'."
|
| 1514 |
+
),
|
| 1515 |
+
"lazy": (
|
| 1516 |
+
"You are typing quickly on a phone. Write an extremely short question with abbreviations, "
|
| 1517 |
+
"lowercase letters, and minor spelling mistakes. Skip articles and punctuation where possible. "
|
| 1518 |
+
"Example style: 'top 5 prods by sales?', 'hw many cust in usa'."
|
| 1519 |
+
),
|
| 1520 |
+
"confused": (
|
| 1521 |
+
"You are a non-technical user who is unsure of the exact terminology. Write one question "
|
| 1522 |
+
"using synonyms and vague language. Replace 'revenue' with 'money made', 'customers' with "
|
| 1523 |
+
"'people' or 'users' or 'accounts', 'orders' with 'purchases' or 'transactions', "
|
| 1524 |
+
"'tier' with 'membership level'. Include a bit of ambiguity."
|
| 1525 |
+
),
|
| 1526 |
+
"analyst": (
|
| 1527 |
+
"You are a data analyst with technical knowledge. Write one precise, jargon-heavy question "
|
| 1528 |
+
"using terms like 'aggregate', 'partition', 'metric', 'fiscal period', 'segmented by', "
|
| 1529 |
+
"'cohort', 'granularity'. Be specific about column names and filters."
|
| 1530 |
+
),
|
| 1531 |
+
}
|
| 1532 |
+
|
| 1533 |
+
|
| 1534 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 1535 |
+
# Rule-based Augmentor
|
| 1536 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 1537 |
+
|
| 1538 |
+
class RuleAugmentor:
|
| 1539 |
+
"""
|
| 1540 |
+
Applies deterministic, non-LLM transformations to a generated NL question.
|
| 1541 |
+
Returns a list of augmented variants (may be empty if no rule applied).
|
| 1542 |
+
"""
|
| 1543 |
+
|
| 1544 |
+
SYNONYMS: Dict[str, List[str]] = {
|
| 1545 |
+
"customers": ["clients", "users", "accounts", "shoppers", "buyers"],
|
| 1546 |
+
"orders": ["purchases", "transactions", "sales", "bookings"],
|
| 1547 |
+
"products": ["items", "goods", "listings", "SKUs"],
|
| 1548 |
+
"revenue": ["sales", "income", "earnings", "money made"],
|
| 1549 |
+
"spending": ["expenditure", "purchases", "money spent"],
|
| 1550 |
+
"delivered": ["completed", "fulfilled", "received"],
|
| 1551 |
+
"cancelled": ["canceled", "voided", "aborted"],
|
| 1552 |
+
"pending": ["waiting", "unprocessed", "queued"],
|
| 1553 |
+
"gold": ["premium", "top-tier", "VIP", "platinum"],
|
| 1554 |
+
"silver": ["mid-tier", "standard-plus"],
|
| 1555 |
+
"bronze": ["basic", "standard", "entry-level"],
|
| 1556 |
+
"rating": ["score", "star rating", "review score"],
|
| 1557 |
+
"country": ["region", "location", "geography", "nation"],
|
| 1558 |
+
"category": ["department", "section", "type", "group"],
|
| 1559 |
+
"price": ["cost", "value", "amount", "fee"],
|
| 1560 |
+
"total": ["sum", "aggregate", "combined", "overall"],
|
| 1561 |
+
"average": ["mean", "typical", "avg"],
|
| 1562 |
+
"show": ["list", "display", "give me", "get", "fetch"],
|
| 1563 |
+
"find": ["identify", "locate", "get", "pull", "retrieve"],
|
| 1564 |
+
"return": ["give me", "show", "list", "provide"],
|
| 1565 |
+
}
|
| 1566 |
+
|
| 1567 |
+
def augment(self, question: str, rng: random.Random) -> Optional[str]:
|
| 1568 |
+
words = question.split()
|
| 1569 |
+
changed = False
|
| 1570 |
+
result = []
|
| 1571 |
+
for w in words:
|
| 1572 |
+
clean = w.lower().strip(".,?!;:")
|
| 1573 |
+
if clean in self.SYNONYMS and rng.random() < 0.4:
|
| 1574 |
+
replacement = rng.choice(self.SYNONYMS[clean])
|
| 1575 |
+
# Preserve trailing punctuation
|
| 1576 |
+
punct = w[len(clean):] if w.lower().startswith(clean) else ""
|
| 1577 |
+
result.append(replacement + punct)
|
| 1578 |
+
changed = True
|
| 1579 |
+
else:
|
| 1580 |
+
result.append(w)
|
| 1581 |
+
if not changed:
|
| 1582 |
+
return None
|
| 1583 |
+
new_q = " ".join(result)
|
| 1584 |
+
# Capitalise first letter
|
| 1585 |
+
return new_q[0].upper() + new_q[1:] if new_q else new_q
|
| 1586 |
+
|
| 1587 |
+
|
| 1588 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 1589 |
+
# vLLM Generator
|
| 1590 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 1591 |
+
|
| 1592 |
+
class VLLMGenerator:
|
| 1593 |
+
"""
|
| 1594 |
+
Async batched inference using the OpenAI-compatible vLLM endpoint.
|
| 1595 |
+
vLLM exposes exactly the same API as OpenAI, so we reuse AsyncOpenAI.
|
| 1596 |
+
"""
|
| 1597 |
+
|
| 1598 |
+
def __init__(self, base_url: str, model: str, temperature: float = 0.8,
|
| 1599 |
+
max_tokens: int = 256, semaphore: int = 64):
|
| 1600 |
+
self.client = AsyncOpenAI(base_url=base_url, api_key="NONE")
|
| 1601 |
+
self.model = model
|
| 1602 |
+
self.temperature = temperature
|
| 1603 |
+
self.max_tokens = max_tokens
|
| 1604 |
+
self._sem = asyncio.Semaphore(semaphore)
|
| 1605 |
+
|
| 1606 |
+
async def generate_one(
|
| 1607 |
+
self,
|
| 1608 |
+
system: str,
|
| 1609 |
+
user: str,
|
| 1610 |
+
retries: int = 3,
|
| 1611 |
+
) -> Optional[str]:
|
| 1612 |
+
for attempt in range(retries):
|
| 1613 |
+
try:
|
| 1614 |
+
async with self._sem:
|
| 1615 |
+
resp = await self.client.chat.completions.create(
|
| 1616 |
+
model=self.model,
|
| 1617 |
+
messages=[
|
| 1618 |
+
{"role": "system", "content": system},
|
| 1619 |
+
{"role": "user", "content": user},
|
| 1620 |
+
],
|
| 1621 |
+
temperature=self.temperature,
|
| 1622 |
+
max_tokens=self.max_tokens,
|
| 1623 |
+
)
|
| 1624 |
+
text = resp.choices[0].message.content.strip()
|
| 1625 |
+
return text if text else None
|
| 1626 |
+
except Exception as exc:
|
| 1627 |
+
wait = 2 ** attempt
|
| 1628 |
+
log.warning(f"vLLM call failed (attempt {attempt+1}): {exc}. Retrying in {wait}s.")
|
| 1629 |
+
await asyncio.sleep(wait)
|
| 1630 |
+
return None
|
| 1631 |
+
|
| 1632 |
+
async def generate_batch(
|
| 1633 |
+
self,
|
| 1634 |
+
requests: List[Tuple[str, str, str]], # (request_id, system, user)
|
| 1635 |
+
) -> Dict[str, Optional[str]]:
|
| 1636 |
+
"""
|
| 1637 |
+
Fire all requests concurrently (bounded by semaphore) and return a dict.
|
| 1638 |
+
"""
|
| 1639 |
+
async def _one(rid, sys, usr):
|
| 1640 |
+
return rid, await self.generate_one(sys, usr)
|
| 1641 |
+
|
| 1642 |
+
tasks = [_one(rid, sys, usr) for rid, sys, usr in requests]
|
| 1643 |
+
results = await asyncio.gather(*tasks)
|
| 1644 |
+
return {rid: text for rid, text in results}
|
| 1645 |
+
|
| 1646 |
+
|
| 1647 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 1648 |
+
# Data Factory
|
| 1649 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 1650 |
+
|
| 1651 |
+
@dataclass
|
| 1652 |
+
class DataPoint:
|
| 1653 |
+
id: str
|
| 1654 |
+
difficulty: str
|
| 1655 |
+
persona: str
|
| 1656 |
+
question: str
|
| 1657 |
+
sql: str
|
| 1658 |
+
db_result_ok: bool
|
| 1659 |
+
augmented: bool
|
| 1660 |
+
|
| 1661 |
+
def to_training_prompt(self, system_prompt: str) -> Dict[str, Any]:
|
| 1662 |
+
"""
|
| 1663 |
+
Return the dict structure expected by train.py / SFT pipelines.
|
| 1664 |
+
Includes both the raw fields and a formatted 'messages' list.
|
| 1665 |
+
"""
|
| 1666 |
+
user_content = (
|
| 1667 |
+
f"SCHEMA:\n{SCHEMA_CONTEXT}\n\nQUESTION: {self.question}"
|
| 1668 |
+
)
|
| 1669 |
+
return {
|
| 1670 |
+
**asdict(self),
|
| 1671 |
+
"messages": [
|
| 1672 |
+
{"role": "system", "content": system_prompt},
|
| 1673 |
+
{"role": "user", "content": user_content},
|
| 1674 |
+
{"role": "assistant", "content": self.sql},
|
| 1675 |
+
],
|
| 1676 |
+
}
|
| 1677 |
+
|
| 1678 |
+
|
| 1679 |
+
SYSTEM_PROMPT = (
|
| 1680 |
+
"You are an expert SQL analyst working with a SQLite e-commerce database. "
|
| 1681 |
+
"Write a single SELECT query. Output ONLY the SQL query, nothing else. No markdown."
|
| 1682 |
+
)
|
| 1683 |
+
|
| 1684 |
+
|
| 1685 |
+
class DataFactory:
|
| 1686 |
+
def __init__(
|
| 1687 |
+
self,
|
| 1688 |
+
generator: VLLMGenerator,
|
| 1689 |
+
validator: SQLiteValidator,
|
| 1690 |
+
augmentor: RuleAugmentor,
|
| 1691 |
+
personas_per_template: int = 5,
|
| 1692 |
+
aug_rounds: int = 2,
|
| 1693 |
+
seed: int = 42,
|
| 1694 |
+
):
|
| 1695 |
+
self.generator = generator
|
| 1696 |
+
self.validator = validator
|
| 1697 |
+
self.augmentor = augmentor
|
| 1698 |
+
self.personas_per_template = personas_per_template
|
| 1699 |
+
self.aug_rounds = aug_rounds
|
| 1700 |
+
self.rng = random.Random(seed)
|
| 1701 |
+
|
| 1702 |
+
# ── Step 1: Validate all template SQLs ───────────────────────────────────
|
| 1703 |
+
|
| 1704 |
+
def validate_templates(self) -> List[SQLTemplate]:
|
| 1705 |
+
log.info("Validating all SQL templates against seeded DB...")
|
| 1706 |
+
valid = []
|
| 1707 |
+
failed = []
|
| 1708 |
+
for t in ALL_TEMPLATES:
|
| 1709 |
+
ok, err = self.validator.validate(t.sql)
|
| 1710 |
+
if ok:
|
| 1711 |
+
valid.append(t)
|
| 1712 |
+
else:
|
| 1713 |
+
failed.append((t.id, err))
|
| 1714 |
+
if failed:
|
| 1715 |
+
log.error(f"FAILED templates (will be skipped): {failed}")
|
| 1716 |
+
log.info(f"Templates validated: {len(valid)} ok, {len(failed)} failed.")
|
| 1717 |
+
return valid
|
| 1718 |
+
|
| 1719 |
+
# ── Step 2: Build generation requests ────────────────────────────────────
|
| 1720 |
+
|
| 1721 |
+
def _build_requests(
|
| 1722 |
+
self,
|
| 1723 |
+
templates: List[SQLTemplate],
|
| 1724 |
+
persona_names: List[str],
|
| 1725 |
+
) -> List[Tuple[str, str, str]]:
|
| 1726 |
+
"""
|
| 1727 |
+
Returns a flat list of (request_id, system_prompt, user_prompt) tuples.
|
| 1728 |
+
"""
|
| 1729 |
+
requests = []
|
| 1730 |
+
for t in templates:
|
| 1731 |
+
chosen_personas = (
|
| 1732 |
+
persona_names
|
| 1733 |
+
if self.personas_per_template >= len(PERSONA_SPECS)
|
| 1734 |
+
else self.rng.sample(persona_names, self.personas_per_template)
|
| 1735 |
+
)
|
| 1736 |
+
for persona in chosen_personas:
|
| 1737 |
+
rid = f"{t.id}__{persona}"
|
| 1738 |
+
system = (
|
| 1739 |
+
f"{PERSONA_SPECS[persona]}\n\n"
|
| 1740 |
+
"Output ONLY the natural language question. "
|
| 1741 |
+
"No explanation, no SQL, no preamble, no quotes around the question."
|
| 1742 |
+
)
|
| 1743 |
+
user = (
|
| 1744 |
+
f"{SCHEMA_CONTEXT}\n"
|
| 1745 |
+
f"The SQL query that answers this question is:\n{t.sql}\n\n"
|
| 1746 |
+
f"Write ONE natural-language question that a {persona.upper()} user "
|
| 1747 |
+
f"would ask to get this exact result."
|
| 1748 |
+
)
|
| 1749 |
+
requests.append((rid, system, user))
|
| 1750 |
+
return requests
|
| 1751 |
+
|
| 1752 |
+
# ── Step 3: Post-process a generated question ─────────────────────────────
|
| 1753 |
+
|
| 1754 |
+
@staticmethod
|
| 1755 |
+
def _clean(text: str) -> str:
|
| 1756 |
+
"""Strip quotes, markdown, leading numbers, trailing newlines."""
|
| 1757 |
+
text = text.strip()
|
| 1758 |
+
# Remove leading numbering like "1. " or "Q: "
|
| 1759 |
+
text = re.sub(r'^[\d]+[\.\)]\s+', '', text)
|
| 1760 |
+
text = re.sub(r'^[Qq]:\s*', '', text)
|
| 1761 |
+
# Strip surrounding quotes
|
| 1762 |
+
if (text.startswith('"') and text.endswith('"')) or \
|
| 1763 |
+
(text.startswith("'") and text.endswith("'")):
|
| 1764 |
+
text = text[1:-1].strip()
|
| 1765 |
+
# Collapse multiple whitespace
|
| 1766 |
+
text = re.sub(r'\s+', ' ', text)
|
| 1767 |
+
return text
|
| 1768 |
+
|
| 1769 |
+
# ── Main pipeline ─────────────────────────────────────────────────────────
|
| 1770 |
+
|
| 1771 |
+
async def run(
|
| 1772 |
+
self,
|
| 1773 |
+
output_path: str,
|
| 1774 |
+
checkpoint_path: str,
|
| 1775 |
+
batch_size: int = 64,
|
| 1776 |
+
) -> None:
|
| 1777 |
+
# -- Validate templates
|
| 1778 |
+
templates = self.validate_templates()
|
| 1779 |
+
|
| 1780 |
+
# -- Load checkpoint
|
| 1781 |
+
done_ids: set = set()
|
| 1782 |
+
if os.path.exists(checkpoint_path):
|
| 1783 |
+
with open(checkpoint_path) as f:
|
| 1784 |
+
done_ids = set(json.loads(line)["id"] for line in f if line.strip())
|
| 1785 |
+
log.info(f"Resuming: {len(done_ids)} examples already generated.")
|
| 1786 |
+
|
| 1787 |
+
persona_names = list(PERSONA_SPECS.keys())[: self.personas_per_template]
|
| 1788 |
+
|
| 1789 |
+
all_requests = self._build_requests(templates, persona_names)
|
| 1790 |
+
# Filter already done
|
| 1791 |
+
pending = [r for r in all_requests if r[0] not in done_ids]
|
| 1792 |
+
log.info(f"Total requests to generate: {len(pending)}")
|
| 1793 |
+
|
| 1794 |
+
# -- Build template lookup
|
| 1795 |
+
tmpl_lookup: Dict[str, SQLTemplate] = {t.id: t for t in templates}
|
| 1796 |
+
|
| 1797 |
+
stats = {"generated": 0, "invalid_llm": 0, "augmented": 0}
|
| 1798 |
+
|
| 1799 |
+
out_f = open(output_path, "a")
|
| 1800 |
+
ckpt_f = open(checkpoint_path, "a")
|
| 1801 |
+
|
| 1802 |
+
try:
|
| 1803 |
+
for i in tqdm(range(0, len(pending), batch_size), desc="Batches"):
|
| 1804 |
+
batch = pending[i: i + batch_size]
|
| 1805 |
+
results = await self.generator.generate_batch(batch)
|
| 1806 |
+
|
| 1807 |
+
for rid, raw_text in results.items():
|
| 1808 |
+
tmpl_id, persona = rid.split("__", 1)
|
| 1809 |
+
tmpl = tmpl_lookup[tmpl_id]
|
| 1810 |
+
|
| 1811 |
+
if not raw_text:
|
| 1812 |
+
stats["invalid_llm"] += 1
|
| 1813 |
+
continue
|
| 1814 |
+
|
| 1815 |
+
question = self._clean(raw_text)
|
| 1816 |
+
if len(question) < 8:
|
| 1817 |
+
stats["invalid_llm"] += 1
|
| 1818 |
+
continue
|
| 1819 |
+
|
| 1820 |
+
# SQL already validated; no need to re-run for NL variants
|
| 1821 |
+
dp = DataPoint(
|
| 1822 |
+
id=rid,
|
| 1823 |
+
difficulty=tmpl.difficulty,
|
| 1824 |
+
persona=persona,
|
| 1825 |
+
question=question,
|
| 1826 |
+
sql=tmpl.sql,
|
| 1827 |
+
db_result_ok=True,
|
| 1828 |
+
augmented=False,
|
| 1829 |
+
)
|
| 1830 |
+
record = dp.to_training_prompt(SYSTEM_PROMPT)
|
| 1831 |
+
line = json.dumps(record, ensure_ascii=False)
|
| 1832 |
+
out_f.write(line + "\n")
|
| 1833 |
+
ckpt_f.write(line + "\n")
|
| 1834 |
+
stats["generated"] += 1
|
| 1835 |
+
|
| 1836 |
+
# -- Rule augmentation rounds
|
| 1837 |
+
for aug_i in range(self.aug_rounds):
|
| 1838 |
+
aug_q = self.augmentor.augment(question, self.rng)
|
| 1839 |
+
if aug_q and aug_q != question:
|
| 1840 |
+
aug_dp = DataPoint(
|
| 1841 |
+
id=f"{rid}__aug{aug_i}",
|
| 1842 |
+
difficulty=tmpl.difficulty,
|
| 1843 |
+
persona=persona,
|
| 1844 |
+
question=aug_q,
|
| 1845 |
+
sql=tmpl.sql,
|
| 1846 |
+
db_result_ok=True,
|
| 1847 |
+
augmented=True,
|
| 1848 |
+
)
|
| 1849 |
+
aug_record = aug_dp.to_training_prompt(SYSTEM_PROMPT)
|
| 1850 |
+
aug_line = json.dumps(aug_record, ensure_ascii=False)
|
| 1851 |
+
out_f.write(aug_line + "\n")
|
| 1852 |
+
ckpt_f.write(aug_line + "\n")
|
| 1853 |
+
stats["augmented"] += 1
|
| 1854 |
+
|
| 1855 |
+
out_f.flush()
|
| 1856 |
+
ckpt_f.flush()
|
| 1857 |
+
|
| 1858 |
+
finally:
|
| 1859 |
+
out_f.close()
|
| 1860 |
+
ckpt_f.close()
|
| 1861 |
+
|
| 1862 |
+
log.info(
|
| 1863 |
+
f"Done. Generated={stats['generated']} "
|
| 1864 |
+
f"Augmented={stats['augmented']} "
|
| 1865 |
+
f"LLM failures={stats['invalid_llm']}"
|
| 1866 |
+
)
|
| 1867 |
+
log.info(f"Output: {output_path}")
|
| 1868 |
+
|
| 1869 |
+
|
| 1870 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 1871 |
+
# CLI
|
| 1872 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 1873 |
+
|
| 1874 |
+
def parse_args() -> argparse.Namespace:
|
| 1875 |
+
p = argparse.ArgumentParser(
|
| 1876 |
+
description="NL2SQL Synthetic Data Factory — H100 + vLLM",
|
| 1877 |
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
| 1878 |
+
)
|
| 1879 |
+
p.add_argument("--vllm-url", default="http://localhost:8001/v1",
|
| 1880 |
+
help="Base URL of the running vLLM server.")
|
| 1881 |
+
p.add_argument("--model", default="meta-llama/Meta-Llama-3-70B-Instruct",
|
| 1882 |
+
help="Model name as registered in the vLLM server.")
|
| 1883 |
+
p.add_argument("--output", default="nl2sql_train.jsonl",
|
| 1884 |
+
help="Path to write the final JSONL dataset.")
|
| 1885 |
+
p.add_argument("--checkpoint",default="nl2sql_checkpoint.jsonl",
|
| 1886 |
+
help="Path for the checkpoint file (enables resume on crash).")
|
| 1887 |
+
p.add_argument("--personas-per-template", type=int, default=5,
|
| 1888 |
+
help="Number of persona variants to generate per SQL template (max 5).")
|
| 1889 |
+
p.add_argument("--aug-rounds", type=int, default=2,
|
| 1890 |
+
help="Number of rule-based augmentation rounds per generated question.")
|
| 1891 |
+
p.add_argument("--batch-size", type=int, default=64,
|
| 1892 |
+
help="Concurrent vLLM requests per batch (tune based on GPU memory).")
|
| 1893 |
+
p.add_argument("--temperature", type=float, default=0.85,
|
| 1894 |
+
help="Sampling temperature for vLLM (higher = more diverse).")
|
| 1895 |
+
p.add_argument("--max-tokens", type=int, default=200,
|
| 1896 |
+
help="Max tokens for each generated question.")
|
| 1897 |
+
p.add_argument("--seed", type=int, default=42)
|
| 1898 |
+
p.add_argument("--validate-only", action="store_true",
|
| 1899 |
+
help="Only validate SQL templates, do not generate data.")
|
| 1900 |
+
return p.parse_args()
|
| 1901 |
+
|
| 1902 |
+
|
| 1903 |
+
async def main() -> None:
|
| 1904 |
+
args = parse_args()
|
| 1905 |
+
|
| 1906 |
+
# Build DB + validator
|
| 1907 |
+
conn = build_db()
|
| 1908 |
+
validator = SQLiteValidator(conn)
|
| 1909 |
+
|
| 1910 |
+
if args.validate_only:
|
| 1911 |
+
valid = [t for t in ALL_TEMPLATES if validator.validate(t.sql)[0]]
|
| 1912 |
+
invalid = [t for t in ALL_TEMPLATES if not validator.validate(t.sql)[0]]
|
| 1913 |
+
print(f"\n✅ Valid: {len(valid)}")
|
| 1914 |
+
print(f"❌ Invalid: {len(invalid)}")
|
| 1915 |
+
for t in invalid:
|
| 1916 |
+
_, err = validator.validate(t.sql)
|
| 1917 |
+
print(f" {t.id}: {err}")
|
| 1918 |
+
return
|
| 1919 |
+
|
| 1920 |
+
# Build pipeline components
|
| 1921 |
+
generator = VLLMGenerator(
|
| 1922 |
+
base_url=args.vllm_url,
|
| 1923 |
+
model=args.model,
|
| 1924 |
+
temperature=args.temperature,
|
| 1925 |
+
max_tokens=args.max_tokens,
|
| 1926 |
+
semaphore=args.batch_size,
|
| 1927 |
+
)
|
| 1928 |
+
augmentor = RuleAugmentor()
|
| 1929 |
+
|
| 1930 |
+
factory = DataFactory(
|
| 1931 |
+
generator=generator,
|
| 1932 |
+
validator=validator,
|
| 1933 |
+
augmentor=augmentor,
|
| 1934 |
+
personas_per_template=min(args.personas_per_template, len(PERSONA_SPECS)),
|
| 1935 |
+
aug_rounds=args.aug_rounds,
|
| 1936 |
+
seed=args.seed,
|
| 1937 |
+
)
|
| 1938 |
+
|
| 1939 |
+
await factory.run(
|
| 1940 |
+
output_path=args.output,
|
| 1941 |
+
checkpoint_path=args.checkpoint,
|
| 1942 |
+
batch_size=args.batch_size,
|
| 1943 |
+
)
|
| 1944 |
+
|
| 1945 |
+
|
| 1946 |
+
if __name__ == "__main__":
|
| 1947 |
+
asyncio.run(main())
|
data_factory/generator.py
ADDED
|
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
data_factory/generator.py
|
| 3 |
+
==========================
|
| 4 |
+
vLLM-based Natural Language question generator for H100.
|
| 5 |
+
|
| 6 |
+
This module uses a large LLM (Llama-3-70B or Qwen-72B) served via vLLM
|
| 7 |
+
to generate diverse, persona-based natural language paraphrases of the
|
| 8 |
+
canonical NL questions in our template library.
|
| 9 |
+
|
| 10 |
+
KEY DESIGN: The LLM generates ONLY natural language questions.
|
| 11 |
+
SQL is NEVER touched by the LLM.
|
| 12 |
+
This guarantees zero SQL errors in the final dataset.
|
| 13 |
+
|
| 14 |
+
Persona descriptions:
|
| 15 |
+
ceo - Direct, short, active voice. Business executive style.
|
| 16 |
+
chatty - Conversational, verbose, passive voice.
|
| 17 |
+
lazy_typist - Short, abbreviations, possible informal grammar.
|
| 18 |
+
non_techie - Plain English, avoids SQL/tech jargon, uses synonyms.
|
| 19 |
+
analyst - Technical, precise, jargon-heavy.
|
| 20 |
+
|
| 21 |
+
Usage (on H100 cluster):
|
| 22 |
+
python -m data_factory.generator --templates-per-chunk 20 --n-variants 10
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
from __future__ import annotations
|
| 26 |
+
|
| 27 |
+
import json
|
| 28 |
+
import logging
|
| 29 |
+
import time
|
| 30 |
+
from typing import Iterator, Optional
|
| 31 |
+
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 36 |
+
# PERSONA SYSTEM PROMPTS
|
| 37 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 38 |
+
|
| 39 |
+
PERSONA_SYSTEM_PROMPTS: dict[str, str] = {
|
| 40 |
+
|
| 41 |
+
"ceo": (
|
| 42 |
+
"You are a busy C-level executive who communicates in short, punchy, "
|
| 43 |
+
"direct sentences. You use active voice, skip filler words, and get "
|
| 44 |
+
"straight to the point. You are asking a data analyst for information."
|
| 45 |
+
),
|
| 46 |
+
|
| 47 |
+
"chatty": (
|
| 48 |
+
"You are a friendly, conversational person who likes to be thorough "
|
| 49 |
+
"and explain things fully. You use passive voice sometimes, add context, "
|
| 50 |
+
"and ask questions in a relaxed, detailed way. You are not technical."
|
| 51 |
+
),
|
| 52 |
+
|
| 53 |
+
"lazy_typist": (
|
| 54 |
+
"You type quickly and informally. You use abbreviations (e.g. 'pls', "
|
| 55 |
+
"'lmk', 'asap'), lowercase, minimal punctuation, and sometimes omit "
|
| 56 |
+
"words. You get your meaning across without perfect grammar."
|
| 57 |
+
),
|
| 58 |
+
|
| 59 |
+
"non_techie": (
|
| 60 |
+
"You have no database or SQL knowledge. You use everyday English words "
|
| 61 |
+
"instead of technical terms. For example, you say 'customers' not 'rows', "
|
| 62 |
+
"'most expensive' not 'highest price', 'total money' not 'sum'. "
|
| 63 |
+
"You describe what you want to see, not how to get it."
|
| 64 |
+
),
|
| 65 |
+
|
| 66 |
+
"analyst": (
|
| 67 |
+
"You are a data scientist or BI analyst who is precise and technical. "
|
| 68 |
+
"You use terms like 'aggregate', 'partition', 'granularity', 'distinct', "
|
| 69 |
+
"'filter predicate', 'ranked by metric'. Your questions are precise and unambiguous."
|
| 70 |
+
),
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 75 |
+
# PROMPT BUILDER
|
| 76 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 77 |
+
|
| 78 |
+
def build_generation_prompt(
|
| 79 |
+
canonical_nl: str,
|
| 80 |
+
description: str,
|
| 81 |
+
persona: str,
|
| 82 |
+
schema_context: str,
|
| 83 |
+
n_variants: int = 5,
|
| 84 |
+
) -> list[dict[str, str]]:
|
| 85 |
+
"""
|
| 86 |
+
Build a chat-format prompt asking the LLM to rephrase the canonical NL
|
| 87 |
+
question in the style of the given persona.
|
| 88 |
+
|
| 89 |
+
Parameters
|
| 90 |
+
----------
|
| 91 |
+
canonical_nl : The base NL question from the template.
|
| 92 |
+
description : One-line SQL description (gives the LLM additional context).
|
| 93 |
+
persona : One of the 5 persona keys.
|
| 94 |
+
schema_context : The compact schema string for the domain.
|
| 95 |
+
n_variants : How many rephrased questions to generate.
|
| 96 |
+
|
| 97 |
+
Returns
|
| 98 |
+
-------
|
| 99 |
+
list[dict] Chat messages in [{"role": ..., "content": ...}] format.
|
| 100 |
+
"""
|
| 101 |
+
persona_desc = PERSONA_SYSTEM_PROMPTS[persona]
|
| 102 |
+
|
| 103 |
+
system = (
|
| 104 |
+
"You are a data labelling specialist. Your task is to rephrase a database "
|
| 105 |
+
"question in a specific communication style (persona). The rephrased questions "
|
| 106 |
+
"must preserve the EXACT same intent and required information as the original — "
|
| 107 |
+
"do not change what data is being asked for, only how it is expressed.\n\n"
|
| 108 |
+
f"PERSONA: {persona_desc}\n\n"
|
| 109 |
+
"OUTPUT FORMAT: Return ONLY a valid JSON array of strings. "
|
| 110 |
+
"No preamble, no markdown, no extra keys. Example: "
|
| 111 |
+
'["question 1", "question 2", "question 3"]'
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
user = (
|
| 115 |
+
f"DATABASE CONTEXT:\n{schema_context}\n\n"
|
| 116 |
+
f"WHAT THE QUERY DOES: {description}\n\n"
|
| 117 |
+
f"CANONICAL QUESTION: {canonical_nl}\n\n"
|
| 118 |
+
f"Generate {n_variants} different ways a person with the persona described "
|
| 119 |
+
f"above would ask this same question. The meaning must stay identical."
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
return [
|
| 123 |
+
{"role": "system", "content": system},
|
| 124 |
+
{"role": "user", "content": user},
|
| 125 |
+
]
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 129 |
+
# RESPONSE PARSER
|
| 130 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 131 |
+
|
| 132 |
+
def parse_llm_response(raw_text: str) -> list[str]:
|
| 133 |
+
"""
|
| 134 |
+
Extract a list of strings from the LLM's JSON response.
|
| 135 |
+
Handles common failures: markdown fences, trailing commas, extra text.
|
| 136 |
+
|
| 137 |
+
Returns an empty list if parsing fails completely.
|
| 138 |
+
"""
|
| 139 |
+
text = raw_text.strip()
|
| 140 |
+
|
| 141 |
+
# Strip markdown fences if present
|
| 142 |
+
if text.startswith("```"):
|
| 143 |
+
lines = text.split("\n")
|
| 144 |
+
text = "\n".join(l for l in lines if not l.strip().startswith("```")).strip()
|
| 145 |
+
|
| 146 |
+
# Find the JSON array boundaries
|
| 147 |
+
start = text.find("[")
|
| 148 |
+
end = text.rfind("]")
|
| 149 |
+
if start == -1 or end == -1 or end <= start:
|
| 150 |
+
logger.warning("LLM response missing JSON array brackets: %s", text[:100])
|
| 151 |
+
return []
|
| 152 |
+
|
| 153 |
+
json_str = text[start:end + 1]
|
| 154 |
+
|
| 155 |
+
# Fix trailing commas before ] (common LLM mistake)
|
| 156 |
+
json_str = json_str.rstrip()
|
| 157 |
+
json_str = json_str.replace(",]", "]").replace(", ]", "]")
|
| 158 |
+
|
| 159 |
+
try:
|
| 160 |
+
parsed = json.loads(json_str)
|
| 161 |
+
if not isinstance(parsed, list):
|
| 162 |
+
return []
|
| 163 |
+
# Filter to only non-empty strings
|
| 164 |
+
return [s.strip() for s in parsed if isinstance(s, str) and s.strip()]
|
| 165 |
+
except json.JSONDecodeError as exc:
|
| 166 |
+
logger.warning("JSON parse error: %s | text: %s", exc, json_str[:200])
|
| 167 |
+
return []
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 171 |
+
# VLLM INTERFACE
|
| 172 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 173 |
+
|
| 174 |
+
class VLLMGenerator:
|
| 175 |
+
"""
|
| 176 |
+
Wrapper around a running vLLM server for high-throughput NL generation.
|
| 177 |
+
|
| 178 |
+
Supports two modes:
|
| 179 |
+
online : Calls a running vLLM OpenAI-compatible API server.
|
| 180 |
+
offline : Uses vllm.LLM directly (loads model in-process, H100 recommended).
|
| 181 |
+
|
| 182 |
+
For H100 cluster usage, prefer 'offline' mode with tensor_parallel_size=4
|
| 183 |
+
to saturate all 4 H100s for maximum throughput.
|
| 184 |
+
"""
|
| 185 |
+
|
| 186 |
+
def __init__(
|
| 187 |
+
self,
|
| 188 |
+
model_name: str,
|
| 189 |
+
mode: str = "offline",
|
| 190 |
+
tensor_parallel_size: int = 4,
|
| 191 |
+
gpu_memory_utilization: float = 0.90,
|
| 192 |
+
max_model_len: int = 4096,
|
| 193 |
+
# Online mode only
|
| 194 |
+
api_base: str = "http://localhost:8000/v1",
|
| 195 |
+
api_key: str = "EMPTY",
|
| 196 |
+
) -> None:
|
| 197 |
+
self.model_name = model_name
|
| 198 |
+
self.mode = mode
|
| 199 |
+
self._llm = None
|
| 200 |
+
self._client = None
|
| 201 |
+
|
| 202 |
+
if mode == "offline":
|
| 203 |
+
self._init_offline(tensor_parallel_size, gpu_memory_utilization, max_model_len)
|
| 204 |
+
elif mode == "online":
|
| 205 |
+
self._init_online(api_base, api_key)
|
| 206 |
+
else:
|
| 207 |
+
raise ValueError(f"Unknown mode: {mode!r}. Use 'offline' or 'online'.")
|
| 208 |
+
|
| 209 |
+
def _init_offline(
|
| 210 |
+
self,
|
| 211 |
+
tensor_parallel_size: int,
|
| 212 |
+
gpu_memory_utilization: float,
|
| 213 |
+
max_model_len: int,
|
| 214 |
+
) -> None:
|
| 215 |
+
"""Load vLLM engine in-process (best for H100 cluster)."""
|
| 216 |
+
try:
|
| 217 |
+
from vllm import LLM, SamplingParams
|
| 218 |
+
self._LLM = LLM
|
| 219 |
+
self._SamplingParams = SamplingParams
|
| 220 |
+
except ImportError:
|
| 221 |
+
raise ImportError(
|
| 222 |
+
"vLLM not installed. Run: pip install vllm\n"
|
| 223 |
+
"For H100: pip install vllm --extra-index-url https://download.pytorch.org/whl/cu124"
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
logger.info("Loading model %s with %d GPUs (offline mode)...", self.model_name, tensor_parallel_size)
|
| 227 |
+
t0 = time.time()
|
| 228 |
+
self._llm = self._LLM(
|
| 229 |
+
model=self.model_name,
|
| 230 |
+
tensor_parallel_size=tensor_parallel_size,
|
| 231 |
+
gpu_memory_utilization=gpu_memory_utilization,
|
| 232 |
+
max_model_len=max_model_len,
|
| 233 |
+
dtype="bfloat16",
|
| 234 |
+
trust_remote_code=True,
|
| 235 |
+
)
|
| 236 |
+
logger.info("Model loaded in %.1f seconds.", time.time() - t0)
|
| 237 |
+
|
| 238 |
+
def _init_online(self, api_base: str, api_key: str) -> None:
|
| 239 |
+
"""Use OpenAI-compatible vLLM server (for distributed setups)."""
|
| 240 |
+
try:
|
| 241 |
+
from openai import OpenAI
|
| 242 |
+
self._client = OpenAI(base_url=api_base, api_key=api_key)
|
| 243 |
+
except ImportError:
|
| 244 |
+
raise ImportError("pip install openai")
|
| 245 |
+
logger.info("Connected to vLLM server at %s", api_base)
|
| 246 |
+
|
| 247 |
+
def generate_batch(
|
| 248 |
+
self,
|
| 249 |
+
prompts: list[list[dict[str, str]]],
|
| 250 |
+
temperature: float = 0.85,
|
| 251 |
+
max_new_tokens: int = 300,
|
| 252 |
+
) -> list[str]:
|
| 253 |
+
"""
|
| 254 |
+
Generate responses for a batch of chat prompts.
|
| 255 |
+
|
| 256 |
+
Parameters
|
| 257 |
+
----------
|
| 258 |
+
prompts : List of chat message lists (one per item in batch).
|
| 259 |
+
temperature : Sampling temperature. Higher = more diverse.
|
| 260 |
+
max_new_tokens : Max tokens per response.
|
| 261 |
+
|
| 262 |
+
Returns
|
| 263 |
+
-------
|
| 264 |
+
list[str] Raw text response per prompt (same length as input).
|
| 265 |
+
"""
|
| 266 |
+
if self.mode == "offline":
|
| 267 |
+
return self._generate_offline(prompts, temperature, max_new_tokens)
|
| 268 |
+
else:
|
| 269 |
+
return self._generate_online(prompts, temperature, max_new_tokens)
|
| 270 |
+
|
| 271 |
+
def _generate_offline(
|
| 272 |
+
self,
|
| 273 |
+
prompts: list[list[dict]],
|
| 274 |
+
temperature: float,
|
| 275 |
+
max_new_tokens: int,
|
| 276 |
+
) -> list[str]:
|
| 277 |
+
"""vLLM offline batched generation — maximises H100 throughput."""
|
| 278 |
+
from vllm import SamplingParams
|
| 279 |
+
|
| 280 |
+
sampling = SamplingParams(
|
| 281 |
+
temperature=temperature,
|
| 282 |
+
max_tokens=max_new_tokens,
|
| 283 |
+
stop=["</s>", "<|eot_id|>"], # Llama-3 stop tokens
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
# Convert chat messages to tokenised prompt strings using the model's template
|
| 287 |
+
tokenizer = self._llm.get_tokenizer()
|
| 288 |
+
formatted_prompts: list[str] = []
|
| 289 |
+
for msgs in prompts:
|
| 290 |
+
if hasattr(tokenizer, "apply_chat_template"):
|
| 291 |
+
text = tokenizer.apply_chat_template(
|
| 292 |
+
msgs, tokenize=False, add_generation_prompt=True
|
| 293 |
+
)
|
| 294 |
+
else:
|
| 295 |
+
# Fallback: simple concatenation
|
| 296 |
+
text = "\n".join(
|
| 297 |
+
f"<|{m['role']}|>\n{m['content']}" for m in msgs
|
| 298 |
+
)
|
| 299 |
+
formatted_prompts.append(text)
|
| 300 |
+
|
| 301 |
+
outputs = self._llm.generate(formatted_prompts, sampling)
|
| 302 |
+
return [o.outputs[0].text for o in outputs]
|
| 303 |
+
|
| 304 |
+
def _generate_online(
|
| 305 |
+
self,
|
| 306 |
+
prompts: list[list[dict]],
|
| 307 |
+
temperature: float,
|
| 308 |
+
max_new_tokens: int,
|
| 309 |
+
) -> list[str]:
|
| 310 |
+
"""Sequential generation via OpenAI-compatible API (fallback / debugging)."""
|
| 311 |
+
results = []
|
| 312 |
+
for msgs in prompts:
|
| 313 |
+
try:
|
| 314 |
+
resp = self._client.chat.completions.create(
|
| 315 |
+
model=self.model_name,
|
| 316 |
+
messages=msgs,
|
| 317 |
+
temperature=temperature,
|
| 318 |
+
max_tokens=max_new_tokens,
|
| 319 |
+
)
|
| 320 |
+
results.append(resp.choices[0].message.content or "")
|
| 321 |
+
except Exception as exc:
|
| 322 |
+
logger.warning("API call failed: %s", exc)
|
| 323 |
+
results.append("")
|
| 324 |
+
return results
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 328 |
+
# HIGH-LEVEL GENERATION LOOP
|
| 329 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 330 |
+
|
| 331 |
+
def generate_persona_variants_batch(
|
| 332 |
+
templates_subset: list[dict],
|
| 333 |
+
generator: VLLMGenerator,
|
| 334 |
+
personas: list[str],
|
| 335 |
+
n_variants_per_persona: int = 5,
|
| 336 |
+
batch_size: int = 64,
|
| 337 |
+
temperature: float = 0.85,
|
| 338 |
+
max_new_tokens: int = 300,
|
| 339 |
+
) -> Iterator[dict]:
|
| 340 |
+
"""
|
| 341 |
+
For each template × persona combination, generate `n_variants_per_persona`
|
| 342 |
+
NL question variants using the LLM.
|
| 343 |
+
|
| 344 |
+
Yields dicts:
|
| 345 |
+
{
|
| 346 |
+
"template_idx": int,
|
| 347 |
+
"persona": str,
|
| 348 |
+
"nl_variants": list[str], # successfully parsed NL questions
|
| 349 |
+
}
|
| 350 |
+
|
| 351 |
+
Parameters
|
| 352 |
+
----------
|
| 353 |
+
templates_subset : List of template dicts (from templates.py).
|
| 354 |
+
generator : VLLMGenerator instance.
|
| 355 |
+
personas : List of persona keys to use.
|
| 356 |
+
n_variants_per_persona : How many NL variants per (template, persona) pair.
|
| 357 |
+
batch_size : How many LLM calls to batch together.
|
| 358 |
+
temperature : Sampling temperature.
|
| 359 |
+
max_new_tokens : Max tokens for LLM response (should be ~300 for JSON array).
|
| 360 |
+
"""
|
| 361 |
+
from data_factory.schemas import SCHEMA_CONTEXT
|
| 362 |
+
|
| 363 |
+
# Build all (template_idx, persona) prompt pairs
|
| 364 |
+
all_jobs: list[tuple[int, str, list[dict]]] = []
|
| 365 |
+
|
| 366 |
+
for t_idx, template in enumerate(templates_subset):
|
| 367 |
+
schema_ctx = SCHEMA_CONTEXT[template["domain"]]
|
| 368 |
+
for persona in personas:
|
| 369 |
+
prompt = build_generation_prompt(
|
| 370 |
+
canonical_nl=template["base_nl"],
|
| 371 |
+
description=template["description"],
|
| 372 |
+
persona=persona,
|
| 373 |
+
schema_context=schema_ctx,
|
| 374 |
+
n_variants=n_variants_per_persona,
|
| 375 |
+
)
|
| 376 |
+
all_jobs.append((t_idx, persona, prompt))
|
| 377 |
+
|
| 378 |
+
total_jobs = len(all_jobs)
|
| 379 |
+
logger.info("Starting LLM generation: %d jobs (templates × personas).", total_jobs)
|
| 380 |
+
|
| 381 |
+
# Process in batches
|
| 382 |
+
for batch_start in range(0, total_jobs, batch_size):
|
| 383 |
+
batch = all_jobs[batch_start: batch_start + batch_size]
|
| 384 |
+
prompts = [job[2] for job in batch]
|
| 385 |
+
|
| 386 |
+
t0 = time.time()
|
| 387 |
+
raw_responses = generator.generate_batch(
|
| 388 |
+
prompts, temperature=temperature, max_new_tokens=max_new_tokens
|
| 389 |
+
)
|
| 390 |
+
elapsed = time.time() - t0
|
| 391 |
+
logger.info(
|
| 392 |
+
"Batch %d-%d completed in %.1fs (%.1f jobs/s).",
|
| 393 |
+
batch_start, batch_start + len(batch), elapsed, len(batch) / max(elapsed, 0.001)
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
for (t_idx, persona, _), raw in zip(batch, raw_responses):
|
| 397 |
+
nl_variants = parse_llm_response(raw)
|
| 398 |
+
if not nl_variants:
|
| 399 |
+
logger.debug(
|
| 400 |
+
"Empty parse for template_idx=%d persona=%s. raw=%s",
|
| 401 |
+
t_idx, persona, raw[:100]
|
| 402 |
+
)
|
| 403 |
+
# Fall back to the canonical NL rather than losing this entry
|
| 404 |
+
nl_variants = [templates_subset[t_idx]["base_nl"]]
|
| 405 |
+
|
| 406 |
+
yield {
|
| 407 |
+
"template_idx": t_idx,
|
| 408 |
+
"persona": persona,
|
| 409 |
+
"nl_variants": nl_variants,
|
| 410 |
+
}
|
data_factory/pipeline.py
ADDED
|
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
data_factory/pipeline.py
|
| 3 |
+
=========================
|
| 4 |
+
Master orchestration pipeline for the NL2SQL Synthetic Data Factory.
|
| 5 |
+
|
| 6 |
+
This module ties together:
|
| 7 |
+
1. Template library (66 verified SQL templates across 4 domains)
|
| 8 |
+
2. Rule-based NL augmentation (augmentor.py)
|
| 9 |
+
3. vLLM persona-based NL generation (generator.py)
|
| 10 |
+
4. SQL execution validation (validator.py)
|
| 11 |
+
5. Output serialisation (JSONL + Parquet)
|
| 12 |
+
|
| 13 |
+
Run modes:
|
| 14 |
+
--mode base : Only uses template base_nl + rule augmentation (no GPU required)
|
| 15 |
+
--mode full : base + vLLM persona generation (requires H100)
|
| 16 |
+
|
| 17 |
+
Output dataset format (JSONL, one record per line):
|
| 18 |
+
{
|
| 19 |
+
"prompt": [{"role": "system", ...}, {"role": "user", ...}],
|
| 20 |
+
"sql": "SELECT ...",
|
| 21 |
+
"metadata": { "domain", "difficulty", "persona", ... }
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
This format is directly loadable by:
|
| 25 |
+
datasets.load_dataset("json", data_files="output/train.jsonl")
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
from __future__ import annotations
|
| 29 |
+
|
| 30 |
+
import argparse
|
| 31 |
+
import json
|
| 32 |
+
import logging
|
| 33 |
+
import os
|
| 34 |
+
import random
|
| 35 |
+
import time
|
| 36 |
+
from pathlib import Path
|
| 37 |
+
from typing import Any, Iterator, Optional
|
| 38 |
+
|
| 39 |
+
logging.basicConfig(
|
| 40 |
+
level=logging.INFO,
|
| 41 |
+
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
| 42 |
+
datefmt="%H:%M:%S",
|
| 43 |
+
)
|
| 44 |
+
logger = logging.getLogger("pipeline")
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 48 |
+
# HELPERS
|
| 49 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 50 |
+
|
| 51 |
+
def _ensure_dirs(*dirs: Path) -> None:
|
| 52 |
+
for d in dirs:
|
| 53 |
+
d.mkdir(parents=True, exist_ok=True)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def _write_jsonl(records: list[dict], path: Path) -> None:
|
| 57 |
+
with open(path, "w", encoding="utf-8") as f:
|
| 58 |
+
for rec in records:
|
| 59 |
+
f.write(json.dumps(rec, ensure_ascii=False) + "\n")
|
| 60 |
+
logger.info("Wrote %d records to %s", len(records), path)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _write_parquet(records: list[dict], path: Path) -> None:
|
| 64 |
+
try:
|
| 65 |
+
import pandas as pd
|
| 66 |
+
df = pd.DataFrame(records)
|
| 67 |
+
df.to_parquet(path, index=False, engine="pyarrow", compression="snappy")
|
| 68 |
+
logger.info("Wrote %d records to %s (Parquet)", len(records), path)
|
| 69 |
+
except ImportError:
|
| 70 |
+
logger.warning("pandas/pyarrow not installed — skipping Parquet output.")
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _train_val_test_split(
|
| 74 |
+
records: list[dict],
|
| 75 |
+
train_frac: float = 0.90,
|
| 76 |
+
val_frac: float = 0.05,
|
| 77 |
+
seed: int = 42,
|
| 78 |
+
) -> tuple[list[dict], list[dict], list[dict]]:
|
| 79 |
+
"""
|
| 80 |
+
Stratified split by (domain, difficulty) to ensure all combinations
|
| 81 |
+
are represented in every split.
|
| 82 |
+
"""
|
| 83 |
+
rng = random.Random(seed)
|
| 84 |
+
from collections import defaultdict
|
| 85 |
+
|
| 86 |
+
buckets: dict[str, list[dict]] = defaultdict(list)
|
| 87 |
+
for rec in records:
|
| 88 |
+
key = f"{rec['metadata']['domain']}_{rec['metadata']['difficulty']}"
|
| 89 |
+
buckets[key].append(rec)
|
| 90 |
+
|
| 91 |
+
train, val, test = [], [], []
|
| 92 |
+
for key, bucket in buckets.items():
|
| 93 |
+
rng.shuffle(bucket)
|
| 94 |
+
n = len(bucket)
|
| 95 |
+
n_train = max(1, int(n * train_frac))
|
| 96 |
+
n_val = max(1, int(n * val_frac))
|
| 97 |
+
train.extend(bucket[:n_train])
|
| 98 |
+
val.extend(bucket[n_train:n_train + n_val])
|
| 99 |
+
test.extend(bucket[n_train + n_val:])
|
| 100 |
+
|
| 101 |
+
rng.shuffle(train)
|
| 102 |
+
rng.shuffle(val)
|
| 103 |
+
rng.shuffle(test)
|
| 104 |
+
return train, val, test
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 108 |
+
# PHASE 1: BASE + RULE AUGMENTATION (no GPU required)
|
| 109 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 110 |
+
|
| 111 |
+
def run_base_pipeline(
|
| 112 |
+
templates: list,
|
| 113 |
+
n_augmentations: int = 5,
|
| 114 |
+
seed: int = 42,
|
| 115 |
+
) -> list[dict]:
|
| 116 |
+
"""
|
| 117 |
+
Generate training records from:
|
| 118 |
+
(a) the canonical base_nl of each template
|
| 119 |
+
(b) rule-based augmented NL variants
|
| 120 |
+
|
| 121 |
+
Returns a list of training dicts (ready to write to JSONL).
|
| 122 |
+
"""
|
| 123 |
+
from data_factory.augmentor import augment_nl
|
| 124 |
+
from data_factory.validator import SQLValidator, build_record
|
| 125 |
+
from data_factory.schemas import SCHEMA_MAP
|
| 126 |
+
|
| 127 |
+
# Build one validator per domain (reuse connection across templates)
|
| 128 |
+
validators = {domain: SQLValidator(domain, seed=seed) for domain in SCHEMA_MAP}
|
| 129 |
+
records: list[dict] = []
|
| 130 |
+
|
| 131 |
+
for t_idx, template in enumerate(templates):
|
| 132 |
+
v = validators[template["domain"]]
|
| 133 |
+
|
| 134 |
+
# (a) Canonical base_nl
|
| 135 |
+
rec = build_record(
|
| 136 |
+
template=template,
|
| 137 |
+
template_idx=t_idx,
|
| 138 |
+
nl_question=template["base_nl"],
|
| 139 |
+
persona="canonical",
|
| 140 |
+
source="template_base",
|
| 141 |
+
validator=v,
|
| 142 |
+
)
|
| 143 |
+
if rec:
|
| 144 |
+
records.append(rec.to_training_dict())
|
| 145 |
+
|
| 146 |
+
# (b) Rule-augmented variants
|
| 147 |
+
augmented = augment_nl(
|
| 148 |
+
nl_question=template["base_nl"],
|
| 149 |
+
n=n_augmentations,
|
| 150 |
+
seed=seed + t_idx,
|
| 151 |
+
)
|
| 152 |
+
for nl_variant in augmented:
|
| 153 |
+
rec = build_record(
|
| 154 |
+
template=template,
|
| 155 |
+
template_idx=t_idx,
|
| 156 |
+
nl_question=nl_variant,
|
| 157 |
+
persona="rule_augmented",
|
| 158 |
+
source="rule_augmented",
|
| 159 |
+
validator=v,
|
| 160 |
+
)
|
| 161 |
+
if rec:
|
| 162 |
+
records.append(rec.to_training_dict())
|
| 163 |
+
|
| 164 |
+
for v in validators.values():
|
| 165 |
+
v.close()
|
| 166 |
+
|
| 167 |
+
logger.info("Base pipeline: %d records generated from %d templates.", len(records), len(templates))
|
| 168 |
+
return records
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 172 |
+
# PHASE 2: vLLM PERSONA GENERATION (H100 required)
|
| 173 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 174 |
+
|
| 175 |
+
def run_vllm_pipeline(
|
| 176 |
+
templates: list,
|
| 177 |
+
generator, # VLLMGenerator instance
|
| 178 |
+
personas: list[str],
|
| 179 |
+
n_variants_per_persona: int = 10,
|
| 180 |
+
batch_size: int = 64,
|
| 181 |
+
temperature: float = 0.85,
|
| 182 |
+
max_new_tokens: int = 350,
|
| 183 |
+
seed: int = 42,
|
| 184 |
+
) -> list[dict]:
|
| 185 |
+
"""
|
| 186 |
+
Generate additional NL variants using the LLM, then validate SQL.
|
| 187 |
+
|
| 188 |
+
Returns a list of training dicts.
|
| 189 |
+
"""
|
| 190 |
+
from data_factory.generator import generate_persona_variants_batch
|
| 191 |
+
from data_factory.validator import SQLValidator, build_record
|
| 192 |
+
from data_factory.schemas import SCHEMA_MAP
|
| 193 |
+
|
| 194 |
+
validators = {domain: SQLValidator(domain, seed=seed) for domain in SCHEMA_MAP}
|
| 195 |
+
records: list[dict] = []
|
| 196 |
+
|
| 197 |
+
gen_iter = generate_persona_variants_batch(
|
| 198 |
+
templates_subset=templates,
|
| 199 |
+
generator=generator,
|
| 200 |
+
personas=personas,
|
| 201 |
+
n_variants_per_persona=n_variants_per_persona,
|
| 202 |
+
batch_size=batch_size,
|
| 203 |
+
temperature=temperature,
|
| 204 |
+
max_new_tokens=max_new_tokens,
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
for job_result in gen_iter:
|
| 208 |
+
t_idx = job_result["template_idx"]
|
| 209 |
+
persona = job_result["persona"]
|
| 210 |
+
template = templates[t_idx]
|
| 211 |
+
v = validators[template["domain"]]
|
| 212 |
+
|
| 213 |
+
for nl_variant in job_result["nl_variants"]:
|
| 214 |
+
rec = build_record(
|
| 215 |
+
template=template,
|
| 216 |
+
template_idx=t_idx,
|
| 217 |
+
nl_question=nl_variant,
|
| 218 |
+
persona=persona,
|
| 219 |
+
source="vllm_persona",
|
| 220 |
+
validator=v,
|
| 221 |
+
)
|
| 222 |
+
if rec:
|
| 223 |
+
records.append(rec.to_training_dict())
|
| 224 |
+
|
| 225 |
+
for v in validators.values():
|
| 226 |
+
v.close()
|
| 227 |
+
|
| 228 |
+
logger.info("vLLM pipeline: %d records generated.", len(records))
|
| 229 |
+
return records
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 233 |
+
# CHECKPOINT UTILITIES
|
| 234 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 235 |
+
|
| 236 |
+
def save_checkpoint(records: list[dict], checkpoint_dir: Path, name: str) -> Path:
|
| 237 |
+
path = checkpoint_dir / f"{name}.jsonl"
|
| 238 |
+
_write_jsonl(records, path)
|
| 239 |
+
return path
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def load_checkpoint(checkpoint_dir: Path, name: str) -> Optional[list[dict]]:
|
| 243 |
+
path = checkpoint_dir / f"{name}.jsonl"
|
| 244 |
+
if not path.exists():
|
| 245 |
+
return None
|
| 246 |
+
records = []
|
| 247 |
+
with open(path, encoding="utf-8") as f:
|
| 248 |
+
for line in f:
|
| 249 |
+
line = line.strip()
|
| 250 |
+
if line:
|
| 251 |
+
records.append(json.loads(line))
|
| 252 |
+
logger.info("Loaded %d records from checkpoint %s", len(records), path)
|
| 253 |
+
return records
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 257 |
+
# DATASET STATISTICS
|
| 258 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 259 |
+
|
| 260 |
+
def print_dataset_stats(records: list[dict]) -> None:
|
| 261 |
+
from collections import Counter
|
| 262 |
+
domains = Counter(r["metadata"]["domain"] for r in records)
|
| 263 |
+
diffs = Counter(r["metadata"]["difficulty"] for r in records)
|
| 264 |
+
personas = Counter(r["metadata"]["persona"] for r in records)
|
| 265 |
+
sources = Counter(r["metadata"]["source"] for r in records)
|
| 266 |
+
|
| 267 |
+
print("\n" + "=" * 55)
|
| 268 |
+
print(f" DATASET STATISTICS ({len(records):,} total records)")
|
| 269 |
+
print("=" * 55)
|
| 270 |
+
print("\nBy Domain:")
|
| 271 |
+
for k, v in sorted(domains.items()):
|
| 272 |
+
print(f" {k:20s}: {v:6,} ({v/len(records)*100:.1f}%)")
|
| 273 |
+
print("\nBy Difficulty:")
|
| 274 |
+
for k, v in sorted(diffs.items()):
|
| 275 |
+
print(f" {k:20s}: {v:6,} ({v/len(records)*100:.1f}%)")
|
| 276 |
+
print("\nBy Persona/Source:")
|
| 277 |
+
for k, v in sorted(personas.items()):
|
| 278 |
+
print(f" {k:20s}: {v:6,}")
|
| 279 |
+
print("\nBy Source:")
|
| 280 |
+
for k, v in sorted(sources.items()):
|
| 281 |
+
print(f" {k:20s}: {v:6,}")
|
| 282 |
+
print("=" * 55 + "\n")
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 286 |
+
# MAIN ENTRY POINT
|
| 287 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 288 |
+
|
| 289 |
+
def main() -> None:
|
| 290 |
+
parser = argparse.ArgumentParser(
|
| 291 |
+
description="NL2SQL Synthetic Data Factory — generates verified training data."
|
| 292 |
+
)
|
| 293 |
+
parser.add_argument(
|
| 294 |
+
"--mode", choices=["base", "full"], default="base",
|
| 295 |
+
help="base = rule augmentation only (no GPU). full = + vLLM on H100.",
|
| 296 |
+
)
|
| 297 |
+
parser.add_argument("--model", default="meta-llama/Meta-Llama-3-70B-Instruct",
|
| 298 |
+
help="HuggingFace model name for vLLM (full mode only).")
|
| 299 |
+
parser.add_argument("--tensor-parallel", type=int, default=4,
|
| 300 |
+
help="Tensor parallel size for vLLM (number of H100s).")
|
| 301 |
+
parser.add_argument("--n-rule-augments", type=int, default=5,
|
| 302 |
+
help="Number of rule-based NL augmentations per template.")
|
| 303 |
+
parser.add_argument("--n-persona-variants", type=int, default=10,
|
| 304 |
+
help="Number of vLLM NL variants per (template, persona) pair.")
|
| 305 |
+
parser.add_argument("--batch-size", type=int, default=64,
|
| 306 |
+
help="vLLM batch size (larger = faster on H100).")
|
| 307 |
+
parser.add_argument("--temperature", type=float, default=0.85,
|
| 308 |
+
help="Sampling temperature for vLLM generation.")
|
| 309 |
+
parser.add_argument("--output-dir", type=str, default="generated_data/output",
|
| 310 |
+
help="Directory to write final dataset files.")
|
| 311 |
+
parser.add_argument("--checkpoint-dir", type=str, default="generated_data/checkpoints",
|
| 312 |
+
help="Directory for intermediate checkpoints.")
|
| 313 |
+
parser.add_argument("--seed", type=int, default=42, help="Global random seed.")
|
| 314 |
+
parser.add_argument("--no-parquet", action="store_true",
|
| 315 |
+
help="Skip Parquet output (write only JSONL).")
|
| 316 |
+
parser.add_argument("--resume", action="store_true",
|
| 317 |
+
help="Resume from latest checkpoint if available.")
|
| 318 |
+
parser.add_argument("--domains", nargs="+",
|
| 319 |
+
choices=["ecommerce","healthcare","finance","hr"],
|
| 320 |
+
default=["ecommerce","healthcare","finance","hr"],
|
| 321 |
+
help="Domains to include (default: all 4).")
|
| 322 |
+
parser.add_argument("--difficulties", nargs="+",
|
| 323 |
+
choices=["easy","medium","hard"],
|
| 324 |
+
default=["easy","medium","hard"],
|
| 325 |
+
help="Difficulty levels to include (default: all 3).")
|
| 326 |
+
args = parser.parse_args()
|
| 327 |
+
|
| 328 |
+
output_dir = Path(args.output_dir)
|
| 329 |
+
checkpoint_dir = Path(args.checkpoint_dir)
|
| 330 |
+
_ensure_dirs(output_dir, checkpoint_dir)
|
| 331 |
+
|
| 332 |
+
# ── Load templates ─────────────────────────────────────────────────────
|
| 333 |
+
from data_factory.templates import ALL_TEMPLATES
|
| 334 |
+
|
| 335 |
+
templates = [
|
| 336 |
+
t for t in ALL_TEMPLATES
|
| 337 |
+
if t["domain"] in args.domains and t["difficulty"] in args.difficulties
|
| 338 |
+
]
|
| 339 |
+
logger.info("Loaded %d templates (domains=%s, difficulties=%s).",
|
| 340 |
+
len(templates), args.domains, args.difficulties)
|
| 341 |
+
|
| 342 |
+
# ── Phase 1: Base + rule augmentation ─────────────────────────────────
|
| 343 |
+
all_records: list[dict] = []
|
| 344 |
+
|
| 345 |
+
ckpt_base = load_checkpoint(checkpoint_dir, "phase1_base") if args.resume else None
|
| 346 |
+
if ckpt_base is not None:
|
| 347 |
+
all_records.extend(ckpt_base)
|
| 348 |
+
logger.info("Resumed Phase 1 from checkpoint (%d records).", len(ckpt_base))
|
| 349 |
+
else:
|
| 350 |
+
logger.info("=== Phase 1: Base + Rule Augmentation ===")
|
| 351 |
+
base_records = run_base_pipeline(
|
| 352 |
+
templates=templates,
|
| 353 |
+
n_augmentations=args.n_rule_augments,
|
| 354 |
+
seed=args.seed,
|
| 355 |
+
)
|
| 356 |
+
all_records.extend(base_records)
|
| 357 |
+
save_checkpoint(base_records, checkpoint_dir, "phase1_base")
|
| 358 |
+
|
| 359 |
+
# ── Phase 2: vLLM persona generation (full mode only) ─────────────────
|
| 360 |
+
if args.mode == "full":
|
| 361 |
+
ckpt_vllm = load_checkpoint(checkpoint_dir, "phase2_vllm") if args.resume else None
|
| 362 |
+
if ckpt_vllm is not None:
|
| 363 |
+
all_records.extend(ckpt_vllm)
|
| 364 |
+
logger.info("Resumed Phase 2 from checkpoint (%d records).", len(ckpt_vllm))
|
| 365 |
+
else:
|
| 366 |
+
logger.info("=== Phase 2: vLLM Persona Generation ===")
|
| 367 |
+
|
| 368 |
+
from data_factory.generator import VLLMGenerator
|
| 369 |
+
from data_factory.config import PERSONAS
|
| 370 |
+
|
| 371 |
+
generator = VLLMGenerator(
|
| 372 |
+
model_name=args.model,
|
| 373 |
+
mode="offline",
|
| 374 |
+
tensor_parallel_size=args.tensor_parallel,
|
| 375 |
+
gpu_memory_utilization=0.90,
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
vllm_records = run_vllm_pipeline(
|
| 379 |
+
templates=templates,
|
| 380 |
+
generator=generator,
|
| 381 |
+
personas=PERSONAS,
|
| 382 |
+
n_variants_per_persona=args.n_persona_variants,
|
| 383 |
+
batch_size=args.batch_size,
|
| 384 |
+
temperature=args.temperature,
|
| 385 |
+
max_new_tokens=350,
|
| 386 |
+
seed=args.seed,
|
| 387 |
+
)
|
| 388 |
+
all_records.extend(vllm_records)
|
| 389 |
+
save_checkpoint(vllm_records, checkpoint_dir, "phase2_vllm")
|
| 390 |
+
|
| 391 |
+
# ── Deduplication ──────────────────────────────────────────────────────
|
| 392 |
+
logger.info("Deduplicating %d records...", len(all_records))
|
| 393 |
+
seen_nl: set[str] = set()
|
| 394 |
+
deduped: list[dict] = []
|
| 395 |
+
for rec in all_records:
|
| 396 |
+
nl = rec["prompt"][1]["content"] # user message contains the NL question
|
| 397 |
+
if nl not in seen_nl:
|
| 398 |
+
seen_nl.add(nl)
|
| 399 |
+
deduped.append(rec)
|
| 400 |
+
logger.info("After dedup: %d unique records (removed %d duplicates).",
|
| 401 |
+
len(deduped), len(all_records) - len(deduped))
|
| 402 |
+
|
| 403 |
+
# ── Statistics ─────────────────────────────────────────────────────────
|
| 404 |
+
print_dataset_stats(deduped)
|
| 405 |
+
|
| 406 |
+
# ── Train / Val / Test split ───────────────────────────────────────────
|
| 407 |
+
train, val, test = _train_val_test_split(deduped, seed=args.seed)
|
| 408 |
+
logger.info("Split: train=%d | val=%d | test=%d", len(train), len(val), len(test))
|
| 409 |
+
|
| 410 |
+
# ── Write outputs ─────────────────────────────────────────────────────
|
| 411 |
+
_write_jsonl(train, output_dir / "train.jsonl")
|
| 412 |
+
_write_jsonl(val, output_dir / "val.jsonl")
|
| 413 |
+
_write_jsonl(test, output_dir / "test.jsonl")
|
| 414 |
+
|
| 415 |
+
if not args.no_parquet:
|
| 416 |
+
_write_parquet(train, output_dir / "train.parquet")
|
| 417 |
+
_write_parquet(val, output_dir / "val.parquet")
|
| 418 |
+
_write_parquet(test, output_dir / "test.parquet")
|
| 419 |
+
|
| 420 |
+
# ── Write dataset card ─────────────────────────────────────────────────
|
| 421 |
+
card = {
|
| 422 |
+
"name": "NL2SQL-Bench Synthetic Training Dataset",
|
| 423 |
+
"version": "1.0",
|
| 424 |
+
"generated_at": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
| 425 |
+
"total_records": len(deduped),
|
| 426 |
+
"splits": {"train": len(train), "val": len(val), "test": len(test)},
|
| 427 |
+
"domains": args.domains,
|
| 428 |
+
"difficulties": args.difficulties,
|
| 429 |
+
"mode": args.mode,
|
| 430 |
+
"seed": args.seed,
|
| 431 |
+
"sql_guarantee": (
|
| 432 |
+
"Every SQL in this dataset was human-authored and execution-validated "
|
| 433 |
+
"against a seeded SQLite database. Zero LLM-generated SQL."
|
| 434 |
+
),
|
| 435 |
+
}
|
| 436 |
+
with open(output_dir / "dataset_card.json", "w") as f:
|
| 437 |
+
json.dump(card, f, indent=2)
|
| 438 |
+
|
| 439 |
+
logger.info("=== Done! Dataset written to %s ===", output_dir)
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
if __name__ == "__main__":
|
| 443 |
+
main()
|
data_factory/run_data_factory.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
run_data_factory.py
|
| 3 |
+
====================
|
| 4 |
+
Entry point and smoke-test runner for the NL2SQL Data Factory.
|
| 5 |
+
|
| 6 |
+
Run this FIRST before running the full pipeline to verify:
|
| 7 |
+
1. All 66 SQL templates execute without errors
|
| 8 |
+
2. Rule augmentation produces diverse NL variants
|
| 9 |
+
3. Validators correctly accept/reject queries
|
| 10 |
+
4. Base pipeline generates well-formed JSONL records
|
| 11 |
+
|
| 12 |
+
Usage:
|
| 13 |
+
# Smoke test only (fast, ~10 seconds)
|
| 14 |
+
python run_data_factory.py --smoke-test
|
| 15 |
+
|
| 16 |
+
# Base mode (no GPU, generates all rule-augmented records)
|
| 17 |
+
python run_data_factory.py --mode base
|
| 18 |
+
|
| 19 |
+
# Full mode (H100 required)
|
| 20 |
+
python run_data_factory.py --mode full --model meta-llama/Meta-Llama-3-70B-Instruct --tensor-parallel 4
|
| 21 |
+
|
| 22 |
+
# Preview what the dataset looks like
|
| 23 |
+
python run_data_factory.py --smoke-test --show-samples 3
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
from __future__ import annotations
|
| 27 |
+
|
| 28 |
+
import argparse
|
| 29 |
+
import json
|
| 30 |
+
import sys
|
| 31 |
+
import textwrap
|
| 32 |
+
from pathlib import Path
|
| 33 |
+
|
| 34 |
+
# Allow running from project root
|
| 35 |
+
sys.path.insert(0, str(Path(__file__).parent))
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 39 |
+
# SMOKE TEST
|
| 40 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 41 |
+
|
| 42 |
+
def run_smoke_test(show_samples: int = 0) -> bool:
|
| 43 |
+
print("\n" + "=" * 60)
|
| 44 |
+
print(" NL2SQL DATA FACTORY — SMOKE TEST")
|
| 45 |
+
print("=" * 60)
|
| 46 |
+
|
| 47 |
+
all_passed = True
|
| 48 |
+
|
| 49 |
+
# 1. Template validation
|
| 50 |
+
print("\n[1/4] Validating all SQL templates against seeded data...")
|
| 51 |
+
from data_factory.templates import ALL_TEMPLATES, template_stats
|
| 52 |
+
from data_factory.validator import validate_all_templates
|
| 53 |
+
|
| 54 |
+
stats = template_stats()
|
| 55 |
+
result = validate_all_templates(ALL_TEMPLATES)
|
| 56 |
+
|
| 57 |
+
print(f" Templates: {stats}")
|
| 58 |
+
print(f" Validation: {result['passed']}/{result['total']} passed", end="")
|
| 59 |
+
|
| 60 |
+
if result["failed"]:
|
| 61 |
+
print(f" ← {result['failed']} FAILURES:")
|
| 62 |
+
for f in result["failures"]:
|
| 63 |
+
print(f" [{f['domain']}] {f['sql']}... → {f['error']}")
|
| 64 |
+
all_passed = False
|
| 65 |
+
else:
|
| 66 |
+
print(" ✓")
|
| 67 |
+
|
| 68 |
+
# 2. Rule augmentation
|
| 69 |
+
print("\n[2/4] Testing rule-based augmentation...")
|
| 70 |
+
from data_factory.augmentor import augment_nl
|
| 71 |
+
|
| 72 |
+
test_nls = [
|
| 73 |
+
"List all gold-tier customers ordered by name alphabetically. Return id, name, email, country.",
|
| 74 |
+
"Which medications are prescribed most often? Return medication_name, category, times_prescribed.",
|
| 75 |
+
"Rank active employees by salary within their department. Return salary_rank.",
|
| 76 |
+
]
|
| 77 |
+
for nl in test_nls:
|
| 78 |
+
variants = augment_nl(nl, n=3, seed=42)
|
| 79 |
+
if not variants:
|
| 80 |
+
print(f" FAIL: No variants generated for: {nl[:50]}")
|
| 81 |
+
all_passed = False
|
| 82 |
+
else:
|
| 83 |
+
print(f" ✓ {len(variants)} variants from: '{nl[:45]}...'")
|
| 84 |
+
if show_samples > 0:
|
| 85 |
+
for i, v in enumerate(variants[:show_samples]):
|
| 86 |
+
print(f" [{i+1}] {v}")
|
| 87 |
+
|
| 88 |
+
# 3. Validator accept/reject
|
| 89 |
+
print("\n[3/4] Testing SQL validator accept/reject logic...")
|
| 90 |
+
from data_factory.validator import SQLValidator
|
| 91 |
+
|
| 92 |
+
v = SQLValidator("ecommerce")
|
| 93 |
+
tests = [
|
| 94 |
+
("SELECT id, name FROM customers WHERE tier = 'gold'", True, "valid SELECT"),
|
| 95 |
+
("INSERT INTO customers VALUES (1,'x','x@x.com','IN','gold','2024-01-01')", False, "rejected INSERT"),
|
| 96 |
+
("SELECT nonexistent_col FROM customers", False, "bad column name"),
|
| 97 |
+
("", False, "empty string"),
|
| 98 |
+
]
|
| 99 |
+
for sql, expect_pass, label in tests:
|
| 100 |
+
vr = v.validate(sql)
|
| 101 |
+
status = "✓" if vr.passed == expect_pass else "✗"
|
| 102 |
+
print(f" {status} {label}: passed={vr.passed}", end="")
|
| 103 |
+
if not vr.passed:
|
| 104 |
+
print(f" (error: {vr.error})", end="")
|
| 105 |
+
print()
|
| 106 |
+
if vr.passed != expect_pass:
|
| 107 |
+
all_passed = False
|
| 108 |
+
v.close()
|
| 109 |
+
|
| 110 |
+
# 4. Mini base pipeline (first 5 templates only)
|
| 111 |
+
print("\n[4/4] Running mini base pipeline (first 5 templates)...")
|
| 112 |
+
from data_factory.pipeline import run_base_pipeline
|
| 113 |
+
|
| 114 |
+
mini_templates = ALL_TEMPLATES[:5]
|
| 115 |
+
records = run_base_pipeline(mini_templates, n_augmentations=2, seed=42)
|
| 116 |
+
expected_min = 5 # at least canonical NLs
|
| 117 |
+
if len(records) < expected_min:
|
| 118 |
+
print(f" FAIL: Only {len(records)} records (expected ≥{expected_min})")
|
| 119 |
+
all_passed = False
|
| 120 |
+
else:
|
| 121 |
+
print(f" ✓ Generated {len(records)} records from 5 templates")
|
| 122 |
+
|
| 123 |
+
# Validate structure
|
| 124 |
+
required_keys = {"prompt", "sql", "metadata"}
|
| 125 |
+
for rec in records[:3]:
|
| 126 |
+
missing = required_keys - rec.keys()
|
| 127 |
+
if missing:
|
| 128 |
+
print(f" FAIL: Record missing keys: {missing}")
|
| 129 |
+
all_passed = False
|
| 130 |
+
break
|
| 131 |
+
else:
|
| 132 |
+
print(" ✓ Record structure validated")
|
| 133 |
+
|
| 134 |
+
if show_samples > 0 and records:
|
| 135 |
+
print(f"\n --- Sample Record ---")
|
| 136 |
+
sample = records[0]
|
| 137 |
+
print(f" Domain: {sample['metadata']['domain']}")
|
| 138 |
+
print(f" Difficulty: {sample['metadata']['difficulty']}")
|
| 139 |
+
print(f" Persona: {sample['metadata']['persona']}")
|
| 140 |
+
print(f" NL: {sample['prompt'][1]['content'].split('QUESTION: ')[-1][:100]}")
|
| 141 |
+
print(f" SQL: {sample['sql'][:80]}...")
|
| 142 |
+
|
| 143 |
+
# Summary
|
| 144 |
+
print("\n" + "=" * 60)
|
| 145 |
+
if all_passed:
|
| 146 |
+
print(" ALL SMOKE TESTS PASSED ✓")
|
| 147 |
+
print(" Safe to run: python run_data_factory.py --mode base")
|
| 148 |
+
else:
|
| 149 |
+
print(" SOME TESTS FAILED ✗ — fix errors before running pipeline")
|
| 150 |
+
print("=" * 60 + "\n")
|
| 151 |
+
|
| 152 |
+
return all_passed
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 156 |
+
# INSPECT DATASET
|
| 157 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 158 |
+
|
| 159 |
+
def inspect_dataset(jsonl_path: str, n: int = 5) -> None:
|
| 160 |
+
"""Pretty-print N records from an output JSONL file."""
|
| 161 |
+
path = Path(jsonl_path)
|
| 162 |
+
if not path.exists():
|
| 163 |
+
print(f"File not found: {path}")
|
| 164 |
+
return
|
| 165 |
+
|
| 166 |
+
records = []
|
| 167 |
+
with open(path, encoding="utf-8") as f:
|
| 168 |
+
for i, line in enumerate(f):
|
| 169 |
+
if i >= n:
|
| 170 |
+
break
|
| 171 |
+
records.append(json.loads(line))
|
| 172 |
+
|
| 173 |
+
print(f"\n{'='*65}")
|
| 174 |
+
print(f" Showing {len(records)} records from {path.name}")
|
| 175 |
+
print(f"{'='*65}")
|
| 176 |
+
|
| 177 |
+
for i, rec in enumerate(records):
|
| 178 |
+
nl = rec["prompt"][1]["content"].split("QUESTION:")[-1].strip()
|
| 179 |
+
sql = rec["sql"]
|
| 180 |
+
meta = rec["metadata"]
|
| 181 |
+
print(f"\n[{i+1}] Domain={meta['domain']} | Difficulty={meta['difficulty']} | "
|
| 182 |
+
f"Persona={meta['persona']} | Source={meta['source']}")
|
| 183 |
+
print(f" NL: {textwrap.shorten(nl, 90)}")
|
| 184 |
+
print(f" SQL: {textwrap.shorten(sql, 90)}")
|
| 185 |
+
|
| 186 |
+
print()
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 190 |
+
# MAIN
|
| 191 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 192 |
+
|
| 193 |
+
def main() -> None:
|
| 194 |
+
parser = argparse.ArgumentParser(
|
| 195 |
+
description="NL2SQL Data Factory — entry point.",
|
| 196 |
+
formatter_class=argparse.RawTextHelpFormatter,
|
| 197 |
+
)
|
| 198 |
+
parser.add_argument(
|
| 199 |
+
"--smoke-test", action="store_true",
|
| 200 |
+
help="Run smoke test only (validates all templates, no output written).",
|
| 201 |
+
)
|
| 202 |
+
parser.add_argument(
|
| 203 |
+
"--show-samples", type=int, default=0,
|
| 204 |
+
help="During smoke test, show N sample NL variants and records.",
|
| 205 |
+
)
|
| 206 |
+
parser.add_argument(
|
| 207 |
+
"--inspect", type=str, default=None,
|
| 208 |
+
help="Path to a JSONL output file to inspect.",
|
| 209 |
+
)
|
| 210 |
+
parser.add_argument(
|
| 211 |
+
"--inspect-n", type=int, default=5,
|
| 212 |
+
help="Number of records to show when inspecting.",
|
| 213 |
+
)
|
| 214 |
+
parser.add_argument(
|
| 215 |
+
"--mode", choices=["base", "full"], default="base",
|
| 216 |
+
help=(
|
| 217 |
+
"base: rule augmentation only, ~450 records, no GPU needed.\n"
|
| 218 |
+
"full: + vLLM persona variants, 500K+ records, H100 required."
|
| 219 |
+
),
|
| 220 |
+
)
|
| 221 |
+
parser.add_argument("--model", default="meta-llama/Meta-Llama-3-70B-Instruct")
|
| 222 |
+
parser.add_argument("--tensor-parallel", type=int, default=4)
|
| 223 |
+
parser.add_argument("--n-rule-augments", type=int, default=5)
|
| 224 |
+
parser.add_argument("--n-persona-variants", type=int, default=10)
|
| 225 |
+
parser.add_argument("--batch-size", type=int, default=64)
|
| 226 |
+
parser.add_argument("--temperature", type=float, default=0.85)
|
| 227 |
+
parser.add_argument("--output-dir", default="generated_data/output")
|
| 228 |
+
parser.add_argument("--checkpoint-dir", default="generated_data/checkpoints")
|
| 229 |
+
parser.add_argument("--seed", type=int, default=42)
|
| 230 |
+
parser.add_argument("--no-parquet", action="store_true")
|
| 231 |
+
parser.add_argument("--resume", action="store_true")
|
| 232 |
+
parser.add_argument(
|
| 233 |
+
"--domains", nargs="+",
|
| 234 |
+
choices=["ecommerce","healthcare","finance","hr"],
|
| 235 |
+
default=["ecommerce","healthcare","finance","hr"],
|
| 236 |
+
)
|
| 237 |
+
parser.add_argument(
|
| 238 |
+
"--difficulties", nargs="+",
|
| 239 |
+
choices=["easy","medium","hard"],
|
| 240 |
+
default=["easy","medium","hard"],
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
args = parser.parse_args()
|
| 244 |
+
|
| 245 |
+
if args.smoke_test:
|
| 246 |
+
ok = run_smoke_test(show_samples=args.show_samples)
|
| 247 |
+
sys.exit(0 if ok else 1)
|
| 248 |
+
|
| 249 |
+
if args.inspect:
|
| 250 |
+
inspect_dataset(args.inspect, n=args.inspect_n)
|
| 251 |
+
sys.exit(0)
|
| 252 |
+
|
| 253 |
+
# Forward to pipeline
|
| 254 |
+
from data_factory.pipeline import main as pipeline_main
|
| 255 |
+
# Re-parse with pipeline's own parser by forwarding sys.argv
|
| 256 |
+
pipeline_main()
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
if __name__ == "__main__":
|
| 260 |
+
main()
|
data_factory/schemas.py
ADDED
|
@@ -0,0 +1,564 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
data_factory/schemas.py
|
| 3 |
+
========================
|
| 4 |
+
SQLite CREATE TABLE statements for all four domains.
|
| 5 |
+
Each schema is fully self-contained and has been verified to create
|
| 6 |
+
without errors in SQLite 3.x.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
import sqlite3
|
| 11 |
+
import random
|
| 12 |
+
from datetime import date, timedelta
|
| 13 |
+
from typing import Callable
|
| 14 |
+
|
| 15 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 16 |
+
# SQL SCHEMAS
|
| 17 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 18 |
+
|
| 19 |
+
ECOMMERCE_SCHEMA = """
|
| 20 |
+
CREATE TABLE IF NOT EXISTS categories (
|
| 21 |
+
id INTEGER PRIMARY KEY,
|
| 22 |
+
name TEXT NOT NULL UNIQUE
|
| 23 |
+
);
|
| 24 |
+
|
| 25 |
+
CREATE TABLE IF NOT EXISTS products (
|
| 26 |
+
id INTEGER PRIMARY KEY,
|
| 27 |
+
name TEXT NOT NULL,
|
| 28 |
+
category_id INTEGER NOT NULL REFERENCES categories(id),
|
| 29 |
+
price REAL NOT NULL CHECK(price >= 0),
|
| 30 |
+
stock_quantity INTEGER NOT NULL DEFAULT 0
|
| 31 |
+
);
|
| 32 |
+
|
| 33 |
+
CREATE TABLE IF NOT EXISTS customers (
|
| 34 |
+
id INTEGER PRIMARY KEY,
|
| 35 |
+
name TEXT NOT NULL,
|
| 36 |
+
email TEXT NOT NULL UNIQUE,
|
| 37 |
+
country TEXT NOT NULL,
|
| 38 |
+
tier TEXT NOT NULL DEFAULT 'bronze'
|
| 39 |
+
CHECK(tier IN ('bronze', 'silver', 'gold')),
|
| 40 |
+
created_at TEXT NOT NULL
|
| 41 |
+
);
|
| 42 |
+
|
| 43 |
+
CREATE TABLE IF NOT EXISTS orders (
|
| 44 |
+
id INTEGER PRIMARY KEY,
|
| 45 |
+
customer_id INTEGER NOT NULL REFERENCES customers(id),
|
| 46 |
+
status TEXT NOT NULL DEFAULT 'pending'
|
| 47 |
+
CHECK(status IN ('pending','processing','shipped','delivered','cancelled')),
|
| 48 |
+
created_at TEXT NOT NULL,
|
| 49 |
+
total_amount REAL NOT NULL CHECK(total_amount >= 0)
|
| 50 |
+
);
|
| 51 |
+
|
| 52 |
+
CREATE TABLE IF NOT EXISTS order_items (
|
| 53 |
+
id INTEGER PRIMARY KEY,
|
| 54 |
+
order_id INTEGER NOT NULL REFERENCES orders(id),
|
| 55 |
+
product_id INTEGER NOT NULL REFERENCES products(id),
|
| 56 |
+
quantity INTEGER NOT NULL CHECK(quantity > 0),
|
| 57 |
+
unit_price REAL NOT NULL CHECK(unit_price >= 0)
|
| 58 |
+
);
|
| 59 |
+
|
| 60 |
+
CREATE TABLE IF NOT EXISTS reviews (
|
| 61 |
+
id INTEGER PRIMARY KEY,
|
| 62 |
+
product_id INTEGER NOT NULL REFERENCES products(id),
|
| 63 |
+
customer_id INTEGER NOT NULL REFERENCES customers(id),
|
| 64 |
+
rating INTEGER NOT NULL CHECK(rating BETWEEN 1 AND 5),
|
| 65 |
+
created_at TEXT NOT NULL
|
| 66 |
+
);
|
| 67 |
+
|
| 68 |
+
CREATE INDEX IF NOT EXISTS idx_products_category ON products(category_id);
|
| 69 |
+
CREATE INDEX IF NOT EXISTS idx_orders_customer ON orders(customer_id);
|
| 70 |
+
CREATE INDEX IF NOT EXISTS idx_orders_status ON orders(status);
|
| 71 |
+
CREATE INDEX IF NOT EXISTS idx_orders_created ON orders(created_at);
|
| 72 |
+
CREATE INDEX IF NOT EXISTS idx_order_items_order ON order_items(order_id);
|
| 73 |
+
CREATE INDEX IF NOT EXISTS idx_order_items_product ON order_items(product_id);
|
| 74 |
+
CREATE INDEX IF NOT EXISTS idx_reviews_product ON reviews(product_id);
|
| 75 |
+
CREATE INDEX IF NOT EXISTS idx_customers_tier ON customers(tier);
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
HEALTHCARE_SCHEMA = """
|
| 79 |
+
CREATE TABLE IF NOT EXISTS patients (
|
| 80 |
+
id INTEGER PRIMARY KEY,
|
| 81 |
+
name TEXT NOT NULL,
|
| 82 |
+
date_of_birth TEXT NOT NULL,
|
| 83 |
+
gender TEXT NOT NULL CHECK(gender IN ('M','F','Other')),
|
| 84 |
+
blood_type TEXT NOT NULL,
|
| 85 |
+
country TEXT NOT NULL,
|
| 86 |
+
registered_at TEXT NOT NULL
|
| 87 |
+
);
|
| 88 |
+
|
| 89 |
+
CREATE TABLE IF NOT EXISTS doctors (
|
| 90 |
+
id INTEGER PRIMARY KEY,
|
| 91 |
+
name TEXT NOT NULL,
|
| 92 |
+
specialization TEXT NOT NULL,
|
| 93 |
+
department TEXT NOT NULL,
|
| 94 |
+
experience_years INTEGER NOT NULL CHECK(experience_years >= 0),
|
| 95 |
+
consultation_fee REAL NOT NULL CHECK(consultation_fee >= 0)
|
| 96 |
+
);
|
| 97 |
+
|
| 98 |
+
CREATE TABLE IF NOT EXISTS appointments (
|
| 99 |
+
id INTEGER PRIMARY KEY,
|
| 100 |
+
patient_id INTEGER NOT NULL REFERENCES patients(id),
|
| 101 |
+
doctor_id INTEGER NOT NULL REFERENCES doctors(id),
|
| 102 |
+
scheduled_at TEXT NOT NULL,
|
| 103 |
+
status TEXT NOT NULL
|
| 104 |
+
CHECK(status IN ('scheduled','completed','cancelled','no_show')),
|
| 105 |
+
notes TEXT
|
| 106 |
+
);
|
| 107 |
+
|
| 108 |
+
CREATE TABLE IF NOT EXISTS diagnoses (
|
| 109 |
+
id INTEGER PRIMARY KEY,
|
| 110 |
+
appointment_id INTEGER NOT NULL REFERENCES appointments(id),
|
| 111 |
+
icd_code TEXT NOT NULL,
|
| 112 |
+
description TEXT NOT NULL,
|
| 113 |
+
severity TEXT NOT NULL CHECK(severity IN ('mild','moderate','severe'))
|
| 114 |
+
);
|
| 115 |
+
|
| 116 |
+
CREATE TABLE IF NOT EXISTS medications (
|
| 117 |
+
id INTEGER PRIMARY KEY,
|
| 118 |
+
name TEXT NOT NULL,
|
| 119 |
+
category TEXT NOT NULL,
|
| 120 |
+
unit_price REAL NOT NULL CHECK(unit_price >= 0)
|
| 121 |
+
);
|
| 122 |
+
|
| 123 |
+
CREATE TABLE IF NOT EXISTS prescriptions (
|
| 124 |
+
id INTEGER PRIMARY KEY,
|
| 125 |
+
appointment_id INTEGER NOT NULL REFERENCES appointments(id),
|
| 126 |
+
medication_id INTEGER NOT NULL REFERENCES medications(id),
|
| 127 |
+
dosage TEXT NOT NULL,
|
| 128 |
+
duration_days INTEGER NOT NULL CHECK(duration_days > 0),
|
| 129 |
+
quantity INTEGER NOT NULL CHECK(quantity > 0)
|
| 130 |
+
);
|
| 131 |
+
|
| 132 |
+
CREATE INDEX IF NOT EXISTS idx_appt_patient ON appointments(patient_id);
|
| 133 |
+
CREATE INDEX IF NOT EXISTS idx_appt_doctor ON appointments(doctor_id);
|
| 134 |
+
CREATE INDEX IF NOT EXISTS idx_appt_status ON appointments(status);
|
| 135 |
+
CREATE INDEX IF NOT EXISTS idx_diag_appt ON diagnoses(appointment_id);
|
| 136 |
+
CREATE INDEX IF NOT EXISTS idx_presc_appt ON prescriptions(appointment_id);
|
| 137 |
+
CREATE INDEX IF NOT EXISTS idx_presc_med ON prescriptions(medication_id);
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
FINANCE_SCHEMA = """
|
| 141 |
+
CREATE TABLE IF NOT EXISTS fin_customers (
|
| 142 |
+
id INTEGER PRIMARY KEY,
|
| 143 |
+
name TEXT NOT NULL,
|
| 144 |
+
email TEXT NOT NULL UNIQUE,
|
| 145 |
+
country TEXT NOT NULL,
|
| 146 |
+
kyc_status TEXT NOT NULL CHECK(kyc_status IN ('pending','verified','rejected')),
|
| 147 |
+
created_at TEXT NOT NULL
|
| 148 |
+
);
|
| 149 |
+
|
| 150 |
+
CREATE TABLE IF NOT EXISTS accounts (
|
| 151 |
+
id INTEGER PRIMARY KEY,
|
| 152 |
+
customer_id INTEGER NOT NULL REFERENCES fin_customers(id),
|
| 153 |
+
account_type TEXT NOT NULL
|
| 154 |
+
CHECK(account_type IN ('savings','current','fixed_deposit','loan')),
|
| 155 |
+
balance REAL NOT NULL DEFAULT 0,
|
| 156 |
+
currency TEXT NOT NULL DEFAULT 'USD',
|
| 157 |
+
status TEXT NOT NULL CHECK(status IN ('active','dormant','closed')),
|
| 158 |
+
opened_at TEXT NOT NULL
|
| 159 |
+
);
|
| 160 |
+
|
| 161 |
+
CREATE TABLE IF NOT EXISTS transactions (
|
| 162 |
+
id INTEGER PRIMARY KEY,
|
| 163 |
+
account_id INTEGER NOT NULL REFERENCES accounts(id),
|
| 164 |
+
txn_type TEXT NOT NULL CHECK(txn_type IN ('credit','debit')),
|
| 165 |
+
amount REAL NOT NULL CHECK(amount > 0),
|
| 166 |
+
currency TEXT NOT NULL DEFAULT 'USD',
|
| 167 |
+
merchant TEXT,
|
| 168 |
+
created_at TEXT NOT NULL
|
| 169 |
+
);
|
| 170 |
+
|
| 171 |
+
CREATE TABLE IF NOT EXISTS loans (
|
| 172 |
+
id INTEGER PRIMARY KEY,
|
| 173 |
+
customer_id INTEGER NOT NULL REFERENCES fin_customers(id),
|
| 174 |
+
loan_type TEXT NOT NULL
|
| 175 |
+
CHECK(loan_type IN ('personal','home','auto','business')),
|
| 176 |
+
principal_amount REAL NOT NULL,
|
| 177 |
+
interest_rate REAL NOT NULL,
|
| 178 |
+
tenure_months INTEGER NOT NULL,
|
| 179 |
+
status TEXT NOT NULL CHECK(status IN ('active','closed','defaulted')),
|
| 180 |
+
disbursed_at TEXT NOT NULL
|
| 181 |
+
);
|
| 182 |
+
|
| 183 |
+
CREATE TABLE IF NOT EXISTS loan_payments (
|
| 184 |
+
id INTEGER PRIMARY KEY,
|
| 185 |
+
loan_id INTEGER NOT NULL REFERENCES loans(id),
|
| 186 |
+
amount_paid REAL NOT NULL CHECK(amount_paid > 0),
|
| 187 |
+
payment_date TEXT NOT NULL,
|
| 188 |
+
is_late INTEGER NOT NULL DEFAULT 0 CHECK(is_late IN (0,1))
|
| 189 |
+
);
|
| 190 |
+
|
| 191 |
+
CREATE INDEX IF NOT EXISTS idx_acct_customer ON accounts(customer_id);
|
| 192 |
+
CREATE INDEX IF NOT EXISTS idx_txn_account ON transactions(account_id);
|
| 193 |
+
CREATE INDEX IF NOT EXISTS idx_txn_type ON transactions(txn_type);
|
| 194 |
+
CREATE INDEX IF NOT EXISTS idx_loan_customer ON loans(customer_id);
|
| 195 |
+
CREATE INDEX IF NOT EXISTS idx_lp_loan ON loan_payments(loan_id);
|
| 196 |
+
"""
|
| 197 |
+
|
| 198 |
+
HR_SCHEMA = """
|
| 199 |
+
CREATE TABLE IF NOT EXISTS departments (
|
| 200 |
+
id INTEGER PRIMARY KEY,
|
| 201 |
+
name TEXT NOT NULL UNIQUE,
|
| 202 |
+
location TEXT NOT NULL,
|
| 203 |
+
budget REAL NOT NULL CHECK(budget >= 0)
|
| 204 |
+
);
|
| 205 |
+
|
| 206 |
+
CREATE TABLE IF NOT EXISTS employees (
|
| 207 |
+
id INTEGER PRIMARY KEY,
|
| 208 |
+
name TEXT NOT NULL,
|
| 209 |
+
email TEXT NOT NULL UNIQUE,
|
| 210 |
+
department_id INTEGER NOT NULL REFERENCES departments(id),
|
| 211 |
+
job_title TEXT NOT NULL,
|
| 212 |
+
hire_date TEXT NOT NULL,
|
| 213 |
+
salary REAL NOT NULL CHECK(salary >= 0),
|
| 214 |
+
status TEXT NOT NULL CHECK(status IN ('active','resigned','terminated'))
|
| 215 |
+
);
|
| 216 |
+
|
| 217 |
+
CREATE TABLE IF NOT EXISTS performance_reviews (
|
| 218 |
+
id INTEGER PRIMARY KEY,
|
| 219 |
+
employee_id INTEGER NOT NULL REFERENCES employees(id),
|
| 220 |
+
review_year INTEGER NOT NULL,
|
| 221 |
+
rating INTEGER NOT NULL CHECK(rating BETWEEN 1 AND 5),
|
| 222 |
+
reviewer_id INTEGER NOT NULL REFERENCES employees(id),
|
| 223 |
+
comments TEXT
|
| 224 |
+
);
|
| 225 |
+
|
| 226 |
+
CREATE TABLE IF NOT EXISTS projects (
|
| 227 |
+
id INTEGER PRIMARY KEY,
|
| 228 |
+
name TEXT NOT NULL,
|
| 229 |
+
department_id INTEGER NOT NULL REFERENCES departments(id),
|
| 230 |
+
start_date TEXT NOT NULL,
|
| 231 |
+
end_date TEXT,
|
| 232 |
+
budget REAL NOT NULL,
|
| 233 |
+
status TEXT NOT NULL
|
| 234 |
+
CHECK(status IN ('planned','active','completed','cancelled'))
|
| 235 |
+
);
|
| 236 |
+
|
| 237 |
+
CREATE TABLE IF NOT EXISTS project_assignments (
|
| 238 |
+
id INTEGER PRIMARY KEY,
|
| 239 |
+
employee_id INTEGER NOT NULL REFERENCES employees(id),
|
| 240 |
+
project_id INTEGER NOT NULL REFERENCES projects(id),
|
| 241 |
+
role TEXT NOT NULL,
|
| 242 |
+
hours_allocated INTEGER NOT NULL CHECK(hours_allocated > 0)
|
| 243 |
+
);
|
| 244 |
+
|
| 245 |
+
CREATE INDEX IF NOT EXISTS idx_emp_dept ON employees(department_id);
|
| 246 |
+
CREATE INDEX IF NOT EXISTS idx_emp_status ON employees(status);
|
| 247 |
+
CREATE INDEX IF NOT EXISTS idx_pr_employee ON performance_reviews(employee_id);
|
| 248 |
+
CREATE INDEX IF NOT EXISTS idx_proj_dept ON projects(department_id);
|
| 249 |
+
CREATE INDEX IF NOT EXISTS idx_pa_employee ON project_assignments(employee_id);
|
| 250 |
+
CREATE INDEX IF NOT EXISTS idx_pa_project ON project_assignments(project_id);
|
| 251 |
+
"""
|
| 252 |
+
|
| 253 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 254 |
+
# SCHEMA REGISTRY
|
| 255 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 256 |
+
|
| 257 |
+
SCHEMA_MAP: dict[str, str] = {
|
| 258 |
+
"ecommerce": ECOMMERCE_SCHEMA,
|
| 259 |
+
"healthcare": HEALTHCARE_SCHEMA,
|
| 260 |
+
"finance": FINANCE_SCHEMA,
|
| 261 |
+
"hr": HR_SCHEMA,
|
| 262 |
+
}
|
| 263 |
+
|
| 264 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 265 |
+
# COMPACT SCHEMA CONTEXT (injected into every training prompt)
|
| 266 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 267 |
+
|
| 268 |
+
SCHEMA_CONTEXT: dict[str, str] = {
|
| 269 |
+
"ecommerce": """\
|
| 270 |
+
Database: ecommerce (SQLite, read-only)
|
| 271 |
+
|
| 272 |
+
TABLES
|
| 273 |
+
------
|
| 274 |
+
categories(id INTEGER PK, name TEXT)
|
| 275 |
+
products(id INTEGER PK, name TEXT, category_id INTEGER FK→categories.id, price REAL, stock_quantity INTEGER)
|
| 276 |
+
customers(id INTEGER PK, name TEXT, email TEXT, country TEXT, tier TEXT ∈ {bronze|silver|gold}, created_at TEXT ISO-8601)
|
| 277 |
+
orders(id INTEGER PK, customer_id INTEGER FK→customers.id, status TEXT ∈ {pending|processing|shipped|delivered|cancelled}, created_at TEXT ISO-8601, total_amount REAL)
|
| 278 |
+
order_items(id INTEGER PK, order_id INTEGER FK→orders.id, product_id INTEGER FK→products.id, quantity INTEGER, unit_price REAL)
|
| 279 |
+
reviews(id INTEGER PK, product_id INTEGER FK→products.id, customer_id INTEGER FK→customers.id, rating INTEGER 1-5, created_at TEXT ISO-8601)
|
| 280 |
+
|
| 281 |
+
NOTES
|
| 282 |
+
-----
|
| 283 |
+
- Use created_at >= '2024-01-01' for date filtering (ISO text sort works)
|
| 284 |
+
- SQLite window functions: RANK, DENSE_RANK, ROW_NUMBER, LAG, LEAD
|
| 285 |
+
- strftime('%Y-%m', created_at) returns 'YYYY-MM'
|
| 286 |
+
- All monetary values in USD
|
| 287 |
+
""",
|
| 288 |
+
|
| 289 |
+
"healthcare": """\
|
| 290 |
+
Database: healthcare (SQLite, read-only)
|
| 291 |
+
|
| 292 |
+
TABLES
|
| 293 |
+
------
|
| 294 |
+
patients(id INTEGER PK, name TEXT, date_of_birth TEXT ISO-8601, gender TEXT ∈ {M|F|Other}, blood_type TEXT, country TEXT, registered_at TEXT ISO-8601)
|
| 295 |
+
doctors(id INTEGER PK, name TEXT, specialization TEXT, department TEXT, experience_years INTEGER, consultation_fee REAL)
|
| 296 |
+
appointments(id INTEGER PK, patient_id INTEGER FK→patients.id, doctor_id INTEGER FK→doctors.id, scheduled_at TEXT ISO-8601, status TEXT ∈ {scheduled|completed|cancelled|no_show}, notes TEXT nullable)
|
| 297 |
+
diagnoses(id INTEGER PK, appointment_id INTEGER FK→appointments.id, icd_code TEXT, description TEXT, severity TEXT ∈ {mild|moderate|severe})
|
| 298 |
+
medications(id INTEGER PK, name TEXT, category TEXT, unit_price REAL)
|
| 299 |
+
prescriptions(id INTEGER PK, appointment_id INTEGER FK→appointments.id, medication_id INTEGER FK→medications.id, dosage TEXT, duration_days INTEGER, quantity INTEGER)
|
| 300 |
+
|
| 301 |
+
NOTES
|
| 302 |
+
-----
|
| 303 |
+
- consultation_fee is in USD per visit
|
| 304 |
+
- ICD codes follow WHO ICD-10 format (e.g. 'I10', 'E11')
|
| 305 |
+
- SQLite window functions available
|
| 306 |
+
""",
|
| 307 |
+
|
| 308 |
+
"finance": """\
|
| 309 |
+
Database: finance (SQLite, read-only)
|
| 310 |
+
|
| 311 |
+
TABLES
|
| 312 |
+
------
|
| 313 |
+
fin_customers(id INTEGER PK, name TEXT, email TEXT, country TEXT, kyc_status TEXT ∈ {pending|verified|rejected}, created_at TEXT ISO-8601)
|
| 314 |
+
accounts(id INTEGER PK, customer_id INTEGER FK→fin_customers.id, account_type TEXT ∈ {savings|current|fixed_deposit|loan}, balance REAL, currency TEXT, status TEXT ∈ {active|dormant|closed}, opened_at TEXT ISO-8601)
|
| 315 |
+
transactions(id INTEGER PK, account_id INTEGER FK→accounts.id, txn_type TEXT ∈ {credit|debit}, amount REAL, currency TEXT, merchant TEXT nullable, created_at TEXT ISO-8601)
|
| 316 |
+
loans(id INTEGER PK, customer_id INTEGER FK→fin_customers.id, loan_type TEXT ∈ {personal|home|auto|business}, principal_amount REAL, interest_rate REAL, tenure_months INTEGER, status TEXT ∈ {active|closed|defaulted}, disbursed_at TEXT ISO-8601)
|
| 317 |
+
loan_payments(id INTEGER PK, loan_id INTEGER FK→loans.id, amount_paid REAL, payment_date TEXT ISO-8601, is_late INTEGER ∈ {0|1})
|
| 318 |
+
|
| 319 |
+
NOTES
|
| 320 |
+
-----
|
| 321 |
+
- All monetary values in USD unless currency column specifies otherwise
|
| 322 |
+
- is_late = 1 means the payment was overdue
|
| 323 |
+
- SQLite window functions available
|
| 324 |
+
""",
|
| 325 |
+
|
| 326 |
+
"hr": """\
|
| 327 |
+
Database: hr (SQLite, read-only)
|
| 328 |
+
|
| 329 |
+
TABLES
|
| 330 |
+
------
|
| 331 |
+
departments(id INTEGER PK, name TEXT, location TEXT, budget REAL)
|
| 332 |
+
employees(id INTEGER PK, name TEXT, email TEXT, department_id INTEGER FK→departments.id, job_title TEXT, hire_date TEXT ISO-8601, salary REAL, status TEXT ∈ {active|resigned|terminated})
|
| 333 |
+
performance_reviews(id INTEGER PK, employee_id INTEGER FK→employees.id, review_year INTEGER, rating INTEGER 1-5, reviewer_id INTEGER FK→employees.id, comments TEXT nullable)
|
| 334 |
+
projects(id INTEGER PK, name TEXT, department_id INTEGER FK→departments.id, start_date TEXT ISO-8601, end_date TEXT nullable, budget REAL, status TEXT ∈ {planned|active|completed|cancelled})
|
| 335 |
+
project_assignments(id INTEGER PK, employee_id INTEGER FK→employees.id, project_id INTEGER FK→projects.id, role TEXT, hours_allocated INTEGER)
|
| 336 |
+
|
| 337 |
+
NOTES
|
| 338 |
+
-----
|
| 339 |
+
- salary is annual in USD
|
| 340 |
+
- performance rating: 1 (lowest) to 5 (highest)
|
| 341 |
+
- end_date is NULL for ongoing projects
|
| 342 |
+
- SQLite window functions available
|
| 343 |
+
""",
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 348 |
+
# SEED FUNCTIONS (deterministic, SEED=42)
|
| 349 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 350 |
+
|
| 351 |
+
def _rdate(rng: random.Random, start: str = "2022-01-01", end: str = "2024-12-31") -> str:
|
| 352 |
+
s = date.fromisoformat(start)
|
| 353 |
+
e = date.fromisoformat(end)
|
| 354 |
+
return (s + timedelta(days=rng.randint(0, (e - s).days))).isoformat()
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def seed_ecommerce(conn: sqlite3.Connection, seed: int = 42) -> None:
|
| 358 |
+
rng = random.Random(seed)
|
| 359 |
+
cats = ["Electronics", "Clothing", "Books", "Home & Garden",
|
| 360 |
+
"Sports & Outdoors", "Toys & Games", "Beauty", "Automotive"]
|
| 361 |
+
conn.executemany("INSERT INTO categories(id,name) VALUES(?,?)", enumerate(cats, 1))
|
| 362 |
+
|
| 363 |
+
products = [
|
| 364 |
+
(1,"Wireless Headphones",1,149.99,50),(2,"Laptop Stand",1,59.99,120),
|
| 365 |
+
(3,"USB-C Hub",1,49.99,90),(4,"Webcam 4K",1,89.99,30),
|
| 366 |
+
(5,"Cotton T-Shirt",2,19.99,200),(6,"Winter Jacket",2,129.99,60),
|
| 367 |
+
(7,"Running Shorts",2,34.99,150),(8,"Clean Code",3,39.99,80),
|
| 368 |
+
(9,"Deep Learning Book",3,59.99,45),(10,"Coffee Maker",4,89.99,40),
|
| 369 |
+
(11,"Air Purifier",4,199.99,25),(12,"Yoga Mat",5,29.99,150),
|
| 370 |
+
(13,"Resistance Bands",5,14.99,200),(14,"Lego City Set",6,79.99,60),
|
| 371 |
+
(15,"Face Serum",7,34.99,100),(16,"Dash Cam",8,119.99,35),
|
| 372 |
+
]
|
| 373 |
+
conn.executemany("INSERT INTO products VALUES(?,?,?,?,?)", products)
|
| 374 |
+
|
| 375 |
+
countries = ["India","USA","Germany","UK","Canada","Australia","France","Brazil"]
|
| 376 |
+
tiers = ["bronze","silver","gold"]
|
| 377 |
+
customers = []
|
| 378 |
+
for i in range(1, 51):
|
| 379 |
+
customers.append((i, f"Customer {i}", f"cust{i}@shop.com",
|
| 380 |
+
rng.choice(countries), rng.choice(tiers), _rdate(rng)))
|
| 381 |
+
conn.executemany("INSERT INTO customers VALUES(?,?,?,?,?,?)", customers)
|
| 382 |
+
|
| 383 |
+
statuses = ["pending","processing","shipped","delivered","cancelled"]
|
| 384 |
+
orders = []
|
| 385 |
+
for i in range(1, 201):
|
| 386 |
+
orders.append((i, rng.randint(1, 50), rng.choice(statuses),
|
| 387 |
+
_rdate(rng), round(rng.uniform(20, 800), 2)))
|
| 388 |
+
conn.executemany("INSERT INTO orders VALUES(?,?,?,?,?)", orders)
|
| 389 |
+
|
| 390 |
+
items = []
|
| 391 |
+
for i in range(1, 301):
|
| 392 |
+
items.append((i, rng.randint(1, 200), rng.randint(1, 16),
|
| 393 |
+
rng.randint(1, 5), round(rng.uniform(10, 200), 2)))
|
| 394 |
+
conn.executemany("INSERT INTO order_items VALUES(?,?,?,?,?)", items)
|
| 395 |
+
|
| 396 |
+
reviews = []
|
| 397 |
+
for i in range(1, 151):
|
| 398 |
+
reviews.append((i, rng.randint(1, 16), rng.randint(1, 50),
|
| 399 |
+
rng.randint(1, 5), _rdate(rng)))
|
| 400 |
+
conn.executemany("INSERT INTO reviews VALUES(?,?,?,?,?)", reviews)
|
| 401 |
+
conn.commit()
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def seed_healthcare(conn: sqlite3.Connection, seed: int = 42) -> None:
|
| 405 |
+
rng = random.Random(seed)
|
| 406 |
+
specs = [("Cardiology","Cardiology"), ("Neurology","Neurology"),
|
| 407 |
+
("Orthopedics","Orthopedics"), ("Dermatology","Dermatology"),
|
| 408 |
+
("Pediatrics","Pediatrics"), ("Oncology","Oncology"),
|
| 409 |
+
("Endocrinology","Endocrinology"), ("Gastroenterology","Gastroenterology")]
|
| 410 |
+
for i, (spec, dept) in enumerate(specs, 1):
|
| 411 |
+
conn.execute("INSERT INTO doctors VALUES(?,?,?,?,?,?)",
|
| 412 |
+
(i, f"Dr. {['Smith','Patel','Kim','Müller','Okafor','Chen','Lopez','Roy'][i-1]}",
|
| 413 |
+
spec, dept, rng.randint(2, 25), round(rng.uniform(50, 350), 2)))
|
| 414 |
+
|
| 415 |
+
genders = ["M", "F", "Other"]
|
| 416 |
+
blood_types = ["A+","A-","B+","B-","O+","O-","AB+","AB-"]
|
| 417 |
+
countries = ["India","USA","Germany","UK","Canada","Australia"]
|
| 418 |
+
for i in range(1, 101):
|
| 419 |
+
conn.execute("INSERT INTO patients VALUES(?,?,?,?,?,?,?)",
|
| 420 |
+
(i, f"Patient {i}", _rdate(rng, "1950-01-01", "2010-01-01"),
|
| 421 |
+
rng.choice(genders), rng.choice(blood_types),
|
| 422 |
+
rng.choice(countries), _rdate(rng, "2020-01-01", "2024-12-31")))
|
| 423 |
+
|
| 424 |
+
appt_statuses = ["scheduled", "completed", "cancelled", "no_show"]
|
| 425 |
+
weights = [0.15, 0.60, 0.15, 0.10]
|
| 426 |
+
for i in range(1, 301):
|
| 427 |
+
conn.execute("INSERT INTO appointments VALUES(?,?,?,?,?,?)",
|
| 428 |
+
(i, rng.randint(1, 100), rng.randint(1, 8),
|
| 429 |
+
_rdate(rng, "2022-01-01", "2024-12-31"),
|
| 430 |
+
rng.choices(appt_statuses, weights)[0], None))
|
| 431 |
+
|
| 432 |
+
icd_codes = ["I10","E11","J45","M54","K21","F32","G43","L30","N39","R05",
|
| 433 |
+
"C50","Z87","I25","E78","J18"]
|
| 434 |
+
descs = ["Hypertension","Type 2 Diabetes","Asthma","Back Pain","GERD",
|
| 435 |
+
"Depression","Migraine","Dermatitis","UTI","Cough",
|
| 436 |
+
"Breast Cancer","Family History","Coronary Artery Disease",
|
| 437 |
+
"Hyperlipidemia","Pneumonia"]
|
| 438 |
+
severities = ["mild","moderate","severe"]
|
| 439 |
+
for i in range(1, 201):
|
| 440 |
+
conn.execute("INSERT INTO diagnoses VALUES(?,?,?,?,?)",
|
| 441 |
+
(i, rng.randint(1, 300), rng.choice(icd_codes),
|
| 442 |
+
rng.choice(descs), rng.choice(severities)))
|
| 443 |
+
|
| 444 |
+
meds = [("Metformin","Antidiabetic",0.15),("Lisinopril","Antihypertensive",0.20),
|
| 445 |
+
("Atorvastatin","Statin",0.25),("Amoxicillin","Antibiotic",0.30),
|
| 446 |
+
("Ibuprofen","NSAID",0.10),("Omeprazole","PPI",0.18),
|
| 447 |
+
("Sertraline","Antidepressant",0.35),("Cetirizine","Antihistamine",0.08),
|
| 448 |
+
("Paracetamol","Analgesic",0.05),("Aspirin","Antiplatelet",0.07)]
|
| 449 |
+
for i, (name, cat, price) in enumerate(meds, 1):
|
| 450 |
+
conn.execute("INSERT INTO medications VALUES(?,?,?,?)", (i, name, cat, price))
|
| 451 |
+
|
| 452 |
+
dosages = ["1x daily","2x daily","3x daily","once at night","as needed"]
|
| 453 |
+
for i in range(1, 251):
|
| 454 |
+
conn.execute("INSERT INTO prescriptions VALUES(?,?,?,?,?,?)",
|
| 455 |
+
(i, rng.randint(1, 300), rng.randint(1, 10),
|
| 456 |
+
rng.choice(dosages), rng.randint(5, 60), rng.randint(10, 90)))
|
| 457 |
+
conn.commit()
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
def seed_finance(conn: sqlite3.Connection, seed: int = 42) -> None:
|
| 461 |
+
rng = random.Random(seed)
|
| 462 |
+
countries = ["India","USA","Germany","UK","Singapore","UAE","Canada"]
|
| 463 |
+
kyc = ["pending","verified","verified","verified","rejected"]
|
| 464 |
+
for i in range(1, 51):
|
| 465 |
+
conn.execute("INSERT INTO fin_customers VALUES(?,?,?,?,?,?)",
|
| 466 |
+
(i, f"FinClient {i}", f"fincli{i}@bank.com",
|
| 467 |
+
rng.choice(countries), rng.choice(kyc), _rdate(rng)))
|
| 468 |
+
|
| 469 |
+
acct_types = ["savings","savings","current","fixed_deposit"]
|
| 470 |
+
statuses = ["active","active","active","dormant","closed"]
|
| 471 |
+
for i in range(1, 101):
|
| 472 |
+
conn.execute("INSERT INTO accounts VALUES(?,?,?,?,?,?,?)",
|
| 473 |
+
(i, rng.randint(1, 50), rng.choice(acct_types),
|
| 474 |
+
round(rng.uniform(100, 100000), 2), "USD",
|
| 475 |
+
rng.choice(statuses), _rdate(rng)))
|
| 476 |
+
|
| 477 |
+
merchants = [None, "Amazon", "Walmart", "Netflix", "Uber", "Apple",
|
| 478 |
+
"Google Pay", "Zomato", "Flipkart", "Airbnb"]
|
| 479 |
+
for i in range(1, 501):
|
| 480 |
+
conn.execute("INSERT INTO transactions VALUES(?,?,?,?,?,?,?)",
|
| 481 |
+
(i, rng.randint(1, 100), rng.choice(["credit","debit"]),
|
| 482 |
+
round(rng.uniform(5, 10000), 2), "USD",
|
| 483 |
+
rng.choice(merchants), _rdate(rng)))
|
| 484 |
+
|
| 485 |
+
loan_types = ["personal","home","auto","business"]
|
| 486 |
+
loan_statuses = ["active","active","closed","defaulted"]
|
| 487 |
+
for i in range(1, 51):
|
| 488 |
+
conn.execute("INSERT INTO loans VALUES(?,?,?,?,?,?,?,?)",
|
| 489 |
+
(i, rng.randint(1, 50), rng.choice(loan_types),
|
| 490 |
+
round(rng.uniform(5000, 500000), 2),
|
| 491 |
+
round(rng.uniform(5, 18), 2), rng.randint(12, 360),
|
| 492 |
+
rng.choice(loan_statuses), _rdate(rng)))
|
| 493 |
+
|
| 494 |
+
for i in range(1, 201):
|
| 495 |
+
conn.execute("INSERT INTO loan_payments VALUES(?,?,?,?,?)",
|
| 496 |
+
(i, rng.randint(1, 50), round(rng.uniform(500, 10000), 2),
|
| 497 |
+
_rdate(rng), rng.randint(0, 1)))
|
| 498 |
+
conn.commit()
|
| 499 |
+
|
| 500 |
+
|
| 501 |
+
def seed_hr(conn: sqlite3.Connection, seed: int = 42) -> None:
|
| 502 |
+
rng = random.Random(seed)
|
| 503 |
+
depts = [("Engineering","Bangalore",8000000),("Marketing","Mumbai",3000000),
|
| 504 |
+
("Finance","Delhi",2000000),("HR","Chennai",1500000),
|
| 505 |
+
("Sales","Hyderabad",5000000),("Product","Pune",4000000),
|
| 506 |
+
("Legal","Delhi",1000000),("Operations","Kolkata",2500000)]
|
| 507 |
+
for i, (name, loc, bud) in enumerate(depts, 1):
|
| 508 |
+
conn.execute("INSERT INTO departments VALUES(?,?,?,?)", (i, name, loc, bud))
|
| 509 |
+
|
| 510 |
+
titles = ["Software Engineer","Senior Engineer","Staff Engineer","Principal Engineer",
|
| 511 |
+
"Engineering Manager","Product Manager","Data Analyst","Data Scientist",
|
| 512 |
+
"Marketing Specialist","Sales Executive","HR Specialist","Finance Analyst",
|
| 513 |
+
"Director","VP","Legal Counsel"]
|
| 514 |
+
statuses = ["active","active","active","active","resigned","terminated"]
|
| 515 |
+
for i in range(1, 101):
|
| 516 |
+
conn.execute("INSERT INTO employees VALUES(?,?,?,?,?,?,?,?)",
|
| 517 |
+
(i, f"Employee {i}", f"emp{i}@corp.com",
|
| 518 |
+
rng.randint(1, 8), rng.choice(titles),
|
| 519 |
+
_rdate(rng, "2015-01-01", "2024-01-01"),
|
| 520 |
+
round(rng.uniform(25000, 200000), 2), rng.choice(statuses)))
|
| 521 |
+
|
| 522 |
+
for i in range(1, 201):
|
| 523 |
+
conn.execute("INSERT INTO performance_reviews VALUES(?,?,?,?,?,?)",
|
| 524 |
+
(i, rng.randint(1, 100), rng.randint(2019, 2024),
|
| 525 |
+
rng.randint(1, 5), rng.randint(1, 100),
|
| 526 |
+
rng.choice(["Excellent work","Good performance","Needs improvement",
|
| 527 |
+
"Outstanding","Meeting expectations"])))
|
| 528 |
+
|
| 529 |
+
proj_statuses = ["planned","active","active","completed","cancelled"]
|
| 530 |
+
for i in range(1, 51):
|
| 531 |
+
sd = _rdate(rng, "2021-01-01", "2024-01-01")
|
| 532 |
+
conn.execute("INSERT INTO projects VALUES(?,?,?,?,?,?,?)",
|
| 533 |
+
(i, f"Project {i}", rng.randint(1, 8), sd,
|
| 534 |
+
_rdate(rng, sd, "2025-06-01") if rng.random() > 0.25 else None,
|
| 535 |
+
round(rng.uniform(50000, 2000000), 2), rng.choice(proj_statuses)))
|
| 536 |
+
|
| 537 |
+
roles = ["Lead","Senior Developer","Developer","Tester","Analyst","DevOps"]
|
| 538 |
+
for i in range(1, 251):
|
| 539 |
+
conn.execute("INSERT INTO project_assignments VALUES(?,?,?,?,?)",
|
| 540 |
+
(i, rng.randint(1, 100), rng.randint(1, 50),
|
| 541 |
+
rng.choice(roles), rng.randint(20, 400)))
|
| 542 |
+
conn.commit()
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 546 |
+
# REGISTRY
|
| 547 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 548 |
+
|
| 549 |
+
SEED_MAP: dict[str, Callable] = {
|
| 550 |
+
"ecommerce": seed_ecommerce,
|
| 551 |
+
"healthcare": seed_healthcare,
|
| 552 |
+
"finance": seed_finance,
|
| 553 |
+
"hr": seed_hr,
|
| 554 |
+
}
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
def build_connection(domain: str, seed: int = 42) -> sqlite3.Connection:
|
| 558 |
+
"""Return a seeded in-memory SQLite connection for the given domain."""
|
| 559 |
+
conn = sqlite3.connect(":memory:", check_same_thread=False)
|
| 560 |
+
conn.row_factory = sqlite3.Row
|
| 561 |
+
conn.execute("PRAGMA foreign_keys = ON")
|
| 562 |
+
conn.executescript(SCHEMA_MAP[domain])
|
| 563 |
+
SEED_MAP[domain](conn, seed=seed)
|
| 564 |
+
return conn
|
data_factory/templates.py
ADDED
|
@@ -0,0 +1,993 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
data_factory/templates.py
|
| 3 |
+
==========================
|
| 4 |
+
Human-authored, execution-verified SQL templates across 4 domains × 3 difficulty tiers.
|
| 5 |
+
|
| 6 |
+
CRITICAL DESIGN PRINCIPLE:
|
| 7 |
+
SQL is NEVER generated by an LLM in this pipeline.
|
| 8 |
+
Every SQL here was written by hand and verified by running it against
|
| 9 |
+
seeded SQLite data. Zero errors guaranteed.
|
| 10 |
+
|
| 11 |
+
Structure per entry:
|
| 12 |
+
{
|
| 13 |
+
"domain": str, # ecommerce | healthcare | finance | hr
|
| 14 |
+
"difficulty": str, # easy | medium | hard
|
| 15 |
+
"sql": str, # verified ground-truth SQL
|
| 16 |
+
"description": str, # one-line English summary (seed for NL generation)
|
| 17 |
+
"base_nl": str, # canonical natural-language question
|
| 18 |
+
"has_order": bool, # True → comparison is order-sensitive
|
| 19 |
+
}
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
from __future__ import annotations
|
| 23 |
+
from typing import TypedDict
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class Template(TypedDict):
|
| 27 |
+
domain: str
|
| 28 |
+
difficulty: str
|
| 29 |
+
sql: str
|
| 30 |
+
description: str
|
| 31 |
+
base_nl: str
|
| 32 |
+
has_order: bool
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 36 |
+
# DOMAIN: ECOMMERCE
|
| 37 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 38 |
+
|
| 39 |
+
ECOMMERCE_TEMPLATES: list[Template] = [
|
| 40 |
+
|
| 41 |
+
# ── EASY ────────────────────────────────────────────────────────────────
|
| 42 |
+
|
| 43 |
+
{
|
| 44 |
+
"domain": "ecommerce", "difficulty": "easy", "has_order": True,
|
| 45 |
+
"description": "List gold-tier customers sorted alphabetically with id, name, email, country",
|
| 46 |
+
"base_nl": "List all gold-tier customers ordered by name alphabetically. Return id, name, email, country.",
|
| 47 |
+
"sql": (
|
| 48 |
+
"SELECT id, name, email, country "
|
| 49 |
+
"FROM customers "
|
| 50 |
+
"WHERE tier = 'gold' "
|
| 51 |
+
"ORDER BY name ASC"
|
| 52 |
+
),
|
| 53 |
+
},
|
| 54 |
+
{
|
| 55 |
+
"domain": "ecommerce", "difficulty": "easy", "has_order": True,
|
| 56 |
+
"description": "Products priced above $100, sorted by price descending",
|
| 57 |
+
"base_nl": "Show all products with a price above $100, sorted from highest to lowest price. Return id, name, price.",
|
| 58 |
+
"sql": (
|
| 59 |
+
"SELECT id, name, price "
|
| 60 |
+
"FROM products "
|
| 61 |
+
"WHERE price > 100 "
|
| 62 |
+
"ORDER BY price DESC"
|
| 63 |
+
),
|
| 64 |
+
},
|
| 65 |
+
{
|
| 66 |
+
"domain": "ecommerce", "difficulty": "easy", "has_order": True,
|
| 67 |
+
"description": "Delivered orders with total_amount > 200, sorted by amount descending",
|
| 68 |
+
"base_nl": "Find all delivered orders with a total amount greater than $200, sorted by total amount descending. Return id, customer_id, total_amount, created_at.",
|
| 69 |
+
"sql": (
|
| 70 |
+
"SELECT id, customer_id, total_amount, created_at "
|
| 71 |
+
"FROM orders "
|
| 72 |
+
"WHERE status = 'delivered' "
|
| 73 |
+
" AND total_amount > 200 "
|
| 74 |
+
"ORDER BY total_amount DESC"
|
| 75 |
+
),
|
| 76 |
+
},
|
| 77 |
+
{
|
| 78 |
+
"domain": "ecommerce", "difficulty": "easy", "has_order": True,
|
| 79 |
+
"description": "Top 5 most expensive products",
|
| 80 |
+
"base_nl": "Return the top 5 most expensive products. Return id, name, price.",
|
| 81 |
+
"sql": (
|
| 82 |
+
"SELECT id, name, price "
|
| 83 |
+
"FROM products "
|
| 84 |
+
"ORDER BY price DESC "
|
| 85 |
+
"LIMIT 5"
|
| 86 |
+
),
|
| 87 |
+
},
|
| 88 |
+
{
|
| 89 |
+
"domain": "ecommerce", "difficulty": "easy", "has_order": True,
|
| 90 |
+
"description": "Distinct countries where customers come from, sorted alphabetically",
|
| 91 |
+
"base_nl": "List all distinct countries our customers come from, sorted alphabetically. Return country.",
|
| 92 |
+
"sql": (
|
| 93 |
+
"SELECT DISTINCT country "
|
| 94 |
+
"FROM customers "
|
| 95 |
+
"ORDER BY country ASC"
|
| 96 |
+
),
|
| 97 |
+
},
|
| 98 |
+
{
|
| 99 |
+
"domain": "ecommerce", "difficulty": "easy", "has_order": False,
|
| 100 |
+
"description": "Count total number of customers",
|
| 101 |
+
"base_nl": "How many customers do we have in total? Return a single column total_customers.",
|
| 102 |
+
"sql": "SELECT COUNT(*) AS total_customers FROM customers",
|
| 103 |
+
},
|
| 104 |
+
{
|
| 105 |
+
"domain": "ecommerce", "difficulty": "easy", "has_order": True,
|
| 106 |
+
"description": "Products with zero stock",
|
| 107 |
+
"base_nl": "List all out-of-stock products. Return id, name, stock_quantity.",
|
| 108 |
+
"sql": (
|
| 109 |
+
"SELECT id, name, stock_quantity "
|
| 110 |
+
"FROM products "
|
| 111 |
+
"WHERE stock_quantity = 0 "
|
| 112 |
+
"ORDER BY name ASC"
|
| 113 |
+
),
|
| 114 |
+
},
|
| 115 |
+
{
|
| 116 |
+
"domain": "ecommerce", "difficulty": "easy", "has_order": True,
|
| 117 |
+
"description": "Customers from India sorted by name",
|
| 118 |
+
"base_nl": "Show all customers from India, sorted by name. Return id, name, email.",
|
| 119 |
+
"sql": (
|
| 120 |
+
"SELECT id, name, email "
|
| 121 |
+
"FROM customers "
|
| 122 |
+
"WHERE country = 'India' "
|
| 123 |
+
"ORDER BY name ASC"
|
| 124 |
+
),
|
| 125 |
+
},
|
| 126 |
+
{
|
| 127 |
+
"domain": "ecommerce", "difficulty": "easy", "has_order": True,
|
| 128 |
+
"description": "Products in a price range of $20 to $100 sorted by price ascending",
|
| 129 |
+
"base_nl": "Which products are priced between $20 and $100? Sort by price ascending. Return id, name, price.",
|
| 130 |
+
"sql": (
|
| 131 |
+
"SELECT id, name, price "
|
| 132 |
+
"FROM products "
|
| 133 |
+
"WHERE price BETWEEN 20 AND 100 "
|
| 134 |
+
"ORDER BY price ASC"
|
| 135 |
+
),
|
| 136 |
+
},
|
| 137 |
+
{
|
| 138 |
+
"domain": "ecommerce", "difficulty": "easy", "has_order": False,
|
| 139 |
+
"description": "Count orders by status",
|
| 140 |
+
"base_nl": "How many orders are there for each status? Return status, order_count.",
|
| 141 |
+
"sql": (
|
| 142 |
+
"SELECT status, COUNT(*) AS order_count "
|
| 143 |
+
"FROM orders "
|
| 144 |
+
"GROUP BY status"
|
| 145 |
+
),
|
| 146 |
+
},
|
| 147 |
+
|
| 148 |
+
# ── MEDIUM ───────────────────────────────────────────────────────────────
|
| 149 |
+
|
| 150 |
+
{
|
| 151 |
+
"domain": "ecommerce", "difficulty": "medium", "has_order": True,
|
| 152 |
+
"description": "Order count per customer including those with zero orders, sorted by count desc",
|
| 153 |
+
"base_nl": "How many orders has each customer placed? Include customers with zero orders. Return customer_name, order_count, sorted by order_count descending then customer_name ascending.",
|
| 154 |
+
"sql": (
|
| 155 |
+
"SELECT c.name AS customer_name, COUNT(o.id) AS order_count "
|
| 156 |
+
"FROM customers c "
|
| 157 |
+
"LEFT JOIN orders o ON c.id = o.customer_id "
|
| 158 |
+
"GROUP BY c.id, c.name "
|
| 159 |
+
"ORDER BY order_count DESC, customer_name ASC"
|
| 160 |
+
),
|
| 161 |
+
},
|
| 162 |
+
{
|
| 163 |
+
"domain": "ecommerce", "difficulty": "medium", "has_order": True,
|
| 164 |
+
"description": "Average product rating per category sorted descending",
|
| 165 |
+
"base_nl": "What is the average product rating per category? Only include categories with at least one review. Return category_name, avg_rating (rounded to 2 decimal places), sorted by avg_rating descending.",
|
| 166 |
+
"sql": (
|
| 167 |
+
"SELECT c.name AS category_name, "
|
| 168 |
+
" ROUND(AVG(r.rating), 2) AS avg_rating "
|
| 169 |
+
"FROM categories c "
|
| 170 |
+
"JOIN products p ON p.category_id = c.id "
|
| 171 |
+
"JOIN reviews r ON r.product_id = p.id "
|
| 172 |
+
"GROUP BY c.id, c.name "
|
| 173 |
+
"ORDER BY avg_rating DESC"
|
| 174 |
+
),
|
| 175 |
+
},
|
| 176 |
+
{
|
| 177 |
+
"domain": "ecommerce", "difficulty": "medium", "has_order": True,
|
| 178 |
+
"description": "Customers who spent more than $500 on delivered orders",
|
| 179 |
+
"base_nl": "Which customers have spent more than $500 total on delivered orders? Return customer_name, total_spent (rounded to 2 decimal places), sorted by total_spent descending.",
|
| 180 |
+
"sql": (
|
| 181 |
+
"SELECT c.name AS customer_name, "
|
| 182 |
+
" ROUND(SUM(o.total_amount), 2) AS total_spent "
|
| 183 |
+
"FROM customers c "
|
| 184 |
+
"JOIN orders o ON o.customer_id = c.id "
|
| 185 |
+
"WHERE o.status = 'delivered' "
|
| 186 |
+
"GROUP BY c.id, c.name "
|
| 187 |
+
"HAVING SUM(o.total_amount) > 500 "
|
| 188 |
+
"ORDER BY total_spent DESC"
|
| 189 |
+
),
|
| 190 |
+
},
|
| 191 |
+
{
|
| 192 |
+
"domain": "ecommerce", "difficulty": "medium", "has_order": True,
|
| 193 |
+
"description": "Total quantity sold per product sorted descending",
|
| 194 |
+
"base_nl": "Show the total quantity sold for each product that appears in at least one order. Return product_name, total_quantity_sold, sorted by total_quantity_sold descending.",
|
| 195 |
+
"sql": (
|
| 196 |
+
"SELECT p.name AS product_name, "
|
| 197 |
+
" SUM(oi.quantity) AS total_quantity_sold "
|
| 198 |
+
"FROM products p "
|
| 199 |
+
"JOIN order_items oi ON oi.product_id = p.id "
|
| 200 |
+
"GROUP BY p.id, p.name "
|
| 201 |
+
"ORDER BY total_quantity_sold DESC"
|
| 202 |
+
),
|
| 203 |
+
},
|
| 204 |
+
{
|
| 205 |
+
"domain": "ecommerce", "difficulty": "medium", "has_order": True,
|
| 206 |
+
"description": "Product count and average price per category sorted by count desc",
|
| 207 |
+
"base_nl": "For each category, show the number of products and their average price. Return category_name, product_count, avg_price (rounded to 2 decimal places), sorted by product_count descending.",
|
| 208 |
+
"sql": (
|
| 209 |
+
"SELECT cat.name AS category_name, "
|
| 210 |
+
" COUNT(p.id) AS product_count, "
|
| 211 |
+
" ROUND(AVG(p.price), 2) AS avg_price "
|
| 212 |
+
"FROM categories cat "
|
| 213 |
+
"JOIN products p ON p.category_id = cat.id "
|
| 214 |
+
"GROUP BY cat.id, cat.name "
|
| 215 |
+
"ORDER BY product_count DESC"
|
| 216 |
+
),
|
| 217 |
+
},
|
| 218 |
+
{
|
| 219 |
+
"domain": "ecommerce", "difficulty": "medium", "has_order": True,
|
| 220 |
+
"description": "Categories with more than 5 in-stock products sorted by count desc",
|
| 221 |
+
"base_nl": "Which categories have more than 5 products in stock (stock_quantity > 0)? Return category_name, in_stock_count, sorted by in_stock_count descending.",
|
| 222 |
+
"sql": (
|
| 223 |
+
"SELECT c.name AS category_name, "
|
| 224 |
+
" COUNT(p.id) AS in_stock_count "
|
| 225 |
+
"FROM categories c "
|
| 226 |
+
"JOIN products p ON p.category_id = c.id "
|
| 227 |
+
"WHERE p.stock_quantity > 0 "
|
| 228 |
+
"GROUP BY c.id, c.name "
|
| 229 |
+
"HAVING COUNT(p.id) > 5 "
|
| 230 |
+
"ORDER BY in_stock_count DESC"
|
| 231 |
+
),
|
| 232 |
+
},
|
| 233 |
+
{
|
| 234 |
+
"domain": "ecommerce", "difficulty": "medium", "has_order": True,
|
| 235 |
+
"description": "Total revenue per product from order items, sorted descending",
|
| 236 |
+
"base_nl": "What is the total revenue generated by each product from order items? Return product_name, total_revenue (rounded to 2 decimal places), sorted by total_revenue descending.",
|
| 237 |
+
"sql": (
|
| 238 |
+
"SELECT p.name AS product_name, "
|
| 239 |
+
" ROUND(SUM(oi.quantity * oi.unit_price), 2) AS total_revenue "
|
| 240 |
+
"FROM products p "
|
| 241 |
+
"JOIN order_items oi ON oi.product_id = p.id "
|
| 242 |
+
"GROUP BY p.id, p.name "
|
| 243 |
+
"ORDER BY total_revenue DESC"
|
| 244 |
+
),
|
| 245 |
+
},
|
| 246 |
+
|
| 247 |
+
# ── HARD ─────────────────────────────────────────────────────────────────
|
| 248 |
+
|
| 249 |
+
{
|
| 250 |
+
"domain": "ecommerce", "difficulty": "hard", "has_order": True,
|
| 251 |
+
"description": "Customer spending rank using DENSE_RANK on delivered orders",
|
| 252 |
+
"base_nl": "Rank customers by total spending on delivered orders using DENSE_RANK (rank 1 = highest spender). Return customer_name, total_spent (rounded to 2 decimal places), spending_rank, sorted by spending_rank ascending.",
|
| 253 |
+
"sql": (
|
| 254 |
+
"SELECT customer_name, total_spent, spending_rank "
|
| 255 |
+
"FROM ( "
|
| 256 |
+
" SELECT c.name AS customer_name, "
|
| 257 |
+
" ROUND(SUM(o.total_amount), 2) AS total_spent, "
|
| 258 |
+
" DENSE_RANK() OVER (ORDER BY SUM(o.total_amount) DESC) AS spending_rank "
|
| 259 |
+
" FROM customers c "
|
| 260 |
+
" JOIN orders o ON o.customer_id = c.id "
|
| 261 |
+
" WHERE o.status = 'delivered' "
|
| 262 |
+
" GROUP BY c.id, c.name "
|
| 263 |
+
") sub "
|
| 264 |
+
"ORDER BY spending_rank ASC"
|
| 265 |
+
),
|
| 266 |
+
},
|
| 267 |
+
{
|
| 268 |
+
"domain": "ecommerce", "difficulty": "hard", "has_order": True,
|
| 269 |
+
"description": "Monthly delivered revenue with running total using window SUM",
|
| 270 |
+
"base_nl": "Show the monthly revenue from delivered orders and its running cumulative total. Return month (YYYY-MM), monthly_revenue, running_total (both rounded to 2 decimal places), sorted by month ascending.",
|
| 271 |
+
"sql": (
|
| 272 |
+
"WITH monthly AS ( "
|
| 273 |
+
" SELECT strftime('%Y-%m', created_at) AS month, "
|
| 274 |
+
" ROUND(SUM(total_amount), 2) AS monthly_revenue "
|
| 275 |
+
" FROM orders "
|
| 276 |
+
" WHERE status = 'delivered' "
|
| 277 |
+
" GROUP BY strftime('%Y-%m', created_at) "
|
| 278 |
+
") "
|
| 279 |
+
"SELECT month, "
|
| 280 |
+
" monthly_revenue, "
|
| 281 |
+
" ROUND(SUM(monthly_revenue) OVER (ORDER BY month), 2) AS running_total "
|
| 282 |
+
"FROM monthly "
|
| 283 |
+
"ORDER BY month ASC"
|
| 284 |
+
),
|
| 285 |
+
},
|
| 286 |
+
{
|
| 287 |
+
"domain": "ecommerce", "difficulty": "hard", "has_order": True,
|
| 288 |
+
"description": "Customers whose most recent order was cancelled, using ROW_NUMBER CTE",
|
| 289 |
+
"base_nl": "Find all customers whose most recent order has status 'cancelled'. Use ROW_NUMBER to identify the latest order per customer. Return customer_name, last_order_status, last_order_date, sorted by customer_name ascending.",
|
| 290 |
+
"sql": (
|
| 291 |
+
"WITH ranked_orders AS ( "
|
| 292 |
+
" SELECT customer_id, status, created_at, "
|
| 293 |
+
" ROW_NUMBER() OVER (PARTITION BY customer_id "
|
| 294 |
+
" ORDER BY created_at DESC) AS rn "
|
| 295 |
+
" FROM orders "
|
| 296 |
+
") "
|
| 297 |
+
"SELECT c.name AS customer_name, "
|
| 298 |
+
" ro.status AS last_order_status, "
|
| 299 |
+
" ro.created_at AS last_order_date "
|
| 300 |
+
"FROM customers c "
|
| 301 |
+
"JOIN ranked_orders ro ON ro.customer_id = c.id "
|
| 302 |
+
"WHERE ro.rn = 1 "
|
| 303 |
+
" AND ro.status = 'cancelled' "
|
| 304 |
+
"ORDER BY customer_name ASC"
|
| 305 |
+
),
|
| 306 |
+
},
|
| 307 |
+
{
|
| 308 |
+
"domain": "ecommerce", "difficulty": "hard", "has_order": True,
|
| 309 |
+
"description": "Products above their category average rating, using two CTEs",
|
| 310 |
+
"base_nl": "Find products whose average rating is strictly above the average rating of all products in their category. Return product_name, category_name, product_avg_rating, category_avg_rating (both rounded to 2 decimal places), sorted by product_avg_rating descending then product_name ascending.",
|
| 311 |
+
"sql": (
|
| 312 |
+
"WITH product_ratings AS ( "
|
| 313 |
+
" SELECT p.id AS product_id, p.name AS product_name, "
|
| 314 |
+
" p.category_id, c.name AS category_name, "
|
| 315 |
+
" ROUND(AVG(r.rating), 2) AS product_avg_rating "
|
| 316 |
+
" FROM products p "
|
| 317 |
+
" JOIN reviews r ON r.product_id = p.id "
|
| 318 |
+
" JOIN categories c ON c.id = p.category_id "
|
| 319 |
+
" GROUP BY p.id, p.name, p.category_id, c.name "
|
| 320 |
+
"), "
|
| 321 |
+
"category_ratings AS ( "
|
| 322 |
+
" SELECT category_id, "
|
| 323 |
+
" ROUND(AVG(product_avg_rating), 2) AS category_avg_rating "
|
| 324 |
+
" FROM product_ratings "
|
| 325 |
+
" GROUP BY category_id "
|
| 326 |
+
") "
|
| 327 |
+
"SELECT pr.product_name, pr.category_name, "
|
| 328 |
+
" pr.product_avg_rating, cr.category_avg_rating "
|
| 329 |
+
"FROM product_ratings pr "
|
| 330 |
+
"JOIN category_ratings cr ON cr.category_id = pr.category_id "
|
| 331 |
+
"WHERE pr.product_avg_rating > cr.category_avg_rating "
|
| 332 |
+
"ORDER BY pr.product_avg_rating DESC, pr.product_name ASC"
|
| 333 |
+
),
|
| 334 |
+
},
|
| 335 |
+
]
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 339 |
+
# DOMAIN: HEALTHCARE
|
| 340 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 341 |
+
|
| 342 |
+
HEALTHCARE_TEMPLATES: list[Template] = [
|
| 343 |
+
|
| 344 |
+
# ── EASY ────────────────────────────────────────────────────────────────
|
| 345 |
+
|
| 346 |
+
{
|
| 347 |
+
"domain": "healthcare", "difficulty": "easy", "has_order": True,
|
| 348 |
+
"description": "Doctors sorted by consultation fee descending",
|
| 349 |
+
"base_nl": "List all doctors sorted by consultation fee from highest to lowest. Return id, name, specialization, consultation_fee.",
|
| 350 |
+
"sql": (
|
| 351 |
+
"SELECT id, name, specialization, consultation_fee "
|
| 352 |
+
"FROM doctors "
|
| 353 |
+
"ORDER BY consultation_fee DESC"
|
| 354 |
+
),
|
| 355 |
+
},
|
| 356 |
+
{
|
| 357 |
+
"domain": "healthcare", "difficulty": "easy", "has_order": True,
|
| 358 |
+
"description": "Doctors with more than 10 years experience sorted desc",
|
| 359 |
+
"base_nl": "Show doctors with more than 10 years of experience, sorted by experience descending. Return id, name, specialization, experience_years.",
|
| 360 |
+
"sql": (
|
| 361 |
+
"SELECT id, name, specialization, experience_years "
|
| 362 |
+
"FROM doctors "
|
| 363 |
+
"WHERE experience_years > 10 "
|
| 364 |
+
"ORDER BY experience_years DESC"
|
| 365 |
+
),
|
| 366 |
+
},
|
| 367 |
+
{
|
| 368 |
+
"domain": "healthcare", "difficulty": "easy", "has_order": True,
|
| 369 |
+
"description": "Patients from India sorted by name",
|
| 370 |
+
"base_nl": "List all patients from India sorted alphabetically by name. Return id, name, country, blood_type.",
|
| 371 |
+
"sql": (
|
| 372 |
+
"SELECT id, name, country, blood_type "
|
| 373 |
+
"FROM patients "
|
| 374 |
+
"WHERE country = 'India' "
|
| 375 |
+
"ORDER BY name ASC"
|
| 376 |
+
),
|
| 377 |
+
},
|
| 378 |
+
{
|
| 379 |
+
"domain": "healthcare", "difficulty": "easy", "has_order": True,
|
| 380 |
+
"description": "Medications with unit price under $0.20 sorted ascending",
|
| 381 |
+
"base_nl": "Which medications cost less than $0.20 per unit? Sort by price ascending. Return id, name, category, unit_price.",
|
| 382 |
+
"sql": (
|
| 383 |
+
"SELECT id, name, category, unit_price "
|
| 384 |
+
"FROM medications "
|
| 385 |
+
"WHERE unit_price < 0.20 "
|
| 386 |
+
"ORDER BY unit_price ASC"
|
| 387 |
+
),
|
| 388 |
+
},
|
| 389 |
+
{
|
| 390 |
+
"domain": "healthcare", "difficulty": "easy", "has_order": True,
|
| 391 |
+
"description": "Top 5 most expensive medications",
|
| 392 |
+
"base_nl": "What are the top 5 most expensive medications? Return id, name, unit_price.",
|
| 393 |
+
"sql": (
|
| 394 |
+
"SELECT id, name, unit_price "
|
| 395 |
+
"FROM medications "
|
| 396 |
+
"ORDER BY unit_price DESC "
|
| 397 |
+
"LIMIT 5"
|
| 398 |
+
),
|
| 399 |
+
},
|
| 400 |
+
{
|
| 401 |
+
"domain": "healthcare", "difficulty": "easy", "has_order": False,
|
| 402 |
+
"description": "Count of completed appointments",
|
| 403 |
+
"base_nl": "How many appointments have been completed? Return a single value total_completed.",
|
| 404 |
+
"sql": (
|
| 405 |
+
"SELECT COUNT(*) AS total_completed "
|
| 406 |
+
"FROM appointments "
|
| 407 |
+
"WHERE status = 'completed'"
|
| 408 |
+
),
|
| 409 |
+
},
|
| 410 |
+
{
|
| 411 |
+
"domain": "healthcare", "difficulty": "easy", "has_order": True,
|
| 412 |
+
"description": "Severe diagnoses sorted by ICD code",
|
| 413 |
+
"base_nl": "List all severe diagnoses sorted by ICD code. Return id, icd_code, description, severity.",
|
| 414 |
+
"sql": (
|
| 415 |
+
"SELECT id, icd_code, description, severity "
|
| 416 |
+
"FROM diagnoses "
|
| 417 |
+
"WHERE severity = 'severe' "
|
| 418 |
+
"ORDER BY icd_code ASC"
|
| 419 |
+
),
|
| 420 |
+
},
|
| 421 |
+
{
|
| 422 |
+
"domain": "healthcare", "difficulty": "easy", "has_order": False,
|
| 423 |
+
"description": "Count patients by gender",
|
| 424 |
+
"base_nl": "How many patients are there by gender? Return gender, patient_count.",
|
| 425 |
+
"sql": (
|
| 426 |
+
"SELECT gender, COUNT(*) AS patient_count "
|
| 427 |
+
"FROM patients "
|
| 428 |
+
"GROUP BY gender"
|
| 429 |
+
),
|
| 430 |
+
},
|
| 431 |
+
|
| 432 |
+
# ── MEDIUM ───────────────────────────────────────────────────────────────
|
| 433 |
+
|
| 434 |
+
{
|
| 435 |
+
"domain": "healthcare", "difficulty": "medium", "has_order": True,
|
| 436 |
+
"description": "Appointment count per doctor including those with no appointments",
|
| 437 |
+
"base_nl": "How many appointments has each doctor had (including those with none)? Return doctor_name, appointment_count, sorted by appointment_count descending.",
|
| 438 |
+
"sql": (
|
| 439 |
+
"SELECT d.name AS doctor_name, COUNT(a.id) AS appointment_count "
|
| 440 |
+
"FROM doctors d "
|
| 441 |
+
"LEFT JOIN appointments a ON a.doctor_id = d.id "
|
| 442 |
+
"GROUP BY d.id, d.name "
|
| 443 |
+
"ORDER BY appointment_count DESC"
|
| 444 |
+
),
|
| 445 |
+
},
|
| 446 |
+
{
|
| 447 |
+
"domain": "healthcare", "difficulty": "medium", "has_order": True,
|
| 448 |
+
"description": "Most prescribed medications by count",
|
| 449 |
+
"base_nl": "Which medications are prescribed most often? Return medication_name, category, times_prescribed, sorted by times_prescribed descending.",
|
| 450 |
+
"sql": (
|
| 451 |
+
"SELECT m.name AS medication_name, m.category, COUNT(p.id) AS times_prescribed "
|
| 452 |
+
"FROM medications m "
|
| 453 |
+
"JOIN prescriptions p ON p.medication_id = m.id "
|
| 454 |
+
"GROUP BY m.id, m.name, m.category "
|
| 455 |
+
"ORDER BY times_prescribed DESC"
|
| 456 |
+
),
|
| 457 |
+
},
|
| 458 |
+
{
|
| 459 |
+
"domain": "healthcare", "difficulty": "medium", "has_order": True,
|
| 460 |
+
"description": "Patients with more than one completed visit",
|
| 461 |
+
"base_nl": "Which patients have had more than one completed appointment? Return patient_name, visit_count, sorted by visit_count descending.",
|
| 462 |
+
"sql": (
|
| 463 |
+
"SELECT pat.name AS patient_name, COUNT(DISTINCT a.id) AS visit_count "
|
| 464 |
+
"FROM patients pat "
|
| 465 |
+
"JOIN appointments a ON a.patient_id = pat.id "
|
| 466 |
+
"WHERE a.status = 'completed' "
|
| 467 |
+
"GROUP BY pat.id, pat.name "
|
| 468 |
+
"HAVING COUNT(DISTINCT a.id) > 1 "
|
| 469 |
+
"ORDER BY visit_count DESC"
|
| 470 |
+
),
|
| 471 |
+
},
|
| 472 |
+
{
|
| 473 |
+
"domain": "healthcare", "difficulty": "medium", "has_order": True,
|
| 474 |
+
"description": "Estimated revenue per doctor from completed appointments",
|
| 475 |
+
"base_nl": "What is the estimated total revenue per doctor from completed appointments (based on consultation fee)? Return doctor_name, specialization, estimated_revenue (rounded to 2 decimal places), sorted by estimated_revenue descending.",
|
| 476 |
+
"sql": (
|
| 477 |
+
"SELECT d.name AS doctor_name, d.specialization, "
|
| 478 |
+
" ROUND(SUM(d.consultation_fee), 2) AS estimated_revenue "
|
| 479 |
+
"FROM doctors d "
|
| 480 |
+
"JOIN appointments a ON a.doctor_id = d.id "
|
| 481 |
+
"WHERE a.status = 'completed' "
|
| 482 |
+
"GROUP BY d.id, d.name, d.specialization "
|
| 483 |
+
"ORDER BY estimated_revenue DESC"
|
| 484 |
+
),
|
| 485 |
+
},
|
| 486 |
+
{
|
| 487 |
+
"domain": "healthcare", "difficulty": "medium", "has_order": True,
|
| 488 |
+
"description": "Diagnosis count per severity level",
|
| 489 |
+
"base_nl": "How many diagnoses are there per severity level? Return severity, diagnosis_count, sorted by diagnosis_count descending.",
|
| 490 |
+
"sql": (
|
| 491 |
+
"SELECT severity, COUNT(*) AS diagnosis_count "
|
| 492 |
+
"FROM diagnoses "
|
| 493 |
+
"GROUP BY severity "
|
| 494 |
+
"ORDER BY diagnosis_count DESC"
|
| 495 |
+
),
|
| 496 |
+
},
|
| 497 |
+
|
| 498 |
+
# ── HARD ─────────────────────────────────────────────────────────────────
|
| 499 |
+
|
| 500 |
+
{
|
| 501 |
+
"domain": "healthcare", "difficulty": "hard", "has_order": True,
|
| 502 |
+
"description": "Doctors ranked by appointment count within specialization using RANK",
|
| 503 |
+
"base_nl": "Rank doctors by appointment count within their specialization (rank 1 = most appointments). Return doctor_name, specialization, appointment_count, rank_in_spec, sorted by specialization then rank_in_spec ascending.",
|
| 504 |
+
"sql": (
|
| 505 |
+
"SELECT doctor_name, specialization, appointment_count, "
|
| 506 |
+
" RANK() OVER (PARTITION BY specialization ORDER BY appointment_count DESC) AS rank_in_spec "
|
| 507 |
+
"FROM ( "
|
| 508 |
+
" SELECT d.name AS doctor_name, d.specialization, COUNT(a.id) AS appointment_count "
|
| 509 |
+
" FROM doctors d "
|
| 510 |
+
" JOIN appointments a ON a.doctor_id = d.id "
|
| 511 |
+
" GROUP BY d.id, d.name, d.specialization "
|
| 512 |
+
") sub "
|
| 513 |
+
"ORDER BY specialization, rank_in_spec"
|
| 514 |
+
),
|
| 515 |
+
},
|
| 516 |
+
{
|
| 517 |
+
"domain": "healthcare", "difficulty": "hard", "has_order": True,
|
| 518 |
+
"description": "Top 10 patients by total completed visits using CTE",
|
| 519 |
+
"base_nl": "Find the top 10 patients by number of completed appointments. Return patient_name, total_visits, last_visit, sorted by total_visits descending.",
|
| 520 |
+
"sql": (
|
| 521 |
+
"WITH patient_visits AS ( "
|
| 522 |
+
" SELECT a.patient_id, COUNT(a.id) AS total_visits, "
|
| 523 |
+
" MAX(a.scheduled_at) AS last_visit "
|
| 524 |
+
" FROM appointments a "
|
| 525 |
+
" WHERE a.status = 'completed' "
|
| 526 |
+
" GROUP BY a.patient_id "
|
| 527 |
+
") "
|
| 528 |
+
"SELECT p.name AS patient_name, pv.total_visits, pv.last_visit "
|
| 529 |
+
"FROM patients p "
|
| 530 |
+
"JOIN patient_visits pv ON pv.patient_id = p.id "
|
| 531 |
+
"ORDER BY pv.total_visits DESC "
|
| 532 |
+
"LIMIT 10"
|
| 533 |
+
),
|
| 534 |
+
},
|
| 535 |
+
{
|
| 536 |
+
"domain": "healthcare", "difficulty": "hard", "has_order": True,
|
| 537 |
+
"description": "Medications total prescription cost per category using window SUM",
|
| 538 |
+
"base_nl": "For each medication, show its total prescription cost (unit_price × quantity) and the running total of cost within its category. Return medication_name, category, total_cost, category_running_cost (both rounded to 2 decimal places), sorted by category then total_cost descending.",
|
| 539 |
+
"sql": (
|
| 540 |
+
"WITH med_costs AS ( "
|
| 541 |
+
" SELECT m.name AS medication_name, m.category, "
|
| 542 |
+
" ROUND(SUM(m.unit_price * pr.quantity), 2) AS total_cost "
|
| 543 |
+
" FROM medications m "
|
| 544 |
+
" JOIN prescriptions pr ON pr.medication_id = m.id "
|
| 545 |
+
" GROUP BY m.id, m.name, m.category "
|
| 546 |
+
") "
|
| 547 |
+
"SELECT medication_name, category, total_cost, "
|
| 548 |
+
" ROUND(SUM(total_cost) OVER (PARTITION BY category ORDER BY total_cost DESC), 2) "
|
| 549 |
+
" AS category_running_cost "
|
| 550 |
+
"FROM med_costs "
|
| 551 |
+
"ORDER BY category, total_cost DESC"
|
| 552 |
+
),
|
| 553 |
+
},
|
| 554 |
+
]
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 558 |
+
# DOMAIN: FINANCE
|
| 559 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 560 |
+
|
| 561 |
+
FINANCE_TEMPLATES: list[Template] = [
|
| 562 |
+
|
| 563 |
+
# ── EASY ────────────────────────────────────────────────────────────────
|
| 564 |
+
|
| 565 |
+
{
|
| 566 |
+
"domain": "finance", "difficulty": "easy", "has_order": True,
|
| 567 |
+
"description": "Verified KYC customers sorted by name",
|
| 568 |
+
"base_nl": "List all customers with verified KYC status, sorted alphabetically. Return id, name, country, kyc_status.",
|
| 569 |
+
"sql": (
|
| 570 |
+
"SELECT id, name, country, kyc_status "
|
| 571 |
+
"FROM fin_customers "
|
| 572 |
+
"WHERE kyc_status = 'verified' "
|
| 573 |
+
"ORDER BY name ASC"
|
| 574 |
+
),
|
| 575 |
+
},
|
| 576 |
+
{
|
| 577 |
+
"domain": "finance", "difficulty": "easy", "has_order": True,
|
| 578 |
+
"description": "Accounts with balance over $10,000 sorted by balance descending",
|
| 579 |
+
"base_nl": "Which accounts have a balance greater than $10,000? Return id, customer_id, account_type, balance, sorted by balance descending.",
|
| 580 |
+
"sql": (
|
| 581 |
+
"SELECT id, customer_id, account_type, balance "
|
| 582 |
+
"FROM accounts "
|
| 583 |
+
"WHERE balance > 10000 "
|
| 584 |
+
"ORDER BY balance DESC"
|
| 585 |
+
),
|
| 586 |
+
},
|
| 587 |
+
{
|
| 588 |
+
"domain": "finance", "difficulty": "easy", "has_order": True,
|
| 589 |
+
"description": "Large credit transactions above $1,000 sorted by amount descending",
|
| 590 |
+
"base_nl": "Show all credit transactions with an amount greater than $1,000. Return id, account_id, txn_type, amount, created_at, sorted by amount descending.",
|
| 591 |
+
"sql": (
|
| 592 |
+
"SELECT id, account_id, txn_type, amount, created_at "
|
| 593 |
+
"FROM transactions "
|
| 594 |
+
"WHERE txn_type = 'credit' AND amount > 1000 "
|
| 595 |
+
"ORDER BY amount DESC"
|
| 596 |
+
),
|
| 597 |
+
},
|
| 598 |
+
{
|
| 599 |
+
"domain": "finance", "difficulty": "easy", "has_order": True,
|
| 600 |
+
"description": "Defaulted loans sorted by principal amount descending",
|
| 601 |
+
"base_nl": "List all defaulted loans, sorted by principal amount descending. Return id, loan_type, principal_amount, interest_rate, status.",
|
| 602 |
+
"sql": (
|
| 603 |
+
"SELECT id, loan_type, principal_amount, interest_rate, status "
|
| 604 |
+
"FROM loans "
|
| 605 |
+
"WHERE status = 'defaulted' "
|
| 606 |
+
"ORDER BY principal_amount DESC"
|
| 607 |
+
),
|
| 608 |
+
},
|
| 609 |
+
{
|
| 610 |
+
"domain": "finance", "difficulty": "easy", "has_order": False,
|
| 611 |
+
"description": "Count of late loan payments",
|
| 612 |
+
"base_nl": "How many loan payments were made late? Return a single value late_payments.",
|
| 613 |
+
"sql": "SELECT COUNT(*) AS late_payments FROM loan_payments WHERE is_late = 1",
|
| 614 |
+
},
|
| 615 |
+
{
|
| 616 |
+
"domain": "finance", "difficulty": "easy", "has_order": True,
|
| 617 |
+
"description": "Top 5 highest principal loans",
|
| 618 |
+
"base_nl": "What are the top 5 loans by principal amount? Return id, customer_id, loan_type, principal_amount.",
|
| 619 |
+
"sql": (
|
| 620 |
+
"SELECT id, customer_id, loan_type, principal_amount "
|
| 621 |
+
"FROM loans "
|
| 622 |
+
"ORDER BY principal_amount DESC "
|
| 623 |
+
"LIMIT 5"
|
| 624 |
+
),
|
| 625 |
+
},
|
| 626 |
+
{
|
| 627 |
+
"domain": "finance", "difficulty": "easy", "has_order": False,
|
| 628 |
+
"description": "Count of accounts by account type",
|
| 629 |
+
"base_nl": "How many accounts exist for each account type? Return account_type, account_count.",
|
| 630 |
+
"sql": (
|
| 631 |
+
"SELECT account_type, COUNT(*) AS account_count "
|
| 632 |
+
"FROM accounts "
|
| 633 |
+
"GROUP BY account_type"
|
| 634 |
+
),
|
| 635 |
+
},
|
| 636 |
+
|
| 637 |
+
# ── MEDIUM ───────────────────────────────────────────────────────────────
|
| 638 |
+
|
| 639 |
+
{
|
| 640 |
+
"domain": "finance", "difficulty": "medium", "has_order": True,
|
| 641 |
+
"description": "Total active account balance per customer sorted by balance descending",
|
| 642 |
+
"base_nl": "What is the total active account balance per customer? Return customer_name, account_count, total_balance (rounded to 2 decimal places), sorted by total_balance descending.",
|
| 643 |
+
"sql": (
|
| 644 |
+
"SELECT fc.name AS customer_name, COUNT(a.id) AS account_count, "
|
| 645 |
+
" ROUND(SUM(a.balance), 2) AS total_balance "
|
| 646 |
+
"FROM fin_customers fc "
|
| 647 |
+
"JOIN accounts a ON a.customer_id = fc.id "
|
| 648 |
+
"WHERE a.status = 'active' "
|
| 649 |
+
"GROUP BY fc.id, fc.name "
|
| 650 |
+
"ORDER BY total_balance DESC"
|
| 651 |
+
),
|
| 652 |
+
},
|
| 653 |
+
{
|
| 654 |
+
"domain": "finance", "difficulty": "medium", "has_order": True,
|
| 655 |
+
"description": "Total credit transaction amount by account type",
|
| 656 |
+
"base_nl": "What is the total credit amount per account type? Return account_type, total_credits (rounded to 2 decimal places), sorted by total_credits descending.",
|
| 657 |
+
"sql": (
|
| 658 |
+
"SELECT a.account_type, ROUND(SUM(t.amount), 2) AS total_credits "
|
| 659 |
+
"FROM accounts a "
|
| 660 |
+
"JOIN transactions t ON t.account_id = a.id "
|
| 661 |
+
"WHERE t.txn_type = 'credit' "
|
| 662 |
+
"GROUP BY a.account_type "
|
| 663 |
+
"ORDER BY total_credits DESC"
|
| 664 |
+
),
|
| 665 |
+
},
|
| 666 |
+
{
|
| 667 |
+
"domain": "finance", "difficulty": "medium", "has_order": True,
|
| 668 |
+
"description": "Total loan borrowing per customer sorted descending",
|
| 669 |
+
"base_nl": "How much has each customer borrowed in total across all loans? Return customer_name, loan_count, total_borrowed (rounded to 2 decimal places), sorted by total_borrowed descending.",
|
| 670 |
+
"sql": (
|
| 671 |
+
"SELECT fc.name AS customer_name, COUNT(l.id) AS loan_count, "
|
| 672 |
+
" ROUND(SUM(l.principal_amount), 2) AS total_borrowed "
|
| 673 |
+
"FROM fin_customers fc "
|
| 674 |
+
"JOIN loans l ON l.customer_id = fc.id "
|
| 675 |
+
"GROUP BY fc.id, fc.name "
|
| 676 |
+
"ORDER BY total_borrowed DESC"
|
| 677 |
+
),
|
| 678 |
+
},
|
| 679 |
+
{
|
| 680 |
+
"domain": "finance", "difficulty": "medium", "has_order": True,
|
| 681 |
+
"description": "Late payment count and total amount by loan type",
|
| 682 |
+
"base_nl": "For each loan type, how many late payments were there and what was the total amount paid late? Return loan_type, late_payments, total_late_paid (rounded to 2 decimal places), sorted by late_payments descending.",
|
| 683 |
+
"sql": (
|
| 684 |
+
"SELECT l.loan_type, COUNT(lp.id) AS late_payments, "
|
| 685 |
+
" ROUND(SUM(lp.amount_paid), 2) AS total_late_paid "
|
| 686 |
+
"FROM loans l "
|
| 687 |
+
"JOIN loan_payments lp ON lp.loan_id = l.id "
|
| 688 |
+
"WHERE lp.is_late = 1 "
|
| 689 |
+
"GROUP BY l.loan_type "
|
| 690 |
+
"ORDER BY late_payments DESC"
|
| 691 |
+
),
|
| 692 |
+
},
|
| 693 |
+
|
| 694 |
+
# ── HARD ─────────────────────────────────────────────────────────────────
|
| 695 |
+
|
| 696 |
+
{
|
| 697 |
+
"domain": "finance", "difficulty": "hard", "has_order": True,
|
| 698 |
+
"description": "Customer balance rank using DENSE_RANK on active accounts",
|
| 699 |
+
"base_nl": "Rank customers by their total active account balance using DENSE_RANK. Return customer_name, total_balance, balance_rank, sorted by balance_rank ascending.",
|
| 700 |
+
"sql": (
|
| 701 |
+
"SELECT customer_name, total_balance, "
|
| 702 |
+
" DENSE_RANK() OVER (ORDER BY total_balance DESC) AS balance_rank "
|
| 703 |
+
"FROM ( "
|
| 704 |
+
" SELECT fc.name AS customer_name, "
|
| 705 |
+
" ROUND(SUM(a.balance), 2) AS total_balance "
|
| 706 |
+
" FROM fin_customers fc "
|
| 707 |
+
" JOIN accounts a ON a.customer_id = fc.id "
|
| 708 |
+
" WHERE a.status = 'active' "
|
| 709 |
+
" GROUP BY fc.id, fc.name "
|
| 710 |
+
") sub "
|
| 711 |
+
"ORDER BY balance_rank"
|
| 712 |
+
),
|
| 713 |
+
},
|
| 714 |
+
{
|
| 715 |
+
"domain": "finance", "difficulty": "hard", "has_order": True,
|
| 716 |
+
"description": "Monthly transaction totals by type with running total using window SUM",
|
| 717 |
+
"base_nl": "Show monthly transaction totals per type (credit/debit) with a running cumulative total. Return month (YYYY-MM), txn_type, total, running_total (rounded to 2 decimal places), sorted by month then txn_type.",
|
| 718 |
+
"sql": (
|
| 719 |
+
"WITH monthly_txn AS ( "
|
| 720 |
+
" SELECT strftime('%Y-%m', created_at) AS month, "
|
| 721 |
+
" txn_type, "
|
| 722 |
+
" ROUND(SUM(amount), 2) AS total "
|
| 723 |
+
" FROM transactions "
|
| 724 |
+
" GROUP BY strftime('%Y-%m', created_at), txn_type "
|
| 725 |
+
") "
|
| 726 |
+
"SELECT month, txn_type, total, "
|
| 727 |
+
" ROUND(SUM(total) OVER (PARTITION BY txn_type ORDER BY month), 2) AS running_total "
|
| 728 |
+
"FROM monthly_txn "
|
| 729 |
+
"ORDER BY month, txn_type"
|
| 730 |
+
),
|
| 731 |
+
},
|
| 732 |
+
{
|
| 733 |
+
"domain": "finance", "difficulty": "hard", "has_order": True,
|
| 734 |
+
"description": "Customers with only defaulted loans using NOT EXISTS",
|
| 735 |
+
"base_nl": "Find customers who have at least one loan and ALL their loans are defaulted. Return customer_name, loan_count, sorted by customer_name ascending.",
|
| 736 |
+
"sql": (
|
| 737 |
+
"SELECT fc.name AS customer_name, COUNT(l.id) AS loan_count "
|
| 738 |
+
"FROM fin_customers fc "
|
| 739 |
+
"JOIN loans l ON l.customer_id = fc.id "
|
| 740 |
+
"GROUP BY fc.id, fc.name "
|
| 741 |
+
"HAVING COUNT(l.id) > 0 "
|
| 742 |
+
" AND SUM(CASE WHEN l.status != 'defaulted' THEN 1 ELSE 0 END) = 0 "
|
| 743 |
+
"ORDER BY customer_name ASC"
|
| 744 |
+
),
|
| 745 |
+
},
|
| 746 |
+
]
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 750 |
+
# DOMAIN: HR
|
| 751 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 752 |
+
|
| 753 |
+
HR_TEMPLATES: list[Template] = [
|
| 754 |
+
|
| 755 |
+
# ── EASY ────────────────────────────────────────────────────────────────
|
| 756 |
+
|
| 757 |
+
{
|
| 758 |
+
"domain": "hr", "difficulty": "easy", "has_order": True,
|
| 759 |
+
"description": "Active employees sorted by salary descending",
|
| 760 |
+
"base_nl": "List all active employees sorted by salary from highest to lowest. Return id, name, job_title, salary.",
|
| 761 |
+
"sql": (
|
| 762 |
+
"SELECT id, name, job_title, salary "
|
| 763 |
+
"FROM employees "
|
| 764 |
+
"WHERE status = 'active' "
|
| 765 |
+
"ORDER BY salary DESC"
|
| 766 |
+
),
|
| 767 |
+
},
|
| 768 |
+
{
|
| 769 |
+
"domain": "hr", "difficulty": "easy", "has_order": True,
|
| 770 |
+
"description": "Departments sorted by budget descending",
|
| 771 |
+
"base_nl": "Show all departments sorted by budget from largest to smallest. Return id, name, location, budget.",
|
| 772 |
+
"sql": (
|
| 773 |
+
"SELECT id, name, location, budget "
|
| 774 |
+
"FROM departments "
|
| 775 |
+
"ORDER BY budget DESC"
|
| 776 |
+
),
|
| 777 |
+
},
|
| 778 |
+
{
|
| 779 |
+
"domain": "hr", "difficulty": "easy", "has_order": True,
|
| 780 |
+
"description": "Employees hired in 2023 or later sorted by hire date descending",
|
| 781 |
+
"base_nl": "Which employees were hired on or after January 1st 2023? Sort by hire date descending. Return id, name, job_title, hire_date.",
|
| 782 |
+
"sql": (
|
| 783 |
+
"SELECT id, name, job_title, hire_date "
|
| 784 |
+
"FROM employees "
|
| 785 |
+
"WHERE hire_date >= '2023-01-01' "
|
| 786 |
+
"ORDER BY hire_date DESC"
|
| 787 |
+
),
|
| 788 |
+
},
|
| 789 |
+
{
|
| 790 |
+
"domain": "hr", "difficulty": "easy", "has_order": True,
|
| 791 |
+
"description": "Active projects sorted by budget descending",
|
| 792 |
+
"base_nl": "Show all currently active projects sorted by budget descending. Return id, name, status, budget.",
|
| 793 |
+
"sql": (
|
| 794 |
+
"SELECT id, name, status, budget "
|
| 795 |
+
"FROM projects "
|
| 796 |
+
"WHERE status = 'active' "
|
| 797 |
+
"ORDER BY budget DESC"
|
| 798 |
+
),
|
| 799 |
+
},
|
| 800 |
+
{
|
| 801 |
+
"domain": "hr", "difficulty": "easy", "has_order": True,
|
| 802 |
+
"description": "Active employees earning above $100,000 sorted by salary descending",
|
| 803 |
+
"base_nl": "Which active employees earn more than $100,000? Return id, name, email, job_title, sorted by salary descending.",
|
| 804 |
+
"sql": (
|
| 805 |
+
"SELECT id, name, email, job_title "
|
| 806 |
+
"FROM employees "
|
| 807 |
+
"WHERE status = 'active' AND salary > 100000 "
|
| 808 |
+
"ORDER BY salary DESC"
|
| 809 |
+
),
|
| 810 |
+
},
|
| 811 |
+
{
|
| 812 |
+
"domain": "hr", "difficulty": "easy", "has_order": False,
|
| 813 |
+
"description": "Count of active employees",
|
| 814 |
+
"base_nl": "How many active employees do we currently have? Return active_employees.",
|
| 815 |
+
"sql": "SELECT COUNT(*) AS active_employees FROM employees WHERE status = 'active'",
|
| 816 |
+
},
|
| 817 |
+
{
|
| 818 |
+
"domain": "hr", "difficulty": "easy", "has_order": True,
|
| 819 |
+
"description": "Projects with no end date (ongoing) sorted by budget descending",
|
| 820 |
+
"base_nl": "List all ongoing projects that have no end date set. Return id, name, start_date, budget, sorted by budget descending.",
|
| 821 |
+
"sql": (
|
| 822 |
+
"SELECT id, name, start_date, budget "
|
| 823 |
+
"FROM projects "
|
| 824 |
+
"WHERE end_date IS NULL "
|
| 825 |
+
"ORDER BY budget DESC"
|
| 826 |
+
),
|
| 827 |
+
},
|
| 828 |
+
|
| 829 |
+
# ── MEDIUM ───────────────────────────────────────────────────────────────
|
| 830 |
+
|
| 831 |
+
{
|
| 832 |
+
"domain": "hr", "difficulty": "medium", "has_order": True,
|
| 833 |
+
"description": "Headcount and average salary per department for active employees",
|
| 834 |
+
"base_nl": "For each department, what is the headcount and average salary of active employees? Return department_name, headcount, avg_salary (rounded to 2 decimal places), sorted by headcount descending.",
|
| 835 |
+
"sql": (
|
| 836 |
+
"SELECT d.name AS department_name, COUNT(e.id) AS headcount, "
|
| 837 |
+
" ROUND(AVG(e.salary), 2) AS avg_salary "
|
| 838 |
+
"FROM departments d "
|
| 839 |
+
"LEFT JOIN employees e ON e.department_id = d.id AND e.status = 'active' "
|
| 840 |
+
"GROUP BY d.id, d.name "
|
| 841 |
+
"ORDER BY headcount DESC"
|
| 842 |
+
),
|
| 843 |
+
},
|
| 844 |
+
{
|
| 845 |
+
"domain": "hr", "difficulty": "medium", "has_order": True,
|
| 846 |
+
"description": "Average performance rating per employee sorted descending",
|
| 847 |
+
"base_nl": "What is the average performance review rating per active employee? Return employee_name, job_title, avg_rating (rounded to 2 decimal places), sorted by avg_rating descending.",
|
| 848 |
+
"sql": (
|
| 849 |
+
"SELECT e.name AS employee_name, e.job_title, "
|
| 850 |
+
" ROUND(AVG(pr.rating), 2) AS avg_rating "
|
| 851 |
+
"FROM employees e "
|
| 852 |
+
"JOIN performance_reviews pr ON pr.employee_id = e.id "
|
| 853 |
+
"WHERE e.status = 'active' "
|
| 854 |
+
"GROUP BY e.id, e.name, e.job_title "
|
| 855 |
+
"ORDER BY avg_rating DESC"
|
| 856 |
+
),
|
| 857 |
+
},
|
| 858 |
+
{
|
| 859 |
+
"domain": "hr", "difficulty": "medium", "has_order": True,
|
| 860 |
+
"description": "Employees with the most total allocated project hours",
|
| 861 |
+
"base_nl": "Which employees have the most total hours allocated across projects? Return employee_name, total_hours, sorted by total_hours descending, top 10.",
|
| 862 |
+
"sql": (
|
| 863 |
+
"SELECT e.name AS employee_name, SUM(pa.hours_allocated) AS total_hours "
|
| 864 |
+
"FROM employees e "
|
| 865 |
+
"JOIN project_assignments pa ON pa.employee_id = e.id "
|
| 866 |
+
"GROUP BY e.id, e.name "
|
| 867 |
+
"ORDER BY total_hours DESC "
|
| 868 |
+
"LIMIT 10"
|
| 869 |
+
),
|
| 870 |
+
},
|
| 871 |
+
{
|
| 872 |
+
"domain": "hr", "difficulty": "medium", "has_order": True,
|
| 873 |
+
"description": "Departments with distinct employees assigned to active projects",
|
| 874 |
+
"base_nl": "For each department, how many distinct employees are assigned to active projects? Return department_name, assigned_employees, sorted by assigned_employees descending.",
|
| 875 |
+
"sql": (
|
| 876 |
+
"SELECT d.name AS department_name, "
|
| 877 |
+
" COUNT(DISTINCT pa.employee_id) AS assigned_employees "
|
| 878 |
+
"FROM departments d "
|
| 879 |
+
"JOIN projects p ON p.department_id = d.id "
|
| 880 |
+
"JOIN project_assignments pa ON pa.project_id = p.id "
|
| 881 |
+
"WHERE p.status = 'active' "
|
| 882 |
+
"GROUP BY d.id, d.name "
|
| 883 |
+
"ORDER BY assigned_employees DESC"
|
| 884 |
+
),
|
| 885 |
+
},
|
| 886 |
+
{
|
| 887 |
+
"domain": "hr", "difficulty": "medium", "has_order": True,
|
| 888 |
+
"description": "Total project budget per department sorted descending",
|
| 889 |
+
"base_nl": "What is the total project budget per department? Return department_name, total_project_budget (rounded to 2 decimal places), sorted by total_project_budget descending.",
|
| 890 |
+
"sql": (
|
| 891 |
+
"SELECT d.name AS department_name, "
|
| 892 |
+
" ROUND(SUM(p.budget), 2) AS total_project_budget "
|
| 893 |
+
"FROM departments d "
|
| 894 |
+
"JOIN projects p ON p.department_id = d.id "
|
| 895 |
+
"GROUP BY d.id, d.name "
|
| 896 |
+
"ORDER BY total_project_budget DESC"
|
| 897 |
+
),
|
| 898 |
+
},
|
| 899 |
+
|
| 900 |
+
# ── HARD ─────────────────────────────────────────────────────────────────
|
| 901 |
+
|
| 902 |
+
{
|
| 903 |
+
"domain": "hr", "difficulty": "hard", "has_order": True,
|
| 904 |
+
"description": "Salary rank within department using DENSE_RANK",
|
| 905 |
+
"base_nl": "Rank active employees by salary within their department using DENSE_RANK (rank 1 = highest paid). Return employee_name, salary, department_name, salary_rank, sorted by department_name then salary_rank ascending.",
|
| 906 |
+
"sql": (
|
| 907 |
+
"SELECT employee_name, salary, department_name, "
|
| 908 |
+
" DENSE_RANK() OVER (PARTITION BY department_name ORDER BY salary DESC) AS salary_rank "
|
| 909 |
+
"FROM ( "
|
| 910 |
+
" SELECT e.name AS employee_name, e.salary, d.name AS department_name "
|
| 911 |
+
" FROM employees e "
|
| 912 |
+
" JOIN departments d ON d.id = e.department_id "
|
| 913 |
+
" WHERE e.status = 'active' "
|
| 914 |
+
") sub "
|
| 915 |
+
"ORDER BY department_name, salary_rank"
|
| 916 |
+
),
|
| 917 |
+
},
|
| 918 |
+
{
|
| 919 |
+
"domain": "hr", "difficulty": "hard", "has_order": True,
|
| 920 |
+
"description": "Employee performance band classification using CASE with avg rating CTE",
|
| 921 |
+
"base_nl": "Classify active employees into performance bands (High Performer: avg rating >= 4, Average: >= 3, Needs Improvement: < 3) based on their average review rating. Return employee_name, salary, avg_rating, performance_band, sorted by avg_rating descending.",
|
| 922 |
+
"sql": (
|
| 923 |
+
"WITH avg_ratings AS ( "
|
| 924 |
+
" SELECT employee_id, ROUND(AVG(rating), 2) AS avg_rating "
|
| 925 |
+
" FROM performance_reviews "
|
| 926 |
+
" GROUP BY employee_id "
|
| 927 |
+
") "
|
| 928 |
+
"SELECT e.name AS employee_name, e.salary, ar.avg_rating, "
|
| 929 |
+
" CASE WHEN ar.avg_rating >= 4 THEN 'High Performer' "
|
| 930 |
+
" WHEN ar.avg_rating >= 3 THEN 'Average' "
|
| 931 |
+
" ELSE 'Needs Improvement' "
|
| 932 |
+
" END AS performance_band "
|
| 933 |
+
"FROM employees e "
|
| 934 |
+
"JOIN avg_ratings ar ON ar.employee_id = e.id "
|
| 935 |
+
"WHERE e.status = 'active' "
|
| 936 |
+
"ORDER BY ar.avg_rating DESC"
|
| 937 |
+
),
|
| 938 |
+
},
|
| 939 |
+
{
|
| 940 |
+
"domain": "hr", "difficulty": "hard", "has_order": True,
|
| 941 |
+
"description": "Employees above their department average salary using CTE",
|
| 942 |
+
"base_nl": "Find active employees whose salary is above their department's average. Return employee_name, department_name, salary, dept_avg_salary (rounded to 2 decimal places), sorted by salary descending.",
|
| 943 |
+
"sql": (
|
| 944 |
+
"WITH dept_avg AS ( "
|
| 945 |
+
" SELECT department_id, ROUND(AVG(salary), 2) AS dept_avg_salary "
|
| 946 |
+
" FROM employees "
|
| 947 |
+
" WHERE status = 'active' "
|
| 948 |
+
" GROUP BY department_id "
|
| 949 |
+
") "
|
| 950 |
+
"SELECT e.name AS employee_name, d.name AS department_name, "
|
| 951 |
+
" e.salary, da.dept_avg_salary "
|
| 952 |
+
"FROM employees e "
|
| 953 |
+
"JOIN departments d ON d.id = e.department_id "
|
| 954 |
+
"JOIN dept_avg da ON da.department_id = e.department_id "
|
| 955 |
+
"WHERE e.status = 'active' AND e.salary > da.dept_avg_salary "
|
| 956 |
+
"ORDER BY e.salary DESC"
|
| 957 |
+
),
|
| 958 |
+
},
|
| 959 |
+
]
|
| 960 |
+
|
| 961 |
+
|
| 962 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 963 |
+
# MASTER TEMPLATE REGISTRY
|
| 964 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 965 |
+
|
| 966 |
+
ALL_TEMPLATES: list[Template] = (
|
| 967 |
+
ECOMMERCE_TEMPLATES +
|
| 968 |
+
HEALTHCARE_TEMPLATES +
|
| 969 |
+
FINANCE_TEMPLATES +
|
| 970 |
+
HR_TEMPLATES
|
| 971 |
+
)
|
| 972 |
+
|
| 973 |
+
TEMPLATES_BY_DOMAIN: dict[str, list[Template]] = {
|
| 974 |
+
"ecommerce": ECOMMERCE_TEMPLATES,
|
| 975 |
+
"healthcare": HEALTHCARE_TEMPLATES,
|
| 976 |
+
"finance": FINANCE_TEMPLATES,
|
| 977 |
+
"hr": HR_TEMPLATES,
|
| 978 |
+
}
|
| 979 |
+
|
| 980 |
+
TEMPLATES_BY_DIFFICULTY: dict[str, list[Template]] = {
|
| 981 |
+
"easy": [t for t in ALL_TEMPLATES if t["difficulty"] == "easy"],
|
| 982 |
+
"medium": [t for t in ALL_TEMPLATES if t["difficulty"] == "medium"],
|
| 983 |
+
"hard": [t for t in ALL_TEMPLATES if t["difficulty"] == "hard"],
|
| 984 |
+
}
|
| 985 |
+
|
| 986 |
+
|
| 987 |
+
def template_stats() -> dict:
|
| 988 |
+
stats: dict = {"total": len(ALL_TEMPLATES), "by_domain": {}, "by_difficulty": {}}
|
| 989 |
+
for d in ["ecommerce","healthcare","finance","hr"]:
|
| 990 |
+
stats["by_domain"][d] = len(TEMPLATES_BY_DOMAIN[d])
|
| 991 |
+
for diff in ["easy","medium","hard"]:
|
| 992 |
+
stats["by_difficulty"][diff] = len(TEMPLATES_BY_DIFFICULTY[diff])
|
| 993 |
+
return stats
|
data_factory/validator.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
data_factory/validator.py
|
| 3 |
+
==========================
|
| 4 |
+
SQL execution validation layer.
|
| 5 |
+
|
| 6 |
+
GUARANTEE: Every record that passes this validator has a SQL that:
|
| 7 |
+
1. Runs without error against the actual seeded SQLite schema
|
| 8 |
+
2. Returns at least one row (non-empty result)
|
| 9 |
+
3. Returns the expected column names
|
| 10 |
+
|
| 11 |
+
No LLM-generated SQL ever reaches this validator — SQL always comes from
|
| 12 |
+
the human-verified template library. This validator is an extra safety net
|
| 13 |
+
to catch any copy-paste or formatting regressions.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import sqlite3
|
| 19 |
+
from dataclasses import dataclass, field
|
| 20 |
+
from typing import Any, Optional
|
| 21 |
+
|
| 22 |
+
from data_factory.schemas import build_connection, SCHEMA_CONTEXT
|
| 23 |
+
from data_factory.templates import Template
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 27 |
+
# DATA CLASSES
|
| 28 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class ValidationResult:
|
| 32 |
+
passed: bool
|
| 33 |
+
sql: str
|
| 34 |
+
error: Optional[str] = None
|
| 35 |
+
row_count: int = 0
|
| 36 |
+
columns: list[str] = field(default_factory=list)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@dataclass
|
| 40 |
+
class DataRecord:
|
| 41 |
+
"""One training example ready to be written to JSONL/Parquet."""
|
| 42 |
+
domain: str
|
| 43 |
+
difficulty: str
|
| 44 |
+
sql: str
|
| 45 |
+
nl_question: str # The NL paraphrase used as prompt
|
| 46 |
+
persona: str # ceo | chatty | lazy_typist | non_techie | analyst | augmented
|
| 47 |
+
has_order: bool
|
| 48 |
+
schema_context: str
|
| 49 |
+
row_count: int # From validation run
|
| 50 |
+
columns: list[str] # From validation run
|
| 51 |
+
source: str # "template_base" | "vllm_persona" | "rule_augmented"
|
| 52 |
+
template_id: int # Index into ALL_TEMPLATES
|
| 53 |
+
|
| 54 |
+
def to_training_dict(self) -> dict[str, Any]:
|
| 55 |
+
"""
|
| 56 |
+
Returns the dictionary that will be written to the output dataset.
|
| 57 |
+
|
| 58 |
+
Format is compatible with TRL / HuggingFace `datasets`:
|
| 59 |
+
prompt : chat-format messages list (system + user)
|
| 60 |
+
sql : ground-truth SQL (label / reward reference)
|
| 61 |
+
metadata: auxiliary fields for curriculum or filtering
|
| 62 |
+
"""
|
| 63 |
+
system_msg = (
|
| 64 |
+
"You are an expert SQL analyst. "
|
| 65 |
+
"Write a single SELECT query that answers the question. "
|
| 66 |
+
"Output ONLY the SQL query — no markdown, no explanation, no backticks."
|
| 67 |
+
)
|
| 68 |
+
user_msg = (
|
| 69 |
+
f"DATABASE SCHEMA\n"
|
| 70 |
+
f"---------------\n"
|
| 71 |
+
f"{self.schema_context}\n\n"
|
| 72 |
+
f"QUESTION: {self.nl_question}"
|
| 73 |
+
)
|
| 74 |
+
return {
|
| 75 |
+
"prompt": [
|
| 76 |
+
{"role": "system", "content": system_msg},
|
| 77 |
+
{"role": "user", "content": user_msg},
|
| 78 |
+
],
|
| 79 |
+
"sql": self.sql,
|
| 80 |
+
"metadata": {
|
| 81 |
+
"domain": self.domain,
|
| 82 |
+
"difficulty": self.difficulty,
|
| 83 |
+
"persona": self.persona,
|
| 84 |
+
"has_order": self.has_order,
|
| 85 |
+
"row_count": self.row_count,
|
| 86 |
+
"columns": self.columns,
|
| 87 |
+
"source": self.source,
|
| 88 |
+
"template_id": self.template_id,
|
| 89 |
+
},
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 94 |
+
# VALIDATOR
|
| 95 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 96 |
+
|
| 97 |
+
class SQLValidator:
|
| 98 |
+
"""
|
| 99 |
+
Validates SQL against a seeded in-memory SQLite connection.
|
| 100 |
+
|
| 101 |
+
One validator per domain to reuse the same connection for all templates
|
| 102 |
+
in that domain (performance optimization).
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
def __init__(self, domain: str, seed: int = 42) -> None:
|
| 106 |
+
self.domain = domain
|
| 107 |
+
self._conn = build_connection(domain, seed=seed)
|
| 108 |
+
|
| 109 |
+
def validate(self, sql: str) -> ValidationResult:
|
| 110 |
+
"""
|
| 111 |
+
Execute SQL and return a ValidationResult.
|
| 112 |
+
Never raises — always returns a result object.
|
| 113 |
+
"""
|
| 114 |
+
sql = sql.strip().rstrip(";")
|
| 115 |
+
if not sql:
|
| 116 |
+
return ValidationResult(passed=False, sql=sql, error="Empty SQL string.")
|
| 117 |
+
|
| 118 |
+
# Block any write operations
|
| 119 |
+
first_word = sql.split()[0].lower() if sql.split() else ""
|
| 120 |
+
forbidden = {"insert","update","delete","drop","alter","create","replace","truncate","pragma"}
|
| 121 |
+
if first_word in forbidden:
|
| 122 |
+
return ValidationResult(
|
| 123 |
+
passed=False, sql=sql,
|
| 124 |
+
error=f"Write operation '{first_word.upper()}' is not permitted."
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
try:
|
| 128 |
+
cur = self._conn.execute(sql)
|
| 129 |
+
cols = [d[0] for d in cur.description] if cur.description else []
|
| 130 |
+
rows = cur.fetchall()
|
| 131 |
+
return ValidationResult(
|
| 132 |
+
passed=True,
|
| 133 |
+
sql=sql,
|
| 134 |
+
row_count=len(rows),
|
| 135 |
+
columns=cols,
|
| 136 |
+
)
|
| 137 |
+
except sqlite3.Error as exc:
|
| 138 |
+
return ValidationResult(passed=False, sql=sql, error=str(exc))
|
| 139 |
+
|
| 140 |
+
def close(self) -> None:
|
| 141 |
+
self._conn.close()
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def validate_template(template: Template, seed: int = 42) -> ValidationResult:
|
| 145 |
+
"""Convenience function: validate a single template."""
|
| 146 |
+
v = SQLValidator(template["domain"], seed=seed)
|
| 147 |
+
result = v.validate(template["sql"])
|
| 148 |
+
v.close()
|
| 149 |
+
return result
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def validate_all_templates(templates: list[Template], seed: int = 42) -> dict[str, Any]:
|
| 153 |
+
"""
|
| 154 |
+
Run validation across all templates. Returns a summary dict.
|
| 155 |
+
Used during CI / smoke testing.
|
| 156 |
+
"""
|
| 157 |
+
from data_factory.schemas import SCHEMA_MAP
|
| 158 |
+
|
| 159 |
+
validators = {domain: SQLValidator(domain, seed) for domain in SCHEMA_MAP}
|
| 160 |
+
passed = []
|
| 161 |
+
failed = []
|
| 162 |
+
|
| 163 |
+
for i, t in enumerate(templates):
|
| 164 |
+
v = validators[t["domain"]]
|
| 165 |
+
result = v.validate(t["sql"])
|
| 166 |
+
if result.passed:
|
| 167 |
+
passed.append(i)
|
| 168 |
+
else:
|
| 169 |
+
failed.append({"index": i, "domain": t["domain"],
|
| 170 |
+
"sql": t["sql"][:80], "error": result.error})
|
| 171 |
+
|
| 172 |
+
for v in validators.values():
|
| 173 |
+
v.close()
|
| 174 |
+
|
| 175 |
+
return {
|
| 176 |
+
"total": len(templates),
|
| 177 |
+
"passed": len(passed),
|
| 178 |
+
"failed": len(failed),
|
| 179 |
+
"failures": failed,
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def build_record(
|
| 184 |
+
template: Template,
|
| 185 |
+
template_idx: int,
|
| 186 |
+
nl_question: str,
|
| 187 |
+
persona: str,
|
| 188 |
+
source: str,
|
| 189 |
+
validator: SQLValidator,
|
| 190 |
+
) -> Optional[DataRecord]:
|
| 191 |
+
"""
|
| 192 |
+
Validate the template SQL and, if it passes, build a DataRecord.
|
| 193 |
+
|
| 194 |
+
Parameters
|
| 195 |
+
----------
|
| 196 |
+
template : The source template (contains SQL, domain, difficulty).
|
| 197 |
+
template_idx : Index of template in ALL_TEMPLATES (for deduplication).
|
| 198 |
+
nl_question : The NL paraphrase to use as the prompt.
|
| 199 |
+
persona : Which persona/strategy generated this NL.
|
| 200 |
+
source : 'template_base' | 'vllm_persona' | 'rule_augmented'
|
| 201 |
+
validator : Pre-built SQLValidator for this domain.
|
| 202 |
+
|
| 203 |
+
Returns None if validation fails.
|
| 204 |
+
"""
|
| 205 |
+
vr = validator.validate(template["sql"])
|
| 206 |
+
if not vr.passed:
|
| 207 |
+
return None
|
| 208 |
+
|
| 209 |
+
return DataRecord(
|
| 210 |
+
domain=template["domain"],
|
| 211 |
+
difficulty=template["difficulty"],
|
| 212 |
+
sql=template["sql"],
|
| 213 |
+
nl_question=nl_question,
|
| 214 |
+
persona=persona,
|
| 215 |
+
has_order=template["has_order"],
|
| 216 |
+
schema_context=SCHEMA_CONTEXT[template["domain"]],
|
| 217 |
+
row_count=vr.row_count,
|
| 218 |
+
columns=vr.columns,
|
| 219 |
+
source=source,
|
| 220 |
+
template_id=template_idx,
|
| 221 |
+
)
|
env_server
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
from fastapi import FastAPI, Request
|
| 4 |
+
import uvicorn
|
| 5 |
+
|
| 6 |
+
sys.path.insert(0, "./server")
|
| 7 |
+
from environment import NL2SQLEnvironment
|
| 8 |
+
from models import NL2SQLAction
|
| 9 |
+
|
| 10 |
+
app = FastAPI()
|
| 11 |
+
env = NL2SQLEnvironment()
|
| 12 |
+
|
| 13 |
+
@app.post("/reset")
|
| 14 |
+
async def reset(request: Request):
|
| 15 |
+
data = await request.json()
|
| 16 |
+
# Now we take task_name directly from the API call
|
| 17 |
+
task_name = data.get("task_name", "simple-filter")
|
| 18 |
+
print(f"🔄 Environment Resetting for Task: {task_name}")
|
| 19 |
+
obs = env.reset(task_name=task_name)
|
| 20 |
+
return {"observation": obs.__dict__}
|
| 21 |
+
|
| 22 |
+
@app.post("/step")
|
| 23 |
+
async def step(request: Request):
|
| 24 |
+
data = await request.json()
|
| 25 |
+
query = data.get("query", "")
|
| 26 |
+
print(f"⏩ Executing SQL: {query[:60]}...")
|
| 27 |
+
|
| 28 |
+
action = NL2SQLAction(query=query)
|
| 29 |
+
obs = env.step(action)
|
| 30 |
+
return {"observation": obs.__dict__}
|
| 31 |
+
|
| 32 |
+
if __name__ == "__main__":
|
| 33 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
folder.txt
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.
|
| 2 |
+
├── check_quality.py
|
| 3 |
+
├── clean_dataset.py
|
| 4 |
+
├── client.py
|
| 5 |
+
├── custom_train.py
|
| 6 |
+
├── data_expander.py
|
| 7 |
+
├── data_factory
|
| 8 |
+
│ ├── augmentor.py
|
| 9 |
+
│ ├── config.py
|
| 10 |
+
│ ├── generate_data.py
|
| 11 |
+
│ ├── generator.py
|
| 12 |
+
│ ├── __init__.py
|
| 13 |
+
│ ├── pipeline.py
|
| 14 |
+
│ ├── run_data_factory.py
|
| 15 |
+
│ ├── schemas.py
|
| 16 |
+
│ ├── templates.py
|
| 17 |
+
│ └── validator.py
|
| 18 |
+
├── Dockerfile
|
| 19 |
+
├── edge_cases.jsonl
|
| 20 |
+
├── env_server
|
| 21 |
+
├── folder.txt
|
| 22 |
+
├── generate_data.py
|
| 23 |
+
├── generate_edge_cases.py
|
| 24 |
+
├── inference.py
|
| 25 |
+
├── __init__.py
|
| 26 |
+
├── llm_hybrid_templates.json
|
| 27 |
+
├── local_test.py
|
| 28 |
+
├── merge_model.py
|
| 29 |
+
├── mini_server.py
|
| 30 |
+
├── models.py
|
| 31 |
+
├── nl2sql_50k_elite_dataset_1.jsonl
|
| 32 |
+
├── nl2sql_50k_elite_dataset.jsonl
|
| 33 |
+
├── nl2sql_cleaned_ready_to_train.jsonl
|
| 34 |
+
├── nl2sql_merged_final.jsonl
|
| 35 |
+
├── openenv.yaml
|
| 36 |
+
├── pyproject.toml
|
| 37 |
+
├── qwen-7b-coder-nl2sql-grpo
|
| 38 |
+
│ ├── checkpoint-70
|
| 39 |
+
│ │ ├── adapter_config.json
|
| 40 |
+
│ │ ├── adapter_model.safetensors
|
| 41 |
+
│ │ ├── chat_template.jinja
|
| 42 |
+
│ │ ├── optimizer.pt
|
| 43 |
+
│ │ ├── README.md
|
| 44 |
+
│ │ ├── rng_state_0.pth
|
| 45 |
+
│ │ ├── rng_state_1.pth
|
| 46 |
+
│ │ ├── scheduler.pt
|
| 47 |
+
│ │ ├── tokenizer_config.json
|
| 48 |
+
│ │ ├── tokenizer.json
|
| 49 |
+
│ │ ├── trainer_state.json
|
| 50 |
+
│ │ └── training_args.bin
|
| 51 |
+
│ ├── final
|
| 52 |
+
│ │ ├── adapter_config.json
|
| 53 |
+
│ │ ├── adapter_model.safetensors
|
| 54 |
+
│ │ ├── chat_template.jinja
|
| 55 |
+
│ │ ├── README.md
|
| 56 |
+
│ │ ├── tokenizer_config.json
|
| 57 |
+
│ │ └── tokenizer.json
|
| 58 |
+
│ └── README.md
|
| 59 |
+
├── qwen-7b-coder-nl2sql-grpo-v2
|
| 60 |
+
├── qwen-7b-nl2sql-merged
|
| 61 |
+
│ ├── chat_template.jinja
|
| 62 |
+
│ ├── config.json
|
| 63 |
+
│ ├── generation_config.json
|
| 64 |
+
│ ├── model.safetensors
|
| 65 |
+
│ ├── tokenizer_config.json
|
| 66 |
+
│ └── tokenizer.json
|
| 67 |
+
├── README.md
|
| 68 |
+
├── scripts
|
| 69 |
+
│ ├── run_local.sh
|
| 70 |
+
│ └── smoke_test.sh
|
| 71 |
+
├── server
|
| 72 |
+
│ ├── app.py
|
| 73 |
+
│ ├── db
|
| 74 |
+
│ │ ├── __init__.py
|
| 75 |
+
│ │ ├── schema.sql
|
| 76 |
+
│ │ └── seed.py
|
| 77 |
+
│ ├── environment.py
|
| 78 |
+
│ ├── grader.py
|
| 79 |
+
│ ├── __init__.py
|
| 80 |
+
│ ├── requirements.txt
|
| 81 |
+
│ └── tasks
|
| 82 |
+
│ ├── base.py
|
| 83 |
+
│ ├── easy.py
|
| 84 |
+
│ ├── hard.py
|
| 85 |
+
│ ├── __init__.py
|
| 86 |
+
│ └── medium.py
|
| 87 |
+
├── swapped_templates.json
|
| 88 |
+
├── tests
|
| 89 |
+
│ ├── conftest.py
|
| 90 |
+
│ ├── __init__.py
|
| 91 |
+
│ └── test_all.py
|
| 92 |
+
├── train.py
|
| 93 |
+
└── value_swapper.py
|
| 94 |
+
|
| 95 |
+
11 directories, 81 files
|
generate_data.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import json
|
| 4 |
+
import torch
|
| 5 |
+
import hashlib
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 9 |
+
|
| 10 |
+
# GPU CONFIG - All 4 H100s engaged
|
| 11 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,7"
|
| 12 |
+
|
| 13 |
+
PROJECT_ROOT = os.path.abspath(os.path.dirname(__file__))
|
| 14 |
+
if PROJECT_ROOT not in sys.path:
|
| 15 |
+
sys.path.insert(0, PROJECT_ROOT)
|
| 16 |
+
|
| 17 |
+
from data_factory.schemas import SCHEMA_CONTEXT
|
| 18 |
+
from data_factory.validator import SQLValidator
|
| 19 |
+
|
| 20 |
+
# CONFIG
|
| 21 |
+
MODEL_NAME = "Qwen/Qwen2.5-72B-Instruct"
|
| 22 |
+
TARGET_TEMPLATES = 10000
|
| 23 |
+
OUTPUT_FILE = "llm_10k_base_templates.json"
|
| 24 |
+
BATCH_SIZE = 64
|
| 25 |
+
|
| 26 |
+
PROMPT_TEMPLATE = """
|
| 27 |
+
You are a senior expert in SQLite schema design and NL2SQL dataset generation.
|
| 28 |
+
|
| 29 |
+
TASK
|
| 30 |
+
Generate exactly 10 UNIQUE, COMPLEX, and FULLY VALID SQLite SQL SELECT queries for the given schema.
|
| 31 |
+
For each query, also write a natural language question that a real user might ask.
|
| 32 |
+
|
| 33 |
+
HARD RULES
|
| 34 |
+
- Output ONLY a valid JSON array.
|
| 35 |
+
- Do NOT wrap output in markdown, code fences, or explanations.
|
| 36 |
+
- Every item must be a JSON object with exactly these keys:
|
| 37 |
+
- "sql"
|
| 38 |
+
- "base_nl"
|
| 39 |
+
- "difficulty"
|
| 40 |
+
- "has_order"
|
| 41 |
+
- All SQL must be a single SELECT statement.
|
| 42 |
+
- Do NOT use INSERT, UPDATE, DELETE, DROP, CREATE, ALTER, PRAGMA, ATTACH, DETACH, or any DDL/DML.
|
| 43 |
+
- Every table and column used in SQL must exist in the provided schema.
|
| 44 |
+
- Do NOT invent columns, tables, aliases, or constraints.
|
| 45 |
+
- SQL must be valid for SQLite.
|
| 46 |
+
- Prefer queries that are meaningfully different from each other.
|
| 47 |
+
- Avoid repetitive templates.
|
| 48 |
+
- Each SQL should test a different reasoning pattern.
|
| 49 |
+
- Each base_nl should sound natural and distinct from the others.
|
| 50 |
+
- Use advanced SQL patterns where appropriate:
|
| 51 |
+
- multiple JOINs
|
| 52 |
+
- CTEs
|
| 53 |
+
- subqueries
|
| 54 |
+
- window functions such as ROW_NUMBER, RANK, DENSE_RANK, LAG, LEAD
|
| 55 |
+
- GROUP BY and HAVING
|
| 56 |
+
- conditional aggregation
|
| 57 |
+
- anti-joins / exclusion logic
|
| 58 |
+
- top-N per group
|
| 59 |
+
- time-based filtering
|
| 60 |
+
- Exactly 3 of the 10 queries must be "easy" (basic filtering, simple lookups, 1-2 tables).
|
| 61 |
+
- Exactly 3 of the 10 queries must be "medium" (moderate complexity, standard JOINs, basic aggregation).
|
| 62 |
+
- Exactly 4 of the 10 queries must be genuinely "hard" (advanced patterns, CTEs, subqueries, window functions).
|
| 63 |
+
- Ensure the "difficulty" key strictly contains one of these exact string values: "easy", "medium", or "hard".
|
| 64 |
+
|
| 65 |
+
QUALITY TARGETS
|
| 66 |
+
- The SQL should be executable as written.
|
| 67 |
+
- The question should be answerable from the schema alone.
|
| 68 |
+
- Prefer business-like, realistic analytics questions.
|
| 69 |
+
- Prefer queries that require combining 2 to 4 tables.
|
| 70 |
+
- If a query uses aggregation, ensure the NL clearly implies aggregation.
|
| 71 |
+
- If a query uses ordering, include "has_order": true.
|
| 72 |
+
- If a query does not require ordering, set "has_order": false.
|
| 73 |
+
- Make the 10 queries cover diverse intent types:
|
| 74 |
+
1. ranking
|
| 75 |
+
2. comparison against average or median
|
| 76 |
+
3. top/bottom-N
|
| 77 |
+
4. grouped aggregation
|
| 78 |
+
5. time filtering
|
| 79 |
+
6. multi-join analysis
|
| 80 |
+
7. exclusion / NOT EXISTS
|
| 81 |
+
8. window-function based analysis
|
| 82 |
+
9. conditional counting
|
| 83 |
+
10. trend or interval-based logic
|
| 84 |
+
|
| 85 |
+
SCHEMA
|
| 86 |
+
{schema}
|
| 87 |
+
|
| 88 |
+
OUTPUT FORMAT
|
| 89 |
+
Return ONLY a valid JSON array of 10 objects.
|
| 90 |
+
|
| 91 |
+
Example structure:
|
| 92 |
+
[
|
| 93 |
+
{{
|
| 94 |
+
"sql": "SELECT ...",
|
| 95 |
+
"base_nl": "Show ...",
|
| 96 |
+
"difficulty": "hard",
|
| 97 |
+
"has_order": true
|
| 98 |
+
}}
|
| 99 |
+
]
|
| 100 |
+
|
| 101 |
+
FINAL SELF-CHECK BEFORE RESPONDING
|
| 102 |
+
- Confirm the output is valid JSON.
|
| 103 |
+
- Confirm there are exactly 10 objects.
|
| 104 |
+
- Confirm every SQL is a single SELECT.
|
| 105 |
+
- Confirm no hallucinated schema elements exist.
|
| 106 |
+
- Confirm the 10 questions are not paraphrases of each other.
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
def extract_json(raw_text):
|
| 110 |
+
text = raw_text.strip()
|
| 111 |
+
if text.startswith("```json"):
|
| 112 |
+
text = text[7:-3].strip()
|
| 113 |
+
elif text.startswith("```"):
|
| 114 |
+
text = text[3:-3].strip()
|
| 115 |
+
start = text.find("[")
|
| 116 |
+
end = text.rfind("]")
|
| 117 |
+
if start != -1 and end != -1:
|
| 118 |
+
return text[start:end+1]
|
| 119 |
+
return None
|
| 120 |
+
|
| 121 |
+
def main():
|
| 122 |
+
print("Loading Model Qwen-72B (SDPA) for 10K Mining...")
|
| 123 |
+
|
| 124 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
|
| 125 |
+
custom_max_memory = {
|
| 126 |
+
0: "60GiB", # System GPU 0 (Has 13GB used, ~67GB free)
|
| 127 |
+
1: "75GiB", # System GPU 1 (Fully free)
|
| 128 |
+
2: "75GiB", # System GPU 2 (Fully free)
|
| 129 |
+
3: "75GiB", # System GPU 3 (Fully free)
|
| 130 |
+
4: "75GiB", # System GPU 4 (Fully free)
|
| 131 |
+
5: "45GiB" # System GPU 7 (Has 25GB used, ~55GB free)
|
| 132 |
+
}
|
| 133 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 134 |
+
MODEL_NAME,
|
| 135 |
+
device_map="auto",
|
| 136 |
+
max_memory = custom_max_memory,
|
| 137 |
+
torch_dtype=torch.bfloat16,
|
| 138 |
+
attn_implementation="sdpa"
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
domains = list(SCHEMA_CONTEXT.keys())
|
| 142 |
+
valid_templates = []
|
| 143 |
+
seen_sql_hashes = set()
|
| 144 |
+
|
| 145 |
+
# Resume support: Load existing templates to prevent duplicates
|
| 146 |
+
if os.path.exists(OUTPUT_FILE):
|
| 147 |
+
with open(OUTPUT_FILE, "r") as f:
|
| 148 |
+
valid_templates = json.load(f)
|
| 149 |
+
for t in valid_templates:
|
| 150 |
+
seen_sql_hashes.add(hashlib.md5(t["sql"].lower().encode()).hexdigest())
|
| 151 |
+
|
| 152 |
+
pbar = tqdm(total=TARGET_TEMPLATES, initial=len(valid_templates), desc="Mining 10K Base Templates")
|
| 153 |
+
|
| 154 |
+
validators = {}
|
| 155 |
+
domain_idx = 0
|
| 156 |
+
|
| 157 |
+
while len(valid_templates) < TARGET_TEMPLATES:
|
| 158 |
+
batch_prompts = []
|
| 159 |
+
batch_domains = []
|
| 160 |
+
|
| 161 |
+
# Prepare Batch
|
| 162 |
+
for _ in range(BATCH_SIZE):
|
| 163 |
+
domain = domains[domain_idx % len(domains)]
|
| 164 |
+
schema_string = SCHEMA_CONTEXT[domain]
|
| 165 |
+
domain_idx += 1
|
| 166 |
+
|
| 167 |
+
messages = [
|
| 168 |
+
{"role": "system", "content": "You output only valid JSON arrays. Do not include markdown."},
|
| 169 |
+
{"role": "user", "content": PROMPT_TEMPLATE.format(schema=schema_string)}
|
| 170 |
+
]
|
| 171 |
+
chat_text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 172 |
+
batch_prompts.append(chat_text)
|
| 173 |
+
batch_domains.append(domain)
|
| 174 |
+
|
| 175 |
+
inputs = tokenizer(batch_prompts, return_tensors="pt", padding=True, truncation=True).to(model.device)
|
| 176 |
+
|
| 177 |
+
try:
|
| 178 |
+
tqdm.write(f"\n[DEBUG] Sending batch of {BATCH_SIZE} to model.generate(). Please wait...")
|
| 179 |
+
with torch.no_grad():
|
| 180 |
+
outputs = model.generate(
|
| 181 |
+
**inputs,
|
| 182 |
+
max_new_tokens=5000,
|
| 183 |
+
do_sample=True,
|
| 184 |
+
temperature=0.55,
|
| 185 |
+
top_p=0.9,
|
| 186 |
+
pad_token_id=tokenizer.eos_token_id
|
| 187 |
+
)
|
| 188 |
+
tqdm.write("[DEBUG] Model generation finished. Decoding responses...")
|
| 189 |
+
|
| 190 |
+
# Output Slicing
|
| 191 |
+
input_length = inputs.input_ids.shape[1]
|
| 192 |
+
generated_tokens = outputs[:, input_length:]
|
| 193 |
+
responses = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
|
| 194 |
+
|
| 195 |
+
batch_added = 0
|
| 196 |
+
for i, (response, domain) in enumerate(zip(responses, batch_domains)):
|
| 197 |
+
tqdm.write(f"\n[DEBUG] Processing Response {i+1}/{BATCH_SIZE} for domain: {domain}")
|
| 198 |
+
|
| 199 |
+
json_text = extract_json(response)
|
| 200 |
+
if not json_text:
|
| 201 |
+
tqdm.write(f"[DEBUG] extract_json failed. Raw text snippet: {response[:200]}...")
|
| 202 |
+
continue
|
| 203 |
+
|
| 204 |
+
try:
|
| 205 |
+
generated_data = json.loads(json_text)
|
| 206 |
+
tqdm.write(f"[DEBUG] JSON loaded successfully. Found {len(generated_data)} items.")
|
| 207 |
+
except Exception as e:
|
| 208 |
+
tqdm.write(f"[DEBUG] json.loads failed. Error: {e}")
|
| 209 |
+
tqdm.write(f"[DEBUG] Bad JSON snippet: {json_text[:200]}...")
|
| 210 |
+
continue
|
| 211 |
+
|
| 212 |
+
if domain not in validators:
|
| 213 |
+
validators[domain] = SQLValidator(domain, seed=42)
|
| 214 |
+
validator = validators[domain]
|
| 215 |
+
|
| 216 |
+
for item in generated_data:
|
| 217 |
+
if not isinstance(item, dict): continue
|
| 218 |
+
|
| 219 |
+
sql = item.get("sql", "").strip()
|
| 220 |
+
if not sql: continue
|
| 221 |
+
|
| 222 |
+
# Check for duplicates using hash
|
| 223 |
+
sql_hash = hashlib.md5(sql.lower().encode()).hexdigest()
|
| 224 |
+
if sql_hash in seen_sql_hashes:
|
| 225 |
+
tqdm.write("[DEBUG] Duplicate query skipped.")
|
| 226 |
+
continue
|
| 227 |
+
|
| 228 |
+
val_result = validator.validate(sql)
|
| 229 |
+
|
| 230 |
+
# Hard validation rule: SQL must execute AND return rows
|
| 231 |
+
if val_result.passed and val_result.row_count > 0:
|
| 232 |
+
tqdm.write(f"[DEBUG] SQL Passed (Rows: {val_result.row_count}): {sql[:50]}...")
|
| 233 |
+
item["domain"] = domain
|
| 234 |
+
item["id"] = f"base_{len(valid_templates)}"
|
| 235 |
+
valid_templates.append(item)
|
| 236 |
+
seen_sql_hashes.add(sql_hash)
|
| 237 |
+
batch_added += 1
|
| 238 |
+
else:
|
| 239 |
+
tqdm.write(f"[DEBUG] SQL Failed Validation or 0 Rows (Passed: {val_result.passed}, Rows: {val_result.row_count}): {sql[:50]}...")
|
| 240 |
+
|
| 241 |
+
if batch_added > 0:
|
| 242 |
+
pbar.update(batch_added)
|
| 243 |
+
tqdm.write(f"[DEBUG] Auto-saving {batch_added} new templates to JSON...")
|
| 244 |
+
# Auto-save after every successful batch
|
| 245 |
+
with open(OUTPUT_FILE, "w") as f:
|
| 246 |
+
json.dump(valid_templates, f, indent=2)
|
| 247 |
+
|
| 248 |
+
if len(valid_templates) >= TARGET_TEMPLATES:
|
| 249 |
+
break
|
| 250 |
+
|
| 251 |
+
except Exception as e:
|
| 252 |
+
tqdm.write(f"\n[DEBUG] CRITICAL EXCEPTION CAUGHT: {e}")
|
| 253 |
+
continue
|
| 254 |
+
|
| 255 |
+
# Close validators
|
| 256 |
+
for v in validators.values():
|
| 257 |
+
v.close()
|
| 258 |
+
|
| 259 |
+
pbar.close()
|
| 260 |
+
print(f"\nBoom! Generated {len(valid_templates)} Elite Base Templates!")
|
| 261 |
+
|
| 262 |
+
if __name__ == "__main__":
|
| 263 |
+
main()
|
generate_edge_cases.py
ADDED
|
@@ -0,0 +1,319 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
generate_edge_cases.py
|
| 3 |
+
======================
|
| 4 |
+
Targeted edge-case data generator for the 4 failure patterns found in eval:
|
| 5 |
+
1. ROW_NUMBER vs RANK vs DENSE_RANK (tie-breaking semantics)
|
| 6 |
+
2. strftime month as INTEGER (not '%Y-%m' string)
|
| 7 |
+
3. SELECT column discipline (no unrequested extras)
|
| 8 |
+
4. LAG/LEAD period-over-period
|
| 9 |
+
5. HAVING vs WHERE placement
|
| 10 |
+
6. COUNT(DISTINCT) vs COUNT
|
| 11 |
+
|
| 12 |
+
Produces: edge_cases.jsonl (same chat format as nl2sql_cleaned_ready_to_train.jsonl)
|
| 13 |
+
Run: python generate_edge_cases.py
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
import os, sys, json, re, hashlib
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, AwqConfig, BitsAndBytesConfig
|
| 19 |
+
import torch
|
| 20 |
+
import transformers.activations
|
| 21 |
+
# Yeh line AutoAWQ ko bewakoof banayegi taaki wo crash na ho
|
| 22 |
+
if not hasattr(transformers.activations, 'PytorchGELUTanh'):
|
| 23 |
+
transformers.activations.PytorchGELUTanh = transformers.activations.NewGELUActivation
|
| 24 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "3,1,6,7"
|
| 25 |
+
|
| 26 |
+
PROJECT_ROOT = os.path.abspath(os.path.dirname(__file__))
|
| 27 |
+
if PROJECT_ROOT not in sys.path:
|
| 28 |
+
sys.path.insert(0, PROJECT_ROOT)
|
| 29 |
+
|
| 30 |
+
from data_factory.schemas import SCHEMA_CONTEXT
|
| 31 |
+
|
| 32 |
+
quantization_config = BitsAndBytesConfig(
|
| 33 |
+
load_in_4bit=True,
|
| 34 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
| 35 |
+
bnb_4bit_use_double_quant=True,
|
| 36 |
+
bnb_4bit_quant_type="nf4"
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
MODEL_NAME = "Qwen/Qwen2.5-72B-Instruct"
|
| 40 |
+
OUTPUT_FILE = "edge_cases.jsonl"
|
| 41 |
+
BATCH_SIZE = 8 # smaller — edge prompts are long
|
| 42 |
+
SAMPLES_PER_PATTERN = 715 # ~6 batches per pattern → 5005 total edge samples
|
| 43 |
+
|
| 44 |
+
SYSTEM_PROMPT = (
|
| 45 |
+
"You are a Senior SQL Architect. "
|
| 46 |
+
"Output ONLY the SQL query. Use SQLite syntax."
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
# ── Edge-case prompt templates ──────────────────────────────────────────────
|
| 50 |
+
# Each entry: (pattern_tag, user_prompt_template)
|
| 51 |
+
# {schema} is filled at runtime with a random domain schema.
|
| 52 |
+
|
| 53 |
+
EDGE_PATTERNS = [
|
| 54 |
+
|
| 55 |
+
# 1. ROW_NUMBER tie-breaking — the #1 failure
|
| 56 |
+
("row_number_tiebreak", """SCHEMA:
|
| 57 |
+
{schema}
|
| 58 |
+
|
| 59 |
+
Generate exactly 8 NL2SQL pairs that REQUIRE ROW_NUMBER() (not RANK or DENSE_RANK) \
|
| 60 |
+
because the question explicitly says "pick one winner when there is a tie" \
|
| 61 |
+
using a tiebreaker column (e.g. lower id, earlier date).
|
| 62 |
+
|
| 63 |
+
Output ONLY a valid JSON array:
|
| 64 |
+
[
|
| 65 |
+
{{"nl": "...", "sql": "SELECT ..."}},
|
| 66 |
+
...
|
| 67 |
+
]
|
| 68 |
+
|
| 69 |
+
Rules:
|
| 70 |
+
- Every SQL must use ROW_NUMBER() OVER (...) not RANK().
|
| 71 |
+
- The OVER clause ORDER BY must include the tiebreaker column.
|
| 72 |
+
- WHERE rn = 1 must appear in an outer query or CTE.
|
| 73 |
+
- No markdown. No explanation. Just the JSON array."""),
|
| 74 |
+
|
| 75 |
+
# 2. RANK / DENSE_RANK — when ties SHOULD persist
|
| 76 |
+
("rank_dense_rank", """SCHEMA:
|
| 77 |
+
{schema}
|
| 78 |
+
|
| 79 |
+
Generate exactly 8 NL2SQL pairs where RANK() or DENSE_RANK() is the CORRECT choice \
|
| 80 |
+
because the question says "show all tied records at the same rank".
|
| 81 |
+
|
| 82 |
+
Output ONLY a valid JSON array:
|
| 83 |
+
[
|
| 84 |
+
{{"nl": "...", "sql": "SELECT ..."}},
|
| 85 |
+
...
|
| 86 |
+
]
|
| 87 |
+
|
| 88 |
+
Rules:
|
| 89 |
+
- Use RANK() when question implies gaps after ties, DENSE_RANK() when no gaps.
|
| 90 |
+
- NL must make the tie-semantics explicit ("same rank", "tied positions").
|
| 91 |
+
- No markdown. No explanation. Just the JSON array."""),
|
| 92 |
+
|
| 93 |
+
# 3. strftime integer month output
|
| 94 |
+
("strftime_integer_month", """SCHEMA:
|
| 95 |
+
{schema}
|
| 96 |
+
|
| 97 |
+
Generate exactly 8 NL2SQL pairs where the question asks for a numeric month number \
|
| 98 |
+
(1–12), NOT a 'YYYY-MM' string.
|
| 99 |
+
|
| 100 |
+
Output ONLY a valid JSON array:
|
| 101 |
+
[
|
| 102 |
+
{{"nl": "...", "sql": "SELECT ..."}},
|
| 103 |
+
...
|
| 104 |
+
]
|
| 105 |
+
|
| 106 |
+
Rules:
|
| 107 |
+
- SQL must use CAST(strftime('%m', <col>) AS INTEGER) to produce integer month.
|
| 108 |
+
- Do NOT use strftime('%Y-%m', ...) when the question asks for month number.
|
| 109 |
+
- NL questions must say "month number", "which month (1–12)", or similar.
|
| 110 |
+
- No markdown. No explanation. Just the JSON array."""),
|
| 111 |
+
|
| 112 |
+
# 4. SELECT column discipline
|
| 113 |
+
("select_column_discipline", """SCHEMA:
|
| 114 |
+
{schema}
|
| 115 |
+
|
| 116 |
+
Generate exactly 8 NL2SQL pairs where the question explicitly names ONLY the columns \
|
| 117 |
+
to return. The SQL must select EXACTLY those columns — no extras like avg_salary, \
|
| 118 |
+
row counts, or intermediate aggregates.
|
| 119 |
+
|
| 120 |
+
Output ONLY a valid JSON array:
|
| 121 |
+
[
|
| 122 |
+
{{"nl": "...", "sql": "SELECT ..."}},
|
| 123 |
+
...
|
| 124 |
+
]
|
| 125 |
+
|
| 126 |
+
Rules:
|
| 127 |
+
- NL must say "return only X, Y, Z" or "show me only the name and total".
|
| 128 |
+
- SQL SELECT list must contain only those columns.
|
| 129 |
+
- If aggregation is needed internally (e.g. for HAVING), do NOT expose it in SELECT.
|
| 130 |
+
- No markdown. No explanation. Just the JSON array."""),
|
| 131 |
+
|
| 132 |
+
# 5. LAG / LEAD period-over-period
|
| 133 |
+
("lag_lead_period", """SCHEMA:
|
| 134 |
+
{schema}
|
| 135 |
+
|
| 136 |
+
Generate exactly 8 NL2SQL pairs that require LAG() or LEAD() window functions \
|
| 137 |
+
for period-over-period comparison (e.g. month-over-month revenue change, \
|
| 138 |
+
previous order amount, next appointment date).
|
| 139 |
+
|
| 140 |
+
Output ONLY a valid JSON array:
|
| 141 |
+
[
|
| 142 |
+
{{"nl": "...", "sql": "SELECT ..."}},
|
| 143 |
+
...
|
| 144 |
+
]
|
| 145 |
+
|
| 146 |
+
Rules:
|
| 147 |
+
- Use LAG(<col>, 1) OVER (ORDER BY ...) or LEAD(...) correctly.
|
| 148 |
+
- NL must imply comparison with previous or next row/period.
|
| 149 |
+
- No markdown. No explanation. Just the JSON array."""),
|
| 150 |
+
|
| 151 |
+
# 6. HAVING vs WHERE
|
| 152 |
+
("having_vs_where", """SCHEMA:
|
| 153 |
+
{schema}
|
| 154 |
+
|
| 155 |
+
Generate exactly 8 NL2SQL pairs that test correct placement of filter conditions:
|
| 156 |
+
- Conditions on raw columns → WHERE
|
| 157 |
+
- Conditions on aggregates → HAVING
|
| 158 |
+
Include 4 pairs where a wrong model might put an aggregate condition in WHERE (trap).
|
| 159 |
+
|
| 160 |
+
Output ONLY a valid JSON array:
|
| 161 |
+
[
|
| 162 |
+
{{"nl": "...", "sql": "SELECT ..."}},
|
| 163 |
+
...
|
| 164 |
+
]
|
| 165 |
+
|
| 166 |
+
Rules:
|
| 167 |
+
- SQL must never filter an aggregate (COUNT, SUM, AVG) inside WHERE.
|
| 168 |
+
- SQL must never put a raw column filter inside HAVING.
|
| 169 |
+
- No markdown. No explanation. Just the JSON array."""),
|
| 170 |
+
|
| 171 |
+
# 7. COUNT(DISTINCT) vs COUNT
|
| 172 |
+
("count_distinct", """SCHEMA:
|
| 173 |
+
{schema}
|
| 174 |
+
|
| 175 |
+
Generate exactly 8 NL2SQL pairs where the question specifically asks for \
|
| 176 |
+
"unique", "distinct", or "different" counts — requiring COUNT(DISTINCT col).
|
| 177 |
+
Also include 2 pairs where COUNT(*) is correct to reinforce the contrast.
|
| 178 |
+
|
| 179 |
+
Output ONLY a valid JSON array:
|
| 180 |
+
[
|
| 181 |
+
{{"nl": "...", "sql": "SELECT ..."}},
|
| 182 |
+
...
|
| 183 |
+
]
|
| 184 |
+
|
| 185 |
+
Rules:
|
| 186 |
+
- When NL says "unique/distinct", SQL must use COUNT(DISTINCT <col>).
|
| 187 |
+
- When NL says "total orders placed" (not distinct), use COUNT(*) or COUNT(id).
|
| 188 |
+
- No markdown. No explanation. Just the JSON array."""),
|
| 189 |
+
]
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
# ── Helpers ──────────────────────────────────────────────────────────────────
|
| 193 |
+
|
| 194 |
+
def extract_json_array(text: str) -> str:
|
| 195 |
+
text = text.strip()
|
| 196 |
+
# strip code fences if model leaks them
|
| 197 |
+
text = re.sub(r"```(?:json)?\n?(.*?)```", r"\1", text, flags=re.DOTALL).strip()
|
| 198 |
+
s, e = text.find("["), text.rfind("]")
|
| 199 |
+
return text[s:e+1] if s != -1 and e != -1 else "[]"
|
| 200 |
+
|
| 201 |
+
def get_hash(text: str) -> str:
|
| 202 |
+
return hashlib.md5(text.lower().strip().encode()).hexdigest()
|
| 203 |
+
|
| 204 |
+
def build_record(nl: str, sql: str, domain: str) -> dict:
|
| 205 |
+
return {
|
| 206 |
+
"prompt": [
|
| 207 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 208 |
+
{"role": "user", "content": f"SCHEMA: {SCHEMA_CONTEXT[domain]}\nQUESTION: {nl}"}
|
| 209 |
+
],
|
| 210 |
+
"sql": sql
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
# ── Main ─────────────────────────────────────────────────────────────────────
|
| 215 |
+
|
| 216 |
+
def main():
|
| 217 |
+
print(f"Loading {MODEL_NAME}...")
|
| 218 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side="left")
|
| 219 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 220 |
+
custom_memory = {0:"30GiB",1:"75GiB",2:"45GiB",3:"45GiB"}
|
| 221 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 222 |
+
MODEL_NAME,
|
| 223 |
+
device_map="auto",
|
| 224 |
+
max_memory=custom_memory,
|
| 225 |
+
quantization_config=quantization_config,
|
| 226 |
+
torch_dtype=torch.bfloat16,
|
| 227 |
+
low_cpu_mem_usage=True,
|
| 228 |
+
attn_implementation = "sdpa"
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
domains = list(SCHEMA_CONTEXT.keys())
|
| 232 |
+
seen = set()
|
| 233 |
+
total = 0
|
| 234 |
+
|
| 235 |
+
out = open(OUTPUT_FILE, "a", encoding="utf-8")
|
| 236 |
+
|
| 237 |
+
for pattern_tag, prompt_tmpl in EDGE_PATTERNS:
|
| 238 |
+
print(f"\n[PATTERN] {pattern_tag}")
|
| 239 |
+
collected = 0
|
| 240 |
+
domain_idx = 0
|
| 241 |
+
pbar = tqdm(total=SAMPLES_PER_PATTERN, desc=pattern_tag)
|
| 242 |
+
|
| 243 |
+
while collected < SAMPLES_PER_PATTERN:
|
| 244 |
+
# Build a batch of prompts, cycling through domains
|
| 245 |
+
batch_domains = []
|
| 246 |
+
batch_prompts = []
|
| 247 |
+
for _ in range(BATCH_SIZE):
|
| 248 |
+
domain = domains[domain_idx % len(domains)]
|
| 249 |
+
domain_idx += 1
|
| 250 |
+
user_msg = prompt_tmpl.format(schema=SCHEMA_CONTEXT[domain])
|
| 251 |
+
msgs = [
|
| 252 |
+
{"role": "system", "content": "You output only valid JSON arrays. No markdown."},
|
| 253 |
+
{"role": "user", "content": user_msg}
|
| 254 |
+
]
|
| 255 |
+
batch_prompts.append(
|
| 256 |
+
tokenizer.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
|
| 257 |
+
)
|
| 258 |
+
batch_domains.append(domain)
|
| 259 |
+
|
| 260 |
+
inputs = tokenizer(
|
| 261 |
+
batch_prompts, return_tensors="pt", padding=True, truncation=True
|
| 262 |
+
).to(model.device)
|
| 263 |
+
|
| 264 |
+
try:
|
| 265 |
+
with torch.no_grad():
|
| 266 |
+
outputs = model.generate(
|
| 267 |
+
**inputs,
|
| 268 |
+
max_new_tokens=2048,
|
| 269 |
+
do_sample=True,
|
| 270 |
+
temperature=0.5,
|
| 271 |
+
top_p=0.9,
|
| 272 |
+
pad_token_id=tokenizer.eos_token_id
|
| 273 |
+
)
|
| 274 |
+
responses = tokenizer.batch_decode(
|
| 275 |
+
outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
for resp, domain in zip(responses, batch_domains):
|
| 279 |
+
raw = extract_json_array(resp)
|
| 280 |
+
try:
|
| 281 |
+
pairs = json.loads(raw)
|
| 282 |
+
except Exception:
|
| 283 |
+
continue
|
| 284 |
+
|
| 285 |
+
for pair in pairs:
|
| 286 |
+
nl = pair.get("nl", "").strip()
|
| 287 |
+
sql = pair.get("sql", "").strip()
|
| 288 |
+
if not nl or not sql:
|
| 289 |
+
continue
|
| 290 |
+
# strip fences in sql just in case
|
| 291 |
+
sql = re.sub(r"```(?:sql)?\n?(.*?)```", r"\1", sql, flags=re.DOTALL).strip()
|
| 292 |
+
|
| 293 |
+
h = get_hash(nl + sql)
|
| 294 |
+
if h in seen:
|
| 295 |
+
continue
|
| 296 |
+
seen.add(h)
|
| 297 |
+
|
| 298 |
+
record = build_record(nl, sql, domain)
|
| 299 |
+
out.write(json.dumps(record, ensure_ascii=False) + "\n")
|
| 300 |
+
out.flush()
|
| 301 |
+
collected += 1
|
| 302 |
+
total += 1
|
| 303 |
+
pbar.update(1)
|
| 304 |
+
|
| 305 |
+
if collected >= SAMPLES_PER_PATTERN:
|
| 306 |
+
break
|
| 307 |
+
|
| 308 |
+
except Exception as e:
|
| 309 |
+
tqdm.write(f"[WARN] Batch failed: {e}")
|
| 310 |
+
continue
|
| 311 |
+
|
| 312 |
+
pbar.close()
|
| 313 |
+
|
| 314 |
+
out.close()
|
| 315 |
+
print(f"\nDone! {total} edge-case records saved to {OUTPUT_FILE}")
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
if __name__ == "__main__":
|
| 319 |
+
main()
|
inference.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
inference.py — NL2SQL-Bench Baseline Inference Script
|
| 3 |
+
========================================================
|
| 4 |
+
|
| 5 |
+
MANDATORY COMPLIANCE
|
| 6 |
+
--------------------
|
| 7 |
+
- Named `inference.py`, placed in project root.
|
| 8 |
+
- Uses OpenAI client for all LLM calls.
|
| 9 |
+
- Reads: API_BASE_URL, MODEL_NAME, HF_TOKEN from environment.
|
| 10 |
+
- Emits [START] / [STEP] / [END] lines to stdout in the exact format below.
|
| 11 |
+
- Runs all 3 tasks; total runtime < 20 min on 2 vCPU / 8 GB.
|
| 12 |
+
|
| 13 |
+
STDOUT FORMAT (exact — any deviation breaks scoring)
|
| 14 |
+
----------------------------------------------------
|
| 15 |
+
[START] task=<task_name> env=nl2sql-bench model=<model_name>
|
| 16 |
+
[STEP] step=<n> action=<sql_one_line> reward=<0.00> done=<true|false> error=<msg|null>
|
| 17 |
+
[END] success=<true|false> steps=<n> score=<0.000> rewards=<r1,r2,...>
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from __future__ import annotations
|
| 21 |
+
|
| 22 |
+
import asyncio
|
| 23 |
+
import os
|
| 24 |
+
import sys
|
| 25 |
+
import textwrap
|
| 26 |
+
from typing import List, Optional
|
| 27 |
+
|
| 28 |
+
from openai import OpenAI
|
| 29 |
+
|
| 30 |
+
# ── Configuration ──────────────────────────────────────────────────────────
|
| 31 |
+
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
| 32 |
+
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
|
| 33 |
+
API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY", "")
|
| 34 |
+
IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "nl2sql-bench:latest")
|
| 35 |
+
SPACE_URL = os.getenv("SPACE_URL", "http://localhost:8000")
|
| 36 |
+
|
| 37 |
+
BENCHMARK = "nl2sql-bench"
|
| 38 |
+
MAX_STEPS = 5
|
| 39 |
+
TEMPERATURE = 0.2 # Low temp for SQL generation
|
| 40 |
+
MAX_TOKENS = 512
|
| 41 |
+
SUCCESS_THRESHOLD = 0.7 # score >= 0.7 → success
|
| 42 |
+
|
| 43 |
+
TASKS = ["simple-filter", "join-aggregation", "analytics-window"]
|
| 44 |
+
|
| 45 |
+
# ── System prompt ──────────────────────────────────────────────────────────
|
| 46 |
+
SYSTEM_PROMPT = textwrap.dedent("""
|
| 47 |
+
You are an expert SQL analyst working with a SQLite e-commerce database.
|
| 48 |
+
|
| 49 |
+
DATABASE SCHEMA
|
| 50 |
+
---------------
|
| 51 |
+
categories(id, name)
|
| 52 |
+
products(id, name, category_id, price, stock_quantity)
|
| 53 |
+
customers(id, name, email, country, tier∈{bronze|silver|gold}, created_at)
|
| 54 |
+
orders(id, customer_id, status∈{pending|processing|shipped|delivered|cancelled},
|
| 55 |
+
created_at, total_amount)
|
| 56 |
+
order_items(id, order_id, product_id, quantity, unit_price)
|
| 57 |
+
reviews(id, product_id, customer_id, rating∈1-5, created_at)
|
| 58 |
+
|
| 59 |
+
RULES
|
| 60 |
+
-----
|
| 61 |
+
1. Write a single SELECT query — no INSERT/UPDATE/DELETE.
|
| 62 |
+
2. Output ONLY the SQL query, nothing else. No markdown, no explanation.
|
| 63 |
+
3. Use SQLite syntax: strftime('%Y-%m', date_col) for month, ROUND(x, 2) for decimals.
|
| 64 |
+
4. Window functions (RANK, DENSE_RANK, ROW_NUMBER, running SUM) are supported.
|
| 65 |
+
5. CTEs (WITH ... AS (...)) are supported.
|
| 66 |
+
6. If you receive an error, fix it carefully in your next attempt.
|
| 67 |
+
7. If you receive partial results, refine your query to match the expected output.
|
| 68 |
+
""").strip()
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
# ── Stdout logging (mandatory format) ─────────────────────────────────────
|
| 72 |
+
|
| 73 |
+
def log_start(task: str, model: str) -> None:
|
| 74 |
+
print(f"[START] task={task} env={BENCHMARK} model={model}", flush=True)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def log_step(
|
| 78 |
+
step: int, action: str, reward: float, done: bool, error: Optional[str]
|
| 79 |
+
) -> None:
|
| 80 |
+
# Collapse multi-line SQL to single line for log compliance
|
| 81 |
+
action_single = " ".join(action.split())
|
| 82 |
+
error_val = error.replace("\n", " ") if error else "null"
|
| 83 |
+
print(
|
| 84 |
+
f"[STEP] step={step} action={action_single!r} "
|
| 85 |
+
f"reward={reward:.2f} done={str(done).lower()} error={error_val}",
|
| 86 |
+
flush=True,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def log_end(
|
| 91 |
+
success: bool, steps: int, score: float, rewards: List[float]
|
| 92 |
+
) -> None:
|
| 93 |
+
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
|
| 94 |
+
print(
|
| 95 |
+
f"[END] success={str(success).lower()} steps={steps} "
|
| 96 |
+
f"score={score:.3f} rewards={rewards_str}",
|
| 97 |
+
flush=True,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# ── LLM interaction ────────────────────────────────────────────────────────
|
| 102 |
+
|
| 103 |
+
def build_user_prompt(
|
| 104 |
+
question: str,
|
| 105 |
+
schema_context: str,
|
| 106 |
+
step: int,
|
| 107 |
+
last_query: str,
|
| 108 |
+
last_error: Optional[str],
|
| 109 |
+
last_result: list,
|
| 110 |
+
result_columns: list,
|
| 111 |
+
) -> str:
|
| 112 |
+
parts = [f"QUESTION: {question}", ""]
|
| 113 |
+
|
| 114 |
+
if step > 1:
|
| 115 |
+
parts.append(f"Your previous SQL (step {step - 1}):")
|
| 116 |
+
parts.append(f" {' '.join(last_query.split())}")
|
| 117 |
+
parts.append("")
|
| 118 |
+
if last_error:
|
| 119 |
+
parts.append(f"ERROR: {last_error}")
|
| 120 |
+
elif last_result:
|
| 121 |
+
preview = str(last_result[:3]).replace("\n", " ")
|
| 122 |
+
parts.append(f"RESULT PREVIEW (first 3 rows): {preview}")
|
| 123 |
+
parts.append(f"COLUMNS: {result_columns}")
|
| 124 |
+
parts.append("")
|
| 125 |
+
parts.append("Please correct or refine your query.")
|
| 126 |
+
else:
|
| 127 |
+
parts.append("Write a SQL query to answer the question.")
|
| 128 |
+
|
| 129 |
+
return "\n".join(parts)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def call_llm(client: OpenAI, user_prompt: str) -> str:
|
| 133 |
+
try:
|
| 134 |
+
resp = client.chat.completions.create(
|
| 135 |
+
model=MODEL_NAME,
|
| 136 |
+
messages=[
|
| 137 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 138 |
+
{"role": "user", "content": user_prompt},
|
| 139 |
+
],
|
| 140 |
+
temperature=TEMPERATURE,
|
| 141 |
+
max_tokens=MAX_TOKENS,
|
| 142 |
+
stream=False,
|
| 143 |
+
)
|
| 144 |
+
text = (resp.choices[0].message.content or "").strip()
|
| 145 |
+
# Strip markdown code fences if model wraps in ```sql ... ```
|
| 146 |
+
if text.startswith("```"):
|
| 147 |
+
lines = text.split("\n")
|
| 148 |
+
text = "\n".join(
|
| 149 |
+
l for l in lines
|
| 150 |
+
if not l.strip().startswith("```")
|
| 151 |
+
).strip()
|
| 152 |
+
return text if text else "SELECT 1"
|
| 153 |
+
except Exception as exc:
|
| 154 |
+
print(f"[DEBUG] LLM call failed: {exc}", file=sys.stderr, flush=True)
|
| 155 |
+
return "SELECT 1"
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
# ── Single-task episode ────────────────────────────────────────────────────
|
| 159 |
+
|
| 160 |
+
async def run_task(client: OpenAI, env, task_name: str) -> dict:
|
| 161 |
+
"""Run one full episode for the given task. Returns result dict."""
|
| 162 |
+
rewards: List[float] = []
|
| 163 |
+
steps_taken = 0
|
| 164 |
+
score = 0.0
|
| 165 |
+
success = False
|
| 166 |
+
|
| 167 |
+
log_start(task_name, MODEL_NAME)
|
| 168 |
+
|
| 169 |
+
try:
|
| 170 |
+
# Reset — pass task_name via action payload or query param
|
| 171 |
+
# OpenEnv reset() may not accept task args via HTTP; we rely on
|
| 172 |
+
# NL2SQL_DEFAULT_TASK env-var being set before calling, OR we
|
| 173 |
+
# pass it as a reset parameter if the server supports it.
|
| 174 |
+
result = await env.reset()
|
| 175 |
+
obs = result.observation
|
| 176 |
+
|
| 177 |
+
for step in range(1, MAX_STEPS + 1):
|
| 178 |
+
if result.done:
|
| 179 |
+
break
|
| 180 |
+
|
| 181 |
+
user_prompt = build_user_prompt(
|
| 182 |
+
question=obs.question,
|
| 183 |
+
schema_context=obs.schema_context,
|
| 184 |
+
step=step,
|
| 185 |
+
last_query=obs.last_query,
|
| 186 |
+
last_error=obs.last_error,
|
| 187 |
+
last_result=obs.last_result,
|
| 188 |
+
result_columns=obs.result_columns,
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
sql = call_llm(client, user_prompt)
|
| 192 |
+
|
| 193 |
+
from client import NL2SQLAction # local to avoid circular at module level
|
| 194 |
+
result = await env.step(NL2SQLAction(query=sql))
|
| 195 |
+
obs = result.observation
|
| 196 |
+
|
| 197 |
+
reward = obs.reward or 0.0
|
| 198 |
+
done = obs.done
|
| 199 |
+
error = obs.last_error
|
| 200 |
+
|
| 201 |
+
rewards.append(reward)
|
| 202 |
+
steps_taken = step
|
| 203 |
+
|
| 204 |
+
log_step(step=step, action=sql, reward=reward, done=done, error=error)
|
| 205 |
+
|
| 206 |
+
if done:
|
| 207 |
+
break
|
| 208 |
+
|
| 209 |
+
# Compute final score
|
| 210 |
+
score = sum(rewards) / max(len(rewards), 1)
|
| 211 |
+
score = round(min(max(score, 0.0), 1.0), 4)
|
| 212 |
+
success = score >= SUCCESS_THRESHOLD
|
| 213 |
+
|
| 214 |
+
except Exception as exc:
|
| 215 |
+
print(f"[DEBUG] Episode error for {task_name}: {exc}", file=sys.stderr, flush=True)
|
| 216 |
+
finally:
|
| 217 |
+
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
|
| 218 |
+
|
| 219 |
+
return {"task": task_name, "success": success, "score": score, "rewards": rewards}
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
# ── Main ───────────────────────────────────────────────────────────────────
|
| 223 |
+
|
| 224 |
+
async def main() -> None:
|
| 225 |
+
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
| 226 |
+
|
| 227 |
+
# Import here to avoid import errors if openenv not installed during lint
|
| 228 |
+
from client import NL2SQLEnv
|
| 229 |
+
|
| 230 |
+
all_results = []
|
| 231 |
+
|
| 232 |
+
for task_name in TASKS:
|
| 233 |
+
# Set the default task for the server session via env-var approach.
|
| 234 |
+
# For the hosted Space, we rely on the task cycling implemented in
|
| 235 |
+
# the task registry's round-robin iterator.
|
| 236 |
+
os.environ["NL2SQL_DEFAULT_TASK"] = task_name
|
| 237 |
+
|
| 238 |
+
try:
|
| 239 |
+
async with NL2SQLEnv(base_url=SPACE_URL) as env:
|
| 240 |
+
result = await run_task(client, env, task_name)
|
| 241 |
+
all_results.append(result)
|
| 242 |
+
except Exception as exc:
|
| 243 |
+
print(
|
| 244 |
+
f"[DEBUG] Failed to connect for task {task_name}: {exc}",
|
| 245 |
+
file=sys.stderr,
|
| 246 |
+
flush=True,
|
| 247 |
+
)
|
| 248 |
+
# Emit a zero-score END to keep log format valid
|
| 249 |
+
log_end(success=False, steps=0, score=0.0, rewards=[])
|
| 250 |
+
all_results.append({"task": task_name, "success": False, "score": 0.0})
|
| 251 |
+
|
| 252 |
+
# Summary to stderr (not scored, for human readability)
|
| 253 |
+
print("\n=== Baseline Summary ===", file=sys.stderr)
|
| 254 |
+
for r in all_results:
|
| 255 |
+
print(
|
| 256 |
+
f" {r['task']:20s} score={r['score']:.3f} "
|
| 257 |
+
f"success={r['success']}",
|
| 258 |
+
file=sys.stderr,
|
| 259 |
+
)
|
| 260 |
+
avg = sum(r["score"] for r in all_results) / max(len(all_results), 1)
|
| 261 |
+
print(f" {'AVERAGE':20s} score={avg:.3f}", file=sys.stderr)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
if __name__ == "__main__":
|
| 265 |
+
asyncio.run(main())
|
local_test.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import torch
|
| 5 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 6 |
+
from peft import PeftModel
|
| 7 |
+
|
| 8 |
+
# --- Configuration ---
|
| 9 |
+
BASE_MODEL = "Qwen/Qwen2.5-0.5B-Instruct"
|
| 10 |
+
LORA_DIR = "./qwen-nl2sql-grpo/checkpoint-50"
|
| 11 |
+
SPACE_URL = "http://localhost:8000" # Local server URL
|
| 12 |
+
TASKS = ["simple-filter", "join-aggregation", "analytics-window"]
|
| 13 |
+
MAX_STEPS = 5
|
| 14 |
+
|
| 15 |
+
print("Loading Base Model and LoRA weights...")
|
| 16 |
+
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
| 17 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 18 |
+
BASE_MODEL,
|
| 19 |
+
torch_dtype=torch.bfloat16,
|
| 20 |
+
device_map="auto"
|
| 21 |
+
)
|
| 22 |
+
model = PeftModel.from_pretrained(base_model, LORA_DIR)
|
| 23 |
+
|
| 24 |
+
# --- System Prompt & LLM Call ---
|
| 25 |
+
SYSTEM_PROMPT = """You are an expert SQL analyst working with a SQLite e-commerce database.
|
| 26 |
+
Write a single SELECT query. Output ONLY the SQL query, nothing else. No markdown."""
|
| 27 |
+
|
| 28 |
+
def call_local_llm(user_prompt: str) -> str:
|
| 29 |
+
messages = [
|
| 30 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 31 |
+
{"role": "user", "content": user_prompt}
|
| 32 |
+
]
|
| 33 |
+
text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 34 |
+
inputs = tokenizer([text], return_tensors="pt").to(model.device)
|
| 35 |
+
|
| 36 |
+
with torch.no_grad():
|
| 37 |
+
outputs = model.generate(**inputs, max_new_tokens=256, temperature=0.2, do_sample=True)
|
| 38 |
+
|
| 39 |
+
response = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
| 40 |
+
|
| 41 |
+
# Strip markdown code fences if model wraps in ```sql ... ```
|
| 42 |
+
if response.startswith("```"):
|
| 43 |
+
lines = response.split("\n")
|
| 44 |
+
response = "\n".join(l for l in lines if not l.strip().startswith("```")).strip()
|
| 45 |
+
return response if response else "SELECT 1"
|
| 46 |
+
|
| 47 |
+
def build_user_prompt(question, schema_context, step, last_query, last_error, last_result, result_columns):
|
| 48 |
+
parts = [f"QUESTION: {question}", ""]
|
| 49 |
+
if step > 1:
|
| 50 |
+
parts.append(f"Your previous SQL (step {step - 1}):")
|
| 51 |
+
parts.append(f" {' '.join(last_query.split())}")
|
| 52 |
+
parts.append("")
|
| 53 |
+
if last_error:
|
| 54 |
+
parts.append(f"ERROR: {last_error}")
|
| 55 |
+
elif last_result:
|
| 56 |
+
preview = str(last_result[:3]).replace("\n", " ")
|
| 57 |
+
parts.append(f"RESULT PREVIEW (first 3 rows): {preview}")
|
| 58 |
+
parts.append(f"COLUMNS: {result_columns}")
|
| 59 |
+
parts.append("")
|
| 60 |
+
parts.append("Please correct or refine your query.")
|
| 61 |
+
else:
|
| 62 |
+
parts.append("Write a SQL query to answer the question.")
|
| 63 |
+
return "\n".join(parts)
|
| 64 |
+
|
| 65 |
+
async def main():
|
| 66 |
+
from client import NL2SQLEnv, NL2SQLAction
|
| 67 |
+
|
| 68 |
+
all_results = []
|
| 69 |
+
|
| 70 |
+
for task_name in TASKS:
|
| 71 |
+
print(f"\n--- Starting Task: {task_name} ---")
|
| 72 |
+
os.environ["NL2SQL_DEFAULT_TASK"] = task_name
|
| 73 |
+
|
| 74 |
+
try:
|
| 75 |
+
async with NL2SQLEnv(base_url=SPACE_URL) as env:
|
| 76 |
+
result = await env.reset()
|
| 77 |
+
obs = result.observation
|
| 78 |
+
|
| 79 |
+
rewards = []
|
| 80 |
+
success = False
|
| 81 |
+
|
| 82 |
+
for step in range(1, MAX_STEPS + 1):
|
| 83 |
+
if obs.done:
|
| 84 |
+
break
|
| 85 |
+
|
| 86 |
+
user_prompt = build_user_prompt(
|
| 87 |
+
obs.question, obs.schema_context, step,
|
| 88 |
+
obs.last_query, obs.last_error, obs.last_result, obs.result_columns
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
sql = call_local_llm(user_prompt)
|
| 92 |
+
|
| 93 |
+
print(f"Step {step} Agent Output: {sql}")
|
| 94 |
+
|
| 95 |
+
step_result = await env.step(NL2SQLAction(query=sql))
|
| 96 |
+
obs = step_result.observation
|
| 97 |
+
|
| 98 |
+
reward = obs.reward or 0.0
|
| 99 |
+
rewards.append(reward)
|
| 100 |
+
print(f"Step {step} Reward: {reward}")
|
| 101 |
+
|
| 102 |
+
if obs.done:
|
| 103 |
+
break
|
| 104 |
+
|
| 105 |
+
score = sum(rewards) / max(len(rewards), 1)
|
| 106 |
+
success = score >= 0.7
|
| 107 |
+
print(f"Final Score for {task_name}: {score:.3f}")
|
| 108 |
+
all_results.append({"task": task_name, "score": score, "success": success})
|
| 109 |
+
|
| 110 |
+
except Exception as e:
|
| 111 |
+
print(f"Error testing task {task_name}: {e}")
|
| 112 |
+
|
| 113 |
+
print("\n=== Final Results ===")
|
| 114 |
+
for r in all_results:
|
| 115 |
+
print(f"{r['task']}: Score {r['score']:.3f} | Success: {r['success']}")
|
| 116 |
+
|
| 117 |
+
if __name__ == "__main__":
|
| 118 |
+
asyncio.run(main())
|
merge_model.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 3 |
+
from peft import PeftModel
|
| 4 |
+
|
| 5 |
+
BASE_MODEL = "Qwen/Qwen2.5-Coder-7B-Instruct"
|
| 6 |
+
ADAPTER_DIR = "./qwen-7b-coder-nl2sql-grpo/final"
|
| 7 |
+
OUTPUT_DIR = "./qwen-7b-nl2sql-merged"
|
| 8 |
+
|
| 9 |
+
def main():
|
| 10 |
+
print("Loading Base Model...")
|
| 11 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 12 |
+
BASE_MODEL,
|
| 13 |
+
torch_dtype=torch.bfloat16,
|
| 14 |
+
device_map="auto"
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
print("Loading Adapters and Merging...")
|
| 18 |
+
# Load the LoRA adapters into the base model
|
| 19 |
+
model = PeftModel.from_pretrained(base_model, ADAPTER_DIR)
|
| 20 |
+
|
| 21 |
+
# Merge weights permanently
|
| 22 |
+
merged_model = model.merge_and_unload()
|
| 23 |
+
|
| 24 |
+
print("Saving Merged Model...")
|
| 25 |
+
merged_model.save_pretrained(OUTPUT_DIR)
|
| 26 |
+
|
| 27 |
+
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
|
| 28 |
+
tokenizer.save_pretrained(OUTPUT_DIR)
|
| 29 |
+
print(f"Done! Merged model saved to {OUTPUT_DIR}")
|
| 30 |
+
|
| 31 |
+
if __name__ == "__main__":
|
| 32 |
+
main()
|
mini_server.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import uvicorn
|
| 4 |
+
from fastapi import FastAPI
|
| 5 |
+
from pydantic import BaseModel
|
| 6 |
+
from typing import List
|
| 7 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 8 |
+
|
| 9 |
+
# CRITICAL: GPU 0 pe host karenge
|
| 10 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
|
| 11 |
+
|
| 12 |
+
app = FastAPI()
|
| 13 |
+
|
| 14 |
+
# Tera Merged Model Path
|
| 15 |
+
MODEL_PATH = "./qwen-7b-nl2sql-merged"
|
| 16 |
+
|
| 17 |
+
print("🚀 Loading Local Model for Inference API... (Takes a minute)")
|
| 18 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
|
| 19 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 20 |
+
MODEL_PATH,
|
| 21 |
+
device_map="auto",
|
| 22 |
+
torch_dtype=torch.bfloat16,
|
| 23 |
+
attn_implementation="sdpa" # Super stable, no vLLM crashes
|
| 24 |
+
)
|
| 25 |
+
print("✅ Server Ready! Acting as OpenAI on Port 8000.")
|
| 26 |
+
|
| 27 |
+
# OpenAI Request Schemas
|
| 28 |
+
class Message(BaseModel):
|
| 29 |
+
role: str
|
| 30 |
+
content: str
|
| 31 |
+
|
| 32 |
+
class ChatRequest(BaseModel):
|
| 33 |
+
model: str
|
| 34 |
+
messages: List[Message]
|
| 35 |
+
temperature: float = 0.2
|
| 36 |
+
max_tokens: int = 512
|
| 37 |
+
|
| 38 |
+
@app.post("/v1/chat/completions")
|
| 39 |
+
async def chat(request: ChatRequest):
|
| 40 |
+
# Convert OpenAI messages to Qwen format
|
| 41 |
+
messages = [{"role": m.role, "content": m.content} for m in request.messages]
|
| 42 |
+
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
| 43 |
+
|
| 44 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
|
| 45 |
+
|
| 46 |
+
# Generate SQL
|
| 47 |
+
with torch.no_grad():
|
| 48 |
+
outputs = model.generate(
|
| 49 |
+
**inputs,
|
| 50 |
+
max_new_tokens=request.max_tokens,
|
| 51 |
+
temperature=request.temperature,
|
| 52 |
+
do_sample=True if request.temperature > 0 else False,
|
| 53 |
+
pad_token_id=tokenizer.eos_token_id
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
# Decode only the newly generated text
|
| 57 |
+
response_text = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
|
| 58 |
+
|
| 59 |
+
# Return EXACT OpenAI JSON Structure
|
| 60 |
+
return {
|
| 61 |
+
"id": "chatcmpl-local-hackathon",
|
| 62 |
+
"object": "chat.completion",
|
| 63 |
+
"created": 1700000000,
|
| 64 |
+
"model": request.model,
|
| 65 |
+
"choices": [{
|
| 66 |
+
"index": 0,
|
| 67 |
+
"message": {"role": "assistant", "content": response_text},
|
| 68 |
+
"finish_reason": "stop"
|
| 69 |
+
}],
|
| 70 |
+
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
if __name__ == "__main__":
|
| 74 |
+
uvicorn.run(app, host="0.0.0.0", port=8001)
|
models.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
nl2sql-bench/models.py
|
| 3 |
+
======================
|
| 4 |
+
Typed contracts for the NL2SQL-Bench OpenEnv environment.
|
| 5 |
+
|
| 6 |
+
Action : The SQL query the agent submits.
|
| 7 |
+
Observation : What the agent sees after each step.
|
| 8 |
+
State : Episode-level metadata (for state() endpoint).
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
from __future__ import annotations
|
| 12 |
+
|
| 13 |
+
from dataclasses import dataclass, field
|
| 14 |
+
from typing import Any, Dict, List, Optional
|
| 15 |
+
|
| 16 |
+
from openenv.core.env_server import Action, Observation, State
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# ---------------------------------------------------------------------------
|
| 20 |
+
# Action
|
| 21 |
+
# ---------------------------------------------------------------------------
|
| 22 |
+
|
| 23 |
+
class NL2SQLAction(Action):
|
| 24 |
+
"""A single SQL query submitted by the agent."""
|
| 25 |
+
query: str = ""
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# ---------------------------------------------------------------------------
|
| 29 |
+
# Observation
|
| 30 |
+
# ---------------------------------------------------------------------------
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class NL2SQLObservation(Observation):
|
| 34 |
+
"""
|
| 35 |
+
Everything the agent needs to reason about and iterate its SQL query.
|
| 36 |
+
|
| 37 |
+
Fields
|
| 38 |
+
------
|
| 39 |
+
question : The natural-language question to answer.
|
| 40 |
+
schema_context : Relevant table/column descriptions as a string block.
|
| 41 |
+
task_name : Identifier of the current task (easy / medium / hard).
|
| 42 |
+
last_query : The SQL the agent submitted on the last step (empty on reset).
|
| 43 |
+
last_result : Up to 10 rows returned by the last query (list of dicts).
|
| 44 |
+
last_error : SQLite error string if the query failed, else None.
|
| 45 |
+
result_columns : Column names of last_result rows.
|
| 46 |
+
step : Current step number (1-indexed).
|
| 47 |
+
max_steps : Maximum steps allowed per episode.
|
| 48 |
+
done : True when the episode is over (success or step exhausted).
|
| 49 |
+
reward : Reward for the most recent action (None on reset).
|
| 50 |
+
score : Normalised cumulative score so far [0.0, 1.0].
|
| 51 |
+
"""
|
| 52 |
+
question: str = ""
|
| 53 |
+
schema_context: str = ""
|
| 54 |
+
task_name: str = ""
|
| 55 |
+
last_query: str = ""
|
| 56 |
+
last_result: List[Dict[str, Any]] = field(default_factory=list)
|
| 57 |
+
last_error: Optional[str] = None
|
| 58 |
+
result_columns: List[str] = field(default_factory=list)
|
| 59 |
+
step: int = 0
|
| 60 |
+
max_steps: int = 5
|
| 61 |
+
done: bool = False
|
| 62 |
+
reward: Optional[float] = None
|
| 63 |
+
score: float = 0.0
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# ---------------------------------------------------------------------------
|
| 67 |
+
# State
|
| 68 |
+
# ---------------------------------------------------------------------------
|
| 69 |
+
|
| 70 |
+
class NL2SQLState(State):
|
| 71 |
+
"""Episode-level state (returned by the /state endpoint)."""
|
| 72 |
+
episode_id: Optional[str] = None
|
| 73 |
+
step_count: int = 0
|
| 74 |
+
task_name: str = ""
|
| 75 |
+
task_difficulty: str = "" # easy | medium | hard
|
| 76 |
+
question: str = ""
|
| 77 |
+
best_reward: float = 0.0 # highest reward seen this episode
|
| 78 |
+
cumulative_reward: float = 0.0
|
| 79 |
+
solved: bool = False # True if exact match was achieved
|
openenv.yaml
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# nl2sql-bench/openenv.yaml
|
| 2 |
+
# OpenEnv environment manifest — validated by `openenv validate`
|
| 3 |
+
|
| 4 |
+
name: nl2sql-bench
|
| 5 |
+
version: "0.1.0"
|
| 6 |
+
description: >
|
| 7 |
+
Natural Language to SQL query generation environment for RL training.
|
| 8 |
+
An agent iteratively writes and refines SQLite queries against a synthetic
|
| 9 |
+
e-commerce database to answer business questions. Multi-turn episodes with
|
| 10 |
+
dense, shaped rewards. Three difficulty tasks: easy (single-table),
|
| 11 |
+
medium (JOIN + GROUP BY), hard (window functions + CTEs).
|
| 12 |
+
|
| 13 |
+
author: "nl2sql-bench team"
|
| 14 |
+
license: MIT
|
| 15 |
+
tags:
|
| 16 |
+
- openenv
|
| 17 |
+
- nl2sql
|
| 18 |
+
- sql
|
| 19 |
+
- analytics
|
| 20 |
+
- rl-training
|
| 21 |
+
- deterministic
|
| 22 |
+
- multi-turn
|
| 23 |
+
|
| 24 |
+
# ── Task definitions ────────────────────────────────────────────────────────
|
| 25 |
+
tasks:
|
| 26 |
+
- name: simple-filter
|
| 27 |
+
difficulty: easy
|
| 28 |
+
description: >
|
| 29 |
+
Single-table SELECT with WHERE, ORDER BY, and LIMIT.
|
| 30 |
+
Tests basic SQL fluency. Expected solve rate: high.
|
| 31 |
+
max_steps: 5
|
| 32 |
+
reward_range: [0.0, 1.0]
|
| 33 |
+
|
| 34 |
+
- name: join-aggregation
|
| 35 |
+
difficulty: medium
|
| 36 |
+
description: >
|
| 37 |
+
Multi-table JOINs with GROUP BY, HAVING, and aggregation functions
|
| 38 |
+
(COUNT, SUM, AVG, ROUND). Tests relational reasoning.
|
| 39 |
+
max_steps: 5
|
| 40 |
+
reward_range: [0.0, 1.0]
|
| 41 |
+
|
| 42 |
+
- name: analytics-window
|
| 43 |
+
difficulty: hard
|
| 44 |
+
description: >
|
| 45 |
+
Advanced analytics using CTEs, window functions (DENSE_RANK,
|
| 46 |
+
ROW_NUMBER, running SUM), and nested subqueries. Tests multi-step
|
| 47 |
+
planning and SQLite-specific syntax.
|
| 48 |
+
max_steps: 5
|
| 49 |
+
reward_range: [0.0, 1.0]
|
| 50 |
+
|
| 51 |
+
# ── Action / Observation space ──────────────────────────────────────────────
|
| 52 |
+
action_space:
|
| 53 |
+
type: object
|
| 54 |
+
properties:
|
| 55 |
+
query:
|
| 56 |
+
type: string
|
| 57 |
+
description: "A SQLite SELECT query string."
|
| 58 |
+
|
| 59 |
+
observation_space:
|
| 60 |
+
type: object
|
| 61 |
+
properties:
|
| 62 |
+
question:
|
| 63 |
+
type: string
|
| 64 |
+
description: "Natural-language question the agent must answer."
|
| 65 |
+
schema_context:
|
| 66 |
+
type: string
|
| 67 |
+
description: "Compact database schema description for the agent."
|
| 68 |
+
task_name:
|
| 69 |
+
type: string
|
| 70 |
+
description: "Active task identifier."
|
| 71 |
+
last_query:
|
| 72 |
+
type: string
|
| 73 |
+
description: "The SQL query submitted on the previous step."
|
| 74 |
+
last_result:
|
| 75 |
+
type: array
|
| 76 |
+
description: "Up to 10 rows returned by the last query (list of dicts)."
|
| 77 |
+
last_error:
|
| 78 |
+
type: string
|
| 79 |
+
nullable: true
|
| 80 |
+
description: "SQLite error string if last query failed, else null."
|
| 81 |
+
result_columns:
|
| 82 |
+
type: array
|
| 83 |
+
description: "Column names of last_result."
|
| 84 |
+
step:
|
| 85 |
+
type: integer
|
| 86 |
+
description: "Current step number (1-indexed; 0 after reset)."
|
| 87 |
+
max_steps:
|
| 88 |
+
type: integer
|
| 89 |
+
description: "Maximum steps per episode."
|
| 90 |
+
done:
|
| 91 |
+
type: boolean
|
| 92 |
+
description: "True when episode ends (exact match or step limit reached)."
|
| 93 |
+
reward:
|
| 94 |
+
type: number
|
| 95 |
+
nullable: true
|
| 96 |
+
description: "Reward for the most recent step [0.0, 1.0]."
|
| 97 |
+
score:
|
| 98 |
+
type: number
|
| 99 |
+
description: "Normalised cumulative episode score [0.0, 1.0]."
|
| 100 |
+
|
| 101 |
+
# ── Reward function description ─────────────────────────────────────────────
|
| 102 |
+
reward:
|
| 103 |
+
type: shaped
|
| 104 |
+
range: [0.0, 1.0]
|
| 105 |
+
components:
|
| 106 |
+
- name: syntax_ok
|
| 107 |
+
weight: 0.10
|
| 108 |
+
description: "Query executes without SQLite error."
|
| 109 |
+
- name: columns_match
|
| 110 |
+
weight: 0.20
|
| 111 |
+
description: "Returned column names match ground truth exactly."
|
| 112 |
+
- name: row_count_match
|
| 113 |
+
weight: 0.20
|
| 114 |
+
description: "Number of returned rows matches ground truth."
|
| 115 |
+
- name: exact_match
|
| 116 |
+
weight: 0.50
|
| 117 |
+
description: "Full result set matches ground truth (order-aware for ORDER BY)."
|
| 118 |
+
- name: step_penalty
|
| 119 |
+
weight: -0.05
|
| 120 |
+
description: "Deducted per step beyond the first (encourages efficiency)."
|
| 121 |
+
|
| 122 |
+
# ── Deployment ──────────────────────────────────────────────────────────────
|
| 123 |
+
server:
|
| 124 |
+
port: 7860
|
| 125 |
+
dockerfile: Dockerfile
|
| 126 |
+
healthcheck: /health
|
| 127 |
+
|
| 128 |
+
# ── Baseline ────────────────────────────────────────────────────────────────
|
| 129 |
+
baseline:
|
| 130 |
+
script: inference.py
|
| 131 |
+
model: Qwen/Qwen2.5-72B-Instruct
|
| 132 |
+
expected_scores:
|
| 133 |
+
simple-filter: 0.70
|
| 134 |
+
join-aggregation: 0.45
|
| 135 |
+
analytics-window: 0.25
|
pyproject.toml
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=69", "wheel"]
|
| 3 |
+
build-backend = "setuptools.backends.legacy:build"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "nl2sql-bench"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "NL2SQL-Bench: Natural Language to SQL Analytics OpenEnv environment for RL training"
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
requires-python = ">=3.10"
|
| 11 |
+
license = { text = "MIT" }
|
| 12 |
+
|
| 13 |
+
dependencies = [
|
| 14 |
+
"openenv-core>=0.2.3",
|
| 15 |
+
"fastapi>=0.110.0",
|
| 16 |
+
"uvicorn[standard]>=0.29.0",
|
| 17 |
+
"pydantic>=2.0.0",
|
| 18 |
+
"openai>=1.0.0",
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
[project.optional-dependencies]
|
| 22 |
+
dev = [
|
| 23 |
+
"pytest>=7.0",
|
| 24 |
+
"pytest-asyncio>=0.23",
|
| 25 |
+
"httpx>=0.27",
|
| 26 |
+
"black",
|
| 27 |
+
"ruff",
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
[tool.setuptools.packages.find]
|
| 31 |
+
where = ["."]
|
| 32 |
+
include = ["nl2sql*", "server*"]
|
| 33 |
+
|
| 34 |
+
[tool.ruff]
|
| 35 |
+
line-length = 100
|
| 36 |
+
target-version = "py310"
|
| 37 |
+
|
| 38 |
+
[tool.pytest.ini_options]
|
| 39 |
+
asyncio_mode = "auto"
|
scripts/run_local.sh
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# nl2sql-bench/scripts/run_local.sh
|
| 3 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 4 |
+
# Quick local development server (no Docker needed).
|
| 5 |
+
# Prerequisites: Python 3.10+, pip or uv
|
| 6 |
+
#
|
| 7 |
+
# Usage:
|
| 8 |
+
# chmod +x scripts/run_local.sh
|
| 9 |
+
# ./scripts/run_local.sh
|
| 10 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 11 |
+
set -euo pipefail
|
| 12 |
+
|
| 13 |
+
REPO_ROOT="$(cd "$(dirname "$0")/.." && pwd)"
|
| 14 |
+
cd "$REPO_ROOT"
|
| 15 |
+
|
| 16 |
+
echo "═══════════════════════════════════════════════"
|
| 17 |
+
echo " NL2SQL-Bench — Local Dev Server"
|
| 18 |
+
echo "═══════════════════════════════════════════════"
|
| 19 |
+
|
| 20 |
+
# ── Check Python ────────────────────────────────────────────────────────────
|
| 21 |
+
if ! command -v python3 &>/dev/null; then
|
| 22 |
+
echo "ERROR: python3 not found. Install Python 3.10+." && exit 1
|
| 23 |
+
fi
|
| 24 |
+
PY_VERSION=$(python3 -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")')
|
| 25 |
+
echo "Python: $PY_VERSION"
|
| 26 |
+
|
| 27 |
+
# ── Virtual environment ──────────────────────────────────────────────────────
|
| 28 |
+
if [ ! -d ".venv" ]; then
|
| 29 |
+
echo "Creating virtualenv..."
|
| 30 |
+
python3 -m venv .venv
|
| 31 |
+
fi
|
| 32 |
+
source .venv/bin/activate
|
| 33 |
+
|
| 34 |
+
# ── Install deps ─────────────────────────────────────────────────────────────
|
| 35 |
+
echo "Installing dependencies..."
|
| 36 |
+
pip install -q --upgrade pip
|
| 37 |
+
pip install -q openenv-core fastapi "uvicorn[standard]" openai pydantic pytest pytest-asyncio
|
| 38 |
+
|
| 39 |
+
# ── Load .env if present ─────────────────────────────────────────────────────
|
| 40 |
+
if [ -f ".env" ]; then
|
| 41 |
+
echo "Loading .env..."
|
| 42 |
+
set -a
|
| 43 |
+
source .env
|
| 44 |
+
set +a
|
| 45 |
+
fi
|
| 46 |
+
|
| 47 |
+
# ── Export PYTHONPATH ─────────────────────────────────────────────────────────
|
| 48 |
+
export PYTHONPATH="$REPO_ROOT:$REPO_ROOT/server"
|
| 49 |
+
|
| 50 |
+
echo ""
|
| 51 |
+
echo "Starting server at http://localhost:8000"
|
| 52 |
+
echo " /reset → POST (start episode)"
|
| 53 |
+
echo " /step → POST (submit SQL)"
|
| 54 |
+
echo " /state → GET (episode metadata)"
|
| 55 |
+
echo " /health → GET (liveness probe)"
|
| 56 |
+
echo " /docs → GET (Swagger UI)"
|
| 57 |
+
echo ""
|
| 58 |
+
echo "Press Ctrl+C to stop."
|
| 59 |
+
echo "───────────────────────────────────────────────"
|
| 60 |
+
|
| 61 |
+
cd "$REPO_ROOT/server"
|
| 62 |
+
uvicorn app:app \
|
| 63 |
+
--host 0.0.0.0 \
|
| 64 |
+
--port 8000 \
|
| 65 |
+
--reload \
|
| 66 |
+
--reload-dir "$REPO_ROOT" \
|
| 67 |
+
--log-level info
|
scripts/smoke_test.sh
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# nl2sql-bench/scripts/smoke_test.sh
|
| 3 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 4 |
+
# Smoke tests against a running server (local or HF Space).
|
| 5 |
+
# Verifies all endpoints return expected HTTP codes and JSON shapes.
|
| 6 |
+
#
|
| 7 |
+
# Usage:
|
| 8 |
+
# ./scripts/smoke_test.sh # default localhost:8000
|
| 9 |
+
# ./scripts/smoke_test.sh https://your.hf.space # HF Space URL
|
| 10 |
+
# ─────────────────────────────────────────────────────────────────────────────
|
| 11 |
+
set -euo pipefail
|
| 12 |
+
|
| 13 |
+
BASE_URL="${1:-http://localhost:8000}"
|
| 14 |
+
BASE_URL="${BASE_URL%/}"
|
| 15 |
+
PASS=0; FAIL=0
|
| 16 |
+
|
| 17 |
+
GREEN='\033[0;32m'; RED='\033[0;31m'; NC='\033[0m'; BOLD='\033[1m'
|
| 18 |
+
|
| 19 |
+
pass() { echo -e "${GREEN}✓${NC} $1"; PASS=$((PASS+1)); }
|
| 20 |
+
fail() { echo -e "${RED}✗${NC} $1"; FAIL=$((FAIL+1)); }
|
| 21 |
+
|
| 22 |
+
echo ""
|
| 23 |
+
echo -e "${BOLD}NL2SQL-Bench Smoke Tests${NC}"
|
| 24 |
+
echo "Target: $BASE_URL"
|
| 25 |
+
echo "────────────────────────────────────────"
|
| 26 |
+
|
| 27 |
+
# ── /health ──────────────────────────────────────────────────────────────────
|
| 28 |
+
CODE=$(curl -s -o /dev/null -w "%{http_code}" "$BASE_URL/health")
|
| 29 |
+
[ "$CODE" = "200" ] && pass "/health → 200" || fail "/health → $CODE (expected 200)"
|
| 30 |
+
|
| 31 |
+
# ── /reset ───────────────────────────────────────────────────────────────────
|
| 32 |
+
RESET_BODY=$(curl -s -X POST "$BASE_URL/reset" \
|
| 33 |
+
-H "Content-Type: application/json" -d '{}')
|
| 34 |
+
echo "$RESET_BODY" | grep -q "question" && pass "/reset → has 'question' field" \
|
| 35 |
+
|| fail "/reset → missing 'question' field. Body: $RESET_BODY"
|
| 36 |
+
|
| 37 |
+
# ── /step (valid SQL) ─────────────────────────────────────────────────────────
|
| 38 |
+
STEP_BODY=$(curl -s -X POST "$BASE_URL/step" \
|
| 39 |
+
-H "Content-Type: application/json" \
|
| 40 |
+
-d '{"query": "SELECT id, name FROM customers LIMIT 3"}')
|
| 41 |
+
echo "$STEP_BODY" | grep -q "reward" && pass "/step valid SQL → has 'reward'" \
|
| 42 |
+
|| fail "/step valid SQL → missing 'reward'. Body: $STEP_BODY"
|
| 43 |
+
echo "$STEP_BODY" | grep -q '"done"' && pass "/step valid SQL → has 'done'" \
|
| 44 |
+
|| fail "/step valid SQL → missing 'done'. Body: $STEP_BODY"
|
| 45 |
+
|
| 46 |
+
# ── /step (syntax error SQL) ──────────────────────────────────────────────────
|
| 47 |
+
STEP_ERR=$(curl -s -X POST "$BASE_URL/step" \
|
| 48 |
+
-H "Content-Type: application/json" \
|
| 49 |
+
-d '{"query": "SELCT * FORM broken_tbl"}')
|
| 50 |
+
echo "$STEP_ERR" | grep -q "last_error" && pass "/step bad SQL → has 'last_error'" \
|
| 51 |
+
|| fail "/step bad SQL → missing 'last_error'. Body: $STEP_ERR"
|
| 52 |
+
|
| 53 |
+
# ── /state ────────────────────────────────────────────────────────────────────
|
| 54 |
+
STATE_BODY=$(curl -s "$BASE_URL/state")
|
| 55 |
+
echo "$STATE_BODY" | grep -q "step_count" && pass "/state → has 'step_count'" \
|
| 56 |
+
|| fail "/state → missing 'step_count'. Body: $STATE_BODY"
|
| 57 |
+
|
| 58 |
+
echo "────────────────────────────────────────"
|
| 59 |
+
echo -e "${BOLD}Results: ${GREEN}${PASS} passed${NC}, ${RED}${FAIL} failed${NC}"
|
| 60 |
+
echo ""
|
| 61 |
+
|
| 62 |
+
[ "$FAIL" -eq 0 ] && exit 0 || exit 1
|
server/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# server/__init__.py
|
server/app.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
nl2sql-bench/server/app.py
|
| 3 |
+
============================
|
| 4 |
+
FastAPI application entry point for the NL2SQL-Bench OpenEnv server.
|
| 5 |
+
|
| 6 |
+
create_fastapi_app() auto-creates all required OpenEnv endpoints:
|
| 7 |
+
POST /reset — start a new episode
|
| 8 |
+
POST /step — submit an action
|
| 9 |
+
GET /state — retrieve episode state
|
| 10 |
+
GET /health — health check
|
| 11 |
+
GET /web — interactive web UI (if ENABLE_WEB_INTERFACE=true)
|
| 12 |
+
GET /docs — Swagger UI
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import sys
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from openenv.core.env_server import create_fastapi_app
|
| 18 |
+
from environment import NL2SQLEnvironment
|
| 19 |
+
|
| 20 |
+
# Ensure models can be imported from the parent directory
|
| 21 |
+
_HERE = Path(__file__).parent
|
| 22 |
+
sys.path.insert(0, str(_HERE.parent))
|
| 23 |
+
|
| 24 |
+
from models import NL2SQLAction, NL2SQLObservation
|
| 25 |
+
|
| 26 |
+
# Pass the explicitly required action and observation classes
|
| 27 |
+
app = create_fastapi_app(
|
| 28 |
+
NL2SQLEnvironment,
|
| 29 |
+
action_cls=NL2SQLAction,
|
| 30 |
+
observation_cls=NL2SQLObservation
|
| 31 |
+
)
|
server/app.py.bak
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
nl2sql-bench/server/app.py
|
| 3 |
+
============================
|
| 4 |
+
FastAPI application entry point for the NL2SQL-Bench OpenEnv server.
|
| 5 |
+
|
| 6 |
+
create_fastapi_app() auto-creates all required OpenEnv endpoints:
|
| 7 |
+
POST /reset — start a new episode
|
| 8 |
+
POST /step — submit an action
|
| 9 |
+
GET /state — retrieve episode state
|
| 10 |
+
GET /health — health check
|
| 11 |
+
GET /web — interactive web UI (if ENABLE_WEB_INTERFACE=true)
|
| 12 |
+
GET /docs — Swagger UI
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import sys
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
from openenv.core.env_server import create_fastapi_app
|
| 18 |
+
from environment import NL2SQLEnvironment
|
| 19 |
+
|
| 20 |
+
# Ensure models can be imported from the parent directory
|
| 21 |
+
_HERE = Path(__file__).parent
|
| 22 |
+
sys.path.insert(0, str(_HERE.parent))
|
| 23 |
+
|
| 24 |
+
from models import NL2SQLAction, NL2SQLObservation
|
| 25 |
+
|
| 26 |
+
# Pass the explicitly required action and observation classes
|
| 27 |
+
app = create_fastapi_app(
|
| 28 |
+
NL2SQLEnvironment,
|
| 29 |
+
action_cls=NL2SQLAction,
|
| 30 |
+
observation_cls=NL2SQLObservation
|
| 31 |
+
)
|
server/db/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# server/db/__init__.py
|
server/db/schema.sql
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-- nl2sql-bench/server/db/schema.sql
|
| 2 |
+
-- E-commerce database schema for NL2SQL-Bench
|
| 3 |
+
-- Designed for in-memory SQLite: realistic, universally understood domain.
|
| 4 |
+
|
| 5 |
+
CREATE TABLE IF NOT EXISTS categories (
|
| 6 |
+
id INTEGER PRIMARY KEY,
|
| 7 |
+
name TEXT NOT NULL UNIQUE
|
| 8 |
+
);
|
| 9 |
+
|
| 10 |
+
CREATE TABLE IF NOT EXISTS products (
|
| 11 |
+
id INTEGER PRIMARY KEY,
|
| 12 |
+
name TEXT NOT NULL,
|
| 13 |
+
category_id INTEGER NOT NULL REFERENCES categories(id),
|
| 14 |
+
price REAL NOT NULL CHECK(price >= 0),
|
| 15 |
+
stock_quantity INTEGER NOT NULL DEFAULT 0
|
| 16 |
+
);
|
| 17 |
+
|
| 18 |
+
CREATE TABLE IF NOT EXISTS customers (
|
| 19 |
+
id INTEGER PRIMARY KEY,
|
| 20 |
+
name TEXT NOT NULL,
|
| 21 |
+
email TEXT NOT NULL UNIQUE,
|
| 22 |
+
country TEXT NOT NULL,
|
| 23 |
+
tier TEXT NOT NULL DEFAULT 'bronze' -- bronze | silver | gold
|
| 24 |
+
CHECK(tier IN ('bronze', 'silver', 'gold')),
|
| 25 |
+
created_at TEXT NOT NULL -- ISO-8601 date string
|
| 26 |
+
);
|
| 27 |
+
|
| 28 |
+
CREATE TABLE IF NOT EXISTS orders (
|
| 29 |
+
id INTEGER PRIMARY KEY,
|
| 30 |
+
customer_id INTEGER NOT NULL REFERENCES customers(id),
|
| 31 |
+
status TEXT NOT NULL DEFAULT 'pending'
|
| 32 |
+
CHECK(status IN ('pending','processing','shipped','delivered','cancelled')),
|
| 33 |
+
created_at TEXT NOT NULL,
|
| 34 |
+
total_amount REAL NOT NULL CHECK(total_amount >= 0)
|
| 35 |
+
);
|
| 36 |
+
|
| 37 |
+
CREATE TABLE IF NOT EXISTS order_items (
|
| 38 |
+
id INTEGER PRIMARY KEY,
|
| 39 |
+
order_id INTEGER NOT NULL REFERENCES orders(id),
|
| 40 |
+
product_id INTEGER NOT NULL REFERENCES products(id),
|
| 41 |
+
quantity INTEGER NOT NULL CHECK(quantity > 0),
|
| 42 |
+
unit_price REAL NOT NULL CHECK(unit_price >= 0)
|
| 43 |
+
);
|
| 44 |
+
|
| 45 |
+
CREATE TABLE IF NOT EXISTS reviews (
|
| 46 |
+
id INTEGER PRIMARY KEY,
|
| 47 |
+
product_id INTEGER NOT NULL REFERENCES products(id),
|
| 48 |
+
customer_id INTEGER NOT NULL REFERENCES customers(id),
|
| 49 |
+
rating INTEGER NOT NULL CHECK(rating BETWEEN 1 AND 5),
|
| 50 |
+
created_at TEXT NOT NULL
|
| 51 |
+
);
|
| 52 |
+
|
| 53 |
+
-- Indexes for common join/filter patterns
|
| 54 |
+
CREATE INDEX IF NOT EXISTS idx_products_category ON products(category_id);
|
| 55 |
+
CREATE INDEX IF NOT EXISTS idx_orders_customer ON orders(customer_id);
|
| 56 |
+
CREATE INDEX IF NOT EXISTS idx_orders_status ON orders(status);
|
| 57 |
+
CREATE INDEX IF NOT EXISTS idx_orders_created ON orders(created_at);
|
| 58 |
+
CREATE INDEX IF NOT EXISTS idx_order_items_order ON order_items(order_id);
|
| 59 |
+
CREATE INDEX IF NOT EXISTS idx_order_items_product ON order_items(product_id);
|
| 60 |
+
CREATE INDEX IF NOT EXISTS idx_reviews_product ON reviews(product_id);
|
| 61 |
+
CREATE INDEX IF NOT EXISTS idx_customers_country ON customers(country);
|
| 62 |
+
CREATE INDEX IF NOT EXISTS idx_customers_tier ON customers(tier);
|
server/db/seed.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
nl2sql-bench/server/db/seed.py
|
| 3 |
+
==============================
|
| 4 |
+
Deterministic synthetic data generator for the NL2SQL-Bench SQLite database.
|
| 5 |
+
|
| 6 |
+
Uses a fixed random seed so every fresh environment build produces
|
| 7 |
+
IDENTICAL data, which is essential for reproducible grader scores across
|
| 8 |
+
different machines, runs, and Docker containers.
|
| 9 |
+
|
| 10 |
+
Call: seed_database(conn) once after creating tables.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import random
|
| 16 |
+
import sqlite3
|
| 17 |
+
from datetime import date, timedelta
|
| 18 |
+
from typing import List
|
| 19 |
+
|
| 20 |
+
# ── Deterministic seed ────────────────────────────────────────────────────
|
| 21 |
+
SEED = 42
|
| 22 |
+
RNG = random.Random(SEED)
|
| 23 |
+
|
| 24 |
+
# ── Domain constants ──────────────────────────────────────────────────────
|
| 25 |
+
CATEGORIES = [
|
| 26 |
+
"Electronics", "Clothing", "Books", "Home & Garden",
|
| 27 |
+
"Sports & Outdoors", "Toys & Games", "Beauty", "Automotive",
|
| 28 |
+
]
|
| 29 |
+
|
| 30 |
+
PRODUCT_NAMES = {
|
| 31 |
+
"Electronics": ["Wireless Headphones", "USB-C Hub", "Mechanical Keyboard",
|
| 32 |
+
"Webcam 4K", "Portable Charger", "Smart Speaker",
|
| 33 |
+
"Monitor Stand", "HDMI Cable 2.1"],
|
| 34 |
+
"Clothing": ["Cotton T-Shirt", "Slim Fit Jeans", "Hoodie",
|
| 35 |
+
"Running Shorts", "Winter Jacket", "Polo Shirt",
|
| 36 |
+
"Casual Sneakers", "Wool Socks"],
|
| 37 |
+
"Books": ["Clean Code", "Designing Data-Intensive Applications",
|
| 38 |
+
"The Pragmatic Programmer", "System Design Interview",
|
| 39 |
+
"Deep Learning Book", "Python Cookbook",
|
| 40 |
+
"Domain-Driven Design", "Refactoring"],
|
| 41 |
+
"Home & Garden": ["Coffee Maker", "Air Purifier", "LED Desk Lamp",
|
| 42 |
+
"Plant Pot Set", "Storage Organiser", "Cutting Board",
|
| 43 |
+
"Vacuum Cleaner", "Electric Kettle"],
|
| 44 |
+
"Sports & Outdoors":["Yoga Mat", "Resistance Bands", "Cycling Gloves",
|
| 45 |
+
"Trekking Poles", "Water Bottle 1L", "Jump Rope",
|
| 46 |
+
"Foam Roller", "Compression Socks"],
|
| 47 |
+
"Toys & Games": ["Lego City Set", "Card Game Pack", "Puzzle 1000pc",
|
| 48 |
+
"Remote Control Car", "Building Blocks",
|
| 49 |
+
"Board Game Strategy", "Art Set", "Toy Drone"],
|
| 50 |
+
"Beauty": ["Face Serum", "SPF 50 Sunscreen", "Lip Balm",
|
| 51 |
+
"Shampoo Pro", "Hair Mask", "Eye Cream",
|
| 52 |
+
"Vitamin C Cream", "Toner Mist"],
|
| 53 |
+
"Automotive": ["Car Phone Mount", "Dash Cam", "Tyre Inflator",
|
| 54 |
+
"Car Vacuum", "Seat Cushion", "Steering Wheel Cover",
|
| 55 |
+
"OBD Scanner", "Jump Starter"],
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
COUNTRIES = ["India", "USA", "Germany", "UK", "Canada",
|
| 59 |
+
"Australia", "France", "Brazil", "Japan", "Singapore"]
|
| 60 |
+
|
| 61 |
+
TIERS = ["bronze", "silver", "gold"]
|
| 62 |
+
STATUSES = ["pending", "processing", "shipped", "delivered", "cancelled"]
|
| 63 |
+
|
| 64 |
+
FIRST_NAMES = [
|
| 65 |
+
"Aarav","Priya","Rahul","Neha","Arjun","Sneha","Vikram","Pooja",
|
| 66 |
+
"Karthik","Divya","James","Sarah","Michael","Emily","David","Jessica",
|
| 67 |
+
"Hans","Lena","Oliver","Sofia","Pierre","Amelie","Carlos","Laura",
|
| 68 |
+
"Yuki","Hana","Wei","Mei","Aiden","Zara",
|
| 69 |
+
]
|
| 70 |
+
LAST_NAMES = [
|
| 71 |
+
"Sharma","Singh","Patel","Kumar","Gupta","Verma","Nair","Reddy",
|
| 72 |
+
"Smith","Johnson","Brown","Williams","Jones","Davis","Wilson",
|
| 73 |
+
"Müller","Schmidt","Schneider","Fischer","Weber",
|
| 74 |
+
"Martin","Bernard","Thomas","Richard","Petit",
|
| 75 |
+
"Garcia","Martinez","Lopez","Sanchez","Gonzalez",
|
| 76 |
+
]
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _random_date(start_year: int = 2022, end_year: int = 2025) -> str:
|
| 80 |
+
start = date(start_year, 1, 1)
|
| 81 |
+
end = date(end_year, 12, 31)
|
| 82 |
+
delta = (end - start).days
|
| 83 |
+
return (start + timedelta(days=RNG.randint(0, delta))).isoformat()
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def seed_database(conn: sqlite3.Connection) -> None:
|
| 87 |
+
"""Populate the database with deterministic synthetic data."""
|
| 88 |
+
conn.execute("PRAGMA foreign_keys = ON")
|
| 89 |
+
cur = conn.cursor()
|
| 90 |
+
|
| 91 |
+
# ── Categories ────────────────────────────────────────────────────────
|
| 92 |
+
for i, name in enumerate(CATEGORIES, 1):
|
| 93 |
+
cur.execute(
|
| 94 |
+
"INSERT OR IGNORE INTO categories(id, name) VALUES (?, ?)",
|
| 95 |
+
(i, name),
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
# ── Products (8 per category → 64 total) ─────────────────────────────
|
| 99 |
+
pid = 1
|
| 100 |
+
for cat_id, (cat_name, names) in enumerate(PRODUCT_NAMES.items(), 1):
|
| 101 |
+
for pname in names:
|
| 102 |
+
price = round(RNG.uniform(5.0, 250.0), 2)
|
| 103 |
+
stock = RNG.randint(0, 500)
|
| 104 |
+
cur.execute(
|
| 105 |
+
"INSERT OR IGNORE INTO products(id, name, category_id, price, stock_quantity) "
|
| 106 |
+
"VALUES (?, ?, ?, ?, ?)",
|
| 107 |
+
(pid, pname, cat_id, price, stock),
|
| 108 |
+
)
|
| 109 |
+
pid += 1
|
| 110 |
+
|
| 111 |
+
# ── Customers (150 total) ─────────────────────────────────────────────
|
| 112 |
+
used_emails: set = set()
|
| 113 |
+
for cid in range(1, 151):
|
| 114 |
+
fname = RNG.choice(FIRST_NAMES)
|
| 115 |
+
lname = RNG.choice(LAST_NAMES)
|
| 116 |
+
name = f"{fname} {lname}"
|
| 117 |
+
email_base = f"{fname.lower()}.{lname.lower()}"
|
| 118 |
+
email = f"{email_base}{cid}@example.com"
|
| 119 |
+
while email in used_emails:
|
| 120 |
+
email = f"{email_base}{cid}x@example.com"
|
| 121 |
+
used_emails.add(email)
|
| 122 |
+
|
| 123 |
+
# Bias: 60% bronze, 30% silver, 10% gold
|
| 124 |
+
tier = RNG.choices(TIERS, weights=[60, 30, 10])[0]
|
| 125 |
+
country = RNG.choice(COUNTRIES)
|
| 126 |
+
created = _random_date(2021, 2023)
|
| 127 |
+
cur.execute(
|
| 128 |
+
"INSERT OR IGNORE INTO customers(id, name, email, country, tier, created_at) "
|
| 129 |
+
"VALUES (?, ?, ?, ?, ?, ?)",
|
| 130 |
+
(cid, name, email, country, tier, created),
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
# ── Orders + Order items ──────────────────────────────────────────────
|
| 134 |
+
oid = 1
|
| 135 |
+
item_id = 1
|
| 136 |
+
for cid in range(1, 151):
|
| 137 |
+
# Each customer has 0–8 orders; gold customers tend to have more
|
| 138 |
+
tier_row = cur.execute(
|
| 139 |
+
"SELECT tier FROM customers WHERE id=?", (cid,)
|
| 140 |
+
).fetchone()
|
| 141 |
+
tier = tier_row[0] if tier_row else "bronze"
|
| 142 |
+
n_orders = RNG.choices(
|
| 143 |
+
range(9),
|
| 144 |
+
weights=[5, 20, 20, 15, 15, 10, 8, 5, 2] if tier == "bronze"
|
| 145 |
+
else ([2, 10, 15, 20, 20, 15, 10, 5, 3] if tier == "silver"
|
| 146 |
+
else [1, 5, 10, 15, 20, 20, 15, 10, 4]),
|
| 147 |
+
)[0]
|
| 148 |
+
|
| 149 |
+
for _ in range(n_orders):
|
| 150 |
+
status = RNG.choices(STATUSES, weights=[5, 10, 15, 60, 10])[0]
|
| 151 |
+
order_date = _random_date(2022, 2025)
|
| 152 |
+
# Pick 1–4 products for this order
|
| 153 |
+
n_items = RNG.randint(1, 4)
|
| 154 |
+
chosen_pids = RNG.sample(range(1, 65), k=min(n_items, 64))
|
| 155 |
+
total = 0.0
|
| 156 |
+
|
| 157 |
+
cur.execute(
|
| 158 |
+
"INSERT OR IGNORE INTO orders(id, customer_id, status, created_at, total_amount) "
|
| 159 |
+
"VALUES (?, ?, ?, ?, ?)",
|
| 160 |
+
(oid, cid, status, order_date, 0.0), # update total after items
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
for cpid in chosen_pids:
|
| 164 |
+
qty = RNG.randint(1, 5)
|
| 165 |
+
price_row = cur.execute(
|
| 166 |
+
"SELECT price FROM products WHERE id=?", (cpid,)
|
| 167 |
+
).fetchone()
|
| 168 |
+
unit_price = price_row[0] if price_row else 10.0
|
| 169 |
+
total += round(qty * unit_price, 2)
|
| 170 |
+
cur.execute(
|
| 171 |
+
"INSERT OR IGNORE INTO order_items(id, order_id, product_id, quantity, unit_price) "
|
| 172 |
+
"VALUES (?, ?, ?, ?, ?)",
|
| 173 |
+
(item_id, oid, cpid, qty, unit_price),
|
| 174 |
+
)
|
| 175 |
+
item_id += 1
|
| 176 |
+
|
| 177 |
+
cur.execute(
|
| 178 |
+
"UPDATE orders SET total_amount=? WHERE id=?",
|
| 179 |
+
(round(total, 2), oid),
|
| 180 |
+
)
|
| 181 |
+
oid += 1
|
| 182 |
+
|
| 183 |
+
# ── Reviews ───────────────────────────────────────────────────────────
|
| 184 |
+
# Each customer reviews 0–6 products they (may have) ordered
|
| 185 |
+
rev_id = 1
|
| 186 |
+
reviewed: set = set() # (customer_id, product_id) pairs
|
| 187 |
+
for cid in range(1, 151):
|
| 188 |
+
n_reviews = RNG.randint(0, 6)
|
| 189 |
+
for _ in range(n_reviews):
|
| 190 |
+
rpid = RNG.randint(1, 64)
|
| 191 |
+
if (cid, rpid) in reviewed:
|
| 192 |
+
continue
|
| 193 |
+
reviewed.add((cid, rpid))
|
| 194 |
+
rating = RNG.choices([1, 2, 3, 4, 5], weights=[5, 10, 15, 35, 35])[0]
|
| 195 |
+
rev_date = _random_date(2022, 2025)
|
| 196 |
+
cur.execute(
|
| 197 |
+
"INSERT OR IGNORE INTO reviews(id, product_id, customer_id, rating, created_at) "
|
| 198 |
+
"VALUES (?, ?, ?, ?, ?)",
|
| 199 |
+
(rev_id, rpid, cid, rating, rev_date),
|
| 200 |
+
)
|
| 201 |
+
rev_id += 1
|
| 202 |
+
|
| 203 |
+
conn.commit()
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def get_db_summary(conn: sqlite3.Connection) -> dict:
|
| 207 |
+
"""Return row counts per table for debugging / README stats."""
|
| 208 |
+
tables = ["categories", "products", "customers", "orders", "order_items", "reviews"]
|
| 209 |
+
summary = {}
|
| 210 |
+
for t in tables:
|
| 211 |
+
row = conn.execute(f"SELECT COUNT(*) FROM {t}").fetchone()
|
| 212 |
+
summary[t] = row[0] if row else 0
|
| 213 |
+
return summary
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
if __name__ == "__main__":
|
| 217 |
+
import os
|
| 218 |
+
schema_path = os.path.join(os.path.dirname(__file__), "schema.sql")
|
| 219 |
+
conn = sqlite3.connect(":memory:")
|
| 220 |
+
conn.row_factory = sqlite3.Row
|
| 221 |
+
with open(schema_path) as f:
|
| 222 |
+
conn.executescript(f.read())
|
| 223 |
+
seed_database(conn)
|
| 224 |
+
print("Seed stats:", get_db_summary(conn))
|
| 225 |
+
conn.close()
|
server/environment.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
nl2sql-bench/server/environment.py
|
| 3 |
+
====================================
|
| 4 |
+
NL2SQL-Bench core environment — implements the OpenEnv Environment interface.
|
| 5 |
+
|
| 6 |
+
Episode flow
|
| 7 |
+
------------
|
| 8 |
+
1. reset(task_name?) → picks a task + question, returns initial observation
|
| 9 |
+
2. step(action) → executes the SQL, grades it, returns observation + reward
|
| 10 |
+
3. state() → returns episode metadata
|
| 11 |
+
4. Episode ends when: exact_match OR step count reaches max_steps
|
| 12 |
+
|
| 13 |
+
The environment manages its own SQLite connection (in-memory, seeded
|
| 14 |
+
deterministically). One connection per Environment instance; the FastAPI
|
| 15 |
+
server creates one Environment per WebSocket session.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
import os
|
| 21 |
+
import sqlite3
|
| 22 |
+
import uuid
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
from typing import Optional
|
| 25 |
+
|
| 26 |
+
from openenv.core.env_server import Environment
|
| 27 |
+
|
| 28 |
+
# Import after openenv so path is correct regardless of working directory
|
| 29 |
+
_HERE = Path(__file__).parent
|
| 30 |
+
|
| 31 |
+
# Lazy import of task registry (avoids circular imports)
|
| 32 |
+
from tasks import get_task, all_task_names, BaseTask
|
| 33 |
+
from tasks.base import TaskExample
|
| 34 |
+
from grader import (
|
| 35 |
+
GradeResult,
|
| 36 |
+
compute_ground_truth,
|
| 37 |
+
execute_query,
|
| 38 |
+
grade,
|
| 39 |
+
has_order_by,
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# We import our models from one level up (models.py at project root)
|
| 43 |
+
import sys
|
| 44 |
+
sys.path.insert(0, str(_HERE.parent))
|
| 45 |
+
from models import NL2SQLAction, NL2SQLObservation, NL2SQLState
|
| 46 |
+
|
| 47 |
+
# ── Constants ──────────────────────────────────────────────────────────────
|
| 48 |
+
DEFAULT_TASK = os.getenv("NL2SQL_DEFAULT_TASK", "simple-filter")
|
| 49 |
+
MAX_STEPS = int(os.getenv("NL2SQL_MAX_STEPS", "5"))
|
| 50 |
+
RESULT_LIMIT = 10 # Max rows shown to agent per step
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class NL2SQLEnvironment(Environment):
|
| 54 |
+
"""
|
| 55 |
+
OpenEnv-compliant environment for NL-to-SQL query generation.
|
| 56 |
+
|
| 57 |
+
One instance per WebSocket session (created by create_fastapi_app).
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
def __init__(self) -> None:
|
| 61 |
+
self._conn: Optional[sqlite3.Connection] = None
|
| 62 |
+
self._task: Optional[BaseTask] = None
|
| 63 |
+
self._example: Optional[TaskExample] = None
|
| 64 |
+
self._ground_truth: list = []
|
| 65 |
+
self._order_sensitive: bool = False
|
| 66 |
+
self._state = NL2SQLState(
|
| 67 |
+
episode_id=None,
|
| 68 |
+
step_count=0,
|
| 69 |
+
task_name="",
|
| 70 |
+
task_difficulty="",
|
| 71 |
+
question="",
|
| 72 |
+
best_reward=0.0,
|
| 73 |
+
cumulative_reward=0.0,
|
| 74 |
+
solved=False
|
| 75 |
+
)
|
| 76 |
+
self._last_obs = NL2SQLObservation(
|
| 77 |
+
question="",
|
| 78 |
+
schema_context="",
|
| 79 |
+
task_name="",
|
| 80 |
+
last_query="",
|
| 81 |
+
last_result=[],
|
| 82 |
+
last_error=None,
|
| 83 |
+
result_columns=[],
|
| 84 |
+
step=0,
|
| 85 |
+
max_steps=5,
|
| 86 |
+
done=False,
|
| 87 |
+
reward=None,
|
| 88 |
+
score=0.0
|
| 89 |
+
)
|
| 90 |
+
self._episode_rewards: list = []
|
| 91 |
+
self._setup_db()
|
| 92 |
+
|
| 93 |
+
# ── DB lifecycle ───────────────────────────────────────────────────────
|
| 94 |
+
|
| 95 |
+
def _setup_db(self) -> None:
|
| 96 |
+
"""Create in-memory SQLite DB and seed it."""
|
| 97 |
+
schema_path = _HERE / "db" / "schema.sql"
|
| 98 |
+
from db.seed import seed_database # local import after sys.path setup
|
| 99 |
+
conn = sqlite3.connect(":memory:", check_same_thread=False)
|
| 100 |
+
conn.row_factory = sqlite3.Row
|
| 101 |
+
conn.execute("PRAGMA foreign_keys = ON")
|
| 102 |
+
conn.executescript(schema_path.read_text())
|
| 103 |
+
seed_database(conn)
|
| 104 |
+
self._conn = conn
|
| 105 |
+
|
| 106 |
+
# ── OpenEnv interface ──────────────────────────────────────────────────
|
| 107 |
+
|
| 108 |
+
def reset(self, task_name: Optional[str] = None) -> NL2SQLObservation:
|
| 109 |
+
"""
|
| 110 |
+
Start a new episode.
|
| 111 |
+
|
| 112 |
+
task_name: one of 'simple-filter', 'join-aggregation', 'analytics-window'.
|
| 113 |
+
Defaults to NL2SQL_DEFAULT_TASK env-var or 'simple-filter'.
|
| 114 |
+
"""
|
| 115 |
+
task_name = task_name or DEFAULT_TASK
|
| 116 |
+
if task_name not in all_task_names():
|
| 117 |
+
task_name = DEFAULT_TASK
|
| 118 |
+
|
| 119 |
+
self._task = get_task(task_name)
|
| 120 |
+
self._example = self._task.next_example()
|
| 121 |
+
self._order_sensitive = has_order_by(self._example.sql)
|
| 122 |
+
|
| 123 |
+
# Pre-compute ground truth once per episode
|
| 124 |
+
self._ground_truth = compute_ground_truth(self._conn, self._example.sql)
|
| 125 |
+
|
| 126 |
+
self._episode_rewards = []
|
| 127 |
+
self._state = NL2SQLState(
|
| 128 |
+
episode_id=str(uuid.uuid4()),
|
| 129 |
+
step_count=0,
|
| 130 |
+
task_name=self._task.name,
|
| 131 |
+
task_difficulty=self._task.difficulty,
|
| 132 |
+
question=self._example.question,
|
| 133 |
+
best_reward=0.0,
|
| 134 |
+
cumulative_reward=0.0,
|
| 135 |
+
solved=False,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
obs = NL2SQLObservation(
|
| 139 |
+
question=self._example.question,
|
| 140 |
+
schema_context=self._task.schema_context(),
|
| 141 |
+
task_name=self._task.name,
|
| 142 |
+
last_query="",
|
| 143 |
+
last_result=[],
|
| 144 |
+
last_error=None,
|
| 145 |
+
result_columns=[],
|
| 146 |
+
step=0,
|
| 147 |
+
max_steps=MAX_STEPS,
|
| 148 |
+
done=False,
|
| 149 |
+
reward=None,
|
| 150 |
+
score=0.0,
|
| 151 |
+
)
|
| 152 |
+
self._last_obs = obs
|
| 153 |
+
return obs
|
| 154 |
+
|
| 155 |
+
def step(self, action: NL2SQLAction) -> NL2SQLObservation:
|
| 156 |
+
"""Execute the agent's SQL and return graded observation."""
|
| 157 |
+
if self._task is None or self._example is None:
|
| 158 |
+
# Called before reset — auto-reset
|
| 159 |
+
self.reset()
|
| 160 |
+
|
| 161 |
+
self._state.step_count += 1
|
| 162 |
+
current_step = self._state.step_count
|
| 163 |
+
done = False
|
| 164 |
+
|
| 165 |
+
# Execute the query
|
| 166 |
+
rows, error = execute_query(self._conn, action.query)
|
| 167 |
+
|
| 168 |
+
# Grade it
|
| 169 |
+
result: GradeResult = grade(
|
| 170 |
+
actual_rows=rows,
|
| 171 |
+
ground_truth_rows=self._ground_truth,
|
| 172 |
+
error=error,
|
| 173 |
+
step=current_step,
|
| 174 |
+
order_sensitive=self._order_sensitive,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
reward = result.reward
|
| 178 |
+
self._episode_rewards.append(reward)
|
| 179 |
+
self._state.cumulative_reward += reward
|
| 180 |
+
self._state.best_reward = max(self._state.best_reward, reward)
|
| 181 |
+
|
| 182 |
+
if result.exact_match:
|
| 183 |
+
self._state.solved = True
|
| 184 |
+
done = True
|
| 185 |
+
elif current_step >= MAX_STEPS:
|
| 186 |
+
done = True
|
| 187 |
+
|
| 188 |
+
# Prepare result rows for observation (truncated for agent readability)
|
| 189 |
+
display_rows = (rows or [])[:RESULT_LIMIT]
|
| 190 |
+
result_columns = list(display_rows[0].keys()) if display_rows else []
|
| 191 |
+
# Convert sqlite3.Row objects if needed
|
| 192 |
+
display_rows = [dict(r) for r in display_rows]
|
| 193 |
+
|
| 194 |
+
# Normalised cumulative score
|
| 195 |
+
n = len(self._episode_rewards)
|
| 196 |
+
score = self._state.cumulative_reward / max(n, 1) if n else 0.0
|
| 197 |
+
score = round(min(max(score, 0.0), 1.0), 4)
|
| 198 |
+
|
| 199 |
+
obs = NL2SQLObservation(
|
| 200 |
+
question=self._example.question,
|
| 201 |
+
schema_context=self._task.schema_context(),
|
| 202 |
+
task_name=self._task.name,
|
| 203 |
+
last_query=action.query,
|
| 204 |
+
last_result=display_rows,
|
| 205 |
+
last_error=error,
|
| 206 |
+
result_columns=result_columns,
|
| 207 |
+
step=current_step,
|
| 208 |
+
max_steps=MAX_STEPS,
|
| 209 |
+
done=done,
|
| 210 |
+
reward=reward,
|
| 211 |
+
score=score,
|
| 212 |
+
)
|
| 213 |
+
self._last_obs = obs
|
| 214 |
+
return obs
|
| 215 |
+
|
| 216 |
+
@property
|
| 217 |
+
def state(self) -> NL2SQLState:
|
| 218 |
+
return self._state
|
| 219 |
+
|
| 220 |
+
# ── Helpers ────────────────────────────────────────────────────────────
|
| 221 |
+
|
| 222 |
+
def available_tasks(self) -> list:
|
| 223 |
+
return all_task_names()
|
server/grader.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
nl2sql-bench/server/grader.py
|
| 3 |
+
==============================
|
| 4 |
+
Deterministic, programmatic reward grader.
|
| 5 |
+
|
| 6 |
+
No LLM-as-judge. Every reward is computed by comparing the agent's SQL
|
| 7 |
+
execution results against a ground-truth result set.
|
| 8 |
+
|
| 9 |
+
Reward decomposition (sums to 1.0 for a perfect first-attempt answer):
|
| 10 |
+
+0.10 syntax_ok — query runs without SQLite error
|
| 11 |
+
+0.20 columns_match — returned column names match ground truth exactly
|
| 12 |
+
+0.20 row_count_match — number of returned rows matches
|
| 13 |
+
+0.50 exact_match — full result set equals ground truth (order-aware
|
| 14 |
+
for ORDER BY queries, order-agnostic otherwise)
|
| 15 |
+
|
| 16 |
+
Step penalty:
|
| 17 |
+
-0.05 per step beyond the first (encourages solving in fewer steps),
|
| 18 |
+
clamped so the minimum is always 0.0.
|
| 19 |
+
|
| 20 |
+
All rewards are floats in [0.0, 1.0].
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
from __future__ import annotations
|
| 24 |
+
|
| 25 |
+
import sqlite3
|
| 26 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# ── Result normalisation ───────────────────────────────────────────────────
|
| 30 |
+
|
| 31 |
+
def _normalise_value(v: Any) -> Any:
|
| 32 |
+
"""Round floats for comparison so 1.2300000001 == 1.23."""
|
| 33 |
+
if isinstance(v, float):
|
| 34 |
+
return round(v, 4)
|
| 35 |
+
if isinstance(v, str):
|
| 36 |
+
return v.strip()
|
| 37 |
+
return v
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _normalise_row(row: Dict[str, Any]) -> Dict[str, Any]:
|
| 41 |
+
return {k: _normalise_value(v) for k, v in row.items()}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _normalise_rows(rows: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
| 45 |
+
return [_normalise_row(r) for r in rows]
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
# ── SQL execution ──────────────────────────────────────────────────────────
|
| 49 |
+
|
| 50 |
+
def execute_query(
|
| 51 |
+
conn: sqlite3.Connection,
|
| 52 |
+
query: str,
|
| 53 |
+
max_rows: int = 200,
|
| 54 |
+
) -> Tuple[Optional[List[Dict[str, Any]]], Optional[str]]:
|
| 55 |
+
"""
|
| 56 |
+
Execute a SQL query safely.
|
| 57 |
+
|
| 58 |
+
Returns (rows, error_string).
|
| 59 |
+
rows is None on error.
|
| 60 |
+
"""
|
| 61 |
+
query = query.strip().rstrip(";")
|
| 62 |
+
if not query:
|
| 63 |
+
return None, "Empty query."
|
| 64 |
+
|
| 65 |
+
# Block write operations — the environment is read-only from the agent's view.
|
| 66 |
+
forbidden = ("insert", "update", "delete", "drop", "alter",
|
| 67 |
+
"create", "replace", "truncate", "pragma")
|
| 68 |
+
first_word = query.split()[0].lower() if query.split() else ""
|
| 69 |
+
if first_word in forbidden:
|
| 70 |
+
return None, (
|
| 71 |
+
f"Operation '{first_word.upper()}' is not allowed. "
|
| 72 |
+
"Only SELECT queries are permitted."
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
try:
|
| 76 |
+
cur = conn.execute(query)
|
| 77 |
+
cols = [d[0] for d in cur.description] if cur.description else []
|
| 78 |
+
rows = [dict(zip(cols, row)) for row in cur.fetchmany(max_rows)]
|
| 79 |
+
return rows, None
|
| 80 |
+
except sqlite3.Error as exc:
|
| 81 |
+
return None, str(exc)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# ── Grading logic ──────────────────────────────────────────────────────────
|
| 85 |
+
|
| 86 |
+
class GradeResult:
|
| 87 |
+
__slots__ = (
|
| 88 |
+
"reward", "syntax_ok", "columns_match",
|
| 89 |
+
"row_count_match", "exact_match", "step_penalty",
|
| 90 |
+
"breakdown",
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
def __init__(
|
| 94 |
+
self,
|
| 95 |
+
reward: float,
|
| 96 |
+
syntax_ok: bool,
|
| 97 |
+
columns_match: bool,
|
| 98 |
+
row_count_match: bool,
|
| 99 |
+
exact_match: bool,
|
| 100 |
+
step_penalty: float,
|
| 101 |
+
) -> None:
|
| 102 |
+
self.reward = reward
|
| 103 |
+
self.syntax_ok = syntax_ok
|
| 104 |
+
self.columns_match = columns_match
|
| 105 |
+
self.row_count_match = row_count_match
|
| 106 |
+
self.exact_match = exact_match
|
| 107 |
+
self.step_penalty = step_penalty
|
| 108 |
+
self.breakdown = {
|
| 109 |
+
"syntax_ok": 0.10 if syntax_ok else 0.0,
|
| 110 |
+
"columns_match": 0.20 if (syntax_ok and columns_match) else 0.0,
|
| 111 |
+
"row_count_match": 0.20 if (syntax_ok and row_count_match) else 0.0,
|
| 112 |
+
"exact_match": 0.50 if (syntax_ok and exact_match) else 0.0,
|
| 113 |
+
"step_penalty": -step_penalty,
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
def __repr__(self) -> str: # pragma: no cover
|
| 117 |
+
return (
|
| 118 |
+
f"GradeResult(reward={self.reward:.3f}, "
|
| 119 |
+
f"exact={self.exact_match}, cols={self.columns_match}, "
|
| 120 |
+
f"rows={self.row_count_match}, syntax={self.syntax_ok})"
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def grade(
|
| 125 |
+
actual_rows: Optional[List[Dict[str, Any]]],
|
| 126 |
+
ground_truth_rows: List[Dict[str, Any]],
|
| 127 |
+
error: Optional[str],
|
| 128 |
+
step: int,
|
| 129 |
+
order_sensitive: bool = False,
|
| 130 |
+
) -> GradeResult:
|
| 131 |
+
"""
|
| 132 |
+
Grade the agent's query result against ground truth.
|
| 133 |
+
|
| 134 |
+
Parameters
|
| 135 |
+
----------
|
| 136 |
+
actual_rows : Rows returned by the agent's query (None on error).
|
| 137 |
+
ground_truth_rows : Expected rows (pre-computed at task load time).
|
| 138 |
+
error : SQLite error string (None if query ran successfully).
|
| 139 |
+
step : Current step number (1-indexed) for penalty calculation.
|
| 140 |
+
order_sensitive : If True, row order matters (queries with ORDER BY).
|
| 141 |
+
"""
|
| 142 |
+
# ── Syntax ──────────────────────────────────────────────────────────
|
| 143 |
+
syntax_ok = error is None and actual_rows is not None
|
| 144 |
+
|
| 145 |
+
if not syntax_ok:
|
| 146 |
+
return GradeResult(
|
| 147 |
+
reward=0.0,
|
| 148 |
+
syntax_ok=False,
|
| 149 |
+
columns_match=False,
|
| 150 |
+
row_count_match=False,
|
| 151 |
+
exact_match=False,
|
| 152 |
+
step_penalty=0.0,
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
gt_norm = _normalise_rows(ground_truth_rows)
|
| 156 |
+
act_norm = _normalise_rows(actual_rows)
|
| 157 |
+
|
| 158 |
+
gt_cols = set(gt_norm[0].keys()) if gt_norm else set()
|
| 159 |
+
act_cols = set(act_norm[0].keys()) if act_norm else set()
|
| 160 |
+
columns_match = act_cols == gt_cols
|
| 161 |
+
row_count_match = len(act_norm) == len(gt_norm)
|
| 162 |
+
|
| 163 |
+
# Exact match: if order matters, compare list; otherwise compare sorted sets
|
| 164 |
+
if columns_match and row_count_match:
|
| 165 |
+
if order_sensitive:
|
| 166 |
+
exact_match = act_norm == gt_norm
|
| 167 |
+
else:
|
| 168 |
+
# Sort rows by their string representation for order-agnostic compare
|
| 169 |
+
def _sort_key(r: Dict) -> str:
|
| 170 |
+
return str(sorted(r.items()))
|
| 171 |
+
exact_match = (
|
| 172 |
+
sorted(act_norm, key=_sort_key) == sorted(gt_norm, key=_sort_key)
|
| 173 |
+
)
|
| 174 |
+
else:
|
| 175 |
+
exact_match = False
|
| 176 |
+
|
| 177 |
+
# ── Score assembly ────────────────────────────────────────────────
|
| 178 |
+
raw = (
|
| 179 |
+
0.10 # syntax
|
| 180 |
+
+ (0.20 if columns_match else 0.0)
|
| 181 |
+
+ (0.20 if row_count_match else 0.0)
|
| 182 |
+
+ (0.50 if exact_match else 0.0)
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
penalty = max(0.0, step - 1) * 0.05
|
| 186 |
+
reward = float(max(0.0, min(1.0, raw - penalty)))
|
| 187 |
+
|
| 188 |
+
return GradeResult(
|
| 189 |
+
reward=reward,
|
| 190 |
+
syntax_ok=syntax_ok,
|
| 191 |
+
columns_match=columns_match,
|
| 192 |
+
row_count_match=row_count_match,
|
| 193 |
+
exact_match=exact_match,
|
| 194 |
+
step_penalty=penalty,
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
# ── Convenience: pre-compute ground truth rows ─────────────────────────────
|
| 199 |
+
|
| 200 |
+
def compute_ground_truth(
|
| 201 |
+
conn: sqlite3.Connection,
|
| 202 |
+
sql: str,
|
| 203 |
+
) -> List[Dict[str, Any]]:
|
| 204 |
+
"""Execute the ground-truth SQL and return normalised rows."""
|
| 205 |
+
rows, error = execute_query(conn, sql)
|
| 206 |
+
if error or rows is None:
|
| 207 |
+
raise ValueError(f"Ground-truth SQL failed: {error}\nSQL: {sql}")
|
| 208 |
+
return _normalise_rows(rows)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def has_order_by(sql: str) -> bool:
|
| 212 |
+
"""Heuristic: does the top-level query have an ORDER BY?"""
|
| 213 |
+
# Simple check sufficient for our controlled task SQL
|
| 214 |
+
return "ORDER BY" in sql.upper()
|
server/requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# nl2sql-bench/server/requirements.txt
|
| 2 |
+
# Minimal dependency set for 2 vCPU / 8 GB constraint.
|
| 3 |
+
# SQLite is part of the Python stdlib — no extra DB dependency needed.
|
| 4 |
+
|
| 5 |
+
# OpenEnv core framework
|
| 6 |
+
openenv-core>=0.2.3
|
| 7 |
+
|
| 8 |
+
# Web server
|
| 9 |
+
fastapi>=0.110.0
|
| 10 |
+
uvicorn[standard]>=0.29.0
|
| 11 |
+
|
| 12 |
+
# Typing helpers (already included in openenv-core but listed explicitly)
|
| 13 |
+
pydantic>=2.0.0
|
server/tasks/__init__.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Auto-registers all tasks by importing them."""
|
| 2 |
+
from .easy import SimpleFilterTask
|
| 3 |
+
from .medium import JoinAggregationTask
|
| 4 |
+
from .hard import AnalyticsWindowTask
|
| 5 |
+
from .base import get_task, all_task_names, BaseTask
|
| 6 |
+
|
| 7 |
+
__all__ = [
|
| 8 |
+
"SimpleFilterTask",
|
| 9 |
+
"JoinAggregationTask",
|
| 10 |
+
"AnalyticsWindowTask",
|
| 11 |
+
"get_task",
|
| 12 |
+
"all_task_names",
|
| 13 |
+
"BaseTask",
|
| 14 |
+
]
|
server/tasks/base.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
nl2sql-bench/server/tasks/base.py
|
| 3 |
+
==================================
|
| 4 |
+
Abstract base for all NL2SQL tasks and the global task registry.
|
| 5 |
+
|
| 6 |
+
Each task holds a list of (question, ground_truth_sql) pairs.
|
| 7 |
+
The environment picks one pair per episode via a deterministic round-robin
|
| 8 |
+
so that the same task always cycles through the same question sequence —
|
| 9 |
+
this keeps grader results reproducible across runs.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import sqlite3
|
| 15 |
+
from abc import ABC, abstractmethod
|
| 16 |
+
from typing import Dict, List, NamedTuple, Tuple, Type
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class TaskExample(NamedTuple):
|
| 20 |
+
question: str
|
| 21 |
+
sql: str
|
| 22 |
+
# Human-readable description of what makes this question that difficulty
|
| 23 |
+
notes: str = ""
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class BaseTask(ABC):
|
| 27 |
+
"""Abstract base class for all tasks."""
|
| 28 |
+
|
| 29 |
+
name: str = ""
|
| 30 |
+
difficulty: str = "" # easy | medium | hard
|
| 31 |
+
examples: List[TaskExample] = []
|
| 32 |
+
|
| 33 |
+
def __init__(self) -> None:
|
| 34 |
+
if not self.examples:
|
| 35 |
+
raise ValueError(f"Task {self.name!r} has no examples defined.")
|
| 36 |
+
self._cursor = 0 # round-robin index
|
| 37 |
+
|
| 38 |
+
def next_example(self) -> TaskExample:
|
| 39 |
+
"""Return the next question in round-robin order."""
|
| 40 |
+
example = self.examples[self._cursor % len(self.examples)]
|
| 41 |
+
self._cursor += 1
|
| 42 |
+
return example
|
| 43 |
+
|
| 44 |
+
@classmethod
|
| 45 |
+
def schema_context(cls) -> str:
|
| 46 |
+
"""Return a compact schema description for the agent system prompt."""
|
| 47 |
+
return _SCHEMA_CONTEXT
|
| 48 |
+
|
| 49 |
+
@abstractmethod
|
| 50 |
+
def description(self) -> str:
|
| 51 |
+
"""One-sentence description for openenv.yaml."""
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# ── Global schema context string (injected into every observation) ─────────
|
| 55 |
+
|
| 56 |
+
_SCHEMA_CONTEXT = """\
|
| 57 |
+
Database: ecommerce (SQLite, read-only)
|
| 58 |
+
|
| 59 |
+
TABLES
|
| 60 |
+
------
|
| 61 |
+
categories(id INTEGER PK, name TEXT)
|
| 62 |
+
|
| 63 |
+
products(id INTEGER PK, name TEXT, category_id INTEGER FK→categories.id,
|
| 64 |
+
price REAL, stock_quantity INTEGER)
|
| 65 |
+
|
| 66 |
+
customers(id INTEGER PK, name TEXT, email TEXT, country TEXT,
|
| 67 |
+
tier TEXT ∈ {bronze|silver|gold}, created_at TEXT ISO-8601)
|
| 68 |
+
|
| 69 |
+
orders(id INTEGER PK, customer_id INTEGER FK→customers.id,
|
| 70 |
+
status TEXT ∈ {pending|processing|shipped|delivered|cancelled},
|
| 71 |
+
created_at TEXT ISO-8601, total_amount REAL)
|
| 72 |
+
|
| 73 |
+
order_items(id INTEGER PK, order_id INTEGER FK→orders.id,
|
| 74 |
+
product_id INTEGER FK→products.id,
|
| 75 |
+
quantity INTEGER, unit_price REAL)
|
| 76 |
+
|
| 77 |
+
reviews(id INTEGER PK, product_id INTEGER FK→products.id,
|
| 78 |
+
customer_id INTEGER FK→customers.id,
|
| 79 |
+
rating INTEGER 1-5, created_at TEXT ISO-8601)
|
| 80 |
+
|
| 81 |
+
NOTES
|
| 82 |
+
-----
|
| 83 |
+
- Date comparisons: use created_at >= '2024-01-01' (text ISO sort works)
|
| 84 |
+
- SQLite window functions (RANK, DENSE_RANK, ROW_NUMBER, LAG, LEAD) are available
|
| 85 |
+
- strftime('%Y-%m', created_at) returns 'YYYY-MM' month strings
|
| 86 |
+
- All monetary values are in USD
|
| 87 |
+
"""
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# ── Task registry ──────────────────────────────────────────────────────────
|
| 91 |
+
|
| 92 |
+
_REGISTRY: Dict[str, Type[BaseTask]] = {}
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def register(cls: Type[BaseTask]) -> Type[BaseTask]:
|
| 96 |
+
"""Class decorator to register a task."""
|
| 97 |
+
_REGISTRY[cls.name] = cls
|
| 98 |
+
return cls
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def get_task(name: str) -> BaseTask:
|
| 102 |
+
if name not in _REGISTRY:
|
| 103 |
+
raise KeyError(f"Unknown task {name!r}. Available: {list(_REGISTRY)}")
|
| 104 |
+
return _REGISTRY[name]()
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def all_task_names() -> List[str]:
|
| 108 |
+
return list(_REGISTRY.keys())
|
server/tasks/easy.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
nl2sql-bench/server/tasks/easy.py
|
| 3 |
+
===================================
|
| 4 |
+
Task 1 — Simple Filter (difficulty: easy)
|
| 5 |
+
|
| 6 |
+
All questions target a SINGLE table with basic WHERE / ORDER BY / LIMIT.
|
| 7 |
+
A competent small model should solve these in 1–2 steps.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
from .base import BaseTask, TaskExample, register
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@register
|
| 16 |
+
class SimpleFilterTask(BaseTask):
|
| 17 |
+
name = "simple-filter"
|
| 18 |
+
difficulty = "easy"
|
| 19 |
+
|
| 20 |
+
examples = [
|
| 21 |
+
TaskExample(
|
| 22 |
+
question=(
|
| 23 |
+
"List all gold-tier customers ordered by their name alphabetically. "
|
| 24 |
+
"Return columns: id, name, email, country."
|
| 25 |
+
),
|
| 26 |
+
sql=(
|
| 27 |
+
"SELECT id, name, email, country "
|
| 28 |
+
"FROM customers "
|
| 29 |
+
"WHERE tier = 'gold' "
|
| 30 |
+
"ORDER BY name ASC"
|
| 31 |
+
),
|
| 32 |
+
notes="Single table, equality filter, text sort.",
|
| 33 |
+
),
|
| 34 |
+
TaskExample(
|
| 35 |
+
question=(
|
| 36 |
+
"Show all products with a price above $100, sorted by price from "
|
| 37 |
+
"highest to lowest. Return columns: id, name, price."
|
| 38 |
+
),
|
| 39 |
+
sql=(
|
| 40 |
+
"SELECT id, name, price "
|
| 41 |
+
"FROM products "
|
| 42 |
+
"WHERE price > 100 "
|
| 43 |
+
"ORDER BY price DESC"
|
| 44 |
+
),
|
| 45 |
+
notes="Numeric range filter, descending sort.",
|
| 46 |
+
),
|
| 47 |
+
TaskExample(
|
| 48 |
+
question=(
|
| 49 |
+
"Find all delivered orders with a total_amount greater than $200, "
|
| 50 |
+
"ordered by total_amount descending. "
|
| 51 |
+
"Return columns: id, customer_id, total_amount, created_at."
|
| 52 |
+
),
|
| 53 |
+
sql=(
|
| 54 |
+
"SELECT id, customer_id, total_amount, created_at "
|
| 55 |
+
"FROM orders "
|
| 56 |
+
"WHERE status = 'delivered' "
|
| 57 |
+
" AND total_amount > 200 "
|
| 58 |
+
"ORDER BY total_amount DESC"
|
| 59 |
+
),
|
| 60 |
+
notes="Two-condition WHERE on a single table.",
|
| 61 |
+
),
|
| 62 |
+
TaskExample(
|
| 63 |
+
question=(
|
| 64 |
+
"Return the top 5 most expensive products. "
|
| 65 |
+
"Return columns: id, name, price."
|
| 66 |
+
),
|
| 67 |
+
sql=(
|
| 68 |
+
"SELECT id, name, price "
|
| 69 |
+
"FROM products "
|
| 70 |
+
"ORDER BY price DESC "
|
| 71 |
+
"LIMIT 5"
|
| 72 |
+
),
|
| 73 |
+
notes="ORDER BY + LIMIT, no WHERE clause.",
|
| 74 |
+
),
|
| 75 |
+
TaskExample(
|
| 76 |
+
question=(
|
| 77 |
+
"List all distinct countries where our customers come from, "
|
| 78 |
+
"sorted alphabetically. Return a single column: country."
|
| 79 |
+
),
|
| 80 |
+
sql=(
|
| 81 |
+
"SELECT DISTINCT country "
|
| 82 |
+
"FROM customers "
|
| 83 |
+
"ORDER BY country ASC"
|
| 84 |
+
),
|
| 85 |
+
notes="DISTINCT on a single column.",
|
| 86 |
+
),
|
| 87 |
+
]
|
| 88 |
+
|
| 89 |
+
def description(self) -> str:
|
| 90 |
+
return (
|
| 91 |
+
"Single-table SELECT queries with WHERE filters, ORDER BY, and LIMIT. "
|
| 92 |
+
"Tests basic SQL fluency."
|
| 93 |
+
)
|
server/tasks/hard.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
nl2sql-bench/server/tasks/hard.py
|
| 3 |
+
===================================
|
| 4 |
+
Task 3 — Analytics & Window (difficulty: hard)
|
| 5 |
+
|
| 6 |
+
Questions require CTEs, window functions (RANK, ROW_NUMBER, running totals),
|
| 7 |
+
or non-trivial subqueries. Even strong frontier models often need 3–5 steps.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
from .base import BaseTask, TaskExample, register
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@register
|
| 16 |
+
class AnalyticsWindowTask(BaseTask):
|
| 17 |
+
name = "analytics-window"
|
| 18 |
+
difficulty = "hard"
|
| 19 |
+
|
| 20 |
+
examples = [
|
| 21 |
+
TaskExample(
|
| 22 |
+
question=(
|
| 23 |
+
"Rank customers by their total spending on delivered orders "
|
| 24 |
+
"using DENSE_RANK (rank 1 = highest spender). "
|
| 25 |
+
"Return columns: customer_name, total_spent, spending_rank. "
|
| 26 |
+
"Round total_spent to 2 decimal places. "
|
| 27 |
+
"Sort by spending_rank ascending."
|
| 28 |
+
),
|
| 29 |
+
sql=(
|
| 30 |
+
"SELECT customer_name, total_spent, spending_rank "
|
| 31 |
+
"FROM ( "
|
| 32 |
+
" SELECT c.name AS customer_name, "
|
| 33 |
+
" ROUND(SUM(o.total_amount), 2) AS total_spent, "
|
| 34 |
+
" DENSE_RANK() OVER (ORDER BY SUM(o.total_amount) DESC) AS spending_rank "
|
| 35 |
+
" FROM customers c "
|
| 36 |
+
" JOIN orders o ON o.customer_id = c.id "
|
| 37 |
+
" WHERE o.status = 'delivered' "
|
| 38 |
+
" GROUP BY c.id, c.name "
|
| 39 |
+
") sub "
|
| 40 |
+
"ORDER BY spending_rank ASC"
|
| 41 |
+
),
|
| 42 |
+
notes="Window function DENSE_RANK inside a subquery wrapping a GROUP BY.",
|
| 43 |
+
),
|
| 44 |
+
TaskExample(
|
| 45 |
+
question=(
|
| 46 |
+
"For each product that has been reviewed, show its name, its own "
|
| 47 |
+
"average rating, and the average rating of all products in its category. "
|
| 48 |
+
"Return columns: product_name, product_avg_rating, category_avg_rating. "
|
| 49 |
+
"Round both averages to 2 decimal places. "
|
| 50 |
+
"Sort by product_avg_rating descending."
|
| 51 |
+
),
|
| 52 |
+
sql=(
|
| 53 |
+
"SELECT p.name AS product_name, "
|
| 54 |
+
" ROUND(AVG(r.rating), 2) AS product_avg_rating, "
|
| 55 |
+
" ROUND(AVG(AVG(r.rating)) OVER (PARTITION BY p.category_id), 2) "
|
| 56 |
+
" AS category_avg_rating "
|
| 57 |
+
"FROM products p "
|
| 58 |
+
"JOIN reviews r ON r.product_id = p.id "
|
| 59 |
+
"GROUP BY p.id, p.name, p.category_id "
|
| 60 |
+
"ORDER BY product_avg_rating DESC"
|
| 61 |
+
),
|
| 62 |
+
notes="AVG of AVG via window PARTITION BY — requires nested aggregate understanding.",
|
| 63 |
+
),
|
| 64 |
+
TaskExample(
|
| 65 |
+
question=(
|
| 66 |
+
"Find all customers whose most recent order has status 'cancelled'. "
|
| 67 |
+
"Use a CTE with ROW_NUMBER to identify the latest order per customer. "
|
| 68 |
+
"Return columns: customer_name, last_order_status, last_order_date. "
|
| 69 |
+
"Sort by customer_name ascending."
|
| 70 |
+
),
|
| 71 |
+
sql=(
|
| 72 |
+
"WITH ranked_orders AS ( "
|
| 73 |
+
" SELECT customer_id, status, created_at, "
|
| 74 |
+
" ROW_NUMBER() OVER (PARTITION BY customer_id "
|
| 75 |
+
" ORDER BY created_at DESC) AS rn "
|
| 76 |
+
" FROM orders "
|
| 77 |
+
") "
|
| 78 |
+
"SELECT c.name AS customer_name, "
|
| 79 |
+
" ro.status AS last_order_status, "
|
| 80 |
+
" ro.created_at AS last_order_date "
|
| 81 |
+
"FROM customers c "
|
| 82 |
+
"JOIN ranked_orders ro ON ro.customer_id = c.id "
|
| 83 |
+
"WHERE ro.rn = 1 "
|
| 84 |
+
" AND ro.status = 'cancelled' "
|
| 85 |
+
"ORDER BY customer_name ASC"
|
| 86 |
+
),
|
| 87 |
+
notes="CTE + ROW_NUMBER window partitioned by customer_id.",
|
| 88 |
+
),
|
| 89 |
+
TaskExample(
|
| 90 |
+
question=(
|
| 91 |
+
"Show the monthly revenue from delivered orders and its running total, "
|
| 92 |
+
"for all months in 2024. "
|
| 93 |
+
"Return columns: month (format YYYY-MM), monthly_revenue, running_total. "
|
| 94 |
+
"Round both revenue columns to 2 decimal places. "
|
| 95 |
+
"Sort by month ascending."
|
| 96 |
+
),
|
| 97 |
+
sql=(
|
| 98 |
+
"WITH monthly AS ( "
|
| 99 |
+
" SELECT strftime('%Y-%m', created_at) AS month, "
|
| 100 |
+
" ROUND(SUM(total_amount), 2) AS monthly_revenue "
|
| 101 |
+
" FROM orders "
|
| 102 |
+
" WHERE status = 'delivered' "
|
| 103 |
+
" AND created_at >= '2024-01-01' "
|
| 104 |
+
" AND created_at < '2025-01-01' "
|
| 105 |
+
" GROUP BY strftime('%Y-%m', created_at) "
|
| 106 |
+
") "
|
| 107 |
+
"SELECT month, "
|
| 108 |
+
" monthly_revenue, "
|
| 109 |
+
" ROUND(SUM(monthly_revenue) OVER (ORDER BY month), 2) AS running_total "
|
| 110 |
+
"FROM monthly "
|
| 111 |
+
"ORDER BY month ASC"
|
| 112 |
+
),
|
| 113 |
+
notes="CTE + cumulative SUM window ordered by month string.",
|
| 114 |
+
),
|
| 115 |
+
TaskExample(
|
| 116 |
+
question=(
|
| 117 |
+
"Find products whose average rating is strictly above the average "
|
| 118 |
+
"rating of all products in their category. "
|
| 119 |
+
"Return columns: product_name, category_name, "
|
| 120 |
+
"product_avg_rating, category_avg_rating. "
|
| 121 |
+
"Round both averages to 2 decimal places. "
|
| 122 |
+
"Sort by product_avg_rating descending, then product_name ascending."
|
| 123 |
+
),
|
| 124 |
+
sql=(
|
| 125 |
+
"WITH product_ratings AS ( "
|
| 126 |
+
" SELECT p.id AS product_id, p.name AS product_name, "
|
| 127 |
+
" p.category_id, c.name AS category_name, "
|
| 128 |
+
" ROUND(AVG(r.rating), 2) AS product_avg_rating "
|
| 129 |
+
" FROM products p "
|
| 130 |
+
" JOIN reviews r ON r.product_id = p.id "
|
| 131 |
+
" JOIN categories c ON c.id = p.category_id "
|
| 132 |
+
" GROUP BY p.id, p.name, p.category_id, c.name "
|
| 133 |
+
"), "
|
| 134 |
+
"category_ratings AS ( "
|
| 135 |
+
" SELECT category_id, "
|
| 136 |
+
" ROUND(AVG(product_avg_rating), 2) AS category_avg_rating "
|
| 137 |
+
" FROM product_ratings "
|
| 138 |
+
" GROUP BY category_id "
|
| 139 |
+
") "
|
| 140 |
+
"SELECT pr.product_name, pr.category_name, "
|
| 141 |
+
" pr.product_avg_rating, cr.category_avg_rating "
|
| 142 |
+
"FROM product_ratings pr "
|
| 143 |
+
"JOIN category_ratings cr ON cr.category_id = pr.category_id "
|
| 144 |
+
"WHERE pr.product_avg_rating > cr.category_avg_rating "
|
| 145 |
+
"ORDER BY pr.product_avg_rating DESC, pr.product_name ASC"
|
| 146 |
+
),
|
| 147 |
+
notes="Two CTEs, correlated comparison between product and category averages.",
|
| 148 |
+
),
|
| 149 |
+
]
|
| 150 |
+
|
| 151 |
+
def description(self) -> str:
|
| 152 |
+
return (
|
| 153 |
+
"Advanced analytics queries using CTEs, window functions "
|
| 154 |
+
"(DENSE_RANK, ROW_NUMBER, running SUM), and nested subqueries. "
|
| 155 |
+
"Tests multi-step reasoning and SQLite-specific syntax."
|
| 156 |
+
)
|
server/tasks/medium.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
nl2sql-bench/server/tasks/medium.py
|
| 3 |
+
=====================================
|
| 4 |
+
Task 2 — Join & Aggregation (difficulty: medium)
|
| 5 |
+
|
| 6 |
+
Questions require at least one JOIN and GROUP BY / HAVING.
|
| 7 |
+
Expect most frontier models to succeed in 2–3 steps.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
from .base import BaseTask, TaskExample, register
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@register
|
| 16 |
+
class JoinAggregationTask(BaseTask):
|
| 17 |
+
name = "join-aggregation"
|
| 18 |
+
difficulty = "medium"
|
| 19 |
+
|
| 20 |
+
examples = [
|
| 21 |
+
TaskExample(
|
| 22 |
+
question=(
|
| 23 |
+
"How many orders has each customer placed? "
|
| 24 |
+
"Return columns: customer_name, order_count. "
|
| 25 |
+
"Include customers with zero orders. "
|
| 26 |
+
"Sort by order_count descending, then customer_name ascending."
|
| 27 |
+
),
|
| 28 |
+
sql=(
|
| 29 |
+
"SELECT c.name AS customer_name, COUNT(o.id) AS order_count "
|
| 30 |
+
"FROM customers c "
|
| 31 |
+
"LEFT JOIN orders o ON c.id = o.customer_id "
|
| 32 |
+
"GROUP BY c.id, c.name "
|
| 33 |
+
"ORDER BY order_count DESC, customer_name ASC"
|
| 34 |
+
),
|
| 35 |
+
notes="LEFT JOIN to include zero-order customers, COUNT aggregate.",
|
| 36 |
+
),
|
| 37 |
+
TaskExample(
|
| 38 |
+
question=(
|
| 39 |
+
"What is the average product rating per category? "
|
| 40 |
+
"Only include categories that have at least one review. "
|
| 41 |
+
"Return columns: category_name, avg_rating. "
|
| 42 |
+
"Round avg_rating to 2 decimal places. "
|
| 43 |
+
"Sort by avg_rating descending."
|
| 44 |
+
),
|
| 45 |
+
sql=(
|
| 46 |
+
"SELECT c.name AS category_name, "
|
| 47 |
+
" ROUND(AVG(r.rating), 2) AS avg_rating "
|
| 48 |
+
"FROM categories c "
|
| 49 |
+
"JOIN products p ON p.category_id = c.id "
|
| 50 |
+
"JOIN reviews r ON r.product_id = p.id "
|
| 51 |
+
"GROUP BY c.id, c.name "
|
| 52 |
+
"ORDER BY avg_rating DESC"
|
| 53 |
+
),
|
| 54 |
+
notes="Two JOINs, AVG aggregate, ROUND function.",
|
| 55 |
+
),
|
| 56 |
+
TaskExample(
|
| 57 |
+
question=(
|
| 58 |
+
"Which categories have more than 5 products in stock "
|
| 59 |
+
"(i.e., stock_quantity > 0)? "
|
| 60 |
+
"Return columns: category_name, in_stock_count. "
|
| 61 |
+
"Sort by in_stock_count descending."
|
| 62 |
+
),
|
| 63 |
+
sql=(
|
| 64 |
+
"SELECT c.name AS category_name, "
|
| 65 |
+
" COUNT(p.id) AS in_stock_count "
|
| 66 |
+
"FROM categories c "
|
| 67 |
+
"JOIN products p ON p.category_id = c.id "
|
| 68 |
+
"WHERE p.stock_quantity > 0 "
|
| 69 |
+
"GROUP BY c.id, c.name "
|
| 70 |
+
"HAVING COUNT(p.id) > 5 "
|
| 71 |
+
"ORDER BY in_stock_count DESC"
|
| 72 |
+
),
|
| 73 |
+
notes="WHERE before GROUP BY, HAVING filter on aggregate.",
|
| 74 |
+
),
|
| 75 |
+
TaskExample(
|
| 76 |
+
question=(
|
| 77 |
+
"Which customers have spent more than $500 total on delivered orders? "
|
| 78 |
+
"Return columns: customer_name, total_spent. "
|
| 79 |
+
"Round total_spent to 2 decimal places. "
|
| 80 |
+
"Sort by total_spent descending."
|
| 81 |
+
),
|
| 82 |
+
sql=(
|
| 83 |
+
"SELECT c.name AS customer_name, "
|
| 84 |
+
" ROUND(SUM(o.total_amount), 2) AS total_spent "
|
| 85 |
+
"FROM customers c "
|
| 86 |
+
"JOIN orders o ON o.customer_id = c.id "
|
| 87 |
+
"WHERE o.status = 'delivered' "
|
| 88 |
+
"GROUP BY c.id, c.name "
|
| 89 |
+
"HAVING SUM(o.total_amount) > 500 "
|
| 90 |
+
"ORDER BY total_spent DESC"
|
| 91 |
+
),
|
| 92 |
+
notes="SUM aggregate, HAVING on SUM, status filter.",
|
| 93 |
+
),
|
| 94 |
+
TaskExample(
|
| 95 |
+
question=(
|
| 96 |
+
"Show the total quantity sold for each product. "
|
| 97 |
+
"Only include products that appear in at least one order item. "
|
| 98 |
+
"Return columns: product_name, total_quantity_sold. "
|
| 99 |
+
"Sort by total_quantity_sold descending."
|
| 100 |
+
),
|
| 101 |
+
sql=(
|
| 102 |
+
"SELECT p.name AS product_name, "
|
| 103 |
+
" SUM(oi.quantity) AS total_quantity_sold "
|
| 104 |
+
"FROM products p "
|
| 105 |
+
"JOIN order_items oi ON oi.product_id = p.id "
|
| 106 |
+
"GROUP BY p.id, p.name "
|
| 107 |
+
"ORDER BY total_quantity_sold DESC"
|
| 108 |
+
),
|
| 109 |
+
notes="JOIN on order_items, SUM aggregate.",
|
| 110 |
+
),
|
| 111 |
+
]
|
| 112 |
+
|
| 113 |
+
def description(self) -> str:
|
| 114 |
+
return (
|
| 115 |
+
"Multi-table JOIN queries with GROUP BY, HAVING, and aggregation "
|
| 116 |
+
"functions (COUNT, SUM, AVG, ROUND). Tests relational reasoning."
|
| 117 |
+
)
|
tests/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# tests/__init__.py
|
tests/conftest.py
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# nl2sql-bench/tests/conftest.py
|
| 2 |
+
"""
|
| 3 |
+
Pytest configuration — adds project root and server/ to sys.path
|
| 4 |
+
so all test imports resolve without installing the package.
|
| 5 |
+
"""
|
| 6 |
+
import sys
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
ROOT = Path(__file__).parent.parent
|
| 10 |
+
SERVER = ROOT / "server"
|
| 11 |
+
|
| 12 |
+
for p in [str(ROOT), str(SERVER)]:
|
| 13 |
+
if p not in sys.path:
|
| 14 |
+
sys.path.insert(0, p)
|
tests/test_all.py
ADDED
|
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
nl2sql-bench/tests/test_all.py
|
| 3 |
+
================================
|
| 4 |
+
Comprehensive test suite covering:
|
| 5 |
+
- Database seeder (determinism + row counts)
|
| 6 |
+
- Grader (all reward components, step penalty, edge cases)
|
| 7 |
+
- Task registry (all 3 tasks load and produce valid examples)
|
| 8 |
+
- Environment (reset, step, episode boundary, done logic)
|
| 9 |
+
- Inference log format (regex checks on START / STEP / END)
|
| 10 |
+
|
| 11 |
+
Run with:
|
| 12 |
+
pytest tests/ -v
|
| 13 |
+
or from project root:
|
| 14 |
+
PYTHONPATH=.:server pytest tests/ -v
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import re
|
| 20 |
+
import sqlite3
|
| 21 |
+
import sys
|
| 22 |
+
import os
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
|
| 25 |
+
import pytest
|
| 26 |
+
|
| 27 |
+
# ── Path setup so tests can import from both project root and server/ ──────
|
| 28 |
+
ROOT = Path(__file__).parent.parent
|
| 29 |
+
SERVER = ROOT / "server"
|
| 30 |
+
sys.path.insert(0, str(ROOT))
|
| 31 |
+
sys.path.insert(0, str(SERVER))
|
| 32 |
+
|
| 33 |
+
# ── Fixtures ───────────────────────────────────────────────────────────────
|
| 34 |
+
|
| 35 |
+
@pytest.fixture(scope="session")
|
| 36 |
+
def db_conn():
|
| 37 |
+
"""Shared in-memory SQLite connection with full schema + seed data."""
|
| 38 |
+
from db.seed import seed_database
|
| 39 |
+
schema = (SERVER / "db" / "schema.sql").read_text()
|
| 40 |
+
conn = sqlite3.connect(":memory:", check_same_thread=False)
|
| 41 |
+
conn.row_factory = sqlite3.Row
|
| 42 |
+
conn.executescript(schema)
|
| 43 |
+
seed_database(conn)
|
| 44 |
+
yield conn
|
| 45 |
+
conn.close()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@pytest.fixture
|
| 49 |
+
def fresh_env():
|
| 50 |
+
"""A fresh NL2SQLEnvironment instance per test."""
|
| 51 |
+
from environment import NL2SQLEnvironment
|
| 52 |
+
return NL2SQLEnvironment()
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 56 |
+
# 1. DATABASE SEEDER
|
| 57 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 58 |
+
|
| 59 |
+
class TestSeeder:
|
| 60 |
+
|
| 61 |
+
def test_categories_count(self, db_conn):
|
| 62 |
+
row = db_conn.execute("SELECT COUNT(*) FROM categories").fetchone()
|
| 63 |
+
assert row[0] == 8, "Should have exactly 8 categories"
|
| 64 |
+
|
| 65 |
+
def test_products_count(self, db_conn):
|
| 66 |
+
row = db_conn.execute("SELECT COUNT(*) FROM products").fetchone()
|
| 67 |
+
assert row[0] == 64, "Should have 8 products × 8 categories = 64"
|
| 68 |
+
|
| 69 |
+
def test_customers_count(self, db_conn):
|
| 70 |
+
row = db_conn.execute("SELECT COUNT(*) FROM customers").fetchone()
|
| 71 |
+
assert row[0] == 150
|
| 72 |
+
|
| 73 |
+
def test_orders_exist(self, db_conn):
|
| 74 |
+
row = db_conn.execute("SELECT COUNT(*) FROM orders").fetchone()
|
| 75 |
+
assert row[0] > 100, "Should have a meaningful number of orders"
|
| 76 |
+
|
| 77 |
+
def test_order_items_exist(self, db_conn):
|
| 78 |
+
row = db_conn.execute("SELECT COUNT(*) FROM order_items").fetchone()
|
| 79 |
+
assert row[0] > 200
|
| 80 |
+
|
| 81 |
+
def test_reviews_exist(self, db_conn):
|
| 82 |
+
row = db_conn.execute("SELECT COUNT(*) FROM reviews").fetchone()
|
| 83 |
+
assert row[0] > 50
|
| 84 |
+
|
| 85 |
+
def test_determinism(self, db_conn):
|
| 86 |
+
"""Seeding a second connection with the same seed gives identical counts."""
|
| 87 |
+
from db.seed import seed_database
|
| 88 |
+
schema = (SERVER / "db" / "schema.sql").read_text()
|
| 89 |
+
conn2 = sqlite3.connect(":memory:")
|
| 90 |
+
conn2.executescript(schema)
|
| 91 |
+
seed_database(conn2)
|
| 92 |
+
|
| 93 |
+
for tbl in ["categories", "products", "customers", "orders",
|
| 94 |
+
"order_items", "reviews"]:
|
| 95 |
+
c1 = db_conn.execute(f"SELECT COUNT(*) FROM {tbl}").fetchone()[0]
|
| 96 |
+
c2 = conn2.execute(f"SELECT COUNT(*) FROM {tbl}").fetchone()[0]
|
| 97 |
+
assert c1 == c2, f"Table {tbl} count mismatch: {c1} vs {c2}"
|
| 98 |
+
conn2.close()
|
| 99 |
+
|
| 100 |
+
def test_tiers_valid(self, db_conn):
|
| 101 |
+
bad = db_conn.execute(
|
| 102 |
+
"SELECT COUNT(*) FROM customers WHERE tier NOT IN ('bronze','silver','gold')"
|
| 103 |
+
).fetchone()[0]
|
| 104 |
+
assert bad == 0
|
| 105 |
+
|
| 106 |
+
def test_statuses_valid(self, db_conn):
|
| 107 |
+
bad = db_conn.execute(
|
| 108 |
+
"SELECT COUNT(*) FROM orders "
|
| 109 |
+
"WHERE status NOT IN ('pending','processing','shipped','delivered','cancelled')"
|
| 110 |
+
).fetchone()[0]
|
| 111 |
+
assert bad == 0
|
| 112 |
+
|
| 113 |
+
def test_ratings_valid(self, db_conn):
|
| 114 |
+
bad = db_conn.execute(
|
| 115 |
+
"SELECT COUNT(*) FROM reviews WHERE rating < 1 OR rating > 5"
|
| 116 |
+
).fetchone()[0]
|
| 117 |
+
assert bad == 0
|
| 118 |
+
|
| 119 |
+
def test_referential_integrity(self, db_conn):
|
| 120 |
+
"""Order items should reference valid orders and products."""
|
| 121 |
+
orphan_orders = db_conn.execute(
|
| 122 |
+
"SELECT COUNT(*) FROM order_items oi "
|
| 123 |
+
"LEFT JOIN orders o ON o.id = oi.order_id WHERE o.id IS NULL"
|
| 124 |
+
).fetchone()[0]
|
| 125 |
+
assert orphan_orders == 0
|
| 126 |
+
|
| 127 |
+
orphan_products = db_conn.execute(
|
| 128 |
+
"SELECT COUNT(*) FROM order_items oi "
|
| 129 |
+
"LEFT JOIN products p ON p.id = oi.product_id WHERE p.id IS NULL"
|
| 130 |
+
).fetchone()[0]
|
| 131 |
+
assert orphan_products == 0
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 135 |
+
# 2. GRADER
|
| 136 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 137 |
+
|
| 138 |
+
class TestGrader:
|
| 139 |
+
|
| 140 |
+
def test_exact_match_first_step(self):
|
| 141 |
+
from grader import grade
|
| 142 |
+
gt = [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}]
|
| 143 |
+
result = grade(
|
| 144 |
+
actual_rows=gt.copy(),
|
| 145 |
+
ground_truth_rows=gt,
|
| 146 |
+
error=None,
|
| 147 |
+
step=1,
|
| 148 |
+
order_sensitive=False,
|
| 149 |
+
)
|
| 150 |
+
assert result.reward == pytest.approx(1.0)
|
| 151 |
+
assert result.exact_match is True
|
| 152 |
+
assert result.syntax_ok is True
|
| 153 |
+
assert result.columns_match is True
|
| 154 |
+
assert result.row_count_match is True
|
| 155 |
+
assert result.step_penalty == 0.0
|
| 156 |
+
|
| 157 |
+
def test_syntax_error_gives_zero(self):
|
| 158 |
+
from grader import grade
|
| 159 |
+
result = grade(
|
| 160 |
+
actual_rows=None,
|
| 161 |
+
ground_truth_rows=[{"x": 1}],
|
| 162 |
+
error="near 'SELCT': syntax error",
|
| 163 |
+
step=1,
|
| 164 |
+
)
|
| 165 |
+
assert result.reward == 0.0
|
| 166 |
+
assert result.syntax_ok is False
|
| 167 |
+
|
| 168 |
+
def test_step_penalty_applied(self):
|
| 169 |
+
from grader import grade
|
| 170 |
+
gt = [{"n": 1}]
|
| 171 |
+
result = grade(
|
| 172 |
+
actual_rows=gt.copy(),
|
| 173 |
+
ground_truth_rows=gt,
|
| 174 |
+
error=None,
|
| 175 |
+
step=3, # penalty = (3-1)*0.05 = 0.10
|
| 176 |
+
)
|
| 177 |
+
assert result.reward == pytest.approx(1.0 - 0.10)
|
| 178 |
+
assert result.step_penalty == pytest.approx(0.10)
|
| 179 |
+
|
| 180 |
+
def test_columns_wrong_zero_higher_components(self):
|
| 181 |
+
from grader import grade
|
| 182 |
+
gt = [{"name": "Alice", "score": 10}]
|
| 183 |
+
actual = [{"user": "Alice", "points": 10}] # wrong column names
|
| 184 |
+
result = grade(actual_rows=actual, ground_truth_rows=gt, error=None, step=1)
|
| 185 |
+
assert result.columns_match is False
|
| 186 |
+
assert result.exact_match is False
|
| 187 |
+
# Only syntax score: 0.10
|
| 188 |
+
assert result.reward == pytest.approx(0.10)
|
| 189 |
+
|
| 190 |
+
def test_correct_columns_wrong_rows(self):
|
| 191 |
+
from grader import grade
|
| 192 |
+
gt = [{"name": "Alice"}, {"name": "Bob"}]
|
| 193 |
+
actual = [{"name": "Charlie"}, {"name": "Dave"}]
|
| 194 |
+
result = grade(actual_rows=actual, ground_truth_rows=gt, error=None, step=1)
|
| 195 |
+
assert result.columns_match is True
|
| 196 |
+
assert result.row_count_match is True
|
| 197 |
+
assert result.exact_match is False
|
| 198 |
+
# syntax(0.10) + columns(0.20) + row_count(0.20) = 0.50
|
| 199 |
+
assert result.reward == pytest.approx(0.50)
|
| 200 |
+
|
| 201 |
+
def test_order_sensitive_wrong_order_is_not_exact(self):
|
| 202 |
+
from grader import grade
|
| 203 |
+
gt = [{"id": 1}, {"id": 2}]
|
| 204 |
+
actual = [{"id": 2}, {"id": 1}] # reversed
|
| 205 |
+
result = grade(
|
| 206 |
+
actual_rows=actual,
|
| 207 |
+
ground_truth_rows=gt,
|
| 208 |
+
error=None,
|
| 209 |
+
step=1,
|
| 210 |
+
order_sensitive=True,
|
| 211 |
+
)
|
| 212 |
+
assert result.exact_match is False
|
| 213 |
+
|
| 214 |
+
def test_order_insensitive_accepts_different_row_order(self):
|
| 215 |
+
from grader import grade
|
| 216 |
+
gt = [{"id": 1}, {"id": 2}]
|
| 217 |
+
actual = [{"id": 2}, {"id": 1}] # different order but same content
|
| 218 |
+
result = grade(
|
| 219 |
+
actual_rows=actual,
|
| 220 |
+
ground_truth_rows=gt,
|
| 221 |
+
error=None,
|
| 222 |
+
step=1,
|
| 223 |
+
order_sensitive=False,
|
| 224 |
+
)
|
| 225 |
+
assert result.exact_match is True
|
| 226 |
+
|
| 227 |
+
def test_penalty_never_makes_reward_negative(self):
|
| 228 |
+
from grader import grade
|
| 229 |
+
# Step 99 with syntax error → reward must be >= 0
|
| 230 |
+
result = grade(
|
| 231 |
+
actual_rows=None,
|
| 232 |
+
ground_truth_rows=[{"x": 1}],
|
| 233 |
+
error="some error",
|
| 234 |
+
step=99,
|
| 235 |
+
)
|
| 236 |
+
assert result.reward >= 0.0
|
| 237 |
+
|
| 238 |
+
def test_execute_query_blocks_writes(self, db_conn):
|
| 239 |
+
from grader import execute_query
|
| 240 |
+
rows, err = execute_query(db_conn, "INSERT INTO categories(name) VALUES ('x')")
|
| 241 |
+
assert rows is None
|
| 242 |
+
assert "not allowed" in err.lower() or "INSERT" in err
|
| 243 |
+
|
| 244 |
+
def test_execute_query_returns_rows(self, db_conn):
|
| 245 |
+
from grader import execute_query
|
| 246 |
+
rows, err = execute_query(db_conn, "SELECT id, name FROM categories ORDER BY id")
|
| 247 |
+
assert err is None
|
| 248 |
+
assert len(rows) == 8
|
| 249 |
+
assert "id" in rows[0]
|
| 250 |
+
assert "name" in rows[0]
|
| 251 |
+
|
| 252 |
+
def test_compute_ground_truth(self, db_conn):
|
| 253 |
+
from grader import compute_ground_truth
|
| 254 |
+
rows = compute_ground_truth(db_conn, "SELECT COUNT(*) AS n FROM customers")
|
| 255 |
+
assert len(rows) == 1
|
| 256 |
+
assert rows[0]["n"] == 150
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
# ════��═════════════════════════════════════════════════════════════════════════
|
| 260 |
+
# 3. TASK REGISTRY
|
| 261 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 262 |
+
|
| 263 |
+
class TestTasks:
|
| 264 |
+
|
| 265 |
+
def test_all_tasks_registered(self):
|
| 266 |
+
from tasks import all_task_names
|
| 267 |
+
names = all_task_names()
|
| 268 |
+
assert "simple-filter" in names
|
| 269 |
+
assert "join-aggregation" in names
|
| 270 |
+
assert "analytics-window" in names
|
| 271 |
+
|
| 272 |
+
@pytest.mark.parametrize("task_name", [
|
| 273 |
+
"simple-filter", "join-aggregation", "analytics-window"
|
| 274 |
+
])
|
| 275 |
+
def test_task_has_examples(self, task_name):
|
| 276 |
+
from tasks import get_task
|
| 277 |
+
task = get_task(task_name)
|
| 278 |
+
assert len(task.examples) >= 3, f"{task_name} needs at least 3 examples"
|
| 279 |
+
|
| 280 |
+
@pytest.mark.parametrize("task_name", [
|
| 281 |
+
"simple-filter", "join-aggregation", "analytics-window"
|
| 282 |
+
])
|
| 283 |
+
def test_task_sql_runs_on_real_db(self, task_name, db_conn):
|
| 284 |
+
"""Every ground-truth SQL must execute cleanly against the seeded DB."""
|
| 285 |
+
from tasks import get_task
|
| 286 |
+
from grader import execute_query
|
| 287 |
+
task = get_task(task_name)
|
| 288 |
+
for ex in task.examples:
|
| 289 |
+
rows, error = execute_query(db_conn, ex.sql)
|
| 290 |
+
assert error is None, (
|
| 291 |
+
f"Task {task_name!r} SQL failed:\n{ex.sql}\nError: {error}"
|
| 292 |
+
)
|
| 293 |
+
assert rows is not None
|
| 294 |
+
|
| 295 |
+
@pytest.mark.parametrize("task_name", [
|
| 296 |
+
"simple-filter", "join-aggregation", "analytics-window"
|
| 297 |
+
])
|
| 298 |
+
def test_task_roundrobin(self, task_name):
|
| 299 |
+
from tasks import get_task
|
| 300 |
+
task = get_task(task_name)
|
| 301 |
+
n = len(task.examples)
|
| 302 |
+
seen = [task.next_example() for _ in range(n * 2)]
|
| 303 |
+
# After n calls, second half should repeat first half
|
| 304 |
+
assert seen[:n] == seen[n:]
|
| 305 |
+
|
| 306 |
+
def test_schema_context_non_empty(self):
|
| 307 |
+
from tasks import get_task
|
| 308 |
+
task = get_task("simple-filter")
|
| 309 |
+
ctx = task.schema_context()
|
| 310 |
+
assert "customers" in ctx
|
| 311 |
+
assert "orders" in ctx
|
| 312 |
+
assert "products" in ctx
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 316 |
+
# 4. ENVIRONMENT
|
| 317 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 318 |
+
|
| 319 |
+
class TestEnvironment:
|
| 320 |
+
|
| 321 |
+
def test_reset_returns_observation(self, fresh_env):
|
| 322 |
+
obs = fresh_env.reset(task_name="simple-filter")
|
| 323 |
+
assert obs.question != ""
|
| 324 |
+
assert obs.schema_context != ""
|
| 325 |
+
assert obs.task_name == "simple-filter"
|
| 326 |
+
assert obs.done is False
|
| 327 |
+
assert obs.step == 0
|
| 328 |
+
assert obs.reward is None
|
| 329 |
+
|
| 330 |
+
def test_reset_state(self, fresh_env):
|
| 331 |
+
fresh_env.reset(task_name="join-aggregation")
|
| 332 |
+
state = fresh_env.state
|
| 333 |
+
assert state.task_name == "join-aggregation"
|
| 334 |
+
assert state.task_difficulty == "medium"
|
| 335 |
+
assert state.step_count == 0
|
| 336 |
+
assert state.solved is False
|
| 337 |
+
|
| 338 |
+
def test_step_increments_step_count(self, fresh_env):
|
| 339 |
+
from models import NL2SQLAction
|
| 340 |
+
fresh_env.reset(task_name="simple-filter")
|
| 341 |
+
fresh_env.step(NL2SQLAction(query="SELECT 1"))
|
| 342 |
+
assert fresh_env.state.step_count == 1
|
| 343 |
+
|
| 344 |
+
def test_step_syntax_error_gives_nonzero_error(self, fresh_env):
|
| 345 |
+
from models import NL2SQLAction
|
| 346 |
+
fresh_env.reset(task_name="simple-filter")
|
| 347 |
+
obs = fresh_env.step(NL2SQLAction(query="SELCT * FORM broken"))
|
| 348 |
+
assert obs.last_error is not None
|
| 349 |
+
assert obs.reward == 0.0
|
| 350 |
+
|
| 351 |
+
def test_step_valid_query_returns_result(self, fresh_env):
|
| 352 |
+
from models import NL2SQLAction
|
| 353 |
+
fresh_env.reset(task_name="simple-filter")
|
| 354 |
+
obs = fresh_env.step(NL2SQLAction(
|
| 355 |
+
query="SELECT id, name FROM customers ORDER BY name LIMIT 5"
|
| 356 |
+
))
|
| 357 |
+
assert obs.last_error is None
|
| 358 |
+
assert len(obs.last_result) <= 5
|
| 359 |
+
assert obs.reward >= 0.0
|
| 360 |
+
|
| 361 |
+
def test_exact_match_ends_episode(self, fresh_env):
|
| 362 |
+
"""Submitting the exact ground-truth SQL should solve the episode."""
|
| 363 |
+
from models import NL2SQLAction
|
| 364 |
+
fresh_env.reset(task_name="simple-filter")
|
| 365 |
+
# Get the ground truth SQL from the internal example
|
| 366 |
+
gt_sql = fresh_env._example.sql
|
| 367 |
+
obs = fresh_env.step(NL2SQLAction(query=gt_sql))
|
| 368 |
+
assert obs.done is True
|
| 369 |
+
assert fresh_env.state.solved is True
|
| 370 |
+
assert obs.reward == pytest.approx(1.0) # step 1, full score
|
| 371 |
+
|
| 372 |
+
def test_max_steps_ends_episode(self, fresh_env):
|
| 373 |
+
"""Exhausting all steps should end the episode even without solving."""
|
| 374 |
+
from models import NL2SQLAction
|
| 375 |
+
from environment import MAX_STEPS
|
| 376 |
+
fresh_env.reset(task_name="analytics-window")
|
| 377 |
+
obs = None
|
| 378 |
+
for _ in range(MAX_STEPS):
|
| 379 |
+
obs = fresh_env.step(NL2SQLAction(query="SELECT 1"))
|
| 380 |
+
assert obs is not None
|
| 381 |
+
assert obs.done is True
|
| 382 |
+
|
| 383 |
+
def test_reset_clears_previous_episode(self, fresh_env):
|
| 384 |
+
from models import NL2SQLAction
|
| 385 |
+
fresh_env.reset(task_name="simple-filter")
|
| 386 |
+
fresh_env.step(NL2SQLAction(query="SELECT 1"))
|
| 387 |
+
# Second reset should clear state
|
| 388 |
+
obs = fresh_env.reset(task_name="join-aggregation")
|
| 389 |
+
assert fresh_env.state.step_count == 0
|
| 390 |
+
assert obs.step == 0
|
| 391 |
+
assert obs.task_name == "join-aggregation"
|
| 392 |
+
|
| 393 |
+
@pytest.mark.parametrize("task_name", [
|
| 394 |
+
"simple-filter", "join-aggregation", "analytics-window"
|
| 395 |
+
])
|
| 396 |
+
def test_all_tasks_solvable(self, task_name):
|
| 397 |
+
"""Ground-truth SQL should always produce reward == 1.0 on step 1."""
|
| 398 |
+
from environment import NL2SQLEnvironment
|
| 399 |
+
from models import NL2SQLAction
|
| 400 |
+
env = NL2SQLEnvironment()
|
| 401 |
+
env.reset(task_name=task_name)
|
| 402 |
+
gt_sql = env._example.sql
|
| 403 |
+
obs = env.step(NL2SQLAction(query=gt_sql))
|
| 404 |
+
assert obs.done is True
|
| 405 |
+
assert obs.reward == pytest.approx(1.0), (
|
| 406 |
+
f"Task {task_name!r}: ground-truth SQL did not score 1.0.\n"
|
| 407 |
+
f"SQL: {gt_sql}\nError: {obs.last_error}\nReward: {obs.reward}"
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
def test_score_normalised_to_0_1(self, fresh_env):
|
| 411 |
+
from models import NL2SQLAction
|
| 412 |
+
fresh_env.reset(task_name="simple-filter")
|
| 413 |
+
for _ in range(3):
|
| 414 |
+
obs = fresh_env.step(NL2SQLAction(query="SELECT 1 AS x"))
|
| 415 |
+
assert 0.0 <= obs.score <= 1.0
|
| 416 |
+
|
| 417 |
+
def test_write_query_blocked(self, fresh_env):
|
| 418 |
+
from models import NL2SQLAction
|
| 419 |
+
fresh_env.reset(task_name="simple-filter")
|
| 420 |
+
obs = fresh_env.step(NL2SQLAction(
|
| 421 |
+
query="INSERT INTO categories(name) VALUES ('hack')"
|
| 422 |
+
))
|
| 423 |
+
assert obs.last_error is not None
|
| 424 |
+
assert "not allowed" in obs.last_error.lower() or "INSERT" in obs.last_error
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 428 |
+
# 5. LOG FORMAT COMPLIANCE
|
| 429 |
+
# ══════════════════════════════════════════════════════════════════════════════
|
| 430 |
+
|
| 431 |
+
class TestLogFormat:
|
| 432 |
+
"""Validate that the inference.py log helpers emit correct format."""
|
| 433 |
+
|
| 434 |
+
START_RE = re.compile(
|
| 435 |
+
r"^\[START\] task=\S+ env=\S+ model=\S+$"
|
| 436 |
+
)
|
| 437 |
+
STEP_RE = re.compile(
|
| 438 |
+
r"^\[STEP\] step=\d+ action=.+ reward=\d+\.\d{2} "
|
| 439 |
+
r"done=(true|false) error=.+$"
|
| 440 |
+
)
|
| 441 |
+
END_RE = re.compile(
|
| 442 |
+
r"^\[END\] success=(true|false) steps=\d+ score=\d+\.\d{3} "
|
| 443 |
+
r"rewards=[\d.,]+$"
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
def _capture(self, func, *args, **kwargs) -> str:
|
| 447 |
+
import io
|
| 448 |
+
from contextlib import redirect_stdout
|
| 449 |
+
buf = io.StringIO()
|
| 450 |
+
with redirect_stdout(buf):
|
| 451 |
+
func(*args, **kwargs)
|
| 452 |
+
return buf.getvalue().strip()
|
| 453 |
+
|
| 454 |
+
def test_log_start_format(self):
|
| 455 |
+
sys.path.insert(0, str(ROOT))
|
| 456 |
+
from inference import log_start
|
| 457 |
+
out = self._capture(log_start, "simple-filter", "Qwen/Qwen2.5-72B")
|
| 458 |
+
assert self.START_RE.match(out), f"Bad [START] format: {out!r}"
|
| 459 |
+
|
| 460 |
+
def test_log_step_format_null_error(self):
|
| 461 |
+
from inference import log_step
|
| 462 |
+
out = self._capture(log_step, 1, "SELECT 1", 0.10, False, None)
|
| 463 |
+
assert self.STEP_RE.match(out), f"Bad [STEP] format: {out!r}"
|
| 464 |
+
|
| 465 |
+
def test_log_step_format_with_error(self):
|
| 466 |
+
from inference import log_step
|
| 467 |
+
out = self._capture(log_step, 2, "SELCT 1", 0.0, False, "syntax error")
|
| 468 |
+
assert self.STEP_RE.match(out), f"Bad [STEP] format: {out!r}"
|
| 469 |
+
|
| 470 |
+
def test_log_end_format_success(self):
|
| 471 |
+
from inference import log_end
|
| 472 |
+
out = self._capture(log_end, True, 3, 0.850, [0.50, 1.0, 1.0])
|
| 473 |
+
assert self.END_RE.match(out), f"Bad [END] format: {out!r}"
|
| 474 |
+
|
| 475 |
+
def test_log_end_format_failure(self):
|
| 476 |
+
from inference import log_end
|
| 477 |
+
out = self._capture(log_end, False, 5, 0.100, [0.1, 0.0, 0.0, 0.0, 0.0])
|
| 478 |
+
assert self.END_RE.match(out), f"Bad [END] format: {out!r}"
|
| 479 |
+
|
| 480 |
+
def test_reward_two_decimal_places(self):
|
| 481 |
+
from inference import log_step
|
| 482 |
+
out = self._capture(log_step, 1, "SELECT 1", 0.5, False, None)
|
| 483 |
+
# reward= field must have exactly 2 decimal places
|
| 484 |
+
match = re.search(r"reward=(\d+\.\d+)", out)
|
| 485 |
+
assert match, "No reward= field found"
|
| 486 |
+
assert len(match.group(1).split(".")[1]) == 2
|
| 487 |
+
|
| 488 |
+
def test_score_three_decimal_places(self):
|
| 489 |
+
from inference import log_end
|
| 490 |
+
out = self._capture(log_end, True, 1, 1.0, [1.0])
|
| 491 |
+
match = re.search(r"score=(\d+\.\d+)", out)
|
| 492 |
+
assert match
|
| 493 |
+
assert len(match.group(1).split(".")[1]) == 3
|
train.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
# CRITICAL: Ye line sabse upar honi chahiye kisi bhi PyTorch import se pehle!
|
| 3 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0,7"
|
| 4 |
+
|
| 5 |
+
import sys
|
| 6 |
+
import torch
|
| 7 |
+
from datasets import Dataset
|
| 8 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 9 |
+
from peft import LoraConfig
|
| 10 |
+
from trl import GRPOConfig, GRPOTrainer
|
| 11 |
+
|
| 12 |
+
sys.path.insert(0, "./server")
|
| 13 |
+
from environment import NL2SQLEnvironment
|
| 14 |
+
from models import NL2SQLAction
|
| 15 |
+
from tasks import all_task_names, get_task
|
| 16 |
+
|
| 17 |
+
MODEL_NAME = "Qwen/Qwen2.5-Coder-7B-Instruct"
|
| 18 |
+
OUTPUT_DIR = "./qwen-7b-coder-nl2sql-grpo"
|
| 19 |
+
|
| 20 |
+
SYSTEM_PROMPT = """You are a Senior Database Architect and an expert in SQLite.
|
| 21 |
+
Your task is to translate natural language questions into highly optimized, correct SQLite SELECT queries.
|
| 22 |
+
|
| 23 |
+
STRICT RULES:
|
| 24 |
+
1. Output EXACTLY ONE valid SQLite query.
|
| 25 |
+
2. DO NOT wrap the query in markdown formatting (no ```sql or ```).
|
| 26 |
+
3. DO NOT output any explanations, conversational text, or preambles (e.g., never say "Here is the query").
|
| 27 |
+
4. ONLY use standard SQLite functions. Avoid SQL Server, MySQL, or PostgreSQL specific syntax.
|
| 28 |
+
5. If the question implies ordering, use the correct ORDER BY clause.
|
| 29 |
+
|
| 30 |
+
Your output must be executable directly against the database as-is."""
|
| 31 |
+
|
| 32 |
+
def build_dataset():
|
| 33 |
+
data = []
|
| 34 |
+
for t_name in all_task_names():
|
| 35 |
+
task = get_task(t_name)
|
| 36 |
+
schema = task.schema_context()
|
| 37 |
+
for ex in task.examples:
|
| 38 |
+
user_content = f"SCHEMA:\n{schema}\n\nQUESTION: {ex.question}"
|
| 39 |
+
data.append({
|
| 40 |
+
"prompt": [
|
| 41 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 42 |
+
{"role": "user", "content": user_content}
|
| 43 |
+
],
|
| 44 |
+
"task_name": t_name
|
| 45 |
+
})
|
| 46 |
+
return Dataset.from_list(data)
|
| 47 |
+
|
| 48 |
+
def sql_reward_func(prompts, completions, task_name, **kwargs):
|
| 49 |
+
rewards = []
|
| 50 |
+
env = NL2SQLEnvironment()
|
| 51 |
+
|
| 52 |
+
for idx, completion in enumerate(completions):
|
| 53 |
+
generated_text = completion[0]['content'] if isinstance(completion, list) else completion
|
| 54 |
+
|
| 55 |
+
if generated_text.startswith("```"):
|
| 56 |
+
lines = generated_text.split("\n")
|
| 57 |
+
generated_text = "\n".join(l for l in lines if not l.strip().startswith("```")).strip()
|
| 58 |
+
|
| 59 |
+
current_task = task_name[idx] if isinstance(task_name, list) else task_name
|
| 60 |
+
|
| 61 |
+
env.reset(task_name=current_task)
|
| 62 |
+
|
| 63 |
+
try:
|
| 64 |
+
action = NL2SQLAction(query=generated_text)
|
| 65 |
+
obs = env.step(action)
|
| 66 |
+
rewards.append(float(obs.reward))
|
| 67 |
+
except Exception:
|
| 68 |
+
rewards.append(0.0)
|
| 69 |
+
|
| 70 |
+
return rewards
|
| 71 |
+
|
| 72 |
+
def main():
|
| 73 |
+
dataset = build_dataset()
|
| 74 |
+
|
| 75 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, padding_side="right")
|
| 76 |
+
if tokenizer.pad_token is None:
|
| 77 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 78 |
+
|
| 79 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 80 |
+
MODEL_NAME,
|
| 81 |
+
torch_dtype=torch.bfloat16,
|
| 82 |
+
attn_implementation="sdpa" # Defaulting to sdpa to avoid any flash_attn setup issues
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
peft_config = LoraConfig(
|
| 86 |
+
r=128,
|
| 87 |
+
lora_alpha=256,
|
| 88 |
+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
|
| 89 |
+
bias="none",
|
| 90 |
+
task_type="CAUSAL_LM"
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
training_args = GRPOConfig(
|
| 94 |
+
output_dir=OUTPUT_DIR,
|
| 95 |
+
learning_rate=2e-5,
|
| 96 |
+
per_device_train_batch_size=2,
|
| 97 |
+
gradient_accumulation_steps=4,
|
| 98 |
+
max_completion_length=256,
|
| 99 |
+
num_generations=8,
|
| 100 |
+
temperature=0.5,
|
| 101 |
+
bf16=True,
|
| 102 |
+
logging_steps=5,
|
| 103 |
+
num_train_epochs=10,
|
| 104 |
+
report_to="none",
|
| 105 |
+
remove_unused_columns=False,
|
| 106 |
+
ddp_find_unused_parameters=False
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
trainer = GRPOTrainer(
|
| 110 |
+
model=model,
|
| 111 |
+
reward_funcs=sql_reward_func,
|
| 112 |
+
args=training_args,
|
| 113 |
+
train_dataset=dataset,
|
| 114 |
+
peft_config=peft_config,
|
| 115 |
+
processing_class=tokenizer
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
trainer.train()
|
| 119 |
+
|
| 120 |
+
if trainer.accelerator.is_main_process:
|
| 121 |
+
trainer.model.save_pretrained(f"{OUTPUT_DIR}/final")
|
| 122 |
+
tokenizer.save_pretrained(f"{OUTPUT_DIR}/final")
|
| 123 |
+
|
| 124 |
+
if __name__ == "__main__":
|
| 125 |
+
main()
|