File size: 4,000 Bytes
0276b0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# from transformers import AutoTokenizer, AutoModelForCausalLM

class SQLModel:
    def __init__(self, model_name="google/flan-t5-base"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
        # self.tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b-it")
        # self.model = AutoModelForCausalLM.from_pretrained("google/gemma-7b-it",)

    def generate_sql(self, natural_language_query):
        input_text = f"""You are a highly skilled SQL translator. Your task is to convert natural language descriptions of data queries into correct and optimized SQL statements.

        Here is the schema information for our database :



        Table: Employees

        - id (INT)

        - NAME (VARCHAR)

        - Department (VARCHAR)

        - Salary (INT)

        - Hire_Date (DATE)



        Table: Departments

        - ID (INT)

        - Name (VARCHAR)

        - Manager (VARCHAR)



        Here are a few examples:



        1. **Input**: "Show me all employees in the Sales department."

        **Output**:



                SELECT *

                FROM Employees

                WHERE Department = 'Sales';



        2. **Input**: "Who is the manager of the Engineering department?"

        **Output**:



                SELECT Manager

                FROM Departments

                WHERE Name = 'Engineering';





        3. **Input**: "List all employees hired after 2021-01-01."

        **Output**:



                SELECT *

                FROM Employees

                WHERE Hire_Date > '2021-01-01';





        4. **Input**: "What is the total salary expense for the Marketing department?"

        **Output**:



                SELECT SUM(Salary)

                FROM Employees

                WHERE Department = 'Marketing';





        5. **Input**: "Find the average salary of employees in each department."

        **Output**:



                SELECT Department, AVG(Salary) AS average_salary

                FROM Employees

                GROUP BY Department;



        Please do not return additional text besides query.



        Please only answer queries which makes sense for the given schema. Else just return - "No information found"



        Now, translate the following natural language query into an syntactically correct SQL query:

        **Input**: {natural_language_query}

        **Output**:



"""
#         input_text = f"""
# Translate the following natural language query into a syntactically correct SQL query using the provided database schema. Output only the SQL query with no additional text or explanation.

# Database Schema:

# Table: Employees
# - id (INT)
# - NAME (VARCHAR)
# - Department (VARCHAR)
# - Salary (INT)
# - Hire_Date (DATE)

# Table: Departments
# - ID (INT)
# - Name (VARCHAR)
# - Manager (VARCHAR)

# Examples:
# 1. Natural Language Query: "List all employees who were hired after '2020-01-01'."
#    Output: SELECT * FROM Employees WHERE Hire_Date > '2020-01-01';

# 2. Natural Language Query: "Retrieve the names and salaries of employees in the 'Sales' department."
#    Output: SELECT NAME, Salary FROM Employees WHERE Department = 'Sales';

# Now, translate the following query:
# {natural_language_query}
# """
        # input_text = f"translate English to SQL: {natural_language_query}"
        # inputs = self.tokenizer(input_text, return_tensors="pt").input_ids
        # outputs = self.model.generate(inputs, max_new_tokens=100, do_sample=False)
        # sql_query = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        inputs = self.tokenizer(input_text, return_tensors="pt")
        outputs = self.model.generate(**inputs)
        sql_query = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        print(sql_query)
        return sql_query