| import os |
| import pandas as pd |
| import sqlite3 |
| import numpy as np |
| import json |
| import re |
| from typing import List, Dict, Tuple |
| from groq import Groq |
| import gradio as gr |
| from sklearn.metrics import accuracy_score |
| import warnings |
| warnings.filterwarnings('ignore') |
|
|
| |
| |
| |
| GROQ_API_KEY = os.getenv("GROQ_API_KEY") |
|
|
| if not GROQ_API_KEY: |
| print("โ ๏ธ WARNING: GROQ_API_KEY environment variable not set!") |
| print("Please add your Groq API key to your Hugging Face Space secrets.") |
| print("For demo purposes, the app will continue but API calls will fail.") |
| GROQ_API_KEY = "dummy-key-for-demo" |
|
|
| |
| |
| |
|
|
| class EnhancedNL2SQLConverter: |
| def __init__(self, model_name: str = "llama-3.3-70b-versatile"): |
| self.model_name = model_name |
| self.client = None |
| |
| try: |
| if GROQ_API_KEY and GROQ_API_KEY != "dummy-key-for-demo": |
| self.client = Groq(api_key=GROQ_API_KEY) |
| print(f"โ
Successfully initialized Groq client with model: {self.model_name}") |
| else: |
| print("โ ๏ธ Groq client not initialized - API key missing") |
| except Exception as e: |
| print(f"โ Error initializing Groq client: {str(e)}") |
| self.client = None |
|
|
| self.default_schema = """ |
| Table: employees |
| Columns: |
| - id (INTEGER) PRIMARY KEY |
| - name (TEXT) NOT NULL |
| - department (TEXT) |
| - salary (REAL) |
| - hire_date (TEXT) |
| - manager_id (INTEGER) |
| """ |
|
|
| def generate_sql(self, query: str, schema: str = None) -> str: |
| try: |
| if not self.client: |
| return "ERROR: Groq API client not initialized. Please check your API key." |
| |
| schema_to_use = schema or self.default_schema |
|
|
| system_prompt = """You are an expert SQL query generator. Convert natural language questions to SQL queries based on the provided database schema. |
| |
| Rules: |
| 1. Only return the SQL query, nothing else |
| 2. Use proper SQL syntax |
| 3. Be precise with column names and table names |
| 4. Use appropriate WHERE clauses, JOINs, and aggregations as needed |
| 5. For date comparisons, use proper date format |
| 6. Don't include explanations, just the SQL query""" |
|
|
| user_prompt = f"""Database Schema: |
| {schema_to_use} |
| |
| Natural Language Question: {query} |
| |
| Generate the SQL query:""" |
|
|
| chat_completion = self.client.chat.completions.create( |
| messages=[ |
| {"role": "system", "content": system_prompt}, |
| {"role": "user", "content": user_prompt} |
| ], |
| model=self.model_name, |
| temperature=0.1, |
| max_tokens=200 |
| ) |
|
|
| sql_query = chat_completion.choices[0].message.content.strip() |
| return self._clean_sql(sql_query) |
|
|
| except Exception as e: |
| print(f"Error generating SQL: {str(e)}") |
| return f"ERROR: Could not generate SQL query - {str(e)}" |
|
|
| def _clean_sql(self, sql: str) -> str: |
| sql = sql.strip() |
| sql = re.sub(r'```sql\n?', '', sql) |
| sql = re.sub(r'```\n?', '', sql) |
| sql = re.sub(r'^["\']|["\']$', '', sql) |
| sql = sql.rstrip(';') |
|
|
| sql_keywords = ['SELECT', 'INSERT', 'UPDATE', 'DELETE', 'CREATE', 'DROP', 'ALTER'] |
| if not any(sql.upper().startswith(keyword) for keyword in sql_keywords): |
| for keyword in sql_keywords: |
| if keyword in sql.upper(): |
| sql = sql[sql.upper().find(keyword):] |
| break |
| return sql |
|
|
| |
| |
| |
|
|
| class SQLEvaluator: |
| def __init__(self): |
| self.db_path = "test_database.db" |
| self.setup_test_database() |
|
|
| def setup_test_database(self): |
| conn = sqlite3.connect(self.db_path) |
| cursor = conn.cursor() |
| |
| cursor.execute(''' |
| CREATE TABLE IF NOT EXISTS employees ( |
| id INTEGER PRIMARY KEY, |
| name TEXT NOT NULL, |
| department TEXT, |
| salary REAL, |
| hire_date TEXT, |
| manager_id INTEGER |
| )''') |
| |
| sample_data = [ |
| (1, 'Alice Johnson', 'Engineering', 75000, '2022-01-15', None), |
| (2, 'Bob Smith', 'Sales', 65000, '2021-06-20', None), |
| (3, 'Charlie Brown', 'Engineering', 80000, '2020-03-10', 1), |
| (4, 'Diana Prince', 'HR', 60000, '2023-02-28', None), |
| (5, 'Eve Wilson', 'Sales', 70000, '2022-11-05', 2), |
| (6, 'Frank Miller', 'Engineering', 85000, '2019-08-12', 1), |
| (7, 'Grace Lee', 'Marketing', 55000, '2023-01-20', None), |
| (8, 'Henry Davis', 'Engineering', 72000, '2022-07-30', 1) |
| ] |
| |
| cursor.executemany(''' |
| INSERT OR REPLACE INTO employees (id, name, department, salary, hire_date, manager_id) |
| VALUES (?, ?, ?, ?, ?, ?)''', sample_data) |
| |
| conn.commit() |
| conn.close() |
| print("โ
Test database initialized successfully") |
|
|
| def execute_sql(self, sql_query: str) -> Tuple[bool, any]: |
| try: |
| conn = sqlite3.connect(self.db_path) |
| cursor = conn.cursor() |
| cursor.execute(sql_query) |
|
|
| if sql_query.strip().upper().startswith('SELECT'): |
| results = cursor.fetchall() |
| columns = [description[0] for description in cursor.description] |
| conn.close() |
| return True, {'columns': columns, 'data': results} |
| else: |
| conn.commit() |
| conn.close() |
| return True, "Query executed successfully" |
| except Exception as e: |
| return False, str(e) |
|
|
| |
| |
| |
| try: |
| converter = EnhancedNL2SQLConverter() |
| evaluator = SQLEvaluator() |
| print("โ
Application components initialized successfully") |
| except Exception as e: |
| print(f"โ Error initializing components: {str(e)}") |
| converter = None |
| evaluator = SQLEvaluator() |
|
|
| |
| |
| |
|
|
| def process_nl_query(nl_query: str) -> Tuple[str, str, str]: |
| """Process natural language query and return SQL + results""" |
| if not nl_query.strip(): |
| return "", "", "โ ๏ธ Please enter a natural language query." |
| |
| try: |
| if not converter: |
| return "", "", "โ Error: SQL converter not initialized. Please check API configuration." |
| |
| generated_sql = converter.generate_sql(nl_query) |
| |
| if generated_sql.startswith("ERROR"): |
| return generated_sql, "", "โ Failed to generate SQL query. Please check your API key." |
| |
| success, result = evaluator.execute_sql(generated_sql) |
| |
| if success and isinstance(result, dict): |
| df = pd.DataFrame(result['data'], columns=result['columns']) |
| if len(df) == 0: |
| formatted_output = "No results found." |
| else: |
| formatted_output = df.to_string(index=False) |
| return generated_sql, formatted_output, "โ
Query executed successfully!" |
| elif success: |
| return generated_sql, str(result), "โ
Query executed successfully!" |
| else: |
| return generated_sql, "", f"โ Error executing query: {result}" |
| |
| except Exception as e: |
| return "", "", f"โ Unexpected error: {str(e)}" |
|
|
| def get_sample_queries(): |
| return [ |
| "Show all employees in the Engineering department", |
| "Find employees with salary greater than 70000", |
| "List all employees hired after 2022", |
| "Count employees by department", |
| "Show the highest paid employee in each department", |
| "Find employees who don't have a manager", |
| "Show average salary by department" |
| ] |
|
|
| |
| |
| |
|
|
| custom_css = """ |
| /* Main container styling */ |
| .gradio-container { |
| background: linear-gradient(135deg, #667eea 0%, #764ba2 100%) !important; |
| min-height: 100vh; |
| font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif; |
| } |
| |
| /* Header styling */ |
| .header-container { |
| background: rgba(255, 255, 255, 0.1); |
| backdrop-filter: blur(20px); |
| border-radius: 20px; |
| padding: 2rem; |
| margin-bottom: 2rem; |
| border: 1px solid rgba(255, 255, 255, 0.2); |
| box-shadow: 0 8px 32px rgba(31, 38, 135, 0.37); |
| } |
| |
| /* Card styling */ |
| .card { |
| background: rgba(255, 255, 255, 0.95); |
| backdrop-filter: blur(20px); |
| border-radius: 16px; |
| padding: 1.5rem; |
| margin: 1rem 0; |
| border: 1px solid rgba(255, 255, 255, 0.3); |
| box-shadow: 0 8px 32px rgba(31, 38, 135, 0.15); |
| transition: all 0.3s ease; |
| } |
| |
| .card:hover { |
| transform: translateY(-2px); |
| box-shadow: 0 12px 40px rgba(31, 38, 135, 0.25); |
| } |
| |
| /* Input styling */ |
| .gr-textbox { |
| border-radius: 12px !important; |
| border: 2px solid rgba(103, 126, 234, 0.3) !important; |
| background: rgba(255, 255, 255, 0.9) !important; |
| transition: all 0.3s ease !important; |
| } |
| |
| .gr-textbox:focus { |
| border-color: #667eea !important; |
| box-shadow: 0 0 0 3px rgba(103, 126, 234, 0.1) !important; |
| transform: scale(1.02); |
| } |
| |
| /* Button styling */ |
| .gr-button { |
| background: linear-gradient(45deg, #667eea, #764ba2) !important; |
| border: none !important; |
| border-radius: 12px !important; |
| padding: 12px 24px !important; |
| font-weight: 600 !important; |
| color: white !important; |
| transition: all 0.3s ease !important; |
| box-shadow: 0 4px 15px rgba(103, 126, 234, 0.4) !important; |
| } |
| |
| .gr-button:hover { |
| transform: translateY(-2px) !important; |
| box-shadow: 0 8px 25px rgba(103, 126, 234, 0.6) !important; |
| } |
| |
| .sample-btn { |
| background: linear-gradient(45deg, #f093fb, #f5576c) !important; |
| margin: 0.25rem !important; |
| font-size: 0.9rem !important; |
| padding: 8px 16px !important; |
| } |
| |
| .sample-btn:hover { |
| background: linear-gradient(45deg, #f5576c, #f093fb) !important; |
| } |
| |
| /* Results area styling */ |
| .results-container { |
| background: linear-gradient(135deg, #a8edea 0%, #fed6e3 100%); |
| border-radius: 16px; |
| padding: 1.5rem; |
| margin-top: 1rem; |
| } |
| |
| /* Status indicators */ |
| .status-success { |
| color: #10b981 !important; |
| font-weight: 600 !important; |
| } |
| |
| .status-error { |
| color: #ef4444 !important; |
| font-weight: 600 !important; |
| } |
| |
| .status-warning { |
| color: #f59e0b !important; |
| font-weight: 600 !important; |
| } |
| |
| /* Schema box */ |
| .schema-box { |
| background: linear-gradient(135deg, #ffecd2 0%, #fcb69f 100%); |
| border-radius: 12px; |
| padding: 1rem; |
| font-family: 'Monaco', 'Consolas', monospace; |
| border-left: 4px solid #f59e0b; |
| } |
| |
| /* Animation keyframes */ |
| @keyframes fadeInUp { |
| from { |
| opacity: 0; |
| transform: translateY(30px); |
| } |
| to { |
| opacity: 1; |
| transform: translateY(0); |
| } |
| } |
| |
| .fade-in { |
| animation: fadeInUp 0.6s ease-out; |
| } |
| |
| /* Responsive design */ |
| @media (max-width: 768px) { |
| .gradio-container { |
| padding: 1rem; |
| } |
| |
| .card { |
| padding: 1rem; |
| margin: 0.5rem 0; |
| } |
| } |
| |
| /* Loading spinner */ |
| .loading { |
| display: inline-block; |
| width: 20px; |
| height: 20px; |
| border: 3px solid rgba(255,255,255,.3); |
| border-radius: 50%; |
| border-top-color: #fff; |
| animation: spin 1s ease-in-out infinite; |
| } |
| |
| @keyframes spin { |
| to { transform: rotate(360deg); } |
| } |
| """ |
|
|
| |
| |
| |
|
|
| with gr.Blocks(css=custom_css, title="AI-Powered NL2SQL Converter", theme=gr.themes.Glass()) as iface: |
| |
| with gr.Row(elem_classes="header-container fade-in"): |
| gr.HTML(""" |
| <div style="text-align: center; color: white;"> |
| <h1 style="font-size: 3rem; margin-bottom: 0.5rem; background: linear-gradient(45deg, #fff, #f0f0f0); -webkit-background-clip: text; -webkit-text-fill-color: transparent;"> |
| ๐ AI-Powered NL2SQL Converter |
| </h1> |
| <p style="font-size: 1.2rem; opacity: 0.9; margin-bottom: 1rem;"> |
| Transform natural language into powerful SQL queries using Groq's advanced AI |
| </p> |
| <div style="display: flex; justify-content: center; gap: 2rem; margin-top: 1rem;"> |
| <div style="text-align: center;"> |
| <div style="font-size: 2rem;">๐ค</div> |
| <div style="font-size: 0.9rem; opacity: 0.8;">AI-Powered</div> |
| </div> |
| <div style="text-align: center;"> |
| <div style="font-size: 2rem;">โก</div> |
| <div style="font-size: 0.9rem; opacity: 0.8;">Lightning Fast</div> |
| </div> |
| <div style="text-align: center;"> |
| <div style="font-size: 2rem;">๐ฏ</div> |
| <div style="font-size: 0.9rem; opacity: 0.8;">Precise Results</div> |
| </div> |
| </div> |
| </div> |
| """) |
| |
| |
| with gr.Row(elem_classes="card fade-in"): |
| gr.HTML(""" |
| <div class="schema-box"> |
| <h3 style="color: #d97706; margin-bottom: 1rem;">๐ Database Schema</h3> |
| <div style="background: rgba(255,255,255,0.7); padding: 1rem; border-radius: 8px;"> |
| <strong>employees</strong> table:<br> |
| โข <code>id</code> (INTEGER) - Primary Key<br> |
| โข <code>name</code> (TEXT) - Employee Name<br> |
| โข <code>department</code> (TEXT) - Department<br> |
| โข <code>salary</code> (REAL) - Salary Amount<br> |
| โข <code>hire_date</code> (TEXT) - Hiring Date<br> |
| โข <code>manager_id</code> (INTEGER) - Manager Reference |
| </div> |
| </div> |
| """) |
| |
| |
| with gr.Row(elem_classes="card fade-in"): |
| with gr.Column(scale=3): |
| nl_input = gr.Textbox( |
| label="๐ฌ Ask your question in plain English", |
| placeholder="e.g., Show me all engineers earning more than $75,000", |
| lines=3, |
| elem_classes="main-input" |
| ) |
| |
| with gr.Row(): |
| submit_btn = gr.Button( |
| "๐ฎ Generate & Execute SQL", |
| variant="primary", |
| size="lg", |
| elem_classes="main-button" |
| ) |
| clear_btn = gr.Button( |
| "๐๏ธ Clear", |
| variant="secondary", |
| size="lg" |
| ) |
| |
| with gr.Column(scale=2): |
| gr.HTML("<h3 style='color: #667eea; margin-bottom: 1rem;'>๐ฏ Try These Examples</h3>") |
| |
| sample_queries = get_sample_queries() |
| for i, query in enumerate(sample_queries): |
| sample_btn = gr.Button( |
| f"๐ก {query}", |
| variant="secondary", |
| size="sm", |
| elem_classes="sample-btn" |
| ) |
| sample_btn.click( |
| lambda q=query: q, |
| outputs=nl_input |
| ) |
| |
| |
| with gr.Row(elem_classes="results-container fade-in"): |
| with gr.Column(): |
| gr.HTML("<h3 style='color: #6366f1; margin-bottom: 1rem;'>๐ Generated SQL Query</h3>") |
| sql_output = gr.Code( |
| label="", |
| language="sql", |
| lines=4, |
| interactive=False, |
| elem_classes="sql-output" |
| ) |
| |
| status_output = gr.HTML( |
| "<div style='padding: 1rem; text-align: center; font-size: 1.1rem;'>Ready to process your query! ๐</div>" |
| ) |
| |
| with gr.Row(elem_classes="card fade-in"): |
| gr.HTML("<h3 style='color: #059669; margin-bottom: 1rem;'>๐ Query Results</h3>") |
| results_output = gr.Code( |
| label="", |
| lines=12, |
| interactive=False, |
| elem_classes="results-output" |
| ) |
| |
| |
| with gr.Row(elem_classes="card fade-in"): |
| gr.HTML(""" |
| <div style="text-align: center; padding: 1rem;"> |
| <h3 style="color: #667eea; margin-bottom: 1rem;">๐ About This Application</h3> |
| <div style="display: grid; grid-template-columns: repeat(auto-fit, minmax(250px, 1fr)); gap: 1rem; margin-top: 1rem;"> |
| <div style="background: linear-gradient(135deg, #667eea, #764ba2); color: white; padding: 1rem; border-radius: 12px;"> |
| <h4>๐ค AI Model</h4> |
| <p>Powered by Groq's Llama3-70B for intelligent SQL generation</p> |
| </div> |
| <div style="background: linear-gradient(135deg, #f093fb, #f5576c); color: white; padding: 1rem; border-radius: 12px;"> |
| <h4>๐พ Database</h4> |
| <p>SQLite with sample employee data for testing and learning</p> |
| </div> |
| <div style="background: linear-gradient(135deg, #a8edea, #fed6e3); color: #374151; padding: 1rem; border-radius: 12px;"> |
| <h4>โจ Features</h4> |
| <p>Natural language processing, SQL execution, and formatted results</p> |
| </div> |
| </div> |
| <div style="margin-top: 2rem; padding: 1rem; background: rgba(103, 126, 234, 0.1); border-radius: 12px;"> |
| <h4 style="color: #667eea;">๐ก Pro Tips for Better Results</h4> |
| <ul style="text-align: left; display: inline-block; color: #4b5563;"> |
| <li>Be specific and clear in your questions</li> |
| <li>Use column names mentioned in the schema</li> |
| <li>Try the sample queries to understand the format</li> |
| <li>Use natural language - no need for technical jargon</li> |
| </ul> |
| </div> |
| </div> |
| """) |
| |
| |
| def enhanced_process(query): |
| if not query.strip(): |
| return "", "<div class='status-warning'>โ ๏ธ Please enter a question first!</div>", "" |
| |
| |
| loading_html = "<div class='status-info'>๐ Processing your query... <span class='loading'></span></div>" |
| |
| try: |
| sql, results, status = process_nl_query(query) |
| |
| |
| if "successfully" in status.lower(): |
| status_html = f"<div class='status-success'>{status}</div>" |
| elif "error" in status.lower() or "failed" in status.lower(): |
| status_html = f"<div class='status-error'>{status}</div>" |
| else: |
| status_html = f"<div class='status-warning'>{status}</div>" |
| |
| return sql, status_html, results |
| |
| except Exception as e: |
| return "", f"<div class='status-error'>โ Unexpected error: {str(e)}</div>", "" |
| |
| def clear_all(): |
| return "", "", "<div style='padding: 1rem; text-align: center; font-size: 1.1rem;'>Ready to process your query! ๐</div>", "" |
| |
| |
| submit_btn.click( |
| fn=enhanced_process, |
| inputs=[nl_input], |
| outputs=[sql_output, status_output, results_output] |
| ) |
| |
| nl_input.submit( |
| fn=enhanced_process, |
| inputs=[nl_input], |
| outputs=[sql_output, status_output, results_output] |
| ) |
| |
| clear_btn.click( |
| fn=clear_all, |
| outputs=[nl_input, sql_output, status_output, results_output] |
| ) |
|
|
| |
| if __name__ == "__main__": |
| print("๐ Launching Enhanced NL2SQL Application...") |
| iface.launch( |
| server_name="0.0.0.0", |
| server_port=7860, |
| share=True, |
| show_error=True |
| ) |