"""Tests for astr_main_agent module.""" import os from unittest.mock import AsyncMock, MagicMock, patch import pytest from astrbot.core import astr_main_agent as ama from astrbot.core.agent.mcp_client import MCPTool from astrbot.core.agent.tool import FunctionTool, ToolSet from astrbot.core.conversation_mgr import Conversation from astrbot.core.message.components import File, Image, Plain, Reply from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.platform.platform_metadata import PlatformMetadata from astrbot.core.provider import Provider from astrbot.core.provider.entities import ProviderRequest @pytest.fixture def mock_provider(): """Create a mock provider.""" provider = MagicMock(spec=Provider) provider.provider_config = { "id": "test-provider", "modalities": ["image", "tool_use"], } provider.get_model.return_value = "gpt-4" return provider @pytest.fixture def mock_context(): """Create a mock Context.""" ctx = MagicMock() ctx.get_config.return_value = {} ctx.conversation_manager = MagicMock() ctx.persona_manager = MagicMock() ctx.persona_manager.personas_v3 = [] ctx.persona_manager.resolve_selected_persona = AsyncMock( return_value=(None, None, None, False) ) ctx.get_llm_tool_manager.return_value = MagicMock() ctx.subagent_orchestrator = None return ctx @pytest.fixture def mock_event(): """Create a mock AstrMessageEvent.""" platform_meta = PlatformMetadata( id="test_platform", name="test_platform", description="Test platform", ) message_obj = MagicMock() message_obj.message = [Plain(text="Hello")] message_obj.sender = MagicMock(user_id="user123", nickname="TestUser") message_obj.group_id = None message_obj.group = None event = MagicMock(spec=AstrMessageEvent) event.message_str = "Hello" event.message_obj = message_obj event.platform_meta = platform_meta event.session_id = "session123" event.unified_msg_origin = "test_platform:private:session123" event.get_extra.return_value = None event.get_platform_name.return_value = "test_platform" event.get_platform_id.return_value = "test_platform" event.get_group_id.return_value = None event.get_sender_name.return_value = "TestUser" event.trace = MagicMock() event.plugins_name = None return event @pytest.fixture def mock_conversation(): """Create a mock conversation.""" conv = MagicMock(spec=Conversation) conv.cid = "conv-id" conv.persona_id = None conv.history = "[]" return conv @pytest.fixture def sample_config(): """Create a sample MainAgentBuildConfig.""" module = ama return module.MainAgentBuildConfig( tool_call_timeout=60, streaming_response=True, file_extract_enabled=True, file_extract_prov="moonshotai", file_extract_msh_api_key="test-api-key", ) def _new_mock_conversation(cid: str = "conv-id") -> MagicMock: conv = MagicMock(spec=Conversation) conv.cid = cid conv.persona_id = None conv.history = "[]" return conv def _setup_conversation_for_build(conv_mgr, cid: str = "conv-id") -> MagicMock: conv_mgr.get_curr_conversation_id = AsyncMock(return_value=None) conv_mgr.new_conversation = AsyncMock(return_value=cid) conversation = _new_mock_conversation(cid=cid) conv_mgr.get_conversation = AsyncMock(return_value=conversation) return conversation class TestMainAgentBuildConfig: """Tests for MainAgentBuildConfig dataclass.""" def test_config_initialization(self): """Test MainAgentBuildConfig initialization with defaults.""" module = ama config = module.MainAgentBuildConfig(tool_call_timeout=60) assert config.tool_call_timeout == 60 assert config.tool_schema_mode == "full" assert config.provider_wake_prefix == "" assert config.streaming_response is True assert config.sanitize_context_by_modalities is False assert config.kb_agentic_mode is False assert config.file_extract_enabled is False assert config.llm_safety_mode is True def test_config_with_custom_values(self): """Test MainAgentBuildConfig with custom values.""" module = ama config = module.MainAgentBuildConfig( tool_call_timeout=120, tool_schema_mode="skills-like", provider_wake_prefix="/", streaming_response=False, kb_agentic_mode=True, file_extract_enabled=True, computer_use_runtime="sandbox", add_cron_tools=False, ) assert config.tool_call_timeout == 120 assert config.tool_schema_mode == "skills-like" assert config.provider_wake_prefix == "/" assert config.streaming_response is False assert config.kb_agentic_mode is True assert config.file_extract_enabled is True assert config.computer_use_runtime == "sandbox" assert config.add_cron_tools is False class TestSelectProvider: """Tests for _select_provider function.""" def test_select_provider_by_id(self, mock_event, mock_context, mock_provider): """Test selecting provider by ID from event extra.""" module = ama mock_event.get_extra.side_effect = lambda k: ( "test-provider" if k == "selected_provider" else None ) mock_context.get_provider_by_id.return_value = mock_provider result = module._select_provider(mock_event, mock_context) assert result == mock_provider mock_context.get_provider_by_id.assert_called_once_with("test-provider") def test_select_provider_not_found(self, mock_event, mock_context): """Test selecting provider when ID is not found.""" module = ama mock_event.get_extra.side_effect = lambda k: ( "non-existent" if k == "selected_provider" else None ) mock_context.get_provider_by_id.return_value = None result = module._select_provider(mock_event, mock_context) assert result is None def test_select_provider_invalid_type(self, mock_event, mock_context): """Test selecting provider when result is not a Provider instance.""" module = ama mock_event.get_extra.side_effect = lambda k: ( "invalid" if k == "selected_provider" else None ) mock_context.get_provider_by_id.return_value = "not a provider" result = module._select_provider(mock_event, mock_context) assert result is None def test_select_provider_fallback(self, mock_event, mock_context, mock_provider): """Test provider selection fallback to using provider.""" module = ama mock_event.get_extra.return_value = None mock_context.get_using_provider.return_value = mock_provider result = module._select_provider(mock_event, mock_context) assert result == mock_provider mock_context.get_using_provider.assert_called_once_with( umo=mock_event.unified_msg_origin ) def test_select_provider_fallback_error(self, mock_event, mock_context): """Test provider selection when fallback raises ValueError.""" module = ama mock_event.get_extra.return_value = None mock_context.get_using_provider.side_effect = ValueError("Test error") result = module._select_provider(mock_event, mock_context) assert result is None class TestGetSessionConv: """Tests for _get_session_conv function.""" @pytest.mark.asyncio async def test_get_session_conv_existing( self, mock_event, mock_context, mock_conversation ): """Test getting existing conversation.""" module = ama conv_mgr = mock_context.conversation_manager conv_mgr.get_curr_conversation_id = AsyncMock(return_value="existing-conv-id") conv_mgr.get_conversation = AsyncMock(return_value=mock_conversation) result = await module._get_session_conv(mock_event, mock_context) assert result == mock_conversation conv_mgr.get_curr_conversation_id.assert_called_once_with( mock_event.unified_msg_origin ) conv_mgr.get_conversation.assert_called_once_with( mock_event.unified_msg_origin, "existing-conv-id" ) @pytest.mark.asyncio async def test_get_session_conv_create_new(self, mock_event, mock_context): """Test creating new conversation when none exists.""" module = ama conv_mgr = mock_context.conversation_manager conv_mgr.get_curr_conversation_id = AsyncMock(return_value=None) conv_mgr.new_conversation = AsyncMock(return_value="new-conv-id") mock_conversation = MagicMock(spec=Conversation) mock_conversation.cid = "new-conv-id" mock_conversation.persona_id = None mock_conversation.history = "[]" conv_mgr.get_conversation = AsyncMock(return_value=mock_conversation) result = await module._get_session_conv(mock_event, mock_context) assert result == mock_conversation conv_mgr.new_conversation.assert_called_once_with( mock_event.unified_msg_origin, mock_event.get_platform_id() ) @pytest.mark.asyncio async def test_get_session_conv_retry(self, mock_event, mock_context): """Test retrying conversation creation after failure.""" module = ama conv_mgr = mock_context.conversation_manager conv_mgr.get_curr_conversation_id = AsyncMock(return_value="conv-id") conv_mgr.get_conversation = AsyncMock(return_value=None) conv_mgr.new_conversation = AsyncMock(return_value="retry-conv-id") mock_conversation = MagicMock(spec=Conversation) mock_conversation.cid = "retry-conv-id" mock_conversation.persona_id = None mock_conversation.history = "[]" conv_mgr.get_conversation.side_effect = [None, mock_conversation] result = await module._get_session_conv(mock_event, mock_context) assert result == mock_conversation assert conv_mgr.new_conversation.call_count == 1 assert conv_mgr.get_conversation.call_count == 2 @pytest.mark.asyncio async def test_get_session_conv_failure(self, mock_event, mock_context): """Test RuntimeError when conversation creation fails.""" module = ama conv_mgr = mock_context.conversation_manager conv_mgr.get_curr_conversation_id = AsyncMock(return_value=None) conv_mgr.new_conversation = AsyncMock(return_value="new-conv-id") conv_mgr.get_conversation = AsyncMock(return_value=None) with pytest.raises(RuntimeError, match="无法创建新的对话。"): await module._get_session_conv(mock_event, mock_context) class TestApplyKb: """Tests for _apply_kb function.""" @pytest.mark.asyncio async def test_apply_kb_without_agentic_mode(self, mock_event, mock_context): """Test applying knowledge base in non-agentic mode.""" module = ama req = ProviderRequest(prompt="test question", system_prompt="System prompt") config = module.MainAgentBuildConfig( tool_call_timeout=60, kb_agentic_mode=False ) with patch( "astrbot.core.astr_main_agent.retrieve_knowledge_base", AsyncMock(return_value="KB result"), ): await module._apply_kb(mock_event, req, mock_context, config) assert "[Related Knowledge Base Results]:" in req.system_prompt assert "KB result" in req.system_prompt @pytest.mark.asyncio async def test_apply_kb_with_agentic_mode(self, mock_event, mock_context): """Test applying knowledge base in agentic mode.""" module = ama req = ProviderRequest(prompt="test question") config = module.MainAgentBuildConfig(tool_call_timeout=60, kb_agentic_mode=True) await module._apply_kb(mock_event, req, mock_context, config) assert req.func_tool is not None @pytest.mark.asyncio async def test_apply_kb_no_prompt(self, mock_event, mock_context): """Test applying knowledge base when prompt is None.""" module = ama req = ProviderRequest(prompt=None, system_prompt="System") config = module.MainAgentBuildConfig( tool_call_timeout=60, kb_agentic_mode=False ) await module._apply_kb(mock_event, req, mock_context, config) assert req.system_prompt == "System" @pytest.mark.asyncio async def test_apply_kb_no_result(self, mock_event, mock_context): """Test applying knowledge base when no result is returned.""" module = ama req = ProviderRequest(prompt="test", system_prompt="System") config = module.MainAgentBuildConfig( tool_call_timeout=60, kb_agentic_mode=False ) with patch( "astrbot.core.astr_main_agent.retrieve_knowledge_base", AsyncMock(return_value=None), ): await module._apply_kb(mock_event, req, mock_context, config) assert req.system_prompt == "System" @pytest.mark.asyncio async def test_apply_kb_with_existing_tools(self, mock_event, mock_context): """Test applying knowledge base with existing toolset.""" module = ama existing_tools = ToolSet() req = ProviderRequest(prompt="test", func_tool=existing_tools) config = module.MainAgentBuildConfig(tool_call_timeout=60, kb_agentic_mode=True) await module._apply_kb(mock_event, req, mock_context, config) assert req.func_tool is not None class TestApplyFileExtract: """Tests for _apply_file_extract function.""" @pytest.mark.asyncio async def test_file_extract_basic(self, mock_event, sample_config): """Test basic file extraction.""" module = ama mock_file = MagicMock(spec=File) mock_file.name = "test.pdf" mock_file.get_file = AsyncMock(return_value="/path/to/test.pdf") mock_event.message_obj.message = [mock_file] req = ProviderRequest(prompt="Summarize") with patch( "astrbot.core.astr_main_agent.extract_file_moonshotai" ) as mock_extract: mock_extract.return_value = "File content" await module._apply_file_extract(mock_event, req, sample_config) assert len(req.contexts) == 1 assert "File Extract Results" in req.contexts[0]["content"] @pytest.mark.asyncio async def test_file_extract_no_files(self, mock_event, sample_config): """Test file extraction when no files present.""" module = ama mock_event.message_obj.message = [Plain(text="Hello")] req = ProviderRequest(prompt="Hello") await module._apply_file_extract(mock_event, req, sample_config) assert len(req.contexts) == 0 @pytest.mark.asyncio async def test_file_extract_in_reply(self, mock_event, sample_config): """Test file extraction from reply chain.""" module = ama mock_file = MagicMock(spec=File) mock_file.name = "reply.pdf" mock_file.get_file = AsyncMock(return_value="/path/to/reply.pdf") mock_reply = MagicMock(spec=Reply) mock_reply.chain = [mock_file] mock_event.message_obj.message = [mock_reply] req = ProviderRequest(prompt="Summarize") with patch( "astrbot.core.astr_main_agent.extract_file_moonshotai" ) as mock_extract: mock_extract.return_value = "Reply content" await module._apply_file_extract(mock_event, req, sample_config) assert len(req.contexts) == 1 @pytest.mark.asyncio async def test_file_extract_no_prompt(self, mock_event, sample_config): """Test file extraction when prompt is empty.""" module = ama mock_file = MagicMock(spec=File) mock_file.name = "test.pdf" mock_file.get_file = AsyncMock(return_value="/path/to/test.pdf") mock_event.message_obj.message = [mock_file] req = ProviderRequest(prompt=None) with patch( "astrbot.core.astr_main_agent.extract_file_moonshotai" ) as mock_extract: mock_extract.return_value = "Content" await module._apply_file_extract(mock_event, req, sample_config) assert req.prompt == "总结一下文件里面讲了什么?" @pytest.mark.asyncio async def test_file_extract_no_api_key(self, mock_event): """Test file extraction when no API key is configured.""" module = ama config = module.MainAgentBuildConfig( tool_call_timeout=60, file_extract_enabled=True, file_extract_msh_api_key="", ) mock_file = MagicMock(spec=File) mock_file.name = "test.pdf" mock_file.get_file = AsyncMock(return_value="/path/to/test.pdf") mock_event.message_obj.message = [mock_file] req = ProviderRequest(prompt="Summarize") await module._apply_file_extract(mock_event, req, config) assert len(req.contexts) == 0 class TestEnsurePersonaAndSkills: """Tests for _ensure_persona_and_skills function.""" @pytest.mark.asyncio async def test_ensure_persona_from_session(self, mock_event, mock_context): """Test applying persona from session service config.""" module = ama persona = {"name": "test-persona", "prompt": "You are helpful."} mock_context.persona_manager.personas_v3 = [persona] mock_context.persona_manager.resolve_selected_persona = AsyncMock( return_value=("test-persona", persona, "test-persona", False) ) mock_event.trace = MagicMock(record=MagicMock()) req = ProviderRequest() req.conversation = MagicMock(persona_id=None) await module._ensure_persona_and_skills(req, {}, mock_context, mock_event) assert "You are helpful." in req.system_prompt @pytest.mark.asyncio async def test_ensure_persona_from_conversation(self, mock_event, mock_context): """Test applying persona from conversation setting.""" module = ama persona = {"name": "conv-persona", "prompt": "Custom persona."} mock_context.persona_manager.personas_v3 = [persona] mock_context.persona_manager.resolve_selected_persona = AsyncMock( return_value=("conv-persona", persona, None, False) ) req = ProviderRequest() req.conversation = MagicMock(persona_id="conv-persona") await module._ensure_persona_and_skills(req, {}, mock_context, mock_event) assert "Custom persona." in req.system_prompt @pytest.mark.asyncio async def test_ensure_persona_none_explicit(self, mock_event, mock_context): """Test that [%None] persona is explicitly set to no persona.""" module = ama mock_context.persona_manager.personas_v3 = [] mock_context.persona_manager.resolve_selected_persona = AsyncMock( return_value=("[%None]", None, None, False) ) req = ProviderRequest() req.conversation = MagicMock(persona_id="[%None]") await module._ensure_persona_and_skills(req, {}, mock_context, mock_event) assert "Persona Instructions" not in req.system_prompt @pytest.mark.asyncio async def test_ensure_tools_from_persona(self, mock_event, mock_context): """Test applying tools from persona.""" module = ama mock_tool = MagicMock() mock_tool.name = "test_tool" mock_tool.active = True persona = {"name": "persona", "prompt": "Test", "tools": ["test_tool"]} mock_context.persona_manager.personas_v3 = [persona] mock_context.persona_manager.resolve_selected_persona = AsyncMock( return_value=("persona", persona, None, False) ) tmgr = mock_context.get_llm_tool_manager.return_value tmgr.get_func.return_value = mock_tool req = ProviderRequest() req.conversation = MagicMock(persona_id="persona") await module._ensure_persona_and_skills(req, {}, mock_context, mock_event) assert req.func_tool is not None class TestDecorateLlmRequest: """Tests for _decorate_llm_request function.""" @pytest.mark.asyncio async def test_decorate_llm_request_basic( self, mock_event, mock_context, sample_config ): """Test basic LLM request decoration.""" module = ama req = ProviderRequest(prompt="Hello", system_prompt="System") await module._decorate_llm_request(mock_event, req, mock_context, sample_config) assert req.prompt == "Hello" assert req.system_prompt == "System" @pytest.mark.asyncio async def test_decorate_llm_request_with_prefix(self, mock_event, mock_context): """Test LLM request decoration with prompt prefix.""" module = ama req = ProviderRequest(prompt="Hello") config = module.MainAgentBuildConfig( tool_call_timeout=60, provider_settings={"prompt_prefix": "AI: "} ) with patch.object(mock_context, "get_config") as mock_get_config: mock_get_config.return_value = {} await module._decorate_llm_request(mock_event, req, mock_context, config) assert req.prompt == "AI: Hello" @pytest.mark.asyncio async def test_decorate_llm_request_prefix_with_placeholder( self, mock_event, mock_context ): """Test prompt prefix with {{prompt}} placeholder.""" module = ama req = ProviderRequest(prompt="Hello") config = module.MainAgentBuildConfig( tool_call_timeout=60, provider_settings={"prompt_prefix": "AI {{prompt}} - Please respond:"}, ) with patch.object(mock_context, "get_config") as mock_get_config: mock_get_config.return_value = {} await module._decorate_llm_request(mock_event, req, mock_context, config) assert req.prompt == "AI Hello - Please respond:" @pytest.mark.asyncio async def test_decorate_llm_request_no_conversation(self, mock_event, mock_context): """Test decoration when no conversation exists.""" module = ama req = ProviderRequest(prompt="Hello") req.conversation = None config = module.MainAgentBuildConfig(tool_call_timeout=60) with patch.object(mock_context, "get_config") as mock_get_config: mock_get_config.return_value = {} await module._decorate_llm_request(mock_event, req, mock_context, config) assert req.prompt == "Hello" class TestModalitiesFix: """Tests for _modalities_fix function.""" def test_modalities_fix_image_not_supported(self, mock_provider): """Test modality fix when image is not supported.""" module = ama mock_provider.provider_config = {"modalities": ["text"]} req = ProviderRequest(prompt="Hello", image_urls=["/path/to/image.jpg"]) module._modalities_fix(mock_provider, req) assert "[图片]" in req.prompt assert req.image_urls == [] def test_modalities_fix_tool_not_supported(self, mock_provider): """Test modality fix when tool is not supported.""" module = ama mock_provider.provider_config = {"modalities": ["text", "image"]} req = ProviderRequest(prompt="Hello") req.func_tool = ToolSet() req.func_tool.add_tool( FunctionTool( name="dummy_tool", description="dummy", parameters={"type": "object", "properties": {}}, ) ) module._modalities_fix(mock_provider, req) assert req.func_tool is None def test_modalities_fix_all_supported(self, mock_provider): """Test modality fix when all features are supported.""" module = ama mock_provider.provider_config = {"modalities": ["image", "tool_use"]} tool_set = ToolSet() tool_set.add_tool( FunctionTool( name="dummy_tool", description="dummy", parameters={"type": "object", "properties": {}}, ) ) req = ProviderRequest( prompt="Hello", image_urls=["/path/to/image.jpg"], func_tool=tool_set, ) module._modalities_fix(mock_provider, req) assert req.prompt == "Hello" assert len(req.image_urls) == 1 assert req.func_tool is not None class TestSanitizeContextByModalities: """Tests for _sanitize_context_by_modalities function.""" def test_sanitize_no_op(self, mock_provider): """Test sanitize when disabled or modalities support everything.""" module = ama config = module.MainAgentBuildConfig( tool_call_timeout=60, sanitize_context_by_modalities=False ) mock_provider.provider_config = {"modalities": ["image", "tool_use"]} req = ProviderRequest(contexts=[{"role": "user", "content": "Hello"}]) module._sanitize_context_by_modalities(config, mock_provider, req) assert len(req.contexts) == 1 def test_sanitize_removes_tool_messages(self, mock_provider): """Test sanitize removes tool messages when tool_use not supported.""" module = ama config = module.MainAgentBuildConfig( tool_call_timeout=60, sanitize_context_by_modalities=True ) mock_provider.provider_config = {"modalities": ["image"]} req = ProviderRequest( contexts=[ {"role": "user", "content": "Hello"}, {"role": "tool", "content": "Tool result"}, ] ) module._sanitize_context_by_modalities(config, mock_provider, req) assert len(req.contexts) == 1 assert req.contexts[0]["role"] == "user" def test_sanitize_removes_tool_calls(self, mock_provider): """Test sanitize removes tool_calls from assistant messages.""" module = ama config = module.MainAgentBuildConfig( tool_call_timeout=60, sanitize_context_by_modalities=True ) mock_provider.provider_config = {"modalities": ["image"]} req = ProviderRequest( contexts=[ { "role": "assistant", "content": "Response", "tool_calls": [{"name": "tool"}], } ] ) module._sanitize_context_by_modalities(config, mock_provider, req) assert "tool_calls" not in req.contexts[0] def test_sanitize_removes_image_blocks(self, mock_provider): """Test sanitize removes image blocks when image not supported.""" module = ama config = module.MainAgentBuildConfig( tool_call_timeout=60, sanitize_context_by_modalities=True ) mock_provider.provider_config = {"modalities": ["tool_use"]} req = ProviderRequest( contexts=[ { "role": "user", "content": [ {"type": "text", "text": "Hello"}, {"type": "image_url", "url": "image.jpg"}, ], } ] ) module._sanitize_context_by_modalities(config, mock_provider, req) assert len(req.contexts[0]["content"]) == 1 assert req.contexts[0]["content"][0]["type"] == "text" class TestPluginToolFix: """Tests for _plugin_tool_fix function.""" def test_plugin_tool_fix_none_plugins(self, mock_event): """Test plugin tool fix when no plugins specified.""" module = ama req = ProviderRequest(func_tool=ToolSet()) mock_event.plugins_name = None module._plugin_tool_fix(mock_event, req) assert req.func_tool is not None def test_plugin_tool_fix_filters_by_plugin(self, mock_event): """Test plugin tool fix filters tools by enabled plugins.""" module = ama mcp_tool = MagicMock(spec=MCPTool) mcp_tool.name = "mcp_tool" plugin_tool = MagicMock() plugin_tool.name = "plugin_tool" plugin_tool.handler_module_path = "test_plugin" plugin_tool.active = True tool_set = ToolSet() tool_set.add_tool(mcp_tool) tool_set.add_tool(plugin_tool) req = ProviderRequest(func_tool=tool_set) mock_event.plugins_name = ["test_plugin"] with patch("astrbot.core.astr_main_agent.star_map") as mock_star_map: mock_plugin = MagicMock() mock_plugin.name = "test_plugin" mock_plugin.reserved = False mock_star_map.get.return_value = mock_plugin module._plugin_tool_fix(mock_event, req) assert "mcp_tool" in req.func_tool.names() assert "plugin_tool" in req.func_tool.names() def test_plugin_tool_fix_mcp_preserved(self, mock_event): """Test that MCP tools are always preserved.""" module = ama mcp_tool = MagicMock(spec=MCPTool) mcp_tool.name = "mcp_tool" mcp_tool.active = True tool_set = ToolSet() tool_set.add_tool(mcp_tool) req = ProviderRequest(func_tool=tool_set) mock_event.plugins_name = ["other_plugin"] with patch("astrbot.core.astr_main_agent.star_map"): module._plugin_tool_fix(mock_event, req) assert "mcp_tool" in req.func_tool.names() class TestBuildMainAgent: """Tests for build_main_agent function.""" @pytest.mark.asyncio async def test_build_main_agent_basic( self, mock_event, mock_context, mock_provider ): """Test basic main agent building.""" module = ama mock_context.get_provider_by_id.return_value = None mock_context.get_using_provider.return_value = mock_provider mock_context.get_config.return_value = {} conv_mgr = mock_context.conversation_manager _setup_conversation_for_build(conv_mgr) with ( patch("astrbot.core.astr_main_agent.AgentRunner") as mock_runner_cls, patch("astrbot.core.astr_main_agent.AstrAgentContext"), ): mock_runner = MagicMock() mock_runner.reset = AsyncMock() mock_runner_cls.return_value = mock_runner result = await module.build_main_agent( event=mock_event, plugin_context=mock_context, config=module.MainAgentBuildConfig(tool_call_timeout=60), ) assert result is not None assert isinstance(result, module.MainAgentBuildResult) @pytest.mark.asyncio async def test_build_main_agent_no_provider(self, mock_event, mock_context): """Test building main agent when no provider is available.""" module = ama mock_context.get_provider_by_id.return_value = None mock_context.get_using_provider.side_effect = ValueError("No provider") result = await module.build_main_agent( event=mock_event, plugin_context=mock_context, config=module.MainAgentBuildConfig(tool_call_timeout=60), ) assert result is None @pytest.mark.asyncio async def test_build_main_agent_with_wake_prefix( self, mock_event, mock_context, mock_provider ): """Test building main agent with wake prefix.""" module = ama mock_event.message_str = "/command" mock_context.get_provider_by_id.return_value = None mock_context.get_using_provider.return_value = mock_provider mock_context.get_config.return_value = {} conv_mgr = mock_context.conversation_manager _setup_conversation_for_build(conv_mgr) with ( patch("astrbot.core.astr_main_agent.AgentRunner") as mock_runner_cls, patch("astrbot.core.astr_main_agent.AstrAgentContext"), ): mock_runner = MagicMock() mock_runner.reset = AsyncMock() mock_runner_cls.return_value = mock_runner result = await module.build_main_agent( event=mock_event, plugin_context=mock_context, config=module.MainAgentBuildConfig( tool_call_timeout=60, provider_wake_prefix="/" ), ) assert result is not None @pytest.mark.asyncio async def test_build_main_agent_no_wake_prefix( self, mock_event, mock_context, mock_provider ): """Test building main agent without matching wake prefix.""" module = ama mock_event.message_str = "hello" mock_context.get_provider_by_id.return_value = None mock_context.get_using_provider.return_value = mock_provider result = await module.build_main_agent( event=mock_event, plugin_context=mock_context, config=module.MainAgentBuildConfig( tool_call_timeout=60, provider_wake_prefix="/" ), ) assert result is None @pytest.mark.asyncio async def test_build_main_agent_with_images( self, mock_event, mock_context, mock_provider ): """Test building main agent with image attachments.""" module = ama mock_image = MagicMock(spec=Image) mock_image.convert_to_file_path = AsyncMock(return_value="/path/to/image.jpg") mock_event.message_obj.message = [mock_image] mock_context.get_provider_by_id.return_value = None mock_context.get_using_provider.return_value = mock_provider mock_context.get_config.return_value = {} conv_mgr = mock_context.conversation_manager _setup_conversation_for_build(conv_mgr) with ( patch("astrbot.core.astr_main_agent.AgentRunner") as mock_runner_cls, patch("astrbot.core.astr_main_agent.AstrAgentContext"), ): mock_runner = MagicMock() mock_runner.reset = AsyncMock() mock_runner_cls.return_value = mock_runner result = await module.build_main_agent( event=mock_event, plugin_context=mock_context, config=module.MainAgentBuildConfig(tool_call_timeout=60), ) assert result is not None @pytest.mark.asyncio async def test_build_main_agent_no_prompt_no_images( self, mock_event, mock_context, mock_provider ): """Test building main agent returns None when no prompt or images.""" module = ama mock_event.message_str = "" mock_event.message_obj.message = [] mock_context.get_provider_by_id.return_value = None mock_context.get_using_provider.return_value = mock_provider mock_context.get_config.return_value = {} conv_mgr = mock_context.conversation_manager _setup_conversation_for_build(conv_mgr) result = await module.build_main_agent( event=mock_event, plugin_context=mock_context, config=module.MainAgentBuildConfig(tool_call_timeout=60), ) assert result is None @pytest.mark.asyncio async def test_build_main_agent_apply_reset_false( self, mock_event, mock_context, mock_provider ): """Test building main agent without applying reset.""" module = ama mock_context.get_provider_by_id.return_value = None mock_context.get_using_provider.return_value = mock_provider mock_context.get_config.return_value = {} conv_mgr = mock_context.conversation_manager _setup_conversation_for_build(conv_mgr) with ( patch("astrbot.core.astr_main_agent.AgentRunner") as mock_runner_cls, patch("astrbot.core.astr_main_agent.AstrAgentContext"), ): mock_runner = MagicMock() mock_runner.reset = AsyncMock() mock_runner_cls.return_value = mock_runner result = await module.build_main_agent( event=mock_event, plugin_context=mock_context, config=module.MainAgentBuildConfig(tool_call_timeout=60), apply_reset=False, ) assert result is not None assert result.reset_coro is not None mock_runner.reset.assert_called_once() result.reset_coro.close() @pytest.mark.asyncio async def test_build_main_agent_with_existing_request( self, mock_event, mock_context, mock_provider ): """Test building main agent with existing ProviderRequest.""" module = ama existing_req = ProviderRequest(prompt="Existing prompt") mock_event.get_extra.side_effect = lambda k: ( existing_req if k == "provider_request" else None ) with ( patch("astrbot.core.astr_main_agent.AgentRunner") as mock_runner_cls, patch("astrbot.core.astr_main_agent.AstrAgentContext"), ): mock_runner = MagicMock() mock_runner.reset = AsyncMock() mock_runner_cls.return_value = mock_runner result = await module.build_main_agent( event=mock_event, plugin_context=mock_context, config=module.MainAgentBuildConfig(tool_call_timeout=60), provider=mock_provider, req=existing_req, ) assert result is not None assert result.provider_request == existing_req class TestHandleWebchat: """Tests for _handle_webchat function.""" @pytest.mark.asyncio async def test_handle_webchat_generates_title(self, mock_event): """Test generating title for webchat session without display name.""" module = ama mock_event.session_id = "platform!webchat-session-123" req = ProviderRequest(prompt="What is machine learning?") prov = MagicMock(spec=Provider) llm_response = MagicMock() llm_response.completion_text = "Machine Learning Introduction" prov.text_chat = AsyncMock(return_value=llm_response) mock_session = MagicMock() mock_session.display_name = None with patch("astrbot.core.db_helper") as mock_db: mock_db.get_platform_session_by_id = AsyncMock(return_value=mock_session) mock_db.update_platform_session = AsyncMock() await module._handle_webchat(mock_event, req, prov) mock_db.get_platform_session_by_id.assert_called_once_with( "webchat-session-123" ) mock_db.update_platform_session.assert_called_once_with( session_id="webchat-session-123", display_name="Machine Learning Introduction", ) @pytest.mark.asyncio async def test_handle_webchat_no_user_prompt(self, mock_event): """Test that title generation is skipped when no user prompt.""" module = ama mock_event.session_id = "platform!webchat-session-123" req = ProviderRequest(prompt=None) prov = MagicMock(spec=Provider) mock_session = MagicMock() mock_session.display_name = None with patch("astrbot.core.db_helper") as mock_db: mock_db.get_platform_session_by_id = AsyncMock(return_value=mock_session) await module._handle_webchat(mock_event, req, prov) prov.text_chat.assert_not_called() @pytest.mark.asyncio async def test_handle_webchat_empty_user_prompt(self, mock_event): """Test that title generation is skipped when user prompt is empty.""" module = ama mock_event.session_id = "platform!webchat-session-123" req = ProviderRequest(prompt="") prov = MagicMock(spec=Provider) mock_session = MagicMock() mock_session.display_name = None with patch("astrbot.core.db_helper") as mock_db: mock_db.get_platform_session_by_id = AsyncMock(return_value=mock_session) await module._handle_webchat(mock_event, req, prov) prov.text_chat.assert_not_called() @pytest.mark.asyncio async def test_handle_webchat_session_already_has_display_name(self, mock_event): """Test that title generation is skipped when session already has display name.""" module = ama mock_event.session_id = "platform!webchat-session-123" req = ProviderRequest(prompt="What is AI?") prov = MagicMock(spec=Provider) mock_session = MagicMock() mock_session.display_name = "Existing Title" with patch("astrbot.core.db_helper") as mock_db: mock_db.get_platform_session_by_id = AsyncMock(return_value=mock_session) await module._handle_webchat(mock_event, req, prov) prov.text_chat.assert_not_called() @pytest.mark.asyncio async def test_handle_webchat_no_session_found(self, mock_event): """Test that title generation is skipped when session is not found.""" module = ama mock_event.session_id = "platform!webchat-session-123" req = ProviderRequest(prompt="What is AI?") prov = MagicMock(spec=Provider) with patch("astrbot.core.db_helper") as mock_db: mock_db.get_platform_session_by_id = AsyncMock(return_value=None) await module._handle_webchat(mock_event, req, prov) prov.text_chat.assert_not_called() @pytest.mark.asyncio async def test_handle_webchat_llm_returns_none_title(self, mock_event): """Test that title is not updated when LLM returns .""" module = ama mock_event.session_id = "platform!webchat-session-123" req = ProviderRequest(prompt="hi") prov = MagicMock(spec=Provider) llm_response = MagicMock() llm_response.completion_text = "" prov.text_chat = AsyncMock(return_value=llm_response) mock_session = MagicMock() mock_session.display_name = None with patch("astrbot.core.db_helper") as mock_db: mock_db.get_platform_session_by_id = AsyncMock(return_value=mock_session) mock_db.update_platform_session = AsyncMock() await module._handle_webchat(mock_event, req, prov) mock_db.update_platform_session.assert_not_called() @pytest.mark.asyncio async def test_handle_webchat_llm_returns_empty_title(self, mock_event): """Test that title is not updated when LLM returns empty string.""" module = ama mock_event.session_id = "platform!webchat-session-123" req = ProviderRequest(prompt="hello") prov = MagicMock(spec=Provider) llm_response = MagicMock() llm_response.completion_text = " " prov.text_chat = AsyncMock(return_value=llm_response) mock_session = MagicMock() mock_session.display_name = None with patch("astrbot.core.db_helper") as mock_db: mock_db.get_platform_session_by_id = AsyncMock(return_value=mock_session) mock_db.update_platform_session = AsyncMock() await module._handle_webchat(mock_event, req, prov) mock_db.update_platform_session.assert_not_called() @pytest.mark.asyncio async def test_handle_webchat_llm_returns_none_response(self, mock_event): """Test handling when LLM returns None response.""" module = ama mock_event.session_id = "platform!webchat-session-123" req = ProviderRequest(prompt="test question") prov = MagicMock(spec=Provider) prov.text_chat = AsyncMock(return_value=None) mock_session = MagicMock() mock_session.display_name = None with patch("astrbot.core.db_helper") as mock_db: mock_db.get_platform_session_by_id = AsyncMock(return_value=mock_session) mock_db.update_platform_session = AsyncMock() await module._handle_webchat(mock_event, req, prov) mock_db.update_platform_session.assert_not_called() @pytest.mark.asyncio async def test_handle_webchat_llm_returns_no_completion_text(self, mock_event): """Test handling when LLM response has no completion_text.""" module = ama mock_event.session_id = "platform!webchat-session-123" req = ProviderRequest(prompt="test question") prov = MagicMock(spec=Provider) llm_response = MagicMock() llm_response.completion_text = None prov.text_chat = AsyncMock(return_value=llm_response) mock_session = MagicMock() mock_session.display_name = None with patch("astrbot.core.db_helper") as mock_db: mock_db.get_platform_session_by_id = AsyncMock(return_value=mock_session) mock_db.update_platform_session = AsyncMock() await module._handle_webchat(mock_event, req, prov) mock_db.update_platform_session.assert_not_called() @pytest.mark.asyncio async def test_handle_webchat_strips_title_whitespace(self, mock_event): """Test that generated title has whitespace stripped.""" module = ama mock_event.session_id = "platform!webchat-session-123" req = ProviderRequest(prompt="What is Python?") prov = MagicMock(spec=Provider) llm_response = MagicMock() llm_response.completion_text = " Python Programming Guide " prov.text_chat = AsyncMock(return_value=llm_response) mock_session = MagicMock() mock_session.display_name = None with patch("astrbot.core.db_helper") as mock_db: mock_db.get_platform_session_by_id = AsyncMock(return_value=mock_session) mock_db.update_platform_session = AsyncMock() await module._handle_webchat(mock_event, req, prov) mock_db.update_platform_session.assert_called_once_with( session_id="webchat-session-123", display_name="Python Programming Guide", ) @pytest.mark.asyncio async def test_handle_webchat_provider_exception_is_handled(self, mock_event): """Test that provider exception during title generation is handled.""" module = ama mock_event.session_id = "platform!webchat-session-123" req = ProviderRequest(prompt="What is Python?") prov = MagicMock(spec=Provider) prov.text_chat = AsyncMock(side_effect=RuntimeError("provider failed")) mock_session = MagicMock() mock_session.display_name = None with ( patch("astrbot.core.db_helper") as mock_db, patch("astrbot.core.astr_main_agent.logger") as mock_logger, ): mock_db.get_platform_session_by_id = AsyncMock(return_value=mock_session) mock_db.update_platform_session = AsyncMock() await module._handle_webchat(mock_event, req, prov) mock_logger.exception.assert_called_once() mock_db.update_platform_session.assert_not_called() class TestApplyLlmSafetyMode: """Tests for _apply_llm_safety_mode function.""" def test_apply_llm_safety_mode_system_prompt_strategy(self): """Test applying safety mode with system_prompt strategy.""" module = ama config = module.MainAgentBuildConfig( tool_call_timeout=60, llm_safety_mode=True, safety_mode_strategy="system_prompt", ) req = ProviderRequest(prompt="Test", system_prompt="Original prompt") module._apply_llm_safety_mode(config, req) assert "You are running in Safe Mode" in req.system_prompt assert "Original prompt" in req.system_prompt def test_apply_llm_safety_mode_prepends_safety_prompt(self): """Test that safety prompt is prepended before original system prompt.""" module = ama config = module.MainAgentBuildConfig( tool_call_timeout=60, safety_mode_strategy="system_prompt", ) req = ProviderRequest(prompt="Test", system_prompt="My custom prompt") module._apply_llm_safety_mode(config, req) assert req.system_prompt.startswith("You are running in Safe Mode") assert "My custom prompt" in req.system_prompt def test_apply_llm_safety_mode_with_none_system_prompt(self): """Test applying safety mode when original system_prompt is None.""" module = ama config = module.MainAgentBuildConfig( tool_call_timeout=60, safety_mode_strategy="system_prompt", ) req = ProviderRequest(prompt="Test", system_prompt=None) module._apply_llm_safety_mode(config, req) assert "You are running in Safe Mode" in req.system_prompt def test_apply_llm_safety_mode_unsupported_strategy(self): """Test that unsupported strategy logs warning and does nothing.""" module = ama config = module.MainAgentBuildConfig( tool_call_timeout=60, safety_mode_strategy="unsupported_strategy", ) req = ProviderRequest(prompt="Test", system_prompt="Original") with patch("astrbot.core.astr_main_agent.logger") as mock_logger: module._apply_llm_safety_mode(config, req) mock_logger.warning.assert_called_once() assert ( "Unsupported llm_safety_mode strategy" in mock_logger.warning.call_args[0][0] ) assert req.system_prompt == "Original" def test_apply_llm_safety_mode_empty_system_prompt(self): """Test applying safety mode when original system_prompt is empty.""" module = ama config = module.MainAgentBuildConfig( tool_call_timeout=60, safety_mode_strategy="system_prompt", ) req = ProviderRequest(prompt="Test", system_prompt="") module._apply_llm_safety_mode(config, req) assert "You are running in Safe Mode" in req.system_prompt class TestApplySandboxTools: """Tests for _apply_sandbox_tools function.""" def test_apply_sandbox_tools_creates_toolset_if_none(self): """Test that ToolSet is created when func_tool is None.""" module = ama config = module.MainAgentBuildConfig( tool_call_timeout=60, computer_use_runtime="sandbox", sandbox_cfg={}, ) req = ProviderRequest(prompt="Test", func_tool=None) module._apply_sandbox_tools(config, req, "session-123") assert req.func_tool is not None assert isinstance(req.func_tool, ToolSet) def test_apply_sandbox_tools_adds_required_tools(self): """Test that all required sandbox tools are added.""" module = ama config = module.MainAgentBuildConfig( tool_call_timeout=60, computer_use_runtime="sandbox", sandbox_cfg={}, ) req = ProviderRequest(prompt="Test", func_tool=None) module._apply_sandbox_tools(config, req, "session-123") tool_names = req.func_tool.names() assert "astrbot_execute_shell" in tool_names assert "astrbot_execute_ipython" in tool_names assert "astrbot_upload_file" in tool_names assert "astrbot_download_file" in tool_names def test_apply_sandbox_tools_adds_sandbox_prompt(self): """Test that sandbox mode prompt is added to system_prompt.""" module = ama config = module.MainAgentBuildConfig( tool_call_timeout=60, computer_use_runtime="sandbox", sandbox_cfg={}, ) req = ProviderRequest(prompt="Test", system_prompt="Original prompt") module._apply_sandbox_tools(config, req, "session-123") assert "sandboxed environment" in req.system_prompt def test_apply_sandbox_tools_with_shipyard_booter(self, monkeypatch): """Test sandbox tools with shipyard booter configuration.""" module = ama config = module.MainAgentBuildConfig( tool_call_timeout=60, computer_use_runtime="sandbox", sandbox_cfg={ "booter": "shipyard", "shipyard_endpoint": "https://shipyard.example.com", "shipyard_access_token": "test-token", }, ) req = ProviderRequest(prompt="Test", func_tool=None) monkeypatch.delenv("SHIPYARD_ENDPOINT", raising=False) monkeypatch.delenv("SHIPYARD_ACCESS_TOKEN", raising=False) module._apply_sandbox_tools(config, req, "session-123") assert os.environ.get("SHIPYARD_ENDPOINT") == "https://shipyard.example.com" assert os.environ.get("SHIPYARD_ACCESS_TOKEN") == "test-token" def test_apply_sandbox_tools_shipyard_missing_endpoint(self): """Test that shipyard config is skipped when endpoint is missing.""" module = ama config = module.MainAgentBuildConfig( tool_call_timeout=60, computer_use_runtime="sandbox", sandbox_cfg={ "booter": "shipyard", "shipyard_endpoint": "", "shipyard_access_token": "test-token", }, ) req = ProviderRequest(prompt="Test", func_tool=None) with patch("astrbot.core.astr_main_agent.logger") as mock_logger: module._apply_sandbox_tools(config, req, "session-123") mock_logger.error.assert_called_once() assert ( "Shipyard sandbox configuration is incomplete" in mock_logger.error.call_args[0][0] ) def test_apply_sandbox_tools_shipyard_missing_access_token(self): """Test that shipyard config is skipped when access token is missing.""" module = ama config = module.MainAgentBuildConfig( tool_call_timeout=60, computer_use_runtime="sandbox", sandbox_cfg={ "booter": "shipyard", "shipyard_endpoint": "https://shipyard.example.com", "shipyard_access_token": "", }, ) req = ProviderRequest(prompt="Test", func_tool=None) with patch("astrbot.core.astr_main_agent.logger") as mock_logger: module._apply_sandbox_tools(config, req, "session-123") mock_logger.error.assert_called_once() def test_apply_sandbox_tools_preserves_existing_toolset(self): """Test that existing tools are preserved when adding sandbox tools.""" module = ama config = module.MainAgentBuildConfig( tool_call_timeout=60, computer_use_runtime="sandbox", sandbox_cfg={}, ) existing_toolset = ToolSet() existing_tool = MagicMock() existing_tool.name = "existing_tool" existing_toolset.add_tool(existing_tool) req = ProviderRequest(prompt="Test", func_tool=existing_toolset) module._apply_sandbox_tools(config, req, "session-123") assert "existing_tool" in req.func_tool.names() assert "astrbot_execute_shell" in req.func_tool.names() def test_apply_sandbox_tools_appends_to_existing_system_prompt(self): """Test that sandbox prompt is appended to existing system prompt.""" module = ama config = module.MainAgentBuildConfig( tool_call_timeout=60, computer_use_runtime="sandbox", sandbox_cfg={}, ) req = ProviderRequest(prompt="Test", system_prompt="Base prompt") module._apply_sandbox_tools(config, req, "session-123") assert req.system_prompt.startswith("Base prompt") assert "sandboxed environment" in req.system_prompt def test_apply_sandbox_tools_with_none_system_prompt(self): """Test that sandbox prompt is applied when system_prompt is None.""" module = ama config = module.MainAgentBuildConfig( tool_call_timeout=60, computer_use_runtime="sandbox", sandbox_cfg={}, ) req = ProviderRequest(prompt="Test", system_prompt=None) module._apply_sandbox_tools(config, req, "session-123") assert isinstance(req.system_prompt, str) assert "sandboxed environment" in req.system_prompt