File size: 4,000 Bytes
0276b0d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# from transformers import AutoTokenizer, AutoModelForCausalLM
class SQLModel:
def __init__(self, model_name="google/flan-t5-base"):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
# self.tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b-it")
# self.model = AutoModelForCausalLM.from_pretrained("google/gemma-7b-it",)
def generate_sql(self, natural_language_query):
input_text = f"""You are a highly skilled SQL translator. Your task is to convert natural language descriptions of data queries into correct and optimized SQL statements.
Here is the schema information for our database :
Table: Employees
- id (INT)
- NAME (VARCHAR)
- Department (VARCHAR)
- Salary (INT)
- Hire_Date (DATE)
Table: Departments
- ID (INT)
- Name (VARCHAR)
- Manager (VARCHAR)
Here are a few examples:
1. **Input**: "Show me all employees in the Sales department."
**Output**:
SELECT *
FROM Employees
WHERE Department = 'Sales';
2. **Input**: "Who is the manager of the Engineering department?"
**Output**:
SELECT Manager
FROM Departments
WHERE Name = 'Engineering';
3. **Input**: "List all employees hired after 2021-01-01."
**Output**:
SELECT *
FROM Employees
WHERE Hire_Date > '2021-01-01';
4. **Input**: "What is the total salary expense for the Marketing department?"
**Output**:
SELECT SUM(Salary)
FROM Employees
WHERE Department = 'Marketing';
5. **Input**: "Find the average salary of employees in each department."
**Output**:
SELECT Department, AVG(Salary) AS average_salary
FROM Employees
GROUP BY Department;
Please do not return additional text besides query.
Please only answer queries which makes sense for the given schema. Else just return - "No information found"
Now, translate the following natural language query into an syntactically correct SQL query:
**Input**: {natural_language_query}
**Output**:
"""
# input_text = f"""
# Translate the following natural language query into a syntactically correct SQL query using the provided database schema. Output only the SQL query with no additional text or explanation.
# Database Schema:
# Table: Employees
# - id (INT)
# - NAME (VARCHAR)
# - Department (VARCHAR)
# - Salary (INT)
# - Hire_Date (DATE)
# Table: Departments
# - ID (INT)
# - Name (VARCHAR)
# - Manager (VARCHAR)
# Examples:
# 1. Natural Language Query: "List all employees who were hired after '2020-01-01'."
# Output: SELECT * FROM Employees WHERE Hire_Date > '2020-01-01';
# 2. Natural Language Query: "Retrieve the names and salaries of employees in the 'Sales' department."
# Output: SELECT NAME, Salary FROM Employees WHERE Department = 'Sales';
# Now, translate the following query:
# {natural_language_query}
# """
# input_text = f"translate English to SQL: {natural_language_query}"
# inputs = self.tokenizer(input_text, return_tensors="pt").input_ids
# outputs = self.model.generate(inputs, max_new_tokens=100, do_sample=False)
# sql_query = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
inputs = self.tokenizer(input_text, return_tensors="pt")
outputs = self.model.generate(**inputs)
sql_query = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
print(sql_query)
return sql_query
|