Chit1324 commited on
Commit
df56a4d
·
verified ·
1 Parent(s): f2fac6d

Upload sql_model.py

Browse files
Files changed (1) hide show
  1. sql_model.py +111 -0
sql_model.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+
4
+ class SQLModel:
5
+ def __init__(self, model_name="google/flan-t5-base"):
6
+ # self.tokenizer = AutoTokenizer.from_pretrained(model_name)
7
+ # self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
8
+ self.tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b-it")
9
+ self.model = AutoModelForCausalLM.from_pretrained("google/gemma-7b-it",)
10
+
11
+ def generate_sql(self, natural_language_query):
12
+ input_text = f"""You are a highly skilled SQL translator. Your task is to convert natural language descriptions of data queries into correct and optimized SQL statements.
13
+ Here is the schema information for our database :
14
+
15
+ Table: Employees
16
+ - id (INT)
17
+ - NAME (VARCHAR)
18
+ - Department (VARCHAR)
19
+ - Salary (INT)
20
+ - Hire_Date (DATE)
21
+
22
+ Table: Departments
23
+ - ID (INT)
24
+ - Name (VARCHAR)
25
+ - Manager (VARCHAR)
26
+
27
+ Here are a few examples:
28
+
29
+ 1. **Input**: "Show me all employees in the Sales department."
30
+ **Output**:
31
+
32
+ SELECT *
33
+ FROM Employees
34
+ WHERE Department = 'Sales';
35
+
36
+ 2. **Input**: "Who is the manager of the Engineering department?"
37
+ **Output**:
38
+
39
+ SELECT Manager
40
+ FROM Departments
41
+ WHERE Name = 'Engineering';
42
+
43
+
44
+ 3. **Input**: "List all employees hired after 2021-01-01."
45
+ **Output**:
46
+
47
+ SELECT *
48
+ FROM Employees
49
+ WHERE Hire_Date > '2021-01-01';
50
+
51
+
52
+ 4. **Input**: "What is the total salary expense for the Marketing department?"
53
+ **Output**:
54
+
55
+ SELECT SUM(Salary)
56
+ FROM Employees
57
+ WHERE Department = 'Marketing';
58
+
59
+
60
+ 5. **Input**: "Find the average salary of employees in each department."
61
+ **Output**:
62
+
63
+ SELECT Department, AVG(Salary) AS average_salary
64
+ FROM Employees
65
+ GROUP BY Department;
66
+
67
+ Please do not return additional text besides query.
68
+
69
+ Please only answer queries which makes sense for the given schema. Else just return - "No information found"
70
+
71
+ Now, translate the following natural language query into an syntactically correct SQL query:
72
+ **Input**: {natural_language_query}
73
+ **Output**:
74
+
75
+ """
76
+ # input_text = f"""
77
+ # Translate the following natural language query into a syntactically correct SQL query using the provided database schema. Output only the SQL query with no additional text or explanation.
78
+
79
+ # Database Schema:
80
+
81
+ # Table: Employees
82
+ # - id (INT)
83
+ # - NAME (VARCHAR)
84
+ # - Department (VARCHAR)
85
+ # - Salary (INT)
86
+ # - Hire_Date (DATE)
87
+
88
+ # Table: Departments
89
+ # - ID (INT)
90
+ # - Name (VARCHAR)
91
+ # - Manager (VARCHAR)
92
+
93
+ # Examples:
94
+ # 1. Natural Language Query: "List all employees who were hired after '2020-01-01'."
95
+ # Output: SELECT * FROM Employees WHERE Hire_Date > '2020-01-01';
96
+
97
+ # 2. Natural Language Query: "Retrieve the names and salaries of employees in the 'Sales' department."
98
+ # Output: SELECT NAME, Salary FROM Employees WHERE Department = 'Sales';
99
+
100
+ # Now, translate the following query:
101
+ # {natural_language_query}
102
+ # """
103
+ # input_text = f"translate English to SQL: {natural_language_query}"
104
+ # inputs = self.tokenizer(input_text, return_tensors="pt").input_ids
105
+ # outputs = self.model.generate(inputs, max_new_tokens=100, do_sample=False)
106
+ # sql_query = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
107
+ inputs = self.tokenizer(input_text, return_tensors="pt")
108
+ outputs = self.model.generate(**inputs)
109
+ sql_query = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
110
+ print(sql_query)
111
+ return sql_query