| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
|
|
|
|
|
| class SQLModel:
|
| def __init__(self, model_name="google/flan-t5-base"):
|
| self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
|
|
|
|
|
|
|
|
|
| def generate_sql(self, natural_language_query):
|
| input_text = f"""You are a highly skilled SQL translator. Your task is to convert natural language descriptions of data queries into correct and optimized SQL statements.
|
| Here is the schema information for our database :
|
|
|
| Table: Employees
|
| - id (INT)
|
| - NAME (VARCHAR)
|
| - Department (VARCHAR)
|
| - Salary (INT)
|
| - Hire_Date (DATE)
|
|
|
| Table: Departments
|
| - ID (INT)
|
| - Name (VARCHAR)
|
| - Manager (VARCHAR)
|
|
|
| Here are a few examples:
|
|
|
| 1. **Input**: "Show me all employees in the Sales department."
|
| **Output**:
|
|
|
| SELECT *
|
| FROM Employees
|
| WHERE Department = 'Sales';
|
|
|
| 2. **Input**: "Who is the manager of the Engineering department?"
|
| **Output**:
|
|
|
| SELECT Manager
|
| FROM Departments
|
| WHERE Name = 'Engineering';
|
|
|
|
|
| 3. **Input**: "List all employees hired after 2021-01-01."
|
| **Output**:
|
|
|
| SELECT *
|
| FROM Employees
|
| WHERE Hire_Date > '2021-01-01';
|
|
|
|
|
| 4. **Input**: "What is the total salary expense for the Marketing department?"
|
| **Output**:
|
|
|
| SELECT SUM(Salary)
|
| FROM Employees
|
| WHERE Department = 'Marketing';
|
|
|
|
|
| 5. **Input**: "Find the average salary of employees in each department."
|
| **Output**:
|
|
|
| SELECT Department, AVG(Salary) AS average_salary
|
| FROM Employees
|
| GROUP BY Department;
|
|
|
| Please do not return additional text besides query.
|
|
|
| Please only answer queries which makes sense for the given schema. Else just return - "No information found"
|
|
|
| Now, translate the following natural language query into an syntactically correct SQL query:
|
| **Input**: {natural_language_query}
|
| **Output**:
|
|
|
| """
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| inputs = self.tokenizer(input_text, return_tensors="pt")
|
| outputs = self.model.generate(**inputs)
|
| sql_query = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| print(sql_query)
|
| return sql_query
|
|
|