jpatel commited on
Commit
22654ec
·
1 Parent(s): 027dfb7

adding sql generator app

Browse files
Files changed (2) hide show
  1. app.py +128 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import re
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+
6
+ MODEL = "jinesh90/qwen2.5-coder-sql-generator"
7
+
8
+ print("Loading model...")
9
+ tokenizer = AutoTokenizer.from_pretrained(MODEL)
10
+ model = AutoModelForCausalLM.from_pretrained(
11
+ MODEL,
12
+ torch_dtype = torch.float16,
13
+ device_map = "auto",
14
+ low_cpu_mem_usage = True,
15
+ )
16
+ model.eval()
17
+ print("Ready!")
18
+
19
+ def clean_sql(text):
20
+ text = text.strip()
21
+ clean = re.sub(r'[^\x00-\x7F].*', '', text).strip()
22
+ for stop in ["###", "assistant", "\n\n"]:
23
+ if stop in clean:
24
+ clean = clean.split(stop)[0].strip()
25
+ return clean
26
+
27
+ def build_prompt(question, schema):
28
+ return f"""You are a SQL expert. Generate the simplest and most direct SQL query.
29
+ Use JOINs only when multiple tables are needed.
30
+
31
+ ### Schema:
32
+ {schema}
33
+
34
+ ### Question:
35
+ {question}
36
+
37
+ ### SQL:"""
38
+
39
+ def generate(question, schema):
40
+ if not question or not schema:
41
+ return "Please provide both a question and schema!"
42
+
43
+ messages = [{"role": "user", "content": build_prompt(question, schema)}]
44
+ text = tokenizer.apply_chat_template(
45
+ messages,
46
+ tokenize = False,
47
+ add_generation_prompt = True
48
+ )
49
+ inputs = tokenizer(
50
+ text,
51
+ return_tensors = "pt",
52
+ truncation = True,
53
+ max_length = 1024
54
+ ).to(model.device)
55
+
56
+ stop_tokens = [
57
+ tokenizer.eos_token_id,
58
+ tokenizer.convert_tokens_to_ids("<|im_end|>"),
59
+ ]
60
+
61
+ with torch.no_grad():
62
+ outputs = model.generate(
63
+ **inputs,
64
+ max_new_tokens = 200,
65
+ do_sample = False,
66
+ temperature = 0,
67
+ repetition_penalty = 1.3,
68
+ eos_token_id = stop_tokens,
69
+ pad_token_id = tokenizer.eos_token_id,
70
+ )
71
+
72
+ input_len = inputs["input_ids"].shape[1]
73
+ raw = tokenizer.decode(outputs[0, input_len:], skip_special_tokens=True)
74
+ return clean_sql(raw)
75
+
76
+ # Example schemas for demo
77
+ example_schema = """CREATE TABLE employees (
78
+ id INTEGER,
79
+ name VARCHAR,
80
+ salary REAL,
81
+ department VARCHAR,
82
+ age INTEGER
83
+ );"""
84
+
85
+ with gr.Blocks(title="SQL Query Generator") as demo:
86
+ gr.Markdown("# 🗄️ SQL Query Generator")
87
+ gr.Markdown("Fine-tuned Qwen2.5-Coder 7B on Spider dataset | 42% execution accuracy")
88
+
89
+ with gr.Row():
90
+ with gr.Column():
91
+ schema = gr.Textbox(
92
+ label = "Database Schema (CREATE TABLE statements)",
93
+ value = example_schema,
94
+ lines = 10
95
+ )
96
+ question = gr.Textbox(
97
+ label = "Question",
98
+ placeholder = "How many employees have salary > 50000?",
99
+ lines = 2
100
+ )
101
+ btn = gr.Button("🚀 Generate SQL", variant="primary")
102
+
103
+ with gr.Column():
104
+ output = gr.Code(
105
+ label = "Generated SQL",
106
+ language = "sql"
107
+ )
108
+ gr.Markdown("""
109
+ ### 📊 Model Stats
110
+ - **Base model**: Qwen2.5-Coder-7B
111
+ - **Training data**: Spider dataset (7.9k samples)
112
+ - **Simple queries**: 64.2% accuracy
113
+ - **Complex queries**: 17.0% accuracy
114
+ - **Overall**: 42% execution accuracy
115
+ """)
116
+
117
+ btn.click(fn=generate, inputs=[question, schema], outputs=output)
118
+
119
+ gr.Examples(
120
+ examples=[
121
+ ["How many employees are there?", example_schema],
122
+ ["Find all employees with salary greater than 50000", example_schema],
123
+ ["What is the average salary by department?", example_schema],
124
+ ],
125
+ inputs=[question, schema]
126
+ )
127
+
128
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ accelerate
4
+ gradio