nl2sql-bench / README.md
ritvik360's picture
Upload folder using huggingface_hub
fcc5471 verified
metadata
title: NL2SQL Bench
emoji: πŸ“Š
colorFrom: blue
colorTo: indigo
sdk: docker
pinned: false

NL2SQL-Bench

Natural Language to SQL Analytics Environment for RL Training

openenv Python 3.10+ License: MIT


What is NL2SQL-Bench?

NL2SQL-Bench is an OpenEnv-compliant RL training environment where an AI agent must iteratively write and refine SQLite queries to answer natural-language business questions against a synthetic e-commerce database.

This fills a genuine gap in the OpenEnv ecosystem β€” no SQL query environment currently exists. Every data-driven company employs analysts who translate business questions into SQL. Training agents to do this well (and to recover from errors) is immediately valuable.

Why it's a great RL domain:

  • Rewards are 100% deterministic β€” no LLM-as-judge, no subjectivity
  • Multi-turn episodes create dense reward signal across the trajectory
  • The error β†’ fix β†’ retry loop is a novel mechanic not present in existing environments
  • Three clearly graduated difficulty levels challenge models across the full skill range

Environment Description

The agent interacts with a synthetic e-commerce SQLite database containing ~150 customers, 64 products across 8 categories, ~600 orders, ~1000 order items, and ~400 reviews. The database is seeded deterministically (seed=42) so results are reproducible across any machine.

The agent receives a natural-language question and iteratively submits SQL queries. Each query is executed, graded against the ground truth, and the reward + error/result is fed back as the next observation.


Database Schema

categories(id, name)
products(id, name, category_id, price, stock_quantity)
customers(id, name, email, country, tier∈{bronze|silver|gold}, created_at)
orders(id, customer_id, status∈{pending|processing|shipped|delivered|cancelled},
       created_at, total_amount)
order_items(id, order_id, product_id, quantity, unit_price)
reviews(id, product_id, customer_id, rating∈1-5, created_at)

All dates are ISO-8601 strings sortable by text comparison. SQLite window functions and CTEs are fully supported.


Action & Observation Space

Action

@dataclass
class NL2SQLAction(Action):
    query: str  # A SQLite SELECT query string

Observation

@dataclass
class NL2SQLObservation(Observation):
    question:       str               # The NL question to answer
    schema_context: str               # Compact schema description
    task_name:      str               # Active task identifier
    last_query:     str               # SQL submitted on previous step
    last_result:    List[Dict]        # Up to 10 result rows
    last_error:     Optional[str]     # SQLite error string or None
    result_columns: List[str]         # Column names of last_result
    step:           int               # Current step (0 after reset)
    max_steps:      int               # Maximum steps per episode
    done:           bool              # Episode ended?
    reward:         Optional[float]   # Step reward [0.0, 1.0]
    score:          float             # Cumulative normalised score

Tasks & Expected Difficulty

Task 1 β€” simple-filter (easy)

Single-table SELECT queries with WHERE, ORDER BY, LIMIT. Tests basic SQL fluency. Example questions:

  • "List all gold-tier customers ordered by name alphabetically."
  • "Return the top 5 most expensive products."

Expected solve rate (frontier model, 5 steps): ~80%

Task 2 β€” join-aggregation (medium)

Multi-table JOINs with GROUP BY, HAVING, and aggregation functions. Example questions:

  • "How many orders has each customer placed? Include customers with zero orders."
  • "Which customers have spent more than $500 total on delivered orders?"

Expected solve rate (frontier model, 5 steps): ~55%

Task 3 β€” analytics-window (hard)

CTEs, window functions (DENSE_RANK, ROW_NUMBER, running SUM), and nested subqueries. Example questions:

  • "Rank customers by total spending using DENSE_RANK."
  • "Show monthly revenue and running total for delivered orders in 2024."

Expected solve rate (frontier model, 5 steps): ~30%


Reward Function

Rewards are computed by deterministic comparison of the agent's result set against the ground truth:

Component Score Description
syntax_ok +0.10 Query runs without SQLite error
columns_match +0.20 Returned column names match ground truth
row_count_match +0.20 Number of rows matches
exact_match +0.50 Full result set equals ground truth
step_penalty βˆ’0.05/step Deducted per step beyond the first

Final reward is clamped to [0.0, 1.0]. Order sensitivity matches the ground-truth query: ORDER BY queries require correct row ordering; others are order-agnostic.


Baseline Scores

Run by the inference.py script using Qwen/Qwen2.5-72B-Instruct via HuggingFace router:

Task Expected Score
simple-filter ~0.70
join-aggregation ~0.45
analytics-window ~0.25

Setup & Usage

Prerequisites

  • Python 3.10+
  • Docker (for containerised deployment)
  • A HuggingFace account + token

Local Development (no Docker)

# Clone the repository
git clone https://huggingface.co/spaces/your-username/nl2sql-bench
cd nl2sql-bench

# Quick start
chmod +x scripts/run_local.sh
./scripts/run_local.sh

# Or manually:
python3 -m venv .venv && source .venv/bin/activate
pip install openenv-core fastapi "uvicorn[standard]" openai pydantic
export PYTHONPATH=".:server"
cd server && uvicorn app:app --reload --port 8000

Test the Running Server

# Run smoke tests
chmod +x scripts/smoke_test.sh
./scripts/smoke_test.sh http://localhost:8000

# Run full test suite
pip install pytest pytest-asyncio
PYTHONPATH=".:server" pytest tests/ -v

Docker

# Build
docker build -t nl2sql-bench:latest .

# Run
docker run -p 7860:7860 nl2sql-bench:latest

# Test
./scripts/smoke_test.sh http://localhost:7860

Pre-submission Validation

# Run the official validator (replace with your HF Space URL)
chmod +x pre_validation_script.sh
./pre_validation_script.sh https://your-username-nl2sql-bench.hf.space .

Running the Baseline Inference

# Set mandatory variables
export API_BASE_URL="https://router.huggingface.co/v1"
export MODEL_NAME="Qwen/Qwen2.5-72B-Instruct"
export HF_TOKEN="hf_your_token_here"
export SPACE_URL="https://your-username-nl2sql-bench.hf.space"

python inference.py

Using the Client Programmatically

import asyncio
from client import NL2SQLEnv
from models import NL2SQLAction

async def main():
    async with NL2SQLEnv(base_url="http://localhost:8000") as env:
        result = await env.reset()
        print(result.observation.question)

        result = await env.step(NL2SQLAction(
            query="SELECT id, name FROM customers WHERE tier='gold' ORDER BY name"
        ))
        print(f"Reward: {result.reward:.2f}")
        print(f"Done: {result.done}")
        print(f"Error: {result.observation.last_error}")

asyncio.run(main())

Project Structure

nl2sql-bench/
β”œβ”€β”€ models.py              # NL2SQLAction, NL2SQLObservation, NL2SQLState
β”œβ”€β”€ client.py              # NL2SQLEnv(HTTPEnvClient)
β”œβ”€β”€ inference.py           # Baseline inference script (mandatory name)
β”œβ”€β”€ openenv.yaml           # OpenEnv manifest
β”œβ”€β”€ pyproject.toml
β”œβ”€β”€ Dockerfile             # HF Spaces compatible (port 7860)
β”œβ”€β”€ .env.example
β”œβ”€β”€ server/
β”‚   β”œβ”€β”€ app.py             # FastAPI entry point
β”‚   β”œβ”€β”€ environment.py     # Core RL environment logic
β”‚   β”œβ”€β”€ grader.py          # Deterministic reward computation
β”‚   β”œβ”€β”€ requirements.txt
β”‚   β”œβ”€β”€ db/
β”‚   β”‚   β”œβ”€β”€ schema.sql     # 6-table e-commerce schema
β”‚   β”‚   └── seed.py        # Deterministic data generator (seed=42)
β”‚   └── tasks/
β”‚       β”œβ”€β”€ base.py        # BaseTask + registry
β”‚       β”œβ”€β”€ easy.py        # simple-filter (5 examples)
β”‚       β”œβ”€β”€ medium.py      # join-aggregation (5 examples)
β”‚       └── hard.py        # analytics-window (5 examples)
β”œβ”€β”€ tests/
β”‚   β”œβ”€β”€ conftest.py
β”‚   └── test_all.py        # 30+ pytest tests
└── scripts/
    β”œβ”€β”€ run_local.sh        # Local dev server
    └── smoke_test.sh       # Endpoint smoke tests

Design Decisions

Why SQLite in-memory? Zero runtime dependency, deterministic, and it runs comfortably within the 2 vCPU / 8 GB constraint. The database loads in ~50ms.

Why multi-turn (up to 5 steps)? A single-shot SQL environment gives binary rewards. Multi-turn with error feedback gives the agent β€” and the GRPO trainer β€” a rich signal: the model learns not just to write SQL, but to debug and refine its queries.

Why step penalty? Without it, an agent that accidentally gets the right answer on step 5 scores the same as one that gets it on step 1. The penalty creates pressure to solve efficiently, which is realistic.

Why order-sensitive comparison for ORDER BY queries? Business questions that say "rank by spending" expect a ranked output. Order-agnostic comparison would give spurious credit.


License

MIT β€” see LICENSE