| |
|
|
| from abc import ABC, abstractmethod |
| from dataclasses import dataclass |
| from enum import Enum |
| from typing import List, Optional |
| import logging |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| |
| ANGER_THRESHOLD = 0.9 |
| FACT_SCORE_THRESHOLD = 0.5 |
|
|
|
|
| class SafetyStatus(Enum): |
| SAFE = "SAFE" |
| KILL_SWITCH_ACTIVE = "KILL_SWITCH_ACTIVE" |
| CHECK_FAILED = "CHECK_FAILED" |
|
|
|
|
| @dataclass |
| class AgentState: |
| user_history: List[str] |
| last_response: str |
|
|
|
|
| @dataclass |
| class SafetyCheckResult: |
| status: SafetyStatus |
| anger_score: Optional[float] = None |
| fact_score: Optional[float] = None |
| error_message: Optional[str] = None |
|
|
|
|
| class SentimentAnalyzer(ABC): |
| @abstractmethod |
| def analyze(self, user_history: List[str]) -> float: |
| """Returns anger score between 0.0 and 1.0.""" |
| pass |
|
|
|
|
| class FactVerifier(ABC): |
| @abstractmethod |
| def verify_facts(self, response: str) -> float: |
| """Returns fact accuracy score between 0.0 and 1.0.""" |
| pass |
|
|
|
|
| class AgentController(ABC): |
| @abstractmethod |
| def stop_agent(self) -> None: |
| pass |
|
|
| @abstractmethod |
| def alert_human_manager(self) -> None: |
| pass |
|
|
| @abstractmethod |
| def display_message(self, message: str) -> None: |
| pass |
|
|
|
|
| class SafetySentinel: |
| """Monitors agent behavior and triggers safety responses.""" |
|
|
| HANDOFF_MESSAGE = "I am having trouble. Connecting you to a human..." |
|
|
| def __init__( |
| self, |
| sentiment_analyzer: SentimentAnalyzer, |
| fact_verifier: FactVerifier, |
| agent_controller: AgentController, |
| anger_threshold: float = ANGER_THRESHOLD, |
| fact_score_threshold: float = FACT_SCORE_THRESHOLD |
| ): |
| self._sentiment_analyzer = sentiment_analyzer |
| self._fact_verifier = fact_verifier |
| self._agent_controller = agent_controller |
| self._anger_threshold = anger_threshold |
| self._fact_score_threshold = fact_score_threshold |
|
|
| def safety_check(self, agent_state: AgentState) -> SafetyCheckResult: |
| """Perform safety checks on the current agent state.""" |
| if not agent_state: |
| logger.error("Invalid agent state provided") |
| return SafetyCheckResult( |
| status=SafetyStatus.CHECK_FAILED, |
| error_message="Agent state is required" |
| ) |
|
|
| try: |
| |
| user_anger = self._sentiment_analyzer.analyze( |
| agent_state.user_history |
| ) |
| except Exception as e: |
| logger.exception("Sentiment analysis failed") |
| return SafetyCheckResult( |
| status=SafetyStatus.CHECK_FAILED, |
| error_message=f"Sentiment analysis error: {str(e)}" |
| ) |
|
|
| try: |
| |
| fact_score = self._fact_verifier.verify_facts( |
| agent_state.last_response |
| ) |
| except Exception as e: |
| logger.exception("Fact verification failed") |
| return SafetyCheckResult( |
| status=SafetyStatus.CHECK_FAILED, |
| anger_score=user_anger, |
| error_message=f"Fact verification error: {str(e)}" |
| ) |
|
|
| is_unsafe = ( |
| user_anger > self._anger_threshold or |
| fact_score < self._fact_score_threshold |
| ) |
|
|
| if is_unsafe: |
| logger.warning( |
| f"Safety threshold breached: anger={user_anger}, " |
| f"fact_score={fact_score}" |
| ) |
| return SafetyCheckResult( |
| status=SafetyStatus.KILL_SWITCH_ACTIVE, |
| anger_score=user_anger, |
| fact_score=fact_score |
| ) |
|
|
| return SafetyCheckResult( |
| status=SafetyStatus.SAFE, |
| anger_score=user_anger, |
| fact_score=fact_score |
| ) |
|
|
| def execute_safety_response(self, agent_state: AgentState) -> SafetyCheckResult: |
| """Check safety and execute appropriate response.""" |
| result = self.safety_check(agent_state) |
|
|
| if result.status == SafetyStatus.KILL_SWITCH_ACTIVE: |
| try: |
| self._agent_controller.stop_agent() |
| self._agent_controller.alert_human_manager() |
| self._agent_controller.display_message(self.HANDOFF_MESSAGE) |
| logger.info("Kill switch activated - handed off to human") |
| except Exception as e: |
| logger.exception("Failed to execute safety response") |
| result.error_message = f"Safety response failed: {str(e)}" |
|
|
| return result |
|
|