File size: 4,666 Bytes
af940ac
 
 
 
 
 
 
 
 
 
 
 
 
d103a0f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
---
title: SQL Data Analyst
emoji: πŸ“Š
colorFrom: gray
colorTo: green
sdk: docker
sdk_version: "0.1"
app_file: app.py
pinned: false
license: mit
---

An RL training environment where an AI agent learns to answer business intelligence questions by writing and executing SQL queries against a live database.

An RL training environment where an AI agent learns to answer business intelligence questions by writing and executing SQL queries against a live database.

## Motivation

Data analysts spend significant time translating business questions into SQL queries. This environment trains agents to do exactly that β€” iteratively exploring a database schema, writing queries, observing results, and submitting final answers.

## Quick Start

```bash
# Install dependencies
pip install -r requirements.txt

# Run tests
pytest tests/ -v
```

## Observation Space

| Field | Type | Description |
|-------|------|-------------|
| `schema_summary` | string | Compact DB schema (one line per table) |
| `question` | string | Natural language business question |
| `last_query` | string \| null | Most recent SQL query |
| `last_result` | object \| null | Query result: columns, rows (max 50), error |
| `last_error` | string \| null | SQL error if last query failed |
| `step` | int | Current step number |
| `max_steps` | int | Episode step limit |
| `hints` | string[] | Progressive hints (revealed after step 5, 10, 15) |
| `done` | bool | Whether episode is complete |

## Action Space

Agent must submit exactly one of:

| Action | Type | Description |
|--------|------|-------------|
| `sql_query` | string | A SELECT or WITH SQL query to execute |
| `submit_answer` | string | Final answer β€” ends the episode |

## Tasks

| Task | Difficulty | Max Steps | Description |
|------|------------|-----------|--------------|
| `monthly_signups` | Easy | 10 | Count signups in the last 30 days |
| `top_revenue_category` | Medium | 15 | Find highest revenue product category in Q3 |
| `churn_analysis` | Hard | 20 | Find emails of users who churned after 3 purchases |

## Reward Function

Rewards are given at every step (not just episode end):

- `+0.15` β€” Query executes without error
- `+0.10` β€” Query references a relevant table
- `+0.05` β€” Result has at least one row
- `+0.05` β€” Result is a sensible size
- `-0.02` per step beyond step 3 (efficiency penalty)
- `-0.10` if agent repeats the same query 3+ times
- `+0.00–0.60` on final submission (task grader Γ— 0.60)

## Usage

### Python API

```python
from env import SQLAnalystEnv, Action

env = SQLAnalystEnv(task_id="monthly_signups")
result = env.reset()
print(result.observation.question)

# Agent takes a step
result = env.step(Action(sql_query="SELECT COUNT(*) FROM users WHERE created_at >= DATE('now', '-30 days')"))
print(result.reward)
```

### FastAPI Server

```bash
python -m uvicorn env.server:app --host 0.0.0.0 --port 7860
```

REST endpoints:
- `POST /reset` β€” Reset environment
- `POST /step` β€” Execute action
- `POST /state` β€” Get current state
- `WebSocket /ws` β€” WebSocket for low-latency training

### Baseline Inference

```bash
export OPENAI_API_KEY=sk-...
python baseline/run_baseline.py
```

### Docker

```bash
docker build -t sql-analyst-env .
docker run -p 7860:7860 sql-analyst-env
```

## Tests

```bash
pytest tests/ -v
```

- `test_env.py` β€” OpenEnv contract tests
- `test_graders.py` β€” Task grader unit tests  
- `test_reward.py` β€” Reward calculator tests

**All 46 tests pass.**

## Baseline Scores

| Task | Score | Model |
|------|-------|-------|
| monthly_signups | ~0.85 | gpt-4o-mini |
| top_revenue_category | ~0.65 | gpt-4o-mini |
| churn_analysis | ~0.40 | gpt-4o-mini |
| **Average** | **~0.63** | gpt-4o-mini |

## File Structure

```
sql-data-analyst/
β”œβ”€β”€ env/
β”‚   β”œβ”€β”€ __init__.py
β”‚   β”œβ”€β”€ models.py           # Pydantic models
β”‚   β”œβ”€β”€ database.py         # SQLite + seeding
β”‚   β”œβ”€β”€ environment.py      # Core environment
β”‚   β”œβ”€β”€ reward.py           # Reward calculator
β”‚   β”œβ”€β”€ utils.py            # Helpers
β”‚   β”œβ”€β”€ server.py           # FastAPI server
β”‚   └── tasks/
β”‚       β”œβ”€β”€ __init__.py
β”‚       β”œβ”€β”€ base.py
β”‚       β”œβ”€β”€ easy.py
β”‚       β”œβ”€β”€ medium.py
β”‚       └── hard.py
β”œβ”€β”€ baseline/
β”‚   β”œβ”€β”€ __init__.py
β”‚   β”œβ”€β”€ run_baseline.py
β”‚   └── prompts.py
β”œβ”€β”€ tests/
β”‚   β”œβ”€β”€ __init__.py
β”‚   β”œβ”€β”€ test_env.py
β”‚   β”œβ”€β”€ test_graders.py
β”‚   └── test_reward.py
β”œβ”€β”€ openenv.yaml
β”œβ”€β”€ Dockerfile
β”œβ”€β”€ requirements.txt
└── README.md
```

## License

MIT