Chit1324 commited on
Commit
4254987
·
verified ·
1 Parent(s): 7cf4e28

Upload 4 files

Browse files
Files changed (1) hide show
  1. nlp/sql_model.py +98 -106
nlp/sql_model.py CHANGED
@@ -1,113 +1,105 @@
1
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
- # import os
3
- # from transformers import AutoTokenizer, AutoModelForCausalLM
4
 
5
  class SQLModel:
6
- def __init__(self, model_name="google/flan-t5-base"):
7
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
8
- self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
9
- # hf_token = os.environ.get("HF_HUB_TOKEN")
10
- # self.tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b-it",use_auth_token=hf_token)
11
- # self.model = AutoModelForCausalLM.from_pretrained("google/gemma-2b-it",use_auth_token=hf_token)
12
 
13
  def generate_sql(self, natural_language_query):
14
- input_text = f"""You are a highly skilled SQL translator. Your task is to convert natural language descriptions of data queries into correct and optimized SQL statements.
15
- Here is the schema information for our database :
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- Table: Employees
18
- - id (INT)
19
- - NAME (VARCHAR)
20
- - Department (VARCHAR)
21
- - Salary (INT)
22
- - Hire_Date (DATE)
23
-
24
- Table: Departments
25
- - ID (INT)
26
- - Name (VARCHAR)
27
- - Manager (VARCHAR)
28
-
29
- Here are a few examples:
30
-
31
- 1. **Input**: "Show me all employees in the Sales department."
32
- **Output**:
33
-
34
- SELECT *
35
- FROM Employees
36
- WHERE Department = 'Sales';
37
-
38
- 2. **Input**: "Who is the manager of the Engineering department?"
39
- **Output**:
40
-
41
- SELECT Manager
42
- FROM Departments
43
- WHERE Name = 'Engineering';
44
-
45
-
46
- 3. **Input**: "List all employees hired after 2021-01-01."
47
- **Output**:
48
-
49
- SELECT *
50
- FROM Employees
51
- WHERE Hire_Date > '2021-01-01';
52
-
53
-
54
- 4. **Input**: "What is the total salary expense for the Marketing department?"
55
- **Output**:
56
-
57
- SELECT SUM(Salary)
58
- FROM Employees
59
- WHERE Department = 'Marketing';
60
-
61
-
62
- 5. **Input**: "Find the average salary of employees in each department."
63
- **Output**:
64
-
65
- SELECT Department, AVG(Salary) AS average_salary
66
- FROM Employees
67
- GROUP BY Department;
68
-
69
- Please do not return additional text besides query.
70
-
71
- Please only answer queries which makes sense for the given schema. Else just return - "No information found"
72
-
73
- Now, translate the following natural language query into an syntactically correct SQL query:
74
- **Input**: {natural_language_query}
75
- **Output**:
76
-
77
- """
78
- # input_text = f"""
79
- # Translate the following natural language query into a syntactically correct SQL query using the provided database schema. Output only the SQL query with no additional text or explanation.
80
-
81
- # Database Schema:
82
-
83
- # Table: Employees
84
- # - id (INT)
85
- # - NAME (VARCHAR)
86
- # - Department (VARCHAR)
87
- # - Salary (INT)
88
- # - Hire_Date (DATE)
89
-
90
- # Table: Departments
91
- # - ID (INT)
92
- # - Name (VARCHAR)
93
- # - Manager (VARCHAR)
94
-
95
- # Examples:
96
- # 1. Natural Language Query: "List all employees who were hired after '2020-01-01'."
97
- # Output: SELECT * FROM Employees WHERE Hire_Date > '2020-01-01';
98
-
99
- # 2. Natural Language Query: "Retrieve the names and salaries of employees in the 'Sales' department."
100
- # Output: SELECT NAME, Salary FROM Employees WHERE Department = 'Sales';
101
-
102
- # Now, translate the following query:
103
- # {natural_language_query}
104
- # """
105
- # input_text = f"translate English to SQL: {natural_language_query}"
106
- # inputs = self.tokenizer(input_text, return_tensors="pt").input_ids
107
- # outputs = self.model.generate(inputs, max_new_tokens=100, do_sample=False)
108
- # sql_query = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
109
- inputs = self.tokenizer(input_text, return_tensors="pt")
110
- outputs = self.model.generate(**inputs)
111
- sql_query = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
112
  print(sql_query)
113
  return sql_query
 
1
+ import re
2
+ from huggingface_hub import InferenceClient
 
3
 
4
  class SQLModel:
5
+ def __init__(self, model_name="HuggingFaceH4/zephyr-7b-beta"):
6
+ self.client = InferenceClient(model_name)
 
 
 
 
7
 
8
  def generate_sql(self, natural_language_query):
9
+ prompt = (
10
+ "You are a highly skilled SQL translator. Your task is to convert natural language descriptions of data queries "
11
+ "into correct and optimized SQL statements.\n\n"
12
+ "Here is the schema information for our database :\n\n"
13
+ "Table: Employees\n"
14
+ "- id (INT)\n"
15
+ "- NAME (VARCHAR)\n"
16
+ "- Department (VARCHAR)\n"
17
+ "- Salary (INT)\n"
18
+ "- Hire_Date (DATE)\n\n"
19
+ "Table: Departments\n"
20
+ "- ID (INT)\n"
21
+ "- Name (VARCHAR)\n"
22
+ "- Manager (VARCHAR)\n\n"
23
+ "Here are a few examples:\n\n"
24
+ "1. **Input**: \"Show me all employees in the Sales department.\"\n"
25
+ "**Output**:\n\n"
26
+ " SELECT *\n"
27
+ " FROM Employees\n"
28
+ " WHERE Department = 'Sales';\n\n"
29
+ "2. **Input**: \"Who is the manager of the Engineering department?\"\n"
30
+ "**Output**:\n\n"
31
+ " SELECT Manager\n"
32
+ " FROM Departments\n"
33
+ " WHERE Name = 'Engineering';\n\n"
34
+ "3. **Input**: \"List all employees hired after 2021-01-01.\"\n"
35
+ "**Output**:\n\n"
36
+ " SELECT *\n"
37
+ " FROM Employees\n"
38
+ " WHERE Hire_Date > '2021-01-01';\n\n"
39
+ "4. **Input**: \"What is the total salary expense for the Marketing department?\"\n"
40
+ "**Output**:\n\n"
41
+ " SELECT SUM(Salary)\n"
42
+ " FROM Employees\n"
43
+ " WHERE Department = 'Marketing';\n\n"
44
+ "5. **Input**: \"Find the average salary of employees in each department.\"\n"
45
+ "**Output**:\n\n"
46
+ " SELECT Department, AVG(Salary) AS average_salary\n"
47
+ " FROM Employees\n"
48
+ " GROUP BY Department;\n\n"
49
+ "Please do not return additional text besides query.\n"
50
+ "Please only answer queries which makes sense for the given schema. Else just return - \"No information found\""
51
+ )
52
+
53
+ messages = [
54
+ {"role": "system", "content": prompt},
55
+ {"role": "user", "content": natural_language_query}
56
+ ]
57
+
58
+ result = self.client.chat_completion(
59
+ messages,
60
+ max_tokens=150,
61
+ stream=False,
62
+ temperature=0.7,
63
+ top_p=0.95,
64
+ )
65
+
66
+ # Initialize a variable to hold the extracted SQL text.
67
+ sql_query = ""
68
+
69
+ # Check if the result is a plain string.
70
+ if isinstance(result, str):
71
+ sql_query = result
72
+ # If the result is a list, iterate over its tokens.
73
+ elif isinstance(result, list):
74
+ for token in result:
75
+ if isinstance(token, str):
76
+ sql_query += token
77
+ elif hasattr(token, "choices"):
78
+ # Extract from the structured object.
79
+ sql_query += token.choices[0].delta.content
80
+ else:
81
+ sql_query += str(token)
82
+ # Otherwise, if it's an object with choices, extract its content.
83
+ elif hasattr(result, "choices"):
84
+ sql_query = result.choices[0].message.content
85
+ else:
86
+ sql_query = str(result)
87
+
88
+ # Optional: If the model output is in a markdown code block, extract only that content.
89
+ match = re.search(r"```sql(.*?)```", sql_query, re.DOTALL | re.IGNORECASE)
90
+ if match:
91
+ sql_query = match.group(1).strip()
92
+
93
+ # Remove both literal "\n" substrings and actual newline characters.
94
+ sql_query = sql_query.replace("\\n", " ").replace("\n", " ")
95
+ # Remove extra spaces.
96
+ sql_query = " ".join(sql_query.split())
97
+
98
+ # Extract only the SQL command: starting from the first occurrence of "select" to the first semicolon.
99
+ extraction_pattern = r"(?i)(select\s.*?;)"
100
+ extraction_match = re.search(extraction_pattern, sql_query, re.DOTALL)
101
+ if extraction_match:
102
+ sql_query = " ".join(extraction_match.group(1).split())
103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  print(sql_query)
105
  return sql_query