Chit1324 commited on
Commit
0276b0d
·
verified ·
1 Parent(s): ea46e6e

Upload 9 files

Browse files
app.py CHANGED
@@ -1,23 +1,23 @@
1
- import gradio as gr
2
- from db.database import Database
3
- from nlp.query_processor import QueryProcessor
4
-
5
- # Initialize Database and Query Processor
6
- db = Database(db_name="chat_assistant.db")
7
- query_processor = QueryProcessor(db)
8
-
9
- def respond(message, history):
10
- """Processes user queries, fetches results from the database, and returns responses."""
11
- response = query_processor.process_query(message)
12
- return response
13
-
14
- # Gradio Chat UI
15
- demo = gr.ChatInterface(
16
- respond,
17
- additional_inputs=[],
18
- title="SQL Chat Assistant",
19
- description="Ask any database-related question, and I will generate an SQL query and fetch the relevant data.",
20
- )
21
-
22
- if __name__ == "__main__":
23
- demo.launch()
 
1
+ import gradio as gr
2
+ from db.database import Database
3
+ from nlp.query_processor import QueryProcessor
4
+
5
+ # Initialize Database and Query Processor
6
+ db = Database(db_name="chat_assistant.db")
7
+ query_processor = QueryProcessor(db)
8
+
9
+ def respond(message, history):
10
+ """Processes user queries, fetches results from the database, and returns responses."""
11
+ response = query_processor.process_query(message)
12
+ return response
13
+
14
+ # Gradio Chat UI
15
+ demo = gr.ChatInterface(
16
+ respond,
17
+ additional_inputs=[],
18
+ title="SQL Chat Assistant",
19
+ description="Ask any database-related question, and I will generate an SQL query and fetch the relevant data.",
20
+ )
21
+
22
+ if __name__ == "__main__":
23
+ demo.launch()
chat_assistant.db ADDED
Binary file (28.7 kB). View file
 
db/__pycache__/database.cpython-312.pyc ADDED
Binary file (1.82 kB). View file
 
db/database.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+
3
+ class Database:
4
+ def __init__(self, db_name):
5
+ self.db_name = db_name
6
+ self.conn = None
7
+
8
+ def connect(self):
9
+ try:
10
+ self.conn = sqlite3.connect(self.db_name)
11
+ return self.conn
12
+ except sqlite3.Error as e:
13
+ print(f"Database connection error: {e}")
14
+ return None
15
+
16
+ def close(self):
17
+ if self.conn:
18
+ self.conn.close()
19
+
20
+ def execute_query(self, query, params=None):
21
+ try:
22
+ cursor = self.conn.cursor()
23
+ if params:
24
+ cursor.execute(query, params)
25
+ else:
26
+ cursor.execute(query)
27
+ self.conn.commit()
28
+ return cursor
29
+ except sqlite3.Error as e:
30
+ print(f"Query execution error: {e}")
31
+ return None
db/init_db.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+
3
+ def initialize_database(db_name='chat_assistant.db'):
4
+ conn = None
5
+ try:
6
+ conn = sqlite3.connect(db_name)
7
+ cursor = conn.cursor()
8
+
9
+ # Create Employees table
10
+ cursor.execute('''CREATE TABLE IF NOT EXISTS employees (
11
+ ID INTEGER PRIMARY KEY,
12
+ Name TEXT,
13
+ Department TEXT,
14
+ Salary REAL,
15
+ Hire_Date TEXT
16
+ )''')
17
+
18
+ # Create Departments table
19
+ cursor.execute('''CREATE TABLE IF NOT EXISTS departments (
20
+ ID INTEGER PRIMARY KEY,
21
+ Name TEXT,
22
+ Manager TEXT
23
+ )''')
24
+
25
+ cursor.execute('''CREATE TABLE IF NOT EXISTS table_metadata (
26
+ table_name TEXT PRIMARY KEY,
27
+ description TEXT
28
+ )''')
29
+
30
+ cursor.execute('''CREATE TABLE IF NOT EXISTS column_metadata (
31
+ table_name TEXT,
32
+ column_name TEXT,
33
+ data_type TEXT,
34
+ description TEXT,
35
+ PRIMARY KEY (table_name, column_name),
36
+ FOREIGN KEY (table_name) references table_metadata (table_name)
37
+ )''')
38
+ # Insert sample data into Employees table
39
+ employees_data = [
40
+ (1, 'Alice', 'Sales', 50000, '2021-01-15'),
41
+ (2, 'Bob', 'Engineering', 70000, '2020-06-10'),
42
+ (3, 'Charlie', 'Marketing', 60000, '2022-03-20')
43
+ ]
44
+ cursor.executemany('INSERT INTO Employees VALUES (?, ?, ?, ?, ?)', employees_data)
45
+
46
+ # Insert sample data into Departments table
47
+ departments_data = [
48
+ (1, 'Sales', 'Alice'),
49
+ (2, 'Engineering', 'Bob'),
50
+ (3, 'Marketing', 'Charlie')
51
+ ]
52
+ cursor.executemany('INSERT INTO Departments VALUES (?, ?, ?)', departments_data)
53
+ cursor.execute("INSERT INTO table_metadata (table_name, description) VALUES ('Employees','Details of Employees in Department.')")
54
+ cursor.execute("INSERT INTO table_metadata (table_name, description) VALUES ('Department','Details of Manager of the Department.')")
55
+
56
+ cursor.execute("INSERT INTO column_metadata (table_name, column_name, data_type,description) VALUES ('Employees','id','INTEGER','Identification number of the Employee.')")
57
+ cursor.execute("INSERT INTO column_metadata (table_name, column_name, data_type,description) VALUES ('Employees','name','TEXT','Name of the Employee.')")
58
+ cursor.execute("INSERT INTO column_metadata (table_name, column_name, data_type,description) VALUES ('Employees','department','TEXT','Department of the Employee.')")
59
+ cursor.execute("INSERT INTO column_metadata (table_name, column_name, data_type,description) VALUES ('Employees','Salary','INTEGER','Salary of the Employee.')")
60
+ cursor.execute("INSERT INTO column_metadata (table_name, column_name, data_type,description) VALUES ('Employees','Hire_Date','DATE','Date in which the Employee was hired.')")
61
+
62
+ cursor.execute("INSERT INTO column_metadata (table_name, column_name, data_type,description) VALUES ('Departments','id','INTEGER','Identification number of the Employee.')")
63
+ cursor.execute("INSERT INTO column_metadata (table_name, column_name, data_type,description) VALUES ('Departments','name','TEXT','Name of the Department.')")
64
+ cursor.execute("INSERT INTO column_metadata (table_name, column_name, data_type,description) VALUES ('Departments','manager','TEXT','Manager of the Department.')")
65
+ conn.commit()
66
+
67
+ print("Database initialized successfully.")
68
+
69
+ except sqlite3.Error as e:
70
+ print(f"Database initialization error: {e}")
71
+
72
+ finally:
73
+ if conn:
74
+ conn.close()
75
+
76
+ if __name__ == "__main__":
77
+ initialize_database()
nlp/__pycache__/query_processor.cpython-312.pyc ADDED
Binary file (2.95 kB). View file
 
nlp/__pycache__/sql_model.cpython-312.pyc ADDED
Binary file (3.19 kB). View file
 
nlp/query_processor.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import sqlite3 # Added import for sqlite3
3
+ from datetime import datetime
4
+ from db.database import Database
5
+ from nlp.sql_model import SQLModel
6
+
7
+ class QueryProcessor:
8
+ def __init__(self, db):
9
+ self.db = db
10
+ self.sql_model = SQLModel()
11
+
12
+ def convert_date(self, date_str):
13
+ formats = [
14
+ '%d/%m/%Y',
15
+ '%d %B %Y',
16
+ '%B %Y',
17
+ '%b %Y',
18
+ '%Y-%m-%d'
19
+ ]
20
+ for fmt in formats:
21
+ try:
22
+ return datetime.strptime(date_str, fmt).date()
23
+ except ValueError:
24
+ continue
25
+ raise ValueError('Invalid date format')
26
+
27
+ def extract_and_convert_date(self, user_query):
28
+ date_pattern = r'\b(\b[0-9]{1,2}[/-]?[0-9]{1,2}[/-]?[0-9]{2,4}\b|\b(?:\b(?:January|February|March|April|May|June|July|August|September|October|November|December|Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)\b[- ]?[0-9]{1,2}[- ]?[0-9]{2,4}\b|\b[0-9]{2,4}[- ]?(?:January|February|March|April|May|June|July|August|September|October|November|December|Jan|Feb|Mar|Apr|May|Jun|Jul|Aug|Sep|Oct|Nov|Dec)\b)\b)'
29
+ match = re.search(date_pattern, user_query, re.I)
30
+ if match:
31
+ date_str = match.group(0)
32
+ try:
33
+ return self.convert_date(date_str)
34
+ except ValueError:
35
+ return None
36
+ return None
37
+
38
+
39
+ def process_query(self, user_query):
40
+ conn = self.db.connect()
41
+ if not conn:
42
+ return "Failed to connect to the database."
43
+
44
+ response = ''
45
+ try:
46
+ # Generate SQL query using the LLM model
47
+ sql_query = self.sql_model.generate_sql(user_query)
48
+ # Execute generated SQL query against the database
49
+ cursor = self.db.execute_query(sql_query)
50
+ if cursor:
51
+ results = cursor.fetchall()
52
+ if results:
53
+ response = f'Results: {results}'
54
+ else:
55
+ response = 'No results found.'
56
+ else:
57
+ response = 'Error executing query.'
58
+ finally:
59
+ self.db.close()
60
+ return response
nlp/sql_model.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
+ # from transformers import AutoTokenizer, AutoModelForCausalLM
3
+
4
+ class SQLModel:
5
+ def __init__(self, model_name="google/flan-t5-base"):
6
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
7
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
8
+ # self.tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b-it")
9
+ # self.model = AutoModelForCausalLM.from_pretrained("google/gemma-7b-it",)
10
+
11
+ def generate_sql(self, natural_language_query):
12
+ input_text = f"""You are a highly skilled SQL translator. Your task is to convert natural language descriptions of data queries into correct and optimized SQL statements.
13
+ Here is the schema information for our database :
14
+
15
+ Table: Employees
16
+ - id (INT)
17
+ - NAME (VARCHAR)
18
+ - Department (VARCHAR)
19
+ - Salary (INT)
20
+ - Hire_Date (DATE)
21
+
22
+ Table: Departments
23
+ - ID (INT)
24
+ - Name (VARCHAR)
25
+ - Manager (VARCHAR)
26
+
27
+ Here are a few examples:
28
+
29
+ 1. **Input**: "Show me all employees in the Sales department."
30
+ **Output**:
31
+
32
+ SELECT *
33
+ FROM Employees
34
+ WHERE Department = 'Sales';
35
+
36
+ 2. **Input**: "Who is the manager of the Engineering department?"
37
+ **Output**:
38
+
39
+ SELECT Manager
40
+ FROM Departments
41
+ WHERE Name = 'Engineering';
42
+
43
+
44
+ 3. **Input**: "List all employees hired after 2021-01-01."
45
+ **Output**:
46
+
47
+ SELECT *
48
+ FROM Employees
49
+ WHERE Hire_Date > '2021-01-01';
50
+
51
+
52
+ 4. **Input**: "What is the total salary expense for the Marketing department?"
53
+ **Output**:
54
+
55
+ SELECT SUM(Salary)
56
+ FROM Employees
57
+ WHERE Department = 'Marketing';
58
+
59
+
60
+ 5. **Input**: "Find the average salary of employees in each department."
61
+ **Output**:
62
+
63
+ SELECT Department, AVG(Salary) AS average_salary
64
+ FROM Employees
65
+ GROUP BY Department;
66
+
67
+ Please do not return additional text besides query.
68
+
69
+ Please only answer queries which makes sense for the given schema. Else just return - "No information found"
70
+
71
+ Now, translate the following natural language query into an syntactically correct SQL query:
72
+ **Input**: {natural_language_query}
73
+ **Output**:
74
+
75
+ """
76
+ # input_text = f"""
77
+ # Translate the following natural language query into a syntactically correct SQL query using the provided database schema. Output only the SQL query with no additional text or explanation.
78
+
79
+ # Database Schema:
80
+
81
+ # Table: Employees
82
+ # - id (INT)
83
+ # - NAME (VARCHAR)
84
+ # - Department (VARCHAR)
85
+ # - Salary (INT)
86
+ # - Hire_Date (DATE)
87
+
88
+ # Table: Departments
89
+ # - ID (INT)
90
+ # - Name (VARCHAR)
91
+ # - Manager (VARCHAR)
92
+
93
+ # Examples:
94
+ # 1. Natural Language Query: "List all employees who were hired after '2020-01-01'."
95
+ # Output: SELECT * FROM Employees WHERE Hire_Date > '2020-01-01';
96
+
97
+ # 2. Natural Language Query: "Retrieve the names and salaries of employees in the 'Sales' department."
98
+ # Output: SELECT NAME, Salary FROM Employees WHERE Department = 'Sales';
99
+
100
+ # Now, translate the following query:
101
+ # {natural_language_query}
102
+ # """
103
+ # input_text = f"translate English to SQL: {natural_language_query}"
104
+ # inputs = self.tokenizer(input_text, return_tensors="pt").input_ids
105
+ # outputs = self.model.generate(inputs, max_new_tokens=100, do_sample=False)
106
+ # sql_query = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
107
+ inputs = self.tokenizer(input_text, return_tensors="pt")
108
+ outputs = self.model.generate(**inputs)
109
+ sql_query = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
110
+ print(sql_query)
111
+ return sql_query