k1golestan commited on
Commit
9aa0f44
Β·
verified Β·
1 Parent(s): 2e5ff2b

Add main app.py for text-to-sql agent

Browse files
Files changed (1) hide show
  1. app.py +438 -0
app.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Multi-Turn Text-to-SQL Agent with Clarification Capabilities
3
+ =============================================================
4
+
5
+ An intelligent SQL assistant that:
6
+ - Answers clear database questions with accurate SQL
7
+ - Detects ambiguous questions and asks targeted clarifications
8
+ - Explains when questions can't be answered with available data
9
+ - Self-corrects SQL errors via ReAct reasoning loop
10
+ - Maintains multi-turn conversation context
11
+
12
+ Architecture based on:
13
+ - MMSQL (arXiv:2412.17867) β€” 4-type question classification
14
+ - PRACTIQ (arXiv:2410.11076) β€” clarification dialogue patterns
15
+ - SQLFixAgent (arXiv:2406.13408) β€” self-correcting SQL generation
16
+
17
+ Built with smolagents CodeAgent + Gradio UI.
18
+ """
19
+
20
+ import os
21
+ import sqlite3
22
+ from textwrap import dedent
23
+
24
+ from smolagents import (
25
+ tool,
26
+ CodeAgent,
27
+ InferenceClientModel,
28
+ GradioUI,
29
+ )
30
+
31
+ # ─────────────────────────────────────────────
32
+ # 1. Database Setup β€” Sample multi-table DB
33
+ # ─────────────────────────────────────────────
34
+
35
+ DB_PATH = "demo_company.db"
36
+
37
+
38
+ def create_demo_database(db_path: str = DB_PATH):
39
+ """Creates a rich demo company database with realistic data and some ambiguous schema elements."""
40
+ conn = sqlite3.connect(db_path)
41
+ cursor = conn.cursor()
42
+
43
+ for table in ["order_items", "orders", "products", "customers", "employees", "departments"]:
44
+ cursor.execute(f"DROP TABLE IF EXISTS {table}")
45
+
46
+ cursor.execute("""
47
+ CREATE TABLE departments (
48
+ dept_id INTEGER PRIMARY KEY,
49
+ name TEXT NOT NULL,
50
+ location TEXT,
51
+ budget REAL
52
+ )
53
+ """)
54
+ cursor.executemany("INSERT INTO departments VALUES (?, ?, ?, ?)", [
55
+ (1, "Engineering", "San Francisco", 2500000.00),
56
+ (2, "Sales", "New York", 1800000.00),
57
+ (3, "Marketing", "New York", 1200000.00),
58
+ (4, "HR", "Chicago", 800000.00),
59
+ (5, "Finance", "Chicago", 950000.00),
60
+ ])
61
+
62
+ cursor.execute("""
63
+ CREATE TABLE employees (
64
+ emp_id INTEGER PRIMARY KEY,
65
+ name TEXT NOT NULL,
66
+ email TEXT,
67
+ dept_id INTEGER REFERENCES departments(dept_id),
68
+ salary REAL,
69
+ hire_date TEXT,
70
+ manager_id INTEGER REFERENCES employees(emp_id),
71
+ status TEXT DEFAULT 'active'
72
+ )
73
+ """)
74
+ cursor.executemany("INSERT INTO employees VALUES (?, ?, ?, ?, ?, ?, ?, ?)", [
75
+ (1, "Alice Chen", "alice@company.com", 1, 145000, "2019-03-15", None, "active"),
76
+ (2, "Bob Martinez", "bob@company.com", 1, 128000, "2020-06-01", 1, "active"),
77
+ (3, "Carol Smith", "carol@company.com", 2, 95000, "2021-01-10", None, "active"),
78
+ (4, "David Lee", "david@company.com", 2, 88000, "2021-08-20", 3, "active"),
79
+ (5, "Eva Johnson", "eva@company.com", 3, 102000, "2020-11-05", None, "active"),
80
+ (6, "Frank Wilson", "frank@company.com", 1, 135000, "2019-07-22", 1, "active"),
81
+ (7, "Grace Kim", "grace@company.com", 4, 78000, "2022-02-14", None, "active"),
82
+ (8, "Henry Brown", "henry@company.com", 5, 115000, "2020-04-30", None, "active"),
83
+ (9, "Iris Davis", "iris@company.com", 2, 92000, "2022-09-01", 3, "active"),
84
+ (10, "Jack Taylor", "jack@company.com", 1, 140000, "2019-11-18", 1, "inactive"),
85
+ (11, "Karen White", "karen@company.com", 3, 98000, "2021-05-12", 5, "active"),
86
+ (12, "Leo Garcia", "leo@company.com", 5, 105000, "2021-03-28", 8, "active"),
87
+ ])
88
+
89
+ cursor.execute("""
90
+ CREATE TABLE customers (
91
+ customer_id INTEGER PRIMARY KEY,
92
+ name TEXT NOT NULL,
93
+ email TEXT,
94
+ city TEXT,
95
+ state TEXT,
96
+ signup_date TEXT,
97
+ tier TEXT DEFAULT 'standard'
98
+ )
99
+ """)
100
+ cursor.executemany("INSERT INTO customers VALUES (?, ?, ?, ?, ?, ?, ?)", [
101
+ (1, "Acme Corp", "contact@acme.com", "San Francisco", "CA", "2020-01-15", "premium"),
102
+ (2, "Beta Industries", "info@beta.com", "New York", "NY", "2020-03-22", "standard"),
103
+ (3, "Gamma Solutions", "hello@gamma.com", "Chicago", "IL", "2020-06-10", "premium"),
104
+ (4, "Delta Systems", "sales@delta.com", "Austin", "TX", "2021-02-05", "enterprise"),
105
+ (5, "Epsilon LLC", "team@epsilon.com", "Seattle", "WA", "2021-08-18", "standard"),
106
+ (6, "Zeta Partners", "info@zeta.com", "Boston", "MA", "2022-01-30", "premium"),
107
+ (7, "Eta Global", "contact@eta.com", "Denver", "CO", "2022-07-14", "standard"),
108
+ (8, "Theta Inc", "hello@theta.com", "Portland", "OR", "2023-03-01", "enterprise"),
109
+ ])
110
+
111
+ cursor.execute("""
112
+ CREATE TABLE products (
113
+ product_id INTEGER PRIMARY KEY,
114
+ name TEXT NOT NULL,
115
+ category TEXT,
116
+ price REAL,
117
+ cost REAL,
118
+ stock_quantity INTEGER,
119
+ status TEXT DEFAULT 'active'
120
+ )
121
+ """)
122
+ cursor.executemany("INSERT INTO products VALUES (?, ?, ?, ?, ?, ?, ?)", [
123
+ (1, "Widget Pro", "Hardware", 299.99, 150.00, 500, "active"),
124
+ (2, "Widget Basic", "Hardware", 149.99, 75.00, 1200, "active"),
125
+ (3, "DataSync Cloud", "Software", 49.99, 10.00, None, "active"),
126
+ (4, "DataSync Enterprise", "Software", 199.99, 40.00, None, "active"),
127
+ (5, "SecureVault", "Software", 89.99, 20.00, None, "active"),
128
+ (6, "PowerAdapter X", "Hardware", 39.99, 18.00, 3000, "active"),
129
+ (7, "Legacy Suite", "Software", 299.99, 60.00, None, "discontinued"),
130
+ (8, "SmartHub", "Hardware", 449.99, 220.00, 200, "active"),
131
+ ])
132
+
133
+ cursor.execute("""
134
+ CREATE TABLE orders (
135
+ order_id INTEGER PRIMARY KEY,
136
+ customer_id INTEGER REFERENCES customers(customer_id),
137
+ employee_id INTEGER REFERENCES employees(emp_id),
138
+ order_date TEXT,
139
+ status TEXT,
140
+ total_amount REAL
141
+ )
142
+ """)
143
+ cursor.executemany("INSERT INTO orders VALUES (?, ?, ?, ?, ?, ?)", [
144
+ (1001, 1, 3, "2024-01-15", "completed", 1499.95),
145
+ (1002, 2, 4, "2024-01-22", "completed", 599.96),
146
+ (1003, 3, 3, "2024-02-10", "completed", 899.97),
147
+ (1004, 1, 9, "2024-02-28", "shipped", 449.99),
148
+ (1005, 4, 4, "2024-03-05", "completed", 2499.90),
149
+ (1006, 5, 3, "2024-03-18", "pending", 149.99),
150
+ (1007, 6, 9, "2024-04-02", "completed", 749.97),
151
+ (1008, 3, 3, "2024-04-15", "completed", 339.98),
152
+ (1009, 7, 4, "2024-05-01", "cancelled", 299.99),
153
+ (1010, 8, 9, "2024-05-20", "shipped", 1349.97),
154
+ (1011, 1, 3, "2024-06-01", "completed", 199.98),
155
+ (1012, 4, 4, "2024-06-15", "completed", 3599.88),
156
+ ])
157
+
158
+ cursor.execute("""
159
+ CREATE TABLE order_items (
160
+ item_id INTEGER PRIMARY KEY,
161
+ order_id INTEGER REFERENCES orders(order_id),
162
+ product_id INTEGER REFERENCES products(product_id),
163
+ quantity INTEGER,
164
+ unit_price REAL,
165
+ discount REAL DEFAULT 0.0
166
+ )
167
+ """)
168
+ cursor.executemany("INSERT INTO order_items VALUES (?, ?, ?, ?, ?, ?)", [
169
+ (1, 1001, 1, 5, 299.99, 0.0),
170
+ (2, 1002, 2, 4, 149.99, 0.0),
171
+ (3, 1003, 3, 6, 49.99, 0.0),
172
+ (4, 1003, 5, 3, 89.99, 10.0),
173
+ (5, 1004, 8, 1, 449.99, 0.0),
174
+ (6, 1005, 1, 5, 299.99, 0.0),
175
+ (7, 1005, 4, 5, 199.99, 0.0),
176
+ (8, 1006, 2, 1, 149.99, 0.0),
177
+ (9, 1007, 5, 3, 89.99, 0.0),
178
+ (10, 1007, 3, 9, 49.99, 10.0),
179
+ (11, 1008, 6, 5, 39.99, 0.0),
180
+ (12, 1008, 3, 3, 49.99, 10.0),
181
+ (13, 1009, 1, 1, 299.99, 0.0),
182
+ (14, 1010, 8, 3, 449.99, 0.0),
183
+ (15, 1011, 3, 4, 49.99, 0.0),
184
+ (16, 1012, 4, 12, 199.99, 15.0),
185
+ (17, 1012, 8, 2, 449.99, 10.0),
186
+ ])
187
+
188
+ conn.commit()
189
+ conn.close()
190
+ return db_path
191
+
192
+
193
+ # ─────────────────────────────────────────────
194
+ # 2. Build Dynamic Schema Description
195
+ # ─────────────────────────────────────────────
196
+
197
+ def get_schema_description(db_path: str = DB_PATH) -> str:
198
+ conn = sqlite3.connect(db_path)
199
+ cursor = conn.cursor()
200
+ cursor.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
201
+ tables = [row[0] for row in cursor.fetchall()]
202
+
203
+ schema_parts = []
204
+ for table in tables:
205
+ cursor.execute(f"PRAGMA table_info({table})")
206
+ columns = cursor.fetchall()
207
+ cursor.execute(f"PRAGMA foreign_key_list({table})")
208
+ fks = cursor.fetchall()
209
+ fk_map = {fk[3]: f"β†’ {fk[2]}({fk[4]})" for fk in fks}
210
+ cursor.execute(f"SELECT COUNT(*) FROM {table}")
211
+ row_count = cursor.fetchone()[0]
212
+
213
+ table_desc = f"Table '{table}' ({row_count} rows):\n Columns:\n"
214
+ for col in columns:
215
+ col_id, col_name, col_type, not_null, default, pk = col
216
+ parts = [f" - {col_name}: {col_type or 'TEXT'}"]
217
+ if pk: parts.append("PRIMARY KEY")
218
+ if not_null and not pk: parts.append("NOT NULL")
219
+ if default is not None: parts.append(f"DEFAULT {default}")
220
+ if col_name in fk_map: parts.append(f"FK {fk_map[col_name]}")
221
+ table_desc += " ".join(parts) + "\n"
222
+
223
+ for col in columns:
224
+ col_name, col_type = col[1], col[2]
225
+ if col_type in ("TEXT", None) and col_name not in ("email",):
226
+ try:
227
+ cursor.execute(f"SELECT DISTINCT {col_name} FROM {table} WHERE {col_name} IS NOT NULL LIMIT 8")
228
+ vals = [str(r[0]) for r in cursor.fetchall()]
229
+ if vals:
230
+ table_desc += f" Sample '{col_name}' values: {', '.join(vals)}\n"
231
+ except:
232
+ pass
233
+ schema_parts.append(table_desc)
234
+
235
+ conn.close()
236
+ return "\n".join(schema_parts)
237
+
238
+
239
+ # ─────────────────────────────────────────────
240
+ # 3. Define Agent Tools
241
+ # ─────────────────────────────────────────────
242
+
243
+ SCHEMA_DESCRIPTION = ""
244
+
245
+
246
+ @tool
247
+ def execute_sql(query: str) -> str:
248
+ """
249
+ Executes a SQL query against the company database and returns the results.
250
+ Use this tool to run SELECT queries to answer user questions about the data.
251
+
252
+ IMPORTANT RULES:
253
+ - Only use SELECT statements (no INSERT, UPDATE, DELETE, DROP)
254
+ - Always use table and column names exactly as shown in the schema
255
+ - Use JOINs when data spans multiple tables
256
+ - Use LIMIT to avoid overwhelming output (max 50 rows)
257
+
258
+ DATABASE SCHEMA:
259
+ {schema}
260
+
261
+ Args:
262
+ query: A valid SQL SELECT query to execute against the database.
263
+ """
264
+ cleaned = query.strip().upper()
265
+ if not cleaned.startswith("SELECT") and not cleaned.startswith("WITH"):
266
+ return "ERROR: Only SELECT queries are allowed."
267
+
268
+ try:
269
+ conn = sqlite3.connect(DB_PATH)
270
+ cursor = conn.cursor()
271
+ cursor.execute(query)
272
+ columns = [desc[0] for desc in cursor.description] if cursor.description else []
273
+ rows = cursor.fetchall()
274
+ conn.close()
275
+
276
+ if not rows:
277
+ return f"Query executed successfully.\nColumns: {', '.join(columns)}\nResult: No rows returned."
278
+
279
+ result = f"Query executed successfully. {len(rows)} row(s) returned.\n"
280
+ result += "Columns: " + " | ".join(columns) + "\n"
281
+ result += "-" * 60 + "\n"
282
+ for row in rows[:50]:
283
+ result += " | ".join(str(v) for v in row) + "\n"
284
+ if len(rows) > 50:
285
+ result += f"... ({len(rows) - 50} more rows truncated)\n"
286
+ return result
287
+
288
+ except Exception as e:
289
+ return f"SQL ERROR: {str(e)}\n\nPlease check your query syntax and column/table names against the schema."
290
+
291
+
292
+ @tool
293
+ def inspect_schema(table_name: str = "") -> str:
294
+ """
295
+ Inspect the database schema. If a table_name is provided, shows detailed info
296
+ about that specific table including column types, foreign keys, and sample data.
297
+ If no table_name is given, shows an overview of all tables.
298
+
299
+ Use this tool BEFORE writing SQL to understand the database structure,
300
+ especially when the user's question is ambiguous about which tables or columns to use.
301
+
302
+ Args:
303
+ table_name: Name of a specific table to inspect. Leave empty for full schema overview.
304
+ """
305
+ conn = sqlite3.connect(DB_PATH)
306
+ cursor = conn.cursor()
307
+
308
+ if not table_name:
309
+ return f"DATABASE SCHEMA OVERVIEW:\n\n{SCHEMA_DESCRIPTION}"
310
+
311
+ try:
312
+ cursor.execute(f"PRAGMA table_info({table_name})")
313
+ columns = cursor.fetchall()
314
+ if not columns:
315
+ conn.close()
316
+ return f"Table '{table_name}' not found. Use inspect_schema() with no arguments to see all tables."
317
+
318
+ result = f"DETAILED INSPECTION OF TABLE '{table_name}':\n\n"
319
+ result += "Columns:\n"
320
+ for col in columns:
321
+ result += f" {col[1]} ({col[2] or 'TEXT'})"
322
+ if col[5]: result += " [PRIMARY KEY]"
323
+ if col[3]: result += " [NOT NULL]"
324
+ result += "\n"
325
+
326
+ cursor.execute(f"PRAGMA foreign_key_list({table_name})")
327
+ fks = cursor.fetchall()
328
+ if fks:
329
+ result += "\nForeign Keys:\n"
330
+ for fk in fks:
331
+ result += f" {fk[3]} β†’ {fk[2]}({fk[4]})\n"
332
+
333
+ cursor.execute(f"SELECT COUNT(*) FROM {table_name}")
334
+ count = cursor.fetchone()[0]
335
+ result += f"\nTotal rows: {count}\n"
336
+
337
+ cursor.execute(f"SELECT * FROM {table_name} LIMIT 3")
338
+ sample_rows = cursor.fetchall()
339
+ col_names = [c[1] for c in columns]
340
+ result += f"\nSample rows (first 3):\n"
341
+ result += " | ".join(col_names) + "\n"
342
+ result += "-" * 60 + "\n"
343
+ for row in sample_rows:
344
+ result += " | ".join(str(v) for v in row) + "\n"
345
+
346
+ conn.close()
347
+ return result
348
+
349
+ except Exception as e:
350
+ conn.close()
351
+ return f"Error inspecting table: {str(e)}"
352
+
353
+
354
+ # ─────────────────────────────────────────────
355
+ # 4. Agent System Prompt
356
+ # ─────────────────────────────────────────────
357
+
358
+ SYSTEM_INSTRUCTIONS = dedent("""\
359
+ You are an expert SQL assistant that helps users query a company database. You follow a structured multi-turn approach:
360
+
361
+ ## YOUR DECISION PROCESS
362
+
363
+ For EVERY user question, follow these steps:
364
+
365
+ ### Step 1: Classify the Question
366
+ Determine if the question is:
367
+ - **ANSWERABLE**: The question is clear and maps directly to the database schema
368
+ - **AMBIGUOUS**: The question could have multiple valid SQL interpretations (e.g., "show me the top employees" β€” top by salary? by sales? by tenure?)
369
+ - **UNANSWERABLE**: The question asks for data that doesn't exist in the database
370
+
371
+ ### Step 2: Handle Based on Classification
372
+
373
+ **If AMBIGUOUS:**
374
+ - Identify ALL possible interpretations
375
+ - Use `final_answer()` to return a targeted clarification question listing the specific options
376
+ - Example: call `final_answer("Your question could mean several things:\\n1. Employees with highest salary\\n2. Employees who handled the most orders\\n3. Employees with the longest tenure\\n\\nWhich interpretation do you mean?")`
377
+ - Do NOT generate SQL β€” return the clarification question immediately using `final_answer()`
378
+ - The user will respond in the next turn with their clarification
379
+
380
+ **If UNANSWERABLE:**
381
+ - Use `final_answer()` to explain clearly what data is missing and why the question can't be answered
382
+ - Include a suggestion for a related question that CAN be answered with the available data
383
+
384
+ **If ANSWERABLE:**
385
+ - First inspect the schema to confirm the right tables/columns
386
+ - Generate and execute the SQL query
387
+ - Present results clearly with a natural language summary
388
+
389
+ ### Step 3: Self-Correct
390
+ - If your SQL returns an error, analyze the error and fix the query
391
+ - If the result seems wrong or empty, verify your joins and filters
392
+ - Always sanity-check: does the result make sense given what was asked?
393
+
394
+ ## COMMON AMBIGUITY PATTERNS TO WATCH FOR
395
+
396
+ 1. **Column ambiguity**: "Show employee names" β€” the 'name' column appears in employees, departments, customers, and products tables
397
+ 2. **Metric ambiguity**: "Top customers" β€” by total spending? by number of orders? by most recent activity?
398
+ 3. **Filter ambiguity**: "Recent orders" β€” last week? last month? last quarter?
399
+ 4. **Scope ambiguity**: "Total sales" β€” all time? this year? by product? by employee?
400
+ 5. **Status ambiguity**: "List products" β€” all products? only active ones? including discontinued?
401
+ 6. **Value ambiguity**: "Expensive products" β€” what price threshold?
402
+
403
+ ## FORMATTING RULES
404
+
405
+ - When presenting query results, format them as a clear table
406
+ - Always explain what the query does in plain language
407
+ - If you make assumptions (e.g., "I'm assuming you mean active employees only"), state them explicitly
408
+ - For numerical results, include relevant aggregations (count, sum, average) when helpful
409
+ """)
410
+
411
+
412
+ # ─────────────────────────────────────────────
413
+ # 5. Main
414
+ # ─────────────────────────────────────────────
415
+
416
+ def create_agent(model_id: str = "Qwen/Qwen2.5-Coder-32B-Instruct"):
417
+ create_demo_database()
418
+
419
+ global SCHEMA_DESCRIPTION
420
+ SCHEMA_DESCRIPTION = get_schema_description()
421
+ execute_sql.description = execute_sql.description.replace("{schema}", SCHEMA_DESCRIPTION)
422
+
423
+ model = InferenceClientModel(model_id=model_id)
424
+
425
+ agent = CodeAgent(
426
+ tools=[execute_sql, inspect_schema],
427
+ model=model,
428
+ instructions=SYSTEM_INSTRUCTIONS,
429
+ max_steps=15,
430
+ additional_authorized_imports=["json", "re"],
431
+ )
432
+ return agent
433
+
434
+
435
+ if __name__ == "__main__":
436
+ agent = create_agent()
437
+ ui = GradioUI(agent, reset_agent_memory=False)
438
+ ui.launch()