Viani commited on
Commit
bcd8636
·
verified ·
1 Parent(s): 5e4b568

Deploy DataDetective: 9-task business investigation environment

Browse files
.env.example ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LLM Configuration (required by hackathon evaluator)
2
+ API_BASE_URL=https://router.huggingface.co/v1
3
+ MODEL_NAME=gpt-4.1-mini
4
+ HF_TOKEN=hf_your_token_here
5
+
6
+ # Environment server
7
+ ENV_URL=http://localhost:7860
8
+
9
+ # AMD LLM Gateway (local development only — overrides API_BASE_URL when set)
10
+ # AMD_LLM_API_KEY=your-ocp-apim-subscription-key-here
11
+ # AMD_GATEWAY_BASE=https://llm-api.amd.com/openai
.gitignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ *.pyo
4
+ *.egg-info/
5
+ dist/
6
+ build/
7
+ .env
8
+ .venv/
9
+ venv/
10
+ *.sqlite
11
+ *.db
12
+ .DS_Store
Dockerfile ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ WORKDIR /app
4
+
5
+ COPY server/requirements.txt ./requirements.txt
6
+ RUN pip install --no-cache-dir -r requirements.txt
7
+
8
+ COPY . .
9
+
10
+ EXPOSE 7860
11
+
12
+ CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "2", "--ws-ping-interval", "300", "--ws-ping-timeout", "300"]
README.md CHANGED
@@ -1,10 +1,138 @@
1
  ---
2
  title: DataDetective
3
- emoji: 🐠
4
  colorFrom: blue
5
  colorTo: green
6
  sdk: docker
7
- pinned: false
8
  ---
9
 
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: DataDetective
3
+ emoji: 🔍
4
  colorFrom: blue
5
  colorTo: green
6
  sdk: docker
7
+ app_port: 7860
8
  ---
9
 
10
+ # DataDetective Business Incident Investigation Environment
11
+
12
+ An [OpenEnv](https://github.com/meta-pytorch/OpenEnv) environment where AI
13
+ agents investigate real-world business incidents by querying a SQL database,
14
+ analysing patterns, and submitting root-cause findings.
15
+
16
+ ## What It Does
17
+
18
+ The agent is given a realistic company database (TechMart — a mid-size B2B+B2C
19
+ electronics retailer) and a business problem to investigate. It can execute
20
+ SQL queries to explore the data, then submit a final written analysis. The
21
+ environment automatically grades the analysis based on whether key findings
22
+ were identified. Each task has 5 grading criteria worth 0.20 each, enabling
23
+ meaningful partial credit.
24
+
25
+ ## Tasks (Easy → Hard)
26
+
27
+ | # | Task ID | Difficulty | Scenario |
28
+ |---|---------|-----------|----------|
29
+ | 1 | `orders_drop` | Easy | Order volume dropped sharply after promo ended |
30
+ | 2 | `returns_spike` | Medium | Product returns spiking in West region (defective SKU) |
31
+ | 3 | `supplier_quality` | Medium | Supplier-level quality crisis across multiple products |
32
+ | 4 | `shipping_delay` | Medium-Hard | Customer satisfaction crisis from carrier delays |
33
+ | 5 | `inventory_stockout` | Medium-Hard | Regional sales underperformance from warehouse stockout |
34
+ | 6 | `customer_churn` | Hard | Active customer decline across segments post price hike |
35
+ | 7 | `revenue_paradox` | Hard | Revenue up but profit down — multi-causal margin erosion |
36
+ | 8 | `fraud_detection` | Hard | Coordinated fraud ring with fake accounts |
37
+ | 9 | `repeat_purchase_decline` | Hard | Repeat purchase collapse masked by acquisition spend |
38
+
39
+ Each task is scored 0.0 – 1.0 based on specific findings the agent must discover.
40
+
41
+ ## Action / Observation Spaces
42
+
43
+ ### Action (`DataDetectiveAction`)
44
+
45
+ | Field | Type | Description |
46
+ |-------|------|-------------|
47
+ | `action_type` | `str` | `"query"` to run SQL, `"answer"` to submit findings |
48
+ | `content` | `str` | SQL query string or final analysis text |
49
+
50
+ ### Observation (`DataDetectiveObservation`)
51
+
52
+ | Field | Type | Description |
53
+ |-------|------|-------------|
54
+ | `output` | `str` | Query results (formatted table) or feedback |
55
+ | `task_description` | `str` | The investigation task |
56
+ | `schema_info` | `str` | Database schema (shown at reset) |
57
+ | `step_number` | `int` | Current step |
58
+ | `max_steps` | `int` | Maximum steps allowed (30) |
59
+ | `message` | `str` | Status message |
60
+
61
+ ## Database Schema (11 Tables)
62
+
63
+ The TechMart database includes:
64
+
65
+ | Table | Description |
66
+ |-------|-------------|
67
+ | `customers` | Customer demographics (region, segment, signup date) |
68
+ | `products` | Product catalog (category, price, cost, supplier) |
69
+ | `orders` | Order history with totals |
70
+ | `order_items` | Line items with quantity and unit price |
71
+ | `returns` | Product returns with reasons and refund amounts |
72
+ | `promotions` | Promotional campaigns with discount percentages |
73
+ | `price_changes` | Historical price adjustments |
74
+ | `shipping` | Shipment records with carrier and delivery dates |
75
+ | `support_tickets` | Customer support tickets by category and priority |
76
+ | `inventory_log` | Daily stock levels per product per warehouse region |
77
+ | `marketing_spend` | Daily marketing spend by channel, campaign, and region |
78
+
79
+ All data is synthetic, generated in-memory (no external databases required).
80
+
81
+ ## Quick Start
82
+
83
+ ### 1. Install Dependencies
84
+
85
+ ```bash
86
+ pip install -r requirements.txt
87
+ ```
88
+
89
+ ### 2. Start the Server
90
+
91
+ ```bash
92
+ uvicorn server.app:app --host 0.0.0.0 --port 7860
93
+ ```
94
+
95
+ ### 3. Health Check
96
+
97
+ ```bash
98
+ curl http://localhost:7860/health
99
+ ```
100
+
101
+ ### 4. Run the Baseline Agent
102
+
103
+ ```bash
104
+ API_BASE_URL="https://router.huggingface.co/v1" \
105
+ MODEL_NAME="gpt-4.1-mini" \
106
+ HF_TOKEN="hf_..." \
107
+ python inference.py
108
+ ```
109
+
110
+ ### 5. Docker
111
+
112
+ ```bash
113
+ docker build -t data-detective .
114
+ docker run -p 7860:7860 data-detective
115
+ ```
116
+
117
+ ## Environment Variables
118
+
119
+ | Env Var | Purpose | Required |
120
+ |---------|---------|----------|
121
+ | `API_BASE_URL` | LLM endpoint URL | Yes |
122
+ | `MODEL_NAME` | Model identifier | Yes |
123
+ | `HF_TOKEN` | API key / HF token | Yes |
124
+ | `ENV_URL` | Environment server URL | No (default: `http://localhost:7860`) |
125
+
126
+ ## How Grading Works
127
+
128
+ Each task has an automated grader that checks the agent's final answer for
129
+ specific key findings (keywords, patterns, named entities). Each task has 5
130
+ grading criteria worth 0.20 each, for a maximum score of 1.0. Partial credit
131
+ is awarded for each finding discovered.
132
+
133
+ ## Setup Requirements
134
+
135
+ - Python 3.10+
136
+ - No GPU required
137
+ - Runs within 2 vCPU / 8 GB memory
138
+ - All data is generated in-memory (no external databases)
__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .models import DataDetectiveAction, DataDetectiveObservation, DataDetectiveState
2
+ from .client import DataDetectiveEnv
3
+
4
+ __all__ = [
5
+ "DataDetectiveAction",
6
+ "DataDetectiveObservation",
7
+ "DataDetectiveState",
8
+ "DataDetectiveEnv",
9
+ ]
client.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """WebSocket client for the DataDetective environment."""
2
+
3
+ from typing import Dict
4
+
5
+ from openenv.core.env_client import EnvClient
6
+ from openenv.core.client_types import StepResult
7
+
8
+ from .models import DataDetectiveAction, DataDetectiveObservation, DataDetectiveState
9
+
10
+
11
+ class DataDetectiveEnv(
12
+ EnvClient[DataDetectiveAction, DataDetectiveObservation, DataDetectiveState]
13
+ ):
14
+ """
15
+ Async/sync client for DataDetective.
16
+
17
+ Example (sync):
18
+ >>> with DataDetectiveEnv(base_url="http://localhost:7860").sync() as env:
19
+ ... result = env.reset(task_id="orders_drop")
20
+ ... result = env.step(DataDetectiveAction(action_type="query", content="SELECT COUNT(*) FROM orders"))
21
+ """
22
+
23
+ def _step_payload(self, action: DataDetectiveAction) -> Dict:
24
+ return {"action_type": action.action_type, "content": action.content}
25
+
26
+ def _parse_result(self, payload: Dict) -> StepResult[DataDetectiveObservation]:
27
+ obs = payload.get("observation", {})
28
+ observation = DataDetectiveObservation(
29
+ output=obs.get("output", ""),
30
+ task_description=obs.get("task_description", ""),
31
+ schema_info=obs.get("schema_info", ""),
32
+ step_number=obs.get("step_number", 0),
33
+ max_steps=obs.get("max_steps", 30),
34
+ message=obs.get("message", ""),
35
+ done=payload.get("done", False),
36
+ reward=payload.get("reward"),
37
+ )
38
+ return StepResult(
39
+ observation=observation,
40
+ reward=payload.get("reward"),
41
+ done=payload.get("done", False),
42
+ )
43
+
44
+ def _parse_state(self, payload: Dict) -> DataDetectiveState:
45
+ return DataDetectiveState(
46
+ episode_id=payload.get("episode_id"),
47
+ step_count=payload.get("step_count", 0),
48
+ task_id=payload.get("task_id", ""),
49
+ queries_executed=payload.get("queries_executed", 0),
50
+ max_steps=payload.get("max_steps", 30),
51
+ )
inference.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Baseline inference script for DataDetective.
4
+
5
+ Uses an LLM via the OpenAI-compatible API to investigate each task by
6
+ running SQL queries and submitting a final analysis.
7
+
8
+ Required environment variables (set by hackathon evaluator):
9
+ API_BASE_URL — LLM endpoint (e.g. https://router.huggingface.co/v1)
10
+ MODEL_NAME — model identifier (e.g. gpt-4.1-mini)
11
+ HF_TOKEN — API key / Hugging Face token
12
+
13
+ Optional:
14
+ ENV_URL — DataDetective server URL (default http://localhost:7860)
15
+ AMD_LLM_API_KEY — If set, uses AMD Gateway instead (local dev only)
16
+ """
17
+
18
+ import asyncio
19
+ import json
20
+ import os
21
+ import re
22
+ import sys
23
+ import time
24
+
25
+ from openai import AzureOpenAI, OpenAI
26
+
27
+ import websockets.asyncio.client as _wsc
28
+ _orig_ws_connect = _wsc.connect
29
+ def _patched_connect(*a, **kw):
30
+ kw.setdefault("ping_interval", 300)
31
+ kw.setdefault("ping_timeout", 300)
32
+ return _orig_ws_connect(*a, **kw)
33
+ _wsc.connect = _patched_connect
34
+
35
+ import openenv.core.env_client as _ec
36
+ _ec.ws_connect = _patched_connect
37
+
38
+ from openenv.core.generic_client import GenericEnvClient
39
+
40
+ # ---------------------------------------------------------------------------
41
+ # Configuration
42
+ # ---------------------------------------------------------------------------
43
+
44
+ API_BASE_URL = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1")
45
+ MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4.1-mini")
46
+ HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("API_KEY", "")
47
+ AMD_LLM_API_KEY = os.environ.get("AMD_LLM_API_KEY", "")
48
+ ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860").rstrip("/")
49
+
50
+ BENCHMARK = "data_detective"
51
+ MAX_STEPS = 20
52
+
53
+ TASK_IDS = [
54
+ "orders_drop",
55
+ "returns_spike",
56
+ "customer_churn",
57
+ "shipping_delay",
58
+ "revenue_paradox",
59
+ "supplier_quality",
60
+ "inventory_stockout",
61
+ "fraud_detection",
62
+ "repeat_purchase_decline",
63
+ ]
64
+
65
+
66
+ def _build_llm_client() -> OpenAI:
67
+ if AMD_LLM_API_KEY:
68
+ return AzureOpenAI(
69
+ api_key="dummy",
70
+ api_version="2024-02-01",
71
+ base_url=os.environ.get("AMD_GATEWAY_BASE", "https://llm-api.amd.com/openai"),
72
+ default_headers={"Ocp-Apim-Subscription-Key": AMD_LLM_API_KEY},
73
+ )
74
+
75
+ if HF_TOKEN:
76
+ return OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
77
+
78
+ print(
79
+ "ERROR: Set HF_TOKEN (or API_KEY) for LLM access, "
80
+ "or AMD_LLM_API_KEY for AMD Gateway. Exiting.",
81
+ file=sys.stderr,
82
+ )
83
+ sys.exit(1)
84
+
85
+
86
+ llm = _build_llm_client()
87
+
88
+ SYSTEM_PROMPT = """\
89
+ You are an expert data analyst investigating a business incident using a
90
+ SQL database. You have a LIMITED number of query steps, so be strategic.
91
+
92
+ At each turn respond with EXACTLY one JSON object (no extra text):
93
+
94
+ {{"action_type": "query", "content": "<SQL query>"}}
95
+ {{"action_type": "answer", "content": "<your analysis>"}}
96
+
97
+ Investigation strategy:
98
+ 1. EXPLORE (1-2 queries): List tables and sample key columns to understand
99
+ the schema. Note all available tables -- some may hold critical clues.
100
+ 2. HYPOTHESISE: Based on the task description, form 2-3 likely root causes.
101
+ 3. QUERY (targeted): Run focused queries that confirm or reject each
102
+ hypothesis. Use JOINs across tables, GROUP BY with aggregates, and
103
+ compare time periods. Avoid broad SELECT * scans.
104
+ 4. QUANTIFY: For every finding, gather specific numbers -- counts, totals,
105
+ percentages, before/after comparisons.
106
+ 5. ANSWER: Submit a thorough analysis naming every root cause with
107
+ supporting evidence. Include specific product names, regions, customer
108
+ segments, suppliers, dollar amounts, dates, and percentages.
109
+
110
+ You have {max_steps} steps total. Budget roughly 70 % for querying and
111
+ reserve the last few steps for your answer. Do NOT run out of steps
112
+ without submitting -- partial evidence is better than none.
113
+ """
114
+
115
+ # ---------------------------------------------------------------------------
116
+ # Helpers
117
+ # ---------------------------------------------------------------------------
118
+
119
+ def _extract_json(text: str) -> dict:
120
+ text = text.strip()
121
+ try:
122
+ return json.loads(text)
123
+ except json.JSONDecodeError:
124
+ pass
125
+ m = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, re.DOTALL)
126
+ if m:
127
+ try:
128
+ return json.loads(m.group(1))
129
+ except json.JSONDecodeError:
130
+ pass
131
+ m = re.search(r"\{[^{}]*\}", text, re.DOTALL)
132
+ if m:
133
+ try:
134
+ return json.loads(m.group(0))
135
+ except json.JSONDecodeError:
136
+ pass
137
+ return {"action_type": "answer", "content": text}
138
+
139
+
140
+ def _log_start(task_id: str) -> None:
141
+ print(
142
+ f"[START] task={task_id} env={BENCHMARK} model={MODEL_NAME}",
143
+ flush=True,
144
+ )
145
+
146
+
147
+ def _log_step(step: int, action: dict, reward: float, done: bool, error: str | None) -> None:
148
+ action_str = json.dumps(action, separators=(",", ":"))
149
+ error_val = f"'{error}'" if error else "null"
150
+ print(
151
+ f"[STEP] step={step} action={action_str} "
152
+ f"reward={reward:.2f} done={str(done).lower()} error={error_val}",
153
+ flush=True,
154
+ )
155
+
156
+
157
+ def _log_end(success: bool, steps: int, score: float, rewards: list[float]) -> None:
158
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
159
+ print(
160
+ f"[END] success={str(success).lower()} steps={steps} "
161
+ f"score={score:.3f} rewards={rewards_str}",
162
+ flush=True,
163
+ )
164
+
165
+
166
+ async def run_task(task_id: str) -> float:
167
+ _log_start(task_id)
168
+
169
+ rewards: list[float] = []
170
+ step = 0
171
+ reward = 0.0
172
+ done = False
173
+ success = False
174
+ error_msg = None
175
+
176
+ try:
177
+ async with GenericEnvClient(base_url=ENV_URL) as env:
178
+ result = await env.reset(task_id=task_id)
179
+ obs = result.observation
180
+
181
+ system = SYSTEM_PROMPT.format(max_steps=MAX_STEPS)
182
+ messages = [
183
+ {"role": "system", "content": system},
184
+ {
185
+ "role": "user",
186
+ "content": (
187
+ f"## Investigation Task\n{obs.get('task_description', '')}\n\n"
188
+ f"## Database\n{obs.get('schema_info', '')}\n\n"
189
+ f"You have {MAX_STEPS} steps. Begin your investigation."
190
+ ),
191
+ },
192
+ ]
193
+
194
+ while not done and step < MAX_STEPS:
195
+ try:
196
+ completion = llm.chat.completions.create(
197
+ model=MODEL_NAME,
198
+ messages=messages,
199
+ temperature=0.1,
200
+ max_completion_tokens=1024,
201
+ )
202
+ llm_text = completion.choices[0].message.content or ""
203
+ except Exception as exc:
204
+ llm_text = json.dumps({
205
+ "action_type": "answer",
206
+ "content": "Unable to complete analysis due to LLM error.",
207
+ })
208
+ error_msg = str(exc)
209
+
210
+ action = _extract_json(llm_text)
211
+ if "action_type" not in action:
212
+ action["action_type"] = "query"
213
+ if "content" not in action:
214
+ action["content"] = llm_text
215
+
216
+ result = await env.step(action)
217
+ step += 1
218
+ done = result.done
219
+ reward = result.reward or 0.0
220
+ rewards.append(reward)
221
+ result_obs = result.observation
222
+ remaining = MAX_STEPS - step
223
+
224
+ _log_step(step, action, reward, done, error_msg)
225
+ error_msg = None
226
+
227
+ messages.append({"role": "assistant", "content": llm_text})
228
+
229
+ if not done and remaining <= 3:
230
+ urgency = (
231
+ f"URGENT: Only {remaining} step(s) left! "
232
+ "You MUST submit your final answer NOW using "
233
+ '{"action_type": "answer", "content": "..."}. '
234
+ "Summarize ALL findings so far."
235
+ )
236
+ else:
237
+ urgency = "Continue investigating or submit your final answer."
238
+
239
+ messages.append({
240
+ "role": "user",
241
+ "content": (
242
+ f"Query result:\n{result_obs.get('output', '')}\n\n"
243
+ f"{result_obs.get('message', '')}\n\n"
244
+ f"[Step {step}/{MAX_STEPS}] {urgency}"
245
+ ),
246
+ })
247
+
248
+ success = done and reward > 0.0
249
+ except Exception as exc:
250
+ error_msg = str(exc)
251
+ _log_step(step + 1, {"action_type": "error"}, 0.0, False, error_msg)
252
+
253
+ score = reward if done else 0.0
254
+ _log_end(success=success, steps=step, score=score, rewards=rewards)
255
+ return score
256
+
257
+
258
+ # ---------------------------------------------------------------------------
259
+ # Main
260
+ # ---------------------------------------------------------------------------
261
+
262
+ async def amain():
263
+ total = 0.0
264
+ for tid in TASK_IDS:
265
+ try:
266
+ r = await run_task(tid)
267
+ except Exception as exc:
268
+ print(f"[END] success=false steps=0 score=0.000 rewards=", flush=True)
269
+ r = 0.0
270
+ total += r
271
+ avg = total / len(TASK_IDS) if TASK_IDS else 0
272
+ print(f"\n=== Overall average score: {avg:.2f} ===", flush=True)
273
+
274
+
275
+ def main():
276
+ asyncio.run(amain())
277
+
278
+
279
+ if __name__ == "__main__":
280
+ main()
models.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import Field
2
+ from openenv.core.env_server.types import Action, Observation, State
3
+
4
+
5
+ class DataDetectiveAction(Action):
6
+ """Agent action: run a SQL query or submit a final answer."""
7
+
8
+ action_type: str = Field(
9
+ ...,
10
+ description="'query' to execute SQL against the database, or 'answer' to submit findings",
11
+ )
12
+ content: str = Field(
13
+ ...,
14
+ description="SQL query string (for action_type='query') or final analysis text (for action_type='answer')",
15
+ )
16
+
17
+
18
+ class DataDetectiveObservation(Observation):
19
+ """Observation returned after each action."""
20
+
21
+ output: str = Field(default="", description="Query results or system feedback")
22
+ task_description: str = Field(default="", description="The investigation task to solve")
23
+ schema_info: str = Field(default="", description="Database schema (provided at reset)")
24
+ step_number: int = Field(default=0, description="Current step in the episode")
25
+ max_steps: int = Field(default=30, description="Maximum steps allowed")
26
+ message: str = Field(default="", description="Status or feedback message")
27
+
28
+
29
+ class DataDetectiveState(State):
30
+ """Internal environment state."""
31
+
32
+ task_id: str = Field(default="", description="Current task identifier")
33
+ queries_executed: int = Field(default=0, description="Number of SQL queries run so far")
34
+ max_steps: int = Field(default=30, description="Maximum steps allowed")
openenv.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: data_detective
2
+ version: "1.0.0"
3
+ description: >
4
+ DataDetective: A business incident investigation environment where AI agents
5
+ use SQL queries to analyze a realistic e-commerce company database (TechMart)
6
+ and uncover root causes of business problems. Covers 9 tasks spanning order
7
+ analysis, product returns, customer churn, shipping ops, margin analysis,
8
+ supplier quality, inventory stockouts, fraud detection, and retention.
9
+ endpoints:
10
+ reset: /reset
11
+ step: /step
12
+ state: /state
13
+ tasks:
14
+ - id: orders_drop
15
+ difficulty: easy
16
+ description: Order volume dropped sharply after a major promotion ended
17
+ - id: returns_spike
18
+ difficulty: medium
19
+ description: Product returns spiking in a specific region due to defective SKU
20
+ - id: customer_churn
21
+ difficulty: hard
22
+ description: Active customer count declining across specific segments
23
+ - id: shipping_delay
24
+ difficulty: medium-hard
25
+ description: Customer satisfaction crisis driven by carrier delays in one region
26
+ - id: revenue_paradox
27
+ difficulty: hard
28
+ description: Revenue is up but profit is down — multi-causal margin erosion
29
+ - id: supplier_quality
30
+ difficulty: medium
31
+ description: Systemic quality issues from a single supplier across multiple products
32
+ - id: inventory_stockout
33
+ difficulty: medium-hard
34
+ description: Regional sales underperformance caused by warehouse stockout during promo
35
+ - id: fraud_detection
36
+ difficulty: hard
37
+ description: Coordinated fraud ring of fake accounts placing high-value orders
38
+ - id: repeat_purchase_decline
39
+ difficulty: hard
40
+ description: Repeat purchase rates collapsing while acquisition masks the problem
pyproject.toml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "data_detective_env"
3
+ version = "1.0.0"
4
+ description = "DataDetective: Business incident investigation environment for OpenEnv"
5
+ readme = "README.md"
6
+ requires-python = ">=3.10"
7
+ dependencies = [
8
+ "openenv-core>=0.2.0",
9
+ "fastapi>=0.104.0",
10
+ "uvicorn>=0.24.0",
11
+ "pydantic>=2.0.0",
12
+ "websockets>=12.0",
13
+ ]
14
+
15
+ [project.optional-dependencies]
16
+ dev = ["pytest", "httpx"]
17
+ inference = ["openai>=1.0.0"]
18
+
19
+ [build-system]
20
+ requires = ["setuptools>=68.0"]
21
+ build-backend = "setuptools.backends._legacy:_Backend"
22
+
23
+ [tool.setuptools.packages.find]
24
+ include = ["data_detective_env*"]
25
+
26
+ [project.scripts]
27
+ server = "server.app:main"
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ openenv-core>=0.2.0
2
+ fastapi>=0.104.0
3
+ uvicorn>=0.24.0
4
+ pydantic>=2.0.0
5
+ websockets>=12.0
6
+ openai>=1.0.0
server/__init__.py ADDED
File without changes
server/app.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """FastAPI application for the DataDetective environment."""
2
+
3
+ try:
4
+ from openenv.core.env_server.http_server import create_app
5
+ except Exception as e:
6
+ raise ImportError(
7
+ "openenv-core is required. pip install openenv-core"
8
+ ) from e
9
+
10
+ try:
11
+ from ..models import DataDetectiveAction, DataDetectiveObservation
12
+ from .environment import DataDetectiveEnvironment
13
+ except (ImportError, ModuleNotFoundError):
14
+ from models import DataDetectiveAction, DataDetectiveObservation
15
+ from server.environment import DataDetectiveEnvironment
16
+
17
+ app = create_app(
18
+ DataDetectiveEnvironment,
19
+ DataDetectiveAction,
20
+ DataDetectiveObservation,
21
+ env_name="data_detective",
22
+ max_concurrent_envs=10,
23
+ )
24
+
25
+
26
+ def main(host: str = "0.0.0.0", port: int = 7860):
27
+ import uvicorn
28
+ uvicorn.run(app, host=host, port=port)
29
+
30
+
31
+ if __name__ == "__main__":
32
+ import argparse
33
+ parser = argparse.ArgumentParser()
34
+ parser.add_argument("--port", type=int, default=7860)
35
+ args = parser.parse_args()
36
+ main(port=args.port)
server/database.py ADDED
@@ -0,0 +1,657 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Generates a realistic in-memory SQLite database for TechMart, a fictional
3
+ e-commerce company. The data contains deliberate patterns that support
4
+ nine investigation tasks:
5
+
6
+ 1. Orders drop after a major promotion ends
7
+ 2. Product returns spike for a specific SKU in the West region
8
+ 3. Customer churn concentrated in the Enterprise/Northeast segment
9
+ 4. Shipping delays by QuickShip in the Midwest driving support tickets
10
+ 5. Revenue up but profit down (multi-causal paradox)
11
+ 6. Supplier quality crisis (AudioTech products 6 & 7)
12
+ 7. Inventory stockout in West for Monitor 27-inch during promo
13
+ 8. Coordinated fraud ring in Southeast with new accounts
14
+ 9. Repeat purchase decline masked by new-customer acquisition spend
15
+ """
16
+
17
+ import random
18
+ import sqlite3
19
+ from datetime import datetime, timedelta
20
+
21
+
22
+ PRODUCTS = [
23
+ (1, "Laptop Pro 15", "Electronics", 999.99, 650.00, "TechCorp"),
24
+ (2, "Desktop Workstation", "Electronics", 1499.99, 950.00, "TechCorp"),
25
+ (3, "Tablet Ultra", "Electronics", 599.99, 350.00, "TechCorp"),
26
+ (4, "Monitor 27-inch", "Electronics", 449.99, 280.00, "DisplayMax"),
27
+ (5, "Smart TV 55-inch", "Electronics", 699.99, 420.00, "DisplayMax"),
28
+ (6, "Wireless Headphones Pro", "Accessories", 149.99, 45.00, "AudioTech"),
29
+ (7, "Bluetooth Speaker", "Accessories", 79.99, 30.00, "AudioTech"),
30
+ (8, "USB-C Hub", "Accessories", 49.99, 15.00, "ConnectPlus"),
31
+ (9, "Laptop Bag Premium", "Accessories", 39.99, 12.00, "CarryAll"),
32
+ (10, "Mouse Pad XL", "Accessories", 24.99, 8.00, "CarryAll"),
33
+ (11, "Office Suite License", "Software", 199.99, 20.00, "SoftVault"),
34
+ (12, "Antivirus Pro Annual", "Software", 49.99, 5.00, "SecureNet"),
35
+ (13, "Cloud Backup 1TB", "Software", 99.99, 10.00, "CloudStore"),
36
+ (14, "Design Studio Pro", "Software", 299.99, 30.00, "CreativeSoft"),
37
+ (15, "DevTools Ultimate", "Software", 149.99, 15.00, "CodeForge"),
38
+ (16, "Mechanical Keyboard RGB", "Peripherals", 129.99, 60.00, "KeyMaster"),
39
+ (17, "Wireless Mouse Pro", "Peripherals", 59.99, 20.00, "ClickTech"),
40
+ (18, "Webcam HD 1080p", "Peripherals", 89.99, 35.00, "VisionCam"),
41
+ (19, "External SSD 1TB", "Peripherals", 109.99, 55.00, "StoragePro"),
42
+ (20, "Laser Printer Pro", "Peripherals", 249.99, 130.00, "PrintMax"),
43
+ ]
44
+
45
+ _FIRST = [
46
+ "James","Mary","Robert","Patricia","John","Jennifer","Michael","Linda",
47
+ "David","Elizabeth","William","Barbara","Richard","Susan","Joseph","Jessica",
48
+ "Thomas","Sarah","Christopher","Karen","Charles","Lisa","Daniel","Nancy",
49
+ "Matthew","Betty","Anthony","Margaret","Mark","Sandra","Donald","Ashley",
50
+ "Steven","Dorothy","Andrew","Kimberly","Paul","Emily","Joshua","Donna",
51
+ "Kenneth","Michelle","Kevin","Carol","Brian","Amanda","George","Melissa",
52
+ "Timothy","Deborah",
53
+ ]
54
+
55
+ _LAST = [
56
+ "Smith","Johnson","Williams","Brown","Jones","Garcia","Miller","Davis",
57
+ "Rodriguez","Martinez","Hernandez","Lopez","Gonzalez","Wilson","Anderson",
58
+ "Thomas","Taylor","Moore","Jackson","Martin","Lee","Perez","Thompson",
59
+ "White","Harris","Sanchez","Clark","Ramirez","Lewis","Robinson","Walker",
60
+ "Young","Allen","King","Wright","Scott","Torres","Nguyen","Hill","Flores",
61
+ "Green","Adams","Nelson","Baker","Hall","Rivera","Campbell","Mitchell",
62
+ "Carter","Roberts",
63
+ ]
64
+
65
+ REGIONS = ["Northeast", "Southeast", "West", "Midwest"]
66
+
67
+ PRICE_CHANGES = [
68
+ (1, 999.99, 1149.99, "2024-02-01", "Annual pricing adjustment"),
69
+ (2, 1499.99, 1699.99, "2024-02-01", "Annual pricing adjustment"),
70
+ (11, 199.99, 229.99, "2024-02-01", "Annual pricing adjustment"),
71
+ (15, 149.99, 174.99, "2024-02-01", "Annual pricing adjustment"),
72
+ (19, 109.99, 129.99, "2024-02-01", "Annual pricing adjustment"),
73
+ ]
74
+
75
+ PROMOTIONS = [
76
+ (1, "New Year Kickoff", "2024-01-01", "2024-01-15", 10.0, "All"),
77
+ (2, "Valentine Tech Sale", "2024-02-10", "2024-02-14", 15.0, "Electronics"),
78
+ (3, "Spring Mega Sale", "2024-02-15", "2024-03-01", 25.0, "All"),
79
+ ]
80
+
81
+ CARRIERS = ["QuickShip", "FastFreight", "ReliableLogistics"]
82
+
83
+ TICKET_CATEGORIES = ["delivery_delay", "product_defect", "billing_issue", "general_inquiry"]
84
+
85
+ MARKETING_CHANNELS = ["email", "social_media", "search_ads", "display_ads", "affiliate"]
86
+
87
+
88
+ def _date_range(start: datetime, end: datetime):
89
+ d = start
90
+ while d <= end:
91
+ yield d
92
+ d += timedelta(days=1)
93
+
94
+
95
+ def _effective_price(base_prices: dict, changes_by_pid: dict, pid: int, date_str: str):
96
+ """Return the unit price for *pid* on *date_str*, considering price changes."""
97
+ price = base_prices[pid]
98
+ for new_price, change_date in changes_by_pid.get(pid, []):
99
+ if date_str >= change_date:
100
+ price = new_price
101
+ return price
102
+
103
+
104
+ def create_database(seed: int = 42) -> sqlite3.Connection:
105
+ rng = random.Random(seed)
106
+ conn = sqlite3.connect(":memory:", check_same_thread=False)
107
+ c = conn.cursor()
108
+
109
+ c.executescript("""
110
+ CREATE TABLE customers (
111
+ customer_id INTEGER PRIMARY KEY,
112
+ name TEXT NOT NULL,
113
+ email TEXT NOT NULL,
114
+ region TEXT NOT NULL,
115
+ segment TEXT NOT NULL,
116
+ signup_date TEXT NOT NULL
117
+ );
118
+ CREATE TABLE products (
119
+ product_id INTEGER PRIMARY KEY,
120
+ name TEXT NOT NULL,
121
+ category TEXT NOT NULL,
122
+ price REAL NOT NULL,
123
+ cost REAL NOT NULL,
124
+ supplier TEXT NOT NULL
125
+ );
126
+ CREATE TABLE orders (
127
+ order_id INTEGER PRIMARY KEY AUTOINCREMENT,
128
+ customer_id INTEGER NOT NULL,
129
+ order_date TEXT NOT NULL,
130
+ status TEXT NOT NULL DEFAULT 'completed',
131
+ total_amount REAL NOT NULL DEFAULT 0,
132
+ FOREIGN KEY (customer_id) REFERENCES customers(customer_id)
133
+ );
134
+ CREATE TABLE order_items (
135
+ item_id INTEGER PRIMARY KEY AUTOINCREMENT,
136
+ order_id INTEGER NOT NULL,
137
+ product_id INTEGER NOT NULL,
138
+ quantity INTEGER NOT NULL DEFAULT 1,
139
+ unit_price REAL NOT NULL,
140
+ FOREIGN KEY (order_id) REFERENCES orders(order_id),
141
+ FOREIGN KEY (product_id) REFERENCES products(product_id)
142
+ );
143
+ CREATE TABLE returns (
144
+ return_id INTEGER PRIMARY KEY AUTOINCREMENT,
145
+ order_id INTEGER NOT NULL,
146
+ product_id INTEGER NOT NULL,
147
+ return_date TEXT NOT NULL,
148
+ reason TEXT NOT NULL,
149
+ refund_amount REAL NOT NULL,
150
+ FOREIGN KEY (order_id) REFERENCES orders(order_id),
151
+ FOREIGN KEY (product_id) REFERENCES products(product_id)
152
+ );
153
+ CREATE TABLE promotions (
154
+ promo_id INTEGER PRIMARY KEY,
155
+ name TEXT NOT NULL,
156
+ start_date TEXT NOT NULL,
157
+ end_date TEXT NOT NULL,
158
+ discount_pct REAL NOT NULL,
159
+ applicable_category TEXT
160
+ );
161
+ CREATE TABLE price_changes (
162
+ change_id INTEGER PRIMARY KEY AUTOINCREMENT,
163
+ product_id INTEGER NOT NULL,
164
+ old_price REAL NOT NULL,
165
+ new_price REAL NOT NULL,
166
+ change_date TEXT NOT NULL,
167
+ reason TEXT,
168
+ FOREIGN KEY (product_id) REFERENCES products(product_id)
169
+ );
170
+ CREATE TABLE shipping (
171
+ shipment_id INTEGER PRIMARY KEY AUTOINCREMENT,
172
+ order_id INTEGER NOT NULL,
173
+ carrier TEXT NOT NULL,
174
+ ship_date TEXT NOT NULL,
175
+ delivery_date TEXT NOT NULL,
176
+ status TEXT NOT NULL DEFAULT 'delivered',
177
+ FOREIGN KEY (order_id) REFERENCES orders(order_id)
178
+ );
179
+ CREATE TABLE support_tickets (
180
+ ticket_id INTEGER PRIMARY KEY AUTOINCREMENT,
181
+ customer_id INTEGER NOT NULL,
182
+ product_id INTEGER,
183
+ created_date TEXT NOT NULL,
184
+ category TEXT NOT NULL,
185
+ priority TEXT NOT NULL DEFAULT 'medium',
186
+ resolution_status TEXT NOT NULL DEFAULT 'open',
187
+ FOREIGN KEY (customer_id) REFERENCES customers(customer_id),
188
+ FOREIGN KEY (product_id) REFERENCES products(product_id)
189
+ );
190
+ CREATE TABLE inventory_log (
191
+ log_id INTEGER PRIMARY KEY AUTOINCREMENT,
192
+ product_id INTEGER NOT NULL,
193
+ log_date TEXT NOT NULL,
194
+ units_in_stock INTEGER NOT NULL,
195
+ units_ordered INTEGER NOT NULL DEFAULT 0,
196
+ warehouse_region TEXT NOT NULL,
197
+ FOREIGN KEY (product_id) REFERENCES products(product_id)
198
+ );
199
+ CREATE TABLE marketing_spend (
200
+ spend_id INTEGER PRIMARY KEY AUTOINCREMENT,
201
+ channel TEXT NOT NULL,
202
+ campaign_name TEXT NOT NULL,
203
+ region TEXT NOT NULL,
204
+ spend_date TEXT NOT NULL,
205
+ amount REAL NOT NULL
206
+ );
207
+ """)
208
+
209
+ c.executemany("INSERT INTO products VALUES (?,?,?,?,?,?)", PRODUCTS)
210
+ base_prices = {p[0]: p[3] for p in PRODUCTS}
211
+
212
+ segments_pool = ["Enterprise"] * 35 + ["SMB"] * 55 + ["Consumer"] * 60
213
+ rng.shuffle(segments_pool)
214
+ customers = []
215
+ for i in range(150):
216
+ first = rng.choice(_FIRST)
217
+ last = rng.choice(_LAST)
218
+ name = f"{first} {last}"
219
+ email = f"{first.lower()}.{last.lower()}{i}@techmart.com"
220
+ region = REGIONS[i % 4]
221
+ segment = segments_pool[i]
222
+ signup = (datetime(2023, 1, 1) + timedelta(days=rng.randint(0, 364))).strftime("%Y-%m-%d")
223
+ c.execute("INSERT INTO customers VALUES (?,?,?,?,?,?)",
224
+ (i + 1, name, email, region, segment, signup))
225
+ customers.append((i + 1, name, email, region, segment, signup))
226
+
227
+ ent_ne = [cu for cu in customers if cu[4] == "Enterprise" and cu[3] == "Northeast"]
228
+ ent_other = [cu for cu in customers if cu[4] == "Enterprise" and cu[3] != "Northeast"]
229
+ smb_all = [cu for cu in customers if cu[4] == "SMB"]
230
+ con_all = [cu for cu in customers if cu[4] == "Consumer"]
231
+
232
+ c.executemany("INSERT INTO promotions VALUES (?,?,?,?,?,?)", PROMOTIONS)
233
+ for pid, old_p, new_p, dt, reason in PRICE_CHANGES:
234
+ c.execute(
235
+ "INSERT INTO price_changes (product_id,old_price,new_price,change_date,reason) VALUES (?,?,?,?,?)",
236
+ (pid, old_p, new_p, dt, reason),
237
+ )
238
+ changes_by_pid: dict[int, list] = {}
239
+ for pid, _, new_p, dt, _ in PRICE_CHANGES:
240
+ changes_by_pid.setdefault(pid, []).append((new_p, dt))
241
+
242
+ START = datetime(2024, 1, 1)
243
+ END = datetime(2024, 3, 15)
244
+ PROMO_S = datetime(2024, 2, 15)
245
+ PROMO_E = datetime(2024, 3, 1)
246
+ PRICE_INC = datetime(2024, 2, 1)
247
+
248
+ product_weights_base = [1.0] * 20
249
+ product_weights_base[5] = 3.0
250
+
251
+ for day in _date_range(START, END):
252
+ date_str = day.strftime("%Y-%m-%d")
253
+ is_promo = PROMO_S <= day <= PROMO_E
254
+ after_price_inc = day >= PRICE_INC
255
+
256
+ daily_count = rng.randint(25, 35) if is_promo else rng.randint(12, 18)
257
+
258
+ for _ in range(daily_count):
259
+ roll = rng.random()
260
+ if roll < 0.08:
261
+ pool = ent_ne
262
+ if after_price_inc and rng.random() < 0.85:
263
+ continue
264
+ elif roll < 0.22:
265
+ pool = ent_other
266
+ if after_price_inc and rng.random() < 0.50:
267
+ continue
268
+ elif roll < 0.55:
269
+ pool = smb_all
270
+ if after_price_inc and rng.random() < 0.20:
271
+ continue
272
+ else:
273
+ pool = con_all
274
+
275
+ cust = rng.choice(pool)
276
+ cust_id, _, _, cust_region, _, _ = cust
277
+
278
+ weights = list(product_weights_base)
279
+ if cust_region == "West":
280
+ weights[5] = 7.0
281
+ if is_promo:
282
+ weights[3] = 0.1
283
+ num_items = rng.choices([1, 2, 3], weights=[0.6, 0.3, 0.1])[0]
284
+ pids = list(set(rng.choices(range(1, 21), weights=weights, k=num_items)))
285
+
286
+ c.execute(
287
+ "INSERT INTO orders (customer_id, order_date, status, total_amount) VALUES (?,?,?,?)",
288
+ (cust_id, date_str, "completed", 0),
289
+ )
290
+ order_id = c.lastrowid
291
+ total = 0.0
292
+ for pid in pids:
293
+ qty = rng.choices([1, 2, 3], weights=[0.75, 0.20, 0.05])[0]
294
+ price = _effective_price(base_prices, changes_by_pid, pid, date_str)
295
+ if is_promo:
296
+ price = round(price * 0.75, 2)
297
+ total += price * qty
298
+ c.execute(
299
+ "INSERT INTO order_items (order_id, product_id, quantity, unit_price) VALUES (?,?,?,?)",
300
+ (order_id, pid, qty, round(price, 2)),
301
+ )
302
+ c.execute("UPDATE orders SET total_amount=? WHERE order_id=?",
303
+ (round(total, 2), order_id))
304
+
305
+ c.execute("""
306
+ SELECT oi.item_id, oi.order_id, oi.product_id, oi.unit_price, oi.quantity,
307
+ o.order_date, cu.region
308
+ FROM order_items oi
309
+ JOIN orders o ON oi.order_id = o.order_id
310
+ JOIN customers cu ON o.customer_id = cu.customer_id
311
+ """)
312
+ items = c.fetchall()
313
+
314
+ defect_reasons = ["defective_unit", "stopped_working", "poor_audio_quality", "battery_issue"]
315
+ normal_reasons = ["changed_mind", "wrong_size", "found_cheaper", "not_as_expected"]
316
+
317
+ speaker_defect_reasons = ["audio_distortion", "bluetooth_disconnect", "battery_issue", "stopped_working"]
318
+
319
+ for _, order_id, product_id, unit_price, qty, order_date, region in items:
320
+ if product_id == 6 and region == "West":
321
+ prob = 0.38
322
+ reasons = defect_reasons
323
+ elif product_id == 6:
324
+ prob = 0.08
325
+ reasons = defect_reasons + normal_reasons
326
+ elif product_id == 7:
327
+ prob = 0.12
328
+ reasons = speaker_defect_reasons
329
+ else:
330
+ prob = 0.04
331
+ reasons = normal_reasons
332
+
333
+ if rng.random() < prob:
334
+ ret_date = (datetime.strptime(order_date, "%Y-%m-%d")
335
+ + timedelta(days=rng.randint(3, 14))).strftime("%Y-%m-%d")
336
+ c.execute(
337
+ "INSERT INTO returns (order_id, product_id, return_date, reason, refund_amount) VALUES (?,?,?,?,?)",
338
+ (order_id, product_id, ret_date, rng.choice(reasons), round(unit_price * qty, 2)),
339
+ )
340
+
341
+ # -- Shipping records for every order ------------------------------------
342
+ QUICKSHIP_DELAY_START = datetime(2024, 2, 10)
343
+ c.execute("SELECT order_id, order_date, customer_id FROM orders")
344
+ all_orders = c.fetchall()
345
+ cust_region_map = {cu[0]: cu[3] for cu in customers}
346
+
347
+ for order_id, order_date_str, cust_id in all_orders:
348
+ order_dt = datetime.strptime(order_date_str, "%Y-%m-%d")
349
+ region = cust_region_map[cust_id]
350
+
351
+ if region == "Midwest":
352
+ carrier = rng.choices(CARRIERS, weights=[0.40, 0.35, 0.25])[0]
353
+ else:
354
+ carrier = rng.choices(CARRIERS, weights=[0.25, 0.40, 0.35])[0]
355
+
356
+ ship_dt = order_dt + timedelta(days=rng.randint(0, 1))
357
+ base_transit = rng.randint(2, 4)
358
+
359
+ if carrier == "QuickShip" and region == "Midwest" and order_dt >= QUICKSHIP_DELAY_START:
360
+ extra_delay = rng.randint(5, 10)
361
+ status = "delayed"
362
+ elif carrier == "FastFreight":
363
+ extra_delay = rng.randint(0, 2)
364
+ status = "delivered"
365
+ else:
366
+ extra_delay = 0
367
+ status = "delivered"
368
+
369
+ delivery_dt = ship_dt + timedelta(days=base_transit + extra_delay)
370
+ if status == "delayed" and rng.random() < 0.7:
371
+ status = "delivered"
372
+
373
+ c.execute(
374
+ "INSERT INTO shipping (order_id, carrier, ship_date, delivery_date, status) "
375
+ "VALUES (?,?,?,?,?)",
376
+ (order_id, carrier, ship_dt.strftime("%Y-%m-%d"),
377
+ delivery_dt.strftime("%Y-%m-%d"), status),
378
+ )
379
+
380
+ # -- Support tickets -----------------------------------------------------
381
+ ticket_priorities = ["low", "medium", "high", "critical"]
382
+ ticket_resolutions = ["open", "resolved", "escalated"]
383
+
384
+ for day in _date_range(START, END):
385
+ date_str = day.strftime("%Y-%m-%d")
386
+ after_qs_issues = day >= QUICKSHIP_DELAY_START
387
+
388
+ for region_name in REGIONS:
389
+ region_custs = [cu for cu in customers if cu[3] == region_name]
390
+
391
+ # Delivery delay tickets: spike in Midwest after QuickShip issues
392
+ if region_name == "Midwest" and after_qs_issues:
393
+ n_delay = rng.randint(3, 6)
394
+ else:
395
+ n_delay = rng.randint(0, 1)
396
+
397
+ for _ in range(n_delay):
398
+ cu = rng.choice(region_custs)
399
+ pri = rng.choices(ticket_priorities, weights=[0.1, 0.3, 0.4, 0.2])[0]
400
+ res = rng.choices(ticket_resolutions, weights=[0.3, 0.5, 0.2])[0]
401
+ c.execute(
402
+ "INSERT INTO support_tickets "
403
+ "(customer_id, product_id, created_date, category, priority, resolution_status) "
404
+ "VALUES (?,?,?,?,?,?)",
405
+ (cu[0], None, date_str, "delivery_delay", pri, res),
406
+ )
407
+
408
+ # Product defect tickets: elevated for AudioTech products (6 in West, 7 everywhere)
409
+ if region_name == "West":
410
+ n_defect = rng.randint(1, 3)
411
+ else:
412
+ n_defect = 1 if rng.random() < 0.3 else 0
413
+
414
+ for _ in range(n_defect):
415
+ cu = rng.choice(region_custs)
416
+ pid = 6 if region_name == "West" or rng.random() < 0.4 else rng.randint(1, 20)
417
+ pri = rng.choices(ticket_priorities, weights=[0.1, 0.3, 0.4, 0.2])[0]
418
+ res = rng.choices(ticket_resolutions, weights=[0.4, 0.4, 0.2])[0]
419
+ c.execute(
420
+ "INSERT INTO support_tickets "
421
+ "(customer_id, product_id, created_date, category, priority, resolution_status) "
422
+ "VALUES (?,?,?,?,?,?)",
423
+ (cu[0], pid, date_str, "product_defect", pri, res),
424
+ )
425
+
426
+ # Product 7 (Bluetooth Speaker) defect tickets across all regions
427
+ if rng.random() < 0.45:
428
+ cu = rng.choice(region_custs)
429
+ pri = rng.choices(ticket_priorities, weights=[0.1, 0.4, 0.35, 0.15])[0]
430
+ res = rng.choices(ticket_resolutions, weights=[0.35, 0.45, 0.2])[0]
431
+ c.execute(
432
+ "INSERT INTO support_tickets "
433
+ "(customer_id, product_id, created_date, category, priority, resolution_status) "
434
+ "VALUES (?,?,?,?,?,?)",
435
+ (cu[0], 7, date_str, "product_defect", pri, res),
436
+ )
437
+
438
+ # Billing issue tickets: evenly spread (red herring / noise)
439
+ if rng.random() < 0.25:
440
+ cu = rng.choice(region_custs)
441
+ c.execute(
442
+ "INSERT INTO support_tickets "
443
+ "(customer_id, product_id, created_date, category, priority, resolution_status) "
444
+ "VALUES (?,?,?,?,?,?)",
445
+ (cu[0], None, date_str, "billing_issue",
446
+ rng.choice(ticket_priorities), rng.choice(ticket_resolutions)),
447
+ )
448
+
449
+ # General inquiry: background noise
450
+ if rng.random() < 0.35:
451
+ cu = rng.choice(region_custs)
452
+ c.execute(
453
+ "INSERT INTO support_tickets "
454
+ "(customer_id, product_id, created_date, category, priority, resolution_status) "
455
+ "VALUES (?,?,?,?,?,?)",
456
+ (cu[0], None, date_str, "general_inquiry", "low",
457
+ rng.choice(["resolved", "open"])),
458
+ )
459
+
460
+ # -- Inventory log -------------------------------------------------------
461
+ # Daily stock levels per product per warehouse region.
462
+ # Product 4 (Monitor 27-inch) stocks out in West during promo.
463
+ base_stock = {}
464
+ for p in PRODUCTS:
465
+ pid = p[0]
466
+ if pid in (1, 2, 3, 4, 5): # electronics — higher stock
467
+ base_stock[pid] = 200
468
+ elif pid <= 10: # accessories
469
+ base_stock[pid] = 350
470
+ elif pid <= 15: # software (digital)
471
+ base_stock[pid] = 9999
472
+ else: # peripherals
473
+ base_stock[pid] = 250
474
+
475
+ for day in _date_range(START, END):
476
+ date_str = day.strftime("%Y-%m-%d")
477
+ is_promo = PROMO_S <= day <= PROMO_E
478
+
479
+ for region_name in REGIONS:
480
+ for p in PRODUCTS:
481
+ pid = p[0]
482
+ stock = base_stock[pid]
483
+
484
+ daily_sold = rng.randint(2, 8)
485
+ if is_promo:
486
+ daily_sold = rng.randint(5, 15)
487
+
488
+ # Product 4 stockout in West during promo
489
+ if pid == 4 and region_name == "West" and is_promo:
490
+ stock = rng.randint(0, 2)
491
+ daily_sold = rng.randint(0, 1)
492
+ else:
493
+ stock = max(stock - daily_sold + rng.randint(1, 6), 10)
494
+
495
+ # Product 6 fluctuates in West but never stocks out (red herring)
496
+ if pid == 6 and region_name == "West":
497
+ stock = rng.randint(30, 80)
498
+
499
+ reorder = 0
500
+ if stock < 20 and pid <= 15:
501
+ reorder = rng.randint(50, 100)
502
+
503
+ c.execute(
504
+ "INSERT INTO inventory_log "
505
+ "(product_id, log_date, units_in_stock, units_ordered, warehouse_region) "
506
+ "VALUES (?,?,?,?,?)",
507
+ (pid, date_str, stock, reorder, region_name),
508
+ )
509
+
510
+ # -- Fraudulent accounts ---------------------------------------------------
511
+ # ~15 fake accounts in Southeast, Consumer, all signed up late Feb,
512
+ # placing high-value Electronics orders (products 1 & 2).
513
+ fraud_customers = []
514
+ for i in range(15):
515
+ cid = 151 + i
516
+ first = rng.choice(_FIRST)
517
+ last = rng.choice(_LAST)
518
+ name = f"{first} {last}"
519
+ email = f"{first.lower()}.{last.lower()}{cid}@techmart.com"
520
+ signup = (datetime(2024, 2, 20) + timedelta(days=rng.randint(0, 7))).strftime("%Y-%m-%d")
521
+ c.execute("INSERT INTO customers VALUES (?,?,?,?,?,?)",
522
+ (cid, name, email, "Southeast", "Consumer", signup))
523
+ fraud_customers.append(cid)
524
+ customers.append((cid, name, email, "Southeast", "Consumer", signup))
525
+
526
+ cust_region_map.update({cid: "Southeast" for cid in fraud_customers})
527
+
528
+ FRAUD_ORDER_START = datetime(2024, 2, 25)
529
+ FRAUD_ORDER_END = datetime(2024, 3, 10)
530
+ for cid in fraud_customers:
531
+ n_orders = rng.randint(3, 5)
532
+ for _ in range(n_orders):
533
+ order_day = FRAUD_ORDER_START + timedelta(
534
+ days=rng.randint(0, (FRAUD_ORDER_END - FRAUD_ORDER_START).days))
535
+ date_str = order_day.strftime("%Y-%m-%d")
536
+ fraud_pid = rng.choice([1, 2])
537
+ qty = rng.randint(1, 2)
538
+ price = _effective_price(base_prices, changes_by_pid, fraud_pid, date_str)
539
+ is_promo_day = PROMO_S <= order_day <= PROMO_E
540
+ if is_promo_day:
541
+ price = round(price * 0.75, 2)
542
+ total = round(price * qty, 2)
543
+ c.execute(
544
+ "INSERT INTO orders (customer_id, order_date, status, total_amount) VALUES (?,?,?,?)",
545
+ (cid, date_str, "completed", total),
546
+ )
547
+ oid = c.lastrowid
548
+ c.execute(
549
+ "INSERT INTO order_items (order_id, product_id, quantity, unit_price) VALUES (?,?,?,?)",
550
+ (oid, fraud_pid, qty, round(price, 2)),
551
+ )
552
+ ship_dt = order_day + timedelta(days=rng.randint(0, 1))
553
+ delivery_dt = ship_dt + timedelta(days=rng.randint(2, 4))
554
+ c.execute(
555
+ "INSERT INTO shipping (order_id, carrier, ship_date, delivery_date, status) "
556
+ "VALUES (?,?,?,?,?)",
557
+ (oid, "FastFreight", ship_dt.strftime("%Y-%m-%d"),
558
+ delivery_dt.strftime("%Y-%m-%d"), "delivered"),
559
+ )
560
+
561
+ # -- Marketing spend -------------------------------------------------------
562
+ # Heavy acquisition spend (search_ads, social_media) in Feb/Mar.
563
+ # Email (retention) drops off after Jan. Southeast gets a big bump in Feb
564
+ # (red herring for fraud task).
565
+ acq_channels = ["search_ads", "social_media", "display_ads", "affiliate"]
566
+ for day in _date_range(START, END):
567
+ date_str = day.strftime("%Y-%m-%d")
568
+ month = day.month
569
+
570
+ for region_name in REGIONS:
571
+ # Retention channel: email
572
+ if month == 1:
573
+ email_spend = round(rng.uniform(200, 400), 2)
574
+ else:
575
+ email_spend = round(rng.uniform(20, 60), 2)
576
+ c.execute(
577
+ "INSERT INTO marketing_spend (channel, campaign_name, region, spend_date, amount) "
578
+ "VALUES (?,?,?,?,?)",
579
+ ("email", "Customer Retention", region_name, date_str, email_spend),
580
+ )
581
+
582
+ # Acquisition channels
583
+ for ch in acq_channels:
584
+ if month == 1:
585
+ base_spend = rng.uniform(100, 200)
586
+ else:
587
+ base_spend = rng.uniform(300, 600)
588
+
589
+ if region_name == "Southeast" and month >= 2:
590
+ base_spend *= 1.5
591
+
592
+ c.execute(
593
+ "INSERT INTO marketing_spend (channel, campaign_name, region, spend_date, amount) "
594
+ "VALUES (?,?,?,?,?)",
595
+ (ch, "New Customer Acquisition", region_name, date_str,
596
+ round(base_spend, 2)),
597
+ )
598
+
599
+ conn.commit()
600
+ return conn
601
+
602
+
603
+ def get_schema_info(conn: sqlite3.Connection) -> str:
604
+ """Human-readable database schema for the LLM agent."""
605
+ c = conn.cursor()
606
+ c.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
607
+ tables = [r[0] for r in c.fetchall()]
608
+
609
+ parts = ["DATABASE SCHEMA", "=" * 50, ""]
610
+ for table in tables:
611
+ c.execute(f"SELECT COUNT(*) FROM {table}")
612
+ count = c.fetchone()[0]
613
+ parts.append(f"Table: {table} ({count} rows)")
614
+
615
+ c.execute(f"PRAGMA table_info({table})")
616
+ for col in c.fetchall():
617
+ pk = " [PK]" if col[5] else ""
618
+ parts.append(f" - {col[1]} {col[2]}{pk}")
619
+
620
+ if table == "customers":
621
+ c.execute("SELECT DISTINCT region FROM customers ORDER BY region")
622
+ parts.append(f" Regions: {', '.join(r[0] for r in c.fetchall())}")
623
+ c.execute("SELECT DISTINCT segment FROM customers ORDER BY segment")
624
+ parts.append(f" Segments: {', '.join(r[0] for r in c.fetchall())}")
625
+ elif table == "products":
626
+ c.execute("SELECT DISTINCT category FROM products ORDER BY category")
627
+ parts.append(f" Categories: {', '.join(r[0] for r in c.fetchall())}")
628
+ c.execute("SELECT DISTINCT supplier FROM products ORDER BY supplier")
629
+ parts.append(f" Suppliers: {', '.join(r[0] for r in c.fetchall())}")
630
+ elif table == "shipping":
631
+ c.execute("SELECT DISTINCT carrier FROM shipping ORDER BY carrier")
632
+ parts.append(f" Carriers: {', '.join(r[0] for r in c.fetchall())}")
633
+ c.execute("SELECT DISTINCT status FROM shipping ORDER BY status")
634
+ parts.append(f" Statuses: {', '.join(r[0] for r in c.fetchall())}")
635
+ elif table == "support_tickets":
636
+ c.execute("SELECT DISTINCT category FROM support_tickets ORDER BY category")
637
+ parts.append(f" Categories: {', '.join(r[0] for r in c.fetchall())}")
638
+ c.execute("SELECT DISTINCT priority FROM support_tickets ORDER BY priority")
639
+ parts.append(f" Priorities: {', '.join(r[0] for r in c.fetchall())}")
640
+ elif table == "inventory_log":
641
+ c.execute("SELECT DISTINCT warehouse_region FROM inventory_log ORDER BY warehouse_region")
642
+ parts.append(f" Warehouse regions: {', '.join(r[0] for r in c.fetchall())}")
643
+ elif table == "marketing_spend":
644
+ c.execute("SELECT DISTINCT channel FROM marketing_spend ORDER BY channel")
645
+ parts.append(f" Channels: {', '.join(r[0] for r in c.fetchall())}")
646
+ c.execute("SELECT DISTINCT campaign_name FROM marketing_spend ORDER BY campaign_name")
647
+ parts.append(f" Campaigns: {', '.join(r[0] for r in c.fetchall())}")
648
+
649
+ parts.append("")
650
+
651
+ parts += [
652
+ "=" * 50,
653
+ "Data spans: 2024-01-01 to 2024-03-15",
654
+ "All dates stored as YYYY-MM-DD text.",
655
+ "Use standard SQLite functions (strftime, date, etc.).",
656
+ ]
657
+ return "\n".join(parts)
server/environment.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Core environment logic for DataDetective."""
2
+
3
+ import random
4
+ import uuid
5
+ from typing import Any, Optional
6
+
7
+ from openenv.core.env_server import Environment
8
+
9
+ try:
10
+ from ..models import DataDetectiveAction, DataDetectiveObservation, DataDetectiveState
11
+ from .database import create_database, get_schema_info
12
+ from .tasks import TASKS, grade_answer
13
+ except (ImportError, ModuleNotFoundError):
14
+ from models import DataDetectiveAction, DataDetectiveObservation, DataDetectiveState
15
+ from server.database import create_database, get_schema_info
16
+ from server.tasks import TASKS, grade_answer
17
+
18
+
19
+ class DataDetectiveEnvironment(
20
+ Environment[DataDetectiveAction, DataDetectiveObservation, DataDetectiveState]
21
+ ):
22
+ SUPPORTS_CONCURRENT_SESSIONS = True
23
+ MAX_STEPS = 30
24
+
25
+ def __init__(self):
26
+ super().__init__()
27
+ self._db = None
28
+ self._task_id: str = ""
29
+ self._step_count: int = 0
30
+ self._episode_id: str = ""
31
+ self._queries_executed: int = 0
32
+ self._state = DataDetectiveState()
33
+
34
+ def reset(
35
+ self,
36
+ seed: Optional[int] = None,
37
+ episode_id: Optional[str] = None,
38
+ task_id: Optional[str] = None,
39
+ **kwargs: Any,
40
+ ) -> DataDetectiveObservation:
41
+ if seed is not None:
42
+ random.seed(seed)
43
+
44
+ self._episode_id = episode_id or str(uuid.uuid4())
45
+ self._task_id = task_id if task_id in TASKS else random.choice(list(TASKS))
46
+ self._step_count = 0
47
+ self._queries_executed = 0
48
+
49
+ if self._db is not None:
50
+ self._db.close()
51
+ self._db = create_database()
52
+
53
+ task = TASKS[self._task_id]
54
+ schema = get_schema_info(self._db)
55
+
56
+ self._state = DataDetectiveState(
57
+ episode_id=self._episode_id,
58
+ step_count=0,
59
+ task_id=self._task_id,
60
+ queries_executed=0,
61
+ max_steps=self.MAX_STEPS,
62
+ )
63
+
64
+ return DataDetectiveObservation(
65
+ done=False,
66
+ reward=None,
67
+ output="Environment ready. Run SQL queries to investigate the issue, then submit your answer.",
68
+ task_description=task["description"],
69
+ schema_info=schema,
70
+ step_number=0,
71
+ max_steps=self.MAX_STEPS,
72
+ message=f"Investigation: {task['title']} [{task['difficulty'].upper()}] -- {self.MAX_STEPS} steps available.",
73
+ )
74
+
75
+ def step(
76
+ self,
77
+ action: DataDetectiveAction,
78
+ timeout_s: Optional[float] = None,
79
+ **kwargs: Any,
80
+ ) -> DataDetectiveObservation:
81
+ self._step_count += 1
82
+ self._state.step_count = self._step_count
83
+
84
+ remaining = self.MAX_STEPS - self._step_count
85
+
86
+ if self._step_count > self.MAX_STEPS:
87
+ return self._obs(
88
+ done=True, reward=0.0,
89
+ output="Maximum steps reached -- investigation ended with no answer submitted.",
90
+ message="Out of steps.",
91
+ )
92
+
93
+ atype = (action.action_type or "").strip().lower()
94
+
95
+ if atype == "query":
96
+ return self._handle_query(action.content, remaining)
97
+ elif atype == "answer":
98
+ return self._handle_answer(action.content)
99
+ else:
100
+ return self._obs(
101
+ done=False, reward=0.0,
102
+ output="",
103
+ message=f"Unknown action_type '{action.action_type}'. Use 'query' or 'answer'. ({remaining} steps left)",
104
+ )
105
+
106
+ @property
107
+ def state(self) -> DataDetectiveState:
108
+ return self._state
109
+
110
+ def close(self) -> None:
111
+ if self._db is not None:
112
+ self._db.close()
113
+ self._db = None
114
+
115
+ def _obs(self, *, done: bool, reward: float | None, output: str, message: str) -> DataDetectiveObservation:
116
+ return DataDetectiveObservation(
117
+ done=done,
118
+ reward=reward,
119
+ output=output,
120
+ task_description=TASKS[self._task_id]["description"],
121
+ schema_info="",
122
+ step_number=self._step_count,
123
+ max_steps=self.MAX_STEPS,
124
+ message=message,
125
+ )
126
+
127
+ def _handle_query(self, sql: str, remaining: int) -> DataDetectiveObservation:
128
+ self._queries_executed += 1
129
+ self._state.queries_executed = self._queries_executed
130
+
131
+ if not sql or not sql.strip():
132
+ return self._obs(
133
+ done=False, reward=0.0,
134
+ output="Empty query -- please provide a valid SQL statement.",
135
+ message=f"{remaining} steps left.",
136
+ )
137
+
138
+ try:
139
+ cur = self._db.cursor()
140
+ cur.execute(sql)
141
+ columns = [d[0] for d in cur.description] if cur.description else []
142
+ rows = cur.fetchall()
143
+ output = _format_table(columns, rows) if rows else "Query returned 0 rows."
144
+ except Exception as exc:
145
+ output = f"SQL Error: {exc}"
146
+ return self._obs(
147
+ done=False, reward=0.0,
148
+ output=output,
149
+ message=f"Query failed. Fix your SQL and retry. ({remaining} steps left)",
150
+ )
151
+
152
+ return self._obs(
153
+ done=False, reward=0.0,
154
+ output=output,
155
+ message=f"{len(rows)} row(s) returned. ({remaining} steps left)",
156
+ )
157
+
158
+ def _handle_answer(self, answer_text: str) -> DataDetectiveObservation:
159
+ reward = grade_answer(self._task_id, answer_text)
160
+ if reward >= 0.8:
161
+ verdict = "Excellent investigation!"
162
+ elif reward >= 0.5:
163
+ verdict = "Good findings, but some details missing."
164
+ else:
165
+ verdict = "Several key findings were missed."
166
+
167
+ return self._obs(
168
+ done=True,
169
+ reward=reward,
170
+ output=f"Score: {reward:.2f} / 1.00 -- {verdict}",
171
+ message=f"Investigation complete. Final score: {reward:.2f}",
172
+ )
173
+
174
+
175
+ def _format_table(columns: list[str], rows: list, max_rows: int = 100) -> str:
176
+ truncated = len(rows) > max_rows
177
+ display = rows[:max_rows]
178
+
179
+ widths = [len(str(c)) for c in columns]
180
+ for row in display:
181
+ for i, v in enumerate(row):
182
+ widths[i] = max(widths[i], min(len(str(v)), 60))
183
+
184
+ header = " | ".join(str(c).ljust(widths[i]) for i, c in enumerate(columns))
185
+ sep = "-+-".join("-" * w for w in widths)
186
+ lines = [header, sep]
187
+ for row in display:
188
+ lines.append(" | ".join(str(v).ljust(widths[i])[:60] for i, v in enumerate(row)))
189
+
190
+ if truncated:
191
+ lines.append(f"... (showing {max_rows} of {len(rows)} rows)")
192
+ return "\n".join(lines)
server/requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ openenv-core>=0.2.0
2
+ fastapi>=0.104.0
3
+ uvicorn>=0.24.0
4
+ pydantic>=2.0.0
5
+ websockets>=12.0
6
+ openai>=1.0.0
server/tasks.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Task definitions and automated graders for the DataDetective environment.
3
+
4
+ Each task has:
5
+ - id, title, difficulty, description
6
+ - A grader function that scores the agent's final answer (0.0 - 1.0)
7
+ based on whether key findings are mentioned.
8
+ """
9
+
10
+ import re
11
+ from typing import Callable
12
+
13
+
14
+ def _has_any(text: str, keywords: list[str]) -> bool:
15
+ """Case-insensitive check: does *text* contain any of *keywords*?"""
16
+ low = text.lower()
17
+ return any(kw.lower() in low for kw in keywords)
18
+
19
+
20
+ def _has_pattern(text: str, pattern: str) -> bool:
21
+ return bool(re.search(pattern, text, re.IGNORECASE))
22
+
23
+
24
+ def _grade_orders_drop(answer: str) -> float:
25
+ score = 0.0
26
+ if _has_any(answer, ["drop", "decrease", "decline", "fell", "fewer", "reduction", "lower"]):
27
+ score += 0.20
28
+ if _has_any(answer, ["spring mega sale", "spring sale", "mega sale"]) or (
29
+ _has_any(answer, ["promotion", "promo", "sale", "discount", "campaign"])
30
+ ):
31
+ score += 0.20
32
+ if _has_any(answer, ["ended", "expired", "over", "concluded", "stopped"]) or _has_pattern(
33
+ answer, r"march\s*0?1"
34
+ ):
35
+ score += 0.20
36
+ if _has_any(answer, [
37
+ "caused", "because", "due to", "result of", "led to",
38
+ "when the", "after the", "ending of", "end of the",
39
+ "correlated", "explains",
40
+ ]):
41
+ score += 0.20
42
+ if _has_pattern(answer, r"\d+\s*(orders|transactions)") or _has_pattern(
43
+ answer, r"\d+\s*%"
44
+ ) or _has_pattern(answer, r"from\s+\d+.*to\s+\d+"):
45
+ score += 0.20
46
+ return min(score, 1.0)
47
+
48
+
49
+ def _grade_returns_spike(answer: str) -> float:
50
+ score = 0.0
51
+ if _has_any(answer, ["wireless headphones", "headphones pro", "headphone"]):
52
+ score += 0.20
53
+ if _has_any(answer, ["west"]):
54
+ score += 0.20
55
+ if _has_any(answer, ["audiotech", "audio tech"]):
56
+ score += 0.20
57
+ if _has_any(answer, [
58
+ "defect", "defective", "faulty", "quality",
59
+ "high return", "return rate", "abnormal",
60
+ "stopped working", "battery issue", "poor audio",
61
+ ]):
62
+ score += 0.20
63
+ if _has_pattern(answer, r"\d+\s*%") or _has_pattern(
64
+ answer, r"\d+\s*(returns|returned|units)"
65
+ ) or _has_any(answer, ["return rate", "compared to"]):
66
+ score += 0.20
67
+ return min(score, 1.0)
68
+
69
+
70
+ def _grade_customer_churn(answer: str) -> float:
71
+ score = 0.0
72
+ if _has_pattern(answer, r"\d+\s*%") or _has_any(answer, [
73
+ "decline", "decrease", "drop", "churn", "fewer active",
74
+ "lost customers", "stopped ordering",
75
+ ]):
76
+ score += 0.20
77
+ if _has_any(answer, ["enterprise"]):
78
+ score += 0.20
79
+ if _has_any(answer, ["northeast", "north east", "north-east"]):
80
+ score += 0.20
81
+ if _has_any(answer, [
82
+ "price increase", "price change", "price hike", "pricing",
83
+ "more expensive", "raised price", "cost increase",
84
+ ]):
85
+ score += 0.20
86
+ if _has_any(answer, [
87
+ "laptop pro", "desktop workstation", "office suite",
88
+ "devtools", "external ssd",
89
+ ]) or _has_pattern(answer, r"product.*(1|2|11|15|19)"):
90
+ score += 0.20
91
+ return min(score, 1.0)
92
+
93
+
94
+ def _grade_shipping_delay(answer: str) -> float:
95
+ score = 0.0
96
+ if _has_any(answer, ["midwest"]):
97
+ score += 0.20
98
+ if _has_any(answer, ["quickship", "quick ship"]):
99
+ score += 0.20
100
+ if _has_any(answer, [
101
+ "delivery delay", "late delivery", "delayed shipment",
102
+ "shipping delay", "late shipment", "delivery time",
103
+ "delayed delivery", "slow delivery",
104
+ ]):
105
+ score += 0.20
106
+ if _has_pattern(answer, r"feb(ruary)?\s*(10|mid|middle)") or _has_any(answer, [
107
+ "mid-february", "mid february", "around february",
108
+ "starting in february", "beginning of february",
109
+ ]):
110
+ score += 0.20
111
+ if _has_any(answer, [
112
+ "support ticket", "complaint", "ticket volume",
113
+ "customer satisfaction", "support request",
114
+ ]) and _has_any(answer, [
115
+ "delivery", "shipping", "carrier", "quickship",
116
+ ]):
117
+ score += 0.20
118
+ return min(score, 1.0)
119
+
120
+
121
+ def _grade_revenue_paradox(answer: str) -> float:
122
+ score = 0.0
123
+ if _has_any(answer, [
124
+ "spring mega sale", "mega sale", "25%", "25 percent",
125
+ ]) or (
126
+ _has_any(answer, ["promotion", "promo", "discount", "sale"])
127
+ and _has_any(answer, ["margin", "profit", "cost"])
128
+ ):
129
+ score += 0.20
130
+ if _has_any(answer, [
131
+ "product mix", "category mix", "mix shift", "shifted toward",
132
+ "higher proportion", "more electronics", "low-margin",
133
+ "composition changed",
134
+ ]):
135
+ score += 0.20
136
+ if _has_any(answer, ["enterprise"]) and _has_any(answer, [
137
+ "price increase", "price change", "price hike",
138
+ "lost", "churn", "left", "fewer", "decline",
139
+ ]):
140
+ score += 0.20
141
+ if _has_any(answer, ["return", "refund"]) and _has_any(answer, [
142
+ "cost", "expense", "profit", "margin", "loss", "erode",
143
+ ]):
144
+ score += 0.20
145
+ if _has_pattern(answer, r"\$\s*[\d,]+") or _has_pattern(
146
+ answer, r"\d+\s*%"
147
+ ) or _has_pattern(answer, r"from\s+\$?[\d,]+.*to\s+\$?[\d,]+"):
148
+ score += 0.20
149
+ return min(score, 1.0)
150
+
151
+
152
+ def _grade_supplier_quality(answer: str) -> float:
153
+ score = 0.0
154
+ if _has_any(answer, ["audiotech", "audio tech"]):
155
+ score += 0.20
156
+ if _has_any(answer, ["wireless headphones", "headphones pro", "product 6"]):
157
+ score += 0.20
158
+ if _has_any(answer, ["bluetooth speaker", "product 7"]):
159
+ score += 0.20
160
+ if _has_any(answer, ["return rate", "refund", "return volume"]) or _has_pattern(
161
+ answer, r"\d+\s*%.*return"
162
+ ) or _has_pattern(answer, r"return.*\d+\s*%") or _has_pattern(
163
+ answer, r"\$\s*[\d,]+"
164
+ ):
165
+ score += 0.20
166
+ if _has_any(answer, [
167
+ "support ticket", "defect", "complaint", "product_defect",
168
+ "quality issue", "customer complaint",
169
+ ]):
170
+ score += 0.20
171
+ return min(score, 1.0)
172
+
173
+
174
+ def _grade_inventory_stockout(answer: str) -> float:
175
+ score = 0.0
176
+ if _has_any(answer, ["west"]):
177
+ score += 0.20
178
+ if _has_any(answer, ["monitor", "product 4", "monitor 27"]):
179
+ score += 0.20
180
+ if _has_any(answer, [
181
+ "inventory", "stock", "out of stock", "stockout", "stock-out",
182
+ "zero units", "no inventory", "warehouse",
183
+ ]):
184
+ score += 0.20
185
+ if _has_any(answer, [
186
+ "spring mega sale", "mega sale", "promo", "promotion",
187
+ "february 15", "feb 15", "during the sale",
188
+ ]):
189
+ score += 0.20
190
+ if _has_pattern(answer, r"\d+\s*(units|orders|sales)") or _has_pattern(
191
+ answer, r"\d+\s*%"
192
+ ) or _has_pattern(answer, r"from\s+\d+.*to\s+\d+"):
193
+ score += 0.20
194
+ return min(score, 1.0)
195
+
196
+
197
+ def _grade_fraud_detection(answer: str) -> float:
198
+ score = 0.0
199
+ if _has_any(answer, ["southeast"]):
200
+ score += 0.20
201
+ if _has_any(answer, [
202
+ "new account", "recent signup", "recently created",
203
+ "new customer", "account creation", "registered in feb",
204
+ "signed up",
205
+ ]):
206
+ score += 0.20
207
+ if _has_any(answer, [
208
+ "high-value", "high value", "expensive", "laptop pro",
209
+ "desktop workstation", "large order", "electronics",
210
+ ]):
211
+ score += 0.20
212
+ if _has_pattern(answer, r"1[0-5]\s*(account|customer|user)") or _has_pattern(
213
+ answer, r"\$\s*[\d,]+"
214
+ ) or _has_pattern(answer, r"\d+\s*(order|transaction)"):
215
+ score += 0.20
216
+ if _has_any(answer, [
217
+ "pattern", "cluster", "coordinated", "suspicious",
218
+ "same product", "no return", "never returned",
219
+ "concentrated", "anomal", "fraud ring",
220
+ ]):
221
+ score += 0.20
222
+ return min(score, 1.0)
223
+
224
+
225
+ def _grade_repeat_purchase_decline(answer: str) -> float:
226
+ score = 0.0
227
+ if _has_any(answer, [
228
+ "repeat purchase", "repeat rate", "returning customer",
229
+ "repeat buyer", "repurchase", "order frequency",
230
+ "second order", "came back",
231
+ ]) and (_has_pattern(answer, r"\d+\s*%") or _has_any(answer, [
232
+ "decline", "drop", "decrease", "fell", "collapsed",
233
+ ])):
234
+ score += 0.20
235
+ if _has_any(answer, ["enterprise"]) and _has_any(answer, [
236
+ "price", "increase", "hike", "stopped", "left", "churn",
237
+ ]):
238
+ score += 0.20
239
+ if (_has_any(answer, ["midwest"]) or _has_any(answer, [
240
+ "shipping", "delivery", "quickship",
241
+ ])) and _has_any(answer, [
242
+ "repeat", "return", "reorder", "come back", "second order",
243
+ ]):
244
+ score += 0.20
245
+ if _has_any(answer, ["marketing", "acquisition", "spend"]) and _has_any(answer, [
246
+ "retention", "email", "loyalty", "re-engage", "lapsed",
247
+ "shifted", "new customer",
248
+ ]):
249
+ score += 0.20
250
+ if _has_any(answer, [
251
+ "segment", "cohort", "by region", "by segment",
252
+ "enterprise vs", "consumer vs", "smb vs",
253
+ ]) or _has_pattern(answer, r"(enterprise|smb|consumer).*\d+\s*%"):
254
+ score += 0.20
255
+ return min(score, 1.0)
256
+
257
+
258
+ TASKS: dict[str, dict] = {
259
+ "orders_drop": {
260
+ "id": "orders_drop",
261
+ "difficulty": "easy",
262
+ "title": "Weekly Orders Drop Investigation",
263
+ "description": (
264
+ "URGENT -- Our order volume dropped sharply in the first two weeks "
265
+ "of March compared to the last two weeks of February. Leadership "
266
+ "needs to know why.\n\n"
267
+ "Investigate the database, identify the root cause of the drop, "
268
+ "and submit a clear summary of your findings."
269
+ ),
270
+ },
271
+ "returns_spike": {
272
+ "id": "returns_spike",
273
+ "difficulty": "medium",
274
+ "title": "Product Returns Spike Investigation",
275
+ "description": (
276
+ "ALERT -- Our return rate has spiked significantly in recent weeks, "
277
+ "with particular concentration in one geographic region. This is "
278
+ "eating into margins.\n\n"
279
+ "Use the database to identify which product(s) are driving the "
280
+ "spike, which region is most affected, and what the likely root "
281
+ "cause is. Include the supplier if relevant."
282
+ ),
283
+ },
284
+ "customer_churn": {
285
+ "id": "customer_churn",
286
+ "difficulty": "hard",
287
+ "title": "Customer Churn Root Cause Analysis",
288
+ "description": (
289
+ "CRITICAL -- Our monthly active customer count has declined "
290
+ "significantly from January to March. The executive team wants a "
291
+ "full root-cause analysis.\n\n"
292
+ "Determine which customer segments and regions are most affected, "
293
+ "quantify the decline, and identify the most likely causes. "
294
+ "Check all available tables for clues."
295
+ ),
296
+ },
297
+ "shipping_delay": {
298
+ "id": "shipping_delay",
299
+ "difficulty": "medium-hard",
300
+ "title": "Customer Satisfaction Crisis Investigation",
301
+ "description": (
302
+ "ESCALATION -- Customer satisfaction scores have plummeted in one "
303
+ "of our regions. The support team is overwhelmed with complaints "
304
+ "and escalations are piling up.\n\n"
305
+ "Investigate what operational issue is driving the complaints, "
306
+ "identify the responsible party (carrier, warehouse, etc.), "
307
+ "determine when the problem started, and quantify the impact. "
308
+ "Cross-reference multiple data sources for a complete picture."
309
+ ),
310
+ },
311
+ "revenue_paradox": {
312
+ "id": "revenue_paradox",
313
+ "difficulty": "hard",
314
+ "title": "Revenue vs. Profit Paradox Investigation",
315
+ "description": (
316
+ "CRITICAL -- Revenue in February was our highest month ever, yet "
317
+ "gross profit actually *decreased* compared to January. The CFO "
318
+ "wants a full breakdown of why we are selling more but earning "
319
+ "less.\n\n"
320
+ "Analyze revenue, costs, margins, discounts, product mix, customer "
321
+ "segments, and any other relevant factors. This is likely multi-"
322
+ "causal -- identify ALL contributing factors and quantify their "
323
+ "impact. Use the products.cost column to compute margins."
324
+ ),
325
+ },
326
+ "supplier_quality": {
327
+ "id": "supplier_quality",
328
+ "difficulty": "medium",
329
+ "title": "Supplier Quality Crisis Investigation",
330
+ "description": (
331
+ "ESCALATION -- The VP of Merchandising has received escalating "
332
+ "complaints about product quality across multiple SKUs. Quality "
333
+ "Assurance wants a supplier-level analysis.\n\n"
334
+ "Determine which supplier(s) have systemic quality issues, which "
335
+ "of their products are affected, and quantify the total business "
336
+ "impact in returns, refunds, and support ticket volume. Include "
337
+ "return rates by supplier to support a contract renegotiation."
338
+ ),
339
+ },
340
+ "inventory_stockout": {
341
+ "id": "inventory_stockout",
342
+ "difficulty": "medium-hard",
343
+ "title": "Regional Sales Underperformance Investigation",
344
+ "description": (
345
+ "INVESTIGATION -- Our West region was projected to be the top "
346
+ "performer during the Spring Mega Sale based on historical trends "
347
+ "and marketing investment, but actual sales came in significantly "
348
+ "below the other regions.\n\n"
349
+ "The Regional VP demands an explanation. Investigate what caused "
350
+ "the West to underperform during our biggest promotional event. "
351
+ "Check product-level sales, inventory data, and any operational "
352
+ "issues that may have limited fulfillment."
353
+ ),
354
+ },
355
+ "fraud_detection": {
356
+ "id": "fraud_detection",
357
+ "difficulty": "hard",
358
+ "title": "Suspicious Order Pattern Investigation",
359
+ "description": (
360
+ "ALERT -- The Finance team has flagged a suspicious spike in "
361
+ "high-value orders from recently created accounts. Several of "
362
+ "these orders have already shipped.\n\n"
363
+ "Investigate the pattern: identify the suspicious accounts, "
364
+ "determine the scope of potential fraud, estimate the financial "
365
+ "exposure, and describe the behavioral signatures that "
366
+ "distinguish these accounts from legitimate customers. Look at "
367
+ "signup dates, order values, product choices, and geographic "
368
+ "concentration."
369
+ ),
370
+ },
371
+ "repeat_purchase_decline": {
372
+ "id": "repeat_purchase_decline",
373
+ "difficulty": "hard",
374
+ "title": "Customer Retention Crisis Investigation",
375
+ "description": (
376
+ "CRITICAL -- Monthly unique buyer count has held steady around "
377
+ "100, but the Customer Success team reports that repeat purchase "
378
+ "rates have collapsed. In January, roughly 40%% of orders came "
379
+ "from returning customers; by March, it appears to be under 20%%."
380
+ "\n\n"
381
+ "The CEO asks: are we becoming a one-time-purchase business? "
382
+ "Diagnose which customer segments and regions lost repeat buyers, "
383
+ "identify the root causes, and determine whether our marketing "
384
+ "spend strategy is masking a retention problem. Check the "
385
+ "marketing_spend table for clues about acquisition vs. retention "
386
+ "investment."
387
+ ),
388
+ },
389
+ }
390
+
391
+ _GRADERS: dict[str, Callable[[str], float]] = {
392
+ "orders_drop": _grade_orders_drop,
393
+ "returns_spike": _grade_returns_spike,
394
+ "customer_churn": _grade_customer_churn,
395
+ "shipping_delay": _grade_shipping_delay,
396
+ "revenue_paradox": _grade_revenue_paradox,
397
+ "supplier_quality": _grade_supplier_quality,
398
+ "inventory_stockout": _grade_inventory_stockout,
399
+ "fraud_detection": _grade_fraud_detection,
400
+ "repeat_purchase_decline": _grade_repeat_purchase_decline,
401
+ }
402
+
403
+
404
+ def grade_answer(task_id: str, answer: str) -> float:
405
+ grader = _GRADERS.get(task_id)
406
+ if grader is None:
407
+ return 0.0
408
+ return grader(answer)