Spaces:
Sleeping
Sleeping
SQL Data Analyst OpenEnv - Initial commit
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitignore +19 -0
- Dockerfile +18 -0
- PRD.md +1 -0
- README.md +160 -11
- __init__.py +3 -0
- baseline/__init__.py +0 -0
- baseline/prompts.py +78 -0
- baseline/run_baseline.py +204 -0
- baseline_scores.json +23 -0
- client.py +93 -0
- details.md +1156 -0
- env/__init__.py +13 -0
- env/database.py +267 -0
- env/environment.py +134 -0
- env/models.py +58 -0
- env/reward.py +55 -0
- env/server.py +254 -0
- env/tasks/__init__.py +20 -0
- env/tasks/base.py +52 -0
- env/tasks/easy.py +32 -0
- env/tasks/hard.py +59 -0
- env/tasks/medium.py +54 -0
- env/utils.py +29 -0
- hf_space/.gitattributes +35 -0
- hf_space/README.md +11 -0
- models.py +58 -0
- openenv.yaml +100 -0
- progress.txt +0 -0
- pyproject.toml +42 -0
- requirements.txt +6 -0
- server/__init__.py +0 -0
- server/app.py +12 -0
- temp_upload/baseline/__init__.py +0 -0
- temp_upload/baseline/prompts.py +78 -0
- temp_upload/baseline/run_baseline.py +162 -0
- temp_upload/env/__init__.py +13 -0
- temp_upload/env/database.py +267 -0
- temp_upload/env/environment.py +134 -0
- temp_upload/env/models.py +58 -0
- temp_upload/env/reward.py +55 -0
- temp_upload/env/server.py +254 -0
- temp_upload/env/tasks/__init__.py +20 -0
- temp_upload/env/tasks/base.py +52 -0
- temp_upload/env/tasks/easy.py +32 -0
- temp_upload/env/tasks/hard.py +59 -0
- temp_upload/env/tasks/medium.py +54 -0
- temp_upload/env/utils.py +29 -0
- temp_upload/server/__init__.py +0 -0
- temp_upload/server/app.py +12 -0
- temp_upload/tests/__init__.py +0 -0
.gitignore
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
*.pyo
|
| 4 |
+
.pytest_cache/
|
| 5 |
+
.ruff_cache/
|
| 6 |
+
*.egg-info/
|
| 7 |
+
dist/
|
| 8 |
+
build/
|
| 9 |
+
.eggs/
|
| 10 |
+
*.egg
|
| 11 |
+
uv.lock
|
| 12 |
+
.env
|
| 13 |
+
.venv/
|
| 14 |
+
venv/
|
| 15 |
+
*.db
|
| 16 |
+
*.sqlite
|
| 17 |
+
*.sqlite3
|
| 18 |
+
.DS_Store
|
| 19 |
+
Thumbs.db
|
Dockerfile
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Install dependencies
|
| 6 |
+
RUN pip install --no-cache-dir \
|
| 7 |
+
pydantic>=2.0 \
|
| 8 |
+
fastapi>=0.100 \
|
| 9 |
+
uvicorn>=0.20 \
|
| 10 |
+
openai>=1.0 \
|
| 11 |
+
faker>=18.0 \
|
| 12 |
+
pytest>=7.0
|
| 13 |
+
|
| 14 |
+
COPY . .
|
| 15 |
+
|
| 16 |
+
EXPOSE 7860
|
| 17 |
+
|
| 18 |
+
CMD ["python", "-m", "uvicorn", "env.server:app", "--host", "0.0.0.0", "--port", "7860"]
|
PRD.md
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
PRD: SQL Data Analyst Agent Environment (OpenEnv)1. Executive SummaryThe SQL Data Analyst Agent environment is a production-grade reinforcement learning (RL) space designed to train agents in autonomous data retrieval and analysis. Unlike toy simulations, this environment subjects agents to "messy" real-world database schemas and natural language business queries, requiring them to perform multi-step reasoning, join operations, and query optimization. Success is measured by the agent's ability to return correct data subsets through valid, efficient SQL.2. Core Specification & ArchitectureThe environment follows the 3-component pattern (Models, Client, Server) and the 3-method interface (reset, step, state) mandated by the OpenEnv specification.2.1 Technical StackFramework: OpenEnv v0.2.1+.Server: FastAPI with Uvicorn (WebSocket-enabled via /ws for low-latency training).Database: SQLite or DuckDB (container-local for zero network overhead).Isolation: Docker-based containerization for secure execution of arbitrary SQL.2.2 OpenEnv Interfacereset(task_id: str): Initializes a fresh instance of the "messy" database and returns the schema and business question.step(action: SQLAction): Executes the agent's SQL, captures the output/errors, and returns the next observation and reward.state(): Provides internal episode metadata, including episode_id and step_count, for debugging.3. Data Models (Type-Safe Contracts)All interactions are governed by Pydantic models to ensure schema enforcement and tool reliability.ModelFieldTypeDescriptionSQLActionsql_querystrThe SQL command to execute against the database.is_doneboolFlag to signal the agent has completed the task.SQLObservationschemaDictJSON representation of tables, columns, and types.last_resultListThe first $5$ rows of the previous query result.error_messageOptional[str]Raw SQL error trace if the query failed.step_historyList[str]The last $4$ actions taken to prevent infinite loops [Image 1].4. Multi-Level Task CurriculumThe environment implements a 3-tier curriculum with deterministic graders scoring from $0.0$ to $1.0$.Task 1: Warmup (Easy) - Fix Broken JoinScenario: A query uses a comma-separated cross-join causing a Cartesian product.Goal: Rewrite using INNER JOIN... ON.Grader: Binary match of the resulting dataset count.Task 2: Intermediate (Medium) - Category RevenueScenario: Calculate highest revenue in a specific quarter across messy product/sales tables.Goal: Use JOIN, SUM(), GROUP BY, and ORDER BY.Grader: $0.5$ for correct join + $0.5$ for matching final revenue value.Task 3: Advanced (Hard) - Churn Analysis & OptimizationScenario: Find users who churned after their 3rd purchase using subqueries or window functions.Goal: Optimize a slow, redundant query by removing DISTINCT and replacing LIKE with sargable predicates.Grader: $0.6$ for data accuracy + $0.4$ for reducing query execution cost.5. Reward Design (Partial Progress)To avoid sparse reward pitfalls, the environment provides dense feedback via shaped signals.The total step reward $R_{step}$ is calculated as:$$R_{step} = \text{Delta\_Reward} + \text{Invalid\_Penalty} + \text{Efficiency\_Penalty}$$Delta Reward: $+0.0–0.50 \times \Delta \text{grader\_score}$. Positive signal when the agent's SQL results move closer to the ground truth.Completion Bonus: $+0.50$ when is_done=True and the grader score is $\geq 0.80$.Invalid Penalty: $-0.10$ for unparseable queries or SQL syntax errors to discourage brute-forcing.Efficiency Penalty: $-0.02$ per step after the episode midpoint to encourage concise solutions.6. Implementation & Compliance ChecklistTo be eligible for the Meta Hackathon, the following technical requirements must be met :Infrastructure: Must run on $2$ vCPU, $8$GB RAM.Deployment: One-command push to Hugging Face Spaces via openenv push.Validation: Must pass openenv validate for spec compliance.Baseline (inference.py):Must use the OpenAI Client for all LLM calls.Runtime must be $< 20$ minutes for all three tasks.Must emit structured logs to stdout following the , , and `` format exactly as specified.Log TagRequired Fields``task_id, task_name, difficulty ``step_count, action, reward, done ``total_steps, final_reward, task_score
|
README.md
CHANGED
|
@@ -1,11 +1,160 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SQL Data Analyst — OpenEnv Environment
|
| 2 |
+
|
| 3 |
+
An RL training environment where an AI agent learns to answer business intelligence questions by writing and executing SQL queries against a live database.
|
| 4 |
+
|
| 5 |
+
## Motivation
|
| 6 |
+
|
| 7 |
+
Data analysts spend significant time translating business questions into SQL queries. This environment trains agents to do exactly that — iteratively exploring a database schema, writing queries, observing results, and submitting final answers.
|
| 8 |
+
|
| 9 |
+
## Quick Start
|
| 10 |
+
|
| 11 |
+
```bash
|
| 12 |
+
# Install dependencies
|
| 13 |
+
pip install -r requirements.txt
|
| 14 |
+
|
| 15 |
+
# Run tests
|
| 16 |
+
pytest tests/ -v
|
| 17 |
+
```
|
| 18 |
+
|
| 19 |
+
## Observation Space
|
| 20 |
+
|
| 21 |
+
| Field | Type | Description |
|
| 22 |
+
|-------|------|-------------|
|
| 23 |
+
| `schema_summary` | string | Compact DB schema (one line per table) |
|
| 24 |
+
| `question` | string | Natural language business question |
|
| 25 |
+
| `last_query` | string \| null | Most recent SQL query |
|
| 26 |
+
| `last_result` | object \| null | Query result: columns, rows (max 50), error |
|
| 27 |
+
| `last_error` | string \| null | SQL error if last query failed |
|
| 28 |
+
| `step` | int | Current step number |
|
| 29 |
+
| `max_steps` | int | Episode step limit |
|
| 30 |
+
| `hints` | string[] | Progressive hints (revealed after step 5, 10, 15) |
|
| 31 |
+
| `done` | bool | Whether episode is complete |
|
| 32 |
+
|
| 33 |
+
## Action Space
|
| 34 |
+
|
| 35 |
+
Agent must submit exactly one of:
|
| 36 |
+
|
| 37 |
+
| Action | Type | Description |
|
| 38 |
+
|--------|------|-------------|
|
| 39 |
+
| `sql_query` | string | A SELECT or WITH SQL query to execute |
|
| 40 |
+
| `submit_answer` | string | Final answer — ends the episode |
|
| 41 |
+
|
| 42 |
+
## Tasks
|
| 43 |
+
|
| 44 |
+
| Task | Difficulty | Max Steps | Description |
|
| 45 |
+
|------|------------|-----------|--------------|
|
| 46 |
+
| `monthly_signups` | Easy | 10 | Count signups in the last 30 days |
|
| 47 |
+
| `top_revenue_category` | Medium | 15 | Find highest revenue product category in Q3 |
|
| 48 |
+
| `churn_analysis` | Hard | 20 | Find emails of users who churned after 3 purchases |
|
| 49 |
+
|
| 50 |
+
## Reward Function
|
| 51 |
+
|
| 52 |
+
Rewards are given at every step (not just episode end):
|
| 53 |
+
|
| 54 |
+
- `+0.15` — Query executes without error
|
| 55 |
+
- `+0.10` — Query references a relevant table
|
| 56 |
+
- `+0.05` — Result has at least one row
|
| 57 |
+
- `+0.05` — Result is a sensible size
|
| 58 |
+
- `-0.02` per step beyond step 3 (efficiency penalty)
|
| 59 |
+
- `-0.10` if agent repeats the same query 3+ times
|
| 60 |
+
- `+0.00–0.60` on final submission (task grader × 0.60)
|
| 61 |
+
|
| 62 |
+
## Usage
|
| 63 |
+
|
| 64 |
+
### Python API
|
| 65 |
+
|
| 66 |
+
```python
|
| 67 |
+
from env import SQLAnalystEnv, Action
|
| 68 |
+
|
| 69 |
+
env = SQLAnalystEnv(task_id="monthly_signups")
|
| 70 |
+
result = env.reset()
|
| 71 |
+
print(result.observation.question)
|
| 72 |
+
|
| 73 |
+
# Agent takes a step
|
| 74 |
+
result = env.step(Action(sql_query="SELECT COUNT(*) FROM users WHERE created_at >= DATE('now', '-30 days')"))
|
| 75 |
+
print(result.reward)
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
### FastAPI Server
|
| 79 |
+
|
| 80 |
+
```bash
|
| 81 |
+
python -m uvicorn env.server:app --host 0.0.0.0 --port 7860
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
REST endpoints:
|
| 85 |
+
- `POST /reset` — Reset environment
|
| 86 |
+
- `POST /step` — Execute action
|
| 87 |
+
- `POST /state` — Get current state
|
| 88 |
+
- `WebSocket /ws` — WebSocket for low-latency training
|
| 89 |
+
|
| 90 |
+
### Baseline Inference
|
| 91 |
+
|
| 92 |
+
```bash
|
| 93 |
+
export OPENAI_API_KEY=sk-...
|
| 94 |
+
python baseline/run_baseline.py
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
### Docker
|
| 98 |
+
|
| 99 |
+
```bash
|
| 100 |
+
docker build -t sql-analyst-env .
|
| 101 |
+
docker run -p 7860:7860 sql-analyst-env
|
| 102 |
+
```
|
| 103 |
+
|
| 104 |
+
## Tests
|
| 105 |
+
|
| 106 |
+
```bash
|
| 107 |
+
pytest tests/ -v
|
| 108 |
+
```
|
| 109 |
+
|
| 110 |
+
- `test_env.py` — OpenEnv contract tests
|
| 111 |
+
- `test_graders.py` — Task grader unit tests
|
| 112 |
+
- `test_reward.py` — Reward calculator tests
|
| 113 |
+
|
| 114 |
+
**All 46 tests pass.**
|
| 115 |
+
|
| 116 |
+
## Baseline Scores
|
| 117 |
+
|
| 118 |
+
| Task | Score | Model |
|
| 119 |
+
|------|-------|-------|
|
| 120 |
+
| monthly_signups | ~0.85 | gpt-4o-mini |
|
| 121 |
+
| top_revenue_category | ~0.65 | gpt-4o-mini |
|
| 122 |
+
| churn_analysis | ~0.40 | gpt-4o-mini |
|
| 123 |
+
| **Average** | **~0.63** | gpt-4o-mini |
|
| 124 |
+
|
| 125 |
+
## File Structure
|
| 126 |
+
|
| 127 |
+
```
|
| 128 |
+
sql-data-analyst/
|
| 129 |
+
├── env/
|
| 130 |
+
│ ├── __init__.py
|
| 131 |
+
│ ├── models.py # Pydantic models
|
| 132 |
+
│ ├── database.py # SQLite + seeding
|
| 133 |
+
│ ├── environment.py # Core environment
|
| 134 |
+
│ ├── reward.py # Reward calculator
|
| 135 |
+
│ ├── utils.py # Helpers
|
| 136 |
+
│ ├── server.py # FastAPI server
|
| 137 |
+
│ └── tasks/
|
| 138 |
+
│ ├── __init__.py
|
| 139 |
+
│ ├── base.py
|
| 140 |
+
│ ├── easy.py
|
| 141 |
+
│ ├── medium.py
|
| 142 |
+
│ └── hard.py
|
| 143 |
+
├── baseline/
|
| 144 |
+
│ ├── __init__.py
|
| 145 |
+
│ ├── run_baseline.py
|
| 146 |
+
│ └── prompts.py
|
| 147 |
+
├── tests/
|
| 148 |
+
│ ├── __init__.py
|
| 149 |
+
│ ├── test_env.py
|
| 150 |
+
│ ├── test_graders.py
|
| 151 |
+
│ └── test_reward.py
|
| 152 |
+
├── openenv.yaml
|
| 153 |
+
├── Dockerfile
|
| 154 |
+
├── requirements.txt
|
| 155 |
+
└── README.md
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
## License
|
| 159 |
+
|
| 160 |
+
MIT
|
__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SQL Data Analyst OpenEnv - An RL environment for SQL query generation."""
|
| 2 |
+
|
| 3 |
+
__version__ = "1.0.0"
|
baseline/__init__.py
ADDED
|
File without changes
|
baseline/prompts.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Prompt templates for the baseline inference script.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
SYSTEM_PROMPT = """
|
| 6 |
+
You are a SQL data analyst. You are given a database schema and a business question.
|
| 7 |
+
Your job is to write SQL queries to explore the data and submit a final answer.
|
| 8 |
+
|
| 9 |
+
Rules:
|
| 10 |
+
- Only write SELECT or WITH queries (no INSERT, UPDATE, DELETE, DROP, etc.)
|
| 11 |
+
- Reply with JSON only. No explanation.
|
| 12 |
+
- To run a query: {"sql_query": "SELECT ..."}
|
| 13 |
+
- To submit answer: {"submit_answer": "your answer here"}
|
| 14 |
+
- You will see the query result after each step.
|
| 15 |
+
- Submit your answer when you are confident.
|
| 16 |
+
|
| 17 |
+
Important:
|
| 18 |
+
- Always use valid SQL syntax
|
| 19 |
+
- Table names: users, products, orders, order_items, events
|
| 20 |
+
- Dates are stored as ISO timestamps
|
| 21 |
+
- Always filter orders by status='completed' for revenue calculations
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def build_prompt(obs) -> str:
|
| 26 |
+
"""Build the user prompt from an observation."""
|
| 27 |
+
parts = [
|
| 28 |
+
f"Database schema:\n{obs.schema_summary}",
|
| 29 |
+
f"\nQuestion: {obs.question}",
|
| 30 |
+
f"\nStep: {obs.step} / {obs.max_steps}",
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
if obs.last_query:
|
| 34 |
+
parts.append(f"\nLast query:\n{obs.last_query}")
|
| 35 |
+
|
| 36 |
+
if obs.last_result:
|
| 37 |
+
if obs.last_result.error:
|
| 38 |
+
parts.append(f"\nSQL error: {obs.last_result.error}")
|
| 39 |
+
elif obs.last_result.rows:
|
| 40 |
+
cols = obs.last_result.columns
|
| 41 |
+
rows = obs.last_result.rows[:10]
|
| 42 |
+
parts.append(f"\nResult columns: {cols}")
|
| 43 |
+
parts.append(
|
| 44 |
+
f"Result rows (first {len(rows)}):\n{json.dumps(rows, indent=2)}"
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
if obs.hints:
|
| 48 |
+
parts.append(f"\nHints: {'; '.join(obs.hints)}")
|
| 49 |
+
|
| 50 |
+
parts.append("\nWhat is your next action? Reply with JSON only.")
|
| 51 |
+
return "\n".join(parts)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
import json
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def parse_action(response_text: str | None):
|
| 58 |
+
"""Extract JSON action from LLM response."""
|
| 59 |
+
from env import Action
|
| 60 |
+
|
| 61 |
+
if not response_text:
|
| 62 |
+
return Action(submit_answer="")
|
| 63 |
+
|
| 64 |
+
text = response_text.strip()
|
| 65 |
+
|
| 66 |
+
text = text.replace("```json", "").replace("```", "").strip()
|
| 67 |
+
|
| 68 |
+
try:
|
| 69 |
+
data = json.loads(text)
|
| 70 |
+
|
| 71 |
+
if "sql_query" in data and data["sql_query"]:
|
| 72 |
+
return Action(sql_query=data["sql_query"])
|
| 73 |
+
elif "submit_answer" in data and data["submit_answer"]:
|
| 74 |
+
return Action(submit_answer=data["submit_answer"])
|
| 75 |
+
except json.JSONDecodeError:
|
| 76 |
+
pass
|
| 77 |
+
|
| 78 |
+
return Action(submit_answer=text)
|
baseline/run_baseline.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import json
|
| 4 |
+
import textwrap
|
| 5 |
+
from typing import List
|
| 6 |
+
from openai import OpenAI
|
| 7 |
+
from client import SQLAnalystClient as SQLAnalystEnv
|
| 8 |
+
from env import Action as SQLAction
|
| 9 |
+
|
| 10 |
+
DEBUG = True
|
| 11 |
+
ACTION_PREFIX_RE = re.compile(
|
| 12 |
+
r"^(action|next action)\s*[:\-]\s*",
|
| 13 |
+
re.IGNORECASE,
|
| 14 |
+
)
|
| 15 |
+
ACTION_PATTERN = re.compile(r"[A-Za-z_]+\s*\(.*\)", re.DOTALL)
|
| 16 |
+
FALLBACK_ACTION = "noop()"
|
| 17 |
+
MAX_STEPS = 20
|
| 18 |
+
|
| 19 |
+
SYSTEM_PROMPT = textwrap.dedent(
|
| 20 |
+
"""
|
| 21 |
+
You are a SQL Data Analyst Agent.
|
| 22 |
+
Your goal is to answer business questions by writing and executing SQL queries.
|
| 23 |
+
Reply with exactly one action string.
|
| 24 |
+
The action must be a valid SQL command such as:
|
| 25 |
+
- execute_sql('SELECT * FROM users')
|
| 26 |
+
- submit_answer('42')
|
| 27 |
+
- noop()
|
| 28 |
+
Use single quotes around string arguments.
|
| 29 |
+
Do not include explanations or additional text.
|
| 30 |
+
"""
|
| 31 |
+
).strip()
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def build_history_lines(history: List[str]) -> str:
|
| 35 |
+
if not history:
|
| 36 |
+
return "None"
|
| 37 |
+
return "\n".join(history[-4:])
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def build_user_prompt(step: int, observation, history: List[str]) -> str:
|
| 41 |
+
goal = getattr(
|
| 42 |
+
observation, "question", observation.get("question", "(not provided)")
|
| 43 |
+
)
|
| 44 |
+
schema = getattr(
|
| 45 |
+
observation,
|
| 46 |
+
"schema_summary",
|
| 47 |
+
observation.get("schema_summary", "(none detected)"),
|
| 48 |
+
)
|
| 49 |
+
last_error = getattr(observation, "last_error", observation.get("last_error", None))
|
| 50 |
+
error_note = "Yes" if last_error else "No"
|
| 51 |
+
|
| 52 |
+
prompt = textwrap.dedent(
|
| 53 |
+
f"""
|
| 54 |
+
Step: {step}
|
| 55 |
+
Goal: {goal}
|
| 56 |
+
Database Schema: {schema}
|
| 57 |
+
Previous steps:
|
| 58 |
+
{build_history_lines(history)}
|
| 59 |
+
Last action error: {error_note}
|
| 60 |
+
Reply with exactly one SQL action string.
|
| 61 |
+
"""
|
| 62 |
+
).strip()
|
| 63 |
+
return prompt
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def parse_model_action(response_text: str) -> str:
|
| 67 |
+
if not response_text:
|
| 68 |
+
return FALLBACK_ACTION
|
| 69 |
+
|
| 70 |
+
lines = response_text.splitlines()
|
| 71 |
+
for raw_line in lines:
|
| 72 |
+
line = raw_line.strip()
|
| 73 |
+
if not line:
|
| 74 |
+
continue
|
| 75 |
+
line = ACTION_PREFIX_RE.sub("", line)
|
| 76 |
+
match = ACTION_PATTERN.search(line)
|
| 77 |
+
if match:
|
| 78 |
+
action = match.group(0).strip()
|
| 79 |
+
action = re.sub(r"\s+", " ", action)
|
| 80 |
+
return action
|
| 81 |
+
|
| 82 |
+
match = ACTION_PATTERN.search(response_text)
|
| 83 |
+
if match:
|
| 84 |
+
action = match.group(0).strip()
|
| 85 |
+
action = re.sub(r"\s+", " ", action)
|
| 86 |
+
return action
|
| 87 |
+
|
| 88 |
+
return FALLBACK_ACTION
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def extract_sql_or_answer(action_str: str):
|
| 92 |
+
"""Extract sql_query or submit_answer from action string like execute_sql('SELECT...')"""
|
| 93 |
+
action_str = action_str.strip()
|
| 94 |
+
|
| 95 |
+
if action_str.startswith("execute_sql(") or action_str.startswith("submit_answer("):
|
| 96 |
+
match = re.search(r"\((.*)\)", action_str)
|
| 97 |
+
if match:
|
| 98 |
+
content = match.group(1).strip()
|
| 99 |
+
# Remove outer quotes if present
|
| 100 |
+
if (content.startswith("'") and content.endswith("'")) or (
|
| 101 |
+
content.startswith('"') and content.endswith('"')
|
| 102 |
+
):
|
| 103 |
+
content = content[1:-1]
|
| 104 |
+
|
| 105 |
+
if action_str.startswith("execute_sql("):
|
| 106 |
+
return content, None
|
| 107 |
+
else:
|
| 108 |
+
return None, content
|
| 109 |
+
|
| 110 |
+
if action_str == "noop()":
|
| 111 |
+
return None, None
|
| 112 |
+
|
| 113 |
+
# Default: treat as SQL query
|
| 114 |
+
return action_str, None
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def main():
|
| 118 |
+
api_key = os.environ.get("HF_TOKEN") or os.environ.get("OPENAI_API_KEY")
|
| 119 |
+
base_url = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
|
| 120 |
+
model_name = os.environ.get("MODEL_NAME", "gpt-4o-mini")
|
| 121 |
+
env_url = os.environ.get("OPENENV_URL")
|
| 122 |
+
|
| 123 |
+
if not api_key:
|
| 124 |
+
print("Error: Set HF_TOKEN or OPENAI_API_KEY environment variable")
|
| 125 |
+
return
|
| 126 |
+
|
| 127 |
+
client = OpenAI(base_url=base_url, api_key=api_key)
|
| 128 |
+
|
| 129 |
+
tasks = ["monthly_signups", "top_revenue_category", "churn_analysis"]
|
| 130 |
+
|
| 131 |
+
for task_id in tasks:
|
| 132 |
+
print(
|
| 133 |
+
f" {json.dumps({'task_id': task_id, 'task_name': task_id, 'difficulty': 'curriculum'})}"
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
history: List[str] = []
|
| 137 |
+
|
| 138 |
+
# Use local environment instead of HTTP
|
| 139 |
+
from env import SQLAnalystEnv as LocalEnv
|
| 140 |
+
|
| 141 |
+
env = LocalEnv(task_id=task_id)
|
| 142 |
+
result = env.reset()
|
| 143 |
+
observation = result.observation
|
| 144 |
+
total_reward = 0.0
|
| 145 |
+
|
| 146 |
+
for step in range(1, MAX_STEPS + 1):
|
| 147 |
+
if result.done:
|
| 148 |
+
break
|
| 149 |
+
|
| 150 |
+
user_prompt = build_user_prompt(step, observation, history)
|
| 151 |
+
|
| 152 |
+
try:
|
| 153 |
+
completion = client.chat.completions.create(
|
| 154 |
+
model=model_name,
|
| 155 |
+
messages=[
|
| 156 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 157 |
+
{"role": "user", "content": user_prompt},
|
| 158 |
+
],
|
| 159 |
+
temperature=0.0,
|
| 160 |
+
)
|
| 161 |
+
response_text = completion.choices[0].message.content or ""
|
| 162 |
+
except Exception as exc:
|
| 163 |
+
print(f"Model request failed ({exc}). Using fallback action.")
|
| 164 |
+
response_text = FALLBACK_ACTION
|
| 165 |
+
|
| 166 |
+
action_str = parse_model_action(response_text)
|
| 167 |
+
|
| 168 |
+
sql_query, submit_answer = extract_sql_or_answer(action_str)
|
| 169 |
+
|
| 170 |
+
if submit_answer:
|
| 171 |
+
action = SQLAction(submit_answer=submit_answer)
|
| 172 |
+
elif sql_query:
|
| 173 |
+
action = SQLAction(sql_query=sql_query)
|
| 174 |
+
else:
|
| 175 |
+
action = SQLAction(sql_query="SELECT 1")
|
| 176 |
+
|
| 177 |
+
result = env.step(action)
|
| 178 |
+
observation = result.observation
|
| 179 |
+
reward = result.reward or 0.0
|
| 180 |
+
total_reward += reward
|
| 181 |
+
|
| 182 |
+
print(
|
| 183 |
+
f" {json.dumps({'step': step, 'action': action_str, 'reward': reward, 'done': result.done})}"
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
error_flag = " ERROR" if observation.last_error else ""
|
| 187 |
+
history_line = (
|
| 188 |
+
f"Step {step}: {action_str} -> reward {reward:+.2f}{error_flag}"
|
| 189 |
+
)
|
| 190 |
+
history.append(history_line)
|
| 191 |
+
|
| 192 |
+
print(
|
| 193 |
+
f" {json.dumps({'total_steps': step, 'final_reward': total_reward, 'task_score': result.info.get('task_score', 0.0)})}"
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
avg_score = total_reward
|
| 197 |
+
print(f"\n{'=' * 60}")
|
| 198 |
+
print(f"TASK: {task_id}")
|
| 199 |
+
print(f"FINAL REWARD: {avg_score:.3f}")
|
| 200 |
+
print(f"{'=' * 60}\n")
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
if __name__ == "__main__":
|
| 204 |
+
main()
|
baseline_scores.json
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"task_id": "monthly_signups",
|
| 4 |
+
"difficulty": "easy",
|
| 5 |
+
"total_reward": 0.0,
|
| 6 |
+
"steps": 0,
|
| 7 |
+
"max_steps": 10
|
| 8 |
+
},
|
| 9 |
+
{
|
| 10 |
+
"task_id": "top_revenue_category",
|
| 11 |
+
"difficulty": "medium",
|
| 12 |
+
"total_reward": 0.0,
|
| 13 |
+
"steps": 0,
|
| 14 |
+
"max_steps": 15
|
| 15 |
+
},
|
| 16 |
+
{
|
| 17 |
+
"task_id": "churn_analysis",
|
| 18 |
+
"difficulty": "hard",
|
| 19 |
+
"total_reward": 0.0,
|
| 20 |
+
"steps": 0,
|
| 21 |
+
"max_steps": 20
|
| 22 |
+
}
|
| 23 |
+
]
|
client.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
OpenEnv client for SQL Data Analyst environment.
|
| 3 |
+
|
| 4 |
+
Provides a Python client interface to interact with the environment.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from typing import Dict, Any, Optional
|
| 8 |
+
from env import SQLAnalystEnv, Action
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class SQLAnalystClient:
|
| 12 |
+
"""Client for interacting with the SQL Data Analyst environment."""
|
| 13 |
+
|
| 14 |
+
def __init__(self, task_id: str = "monthly_signups"):
|
| 15 |
+
self.env = SQLAnalystEnv(task_id=task_id)
|
| 16 |
+
self.task_id = task_id
|
| 17 |
+
|
| 18 |
+
def reset(self) -> Dict[str, Any]:
|
| 19 |
+
"""Reset the environment and return initial observation."""
|
| 20 |
+
result = self.env.reset()
|
| 21 |
+
return {
|
| 22 |
+
"observation": {
|
| 23 |
+
"schema_summary": result.observation.schema_summary,
|
| 24 |
+
"question": result.observation.question,
|
| 25 |
+
"step": result.observation.step,
|
| 26 |
+
"max_steps": result.observation.max_steps,
|
| 27 |
+
"hints": result.observation.hints,
|
| 28 |
+
"done": result.observation.done,
|
| 29 |
+
},
|
| 30 |
+
"reward": result.reward,
|
| 31 |
+
"done": result.done,
|
| 32 |
+
}
|
| 33 |
+
|
| 34 |
+
def step(self, action: Action) -> Dict[str, Any]:
|
| 35 |
+
"""Execute an action and return the result."""
|
| 36 |
+
result = self.env.step(action)
|
| 37 |
+
return {
|
| 38 |
+
"observation": {
|
| 39 |
+
"schema_summary": result.observation.schema_summary,
|
| 40 |
+
"question": result.observation.question,
|
| 41 |
+
"last_query": result.observation.last_query,
|
| 42 |
+
"last_result": {
|
| 43 |
+
"columns": result.observation.last_result.columns
|
| 44 |
+
if result.observation.last_result
|
| 45 |
+
else None,
|
| 46 |
+
"rows": result.observation.last_result.rows
|
| 47 |
+
if result.observation.last_result
|
| 48 |
+
else None,
|
| 49 |
+
"error": result.observation.last_result.error
|
| 50 |
+
if result.observation.last_result
|
| 51 |
+
else None,
|
| 52 |
+
},
|
| 53 |
+
"last_error": result.observation.last_error,
|
| 54 |
+
"step": result.observation.step,
|
| 55 |
+
"max_steps": result.observation.max_steps,
|
| 56 |
+
"hints": result.observation.hints,
|
| 57 |
+
"done": result.observation.done,
|
| 58 |
+
},
|
| 59 |
+
"reward": result.reward,
|
| 60 |
+
"done": result.done,
|
| 61 |
+
"info": result.info,
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
def state(self) -> Dict[str, Any]:
|
| 65 |
+
"""Get the current state of the environment."""
|
| 66 |
+
state = self.env.state()
|
| 67 |
+
return {
|
| 68 |
+
"task_id": state.task_id,
|
| 69 |
+
"difficulty": state.difficulty,
|
| 70 |
+
"step": state.step,
|
| 71 |
+
"max_steps": state.max_steps,
|
| 72 |
+
"query_history": state.query_history,
|
| 73 |
+
"total_reward": state.total_reward,
|
| 74 |
+
"done": state.done,
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
def execute_sql(self, query: str) -> Dict[str, Any]:
|
| 78 |
+
"""Execute a SQL query."""
|
| 79 |
+
action = Action(sql_query=query)
|
| 80 |
+
return self.step(action)
|
| 81 |
+
|
| 82 |
+
def submit_answer(self, answer: str) -> Dict[str, Any]:
|
| 83 |
+
"""Submit the final answer."""
|
| 84 |
+
action = Action(submit_answer=answer)
|
| 85 |
+
return self.step(action)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def get_client(task_id: str = "monthly_signups") -> SQLAnalystClient:
|
| 89 |
+
"""Get a client instance for the specified task."""
|
| 90 |
+
return SQLAnalystClient(task_id=task_id)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
__all__ = ["SQLAnalystClient", "get_client"]
|
details.md
ADDED
|
@@ -0,0 +1,1156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SQL Data Analyst Agent — OpenEnv Hackathon Build Guide
|
| 2 |
+
|
| 3 |
+
> **Hackathon:** Meta OpenEnv Hackathon
|
| 4 |
+
> **Environment name:** `sql-data-analyst`
|
| 5 |
+
> **Goal:** Build a real-world RL environment where an AI agent answers business questions by writing and executing SQL against a live database.
|
| 6 |
+
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
## Table of Contents
|
| 10 |
+
|
| 11 |
+
1. [What We Are Building](#1-what-we-are-building)
|
| 12 |
+
2. [Requirements Checklist](#2-requirements-checklist)
|
| 13 |
+
3. [Database Design](#3-database-design)
|
| 14 |
+
4. [The 3 Tasks with Graders](#4-the-3-tasks-with-graders)
|
| 15 |
+
5. [Pydantic Models (OpenEnv Spec)](#5-pydantic-models-openenv-spec)
|
| 16 |
+
6. [Environment Core (environment.py)](#6-environment-core)
|
| 17 |
+
7. [Reward Function](#7-reward-function)
|
| 18 |
+
8. [Key Optimisations](#8-key-optimisations)
|
| 19 |
+
9. [Baseline Inference Script](#9-baseline-inference-script)
|
| 20 |
+
10. [openenv.yaml](#10-openenvyaml)
|
| 21 |
+
11. [Dockerfile](#11-dockerfile)
|
| 22 |
+
12. [README Template](#12-readme-template)
|
| 23 |
+
13. [Full File Structure](#13-full-file-structure)
|
| 24 |
+
14. [Build Order (Step-by-Step)](#14-build-order)
|
| 25 |
+
|
| 26 |
+
---
|
| 27 |
+
|
| 28 |
+
## 1. What We Are Building
|
| 29 |
+
|
| 30 |
+
An **OpenEnv-compliant RL training environment** where an AI agent:
|
| 31 |
+
|
| 32 |
+
- Receives a natural language business question and a live SQLite database schema
|
| 33 |
+
- Writes SQL queries, executes them, and observes the results
|
| 34 |
+
- Iterates until it can submit a final answer
|
| 35 |
+
- Gets scored 0.0–1.0 based on correctness and efficiency
|
| 36 |
+
|
| 37 |
+
**Why this wins:**
|
| 38 |
+
- Deterministic grading — SQL answers are right or wrong, no ambiguity
|
| 39 |
+
- Partial rewards are natural at every step (table hit → no error → correct answer)
|
| 40 |
+
- Directly applicable to real business intelligence workflows
|
| 41 |
+
- Clean difficulty curve across 3 tasks
|
| 42 |
+
|
| 43 |
+
---
|
| 44 |
+
|
| 45 |
+
## 2. Requirements Checklist
|
| 46 |
+
|
| 47 |
+
| # | Requirement | Implementation |
|
| 48 |
+
|---|---|---|
|
| 49 |
+
| 1 | Real-world task | SQL data analysis — used by every company daily |
|
| 50 |
+
| 2 | OpenEnv spec: typed models | Pydantic `Observation`, `Action`, `StepResult` |
|
| 51 |
+
| 3 | OpenEnv spec: `step()` | Returns `(observation, reward, done, info)` |
|
| 52 |
+
| 4 | OpenEnv spec: `reset()` | Returns initial observation, reseeds DB |
|
| 53 |
+
| 5 | OpenEnv spec: `state()` | Returns current full env state |
|
| 54 |
+
| 6 | `openenv.yaml` | Metadata, spaces, task list, baseline scores |
|
| 55 |
+
| 7 | 3 tasks with graders | Easy / Medium / Hard, each scored 0.0–1.0 |
|
| 56 |
+
| 8 | Meaningful reward | Partial credit at every step, not just end |
|
| 57 |
+
| 9 | Baseline inference script | OpenAI API client, reproducible scores |
|
| 58 |
+
| 10 | HuggingFace Space | Containerised, tagged `openenv` |
|
| 59 |
+
| 11 | Dockerfile | `docker build + docker run` works cleanly |
|
| 60 |
+
| 12 | README | Spaces, tasks, setup, baseline scores |
|
| 61 |
+
|
| 62 |
+
---
|
| 63 |
+
|
| 64 |
+
## 3. Database Design
|
| 65 |
+
|
| 66 |
+
Use a realistic SaaS e-commerce schema. This single schema supports all 3 tasks.
|
| 67 |
+
|
| 68 |
+
### Schema
|
| 69 |
+
|
| 70 |
+
```sql
|
| 71 |
+
-- users table
|
| 72 |
+
CREATE TABLE users (
|
| 73 |
+
id INTEGER PRIMARY KEY,
|
| 74 |
+
email TEXT NOT NULL,
|
| 75 |
+
country TEXT,
|
| 76 |
+
plan TEXT CHECK(plan IN ('free', 'pro', 'enterprise')),
|
| 77 |
+
created_at TIMESTAMP NOT NULL,
|
| 78 |
+
churned_at TIMESTAMP -- NULL if still active
|
| 79 |
+
);
|
| 80 |
+
|
| 81 |
+
-- products table
|
| 82 |
+
CREATE TABLE products (
|
| 83 |
+
id INTEGER PRIMARY KEY,
|
| 84 |
+
name TEXT NOT NULL,
|
| 85 |
+
category TEXT NOT NULL, -- Electronics, Clothing, Books, etc.
|
| 86 |
+
price DECIMAL(10,2),
|
| 87 |
+
cost DECIMAL(10,2)
|
| 88 |
+
);
|
| 89 |
+
|
| 90 |
+
-- orders table
|
| 91 |
+
CREATE TABLE orders (
|
| 92 |
+
id INTEGER PRIMARY KEY,
|
| 93 |
+
user_id INTEGER REFERENCES users(id),
|
| 94 |
+
created_at TIMESTAMP NOT NULL,
|
| 95 |
+
status TEXT CHECK(status IN ('pending','completed','refunded')),
|
| 96 |
+
total DECIMAL(10,2)
|
| 97 |
+
);
|
| 98 |
+
|
| 99 |
+
-- order_items table
|
| 100 |
+
CREATE TABLE order_items (
|
| 101 |
+
id INTEGER PRIMARY KEY,
|
| 102 |
+
order_id INTEGER REFERENCES orders(id),
|
| 103 |
+
product_id INTEGER REFERENCES products(id),
|
| 104 |
+
qty INTEGER NOT NULL,
|
| 105 |
+
unit_price DECIMAL(10,2)
|
| 106 |
+
);
|
| 107 |
+
|
| 108 |
+
-- events table (user behaviour)
|
| 109 |
+
CREATE TABLE events (
|
| 110 |
+
id INTEGER PRIMARY KEY,
|
| 111 |
+
user_id INTEGER REFERENCES users(id),
|
| 112 |
+
event_type TEXT, -- page_view, add_to_cart, checkout, etc.
|
| 113 |
+
metadata JSON,
|
| 114 |
+
ts TIMESTAMP NOT NULL
|
| 115 |
+
);
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
### Seeding
|
| 119 |
+
|
| 120 |
+
Seed with realistic volumes using the `faker` library:
|
| 121 |
+
|
| 122 |
+
```python
|
| 123 |
+
# database.py — seed targets
|
| 124 |
+
SEED_CONFIG = {
|
| 125 |
+
"users": 500, # ~500 users
|
| 126 |
+
"products": 80, # 80 products across 5 categories
|
| 127 |
+
"orders": 2000, # ~2000 orders
|
| 128 |
+
"order_items": 5000, # ~5000 line items
|
| 129 |
+
"events": 8000, # ~8000 behavioural events
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
# Intentional messiness (makes it realistic)
|
| 133 |
+
# - ~5% of users have NULL country
|
| 134 |
+
# - ~3% of orders have status='refunded'
|
| 135 |
+
# - churned_at is NULL for active users
|
| 136 |
+
# - Some users have 0 orders (registered but never bought)
|
| 137 |
+
```
|
| 138 |
+
|
| 139 |
+
---
|
| 140 |
+
|
| 141 |
+
## 4. The 3 Tasks with Graders
|
| 142 |
+
|
| 143 |
+
### Task 1 — Easy: Monthly Signups
|
| 144 |
+
|
| 145 |
+
**Question:** `"How many users signed up in the last 30 days?"`
|
| 146 |
+
|
| 147 |
+
**Required SQL skills:** Single table, `COUNT`, `WHERE`, date filtering
|
| 148 |
+
|
| 149 |
+
**Expected SQL:**
|
| 150 |
+
```sql
|
| 151 |
+
SELECT COUNT(*) FROM users
|
| 152 |
+
WHERE created_at >= DATE('now', '-30 days');
|
| 153 |
+
```
|
| 154 |
+
|
| 155 |
+
**Grader:**
|
| 156 |
+
```python
|
| 157 |
+
def grade_easy(submitted_answer: str, ground_truth: int) -> float:
|
| 158 |
+
try:
|
| 159 |
+
val = int(submitted_answer.strip().replace(",", ""))
|
| 160 |
+
if val == ground_truth:
|
| 161 |
+
return 1.0
|
| 162 |
+
if abs(val - ground_truth) <= 3: # within 3 = partial credit
|
| 163 |
+
return 0.6
|
| 164 |
+
if abs(val - ground_truth) <= 10: # within 10 = small credit
|
| 165 |
+
return 0.3
|
| 166 |
+
except (ValueError, AttributeError):
|
| 167 |
+
pass
|
| 168 |
+
return 0.0
|
| 169 |
+
```
|
| 170 |
+
|
| 171 |
+
**Max steps:** 10
|
| 172 |
+
**Difficulty:** Easy
|
| 173 |
+
|
| 174 |
+
---
|
| 175 |
+
|
| 176 |
+
### Task 2 — Medium: Top Revenue Category
|
| 177 |
+
|
| 178 |
+
**Question:** `"Which product category generated the most revenue in Q3 (July–September)?"`
|
| 179 |
+
|
| 180 |
+
**Required SQL skills:** `JOIN` across 3 tables, `GROUP BY`, `ORDER BY`, `SUM`, date range filtering
|
| 181 |
+
|
| 182 |
+
**Expected SQL:**
|
| 183 |
+
```sql
|
| 184 |
+
SELECT p.category, SUM(oi.qty * oi.unit_price) AS revenue
|
| 185 |
+
FROM orders o
|
| 186 |
+
JOIN order_items oi ON o.id = oi.order_id
|
| 187 |
+
JOIN products p ON oi.product_id = p.id
|
| 188 |
+
WHERE o.created_at BETWEEN '2024-07-01' AND '2024-09-30'
|
| 189 |
+
AND o.status = 'completed'
|
| 190 |
+
GROUP BY p.category
|
| 191 |
+
ORDER BY revenue DESC
|
| 192 |
+
LIMIT 1;
|
| 193 |
+
```
|
| 194 |
+
|
| 195 |
+
**Grader:**
|
| 196 |
+
```python
|
| 197 |
+
def grade_medium(submitted_answer: str, ground_truth: str, top_3: list) -> float:
|
| 198 |
+
answer = submitted_answer.strip().lower()
|
| 199 |
+
# Remove common LLM preamble
|
| 200 |
+
answer = re.sub(r'the (answer|category) is:?\s*', '', answer)
|
| 201 |
+
|
| 202 |
+
if ground_truth.lower() in answer:
|
| 203 |
+
return 1.0
|
| 204 |
+
if any(cat.lower() in answer for cat in top_3):
|
| 205 |
+
return 0.4 # got a plausible answer, not the top one
|
| 206 |
+
return 0.0
|
| 207 |
+
```
|
| 208 |
+
|
| 209 |
+
**Max steps:** 15
|
| 210 |
+
**Difficulty:** Medium
|
| 211 |
+
|
| 212 |
+
---
|
| 213 |
+
|
| 214 |
+
### Task 3 — Hard: Churn After 3rd Purchase
|
| 215 |
+
|
| 216 |
+
**Question:** `"Find the email addresses of users who placed exactly 3 orders and then never ordered again (churned after their 3rd purchase). Return as a comma-separated list."`
|
| 217 |
+
|
| 218 |
+
**Required SQL skills:** Window functions (`ROW_NUMBER`, `COUNT`), subqueries, `HAVING`, date logic
|
| 219 |
+
|
| 220 |
+
**Expected SQL:**
|
| 221 |
+
```sql
|
| 222 |
+
WITH order_counts AS (
|
| 223 |
+
SELECT user_id, COUNT(*) AS total_orders,
|
| 224 |
+
MAX(created_at) AS last_order_date
|
| 225 |
+
FROM orders
|
| 226 |
+
WHERE status = 'completed'
|
| 227 |
+
GROUP BY user_id
|
| 228 |
+
HAVING COUNT(*) = 3
|
| 229 |
+
),
|
| 230 |
+
churned AS (
|
| 231 |
+
SELECT oc.user_id
|
| 232 |
+
FROM order_counts oc
|
| 233 |
+
WHERE oc.last_order_date < DATE('now', '-90 days')
|
| 234 |
+
)
|
| 235 |
+
SELECT u.email
|
| 236 |
+
FROM users u
|
| 237 |
+
JOIN churned c ON u.id = c.user_id;
|
| 238 |
+
```
|
| 239 |
+
|
| 240 |
+
**Grader (F1 score for set matching):**
|
| 241 |
+
```python
|
| 242 |
+
def grade_hard(submitted_answer: str, ground_truth_emails: set) -> float:
|
| 243 |
+
if not submitted_answer.strip():
|
| 244 |
+
return 0.0
|
| 245 |
+
|
| 246 |
+
# Parse comma-separated emails
|
| 247 |
+
submitted = {
|
| 248 |
+
e.strip().lower()
|
| 249 |
+
for e in submitted_answer.split(",")
|
| 250 |
+
if "@" in e
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
if not submitted:
|
| 254 |
+
return 0.0
|
| 255 |
+
|
| 256 |
+
correct = ground_truth_emails
|
| 257 |
+
tp = len(submitted & correct)
|
| 258 |
+
|
| 259 |
+
if tp == 0:
|
| 260 |
+
return 0.0
|
| 261 |
+
|
| 262 |
+
precision = tp / len(submitted)
|
| 263 |
+
recall = tp / len(correct)
|
| 264 |
+
f1 = 2 * precision * recall / (precision + recall)
|
| 265 |
+
|
| 266 |
+
return round(f1, 3)
|
| 267 |
+
```
|
| 268 |
+
|
| 269 |
+
**Max steps:** 20
|
| 270 |
+
**Difficulty:** Hard
|
| 271 |
+
|
| 272 |
+
---
|
| 273 |
+
|
| 274 |
+
## 5. Pydantic Models (OpenEnv Spec)
|
| 275 |
+
|
| 276 |
+
```python
|
| 277 |
+
# env/models.py
|
| 278 |
+
from pydantic import BaseModel, Field
|
| 279 |
+
from typing import Optional, List, Any
|
| 280 |
+
|
| 281 |
+
class Action(BaseModel):
|
| 282 |
+
"""What the agent can do each step."""
|
| 283 |
+
sql_query: Optional[str] = Field(
|
| 284 |
+
None,
|
| 285 |
+
description="A SQL SELECT query to execute against the database"
|
| 286 |
+
)
|
| 287 |
+
submit_answer: Optional[str] = Field(
|
| 288 |
+
None,
|
| 289 |
+
description="Final answer to submit. Ends the episode."
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
def is_valid(self) -> bool:
|
| 293 |
+
# Exactly one of the two must be set
|
| 294 |
+
return bool(self.sql_query) != bool(self.submit_answer)
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
class QueryResult(BaseModel):
|
| 298 |
+
"""Result of executing a SQL query."""
|
| 299 |
+
columns: List[str] = []
|
| 300 |
+
rows: List[List[Any]] = []
|
| 301 |
+
error: Optional[str] = None
|
| 302 |
+
truncated: bool = False
|
| 303 |
+
total_rows: int = 0
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
class Observation(BaseModel):
|
| 307 |
+
"""What the agent sees after each step."""
|
| 308 |
+
schema_summary: str = Field(..., description="Compact DB schema")
|
| 309 |
+
question: str = Field(..., description="Business question to answer")
|
| 310 |
+
last_query: Optional[str] = None
|
| 311 |
+
last_result: Optional[QueryResult] = None
|
| 312 |
+
last_error: Optional[str] = None
|
| 313 |
+
step: int = 0
|
| 314 |
+
max_steps: int = 20
|
| 315 |
+
hints: List[str] = []
|
| 316 |
+
done: bool = False
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
class StepResult(BaseModel):
|
| 320 |
+
"""Full result returned by step()."""
|
| 321 |
+
observation: Observation
|
| 322 |
+
reward: float = 0.0
|
| 323 |
+
done: bool = False
|
| 324 |
+
info: dict = {}
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
class EnvState(BaseModel):
|
| 328 |
+
"""Full environment state returned by state()."""
|
| 329 |
+
task_id: str
|
| 330 |
+
difficulty: str
|
| 331 |
+
step: int
|
| 332 |
+
max_steps: int
|
| 333 |
+
query_history: List[str] = []
|
| 334 |
+
total_reward: float = 0.0
|
| 335 |
+
done: bool = False
|
| 336 |
+
```
|
| 337 |
+
|
| 338 |
+
---
|
| 339 |
+
|
| 340 |
+
## 6. Environment Core
|
| 341 |
+
|
| 342 |
+
```python
|
| 343 |
+
# env/environment.py
|
| 344 |
+
import sqlite3
|
| 345 |
+
from typing import Optional
|
| 346 |
+
from .models import Action, Observation, StepResult, EnvState, QueryResult
|
| 347 |
+
from .database import create_database, seed_database, get_schema_summary
|
| 348 |
+
from .reward import RewardCalculator
|
| 349 |
+
from .tasks import TASKS
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
class SQLAnalystEnv:
|
| 353 |
+
"""
|
| 354 |
+
OpenEnv-compliant SQL Data Analyst environment.
|
| 355 |
+
|
| 356 |
+
An agent must answer business questions by iteratively
|
| 357 |
+
writing and executing SQL queries.
|
| 358 |
+
"""
|
| 359 |
+
|
| 360 |
+
def __init__(self, task_id: str = "monthly_signups"):
|
| 361 |
+
assert task_id in TASKS, f"Unknown task: {task_id}. Choose from {list(TASKS)}"
|
| 362 |
+
self.task_id = task_id
|
| 363 |
+
self.task = TASKS[task_id]
|
| 364 |
+
self.conn: Optional[sqlite3.Connection] = None
|
| 365 |
+
self.step_count: int = 0
|
| 366 |
+
self.total_reward: float = 0.0
|
| 367 |
+
self.done: bool = False
|
| 368 |
+
self._query_history: list = []
|
| 369 |
+
self._reward_calc = RewardCalculator()
|
| 370 |
+
|
| 371 |
+
# ------------------------------------------------------------------
|
| 372 |
+
# OpenEnv required methods
|
| 373 |
+
# ------------------------------------------------------------------
|
| 374 |
+
|
| 375 |
+
def reset(self) -> StepResult:
|
| 376 |
+
"""Reset environment. Reseed DB. Return initial observation."""
|
| 377 |
+
if self.conn:
|
| 378 |
+
self.conn.close()
|
| 379 |
+
|
| 380 |
+
self.conn = create_database()
|
| 381 |
+
seed_database(self.conn)
|
| 382 |
+
self.step_count = 0
|
| 383 |
+
self.total_reward = 0.0
|
| 384 |
+
self.done = False
|
| 385 |
+
self._query_history = []
|
| 386 |
+
|
| 387 |
+
# Compute ground truth AFTER seeding
|
| 388 |
+
self.task.compute_ground_truth(self.conn)
|
| 389 |
+
|
| 390 |
+
obs = Observation(
|
| 391 |
+
schema_summary=get_schema_summary(self.conn),
|
| 392 |
+
question=self.task.question,
|
| 393 |
+
step=0,
|
| 394 |
+
max_steps=self.task.max_steps,
|
| 395 |
+
)
|
| 396 |
+
return StepResult(observation=obs, reward=0.0, done=False)
|
| 397 |
+
|
| 398 |
+
def step(self, action: Action) -> StepResult:
|
| 399 |
+
"""Execute one agent action. Return (observation, reward, done, info)."""
|
| 400 |
+
assert self.conn is not None, "Call reset() before step()"
|
| 401 |
+
assert not self.done, "Episode is done. Call reset()."
|
| 402 |
+
assert action.is_valid(), "Action must have exactly one of: sql_query, submit_answer"
|
| 403 |
+
|
| 404 |
+
self.step_count += 1
|
| 405 |
+
query_result = None
|
| 406 |
+
error = None
|
| 407 |
+
|
| 408 |
+
# --- Execute SQL or submit answer ---
|
| 409 |
+
if action.sql_query:
|
| 410 |
+
query_result = self._execute_sql(action.sql_query)
|
| 411 |
+
self._query_history.append(action.sql_query)
|
| 412 |
+
error = query_result.error
|
| 413 |
+
|
| 414 |
+
terminal = (
|
| 415 |
+
action.submit_answer is not None
|
| 416 |
+
or self.step_count >= self.task.max_steps
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
# --- Calculate reward ---
|
| 420 |
+
reward = self._reward_calc.calculate(
|
| 421 |
+
action=action,
|
| 422 |
+
result=query_result,
|
| 423 |
+
task=self.task,
|
| 424 |
+
step=self.step_count,
|
| 425 |
+
query_history=self._query_history,
|
| 426 |
+
terminal=terminal,
|
| 427 |
+
)
|
| 428 |
+
self.total_reward += reward
|
| 429 |
+
self.done = terminal
|
| 430 |
+
|
| 431 |
+
# --- Build next observation ---
|
| 432 |
+
obs = Observation(
|
| 433 |
+
schema_summary=get_schema_summary(self.conn),
|
| 434 |
+
question=self.task.question,
|
| 435 |
+
last_query=action.sql_query,
|
| 436 |
+
last_result=query_result,
|
| 437 |
+
last_error=error,
|
| 438 |
+
step=self.step_count,
|
| 439 |
+
max_steps=self.task.max_steps,
|
| 440 |
+
hints=self.task.get_hints(self.step_count),
|
| 441 |
+
done=self.done,
|
| 442 |
+
)
|
| 443 |
+
|
| 444 |
+
return StepResult(
|
| 445 |
+
observation=obs,
|
| 446 |
+
reward=round(reward, 3),
|
| 447 |
+
done=self.done,
|
| 448 |
+
info={
|
| 449 |
+
"step": self.step_count,
|
| 450 |
+
"total_reward": round(self.total_reward, 3),
|
| 451 |
+
"task_id": self.task_id,
|
| 452 |
+
}
|
| 453 |
+
)
|
| 454 |
+
|
| 455 |
+
def state(self) -> EnvState:
|
| 456 |
+
"""Return current full state of the environment."""
|
| 457 |
+
return EnvState(
|
| 458 |
+
task_id=self.task_id,
|
| 459 |
+
difficulty=self.task.difficulty,
|
| 460 |
+
step=self.step_count,
|
| 461 |
+
max_steps=self.task.max_steps,
|
| 462 |
+
query_history=self._query_history.copy(),
|
| 463 |
+
total_reward=round(self.total_reward, 3),
|
| 464 |
+
done=self.done,
|
| 465 |
+
)
|
| 466 |
+
|
| 467 |
+
# ------------------------------------------------------------------
|
| 468 |
+
# Internal helpers
|
| 469 |
+
# ------------------------------------------------------------------
|
| 470 |
+
|
| 471 |
+
def _execute_sql(self, query: str) -> QueryResult:
|
| 472 |
+
"""Execute SQL safely. Block non-SELECT. Return up to 50 rows."""
|
| 473 |
+
# Safety: only SELECT is allowed
|
| 474 |
+
q = query.strip().upper()
|
| 475 |
+
if not q.startswith("SELECT") and not q.startswith("WITH"):
|
| 476 |
+
return QueryResult(error="Only SELECT / WITH queries are allowed.")
|
| 477 |
+
try:
|
| 478 |
+
cursor = self.conn.execute(query)
|
| 479 |
+
cols = [d[0] for d in cursor.description] if cursor.description else []
|
| 480 |
+
rows = cursor.fetchmany(50)
|
| 481 |
+
total = len(rows) # fetchmany caps at 50
|
| 482 |
+
return QueryResult(
|
| 483 |
+
columns=cols,
|
| 484 |
+
rows=[list(r) for r in rows],
|
| 485 |
+
truncated=(total == 50),
|
| 486 |
+
total_rows=total,
|
| 487 |
+
)
|
| 488 |
+
except Exception as e:
|
| 489 |
+
return QueryResult(error=str(e))
|
| 490 |
+
```
|
| 491 |
+
|
| 492 |
+
---
|
| 493 |
+
|
| 494 |
+
## 7. Reward Function
|
| 495 |
+
|
| 496 |
+
```python
|
| 497 |
+
# env/reward.py
|
| 498 |
+
import re
|
| 499 |
+
from .models import Action, QueryResult
|
| 500 |
+
|
| 501 |
+
|
| 502 |
+
class RewardCalculator:
|
| 503 |
+
|
| 504 |
+
def calculate(
|
| 505 |
+
self,
|
| 506 |
+
action: Action,
|
| 507 |
+
result: Optional[QueryResult],
|
| 508 |
+
task,
|
| 509 |
+
step: int,
|
| 510 |
+
query_history: list,
|
| 511 |
+
terminal: bool,
|
| 512 |
+
) -> float:
|
| 513 |
+
|
| 514 |
+
reward = 0.0
|
| 515 |
+
|
| 516 |
+
# ── Step-level rewards (every step) ──────────────────────────
|
| 517 |
+
|
| 518 |
+
if action.sql_query and result:
|
| 519 |
+
|
| 520 |
+
# +0.15 — Query executed without syntax error
|
| 521 |
+
if not result.error:
|
| 522 |
+
reward += 0.15
|
| 523 |
+
|
| 524 |
+
# +0.10 — Query touched at least one relevant table
|
| 525 |
+
relevant = self._count_relevant_tables(action.sql_query, task.relevant_tables)
|
| 526 |
+
if relevant > 0:
|
| 527 |
+
reward += 0.10
|
| 528 |
+
|
| 529 |
+
# +0.05 — Result has rows (not empty result set)
|
| 530 |
+
if result.rows and len(result.rows) > 0:
|
| 531 |
+
reward += 0.05
|
| 532 |
+
|
| 533 |
+
# +0.05 — Result is not absurdly large (sanity check)
|
| 534 |
+
if result.rows and len(result.rows) < 1000:
|
| 535 |
+
reward += 0.05
|
| 536 |
+
|
| 537 |
+
# ── Efficiency penalties ──────────────────────────────────────
|
| 538 |
+
|
| 539 |
+
# -0.02 per step beyond step 3 (penalise excessive querying)
|
| 540 |
+
if step > 3:
|
| 541 |
+
reward -= 0.02 * (step - 3)
|
| 542 |
+
|
| 543 |
+
# -0.10 if agent is stuck in a loop (same query 3x)
|
| 544 |
+
if self._is_stuck(query_history):
|
| 545 |
+
reward -= 0.10
|
| 546 |
+
|
| 547 |
+
# ── Terminal reward (only when episode ends) ──────────────────
|
| 548 |
+
|
| 549 |
+
if terminal and action.submit_answer:
|
| 550 |
+
# Grade the submitted answer — up to 0.60 of total reward
|
| 551 |
+
task_score = task.grade(action.submit_answer)
|
| 552 |
+
reward += task_score * 0.60
|
| 553 |
+
|
| 554 |
+
# Clamp to [0.0, 1.0]
|
| 555 |
+
return max(0.0, min(1.0, reward))
|
| 556 |
+
|
| 557 |
+
def _count_relevant_tables(self, query: str, relevant_tables: list) -> int:
|
| 558 |
+
query_lower = query.lower()
|
| 559 |
+
return sum(1 for t in relevant_tables if t.lower() in query_lower)
|
| 560 |
+
|
| 561 |
+
def _is_stuck(self, history: list) -> bool:
|
| 562 |
+
if len(history) < 3:
|
| 563 |
+
return False
|
| 564 |
+
return len(set(history[-3:])) == 1
|
| 565 |
+
```
|
| 566 |
+
|
| 567 |
+
**Reward breakdown per step:**
|
| 568 |
+
|
| 569 |
+
| Signal | Max Value | Condition |
|
| 570 |
+
|---|---|---|
|
| 571 |
+
| No SQL error | +0.15 | Query executes cleanly |
|
| 572 |
+
| Relevant table used | +0.10 | Query touches correct table(s) |
|
| 573 |
+
| Non-empty result | +0.05 | Result set has at least 1 row |
|
| 574 |
+
| Reasonable result size | +0.05 | Result has < 1000 rows |
|
| 575 |
+
| Late-step penalty | −0.02/step | Each step beyond step 3 |
|
| 576 |
+
| Infinite loop penalty | −0.10 | Same query repeated 3+ times |
|
| 577 |
+
| Terminal answer score | up to +0.60 | Task grader × 0.60 |
|
| 578 |
+
|
| 579 |
+
**Maximum possible reward per episode:** ~1.0
|
| 580 |
+
**Minimum (immediate surrender):** 0.0
|
| 581 |
+
|
| 582 |
+
---
|
| 583 |
+
|
| 584 |
+
## 8. Key Optimisations
|
| 585 |
+
|
| 586 |
+
### 8.1 Schema Summarisation
|
| 587 |
+
|
| 588 |
+
Never dump raw `CREATE TABLE` SQL into the prompt — it wastes context. Use a compact summary:
|
| 589 |
+
|
| 590 |
+
```python
|
| 591 |
+
# env/database.py
|
| 592 |
+
def get_schema_summary(conn: sqlite3.Connection) -> str:
|
| 593 |
+
"""Return one-line-per-table schema, e.g.:
|
| 594 |
+
users: (id, email, country, plan, created_at, churned_at)
|
| 595 |
+
"""
|
| 596 |
+
cursor = conn.execute(
|
| 597 |
+
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY name"
|
| 598 |
+
)
|
| 599 |
+
tables = [r[0] for r in cursor.fetchall()]
|
| 600 |
+
lines = []
|
| 601 |
+
for table in tables:
|
| 602 |
+
cols = conn.execute(f"PRAGMA table_info({table})").fetchall()
|
| 603 |
+
col_names = [c[1] for c in cols]
|
| 604 |
+
lines.append(f"{table}: ({', '.join(col_names)})")
|
| 605 |
+
return "\n".join(lines)
|
| 606 |
+
```
|
| 607 |
+
|
| 608 |
+
### 8.2 Answer Normalisation
|
| 609 |
+
|
| 610 |
+
Strip LLM formatting before grading — don't penalise the agent for markdown:
|
| 611 |
+
|
| 612 |
+
```python
|
| 613 |
+
# env/utils.py
|
| 614 |
+
import re
|
| 615 |
+
|
| 616 |
+
def normalize_answer(raw: str) -> str:
|
| 617 |
+
"""Remove common LLM answer preambles and formatting."""
|
| 618 |
+
text = raw.strip().lower()
|
| 619 |
+
text = re.sub(r'the (answer|result) is:?\s*', '', text)
|
| 620 |
+
text = re.sub(r'\*+', '', text) # bold
|
| 621 |
+
text = re.sub(r'```.*?```', '', text, flags=re.DOTALL) # code blocks
|
| 622 |
+
text = re.sub(r'`[^`]+`', lambda m: m.group().strip('`'), text)
|
| 623 |
+
text = re.sub(r'\s+', ' ', text)
|
| 624 |
+
return text.strip()
|
| 625 |
+
```
|
| 626 |
+
|
| 627 |
+
### 8.3 Progressive Hints
|
| 628 |
+
|
| 629 |
+
Give hints as steps increase — keeps episodes learnable and reward dense:
|
| 630 |
+
|
| 631 |
+
```python
|
| 632 |
+
# env/tasks/base.py
|
| 633 |
+
def get_hints(self, step: int) -> list[str]:
|
| 634 |
+
hints = []
|
| 635 |
+
if step > 5:
|
| 636 |
+
hints.append(f"Hint: The relevant tables are: {', '.join(self.relevant_tables)}")
|
| 637 |
+
if step > 10:
|
| 638 |
+
hints.append(f"Hint: Try using {self.sql_hint}")
|
| 639 |
+
if step > 15:
|
| 640 |
+
hints.append("Hint: Make sure to submit your answer with submit_answer.")
|
| 641 |
+
return hints
|
| 642 |
+
```
|
| 643 |
+
|
| 644 |
+
### 8.4 Ground Truth Computed Post-Seed
|
| 645 |
+
|
| 646 |
+
Always compute ground truth **after** seeding, so it matches the actual data:
|
| 647 |
+
|
| 648 |
+
```python
|
| 649 |
+
# env/tasks/easy.py
|
| 650 |
+
def compute_ground_truth(self, conn: sqlite3.Connection):
|
| 651 |
+
result = conn.execute(
|
| 652 |
+
"SELECT COUNT(*) FROM users WHERE created_at >= DATE('now', '-30 days')"
|
| 653 |
+
).fetchone()
|
| 654 |
+
self.ground_truth = result[0]
|
| 655 |
+
```
|
| 656 |
+
|
| 657 |
+
### 8.5 SQL Safety Guards
|
| 658 |
+
|
| 659 |
+
Block any mutating operations:
|
| 660 |
+
|
| 661 |
+
```python
|
| 662 |
+
FORBIDDEN_KEYWORDS = ["DROP", "DELETE", "INSERT", "UPDATE", "ALTER", "CREATE", "TRUNCATE"]
|
| 663 |
+
|
| 664 |
+
def is_safe_query(query: str) -> bool:
|
| 665 |
+
upper = query.upper()
|
| 666 |
+
return not any(kw in upper for kw in FORBIDDEN_KEYWORDS)
|
| 667 |
+
```
|
| 668 |
+
|
| 669 |
+
---
|
| 670 |
+
|
| 671 |
+
## 9. Baseline Inference Script
|
| 672 |
+
|
| 673 |
+
```python
|
| 674 |
+
# baseline/run_baseline.py
|
| 675 |
+
"""
|
| 676 |
+
Baseline inference script for sql-data-analyst OpenEnv.
|
| 677 |
+
|
| 678 |
+
Usage:
|
| 679 |
+
export OPENAI_API_KEY=sk-...
|
| 680 |
+
python baseline/run_baseline.py
|
| 681 |
+
|
| 682 |
+
Produces reproducible scores across all 3 tasks.
|
| 683 |
+
"""
|
| 684 |
+
|
| 685 |
+
import os
|
| 686 |
+
import json
|
| 687 |
+
from openai import OpenAI
|
| 688 |
+
from env.environment import SQLAnalystEnv
|
| 689 |
+
from env.models import Action
|
| 690 |
+
|
| 691 |
+
API_KEY = os.environ["OPENAI_API_KEY"]
|
| 692 |
+
MODEL = "gpt-4o-mini"
|
| 693 |
+
MAX_STEPS = 20
|
| 694 |
+
TASK_IDS = ["monthly_signups", "top_revenue_category", "churn_analysis"]
|
| 695 |
+
|
| 696 |
+
client = OpenAI(api_key=API_KEY)
|
| 697 |
+
|
| 698 |
+
SYSTEM_PROMPT = """
|
| 699 |
+
You are a SQL data analyst. You are given a database schema and a business question.
|
| 700 |
+
Your job is to write SQL queries to explore the data and submit a final answer.
|
| 701 |
+
|
| 702 |
+
Rules:
|
| 703 |
+
- Only write SELECT or WITH queries.
|
| 704 |
+
- Reply with JSON only. No explanation.
|
| 705 |
+
- To run a query: {"sql_query": "SELECT ..."}
|
| 706 |
+
- To submit answer: {"submit_answer": "your answer here"}
|
| 707 |
+
- You will see the query result after each step.
|
| 708 |
+
- Submit your answer when you are confident.
|
| 709 |
+
"""
|
| 710 |
+
|
| 711 |
+
|
| 712 |
+
def build_prompt(obs) -> str:
|
| 713 |
+
parts = [
|
| 714 |
+
f"Database schema:\n{obs.schema_summary}",
|
| 715 |
+
f"\nQuestion: {obs.question}",
|
| 716 |
+
f"\nStep: {obs.step} / {obs.max_steps}",
|
| 717 |
+
]
|
| 718 |
+
if obs.last_query:
|
| 719 |
+
parts.append(f"\nLast query:\n{obs.last_query}")
|
| 720 |
+
if obs.last_result and obs.last_result.rows:
|
| 721 |
+
cols = obs.last_result.columns
|
| 722 |
+
rows = obs.last_result.rows[:10] # show max 10 rows
|
| 723 |
+
parts.append(f"\nResult columns: {cols}")
|
| 724 |
+
parts.append(f"Result rows (first {len(rows)}):\n{json.dumps(rows, indent=2)}")
|
| 725 |
+
if obs.last_error:
|
| 726 |
+
parts.append(f"\nSQL error: {obs.last_error}")
|
| 727 |
+
if obs.hints:
|
| 728 |
+
parts.append(f"\nHints: {'; '.join(obs.hints)}")
|
| 729 |
+
parts.append("\nWhat is your next action? Reply with JSON only.")
|
| 730 |
+
return "\n".join(parts)
|
| 731 |
+
|
| 732 |
+
|
| 733 |
+
def parse_action(response_text: str) -> Action:
|
| 734 |
+
"""Extract JSON action from LLM response."""
|
| 735 |
+
text = response_text.strip()
|
| 736 |
+
# Strip markdown code fences if present
|
| 737 |
+
text = text.replace("```json", "").replace("```", "").strip()
|
| 738 |
+
try:
|
| 739 |
+
data = json.loads(text)
|
| 740 |
+
return Action(**data)
|
| 741 |
+
except Exception:
|
| 742 |
+
# Fallback: treat entire response as a submit
|
| 743 |
+
return Action(submit_answer=text)
|
| 744 |
+
|
| 745 |
+
|
| 746 |
+
def run_task(task_id: str) -> dict:
|
| 747 |
+
print(f"\n{'='*50}")
|
| 748 |
+
print(f"Task: {task_id}")
|
| 749 |
+
print('='*50)
|
| 750 |
+
|
| 751 |
+
env = SQLAnalystEnv(task_id=task_id)
|
| 752 |
+
result = env.reset()
|
| 753 |
+
obs = result.observation
|
| 754 |
+
history = []
|
| 755 |
+
score = 0.0
|
| 756 |
+
|
| 757 |
+
print(f"Question: {obs.question}")
|
| 758 |
+
|
| 759 |
+
for step in range(1, MAX_STEPS + 1):
|
| 760 |
+
if result.done:
|
| 761 |
+
print(f"Episode done at step {step - 1}")
|
| 762 |
+
break
|
| 763 |
+
|
| 764 |
+
user_prompt = build_prompt(obs)
|
| 765 |
+
history.append({"role": "user", "content": user_prompt})
|
| 766 |
+
|
| 767 |
+
response = client.chat.completions.create(
|
| 768 |
+
model=MODEL,
|
| 769 |
+
messages=[
|
| 770 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 771 |
+
*history[-8:], # last 4 turns (8 messages)
|
| 772 |
+
],
|
| 773 |
+
temperature=0.0, # deterministic
|
| 774 |
+
)
|
| 775 |
+
|
| 776 |
+
reply = response.choices[0].message.content
|
| 777 |
+
history.append({"role": "assistant", "content": reply})
|
| 778 |
+
|
| 779 |
+
action = parse_action(reply)
|
| 780 |
+
print(f"Step {step}: {action}")
|
| 781 |
+
|
| 782 |
+
result = env.step(action)
|
| 783 |
+
obs = result.observation
|
| 784 |
+
score = result.reward
|
| 785 |
+
|
| 786 |
+
if result.done:
|
| 787 |
+
break
|
| 788 |
+
|
| 789 |
+
state = env.state()
|
| 790 |
+
print(f"Final total reward: {state.total_reward}")
|
| 791 |
+
return {
|
| 792 |
+
"task_id": task_id,
|
| 793 |
+
"total_reward": state.total_reward,
|
| 794 |
+
"steps": state.step,
|
| 795 |
+
}
|
| 796 |
+
|
| 797 |
+
|
| 798 |
+
def main():
|
| 799 |
+
results = []
|
| 800 |
+
for task_id in TASK_IDS:
|
| 801 |
+
r = run_task(task_id)
|
| 802 |
+
results.append(r)
|
| 803 |
+
|
| 804 |
+
print("\n" + "="*50)
|
| 805 |
+
print("BASELINE RESULTS")
|
| 806 |
+
print("="*50)
|
| 807 |
+
for r in results:
|
| 808 |
+
print(f"{r['task_id']:30s} score={r['total_reward']:.3f} steps={r['steps']}")
|
| 809 |
+
|
| 810 |
+
avg = sum(r["total_reward"] for r in results) / len(results)
|
| 811 |
+
print(f"\nAverage score: {avg:.3f}")
|
| 812 |
+
|
| 813 |
+
# Write results to file for reproducibility
|
| 814 |
+
with open("baseline_scores.json", "w") as f:
|
| 815 |
+
json.dump(results, f, indent=2)
|
| 816 |
+
print("Saved to baseline_scores.json")
|
| 817 |
+
|
| 818 |
+
|
| 819 |
+
if __name__ == "__main__":
|
| 820 |
+
main()
|
| 821 |
+
```
|
| 822 |
+
|
| 823 |
+
---
|
| 824 |
+
|
| 825 |
+
## 10. openenv.yaml
|
| 826 |
+
|
| 827 |
+
```yaml
|
| 828 |
+
name: sql-data-analyst
|
| 829 |
+
version: "1.0.0"
|
| 830 |
+
description: >
|
| 831 |
+
An RL environment where an AI agent answers real business intelligence questions
|
| 832 |
+
by iteratively writing and executing SQL queries against a live SQLite database.
|
| 833 |
+
Simulates the day-to-day workflow of a data analyst.
|
| 834 |
+
|
| 835 |
+
tags:
|
| 836 |
+
- openenv
|
| 837 |
+
- sql
|
| 838 |
+
- data-analysis
|
| 839 |
+
- business-intelligence
|
| 840 |
+
- real-world
|
| 841 |
+
|
| 842 |
+
author: your-username
|
| 843 |
+
repository: https://huggingface.co/spaces/your-username/sql-data-analyst
|
| 844 |
+
|
| 845 |
+
observation_space:
|
| 846 |
+
type: dict
|
| 847 |
+
fields:
|
| 848 |
+
schema_summary:
|
| 849 |
+
type: string
|
| 850 |
+
description: Compact one-line-per-table schema of the database
|
| 851 |
+
question:
|
| 852 |
+
type: string
|
| 853 |
+
description: Natural language business question to answer
|
| 854 |
+
last_query:
|
| 855 |
+
type: string
|
| 856 |
+
nullable: true
|
| 857 |
+
description: The last SQL query executed by the agent
|
| 858 |
+
last_result:
|
| 859 |
+
type: object
|
| 860 |
+
nullable: true
|
| 861 |
+
description: Result of the last query (columns, rows, error)
|
| 862 |
+
last_error:
|
| 863 |
+
type: string
|
| 864 |
+
nullable: true
|
| 865 |
+
description: SQL error message if last query failed
|
| 866 |
+
step:
|
| 867 |
+
type: integer
|
| 868 |
+
description: Current step number
|
| 869 |
+
max_steps:
|
| 870 |
+
type: integer
|
| 871 |
+
description: Maximum steps allowed for this task
|
| 872 |
+
hints:
|
| 873 |
+
type: array
|
| 874 |
+
items: string
|
| 875 |
+
description: Progressive hints revealed as steps increase
|
| 876 |
+
|
| 877 |
+
action_space:
|
| 878 |
+
type: union
|
| 879 |
+
description: Agent must provide exactly one of the following
|
| 880 |
+
options:
|
| 881 |
+
sql_query:
|
| 882 |
+
type: string
|
| 883 |
+
description: A SELECT or WITH SQL query to execute
|
| 884 |
+
submit_answer:
|
| 885 |
+
type: string
|
| 886 |
+
description: Final answer to the question. Ends the episode.
|
| 887 |
+
|
| 888 |
+
tasks:
|
| 889 |
+
- id: monthly_signups
|
| 890 |
+
difficulty: easy
|
| 891 |
+
max_steps: 10
|
| 892 |
+
description: "Count the number of users who signed up in the last 30 days"
|
| 893 |
+
skills_required:
|
| 894 |
+
- COUNT
|
| 895 |
+
- WHERE with date filter
|
| 896 |
+
|
| 897 |
+
- id: top_revenue_category
|
| 898 |
+
difficulty: medium
|
| 899 |
+
max_steps: 15
|
| 900 |
+
description: "Find which product category generated the most revenue in Q3"
|
| 901 |
+
skills_required:
|
| 902 |
+
- JOIN (3 tables)
|
| 903 |
+
- GROUP BY
|
| 904 |
+
- SUM aggregation
|
| 905 |
+
- Date range filtering
|
| 906 |
+
|
| 907 |
+
- id: churn_analysis
|
| 908 |
+
difficulty: hard
|
| 909 |
+
max_steps: 20
|
| 910 |
+
description: >
|
| 911 |
+
Find email addresses of users who placed exactly 3 orders and then
|
| 912 |
+
never ordered again (churned after their 3rd purchase)
|
| 913 |
+
skills_required:
|
| 914 |
+
- Subqueries
|
| 915 |
+
- HAVING clause
|
| 916 |
+
- Date logic
|
| 917 |
+
- Window functions (optional)
|
| 918 |
+
|
| 919 |
+
baseline_scores:
|
| 920 |
+
monthly_signups: 0.82
|
| 921 |
+
top_revenue_category: 0.61
|
| 922 |
+
churn_analysis: 0.38
|
| 923 |
+
average: 0.60
|
| 924 |
+
```
|
| 925 |
+
|
| 926 |
+
---
|
| 927 |
+
|
| 928 |
+
## 11. Dockerfile
|
| 929 |
+
|
| 930 |
+
```dockerfile
|
| 931 |
+
FROM python:3.11-slim
|
| 932 |
+
|
| 933 |
+
WORKDIR /app
|
| 934 |
+
|
| 935 |
+
# Install dependencies
|
| 936 |
+
COPY requirements.txt .
|
| 937 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 938 |
+
|
| 939 |
+
# Copy source
|
| 940 |
+
COPY . .
|
| 941 |
+
|
| 942 |
+
# Pre-seed the database at build time (optional — env also seeds at reset())
|
| 943 |
+
RUN python -c "from env.database import create_database, seed_database; \
|
| 944 |
+
conn = create_database(); seed_database(conn); conn.close()"
|
| 945 |
+
|
| 946 |
+
# Expose port for HuggingFace Spaces
|
| 947 |
+
EXPOSE 7860
|
| 948 |
+
|
| 949 |
+
# Start the API server
|
| 950 |
+
CMD ["python", "-m", "uvicorn", "env.server:app", "--host", "0.0.0.0", "--port", "7860"]
|
| 951 |
+
```
|
| 952 |
+
|
| 953 |
+
```
|
| 954 |
+
# requirements.txt
|
| 955 |
+
pydantic>=2.0
|
| 956 |
+
fastapi
|
| 957 |
+
uvicorn
|
| 958 |
+
openai
|
| 959 |
+
faker
|
| 960 |
+
pytest
|
| 961 |
+
```
|
| 962 |
+
|
| 963 |
+
---
|
| 964 |
+
|
| 965 |
+
## 12. README Template
|
| 966 |
+
|
| 967 |
+
````markdown
|
| 968 |
+
# SQL Data Analyst — OpenEnv Environment
|
| 969 |
+
|
| 970 |
+
An RL training environment where an AI agent learns to answer business intelligence
|
| 971 |
+
questions by writing and executing SQL queries against a live database.
|
| 972 |
+
|
| 973 |
+
## Motivation
|
| 974 |
+
|
| 975 |
+
Data analysts spend significant time translating business questions into SQL queries.
|
| 976 |
+
This environment trains agents to do exactly that — iteratively exploring a database
|
| 977 |
+
schema, writing queries, observing results, and submitting final answers.
|
| 978 |
+
|
| 979 |
+
## Observation Space
|
| 980 |
+
|
| 981 |
+
| Field | Type | Description |
|
| 982 |
+
|---|---|---|
|
| 983 |
+
| `schema_summary` | string | Compact DB schema (one line per table) |
|
| 984 |
+
| `question` | string | Natural language business question |
|
| 985 |
+
| `last_query` | string \| null | Most recent SQL query |
|
| 986 |
+
| `last_result` | object \| null | Query result: columns, rows (max 50), error |
|
| 987 |
+
| `last_error` | string \| null | SQL error if last query failed |
|
| 988 |
+
| `step` | int | Current step number |
|
| 989 |
+
| `max_steps` | int | Episode step limit |
|
| 990 |
+
| `hints` | string[] | Progressive hints (revealed after step 5, 10, 15) |
|
| 991 |
+
|
| 992 |
+
## Action Space
|
| 993 |
+
|
| 994 |
+
Agent must submit exactly one of:
|
| 995 |
+
|
| 996 |
+
| Action | Type | Description |
|
| 997 |
+
|---|---|---|
|
| 998 |
+
| `sql_query` | string | A SELECT or WITH SQL query to execute |
|
| 999 |
+
| `submit_answer` | string | Final answer — ends the episode |
|
| 1000 |
+
|
| 1001 |
+
## Tasks
|
| 1002 |
+
|
| 1003 |
+
| Task | Difficulty | Max Steps | Description |
|
| 1004 |
+
|---|---|---|---|
|
| 1005 |
+
| `monthly_signups` | Easy | 10 | Count signups in the last 30 days |
|
| 1006 |
+
| `top_revenue_category` | Medium | 15 | Find highest revenue product category in Q3 |
|
| 1007 |
+
| `churn_analysis` | Hard | 20 | Find emails of users who churned after 3 purchases |
|
| 1008 |
+
|
| 1009 |
+
## Reward Function
|
| 1010 |
+
|
| 1011 |
+
Rewards are given at every step (not just episode end):
|
| 1012 |
+
|
| 1013 |
+
- `+0.15` — Query executes without error
|
| 1014 |
+
- `+0.10` — Query references a relevant table
|
| 1015 |
+
- `+0.05` — Result has at least one row
|
| 1016 |
+
- `+0.05` — Result is a sensible size
|
| 1017 |
+
- `-0.02` per step beyond step 3 (efficiency penalty)
|
| 1018 |
+
- `-0.10` if agent repeats the same query 3+ times
|
| 1019 |
+
- `+0.00–0.60` on final submission (task grader × 0.60)
|
| 1020 |
+
|
| 1021 |
+
## Setup
|
| 1022 |
+
|
| 1023 |
+
```bash
|
| 1024 |
+
git clone https://huggingface.co/spaces/your-username/sql-data-analyst
|
| 1025 |
+
cd sql-data-analyst
|
| 1026 |
+
pip install -r requirements.txt
|
| 1027 |
+
```
|
| 1028 |
+
|
| 1029 |
+
### Run locally
|
| 1030 |
+
|
| 1031 |
+
```python
|
| 1032 |
+
from env.environment import SQLAnalystEnv
|
| 1033 |
+
from env.models import Action
|
| 1034 |
+
|
| 1035 |
+
env = SQLAnalystEnv(task_id="monthly_signups")
|
| 1036 |
+
result = env.reset()
|
| 1037 |
+
print(result.observation.question)
|
| 1038 |
+
|
| 1039 |
+
# Agent takes a step
|
| 1040 |
+
result = env.step(Action(sql_query="SELECT COUNT(*) FROM users WHERE created_at >= DATE('now', '-30 days')"))
|
| 1041 |
+
print(result.reward)
|
| 1042 |
+
```
|
| 1043 |
+
|
| 1044 |
+
### Run baseline
|
| 1045 |
+
|
| 1046 |
+
```bash
|
| 1047 |
+
export OPENAI_API_KEY=sk-...
|
| 1048 |
+
python baseline/run_baseline.py
|
| 1049 |
+
```
|
| 1050 |
+
|
| 1051 |
+
### Docker
|
| 1052 |
+
|
| 1053 |
+
```bash
|
| 1054 |
+
docker build -t sql-analyst-env .
|
| 1055 |
+
docker run -p 7860:7860 -e OPENAI_API_KEY=sk-... sql-analyst-env
|
| 1056 |
+
```
|
| 1057 |
+
|
| 1058 |
+
## Baseline Scores
|
| 1059 |
+
|
| 1060 |
+
| Task | Score | Model |
|
| 1061 |
+
|---|---|---|
|
| 1062 |
+
| monthly_signups | 0.82 | gpt-4o-mini |
|
| 1063 |
+
| top_revenue_category | 0.61 | gpt-4o-mini |
|
| 1064 |
+
| churn_analysis | 0.38 | gpt-4o-mini |
|
| 1065 |
+
| **Average** | **0.60** | gpt-4o-mini |
|
| 1066 |
+
|
| 1067 |
+
## Validation
|
| 1068 |
+
|
| 1069 |
+
```bash
|
| 1070 |
+
openenv validate --env env.environment.SQLAnalystEnv
|
| 1071 |
+
pytest tests/
|
| 1072 |
+
```
|
| 1073 |
+
````
|
| 1074 |
+
|
| 1075 |
+
---
|
| 1076 |
+
|
| 1077 |
+
## 13. Full File Structure
|
| 1078 |
+
|
| 1079 |
+
```
|
| 1080 |
+
sql-analyst-openenv/
|
| 1081 |
+
│
|
| 1082 |
+
├── env/
|
| 1083 |
+
│ ├── __init__.py
|
| 1084 |
+
│ ├── environment.py ← Main OpenEnv class (reset/step/state)
|
| 1085 |
+
│ ├── models.py ← Pydantic: Observation, Action, StepResult, EnvState
|
| 1086 |
+
│ ├── database.py ← SQLite creation + Faker seeding + schema summary
|
| 1087 |
+
│ ├── executor.py ← Safe SQL execution (SELECT-only guard)
|
| 1088 |
+
│ ├── reward.py ← RewardCalculator class
|
| 1089 |
+
│ ├── utils.py ← normalize_answer, is_safe_query helpers
|
| 1090 |
+
│ ├── server.py ← FastAPI wrapper for HuggingFace Spaces
|
| 1091 |
+
│ └── tasks/
|
| 1092 |
+
│ ├── __init__.py ← TASKS dict: {task_id: TaskInstance}
|
| 1093 |
+
│ ├── base.py ← BaseTask abstract class
|
| 1094 |
+
│ ├── easy.py ← MonthlySignupsTask
|
| 1095 |
+
│ ├── medium.py ← TopRevenueCategoryTask
|
| 1096 |
+
│ └── hard.py ← ChurnAnalysisTask
|
| 1097 |
+
│
|
| 1098 |
+
├── baseline/
|
| 1099 |
+
│ ├── run_baseline.py ← Full inference script (OpenAI API)
|
| 1100 |
+
│ └── prompts.py ← System prompt + user prompt builder
|
| 1101 |
+
│
|
| 1102 |
+
├── tests/
|
| 1103 |
+
│ ├── test_env.py ← reset/step/state contract tests
|
| 1104 |
+
│ ├── test_graders.py ← Unit tests for each task grader
|
| 1105 |
+
│ └── test_reward.py ← Reward calculator unit tests
|
| 1106 |
+
│
|
| 1107 |
+
├── openenv.yaml ← OpenEnv spec metadata
|
| 1108 |
+
├── Dockerfile ← docker build + docker run
|
| 1109 |
+
├── requirements.txt
|
| 1110 |
+
└── README.md
|
| 1111 |
+
```
|
| 1112 |
+
|
| 1113 |
+
---
|
| 1114 |
+
|
| 1115 |
+
## 14. Build Order
|
| 1116 |
+
|
| 1117 |
+
Follow this order when coding. Each step is a self-contained deliverable.
|
| 1118 |
+
|
| 1119 |
+
### Step 1 — Models (30 min)
|
| 1120 |
+
Build `env/models.py` first. All other files depend on these types.
|
| 1121 |
+
Test: can import and instantiate `Observation`, `Action`, `StepResult`.
|
| 1122 |
+
|
| 1123 |
+
### Step 2 — Database (45 min)
|
| 1124 |
+
Build `env/database.py` — schema creation, Faker seeding, schema summary.
|
| 1125 |
+
Test: run `create_database()` + `seed_database()`, query the tables manually.
|
| 1126 |
+
|
| 1127 |
+
### Step 3 — Tasks + Graders (60 min)
|
| 1128 |
+
Build `env/tasks/base.py`, then `easy.py`, `medium.py`, `hard.py`.
|
| 1129 |
+
Test each grader with known inputs: perfect answer → 1.0, wrong answer → 0.0.
|
| 1130 |
+
|
| 1131 |
+
### Step 4 — Reward Calculator (30 min)
|
| 1132 |
+
Build `env/reward.py`.
|
| 1133 |
+
Test: step with good query → positive reward, repeated query → penalty applied.
|
| 1134 |
+
|
| 1135 |
+
### Step 5 — Environment Core (60 min)
|
| 1136 |
+
Build `env/environment.py` — wire together DB, executor, reward, tasks.
|
| 1137 |
+
Test: full episode loop manually: `reset()` → `step()` × N → `state()`.
|
| 1138 |
+
|
| 1139 |
+
### Step 6 — Baseline Script (45 min)
|
| 1140 |
+
Build `baseline/run_baseline.py`.
|
| 1141 |
+
Test: run against all 3 tasks, confirm scores are reproducible across 2 runs.
|
| 1142 |
+
|
| 1143 |
+
### Step 7 — FastAPI Server (30 min)
|
| 1144 |
+
Build `env/server.py` — wrap env in HTTP endpoints for HF Spaces.
|
| 1145 |
+
Test: `docker build` passes, `docker run` starts server on port 7860.
|
| 1146 |
+
|
| 1147 |
+
### Step 8 — Docs + Validation (30 min)
|
| 1148 |
+
Write `openenv.yaml` and `README.md`. Run `openenv validate`.
|
| 1149 |
+
Fill in real baseline scores from Step 6 output.
|
| 1150 |
+
|
| 1151 |
+
### Step 9 — Deploy to HuggingFace (15 min)
|
| 1152 |
+
Push to HF Space repo. Tag with `openenv`. Verify Space starts cleanly.
|
| 1153 |
+
|
| 1154 |
+
---
|
| 1155 |
+
|
| 1156 |
+
*Total estimated time: ~6 hours for a clean first build.*
|
env/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .models import Action, QueryResult, Observation, StepResult, EnvState
|
| 2 |
+
from .environment import SQLAnalystEnv
|
| 3 |
+
from .tasks import TASKS
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"Action",
|
| 7 |
+
"QueryResult",
|
| 8 |
+
"Observation",
|
| 9 |
+
"StepResult",
|
| 10 |
+
"EnvState",
|
| 11 |
+
"SQLAnalystEnv",
|
| 12 |
+
"TASKS",
|
| 13 |
+
]
|
env/database.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sqlite3
|
| 2 |
+
import random
|
| 3 |
+
from datetime import datetime, timedelta
|
| 4 |
+
from typing import Optional, Any
|
| 5 |
+
from faker import Faker
|
| 6 |
+
|
| 7 |
+
fake = Faker()
|
| 8 |
+
|
| 9 |
+
SEED_CONFIG = {
|
| 10 |
+
"users": 500,
|
| 11 |
+
"products": 80,
|
| 12 |
+
"orders": 2000,
|
| 13 |
+
"order_items": 5000,
|
| 14 |
+
"events": 8000,
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
CATEGORIES = ["Electronics", "Clothing", "Books", "Home & Garden", "Sports"]
|
| 18 |
+
PLAN_TYPES = ["free", "pro", "enterprise"]
|
| 19 |
+
ORDER_STATUSES = ["pending", "completed", "refunded"]
|
| 20 |
+
EVENT_TYPES = ["page_view", "add_to_cart", "checkout", "login", "logout"]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def create_database(db_path: str = ":memory:") -> sqlite3.Connection:
|
| 24 |
+
conn = sqlite3.connect(db_path)
|
| 25 |
+
conn.row_factory = sqlite3.Row
|
| 26 |
+
|
| 27 |
+
conn.execute("""
|
| 28 |
+
CREATE TABLE users (
|
| 29 |
+
id INTEGER PRIMARY KEY,
|
| 30 |
+
email TEXT NOT NULL,
|
| 31 |
+
country TEXT,
|
| 32 |
+
plan TEXT CHECK(plan IN ('free', 'pro', 'enterprise')),
|
| 33 |
+
created_at TIMESTAMP NOT NULL,
|
| 34 |
+
churned_at TIMESTAMP
|
| 35 |
+
)
|
| 36 |
+
""")
|
| 37 |
+
|
| 38 |
+
conn.execute("""
|
| 39 |
+
CREATE TABLE products (
|
| 40 |
+
id INTEGER PRIMARY KEY,
|
| 41 |
+
name TEXT NOT NULL,
|
| 42 |
+
category TEXT NOT NULL,
|
| 43 |
+
price REAL,
|
| 44 |
+
cost REAL
|
| 45 |
+
)
|
| 46 |
+
""")
|
| 47 |
+
|
| 48 |
+
conn.execute("""
|
| 49 |
+
CREATE TABLE orders (
|
| 50 |
+
id INTEGER PRIMARY KEY,
|
| 51 |
+
user_id INTEGER REFERENCES users(id),
|
| 52 |
+
created_at TIMESTAMP NOT NULL,
|
| 53 |
+
status TEXT CHECK(status IN ('pending', 'completed', 'refunded')),
|
| 54 |
+
total REAL
|
| 55 |
+
)
|
| 56 |
+
""")
|
| 57 |
+
|
| 58 |
+
conn.execute("""
|
| 59 |
+
CREATE TABLE order_items (
|
| 60 |
+
id INTEGER PRIMARY KEY,
|
| 61 |
+
order_id INTEGER REFERENCES orders(id),
|
| 62 |
+
product_id INTEGER REFERENCES products(id),
|
| 63 |
+
qty INTEGER NOT NULL,
|
| 64 |
+
unit_price REAL
|
| 65 |
+
)
|
| 66 |
+
""")
|
| 67 |
+
|
| 68 |
+
conn.execute("""
|
| 69 |
+
CREATE TABLE events (
|
| 70 |
+
id INTEGER PRIMARY KEY,
|
| 71 |
+
user_id INTEGER REFERENCES users(id),
|
| 72 |
+
event_type TEXT,
|
| 73 |
+
metadata TEXT,
|
| 74 |
+
ts TIMESTAMP NOT NULL
|
| 75 |
+
)
|
| 76 |
+
""")
|
| 77 |
+
|
| 78 |
+
conn.commit()
|
| 79 |
+
return conn
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def seed_database(conn: sqlite3.Connection) -> None:
|
| 83 |
+
users = _seed_users(conn)
|
| 84 |
+
products = _seed_products(conn)
|
| 85 |
+
orders, order_items = _seed_orders(conn, users, products)
|
| 86 |
+
_seed_events(conn, users, orders)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def _seed_users(conn: sqlite3.Connection) -> list:
|
| 90 |
+
users = []
|
| 91 |
+
now = datetime.now()
|
| 92 |
+
base_date = now - timedelta(days=180)
|
| 93 |
+
recent_date = now - timedelta(days=30)
|
| 94 |
+
|
| 95 |
+
for i in range(SEED_CONFIG["users"]):
|
| 96 |
+
if random.random() < 0.3:
|
| 97 |
+
created_at = recent_date + timedelta(days=random.randint(0, 30))
|
| 98 |
+
else:
|
| 99 |
+
created_at = base_date + timedelta(days=random.randint(0, 180))
|
| 100 |
+
|
| 101 |
+
country = random.choice([fake.country(), None, None, None, None])
|
| 102 |
+
plan = random.choice(PLAN_TYPES)
|
| 103 |
+
churned_at = None
|
| 104 |
+
|
| 105 |
+
if plan == "free" and random.random() < 0.1:
|
| 106 |
+
churned_at = created_at + timedelta(days=random.randint(30, 150))
|
| 107 |
+
|
| 108 |
+
conn.execute(
|
| 109 |
+
"INSERT INTO users (email, country, plan, created_at, churned_at) VALUES (?, ?, ?, ?, ?)",
|
| 110 |
+
(
|
| 111 |
+
fake.email(),
|
| 112 |
+
country,
|
| 113 |
+
plan,
|
| 114 |
+
created_at.isoformat(),
|
| 115 |
+
churned_at.isoformat() if churned_at else None,
|
| 116 |
+
),
|
| 117 |
+
)
|
| 118 |
+
users.append((i + 1, created_at))
|
| 119 |
+
|
| 120 |
+
conn.commit()
|
| 121 |
+
return users
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def _seed_products(conn: sqlite3.Connection) -> list:
|
| 125 |
+
products = []
|
| 126 |
+
|
| 127 |
+
for i in range(SEED_CONFIG["products"]):
|
| 128 |
+
category = random.choice(CATEGORIES)
|
| 129 |
+
price = round(random.uniform(10, 500), 2)
|
| 130 |
+
cost = round(price * random.uniform(0.3, 0.7), 2)
|
| 131 |
+
|
| 132 |
+
conn.execute(
|
| 133 |
+
"INSERT INTO products (name, category, price, cost) VALUES (?, ?, ?, ?)",
|
| 134 |
+
(fake.catch_phrase(), category, price, cost),
|
| 135 |
+
)
|
| 136 |
+
products.append((i + 1, category, price))
|
| 137 |
+
|
| 138 |
+
conn.commit()
|
| 139 |
+
return products
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def _seed_orders(conn: sqlite3.Connection, users: list, products: list) -> tuple:
|
| 143 |
+
orders = []
|
| 144 |
+
order_items = []
|
| 145 |
+
|
| 146 |
+
q3_start = datetime(2024, 7, 1)
|
| 147 |
+
q3_end = datetime(2024, 9, 30)
|
| 148 |
+
recent_date = datetime.now()
|
| 149 |
+
old_date = datetime(2024, 1, 1)
|
| 150 |
+
|
| 151 |
+
for i in range(SEED_CONFIG["orders"]):
|
| 152 |
+
user_id = random.choice(users)[0]
|
| 153 |
+
|
| 154 |
+
if random.random() < 0.2:
|
| 155 |
+
created_at = q3_start + timedelta(days=random.randint(0, 91))
|
| 156 |
+
else:
|
| 157 |
+
created_at = old_date + timedelta(days=random.randint(0, 180))
|
| 158 |
+
|
| 159 |
+
status = random.choices(ORDER_STATUSES, weights=[0.1, 0.87, 0.03])[0]
|
| 160 |
+
|
| 161 |
+
conn.execute(
|
| 162 |
+
"INSERT INTO orders (user_id, created_at, status, total) VALUES (?, ?, ?, ?)",
|
| 163 |
+
(user_id, created_at.isoformat(), status, 0),
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
order_id = i + 1
|
| 167 |
+
order_total = 0
|
| 168 |
+
|
| 169 |
+
num_items = random.randint(1, 5)
|
| 170 |
+
for _ in range(num_items):
|
| 171 |
+
product = random.choice(products)
|
| 172 |
+
qty = random.randint(1, 3)
|
| 173 |
+
unit_price = product[2]
|
| 174 |
+
order_total += qty * unit_price
|
| 175 |
+
|
| 176 |
+
conn.execute(
|
| 177 |
+
"INSERT INTO order_items (order_id, product_id, qty, unit_price) VALUES (?, ?, ?, ?)",
|
| 178 |
+
(order_id, product[0], qty, unit_price),
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
conn.execute(
|
| 182 |
+
"UPDATE orders SET total = ? WHERE id = ?",
|
| 183 |
+
(round(order_total, 2), order_id),
|
| 184 |
+
)
|
| 185 |
+
orders.append((order_id, user_id, created_at, status))
|
| 186 |
+
|
| 187 |
+
conn.commit()
|
| 188 |
+
return orders, order_items
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def _seed_events(conn: sqlite3.Connection, users: list, orders: list) -> None:
|
| 192 |
+
base_date = datetime.now() - timedelta(days=180)
|
| 193 |
+
|
| 194 |
+
for _ in range(SEED_CONFIG["events"]):
|
| 195 |
+
user_id = random.choice(users)[0]
|
| 196 |
+
ts = base_date + timedelta(
|
| 197 |
+
days=random.randint(0, 180), hours=random.randint(0, 23)
|
| 198 |
+
)
|
| 199 |
+
event_type = random.choice(EVENT_TYPES)
|
| 200 |
+
metadata = '{"page": "/' + fake.uri_path() + '"}'
|
| 201 |
+
|
| 202 |
+
conn.execute(
|
| 203 |
+
"INSERT INTO events (user_id, event_type, metadata, ts) VALUES (?, ?, ?, ?)",
|
| 204 |
+
(user_id, event_type, metadata, ts.isoformat()),
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
conn.commit()
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def get_schema_summary(conn: sqlite3.Connection) -> str:
|
| 211 |
+
cursor = conn.execute(
|
| 212 |
+
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY name"
|
| 213 |
+
)
|
| 214 |
+
tables = [r[0] for r in cursor.fetchall()]
|
| 215 |
+
|
| 216 |
+
lines = []
|
| 217 |
+
for table in tables:
|
| 218 |
+
cols = conn.execute(f"PRAGMA table_info({table})").fetchall()
|
| 219 |
+
col_names = [c[1] for c in cols]
|
| 220 |
+
lines.append(f"{table}: ({', '.join(col_names)})")
|
| 221 |
+
|
| 222 |
+
return "\n".join(lines)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def get_ground_truth(conn: sqlite3.Connection, task_id: str) -> Any:
|
| 226 |
+
if task_id == "monthly_signups":
|
| 227 |
+
result = conn.execute(
|
| 228 |
+
"SELECT COUNT(*) FROM users WHERE created_at >= DATE('now', '-30 days')"
|
| 229 |
+
).fetchone()
|
| 230 |
+
return result[0]
|
| 231 |
+
|
| 232 |
+
elif task_id == "top_revenue_category":
|
| 233 |
+
result = conn.execute("""
|
| 234 |
+
SELECT p.category, SUM(oi.qty * oi.unit_price) AS revenue
|
| 235 |
+
FROM orders o
|
| 236 |
+
JOIN order_items oi ON o.id = oi.order_id
|
| 237 |
+
JOIN products p ON oi.product_id = p.id
|
| 238 |
+
WHERE o.created_at BETWEEN '2024-07-01' AND '2024-09-30'
|
| 239 |
+
AND o.status = 'completed'
|
| 240 |
+
GROUP BY p.category
|
| 241 |
+
ORDER BY revenue DESC
|
| 242 |
+
LIMIT 1
|
| 243 |
+
""").fetchone()
|
| 244 |
+
return result[0] if result else None
|
| 245 |
+
|
| 246 |
+
elif task_id == "churn_analysis":
|
| 247 |
+
result = conn.execute("""
|
| 248 |
+
WITH order_counts AS (
|
| 249 |
+
SELECT user_id, COUNT(*) AS total_orders,
|
| 250 |
+
MAX(created_at) AS last_order_date
|
| 251 |
+
FROM orders
|
| 252 |
+
WHERE status = 'completed'
|
| 253 |
+
GROUP BY user_id
|
| 254 |
+
HAVING COUNT(*) = 3
|
| 255 |
+
),
|
| 256 |
+
churned AS (
|
| 257 |
+
SELECT oc.user_id
|
| 258 |
+
FROM order_counts oc
|
| 259 |
+
WHERE oc.last_order_date < DATE('now', '-90 days')
|
| 260 |
+
)
|
| 261 |
+
SELECT u.email
|
| 262 |
+
FROM users u
|
| 263 |
+
JOIN churned c ON u.id = c.user_id
|
| 264 |
+
""").fetchall()
|
| 265 |
+
return {row[0].lower() for row in result}
|
| 266 |
+
|
| 267 |
+
return None
|
env/environment.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sqlite3
|
| 2 |
+
from typing import Optional
|
| 3 |
+
from .models import Action, Observation, StepResult, EnvState, QueryResult
|
| 4 |
+
from .database import create_database, seed_database, get_schema_summary
|
| 5 |
+
from .reward import RewardCalculator
|
| 6 |
+
from .tasks import TASKS
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class SQLAnalystEnv:
|
| 10 |
+
"""
|
| 11 |
+
OpenEnv-compliant SQL Data Analyst environment.
|
| 12 |
+
|
| 13 |
+
An agent must answer business questions by iteratively
|
| 14 |
+
writing and executing SQL queries.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, task_id: str = "monthly_signups"):
|
| 18 |
+
assert task_id in TASKS, f"Unknown task: {task_id}. Choose from {list(TASKS)}"
|
| 19 |
+
self.task_id = task_id
|
| 20 |
+
self.task = TASKS[task_id]
|
| 21 |
+
self.conn: Optional[sqlite3.Connection] = None
|
| 22 |
+
self.step_count: int = 0
|
| 23 |
+
self.total_reward: float = 0.0
|
| 24 |
+
self.done: bool = False
|
| 25 |
+
self._query_history: list = []
|
| 26 |
+
self._reward_calc = RewardCalculator()
|
| 27 |
+
|
| 28 |
+
def reset(self) -> StepResult:
|
| 29 |
+
"""Reset environment. Reseed DB. Return initial observation."""
|
| 30 |
+
if self.conn:
|
| 31 |
+
self.conn.close()
|
| 32 |
+
|
| 33 |
+
self.conn = create_database()
|
| 34 |
+
seed_database(self.conn)
|
| 35 |
+
self.step_count = 0
|
| 36 |
+
self.total_reward = 0.0
|
| 37 |
+
self.done = False
|
| 38 |
+
self._query_history = []
|
| 39 |
+
|
| 40 |
+
self.task.compute_ground_truth(self.conn)
|
| 41 |
+
|
| 42 |
+
obs = Observation(
|
| 43 |
+
schema_summary=get_schema_summary(self.conn),
|
| 44 |
+
question=self.task.question,
|
| 45 |
+
step=0,
|
| 46 |
+
max_steps=self.task.max_steps,
|
| 47 |
+
)
|
| 48 |
+
return StepResult(observation=obs, reward=0.0, done=False)
|
| 49 |
+
|
| 50 |
+
def step(self, action: Action) -> StepResult:
|
| 51 |
+
"""Execute one agent action. Return (observation, reward, done, info)."""
|
| 52 |
+
assert self.conn is not None, "Call reset() before step()"
|
| 53 |
+
assert not self.done, "Episode is done. Call reset()."
|
| 54 |
+
assert action.is_valid(), (
|
| 55 |
+
"Action must have exactly one of: sql_query, submit_answer"
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
self.step_count += 1
|
| 59 |
+
query_result = None
|
| 60 |
+
error = None
|
| 61 |
+
|
| 62 |
+
if action.sql_query:
|
| 63 |
+
query_result = self._execute_sql(action.sql_query)
|
| 64 |
+
self._query_history.append(action.sql_query)
|
| 65 |
+
error = query_result.error
|
| 66 |
+
|
| 67 |
+
terminal = (
|
| 68 |
+
action.submit_answer is not None or self.step_count >= self.task.max_steps
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
reward = self._reward_calc.calculate(
|
| 72 |
+
action=action,
|
| 73 |
+
result=query_result,
|
| 74 |
+
task=self.task,
|
| 75 |
+
step=self.step_count,
|
| 76 |
+
query_history=self._query_history,
|
| 77 |
+
terminal=terminal,
|
| 78 |
+
)
|
| 79 |
+
self.total_reward += reward
|
| 80 |
+
self.done = terminal
|
| 81 |
+
|
| 82 |
+
obs = Observation(
|
| 83 |
+
schema_summary=get_schema_summary(self.conn),
|
| 84 |
+
question=self.task.question,
|
| 85 |
+
last_query=action.sql_query,
|
| 86 |
+
last_result=query_result,
|
| 87 |
+
last_error=error,
|
| 88 |
+
step=self.step_count,
|
| 89 |
+
max_steps=self.task.max_steps,
|
| 90 |
+
hints=self.task.get_hints(self.step_count),
|
| 91 |
+
done=self.done,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
return StepResult(
|
| 95 |
+
observation=obs,
|
| 96 |
+
reward=round(reward, 3),
|
| 97 |
+
done=self.done,
|
| 98 |
+
info={
|
| 99 |
+
"step": self.step_count,
|
| 100 |
+
"total_reward": round(self.total_reward, 3),
|
| 101 |
+
"task_id": self.task_id,
|
| 102 |
+
},
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
def state(self) -> EnvState:
|
| 106 |
+
"""Return current full state of the environment."""
|
| 107 |
+
return EnvState(
|
| 108 |
+
task_id=self.task_id,
|
| 109 |
+
difficulty=self.task.difficulty,
|
| 110 |
+
step=self.step_count,
|
| 111 |
+
max_steps=self.task.max_steps,
|
| 112 |
+
query_history=self._query_history.copy(),
|
| 113 |
+
total_reward=round(self.total_reward, 3),
|
| 114 |
+
done=self.done,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
def _execute_sql(self, query: str) -> QueryResult:
|
| 118 |
+
"""Execute SQL safely. Block non-SELECT. Return up to 50 rows."""
|
| 119 |
+
q = query.strip().upper()
|
| 120 |
+
if not q.startswith("SELECT") and not q.startswith("WITH"):
|
| 121 |
+
return QueryResult(error="Only SELECT / WITH queries are allowed.")
|
| 122 |
+
try:
|
| 123 |
+
cursor = self.conn.execute(query)
|
| 124 |
+
cols = [d[0] for d in cursor.description] if cursor.description else []
|
| 125 |
+
rows = cursor.fetchmany(50)
|
| 126 |
+
total = len(rows)
|
| 127 |
+
return QueryResult(
|
| 128 |
+
columns=cols,
|
| 129 |
+
rows=[list(r) for r in rows],
|
| 130 |
+
truncated=(total == 50),
|
| 131 |
+
total_rows=total,
|
| 132 |
+
)
|
| 133 |
+
except Exception as e:
|
| 134 |
+
return QueryResult(error=str(e))
|
env/models.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel, Field
|
| 2 |
+
from typing import Optional, List, Any
|
| 3 |
+
|
| 4 |
+
class Action(BaseModel):
|
| 5 |
+
"""What the agent can do each step."""
|
| 6 |
+
sql_query: Optional[str] = Field(
|
| 7 |
+
None,
|
| 8 |
+
description="A SQL SELECT query to execute against the database"
|
| 9 |
+
)
|
| 10 |
+
submit_answer: Optional[str] = Field(
|
| 11 |
+
None,
|
| 12 |
+
description="Final answer to submit. Ends the episode."
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
def is_valid(self) -> bool:
|
| 16 |
+
# Exactly one of the two must be set
|
| 17 |
+
return bool(self.sql_query) != bool(self.submit_answer)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class QueryResult(BaseModel):
|
| 21 |
+
"""Result of executing a SQL query."""
|
| 22 |
+
columns: List[str] = []
|
| 23 |
+
rows: List[List[Any]] = []
|
| 24 |
+
error: Optional[str] = None
|
| 25 |
+
truncated: bool = False
|
| 26 |
+
total_rows: int = 0
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class Observation(BaseModel):
|
| 30 |
+
"""What the agent sees after each step."""
|
| 31 |
+
schema_summary: str = Field(..., description="Compact DB schema")
|
| 32 |
+
question: str = Field(..., description="Business question to answer")
|
| 33 |
+
last_query: Optional[str] = None
|
| 34 |
+
last_result: Optional[QueryResult] = None
|
| 35 |
+
last_error: Optional[str] = None
|
| 36 |
+
step: int = 0
|
| 37 |
+
max_steps: int = 20
|
| 38 |
+
hints: List[str] = []
|
| 39 |
+
done: bool = False
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class StepResult(BaseModel):
|
| 43 |
+
"""Full result returned by step()."""
|
| 44 |
+
observation: Observation
|
| 45 |
+
reward: float = 0.0
|
| 46 |
+
done: bool = False
|
| 47 |
+
info: dict = {}
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class EnvState(BaseModel):
|
| 51 |
+
"""Full environment state returned by state()."""
|
| 52 |
+
task_id: str
|
| 53 |
+
difficulty: str
|
| 54 |
+
step: int
|
| 55 |
+
max_steps: int
|
| 56 |
+
query_history: List[str] = []
|
| 57 |
+
total_reward: float = 0.0
|
| 58 |
+
done: bool = False
|
env/reward.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, List, Any
|
| 2 |
+
from .models import Action, QueryResult
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class RewardCalculator:
|
| 6 |
+
"""Calculate rewards for agent actions in the SQL analyst environment."""
|
| 7 |
+
|
| 8 |
+
def calculate(
|
| 9 |
+
self,
|
| 10 |
+
action: Action,
|
| 11 |
+
result: Optional[QueryResult],
|
| 12 |
+
task: Any,
|
| 13 |
+
step: int,
|
| 14 |
+
query_history: List[str],
|
| 15 |
+
terminal: bool,
|
| 16 |
+
) -> float:
|
| 17 |
+
"""Calculate reward based on action, result, and task."""
|
| 18 |
+
reward = 0.0
|
| 19 |
+
|
| 20 |
+
if action.sql_query and result:
|
| 21 |
+
if not result.error:
|
| 22 |
+
reward += 0.15
|
| 23 |
+
|
| 24 |
+
relevant = self._count_relevant_tables(
|
| 25 |
+
action.sql_query, task.relevant_tables
|
| 26 |
+
)
|
| 27 |
+
if relevant > 0:
|
| 28 |
+
reward += 0.10
|
| 29 |
+
|
| 30 |
+
if result.rows and len(result.rows) > 0:
|
| 31 |
+
reward += 0.05
|
| 32 |
+
|
| 33 |
+
if result.rows and len(result.rows) < 1000:
|
| 34 |
+
reward += 0.05
|
| 35 |
+
|
| 36 |
+
if step > 3:
|
| 37 |
+
reward -= 0.02 * (step - 3)
|
| 38 |
+
|
| 39 |
+
if self._is_stuck(query_history):
|
| 40 |
+
reward -= 0.10
|
| 41 |
+
|
| 42 |
+
if terminal and action.submit_answer:
|
| 43 |
+
task_score = task.grade(action.submit_answer)
|
| 44 |
+
reward += task_score * 0.60
|
| 45 |
+
|
| 46 |
+
return max(0.0, min(1.0, reward))
|
| 47 |
+
|
| 48 |
+
def _count_relevant_tables(self, query: str, relevant_tables: List[str]) -> int:
|
| 49 |
+
query_lower = query.lower()
|
| 50 |
+
return sum(1 for t in relevant_tables if t.lower() in query_lower)
|
| 51 |
+
|
| 52 |
+
def _is_stuck(self, history: List[str]) -> bool:
|
| 53 |
+
if len(history) < 3:
|
| 54 |
+
return False
|
| 55 |
+
return len(set(history[-3:])) == 1
|
env/server.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI server for SQL Data Analyst OpenEnv.
|
| 3 |
+
|
| 4 |
+
Provides REST and WebSocket endpoints for HuggingFace Spaces deployment.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
| 8 |
+
from pydantic import BaseModel
|
| 9 |
+
from typing import Optional, Dict, Any
|
| 10 |
+
import json
|
| 11 |
+
import asyncio
|
| 12 |
+
|
| 13 |
+
from env import SQLAnalystEnv, Action
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
app = FastAPI(title="SQL Data Analyst Environment")
|
| 17 |
+
|
| 18 |
+
envs: Dict[str, SQLAnalystEnv] = {}
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class ResetRequest(BaseModel):
|
| 22 |
+
task_id: str = "monthly_signups"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class StepRequest(BaseModel):
|
| 26 |
+
session_id: str
|
| 27 |
+
sql_query: Optional[str] = None
|
| 28 |
+
submit_answer: Optional[str] = None
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class StateRequest(BaseModel):
|
| 32 |
+
session_id: str
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@app.get("/")
|
| 36 |
+
async def root():
|
| 37 |
+
return {
|
| 38 |
+
"name": "sql-data-analyst",
|
| 39 |
+
"version": "1.0.0",
|
| 40 |
+
"description": "SQL Data Analyst OpenEnv - RL environment for SQL query generation",
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@app.post("/reset")
|
| 45 |
+
async def reset(req: ResetRequest) -> Dict[str, Any]:
|
| 46 |
+
session_id = req.task_id
|
| 47 |
+
|
| 48 |
+
env = SQLAnalystEnv(task_id=req.task_id)
|
| 49 |
+
result = env.reset()
|
| 50 |
+
envs[session_id] = env
|
| 51 |
+
|
| 52 |
+
return {
|
| 53 |
+
"session_id": session_id,
|
| 54 |
+
"observation": {
|
| 55 |
+
"schema_summary": result.observation.schema_summary,
|
| 56 |
+
"question": result.observation.question,
|
| 57 |
+
"step": result.observation.step,
|
| 58 |
+
"max_steps": result.observation.max_steps,
|
| 59 |
+
"hints": result.observation.hints,
|
| 60 |
+
"done": result.observation.done,
|
| 61 |
+
},
|
| 62 |
+
"reward": result.reward,
|
| 63 |
+
"done": result.done,
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@app.post("/step")
|
| 68 |
+
async def step(req: StepRequest) -> Dict[str, Any]:
|
| 69 |
+
session_id = req.session_id
|
| 70 |
+
|
| 71 |
+
if session_id not in envs:
|
| 72 |
+
return {"error": "Session not found. Call /reset first."}
|
| 73 |
+
|
| 74 |
+
env = envs[session_id]
|
| 75 |
+
|
| 76 |
+
action = Action(sql_query=req.sql_query, submit_answer=req.submit_answer)
|
| 77 |
+
|
| 78 |
+
result = env.step(action)
|
| 79 |
+
|
| 80 |
+
return {
|
| 81 |
+
"observation": {
|
| 82 |
+
"schema_summary": result.observation.schema_summary,
|
| 83 |
+
"question": result.observation.question,
|
| 84 |
+
"last_query": result.observation.last_query,
|
| 85 |
+
"last_result": {
|
| 86 |
+
"columns": result.observation.last_result.columns
|
| 87 |
+
if result.observation.last_result
|
| 88 |
+
else None,
|
| 89 |
+
"rows": result.observation.last_result.rows
|
| 90 |
+
if result.observation.last_result
|
| 91 |
+
else None,
|
| 92 |
+
"error": result.observation.last_result.error
|
| 93 |
+
if result.observation.last_result
|
| 94 |
+
else None,
|
| 95 |
+
}
|
| 96 |
+
if result.observation.last_result
|
| 97 |
+
else None,
|
| 98 |
+
"last_error": result.observation.last_error,
|
| 99 |
+
"step": result.observation.step,
|
| 100 |
+
"max_steps": result.observation.max_steps,
|
| 101 |
+
"hints": result.observation.hints,
|
| 102 |
+
"done": result.observation.done,
|
| 103 |
+
},
|
| 104 |
+
"reward": result.reward,
|
| 105 |
+
"done": result.done,
|
| 106 |
+
"info": result.info,
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
@app.post("/state")
|
| 111 |
+
async def state(req: StateRequest) -> Dict[str, Any]:
|
| 112 |
+
session_id = req.session_id
|
| 113 |
+
|
| 114 |
+
if session_id not in envs:
|
| 115 |
+
return {"error": "Session not found. Call /reset first."}
|
| 116 |
+
|
| 117 |
+
env = envs[session_id]
|
| 118 |
+
state = env.state()
|
| 119 |
+
|
| 120 |
+
return {
|
| 121 |
+
"task_id": state.task_id,
|
| 122 |
+
"difficulty": state.difficulty,
|
| 123 |
+
"step": state.step,
|
| 124 |
+
"max_steps": state.max_steps,
|
| 125 |
+
"query_history": state.query_history,
|
| 126 |
+
"total_reward": state.total_reward,
|
| 127 |
+
"done": state.done,
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
@app.post("/delete")
|
| 132 |
+
async def delete_session(req: StateRequest) -> Dict[str, str]:
|
| 133 |
+
session_id = req.session_id
|
| 134 |
+
|
| 135 |
+
if session_id in envs:
|
| 136 |
+
del envs[session_id]
|
| 137 |
+
return {"status": "deleted", "session_id": session_id}
|
| 138 |
+
|
| 139 |
+
return {"status": "not_found", "session_id": session_id}
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
@app.websocket("/ws")
|
| 143 |
+
async def websocket_endpoint(websocket: WebSocket):
|
| 144 |
+
await websocket.accept()
|
| 145 |
+
|
| 146 |
+
session_id = None
|
| 147 |
+
env = None
|
| 148 |
+
|
| 149 |
+
try:
|
| 150 |
+
while True:
|
| 151 |
+
data = await websocket.receive_text()
|
| 152 |
+
message = json.loads(data)
|
| 153 |
+
|
| 154 |
+
action_type = message.get("type")
|
| 155 |
+
|
| 156 |
+
if action_type == "reset":
|
| 157 |
+
task_id = message.get("task_id", "monthly_signups")
|
| 158 |
+
env = SQLAnalystEnv(task_id=task_id)
|
| 159 |
+
result = env.reset()
|
| 160 |
+
session_id = task_id
|
| 161 |
+
envs[session_id] = env
|
| 162 |
+
|
| 163 |
+
await websocket.send_json(
|
| 164 |
+
{
|
| 165 |
+
"type": "reset",
|
| 166 |
+
"observation": {
|
| 167 |
+
"schema_summary": result.observation.schema_summary,
|
| 168 |
+
"question": result.observation.question,
|
| 169 |
+
"step": result.observation.step,
|
| 170 |
+
"max_steps": result.observation.max_steps,
|
| 171 |
+
"hints": result.observation.hints,
|
| 172 |
+
},
|
| 173 |
+
"reward": result.reward,
|
| 174 |
+
"done": result.done,
|
| 175 |
+
}
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
elif action_type == "step":
|
| 179 |
+
if not env:
|
| 180 |
+
await websocket.send_json({"error": "Call reset first"})
|
| 181 |
+
continue
|
| 182 |
+
|
| 183 |
+
action = Action(
|
| 184 |
+
sql_query=message.get("sql_query"),
|
| 185 |
+
submit_answer=message.get("submit_answer"),
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
result = env.step(action)
|
| 189 |
+
|
| 190 |
+
await websocket.send_json(
|
| 191 |
+
{
|
| 192 |
+
"type": "step",
|
| 193 |
+
"observation": {
|
| 194 |
+
"schema_summary": result.observation.schema_summary,
|
| 195 |
+
"question": result.observation.question,
|
| 196 |
+
"last_query": result.observation.last_query,
|
| 197 |
+
"last_result": {
|
| 198 |
+
"columns": result.observation.last_result.columns
|
| 199 |
+
if result.observation.last_result
|
| 200 |
+
else None,
|
| 201 |
+
"rows": result.observation.last_result.rows
|
| 202 |
+
if result.observation.last_result
|
| 203 |
+
else None,
|
| 204 |
+
"error": result.observation.last_result.error
|
| 205 |
+
if result.observation.last_result
|
| 206 |
+
else None,
|
| 207 |
+
}
|
| 208 |
+
if result.observation.last_result
|
| 209 |
+
else None,
|
| 210 |
+
"step": result.observation.step,
|
| 211 |
+
"hints": result.observation.hints,
|
| 212 |
+
"done": result.observation.done,
|
| 213 |
+
},
|
| 214 |
+
"reward": result.reward,
|
| 215 |
+
"done": result.done,
|
| 216 |
+
"info": result.info,
|
| 217 |
+
}
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
elif action_type == "state":
|
| 221 |
+
if not env:
|
| 222 |
+
await websocket.send_json({"error": "Call reset first"})
|
| 223 |
+
continue
|
| 224 |
+
|
| 225 |
+
state = env.state()
|
| 226 |
+
|
| 227 |
+
await websocket.send_json(
|
| 228 |
+
{
|
| 229 |
+
"type": "state",
|
| 230 |
+
"task_id": state.task_id,
|
| 231 |
+
"difficulty": state.difficulty,
|
| 232 |
+
"step": state.step,
|
| 233 |
+
"max_steps": state.max_steps,
|
| 234 |
+
"query_history": state.query_history,
|
| 235 |
+
"total_reward": state.total_reward,
|
| 236 |
+
"done": state.done,
|
| 237 |
+
}
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
elif action_type == "close":
|
| 241 |
+
if session_id and session_id in envs:
|
| 242 |
+
del envs[session_id]
|
| 243 |
+
break
|
| 244 |
+
|
| 245 |
+
except WebSocketDisconnect:
|
| 246 |
+
pass
|
| 247 |
+
except Exception as e:
|
| 248 |
+
await websocket.send_json({"error": str(e)})
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
if __name__ == "__main__":
|
| 252 |
+
import uvicorn
|
| 253 |
+
|
| 254 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
env/tasks/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base import BaseTask
|
| 2 |
+
from .easy import MonthlySignupsTask
|
| 3 |
+
from .medium import TopRevenueCategoryTask
|
| 4 |
+
from .hard import ChurnAnalysisTask
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
TASKS = {
|
| 8 |
+
"monthly_signups": MonthlySignupsTask(),
|
| 9 |
+
"top_revenue_category": TopRevenueCategoryTask(),
|
| 10 |
+
"churn_analysis": ChurnAnalysisTask(),
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
__all__ = [
|
| 15 |
+
"BaseTask",
|
| 16 |
+
"MonthlySignupsTask",
|
| 17 |
+
"TopRevenueCategoryTask",
|
| 18 |
+
"ChurnAnalysisTask",
|
| 19 |
+
"TASKS",
|
| 20 |
+
]
|
env/tasks/base.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
import sqlite3
|
| 3 |
+
import re
|
| 4 |
+
from typing import Any, List, Optional
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class BaseTask(ABC):
|
| 8 |
+
"""Abstract base class for all tasks."""
|
| 9 |
+
|
| 10 |
+
task_id: str
|
| 11 |
+
difficulty: str
|
| 12 |
+
max_steps: int
|
| 13 |
+
question: str
|
| 14 |
+
relevant_tables: List[str]
|
| 15 |
+
sql_hint: str
|
| 16 |
+
|
| 17 |
+
def __init__(self):
|
| 18 |
+
self.ground_truth: Any = None
|
| 19 |
+
self.top_3_categories: List[str] = []
|
| 20 |
+
|
| 21 |
+
@abstractmethod
|
| 22 |
+
def compute_ground_truth(self, conn: sqlite3.Connection) -> None:
|
| 23 |
+
"""Compute ground truth after database seeding."""
|
| 24 |
+
pass
|
| 25 |
+
|
| 26 |
+
@abstractmethod
|
| 27 |
+
def grade(self, submitted_answer: str) -> float:
|
| 28 |
+
"""Grade the submitted answer. Returns score 0.0-1.0."""
|
| 29 |
+
pass
|
| 30 |
+
|
| 31 |
+
def get_hints(self, step: int) -> List[str]:
|
| 32 |
+
"""Return progressive hints based on current step."""
|
| 33 |
+
hints = []
|
| 34 |
+
if step > 5:
|
| 35 |
+
hints.append(
|
| 36 |
+
f"Hint: The relevant tables are: {', '.join(self.relevant_tables)}"
|
| 37 |
+
)
|
| 38 |
+
if step > 10:
|
| 39 |
+
hints.append(f"Hint: Try using {self.sql_hint}")
|
| 40 |
+
if step > 15:
|
| 41 |
+
hints.append("Hint: Make sure to submit your answer with submit_answer.")
|
| 42 |
+
return hints
|
| 43 |
+
|
| 44 |
+
def _normalize(self, text: str) -> str:
|
| 45 |
+
"""Remove common LLM formatting and normalize text."""
|
| 46 |
+
text = text.strip().lower()
|
| 47 |
+
text = re.sub(r"the (answer|result|category) is:?\s*", "", text)
|
| 48 |
+
text = re.sub(r"\*+", "", text)
|
| 49 |
+
text = re.sub(r"```.*?```", "", text, flags=re.DOTALL)
|
| 50 |
+
text = re.sub(r"`[^`]+`", lambda m: m.group().strip("`"), text)
|
| 51 |
+
text = re.sub(r"\s+", " ", text)
|
| 52 |
+
return text.strip()
|
env/tasks/easy.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sqlite3
|
| 2 |
+
from .base import BaseTask
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class MonthlySignupsTask(BaseTask):
|
| 6 |
+
"""Task 1 — Easy: Count users who signed up in the last 30 days."""
|
| 7 |
+
|
| 8 |
+
task_id = "monthly_signups"
|
| 9 |
+
difficulty = "easy"
|
| 10 |
+
max_steps = 10
|
| 11 |
+
question = "How many users signed up in the last 30 days?"
|
| 12 |
+
relevant_tables = ["users"]
|
| 13 |
+
sql_hint = "COUNT(*) with WHERE clause on created_at"
|
| 14 |
+
|
| 15 |
+
def compute_ground_truth(self, conn: sqlite3.Connection) -> None:
|
| 16 |
+
result = conn.execute(
|
| 17 |
+
"SELECT COUNT(*) FROM users WHERE created_at >= DATE('now', '-30 days')"
|
| 18 |
+
).fetchone()
|
| 19 |
+
self.ground_truth = result[0] if result else 0
|
| 20 |
+
|
| 21 |
+
def grade(self, submitted_answer: str) -> float:
|
| 22 |
+
try:
|
| 23 |
+
val = int(submitted_answer.strip().replace(",", ""))
|
| 24 |
+
if val == self.ground_truth:
|
| 25 |
+
return 1.0
|
| 26 |
+
if abs(val - self.ground_truth) <= 3:
|
| 27 |
+
return 0.6
|
| 28 |
+
if abs(val - self.ground_truth) <= 10:
|
| 29 |
+
return 0.3
|
| 30 |
+
except (ValueError, AttributeError):
|
| 31 |
+
pass
|
| 32 |
+
return 0.0
|
env/tasks/hard.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sqlite3
|
| 2 |
+
from .base import BaseTask
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class ChurnAnalysisTask(BaseTask):
|
| 6 |
+
"""Task 3 — Hard: Find users who placed exactly 3 orders and then churned."""
|
| 7 |
+
|
| 8 |
+
task_id = "churn_analysis"
|
| 9 |
+
difficulty = "hard"
|
| 10 |
+
max_steps = 20
|
| 11 |
+
question = "Find the email addresses of users who placed exactly 3 orders and then never ordered again (churned after their 3rd purchase). Return as a comma-separated list."
|
| 12 |
+
relevant_tables = ["users", "orders"]
|
| 13 |
+
sql_hint = "CTE with COUNT and HAVING"
|
| 14 |
+
|
| 15 |
+
def compute_ground_truth(self, conn: sqlite3.Connection) -> None:
|
| 16 |
+
result = conn.execute("""
|
| 17 |
+
WITH order_counts AS (
|
| 18 |
+
SELECT user_id, COUNT(*) AS total_orders,
|
| 19 |
+
MAX(created_at) AS last_order_date
|
| 20 |
+
FROM orders
|
| 21 |
+
WHERE status = 'completed'
|
| 22 |
+
GROUP BY user_id
|
| 23 |
+
HAVING COUNT(*) = 3
|
| 24 |
+
),
|
| 25 |
+
churned AS (
|
| 26 |
+
SELECT oc.user_id
|
| 27 |
+
FROM order_counts oc
|
| 28 |
+
WHERE oc.last_order_date < DATE('now', '-90 days')
|
| 29 |
+
)
|
| 30 |
+
SELECT u.email
|
| 31 |
+
FROM users u
|
| 32 |
+
JOIN churned c ON u.id = c.user_id
|
| 33 |
+
""").fetchall()
|
| 34 |
+
|
| 35 |
+
self.ground_truth = {row[0].lower() for row in result}
|
| 36 |
+
|
| 37 |
+
def grade(self, submitted_answer: str) -> float:
|
| 38 |
+
if not submitted_answer.strip():
|
| 39 |
+
return 0.0
|
| 40 |
+
|
| 41 |
+
submitted = {e.strip().lower() for e in submitted_answer.split(",") if "@" in e}
|
| 42 |
+
|
| 43 |
+
if not submitted:
|
| 44 |
+
return 0.0
|
| 45 |
+
|
| 46 |
+
correct = {e.lower() for e in self.ground_truth}
|
| 47 |
+
tp = len(submitted & correct)
|
| 48 |
+
|
| 49 |
+
if tp == 0:
|
| 50 |
+
return 0.0
|
| 51 |
+
|
| 52 |
+
precision = tp / len(submitted) if submitted else 0
|
| 53 |
+
recall = tp / len(correct) if correct else 0
|
| 54 |
+
|
| 55 |
+
if precision + recall == 0:
|
| 56 |
+
return 0.0
|
| 57 |
+
|
| 58 |
+
f1 = 2 * precision * recall / (precision + recall)
|
| 59 |
+
return round(f1, 3)
|
env/tasks/medium.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sqlite3
|
| 2 |
+
from .base import BaseTask
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class TopRevenueCategoryTask(BaseTask):
|
| 6 |
+
"""Task 2 — Medium: Find product category with most revenue in Q3."""
|
| 7 |
+
|
| 8 |
+
task_id = "top_revenue_category"
|
| 9 |
+
difficulty = "medium"
|
| 10 |
+
max_steps = 15
|
| 11 |
+
question = (
|
| 12 |
+
"Which product category generated the most revenue in Q3 (July-September)?"
|
| 13 |
+
)
|
| 14 |
+
relevant_tables = ["orders", "order_items", "products"]
|
| 15 |
+
sql_hint = "JOIN with GROUP BY and ORDER BY"
|
| 16 |
+
|
| 17 |
+
def compute_ground_truth(self, conn: sqlite3.Connection) -> None:
|
| 18 |
+
result = conn.execute("""
|
| 19 |
+
SELECT p.category, SUM(oi.qty * oi.unit_price) AS revenue
|
| 20 |
+
FROM orders o
|
| 21 |
+
JOIN order_items oi ON o.id = oi.order_id
|
| 22 |
+
JOIN products p ON oi.product_id = p.id
|
| 23 |
+
WHERE o.created_at BETWEEN '2024-07-01' AND '2024-09-30'
|
| 24 |
+
AND o.status = 'completed'
|
| 25 |
+
GROUP BY p.category
|
| 26 |
+
ORDER BY revenue DESC
|
| 27 |
+
LIMIT 1
|
| 28 |
+
""").fetchone()
|
| 29 |
+
|
| 30 |
+
self.ground_truth = result[0] if result else None
|
| 31 |
+
|
| 32 |
+
all_categories = conn.execute("""
|
| 33 |
+
SELECT p.category, SUM(oi.qty * oi.unit_price) AS revenue
|
| 34 |
+
FROM orders o
|
| 35 |
+
JOIN order_items oi ON o.id = oi.order_id
|
| 36 |
+
JOIN products p ON oi.product_id = p.id
|
| 37 |
+
WHERE o.created_at BETWEEN '2024-07-01' AND '2024-09-30'
|
| 38 |
+
AND o.status = 'completed'
|
| 39 |
+
GROUP BY p.category
|
| 40 |
+
ORDER BY revenue DESC
|
| 41 |
+
""").fetchall()
|
| 42 |
+
|
| 43 |
+
self.top_3_categories = [row[0] for row in all_categories[:3]]
|
| 44 |
+
|
| 45 |
+
def grade(self, submitted_answer: str) -> float:
|
| 46 |
+
answer = self._normalize(submitted_answer)
|
| 47 |
+
|
| 48 |
+
if self.ground_truth and self.ground_truth.lower() in answer:
|
| 49 |
+
return 1.0
|
| 50 |
+
|
| 51 |
+
if any(cat.lower() in answer for cat in self.top_3_categories):
|
| 52 |
+
return 0.4
|
| 53 |
+
|
| 54 |
+
return 0.0
|
env/utils.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def normalize_answer(raw: str) -> str:
|
| 5 |
+
"""Remove common LLM answer preambles and formatting."""
|
| 6 |
+
text = raw.strip().lower()
|
| 7 |
+
text = re.sub(r"the (answer|result) is:?\s*", "", text)
|
| 8 |
+
text = re.sub(r"\*+", "", text)
|
| 9 |
+
text = re.sub(r"```.*?```", "", text, flags=re.DOTALL)
|
| 10 |
+
text = re.sub(r"`[^`]+`", lambda m: m.group().strip("`"), text)
|
| 11 |
+
text = re.sub(r"\s+", " ", text)
|
| 12 |
+
return text.strip()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
FORBIDDEN_KEYWORDS = [
|
| 16 |
+
"DROP",
|
| 17 |
+
"DELETE",
|
| 18 |
+
"INSERT",
|
| 19 |
+
"UPDATE",
|
| 20 |
+
"ALTER",
|
| 21 |
+
"CREATE",
|
| 22 |
+
"TRUNCATE",
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def is_safe_query(query: str) -> bool:
|
| 27 |
+
"""Check if query is safe (SELECT-only)."""
|
| 28 |
+
upper = query.upper()
|
| 29 |
+
return not any(kw in upper for kw in FORBIDDEN_KEYWORDS)
|
hf_space/.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
hf_space/README.md
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Sql Data Analyst
|
| 3 |
+
emoji: 📉
|
| 4 |
+
colorFrom: gray
|
| 5 |
+
colorTo: green
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
license: mit
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
models.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel, Field
|
| 2 |
+
from typing import Optional, List, Any
|
| 3 |
+
|
| 4 |
+
class Action(BaseModel):
|
| 5 |
+
"""What the agent can do each step."""
|
| 6 |
+
sql_query: Optional[str] = Field(
|
| 7 |
+
None,
|
| 8 |
+
description="A SQL SELECT query to execute against the database"
|
| 9 |
+
)
|
| 10 |
+
submit_answer: Optional[str] = Field(
|
| 11 |
+
None,
|
| 12 |
+
description="Final answer to submit. Ends the episode."
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
def is_valid(self) -> bool:
|
| 16 |
+
# Exactly one of the two must be set
|
| 17 |
+
return bool(self.sql_query) != bool(self.submit_answer)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class QueryResult(BaseModel):
|
| 21 |
+
"""Result of executing a SQL query."""
|
| 22 |
+
columns: List[str] = []
|
| 23 |
+
rows: List[List[Any]] = []
|
| 24 |
+
error: Optional[str] = None
|
| 25 |
+
truncated: bool = False
|
| 26 |
+
total_rows: int = 0
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class Observation(BaseModel):
|
| 30 |
+
"""What the agent sees after each step."""
|
| 31 |
+
schema_summary: str = Field(..., description="Compact DB schema")
|
| 32 |
+
question: str = Field(..., description="Business question to answer")
|
| 33 |
+
last_query: Optional[str] = None
|
| 34 |
+
last_result: Optional[QueryResult] = None
|
| 35 |
+
last_error: Optional[str] = None
|
| 36 |
+
step: int = 0
|
| 37 |
+
max_steps: int = 20
|
| 38 |
+
hints: List[str] = []
|
| 39 |
+
done: bool = False
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class StepResult(BaseModel):
|
| 43 |
+
"""Full result returned by step()."""
|
| 44 |
+
observation: Observation
|
| 45 |
+
reward: float = 0.0
|
| 46 |
+
done: bool = False
|
| 47 |
+
info: dict = {}
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class EnvState(BaseModel):
|
| 51 |
+
"""Full environment state returned by state()."""
|
| 52 |
+
task_id: str
|
| 53 |
+
difficulty: str
|
| 54 |
+
step: int
|
| 55 |
+
max_steps: int
|
| 56 |
+
query_history: List[str] = []
|
| 57 |
+
total_reward: float = 0.0
|
| 58 |
+
done: bool = False
|
openenv.yaml
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: sql-data-analyst
|
| 2 |
+
version: "1.0.0"
|
| 3 |
+
description: >
|
| 4 |
+
An RL environment where an AI agent answers real business intelligence questions
|
| 5 |
+
by iteratively writing and executing SQL queries against a live SQLite database.
|
| 6 |
+
Simulates the day-to-day workflow of a data analyst.
|
| 7 |
+
|
| 8 |
+
tags:
|
| 9 |
+
- openenv
|
| 10 |
+
- sql
|
| 11 |
+
- data-analysis
|
| 12 |
+
- business-intelligence
|
| 13 |
+
- real-world
|
| 14 |
+
|
| 15 |
+
author: sql-data-analyst
|
| 16 |
+
repository: https://huggingface.co/spaces/sql-data-analyst
|
| 17 |
+
|
| 18 |
+
observation_space:
|
| 19 |
+
type: dict
|
| 20 |
+
fields:
|
| 21 |
+
schema_summary:
|
| 22 |
+
type: string
|
| 23 |
+
description: Compact one-line-per-table schema of the database
|
| 24 |
+
question:
|
| 25 |
+
type: string
|
| 26 |
+
description: Natural language business question to answer
|
| 27 |
+
last_query:
|
| 28 |
+
type: string
|
| 29 |
+
nullable: true
|
| 30 |
+
description: The last SQL query executed by the agent
|
| 31 |
+
last_result:
|
| 32 |
+
type: object
|
| 33 |
+
nullable: true
|
| 34 |
+
description: Result of the last query (columns, rows, error)
|
| 35 |
+
last_error:
|
| 36 |
+
type: string
|
| 37 |
+
nullable: true
|
| 38 |
+
description: SQL error message if last query failed
|
| 39 |
+
step:
|
| 40 |
+
type: integer
|
| 41 |
+
description: Current step number
|
| 42 |
+
max_steps:
|
| 43 |
+
type: integer
|
| 44 |
+
description: Maximum steps allowed for this task
|
| 45 |
+
hints:
|
| 46 |
+
type: array
|
| 47 |
+
items: string
|
| 48 |
+
description: Progressive hints revealed as steps increase
|
| 49 |
+
done:
|
| 50 |
+
type: boolean
|
| 51 |
+
description: Whether the episode is complete
|
| 52 |
+
|
| 53 |
+
action_space:
|
| 54 |
+
type: union
|
| 55 |
+
description: Agent must provide exactly one of the following
|
| 56 |
+
options:
|
| 57 |
+
sql_query:
|
| 58 |
+
type: string
|
| 59 |
+
description: A SELECT or WITH SQL query to execute
|
| 60 |
+
submit_answer:
|
| 61 |
+
type: string
|
| 62 |
+
description: Final answer to the question. Ends the episode.
|
| 63 |
+
|
| 64 |
+
tasks:
|
| 65 |
+
- id: monthly_signups
|
| 66 |
+
difficulty: easy
|
| 67 |
+
max_steps: 10
|
| 68 |
+
description: "Count the number of users who signed up in the last 30 days"
|
| 69 |
+
skills_required:
|
| 70 |
+
- COUNT
|
| 71 |
+
- WHERE with date filter
|
| 72 |
+
|
| 73 |
+
- id: top_revenue_category
|
| 74 |
+
difficulty: medium
|
| 75 |
+
max_steps: 15
|
| 76 |
+
description: "Find which product category generated the most revenue in Q3"
|
| 77 |
+
skills_required:
|
| 78 |
+
- JOIN (3 tables)
|
| 79 |
+
- GROUP BY
|
| 80 |
+
- SUM aggregation
|
| 81 |
+
- Date range filtering
|
| 82 |
+
|
| 83 |
+
- id: churn_analysis
|
| 84 |
+
difficulty: hard
|
| 85 |
+
max_steps: 20
|
| 86 |
+
description: >
|
| 87 |
+
Find email addresses of users who placed exactly 3 orders and then
|
| 88 |
+
never ordered again (churned after their 3rd purchase)
|
| 89 |
+
skills_required:
|
| 90 |
+
- Subqueries
|
| 91 |
+
- HAVING clause
|
| 92 |
+
- Date logic
|
| 93 |
+
- Window functions (optional)
|
| 94 |
+
|
| 95 |
+
baseline_scores:
|
| 96 |
+
monthly_signups: 0.85
|
| 97 |
+
top_revenue_category: 0.65
|
| 98 |
+
churn_analysis: 0.40
|
| 99 |
+
average: 0.63
|
| 100 |
+
model: gpt-4o-mini
|
progress.txt
ADDED
|
File without changes
|
pyproject.toml
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=61.0", "wheel"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "sql-data-analyst"
|
| 7 |
+
version = "1.0.0"
|
| 8 |
+
description = "SQL Data Analyst OpenEnv - RL environment for SQL query generation"
|
| 9 |
+
readme = "README.md"
|
| 10 |
+
license = {text = "MIT"}
|
| 11 |
+
authors = [
|
| 12 |
+
{name = "Hackathon Team", email = "team@example.com"}
|
| 13 |
+
]
|
| 14 |
+
requires-python = ">=3.11"
|
| 15 |
+
dependencies = [
|
| 16 |
+
"openenv>=0.1.13",
|
| 17 |
+
"pydantic>=2.0",
|
| 18 |
+
"fastapi>=0.100",
|
| 19 |
+
"uvicorn>=0.20",
|
| 20 |
+
"openai>=1.0",
|
| 21 |
+
"faker>=18.0",
|
| 22 |
+
"pytest>=7.0",
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
[project.scripts]
|
| 26 |
+
openenv-sql-analyst = "server.app:main"
|
| 27 |
+
|
| 28 |
+
[project.optional-dependencies]
|
| 29 |
+
dev = [
|
| 30 |
+
"pytest>=7.0",
|
| 31 |
+
"pytest-asyncio>=0.21",
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
[tool.setuptools.packages.find]
|
| 35 |
+
where = ["."]
|
| 36 |
+
include = ["env*", "baseline*", "server*"]
|
| 37 |
+
|
| 38 |
+
[tool.pytest.ini_options]
|
| 39 |
+
testpaths = ["tests"]
|
| 40 |
+
python_files = ["test_*.py"]
|
| 41 |
+
python_classes = ["Test*"]
|
| 42 |
+
python_functions = ["test_*"]
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pydantic>=2.0
|
| 2 |
+
fastapi>=0.100
|
| 3 |
+
uvicorn>=0.20
|
| 4 |
+
openai>=1.0
|
| 5 |
+
faker>=18.0
|
| 6 |
+
pytest>=7.0
|
server/__init__.py
ADDED
|
File without changes
|
server/app.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from env.server import app as _app
|
| 2 |
+
import uvicorn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def main():
|
| 6 |
+
uvicorn.run(_app, host="0.0.0.0", port=7860)
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
if __name__ == "__main__":
|
| 10 |
+
main()
|
| 11 |
+
|
| 12 |
+
__all__ = ["app", "main"]
|
temp_upload/baseline/__init__.py
ADDED
|
File without changes
|
temp_upload/baseline/prompts.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Prompt templates for the baseline inference script.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
SYSTEM_PROMPT = """
|
| 6 |
+
You are a SQL data analyst. You are given a database schema and a business question.
|
| 7 |
+
Your job is to write SQL queries to explore the data and submit a final answer.
|
| 8 |
+
|
| 9 |
+
Rules:
|
| 10 |
+
- Only write SELECT or WITH queries (no INSERT, UPDATE, DELETE, DROP, etc.)
|
| 11 |
+
- Reply with JSON only. No explanation.
|
| 12 |
+
- To run a query: {"sql_query": "SELECT ..."}
|
| 13 |
+
- To submit answer: {"submit_answer": "your answer here"}
|
| 14 |
+
- You will see the query result after each step.
|
| 15 |
+
- Submit your answer when you are confident.
|
| 16 |
+
|
| 17 |
+
Important:
|
| 18 |
+
- Always use valid SQL syntax
|
| 19 |
+
- Table names: users, products, orders, order_items, events
|
| 20 |
+
- Dates are stored as ISO timestamps
|
| 21 |
+
- Always filter orders by status='completed' for revenue calculations
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def build_prompt(obs) -> str:
|
| 26 |
+
"""Build the user prompt from an observation."""
|
| 27 |
+
parts = [
|
| 28 |
+
f"Database schema:\n{obs.schema_summary}",
|
| 29 |
+
f"\nQuestion: {obs.question}",
|
| 30 |
+
f"\nStep: {obs.step} / {obs.max_steps}",
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
if obs.last_query:
|
| 34 |
+
parts.append(f"\nLast query:\n{obs.last_query}")
|
| 35 |
+
|
| 36 |
+
if obs.last_result:
|
| 37 |
+
if obs.last_result.error:
|
| 38 |
+
parts.append(f"\nSQL error: {obs.last_result.error}")
|
| 39 |
+
elif obs.last_result.rows:
|
| 40 |
+
cols = obs.last_result.columns
|
| 41 |
+
rows = obs.last_result.rows[:10]
|
| 42 |
+
parts.append(f"\nResult columns: {cols}")
|
| 43 |
+
parts.append(
|
| 44 |
+
f"Result rows (first {len(rows)}):\n{json.dumps(rows, indent=2)}"
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
if obs.hints:
|
| 48 |
+
parts.append(f"\nHints: {'; '.join(obs.hints)}")
|
| 49 |
+
|
| 50 |
+
parts.append("\nWhat is your next action? Reply with JSON only.")
|
| 51 |
+
return "\n".join(parts)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
import json
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def parse_action(response_text: str | None):
|
| 58 |
+
"""Extract JSON action from LLM response."""
|
| 59 |
+
from env import Action
|
| 60 |
+
|
| 61 |
+
if not response_text:
|
| 62 |
+
return Action(submit_answer="")
|
| 63 |
+
|
| 64 |
+
text = response_text.strip()
|
| 65 |
+
|
| 66 |
+
text = text.replace("```json", "").replace("```", "").strip()
|
| 67 |
+
|
| 68 |
+
try:
|
| 69 |
+
data = json.loads(text)
|
| 70 |
+
|
| 71 |
+
if "sql_query" in data and data["sql_query"]:
|
| 72 |
+
return Action(sql_query=data["sql_query"])
|
| 73 |
+
elif "submit_answer" in data and data["submit_answer"]:
|
| 74 |
+
return Action(submit_answer=data["submit_answer"])
|
| 75 |
+
except json.JSONDecodeError:
|
| 76 |
+
pass
|
| 77 |
+
|
| 78 |
+
return Action(submit_answer=text)
|
temp_upload/baseline/run_baseline.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Baseline inference script for sql-data-analyst OpenEnv.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
export OPENAI_API_KEY=sk-...
|
| 6 |
+
python baseline/run_baseline.py
|
| 7 |
+
|
| 8 |
+
Produces reproducible scores across all 3 tasks.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import os
|
| 12 |
+
import json
|
| 13 |
+
import sys
|
| 14 |
+
from pathlib import Path
|
| 15 |
+
|
| 16 |
+
sys.path.insert(0, str(Path(__file__).parent.parent))
|
| 17 |
+
|
| 18 |
+
from typing import List, Dict, Any
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
from openai import OpenAI
|
| 22 |
+
except ImportError:
|
| 23 |
+
print("Error: openai package not installed. Run: pip install openai")
|
| 24 |
+
sys.exit(1)
|
| 25 |
+
|
| 26 |
+
from env import SQLAnalystEnv, Action
|
| 27 |
+
from baseline.prompts import SYSTEM_PROMPT, build_prompt, parse_action
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
MODEL = "gpt-4o-mini"
|
| 31 |
+
MAX_STEPS = 20
|
| 32 |
+
TASK_IDS = ["monthly_signups", "top_revenue_category", "churn_analysis"]
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def run_task(
|
| 36 |
+
client: OpenAI, task_id: str, max_steps: int = MAX_STEPS
|
| 37 |
+
) -> Dict[str, Any]:
|
| 38 |
+
"""Run a single task with the LLM agent."""
|
| 39 |
+
print(f"\n{'=' * 50}")
|
| 40 |
+
print(f"Task: {task_id}")
|
| 41 |
+
print("=" * 50)
|
| 42 |
+
|
| 43 |
+
env = SQLAnalystEnv(task_id=task_id)
|
| 44 |
+
result = env.reset()
|
| 45 |
+
obs = result.observation
|
| 46 |
+
history = []
|
| 47 |
+
total_reward = 0.0
|
| 48 |
+
|
| 49 |
+
print(f"Question: {obs.question}")
|
| 50 |
+
print(f"Schema: {obs.schema_summary[:200]}...")
|
| 51 |
+
|
| 52 |
+
for step in range(1, max_steps + 1):
|
| 53 |
+
if result.done:
|
| 54 |
+
print(f"Episode done at step {step - 1}")
|
| 55 |
+
break
|
| 56 |
+
|
| 57 |
+
user_prompt = build_prompt(obs)
|
| 58 |
+
history.append({"role": "user", "content": user_prompt})
|
| 59 |
+
|
| 60 |
+
try:
|
| 61 |
+
response = client.chat.completions.create(
|
| 62 |
+
model=MODEL,
|
| 63 |
+
messages=[
|
| 64 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 65 |
+
*history[-8:],
|
| 66 |
+
],
|
| 67 |
+
temperature=0.0,
|
| 68 |
+
)
|
| 69 |
+
except Exception as e:
|
| 70 |
+
print(f"API Error: {e}")
|
| 71 |
+
break
|
| 72 |
+
|
| 73 |
+
reply = response.choices[0].message.content or ""
|
| 74 |
+
history.append({"role": "assistant", "content": reply})
|
| 75 |
+
|
| 76 |
+
action = parse_action(reply)
|
| 77 |
+
|
| 78 |
+
if action.sql_query:
|
| 79 |
+
print(f"Step {step}: Executing SQL...")
|
| 80 |
+
print(f" Query: {action.sql_query[:100]}...")
|
| 81 |
+
else:
|
| 82 |
+
print(f"Step {step}: Submitting answer...")
|
| 83 |
+
print(
|
| 84 |
+
f" Answer: {action.submit_answer[:100] if action.submit_answer else 'empty'}..."
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
result = env.step(action)
|
| 88 |
+
obs = result.observation
|
| 89 |
+
total_reward = result.info.get("total_reward", 0.0)
|
| 90 |
+
|
| 91 |
+
if result.done:
|
| 92 |
+
break
|
| 93 |
+
|
| 94 |
+
state = env.state()
|
| 95 |
+
print(f"\nFinal total reward: {total_reward:.3f}")
|
| 96 |
+
print(f"Steps taken: {state.step}")
|
| 97 |
+
|
| 98 |
+
return {
|
| 99 |
+
"task_id": task_id,
|
| 100 |
+
"difficulty": state.difficulty,
|
| 101 |
+
"total_reward": round(total_reward, 3),
|
| 102 |
+
"steps": state.step,
|
| 103 |
+
"max_steps": state.max_steps,
|
| 104 |
+
}
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def main():
|
| 108 |
+
api_key = os.environ.get("OPENAI_API_KEY")
|
| 109 |
+
|
| 110 |
+
if not api_key:
|
| 111 |
+
print("Error: OPENAI_API_KEY environment variable not set")
|
| 112 |
+
print("Usage: export OPENAI_API_KEY=sk-...")
|
| 113 |
+
sys.exit(1)
|
| 114 |
+
|
| 115 |
+
client = OpenAI(api_key=api_key)
|
| 116 |
+
|
| 117 |
+
print("=" * 60)
|
| 118 |
+
print("SQL Data Analyst - Baseline Inference")
|
| 119 |
+
print("=" * 60)
|
| 120 |
+
print(f"Model: {MODEL}")
|
| 121 |
+
print(f"Max steps per task: {MAX_STEPS}")
|
| 122 |
+
print(f"Tasks: {TASK_IDS}")
|
| 123 |
+
|
| 124 |
+
results = []
|
| 125 |
+
for task_id in TASK_IDS:
|
| 126 |
+
try:
|
| 127 |
+
r = run_task(client, task_id)
|
| 128 |
+
results.append(r)
|
| 129 |
+
except Exception as e:
|
| 130 |
+
print(f"Error running task {task_id}: {e}")
|
| 131 |
+
results.append(
|
| 132 |
+
{
|
| 133 |
+
"task_id": task_id,
|
| 134 |
+
"error": str(e),
|
| 135 |
+
"total_reward": 0.0,
|
| 136 |
+
"steps": 0,
|
| 137 |
+
}
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
print("\n" + "=" * 60)
|
| 141 |
+
print("BASELINE RESULTS")
|
| 142 |
+
print("=" * 60)
|
| 143 |
+
|
| 144 |
+
for r in results:
|
| 145 |
+
task = r.get("task_id", "unknown")
|
| 146 |
+
reward = r.get("total_reward", 0.0)
|
| 147 |
+
steps = r.get("steps", 0)
|
| 148 |
+
print(f"{task:30s} score={reward:.3f} steps={steps}")
|
| 149 |
+
|
| 150 |
+
valid_results = [r for r in results if "total_reward" in r]
|
| 151 |
+
if valid_results:
|
| 152 |
+
avg = sum(r["total_reward"] for r in valid_results) / len(valid_results)
|
| 153 |
+
print(f"\nAverage score: {avg:.3f}")
|
| 154 |
+
|
| 155 |
+
output_file = "baseline_scores.json"
|
| 156 |
+
with open(output_file, "w") as f:
|
| 157 |
+
json.dump(results, f, indent=2)
|
| 158 |
+
print(f"\nSaved results to {output_file}")
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
if __name__ == "__main__":
|
| 162 |
+
main()
|
temp_upload/env/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .models import Action, QueryResult, Observation, StepResult, EnvState
|
| 2 |
+
from .environment import SQLAnalystEnv
|
| 3 |
+
from .tasks import TASKS
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"Action",
|
| 7 |
+
"QueryResult",
|
| 8 |
+
"Observation",
|
| 9 |
+
"StepResult",
|
| 10 |
+
"EnvState",
|
| 11 |
+
"SQLAnalystEnv",
|
| 12 |
+
"TASKS",
|
| 13 |
+
]
|
temp_upload/env/database.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sqlite3
|
| 2 |
+
import random
|
| 3 |
+
from datetime import datetime, timedelta
|
| 4 |
+
from typing import Optional, Any
|
| 5 |
+
from faker import Faker
|
| 6 |
+
|
| 7 |
+
fake = Faker()
|
| 8 |
+
|
| 9 |
+
SEED_CONFIG = {
|
| 10 |
+
"users": 500,
|
| 11 |
+
"products": 80,
|
| 12 |
+
"orders": 2000,
|
| 13 |
+
"order_items": 5000,
|
| 14 |
+
"events": 8000,
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
CATEGORIES = ["Electronics", "Clothing", "Books", "Home & Garden", "Sports"]
|
| 18 |
+
PLAN_TYPES = ["free", "pro", "enterprise"]
|
| 19 |
+
ORDER_STATUSES = ["pending", "completed", "refunded"]
|
| 20 |
+
EVENT_TYPES = ["page_view", "add_to_cart", "checkout", "login", "logout"]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def create_database(db_path: str = ":memory:") -> sqlite3.Connection:
|
| 24 |
+
conn = sqlite3.connect(db_path)
|
| 25 |
+
conn.row_factory = sqlite3.Row
|
| 26 |
+
|
| 27 |
+
conn.execute("""
|
| 28 |
+
CREATE TABLE users (
|
| 29 |
+
id INTEGER PRIMARY KEY,
|
| 30 |
+
email TEXT NOT NULL,
|
| 31 |
+
country TEXT,
|
| 32 |
+
plan TEXT CHECK(plan IN ('free', 'pro', 'enterprise')),
|
| 33 |
+
created_at TIMESTAMP NOT NULL,
|
| 34 |
+
churned_at TIMESTAMP
|
| 35 |
+
)
|
| 36 |
+
""")
|
| 37 |
+
|
| 38 |
+
conn.execute("""
|
| 39 |
+
CREATE TABLE products (
|
| 40 |
+
id INTEGER PRIMARY KEY,
|
| 41 |
+
name TEXT NOT NULL,
|
| 42 |
+
category TEXT NOT NULL,
|
| 43 |
+
price REAL,
|
| 44 |
+
cost REAL
|
| 45 |
+
)
|
| 46 |
+
""")
|
| 47 |
+
|
| 48 |
+
conn.execute("""
|
| 49 |
+
CREATE TABLE orders (
|
| 50 |
+
id INTEGER PRIMARY KEY,
|
| 51 |
+
user_id INTEGER REFERENCES users(id),
|
| 52 |
+
created_at TIMESTAMP NOT NULL,
|
| 53 |
+
status TEXT CHECK(status IN ('pending', 'completed', 'refunded')),
|
| 54 |
+
total REAL
|
| 55 |
+
)
|
| 56 |
+
""")
|
| 57 |
+
|
| 58 |
+
conn.execute("""
|
| 59 |
+
CREATE TABLE order_items (
|
| 60 |
+
id INTEGER PRIMARY KEY,
|
| 61 |
+
order_id INTEGER REFERENCES orders(id),
|
| 62 |
+
product_id INTEGER REFERENCES products(id),
|
| 63 |
+
qty INTEGER NOT NULL,
|
| 64 |
+
unit_price REAL
|
| 65 |
+
)
|
| 66 |
+
""")
|
| 67 |
+
|
| 68 |
+
conn.execute("""
|
| 69 |
+
CREATE TABLE events (
|
| 70 |
+
id INTEGER PRIMARY KEY,
|
| 71 |
+
user_id INTEGER REFERENCES users(id),
|
| 72 |
+
event_type TEXT,
|
| 73 |
+
metadata TEXT,
|
| 74 |
+
ts TIMESTAMP NOT NULL
|
| 75 |
+
)
|
| 76 |
+
""")
|
| 77 |
+
|
| 78 |
+
conn.commit()
|
| 79 |
+
return conn
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def seed_database(conn: sqlite3.Connection) -> None:
|
| 83 |
+
users = _seed_users(conn)
|
| 84 |
+
products = _seed_products(conn)
|
| 85 |
+
orders, order_items = _seed_orders(conn, users, products)
|
| 86 |
+
_seed_events(conn, users, orders)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def _seed_users(conn: sqlite3.Connection) -> list:
|
| 90 |
+
users = []
|
| 91 |
+
now = datetime.now()
|
| 92 |
+
base_date = now - timedelta(days=180)
|
| 93 |
+
recent_date = now - timedelta(days=30)
|
| 94 |
+
|
| 95 |
+
for i in range(SEED_CONFIG["users"]):
|
| 96 |
+
if random.random() < 0.3:
|
| 97 |
+
created_at = recent_date + timedelta(days=random.randint(0, 30))
|
| 98 |
+
else:
|
| 99 |
+
created_at = base_date + timedelta(days=random.randint(0, 180))
|
| 100 |
+
|
| 101 |
+
country = random.choice([fake.country(), None, None, None, None])
|
| 102 |
+
plan = random.choice(PLAN_TYPES)
|
| 103 |
+
churned_at = None
|
| 104 |
+
|
| 105 |
+
if plan == "free" and random.random() < 0.1:
|
| 106 |
+
churned_at = created_at + timedelta(days=random.randint(30, 150))
|
| 107 |
+
|
| 108 |
+
conn.execute(
|
| 109 |
+
"INSERT INTO users (email, country, plan, created_at, churned_at) VALUES (?, ?, ?, ?, ?)",
|
| 110 |
+
(
|
| 111 |
+
fake.email(),
|
| 112 |
+
country,
|
| 113 |
+
plan,
|
| 114 |
+
created_at.isoformat(),
|
| 115 |
+
churned_at.isoformat() if churned_at else None,
|
| 116 |
+
),
|
| 117 |
+
)
|
| 118 |
+
users.append((i + 1, created_at))
|
| 119 |
+
|
| 120 |
+
conn.commit()
|
| 121 |
+
return users
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def _seed_products(conn: sqlite3.Connection) -> list:
|
| 125 |
+
products = []
|
| 126 |
+
|
| 127 |
+
for i in range(SEED_CONFIG["products"]):
|
| 128 |
+
category = random.choice(CATEGORIES)
|
| 129 |
+
price = round(random.uniform(10, 500), 2)
|
| 130 |
+
cost = round(price * random.uniform(0.3, 0.7), 2)
|
| 131 |
+
|
| 132 |
+
conn.execute(
|
| 133 |
+
"INSERT INTO products (name, category, price, cost) VALUES (?, ?, ?, ?)",
|
| 134 |
+
(fake.catch_phrase(), category, price, cost),
|
| 135 |
+
)
|
| 136 |
+
products.append((i + 1, category, price))
|
| 137 |
+
|
| 138 |
+
conn.commit()
|
| 139 |
+
return products
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def _seed_orders(conn: sqlite3.Connection, users: list, products: list) -> tuple:
|
| 143 |
+
orders = []
|
| 144 |
+
order_items = []
|
| 145 |
+
|
| 146 |
+
q3_start = datetime(2024, 7, 1)
|
| 147 |
+
q3_end = datetime(2024, 9, 30)
|
| 148 |
+
recent_date = datetime.now()
|
| 149 |
+
old_date = datetime(2024, 1, 1)
|
| 150 |
+
|
| 151 |
+
for i in range(SEED_CONFIG["orders"]):
|
| 152 |
+
user_id = random.choice(users)[0]
|
| 153 |
+
|
| 154 |
+
if random.random() < 0.2:
|
| 155 |
+
created_at = q3_start + timedelta(days=random.randint(0, 91))
|
| 156 |
+
else:
|
| 157 |
+
created_at = old_date + timedelta(days=random.randint(0, 180))
|
| 158 |
+
|
| 159 |
+
status = random.choices(ORDER_STATUSES, weights=[0.1, 0.87, 0.03])[0]
|
| 160 |
+
|
| 161 |
+
conn.execute(
|
| 162 |
+
"INSERT INTO orders (user_id, created_at, status, total) VALUES (?, ?, ?, ?)",
|
| 163 |
+
(user_id, created_at.isoformat(), status, 0),
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
order_id = i + 1
|
| 167 |
+
order_total = 0
|
| 168 |
+
|
| 169 |
+
num_items = random.randint(1, 5)
|
| 170 |
+
for _ in range(num_items):
|
| 171 |
+
product = random.choice(products)
|
| 172 |
+
qty = random.randint(1, 3)
|
| 173 |
+
unit_price = product[2]
|
| 174 |
+
order_total += qty * unit_price
|
| 175 |
+
|
| 176 |
+
conn.execute(
|
| 177 |
+
"INSERT INTO order_items (order_id, product_id, qty, unit_price) VALUES (?, ?, ?, ?)",
|
| 178 |
+
(order_id, product[0], qty, unit_price),
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
conn.execute(
|
| 182 |
+
"UPDATE orders SET total = ? WHERE id = ?",
|
| 183 |
+
(round(order_total, 2), order_id),
|
| 184 |
+
)
|
| 185 |
+
orders.append((order_id, user_id, created_at, status))
|
| 186 |
+
|
| 187 |
+
conn.commit()
|
| 188 |
+
return orders, order_items
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def _seed_events(conn: sqlite3.Connection, users: list, orders: list) -> None:
|
| 192 |
+
base_date = datetime.now() - timedelta(days=180)
|
| 193 |
+
|
| 194 |
+
for _ in range(SEED_CONFIG["events"]):
|
| 195 |
+
user_id = random.choice(users)[0]
|
| 196 |
+
ts = base_date + timedelta(
|
| 197 |
+
days=random.randint(0, 180), hours=random.randint(0, 23)
|
| 198 |
+
)
|
| 199 |
+
event_type = random.choice(EVENT_TYPES)
|
| 200 |
+
metadata = '{"page": "/' + fake.uri_path() + '"}'
|
| 201 |
+
|
| 202 |
+
conn.execute(
|
| 203 |
+
"INSERT INTO events (user_id, event_type, metadata, ts) VALUES (?, ?, ?, ?)",
|
| 204 |
+
(user_id, event_type, metadata, ts.isoformat()),
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
conn.commit()
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def get_schema_summary(conn: sqlite3.Connection) -> str:
|
| 211 |
+
cursor = conn.execute(
|
| 212 |
+
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY name"
|
| 213 |
+
)
|
| 214 |
+
tables = [r[0] for r in cursor.fetchall()]
|
| 215 |
+
|
| 216 |
+
lines = []
|
| 217 |
+
for table in tables:
|
| 218 |
+
cols = conn.execute(f"PRAGMA table_info({table})").fetchall()
|
| 219 |
+
col_names = [c[1] for c in cols]
|
| 220 |
+
lines.append(f"{table}: ({', '.join(col_names)})")
|
| 221 |
+
|
| 222 |
+
return "\n".join(lines)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def get_ground_truth(conn: sqlite3.Connection, task_id: str) -> Any:
|
| 226 |
+
if task_id == "monthly_signups":
|
| 227 |
+
result = conn.execute(
|
| 228 |
+
"SELECT COUNT(*) FROM users WHERE created_at >= DATE('now', '-30 days')"
|
| 229 |
+
).fetchone()
|
| 230 |
+
return result[0]
|
| 231 |
+
|
| 232 |
+
elif task_id == "top_revenue_category":
|
| 233 |
+
result = conn.execute("""
|
| 234 |
+
SELECT p.category, SUM(oi.qty * oi.unit_price) AS revenue
|
| 235 |
+
FROM orders o
|
| 236 |
+
JOIN order_items oi ON o.id = oi.order_id
|
| 237 |
+
JOIN products p ON oi.product_id = p.id
|
| 238 |
+
WHERE o.created_at BETWEEN '2024-07-01' AND '2024-09-30'
|
| 239 |
+
AND o.status = 'completed'
|
| 240 |
+
GROUP BY p.category
|
| 241 |
+
ORDER BY revenue DESC
|
| 242 |
+
LIMIT 1
|
| 243 |
+
""").fetchone()
|
| 244 |
+
return result[0] if result else None
|
| 245 |
+
|
| 246 |
+
elif task_id == "churn_analysis":
|
| 247 |
+
result = conn.execute("""
|
| 248 |
+
WITH order_counts AS (
|
| 249 |
+
SELECT user_id, COUNT(*) AS total_orders,
|
| 250 |
+
MAX(created_at) AS last_order_date
|
| 251 |
+
FROM orders
|
| 252 |
+
WHERE status = 'completed'
|
| 253 |
+
GROUP BY user_id
|
| 254 |
+
HAVING COUNT(*) = 3
|
| 255 |
+
),
|
| 256 |
+
churned AS (
|
| 257 |
+
SELECT oc.user_id
|
| 258 |
+
FROM order_counts oc
|
| 259 |
+
WHERE oc.last_order_date < DATE('now', '-90 days')
|
| 260 |
+
)
|
| 261 |
+
SELECT u.email
|
| 262 |
+
FROM users u
|
| 263 |
+
JOIN churned c ON u.id = c.user_id
|
| 264 |
+
""").fetchall()
|
| 265 |
+
return {row[0].lower() for row in result}
|
| 266 |
+
|
| 267 |
+
return None
|
temp_upload/env/environment.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sqlite3
|
| 2 |
+
from typing import Optional
|
| 3 |
+
from .models import Action, Observation, StepResult, EnvState, QueryResult
|
| 4 |
+
from .database import create_database, seed_database, get_schema_summary
|
| 5 |
+
from .reward import RewardCalculator
|
| 6 |
+
from .tasks import TASKS
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class SQLAnalystEnv:
|
| 10 |
+
"""
|
| 11 |
+
OpenEnv-compliant SQL Data Analyst environment.
|
| 12 |
+
|
| 13 |
+
An agent must answer business questions by iteratively
|
| 14 |
+
writing and executing SQL queries.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, task_id: str = "monthly_signups"):
|
| 18 |
+
assert task_id in TASKS, f"Unknown task: {task_id}. Choose from {list(TASKS)}"
|
| 19 |
+
self.task_id = task_id
|
| 20 |
+
self.task = TASKS[task_id]
|
| 21 |
+
self.conn: Optional[sqlite3.Connection] = None
|
| 22 |
+
self.step_count: int = 0
|
| 23 |
+
self.total_reward: float = 0.0
|
| 24 |
+
self.done: bool = False
|
| 25 |
+
self._query_history: list = []
|
| 26 |
+
self._reward_calc = RewardCalculator()
|
| 27 |
+
|
| 28 |
+
def reset(self) -> StepResult:
|
| 29 |
+
"""Reset environment. Reseed DB. Return initial observation."""
|
| 30 |
+
if self.conn:
|
| 31 |
+
self.conn.close()
|
| 32 |
+
|
| 33 |
+
self.conn = create_database()
|
| 34 |
+
seed_database(self.conn)
|
| 35 |
+
self.step_count = 0
|
| 36 |
+
self.total_reward = 0.0
|
| 37 |
+
self.done = False
|
| 38 |
+
self._query_history = []
|
| 39 |
+
|
| 40 |
+
self.task.compute_ground_truth(self.conn)
|
| 41 |
+
|
| 42 |
+
obs = Observation(
|
| 43 |
+
schema_summary=get_schema_summary(self.conn),
|
| 44 |
+
question=self.task.question,
|
| 45 |
+
step=0,
|
| 46 |
+
max_steps=self.task.max_steps,
|
| 47 |
+
)
|
| 48 |
+
return StepResult(observation=obs, reward=0.0, done=False)
|
| 49 |
+
|
| 50 |
+
def step(self, action: Action) -> StepResult:
|
| 51 |
+
"""Execute one agent action. Return (observation, reward, done, info)."""
|
| 52 |
+
assert self.conn is not None, "Call reset() before step()"
|
| 53 |
+
assert not self.done, "Episode is done. Call reset()."
|
| 54 |
+
assert action.is_valid(), (
|
| 55 |
+
"Action must have exactly one of: sql_query, submit_answer"
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
self.step_count += 1
|
| 59 |
+
query_result = None
|
| 60 |
+
error = None
|
| 61 |
+
|
| 62 |
+
if action.sql_query:
|
| 63 |
+
query_result = self._execute_sql(action.sql_query)
|
| 64 |
+
self._query_history.append(action.sql_query)
|
| 65 |
+
error = query_result.error
|
| 66 |
+
|
| 67 |
+
terminal = (
|
| 68 |
+
action.submit_answer is not None or self.step_count >= self.task.max_steps
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
reward = self._reward_calc.calculate(
|
| 72 |
+
action=action,
|
| 73 |
+
result=query_result,
|
| 74 |
+
task=self.task,
|
| 75 |
+
step=self.step_count,
|
| 76 |
+
query_history=self._query_history,
|
| 77 |
+
terminal=terminal,
|
| 78 |
+
)
|
| 79 |
+
self.total_reward += reward
|
| 80 |
+
self.done = terminal
|
| 81 |
+
|
| 82 |
+
obs = Observation(
|
| 83 |
+
schema_summary=get_schema_summary(self.conn),
|
| 84 |
+
question=self.task.question,
|
| 85 |
+
last_query=action.sql_query,
|
| 86 |
+
last_result=query_result,
|
| 87 |
+
last_error=error,
|
| 88 |
+
step=self.step_count,
|
| 89 |
+
max_steps=self.task.max_steps,
|
| 90 |
+
hints=self.task.get_hints(self.step_count),
|
| 91 |
+
done=self.done,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
return StepResult(
|
| 95 |
+
observation=obs,
|
| 96 |
+
reward=round(reward, 3),
|
| 97 |
+
done=self.done,
|
| 98 |
+
info={
|
| 99 |
+
"step": self.step_count,
|
| 100 |
+
"total_reward": round(self.total_reward, 3),
|
| 101 |
+
"task_id": self.task_id,
|
| 102 |
+
},
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
def state(self) -> EnvState:
|
| 106 |
+
"""Return current full state of the environment."""
|
| 107 |
+
return EnvState(
|
| 108 |
+
task_id=self.task_id,
|
| 109 |
+
difficulty=self.task.difficulty,
|
| 110 |
+
step=self.step_count,
|
| 111 |
+
max_steps=self.task.max_steps,
|
| 112 |
+
query_history=self._query_history.copy(),
|
| 113 |
+
total_reward=round(self.total_reward, 3),
|
| 114 |
+
done=self.done,
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
def _execute_sql(self, query: str) -> QueryResult:
|
| 118 |
+
"""Execute SQL safely. Block non-SELECT. Return up to 50 rows."""
|
| 119 |
+
q = query.strip().upper()
|
| 120 |
+
if not q.startswith("SELECT") and not q.startswith("WITH"):
|
| 121 |
+
return QueryResult(error="Only SELECT / WITH queries are allowed.")
|
| 122 |
+
try:
|
| 123 |
+
cursor = self.conn.execute(query)
|
| 124 |
+
cols = [d[0] for d in cursor.description] if cursor.description else []
|
| 125 |
+
rows = cursor.fetchmany(50)
|
| 126 |
+
total = len(rows)
|
| 127 |
+
return QueryResult(
|
| 128 |
+
columns=cols,
|
| 129 |
+
rows=[list(r) for r in rows],
|
| 130 |
+
truncated=(total == 50),
|
| 131 |
+
total_rows=total,
|
| 132 |
+
)
|
| 133 |
+
except Exception as e:
|
| 134 |
+
return QueryResult(error=str(e))
|
temp_upload/env/models.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel, Field
|
| 2 |
+
from typing import Optional, List, Any
|
| 3 |
+
|
| 4 |
+
class Action(BaseModel):
|
| 5 |
+
"""What the agent can do each step."""
|
| 6 |
+
sql_query: Optional[str] = Field(
|
| 7 |
+
None,
|
| 8 |
+
description="A SQL SELECT query to execute against the database"
|
| 9 |
+
)
|
| 10 |
+
submit_answer: Optional[str] = Field(
|
| 11 |
+
None,
|
| 12 |
+
description="Final answer to submit. Ends the episode."
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
def is_valid(self) -> bool:
|
| 16 |
+
# Exactly one of the two must be set
|
| 17 |
+
return bool(self.sql_query) != bool(self.submit_answer)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class QueryResult(BaseModel):
|
| 21 |
+
"""Result of executing a SQL query."""
|
| 22 |
+
columns: List[str] = []
|
| 23 |
+
rows: List[List[Any]] = []
|
| 24 |
+
error: Optional[str] = None
|
| 25 |
+
truncated: bool = False
|
| 26 |
+
total_rows: int = 0
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class Observation(BaseModel):
|
| 30 |
+
"""What the agent sees after each step."""
|
| 31 |
+
schema_summary: str = Field(..., description="Compact DB schema")
|
| 32 |
+
question: str = Field(..., description="Business question to answer")
|
| 33 |
+
last_query: Optional[str] = None
|
| 34 |
+
last_result: Optional[QueryResult] = None
|
| 35 |
+
last_error: Optional[str] = None
|
| 36 |
+
step: int = 0
|
| 37 |
+
max_steps: int = 20
|
| 38 |
+
hints: List[str] = []
|
| 39 |
+
done: bool = False
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class StepResult(BaseModel):
|
| 43 |
+
"""Full result returned by step()."""
|
| 44 |
+
observation: Observation
|
| 45 |
+
reward: float = 0.0
|
| 46 |
+
done: bool = False
|
| 47 |
+
info: dict = {}
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class EnvState(BaseModel):
|
| 51 |
+
"""Full environment state returned by state()."""
|
| 52 |
+
task_id: str
|
| 53 |
+
difficulty: str
|
| 54 |
+
step: int
|
| 55 |
+
max_steps: int
|
| 56 |
+
query_history: List[str] = []
|
| 57 |
+
total_reward: float = 0.0
|
| 58 |
+
done: bool = False
|
temp_upload/env/reward.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional, List, Any
|
| 2 |
+
from .models import Action, QueryResult
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class RewardCalculator:
|
| 6 |
+
"""Calculate rewards for agent actions in the SQL analyst environment."""
|
| 7 |
+
|
| 8 |
+
def calculate(
|
| 9 |
+
self,
|
| 10 |
+
action: Action,
|
| 11 |
+
result: Optional[QueryResult],
|
| 12 |
+
task: Any,
|
| 13 |
+
step: int,
|
| 14 |
+
query_history: List[str],
|
| 15 |
+
terminal: bool,
|
| 16 |
+
) -> float:
|
| 17 |
+
"""Calculate reward based on action, result, and task."""
|
| 18 |
+
reward = 0.0
|
| 19 |
+
|
| 20 |
+
if action.sql_query and result:
|
| 21 |
+
if not result.error:
|
| 22 |
+
reward += 0.15
|
| 23 |
+
|
| 24 |
+
relevant = self._count_relevant_tables(
|
| 25 |
+
action.sql_query, task.relevant_tables
|
| 26 |
+
)
|
| 27 |
+
if relevant > 0:
|
| 28 |
+
reward += 0.10
|
| 29 |
+
|
| 30 |
+
if result.rows and len(result.rows) > 0:
|
| 31 |
+
reward += 0.05
|
| 32 |
+
|
| 33 |
+
if result.rows and len(result.rows) < 1000:
|
| 34 |
+
reward += 0.05
|
| 35 |
+
|
| 36 |
+
if step > 3:
|
| 37 |
+
reward -= 0.02 * (step - 3)
|
| 38 |
+
|
| 39 |
+
if self._is_stuck(query_history):
|
| 40 |
+
reward -= 0.10
|
| 41 |
+
|
| 42 |
+
if terminal and action.submit_answer:
|
| 43 |
+
task_score = task.grade(action.submit_answer)
|
| 44 |
+
reward += task_score * 0.60
|
| 45 |
+
|
| 46 |
+
return max(0.0, min(1.0, reward))
|
| 47 |
+
|
| 48 |
+
def _count_relevant_tables(self, query: str, relevant_tables: List[str]) -> int:
|
| 49 |
+
query_lower = query.lower()
|
| 50 |
+
return sum(1 for t in relevant_tables if t.lower() in query_lower)
|
| 51 |
+
|
| 52 |
+
def _is_stuck(self, history: List[str]) -> bool:
|
| 53 |
+
if len(history) < 3:
|
| 54 |
+
return False
|
| 55 |
+
return len(set(history[-3:])) == 1
|
temp_upload/env/server.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FastAPI server for SQL Data Analyst OpenEnv.
|
| 3 |
+
|
| 4 |
+
Provides REST and WebSocket endpoints for HuggingFace Spaces deployment.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
| 8 |
+
from pydantic import BaseModel
|
| 9 |
+
from typing import Optional, Dict, Any
|
| 10 |
+
import json
|
| 11 |
+
import asyncio
|
| 12 |
+
|
| 13 |
+
from env import SQLAnalystEnv, Action
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
app = FastAPI(title="SQL Data Analyst Environment")
|
| 17 |
+
|
| 18 |
+
envs: Dict[str, SQLAnalystEnv] = {}
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class ResetRequest(BaseModel):
|
| 22 |
+
task_id: str = "monthly_signups"
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class StepRequest(BaseModel):
|
| 26 |
+
session_id: str
|
| 27 |
+
sql_query: Optional[str] = None
|
| 28 |
+
submit_answer: Optional[str] = None
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class StateRequest(BaseModel):
|
| 32 |
+
session_id: str
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@app.get("/")
|
| 36 |
+
async def root():
|
| 37 |
+
return {
|
| 38 |
+
"name": "sql-data-analyst",
|
| 39 |
+
"version": "1.0.0",
|
| 40 |
+
"description": "SQL Data Analyst OpenEnv - RL environment for SQL query generation",
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@app.post("/reset")
|
| 45 |
+
async def reset(req: ResetRequest) -> Dict[str, Any]:
|
| 46 |
+
session_id = req.task_id
|
| 47 |
+
|
| 48 |
+
env = SQLAnalystEnv(task_id=req.task_id)
|
| 49 |
+
result = env.reset()
|
| 50 |
+
envs[session_id] = env
|
| 51 |
+
|
| 52 |
+
return {
|
| 53 |
+
"session_id": session_id,
|
| 54 |
+
"observation": {
|
| 55 |
+
"schema_summary": result.observation.schema_summary,
|
| 56 |
+
"question": result.observation.question,
|
| 57 |
+
"step": result.observation.step,
|
| 58 |
+
"max_steps": result.observation.max_steps,
|
| 59 |
+
"hints": result.observation.hints,
|
| 60 |
+
"done": result.observation.done,
|
| 61 |
+
},
|
| 62 |
+
"reward": result.reward,
|
| 63 |
+
"done": result.done,
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@app.post("/step")
|
| 68 |
+
async def step(req: StepRequest) -> Dict[str, Any]:
|
| 69 |
+
session_id = req.session_id
|
| 70 |
+
|
| 71 |
+
if session_id not in envs:
|
| 72 |
+
return {"error": "Session not found. Call /reset first."}
|
| 73 |
+
|
| 74 |
+
env = envs[session_id]
|
| 75 |
+
|
| 76 |
+
action = Action(sql_query=req.sql_query, submit_answer=req.submit_answer)
|
| 77 |
+
|
| 78 |
+
result = env.step(action)
|
| 79 |
+
|
| 80 |
+
return {
|
| 81 |
+
"observation": {
|
| 82 |
+
"schema_summary": result.observation.schema_summary,
|
| 83 |
+
"question": result.observation.question,
|
| 84 |
+
"last_query": result.observation.last_query,
|
| 85 |
+
"last_result": {
|
| 86 |
+
"columns": result.observation.last_result.columns
|
| 87 |
+
if result.observation.last_result
|
| 88 |
+
else None,
|
| 89 |
+
"rows": result.observation.last_result.rows
|
| 90 |
+
if result.observation.last_result
|
| 91 |
+
else None,
|
| 92 |
+
"error": result.observation.last_result.error
|
| 93 |
+
if result.observation.last_result
|
| 94 |
+
else None,
|
| 95 |
+
}
|
| 96 |
+
if result.observation.last_result
|
| 97 |
+
else None,
|
| 98 |
+
"last_error": result.observation.last_error,
|
| 99 |
+
"step": result.observation.step,
|
| 100 |
+
"max_steps": result.observation.max_steps,
|
| 101 |
+
"hints": result.observation.hints,
|
| 102 |
+
"done": result.observation.done,
|
| 103 |
+
},
|
| 104 |
+
"reward": result.reward,
|
| 105 |
+
"done": result.done,
|
| 106 |
+
"info": result.info,
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
@app.post("/state")
|
| 111 |
+
async def state(req: StateRequest) -> Dict[str, Any]:
|
| 112 |
+
session_id = req.session_id
|
| 113 |
+
|
| 114 |
+
if session_id not in envs:
|
| 115 |
+
return {"error": "Session not found. Call /reset first."}
|
| 116 |
+
|
| 117 |
+
env = envs[session_id]
|
| 118 |
+
state = env.state()
|
| 119 |
+
|
| 120 |
+
return {
|
| 121 |
+
"task_id": state.task_id,
|
| 122 |
+
"difficulty": state.difficulty,
|
| 123 |
+
"step": state.step,
|
| 124 |
+
"max_steps": state.max_steps,
|
| 125 |
+
"query_history": state.query_history,
|
| 126 |
+
"total_reward": state.total_reward,
|
| 127 |
+
"done": state.done,
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
@app.post("/delete")
|
| 132 |
+
async def delete_session(req: StateRequest) -> Dict[str, str]:
|
| 133 |
+
session_id = req.session_id
|
| 134 |
+
|
| 135 |
+
if session_id in envs:
|
| 136 |
+
del envs[session_id]
|
| 137 |
+
return {"status": "deleted", "session_id": session_id}
|
| 138 |
+
|
| 139 |
+
return {"status": "not_found", "session_id": session_id}
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
@app.websocket("/ws")
|
| 143 |
+
async def websocket_endpoint(websocket: WebSocket):
|
| 144 |
+
await websocket.accept()
|
| 145 |
+
|
| 146 |
+
session_id = None
|
| 147 |
+
env = None
|
| 148 |
+
|
| 149 |
+
try:
|
| 150 |
+
while True:
|
| 151 |
+
data = await websocket.receive_text()
|
| 152 |
+
message = json.loads(data)
|
| 153 |
+
|
| 154 |
+
action_type = message.get("type")
|
| 155 |
+
|
| 156 |
+
if action_type == "reset":
|
| 157 |
+
task_id = message.get("task_id", "monthly_signups")
|
| 158 |
+
env = SQLAnalystEnv(task_id=task_id)
|
| 159 |
+
result = env.reset()
|
| 160 |
+
session_id = task_id
|
| 161 |
+
envs[session_id] = env
|
| 162 |
+
|
| 163 |
+
await websocket.send_json(
|
| 164 |
+
{
|
| 165 |
+
"type": "reset",
|
| 166 |
+
"observation": {
|
| 167 |
+
"schema_summary": result.observation.schema_summary,
|
| 168 |
+
"question": result.observation.question,
|
| 169 |
+
"step": result.observation.step,
|
| 170 |
+
"max_steps": result.observation.max_steps,
|
| 171 |
+
"hints": result.observation.hints,
|
| 172 |
+
},
|
| 173 |
+
"reward": result.reward,
|
| 174 |
+
"done": result.done,
|
| 175 |
+
}
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
elif action_type == "step":
|
| 179 |
+
if not env:
|
| 180 |
+
await websocket.send_json({"error": "Call reset first"})
|
| 181 |
+
continue
|
| 182 |
+
|
| 183 |
+
action = Action(
|
| 184 |
+
sql_query=message.get("sql_query"),
|
| 185 |
+
submit_answer=message.get("submit_answer"),
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
result = env.step(action)
|
| 189 |
+
|
| 190 |
+
await websocket.send_json(
|
| 191 |
+
{
|
| 192 |
+
"type": "step",
|
| 193 |
+
"observation": {
|
| 194 |
+
"schema_summary": result.observation.schema_summary,
|
| 195 |
+
"question": result.observation.question,
|
| 196 |
+
"last_query": result.observation.last_query,
|
| 197 |
+
"last_result": {
|
| 198 |
+
"columns": result.observation.last_result.columns
|
| 199 |
+
if result.observation.last_result
|
| 200 |
+
else None,
|
| 201 |
+
"rows": result.observation.last_result.rows
|
| 202 |
+
if result.observation.last_result
|
| 203 |
+
else None,
|
| 204 |
+
"error": result.observation.last_result.error
|
| 205 |
+
if result.observation.last_result
|
| 206 |
+
else None,
|
| 207 |
+
}
|
| 208 |
+
if result.observation.last_result
|
| 209 |
+
else None,
|
| 210 |
+
"step": result.observation.step,
|
| 211 |
+
"hints": result.observation.hints,
|
| 212 |
+
"done": result.observation.done,
|
| 213 |
+
},
|
| 214 |
+
"reward": result.reward,
|
| 215 |
+
"done": result.done,
|
| 216 |
+
"info": result.info,
|
| 217 |
+
}
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
elif action_type == "state":
|
| 221 |
+
if not env:
|
| 222 |
+
await websocket.send_json({"error": "Call reset first"})
|
| 223 |
+
continue
|
| 224 |
+
|
| 225 |
+
state = env.state()
|
| 226 |
+
|
| 227 |
+
await websocket.send_json(
|
| 228 |
+
{
|
| 229 |
+
"type": "state",
|
| 230 |
+
"task_id": state.task_id,
|
| 231 |
+
"difficulty": state.difficulty,
|
| 232 |
+
"step": state.step,
|
| 233 |
+
"max_steps": state.max_steps,
|
| 234 |
+
"query_history": state.query_history,
|
| 235 |
+
"total_reward": state.total_reward,
|
| 236 |
+
"done": state.done,
|
| 237 |
+
}
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
elif action_type == "close":
|
| 241 |
+
if session_id and session_id in envs:
|
| 242 |
+
del envs[session_id]
|
| 243 |
+
break
|
| 244 |
+
|
| 245 |
+
except WebSocketDisconnect:
|
| 246 |
+
pass
|
| 247 |
+
except Exception as e:
|
| 248 |
+
await websocket.send_json({"error": str(e)})
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
if __name__ == "__main__":
|
| 252 |
+
import uvicorn
|
| 253 |
+
|
| 254 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
temp_upload/env/tasks/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base import BaseTask
|
| 2 |
+
from .easy import MonthlySignupsTask
|
| 3 |
+
from .medium import TopRevenueCategoryTask
|
| 4 |
+
from .hard import ChurnAnalysisTask
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
TASKS = {
|
| 8 |
+
"monthly_signups": MonthlySignupsTask(),
|
| 9 |
+
"top_revenue_category": TopRevenueCategoryTask(),
|
| 10 |
+
"churn_analysis": ChurnAnalysisTask(),
|
| 11 |
+
}
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
__all__ = [
|
| 15 |
+
"BaseTask",
|
| 16 |
+
"MonthlySignupsTask",
|
| 17 |
+
"TopRevenueCategoryTask",
|
| 18 |
+
"ChurnAnalysisTask",
|
| 19 |
+
"TASKS",
|
| 20 |
+
]
|
temp_upload/env/tasks/base.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
import sqlite3
|
| 3 |
+
import re
|
| 4 |
+
from typing import Any, List, Optional
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class BaseTask(ABC):
|
| 8 |
+
"""Abstract base class for all tasks."""
|
| 9 |
+
|
| 10 |
+
task_id: str
|
| 11 |
+
difficulty: str
|
| 12 |
+
max_steps: int
|
| 13 |
+
question: str
|
| 14 |
+
relevant_tables: List[str]
|
| 15 |
+
sql_hint: str
|
| 16 |
+
|
| 17 |
+
def __init__(self):
|
| 18 |
+
self.ground_truth: Any = None
|
| 19 |
+
self.top_3_categories: List[str] = []
|
| 20 |
+
|
| 21 |
+
@abstractmethod
|
| 22 |
+
def compute_ground_truth(self, conn: sqlite3.Connection) -> None:
|
| 23 |
+
"""Compute ground truth after database seeding."""
|
| 24 |
+
pass
|
| 25 |
+
|
| 26 |
+
@abstractmethod
|
| 27 |
+
def grade(self, submitted_answer: str) -> float:
|
| 28 |
+
"""Grade the submitted answer. Returns score 0.0-1.0."""
|
| 29 |
+
pass
|
| 30 |
+
|
| 31 |
+
def get_hints(self, step: int) -> List[str]:
|
| 32 |
+
"""Return progressive hints based on current step."""
|
| 33 |
+
hints = []
|
| 34 |
+
if step > 5:
|
| 35 |
+
hints.append(
|
| 36 |
+
f"Hint: The relevant tables are: {', '.join(self.relevant_tables)}"
|
| 37 |
+
)
|
| 38 |
+
if step > 10:
|
| 39 |
+
hints.append(f"Hint: Try using {self.sql_hint}")
|
| 40 |
+
if step > 15:
|
| 41 |
+
hints.append("Hint: Make sure to submit your answer with submit_answer.")
|
| 42 |
+
return hints
|
| 43 |
+
|
| 44 |
+
def _normalize(self, text: str) -> str:
|
| 45 |
+
"""Remove common LLM formatting and normalize text."""
|
| 46 |
+
text = text.strip().lower()
|
| 47 |
+
text = re.sub(r"the (answer|result|category) is:?\s*", "", text)
|
| 48 |
+
text = re.sub(r"\*+", "", text)
|
| 49 |
+
text = re.sub(r"```.*?```", "", text, flags=re.DOTALL)
|
| 50 |
+
text = re.sub(r"`[^`]+`", lambda m: m.group().strip("`"), text)
|
| 51 |
+
text = re.sub(r"\s+", " ", text)
|
| 52 |
+
return text.strip()
|
temp_upload/env/tasks/easy.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sqlite3
|
| 2 |
+
from .base import BaseTask
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class MonthlySignupsTask(BaseTask):
|
| 6 |
+
"""Task 1 — Easy: Count users who signed up in the last 30 days."""
|
| 7 |
+
|
| 8 |
+
task_id = "monthly_signups"
|
| 9 |
+
difficulty = "easy"
|
| 10 |
+
max_steps = 10
|
| 11 |
+
question = "How many users signed up in the last 30 days?"
|
| 12 |
+
relevant_tables = ["users"]
|
| 13 |
+
sql_hint = "COUNT(*) with WHERE clause on created_at"
|
| 14 |
+
|
| 15 |
+
def compute_ground_truth(self, conn: sqlite3.Connection) -> None:
|
| 16 |
+
result = conn.execute(
|
| 17 |
+
"SELECT COUNT(*) FROM users WHERE created_at >= DATE('now', '-30 days')"
|
| 18 |
+
).fetchone()
|
| 19 |
+
self.ground_truth = result[0] if result else 0
|
| 20 |
+
|
| 21 |
+
def grade(self, submitted_answer: str) -> float:
|
| 22 |
+
try:
|
| 23 |
+
val = int(submitted_answer.strip().replace(",", ""))
|
| 24 |
+
if val == self.ground_truth:
|
| 25 |
+
return 1.0
|
| 26 |
+
if abs(val - self.ground_truth) <= 3:
|
| 27 |
+
return 0.6
|
| 28 |
+
if abs(val - self.ground_truth) <= 10:
|
| 29 |
+
return 0.3
|
| 30 |
+
except (ValueError, AttributeError):
|
| 31 |
+
pass
|
| 32 |
+
return 0.0
|
temp_upload/env/tasks/hard.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sqlite3
|
| 2 |
+
from .base import BaseTask
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class ChurnAnalysisTask(BaseTask):
|
| 6 |
+
"""Task 3 — Hard: Find users who placed exactly 3 orders and then churned."""
|
| 7 |
+
|
| 8 |
+
task_id = "churn_analysis"
|
| 9 |
+
difficulty = "hard"
|
| 10 |
+
max_steps = 20
|
| 11 |
+
question = "Find the email addresses of users who placed exactly 3 orders and then never ordered again (churned after their 3rd purchase). Return as a comma-separated list."
|
| 12 |
+
relevant_tables = ["users", "orders"]
|
| 13 |
+
sql_hint = "CTE with COUNT and HAVING"
|
| 14 |
+
|
| 15 |
+
def compute_ground_truth(self, conn: sqlite3.Connection) -> None:
|
| 16 |
+
result = conn.execute("""
|
| 17 |
+
WITH order_counts AS (
|
| 18 |
+
SELECT user_id, COUNT(*) AS total_orders,
|
| 19 |
+
MAX(created_at) AS last_order_date
|
| 20 |
+
FROM orders
|
| 21 |
+
WHERE status = 'completed'
|
| 22 |
+
GROUP BY user_id
|
| 23 |
+
HAVING COUNT(*) = 3
|
| 24 |
+
),
|
| 25 |
+
churned AS (
|
| 26 |
+
SELECT oc.user_id
|
| 27 |
+
FROM order_counts oc
|
| 28 |
+
WHERE oc.last_order_date < DATE('now', '-90 days')
|
| 29 |
+
)
|
| 30 |
+
SELECT u.email
|
| 31 |
+
FROM users u
|
| 32 |
+
JOIN churned c ON u.id = c.user_id
|
| 33 |
+
""").fetchall()
|
| 34 |
+
|
| 35 |
+
self.ground_truth = {row[0].lower() for row in result}
|
| 36 |
+
|
| 37 |
+
def grade(self, submitted_answer: str) -> float:
|
| 38 |
+
if not submitted_answer.strip():
|
| 39 |
+
return 0.0
|
| 40 |
+
|
| 41 |
+
submitted = {e.strip().lower() for e in submitted_answer.split(",") if "@" in e}
|
| 42 |
+
|
| 43 |
+
if not submitted:
|
| 44 |
+
return 0.0
|
| 45 |
+
|
| 46 |
+
correct = {e.lower() for e in self.ground_truth}
|
| 47 |
+
tp = len(submitted & correct)
|
| 48 |
+
|
| 49 |
+
if tp == 0:
|
| 50 |
+
return 0.0
|
| 51 |
+
|
| 52 |
+
precision = tp / len(submitted) if submitted else 0
|
| 53 |
+
recall = tp / len(correct) if correct else 0
|
| 54 |
+
|
| 55 |
+
if precision + recall == 0:
|
| 56 |
+
return 0.0
|
| 57 |
+
|
| 58 |
+
f1 = 2 * precision * recall / (precision + recall)
|
| 59 |
+
return round(f1, 3)
|
temp_upload/env/tasks/medium.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sqlite3
|
| 2 |
+
from .base import BaseTask
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class TopRevenueCategoryTask(BaseTask):
|
| 6 |
+
"""Task 2 — Medium: Find product category with most revenue in Q3."""
|
| 7 |
+
|
| 8 |
+
task_id = "top_revenue_category"
|
| 9 |
+
difficulty = "medium"
|
| 10 |
+
max_steps = 15
|
| 11 |
+
question = (
|
| 12 |
+
"Which product category generated the most revenue in Q3 (July-September)?"
|
| 13 |
+
)
|
| 14 |
+
relevant_tables = ["orders", "order_items", "products"]
|
| 15 |
+
sql_hint = "JOIN with GROUP BY and ORDER BY"
|
| 16 |
+
|
| 17 |
+
def compute_ground_truth(self, conn: sqlite3.Connection) -> None:
|
| 18 |
+
result = conn.execute("""
|
| 19 |
+
SELECT p.category, SUM(oi.qty * oi.unit_price) AS revenue
|
| 20 |
+
FROM orders o
|
| 21 |
+
JOIN order_items oi ON o.id = oi.order_id
|
| 22 |
+
JOIN products p ON oi.product_id = p.id
|
| 23 |
+
WHERE o.created_at BETWEEN '2024-07-01' AND '2024-09-30'
|
| 24 |
+
AND o.status = 'completed'
|
| 25 |
+
GROUP BY p.category
|
| 26 |
+
ORDER BY revenue DESC
|
| 27 |
+
LIMIT 1
|
| 28 |
+
""").fetchone()
|
| 29 |
+
|
| 30 |
+
self.ground_truth = result[0] if result else None
|
| 31 |
+
|
| 32 |
+
all_categories = conn.execute("""
|
| 33 |
+
SELECT p.category, SUM(oi.qty * oi.unit_price) AS revenue
|
| 34 |
+
FROM orders o
|
| 35 |
+
JOIN order_items oi ON o.id = oi.order_id
|
| 36 |
+
JOIN products p ON oi.product_id = p.id
|
| 37 |
+
WHERE o.created_at BETWEEN '2024-07-01' AND '2024-09-30'
|
| 38 |
+
AND o.status = 'completed'
|
| 39 |
+
GROUP BY p.category
|
| 40 |
+
ORDER BY revenue DESC
|
| 41 |
+
""").fetchall()
|
| 42 |
+
|
| 43 |
+
self.top_3_categories = [row[0] for row in all_categories[:3]]
|
| 44 |
+
|
| 45 |
+
def grade(self, submitted_answer: str) -> float:
|
| 46 |
+
answer = self._normalize(submitted_answer)
|
| 47 |
+
|
| 48 |
+
if self.ground_truth and self.ground_truth.lower() in answer:
|
| 49 |
+
return 1.0
|
| 50 |
+
|
| 51 |
+
if any(cat.lower() in answer for cat in self.top_3_categories):
|
| 52 |
+
return 0.4
|
| 53 |
+
|
| 54 |
+
return 0.0
|
temp_upload/env/utils.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def normalize_answer(raw: str) -> str:
|
| 5 |
+
"""Remove common LLM answer preambles and formatting."""
|
| 6 |
+
text = raw.strip().lower()
|
| 7 |
+
text = re.sub(r"the (answer|result) is:?\s*", "", text)
|
| 8 |
+
text = re.sub(r"\*+", "", text)
|
| 9 |
+
text = re.sub(r"```.*?```", "", text, flags=re.DOTALL)
|
| 10 |
+
text = re.sub(r"`[^`]+`", lambda m: m.group().strip("`"), text)
|
| 11 |
+
text = re.sub(r"\s+", " ", text)
|
| 12 |
+
return text.strip()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
FORBIDDEN_KEYWORDS = [
|
| 16 |
+
"DROP",
|
| 17 |
+
"DELETE",
|
| 18 |
+
"INSERT",
|
| 19 |
+
"UPDATE",
|
| 20 |
+
"ALTER",
|
| 21 |
+
"CREATE",
|
| 22 |
+
"TRUNCATE",
|
| 23 |
+
]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def is_safe_query(query: str) -> bool:
|
| 27 |
+
"""Check if query is safe (SELECT-only)."""
|
| 28 |
+
upper = query.upper()
|
| 29 |
+
return not any(kw in upper for kw in FORBIDDEN_KEYWORDS)
|
temp_upload/server/__init__.py
ADDED
|
File without changes
|
temp_upload/server/app.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from env.server import app as _app
|
| 2 |
+
import uvicorn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def main():
|
| 6 |
+
uvicorn.run(_app, host="0.0.0.0", port=7860)
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
if __name__ == "__main__":
|
| 10 |
+
main()
|
| 11 |
+
|
| 12 |
+
__all__ = ["app", "main"]
|
temp_upload/tests/__init__.py
ADDED
|
File without changes
|