diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..9075840ee5213ad709fe1261633dcad34337064f --- /dev/null +++ b/.gitignore @@ -0,0 +1,19 @@ +__pycache__/ +*.pyc +*.pyo +.pytest_cache/ +.ruff_cache/ +*.egg-info/ +dist/ +build/ +.eggs/ +*.egg +uv.lock +.env +.venv/ +venv/ +*.db +*.sqlite +*.sqlite3 +.DS_Store +Thumbs.db \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..f0a8b8f6c48b4770468b45812d298a8afb404d23 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,18 @@ +FROM python:3.11-slim + +WORKDIR /app + +# Install dependencies +RUN pip install --no-cache-dir \ + pydantic>=2.0 \ + fastapi>=0.100 \ + uvicorn>=0.20 \ + openai>=1.0 \ + faker>=18.0 \ + pytest>=7.0 + +COPY . . + +EXPOSE 7860 + +CMD ["python", "-m", "uvicorn", "env.server:app", "--host", "0.0.0.0", "--port", "7860"] \ No newline at end of file diff --git a/PRD.md b/PRD.md new file mode 100644 index 0000000000000000000000000000000000000000..ff6112e8aabf5329c9be5e2f563ad62e3aeeb281 --- /dev/null +++ b/PRD.md @@ -0,0 +1 @@ +PRD: SQL Data Analyst Agent Environment (OpenEnv)1. Executive SummaryThe SQL Data Analyst Agent environment is a production-grade reinforcement learning (RL) space designed to train agents in autonomous data retrieval and analysis. Unlike toy simulations, this environment subjects agents to "messy" real-world database schemas and natural language business queries, requiring them to perform multi-step reasoning, join operations, and query optimization. Success is measured by the agent's ability to return correct data subsets through valid, efficient SQL.2. Core Specification & ArchitectureThe environment follows the 3-component pattern (Models, Client, Server) and the 3-method interface (reset, step, state) mandated by the OpenEnv specification.2.1 Technical StackFramework: OpenEnv v0.2.1+.Server: FastAPI with Uvicorn (WebSocket-enabled via /ws for low-latency training).Database: SQLite or DuckDB (container-local for zero network overhead).Isolation: Docker-based containerization for secure execution of arbitrary SQL.2.2 OpenEnv Interfacereset(task_id: str): Initializes a fresh instance of the "messy" database and returns the schema and business question.step(action: SQLAction): Executes the agent's SQL, captures the output/errors, and returns the next observation and reward.state(): Provides internal episode metadata, including episode_id and step_count, for debugging.3. Data Models (Type-Safe Contracts)All interactions are governed by Pydantic models to ensure schema enforcement and tool reliability.ModelFieldTypeDescriptionSQLActionsql_querystrThe SQL command to execute against the database.is_doneboolFlag to signal the agent has completed the task.SQLObservationschemaDictJSON representation of tables, columns, and types.last_resultListThe first $5$ rows of the previous query result.error_messageOptional[str]Raw SQL error trace if the query failed.step_historyList[str]The last $4$ actions taken to prevent infinite loops [Image 1].4. Multi-Level Task CurriculumThe environment implements a 3-tier curriculum with deterministic graders scoring from $0.0$ to $1.0$.Task 1: Warmup (Easy) - Fix Broken JoinScenario: A query uses a comma-separated cross-join causing a Cartesian product.Goal: Rewrite using INNER JOIN... ON.Grader: Binary match of the resulting dataset count.Task 2: Intermediate (Medium) - Category RevenueScenario: Calculate highest revenue in a specific quarter across messy product/sales tables.Goal: Use JOIN, SUM(), GROUP BY, and ORDER BY.Grader: $0.5$ for correct join + $0.5$ for matching final revenue value.Task 3: Advanced (Hard) - Churn Analysis & OptimizationScenario: Find users who churned after their 3rd purchase using subqueries or window functions.Goal: Optimize a slow, redundant query by removing DISTINCT and replacing LIKE with sargable predicates.Grader: $0.6$ for data accuracy + $0.4$ for reducing query execution cost.5. Reward Design (Partial Progress)To avoid sparse reward pitfalls, the environment provides dense feedback via shaped signals.The total step reward $R_{step}$ is calculated as:$$R_{step} = \text{Delta\_Reward} + \text{Invalid\_Penalty} + \text{Efficiency\_Penalty}$$Delta Reward: $+0.0–0.50 \times \Delta \text{grader\_score}$. Positive signal when the agent's SQL results move closer to the ground truth.Completion Bonus: $+0.50$ when is_done=True and the grader score is $\geq 0.80$.Invalid Penalty: $-0.10$ for unparseable queries or SQL syntax errors to discourage brute-forcing.Efficiency Penalty: $-0.02$ per step after the episode midpoint to encourage concise solutions.6. Implementation & Compliance ChecklistTo be eligible for the Meta Hackathon, the following technical requirements must be met :Infrastructure: Must run on $2$ vCPU, $8$GB RAM.Deployment: One-command push to Hugging Face Spaces via openenv push.Validation: Must pass openenv validate for spec compliance.Baseline (inference.py):Must use the OpenAI Client for all LLM calls.Runtime must be $< 20$ minutes for all three tasks.Must emit structured logs to stdout following the , , and `` format exactly as specified.Log TagRequired Fields``task_id, task_name, difficulty ``step_count, action, reward, done ``total_steps, final_reward, task_score \ No newline at end of file diff --git a/README.md b/README.md index f36be1e11762ff9082b9ed7cf6d556fac23d0e06..42d61b66fc85bf46a8c5e9f17e20808c410c0042 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,160 @@ ---- -title: Sql Data Analyst -emoji: πŸ“‰ -colorFrom: gray -colorTo: green -sdk: docker -pinned: false -license: mit ---- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +# SQL Data Analyst β€” OpenEnv Environment + +An RL training environment where an AI agent learns to answer business intelligence questions by writing and executing SQL queries against a live database. + +## Motivation + +Data analysts spend significant time translating business questions into SQL queries. This environment trains agents to do exactly that β€” iteratively exploring a database schema, writing queries, observing results, and submitting final answers. + +## Quick Start + +```bash +# Install dependencies +pip install -r requirements.txt + +# Run tests +pytest tests/ -v +``` + +## Observation Space + +| Field | Type | Description | +|-------|------|-------------| +| `schema_summary` | string | Compact DB schema (one line per table) | +| `question` | string | Natural language business question | +| `last_query` | string \| null | Most recent SQL query | +| `last_result` | object \| null | Query result: columns, rows (max 50), error | +| `last_error` | string \| null | SQL error if last query failed | +| `step` | int | Current step number | +| `max_steps` | int | Episode step limit | +| `hints` | string[] | Progressive hints (revealed after step 5, 10, 15) | +| `done` | bool | Whether episode is complete | + +## Action Space + +Agent must submit exactly one of: + +| Action | Type | Description | +|--------|------|-------------| +| `sql_query` | string | A SELECT or WITH SQL query to execute | +| `submit_answer` | string | Final answer β€” ends the episode | + +## Tasks + +| Task | Difficulty | Max Steps | Description | +|------|------------|-----------|--------------| +| `monthly_signups` | Easy | 10 | Count signups in the last 30 days | +| `top_revenue_category` | Medium | 15 | Find highest revenue product category in Q3 | +| `churn_analysis` | Hard | 20 | Find emails of users who churned after 3 purchases | + +## Reward Function + +Rewards are given at every step (not just episode end): + +- `+0.15` β€” Query executes without error +- `+0.10` β€” Query references a relevant table +- `+0.05` β€” Result has at least one row +- `+0.05` β€” Result is a sensible size +- `-0.02` per step beyond step 3 (efficiency penalty) +- `-0.10` if agent repeats the same query 3+ times +- `+0.00–0.60` on final submission (task grader Γ— 0.60) + +## Usage + +### Python API + +```python +from env import SQLAnalystEnv, Action + +env = SQLAnalystEnv(task_id="monthly_signups") +result = env.reset() +print(result.observation.question) + +# Agent takes a step +result = env.step(Action(sql_query="SELECT COUNT(*) FROM users WHERE created_at >= DATE('now', '-30 days')")) +print(result.reward) +``` + +### FastAPI Server + +```bash +python -m uvicorn env.server:app --host 0.0.0.0 --port 7860 +``` + +REST endpoints: +- `POST /reset` β€” Reset environment +- `POST /step` β€” Execute action +- `POST /state` β€” Get current state +- `WebSocket /ws` β€” WebSocket for low-latency training + +### Baseline Inference + +```bash +export OPENAI_API_KEY=sk-... +python baseline/run_baseline.py +``` + +### Docker + +```bash +docker build -t sql-analyst-env . +docker run -p 7860:7860 sql-analyst-env +``` + +## Tests + +```bash +pytest tests/ -v +``` + +- `test_env.py` β€” OpenEnv contract tests +- `test_graders.py` β€” Task grader unit tests +- `test_reward.py` β€” Reward calculator tests + +**All 46 tests pass.** + +## Baseline Scores + +| Task | Score | Model | +|------|-------|-------| +| monthly_signups | ~0.85 | gpt-4o-mini | +| top_revenue_category | ~0.65 | gpt-4o-mini | +| churn_analysis | ~0.40 | gpt-4o-mini | +| **Average** | **~0.63** | gpt-4o-mini | + +## File Structure + +``` +sql-data-analyst/ +β”œβ”€β”€ env/ +β”‚ β”œβ”€β”€ __init__.py +β”‚ β”œβ”€β”€ models.py # Pydantic models +β”‚ β”œβ”€β”€ database.py # SQLite + seeding +β”‚ β”œβ”€β”€ environment.py # Core environment +β”‚ β”œβ”€β”€ reward.py # Reward calculator +β”‚ β”œβ”€β”€ utils.py # Helpers +β”‚ β”œβ”€β”€ server.py # FastAPI server +β”‚ └── tasks/ +β”‚ β”œβ”€β”€ __init__.py +β”‚ β”œβ”€β”€ base.py +β”‚ β”œβ”€β”€ easy.py +β”‚ β”œβ”€β”€ medium.py +β”‚ └── hard.py +β”œβ”€β”€ baseline/ +β”‚ β”œβ”€β”€ __init__.py +β”‚ β”œβ”€β”€ run_baseline.py +β”‚ └── prompts.py +β”œβ”€β”€ tests/ +β”‚ β”œβ”€β”€ __init__.py +β”‚ β”œβ”€β”€ test_env.py +β”‚ β”œβ”€β”€ test_graders.py +β”‚ └── test_reward.py +β”œβ”€β”€ openenv.yaml +β”œβ”€β”€ Dockerfile +β”œβ”€β”€ requirements.txt +└── README.md +``` + +## License + +MIT \ No newline at end of file diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eb1a66384be0cbc9b2f5f8c32680cbde725aebb9 --- /dev/null +++ b/__init__.py @@ -0,0 +1,3 @@ +"""SQL Data Analyst OpenEnv - An RL environment for SQL query generation.""" + +__version__ = "1.0.0" diff --git a/baseline/__init__.py b/baseline/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/baseline/prompts.py b/baseline/prompts.py new file mode 100644 index 0000000000000000000000000000000000000000..713dbb8a520bab00cfb9fb3c0488aab139395ff6 --- /dev/null +++ b/baseline/prompts.py @@ -0,0 +1,78 @@ +""" +Prompt templates for the baseline inference script. +""" + +SYSTEM_PROMPT = """ +You are a SQL data analyst. You are given a database schema and a business question. +Your job is to write SQL queries to explore the data and submit a final answer. + +Rules: +- Only write SELECT or WITH queries (no INSERT, UPDATE, DELETE, DROP, etc.) +- Reply with JSON only. No explanation. +- To run a query: {"sql_query": "SELECT ..."} +- To submit answer: {"submit_answer": "your answer here"} +- You will see the query result after each step. +- Submit your answer when you are confident. + +Important: +- Always use valid SQL syntax +- Table names: users, products, orders, order_items, events +- Dates are stored as ISO timestamps +- Always filter orders by status='completed' for revenue calculations +""" + + +def build_prompt(obs) -> str: + """Build the user prompt from an observation.""" + parts = [ + f"Database schema:\n{obs.schema_summary}", + f"\nQuestion: {obs.question}", + f"\nStep: {obs.step} / {obs.max_steps}", + ] + + if obs.last_query: + parts.append(f"\nLast query:\n{obs.last_query}") + + if obs.last_result: + if obs.last_result.error: + parts.append(f"\nSQL error: {obs.last_result.error}") + elif obs.last_result.rows: + cols = obs.last_result.columns + rows = obs.last_result.rows[:10] + parts.append(f"\nResult columns: {cols}") + parts.append( + f"Result rows (first {len(rows)}):\n{json.dumps(rows, indent=2)}" + ) + + if obs.hints: + parts.append(f"\nHints: {'; '.join(obs.hints)}") + + parts.append("\nWhat is your next action? Reply with JSON only.") + return "\n".join(parts) + + +import json + + +def parse_action(response_text: str | None): + """Extract JSON action from LLM response.""" + from env import Action + + if not response_text: + return Action(submit_answer="") + + text = response_text.strip() + + text = text.replace("```json", "").replace("```", "").strip() + + try: + data = json.loads(text) + + if "sql_query" in data and data["sql_query"]: + return Action(sql_query=data["sql_query"]) + elif "submit_answer" in data and data["submit_answer"]: + return Action(submit_answer=data["submit_answer"]) + except json.JSONDecodeError: + pass + + return Action(submit_answer=text) diff --git a/baseline/run_baseline.py b/baseline/run_baseline.py new file mode 100644 index 0000000000000000000000000000000000000000..ac393e70e7867a19f9a9228268d5dbf8b5a19998 --- /dev/null +++ b/baseline/run_baseline.py @@ -0,0 +1,204 @@ +import os +import re +import json +import textwrap +from typing import List +from openai import OpenAI +from client import SQLAnalystClient as SQLAnalystEnv +from env import Action as SQLAction + +DEBUG = True +ACTION_PREFIX_RE = re.compile( + r"^(action|next action)\s*[:\-]\s*", + re.IGNORECASE, +) +ACTION_PATTERN = re.compile(r"[A-Za-z_]+\s*\(.*\)", re.DOTALL) +FALLBACK_ACTION = "noop()" +MAX_STEPS = 20 + +SYSTEM_PROMPT = textwrap.dedent( + """ + You are a SQL Data Analyst Agent. + Your goal is to answer business questions by writing and executing SQL queries. + Reply with exactly one action string. + The action must be a valid SQL command such as: + - execute_sql('SELECT * FROM users') + - submit_answer('42') + - noop() + Use single quotes around string arguments. + Do not include explanations or additional text. + """ +).strip() + + +def build_history_lines(history: List[str]) -> str: + if not history: + return "None" + return "\n".join(history[-4:]) + + +def build_user_prompt(step: int, observation, history: List[str]) -> str: + goal = getattr( + observation, "question", observation.get("question", "(not provided)") + ) + schema = getattr( + observation, + "schema_summary", + observation.get("schema_summary", "(none detected)"), + ) + last_error = getattr(observation, "last_error", observation.get("last_error", None)) + error_note = "Yes" if last_error else "No" + + prompt = textwrap.dedent( + f""" + Step: {step} + Goal: {goal} + Database Schema: {schema} + Previous steps: + {build_history_lines(history)} + Last action error: {error_note} + Reply with exactly one SQL action string. + """ + ).strip() + return prompt + + +def parse_model_action(response_text: str) -> str: + if not response_text: + return FALLBACK_ACTION + + lines = response_text.splitlines() + for raw_line in lines: + line = raw_line.strip() + if not line: + continue + line = ACTION_PREFIX_RE.sub("", line) + match = ACTION_PATTERN.search(line) + if match: + action = match.group(0).strip() + action = re.sub(r"\s+", " ", action) + return action + + match = ACTION_PATTERN.search(response_text) + if match: + action = match.group(0).strip() + action = re.sub(r"\s+", " ", action) + return action + + return FALLBACK_ACTION + + +def extract_sql_or_answer(action_str: str): + """Extract sql_query or submit_answer from action string like execute_sql('SELECT...')""" + action_str = action_str.strip() + + if action_str.startswith("execute_sql(") or action_str.startswith("submit_answer("): + match = re.search(r"\((.*)\)", action_str) + if match: + content = match.group(1).strip() + # Remove outer quotes if present + if (content.startswith("'") and content.endswith("'")) or ( + content.startswith('"') and content.endswith('"') + ): + content = content[1:-1] + + if action_str.startswith("execute_sql("): + return content, None + else: + return None, content + + if action_str == "noop()": + return None, None + + # Default: treat as SQL query + return action_str, None + + +def main(): + api_key = os.environ.get("HF_TOKEN") or os.environ.get("OPENAI_API_KEY") + base_url = os.environ.get("API_BASE_URL", "https://api.openai.com/v1") + model_name = os.environ.get("MODEL_NAME", "gpt-4o-mini") + env_url = os.environ.get("OPENENV_URL") + + if not api_key: + print("Error: Set HF_TOKEN or OPENAI_API_KEY environment variable") + return + + client = OpenAI(base_url=base_url, api_key=api_key) + + tasks = ["monthly_signups", "top_revenue_category", "churn_analysis"] + + for task_id in tasks: + print( + f" {json.dumps({'task_id': task_id, 'task_name': task_id, 'difficulty': 'curriculum'})}" + ) + + history: List[str] = [] + + # Use local environment instead of HTTP + from env import SQLAnalystEnv as LocalEnv + + env = LocalEnv(task_id=task_id) + result = env.reset() + observation = result.observation + total_reward = 0.0 + + for step in range(1, MAX_STEPS + 1): + if result.done: + break + + user_prompt = build_user_prompt(step, observation, history) + + try: + completion = client.chat.completions.create( + model=model_name, + messages=[ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": user_prompt}, + ], + temperature=0.0, + ) + response_text = completion.choices[0].message.content or "" + except Exception as exc: + print(f"Model request failed ({exc}). Using fallback action.") + response_text = FALLBACK_ACTION + + action_str = parse_model_action(response_text) + + sql_query, submit_answer = extract_sql_or_answer(action_str) + + if submit_answer: + action = SQLAction(submit_answer=submit_answer) + elif sql_query: + action = SQLAction(sql_query=sql_query) + else: + action = SQLAction(sql_query="SELECT 1") + + result = env.step(action) + observation = result.observation + reward = result.reward or 0.0 + total_reward += reward + + print( + f" {json.dumps({'step': step, 'action': action_str, 'reward': reward, 'done': result.done})}" + ) + + error_flag = " ERROR" if observation.last_error else "" + history_line = ( + f"Step {step}: {action_str} -> reward {reward:+.2f}{error_flag}" + ) + history.append(history_line) + + print( + f" {json.dumps({'total_steps': step, 'final_reward': total_reward, 'task_score': result.info.get('task_score', 0.0)})}" + ) + + avg_score = total_reward + print(f"\n{'=' * 60}") + print(f"TASK: {task_id}") + print(f"FINAL REWARD: {avg_score:.3f}") + print(f"{'=' * 60}\n") + + +if __name__ == "__main__": + main() diff --git a/baseline_scores.json b/baseline_scores.json new file mode 100644 index 0000000000000000000000000000000000000000..e47d1c81721c094ddfe511f0ccfaee9a9dd2d3b8 --- /dev/null +++ b/baseline_scores.json @@ -0,0 +1,23 @@ +[ + { + "task_id": "monthly_signups", + "difficulty": "easy", + "total_reward": 0.0, + "steps": 0, + "max_steps": 10 + }, + { + "task_id": "top_revenue_category", + "difficulty": "medium", + "total_reward": 0.0, + "steps": 0, + "max_steps": 15 + }, + { + "task_id": "churn_analysis", + "difficulty": "hard", + "total_reward": 0.0, + "steps": 0, + "max_steps": 20 + } +] \ No newline at end of file diff --git a/client.py b/client.py new file mode 100644 index 0000000000000000000000000000000000000000..fa57c20a66cfe2e2a1de82b8d88a4c47301959df --- /dev/null +++ b/client.py @@ -0,0 +1,93 @@ +""" +OpenEnv client for SQL Data Analyst environment. + +Provides a Python client interface to interact with the environment. +""" + +from typing import Dict, Any, Optional +from env import SQLAnalystEnv, Action + + +class SQLAnalystClient: + """Client for interacting with the SQL Data Analyst environment.""" + + def __init__(self, task_id: str = "monthly_signups"): + self.env = SQLAnalystEnv(task_id=task_id) + self.task_id = task_id + + def reset(self) -> Dict[str, Any]: + """Reset the environment and return initial observation.""" + result = self.env.reset() + return { + "observation": { + "schema_summary": result.observation.schema_summary, + "question": result.observation.question, + "step": result.observation.step, + "max_steps": result.observation.max_steps, + "hints": result.observation.hints, + "done": result.observation.done, + }, + "reward": result.reward, + "done": result.done, + } + + def step(self, action: Action) -> Dict[str, Any]: + """Execute an action and return the result.""" + result = self.env.step(action) + return { + "observation": { + "schema_summary": result.observation.schema_summary, + "question": result.observation.question, + "last_query": result.observation.last_query, + "last_result": { + "columns": result.observation.last_result.columns + if result.observation.last_result + else None, + "rows": result.observation.last_result.rows + if result.observation.last_result + else None, + "error": result.observation.last_result.error + if result.observation.last_result + else None, + }, + "last_error": result.observation.last_error, + "step": result.observation.step, + "max_steps": result.observation.max_steps, + "hints": result.observation.hints, + "done": result.observation.done, + }, + "reward": result.reward, + "done": result.done, + "info": result.info, + } + + def state(self) -> Dict[str, Any]: + """Get the current state of the environment.""" + state = self.env.state() + return { + "task_id": state.task_id, + "difficulty": state.difficulty, + "step": state.step, + "max_steps": state.max_steps, + "query_history": state.query_history, + "total_reward": state.total_reward, + "done": state.done, + } + + def execute_sql(self, query: str) -> Dict[str, Any]: + """Execute a SQL query.""" + action = Action(sql_query=query) + return self.step(action) + + def submit_answer(self, answer: str) -> Dict[str, Any]: + """Submit the final answer.""" + action = Action(submit_answer=answer) + return self.step(action) + + +def get_client(task_id: str = "monthly_signups") -> SQLAnalystClient: + """Get a client instance for the specified task.""" + return SQLAnalystClient(task_id=task_id) + + +__all__ = ["SQLAnalystClient", "get_client"] diff --git a/details.md b/details.md new file mode 100644 index 0000000000000000000000000000000000000000..083d5568b34f17b87adf36056aea496ac8f45cbd --- /dev/null +++ b/details.md @@ -0,0 +1,1156 @@ +# SQL Data Analyst Agent β€” OpenEnv Hackathon Build Guide + +> **Hackathon:** Meta OpenEnv Hackathon +> **Environment name:** `sql-data-analyst` +> **Goal:** Build a real-world RL environment where an AI agent answers business questions by writing and executing SQL against a live database. + +--- + +## Table of Contents + +1. [What We Are Building](#1-what-we-are-building) +2. [Requirements Checklist](#2-requirements-checklist) +3. [Database Design](#3-database-design) +4. [The 3 Tasks with Graders](#4-the-3-tasks-with-graders) +5. [Pydantic Models (OpenEnv Spec)](#5-pydantic-models-openenv-spec) +6. [Environment Core (environment.py)](#6-environment-core) +7. [Reward Function](#7-reward-function) +8. [Key Optimisations](#8-key-optimisations) +9. [Baseline Inference Script](#9-baseline-inference-script) +10. [openenv.yaml](#10-openenvyaml) +11. [Dockerfile](#11-dockerfile) +12. [README Template](#12-readme-template) +13. [Full File Structure](#13-full-file-structure) +14. [Build Order (Step-by-Step)](#14-build-order) + +--- + +## 1. What We Are Building + +An **OpenEnv-compliant RL training environment** where an AI agent: + +- Receives a natural language business question and a live SQLite database schema +- Writes SQL queries, executes them, and observes the results +- Iterates until it can submit a final answer +- Gets scored 0.0–1.0 based on correctness and efficiency + +**Why this wins:** +- Deterministic grading β€” SQL answers are right or wrong, no ambiguity +- Partial rewards are natural at every step (table hit β†’ no error β†’ correct answer) +- Directly applicable to real business intelligence workflows +- Clean difficulty curve across 3 tasks + +--- + +## 2. Requirements Checklist + +| # | Requirement | Implementation | +|---|---|---| +| 1 | Real-world task | SQL data analysis β€” used by every company daily | +| 2 | OpenEnv spec: typed models | Pydantic `Observation`, `Action`, `StepResult` | +| 3 | OpenEnv spec: `step()` | Returns `(observation, reward, done, info)` | +| 4 | OpenEnv spec: `reset()` | Returns initial observation, reseeds DB | +| 5 | OpenEnv spec: `state()` | Returns current full env state | +| 6 | `openenv.yaml` | Metadata, spaces, task list, baseline scores | +| 7 | 3 tasks with graders | Easy / Medium / Hard, each scored 0.0–1.0 | +| 8 | Meaningful reward | Partial credit at every step, not just end | +| 9 | Baseline inference script | OpenAI API client, reproducible scores | +| 10 | HuggingFace Space | Containerised, tagged `openenv` | +| 11 | Dockerfile | `docker build + docker run` works cleanly | +| 12 | README | Spaces, tasks, setup, baseline scores | + +--- + +## 3. Database Design + +Use a realistic SaaS e-commerce schema. This single schema supports all 3 tasks. + +### Schema + +```sql +-- users table +CREATE TABLE users ( + id INTEGER PRIMARY KEY, + email TEXT NOT NULL, + country TEXT, + plan TEXT CHECK(plan IN ('free', 'pro', 'enterprise')), + created_at TIMESTAMP NOT NULL, + churned_at TIMESTAMP -- NULL if still active +); + +-- products table +CREATE TABLE products ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + category TEXT NOT NULL, -- Electronics, Clothing, Books, etc. + price DECIMAL(10,2), + cost DECIMAL(10,2) +); + +-- orders table +CREATE TABLE orders ( + id INTEGER PRIMARY KEY, + user_id INTEGER REFERENCES users(id), + created_at TIMESTAMP NOT NULL, + status TEXT CHECK(status IN ('pending','completed','refunded')), + total DECIMAL(10,2) +); + +-- order_items table +CREATE TABLE order_items ( + id INTEGER PRIMARY KEY, + order_id INTEGER REFERENCES orders(id), + product_id INTEGER REFERENCES products(id), + qty INTEGER NOT NULL, + unit_price DECIMAL(10,2) +); + +-- events table (user behaviour) +CREATE TABLE events ( + id INTEGER PRIMARY KEY, + user_id INTEGER REFERENCES users(id), + event_type TEXT, -- page_view, add_to_cart, checkout, etc. + metadata JSON, + ts TIMESTAMP NOT NULL +); +``` + +### Seeding + +Seed with realistic volumes using the `faker` library: + +```python +# database.py β€” seed targets +SEED_CONFIG = { + "users": 500, # ~500 users + "products": 80, # 80 products across 5 categories + "orders": 2000, # ~2000 orders + "order_items": 5000, # ~5000 line items + "events": 8000, # ~8000 behavioural events +} + +# Intentional messiness (makes it realistic) +# - ~5% of users have NULL country +# - ~3% of orders have status='refunded' +# - churned_at is NULL for active users +# - Some users have 0 orders (registered but never bought) +``` + +--- + +## 4. The 3 Tasks with Graders + +### Task 1 β€” Easy: Monthly Signups + +**Question:** `"How many users signed up in the last 30 days?"` + +**Required SQL skills:** Single table, `COUNT`, `WHERE`, date filtering + +**Expected SQL:** +```sql +SELECT COUNT(*) FROM users +WHERE created_at >= DATE('now', '-30 days'); +``` + +**Grader:** +```python +def grade_easy(submitted_answer: str, ground_truth: int) -> float: + try: + val = int(submitted_answer.strip().replace(",", "")) + if val == ground_truth: + return 1.0 + if abs(val - ground_truth) <= 3: # within 3 = partial credit + return 0.6 + if abs(val - ground_truth) <= 10: # within 10 = small credit + return 0.3 + except (ValueError, AttributeError): + pass + return 0.0 +``` + +**Max steps:** 10 +**Difficulty:** Easy + +--- + +### Task 2 β€” Medium: Top Revenue Category + +**Question:** `"Which product category generated the most revenue in Q3 (July–September)?"` + +**Required SQL skills:** `JOIN` across 3 tables, `GROUP BY`, `ORDER BY`, `SUM`, date range filtering + +**Expected SQL:** +```sql +SELECT p.category, SUM(oi.qty * oi.unit_price) AS revenue +FROM orders o +JOIN order_items oi ON o.id = oi.order_id +JOIN products p ON oi.product_id = p.id +WHERE o.created_at BETWEEN '2024-07-01' AND '2024-09-30' + AND o.status = 'completed' +GROUP BY p.category +ORDER BY revenue DESC +LIMIT 1; +``` + +**Grader:** +```python +def grade_medium(submitted_answer: str, ground_truth: str, top_3: list) -> float: + answer = submitted_answer.strip().lower() + # Remove common LLM preamble + answer = re.sub(r'the (answer|category) is:?\s*', '', answer) + + if ground_truth.lower() in answer: + return 1.0 + if any(cat.lower() in answer for cat in top_3): + return 0.4 # got a plausible answer, not the top one + return 0.0 +``` + +**Max steps:** 15 +**Difficulty:** Medium + +--- + +### Task 3 β€” Hard: Churn After 3rd Purchase + +**Question:** `"Find the email addresses of users who placed exactly 3 orders and then never ordered again (churned after their 3rd purchase). Return as a comma-separated list."` + +**Required SQL skills:** Window functions (`ROW_NUMBER`, `COUNT`), subqueries, `HAVING`, date logic + +**Expected SQL:** +```sql +WITH order_counts AS ( + SELECT user_id, COUNT(*) AS total_orders, + MAX(created_at) AS last_order_date + FROM orders + WHERE status = 'completed' + GROUP BY user_id + HAVING COUNT(*) = 3 +), +churned AS ( + SELECT oc.user_id + FROM order_counts oc + WHERE oc.last_order_date < DATE('now', '-90 days') +) +SELECT u.email +FROM users u +JOIN churned c ON u.id = c.user_id; +``` + +**Grader (F1 score for set matching):** +```python +def grade_hard(submitted_answer: str, ground_truth_emails: set) -> float: + if not submitted_answer.strip(): + return 0.0 + + # Parse comma-separated emails + submitted = { + e.strip().lower() + for e in submitted_answer.split(",") + if "@" in e + } + + if not submitted: + return 0.0 + + correct = ground_truth_emails + tp = len(submitted & correct) + + if tp == 0: + return 0.0 + + precision = tp / len(submitted) + recall = tp / len(correct) + f1 = 2 * precision * recall / (precision + recall) + + return round(f1, 3) +``` + +**Max steps:** 20 +**Difficulty:** Hard + +--- + +## 5. Pydantic Models (OpenEnv Spec) + +```python +# env/models.py +from pydantic import BaseModel, Field +from typing import Optional, List, Any + +class Action(BaseModel): + """What the agent can do each step.""" + sql_query: Optional[str] = Field( + None, + description="A SQL SELECT query to execute against the database" + ) + submit_answer: Optional[str] = Field( + None, + description="Final answer to submit. Ends the episode." + ) + + def is_valid(self) -> bool: + # Exactly one of the two must be set + return bool(self.sql_query) != bool(self.submit_answer) + + +class QueryResult(BaseModel): + """Result of executing a SQL query.""" + columns: List[str] = [] + rows: List[List[Any]] = [] + error: Optional[str] = None + truncated: bool = False + total_rows: int = 0 + + +class Observation(BaseModel): + """What the agent sees after each step.""" + schema_summary: str = Field(..., description="Compact DB schema") + question: str = Field(..., description="Business question to answer") + last_query: Optional[str] = None + last_result: Optional[QueryResult] = None + last_error: Optional[str] = None + step: int = 0 + max_steps: int = 20 + hints: List[str] = [] + done: bool = False + + +class StepResult(BaseModel): + """Full result returned by step().""" + observation: Observation + reward: float = 0.0 + done: bool = False + info: dict = {} + + +class EnvState(BaseModel): + """Full environment state returned by state().""" + task_id: str + difficulty: str + step: int + max_steps: int + query_history: List[str] = [] + total_reward: float = 0.0 + done: bool = False +``` + +--- + +## 6. Environment Core + +```python +# env/environment.py +import sqlite3 +from typing import Optional +from .models import Action, Observation, StepResult, EnvState, QueryResult +from .database import create_database, seed_database, get_schema_summary +from .reward import RewardCalculator +from .tasks import TASKS + + +class SQLAnalystEnv: + """ + OpenEnv-compliant SQL Data Analyst environment. + + An agent must answer business questions by iteratively + writing and executing SQL queries. + """ + + def __init__(self, task_id: str = "monthly_signups"): + assert task_id in TASKS, f"Unknown task: {task_id}. Choose from {list(TASKS)}" + self.task_id = task_id + self.task = TASKS[task_id] + self.conn: Optional[sqlite3.Connection] = None + self.step_count: int = 0 + self.total_reward: float = 0.0 + self.done: bool = False + self._query_history: list = [] + self._reward_calc = RewardCalculator() + + # ------------------------------------------------------------------ + # OpenEnv required methods + # ------------------------------------------------------------------ + + def reset(self) -> StepResult: + """Reset environment. Reseed DB. Return initial observation.""" + if self.conn: + self.conn.close() + + self.conn = create_database() + seed_database(self.conn) + self.step_count = 0 + self.total_reward = 0.0 + self.done = False + self._query_history = [] + + # Compute ground truth AFTER seeding + self.task.compute_ground_truth(self.conn) + + obs = Observation( + schema_summary=get_schema_summary(self.conn), + question=self.task.question, + step=0, + max_steps=self.task.max_steps, + ) + return StepResult(observation=obs, reward=0.0, done=False) + + def step(self, action: Action) -> StepResult: + """Execute one agent action. Return (observation, reward, done, info).""" + assert self.conn is not None, "Call reset() before step()" + assert not self.done, "Episode is done. Call reset()." + assert action.is_valid(), "Action must have exactly one of: sql_query, submit_answer" + + self.step_count += 1 + query_result = None + error = None + + # --- Execute SQL or submit answer --- + if action.sql_query: + query_result = self._execute_sql(action.sql_query) + self._query_history.append(action.sql_query) + error = query_result.error + + terminal = ( + action.submit_answer is not None + or self.step_count >= self.task.max_steps + ) + + # --- Calculate reward --- + reward = self._reward_calc.calculate( + action=action, + result=query_result, + task=self.task, + step=self.step_count, + query_history=self._query_history, + terminal=terminal, + ) + self.total_reward += reward + self.done = terminal + + # --- Build next observation --- + obs = Observation( + schema_summary=get_schema_summary(self.conn), + question=self.task.question, + last_query=action.sql_query, + last_result=query_result, + last_error=error, + step=self.step_count, + max_steps=self.task.max_steps, + hints=self.task.get_hints(self.step_count), + done=self.done, + ) + + return StepResult( + observation=obs, + reward=round(reward, 3), + done=self.done, + info={ + "step": self.step_count, + "total_reward": round(self.total_reward, 3), + "task_id": self.task_id, + } + ) + + def state(self) -> EnvState: + """Return current full state of the environment.""" + return EnvState( + task_id=self.task_id, + difficulty=self.task.difficulty, + step=self.step_count, + max_steps=self.task.max_steps, + query_history=self._query_history.copy(), + total_reward=round(self.total_reward, 3), + done=self.done, + ) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _execute_sql(self, query: str) -> QueryResult: + """Execute SQL safely. Block non-SELECT. Return up to 50 rows.""" + # Safety: only SELECT is allowed + q = query.strip().upper() + if not q.startswith("SELECT") and not q.startswith("WITH"): + return QueryResult(error="Only SELECT / WITH queries are allowed.") + try: + cursor = self.conn.execute(query) + cols = [d[0] for d in cursor.description] if cursor.description else [] + rows = cursor.fetchmany(50) + total = len(rows) # fetchmany caps at 50 + return QueryResult( + columns=cols, + rows=[list(r) for r in rows], + truncated=(total == 50), + total_rows=total, + ) + except Exception as e: + return QueryResult(error=str(e)) +``` + +--- + +## 7. Reward Function + +```python +# env/reward.py +import re +from .models import Action, QueryResult + + +class RewardCalculator: + + def calculate( + self, + action: Action, + result: Optional[QueryResult], + task, + step: int, + query_history: list, + terminal: bool, + ) -> float: + + reward = 0.0 + + # ── Step-level rewards (every step) ────────────────────────── + + if action.sql_query and result: + + # +0.15 β€” Query executed without syntax error + if not result.error: + reward += 0.15 + + # +0.10 β€” Query touched at least one relevant table + relevant = self._count_relevant_tables(action.sql_query, task.relevant_tables) + if relevant > 0: + reward += 0.10 + + # +0.05 β€” Result has rows (not empty result set) + if result.rows and len(result.rows) > 0: + reward += 0.05 + + # +0.05 β€” Result is not absurdly large (sanity check) + if result.rows and len(result.rows) < 1000: + reward += 0.05 + + # ── Efficiency penalties ────────────────────────────────────── + + # -0.02 per step beyond step 3 (penalise excessive querying) + if step > 3: + reward -= 0.02 * (step - 3) + + # -0.10 if agent is stuck in a loop (same query 3x) + if self._is_stuck(query_history): + reward -= 0.10 + + # ── Terminal reward (only when episode ends) ────────────────── + + if terminal and action.submit_answer: + # Grade the submitted answer β€” up to 0.60 of total reward + task_score = task.grade(action.submit_answer) + reward += task_score * 0.60 + + # Clamp to [0.0, 1.0] + return max(0.0, min(1.0, reward)) + + def _count_relevant_tables(self, query: str, relevant_tables: list) -> int: + query_lower = query.lower() + return sum(1 for t in relevant_tables if t.lower() in query_lower) + + def _is_stuck(self, history: list) -> bool: + if len(history) < 3: + return False + return len(set(history[-3:])) == 1 +``` + +**Reward breakdown per step:** + +| Signal | Max Value | Condition | +|---|---|---| +| No SQL error | +0.15 | Query executes cleanly | +| Relevant table used | +0.10 | Query touches correct table(s) | +| Non-empty result | +0.05 | Result set has at least 1 row | +| Reasonable result size | +0.05 | Result has < 1000 rows | +| Late-step penalty | βˆ’0.02/step | Each step beyond step 3 | +| Infinite loop penalty | βˆ’0.10 | Same query repeated 3+ times | +| Terminal answer score | up to +0.60 | Task grader Γ— 0.60 | + +**Maximum possible reward per episode:** ~1.0 +**Minimum (immediate surrender):** 0.0 + +--- + +## 8. Key Optimisations + +### 8.1 Schema Summarisation + +Never dump raw `CREATE TABLE` SQL into the prompt β€” it wastes context. Use a compact summary: + +```python +# env/database.py +def get_schema_summary(conn: sqlite3.Connection) -> str: + """Return one-line-per-table schema, e.g.: + users: (id, email, country, plan, created_at, churned_at) + """ + cursor = conn.execute( + "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name" + ) + tables = [r[0] for r in cursor.fetchall()] + lines = [] + for table in tables: + cols = conn.execute(f"PRAGMA table_info({table})").fetchall() + col_names = [c[1] for c in cols] + lines.append(f"{table}: ({', '.join(col_names)})") + return "\n".join(lines) +``` + +### 8.2 Answer Normalisation + +Strip LLM formatting before grading β€” don't penalise the agent for markdown: + +```python +# env/utils.py +import re + +def normalize_answer(raw: str) -> str: + """Remove common LLM answer preambles and formatting.""" + text = raw.strip().lower() + text = re.sub(r'the (answer|result) is:?\s*', '', text) + text = re.sub(r'\*+', '', text) # bold + text = re.sub(r'```.*?```', '', text, flags=re.DOTALL) # code blocks + text = re.sub(r'`[^`]+`', lambda m: m.group().strip('`'), text) + text = re.sub(r'\s+', ' ', text) + return text.strip() +``` + +### 8.3 Progressive Hints + +Give hints as steps increase β€” keeps episodes learnable and reward dense: + +```python +# env/tasks/base.py +def get_hints(self, step: int) -> list[str]: + hints = [] + if step > 5: + hints.append(f"Hint: The relevant tables are: {', '.join(self.relevant_tables)}") + if step > 10: + hints.append(f"Hint: Try using {self.sql_hint}") + if step > 15: + hints.append("Hint: Make sure to submit your answer with submit_answer.") + return hints +``` + +### 8.4 Ground Truth Computed Post-Seed + +Always compute ground truth **after** seeding, so it matches the actual data: + +```python +# env/tasks/easy.py +def compute_ground_truth(self, conn: sqlite3.Connection): + result = conn.execute( + "SELECT COUNT(*) FROM users WHERE created_at >= DATE('now', '-30 days')" + ).fetchone() + self.ground_truth = result[0] +``` + +### 8.5 SQL Safety Guards + +Block any mutating operations: + +```python +FORBIDDEN_KEYWORDS = ["DROP", "DELETE", "INSERT", "UPDATE", "ALTER", "CREATE", "TRUNCATE"] + +def is_safe_query(query: str) -> bool: + upper = query.upper() + return not any(kw in upper for kw in FORBIDDEN_KEYWORDS) +``` + +--- + +## 9. Baseline Inference Script + +```python +# baseline/run_baseline.py +""" +Baseline inference script for sql-data-analyst OpenEnv. + +Usage: + export OPENAI_API_KEY=sk-... + python baseline/run_baseline.py + +Produces reproducible scores across all 3 tasks. +""" + +import os +import json +from openai import OpenAI +from env.environment import SQLAnalystEnv +from env.models import Action + +API_KEY = os.environ["OPENAI_API_KEY"] +MODEL = "gpt-4o-mini" +MAX_STEPS = 20 +TASK_IDS = ["monthly_signups", "top_revenue_category", "churn_analysis"] + +client = OpenAI(api_key=API_KEY) + +SYSTEM_PROMPT = """ +You are a SQL data analyst. You are given a database schema and a business question. +Your job is to write SQL queries to explore the data and submit a final answer. + +Rules: +- Only write SELECT or WITH queries. +- Reply with JSON only. No explanation. +- To run a query: {"sql_query": "SELECT ..."} +- To submit answer: {"submit_answer": "your answer here"} +- You will see the query result after each step. +- Submit your answer when you are confident. +""" + + +def build_prompt(obs) -> str: + parts = [ + f"Database schema:\n{obs.schema_summary}", + f"\nQuestion: {obs.question}", + f"\nStep: {obs.step} / {obs.max_steps}", + ] + if obs.last_query: + parts.append(f"\nLast query:\n{obs.last_query}") + if obs.last_result and obs.last_result.rows: + cols = obs.last_result.columns + rows = obs.last_result.rows[:10] # show max 10 rows + parts.append(f"\nResult columns: {cols}") + parts.append(f"Result rows (first {len(rows)}):\n{json.dumps(rows, indent=2)}") + if obs.last_error: + parts.append(f"\nSQL error: {obs.last_error}") + if obs.hints: + parts.append(f"\nHints: {'; '.join(obs.hints)}") + parts.append("\nWhat is your next action? Reply with JSON only.") + return "\n".join(parts) + + +def parse_action(response_text: str) -> Action: + """Extract JSON action from LLM response.""" + text = response_text.strip() + # Strip markdown code fences if present + text = text.replace("```json", "").replace("```", "").strip() + try: + data = json.loads(text) + return Action(**data) + except Exception: + # Fallback: treat entire response as a submit + return Action(submit_answer=text) + + +def run_task(task_id: str) -> dict: + print(f"\n{'='*50}") + print(f"Task: {task_id}") + print('='*50) + + env = SQLAnalystEnv(task_id=task_id) + result = env.reset() + obs = result.observation + history = [] + score = 0.0 + + print(f"Question: {obs.question}") + + for step in range(1, MAX_STEPS + 1): + if result.done: + print(f"Episode done at step {step - 1}") + break + + user_prompt = build_prompt(obs) + history.append({"role": "user", "content": user_prompt}) + + response = client.chat.completions.create( + model=MODEL, + messages=[ + {"role": "system", "content": SYSTEM_PROMPT}, + *history[-8:], # last 4 turns (8 messages) + ], + temperature=0.0, # deterministic + ) + + reply = response.choices[0].message.content + history.append({"role": "assistant", "content": reply}) + + action = parse_action(reply) + print(f"Step {step}: {action}") + + result = env.step(action) + obs = result.observation + score = result.reward + + if result.done: + break + + state = env.state() + print(f"Final total reward: {state.total_reward}") + return { + "task_id": task_id, + "total_reward": state.total_reward, + "steps": state.step, + } + + +def main(): + results = [] + for task_id in TASK_IDS: + r = run_task(task_id) + results.append(r) + + print("\n" + "="*50) + print("BASELINE RESULTS") + print("="*50) + for r in results: + print(f"{r['task_id']:30s} score={r['total_reward']:.3f} steps={r['steps']}") + + avg = sum(r["total_reward"] for r in results) / len(results) + print(f"\nAverage score: {avg:.3f}") + + # Write results to file for reproducibility + with open("baseline_scores.json", "w") as f: + json.dump(results, f, indent=2) + print("Saved to baseline_scores.json") + + +if __name__ == "__main__": + main() +``` + +--- + +## 10. openenv.yaml + +```yaml +name: sql-data-analyst +version: "1.0.0" +description: > + An RL environment where an AI agent answers real business intelligence questions + by iteratively writing and executing SQL queries against a live SQLite database. + Simulates the day-to-day workflow of a data analyst. + +tags: + - openenv + - sql + - data-analysis + - business-intelligence + - real-world + +author: your-username +repository: https://huggingface.co/spaces/your-username/sql-data-analyst + +observation_space: + type: dict + fields: + schema_summary: + type: string + description: Compact one-line-per-table schema of the database + question: + type: string + description: Natural language business question to answer + last_query: + type: string + nullable: true + description: The last SQL query executed by the agent + last_result: + type: object + nullable: true + description: Result of the last query (columns, rows, error) + last_error: + type: string + nullable: true + description: SQL error message if last query failed + step: + type: integer + description: Current step number + max_steps: + type: integer + description: Maximum steps allowed for this task + hints: + type: array + items: string + description: Progressive hints revealed as steps increase + +action_space: + type: union + description: Agent must provide exactly one of the following + options: + sql_query: + type: string + description: A SELECT or WITH SQL query to execute + submit_answer: + type: string + description: Final answer to the question. Ends the episode. + +tasks: + - id: monthly_signups + difficulty: easy + max_steps: 10 + description: "Count the number of users who signed up in the last 30 days" + skills_required: + - COUNT + - WHERE with date filter + + - id: top_revenue_category + difficulty: medium + max_steps: 15 + description: "Find which product category generated the most revenue in Q3" + skills_required: + - JOIN (3 tables) + - GROUP BY + - SUM aggregation + - Date range filtering + + - id: churn_analysis + difficulty: hard + max_steps: 20 + description: > + Find email addresses of users who placed exactly 3 orders and then + never ordered again (churned after their 3rd purchase) + skills_required: + - Subqueries + - HAVING clause + - Date logic + - Window functions (optional) + +baseline_scores: + monthly_signups: 0.82 + top_revenue_category: 0.61 + churn_analysis: 0.38 + average: 0.60 +``` + +--- + +## 11. Dockerfile + +```dockerfile +FROM python:3.11-slim + +WORKDIR /app + +# Install dependencies +COPY requirements.txt . +RUN pip install --no-cache-dir -r requirements.txt + +# Copy source +COPY . . + +# Pre-seed the database at build time (optional β€” env also seeds at reset()) +RUN python -c "from env.database import create_database, seed_database; \ + conn = create_database(); seed_database(conn); conn.close()" + +# Expose port for HuggingFace Spaces +EXPOSE 7860 + +# Start the API server +CMD ["python", "-m", "uvicorn", "env.server:app", "--host", "0.0.0.0", "--port", "7860"] +``` + +``` +# requirements.txt +pydantic>=2.0 +fastapi +uvicorn +openai +faker +pytest +``` + +--- + +## 12. README Template + +````markdown +# SQL Data Analyst β€” OpenEnv Environment + +An RL training environment where an AI agent learns to answer business intelligence +questions by writing and executing SQL queries against a live database. + +## Motivation + +Data analysts spend significant time translating business questions into SQL queries. +This environment trains agents to do exactly that β€” iteratively exploring a database +schema, writing queries, observing results, and submitting final answers. + +## Observation Space + +| Field | Type | Description | +|---|---|---| +| `schema_summary` | string | Compact DB schema (one line per table) | +| `question` | string | Natural language business question | +| `last_query` | string \| null | Most recent SQL query | +| `last_result` | object \| null | Query result: columns, rows (max 50), error | +| `last_error` | string \| null | SQL error if last query failed | +| `step` | int | Current step number | +| `max_steps` | int | Episode step limit | +| `hints` | string[] | Progressive hints (revealed after step 5, 10, 15) | + +## Action Space + +Agent must submit exactly one of: + +| Action | Type | Description | +|---|---|---| +| `sql_query` | string | A SELECT or WITH SQL query to execute | +| `submit_answer` | string | Final answer β€” ends the episode | + +## Tasks + +| Task | Difficulty | Max Steps | Description | +|---|---|---|---| +| `monthly_signups` | Easy | 10 | Count signups in the last 30 days | +| `top_revenue_category` | Medium | 15 | Find highest revenue product category in Q3 | +| `churn_analysis` | Hard | 20 | Find emails of users who churned after 3 purchases | + +## Reward Function + +Rewards are given at every step (not just episode end): + +- `+0.15` β€” Query executes without error +- `+0.10` β€” Query references a relevant table +- `+0.05` β€” Result has at least one row +- `+0.05` β€” Result is a sensible size +- `-0.02` per step beyond step 3 (efficiency penalty) +- `-0.10` if agent repeats the same query 3+ times +- `+0.00–0.60` on final submission (task grader Γ— 0.60) + +## Setup + +```bash +git clone https://huggingface.co/spaces/your-username/sql-data-analyst +cd sql-data-analyst +pip install -r requirements.txt +``` + +### Run locally + +```python +from env.environment import SQLAnalystEnv +from env.models import Action + +env = SQLAnalystEnv(task_id="monthly_signups") +result = env.reset() +print(result.observation.question) + +# Agent takes a step +result = env.step(Action(sql_query="SELECT COUNT(*) FROM users WHERE created_at >= DATE('now', '-30 days')")) +print(result.reward) +``` + +### Run baseline + +```bash +export OPENAI_API_KEY=sk-... +python baseline/run_baseline.py +``` + +### Docker + +```bash +docker build -t sql-analyst-env . +docker run -p 7860:7860 -e OPENAI_API_KEY=sk-... sql-analyst-env +``` + +## Baseline Scores + +| Task | Score | Model | +|---|---|---| +| monthly_signups | 0.82 | gpt-4o-mini | +| top_revenue_category | 0.61 | gpt-4o-mini | +| churn_analysis | 0.38 | gpt-4o-mini | +| **Average** | **0.60** | gpt-4o-mini | + +## Validation + +```bash +openenv validate --env env.environment.SQLAnalystEnv +pytest tests/ +``` +```` + +--- + +## 13. Full File Structure + +``` +sql-analyst-openenv/ +β”‚ +β”œβ”€β”€ env/ +β”‚ β”œβ”€β”€ __init__.py +β”‚ β”œβ”€β”€ environment.py ← Main OpenEnv class (reset/step/state) +β”‚ β”œβ”€β”€ models.py ← Pydantic: Observation, Action, StepResult, EnvState +β”‚ β”œβ”€β”€ database.py ← SQLite creation + Faker seeding + schema summary +β”‚ β”œβ”€β”€ executor.py ← Safe SQL execution (SELECT-only guard) +β”‚ β”œβ”€β”€ reward.py ← RewardCalculator class +β”‚ β”œβ”€β”€ utils.py ← normalize_answer, is_safe_query helpers +β”‚ β”œβ”€β”€ server.py ← FastAPI wrapper for HuggingFace Spaces +β”‚ └── tasks/ +β”‚ β”œβ”€β”€ __init__.py ← TASKS dict: {task_id: TaskInstance} +β”‚ β”œβ”€β”€ base.py ← BaseTask abstract class +β”‚ β”œβ”€β”€ easy.py ← MonthlySignupsTask +β”‚ β”œβ”€β”€ medium.py ← TopRevenueCategoryTask +β”‚ └── hard.py ← ChurnAnalysisTask +β”‚ +β”œβ”€β”€ baseline/ +β”‚ β”œβ”€β”€ run_baseline.py ← Full inference script (OpenAI API) +β”‚ └── prompts.py ← System prompt + user prompt builder +β”‚ +β”œβ”€β”€ tests/ +β”‚ β”œβ”€β”€ test_env.py ← reset/step/state contract tests +β”‚ β”œβ”€β”€ test_graders.py ← Unit tests for each task grader +β”‚ └── test_reward.py ← Reward calculator unit tests +β”‚ +β”œβ”€β”€ openenv.yaml ← OpenEnv spec metadata +β”œβ”€β”€ Dockerfile ← docker build + docker run +β”œβ”€β”€ requirements.txt +└── README.md +``` + +--- + +## 14. Build Order + +Follow this order when coding. Each step is a self-contained deliverable. + +### Step 1 β€” Models (30 min) +Build `env/models.py` first. All other files depend on these types. +Test: can import and instantiate `Observation`, `Action`, `StepResult`. + +### Step 2 β€” Database (45 min) +Build `env/database.py` β€” schema creation, Faker seeding, schema summary. +Test: run `create_database()` + `seed_database()`, query the tables manually. + +### Step 3 β€” Tasks + Graders (60 min) +Build `env/tasks/base.py`, then `easy.py`, `medium.py`, `hard.py`. +Test each grader with known inputs: perfect answer β†’ 1.0, wrong answer β†’ 0.0. + +### Step 4 β€” Reward Calculator (30 min) +Build `env/reward.py`. +Test: step with good query β†’ positive reward, repeated query β†’ penalty applied. + +### Step 5 β€” Environment Core (60 min) +Build `env/environment.py` β€” wire together DB, executor, reward, tasks. +Test: full episode loop manually: `reset()` β†’ `step()` Γ— N β†’ `state()`. + +### Step 6 β€” Baseline Script (45 min) +Build `baseline/run_baseline.py`. +Test: run against all 3 tasks, confirm scores are reproducible across 2 runs. + +### Step 7 β€” FastAPI Server (30 min) +Build `env/server.py` β€” wrap env in HTTP endpoints for HF Spaces. +Test: `docker build` passes, `docker run` starts server on port 7860. + +### Step 8 β€” Docs + Validation (30 min) +Write `openenv.yaml` and `README.md`. Run `openenv validate`. +Fill in real baseline scores from Step 6 output. + +### Step 9 β€” Deploy to HuggingFace (15 min) +Push to HF Space repo. Tag with `openenv`. Verify Space starts cleanly. + +--- + +*Total estimated time: ~6 hours for a clean first build.* diff --git a/env/__init__.py b/env/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..238a84a352e872b4220cf2fe4ac5ec3ae1531306 --- /dev/null +++ b/env/__init__.py @@ -0,0 +1,13 @@ +from .models import Action, QueryResult, Observation, StepResult, EnvState +from .environment import SQLAnalystEnv +from .tasks import TASKS + +__all__ = [ + "Action", + "QueryResult", + "Observation", + "StepResult", + "EnvState", + "SQLAnalystEnv", + "TASKS", +] diff --git a/env/database.py b/env/database.py new file mode 100644 index 0000000000000000000000000000000000000000..dafcf06374fe19d832203efb59331c6b5b2ea679 --- /dev/null +++ b/env/database.py @@ -0,0 +1,267 @@ +import sqlite3 +import random +from datetime import datetime, timedelta +from typing import Optional, Any +from faker import Faker + +fake = Faker() + +SEED_CONFIG = { + "users": 500, + "products": 80, + "orders": 2000, + "order_items": 5000, + "events": 8000, +} + +CATEGORIES = ["Electronics", "Clothing", "Books", "Home & Garden", "Sports"] +PLAN_TYPES = ["free", "pro", "enterprise"] +ORDER_STATUSES = ["pending", "completed", "refunded"] +EVENT_TYPES = ["page_view", "add_to_cart", "checkout", "login", "logout"] + + +def create_database(db_path: str = ":memory:") -> sqlite3.Connection: + conn = sqlite3.connect(db_path) + conn.row_factory = sqlite3.Row + + conn.execute(""" + CREATE TABLE users ( + id INTEGER PRIMARY KEY, + email TEXT NOT NULL, + country TEXT, + plan TEXT CHECK(plan IN ('free', 'pro', 'enterprise')), + created_at TIMESTAMP NOT NULL, + churned_at TIMESTAMP + ) + """) + + conn.execute(""" + CREATE TABLE products ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + category TEXT NOT NULL, + price REAL, + cost REAL + ) + """) + + conn.execute(""" + CREATE TABLE orders ( + id INTEGER PRIMARY KEY, + user_id INTEGER REFERENCES users(id), + created_at TIMESTAMP NOT NULL, + status TEXT CHECK(status IN ('pending', 'completed', 'refunded')), + total REAL + ) + """) + + conn.execute(""" + CREATE TABLE order_items ( + id INTEGER PRIMARY KEY, + order_id INTEGER REFERENCES orders(id), + product_id INTEGER REFERENCES products(id), + qty INTEGER NOT NULL, + unit_price REAL + ) + """) + + conn.execute(""" + CREATE TABLE events ( + id INTEGER PRIMARY KEY, + user_id INTEGER REFERENCES users(id), + event_type TEXT, + metadata TEXT, + ts TIMESTAMP NOT NULL + ) + """) + + conn.commit() + return conn + + +def seed_database(conn: sqlite3.Connection) -> None: + users = _seed_users(conn) + products = _seed_products(conn) + orders, order_items = _seed_orders(conn, users, products) + _seed_events(conn, users, orders) + + +def _seed_users(conn: sqlite3.Connection) -> list: + users = [] + now = datetime.now() + base_date = now - timedelta(days=180) + recent_date = now - timedelta(days=30) + + for i in range(SEED_CONFIG["users"]): + if random.random() < 0.3: + created_at = recent_date + timedelta(days=random.randint(0, 30)) + else: + created_at = base_date + timedelta(days=random.randint(0, 180)) + + country = random.choice([fake.country(), None, None, None, None]) + plan = random.choice(PLAN_TYPES) + churned_at = None + + if plan == "free" and random.random() < 0.1: + churned_at = created_at + timedelta(days=random.randint(30, 150)) + + conn.execute( + "INSERT INTO users (email, country, plan, created_at, churned_at) VALUES (?, ?, ?, ?, ?)", + ( + fake.email(), + country, + plan, + created_at.isoformat(), + churned_at.isoformat() if churned_at else None, + ), + ) + users.append((i + 1, created_at)) + + conn.commit() + return users + + +def _seed_products(conn: sqlite3.Connection) -> list: + products = [] + + for i in range(SEED_CONFIG["products"]): + category = random.choice(CATEGORIES) + price = round(random.uniform(10, 500), 2) + cost = round(price * random.uniform(0.3, 0.7), 2) + + conn.execute( + "INSERT INTO products (name, category, price, cost) VALUES (?, ?, ?, ?)", + (fake.catch_phrase(), category, price, cost), + ) + products.append((i + 1, category, price)) + + conn.commit() + return products + + +def _seed_orders(conn: sqlite3.Connection, users: list, products: list) -> tuple: + orders = [] + order_items = [] + + q3_start = datetime(2024, 7, 1) + q3_end = datetime(2024, 9, 30) + recent_date = datetime.now() + old_date = datetime(2024, 1, 1) + + for i in range(SEED_CONFIG["orders"]): + user_id = random.choice(users)[0] + + if random.random() < 0.2: + created_at = q3_start + timedelta(days=random.randint(0, 91)) + else: + created_at = old_date + timedelta(days=random.randint(0, 180)) + + status = random.choices(ORDER_STATUSES, weights=[0.1, 0.87, 0.03])[0] + + conn.execute( + "INSERT INTO orders (user_id, created_at, status, total) VALUES (?, ?, ?, ?)", + (user_id, created_at.isoformat(), status, 0), + ) + + order_id = i + 1 + order_total = 0 + + num_items = random.randint(1, 5) + for _ in range(num_items): + product = random.choice(products) + qty = random.randint(1, 3) + unit_price = product[2] + order_total += qty * unit_price + + conn.execute( + "INSERT INTO order_items (order_id, product_id, qty, unit_price) VALUES (?, ?, ?, ?)", + (order_id, product[0], qty, unit_price), + ) + + conn.execute( + "UPDATE orders SET total = ? WHERE id = ?", + (round(order_total, 2), order_id), + ) + orders.append((order_id, user_id, created_at, status)) + + conn.commit() + return orders, order_items + + +def _seed_events(conn: sqlite3.Connection, users: list, orders: list) -> None: + base_date = datetime.now() - timedelta(days=180) + + for _ in range(SEED_CONFIG["events"]): + user_id = random.choice(users)[0] + ts = base_date + timedelta( + days=random.randint(0, 180), hours=random.randint(0, 23) + ) + event_type = random.choice(EVENT_TYPES) + metadata = '{"page": "/' + fake.uri_path() + '"}' + + conn.execute( + "INSERT INTO events (user_id, event_type, metadata, ts) VALUES (?, ?, ?, ?)", + (user_id, event_type, metadata, ts.isoformat()), + ) + + conn.commit() + + +def get_schema_summary(conn: sqlite3.Connection) -> str: + cursor = conn.execute( + "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name" + ) + tables = [r[0] for r in cursor.fetchall()] + + lines = [] + for table in tables: + cols = conn.execute(f"PRAGMA table_info({table})").fetchall() + col_names = [c[1] for c in cols] + lines.append(f"{table}: ({', '.join(col_names)})") + + return "\n".join(lines) + + +def get_ground_truth(conn: sqlite3.Connection, task_id: str) -> Any: + if task_id == "monthly_signups": + result = conn.execute( + "SELECT COUNT(*) FROM users WHERE created_at >= DATE('now', '-30 days')" + ).fetchone() + return result[0] + + elif task_id == "top_revenue_category": + result = conn.execute(""" + SELECT p.category, SUM(oi.qty * oi.unit_price) AS revenue + FROM orders o + JOIN order_items oi ON o.id = oi.order_id + JOIN products p ON oi.product_id = p.id + WHERE o.created_at BETWEEN '2024-07-01' AND '2024-09-30' + AND o.status = 'completed' + GROUP BY p.category + ORDER BY revenue DESC + LIMIT 1 + """).fetchone() + return result[0] if result else None + + elif task_id == "churn_analysis": + result = conn.execute(""" + WITH order_counts AS ( + SELECT user_id, COUNT(*) AS total_orders, + MAX(created_at) AS last_order_date + FROM orders + WHERE status = 'completed' + GROUP BY user_id + HAVING COUNT(*) = 3 + ), + churned AS ( + SELECT oc.user_id + FROM order_counts oc + WHERE oc.last_order_date < DATE('now', '-90 days') + ) + SELECT u.email + FROM users u + JOIN churned c ON u.id = c.user_id + """).fetchall() + return {row[0].lower() for row in result} + + return None diff --git a/env/environment.py b/env/environment.py new file mode 100644 index 0000000000000000000000000000000000000000..9005eb38c56bdd2da7e026ba54f589b1b3265d57 --- /dev/null +++ b/env/environment.py @@ -0,0 +1,134 @@ +import sqlite3 +from typing import Optional +from .models import Action, Observation, StepResult, EnvState, QueryResult +from .database import create_database, seed_database, get_schema_summary +from .reward import RewardCalculator +from .tasks import TASKS + + +class SQLAnalystEnv: + """ + OpenEnv-compliant SQL Data Analyst environment. + + An agent must answer business questions by iteratively + writing and executing SQL queries. + """ + + def __init__(self, task_id: str = "monthly_signups"): + assert task_id in TASKS, f"Unknown task: {task_id}. Choose from {list(TASKS)}" + self.task_id = task_id + self.task = TASKS[task_id] + self.conn: Optional[sqlite3.Connection] = None + self.step_count: int = 0 + self.total_reward: float = 0.0 + self.done: bool = False + self._query_history: list = [] + self._reward_calc = RewardCalculator() + + def reset(self) -> StepResult: + """Reset environment. Reseed DB. Return initial observation.""" + if self.conn: + self.conn.close() + + self.conn = create_database() + seed_database(self.conn) + self.step_count = 0 + self.total_reward = 0.0 + self.done = False + self._query_history = [] + + self.task.compute_ground_truth(self.conn) + + obs = Observation( + schema_summary=get_schema_summary(self.conn), + question=self.task.question, + step=0, + max_steps=self.task.max_steps, + ) + return StepResult(observation=obs, reward=0.0, done=False) + + def step(self, action: Action) -> StepResult: + """Execute one agent action. Return (observation, reward, done, info).""" + assert self.conn is not None, "Call reset() before step()" + assert not self.done, "Episode is done. Call reset()." + assert action.is_valid(), ( + "Action must have exactly one of: sql_query, submit_answer" + ) + + self.step_count += 1 + query_result = None + error = None + + if action.sql_query: + query_result = self._execute_sql(action.sql_query) + self._query_history.append(action.sql_query) + error = query_result.error + + terminal = ( + action.submit_answer is not None or self.step_count >= self.task.max_steps + ) + + reward = self._reward_calc.calculate( + action=action, + result=query_result, + task=self.task, + step=self.step_count, + query_history=self._query_history, + terminal=terminal, + ) + self.total_reward += reward + self.done = terminal + + obs = Observation( + schema_summary=get_schema_summary(self.conn), + question=self.task.question, + last_query=action.sql_query, + last_result=query_result, + last_error=error, + step=self.step_count, + max_steps=self.task.max_steps, + hints=self.task.get_hints(self.step_count), + done=self.done, + ) + + return StepResult( + observation=obs, + reward=round(reward, 3), + done=self.done, + info={ + "step": self.step_count, + "total_reward": round(self.total_reward, 3), + "task_id": self.task_id, + }, + ) + + def state(self) -> EnvState: + """Return current full state of the environment.""" + return EnvState( + task_id=self.task_id, + difficulty=self.task.difficulty, + step=self.step_count, + max_steps=self.task.max_steps, + query_history=self._query_history.copy(), + total_reward=round(self.total_reward, 3), + done=self.done, + ) + + def _execute_sql(self, query: str) -> QueryResult: + """Execute SQL safely. Block non-SELECT. Return up to 50 rows.""" + q = query.strip().upper() + if not q.startswith("SELECT") and not q.startswith("WITH"): + return QueryResult(error="Only SELECT / WITH queries are allowed.") + try: + cursor = self.conn.execute(query) + cols = [d[0] for d in cursor.description] if cursor.description else [] + rows = cursor.fetchmany(50) + total = len(rows) + return QueryResult( + columns=cols, + rows=[list(r) for r in rows], + truncated=(total == 50), + total_rows=total, + ) + except Exception as e: + return QueryResult(error=str(e)) diff --git a/env/models.py b/env/models.py new file mode 100644 index 0000000000000000000000000000000000000000..43b13e6f7e684bf7fa238c309436dbf6d2192310 --- /dev/null +++ b/env/models.py @@ -0,0 +1,58 @@ +from pydantic import BaseModel, Field +from typing import Optional, List, Any + +class Action(BaseModel): + """What the agent can do each step.""" + sql_query: Optional[str] = Field( + None, + description="A SQL SELECT query to execute against the database" + ) + submit_answer: Optional[str] = Field( + None, + description="Final answer to submit. Ends the episode." + ) + + def is_valid(self) -> bool: + # Exactly one of the two must be set + return bool(self.sql_query) != bool(self.submit_answer) + + +class QueryResult(BaseModel): + """Result of executing a SQL query.""" + columns: List[str] = [] + rows: List[List[Any]] = [] + error: Optional[str] = None + truncated: bool = False + total_rows: int = 0 + + +class Observation(BaseModel): + """What the agent sees after each step.""" + schema_summary: str = Field(..., description="Compact DB schema") + question: str = Field(..., description="Business question to answer") + last_query: Optional[str] = None + last_result: Optional[QueryResult] = None + last_error: Optional[str] = None + step: int = 0 + max_steps: int = 20 + hints: List[str] = [] + done: bool = False + + +class StepResult(BaseModel): + """Full result returned by step().""" + observation: Observation + reward: float = 0.0 + done: bool = False + info: dict = {} + + +class EnvState(BaseModel): + """Full environment state returned by state().""" + task_id: str + difficulty: str + step: int + max_steps: int + query_history: List[str] = [] + total_reward: float = 0.0 + done: bool = False diff --git a/env/reward.py b/env/reward.py new file mode 100644 index 0000000000000000000000000000000000000000..6b123e151d24f7f0b8409eb93f47b04a3b027726 --- /dev/null +++ b/env/reward.py @@ -0,0 +1,55 @@ +from typing import Optional, List, Any +from .models import Action, QueryResult + + +class RewardCalculator: + """Calculate rewards for agent actions in the SQL analyst environment.""" + + def calculate( + self, + action: Action, + result: Optional[QueryResult], + task: Any, + step: int, + query_history: List[str], + terminal: bool, + ) -> float: + """Calculate reward based on action, result, and task.""" + reward = 0.0 + + if action.sql_query and result: + if not result.error: + reward += 0.15 + + relevant = self._count_relevant_tables( + action.sql_query, task.relevant_tables + ) + if relevant > 0: + reward += 0.10 + + if result.rows and len(result.rows) > 0: + reward += 0.05 + + if result.rows and len(result.rows) < 1000: + reward += 0.05 + + if step > 3: + reward -= 0.02 * (step - 3) + + if self._is_stuck(query_history): + reward -= 0.10 + + if terminal and action.submit_answer: + task_score = task.grade(action.submit_answer) + reward += task_score * 0.60 + + return max(0.0, min(1.0, reward)) + + def _count_relevant_tables(self, query: str, relevant_tables: List[str]) -> int: + query_lower = query.lower() + return sum(1 for t in relevant_tables if t.lower() in query_lower) + + def _is_stuck(self, history: List[str]) -> bool: + if len(history) < 3: + return False + return len(set(history[-3:])) == 1 diff --git a/env/server.py b/env/server.py new file mode 100644 index 0000000000000000000000000000000000000000..cd6ce5b4c71d2893e22dbe28800c9fc0fa04f2bf --- /dev/null +++ b/env/server.py @@ -0,0 +1,254 @@ +""" +FastAPI server for SQL Data Analyst OpenEnv. + +Provides REST and WebSocket endpoints for HuggingFace Spaces deployment. +""" + +from fastapi import FastAPI, WebSocket, WebSocketDisconnect +from pydantic import BaseModel +from typing import Optional, Dict, Any +import json +import asyncio + +from env import SQLAnalystEnv, Action + + +app = FastAPI(title="SQL Data Analyst Environment") + +envs: Dict[str, SQLAnalystEnv] = {} + + +class ResetRequest(BaseModel): + task_id: str = "monthly_signups" + + +class StepRequest(BaseModel): + session_id: str + sql_query: Optional[str] = None + submit_answer: Optional[str] = None + + +class StateRequest(BaseModel): + session_id: str + + +@app.get("/") +async def root(): + return { + "name": "sql-data-analyst", + "version": "1.0.0", + "description": "SQL Data Analyst OpenEnv - RL environment for SQL query generation", + } + + +@app.post("/reset") +async def reset(req: ResetRequest) -> Dict[str, Any]: + session_id = req.task_id + + env = SQLAnalystEnv(task_id=req.task_id) + result = env.reset() + envs[session_id] = env + + return { + "session_id": session_id, + "observation": { + "schema_summary": result.observation.schema_summary, + "question": result.observation.question, + "step": result.observation.step, + "max_steps": result.observation.max_steps, + "hints": result.observation.hints, + "done": result.observation.done, + }, + "reward": result.reward, + "done": result.done, + } + + +@app.post("/step") +async def step(req: StepRequest) -> Dict[str, Any]: + session_id = req.session_id + + if session_id not in envs: + return {"error": "Session not found. Call /reset first."} + + env = envs[session_id] + + action = Action(sql_query=req.sql_query, submit_answer=req.submit_answer) + + result = env.step(action) + + return { + "observation": { + "schema_summary": result.observation.schema_summary, + "question": result.observation.question, + "last_query": result.observation.last_query, + "last_result": { + "columns": result.observation.last_result.columns + if result.observation.last_result + else None, + "rows": result.observation.last_result.rows + if result.observation.last_result + else None, + "error": result.observation.last_result.error + if result.observation.last_result + else None, + } + if result.observation.last_result + else None, + "last_error": result.observation.last_error, + "step": result.observation.step, + "max_steps": result.observation.max_steps, + "hints": result.observation.hints, + "done": result.observation.done, + }, + "reward": result.reward, + "done": result.done, + "info": result.info, + } + + +@app.post("/state") +async def state(req: StateRequest) -> Dict[str, Any]: + session_id = req.session_id + + if session_id not in envs: + return {"error": "Session not found. Call /reset first."} + + env = envs[session_id] + state = env.state() + + return { + "task_id": state.task_id, + "difficulty": state.difficulty, + "step": state.step, + "max_steps": state.max_steps, + "query_history": state.query_history, + "total_reward": state.total_reward, + "done": state.done, + } + + +@app.post("/delete") +async def delete_session(req: StateRequest) -> Dict[str, str]: + session_id = req.session_id + + if session_id in envs: + del envs[session_id] + return {"status": "deleted", "session_id": session_id} + + return {"status": "not_found", "session_id": session_id} + + +@app.websocket("/ws") +async def websocket_endpoint(websocket: WebSocket): + await websocket.accept() + + session_id = None + env = None + + try: + while True: + data = await websocket.receive_text() + message = json.loads(data) + + action_type = message.get("type") + + if action_type == "reset": + task_id = message.get("task_id", "monthly_signups") + env = SQLAnalystEnv(task_id=task_id) + result = env.reset() + session_id = task_id + envs[session_id] = env + + await websocket.send_json( + { + "type": "reset", + "observation": { + "schema_summary": result.observation.schema_summary, + "question": result.observation.question, + "step": result.observation.step, + "max_steps": result.observation.max_steps, + "hints": result.observation.hints, + }, + "reward": result.reward, + "done": result.done, + } + ) + + elif action_type == "step": + if not env: + await websocket.send_json({"error": "Call reset first"}) + continue + + action = Action( + sql_query=message.get("sql_query"), + submit_answer=message.get("submit_answer"), + ) + + result = env.step(action) + + await websocket.send_json( + { + "type": "step", + "observation": { + "schema_summary": result.observation.schema_summary, + "question": result.observation.question, + "last_query": result.observation.last_query, + "last_result": { + "columns": result.observation.last_result.columns + if result.observation.last_result + else None, + "rows": result.observation.last_result.rows + if result.observation.last_result + else None, + "error": result.observation.last_result.error + if result.observation.last_result + else None, + } + if result.observation.last_result + else None, + "step": result.observation.step, + "hints": result.observation.hints, + "done": result.observation.done, + }, + "reward": result.reward, + "done": result.done, + "info": result.info, + } + ) + + elif action_type == "state": + if not env: + await websocket.send_json({"error": "Call reset first"}) + continue + + state = env.state() + + await websocket.send_json( + { + "type": "state", + "task_id": state.task_id, + "difficulty": state.difficulty, + "step": state.step, + "max_steps": state.max_steps, + "query_history": state.query_history, + "total_reward": state.total_reward, + "done": state.done, + } + ) + + elif action_type == "close": + if session_id and session_id in envs: + del envs[session_id] + break + + except WebSocketDisconnect: + pass + except Exception as e: + await websocket.send_json({"error": str(e)}) + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host="0.0.0.0", port=7860) diff --git a/env/tasks/__init__.py b/env/tasks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ef74bb4abe44e3d0caec0cbd2b89e39580236dbc --- /dev/null +++ b/env/tasks/__init__.py @@ -0,0 +1,20 @@ +from .base import BaseTask +from .easy import MonthlySignupsTask +from .medium import TopRevenueCategoryTask +from .hard import ChurnAnalysisTask + + +TASKS = { + "monthly_signups": MonthlySignupsTask(), + "top_revenue_category": TopRevenueCategoryTask(), + "churn_analysis": ChurnAnalysisTask(), +} + + +__all__ = [ + "BaseTask", + "MonthlySignupsTask", + "TopRevenueCategoryTask", + "ChurnAnalysisTask", + "TASKS", +] diff --git a/env/tasks/base.py b/env/tasks/base.py new file mode 100644 index 0000000000000000000000000000000000000000..c97be757e04709f605f9f6c955e64cd32515d841 --- /dev/null +++ b/env/tasks/base.py @@ -0,0 +1,52 @@ +from abc import ABC, abstractmethod +import sqlite3 +import re +from typing import Any, List, Optional + + +class BaseTask(ABC): + """Abstract base class for all tasks.""" + + task_id: str + difficulty: str + max_steps: int + question: str + relevant_tables: List[str] + sql_hint: str + + def __init__(self): + self.ground_truth: Any = None + self.top_3_categories: List[str] = [] + + @abstractmethod + def compute_ground_truth(self, conn: sqlite3.Connection) -> None: + """Compute ground truth after database seeding.""" + pass + + @abstractmethod + def grade(self, submitted_answer: str) -> float: + """Grade the submitted answer. Returns score 0.0-1.0.""" + pass + + def get_hints(self, step: int) -> List[str]: + """Return progressive hints based on current step.""" + hints = [] + if step > 5: + hints.append( + f"Hint: The relevant tables are: {', '.join(self.relevant_tables)}" + ) + if step > 10: + hints.append(f"Hint: Try using {self.sql_hint}") + if step > 15: + hints.append("Hint: Make sure to submit your answer with submit_answer.") + return hints + + def _normalize(self, text: str) -> str: + """Remove common LLM formatting and normalize text.""" + text = text.strip().lower() + text = re.sub(r"the (answer|result|category) is:?\s*", "", text) + text = re.sub(r"\*+", "", text) + text = re.sub(r"```.*?```", "", text, flags=re.DOTALL) + text = re.sub(r"`[^`]+`", lambda m: m.group().strip("`"), text) + text = re.sub(r"\s+", " ", text) + return text.strip() diff --git a/env/tasks/easy.py b/env/tasks/easy.py new file mode 100644 index 0000000000000000000000000000000000000000..cc936675157d104a1a97bdfe3c67b42bf74bda35 --- /dev/null +++ b/env/tasks/easy.py @@ -0,0 +1,32 @@ +import sqlite3 +from .base import BaseTask + + +class MonthlySignupsTask(BaseTask): + """Task 1 β€” Easy: Count users who signed up in the last 30 days.""" + + task_id = "monthly_signups" + difficulty = "easy" + max_steps = 10 + question = "How many users signed up in the last 30 days?" + relevant_tables = ["users"] + sql_hint = "COUNT(*) with WHERE clause on created_at" + + def compute_ground_truth(self, conn: sqlite3.Connection) -> None: + result = conn.execute( + "SELECT COUNT(*) FROM users WHERE created_at >= DATE('now', '-30 days')" + ).fetchone() + self.ground_truth = result[0] if result else 0 + + def grade(self, submitted_answer: str) -> float: + try: + val = int(submitted_answer.strip().replace(",", "")) + if val == self.ground_truth: + return 1.0 + if abs(val - self.ground_truth) <= 3: + return 0.6 + if abs(val - self.ground_truth) <= 10: + return 0.3 + except (ValueError, AttributeError): + pass + return 0.0 diff --git a/env/tasks/hard.py b/env/tasks/hard.py new file mode 100644 index 0000000000000000000000000000000000000000..852e8eda17fc31ab0b3a78bdd0958881cae3301f --- /dev/null +++ b/env/tasks/hard.py @@ -0,0 +1,59 @@ +import sqlite3 +from .base import BaseTask + + +class ChurnAnalysisTask(BaseTask): + """Task 3 β€” Hard: Find users who placed exactly 3 orders and then churned.""" + + task_id = "churn_analysis" + difficulty = "hard" + max_steps = 20 + question = "Find the email addresses of users who placed exactly 3 orders and then never ordered again (churned after their 3rd purchase). Return as a comma-separated list." + relevant_tables = ["users", "orders"] + sql_hint = "CTE with COUNT and HAVING" + + def compute_ground_truth(self, conn: sqlite3.Connection) -> None: + result = conn.execute(""" + WITH order_counts AS ( + SELECT user_id, COUNT(*) AS total_orders, + MAX(created_at) AS last_order_date + FROM orders + WHERE status = 'completed' + GROUP BY user_id + HAVING COUNT(*) = 3 + ), + churned AS ( + SELECT oc.user_id + FROM order_counts oc + WHERE oc.last_order_date < DATE('now', '-90 days') + ) + SELECT u.email + FROM users u + JOIN churned c ON u.id = c.user_id + """).fetchall() + + self.ground_truth = {row[0].lower() for row in result} + + def grade(self, submitted_answer: str) -> float: + if not submitted_answer.strip(): + return 0.0 + + submitted = {e.strip().lower() for e in submitted_answer.split(",") if "@" in e} + + if not submitted: + return 0.0 + + correct = {e.lower() for e in self.ground_truth} + tp = len(submitted & correct) + + if tp == 0: + return 0.0 + + precision = tp / len(submitted) if submitted else 0 + recall = tp / len(correct) if correct else 0 + + if precision + recall == 0: + return 0.0 + + f1 = 2 * precision * recall / (precision + recall) + return round(f1, 3) diff --git a/env/tasks/medium.py b/env/tasks/medium.py new file mode 100644 index 0000000000000000000000000000000000000000..1a0f82dd33873f3c790efa7b0a3601cc6b5542d8 --- /dev/null +++ b/env/tasks/medium.py @@ -0,0 +1,54 @@ +import sqlite3 +from .base import BaseTask + + +class TopRevenueCategoryTask(BaseTask): + """Task 2 β€” Medium: Find product category with most revenue in Q3.""" + + task_id = "top_revenue_category" + difficulty = "medium" + max_steps = 15 + question = ( + "Which product category generated the most revenue in Q3 (July-September)?" + ) + relevant_tables = ["orders", "order_items", "products"] + sql_hint = "JOIN with GROUP BY and ORDER BY" + + def compute_ground_truth(self, conn: sqlite3.Connection) -> None: + result = conn.execute(""" + SELECT p.category, SUM(oi.qty * oi.unit_price) AS revenue + FROM orders o + JOIN order_items oi ON o.id = oi.order_id + JOIN products p ON oi.product_id = p.id + WHERE o.created_at BETWEEN '2024-07-01' AND '2024-09-30' + AND o.status = 'completed' + GROUP BY p.category + ORDER BY revenue DESC + LIMIT 1 + """).fetchone() + + self.ground_truth = result[0] if result else None + + all_categories = conn.execute(""" + SELECT p.category, SUM(oi.qty * oi.unit_price) AS revenue + FROM orders o + JOIN order_items oi ON o.id = oi.order_id + JOIN products p ON oi.product_id = p.id + WHERE o.created_at BETWEEN '2024-07-01' AND '2024-09-30' + AND o.status = 'completed' + GROUP BY p.category + ORDER BY revenue DESC + """).fetchall() + + self.top_3_categories = [row[0] for row in all_categories[:3]] + + def grade(self, submitted_answer: str) -> float: + answer = self._normalize(submitted_answer) + + if self.ground_truth and self.ground_truth.lower() in answer: + return 1.0 + + if any(cat.lower() in answer for cat in self.top_3_categories): + return 0.4 + + return 0.0 diff --git a/env/utils.py b/env/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..140816331dc7185e448f5ba964aa0a20ede639eb --- /dev/null +++ b/env/utils.py @@ -0,0 +1,29 @@ +import re + + +def normalize_answer(raw: str) -> str: + """Remove common LLM answer preambles and formatting.""" + text = raw.strip().lower() + text = re.sub(r"the (answer|result) is:?\s*", "", text) + text = re.sub(r"\*+", "", text) + text = re.sub(r"```.*?```", "", text, flags=re.DOTALL) + text = re.sub(r"`[^`]+`", lambda m: m.group().strip("`"), text) + text = re.sub(r"\s+", " ", text) + return text.strip() + + +FORBIDDEN_KEYWORDS = [ + "DROP", + "DELETE", + "INSERT", + "UPDATE", + "ALTER", + "CREATE", + "TRUNCATE", +] + + +def is_safe_query(query: str) -> bool: + """Check if query is safe (SELECT-only).""" + upper = query.upper() + return not any(kw in upper for kw in FORBIDDEN_KEYWORDS) diff --git a/hf_space/.gitattributes b/hf_space/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..a6344aac8c09253b3b630fb776ae94478aa0275b --- /dev/null +++ b/hf_space/.gitattributes @@ -0,0 +1,35 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text diff --git a/hf_space/README.md b/hf_space/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f36be1e11762ff9082b9ed7cf6d556fac23d0e06 --- /dev/null +++ b/hf_space/README.md @@ -0,0 +1,11 @@ +--- +title: Sql Data Analyst +emoji: πŸ“‰ +colorFrom: gray +colorTo: green +sdk: docker +pinned: false +license: mit +--- + +Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/models.py b/models.py new file mode 100644 index 0000000000000000000000000000000000000000..43b13e6f7e684bf7fa238c309436dbf6d2192310 --- /dev/null +++ b/models.py @@ -0,0 +1,58 @@ +from pydantic import BaseModel, Field +from typing import Optional, List, Any + +class Action(BaseModel): + """What the agent can do each step.""" + sql_query: Optional[str] = Field( + None, + description="A SQL SELECT query to execute against the database" + ) + submit_answer: Optional[str] = Field( + None, + description="Final answer to submit. Ends the episode." + ) + + def is_valid(self) -> bool: + # Exactly one of the two must be set + return bool(self.sql_query) != bool(self.submit_answer) + + +class QueryResult(BaseModel): + """Result of executing a SQL query.""" + columns: List[str] = [] + rows: List[List[Any]] = [] + error: Optional[str] = None + truncated: bool = False + total_rows: int = 0 + + +class Observation(BaseModel): + """What the agent sees after each step.""" + schema_summary: str = Field(..., description="Compact DB schema") + question: str = Field(..., description="Business question to answer") + last_query: Optional[str] = None + last_result: Optional[QueryResult] = None + last_error: Optional[str] = None + step: int = 0 + max_steps: int = 20 + hints: List[str] = [] + done: bool = False + + +class StepResult(BaseModel): + """Full result returned by step().""" + observation: Observation + reward: float = 0.0 + done: bool = False + info: dict = {} + + +class EnvState(BaseModel): + """Full environment state returned by state().""" + task_id: str + difficulty: str + step: int + max_steps: int + query_history: List[str] = [] + total_reward: float = 0.0 + done: bool = False diff --git a/openenv.yaml b/openenv.yaml new file mode 100644 index 0000000000000000000000000000000000000000..38b5143c431692d7a7965ee7b18f636e8aa0860a --- /dev/null +++ b/openenv.yaml @@ -0,0 +1,100 @@ +name: sql-data-analyst +version: "1.0.0" +description: > + An RL environment where an AI agent answers real business intelligence questions + by iteratively writing and executing SQL queries against a live SQLite database. + Simulates the day-to-day workflow of a data analyst. + +tags: + - openenv + - sql + - data-analysis + - business-intelligence + - real-world + +author: sql-data-analyst +repository: https://huggingface.co/spaces/sql-data-analyst + +observation_space: + type: dict + fields: + schema_summary: + type: string + description: Compact one-line-per-table schema of the database + question: + type: string + description: Natural language business question to answer + last_query: + type: string + nullable: true + description: The last SQL query executed by the agent + last_result: + type: object + nullable: true + description: Result of the last query (columns, rows, error) + last_error: + type: string + nullable: true + description: SQL error message if last query failed + step: + type: integer + description: Current step number + max_steps: + type: integer + description: Maximum steps allowed for this task + hints: + type: array + items: string + description: Progressive hints revealed as steps increase + done: + type: boolean + description: Whether the episode is complete + +action_space: + type: union + description: Agent must provide exactly one of the following + options: + sql_query: + type: string + description: A SELECT or WITH SQL query to execute + submit_answer: + type: string + description: Final answer to the question. Ends the episode. + +tasks: + - id: monthly_signups + difficulty: easy + max_steps: 10 + description: "Count the number of users who signed up in the last 30 days" + skills_required: + - COUNT + - WHERE with date filter + + - id: top_revenue_category + difficulty: medium + max_steps: 15 + description: "Find which product category generated the most revenue in Q3" + skills_required: + - JOIN (3 tables) + - GROUP BY + - SUM aggregation + - Date range filtering + + - id: churn_analysis + difficulty: hard + max_steps: 20 + description: > + Find email addresses of users who placed exactly 3 orders and then + never ordered again (churned after their 3rd purchase) + skills_required: + - Subqueries + - HAVING clause + - Date logic + - Window functions (optional) + +baseline_scores: + monthly_signups: 0.85 + top_revenue_category: 0.65 + churn_analysis: 0.40 + average: 0.63 + model: gpt-4o-mini \ No newline at end of file diff --git a/progress.txt b/progress.txt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..53eb2d4a4a50e019bb35ec7e884a69f68b009f8d --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,42 @@ +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "sql-data-analyst" +version = "1.0.0" +description = "SQL Data Analyst OpenEnv - RL environment for SQL query generation" +readme = "README.md" +license = {text = "MIT"} +authors = [ + {name = "Hackathon Team", email = "team@example.com"} +] +requires-python = ">=3.11" +dependencies = [ + "openenv>=0.1.13", + "pydantic>=2.0", + "fastapi>=0.100", + "uvicorn>=0.20", + "openai>=1.0", + "faker>=18.0", + "pytest>=7.0", +] + +[project.scripts] +openenv-sql-analyst = "server.app:main" + +[project.optional-dependencies] +dev = [ + "pytest>=7.0", + "pytest-asyncio>=0.21", +] + +[tool.setuptools.packages.find] +where = ["."] +include = ["env*", "baseline*", "server*"] + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] +python_classes = ["Test*"] +python_functions = ["test_*"] \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..9207a91bc34a0382a1cb3629e5b06682215d0a6a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,6 @@ +pydantic>=2.0 +fastapi>=0.100 +uvicorn>=0.20 +openai>=1.0 +faker>=18.0 +pytest>=7.0 \ No newline at end of file diff --git a/server/__init__.py b/server/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/server/app.py b/server/app.py new file mode 100644 index 0000000000000000000000000000000000000000..04e6c5a69e685c98e15f634b9e7a6ad70ecc89b3 --- /dev/null +++ b/server/app.py @@ -0,0 +1,12 @@ +from env.server import app as _app +import uvicorn + + +def main(): + uvicorn.run(_app, host="0.0.0.0", port=7860) + + +if __name__ == "__main__": + main() + +__all__ = ["app", "main"] diff --git a/temp_upload/baseline/__init__.py b/temp_upload/baseline/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/temp_upload/baseline/prompts.py b/temp_upload/baseline/prompts.py new file mode 100644 index 0000000000000000000000000000000000000000..713dbb8a520bab00cfb9fb3c0488aab139395ff6 --- /dev/null +++ b/temp_upload/baseline/prompts.py @@ -0,0 +1,78 @@ +""" +Prompt templates for the baseline inference script. +""" + +SYSTEM_PROMPT = """ +You are a SQL data analyst. You are given a database schema and a business question. +Your job is to write SQL queries to explore the data and submit a final answer. + +Rules: +- Only write SELECT or WITH queries (no INSERT, UPDATE, DELETE, DROP, etc.) +- Reply with JSON only. No explanation. +- To run a query: {"sql_query": "SELECT ..."} +- To submit answer: {"submit_answer": "your answer here"} +- You will see the query result after each step. +- Submit your answer when you are confident. + +Important: +- Always use valid SQL syntax +- Table names: users, products, orders, order_items, events +- Dates are stored as ISO timestamps +- Always filter orders by status='completed' for revenue calculations +""" + + +def build_prompt(obs) -> str: + """Build the user prompt from an observation.""" + parts = [ + f"Database schema:\n{obs.schema_summary}", + f"\nQuestion: {obs.question}", + f"\nStep: {obs.step} / {obs.max_steps}", + ] + + if obs.last_query: + parts.append(f"\nLast query:\n{obs.last_query}") + + if obs.last_result: + if obs.last_result.error: + parts.append(f"\nSQL error: {obs.last_result.error}") + elif obs.last_result.rows: + cols = obs.last_result.columns + rows = obs.last_result.rows[:10] + parts.append(f"\nResult columns: {cols}") + parts.append( + f"Result rows (first {len(rows)}):\n{json.dumps(rows, indent=2)}" + ) + + if obs.hints: + parts.append(f"\nHints: {'; '.join(obs.hints)}") + + parts.append("\nWhat is your next action? Reply with JSON only.") + return "\n".join(parts) + + +import json + + +def parse_action(response_text: str | None): + """Extract JSON action from LLM response.""" + from env import Action + + if not response_text: + return Action(submit_answer="") + + text = response_text.strip() + + text = text.replace("```json", "").replace("```", "").strip() + + try: + data = json.loads(text) + + if "sql_query" in data and data["sql_query"]: + return Action(sql_query=data["sql_query"]) + elif "submit_answer" in data and data["submit_answer"]: + return Action(submit_answer=data["submit_answer"]) + except json.JSONDecodeError: + pass + + return Action(submit_answer=text) diff --git a/temp_upload/baseline/run_baseline.py b/temp_upload/baseline/run_baseline.py new file mode 100644 index 0000000000000000000000000000000000000000..b20b8164e81ef20116cbff1cf13836f2ea501c41 --- /dev/null +++ b/temp_upload/baseline/run_baseline.py @@ -0,0 +1,162 @@ +""" +Baseline inference script for sql-data-analyst OpenEnv. + +Usage: + export OPENAI_API_KEY=sk-... + python baseline/run_baseline.py + +Produces reproducible scores across all 3 tasks. +""" + +import os +import json +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from typing import List, Dict, Any + +try: + from openai import OpenAI +except ImportError: + print("Error: openai package not installed. Run: pip install openai") + sys.exit(1) + +from env import SQLAnalystEnv, Action +from baseline.prompts import SYSTEM_PROMPT, build_prompt, parse_action + + +MODEL = "gpt-4o-mini" +MAX_STEPS = 20 +TASK_IDS = ["monthly_signups", "top_revenue_category", "churn_analysis"] + + +def run_task( + client: OpenAI, task_id: str, max_steps: int = MAX_STEPS +) -> Dict[str, Any]: + """Run a single task with the LLM agent.""" + print(f"\n{'=' * 50}") + print(f"Task: {task_id}") + print("=" * 50) + + env = SQLAnalystEnv(task_id=task_id) + result = env.reset() + obs = result.observation + history = [] + total_reward = 0.0 + + print(f"Question: {obs.question}") + print(f"Schema: {obs.schema_summary[:200]}...") + + for step in range(1, max_steps + 1): + if result.done: + print(f"Episode done at step {step - 1}") + break + + user_prompt = build_prompt(obs) + history.append({"role": "user", "content": user_prompt}) + + try: + response = client.chat.completions.create( + model=MODEL, + messages=[ + {"role": "system", "content": SYSTEM_PROMPT}, + *history[-8:], + ], + temperature=0.0, + ) + except Exception as e: + print(f"API Error: {e}") + break + + reply = response.choices[0].message.content or "" + history.append({"role": "assistant", "content": reply}) + + action = parse_action(reply) + + if action.sql_query: + print(f"Step {step}: Executing SQL...") + print(f" Query: {action.sql_query[:100]}...") + else: + print(f"Step {step}: Submitting answer...") + print( + f" Answer: {action.submit_answer[:100] if action.submit_answer else 'empty'}..." + ) + + result = env.step(action) + obs = result.observation + total_reward = result.info.get("total_reward", 0.0) + + if result.done: + break + + state = env.state() + print(f"\nFinal total reward: {total_reward:.3f}") + print(f"Steps taken: {state.step}") + + return { + "task_id": task_id, + "difficulty": state.difficulty, + "total_reward": round(total_reward, 3), + "steps": state.step, + "max_steps": state.max_steps, + } + + +def main(): + api_key = os.environ.get("OPENAI_API_KEY") + + if not api_key: + print("Error: OPENAI_API_KEY environment variable not set") + print("Usage: export OPENAI_API_KEY=sk-...") + sys.exit(1) + + client = OpenAI(api_key=api_key) + + print("=" * 60) + print("SQL Data Analyst - Baseline Inference") + print("=" * 60) + print(f"Model: {MODEL}") + print(f"Max steps per task: {MAX_STEPS}") + print(f"Tasks: {TASK_IDS}") + + results = [] + for task_id in TASK_IDS: + try: + r = run_task(client, task_id) + results.append(r) + except Exception as e: + print(f"Error running task {task_id}: {e}") + results.append( + { + "task_id": task_id, + "error": str(e), + "total_reward": 0.0, + "steps": 0, + } + ) + + print("\n" + "=" * 60) + print("BASELINE RESULTS") + print("=" * 60) + + for r in results: + task = r.get("task_id", "unknown") + reward = r.get("total_reward", 0.0) + steps = r.get("steps", 0) + print(f"{task:30s} score={reward:.3f} steps={steps}") + + valid_results = [r for r in results if "total_reward" in r] + if valid_results: + avg = sum(r["total_reward"] for r in valid_results) / len(valid_results) + print(f"\nAverage score: {avg:.3f}") + + output_file = "baseline_scores.json" + with open(output_file, "w") as f: + json.dump(results, f, indent=2) + print(f"\nSaved results to {output_file}") + + +if __name__ == "__main__": + main() diff --git a/temp_upload/env/__init__.py b/temp_upload/env/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..238a84a352e872b4220cf2fe4ac5ec3ae1531306 --- /dev/null +++ b/temp_upload/env/__init__.py @@ -0,0 +1,13 @@ +from .models import Action, QueryResult, Observation, StepResult, EnvState +from .environment import SQLAnalystEnv +from .tasks import TASKS + +__all__ = [ + "Action", + "QueryResult", + "Observation", + "StepResult", + "EnvState", + "SQLAnalystEnv", + "TASKS", +] diff --git a/temp_upload/env/database.py b/temp_upload/env/database.py new file mode 100644 index 0000000000000000000000000000000000000000..dafcf06374fe19d832203efb59331c6b5b2ea679 --- /dev/null +++ b/temp_upload/env/database.py @@ -0,0 +1,267 @@ +import sqlite3 +import random +from datetime import datetime, timedelta +from typing import Optional, Any +from faker import Faker + +fake = Faker() + +SEED_CONFIG = { + "users": 500, + "products": 80, + "orders": 2000, + "order_items": 5000, + "events": 8000, +} + +CATEGORIES = ["Electronics", "Clothing", "Books", "Home & Garden", "Sports"] +PLAN_TYPES = ["free", "pro", "enterprise"] +ORDER_STATUSES = ["pending", "completed", "refunded"] +EVENT_TYPES = ["page_view", "add_to_cart", "checkout", "login", "logout"] + + +def create_database(db_path: str = ":memory:") -> sqlite3.Connection: + conn = sqlite3.connect(db_path) + conn.row_factory = sqlite3.Row + + conn.execute(""" + CREATE TABLE users ( + id INTEGER PRIMARY KEY, + email TEXT NOT NULL, + country TEXT, + plan TEXT CHECK(plan IN ('free', 'pro', 'enterprise')), + created_at TIMESTAMP NOT NULL, + churned_at TIMESTAMP + ) + """) + + conn.execute(""" + CREATE TABLE products ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + category TEXT NOT NULL, + price REAL, + cost REAL + ) + """) + + conn.execute(""" + CREATE TABLE orders ( + id INTEGER PRIMARY KEY, + user_id INTEGER REFERENCES users(id), + created_at TIMESTAMP NOT NULL, + status TEXT CHECK(status IN ('pending', 'completed', 'refunded')), + total REAL + ) + """) + + conn.execute(""" + CREATE TABLE order_items ( + id INTEGER PRIMARY KEY, + order_id INTEGER REFERENCES orders(id), + product_id INTEGER REFERENCES products(id), + qty INTEGER NOT NULL, + unit_price REAL + ) + """) + + conn.execute(""" + CREATE TABLE events ( + id INTEGER PRIMARY KEY, + user_id INTEGER REFERENCES users(id), + event_type TEXT, + metadata TEXT, + ts TIMESTAMP NOT NULL + ) + """) + + conn.commit() + return conn + + +def seed_database(conn: sqlite3.Connection) -> None: + users = _seed_users(conn) + products = _seed_products(conn) + orders, order_items = _seed_orders(conn, users, products) + _seed_events(conn, users, orders) + + +def _seed_users(conn: sqlite3.Connection) -> list: + users = [] + now = datetime.now() + base_date = now - timedelta(days=180) + recent_date = now - timedelta(days=30) + + for i in range(SEED_CONFIG["users"]): + if random.random() < 0.3: + created_at = recent_date + timedelta(days=random.randint(0, 30)) + else: + created_at = base_date + timedelta(days=random.randint(0, 180)) + + country = random.choice([fake.country(), None, None, None, None]) + plan = random.choice(PLAN_TYPES) + churned_at = None + + if plan == "free" and random.random() < 0.1: + churned_at = created_at + timedelta(days=random.randint(30, 150)) + + conn.execute( + "INSERT INTO users (email, country, plan, created_at, churned_at) VALUES (?, ?, ?, ?, ?)", + ( + fake.email(), + country, + plan, + created_at.isoformat(), + churned_at.isoformat() if churned_at else None, + ), + ) + users.append((i + 1, created_at)) + + conn.commit() + return users + + +def _seed_products(conn: sqlite3.Connection) -> list: + products = [] + + for i in range(SEED_CONFIG["products"]): + category = random.choice(CATEGORIES) + price = round(random.uniform(10, 500), 2) + cost = round(price * random.uniform(0.3, 0.7), 2) + + conn.execute( + "INSERT INTO products (name, category, price, cost) VALUES (?, ?, ?, ?)", + (fake.catch_phrase(), category, price, cost), + ) + products.append((i + 1, category, price)) + + conn.commit() + return products + + +def _seed_orders(conn: sqlite3.Connection, users: list, products: list) -> tuple: + orders = [] + order_items = [] + + q3_start = datetime(2024, 7, 1) + q3_end = datetime(2024, 9, 30) + recent_date = datetime.now() + old_date = datetime(2024, 1, 1) + + for i in range(SEED_CONFIG["orders"]): + user_id = random.choice(users)[0] + + if random.random() < 0.2: + created_at = q3_start + timedelta(days=random.randint(0, 91)) + else: + created_at = old_date + timedelta(days=random.randint(0, 180)) + + status = random.choices(ORDER_STATUSES, weights=[0.1, 0.87, 0.03])[0] + + conn.execute( + "INSERT INTO orders (user_id, created_at, status, total) VALUES (?, ?, ?, ?)", + (user_id, created_at.isoformat(), status, 0), + ) + + order_id = i + 1 + order_total = 0 + + num_items = random.randint(1, 5) + for _ in range(num_items): + product = random.choice(products) + qty = random.randint(1, 3) + unit_price = product[2] + order_total += qty * unit_price + + conn.execute( + "INSERT INTO order_items (order_id, product_id, qty, unit_price) VALUES (?, ?, ?, ?)", + (order_id, product[0], qty, unit_price), + ) + + conn.execute( + "UPDATE orders SET total = ? WHERE id = ?", + (round(order_total, 2), order_id), + ) + orders.append((order_id, user_id, created_at, status)) + + conn.commit() + return orders, order_items + + +def _seed_events(conn: sqlite3.Connection, users: list, orders: list) -> None: + base_date = datetime.now() - timedelta(days=180) + + for _ in range(SEED_CONFIG["events"]): + user_id = random.choice(users)[0] + ts = base_date + timedelta( + days=random.randint(0, 180), hours=random.randint(0, 23) + ) + event_type = random.choice(EVENT_TYPES) + metadata = '{"page": "/' + fake.uri_path() + '"}' + + conn.execute( + "INSERT INTO events (user_id, event_type, metadata, ts) VALUES (?, ?, ?, ?)", + (user_id, event_type, metadata, ts.isoformat()), + ) + + conn.commit() + + +def get_schema_summary(conn: sqlite3.Connection) -> str: + cursor = conn.execute( + "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name" + ) + tables = [r[0] for r in cursor.fetchall()] + + lines = [] + for table in tables: + cols = conn.execute(f"PRAGMA table_info({table})").fetchall() + col_names = [c[1] for c in cols] + lines.append(f"{table}: ({', '.join(col_names)})") + + return "\n".join(lines) + + +def get_ground_truth(conn: sqlite3.Connection, task_id: str) -> Any: + if task_id == "monthly_signups": + result = conn.execute( + "SELECT COUNT(*) FROM users WHERE created_at >= DATE('now', '-30 days')" + ).fetchone() + return result[0] + + elif task_id == "top_revenue_category": + result = conn.execute(""" + SELECT p.category, SUM(oi.qty * oi.unit_price) AS revenue + FROM orders o + JOIN order_items oi ON o.id = oi.order_id + JOIN products p ON oi.product_id = p.id + WHERE o.created_at BETWEEN '2024-07-01' AND '2024-09-30' + AND o.status = 'completed' + GROUP BY p.category + ORDER BY revenue DESC + LIMIT 1 + """).fetchone() + return result[0] if result else None + + elif task_id == "churn_analysis": + result = conn.execute(""" + WITH order_counts AS ( + SELECT user_id, COUNT(*) AS total_orders, + MAX(created_at) AS last_order_date + FROM orders + WHERE status = 'completed' + GROUP BY user_id + HAVING COUNT(*) = 3 + ), + churned AS ( + SELECT oc.user_id + FROM order_counts oc + WHERE oc.last_order_date < DATE('now', '-90 days') + ) + SELECT u.email + FROM users u + JOIN churned c ON u.id = c.user_id + """).fetchall() + return {row[0].lower() for row in result} + + return None diff --git a/temp_upload/env/environment.py b/temp_upload/env/environment.py new file mode 100644 index 0000000000000000000000000000000000000000..9005eb38c56bdd2da7e026ba54f589b1b3265d57 --- /dev/null +++ b/temp_upload/env/environment.py @@ -0,0 +1,134 @@ +import sqlite3 +from typing import Optional +from .models import Action, Observation, StepResult, EnvState, QueryResult +from .database import create_database, seed_database, get_schema_summary +from .reward import RewardCalculator +from .tasks import TASKS + + +class SQLAnalystEnv: + """ + OpenEnv-compliant SQL Data Analyst environment. + + An agent must answer business questions by iteratively + writing and executing SQL queries. + """ + + def __init__(self, task_id: str = "monthly_signups"): + assert task_id in TASKS, f"Unknown task: {task_id}. Choose from {list(TASKS)}" + self.task_id = task_id + self.task = TASKS[task_id] + self.conn: Optional[sqlite3.Connection] = None + self.step_count: int = 0 + self.total_reward: float = 0.0 + self.done: bool = False + self._query_history: list = [] + self._reward_calc = RewardCalculator() + + def reset(self) -> StepResult: + """Reset environment. Reseed DB. Return initial observation.""" + if self.conn: + self.conn.close() + + self.conn = create_database() + seed_database(self.conn) + self.step_count = 0 + self.total_reward = 0.0 + self.done = False + self._query_history = [] + + self.task.compute_ground_truth(self.conn) + + obs = Observation( + schema_summary=get_schema_summary(self.conn), + question=self.task.question, + step=0, + max_steps=self.task.max_steps, + ) + return StepResult(observation=obs, reward=0.0, done=False) + + def step(self, action: Action) -> StepResult: + """Execute one agent action. Return (observation, reward, done, info).""" + assert self.conn is not None, "Call reset() before step()" + assert not self.done, "Episode is done. Call reset()." + assert action.is_valid(), ( + "Action must have exactly one of: sql_query, submit_answer" + ) + + self.step_count += 1 + query_result = None + error = None + + if action.sql_query: + query_result = self._execute_sql(action.sql_query) + self._query_history.append(action.sql_query) + error = query_result.error + + terminal = ( + action.submit_answer is not None or self.step_count >= self.task.max_steps + ) + + reward = self._reward_calc.calculate( + action=action, + result=query_result, + task=self.task, + step=self.step_count, + query_history=self._query_history, + terminal=terminal, + ) + self.total_reward += reward + self.done = terminal + + obs = Observation( + schema_summary=get_schema_summary(self.conn), + question=self.task.question, + last_query=action.sql_query, + last_result=query_result, + last_error=error, + step=self.step_count, + max_steps=self.task.max_steps, + hints=self.task.get_hints(self.step_count), + done=self.done, + ) + + return StepResult( + observation=obs, + reward=round(reward, 3), + done=self.done, + info={ + "step": self.step_count, + "total_reward": round(self.total_reward, 3), + "task_id": self.task_id, + }, + ) + + def state(self) -> EnvState: + """Return current full state of the environment.""" + return EnvState( + task_id=self.task_id, + difficulty=self.task.difficulty, + step=self.step_count, + max_steps=self.task.max_steps, + query_history=self._query_history.copy(), + total_reward=round(self.total_reward, 3), + done=self.done, + ) + + def _execute_sql(self, query: str) -> QueryResult: + """Execute SQL safely. Block non-SELECT. Return up to 50 rows.""" + q = query.strip().upper() + if not q.startswith("SELECT") and not q.startswith("WITH"): + return QueryResult(error="Only SELECT / WITH queries are allowed.") + try: + cursor = self.conn.execute(query) + cols = [d[0] for d in cursor.description] if cursor.description else [] + rows = cursor.fetchmany(50) + total = len(rows) + return QueryResult( + columns=cols, + rows=[list(r) for r in rows], + truncated=(total == 50), + total_rows=total, + ) + except Exception as e: + return QueryResult(error=str(e)) diff --git a/temp_upload/env/models.py b/temp_upload/env/models.py new file mode 100644 index 0000000000000000000000000000000000000000..43b13e6f7e684bf7fa238c309436dbf6d2192310 --- /dev/null +++ b/temp_upload/env/models.py @@ -0,0 +1,58 @@ +from pydantic import BaseModel, Field +from typing import Optional, List, Any + +class Action(BaseModel): + """What the agent can do each step.""" + sql_query: Optional[str] = Field( + None, + description="A SQL SELECT query to execute against the database" + ) + submit_answer: Optional[str] = Field( + None, + description="Final answer to submit. Ends the episode." + ) + + def is_valid(self) -> bool: + # Exactly one of the two must be set + return bool(self.sql_query) != bool(self.submit_answer) + + +class QueryResult(BaseModel): + """Result of executing a SQL query.""" + columns: List[str] = [] + rows: List[List[Any]] = [] + error: Optional[str] = None + truncated: bool = False + total_rows: int = 0 + + +class Observation(BaseModel): + """What the agent sees after each step.""" + schema_summary: str = Field(..., description="Compact DB schema") + question: str = Field(..., description="Business question to answer") + last_query: Optional[str] = None + last_result: Optional[QueryResult] = None + last_error: Optional[str] = None + step: int = 0 + max_steps: int = 20 + hints: List[str] = [] + done: bool = False + + +class StepResult(BaseModel): + """Full result returned by step().""" + observation: Observation + reward: float = 0.0 + done: bool = False + info: dict = {} + + +class EnvState(BaseModel): + """Full environment state returned by state().""" + task_id: str + difficulty: str + step: int + max_steps: int + query_history: List[str] = [] + total_reward: float = 0.0 + done: bool = False diff --git a/temp_upload/env/reward.py b/temp_upload/env/reward.py new file mode 100644 index 0000000000000000000000000000000000000000..6b123e151d24f7f0b8409eb93f47b04a3b027726 --- /dev/null +++ b/temp_upload/env/reward.py @@ -0,0 +1,55 @@ +from typing import Optional, List, Any +from .models import Action, QueryResult + + +class RewardCalculator: + """Calculate rewards for agent actions in the SQL analyst environment.""" + + def calculate( + self, + action: Action, + result: Optional[QueryResult], + task: Any, + step: int, + query_history: List[str], + terminal: bool, + ) -> float: + """Calculate reward based on action, result, and task.""" + reward = 0.0 + + if action.sql_query and result: + if not result.error: + reward += 0.15 + + relevant = self._count_relevant_tables( + action.sql_query, task.relevant_tables + ) + if relevant > 0: + reward += 0.10 + + if result.rows and len(result.rows) > 0: + reward += 0.05 + + if result.rows and len(result.rows) < 1000: + reward += 0.05 + + if step > 3: + reward -= 0.02 * (step - 3) + + if self._is_stuck(query_history): + reward -= 0.10 + + if terminal and action.submit_answer: + task_score = task.grade(action.submit_answer) + reward += task_score * 0.60 + + return max(0.0, min(1.0, reward)) + + def _count_relevant_tables(self, query: str, relevant_tables: List[str]) -> int: + query_lower = query.lower() + return sum(1 for t in relevant_tables if t.lower() in query_lower) + + def _is_stuck(self, history: List[str]) -> bool: + if len(history) < 3: + return False + return len(set(history[-3:])) == 1 diff --git a/temp_upload/env/server.py b/temp_upload/env/server.py new file mode 100644 index 0000000000000000000000000000000000000000..cd6ce5b4c71d2893e22dbe28800c9fc0fa04f2bf --- /dev/null +++ b/temp_upload/env/server.py @@ -0,0 +1,254 @@ +""" +FastAPI server for SQL Data Analyst OpenEnv. + +Provides REST and WebSocket endpoints for HuggingFace Spaces deployment. +""" + +from fastapi import FastAPI, WebSocket, WebSocketDisconnect +from pydantic import BaseModel +from typing import Optional, Dict, Any +import json +import asyncio + +from env import SQLAnalystEnv, Action + + +app = FastAPI(title="SQL Data Analyst Environment") + +envs: Dict[str, SQLAnalystEnv] = {} + + +class ResetRequest(BaseModel): + task_id: str = "monthly_signups" + + +class StepRequest(BaseModel): + session_id: str + sql_query: Optional[str] = None + submit_answer: Optional[str] = None + + +class StateRequest(BaseModel): + session_id: str + + +@app.get("/") +async def root(): + return { + "name": "sql-data-analyst", + "version": "1.0.0", + "description": "SQL Data Analyst OpenEnv - RL environment for SQL query generation", + } + + +@app.post("/reset") +async def reset(req: ResetRequest) -> Dict[str, Any]: + session_id = req.task_id + + env = SQLAnalystEnv(task_id=req.task_id) + result = env.reset() + envs[session_id] = env + + return { + "session_id": session_id, + "observation": { + "schema_summary": result.observation.schema_summary, + "question": result.observation.question, + "step": result.observation.step, + "max_steps": result.observation.max_steps, + "hints": result.observation.hints, + "done": result.observation.done, + }, + "reward": result.reward, + "done": result.done, + } + + +@app.post("/step") +async def step(req: StepRequest) -> Dict[str, Any]: + session_id = req.session_id + + if session_id not in envs: + return {"error": "Session not found. Call /reset first."} + + env = envs[session_id] + + action = Action(sql_query=req.sql_query, submit_answer=req.submit_answer) + + result = env.step(action) + + return { + "observation": { + "schema_summary": result.observation.schema_summary, + "question": result.observation.question, + "last_query": result.observation.last_query, + "last_result": { + "columns": result.observation.last_result.columns + if result.observation.last_result + else None, + "rows": result.observation.last_result.rows + if result.observation.last_result + else None, + "error": result.observation.last_result.error + if result.observation.last_result + else None, + } + if result.observation.last_result + else None, + "last_error": result.observation.last_error, + "step": result.observation.step, + "max_steps": result.observation.max_steps, + "hints": result.observation.hints, + "done": result.observation.done, + }, + "reward": result.reward, + "done": result.done, + "info": result.info, + } + + +@app.post("/state") +async def state(req: StateRequest) -> Dict[str, Any]: + session_id = req.session_id + + if session_id not in envs: + return {"error": "Session not found. Call /reset first."} + + env = envs[session_id] + state = env.state() + + return { + "task_id": state.task_id, + "difficulty": state.difficulty, + "step": state.step, + "max_steps": state.max_steps, + "query_history": state.query_history, + "total_reward": state.total_reward, + "done": state.done, + } + + +@app.post("/delete") +async def delete_session(req: StateRequest) -> Dict[str, str]: + session_id = req.session_id + + if session_id in envs: + del envs[session_id] + return {"status": "deleted", "session_id": session_id} + + return {"status": "not_found", "session_id": session_id} + + +@app.websocket("/ws") +async def websocket_endpoint(websocket: WebSocket): + await websocket.accept() + + session_id = None + env = None + + try: + while True: + data = await websocket.receive_text() + message = json.loads(data) + + action_type = message.get("type") + + if action_type == "reset": + task_id = message.get("task_id", "monthly_signups") + env = SQLAnalystEnv(task_id=task_id) + result = env.reset() + session_id = task_id + envs[session_id] = env + + await websocket.send_json( + { + "type": "reset", + "observation": { + "schema_summary": result.observation.schema_summary, + "question": result.observation.question, + "step": result.observation.step, + "max_steps": result.observation.max_steps, + "hints": result.observation.hints, + }, + "reward": result.reward, + "done": result.done, + } + ) + + elif action_type == "step": + if not env: + await websocket.send_json({"error": "Call reset first"}) + continue + + action = Action( + sql_query=message.get("sql_query"), + submit_answer=message.get("submit_answer"), + ) + + result = env.step(action) + + await websocket.send_json( + { + "type": "step", + "observation": { + "schema_summary": result.observation.schema_summary, + "question": result.observation.question, + "last_query": result.observation.last_query, + "last_result": { + "columns": result.observation.last_result.columns + if result.observation.last_result + else None, + "rows": result.observation.last_result.rows + if result.observation.last_result + else None, + "error": result.observation.last_result.error + if result.observation.last_result + else None, + } + if result.observation.last_result + else None, + "step": result.observation.step, + "hints": result.observation.hints, + "done": result.observation.done, + }, + "reward": result.reward, + "done": result.done, + "info": result.info, + } + ) + + elif action_type == "state": + if not env: + await websocket.send_json({"error": "Call reset first"}) + continue + + state = env.state() + + await websocket.send_json( + { + "type": "state", + "task_id": state.task_id, + "difficulty": state.difficulty, + "step": state.step, + "max_steps": state.max_steps, + "query_history": state.query_history, + "total_reward": state.total_reward, + "done": state.done, + } + ) + + elif action_type == "close": + if session_id and session_id in envs: + del envs[session_id] + break + + except WebSocketDisconnect: + pass + except Exception as e: + await websocket.send_json({"error": str(e)}) + + +if __name__ == "__main__": + import uvicorn + + uvicorn.run(app, host="0.0.0.0", port=7860) diff --git a/temp_upload/env/tasks/__init__.py b/temp_upload/env/tasks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ef74bb4abe44e3d0caec0cbd2b89e39580236dbc --- /dev/null +++ b/temp_upload/env/tasks/__init__.py @@ -0,0 +1,20 @@ +from .base import BaseTask +from .easy import MonthlySignupsTask +from .medium import TopRevenueCategoryTask +from .hard import ChurnAnalysisTask + + +TASKS = { + "monthly_signups": MonthlySignupsTask(), + "top_revenue_category": TopRevenueCategoryTask(), + "churn_analysis": ChurnAnalysisTask(), +} + + +__all__ = [ + "BaseTask", + "MonthlySignupsTask", + "TopRevenueCategoryTask", + "ChurnAnalysisTask", + "TASKS", +] diff --git a/temp_upload/env/tasks/base.py b/temp_upload/env/tasks/base.py new file mode 100644 index 0000000000000000000000000000000000000000..c97be757e04709f605f9f6c955e64cd32515d841 --- /dev/null +++ b/temp_upload/env/tasks/base.py @@ -0,0 +1,52 @@ +from abc import ABC, abstractmethod +import sqlite3 +import re +from typing import Any, List, Optional + + +class BaseTask(ABC): + """Abstract base class for all tasks.""" + + task_id: str + difficulty: str + max_steps: int + question: str + relevant_tables: List[str] + sql_hint: str + + def __init__(self): + self.ground_truth: Any = None + self.top_3_categories: List[str] = [] + + @abstractmethod + def compute_ground_truth(self, conn: sqlite3.Connection) -> None: + """Compute ground truth after database seeding.""" + pass + + @abstractmethod + def grade(self, submitted_answer: str) -> float: + """Grade the submitted answer. Returns score 0.0-1.0.""" + pass + + def get_hints(self, step: int) -> List[str]: + """Return progressive hints based on current step.""" + hints = [] + if step > 5: + hints.append( + f"Hint: The relevant tables are: {', '.join(self.relevant_tables)}" + ) + if step > 10: + hints.append(f"Hint: Try using {self.sql_hint}") + if step > 15: + hints.append("Hint: Make sure to submit your answer with submit_answer.") + return hints + + def _normalize(self, text: str) -> str: + """Remove common LLM formatting and normalize text.""" + text = text.strip().lower() + text = re.sub(r"the (answer|result|category) is:?\s*", "", text) + text = re.sub(r"\*+", "", text) + text = re.sub(r"```.*?```", "", text, flags=re.DOTALL) + text = re.sub(r"`[^`]+`", lambda m: m.group().strip("`"), text) + text = re.sub(r"\s+", " ", text) + return text.strip() diff --git a/temp_upload/env/tasks/easy.py b/temp_upload/env/tasks/easy.py new file mode 100644 index 0000000000000000000000000000000000000000..cc936675157d104a1a97bdfe3c67b42bf74bda35 --- /dev/null +++ b/temp_upload/env/tasks/easy.py @@ -0,0 +1,32 @@ +import sqlite3 +from .base import BaseTask + + +class MonthlySignupsTask(BaseTask): + """Task 1 β€” Easy: Count users who signed up in the last 30 days.""" + + task_id = "monthly_signups" + difficulty = "easy" + max_steps = 10 + question = "How many users signed up in the last 30 days?" + relevant_tables = ["users"] + sql_hint = "COUNT(*) with WHERE clause on created_at" + + def compute_ground_truth(self, conn: sqlite3.Connection) -> None: + result = conn.execute( + "SELECT COUNT(*) FROM users WHERE created_at >= DATE('now', '-30 days')" + ).fetchone() + self.ground_truth = result[0] if result else 0 + + def grade(self, submitted_answer: str) -> float: + try: + val = int(submitted_answer.strip().replace(",", "")) + if val == self.ground_truth: + return 1.0 + if abs(val - self.ground_truth) <= 3: + return 0.6 + if abs(val - self.ground_truth) <= 10: + return 0.3 + except (ValueError, AttributeError): + pass + return 0.0 diff --git a/temp_upload/env/tasks/hard.py b/temp_upload/env/tasks/hard.py new file mode 100644 index 0000000000000000000000000000000000000000..852e8eda17fc31ab0b3a78bdd0958881cae3301f --- /dev/null +++ b/temp_upload/env/tasks/hard.py @@ -0,0 +1,59 @@ +import sqlite3 +from .base import BaseTask + + +class ChurnAnalysisTask(BaseTask): + """Task 3 β€” Hard: Find users who placed exactly 3 orders and then churned.""" + + task_id = "churn_analysis" + difficulty = "hard" + max_steps = 20 + question = "Find the email addresses of users who placed exactly 3 orders and then never ordered again (churned after their 3rd purchase). Return as a comma-separated list." + relevant_tables = ["users", "orders"] + sql_hint = "CTE with COUNT and HAVING" + + def compute_ground_truth(self, conn: sqlite3.Connection) -> None: + result = conn.execute(""" + WITH order_counts AS ( + SELECT user_id, COUNT(*) AS total_orders, + MAX(created_at) AS last_order_date + FROM orders + WHERE status = 'completed' + GROUP BY user_id + HAVING COUNT(*) = 3 + ), + churned AS ( + SELECT oc.user_id + FROM order_counts oc + WHERE oc.last_order_date < DATE('now', '-90 days') + ) + SELECT u.email + FROM users u + JOIN churned c ON u.id = c.user_id + """).fetchall() + + self.ground_truth = {row[0].lower() for row in result} + + def grade(self, submitted_answer: str) -> float: + if not submitted_answer.strip(): + return 0.0 + + submitted = {e.strip().lower() for e in submitted_answer.split(",") if "@" in e} + + if not submitted: + return 0.0 + + correct = {e.lower() for e in self.ground_truth} + tp = len(submitted & correct) + + if tp == 0: + return 0.0 + + precision = tp / len(submitted) if submitted else 0 + recall = tp / len(correct) if correct else 0 + + if precision + recall == 0: + return 0.0 + + f1 = 2 * precision * recall / (precision + recall) + return round(f1, 3) diff --git a/temp_upload/env/tasks/medium.py b/temp_upload/env/tasks/medium.py new file mode 100644 index 0000000000000000000000000000000000000000..1a0f82dd33873f3c790efa7b0a3601cc6b5542d8 --- /dev/null +++ b/temp_upload/env/tasks/medium.py @@ -0,0 +1,54 @@ +import sqlite3 +from .base import BaseTask + + +class TopRevenueCategoryTask(BaseTask): + """Task 2 β€” Medium: Find product category with most revenue in Q3.""" + + task_id = "top_revenue_category" + difficulty = "medium" + max_steps = 15 + question = ( + "Which product category generated the most revenue in Q3 (July-September)?" + ) + relevant_tables = ["orders", "order_items", "products"] + sql_hint = "JOIN with GROUP BY and ORDER BY" + + def compute_ground_truth(self, conn: sqlite3.Connection) -> None: + result = conn.execute(""" + SELECT p.category, SUM(oi.qty * oi.unit_price) AS revenue + FROM orders o + JOIN order_items oi ON o.id = oi.order_id + JOIN products p ON oi.product_id = p.id + WHERE o.created_at BETWEEN '2024-07-01' AND '2024-09-30' + AND o.status = 'completed' + GROUP BY p.category + ORDER BY revenue DESC + LIMIT 1 + """).fetchone() + + self.ground_truth = result[0] if result else None + + all_categories = conn.execute(""" + SELECT p.category, SUM(oi.qty * oi.unit_price) AS revenue + FROM orders o + JOIN order_items oi ON o.id = oi.order_id + JOIN products p ON oi.product_id = p.id + WHERE o.created_at BETWEEN '2024-07-01' AND '2024-09-30' + AND o.status = 'completed' + GROUP BY p.category + ORDER BY revenue DESC + """).fetchall() + + self.top_3_categories = [row[0] for row in all_categories[:3]] + + def grade(self, submitted_answer: str) -> float: + answer = self._normalize(submitted_answer) + + if self.ground_truth and self.ground_truth.lower() in answer: + return 1.0 + + if any(cat.lower() in answer for cat in self.top_3_categories): + return 0.4 + + return 0.0 diff --git a/temp_upload/env/utils.py b/temp_upload/env/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..140816331dc7185e448f5ba964aa0a20ede639eb --- /dev/null +++ b/temp_upload/env/utils.py @@ -0,0 +1,29 @@ +import re + + +def normalize_answer(raw: str) -> str: + """Remove common LLM answer preambles and formatting.""" + text = raw.strip().lower() + text = re.sub(r"the (answer|result) is:?\s*", "", text) + text = re.sub(r"\*+", "", text) + text = re.sub(r"```.*?```", "", text, flags=re.DOTALL) + text = re.sub(r"`[^`]+`", lambda m: m.group().strip("`"), text) + text = re.sub(r"\s+", " ", text) + return text.strip() + + +FORBIDDEN_KEYWORDS = [ + "DROP", + "DELETE", + "INSERT", + "UPDATE", + "ALTER", + "CREATE", + "TRUNCATE", +] + + +def is_safe_query(query: str) -> bool: + """Check if query is safe (SELECT-only).""" + upper = query.upper() + return not any(kw in upper for kw in FORBIDDEN_KEYWORDS) diff --git a/temp_upload/server/__init__.py b/temp_upload/server/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/temp_upload/server/app.py b/temp_upload/server/app.py new file mode 100644 index 0000000000000000000000000000000000000000..04e6c5a69e685c98e15f634b9e7a6ad70ecc89b3 --- /dev/null +++ b/temp_upload/server/app.py @@ -0,0 +1,12 @@ +from env.server import app as _app +import uvicorn + + +def main(): + uvicorn.run(_app, host="0.0.0.0", port=7860) + + +if __name__ == "__main__": + main() + +__all__ = ["app", "main"] diff --git a/temp_upload/tests/__init__.py b/temp_upload/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/temp_upload/tests/test_env.py b/temp_upload/tests/test_env.py new file mode 100644 index 0000000000000000000000000000000000000000..66cf8c93a87f3a04e13e3bb319060efe1960ba9d --- /dev/null +++ b/temp_upload/tests/test_env.py @@ -0,0 +1,156 @@ +import pytest +from env import SQLAnalystEnv, Action + + +class TestOpenEnvContract: + """Test OpenEnv required methods: reset(), step(), state()""" + + def test_reset_returns_step_result(self): + """reset() must return a StepResult with observation""" + env = SQLAnalystEnv(task_id="monthly_signups") + result = env.reset() + + assert result.observation is not None + assert result.reward == 0.0 + assert result.done is False + + def test_reset_contains_required_fields(self): + """Initial observation must contain all required fields""" + env = SQLAnalystEnv(task_id="monthly_signups") + result = env.reset() + obs = result.observation + + assert obs.schema_summary is not None + assert obs.question is not None + assert obs.step == 0 + assert obs.max_steps == 10 + + def test_step_returns_step_result(self): + """step() must return StepResult with observation, reward, done""" + env = SQLAnalystEnv(task_id="monthly_signups") + env.reset() + + action = Action(sql_query="SELECT COUNT(*) FROM users") + result = env.step(action) + + assert result.observation is not None + assert isinstance(result.reward, float) + assert isinstance(result.done, bool) + assert result.info is not None + + def test_step_increments_step_count(self): + """Each step should increment step count""" + env = SQLAnalystEnv(task_id="monthly_signups") + env.reset() + + action = Action(sql_query="SELECT COUNT(*) FROM users") + + result1 = env.step(action) + assert result1.observation.step == 1 + + result2 = env.step(action) + assert result2.observation.step == 2 + + def test_step_sql_query_execution(self): + """step() with sql_query should execute query and return result""" + env = SQLAnalystEnv(task_id="monthly_signups") + env.reset() + + action = Action(sql_query="SELECT COUNT(*) as cnt FROM users") + result = env.step(action) + + assert result.observation.last_query is not None + assert result.observation.last_result is not None + assert "cnt" in result.observation.last_result.columns + + def test_step_submit_answer_terminates(self): + """submit_answer should set done=True""" + env = SQLAnalystEnv(task_id="monthly_signups") + env.reset() + + action = Action(submit_answer="100") + result = env.step(action) + + assert result.done is True + + def test_step_max_steps_terminates(self): + """Exceeding max_steps should terminate episode""" + env = SQLAnalystEnv(task_id="monthly_signups") + env.reset() + + # Easy task has max_steps=10 + done_count = 0 + for i in range(10): + action = Action(sql_query="SELECT 1") + result = env.step(action) + if result.done: + done_count += 1 + + # At step 10 (last step), done should be True + assert done_count >= 1, "Episode should terminate by step 10" + + def test_state_returns_full_state(self): + """state() should return EnvState with all metadata""" + env = SQLAnalystEnv(task_id="monthly_signups") + env.reset() + + env.step(Action(sql_query="SELECT 1")) + + state = env.state() + + assert state.task_id == "monthly_signups" + assert state.difficulty == "easy" + assert state.step > 0 + assert state.max_steps == 10 + assert len(state.query_history) > 0 + + def test_invalid_action_raises(self): + """Action with both sql_query and submit_answer should error""" + env = SQLAnalystEnv(task_id="monthly_signups") + env.reset() + + # This should fail - exactly one must be set + with pytest.raises(AssertionError): + action = Action(sql_query="SELECT 1", submit_answer="test") + env.step(action) + + def test_all_three_tasks_work(self): + """All task IDs should be supported""" + for task_id in ["monthly_signups", "top_revenue_category", "churn_analysis"]: + env = SQLAnalystEnv(task_id=task_id) + result = env.reset() + assert result.observation.question is not None + + +class TestEdgeCases: + """Test edge cases and error handling""" + + def test_sql_error_returns_error_in_observation(self): + """Invalid SQL should return error in observation""" + env = SQLAnalystEnv(task_id="monthly_signups") + env.reset() + + action = Action(sql_query="SELECT * FROM nonexistent_table") + result = env.step(action) + + assert result.observation.last_error is not None + assert result.observation.last_error != "" + + def test_non_select_blocked(self): + """Non-SELECT queries should be blocked""" + env = SQLAnalystEnv(task_id="monthly_signups") + env.reset() + + action = Action(sql_query="DELETE FROM users") + result = env.step(action) + + assert result.observation.last_error is not None + assert "Only SELECT" in result.observation.last_error + + def test_empty_db_after_reset(self): + """Database should have data after reset""" + env = SQLAnalystEnv(task_id="monthly_signups") + env.reset() + + result = env.step(Action(sql_query="SELECT COUNT(*) FROM users")) + assert result.observation.last_result.rows[0][0] > 0 diff --git a/temp_upload/tests/test_graders.py b/temp_upload/tests/test_graders.py new file mode 100644 index 0000000000000000000000000000000000000000..4828a1815388c1cd1fa8097e25d0722a92f0cd05 --- /dev/null +++ b/temp_upload/tests/test_graders.py @@ -0,0 +1,223 @@ +import pytest +import sqlite3 +from env.tasks import ( + TASKS, + MonthlySignupsTask, + TopRevenueCategoryTask, + ChurnAnalysisTask, +) +from env.database import create_database, seed_database + + +class TestMonthlySignupsGrader: + """Test the easy task grader""" + + def test_perfect_answer(self): + """Exact match should return 1.0""" + task = MonthlySignupsTask() + task.ground_truth = 100 + + score = task.grade("100") + assert score == 1.0 + + def test_partial_credit_within_3(self): + """Answer within 3 should return 0.6""" + task = MonthlySignupsTask() + task.ground_truth = 100 + + score = task.grade("98") + assert score == 0.6 + + def test_small_credit_within_10(self): + """Answer within 10 should return 0.3""" + task = MonthlySignupsTask() + task.ground_truth = 100 + + score = task.grade("92") + assert score == 0.3 + + def test_wrong_answer(self): + """Wrong answer should return 0.0""" + task = MonthlySignupsTask() + task.ground_truth = 100 + + score = task.grade("50") + assert score == 0.0 + + def test_comma_separated_number(self): + """Numbers with commas should work""" + task = MonthlySignupsTask() + task.ground_truth = 1000 + + score = task.grade("1,000") + assert score == 1.0 + + def test_invalid_input(self): + """Invalid input should return 0.0""" + task = MonthlySignupsTask() + task.ground_truth = 100 + + score = task.grade("not a number") + assert score == 0.0 + + +class TestTopRevenueCategoryGrader: + """Test the medium task grader""" + + def test_perfect_match(self): + """Exact category match should return 1.0""" + task = TopRevenueCategoryTask() + task.ground_truth = "Electronics" + task.top_3_categories = ["Electronics", "Books", "Clothing"] + + score = task.grade("Electronics") + assert score == 1.0 + + def test_partial_match_top_3(self): + """Answer in top 3 but not first should return 0.4""" + task = TopRevenueCategoryTask() + task.ground_truth = "Electronics" + task.top_3_categories = ["Electronics", "Books", "Clothing"] + + score = task.grade("Books") + assert score == 0.4 + + def test_case_insensitive(self): + """Should be case insensitive""" + task = TopRevenueCategoryTask() + task.ground_truth = "Electronics" + task.top_3_categories = ["Electronics", "Books", "Clothing"] + + score = task.grade("electronics") + assert score == 1.0 + + def test_llm_preamble_removed(self): + """LLM preamble should be stripped""" + task = TopRevenueCategoryTask() + task.ground_truth = "Electronics" + task.top_3_categories = ["Electronics", "Books", "Clothing"] + + score = task.grade("The answer is: Electronics") + assert score == 1.0 + + def test_wrong_category(self): + """Wrong category should return 0.0""" + task = TopRevenueCategoryTask() + task.ground_truth = "Electronics" + task.top_3_categories = ["Electronics", "Books", "Clothing"] + + score = task.grade("Sports") + assert score == 0.0 + + +class TestChurnAnalysisGrader: + """Test the hard task grader""" + + def test_perfect_match_all_emails(self): + """All correct emails should return 1.0""" + task = ChurnAnalysisTask() + task.ground_truth = {"a@test.com", "b@test.com", "c@test.com"} + + score = task.grade("a@test.com, b@test.com, c@test.com") + assert score == 1.0 + + def test_partial_match_precision_recall(self): + """Partial match should use F1 score""" + task = ChurnAnalysisTask() + task.ground_truth = {"a@test.com", "b@test.com", "c@test.com"} + + # 2 correct out of 3 submitted, 3 total correct + # precision = 2/3 = 0.667, recall = 2/3 = 0.667, f1 = 0.667 + score = task.grade("a@test.com, b@test.com, wrong@test.com") + assert abs(score - 0.667) < 0.01 + + def test_empty_submission(self): + """Empty submission should return 0.0""" + task = ChurnAnalysisTask() + task.ground_truth = {"a@test.com", "b@test.com"} + + score = task.grade("") + assert score == 0.0 + + def test_no_valid_emails(self): + """No valid emails should return 0.0""" + task = ChurnAnalysisTask() + task.ground_truth = {"a@test.com", "b@test.com"} + + score = task.grade("not an email, also not") + assert score == 0.0 + + def test_case_insensitive(self): + """Should be case insensitive""" + task = ChurnAnalysisTask() + task.ground_truth = {"A@Test.com", "B@Test.com"} + + score = task.grade("a@test.com, b@test.com") + assert score == 1.0 + + +class TestTaskIntegration: + """Test tasks with real database""" + + def test_monthly_signups_with_real_db(self): + """Test with seeded database""" + conn = create_database() + seed_database(conn) + + task = MonthlySignupsTask() + task.compute_ground_truth(conn) + + assert task.ground_truth is not None + assert isinstance(task.ground_truth, int) + + def test_top_revenue_with_real_db(self): + """Test with seeded database""" + conn = create_database() + seed_database(conn) + + task = TopRevenueCategoryTask() + task.compute_ground_truth(conn) + + assert task.ground_truth is not None + assert len(task.top_3_categories) == 3 + + def test_churn_analysis_with_real_db(self): + """Test with seeded database""" + conn = create_database() + seed_database(conn) + + task = ChurnAnalysisTask() + task.compute_ground_truth(conn) + + assert task.ground_truth is not None + assert isinstance(task.ground_truth, set) + + +class TestHintSystem: + """Test progressive hints""" + + def test_no_hints_early(self): + """No hints at step 5 or less""" + task = MonthlySignupsTask() + hints = task.get_hints(3) + assert len(hints) == 0 + + def test_first_hint_after_5(self): + """First hint after step 5""" + task = MonthlySignupsTask() + hints = task.get_hints(6) + assert len(hints) >= 1 + assert "relevant tables" in hints[0].lower() + + def test_second_hint_after_10(self): + """Second hint after step 10""" + task = MonthlySignupsTask() + hints = task.get_hints(11) + assert len(hints) >= 2 + + def test_third_hint_after_15(self): + """Third hint after step 15""" + task = MonthlySignupsTask() + hints = task.get_hints(16) + assert len(hints) >= 3 + assert "submit_answer" in hints[2].lower() diff --git a/temp_upload/tests/test_reward.py b/temp_upload/tests/test_reward.py new file mode 100644 index 0000000000000000000000000000000000000000..d4f25af034d322087c1cac4eebc025f762a74e01 --- /dev/null +++ b/temp_upload/tests/test_reward.py @@ -0,0 +1,136 @@ +import pytest +from env.reward import RewardCalculator +from env.models import Action, QueryResult +from env.tasks import MonthlySignupsTask + + +class MockTask: + """Mock task for testing reward calculator""" + + def __init__(self): + self.relevant_tables = ["users", "orders"] + self.ground_truth = 100 + self.difficulty = "easy" + self.max_steps = 10 + + def grade(self, answer): + return 1.0 if answer == str(self.ground_truth) else 0.0 + + def get_hints(self, step): + return [] + + +class TestRewardCalculator: + """Test the reward calculation logic""" + + def setup_method(self): + self.calc = RewardCalculator() + self.task = MockTask() + + def test_no_error_query_reward(self): + """Query without error gets +0.15""" + action = Action(sql_query="SELECT 1 FROM users") + result = QueryResult(columns=["1"], rows=[[1]], error=None) + + reward = self.calc.calculate(action, result, self.task, 1, [], False) + + assert reward >= 0.15 + + def test_relevant_table_reward(self): + """Query touching relevant table gets +0.10""" + action = Action(sql_query="SELECT * FROM users") + result = QueryResult(columns=["id"], rows=[[1]], error=None) + + reward = self.calc.calculate(action, result, self.task, 1, [], False) + + assert reward >= 0.10 + + def test_non_empty_result_reward(self): + """Query with rows gets +0.05""" + action = Action(sql_query="SELECT 1") + result = QueryResult(columns=["1"], rows=[[1]], error=None) + + reward = self.calc.calculate(action, result, self.task, 1, [], False) + + assert reward >= 0.05 + + def test_error_query_no_reward(self): + """Query with error gets no step rewards""" + action = Action(sql_query="SELECT * FROM nonexistent") + result = QueryResult(columns=[], rows=[], error="Table not found") + + reward = self.calc.calculate(action, result, self.task, 1, [], False) + + assert reward == 0.0 + + def test_efficiency_penalty_after_step_3(self): + """Steps beyond 3 get -0.02 per step""" + action = Action(sql_query="SELECT 1") + result = QueryResult(columns=["1"], rows=[[1]], error=None) + + reward = self.calc.calculate(action, result, self.task, 5, [], False) + + # 0.15 + 0.10 + 0.05 + 0.05 - (0.02 * 2) = 0.31 + assert reward < 0.35 + + def test_infinite_loop_penalty(self): + """Same query 3 times gets -0.10""" + action = Action(sql_query="SELECT 1") + result = QueryResult(columns=["1"], rows=[[1]], error=None) + + query_history = ["SELECT 1", "SELECT 1", "SELECT 1"] + reward = self.calc.calculate(action, result, self.task, 4, query_history, False) + + assert reward < 0.30 + + def test_terminal_submit_grade_reward(self): + """Terminal submit gets up to 0.60 based on grade""" + action = Action(submit_answer="100") + result = None + + # Use step 1 to avoid efficiency penalty + reward = self.calc.calculate(action, result, self.task, 1, [], True) + + # grade(100) = 1.0 * 0.60 = 0.60 + assert reward >= 0.60 + + def test_terminal_submit_wrong_answer(self): + """Wrong answer gets partial terminal reward""" + action = Action(submit_answer="999") + result = None + + reward = self.calc.calculate(action, result, self.task, 5, [], True) + + # grade(999) = 0.0 * 0.60 = 0.0 + assert reward < 0.10 + + def test_reward_clamped_to_0_1(self): + """Reward should be clamped between 0 and 1""" + # Create task that always grades 1.0 + task = MockTask() + + # Many steps should accumulate penalty but stay >= 0 + action = Action(sql_query="SELECT 1") + result = QueryResult(columns=["1"], rows=[[1]], error=None) + + reward = self.calc.calculate(action, result, task, 50, [], False) + + assert 0.0 <= reward <= 1.0 + + +class TestRewardBreakdown: + """Test specific reward components""" + + def test_max_step_reward_calculation(self): + """Test maximum possible reward at good query""" + action = Action(sql_query="SELECT * FROM users") + result = QueryResult(columns=["id"], rows=[[1], [2], [3]], error=None) + + calc = RewardCalculator() + task = MockTask() + + reward = calc.calculate(action, result, task, 1, [], False) + + # 0.15 (no error) + 0.10 (relevant table) + 0.05 (has rows) + 0.05 (reasonable size) + expected = 0.35 + assert abs(reward - expected) < 0.01 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tests/test_env.py b/tests/test_env.py new file mode 100644 index 0000000000000000000000000000000000000000..66cf8c93a87f3a04e13e3bb319060efe1960ba9d --- /dev/null +++ b/tests/test_env.py @@ -0,0 +1,156 @@ +import pytest +from env import SQLAnalystEnv, Action + + +class TestOpenEnvContract: + """Test OpenEnv required methods: reset(), step(), state()""" + + def test_reset_returns_step_result(self): + """reset() must return a StepResult with observation""" + env = SQLAnalystEnv(task_id="monthly_signups") + result = env.reset() + + assert result.observation is not None + assert result.reward == 0.0 + assert result.done is False + + def test_reset_contains_required_fields(self): + """Initial observation must contain all required fields""" + env = SQLAnalystEnv(task_id="monthly_signups") + result = env.reset() + obs = result.observation + + assert obs.schema_summary is not None + assert obs.question is not None + assert obs.step == 0 + assert obs.max_steps == 10 + + def test_step_returns_step_result(self): + """step() must return StepResult with observation, reward, done""" + env = SQLAnalystEnv(task_id="monthly_signups") + env.reset() + + action = Action(sql_query="SELECT COUNT(*) FROM users") + result = env.step(action) + + assert result.observation is not None + assert isinstance(result.reward, float) + assert isinstance(result.done, bool) + assert result.info is not None + + def test_step_increments_step_count(self): + """Each step should increment step count""" + env = SQLAnalystEnv(task_id="monthly_signups") + env.reset() + + action = Action(sql_query="SELECT COUNT(*) FROM users") + + result1 = env.step(action) + assert result1.observation.step == 1 + + result2 = env.step(action) + assert result2.observation.step == 2 + + def test_step_sql_query_execution(self): + """step() with sql_query should execute query and return result""" + env = SQLAnalystEnv(task_id="monthly_signups") + env.reset() + + action = Action(sql_query="SELECT COUNT(*) as cnt FROM users") + result = env.step(action) + + assert result.observation.last_query is not None + assert result.observation.last_result is not None + assert "cnt" in result.observation.last_result.columns + + def test_step_submit_answer_terminates(self): + """submit_answer should set done=True""" + env = SQLAnalystEnv(task_id="monthly_signups") + env.reset() + + action = Action(submit_answer="100") + result = env.step(action) + + assert result.done is True + + def test_step_max_steps_terminates(self): + """Exceeding max_steps should terminate episode""" + env = SQLAnalystEnv(task_id="monthly_signups") + env.reset() + + # Easy task has max_steps=10 + done_count = 0 + for i in range(10): + action = Action(sql_query="SELECT 1") + result = env.step(action) + if result.done: + done_count += 1 + + # At step 10 (last step), done should be True + assert done_count >= 1, "Episode should terminate by step 10" + + def test_state_returns_full_state(self): + """state() should return EnvState with all metadata""" + env = SQLAnalystEnv(task_id="monthly_signups") + env.reset() + + env.step(Action(sql_query="SELECT 1")) + + state = env.state() + + assert state.task_id == "monthly_signups" + assert state.difficulty == "easy" + assert state.step > 0 + assert state.max_steps == 10 + assert len(state.query_history) > 0 + + def test_invalid_action_raises(self): + """Action with both sql_query and submit_answer should error""" + env = SQLAnalystEnv(task_id="monthly_signups") + env.reset() + + # This should fail - exactly one must be set + with pytest.raises(AssertionError): + action = Action(sql_query="SELECT 1", submit_answer="test") + env.step(action) + + def test_all_three_tasks_work(self): + """All task IDs should be supported""" + for task_id in ["monthly_signups", "top_revenue_category", "churn_analysis"]: + env = SQLAnalystEnv(task_id=task_id) + result = env.reset() + assert result.observation.question is not None + + +class TestEdgeCases: + """Test edge cases and error handling""" + + def test_sql_error_returns_error_in_observation(self): + """Invalid SQL should return error in observation""" + env = SQLAnalystEnv(task_id="monthly_signups") + env.reset() + + action = Action(sql_query="SELECT * FROM nonexistent_table") + result = env.step(action) + + assert result.observation.last_error is not None + assert result.observation.last_error != "" + + def test_non_select_blocked(self): + """Non-SELECT queries should be blocked""" + env = SQLAnalystEnv(task_id="monthly_signups") + env.reset() + + action = Action(sql_query="DELETE FROM users") + result = env.step(action) + + assert result.observation.last_error is not None + assert "Only SELECT" in result.observation.last_error + + def test_empty_db_after_reset(self): + """Database should have data after reset""" + env = SQLAnalystEnv(task_id="monthly_signups") + env.reset() + + result = env.step(Action(sql_query="SELECT COUNT(*) FROM users")) + assert result.observation.last_result.rows[0][0] > 0 diff --git a/tests/test_graders.py b/tests/test_graders.py new file mode 100644 index 0000000000000000000000000000000000000000..4828a1815388c1cd1fa8097e25d0722a92f0cd05 --- /dev/null +++ b/tests/test_graders.py @@ -0,0 +1,223 @@ +import pytest +import sqlite3 +from env.tasks import ( + TASKS, + MonthlySignupsTask, + TopRevenueCategoryTask, + ChurnAnalysisTask, +) +from env.database import create_database, seed_database + + +class TestMonthlySignupsGrader: + """Test the easy task grader""" + + def test_perfect_answer(self): + """Exact match should return 1.0""" + task = MonthlySignupsTask() + task.ground_truth = 100 + + score = task.grade("100") + assert score == 1.0 + + def test_partial_credit_within_3(self): + """Answer within 3 should return 0.6""" + task = MonthlySignupsTask() + task.ground_truth = 100 + + score = task.grade("98") + assert score == 0.6 + + def test_small_credit_within_10(self): + """Answer within 10 should return 0.3""" + task = MonthlySignupsTask() + task.ground_truth = 100 + + score = task.grade("92") + assert score == 0.3 + + def test_wrong_answer(self): + """Wrong answer should return 0.0""" + task = MonthlySignupsTask() + task.ground_truth = 100 + + score = task.grade("50") + assert score == 0.0 + + def test_comma_separated_number(self): + """Numbers with commas should work""" + task = MonthlySignupsTask() + task.ground_truth = 1000 + + score = task.grade("1,000") + assert score == 1.0 + + def test_invalid_input(self): + """Invalid input should return 0.0""" + task = MonthlySignupsTask() + task.ground_truth = 100 + + score = task.grade("not a number") + assert score == 0.0 + + +class TestTopRevenueCategoryGrader: + """Test the medium task grader""" + + def test_perfect_match(self): + """Exact category match should return 1.0""" + task = TopRevenueCategoryTask() + task.ground_truth = "Electronics" + task.top_3_categories = ["Electronics", "Books", "Clothing"] + + score = task.grade("Electronics") + assert score == 1.0 + + def test_partial_match_top_3(self): + """Answer in top 3 but not first should return 0.4""" + task = TopRevenueCategoryTask() + task.ground_truth = "Electronics" + task.top_3_categories = ["Electronics", "Books", "Clothing"] + + score = task.grade("Books") + assert score == 0.4 + + def test_case_insensitive(self): + """Should be case insensitive""" + task = TopRevenueCategoryTask() + task.ground_truth = "Electronics" + task.top_3_categories = ["Electronics", "Books", "Clothing"] + + score = task.grade("electronics") + assert score == 1.0 + + def test_llm_preamble_removed(self): + """LLM preamble should be stripped""" + task = TopRevenueCategoryTask() + task.ground_truth = "Electronics" + task.top_3_categories = ["Electronics", "Books", "Clothing"] + + score = task.grade("The answer is: Electronics") + assert score == 1.0 + + def test_wrong_category(self): + """Wrong category should return 0.0""" + task = TopRevenueCategoryTask() + task.ground_truth = "Electronics" + task.top_3_categories = ["Electronics", "Books", "Clothing"] + + score = task.grade("Sports") + assert score == 0.0 + + +class TestChurnAnalysisGrader: + """Test the hard task grader""" + + def test_perfect_match_all_emails(self): + """All correct emails should return 1.0""" + task = ChurnAnalysisTask() + task.ground_truth = {"a@test.com", "b@test.com", "c@test.com"} + + score = task.grade("a@test.com, b@test.com, c@test.com") + assert score == 1.0 + + def test_partial_match_precision_recall(self): + """Partial match should use F1 score""" + task = ChurnAnalysisTask() + task.ground_truth = {"a@test.com", "b@test.com", "c@test.com"} + + # 2 correct out of 3 submitted, 3 total correct + # precision = 2/3 = 0.667, recall = 2/3 = 0.667, f1 = 0.667 + score = task.grade("a@test.com, b@test.com, wrong@test.com") + assert abs(score - 0.667) < 0.01 + + def test_empty_submission(self): + """Empty submission should return 0.0""" + task = ChurnAnalysisTask() + task.ground_truth = {"a@test.com", "b@test.com"} + + score = task.grade("") + assert score == 0.0 + + def test_no_valid_emails(self): + """No valid emails should return 0.0""" + task = ChurnAnalysisTask() + task.ground_truth = {"a@test.com", "b@test.com"} + + score = task.grade("not an email, also not") + assert score == 0.0 + + def test_case_insensitive(self): + """Should be case insensitive""" + task = ChurnAnalysisTask() + task.ground_truth = {"A@Test.com", "B@Test.com"} + + score = task.grade("a@test.com, b@test.com") + assert score == 1.0 + + +class TestTaskIntegration: + """Test tasks with real database""" + + def test_monthly_signups_with_real_db(self): + """Test with seeded database""" + conn = create_database() + seed_database(conn) + + task = MonthlySignupsTask() + task.compute_ground_truth(conn) + + assert task.ground_truth is not None + assert isinstance(task.ground_truth, int) + + def test_top_revenue_with_real_db(self): + """Test with seeded database""" + conn = create_database() + seed_database(conn) + + task = TopRevenueCategoryTask() + task.compute_ground_truth(conn) + + assert task.ground_truth is not None + assert len(task.top_3_categories) == 3 + + def test_churn_analysis_with_real_db(self): + """Test with seeded database""" + conn = create_database() + seed_database(conn) + + task = ChurnAnalysisTask() + task.compute_ground_truth(conn) + + assert task.ground_truth is not None + assert isinstance(task.ground_truth, set) + + +class TestHintSystem: + """Test progressive hints""" + + def test_no_hints_early(self): + """No hints at step 5 or less""" + task = MonthlySignupsTask() + hints = task.get_hints(3) + assert len(hints) == 0 + + def test_first_hint_after_5(self): + """First hint after step 5""" + task = MonthlySignupsTask() + hints = task.get_hints(6) + assert len(hints) >= 1 + assert "relevant tables" in hints[0].lower() + + def test_second_hint_after_10(self): + """Second hint after step 10""" + task = MonthlySignupsTask() + hints = task.get_hints(11) + assert len(hints) >= 2 + + def test_third_hint_after_15(self): + """Third hint after step 15""" + task = MonthlySignupsTask() + hints = task.get_hints(16) + assert len(hints) >= 3 + assert "submit_answer" in hints[2].lower() diff --git a/tests/test_reward.py b/tests/test_reward.py new file mode 100644 index 0000000000000000000000000000000000000000..d4f25af034d322087c1cac4eebc025f762a74e01 --- /dev/null +++ b/tests/test_reward.py @@ -0,0 +1,136 @@ +import pytest +from env.reward import RewardCalculator +from env.models import Action, QueryResult +from env.tasks import MonthlySignupsTask + + +class MockTask: + """Mock task for testing reward calculator""" + + def __init__(self): + self.relevant_tables = ["users", "orders"] + self.ground_truth = 100 + self.difficulty = "easy" + self.max_steps = 10 + + def grade(self, answer): + return 1.0 if answer == str(self.ground_truth) else 0.0 + + def get_hints(self, step): + return [] + + +class TestRewardCalculator: + """Test the reward calculation logic""" + + def setup_method(self): + self.calc = RewardCalculator() + self.task = MockTask() + + def test_no_error_query_reward(self): + """Query without error gets +0.15""" + action = Action(sql_query="SELECT 1 FROM users") + result = QueryResult(columns=["1"], rows=[[1]], error=None) + + reward = self.calc.calculate(action, result, self.task, 1, [], False) + + assert reward >= 0.15 + + def test_relevant_table_reward(self): + """Query touching relevant table gets +0.10""" + action = Action(sql_query="SELECT * FROM users") + result = QueryResult(columns=["id"], rows=[[1]], error=None) + + reward = self.calc.calculate(action, result, self.task, 1, [], False) + + assert reward >= 0.10 + + def test_non_empty_result_reward(self): + """Query with rows gets +0.05""" + action = Action(sql_query="SELECT 1") + result = QueryResult(columns=["1"], rows=[[1]], error=None) + + reward = self.calc.calculate(action, result, self.task, 1, [], False) + + assert reward >= 0.05 + + def test_error_query_no_reward(self): + """Query with error gets no step rewards""" + action = Action(sql_query="SELECT * FROM nonexistent") + result = QueryResult(columns=[], rows=[], error="Table not found") + + reward = self.calc.calculate(action, result, self.task, 1, [], False) + + assert reward == 0.0 + + def test_efficiency_penalty_after_step_3(self): + """Steps beyond 3 get -0.02 per step""" + action = Action(sql_query="SELECT 1") + result = QueryResult(columns=["1"], rows=[[1]], error=None) + + reward = self.calc.calculate(action, result, self.task, 5, [], False) + + # 0.15 + 0.10 + 0.05 + 0.05 - (0.02 * 2) = 0.31 + assert reward < 0.35 + + def test_infinite_loop_penalty(self): + """Same query 3 times gets -0.10""" + action = Action(sql_query="SELECT 1") + result = QueryResult(columns=["1"], rows=[[1]], error=None) + + query_history = ["SELECT 1", "SELECT 1", "SELECT 1"] + reward = self.calc.calculate(action, result, self.task, 4, query_history, False) + + assert reward < 0.30 + + def test_terminal_submit_grade_reward(self): + """Terminal submit gets up to 0.60 based on grade""" + action = Action(submit_answer="100") + result = None + + # Use step 1 to avoid efficiency penalty + reward = self.calc.calculate(action, result, self.task, 1, [], True) + + # grade(100) = 1.0 * 0.60 = 0.60 + assert reward >= 0.60 + + def test_terminal_submit_wrong_answer(self): + """Wrong answer gets partial terminal reward""" + action = Action(submit_answer="999") + result = None + + reward = self.calc.calculate(action, result, self.task, 5, [], True) + + # grade(999) = 0.0 * 0.60 = 0.0 + assert reward < 0.10 + + def test_reward_clamped_to_0_1(self): + """Reward should be clamped between 0 and 1""" + # Create task that always grades 1.0 + task = MockTask() + + # Many steps should accumulate penalty but stay >= 0 + action = Action(sql_query="SELECT 1") + result = QueryResult(columns=["1"], rows=[[1]], error=None) + + reward = self.calc.calculate(action, result, task, 50, [], False) + + assert 0.0 <= reward <= 1.0 + + +class TestRewardBreakdown: + """Test specific reward components""" + + def test_max_step_reward_calculation(self): + """Test maximum possible reward at good query""" + action = Action(sql_query="SELECT * FROM users") + result = QueryResult(columns=["id"], rows=[[1], [2], [3]], error=None) + + calc = RewardCalculator() + task = MockTask() + + reward = calc.calculate(action, result, task, 1, [], False) + + # 0.15 (no error) + 0.10 (relevant table) + 0.05 (has rows) + 0.05 (reasonable size) + expected = 0.35 + assert abs(reward - expected) < 0.01 diff --git a/upload_hf.py b/upload_hf.py new file mode 100644 index 0000000000000000000000000000000000000000..2a4c1bf4419b2f4bdc6b5a5045fd559bdbeb18d9 --- /dev/null +++ b/upload_hf.py @@ -0,0 +1,58 @@ +from huggingface_hub import HfApi +import os +import shutil + +api = HfApi() + +temp_dir = "./temp_upload" +if os.path.exists(temp_dir): + shutil.rmtree(temp_dir) +os.makedirs(temp_dir) + +for root, dirs, files in os.walk("."): + skip_dirs = [ + ".git", + "__pycache__", + "hf_space", + ".cache", + "openenvhackathon", + ".pytest_cache", + "temp_upload", + "env\\__pycache__", + "tests\\__pycache__", + "baseline\\__pycache__", + "server\\__pycache__", + ] + if any(s in root for s in skip_dirs): + continue + + temp_root = root.replace(".\\", temp_dir + "\\") + os.makedirs(temp_root, exist_ok=True) + + for f in files: + if ( + not f.endswith(".pyc") + and not f.endswith(".lock") + and not f.startswith(".") + and not f.endswith(".metadata") + ): + src = os.path.join(root, f) + dst = os.path.join(temp_root, f) + try: + shutil.copy2(src, dst) + except Exception as e: + print(f"Skipped {src}: {e}") + +print("Prepared files") + +api.upload_folder( + folder_path=temp_dir, + repo_id="YashashMathur/sql_data_analyst", + repo_type="space", + commit_message="SQL Data Analyst OpenEnv - Initial commit", +) + +print("SUCCESS!") +print("https://huggingface.co/spaces/YashashMathur/sql_data_analyst") + +shutil.rmtree(temp_dir)