Upload folder using huggingface_hub
Browse files- DATA_LICENSE +19 -0
- LICENSE +21 -0
- README.md +86 -108
- docs/ARCHITECTURE.md +8 -12
- docs/data-sources.md +31 -20
- notebooks/showcase_sqlenv.ipynb +201 -97
- scripts/download_spider_questions.py +106 -0
- server/app.py +10 -1
- sql_env.egg-info/PKG-INFO +2 -0
- sql_env.egg-info/SOURCES.txt +2 -6
DATA_LICENSE
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Data License Notice
|
| 2 |
+
|
| 3 |
+
Data in data/ is adapted from the Spider dataset (Yu et al., 2018),
|
| 4 |
+
distributed under CC BY-SA 4.0.
|
| 5 |
+
|
| 6 |
+
We retrieved question/SQL pairs from the xlangai/spider HuggingFace mirror
|
| 7 |
+
and SQLite databases from the taoyds/spider GitHub mirror, then curated a
|
| 8 |
+
10-database subset, derived gold answers by executing the gold SQL, and
|
| 9 |
+
generated SFT trajectories from those artifacts.
|
| 10 |
+
|
| 11 |
+
Derived data in data/ is shared under CC BY-SA 4.0.
|
| 12 |
+
Software code is licensed separately under MIT (see LICENSE).
|
| 13 |
+
|
| 14 |
+
References:
|
| 15 |
+
- Spider dataset: https://yale-lily.github.io/spider
|
| 16 |
+
- Yu et al. (2018). Spider: A Large-Scale Human-Labeled Dataset for Complex
|
| 17 |
+
and Cross-Domain Semantic Parsing and Text-to-SQL Task. EMNLP.
|
| 18 |
+
- xlangai/spider on HuggingFace: https://huggingface.co/datasets/xlangai/spider
|
| 19 |
+
- taoyds/spider on GitHub: https://github.com/taoyds/spider
|
LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2026 Adam Hjerpe
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
README.md
CHANGED
|
@@ -5,168 +5,146 @@ colorFrom: blue
|
|
| 5 |
colorTo: green
|
| 6 |
sdk: docker
|
| 7 |
app_port: 8000
|
| 8 |
-
pinned:
|
| 9 |
base_path: /web
|
| 10 |
---
|
| 11 |
|
| 12 |
-
# SQLEnv: Teaching
|
| 13 |
|
| 14 |

|
| 15 |

|
|
|
|
| 16 |
|
| 17 |
-
SQLEnv is an
|
| 18 |
|
| 19 |
-
Built
|
| 20 |
|
| 21 |
-
**[
|
| 22 |
|
| 23 |
## Quick Start
|
| 24 |
|
| 25 |
-
Run these three commands to install, validate, and smoke-test the environment:
|
| 26 |
-
|
| 27 |
```bash
|
| 28 |
uv sync
|
| 29 |
-
uv run openenv validate --verbose
|
| 30 |
uv run pytest tests/ -v
|
| 31 |
```
|
| 32 |
|
| 33 |
-
|
| 34 |
|
| 35 |
```bash
|
| 36 |
uv run uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
|
| 37 |
```
|
| 38 |
|
| 39 |
-
|
| 40 |
|
| 41 |
```bash
|
| 42 |
-
docker build -t
|
| 43 |
-
docker run -p 8000:8000
|
| 44 |
-
```
|
| 45 |
-
|
| 46 |
-
## Why SQLEnv
|
| 47 |
-
|
| 48 |
-
Static text-to-SQL benchmarks reward final outputs, not reasoning quality. SQLEnv turns SQL generation into an interactive decision process with feedback at each step, making it suitable for RL training and behavior analysis.
|
| 49 |
-
|
| 50 |
-
## Architecture
|
| 51 |
-
|
| 52 |
-
```text
|
| 53 |
-
+-------------+ WebSocket +----------------------+ SQLite
|
| 54 |
-
| RL Agent | <------------------> | SQLEnvClient | <----------------+
|
| 55 |
-
| (GRPO/TRL) | | (client.py) | |
|
| 56 |
-
+-------------+ +----------+-----------+ |
|
| 57 |
-
HTTP/WebSocket |
|
| 58 |
-
| |
|
| 59 |
-
v |
|
| 60 |
-
+--------------------------+ |
|
| 61 |
-
| FastAPI Server | |
|
| 62 |
-
| (server.app:app) | |
|
| 63 |
-
+------------+-------------+ |
|
| 64 |
-
| |
|
| 65 |
-
v |
|
| 66 |
-
+--------------------------+ |
|
| 67 |
-
| SQLEnvironment |------------+
|
| 68 |
-
| step/reset/reward/verify |
|
| 69 |
-
+--------------------------+
|
| 70 |
```
|
| 71 |
|
| 72 |
## How It Works
|
| 73 |
|
| 74 |
-
Each episode
|
| 75 |
|
| 76 |
-
| Action | Purpose |
|
| 77 |
-
|--------|---------|
|
| 78 |
-
| `DESCRIBE
|
| 79 |
-
| `SAMPLE
|
| 80 |
-
| `QUERY
|
| 81 |
-
| `ANSWER value` | Submit final answer
|
| 82 |
|
| 83 |
-
|
| 84 |
-
1. `reset()` returns question context and available tables.
|
| 85 |
-
2. `step()` executes one exploration action at a time.
|
| 86 |
-
3. `ANSWER` ends the episode with correctness-based terminal reward.
|
| 87 |
|
| 88 |
-
|
|
|
|
| 89 |
|
| 90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
-
##
|
| 93 |
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
```bash
|
| 97 |
docker build -f Dockerfile.test -t sqlenv-test .
|
| 98 |
docker run --rm sqlenv-test
|
| 99 |
```
|
| 100 |
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
### Colab training (GPU)
|
| 104 |
|
| 105 |
-
|
| 106 |
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
The notebook uses `configs/colab_l4.json` settings: batch size 4, 4 generations per prompt, bf16 precision. Live reward plots and execution traces update during training.
|
| 110 |
-
|
| 111 |
-
### What the model sees
|
| 112 |
-
|
| 113 |
-
Each episode, TRL injects tool schemas into the prompt. The model generates structured tool calls:
|
| 114 |
|
|
|
|
|
|
|
| 115 |
```
|
| 116 |
-
<tool_call>{"name": "describe", "arguments": {"table_name": "employee"}}</tool_call>
|
| 117 |
-
```
|
| 118 |
|
| 119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
|
| 121 |
-
|
| 122 |
|
| 123 |
-
|
| 124 |
-
- `test_cpu.json` — 2 steps, 256 tokens, budget 3 (local validation)
|
| 125 |
-
- `colab_l4.json` — full epoch, 512 tokens, budget 10, bf16 (L4 GPU)
|
| 126 |
|
| 127 |
-
|
| 128 |
|
| 129 |
-
|
| 130 |
-
- Health check: `curl https://<space-url>/health`
|
| 131 |
-
- Deploy command: `uv run openenv push`
|
| 132 |
|
| 133 |
## Project Structure
|
| 134 |
|
| 135 |
-
```
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
```
|
| 158 |
|
| 159 |
-
##
|
| 160 |
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
|
|
|
| 165 |
|
| 166 |
-
##
|
| 167 |
|
| 168 |
-
|
| 169 |
-
- OpenEnv docs: https://meta-pytorch.org/OpenEnv/
|
| 170 |
-
- Spider dataset: https://huggingface.co/datasets/xlangai/spider
|
| 171 |
-
- TRL OpenEnv docs: https://huggingface.co/docs/trl/openenv
|
| 172 |
-
- Verification plan: `specs/F007-VERIFICATION_SPEC.md`
|
|
|
|
| 5 |
colorTo: green
|
| 6 |
sdk: docker
|
| 7 |
app_port: 8000
|
| 8 |
+
pinned: true
|
| 9 |
base_path: /web
|
| 10 |
---
|
| 11 |
|
| 12 |
+
# SQLEnv: Teaching Small Models to Explore Databases
|
| 13 |
|
| 14 |

|
| 15 |

|
| 16 |
+

|
| 17 |
|
| 18 |
+
SQLEnv is an RL environment for training small language models to answer questions about SQL databases through iterative exploration. Instead of producing one-shot SQL from a fully visible schema, the agent discovers the schema step by step using four tools: DESCRIBE, SAMPLE, QUERY, and ANSWER.
|
| 19 |
|
| 20 |
+
Built on [OpenEnv](https://github.com/meta-pytorch/OpenEnv) and trained with [TRL](https://huggingface.co/docs/trl)'s GRPO implementation. A 0.6B parameter model trained in this environment goes from 0% to ~30% accuracy on a curated Spider subset, learning to explore schemas, recover from SQL errors, and format answers correctly.
|
| 21 |
|
| 22 |
+
**[Blog post](https://hjerpe-sqlenv-blog.static.hf.space)** | **[Live environment](https://huggingface.co/spaces/hjerpe/sql_env)** | **[Training notebook](notebooks/train_grpo.ipynb)**
|
| 23 |
|
| 24 |
## Quick Start
|
| 25 |
|
|
|
|
|
|
|
| 26 |
```bash
|
| 27 |
uv sync
|
|
|
|
| 28 |
uv run pytest tests/ -v
|
| 29 |
```
|
| 30 |
|
| 31 |
+
Run the environment locally:
|
| 32 |
|
| 33 |
```bash
|
| 34 |
uv run uvicorn server.app:app --reload --host 0.0.0.0 --port 8000
|
| 35 |
```
|
| 36 |
|
| 37 |
+
Or with Docker:
|
| 38 |
|
| 39 |
```bash
|
| 40 |
+
docker build -t sqlenv:latest -f server/Dockerfile .
|
| 41 |
+
docker run -p 8000:8000 sqlenv:latest
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
```
|
| 43 |
|
| 44 |
## How It Works
|
| 45 |
|
| 46 |
+
Each episode starts with a natural-language question and a list of table names. The schema (columns, types, relationships) is hidden. The agent uses four actions to explore:
|
| 47 |
|
| 48 |
+
| Action | Purpose |
|
| 49 |
+
|--------|---------|
|
| 50 |
+
| `DESCRIBE table` | Reveal column names, types, and row count |
|
| 51 |
+
| `SAMPLE table` | Preview representative rows |
|
| 52 |
+
| `QUERY sql` | Execute read-only SQL |
|
| 53 |
+
| `ANSWER value` | Submit a final answer (ends episode) |
|
| 54 |
|
| 55 |
+
The environment provides dense reward at each step (operational feedback + progress toward the answer) and a terminal reward for correctness (+1.0 correct, 0.0 wrong). See the [blog post](https://hjerpe-sqlenv-blog.static.hf.space) for details on the reward architecture.
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
+
```python
|
| 58 |
+
from server.sql_environment import SQLEnvironment, SQLAction
|
| 59 |
|
| 60 |
+
env = SQLEnvironment(questions_path="data/questions/questions_train.json",
|
| 61 |
+
db_dir="data/databases", tokenizer=tok)
|
| 62 |
+
obs = env.reset(seed=42)
|
| 63 |
+
obs = env.step(SQLAction(action_type="DESCRIBE", argument="employee"))
|
| 64 |
+
obs = env.step(SQLAction(action_type="QUERY", argument="SELECT COUNT(*) FROM employee"))
|
| 65 |
+
obs = env.step(SQLAction(action_type="ANSWER", argument="10"))
|
| 66 |
+
# obs.done=True, obs.reward=1.0
|
| 67 |
+
```
|
| 68 |
|
| 69 |
+
## Training
|
| 70 |
|
| 71 |
+
We train [Qwen3-0.6B](https://arxiv.org/abs/2505.09388) using [GRPO](https://arxiv.org/abs/2402.03300) (from DeepSeekMath) through TRL's `environment_factory`. The full pipeline (SFT warmup + two-phase GRPO) runs in ~5 hours on a single Colab L4.
|
| 72 |
+
|
| 73 |
+
**Notebooks:**
|
| 74 |
+
- **[train_grpo.ipynb](notebooks/train_grpo.ipynb)** runs the full SFT + GRPO pipeline
|
| 75 |
+
- **[compare_methods.ipynb](notebooks/compare_methods.ipynb)** evaluates base vs trained models
|
| 76 |
+
- **[showcase_sqlenv.ipynb](notebooks/showcase_sqlenv.ipynb)** lets you explore the environment interactively
|
| 77 |
+
|
| 78 |
+
**Local test (CPU, ~3 min):**
|
| 79 |
|
| 80 |
```bash
|
| 81 |
docker build -f Dockerfile.test -t sqlenv-test .
|
| 82 |
docker run --rm sqlenv-test
|
| 83 |
```
|
| 84 |
|
| 85 |
+
## Evaluation
|
|
|
|
|
|
|
| 86 |
|
| 87 |
+
All evaluation runs through the Green Agent evaluator:
|
| 88 |
|
| 89 |
+
```python
|
| 90 |
+
from sql_env.evaluation import evaluate, RandomPolicy, OraclePolicy
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
+
result = evaluate(env, policy, n_episodes=50, seed=0)
|
| 93 |
+
print(f"Accuracy: {result.success_rate:.1%}, Reward: {result.avg_reward:.3f}")
|
| 94 |
```
|
|
|
|
|
|
|
| 95 |
|
| 96 |
+
Results on our curated 10-database Spider subset (N=50, 2 runs):
|
| 97 |
+
|
| 98 |
+
| Method | Accuracy | Parse Rate | Avg Steps |
|
| 99 |
+
|--------|----------|------------|-----------|
|
| 100 |
+
| Zero-shot | 0% | 24-28% | 10.8-12.4 |
|
| 101 |
+
| 1-shot | 0-2% | 16-17% | 14.0-14.8 |
|
| 102 |
+
| 3-shot | 0% | 19-20% | 13.8-14.8 |
|
| 103 |
+
| GRPO v1 (2 epochs) | 28-30% | 95-100% | 3.5-4.0 |
|
| 104 |
+
| GRPO v2 (4 epochs) | 24-32% | 87-95% | 3.5-4.0 |
|
| 105 |
|
| 106 |
+
This evaluation is not comparable to the official Spider leaderboard, which uses different scoring, full-schema input, and a broader database set. See the [blog post](https://hjerpe-sqlenv-blog.static.hf.space) for detailed analysis.
|
| 107 |
|
| 108 |
+
## Data
|
|
|
|
|
|
|
| 109 |
|
| 110 |
+
676 questions (473 train, 203 eval) across 10 Spider databases with difficulty labels, plus 120 multi-turn SFT warmup trajectories generated from gold SQL. See [docs/data-sources.md](docs/data-sources.md) for full details on provenance, curation, and regeneration.
|
| 111 |
|
| 112 |
+
Data in `data/` is adapted from [Spider](https://yale-lily.github.io/spider) (Yu et al., 2018) and shared under CC BY-SA 4.0. See [DATA_LICENSE](DATA_LICENSE).
|
|
|
|
|
|
|
| 113 |
|
| 114 |
## Project Structure
|
| 115 |
|
| 116 |
+
```
|
| 117 |
+
sqlenv/
|
| 118 |
+
├── __init__.py, client.py, models.py # Core types and client
|
| 119 |
+
├── server/
|
| 120 |
+
│ ├── app.py # FastAPI server
|
| 121 |
+
│ ├── sql_environment.py # Environment implementation
|
| 122 |
+
│ ├── reward.py # Three-layer reward function
|
| 123 |
+
│ ├── verifier.py # Answer verification
|
| 124 |
+
│ └── Dockerfile # HF Spaces deployment
|
| 125 |
+
├── evaluation/ # Green Agent evaluator, policies
|
| 126 |
+
├── training/ # TRL adapter, data loading
|
| 127 |
+
├── scripts/ # Data curation, SFT generation
|
| 128 |
+
├── notebooks/ # Training, evaluation, showcase
|
| 129 |
+
├── data/
|
| 130 |
+
│ ├── databases/ # 10 Spider SQLite databases
|
| 131 |
+
│ ├── questions/ # Train/eval question sets
|
| 132 |
+
│ └── sft/ # SFT warmup trajectories
|
| 133 |
+
├── configs/ # Training configurations
|
| 134 |
+
├── tests/ # Unit and integration tests
|
| 135 |
+
└── docs/
|
| 136 |
+
├── data-sources.md # Data provenance
|
| 137 |
+
└── ARCHITECTURE.md # System architecture
|
| 138 |
```
|
| 139 |
|
| 140 |
+
## References
|
| 141 |
|
| 142 |
+
- Yu et al. (2018). [Spider: A Large-Scale Human-Labeled Dataset for Complex and Cross-Domain Semantic Parsing and Text-to-SQL Task](https://yale-lily.github.io/spider). EMNLP.
|
| 143 |
+
- Shao et al. (2024). [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://arxiv.org/abs/2402.03300). (GRPO algorithm)
|
| 144 |
+
- Ng, Harada, Russell (1999). [Policy Invariance Under Reward Transformations](https://people.eecs.berkeley.edu/~pabbeel/cs287-fa09/readings/NgHaradaRussell-shaping-ICML1999.pdf). ICML.
|
| 145 |
+
- [OpenEnv framework](https://github.com/meta-pytorch/OpenEnv)
|
| 146 |
+
- [TRL OpenEnv docs](https://huggingface.co/docs/trl/openenv)
|
| 147 |
|
| 148 |
+
## License
|
| 149 |
|
| 150 |
+
Code: [MIT](LICENSE). Data: [CC BY-SA 4.0](DATA_LICENSE).
|
|
|
|
|
|
|
|
|
|
|
|
docs/ARCHITECTURE.md
CHANGED
|
@@ -2,7 +2,7 @@
|
|
| 2 |
|
| 3 |
> Last updated: 2026-03-29
|
| 4 |
|
| 5 |
-
System map for SQLEnv
|
| 6 |
|
| 7 |
**Goals:**
|
| 8 |
- Show how components connect (system map + key flows)
|
|
@@ -11,8 +11,8 @@ System map for SQLEnv — an RL environment where agents learn interactive SQL e
|
|
| 11 |
- Keep invariants legible (what must stay true)
|
| 12 |
|
| 13 |
**Non-goals:**
|
| 14 |
-
-
|
| 15 |
-
-
|
| 16 |
|
| 17 |
---
|
| 18 |
|
|
@@ -45,7 +45,7 @@ System map for SQLEnv — an RL environment where agents learn interactive SQL e
|
|
| 45 |
────────── │ DESCRIBE → PRAGMA │
|
| 46 |
+──────────────+ │ SAMPLE → SELECT N │
|
| 47 |
│ evaluate() │──> env.reset/step │ QUERY → SQL exec │
|
| 48 |
-
│ policies
|
| 49 |
│ .py │ +────────┬───────────────+
|
| 50 |
+──────────────+ │
|
| 51 |
│ v
|
|
@@ -267,7 +267,7 @@ class EpisodeContext:
|
|
| 267 |
cumulative_new_info_reward: float = 0.0
|
| 268 |
```
|
| 269 |
|
| 270 |
-
**POMDP design:** The agent sees `SQLObservation`
|
| 271 |
|
| 272 |
---
|
| 273 |
|
|
@@ -348,7 +348,7 @@ except ImportError:
|
|
| 348 |
| `QUESTIONS_PATH` | No | `data/questions/student_assessment.json` | Questions JSON |
|
| 349 |
| `DB_DIR` | No | `data/databases/` | SQLite database directory |
|
| 350 |
| `TOKENIZER_NAME` | No | `mistralai/Mistral-7B-Instruct-v0.1` | HuggingFace tokenizer |
|
| 351 |
-
| `PORT` | No | `8000` | Server port
|
| 352 |
|
| 353 |
---
|
| 354 |
|
|
@@ -431,7 +431,7 @@ uv run openenv build # build Docker image
|
|
| 431 |
uv run openenv push # push to HF Spaces
|
| 432 |
```
|
| 433 |
|
| 434 |
-
The Dockerfile uses multi-stage build with `openenv-base`, runs as non-root `appuser`, bundles Spider databases, and exposes
|
| 435 |
|
| 436 |
---
|
| 437 |
|
|
@@ -451,7 +451,7 @@ The Dockerfile uses multi-stage build with `openenv-base`, runs as non-root `app
|
|
| 451 |
|------|------------|
|
| 452 |
| Episode | One question-answering session: reset -> N steps -> terminal |
|
| 453 |
| Action type | One of: DESCRIBE, SAMPLE, QUERY, ANSWER |
|
| 454 |
-
| POMDP | Partially observable MDP
|
| 455 |
| Spider | Academic text-to-SQL benchmark dataset (10 DBs used) |
|
| 456 |
| OpenEnv | Meta's RL environment framework (Environment, EnvClient) |
|
| 457 |
| Green Agent | OpenEnv's evaluation wrapper pattern |
|
|
@@ -464,10 +464,6 @@ The Dockerfile uses multi-stage build with `openenv-base`, runs as non-root `app
|
|
| 464 |
|
| 465 |
## References
|
| 466 |
|
| 467 |
-
- Docs index: `docs/README.md`
|
| 468 |
-
- Operations: `docs/RUNBOOK.md`
|
| 469 |
-
- Vision: `vision/VISION.md`
|
| 470 |
-
- Feature specs: `specs/FEATURES.json`
|
| 471 |
- OpenEnv framework: https://github.com/meta-pytorch/OpenEnv
|
| 472 |
- Spider dataset: https://huggingface.co/datasets/xlangai/spider
|
| 473 |
- TRL OpenEnv docs: https://huggingface.co/docs/trl/openenv
|
|
|
|
| 2 |
|
| 3 |
> Last updated: 2026-03-29
|
| 4 |
|
| 5 |
+
System map for SQLEnv, an RL environment where agents learn interactive SQL exploration via the OpenEnv framework.
|
| 6 |
|
| 7 |
**Goals:**
|
| 8 |
- Show how components connect (system map + key flows)
|
|
|
|
| 11 |
- Keep invariants legible (what must stay true)
|
| 12 |
|
| 13 |
**Non-goals:**
|
| 14 |
+
- Exhaustive API reference
|
| 15 |
+
- Training hyperparameter tuning guide
|
| 16 |
|
| 17 |
---
|
| 18 |
|
|
|
|
| 45 |
────────── │ DESCRIBE → PRAGMA │
|
| 46 |
+──────────────+ │ SAMPLE → SELECT N │
|
| 47 |
│ evaluate() │──> env.reset/step │ QUERY → SQL exec │
|
| 48 |
+
│ policies │ │ ANSWER → verifier │
|
| 49 |
│ .py │ +────────┬───────────────+
|
| 50 |
+──────────────+ │
|
| 51 |
│ v
|
|
|
|
| 267 |
cumulative_new_info_reward: float = 0.0
|
| 268 |
```
|
| 269 |
|
| 270 |
+
**POMDP design:** The agent sees `SQLObservation`. The server holds `EpisodeContext`. The agent never sees gold answers, progress scores, or the full database. This separation forces exploration.
|
| 271 |
|
| 272 |
---
|
| 273 |
|
|
|
|
| 348 |
| `QUESTIONS_PATH` | No | `data/questions/student_assessment.json` | Questions JSON |
|
| 349 |
| `DB_DIR` | No | `data/databases/` | SQLite database directory |
|
| 350 |
| `TOKENIZER_NAME` | No | `mistralai/Mistral-7B-Instruct-v0.1` | HuggingFace tokenizer |
|
| 351 |
+
| `PORT` | No | `8000` | Server port |
|
| 352 |
|
| 353 |
---
|
| 354 |
|
|
|
|
| 431 |
uv run openenv push # push to HF Spaces
|
| 432 |
```
|
| 433 |
|
| 434 |
+
The Dockerfile uses multi-stage build with `openenv-base`, runs as non-root `appuser`, bundles Spider databases, and exposes port 8000.
|
| 435 |
|
| 436 |
---
|
| 437 |
|
|
|
|
| 451 |
|------|------------|
|
| 452 |
| Episode | One question-answering session: reset -> N steps -> terminal |
|
| 453 |
| Action type | One of: DESCRIBE, SAMPLE, QUERY, ANSWER |
|
| 454 |
+
| POMDP | Partially observable MDP. Agent acts under uncertainty |
|
| 455 |
| Spider | Academic text-to-SQL benchmark dataset (10 DBs used) |
|
| 456 |
| OpenEnv | Meta's RL environment framework (Environment, EnvClient) |
|
| 457 |
| Green Agent | OpenEnv's evaluation wrapper pattern |
|
|
|
|
| 464 |
|
| 465 |
## References
|
| 466 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 467 |
- OpenEnv framework: https://github.com/meta-pytorch/OpenEnv
|
| 468 |
- Spider dataset: https://huggingface.co/datasets/xlangai/spider
|
| 469 |
- TRL OpenEnv docs: https://huggingface.co/docs/trl/openenv
|
docs/data-sources.md
CHANGED
|
@@ -14,7 +14,7 @@ so a fresh clone works offline after `uv sync`.
|
|
| 14 |
| DB allowlist | `data/questions/db_list.json` | hand-curated subset | 10 db_ids |
|
| 15 |
| SFT trajectories | `data/sft/sft_trajectories.json` | generated from gold SQL | 120 trajectories |
|
| 16 |
|
| 17 |
-
Total:
|
| 18 |
warmup trajectories.
|
| 19 |
|
| 20 |
## Upstream: Spider
|
|
@@ -26,7 +26,7 @@ gold SQL query, and a target database. We use two mirrors:
|
|
| 26 |
|
| 27 |
1. **Questions** via HuggingFace Datasets: [`xlangai/spider`](https://huggingface.co/datasets/xlangai/spider)
|
| 28 |
— loaded with `datasets.load_dataset("xlangai/spider", split=...)` in
|
| 29 |
-
`scripts/
|
| 30 |
2. **SQLite databases** via the Spider GitHub mirror:
|
| 31 |
- `https://raw.githubusercontent.com/taoyds/spider/master/database/{db_id}/{db_id}.sqlite`
|
| 32 |
- Fallback: the official Google Drive archive
|
|
@@ -56,8 +56,8 @@ database. This prevents train/eval leakage at the schema level:
|
|
| 56 |
dog_kennels, employee_hire_evaluation, flight_2, student_assessment`
|
| 57 |
- **Eval databases** (4): `flight_2, pets_1, poker_player, world_1`
|
| 58 |
|
| 59 |
-
`flight_2` appears in both
|
| 60 |
-
saw during training. `sql_env.training.data_loading.validate_no_data_leak`
|
| 61 |
asserts zero question-text overlap between the two files at load time.
|
| 62 |
|
| 63 |
## Question files
|
|
@@ -86,10 +86,13 @@ with this shape (actual sample from `car_1` train):
|
|
| 86 |
| train | 473 | 435 | 32 | 6 |
|
| 87 |
| eval | 203 | 185 | 18 | 0 |
|
| 88 |
|
| 89 |
-
The easy-heavy distribution is deliberate for the 0.6B capacity ceiling
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
|
|
|
|
|
|
|
|
|
| 93 |
|
| 94 |
### Curation pipeline
|
| 95 |
|
|
@@ -126,18 +129,26 @@ one, runs the real `SQLEnvironment` programmatically:
|
|
| 126 |
3. `answer(gold_answer)` — terminal step
|
| 127 |
|
| 128 |
The captured sequence becomes an assistant-labelled trajectory. This is
|
| 129 |
-
**not synthetic text**
|
| 130 |
-
responses the model will see at training and inference time,
|
| 131 |
-
|
|
|
|
| 132 |
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
|
| 137 |
Why multi-turn matters: an earlier per-turn SFT (347 single-turn
|
| 138 |
-
examples) taught the model to always call `describe`
|
| 139 |
-
|
| 140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
## How to regenerate from scratch
|
| 143 |
|
|
@@ -146,8 +157,8 @@ Multi-turn teaches the full `describe → query → answer` sequence. See
|
|
| 146 |
uv run python scripts/download_spider_databases.py --db-id all
|
| 147 |
|
| 148 |
# 2. Raw Spider questions (via HF Datasets)
|
| 149 |
-
uv run python scripts/
|
| 150 |
-
uv run python scripts/
|
| 151 |
|
| 152 |
# 3. Curate into questions_train.json / questions_eval.json
|
| 153 |
uv run python scripts/curate_questions.py
|
|
@@ -164,7 +175,7 @@ snapshot.
|
|
| 164 |
## What we deliberately do not use
|
| 165 |
|
| 166 |
- **BIRD** (Li et al., 2023) — larger, harder text-to-SQL benchmark. Out
|
| 167 |
-
of scope for a 0.6B model
|
| 168 |
- **WikiSQL** — single-table only, doesn't exercise the multi-turn
|
| 169 |
exploration the environment is built for.
|
| 170 |
- **Synthetic LLM-generated questions** — we want Spider's human-written
|
|
|
|
| 14 |
| DB allowlist | `data/questions/db_list.json` | hand-curated subset | 10 db_ids |
|
| 15 |
| SFT trajectories | `data/sft/sft_trajectories.json` | generated from gold SQL | 120 trajectories |
|
| 16 |
|
| 17 |
+
Total: 676 questions across 10 Spider databases, plus 120 multi-turn SFT
|
| 18 |
warmup trajectories.
|
| 19 |
|
| 20 |
## Upstream: Spider
|
|
|
|
| 26 |
|
| 27 |
1. **Questions** via HuggingFace Datasets: [`xlangai/spider`](https://huggingface.co/datasets/xlangai/spider)
|
| 28 |
— loaded with `datasets.load_dataset("xlangai/spider", split=...)` in
|
| 29 |
+
`scripts/download_spider_questions.py`.
|
| 30 |
2. **SQLite databases** via the Spider GitHub mirror:
|
| 31 |
- `https://raw.githubusercontent.com/taoyds/spider/master/database/{db_id}/{db_id}.sqlite`
|
| 32 |
- Fallback: the official Google Drive archive
|
|
|
|
| 56 |
dog_kennels, employee_hire_evaluation, flight_2, student_assessment`
|
| 57 |
- **Eval databases** (4): `flight_2, pets_1, poker_player, world_1`
|
| 58 |
|
| 59 |
+
`flight_2` appears in both. The other eval DBs are schemas the model
|
| 60 |
+
never saw during training. `sql_env.training.data_loading.validate_no_data_leak`
|
| 61 |
asserts zero question-text overlap between the two files at load time.
|
| 62 |
|
| 63 |
## Question files
|
|
|
|
| 86 |
| train | 473 | 435 | 32 | 6 |
|
| 87 |
| eval | 203 | 185 | 18 | 0 |
|
| 88 |
|
| 89 |
+
The easy-heavy distribution is deliberate for the 0.6B capacity ceiling.
|
| 90 |
+
Extended GRPO training on harder questions produced identical accuracy,
|
| 91 |
+
which indicates the ceiling comes from pretraining knowledge rather than
|
| 92 |
+
training budget. Medium and hard questions stay in the mix for Phase 2
|
| 93 |
+
exposure but are not where this model size gains accuracy. See the
|
| 94 |
+
"Limitations at 0.6B Parameters" section of the
|
| 95 |
+
[blog post](https://hjerpe-sqlenv-blog.static.hf.space).
|
| 96 |
|
| 97 |
### Curation pipeline
|
| 98 |
|
|
|
|
| 129 |
3. `answer(gold_answer)` — terminal step
|
| 130 |
|
| 131 |
The captured sequence becomes an assistant-labelled trajectory. This is
|
| 132 |
+
**not synthetic text**. The assistant turns wrap the actual environment
|
| 133 |
+
responses the model will see at training and inference time, so the
|
| 134 |
+
SFT-warmed reference policy already expects real env output when GRPO
|
| 135 |
+
takes over.
|
| 136 |
|
| 137 |
+
SFT uses 120 trajectories rather than one per training question. The
|
| 138 |
+
subset is chosen to cover each database and difficulty bucket. See
|
| 139 |
+
`scripts/generate_sft_data.py` for the selection logic.
|
| 140 |
|
| 141 |
Why multi-turn matters: an earlier per-turn SFT (347 single-turn
|
| 142 |
+
examples) taught the model to always call `describe`. Half those
|
| 143 |
+
examples were describe calls, so the model learned "when asked a
|
| 144 |
+
question, call describe." Under a KL penalty during GRPO, every rollout
|
| 145 |
+
stayed identical, the advantage between rollouts was zero, and no policy
|
| 146 |
+
gradient could form. Multi-turn SFT (120 full trajectories trained with
|
| 147 |
+
`assistant_only_loss`) instead teaches the full
|
| 148 |
+
`describe → query → answer` sequence as a coherent strategy, which GRPO
|
| 149 |
+
then refines into error recovery and answer formatting. See the
|
| 150 |
+
"Training" section of the
|
| 151 |
+
[blog post](https://hjerpe-sqlenv-blog.static.hf.space).
|
| 152 |
|
| 153 |
## How to regenerate from scratch
|
| 154 |
|
|
|
|
| 157 |
uv run python scripts/download_spider_databases.py --db-id all
|
| 158 |
|
| 159 |
# 2. Raw Spider questions (via HF Datasets)
|
| 160 |
+
uv run python scripts/download_spider_questions.py --db-id all --split train
|
| 161 |
+
uv run python scripts/download_spider_questions.py --db-id all --split validation
|
| 162 |
|
| 163 |
# 3. Curate into questions_train.json / questions_eval.json
|
| 164 |
uv run python scripts/curate_questions.py
|
|
|
|
| 175 |
## What we deliberately do not use
|
| 176 |
|
| 177 |
- **BIRD** (Li et al., 2023) — larger, harder text-to-SQL benchmark. Out
|
| 178 |
+
of scope for a 0.6B model. Revisit for a larger-model follow-up.
|
| 179 |
- **WikiSQL** — single-table only, doesn't exercise the multi-turn
|
| 180 |
exploration the environment is built for.
|
| 181 |
- **Synthetic LLM-generated questions** — we want Spider's human-written
|
notebooks/showcase_sqlenv.ipynb
CHANGED
|
@@ -29,40 +29,83 @@
|
|
| 29 |
},
|
| 30 |
{
|
| 31 |
"cell_type": "code",
|
| 32 |
-
"execution_count":
|
| 33 |
"metadata": {},
|
| 34 |
-
"outputs": [
|
| 35 |
-
{
|
| 36 |
-
"name": "stdout",
|
| 37 |
-
"output_type": "stream",
|
| 38 |
-
"text": [
|
| 39 |
-
"Project root: /Users/hjerp/Projects/sql-env\n"
|
| 40 |
-
]
|
| 41 |
-
}
|
| 42 |
-
],
|
| 43 |
"source": [
|
| 44 |
"import os\n",
|
|
|
|
| 45 |
"import sys\n",
|
| 46 |
"from pathlib import Path\n",
|
| 47 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
"\n",
|
| 49 |
-
"
|
| 50 |
-
"
|
| 51 |
-
"
|
| 52 |
-
"
|
| 53 |
-
"
|
| 54 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
"\n",
|
|
|
|
|
|
|
| 56 |
"\n",
|
| 57 |
-
"PROJECT_ROOT = find_project_root()\n",
|
| 58 |
-
"os.chdir(PROJECT_ROOT)\n",
|
| 59 |
"if str(PROJECT_ROOT) not in sys.path:\n",
|
| 60 |
" sys.path.insert(0, str(PROJECT_ROOT))\n",
|
| 61 |
"\n",
|
| 62 |
-
"# In Colab, uncomment:\n",
|
| 63 |
-
"# !pip install -q git+https://github.com/hjerpe/sql-env.git\n",
|
| 64 |
-
"# !python scripts/download_spider_databases.py\n",
|
| 65 |
-
"\n",
|
| 66 |
"print(f\"Project root: {PROJECT_ROOT}\")"
|
| 67 |
]
|
| 68 |
},
|
|
@@ -553,62 +596,19 @@
|
|
| 553 |
},
|
| 554 |
{
|
| 555 |
"cell_type": "code",
|
| 556 |
-
"execution_count":
|
| 557 |
"metadata": {},
|
| 558 |
-
"outputs": [
|
| 559 |
-
{
|
| 560 |
-
"name": "stdout",
|
| 561 |
-
"output_type": "stream",
|
| 562 |
-
"text": [
|
| 563 |
-
"Q: List the id of students who registered some courses and the number of their registered courses?\n",
|
| 564 |
-
"\n",
|
| 565 |
-
" Step 1: DESCRIBE\n",
|
| 566 |
-
" Action: student_course_registrations\n",
|
| 567 |
-
" Result: Table 'Student_Course_Registrations' columns:\n",
|
| 568 |
-
"- student_id: INTEGER\n",
|
| 569 |
-
"- course_id: INTEGER\n",
|
| 570 |
-
"- registration_date: DATETIME\n",
|
| 571 |
-
"Row count: 9\n",
|
| 572 |
-
" Reward: +0.0150\n",
|
| 573 |
-
"\n",
|
| 574 |
-
" Step 2: DESCRIBE\n",
|
| 575 |
-
" Action: students\n",
|
| 576 |
-
" Result: Table 'Students' columns:\n",
|
| 577 |
-
"- student_id: INTEGER\n",
|
| 578 |
-
"- student_details: VARCHAR(255)\n",
|
| 579 |
-
"Row count: 8\n",
|
| 580 |
-
" Reward: +0.0150\n",
|
| 581 |
-
"\n",
|
| 582 |
-
" Step 3: QUERY\n",
|
| 583 |
-
" SQL:\n",
|
| 584 |
-
" SELECT T1.student_id , count(*) \n",
|
| 585 |
-
" FROM students AS T1 \n",
|
| 586 |
-
" JOIN student_course_registrations AS T2 \n",
|
| 587 |
-
" ON T1.student_id = T2.student_id \n",
|
| 588 |
-
" GROUP BY T1.student_id\n",
|
| 589 |
-
" Result: 1. 111 | 1\n",
|
| 590 |
-
"2. 121 | 2\n",
|
| 591 |
-
"3. 131 | 1\n",
|
| 592 |
-
"4. 141 | 2\n",
|
| 593 |
-
"5. 151 | 1\n",
|
| 594 |
-
"6. 161 | 1\n",
|
| 595 |
-
"7. 171 | 1\n",
|
| 596 |
-
" Reward: +0.1500\n",
|
| 597 |
-
"\n",
|
| 598 |
-
" Step 4: ANSWER\n",
|
| 599 |
-
" Action: [[111, 1], [121, 2], [131, 1], [141, 2], [151, 1], [161, 1], [171, 1]]\n",
|
| 600 |
-
" Result: Answer submitted: correct.\n",
|
| 601 |
-
" Reward: +1.0000\n",
|
| 602 |
-
"\n",
|
| 603 |
-
"Total reward: 1.180\n",
|
| 604 |
-
" Exploration (L1+L2): 0.180 (3 steps)\n",
|
| 605 |
-
" Terminal (L3): 1.000\n"
|
| 606 |
-
]
|
| 607 |
-
}
|
| 608 |
-
],
|
| 609 |
"source": [
|
| 610 |
"import re\n",
|
| 611 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 612 |
"\n",
|
| 613 |
"def format_sql(sql):\n",
|
| 614 |
" \"\"\"Simple SQL formatter for display.\"\"\"\n",
|
|
@@ -617,14 +617,58 @@
|
|
| 617 |
" return formatted\n",
|
| 618 |
"\n",
|
| 619 |
"\n",
|
| 620 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 621 |
"obs = env.reset(seed=0)\n",
|
| 622 |
"oracle = OraclePolicy(questions)\n",
|
| 623 |
"step_rewards = []\n",
|
|
|
|
| 624 |
"\n",
|
| 625 |
"print(f\"Q: {obs.question}\\n\")\n",
|
| 626 |
"while not obs.done:\n",
|
| 627 |
" action = oracle.select_action(obs)\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 628 |
" obs = env.step(action)\n",
|
| 629 |
" reward = obs.reward or 0.0\n",
|
| 630 |
" step_rewards.append(reward)\n",
|
|
@@ -640,7 +684,7 @@
|
|
| 640 |
" print(f\" Result: {obs.result}\")\n",
|
| 641 |
" if obs.error:\n",
|
| 642 |
" print(f\" Error: {obs.error}\")\n",
|
| 643 |
-
" print(f\" Reward: {
|
| 644 |
" print()\n",
|
| 645 |
"\n",
|
| 646 |
"exploration = sum(step_rewards[:-1]) if len(step_rewards) > 1 else 0.0\n",
|
|
@@ -655,9 +699,17 @@
|
|
| 655 |
"cell_type": "markdown",
|
| 656 |
"metadata": {},
|
| 657 |
"source": [
|
| 658 |
-
"## 8.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 659 |
"\n",
|
| 660 |
-
"The
|
| 661 |
]
|
| 662 |
},
|
| 663 |
{
|
|
@@ -666,25 +718,77 @@
|
|
| 666 |
"metadata": {},
|
| 667 |
"outputs": [],
|
| 668 |
"source": [
|
| 669 |
-
"
|
| 670 |
-
"
|
| 671 |
-
"
|
| 672 |
-
"#
|
| 673 |
-
"#\n",
|
| 674 |
-
"#
|
| 675 |
-
"
|
| 676 |
-
"
|
| 677 |
-
"
|
| 678 |
-
"
|
| 679 |
-
"#
|
| 680 |
-
"#\n",
|
| 681 |
-
"#
|
| 682 |
-
"#
|
| 683 |
-
"
|
| 684 |
-
"
|
| 685 |
-
"
|
| 686 |
-
"\n",
|
| 687 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 688 |
]
|
| 689 |
},
|
| 690 |
{
|
|
|
|
| 29 |
},
|
| 30 |
{
|
| 31 |
"cell_type": "code",
|
| 32 |
+
"execution_count": null,
|
| 33 |
"metadata": {},
|
| 34 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
"source": [
|
| 36 |
"import os\n",
|
| 37 |
+
"import subprocess\n",
|
| 38 |
"import sys\n",
|
| 39 |
"from pathlib import Path\n",
|
| 40 |
"\n",
|
| 41 |
+
"IN_COLAB = \"google.colab\" in sys.modules\n",
|
| 42 |
+
"\n",
|
| 43 |
+
"if IN_COLAB:\n",
|
| 44 |
+
" # Colab: clone the repo, install the package, fetch Spider databases.\n",
|
| 45 |
+
" # Requires a GITHUB_TOKEN in Colab userdata if the repo is private.\n",
|
| 46 |
+
" from google.colab import userdata\n",
|
| 47 |
+
"\n",
|
| 48 |
+
" token = userdata.get(\"GITHUB_TOKEN\")\n",
|
| 49 |
+
" BRANCH = \"main\" # @param {type:\"string\"}\n",
|
| 50 |
+
" repo_url = f\"https://{token}@github.com/hjerpe/sql-env.git\"\n",
|
| 51 |
+
"\n",
|
| 52 |
+
" if Path(\"sql-env\").exists():\n",
|
| 53 |
+
" subprocess.check_call([\"git\", \"-C\", \"sql-env\", \"pull\", \"-q\"])\n",
|
| 54 |
+
" else:\n",
|
| 55 |
+
" subprocess.check_call([\"git\", \"clone\", \"-q\", \"-b\", BRANCH, repo_url])\n",
|
| 56 |
+
" os.chdir(\"sql-env\")\n",
|
| 57 |
+
"\n",
|
| 58 |
+
" print(\"Colab detected: installing dependencies...\")\n",
|
| 59 |
+
" subprocess.check_call(\n",
|
| 60 |
+
" [sys.executable, \"-m\", \"pip\", \"install\", \"-q\", \"--upgrade\", \"pip\"]\n",
|
| 61 |
+
" )\n",
|
| 62 |
+
" subprocess.check_call(\n",
|
| 63 |
+
" [\n",
|
| 64 |
+
" sys.executable,\n",
|
| 65 |
+
" \"-m\",\n",
|
| 66 |
+
" \"pip\",\n",
|
| 67 |
+
" \"install\",\n",
|
| 68 |
+
" \"-q\",\n",
|
| 69 |
+
" \"--no-deps\",\n",
|
| 70 |
+
" \"--force-reinstall\",\n",
|
| 71 |
+
" \".\",\n",
|
| 72 |
+
" ]\n",
|
| 73 |
+
" )\n",
|
| 74 |
+
" subprocess.check_call(\n",
|
| 75 |
+
" [\n",
|
| 76 |
+
" sys.executable,\n",
|
| 77 |
+
" \"-m\",\n",
|
| 78 |
+
" \"pip\",\n",
|
| 79 |
+
" \"install\",\n",
|
| 80 |
+
" \"-q\",\n",
|
| 81 |
+
" \"openenv-core[core]>=0.2.1\",\n",
|
| 82 |
+
" \"pydantic>=2.0.0\",\n",
|
| 83 |
+
" \"jmespath\",\n",
|
| 84 |
+
" ]\n",
|
| 85 |
+
" )\n",
|
| 86 |
+
" # Download Spider SQLite databases the notebook reads from\n",
|
| 87 |
+
" subprocess.check_call(\n",
|
| 88 |
+
" [sys.executable, \"scripts/download_spider_databases.py\", \"--db-id\", \"all\"]\n",
|
| 89 |
+
" )\n",
|
| 90 |
"\n",
|
| 91 |
+
" PROJECT_ROOT = Path.cwd()\n",
|
| 92 |
+
"else:\n",
|
| 93 |
+
" # Local: walk up from CWD to find the project root\n",
|
| 94 |
+
" def find_project_root() -> Path:\n",
|
| 95 |
+
" \"\"\"Walk up from CWD until pyproject.toml is found.\"\"\"\n",
|
| 96 |
+
" for parent in [Path.cwd(), *Path.cwd().parents]:\n",
|
| 97 |
+
" if (parent / \"pyproject.toml\").exists():\n",
|
| 98 |
+
" return parent\n",
|
| 99 |
+
" raise FileNotFoundError(\n",
|
| 100 |
+
" \"Could not locate project root (no pyproject.toml found)\"\n",
|
| 101 |
+
" )\n",
|
| 102 |
"\n",
|
| 103 |
+
" PROJECT_ROOT = find_project_root()\n",
|
| 104 |
+
" os.chdir(PROJECT_ROOT)\n",
|
| 105 |
"\n",
|
|
|
|
|
|
|
| 106 |
"if str(PROJECT_ROOT) not in sys.path:\n",
|
| 107 |
" sys.path.insert(0, str(PROJECT_ROOT))\n",
|
| 108 |
"\n",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
"print(f\"Project root: {PROJECT_ROOT}\")"
|
| 110 |
]
|
| 111 |
},
|
|
|
|
| 596 |
},
|
| 597 |
{
|
| 598 |
"cell_type": "code",
|
| 599 |
+
"execution_count": null,
|
| 600 |
"metadata": {},
|
| 601 |
+
"outputs": [],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 602 |
"source": [
|
| 603 |
"import re\n",
|
| 604 |
"\n",
|
| 605 |
+
"from server.reward import (\n",
|
| 606 |
+
" _EXEC_OK_REWARD,\n",
|
| 607 |
+
" _NEW_INFO_REWARD,\n",
|
| 608 |
+
" _REPEAT_PENALTY,\n",
|
| 609 |
+
" _STEP_COST,\n",
|
| 610 |
+
")\n",
|
| 611 |
+
"\n",
|
| 612 |
"\n",
|
| 613 |
"def format_sql(sql):\n",
|
| 614 |
" \"\"\"Simple SQL formatter for display.\"\"\"\n",
|
|
|
|
| 617 |
" return formatted\n",
|
| 618 |
"\n",
|
| 619 |
"\n",
|
| 620 |
+
"def explain_reward(action_type, error, is_repeat_query, total_reward):\n",
|
| 621 |
+
" \"\"\"Decompose a step reward into labeled components.\n",
|
| 622 |
+
"\n",
|
| 623 |
+
" Layer 1 components (step_cost, exec_ok, new_info, repeat_penalty) are\n",
|
| 624 |
+
" deterministic from action type + state, so we reconstruct them exactly\n",
|
| 625 |
+
" from the reward constants imported above. Layer 2 (progress delta on\n",
|
| 626 |
+
" QUERY) and Layer 3 (terminal on ANSWER) are not exposed in the\n",
|
| 627 |
+
" observation, so we recover them as 'total reward minus L1 sum' and\n",
|
| 628 |
+
" label them accordingly. The clip range [-0.10, +0.15] may adjust the\n",
|
| 629 |
+
" final value — any residual after layer reconstruction is labeled\n",
|
| 630 |
+
" 'clip_adjust'.\n",
|
| 631 |
+
" \"\"\"\n",
|
| 632 |
+
" at = action_type.upper()\n",
|
| 633 |
+
" parts = [(\"step_cost\", -_STEP_COST)] # always applied\n",
|
| 634 |
+
"\n",
|
| 635 |
+
" if error:\n",
|
| 636 |
+
" pass # no exec_ok when the action errored\n",
|
| 637 |
+
" elif at == \"QUERY\" and is_repeat_query:\n",
|
| 638 |
+
" parts.append((\"repeat_penalty\", -_REPEAT_PENALTY))\n",
|
| 639 |
+
" else:\n",
|
| 640 |
+
" parts.append((\"exec_ok\", +_EXEC_OK_REWARD))\n",
|
| 641 |
+
" if at == \"QUERY\":\n",
|
| 642 |
+
" parts.append((\"new_info\", +_NEW_INFO_REWARD))\n",
|
| 643 |
+
"\n",
|
| 644 |
+
" l1_sum = sum(v for _, v in parts)\n",
|
| 645 |
+
" remainder = total_reward - l1_sum\n",
|
| 646 |
+
"\n",
|
| 647 |
+
" if abs(remainder) > 1e-9:\n",
|
| 648 |
+
" if at == \"ANSWER\":\n",
|
| 649 |
+
" parts.append((\"terminal\", remainder))\n",
|
| 650 |
+
" elif at == \"QUERY\":\n",
|
| 651 |
+
" parts.append((\"layer2_progress\", remainder))\n",
|
| 652 |
+
" else:\n",
|
| 653 |
+
" parts.append((\"clip_adjust\", remainder))\n",
|
| 654 |
+
"\n",
|
| 655 |
+
" labels = \" + \".join(f\"{name}({v:+.3f})\" for name, v in parts)\n",
|
| 656 |
+
" return f\"{labels} = {total_reward:+.4f}\"\n",
|
| 657 |
+
"\n",
|
| 658 |
+
"\n",
|
| 659 |
+
"# Run one oracle episode and show per-step rewards with component breakdown\n",
|
| 660 |
"obs = env.reset(seed=0)\n",
|
| 661 |
"oracle = OraclePolicy(questions)\n",
|
| 662 |
"step_rewards = []\n",
|
| 663 |
+
"seen_queries: set[str] = set()\n",
|
| 664 |
"\n",
|
| 665 |
"print(f\"Q: {obs.question}\\n\")\n",
|
| 666 |
"while not obs.done:\n",
|
| 667 |
" action = oracle.select_action(obs)\n",
|
| 668 |
+
" is_repeat = action.action_type.upper() == \"QUERY\" and action.argument in seen_queries\n",
|
| 669 |
+
" if action.action_type.upper() == \"QUERY\":\n",
|
| 670 |
+
" seen_queries.add(action.argument)\n",
|
| 671 |
+
"\n",
|
| 672 |
" obs = env.step(action)\n",
|
| 673 |
" reward = obs.reward or 0.0\n",
|
| 674 |
" step_rewards.append(reward)\n",
|
|
|
|
| 684 |
" print(f\" Result: {obs.result}\")\n",
|
| 685 |
" if obs.error:\n",
|
| 686 |
" print(f\" Error: {obs.error}\")\n",
|
| 687 |
+
" print(f\" Reward: {explain_reward(action.action_type, obs.error, is_repeat, reward)}\")\n",
|
| 688 |
" print()\n",
|
| 689 |
"\n",
|
| 690 |
"exploration = sum(step_rewards[:-1]) if len(step_rewards) > 1 else 0.0\n",
|
|
|
|
| 699 |
"cell_type": "markdown",
|
| 700 |
"metadata": {},
|
| 701 |
"source": [
|
| 702 |
+
"## 8. Same Environment, Over the Wire\n",
|
| 703 |
+
"\n",
|
| 704 |
+
"The same `SQLEnvironment` runs as a Docker container on HuggingFace Spaces:\n",
|
| 705 |
+
"[**huggingface.co/spaces/hjerpe/sql_env**](https://huggingface.co/spaces/hjerpe/sql_env).\n",
|
| 706 |
+
"`SQLEnvClient` connects via WebSocket and provides the same `reset()` /\n",
|
| 707 |
+
"`step()` interface we used above — same action space, same observation shape,\n",
|
| 708 |
+
"same reward model. The only difference is that the SQLite database and\n",
|
| 709 |
+
"reward computation now live on a remote container instead of in this\n",
|
| 710 |
+
"Python process.\n",
|
| 711 |
"\n",
|
| 712 |
+
"The cell below drives one full episode against the live Space."
|
| 713 |
]
|
| 714 |
},
|
| 715 |
{
|
|
|
|
| 718 |
"metadata": {},
|
| 719 |
"outputs": [],
|
| 720 |
"source": [
|
| 721 |
+
"from sql_env.client import SQLEnvClient\n",
|
| 722 |
+
"from sql_env.models import SQLAction\n",
|
| 723 |
+
"\n",
|
| 724 |
+
"# Live hosted Space. This is the URL anyone in the world can point a client\n",
|
| 725 |
+
"# at — no local setup required. The first request may take ~30s if the\n",
|
| 726 |
+
"# container is cold-starting.\n",
|
| 727 |
+
"SPACE_URL = \"https://hjerpe-sql-env.hf.space\"\n",
|
| 728 |
+
"\n",
|
| 729 |
+
"print(f\"Connecting to {SPACE_URL} ...\\n\")\n",
|
| 730 |
+
"\n",
|
| 731 |
+
"# openenv-core's SQLEnvClient is sync-by-default in older versions but\n",
|
| 732 |
+
"# async-by-default in newer ones (the newer API exposes .sync() as an\n",
|
| 733 |
+
"# explicit synchronous wrapper). Detect at runtime so the cell works on\n",
|
| 734 |
+
"# both local dev installs and Colab's pinned >=0.2.1 version.\n",
|
| 735 |
+
"_remote_client = SQLEnvClient(base_url=SPACE_URL)\n",
|
| 736 |
+
"_remote_ctx = _remote_client.sync() if hasattr(_remote_client, \"sync\") else _remote_client\n",
|
| 737 |
+
"\n",
|
| 738 |
+
"try:\n",
|
| 739 |
+
" with _remote_ctx as remote_env:\n",
|
| 740 |
+
" # --- reset ---\n",
|
| 741 |
+
" result = remote_env.reset()\n",
|
| 742 |
+
" remote_obs = result.observation\n",
|
| 743 |
+
" print(f\"Q: {remote_obs.question}\")\n",
|
| 744 |
+
" tables = [\n",
|
| 745 |
+
" line.lstrip(\"- \").strip()\n",
|
| 746 |
+
" for line in remote_obs.schema_info.splitlines()[1:]\n",
|
| 747 |
+
" if line.strip()\n",
|
| 748 |
+
" ]\n",
|
| 749 |
+
" print(f\"Tables: {tables}\\n\")\n",
|
| 750 |
+
"\n",
|
| 751 |
+
" first_table = tables[0]\n",
|
| 752 |
+
"\n",
|
| 753 |
+
" # --- describe ---\n",
|
| 754 |
+
" result = remote_env.step(\n",
|
| 755 |
+
" SQLAction(action_type=\"DESCRIBE\", argument=first_table)\n",
|
| 756 |
+
" )\n",
|
| 757 |
+
" print(f\"DESCRIBE {first_table}\")\n",
|
| 758 |
+
" print(f\" reward: {result.observation.reward:+.4f}\")\n",
|
| 759 |
+
" # Line-based preview so truncation never cuts mid-word\n",
|
| 760 |
+
" _lines = result.observation.result.splitlines()\n",
|
| 761 |
+
" _preview = \"\\n \".join(_lines[:6])\n",
|
| 762 |
+
" _more = (\n",
|
| 763 |
+
" f\"\\n ... ({len(_lines) - 6} more lines)\"\n",
|
| 764 |
+
" if len(_lines) > 6\n",
|
| 765 |
+
" else \"\"\n",
|
| 766 |
+
" )\n",
|
| 767 |
+
" print(f\" result: {_preview}{_more}\\n\")\n",
|
| 768 |
+
"\n",
|
| 769 |
+
" # --- query ---\n",
|
| 770 |
+
" query_sql = f\"SELECT COUNT(*) FROM {first_table}\"\n",
|
| 771 |
+
" result = remote_env.step(\n",
|
| 772 |
+
" SQLAction(action_type=\"QUERY\", argument=query_sql)\n",
|
| 773 |
+
" )\n",
|
| 774 |
+
" print(f\"QUERY {query_sql}\")\n",
|
| 775 |
+
" print(f\" reward: {result.observation.reward:+.4f}\")\n",
|
| 776 |
+
" print(f\" result: {result.observation.result}\\n\")\n",
|
| 777 |
+
"\n",
|
| 778 |
+
" # --- answer (intentionally wrong — we're demoing plumbing, not correctness) ---\n",
|
| 779 |
+
" result = remote_env.step(\n",
|
| 780 |
+
" SQLAction(action_type=\"ANSWER\", argument=\"demo\")\n",
|
| 781 |
+
" )\n",
|
| 782 |
+
" print(f\"ANSWER demo\")\n",
|
| 783 |
+
" print(f\" done: {result.observation.done}\")\n",
|
| 784 |
+
" print(f\" reward: {result.observation.reward:+.4f}\")\n",
|
| 785 |
+
" print(\"\\nSame action space, same observation shape, same rewards — just running remotely.\")\n",
|
| 786 |
+
"except Exception as exc: # noqa: BLE001 — demo cell should not crash the notebook\n",
|
| 787 |
+
" print(f\"Remote call failed: {type(exc).__name__}: {exc}\")\n",
|
| 788 |
+
" print(\n",
|
| 789 |
+
" \"If the Space is sleeping, the first request usually wakes it. \"\n",
|
| 790 |
+
" \"Retry in ~30s, or skip this cell to run the notebook fully offline.\"\n",
|
| 791 |
+
" )"
|
| 792 |
]
|
| 793 |
},
|
| 794 |
{
|
scripts/download_spider_questions.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Script to download Spider dataset questions for specific databases.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
python download_spider_questions.py --db-id student_assessment
|
| 6 |
+
python download_spider_questions.py --db-id student_assessment --split validation
|
| 7 |
+
python download_spider_questions.py --db-id all # downloads all db_ids
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import json
|
| 11 |
+
import argparse
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from datasets import load_dataset
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def download_spider_questions(
|
| 17 |
+
db_id: str = "student_assessment",
|
| 18 |
+
split: str = "train",
|
| 19 |
+
output_dir: str = "data/questions",
|
| 20 |
+
) -> None:
|
| 21 |
+
"""Download Spider dataset questions for specified database(s).
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
db_id: Database ID to filter by, or "all" to get all databases
|
| 25 |
+
split: Dataset split ("train" or "validation")
|
| 26 |
+
output_dir: Directory to save JSON files
|
| 27 |
+
"""
|
| 28 |
+
output_path = Path(output_dir)
|
| 29 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
| 30 |
+
|
| 31 |
+
print(f"Loading Spider dataset ({split} split)...")
|
| 32 |
+
dataset = load_dataset("xlangai/spider", split=split)
|
| 33 |
+
|
| 34 |
+
if db_id.lower() == "all":
|
| 35 |
+
# Group by db_id
|
| 36 |
+
grouped = {}
|
| 37 |
+
for item in dataset:
|
| 38 |
+
current_db_id = item.get("db_id")
|
| 39 |
+
if current_db_id not in grouped:
|
| 40 |
+
grouped[current_db_id] = []
|
| 41 |
+
grouped[current_db_id].append(item)
|
| 42 |
+
|
| 43 |
+
total_questions = 0
|
| 44 |
+
for current_db_id, questions in grouped.items():
|
| 45 |
+
filepath = output_path / f"{current_db_id}.json"
|
| 46 |
+
with open(filepath, "w") as f:
|
| 47 |
+
json.dump(questions, f, indent=2)
|
| 48 |
+
print(f" {current_db_id}: {len(questions)} questions → {filepath}")
|
| 49 |
+
total_questions += len(questions)
|
| 50 |
+
|
| 51 |
+
print(f"\nTotal: {total_questions} questions across {len(grouped)} databases")
|
| 52 |
+
else:
|
| 53 |
+
# Filter for specific db_id
|
| 54 |
+
filtered_data = [item for item in dataset if item.get("db_id") == db_id]
|
| 55 |
+
|
| 56 |
+
if not filtered_data:
|
| 57 |
+
print(f"No questions found for db_id='{db_id}'")
|
| 58 |
+
return
|
| 59 |
+
|
| 60 |
+
filepath = output_path / f"{db_id}.json"
|
| 61 |
+
with open(filepath, "w") as f:
|
| 62 |
+
json.dump(filtered_data, f, indent=2)
|
| 63 |
+
|
| 64 |
+
print(f"Found {len(filtered_data)} questions for db_id='{db_id}'")
|
| 65 |
+
print(f"Saved to {filepath}")
|
| 66 |
+
|
| 67 |
+
# Print sample
|
| 68 |
+
if filtered_data:
|
| 69 |
+
sample = filtered_data[0]
|
| 70 |
+
print("\nFirst question sample:")
|
| 71 |
+
print(
|
| 72 |
+
json.dumps(
|
| 73 |
+
{k: v for k, v in sample.items() if k != "evidence"}, indent=2
|
| 74 |
+
)
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
if __name__ == "__main__":
|
| 79 |
+
parser = argparse.ArgumentParser(
|
| 80 |
+
description="Download Spider dataset questions for specific databases",
|
| 81 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 82 |
+
)
|
| 83 |
+
parser.add_argument(
|
| 84 |
+
"--db-id",
|
| 85 |
+
type=str,
|
| 86 |
+
default="student_assessment",
|
| 87 |
+
help="Database ID to filter by (or 'all' for all databases)",
|
| 88 |
+
)
|
| 89 |
+
parser.add_argument(
|
| 90 |
+
"--split",
|
| 91 |
+
type=str,
|
| 92 |
+
default="train",
|
| 93 |
+
choices=["train", "validation"],
|
| 94 |
+
help="Dataset split to download",
|
| 95 |
+
)
|
| 96 |
+
parser.add_argument(
|
| 97 |
+
"--output-dir",
|
| 98 |
+
type=str,
|
| 99 |
+
default="data/questions",
|
| 100 |
+
help="Directory to save JSON files",
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
args = parser.parse_args()
|
| 104 |
+
download_spider_questions(
|
| 105 |
+
db_id=args.db_id, split=args.split, output_dir=args.output_dir
|
| 106 |
+
)
|
server/app.py
CHANGED
|
@@ -80,7 +80,16 @@ def create_sql_environment():
|
|
| 80 |
)
|
| 81 |
|
| 82 |
|
| 83 |
-
# Create the FastAPI app
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
app = create_app(
|
| 85 |
create_sql_environment,
|
| 86 |
SQLAction,
|
|
|
|
| 80 |
)
|
| 81 |
|
| 82 |
|
| 83 |
+
# Create the FastAPI app.
|
| 84 |
+
#
|
| 85 |
+
# Note: hosted Space is single-session. External users running TRL's
|
| 86 |
+
# GRPOTrainer against https://hjerpe-sql-env.hf.space with
|
| 87 |
+
# num_generations > 1 will hit openenv-core's default 1-session cap.
|
| 88 |
+
# Fix requires (a) auditing SQLEnvironment for shared mutable state
|
| 89 |
+
# across sessions, (b) declaring SUPPORTS_CONCURRENT_SESSIONS=True on
|
| 90 |
+
# the class, (c) passing max_concurrent_envs=64 here. Deferred as a
|
| 91 |
+
# post-launch follow-up. Our own training uses an in-process
|
| 92 |
+
# SQLEnvironment via SQLEnvTRL, so this does not affect internal runs.
|
| 93 |
app = create_app(
|
| 94 |
create_sql_environment,
|
| 95 |
SQLAction,
|
sql_env.egg-info/PKG-INFO
CHANGED
|
@@ -3,6 +3,7 @@ Name: sql-env
|
|
| 3 |
Version: 0.1.0
|
| 4 |
Summary: Interactive SQL exploration RL environment for the OpenEnv Challenge
|
| 5 |
Requires-Python: <3.13,>=3.11
|
|
|
|
| 6 |
Requires-Dist: openenv-core[core]>=0.2.1
|
| 7 |
Requires-Dist: pydantic>=2.0.0
|
| 8 |
Requires-Dist: fastapi>=0.104.0
|
|
@@ -24,3 +25,4 @@ Requires-Dist: trl>=0.29.0; extra == "training"
|
|
| 24 |
Requires-Dist: accelerate>=0.34.0; extra == "training"
|
| 25 |
Requires-Dist: notebook>=7.5.5; extra == "training"
|
| 26 |
Requires-Dist: matplotlib>=3.7.0; extra == "training"
|
|
|
|
|
|
| 3 |
Version: 0.1.0
|
| 4 |
Summary: Interactive SQL exploration RL environment for the OpenEnv Challenge
|
| 5 |
Requires-Python: <3.13,>=3.11
|
| 6 |
+
License-File: LICENSE
|
| 7 |
Requires-Dist: openenv-core[core]>=0.2.1
|
| 8 |
Requires-Dist: pydantic>=2.0.0
|
| 9 |
Requires-Dist: fastapi>=0.104.0
|
|
|
|
| 25 |
Requires-Dist: accelerate>=0.34.0; extra == "training"
|
| 26 |
Requires-Dist: notebook>=7.5.5; extra == "training"
|
| 27 |
Requires-Dist: matplotlib>=3.7.0; extra == "training"
|
| 28 |
+
Dynamic: license-file
|
sql_env.egg-info/SOURCES.txt
CHANGED
|
@@ -1,8 +1,5 @@
|
|
|
|
|
| 1 |
README.md
|
| 2 |
-
__init__.py
|
| 3 |
-
client.py
|
| 4 |
-
conftest.py
|
| 5 |
-
models.py
|
| 6 |
pyproject.toml
|
| 7 |
./__init__.py
|
| 8 |
./client.py
|
|
@@ -16,9 +13,9 @@ evaluation/oracle_policy.py
|
|
| 16 |
evaluation/policies.py
|
| 17 |
server/__init__.py
|
| 18 |
server/app.py
|
|
|
|
| 19 |
server/reward.py
|
| 20 |
server/sql_environment.py
|
| 21 |
-
server/test_sql_env.py
|
| 22 |
server/verifier.py
|
| 23 |
sql_env.egg-info/PKG-INFO
|
| 24 |
sql_env.egg-info/SOURCES.txt
|
|
@@ -38,6 +35,5 @@ training/data_loading.py
|
|
| 38 |
training/few_shot_examples.py
|
| 39 |
training/notebook_pipeline.py
|
| 40 |
training/prompts.py
|
| 41 |
-
training/rewards.py
|
| 42 |
training/trl_adapter.py
|
| 43 |
training/visualization.py
|
|
|
|
| 1 |
+
LICENSE
|
| 2 |
README.md
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
pyproject.toml
|
| 4 |
./__init__.py
|
| 5 |
./client.py
|
|
|
|
| 13 |
evaluation/policies.py
|
| 14 |
server/__init__.py
|
| 15 |
server/app.py
|
| 16 |
+
server/mock_tokenizer.py
|
| 17 |
server/reward.py
|
| 18 |
server/sql_environment.py
|
|
|
|
| 19 |
server/verifier.py
|
| 20 |
sql_env.egg-info/PKG-INFO
|
| 21 |
sql_env.egg-info/SOURCES.txt
|
|
|
|
| 35 |
training/few_shot_examples.py
|
| 36 |
training/notebook_pipeline.py
|
| 37 |
training/prompts.py
|
|
|
|
| 38 |
training/trl_adapter.py
|
| 39 |
training/visualization.py
|