YashashMathur commited on
Commit
d103a0f
·
verified ·
1 Parent(s): 0371d4b

SQL Data Analyst OpenEnv - Initial commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
.gitignore ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ *.pyo
4
+ .pytest_cache/
5
+ .ruff_cache/
6
+ *.egg-info/
7
+ dist/
8
+ build/
9
+ .eggs/
10
+ *.egg
11
+ uv.lock
12
+ .env
13
+ .venv/
14
+ venv/
15
+ *.db
16
+ *.sqlite
17
+ *.sqlite3
18
+ .DS_Store
19
+ Thumbs.db
Dockerfile ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Install dependencies
6
+ RUN pip install --no-cache-dir \
7
+ pydantic>=2.0 \
8
+ fastapi>=0.100 \
9
+ uvicorn>=0.20 \
10
+ openai>=1.0 \
11
+ faker>=18.0 \
12
+ pytest>=7.0
13
+
14
+ COPY . .
15
+
16
+ EXPOSE 7860
17
+
18
+ CMD ["python", "-m", "uvicorn", "env.server:app", "--host", "0.0.0.0", "--port", "7860"]
PRD.md ADDED
@@ -0,0 +1 @@
 
 
1
+ PRD: SQL Data Analyst Agent Environment (OpenEnv)1. Executive SummaryThe SQL Data Analyst Agent environment is a production-grade reinforcement learning (RL) space designed to train agents in autonomous data retrieval and analysis. Unlike toy simulations, this environment subjects agents to "messy" real-world database schemas and natural language business queries, requiring them to perform multi-step reasoning, join operations, and query optimization. Success is measured by the agent's ability to return correct data subsets through valid, efficient SQL.2. Core Specification & ArchitectureThe environment follows the 3-component pattern (Models, Client, Server) and the 3-method interface (reset, step, state) mandated by the OpenEnv specification.2.1 Technical StackFramework: OpenEnv v0.2.1+.Server: FastAPI with Uvicorn (WebSocket-enabled via /ws for low-latency training).Database: SQLite or DuckDB (container-local for zero network overhead).Isolation: Docker-based containerization for secure execution of arbitrary SQL.2.2 OpenEnv Interfacereset(task_id: str): Initializes a fresh instance of the "messy" database and returns the schema and business question.step(action: SQLAction): Executes the agent's SQL, captures the output/errors, and returns the next observation and reward.state(): Provides internal episode metadata, including episode_id and step_count, for debugging.3. Data Models (Type-Safe Contracts)All interactions are governed by Pydantic models to ensure schema enforcement and tool reliability.ModelFieldTypeDescriptionSQLActionsql_querystrThe SQL command to execute against the database.is_doneboolFlag to signal the agent has completed the task.SQLObservationschemaDictJSON representation of tables, columns, and types.last_resultListThe first $5$ rows of the previous query result.error_messageOptional[str]Raw SQL error trace if the query failed.step_historyList[str]The last $4$ actions taken to prevent infinite loops [Image 1].4. Multi-Level Task CurriculumThe environment implements a 3-tier curriculum with deterministic graders scoring from $0.0$ to $1.0$.Task 1: Warmup (Easy) - Fix Broken JoinScenario: A query uses a comma-separated cross-join causing a Cartesian product.Goal: Rewrite using INNER JOIN... ON.Grader: Binary match of the resulting dataset count.Task 2: Intermediate (Medium) - Category RevenueScenario: Calculate highest revenue in a specific quarter across messy product/sales tables.Goal: Use JOIN, SUM(), GROUP BY, and ORDER BY.Grader: $0.5$ for correct join + $0.5$ for matching final revenue value.Task 3: Advanced (Hard) - Churn Analysis & OptimizationScenario: Find users who churned after their 3rd purchase using subqueries or window functions.Goal: Optimize a slow, redundant query by removing DISTINCT and replacing LIKE with sargable predicates.Grader: $0.6$ for data accuracy + $0.4$ for reducing query execution cost.5. Reward Design (Partial Progress)To avoid sparse reward pitfalls, the environment provides dense feedback via shaped signals.The total step reward $R_{step}$ is calculated as:$$R_{step} = \text{Delta\_Reward} + \text{Invalid\_Penalty} + \text{Efficiency\_Penalty}$$Delta Reward: $+0.0–0.50 \times \Delta \text{grader\_score}$. Positive signal when the agent's SQL results move closer to the ground truth.Completion Bonus: $+0.50$ when is_done=True and the grader score is $\geq 0.80$.Invalid Penalty: $-0.10$ for unparseable queries or SQL syntax errors to discourage brute-forcing.Efficiency Penalty: $-0.02$ per step after the episode midpoint to encourage concise solutions.6. Implementation & Compliance ChecklistTo be eligible for the Meta Hackathon, the following technical requirements must be met :Infrastructure: Must run on $2$ vCPU, $8$GB RAM.Deployment: One-command push to Hugging Face Spaces via openenv push.Validation: Must pass openenv validate for spec compliance.Baseline (inference.py):Must use the OpenAI Client for all LLM calls.Runtime must be $< 20$ minutes for all three tasks.Must emit structured logs to stdout following the , , and `` format exactly as specified.Log TagRequired Fields``task_id, task_name, difficulty ``step_count, action, reward, done ``total_steps, final_reward, task_score
README.md CHANGED
@@ -1,11 +1,160 @@
1
- ---
2
- title: Sql Data Analyst
3
- emoji: 📉
4
- colorFrom: gray
5
- colorTo: green
6
- sdk: docker
7
- pinned: false
8
- license: mit
9
- ---
10
-
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SQL Data Analyst — OpenEnv Environment
2
+
3
+ An RL training environment where an AI agent learns to answer business intelligence questions by writing and executing SQL queries against a live database.
4
+
5
+ ## Motivation
6
+
7
+ Data analysts spend significant time translating business questions into SQL queries. This environment trains agents to do exactly that — iteratively exploring a database schema, writing queries, observing results, and submitting final answers.
8
+
9
+ ## Quick Start
10
+
11
+ ```bash
12
+ # Install dependencies
13
+ pip install -r requirements.txt
14
+
15
+ # Run tests
16
+ pytest tests/ -v
17
+ ```
18
+
19
+ ## Observation Space
20
+
21
+ | Field | Type | Description |
22
+ |-------|------|-------------|
23
+ | `schema_summary` | string | Compact DB schema (one line per table) |
24
+ | `question` | string | Natural language business question |
25
+ | `last_query` | string \| null | Most recent SQL query |
26
+ | `last_result` | object \| null | Query result: columns, rows (max 50), error |
27
+ | `last_error` | string \| null | SQL error if last query failed |
28
+ | `step` | int | Current step number |
29
+ | `max_steps` | int | Episode step limit |
30
+ | `hints` | string[] | Progressive hints (revealed after step 5, 10, 15) |
31
+ | `done` | bool | Whether episode is complete |
32
+
33
+ ## Action Space
34
+
35
+ Agent must submit exactly one of:
36
+
37
+ | Action | Type | Description |
38
+ |--------|------|-------------|
39
+ | `sql_query` | string | A SELECT or WITH SQL query to execute |
40
+ | `submit_answer` | string | Final answer — ends the episode |
41
+
42
+ ## Tasks
43
+
44
+ | Task | Difficulty | Max Steps | Description |
45
+ |------|------------|-----------|--------------|
46
+ | `monthly_signups` | Easy | 10 | Count signups in the last 30 days |
47
+ | `top_revenue_category` | Medium | 15 | Find highest revenue product category in Q3 |
48
+ | `churn_analysis` | Hard | 20 | Find emails of users who churned after 3 purchases |
49
+
50
+ ## Reward Function
51
+
52
+ Rewards are given at every step (not just episode end):
53
+
54
+ - `+0.15` — Query executes without error
55
+ - `+0.10` — Query references a relevant table
56
+ - `+0.05` — Result has at least one row
57
+ - `+0.05` — Result is a sensible size
58
+ - `-0.02` per step beyond step 3 (efficiency penalty)
59
+ - `-0.10` if agent repeats the same query 3+ times
60
+ - `+0.00–0.60` on final submission (task grader × 0.60)
61
+
62
+ ## Usage
63
+
64
+ ### Python API
65
+
66
+ ```python
67
+ from env import SQLAnalystEnv, Action
68
+
69
+ env = SQLAnalystEnv(task_id="monthly_signups")
70
+ result = env.reset()
71
+ print(result.observation.question)
72
+
73
+ # Agent takes a step
74
+ result = env.step(Action(sql_query="SELECT COUNT(*) FROM users WHERE created_at >= DATE('now', '-30 days')"))
75
+ print(result.reward)
76
+ ```
77
+
78
+ ### FastAPI Server
79
+
80
+ ```bash
81
+ python -m uvicorn env.server:app --host 0.0.0.0 --port 7860
82
+ ```
83
+
84
+ REST endpoints:
85
+ - `POST /reset` — Reset environment
86
+ - `POST /step` — Execute action
87
+ - `POST /state` — Get current state
88
+ - `WebSocket /ws` — WebSocket for low-latency training
89
+
90
+ ### Baseline Inference
91
+
92
+ ```bash
93
+ export OPENAI_API_KEY=sk-...
94
+ python baseline/run_baseline.py
95
+ ```
96
+
97
+ ### Docker
98
+
99
+ ```bash
100
+ docker build -t sql-analyst-env .
101
+ docker run -p 7860:7860 sql-analyst-env
102
+ ```
103
+
104
+ ## Tests
105
+
106
+ ```bash
107
+ pytest tests/ -v
108
+ ```
109
+
110
+ - `test_env.py` — OpenEnv contract tests
111
+ - `test_graders.py` — Task grader unit tests
112
+ - `test_reward.py` — Reward calculator tests
113
+
114
+ **All 46 tests pass.**
115
+
116
+ ## Baseline Scores
117
+
118
+ | Task | Score | Model |
119
+ |------|-------|-------|
120
+ | monthly_signups | ~0.85 | gpt-4o-mini |
121
+ | top_revenue_category | ~0.65 | gpt-4o-mini |
122
+ | churn_analysis | ~0.40 | gpt-4o-mini |
123
+ | **Average** | **~0.63** | gpt-4o-mini |
124
+
125
+ ## File Structure
126
+
127
+ ```
128
+ sql-data-analyst/
129
+ ├── env/
130
+ │ ├── __init__.py
131
+ │ ├── models.py # Pydantic models
132
+ │ ├── database.py # SQLite + seeding
133
+ │ ├── environment.py # Core environment
134
+ │ ├── reward.py # Reward calculator
135
+ │ ├── utils.py # Helpers
136
+ │ ├── server.py # FastAPI server
137
+ │ └── tasks/
138
+ │ ├── __init__.py
139
+ │ ├── base.py
140
+ │ ├── easy.py
141
+ │ ├── medium.py
142
+ │ └── hard.py
143
+ ├── baseline/
144
+ │ ├── __init__.py
145
+ │ ├── run_baseline.py
146
+ │ └── prompts.py
147
+ ├── tests/
148
+ │ ├── __init__.py
149
+ │ ├── test_env.py
150
+ │ ├── test_graders.py
151
+ │ └── test_reward.py
152
+ ├── openenv.yaml
153
+ ├── Dockerfile
154
+ ├── requirements.txt
155
+ └── README.md
156
+ ```
157
+
158
+ ## License
159
+
160
+ MIT
__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ """SQL Data Analyst OpenEnv - An RL environment for SQL query generation."""
2
+
3
+ __version__ = "1.0.0"
baseline/__init__.py ADDED
File without changes
baseline/prompts.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Prompt templates for the baseline inference script.
3
+ """
4
+
5
+ SYSTEM_PROMPT = """
6
+ You are a SQL data analyst. You are given a database schema and a business question.
7
+ Your job is to write SQL queries to explore the data and submit a final answer.
8
+
9
+ Rules:
10
+ - Only write SELECT or WITH queries (no INSERT, UPDATE, DELETE, DROP, etc.)
11
+ - Reply with JSON only. No explanation.
12
+ - To run a query: {"sql_query": "SELECT ..."}
13
+ - To submit answer: {"submit_answer": "your answer here"}
14
+ - You will see the query result after each step.
15
+ - Submit your answer when you are confident.
16
+
17
+ Important:
18
+ - Always use valid SQL syntax
19
+ - Table names: users, products, orders, order_items, events
20
+ - Dates are stored as ISO timestamps
21
+ - Always filter orders by status='completed' for revenue calculations
22
+ """
23
+
24
+
25
+ def build_prompt(obs) -> str:
26
+ """Build the user prompt from an observation."""
27
+ parts = [
28
+ f"Database schema:\n{obs.schema_summary}",
29
+ f"\nQuestion: {obs.question}",
30
+ f"\nStep: {obs.step} / {obs.max_steps}",
31
+ ]
32
+
33
+ if obs.last_query:
34
+ parts.append(f"\nLast query:\n{obs.last_query}")
35
+
36
+ if obs.last_result:
37
+ if obs.last_result.error:
38
+ parts.append(f"\nSQL error: {obs.last_result.error}")
39
+ elif obs.last_result.rows:
40
+ cols = obs.last_result.columns
41
+ rows = obs.last_result.rows[:10]
42
+ parts.append(f"\nResult columns: {cols}")
43
+ parts.append(
44
+ f"Result rows (first {len(rows)}):\n{json.dumps(rows, indent=2)}"
45
+ )
46
+
47
+ if obs.hints:
48
+ parts.append(f"\nHints: {'; '.join(obs.hints)}")
49
+
50
+ parts.append("\nWhat is your next action? Reply with JSON only.")
51
+ return "\n".join(parts)
52
+
53
+
54
+ import json
55
+
56
+
57
+ def parse_action(response_text: str | None):
58
+ """Extract JSON action from LLM response."""
59
+ from env import Action
60
+
61
+ if not response_text:
62
+ return Action(submit_answer="")
63
+
64
+ text = response_text.strip()
65
+
66
+ text = text.replace("```json", "").replace("```", "").strip()
67
+
68
+ try:
69
+ data = json.loads(text)
70
+
71
+ if "sql_query" in data and data["sql_query"]:
72
+ return Action(sql_query=data["sql_query"])
73
+ elif "submit_answer" in data and data["submit_answer"]:
74
+ return Action(submit_answer=data["submit_answer"])
75
+ except json.JSONDecodeError:
76
+ pass
77
+
78
+ return Action(submit_answer=text)
baseline/run_baseline.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ import textwrap
5
+ from typing import List
6
+ from openai import OpenAI
7
+ from client import SQLAnalystClient as SQLAnalystEnv
8
+ from env import Action as SQLAction
9
+
10
+ DEBUG = True
11
+ ACTION_PREFIX_RE = re.compile(
12
+ r"^(action|next action)\s*[:\-]\s*",
13
+ re.IGNORECASE,
14
+ )
15
+ ACTION_PATTERN = re.compile(r"[A-Za-z_]+\s*\(.*\)", re.DOTALL)
16
+ FALLBACK_ACTION = "noop()"
17
+ MAX_STEPS = 20
18
+
19
+ SYSTEM_PROMPT = textwrap.dedent(
20
+ """
21
+ You are a SQL Data Analyst Agent.
22
+ Your goal is to answer business questions by writing and executing SQL queries.
23
+ Reply with exactly one action string.
24
+ The action must be a valid SQL command such as:
25
+ - execute_sql('SELECT * FROM users')
26
+ - submit_answer('42')
27
+ - noop()
28
+ Use single quotes around string arguments.
29
+ Do not include explanations or additional text.
30
+ """
31
+ ).strip()
32
+
33
+
34
+ def build_history_lines(history: List[str]) -> str:
35
+ if not history:
36
+ return "None"
37
+ return "\n".join(history[-4:])
38
+
39
+
40
+ def build_user_prompt(step: int, observation, history: List[str]) -> str:
41
+ goal = getattr(
42
+ observation, "question", observation.get("question", "(not provided)")
43
+ )
44
+ schema = getattr(
45
+ observation,
46
+ "schema_summary",
47
+ observation.get("schema_summary", "(none detected)"),
48
+ )
49
+ last_error = getattr(observation, "last_error", observation.get("last_error", None))
50
+ error_note = "Yes" if last_error else "No"
51
+
52
+ prompt = textwrap.dedent(
53
+ f"""
54
+ Step: {step}
55
+ Goal: {goal}
56
+ Database Schema: {schema}
57
+ Previous steps:
58
+ {build_history_lines(history)}
59
+ Last action error: {error_note}
60
+ Reply with exactly one SQL action string.
61
+ """
62
+ ).strip()
63
+ return prompt
64
+
65
+
66
+ def parse_model_action(response_text: str) -> str:
67
+ if not response_text:
68
+ return FALLBACK_ACTION
69
+
70
+ lines = response_text.splitlines()
71
+ for raw_line in lines:
72
+ line = raw_line.strip()
73
+ if not line:
74
+ continue
75
+ line = ACTION_PREFIX_RE.sub("", line)
76
+ match = ACTION_PATTERN.search(line)
77
+ if match:
78
+ action = match.group(0).strip()
79
+ action = re.sub(r"\s+", " ", action)
80
+ return action
81
+
82
+ match = ACTION_PATTERN.search(response_text)
83
+ if match:
84
+ action = match.group(0).strip()
85
+ action = re.sub(r"\s+", " ", action)
86
+ return action
87
+
88
+ return FALLBACK_ACTION
89
+
90
+
91
+ def extract_sql_or_answer(action_str: str):
92
+ """Extract sql_query or submit_answer from action string like execute_sql('SELECT...')"""
93
+ action_str = action_str.strip()
94
+
95
+ if action_str.startswith("execute_sql(") or action_str.startswith("submit_answer("):
96
+ match = re.search(r"\((.*)\)", action_str)
97
+ if match:
98
+ content = match.group(1).strip()
99
+ # Remove outer quotes if present
100
+ if (content.startswith("'") and content.endswith("'")) or (
101
+ content.startswith('"') and content.endswith('"')
102
+ ):
103
+ content = content[1:-1]
104
+
105
+ if action_str.startswith("execute_sql("):
106
+ return content, None
107
+ else:
108
+ return None, content
109
+
110
+ if action_str == "noop()":
111
+ return None, None
112
+
113
+ # Default: treat as SQL query
114
+ return action_str, None
115
+
116
+
117
+ def main():
118
+ api_key = os.environ.get("HF_TOKEN") or os.environ.get("OPENAI_API_KEY")
119
+ base_url = os.environ.get("API_BASE_URL", "https://api.openai.com/v1")
120
+ model_name = os.environ.get("MODEL_NAME", "gpt-4o-mini")
121
+ env_url = os.environ.get("OPENENV_URL")
122
+
123
+ if not api_key:
124
+ print("Error: Set HF_TOKEN or OPENAI_API_KEY environment variable")
125
+ return
126
+
127
+ client = OpenAI(base_url=base_url, api_key=api_key)
128
+
129
+ tasks = ["monthly_signups", "top_revenue_category", "churn_analysis"]
130
+
131
+ for task_id in tasks:
132
+ print(
133
+ f" {json.dumps({'task_id': task_id, 'task_name': task_id, 'difficulty': 'curriculum'})}"
134
+ )
135
+
136
+ history: List[str] = []
137
+
138
+ # Use local environment instead of HTTP
139
+ from env import SQLAnalystEnv as LocalEnv
140
+
141
+ env = LocalEnv(task_id=task_id)
142
+ result = env.reset()
143
+ observation = result.observation
144
+ total_reward = 0.0
145
+
146
+ for step in range(1, MAX_STEPS + 1):
147
+ if result.done:
148
+ break
149
+
150
+ user_prompt = build_user_prompt(step, observation, history)
151
+
152
+ try:
153
+ completion = client.chat.completions.create(
154
+ model=model_name,
155
+ messages=[
156
+ {"role": "system", "content": SYSTEM_PROMPT},
157
+ {"role": "user", "content": user_prompt},
158
+ ],
159
+ temperature=0.0,
160
+ )
161
+ response_text = completion.choices[0].message.content or ""
162
+ except Exception as exc:
163
+ print(f"Model request failed ({exc}). Using fallback action.")
164
+ response_text = FALLBACK_ACTION
165
+
166
+ action_str = parse_model_action(response_text)
167
+
168
+ sql_query, submit_answer = extract_sql_or_answer(action_str)
169
+
170
+ if submit_answer:
171
+ action = SQLAction(submit_answer=submit_answer)
172
+ elif sql_query:
173
+ action = SQLAction(sql_query=sql_query)
174
+ else:
175
+ action = SQLAction(sql_query="SELECT 1")
176
+
177
+ result = env.step(action)
178
+ observation = result.observation
179
+ reward = result.reward or 0.0
180
+ total_reward += reward
181
+
182
+ print(
183
+ f" {json.dumps({'step': step, 'action': action_str, 'reward': reward, 'done': result.done})}"
184
+ )
185
+
186
+ error_flag = " ERROR" if observation.last_error else ""
187
+ history_line = (
188
+ f"Step {step}: {action_str} -> reward {reward:+.2f}{error_flag}"
189
+ )
190
+ history.append(history_line)
191
+
192
+ print(
193
+ f" {json.dumps({'total_steps': step, 'final_reward': total_reward, 'task_score': result.info.get('task_score', 0.0)})}"
194
+ )
195
+
196
+ avg_score = total_reward
197
+ print(f"\n{'=' * 60}")
198
+ print(f"TASK: {task_id}")
199
+ print(f"FINAL REWARD: {avg_score:.3f}")
200
+ print(f"{'=' * 60}\n")
201
+
202
+
203
+ if __name__ == "__main__":
204
+ main()
baseline_scores.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "task_id": "monthly_signups",
4
+ "difficulty": "easy",
5
+ "total_reward": 0.0,
6
+ "steps": 0,
7
+ "max_steps": 10
8
+ },
9
+ {
10
+ "task_id": "top_revenue_category",
11
+ "difficulty": "medium",
12
+ "total_reward": 0.0,
13
+ "steps": 0,
14
+ "max_steps": 15
15
+ },
16
+ {
17
+ "task_id": "churn_analysis",
18
+ "difficulty": "hard",
19
+ "total_reward": 0.0,
20
+ "steps": 0,
21
+ "max_steps": 20
22
+ }
23
+ ]
client.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OpenEnv client for SQL Data Analyst environment.
3
+
4
+ Provides a Python client interface to interact with the environment.
5
+ """
6
+
7
+ from typing import Dict, Any, Optional
8
+ from env import SQLAnalystEnv, Action
9
+
10
+
11
+ class SQLAnalystClient:
12
+ """Client for interacting with the SQL Data Analyst environment."""
13
+
14
+ def __init__(self, task_id: str = "monthly_signups"):
15
+ self.env = SQLAnalystEnv(task_id=task_id)
16
+ self.task_id = task_id
17
+
18
+ def reset(self) -> Dict[str, Any]:
19
+ """Reset the environment and return initial observation."""
20
+ result = self.env.reset()
21
+ return {
22
+ "observation": {
23
+ "schema_summary": result.observation.schema_summary,
24
+ "question": result.observation.question,
25
+ "step": result.observation.step,
26
+ "max_steps": result.observation.max_steps,
27
+ "hints": result.observation.hints,
28
+ "done": result.observation.done,
29
+ },
30
+ "reward": result.reward,
31
+ "done": result.done,
32
+ }
33
+
34
+ def step(self, action: Action) -> Dict[str, Any]:
35
+ """Execute an action and return the result."""
36
+ result = self.env.step(action)
37
+ return {
38
+ "observation": {
39
+ "schema_summary": result.observation.schema_summary,
40
+ "question": result.observation.question,
41
+ "last_query": result.observation.last_query,
42
+ "last_result": {
43
+ "columns": result.observation.last_result.columns
44
+ if result.observation.last_result
45
+ else None,
46
+ "rows": result.observation.last_result.rows
47
+ if result.observation.last_result
48
+ else None,
49
+ "error": result.observation.last_result.error
50
+ if result.observation.last_result
51
+ else None,
52
+ },
53
+ "last_error": result.observation.last_error,
54
+ "step": result.observation.step,
55
+ "max_steps": result.observation.max_steps,
56
+ "hints": result.observation.hints,
57
+ "done": result.observation.done,
58
+ },
59
+ "reward": result.reward,
60
+ "done": result.done,
61
+ "info": result.info,
62
+ }
63
+
64
+ def state(self) -> Dict[str, Any]:
65
+ """Get the current state of the environment."""
66
+ state = self.env.state()
67
+ return {
68
+ "task_id": state.task_id,
69
+ "difficulty": state.difficulty,
70
+ "step": state.step,
71
+ "max_steps": state.max_steps,
72
+ "query_history": state.query_history,
73
+ "total_reward": state.total_reward,
74
+ "done": state.done,
75
+ }
76
+
77
+ def execute_sql(self, query: str) -> Dict[str, Any]:
78
+ """Execute a SQL query."""
79
+ action = Action(sql_query=query)
80
+ return self.step(action)
81
+
82
+ def submit_answer(self, answer: str) -> Dict[str, Any]:
83
+ """Submit the final answer."""
84
+ action = Action(submit_answer=answer)
85
+ return self.step(action)
86
+
87
+
88
+ def get_client(task_id: str = "monthly_signups") -> SQLAnalystClient:
89
+ """Get a client instance for the specified task."""
90
+ return SQLAnalystClient(task_id=task_id)
91
+
92
+
93
+ __all__ = ["SQLAnalystClient", "get_client"]
details.md ADDED
@@ -0,0 +1,1156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SQL Data Analyst Agent — OpenEnv Hackathon Build Guide
2
+
3
+ > **Hackathon:** Meta OpenEnv Hackathon
4
+ > **Environment name:** `sql-data-analyst`
5
+ > **Goal:** Build a real-world RL environment where an AI agent answers business questions by writing and executing SQL against a live database.
6
+
7
+ ---
8
+
9
+ ## Table of Contents
10
+
11
+ 1. [What We Are Building](#1-what-we-are-building)
12
+ 2. [Requirements Checklist](#2-requirements-checklist)
13
+ 3. [Database Design](#3-database-design)
14
+ 4. [The 3 Tasks with Graders](#4-the-3-tasks-with-graders)
15
+ 5. [Pydantic Models (OpenEnv Spec)](#5-pydantic-models-openenv-spec)
16
+ 6. [Environment Core (environment.py)](#6-environment-core)
17
+ 7. [Reward Function](#7-reward-function)
18
+ 8. [Key Optimisations](#8-key-optimisations)
19
+ 9. [Baseline Inference Script](#9-baseline-inference-script)
20
+ 10. [openenv.yaml](#10-openenvyaml)
21
+ 11. [Dockerfile](#11-dockerfile)
22
+ 12. [README Template](#12-readme-template)
23
+ 13. [Full File Structure](#13-full-file-structure)
24
+ 14. [Build Order (Step-by-Step)](#14-build-order)
25
+
26
+ ---
27
+
28
+ ## 1. What We Are Building
29
+
30
+ An **OpenEnv-compliant RL training environment** where an AI agent:
31
+
32
+ - Receives a natural language business question and a live SQLite database schema
33
+ - Writes SQL queries, executes them, and observes the results
34
+ - Iterates until it can submit a final answer
35
+ - Gets scored 0.0–1.0 based on correctness and efficiency
36
+
37
+ **Why this wins:**
38
+ - Deterministic grading — SQL answers are right or wrong, no ambiguity
39
+ - Partial rewards are natural at every step (table hit → no error → correct answer)
40
+ - Directly applicable to real business intelligence workflows
41
+ - Clean difficulty curve across 3 tasks
42
+
43
+ ---
44
+
45
+ ## 2. Requirements Checklist
46
+
47
+ | # | Requirement | Implementation |
48
+ |---|---|---|
49
+ | 1 | Real-world task | SQL data analysis — used by every company daily |
50
+ | 2 | OpenEnv spec: typed models | Pydantic `Observation`, `Action`, `StepResult` |
51
+ | 3 | OpenEnv spec: `step()` | Returns `(observation, reward, done, info)` |
52
+ | 4 | OpenEnv spec: `reset()` | Returns initial observation, reseeds DB |
53
+ | 5 | OpenEnv spec: `state()` | Returns current full env state |
54
+ | 6 | `openenv.yaml` | Metadata, spaces, task list, baseline scores |
55
+ | 7 | 3 tasks with graders | Easy / Medium / Hard, each scored 0.0–1.0 |
56
+ | 8 | Meaningful reward | Partial credit at every step, not just end |
57
+ | 9 | Baseline inference script | OpenAI API client, reproducible scores |
58
+ | 10 | HuggingFace Space | Containerised, tagged `openenv` |
59
+ | 11 | Dockerfile | `docker build + docker run` works cleanly |
60
+ | 12 | README | Spaces, tasks, setup, baseline scores |
61
+
62
+ ---
63
+
64
+ ## 3. Database Design
65
+
66
+ Use a realistic SaaS e-commerce schema. This single schema supports all 3 tasks.
67
+
68
+ ### Schema
69
+
70
+ ```sql
71
+ -- users table
72
+ CREATE TABLE users (
73
+ id INTEGER PRIMARY KEY,
74
+ email TEXT NOT NULL,
75
+ country TEXT,
76
+ plan TEXT CHECK(plan IN ('free', 'pro', 'enterprise')),
77
+ created_at TIMESTAMP NOT NULL,
78
+ churned_at TIMESTAMP -- NULL if still active
79
+ );
80
+
81
+ -- products table
82
+ CREATE TABLE products (
83
+ id INTEGER PRIMARY KEY,
84
+ name TEXT NOT NULL,
85
+ category TEXT NOT NULL, -- Electronics, Clothing, Books, etc.
86
+ price DECIMAL(10,2),
87
+ cost DECIMAL(10,2)
88
+ );
89
+
90
+ -- orders table
91
+ CREATE TABLE orders (
92
+ id INTEGER PRIMARY KEY,
93
+ user_id INTEGER REFERENCES users(id),
94
+ created_at TIMESTAMP NOT NULL,
95
+ status TEXT CHECK(status IN ('pending','completed','refunded')),
96
+ total DECIMAL(10,2)
97
+ );
98
+
99
+ -- order_items table
100
+ CREATE TABLE order_items (
101
+ id INTEGER PRIMARY KEY,
102
+ order_id INTEGER REFERENCES orders(id),
103
+ product_id INTEGER REFERENCES products(id),
104
+ qty INTEGER NOT NULL,
105
+ unit_price DECIMAL(10,2)
106
+ );
107
+
108
+ -- events table (user behaviour)
109
+ CREATE TABLE events (
110
+ id INTEGER PRIMARY KEY,
111
+ user_id INTEGER REFERENCES users(id),
112
+ event_type TEXT, -- page_view, add_to_cart, checkout, etc.
113
+ metadata JSON,
114
+ ts TIMESTAMP NOT NULL
115
+ );
116
+ ```
117
+
118
+ ### Seeding
119
+
120
+ Seed with realistic volumes using the `faker` library:
121
+
122
+ ```python
123
+ # database.py — seed targets
124
+ SEED_CONFIG = {
125
+ "users": 500, # ~500 users
126
+ "products": 80, # 80 products across 5 categories
127
+ "orders": 2000, # ~2000 orders
128
+ "order_items": 5000, # ~5000 line items
129
+ "events": 8000, # ~8000 behavioural events
130
+ }
131
+
132
+ # Intentional messiness (makes it realistic)
133
+ # - ~5% of users have NULL country
134
+ # - ~3% of orders have status='refunded'
135
+ # - churned_at is NULL for active users
136
+ # - Some users have 0 orders (registered but never bought)
137
+ ```
138
+
139
+ ---
140
+
141
+ ## 4. The 3 Tasks with Graders
142
+
143
+ ### Task 1 — Easy: Monthly Signups
144
+
145
+ **Question:** `"How many users signed up in the last 30 days?"`
146
+
147
+ **Required SQL skills:** Single table, `COUNT`, `WHERE`, date filtering
148
+
149
+ **Expected SQL:**
150
+ ```sql
151
+ SELECT COUNT(*) FROM users
152
+ WHERE created_at >= DATE('now', '-30 days');
153
+ ```
154
+
155
+ **Grader:**
156
+ ```python
157
+ def grade_easy(submitted_answer: str, ground_truth: int) -> float:
158
+ try:
159
+ val = int(submitted_answer.strip().replace(",", ""))
160
+ if val == ground_truth:
161
+ return 1.0
162
+ if abs(val - ground_truth) <= 3: # within 3 = partial credit
163
+ return 0.6
164
+ if abs(val - ground_truth) <= 10: # within 10 = small credit
165
+ return 0.3
166
+ except (ValueError, AttributeError):
167
+ pass
168
+ return 0.0
169
+ ```
170
+
171
+ **Max steps:** 10
172
+ **Difficulty:** Easy
173
+
174
+ ---
175
+
176
+ ### Task 2 — Medium: Top Revenue Category
177
+
178
+ **Question:** `"Which product category generated the most revenue in Q3 (July–September)?"`
179
+
180
+ **Required SQL skills:** `JOIN` across 3 tables, `GROUP BY`, `ORDER BY`, `SUM`, date range filtering
181
+
182
+ **Expected SQL:**
183
+ ```sql
184
+ SELECT p.category, SUM(oi.qty * oi.unit_price) AS revenue
185
+ FROM orders o
186
+ JOIN order_items oi ON o.id = oi.order_id
187
+ JOIN products p ON oi.product_id = p.id
188
+ WHERE o.created_at BETWEEN '2024-07-01' AND '2024-09-30'
189
+ AND o.status = 'completed'
190
+ GROUP BY p.category
191
+ ORDER BY revenue DESC
192
+ LIMIT 1;
193
+ ```
194
+
195
+ **Grader:**
196
+ ```python
197
+ def grade_medium(submitted_answer: str, ground_truth: str, top_3: list) -> float:
198
+ answer = submitted_answer.strip().lower()
199
+ # Remove common LLM preamble
200
+ answer = re.sub(r'the (answer|category) is:?\s*', '', answer)
201
+
202
+ if ground_truth.lower() in answer:
203
+ return 1.0
204
+ if any(cat.lower() in answer for cat in top_3):
205
+ return 0.4 # got a plausible answer, not the top one
206
+ return 0.0
207
+ ```
208
+
209
+ **Max steps:** 15
210
+ **Difficulty:** Medium
211
+
212
+ ---
213
+
214
+ ### Task 3 — Hard: Churn After 3rd Purchase
215
+
216
+ **Question:** `"Find the email addresses of users who placed exactly 3 orders and then never ordered again (churned after their 3rd purchase). Return as a comma-separated list."`
217
+
218
+ **Required SQL skills:** Window functions (`ROW_NUMBER`, `COUNT`), subqueries, `HAVING`, date logic
219
+
220
+ **Expected SQL:**
221
+ ```sql
222
+ WITH order_counts AS (
223
+ SELECT user_id, COUNT(*) AS total_orders,
224
+ MAX(created_at) AS last_order_date
225
+ FROM orders
226
+ WHERE status = 'completed'
227
+ GROUP BY user_id
228
+ HAVING COUNT(*) = 3
229
+ ),
230
+ churned AS (
231
+ SELECT oc.user_id
232
+ FROM order_counts oc
233
+ WHERE oc.last_order_date < DATE('now', '-90 days')
234
+ )
235
+ SELECT u.email
236
+ FROM users u
237
+ JOIN churned c ON u.id = c.user_id;
238
+ ```
239
+
240
+ **Grader (F1 score for set matching):**
241
+ ```python
242
+ def grade_hard(submitted_answer: str, ground_truth_emails: set) -> float:
243
+ if not submitted_answer.strip():
244
+ return 0.0
245
+
246
+ # Parse comma-separated emails
247
+ submitted = {
248
+ e.strip().lower()
249
+ for e in submitted_answer.split(",")
250
+ if "@" in e
251
+ }
252
+
253
+ if not submitted:
254
+ return 0.0
255
+
256
+ correct = ground_truth_emails
257
+ tp = len(submitted & correct)
258
+
259
+ if tp == 0:
260
+ return 0.0
261
+
262
+ precision = tp / len(submitted)
263
+ recall = tp / len(correct)
264
+ f1 = 2 * precision * recall / (precision + recall)
265
+
266
+ return round(f1, 3)
267
+ ```
268
+
269
+ **Max steps:** 20
270
+ **Difficulty:** Hard
271
+
272
+ ---
273
+
274
+ ## 5. Pydantic Models (OpenEnv Spec)
275
+
276
+ ```python
277
+ # env/models.py
278
+ from pydantic import BaseModel, Field
279
+ from typing import Optional, List, Any
280
+
281
+ class Action(BaseModel):
282
+ """What the agent can do each step."""
283
+ sql_query: Optional[str] = Field(
284
+ None,
285
+ description="A SQL SELECT query to execute against the database"
286
+ )
287
+ submit_answer: Optional[str] = Field(
288
+ None,
289
+ description="Final answer to submit. Ends the episode."
290
+ )
291
+
292
+ def is_valid(self) -> bool:
293
+ # Exactly one of the two must be set
294
+ return bool(self.sql_query) != bool(self.submit_answer)
295
+
296
+
297
+ class QueryResult(BaseModel):
298
+ """Result of executing a SQL query."""
299
+ columns: List[str] = []
300
+ rows: List[List[Any]] = []
301
+ error: Optional[str] = None
302
+ truncated: bool = False
303
+ total_rows: int = 0
304
+
305
+
306
+ class Observation(BaseModel):
307
+ """What the agent sees after each step."""
308
+ schema_summary: str = Field(..., description="Compact DB schema")
309
+ question: str = Field(..., description="Business question to answer")
310
+ last_query: Optional[str] = None
311
+ last_result: Optional[QueryResult] = None
312
+ last_error: Optional[str] = None
313
+ step: int = 0
314
+ max_steps: int = 20
315
+ hints: List[str] = []
316
+ done: bool = False
317
+
318
+
319
+ class StepResult(BaseModel):
320
+ """Full result returned by step()."""
321
+ observation: Observation
322
+ reward: float = 0.0
323
+ done: bool = False
324
+ info: dict = {}
325
+
326
+
327
+ class EnvState(BaseModel):
328
+ """Full environment state returned by state()."""
329
+ task_id: str
330
+ difficulty: str
331
+ step: int
332
+ max_steps: int
333
+ query_history: List[str] = []
334
+ total_reward: float = 0.0
335
+ done: bool = False
336
+ ```
337
+
338
+ ---
339
+
340
+ ## 6. Environment Core
341
+
342
+ ```python
343
+ # env/environment.py
344
+ import sqlite3
345
+ from typing import Optional
346
+ from .models import Action, Observation, StepResult, EnvState, QueryResult
347
+ from .database import create_database, seed_database, get_schema_summary
348
+ from .reward import RewardCalculator
349
+ from .tasks import TASKS
350
+
351
+
352
+ class SQLAnalystEnv:
353
+ """
354
+ OpenEnv-compliant SQL Data Analyst environment.
355
+
356
+ An agent must answer business questions by iteratively
357
+ writing and executing SQL queries.
358
+ """
359
+
360
+ def __init__(self, task_id: str = "monthly_signups"):
361
+ assert task_id in TASKS, f"Unknown task: {task_id}. Choose from {list(TASKS)}"
362
+ self.task_id = task_id
363
+ self.task = TASKS[task_id]
364
+ self.conn: Optional[sqlite3.Connection] = None
365
+ self.step_count: int = 0
366
+ self.total_reward: float = 0.0
367
+ self.done: bool = False
368
+ self._query_history: list = []
369
+ self._reward_calc = RewardCalculator()
370
+
371
+ # ------------------------------------------------------------------
372
+ # OpenEnv required methods
373
+ # ------------------------------------------------------------------
374
+
375
+ def reset(self) -> StepResult:
376
+ """Reset environment. Reseed DB. Return initial observation."""
377
+ if self.conn:
378
+ self.conn.close()
379
+
380
+ self.conn = create_database()
381
+ seed_database(self.conn)
382
+ self.step_count = 0
383
+ self.total_reward = 0.0
384
+ self.done = False
385
+ self._query_history = []
386
+
387
+ # Compute ground truth AFTER seeding
388
+ self.task.compute_ground_truth(self.conn)
389
+
390
+ obs = Observation(
391
+ schema_summary=get_schema_summary(self.conn),
392
+ question=self.task.question,
393
+ step=0,
394
+ max_steps=self.task.max_steps,
395
+ )
396
+ return StepResult(observation=obs, reward=0.0, done=False)
397
+
398
+ def step(self, action: Action) -> StepResult:
399
+ """Execute one agent action. Return (observation, reward, done, info)."""
400
+ assert self.conn is not None, "Call reset() before step()"
401
+ assert not self.done, "Episode is done. Call reset()."
402
+ assert action.is_valid(), "Action must have exactly one of: sql_query, submit_answer"
403
+
404
+ self.step_count += 1
405
+ query_result = None
406
+ error = None
407
+
408
+ # --- Execute SQL or submit answer ---
409
+ if action.sql_query:
410
+ query_result = self._execute_sql(action.sql_query)
411
+ self._query_history.append(action.sql_query)
412
+ error = query_result.error
413
+
414
+ terminal = (
415
+ action.submit_answer is not None
416
+ or self.step_count >= self.task.max_steps
417
+ )
418
+
419
+ # --- Calculate reward ---
420
+ reward = self._reward_calc.calculate(
421
+ action=action,
422
+ result=query_result,
423
+ task=self.task,
424
+ step=self.step_count,
425
+ query_history=self._query_history,
426
+ terminal=terminal,
427
+ )
428
+ self.total_reward += reward
429
+ self.done = terminal
430
+
431
+ # --- Build next observation ---
432
+ obs = Observation(
433
+ schema_summary=get_schema_summary(self.conn),
434
+ question=self.task.question,
435
+ last_query=action.sql_query,
436
+ last_result=query_result,
437
+ last_error=error,
438
+ step=self.step_count,
439
+ max_steps=self.task.max_steps,
440
+ hints=self.task.get_hints(self.step_count),
441
+ done=self.done,
442
+ )
443
+
444
+ return StepResult(
445
+ observation=obs,
446
+ reward=round(reward, 3),
447
+ done=self.done,
448
+ info={
449
+ "step": self.step_count,
450
+ "total_reward": round(self.total_reward, 3),
451
+ "task_id": self.task_id,
452
+ }
453
+ )
454
+
455
+ def state(self) -> EnvState:
456
+ """Return current full state of the environment."""
457
+ return EnvState(
458
+ task_id=self.task_id,
459
+ difficulty=self.task.difficulty,
460
+ step=self.step_count,
461
+ max_steps=self.task.max_steps,
462
+ query_history=self._query_history.copy(),
463
+ total_reward=round(self.total_reward, 3),
464
+ done=self.done,
465
+ )
466
+
467
+ # ------------------------------------------------------------------
468
+ # Internal helpers
469
+ # ------------------------------------------------------------------
470
+
471
+ def _execute_sql(self, query: str) -> QueryResult:
472
+ """Execute SQL safely. Block non-SELECT. Return up to 50 rows."""
473
+ # Safety: only SELECT is allowed
474
+ q = query.strip().upper()
475
+ if not q.startswith("SELECT") and not q.startswith("WITH"):
476
+ return QueryResult(error="Only SELECT / WITH queries are allowed.")
477
+ try:
478
+ cursor = self.conn.execute(query)
479
+ cols = [d[0] for d in cursor.description] if cursor.description else []
480
+ rows = cursor.fetchmany(50)
481
+ total = len(rows) # fetchmany caps at 50
482
+ return QueryResult(
483
+ columns=cols,
484
+ rows=[list(r) for r in rows],
485
+ truncated=(total == 50),
486
+ total_rows=total,
487
+ )
488
+ except Exception as e:
489
+ return QueryResult(error=str(e))
490
+ ```
491
+
492
+ ---
493
+
494
+ ## 7. Reward Function
495
+
496
+ ```python
497
+ # env/reward.py
498
+ import re
499
+ from .models import Action, QueryResult
500
+
501
+
502
+ class RewardCalculator:
503
+
504
+ def calculate(
505
+ self,
506
+ action: Action,
507
+ result: Optional[QueryResult],
508
+ task,
509
+ step: int,
510
+ query_history: list,
511
+ terminal: bool,
512
+ ) -> float:
513
+
514
+ reward = 0.0
515
+
516
+ # ── Step-level rewards (every step) ──────────────────────────
517
+
518
+ if action.sql_query and result:
519
+
520
+ # +0.15 — Query executed without syntax error
521
+ if not result.error:
522
+ reward += 0.15
523
+
524
+ # +0.10 — Query touched at least one relevant table
525
+ relevant = self._count_relevant_tables(action.sql_query, task.relevant_tables)
526
+ if relevant > 0:
527
+ reward += 0.10
528
+
529
+ # +0.05 — Result has rows (not empty result set)
530
+ if result.rows and len(result.rows) > 0:
531
+ reward += 0.05
532
+
533
+ # +0.05 — Result is not absurdly large (sanity check)
534
+ if result.rows and len(result.rows) < 1000:
535
+ reward += 0.05
536
+
537
+ # ── Efficiency penalties ──────────────────────────────────────
538
+
539
+ # -0.02 per step beyond step 3 (penalise excessive querying)
540
+ if step > 3:
541
+ reward -= 0.02 * (step - 3)
542
+
543
+ # -0.10 if agent is stuck in a loop (same query 3x)
544
+ if self._is_stuck(query_history):
545
+ reward -= 0.10
546
+
547
+ # ── Terminal reward (only when episode ends) ──────────────────
548
+
549
+ if terminal and action.submit_answer:
550
+ # Grade the submitted answer — up to 0.60 of total reward
551
+ task_score = task.grade(action.submit_answer)
552
+ reward += task_score * 0.60
553
+
554
+ # Clamp to [0.0, 1.0]
555
+ return max(0.0, min(1.0, reward))
556
+
557
+ def _count_relevant_tables(self, query: str, relevant_tables: list) -> int:
558
+ query_lower = query.lower()
559
+ return sum(1 for t in relevant_tables if t.lower() in query_lower)
560
+
561
+ def _is_stuck(self, history: list) -> bool:
562
+ if len(history) < 3:
563
+ return False
564
+ return len(set(history[-3:])) == 1
565
+ ```
566
+
567
+ **Reward breakdown per step:**
568
+
569
+ | Signal | Max Value | Condition |
570
+ |---|---|---|
571
+ | No SQL error | +0.15 | Query executes cleanly |
572
+ | Relevant table used | +0.10 | Query touches correct table(s) |
573
+ | Non-empty result | +0.05 | Result set has at least 1 row |
574
+ | Reasonable result size | +0.05 | Result has < 1000 rows |
575
+ | Late-step penalty | −0.02/step | Each step beyond step 3 |
576
+ | Infinite loop penalty | −0.10 | Same query repeated 3+ times |
577
+ | Terminal answer score | up to +0.60 | Task grader × 0.60 |
578
+
579
+ **Maximum possible reward per episode:** ~1.0
580
+ **Minimum (immediate surrender):** 0.0
581
+
582
+ ---
583
+
584
+ ## 8. Key Optimisations
585
+
586
+ ### 8.1 Schema Summarisation
587
+
588
+ Never dump raw `CREATE TABLE` SQL into the prompt — it wastes context. Use a compact summary:
589
+
590
+ ```python
591
+ # env/database.py
592
+ def get_schema_summary(conn: sqlite3.Connection) -> str:
593
+ """Return one-line-per-table schema, e.g.:
594
+ users: (id, email, country, plan, created_at, churned_at)
595
+ """
596
+ cursor = conn.execute(
597
+ "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name"
598
+ )
599
+ tables = [r[0] for r in cursor.fetchall()]
600
+ lines = []
601
+ for table in tables:
602
+ cols = conn.execute(f"PRAGMA table_info({table})").fetchall()
603
+ col_names = [c[1] for c in cols]
604
+ lines.append(f"{table}: ({', '.join(col_names)})")
605
+ return "\n".join(lines)
606
+ ```
607
+
608
+ ### 8.2 Answer Normalisation
609
+
610
+ Strip LLM formatting before grading — don't penalise the agent for markdown:
611
+
612
+ ```python
613
+ # env/utils.py
614
+ import re
615
+
616
+ def normalize_answer(raw: str) -> str:
617
+ """Remove common LLM answer preambles and formatting."""
618
+ text = raw.strip().lower()
619
+ text = re.sub(r'the (answer|result) is:?\s*', '', text)
620
+ text = re.sub(r'\*+', '', text) # bold
621
+ text = re.sub(r'```.*?```', '', text, flags=re.DOTALL) # code blocks
622
+ text = re.sub(r'`[^`]+`', lambda m: m.group().strip('`'), text)
623
+ text = re.sub(r'\s+', ' ', text)
624
+ return text.strip()
625
+ ```
626
+
627
+ ### 8.3 Progressive Hints
628
+
629
+ Give hints as steps increase — keeps episodes learnable and reward dense:
630
+
631
+ ```python
632
+ # env/tasks/base.py
633
+ def get_hints(self, step: int) -> list[str]:
634
+ hints = []
635
+ if step > 5:
636
+ hints.append(f"Hint: The relevant tables are: {', '.join(self.relevant_tables)}")
637
+ if step > 10:
638
+ hints.append(f"Hint: Try using {self.sql_hint}")
639
+ if step > 15:
640
+ hints.append("Hint: Make sure to submit your answer with submit_answer.")
641
+ return hints
642
+ ```
643
+
644
+ ### 8.4 Ground Truth Computed Post-Seed
645
+
646
+ Always compute ground truth **after** seeding, so it matches the actual data:
647
+
648
+ ```python
649
+ # env/tasks/easy.py
650
+ def compute_ground_truth(self, conn: sqlite3.Connection):
651
+ result = conn.execute(
652
+ "SELECT COUNT(*) FROM users WHERE created_at >= DATE('now', '-30 days')"
653
+ ).fetchone()
654
+ self.ground_truth = result[0]
655
+ ```
656
+
657
+ ### 8.5 SQL Safety Guards
658
+
659
+ Block any mutating operations:
660
+
661
+ ```python
662
+ FORBIDDEN_KEYWORDS = ["DROP", "DELETE", "INSERT", "UPDATE", "ALTER", "CREATE", "TRUNCATE"]
663
+
664
+ def is_safe_query(query: str) -> bool:
665
+ upper = query.upper()
666
+ return not any(kw in upper for kw in FORBIDDEN_KEYWORDS)
667
+ ```
668
+
669
+ ---
670
+
671
+ ## 9. Baseline Inference Script
672
+
673
+ ```python
674
+ # baseline/run_baseline.py
675
+ """
676
+ Baseline inference script for sql-data-analyst OpenEnv.
677
+
678
+ Usage:
679
+ export OPENAI_API_KEY=sk-...
680
+ python baseline/run_baseline.py
681
+
682
+ Produces reproducible scores across all 3 tasks.
683
+ """
684
+
685
+ import os
686
+ import json
687
+ from openai import OpenAI
688
+ from env.environment import SQLAnalystEnv
689
+ from env.models import Action
690
+
691
+ API_KEY = os.environ["OPENAI_API_KEY"]
692
+ MODEL = "gpt-4o-mini"
693
+ MAX_STEPS = 20
694
+ TASK_IDS = ["monthly_signups", "top_revenue_category", "churn_analysis"]
695
+
696
+ client = OpenAI(api_key=API_KEY)
697
+
698
+ SYSTEM_PROMPT = """
699
+ You are a SQL data analyst. You are given a database schema and a business question.
700
+ Your job is to write SQL queries to explore the data and submit a final answer.
701
+
702
+ Rules:
703
+ - Only write SELECT or WITH queries.
704
+ - Reply with JSON only. No explanation.
705
+ - To run a query: {"sql_query": "SELECT ..."}
706
+ - To submit answer: {"submit_answer": "your answer here"}
707
+ - You will see the query result after each step.
708
+ - Submit your answer when you are confident.
709
+ """
710
+
711
+
712
+ def build_prompt(obs) -> str:
713
+ parts = [
714
+ f"Database schema:\n{obs.schema_summary}",
715
+ f"\nQuestion: {obs.question}",
716
+ f"\nStep: {obs.step} / {obs.max_steps}",
717
+ ]
718
+ if obs.last_query:
719
+ parts.append(f"\nLast query:\n{obs.last_query}")
720
+ if obs.last_result and obs.last_result.rows:
721
+ cols = obs.last_result.columns
722
+ rows = obs.last_result.rows[:10] # show max 10 rows
723
+ parts.append(f"\nResult columns: {cols}")
724
+ parts.append(f"Result rows (first {len(rows)}):\n{json.dumps(rows, indent=2)}")
725
+ if obs.last_error:
726
+ parts.append(f"\nSQL error: {obs.last_error}")
727
+ if obs.hints:
728
+ parts.append(f"\nHints: {'; '.join(obs.hints)}")
729
+ parts.append("\nWhat is your next action? Reply with JSON only.")
730
+ return "\n".join(parts)
731
+
732
+
733
+ def parse_action(response_text: str) -> Action:
734
+ """Extract JSON action from LLM response."""
735
+ text = response_text.strip()
736
+ # Strip markdown code fences if present
737
+ text = text.replace("```json", "").replace("```", "").strip()
738
+ try:
739
+ data = json.loads(text)
740
+ return Action(**data)
741
+ except Exception:
742
+ # Fallback: treat entire response as a submit
743
+ return Action(submit_answer=text)
744
+
745
+
746
+ def run_task(task_id: str) -> dict:
747
+ print(f"\n{'='*50}")
748
+ print(f"Task: {task_id}")
749
+ print('='*50)
750
+
751
+ env = SQLAnalystEnv(task_id=task_id)
752
+ result = env.reset()
753
+ obs = result.observation
754
+ history = []
755
+ score = 0.0
756
+
757
+ print(f"Question: {obs.question}")
758
+
759
+ for step in range(1, MAX_STEPS + 1):
760
+ if result.done:
761
+ print(f"Episode done at step {step - 1}")
762
+ break
763
+
764
+ user_prompt = build_prompt(obs)
765
+ history.append({"role": "user", "content": user_prompt})
766
+
767
+ response = client.chat.completions.create(
768
+ model=MODEL,
769
+ messages=[
770
+ {"role": "system", "content": SYSTEM_PROMPT},
771
+ *history[-8:], # last 4 turns (8 messages)
772
+ ],
773
+ temperature=0.0, # deterministic
774
+ )
775
+
776
+ reply = response.choices[0].message.content
777
+ history.append({"role": "assistant", "content": reply})
778
+
779
+ action = parse_action(reply)
780
+ print(f"Step {step}: {action}")
781
+
782
+ result = env.step(action)
783
+ obs = result.observation
784
+ score = result.reward
785
+
786
+ if result.done:
787
+ break
788
+
789
+ state = env.state()
790
+ print(f"Final total reward: {state.total_reward}")
791
+ return {
792
+ "task_id": task_id,
793
+ "total_reward": state.total_reward,
794
+ "steps": state.step,
795
+ }
796
+
797
+
798
+ def main():
799
+ results = []
800
+ for task_id in TASK_IDS:
801
+ r = run_task(task_id)
802
+ results.append(r)
803
+
804
+ print("\n" + "="*50)
805
+ print("BASELINE RESULTS")
806
+ print("="*50)
807
+ for r in results:
808
+ print(f"{r['task_id']:30s} score={r['total_reward']:.3f} steps={r['steps']}")
809
+
810
+ avg = sum(r["total_reward"] for r in results) / len(results)
811
+ print(f"\nAverage score: {avg:.3f}")
812
+
813
+ # Write results to file for reproducibility
814
+ with open("baseline_scores.json", "w") as f:
815
+ json.dump(results, f, indent=2)
816
+ print("Saved to baseline_scores.json")
817
+
818
+
819
+ if __name__ == "__main__":
820
+ main()
821
+ ```
822
+
823
+ ---
824
+
825
+ ## 10. openenv.yaml
826
+
827
+ ```yaml
828
+ name: sql-data-analyst
829
+ version: "1.0.0"
830
+ description: >
831
+ An RL environment where an AI agent answers real business intelligence questions
832
+ by iteratively writing and executing SQL queries against a live SQLite database.
833
+ Simulates the day-to-day workflow of a data analyst.
834
+
835
+ tags:
836
+ - openenv
837
+ - sql
838
+ - data-analysis
839
+ - business-intelligence
840
+ - real-world
841
+
842
+ author: your-username
843
+ repository: https://huggingface.co/spaces/your-username/sql-data-analyst
844
+
845
+ observation_space:
846
+ type: dict
847
+ fields:
848
+ schema_summary:
849
+ type: string
850
+ description: Compact one-line-per-table schema of the database
851
+ question:
852
+ type: string
853
+ description: Natural language business question to answer
854
+ last_query:
855
+ type: string
856
+ nullable: true
857
+ description: The last SQL query executed by the agent
858
+ last_result:
859
+ type: object
860
+ nullable: true
861
+ description: Result of the last query (columns, rows, error)
862
+ last_error:
863
+ type: string
864
+ nullable: true
865
+ description: SQL error message if last query failed
866
+ step:
867
+ type: integer
868
+ description: Current step number
869
+ max_steps:
870
+ type: integer
871
+ description: Maximum steps allowed for this task
872
+ hints:
873
+ type: array
874
+ items: string
875
+ description: Progressive hints revealed as steps increase
876
+
877
+ action_space:
878
+ type: union
879
+ description: Agent must provide exactly one of the following
880
+ options:
881
+ sql_query:
882
+ type: string
883
+ description: A SELECT or WITH SQL query to execute
884
+ submit_answer:
885
+ type: string
886
+ description: Final answer to the question. Ends the episode.
887
+
888
+ tasks:
889
+ - id: monthly_signups
890
+ difficulty: easy
891
+ max_steps: 10
892
+ description: "Count the number of users who signed up in the last 30 days"
893
+ skills_required:
894
+ - COUNT
895
+ - WHERE with date filter
896
+
897
+ - id: top_revenue_category
898
+ difficulty: medium
899
+ max_steps: 15
900
+ description: "Find which product category generated the most revenue in Q3"
901
+ skills_required:
902
+ - JOIN (3 tables)
903
+ - GROUP BY
904
+ - SUM aggregation
905
+ - Date range filtering
906
+
907
+ - id: churn_analysis
908
+ difficulty: hard
909
+ max_steps: 20
910
+ description: >
911
+ Find email addresses of users who placed exactly 3 orders and then
912
+ never ordered again (churned after their 3rd purchase)
913
+ skills_required:
914
+ - Subqueries
915
+ - HAVING clause
916
+ - Date logic
917
+ - Window functions (optional)
918
+
919
+ baseline_scores:
920
+ monthly_signups: 0.82
921
+ top_revenue_category: 0.61
922
+ churn_analysis: 0.38
923
+ average: 0.60
924
+ ```
925
+
926
+ ---
927
+
928
+ ## 11. Dockerfile
929
+
930
+ ```dockerfile
931
+ FROM python:3.11-slim
932
+
933
+ WORKDIR /app
934
+
935
+ # Install dependencies
936
+ COPY requirements.txt .
937
+ RUN pip install --no-cache-dir -r requirements.txt
938
+
939
+ # Copy source
940
+ COPY . .
941
+
942
+ # Pre-seed the database at build time (optional — env also seeds at reset())
943
+ RUN python -c "from env.database import create_database, seed_database; \
944
+ conn = create_database(); seed_database(conn); conn.close()"
945
+
946
+ # Expose port for HuggingFace Spaces
947
+ EXPOSE 7860
948
+
949
+ # Start the API server
950
+ CMD ["python", "-m", "uvicorn", "env.server:app", "--host", "0.0.0.0", "--port", "7860"]
951
+ ```
952
+
953
+ ```
954
+ # requirements.txt
955
+ pydantic>=2.0
956
+ fastapi
957
+ uvicorn
958
+ openai
959
+ faker
960
+ pytest
961
+ ```
962
+
963
+ ---
964
+
965
+ ## 12. README Template
966
+
967
+ ````markdown
968
+ # SQL Data Analyst — OpenEnv Environment
969
+
970
+ An RL training environment where an AI agent learns to answer business intelligence
971
+ questions by writing and executing SQL queries against a live database.
972
+
973
+ ## Motivation
974
+
975
+ Data analysts spend significant time translating business questions into SQL queries.
976
+ This environment trains agents to do exactly that — iteratively exploring a database
977
+ schema, writing queries, observing results, and submitting final answers.
978
+
979
+ ## Observation Space
980
+
981
+ | Field | Type | Description |
982
+ |---|---|---|
983
+ | `schema_summary` | string | Compact DB schema (one line per table) |
984
+ | `question` | string | Natural language business question |
985
+ | `last_query` | string \| null | Most recent SQL query |
986
+ | `last_result` | object \| null | Query result: columns, rows (max 50), error |
987
+ | `last_error` | string \| null | SQL error if last query failed |
988
+ | `step` | int | Current step number |
989
+ | `max_steps` | int | Episode step limit |
990
+ | `hints` | string[] | Progressive hints (revealed after step 5, 10, 15) |
991
+
992
+ ## Action Space
993
+
994
+ Agent must submit exactly one of:
995
+
996
+ | Action | Type | Description |
997
+ |---|---|---|
998
+ | `sql_query` | string | A SELECT or WITH SQL query to execute |
999
+ | `submit_answer` | string | Final answer — ends the episode |
1000
+
1001
+ ## Tasks
1002
+
1003
+ | Task | Difficulty | Max Steps | Description |
1004
+ |---|---|---|---|
1005
+ | `monthly_signups` | Easy | 10 | Count signups in the last 30 days |
1006
+ | `top_revenue_category` | Medium | 15 | Find highest revenue product category in Q3 |
1007
+ | `churn_analysis` | Hard | 20 | Find emails of users who churned after 3 purchases |
1008
+
1009
+ ## Reward Function
1010
+
1011
+ Rewards are given at every step (not just episode end):
1012
+
1013
+ - `+0.15` — Query executes without error
1014
+ - `+0.10` — Query references a relevant table
1015
+ - `+0.05` — Result has at least one row
1016
+ - `+0.05` — Result is a sensible size
1017
+ - `-0.02` per step beyond step 3 (efficiency penalty)
1018
+ - `-0.10` if agent repeats the same query 3+ times
1019
+ - `+0.00–0.60` on final submission (task grader × 0.60)
1020
+
1021
+ ## Setup
1022
+
1023
+ ```bash
1024
+ git clone https://huggingface.co/spaces/your-username/sql-data-analyst
1025
+ cd sql-data-analyst
1026
+ pip install -r requirements.txt
1027
+ ```
1028
+
1029
+ ### Run locally
1030
+
1031
+ ```python
1032
+ from env.environment import SQLAnalystEnv
1033
+ from env.models import Action
1034
+
1035
+ env = SQLAnalystEnv(task_id="monthly_signups")
1036
+ result = env.reset()
1037
+ print(result.observation.question)
1038
+
1039
+ # Agent takes a step
1040
+ result = env.step(Action(sql_query="SELECT COUNT(*) FROM users WHERE created_at >= DATE('now', '-30 days')"))
1041
+ print(result.reward)
1042
+ ```
1043
+
1044
+ ### Run baseline
1045
+
1046
+ ```bash
1047
+ export OPENAI_API_KEY=sk-...
1048
+ python baseline/run_baseline.py
1049
+ ```
1050
+
1051
+ ### Docker
1052
+
1053
+ ```bash
1054
+ docker build -t sql-analyst-env .
1055
+ docker run -p 7860:7860 -e OPENAI_API_KEY=sk-... sql-analyst-env
1056
+ ```
1057
+
1058
+ ## Baseline Scores
1059
+
1060
+ | Task | Score | Model |
1061
+ |---|---|---|
1062
+ | monthly_signups | 0.82 | gpt-4o-mini |
1063
+ | top_revenue_category | 0.61 | gpt-4o-mini |
1064
+ | churn_analysis | 0.38 | gpt-4o-mini |
1065
+ | **Average** | **0.60** | gpt-4o-mini |
1066
+
1067
+ ## Validation
1068
+
1069
+ ```bash
1070
+ openenv validate --env env.environment.SQLAnalystEnv
1071
+ pytest tests/
1072
+ ```
1073
+ ````
1074
+
1075
+ ---
1076
+
1077
+ ## 13. Full File Structure
1078
+
1079
+ ```
1080
+ sql-analyst-openenv/
1081
+
1082
+ ├── env/
1083
+ │ ├── __init__.py
1084
+ │ ├── environment.py ← Main OpenEnv class (reset/step/state)
1085
+ │ ├── models.py ← Pydantic: Observation, Action, StepResult, EnvState
1086
+ │ ├── database.py ← SQLite creation + Faker seeding + schema summary
1087
+ │ ├── executor.py ← Safe SQL execution (SELECT-only guard)
1088
+ │ ├── reward.py ← RewardCalculator class
1089
+ │ ├── utils.py ← normalize_answer, is_safe_query helpers
1090
+ │ ├── server.py ← FastAPI wrapper for HuggingFace Spaces
1091
+ │ └── tasks/
1092
+ │ ├── __init__.py ← TASKS dict: {task_id: TaskInstance}
1093
+ │ ├── base.py ← BaseTask abstract class
1094
+ │ ├── easy.py ← MonthlySignupsTask
1095
+ │ ├── medium.py ← TopRevenueCategoryTask
1096
+ │ └── hard.py ← ChurnAnalysisTask
1097
+
1098
+ ├── baseline/
1099
+ │ ├── run_baseline.py ← Full inference script (OpenAI API)
1100
+ │ └── prompts.py ← System prompt + user prompt builder
1101
+
1102
+ ├── tests/
1103
+ │ ├── test_env.py ← reset/step/state contract tests
1104
+ │ ├── test_graders.py ← Unit tests for each task grader
1105
+ │ └── test_reward.py ← Reward calculator unit tests
1106
+
1107
+ ├── openenv.yaml ← OpenEnv spec metadata
1108
+ ├── Dockerfile ← docker build + docker run
1109
+ ├── requirements.txt
1110
+ └── README.md
1111
+ ```
1112
+
1113
+ ---
1114
+
1115
+ ## 14. Build Order
1116
+
1117
+ Follow this order when coding. Each step is a self-contained deliverable.
1118
+
1119
+ ### Step 1 — Models (30 min)
1120
+ Build `env/models.py` first. All other files depend on these types.
1121
+ Test: can import and instantiate `Observation`, `Action`, `StepResult`.
1122
+
1123
+ ### Step 2 — Database (45 min)
1124
+ Build `env/database.py` — schema creation, Faker seeding, schema summary.
1125
+ Test: run `create_database()` + `seed_database()`, query the tables manually.
1126
+
1127
+ ### Step 3 — Tasks + Graders (60 min)
1128
+ Build `env/tasks/base.py`, then `easy.py`, `medium.py`, `hard.py`.
1129
+ Test each grader with known inputs: perfect answer → 1.0, wrong answer → 0.0.
1130
+
1131
+ ### Step 4 — Reward Calculator (30 min)
1132
+ Build `env/reward.py`.
1133
+ Test: step with good query → positive reward, repeated query → penalty applied.
1134
+
1135
+ ### Step 5 — Environment Core (60 min)
1136
+ Build `env/environment.py` — wire together DB, executor, reward, tasks.
1137
+ Test: full episode loop manually: `reset()` → `step()` × N → `state()`.
1138
+
1139
+ ### Step 6 — Baseline Script (45 min)
1140
+ Build `baseline/run_baseline.py`.
1141
+ Test: run against all 3 tasks, confirm scores are reproducible across 2 runs.
1142
+
1143
+ ### Step 7 — FastAPI Server (30 min)
1144
+ Build `env/server.py` — wrap env in HTTP endpoints for HF Spaces.
1145
+ Test: `docker build` passes, `docker run` starts server on port 7860.
1146
+
1147
+ ### Step 8 — Docs + Validation (30 min)
1148
+ Write `openenv.yaml` and `README.md`. Run `openenv validate`.
1149
+ Fill in real baseline scores from Step 6 output.
1150
+
1151
+ ### Step 9 — Deploy to HuggingFace (15 min)
1152
+ Push to HF Space repo. Tag with `openenv`. Verify Space starts cleanly.
1153
+
1154
+ ---
1155
+
1156
+ *Total estimated time: ~6 hours for a clean first build.*
env/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .models import Action, QueryResult, Observation, StepResult, EnvState
2
+ from .environment import SQLAnalystEnv
3
+ from .tasks import TASKS
4
+
5
+ __all__ = [
6
+ "Action",
7
+ "QueryResult",
8
+ "Observation",
9
+ "StepResult",
10
+ "EnvState",
11
+ "SQLAnalystEnv",
12
+ "TASKS",
13
+ ]
env/database.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ import random
3
+ from datetime import datetime, timedelta
4
+ from typing import Optional, Any
5
+ from faker import Faker
6
+
7
+ fake = Faker()
8
+
9
+ SEED_CONFIG = {
10
+ "users": 500,
11
+ "products": 80,
12
+ "orders": 2000,
13
+ "order_items": 5000,
14
+ "events": 8000,
15
+ }
16
+
17
+ CATEGORIES = ["Electronics", "Clothing", "Books", "Home & Garden", "Sports"]
18
+ PLAN_TYPES = ["free", "pro", "enterprise"]
19
+ ORDER_STATUSES = ["pending", "completed", "refunded"]
20
+ EVENT_TYPES = ["page_view", "add_to_cart", "checkout", "login", "logout"]
21
+
22
+
23
+ def create_database(db_path: str = ":memory:") -> sqlite3.Connection:
24
+ conn = sqlite3.connect(db_path)
25
+ conn.row_factory = sqlite3.Row
26
+
27
+ conn.execute("""
28
+ CREATE TABLE users (
29
+ id INTEGER PRIMARY KEY,
30
+ email TEXT NOT NULL,
31
+ country TEXT,
32
+ plan TEXT CHECK(plan IN ('free', 'pro', 'enterprise')),
33
+ created_at TIMESTAMP NOT NULL,
34
+ churned_at TIMESTAMP
35
+ )
36
+ """)
37
+
38
+ conn.execute("""
39
+ CREATE TABLE products (
40
+ id INTEGER PRIMARY KEY,
41
+ name TEXT NOT NULL,
42
+ category TEXT NOT NULL,
43
+ price REAL,
44
+ cost REAL
45
+ )
46
+ """)
47
+
48
+ conn.execute("""
49
+ CREATE TABLE orders (
50
+ id INTEGER PRIMARY KEY,
51
+ user_id INTEGER REFERENCES users(id),
52
+ created_at TIMESTAMP NOT NULL,
53
+ status TEXT CHECK(status IN ('pending', 'completed', 'refunded')),
54
+ total REAL
55
+ )
56
+ """)
57
+
58
+ conn.execute("""
59
+ CREATE TABLE order_items (
60
+ id INTEGER PRIMARY KEY,
61
+ order_id INTEGER REFERENCES orders(id),
62
+ product_id INTEGER REFERENCES products(id),
63
+ qty INTEGER NOT NULL,
64
+ unit_price REAL
65
+ )
66
+ """)
67
+
68
+ conn.execute("""
69
+ CREATE TABLE events (
70
+ id INTEGER PRIMARY KEY,
71
+ user_id INTEGER REFERENCES users(id),
72
+ event_type TEXT,
73
+ metadata TEXT,
74
+ ts TIMESTAMP NOT NULL
75
+ )
76
+ """)
77
+
78
+ conn.commit()
79
+ return conn
80
+
81
+
82
+ def seed_database(conn: sqlite3.Connection) -> None:
83
+ users = _seed_users(conn)
84
+ products = _seed_products(conn)
85
+ orders, order_items = _seed_orders(conn, users, products)
86
+ _seed_events(conn, users, orders)
87
+
88
+
89
+ def _seed_users(conn: sqlite3.Connection) -> list:
90
+ users = []
91
+ now = datetime.now()
92
+ base_date = now - timedelta(days=180)
93
+ recent_date = now - timedelta(days=30)
94
+
95
+ for i in range(SEED_CONFIG["users"]):
96
+ if random.random() < 0.3:
97
+ created_at = recent_date + timedelta(days=random.randint(0, 30))
98
+ else:
99
+ created_at = base_date + timedelta(days=random.randint(0, 180))
100
+
101
+ country = random.choice([fake.country(), None, None, None, None])
102
+ plan = random.choice(PLAN_TYPES)
103
+ churned_at = None
104
+
105
+ if plan == "free" and random.random() < 0.1:
106
+ churned_at = created_at + timedelta(days=random.randint(30, 150))
107
+
108
+ conn.execute(
109
+ "INSERT INTO users (email, country, plan, created_at, churned_at) VALUES (?, ?, ?, ?, ?)",
110
+ (
111
+ fake.email(),
112
+ country,
113
+ plan,
114
+ created_at.isoformat(),
115
+ churned_at.isoformat() if churned_at else None,
116
+ ),
117
+ )
118
+ users.append((i + 1, created_at))
119
+
120
+ conn.commit()
121
+ return users
122
+
123
+
124
+ def _seed_products(conn: sqlite3.Connection) -> list:
125
+ products = []
126
+
127
+ for i in range(SEED_CONFIG["products"]):
128
+ category = random.choice(CATEGORIES)
129
+ price = round(random.uniform(10, 500), 2)
130
+ cost = round(price * random.uniform(0.3, 0.7), 2)
131
+
132
+ conn.execute(
133
+ "INSERT INTO products (name, category, price, cost) VALUES (?, ?, ?, ?)",
134
+ (fake.catch_phrase(), category, price, cost),
135
+ )
136
+ products.append((i + 1, category, price))
137
+
138
+ conn.commit()
139
+ return products
140
+
141
+
142
+ def _seed_orders(conn: sqlite3.Connection, users: list, products: list) -> tuple:
143
+ orders = []
144
+ order_items = []
145
+
146
+ q3_start = datetime(2024, 7, 1)
147
+ q3_end = datetime(2024, 9, 30)
148
+ recent_date = datetime.now()
149
+ old_date = datetime(2024, 1, 1)
150
+
151
+ for i in range(SEED_CONFIG["orders"]):
152
+ user_id = random.choice(users)[0]
153
+
154
+ if random.random() < 0.2:
155
+ created_at = q3_start + timedelta(days=random.randint(0, 91))
156
+ else:
157
+ created_at = old_date + timedelta(days=random.randint(0, 180))
158
+
159
+ status = random.choices(ORDER_STATUSES, weights=[0.1, 0.87, 0.03])[0]
160
+
161
+ conn.execute(
162
+ "INSERT INTO orders (user_id, created_at, status, total) VALUES (?, ?, ?, ?)",
163
+ (user_id, created_at.isoformat(), status, 0),
164
+ )
165
+
166
+ order_id = i + 1
167
+ order_total = 0
168
+
169
+ num_items = random.randint(1, 5)
170
+ for _ in range(num_items):
171
+ product = random.choice(products)
172
+ qty = random.randint(1, 3)
173
+ unit_price = product[2]
174
+ order_total += qty * unit_price
175
+
176
+ conn.execute(
177
+ "INSERT INTO order_items (order_id, product_id, qty, unit_price) VALUES (?, ?, ?, ?)",
178
+ (order_id, product[0], qty, unit_price),
179
+ )
180
+
181
+ conn.execute(
182
+ "UPDATE orders SET total = ? WHERE id = ?",
183
+ (round(order_total, 2), order_id),
184
+ )
185
+ orders.append((order_id, user_id, created_at, status))
186
+
187
+ conn.commit()
188
+ return orders, order_items
189
+
190
+
191
+ def _seed_events(conn: sqlite3.Connection, users: list, orders: list) -> None:
192
+ base_date = datetime.now() - timedelta(days=180)
193
+
194
+ for _ in range(SEED_CONFIG["events"]):
195
+ user_id = random.choice(users)[0]
196
+ ts = base_date + timedelta(
197
+ days=random.randint(0, 180), hours=random.randint(0, 23)
198
+ )
199
+ event_type = random.choice(EVENT_TYPES)
200
+ metadata = '{"page": "/' + fake.uri_path() + '"}'
201
+
202
+ conn.execute(
203
+ "INSERT INTO events (user_id, event_type, metadata, ts) VALUES (?, ?, ?, ?)",
204
+ (user_id, event_type, metadata, ts.isoformat()),
205
+ )
206
+
207
+ conn.commit()
208
+
209
+
210
+ def get_schema_summary(conn: sqlite3.Connection) -> str:
211
+ cursor = conn.execute(
212
+ "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name"
213
+ )
214
+ tables = [r[0] for r in cursor.fetchall()]
215
+
216
+ lines = []
217
+ for table in tables:
218
+ cols = conn.execute(f"PRAGMA table_info({table})").fetchall()
219
+ col_names = [c[1] for c in cols]
220
+ lines.append(f"{table}: ({', '.join(col_names)})")
221
+
222
+ return "\n".join(lines)
223
+
224
+
225
+ def get_ground_truth(conn: sqlite3.Connection, task_id: str) -> Any:
226
+ if task_id == "monthly_signups":
227
+ result = conn.execute(
228
+ "SELECT COUNT(*) FROM users WHERE created_at >= DATE('now', '-30 days')"
229
+ ).fetchone()
230
+ return result[0]
231
+
232
+ elif task_id == "top_revenue_category":
233
+ result = conn.execute("""
234
+ SELECT p.category, SUM(oi.qty * oi.unit_price) AS revenue
235
+ FROM orders o
236
+ JOIN order_items oi ON o.id = oi.order_id
237
+ JOIN products p ON oi.product_id = p.id
238
+ WHERE o.created_at BETWEEN '2024-07-01' AND '2024-09-30'
239
+ AND o.status = 'completed'
240
+ GROUP BY p.category
241
+ ORDER BY revenue DESC
242
+ LIMIT 1
243
+ """).fetchone()
244
+ return result[0] if result else None
245
+
246
+ elif task_id == "churn_analysis":
247
+ result = conn.execute("""
248
+ WITH order_counts AS (
249
+ SELECT user_id, COUNT(*) AS total_orders,
250
+ MAX(created_at) AS last_order_date
251
+ FROM orders
252
+ WHERE status = 'completed'
253
+ GROUP BY user_id
254
+ HAVING COUNT(*) = 3
255
+ ),
256
+ churned AS (
257
+ SELECT oc.user_id
258
+ FROM order_counts oc
259
+ WHERE oc.last_order_date < DATE('now', '-90 days')
260
+ )
261
+ SELECT u.email
262
+ FROM users u
263
+ JOIN churned c ON u.id = c.user_id
264
+ """).fetchall()
265
+ return {row[0].lower() for row in result}
266
+
267
+ return None
env/environment.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ from typing import Optional
3
+ from .models import Action, Observation, StepResult, EnvState, QueryResult
4
+ from .database import create_database, seed_database, get_schema_summary
5
+ from .reward import RewardCalculator
6
+ from .tasks import TASKS
7
+
8
+
9
+ class SQLAnalystEnv:
10
+ """
11
+ OpenEnv-compliant SQL Data Analyst environment.
12
+
13
+ An agent must answer business questions by iteratively
14
+ writing and executing SQL queries.
15
+ """
16
+
17
+ def __init__(self, task_id: str = "monthly_signups"):
18
+ assert task_id in TASKS, f"Unknown task: {task_id}. Choose from {list(TASKS)}"
19
+ self.task_id = task_id
20
+ self.task = TASKS[task_id]
21
+ self.conn: Optional[sqlite3.Connection] = None
22
+ self.step_count: int = 0
23
+ self.total_reward: float = 0.0
24
+ self.done: bool = False
25
+ self._query_history: list = []
26
+ self._reward_calc = RewardCalculator()
27
+
28
+ def reset(self) -> StepResult:
29
+ """Reset environment. Reseed DB. Return initial observation."""
30
+ if self.conn:
31
+ self.conn.close()
32
+
33
+ self.conn = create_database()
34
+ seed_database(self.conn)
35
+ self.step_count = 0
36
+ self.total_reward = 0.0
37
+ self.done = False
38
+ self._query_history = []
39
+
40
+ self.task.compute_ground_truth(self.conn)
41
+
42
+ obs = Observation(
43
+ schema_summary=get_schema_summary(self.conn),
44
+ question=self.task.question,
45
+ step=0,
46
+ max_steps=self.task.max_steps,
47
+ )
48
+ return StepResult(observation=obs, reward=0.0, done=False)
49
+
50
+ def step(self, action: Action) -> StepResult:
51
+ """Execute one agent action. Return (observation, reward, done, info)."""
52
+ assert self.conn is not None, "Call reset() before step()"
53
+ assert not self.done, "Episode is done. Call reset()."
54
+ assert action.is_valid(), (
55
+ "Action must have exactly one of: sql_query, submit_answer"
56
+ )
57
+
58
+ self.step_count += 1
59
+ query_result = None
60
+ error = None
61
+
62
+ if action.sql_query:
63
+ query_result = self._execute_sql(action.sql_query)
64
+ self._query_history.append(action.sql_query)
65
+ error = query_result.error
66
+
67
+ terminal = (
68
+ action.submit_answer is not None or self.step_count >= self.task.max_steps
69
+ )
70
+
71
+ reward = self._reward_calc.calculate(
72
+ action=action,
73
+ result=query_result,
74
+ task=self.task,
75
+ step=self.step_count,
76
+ query_history=self._query_history,
77
+ terminal=terminal,
78
+ )
79
+ self.total_reward += reward
80
+ self.done = terminal
81
+
82
+ obs = Observation(
83
+ schema_summary=get_schema_summary(self.conn),
84
+ question=self.task.question,
85
+ last_query=action.sql_query,
86
+ last_result=query_result,
87
+ last_error=error,
88
+ step=self.step_count,
89
+ max_steps=self.task.max_steps,
90
+ hints=self.task.get_hints(self.step_count),
91
+ done=self.done,
92
+ )
93
+
94
+ return StepResult(
95
+ observation=obs,
96
+ reward=round(reward, 3),
97
+ done=self.done,
98
+ info={
99
+ "step": self.step_count,
100
+ "total_reward": round(self.total_reward, 3),
101
+ "task_id": self.task_id,
102
+ },
103
+ )
104
+
105
+ def state(self) -> EnvState:
106
+ """Return current full state of the environment."""
107
+ return EnvState(
108
+ task_id=self.task_id,
109
+ difficulty=self.task.difficulty,
110
+ step=self.step_count,
111
+ max_steps=self.task.max_steps,
112
+ query_history=self._query_history.copy(),
113
+ total_reward=round(self.total_reward, 3),
114
+ done=self.done,
115
+ )
116
+
117
+ def _execute_sql(self, query: str) -> QueryResult:
118
+ """Execute SQL safely. Block non-SELECT. Return up to 50 rows."""
119
+ q = query.strip().upper()
120
+ if not q.startswith("SELECT") and not q.startswith("WITH"):
121
+ return QueryResult(error="Only SELECT / WITH queries are allowed.")
122
+ try:
123
+ cursor = self.conn.execute(query)
124
+ cols = [d[0] for d in cursor.description] if cursor.description else []
125
+ rows = cursor.fetchmany(50)
126
+ total = len(rows)
127
+ return QueryResult(
128
+ columns=cols,
129
+ rows=[list(r) for r in rows],
130
+ truncated=(total == 50),
131
+ total_rows=total,
132
+ )
133
+ except Exception as e:
134
+ return QueryResult(error=str(e))
env/models.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ from typing import Optional, List, Any
3
+
4
+ class Action(BaseModel):
5
+ """What the agent can do each step."""
6
+ sql_query: Optional[str] = Field(
7
+ None,
8
+ description="A SQL SELECT query to execute against the database"
9
+ )
10
+ submit_answer: Optional[str] = Field(
11
+ None,
12
+ description="Final answer to submit. Ends the episode."
13
+ )
14
+
15
+ def is_valid(self) -> bool:
16
+ # Exactly one of the two must be set
17
+ return bool(self.sql_query) != bool(self.submit_answer)
18
+
19
+
20
+ class QueryResult(BaseModel):
21
+ """Result of executing a SQL query."""
22
+ columns: List[str] = []
23
+ rows: List[List[Any]] = []
24
+ error: Optional[str] = None
25
+ truncated: bool = False
26
+ total_rows: int = 0
27
+
28
+
29
+ class Observation(BaseModel):
30
+ """What the agent sees after each step."""
31
+ schema_summary: str = Field(..., description="Compact DB schema")
32
+ question: str = Field(..., description="Business question to answer")
33
+ last_query: Optional[str] = None
34
+ last_result: Optional[QueryResult] = None
35
+ last_error: Optional[str] = None
36
+ step: int = 0
37
+ max_steps: int = 20
38
+ hints: List[str] = []
39
+ done: bool = False
40
+
41
+
42
+ class StepResult(BaseModel):
43
+ """Full result returned by step()."""
44
+ observation: Observation
45
+ reward: float = 0.0
46
+ done: bool = False
47
+ info: dict = {}
48
+
49
+
50
+ class EnvState(BaseModel):
51
+ """Full environment state returned by state()."""
52
+ task_id: str
53
+ difficulty: str
54
+ step: int
55
+ max_steps: int
56
+ query_history: List[str] = []
57
+ total_reward: float = 0.0
58
+ done: bool = False
env/reward.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List, Any
2
+ from .models import Action, QueryResult
3
+
4
+
5
+ class RewardCalculator:
6
+ """Calculate rewards for agent actions in the SQL analyst environment."""
7
+
8
+ def calculate(
9
+ self,
10
+ action: Action,
11
+ result: Optional[QueryResult],
12
+ task: Any,
13
+ step: int,
14
+ query_history: List[str],
15
+ terminal: bool,
16
+ ) -> float:
17
+ """Calculate reward based on action, result, and task."""
18
+ reward = 0.0
19
+
20
+ if action.sql_query and result:
21
+ if not result.error:
22
+ reward += 0.15
23
+
24
+ relevant = self._count_relevant_tables(
25
+ action.sql_query, task.relevant_tables
26
+ )
27
+ if relevant > 0:
28
+ reward += 0.10
29
+
30
+ if result.rows and len(result.rows) > 0:
31
+ reward += 0.05
32
+
33
+ if result.rows and len(result.rows) < 1000:
34
+ reward += 0.05
35
+
36
+ if step > 3:
37
+ reward -= 0.02 * (step - 3)
38
+
39
+ if self._is_stuck(query_history):
40
+ reward -= 0.10
41
+
42
+ if terminal and action.submit_answer:
43
+ task_score = task.grade(action.submit_answer)
44
+ reward += task_score * 0.60
45
+
46
+ return max(0.0, min(1.0, reward))
47
+
48
+ def _count_relevant_tables(self, query: str, relevant_tables: List[str]) -> int:
49
+ query_lower = query.lower()
50
+ return sum(1 for t in relevant_tables if t.lower() in query_lower)
51
+
52
+ def _is_stuck(self, history: List[str]) -> bool:
53
+ if len(history) < 3:
54
+ return False
55
+ return len(set(history[-3:])) == 1
env/server.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI server for SQL Data Analyst OpenEnv.
3
+
4
+ Provides REST and WebSocket endpoints for HuggingFace Spaces deployment.
5
+ """
6
+
7
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
8
+ from pydantic import BaseModel
9
+ from typing import Optional, Dict, Any
10
+ import json
11
+ import asyncio
12
+
13
+ from env import SQLAnalystEnv, Action
14
+
15
+
16
+ app = FastAPI(title="SQL Data Analyst Environment")
17
+
18
+ envs: Dict[str, SQLAnalystEnv] = {}
19
+
20
+
21
+ class ResetRequest(BaseModel):
22
+ task_id: str = "monthly_signups"
23
+
24
+
25
+ class StepRequest(BaseModel):
26
+ session_id: str
27
+ sql_query: Optional[str] = None
28
+ submit_answer: Optional[str] = None
29
+
30
+
31
+ class StateRequest(BaseModel):
32
+ session_id: str
33
+
34
+
35
+ @app.get("/")
36
+ async def root():
37
+ return {
38
+ "name": "sql-data-analyst",
39
+ "version": "1.0.0",
40
+ "description": "SQL Data Analyst OpenEnv - RL environment for SQL query generation",
41
+ }
42
+
43
+
44
+ @app.post("/reset")
45
+ async def reset(req: ResetRequest) -> Dict[str, Any]:
46
+ session_id = req.task_id
47
+
48
+ env = SQLAnalystEnv(task_id=req.task_id)
49
+ result = env.reset()
50
+ envs[session_id] = env
51
+
52
+ return {
53
+ "session_id": session_id,
54
+ "observation": {
55
+ "schema_summary": result.observation.schema_summary,
56
+ "question": result.observation.question,
57
+ "step": result.observation.step,
58
+ "max_steps": result.observation.max_steps,
59
+ "hints": result.observation.hints,
60
+ "done": result.observation.done,
61
+ },
62
+ "reward": result.reward,
63
+ "done": result.done,
64
+ }
65
+
66
+
67
+ @app.post("/step")
68
+ async def step(req: StepRequest) -> Dict[str, Any]:
69
+ session_id = req.session_id
70
+
71
+ if session_id not in envs:
72
+ return {"error": "Session not found. Call /reset first."}
73
+
74
+ env = envs[session_id]
75
+
76
+ action = Action(sql_query=req.sql_query, submit_answer=req.submit_answer)
77
+
78
+ result = env.step(action)
79
+
80
+ return {
81
+ "observation": {
82
+ "schema_summary": result.observation.schema_summary,
83
+ "question": result.observation.question,
84
+ "last_query": result.observation.last_query,
85
+ "last_result": {
86
+ "columns": result.observation.last_result.columns
87
+ if result.observation.last_result
88
+ else None,
89
+ "rows": result.observation.last_result.rows
90
+ if result.observation.last_result
91
+ else None,
92
+ "error": result.observation.last_result.error
93
+ if result.observation.last_result
94
+ else None,
95
+ }
96
+ if result.observation.last_result
97
+ else None,
98
+ "last_error": result.observation.last_error,
99
+ "step": result.observation.step,
100
+ "max_steps": result.observation.max_steps,
101
+ "hints": result.observation.hints,
102
+ "done": result.observation.done,
103
+ },
104
+ "reward": result.reward,
105
+ "done": result.done,
106
+ "info": result.info,
107
+ }
108
+
109
+
110
+ @app.post("/state")
111
+ async def state(req: StateRequest) -> Dict[str, Any]:
112
+ session_id = req.session_id
113
+
114
+ if session_id not in envs:
115
+ return {"error": "Session not found. Call /reset first."}
116
+
117
+ env = envs[session_id]
118
+ state = env.state()
119
+
120
+ return {
121
+ "task_id": state.task_id,
122
+ "difficulty": state.difficulty,
123
+ "step": state.step,
124
+ "max_steps": state.max_steps,
125
+ "query_history": state.query_history,
126
+ "total_reward": state.total_reward,
127
+ "done": state.done,
128
+ }
129
+
130
+
131
+ @app.post("/delete")
132
+ async def delete_session(req: StateRequest) -> Dict[str, str]:
133
+ session_id = req.session_id
134
+
135
+ if session_id in envs:
136
+ del envs[session_id]
137
+ return {"status": "deleted", "session_id": session_id}
138
+
139
+ return {"status": "not_found", "session_id": session_id}
140
+
141
+
142
+ @app.websocket("/ws")
143
+ async def websocket_endpoint(websocket: WebSocket):
144
+ await websocket.accept()
145
+
146
+ session_id = None
147
+ env = None
148
+
149
+ try:
150
+ while True:
151
+ data = await websocket.receive_text()
152
+ message = json.loads(data)
153
+
154
+ action_type = message.get("type")
155
+
156
+ if action_type == "reset":
157
+ task_id = message.get("task_id", "monthly_signups")
158
+ env = SQLAnalystEnv(task_id=task_id)
159
+ result = env.reset()
160
+ session_id = task_id
161
+ envs[session_id] = env
162
+
163
+ await websocket.send_json(
164
+ {
165
+ "type": "reset",
166
+ "observation": {
167
+ "schema_summary": result.observation.schema_summary,
168
+ "question": result.observation.question,
169
+ "step": result.observation.step,
170
+ "max_steps": result.observation.max_steps,
171
+ "hints": result.observation.hints,
172
+ },
173
+ "reward": result.reward,
174
+ "done": result.done,
175
+ }
176
+ )
177
+
178
+ elif action_type == "step":
179
+ if not env:
180
+ await websocket.send_json({"error": "Call reset first"})
181
+ continue
182
+
183
+ action = Action(
184
+ sql_query=message.get("sql_query"),
185
+ submit_answer=message.get("submit_answer"),
186
+ )
187
+
188
+ result = env.step(action)
189
+
190
+ await websocket.send_json(
191
+ {
192
+ "type": "step",
193
+ "observation": {
194
+ "schema_summary": result.observation.schema_summary,
195
+ "question": result.observation.question,
196
+ "last_query": result.observation.last_query,
197
+ "last_result": {
198
+ "columns": result.observation.last_result.columns
199
+ if result.observation.last_result
200
+ else None,
201
+ "rows": result.observation.last_result.rows
202
+ if result.observation.last_result
203
+ else None,
204
+ "error": result.observation.last_result.error
205
+ if result.observation.last_result
206
+ else None,
207
+ }
208
+ if result.observation.last_result
209
+ else None,
210
+ "step": result.observation.step,
211
+ "hints": result.observation.hints,
212
+ "done": result.observation.done,
213
+ },
214
+ "reward": result.reward,
215
+ "done": result.done,
216
+ "info": result.info,
217
+ }
218
+ )
219
+
220
+ elif action_type == "state":
221
+ if not env:
222
+ await websocket.send_json({"error": "Call reset first"})
223
+ continue
224
+
225
+ state = env.state()
226
+
227
+ await websocket.send_json(
228
+ {
229
+ "type": "state",
230
+ "task_id": state.task_id,
231
+ "difficulty": state.difficulty,
232
+ "step": state.step,
233
+ "max_steps": state.max_steps,
234
+ "query_history": state.query_history,
235
+ "total_reward": state.total_reward,
236
+ "done": state.done,
237
+ }
238
+ )
239
+
240
+ elif action_type == "close":
241
+ if session_id and session_id in envs:
242
+ del envs[session_id]
243
+ break
244
+
245
+ except WebSocketDisconnect:
246
+ pass
247
+ except Exception as e:
248
+ await websocket.send_json({"error": str(e)})
249
+
250
+
251
+ if __name__ == "__main__":
252
+ import uvicorn
253
+
254
+ uvicorn.run(app, host="0.0.0.0", port=7860)
env/tasks/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base import BaseTask
2
+ from .easy import MonthlySignupsTask
3
+ from .medium import TopRevenueCategoryTask
4
+ from .hard import ChurnAnalysisTask
5
+
6
+
7
+ TASKS = {
8
+ "monthly_signups": MonthlySignupsTask(),
9
+ "top_revenue_category": TopRevenueCategoryTask(),
10
+ "churn_analysis": ChurnAnalysisTask(),
11
+ }
12
+
13
+
14
+ __all__ = [
15
+ "BaseTask",
16
+ "MonthlySignupsTask",
17
+ "TopRevenueCategoryTask",
18
+ "ChurnAnalysisTask",
19
+ "TASKS",
20
+ ]
env/tasks/base.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ import sqlite3
3
+ import re
4
+ from typing import Any, List, Optional
5
+
6
+
7
+ class BaseTask(ABC):
8
+ """Abstract base class for all tasks."""
9
+
10
+ task_id: str
11
+ difficulty: str
12
+ max_steps: int
13
+ question: str
14
+ relevant_tables: List[str]
15
+ sql_hint: str
16
+
17
+ def __init__(self):
18
+ self.ground_truth: Any = None
19
+ self.top_3_categories: List[str] = []
20
+
21
+ @abstractmethod
22
+ def compute_ground_truth(self, conn: sqlite3.Connection) -> None:
23
+ """Compute ground truth after database seeding."""
24
+ pass
25
+
26
+ @abstractmethod
27
+ def grade(self, submitted_answer: str) -> float:
28
+ """Grade the submitted answer. Returns score 0.0-1.0."""
29
+ pass
30
+
31
+ def get_hints(self, step: int) -> List[str]:
32
+ """Return progressive hints based on current step."""
33
+ hints = []
34
+ if step > 5:
35
+ hints.append(
36
+ f"Hint: The relevant tables are: {', '.join(self.relevant_tables)}"
37
+ )
38
+ if step > 10:
39
+ hints.append(f"Hint: Try using {self.sql_hint}")
40
+ if step > 15:
41
+ hints.append("Hint: Make sure to submit your answer with submit_answer.")
42
+ return hints
43
+
44
+ def _normalize(self, text: str) -> str:
45
+ """Remove common LLM formatting and normalize text."""
46
+ text = text.strip().lower()
47
+ text = re.sub(r"the (answer|result|category) is:?\s*", "", text)
48
+ text = re.sub(r"\*+", "", text)
49
+ text = re.sub(r"```.*?```", "", text, flags=re.DOTALL)
50
+ text = re.sub(r"`[^`]+`", lambda m: m.group().strip("`"), text)
51
+ text = re.sub(r"\s+", " ", text)
52
+ return text.strip()
env/tasks/easy.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ from .base import BaseTask
3
+
4
+
5
+ class MonthlySignupsTask(BaseTask):
6
+ """Task 1 — Easy: Count users who signed up in the last 30 days."""
7
+
8
+ task_id = "monthly_signups"
9
+ difficulty = "easy"
10
+ max_steps = 10
11
+ question = "How many users signed up in the last 30 days?"
12
+ relevant_tables = ["users"]
13
+ sql_hint = "COUNT(*) with WHERE clause on created_at"
14
+
15
+ def compute_ground_truth(self, conn: sqlite3.Connection) -> None:
16
+ result = conn.execute(
17
+ "SELECT COUNT(*) FROM users WHERE created_at >= DATE('now', '-30 days')"
18
+ ).fetchone()
19
+ self.ground_truth = result[0] if result else 0
20
+
21
+ def grade(self, submitted_answer: str) -> float:
22
+ try:
23
+ val = int(submitted_answer.strip().replace(",", ""))
24
+ if val == self.ground_truth:
25
+ return 1.0
26
+ if abs(val - self.ground_truth) <= 3:
27
+ return 0.6
28
+ if abs(val - self.ground_truth) <= 10:
29
+ return 0.3
30
+ except (ValueError, AttributeError):
31
+ pass
32
+ return 0.0
env/tasks/hard.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ from .base import BaseTask
3
+
4
+
5
+ class ChurnAnalysisTask(BaseTask):
6
+ """Task 3 — Hard: Find users who placed exactly 3 orders and then churned."""
7
+
8
+ task_id = "churn_analysis"
9
+ difficulty = "hard"
10
+ max_steps = 20
11
+ question = "Find the email addresses of users who placed exactly 3 orders and then never ordered again (churned after their 3rd purchase). Return as a comma-separated list."
12
+ relevant_tables = ["users", "orders"]
13
+ sql_hint = "CTE with COUNT and HAVING"
14
+
15
+ def compute_ground_truth(self, conn: sqlite3.Connection) -> None:
16
+ result = conn.execute("""
17
+ WITH order_counts AS (
18
+ SELECT user_id, COUNT(*) AS total_orders,
19
+ MAX(created_at) AS last_order_date
20
+ FROM orders
21
+ WHERE status = 'completed'
22
+ GROUP BY user_id
23
+ HAVING COUNT(*) = 3
24
+ ),
25
+ churned AS (
26
+ SELECT oc.user_id
27
+ FROM order_counts oc
28
+ WHERE oc.last_order_date < DATE('now', '-90 days')
29
+ )
30
+ SELECT u.email
31
+ FROM users u
32
+ JOIN churned c ON u.id = c.user_id
33
+ """).fetchall()
34
+
35
+ self.ground_truth = {row[0].lower() for row in result}
36
+
37
+ def grade(self, submitted_answer: str) -> float:
38
+ if not submitted_answer.strip():
39
+ return 0.0
40
+
41
+ submitted = {e.strip().lower() for e in submitted_answer.split(",") if "@" in e}
42
+
43
+ if not submitted:
44
+ return 0.0
45
+
46
+ correct = {e.lower() for e in self.ground_truth}
47
+ tp = len(submitted & correct)
48
+
49
+ if tp == 0:
50
+ return 0.0
51
+
52
+ precision = tp / len(submitted) if submitted else 0
53
+ recall = tp / len(correct) if correct else 0
54
+
55
+ if precision + recall == 0:
56
+ return 0.0
57
+
58
+ f1 = 2 * precision * recall / (precision + recall)
59
+ return round(f1, 3)
env/tasks/medium.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ from .base import BaseTask
3
+
4
+
5
+ class TopRevenueCategoryTask(BaseTask):
6
+ """Task 2 — Medium: Find product category with most revenue in Q3."""
7
+
8
+ task_id = "top_revenue_category"
9
+ difficulty = "medium"
10
+ max_steps = 15
11
+ question = (
12
+ "Which product category generated the most revenue in Q3 (July-September)?"
13
+ )
14
+ relevant_tables = ["orders", "order_items", "products"]
15
+ sql_hint = "JOIN with GROUP BY and ORDER BY"
16
+
17
+ def compute_ground_truth(self, conn: sqlite3.Connection) -> None:
18
+ result = conn.execute("""
19
+ SELECT p.category, SUM(oi.qty * oi.unit_price) AS revenue
20
+ FROM orders o
21
+ JOIN order_items oi ON o.id = oi.order_id
22
+ JOIN products p ON oi.product_id = p.id
23
+ WHERE o.created_at BETWEEN '2024-07-01' AND '2024-09-30'
24
+ AND o.status = 'completed'
25
+ GROUP BY p.category
26
+ ORDER BY revenue DESC
27
+ LIMIT 1
28
+ """).fetchone()
29
+
30
+ self.ground_truth = result[0] if result else None
31
+
32
+ all_categories = conn.execute("""
33
+ SELECT p.category, SUM(oi.qty * oi.unit_price) AS revenue
34
+ FROM orders o
35
+ JOIN order_items oi ON o.id = oi.order_id
36
+ JOIN products p ON oi.product_id = p.id
37
+ WHERE o.created_at BETWEEN '2024-07-01' AND '2024-09-30'
38
+ AND o.status = 'completed'
39
+ GROUP BY p.category
40
+ ORDER BY revenue DESC
41
+ """).fetchall()
42
+
43
+ self.top_3_categories = [row[0] for row in all_categories[:3]]
44
+
45
+ def grade(self, submitted_answer: str) -> float:
46
+ answer = self._normalize(submitted_answer)
47
+
48
+ if self.ground_truth and self.ground_truth.lower() in answer:
49
+ return 1.0
50
+
51
+ if any(cat.lower() in answer for cat in self.top_3_categories):
52
+ return 0.4
53
+
54
+ return 0.0
env/utils.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+
4
+ def normalize_answer(raw: str) -> str:
5
+ """Remove common LLM answer preambles and formatting."""
6
+ text = raw.strip().lower()
7
+ text = re.sub(r"the (answer|result) is:?\s*", "", text)
8
+ text = re.sub(r"\*+", "", text)
9
+ text = re.sub(r"```.*?```", "", text, flags=re.DOTALL)
10
+ text = re.sub(r"`[^`]+`", lambda m: m.group().strip("`"), text)
11
+ text = re.sub(r"\s+", " ", text)
12
+ return text.strip()
13
+
14
+
15
+ FORBIDDEN_KEYWORDS = [
16
+ "DROP",
17
+ "DELETE",
18
+ "INSERT",
19
+ "UPDATE",
20
+ "ALTER",
21
+ "CREATE",
22
+ "TRUNCATE",
23
+ ]
24
+
25
+
26
+ def is_safe_query(query: str) -> bool:
27
+ """Check if query is safe (SELECT-only)."""
28
+ upper = query.upper()
29
+ return not any(kw in upper for kw in FORBIDDEN_KEYWORDS)
hf_space/.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
hf_space/README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Sql Data Analyst
3
+ emoji: 📉
4
+ colorFrom: gray
5
+ colorTo: green
6
+ sdk: docker
7
+ pinned: false
8
+ license: mit
9
+ ---
10
+
11
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
models.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ from typing import Optional, List, Any
3
+
4
+ class Action(BaseModel):
5
+ """What the agent can do each step."""
6
+ sql_query: Optional[str] = Field(
7
+ None,
8
+ description="A SQL SELECT query to execute against the database"
9
+ )
10
+ submit_answer: Optional[str] = Field(
11
+ None,
12
+ description="Final answer to submit. Ends the episode."
13
+ )
14
+
15
+ def is_valid(self) -> bool:
16
+ # Exactly one of the two must be set
17
+ return bool(self.sql_query) != bool(self.submit_answer)
18
+
19
+
20
+ class QueryResult(BaseModel):
21
+ """Result of executing a SQL query."""
22
+ columns: List[str] = []
23
+ rows: List[List[Any]] = []
24
+ error: Optional[str] = None
25
+ truncated: bool = False
26
+ total_rows: int = 0
27
+
28
+
29
+ class Observation(BaseModel):
30
+ """What the agent sees after each step."""
31
+ schema_summary: str = Field(..., description="Compact DB schema")
32
+ question: str = Field(..., description="Business question to answer")
33
+ last_query: Optional[str] = None
34
+ last_result: Optional[QueryResult] = None
35
+ last_error: Optional[str] = None
36
+ step: int = 0
37
+ max_steps: int = 20
38
+ hints: List[str] = []
39
+ done: bool = False
40
+
41
+
42
+ class StepResult(BaseModel):
43
+ """Full result returned by step()."""
44
+ observation: Observation
45
+ reward: float = 0.0
46
+ done: bool = False
47
+ info: dict = {}
48
+
49
+
50
+ class EnvState(BaseModel):
51
+ """Full environment state returned by state()."""
52
+ task_id: str
53
+ difficulty: str
54
+ step: int
55
+ max_steps: int
56
+ query_history: List[str] = []
57
+ total_reward: float = 0.0
58
+ done: bool = False
openenv.yaml ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: sql-data-analyst
2
+ version: "1.0.0"
3
+ description: >
4
+ An RL environment where an AI agent answers real business intelligence questions
5
+ by iteratively writing and executing SQL queries against a live SQLite database.
6
+ Simulates the day-to-day workflow of a data analyst.
7
+
8
+ tags:
9
+ - openenv
10
+ - sql
11
+ - data-analysis
12
+ - business-intelligence
13
+ - real-world
14
+
15
+ author: sql-data-analyst
16
+ repository: https://huggingface.co/spaces/sql-data-analyst
17
+
18
+ observation_space:
19
+ type: dict
20
+ fields:
21
+ schema_summary:
22
+ type: string
23
+ description: Compact one-line-per-table schema of the database
24
+ question:
25
+ type: string
26
+ description: Natural language business question to answer
27
+ last_query:
28
+ type: string
29
+ nullable: true
30
+ description: The last SQL query executed by the agent
31
+ last_result:
32
+ type: object
33
+ nullable: true
34
+ description: Result of the last query (columns, rows, error)
35
+ last_error:
36
+ type: string
37
+ nullable: true
38
+ description: SQL error message if last query failed
39
+ step:
40
+ type: integer
41
+ description: Current step number
42
+ max_steps:
43
+ type: integer
44
+ description: Maximum steps allowed for this task
45
+ hints:
46
+ type: array
47
+ items: string
48
+ description: Progressive hints revealed as steps increase
49
+ done:
50
+ type: boolean
51
+ description: Whether the episode is complete
52
+
53
+ action_space:
54
+ type: union
55
+ description: Agent must provide exactly one of the following
56
+ options:
57
+ sql_query:
58
+ type: string
59
+ description: A SELECT or WITH SQL query to execute
60
+ submit_answer:
61
+ type: string
62
+ description: Final answer to the question. Ends the episode.
63
+
64
+ tasks:
65
+ - id: monthly_signups
66
+ difficulty: easy
67
+ max_steps: 10
68
+ description: "Count the number of users who signed up in the last 30 days"
69
+ skills_required:
70
+ - COUNT
71
+ - WHERE with date filter
72
+
73
+ - id: top_revenue_category
74
+ difficulty: medium
75
+ max_steps: 15
76
+ description: "Find which product category generated the most revenue in Q3"
77
+ skills_required:
78
+ - JOIN (3 tables)
79
+ - GROUP BY
80
+ - SUM aggregation
81
+ - Date range filtering
82
+
83
+ - id: churn_analysis
84
+ difficulty: hard
85
+ max_steps: 20
86
+ description: >
87
+ Find email addresses of users who placed exactly 3 orders and then
88
+ never ordered again (churned after their 3rd purchase)
89
+ skills_required:
90
+ - Subqueries
91
+ - HAVING clause
92
+ - Date logic
93
+ - Window functions (optional)
94
+
95
+ baseline_scores:
96
+ monthly_signups: 0.85
97
+ top_revenue_category: 0.65
98
+ churn_analysis: 0.40
99
+ average: 0.63
100
+ model: gpt-4o-mini
progress.txt ADDED
File without changes
pyproject.toml ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61.0", "wheel"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "sql-data-analyst"
7
+ version = "1.0.0"
8
+ description = "SQL Data Analyst OpenEnv - RL environment for SQL query generation"
9
+ readme = "README.md"
10
+ license = {text = "MIT"}
11
+ authors = [
12
+ {name = "Hackathon Team", email = "team@example.com"}
13
+ ]
14
+ requires-python = ">=3.11"
15
+ dependencies = [
16
+ "openenv>=0.1.13",
17
+ "pydantic>=2.0",
18
+ "fastapi>=0.100",
19
+ "uvicorn>=0.20",
20
+ "openai>=1.0",
21
+ "faker>=18.0",
22
+ "pytest>=7.0",
23
+ ]
24
+
25
+ [project.scripts]
26
+ openenv-sql-analyst = "server.app:main"
27
+
28
+ [project.optional-dependencies]
29
+ dev = [
30
+ "pytest>=7.0",
31
+ "pytest-asyncio>=0.21",
32
+ ]
33
+
34
+ [tool.setuptools.packages.find]
35
+ where = ["."]
36
+ include = ["env*", "baseline*", "server*"]
37
+
38
+ [tool.pytest.ini_options]
39
+ testpaths = ["tests"]
40
+ python_files = ["test_*.py"]
41
+ python_classes = ["Test*"]
42
+ python_functions = ["test_*"]
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ pydantic>=2.0
2
+ fastapi>=0.100
3
+ uvicorn>=0.20
4
+ openai>=1.0
5
+ faker>=18.0
6
+ pytest>=7.0
server/__init__.py ADDED
File without changes
server/app.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from env.server import app as _app
2
+ import uvicorn
3
+
4
+
5
+ def main():
6
+ uvicorn.run(_app, host="0.0.0.0", port=7860)
7
+
8
+
9
+ if __name__ == "__main__":
10
+ main()
11
+
12
+ __all__ = ["app", "main"]
temp_upload/baseline/__init__.py ADDED
File without changes
temp_upload/baseline/prompts.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Prompt templates for the baseline inference script.
3
+ """
4
+
5
+ SYSTEM_PROMPT = """
6
+ You are a SQL data analyst. You are given a database schema and a business question.
7
+ Your job is to write SQL queries to explore the data and submit a final answer.
8
+
9
+ Rules:
10
+ - Only write SELECT or WITH queries (no INSERT, UPDATE, DELETE, DROP, etc.)
11
+ - Reply with JSON only. No explanation.
12
+ - To run a query: {"sql_query": "SELECT ..."}
13
+ - To submit answer: {"submit_answer": "your answer here"}
14
+ - You will see the query result after each step.
15
+ - Submit your answer when you are confident.
16
+
17
+ Important:
18
+ - Always use valid SQL syntax
19
+ - Table names: users, products, orders, order_items, events
20
+ - Dates are stored as ISO timestamps
21
+ - Always filter orders by status='completed' for revenue calculations
22
+ """
23
+
24
+
25
+ def build_prompt(obs) -> str:
26
+ """Build the user prompt from an observation."""
27
+ parts = [
28
+ f"Database schema:\n{obs.schema_summary}",
29
+ f"\nQuestion: {obs.question}",
30
+ f"\nStep: {obs.step} / {obs.max_steps}",
31
+ ]
32
+
33
+ if obs.last_query:
34
+ parts.append(f"\nLast query:\n{obs.last_query}")
35
+
36
+ if obs.last_result:
37
+ if obs.last_result.error:
38
+ parts.append(f"\nSQL error: {obs.last_result.error}")
39
+ elif obs.last_result.rows:
40
+ cols = obs.last_result.columns
41
+ rows = obs.last_result.rows[:10]
42
+ parts.append(f"\nResult columns: {cols}")
43
+ parts.append(
44
+ f"Result rows (first {len(rows)}):\n{json.dumps(rows, indent=2)}"
45
+ )
46
+
47
+ if obs.hints:
48
+ parts.append(f"\nHints: {'; '.join(obs.hints)}")
49
+
50
+ parts.append("\nWhat is your next action? Reply with JSON only.")
51
+ return "\n".join(parts)
52
+
53
+
54
+ import json
55
+
56
+
57
+ def parse_action(response_text: str | None):
58
+ """Extract JSON action from LLM response."""
59
+ from env import Action
60
+
61
+ if not response_text:
62
+ return Action(submit_answer="")
63
+
64
+ text = response_text.strip()
65
+
66
+ text = text.replace("```json", "").replace("```", "").strip()
67
+
68
+ try:
69
+ data = json.loads(text)
70
+
71
+ if "sql_query" in data and data["sql_query"]:
72
+ return Action(sql_query=data["sql_query"])
73
+ elif "submit_answer" in data and data["submit_answer"]:
74
+ return Action(submit_answer=data["submit_answer"])
75
+ except json.JSONDecodeError:
76
+ pass
77
+
78
+ return Action(submit_answer=text)
temp_upload/baseline/run_baseline.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Baseline inference script for sql-data-analyst OpenEnv.
3
+
4
+ Usage:
5
+ export OPENAI_API_KEY=sk-...
6
+ python baseline/run_baseline.py
7
+
8
+ Produces reproducible scores across all 3 tasks.
9
+ """
10
+
11
+ import os
12
+ import json
13
+ import sys
14
+ from pathlib import Path
15
+
16
+ sys.path.insert(0, str(Path(__file__).parent.parent))
17
+
18
+ from typing import List, Dict, Any
19
+
20
+ try:
21
+ from openai import OpenAI
22
+ except ImportError:
23
+ print("Error: openai package not installed. Run: pip install openai")
24
+ sys.exit(1)
25
+
26
+ from env import SQLAnalystEnv, Action
27
+ from baseline.prompts import SYSTEM_PROMPT, build_prompt, parse_action
28
+
29
+
30
+ MODEL = "gpt-4o-mini"
31
+ MAX_STEPS = 20
32
+ TASK_IDS = ["monthly_signups", "top_revenue_category", "churn_analysis"]
33
+
34
+
35
+ def run_task(
36
+ client: OpenAI, task_id: str, max_steps: int = MAX_STEPS
37
+ ) -> Dict[str, Any]:
38
+ """Run a single task with the LLM agent."""
39
+ print(f"\n{'=' * 50}")
40
+ print(f"Task: {task_id}")
41
+ print("=" * 50)
42
+
43
+ env = SQLAnalystEnv(task_id=task_id)
44
+ result = env.reset()
45
+ obs = result.observation
46
+ history = []
47
+ total_reward = 0.0
48
+
49
+ print(f"Question: {obs.question}")
50
+ print(f"Schema: {obs.schema_summary[:200]}...")
51
+
52
+ for step in range(1, max_steps + 1):
53
+ if result.done:
54
+ print(f"Episode done at step {step - 1}")
55
+ break
56
+
57
+ user_prompt = build_prompt(obs)
58
+ history.append({"role": "user", "content": user_prompt})
59
+
60
+ try:
61
+ response = client.chat.completions.create(
62
+ model=MODEL,
63
+ messages=[
64
+ {"role": "system", "content": SYSTEM_PROMPT},
65
+ *history[-8:],
66
+ ],
67
+ temperature=0.0,
68
+ )
69
+ except Exception as e:
70
+ print(f"API Error: {e}")
71
+ break
72
+
73
+ reply = response.choices[0].message.content or ""
74
+ history.append({"role": "assistant", "content": reply})
75
+
76
+ action = parse_action(reply)
77
+
78
+ if action.sql_query:
79
+ print(f"Step {step}: Executing SQL...")
80
+ print(f" Query: {action.sql_query[:100]}...")
81
+ else:
82
+ print(f"Step {step}: Submitting answer...")
83
+ print(
84
+ f" Answer: {action.submit_answer[:100] if action.submit_answer else 'empty'}..."
85
+ )
86
+
87
+ result = env.step(action)
88
+ obs = result.observation
89
+ total_reward = result.info.get("total_reward", 0.0)
90
+
91
+ if result.done:
92
+ break
93
+
94
+ state = env.state()
95
+ print(f"\nFinal total reward: {total_reward:.3f}")
96
+ print(f"Steps taken: {state.step}")
97
+
98
+ return {
99
+ "task_id": task_id,
100
+ "difficulty": state.difficulty,
101
+ "total_reward": round(total_reward, 3),
102
+ "steps": state.step,
103
+ "max_steps": state.max_steps,
104
+ }
105
+
106
+
107
+ def main():
108
+ api_key = os.environ.get("OPENAI_API_KEY")
109
+
110
+ if not api_key:
111
+ print("Error: OPENAI_API_KEY environment variable not set")
112
+ print("Usage: export OPENAI_API_KEY=sk-...")
113
+ sys.exit(1)
114
+
115
+ client = OpenAI(api_key=api_key)
116
+
117
+ print("=" * 60)
118
+ print("SQL Data Analyst - Baseline Inference")
119
+ print("=" * 60)
120
+ print(f"Model: {MODEL}")
121
+ print(f"Max steps per task: {MAX_STEPS}")
122
+ print(f"Tasks: {TASK_IDS}")
123
+
124
+ results = []
125
+ for task_id in TASK_IDS:
126
+ try:
127
+ r = run_task(client, task_id)
128
+ results.append(r)
129
+ except Exception as e:
130
+ print(f"Error running task {task_id}: {e}")
131
+ results.append(
132
+ {
133
+ "task_id": task_id,
134
+ "error": str(e),
135
+ "total_reward": 0.0,
136
+ "steps": 0,
137
+ }
138
+ )
139
+
140
+ print("\n" + "=" * 60)
141
+ print("BASELINE RESULTS")
142
+ print("=" * 60)
143
+
144
+ for r in results:
145
+ task = r.get("task_id", "unknown")
146
+ reward = r.get("total_reward", 0.0)
147
+ steps = r.get("steps", 0)
148
+ print(f"{task:30s} score={reward:.3f} steps={steps}")
149
+
150
+ valid_results = [r for r in results if "total_reward" in r]
151
+ if valid_results:
152
+ avg = sum(r["total_reward"] for r in valid_results) / len(valid_results)
153
+ print(f"\nAverage score: {avg:.3f}")
154
+
155
+ output_file = "baseline_scores.json"
156
+ with open(output_file, "w") as f:
157
+ json.dump(results, f, indent=2)
158
+ print(f"\nSaved results to {output_file}")
159
+
160
+
161
+ if __name__ == "__main__":
162
+ main()
temp_upload/env/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .models import Action, QueryResult, Observation, StepResult, EnvState
2
+ from .environment import SQLAnalystEnv
3
+ from .tasks import TASKS
4
+
5
+ __all__ = [
6
+ "Action",
7
+ "QueryResult",
8
+ "Observation",
9
+ "StepResult",
10
+ "EnvState",
11
+ "SQLAnalystEnv",
12
+ "TASKS",
13
+ ]
temp_upload/env/database.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ import random
3
+ from datetime import datetime, timedelta
4
+ from typing import Optional, Any
5
+ from faker import Faker
6
+
7
+ fake = Faker()
8
+
9
+ SEED_CONFIG = {
10
+ "users": 500,
11
+ "products": 80,
12
+ "orders": 2000,
13
+ "order_items": 5000,
14
+ "events": 8000,
15
+ }
16
+
17
+ CATEGORIES = ["Electronics", "Clothing", "Books", "Home & Garden", "Sports"]
18
+ PLAN_TYPES = ["free", "pro", "enterprise"]
19
+ ORDER_STATUSES = ["pending", "completed", "refunded"]
20
+ EVENT_TYPES = ["page_view", "add_to_cart", "checkout", "login", "logout"]
21
+
22
+
23
+ def create_database(db_path: str = ":memory:") -> sqlite3.Connection:
24
+ conn = sqlite3.connect(db_path)
25
+ conn.row_factory = sqlite3.Row
26
+
27
+ conn.execute("""
28
+ CREATE TABLE users (
29
+ id INTEGER PRIMARY KEY,
30
+ email TEXT NOT NULL,
31
+ country TEXT,
32
+ plan TEXT CHECK(plan IN ('free', 'pro', 'enterprise')),
33
+ created_at TIMESTAMP NOT NULL,
34
+ churned_at TIMESTAMP
35
+ )
36
+ """)
37
+
38
+ conn.execute("""
39
+ CREATE TABLE products (
40
+ id INTEGER PRIMARY KEY,
41
+ name TEXT NOT NULL,
42
+ category TEXT NOT NULL,
43
+ price REAL,
44
+ cost REAL
45
+ )
46
+ """)
47
+
48
+ conn.execute("""
49
+ CREATE TABLE orders (
50
+ id INTEGER PRIMARY KEY,
51
+ user_id INTEGER REFERENCES users(id),
52
+ created_at TIMESTAMP NOT NULL,
53
+ status TEXT CHECK(status IN ('pending', 'completed', 'refunded')),
54
+ total REAL
55
+ )
56
+ """)
57
+
58
+ conn.execute("""
59
+ CREATE TABLE order_items (
60
+ id INTEGER PRIMARY KEY,
61
+ order_id INTEGER REFERENCES orders(id),
62
+ product_id INTEGER REFERENCES products(id),
63
+ qty INTEGER NOT NULL,
64
+ unit_price REAL
65
+ )
66
+ """)
67
+
68
+ conn.execute("""
69
+ CREATE TABLE events (
70
+ id INTEGER PRIMARY KEY,
71
+ user_id INTEGER REFERENCES users(id),
72
+ event_type TEXT,
73
+ metadata TEXT,
74
+ ts TIMESTAMP NOT NULL
75
+ )
76
+ """)
77
+
78
+ conn.commit()
79
+ return conn
80
+
81
+
82
+ def seed_database(conn: sqlite3.Connection) -> None:
83
+ users = _seed_users(conn)
84
+ products = _seed_products(conn)
85
+ orders, order_items = _seed_orders(conn, users, products)
86
+ _seed_events(conn, users, orders)
87
+
88
+
89
+ def _seed_users(conn: sqlite3.Connection) -> list:
90
+ users = []
91
+ now = datetime.now()
92
+ base_date = now - timedelta(days=180)
93
+ recent_date = now - timedelta(days=30)
94
+
95
+ for i in range(SEED_CONFIG["users"]):
96
+ if random.random() < 0.3:
97
+ created_at = recent_date + timedelta(days=random.randint(0, 30))
98
+ else:
99
+ created_at = base_date + timedelta(days=random.randint(0, 180))
100
+
101
+ country = random.choice([fake.country(), None, None, None, None])
102
+ plan = random.choice(PLAN_TYPES)
103
+ churned_at = None
104
+
105
+ if plan == "free" and random.random() < 0.1:
106
+ churned_at = created_at + timedelta(days=random.randint(30, 150))
107
+
108
+ conn.execute(
109
+ "INSERT INTO users (email, country, plan, created_at, churned_at) VALUES (?, ?, ?, ?, ?)",
110
+ (
111
+ fake.email(),
112
+ country,
113
+ plan,
114
+ created_at.isoformat(),
115
+ churned_at.isoformat() if churned_at else None,
116
+ ),
117
+ )
118
+ users.append((i + 1, created_at))
119
+
120
+ conn.commit()
121
+ return users
122
+
123
+
124
+ def _seed_products(conn: sqlite3.Connection) -> list:
125
+ products = []
126
+
127
+ for i in range(SEED_CONFIG["products"]):
128
+ category = random.choice(CATEGORIES)
129
+ price = round(random.uniform(10, 500), 2)
130
+ cost = round(price * random.uniform(0.3, 0.7), 2)
131
+
132
+ conn.execute(
133
+ "INSERT INTO products (name, category, price, cost) VALUES (?, ?, ?, ?)",
134
+ (fake.catch_phrase(), category, price, cost),
135
+ )
136
+ products.append((i + 1, category, price))
137
+
138
+ conn.commit()
139
+ return products
140
+
141
+
142
+ def _seed_orders(conn: sqlite3.Connection, users: list, products: list) -> tuple:
143
+ orders = []
144
+ order_items = []
145
+
146
+ q3_start = datetime(2024, 7, 1)
147
+ q3_end = datetime(2024, 9, 30)
148
+ recent_date = datetime.now()
149
+ old_date = datetime(2024, 1, 1)
150
+
151
+ for i in range(SEED_CONFIG["orders"]):
152
+ user_id = random.choice(users)[0]
153
+
154
+ if random.random() < 0.2:
155
+ created_at = q3_start + timedelta(days=random.randint(0, 91))
156
+ else:
157
+ created_at = old_date + timedelta(days=random.randint(0, 180))
158
+
159
+ status = random.choices(ORDER_STATUSES, weights=[0.1, 0.87, 0.03])[0]
160
+
161
+ conn.execute(
162
+ "INSERT INTO orders (user_id, created_at, status, total) VALUES (?, ?, ?, ?)",
163
+ (user_id, created_at.isoformat(), status, 0),
164
+ )
165
+
166
+ order_id = i + 1
167
+ order_total = 0
168
+
169
+ num_items = random.randint(1, 5)
170
+ for _ in range(num_items):
171
+ product = random.choice(products)
172
+ qty = random.randint(1, 3)
173
+ unit_price = product[2]
174
+ order_total += qty * unit_price
175
+
176
+ conn.execute(
177
+ "INSERT INTO order_items (order_id, product_id, qty, unit_price) VALUES (?, ?, ?, ?)",
178
+ (order_id, product[0], qty, unit_price),
179
+ )
180
+
181
+ conn.execute(
182
+ "UPDATE orders SET total = ? WHERE id = ?",
183
+ (round(order_total, 2), order_id),
184
+ )
185
+ orders.append((order_id, user_id, created_at, status))
186
+
187
+ conn.commit()
188
+ return orders, order_items
189
+
190
+
191
+ def _seed_events(conn: sqlite3.Connection, users: list, orders: list) -> None:
192
+ base_date = datetime.now() - timedelta(days=180)
193
+
194
+ for _ in range(SEED_CONFIG["events"]):
195
+ user_id = random.choice(users)[0]
196
+ ts = base_date + timedelta(
197
+ days=random.randint(0, 180), hours=random.randint(0, 23)
198
+ )
199
+ event_type = random.choice(EVENT_TYPES)
200
+ metadata = '{"page": "/' + fake.uri_path() + '"}'
201
+
202
+ conn.execute(
203
+ "INSERT INTO events (user_id, event_type, metadata, ts) VALUES (?, ?, ?, ?)",
204
+ (user_id, event_type, metadata, ts.isoformat()),
205
+ )
206
+
207
+ conn.commit()
208
+
209
+
210
+ def get_schema_summary(conn: sqlite3.Connection) -> str:
211
+ cursor = conn.execute(
212
+ "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name"
213
+ )
214
+ tables = [r[0] for r in cursor.fetchall()]
215
+
216
+ lines = []
217
+ for table in tables:
218
+ cols = conn.execute(f"PRAGMA table_info({table})").fetchall()
219
+ col_names = [c[1] for c in cols]
220
+ lines.append(f"{table}: ({', '.join(col_names)})")
221
+
222
+ return "\n".join(lines)
223
+
224
+
225
+ def get_ground_truth(conn: sqlite3.Connection, task_id: str) -> Any:
226
+ if task_id == "monthly_signups":
227
+ result = conn.execute(
228
+ "SELECT COUNT(*) FROM users WHERE created_at >= DATE('now', '-30 days')"
229
+ ).fetchone()
230
+ return result[0]
231
+
232
+ elif task_id == "top_revenue_category":
233
+ result = conn.execute("""
234
+ SELECT p.category, SUM(oi.qty * oi.unit_price) AS revenue
235
+ FROM orders o
236
+ JOIN order_items oi ON o.id = oi.order_id
237
+ JOIN products p ON oi.product_id = p.id
238
+ WHERE o.created_at BETWEEN '2024-07-01' AND '2024-09-30'
239
+ AND o.status = 'completed'
240
+ GROUP BY p.category
241
+ ORDER BY revenue DESC
242
+ LIMIT 1
243
+ """).fetchone()
244
+ return result[0] if result else None
245
+
246
+ elif task_id == "churn_analysis":
247
+ result = conn.execute("""
248
+ WITH order_counts AS (
249
+ SELECT user_id, COUNT(*) AS total_orders,
250
+ MAX(created_at) AS last_order_date
251
+ FROM orders
252
+ WHERE status = 'completed'
253
+ GROUP BY user_id
254
+ HAVING COUNT(*) = 3
255
+ ),
256
+ churned AS (
257
+ SELECT oc.user_id
258
+ FROM order_counts oc
259
+ WHERE oc.last_order_date < DATE('now', '-90 days')
260
+ )
261
+ SELECT u.email
262
+ FROM users u
263
+ JOIN churned c ON u.id = c.user_id
264
+ """).fetchall()
265
+ return {row[0].lower() for row in result}
266
+
267
+ return None
temp_upload/env/environment.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ from typing import Optional
3
+ from .models import Action, Observation, StepResult, EnvState, QueryResult
4
+ from .database import create_database, seed_database, get_schema_summary
5
+ from .reward import RewardCalculator
6
+ from .tasks import TASKS
7
+
8
+
9
+ class SQLAnalystEnv:
10
+ """
11
+ OpenEnv-compliant SQL Data Analyst environment.
12
+
13
+ An agent must answer business questions by iteratively
14
+ writing and executing SQL queries.
15
+ """
16
+
17
+ def __init__(self, task_id: str = "monthly_signups"):
18
+ assert task_id in TASKS, f"Unknown task: {task_id}. Choose from {list(TASKS)}"
19
+ self.task_id = task_id
20
+ self.task = TASKS[task_id]
21
+ self.conn: Optional[sqlite3.Connection] = None
22
+ self.step_count: int = 0
23
+ self.total_reward: float = 0.0
24
+ self.done: bool = False
25
+ self._query_history: list = []
26
+ self._reward_calc = RewardCalculator()
27
+
28
+ def reset(self) -> StepResult:
29
+ """Reset environment. Reseed DB. Return initial observation."""
30
+ if self.conn:
31
+ self.conn.close()
32
+
33
+ self.conn = create_database()
34
+ seed_database(self.conn)
35
+ self.step_count = 0
36
+ self.total_reward = 0.0
37
+ self.done = False
38
+ self._query_history = []
39
+
40
+ self.task.compute_ground_truth(self.conn)
41
+
42
+ obs = Observation(
43
+ schema_summary=get_schema_summary(self.conn),
44
+ question=self.task.question,
45
+ step=0,
46
+ max_steps=self.task.max_steps,
47
+ )
48
+ return StepResult(observation=obs, reward=0.0, done=False)
49
+
50
+ def step(self, action: Action) -> StepResult:
51
+ """Execute one agent action. Return (observation, reward, done, info)."""
52
+ assert self.conn is not None, "Call reset() before step()"
53
+ assert not self.done, "Episode is done. Call reset()."
54
+ assert action.is_valid(), (
55
+ "Action must have exactly one of: sql_query, submit_answer"
56
+ )
57
+
58
+ self.step_count += 1
59
+ query_result = None
60
+ error = None
61
+
62
+ if action.sql_query:
63
+ query_result = self._execute_sql(action.sql_query)
64
+ self._query_history.append(action.sql_query)
65
+ error = query_result.error
66
+
67
+ terminal = (
68
+ action.submit_answer is not None or self.step_count >= self.task.max_steps
69
+ )
70
+
71
+ reward = self._reward_calc.calculate(
72
+ action=action,
73
+ result=query_result,
74
+ task=self.task,
75
+ step=self.step_count,
76
+ query_history=self._query_history,
77
+ terminal=terminal,
78
+ )
79
+ self.total_reward += reward
80
+ self.done = terminal
81
+
82
+ obs = Observation(
83
+ schema_summary=get_schema_summary(self.conn),
84
+ question=self.task.question,
85
+ last_query=action.sql_query,
86
+ last_result=query_result,
87
+ last_error=error,
88
+ step=self.step_count,
89
+ max_steps=self.task.max_steps,
90
+ hints=self.task.get_hints(self.step_count),
91
+ done=self.done,
92
+ )
93
+
94
+ return StepResult(
95
+ observation=obs,
96
+ reward=round(reward, 3),
97
+ done=self.done,
98
+ info={
99
+ "step": self.step_count,
100
+ "total_reward": round(self.total_reward, 3),
101
+ "task_id": self.task_id,
102
+ },
103
+ )
104
+
105
+ def state(self) -> EnvState:
106
+ """Return current full state of the environment."""
107
+ return EnvState(
108
+ task_id=self.task_id,
109
+ difficulty=self.task.difficulty,
110
+ step=self.step_count,
111
+ max_steps=self.task.max_steps,
112
+ query_history=self._query_history.copy(),
113
+ total_reward=round(self.total_reward, 3),
114
+ done=self.done,
115
+ )
116
+
117
+ def _execute_sql(self, query: str) -> QueryResult:
118
+ """Execute SQL safely. Block non-SELECT. Return up to 50 rows."""
119
+ q = query.strip().upper()
120
+ if not q.startswith("SELECT") and not q.startswith("WITH"):
121
+ return QueryResult(error="Only SELECT / WITH queries are allowed.")
122
+ try:
123
+ cursor = self.conn.execute(query)
124
+ cols = [d[0] for d in cursor.description] if cursor.description else []
125
+ rows = cursor.fetchmany(50)
126
+ total = len(rows)
127
+ return QueryResult(
128
+ columns=cols,
129
+ rows=[list(r) for r in rows],
130
+ truncated=(total == 50),
131
+ total_rows=total,
132
+ )
133
+ except Exception as e:
134
+ return QueryResult(error=str(e))
temp_upload/env/models.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
+ from typing import Optional, List, Any
3
+
4
+ class Action(BaseModel):
5
+ """What the agent can do each step."""
6
+ sql_query: Optional[str] = Field(
7
+ None,
8
+ description="A SQL SELECT query to execute against the database"
9
+ )
10
+ submit_answer: Optional[str] = Field(
11
+ None,
12
+ description="Final answer to submit. Ends the episode."
13
+ )
14
+
15
+ def is_valid(self) -> bool:
16
+ # Exactly one of the two must be set
17
+ return bool(self.sql_query) != bool(self.submit_answer)
18
+
19
+
20
+ class QueryResult(BaseModel):
21
+ """Result of executing a SQL query."""
22
+ columns: List[str] = []
23
+ rows: List[List[Any]] = []
24
+ error: Optional[str] = None
25
+ truncated: bool = False
26
+ total_rows: int = 0
27
+
28
+
29
+ class Observation(BaseModel):
30
+ """What the agent sees after each step."""
31
+ schema_summary: str = Field(..., description="Compact DB schema")
32
+ question: str = Field(..., description="Business question to answer")
33
+ last_query: Optional[str] = None
34
+ last_result: Optional[QueryResult] = None
35
+ last_error: Optional[str] = None
36
+ step: int = 0
37
+ max_steps: int = 20
38
+ hints: List[str] = []
39
+ done: bool = False
40
+
41
+
42
+ class StepResult(BaseModel):
43
+ """Full result returned by step()."""
44
+ observation: Observation
45
+ reward: float = 0.0
46
+ done: bool = False
47
+ info: dict = {}
48
+
49
+
50
+ class EnvState(BaseModel):
51
+ """Full environment state returned by state()."""
52
+ task_id: str
53
+ difficulty: str
54
+ step: int
55
+ max_steps: int
56
+ query_history: List[str] = []
57
+ total_reward: float = 0.0
58
+ done: bool = False
temp_upload/env/reward.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List, Any
2
+ from .models import Action, QueryResult
3
+
4
+
5
+ class RewardCalculator:
6
+ """Calculate rewards for agent actions in the SQL analyst environment."""
7
+
8
+ def calculate(
9
+ self,
10
+ action: Action,
11
+ result: Optional[QueryResult],
12
+ task: Any,
13
+ step: int,
14
+ query_history: List[str],
15
+ terminal: bool,
16
+ ) -> float:
17
+ """Calculate reward based on action, result, and task."""
18
+ reward = 0.0
19
+
20
+ if action.sql_query and result:
21
+ if not result.error:
22
+ reward += 0.15
23
+
24
+ relevant = self._count_relevant_tables(
25
+ action.sql_query, task.relevant_tables
26
+ )
27
+ if relevant > 0:
28
+ reward += 0.10
29
+
30
+ if result.rows and len(result.rows) > 0:
31
+ reward += 0.05
32
+
33
+ if result.rows and len(result.rows) < 1000:
34
+ reward += 0.05
35
+
36
+ if step > 3:
37
+ reward -= 0.02 * (step - 3)
38
+
39
+ if self._is_stuck(query_history):
40
+ reward -= 0.10
41
+
42
+ if terminal and action.submit_answer:
43
+ task_score = task.grade(action.submit_answer)
44
+ reward += task_score * 0.60
45
+
46
+ return max(0.0, min(1.0, reward))
47
+
48
+ def _count_relevant_tables(self, query: str, relevant_tables: List[str]) -> int:
49
+ query_lower = query.lower()
50
+ return sum(1 for t in relevant_tables if t.lower() in query_lower)
51
+
52
+ def _is_stuck(self, history: List[str]) -> bool:
53
+ if len(history) < 3:
54
+ return False
55
+ return len(set(history[-3:])) == 1
temp_upload/env/server.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI server for SQL Data Analyst OpenEnv.
3
+
4
+ Provides REST and WebSocket endpoints for HuggingFace Spaces deployment.
5
+ """
6
+
7
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
8
+ from pydantic import BaseModel
9
+ from typing import Optional, Dict, Any
10
+ import json
11
+ import asyncio
12
+
13
+ from env import SQLAnalystEnv, Action
14
+
15
+
16
+ app = FastAPI(title="SQL Data Analyst Environment")
17
+
18
+ envs: Dict[str, SQLAnalystEnv] = {}
19
+
20
+
21
+ class ResetRequest(BaseModel):
22
+ task_id: str = "monthly_signups"
23
+
24
+
25
+ class StepRequest(BaseModel):
26
+ session_id: str
27
+ sql_query: Optional[str] = None
28
+ submit_answer: Optional[str] = None
29
+
30
+
31
+ class StateRequest(BaseModel):
32
+ session_id: str
33
+
34
+
35
+ @app.get("/")
36
+ async def root():
37
+ return {
38
+ "name": "sql-data-analyst",
39
+ "version": "1.0.0",
40
+ "description": "SQL Data Analyst OpenEnv - RL environment for SQL query generation",
41
+ }
42
+
43
+
44
+ @app.post("/reset")
45
+ async def reset(req: ResetRequest) -> Dict[str, Any]:
46
+ session_id = req.task_id
47
+
48
+ env = SQLAnalystEnv(task_id=req.task_id)
49
+ result = env.reset()
50
+ envs[session_id] = env
51
+
52
+ return {
53
+ "session_id": session_id,
54
+ "observation": {
55
+ "schema_summary": result.observation.schema_summary,
56
+ "question": result.observation.question,
57
+ "step": result.observation.step,
58
+ "max_steps": result.observation.max_steps,
59
+ "hints": result.observation.hints,
60
+ "done": result.observation.done,
61
+ },
62
+ "reward": result.reward,
63
+ "done": result.done,
64
+ }
65
+
66
+
67
+ @app.post("/step")
68
+ async def step(req: StepRequest) -> Dict[str, Any]:
69
+ session_id = req.session_id
70
+
71
+ if session_id not in envs:
72
+ return {"error": "Session not found. Call /reset first."}
73
+
74
+ env = envs[session_id]
75
+
76
+ action = Action(sql_query=req.sql_query, submit_answer=req.submit_answer)
77
+
78
+ result = env.step(action)
79
+
80
+ return {
81
+ "observation": {
82
+ "schema_summary": result.observation.schema_summary,
83
+ "question": result.observation.question,
84
+ "last_query": result.observation.last_query,
85
+ "last_result": {
86
+ "columns": result.observation.last_result.columns
87
+ if result.observation.last_result
88
+ else None,
89
+ "rows": result.observation.last_result.rows
90
+ if result.observation.last_result
91
+ else None,
92
+ "error": result.observation.last_result.error
93
+ if result.observation.last_result
94
+ else None,
95
+ }
96
+ if result.observation.last_result
97
+ else None,
98
+ "last_error": result.observation.last_error,
99
+ "step": result.observation.step,
100
+ "max_steps": result.observation.max_steps,
101
+ "hints": result.observation.hints,
102
+ "done": result.observation.done,
103
+ },
104
+ "reward": result.reward,
105
+ "done": result.done,
106
+ "info": result.info,
107
+ }
108
+
109
+
110
+ @app.post("/state")
111
+ async def state(req: StateRequest) -> Dict[str, Any]:
112
+ session_id = req.session_id
113
+
114
+ if session_id not in envs:
115
+ return {"error": "Session not found. Call /reset first."}
116
+
117
+ env = envs[session_id]
118
+ state = env.state()
119
+
120
+ return {
121
+ "task_id": state.task_id,
122
+ "difficulty": state.difficulty,
123
+ "step": state.step,
124
+ "max_steps": state.max_steps,
125
+ "query_history": state.query_history,
126
+ "total_reward": state.total_reward,
127
+ "done": state.done,
128
+ }
129
+
130
+
131
+ @app.post("/delete")
132
+ async def delete_session(req: StateRequest) -> Dict[str, str]:
133
+ session_id = req.session_id
134
+
135
+ if session_id in envs:
136
+ del envs[session_id]
137
+ return {"status": "deleted", "session_id": session_id}
138
+
139
+ return {"status": "not_found", "session_id": session_id}
140
+
141
+
142
+ @app.websocket("/ws")
143
+ async def websocket_endpoint(websocket: WebSocket):
144
+ await websocket.accept()
145
+
146
+ session_id = None
147
+ env = None
148
+
149
+ try:
150
+ while True:
151
+ data = await websocket.receive_text()
152
+ message = json.loads(data)
153
+
154
+ action_type = message.get("type")
155
+
156
+ if action_type == "reset":
157
+ task_id = message.get("task_id", "monthly_signups")
158
+ env = SQLAnalystEnv(task_id=task_id)
159
+ result = env.reset()
160
+ session_id = task_id
161
+ envs[session_id] = env
162
+
163
+ await websocket.send_json(
164
+ {
165
+ "type": "reset",
166
+ "observation": {
167
+ "schema_summary": result.observation.schema_summary,
168
+ "question": result.observation.question,
169
+ "step": result.observation.step,
170
+ "max_steps": result.observation.max_steps,
171
+ "hints": result.observation.hints,
172
+ },
173
+ "reward": result.reward,
174
+ "done": result.done,
175
+ }
176
+ )
177
+
178
+ elif action_type == "step":
179
+ if not env:
180
+ await websocket.send_json({"error": "Call reset first"})
181
+ continue
182
+
183
+ action = Action(
184
+ sql_query=message.get("sql_query"),
185
+ submit_answer=message.get("submit_answer"),
186
+ )
187
+
188
+ result = env.step(action)
189
+
190
+ await websocket.send_json(
191
+ {
192
+ "type": "step",
193
+ "observation": {
194
+ "schema_summary": result.observation.schema_summary,
195
+ "question": result.observation.question,
196
+ "last_query": result.observation.last_query,
197
+ "last_result": {
198
+ "columns": result.observation.last_result.columns
199
+ if result.observation.last_result
200
+ else None,
201
+ "rows": result.observation.last_result.rows
202
+ if result.observation.last_result
203
+ else None,
204
+ "error": result.observation.last_result.error
205
+ if result.observation.last_result
206
+ else None,
207
+ }
208
+ if result.observation.last_result
209
+ else None,
210
+ "step": result.observation.step,
211
+ "hints": result.observation.hints,
212
+ "done": result.observation.done,
213
+ },
214
+ "reward": result.reward,
215
+ "done": result.done,
216
+ "info": result.info,
217
+ }
218
+ )
219
+
220
+ elif action_type == "state":
221
+ if not env:
222
+ await websocket.send_json({"error": "Call reset first"})
223
+ continue
224
+
225
+ state = env.state()
226
+
227
+ await websocket.send_json(
228
+ {
229
+ "type": "state",
230
+ "task_id": state.task_id,
231
+ "difficulty": state.difficulty,
232
+ "step": state.step,
233
+ "max_steps": state.max_steps,
234
+ "query_history": state.query_history,
235
+ "total_reward": state.total_reward,
236
+ "done": state.done,
237
+ }
238
+ )
239
+
240
+ elif action_type == "close":
241
+ if session_id and session_id in envs:
242
+ del envs[session_id]
243
+ break
244
+
245
+ except WebSocketDisconnect:
246
+ pass
247
+ except Exception as e:
248
+ await websocket.send_json({"error": str(e)})
249
+
250
+
251
+ if __name__ == "__main__":
252
+ import uvicorn
253
+
254
+ uvicorn.run(app, host="0.0.0.0", port=7860)
temp_upload/env/tasks/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base import BaseTask
2
+ from .easy import MonthlySignupsTask
3
+ from .medium import TopRevenueCategoryTask
4
+ from .hard import ChurnAnalysisTask
5
+
6
+
7
+ TASKS = {
8
+ "monthly_signups": MonthlySignupsTask(),
9
+ "top_revenue_category": TopRevenueCategoryTask(),
10
+ "churn_analysis": ChurnAnalysisTask(),
11
+ }
12
+
13
+
14
+ __all__ = [
15
+ "BaseTask",
16
+ "MonthlySignupsTask",
17
+ "TopRevenueCategoryTask",
18
+ "ChurnAnalysisTask",
19
+ "TASKS",
20
+ ]
temp_upload/env/tasks/base.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ import sqlite3
3
+ import re
4
+ from typing import Any, List, Optional
5
+
6
+
7
+ class BaseTask(ABC):
8
+ """Abstract base class for all tasks."""
9
+
10
+ task_id: str
11
+ difficulty: str
12
+ max_steps: int
13
+ question: str
14
+ relevant_tables: List[str]
15
+ sql_hint: str
16
+
17
+ def __init__(self):
18
+ self.ground_truth: Any = None
19
+ self.top_3_categories: List[str] = []
20
+
21
+ @abstractmethod
22
+ def compute_ground_truth(self, conn: sqlite3.Connection) -> None:
23
+ """Compute ground truth after database seeding."""
24
+ pass
25
+
26
+ @abstractmethod
27
+ def grade(self, submitted_answer: str) -> float:
28
+ """Grade the submitted answer. Returns score 0.0-1.0."""
29
+ pass
30
+
31
+ def get_hints(self, step: int) -> List[str]:
32
+ """Return progressive hints based on current step."""
33
+ hints = []
34
+ if step > 5:
35
+ hints.append(
36
+ f"Hint: The relevant tables are: {', '.join(self.relevant_tables)}"
37
+ )
38
+ if step > 10:
39
+ hints.append(f"Hint: Try using {self.sql_hint}")
40
+ if step > 15:
41
+ hints.append("Hint: Make sure to submit your answer with submit_answer.")
42
+ return hints
43
+
44
+ def _normalize(self, text: str) -> str:
45
+ """Remove common LLM formatting and normalize text."""
46
+ text = text.strip().lower()
47
+ text = re.sub(r"the (answer|result|category) is:?\s*", "", text)
48
+ text = re.sub(r"\*+", "", text)
49
+ text = re.sub(r"```.*?```", "", text, flags=re.DOTALL)
50
+ text = re.sub(r"`[^`]+`", lambda m: m.group().strip("`"), text)
51
+ text = re.sub(r"\s+", " ", text)
52
+ return text.strip()
temp_upload/env/tasks/easy.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ from .base import BaseTask
3
+
4
+
5
+ class MonthlySignupsTask(BaseTask):
6
+ """Task 1 — Easy: Count users who signed up in the last 30 days."""
7
+
8
+ task_id = "monthly_signups"
9
+ difficulty = "easy"
10
+ max_steps = 10
11
+ question = "How many users signed up in the last 30 days?"
12
+ relevant_tables = ["users"]
13
+ sql_hint = "COUNT(*) with WHERE clause on created_at"
14
+
15
+ def compute_ground_truth(self, conn: sqlite3.Connection) -> None:
16
+ result = conn.execute(
17
+ "SELECT COUNT(*) FROM users WHERE created_at >= DATE('now', '-30 days')"
18
+ ).fetchone()
19
+ self.ground_truth = result[0] if result else 0
20
+
21
+ def grade(self, submitted_answer: str) -> float:
22
+ try:
23
+ val = int(submitted_answer.strip().replace(",", ""))
24
+ if val == self.ground_truth:
25
+ return 1.0
26
+ if abs(val - self.ground_truth) <= 3:
27
+ return 0.6
28
+ if abs(val - self.ground_truth) <= 10:
29
+ return 0.3
30
+ except (ValueError, AttributeError):
31
+ pass
32
+ return 0.0
temp_upload/env/tasks/hard.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ from .base import BaseTask
3
+
4
+
5
+ class ChurnAnalysisTask(BaseTask):
6
+ """Task 3 — Hard: Find users who placed exactly 3 orders and then churned."""
7
+
8
+ task_id = "churn_analysis"
9
+ difficulty = "hard"
10
+ max_steps = 20
11
+ question = "Find the email addresses of users who placed exactly 3 orders and then never ordered again (churned after their 3rd purchase). Return as a comma-separated list."
12
+ relevant_tables = ["users", "orders"]
13
+ sql_hint = "CTE with COUNT and HAVING"
14
+
15
+ def compute_ground_truth(self, conn: sqlite3.Connection) -> None:
16
+ result = conn.execute("""
17
+ WITH order_counts AS (
18
+ SELECT user_id, COUNT(*) AS total_orders,
19
+ MAX(created_at) AS last_order_date
20
+ FROM orders
21
+ WHERE status = 'completed'
22
+ GROUP BY user_id
23
+ HAVING COUNT(*) = 3
24
+ ),
25
+ churned AS (
26
+ SELECT oc.user_id
27
+ FROM order_counts oc
28
+ WHERE oc.last_order_date < DATE('now', '-90 days')
29
+ )
30
+ SELECT u.email
31
+ FROM users u
32
+ JOIN churned c ON u.id = c.user_id
33
+ """).fetchall()
34
+
35
+ self.ground_truth = {row[0].lower() for row in result}
36
+
37
+ def grade(self, submitted_answer: str) -> float:
38
+ if not submitted_answer.strip():
39
+ return 0.0
40
+
41
+ submitted = {e.strip().lower() for e in submitted_answer.split(",") if "@" in e}
42
+
43
+ if not submitted:
44
+ return 0.0
45
+
46
+ correct = {e.lower() for e in self.ground_truth}
47
+ tp = len(submitted & correct)
48
+
49
+ if tp == 0:
50
+ return 0.0
51
+
52
+ precision = tp / len(submitted) if submitted else 0
53
+ recall = tp / len(correct) if correct else 0
54
+
55
+ if precision + recall == 0:
56
+ return 0.0
57
+
58
+ f1 = 2 * precision * recall / (precision + recall)
59
+ return round(f1, 3)
temp_upload/env/tasks/medium.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ from .base import BaseTask
3
+
4
+
5
+ class TopRevenueCategoryTask(BaseTask):
6
+ """Task 2 — Medium: Find product category with most revenue in Q3."""
7
+
8
+ task_id = "top_revenue_category"
9
+ difficulty = "medium"
10
+ max_steps = 15
11
+ question = (
12
+ "Which product category generated the most revenue in Q3 (July-September)?"
13
+ )
14
+ relevant_tables = ["orders", "order_items", "products"]
15
+ sql_hint = "JOIN with GROUP BY and ORDER BY"
16
+
17
+ def compute_ground_truth(self, conn: sqlite3.Connection) -> None:
18
+ result = conn.execute("""
19
+ SELECT p.category, SUM(oi.qty * oi.unit_price) AS revenue
20
+ FROM orders o
21
+ JOIN order_items oi ON o.id = oi.order_id
22
+ JOIN products p ON oi.product_id = p.id
23
+ WHERE o.created_at BETWEEN '2024-07-01' AND '2024-09-30'
24
+ AND o.status = 'completed'
25
+ GROUP BY p.category
26
+ ORDER BY revenue DESC
27
+ LIMIT 1
28
+ """).fetchone()
29
+
30
+ self.ground_truth = result[0] if result else None
31
+
32
+ all_categories = conn.execute("""
33
+ SELECT p.category, SUM(oi.qty * oi.unit_price) AS revenue
34
+ FROM orders o
35
+ JOIN order_items oi ON o.id = oi.order_id
36
+ JOIN products p ON oi.product_id = p.id
37
+ WHERE o.created_at BETWEEN '2024-07-01' AND '2024-09-30'
38
+ AND o.status = 'completed'
39
+ GROUP BY p.category
40
+ ORDER BY revenue DESC
41
+ """).fetchall()
42
+
43
+ self.top_3_categories = [row[0] for row in all_categories[:3]]
44
+
45
+ def grade(self, submitted_answer: str) -> float:
46
+ answer = self._normalize(submitted_answer)
47
+
48
+ if self.ground_truth and self.ground_truth.lower() in answer:
49
+ return 1.0
50
+
51
+ if any(cat.lower() in answer for cat in self.top_3_categories):
52
+ return 0.4
53
+
54
+ return 0.0
temp_upload/env/utils.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+
4
+ def normalize_answer(raw: str) -> str:
5
+ """Remove common LLM answer preambles and formatting."""
6
+ text = raw.strip().lower()
7
+ text = re.sub(r"the (answer|result) is:?\s*", "", text)
8
+ text = re.sub(r"\*+", "", text)
9
+ text = re.sub(r"```.*?```", "", text, flags=re.DOTALL)
10
+ text = re.sub(r"`[^`]+`", lambda m: m.group().strip("`"), text)
11
+ text = re.sub(r"\s+", " ", text)
12
+ return text.strip()
13
+
14
+
15
+ FORBIDDEN_KEYWORDS = [
16
+ "DROP",
17
+ "DELETE",
18
+ "INSERT",
19
+ "UPDATE",
20
+ "ALTER",
21
+ "CREATE",
22
+ "TRUNCATE",
23
+ ]
24
+
25
+
26
+ def is_safe_query(query: str) -> bool:
27
+ """Check if query is safe (SELECT-only)."""
28
+ upper = query.upper()
29
+ return not any(kw in upper for kw in FORBIDDEN_KEYWORDS)
temp_upload/server/__init__.py ADDED
File without changes
temp_upload/server/app.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from env.server import app as _app
2
+ import uvicorn
3
+
4
+
5
+ def main():
6
+ uvicorn.run(_app, host="0.0.0.0", port=7860)
7
+
8
+
9
+ if __name__ == "__main__":
10
+ main()
11
+
12
+ __all__ = ["app", "main"]
temp_upload/tests/__init__.py ADDED
File without changes