| |
|
|
| from abc import ABC, abstractmethod |
| from dataclasses import dataclass |
| from enum import Enum |
| from typing import Any, Optional |
| import logging |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| class RiskLevel(Enum): |
| LOW = "LOW" |
| MEDIUM = "MEDIUM" |
| HIGH = "HIGH" |
| CRITICAL = "CRITICAL" |
|
|
|
|
| class ExecutionStatus(Enum): |
| SUCCESS = "SUCCESS" |
| WAITING_FOR_APPROVAL = "WAITING_FOR_APPROVAL" |
| REJECTED = "REJECTED" |
| ERROR = "ERROR" |
|
|
|
|
| @dataclass |
| class ExecutionResult: |
| status: ExecutionStatus |
| result: Optional[Any] = None |
| error_message: Optional[str] = None |
| decision_id: Optional[str] = None |
|
|
|
|
| class RiskPolicy(ABC): |
| @abstractmethod |
| def get_level(self, tool_name: str) -> RiskLevel: |
| pass |
|
|
|
|
| class Tool(ABC): |
| @abstractmethod |
| def run(self, parameters: dict) -> Any: |
| pass |
|
|
|
|
| class ApprovalService(ABC): |
| @abstractmethod |
| def create_approval_request(self, tool_name: str, parameters: dict) -> str: |
| pass |
|
|
| @abstractmethod |
| def notify_manager(self, decision_id: str) -> None: |
| pass |
|
|
|
|
| class EscalationLadder: |
| def __init__( |
| self, |
| risk_policy: RiskPolicy, |
| tool: Tool, |
| approval_service: ApprovalService |
| ): |
| self._risk_policy = risk_policy |
| self._tool = tool |
| self._approval_service = approval_service |
|
|
| def execute_tool(self, tool_name: str, parameters: dict) -> ExecutionResult: |
| """Execute a tool with appropriate risk-based escalation.""" |
| if not tool_name or not isinstance(tool_name, str): |
| logger.error("Invalid tool_name provided") |
| return ExecutionResult( |
| status=ExecutionStatus.ERROR, |
| error_message="Invalid tool_name: must be a non-empty string" |
| ) |
|
|
| if not isinstance(parameters, dict): |
| logger.error("Invalid parameters provided") |
| return ExecutionResult( |
| status=ExecutionStatus.ERROR, |
| error_message="Invalid parameters: must be a dictionary" |
| ) |
|
|
| try: |
| risk_level = self._risk_policy.get_level(tool_name) |
| except Exception as e: |
| logger.exception("Failed to assess risk level") |
| return ExecutionResult( |
| status=ExecutionStatus.ERROR, |
| error_message=f"Risk assessment failed: {str(e)}" |
| ) |
|
|
| if risk_level == RiskLevel.CRITICAL: |
| logger.warning(f"CRITICAL risk tool '{tool_name}' blocked") |
| return ExecutionResult( |
| status=ExecutionStatus.REJECTED, |
| error_message="Critical risk tools are not permitted" |
| ) |
|
|
| if risk_level == RiskLevel.HIGH: |
| return self._handle_high_risk(tool_name, parameters) |
|
|
| if risk_level == RiskLevel.MEDIUM: |
| logger.info(f"MEDIUM risk tool '{tool_name}' - logging for audit") |
| return self._execute_with_logging(tool_name, parameters) |
|
|
| |
| return self._execute_tool(parameters) |
|
|
| def _handle_high_risk(self, tool_name: str, parameters: dict) -> ExecutionResult: |
| """Handle high-risk tool execution with approval workflow.""" |
| try: |
| decision_id = self._approval_service.create_approval_request( |
| tool_name, parameters |
| ) |
| self._approval_service.notify_manager(decision_id) |
| logger.info(f"Approval request created: {decision_id}") |
| return ExecutionResult( |
| status=ExecutionStatus.WAITING_FOR_APPROVAL, |
| decision_id=decision_id |
| ) |
| except Exception as e: |
| logger.exception("Failed to create approval request") |
| return ExecutionResult( |
| status=ExecutionStatus.ERROR, |
| error_message=f"Approval workflow failed: {str(e)}" |
| ) |
|
|
| def _execute_with_logging(self, tool_name: str, parameters: dict) -> ExecutionResult: |
| """Execute tool with enhanced audit logging.""" |
| logger.info(f"Executing medium-risk tool: {tool_name}") |
| return self._execute_tool(parameters) |
|
|
| def _execute_tool(self, parameters: dict) -> ExecutionResult: |
| """Execute the tool and return result.""" |
| try: |
| result = self._tool.run(parameters) |
| return ExecutionResult(status=ExecutionStatus.SUCCESS, result=result) |
| except Exception as e: |
| logger.exception("Tool execution failed") |
| return ExecutionResult( |
| status=ExecutionStatus.ERROR, |
| error_message=f"Execution failed: {str(e)}" |
| ) |
|
|