Spaces:
Sleeping
Sleeping
Deploy DataDetective: 9-task business investigation environment
Browse files- .env.example +11 -0
- .gitignore +12 -0
- Dockerfile +12 -0
- README.md +131 -3
- __init__.py +9 -0
- client.py +51 -0
- inference.py +280 -0
- models.py +34 -0
- openenv.yaml +40 -0
- pyproject.toml +27 -0
- requirements.txt +6 -0
- server/__init__.py +0 -0
- server/app.py +36 -0
- server/database.py +657 -0
- server/environment.py +192 -0
- server/requirements.txt +6 -0
- server/tasks.py +408 -0
.env.example
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# LLM Configuration (required by hackathon evaluator)
|
| 2 |
+
API_BASE_URL=https://router.huggingface.co/v1
|
| 3 |
+
MODEL_NAME=gpt-4.1-mini
|
| 4 |
+
HF_TOKEN=hf_your_token_here
|
| 5 |
+
|
| 6 |
+
# Environment server
|
| 7 |
+
ENV_URL=http://localhost:7860
|
| 8 |
+
|
| 9 |
+
# AMD LLM Gateway (local development only — overrides API_BASE_URL when set)
|
| 10 |
+
# AMD_LLM_API_KEY=your-ocp-apim-subscription-key-here
|
| 11 |
+
# AMD_GATEWAY_BASE=https://llm-api.amd.com/openai
|
.gitignore
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
*.pyo
|
| 4 |
+
*.egg-info/
|
| 5 |
+
dist/
|
| 6 |
+
build/
|
| 7 |
+
.env
|
| 8 |
+
.venv/
|
| 9 |
+
venv/
|
| 10 |
+
*.sqlite
|
| 11 |
+
*.db
|
| 12 |
+
.DS_Store
|
Dockerfile
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
COPY server/requirements.txt ./requirements.txt
|
| 6 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 7 |
+
|
| 8 |
+
COPY . .
|
| 9 |
+
|
| 10 |
+
EXPOSE 7860
|
| 11 |
+
|
| 12 |
+
CMD ["uvicorn", "server.app:app", "--host", "0.0.0.0", "--port", "7860", "--workers", "2", "--ws-ping-interval", "300", "--ws-ping-timeout", "300"]
|
README.md
CHANGED
|
@@ -1,10 +1,138 @@
|
|
| 1 |
---
|
| 2 |
title: DataDetective
|
| 3 |
-
emoji:
|
| 4 |
colorFrom: blue
|
| 5 |
colorTo: green
|
| 6 |
sdk: docker
|
| 7 |
-
|
| 8 |
---
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
title: DataDetective
|
| 3 |
+
emoji: 🔍
|
| 4 |
colorFrom: blue
|
| 5 |
colorTo: green
|
| 6 |
sdk: docker
|
| 7 |
+
app_port: 7860
|
| 8 |
---
|
| 9 |
|
| 10 |
+
# DataDetective — Business Incident Investigation Environment
|
| 11 |
+
|
| 12 |
+
An [OpenEnv](https://github.com/meta-pytorch/OpenEnv) environment where AI
|
| 13 |
+
agents investigate real-world business incidents by querying a SQL database,
|
| 14 |
+
analysing patterns, and submitting root-cause findings.
|
| 15 |
+
|
| 16 |
+
## What It Does
|
| 17 |
+
|
| 18 |
+
The agent is given a realistic company database (TechMart — a mid-size B2B+B2C
|
| 19 |
+
electronics retailer) and a business problem to investigate. It can execute
|
| 20 |
+
SQL queries to explore the data, then submit a final written analysis. The
|
| 21 |
+
environment automatically grades the analysis based on whether key findings
|
| 22 |
+
were identified. Each task has 5 grading criteria worth 0.20 each, enabling
|
| 23 |
+
meaningful partial credit.
|
| 24 |
+
|
| 25 |
+
## Tasks (Easy → Hard)
|
| 26 |
+
|
| 27 |
+
| # | Task ID | Difficulty | Scenario |
|
| 28 |
+
|---|---------|-----------|----------|
|
| 29 |
+
| 1 | `orders_drop` | Easy | Order volume dropped sharply after promo ended |
|
| 30 |
+
| 2 | `returns_spike` | Medium | Product returns spiking in West region (defective SKU) |
|
| 31 |
+
| 3 | `supplier_quality` | Medium | Supplier-level quality crisis across multiple products |
|
| 32 |
+
| 4 | `shipping_delay` | Medium-Hard | Customer satisfaction crisis from carrier delays |
|
| 33 |
+
| 5 | `inventory_stockout` | Medium-Hard | Regional sales underperformance from warehouse stockout |
|
| 34 |
+
| 6 | `customer_churn` | Hard | Active customer decline across segments post price hike |
|
| 35 |
+
| 7 | `revenue_paradox` | Hard | Revenue up but profit down — multi-causal margin erosion |
|
| 36 |
+
| 8 | `fraud_detection` | Hard | Coordinated fraud ring with fake accounts |
|
| 37 |
+
| 9 | `repeat_purchase_decline` | Hard | Repeat purchase collapse masked by acquisition spend |
|
| 38 |
+
|
| 39 |
+
Each task is scored 0.0 – 1.0 based on specific findings the agent must discover.
|
| 40 |
+
|
| 41 |
+
## Action / Observation Spaces
|
| 42 |
+
|
| 43 |
+
### Action (`DataDetectiveAction`)
|
| 44 |
+
|
| 45 |
+
| Field | Type | Description |
|
| 46 |
+
|-------|------|-------------|
|
| 47 |
+
| `action_type` | `str` | `"query"` to run SQL, `"answer"` to submit findings |
|
| 48 |
+
| `content` | `str` | SQL query string or final analysis text |
|
| 49 |
+
|
| 50 |
+
### Observation (`DataDetectiveObservation`)
|
| 51 |
+
|
| 52 |
+
| Field | Type | Description |
|
| 53 |
+
|-------|------|-------------|
|
| 54 |
+
| `output` | `str` | Query results (formatted table) or feedback |
|
| 55 |
+
| `task_description` | `str` | The investigation task |
|
| 56 |
+
| `schema_info` | `str` | Database schema (shown at reset) |
|
| 57 |
+
| `step_number` | `int` | Current step |
|
| 58 |
+
| `max_steps` | `int` | Maximum steps allowed (30) |
|
| 59 |
+
| `message` | `str` | Status message |
|
| 60 |
+
|
| 61 |
+
## Database Schema (11 Tables)
|
| 62 |
+
|
| 63 |
+
The TechMart database includes:
|
| 64 |
+
|
| 65 |
+
| Table | Description |
|
| 66 |
+
|-------|-------------|
|
| 67 |
+
| `customers` | Customer demographics (region, segment, signup date) |
|
| 68 |
+
| `products` | Product catalog (category, price, cost, supplier) |
|
| 69 |
+
| `orders` | Order history with totals |
|
| 70 |
+
| `order_items` | Line items with quantity and unit price |
|
| 71 |
+
| `returns` | Product returns with reasons and refund amounts |
|
| 72 |
+
| `promotions` | Promotional campaigns with discount percentages |
|
| 73 |
+
| `price_changes` | Historical price adjustments |
|
| 74 |
+
| `shipping` | Shipment records with carrier and delivery dates |
|
| 75 |
+
| `support_tickets` | Customer support tickets by category and priority |
|
| 76 |
+
| `inventory_log` | Daily stock levels per product per warehouse region |
|
| 77 |
+
| `marketing_spend` | Daily marketing spend by channel, campaign, and region |
|
| 78 |
+
|
| 79 |
+
All data is synthetic, generated in-memory (no external databases required).
|
| 80 |
+
|
| 81 |
+
## Quick Start
|
| 82 |
+
|
| 83 |
+
### 1. Install Dependencies
|
| 84 |
+
|
| 85 |
+
```bash
|
| 86 |
+
pip install -r requirements.txt
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
### 2. Start the Server
|
| 90 |
+
|
| 91 |
+
```bash
|
| 92 |
+
uvicorn server.app:app --host 0.0.0.0 --port 7860
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
### 3. Health Check
|
| 96 |
+
|
| 97 |
+
```bash
|
| 98 |
+
curl http://localhost:7860/health
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
### 4. Run the Baseline Agent
|
| 102 |
+
|
| 103 |
+
```bash
|
| 104 |
+
API_BASE_URL="https://router.huggingface.co/v1" \
|
| 105 |
+
MODEL_NAME="gpt-4.1-mini" \
|
| 106 |
+
HF_TOKEN="hf_..." \
|
| 107 |
+
python inference.py
|
| 108 |
+
```
|
| 109 |
+
|
| 110 |
+
### 5. Docker
|
| 111 |
+
|
| 112 |
+
```bash
|
| 113 |
+
docker build -t data-detective .
|
| 114 |
+
docker run -p 7860:7860 data-detective
|
| 115 |
+
```
|
| 116 |
+
|
| 117 |
+
## Environment Variables
|
| 118 |
+
|
| 119 |
+
| Env Var | Purpose | Required |
|
| 120 |
+
|---------|---------|----------|
|
| 121 |
+
| `API_BASE_URL` | LLM endpoint URL | Yes |
|
| 122 |
+
| `MODEL_NAME` | Model identifier | Yes |
|
| 123 |
+
| `HF_TOKEN` | API key / HF token | Yes |
|
| 124 |
+
| `ENV_URL` | Environment server URL | No (default: `http://localhost:7860`) |
|
| 125 |
+
|
| 126 |
+
## How Grading Works
|
| 127 |
+
|
| 128 |
+
Each task has an automated grader that checks the agent's final answer for
|
| 129 |
+
specific key findings (keywords, patterns, named entities). Each task has 5
|
| 130 |
+
grading criteria worth 0.20 each, for a maximum score of 1.0. Partial credit
|
| 131 |
+
is awarded for each finding discovered.
|
| 132 |
+
|
| 133 |
+
## Setup Requirements
|
| 134 |
+
|
| 135 |
+
- Python 3.10+
|
| 136 |
+
- No GPU required
|
| 137 |
+
- Runs within 2 vCPU / 8 GB memory
|
| 138 |
+
- All data is generated in-memory (no external databases)
|
__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .models import DataDetectiveAction, DataDetectiveObservation, DataDetectiveState
|
| 2 |
+
from .client import DataDetectiveEnv
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"DataDetectiveAction",
|
| 6 |
+
"DataDetectiveObservation",
|
| 7 |
+
"DataDetectiveState",
|
| 8 |
+
"DataDetectiveEnv",
|
| 9 |
+
]
|
client.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""WebSocket client for the DataDetective environment."""
|
| 2 |
+
|
| 3 |
+
from typing import Dict
|
| 4 |
+
|
| 5 |
+
from openenv.core.env_client import EnvClient
|
| 6 |
+
from openenv.core.client_types import StepResult
|
| 7 |
+
|
| 8 |
+
from .models import DataDetectiveAction, DataDetectiveObservation, DataDetectiveState
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class DataDetectiveEnv(
|
| 12 |
+
EnvClient[DataDetectiveAction, DataDetectiveObservation, DataDetectiveState]
|
| 13 |
+
):
|
| 14 |
+
"""
|
| 15 |
+
Async/sync client for DataDetective.
|
| 16 |
+
|
| 17 |
+
Example (sync):
|
| 18 |
+
>>> with DataDetectiveEnv(base_url="http://localhost:7860").sync() as env:
|
| 19 |
+
... result = env.reset(task_id="orders_drop")
|
| 20 |
+
... result = env.step(DataDetectiveAction(action_type="query", content="SELECT COUNT(*) FROM orders"))
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
def _step_payload(self, action: DataDetectiveAction) -> Dict:
|
| 24 |
+
return {"action_type": action.action_type, "content": action.content}
|
| 25 |
+
|
| 26 |
+
def _parse_result(self, payload: Dict) -> StepResult[DataDetectiveObservation]:
|
| 27 |
+
obs = payload.get("observation", {})
|
| 28 |
+
observation = DataDetectiveObservation(
|
| 29 |
+
output=obs.get("output", ""),
|
| 30 |
+
task_description=obs.get("task_description", ""),
|
| 31 |
+
schema_info=obs.get("schema_info", ""),
|
| 32 |
+
step_number=obs.get("step_number", 0),
|
| 33 |
+
max_steps=obs.get("max_steps", 30),
|
| 34 |
+
message=obs.get("message", ""),
|
| 35 |
+
done=payload.get("done", False),
|
| 36 |
+
reward=payload.get("reward"),
|
| 37 |
+
)
|
| 38 |
+
return StepResult(
|
| 39 |
+
observation=observation,
|
| 40 |
+
reward=payload.get("reward"),
|
| 41 |
+
done=payload.get("done", False),
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
def _parse_state(self, payload: Dict) -> DataDetectiveState:
|
| 45 |
+
return DataDetectiveState(
|
| 46 |
+
episode_id=payload.get("episode_id"),
|
| 47 |
+
step_count=payload.get("step_count", 0),
|
| 48 |
+
task_id=payload.get("task_id", ""),
|
| 49 |
+
queries_executed=payload.get("queries_executed", 0),
|
| 50 |
+
max_steps=payload.get("max_steps", 30),
|
| 51 |
+
)
|
inference.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Baseline inference script for DataDetective.
|
| 4 |
+
|
| 5 |
+
Uses an LLM via the OpenAI-compatible API to investigate each task by
|
| 6 |
+
running SQL queries and submitting a final analysis.
|
| 7 |
+
|
| 8 |
+
Required environment variables (set by hackathon evaluator):
|
| 9 |
+
API_BASE_URL — LLM endpoint (e.g. https://router.huggingface.co/v1)
|
| 10 |
+
MODEL_NAME — model identifier (e.g. gpt-4.1-mini)
|
| 11 |
+
HF_TOKEN — API key / Hugging Face token
|
| 12 |
+
|
| 13 |
+
Optional:
|
| 14 |
+
ENV_URL — DataDetective server URL (default http://localhost:7860)
|
| 15 |
+
AMD_LLM_API_KEY — If set, uses AMD Gateway instead (local dev only)
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import asyncio
|
| 19 |
+
import json
|
| 20 |
+
import os
|
| 21 |
+
import re
|
| 22 |
+
import sys
|
| 23 |
+
import time
|
| 24 |
+
|
| 25 |
+
from openai import AzureOpenAI, OpenAI
|
| 26 |
+
|
| 27 |
+
import websockets.asyncio.client as _wsc
|
| 28 |
+
_orig_ws_connect = _wsc.connect
|
| 29 |
+
def _patched_connect(*a, **kw):
|
| 30 |
+
kw.setdefault("ping_interval", 300)
|
| 31 |
+
kw.setdefault("ping_timeout", 300)
|
| 32 |
+
return _orig_ws_connect(*a, **kw)
|
| 33 |
+
_wsc.connect = _patched_connect
|
| 34 |
+
|
| 35 |
+
import openenv.core.env_client as _ec
|
| 36 |
+
_ec.ws_connect = _patched_connect
|
| 37 |
+
|
| 38 |
+
from openenv.core.generic_client import GenericEnvClient
|
| 39 |
+
|
| 40 |
+
# ---------------------------------------------------------------------------
|
| 41 |
+
# Configuration
|
| 42 |
+
# ---------------------------------------------------------------------------
|
| 43 |
+
|
| 44 |
+
API_BASE_URL = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1")
|
| 45 |
+
MODEL_NAME = os.environ.get("MODEL_NAME", "gpt-4.1-mini")
|
| 46 |
+
HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("API_KEY", "")
|
| 47 |
+
AMD_LLM_API_KEY = os.environ.get("AMD_LLM_API_KEY", "")
|
| 48 |
+
ENV_URL = os.environ.get("ENV_URL", "http://localhost:7860").rstrip("/")
|
| 49 |
+
|
| 50 |
+
BENCHMARK = "data_detective"
|
| 51 |
+
MAX_STEPS = 20
|
| 52 |
+
|
| 53 |
+
TASK_IDS = [
|
| 54 |
+
"orders_drop",
|
| 55 |
+
"returns_spike",
|
| 56 |
+
"customer_churn",
|
| 57 |
+
"shipping_delay",
|
| 58 |
+
"revenue_paradox",
|
| 59 |
+
"supplier_quality",
|
| 60 |
+
"inventory_stockout",
|
| 61 |
+
"fraud_detection",
|
| 62 |
+
"repeat_purchase_decline",
|
| 63 |
+
]
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _build_llm_client() -> OpenAI:
|
| 67 |
+
if AMD_LLM_API_KEY:
|
| 68 |
+
return AzureOpenAI(
|
| 69 |
+
api_key="dummy",
|
| 70 |
+
api_version="2024-02-01",
|
| 71 |
+
base_url=os.environ.get("AMD_GATEWAY_BASE", "https://llm-api.amd.com/openai"),
|
| 72 |
+
default_headers={"Ocp-Apim-Subscription-Key": AMD_LLM_API_KEY},
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
if HF_TOKEN:
|
| 76 |
+
return OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
|
| 77 |
+
|
| 78 |
+
print(
|
| 79 |
+
"ERROR: Set HF_TOKEN (or API_KEY) for LLM access, "
|
| 80 |
+
"or AMD_LLM_API_KEY for AMD Gateway. Exiting.",
|
| 81 |
+
file=sys.stderr,
|
| 82 |
+
)
|
| 83 |
+
sys.exit(1)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
llm = _build_llm_client()
|
| 87 |
+
|
| 88 |
+
SYSTEM_PROMPT = """\
|
| 89 |
+
You are an expert data analyst investigating a business incident using a
|
| 90 |
+
SQL database. You have a LIMITED number of query steps, so be strategic.
|
| 91 |
+
|
| 92 |
+
At each turn respond with EXACTLY one JSON object (no extra text):
|
| 93 |
+
|
| 94 |
+
{{"action_type": "query", "content": "<SQL query>"}}
|
| 95 |
+
{{"action_type": "answer", "content": "<your analysis>"}}
|
| 96 |
+
|
| 97 |
+
Investigation strategy:
|
| 98 |
+
1. EXPLORE (1-2 queries): List tables and sample key columns to understand
|
| 99 |
+
the schema. Note all available tables -- some may hold critical clues.
|
| 100 |
+
2. HYPOTHESISE: Based on the task description, form 2-3 likely root causes.
|
| 101 |
+
3. QUERY (targeted): Run focused queries that confirm or reject each
|
| 102 |
+
hypothesis. Use JOINs across tables, GROUP BY with aggregates, and
|
| 103 |
+
compare time periods. Avoid broad SELECT * scans.
|
| 104 |
+
4. QUANTIFY: For every finding, gather specific numbers -- counts, totals,
|
| 105 |
+
percentages, before/after comparisons.
|
| 106 |
+
5. ANSWER: Submit a thorough analysis naming every root cause with
|
| 107 |
+
supporting evidence. Include specific product names, regions, customer
|
| 108 |
+
segments, suppliers, dollar amounts, dates, and percentages.
|
| 109 |
+
|
| 110 |
+
You have {max_steps} steps total. Budget roughly 70 % for querying and
|
| 111 |
+
reserve the last few steps for your answer. Do NOT run out of steps
|
| 112 |
+
without submitting -- partial evidence is better than none.
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
# ---------------------------------------------------------------------------
|
| 116 |
+
# Helpers
|
| 117 |
+
# ---------------------------------------------------------------------------
|
| 118 |
+
|
| 119 |
+
def _extract_json(text: str) -> dict:
|
| 120 |
+
text = text.strip()
|
| 121 |
+
try:
|
| 122 |
+
return json.loads(text)
|
| 123 |
+
except json.JSONDecodeError:
|
| 124 |
+
pass
|
| 125 |
+
m = re.search(r"```(?:json)?\s*(\{.*?\})\s*```", text, re.DOTALL)
|
| 126 |
+
if m:
|
| 127 |
+
try:
|
| 128 |
+
return json.loads(m.group(1))
|
| 129 |
+
except json.JSONDecodeError:
|
| 130 |
+
pass
|
| 131 |
+
m = re.search(r"\{[^{}]*\}", text, re.DOTALL)
|
| 132 |
+
if m:
|
| 133 |
+
try:
|
| 134 |
+
return json.loads(m.group(0))
|
| 135 |
+
except json.JSONDecodeError:
|
| 136 |
+
pass
|
| 137 |
+
return {"action_type": "answer", "content": text}
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def _log_start(task_id: str) -> None:
|
| 141 |
+
print(
|
| 142 |
+
f"[START] task={task_id} env={BENCHMARK} model={MODEL_NAME}",
|
| 143 |
+
flush=True,
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def _log_step(step: int, action: dict, reward: float, done: bool, error: str | None) -> None:
|
| 148 |
+
action_str = json.dumps(action, separators=(",", ":"))
|
| 149 |
+
error_val = f"'{error}'" if error else "null"
|
| 150 |
+
print(
|
| 151 |
+
f"[STEP] step={step} action={action_str} "
|
| 152 |
+
f"reward={reward:.2f} done={str(done).lower()} error={error_val}",
|
| 153 |
+
flush=True,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def _log_end(success: bool, steps: int, score: float, rewards: list[float]) -> None:
|
| 158 |
+
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
|
| 159 |
+
print(
|
| 160 |
+
f"[END] success={str(success).lower()} steps={steps} "
|
| 161 |
+
f"score={score:.3f} rewards={rewards_str}",
|
| 162 |
+
flush=True,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
async def run_task(task_id: str) -> float:
|
| 167 |
+
_log_start(task_id)
|
| 168 |
+
|
| 169 |
+
rewards: list[float] = []
|
| 170 |
+
step = 0
|
| 171 |
+
reward = 0.0
|
| 172 |
+
done = False
|
| 173 |
+
success = False
|
| 174 |
+
error_msg = None
|
| 175 |
+
|
| 176 |
+
try:
|
| 177 |
+
async with GenericEnvClient(base_url=ENV_URL) as env:
|
| 178 |
+
result = await env.reset(task_id=task_id)
|
| 179 |
+
obs = result.observation
|
| 180 |
+
|
| 181 |
+
system = SYSTEM_PROMPT.format(max_steps=MAX_STEPS)
|
| 182 |
+
messages = [
|
| 183 |
+
{"role": "system", "content": system},
|
| 184 |
+
{
|
| 185 |
+
"role": "user",
|
| 186 |
+
"content": (
|
| 187 |
+
f"## Investigation Task\n{obs.get('task_description', '')}\n\n"
|
| 188 |
+
f"## Database\n{obs.get('schema_info', '')}\n\n"
|
| 189 |
+
f"You have {MAX_STEPS} steps. Begin your investigation."
|
| 190 |
+
),
|
| 191 |
+
},
|
| 192 |
+
]
|
| 193 |
+
|
| 194 |
+
while not done and step < MAX_STEPS:
|
| 195 |
+
try:
|
| 196 |
+
completion = llm.chat.completions.create(
|
| 197 |
+
model=MODEL_NAME,
|
| 198 |
+
messages=messages,
|
| 199 |
+
temperature=0.1,
|
| 200 |
+
max_completion_tokens=1024,
|
| 201 |
+
)
|
| 202 |
+
llm_text = completion.choices[0].message.content or ""
|
| 203 |
+
except Exception as exc:
|
| 204 |
+
llm_text = json.dumps({
|
| 205 |
+
"action_type": "answer",
|
| 206 |
+
"content": "Unable to complete analysis due to LLM error.",
|
| 207 |
+
})
|
| 208 |
+
error_msg = str(exc)
|
| 209 |
+
|
| 210 |
+
action = _extract_json(llm_text)
|
| 211 |
+
if "action_type" not in action:
|
| 212 |
+
action["action_type"] = "query"
|
| 213 |
+
if "content" not in action:
|
| 214 |
+
action["content"] = llm_text
|
| 215 |
+
|
| 216 |
+
result = await env.step(action)
|
| 217 |
+
step += 1
|
| 218 |
+
done = result.done
|
| 219 |
+
reward = result.reward or 0.0
|
| 220 |
+
rewards.append(reward)
|
| 221 |
+
result_obs = result.observation
|
| 222 |
+
remaining = MAX_STEPS - step
|
| 223 |
+
|
| 224 |
+
_log_step(step, action, reward, done, error_msg)
|
| 225 |
+
error_msg = None
|
| 226 |
+
|
| 227 |
+
messages.append({"role": "assistant", "content": llm_text})
|
| 228 |
+
|
| 229 |
+
if not done and remaining <= 3:
|
| 230 |
+
urgency = (
|
| 231 |
+
f"URGENT: Only {remaining} step(s) left! "
|
| 232 |
+
"You MUST submit your final answer NOW using "
|
| 233 |
+
'{"action_type": "answer", "content": "..."}. '
|
| 234 |
+
"Summarize ALL findings so far."
|
| 235 |
+
)
|
| 236 |
+
else:
|
| 237 |
+
urgency = "Continue investigating or submit your final answer."
|
| 238 |
+
|
| 239 |
+
messages.append({
|
| 240 |
+
"role": "user",
|
| 241 |
+
"content": (
|
| 242 |
+
f"Query result:\n{result_obs.get('output', '')}\n\n"
|
| 243 |
+
f"{result_obs.get('message', '')}\n\n"
|
| 244 |
+
f"[Step {step}/{MAX_STEPS}] {urgency}"
|
| 245 |
+
),
|
| 246 |
+
})
|
| 247 |
+
|
| 248 |
+
success = done and reward > 0.0
|
| 249 |
+
except Exception as exc:
|
| 250 |
+
error_msg = str(exc)
|
| 251 |
+
_log_step(step + 1, {"action_type": "error"}, 0.0, False, error_msg)
|
| 252 |
+
|
| 253 |
+
score = reward if done else 0.0
|
| 254 |
+
_log_end(success=success, steps=step, score=score, rewards=rewards)
|
| 255 |
+
return score
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
# ---------------------------------------------------------------------------
|
| 259 |
+
# Main
|
| 260 |
+
# ---------------------------------------------------------------------------
|
| 261 |
+
|
| 262 |
+
async def amain():
|
| 263 |
+
total = 0.0
|
| 264 |
+
for tid in TASK_IDS:
|
| 265 |
+
try:
|
| 266 |
+
r = await run_task(tid)
|
| 267 |
+
except Exception as exc:
|
| 268 |
+
print(f"[END] success=false steps=0 score=0.000 rewards=", flush=True)
|
| 269 |
+
r = 0.0
|
| 270 |
+
total += r
|
| 271 |
+
avg = total / len(TASK_IDS) if TASK_IDS else 0
|
| 272 |
+
print(f"\n=== Overall average score: {avg:.2f} ===", flush=True)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def main():
|
| 276 |
+
asyncio.run(amain())
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
if __name__ == "__main__":
|
| 280 |
+
main()
|
models.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import Field
|
| 2 |
+
from openenv.core.env_server.types import Action, Observation, State
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class DataDetectiveAction(Action):
|
| 6 |
+
"""Agent action: run a SQL query or submit a final answer."""
|
| 7 |
+
|
| 8 |
+
action_type: str = Field(
|
| 9 |
+
...,
|
| 10 |
+
description="'query' to execute SQL against the database, or 'answer' to submit findings",
|
| 11 |
+
)
|
| 12 |
+
content: str = Field(
|
| 13 |
+
...,
|
| 14 |
+
description="SQL query string (for action_type='query') or final analysis text (for action_type='answer')",
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class DataDetectiveObservation(Observation):
|
| 19 |
+
"""Observation returned after each action."""
|
| 20 |
+
|
| 21 |
+
output: str = Field(default="", description="Query results or system feedback")
|
| 22 |
+
task_description: str = Field(default="", description="The investigation task to solve")
|
| 23 |
+
schema_info: str = Field(default="", description="Database schema (provided at reset)")
|
| 24 |
+
step_number: int = Field(default=0, description="Current step in the episode")
|
| 25 |
+
max_steps: int = Field(default=30, description="Maximum steps allowed")
|
| 26 |
+
message: str = Field(default="", description="Status or feedback message")
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class DataDetectiveState(State):
|
| 30 |
+
"""Internal environment state."""
|
| 31 |
+
|
| 32 |
+
task_id: str = Field(default="", description="Current task identifier")
|
| 33 |
+
queries_executed: int = Field(default=0, description="Number of SQL queries run so far")
|
| 34 |
+
max_steps: int = Field(default=30, description="Maximum steps allowed")
|
openenv.yaml
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: data_detective
|
| 2 |
+
version: "1.0.0"
|
| 3 |
+
description: >
|
| 4 |
+
DataDetective: A business incident investigation environment where AI agents
|
| 5 |
+
use SQL queries to analyze a realistic e-commerce company database (TechMart)
|
| 6 |
+
and uncover root causes of business problems. Covers 9 tasks spanning order
|
| 7 |
+
analysis, product returns, customer churn, shipping ops, margin analysis,
|
| 8 |
+
supplier quality, inventory stockouts, fraud detection, and retention.
|
| 9 |
+
endpoints:
|
| 10 |
+
reset: /reset
|
| 11 |
+
step: /step
|
| 12 |
+
state: /state
|
| 13 |
+
tasks:
|
| 14 |
+
- id: orders_drop
|
| 15 |
+
difficulty: easy
|
| 16 |
+
description: Order volume dropped sharply after a major promotion ended
|
| 17 |
+
- id: returns_spike
|
| 18 |
+
difficulty: medium
|
| 19 |
+
description: Product returns spiking in a specific region due to defective SKU
|
| 20 |
+
- id: customer_churn
|
| 21 |
+
difficulty: hard
|
| 22 |
+
description: Active customer count declining across specific segments
|
| 23 |
+
- id: shipping_delay
|
| 24 |
+
difficulty: medium-hard
|
| 25 |
+
description: Customer satisfaction crisis driven by carrier delays in one region
|
| 26 |
+
- id: revenue_paradox
|
| 27 |
+
difficulty: hard
|
| 28 |
+
description: Revenue is up but profit is down — multi-causal margin erosion
|
| 29 |
+
- id: supplier_quality
|
| 30 |
+
difficulty: medium
|
| 31 |
+
description: Systemic quality issues from a single supplier across multiple products
|
| 32 |
+
- id: inventory_stockout
|
| 33 |
+
difficulty: medium-hard
|
| 34 |
+
description: Regional sales underperformance caused by warehouse stockout during promo
|
| 35 |
+
- id: fraud_detection
|
| 36 |
+
difficulty: hard
|
| 37 |
+
description: Coordinated fraud ring of fake accounts placing high-value orders
|
| 38 |
+
- id: repeat_purchase_decline
|
| 39 |
+
difficulty: hard
|
| 40 |
+
description: Repeat purchase rates collapsing while acquisition masks the problem
|
pyproject.toml
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "data_detective_env"
|
| 3 |
+
version = "1.0.0"
|
| 4 |
+
description = "DataDetective: Business incident investigation environment for OpenEnv"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.10"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"openenv-core>=0.2.0",
|
| 9 |
+
"fastapi>=0.104.0",
|
| 10 |
+
"uvicorn>=0.24.0",
|
| 11 |
+
"pydantic>=2.0.0",
|
| 12 |
+
"websockets>=12.0",
|
| 13 |
+
]
|
| 14 |
+
|
| 15 |
+
[project.optional-dependencies]
|
| 16 |
+
dev = ["pytest", "httpx"]
|
| 17 |
+
inference = ["openai>=1.0.0"]
|
| 18 |
+
|
| 19 |
+
[build-system]
|
| 20 |
+
requires = ["setuptools>=68.0"]
|
| 21 |
+
build-backend = "setuptools.backends._legacy:_Backend"
|
| 22 |
+
|
| 23 |
+
[tool.setuptools.packages.find]
|
| 24 |
+
include = ["data_detective_env*"]
|
| 25 |
+
|
| 26 |
+
[project.scripts]
|
| 27 |
+
server = "server.app:main"
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv-core>=0.2.0
|
| 2 |
+
fastapi>=0.104.0
|
| 3 |
+
uvicorn>=0.24.0
|
| 4 |
+
pydantic>=2.0.0
|
| 5 |
+
websockets>=12.0
|
| 6 |
+
openai>=1.0.0
|
server/__init__.py
ADDED
|
File without changes
|
server/app.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""FastAPI application for the DataDetective environment."""
|
| 2 |
+
|
| 3 |
+
try:
|
| 4 |
+
from openenv.core.env_server.http_server import create_app
|
| 5 |
+
except Exception as e:
|
| 6 |
+
raise ImportError(
|
| 7 |
+
"openenv-core is required. pip install openenv-core"
|
| 8 |
+
) from e
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
from ..models import DataDetectiveAction, DataDetectiveObservation
|
| 12 |
+
from .environment import DataDetectiveEnvironment
|
| 13 |
+
except (ImportError, ModuleNotFoundError):
|
| 14 |
+
from models import DataDetectiveAction, DataDetectiveObservation
|
| 15 |
+
from server.environment import DataDetectiveEnvironment
|
| 16 |
+
|
| 17 |
+
app = create_app(
|
| 18 |
+
DataDetectiveEnvironment,
|
| 19 |
+
DataDetectiveAction,
|
| 20 |
+
DataDetectiveObservation,
|
| 21 |
+
env_name="data_detective",
|
| 22 |
+
max_concurrent_envs=10,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def main(host: str = "0.0.0.0", port: int = 7860):
|
| 27 |
+
import uvicorn
|
| 28 |
+
uvicorn.run(app, host=host, port=port)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
if __name__ == "__main__":
|
| 32 |
+
import argparse
|
| 33 |
+
parser = argparse.ArgumentParser()
|
| 34 |
+
parser.add_argument("--port", type=int, default=7860)
|
| 35 |
+
args = parser.parse_args()
|
| 36 |
+
main(port=args.port)
|
server/database.py
ADDED
|
@@ -0,0 +1,657 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Generates a realistic in-memory SQLite database for TechMart, a fictional
|
| 3 |
+
e-commerce company. The data contains deliberate patterns that support
|
| 4 |
+
nine investigation tasks:
|
| 5 |
+
|
| 6 |
+
1. Orders drop after a major promotion ends
|
| 7 |
+
2. Product returns spike for a specific SKU in the West region
|
| 8 |
+
3. Customer churn concentrated in the Enterprise/Northeast segment
|
| 9 |
+
4. Shipping delays by QuickShip in the Midwest driving support tickets
|
| 10 |
+
5. Revenue up but profit down (multi-causal paradox)
|
| 11 |
+
6. Supplier quality crisis (AudioTech products 6 & 7)
|
| 12 |
+
7. Inventory stockout in West for Monitor 27-inch during promo
|
| 13 |
+
8. Coordinated fraud ring in Southeast with new accounts
|
| 14 |
+
9. Repeat purchase decline masked by new-customer acquisition spend
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
import random
|
| 18 |
+
import sqlite3
|
| 19 |
+
from datetime import datetime, timedelta
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
PRODUCTS = [
|
| 23 |
+
(1, "Laptop Pro 15", "Electronics", 999.99, 650.00, "TechCorp"),
|
| 24 |
+
(2, "Desktop Workstation", "Electronics", 1499.99, 950.00, "TechCorp"),
|
| 25 |
+
(3, "Tablet Ultra", "Electronics", 599.99, 350.00, "TechCorp"),
|
| 26 |
+
(4, "Monitor 27-inch", "Electronics", 449.99, 280.00, "DisplayMax"),
|
| 27 |
+
(5, "Smart TV 55-inch", "Electronics", 699.99, 420.00, "DisplayMax"),
|
| 28 |
+
(6, "Wireless Headphones Pro", "Accessories", 149.99, 45.00, "AudioTech"),
|
| 29 |
+
(7, "Bluetooth Speaker", "Accessories", 79.99, 30.00, "AudioTech"),
|
| 30 |
+
(8, "USB-C Hub", "Accessories", 49.99, 15.00, "ConnectPlus"),
|
| 31 |
+
(9, "Laptop Bag Premium", "Accessories", 39.99, 12.00, "CarryAll"),
|
| 32 |
+
(10, "Mouse Pad XL", "Accessories", 24.99, 8.00, "CarryAll"),
|
| 33 |
+
(11, "Office Suite License", "Software", 199.99, 20.00, "SoftVault"),
|
| 34 |
+
(12, "Antivirus Pro Annual", "Software", 49.99, 5.00, "SecureNet"),
|
| 35 |
+
(13, "Cloud Backup 1TB", "Software", 99.99, 10.00, "CloudStore"),
|
| 36 |
+
(14, "Design Studio Pro", "Software", 299.99, 30.00, "CreativeSoft"),
|
| 37 |
+
(15, "DevTools Ultimate", "Software", 149.99, 15.00, "CodeForge"),
|
| 38 |
+
(16, "Mechanical Keyboard RGB", "Peripherals", 129.99, 60.00, "KeyMaster"),
|
| 39 |
+
(17, "Wireless Mouse Pro", "Peripherals", 59.99, 20.00, "ClickTech"),
|
| 40 |
+
(18, "Webcam HD 1080p", "Peripherals", 89.99, 35.00, "VisionCam"),
|
| 41 |
+
(19, "External SSD 1TB", "Peripherals", 109.99, 55.00, "StoragePro"),
|
| 42 |
+
(20, "Laser Printer Pro", "Peripherals", 249.99, 130.00, "PrintMax"),
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
_FIRST = [
|
| 46 |
+
"James","Mary","Robert","Patricia","John","Jennifer","Michael","Linda",
|
| 47 |
+
"David","Elizabeth","William","Barbara","Richard","Susan","Joseph","Jessica",
|
| 48 |
+
"Thomas","Sarah","Christopher","Karen","Charles","Lisa","Daniel","Nancy",
|
| 49 |
+
"Matthew","Betty","Anthony","Margaret","Mark","Sandra","Donald","Ashley",
|
| 50 |
+
"Steven","Dorothy","Andrew","Kimberly","Paul","Emily","Joshua","Donna",
|
| 51 |
+
"Kenneth","Michelle","Kevin","Carol","Brian","Amanda","George","Melissa",
|
| 52 |
+
"Timothy","Deborah",
|
| 53 |
+
]
|
| 54 |
+
|
| 55 |
+
_LAST = [
|
| 56 |
+
"Smith","Johnson","Williams","Brown","Jones","Garcia","Miller","Davis",
|
| 57 |
+
"Rodriguez","Martinez","Hernandez","Lopez","Gonzalez","Wilson","Anderson",
|
| 58 |
+
"Thomas","Taylor","Moore","Jackson","Martin","Lee","Perez","Thompson",
|
| 59 |
+
"White","Harris","Sanchez","Clark","Ramirez","Lewis","Robinson","Walker",
|
| 60 |
+
"Young","Allen","King","Wright","Scott","Torres","Nguyen","Hill","Flores",
|
| 61 |
+
"Green","Adams","Nelson","Baker","Hall","Rivera","Campbell","Mitchell",
|
| 62 |
+
"Carter","Roberts",
|
| 63 |
+
]
|
| 64 |
+
|
| 65 |
+
REGIONS = ["Northeast", "Southeast", "West", "Midwest"]
|
| 66 |
+
|
| 67 |
+
PRICE_CHANGES = [
|
| 68 |
+
(1, 999.99, 1149.99, "2024-02-01", "Annual pricing adjustment"),
|
| 69 |
+
(2, 1499.99, 1699.99, "2024-02-01", "Annual pricing adjustment"),
|
| 70 |
+
(11, 199.99, 229.99, "2024-02-01", "Annual pricing adjustment"),
|
| 71 |
+
(15, 149.99, 174.99, "2024-02-01", "Annual pricing adjustment"),
|
| 72 |
+
(19, 109.99, 129.99, "2024-02-01", "Annual pricing adjustment"),
|
| 73 |
+
]
|
| 74 |
+
|
| 75 |
+
PROMOTIONS = [
|
| 76 |
+
(1, "New Year Kickoff", "2024-01-01", "2024-01-15", 10.0, "All"),
|
| 77 |
+
(2, "Valentine Tech Sale", "2024-02-10", "2024-02-14", 15.0, "Electronics"),
|
| 78 |
+
(3, "Spring Mega Sale", "2024-02-15", "2024-03-01", 25.0, "All"),
|
| 79 |
+
]
|
| 80 |
+
|
| 81 |
+
CARRIERS = ["QuickShip", "FastFreight", "ReliableLogistics"]
|
| 82 |
+
|
| 83 |
+
TICKET_CATEGORIES = ["delivery_delay", "product_defect", "billing_issue", "general_inquiry"]
|
| 84 |
+
|
| 85 |
+
MARKETING_CHANNELS = ["email", "social_media", "search_ads", "display_ads", "affiliate"]
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _date_range(start: datetime, end: datetime):
|
| 89 |
+
d = start
|
| 90 |
+
while d <= end:
|
| 91 |
+
yield d
|
| 92 |
+
d += timedelta(days=1)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def _effective_price(base_prices: dict, changes_by_pid: dict, pid: int, date_str: str):
|
| 96 |
+
"""Return the unit price for *pid* on *date_str*, considering price changes."""
|
| 97 |
+
price = base_prices[pid]
|
| 98 |
+
for new_price, change_date in changes_by_pid.get(pid, []):
|
| 99 |
+
if date_str >= change_date:
|
| 100 |
+
price = new_price
|
| 101 |
+
return price
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def create_database(seed: int = 42) -> sqlite3.Connection:
|
| 105 |
+
rng = random.Random(seed)
|
| 106 |
+
conn = sqlite3.connect(":memory:", check_same_thread=False)
|
| 107 |
+
c = conn.cursor()
|
| 108 |
+
|
| 109 |
+
c.executescript("""
|
| 110 |
+
CREATE TABLE customers (
|
| 111 |
+
customer_id INTEGER PRIMARY KEY,
|
| 112 |
+
name TEXT NOT NULL,
|
| 113 |
+
email TEXT NOT NULL,
|
| 114 |
+
region TEXT NOT NULL,
|
| 115 |
+
segment TEXT NOT NULL,
|
| 116 |
+
signup_date TEXT NOT NULL
|
| 117 |
+
);
|
| 118 |
+
CREATE TABLE products (
|
| 119 |
+
product_id INTEGER PRIMARY KEY,
|
| 120 |
+
name TEXT NOT NULL,
|
| 121 |
+
category TEXT NOT NULL,
|
| 122 |
+
price REAL NOT NULL,
|
| 123 |
+
cost REAL NOT NULL,
|
| 124 |
+
supplier TEXT NOT NULL
|
| 125 |
+
);
|
| 126 |
+
CREATE TABLE orders (
|
| 127 |
+
order_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 128 |
+
customer_id INTEGER NOT NULL,
|
| 129 |
+
order_date TEXT NOT NULL,
|
| 130 |
+
status TEXT NOT NULL DEFAULT 'completed',
|
| 131 |
+
total_amount REAL NOT NULL DEFAULT 0,
|
| 132 |
+
FOREIGN KEY (customer_id) REFERENCES customers(customer_id)
|
| 133 |
+
);
|
| 134 |
+
CREATE TABLE order_items (
|
| 135 |
+
item_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 136 |
+
order_id INTEGER NOT NULL,
|
| 137 |
+
product_id INTEGER NOT NULL,
|
| 138 |
+
quantity INTEGER NOT NULL DEFAULT 1,
|
| 139 |
+
unit_price REAL NOT NULL,
|
| 140 |
+
FOREIGN KEY (order_id) REFERENCES orders(order_id),
|
| 141 |
+
FOREIGN KEY (product_id) REFERENCES products(product_id)
|
| 142 |
+
);
|
| 143 |
+
CREATE TABLE returns (
|
| 144 |
+
return_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 145 |
+
order_id INTEGER NOT NULL,
|
| 146 |
+
product_id INTEGER NOT NULL,
|
| 147 |
+
return_date TEXT NOT NULL,
|
| 148 |
+
reason TEXT NOT NULL,
|
| 149 |
+
refund_amount REAL NOT NULL,
|
| 150 |
+
FOREIGN KEY (order_id) REFERENCES orders(order_id),
|
| 151 |
+
FOREIGN KEY (product_id) REFERENCES products(product_id)
|
| 152 |
+
);
|
| 153 |
+
CREATE TABLE promotions (
|
| 154 |
+
promo_id INTEGER PRIMARY KEY,
|
| 155 |
+
name TEXT NOT NULL,
|
| 156 |
+
start_date TEXT NOT NULL,
|
| 157 |
+
end_date TEXT NOT NULL,
|
| 158 |
+
discount_pct REAL NOT NULL,
|
| 159 |
+
applicable_category TEXT
|
| 160 |
+
);
|
| 161 |
+
CREATE TABLE price_changes (
|
| 162 |
+
change_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 163 |
+
product_id INTEGER NOT NULL,
|
| 164 |
+
old_price REAL NOT NULL,
|
| 165 |
+
new_price REAL NOT NULL,
|
| 166 |
+
change_date TEXT NOT NULL,
|
| 167 |
+
reason TEXT,
|
| 168 |
+
FOREIGN KEY (product_id) REFERENCES products(product_id)
|
| 169 |
+
);
|
| 170 |
+
CREATE TABLE shipping (
|
| 171 |
+
shipment_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 172 |
+
order_id INTEGER NOT NULL,
|
| 173 |
+
carrier TEXT NOT NULL,
|
| 174 |
+
ship_date TEXT NOT NULL,
|
| 175 |
+
delivery_date TEXT NOT NULL,
|
| 176 |
+
status TEXT NOT NULL DEFAULT 'delivered',
|
| 177 |
+
FOREIGN KEY (order_id) REFERENCES orders(order_id)
|
| 178 |
+
);
|
| 179 |
+
CREATE TABLE support_tickets (
|
| 180 |
+
ticket_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 181 |
+
customer_id INTEGER NOT NULL,
|
| 182 |
+
product_id INTEGER,
|
| 183 |
+
created_date TEXT NOT NULL,
|
| 184 |
+
category TEXT NOT NULL,
|
| 185 |
+
priority TEXT NOT NULL DEFAULT 'medium',
|
| 186 |
+
resolution_status TEXT NOT NULL DEFAULT 'open',
|
| 187 |
+
FOREIGN KEY (customer_id) REFERENCES customers(customer_id),
|
| 188 |
+
FOREIGN KEY (product_id) REFERENCES products(product_id)
|
| 189 |
+
);
|
| 190 |
+
CREATE TABLE inventory_log (
|
| 191 |
+
log_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 192 |
+
product_id INTEGER NOT NULL,
|
| 193 |
+
log_date TEXT NOT NULL,
|
| 194 |
+
units_in_stock INTEGER NOT NULL,
|
| 195 |
+
units_ordered INTEGER NOT NULL DEFAULT 0,
|
| 196 |
+
warehouse_region TEXT NOT NULL,
|
| 197 |
+
FOREIGN KEY (product_id) REFERENCES products(product_id)
|
| 198 |
+
);
|
| 199 |
+
CREATE TABLE marketing_spend (
|
| 200 |
+
spend_id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 201 |
+
channel TEXT NOT NULL,
|
| 202 |
+
campaign_name TEXT NOT NULL,
|
| 203 |
+
region TEXT NOT NULL,
|
| 204 |
+
spend_date TEXT NOT NULL,
|
| 205 |
+
amount REAL NOT NULL
|
| 206 |
+
);
|
| 207 |
+
""")
|
| 208 |
+
|
| 209 |
+
c.executemany("INSERT INTO products VALUES (?,?,?,?,?,?)", PRODUCTS)
|
| 210 |
+
base_prices = {p[0]: p[3] for p in PRODUCTS}
|
| 211 |
+
|
| 212 |
+
segments_pool = ["Enterprise"] * 35 + ["SMB"] * 55 + ["Consumer"] * 60
|
| 213 |
+
rng.shuffle(segments_pool)
|
| 214 |
+
customers = []
|
| 215 |
+
for i in range(150):
|
| 216 |
+
first = rng.choice(_FIRST)
|
| 217 |
+
last = rng.choice(_LAST)
|
| 218 |
+
name = f"{first} {last}"
|
| 219 |
+
email = f"{first.lower()}.{last.lower()}{i}@techmart.com"
|
| 220 |
+
region = REGIONS[i % 4]
|
| 221 |
+
segment = segments_pool[i]
|
| 222 |
+
signup = (datetime(2023, 1, 1) + timedelta(days=rng.randint(0, 364))).strftime("%Y-%m-%d")
|
| 223 |
+
c.execute("INSERT INTO customers VALUES (?,?,?,?,?,?)",
|
| 224 |
+
(i + 1, name, email, region, segment, signup))
|
| 225 |
+
customers.append((i + 1, name, email, region, segment, signup))
|
| 226 |
+
|
| 227 |
+
ent_ne = [cu for cu in customers if cu[4] == "Enterprise" and cu[3] == "Northeast"]
|
| 228 |
+
ent_other = [cu for cu in customers if cu[4] == "Enterprise" and cu[3] != "Northeast"]
|
| 229 |
+
smb_all = [cu for cu in customers if cu[4] == "SMB"]
|
| 230 |
+
con_all = [cu for cu in customers if cu[4] == "Consumer"]
|
| 231 |
+
|
| 232 |
+
c.executemany("INSERT INTO promotions VALUES (?,?,?,?,?,?)", PROMOTIONS)
|
| 233 |
+
for pid, old_p, new_p, dt, reason in PRICE_CHANGES:
|
| 234 |
+
c.execute(
|
| 235 |
+
"INSERT INTO price_changes (product_id,old_price,new_price,change_date,reason) VALUES (?,?,?,?,?)",
|
| 236 |
+
(pid, old_p, new_p, dt, reason),
|
| 237 |
+
)
|
| 238 |
+
changes_by_pid: dict[int, list] = {}
|
| 239 |
+
for pid, _, new_p, dt, _ in PRICE_CHANGES:
|
| 240 |
+
changes_by_pid.setdefault(pid, []).append((new_p, dt))
|
| 241 |
+
|
| 242 |
+
START = datetime(2024, 1, 1)
|
| 243 |
+
END = datetime(2024, 3, 15)
|
| 244 |
+
PROMO_S = datetime(2024, 2, 15)
|
| 245 |
+
PROMO_E = datetime(2024, 3, 1)
|
| 246 |
+
PRICE_INC = datetime(2024, 2, 1)
|
| 247 |
+
|
| 248 |
+
product_weights_base = [1.0] * 20
|
| 249 |
+
product_weights_base[5] = 3.0
|
| 250 |
+
|
| 251 |
+
for day in _date_range(START, END):
|
| 252 |
+
date_str = day.strftime("%Y-%m-%d")
|
| 253 |
+
is_promo = PROMO_S <= day <= PROMO_E
|
| 254 |
+
after_price_inc = day >= PRICE_INC
|
| 255 |
+
|
| 256 |
+
daily_count = rng.randint(25, 35) if is_promo else rng.randint(12, 18)
|
| 257 |
+
|
| 258 |
+
for _ in range(daily_count):
|
| 259 |
+
roll = rng.random()
|
| 260 |
+
if roll < 0.08:
|
| 261 |
+
pool = ent_ne
|
| 262 |
+
if after_price_inc and rng.random() < 0.85:
|
| 263 |
+
continue
|
| 264 |
+
elif roll < 0.22:
|
| 265 |
+
pool = ent_other
|
| 266 |
+
if after_price_inc and rng.random() < 0.50:
|
| 267 |
+
continue
|
| 268 |
+
elif roll < 0.55:
|
| 269 |
+
pool = smb_all
|
| 270 |
+
if after_price_inc and rng.random() < 0.20:
|
| 271 |
+
continue
|
| 272 |
+
else:
|
| 273 |
+
pool = con_all
|
| 274 |
+
|
| 275 |
+
cust = rng.choice(pool)
|
| 276 |
+
cust_id, _, _, cust_region, _, _ = cust
|
| 277 |
+
|
| 278 |
+
weights = list(product_weights_base)
|
| 279 |
+
if cust_region == "West":
|
| 280 |
+
weights[5] = 7.0
|
| 281 |
+
if is_promo:
|
| 282 |
+
weights[3] = 0.1
|
| 283 |
+
num_items = rng.choices([1, 2, 3], weights=[0.6, 0.3, 0.1])[0]
|
| 284 |
+
pids = list(set(rng.choices(range(1, 21), weights=weights, k=num_items)))
|
| 285 |
+
|
| 286 |
+
c.execute(
|
| 287 |
+
"INSERT INTO orders (customer_id, order_date, status, total_amount) VALUES (?,?,?,?)",
|
| 288 |
+
(cust_id, date_str, "completed", 0),
|
| 289 |
+
)
|
| 290 |
+
order_id = c.lastrowid
|
| 291 |
+
total = 0.0
|
| 292 |
+
for pid in pids:
|
| 293 |
+
qty = rng.choices([1, 2, 3], weights=[0.75, 0.20, 0.05])[0]
|
| 294 |
+
price = _effective_price(base_prices, changes_by_pid, pid, date_str)
|
| 295 |
+
if is_promo:
|
| 296 |
+
price = round(price * 0.75, 2)
|
| 297 |
+
total += price * qty
|
| 298 |
+
c.execute(
|
| 299 |
+
"INSERT INTO order_items (order_id, product_id, quantity, unit_price) VALUES (?,?,?,?)",
|
| 300 |
+
(order_id, pid, qty, round(price, 2)),
|
| 301 |
+
)
|
| 302 |
+
c.execute("UPDATE orders SET total_amount=? WHERE order_id=?",
|
| 303 |
+
(round(total, 2), order_id))
|
| 304 |
+
|
| 305 |
+
c.execute("""
|
| 306 |
+
SELECT oi.item_id, oi.order_id, oi.product_id, oi.unit_price, oi.quantity,
|
| 307 |
+
o.order_date, cu.region
|
| 308 |
+
FROM order_items oi
|
| 309 |
+
JOIN orders o ON oi.order_id = o.order_id
|
| 310 |
+
JOIN customers cu ON o.customer_id = cu.customer_id
|
| 311 |
+
""")
|
| 312 |
+
items = c.fetchall()
|
| 313 |
+
|
| 314 |
+
defect_reasons = ["defective_unit", "stopped_working", "poor_audio_quality", "battery_issue"]
|
| 315 |
+
normal_reasons = ["changed_mind", "wrong_size", "found_cheaper", "not_as_expected"]
|
| 316 |
+
|
| 317 |
+
speaker_defect_reasons = ["audio_distortion", "bluetooth_disconnect", "battery_issue", "stopped_working"]
|
| 318 |
+
|
| 319 |
+
for _, order_id, product_id, unit_price, qty, order_date, region in items:
|
| 320 |
+
if product_id == 6 and region == "West":
|
| 321 |
+
prob = 0.38
|
| 322 |
+
reasons = defect_reasons
|
| 323 |
+
elif product_id == 6:
|
| 324 |
+
prob = 0.08
|
| 325 |
+
reasons = defect_reasons + normal_reasons
|
| 326 |
+
elif product_id == 7:
|
| 327 |
+
prob = 0.12
|
| 328 |
+
reasons = speaker_defect_reasons
|
| 329 |
+
else:
|
| 330 |
+
prob = 0.04
|
| 331 |
+
reasons = normal_reasons
|
| 332 |
+
|
| 333 |
+
if rng.random() < prob:
|
| 334 |
+
ret_date = (datetime.strptime(order_date, "%Y-%m-%d")
|
| 335 |
+
+ timedelta(days=rng.randint(3, 14))).strftime("%Y-%m-%d")
|
| 336 |
+
c.execute(
|
| 337 |
+
"INSERT INTO returns (order_id, product_id, return_date, reason, refund_amount) VALUES (?,?,?,?,?)",
|
| 338 |
+
(order_id, product_id, ret_date, rng.choice(reasons), round(unit_price * qty, 2)),
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
# -- Shipping records for every order ------------------------------------
|
| 342 |
+
QUICKSHIP_DELAY_START = datetime(2024, 2, 10)
|
| 343 |
+
c.execute("SELECT order_id, order_date, customer_id FROM orders")
|
| 344 |
+
all_orders = c.fetchall()
|
| 345 |
+
cust_region_map = {cu[0]: cu[3] for cu in customers}
|
| 346 |
+
|
| 347 |
+
for order_id, order_date_str, cust_id in all_orders:
|
| 348 |
+
order_dt = datetime.strptime(order_date_str, "%Y-%m-%d")
|
| 349 |
+
region = cust_region_map[cust_id]
|
| 350 |
+
|
| 351 |
+
if region == "Midwest":
|
| 352 |
+
carrier = rng.choices(CARRIERS, weights=[0.40, 0.35, 0.25])[0]
|
| 353 |
+
else:
|
| 354 |
+
carrier = rng.choices(CARRIERS, weights=[0.25, 0.40, 0.35])[0]
|
| 355 |
+
|
| 356 |
+
ship_dt = order_dt + timedelta(days=rng.randint(0, 1))
|
| 357 |
+
base_transit = rng.randint(2, 4)
|
| 358 |
+
|
| 359 |
+
if carrier == "QuickShip" and region == "Midwest" and order_dt >= QUICKSHIP_DELAY_START:
|
| 360 |
+
extra_delay = rng.randint(5, 10)
|
| 361 |
+
status = "delayed"
|
| 362 |
+
elif carrier == "FastFreight":
|
| 363 |
+
extra_delay = rng.randint(0, 2)
|
| 364 |
+
status = "delivered"
|
| 365 |
+
else:
|
| 366 |
+
extra_delay = 0
|
| 367 |
+
status = "delivered"
|
| 368 |
+
|
| 369 |
+
delivery_dt = ship_dt + timedelta(days=base_transit + extra_delay)
|
| 370 |
+
if status == "delayed" and rng.random() < 0.7:
|
| 371 |
+
status = "delivered"
|
| 372 |
+
|
| 373 |
+
c.execute(
|
| 374 |
+
"INSERT INTO shipping (order_id, carrier, ship_date, delivery_date, status) "
|
| 375 |
+
"VALUES (?,?,?,?,?)",
|
| 376 |
+
(order_id, carrier, ship_dt.strftime("%Y-%m-%d"),
|
| 377 |
+
delivery_dt.strftime("%Y-%m-%d"), status),
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
# -- Support tickets -----------------------------------------------------
|
| 381 |
+
ticket_priorities = ["low", "medium", "high", "critical"]
|
| 382 |
+
ticket_resolutions = ["open", "resolved", "escalated"]
|
| 383 |
+
|
| 384 |
+
for day in _date_range(START, END):
|
| 385 |
+
date_str = day.strftime("%Y-%m-%d")
|
| 386 |
+
after_qs_issues = day >= QUICKSHIP_DELAY_START
|
| 387 |
+
|
| 388 |
+
for region_name in REGIONS:
|
| 389 |
+
region_custs = [cu for cu in customers if cu[3] == region_name]
|
| 390 |
+
|
| 391 |
+
# Delivery delay tickets: spike in Midwest after QuickShip issues
|
| 392 |
+
if region_name == "Midwest" and after_qs_issues:
|
| 393 |
+
n_delay = rng.randint(3, 6)
|
| 394 |
+
else:
|
| 395 |
+
n_delay = rng.randint(0, 1)
|
| 396 |
+
|
| 397 |
+
for _ in range(n_delay):
|
| 398 |
+
cu = rng.choice(region_custs)
|
| 399 |
+
pri = rng.choices(ticket_priorities, weights=[0.1, 0.3, 0.4, 0.2])[0]
|
| 400 |
+
res = rng.choices(ticket_resolutions, weights=[0.3, 0.5, 0.2])[0]
|
| 401 |
+
c.execute(
|
| 402 |
+
"INSERT INTO support_tickets "
|
| 403 |
+
"(customer_id, product_id, created_date, category, priority, resolution_status) "
|
| 404 |
+
"VALUES (?,?,?,?,?,?)",
|
| 405 |
+
(cu[0], None, date_str, "delivery_delay", pri, res),
|
| 406 |
+
)
|
| 407 |
+
|
| 408 |
+
# Product defect tickets: elevated for AudioTech products (6 in West, 7 everywhere)
|
| 409 |
+
if region_name == "West":
|
| 410 |
+
n_defect = rng.randint(1, 3)
|
| 411 |
+
else:
|
| 412 |
+
n_defect = 1 if rng.random() < 0.3 else 0
|
| 413 |
+
|
| 414 |
+
for _ in range(n_defect):
|
| 415 |
+
cu = rng.choice(region_custs)
|
| 416 |
+
pid = 6 if region_name == "West" or rng.random() < 0.4 else rng.randint(1, 20)
|
| 417 |
+
pri = rng.choices(ticket_priorities, weights=[0.1, 0.3, 0.4, 0.2])[0]
|
| 418 |
+
res = rng.choices(ticket_resolutions, weights=[0.4, 0.4, 0.2])[0]
|
| 419 |
+
c.execute(
|
| 420 |
+
"INSERT INTO support_tickets "
|
| 421 |
+
"(customer_id, product_id, created_date, category, priority, resolution_status) "
|
| 422 |
+
"VALUES (?,?,?,?,?,?)",
|
| 423 |
+
(cu[0], pid, date_str, "product_defect", pri, res),
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
# Product 7 (Bluetooth Speaker) defect tickets across all regions
|
| 427 |
+
if rng.random() < 0.45:
|
| 428 |
+
cu = rng.choice(region_custs)
|
| 429 |
+
pri = rng.choices(ticket_priorities, weights=[0.1, 0.4, 0.35, 0.15])[0]
|
| 430 |
+
res = rng.choices(ticket_resolutions, weights=[0.35, 0.45, 0.2])[0]
|
| 431 |
+
c.execute(
|
| 432 |
+
"INSERT INTO support_tickets "
|
| 433 |
+
"(customer_id, product_id, created_date, category, priority, resolution_status) "
|
| 434 |
+
"VALUES (?,?,?,?,?,?)",
|
| 435 |
+
(cu[0], 7, date_str, "product_defect", pri, res),
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
# Billing issue tickets: evenly spread (red herring / noise)
|
| 439 |
+
if rng.random() < 0.25:
|
| 440 |
+
cu = rng.choice(region_custs)
|
| 441 |
+
c.execute(
|
| 442 |
+
"INSERT INTO support_tickets "
|
| 443 |
+
"(customer_id, product_id, created_date, category, priority, resolution_status) "
|
| 444 |
+
"VALUES (?,?,?,?,?,?)",
|
| 445 |
+
(cu[0], None, date_str, "billing_issue",
|
| 446 |
+
rng.choice(ticket_priorities), rng.choice(ticket_resolutions)),
|
| 447 |
+
)
|
| 448 |
+
|
| 449 |
+
# General inquiry: background noise
|
| 450 |
+
if rng.random() < 0.35:
|
| 451 |
+
cu = rng.choice(region_custs)
|
| 452 |
+
c.execute(
|
| 453 |
+
"INSERT INTO support_tickets "
|
| 454 |
+
"(customer_id, product_id, created_date, category, priority, resolution_status) "
|
| 455 |
+
"VALUES (?,?,?,?,?,?)",
|
| 456 |
+
(cu[0], None, date_str, "general_inquiry", "low",
|
| 457 |
+
rng.choice(["resolved", "open"])),
|
| 458 |
+
)
|
| 459 |
+
|
| 460 |
+
# -- Inventory log -------------------------------------------------------
|
| 461 |
+
# Daily stock levels per product per warehouse region.
|
| 462 |
+
# Product 4 (Monitor 27-inch) stocks out in West during promo.
|
| 463 |
+
base_stock = {}
|
| 464 |
+
for p in PRODUCTS:
|
| 465 |
+
pid = p[0]
|
| 466 |
+
if pid in (1, 2, 3, 4, 5): # electronics — higher stock
|
| 467 |
+
base_stock[pid] = 200
|
| 468 |
+
elif pid <= 10: # accessories
|
| 469 |
+
base_stock[pid] = 350
|
| 470 |
+
elif pid <= 15: # software (digital)
|
| 471 |
+
base_stock[pid] = 9999
|
| 472 |
+
else: # peripherals
|
| 473 |
+
base_stock[pid] = 250
|
| 474 |
+
|
| 475 |
+
for day in _date_range(START, END):
|
| 476 |
+
date_str = day.strftime("%Y-%m-%d")
|
| 477 |
+
is_promo = PROMO_S <= day <= PROMO_E
|
| 478 |
+
|
| 479 |
+
for region_name in REGIONS:
|
| 480 |
+
for p in PRODUCTS:
|
| 481 |
+
pid = p[0]
|
| 482 |
+
stock = base_stock[pid]
|
| 483 |
+
|
| 484 |
+
daily_sold = rng.randint(2, 8)
|
| 485 |
+
if is_promo:
|
| 486 |
+
daily_sold = rng.randint(5, 15)
|
| 487 |
+
|
| 488 |
+
# Product 4 stockout in West during promo
|
| 489 |
+
if pid == 4 and region_name == "West" and is_promo:
|
| 490 |
+
stock = rng.randint(0, 2)
|
| 491 |
+
daily_sold = rng.randint(0, 1)
|
| 492 |
+
else:
|
| 493 |
+
stock = max(stock - daily_sold + rng.randint(1, 6), 10)
|
| 494 |
+
|
| 495 |
+
# Product 6 fluctuates in West but never stocks out (red herring)
|
| 496 |
+
if pid == 6 and region_name == "West":
|
| 497 |
+
stock = rng.randint(30, 80)
|
| 498 |
+
|
| 499 |
+
reorder = 0
|
| 500 |
+
if stock < 20 and pid <= 15:
|
| 501 |
+
reorder = rng.randint(50, 100)
|
| 502 |
+
|
| 503 |
+
c.execute(
|
| 504 |
+
"INSERT INTO inventory_log "
|
| 505 |
+
"(product_id, log_date, units_in_stock, units_ordered, warehouse_region) "
|
| 506 |
+
"VALUES (?,?,?,?,?)",
|
| 507 |
+
(pid, date_str, stock, reorder, region_name),
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
# -- Fraudulent accounts ---------------------------------------------------
|
| 511 |
+
# ~15 fake accounts in Southeast, Consumer, all signed up late Feb,
|
| 512 |
+
# placing high-value Electronics orders (products 1 & 2).
|
| 513 |
+
fraud_customers = []
|
| 514 |
+
for i in range(15):
|
| 515 |
+
cid = 151 + i
|
| 516 |
+
first = rng.choice(_FIRST)
|
| 517 |
+
last = rng.choice(_LAST)
|
| 518 |
+
name = f"{first} {last}"
|
| 519 |
+
email = f"{first.lower()}.{last.lower()}{cid}@techmart.com"
|
| 520 |
+
signup = (datetime(2024, 2, 20) + timedelta(days=rng.randint(0, 7))).strftime("%Y-%m-%d")
|
| 521 |
+
c.execute("INSERT INTO customers VALUES (?,?,?,?,?,?)",
|
| 522 |
+
(cid, name, email, "Southeast", "Consumer", signup))
|
| 523 |
+
fraud_customers.append(cid)
|
| 524 |
+
customers.append((cid, name, email, "Southeast", "Consumer", signup))
|
| 525 |
+
|
| 526 |
+
cust_region_map.update({cid: "Southeast" for cid in fraud_customers})
|
| 527 |
+
|
| 528 |
+
FRAUD_ORDER_START = datetime(2024, 2, 25)
|
| 529 |
+
FRAUD_ORDER_END = datetime(2024, 3, 10)
|
| 530 |
+
for cid in fraud_customers:
|
| 531 |
+
n_orders = rng.randint(3, 5)
|
| 532 |
+
for _ in range(n_orders):
|
| 533 |
+
order_day = FRAUD_ORDER_START + timedelta(
|
| 534 |
+
days=rng.randint(0, (FRAUD_ORDER_END - FRAUD_ORDER_START).days))
|
| 535 |
+
date_str = order_day.strftime("%Y-%m-%d")
|
| 536 |
+
fraud_pid = rng.choice([1, 2])
|
| 537 |
+
qty = rng.randint(1, 2)
|
| 538 |
+
price = _effective_price(base_prices, changes_by_pid, fraud_pid, date_str)
|
| 539 |
+
is_promo_day = PROMO_S <= order_day <= PROMO_E
|
| 540 |
+
if is_promo_day:
|
| 541 |
+
price = round(price * 0.75, 2)
|
| 542 |
+
total = round(price * qty, 2)
|
| 543 |
+
c.execute(
|
| 544 |
+
"INSERT INTO orders (customer_id, order_date, status, total_amount) VALUES (?,?,?,?)",
|
| 545 |
+
(cid, date_str, "completed", total),
|
| 546 |
+
)
|
| 547 |
+
oid = c.lastrowid
|
| 548 |
+
c.execute(
|
| 549 |
+
"INSERT INTO order_items (order_id, product_id, quantity, unit_price) VALUES (?,?,?,?)",
|
| 550 |
+
(oid, fraud_pid, qty, round(price, 2)),
|
| 551 |
+
)
|
| 552 |
+
ship_dt = order_day + timedelta(days=rng.randint(0, 1))
|
| 553 |
+
delivery_dt = ship_dt + timedelta(days=rng.randint(2, 4))
|
| 554 |
+
c.execute(
|
| 555 |
+
"INSERT INTO shipping (order_id, carrier, ship_date, delivery_date, status) "
|
| 556 |
+
"VALUES (?,?,?,?,?)",
|
| 557 |
+
(oid, "FastFreight", ship_dt.strftime("%Y-%m-%d"),
|
| 558 |
+
delivery_dt.strftime("%Y-%m-%d"), "delivered"),
|
| 559 |
+
)
|
| 560 |
+
|
| 561 |
+
# -- Marketing spend -------------------------------------------------------
|
| 562 |
+
# Heavy acquisition spend (search_ads, social_media) in Feb/Mar.
|
| 563 |
+
# Email (retention) drops off after Jan. Southeast gets a big bump in Feb
|
| 564 |
+
# (red herring for fraud task).
|
| 565 |
+
acq_channels = ["search_ads", "social_media", "display_ads", "affiliate"]
|
| 566 |
+
for day in _date_range(START, END):
|
| 567 |
+
date_str = day.strftime("%Y-%m-%d")
|
| 568 |
+
month = day.month
|
| 569 |
+
|
| 570 |
+
for region_name in REGIONS:
|
| 571 |
+
# Retention channel: email
|
| 572 |
+
if month == 1:
|
| 573 |
+
email_spend = round(rng.uniform(200, 400), 2)
|
| 574 |
+
else:
|
| 575 |
+
email_spend = round(rng.uniform(20, 60), 2)
|
| 576 |
+
c.execute(
|
| 577 |
+
"INSERT INTO marketing_spend (channel, campaign_name, region, spend_date, amount) "
|
| 578 |
+
"VALUES (?,?,?,?,?)",
|
| 579 |
+
("email", "Customer Retention", region_name, date_str, email_spend),
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
# Acquisition channels
|
| 583 |
+
for ch in acq_channels:
|
| 584 |
+
if month == 1:
|
| 585 |
+
base_spend = rng.uniform(100, 200)
|
| 586 |
+
else:
|
| 587 |
+
base_spend = rng.uniform(300, 600)
|
| 588 |
+
|
| 589 |
+
if region_name == "Southeast" and month >= 2:
|
| 590 |
+
base_spend *= 1.5
|
| 591 |
+
|
| 592 |
+
c.execute(
|
| 593 |
+
"INSERT INTO marketing_spend (channel, campaign_name, region, spend_date, amount) "
|
| 594 |
+
"VALUES (?,?,?,?,?)",
|
| 595 |
+
(ch, "New Customer Acquisition", region_name, date_str,
|
| 596 |
+
round(base_spend, 2)),
|
| 597 |
+
)
|
| 598 |
+
|
| 599 |
+
conn.commit()
|
| 600 |
+
return conn
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
def get_schema_info(conn: sqlite3.Connection) -> str:
|
| 604 |
+
"""Human-readable database schema for the LLM agent."""
|
| 605 |
+
c = conn.cursor()
|
| 606 |
+
c.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
|
| 607 |
+
tables = [r[0] for r in c.fetchall()]
|
| 608 |
+
|
| 609 |
+
parts = ["DATABASE SCHEMA", "=" * 50, ""]
|
| 610 |
+
for table in tables:
|
| 611 |
+
c.execute(f"SELECT COUNT(*) FROM {table}")
|
| 612 |
+
count = c.fetchone()[0]
|
| 613 |
+
parts.append(f"Table: {table} ({count} rows)")
|
| 614 |
+
|
| 615 |
+
c.execute(f"PRAGMA table_info({table})")
|
| 616 |
+
for col in c.fetchall():
|
| 617 |
+
pk = " [PK]" if col[5] else ""
|
| 618 |
+
parts.append(f" - {col[1]} {col[2]}{pk}")
|
| 619 |
+
|
| 620 |
+
if table == "customers":
|
| 621 |
+
c.execute("SELECT DISTINCT region FROM customers ORDER BY region")
|
| 622 |
+
parts.append(f" Regions: {', '.join(r[0] for r in c.fetchall())}")
|
| 623 |
+
c.execute("SELECT DISTINCT segment FROM customers ORDER BY segment")
|
| 624 |
+
parts.append(f" Segments: {', '.join(r[0] for r in c.fetchall())}")
|
| 625 |
+
elif table == "products":
|
| 626 |
+
c.execute("SELECT DISTINCT category FROM products ORDER BY category")
|
| 627 |
+
parts.append(f" Categories: {', '.join(r[0] for r in c.fetchall())}")
|
| 628 |
+
c.execute("SELECT DISTINCT supplier FROM products ORDER BY supplier")
|
| 629 |
+
parts.append(f" Suppliers: {', '.join(r[0] for r in c.fetchall())}")
|
| 630 |
+
elif table == "shipping":
|
| 631 |
+
c.execute("SELECT DISTINCT carrier FROM shipping ORDER BY carrier")
|
| 632 |
+
parts.append(f" Carriers: {', '.join(r[0] for r in c.fetchall())}")
|
| 633 |
+
c.execute("SELECT DISTINCT status FROM shipping ORDER BY status")
|
| 634 |
+
parts.append(f" Statuses: {', '.join(r[0] for r in c.fetchall())}")
|
| 635 |
+
elif table == "support_tickets":
|
| 636 |
+
c.execute("SELECT DISTINCT category FROM support_tickets ORDER BY category")
|
| 637 |
+
parts.append(f" Categories: {', '.join(r[0] for r in c.fetchall())}")
|
| 638 |
+
c.execute("SELECT DISTINCT priority FROM support_tickets ORDER BY priority")
|
| 639 |
+
parts.append(f" Priorities: {', '.join(r[0] for r in c.fetchall())}")
|
| 640 |
+
elif table == "inventory_log":
|
| 641 |
+
c.execute("SELECT DISTINCT warehouse_region FROM inventory_log ORDER BY warehouse_region")
|
| 642 |
+
parts.append(f" Warehouse regions: {', '.join(r[0] for r in c.fetchall())}")
|
| 643 |
+
elif table == "marketing_spend":
|
| 644 |
+
c.execute("SELECT DISTINCT channel FROM marketing_spend ORDER BY channel")
|
| 645 |
+
parts.append(f" Channels: {', '.join(r[0] for r in c.fetchall())}")
|
| 646 |
+
c.execute("SELECT DISTINCT campaign_name FROM marketing_spend ORDER BY campaign_name")
|
| 647 |
+
parts.append(f" Campaigns: {', '.join(r[0] for r in c.fetchall())}")
|
| 648 |
+
|
| 649 |
+
parts.append("")
|
| 650 |
+
|
| 651 |
+
parts += [
|
| 652 |
+
"=" * 50,
|
| 653 |
+
"Data spans: 2024-01-01 to 2024-03-15",
|
| 654 |
+
"All dates stored as YYYY-MM-DD text.",
|
| 655 |
+
"Use standard SQLite functions (strftime, date, etc.).",
|
| 656 |
+
]
|
| 657 |
+
return "\n".join(parts)
|
server/environment.py
ADDED
|
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Core environment logic for DataDetective."""
|
| 2 |
+
|
| 3 |
+
import random
|
| 4 |
+
import uuid
|
| 5 |
+
from typing import Any, Optional
|
| 6 |
+
|
| 7 |
+
from openenv.core.env_server import Environment
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
from ..models import DataDetectiveAction, DataDetectiveObservation, DataDetectiveState
|
| 11 |
+
from .database import create_database, get_schema_info
|
| 12 |
+
from .tasks import TASKS, grade_answer
|
| 13 |
+
except (ImportError, ModuleNotFoundError):
|
| 14 |
+
from models import DataDetectiveAction, DataDetectiveObservation, DataDetectiveState
|
| 15 |
+
from server.database import create_database, get_schema_info
|
| 16 |
+
from server.tasks import TASKS, grade_answer
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class DataDetectiveEnvironment(
|
| 20 |
+
Environment[DataDetectiveAction, DataDetectiveObservation, DataDetectiveState]
|
| 21 |
+
):
|
| 22 |
+
SUPPORTS_CONCURRENT_SESSIONS = True
|
| 23 |
+
MAX_STEPS = 30
|
| 24 |
+
|
| 25 |
+
def __init__(self):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self._db = None
|
| 28 |
+
self._task_id: str = ""
|
| 29 |
+
self._step_count: int = 0
|
| 30 |
+
self._episode_id: str = ""
|
| 31 |
+
self._queries_executed: int = 0
|
| 32 |
+
self._state = DataDetectiveState()
|
| 33 |
+
|
| 34 |
+
def reset(
|
| 35 |
+
self,
|
| 36 |
+
seed: Optional[int] = None,
|
| 37 |
+
episode_id: Optional[str] = None,
|
| 38 |
+
task_id: Optional[str] = None,
|
| 39 |
+
**kwargs: Any,
|
| 40 |
+
) -> DataDetectiveObservation:
|
| 41 |
+
if seed is not None:
|
| 42 |
+
random.seed(seed)
|
| 43 |
+
|
| 44 |
+
self._episode_id = episode_id or str(uuid.uuid4())
|
| 45 |
+
self._task_id = task_id if task_id in TASKS else random.choice(list(TASKS))
|
| 46 |
+
self._step_count = 0
|
| 47 |
+
self._queries_executed = 0
|
| 48 |
+
|
| 49 |
+
if self._db is not None:
|
| 50 |
+
self._db.close()
|
| 51 |
+
self._db = create_database()
|
| 52 |
+
|
| 53 |
+
task = TASKS[self._task_id]
|
| 54 |
+
schema = get_schema_info(self._db)
|
| 55 |
+
|
| 56 |
+
self._state = DataDetectiveState(
|
| 57 |
+
episode_id=self._episode_id,
|
| 58 |
+
step_count=0,
|
| 59 |
+
task_id=self._task_id,
|
| 60 |
+
queries_executed=0,
|
| 61 |
+
max_steps=self.MAX_STEPS,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
return DataDetectiveObservation(
|
| 65 |
+
done=False,
|
| 66 |
+
reward=None,
|
| 67 |
+
output="Environment ready. Run SQL queries to investigate the issue, then submit your answer.",
|
| 68 |
+
task_description=task["description"],
|
| 69 |
+
schema_info=schema,
|
| 70 |
+
step_number=0,
|
| 71 |
+
max_steps=self.MAX_STEPS,
|
| 72 |
+
message=f"Investigation: {task['title']} [{task['difficulty'].upper()}] -- {self.MAX_STEPS} steps available.",
|
| 73 |
+
)
|
| 74 |
+
|
| 75 |
+
def step(
|
| 76 |
+
self,
|
| 77 |
+
action: DataDetectiveAction,
|
| 78 |
+
timeout_s: Optional[float] = None,
|
| 79 |
+
**kwargs: Any,
|
| 80 |
+
) -> DataDetectiveObservation:
|
| 81 |
+
self._step_count += 1
|
| 82 |
+
self._state.step_count = self._step_count
|
| 83 |
+
|
| 84 |
+
remaining = self.MAX_STEPS - self._step_count
|
| 85 |
+
|
| 86 |
+
if self._step_count > self.MAX_STEPS:
|
| 87 |
+
return self._obs(
|
| 88 |
+
done=True, reward=0.0,
|
| 89 |
+
output="Maximum steps reached -- investigation ended with no answer submitted.",
|
| 90 |
+
message="Out of steps.",
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
atype = (action.action_type or "").strip().lower()
|
| 94 |
+
|
| 95 |
+
if atype == "query":
|
| 96 |
+
return self._handle_query(action.content, remaining)
|
| 97 |
+
elif atype == "answer":
|
| 98 |
+
return self._handle_answer(action.content)
|
| 99 |
+
else:
|
| 100 |
+
return self._obs(
|
| 101 |
+
done=False, reward=0.0,
|
| 102 |
+
output="",
|
| 103 |
+
message=f"Unknown action_type '{action.action_type}'. Use 'query' or 'answer'. ({remaining} steps left)",
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
@property
|
| 107 |
+
def state(self) -> DataDetectiveState:
|
| 108 |
+
return self._state
|
| 109 |
+
|
| 110 |
+
def close(self) -> None:
|
| 111 |
+
if self._db is not None:
|
| 112 |
+
self._db.close()
|
| 113 |
+
self._db = None
|
| 114 |
+
|
| 115 |
+
def _obs(self, *, done: bool, reward: float | None, output: str, message: str) -> DataDetectiveObservation:
|
| 116 |
+
return DataDetectiveObservation(
|
| 117 |
+
done=done,
|
| 118 |
+
reward=reward,
|
| 119 |
+
output=output,
|
| 120 |
+
task_description=TASKS[self._task_id]["description"],
|
| 121 |
+
schema_info="",
|
| 122 |
+
step_number=self._step_count,
|
| 123 |
+
max_steps=self.MAX_STEPS,
|
| 124 |
+
message=message,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
def _handle_query(self, sql: str, remaining: int) -> DataDetectiveObservation:
|
| 128 |
+
self._queries_executed += 1
|
| 129 |
+
self._state.queries_executed = self._queries_executed
|
| 130 |
+
|
| 131 |
+
if not sql or not sql.strip():
|
| 132 |
+
return self._obs(
|
| 133 |
+
done=False, reward=0.0,
|
| 134 |
+
output="Empty query -- please provide a valid SQL statement.",
|
| 135 |
+
message=f"{remaining} steps left.",
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
try:
|
| 139 |
+
cur = self._db.cursor()
|
| 140 |
+
cur.execute(sql)
|
| 141 |
+
columns = [d[0] for d in cur.description] if cur.description else []
|
| 142 |
+
rows = cur.fetchall()
|
| 143 |
+
output = _format_table(columns, rows) if rows else "Query returned 0 rows."
|
| 144 |
+
except Exception as exc:
|
| 145 |
+
output = f"SQL Error: {exc}"
|
| 146 |
+
return self._obs(
|
| 147 |
+
done=False, reward=0.0,
|
| 148 |
+
output=output,
|
| 149 |
+
message=f"Query failed. Fix your SQL and retry. ({remaining} steps left)",
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
return self._obs(
|
| 153 |
+
done=False, reward=0.0,
|
| 154 |
+
output=output,
|
| 155 |
+
message=f"{len(rows)} row(s) returned. ({remaining} steps left)",
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
def _handle_answer(self, answer_text: str) -> DataDetectiveObservation:
|
| 159 |
+
reward = grade_answer(self._task_id, answer_text)
|
| 160 |
+
if reward >= 0.8:
|
| 161 |
+
verdict = "Excellent investigation!"
|
| 162 |
+
elif reward >= 0.5:
|
| 163 |
+
verdict = "Good findings, but some details missing."
|
| 164 |
+
else:
|
| 165 |
+
verdict = "Several key findings were missed."
|
| 166 |
+
|
| 167 |
+
return self._obs(
|
| 168 |
+
done=True,
|
| 169 |
+
reward=reward,
|
| 170 |
+
output=f"Score: {reward:.2f} / 1.00 -- {verdict}",
|
| 171 |
+
message=f"Investigation complete. Final score: {reward:.2f}",
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def _format_table(columns: list[str], rows: list, max_rows: int = 100) -> str:
|
| 176 |
+
truncated = len(rows) > max_rows
|
| 177 |
+
display = rows[:max_rows]
|
| 178 |
+
|
| 179 |
+
widths = [len(str(c)) for c in columns]
|
| 180 |
+
for row in display:
|
| 181 |
+
for i, v in enumerate(row):
|
| 182 |
+
widths[i] = max(widths[i], min(len(str(v)), 60))
|
| 183 |
+
|
| 184 |
+
header = " | ".join(str(c).ljust(widths[i]) for i, c in enumerate(columns))
|
| 185 |
+
sep = "-+-".join("-" * w for w in widths)
|
| 186 |
+
lines = [header, sep]
|
| 187 |
+
for row in display:
|
| 188 |
+
lines.append(" | ".join(str(v).ljust(widths[i])[:60] for i, v in enumerate(row)))
|
| 189 |
+
|
| 190 |
+
if truncated:
|
| 191 |
+
lines.append(f"... (showing {max_rows} of {len(rows)} rows)")
|
| 192 |
+
return "\n".join(lines)
|
server/requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
openenv-core>=0.2.0
|
| 2 |
+
fastapi>=0.104.0
|
| 3 |
+
uvicorn>=0.24.0
|
| 4 |
+
pydantic>=2.0.0
|
| 5 |
+
websockets>=12.0
|
| 6 |
+
openai>=1.0.0
|
server/tasks.py
ADDED
|
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Task definitions and automated graders for the DataDetective environment.
|
| 3 |
+
|
| 4 |
+
Each task has:
|
| 5 |
+
- id, title, difficulty, description
|
| 6 |
+
- A grader function that scores the agent's final answer (0.0 - 1.0)
|
| 7 |
+
based on whether key findings are mentioned.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import re
|
| 11 |
+
from typing import Callable
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _has_any(text: str, keywords: list[str]) -> bool:
|
| 15 |
+
"""Case-insensitive check: does *text* contain any of *keywords*?"""
|
| 16 |
+
low = text.lower()
|
| 17 |
+
return any(kw.lower() in low for kw in keywords)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _has_pattern(text: str, pattern: str) -> bool:
|
| 21 |
+
return bool(re.search(pattern, text, re.IGNORECASE))
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _grade_orders_drop(answer: str) -> float:
|
| 25 |
+
score = 0.0
|
| 26 |
+
if _has_any(answer, ["drop", "decrease", "decline", "fell", "fewer", "reduction", "lower"]):
|
| 27 |
+
score += 0.20
|
| 28 |
+
if _has_any(answer, ["spring mega sale", "spring sale", "mega sale"]) or (
|
| 29 |
+
_has_any(answer, ["promotion", "promo", "sale", "discount", "campaign"])
|
| 30 |
+
):
|
| 31 |
+
score += 0.20
|
| 32 |
+
if _has_any(answer, ["ended", "expired", "over", "concluded", "stopped"]) or _has_pattern(
|
| 33 |
+
answer, r"march\s*0?1"
|
| 34 |
+
):
|
| 35 |
+
score += 0.20
|
| 36 |
+
if _has_any(answer, [
|
| 37 |
+
"caused", "because", "due to", "result of", "led to",
|
| 38 |
+
"when the", "after the", "ending of", "end of the",
|
| 39 |
+
"correlated", "explains",
|
| 40 |
+
]):
|
| 41 |
+
score += 0.20
|
| 42 |
+
if _has_pattern(answer, r"\d+\s*(orders|transactions)") or _has_pattern(
|
| 43 |
+
answer, r"\d+\s*%"
|
| 44 |
+
) or _has_pattern(answer, r"from\s+\d+.*to\s+\d+"):
|
| 45 |
+
score += 0.20
|
| 46 |
+
return min(score, 1.0)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _grade_returns_spike(answer: str) -> float:
|
| 50 |
+
score = 0.0
|
| 51 |
+
if _has_any(answer, ["wireless headphones", "headphones pro", "headphone"]):
|
| 52 |
+
score += 0.20
|
| 53 |
+
if _has_any(answer, ["west"]):
|
| 54 |
+
score += 0.20
|
| 55 |
+
if _has_any(answer, ["audiotech", "audio tech"]):
|
| 56 |
+
score += 0.20
|
| 57 |
+
if _has_any(answer, [
|
| 58 |
+
"defect", "defective", "faulty", "quality",
|
| 59 |
+
"high return", "return rate", "abnormal",
|
| 60 |
+
"stopped working", "battery issue", "poor audio",
|
| 61 |
+
]):
|
| 62 |
+
score += 0.20
|
| 63 |
+
if _has_pattern(answer, r"\d+\s*%") or _has_pattern(
|
| 64 |
+
answer, r"\d+\s*(returns|returned|units)"
|
| 65 |
+
) or _has_any(answer, ["return rate", "compared to"]):
|
| 66 |
+
score += 0.20
|
| 67 |
+
return min(score, 1.0)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _grade_customer_churn(answer: str) -> float:
|
| 71 |
+
score = 0.0
|
| 72 |
+
if _has_pattern(answer, r"\d+\s*%") or _has_any(answer, [
|
| 73 |
+
"decline", "decrease", "drop", "churn", "fewer active",
|
| 74 |
+
"lost customers", "stopped ordering",
|
| 75 |
+
]):
|
| 76 |
+
score += 0.20
|
| 77 |
+
if _has_any(answer, ["enterprise"]):
|
| 78 |
+
score += 0.20
|
| 79 |
+
if _has_any(answer, ["northeast", "north east", "north-east"]):
|
| 80 |
+
score += 0.20
|
| 81 |
+
if _has_any(answer, [
|
| 82 |
+
"price increase", "price change", "price hike", "pricing",
|
| 83 |
+
"more expensive", "raised price", "cost increase",
|
| 84 |
+
]):
|
| 85 |
+
score += 0.20
|
| 86 |
+
if _has_any(answer, [
|
| 87 |
+
"laptop pro", "desktop workstation", "office suite",
|
| 88 |
+
"devtools", "external ssd",
|
| 89 |
+
]) or _has_pattern(answer, r"product.*(1|2|11|15|19)"):
|
| 90 |
+
score += 0.20
|
| 91 |
+
return min(score, 1.0)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def _grade_shipping_delay(answer: str) -> float:
|
| 95 |
+
score = 0.0
|
| 96 |
+
if _has_any(answer, ["midwest"]):
|
| 97 |
+
score += 0.20
|
| 98 |
+
if _has_any(answer, ["quickship", "quick ship"]):
|
| 99 |
+
score += 0.20
|
| 100 |
+
if _has_any(answer, [
|
| 101 |
+
"delivery delay", "late delivery", "delayed shipment",
|
| 102 |
+
"shipping delay", "late shipment", "delivery time",
|
| 103 |
+
"delayed delivery", "slow delivery",
|
| 104 |
+
]):
|
| 105 |
+
score += 0.20
|
| 106 |
+
if _has_pattern(answer, r"feb(ruary)?\s*(10|mid|middle)") or _has_any(answer, [
|
| 107 |
+
"mid-february", "mid february", "around february",
|
| 108 |
+
"starting in february", "beginning of february",
|
| 109 |
+
]):
|
| 110 |
+
score += 0.20
|
| 111 |
+
if _has_any(answer, [
|
| 112 |
+
"support ticket", "complaint", "ticket volume",
|
| 113 |
+
"customer satisfaction", "support request",
|
| 114 |
+
]) and _has_any(answer, [
|
| 115 |
+
"delivery", "shipping", "carrier", "quickship",
|
| 116 |
+
]):
|
| 117 |
+
score += 0.20
|
| 118 |
+
return min(score, 1.0)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def _grade_revenue_paradox(answer: str) -> float:
|
| 122 |
+
score = 0.0
|
| 123 |
+
if _has_any(answer, [
|
| 124 |
+
"spring mega sale", "mega sale", "25%", "25 percent",
|
| 125 |
+
]) or (
|
| 126 |
+
_has_any(answer, ["promotion", "promo", "discount", "sale"])
|
| 127 |
+
and _has_any(answer, ["margin", "profit", "cost"])
|
| 128 |
+
):
|
| 129 |
+
score += 0.20
|
| 130 |
+
if _has_any(answer, [
|
| 131 |
+
"product mix", "category mix", "mix shift", "shifted toward",
|
| 132 |
+
"higher proportion", "more electronics", "low-margin",
|
| 133 |
+
"composition changed",
|
| 134 |
+
]):
|
| 135 |
+
score += 0.20
|
| 136 |
+
if _has_any(answer, ["enterprise"]) and _has_any(answer, [
|
| 137 |
+
"price increase", "price change", "price hike",
|
| 138 |
+
"lost", "churn", "left", "fewer", "decline",
|
| 139 |
+
]):
|
| 140 |
+
score += 0.20
|
| 141 |
+
if _has_any(answer, ["return", "refund"]) and _has_any(answer, [
|
| 142 |
+
"cost", "expense", "profit", "margin", "loss", "erode",
|
| 143 |
+
]):
|
| 144 |
+
score += 0.20
|
| 145 |
+
if _has_pattern(answer, r"\$\s*[\d,]+") or _has_pattern(
|
| 146 |
+
answer, r"\d+\s*%"
|
| 147 |
+
) or _has_pattern(answer, r"from\s+\$?[\d,]+.*to\s+\$?[\d,]+"):
|
| 148 |
+
score += 0.20
|
| 149 |
+
return min(score, 1.0)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def _grade_supplier_quality(answer: str) -> float:
|
| 153 |
+
score = 0.0
|
| 154 |
+
if _has_any(answer, ["audiotech", "audio tech"]):
|
| 155 |
+
score += 0.20
|
| 156 |
+
if _has_any(answer, ["wireless headphones", "headphones pro", "product 6"]):
|
| 157 |
+
score += 0.20
|
| 158 |
+
if _has_any(answer, ["bluetooth speaker", "product 7"]):
|
| 159 |
+
score += 0.20
|
| 160 |
+
if _has_any(answer, ["return rate", "refund", "return volume"]) or _has_pattern(
|
| 161 |
+
answer, r"\d+\s*%.*return"
|
| 162 |
+
) or _has_pattern(answer, r"return.*\d+\s*%") or _has_pattern(
|
| 163 |
+
answer, r"\$\s*[\d,]+"
|
| 164 |
+
):
|
| 165 |
+
score += 0.20
|
| 166 |
+
if _has_any(answer, [
|
| 167 |
+
"support ticket", "defect", "complaint", "product_defect",
|
| 168 |
+
"quality issue", "customer complaint",
|
| 169 |
+
]):
|
| 170 |
+
score += 0.20
|
| 171 |
+
return min(score, 1.0)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def _grade_inventory_stockout(answer: str) -> float:
|
| 175 |
+
score = 0.0
|
| 176 |
+
if _has_any(answer, ["west"]):
|
| 177 |
+
score += 0.20
|
| 178 |
+
if _has_any(answer, ["monitor", "product 4", "monitor 27"]):
|
| 179 |
+
score += 0.20
|
| 180 |
+
if _has_any(answer, [
|
| 181 |
+
"inventory", "stock", "out of stock", "stockout", "stock-out",
|
| 182 |
+
"zero units", "no inventory", "warehouse",
|
| 183 |
+
]):
|
| 184 |
+
score += 0.20
|
| 185 |
+
if _has_any(answer, [
|
| 186 |
+
"spring mega sale", "mega sale", "promo", "promotion",
|
| 187 |
+
"february 15", "feb 15", "during the sale",
|
| 188 |
+
]):
|
| 189 |
+
score += 0.20
|
| 190 |
+
if _has_pattern(answer, r"\d+\s*(units|orders|sales)") or _has_pattern(
|
| 191 |
+
answer, r"\d+\s*%"
|
| 192 |
+
) or _has_pattern(answer, r"from\s+\d+.*to\s+\d+"):
|
| 193 |
+
score += 0.20
|
| 194 |
+
return min(score, 1.0)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def _grade_fraud_detection(answer: str) -> float:
|
| 198 |
+
score = 0.0
|
| 199 |
+
if _has_any(answer, ["southeast"]):
|
| 200 |
+
score += 0.20
|
| 201 |
+
if _has_any(answer, [
|
| 202 |
+
"new account", "recent signup", "recently created",
|
| 203 |
+
"new customer", "account creation", "registered in feb",
|
| 204 |
+
"signed up",
|
| 205 |
+
]):
|
| 206 |
+
score += 0.20
|
| 207 |
+
if _has_any(answer, [
|
| 208 |
+
"high-value", "high value", "expensive", "laptop pro",
|
| 209 |
+
"desktop workstation", "large order", "electronics",
|
| 210 |
+
]):
|
| 211 |
+
score += 0.20
|
| 212 |
+
if _has_pattern(answer, r"1[0-5]\s*(account|customer|user)") or _has_pattern(
|
| 213 |
+
answer, r"\$\s*[\d,]+"
|
| 214 |
+
) or _has_pattern(answer, r"\d+\s*(order|transaction)"):
|
| 215 |
+
score += 0.20
|
| 216 |
+
if _has_any(answer, [
|
| 217 |
+
"pattern", "cluster", "coordinated", "suspicious",
|
| 218 |
+
"same product", "no return", "never returned",
|
| 219 |
+
"concentrated", "anomal", "fraud ring",
|
| 220 |
+
]):
|
| 221 |
+
score += 0.20
|
| 222 |
+
return min(score, 1.0)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def _grade_repeat_purchase_decline(answer: str) -> float:
|
| 226 |
+
score = 0.0
|
| 227 |
+
if _has_any(answer, [
|
| 228 |
+
"repeat purchase", "repeat rate", "returning customer",
|
| 229 |
+
"repeat buyer", "repurchase", "order frequency",
|
| 230 |
+
"second order", "came back",
|
| 231 |
+
]) and (_has_pattern(answer, r"\d+\s*%") or _has_any(answer, [
|
| 232 |
+
"decline", "drop", "decrease", "fell", "collapsed",
|
| 233 |
+
])):
|
| 234 |
+
score += 0.20
|
| 235 |
+
if _has_any(answer, ["enterprise"]) and _has_any(answer, [
|
| 236 |
+
"price", "increase", "hike", "stopped", "left", "churn",
|
| 237 |
+
]):
|
| 238 |
+
score += 0.20
|
| 239 |
+
if (_has_any(answer, ["midwest"]) or _has_any(answer, [
|
| 240 |
+
"shipping", "delivery", "quickship",
|
| 241 |
+
])) and _has_any(answer, [
|
| 242 |
+
"repeat", "return", "reorder", "come back", "second order",
|
| 243 |
+
]):
|
| 244 |
+
score += 0.20
|
| 245 |
+
if _has_any(answer, ["marketing", "acquisition", "spend"]) and _has_any(answer, [
|
| 246 |
+
"retention", "email", "loyalty", "re-engage", "lapsed",
|
| 247 |
+
"shifted", "new customer",
|
| 248 |
+
]):
|
| 249 |
+
score += 0.20
|
| 250 |
+
if _has_any(answer, [
|
| 251 |
+
"segment", "cohort", "by region", "by segment",
|
| 252 |
+
"enterprise vs", "consumer vs", "smb vs",
|
| 253 |
+
]) or _has_pattern(answer, r"(enterprise|smb|consumer).*\d+\s*%"):
|
| 254 |
+
score += 0.20
|
| 255 |
+
return min(score, 1.0)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
TASKS: dict[str, dict] = {
|
| 259 |
+
"orders_drop": {
|
| 260 |
+
"id": "orders_drop",
|
| 261 |
+
"difficulty": "easy",
|
| 262 |
+
"title": "Weekly Orders Drop Investigation",
|
| 263 |
+
"description": (
|
| 264 |
+
"URGENT -- Our order volume dropped sharply in the first two weeks "
|
| 265 |
+
"of March compared to the last two weeks of February. Leadership "
|
| 266 |
+
"needs to know why.\n\n"
|
| 267 |
+
"Investigate the database, identify the root cause of the drop, "
|
| 268 |
+
"and submit a clear summary of your findings."
|
| 269 |
+
),
|
| 270 |
+
},
|
| 271 |
+
"returns_spike": {
|
| 272 |
+
"id": "returns_spike",
|
| 273 |
+
"difficulty": "medium",
|
| 274 |
+
"title": "Product Returns Spike Investigation",
|
| 275 |
+
"description": (
|
| 276 |
+
"ALERT -- Our return rate has spiked significantly in recent weeks, "
|
| 277 |
+
"with particular concentration in one geographic region. This is "
|
| 278 |
+
"eating into margins.\n\n"
|
| 279 |
+
"Use the database to identify which product(s) are driving the "
|
| 280 |
+
"spike, which region is most affected, and what the likely root "
|
| 281 |
+
"cause is. Include the supplier if relevant."
|
| 282 |
+
),
|
| 283 |
+
},
|
| 284 |
+
"customer_churn": {
|
| 285 |
+
"id": "customer_churn",
|
| 286 |
+
"difficulty": "hard",
|
| 287 |
+
"title": "Customer Churn Root Cause Analysis",
|
| 288 |
+
"description": (
|
| 289 |
+
"CRITICAL -- Our monthly active customer count has declined "
|
| 290 |
+
"significantly from January to March. The executive team wants a "
|
| 291 |
+
"full root-cause analysis.\n\n"
|
| 292 |
+
"Determine which customer segments and regions are most affected, "
|
| 293 |
+
"quantify the decline, and identify the most likely causes. "
|
| 294 |
+
"Check all available tables for clues."
|
| 295 |
+
),
|
| 296 |
+
},
|
| 297 |
+
"shipping_delay": {
|
| 298 |
+
"id": "shipping_delay",
|
| 299 |
+
"difficulty": "medium-hard",
|
| 300 |
+
"title": "Customer Satisfaction Crisis Investigation",
|
| 301 |
+
"description": (
|
| 302 |
+
"ESCALATION -- Customer satisfaction scores have plummeted in one "
|
| 303 |
+
"of our regions. The support team is overwhelmed with complaints "
|
| 304 |
+
"and escalations are piling up.\n\n"
|
| 305 |
+
"Investigate what operational issue is driving the complaints, "
|
| 306 |
+
"identify the responsible party (carrier, warehouse, etc.), "
|
| 307 |
+
"determine when the problem started, and quantify the impact. "
|
| 308 |
+
"Cross-reference multiple data sources for a complete picture."
|
| 309 |
+
),
|
| 310 |
+
},
|
| 311 |
+
"revenue_paradox": {
|
| 312 |
+
"id": "revenue_paradox",
|
| 313 |
+
"difficulty": "hard",
|
| 314 |
+
"title": "Revenue vs. Profit Paradox Investigation",
|
| 315 |
+
"description": (
|
| 316 |
+
"CRITICAL -- Revenue in February was our highest month ever, yet "
|
| 317 |
+
"gross profit actually *decreased* compared to January. The CFO "
|
| 318 |
+
"wants a full breakdown of why we are selling more but earning "
|
| 319 |
+
"less.\n\n"
|
| 320 |
+
"Analyze revenue, costs, margins, discounts, product mix, customer "
|
| 321 |
+
"segments, and any other relevant factors. This is likely multi-"
|
| 322 |
+
"causal -- identify ALL contributing factors and quantify their "
|
| 323 |
+
"impact. Use the products.cost column to compute margins."
|
| 324 |
+
),
|
| 325 |
+
},
|
| 326 |
+
"supplier_quality": {
|
| 327 |
+
"id": "supplier_quality",
|
| 328 |
+
"difficulty": "medium",
|
| 329 |
+
"title": "Supplier Quality Crisis Investigation",
|
| 330 |
+
"description": (
|
| 331 |
+
"ESCALATION -- The VP of Merchandising has received escalating "
|
| 332 |
+
"complaints about product quality across multiple SKUs. Quality "
|
| 333 |
+
"Assurance wants a supplier-level analysis.\n\n"
|
| 334 |
+
"Determine which supplier(s) have systemic quality issues, which "
|
| 335 |
+
"of their products are affected, and quantify the total business "
|
| 336 |
+
"impact in returns, refunds, and support ticket volume. Include "
|
| 337 |
+
"return rates by supplier to support a contract renegotiation."
|
| 338 |
+
),
|
| 339 |
+
},
|
| 340 |
+
"inventory_stockout": {
|
| 341 |
+
"id": "inventory_stockout",
|
| 342 |
+
"difficulty": "medium-hard",
|
| 343 |
+
"title": "Regional Sales Underperformance Investigation",
|
| 344 |
+
"description": (
|
| 345 |
+
"INVESTIGATION -- Our West region was projected to be the top "
|
| 346 |
+
"performer during the Spring Mega Sale based on historical trends "
|
| 347 |
+
"and marketing investment, but actual sales came in significantly "
|
| 348 |
+
"below the other regions.\n\n"
|
| 349 |
+
"The Regional VP demands an explanation. Investigate what caused "
|
| 350 |
+
"the West to underperform during our biggest promotional event. "
|
| 351 |
+
"Check product-level sales, inventory data, and any operational "
|
| 352 |
+
"issues that may have limited fulfillment."
|
| 353 |
+
),
|
| 354 |
+
},
|
| 355 |
+
"fraud_detection": {
|
| 356 |
+
"id": "fraud_detection",
|
| 357 |
+
"difficulty": "hard",
|
| 358 |
+
"title": "Suspicious Order Pattern Investigation",
|
| 359 |
+
"description": (
|
| 360 |
+
"ALERT -- The Finance team has flagged a suspicious spike in "
|
| 361 |
+
"high-value orders from recently created accounts. Several of "
|
| 362 |
+
"these orders have already shipped.\n\n"
|
| 363 |
+
"Investigate the pattern: identify the suspicious accounts, "
|
| 364 |
+
"determine the scope of potential fraud, estimate the financial "
|
| 365 |
+
"exposure, and describe the behavioral signatures that "
|
| 366 |
+
"distinguish these accounts from legitimate customers. Look at "
|
| 367 |
+
"signup dates, order values, product choices, and geographic "
|
| 368 |
+
"concentration."
|
| 369 |
+
),
|
| 370 |
+
},
|
| 371 |
+
"repeat_purchase_decline": {
|
| 372 |
+
"id": "repeat_purchase_decline",
|
| 373 |
+
"difficulty": "hard",
|
| 374 |
+
"title": "Customer Retention Crisis Investigation",
|
| 375 |
+
"description": (
|
| 376 |
+
"CRITICAL -- Monthly unique buyer count has held steady around "
|
| 377 |
+
"100, but the Customer Success team reports that repeat purchase "
|
| 378 |
+
"rates have collapsed. In January, roughly 40%% of orders came "
|
| 379 |
+
"from returning customers; by March, it appears to be under 20%%."
|
| 380 |
+
"\n\n"
|
| 381 |
+
"The CEO asks: are we becoming a one-time-purchase business? "
|
| 382 |
+
"Diagnose which customer segments and regions lost repeat buyers, "
|
| 383 |
+
"identify the root causes, and determine whether our marketing "
|
| 384 |
+
"spend strategy is masking a retention problem. Check the "
|
| 385 |
+
"marketing_spend table for clues about acquisition vs. retention "
|
| 386 |
+
"investment."
|
| 387 |
+
),
|
| 388 |
+
},
|
| 389 |
+
}
|
| 390 |
+
|
| 391 |
+
_GRADERS: dict[str, Callable[[str], float]] = {
|
| 392 |
+
"orders_drop": _grade_orders_drop,
|
| 393 |
+
"returns_spike": _grade_returns_spike,
|
| 394 |
+
"customer_churn": _grade_customer_churn,
|
| 395 |
+
"shipping_delay": _grade_shipping_delay,
|
| 396 |
+
"revenue_paradox": _grade_revenue_paradox,
|
| 397 |
+
"supplier_quality": _grade_supplier_quality,
|
| 398 |
+
"inventory_stockout": _grade_inventory_stockout,
|
| 399 |
+
"fraud_detection": _grade_fraud_detection,
|
| 400 |
+
"repeat_purchase_decline": _grade_repeat_purchase_decline,
|
| 401 |
+
}
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def grade_answer(task_id: str, answer: str) -> float:
|
| 405 |
+
grader = _GRADERS.get(task_id)
|
| 406 |
+
if grader is None:
|
| 407 |
+
return 0.0
|
| 408 |
+
return grader(answer)
|