| import time |
| import json |
| from typing import Dict, Any, List, Union, Optional |
| from pathlib import Path |
| import psycopg2 |
| import psycopg2.extras |
| import re |
|
|
| from .database_base import DatabaseBase, DatabaseType, QueryType, DatabaseConnection |
| from .tool import Tool, Toolkit |
| from ..core.logging import logger |
|
|
| class PostgreSQLConnection(DatabaseConnection): |
| """PostgreSQL-specific connection management""" |
| def __init__(self, connection_string: str, **kwargs): |
| super().__init__(connection_string, **kwargs) |
| self.conn = None |
|
|
| def connect(self) -> bool: |
| try: |
| self.conn = psycopg2.connect(self.connection_string, **self.connection_params) |
| self._is_connected = True |
| logger.info("Successfully connected to PostgreSQL") |
| return True |
| except Exception as e: |
| logger.error(f"Failed to connect to PostgreSQL: {str(e)}") |
| self._is_connected = False |
| return False |
|
|
| def disconnect(self) -> bool: |
| try: |
| if self.conn: |
| self.conn.close() |
| self.conn = None |
| self._is_connected = False |
| logger.info("Disconnected from PostgreSQL") |
| return True |
| except Exception as e: |
| logger.error(f"Error disconnecting from PostgreSQL: {str(e)}") |
| return False |
|
|
| def test_connection(self) -> bool: |
| try: |
| if self.conn: |
| with self.conn.cursor() as cur: |
| cur.execute("SELECT 1;") |
| return True |
| return False |
| except Exception: |
| return False |
|
|
| class PostgreSQLDatabase(DatabaseBase): |
| """ |
| PostgreSQL database implementation with automatic initialization. |
| Handles remote connections, existing local databases, and new local database creation. |
| """ |
| def __init__(self, |
| connection_string: str = None, |
| database_name: str = None, |
| local_path: str = None, |
| auto_save: bool = True, |
| **kwargs): |
| init_params = { |
| 'connection_string': connection_string, |
| 'database_name': database_name |
| } |
| super().__init__(**init_params, **kwargs) |
| self.local_path = Path(local_path) if local_path else None |
| self.auto_save = auto_save |
| self.connection_params = kwargs |
| self.is_local_database = False |
| self.conn = None |
| self.cursor = None |
| self.file_based_mode = False |
| self.tables = {} |
| |
| if self._is_remote_connection(): |
| self._init_remote_database() |
| elif self._is_existing_local_database(): |
| self._init_existing_local_database() |
| else: |
| self._init_new_local_database() |
|
|
| def _is_remote_connection(self) -> bool: |
| return self.connection_string and ("@" in self.connection_string or "postgresql://" in self.connection_string) |
|
|
| def _is_existing_local_database(self) -> bool: |
| if not self.local_path: |
| return False |
| if not self.local_path.exists(): |
| return False |
| db_info_file = self.local_path / "db_info.json" |
| return db_info_file.exists() |
|
|
| def _init_remote_database(self): |
| """Initialize remote PostgreSQL connection""" |
| try: |
| |
| connection_params = self.connection_params.copy() |
| connection_params.update({ |
| 'connect_timeout': 5, |
| 'options': '-c statement_timeout=5000' |
| }) |
| |
| self.conn = psycopg2.connect(self.connection_string, **connection_params) |
| self.cursor = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) |
| if self.database_name: |
| self.conn.set_isolation_level(0) |
| self.cursor.execute("SELECT 1 FROM pg_database WHERE datname = %s", (self.database_name,)) |
| self._is_initialized = True |
| self.is_local_database = False |
| self.file_based_mode = False |
| logger.info(f"Connected to remote PostgreSQL: {self.database_name}") |
| except Exception as e: |
| logger.error(f"Failed to connect to remote PostgreSQL: {str(e)}") |
| self._is_initialized = False |
| |
| logger.info("Falling back to local database mode") |
|
|
| def _init_existing_local_database(self): |
| """Initialize existing local file-based database""" |
| try: |
| if not self.database_name: |
| self.database_name = self.local_path.name |
| |
| |
| self._load_tables_from_files() |
| |
| self._is_initialized = True |
| self.is_local_database = True |
| self.file_based_mode = True |
| logger.info(f"Loaded existing local file-based database from: {self.local_path}") |
| except Exception as e: |
| logger.error(f"Failed to load existing local database: {str(e)}") |
| self._is_initialized = False |
| logger.info("Falling back to new local database mode") |
| self._init_new_local_database() |
|
|
| def _init_new_local_database(self): |
| """Initialize new local file-based database""" |
| try: |
| if not self.local_path: |
| self.local_path = Path("./workplace/postgresql_local") |
| self.local_path.mkdir(parents=True, exist_ok=True) |
| |
| if not self.database_name: |
| self.database_name = self.local_path.name |
| |
| self._create_db_info_file() |
| self._is_initialized = True |
| self.is_local_database = True |
| self.file_based_mode = True |
| logger.info(f"Created new local file-based database at: {self.local_path}") |
| except Exception as e: |
| logger.error(f"Failed to create new local database: {str(e)}") |
| self._is_initialized = False |
| logger.info("Database initialization failed, but toolkit is still usable") |
|
|
| def _create_db_info_file(self): |
| """Create database info file""" |
| try: |
| db_info = { |
| "database_name": self.database_name, |
| "created_at": time.time(), |
| "local_path": str(self.local_path.absolute()), |
| "auto_save": self.auto_save, |
| "version": "1.0", |
| "mode": "file_based" |
| } |
| info_file = self.local_path / "db_info.json" |
| with open(info_file, 'w', encoding='utf-8') as f: |
| json.dump(db_info, f, indent=2, ensure_ascii=False) |
| except Exception as e: |
| logger.warning(f"Failed to create db info file: {str(e)}") |
|
|
| def _load_tables_from_files(self): |
| """Load tables from JSON files""" |
| try: |
| for json_file in self.local_path.glob("*.json"): |
| if json_file.name == "db_info.json": |
| continue |
| table_name = json_file.stem |
| with open(json_file, 'r', encoding='utf-8') as f: |
| loaded_data = json.load(f) |
| |
| if not isinstance(loaded_data, list): |
| logger.warning(f"Table {table_name} file contains non-list data: {type(loaded_data)}, converting to empty list") |
| self.tables[table_name] = [] |
| else: |
| self.tables[table_name] = loaded_data |
| except Exception as e: |
| logger.warning(f"Error loading tables from files: {str(e)}") |
|
|
| def _save_table_to_file(self, table_name: str): |
| """Save table data to JSON file""" |
| try: |
| if table_name in self.tables: |
| table_file = self.local_path / f"{table_name}.json" |
| with open(table_file, 'w', encoding='utf-8') as f: |
| json.dump(self.tables[table_name], f, indent=2, ensure_ascii=False) |
| except Exception as e: |
| logger.error(f"Error saving table {table_name}: {str(e)}") |
|
|
|
|
|
|
| def _parse_sql_query(self, sql: str) -> Dict[str, Any]: |
| """Enhanced SQL parser for file-based mode - now supports JOINs and complex queries""" |
| sql = sql.strip() |
| upper_sql = sql.upper() |
| |
| |
| if upper_sql.startswith("CREATE TABLE"): |
| match = re.search(r"CREATE TABLE (?:IF NOT EXISTS )?(\w+) *\((.*?)\)", sql, re.IGNORECASE | re.DOTALL) |
| if match: |
| table = match.group(1).lower() |
| columns = match.group(2) |
| col_defs = [c.strip() for c in columns.split(',') if c.strip()] |
| col_names = [c.split()[0] for c in col_defs] |
| return {"type": "CREATE", "table": table, "columns": col_names} |
| |
| |
| elif upper_sql.startswith("INSERT"): |
| match = re.search(r"INSERT INTO (\w+) *\((.*?)\) *VALUES", sql, re.IGNORECASE | re.DOTALL) |
| if match: |
| table = match.group(1).lower() |
| columns = [c.strip() for c in match.group(2).split(',')] |
| values_match = re.search(r"VALUES\s*(.*)", sql, re.IGNORECASE | re.DOTALL) |
| if values_match: |
| values_str = values_match.group(1) |
| value_groups = re.findall(r'\(([^)]+)\)', values_str) |
| all_values = [] |
| for group in value_groups: |
| values = [v.strip().strip("'\"") for v in group.split(',')] |
| all_values.append(values) |
| return {"type": "INSERT", "table": table, "columns": columns, "values": all_values} |
| |
| |
| elif upper_sql.startswith("SELECT"): |
| |
| if "JOIN" in upper_sql: |
| |
| match = re.search(r"SELECT (.*?) FROM (\w+)(?:\s+(\w+))?\s+(?:(\w+)\s+)?JOIN\s+(\w+)(?:\s+(\w+))?\s+ON\s+(.*?)(?: WHERE (.*?))?(?: ORDER BY (.*?))?(?: LIMIT (\d+))?", sql, re.IGNORECASE | re.DOTALL) |
| if match: |
| columns = [c.strip() for c in match.group(1).split(',')] |
| table1 = match.group(2).lower() |
| alias1 = match.group(3) |
| join_type = match.group(4) or "INNER" |
| table2 = match.group(5).lower() |
| alias2 = match.group(6) |
| join_condition = match.group(7) |
| where = match.group(8) |
| order_by = match.group(9) |
| limit = match.group(10) |
| |
| return { |
| "type": "SELECT_JOIN", |
| "columns": columns, |
| "table1": table1, |
| "alias1": alias1, |
| "join_type": join_type, |
| "table2": table2, |
| "alias2": alias2, |
| "join_condition": join_condition, |
| "where": where, |
| "order_by": order_by, |
| "limit": limit |
| } |
| |
| |
| elif "CROSS JOIN" in upper_sql: |
| match = re.search(r"SELECT (.*?) FROM (\w+)(?:\s+(\w+))?\s+CROSS\s+JOIN\s+(\w+)(?:\s+(\w+))?(?: WHERE (.*?))?(?: ORDER BY (.*?))?(?: LIMIT (\d+))?", sql, re.IGNORECASE | re.DOTALL) |
| if match: |
| columns = [c.strip() for c in match.group(1).split(',')] |
| table1 = match.group(2).lower() |
| alias1 = match.group(3) |
| table2 = match.group(4).lower() |
| alias2 = match.group(5) |
| where = match.group(6) |
| order_by = match.group(7) |
| limit = match.group(8) |
| |
| return { |
| "type": "SELECT_CROSS_JOIN", |
| "columns": columns, |
| "table1": table1, |
| "alias1": alias1, |
| "table2": table2, |
| "alias2": alias2, |
| "where": where, |
| "order_by": order_by, |
| "limit": limit |
| } |
| |
| |
| else: |
| match = re.search(r"SELECT (.*?) FROM (\w+)(?: WHERE (.*?))?(?: GROUP BY (.*?))?(?: ORDER BY (.*?))?(?: LIMIT (\d+))?", sql, re.IGNORECASE | re.DOTALL) |
| if match: |
| columns = [c.strip() for c in match.group(1).split(',')] |
| table = match.group(2).lower() |
| where = match.group(3) |
| group_by = match.group(4) |
| order_by = match.group(5) |
| limit = match.group(6) |
| return {"type": "SELECT", "table": table, "columns": columns, "where": where, "group_by": group_by, "order_by": order_by, "limit": limit} |
| |
| |
| elif upper_sql.startswith("UPDATE"): |
| match = re.search(r"UPDATE (\w+) SET (.*?)(?: WHERE (.*?))?$", sql, re.IGNORECASE | re.DOTALL) |
| if match: |
| table = match.group(1).lower() |
| set_clause = match.group(2) |
| where = match.group(3) |
| return {"type": "UPDATE", "table": table, "set": set_clause, "where": where} |
| |
| |
| elif upper_sql.startswith("DELETE"): |
| match = re.search(r"DELETE FROM (\w+)(?: WHERE (.*?))?", sql, re.IGNORECASE | re.DOTALL) |
| if match: |
| table = match.group(1).lower() |
| where = match.group(2) |
| return {"type": "DELETE", "table": table, "where": where} |
| |
| return {"type": "UNKNOWN"} |
|
|
| def _apply_where_filter(self, rows: List[Dict], where: str) -> List[Dict]: |
| """Apply WHERE filter to rows""" |
| if not where: |
| return rows |
| |
| |
| if not isinstance(rows, list): |
| logger.warning(f"_apply_where_filter: rows is not a list: {type(rows)}") |
| return [] |
| |
| |
| valid_rows = [r for r in rows if isinstance(r, dict)] |
| if len(valid_rows) != len(rows): |
| logger.warning(f"_apply_where_filter: filtered out {len(rows) - len(valid_rows)} non-dict rows") |
| |
| |
| m = re.match(r"(\w+) *([=><]+) *'?([\w@.\- ]+)'?", where) |
| if m: |
| col, op, val = m.group(1), m.group(2), m.group(3) |
| if op == "=": |
| return [r for r in valid_rows if str(r.get(col, "")) == val] |
| elif op == ">": |
| try: |
| val_num = int(val) |
| return [r for r in valid_rows if int(r.get(col, 0)) > val_num] |
| except ValueError: |
| pass |
| elif op == "<": |
| try: |
| val_num = int(val) |
| return [r for r in valid_rows if int(r.get(col, 0)) < val_num] |
| except ValueError: |
| pass |
| return valid_rows |
|
|
| def _apply_column_selection(self, rows: List[Dict], columns: List[str]) -> List[Dict]: |
| """Apply column selection to rows""" |
| if columns == ['*']: |
| return rows |
| |
| |
| if not isinstance(rows, list): |
| logger.warning(f"_apply_column_selection: rows is not a list: {type(rows)}") |
| return [] |
| |
| |
| valid_rows = [r for r in rows if isinstance(r, dict)] |
| if len(valid_rows) != len(rows): |
| logger.warning(f"_apply_column_selection: filtered out {len(rows) - len(valid_rows)} non-dict rows") |
| |
| filtered_rows = [] |
| for row in valid_rows: |
| filtered_row = {} |
| for col in columns: |
| if col in row: |
| filtered_row[col] = row[col] |
| filtered_rows.append(filtered_row) |
| return filtered_rows |
|
|
| def _apply_group_by(self, rows: List[Dict], group_by: str) -> List[Dict]: |
| """Apply GROUP BY aggregation to rows""" |
| if not group_by: |
| return rows |
| |
| |
| if not isinstance(rows, list): |
| logger.warning(f"_apply_group_by: rows is not a list: {type(rows)}") |
| return [] |
| |
| |
| valid_rows = [r for r in rows if isinstance(r, dict)] |
| if len(valid_rows) != len(rows): |
| logger.warning(f"_apply_group_by: filtered out {len(rows) - len(valid_rows)} non-dict rows") |
| |
| group_col = group_by.strip() |
| groups = {} |
| for row in valid_rows: |
| group_val = row.get(group_col, "Unknown") |
| if group_val not in groups: |
| groups[group_val] = [] |
| groups[group_val].append(row) |
| |
| result = [] |
| for group_val, group_rows in groups.items(): |
| group_result = {group_col: group_val} |
| |
| group_result["employee_count"] = len(group_rows) |
| salaries = [float(r.get("salary", 0)) for r in group_rows if r.get("salary") is not None] |
| group_result["avg_salary"] = sum(salaries) / len(salaries) if salaries else 0 |
| group_result["max_salary"] = max(salaries) if salaries else 0 |
| result.append(group_result) |
| |
| return result |
|
|
| def _execute_join_query(self, parsed: Dict) -> Dict[str, Any]: |
| """Execute JOIN query in file-based mode""" |
| try: |
| table1 = parsed["table1"] |
| table2 = parsed["table2"] |
| columns = parsed["columns"] |
| join_condition = parsed["join_condition"] |
| where = parsed.get("where") |
| |
| |
| rows1 = self.tables.get(table1, []) |
| rows2 = self.tables.get(table2, []) |
| |
| |
| if not isinstance(rows1, list): |
| logger.warning(f"Table {table1} contains non-list data: {type(rows1)}") |
| rows1 = [] |
| if not isinstance(rows2, list): |
| logger.warning(f"Table {table2} contains non-list data: {type(rows2)}") |
| rows2 = [] |
| |
| |
| join_match = re.match(r"(\w+)\.(\w+)\s*=\s*(\w+)\.(\w+)", join_condition) |
| if not join_match: |
| return {"error": "Invalid join condition format"} |
| |
| col1, col2 = join_match.group(2), join_match.group(4) |
| |
| |
| result_rows = [] |
| for row1 in rows1: |
| |
| if not isinstance(row1, dict): |
| logger.warning(f"Skipping non-dict row1 in JOIN: {type(row1)}") |
| continue |
| for row2 in rows2: |
| |
| if not isinstance(row2, dict): |
| logger.warning(f"Skipping non-dict row2 in JOIN: {type(row2)}") |
| continue |
| if str(row1.get(col1, "")) == str(row2.get(col2, "")): |
| |
| combined_row = {} |
| for col in columns: |
| if '.' in col: |
| |
| table_alias, col_name = col.split('.', 1) |
| if table_alias == parsed.get("alias1") or table_alias == table1: |
| combined_row[col] = row1.get(col_name, "") |
| elif table_alias == parsed.get("alias2") or table_alias == table2: |
| combined_row[col] = row2.get(col_name, "") |
| else: |
| |
| if col in row1: |
| combined_row[col] = row1[col] |
| elif col in row2: |
| combined_row[col] = row2[col] |
| result_rows.append(combined_row) |
| |
| |
| if where: |
| result_rows = self._apply_where_filter(result_rows, where) |
| |
| return result_rows |
| |
| except Exception as e: |
| logger.error(f"Error executing JOIN query: {str(e)}") |
| return {"error": str(e)} |
|
|
| def _execute_cross_join_query(self, parsed: Dict) -> Dict[str, Any]: |
| """Execute CROSS JOIN query in file-based mode""" |
| try: |
| table1 = parsed["table1"] |
| table2 = parsed["table2"] |
| columns = parsed["columns"] |
| where = parsed.get("where") |
| |
| |
| rows1 = self.tables.get(table1, []) |
| rows2 = self.tables.get(table2, []) |
| |
| |
| if not isinstance(rows1, list): |
| logger.warning(f"Table {table1} contains non-list data: {type(rows1)}") |
| rows1 = [] |
| if not isinstance(rows2, list): |
| logger.warning(f"Table {table2} contains non-list data: {type(rows2)}") |
| rows2 = [] |
| |
| |
| result_rows = [] |
| for row1 in rows1: |
| |
| if not isinstance(row1, dict): |
| logger.warning(f"Skipping non-dict row1 in CROSS JOIN: {type(row1)}") |
| continue |
| for row2 in rows2: |
| |
| if not isinstance(row2, dict): |
| logger.warning(f"Skipping non-dict row2 in CROSS JOIN: {type(row2)}") |
| continue |
| |
| combined_row = {} |
| for col in columns: |
| if '.' in col: |
| |
| table_alias, col_name = col.split('.', 1) |
| if table_alias == parsed.get("alias1") or table_alias == table1: |
| combined_row[col] = row1.get(col_name, "") |
| elif table_alias == parsed.get("alias2") or table_alias == table2: |
| combined_row[col] = row2.get(col_name, "") |
| else: |
| |
| if col in row1: |
| combined_row[col] = row1[col] |
| elif col in row2: |
| combined_row[col] = row2[col] |
| result_rows.append(combined_row) |
| |
| |
| if where: |
| result_rows = self._apply_where_filter(result_rows, where) |
| |
| return result_rows |
| |
| except Exception as e: |
| logger.error(f"Error executing CROSS JOIN query: {str(e)}") |
| return {"error": str(e)} |
|
|
| def _get_database_type(self) -> DatabaseType: |
| return DatabaseType.POSTGRESQL |
|
|
| def connect(self) -> bool: |
| return self._is_initialized |
|
|
| def disconnect(self) -> bool: |
| try: |
| if self.conn: |
| self.conn.close() |
| self.conn = None |
| self.cursor = None |
| self._is_initialized = False |
| logger.info("Disconnected from PostgreSQL") |
| return True |
| except Exception as e: |
| logger.error(f"Error disconnecting: {str(e)}") |
| return False |
|
|
| def test_connection(self) -> bool: |
| if self.file_based_mode: |
| return self._is_initialized |
| try: |
| if self.conn: |
| with self.conn.cursor() as cur: |
| cur.execute("SELECT 1;") |
| return True |
| return False |
| except Exception: |
| return False |
|
|
| def execute_query(self, query: Union[str, Dict, List], query_type: QueryType = None, **kwargs) -> Dict[str, Any]: |
| if not self._is_initialized: |
| return self.format_error_result("Database not initialized") |
| |
| |
| if self.file_based_mode: |
| return self._execute_file_based_query(query, query_type) |
| |
| |
| if self.conn is None: |
| return self.format_error_result("PostgreSQL server not available") |
| |
| start_time = time.time() |
| try: |
| with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: |
| |
| if isinstance(query, str): |
| |
| cur.execute(query) |
| elif isinstance(query, dict): |
| |
| sql = query.get("sql") |
| params = query.get("params", None) |
| if params: |
| cur.execute(sql, params) |
| else: |
| cur.execute(sql) |
| elif isinstance(query, list): |
| |
| for q in query: |
| if isinstance(q, str): |
| cur.execute(q) |
| elif isinstance(q, dict): |
| sql = q.get("sql") |
| params = q.get("params", None) |
| if params: |
| cur.execute(sql, params) |
| else: |
| cur.execute(sql) |
| else: |
| return self.format_error_result("Unsupported query format", query_type) |
| |
| |
| if cur.description: |
| result = cur.fetchall() |
| else: |
| result = {"rowcount": cur.rowcount} |
| |
| self.conn.commit() |
| |
| execution_time = time.time() - start_time |
| return self.format_query_result(result, query_type or QueryType.SELECT, execution_time=execution_time) |
| |
| except Exception as e: |
| execution_time = time.time() - start_time |
| logger.error(f"Error executing PostgreSQL query: {str(e)}") |
| |
| try: |
| if self.conn: |
| self.conn.rollback() |
| except Exception as rollback_error: |
| logger.warning(f"Error during rollback: {str(rollback_error)}") |
| return self.format_error_result(str(e), query_type, execution_time=execution_time) |
|
|
| def _execute_file_based_query(self, query: Union[str, Dict, List], query_type: QueryType = None) -> Dict[str, Any]: |
| """Execute query in file-based mode""" |
| start_time = time.time() |
| try: |
| if isinstance(query, str): |
| parsed = self._parse_sql_query(query) |
| query_type = query_type or QueryType.SELECT |
| |
| |
| if not isinstance(parsed, dict) or "type" not in parsed: |
| logger.error(f"_execute_file_based_query: parsed is not a valid dict: {parsed}") |
| return self.format_error_result(f"Failed to parse SQL query: {query}", query_type) |
| |
| logger.debug(f"Executing {parsed['type']} query: {parsed}") |
| |
| if parsed["type"] == "CREATE": |
| table_name = parsed["table"] |
| columns = parsed.get("columns", ["id"]) |
| if table_name not in self.tables: |
| self.tables[table_name] = [] |
| |
| if not isinstance(self.tables[table_name], list): |
| logger.warning(f"Reinitializing table {table_name} as list (was {type(self.tables[table_name])})") |
| self.tables[table_name] = [] |
| |
| self.tables[f"__schema__{table_name}"] = columns |
| if self.auto_save: |
| self._save_table_to_file(table_name) |
| result = {"rowcount": 0} |
| elif parsed["type"] == "INSERT": |
| table_name = parsed["table"] |
| columns = parsed["columns"] |
| all_values = parsed["values"] |
| if table_name not in self.tables: |
| self.tables[table_name] = [] |
| |
| if not isinstance(self.tables[table_name], list): |
| logger.warning(f"Reinitializing table {table_name} as list (was {type(self.tables[table_name])})") |
| self.tables[table_name] = [] |
| |
| valid_rows = 0 |
| |
| for values in all_values: |
| |
| if len(values) != len(columns): |
| logger.warning(f"Skipping invalid row: {values} (expected {len(columns)} values, got {len(values)})") |
| continue |
| |
| if not isinstance(values, list): |
| logger.warning(f"Skipping non-list values: {type(values)}") |
| continue |
| row = {col: val for col, val in zip(columns, values)} |
| row["id"] = len(self.tables[table_name]) + 1 |
| self.tables[table_name].append(row) |
| valid_rows += 1 |
| |
| if self.auto_save: |
| self._save_table_to_file(table_name) |
| result = {"rowcount": valid_rows} |
| elif parsed["type"] == "SELECT": |
| table_name = parsed["table"] |
| columns = parsed["columns"] |
| where = parsed.get("where") |
| group_by = parsed.get("group_by") |
| rows = self.tables.get(table_name, []) |
| |
| if not isinstance(rows, list): |
| logger.warning(f"Table {table_name} contains non-list data: {type(rows)}") |
| rows = [] |
| |
| |
| logger.debug(f"SELECT query: table={table_name}, columns={columns}, where={where}, group_by={group_by}") |
| logger.debug(f"Rows from table: {type(rows)}, length={len(rows) if isinstance(rows, list) else 'N/A'}") |
| if isinstance(rows, list) and rows: |
| logger.debug(f"First row type: {type(rows[0])}, content: {rows[0]}") |
| |
| |
| if where: |
| rows = self._apply_where_filter(rows, where) |
| |
| |
| if group_by: |
| result = self._apply_group_by(rows, group_by) |
| else: |
| |
| result = {"data": self._apply_column_selection(rows, columns)} |
| |
| elif parsed["type"] == "SELECT_JOIN": |
| |
| logger.debug(f"Executing JOIN query: {parsed}") |
| join_result = self._execute_join_query(parsed) |
| if isinstance(join_result, dict) and "error" in join_result: |
| result = {"error": join_result["error"]} |
| else: |
| result = {"data": join_result} |
| |
| elif parsed["type"] == "SELECT_CROSS_JOIN": |
| |
| logger.debug(f"Executing CROSS JOIN query: {parsed}") |
| cross_join_result = self._execute_cross_join_query(parsed) |
| if isinstance(cross_join_result, dict) and "error" in cross_join_result: |
| result = {"error": cross_join_result["error"]} |
| else: |
| result = {"data": cross_join_result} |
| elif parsed["type"] == "UPDATE": |
| table_name = parsed["table"] |
| set_clause = parsed["set"] |
| where = parsed.get("where") |
| rows = self.tables.get(table_name, []) |
| |
| if not isinstance(rows, list): |
| logger.warning(f"Table {table_name} contains non-list data: {type(rows)}") |
| rows = [] |
| |
| updates = dict(re.findall(r"(\w+) *= *'?([\w@.\- ]+)'?", set_clause)) |
| count = 0 |
| for r in rows: |
| |
| if not isinstance(r, dict): |
| logger.warning(f"Skipping non-dict row in UPDATE: {type(r)}") |
| continue |
| match = True |
| if where: |
| m = re.match(r"(\w+) *([=><]+) *'?([\w@.\- ]+)'?", where) |
| if m: |
| col, op, val = m.group(1), m.group(2), m.group(3) |
| if op == "=" and str(r.get(col, "")) != val: |
| match = False |
| elif op == ">" and int(r.get(col, 0)) <= int(val): |
| match = False |
| elif op == "<" and int(r.get(col, 0)) >= int(val): |
| match = False |
| if match: |
| r.update(updates) |
| count += 1 |
| if self.auto_save: |
| self._save_table_to_file(table_name) |
| result = {"rowcount": count} |
| elif parsed["type"] == "DELETE": |
| table_name = parsed["table"] |
| where = parsed.get("where") |
| rows = self.tables.get(table_name, []) |
| |
| if not isinstance(rows, list): |
| logger.warning(f"Table {table_name} contains non-list data: {type(rows)}") |
| rows = [] |
| if where: |
| m = re.match(r"(\w+) *([=><]+) *'?([\w@.\- ]+)'?", where) |
| if m: |
| col, op, val = m.group(1), m.group(2), m.group(3) |
| if op == "=": |
| new_rows = [r for r in rows if isinstance(r, dict) and str(r.get(col, "")) != val] |
| elif op == ">": |
| try: |
| val_num = int(val) |
| new_rows = [r for r in rows if isinstance(r, dict) and int(r.get(col, 0)) <= val_num] |
| except ValueError: |
| new_rows = rows |
| else: |
| new_rows = rows |
| deleted_count = len(rows) - len(new_rows) |
| self.tables[table_name] = new_rows |
| else: |
| deleted_count = 0 |
| else: |
| deleted_count = len(rows) |
| self.tables[table_name] = [] |
| if self.auto_save: |
| self._save_table_to_file(table_name) |
| result = {"rowcount": deleted_count} |
| else: |
| return self.format_error_result("Unsupported query type in file-based mode", query_type) |
| execution_time = time.time() - start_time |
| return self.format_query_result(result, query_type, execution_time=execution_time) |
| else: |
| return self.format_error_result("Unsupported query format in file-based mode", query_type) |
| except Exception as e: |
| execution_time = time.time() - start_time |
| logger.error(f"Error executing file-based query: {str(e)}") |
| logger.error(f"Query that caused error: {query}") |
| logger.error(f"Query type: {query_type}") |
| import traceback |
| logger.error(f"Traceback: {traceback.format_exc()}") |
| return self.format_error_result(str(e), query_type, execution_time=execution_time) |
|
|
| def get_database_info(self) -> Dict[str, Any]: |
| try: |
| if not self._is_initialized: |
| return self.format_error_result("Database not initialized") |
| |
| if self.file_based_mode: |
| info = { |
| "database": self.database_name, |
| "user": "file_based", |
| "table_count": len(self.tables), |
| "connection_string": "file_based", |
| "is_connected": True, |
| "mode": "file_based" |
| } |
| else: |
| if self.conn is None: |
| return self.format_error_result("PostgreSQL server not available") |
| |
| with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: |
| cur.execute("SELECT current_database() as database, current_user as user") |
| db_info = cur.fetchone() |
| cur.execute("SELECT COUNT(*) as table_count FROM information_schema.tables WHERE table_schema = 'public'") |
| table_count = cur.fetchone()["table_count"] |
| info = { |
| "database": db_info["database"], |
| "user": db_info["user"], |
| "table_count": table_count, |
| "connection_string": self.connection_string, |
| "is_connected": self._is_initialized |
| } |
| return self.format_query_result(info, QueryType.SELECT) |
| except Exception as e: |
| return self.format_error_result(str(e)) |
|
|
| def list_collections(self) -> List[str]: |
| try: |
| if self.file_based_mode: |
| return list(self.tables.keys()) |
| if not self._is_initialized or self.conn is None: |
| return [] |
| with self.conn.cursor() as cur: |
| cur.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'") |
| tables = [row[0] for row in cur.fetchall()] |
| return tables |
| except Exception as e: |
| logger.error(f"Error listing tables: {str(e)}") |
| return [] |
|
|
| def get_collection_info(self, collection_name: str) -> Dict[str, Any]: |
| try: |
| if not self._is_initialized: |
| return self.format_error_result("Database not initialized") |
| |
| if self.file_based_mode: |
| if collection_name in self.tables: |
| row_count = len(self.tables[collection_name]) |
| info = { |
| "table_name": collection_name, |
| "row_count": row_count, |
| "columns": ["id"] |
| } |
| else: |
| return self.format_error_result(f"Table {collection_name} not found") |
| else: |
| if self.conn is None: |
| return self.format_error_result("PostgreSQL server not available") |
| |
| with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: |
| cur.execute(f"SELECT COUNT(*) as row_count FROM {collection_name}") |
| row_count = cur.fetchone()["row_count"] |
| cur.execute("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %s", (collection_name,)) |
| columns = cur.fetchall() |
| info = { |
| "table_name": collection_name, |
| "row_count": row_count, |
| "columns": columns |
| } |
| return self.format_query_result(info, QueryType.SELECT) |
| except Exception as e: |
| return self.format_error_result(str(e)) |
|
|
| def get_schema(self, collection_name: str = None) -> Dict[str, Any]: |
| try: |
| if not self._is_initialized: |
| return self.format_error_result("Database not initialized") |
| |
| if self.file_based_mode: |
| if collection_name: |
| if collection_name in self.tables: |
| schema = {"id": "integer"} |
| return self.format_query_result({"table_name": collection_name, "schema": schema}, QueryType.SELECT) |
| else: |
| return self.format_error_result(f"Table {collection_name} not found") |
| else: |
| schemas = {} |
| for table_name in self.tables: |
| schemas[table_name] = {"id": "integer"} |
| return self.format_query_result({"database_name": self.database_name, "schemas": schemas}, QueryType.SELECT) |
| else: |
| if self.conn is None: |
| return self.format_error_result("PostgreSQL server not available") |
| |
| with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as cur: |
| if collection_name: |
| cur.execute("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %s", (collection_name,)) |
| columns = cur.fetchall() |
| schema = {col["column_name"]: col["data_type"] for col in columns} |
| return self.format_query_result({"table_name": collection_name, "schema": schema}, QueryType.SELECT) |
| else: |
| cur.execute("SELECT table_name FROM information_schema.tables WHERE table_schema = 'public'") |
| tables = [row[0] for row in cur.fetchall()] |
| schemas = {} |
| for table in tables: |
| cur.execute("SELECT column_name, data_type FROM information_schema.columns WHERE table_name = %s", (table,)) |
| columns = cur.fetchall() |
| schemas[table] = {col["column_name"]: col["data_type"] for col in columns} |
| return self.format_query_result({"database_name": self.database_name, "schemas": schemas}, QueryType.SELECT) |
| except Exception as e: |
| return self.format_error_result(str(e)) |
|
|
| def get_supported_query_types(self) -> List[QueryType]: |
| return [ |
| QueryType.SELECT, |
| QueryType.INSERT, |
| QueryType.UPDATE, |
| QueryType.DELETE, |
| QueryType.CREATE, |
| QueryType.DROP, |
| QueryType.ALTER, |
| QueryType.INDEX |
| ] |
|
|
| def get_capabilities(self) -> Dict[str, Any]: |
| base_capabilities = super().get_capabilities() |
| base_capabilities.update({ |
| "supports_sql": True, |
| "supports_transactions": not self.file_based_mode, |
| "supports_indexing": not self.file_based_mode, |
| "schema_flexible": self.file_based_mode, |
| "file_based_mode": self.file_based_mode |
| }) |
| return base_capabilities |
|
|
| |
| class PostgreSQLExecuteTool(Tool): |
| name: str = "postgresql_execute" |
| description: str = "Execute arbitrary SQL queries on PostgreSQL." |
| inputs: Dict[str, Dict[str, str]] = { |
| "query": {"type": "string", "description": "SQL query to execute (can be SELECT, INSERT, UPDATE, DELETE, etc.)"}, |
| "query_type": {"type": "string", "description": "Type of query (select, insert, update, delete, create, drop, alter, index) - auto-detected if not provided"} |
| } |
| required: Optional[List[str]] = ["query"] |
| def __init__(self, database: PostgreSQLDatabase = None): |
| super().__init__() |
| self.database = database |
| def __call__(self, query: str, query_type: str = None) -> Dict[str, Any]: |
| try: |
| if not self.database: |
| return {"success": False, "error": "PostgreSQL database not initialized", "data": None} |
| |
| |
| |
| query_type_enum = None |
| if query_type: |
| try: |
| query_type_enum = QueryType(query_type.lower()) |
| except ValueError: |
| return {"success": False, "error": f"Invalid query type: {query_type}", "data": None} |
| |
| result = self.database.execute_query(query=query, query_type=query_type_enum) |
| return result |
| except Exception as e: |
| logger.error(f"Error in postgresql_execute tool: {str(e)}") |
| return {"success": False, "error": str(e), "data": None} |
|
|
| class PostgreSQLFindTool(Tool): |
| name: str = "postgresql_find" |
| description: str = "Find (SELECT) rows from a PostgreSQL table." |
| inputs: Dict[str, Dict[str, str]] = { |
| "table_name": {"type": "string", "description": "Table name to query"}, |
| "where": {"type": "string", "description": "WHERE clause (optional, e.g., 'age > 18')"}, |
| "columns": {"type": "string", "description": "Comma-separated columns to select (default '*')"}, |
| "limit": {"type": "integer", "description": "Maximum number of rows to return (optional)"}, |
| "offset": {"type": "integer", "description": "Number of rows to skip (optional)"}, |
| "sort": {"type": "string", "description": "ORDER BY clause (optional, e.g., 'age ASC')"} |
| } |
| required: Optional[List[str]] = ["table_name"] |
| def __init__(self, database: PostgreSQLDatabase = None): |
| super().__init__() |
| self.database = database |
| def __call__(self, table_name: str, where: str = None, columns: str = "*", limit: int = None, offset: int = None, sort: str = None) -> Dict[str, Any]: |
| try: |
| if not self.database: |
| return {"success": False, "error": "PostgreSQL database not initialized", "data": None} |
| sql = f"SELECT {columns} FROM {table_name}" |
| if where: |
| sql += f" WHERE {where}" |
| if sort: |
| sql += f" ORDER BY {sort}" |
| if limit is not None: |
| sql += f" LIMIT {limit}" |
| if offset is not None: |
| sql += f" OFFSET {offset}" |
| result = self.database.execute_query(sql, QueryType.SELECT) |
| return result |
| except Exception as e: |
| logger.error(f"Error in postgresql_find tool: {str(e)}") |
| return {"success": False, "error": str(e), "data": None} |
|
|
| class PostgreSQLUpdateTool(Tool): |
| name: str = "postgresql_update" |
| description: str = "Update rows in a PostgreSQL table." |
| inputs: Dict[str, Dict[str, str]] = { |
| "table_name": {"type": "string", "description": "Table name to update"}, |
| "set": {"type": "string", "description": "SET clause (e.g., 'status = \'active\'')"}, |
| "where": {"type": "string", "description": "WHERE clause (optional)"} |
| } |
| required: Optional[List[str]] = ["table_name", "set"] |
| def __init__(self, database: PostgreSQLDatabase = None): |
| super().__init__() |
| self.database = database |
| def __call__(self, table_name: str, set: str, where: str = None) -> Dict[str, Any]: |
| try: |
| if not self.database: |
| return {"success": False, "error": "PostgreSQL database not initialized", "data": None} |
| sql = f"UPDATE {table_name} SET {set}" |
| if where: |
| sql += f" WHERE {where}" |
| result = self.database.execute_query(sql, QueryType.UPDATE) |
| return result |
| except Exception as e: |
| logger.error(f"Error in postgresql_update tool: {str(e)}") |
| return {"success": False, "error": str(e), "data": None} |
|
|
| class PostgreSQLCreateTool(Tool): |
| name: str = "postgresql_create" |
| description: str = "Create a table or other object in PostgreSQL." |
| inputs: Dict[str, Dict[str, str]] = { |
| "query": {"type": "string", "description": "CREATE statement (e.g., CREATE TABLE ...)"} |
| } |
| required: Optional[List[str]] = ["query"] |
| def __init__(self, database: PostgreSQLDatabase = None): |
| super().__init__() |
| self.database = database |
| def __call__(self, query: str) -> Dict[str, Any]: |
| try: |
| if not self.database: |
| return {"success": False, "error": "PostgreSQL database not initialized", "data": None} |
| result = self.database.execute_query(query, QueryType.CREATE) |
| return result |
| except Exception as e: |
| logger.error(f"Error in postgresql_create tool: {str(e)}") |
| return {"success": False, "error": str(e), "data": None} |
|
|
| class PostgreSQLDeleteTool(Tool): |
| name: str = "postgresql_delete" |
| description: str = "Delete rows from a PostgreSQL table." |
| inputs: Dict[str, Dict[str, str]] = { |
| "table_name": {"type": "string", "description": "Table name to delete from"}, |
| "where": {"type": "string", "description": "WHERE clause (optional)"} |
| } |
| required: Optional[List[str]] = ["table_name"] |
| def __init__(self, database: PostgreSQLDatabase = None): |
| super().__init__() |
| self.database = database |
| def __call__(self, table_name: str, where: str = None) -> Dict[str, Any]: |
| try: |
| if not self.database: |
| return {"success": False, "error": "PostgreSQL database not initialized", "data": None} |
| sql = f"DELETE FROM {table_name}" |
| if where: |
| sql += f" WHERE {where}" |
| result = self.database.execute_query(sql, QueryType.DELETE) |
| return result |
| except Exception as e: |
| logger.error(f"Error in postgresql_delete tool: {str(e)}") |
| return {"success": False, "error": str(e), "data": None} |
|
|
| class PostgreSQLInfoTool(Tool): |
| name: str = "postgresql_info" |
| description: str = "Get PostgreSQL database and table information." |
| inputs: Dict[str, Dict[str, str]] = { |
| "info_type": {"type": "string", "description": "Type of information (database, tables, table, schema, capabilities)"}, |
| "table_name": {"type": "string", "description": "Table name for table-specific info (optional)"} |
| } |
| required: Optional[List[str]] = [] |
| def __init__(self, database: PostgreSQLDatabase = None): |
| super().__init__() |
| self.database = database |
| def __call__(self, info_type: str = "database", table_name: str = None) -> Dict[str, Any]: |
| try: |
| if not self.database: |
| return {"success": False, "error": "PostgreSQL database not initialized", "data": None} |
| info_type = info_type.lower() |
| if info_type == "database": |
| result = self.database.get_database_info() |
| elif info_type == "tables": |
| tables = self.database.list_collections() |
| result = {"success": True, "data": tables, "table_count": len(tables)} |
| elif info_type == "table" and table_name: |
| result = self.database.get_collection_info(table_name) |
| elif info_type == "schema": |
| result = self.database.get_schema(table_name) |
| elif info_type == "capabilities": |
| result = {"success": True, "data": self.database.get_capabilities()} |
| else: |
| return {"success": False, "error": f"Invalid info type: {info_type}", "data": None} |
| return result |
| except Exception as e: |
| logger.error(f"Error in postgresql_info tool: {str(e)}") |
| return {"success": False, "error": str(e), "data": None} |
|
|
| class PostgreSQLToolkit(Toolkit): |
| def __init__(self, |
| name: str = "PostgreSQLToolkit", |
| connection_string: str = None, |
| database_name: str = None, |
| local_path: str = None, |
| auto_save: bool = True, |
| **kwargs): |
| database = PostgreSQLDatabase( |
| connection_string=connection_string, |
| database_name=database_name, |
| local_path=local_path, |
| auto_save=auto_save, |
| **kwargs |
| ) |
| tools = [ |
| PostgreSQLExecuteTool(database=database), |
| PostgreSQLFindTool(database=database), |
| PostgreSQLUpdateTool(database=database), |
| PostgreSQLCreateTool(database=database), |
| PostgreSQLDeleteTool(database=database), |
| PostgreSQLInfoTool(database=database) |
| ] |
| super().__init__(name=name, tools=tools) |
| self.database = database |
| self.connection_string = connection_string |
| self.database_name = database_name |
| self.local_path = local_path |
| self.auto_save = auto_save |
| import atexit |
| atexit.register(self._cleanup) |
| def _cleanup(self): |
| try: |
| if self.database: |
| self.database.disconnect() |
| logger.info("Disconnected from PostgreSQL database") |
| except Exception as e: |
| logger.warning(f"Error during cleanup: {str(e)}") |
| def get_capabilities(self) -> Dict[str, Any]: |
| if self.database: |
| capabilities = self.database.get_capabilities() |
| capabilities.update({ |
| "is_local_database": self.database.is_local_database, |
| "local_path": str(self.database.local_path) if self.database.local_path else None, |
| "auto_save": self.database.auto_save |
| }) |
| return capabilities |
| return {"error": "PostgreSQL database not initialized"} |
| def connect(self) -> bool: |
| return self.database.connect() if self.database else False |
| def disconnect(self) -> bool: |
| return self.database.disconnect() if self.database else False |
| def test_connection(self) -> bool: |
| return self.database.test_connection() if self.database else False |
| def get_database(self) -> PostgreSQLDatabase: |
| return self.database |
| def get_local_info(self) -> Dict[str, Any]: |
| return { |
| "is_local_database": self.database.is_local_database, |
| "local_path": str(self.database.local_path) if self.database.local_path else None, |
| "auto_save": self.database.auto_save, |
| "database_name": self.database_name, |
| "connection_string": self.connection_string |
| } if self.database else {"error": "Database not initialized"} |