| import asyncio |
| import json |
| from typing import Any, List, Optional, Union |
|
|
| from pydantic import Field |
|
|
| from app.agent.react import ReActAgent |
| from app.exceptions import TokenLimitExceeded |
| from app.logger import logger |
| from app.prompt.toolcall import NEXT_STEP_PROMPT, SYSTEM_PROMPT |
| from app.schema import TOOL_CHOICE_TYPE, AgentState, Message, ToolCall, ToolChoice |
| from app.tool import CreateChatCompletion, Terminate, ToolCollection |
|
|
|
|
| TOOL_CALL_REQUIRED = "Tool calls required but none provided" |
|
|
|
|
| class ToolCallAgent(ReActAgent): |
| """Base agent class for handling tool/function calls with enhanced abstraction""" |
|
|
| name: str = "toolcall" |
| description: str = "an agent that can execute tool calls." |
|
|
| system_prompt: str = SYSTEM_PROMPT |
| next_step_prompt: str = NEXT_STEP_PROMPT |
|
|
| available_tools: ToolCollection = ToolCollection( |
| CreateChatCompletion(), Terminate() |
| ) |
| tool_choices: TOOL_CHOICE_TYPE = ToolChoice.AUTO |
| special_tool_names: List[str] = Field(default_factory=lambda: [Terminate().name]) |
|
|
| tool_calls: List[ToolCall] = Field(default_factory=list) |
| _current_base64_image: Optional[str] = None |
|
|
| max_steps: int = 30 |
| max_observe: Optional[Union[int, bool]] = None |
|
|
| async def think(self) -> bool: |
| """Process current state and decide next actions using tools""" |
| if self.next_step_prompt: |
| user_msg = Message.user_message(self.next_step_prompt) |
| self.messages += [user_msg] |
|
|
| try: |
| |
| response = await self.llm.ask_tool( |
| messages=self.messages, |
| system_msgs=( |
| [Message.system_message(self.system_prompt)] |
| if self.system_prompt |
| else None |
| ), |
| tools=self.available_tools.to_params(), |
| tool_choice=self.tool_choices, |
| ) |
| except ValueError: |
| raise |
| except Exception as e: |
| |
| if hasattr(e, "__cause__") and isinstance(e.__cause__, TokenLimitExceeded): |
| token_limit_error = e.__cause__ |
| logger.error( |
| f"π¨ Token limit error (from RetryError): {token_limit_error}" |
| ) |
| self.memory.add_message( |
| Message.assistant_message( |
| f"Maximum token limit reached, cannot continue execution: {str(token_limit_error)}" |
| ) |
| ) |
| self.state = AgentState.FINISHED |
| return False |
| raise |
|
|
| self.tool_calls = tool_calls = ( |
| response.tool_calls if response and response.tool_calls else [] |
| ) |
| content = response.content if response and response.content else "" |
|
|
| |
| logger.info(f"β¨ {self.name}'s thoughts: {content}") |
| logger.info( |
| f"π οΈ {self.name} selected {len(tool_calls) if tool_calls else 0} tools to use" |
| ) |
| if tool_calls: |
| logger.info( |
| f"π§° Tools being prepared: {[call.function.name for call in tool_calls]}" |
| ) |
| logger.info(f"π§ Tool arguments: {tool_calls[0].function.arguments}") |
|
|
| try: |
| if response is None: |
| raise RuntimeError("No response received from the LLM") |
|
|
| |
| if self.tool_choices == ToolChoice.NONE: |
| if tool_calls: |
| logger.warning( |
| f"π€ Hmm, {self.name} tried to use tools when they weren't available!" |
| ) |
| if content: |
| self.memory.add_message(Message.assistant_message(content)) |
| return True |
| return False |
|
|
| |
| assistant_msg = ( |
| Message.from_tool_calls(content=content, tool_calls=self.tool_calls) |
| if self.tool_calls |
| else Message.assistant_message(content) |
| ) |
| self.memory.add_message(assistant_msg) |
|
|
| if self.tool_choices == ToolChoice.REQUIRED and not self.tool_calls: |
| return True |
|
|
| |
| if self.tool_choices == ToolChoice.AUTO and not self.tool_calls: |
| return bool(content) |
|
|
| return bool(self.tool_calls) |
| except Exception as e: |
| logger.error(f"π¨ Oops! The {self.name}'s thinking process hit a snag: {e}") |
| self.memory.add_message( |
| Message.assistant_message( |
| f"Error encountered while processing: {str(e)}" |
| ) |
| ) |
| return False |
|
|
| async def act(self) -> str: |
| """Execute tool calls and handle their results""" |
| if not self.tool_calls: |
| if self.tool_choices == ToolChoice.REQUIRED: |
| raise ValueError(TOOL_CALL_REQUIRED) |
|
|
| |
| return self.messages[-1].content or "No content or commands to execute" |
|
|
| results = [] |
| for command in self.tool_calls: |
| |
| self._current_base64_image = None |
|
|
| result = await self.execute_tool(command) |
|
|
| if self.max_observe: |
| result = result[: self.max_observe] |
|
|
| logger.info( |
| f"π― Tool '{command.function.name}' completed its mission! Result: {result}" |
| ) |
|
|
| |
| tool_msg = Message.tool_message( |
| content=result, |
| tool_call_id=command.id, |
| name=command.function.name, |
| base64_image=self._current_base64_image, |
| ) |
| self.memory.add_message(tool_msg) |
| results.append(result) |
|
|
| return "\n\n".join(results) |
|
|
| async def execute_tool(self, command: ToolCall) -> str: |
| """Execute a single tool call with robust error handling""" |
| if not command or not command.function or not command.function.name: |
| return "Error: Invalid command format" |
|
|
| name = command.function.name |
| if name not in self.available_tools.tool_map: |
| return f"Error: Unknown tool '{name}'" |
|
|
| try: |
| |
| args = json.loads(command.function.arguments or "{}") |
|
|
| |
| logger.info(f"π§ Activating tool: '{name}'...") |
| result = await self.available_tools.execute(name=name, tool_input=args) |
|
|
| |
| await self._handle_special_tool(name=name, result=result) |
|
|
| |
| if hasattr(result, "base64_image") and result.base64_image: |
| |
| self._current_base64_image = result.base64_image |
|
|
| |
| observation = ( |
| f"Observed output of cmd `{name}` executed:\n{str(result)}" |
| if result |
| else f"Cmd `{name}` completed with no output" |
| ) |
|
|
| return observation |
| except json.JSONDecodeError: |
| error_msg = f"Error parsing arguments for {name}: Invalid JSON format" |
| logger.error( |
| f"π Oops! The arguments for '{name}' don't make sense - invalid JSON, arguments:{command.function.arguments}" |
| ) |
| return f"Error: {error_msg}" |
| except Exception as e: |
| error_msg = f"β οΈ Tool '{name}' encountered a problem: {str(e)}" |
| logger.exception(error_msg) |
| return f"Error: {error_msg}" |
|
|
| async def _handle_special_tool(self, name: str, result: Any, **kwargs): |
| """Handle special tool execution and state changes""" |
| if not self._is_special_tool(name): |
| return |
|
|
| if self._should_finish_execution(name=name, result=result, **kwargs): |
| |
| logger.info(f"π Special tool '{name}' has completed the task!") |
| self.state = AgentState.FINISHED |
|
|
| @staticmethod |
| def _should_finish_execution(**kwargs) -> bool: |
| """Determine if tool execution should finish the agent""" |
| return True |
|
|
| def _is_special_tool(self, name: str) -> bool: |
| """Check if tool name is in special tools list""" |
| return name.lower() in [n.lower() for n in self.special_tool_names] |
|
|
| async def cleanup(self): |
| """Clean up resources used by the agent's tools.""" |
| logger.info(f"π§Ή Cleaning up resources for agent '{self.name}'...") |
| for tool_name, tool_instance in self.available_tools.tool_map.items(): |
| if hasattr(tool_instance, "cleanup") and asyncio.iscoroutinefunction( |
| tool_instance.cleanup |
| ): |
| try: |
| logger.debug(f"π§Ό Cleaning up tool: {tool_name}") |
| await tool_instance.cleanup() |
| except Exception as e: |
| logger.error( |
| f"π¨ Error cleaning up tool '{tool_name}': {e}", exc_info=True |
| ) |
| logger.info(f"β¨ Cleanup complete for agent '{self.name}'.") |
|
|
| async def run(self, request: Optional[str] = None) -> str: |
| """Run the agent with cleanup when done.""" |
| try: |
| return await super().run(request) |
| finally: |
| await self.cleanup() |
|
|