Spaces:
Sleeping
Sleeping
| # SQL Data Analyst Agent β OpenEnv Hackathon Build Guide | |
| > **Hackathon:** Meta OpenEnv Hackathon | |
| > **Environment name:** `sql-data-analyst` | |
| > **Goal:** Build a real-world RL environment where an AI agent answers business questions by writing and executing SQL against a live database. | |
| --- | |
| ## Table of Contents | |
| 1. [What We Are Building](#1-what-we-are-building) | |
| 2. [Requirements Checklist](#2-requirements-checklist) | |
| 3. [Database Design](#3-database-design) | |
| 4. [The 3 Tasks with Graders](#4-the-3-tasks-with-graders) | |
| 5. [Pydantic Models (OpenEnv Spec)](#5-pydantic-models-openenv-spec) | |
| 6. [Environment Core (environment.py)](#6-environment-core) | |
| 7. [Reward Function](#7-reward-function) | |
| 8. [Key Optimisations](#8-key-optimisations) | |
| 9. [Baseline Inference Script](#9-baseline-inference-script) | |
| 10. [openenv.yaml](#10-openenvyaml) | |
| 11. [Dockerfile](#11-dockerfile) | |
| 12. [README Template](#12-readme-template) | |
| 13. [Full File Structure](#13-full-file-structure) | |
| 14. [Build Order (Step-by-Step)](#14-build-order) | |
| --- | |
| ## 1. What We Are Building | |
| An **OpenEnv-compliant RL training environment** where an AI agent: | |
| - Receives a natural language business question and a live SQLite database schema | |
| - Writes SQL queries, executes them, and observes the results | |
| - Iterates until it can submit a final answer | |
| - Gets scored 0.0β1.0 based on correctness and efficiency | |
| **Why this wins:** | |
| - Deterministic grading β SQL answers are right or wrong, no ambiguity | |
| - Partial rewards are natural at every step (table hit β no error β correct answer) | |
| - Directly applicable to real business intelligence workflows | |
| - Clean difficulty curve across 3 tasks | |
| --- | |
| ## 2. Requirements Checklist | |
| | # | Requirement | Implementation | | |
| |---|---|---| | |
| | 1 | Real-world task | SQL data analysis β used by every company daily | | |
| | 2 | OpenEnv spec: typed models | Pydantic `Observation`, `Action`, `StepResult` | | |
| | 3 | OpenEnv spec: `step()` | Returns `(observation, reward, done, info)` | | |
| | 4 | OpenEnv spec: `reset()` | Returns initial observation, reseeds DB | | |
| | 5 | OpenEnv spec: `state()` | Returns current full env state | | |
| | 6 | `openenv.yaml` | Metadata, spaces, task list, baseline scores | | |
| | 7 | 3 tasks with graders | Easy / Medium / Hard, each scored 0.0β1.0 | | |
| | 8 | Meaningful reward | Partial credit at every step, not just end | | |
| | 9 | Baseline inference script | OpenAI API client, reproducible scores | | |
| | 10 | HuggingFace Space | Containerised, tagged `openenv` | | |
| | 11 | Dockerfile | `docker build + docker run` works cleanly | | |
| | 12 | README | Spaces, tasks, setup, baseline scores | | |
| --- | |
| ## 3. Database Design | |
| Use a realistic SaaS e-commerce schema. This single schema supports all 3 tasks. | |
| ### Schema | |
| ```sql | |
| -- users table | |
| CREATE TABLE users ( | |
| id INTEGER PRIMARY KEY, | |
| email TEXT NOT NULL, | |
| country TEXT, | |
| plan TEXT CHECK(plan IN ('free', 'pro', 'enterprise')), | |
| created_at TIMESTAMP NOT NULL, | |
| churned_at TIMESTAMP -- NULL if still active | |
| ); | |
| -- products table | |
| CREATE TABLE products ( | |
| id INTEGER PRIMARY KEY, | |
| name TEXT NOT NULL, | |
| category TEXT NOT NULL, -- Electronics, Clothing, Books, etc. | |
| price DECIMAL(10,2), | |
| cost DECIMAL(10,2) | |
| ); | |
| -- orders table | |
| CREATE TABLE orders ( | |
| id INTEGER PRIMARY KEY, | |
| user_id INTEGER REFERENCES users(id), | |
| created_at TIMESTAMP NOT NULL, | |
| status TEXT CHECK(status IN ('pending','completed','refunded')), | |
| total DECIMAL(10,2) | |
| ); | |
| -- order_items table | |
| CREATE TABLE order_items ( | |
| id INTEGER PRIMARY KEY, | |
| order_id INTEGER REFERENCES orders(id), | |
| product_id INTEGER REFERENCES products(id), | |
| qty INTEGER NOT NULL, | |
| unit_price DECIMAL(10,2) | |
| ); | |
| -- events table (user behaviour) | |
| CREATE TABLE events ( | |
| id INTEGER PRIMARY KEY, | |
| user_id INTEGER REFERENCES users(id), | |
| event_type TEXT, -- page_view, add_to_cart, checkout, etc. | |
| metadata JSON, | |
| ts TIMESTAMP NOT NULL | |
| ); | |
| ``` | |
| ### Seeding | |
| Seed with realistic volumes using the `faker` library: | |
| ```python | |
| # database.py β seed targets | |
| SEED_CONFIG = { | |
| "users": 500, # ~500 users | |
| "products": 80, # 80 products across 5 categories | |
| "orders": 2000, # ~2000 orders | |
| "order_items": 5000, # ~5000 line items | |
| "events": 8000, # ~8000 behavioural events | |
| } | |
| # Intentional messiness (makes it realistic) | |
| # - ~5% of users have NULL country | |
| # - ~3% of orders have status='refunded' | |
| # - churned_at is NULL for active users | |
| # - Some users have 0 orders (registered but never bought) | |
| ``` | |
| --- | |
| ## 4. The 3 Tasks with Graders | |
| ### Task 1 β Easy: Monthly Signups | |
| **Question:** `"How many users signed up in the last 30 days?"` | |
| **Required SQL skills:** Single table, `COUNT`, `WHERE`, date filtering | |
| **Expected SQL:** | |
| ```sql | |
| SELECT COUNT(*) FROM users | |
| WHERE created_at >= DATE('now', '-30 days'); | |
| ``` | |
| **Grader:** | |
| ```python | |
| def grade_easy(submitted_answer: str, ground_truth: int) -> float: | |
| try: | |
| val = int(submitted_answer.strip().replace(",", "")) | |
| if val == ground_truth: | |
| return 1.0 | |
| if abs(val - ground_truth) <= 3: # within 3 = partial credit | |
| return 0.6 | |
| if abs(val - ground_truth) <= 10: # within 10 = small credit | |
| return 0.3 | |
| except (ValueError, AttributeError): | |
| pass | |
| return 0.0 | |
| ``` | |
| **Max steps:** 10 | |
| **Difficulty:** Easy | |
| --- | |
| ### Task 2 β Medium: Top Revenue Category | |
| **Question:** `"Which product category generated the most revenue in Q3 (JulyβSeptember)?"` | |
| **Required SQL skills:** `JOIN` across 3 tables, `GROUP BY`, `ORDER BY`, `SUM`, date range filtering | |
| **Expected SQL:** | |
| ```sql | |
| SELECT p.category, SUM(oi.qty * oi.unit_price) AS revenue | |
| FROM orders o | |
| JOIN order_items oi ON o.id = oi.order_id | |
| JOIN products p ON oi.product_id = p.id | |
| WHERE o.created_at BETWEEN '2024-07-01' AND '2024-09-30' | |
| AND o.status = 'completed' | |
| GROUP BY p.category | |
| ORDER BY revenue DESC | |
| LIMIT 1; | |
| ``` | |
| **Grader:** | |
| ```python | |
| def grade_medium(submitted_answer: str, ground_truth: str, top_3: list) -> float: | |
| answer = submitted_answer.strip().lower() | |
| # Remove common LLM preamble | |
| answer = re.sub(r'the (answer|category) is:?\s*', '', answer) | |
| if ground_truth.lower() in answer: | |
| return 1.0 | |
| if any(cat.lower() in answer for cat in top_3): | |
| return 0.4 # got a plausible answer, not the top one | |
| return 0.0 | |
| ``` | |
| **Max steps:** 15 | |
| **Difficulty:** Medium | |
| --- | |
| ### Task 3 β Hard: Churn After 3rd Purchase | |
| **Question:** `"Find the email addresses of users who placed exactly 3 orders and then never ordered again (churned after their 3rd purchase). Return as a comma-separated list."` | |
| **Required SQL skills:** Window functions (`ROW_NUMBER`, `COUNT`), subqueries, `HAVING`, date logic | |
| **Expected SQL:** | |
| ```sql | |
| WITH order_counts AS ( | |
| SELECT user_id, COUNT(*) AS total_orders, | |
| MAX(created_at) AS last_order_date | |
| FROM orders | |
| WHERE status = 'completed' | |
| GROUP BY user_id | |
| HAVING COUNT(*) = 3 | |
| ), | |
| churned AS ( | |
| SELECT oc.user_id | |
| FROM order_counts oc | |
| WHERE oc.last_order_date < DATE('now', '-90 days') | |
| ) | |
| SELECT u.email | |
| FROM users u | |
| JOIN churned c ON u.id = c.user_id; | |
| ``` | |
| **Grader (F1 score for set matching):** | |
| ```python | |
| def grade_hard(submitted_answer: str, ground_truth_emails: set) -> float: | |
| if not submitted_answer.strip(): | |
| return 0.0 | |
| # Parse comma-separated emails | |
| submitted = { | |
| e.strip().lower() | |
| for e in submitted_answer.split(",") | |
| if "@" in e | |
| } | |
| if not submitted: | |
| return 0.0 | |
| correct = ground_truth_emails | |
| tp = len(submitted & correct) | |
| if tp == 0: | |
| return 0.0 | |
| precision = tp / len(submitted) | |
| recall = tp / len(correct) | |
| f1 = 2 * precision * recall / (precision + recall) | |
| return round(f1, 3) | |
| ``` | |
| **Max steps:** 20 | |
| **Difficulty:** Hard | |
| --- | |
| ## 5. Pydantic Models (OpenEnv Spec) | |
| ```python | |
| # env/models.py | |
| from pydantic import BaseModel, Field | |
| from typing import Optional, List, Any | |
| class Action(BaseModel): | |
| """What the agent can do each step.""" | |
| sql_query: Optional[str] = Field( | |
| None, | |
| description="A SQL SELECT query to execute against the database" | |
| ) | |
| submit_answer: Optional[str] = Field( | |
| None, | |
| description="Final answer to submit. Ends the episode." | |
| ) | |
| def is_valid(self) -> bool: | |
| # Exactly one of the two must be set | |
| return bool(self.sql_query) != bool(self.submit_answer) | |
| class QueryResult(BaseModel): | |
| """Result of executing a SQL query.""" | |
| columns: List[str] = [] | |
| rows: List[List[Any]] = [] | |
| error: Optional[str] = None | |
| truncated: bool = False | |
| total_rows: int = 0 | |
| class Observation(BaseModel): | |
| """What the agent sees after each step.""" | |
| schema_summary: str = Field(..., description="Compact DB schema") | |
| question: str = Field(..., description="Business question to answer") | |
| last_query: Optional[str] = None | |
| last_result: Optional[QueryResult] = None | |
| last_error: Optional[str] = None | |
| step: int = 0 | |
| max_steps: int = 20 | |
| hints: List[str] = [] | |
| done: bool = False | |
| class StepResult(BaseModel): | |
| """Full result returned by step().""" | |
| observation: Observation | |
| reward: float = 0.0 | |
| done: bool = False | |
| info: dict = {} | |
| class EnvState(BaseModel): | |
| """Full environment state returned by state().""" | |
| task_id: str | |
| difficulty: str | |
| step: int | |
| max_steps: int | |
| query_history: List[str] = [] | |
| total_reward: float = 0.0 | |
| done: bool = False | |
| ``` | |
| --- | |
| ## 6. Environment Core | |
| ```python | |
| # env/environment.py | |
| import sqlite3 | |
| from typing import Optional | |
| from .models import Action, Observation, StepResult, EnvState, QueryResult | |
| from .database import create_database, seed_database, get_schema_summary | |
| from .reward import RewardCalculator | |
| from .tasks import TASKS | |
| class SQLAnalystEnv: | |
| """ | |
| OpenEnv-compliant SQL Data Analyst environment. | |
| An agent must answer business questions by iteratively | |
| writing and executing SQL queries. | |
| """ | |
| def __init__(self, task_id: str = "monthly_signups"): | |
| assert task_id in TASKS, f"Unknown task: {task_id}. Choose from {list(TASKS)}" | |
| self.task_id = task_id | |
| self.task = TASKS[task_id] | |
| self.conn: Optional[sqlite3.Connection] = None | |
| self.step_count: int = 0 | |
| self.total_reward: float = 0.0 | |
| self.done: bool = False | |
| self._query_history: list = [] | |
| self._reward_calc = RewardCalculator() | |
| # ------------------------------------------------------------------ | |
| # OpenEnv required methods | |
| # ------------------------------------------------------------------ | |
| def reset(self) -> StepResult: | |
| """Reset environment. Reseed DB. Return initial observation.""" | |
| if self.conn: | |
| self.conn.close() | |
| self.conn = create_database() | |
| seed_database(self.conn) | |
| self.step_count = 0 | |
| self.total_reward = 0.0 | |
| self.done = False | |
| self._query_history = [] | |
| # Compute ground truth AFTER seeding | |
| self.task.compute_ground_truth(self.conn) | |
| obs = Observation( | |
| schema_summary=get_schema_summary(self.conn), | |
| question=self.task.question, | |
| step=0, | |
| max_steps=self.task.max_steps, | |
| ) | |
| return StepResult(observation=obs, reward=0.0, done=False) | |
| def step(self, action: Action) -> StepResult: | |
| """Execute one agent action. Return (observation, reward, done, info).""" | |
| assert self.conn is not None, "Call reset() before step()" | |
| assert not self.done, "Episode is done. Call reset()." | |
| assert action.is_valid(), "Action must have exactly one of: sql_query, submit_answer" | |
| self.step_count += 1 | |
| query_result = None | |
| error = None | |
| # --- Execute SQL or submit answer --- | |
| if action.sql_query: | |
| query_result = self._execute_sql(action.sql_query) | |
| self._query_history.append(action.sql_query) | |
| error = query_result.error | |
| terminal = ( | |
| action.submit_answer is not None | |
| or self.step_count >= self.task.max_steps | |
| ) | |
| # --- Calculate reward --- | |
| reward = self._reward_calc.calculate( | |
| action=action, | |
| result=query_result, | |
| task=self.task, | |
| step=self.step_count, | |
| query_history=self._query_history, | |
| terminal=terminal, | |
| ) | |
| self.total_reward += reward | |
| self.done = terminal | |
| # --- Build next observation --- | |
| obs = Observation( | |
| schema_summary=get_schema_summary(self.conn), | |
| question=self.task.question, | |
| last_query=action.sql_query, | |
| last_result=query_result, | |
| last_error=error, | |
| step=self.step_count, | |
| max_steps=self.task.max_steps, | |
| hints=self.task.get_hints(self.step_count), | |
| done=self.done, | |
| ) | |
| return StepResult( | |
| observation=obs, | |
| reward=round(reward, 3), | |
| done=self.done, | |
| info={ | |
| "step": self.step_count, | |
| "total_reward": round(self.total_reward, 3), | |
| "task_id": self.task_id, | |
| } | |
| ) | |
| def state(self) -> EnvState: | |
| """Return current full state of the environment.""" | |
| return EnvState( | |
| task_id=self.task_id, | |
| difficulty=self.task.difficulty, | |
| step=self.step_count, | |
| max_steps=self.task.max_steps, | |
| query_history=self._query_history.copy(), | |
| total_reward=round(self.total_reward, 3), | |
| done=self.done, | |
| ) | |
| # ------------------------------------------------------------------ | |
| # Internal helpers | |
| # ------------------------------------------------------------------ | |
| def _execute_sql(self, query: str) -> QueryResult: | |
| """Execute SQL safely. Block non-SELECT. Return up to 50 rows.""" | |
| # Safety: only SELECT is allowed | |
| q = query.strip().upper() | |
| if not q.startswith("SELECT") and not q.startswith("WITH"): | |
| return QueryResult(error="Only SELECT / WITH queries are allowed.") | |
| try: | |
| cursor = self.conn.execute(query) | |
| cols = [d[0] for d in cursor.description] if cursor.description else [] | |
| rows = cursor.fetchmany(50) | |
| total = len(rows) # fetchmany caps at 50 | |
| return QueryResult( | |
| columns=cols, | |
| rows=[list(r) for r in rows], | |
| truncated=(total == 50), | |
| total_rows=total, | |
| ) | |
| except Exception as e: | |
| return QueryResult(error=str(e)) | |
| ``` | |
| --- | |
| ## 7. Reward Function | |
| ```python | |
| # env/reward.py | |
| import re | |
| from .models import Action, QueryResult | |
| class RewardCalculator: | |
| def calculate( | |
| self, | |
| action: Action, | |
| result: Optional[QueryResult], | |
| task, | |
| step: int, | |
| query_history: list, | |
| terminal: bool, | |
| ) -> float: | |
| reward = 0.0 | |
| # ββ Step-level rewards (every step) ββββββββββββββββββββββββββ | |
| if action.sql_query and result: | |
| # +0.15 β Query executed without syntax error | |
| if not result.error: | |
| reward += 0.15 | |
| # +0.10 β Query touched at least one relevant table | |
| relevant = self._count_relevant_tables(action.sql_query, task.relevant_tables) | |
| if relevant > 0: | |
| reward += 0.10 | |
| # +0.05 β Result has rows (not empty result set) | |
| if result.rows and len(result.rows) > 0: | |
| reward += 0.05 | |
| # +0.05 β Result is not absurdly large (sanity check) | |
| if result.rows and len(result.rows) < 1000: | |
| reward += 0.05 | |
| # ββ Efficiency penalties ββββββββββββββββββββββββββββββββββββββ | |
| # -0.02 per step beyond step 3 (penalise excessive querying) | |
| if step > 3: | |
| reward -= 0.02 * (step - 3) | |
| # -0.10 if agent is stuck in a loop (same query 3x) | |
| if self._is_stuck(query_history): | |
| reward -= 0.10 | |
| # ββ Terminal reward (only when episode ends) ββββββββββββββββββ | |
| if terminal and action.submit_answer: | |
| # Grade the submitted answer β up to 0.60 of total reward | |
| task_score = task.grade(action.submit_answer) | |
| reward += task_score * 0.60 | |
| # Clamp to [0.0, 1.0] | |
| return max(0.0, min(1.0, reward)) | |
| def _count_relevant_tables(self, query: str, relevant_tables: list) -> int: | |
| query_lower = query.lower() | |
| return sum(1 for t in relevant_tables if t.lower() in query_lower) | |
| def _is_stuck(self, history: list) -> bool: | |
| if len(history) < 3: | |
| return False | |
| return len(set(history[-3:])) == 1 | |
| ``` | |
| **Reward breakdown per step:** | |
| | Signal | Max Value | Condition | | |
| |---|---|---| | |
| | No SQL error | +0.15 | Query executes cleanly | | |
| | Relevant table used | +0.10 | Query touches correct table(s) | | |
| | Non-empty result | +0.05 | Result set has at least 1 row | | |
| | Reasonable result size | +0.05 | Result has < 1000 rows | | |
| | Late-step penalty | β0.02/step | Each step beyond step 3 | | |
| | Infinite loop penalty | β0.10 | Same query repeated 3+ times | | |
| | Terminal answer score | up to +0.60 | Task grader Γ 0.60 | | |
| **Maximum possible reward per episode:** ~1.0 | |
| **Minimum (immediate surrender):** 0.0 | |
| --- | |
| ## 8. Key Optimisations | |
| ### 8.1 Schema Summarisation | |
| Never dump raw `CREATE TABLE` SQL into the prompt β it wastes context. Use a compact summary: | |
| ```python | |
| # env/database.py | |
| def get_schema_summary(conn: sqlite3.Connection) -> str: | |
| """Return one-line-per-table schema, e.g.: | |
| users: (id, email, country, plan, created_at, churned_at) | |
| """ | |
| cursor = conn.execute( | |
| "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name" | |
| ) | |
| tables = [r[0] for r in cursor.fetchall()] | |
| lines = [] | |
| for table in tables: | |
| cols = conn.execute(f"PRAGMA table_info({table})").fetchall() | |
| col_names = [c[1] for c in cols] | |
| lines.append(f"{table}: ({', '.join(col_names)})") | |
| return "\n".join(lines) | |
| ``` | |
| ### 8.2 Answer Normalisation | |
| Strip LLM formatting before grading β don't penalise the agent for markdown: | |
| ```python | |
| # env/utils.py | |
| import re | |
| def normalize_answer(raw: str) -> str: | |
| """Remove common LLM answer preambles and formatting.""" | |
| text = raw.strip().lower() | |
| text = re.sub(r'the (answer|result) is:?\s*', '', text) | |
| text = re.sub(r'\*+', '', text) # bold | |
| text = re.sub(r'```.*?```', '', text, flags=re.DOTALL) # code blocks | |
| text = re.sub(r'`[^`]+`', lambda m: m.group().strip('`'), text) | |
| text = re.sub(r'\s+', ' ', text) | |
| return text.strip() | |
| ``` | |
| ### 8.3 Progressive Hints | |
| Give hints as steps increase β keeps episodes learnable and reward dense: | |
| ```python | |
| # env/tasks/base.py | |
| def get_hints(self, step: int) -> list[str]: | |
| hints = [] | |
| if step > 5: | |
| hints.append(f"Hint: The relevant tables are: {', '.join(self.relevant_tables)}") | |
| if step > 10: | |
| hints.append(f"Hint: Try using {self.sql_hint}") | |
| if step > 15: | |
| hints.append("Hint: Make sure to submit your answer with submit_answer.") | |
| return hints | |
| ``` | |
| ### 8.4 Ground Truth Computed Post-Seed | |
| Always compute ground truth **after** seeding, so it matches the actual data: | |
| ```python | |
| # env/tasks/easy.py | |
| def compute_ground_truth(self, conn: sqlite3.Connection): | |
| result = conn.execute( | |
| "SELECT COUNT(*) FROM users WHERE created_at >= DATE('now', '-30 days')" | |
| ).fetchone() | |
| self.ground_truth = result[0] | |
| ``` | |
| ### 8.5 SQL Safety Guards | |
| Block any mutating operations: | |
| ```python | |
| FORBIDDEN_KEYWORDS = ["DROP", "DELETE", "INSERT", "UPDATE", "ALTER", "CREATE", "TRUNCATE"] | |
| def is_safe_query(query: str) -> bool: | |
| upper = query.upper() | |
| return not any(kw in upper for kw in FORBIDDEN_KEYWORDS) | |
| ``` | |
| --- | |
| ## 9. Baseline Inference Script | |
| ```python | |
| # baseline/run_baseline.py | |
| """ | |
| Baseline inference script for sql-data-analyst OpenEnv. | |
| Usage: | |
| export OPENAI_API_KEY=sk-... | |
| python baseline/run_baseline.py | |
| Produces reproducible scores across all 3 tasks. | |
| """ | |
| import os | |
| import json | |
| from openai import OpenAI | |
| from env.environment import SQLAnalystEnv | |
| from env.models import Action | |
| API_KEY = os.environ["OPENAI_API_KEY"] | |
| MODEL = "gpt-4o-mini" | |
| MAX_STEPS = 20 | |
| TASK_IDS = ["monthly_signups", "top_revenue_category", "churn_analysis"] | |
| client = OpenAI(api_key=API_KEY) | |
| SYSTEM_PROMPT = """ | |
| You are a SQL data analyst. You are given a database schema and a business question. | |
| Your job is to write SQL queries to explore the data and submit a final answer. | |
| Rules: | |
| - Only write SELECT or WITH queries. | |
| - Reply with JSON only. No explanation. | |
| - To run a query: {"sql_query": "SELECT ..."} | |
| - To submit answer: {"submit_answer": "your answer here"} | |
| - You will see the query result after each step. | |
| - Submit your answer when you are confident. | |
| """ | |
| def build_prompt(obs) -> str: | |
| parts = [ | |
| f"Database schema:\n{obs.schema_summary}", | |
| f"\nQuestion: {obs.question}", | |
| f"\nStep: {obs.step} / {obs.max_steps}", | |
| ] | |
| if obs.last_query: | |
| parts.append(f"\nLast query:\n{obs.last_query}") | |
| if obs.last_result and obs.last_result.rows: | |
| cols = obs.last_result.columns | |
| rows = obs.last_result.rows[:10] # show max 10 rows | |
| parts.append(f"\nResult columns: {cols}") | |
| parts.append(f"Result rows (first {len(rows)}):\n{json.dumps(rows, indent=2)}") | |
| if obs.last_error: | |
| parts.append(f"\nSQL error: {obs.last_error}") | |
| if obs.hints: | |
| parts.append(f"\nHints: {'; '.join(obs.hints)}") | |
| parts.append("\nWhat is your next action? Reply with JSON only.") | |
| return "\n".join(parts) | |
| def parse_action(response_text: str) -> Action: | |
| """Extract JSON action from LLM response.""" | |
| text = response_text.strip() | |
| # Strip markdown code fences if present | |
| text = text.replace("```json", "").replace("```", "").strip() | |
| try: | |
| data = json.loads(text) | |
| return Action(**data) | |
| except Exception: | |
| # Fallback: treat entire response as a submit | |
| return Action(submit_answer=text) | |
| def run_task(task_id: str) -> dict: | |
| print(f"\n{'='*50}") | |
| print(f"Task: {task_id}") | |
| print('='*50) | |
| env = SQLAnalystEnv(task_id=task_id) | |
| result = env.reset() | |
| obs = result.observation | |
| history = [] | |
| score = 0.0 | |
| print(f"Question: {obs.question}") | |
| for step in range(1, MAX_STEPS + 1): | |
| if result.done: | |
| print(f"Episode done at step {step - 1}") | |
| break | |
| user_prompt = build_prompt(obs) | |
| history.append({"role": "user", "content": user_prompt}) | |
| response = client.chat.completions.create( | |
| model=MODEL, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| *history[-8:], # last 4 turns (8 messages) | |
| ], | |
| temperature=0.0, # deterministic | |
| ) | |
| reply = response.choices[0].message.content | |
| history.append({"role": "assistant", "content": reply}) | |
| action = parse_action(reply) | |
| print(f"Step {step}: {action}") | |
| result = env.step(action) | |
| obs = result.observation | |
| score = result.reward | |
| if result.done: | |
| break | |
| state = env.state() | |
| print(f"Final total reward: {state.total_reward}") | |
| return { | |
| "task_id": task_id, | |
| "total_reward": state.total_reward, | |
| "steps": state.step, | |
| } | |
| def main(): | |
| results = [] | |
| for task_id in TASK_IDS: | |
| r = run_task(task_id) | |
| results.append(r) | |
| print("\n" + "="*50) | |
| print("BASELINE RESULTS") | |
| print("="*50) | |
| for r in results: | |
| print(f"{r['task_id']:30s} score={r['total_reward']:.3f} steps={r['steps']}") | |
| avg = sum(r["total_reward"] for r in results) / len(results) | |
| print(f"\nAverage score: {avg:.3f}") | |
| # Write results to file for reproducibility | |
| with open("baseline_scores.json", "w") as f: | |
| json.dump(results, f, indent=2) | |
| print("Saved to baseline_scores.json") | |
| if __name__ == "__main__": | |
| main() | |
| ``` | |
| --- | |
| ## 10. openenv.yaml | |
| ```yaml | |
| name: sql-data-analyst | |
| version: "1.0.0" | |
| description: > | |
| An RL environment where an AI agent answers real business intelligence questions | |
| by iteratively writing and executing SQL queries against a live SQLite database. | |
| Simulates the day-to-day workflow of a data analyst. | |
| tags: | |
| - openenv | |
| - sql | |
| - data-analysis | |
| - business-intelligence | |
| - real-world | |
| author: your-username | |
| repository: https://huggingface.co/spaces/your-username/sql-data-analyst | |
| observation_space: | |
| type: dict | |
| fields: | |
| schema_summary: | |
| type: string | |
| description: Compact one-line-per-table schema of the database | |
| question: | |
| type: string | |
| description: Natural language business question to answer | |
| last_query: | |
| type: string | |
| nullable: true | |
| description: The last SQL query executed by the agent | |
| last_result: | |
| type: object | |
| nullable: true | |
| description: Result of the last query (columns, rows, error) | |
| last_error: | |
| type: string | |
| nullable: true | |
| description: SQL error message if last query failed | |
| step: | |
| type: integer | |
| description: Current step number | |
| max_steps: | |
| type: integer | |
| description: Maximum steps allowed for this task | |
| hints: | |
| type: array | |
| items: string | |
| description: Progressive hints revealed as steps increase | |
| action_space: | |
| type: union | |
| description: Agent must provide exactly one of the following | |
| options: | |
| sql_query: | |
| type: string | |
| description: A SELECT or WITH SQL query to execute | |
| submit_answer: | |
| type: string | |
| description: Final answer to the question. Ends the episode. | |
| tasks: | |
| - id: monthly_signups | |
| difficulty: easy | |
| max_steps: 10 | |
| description: "Count the number of users who signed up in the last 30 days" | |
| skills_required: | |
| - COUNT | |
| - WHERE with date filter | |
| - id: top_revenue_category | |
| difficulty: medium | |
| max_steps: 15 | |
| description: "Find which product category generated the most revenue in Q3" | |
| skills_required: | |
| - JOIN (3 tables) | |
| - GROUP BY | |
| - SUM aggregation | |
| - Date range filtering | |
| - id: churn_analysis | |
| difficulty: hard | |
| max_steps: 20 | |
| description: > | |
| Find email addresses of users who placed exactly 3 orders and then | |
| never ordered again (churned after their 3rd purchase) | |
| skills_required: | |
| - Subqueries | |
| - HAVING clause | |
| - Date logic | |
| - Window functions (optional) | |
| baseline_scores: | |
| monthly_signups: 0.82 | |
| top_revenue_category: 0.61 | |
| churn_analysis: 0.38 | |
| average: 0.60 | |
| ``` | |
| --- | |
| ## 11. Dockerfile | |
| ```dockerfile | |
| FROM python:3.11-slim | |
| WORKDIR /app | |
| # Install dependencies | |
| COPY requirements.txt . | |
| RUN pip install --no-cache-dir -r requirements.txt | |
| # Copy source | |
| COPY . . | |
| # Pre-seed the database at build time (optional β env also seeds at reset()) | |
| RUN python -c "from env.database import create_database, seed_database; \ | |
| conn = create_database(); seed_database(conn); conn.close()" | |
| # Expose port for HuggingFace Spaces | |
| EXPOSE 7860 | |
| # Start the API server | |
| CMD ["python", "-m", "uvicorn", "env.server:app", "--host", "0.0.0.0", "--port", "7860"] | |
| ``` | |
| ``` | |
| # requirements.txt | |
| pydantic>=2.0 | |
| fastapi | |
| uvicorn | |
| openai | |
| faker | |
| pytest | |
| ``` | |
| --- | |
| ## 12. README Template | |
| ````markdown | |
| # SQL Data Analyst β OpenEnv Environment | |
| An RL training environment where an AI agent learns to answer business intelligence | |
| questions by writing and executing SQL queries against a live database. | |
| ## Motivation | |
| Data analysts spend significant time translating business questions into SQL queries. | |
| This environment trains agents to do exactly that β iteratively exploring a database | |
| schema, writing queries, observing results, and submitting final answers. | |
| ## Observation Space | |
| | Field | Type | Description | | |
| |---|---|---| | |
| | `schema_summary` | string | Compact DB schema (one line per table) | | |
| | `question` | string | Natural language business question | | |
| | `last_query` | string \| null | Most recent SQL query | | |
| | `last_result` | object \| null | Query result: columns, rows (max 50), error | | |
| | `last_error` | string \| null | SQL error if last query failed | | |
| | `step` | int | Current step number | | |
| | `max_steps` | int | Episode step limit | | |
| | `hints` | string[] | Progressive hints (revealed after step 5, 10, 15) | | |
| ## Action Space | |
| Agent must submit exactly one of: | |
| | Action | Type | Description | | |
| |---|---|---| | |
| | `sql_query` | string | A SELECT or WITH SQL query to execute | | |
| | `submit_answer` | string | Final answer β ends the episode | | |
| ## Tasks | |
| | Task | Difficulty | Max Steps | Description | | |
| |---|---|---|---| | |
| | `monthly_signups` | Easy | 10 | Count signups in the last 30 days | | |
| | `top_revenue_category` | Medium | 15 | Find highest revenue product category in Q3 | | |
| | `churn_analysis` | Hard | 20 | Find emails of users who churned after 3 purchases | | |
| ## Reward Function | |
| Rewards are given at every step (not just episode end): | |
| - `+0.15` β Query executes without error | |
| - `+0.10` β Query references a relevant table | |
| - `+0.05` β Result has at least one row | |
| - `+0.05` β Result is a sensible size | |
| - `-0.02` per step beyond step 3 (efficiency penalty) | |
| - `-0.10` if agent repeats the same query 3+ times | |
| - `+0.00β0.60` on final submission (task grader Γ 0.60) | |
| ## Setup | |
| ```bash | |
| git clone https://huggingface.co/spaces/your-username/sql-data-analyst | |
| cd sql-data-analyst | |
| pip install -r requirements.txt | |
| ``` | |
| ### Run locally | |
| ```python | |
| from env.environment import SQLAnalystEnv | |
| from env.models import Action | |
| env = SQLAnalystEnv(task_id="monthly_signups") | |
| result = env.reset() | |
| print(result.observation.question) | |
| # Agent takes a step | |
| result = env.step(Action(sql_query="SELECT COUNT(*) FROM users WHERE created_at >= DATE('now', '-30 days')")) | |
| print(result.reward) | |
| ``` | |
| ### Run baseline | |
| ```bash | |
| export OPENAI_API_KEY=sk-... | |
| python baseline/run_baseline.py | |
| ``` | |
| ### Docker | |
| ```bash | |
| docker build -t sql-analyst-env . | |
| docker run -p 7860:7860 -e OPENAI_API_KEY=sk-... sql-analyst-env | |
| ``` | |
| ## Baseline Scores | |
| | Task | Score | Model | | |
| |---|---|---| | |
| | monthly_signups | 0.82 | gpt-4o-mini | | |
| | top_revenue_category | 0.61 | gpt-4o-mini | | |
| | churn_analysis | 0.38 | gpt-4o-mini | | |
| | **Average** | **0.60** | gpt-4o-mini | | |
| ## Validation | |
| ```bash | |
| openenv validate --env env.environment.SQLAnalystEnv | |
| pytest tests/ | |
| ``` | |
| ```` | |
| --- | |
| ## 13. Full File Structure | |
| ``` | |
| sql-analyst-openenv/ | |
| β | |
| βββ env/ | |
| β βββ __init__.py | |
| β βββ environment.py β Main OpenEnv class (reset/step/state) | |
| β βββ models.py β Pydantic: Observation, Action, StepResult, EnvState | |
| β βββ database.py β SQLite creation + Faker seeding + schema summary | |
| β βββ executor.py β Safe SQL execution (SELECT-only guard) | |
| β βββ reward.py β RewardCalculator class | |
| β βββ utils.py β normalize_answer, is_safe_query helpers | |
| β βββ server.py β FastAPI wrapper for HuggingFace Spaces | |
| β βββ tasks/ | |
| β βββ __init__.py β TASKS dict: {task_id: TaskInstance} | |
| β βββ base.py β BaseTask abstract class | |
| β βββ easy.py β MonthlySignupsTask | |
| β βββ medium.py β TopRevenueCategoryTask | |
| β βββ hard.py β ChurnAnalysisTask | |
| β | |
| βββ baseline/ | |
| β βββ run_baseline.py β Full inference script (OpenAI API) | |
| β βββ prompts.py β System prompt + user prompt builder | |
| β | |
| βββ tests/ | |
| β βββ test_env.py β reset/step/state contract tests | |
| β βββ test_graders.py β Unit tests for each task grader | |
| β βββ test_reward.py β Reward calculator unit tests | |
| β | |
| βββ openenv.yaml β OpenEnv spec metadata | |
| βββ Dockerfile β docker build + docker run | |
| βββ requirements.txt | |
| βββ README.md | |
| ``` | |
| --- | |
| ## 14. Build Order | |
| Follow this order when coding. Each step is a self-contained deliverable. | |
| ### Step 1 β Models (30 min) | |
| Build `env/models.py` first. All other files depend on these types. | |
| Test: can import and instantiate `Observation`, `Action`, `StepResult`. | |
| ### Step 2 β Database (45 min) | |
| Build `env/database.py` β schema creation, Faker seeding, schema summary. | |
| Test: run `create_database()` + `seed_database()`, query the tables manually. | |
| ### Step 3 β Tasks + Graders (60 min) | |
| Build `env/tasks/base.py`, then `easy.py`, `medium.py`, `hard.py`. | |
| Test each grader with known inputs: perfect answer β 1.0, wrong answer β 0.0. | |
| ### Step 4 β Reward Calculator (30 min) | |
| Build `env/reward.py`. | |
| Test: step with good query β positive reward, repeated query β penalty applied. | |
| ### Step 5 β Environment Core (60 min) | |
| Build `env/environment.py` β wire together DB, executor, reward, tasks. | |
| Test: full episode loop manually: `reset()` β `step()` Γ N β `state()`. | |
| ### Step 6 β Baseline Script (45 min) | |
| Build `baseline/run_baseline.py`. | |
| Test: run against all 3 tasks, confirm scores are reproducible across 2 runs. | |
| ### Step 7 β FastAPI Server (30 min) | |
| Build `env/server.py` β wrap env in HTTP endpoints for HF Spaces. | |
| Test: `docker build` passes, `docker run` starts server on port 7860. | |
| ### Step 8 β Docs + Validation (30 min) | |
| Write `openenv.yaml` and `README.md`. Run `openenv validate`. | |
| Fill in real baseline scores from Step 6 output. | |
| ### Step 9 β Deploy to HuggingFace (15 min) | |
| Push to HF Space repo. Tag with `openenv`. Verify Space starts cleanly. | |
| --- | |
| *Total estimated time: ~6 hours for a clean first build.* | |