| import os |
| import sys |
| from unittest.mock import AsyncMock |
|
|
| import pytest |
|
|
| |
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) |
|
|
| from astrbot.core.agent.hooks import BaseAgentRunHooks |
| from astrbot.core.agent.run_context import ContextWrapper |
| from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner |
| from astrbot.core.agent.tool import FunctionTool, ToolSet |
| from astrbot.core.provider.entities import LLMResponse, ProviderRequest, TokenUsage |
| from astrbot.core.provider.provider import Provider |
|
|
|
|
| class MockProvider(Provider): |
| """模拟Provider用于测试""" |
|
|
| def __init__(self): |
| super().__init__({}, {}) |
| self.call_count = 0 |
| self.should_call_tools = True |
| self.max_calls_before_normal_response = 10 |
|
|
| def get_current_key(self) -> str: |
| return "test_key" |
|
|
| def set_key(self, key: str): |
| pass |
|
|
| async def get_models(self) -> list[str]: |
| return ["test_model"] |
|
|
| async def text_chat(self, **kwargs) -> LLMResponse: |
| self.call_count += 1 |
|
|
| |
| func_tool = kwargs.get("func_tool") |
|
|
| |
| if func_tool is None or self.call_count > self.max_calls_before_normal_response: |
| return LLMResponse( |
| role="assistant", |
| completion_text="这是我的最终回答", |
| usage=TokenUsage(input_other=10, output=5), |
| ) |
|
|
| |
| if self.should_call_tools: |
| return LLMResponse( |
| role="assistant", |
| completion_text="我需要使用工具来帮助您", |
| tools_call_name=["test_tool"], |
| tools_call_args=[{"query": "test"}], |
| tools_call_ids=["call_123"], |
| usage=TokenUsage(input_other=10, output=5), |
| ) |
|
|
| |
| return LLMResponse( |
| role="assistant", |
| completion_text="这是我的最终回答", |
| usage=TokenUsage(input_other=10, output=5), |
| ) |
|
|
| async def text_chat_stream(self, **kwargs): |
| response = await self.text_chat(**kwargs) |
| response.is_chunk = True |
| yield response |
| response.is_chunk = False |
| yield response |
|
|
|
|
| class MockToolExecutor: |
| """模拟工具执行器""" |
|
|
| @classmethod |
| def execute(cls, tool, run_context, **tool_args): |
| async def generator(): |
| |
| from mcp.types import CallToolResult, TextContent |
|
|
| result = CallToolResult( |
| content=[TextContent(type="text", text="工具执行结果")] |
| ) |
| yield result |
|
|
| return generator() |
|
|
|
|
| class MockFailingProvider(MockProvider): |
| async def text_chat(self, **kwargs) -> LLMResponse: |
| self.call_count += 1 |
| raise RuntimeError("primary provider failed") |
|
|
|
|
| class MockErrProvider(MockProvider): |
| async def text_chat(self, **kwargs) -> LLMResponse: |
| self.call_count += 1 |
| return LLMResponse( |
| role="err", |
| completion_text="primary provider returned error", |
| ) |
|
|
|
|
| class MockAbortableStreamProvider(MockProvider): |
| async def text_chat_stream(self, **kwargs): |
| abort_signal = kwargs.get("abort_signal") |
| yield LLMResponse( |
| role="assistant", |
| completion_text="partial ", |
| is_chunk=True, |
| ) |
| if abort_signal and abort_signal.is_set(): |
| yield LLMResponse( |
| role="assistant", |
| completion_text="partial ", |
| is_chunk=False, |
| ) |
| return |
| yield LLMResponse( |
| role="assistant", |
| completion_text="partial final", |
| is_chunk=False, |
| ) |
|
|
|
|
| class MockHooks(BaseAgentRunHooks): |
| """模拟钩子函数""" |
|
|
| def __init__(self): |
| self.agent_begin_called = False |
| self.agent_done_called = False |
| self.tool_start_called = False |
| self.tool_end_called = False |
|
|
| async def on_agent_begin(self, run_context): |
| self.agent_begin_called = True |
|
|
| async def on_tool_start(self, run_context, tool, tool_args): |
| self.tool_start_called = True |
|
|
| async def on_tool_end(self, run_context, tool, tool_args, tool_result): |
| self.tool_end_called = True |
|
|
| async def on_agent_done(self, run_context, llm_response): |
| self.agent_done_called = True |
|
|
|
|
| class MockEvent: |
| def __init__(self, umo: str, sender_id: str): |
| self.unified_msg_origin = umo |
| self._sender_id = sender_id |
|
|
| def get_sender_id(self): |
| return self._sender_id |
|
|
|
|
| class MockAgentContext: |
| def __init__(self, event): |
| self.event = event |
|
|
|
|
| @pytest.fixture |
| def mock_provider(): |
| return MockProvider() |
|
|
|
|
| @pytest.fixture |
| def mock_tool_executor(): |
| return MockToolExecutor() |
|
|
|
|
| @pytest.fixture |
| def mock_hooks(): |
| return MockHooks() |
|
|
|
|
| @pytest.fixture |
| def tool_set(): |
| """创建测试用的工具集""" |
| tool = FunctionTool( |
| name="test_tool", |
| description="测试工具", |
| parameters={"type": "object", "properties": {"query": {"type": "string"}}}, |
| handler=AsyncMock(), |
| ) |
| return ToolSet(tools=[tool]) |
|
|
|
|
| @pytest.fixture |
| def provider_request(tool_set): |
| """创建测试用的ProviderRequest""" |
| return ProviderRequest(prompt="请帮我查询信息", func_tool=tool_set, contexts=[]) |
|
|
|
|
| @pytest.fixture |
| def runner(): |
| """创建ToolLoopAgentRunner实例""" |
| return ToolLoopAgentRunner() |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_max_step_limit_functionality( |
| runner, mock_provider, provider_request, mock_tool_executor, mock_hooks |
| ): |
| """测试最大步数限制功能""" |
|
|
| |
| mock_provider.should_call_tools = True |
| mock_provider.max_calls_before_normal_response = ( |
| 100 |
| ) |
|
|
| |
| await runner.reset( |
| provider=mock_provider, |
| request=provider_request, |
| run_context=ContextWrapper(context=None), |
| tool_executor=mock_tool_executor, |
| agent_hooks=mock_hooks, |
| streaming=False, |
| ) |
|
|
| |
| max_steps = 3 |
|
|
| |
| responses = [] |
| async for response in runner.step_until_done(max_steps): |
| responses.append(response) |
|
|
| |
| assert runner.done(), "代理应该在达到最大步数后完成" |
|
|
| |
| assert runner.req.func_tool is None, "达到最大步数后工具应该被禁用" |
|
|
| |
| final_responses = [r for r in responses if r.type == "llm_result"] |
| assert len(final_responses) > 0, "应该有最终的LLM响应" |
|
|
| |
| last_message = runner.run_context.messages[-1] |
| assert last_message.role == "assistant", "最后一条消息应该是assistant的最终回答" |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_normal_completion_without_max_step( |
| runner, mock_provider, provider_request, mock_tool_executor, mock_hooks |
| ): |
| """测试正常完成(不触发最大步数限制)""" |
|
|
| |
| mock_provider.should_call_tools = True |
| mock_provider.max_calls_before_normal_response = 2 |
|
|
| |
| await runner.reset( |
| provider=mock_provider, |
| request=provider_request, |
| run_context=ContextWrapper(context=None), |
| tool_executor=mock_tool_executor, |
| agent_hooks=mock_hooks, |
| streaming=False, |
| ) |
|
|
| |
| max_steps = 10 |
|
|
| |
| responses = [] |
| async for response in runner.step_until_done(max_steps): |
| responses.append(response) |
|
|
| |
| assert runner.done(), "代理应该正常完成" |
|
|
| |
| |
| assert mock_provider.call_count < max_steps, ( |
| f"正常完成时调用次数({mock_provider.call_count})应该小于最大步数({max_steps})" |
| ) |
|
|
| |
| user_messages = [m for m in runner.run_context.messages if m.role == "user"] |
| max_step_messages = [ |
| m for m in user_messages if "工具调用次数已达到上限" in m.content |
| ] |
| assert len(max_step_messages) == 0, "正常完成时不应该有步数限制消息" |
|
|
| |
| assert runner.req.func_tool is not None, "正常完成时工具不应该被禁用" |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_max_step_with_streaming( |
| runner, mock_provider, provider_request, mock_tool_executor, mock_hooks |
| ): |
| """测试流式响应下的最大步数限制""" |
|
|
| |
| mock_provider.should_call_tools = True |
| mock_provider.max_calls_before_normal_response = 100 |
|
|
| |
| await runner.reset( |
| provider=mock_provider, |
| request=provider_request, |
| run_context=ContextWrapper(context=None), |
| tool_executor=mock_tool_executor, |
| agent_hooks=mock_hooks, |
| streaming=True, |
| ) |
|
|
| |
| max_steps = 2 |
|
|
| |
| responses = [] |
| async for response in runner.step_until_done(max_steps): |
| responses.append(response) |
|
|
| |
| assert runner.done(), "代理应该在达到最大步数后完成" |
|
|
| |
| streaming_responses = [r for r in responses if r.type == "streaming_delta"] |
| assert len(streaming_responses) > 0, "应该有流式响应" |
|
|
| |
| assert runner.req.func_tool is None, "达到最大步数后工具应该被禁用" |
|
|
| |
| last_message = runner.run_context.messages[-1] |
| assert last_message.role == "assistant", "最后一条消息应该是assistant的最终回答" |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_hooks_called_with_max_step( |
| runner, mock_provider, provider_request, mock_tool_executor, mock_hooks |
| ): |
| """测试达到最大步数时钩子函数是否被正确调用""" |
|
|
| |
| mock_provider.should_call_tools = True |
| mock_provider.max_calls_before_normal_response = 100 |
|
|
| |
| await runner.reset( |
| provider=mock_provider, |
| request=provider_request, |
| run_context=ContextWrapper(context=None), |
| tool_executor=mock_tool_executor, |
| agent_hooks=mock_hooks, |
| streaming=False, |
| ) |
|
|
| |
| max_steps = 2 |
|
|
| |
| async for response in runner.step_until_done(max_steps): |
| pass |
|
|
| |
| assert mock_hooks.agent_begin_called, "on_agent_begin应该被调用" |
| assert mock_hooks.agent_done_called, "on_agent_done应该被调用" |
| assert mock_hooks.tool_start_called, "on_tool_start应该被调用" |
| assert mock_hooks.tool_end_called, "on_tool_end应该被调用" |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_fallback_provider_used_when_primary_raises( |
| runner, provider_request, mock_tool_executor, mock_hooks |
| ): |
| primary_provider = MockFailingProvider() |
| fallback_provider = MockProvider() |
| fallback_provider.should_call_tools = False |
|
|
| await runner.reset( |
| provider=primary_provider, |
| request=provider_request, |
| run_context=ContextWrapper(context=None), |
| tool_executor=mock_tool_executor, |
| agent_hooks=mock_hooks, |
| streaming=False, |
| fallback_providers=[fallback_provider], |
| ) |
|
|
| async for _ in runner.step_until_done(5): |
| pass |
|
|
| final_resp = runner.get_final_llm_resp() |
| assert final_resp is not None |
| assert final_resp.role == "assistant" |
| assert final_resp.completion_text == "这是我的最终回答" |
| assert primary_provider.call_count == 1 |
| assert fallback_provider.call_count == 1 |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_fallback_provider_used_when_primary_returns_err( |
| runner, provider_request, mock_tool_executor, mock_hooks |
| ): |
| primary_provider = MockErrProvider() |
| fallback_provider = MockProvider() |
| fallback_provider.should_call_tools = False |
|
|
| await runner.reset( |
| provider=primary_provider, |
| request=provider_request, |
| run_context=ContextWrapper(context=None), |
| tool_executor=mock_tool_executor, |
| agent_hooks=mock_hooks, |
| streaming=False, |
| fallback_providers=[fallback_provider], |
| ) |
|
|
| async for _ in runner.step_until_done(5): |
| pass |
|
|
| final_resp = runner.get_final_llm_resp() |
| assert final_resp is not None |
| assert final_resp.role == "assistant" |
| assert final_resp.completion_text == "这是我的最终回答" |
| assert primary_provider.call_count == 1 |
| assert fallback_provider.call_count == 1 |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_stop_signal_returns_aborted_and_persists_partial_message( |
| runner, provider_request, mock_tool_executor, mock_hooks |
| ): |
| provider = MockAbortableStreamProvider() |
|
|
| await runner.reset( |
| provider=provider, |
| request=provider_request, |
| run_context=ContextWrapper(context=None), |
| tool_executor=mock_tool_executor, |
| agent_hooks=mock_hooks, |
| streaming=True, |
| ) |
|
|
| step_iter = runner.step() |
| first_resp = await step_iter.__anext__() |
| assert first_resp.type == "streaming_delta" |
|
|
| runner.request_stop() |
|
|
| rest_responses = [] |
| async for response in step_iter: |
| rest_responses.append(response) |
|
|
| assert any(resp.type == "aborted" for resp in rest_responses) |
| assert runner.was_aborted() is True |
|
|
| final_resp = runner.get_final_llm_resp() |
| assert final_resp is not None |
| assert final_resp.role == "assistant" |
| |
| assert "interrupted" in final_resp.completion_text.lower() |
| assert runner.run_context.messages[-1].role == "assistant" |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_tool_result_injects_follow_up_notice( |
| runner, mock_provider, provider_request, mock_tool_executor, mock_hooks |
| ): |
| mock_event = MockEvent("test:FriendMessage:follow_up", "u1") |
| run_context = ContextWrapper(context=MockAgentContext(mock_event)) |
|
|
| await runner.reset( |
| provider=mock_provider, |
| request=provider_request, |
| run_context=run_context, |
| tool_executor=mock_tool_executor, |
| agent_hooks=mock_hooks, |
| streaming=False, |
| ) |
|
|
| ticket1 = runner.follow_up( |
| message_text="follow up 1", |
| ) |
| ticket2 = runner.follow_up( |
| message_text="follow up 2", |
| ) |
| assert ticket1 is not None |
| assert ticket2 is not None |
|
|
| async for _ in runner.step(): |
| pass |
|
|
| assert provider_request.tool_calls_result is not None |
| assert isinstance(provider_request.tool_calls_result, list) |
| assert provider_request.tool_calls_result |
| tool_result = str( |
| provider_request.tool_calls_result[0].tool_calls_result[0].content |
| ) |
| assert "SYSTEM NOTICE" in tool_result |
| assert "1. follow up 1" in tool_result |
| assert "2. follow up 2" in tool_result |
| assert ticket1.resolved.is_set() is True |
| assert ticket2.resolved.is_set() is True |
| assert ticket1.consumed is True |
| assert ticket2.consumed is True |
|
|
|
|
| @pytest.mark.asyncio |
| async def test_follow_up_ticket_not_consumed_when_no_next_tool_call( |
| runner, mock_provider, provider_request, mock_tool_executor, mock_hooks |
| ): |
| mock_provider.should_call_tools = False |
| mock_event = MockEvent("test:FriendMessage:follow_up_no_tool", "u1") |
| run_context = ContextWrapper(context=MockAgentContext(mock_event)) |
|
|
| await runner.reset( |
| provider=mock_provider, |
| request=provider_request, |
| run_context=run_context, |
| tool_executor=mock_tool_executor, |
| agent_hooks=mock_hooks, |
| streaming=False, |
| ) |
|
|
| ticket = runner.follow_up(message_text="follow up without tool") |
| assert ticket is not None |
|
|
| async for _ in runner.step(): |
| pass |
|
|
| assert ticket.resolved.is_set() is True |
| assert ticket.consumed is False |
|
|
|
|
| if __name__ == "__main__": |
| |
| pytest.main([__file__, "-v"]) |
|
|