| import re
|
| from huggingface_hub import InferenceClient
|
|
|
| class SQLModel:
|
| def __init__(self, model_name="HuggingFaceH4/zephyr-7b-beta"):
|
| self.client = InferenceClient(model_name)
|
|
|
| def generate_sql(self, natural_language_query):
|
| prompt = (
|
| "You are a highly skilled SQL translator. Your task is to convert natural language descriptions of data queries "
|
| "into correct and optimized SQL statements.\n\n"
|
| "Here is the schema information for our database :\n\n"
|
| "Table: Employees\n"
|
| "- id (INT)\n"
|
| "- NAME (VARCHAR)\n"
|
| "- Department (VARCHAR)\n"
|
| "- Salary (INT)\n"
|
| "- Hire_Date (DATE)\n\n"
|
| "Table: Departments\n"
|
| "- ID (INT)\n"
|
| "- Name (VARCHAR)\n"
|
| "- Manager (VARCHAR)\n\n"
|
|
|
| "Here are a few examples:\n\n"
|
| "1. **Input**: \"Show me all employees in the Sales department.\"\n"
|
| "**Output**:\n\n"
|
| " SELECT *\n"
|
| " FROM Employees\n"
|
| " WHERE Department = 'Sales';\n\n"
|
| "2. **Input**: \"Who is the manager of the Engineering department?\"\n"
|
| "**Output**:\n\n"
|
| " SELECT Manager\n"
|
| " FROM Departments\n"
|
| " WHERE Name = 'Engineering';\n\n"
|
| "3. **Input**: \"List all employees hired after 2021-01-01.\"\n"
|
| "**Output**:\n\n"
|
| " SELECT *\n"
|
| " FROM Employees\n"
|
| " WHERE Hire_Date > '2021-01-01';\n\n"
|
| "4. **Input**: \"What is the total salary expense for the Marketing department?\"\n"
|
| "**Output**:\n\n"
|
| " SELECT SUM(Salary)\n"
|
| " FROM Employees\n"
|
| " WHERE Department = 'Marketing';\n\n"
|
| "5. **Input**: \"Find the average salary of employees in each department.\"\n"
|
| "**Output**:\n\n"
|
| " SELECT Department, AVG(Salary) AS average_salary\n"
|
| " FROM Employees\n"
|
| " GROUP BY Department;\n\n"
|
| "6. **Input**: \"Find the name of employee with Highest salary in HR department.\"\n"
|
| "**Output**:\n\n"
|
| " SELECT Name\n"
|
| " FROM Employees\n"
|
| " WHERE Department = 'HR'\n"
|
| " AND Salary = (SELECT MAX(Salary)\n"
|
| " FROM Employees\n"
|
| " WHERE Department = 'HR');\n\n"
|
| "Please do not return additional text besides query.\n"
|
| "Please only answer queries which makes sense for the given schema. Else just return - \"No information found\""
|
| )
|
|
|
| messages = [
|
| {"role": "system", "content": prompt},
|
| {"role": "user", "content": natural_language_query}
|
| ]
|
|
|
| result = self.client.chat_completion(
|
| messages,
|
| max_tokens=150,
|
| stream=False,
|
| temperature=0.7,
|
| top_p=0.95,
|
| )
|
|
|
|
|
| sql_query = ""
|
|
|
|
|
| if isinstance(result, str):
|
| sql_query = result
|
|
|
| elif isinstance(result, list):
|
| for token in result:
|
| if isinstance(token, str):
|
| sql_query += token
|
| elif hasattr(token, "choices"):
|
|
|
| sql_query += token.choices[0].delta.content
|
| else:
|
| sql_query += str(token)
|
|
|
| elif hasattr(result, "choices"):
|
| sql_query = result.choices[0].message.content
|
| else:
|
| sql_query = str(result)
|
|
|
|
|
| match = re.search(r"```sql(.*?)```", sql_query, re.DOTALL | re.IGNORECASE)
|
| if match:
|
| sql_query = match.group(1).strip()
|
|
|
|
|
| sql_query = sql_query.replace("\\n", " ").replace("\n", " ")
|
|
|
| sql_query = " ".join(sql_query.split())
|
|
|
|
|
| extraction_pattern = r"(?i)(select\s.*?;)"
|
| extraction_match = re.search(extraction_pattern, sql_query, re.DOTALL)
|
| if extraction_match:
|
| sql_query = " ".join(extraction_match.group(1).split())
|
|
|
| print(sql_query)
|
| return sql_query
|
|
|