| # Architecture |
|
|
| > Last updated: 2026-03-29 |
|
|
| System map for SQLEnv, an RL environment where agents learn interactive SQL exploration via the OpenEnv framework. |
|
|
| **Goals:** |
| - Show how components connect (system map + key flows) |
| - Make hidden state explicit (what lives where) |
| - Define shared interfaces (Pydantic models, HTTP API) |
| - Keep invariants legible (what must stay true) |
|
|
| **Non-goals:** |
| - Exhaustive API reference |
| - Training hyperparameter tuning guide |
|
|
| --- |
|
|
| ## System Map |
|
|
| ```text |
| SQLEnv System |
| ================================================================ |
| |
| RL Training SQLEnv Server (Docker) |
| βββββββββββββ ββββββββββββββββββββββ |
| +ββββββββββββββ+ +βββββββββββββββββββββ+ |
| β TRL GRPO β β server/app.py β |
| β Trainer β HTTP (JSON) β FastAPI + OpenEnv β |
| β β<========================>β β |
| β training/ β SQLAction -> server +βββββββββββ¬ββββββββββ+ |
| β trl_adapter β SQLObs <- server β |
| β .py β v |
| +ββββββββββββββ+ +βββββββββββββββββββββ+ |
| β β SQLEnvironment β |
| β OR β (sql_environment.py)β |
| v β β |
| +ββββββββββββββ+ β reset() / step() β |
| β Custom β β action dispatch β |
| β rollout_func β +βββ¬βββββββ¬βββββββ¬ββββ+ |
| β (rollout.py) β β β β |
| +ββββββββββββββ+ v v v |
| +ββββββββββββββββββββββββ+ |
| Evaluation β Action Handlers β |
| ββββββββββ β DESCRIBE β PRAGMA β |
| +ββββββββββββββ+ β SAMPLE β SELECT N β |
| β evaluate() βββ> env.reset/step β QUERY β SQL exec β |
| β policies β β ANSWER β verifier β |
| β .py β +βββββββββ¬βββββββββββββββ+ |
| +ββββββββββββββ+ β |
| β v |
| +ββββββββββββββ+ +ββββββββββββββββββββββββ+ |
| β Policies β β SQLite (read-only) β |
| β RandomPolicy β β data/databases/ β |
| β OraclePolicy β β {db_id}/{db_id}.sqlite β |
| +ββββββββββββββ+ +ββββββββββββββββββββββββ+ |
| β |
| βββββββββ΄ββββββββ |
| v v |
| +βββββββββββ+ +βββββββββββ+ |
| β reward.py β βverifier.pyβ |
| β 3-layer β β type-awareβ |
| β dense β β comparisonβ |
| +βββββββββββ+ +βββββββββββ+ |
| |
| Data (committed) Synthetic (optional) |
| ββββββββββββββββ ββββββββββββββββββββ |
| data/questions/ server/synthetic/ |
| questions_train.json (473 Q) generate.py |
| questions_eval.json (203 Q) mutations.py |
| db_list.json (10 databases) validate.py |
| ``` |
|
|
| --- |
|
|
| ## Component Inventory |
|
|
| | Component | Owns | Entrypoint | State / Output | |
| |-----------|------|------------|----------------| |
| | **SQLEnvironment** | Episode lifecycle, action dispatch, step budget | `server/sql_environment.py` | `EpisodeContext` (in-memory, per episode) | |
| | **FastAPI app** | HTTP endpoints, tokenizer factory | `server/app.py` | Stateless (delegates to environment) | |
| | **SQLEnvClient** | HTTP transport, payload serialization | `client.py` | Stateless (wraps server) | |
| | **Pydantic models** | Type contracts (action, observation, state) | `models.py` | N/A (data classes) | |
| | **Reward engine** | 3-layer dense reward computation | `server/reward.py` | Mutates `EpisodeContext` accumulators | |
| | **Answer verifier** | Type-aware answer comparison | `server/verifier.py` | Stateless (pure function) | |
| | **GRPO pipeline** | Training orchestration, rollout, reward callables | `training/` (6 modules) | Training artifacts in `outputs/` | |
| | **TRL adapter** | `environment_factory` for TRL GRPOTrainer | `training/trl_adapter.py` | Per-session environment instances | |
| | **Evaluation** | Policy protocol, evaluate() runner | `evaluation/policies.py` | `EvaluationResult` metrics | |
| | **Oracle policy** | Deterministic upper-bound baseline | `evaluation/oracle_policy.py` | Stateless per-step | |
| | **Synthetic DB gen** | Metamorphic testing via data mutations | `server/synthetic/` | Variant SQLite files | |
| | **Question dataset** | 676 curated Spider questions across 10 DBs | `data/questions/` | JSON files | |
|
|
| ### External Dependencies |
|
|
| | Dependency | Purpose | Required | |
| |------------|---------|----------| |
| | SQLite (stdlib) | Database execution | Yes | |
| | OpenEnv (`openenv-core`) | Environment protocol, `create_app` | Yes | |
| | TRL (`trl`) | GRPO training | Only for training | |
| | HuggingFace Transformers | Tokenizer loading | Only for production server | |
|
|
| --- |
|
|
| ## Key Flows |
|
|
| ### Flow: Episode (Reset + Multi-Turn Steps) |
|
|
| ```text |
| Client / Policy SQLEnvironment |
| β β |
| βββ reset(seed=42) ββββββββββββββββ> β |
| β βββ pick question (random or seeded) |
| β βββ open read-only SQLite connection |
| β βββ execute gold_sql β store gold_rows |
| β βββ init EpisodeContext (budget=15) |
| β <ββ SQLObservation βββββββββββββββββ |
| β .question="How many students?" β |
| β .schema_info="Tables: student" β (column details hidden) |
| β .budget_remaining=15 β |
| β β |
| βββ step(DESCRIBE student) ββββββββ> β |
| β βββ PRAGMA table_info(student) |
| β βββ add to described_tables |
| β βββ compute_step_reward() |
| β <ββ SQLObservation βββββββββββββββββ |
| β .schema_info="student: id INT" β (columns now revealed) |
| β .result="5 columns, 20 rows" β |
| β .reward=0.02 β |
| β .budget_remaining=14 β |
| β β |
| βββ step(QUERY "SELECT COUNT(*)...") β |
| β βββ validate (SELECT-only, single stmt) |
| β βββ execute with 5s timeout |
| β βββ compute_step_reward() (L1 + L2) |
| β <ββ SQLObservation βββββββββββββββββ |
| β .result="| COUNT(*) |\n| 20 |" β |
| β .reward=0.035 β |
| β β |
| βββ step(ANSWER "20") βββββββββββββ> β |
| β βββ verify_answer("20", gold, type) |
| β βββ terminal reward: +1.0 or 0.0 |
| β <ββ SQLObservation βββββββββββββββββ |
| β .done=true β |
| β .reward=1.0 β |
| ``` |
|
|
| ### Flow: 3-Layer Reward Computation |
|
|
| ```text |
| step() called with action |
| β |
| v |
| Layer 1: Operational Shaping (every action) |
| βββ exec_ok? β +0.02 |
| βββ new SQL hash? β +0.01 (per unique query, no cumulative cap) |
| βββ repeated SQL? β -0.01 |
| βββ step cost β -0.005 |
| β |
| v (only if action_type == QUERY and no error) |
| Layer 2: Progress Shaping (delta-from-previous, PBRS) |
| βββ cardinality score (25%) β |pred_rows - gold_rows| / max |
| βββ value overlap (50%) β Jaccard of cell values |
| βββ numeric range (25%) β log-distance proximity |
| β |
| v |
| bin to {0.0, 0.25, 0.5, 0.75, 1.0} |
| delta = binned - previous_progress β delta * 0.15 |
| (positive = improvement, negative = regression) |
| β |
| v |
| Clip per step to [-0.05, +0.15] |
| No cumulative tracking |
| β |
| v (on ANSWER action) |
| Layer 3: Terminal Correctness |
| βββ verify_answer() β +1.0 (correct) or 0.0 (wrong) |
| ``` |
|
|
| ### Flow: TRL Training Integration |
|
|
| ```text |
| GRPOTrainer |
| β |
| βββ discovers tool methods via docstrings |
| β (describe, sample, query, answer) |
| β |
| βββ per rollout: |
| β SQLEnvTRL() β SQLEnvironment (internal) |
| β .reset() β observation string |
| β .describe(table) β schema string |
| β .query(sql) β result string |
| β .answer(value) β final string |
| β |
| βββ reward: |
| β sql_env_reward_func() β accumulated .reward |
| β |
| v |
| Training loop (GRPO: generate N completions, rank by reward) |
| ``` |
|
|
| --- |
|
|
| ## Shared Data Models |
|
|
| Defined in `models.py`. These cross the HTTP boundary between client and server. |
|
|
| ### SQLAction (agent -> server) |
|
|
| ```python |
| class SQLAction(Action): |
| action_type: str # DESCRIBE | SAMPLE | QUERY | ANSWER |
| argument: str # table name, SQL string, or answer value |
| ``` |
|
|
| ### SQLObservation (server -> agent) |
|
|
| ```python |
| class SQLObservation(Observation): |
| question: str # NL question to answer |
| schema_info: str # known schema (incrementally revealed) |
| result: str # last action result (truncated) |
| error: str # error message if action failed |
| step_count: int # current step number |
| budget_remaining: int # steps left |
| action_history: list[str] # summary of prior actions |
| # Inherited: done (bool), reward (float | None) |
| ``` |
|
|
| ### SQLState (metadata endpoint) |
|
|
| ```python |
| class SQLState(State): |
| history_messages: list[Message] |
| current_action_type: str |
| ``` |
|
|
| ### Server-Only Types (never sent to agent) |
|
|
| ```python |
| @dataclass |
| class QuestionRecord: |
| question_id: str |
| question_text: str |
| database_name: str |
| gold_sql: str |
| gold_answer: str |
| answer_type: str # integer | float | string | list |
| difficulty: str # easy | medium | hard |
| tables_involved: list[str] |
| |
| @dataclass |
| class EpisodeContext: |
| episode_id: str |
| db_connection: sqlite3.Connection |
| question_record: QuestionRecord |
| step_count: int = 0 |
| budget: int = 15 |
| described_tables: set[str] |
| action_log: list[str] |
| done: bool = False |
| gold_answer: str | None |
| gold_rows: list[tuple] |
| # Reward accumulators |
| query_hashes: set[str] |
| best_progress: float = 0.0 |
| cumulative_step_reward: float = 0.0 |
| cumulative_new_info_reward: float = 0.0 |
| ``` |
|
|
| **POMDP design:** The agent sees `SQLObservation`. The server holds `EpisodeContext`. The agent never sees gold answers, progress scores, or the full database. This separation forces exploration. |
|
|
| --- |
|
|
| ## API Contracts |
|
|
| ### HTTP (OpenEnv Protocol) |
|
|
| The server exposes HTTP endpoints via `openenv.core.env_server.create_app()`. |
|
|
| | Operation | Method | Payload | Response | |
| |-----------|--------|---------|----------| |
| | Reset | `POST /reset` | `{seed: int}` (optional) | `SQLObservation` (JSON) | |
| | Step | `POST /step` | `{action_type, argument, metadata}` | `{observation, reward, done, info}` | |
| | State | `GET /state` | β | `SQLState` (JSON) | |
|
|
| ### Evaluation API |
|
|
| ```python |
| # Policy protocol |
| class Policy(Protocol): |
| def select_action(self, observation: SQLObservation) -> SQLAction: ... |
| |
| # Built-in policies |
| RandomPolicy() # random baseline |
| OraclePolicy(questions) # gold-answer upper bound |
| |
| # Runner |
| evaluate(env, policy, n_episodes, seed) -> EvaluationResult |
| # .success_rate, .avg_reward, .avg_steps, .episodes[] |
| ``` |
|
|
| ### TRL Adapter API |
|
|
| ```python |
| SQLEnvTRL.configure(questions_path, db_dir, step_budget) # class method |
| # Tool methods (auto-discovered by TRL): |
| SQLEnvTRL.describe(table_name: str) -> str |
| SQLEnvTRL.sample(table_name: str) -> str |
| SQLEnvTRL.query(sql: str) -> str |
| SQLEnvTRL.answer(value: str) -> str |
| ``` |
|
|
| --- |
|
|
| ## Cross-Cutting Concerns |
|
|
| ### SQL Safety |
|
|
| All database access enforces: |
| - **Read-only** SQLite connections (`file:...?mode=ro`) |
| - **SELECT-only** β rejects INSERT, UPDATE, DELETE, ALTER, DROP |
| - **Single statement** β rejects `; ...` (no stacked queries) |
| - **5-second timeout** via SQLite progress handler |
| - **20-row truncation** on all result sets |
|
|
| ### POMDP Structure |
|
|
| The partial observability is deliberate and load-bearing: |
| - Agent sees table names at reset but **not column details** (must DESCRIBE) |
| - Query results are **truncated** (at most 20 rows) |
| - Agent never sees `gold_answer`, `best_progress`, or `gold_rows` |
| - Step budget (default 15) forces strategic allocation of exploration |
|
|
| ### Import Compatibility |
|
|
| Dual import paths throughout for local vs Docker execution: |
| ```python |
| try: |
| from sql_env.models import SQLAction # local / pip install |
| except ImportError: |
| from models import SQLAction # Docker (PYTHONPATH=/app/env) |
| ``` |
|
|
| ### Configuration |
|
|
| | Variable | Required | Default | Purpose | |
| |----------|----------|---------|---------| |
| | `QUESTIONS_PATH` | No | `data/questions/student_assessment.json` | Questions JSON | |
| | `DB_DIR` | No | `data/databases/` | SQLite database directory | |
| | `TOKENIZER_NAME` | No | `mistralai/Mistral-7B-Instruct-v0.1` | HuggingFace tokenizer | |
| | `PORT` | No | `8000` | Server port | |
|
|
| --- |
|
|
| ## Data, State, and Storage |
|
|
| ### Committed Data |
|
|
| | Path | Contents | |
| |------|----------| |
| | `data/questions/questions_train.json` | 473 training questions across 10 DBs | |
| | `data/questions/questions_eval.json` | 203 evaluation questions across 10 DBs | |
| | `data/questions/db_list.json` | 10 Spider database IDs | |
| | `data/databases/models.py` | Legacy SQLAlchemy ORM models | |
|
|
| ### Downloaded Data (gitignored) |
|
|
| Spider SQLite databases in `data/databases/{db_id}/{db_id}.sqlite`. Downloaded via `scripts/download_spider_databases.py`. The 10 databases: student_assessment, concert_singer, world_1, car_1, employee_hire_evaluation, pets_1, cre_Doc_Template_Mgt, dog_kennels, flight_2, poker_player. |
| |
| ### Runtime State (in-memory, per episode) |
| |
| `EpisodeContext` holds all episode state: DB connection, gold data, reward accumulators, action history. Created on `reset()`, discarded when episode ends. Nothing persists between episodes. |
| |
| --- |
| |
| <!-- ARCHITECTURE-SNAPSHOT-BEGIN --> |
| |
| Snapshot (auto-managed) |
| |
| - Repo signals: Python (pyproject.toml) |
| - Roots: tests/ |
| - Entrypoint candidates: (none detected) |
| |
| ```text |
| tests/ |
| e2e/ |
| test_training_e2e.py |
| integration/ |
| test_training_pipeline.py |
| unit/ |
| test_error_handling.py |
| test_grpo_config.py |
| test_oracle_policy.py |
| test_prompts.py |
| test_reward.py |
| test_rewards.py |
| test_rollout.py |
| test_sft_terminal_message.py |
| test_trl_adapter.py |
| test_evaluation.py |
| test_smoke.py |
| test_synthetic.py |
| test_trl_adapter.py |
| test_verifier.py |
| test_verifier_integration.py |
| ``` |
| |
| <!-- ARCHITECTURE-SNAPSHOT-END --> |
|
|
| --- |
|
|
| ## Infrastructure |
|
|
| ### Development |
|
|
| **Prerequisites:** Python 3.11-3.12, `uv`, Docker (for deployment) |
|
|
| ```bash |
| git clone <repo-url> && cd sql-env |
| uv sync |
| uv run python scripts/download_spider_databases.py |
| uv run pytest tests/ -v |
| ``` |
|
|
| ### Deployment |
|
|
| **Target:** HuggingFace Spaces (Docker, free tier) |
|
|
| ```bash |
| uv run openenv build # build Docker image |
| uv run openenv push # push to HF Spaces |
| ``` |
|
|
| The Dockerfile uses multi-stage build with `openenv-base`, runs as non-root `appuser`, bundles Spider databases, and exposes port 8000. |
|
|
| --- |
|
|
| ## Invariants |
|
|
| - Token tensors in `SQLState` grow monotonically across turns (never shrink mid-episode) |
| - `EpisodeContext` is server-only β leaking gold data to the agent breaks the POMDP |
| - Per-step rewards clipped to `[-0.05, 0.15]` β terminal reward (+1.0) always dominates exploration (~0.3 max) |
| - `tests/` must pass without GPU, without network, without downloaded databases (mocks/fixtures) |
| - SQL execution never mutates the database (read-only mode enforced at connection level) |
|
|
| --- |
|
|
| ## Glossary |
|
|
| | Term | Definition | |
| |------|------------| |
| | Episode | One question-answering session: reset -> N steps -> terminal | |
| | Action type | One of: DESCRIBE, SAMPLE, QUERY, ANSWER | |
| | POMDP | Partially observable MDP. Agent acts under uncertainty | |
| | Spider | Academic text-to-SQL benchmark dataset (10 DBs used) | |
| | OpenEnv | Meta's RL environment framework (Environment, EnvClient) | |
| | Green Agent | OpenEnv's evaluation wrapper pattern | |
| | Oracle policy | Baseline that uses gold SQL/answer for reward ceiling validation | |
| | TRL | HuggingFace Transformer Reinforcement Learning library | |
| | GRPO | Group Relative Policy Optimization (RL algorithm used for training) | |
| | Dense reward | Per-step reward signal (vs sparse terminal-only reward) | |
|
|
| --- |
|
|
| ## References |
|
|
| - OpenEnv framework: https://github.com/meta-pytorch/OpenEnv |
| - Spider dataset: https://huggingface.co/datasets/xlangai/spider |
| - TRL OpenEnv docs: https://huggingface.co/docs/trl/openenv |
|
|