Spaces:
Sleeping
Sleeping
Sync from GitHub - all files
Browse files- inference.py +2 -2
- openenv-sql-analyst/Dockerfile +31 -0
- openenv-sql-analyst/README.md +337 -0
- openenv-sql-analyst/data/mock_data.sql +95 -0
- openenv-sql-analyst/environment/__init__.py +19 -0
- openenv-sql-analyst/environment/db_engine.py +260 -0
- openenv-sql-analyst/environment/env.py +304 -0
- openenv-sql-analyst/environment/graders.py +232 -0
- openenv-sql-analyst/environment/models.py +70 -0
- openenv-sql-analyst/environment/tasks.py +143 -0
- openenv-sql-analyst/inference.py +267 -0
- openenv-sql-analyst/openenv.yaml +98 -0
- openenv-sql-analyst/pyproject.toml +20 -0
- openenv-sql-analyst/requirements.txt +20 -0
- openenv-sql-analyst/server/app.py +41 -0
- openenv-sql-analyst/validate.sh +112 -0
inference.py
CHANGED
|
@@ -115,13 +115,13 @@ def extract_sql_or_answer(action_str: str):
|
|
| 115 |
|
| 116 |
|
| 117 |
def main():
|
| 118 |
-
api_key = os.environ.get("HF_TOKEN") or os.environ.get("OPENAI_API_KEY")
|
| 119 |
base_url = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
|
| 120 |
model_name = os.environ.get("MODEL_NAME", "gpt-4o-mini")
|
| 121 |
env_url = os.environ.get("OPENENV_URL")
|
| 122 |
|
| 123 |
if not api_key:
|
| 124 |
-
print("Error: Set HF_TOKEN or OPENAI_API_KEY environment variable")
|
| 125 |
return
|
| 126 |
|
| 127 |
client = OpenAI(base_url=base_url, api_key=api_key)
|
|
|
|
| 115 |
|
| 116 |
|
| 117 |
def main():
|
| 118 |
+
api_key = os.environ.get("API_KEY") or os.environ.get("HF_TOKEN") or os.environ.get("OPENAI_API_KEY")
|
| 119 |
base_url = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
|
| 120 |
model_name = os.environ.get("MODEL_NAME", "gpt-4o-mini")
|
| 121 |
env_url = os.environ.get("OPENENV_URL")
|
| 122 |
|
| 123 |
if not api_key:
|
| 124 |
+
print("Error: Set API_KEY, HF_TOKEN, or OPENAI_API_KEY environment variable")
|
| 125 |
return
|
| 126 |
|
| 127 |
client = OpenAI(base_url=base_url, api_key=api_key)
|
openenv-sql-analyst/Dockerfile
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# OpenEnv SQL Analyst Environment
|
| 2 |
+
# Base: python:3.10-slim for minimal memory footprint (<8GB RAM limit)
|
| 3 |
+
|
| 4 |
+
FROM python:3.10-slim
|
| 5 |
+
|
| 6 |
+
# Set working directory
|
| 7 |
+
WORKDIR /app
|
| 8 |
+
|
| 9 |
+
# Install system dependencies
|
| 10 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
| 11 |
+
gcc \
|
| 12 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 13 |
+
|
| 14 |
+
# Copy requirements first for layer caching
|
| 15 |
+
COPY requirements.txt .
|
| 16 |
+
|
| 17 |
+
# Install Python dependencies WITH UV added for the hotfix
|
| 18 |
+
RUN pip install --no-cache-dir -r requirements.txt uv
|
| 19 |
+
|
| 20 |
+
# Copy application code
|
| 21 |
+
COPY . .
|
| 22 |
+
|
| 23 |
+
# Expose the OpenEnv serving port
|
| 24 |
+
EXPOSE 7860
|
| 25 |
+
|
| 26 |
+
# Set environment variables
|
| 27 |
+
ENV PYTHONUNBUFFERED=1
|
| 28 |
+
ENV PYTHONDONTWRITEBYTECODE=1
|
| 29 |
+
|
| 30 |
+
# Replaced deprecated 'openenv serve' with the command the runtime error requested
|
| 31 |
+
CMD ["uv", "run", "--project", ".", "server", "--port", "7860"]
|
openenv-sql-analyst/README.md
ADDED
|
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: OpenEnv SQL Analyst
|
| 3 |
+
emoji: π
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
+
sdk: docker
|
| 7 |
+
pinned: false
|
| 8 |
+
tags:
|
| 9 |
+
- openenv
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# SQL Data Analyst RL Environment
|
| 13 |
+
|
| 14 |
+
> A production-grade, containerized Reinforcement Learning environment for evaluating LLM-powered Data Analysts on real SQL business intelligence tasks.
|
| 15 |
+
|
| 16 |
+
**OpenEnv Hackathon Submission** | Meta x Scaler
|
| 17 |
+
|
| 18 |
+
---
|
| 19 |
+
|
| 20 |
+
## Environment Description and Motivation
|
| 21 |
+
|
| 22 |
+
This environment simulates a **mission-critical enterprise task**: an AI agent querying a production SQL database to extract business intelligence. In real-world enterprises, data analysts spend countless hours writing SQL queries to answer ad-hoc business questions from stakeholders. This environment provides a standardized benchmark to evaluate whether LLM agents can safely and accurately perform this task autonomously, measuring both **correctness** and **efficiency**.
|
| 23 |
+
|
| 24 |
+
### Why This Matters
|
| 25 |
+
|
| 26 |
+
- **Real-World Applicability**: Data analysis is one of the most common knowledge work tasks that LLMs are being deployed for
|
| 27 |
+
- **Safety-Critical**: Database access requires strict guardrails to prevent data corruption
|
| 28 |
+
- **Measurable Outcomes**: Business questions have definitive correct answers, enabling objective evaluation
|
| 29 |
+
|
| 30 |
+
### Production-Grade Security
|
| 31 |
+
|
| 32 |
+
The environment implements security safeguards that mirror real enterprise database access controls:
|
| 33 |
+
|
| 34 |
+
| Security Layer | Implementation | Purpose |
|
| 35 |
+
|----------------|----------------|---------|
|
| 36 |
+
| **Mutation Blocker** | Regex-based blocking of `INSERT`, `UPDATE`, `DELETE`, `DROP`, `ALTER`, `TRUNCATE` | Prevents data corruption |
|
| 37 |
+
| **OOM Protection** | `cursor.fetchmany(50)` instead of `fetchall()` | Prevents memory exhaustion on large result sets |
|
| 38 |
+
| **Query Timeout** | 2-second timeout wrapper | Prevents runaway queries from consuming resources |
|
| 39 |
+
| **Read-Only Sandbox** | In-memory SQLite (`:memory:` mode) | Isolated execution environment |
|
| 40 |
+
|
| 41 |
+
---
|
| 42 |
+
|
| 43 |
+
## Action Space
|
| 44 |
+
|
| 45 |
+
The agent submits an `Action` object with **exactly one** of two fields:
|
| 46 |
+
|
| 47 |
+
| Field | Type | Description |
|
| 48 |
+
|-------|------|-------------|
|
| 49 |
+
| `sql_query` | `Optional[str]` | Execute a SQL query against the database |
|
| 50 |
+
| `submit_answer` | `Optional[str]` | Submit a final answer for grading |
|
| 51 |
+
|
| 52 |
+
**Mutual Exclusivity Enforced**: A Pydantic `@model_validator` ensures the agent provides exactly one of `sql_query` or `submit_answer`. Providing both or neither raises a `ValueError`.
|
| 53 |
+
|
| 54 |
+
```python
|
| 55 |
+
# Example Actions
|
| 56 |
+
action_query = Action(sql_query="SELECT COUNT(*) FROM users")
|
| 57 |
+
action_submit = Action(submit_answer="15")
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
---
|
| 61 |
+
|
| 62 |
+
## Observation Space
|
| 63 |
+
|
| 64 |
+
The agent receives an `Observation` object containing four fields:
|
| 65 |
+
|
| 66 |
+
| Field | Type | Description |
|
| 67 |
+
|-------|------|-------------|
|
| 68 |
+
| `schema_info` | `str` | Database schema information (tables, columns, types) |
|
| 69 |
+
| `current_question` | `str` | The business question the agent must answer |
|
| 70 |
+
| `last_query_result` | `str` | Result from the most recent SQL query (markdown table format) |
|
| 71 |
+
| `error_message` | `str` | Any error from the last action (empty string if none) |
|
| 72 |
+
|
| 73 |
+
---
|
| 74 |
+
|
| 75 |
+
## Reward Shaping
|
| 76 |
+
|
| 77 |
+
The environment implements precise partial reward signals to guide learning:
|
| 78 |
+
|
| 79 |
+
| Event | Reward | Episode Ends? |
|
| 80 |
+
|-------|--------|---------------|
|
| 81 |
+
| Successful SQL query (no errors) | `+0.1` | No |
|
| 82 |
+
| SQLite syntax error | `-0.1` | No |
|
| 83 |
+
| Destructive action detected | `-1.0` | **Yes** |
|
| 84 |
+
| Step count >= 15 (infinite loop shield) | `-0.5` | **Yes** |
|
| 85 |
+
| Correct answer submitted | `+1.0` | **Yes** |
|
| 86 |
+
| Incorrect answer submitted | `0.0` | **Yes** |
|
| 87 |
+
|
| 88 |
+
**Final Score Calculation**:
|
| 89 |
+
- If incorrect: `score = 0.0`
|
| 90 |
+
- If correct: `score = 0.7 + (1 - steps/15) * 0.3`
|
| 91 |
+
- Score range: `0.0` to `1.0`
|
| 92 |
+
|
| 93 |
+
---
|
| 94 |
+
|
| 95 |
+
## Task Descriptions
|
| 96 |
+
|
| 97 |
+
The environment includes **3 deterministic tasks** of increasing difficulty:
|
| 98 |
+
|
| 99 |
+
### Easy: User Count
|
| 100 |
+
| Attribute | Value |
|
| 101 |
+
|-----------|-------|
|
| 102 |
+
| **Task ID** | `easy_user_count` |
|
| 103 |
+
| **Difficulty** | Easy |
|
| 104 |
+
| **Question** | "How many users are registered in the system? Provide the total count as a single number." |
|
| 105 |
+
| **Ground Truth** | `15` |
|
| 106 |
+
| **SQL Complexity** | Single table `COUNT` query |
|
| 107 |
+
| **Reference SQL** | `SELECT COUNT(*) FROM users` |
|
| 108 |
+
|
| 109 |
+
### Medium: USA Revenue
|
| 110 |
+
| Attribute | Value |
|
| 111 |
+
|-----------|-------|
|
| 112 |
+
| **Task ID** | `medium_usa_revenue` |
|
| 113 |
+
| **Difficulty** | Medium |
|
| 114 |
+
| **Question** | "What is the total revenue (sum of total_amount) from purchases made by users in the USA? Provide the total as a number (rounded to 2 decimal places if needed)." |
|
| 115 |
+
| **Ground Truth** | `2423.87` |
|
| 116 |
+
| **SQL Complexity** | Two-table `JOIN` with `SUM` aggregation filtered by country |
|
| 117 |
+
| **Reference SQL** | `SELECT ROUND(SUM(p.total_amount), 2) FROM purchases p JOIN users u ON p.user_id = u.user_id WHERE u.country = 'USA'` |
|
| 118 |
+
|
| 119 |
+
### Hard: Top Spender
|
| 120 |
+
| Attribute | Value |
|
| 121 |
+
|-----------|-------|
|
| 122 |
+
| **Task ID** | `hard_top_spender` |
|
| 123 |
+
| **Difficulty** | Hard |
|
| 124 |
+
| **Question** | "Who is the top spender (user with highest total purchase amount)? Provide the username of the user who spent the most money in total." |
|
| 125 |
+
| **Ground Truth** | `alice` |
|
| 126 |
+
| **SQL Complexity** | Complex query with `JOIN`, `GROUP BY`, `ORDER BY`, and `LIMIT` |
|
| 127 |
+
| **Reference SQL** | `SELECT u.username FROM users u JOIN purchases p ON u.user_id = p.user_id GROUP BY u.user_id, u.username ORDER BY SUM(p.total_amount) DESC LIMIT 1` |
|
| 128 |
+
|
| 129 |
+
### Grading System
|
| 130 |
+
|
| 131 |
+
All graders implement:
|
| 132 |
+
- **Type-agnostic normalization**: Whitespace trimming, lowercasing, numeric rounding to 2 decimal places
|
| 133 |
+
- **Numeric tolerance**: Answers within 0.01 absolute tolerance are exact matches
|
| 134 |
+
- **Partial credit**: Numeric answers within 10% receive 0.5 score
|
| 135 |
+
- **SQL evaluation**: If agent submits SQL as answer, it's executed and results compared
|
| 136 |
+
|
| 137 |
+
---
|
| 138 |
+
|
| 139 |
+
## Setup and Usage Instructions
|
| 140 |
+
|
| 141 |
+
### Prerequisites
|
| 142 |
+
|
| 143 |
+
- Docker installed and running
|
| 144 |
+
- Python 3.10+ (for local development)
|
| 145 |
+
- (Optional) HuggingFace token for inference with HF-hosted models
|
| 146 |
+
|
| 147 |
+
### Quick Start with Docker
|
| 148 |
+
|
| 149 |
+
```bash
|
| 150 |
+
# Clone the repository
|
| 151 |
+
git clone https://github.com/hitanshu04/openenv-sql-analyst.git
|
| 152 |
+
cd openenv_sql_analyst
|
| 153 |
+
|
| 154 |
+
# Build the Docker image
|
| 155 |
+
docker build -t openenv-sql-analyst .
|
| 156 |
+
|
| 157 |
+
# Run the container
|
| 158 |
+
docker run -p 7860:7860 openenv-sql-analyst
|
| 159 |
+
```
|
| 160 |
+
|
| 161 |
+
The server will be available at `http://localhost:7860`
|
| 162 |
+
|
| 163 |
+
### API Endpoints
|
| 164 |
+
|
| 165 |
+
| Endpoint | Method | Description |
|
| 166 |
+
|----------|--------|-------------|
|
| 167 |
+
| `/` | GET | Health check (returns 200 OK) |
|
| 168 |
+
| `/reset` | POST | Reset environment, returns initial observation |
|
| 169 |
+
| `/step` | POST | Execute action, returns (observation, reward, done, info) |
|
| 170 |
+
| `/state` | GET | Get current internal state |
|
| 171 |
+
|
| 172 |
+
### Local Development (Without Docker)
|
| 173 |
+
|
| 174 |
+
```bash
|
| 175 |
+
# Create virtual environment
|
| 176 |
+
python -m venv venv
|
| 177 |
+
source venv/bin/activate # On Windows: venv\Scripts\activate
|
| 178 |
+
|
| 179 |
+
# Install dependencies
|
| 180 |
+
pip install -r requirements.txt
|
| 181 |
+
|
| 182 |
+
# Run the server directly
|
| 183 |
+
python -m server.app
|
| 184 |
+
|
| 185 |
+
# Or run validation
|
| 186 |
+
chmod +x validate.sh
|
| 187 |
+
./validate.sh
|
| 188 |
+
```
|
| 189 |
+
|
| 190 |
+
### Running Inference
|
| 191 |
+
|
| 192 |
+
```bash
|
| 193 |
+
# Set environment variables
|
| 194 |
+
export HF_TOKEN="your-huggingface-token"
|
| 195 |
+
export API_BASE_URL="https://api.openai.com/v1" # or HF inference endpoint
|
| 196 |
+
export MODEL_NAME="gpt-4o-mini"
|
| 197 |
+
|
| 198 |
+
# Run inference
|
| 199 |
+
python inference.py
|
| 200 |
+
```
|
| 201 |
+
|
| 202 |
+
### Environment Variables
|
| 203 |
+
|
| 204 |
+
| Variable | Description | Default |
|
| 205 |
+
|----------|-------------|---------|
|
| 206 |
+
| `HF_TOKEN` | HuggingFace API token (used as API key) | Required for inference |
|
| 207 |
+
| `API_BASE_URL` | OpenAI-compatible API endpoint | `https://api.openai.com/v1` |
|
| 208 |
+
| `MODEL_NAME` | Model identifier | `gpt-4o-mini` |
|
| 209 |
+
|
| 210 |
+
### Validation Gates
|
| 211 |
+
|
| 212 |
+
Run `./validate.sh` before submission. All 4 checks must pass:
|
| 213 |
+
|
| 214 |
+
| Step | Check | Failure Condition |
|
| 215 |
+
|------|-------|-------------------|
|
| 216 |
+
| 1/4 | Prerequisites | `docker` or `openenv` CLI not found |
|
| 217 |
+
| 2/4 | Docker Build | `Dockerfile` missing or build fails |
|
| 218 |
+
| 3/4 | OpenEnv Spec | `openenv validate` fails (yaml/models mismatch) |
|
| 219 |
+
| 4/4 | Inference Logs | Missing `[START]`/`[STEP]`/`[END]` tags or invalid score |
|
| 220 |
+
|
| 221 |
+
---
|
| 222 |
+
|
| 223 |
+
## Baseline Scores
|
| 224 |
+
|
| 225 |
+
Expected performance with `gpt-4o-mini`:
|
| 226 |
+
|
| 227 |
+
| Task | Difficulty | Expected Steps | Expected Score |
|
| 228 |
+
|------|------------|----------------|----------------|
|
| 229 |
+
| `easy_user_count` | Easy | 2-3 | 0.90 - 1.00 |
|
| 230 |
+
| `medium_usa_revenue` | Medium | 3-5 | 0.85 - 0.95 |
|
| 231 |
+
| `hard_top_spender` | Hard | 4-7 | 0.75 - 0.90 |
|
| 232 |
+
|
| 233 |
+
### STDOUT Log Format
|
| 234 |
+
|
| 235 |
+
The inference script outputs logs in the exact required format:
|
| 236 |
+
|
| 237 |
+
```
|
| 238 |
+
[START] task=<task_id> env=sql_analyst model=<model_name>
|
| 239 |
+
[STEP] step=<n> action=<action_type>=<value> reward=<r.rr> done=<bool> error=<msg>
|
| 240 |
+
[END] success=<bool> steps=<n> score=<s.ss> rewards=<r1>,<r2>,...
|
| 241 |
+
```
|
| 242 |
+
|
| 243 |
+
**Example Output**:
|
| 244 |
+
```
|
| 245 |
+
[START] task=easy_user_count env=sql_analyst model=gpt-4o-mini
|
| 246 |
+
[STEP] step=1 action=sql_query=SELECT COUNT(*) FROM users reward=0.10 done=false error=null
|
| 247 |
+
[STEP] step=2 action=submit_answer=15 reward=1.00 done=true error=null
|
| 248 |
+
[END] success=true steps=2 score=0.96 rewards=0.10,1.00
|
| 249 |
+
```
|
| 250 |
+
|
| 251 |
+
---
|
| 252 |
+
|
| 253 |
+
## Project Architecture
|
| 254 |
+
|
| 255 |
+
```
|
| 256 |
+
openenv_sql_analyst/
|
| 257 |
+
βββ openenv.yaml # OpenEnv specification (name, schemas, endpoints)
|
| 258 |
+
βββ Dockerfile # Container config (python:3.10-slim, port 7860)
|
| 259 |
+
βββ requirements.txt # Python dependencies
|
| 260 |
+
βββ pyproject.toml # Python project configuration
|
| 261 |
+
βββ validate.sh # Pre-submission validation (4 gates)
|
| 262 |
+
βββ inference.py # Baseline LLM agent implementation
|
| 263 |
+
βββ data/
|
| 264 |
+
β βββ mock_data.sql # SQLite mock database (3 tables, ~50 rows)
|
| 265 |
+
βββ environment/
|
| 266 |
+
β βββ __init__.py # Package exports
|
| 267 |
+
β βββ models.py # Pydantic schemas (Action, Observation, Reward)
|
| 268 |
+
β βββ db_engine.py # SQLite engine with security safeguards
|
| 269 |
+
β βββ tasks.py # Task definitions (Easy, Medium, Hard)
|
| 270 |
+
β βββ graders.py # Deterministic grading system
|
| 271 |
+
β βββ env.py # Main SQLAnalystEnv class (reset, step, state)
|
| 272 |
+
βββ server/
|
| 273 |
+
βββ app.py # FastAPI server (/reset, /step, /state endpoints)
|
| 274 |
+
```
|
| 275 |
+
|
| 276 |
+
---
|
| 277 |
+
|
| 278 |
+
## Technical Specifications
|
| 279 |
+
|
| 280 |
+
| Specification | Value |
|
| 281 |
+
|---------------|-------|
|
| 282 |
+
| Python Version | 3.10 |
|
| 283 |
+
| Container Base | `python:3.10-slim` |
|
| 284 |
+
| Container Port | 7860 |
|
| 285 |
+
| vCPU Limit | 2 |
|
| 286 |
+
| Memory Limit | 8 GB |
|
| 287 |
+
| Max Runtime | 20 minutes |
|
| 288 |
+
| Max Steps per Episode | 15 |
|
| 289 |
+
| Query Timeout | 2 seconds |
|
| 290 |
+
| Max Fetch Rows | 50 |
|
| 291 |
+
| Database | SQLite (in-memory) |
|
| 292 |
+
|
| 293 |
+
---
|
| 294 |
+
|
| 295 |
+
## Database Schema
|
| 296 |
+
|
| 297 |
+
The mock database contains 3 tables:
|
| 298 |
+
|
| 299 |
+
### users
|
| 300 |
+
| Column | Type | Constraints |
|
| 301 |
+
|--------|------|-------------|
|
| 302 |
+
| user_id | INTEGER | PRIMARY KEY |
|
| 303 |
+
| username | TEXT | NOT NULL |
|
| 304 |
+
| email | TEXT | NOT NULL |
|
| 305 |
+
| country | TEXT | NOT NULL |
|
| 306 |
+
| created_at | TEXT | NOT NULL |
|
| 307 |
+
|
| 308 |
+
### products
|
| 309 |
+
| Column | Type | Constraints |
|
| 310 |
+
|--------|------|-------------|
|
| 311 |
+
| product_id | INTEGER | PRIMARY KEY |
|
| 312 |
+
| product_name | TEXT | NOT NULL |
|
| 313 |
+
| category | TEXT | NOT NULL |
|
| 314 |
+
| price | REAL | NOT NULL |
|
| 315 |
+
| stock | INTEGER | NOT NULL |
|
| 316 |
+
|
| 317 |
+
### purchases
|
| 318 |
+
| Column | Type | Constraints |
|
| 319 |
+
|--------|------|-------------|
|
| 320 |
+
| purchase_id | INTEGER | PRIMARY KEY |
|
| 321 |
+
| user_id | INTEGER | NOT NULL, FOREIGN KEY |
|
| 322 |
+
| product_id | INTEGER | NOT NULL, FOREIGN KEY |
|
| 323 |
+
| quantity | INTEGER | NOT NULL |
|
| 324 |
+
| purchase_date | TEXT | NOT NULL |
|
| 325 |
+
| total_amount | REAL | NOT NULL |
|
| 326 |
+
|
| 327 |
+
---
|
| 328 |
+
|
| 329 |
+
## License
|
| 330 |
+
|
| 331 |
+
MIT License
|
| 332 |
+
|
| 333 |
+
---
|
| 334 |
+
|
| 335 |
+
## Acknowledgments
|
| 336 |
+
|
| 337 |
+
Built for the **Meta x Scaler OpenEnv Hackathon** - advancing the frontier of LLM agent evaluation through standardized, production-grade reinforcement learning environments.
|
openenv-sql-analyst/data/mock_data.sql
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
-- OpenEnv SQL Analyst - Mock Data
|
| 2 |
+
-- Tables: users, products, purchases
|
| 3 |
+
-- Approximately 50 rows total for lightweight operation
|
| 4 |
+
|
| 5 |
+
-- =============================================
|
| 6 |
+
-- TABLE: users
|
| 7 |
+
-- =============================================
|
| 8 |
+
CREATE TABLE IF NOT EXISTS users (
|
| 9 |
+
user_id INTEGER PRIMARY KEY,
|
| 10 |
+
username TEXT NOT NULL,
|
| 11 |
+
email TEXT NOT NULL,
|
| 12 |
+
country TEXT NOT NULL,
|
| 13 |
+
created_at TEXT NOT NULL
|
| 14 |
+
);
|
| 15 |
+
|
| 16 |
+
INSERT INTO users (user_id, username, email, country, created_at) VALUES
|
| 17 |
+
(1, 'alice', 'alice@example.com', 'USA', '2023-01-15'),
|
| 18 |
+
(2, 'bob', 'bob@example.com', 'Canada', '2023-02-20'),
|
| 19 |
+
(3, 'charlie', 'charlie@example.com', 'UK', '2023-03-10'),
|
| 20 |
+
(4, 'diana', 'diana@example.com', 'USA', '2023-04-05'),
|
| 21 |
+
(5, 'eve', 'eve@example.com', 'Germany', '2023-05-12'),
|
| 22 |
+
(6, 'frank', 'frank@example.com', 'France', '2023-06-18'),
|
| 23 |
+
(7, 'grace', 'grace@example.com', 'USA', '2023-07-22'),
|
| 24 |
+
(8, 'henry', 'henry@example.com', 'Canada', '2023-08-30'),
|
| 25 |
+
(9, 'iris', 'iris@example.com', 'UK', '2023-09-14'),
|
| 26 |
+
(10, 'jack', 'jack@example.com', 'USA', '2023-10-01'),
|
| 27 |
+
(11, 'karen', 'karen@example.com', 'Germany', '2023-10-15'),
|
| 28 |
+
(12, 'leo', 'leo@example.com', 'France', '2023-11-02'),
|
| 29 |
+
(13, 'maria', 'maria@example.com', 'Spain', '2023-11-20'),
|
| 30 |
+
(14, 'nathan', 'nathan@example.com', 'USA', '2023-12-05'),
|
| 31 |
+
(15, 'olivia', 'olivia@example.com', 'Canada', '2023-12-18');
|
| 32 |
+
|
| 33 |
+
-- =============================================
|
| 34 |
+
-- TABLE: products
|
| 35 |
+
-- =============================================
|
| 36 |
+
CREATE TABLE IF NOT EXISTS products (
|
| 37 |
+
product_id INTEGER PRIMARY KEY,
|
| 38 |
+
product_name TEXT NOT NULL,
|
| 39 |
+
category TEXT NOT NULL,
|
| 40 |
+
price REAL NOT NULL,
|
| 41 |
+
stock INTEGER NOT NULL
|
| 42 |
+
);
|
| 43 |
+
|
| 44 |
+
INSERT INTO products (product_id, product_name, category, price, stock) VALUES
|
| 45 |
+
(1, 'Laptop Pro', 'Electronics', 1299.99, 50),
|
| 46 |
+
(2, 'Wireless Mouse', 'Electronics', 29.99, 200),
|
| 47 |
+
(3, 'USB-C Hub', 'Electronics', 49.99, 150),
|
| 48 |
+
(4, 'Mechanical Keyboard', 'Electronics', 89.99, 100),
|
| 49 |
+
(5, 'Monitor 27"', 'Electronics', 349.99, 75),
|
| 50 |
+
(6, 'Desk Chair', 'Furniture', 199.99, 40),
|
| 51 |
+
(7, 'Standing Desk', 'Furniture', 449.99, 25),
|
| 52 |
+
(8, 'Desk Lamp', 'Furniture', 34.99, 120),
|
| 53 |
+
(9, 'Notebook Pack', 'Office', 12.99, 300),
|
| 54 |
+
(10, 'Pen Set', 'Office', 8.99, 500),
|
| 55 |
+
(11, 'Headphones', 'Electronics', 149.99, 80),
|
| 56 |
+
(12, 'Webcam HD', 'Electronics', 79.99, 90),
|
| 57 |
+
(13, 'Mousepad XL', 'Electronics', 19.99, 250),
|
| 58 |
+
(14, 'Cable Organizer', 'Office', 14.99, 180),
|
| 59 |
+
(15, 'Monitor Stand', 'Furniture', 59.99, 60);
|
| 60 |
+
|
| 61 |
+
-- =============================================
|
| 62 |
+
-- TABLE: purchases
|
| 63 |
+
-- =============================================
|
| 64 |
+
CREATE TABLE IF NOT EXISTS purchases (
|
| 65 |
+
purchase_id INTEGER PRIMARY KEY,
|
| 66 |
+
user_id INTEGER NOT NULL,
|
| 67 |
+
product_id INTEGER NOT NULL,
|
| 68 |
+
quantity INTEGER NOT NULL,
|
| 69 |
+
purchase_date TEXT NOT NULL,
|
| 70 |
+
total_amount REAL NOT NULL,
|
| 71 |
+
FOREIGN KEY (user_id) REFERENCES users(user_id),
|
| 72 |
+
FOREIGN KEY (product_id) REFERENCES products(product_id)
|
| 73 |
+
);
|
| 74 |
+
|
| 75 |
+
INSERT INTO purchases (purchase_id, user_id, product_id, quantity, purchase_date, total_amount) VALUES
|
| 76 |
+
(1, 1, 1, 1, '2023-06-01', 1299.99),
|
| 77 |
+
(2, 1, 2, 2, '2023-06-01', 59.98),
|
| 78 |
+
(3, 2, 4, 1, '2023-06-15', 89.99),
|
| 79 |
+
(4, 3, 5, 1, '2023-07-01', 349.99),
|
| 80 |
+
(5, 4, 6, 1, '2023-07-10', 199.99),
|
| 81 |
+
(6, 5, 7, 1, '2023-07-20', 449.99),
|
| 82 |
+
(7, 1, 11, 1, '2023-08-01', 149.99),
|
| 83 |
+
(8, 6, 3, 2, '2023-08-05', 99.98),
|
| 84 |
+
(9, 7, 9, 5, '2023-08-10', 64.95),
|
| 85 |
+
(10, 8, 10, 10, '2023-08-15', 89.90),
|
| 86 |
+
(11, 2, 12, 1, '2023-09-01', 79.99),
|
| 87 |
+
(12, 9, 8, 2, '2023-09-10', 69.98),
|
| 88 |
+
(13, 10, 13, 1, '2023-09-15', 19.99),
|
| 89 |
+
(14, 3, 14, 3, '2023-09-20', 44.97),
|
| 90 |
+
(15, 4, 15, 1, '2023-10-01', 59.99),
|
| 91 |
+
(16, 11, 1, 1, '2023-10-05', 1299.99),
|
| 92 |
+
(17, 12, 2, 3, '2023-10-10', 89.97),
|
| 93 |
+
(18, 5, 4, 1, '2023-10-15', 89.99),
|
| 94 |
+
(19, 13, 11, 2, '2023-10-20', 299.98),
|
| 95 |
+
(20, 14, 5, 1, '2023-11-01', 349.99);
|
openenv-sql-analyst/environment/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# environment/__init__.py
|
| 2 |
+
# OpenEnv SQL Analyst Environment Package
|
| 3 |
+
|
| 4 |
+
from .models import Action, Observation, Reward
|
| 5 |
+
from .db_engine import DatabaseEngine
|
| 6 |
+
from .tasks import TASKS, get_task_by_difficulty
|
| 7 |
+
from .graders import grade_answer
|
| 8 |
+
from .env import SQLAnalystEnv
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"Action",
|
| 12 |
+
"Observation",
|
| 13 |
+
"Reward",
|
| 14 |
+
"DatabaseEngine",
|
| 15 |
+
"TASKS",
|
| 16 |
+
"get_task_by_difficulty",
|
| 17 |
+
"grade_answer",
|
| 18 |
+
"SQLAnalystEnv",
|
| 19 |
+
]
|
openenv-sql-analyst/environment/db_engine.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# environment/db_engine.py
|
| 2 |
+
# SQLite Database Engine with Security Safeguards
|
| 3 |
+
# Implements: Mutation Blocker, OOM Protection, Timeout Wrapper
|
| 4 |
+
|
| 5 |
+
import re
|
| 6 |
+
import sqlite3
|
| 7 |
+
import signal
|
| 8 |
+
import os
|
| 9 |
+
from typing import Tuple, Optional
|
| 10 |
+
from contextlib import contextmanager
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
# Regex pattern for blocking destructive SQL operations
|
| 15 |
+
MUTATION_PATTERN = re.compile(
|
| 16 |
+
r'\b(INSERT|UPDATE|DELETE|DROP|ALTER|TRUNCATE)\b',
|
| 17 |
+
re.IGNORECASE
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
# Query execution timeout in seconds
|
| 21 |
+
QUERY_TIMEOUT = 2.0
|
| 22 |
+
|
| 23 |
+
# Maximum rows to fetch (OOM protection)
|
| 24 |
+
MAX_FETCH_ROWS = 50
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class TimeoutError(Exception):
|
| 28 |
+
"""Custom exception for query timeout."""
|
| 29 |
+
pass
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@contextmanager
|
| 33 |
+
def timeout_handler(seconds: float):
|
| 34 |
+
"""
|
| 35 |
+
Context manager for query timeout.
|
| 36 |
+
Note: signal.alarm only works on Unix. On Windows, we use a simpler approach.
|
| 37 |
+
"""
|
| 38 |
+
# On Windows, signal.SIGALRM is not available
|
| 39 |
+
# We implement a basic timeout check instead
|
| 40 |
+
if os.name == 'nt':
|
| 41 |
+
# Windows: No signal-based timeout, rely on sqlite3 timeout
|
| 42 |
+
yield
|
| 43 |
+
else:
|
| 44 |
+
def handler(signum, frame):
|
| 45 |
+
raise TimeoutError(f"Query execution exceeded {seconds} seconds timeout")
|
| 46 |
+
|
| 47 |
+
old_handler = signal.signal(signal.SIGALRM, handler)
|
| 48 |
+
signal.setitimer(signal.ITIMER_REAL, seconds)
|
| 49 |
+
try:
|
| 50 |
+
yield
|
| 51 |
+
finally:
|
| 52 |
+
signal.setitimer(signal.ITIMER_REAL, 0)
|
| 53 |
+
signal.signal(signal.SIGALRM, old_handler)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class DatabaseEngine:
|
| 57 |
+
"""
|
| 58 |
+
SQLite Database Engine with security safeguards.
|
| 59 |
+
|
| 60 |
+
Features:
|
| 61 |
+
- In-memory SQLite database (:memory: mode)
|
| 62 |
+
- Mutation Blocker: Regex-based blocking of INSERT, UPDATE, DELETE, DROP, ALTER, TRUNCATE
|
| 63 |
+
- OOM Protection: cursor.fetchmany(50), never fetchall()
|
| 64 |
+
- Timeout Wrapper: 2.0-second timeout for query execution
|
| 65 |
+
- Stringified errors: Never raises Python exceptions to caller
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
def __init__(self):
|
| 69 |
+
"""Initialize the database engine with an in-memory SQLite database."""
|
| 70 |
+
self.connection: Optional[sqlite3.Connection] = None
|
| 71 |
+
self.cursor: Optional[sqlite3.Cursor] = None
|
| 72 |
+
self._schema_cache: Optional[str] = None
|
| 73 |
+
|
| 74 |
+
def initialize(self) -> str:
|
| 75 |
+
"""
|
| 76 |
+
Initialize a clean in-memory SQLite database and load mock data.
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
str: Success message or error string
|
| 80 |
+
"""
|
| 81 |
+
try:
|
| 82 |
+
# Close existing connection if any
|
| 83 |
+
self.close()
|
| 84 |
+
|
| 85 |
+
# Create new in-memory database
|
| 86 |
+
self.connection = sqlite3.connect(
|
| 87 |
+
':memory:',
|
| 88 |
+
timeout=QUERY_TIMEOUT,
|
| 89 |
+
check_same_thread=False
|
| 90 |
+
)
|
| 91 |
+
self.cursor = self.connection.cursor()
|
| 92 |
+
|
| 93 |
+
# Load mock data from SQL file
|
| 94 |
+
mock_data_path = Path(__file__).parent.parent / 'data' / 'mock_data.sql'
|
| 95 |
+
|
| 96 |
+
if mock_data_path.exists():
|
| 97 |
+
with open(mock_data_path, 'r') as f:
|
| 98 |
+
sql_script = f.read()
|
| 99 |
+
self.cursor.executescript(sql_script)
|
| 100 |
+
self.connection.commit()
|
| 101 |
+
else:
|
| 102 |
+
return f"Error: Mock data file not found at {mock_data_path}"
|
| 103 |
+
|
| 104 |
+
# Cache schema info
|
| 105 |
+
self._schema_cache = self._get_schema_info()
|
| 106 |
+
|
| 107 |
+
return "Database initialized successfully"
|
| 108 |
+
|
| 109 |
+
except Exception as e:
|
| 110 |
+
return f"Error initializing database: {str(e)}"
|
| 111 |
+
|
| 112 |
+
def _get_schema_info(self) -> str:
|
| 113 |
+
"""
|
| 114 |
+
Get database schema information for the agent.
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
str: Formatted schema information
|
| 118 |
+
"""
|
| 119 |
+
if not self.cursor:
|
| 120 |
+
return "Error: Database not initialized"
|
| 121 |
+
|
| 122 |
+
try:
|
| 123 |
+
# Get all table names
|
| 124 |
+
self.cursor.execute(
|
| 125 |
+
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY name"
|
| 126 |
+
)
|
| 127 |
+
tables = [row[0] for row in self.cursor.fetchmany(MAX_FETCH_ROWS)]
|
| 128 |
+
|
| 129 |
+
schema_parts = ["DATABASE SCHEMA:", "=" * 50]
|
| 130 |
+
|
| 131 |
+
for table in tables:
|
| 132 |
+
schema_parts.append(f"\nTable: {table}")
|
| 133 |
+
schema_parts.append("-" * 30)
|
| 134 |
+
|
| 135 |
+
# Get column info using PRAGMA
|
| 136 |
+
self.cursor.execute(f"PRAGMA table_info({table})")
|
| 137 |
+
columns = self.cursor.fetchmany(MAX_FETCH_ROWS)
|
| 138 |
+
|
| 139 |
+
for col in columns:
|
| 140 |
+
col_id, name, col_type, not_null, default, pk = col
|
| 141 |
+
pk_marker = " [PRIMARY KEY]" if pk else ""
|
| 142 |
+
null_marker = " NOT NULL" if not_null else ""
|
| 143 |
+
schema_parts.append(f" - {name}: {col_type}{null_marker}{pk_marker}")
|
| 144 |
+
|
| 145 |
+
return "\n".join(schema_parts)
|
| 146 |
+
|
| 147 |
+
except Exception as e:
|
| 148 |
+
return f"Error getting schema: {str(e)}"
|
| 149 |
+
|
| 150 |
+
def get_schema(self) -> str:
|
| 151 |
+
"""
|
| 152 |
+
Get cached schema information.
|
| 153 |
+
|
| 154 |
+
Returns:
|
| 155 |
+
str: Schema information string
|
| 156 |
+
"""
|
| 157 |
+
if self._schema_cache:
|
| 158 |
+
return self._schema_cache
|
| 159 |
+
return self._get_schema_info()
|
| 160 |
+
|
| 161 |
+
def check_mutation(self, query: str) -> Optional[str]:
|
| 162 |
+
"""
|
| 163 |
+
Check if query contains mutation operations.
|
| 164 |
+
|
| 165 |
+
Args:
|
| 166 |
+
query: SQL query string
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
Optional[str]: Error message if mutation detected, None otherwise
|
| 170 |
+
"""
|
| 171 |
+
match = MUTATION_PATTERN.search(query)
|
| 172 |
+
if match:
|
| 173 |
+
matched = match.group(1).upper()
|
| 174 |
+
return (
|
| 175 |
+
f"DESTRUCTIVE_ACTION_BLOCKED: {matched} operations are not allowed. "
|
| 176 |
+
f"This environment is read-only. Only SELECT queries are permitted."
|
| 177 |
+
)
|
| 178 |
+
return None
|
| 179 |
+
|
| 180 |
+
def execute_query(self, query: str) -> Tuple[str, bool]:
|
| 181 |
+
"""
|
| 182 |
+
Execute a SQL query with all safety measures.
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
query: SQL query string
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
Tuple[str, bool]: (result_string, is_error)
|
| 189 |
+
- result_string: Query results or error message
|
| 190 |
+
- is_error: True if an error occurred, False otherwise
|
| 191 |
+
"""
|
| 192 |
+
if not self.connection or not self.cursor:
|
| 193 |
+
return "Error: Database not initialized", True
|
| 194 |
+
|
| 195 |
+
# Strip and validate query
|
| 196 |
+
query = query.strip()
|
| 197 |
+
if not query:
|
| 198 |
+
return "Error: Empty query provided", True
|
| 199 |
+
|
| 200 |
+
# MUTATION BLOCKER: Check for destructive operations
|
| 201 |
+
mutation_error = self.check_mutation(query)
|
| 202 |
+
if mutation_error:
|
| 203 |
+
return mutation_error, True
|
| 204 |
+
|
| 205 |
+
try:
|
| 206 |
+
# Execute with timeout protection
|
| 207 |
+
with timeout_handler(QUERY_TIMEOUT):
|
| 208 |
+
self.cursor.execute(query)
|
| 209 |
+
|
| 210 |
+
# OOM PROTECTION: Use fetchmany(50), NEVER fetchall()
|
| 211 |
+
rows = self.cursor.fetchmany(MAX_FETCH_ROWS)
|
| 212 |
+
|
| 213 |
+
if not rows:
|
| 214 |
+
# Check if it was a query that doesn't return rows
|
| 215 |
+
if self.cursor.description is None:
|
| 216 |
+
return "Query executed successfully (no results)", False
|
| 217 |
+
return "Query returned no results", False
|
| 218 |
+
|
| 219 |
+
# Get column names
|
| 220 |
+
columns = [desc[0] for desc in self.cursor.description]
|
| 221 |
+
|
| 222 |
+
# Format results
|
| 223 |
+
result_lines = []
|
| 224 |
+
result_lines.append("| " + " | ".join(columns) + " |")
|
| 225 |
+
result_lines.append("|" + "|".join(["---"] * len(columns)) + "|")
|
| 226 |
+
|
| 227 |
+
for row in rows:
|
| 228 |
+
formatted_row = [str(val) if val is not None else "NULL" for val in row]
|
| 229 |
+
result_lines.append("| " + " | ".join(formatted_row) + " |")
|
| 230 |
+
|
| 231 |
+
result = "\n".join(result_lines)
|
| 232 |
+
|
| 233 |
+
# Check if results were truncated
|
| 234 |
+
# Try to fetch one more row to see if there are more
|
| 235 |
+
extra = self.cursor.fetchmany(1)
|
| 236 |
+
if extra:
|
| 237 |
+
result += f"\n\n[TRUNCATED] Results limited to {MAX_FETCH_ROWS} rows. More rows exist."
|
| 238 |
+
|
| 239 |
+
return result, False
|
| 240 |
+
|
| 241 |
+
except TimeoutError as e:
|
| 242 |
+
return f"Error: {str(e)}", True
|
| 243 |
+
except sqlite3.Error as e:
|
| 244 |
+
return f"SQLite Error: {str(e)}", True
|
| 245 |
+
except Exception as e:
|
| 246 |
+
return f"Error: {str(e)}", True
|
| 247 |
+
|
| 248 |
+
def close(self):
|
| 249 |
+
"""Close the database connection."""
|
| 250 |
+
if self.cursor:
|
| 251 |
+
self.cursor.close()
|
| 252 |
+
self.cursor = None
|
| 253 |
+
if self.connection:
|
| 254 |
+
self.connection.close()
|
| 255 |
+
self.connection = None
|
| 256 |
+
self._schema_cache = None
|
| 257 |
+
|
| 258 |
+
def __del__(self):
|
| 259 |
+
"""Destructor to ensure connection is closed."""
|
| 260 |
+
self.close()
|
openenv-sql-analyst/environment/env.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# environment/env.py
|
| 2 |
+
# Main OpenEnv Environment for SQL Data Analyst
|
| 3 |
+
# Inherits from openenv.BaseEnv and implements reset(), step(), state()
|
| 4 |
+
|
| 5 |
+
from typing import Dict, Any, Tuple, Optional
|
| 6 |
+
from dataclasses import dataclass, field
|
| 7 |
+
from .models import Action, Observation, Reward
|
| 8 |
+
from .db_engine import DatabaseEngine
|
| 9 |
+
from .tasks import Task, get_random_task, TASKS
|
| 10 |
+
from .graders import grade_answer, calculate_final_score
|
| 11 |
+
|
| 12 |
+
# Try to import openenv.BaseEnv, fallback to a simple base class if not available
|
| 13 |
+
try:
|
| 14 |
+
from openenv import BaseEnv
|
| 15 |
+
except ImportError:
|
| 16 |
+
# Fallback base class for development/testing
|
| 17 |
+
class BaseEnv:
|
| 18 |
+
"""Fallback base class when openenv-core is not installed."""
|
| 19 |
+
pass
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# ============================================
|
| 23 |
+
# REWARD CONSTANTS (per PRD specification)
|
| 24 |
+
# ============================================
|
| 25 |
+
REWARD_SUCCESSFUL_QUERY = 0.1 # Successful, error-free SQL query
|
| 26 |
+
REWARD_SYNTAX_ERROR = -0.1 # SQLite syntax error
|
| 27 |
+
REWARD_DESTRUCTIVE_ACTION = -1.0 # Destructive action detected
|
| 28 |
+
REWARD_INFINITE_LOOP = -0.5 # Step count >= 15
|
| 29 |
+
|
| 30 |
+
# Maximum steps before infinite loop shield activates
|
| 31 |
+
MAX_STEPS = 15
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclass
|
| 35 |
+
class EnvironmentState:
|
| 36 |
+
"""
|
| 37 |
+
Internal state of the SQL Analyst environment.
|
| 38 |
+
|
| 39 |
+
Attributes:
|
| 40 |
+
task: The current task being solved
|
| 41 |
+
step_count: Number of steps taken in current episode
|
| 42 |
+
done: Whether the episode has ended
|
| 43 |
+
last_query_result: Result from the most recent SQL query
|
| 44 |
+
error_message: Error message from the last action
|
| 45 |
+
rewards: List of all rewards received in this episode
|
| 46 |
+
final_score: The final grading score (0.0 to 1.0)
|
| 47 |
+
success: Whether the task was completed successfully
|
| 48 |
+
"""
|
| 49 |
+
task: Optional[Task] = None
|
| 50 |
+
step_count: int = 0
|
| 51 |
+
done: bool = False
|
| 52 |
+
last_query_result: str = ""
|
| 53 |
+
error_message: str = ""
|
| 54 |
+
rewards: list = field(default_factory=list)
|
| 55 |
+
final_score: float = 0.0
|
| 56 |
+
success: bool = False
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class SQLAnalystEnv(BaseEnv):
|
| 60 |
+
"""
|
| 61 |
+
SQL Data Analyst Reinforcement Learning Environment.
|
| 62 |
+
|
| 63 |
+
This environment simulates a Data Analyst workspace where an AI agent
|
| 64 |
+
queries a SQLite database to answer business questions.
|
| 65 |
+
|
| 66 |
+
Implements the OpenEnv interface:
|
| 67 |
+
- reset(): Initialize a clean episode
|
| 68 |
+
- step(action): Execute an action and return (observation, reward, done, info)
|
| 69 |
+
- state(): Return the current internal state
|
| 70 |
+
|
| 71 |
+
Reward Shaping (per PRD):
|
| 72 |
+
- +0.1: Successful, error-free SQL query
|
| 73 |
+
- -0.1: SQLite syntax error
|
| 74 |
+
- -1.0: Destructive action detected (done=True)
|
| 75 |
+
- -0.5: Step count >= 15 (infinite loop shield, done=True)
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
def __init__(self):
|
| 79 |
+
"""Initialize the SQL Analyst environment."""
|
| 80 |
+
super().__init__()
|
| 81 |
+
self.db_engine = DatabaseEngine()
|
| 82 |
+
self._state = EnvironmentState()
|
| 83 |
+
|
| 84 |
+
def reset(self, task_id: Optional[str] = None) -> Observation:
|
| 85 |
+
"""
|
| 86 |
+
Reset the environment to start a new episode.
|
| 87 |
+
|
| 88 |
+
This method:
|
| 89 |
+
1. Initializes a clean in-memory SQLite database
|
| 90 |
+
2. Randomly selects 1 of the 3 tasks (or uses specified task)
|
| 91 |
+
3. Resets step_count to 0
|
| 92 |
+
4. Returns the initial observation
|
| 93 |
+
|
| 94 |
+
Args:
|
| 95 |
+
task_id: Optional specific task to use
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
Observation: The initial observation for the episode
|
| 99 |
+
"""
|
| 100 |
+
# Initialize clean database
|
| 101 |
+
self.db_engine.initialize()
|
| 102 |
+
|
| 103 |
+
# Select task
|
| 104 |
+
if task_id:
|
| 105 |
+
for task in TASKS:
|
| 106 |
+
if task.task_id == task_id:
|
| 107 |
+
self._state.task = task
|
| 108 |
+
break
|
| 109 |
+
else:
|
| 110 |
+
self._state.task = get_random_task()
|
| 111 |
+
else:
|
| 112 |
+
self._state.task = get_random_task()
|
| 113 |
+
|
| 114 |
+
# Reset state
|
| 115 |
+
self._state.step_count = 0
|
| 116 |
+
self._state.done = False
|
| 117 |
+
self._state.last_query_result = ""
|
| 118 |
+
self._state.error_message = ""
|
| 119 |
+
self._state.rewards = []
|
| 120 |
+
self._state.final_score = 0.0
|
| 121 |
+
self._state.success = False
|
| 122 |
+
|
| 123 |
+
# Build initial observation
|
| 124 |
+
return Observation(
|
| 125 |
+
schema_info=self.db_engine.get_schema(),
|
| 126 |
+
current_question=self._state.task.question,
|
| 127 |
+
last_query_result="No queries executed yet.",
|
| 128 |
+
error_message=""
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
def step(self, action: Action) -> Tuple[Observation, Reward, bool, Dict[str, Any]]:
|
| 132 |
+
"""
|
| 133 |
+
Execute an action in the environment.
|
| 134 |
+
|
| 135 |
+
This method processes the agent's action and returns:
|
| 136 |
+
- observation: The new state after the action
|
| 137 |
+
- reward: The reward for this action
|
| 138 |
+
- done: Whether the episode has ended
|
| 139 |
+
- info: Additional information
|
| 140 |
+
|
| 141 |
+
Reward Shaping:
|
| 142 |
+
- +0.1: Successful, error-free SQL query
|
| 143 |
+
- -0.1: SQLite syntax error
|
| 144 |
+
- -1.0: Destructive action detected (done=True)
|
| 145 |
+
- -0.5: Step count >= 15 (done=True)
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
action: The Action to execute
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
Tuple containing (observation, reward, done, info)
|
| 152 |
+
"""
|
| 153 |
+
if self._state.done:
|
| 154 |
+
# Episode already ended
|
| 155 |
+
return self._get_observation(), Reward(value=0.0), True, self._get_info()
|
| 156 |
+
|
| 157 |
+
# Increment step count
|
| 158 |
+
self._state.step_count += 1
|
| 159 |
+
|
| 160 |
+
# Check for infinite loop shield FIRST
|
| 161 |
+
if self._state.step_count >= MAX_STEPS:
|
| 162 |
+
self._state.done = True
|
| 163 |
+
self._state.error_message = f"Maximum steps ({MAX_STEPS}) reached. Episode terminated."
|
| 164 |
+
reward = REWARD_INFINITE_LOOP
|
| 165 |
+
self._state.rewards.append(reward)
|
| 166 |
+
return self._get_observation(), Reward(value=reward), True, self._get_info()
|
| 167 |
+
|
| 168 |
+
# Initialize reward for this step
|
| 169 |
+
reward = 0.0
|
| 170 |
+
self._state.error_message = ""
|
| 171 |
+
|
| 172 |
+
# Process action
|
| 173 |
+
if action.sql_query:
|
| 174 |
+
reward = self._handle_sql_query(action.sql_query)
|
| 175 |
+
elif action.submit_answer:
|
| 176 |
+
reward = self._handle_submit_answer(action.submit_answer)
|
| 177 |
+
|
| 178 |
+
# Record reward
|
| 179 |
+
self._state.rewards.append(reward)
|
| 180 |
+
|
| 181 |
+
return self._get_observation(), Reward(value=reward), self._state.done, self._get_info()
|
| 182 |
+
|
| 183 |
+
def _handle_sql_query(self, query: str) -> float:
|
| 184 |
+
"""
|
| 185 |
+
Handle a SQL query action.
|
| 186 |
+
|
| 187 |
+
Args:
|
| 188 |
+
query: The SQL query to execute
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
float: The reward for this action
|
| 192 |
+
"""
|
| 193 |
+
# Check for destructive action first
|
| 194 |
+
mutation_error = self.db_engine.check_mutation(query)
|
| 195 |
+
if mutation_error:
|
| 196 |
+
self._state.done = True
|
| 197 |
+
self._state.error_message = mutation_error
|
| 198 |
+
self._state.last_query_result = ""
|
| 199 |
+
return REWARD_DESTRUCTIVE_ACTION
|
| 200 |
+
|
| 201 |
+
# Execute the query
|
| 202 |
+
result, is_error = self.db_engine.execute_query(query)
|
| 203 |
+
|
| 204 |
+
if is_error:
|
| 205 |
+
self._state.error_message = result
|
| 206 |
+
self._state.last_query_result = ""
|
| 207 |
+
return REWARD_SYNTAX_ERROR
|
| 208 |
+
|
| 209 |
+
# Successful query
|
| 210 |
+
self._state.last_query_result = result
|
| 211 |
+
self._state.error_message = ""
|
| 212 |
+
return REWARD_SUCCESSFUL_QUERY
|
| 213 |
+
|
| 214 |
+
def _handle_submit_answer(self, answer: str) -> float:
|
| 215 |
+
"""
|
| 216 |
+
Handle a submit answer action.
|
| 217 |
+
|
| 218 |
+
Args:
|
| 219 |
+
answer: The answer to submit for grading
|
| 220 |
+
|
| 221 |
+
Returns:
|
| 222 |
+
float: The reward for this action
|
| 223 |
+
"""
|
| 224 |
+
# Episode ends when answer is submitted
|
| 225 |
+
self._state.done = True
|
| 226 |
+
|
| 227 |
+
# Grade the answer
|
| 228 |
+
is_correct, grading_score = grade_answer(
|
| 229 |
+
answer,
|
| 230 |
+
self._state.task.ground_truth,
|
| 231 |
+
self.db_engine
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
# Calculate final score
|
| 235 |
+
self._state.success = is_correct
|
| 236 |
+
self._state.final_score = calculate_final_score(
|
| 237 |
+
is_correct,
|
| 238 |
+
self._state.step_count,
|
| 239 |
+
MAX_STEPS
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
# Reward for submission is based on correctness
|
| 243 |
+
# This is separate from the final_score which considers efficiency
|
| 244 |
+
if is_correct:
|
| 245 |
+
return 1.0 # Full reward for correct answer
|
| 246 |
+
else:
|
| 247 |
+
return 0.0 # No reward for incorrect answer
|
| 248 |
+
|
| 249 |
+
def _get_observation(self) -> Observation:
|
| 250 |
+
"""
|
| 251 |
+
Build the current observation.
|
| 252 |
+
|
| 253 |
+
Returns:
|
| 254 |
+
Observation: The current state visible to the agent
|
| 255 |
+
"""
|
| 256 |
+
return Observation(
|
| 257 |
+
schema_info=self.db_engine.get_schema(),
|
| 258 |
+
current_question=self._state.task.question if self._state.task else "",
|
| 259 |
+
last_query_result=self._state.last_query_result or "No results yet.",
|
| 260 |
+
error_message=self._state.error_message
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
def _get_info(self) -> Dict[str, Any]:
|
| 264 |
+
"""
|
| 265 |
+
Build the info dictionary.
|
| 266 |
+
|
| 267 |
+
Returns:
|
| 268 |
+
Dict: Additional information about the current state
|
| 269 |
+
"""
|
| 270 |
+
return {
|
| 271 |
+
"step_count": self._state.step_count,
|
| 272 |
+
"task_id": self._state.task.task_id if self._state.task else None,
|
| 273 |
+
"task_difficulty": self._state.task.difficulty if self._state.task else None,
|
| 274 |
+
"success": self._state.success,
|
| 275 |
+
"final_score": self._state.final_score,
|
| 276 |
+
"total_reward": sum(self._state.rewards),
|
| 277 |
+
"rewards_history": self._state.rewards.copy()
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
def state(self) -> Dict[str, Any]:
|
| 281 |
+
"""
|
| 282 |
+
Return the current internal state of the environment.
|
| 283 |
+
|
| 284 |
+
Returns:
|
| 285 |
+
Dict: The full internal state
|
| 286 |
+
"""
|
| 287 |
+
return {
|
| 288 |
+
"task_id": self._state.task.task_id if self._state.task else None,
|
| 289 |
+
"task_difficulty": self._state.task.difficulty if self._state.task else None,
|
| 290 |
+
"task_question": self._state.task.question if self._state.task else None,
|
| 291 |
+
"step_count": self._state.step_count,
|
| 292 |
+
"done": self._state.done,
|
| 293 |
+
"last_query_result": self._state.last_query_result,
|
| 294 |
+
"error_message": self._state.error_message,
|
| 295 |
+
"rewards": self._state.rewards.copy(),
|
| 296 |
+
"total_reward": sum(self._state.rewards),
|
| 297 |
+
"success": self._state.success,
|
| 298 |
+
"final_score": self._state.final_score
|
| 299 |
+
}
|
| 300 |
+
|
| 301 |
+
def close(self):
|
| 302 |
+
"""Clean up resources."""
|
| 303 |
+
if self.db_engine:
|
| 304 |
+
self.db_engine.close()
|
openenv-sql-analyst/environment/graders.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# environment/graders.py
|
| 2 |
+
# Deterministic grading system for SQL Data Analyst environment
|
| 3 |
+
# Implements type-agnostic normalization and SQL evaluation
|
| 4 |
+
|
| 5 |
+
from typing import Any, Tuple, Optional
|
| 6 |
+
import re
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def normalize_value(value: Any) -> str:
|
| 10 |
+
"""
|
| 11 |
+
Normalize a value for comparison.
|
| 12 |
+
|
| 13 |
+
Type-Agnostic Normalization:
|
| 14 |
+
- Strip whitespace
|
| 15 |
+
- Lowercase strings
|
| 16 |
+
- Handle numeric conversions
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
value: Any value to normalize
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
str: Normalized string representation
|
| 23 |
+
"""
|
| 24 |
+
if value is None:
|
| 25 |
+
return ""
|
| 26 |
+
|
| 27 |
+
# Convert to string first
|
| 28 |
+
str_value = str(value).strip().lower()
|
| 29 |
+
|
| 30 |
+
# Remove extra whitespace
|
| 31 |
+
str_value = re.sub(r'\s+', ' ', str_value)
|
| 32 |
+
|
| 33 |
+
# Try to normalize numeric values
|
| 34 |
+
try:
|
| 35 |
+
# Try float first
|
| 36 |
+
float_val = float(str_value)
|
| 37 |
+
# Round to 2 decimal places for comparison
|
| 38 |
+
return str(round(float_val, 2))
|
| 39 |
+
except (ValueError, TypeError):
|
| 40 |
+
pass
|
| 41 |
+
|
| 42 |
+
return str_value
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def extract_numeric(value: str) -> Optional[float]:
|
| 46 |
+
"""
|
| 47 |
+
Extract a numeric value from a string.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
value: String that may contain a number
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
Optional[float]: Extracted number or None
|
| 54 |
+
"""
|
| 55 |
+
# Remove common formatting
|
| 56 |
+
cleaned = re.sub(r'[$,]', '', str(value).strip())
|
| 57 |
+
|
| 58 |
+
try:
|
| 59 |
+
return float(cleaned)
|
| 60 |
+
except (ValueError, TypeError):
|
| 61 |
+
return None
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def compare_values(submitted: Any, ground_truth: Any) -> Tuple[bool, float]:
|
| 65 |
+
"""
|
| 66 |
+
Compare submitted answer to ground truth.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
submitted: The agent's submitted answer
|
| 70 |
+
ground_truth: The expected correct answer
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
Tuple[bool, float]: (is_correct, score)
|
| 74 |
+
- is_correct: True if answer matches
|
| 75 |
+
- score: Value between 0.0 and 1.0
|
| 76 |
+
"""
|
| 77 |
+
# Normalize both values
|
| 78 |
+
norm_submitted = normalize_value(submitted)
|
| 79 |
+
norm_truth = normalize_value(ground_truth)
|
| 80 |
+
|
| 81 |
+
# Direct string comparison after normalization
|
| 82 |
+
if norm_submitted == norm_truth:
|
| 83 |
+
return True, 1.0
|
| 84 |
+
|
| 85 |
+
# Try numeric comparison for numeric ground truths
|
| 86 |
+
if isinstance(ground_truth, (int, float)):
|
| 87 |
+
submitted_num = extract_numeric(submitted)
|
| 88 |
+
if submitted_num is not None:
|
| 89 |
+
truth_num = float(ground_truth)
|
| 90 |
+
# Allow small floating point tolerance
|
| 91 |
+
if abs(submitted_num - truth_num) < 0.01:
|
| 92 |
+
return True, 1.0
|
| 93 |
+
# Partial credit for being close (within 10%)
|
| 94 |
+
if truth_num != 0:
|
| 95 |
+
error_pct = abs(submitted_num - truth_num) / abs(truth_num)
|
| 96 |
+
if error_pct < 0.1:
|
| 97 |
+
return False, 0.5
|
| 98 |
+
|
| 99 |
+
# Check if submitted answer contains the ground truth
|
| 100 |
+
if norm_truth in norm_submitted:
|
| 101 |
+
return True, 1.0
|
| 102 |
+
|
| 103 |
+
return False, 0.0
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def grade_sql_result(
|
| 107 |
+
query_result: str,
|
| 108 |
+
ground_truth: Any,
|
| 109 |
+
is_error: bool
|
| 110 |
+
) -> Tuple[bool, float]:
|
| 111 |
+
"""
|
| 112 |
+
Grade a SQL query result against ground truth.
|
| 113 |
+
|
| 114 |
+
If the agent submits a SQL query as the final answer,
|
| 115 |
+
this function evaluates the query result.
|
| 116 |
+
|
| 117 |
+
Args:
|
| 118 |
+
query_result: The result string from executing the SQL query
|
| 119 |
+
ground_truth: The expected correct answer
|
| 120 |
+
is_error: Whether the query execution resulted in an error
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
Tuple[bool, float]: (is_correct, score)
|
| 124 |
+
"""
|
| 125 |
+
if is_error:
|
| 126 |
+
return False, 0.0
|
| 127 |
+
|
| 128 |
+
# Parse the query result to extract values
|
| 129 |
+
# Result format is markdown table: | col1 | col2 |
|
| 130 |
+
lines = query_result.strip().split('\n')
|
| 131 |
+
|
| 132 |
+
# Skip header and separator lines
|
| 133 |
+
data_lines = [l for l in lines if l.strip() and not l.startswith('|---')]
|
| 134 |
+
|
| 135 |
+
if len(data_lines) < 2: # Need at least header + 1 data row
|
| 136 |
+
return False, 0.0
|
| 137 |
+
|
| 138 |
+
# Get the first data row (skip header)
|
| 139 |
+
data_row = data_lines[1] if len(data_lines) > 1 else ""
|
| 140 |
+
|
| 141 |
+
# Extract values from the row
|
| 142 |
+
values = [v.strip() for v in data_row.split('|') if v.strip()]
|
| 143 |
+
|
| 144 |
+
if not values:
|
| 145 |
+
return False, 0.0
|
| 146 |
+
|
| 147 |
+
# For single-value answers, compare the first value
|
| 148 |
+
# For multi-column results, try each value
|
| 149 |
+
for value in values:
|
| 150 |
+
is_correct, score = compare_values(value, ground_truth)
|
| 151 |
+
if is_correct:
|
| 152 |
+
return True, score
|
| 153 |
+
|
| 154 |
+
return False, 0.0
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def grade_answer(
|
| 158 |
+
submitted_answer: str,
|
| 159 |
+
ground_truth: Any,
|
| 160 |
+
db_engine: Any = None
|
| 161 |
+
) -> Tuple[bool, float]:
|
| 162 |
+
"""
|
| 163 |
+
Grade the agent's submitted answer.
|
| 164 |
+
|
| 165 |
+
This is the main grading function called by the environment.
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
submitted_answer: The agent's submitted answer string
|
| 169 |
+
ground_truth: The expected correct answer
|
| 170 |
+
db_engine: Optional database engine for SQL evaluation
|
| 171 |
+
|
| 172 |
+
Returns:
|
| 173 |
+
Tuple[bool, float]: (is_correct, score)
|
| 174 |
+
- is_correct: True if answer is correct
|
| 175 |
+
- score: Value strictly between 0.0 and 1.0
|
| 176 |
+
"""
|
| 177 |
+
if not submitted_answer or not submitted_answer.strip():
|
| 178 |
+
return False, 0.0
|
| 179 |
+
|
| 180 |
+
submitted = submitted_answer.strip()
|
| 181 |
+
|
| 182 |
+
# Check if the submitted answer looks like a SQL query
|
| 183 |
+
sql_keywords = ['SELECT', 'FROM', 'WHERE', 'JOIN', 'GROUP', 'ORDER']
|
| 184 |
+
is_sql_query = any(
|
| 185 |
+
keyword in submitted.upper()
|
| 186 |
+
for keyword in sql_keywords
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
if is_sql_query and db_engine is not None:
|
| 190 |
+
# Execute the SQL and grade the result
|
| 191 |
+
result, is_error = db_engine.execute_query(submitted)
|
| 192 |
+
return grade_sql_result(result, ground_truth, is_error)
|
| 193 |
+
|
| 194 |
+
# Direct answer comparison
|
| 195 |
+
return compare_values(submitted, ground_truth)
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def calculate_final_score(
|
| 199 |
+
is_correct: bool,
|
| 200 |
+
total_steps: int,
|
| 201 |
+
max_steps: int = 15
|
| 202 |
+
) -> float:
|
| 203 |
+
"""
|
| 204 |
+
Calculate the final score for a task.
|
| 205 |
+
|
| 206 |
+
Scoring factors:
|
| 207 |
+
- Correctness is primary (0 if incorrect)
|
| 208 |
+
- Efficiency bonus for fewer steps
|
| 209 |
+
|
| 210 |
+
Args:
|
| 211 |
+
is_correct: Whether the answer was correct
|
| 212 |
+
total_steps: Number of steps taken
|
| 213 |
+
max_steps: Maximum allowed steps
|
| 214 |
+
|
| 215 |
+
Returns:
|
| 216 |
+
float: Final score between 0.0 and 1.0
|
| 217 |
+
"""
|
| 218 |
+
if not is_correct:
|
| 219 |
+
return 0.0
|
| 220 |
+
|
| 221 |
+
# Base score for correct answer
|
| 222 |
+
base_score = 0.7
|
| 223 |
+
|
| 224 |
+
# Efficiency bonus (up to 0.3)
|
| 225 |
+
# Fewer steps = higher bonus
|
| 226 |
+
efficiency_ratio = 1.0 - (total_steps / max_steps)
|
| 227 |
+
efficiency_bonus = max(0.0, efficiency_ratio * 0.3)
|
| 228 |
+
|
| 229 |
+
final_score = base_score + efficiency_bonus
|
| 230 |
+
|
| 231 |
+
# Ensure score is strictly between 0.0 and 1.0
|
| 232 |
+
return min(1.0, max(0.0, final_score))
|
openenv-sql-analyst/environment/models.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# environment/models.py
|
| 2 |
+
# Typed Pydantic models for OpenEnv interface
|
| 3 |
+
# Implements Action, Observation, and Reward schemas
|
| 4 |
+
|
| 5 |
+
from typing import Optional
|
| 6 |
+
from pydantic import BaseModel, model_validator
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Action(BaseModel):
|
| 10 |
+
"""
|
| 11 |
+
Action model for the SQL Analyst environment.
|
| 12 |
+
|
| 13 |
+
The agent must provide EXACTLY ONE of:
|
| 14 |
+
- sql_query: Execute a SQL query against the database
|
| 15 |
+
- submit_answer: Submit a final answer for grading
|
| 16 |
+
|
| 17 |
+
Edge Case Shield: Pydantic model_validator enforces mutual exclusivity.
|
| 18 |
+
"""
|
| 19 |
+
sql_query: Optional[str] = None
|
| 20 |
+
submit_answer: Optional[str] = None
|
| 21 |
+
|
| 22 |
+
@model_validator(mode='after')
|
| 23 |
+
def validate_exactly_one_action(self) -> 'Action':
|
| 24 |
+
"""
|
| 25 |
+
Enforce that the agent provides exactly one of sql_query or submit_answer.
|
| 26 |
+
This prevents ambiguous actions and ensures clean state transitions.
|
| 27 |
+
"""
|
| 28 |
+
has_sql = self.sql_query is not None and self.sql_query.strip() != ""
|
| 29 |
+
has_answer = self.submit_answer is not None and self.submit_answer.strip() != ""
|
| 30 |
+
|
| 31 |
+
if has_sql and has_answer:
|
| 32 |
+
raise ValueError(
|
| 33 |
+
"Invalid action: Provide exactly ONE of 'sql_query' or 'submit_answer', not both."
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
if not has_sql and not has_answer:
|
| 37 |
+
raise ValueError(
|
| 38 |
+
"Invalid action: Must provide exactly ONE of 'sql_query' or 'submit_answer'."
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
return self
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class Observation(BaseModel):
|
| 45 |
+
"""
|
| 46 |
+
Observation model representing the current state visible to the agent.
|
| 47 |
+
|
| 48 |
+
Fields:
|
| 49 |
+
- schema_info: Database schema information (tables, columns, types)
|
| 50 |
+
- current_question: The task question the agent must answer
|
| 51 |
+
- last_query_result: Result from the most recent SQL query execution
|
| 52 |
+
- error_message: Any error from the last action (empty string if none)
|
| 53 |
+
"""
|
| 54 |
+
schema_info: str
|
| 55 |
+
current_question: str
|
| 56 |
+
last_query_result: str
|
| 57 |
+
error_message: str
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class Reward(BaseModel):
|
| 61 |
+
"""
|
| 62 |
+
Reward model containing a single float value.
|
| 63 |
+
|
| 64 |
+
Reward shaping follows the PRD specification:
|
| 65 |
+
- +0.1: Successful, error-free SQL query
|
| 66 |
+
- -0.1: SQLite syntax error
|
| 67 |
+
- -1.0: Destructive action detected (done=True)
|
| 68 |
+
- -0.5: Step count >= 15 (infinite loop shield, done=True)
|
| 69 |
+
"""
|
| 70 |
+
value: float
|
openenv-sql-analyst/environment/tasks.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# environment/tasks.py
|
| 2 |
+
# Task definitions for SQL Data Analyst environment
|
| 3 |
+
# 3 Tasks: Easy (single table COUNT), Medium (JOIN + aggregation), Hard (subquery/ordering)
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import List, Callable, Any
|
| 7 |
+
import random
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class Task:
|
| 12 |
+
"""
|
| 13 |
+
Represents a data analysis task for the agent.
|
| 14 |
+
|
| 15 |
+
Attributes:
|
| 16 |
+
task_id: Unique identifier for the task
|
| 17 |
+
difficulty: easy, medium, or hard
|
| 18 |
+
question: The business question to answer
|
| 19 |
+
ground_truth: The expected correct answer
|
| 20 |
+
ground_truth_sql: A SQL query that produces the correct answer
|
| 21 |
+
description: Additional context about the task
|
| 22 |
+
"""
|
| 23 |
+
task_id: str
|
| 24 |
+
difficulty: str
|
| 25 |
+
question: str
|
| 26 |
+
ground_truth: Any
|
| 27 |
+
ground_truth_sql: str
|
| 28 |
+
description: str
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# ============================================
|
| 32 |
+
# TASK DEFINITIONS
|
| 33 |
+
# ============================================
|
| 34 |
+
|
| 35 |
+
TASK_EASY = Task(
|
| 36 |
+
task_id="easy_user_count",
|
| 37 |
+
difficulty="easy",
|
| 38 |
+
question=(
|
| 39 |
+
"How many users are registered in the system? "
|
| 40 |
+
"Provide the total count as a single number."
|
| 41 |
+
),
|
| 42 |
+
ground_truth=15,
|
| 43 |
+
ground_truth_sql="SELECT COUNT(*) FROM users",
|
| 44 |
+
description="Single table COUNT query on users table"
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
TASK_MEDIUM = Task(
|
| 48 |
+
task_id="medium_usa_revenue",
|
| 49 |
+
difficulty="medium",
|
| 50 |
+
question=(
|
| 51 |
+
"What is the total revenue (sum of total_amount) from purchases made by users in the USA? "
|
| 52 |
+
"Provide the total as a number (rounded to 2 decimal places if needed)."
|
| 53 |
+
),
|
| 54 |
+
ground_truth=2423.87, # Sum of purchases by USA users (user_ids: 1, 4, 7, 10, 14)
|
| 55 |
+
ground_truth_sql="""
|
| 56 |
+
SELECT ROUND(SUM(p.total_amount), 2) as total_revenue
|
| 57 |
+
FROM purchases p
|
| 58 |
+
JOIN users u ON p.user_id = u.user_id
|
| 59 |
+
WHERE u.country = 'USA'
|
| 60 |
+
""",
|
| 61 |
+
description="Two-table JOIN with SUM aggregation filtered by country"
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
TASK_HARD = Task(
|
| 65 |
+
task_id="hard_top_spender",
|
| 66 |
+
difficulty="hard",
|
| 67 |
+
question=(
|
| 68 |
+
"Who is the top spender (user with highest total purchase amount)? "
|
| 69 |
+
"Provide the username of the user who spent the most money in total."
|
| 70 |
+
),
|
| 71 |
+
ground_truth="alice", # alice has purchases totaling 1509.96 (1299.99 + 59.98 + 149.99)
|
| 72 |
+
ground_truth_sql="""
|
| 73 |
+
SELECT u.username
|
| 74 |
+
FROM users u
|
| 75 |
+
JOIN purchases p ON u.user_id = p.user_id
|
| 76 |
+
GROUP BY u.user_id, u.username
|
| 77 |
+
ORDER BY SUM(p.total_amount) DESC
|
| 78 |
+
LIMIT 1
|
| 79 |
+
""",
|
| 80 |
+
description="Complex query with JOIN, GROUP BY, ORDER BY, and LIMIT"
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# List of all tasks
|
| 85 |
+
TASKS: List[Task] = [TASK_EASY, TASK_MEDIUM, TASK_HARD]
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def get_task_by_id(task_id: str) -> Task:
|
| 89 |
+
"""
|
| 90 |
+
Get a task by its ID.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
task_id: The unique task identifier
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
Task: The matching task
|
| 97 |
+
|
| 98 |
+
Raises:
|
| 99 |
+
ValueError: If task_id not found
|
| 100 |
+
"""
|
| 101 |
+
for task in TASKS:
|
| 102 |
+
if task.task_id == task_id:
|
| 103 |
+
return task
|
| 104 |
+
raise ValueError(f"Task not found: {task_id}")
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def get_task_by_difficulty(difficulty: str) -> Task:
|
| 108 |
+
"""
|
| 109 |
+
Get a task by difficulty level.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
difficulty: easy, medium, or hard
|
| 113 |
+
|
| 114 |
+
Returns:
|
| 115 |
+
Task: A task matching the difficulty
|
| 116 |
+
|
| 117 |
+
Raises:
|
| 118 |
+
ValueError: If difficulty not found
|
| 119 |
+
"""
|
| 120 |
+
for task in TASKS:
|
| 121 |
+
if task.difficulty == difficulty:
|
| 122 |
+
return task
|
| 123 |
+
raise ValueError(f"No task found for difficulty: {difficulty}")
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def get_random_task() -> Task:
|
| 127 |
+
"""
|
| 128 |
+
Get a random task from the available tasks.
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
Task: A randomly selected task
|
| 132 |
+
"""
|
| 133 |
+
return random.choice(TASKS)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def get_all_tasks() -> List[Task]:
|
| 137 |
+
"""
|
| 138 |
+
Get all available tasks.
|
| 139 |
+
|
| 140 |
+
Returns:
|
| 141 |
+
List[Task]: All defined tasks
|
| 142 |
+
"""
|
| 143 |
+
return TASKS.copy()
|
openenv-sql-analyst/inference.py
ADDED
|
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# inference.py
|
| 3 |
+
# Baseline Inference Script for OpenEnv SQL Analyst
|
| 4 |
+
# Uses OpenAI API client to run model against the environment
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import sys
|
| 8 |
+
import json
|
| 9 |
+
from typing import Optional
|
| 10 |
+
|
| 11 |
+
# Add the project root to path for imports
|
| 12 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 13 |
+
|
| 14 |
+
from openai import OpenAI
|
| 15 |
+
from environment.env import SQLAnalystEnv
|
| 16 |
+
from environment.models import Action
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# ============================================
|
| 20 |
+
# CONFIGURATION
|
| 21 |
+
# ============================================
|
| 22 |
+
API_BASE_URL = os.environ.get("API_BASE_URL")
|
| 23 |
+
MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4o-mini")
|
| 24 |
+
API_KEY = os.environ.get("API_KEY")
|
| 25 |
+
|
| 26 |
+
if not API_BASE_URL:
|
| 27 |
+
raise ValueError("API_BASE_URL environment variable is required")
|
| 28 |
+
if not API_KEY:
|
| 29 |
+
raise ValueError("API_KEY environment variable is required")
|
| 30 |
+
|
| 31 |
+
# Environment configuration
|
| 32 |
+
BENCHMARK_NAME = "sql_analyst"
|
| 33 |
+
MAX_STEPS = 15
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# ============================================
|
| 37 |
+
# SYSTEM PROMPT
|
| 38 |
+
# ============================================
|
| 39 |
+
SYSTEM_PROMPT = """You are an expert SQL Data Analyst AI agent. Your task is to answer business questions by querying a SQLite database.
|
| 40 |
+
|
| 41 |
+
You have two possible actions each turn:
|
| 42 |
+
1. Execute a SQL query to explore the data: {"sql_query": "SELECT ..."}
|
| 43 |
+
2. Submit your final answer: {"submit_answer": "your answer"}
|
| 44 |
+
|
| 45 |
+
IMPORTANT RULES:
|
| 46 |
+
- Only use SELECT queries. INSERT, UPDATE, DELETE, DROP, ALTER, TRUNCATE are blocked.
|
| 47 |
+
- Explore the data step by step before submitting your final answer.
|
| 48 |
+
- Your final answer should be just the value requested (a number, name, etc.), not a SQL query.
|
| 49 |
+
- Respond with ONLY a valid JSON object, no other text.
|
| 50 |
+
|
| 51 |
+
DATABASE SCHEMA:
|
| 52 |
+
{schema_info}
|
| 53 |
+
|
| 54 |
+
CURRENT QUESTION:
|
| 55 |
+
{current_question}
|
| 56 |
+
|
| 57 |
+
LAST QUERY RESULT:
|
| 58 |
+
{last_query_result}
|
| 59 |
+
|
| 60 |
+
{error_section}
|
| 61 |
+
|
| 62 |
+
Respond with a JSON object containing either "sql_query" or "submit_answer"."""
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def format_action_str(action: Action) -> str:
|
| 66 |
+
"""Format action for logging."""
|
| 67 |
+
if action.sql_query:
|
| 68 |
+
# Truncate long queries for logging
|
| 69 |
+
query = action.sql_query.replace("\n", " ").strip()
|
| 70 |
+
if len(query) > 50:
|
| 71 |
+
query = query[:47] + "..."
|
| 72 |
+
return f"sql_query={query}"
|
| 73 |
+
elif action.submit_answer:
|
| 74 |
+
answer = str(action.submit_answer).strip()
|
| 75 |
+
if len(answer) > 30:
|
| 76 |
+
answer = answer[:27] + "..."
|
| 77 |
+
return f"submit_answer={answer}"
|
| 78 |
+
return "invalid_action"
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def parse_model_response(response_text: str) -> Optional[Action]:
|
| 82 |
+
"""
|
| 83 |
+
Parse the model's response into an Action.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
response_text: The raw text response from the model
|
| 87 |
+
|
| 88 |
+
Returns:
|
| 89 |
+
Action or None if parsing fails
|
| 90 |
+
"""
|
| 91 |
+
try:
|
| 92 |
+
# Clean the response
|
| 93 |
+
text = response_text.strip()
|
| 94 |
+
|
| 95 |
+
# Try to extract JSON from the response
|
| 96 |
+
# Handle cases where model wraps JSON in markdown code blocks
|
| 97 |
+
if "```json" in text:
|
| 98 |
+
start = text.find("```json") + 7
|
| 99 |
+
end = text.find("```", start)
|
| 100 |
+
text = text[start:end].strip()
|
| 101 |
+
elif "```" in text:
|
| 102 |
+
start = text.find("```") + 3
|
| 103 |
+
end = text.find("```", start)
|
| 104 |
+
text = text[start:end].strip()
|
| 105 |
+
|
| 106 |
+
# Parse JSON
|
| 107 |
+
data = json.loads(text)
|
| 108 |
+
|
| 109 |
+
# Create Action
|
| 110 |
+
return Action(
|
| 111 |
+
sql_query=data.get("sql_query"), submit_answer=data.get("submit_answer")
|
| 112 |
+
)
|
| 113 |
+
except (json.JSONDecodeError, ValueError) as e:
|
| 114 |
+
return None
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def run_inference():
|
| 118 |
+
"""
|
| 119 |
+
Run the baseline inference loop.
|
| 120 |
+
|
| 121 |
+
This function:
|
| 122 |
+
1. Initializes the environment
|
| 123 |
+
2. Runs the model against the environment
|
| 124 |
+
3. Outputs structured logs in the exact required format
|
| 125 |
+
"""
|
| 126 |
+
# Initialize OpenAI client
|
| 127 |
+
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
| 128 |
+
|
| 129 |
+
# Initialize environment
|
| 130 |
+
env = SQLAnalystEnv()
|
| 131 |
+
|
| 132 |
+
# Reset environment and get initial observation
|
| 133 |
+
observation = env.reset()
|
| 134 |
+
|
| 135 |
+
# Get task info from state
|
| 136 |
+
state = env.state()
|
| 137 |
+
task_name = state.get("task_id", "unknown")
|
| 138 |
+
|
| 139 |
+
# ============================================
|
| 140 |
+
# [START] LOG - EXACT FORMAT REQUIRED
|
| 141 |
+
# ============================================
|
| 142 |
+
print(f"[START] task={task_name} env={BENCHMARK_NAME} model={MODEL_NAME}")
|
| 143 |
+
|
| 144 |
+
# Track rewards and steps
|
| 145 |
+
rewards = []
|
| 146 |
+
step_num = 0
|
| 147 |
+
done = False
|
| 148 |
+
success = False
|
| 149 |
+
final_score = 0.0
|
| 150 |
+
|
| 151 |
+
while not done and step_num < MAX_STEPS:
|
| 152 |
+
step_num += 1
|
| 153 |
+
|
| 154 |
+
# Build the prompt
|
| 155 |
+
error_section = ""
|
| 156 |
+
if observation.error_message:
|
| 157 |
+
error_section = f"ERROR FROM LAST ACTION:\n{observation.error_message}"
|
| 158 |
+
|
| 159 |
+
prompt = SYSTEM_PROMPT.format(
|
| 160 |
+
schema_info=observation.schema_info,
|
| 161 |
+
current_question=observation.current_question,
|
| 162 |
+
last_query_result=observation.last_query_result,
|
| 163 |
+
error_section=error_section,
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
try:
|
| 167 |
+
# Call the model
|
| 168 |
+
response = client.chat.completions.create(
|
| 169 |
+
model=MODEL_NAME,
|
| 170 |
+
messages=[
|
| 171 |
+
{
|
| 172 |
+
"role": "system",
|
| 173 |
+
"content": "You are a SQL expert. Respond only with valid JSON.",
|
| 174 |
+
},
|
| 175 |
+
{"role": "user", "content": prompt},
|
| 176 |
+
],
|
| 177 |
+
temperature=0.0,
|
| 178 |
+
max_tokens=500,
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
# Extract response text
|
| 182 |
+
response_text = response.choices[0].message.content
|
| 183 |
+
|
| 184 |
+
# Parse into Action
|
| 185 |
+
action = parse_model_response(response_text)
|
| 186 |
+
|
| 187 |
+
if action is None:
|
| 188 |
+
# Failed to parse, try a simple query as fallback
|
| 189 |
+
action = Action(sql_query="SELECT 1")
|
| 190 |
+
error_msg = "parse_error"
|
| 191 |
+
else:
|
| 192 |
+
error_msg = "null"
|
| 193 |
+
|
| 194 |
+
# Execute action in environment
|
| 195 |
+
observation, reward, done, info = env.step(action)
|
| 196 |
+
|
| 197 |
+
# Track reward
|
| 198 |
+
reward_value = reward.value
|
| 199 |
+
rewards.append(reward_value)
|
| 200 |
+
|
| 201 |
+
# Check for errors in observation
|
| 202 |
+
if observation.error_message:
|
| 203 |
+
error_msg = observation.error_message.replace("\n", " ")[:50]
|
| 204 |
+
|
| 205 |
+
# ============================================
|
| 206 |
+
# [STEP] LOG - EXACT FORMAT REQUIRED
|
| 207 |
+
# ============================================
|
| 208 |
+
action_str = format_action_str(action)
|
| 209 |
+
done_str = "true" if done else "false"
|
| 210 |
+
print(
|
| 211 |
+
f"[STEP] step={step_num} action={action_str} reward={reward_value:.2f} done={done_str} error={error_msg}"
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
# Update final results
|
| 215 |
+
if done:
|
| 216 |
+
success = info.get("success", False)
|
| 217 |
+
final_score = info.get("final_score", 0.0)
|
| 218 |
+
|
| 219 |
+
except Exception as e:
|
| 220 |
+
# Handle API or other errors
|
| 221 |
+
error_msg = str(e).replace("\n", " ")[:50]
|
| 222 |
+
print(
|
| 223 |
+
f"[STEP] step={step_num} action=error reward=0.00 done=false error={error_msg}"
|
| 224 |
+
)
|
| 225 |
+
rewards.append(0.0)
|
| 226 |
+
|
| 227 |
+
# Try to continue with a simple action
|
| 228 |
+
try:
|
| 229 |
+
action = Action(submit_answer="error")
|
| 230 |
+
observation, reward, done, info = env.step(action)
|
| 231 |
+
success = info.get("success", False)
|
| 232 |
+
final_score = info.get("final_score", 0.0)
|
| 233 |
+
except:
|
| 234 |
+
done = True
|
| 235 |
+
success = False
|
| 236 |
+
final_score = 0.0
|
| 237 |
+
|
| 238 |
+
# ============================================
|
| 239 |
+
# [END] LOG - EXACT FORMAT REQUIRED
|
| 240 |
+
# ============================================
|
| 241 |
+
success_str = "true" if success else "false"
|
| 242 |
+
rewards_str = ",".join([f"{r:.2f}" for r in rewards])
|
| 243 |
+
print(
|
| 244 |
+
f"[END] success={success_str} steps={step_num} score={final_score:.2f} rewards={rewards_str}"
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
# Cleanup
|
| 248 |
+
env.close()
|
| 249 |
+
|
| 250 |
+
return success, final_score
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def main():
|
| 254 |
+
"""Main entry point."""
|
| 255 |
+
try:
|
| 256 |
+
success, score = run_inference()
|
| 257 |
+
sys.exit(0 if success else 0) # Always exit 0 for validation script
|
| 258 |
+
except Exception as e:
|
| 259 |
+
# Emergency fallback - still output required logs
|
| 260 |
+
print(f"[START] task=error env={BENCHMARK_NAME} model={MODEL_NAME}")
|
| 261 |
+
print(f"[STEP] step=1 action=error reward=0.00 done=true error={str(e)[:50]}")
|
| 262 |
+
print(f"[END] success=false steps=1 score=0.00 rewards=0.00")
|
| 263 |
+
sys.exit(0)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
if __name__ == "__main__":
|
| 267 |
+
main()
|
openenv-sql-analyst/openenv.yaml
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# OpenEnv Specification for SQL Data Analyst Environment
|
| 2 |
+
# Hackathon: Meta x Scaler - OpenEnv Framework
|
| 3 |
+
|
| 4 |
+
name: sql_analyst
|
| 5 |
+
version: "1.0.0"
|
| 6 |
+
description: >
|
| 7 |
+
A Reinforcement Learning environment simulating a Data Analyst workspace
|
| 8 |
+
where an AI agent queries a SQLite database to answer business questions.
|
| 9 |
+
|
| 10 |
+
tags:
|
| 11 |
+
- openenv
|
| 12 |
+
- sql
|
| 13 |
+
- data-analyst
|
| 14 |
+
- reinforcement-learning
|
| 15 |
+
|
| 16 |
+
infrastructure:
|
| 17 |
+
vcpu: 2
|
| 18 |
+
memory: 8gb
|
| 19 |
+
timeout: 1200 # 20 minutes max runtime
|
| 20 |
+
|
| 21 |
+
entry_point: environment.env:SQLAnalystEnv
|
| 22 |
+
|
| 23 |
+
models:
|
| 24 |
+
action: environment.models:Action
|
| 25 |
+
observation: environment.models:Observation
|
| 26 |
+
reward: environment.models:Reward
|
| 27 |
+
|
| 28 |
+
schemas:
|
| 29 |
+
action:
|
| 30 |
+
type: object
|
| 31 |
+
properties:
|
| 32 |
+
sql_query:
|
| 33 |
+
type: string
|
| 34 |
+
description: SQL query to execute against the database
|
| 35 |
+
nullable: true
|
| 36 |
+
submit_answer:
|
| 37 |
+
type: string
|
| 38 |
+
description: Final answer to submit for grading
|
| 39 |
+
nullable: true
|
| 40 |
+
required: []
|
| 41 |
+
additionalProperties: false
|
| 42 |
+
|
| 43 |
+
observation:
|
| 44 |
+
type: object
|
| 45 |
+
properties:
|
| 46 |
+
schema_info:
|
| 47 |
+
type: string
|
| 48 |
+
description: Database schema information
|
| 49 |
+
current_question:
|
| 50 |
+
type: string
|
| 51 |
+
description: The current task question to answer
|
| 52 |
+
last_query_result:
|
| 53 |
+
type: string
|
| 54 |
+
description: Result from the last SQL query execution
|
| 55 |
+
error_message:
|
| 56 |
+
type: string
|
| 57 |
+
description: Error message from last action, if any
|
| 58 |
+
required:
|
| 59 |
+
- schema_info
|
| 60 |
+
- current_question
|
| 61 |
+
- last_query_result
|
| 62 |
+
- error_message
|
| 63 |
+
|
| 64 |
+
reward:
|
| 65 |
+
type: object
|
| 66 |
+
properties:
|
| 67 |
+
value:
|
| 68 |
+
type: number
|
| 69 |
+
description: Reward value for the action taken
|
| 70 |
+
required:
|
| 71 |
+
- value
|
| 72 |
+
|
| 73 |
+
endpoints:
|
| 74 |
+
reset:
|
| 75 |
+
method: POST
|
| 76 |
+
path: /reset
|
| 77 |
+
description: Reset the environment and get initial observation
|
| 78 |
+
response: observation
|
| 79 |
+
|
| 80 |
+
step:
|
| 81 |
+
method: POST
|
| 82 |
+
path: /step
|
| 83 |
+
description: Execute an action and receive observation, reward, done, info
|
| 84 |
+
request: action
|
| 85 |
+
response:
|
| 86 |
+
type: object
|
| 87 |
+
properties:
|
| 88 |
+
observation: observation
|
| 89 |
+
reward: reward
|
| 90 |
+
done:
|
| 91 |
+
type: boolean
|
| 92 |
+
info:
|
| 93 |
+
type: object
|
| 94 |
+
|
| 95 |
+
state:
|
| 96 |
+
method: GET
|
| 97 |
+
path: /state
|
| 98 |
+
description: Get the current internal state of the environment
|
openenv-sql-analyst/pyproject.toml
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = ["setuptools>=61.0"]
|
| 3 |
+
build-backend = "setuptools.build_meta"
|
| 4 |
+
|
| 5 |
+
[project]
|
| 6 |
+
name = "openenv_sql_analyst"
|
| 7 |
+
version = "0.1.0"
|
| 8 |
+
description = "OpenEnv SQL Data Analyst Agent"
|
| 9 |
+
requires-python = ">=3.10"
|
| 10 |
+
dependencies = [
|
| 11 |
+
"openenv-core",
|
| 12 |
+
"pydantic",
|
| 13 |
+
"openai"
|
| 14 |
+
]
|
| 15 |
+
|
| 16 |
+
[project.scripts]
|
| 17 |
+
server = "server.app:main"
|
| 18 |
+
|
| 19 |
+
[tool.setuptools]
|
| 20 |
+
packages = ["environment", "server"]
|
openenv-sql-analyst/requirements.txt
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# OpenEnv SQL Analyst Environment Dependencies
|
| 2 |
+
# Optimized for 8GB RAM constraint
|
| 3 |
+
|
| 4 |
+
# Core framework
|
| 5 |
+
openenv-core>=0.1.0
|
| 6 |
+
|
| 7 |
+
# Pydantic for typed models
|
| 8 |
+
pydantic>=2.0.0
|
| 9 |
+
|
| 10 |
+
# OpenAI client for inference
|
| 11 |
+
openai>=1.0.0
|
| 12 |
+
|
| 13 |
+
# Database (sqlite3 is built-in, no extra deps needed)
|
| 14 |
+
|
| 15 |
+
# HTTP server dependencies (typically bundled with openenv-core)
|
| 16 |
+
uvicorn>=0.23.0
|
| 17 |
+
fastapi>=0.100.0
|
| 18 |
+
|
| 19 |
+
# Utilities
|
| 20 |
+
python-dotenv>=1.0.0
|
openenv-sql-analyst/server/app.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import uvicorn
|
| 3 |
+
from fastapi import FastAPI
|
| 4 |
+
from environment.env import SQLAnalystEnv
|
| 5 |
+
from environment.models import Action
|
| 6 |
+
|
| 7 |
+
# Initialize the API and our RL Environment
|
| 8 |
+
app = FastAPI(title="OpenEnv SQL Analyst")
|
| 9 |
+
env = SQLAnalystEnv()
|
| 10 |
+
|
| 11 |
+
@app.get("/")
|
| 12 |
+
def health_check():
|
| 13 |
+
"""Hackathon requirement: Ping must return 200 OK"""
|
| 14 |
+
return {"status": "ok", "message": "OpenEnv SQL Analyst is live!"}
|
| 15 |
+
|
| 16 |
+
@app.post("/reset")
|
| 17 |
+
def reset():
|
| 18 |
+
"""Hackathon requirement: Must respond to reset()"""
|
| 19 |
+
return env.reset()
|
| 20 |
+
|
| 21 |
+
@app.post("/step")
|
| 22 |
+
def step(action: Action):
|
| 23 |
+
"""Executes the agent's action and returns the new state"""
|
| 24 |
+
obs, reward, done, info = env.step(action)
|
| 25 |
+
return {
|
| 26 |
+
"observation": obs,
|
| 27 |
+
"reward": reward,
|
| 28 |
+
"done": done,
|
| 29 |
+
"info": info
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
@app.get("/state")
|
| 33 |
+
def state():
|
| 34 |
+
return env.state()
|
| 35 |
+
|
| 36 |
+
def main():
|
| 37 |
+
print("π Starting OpenEnv Production Server on port 7860...")
|
| 38 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
| 39 |
+
|
| 40 |
+
if __name__ == "__main__":
|
| 41 |
+
main()
|
openenv-sql-analyst/validate.sh
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# OpenEnv Hackathon Pre-Submission Validation Script
|
| 3 |
+
# Based on Meta x Scaler Hackathon Round 1 Guidelines
|
| 4 |
+
|
| 5 |
+
# Colors for output
|
| 6 |
+
GREEN='\033[0;32m'
|
| 7 |
+
RED='\033[0;31m'
|
| 8 |
+
BOLD='\033[1m'
|
| 9 |
+
NC='\033[0m'
|
| 10 |
+
|
| 11 |
+
echo -e "${BOLD}Starting Validation...${NC}\n"
|
| 12 |
+
|
| 13 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 14 |
+
# STEP 1: Prerequisite Check
|
| 15 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 16 |
+
echo -e "${BOLD}Step 1/4: Checking Prerequisites...${NC}"
|
| 17 |
+
|
| 18 |
+
if ! command -v docker &>/dev/null; then
|
| 19 |
+
echo -e "${RED}[FAIL] Docker command not found. Install it: https://docs.docker.com/get-docker/${NC}"
|
| 20 |
+
exit 1
|
| 21 |
+
fi
|
| 22 |
+
|
| 23 |
+
if ! command -v openenv &>/dev/null; then
|
| 24 |
+
echo -e "${RED}[FAIL] openenv-core not found. Install it: pip install openenv-core${NC}"
|
| 25 |
+
exit 1
|
| 26 |
+
fi
|
| 27 |
+
|
| 28 |
+
echo -e "${GREEN}[PASS] Prerequisites found.${NC}\n"
|
| 29 |
+
|
| 30 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 31 |
+
# STEP 2: Docker Build Check
|
| 32 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 33 |
+
echo -e "${BOLD}Step 2/4: Running Docker Build...${NC}"
|
| 34 |
+
|
| 35 |
+
if [ -f "Dockerfile" ]; then
|
| 36 |
+
DOCKER_CONTEXT="."
|
| 37 |
+
elif [ -f "server/Dockerfile" ]; then
|
| 38 |
+
DOCKER_CONTEXT="server"
|
| 39 |
+
else
|
| 40 |
+
echo -e "${RED}[FAIL] No Dockerfile found in root or server/ directory.${NC}"
|
| 41 |
+
exit 1
|
| 42 |
+
fi
|
| 43 |
+
|
| 44 |
+
docker build -t openenv-validator "$DOCKER_CONTEXT"
|
| 45 |
+
|
| 46 |
+
if [ $? -eq 0 ]; then
|
| 47 |
+
echo -e "${GREEN}[PASS] Docker build succeeded.${NC}\n"
|
| 48 |
+
else
|
| 49 |
+
echo -e "${RED}[FAIL] Docker build failed. Check your Dockerfile.${NC}"
|
| 50 |
+
exit 1
|
| 51 |
+
fi
|
| 52 |
+
|
| 53 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 54 |
+
# STEP 3: OpenEnv Spec Validation
|
| 55 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 56 |
+
echo -e "${BOLD}Step 3/4: Running openenv validate...${NC}"
|
| 57 |
+
|
| 58 |
+
openenv validate
|
| 59 |
+
|
| 60 |
+
if [ $? -eq 0 ]; then
|
| 61 |
+
echo -e "${GREEN}[PASS] OpenEnv spec compliance verified (yaml, models, endpoints).${NC}\n"
|
| 62 |
+
else
|
| 63 |
+
echo -e "${RED}[FAIL] OpenEnv validation failed. Check openenv.yaml and models.py.${NC}"
|
| 64 |
+
exit 1
|
| 65 |
+
fi
|
| 66 |
+
|
| 67 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 68 |
+
# STEP 4: Baseline Inference & Log Format Check
|
| 69 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 70 |
+
echo -e "${BOLD}Step 4/4: Running Baseline Inference Check...${NC}"
|
| 71 |
+
|
| 72 |
+
if [ ! -f "inference.py" ]; then
|
| 73 |
+
echo -e "${RED}[FAIL] inference.py NOT found in root directory.${NC}"
|
| 74 |
+
exit 1
|
| 75 |
+
fi
|
| 76 |
+
|
| 77 |
+
# Run inference and capture output to check STDOUT format
|
| 78 |
+
OUTPUT=$(python inference.py 2>&1)
|
| 79 |
+
EXIT_CODE=$?
|
| 80 |
+
|
| 81 |
+
if [ $EXIT_CODE -ne 0 ]; then
|
| 82 |
+
echo -e "${RED}[FAIL] inference.py failed to execute without errors.${NC}"
|
| 83 |
+
echo "$OUTPUT"
|
| 84 |
+
exit 1
|
| 85 |
+
fi
|
| 86 |
+
|
| 87 |
+
# Verify mandatory log tags: [START], [STEP], [END]
|
| 88 |
+
if [[ "$OUTPUT" == *"[START]"* ]] && [[ "$OUTPUT" == *"[STEP]"* ]] && [[ "$OUTPUT" == *"[END]"* ]]; then
|
| 89 |
+
echo -e "${GREEN}[PASS] Mandatory STDOUT log format ([START], [STEP], [END]) detected.${NC}"
|
| 90 |
+
else
|
| 91 |
+
echo -e "${RED}[FAIL] STDOUT format incorrect. Must strictly follow [START], [STEP], [END] lines.${NC}"
|
| 92 |
+
exit 1
|
| 93 |
+
fi
|
| 94 |
+
|
| 95 |
+
# Verify score is within valid 0.0β1.0 range
|
| 96 |
+
if [[ "$OUTPUT" =~ "score="([0-9]*\.[0-9]+|[0-9]+) ]]; then
|
| 97 |
+
SCORE=${BASH_REMATCH[1]}
|
| 98 |
+
if awk "BEGIN {exit !($SCORE >= 0.0 && $SCORE <= 1.0)}"; then
|
| 99 |
+
echo -e "${GREEN}[PASS] Score ($SCORE) is within valid 0.0-1.0 range.${NC}"
|
| 100 |
+
else
|
| 101 |
+
echo -e "${RED}[FAIL] Invalid score: $SCORE. Must be between 0.0 and 1.0.${NC}"
|
| 102 |
+
exit 1
|
| 103 |
+
fi
|
| 104 |
+
fi
|
| 105 |
+
|
| 106 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 107 |
+
# ALL CHECKS PASSED
|
| 108 |
+
# βββββββββββββββββββββββββββββββββββββββββββββ
|
| 109 |
+
echo -e "\n${GREEN}${BOLD}========================================${NC}"
|
| 110 |
+
echo -e "${GREEN}${BOLD} ALL 4/4 CHECKS PASSED!${NC}"
|
| 111 |
+
echo -e "${GREEN}${BOLD} YOUR SUBMISSION IS READY.${NC}"
|
| 112 |
+
echo -e "${GREEN}${BOLD}========================================${NC}"
|