NL-SQL / database.py
sidharth-pm's picture
Create database.py
ff679e0 verified
import sqlite3
import os
DB_PATH = "company.db"
def init_database():
"""Initialize the SQLite database with sample tables and data."""
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
# Create tables
cursor.executescript("""
CREATE TABLE IF NOT EXISTS employees (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
department TEXT NOT NULL,
salary REAL NOT NULL,
hire_date TEXT NOT NULL,
manager_id INTEGER,
email TEXT UNIQUE
);
CREATE TABLE IF NOT EXISTS departments (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL UNIQUE,
budget REAL NOT NULL,
location TEXT NOT NULL,
head_count INTEGER DEFAULT 0
);
CREATE TABLE IF NOT EXISTS projects (
id INTEGER PRIMARY KEY AUTOINCREMENT,
title TEXT NOT NULL,
department TEXT NOT NULL,
start_date TEXT NOT NULL,
end_date TEXT,
status TEXT DEFAULT 'active',
budget REAL NOT NULL
);
CREATE TABLE IF NOT EXISTS sales (
id INTEGER PRIMARY KEY AUTOINCREMENT,
employee_id INTEGER NOT NULL,
product TEXT NOT NULL,
amount REAL NOT NULL,
sale_date TEXT NOT NULL,
region TEXT NOT NULL,
FOREIGN KEY (employee_id) REFERENCES employees(id)
);
""")
# Seed data only if empty
cursor.execute("SELECT COUNT(*) FROM employees")
if cursor.fetchone()[0] == 0:
cursor.executescript("""
INSERT INTO departments (name, budget, location, head_count) VALUES
('Engineering', 2500000, 'San Francisco', 12),
('Sales', 1800000, 'New York', 15),
('Marketing', 900000, 'Chicago', 8),
('HR', 600000, 'Austin', 5),
('Finance', 750000, 'Boston', 6);
INSERT INTO employees (name, department, salary, hire_date, email) VALUES
('Alice Johnson', 'Engineering', 120000, '2020-03-15', 'alice@company.com'),
('Bob Smith', 'Engineering', 115000, '2019-07-22', 'bob@company.com'),
('Carol White', 'Sales', 85000, '2021-01-10', 'carol@company.com'),
('David Brown', 'Sales', 92000, '2018-11-05', 'david@company.com'),
('Eve Davis', 'Marketing', 78000, '2022-04-18', 'eve@company.com'),
('Frank Miller', 'HR', 72000, '2020-09-30', 'frank@company.com'),
('Grace Wilson', 'Engineering', 130000, '2017-06-01', 'grace@company.com'),
('Henry Moore', 'Finance', 95000, '2019-02-14', 'henry@company.com'),
('Iris Taylor', 'Marketing', 81000, '2021-08-25', 'iris@company.com'),
('Jack Anderson', 'Sales', 88000, '2020-12-07', 'jack@company.com'),
('Karen Thomas', 'Engineering', 125000, '2018-05-20', 'karen@company.com'),
('Leo Jackson', 'Finance', 98000, '2016-10-11', 'leo@company.com');
INSERT INTO projects (title, department, start_date, end_date, status, budget) VALUES
('AI Platform v2', 'Engineering', '2024-01-01', '2024-12-31', 'active', 500000),
('Customer Portal', 'Engineering', '2023-06-01', '2024-03-31', 'completed', 200000),
('Q4 Campaign', 'Marketing', '2024-10-01', '2024-12-31', 'active', 150000),
('Sales CRM Migration', 'Sales', '2024-03-01', NULL, 'active', 80000),
('Annual Audit', 'Finance', '2024-11-01', '2024-11-30', 'completed', 30000),
('Talent Pipeline', 'HR', '2024-07-01', NULL, 'active', 50000);
INSERT INTO sales (employee_id, product, amount, sale_date, region) VALUES
(3, 'Enterprise License', 45000, '2024-01-15', 'East'),
(4, 'SaaS Subscription', 12000, '2024-02-20', 'West'),
(10, 'Consulting Package', 28000, '2024-03-10', 'Central'),
(3, 'Enterprise License', 52000, '2024-04-05', 'East'),
(4, 'Support Plan', 8500, '2024-05-18', 'West'),
(10, 'Enterprise License', 61000, '2024-06-22', 'East'),
(3, 'SaaS Subscription', 15000, '2024-07-30', 'Central'),
(4, 'Consulting Package', 33000, '2024-08-14', 'West'),
(10, 'Support Plan', 9000, '2024-09-01', 'East'),
(3, 'Enterprise License', 47000, '2024-10-17', 'Central');
""")
conn.commit()
conn.close()
print("✅ Database initialized successfully.")
def get_schema() -> str:
"""Return a CREATE TABLE schema string for the prompt."""
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
cursor.execute("SELECT sql FROM sqlite_master WHERE type='table' AND sql IS NOT NULL")
tables = cursor.fetchall()
conn.close()
return "\n\n".join(t[0] for t in tables)
def execute_query(sql: str):
"""Execute a SQL query and return (columns, rows) or raise on error."""
conn = sqlite3.connect(DB_PATH)
cursor = conn.cursor()
try:
cursor.execute(sql)
columns = [desc[0] for desc in cursor.description] if cursor.description else []
rows = cursor.fetchall()
conn.commit()
return columns, rows
finally:
conn.close()
if __name__ == "__main__":
init_database()
schema = get_schema()
print("\n--- SCHEMA ---")
print(schema)