| from __future__ import annotations |
|
|
| import re |
| from collections.abc import AsyncGenerator, Awaitable, Callable |
| from typing import TYPE_CHECKING, Any |
|
|
| import docstring_parser |
|
|
| from astrbot.core import logger |
| from astrbot.core.agent.agent import Agent |
| from astrbot.core.agent.handoff import HandoffTool |
| from astrbot.core.agent.hooks import BaseAgentRunHooks |
| from astrbot.core.agent.tool import FunctionTool |
| from astrbot.core.message.message_event_result import MessageEventResult |
| from astrbot.core.provider.func_tool_manager import PY_TO_JSON_TYPE, SUPPORTED_TYPES |
| from astrbot.core.provider.register import llm_tools |
|
|
| if TYPE_CHECKING: |
| from astrbot.core.astr_agent_context import AstrAgentContext |
|
|
| from ..filter.command import CommandFilter |
| from ..filter.command_group import CommandGroupFilter |
| from ..filter.custom_filter import CustomFilterAnd, CustomFilterOr |
| from ..filter.event_message_type import EventMessageType, EventMessageTypeFilter |
| from ..filter.permission import PermissionType, PermissionTypeFilter |
| from ..filter.platform_adapter_type import ( |
| PlatformAdapterType, |
| PlatformAdapterTypeFilter, |
| ) |
| from ..filter.regex import RegexFilter |
| from ..star_handler import EventType, StarHandlerMetadata, star_handlers_registry |
|
|
|
|
| def get_handler_full_name( |
| awaitable: Callable[..., Awaitable[Any] | AsyncGenerator[Any]], |
| ) -> str: |
| """获取 Handler 的全名""" |
| return f"{awaitable.__module__}_{awaitable.__name__}" |
|
|
|
|
| def get_handler_or_create( |
| handler: Callable[ |
| ..., |
| Awaitable[MessageEventResult | str | None] |
| | AsyncGenerator[MessageEventResult | str | None], |
| ], |
| event_type: EventType, |
| dont_add=False, |
| **kwargs, |
| ) -> StarHandlerMetadata: |
| """获取 Handler 或者创建一个新的 Handler""" |
| handler_full_name = get_handler_full_name(handler) |
| md = star_handlers_registry.get_handler_by_full_name(handler_full_name) |
| if md: |
| return md |
| md = StarHandlerMetadata( |
| event_type=event_type, |
| handler_full_name=handler_full_name, |
| handler_name=handler.__name__, |
| handler_module_path=handler.__module__, |
| handler=handler, |
| event_filters=[], |
| ) |
|
|
| |
| if handler.__doc__: |
| md.desc = handler.__doc__.strip() |
| if "desc" in kwargs: |
| md.desc = kwargs["desc"] |
| del kwargs["desc"] |
| md.extras_configs = kwargs |
|
|
| if not dont_add: |
| star_handlers_registry.append(md) |
| return md |
|
|
|
|
| def register_command( |
| command_name: str | None = None, |
| sub_command: str | None = None, |
| alias: set | None = None, |
| **kwargs, |
| ): |
| """注册一个 Command.""" |
| new_command = None |
| add_to_event_filters = False |
| if isinstance(command_name, RegisteringCommandable): |
| |
| if sub_command is not None: |
| parent_command_names = ( |
| command_name.parent_group.get_complete_command_names() |
| ) |
| new_command = CommandFilter( |
| sub_command, |
| alias, |
| None, |
| parent_command_names=parent_command_names, |
| ) |
| command_name.parent_group.add_sub_command_filter(new_command) |
| else: |
| logger.warning( |
| f"注册指令{command_name} 的子指令时未提供 sub_command 参数。", |
| ) |
| |
| elif command_name is None: |
| logger.warning("注册裸指令时未提供 command_name 参数。") |
| else: |
| new_command = CommandFilter(command_name, alias, None) |
| add_to_event_filters = True |
|
|
| def decorator(awaitable): |
| if not add_to_event_filters: |
| kwargs["sub_command"] = ( |
| True |
| ) |
| handler_md = get_handler_or_create( |
| awaitable, |
| EventType.AdapterMessageEvent, |
| **kwargs, |
| ) |
| if new_command: |
| new_command.init_handler_md(handler_md) |
| handler_md.event_filters.append(new_command) |
| return awaitable |
|
|
| return decorator |
|
|
|
|
| def register_custom_filter(custom_type_filter, *args, **kwargs): |
| """注册一个自定义的 CustomFilter |
| |
| Args: |
| custom_type_filter: 在裸指令时为CustomFilter对象 |
| 在指令组时为父指令的RegisteringCommandable对象,即self或者command_group的返回 |
| raise_error: 如果没有权限,是否抛出错误到消息平台,并且停止事件传播。默认为 True |
| |
| """ |
| add_to_event_filters = False |
| raise_error = True |
|
|
| |
| if isinstance(custom_type_filter, RegisteringCommandable): |
| |
| parent_register_commandable = custom_type_filter |
| custom_filter = args[0] |
| if len(args) > 1: |
| raise_error = args[1] |
| else: |
| |
| add_to_event_filters = True |
| custom_filter = custom_type_filter |
| if args: |
| raise_error = args[0] |
|
|
| if not isinstance(custom_filter, (CustomFilterAnd, CustomFilterOr)): |
| custom_filter = custom_filter(raise_error) |
|
|
| def decorator(awaitable): |
| |
| if ( |
| not add_to_event_filters and isinstance(awaitable, RegisteringCommandable) |
| ) or (add_to_event_filters and isinstance(awaitable, RegisteringCommandable)): |
| |
| awaitable.parent_group.add_custom_filter(custom_filter) |
| else: |
| handler_md = get_handler_or_create( |
| awaitable, |
| EventType.AdapterMessageEvent, |
| **kwargs, |
| ) |
|
|
| if not add_to_event_filters and not isinstance( |
| awaitable, |
| RegisteringCommandable, |
| ): |
| |
| handle_full_name = get_handler_full_name(awaitable) |
| for ( |
| sub_handle |
| ) in parent_register_commandable.parent_group.sub_command_filters: |
| if isinstance(sub_handle, CommandGroupFilter): |
| continue |
| |
| |
| sub_handle_md = sub_handle.get_handler_md() |
| if ( |
| sub_handle_md |
| and sub_handle_md.handler_full_name == handle_full_name |
| ): |
| sub_handle.add_custom_filter(custom_filter) |
|
|
| else: |
| |
| |
| assert isinstance(awaitable, Callable) |
| handler_md = get_handler_or_create( |
| awaitable, |
| EventType.AdapterMessageEvent, |
| **kwargs, |
| ) |
| handler_md.event_filters.append(custom_filter) |
|
|
| return awaitable |
|
|
| return decorator |
|
|
|
|
| def register_command_group( |
| command_group_name: str | None = None, |
| sub_command: str | None = None, |
| alias: set | None = None, |
| **kwargs, |
| ): |
| """注册一个 CommandGroup""" |
| new_group = None |
| if isinstance(command_group_name, RegisteringCommandable): |
| |
| if sub_command is None: |
| logger.warning(f"{command_group_name} 指令组的子指令组 sub_command 未指定") |
| else: |
| new_group = CommandGroupFilter( |
| sub_command, |
| alias, |
| parent_group=command_group_name.parent_group, |
| ) |
| command_group_name.parent_group.add_sub_command_filter(new_group) |
| |
| elif command_group_name is None: |
| logger.warning("根指令组的名称未指定") |
| else: |
| new_group = CommandGroupFilter(command_group_name, alias) |
|
|
| def decorator(obj): |
| if new_group: |
| handler_md = get_handler_or_create( |
| obj, |
| EventType.AdapterMessageEvent, |
| **kwargs, |
| ) |
| handler_md.event_filters.append(new_group) |
|
|
| return RegisteringCommandable(new_group) |
| raise ValueError("注册指令组失败。") |
|
|
| return decorator |
|
|
|
|
| class RegisteringCommandable: |
| """用于指令组级联注册""" |
|
|
| group: Callable[..., Callable[..., RegisteringCommandable]] = register_command_group |
| command: Callable[..., Callable[..., None]] = register_command |
| custom_filter: Callable[..., Callable[..., Any]] = register_custom_filter |
|
|
| def __init__(self, parent_group: CommandGroupFilter) -> None: |
| self.parent_group = parent_group |
|
|
|
|
| def register_event_message_type(event_message_type: EventMessageType, **kwargs): |
| """注册一个 EventMessageType""" |
|
|
| def decorator(awaitable): |
| handler_md = get_handler_or_create( |
| awaitable, |
| EventType.AdapterMessageEvent, |
| **kwargs, |
| ) |
| handler_md.event_filters.append(EventMessageTypeFilter(event_message_type)) |
| return awaitable |
|
|
| return decorator |
|
|
|
|
| def register_platform_adapter_type( |
| platform_adapter_type: PlatformAdapterType, |
| **kwargs, |
| ): |
| """注册一个 PlatformAdapterType""" |
|
|
| def decorator(awaitable): |
| handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent) |
| handler_md.event_filters.append( |
| PlatformAdapterTypeFilter(platform_adapter_type), |
| ) |
| return awaitable |
|
|
| return decorator |
|
|
|
|
| def register_regex(regex: str, **kwargs): |
| """注册一个 Regex""" |
|
|
| def decorator(awaitable): |
| handler_md = get_handler_or_create( |
| awaitable, |
| EventType.AdapterMessageEvent, |
| **kwargs, |
| ) |
| handler_md.event_filters.append(RegexFilter(regex)) |
| return awaitable |
|
|
| return decorator |
|
|
|
|
| def register_permission_type(permission_type: PermissionType, raise_error: bool = True): |
| """注册一个 PermissionType |
| |
| Args: |
| permission_type: PermissionType |
| raise_error: 如果没有权限,是否抛出错误到消息平台,并且停止事件传播。默认为 True |
| |
| """ |
|
|
| def decorator(awaitable): |
| handler_md = get_handler_or_create(awaitable, EventType.AdapterMessageEvent) |
| handler_md.event_filters.append( |
| PermissionTypeFilter(permission_type, raise_error), |
| ) |
| return awaitable |
|
|
| return decorator |
|
|
|
|
| def register_on_astrbot_loaded(**kwargs): |
| """当 AstrBot 加载完成时""" |
|
|
| def decorator(awaitable): |
| _ = get_handler_or_create(awaitable, EventType.OnAstrBotLoadedEvent, **kwargs) |
| return awaitable |
|
|
| return decorator |
|
|
|
|
| def register_on_platform_loaded(**kwargs): |
| """当平台加载完成时""" |
|
|
| def decorator(awaitable): |
| _ = get_handler_or_create(awaitable, EventType.OnPlatformLoadedEvent, **kwargs) |
| return awaitable |
|
|
| return decorator |
|
|
|
|
| def register_on_plugin_error(**kwargs): |
| """当插件处理消息异常时触发。 |
| |
| Hook 参数: |
| event, plugin_name, handler_name, error, traceback_text |
| |
| 说明: |
| 在 hook 中调用 `event.stop_event()` 可屏蔽默认报错回显, |
| 并由插件自行决定是否转发到其他会话。 |
| """ |
|
|
| def decorator(awaitable): |
| _ = get_handler_or_create(awaitable, EventType.OnPluginErrorEvent, **kwargs) |
| return awaitable |
|
|
| return decorator |
|
|
|
|
| def register_on_plugin_loaded(**kwargs): |
| """当有插件加载完成时 |
| |
| Hook 参数: |
| metadata |
| |
| 说明: |
| 当有插件加载完成时,触发该事件并获取到该插件的元数据 |
| """ |
|
|
| def decorator(awaitable): |
| _ = get_handler_or_create(awaitable, EventType.OnPluginLoadedEvent, **kwargs) |
| return awaitable |
|
|
| return decorator |
|
|
|
|
| def register_on_plugin_unloaded(**kwargs): |
| """当有插件卸载完成时 |
| |
| Hook 参数: |
| metadata |
| |
| 说明: |
| 当有插件卸载完成时,触发该事件并获取到该插件的元数据 |
| """ |
|
|
| def decorator(awaitable): |
| _ = get_handler_or_create(awaitable, EventType.OnPluginUnloadedEvent, **kwargs) |
| return awaitable |
|
|
| return decorator |
|
|
|
|
| def register_on_waiting_llm_request(**kwargs): |
| """当等待调用 LLM 时的通知事件(在获取锁之前) |
| |
| 此钩子在消息确定要调用 LLM 但还未开始排队等锁时触发, |
| 适合用于发送"正在思考中..."等用户反馈提示。 |
| |
| Examples: |
| ```py |
| @on_waiting_llm_request() |
| async def on_waiting_llm(self, event: AstrMessageEvent) -> None: |
| await event.send("🤔 正在思考中...") |
| ``` |
| |
| """ |
|
|
| def decorator(awaitable): |
| _ = get_handler_or_create( |
| awaitable, EventType.OnWaitingLLMRequestEvent, **kwargs |
| ) |
| return awaitable |
|
|
| return decorator |
|
|
|
|
| def register_on_llm_request(**kwargs): |
| """当有 LLM 请求时的事件 |
| |
| Examples: |
| ```py |
| from astrbot.api.provider import ProviderRequest |
| |
| @on_llm_request() |
| async def test(self, event: AstrMessageEvent, request: ProviderRequest) -> None: |
| request.system_prompt += "你是一个猫娘..." |
| ``` |
| |
| 请务必接收两个参数:event, request |
| |
| """ |
|
|
| def decorator(awaitable): |
| _ = get_handler_or_create(awaitable, EventType.OnLLMRequestEvent, **kwargs) |
| return awaitable |
|
|
| return decorator |
|
|
|
|
| def register_on_llm_response(**kwargs): |
| """当有 LLM 请求后的事件 |
| |
| Examples: |
| ```py |
| from astrbot.api.provider import LLMResponse |
| |
| @on_llm_response() |
| async def test(self, event: AstrMessageEvent, response: LLMResponse) -> None: |
| ... |
| ``` |
| |
| 请务必接收两个参数:event, request |
| |
| """ |
|
|
| def decorator(awaitable): |
| _ = get_handler_or_create(awaitable, EventType.OnLLMResponseEvent, **kwargs) |
| return awaitable |
|
|
| return decorator |
|
|
|
|
| def register_on_using_llm_tool(**kwargs): |
| """当调用函数工具前的事件。 |
| 会传入 tool 和 tool_args 参数。 |
| |
| Examples: |
| ```py |
| from astrbot.core.agent.tool import FunctionTool |
| |
| @on_using_llm_tool() |
| async def test(self, event: AstrMessageEvent, tool: FunctionTool, tool_args: dict | None) -> None: |
| ... |
| ``` |
| |
| 请务必接收三个参数:event, tool, tool_args |
| |
| """ |
|
|
| def decorator(awaitable): |
| _ = get_handler_or_create(awaitable, EventType.OnUsingLLMToolEvent, **kwargs) |
| return awaitable |
|
|
| return decorator |
|
|
|
|
| def register_on_llm_tool_respond(**kwargs): |
| """当调用函数工具后的事件。 |
| 会传入 tool、tool_args 和 tool 的调用结果 tool_result 参数。 |
| |
| Examples: |
| ```py |
| from astrbot.core.agent.tool import FunctionTool |
| from mcp.types import CallToolResult |
| |
| @on_llm_tool_respond() |
| async def test(self, event: AstrMessageEvent, tool: FunctionTool, tool_args: dict | None, tool_result: CallToolResult | None) -> None: |
| ... |
| ``` |
| |
| 请务必接收四个参数:event, tool, tool_args, tool_result |
| |
| """ |
|
|
| def decorator(awaitable): |
| _ = get_handler_or_create(awaitable, EventType.OnLLMToolRespondEvent, **kwargs) |
| return awaitable |
|
|
| return decorator |
|
|
|
|
| def register_llm_tool(name: str | None = None, **kwargs): |
| """为函数调用(function-calling / tools-use)添加工具。 |
| |
| 请务必按照以下格式编写一个工具(包括函数注释,AstrBot 会尝试解析该函数注释) |
| |
| ``` |
| @llm_tool(name="get_weather") # 如果 name 不填,将使用函数名 |
| async def get_weather(event: AstrMessageEvent, location: str): |
| \'\'\'获取天气信息。 |
| |
| Args: |
| location(string): 地点 |
| \'\'\' |
| # 处理逻辑 |
| ``` |
| |
| 可接受的参数类型有:string, number, object, array, boolean。 |
| |
| 返回值: |
| - 返回 str:结果会被加入下一次 LLM 请求的 prompt 中,用于让 LLM 总结工具返回的结果 |
| - 返回 None:结果不会被加入下一次 LLM 请求的 prompt 中。 |
| |
| 可以使用 yield 发送消息、终止事件。 |
| |
| 发送消息:请参考文档。 |
| |
| 终止事件: |
| ``` |
| event.stop_event() |
| yield |
| ``` |
| |
| """ |
| name_ = name |
| registering_agent = None |
| if kwargs.get("registering_agent"): |
| registering_agent = kwargs["registering_agent"] |
|
|
| def decorator( |
| awaitable: Callable[ |
| ..., |
| AsyncGenerator[MessageEventResult | str | None] |
| | Awaitable[MessageEventResult | str | None], |
| ], |
| ): |
| llm_tool_name = name_ if name_ else awaitable.__name__ |
| func_doc = awaitable.__doc__ or "" |
| docstring = docstring_parser.parse(func_doc) |
| args = [] |
| for arg in docstring.params: |
| sub_type_name = None |
| type_name = arg.type_name |
| if not type_name: |
| raise ValueError( |
| f"LLM 函数工具 {awaitable.__module__}_{llm_tool_name} 的参数 {arg.arg_name} 缺少类型注释。", |
| ) |
| |
| match = re.match(r"(\w+)\[(\w+)\]", type_name) |
| if match: |
| type_name = match.group(1) |
| sub_type_name = match.group(2) |
| type_name = PY_TO_JSON_TYPE.get(type_name, type_name) |
| if sub_type_name: |
| sub_type_name = PY_TO_JSON_TYPE.get(sub_type_name, sub_type_name) |
| if type_name not in SUPPORTED_TYPES or ( |
| sub_type_name and sub_type_name not in SUPPORTED_TYPES |
| ): |
| raise ValueError( |
| f"LLM 函数工具 {awaitable.__module__}_{llm_tool_name} 不支持的参数类型:{arg.type_name}", |
| ) |
|
|
| arg_json_schema = { |
| "type": type_name, |
| "name": arg.arg_name, |
| "description": arg.description, |
| } |
| if sub_type_name: |
| if type_name == "array": |
| arg_json_schema["items"] = {"type": sub_type_name} |
| args.append(arg_json_schema) |
|
|
| if not registering_agent: |
| doc_desc = docstring.description.strip() if docstring.description else "" |
| md = get_handler_or_create(awaitable, EventType.OnCallingFuncToolEvent) |
| llm_tools.add_func(llm_tool_name, args, doc_desc, md.handler) |
| else: |
| assert isinstance(registering_agent, RegisteringAgent) |
| |
| if registering_agent._agent.tools is None: |
| registering_agent._agent.tools = [] |
|
|
| desc = docstring.description.strip() if docstring.description else "" |
| tool = llm_tools.spec_to_func(llm_tool_name, args, desc, awaitable) |
| registering_agent._agent.tools.append(tool) |
|
|
| return awaitable |
|
|
| return decorator |
|
|
|
|
| class RegisteringAgent: |
| """用于 Agent 注册""" |
|
|
| def llm_tool(self, *args, **kwargs): |
| kwargs["registering_agent"] = self |
| return register_llm_tool(*args, **kwargs) |
|
|
| def __init__(self, agent: Agent[AstrAgentContext]) -> None: |
| self._agent = agent |
|
|
|
|
| def register_agent( |
| name: str, |
| instruction: str, |
| tools: list[str | FunctionTool] | None = None, |
| run_hooks: BaseAgentRunHooks[AstrAgentContext] | None = None, |
| ): |
| """注册一个 Agent |
| |
| Args: |
| name: Agent 的名称 |
| instruction: Agent 的指令 |
| tools: Agent 使用的工具列表 |
| run_hooks: Agent 运行时的钩子函数 |
| |
| """ |
| tools_ = tools or [] |
|
|
| def decorator(awaitable: Callable[..., Awaitable[Any]]): |
| AstrAgent = Agent[AstrAgentContext] |
| agent = AstrAgent( |
| name=name, |
| instructions=instruction, |
| tools=tools_, |
| run_hooks=run_hooks or BaseAgentRunHooks[AstrAgentContext](), |
| ) |
| handoff_tool = HandoffTool(agent=agent) |
| handoff_tool.handler = awaitable |
| llm_tools.func_list.append(handoff_tool) |
| return RegisteringAgent(agent) |
|
|
| return decorator |
|
|
|
|
| def register_on_decorating_result(**kwargs): |
| """在发送消息前的事件""" |
|
|
| def decorator(awaitable): |
| _ = get_handler_or_create( |
| awaitable, |
| EventType.OnDecoratingResultEvent, |
| **kwargs, |
| ) |
| return awaitable |
|
|
| return decorator |
|
|
|
|
| def register_after_message_sent(**kwargs): |
| """在消息发送后的事件""" |
|
|
| def decorator(awaitable): |
| _ = get_handler_or_create( |
| awaitable, |
| EventType.OnAfterMessageSentEvent, |
| **kwargs, |
| ) |
| return awaitable |
|
|
| return decorator |
|
|