ar9avg commited on
Commit
ed79e58
·
1 Parent(s): 4ec680a
backend/api/demo.py CHANGED
@@ -42,7 +42,7 @@ _DIFFICULTY_MAP = {
42
  "hard": "complex_queries",
43
  }
44
  from env.tasks import TASKS, get_task
45
- from env.sql_env import SQLAgentEnv, Action, get_env, BASE_SYSTEM_PROMPT, _clean_sql
46
  from rl.environment import get_bandit_state
47
  from rl.types import RepairAction, REPAIR_ACTION_NAMES, REPAIR_ACTION_BY_NAME
48
  from rl.error_classifier import classify_error, extract_offending_token
@@ -196,7 +196,7 @@ async def execute_query_stream(req: ExecuteQueryRequest):
196
  # Multi-turn conversation — grows with each failed attempt so the LLM
197
  # sees its own history and doesn't repeat the same mistake.
198
  conversation: list[dict] = [
199
- {"role": "system", "content": BASE_SYSTEM_PROMPT},
200
  {"role": "user", "content": initial_user_msg},
201
  ]
202
 
@@ -222,14 +222,14 @@ async def execute_query_stream(req: ExecuteQueryRequest):
222
  # Update system prompt with repair-specific guidance
223
  conversation[0] = {
224
  "role": "system",
225
- "content": BASE_SYSTEM_PROMPT + get_repair_system_suffix(repair_enum),
226
  }
227
  elif attempt > 1:
228
  repair_enum = RepairAction.REWRITE_FULL
229
  action = Action(repair_action="rewrite_full")
230
  conversation[0] = {
231
  "role": "system",
232
- "content": BASE_SYSTEM_PROMPT + get_repair_system_suffix(repair_enum),
233
  }
234
 
235
  # Stream SQL generation using the full conversation history
@@ -498,7 +498,7 @@ async def run_benchmark(req: BenchmarkRequest):
498
  ep = env._episode # type: ignore[union-attr]
499
 
500
  gepa = get_gepa()
501
- system_prompt = gepa.get_current_prompt()
502
  from env.sql_env import _make_client, _MODEL
503
 
504
  for attempt in range(1, max_attempts + 1):
 
42
  "hard": "complex_queries",
43
  }
44
  from env.tasks import TASKS, get_task
45
+ from env.sql_env import SQLAgentEnv, Action, get_env, BASE_SYSTEM_PROMPT, get_system_prompt, _clean_sql
46
  from rl.environment import get_bandit_state
47
  from rl.types import RepairAction, REPAIR_ACTION_NAMES, REPAIR_ACTION_BY_NAME
48
  from rl.error_classifier import classify_error, extract_offending_token
 
196
  # Multi-turn conversation — grows with each failed attempt so the LLM
197
  # sees its own history and doesn't repeat the same mistake.
198
  conversation: list[dict] = [
199
+ {"role": "system", "content": get_system_prompt()},
200
  {"role": "user", "content": initial_user_msg},
201
  ]
202
 
 
222
  # Update system prompt with repair-specific guidance
223
  conversation[0] = {
224
  "role": "system",
225
+ "content": get_system_prompt() + get_repair_system_suffix(repair_enum),
226
  }
227
  elif attempt > 1:
228
  repair_enum = RepairAction.REWRITE_FULL
229
  action = Action(repair_action="rewrite_full")
230
  conversation[0] = {
231
  "role": "system",
232
+ "content": get_system_prompt() + get_repair_system_suffix(repair_enum),
233
  }
234
 
235
  # Stream SQL generation using the full conversation history
 
498
  ep = env._episode # type: ignore[union-attr]
499
 
500
  gepa = get_gepa()
501
+ system_prompt = gepa.get_current_prompt() or get_system_prompt()
502
  from env.sql_env import _make_client, _MODEL
503
 
504
  for attempt in range(1, max_attempts + 1):
backend/env/database.py CHANGED
@@ -1,14 +1,8 @@
1
  """
2
- SQLite database setup and schema for the benchmark marketplace.
3
 
4
- Tables:
5
- sellers (id, name, email, country, rating)
6
- users (id, name, email, created_at, country)
7
- products (id, name, category, price, stock_quantity, seller_id)
8
- orders (id, user_id, product_id, quantity, total_price, status, created_at)
9
- reviews (id, user_id, product_id, rating, comment, created_at)
10
-
11
- ~50 rows per table of realistic seed data.
12
  """
13
 
14
  from __future__ import annotations
@@ -21,23 +15,59 @@ from typing import Any
21
  _DATA_DIR = Path(os.environ.get("DATA_DIR", Path(__file__).parent.parent / "data"))
22
  DB_PATH = _DATA_DIR / "benchmark.db"
23
 
24
- # Active DB path — can be overridden via connect_external_db()
25
- _active_db_path: str = str(DB_PATH)
26
  _active_db_label: str = "benchmark (built-in)"
 
 
 
 
 
27
 
28
 
29
- def connect_external_db(path: str) -> tuple[bool, str]:
30
- """Switch the active SQLite database. Returns (success, message)."""
31
- global _active_db_path, _active_db_label
32
  try:
33
- conn = sqlite3.connect(path)
34
- tables = conn.execute(
35
- "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name"
36
- ).fetchall()
37
- conn.close()
38
- _active_db_path = path
39
- _active_db_label = Path(path).name
40
- return True, f"Connected to {Path(path).name} ({len(tables)} tables)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  except Exception as e:
42
  return False, str(e)
43
 
@@ -46,8 +76,9 @@ def get_active_db_label() -> str:
46
  return _active_db_label
47
 
48
 
49
- def _get_db_path() -> str:
50
- return _active_db_path
 
51
 
52
 
53
  # ─── Schema ───────────────────────────────────────────────────────
@@ -333,10 +364,7 @@ def get_db_path() -> Path:
333
 
334
 
335
  def ensure_seeded() -> bool:
336
- """
337
- Create the database and populate seed data if not already done.
338
- Returns True if seed was needed (first run), False if already seeded.
339
- """
340
  _DATA_DIR.mkdir(parents=True, exist_ok=True)
341
  conn = sqlite3.connect(str(DB_PATH))
342
  try:
@@ -345,7 +373,7 @@ def ensure_seeded() -> bool:
345
 
346
  count = conn.execute("SELECT COUNT(*) FROM users").fetchone()[0]
347
  if count >= 50:
348
- return False # Already seeded
349
 
350
  conn.execute("DELETE FROM reviews")
351
  conn.execute("DELETE FROM orders")
@@ -353,43 +381,29 @@ def ensure_seeded() -> bool:
353
  conn.execute("DELETE FROM users")
354
  conn.execute("DELETE FROM sellers")
355
 
356
- conn.executemany(
357
- "INSERT OR REPLACE INTO sellers VALUES (?,?,?,?,?)", _SELLERS
358
- )
359
- conn.executemany(
360
- "INSERT OR REPLACE INTO users VALUES (?,?,?,?,?)", _USERS
361
- )
362
- conn.executemany(
363
- "INSERT OR REPLACE INTO products VALUES (?,?,?,?,?,?)", _PRODUCTS
364
- )
365
- conn.executemany(
366
- "INSERT OR REPLACE INTO orders VALUES (?,?,?,?,?,?,?)", _ORDERS
367
- )
368
- conn.executemany(
369
- "INSERT OR REPLACE INTO reviews VALUES (?,?,?,?,?,?)", _REVIEWS
370
- )
371
  conn.commit()
372
  return True
373
  finally:
374
  conn.close()
375
 
376
 
377
- def get_schema_info() -> str:
378
- """
379
- Return a concise textual schema summary for use in prompts.
380
- """
381
- conn = sqlite3.connect(_get_db_path())
382
  try:
383
  cur = conn.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
384
- tables_in_db = [r[0] for r in cur.fetchall()]
385
- # Fall back to all tables if schema is unknown
386
- tables = tables_in_db if tables_in_db else ["sellers", "users", "products", "orders", "reviews"]
387
  lines = []
388
  for table in tables:
389
  info = conn.execute(f"PRAGMA table_info({table})").fetchall()
390
  cols = ", ".join(
391
- f"{col[1]} {col[2]}{'(PK)' if col[5] else ''}"
392
- for col in info
393
  )
394
  row_count = conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone()[0]
395
  lines.append(f"Table: {table} ({row_count} rows)\n Columns: {cols}")
@@ -398,12 +412,59 @@ def get_schema_info() -> str:
398
  conn.close()
399
 
400
 
401
- def execute_query(sql: str) -> tuple[list[dict], str | None]:
402
- """
403
- Execute a SQL query and return (rows, error_message).
404
- rows is a list of dicts; error_message is None on success.
405
- """
406
- conn = sqlite3.connect(_get_db_path())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407
  conn.row_factory = sqlite3.Row
408
  try:
409
  cursor = conn.execute(sql)
@@ -415,49 +476,156 @@ def execute_query(sql: str) -> tuple[list[dict], str | None]:
415
  conn.close()
416
 
417
 
418
- def get_table_stats() -> list[dict]:
419
- """Return [{name, rows}, ...] for all tables."""
420
- conn = sqlite3.connect(_get_db_path())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
421
  try:
422
  cur = conn.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
423
  tables = [r[0] for r in cur.fetchall()] or ["sellers", "users", "products", "orders", "reviews"]
424
  return [
425
- {
426
- "name": t,
427
- "rows": conn.execute(f"SELECT COUNT(*) FROM {t}").fetchone()[0],
428
- }
429
  for t in tables
430
  ]
431
  finally:
432
  conn.close()
433
 
434
 
435
- def get_schema_graph() -> dict:
436
- """Return schema graph with tables, columns, and foreign keys."""
437
- conn = sqlite3.connect(str(DB_PATH))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
438
  try:
 
 
439
  tables = []
440
- for table in ["sellers", "users", "products", "orders", "reviews"]:
441
  info = conn.execute(f"PRAGMA table_info({table})").fetchall()
442
- columns = [
443
- {"name": col[1], "type": col[2], "pk": bool(col[5])}
444
- for col in info
445
- ]
446
  tables.append({"name": table, "columns": columns})
447
 
448
  foreign_keys = []
449
- for table in ["sellers", "users", "products", "orders", "reviews"]:
450
  fks = conn.execute(f"PRAGMA foreign_key_list({table})").fetchall()
451
  for fk in fks:
452
- foreign_keys.append(
453
- {
454
- "from_table": table,
455
- "from_col": fk[3],
456
- "to_table": fk[2],
457
- "to_col": fk[4],
458
- }
459
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
461
  return {"tables": tables, "foreign_keys": foreign_keys}
462
  finally:
463
  conn.close()
 
 
 
 
 
 
 
1
  """
2
+ Database abstraction supporting SQLite and PostgreSQL.
3
 
4
+ Active connection is set via connect_external_db(). Defaults to the built-in
5
+ SQLite benchmark database.
 
 
 
 
 
 
6
  """
7
 
8
  from __future__ import annotations
 
15
  _DATA_DIR = Path(os.environ.get("DATA_DIR", Path(__file__).parent.parent / "data"))
16
  DB_PATH = _DATA_DIR / "benchmark.db"
17
 
18
+ # ─── Active-connection state ──────────────────────────────────────
19
+ _active_dsn: str = str(DB_PATH) # SQLite path OR postgres DSN
20
  _active_db_label: str = "benchmark (built-in)"
21
+ _active_db_type: str = "sqlite" # "sqlite" | "postgres"
22
+
23
+
24
+ def _is_postgres(dsn: str) -> bool:
25
+ return dsn.startswith(("postgresql://", "postgres://"))
26
 
27
 
28
+ def _pg_label(dsn: str) -> str:
29
+ """Extract a short display label from a postgres DSN."""
 
30
  try:
31
+ # postgresql://user:pass@host:port/dbname → host/dbname
32
+ without_scheme = dsn.split("://", 1)[1]
33
+ at_split = without_scheme.rsplit("@", 1)
34
+ hostdb = at_split[-1] # host:port/dbname
35
+ parts = hostdb.split("/", 1)
36
+ host = parts[0].split(":")[0]
37
+ dbname = parts[1] if len(parts) > 1 else "?"
38
+ return f"{host}/{dbname}"
39
+ except Exception:
40
+ return "postgres"
41
+
42
+
43
+ def connect_external_db(dsn: str) -> tuple[bool, str]:
44
+ """Switch active database. Accepts a SQLite file path or a PostgreSQL DSN."""
45
+ global _active_dsn, _active_db_label, _active_db_type
46
+ try:
47
+ if _is_postgres(dsn):
48
+ import psycopg2 # type: ignore[import]
49
+ conn = psycopg2.connect(dsn)
50
+ cur = conn.cursor()
51
+ cur.execute(
52
+ "SELECT table_name FROM information_schema.tables "
53
+ "WHERE table_schema='public' AND table_type='BASE TABLE'"
54
+ )
55
+ tables = cur.fetchall()
56
+ conn.close()
57
+ _active_dsn = dsn
58
+ _active_db_label = _pg_label(dsn)
59
+ _active_db_type = "postgres"
60
+ return True, f"Connected to PostgreSQL: {_active_db_label} ({len(tables)} tables)"
61
+ else:
62
+ conn = sqlite3.connect(dsn)
63
+ tables = conn.execute(
64
+ "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name"
65
+ ).fetchall()
66
+ conn.close()
67
+ _active_dsn = dsn
68
+ _active_db_label = Path(dsn).name if dsn != ":memory:" else "in-memory"
69
+ _active_db_type = "sqlite"
70
+ return True, f"Connected to {_active_db_label} ({len(tables)} tables)"
71
  except Exception as e:
72
  return False, str(e)
73
 
 
76
  return _active_db_label
77
 
78
 
79
+ def get_active_db_type() -> str:
80
+ """Returns 'sqlite' or 'postgres'."""
81
+ return _active_db_type
82
 
83
 
84
  # ─── Schema ───────────────────────────────────────────────────────
 
364
 
365
 
366
  def ensure_seeded() -> bool:
367
+ """Create the database and populate seed data if not already done."""
 
 
 
368
  _DATA_DIR.mkdir(parents=True, exist_ok=True)
369
  conn = sqlite3.connect(str(DB_PATH))
370
  try:
 
373
 
374
  count = conn.execute("SELECT COUNT(*) FROM users").fetchone()[0]
375
  if count >= 50:
376
+ return False
377
 
378
  conn.execute("DELETE FROM reviews")
379
  conn.execute("DELETE FROM orders")
 
381
  conn.execute("DELETE FROM users")
382
  conn.execute("DELETE FROM sellers")
383
 
384
+ conn.executemany("INSERT OR REPLACE INTO sellers VALUES (?,?,?,?,?)", _SELLERS)
385
+ conn.executemany("INSERT OR REPLACE INTO users VALUES (?,?,?,?,?)", _USERS)
386
+ conn.executemany("INSERT OR REPLACE INTO products VALUES (?,?,?,?,?,?)", _PRODUCTS)
387
+ conn.executemany("INSERT OR REPLACE INTO orders VALUES (?,?,?,?,?,?,?)", _ORDERS)
388
+ conn.executemany("INSERT OR REPLACE INTO reviews VALUES (?,?,?,?,?,?)", _REVIEWS)
 
 
 
 
 
 
 
 
 
 
389
  conn.commit()
390
  return True
391
  finally:
392
  conn.close()
393
 
394
 
395
+ # ─── Schema info ──────────────────────────────────────────────────
396
+
397
+ def _schema_info_sqlite() -> str:
398
+ conn = sqlite3.connect(_active_dsn)
 
399
  try:
400
  cur = conn.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
401
+ tables = [r[0] for r in cur.fetchall()] or ["sellers", "users", "products", "orders", "reviews"]
 
 
402
  lines = []
403
  for table in tables:
404
  info = conn.execute(f"PRAGMA table_info({table})").fetchall()
405
  cols = ", ".join(
406
+ f"{col[1]} {col[2]}{'(PK)' if col[5] else ''}" for col in info
 
407
  )
408
  row_count = conn.execute(f"SELECT COUNT(*) FROM {table}").fetchone()[0]
409
  lines.append(f"Table: {table} ({row_count} rows)\n Columns: {cols}")
 
412
  conn.close()
413
 
414
 
415
+ def _schema_info_postgres() -> str:
416
+ import psycopg2 # type: ignore[import]
417
+ conn = psycopg2.connect(_active_dsn)
418
+ try:
419
+ cur = conn.cursor()
420
+ cur.execute(
421
+ "SELECT table_name FROM information_schema.tables "
422
+ "WHERE table_schema='public' AND table_type='BASE TABLE' ORDER BY table_name"
423
+ )
424
+ tables = [r[0] for r in cur.fetchall()]
425
+
426
+ # Primary keys per table
427
+ cur.execute(
428
+ "SELECT tc.table_name, kcu.column_name "
429
+ "FROM information_schema.table_constraints tc "
430
+ "JOIN information_schema.key_column_usage kcu "
431
+ " ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema "
432
+ "WHERE tc.constraint_type = 'PRIMARY KEY' AND tc.table_schema = 'public'"
433
+ )
434
+ pks: dict[str, set[str]] = {}
435
+ for tbl, col in cur.fetchall():
436
+ pks.setdefault(tbl, set()).add(col)
437
+
438
+ lines = []
439
+ for table in tables:
440
+ cur.execute(
441
+ "SELECT column_name, data_type FROM information_schema.columns "
442
+ "WHERE table_name = %s AND table_schema = 'public' ORDER BY ordinal_position",
443
+ (table,),
444
+ )
445
+ cols_info = cur.fetchall()
446
+ cols = ", ".join(
447
+ f"{col} {dtype}{'(PK)' if col in pks.get(table, set()) else ''}"
448
+ for col, dtype in cols_info
449
+ )
450
+ cur.execute(f'SELECT COUNT(*) FROM "{table}"')
451
+ row_count = cur.fetchone()[0]
452
+ lines.append(f"Table: {table} ({row_count} rows)\n Columns: {cols}")
453
+ return "\n\n".join(lines)
454
+ finally:
455
+ conn.close()
456
+
457
+
458
+ def get_schema_info() -> str:
459
+ if _active_db_type == "postgres":
460
+ return _schema_info_postgres()
461
+ return _schema_info_sqlite()
462
+
463
+
464
+ # ─── Execute query ────────────────────────────────────────────────
465
+
466
+ def _execute_sqlite(sql: str) -> tuple[list[dict], str | None]:
467
+ conn = sqlite3.connect(_active_dsn)
468
  conn.row_factory = sqlite3.Row
469
  try:
470
  cursor = conn.execute(sql)
 
476
  conn.close()
477
 
478
 
479
+ def _execute_postgres(sql: str) -> tuple[list[dict], str | None]:
480
+ import psycopg2 # type: ignore[import]
481
+ import psycopg2.extras # type: ignore[import]
482
+ conn = psycopg2.connect(_active_dsn)
483
+ try:
484
+ cur = conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
485
+ cur.execute(sql)
486
+ if cur.description is not None:
487
+ rows = [dict(row) for row in cur.fetchall()]
488
+ else:
489
+ rows = []
490
+ conn.commit()
491
+ return rows, None
492
+ except psycopg2.Error as e:
493
+ return [], str(e).strip()
494
+ finally:
495
+ conn.close()
496
+
497
+
498
+ def execute_query(sql: str) -> tuple[list[dict], str | None]:
499
+ """Execute a SQL query and return (rows, error_message)."""
500
+ if _active_db_type == "postgres":
501
+ return _execute_postgres(sql)
502
+ return _execute_sqlite(sql)
503
+
504
+
505
+ # ─── Table stats ──────────────────────────────────────────────────
506
+
507
+ def _table_stats_sqlite() -> list[dict]:
508
+ conn = sqlite3.connect(_active_dsn)
509
  try:
510
  cur = conn.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
511
  tables = [r[0] for r in cur.fetchall()] or ["sellers", "users", "products", "orders", "reviews"]
512
  return [
513
+ {"name": t, "rows": conn.execute(f"SELECT COUNT(*) FROM {t}").fetchone()[0]}
 
 
 
514
  for t in tables
515
  ]
516
  finally:
517
  conn.close()
518
 
519
 
520
+ def _table_stats_postgres() -> list[dict]:
521
+ import psycopg2 # type: ignore[import]
522
+ conn = psycopg2.connect(_active_dsn)
523
+ try:
524
+ cur = conn.cursor()
525
+ cur.execute(
526
+ "SELECT table_name FROM information_schema.tables "
527
+ "WHERE table_schema='public' AND table_type='BASE TABLE' ORDER BY table_name"
528
+ )
529
+ tables = [r[0] for r in cur.fetchall()]
530
+ result = []
531
+ for t in tables:
532
+ cur.execute(f'SELECT COUNT(*) FROM "{t}"')
533
+ result.append({"name": t, "rows": cur.fetchone()[0]})
534
+ return result
535
+ finally:
536
+ conn.close()
537
+
538
+
539
+ def get_table_stats() -> list[dict]:
540
+ if _active_db_type == "postgres":
541
+ return _table_stats_postgres()
542
+ return _table_stats_sqlite()
543
+
544
+
545
+ # ─── Schema graph ─────────────────────────────────────────────────
546
+
547
+ def _schema_graph_sqlite() -> dict:
548
+ conn = sqlite3.connect(_active_dsn)
549
  try:
550
+ cur = conn.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
551
+ table_names = [r[0] for r in cur.fetchall()]
552
  tables = []
553
+ for table in table_names:
554
  info = conn.execute(f"PRAGMA table_info({table})").fetchall()
555
+ columns = [{"name": col[1], "type": col[2], "pk": bool(col[5])} for col in info]
 
 
 
556
  tables.append({"name": table, "columns": columns})
557
 
558
  foreign_keys = []
559
+ for table in table_names:
560
  fks = conn.execute(f"PRAGMA foreign_key_list({table})").fetchall()
561
  for fk in fks:
562
+ foreign_keys.append({
563
+ "from_table": table,
564
+ "from_col": fk[3],
565
+ "to_table": fk[2],
566
+ "to_col": fk[4],
567
+ })
568
+ return {"tables": tables, "foreign_keys": foreign_keys}
569
+ finally:
570
+ conn.close()
571
+
572
+
573
+ def _schema_graph_postgres() -> dict:
574
+ import psycopg2 # type: ignore[import]
575
+ conn = psycopg2.connect(_active_dsn)
576
+ try:
577
+ cur = conn.cursor()
578
+ cur.execute(
579
+ "SELECT table_name FROM information_schema.tables "
580
+ "WHERE table_schema='public' AND table_type='BASE TABLE' ORDER BY table_name"
581
+ )
582
+ table_names = [r[0] for r in cur.fetchall()]
583
+
584
+ cur.execute(
585
+ "SELECT tc.table_name, kcu.column_name "
586
+ "FROM information_schema.table_constraints tc "
587
+ "JOIN information_schema.key_column_usage kcu "
588
+ " ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema "
589
+ "WHERE tc.constraint_type = 'PRIMARY KEY' AND tc.table_schema = 'public'"
590
+ )
591
+ pks: dict[str, set[str]] = {}
592
+ for tbl, col in cur.fetchall():
593
+ pks.setdefault(tbl, set()).add(col)
594
+
595
+ tables = []
596
+ for table in table_names:
597
+ cur.execute(
598
+ "SELECT column_name, data_type FROM information_schema.columns "
599
+ "WHERE table_name = %s AND table_schema = 'public' ORDER BY ordinal_position",
600
+ (table,),
601
+ )
602
+ columns = [
603
+ {"name": col, "type": dtype, "pk": col in pks.get(table, set())}
604
+ for col, dtype in cur.fetchall()
605
+ ]
606
+ tables.append({"name": table, "columns": columns})
607
 
608
+ cur.execute(
609
+ "SELECT kcu.table_name, kcu.column_name, ccu.table_name, ccu.column_name "
610
+ "FROM information_schema.table_constraints tc "
611
+ "JOIN information_schema.key_column_usage kcu "
612
+ " ON tc.constraint_name = kcu.constraint_name AND tc.table_schema = kcu.table_schema "
613
+ "JOIN information_schema.referential_constraints rc "
614
+ " ON tc.constraint_name = rc.constraint_name "
615
+ "JOIN information_schema.constraint_column_usage ccu "
616
+ " ON ccu.constraint_name = rc.unique_constraint_name AND ccu.table_schema = tc.table_schema "
617
+ "WHERE tc.constraint_type = 'FOREIGN KEY' AND tc.table_schema = 'public'"
618
+ )
619
+ foreign_keys = [
620
+ {"from_table": r[0], "from_col": r[1], "to_table": r[2], "to_col": r[3]}
621
+ for r in cur.fetchall()
622
+ ]
623
  return {"tables": tables, "foreign_keys": foreign_keys}
624
  finally:
625
  conn.close()
626
+
627
+
628
+ def get_schema_graph() -> dict:
629
+ if _active_db_type == "postgres":
630
+ return _schema_graph_postgres()
631
+ return _schema_graph_sqlite()
backend/env/sql_env.py CHANGED
@@ -84,6 +84,22 @@ Rules:
84
  - Use SQLite syntax
85
  - Do not include semicolons at the end"""
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  def _clean_sql(raw: str) -> str:
89
  """Strip markdown code fences and extra whitespace."""
 
84
  - Use SQLite syntax
85
  - Do not include semicolons at the end"""
86
 
87
+ _POSTGRES_SYSTEM_PROMPT = """You are a SQL expert. Given a natural language question and a PostgreSQL database schema, write a correct SQL query.
88
+
89
+ Rules:
90
+ - Output ONLY the SQL query, nothing else
91
+ - No markdown, no code fences, no explanation
92
+ - Use PostgreSQL syntax
93
+ - Do not include semicolons at the end"""
94
+
95
+
96
+ def get_system_prompt() -> str:
97
+ """Return the system prompt appropriate for the currently active database dialect."""
98
+ from env.database import get_active_db_type
99
+ if get_active_db_type() == "postgres":
100
+ return _POSTGRES_SYSTEM_PROMPT
101
+ return BASE_SYSTEM_PROMPT
102
+
103
 
104
  def _clean_sql(raw: str) -> str:
105
  """Strip markdown code fences and extra whitespace."""
backend/requirements.txt CHANGED
@@ -7,3 +7,4 @@ aiofiles>=24.0.0
7
  python-multipart>=0.0.9
8
  sse-starlette>=2.1.0
9
  aiosqlite>=0.20.0
 
 
7
  python-multipart>=0.0.9
8
  sse-starlette>=2.1.0
9
  aiosqlite>=0.20.0
10
+ psycopg2-binary>=2.9.0
frontend/src/components/ConnectDB.tsx CHANGED
@@ -8,23 +8,45 @@ interface ConnectDBProps {
8
  onClose: () => void
9
  }
10
 
11
- const EXAMPLES = [
 
 
12
  { label: 'In-memory (blank)', value: ':memory:' },
13
  { label: 'Custom path', value: '/path/to/your/database.db' },
14
  ]
15
 
 
 
 
 
 
16
  export function ConnectDB({ onClose }: ConnectDBProps) {
17
  const { dbLabel, setDbLabel, setTables, setDbSeeded } = useStore()
18
- const [path, setPath] = useState('')
 
19
  const [status, setStatus] = useState<'idle' | 'connecting' | 'success' | 'error'>('idle')
20
  const [message, setMessage] = useState('')
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  const handleConnect = async () => {
23
- if (!path.trim()) return
 
24
  setStatus('connecting')
25
  setMessage('')
26
  try {
27
- const res = await connectExternalDb(path.trim())
28
  if (res.success) {
29
  setDbLabel(res.dbLabel)
30
  setTables(res.tables)
@@ -86,7 +108,7 @@ export function ConnectDB({ onClose }: ConnectDBProps) {
86
  </div>
87
  <div>
88
  <h2 className="text-sm font-semibold theme-text-primary">Connect Database</h2>
89
- <p className="text-[10px] text-gray-500">SQLite file path</p>
90
  </div>
91
  </div>
92
  <button onClick={onClose} className="p-1.5 rounded-lg hover:bg-white/5 text-gray-500 hover:text-gray-300 transition-colors">
@@ -111,16 +133,38 @@ export function ConnectDB({ onClose }: ConnectDBProps) {
111
 
112
  {/* Body */}
113
  <div className="px-5 py-4 flex flex-col gap-4">
 
114
  <div>
115
  <label className="text-[10px] font-semibold text-gray-500 uppercase tracking-wider block mb-1.5">
116
- SQLite File Path
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  </label>
118
  <input
119
- type="text"
120
- value={path}
121
- onChange={(e) => { setPath(e.target.value); setStatus('idle') }}
122
  onKeyDown={(e) => e.key === 'Enter' && void handleConnect()}
123
- placeholder="/path/to/database.db"
124
  className="w-full px-3 py-2.5 text-sm rounded-xl border focus:outline-none transition-all font-mono"
125
  style={{
126
  background: 'var(--bg-tertiary)',
@@ -129,16 +173,21 @@ export function ConnectDB({ onClose }: ConnectDBProps) {
129
  }}
130
  autoFocus
131
  />
 
 
 
 
 
132
  </div>
133
 
134
  {/* Quick examples */}
135
  <div className="flex flex-col gap-1.5">
136
  <span className="text-[10px] text-gray-600 uppercase tracking-wider">Quick select</span>
137
  <div className="flex flex-wrap gap-1.5">
138
- {EXAMPLES.map((ex) => (
139
  <button
140
  key={ex.value}
141
- onClick={() => { setPath(ex.value); setStatus('idle') }}
142
  className="text-[10px] px-2.5 py-1 rounded-full border transition-all text-gray-500 hover:text-gray-300"
143
  style={{ borderColor: 'var(--border-color)', background: 'var(--bg-tertiary)' }}
144
  >
@@ -173,7 +222,7 @@ export function ConnectDB({ onClose }: ConnectDBProps) {
173
  </button>
174
  <button
175
  onClick={() => void handleConnect()}
176
- disabled={!path.trim() || status === 'connecting'}
177
  className="flex items-center gap-1.5 px-4 py-2 rounded-xl text-xs font-semibold text-white transition-all active:scale-95 disabled:opacity-40 disabled:cursor-not-allowed"
178
  style={{ background: 'linear-gradient(135deg,#7c3aed,#2563eb)' }}
179
  >
 
8
  onClose: () => void
9
  }
10
 
11
+ type DbType = 'sqlite' | 'postgres'
12
+
13
+ const SQLITE_EXAMPLES = [
14
  { label: 'In-memory (blank)', value: ':memory:' },
15
  { label: 'Custom path', value: '/path/to/your/database.db' },
16
  ]
17
 
18
+ const POSTGRES_EXAMPLES = [
19
+ { label: 'Local default', value: 'postgresql://postgres:password@localhost:5432/mydb' },
20
+ { label: 'With SSL', value: 'postgresql://user:pass@host:5432/dbname?sslmode=require' },
21
+ ]
22
+
23
  export function ConnectDB({ onClose }: ConnectDBProps) {
24
  const { dbLabel, setDbLabel, setTables, setDbSeeded } = useStore()
25
+ const [dbType, setDbType] = useState<DbType>('sqlite')
26
+ const [value, setValue] = useState('')
27
  const [status, setStatus] = useState<'idle' | 'connecting' | 'success' | 'error'>('idle')
28
  const [message, setMessage] = useState('')
29
 
30
+ const placeholder = dbType === 'postgres'
31
+ ? 'postgresql://user:password@host:5432/dbname'
32
+ : '/path/to/database.db'
33
+
34
+ const examples = dbType === 'postgres' ? POSTGRES_EXAMPLES : SQLITE_EXAMPLES
35
+
36
+ const getDsn = () => {
37
+ if (dbType === 'postgres' && value.trim() && !value.trim().startsWith('postgresql://') && !value.trim().startsWith('postgres://')) {
38
+ return `postgresql://${value.trim()}`
39
+ }
40
+ return value.trim()
41
+ }
42
+
43
  const handleConnect = async () => {
44
+ const dsn = getDsn()
45
+ if (!dsn) return
46
  setStatus('connecting')
47
  setMessage('')
48
  try {
49
+ const res = await connectExternalDb(dsn)
50
  if (res.success) {
51
  setDbLabel(res.dbLabel)
52
  setTables(res.tables)
 
108
  </div>
109
  <div>
110
  <h2 className="text-sm font-semibold theme-text-primary">Connect Database</h2>
111
+ <p className="text-[10px] text-gray-500">SQLite file or PostgreSQL connection string</p>
112
  </div>
113
  </div>
114
  <button onClick={onClose} className="p-1.5 rounded-lg hover:bg-white/5 text-gray-500 hover:text-gray-300 transition-colors">
 
133
 
134
  {/* Body */}
135
  <div className="px-5 py-4 flex flex-col gap-4">
136
+ {/* DB Type toggle */}
137
  <div>
138
  <label className="text-[10px] font-semibold text-gray-500 uppercase tracking-wider block mb-1.5">
139
+ Database Type
140
+ </label>
141
+ <div className="flex rounded-xl overflow-hidden border" style={{ borderColor: 'var(--border-color)' }}>
142
+ {(['sqlite', 'postgres'] as DbType[]).map((type) => (
143
+ <button
144
+ key={type}
145
+ onClick={() => { setDbType(type); setValue(''); setStatus('idle') }}
146
+ className="flex-1 py-2 text-xs font-medium transition-all"
147
+ style={{
148
+ background: dbType === type ? 'linear-gradient(135deg,#7c3aed,#2563eb)' : 'var(--bg-tertiary)',
149
+ color: dbType === type ? '#fff' : 'var(--text-secondary)',
150
+ }}
151
+ >
152
+ {type === 'sqlite' ? 'SQLite' : 'PostgreSQL'}
153
+ </button>
154
+ ))}
155
+ </div>
156
+ </div>
157
+
158
+ <div>
159
+ <label className="text-[10px] font-semibold text-gray-500 uppercase tracking-wider block mb-1.5">
160
+ {dbType === 'postgres' ? 'Connection String' : 'File Path'}
161
  </label>
162
  <input
163
+ type={dbType === 'postgres' ? 'text' : 'text'}
164
+ value={value}
165
+ onChange={(e) => { setValue(e.target.value); setStatus('idle') }}
166
  onKeyDown={(e) => e.key === 'Enter' && void handleConnect()}
167
+ placeholder={placeholder}
168
  className="w-full px-3 py-2.5 text-sm rounded-xl border focus:outline-none transition-all font-mono"
169
  style={{
170
  background: 'var(--bg-tertiary)',
 
173
  }}
174
  autoFocus
175
  />
176
+ {dbType === 'postgres' && (
177
+ <p className="text-[10px] text-gray-600 mt-1">
178
+ Format: <span className="font-mono text-gray-500">postgresql://user:password@host:port/dbname</span>
179
+ </p>
180
+ )}
181
  </div>
182
 
183
  {/* Quick examples */}
184
  <div className="flex flex-col gap-1.5">
185
  <span className="text-[10px] text-gray-600 uppercase tracking-wider">Quick select</span>
186
  <div className="flex flex-wrap gap-1.5">
187
+ {examples.map((ex) => (
188
  <button
189
  key={ex.value}
190
+ onClick={() => { setValue(ex.value); setStatus('idle') }}
191
  className="text-[10px] px-2.5 py-1 rounded-full border transition-all text-gray-500 hover:text-gray-300"
192
  style={{ borderColor: 'var(--border-color)', background: 'var(--bg-tertiary)' }}
193
  >
 
222
  </button>
223
  <button
224
  onClick={() => void handleConnect()}
225
+ disabled={!value.trim() || status === 'connecting'}
226
  className="flex items-center gap-1.5 px-4 py-2 rounded-xl text-xs font-semibold text-white transition-all active:scale-95 disabled:opacity-40 disabled:cursor-not-allowed"
227
  style={{ background: 'linear-gradient(135deg,#7c3aed,#2563eb)' }}
228
  >