Spaces:
Sleeping
Sleeping
File size: 9,320 Bytes
f762b8d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 | # environment/db_engine.py
# SQLite Database Engine with Security Safeguards
# Implements: Mutation Blocker, OOM Protection, Timeout Wrapper
import re
import sqlite3
import signal
import os
from typing import Tuple, Optional
from contextlib import contextmanager
from pathlib import Path
# Regex pattern for blocking destructive SQL operations
MUTATION_PATTERN = re.compile(
r'\b(INSERT|UPDATE|DELETE|DROP|ALTER|TRUNCATE)\b',
re.IGNORECASE
)
# Query execution timeout in seconds
QUERY_TIMEOUT = 2.0
# Maximum rows to fetch (OOM protection)
MAX_FETCH_ROWS = 50
class TimeoutError(Exception):
"""Custom exception for query timeout."""
pass
@contextmanager
def timeout_handler(seconds: float):
"""
Context manager for query timeout.
Note: signal.alarm only works on Unix. On Windows, we use a simpler approach.
"""
# On Windows, signal.SIGALRM is not available
# We implement a basic timeout check instead
if os.name == 'nt':
# Windows: No signal-based timeout, rely on sqlite3 timeout
yield
else:
def handler(signum, frame):
raise TimeoutError(f"Query execution exceeded {seconds} seconds timeout")
old_handler = signal.signal(signal.SIGALRM, handler)
signal.setitimer(signal.ITIMER_REAL, seconds)
try:
yield
finally:
signal.setitimer(signal.ITIMER_REAL, 0)
signal.signal(signal.SIGALRM, old_handler)
class DatabaseEngine:
"""
SQLite Database Engine with security safeguards.
Features:
- In-memory SQLite database (:memory: mode)
- Mutation Blocker: Regex-based blocking of INSERT, UPDATE, DELETE, DROP, ALTER, TRUNCATE
- OOM Protection: cursor.fetchmany(50), never fetchall()
- Timeout Wrapper: 2.0-second timeout for query execution
- Stringified errors: Never raises Python exceptions to caller
"""
def __init__(self):
"""Initialize the database engine with an in-memory SQLite database."""
self.connection: Optional[sqlite3.Connection] = None
self.cursor: Optional[sqlite3.Cursor] = None
self._schema_cache: Optional[str] = None
def initialize(self) -> str:
"""
Initialize a clean in-memory SQLite database and load mock data.
Returns:
str: Success message or error string
"""
try:
# Close existing connection if any
self.close()
# Create new in-memory database
self.connection = sqlite3.connect(
':memory:',
timeout=QUERY_TIMEOUT,
check_same_thread=False
)
self.cursor = self.connection.cursor()
# Load mock data from SQL file
mock_data_path = Path(__file__).parent.parent / 'data' / 'mock_data.sql'
if mock_data_path.exists():
with open(mock_data_path, 'r') as f:
sql_script = f.read()
self.cursor.executescript(sql_script)
self.connection.commit()
else:
return f"Error: Mock data file not found at {mock_data_path}"
# Cache schema info
self._schema_cache = self._get_schema_info()
return "Database initialized successfully"
except Exception as e:
return f"Error initializing database: {str(e)}"
def _get_schema_info(self) -> str:
"""
Get database schema information for the agent.
Returns:
str: Formatted schema information
"""
if not self.cursor:
return "Error: Database not initialized"
try:
# Get all table names
self.cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY name"
)
tables = [row[0] for row in self.cursor.fetchmany(MAX_FETCH_ROWS)]
schema_parts = ["DATABASE SCHEMA:", "=" * 50]
for table in tables:
schema_parts.append(f"\nTable: {table}")
schema_parts.append("-" * 30)
# Get column info using PRAGMA
self.cursor.execute(f"PRAGMA table_info({table})")
columns = self.cursor.fetchmany(MAX_FETCH_ROWS)
for col in columns:
col_id, name, col_type, not_null, default, pk = col
pk_marker = " [PRIMARY KEY]" if pk else ""
null_marker = " NOT NULL" if not_null else ""
schema_parts.append(f" - {name}: {col_type}{null_marker}{pk_marker}")
return "\n".join(schema_parts)
except Exception as e:
return f"Error getting schema: {str(e)}"
def get_schema(self) -> str:
"""
Get cached schema information.
Returns:
str: Schema information string
"""
if self._schema_cache:
return self._schema_cache
return self._get_schema_info()
def check_mutation(self, query: str) -> Optional[str]:
"""
Check if query contains mutation operations.
Args:
query: SQL query string
Returns:
Optional[str]: Error message if mutation detected, None otherwise
"""
match = MUTATION_PATTERN.search(query)
if match:
matched = match.group(1).upper()
return (
f"DESTRUCTIVE_ACTION_BLOCKED: {matched} operations are not allowed. "
f"This environment is read-only. Only SELECT queries are permitted."
)
return None
def execute_query(self, query: str) -> Tuple[str, bool]:
"""
Execute a SQL query with all safety measures.
Args:
query: SQL query string
Returns:
Tuple[str, bool]: (result_string, is_error)
- result_string: Query results or error message
- is_error: True if an error occurred, False otherwise
"""
if not self.connection or not self.cursor:
return "Error: Database not initialized", True
# Strip and validate query
query = query.strip()
if not query:
return "Error: Empty query provided", True
# MUTATION BLOCKER: Check for destructive operations
mutation_error = self.check_mutation(query)
if mutation_error:
return mutation_error, True
try:
# Execute with timeout protection
with timeout_handler(QUERY_TIMEOUT):
self.cursor.execute(query)
# OOM PROTECTION: Use fetchmany(50), NEVER fetchall()
rows = self.cursor.fetchmany(MAX_FETCH_ROWS)
if not rows:
# Check if it was a query that doesn't return rows
if self.cursor.description is None:
return "Query executed successfully (no results)", False
return "Query returned no results", False
# Get column names
columns = [desc[0] for desc in self.cursor.description]
# Format results
result_lines = []
result_lines.append("| " + " | ".join(columns) + " |")
result_lines.append("|" + "|".join(["---"] * len(columns)) + "|")
for row in rows:
formatted_row = [str(val) if val is not None else "NULL" for val in row]
result_lines.append("| " + " | ".join(formatted_row) + " |")
result = "\n".join(result_lines)
# Check if results were truncated
# Try to fetch one more row to see if there are more
extra = self.cursor.fetchmany(1)
if extra:
result += f"\n\n[TRUNCATED] Results limited to {MAX_FETCH_ROWS} rows. More rows exist."
return result, False
except TimeoutError as e:
return f"Error: {str(e)}", True
except sqlite3.Error as e:
return f"SQLite Error: {str(e)}", True
except Exception as e:
return f"Error: {str(e)}", True
def close(self):
"""Close the database connection."""
if self.cursor:
self.cursor.close()
self.cursor = None
if self.connection:
self.connection.close()
self.connection = None
self._schema_cache = None
def __del__(self):
"""Destructor to ensure connection is closed."""
self.close()
|