| import gradio as gr |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| import torch |
| import re |
| import sqlparse |
|
|
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| model = AutoModelForCausalLM.from_pretrained( |
| "onkolahmet/Qwen2-0.5B-Instruct-SQL-generator", |
| torch_dtype="auto", |
| device_map="auto" |
| ) |
| tokenizer = AutoTokenizer.from_pretrained("onkolahmet/Qwen2-0.5B-Instruct-SQL-generator") |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| def generate_sql(question, context=None): |
| |
| prompt = "Translate natural language questions to SQL queries.\n\n" |
| |
| |
| if context and context.strip(): |
| prompt += f"Table Context:\n{context}\n\n" |
| |
| |
| |
| |
| |
| |
| prompt += f"Q: {question}\nSQL:" |
| |
| |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) |
| |
| |
| outputs = model.generate( |
| inputs.input_ids, |
| max_new_tokens=128, |
| do_sample=True, |
| eos_token_id=tokenizer.eos_token_id |
| ) |
| |
| |
| sql_query = tokenizer.decode(outputs[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True) |
| return sql_query.strip() |
|
|
| def clean_sql_output(sql_text): |
| """ |
| Clean and deduplicate SQL queries: |
| 1. Remove comments |
| 2. Remove duplicate queries |
| 3. Extract only the most relevant query |
| 4. Format properly |
| """ |
| |
| sql_text = re.sub(r'--.*?$', '', sql_text, flags=re.MULTILINE) |
| sql_text = re.sub(r'/\*.*?\*/', '', sql_text, flags=re.DOTALL) |
| |
| |
| sql_text = re.sub(r'```sql|```', '', sql_text) |
| |
| |
| if ';' in sql_text: |
| queries = [q.strip() for q in sql_text.split(';') if q.strip()] |
| else: |
| |
| sql_text_cleaned = re.sub(r'\s+', ' ', sql_text) |
| select_matches = list(re.finditer(r'SELECT\s+', sql_text_cleaned, re.IGNORECASE)) |
| |
| if len(select_matches) > 1: |
| queries = [] |
| for i in range(len(select_matches)): |
| start = select_matches[i].start() |
| end = select_matches[i+1].start() if i < len(select_matches) - 1 else len(sql_text_cleaned) |
| queries.append(sql_text_cleaned[start:end].strip()) |
| else: |
| queries = [sql_text] |
| |
| |
| queries = [q for q in queries if q.strip()] |
| |
| if not queries: |
| return "" |
| |
| |
| if len(queries) > 1: |
| |
| normalized_queries = [] |
| for q in queries: |
| |
| try: |
| formatted = sqlparse.format( |
| q + ('' if q.strip().endswith(';') else ';'), |
| keyword_case='lower', |
| identifier_case='lower', |
| strip_comments=True, |
| reindent=True |
| ) |
| normalized_queries.append(formatted) |
| except: |
| |
| normalized = re.sub(r'\s+', ' ', q.lower().strip()) |
| normalized_queries.append(normalized) |
| |
| |
| unique_queries = [] |
| unique_normalized = [] |
| |
| for i, norm_q in enumerate(normalized_queries): |
| if norm_q not in unique_normalized: |
| unique_normalized.append(norm_q) |
| unique_queries.append(queries[i]) |
| |
| |
| |
| |
| |
| select_queries = [q for q in unique_queries if re.search(r'SELECT\s+', q, re.IGNORECASE)] |
| |
| if select_queries: |
| |
| best_query = max(select_queries, key=len) |
| elif unique_queries: |
| |
| best_query = max(unique_queries, key=len) |
| else: |
| |
| best_query = queries[0] |
| else: |
| best_query = queries[0] |
| |
| |
| best_query = best_query.strip() |
| if not best_query.endswith(';'): |
| best_query += ';' |
| |
| |
| best_query = re.sub(r'\s+', ' ', best_query) |
| |
| try: |
| |
| formatted_sql = sqlparse.format( |
| best_query, |
| keyword_case='upper', |
| identifier_case='lower', |
| reindent=True, |
| indent_width=2 |
| ) |
| return formatted_sql |
| except: |
| return best_query |
|
|
| def process_input(question, table_context): |
| """Function to process user input through the model and return formatted results""" |
| if not question.strip(): |
| return "Please enter a question." |
| |
| |
| raw_sql = generate_sql(question, table_context) |
| |
| |
| cleaned_sql = clean_sql_output(raw_sql) |
| |
| if not cleaned_sql: |
| return "Sorry, I couldn't generate a valid SQL query. Please try rephrasing your question." |
| |
| return cleaned_sql |
|
|
| |
| example_contexts = [ |
| |
| """ |
| CREATE TABLE customers ( |
| id INT PRIMARY KEY, |
| name VARCHAR(100), |
| email VARCHAR(100), |
| order_date DATE |
| ); |
| """, |
| |
| |
| """ |
| CREATE TABLE products ( |
| id INT PRIMARY KEY, |
| name VARCHAR(100), |
| category VARCHAR(50), |
| price DECIMAL(10,2), |
| stock_quantity INT |
| ); |
| """, |
| |
| |
| """ |
| CREATE TABLE employees ( |
| id INT PRIMARY KEY, |
| name VARCHAR(100), |
| department VARCHAR(50), |
| salary DECIMAL(10,2), |
| hire_date DATE |
| ); |
| CREATE TABLE departments ( |
| id INT PRIMARY KEY, |
| name VARCHAR(50), |
| manager_id INT, |
| budget DECIMAL(15,2) |
| ); |
| """ |
| ] |
|
|
| |
| example_questions = [ |
| "Get the names and emails of customers who placed an order in the last 30 days.", |
| "Find all products with less than 10 items in stock.", |
| "List all employees in the Sales department with a salary greater than 50000.", |
| "What is the total budget for departments with more than 5 employees?", |
| "Count how many products are in each category where the price is greater than 100." |
| ] |
|
|
| |
| with gr.Blocks(title="Text to SQL Converter") as demo: |
| gr.Markdown("# Text to SQL Query Converter") |
| gr.Markdown("Enter your question and optional table context to generate an SQL query.") |
| |
| with gr.Row(): |
| with gr.Column(): |
| question_input = gr.Textbox( |
| label="Your Question", |
| placeholder="e.g., Find all products with price less than $50", |
| lines=2 |
| ) |
| |
| table_context = gr.Textbox( |
| label="Table Context (Optional)", |
| placeholder="Enter your database schema or table definitions here...", |
| lines=10 |
| ) |
| |
| submit_btn = gr.Button("Generate SQL Query") |
| |
| with gr.Column(): |
| sql_output = gr.Code( |
| label="Generated SQL Query", |
| language="sql", |
| lines=12 |
| ) |
| |
| |
| gr.Markdown("### Try some examples") |
| |
| example_selector = gr.Examples( |
| examples=[ |
| ["List all products in the 'Electronics' category with price less than $500", example_contexts[1]], |
| ["Find the total number of employees in each department", example_contexts[2]], |
| ["Get customers who placed orders in the last 7 days", example_contexts[0]], |
| ["Count the number of products in each category", example_contexts[1]], |
| ["Find the average salary by department", example_contexts[2]] |
| ], |
| inputs=[question_input, table_context] |
| ) |
| |
| |
| submit_btn.click( |
| fn=process_input, |
| inputs=[question_input, table_context], |
| outputs=sql_output |
| ) |
| |
| |
| question_input.submit( |
| fn=process_input, |
| inputs=[question_input, table_context], |
| outputs=sql_output |
| ) |
| |
| |
| gr.Markdown(""" |
| ### About |
| This app uses a fine-tuned language model to convert natural language questions into SQL queries. |
| |
| - **Model**: [onkolahmet/Qwen2-0.5B-Instruct-SQL-generator](https://huggingface.co/onkolahmet/Qwen2-0.5B-Instruct-SQL-generator) |
| - **How to use**: |
| 1. Enter your question in natural language |
| 2. If you have specific table schemas, add them in the Table Context field |
| 3. Click "Generate SQL Query" or press Enter |
| |
| Note: The model works best when table context is provided, but can generate generic SQL queries without it. |
| """) |
|
|
| |
| demo.launch() |