File size: 18,688 Bytes
5dd1bb4
 
9e64e71
5dd1bb4
a001a97
5dd1bb4
 
 
 
9e64e71
5dd1bb4
 
 
a001a97
 
5dd1bb4
 
 
 
 
 
 
 
 
9e64e71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a001a97
9e64e71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5dd1bb4
 
 
 
 
 
 
 
9e64e71
 
 
5dd1bb4
9e64e71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5dd1bb4
 
 
 
 
 
 
 
9e64e71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5dd1bb4
 
9e64e71
5dd1bb4
 
9e64e71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5dd1bb4
9e64e71
 
 
 
 
 
 
 
 
 
 
5dd1bb4
 
9e64e71
5dd1bb4
 
9e64e71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5dd1bb4
 
 
 
 
 
9e64e71
5dd1bb4
9e64e71
5dd1bb4
 
 
9e64e71
 
5dd1bb4
 
9e64e71
5dd1bb4
 
 
9e64e71
 
 
 
 
 
 
 
5dd1bb4
 
9e64e71
5dd1bb4
 
 
9e64e71
 
5dd1bb4
 
9e64e71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a001a97
5dd1bb4
 
 
 
 
9e64e71
5dd1bb4
9e64e71
5dd1bb4
9e64e71
 
 
 
 
5dd1bb4
9e64e71
5dd1bb4
9e64e71
 
 
 
5dd1bb4
9e64e71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5dd1bb4
 
 
 
 
9e64e71
5dd1bb4
9e64e71
 
 
 
 
 
5dd1bb4
9e64e71
5dd1bb4
9e64e71
 
 
 
 
5dd1bb4
9e64e71
5dd1bb4
9e64e71
 
 
 
 
 
 
5dd1bb4
 
 
9e64e71
 
 
 
 
a001a97
5dd1bb4
 
 
9e64e71
5dd1bb4
9e64e71
5dd1bb4
9e64e71
 
 
 
 
 
5dd1bb4
9e64e71
5dd1bb4
9e64e71
 
 
 
 
5dd1bb4
 
 
9e64e71
5dd1bb4
9e64e71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5dd1bb4
 
 
 
 
 
 
9e64e71
5dd1bb4
 
 
 
9e64e71
 
5dd1bb4
 
9e64e71
5dd1bb4
9e64e71
5dd1bb4
9e64e71
 
 
 
5dd1bb4
a001a97
5dd1bb4
 
 
9e64e71
5dd1bb4
9e64e71
 
 
 
 
5dd1bb4
 
 
9e64e71
5dd1bb4
9e64e71
 
 
 
a001a97
9e64e71
 
 
 
 
 
 
5dd1bb4
 
 
 
 
 
 
9e64e71
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
# Architecture

> Last updated: 2026-03-29

System map for SQLEnv, an RL environment where agents learn interactive SQL exploration via the OpenEnv framework.

**Goals:**
- Show how components connect (system map + key flows)
- Make hidden state explicit (what lives where)
- Define shared interfaces (Pydantic models, HTTP API)
- Keep invariants legible (what must stay true)

**Non-goals:**
- Exhaustive API reference
- Training hyperparameter tuning guide

---

## System Map

```text
                         SQLEnv System
  ================================================================

  RL Training                                SQLEnv Server (Docker)
  ─────────────                              ──────────────────────
  +──────────────+                          +─────────────────────+
  β”‚ TRL GRPO     β”‚                          β”‚ server/app.py       β”‚
  β”‚ Trainer      β”‚    HTTP (JSON)           β”‚ FastAPI + OpenEnv   β”‚
  β”‚              β”‚<========================>β”‚                     β”‚
  β”‚ training/    β”‚  SQLAction  -> server    +──────────┬──────────+
  β”‚ trl_adapter  β”‚  SQLObs    <- server               β”‚
  β”‚ .py          β”‚                                    v
  +──────────────+                          +─────────────────────+
        β”‚                                   β”‚ SQLEnvironment      β”‚
        β”‚ OR                                β”‚ (sql_environment.py)β”‚
        v                                   β”‚                     β”‚
  +──────────────+                          β”‚ reset() / step()    β”‚
  β”‚ Custom       β”‚                          β”‚ action dispatch     β”‚
  β”‚ rollout_func β”‚                          +──┬──────┬──────┬────+
  β”‚ (rollout.py) β”‚                             β”‚      β”‚      β”‚
  +──────────────+                             v      v      v
                                         +────────────────────────+
  Evaluation                             β”‚  Action Handlers       β”‚
  ──────────                             β”‚  DESCRIBE β†’ PRAGMA     β”‚
  +──────────────+                       β”‚  SAMPLE   β†’ SELECT N   β”‚
  β”‚ evaluate()   │──> env.reset/step     β”‚  QUERY    β†’ SQL exec   β”‚
  β”‚ policies     β”‚                       β”‚  ANSWER   β†’ verifier   β”‚
  β”‚ .py          β”‚                       +────────┬───────────────+
  +──────────────+                                β”‚
        β”‚                                         v
  +──────────────+                       +────────────────────────+
  β”‚ Policies     β”‚                       β”‚ SQLite (read-only)     β”‚
  β”‚ RandomPolicy β”‚                       β”‚ data/databases/        β”‚
  β”‚ OraclePolicy β”‚                       β”‚ {db_id}/{db_id}.sqlite β”‚
  +──────────────+                       +────────────────────────+
                                                  β”‚
                                          β”Œβ”€β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”€β”
                                          v               v
                                   +───────────+   +───────────+
                                   β”‚ reward.py β”‚   β”‚verifier.pyβ”‚
                                   β”‚ 3-layer   β”‚   β”‚ type-awareβ”‚
                                   β”‚ dense     β”‚   β”‚ comparisonβ”‚
                                   +───────────+   +───────────+

  Data (committed)                    Synthetic (optional)
  ────────────────                    ────────────────────
  data/questions/                     server/synthetic/
    questions_train.json (473 Q)        generate.py
    questions_eval.json  (203 Q)        mutations.py
    db_list.json (10 databases)         validate.py
```

---

## Component Inventory

| Component | Owns | Entrypoint | State / Output |
|-----------|------|------------|----------------|
| **SQLEnvironment** | Episode lifecycle, action dispatch, step budget | `server/sql_environment.py` | `EpisodeContext` (in-memory, per episode) |
| **FastAPI app** | HTTP endpoints, tokenizer factory | `server/app.py` | Stateless (delegates to environment) |
| **SQLEnvClient** | HTTP transport, payload serialization | `client.py` | Stateless (wraps server) |
| **Pydantic models** | Type contracts (action, observation, state) | `models.py` | N/A (data classes) |
| **Reward engine** | 3-layer dense reward computation | `server/reward.py` | Mutates `EpisodeContext` accumulators |
| **Answer verifier** | Type-aware answer comparison | `server/verifier.py` | Stateless (pure function) |
| **GRPO pipeline** | Training orchestration, rollout, reward callables | `training/` (6 modules) | Training artifacts in `outputs/` |
| **TRL adapter** | `environment_factory` for TRL GRPOTrainer | `training/trl_adapter.py` | Per-session environment instances |
| **Evaluation** | Policy protocol, evaluate() runner | `evaluation/policies.py` | `EvaluationResult` metrics |
| **Oracle policy** | Deterministic upper-bound baseline | `evaluation/oracle_policy.py` | Stateless per-step |
| **Synthetic DB gen** | Metamorphic testing via data mutations | `server/synthetic/` | Variant SQLite files |
| **Question dataset** | 676 curated Spider questions across 10 DBs | `data/questions/` | JSON files |

### External Dependencies

| Dependency | Purpose | Required |
|------------|---------|----------|
| SQLite (stdlib) | Database execution | Yes |
| OpenEnv (`openenv-core`) | Environment protocol, `create_app` | Yes |
| TRL (`trl`) | GRPO training | Only for training |
| HuggingFace Transformers | Tokenizer loading | Only for production server |

---

## Key Flows

### Flow: Episode (Reset + Multi-Turn Steps)

```text
Client / Policy                  SQLEnvironment
  β”‚                                    β”‚
  │── reset(seed=42) ────────────────> β”‚
  β”‚                                    │── pick question (random or seeded)
  β”‚                                    │── open read-only SQLite connection
  β”‚                                    │── execute gold_sql β†’ store gold_rows
  β”‚                                    │── init EpisodeContext (budget=15)
  β”‚ <── SQLObservation ────────────────│
  β”‚     .question="How many students?" β”‚
  β”‚     .schema_info="Tables: student" β”‚   (column details hidden)
  β”‚     .budget_remaining=15           β”‚
  β”‚                                    β”‚
  │── step(DESCRIBE student) ────────> β”‚
  β”‚                                    │── PRAGMA table_info(student)
  β”‚                                    │── add to described_tables
  β”‚                                    │── compute_step_reward()
  β”‚ <── SQLObservation ────────────────│
  β”‚     .schema_info="student: id INT" β”‚   (columns now revealed)
  β”‚     .result="5 columns, 20 rows"   β”‚
  β”‚     .reward=0.02                   β”‚
  β”‚     .budget_remaining=14           β”‚
  β”‚                                    β”‚
  │── step(QUERY "SELECT COUNT(*)...") β”‚
  β”‚                                    │── validate (SELECT-only, single stmt)
  β”‚                                    │── execute with 5s timeout
  β”‚                                    │── compute_step_reward() (L1 + L2)
  β”‚ <── SQLObservation ────────────────│
  β”‚     .result="| COUNT(*) |\n| 20 |" β”‚
  β”‚     .reward=0.035                  β”‚
  β”‚                                    β”‚
  │── step(ANSWER "20") ─────────────> β”‚
  β”‚                                    │── verify_answer("20", gold, type)
  β”‚                                    │── terminal reward: +1.0 or 0.0
  β”‚ <── SQLObservation ────────────────│
  β”‚     .done=true                     β”‚
  β”‚     .reward=1.0                    β”‚
```

### Flow: 3-Layer Reward Computation

```text
step() called with action
        β”‚
        v
  Layer 1: Operational Shaping (every action)
  β”œβ”€β”€ exec_ok?        β†’ +0.02
  β”œβ”€β”€ new SQL hash?   β†’ +0.01 (per unique query, no cumulative cap)
  β”œβ”€β”€ repeated SQL?   β†’ -0.01
  └── step cost       β†’ -0.005
        β”‚
        v (only if action_type == QUERY and no error)
  Layer 2: Progress Shaping (delta-from-previous, PBRS)
  β”œβ”€β”€ cardinality score  (25%) β€” |pred_rows - gold_rows| / max
  β”œβ”€β”€ value overlap      (50%) β€” Jaccard of cell values
  └── numeric range      (25%) β€” log-distance proximity
        β”‚
        v
  bin to {0.0, 0.25, 0.5, 0.75, 1.0}
  delta = binned - previous_progress β†’ delta * 0.15
  (positive = improvement, negative = regression)
        β”‚
        v
  Clip per step to [-0.05, +0.15]
  No cumulative tracking
        β”‚
        v (on ANSWER action)
  Layer 3: Terminal Correctness
  └── verify_answer() β†’ +1.0 (correct) or 0.0 (wrong)
```

### Flow: TRL Training Integration

```text
  GRPOTrainer
      β”‚
      │── discovers tool methods via docstrings
      β”‚   (describe, sample, query, answer)
      β”‚
      │── per rollout:
      β”‚     SQLEnvTRL() β†’ SQLEnvironment (internal)
      β”‚     .reset() β†’ observation string
      β”‚     .describe(table) β†’ schema string
      β”‚     .query(sql) β†’ result string
      β”‚     .answer(value) β†’ final string
      β”‚
      │── reward:
      β”‚     sql_env_reward_func() β†’ accumulated .reward
      β”‚
      v
  Training loop (GRPO: generate N completions, rank by reward)
```

---

## Shared Data Models

Defined in `models.py`. These cross the HTTP boundary between client and server.

### SQLAction (agent -> server)

```python
class SQLAction(Action):
    action_type: str   # DESCRIBE | SAMPLE | QUERY | ANSWER
    argument: str      # table name, SQL string, or answer value
```

### SQLObservation (server -> agent)

```python
class SQLObservation(Observation):
    question: str              # NL question to answer
    schema_info: str           # known schema (incrementally revealed)
    result: str                # last action result (truncated)
    error: str                 # error message if action failed
    step_count: int            # current step number
    budget_remaining: int      # steps left
    action_history: list[str]  # summary of prior actions
    # Inherited: done (bool), reward (float | None)
```

### SQLState (metadata endpoint)

```python
class SQLState(State):
    history_messages: list[Message]
    current_action_type: str
```

### Server-Only Types (never sent to agent)

```python
@dataclass
class QuestionRecord:
    question_id: str
    question_text: str
    database_name: str
    gold_sql: str
    gold_answer: str
    answer_type: str          # integer | float | string | list
    difficulty: str           # easy | medium | hard
    tables_involved: list[str]

@dataclass
class EpisodeContext:
    episode_id: str
    db_connection: sqlite3.Connection
    question_record: QuestionRecord
    step_count: int = 0
    budget: int = 15
    described_tables: set[str]
    action_log: list[str]
    done: bool = False
    gold_answer: str | None
    gold_rows: list[tuple]
    # Reward accumulators
    query_hashes: set[str]
    best_progress: float = 0.0
    cumulative_step_reward: float = 0.0
    cumulative_new_info_reward: float = 0.0
```

**POMDP design:** The agent sees `SQLObservation`. The server holds `EpisodeContext`. The agent never sees gold answers, progress scores, or the full database. This separation forces exploration.

---

## API Contracts

### HTTP (OpenEnv Protocol)

The server exposes HTTP endpoints via `openenv.core.env_server.create_app()`.

| Operation | Method | Payload | Response |
|-----------|--------|---------|----------|
| Reset | `POST /reset` | `{seed: int}` (optional) | `SQLObservation` (JSON) |
| Step | `POST /step` | `{action_type, argument, metadata}` | `{observation, reward, done, info}` |
| State | `GET /state` | β€” | `SQLState` (JSON) |

### Evaluation API

```python
# Policy protocol
class Policy(Protocol):
    def select_action(self, observation: SQLObservation) -> SQLAction: ...

# Built-in policies
RandomPolicy()                    # random baseline
OraclePolicy(questions)           # gold-answer upper bound

# Runner
evaluate(env, policy, n_episodes, seed) -> EvaluationResult
#   .success_rate, .avg_reward, .avg_steps, .episodes[]
```

### TRL Adapter API

```python
SQLEnvTRL.configure(questions_path, db_dir, step_budget)  # class method
# Tool methods (auto-discovered by TRL):
SQLEnvTRL.describe(table_name: str) -> str
SQLEnvTRL.sample(table_name: str) -> str
SQLEnvTRL.query(sql: str) -> str
SQLEnvTRL.answer(value: str) -> str
```

---

## Cross-Cutting Concerns

### SQL Safety

All database access enforces:
- **Read-only** SQLite connections (`file:...?mode=ro`)
- **SELECT-only** β€” rejects INSERT, UPDATE, DELETE, ALTER, DROP
- **Single statement** β€” rejects `; ...` (no stacked queries)
- **5-second timeout** via SQLite progress handler
- **20-row truncation** on all result sets

### POMDP Structure

The partial observability is deliberate and load-bearing:
- Agent sees table names at reset but **not column details** (must DESCRIBE)
- Query results are **truncated** (at most 20 rows)
- Agent never sees `gold_answer`, `best_progress`, or `gold_rows`
- Step budget (default 15) forces strategic allocation of exploration

### Import Compatibility

Dual import paths throughout for local vs Docker execution:
```python
try:
    from sql_env.models import SQLAction      # local / pip install
except ImportError:
    from models import SQLAction              # Docker (PYTHONPATH=/app/env)
```

### Configuration

| Variable | Required | Default | Purpose |
|----------|----------|---------|---------|
| `QUESTIONS_PATH` | No | `data/questions/student_assessment.json` | Questions JSON |
| `DB_DIR` | No | `data/databases/` | SQLite database directory |
| `TOKENIZER_NAME` | No | `mistralai/Mistral-7B-Instruct-v0.1` | HuggingFace tokenizer |
| `PORT` | No | `8000` | Server port |

---

## Data, State, and Storage

### Committed Data

| Path | Contents |
|------|----------|
| `data/questions/questions_train.json` | 473 training questions across 10 DBs |
| `data/questions/questions_eval.json` | 203 evaluation questions across 10 DBs |
| `data/questions/db_list.json` | 10 Spider database IDs |
| `data/databases/models.py` | Legacy SQLAlchemy ORM models |

### Downloaded Data (gitignored)

Spider SQLite databases in `data/databases/{db_id}/{db_id}.sqlite`. Downloaded via `scripts/download_spider_databases.py`. The 10 databases: student_assessment, concert_singer, world_1, car_1, employee_hire_evaluation, pets_1, cre_Doc_Template_Mgt, dog_kennels, flight_2, poker_player.

### Runtime State (in-memory, per episode)

`EpisodeContext` holds all episode state: DB connection, gold data, reward accumulators, action history. Created on `reset()`, discarded when episode ends. Nothing persists between episodes.

---

<!-- ARCHITECTURE-SNAPSHOT-BEGIN -->

Snapshot (auto-managed)

- Repo signals: Python (pyproject.toml)
- Roots: tests/
- Entrypoint candidates: (none detected)

```text
tests/
  e2e/
    test_training_e2e.py
  integration/
    test_training_pipeline.py
  unit/
    test_error_handling.py
    test_grpo_config.py
    test_oracle_policy.py
    test_prompts.py
    test_reward.py
    test_rewards.py
    test_rollout.py
    test_sft_terminal_message.py
    test_trl_adapter.py
  test_evaluation.py
  test_smoke.py
  test_synthetic.py
  test_trl_adapter.py
  test_verifier.py
  test_verifier_integration.py
```

<!-- ARCHITECTURE-SNAPSHOT-END -->

---

## Infrastructure

### Development

**Prerequisites:** Python 3.11-3.12, `uv`, Docker (for deployment)

```bash
git clone <repo-url> && cd sql-env
uv sync
uv run python scripts/download_spider_databases.py
uv run pytest tests/ -v
```

### Deployment

**Target:** HuggingFace Spaces (Docker, free tier)

```bash
uv run openenv build           # build Docker image
uv run openenv push            # push to HF Spaces
```

The Dockerfile uses multi-stage build with `openenv-base`, runs as non-root `appuser`, bundles Spider databases, and exposes port 8000.

---

## Invariants

- Token tensors in `SQLState` grow monotonically across turns (never shrink mid-episode)
- `EpisodeContext` is server-only β€” leaking gold data to the agent breaks the POMDP
- Per-step rewards clipped to `[-0.05, 0.15]` β€” terminal reward (+1.0) always dominates exploration (~0.3 max)
- `tests/` must pass without GPU, without network, without downloaded databases (mocks/fixtures)
- SQL execution never mutates the database (read-only mode enforced at connection level)

---

## Glossary

| Term | Definition |
|------|------------|
| Episode | One question-answering session: reset -> N steps -> terminal |
| Action type | One of: DESCRIBE, SAMPLE, QUERY, ANSWER |
| POMDP | Partially observable MDP. Agent acts under uncertainty |
| Spider | Academic text-to-SQL benchmark dataset (10 DBs used) |
| OpenEnv | Meta's RL environment framework (Environment, EnvClient) |
| Green Agent | OpenEnv's evaluation wrapper pattern |
| Oracle policy | Baseline that uses gold SQL/answer for reward ceiling validation |
| TRL | HuggingFace Transformer Reinforcement Learning library |
| GRPO | Group Relative Policy Optimization (RL algorithm used for training) |
| Dense reward | Per-step reward signal (vs sparse terminal-only reward) |

---

## References

- OpenEnv framework: https://github.com/meta-pytorch/OpenEnv
- Spider dataset: https://huggingface.co/datasets/xlangai/spider
- TRL OpenEnv docs: https://huggingface.co/docs/trl/openenv