| from abc import ABC, abstractmethod |
| from contextlib import asynccontextmanager |
| from typing import List, Optional |
|
|
| from pydantic import BaseModel, Field, model_validator |
|
|
| from app.llm import LLM |
| from app.logger import logger |
| from app.sandbox.client import SANDBOX_CLIENT |
| from app.schema import ROLE_TYPE, AgentState, Memory, Message |
|
|
|
|
| class BaseAgent(BaseModel, ABC): |
| """Abstract base class for managing agent state and execution. |
| |
| Provides foundational functionality for state transitions, memory management, |
| and a step-based execution loop. Subclasses must implement the `step` method. |
| """ |
|
|
| |
| name: str = Field(..., description="Unique name of the agent") |
| description: Optional[str] = Field(None, description="Optional agent description") |
|
|
| |
| system_prompt: Optional[str] = Field( |
| None, description="System-level instruction prompt" |
| ) |
| next_step_prompt: Optional[str] = Field( |
| None, description="Prompt for determining next action" |
| ) |
|
|
| |
| llm: LLM = Field(default_factory=LLM, description="Language model instance") |
| memory: Memory = Field(default_factory=Memory, description="Agent's memory store") |
| state: AgentState = Field( |
| default=AgentState.IDLE, description="Current agent state" |
| ) |
|
|
| |
| max_steps: int = Field(default=10, description="Maximum steps before termination") |
| current_step: int = Field(default=0, description="Current step in execution") |
|
|
| duplicate_threshold: int = 2 |
|
|
| class Config: |
| arbitrary_types_allowed = True |
| extra = "allow" |
|
|
| @model_validator(mode="after") |
| def initialize_agent(self) -> "BaseAgent": |
| """Initialize agent with default settings if not provided.""" |
| if self.llm is None or not isinstance(self.llm, LLM): |
| self.llm = LLM(config_name=self.name.lower()) |
| if not isinstance(self.memory, Memory): |
| self.memory = Memory() |
| return self |
|
|
| @asynccontextmanager |
| async def state_context(self, new_state: AgentState): |
| """Context manager for safe agent state transitions. |
| |
| Args: |
| new_state: The state to transition to during the context. |
| |
| Yields: |
| None: Allows execution within the new state. |
| |
| Raises: |
| ValueError: If the new_state is invalid. |
| """ |
| if not isinstance(new_state, AgentState): |
| raise ValueError(f"Invalid state: {new_state}") |
|
|
| previous_state = self.state |
| self.state = new_state |
| try: |
| yield |
| except Exception as e: |
| self.state = AgentState.ERROR |
| raise e |
| finally: |
| self.state = previous_state |
|
|
| def update_memory( |
| self, |
| role: ROLE_TYPE, |
| content: str, |
| base64_image: Optional[str] = None, |
| **kwargs, |
| ) -> None: |
| """Add a message to the agent's memory. |
| |
| Args: |
| role: The role of the message sender (user, system, assistant, tool). |
| content: The message content. |
| base64_image: Optional base64 encoded image. |
| **kwargs: Additional arguments (e.g., tool_call_id for tool messages). |
| |
| Raises: |
| ValueError: If the role is unsupported. |
| """ |
| message_map = { |
| "user": Message.user_message, |
| "system": Message.system_message, |
| "assistant": Message.assistant_message, |
| "tool": lambda content, **kw: Message.tool_message(content, **kw), |
| } |
|
|
| if role not in message_map: |
| raise ValueError(f"Unsupported message role: {role}") |
|
|
| |
| kwargs = {"base64_image": base64_image, **(kwargs if role == "tool" else {})} |
| self.memory.add_message(message_map[role](content, **kwargs)) |
|
|
| async def run(self, request: Optional[str] = None) -> str: |
| """Execute the agent's main loop asynchronously. |
| |
| Args: |
| request: Optional initial user request to process. |
| |
| Returns: |
| A string summarizing the execution results. |
| |
| Raises: |
| RuntimeError: If the agent is not in IDLE state at start. |
| """ |
| if self.state != AgentState.IDLE: |
| raise RuntimeError(f"Cannot run agent from state: {self.state}") |
|
|
| if request: |
| self.update_memory("user", request) |
|
|
| results: List[str] = [] |
| async with self.state_context(AgentState.RUNNING): |
| while ( |
| self.current_step < self.max_steps and self.state != AgentState.FINISHED |
| ): |
| self.current_step += 1 |
| logger.info(f"Executing step {self.current_step}/{self.max_steps}") |
| step_result = await self.step() |
|
|
| |
| if self.is_stuck(): |
| self.handle_stuck_state() |
|
|
| results.append(f"Step {self.current_step}: {step_result}") |
|
|
| if self.current_step >= self.max_steps: |
| self.current_step = 0 |
| self.state = AgentState.IDLE |
| results.append(f"Terminated: Reached max steps ({self.max_steps})") |
| await SANDBOX_CLIENT.cleanup() |
| return "\n".join(results) if results else "No steps executed" |
|
|
| @abstractmethod |
| async def step(self) -> str: |
| """Execute a single step in the agent's workflow. |
| |
| Must be implemented by subclasses to define specific behavior. |
| """ |
|
|
| def handle_stuck_state(self): |
| """Handle stuck state by adding a prompt to change strategy""" |
| stuck_prompt = "\ |
| Observed duplicate responses. Consider new strategies and avoid repeating ineffective paths already attempted." |
| self.next_step_prompt = f"{stuck_prompt}\n{self.next_step_prompt}" |
| logger.warning(f"Agent detected stuck state. Added prompt: {stuck_prompt}") |
|
|
| def is_stuck(self) -> bool: |
| """Check if the agent is stuck in a loop by detecting duplicate content""" |
| if len(self.memory.messages) < 2: |
| return False |
|
|
| last_message = self.memory.messages[-1] |
| if not last_message.content: |
| return False |
|
|
| |
| duplicate_count = sum( |
| 1 |
| for msg in reversed(self.memory.messages[:-1]) |
| if msg.role == "assistant" and msg.content == last_message.content |
| ) |
|
|
| return duplicate_count >= self.duplicate_threshold |
|
|
| @property |
| def messages(self) -> List[Message]: |
| """Retrieve a list of messages from the agent's memory.""" |
| return self.memory.messages |
|
|
| @messages.setter |
| def messages(self, value: List[Message]): |
| """Set the list of messages in the agent's memory.""" |
| self.memory.messages = value |
|
|