sql_env / docs /ARCHITECTURE.md
hjerpe's picture
Upload folder using huggingface_hub
a001a97 verified
# Architecture
> Last updated: 2026-03-29
System map for SQLEnv, an RL environment where agents learn interactive SQL exploration via the OpenEnv framework.
**Goals:**
- Show how components connect (system map + key flows)
- Make hidden state explicit (what lives where)
- Define shared interfaces (Pydantic models, HTTP API)
- Keep invariants legible (what must stay true)
**Non-goals:**
- Exhaustive API reference
- Training hyperparameter tuning guide
---
## System Map
```text
SQLEnv System
================================================================
RL Training SQLEnv Server (Docker)
───────────── ──────────────────────
+──────────────+ +─────────────────────+
β”‚ TRL GRPO β”‚ β”‚ server/app.py β”‚
β”‚ Trainer β”‚ HTTP (JSON) β”‚ FastAPI + OpenEnv β”‚
β”‚ β”‚<========================>β”‚ β”‚
β”‚ training/ β”‚ SQLAction -> server +──────────┬──────────+
β”‚ trl_adapter β”‚ SQLObs <- server β”‚
β”‚ .py β”‚ v
+──────────────+ +─────────────────────+
β”‚ β”‚ SQLEnvironment β”‚
β”‚ OR β”‚ (sql_environment.py)β”‚
v β”‚ β”‚
+──────────────+ β”‚ reset() / step() β”‚
β”‚ Custom β”‚ β”‚ action dispatch β”‚
β”‚ rollout_func β”‚ +──┬──────┬──────┬────+
β”‚ (rollout.py) β”‚ β”‚ β”‚ β”‚
+──────────────+ v v v
+────────────────────────+
Evaluation β”‚ Action Handlers β”‚
────────── β”‚ DESCRIBE β†’ PRAGMA β”‚
+──────────────+ β”‚ SAMPLE β†’ SELECT N β”‚
β”‚ evaluate() │──> env.reset/step β”‚ QUERY β†’ SQL exec β”‚
β”‚ policies β”‚ β”‚ ANSWER β†’ verifier β”‚
β”‚ .py β”‚ +────────┬───────────────+
+──────────────+ β”‚
β”‚ v
+──────────────+ +────────────────────────+
β”‚ Policies β”‚ β”‚ SQLite (read-only) β”‚
β”‚ RandomPolicy β”‚ β”‚ data/databases/ β”‚
β”‚ OraclePolicy β”‚ β”‚ {db_id}/{db_id}.sqlite β”‚
+──────────────+ +────────────────────────+
β”‚
β”Œβ”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”
v v
+───────────+ +───────────+
β”‚ reward.py β”‚ β”‚verifier.pyβ”‚
β”‚ 3-layer β”‚ β”‚ type-awareβ”‚
β”‚ dense β”‚ β”‚ comparisonβ”‚
+───────────+ +───────────+
Data (committed) Synthetic (optional)
──────────────── ────────────────────
data/questions/ server/synthetic/
questions_train.json (473 Q) generate.py
questions_eval.json (203 Q) mutations.py
db_list.json (10 databases) validate.py
```
---
## Component Inventory
| Component | Owns | Entrypoint | State / Output |
|-----------|------|------------|----------------|
| **SQLEnvironment** | Episode lifecycle, action dispatch, step budget | `server/sql_environment.py` | `EpisodeContext` (in-memory, per episode) |
| **FastAPI app** | HTTP endpoints, tokenizer factory | `server/app.py` | Stateless (delegates to environment) |
| **SQLEnvClient** | HTTP transport, payload serialization | `client.py` | Stateless (wraps server) |
| **Pydantic models** | Type contracts (action, observation, state) | `models.py` | N/A (data classes) |
| **Reward engine** | 3-layer dense reward computation | `server/reward.py` | Mutates `EpisodeContext` accumulators |
| **Answer verifier** | Type-aware answer comparison | `server/verifier.py` | Stateless (pure function) |
| **GRPO pipeline** | Training orchestration, rollout, reward callables | `training/` (6 modules) | Training artifacts in `outputs/` |
| **TRL adapter** | `environment_factory` for TRL GRPOTrainer | `training/trl_adapter.py` | Per-session environment instances |
| **Evaluation** | Policy protocol, evaluate() runner | `evaluation/policies.py` | `EvaluationResult` metrics |
| **Oracle policy** | Deterministic upper-bound baseline | `evaluation/oracle_policy.py` | Stateless per-step |
| **Synthetic DB gen** | Metamorphic testing via data mutations | `server/synthetic/` | Variant SQLite files |
| **Question dataset** | 676 curated Spider questions across 10 DBs | `data/questions/` | JSON files |
### External Dependencies
| Dependency | Purpose | Required |
|------------|---------|----------|
| SQLite (stdlib) | Database execution | Yes |
| OpenEnv (`openenv-core`) | Environment protocol, `create_app` | Yes |
| TRL (`trl`) | GRPO training | Only for training |
| HuggingFace Transformers | Tokenizer loading | Only for production server |
---
## Key Flows
### Flow: Episode (Reset + Multi-Turn Steps)
```text
Client / Policy SQLEnvironment
β”‚ β”‚
│── reset(seed=42) ────────────────> β”‚
β”‚ │── pick question (random or seeded)
β”‚ │── open read-only SQLite connection
β”‚ │── execute gold_sql β†’ store gold_rows
β”‚ │── init EpisodeContext (budget=15)
β”‚ <── SQLObservation ────────────────│
β”‚ .question="How many students?" β”‚
β”‚ .schema_info="Tables: student" β”‚ (column details hidden)
β”‚ .budget_remaining=15 β”‚
β”‚ β”‚
│── step(DESCRIBE student) ────────> β”‚
β”‚ │── PRAGMA table_info(student)
β”‚ │── add to described_tables
β”‚ │── compute_step_reward()
β”‚ <── SQLObservation ────────────────│
β”‚ .schema_info="student: id INT" β”‚ (columns now revealed)
β”‚ .result="5 columns, 20 rows" β”‚
β”‚ .reward=0.02 β”‚
β”‚ .budget_remaining=14 β”‚
β”‚ β”‚
│── step(QUERY "SELECT COUNT(*)...") β”‚
β”‚ │── validate (SELECT-only, single stmt)
β”‚ │── execute with 5s timeout
β”‚ │── compute_step_reward() (L1 + L2)
β”‚ <── SQLObservation ────────────────│
β”‚ .result="| COUNT(*) |\n| 20 |" β”‚
β”‚ .reward=0.035 β”‚
β”‚ β”‚
│── step(ANSWER "20") ─────────────> β”‚
β”‚ │── verify_answer("20", gold, type)
β”‚ │── terminal reward: +1.0 or 0.0
β”‚ <── SQLObservation ────────────────│
β”‚ .done=true β”‚
β”‚ .reward=1.0 β”‚
```
### Flow: 3-Layer Reward Computation
```text
step() called with action
β”‚
v
Layer 1: Operational Shaping (every action)
β”œβ”€β”€ exec_ok? β†’ +0.02
β”œβ”€β”€ new SQL hash? β†’ +0.01 (per unique query, no cumulative cap)
β”œβ”€β”€ repeated SQL? β†’ -0.01
└── step cost β†’ -0.005
β”‚
v (only if action_type == QUERY and no error)
Layer 2: Progress Shaping (delta-from-previous, PBRS)
β”œβ”€β”€ cardinality score (25%) β€” |pred_rows - gold_rows| / max
β”œβ”€β”€ value overlap (50%) β€” Jaccard of cell values
└── numeric range (25%) β€” log-distance proximity
β”‚
v
bin to {0.0, 0.25, 0.5, 0.75, 1.0}
delta = binned - previous_progress β†’ delta * 0.15
(positive = improvement, negative = regression)
β”‚
v
Clip per step to [-0.05, +0.15]
No cumulative tracking
β”‚
v (on ANSWER action)
Layer 3: Terminal Correctness
└── verify_answer() β†’ +1.0 (correct) or 0.0 (wrong)
```
### Flow: TRL Training Integration
```text
GRPOTrainer
β”‚
│── discovers tool methods via docstrings
β”‚ (describe, sample, query, answer)
β”‚
│── per rollout:
β”‚ SQLEnvTRL() β†’ SQLEnvironment (internal)
β”‚ .reset() β†’ observation string
β”‚ .describe(table) β†’ schema string
β”‚ .query(sql) β†’ result string
β”‚ .answer(value) β†’ final string
β”‚
│── reward:
β”‚ sql_env_reward_func() β†’ accumulated .reward
β”‚
v
Training loop (GRPO: generate N completions, rank by reward)
```
---
## Shared Data Models
Defined in `models.py`. These cross the HTTP boundary between client and server.
### SQLAction (agent -> server)
```python
class SQLAction(Action):
action_type: str # DESCRIBE | SAMPLE | QUERY | ANSWER
argument: str # table name, SQL string, or answer value
```
### SQLObservation (server -> agent)
```python
class SQLObservation(Observation):
question: str # NL question to answer
schema_info: str # known schema (incrementally revealed)
result: str # last action result (truncated)
error: str # error message if action failed
step_count: int # current step number
budget_remaining: int # steps left
action_history: list[str] # summary of prior actions
# Inherited: done (bool), reward (float | None)
```
### SQLState (metadata endpoint)
```python
class SQLState(State):
history_messages: list[Message]
current_action_type: str
```
### Server-Only Types (never sent to agent)
```python
@dataclass
class QuestionRecord:
question_id: str
question_text: str
database_name: str
gold_sql: str
gold_answer: str
answer_type: str # integer | float | string | list
difficulty: str # easy | medium | hard
tables_involved: list[str]
@dataclass
class EpisodeContext:
episode_id: str
db_connection: sqlite3.Connection
question_record: QuestionRecord
step_count: int = 0
budget: int = 15
described_tables: set[str]
action_log: list[str]
done: bool = False
gold_answer: str | None
gold_rows: list[tuple]
# Reward accumulators
query_hashes: set[str]
best_progress: float = 0.0
cumulative_step_reward: float = 0.0
cumulative_new_info_reward: float = 0.0
```
**POMDP design:** The agent sees `SQLObservation`. The server holds `EpisodeContext`. The agent never sees gold answers, progress scores, or the full database. This separation forces exploration.
---
## API Contracts
### HTTP (OpenEnv Protocol)
The server exposes HTTP endpoints via `openenv.core.env_server.create_app()`.
| Operation | Method | Payload | Response |
|-----------|--------|---------|----------|
| Reset | `POST /reset` | `{seed: int}` (optional) | `SQLObservation` (JSON) |
| Step | `POST /step` | `{action_type, argument, metadata}` | `{observation, reward, done, info}` |
| State | `GET /state` | β€” | `SQLState` (JSON) |
### Evaluation API
```python
# Policy protocol
class Policy(Protocol):
def select_action(self, observation: SQLObservation) -> SQLAction: ...
# Built-in policies
RandomPolicy() # random baseline
OraclePolicy(questions) # gold-answer upper bound
# Runner
evaluate(env, policy, n_episodes, seed) -> EvaluationResult
# .success_rate, .avg_reward, .avg_steps, .episodes[]
```
### TRL Adapter API
```python
SQLEnvTRL.configure(questions_path, db_dir, step_budget) # class method
# Tool methods (auto-discovered by TRL):
SQLEnvTRL.describe(table_name: str) -> str
SQLEnvTRL.sample(table_name: str) -> str
SQLEnvTRL.query(sql: str) -> str
SQLEnvTRL.answer(value: str) -> str
```
---
## Cross-Cutting Concerns
### SQL Safety
All database access enforces:
- **Read-only** SQLite connections (`file:...?mode=ro`)
- **SELECT-only** β€” rejects INSERT, UPDATE, DELETE, ALTER, DROP
- **Single statement** β€” rejects `; ...` (no stacked queries)
- **5-second timeout** via SQLite progress handler
- **20-row truncation** on all result sets
### POMDP Structure
The partial observability is deliberate and load-bearing:
- Agent sees table names at reset but **not column details** (must DESCRIBE)
- Query results are **truncated** (at most 20 rows)
- Agent never sees `gold_answer`, `best_progress`, or `gold_rows`
- Step budget (default 15) forces strategic allocation of exploration
### Import Compatibility
Dual import paths throughout for local vs Docker execution:
```python
try:
from sql_env.models import SQLAction # local / pip install
except ImportError:
from models import SQLAction # Docker (PYTHONPATH=/app/env)
```
### Configuration
| Variable | Required | Default | Purpose |
|----------|----------|---------|---------|
| `QUESTIONS_PATH` | No | `data/questions/student_assessment.json` | Questions JSON |
| `DB_DIR` | No | `data/databases/` | SQLite database directory |
| `TOKENIZER_NAME` | No | `mistralai/Mistral-7B-Instruct-v0.1` | HuggingFace tokenizer |
| `PORT` | No | `8000` | Server port |
---
## Data, State, and Storage
### Committed Data
| Path | Contents |
|------|----------|
| `data/questions/questions_train.json` | 473 training questions across 10 DBs |
| `data/questions/questions_eval.json` | 203 evaluation questions across 10 DBs |
| `data/questions/db_list.json` | 10 Spider database IDs |
| `data/databases/models.py` | Legacy SQLAlchemy ORM models |
### Downloaded Data (gitignored)
Spider SQLite databases in `data/databases/{db_id}/{db_id}.sqlite`. Downloaded via `scripts/download_spider_databases.py`. The 10 databases: student_assessment, concert_singer, world_1, car_1, employee_hire_evaluation, pets_1, cre_Doc_Template_Mgt, dog_kennels, flight_2, poker_player.
### Runtime State (in-memory, per episode)
`EpisodeContext` holds all episode state: DB connection, gold data, reward accumulators, action history. Created on `reset()`, discarded when episode ends. Nothing persists between episodes.
---
<!-- ARCHITECTURE-SNAPSHOT-BEGIN -->
Snapshot (auto-managed)
- Repo signals: Python (pyproject.toml)
- Roots: tests/
- Entrypoint candidates: (none detected)
```text
tests/
e2e/
test_training_e2e.py
integration/
test_training_pipeline.py
unit/
test_error_handling.py
test_grpo_config.py
test_oracle_policy.py
test_prompts.py
test_reward.py
test_rewards.py
test_rollout.py
test_sft_terminal_message.py
test_trl_adapter.py
test_evaluation.py
test_smoke.py
test_synthetic.py
test_trl_adapter.py
test_verifier.py
test_verifier_integration.py
```
<!-- ARCHITECTURE-SNAPSHOT-END -->
---
## Infrastructure
### Development
**Prerequisites:** Python 3.11-3.12, `uv`, Docker (for deployment)
```bash
git clone <repo-url> && cd sql-env
uv sync
uv run python scripts/download_spider_databases.py
uv run pytest tests/ -v
```
### Deployment
**Target:** HuggingFace Spaces (Docker, free tier)
```bash
uv run openenv build # build Docker image
uv run openenv push # push to HF Spaces
```
The Dockerfile uses multi-stage build with `openenv-base`, runs as non-root `appuser`, bundles Spider databases, and exposes port 8000.
---
## Invariants
- Token tensors in `SQLState` grow monotonically across turns (never shrink mid-episode)
- `EpisodeContext` is server-only β€” leaking gold data to the agent breaks the POMDP
- Per-step rewards clipped to `[-0.05, 0.15]` β€” terminal reward (+1.0) always dominates exploration (~0.3 max)
- `tests/` must pass without GPU, without network, without downloaded databases (mocks/fixtures)
- SQL execution never mutates the database (read-only mode enforced at connection level)
---
## Glossary
| Term | Definition |
|------|------------|
| Episode | One question-answering session: reset -> N steps -> terminal |
| Action type | One of: DESCRIBE, SAMPLE, QUERY, ANSWER |
| POMDP | Partially observable MDP. Agent acts under uncertainty |
| Spider | Academic text-to-SQL benchmark dataset (10 DBs used) |
| OpenEnv | Meta's RL environment framework (Environment, EnvClient) |
| Green Agent | OpenEnv's evaluation wrapper pattern |
| Oracle policy | Baseline that uses gold SQL/answer for reward ceiling validation |
| TRL | HuggingFace Transformer Reinforcement Learning library |
| GRPO | Group Relative Policy Optimization (RL algorithm used for training) |
| Dense reward | Per-step reward signal (vs sparse terminal-only reward) |
---
## References
- OpenEnv framework: https://github.com/meta-pytorch/OpenEnv
- Spider dataset: https://huggingface.co/datasets/xlangai/spider
- TRL OpenEnv docs: https://huggingface.co/docs/trl/openenv