File size: 4,540 Bytes
4254987
 
0276b0d
 
4254987
 
0276b0d
 
4254987
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0276b0d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import re
from huggingface_hub import InferenceClient

class SQLModel:
    def __init__(self, model_name="HuggingFaceH4/zephyr-7b-beta"):
        self.client = InferenceClient(model_name)

    def generate_sql(self, natural_language_query):
        prompt = (
            "You are a highly skilled SQL translator. Your task is to convert natural language descriptions of data queries "
            "into correct and optimized SQL statements.\n\n"
            "Here is the schema information for our database :\n\n"
            "Table: Employees\n"
            "- id (INT)\n"
            "- NAME (VARCHAR)\n"
            "- Department (VARCHAR)\n"
            "- Salary (INT)\n"
            "- Hire_Date (DATE)\n\n"
            "Table: Departments\n"
            "- ID (INT)\n"
            "- Name (VARCHAR)\n"
            "- Manager (VARCHAR)\n\n"
            "Here are a few examples:\n\n"
            "1. **Input**: \"Show me all employees in the Sales department.\"\n"
            "**Output**:\n\n"
            "        SELECT *\n"
            "        FROM Employees\n"
            "        WHERE Department = 'Sales';\n\n"
            "2. **Input**: \"Who is the manager of the Engineering department?\"\n"
            "**Output**:\n\n"
            "        SELECT Manager\n"
            "        FROM Departments\n"
            "        WHERE Name = 'Engineering';\n\n"
            "3. **Input**: \"List all employees hired after 2021-01-01.\"\n"
            "**Output**:\n\n"
            "        SELECT *\n"
            "        FROM Employees\n"
            "        WHERE Hire_Date > '2021-01-01';\n\n"
            "4. **Input**: \"What is the total salary expense for the Marketing department?\"\n"
            "**Output**:\n\n"
            "        SELECT SUM(Salary)\n"
            "        FROM Employees\n"
            "        WHERE Department = 'Marketing';\n\n"
            "5. **Input**: \"Find the average salary of employees in each department.\"\n"
            "**Output**:\n\n"
            "        SELECT Department, AVG(Salary) AS average_salary\n"
            "        FROM Employees\n"
            "        GROUP BY Department;\n\n"
            "Please do not return additional text besides query.\n"
            "Please only answer queries which makes sense for the given schema. Else just return - \"No information found\""
        )

        messages = [
            {"role": "system", "content": prompt},
            {"role": "user", "content": natural_language_query}
        ]

        result = self.client.chat_completion(
            messages,
            max_tokens=150,
            stream=False,
            temperature=0.7,
            top_p=0.95,
        )

        # Initialize a variable to hold the extracted SQL text.
        sql_query = ""

        # Check if the result is a plain string.
        if isinstance(result, str):
            sql_query = result
        # If the result is a list, iterate over its tokens.
        elif isinstance(result, list):
            for token in result:
                if isinstance(token, str):
                    sql_query += token
                elif hasattr(token, "choices"):
                    # Extract from the structured object.
                    sql_query += token.choices[0].delta.content
                else:
                    sql_query += str(token)
        # Otherwise, if it's an object with choices, extract its content.
        elif hasattr(result, "choices"):
            sql_query = result.choices[0].message.content
        else:
            sql_query = str(result)

        # Optional: If the model output is in a markdown code block, extract only that content.
        match = re.search(r"```sql(.*?)```", sql_query, re.DOTALL | re.IGNORECASE)
        if match:
            sql_query = match.group(1).strip()

        # Remove both literal "\n" substrings and actual newline characters.
        sql_query = sql_query.replace("\\n", " ").replace("\n", " ")
        # Remove extra spaces.
        sql_query = " ".join(sql_query.split())

        # Extract only the SQL command: starting from the first occurrence of "select" to the first semicolon.
        extraction_pattern = r"(?i)(select\s.*?;)"
        extraction_match = re.search(extraction_pattern, sql_query, re.DOTALL)
        if extraction_match:
            sql_query = " ".join(extraction_match.group(1).split())

        print(sql_query)
        return sql_query