TextToSQL / nlp /sql_model.py
Chit1324's picture
Upload 9 files
a16d353 verified
import re
from huggingface_hub import InferenceClient
class SQLModel:
def __init__(self, model_name="HuggingFaceH4/zephyr-7b-beta"):
self.client = InferenceClient(model_name)
def generate_sql(self, natural_language_query):
prompt = (
"You are a highly skilled SQL translator. Your task is to convert natural language descriptions of data queries "
"into correct and optimized SQL statements.\n\n"
"Here is the schema information for our database :\n\n"
"Table: Employees\n"
"- id (INT)\n"
"- NAME (VARCHAR)\n"
"- Department (VARCHAR)\n"
"- Salary (INT)\n"
"- Hire_Date (DATE)\n\n"
"Table: Departments\n"
"- ID (INT)\n"
"- Name (VARCHAR)\n"
"- Manager (VARCHAR)\n\n"
"Here are a few examples:\n\n"
"1. **Input**: \"Show me all employees in the Sales department.\"\n"
"**Output**:\n\n"
" SELECT *\n"
" FROM Employees\n"
" WHERE Department = 'Sales';\n\n"
"2. **Input**: \"Who is the manager of the Engineering department?\"\n"
"**Output**:\n\n"
" SELECT Manager\n"
" FROM Departments\n"
" WHERE Name = 'Engineering';\n\n"
"3. **Input**: \"List all employees hired after 2021-01-01.\"\n"
"**Output**:\n\n"
" SELECT *\n"
" FROM Employees\n"
" WHERE Hire_Date > '2021-01-01';\n\n"
"4. **Input**: \"What is the total salary expense for the Marketing department?\"\n"
"**Output**:\n\n"
" SELECT SUM(Salary)\n"
" FROM Employees\n"
" WHERE Department = 'Marketing';\n\n"
"5. **Input**: \"Find the average salary of employees in each department.\"\n"
"**Output**:\n\n"
" SELECT Department, AVG(Salary) AS average_salary\n"
" FROM Employees\n"
" GROUP BY Department;\n\n"
"6. **Input**: \"Find the name of employee with Highest salary in HR department.\"\n"
"**Output**:\n\n"
" SELECT Name\n"
" FROM Employees\n"
" WHERE Department = 'HR'\n"
" AND Salary = (SELECT MAX(Salary)\n"
" FROM Employees\n"
" WHERE Department = 'HR');\n\n"
"Please do not return additional text besides query.\n"
"Please only answer queries which makes sense for the given schema. Else just return - \"No information found\""
)
messages = [
{"role": "system", "content": prompt},
{"role": "user", "content": natural_language_query}
]
result = self.client.chat_completion(
messages,
max_tokens=150,
stream=False,
temperature=0.7,
top_p=0.95,
)
# Initialize a variable to hold the extracted SQL text.
sql_query = ""
# Check if the result is a plain string.
if isinstance(result, str):
sql_query = result
# If the result is a list, iterate over its tokens.
elif isinstance(result, list):
for token in result:
if isinstance(token, str):
sql_query += token
elif hasattr(token, "choices"):
# Extract from the structured object.
sql_query += token.choices[0].delta.content
else:
sql_query += str(token)
# Otherwise, if it's an object with choices, extract its content.
elif hasattr(result, "choices"):
sql_query = result.choices[0].message.content
else:
sql_query = str(result)
# Optional: If the model output is in a markdown code block, extract only that content.
match = re.search(r"```sql(.*?)```", sql_query, re.DOTALL | re.IGNORECASE)
if match:
sql_query = match.group(1).strip()
# Remove both literal "\n" substrings and actual newline characters.
sql_query = sql_query.replace("\\n", " ").replace("\n", " ")
# Remove extra spaces.
sql_query = " ".join(sql_query.split())
# Extract only the SQL command: starting from the first occurrence of "select" to the first semicolon.
extraction_pattern = r"(?i)(select\s.*?;)"
extraction_match = re.search(extraction_pattern, sql_query, re.DOTALL)
if extraction_match:
sql_query = " ".join(extraction_match.group(1).split())
print(sql_query)
return sql_query