Spaces:
Sleeping
Sleeping
| import sqlite3 | |
| import random | |
| from datetime import datetime, timedelta | |
| from typing import Optional, Any | |
| from faker import Faker | |
| fake = Faker() | |
| SEED_CONFIG = { | |
| "users": 500, | |
| "products": 80, | |
| "orders": 2000, | |
| "order_items": 5000, | |
| "events": 8000, | |
| } | |
| CATEGORIES = ["Electronics", "Clothing", "Books", "Home & Garden", "Sports"] | |
| PLAN_TYPES = ["free", "pro", "enterprise"] | |
| ORDER_STATUSES = ["pending", "completed", "refunded"] | |
| EVENT_TYPES = ["page_view", "add_to_cart", "checkout", "login", "logout"] | |
| def create_database(db_path: str = ":memory:") -> sqlite3.Connection: | |
| conn = sqlite3.connect(db_path) | |
| conn.row_factory = sqlite3.Row | |
| conn.execute(""" | |
| CREATE TABLE users ( | |
| id INTEGER PRIMARY KEY, | |
| email TEXT NOT NULL, | |
| country TEXT, | |
| plan TEXT CHECK(plan IN ('free', 'pro', 'enterprise')), | |
| created_at TIMESTAMP NOT NULL, | |
| churned_at TIMESTAMP | |
| ) | |
| """) | |
| conn.execute(""" | |
| CREATE TABLE products ( | |
| id INTEGER PRIMARY KEY, | |
| name TEXT NOT NULL, | |
| category TEXT NOT NULL, | |
| price REAL, | |
| cost REAL | |
| ) | |
| """) | |
| conn.execute(""" | |
| CREATE TABLE orders ( | |
| id INTEGER PRIMARY KEY, | |
| user_id INTEGER REFERENCES users(id), | |
| created_at TIMESTAMP NOT NULL, | |
| status TEXT CHECK(status IN ('pending', 'completed', 'refunded')), | |
| total REAL | |
| ) | |
| """) | |
| conn.execute(""" | |
| CREATE TABLE order_items ( | |
| id INTEGER PRIMARY KEY, | |
| order_id INTEGER REFERENCES orders(id), | |
| product_id INTEGER REFERENCES products(id), | |
| qty INTEGER NOT NULL, | |
| unit_price REAL | |
| ) | |
| """) | |
| conn.execute(""" | |
| CREATE TABLE events ( | |
| id INTEGER PRIMARY KEY, | |
| user_id INTEGER REFERENCES users(id), | |
| event_type TEXT, | |
| metadata TEXT, | |
| ts TIMESTAMP NOT NULL | |
| ) | |
| """) | |
| conn.commit() | |
| return conn | |
| def seed_database(conn: sqlite3.Connection) -> None: | |
| users = _seed_users(conn) | |
| products = _seed_products(conn) | |
| orders, order_items = _seed_orders(conn, users, products) | |
| _seed_events(conn, users, orders) | |
| def _seed_users(conn: sqlite3.Connection) -> list: | |
| users = [] | |
| now = datetime.now() | |
| base_date = now - timedelta(days=180) | |
| recent_date = now - timedelta(days=30) | |
| for i in range(SEED_CONFIG["users"]): | |
| if random.random() < 0.3: | |
| created_at = recent_date + timedelta(days=random.randint(0, 30)) | |
| else: | |
| created_at = base_date + timedelta(days=random.randint(0, 180)) | |
| country = random.choice([fake.country(), None, None, None, None]) | |
| plan = random.choice(PLAN_TYPES) | |
| churned_at = None | |
| if plan == "free" and random.random() < 0.1: | |
| churned_at = created_at + timedelta(days=random.randint(30, 150)) | |
| conn.execute( | |
| "INSERT INTO users (email, country, plan, created_at, churned_at) VALUES (?, ?, ?, ?, ?)", | |
| ( | |
| fake.email(), | |
| country, | |
| plan, | |
| created_at.isoformat(), | |
| churned_at.isoformat() if churned_at else None, | |
| ), | |
| ) | |
| users.append((i + 1, created_at)) | |
| conn.commit() | |
| return users | |
| def _seed_products(conn: sqlite3.Connection) -> list: | |
| products = [] | |
| for i in range(SEED_CONFIG["products"]): | |
| category = random.choice(CATEGORIES) | |
| price = round(random.uniform(10, 500), 2) | |
| cost = round(price * random.uniform(0.3, 0.7), 2) | |
| conn.execute( | |
| "INSERT INTO products (name, category, price, cost) VALUES (?, ?, ?, ?)", | |
| (fake.catch_phrase(), category, price, cost), | |
| ) | |
| products.append((i + 1, category, price)) | |
| conn.commit() | |
| return products | |
| def _seed_orders(conn: sqlite3.Connection, users: list, products: list) -> tuple: | |
| orders = [] | |
| order_items = [] | |
| q3_start = datetime(2024, 7, 1) | |
| q3_end = datetime(2024, 9, 30) | |
| recent_date = datetime.now() | |
| old_date = datetime(2024, 1, 1) | |
| for i in range(SEED_CONFIG["orders"]): | |
| user_id = random.choice(users)[0] | |
| if random.random() < 0.2: | |
| created_at = q3_start + timedelta(days=random.randint(0, 91)) | |
| else: | |
| created_at = old_date + timedelta(days=random.randint(0, 180)) | |
| status = random.choices(ORDER_STATUSES, weights=[0.1, 0.87, 0.03])[0] | |
| conn.execute( | |
| "INSERT INTO orders (user_id, created_at, status, total) VALUES (?, ?, ?, ?)", | |
| (user_id, created_at.isoformat(), status, 0), | |
| ) | |
| order_id = i + 1 | |
| order_total = 0 | |
| num_items = random.randint(1, 5) | |
| for _ in range(num_items): | |
| product = random.choice(products) | |
| qty = random.randint(1, 3) | |
| unit_price = product[2] | |
| order_total += qty * unit_price | |
| conn.execute( | |
| "INSERT INTO order_items (order_id, product_id, qty, unit_price) VALUES (?, ?, ?, ?)", | |
| (order_id, product[0], qty, unit_price), | |
| ) | |
| conn.execute( | |
| "UPDATE orders SET total = ? WHERE id = ?", | |
| (round(order_total, 2), order_id), | |
| ) | |
| orders.append((order_id, user_id, created_at, status)) | |
| conn.commit() | |
| return orders, order_items | |
| def _seed_events(conn: sqlite3.Connection, users: list, orders: list) -> None: | |
| base_date = datetime.now() - timedelta(days=180) | |
| for _ in range(SEED_CONFIG["events"]): | |
| user_id = random.choice(users)[0] | |
| ts = base_date + timedelta( | |
| days=random.randint(0, 180), hours=random.randint(0, 23) | |
| ) | |
| event_type = random.choice(EVENT_TYPES) | |
| metadata = '{"page": "/' + fake.uri_path() + '"}' | |
| conn.execute( | |
| "INSERT INTO events (user_id, event_type, metadata, ts) VALUES (?, ?, ?, ?)", | |
| (user_id, event_type, metadata, ts.isoformat()), | |
| ) | |
| conn.commit() | |
| def get_schema_summary(conn: sqlite3.Connection) -> str: | |
| cursor = conn.execute( | |
| "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name" | |
| ) | |
| tables = [r[0] for r in cursor.fetchall()] | |
| lines = [] | |
| for table in tables: | |
| cols = conn.execute(f"PRAGMA table_info({table})").fetchall() | |
| col_names = [c[1] for c in cols] | |
| lines.append(f"{table}: ({', '.join(col_names)})") | |
| return "\n".join(lines) | |
| def get_ground_truth(conn: sqlite3.Connection, task_id: str) -> Any: | |
| if task_id == "monthly_signups": | |
| result = conn.execute( | |
| "SELECT COUNT(*) FROM users WHERE created_at >= DATE('now', '-30 days')" | |
| ).fetchone() | |
| return result[0] | |
| elif task_id == "top_revenue_category": | |
| result = conn.execute(""" | |
| SELECT p.category, SUM(oi.qty * oi.unit_price) AS revenue | |
| FROM orders o | |
| JOIN order_items oi ON o.id = oi.order_id | |
| JOIN products p ON oi.product_id = p.id | |
| WHERE o.created_at BETWEEN '2024-07-01' AND '2024-09-30' | |
| AND o.status = 'completed' | |
| GROUP BY p.category | |
| ORDER BY revenue DESC | |
| LIMIT 1 | |
| """).fetchone() | |
| return result[0] if result else None | |
| elif task_id == "churn_analysis": | |
| result = conn.execute(""" | |
| WITH order_counts AS ( | |
| SELECT user_id, COUNT(*) AS total_orders, | |
| MAX(created_at) AS last_order_date | |
| FROM orders | |
| WHERE status = 'completed' | |
| GROUP BY user_id | |
| HAVING COUNT(*) = 3 | |
| ), | |
| churned AS ( | |
| SELECT oc.user_id | |
| FROM order_counts oc | |
| WHERE oc.last_order_date < DATE('now', '-90 days') | |
| ) | |
| SELECT u.email | |
| FROM users u | |
| JOIN churned c ON u.id = c.user_id | |
| """).fetchall() | |
| return {row[0].lower() for row in result} | |
| return None | |