import asyncio import inspect import json import traceback import typing as T import uuid from collections.abc import Sequence from collections.abc import Set as AbstractSet import mcp from astrbot import logger from astrbot.core.agent.handoff import HandoffTool from astrbot.core.agent.mcp_client import MCPTool from astrbot.core.agent.message import Message from astrbot.core.agent.run_context import ContextWrapper from astrbot.core.agent.tool import FunctionTool, ToolSet from astrbot.core.agent.tool_executor import BaseFunctionToolExecutor from astrbot.core.astr_agent_context import AstrAgentContext from astrbot.core.astr_main_agent_resources import ( BACKGROUND_TASK_RESULT_WOKE_SYSTEM_PROMPT, EXECUTE_SHELL_TOOL, FILE_DOWNLOAD_TOOL, FILE_UPLOAD_TOOL, LOCAL_EXECUTE_SHELL_TOOL, LOCAL_PYTHON_TOOL, PYTHON_TOOL, SEND_MESSAGE_TO_USER_TOOL, ) from astrbot.core.cron.events import CronMessageEvent from astrbot.core.message.components import Image from astrbot.core.message.message_event_result import ( CommandResult, MessageChain, MessageEventResult, ) from astrbot.core.platform.message_session import MessageSession from astrbot.core.provider.entites import ProviderRequest from astrbot.core.provider.register import llm_tools from astrbot.core.utils.astrbot_path import get_astrbot_temp_path from astrbot.core.utils.history_saver import persist_agent_history from astrbot.core.utils.image_ref_utils import is_supported_image_ref from astrbot.core.utils.string_utils import normalize_and_dedupe_strings class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]): @classmethod def _collect_image_urls_from_args(cls, image_urls_raw: T.Any) -> list[str]: if image_urls_raw is None: return [] if isinstance(image_urls_raw, str): return [image_urls_raw] if isinstance(image_urls_raw, (Sequence, AbstractSet)) and not isinstance( image_urls_raw, (str, bytes, bytearray) ): return [item for item in image_urls_raw if isinstance(item, str)] logger.debug( "Unsupported image_urls type in handoff tool args: %s", type(image_urls_raw).__name__, ) return [] @classmethod async def _collect_image_urls_from_message( cls, run_context: ContextWrapper[AstrAgentContext] ) -> list[str]: urls: list[str] = [] event = getattr(run_context.context, "event", None) message_obj = getattr(event, "message_obj", None) message = getattr(message_obj, "message", None) if message: for idx, component in enumerate(message): if not isinstance(component, Image): continue try: path = await component.convert_to_file_path() if path: urls.append(path) except Exception as e: logger.error( "Failed to convert handoff image component at index %d: %s", idx, e, exc_info=True, ) return urls @classmethod async def _collect_handoff_image_urls( cls, run_context: ContextWrapper[AstrAgentContext], image_urls_raw: T.Any, ) -> list[str]: candidates: list[str] = [] candidates.extend(cls._collect_image_urls_from_args(image_urls_raw)) candidates.extend(await cls._collect_image_urls_from_message(run_context)) normalized = normalize_and_dedupe_strings(candidates) extensionless_local_roots = (get_astrbot_temp_path(),) sanitized = [ item for item in normalized if is_supported_image_ref( item, allow_extensionless_existing_local_file=True, extensionless_local_roots=extensionless_local_roots, ) ] dropped_count = len(normalized) - len(sanitized) if dropped_count > 0: logger.debug( "Dropped %d invalid image_urls entries in handoff image inputs.", dropped_count, ) return sanitized @classmethod async def execute(cls, tool, run_context, **tool_args): """执行函数调用。 Args: event (AstrMessageEvent): 事件对象, 当 origin 为 local 时必须提供。 **kwargs: 函数调用的参数。 Returns: AsyncGenerator[None | mcp.types.CallToolResult, None] """ if isinstance(tool, HandoffTool): is_bg = tool_args.pop("background_task", False) if is_bg: async for r in cls._execute_handoff_background( tool, run_context, **tool_args ): yield r return async for r in cls._execute_handoff(tool, run_context, **tool_args): yield r return elif isinstance(tool, MCPTool): async for r in cls._execute_mcp(tool, run_context, **tool_args): yield r return elif tool.is_background_task: task_id = uuid.uuid4().hex async def _run_in_background() -> None: try: await cls._execute_background( tool=tool, run_context=run_context, task_id=task_id, **tool_args, ) except Exception as e: # noqa: BLE001 logger.error( f"Background task {task_id} failed: {e!s}", exc_info=True, ) asyncio.create_task(_run_in_background()) text_content = mcp.types.TextContent( type="text", text=f"Background task submitted. task_id={task_id}", ) yield mcp.types.CallToolResult(content=[text_content]) return else: async for r in cls._execute_local(tool, run_context, **tool_args): yield r return @classmethod def _get_runtime_computer_tools(cls, runtime: str) -> dict[str, FunctionTool]: if runtime == "sandbox": return { EXECUTE_SHELL_TOOL.name: EXECUTE_SHELL_TOOL, PYTHON_TOOL.name: PYTHON_TOOL, FILE_UPLOAD_TOOL.name: FILE_UPLOAD_TOOL, FILE_DOWNLOAD_TOOL.name: FILE_DOWNLOAD_TOOL, } if runtime == "local": return { LOCAL_EXECUTE_SHELL_TOOL.name: LOCAL_EXECUTE_SHELL_TOOL, LOCAL_PYTHON_TOOL.name: LOCAL_PYTHON_TOOL, } return {} @classmethod def _build_handoff_toolset( cls, run_context: ContextWrapper[AstrAgentContext], tools: list[str | FunctionTool] | None, ) -> ToolSet | None: ctx = run_context.context.context event = run_context.context.event cfg = ctx.get_config(umo=event.unified_msg_origin) provider_settings = cfg.get("provider_settings", {}) runtime = str(provider_settings.get("computer_use_runtime", "local")) runtime_computer_tools = cls._get_runtime_computer_tools(runtime) # Keep persona semantics aligned with the main agent: tools=None means # "all tools", including runtime computer-use tools. if tools is None: toolset = ToolSet() for registered_tool in llm_tools.func_list: if isinstance(registered_tool, HandoffTool): continue if registered_tool.active: toolset.add_tool(registered_tool) for runtime_tool in runtime_computer_tools.values(): toolset.add_tool(runtime_tool) return None if toolset.empty() else toolset if not tools: return None toolset = ToolSet() for tool_name_or_obj in tools: if isinstance(tool_name_or_obj, str): registered_tool = llm_tools.get_func(tool_name_or_obj) if registered_tool and registered_tool.active: toolset.add_tool(registered_tool) continue runtime_tool = runtime_computer_tools.get(tool_name_or_obj) if runtime_tool: toolset.add_tool(runtime_tool) elif isinstance(tool_name_or_obj, FunctionTool): toolset.add_tool(tool_name_or_obj) return None if toolset.empty() else toolset @classmethod async def _execute_handoff( cls, tool: HandoffTool, run_context: ContextWrapper[AstrAgentContext], *, image_urls_prepared: bool = False, **tool_args: T.Any, ): tool_args = dict(tool_args) input_ = tool_args.get("input") if image_urls_prepared: prepared_image_urls = tool_args.get("image_urls") if isinstance(prepared_image_urls, list): image_urls = prepared_image_urls else: logger.debug( "Expected prepared handoff image_urls as list[str], got %s.", type(prepared_image_urls).__name__, ) image_urls = [] else: image_urls = await cls._collect_handoff_image_urls( run_context, tool_args.get("image_urls"), ) tool_args["image_urls"] = image_urls # Build handoff toolset from registered tools plus runtime computer tools. toolset = cls._build_handoff_toolset(run_context, tool.agent.tools) ctx = run_context.context.context event = run_context.context.event umo = event.unified_msg_origin # Use per-subagent provider override if configured; otherwise fall back # to the current/default provider resolution. prov_id = getattr( tool, "provider_id", None ) or await ctx.get_current_chat_provider_id(umo) # prepare begin dialogs contexts = None dialogs = tool.agent.begin_dialogs if dialogs: contexts = [] for dialog in dialogs: try: contexts.append( dialog if isinstance(dialog, Message) else Message.model_validate(dialog) ) except Exception: continue prov_settings: dict = ctx.get_config(umo=umo).get("provider_settings", {}) agent_max_step = int(prov_settings.get("max_agent_step", 30)) stream = prov_settings.get("streaming_response", False) llm_resp = await ctx.tool_loop_agent( event=event, chat_provider_id=prov_id, prompt=input_, image_urls=image_urls, system_prompt=tool.agent.instructions, tools=toolset, contexts=contexts, max_steps=agent_max_step, stream=stream, ) yield mcp.types.CallToolResult( content=[mcp.types.TextContent(type="text", text=llm_resp.completion_text)] ) @classmethod async def _execute_handoff_background( cls, tool: HandoffTool, run_context: ContextWrapper[AstrAgentContext], **tool_args, ): """Execute a handoff as a background task. Immediately yields a success response with a task_id, then runs the subagent asynchronously. When the subagent finishes, a ``CronMessageEvent`` is created so the main LLM can inform the user of the result – the same pattern used by ``_execute_background`` for regular background tasks. """ task_id = uuid.uuid4().hex async def _run_handoff_in_background() -> None: try: await cls._do_handoff_background( tool=tool, run_context=run_context, task_id=task_id, **tool_args, ) except Exception as e: # noqa: BLE001 logger.error( f"Background handoff {task_id} ({tool.name}) failed: {e!s}", exc_info=True, ) asyncio.create_task(_run_handoff_in_background()) text_content = mcp.types.TextContent( type="text", text=( f"Background task dedicated to subagent '{tool.agent.name}' submitted. task_id={task_id}. " f"The subagent '{tool.agent.name}' is working on the task on hehalf you. " f"You will be notified when it finishes." ), ) yield mcp.types.CallToolResult(content=[text_content]) @classmethod async def _do_handoff_background( cls, tool: HandoffTool, run_context: ContextWrapper[AstrAgentContext], task_id: str, **tool_args, ) -> None: """Run the subagent handoff and, on completion, wake the main agent.""" result_text = "" tool_args = dict(tool_args) tool_args["image_urls"] = await cls._collect_handoff_image_urls( run_context, tool_args.get("image_urls"), ) try: async for r in cls._execute_handoff( tool, run_context, image_urls_prepared=True, **tool_args, ): if isinstance(r, mcp.types.CallToolResult): for content in r.content: if isinstance(content, mcp.types.TextContent): result_text += content.text + "\n" except Exception as e: result_text = ( f"error: Background task execution failed, internal error: {e!s}" ) event = run_context.context.event await cls._wake_main_agent_for_background_result( run_context=run_context, task_id=task_id, tool_name=tool.name, result_text=result_text, tool_args=tool_args, note=( event.get_extra("background_note") or f"Background task for subagent '{tool.agent.name}' finished." ), summary_name=f"Dedicated to subagent `{tool.agent.name}`", extra_result_fields={"subagent_name": tool.agent.name}, ) @classmethod async def _execute_background( cls, tool: FunctionTool, run_context: ContextWrapper[AstrAgentContext], task_id: str, **tool_args, ) -> None: # run the tool result_text = "" try: async for r in cls._execute_local( tool, run_context, tool_call_timeout=3600, **tool_args ): # collect results, currently we just collect the text results if isinstance(r, mcp.types.CallToolResult): result_text = "" for content in r.content: if isinstance(content, mcp.types.TextContent): result_text += content.text + "\n" except Exception as e: result_text = ( f"error: Background task execution failed, internal error: {e!s}" ) event = run_context.context.event await cls._wake_main_agent_for_background_result( run_context=run_context, task_id=task_id, tool_name=tool.name, result_text=result_text, tool_args=tool_args, note=( event.get_extra("background_note") or f"Background task {tool.name} finished." ), summary_name=tool.name, ) @classmethod async def _wake_main_agent_for_background_result( cls, run_context: ContextWrapper[AstrAgentContext], *, task_id: str, tool_name: str, result_text: str, tool_args: dict[str, T.Any], note: str, summary_name: str, extra_result_fields: dict[str, T.Any] | None = None, ) -> None: from astrbot.core.astr_main_agent import ( MainAgentBuildConfig, _get_session_conv, build_main_agent, ) event = run_context.context.event ctx = run_context.context.context task_result = { "task_id": task_id, "tool_name": tool_name, "result": result_text or "", "tool_args": tool_args, } if extra_result_fields: task_result.update(extra_result_fields) extras = {"background_task_result": task_result} session = MessageSession.from_str(event.unified_msg_origin) cron_event = CronMessageEvent( context=ctx, session=session, message=note, extras=extras, message_type=session.message_type, ) cron_event.role = event.role config = MainAgentBuildConfig( tool_call_timeout=3600, streaming_response=ctx.get_config() .get("provider_settings", {}) .get("stream", False), ) req = ProviderRequest() conv = await _get_session_conv(event=cron_event, plugin_context=ctx) req.conversation = conv context = json.loads(conv.history) if context: req.contexts = context context_dump = req._print_friendly_context() req.contexts = [] req.system_prompt += ( "\n\nBellow is you and user previous conversation history:\n" f"{context_dump}" ) bg = json.dumps(extras["background_task_result"], ensure_ascii=False) req.system_prompt += BACKGROUND_TASK_RESULT_WOKE_SYSTEM_PROMPT.format( background_task_result=bg ) req.prompt = ( "Proceed according to your system instructions. " "Output using same language as previous conversation. " "If you need to deliver the result to the user immediately, " "you MUST use `send_message_to_user` tool to send the message directly to the user, " "otherwise the user will not see the result. " "After completing your task, summarize and output your actions and results. " ) if not req.func_tool: req.func_tool = ToolSet() req.func_tool.add_tool(SEND_MESSAGE_TO_USER_TOOL) result = await build_main_agent( event=cron_event, plugin_context=ctx, config=config, req=req ) if not result: logger.error(f"Failed to build main agent for background task {tool_name}.") return runner = result.agent_runner async for _ in runner.step_until_done(30): # agent will send message to user via using tools pass llm_resp = runner.get_final_llm_resp() task_meta = extras.get("background_task_result", {}) summary_note = ( f"[BackgroundTask] {summary_name} " f"(task_id={task_meta.get('task_id', task_id)}) finished. " f"Result: {task_meta.get('result') or result_text or 'no content'}" ) if llm_resp and llm_resp.completion_text: summary_note += ( f"I finished the task, here is the result: {llm_resp.completion_text}" ) await persist_agent_history( ctx.conversation_manager, event=cron_event, req=req, summary_note=summary_note, ) if not llm_resp: logger.warning("background task agent got no response") return @classmethod async def _execute_local( cls, tool: FunctionTool, run_context: ContextWrapper[AstrAgentContext], *, tool_call_timeout: int | None = None, **tool_args, ): event = run_context.context.event if not event: raise ValueError("Event must be provided for local function tools.") is_override_call = False for ty in type(tool).mro(): if "call" in ty.__dict__ and ty.__dict__["call"] is not FunctionTool.call: is_override_call = True break # 检查 tool 下有没有 run 方法 if not tool.handler and not hasattr(tool, "run") and not is_override_call: raise ValueError("Tool must have a valid handler or override 'run' method.") awaitable = None method_name = "" if tool.handler: awaitable = tool.handler method_name = "decorator_handler" elif is_override_call: awaitable = tool.call method_name = "call" elif hasattr(tool, "run"): awaitable = getattr(tool, "run") method_name = "run" if awaitable is None: raise ValueError("Tool must have a valid handler or override 'run' method.") wrapper = call_local_llm_tool( context=run_context, handler=awaitable, method_name=method_name, **tool_args, ) while True: try: resp = await asyncio.wait_for( anext(wrapper), timeout=tool_call_timeout or run_context.tool_call_timeout, ) if resp is not None: if isinstance(resp, mcp.types.CallToolResult): yield resp else: text_content = mcp.types.TextContent( type="text", text=str(resp), ) yield mcp.types.CallToolResult(content=[text_content]) else: # NOTE: Tool 在这里直接请求发送消息给用户 # TODO: 是否需要判断 event.get_result() 是否为空? # 如果为空,则说明没有发送消息给用户,并且返回值为空,将返回一个特殊的 TextContent,其内容如"工具没有返回内容" if res := run_context.context.event.get_result(): if res.chain: try: await event.send( MessageChain( chain=res.chain, type="tool_direct_result", ) ) except Exception as e: logger.error( f"Tool 直接发送消息失败: {e}", exc_info=True, ) yield None except asyncio.TimeoutError: raise Exception( f"tool {tool.name} execution timeout after {tool_call_timeout or run_context.tool_call_timeout} seconds.", ) except StopAsyncIteration: break @classmethod async def _execute_mcp( cls, tool: FunctionTool, run_context: ContextWrapper[AstrAgentContext], **tool_args, ): res = await tool.call(run_context, **tool_args) if not res: return yield res async def call_local_llm_tool( context: ContextWrapper[AstrAgentContext], handler: T.Callable[ ..., T.Awaitable[MessageEventResult | mcp.types.CallToolResult | str | None] | T.AsyncGenerator[MessageEventResult | CommandResult | str | None, None], ], method_name: str, *args, **kwargs, ) -> T.AsyncGenerator[T.Any, None]: """执行本地 LLM 工具的处理函数并处理其返回结果""" ready_to_call = None # 一个协程或者异步生成器 trace_ = None event = context.context.event try: if method_name == "run" or method_name == "decorator_handler": ready_to_call = handler(event, *args, **kwargs) elif method_name == "call": ready_to_call = handler(context, *args, **kwargs) else: raise ValueError(f"未知的方法名: {method_name}") except ValueError as e: raise Exception(f"Tool execution ValueError: {e}") from e except TypeError as e: # 获取函数的签名(包括类型),除了第一个 event/context 参数。 try: sig = inspect.signature(handler) params = list(sig.parameters.values()) # 跳过第一个参数(event 或 context) if params: params = params[1:] param_strs = [] for param in params: param_str = param.name if param.annotation != inspect.Parameter.empty: # 获取类型注解的字符串表示 if isinstance(param.annotation, type): type_str = param.annotation.__name__ else: type_str = str(param.annotation) param_str += f": {type_str}" if param.default != inspect.Parameter.empty: param_str += f" = {param.default!r}" param_strs.append(param_str) handler_param_str = ( ", ".join(param_strs) if param_strs else "(no additional parameters)" ) except Exception: handler_param_str = "(unable to inspect signature)" raise Exception( f"Tool handler parameter mismatch, please check the handler definition. Handler parameters: {handler_param_str}" ) from e except Exception as e: trace_ = traceback.format_exc() raise Exception(f"Tool execution error: {e}. Traceback: {trace_}") from e if not ready_to_call: return if inspect.isasyncgen(ready_to_call): _has_yielded = False try: async for ret in ready_to_call: # 这里逐步执行异步生成器, 对于每个 yield 返回的 ret, 执行下面的代码 # 返回值只能是 MessageEventResult 或者 None(无返回值) _has_yielded = True if isinstance(ret, MessageEventResult | CommandResult): # 如果返回值是 MessageEventResult, 设置结果并继续 event.set_result(ret) yield else: # 如果返回值是 None, 则不设置结果并继续 # 继续执行后续阶段 yield ret if not _has_yielded: # 如果这个异步生成器没有执行到 yield 分支 yield except Exception as e: logger.error(f"Previous Error: {trace_}") raise e elif inspect.iscoroutine(ready_to_call): # 如果只是一个协程, 直接执行 ret = await ready_to_call if isinstance(ret, MessageEventResult | CommandResult): event.set_result(ret) yield else: yield ret