YashashMathur's picture
Sync from GitHub - all files
f762b8d verified
# environment/db_engine.py
# SQLite Database Engine with Security Safeguards
# Implements: Mutation Blocker, OOM Protection, Timeout Wrapper
import re
import sqlite3
import signal
import os
from typing import Tuple, Optional
from contextlib import contextmanager
from pathlib import Path
# Regex pattern for blocking destructive SQL operations
MUTATION_PATTERN = re.compile(
r'\b(INSERT|UPDATE|DELETE|DROP|ALTER|TRUNCATE)\b',
re.IGNORECASE
)
# Query execution timeout in seconds
QUERY_TIMEOUT = 2.0
# Maximum rows to fetch (OOM protection)
MAX_FETCH_ROWS = 50
class TimeoutError(Exception):
"""Custom exception for query timeout."""
pass
@contextmanager
def timeout_handler(seconds: float):
"""
Context manager for query timeout.
Note: signal.alarm only works on Unix. On Windows, we use a simpler approach.
"""
# On Windows, signal.SIGALRM is not available
# We implement a basic timeout check instead
if os.name == 'nt':
# Windows: No signal-based timeout, rely on sqlite3 timeout
yield
else:
def handler(signum, frame):
raise TimeoutError(f"Query execution exceeded {seconds} seconds timeout")
old_handler = signal.signal(signal.SIGALRM, handler)
signal.setitimer(signal.ITIMER_REAL, seconds)
try:
yield
finally:
signal.setitimer(signal.ITIMER_REAL, 0)
signal.signal(signal.SIGALRM, old_handler)
class DatabaseEngine:
"""
SQLite Database Engine with security safeguards.
Features:
- In-memory SQLite database (:memory: mode)
- Mutation Blocker: Regex-based blocking of INSERT, UPDATE, DELETE, DROP, ALTER, TRUNCATE
- OOM Protection: cursor.fetchmany(50), never fetchall()
- Timeout Wrapper: 2.0-second timeout for query execution
- Stringified errors: Never raises Python exceptions to caller
"""
def __init__(self):
"""Initialize the database engine with an in-memory SQLite database."""
self.connection: Optional[sqlite3.Connection] = None
self.cursor: Optional[sqlite3.Cursor] = None
self._schema_cache: Optional[str] = None
def initialize(self) -> str:
"""
Initialize a clean in-memory SQLite database and load mock data.
Returns:
str: Success message or error string
"""
try:
# Close existing connection if any
self.close()
# Create new in-memory database
self.connection = sqlite3.connect(
':memory:',
timeout=QUERY_TIMEOUT,
check_same_thread=False
)
self.cursor = self.connection.cursor()
# Load mock data from SQL file
mock_data_path = Path(__file__).parent.parent / 'data' / 'mock_data.sql'
if mock_data_path.exists():
with open(mock_data_path, 'r') as f:
sql_script = f.read()
self.cursor.executescript(sql_script)
self.connection.commit()
else:
return f"Error: Mock data file not found at {mock_data_path}"
# Cache schema info
self._schema_cache = self._get_schema_info()
return "Database initialized successfully"
except Exception as e:
return f"Error initializing database: {str(e)}"
def _get_schema_info(self) -> str:
"""
Get database schema information for the agent.
Returns:
str: Formatted schema information
"""
if not self.cursor:
return "Error: Database not initialized"
try:
# Get all table names
self.cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY name"
)
tables = [row[0] for row in self.cursor.fetchmany(MAX_FETCH_ROWS)]
schema_parts = ["DATABASE SCHEMA:", "=" * 50]
for table in tables:
schema_parts.append(f"\nTable: {table}")
schema_parts.append("-" * 30)
# Get column info using PRAGMA
self.cursor.execute(f"PRAGMA table_info({table})")
columns = self.cursor.fetchmany(MAX_FETCH_ROWS)
for col in columns:
col_id, name, col_type, not_null, default, pk = col
pk_marker = " [PRIMARY KEY]" if pk else ""
null_marker = " NOT NULL" if not_null else ""
schema_parts.append(f" - {name}: {col_type}{null_marker}{pk_marker}")
return "\n".join(schema_parts)
except Exception as e:
return f"Error getting schema: {str(e)}"
def get_schema(self) -> str:
"""
Get cached schema information.
Returns:
str: Schema information string
"""
if self._schema_cache:
return self._schema_cache
return self._get_schema_info()
def check_mutation(self, query: str) -> Optional[str]:
"""
Check if query contains mutation operations.
Args:
query: SQL query string
Returns:
Optional[str]: Error message if mutation detected, None otherwise
"""
match = MUTATION_PATTERN.search(query)
if match:
matched = match.group(1).upper()
return (
f"DESTRUCTIVE_ACTION_BLOCKED: {matched} operations are not allowed. "
f"This environment is read-only. Only SELECT queries are permitted."
)
return None
def execute_query(self, query: str) -> Tuple[str, bool]:
"""
Execute a SQL query with all safety measures.
Args:
query: SQL query string
Returns:
Tuple[str, bool]: (result_string, is_error)
- result_string: Query results or error message
- is_error: True if an error occurred, False otherwise
"""
if not self.connection or not self.cursor:
return "Error: Database not initialized", True
# Strip and validate query
query = query.strip()
if not query:
return "Error: Empty query provided", True
# MUTATION BLOCKER: Check for destructive operations
mutation_error = self.check_mutation(query)
if mutation_error:
return mutation_error, True
try:
# Execute with timeout protection
with timeout_handler(QUERY_TIMEOUT):
self.cursor.execute(query)
# OOM PROTECTION: Use fetchmany(50), NEVER fetchall()
rows = self.cursor.fetchmany(MAX_FETCH_ROWS)
if not rows:
# Check if it was a query that doesn't return rows
if self.cursor.description is None:
return "Query executed successfully (no results)", False
return "Query returned no results", False
# Get column names
columns = [desc[0] for desc in self.cursor.description]
# Format results
result_lines = []
result_lines.append("| " + " | ".join(columns) + " |")
result_lines.append("|" + "|".join(["---"] * len(columns)) + "|")
for row in rows:
formatted_row = [str(val) if val is not None else "NULL" for val in row]
result_lines.append("| " + " | ".join(formatted_row) + " |")
result = "\n".join(result_lines)
# Check if results were truncated
# Try to fetch one more row to see if there are more
extra = self.cursor.fetchmany(1)
if extra:
result += f"\n\n[TRUNCATED] Results limited to {MAX_FETCH_ROWS} rows. More rows exist."
return result, False
except TimeoutError as e:
return f"Error: {str(e)}", True
except sqlite3.Error as e:
return f"SQLite Error: {str(e)}", True
except Exception as e:
return f"Error: {str(e)}", True
def close(self):
"""Close the database connection."""
if self.cursor:
self.cursor.close()
self.cursor = None
if self.connection:
self.connection.close()
self.connection = None
self._schema_cache = None
def __del__(self):
"""Destructor to ensure connection is closed."""
self.close()