Chit1324 commited on
Commit
7cf4e28
·
verified ·
1 Parent(s): 7ed6e6e

Delete sql_model.py

Browse files
Files changed (1) hide show
  1. sql_model.py +0 -111
sql_model.py DELETED
@@ -1,111 +0,0 @@
1
- # from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
- from transformers import AutoTokenizer, AutoModelForCausalLM
3
-
4
- class SQLModel:
5
- def __init__(self, model_name="google/flan-t5-base"):
6
- # self.tokenizer = AutoTokenizer.from_pretrained(model_name)
7
- # self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
8
- self.tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b-it")
9
- self.model = AutoModelForCausalLM.from_pretrained("google/gemma-7b-it",)
10
-
11
- def generate_sql(self, natural_language_query):
12
- input_text = f"""You are a highly skilled SQL translator. Your task is to convert natural language descriptions of data queries into correct and optimized SQL statements.
13
- Here is the schema information for our database :
14
-
15
- Table: Employees
16
- - id (INT)
17
- - NAME (VARCHAR)
18
- - Department (VARCHAR)
19
- - Salary (INT)
20
- - Hire_Date (DATE)
21
-
22
- Table: Departments
23
- - ID (INT)
24
- - Name (VARCHAR)
25
- - Manager (VARCHAR)
26
-
27
- Here are a few examples:
28
-
29
- 1. **Input**: "Show me all employees in the Sales department."
30
- **Output**:
31
-
32
- SELECT *
33
- FROM Employees
34
- WHERE Department = 'Sales';
35
-
36
- 2. **Input**: "Who is the manager of the Engineering department?"
37
- **Output**:
38
-
39
- SELECT Manager
40
- FROM Departments
41
- WHERE Name = 'Engineering';
42
-
43
-
44
- 3. **Input**: "List all employees hired after 2021-01-01."
45
- **Output**:
46
-
47
- SELECT *
48
- FROM Employees
49
- WHERE Hire_Date > '2021-01-01';
50
-
51
-
52
- 4. **Input**: "What is the total salary expense for the Marketing department?"
53
- **Output**:
54
-
55
- SELECT SUM(Salary)
56
- FROM Employees
57
- WHERE Department = 'Marketing';
58
-
59
-
60
- 5. **Input**: "Find the average salary of employees in each department."
61
- **Output**:
62
-
63
- SELECT Department, AVG(Salary) AS average_salary
64
- FROM Employees
65
- GROUP BY Department;
66
-
67
- Please do not return additional text besides query.
68
-
69
- Please only answer queries which makes sense for the given schema. Else just return - "No information found"
70
-
71
- Now, translate the following natural language query into an syntactically correct SQL query:
72
- **Input**: {natural_language_query}
73
- **Output**:
74
-
75
- """
76
- # input_text = f"""
77
- # Translate the following natural language query into a syntactically correct SQL query using the provided database schema. Output only the SQL query with no additional text or explanation.
78
-
79
- # Database Schema:
80
-
81
- # Table: Employees
82
- # - id (INT)
83
- # - NAME (VARCHAR)
84
- # - Department (VARCHAR)
85
- # - Salary (INT)
86
- # - Hire_Date (DATE)
87
-
88
- # Table: Departments
89
- # - ID (INT)
90
- # - Name (VARCHAR)
91
- # - Manager (VARCHAR)
92
-
93
- # Examples:
94
- # 1. Natural Language Query: "List all employees who were hired after '2020-01-01'."
95
- # Output: SELECT * FROM Employees WHERE Hire_Date > '2020-01-01';
96
-
97
- # 2. Natural Language Query: "Retrieve the names and salaries of employees in the 'Sales' department."
98
- # Output: SELECT NAME, Salary FROM Employees WHERE Department = 'Sales';
99
-
100
- # Now, translate the following query:
101
- # {natural_language_query}
102
- # """
103
- # input_text = f"translate English to SQL: {natural_language_query}"
104
- # inputs = self.tokenizer(input_text, return_tensors="pt").input_ids
105
- # outputs = self.model.generate(inputs, max_new_tokens=100, do_sample=False)
106
- # sql_query = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
107
- inputs = self.tokenizer(input_text, return_tensors="pt")
108
- outputs = self.model.generate(**inputs)
109
- sql_query = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
110
- print(sql_query)
111
- return sql_query