hjerpe commited on
Commit
a001a97
·
verified ·
1 Parent(s): a19eef8

Upload folder using huggingface_hub

Browse files
DATA_LICENSE ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Data License Notice
2
+
3
+ Data in data/ is adapted from the Spider dataset (Yu et al., 2018),
4
+ distributed under CC BY-SA 4.0.
5
+
6
+ We retrieved question/SQL pairs from the xlangai/spider HuggingFace mirror
7
+ and SQLite databases from the taoyds/spider GitHub mirror, then curated a
8
+ 10-database subset, derived gold answers by executing the gold SQL, and
9
+ generated SFT trajectories from those artifacts.
10
+
11
+ Derived data in data/ is shared under CC BY-SA 4.0.
12
+ Software code is licensed separately under MIT (see LICENSE).
13
+
14
+ References:
15
+ - Spider dataset: https://yale-lily.github.io/spider
16
+ - Yu et al. (2018). Spider: A Large-Scale Human-Labeled Dataset for Complex
17
+ and Cross-Domain Semantic Parsing and Text-to-SQL Task. EMNLP.
18
+ - xlangai/spider on HuggingFace: https://huggingface.co/datasets/xlangai/spider
19
+ - taoyds/spider on GitHub: https://github.com/taoyds/spider
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2026 Adam Hjerpe
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -5,168 +5,146 @@ colorFrom: blue
5
  colorTo: green
6
  sdk: docker
7
  app_port: 8000
8
- pinned: false
9
  base_path: /web
10
  ---
11
 
12
- # SQLEnv: Teaching Agents to Explore Databases
13
 
14
  ![Python](https://img.shields.io/badge/python-3.12-blue.svg)
15
  ![License](https://img.shields.io/badge/license-MIT-green.svg)
 
16
 
17
- SQLEnv is an interactive RL environment for text-to-SQL reasoning. Instead of producing one-shot SQL, agents learn to think like data analysts: inspect schema, sample rows, run exploratory queries, and submit a final answer with confidence.
18
 
19
- Built for the [OpenEnv Challenge](https://github.com/meta-pytorch/OpenEnv), this project packages environment runtime, dense rewards, evaluation, and training hooks so others can reproduce results and iterate quickly.
20
 
21
- **[Read the blog post](https://hjerpe-sqlenv-blog.static.hf.space)** | **[Source code](https://github.com/hjerpe/sqlenv)**
22
 
23
  ## Quick Start
24
 
25
- Run these three commands to install, validate, and smoke-test the environment:
26
-
27
  ```bash
28
  uv sync
29
- uv run openenv validate --verbose
30
  uv run pytest tests/ -v
31
  ```
32
 
33
- Local server run:
34
 
35
  ```bash
36
  uv run uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
37
  ```
38
 
39
- Docker run:
40
 
41
  ```bash
42
- docker build -t sql-env:latest -f server/Dockerfile .
43
- docker run -p 8000:8000 sql-env:latest
44
- ```
45
-
46
- ## Why SQLEnv
47
-
48
- Static text-to-SQL benchmarks reward final outputs, not reasoning quality. SQLEnv turns SQL generation into an interactive decision process with feedback at each step, making it suitable for RL training and behavior analysis.
49
-
50
- ## Architecture
51
-
52
- ```text
53
- +-------------+ WebSocket +----------------------+ SQLite
54
- | RL Agent | <------------------> | SQLEnvClient | <----------------+
55
- | (GRPO/TRL) | | (client.py) | |
56
- +-------------+ +----------+-----------+ |
57
- HTTP/WebSocket |
58
- | |
59
- v |
60
- +--------------------------+ |
61
- | FastAPI Server | |
62
- | (server.app:app) | |
63
- +------------+-------------+ |
64
- | |
65
- v |
66
- +--------------------------+ |
67
- | SQLEnvironment |------------+
68
- | step/reset/reward/verify |
69
- +--------------------------+
70
  ```
71
 
72
  ## How It Works
73
 
74
- Each episode begins with a natural language question mapped to a hidden Spider database. The agent acts through four environment actions:
75
 
76
- | Action | Purpose | Typical Output |
77
- |--------|---------|----------------|
78
- | `DESCRIBE table_name` | Inspect schema and column metadata | Column names, types, row count |
79
- | `SAMPLE table_name` | Inspect representative rows | Small row sample |
80
- | `QUERY sql_string` | Execute read-only SQL in sandbox | Query result rows or SQL error |
81
- | `ANSWER value` | Submit final answer | Terminal reward and completion |
82
 
83
- Episode flow:
84
- 1. `reset()` returns question context and available tables.
85
- 2. `step()` executes one exploration action at a time.
86
- 3. `ANSWER` ends the episode with correctness-based terminal reward.
87
 
88
- ## Train an Agent
 
89
 
90
- The environment exposes four tools (`describe`, `sample`, `query`, `answer`) that TRL's GRPOTrainer discovers automatically. The model learns to call these tools through GRPO — no custom rollout code needed.
 
 
 
 
 
 
 
91
 
92
- ### Local test (Docker, CPU)
93
 
94
- Verify the training pipeline end-to-end in about 3 minutes:
 
 
 
 
 
 
 
95
 
96
  ```bash
97
  docker build -f Dockerfile.test -t sqlenv-test .
98
  docker run --rm sqlenv-test
99
  ```
100
 
101
- This runs 2 training steps with `configs/test_cpu.json` and prints per-step loss, reward, tool call frequency, and model completions.
102
-
103
- ### Colab training (GPU)
104
 
105
- Open the notebook and select a GPU runtime (L4 recommended):
106
 
107
- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/hjerpe/sql-env/blob/main/notebooks/train_grpo.ipynb)
108
-
109
- The notebook uses `configs/colab_l4.json` settings: batch size 4, 4 generations per prompt, bf16 precision. Live reward plots and execution traces update during training.
110
-
111
- ### What the model sees
112
-
113
- Each episode, TRL injects tool schemas into the prompt. The model generates structured tool calls:
114
 
 
 
115
  ```
116
- <tool_call>{"name": "describe", "arguments": {"table_name": "employee"}}</tool_call>
117
- ```
118
 
119
- TRL parses this, calls `env.describe(table_name="employee")`, and appends the result. The model can then call more tools or submit an answer. Rewards accumulate from each interaction.
 
 
 
 
 
 
 
 
120
 
121
- ### Configuration
122
 
123
- Training configs live in `configs/`:
124
- - `test_cpu.json` — 2 steps, 256 tokens, budget 3 (local validation)
125
- - `colab_l4.json` — full epoch, 512 tokens, budget 10, bf16 (L4 GPU)
126
 
127
- ## HuggingFace Space
128
 
129
- - Live Space: `https://huggingface.co/spaces/<your-org-or-user>/sql-env` (update after push)
130
- - Health check: `curl https://<space-url>/health`
131
- - Deploy command: `uv run openenv push`
132
 
133
  ## Project Structure
134
 
135
- ```text
136
- sql-env/
137
- |- __init__.py
138
- |- client.py
139
- |- models.py
140
- |- openenv.yaml
141
- |- server/
142
- | |- app.py
143
- | |- sql_environment.py
144
- | |- reward.py
145
- | |- verifier.py
146
- | `- Dockerfile
147
- |- data/
148
- | |- databases/
149
- | `- questions/
150
- |- training/
151
- |- evaluation/
152
- |- notebooks/
153
- | `- train_grpo.ipynb
154
- |- specs/
155
- |- docs/
156
- `- tests/
157
  ```
158
 
159
- ## Deployment Checklist
160
 
161
- 1. `uv run openenv validate --verbose`
162
- 2. `uv run openenv build`
163
- 3. `uv run openenv push`
164
- 4. Verify `/health` and run one full episode through the client.
 
165
 
166
- ## Links
167
 
168
- - OpenEnv framework: https://github.com/meta-pytorch/OpenEnv
169
- - OpenEnv docs: https://meta-pytorch.org/OpenEnv/
170
- - Spider dataset: https://huggingface.co/datasets/xlangai/spider
171
- - TRL OpenEnv docs: https://huggingface.co/docs/trl/openenv
172
- - Verification plan: `specs/F007-VERIFICATION_SPEC.md`
 
5
  colorTo: green
6
  sdk: docker
7
  app_port: 8000
8
+ pinned: true
9
  base_path: /web
10
  ---
11
 
12
+ # SQLEnv: Teaching Small Models to Explore Databases
13
 
14
  ![Python](https://img.shields.io/badge/python-3.12-blue.svg)
15
  ![License](https://img.shields.io/badge/license-MIT-green.svg)
16
+ ![Data](https://img.shields.io/badge/data-CC%20BY--SA%204.0-orange.svg)
17
 
18
+ SQLEnv is an RL environment for training small language models to answer questions about SQL databases through iterative exploration. Instead of producing one-shot SQL from a fully visible schema, the agent discovers the schema step by step using four tools: DESCRIBE, SAMPLE, QUERY, and ANSWER.
19
 
20
+ Built on [OpenEnv](https://github.com/meta-pytorch/OpenEnv) and trained with [TRL](https://huggingface.co/docs/trl)'s GRPO implementation. A 0.6B parameter model trained in this environment goes from 0% to ~30% accuracy on a curated Spider subset, learning to explore schemas, recover from SQL errors, and format answers correctly.
21
 
22
+ **[Blog post](https://hjerpe-sqlenv-blog.static.hf.space)** | **[Live environment](https://huggingface.co/spaces/hjerpe/sql_env)** | **[Training notebook](notebooks/train_grpo.ipynb)**
23
 
24
  ## Quick Start
25
 
 
 
26
  ```bash
27
  uv sync
 
28
  uv run pytest tests/ -v
29
  ```
30
 
31
+ Run the environment locally:
32
 
33
  ```bash
34
  uv run uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
35
  ```
36
 
37
+ Or with Docker:
38
 
39
  ```bash
40
+ docker build -t sqlenv:latest -f server/Dockerfile .
41
+ docker run -p 8000:8000 sqlenv:latest
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  ```
43
 
44
  ## How It Works
45
 
46
+ Each episode starts with a natural-language question and a list of table names. The schema (columns, types, relationships) is hidden. The agent uses four actions to explore:
47
 
48
+ | Action | Purpose |
49
+ |--------|---------|
50
+ | `DESCRIBE table` | Reveal column names, types, and row count |
51
+ | `SAMPLE table` | Preview representative rows |
52
+ | `QUERY sql` | Execute read-only SQL |
53
+ | `ANSWER value` | Submit a final answer (ends episode) |
54
 
55
+ The environment provides dense reward at each step (operational feedback + progress toward the answer) and a terminal reward for correctness (+1.0 correct, 0.0 wrong). See the [blog post](https://hjerpe-sqlenv-blog.static.hf.space) for details on the reward architecture.
 
 
 
56
 
57
+ ```python
58
+ from server.sql_environment import SQLEnvironment, SQLAction
59
 
60
+ env = SQLEnvironment(questions_path="data/questions/questions_train.json",
61
+ db_dir="data/databases", tokenizer=tok)
62
+ obs = env.reset(seed=42)
63
+ obs = env.step(SQLAction(action_type="DESCRIBE", argument="employee"))
64
+ obs = env.step(SQLAction(action_type="QUERY", argument="SELECT COUNT(*) FROM employee"))
65
+ obs = env.step(SQLAction(action_type="ANSWER", argument="10"))
66
+ # obs.done=True, obs.reward=1.0
67
+ ```
68
 
69
+ ## Training
70
 
71
+ We train [Qwen3-0.6B](https://arxiv.org/abs/2505.09388) using [GRPO](https://arxiv.org/abs/2402.03300) (from DeepSeekMath) through TRL's `environment_factory`. The full pipeline (SFT warmup + two-phase GRPO) runs in ~5 hours on a single Colab L4.
72
+
73
+ **Notebooks:**
74
+ - **[train_grpo.ipynb](notebooks/train_grpo.ipynb)** runs the full SFT + GRPO pipeline
75
+ - **[compare_methods.ipynb](notebooks/compare_methods.ipynb)** evaluates base vs trained models
76
+ - **[showcase_sqlenv.ipynb](notebooks/showcase_sqlenv.ipynb)** lets you explore the environment interactively
77
+
78
+ **Local test (CPU, ~3 min):**
79
 
80
  ```bash
81
  docker build -f Dockerfile.test -t sqlenv-test .
82
  docker run --rm sqlenv-test
83
  ```
84
 
85
+ ## Evaluation
 
 
86
 
87
+ All evaluation runs through the Green Agent evaluator:
88
 
89
+ ```python
90
+ from sql_env.evaluation import evaluate, RandomPolicy, OraclePolicy
 
 
 
 
 
91
 
92
+ result = evaluate(env, policy, n_episodes=50, seed=0)
93
+ print(f"Accuracy: {result.success_rate:.1%}, Reward: {result.avg_reward:.3f}")
94
  ```
 
 
95
 
96
+ Results on our curated 10-database Spider subset (N=50, 2 runs):
97
+
98
+ | Method | Accuracy | Parse Rate | Avg Steps |
99
+ |--------|----------|------------|-----------|
100
+ | Zero-shot | 0% | 24-28% | 10.8-12.4 |
101
+ | 1-shot | 0-2% | 16-17% | 14.0-14.8 |
102
+ | 3-shot | 0% | 19-20% | 13.8-14.8 |
103
+ | GRPO v1 (2 epochs) | 28-30% | 95-100% | 3.5-4.0 |
104
+ | GRPO v2 (4 epochs) | 24-32% | 87-95% | 3.5-4.0 |
105
 
106
+ This evaluation is not comparable to the official Spider leaderboard, which uses different scoring, full-schema input, and a broader database set. See the [blog post](https://hjerpe-sqlenv-blog.static.hf.space) for detailed analysis.
107
 
108
+ ## Data
 
 
109
 
110
+ 676 questions (473 train, 203 eval) across 10 Spider databases with difficulty labels, plus 120 multi-turn SFT warmup trajectories generated from gold SQL. See [docs/data-sources.md](docs/data-sources.md) for full details on provenance, curation, and regeneration.
111
 
112
+ Data in `data/` is adapted from [Spider](https://yale-lily.github.io/spider) (Yu et al., 2018) and shared under CC BY-SA 4.0. See [DATA_LICENSE](DATA_LICENSE).
 
 
113
 
114
  ## Project Structure
115
 
116
+ ```
117
+ sqlenv/
118
+ ├── __init__.py, client.py, models.py # Core types and client
119
+ ├── server/
120
+ │ ├── app.py # FastAPI server
121
+ │ ├── sql_environment.py # Environment implementation
122
+ │ ├── reward.py # Three-layer reward function
123
+ │ ├── verifier.py # Answer verification
124
+ │ └── Dockerfile # HF Spaces deployment
125
+ ├── evaluation/ # Green Agent evaluator, policies
126
+ ├── training/ # TRL adapter, data loading
127
+ ├── scripts/ # Data curation, SFT generation
128
+ ├── notebooks/ # Training, evaluation, showcase
129
+ ├── data/
130
+ │ ├── databases/ # 10 Spider SQLite databases
131
+ │ ├── questions/ # Train/eval question sets
132
+ │ └── sft/ # SFT warmup trajectories
133
+ ├── configs/ # Training configurations
134
+ ├── tests/ # Unit and integration tests
135
+ └── docs/
136
+ ├── data-sources.md # Data provenance
137
+ └── ARCHITECTURE.md # System architecture
138
  ```
139
 
140
+ ## References
141
 
142
+ - Yu et al. (2018). [Spider: A Large-Scale Human-Labeled Dataset for Complex and Cross-Domain Semantic Parsing and Text-to-SQL Task](https://yale-lily.github.io/spider). EMNLP.
143
+ - Shao et al. (2024). [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://arxiv.org/abs/2402.03300). (GRPO algorithm)
144
+ - Ng, Harada, Russell (1999). [Policy Invariance Under Reward Transformations](https://people.eecs.berkeley.edu/~pabbeel/cs287-fa09/readings/NgHaradaRussell-shaping-ICML1999.pdf). ICML.
145
+ - [OpenEnv framework](https://github.com/meta-pytorch/OpenEnv)
146
+ - [TRL OpenEnv docs](https://huggingface.co/docs/trl/openenv)
147
 
148
+ ## License
149
 
150
+ Code: [MIT](LICENSE). Data: [CC BY-SA 4.0](DATA_LICENSE).
 
 
 
 
docs/ARCHITECTURE.md CHANGED
@@ -2,7 +2,7 @@
2
 
3
  > Last updated: 2026-03-29
4
 
5
- System map for SQLEnv an RL environment where agents learn interactive SQL exploration via the OpenEnv framework.
6
 
7
  **Goals:**
8
  - Show how components connect (system map + key flows)
@@ -11,8 +11,8 @@ System map for SQLEnv — an RL environment where agents learn interactive SQL e
11
  - Keep invariants legible (what must stay true)
12
 
13
  **Non-goals:**
14
- - CLI reference (see `docs/RUNBOOK.md`)
15
- - Per-feature implementation details (link to specs)
16
 
17
  ---
18
 
@@ -45,7 +45,7 @@ System map for SQLEnv — an RL environment where agents learn interactive SQL e
45
  ────────── │ DESCRIBE → PRAGMA │
46
  +──────────────+ │ SAMPLE → SELECT N │
47
  │ evaluate() │──> env.reset/step │ QUERY → SQL exec │
48
- │ policies │ │ ANSWER → verifier │
49
  │ .py │ +────────┬───────────────+
50
  +──────────────+ │
51
  │ v
@@ -267,7 +267,7 @@ class EpisodeContext:
267
  cumulative_new_info_reward: float = 0.0
268
  ```
269
 
270
- **POMDP design:** The agent sees `SQLObservation`; the server holds `EpisodeContext`. The agent never sees gold answers, progress scores, or the full database. This separation forces exploration.
271
 
272
  ---
273
 
@@ -348,7 +348,7 @@ except ImportError:
348
  | `QUESTIONS_PATH` | No | `data/questions/student_assessment.json` | Questions JSON |
349
  | `DB_DIR` | No | `data/databases/` | SQLite database directory |
350
  | `TOKENIZER_NAME` | No | `mistralai/Mistral-7B-Instruct-v0.1` | HuggingFace tokenizer |
351
- | `PORT` | No | `8000` | Server port (HF Spaces uses 7860) |
352
 
353
  ---
354
 
@@ -431,7 +431,7 @@ uv run openenv build # build Docker image
431
  uv run openenv push # push to HF Spaces
432
  ```
433
 
434
- The Dockerfile uses multi-stage build with `openenv-base`, runs as non-root `appuser`, bundles Spider databases, and exposes `PORT` (default 7860 on HF Spaces).
435
 
436
  ---
437
 
@@ -451,7 +451,7 @@ The Dockerfile uses multi-stage build with `openenv-base`, runs as non-root `app
451
  |------|------------|
452
  | Episode | One question-answering session: reset -> N steps -> terminal |
453
  | Action type | One of: DESCRIBE, SAMPLE, QUERY, ANSWER |
454
- | POMDP | Partially observable MDP agent acts under uncertainty |
455
  | Spider | Academic text-to-SQL benchmark dataset (10 DBs used) |
456
  | OpenEnv | Meta's RL environment framework (Environment, EnvClient) |
457
  | Green Agent | OpenEnv's evaluation wrapper pattern |
@@ -464,10 +464,6 @@ The Dockerfile uses multi-stage build with `openenv-base`, runs as non-root `app
464
 
465
  ## References
466
 
467
- - Docs index: `docs/README.md`
468
- - Operations: `docs/RUNBOOK.md`
469
- - Vision: `vision/VISION.md`
470
- - Feature specs: `specs/FEATURES.json`
471
  - OpenEnv framework: https://github.com/meta-pytorch/OpenEnv
472
  - Spider dataset: https://huggingface.co/datasets/xlangai/spider
473
  - TRL OpenEnv docs: https://huggingface.co/docs/trl/openenv
 
2
 
3
  > Last updated: 2026-03-29
4
 
5
+ System map for SQLEnv, an RL environment where agents learn interactive SQL exploration via the OpenEnv framework.
6
 
7
  **Goals:**
8
  - Show how components connect (system map + key flows)
 
11
  - Keep invariants legible (what must stay true)
12
 
13
  **Non-goals:**
14
+ - Exhaustive API reference
15
+ - Training hyperparameter tuning guide
16
 
17
  ---
18
 
 
45
  ────────── │ DESCRIBE → PRAGMA │
46
  +──────────────+ │ SAMPLE → SELECT N │
47
  │ evaluate() │──> env.reset/step │ QUERY → SQL exec │
48
+ │ policies │ │ ANSWER → verifier │
49
  │ .py │ +────────┬───────────────+
50
  +──────────────+ │
51
  │ v
 
267
  cumulative_new_info_reward: float = 0.0
268
  ```
269
 
270
+ **POMDP design:** The agent sees `SQLObservation`. The server holds `EpisodeContext`. The agent never sees gold answers, progress scores, or the full database. This separation forces exploration.
271
 
272
  ---
273
 
 
348
  | `QUESTIONS_PATH` | No | `data/questions/student_assessment.json` | Questions JSON |
349
  | `DB_DIR` | No | `data/databases/` | SQLite database directory |
350
  | `TOKENIZER_NAME` | No | `mistralai/Mistral-7B-Instruct-v0.1` | HuggingFace tokenizer |
351
+ | `PORT` | No | `8000` | Server port |
352
 
353
  ---
354
 
 
431
  uv run openenv push # push to HF Spaces
432
  ```
433
 
434
+ The Dockerfile uses multi-stage build with `openenv-base`, runs as non-root `appuser`, bundles Spider databases, and exposes port 8000.
435
 
436
  ---
437
 
 
451
  |------|------------|
452
  | Episode | One question-answering session: reset -> N steps -> terminal |
453
  | Action type | One of: DESCRIBE, SAMPLE, QUERY, ANSWER |
454
+ | POMDP | Partially observable MDP. Agent acts under uncertainty |
455
  | Spider | Academic text-to-SQL benchmark dataset (10 DBs used) |
456
  | OpenEnv | Meta's RL environment framework (Environment, EnvClient) |
457
  | Green Agent | OpenEnv's evaluation wrapper pattern |
 
464
 
465
  ## References
466
 
 
 
 
 
467
  - OpenEnv framework: https://github.com/meta-pytorch/OpenEnv
468
  - Spider dataset: https://huggingface.co/datasets/xlangai/spider
469
  - TRL OpenEnv docs: https://huggingface.co/docs/trl/openenv
docs/data-sources.md CHANGED
@@ -14,7 +14,7 @@ so a fresh clone works offline after `uv sync`.
14
  | DB allowlist | `data/questions/db_list.json` | hand-curated subset | 10 db_ids |
15
  | SFT trajectories | `data/sft/sft_trajectories.json` | generated from gold SQL | 120 trajectories |
16
 
17
- Total: ~676 questions across 10 Spider databases, plus 120 multi-turn SFT
18
  warmup trajectories.
19
 
20
  ## Upstream: Spider
@@ -26,7 +26,7 @@ gold SQL query, and a target database. We use two mirrors:
26
 
27
  1. **Questions** via HuggingFace Datasets: [`xlangai/spider`](https://huggingface.co/datasets/xlangai/spider)
28
  — loaded with `datasets.load_dataset("xlangai/spider", split=...)` in
29
- `scripts/download_spider_data.py`.
30
  2. **SQLite databases** via the Spider GitHub mirror:
31
  - `https://raw.githubusercontent.com/taoyds/spider/master/database/{db_id}/{db_id}.sqlite`
32
  - Fallback: the official Google Drive archive
@@ -56,8 +56,8 @@ database. This prevents train/eval leakage at the schema level:
56
  dog_kennels, employee_hire_evaluation, flight_2, student_assessment`
57
  - **Eval databases** (4): `flight_2, pets_1, poker_player, world_1`
58
 
59
- `flight_2` appears in both; other eval DBs are schemas the model never
60
- saw during training. `sql_env.training.data_loading.validate_no_data_leak`
61
  asserts zero question-text overlap between the two files at load time.
62
 
63
  ## Question files
@@ -86,10 +86,13 @@ with this shape (actual sample from `car_1` train):
86
  | train | 473 | 435 | 32 | 6 |
87
  | eval | 203 | 185 | 18 | 0 |
88
 
89
- The easy-heavy distribution is deliberate for the 0.6B capacity ceiling
90
- (see `docs/blog-material.md` "The 0.6B Capacity Ceiling"). Medium and
91
- hard questions are kept in the mix for Phase 2 exposure but are not where
92
- this model size gains accuracy.
 
 
 
93
 
94
  ### Curation pipeline
95
 
@@ -126,18 +129,26 @@ one, runs the real `SQLEnvironment` programmatically:
126
  3. `answer(gold_answer)` — terminal step
127
 
128
  The captured sequence becomes an assistant-labelled trajectory. This is
129
- **not synthetic text** the assistant turns wrap the actual environment
130
- responses the model will see at training and inference time, which is
131
- what lets GRPO's KL anchor point align with real env output.
 
132
 
133
- The 120-count is smaller than the 473 training questions because SFT
134
- samples a subset that exercises each database and difficulty bucket;
135
- see `scripts/generate_sft_data.py` for the selection logic.
136
 
137
  Why multi-turn matters: an earlier per-turn SFT (347 single-turn
138
- examples) taught the model to always call `describe` and nothing else.
139
- Multi-turn teaches the full `describe query answer` sequence. See
140
- `docs/blog-material.md` "Multi-Turn SFT Why It's Critical".
 
 
 
 
 
 
 
141
 
142
  ## How to regenerate from scratch
143
 
@@ -146,8 +157,8 @@ Multi-turn teaches the full `describe → query → answer` sequence. See
146
  uv run python scripts/download_spider_databases.py --db-id all
147
 
148
  # 2. Raw Spider questions (via HF Datasets)
149
- uv run python scripts/download_spider_data.py --db-id all --split train
150
- uv run python scripts/download_spider_data.py --db-id all --split validation
151
 
152
  # 3. Curate into questions_train.json / questions_eval.json
153
  uv run python scripts/curate_questions.py
@@ -164,7 +175,7 @@ snapshot.
164
  ## What we deliberately do not use
165
 
166
  - **BIRD** (Li et al., 2023) — larger, harder text-to-SQL benchmark. Out
167
- of scope for a 0.6B model; revisit for a larger-model follow-up.
168
  - **WikiSQL** — single-table only, doesn't exercise the multi-turn
169
  exploration the environment is built for.
170
  - **Synthetic LLM-generated questions** — we want Spider's human-written
 
14
  | DB allowlist | `data/questions/db_list.json` | hand-curated subset | 10 db_ids |
15
  | SFT trajectories | `data/sft/sft_trajectories.json` | generated from gold SQL | 120 trajectories |
16
 
17
+ Total: 676 questions across 10 Spider databases, plus 120 multi-turn SFT
18
  warmup trajectories.
19
 
20
  ## Upstream: Spider
 
26
 
27
  1. **Questions** via HuggingFace Datasets: [`xlangai/spider`](https://huggingface.co/datasets/xlangai/spider)
28
  — loaded with `datasets.load_dataset("xlangai/spider", split=...)` in
29
+ `scripts/download_spider_questions.py`.
30
  2. **SQLite databases** via the Spider GitHub mirror:
31
  - `https://raw.githubusercontent.com/taoyds/spider/master/database/{db_id}/{db_id}.sqlite`
32
  - Fallback: the official Google Drive archive
 
56
  dog_kennels, employee_hire_evaluation, flight_2, student_assessment`
57
  - **Eval databases** (4): `flight_2, pets_1, poker_player, world_1`
58
 
59
+ `flight_2` appears in both. The other eval DBs are schemas the model
60
+ never saw during training. `sql_env.training.data_loading.validate_no_data_leak`
61
  asserts zero question-text overlap between the two files at load time.
62
 
63
  ## Question files
 
86
  | train | 473 | 435 | 32 | 6 |
87
  | eval | 203 | 185 | 18 | 0 |
88
 
89
+ The easy-heavy distribution is deliberate for the 0.6B capacity ceiling.
90
+ Extended GRPO training on harder questions produced identical accuracy,
91
+ which indicates the ceiling comes from pretraining knowledge rather than
92
+ training budget. Medium and hard questions stay in the mix for Phase 2
93
+ exposure but are not where this model size gains accuracy. See the
94
+ "Limitations at 0.6B Parameters" section of the
95
+ [blog post](https://hjerpe-sqlenv-blog.static.hf.space).
96
 
97
  ### Curation pipeline
98
 
 
129
  3. `answer(gold_answer)` — terminal step
130
 
131
  The captured sequence becomes an assistant-labelled trajectory. This is
132
+ **not synthetic text**. The assistant turns wrap the actual environment
133
+ responses the model will see at training and inference time, so the
134
+ SFT-warmed reference policy already expects real env output when GRPO
135
+ takes over.
136
 
137
+ SFT uses 120 trajectories rather than one per training question. The
138
+ subset is chosen to cover each database and difficulty bucket. See
139
+ `scripts/generate_sft_data.py` for the selection logic.
140
 
141
  Why multi-turn matters: an earlier per-turn SFT (347 single-turn
142
+ examples) taught the model to always call `describe`. Half those
143
+ examples were describe calls, so the model learned "when asked a
144
+ question, call describe." Under a KL penalty during GRPO, every rollout
145
+ stayed identical, the advantage between rollouts was zero, and no policy
146
+ gradient could form. Multi-turn SFT (120 full trajectories trained with
147
+ `assistant_only_loss`) instead teaches the full
148
+ `describe → query → answer` sequence as a coherent strategy, which GRPO
149
+ then refines into error recovery and answer formatting. See the
150
+ "Training" section of the
151
+ [blog post](https://hjerpe-sqlenv-blog.static.hf.space).
152
 
153
  ## How to regenerate from scratch
154
 
 
157
  uv run python scripts/download_spider_databases.py --db-id all
158
 
159
  # 2. Raw Spider questions (via HF Datasets)
160
+ uv run python scripts/download_spider_questions.py --db-id all --split train
161
+ uv run python scripts/download_spider_questions.py --db-id all --split validation
162
 
163
  # 3. Curate into questions_train.json / questions_eval.json
164
  uv run python scripts/curate_questions.py
 
175
  ## What we deliberately do not use
176
 
177
  - **BIRD** (Li et al., 2023) — larger, harder text-to-SQL benchmark. Out
178
+ of scope for a 0.6B model. Revisit for a larger-model follow-up.
179
  - **WikiSQL** — single-table only, doesn't exercise the multi-turn
180
  exploration the environment is built for.
181
  - **Synthetic LLM-generated questions** — we want Spider's human-written
notebooks/showcase_sqlenv.ipynb CHANGED
@@ -29,40 +29,83 @@
29
  },
30
  {
31
  "cell_type": "code",
32
- "execution_count": 1,
33
  "metadata": {},
34
- "outputs": [
35
- {
36
- "name": "stdout",
37
- "output_type": "stream",
38
- "text": [
39
- "Project root: /Users/hjerp/Projects/sql-env\n"
40
- ]
41
- }
42
- ],
43
  "source": [
44
  "import os\n",
 
45
  "import sys\n",
46
  "from pathlib import Path\n",
47
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  "\n",
49
- "def find_project_root() -> Path:\n",
50
- " \"\"\"Walk up from CWD until pyproject.toml is found.\"\"\"\n",
51
- " for parent in [Path.cwd(), *Path.cwd().parents]:\n",
52
- " if (parent / \"pyproject.toml\").exists():\n",
53
- " return parent\n",
54
- " raise FileNotFoundError(\"Could not locate project root (no pyproject.toml found)\")\n",
 
 
 
 
 
55
  "\n",
 
 
56
  "\n",
57
- "PROJECT_ROOT = find_project_root()\n",
58
- "os.chdir(PROJECT_ROOT)\n",
59
  "if str(PROJECT_ROOT) not in sys.path:\n",
60
  " sys.path.insert(0, str(PROJECT_ROOT))\n",
61
  "\n",
62
- "# In Colab, uncomment:\n",
63
- "# !pip install -q git+https://github.com/hjerpe/sql-env.git\n",
64
- "# !python scripts/download_spider_databases.py\n",
65
- "\n",
66
  "print(f\"Project root: {PROJECT_ROOT}\")"
67
  ]
68
  },
@@ -553,62 +596,19 @@
553
  },
554
  {
555
  "cell_type": "code",
556
- "execution_count": 12,
557
  "metadata": {},
558
- "outputs": [
559
- {
560
- "name": "stdout",
561
- "output_type": "stream",
562
- "text": [
563
- "Q: List the id of students who registered some courses and the number of their registered courses?\n",
564
- "\n",
565
- " Step 1: DESCRIBE\n",
566
- " Action: student_course_registrations\n",
567
- " Result: Table 'Student_Course_Registrations' columns:\n",
568
- "- student_id: INTEGER\n",
569
- "- course_id: INTEGER\n",
570
- "- registration_date: DATETIME\n",
571
- "Row count: 9\n",
572
- " Reward: +0.0150\n",
573
- "\n",
574
- " Step 2: DESCRIBE\n",
575
- " Action: students\n",
576
- " Result: Table 'Students' columns:\n",
577
- "- student_id: INTEGER\n",
578
- "- student_details: VARCHAR(255)\n",
579
- "Row count: 8\n",
580
- " Reward: +0.0150\n",
581
- "\n",
582
- " Step 3: QUERY\n",
583
- " SQL:\n",
584
- " SELECT T1.student_id , count(*) \n",
585
- " FROM students AS T1 \n",
586
- " JOIN student_course_registrations AS T2 \n",
587
- " ON T1.student_id = T2.student_id \n",
588
- " GROUP BY T1.student_id\n",
589
- " Result: 1. 111 | 1\n",
590
- "2. 121 | 2\n",
591
- "3. 131 | 1\n",
592
- "4. 141 | 2\n",
593
- "5. 151 | 1\n",
594
- "6. 161 | 1\n",
595
- "7. 171 | 1\n",
596
- " Reward: +0.1500\n",
597
- "\n",
598
- " Step 4: ANSWER\n",
599
- " Action: [[111, 1], [121, 2], [131, 1], [141, 2], [151, 1], [161, 1], [171, 1]]\n",
600
- " Result: Answer submitted: correct.\n",
601
- " Reward: +1.0000\n",
602
- "\n",
603
- "Total reward: 1.180\n",
604
- " Exploration (L1+L2): 0.180 (3 steps)\n",
605
- " Terminal (L3): 1.000\n"
606
- ]
607
- }
608
- ],
609
  "source": [
610
  "import re\n",
611
  "\n",
 
 
 
 
 
 
 
612
  "\n",
613
  "def format_sql(sql):\n",
614
  " \"\"\"Simple SQL formatter for display.\"\"\"\n",
@@ -617,14 +617,58 @@
617
  " return formatted\n",
618
  "\n",
619
  "\n",
620
- "# Run one oracle episode and show per-step rewards\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
621
  "obs = env.reset(seed=0)\n",
622
  "oracle = OraclePolicy(questions)\n",
623
  "step_rewards = []\n",
 
624
  "\n",
625
  "print(f\"Q: {obs.question}\\n\")\n",
626
  "while not obs.done:\n",
627
  " action = oracle.select_action(obs)\n",
 
 
 
 
628
  " obs = env.step(action)\n",
629
  " reward = obs.reward or 0.0\n",
630
  " step_rewards.append(reward)\n",
@@ -640,7 +684,7 @@
640
  " print(f\" Result: {obs.result}\")\n",
641
  " if obs.error:\n",
642
  " print(f\" Error: {obs.error}\")\n",
643
- " print(f\" Reward: {reward:+.4f}\")\n",
644
  " print()\n",
645
  "\n",
646
  "exploration = sum(step_rewards[:-1]) if len(step_rewards) > 1 else 0.0\n",
@@ -655,9 +699,17 @@
655
  "cell_type": "markdown",
656
  "metadata": {},
657
  "source": [
658
- "## 8. Connect to a Deployed Space\n",
 
 
 
 
 
 
 
 
659
  "\n",
660
- "The same environment runs as a Docker container on HuggingFace Spaces. The `SQLEnvClient` connects via WebSocket and provides the same `reset()`/`step()` interface."
661
  ]
662
  },
663
  {
@@ -666,25 +718,77 @@
666
  "metadata": {},
667
  "outputs": [],
668
  "source": [
669
- "# Uncomment to connect to a running Space:\n",
670
- "#\n",
671
- "# from sql_env.client import SQLEnvClient\n",
672
- "# from sql_env.models import SQLAction\n",
673
- "#\n",
674
- "# client = SQLEnvClient(base_url=\"wss://your-space.hf.space\")\n",
675
- "# client.connect()\n",
676
- "# result = client.reset(seed=42)\n",
677
- "# obs = result.observation\n",
678
- "# print(\"Question:\", obs.question)\n",
679
- "# print(\"Schema:\", obs.schema_info)\n",
680
- "#\n",
681
- "# # Same actions work over the wire:\n",
682
- "# step = client.step(SQLAction(action_type=\"DESCRIBE\", argument=\"employees\"))\n",
683
- "# print(\"Result:\", step.observation.result)\n",
684
- "#\n",
685
- "# client.close()\n",
686
- "\n",
687
- "print(\"Uncomment the cell above and set your HF Space URL to connect remotely.\")"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
688
  ]
689
  },
690
  {
 
29
  },
30
  {
31
  "cell_type": "code",
32
+ "execution_count": null,
33
  "metadata": {},
34
+ "outputs": [],
 
 
 
 
 
 
 
 
35
  "source": [
36
  "import os\n",
37
+ "import subprocess\n",
38
  "import sys\n",
39
  "from pathlib import Path\n",
40
  "\n",
41
+ "IN_COLAB = \"google.colab\" in sys.modules\n",
42
+ "\n",
43
+ "if IN_COLAB:\n",
44
+ " # Colab: clone the repo, install the package, fetch Spider databases.\n",
45
+ " # Requires a GITHUB_TOKEN in Colab userdata if the repo is private.\n",
46
+ " from google.colab import userdata\n",
47
+ "\n",
48
+ " token = userdata.get(\"GITHUB_TOKEN\")\n",
49
+ " BRANCH = \"main\" # @param {type:\"string\"}\n",
50
+ " repo_url = f\"https://{token}@github.com/hjerpe/sql-env.git\"\n",
51
+ "\n",
52
+ " if Path(\"sql-env\").exists():\n",
53
+ " subprocess.check_call([\"git\", \"-C\", \"sql-env\", \"pull\", \"-q\"])\n",
54
+ " else:\n",
55
+ " subprocess.check_call([\"git\", \"clone\", \"-q\", \"-b\", BRANCH, repo_url])\n",
56
+ " os.chdir(\"sql-env\")\n",
57
+ "\n",
58
+ " print(\"Colab detected: installing dependencies...\")\n",
59
+ " subprocess.check_call(\n",
60
+ " [sys.executable, \"-m\", \"pip\", \"install\", \"-q\", \"--upgrade\", \"pip\"]\n",
61
+ " )\n",
62
+ " subprocess.check_call(\n",
63
+ " [\n",
64
+ " sys.executable,\n",
65
+ " \"-m\",\n",
66
+ " \"pip\",\n",
67
+ " \"install\",\n",
68
+ " \"-q\",\n",
69
+ " \"--no-deps\",\n",
70
+ " \"--force-reinstall\",\n",
71
+ " \".\",\n",
72
+ " ]\n",
73
+ " )\n",
74
+ " subprocess.check_call(\n",
75
+ " [\n",
76
+ " sys.executable,\n",
77
+ " \"-m\",\n",
78
+ " \"pip\",\n",
79
+ " \"install\",\n",
80
+ " \"-q\",\n",
81
+ " \"openenv-core[core]>=0.2.1\",\n",
82
+ " \"pydantic>=2.0.0\",\n",
83
+ " \"jmespath\",\n",
84
+ " ]\n",
85
+ " )\n",
86
+ " # Download Spider SQLite databases the notebook reads from\n",
87
+ " subprocess.check_call(\n",
88
+ " [sys.executable, \"scripts/download_spider_databases.py\", \"--db-id\", \"all\"]\n",
89
+ " )\n",
90
  "\n",
91
+ " PROJECT_ROOT = Path.cwd()\n",
92
+ "else:\n",
93
+ " # Local: walk up from CWD to find the project root\n",
94
+ " def find_project_root() -> Path:\n",
95
+ " \"\"\"Walk up from CWD until pyproject.toml is found.\"\"\"\n",
96
+ " for parent in [Path.cwd(), *Path.cwd().parents]:\n",
97
+ " if (parent / \"pyproject.toml\").exists():\n",
98
+ " return parent\n",
99
+ " raise FileNotFoundError(\n",
100
+ " \"Could not locate project root (no pyproject.toml found)\"\n",
101
+ " )\n",
102
  "\n",
103
+ " PROJECT_ROOT = find_project_root()\n",
104
+ " os.chdir(PROJECT_ROOT)\n",
105
  "\n",
 
 
106
  "if str(PROJECT_ROOT) not in sys.path:\n",
107
  " sys.path.insert(0, str(PROJECT_ROOT))\n",
108
  "\n",
 
 
 
 
109
  "print(f\"Project root: {PROJECT_ROOT}\")"
110
  ]
111
  },
 
596
  },
597
  {
598
  "cell_type": "code",
599
+ "execution_count": null,
600
  "metadata": {},
601
+ "outputs": [],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
602
  "source": [
603
  "import re\n",
604
  "\n",
605
+ "from server.reward import (\n",
606
+ " _EXEC_OK_REWARD,\n",
607
+ " _NEW_INFO_REWARD,\n",
608
+ " _REPEAT_PENALTY,\n",
609
+ " _STEP_COST,\n",
610
+ ")\n",
611
+ "\n",
612
  "\n",
613
  "def format_sql(sql):\n",
614
  " \"\"\"Simple SQL formatter for display.\"\"\"\n",
 
617
  " return formatted\n",
618
  "\n",
619
  "\n",
620
+ "def explain_reward(action_type, error, is_repeat_query, total_reward):\n",
621
+ " \"\"\"Decompose a step reward into labeled components.\n",
622
+ "\n",
623
+ " Layer 1 components (step_cost, exec_ok, new_info, repeat_penalty) are\n",
624
+ " deterministic from action type + state, so we reconstruct them exactly\n",
625
+ " from the reward constants imported above. Layer 2 (progress delta on\n",
626
+ " QUERY) and Layer 3 (terminal on ANSWER) are not exposed in the\n",
627
+ " observation, so we recover them as 'total reward minus L1 sum' and\n",
628
+ " label them accordingly. The clip range [-0.10, +0.15] may adjust the\n",
629
+ " final value — any residual after layer reconstruction is labeled\n",
630
+ " 'clip_adjust'.\n",
631
+ " \"\"\"\n",
632
+ " at = action_type.upper()\n",
633
+ " parts = [(\"step_cost\", -_STEP_COST)] # always applied\n",
634
+ "\n",
635
+ " if error:\n",
636
+ " pass # no exec_ok when the action errored\n",
637
+ " elif at == \"QUERY\" and is_repeat_query:\n",
638
+ " parts.append((\"repeat_penalty\", -_REPEAT_PENALTY))\n",
639
+ " else:\n",
640
+ " parts.append((\"exec_ok\", +_EXEC_OK_REWARD))\n",
641
+ " if at == \"QUERY\":\n",
642
+ " parts.append((\"new_info\", +_NEW_INFO_REWARD))\n",
643
+ "\n",
644
+ " l1_sum = sum(v for _, v in parts)\n",
645
+ " remainder = total_reward - l1_sum\n",
646
+ "\n",
647
+ " if abs(remainder) > 1e-9:\n",
648
+ " if at == \"ANSWER\":\n",
649
+ " parts.append((\"terminal\", remainder))\n",
650
+ " elif at == \"QUERY\":\n",
651
+ " parts.append((\"layer2_progress\", remainder))\n",
652
+ " else:\n",
653
+ " parts.append((\"clip_adjust\", remainder))\n",
654
+ "\n",
655
+ " labels = \" + \".join(f\"{name}({v:+.3f})\" for name, v in parts)\n",
656
+ " return f\"{labels} = {total_reward:+.4f}\"\n",
657
+ "\n",
658
+ "\n",
659
+ "# Run one oracle episode and show per-step rewards with component breakdown\n",
660
  "obs = env.reset(seed=0)\n",
661
  "oracle = OraclePolicy(questions)\n",
662
  "step_rewards = []\n",
663
+ "seen_queries: set[str] = set()\n",
664
  "\n",
665
  "print(f\"Q: {obs.question}\\n\")\n",
666
  "while not obs.done:\n",
667
  " action = oracle.select_action(obs)\n",
668
+ " is_repeat = action.action_type.upper() == \"QUERY\" and action.argument in seen_queries\n",
669
+ " if action.action_type.upper() == \"QUERY\":\n",
670
+ " seen_queries.add(action.argument)\n",
671
+ "\n",
672
  " obs = env.step(action)\n",
673
  " reward = obs.reward or 0.0\n",
674
  " step_rewards.append(reward)\n",
 
684
  " print(f\" Result: {obs.result}\")\n",
685
  " if obs.error:\n",
686
  " print(f\" Error: {obs.error}\")\n",
687
+ " print(f\" Reward: {explain_reward(action.action_type, obs.error, is_repeat, reward)}\")\n",
688
  " print()\n",
689
  "\n",
690
  "exploration = sum(step_rewards[:-1]) if len(step_rewards) > 1 else 0.0\n",
 
699
  "cell_type": "markdown",
700
  "metadata": {},
701
  "source": [
702
+ "## 8. Same Environment, Over the Wire\n",
703
+ "\n",
704
+ "The same `SQLEnvironment` runs as a Docker container on HuggingFace Spaces:\n",
705
+ "[**huggingface.co/spaces/hjerpe/sql_env**](https://huggingface.co/spaces/hjerpe/sql_env).\n",
706
+ "`SQLEnvClient` connects via WebSocket and provides the same `reset()` /\n",
707
+ "`step()` interface we used above — same action space, same observation shape,\n",
708
+ "same reward model. The only difference is that the SQLite database and\n",
709
+ "reward computation now live on a remote container instead of in this\n",
710
+ "Python process.\n",
711
  "\n",
712
+ "The cell below drives one full episode against the live Space."
713
  ]
714
  },
715
  {
 
718
  "metadata": {},
719
  "outputs": [],
720
  "source": [
721
+ "from sql_env.client import SQLEnvClient\n",
722
+ "from sql_env.models import SQLAction\n",
723
+ "\n",
724
+ "# Live hosted Space. This is the URL anyone in the world can point a client\n",
725
+ "# at — no local setup required. The first request may take ~30s if the\n",
726
+ "# container is cold-starting.\n",
727
+ "SPACE_URL = \"https://hjerpe-sql-env.hf.space\"\n",
728
+ "\n",
729
+ "print(f\"Connecting to {SPACE_URL} ...\\n\")\n",
730
+ "\n",
731
+ "# openenv-core's SQLEnvClient is sync-by-default in older versions but\n",
732
+ "# async-by-default in newer ones (the newer API exposes .sync() as an\n",
733
+ "# explicit synchronous wrapper). Detect at runtime so the cell works on\n",
734
+ "# both local dev installs and Colab's pinned >=0.2.1 version.\n",
735
+ "_remote_client = SQLEnvClient(base_url=SPACE_URL)\n",
736
+ "_remote_ctx = _remote_client.sync() if hasattr(_remote_client, \"sync\") else _remote_client\n",
737
+ "\n",
738
+ "try:\n",
739
+ " with _remote_ctx as remote_env:\n",
740
+ " # --- reset ---\n",
741
+ " result = remote_env.reset()\n",
742
+ " remote_obs = result.observation\n",
743
+ " print(f\"Q: {remote_obs.question}\")\n",
744
+ " tables = [\n",
745
+ " line.lstrip(\"- \").strip()\n",
746
+ " for line in remote_obs.schema_info.splitlines()[1:]\n",
747
+ " if line.strip()\n",
748
+ " ]\n",
749
+ " print(f\"Tables: {tables}\\n\")\n",
750
+ "\n",
751
+ " first_table = tables[0]\n",
752
+ "\n",
753
+ " # --- describe ---\n",
754
+ " result = remote_env.step(\n",
755
+ " SQLAction(action_type=\"DESCRIBE\", argument=first_table)\n",
756
+ " )\n",
757
+ " print(f\"DESCRIBE {first_table}\")\n",
758
+ " print(f\" reward: {result.observation.reward:+.4f}\")\n",
759
+ " # Line-based preview so truncation never cuts mid-word\n",
760
+ " _lines = result.observation.result.splitlines()\n",
761
+ " _preview = \"\\n \".join(_lines[:6])\n",
762
+ " _more = (\n",
763
+ " f\"\\n ... ({len(_lines) - 6} more lines)\"\n",
764
+ " if len(_lines) > 6\n",
765
+ " else \"\"\n",
766
+ " )\n",
767
+ " print(f\" result: {_preview}{_more}\\n\")\n",
768
+ "\n",
769
+ " # --- query ---\n",
770
+ " query_sql = f\"SELECT COUNT(*) FROM {first_table}\"\n",
771
+ " result = remote_env.step(\n",
772
+ " SQLAction(action_type=\"QUERY\", argument=query_sql)\n",
773
+ " )\n",
774
+ " print(f\"QUERY {query_sql}\")\n",
775
+ " print(f\" reward: {result.observation.reward:+.4f}\")\n",
776
+ " print(f\" result: {result.observation.result}\\n\")\n",
777
+ "\n",
778
+ " # --- answer (intentionally wrong — we're demoing plumbing, not correctness) ---\n",
779
+ " result = remote_env.step(\n",
780
+ " SQLAction(action_type=\"ANSWER\", argument=\"demo\")\n",
781
+ " )\n",
782
+ " print(f\"ANSWER demo\")\n",
783
+ " print(f\" done: {result.observation.done}\")\n",
784
+ " print(f\" reward: {result.observation.reward:+.4f}\")\n",
785
+ " print(\"\\nSame action space, same observation shape, same rewards — just running remotely.\")\n",
786
+ "except Exception as exc: # noqa: BLE001 — demo cell should not crash the notebook\n",
787
+ " print(f\"Remote call failed: {type(exc).__name__}: {exc}\")\n",
788
+ " print(\n",
789
+ " \"If the Space is sleeping, the first request usually wakes it. \"\n",
790
+ " \"Retry in ~30s, or skip this cell to run the notebook fully offline.\"\n",
791
+ " )"
792
  ]
793
  },
794
  {
scripts/download_spider_questions.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Script to download Spider dataset questions for specific databases.
3
+
4
+ Usage:
5
+ python download_spider_questions.py --db-id student_assessment
6
+ python download_spider_questions.py --db-id student_assessment --split validation
7
+ python download_spider_questions.py --db-id all # downloads all db_ids
8
+ """
9
+
10
+ import json
11
+ import argparse
12
+ from pathlib import Path
13
+ from datasets import load_dataset
14
+
15
+
16
+ def download_spider_questions(
17
+ db_id: str = "student_assessment",
18
+ split: str = "train",
19
+ output_dir: str = "data/questions",
20
+ ) -> None:
21
+ """Download Spider dataset questions for specified database(s).
22
+
23
+ Args:
24
+ db_id: Database ID to filter by, or "all" to get all databases
25
+ split: Dataset split ("train" or "validation")
26
+ output_dir: Directory to save JSON files
27
+ """
28
+ output_path = Path(output_dir)
29
+ output_path.mkdir(parents=True, exist_ok=True)
30
+
31
+ print(f"Loading Spider dataset ({split} split)...")
32
+ dataset = load_dataset("xlangai/spider", split=split)
33
+
34
+ if db_id.lower() == "all":
35
+ # Group by db_id
36
+ grouped = {}
37
+ for item in dataset:
38
+ current_db_id = item.get("db_id")
39
+ if current_db_id not in grouped:
40
+ grouped[current_db_id] = []
41
+ grouped[current_db_id].append(item)
42
+
43
+ total_questions = 0
44
+ for current_db_id, questions in grouped.items():
45
+ filepath = output_path / f"{current_db_id}.json"
46
+ with open(filepath, "w") as f:
47
+ json.dump(questions, f, indent=2)
48
+ print(f" {current_db_id}: {len(questions)} questions → {filepath}")
49
+ total_questions += len(questions)
50
+
51
+ print(f"\nTotal: {total_questions} questions across {len(grouped)} databases")
52
+ else:
53
+ # Filter for specific db_id
54
+ filtered_data = [item for item in dataset if item.get("db_id") == db_id]
55
+
56
+ if not filtered_data:
57
+ print(f"No questions found for db_id='{db_id}'")
58
+ return
59
+
60
+ filepath = output_path / f"{db_id}.json"
61
+ with open(filepath, "w") as f:
62
+ json.dump(filtered_data, f, indent=2)
63
+
64
+ print(f"Found {len(filtered_data)} questions for db_id='{db_id}'")
65
+ print(f"Saved to {filepath}")
66
+
67
+ # Print sample
68
+ if filtered_data:
69
+ sample = filtered_data[0]
70
+ print("\nFirst question sample:")
71
+ print(
72
+ json.dumps(
73
+ {k: v for k, v in sample.items() if k != "evidence"}, indent=2
74
+ )
75
+ )
76
+
77
+
78
+ if __name__ == "__main__":
79
+ parser = argparse.ArgumentParser(
80
+ description="Download Spider dataset questions for specific databases",
81
+ formatter_class=argparse.RawDescriptionHelpFormatter,
82
+ )
83
+ parser.add_argument(
84
+ "--db-id",
85
+ type=str,
86
+ default="student_assessment",
87
+ help="Database ID to filter by (or 'all' for all databases)",
88
+ )
89
+ parser.add_argument(
90
+ "--split",
91
+ type=str,
92
+ default="train",
93
+ choices=["train", "validation"],
94
+ help="Dataset split to download",
95
+ )
96
+ parser.add_argument(
97
+ "--output-dir",
98
+ type=str,
99
+ default="data/questions",
100
+ help="Directory to save JSON files",
101
+ )
102
+
103
+ args = parser.parse_args()
104
+ download_spider_questions(
105
+ db_id=args.db_id, split=args.split, output_dir=args.output_dir
106
+ )
server/app.py CHANGED
@@ -80,7 +80,16 @@ def create_sql_environment():
80
  )
81
 
82
 
83
- # Create the FastAPI app
 
 
 
 
 
 
 
 
 
84
  app = create_app(
85
  create_sql_environment,
86
  SQLAction,
 
80
  )
81
 
82
 
83
+ # Create the FastAPI app.
84
+ #
85
+ # Note: hosted Space is single-session. External users running TRL's
86
+ # GRPOTrainer against https://hjerpe-sql-env.hf.space with
87
+ # num_generations > 1 will hit openenv-core's default 1-session cap.
88
+ # Fix requires (a) auditing SQLEnvironment for shared mutable state
89
+ # across sessions, (b) declaring SUPPORTS_CONCURRENT_SESSIONS=True on
90
+ # the class, (c) passing max_concurrent_envs=64 here. Deferred as a
91
+ # post-launch follow-up. Our own training uses an in-process
92
+ # SQLEnvironment via SQLEnvTRL, so this does not affect internal runs.
93
  app = create_app(
94
  create_sql_environment,
95
  SQLAction,
sql_env.egg-info/PKG-INFO CHANGED
@@ -3,6 +3,7 @@ Name: sql-env
3
  Version: 0.1.0
4
  Summary: Interactive SQL exploration RL environment for the OpenEnv Challenge
5
  Requires-Python: <3.13,>=3.11
 
6
  Requires-Dist: openenv-core[core]>=0.2.1
7
  Requires-Dist: pydantic>=2.0.0
8
  Requires-Dist: fastapi>=0.104.0
@@ -24,3 +25,4 @@ Requires-Dist: trl>=0.29.0; extra == "training"
24
  Requires-Dist: accelerate>=0.34.0; extra == "training"
25
  Requires-Dist: notebook>=7.5.5; extra == "training"
26
  Requires-Dist: matplotlib>=3.7.0; extra == "training"
 
 
3
  Version: 0.1.0
4
  Summary: Interactive SQL exploration RL environment for the OpenEnv Challenge
5
  Requires-Python: <3.13,>=3.11
6
+ License-File: LICENSE
7
  Requires-Dist: openenv-core[core]>=0.2.1
8
  Requires-Dist: pydantic>=2.0.0
9
  Requires-Dist: fastapi>=0.104.0
 
25
  Requires-Dist: accelerate>=0.34.0; extra == "training"
26
  Requires-Dist: notebook>=7.5.5; extra == "training"
27
  Requires-Dist: matplotlib>=3.7.0; extra == "training"
28
+ Dynamic: license-file
sql_env.egg-info/SOURCES.txt CHANGED
@@ -1,8 +1,5 @@
 
1
  README.md
2
- __init__.py
3
- client.py
4
- conftest.py
5
- models.py
6
  pyproject.toml
7
  ./__init__.py
8
  ./client.py
@@ -16,9 +13,9 @@ evaluation/oracle_policy.py
16
  evaluation/policies.py
17
  server/__init__.py
18
  server/app.py
 
19
  server/reward.py
20
  server/sql_environment.py
21
- server/test_sql_env.py
22
  server/verifier.py
23
  sql_env.egg-info/PKG-INFO
24
  sql_env.egg-info/SOURCES.txt
@@ -38,6 +35,5 @@ training/data_loading.py
38
  training/few_shot_examples.py
39
  training/notebook_pipeline.py
40
  training/prompts.py
41
- training/rewards.py
42
  training/trl_adapter.py
43
  training/visualization.py
 
1
+ LICENSE
2
  README.md
 
 
 
 
3
  pyproject.toml
4
  ./__init__.py
5
  ./client.py
 
13
  evaluation/policies.py
14
  server/__init__.py
15
  server/app.py
16
+ server/mock_tokenizer.py
17
  server/reward.py
18
  server/sql_environment.py
 
19
  server/verifier.py
20
  sql_env.egg-info/PKG-INFO
21
  sql_env.egg-info/SOURCES.txt
 
35
  training/few_shot_examples.py
36
  training/notebook_pipeline.py
37
  training/prompts.py
 
38
  training/trl_adapter.py
39
  training/visualization.py