Spaces:
Sleeping
Sleeping
| """ | |
| Multi-Turn Text-to-SQL Agent with Clarification Capabilities | |
| ============================================================= | |
| An intelligent SQL assistant that: | |
| - Answers clear database questions with accurate SQL | |
| - Detects ambiguous questions and asks targeted clarifications | |
| - Explains when questions can't be answered with available data | |
| - Self-corrects SQL errors via ReAct reasoning loop | |
| - Maintains multi-turn conversation context | |
| Architecture based on: | |
| - MMSQL (arXiv:2412.17867) β 4-type question classification | |
| - PRACTIQ (arXiv:2410.11076) β clarification dialogue patterns | |
| - SQLFixAgent (arXiv:2406.13408) β self-correcting SQL generation | |
| Built with smolagents CodeAgent + Gradio UI. | |
| """ | |
| import os | |
| import sqlite3 | |
| from textwrap import dedent | |
| from smolagents import ( | |
| tool, | |
| CodeAgent, | |
| InferenceClientModel, | |
| GradioUI, | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # 1. Database Setup β Sample multi-table DB | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| DB_PATH = "demo_company.db" | |
| def create_demo_database(db_path: str = DB_PATH): | |
| """Creates a rich demo company database with realistic data and some ambiguous schema elements.""" | |
| conn = sqlite3.connect(db_path) | |
| cursor = conn.cursor() | |
| for table in ["order_items", "orders", "products", "customers", "employees", "departments"]: | |
| cursor.execute(f"DROP TABLE IF EXISTS {table}") | |
| cursor.execute(""" | |
| CREATE TABLE departments ( | |
| dept_id INTEGER PRIMARY KEY, | |
| name TEXT NOT NULL, | |
| location TEXT, | |
| budget REAL | |
| ) | |
| """) | |
| cursor.executemany("INSERT INTO departments VALUES (?, ?, ?, ?)", [ | |
| (1, "Engineering", "San Francisco", 2500000.00), | |
| (2, "Sales", "New York", 1800000.00), | |
| (3, "Marketing", "New York", 1200000.00), | |
| (4, "HR", "Chicago", 800000.00), | |
| (5, "Finance", "Chicago", 950000.00), | |
| ]) | |
| cursor.execute(""" | |
| CREATE TABLE employees ( | |
| emp_id INTEGER PRIMARY KEY, | |
| name TEXT NOT NULL, | |
| email TEXT, | |
| dept_id INTEGER REFERENCES departments(dept_id), | |
| salary REAL, | |
| hire_date TEXT, | |
| manager_id INTEGER REFERENCES employees(emp_id), | |
| status TEXT DEFAULT 'active' | |
| ) | |
| """) | |
| cursor.executemany("INSERT INTO employees VALUES (?, ?, ?, ?, ?, ?, ?, ?)", [ | |
| (1, "Alice Chen", "alice@company.com", 1, 145000, "2019-03-15", None, "active"), | |
| (2, "Bob Martinez", "bob@company.com", 1, 128000, "2020-06-01", 1, "active"), | |
| (3, "Carol Smith", "carol@company.com", 2, 95000, "2021-01-10", None, "active"), | |
| (4, "David Lee", "david@company.com", 2, 88000, "2021-08-20", 3, "active"), | |
| (5, "Eva Johnson", "eva@company.com", 3, 102000, "2020-11-05", None, "active"), | |
| (6, "Frank Wilson", "frank@company.com", 1, 135000, "2019-07-22", 1, "active"), | |
| (7, "Grace Kim", "grace@company.com", 4, 78000, "2022-02-14", None, "active"), | |
| (8, "Henry Brown", "henry@company.com", 5, 115000, "2020-04-30", None, "active"), | |
| (9, "Iris Davis", "iris@company.com", 2, 92000, "2022-09-01", 3, "active"), | |
| (10, "Jack Taylor", "jack@company.com", 1, 140000, "2019-11-18", 1, "inactive"), | |
| (11, "Karen White", "karen@company.com", 3, 98000, "2021-05-12", 5, "active"), | |
| (12, "Leo Garcia", "leo@company.com", 5, 105000, "2021-03-28", 8, "active"), | |
| ]) | |
| cursor.execute(""" | |
| CREATE TABLE customers ( | |
| customer_id INTEGER PRIMARY KEY, | |
| name TEXT NOT NULL, | |
| email TEXT, | |
| city TEXT, | |
| state TEXT, | |
| signup_date TEXT, | |
| tier TEXT DEFAULT 'standard' | |
| ) | |
| """) | |
| cursor.executemany("INSERT INTO customers VALUES (?, ?, ?, ?, ?, ?, ?)", [ | |
| (1, "Acme Corp", "contact@acme.com", "San Francisco", "CA", "2020-01-15", "premium"), | |
| (2, "Beta Industries", "info@beta.com", "New York", "NY", "2020-03-22", "standard"), | |
| (3, "Gamma Solutions", "hello@gamma.com", "Chicago", "IL", "2020-06-10", "premium"), | |
| (4, "Delta Systems", "sales@delta.com", "Austin", "TX", "2021-02-05", "enterprise"), | |
| (5, "Epsilon LLC", "team@epsilon.com", "Seattle", "WA", "2021-08-18", "standard"), | |
| (6, "Zeta Partners", "info@zeta.com", "Boston", "MA", "2022-01-30", "premium"), | |
| (7, "Eta Global", "contact@eta.com", "Denver", "CO", "2022-07-14", "standard"), | |
| (8, "Theta Inc", "hello@theta.com", "Portland", "OR", "2023-03-01", "enterprise"), | |
| ]) | |
| cursor.execute(""" | |
| CREATE TABLE products ( | |
| product_id INTEGER PRIMARY KEY, | |
| name TEXT NOT NULL, | |
| category TEXT, | |
| price REAL, | |
| cost REAL, | |
| stock_quantity INTEGER, | |
| status TEXT DEFAULT 'active' | |
| ) | |
| """) | |
| cursor.executemany("INSERT INTO products VALUES (?, ?, ?, ?, ?, ?, ?)", [ | |
| (1, "Widget Pro", "Hardware", 299.99, 150.00, 500, "active"), | |
| (2, "Widget Basic", "Hardware", 149.99, 75.00, 1200, "active"), | |
| (3, "DataSync Cloud", "Software", 49.99, 10.00, None, "active"), | |
| (4, "DataSync Enterprise", "Software", 199.99, 40.00, None, "active"), | |
| (5, "SecureVault", "Software", 89.99, 20.00, None, "active"), | |
| (6, "PowerAdapter X", "Hardware", 39.99, 18.00, 3000, "active"), | |
| (7, "Legacy Suite", "Software", 299.99, 60.00, None, "discontinued"), | |
| (8, "SmartHub", "Hardware", 449.99, 220.00, 200, "active"), | |
| ]) | |
| cursor.execute(""" | |
| CREATE TABLE orders ( | |
| order_id INTEGER PRIMARY KEY, | |
| customer_id INTEGER REFERENCES customers(customer_id), | |
| employee_id INTEGER REFERENCES employees(emp_id), | |
| order_date TEXT, | |
| status TEXT, | |
| total_amount REAL | |
| ) | |
| """) | |
| cursor.executemany("INSERT INTO orders VALUES (?, ?, ?, ?, ?, ?)", [ | |
| (1001, 1, 3, "2024-01-15", "completed", 1499.95), | |
| (1002, 2, 4, "2024-01-22", "completed", 599.96), | |
| (1003, 3, 3, "2024-02-10", "completed", 899.97), | |
| (1004, 1, 9, "2024-02-28", "shipped", 449.99), | |
| (1005, 4, 4, "2024-03-05", "completed", 2499.90), | |
| (1006, 5, 3, "2024-03-18", "pending", 149.99), | |
| (1007, 6, 9, "2024-04-02", "completed", 749.97), | |
| (1008, 3, 3, "2024-04-15", "completed", 339.98), | |
| (1009, 7, 4, "2024-05-01", "cancelled", 299.99), | |
| (1010, 8, 9, "2024-05-20", "shipped", 1349.97), | |
| (1011, 1, 3, "2024-06-01", "completed", 199.98), | |
| (1012, 4, 4, "2024-06-15", "completed", 3599.88), | |
| ]) | |
| cursor.execute(""" | |
| CREATE TABLE order_items ( | |
| item_id INTEGER PRIMARY KEY, | |
| order_id INTEGER REFERENCES orders(order_id), | |
| product_id INTEGER REFERENCES products(product_id), | |
| quantity INTEGER, | |
| unit_price REAL, | |
| discount REAL DEFAULT 0.0 | |
| ) | |
| """) | |
| cursor.executemany("INSERT INTO order_items VALUES (?, ?, ?, ?, ?, ?)", [ | |
| (1, 1001, 1, 5, 299.99, 0.0), | |
| (2, 1002, 2, 4, 149.99, 0.0), | |
| (3, 1003, 3, 6, 49.99, 0.0), | |
| (4, 1003, 5, 3, 89.99, 10.0), | |
| (5, 1004, 8, 1, 449.99, 0.0), | |
| (6, 1005, 1, 5, 299.99, 0.0), | |
| (7, 1005, 4, 5, 199.99, 0.0), | |
| (8, 1006, 2, 1, 149.99, 0.0), | |
| (9, 1007, 5, 3, 89.99, 0.0), | |
| (10, 1007, 3, 9, 49.99, 10.0), | |
| (11, 1008, 6, 5, 39.99, 0.0), | |
| (12, 1008, 3, 3, 49.99, 10.0), | |
| (13, 1009, 1, 1, 299.99, 0.0), | |
| (14, 1010, 8, 3, 449.99, 0.0), | |
| (15, 1011, 3, 4, 49.99, 0.0), | |
| (16, 1012, 4, 12, 199.99, 15.0), | |
| (17, 1012, 8, 2, 449.99, 10.0), | |
| ]) | |
| conn.commit() | |
| conn.close() | |
| return db_path | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # 2. Build Dynamic Schema Description | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| def get_schema_description(db_path: str = DB_PATH) -> str: | |
| conn = sqlite3.connect(db_path) | |
| cursor = conn.cursor() | |
| cursor.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name") | |
| tables = [row[0] for row in cursor.fetchall()] | |
| schema_parts = [] | |
| for table in tables: | |
| cursor.execute(f"PRAGMA table_info({table})") | |
| columns = cursor.fetchall() | |
| cursor.execute(f"PRAGMA foreign_key_list({table})") | |
| fks = cursor.fetchall() | |
| fk_map = {fk[3]: f"β {fk[2]}({fk[4]})" for fk in fks} | |
| cursor.execute(f"SELECT COUNT(*) FROM {table}") | |
| row_count = cursor.fetchone()[0] | |
| table_desc = f"Table '{table}' ({row_count} rows):\n Columns:\n" | |
| for col in columns: | |
| col_id, col_name, col_type, not_null, default, pk = col | |
| parts = [f" - {col_name}: {col_type or 'TEXT'}"] | |
| if pk: parts.append("PRIMARY KEY") | |
| if not_null and not pk: parts.append("NOT NULL") | |
| if default is not None: parts.append(f"DEFAULT {default}") | |
| if col_name in fk_map: parts.append(f"FK {fk_map[col_name]}") | |
| table_desc += " ".join(parts) + "\n" | |
| for col in columns: | |
| col_name, col_type = col[1], col[2] | |
| if col_type in ("TEXT", None) and col_name not in ("email",): | |
| try: | |
| cursor.execute(f"SELECT DISTINCT {col_name} FROM {table} WHERE {col_name} IS NOT NULL LIMIT 8") | |
| vals = [str(r[0]) for r in cursor.fetchall()] | |
| if vals: | |
| table_desc += f" Sample '{col_name}' values: {', '.join(vals)}\n" | |
| except: | |
| pass | |
| schema_parts.append(table_desc) | |
| conn.close() | |
| return "\n".join(schema_parts) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # 3. Define Agent Tools | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| SCHEMA_DESCRIPTION = "" | |
| def execute_sql(query: str) -> str: | |
| """ | |
| Executes a SQL query against the company database and returns the results. | |
| Use this tool to run SELECT queries to answer user questions about the data. | |
| IMPORTANT RULES: | |
| - Only use SELECT statements (no INSERT, UPDATE, DELETE, DROP) | |
| - Always use table and column names exactly as shown in the schema | |
| - Use JOINs when data spans multiple tables | |
| - Use LIMIT to avoid overwhelming output (max 50 rows) | |
| DATABASE SCHEMA: | |
| {schema} | |
| Args: | |
| query: A valid SQL SELECT query to execute against the database. | |
| """ | |
| cleaned = query.strip().upper() | |
| if not cleaned.startswith("SELECT") and not cleaned.startswith("WITH"): | |
| return "ERROR: Only SELECT queries are allowed." | |
| try: | |
| conn = sqlite3.connect(DB_PATH) | |
| cursor = conn.cursor() | |
| cursor.execute(query) | |
| columns = [desc[0] for desc in cursor.description] if cursor.description else [] | |
| rows = cursor.fetchall() | |
| conn.close() | |
| if not rows: | |
| return f"Query executed successfully.\nColumns: {', '.join(columns)}\nResult: No rows returned." | |
| result = f"Query executed successfully. {len(rows)} row(s) returned.\n" | |
| result += "Columns: " + " | ".join(columns) + "\n" | |
| result += "-" * 60 + "\n" | |
| for row in rows[:50]: | |
| result += " | ".join(str(v) for v in row) + "\n" | |
| if len(rows) > 50: | |
| result += f"... ({len(rows) - 50} more rows truncated)\n" | |
| return result | |
| except Exception as e: | |
| return f"SQL ERROR: {str(e)}\n\nPlease check your query syntax and column/table names against the schema." | |
| def inspect_schema(table_name: str = "") -> str: | |
| """ | |
| Inspect the database schema. If a table_name is provided, shows detailed info | |
| about that specific table including column types, foreign keys, and sample data. | |
| If no table_name is given, shows an overview of all tables. | |
| Use this tool BEFORE writing SQL to understand the database structure, | |
| especially when the user's question is ambiguous about which tables or columns to use. | |
| Args: | |
| table_name: Name of a specific table to inspect. Leave empty for full schema overview. | |
| """ | |
| conn = sqlite3.connect(DB_PATH) | |
| cursor = conn.cursor() | |
| if not table_name: | |
| return f"DATABASE SCHEMA OVERVIEW:\n\n{SCHEMA_DESCRIPTION}" | |
| try: | |
| cursor.execute(f"PRAGMA table_info({table_name})") | |
| columns = cursor.fetchall() | |
| if not columns: | |
| conn.close() | |
| return f"Table '{table_name}' not found. Use inspect_schema() with no arguments to see all tables." | |
| result = f"DETAILED INSPECTION OF TABLE '{table_name}':\n\n" | |
| result += "Columns:\n" | |
| for col in columns: | |
| result += f" {col[1]} ({col[2] or 'TEXT'})" | |
| if col[5]: result += " [PRIMARY KEY]" | |
| if col[3]: result += " [NOT NULL]" | |
| result += "\n" | |
| cursor.execute(f"PRAGMA foreign_key_list({table_name})") | |
| fks = cursor.fetchall() | |
| if fks: | |
| result += "\nForeign Keys:\n" | |
| for fk in fks: | |
| result += f" {fk[3]} β {fk[2]}({fk[4]})\n" | |
| cursor.execute(f"SELECT COUNT(*) FROM {table_name}") | |
| count = cursor.fetchone()[0] | |
| result += f"\nTotal rows: {count}\n" | |
| cursor.execute(f"SELECT * FROM {table_name} LIMIT 3") | |
| sample_rows = cursor.fetchall() | |
| col_names = [c[1] for c in columns] | |
| result += f"\nSample rows (first 3):\n" | |
| result += " | ".join(col_names) + "\n" | |
| result += "-" * 60 + "\n" | |
| for row in sample_rows: | |
| result += " | ".join(str(v) for v in row) + "\n" | |
| conn.close() | |
| return result | |
| except Exception as e: | |
| conn.close() | |
| return f"Error inspecting table: {str(e)}" | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # 4. Agent System Prompt | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| SYSTEM_INSTRUCTIONS = dedent("""\ | |
| You are an expert SQL assistant that helps users query a company database. You follow a structured multi-turn approach: | |
| ## YOUR DECISION PROCESS | |
| For EVERY user question, follow these steps: | |
| ### Step 1: Classify the Question | |
| Determine if the question is: | |
| - **ANSWERABLE**: The question is clear and maps directly to the database schema | |
| - **AMBIGUOUS**: The question could have multiple valid SQL interpretations (e.g., "show me the top employees" β top by salary? by sales? by tenure?) | |
| - **UNANSWERABLE**: The question asks for data that doesn't exist in the database | |
| ### Step 2: Handle Based on Classification | |
| **If AMBIGUOUS:** | |
| - Identify ALL possible interpretations | |
| - Use `final_answer()` to return a targeted clarification question listing the specific options | |
| - Example: call `final_answer("Your question could mean several things:\\n1. Employees with highest salary\\n2. Employees who handled the most orders\\n3. Employees with the longest tenure\\n\\nWhich interpretation do you mean?")` | |
| - Do NOT generate SQL β return the clarification question immediately using `final_answer()` | |
| - The user will respond in the next turn with their clarification | |
| **If UNANSWERABLE:** | |
| - Use `final_answer()` to explain clearly what data is missing and why the question can't be answered | |
| - Include a suggestion for a related question that CAN be answered with the available data | |
| **If ANSWERABLE:** | |
| - First inspect the schema to confirm the right tables/columns | |
| - Generate and execute the SQL query | |
| - Present results clearly with a natural language summary | |
| ### Step 3: Self-Correct | |
| - If your SQL returns an error, analyze the error and fix the query | |
| - If the result seems wrong or empty, verify your joins and filters | |
| - Always sanity-check: does the result make sense given what was asked? | |
| ## COMMON AMBIGUITY PATTERNS TO WATCH FOR | |
| 1. **Column ambiguity**: "Show employee names" β the 'name' column appears in employees, departments, customers, and products tables | |
| 2. **Metric ambiguity**: "Top customers" β by total spending? by number of orders? by most recent activity? | |
| 3. **Filter ambiguity**: "Recent orders" β last week? last month? last quarter? | |
| 4. **Scope ambiguity**: "Total sales" β all time? this year? by product? by employee? | |
| 5. **Status ambiguity**: "List products" β all products? only active ones? including discontinued? | |
| 6. **Value ambiguity**: "Expensive products" β what price threshold? | |
| ## FORMATTING RULES | |
| - When presenting query results, format them as a clear table | |
| - Always explain what the query does in plain language | |
| - If you make assumptions (e.g., "I'm assuming you mean active employees only"), state them explicitly | |
| - For numerical results, include relevant aggregations (count, sum, average) when helpful | |
| """) | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| # 5. Main | |
| # βββββββββββββββββββββββββββββββββββββββββββββ | |
| def create_agent(model_id: str = "Qwen/Qwen2.5-Coder-32B-Instruct"): | |
| create_demo_database() | |
| global SCHEMA_DESCRIPTION | |
| SCHEMA_DESCRIPTION = get_schema_description() | |
| execute_sql.description = execute_sql.description.replace("{schema}", SCHEMA_DESCRIPTION) | |
| model = InferenceClientModel(model_id=model_id) | |
| agent = CodeAgent( | |
| tools=[execute_sql, inspect_schema], | |
| model=model, | |
| instructions=SYSTEM_INSTRUCTIONS, | |
| max_steps=15, | |
| additional_authorized_imports=["json", "re"], | |
| ) | |
| return agent | |
| if __name__ == "__main__": | |
| agent = create_agent() | |
| ui = GradioUI(agent, reset_agent_memory=False) | |
| ui.launch() | |