astrbbbb / tests /unit /test_astr_main_agent.py
qa1145's picture
Upload 1245 files
8ede856 verified
"""Tests for astr_main_agent module."""
import os
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from astrbot.core import astr_main_agent as ama
from astrbot.core.agent.mcp_client import MCPTool
from astrbot.core.agent.tool import FunctionTool, ToolSet
from astrbot.core.conversation_mgr import Conversation
from astrbot.core.message.components import File, Image, Plain, Reply
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.platform.platform_metadata import PlatformMetadata
from astrbot.core.provider import Provider
from astrbot.core.provider.entities import ProviderRequest
@pytest.fixture
def mock_provider():
"""Create a mock provider."""
provider = MagicMock(spec=Provider)
provider.provider_config = {
"id": "test-provider",
"modalities": ["image", "tool_use"],
}
provider.get_model.return_value = "gpt-4"
return provider
@pytest.fixture
def mock_context():
"""Create a mock Context."""
ctx = MagicMock()
ctx.get_config.return_value = {}
ctx.conversation_manager = MagicMock()
ctx.persona_manager = MagicMock()
ctx.persona_manager.personas_v3 = []
ctx.persona_manager.resolve_selected_persona = AsyncMock(
return_value=(None, None, None, False)
)
ctx.get_llm_tool_manager.return_value = MagicMock()
ctx.subagent_orchestrator = None
return ctx
@pytest.fixture
def mock_event():
"""Create a mock AstrMessageEvent."""
platform_meta = PlatformMetadata(
id="test_platform",
name="test_platform",
description="Test platform",
)
message_obj = MagicMock()
message_obj.message = [Plain(text="Hello")]
message_obj.sender = MagicMock(user_id="user123", nickname="TestUser")
message_obj.group_id = None
message_obj.group = None
event = MagicMock(spec=AstrMessageEvent)
event.message_str = "Hello"
event.message_obj = message_obj
event.platform_meta = platform_meta
event.session_id = "session123"
event.unified_msg_origin = "test_platform:private:session123"
event.get_extra.return_value = None
event.get_platform_name.return_value = "test_platform"
event.get_platform_id.return_value = "test_platform"
event.get_group_id.return_value = None
event.get_sender_name.return_value = "TestUser"
event.trace = MagicMock()
event.plugins_name = None
return event
@pytest.fixture
def mock_conversation():
"""Create a mock conversation."""
conv = MagicMock(spec=Conversation)
conv.cid = "conv-id"
conv.persona_id = None
conv.history = "[]"
return conv
@pytest.fixture
def sample_config():
"""Create a sample MainAgentBuildConfig."""
module = ama
return module.MainAgentBuildConfig(
tool_call_timeout=60,
streaming_response=True,
file_extract_enabled=True,
file_extract_prov="moonshotai",
file_extract_msh_api_key="test-api-key",
)
def _new_mock_conversation(cid: str = "conv-id") -> MagicMock:
conv = MagicMock(spec=Conversation)
conv.cid = cid
conv.persona_id = None
conv.history = "[]"
return conv
def _setup_conversation_for_build(conv_mgr, cid: str = "conv-id") -> MagicMock:
conv_mgr.get_curr_conversation_id = AsyncMock(return_value=None)
conv_mgr.new_conversation = AsyncMock(return_value=cid)
conversation = _new_mock_conversation(cid=cid)
conv_mgr.get_conversation = AsyncMock(return_value=conversation)
return conversation
class TestMainAgentBuildConfig:
"""Tests for MainAgentBuildConfig dataclass."""
def test_config_initialization(self):
"""Test MainAgentBuildConfig initialization with defaults."""
module = ama
config = module.MainAgentBuildConfig(tool_call_timeout=60)
assert config.tool_call_timeout == 60
assert config.tool_schema_mode == "full"
assert config.provider_wake_prefix == ""
assert config.streaming_response is True
assert config.sanitize_context_by_modalities is False
assert config.kb_agentic_mode is False
assert config.file_extract_enabled is False
assert config.llm_safety_mode is True
def test_config_with_custom_values(self):
"""Test MainAgentBuildConfig with custom values."""
module = ama
config = module.MainAgentBuildConfig(
tool_call_timeout=120,
tool_schema_mode="skills-like",
provider_wake_prefix="/",
streaming_response=False,
kb_agentic_mode=True,
file_extract_enabled=True,
computer_use_runtime="sandbox",
add_cron_tools=False,
)
assert config.tool_call_timeout == 120
assert config.tool_schema_mode == "skills-like"
assert config.provider_wake_prefix == "/"
assert config.streaming_response is False
assert config.kb_agentic_mode is True
assert config.file_extract_enabled is True
assert config.computer_use_runtime == "sandbox"
assert config.add_cron_tools is False
class TestSelectProvider:
"""Tests for _select_provider function."""
def test_select_provider_by_id(self, mock_event, mock_context, mock_provider):
"""Test selecting provider by ID from event extra."""
module = ama
mock_event.get_extra.side_effect = lambda k: (
"test-provider" if k == "selected_provider" else None
)
mock_context.get_provider_by_id.return_value = mock_provider
result = module._select_provider(mock_event, mock_context)
assert result == mock_provider
mock_context.get_provider_by_id.assert_called_once_with("test-provider")
def test_select_provider_not_found(self, mock_event, mock_context):
"""Test selecting provider when ID is not found."""
module = ama
mock_event.get_extra.side_effect = lambda k: (
"non-existent" if k == "selected_provider" else None
)
mock_context.get_provider_by_id.return_value = None
result = module._select_provider(mock_event, mock_context)
assert result is None
def test_select_provider_invalid_type(self, mock_event, mock_context):
"""Test selecting provider when result is not a Provider instance."""
module = ama
mock_event.get_extra.side_effect = lambda k: (
"invalid" if k == "selected_provider" else None
)
mock_context.get_provider_by_id.return_value = "not a provider"
result = module._select_provider(mock_event, mock_context)
assert result is None
def test_select_provider_fallback(self, mock_event, mock_context, mock_provider):
"""Test provider selection fallback to using provider."""
module = ama
mock_event.get_extra.return_value = None
mock_context.get_using_provider.return_value = mock_provider
result = module._select_provider(mock_event, mock_context)
assert result == mock_provider
mock_context.get_using_provider.assert_called_once_with(
umo=mock_event.unified_msg_origin
)
def test_select_provider_fallback_error(self, mock_event, mock_context):
"""Test provider selection when fallback raises ValueError."""
module = ama
mock_event.get_extra.return_value = None
mock_context.get_using_provider.side_effect = ValueError("Test error")
result = module._select_provider(mock_event, mock_context)
assert result is None
class TestGetSessionConv:
"""Tests for _get_session_conv function."""
@pytest.mark.asyncio
async def test_get_session_conv_existing(
self, mock_event, mock_context, mock_conversation
):
"""Test getting existing conversation."""
module = ama
conv_mgr = mock_context.conversation_manager
conv_mgr.get_curr_conversation_id = AsyncMock(return_value="existing-conv-id")
conv_mgr.get_conversation = AsyncMock(return_value=mock_conversation)
result = await module._get_session_conv(mock_event, mock_context)
assert result == mock_conversation
conv_mgr.get_curr_conversation_id.assert_called_once_with(
mock_event.unified_msg_origin
)
conv_mgr.get_conversation.assert_called_once_with(
mock_event.unified_msg_origin, "existing-conv-id"
)
@pytest.mark.asyncio
async def test_get_session_conv_create_new(self, mock_event, mock_context):
"""Test creating new conversation when none exists."""
module = ama
conv_mgr = mock_context.conversation_manager
conv_mgr.get_curr_conversation_id = AsyncMock(return_value=None)
conv_mgr.new_conversation = AsyncMock(return_value="new-conv-id")
mock_conversation = MagicMock(spec=Conversation)
mock_conversation.cid = "new-conv-id"
mock_conversation.persona_id = None
mock_conversation.history = "[]"
conv_mgr.get_conversation = AsyncMock(return_value=mock_conversation)
result = await module._get_session_conv(mock_event, mock_context)
assert result == mock_conversation
conv_mgr.new_conversation.assert_called_once_with(
mock_event.unified_msg_origin, mock_event.get_platform_id()
)
@pytest.mark.asyncio
async def test_get_session_conv_retry(self, mock_event, mock_context):
"""Test retrying conversation creation after failure."""
module = ama
conv_mgr = mock_context.conversation_manager
conv_mgr.get_curr_conversation_id = AsyncMock(return_value="conv-id")
conv_mgr.get_conversation = AsyncMock(return_value=None)
conv_mgr.new_conversation = AsyncMock(return_value="retry-conv-id")
mock_conversation = MagicMock(spec=Conversation)
mock_conversation.cid = "retry-conv-id"
mock_conversation.persona_id = None
mock_conversation.history = "[]"
conv_mgr.get_conversation.side_effect = [None, mock_conversation]
result = await module._get_session_conv(mock_event, mock_context)
assert result == mock_conversation
assert conv_mgr.new_conversation.call_count == 1
assert conv_mgr.get_conversation.call_count == 2
@pytest.mark.asyncio
async def test_get_session_conv_failure(self, mock_event, mock_context):
"""Test RuntimeError when conversation creation fails."""
module = ama
conv_mgr = mock_context.conversation_manager
conv_mgr.get_curr_conversation_id = AsyncMock(return_value=None)
conv_mgr.new_conversation = AsyncMock(return_value="new-conv-id")
conv_mgr.get_conversation = AsyncMock(return_value=None)
with pytest.raises(RuntimeError, match="无法创建新的对话。"):
await module._get_session_conv(mock_event, mock_context)
class TestApplyKb:
"""Tests for _apply_kb function."""
@pytest.mark.asyncio
async def test_apply_kb_without_agentic_mode(self, mock_event, mock_context):
"""Test applying knowledge base in non-agentic mode."""
module = ama
req = ProviderRequest(prompt="test question", system_prompt="System prompt")
config = module.MainAgentBuildConfig(
tool_call_timeout=60, kb_agentic_mode=False
)
with patch(
"astrbot.core.astr_main_agent.retrieve_knowledge_base",
AsyncMock(return_value="KB result"),
):
await module._apply_kb(mock_event, req, mock_context, config)
assert "[Related Knowledge Base Results]:" in req.system_prompt
assert "KB result" in req.system_prompt
@pytest.mark.asyncio
async def test_apply_kb_with_agentic_mode(self, mock_event, mock_context):
"""Test applying knowledge base in agentic mode."""
module = ama
req = ProviderRequest(prompt="test question")
config = module.MainAgentBuildConfig(tool_call_timeout=60, kb_agentic_mode=True)
await module._apply_kb(mock_event, req, mock_context, config)
assert req.func_tool is not None
@pytest.mark.asyncio
async def test_apply_kb_no_prompt(self, mock_event, mock_context):
"""Test applying knowledge base when prompt is None."""
module = ama
req = ProviderRequest(prompt=None, system_prompt="System")
config = module.MainAgentBuildConfig(
tool_call_timeout=60, kb_agentic_mode=False
)
await module._apply_kb(mock_event, req, mock_context, config)
assert req.system_prompt == "System"
@pytest.mark.asyncio
async def test_apply_kb_no_result(self, mock_event, mock_context):
"""Test applying knowledge base when no result is returned."""
module = ama
req = ProviderRequest(prompt="test", system_prompt="System")
config = module.MainAgentBuildConfig(
tool_call_timeout=60, kb_agentic_mode=False
)
with patch(
"astrbot.core.astr_main_agent.retrieve_knowledge_base",
AsyncMock(return_value=None),
):
await module._apply_kb(mock_event, req, mock_context, config)
assert req.system_prompt == "System"
@pytest.mark.asyncio
async def test_apply_kb_with_existing_tools(self, mock_event, mock_context):
"""Test applying knowledge base with existing toolset."""
module = ama
existing_tools = ToolSet()
req = ProviderRequest(prompt="test", func_tool=existing_tools)
config = module.MainAgentBuildConfig(tool_call_timeout=60, kb_agentic_mode=True)
await module._apply_kb(mock_event, req, mock_context, config)
assert req.func_tool is not None
class TestApplyFileExtract:
"""Tests for _apply_file_extract function."""
@pytest.mark.asyncio
async def test_file_extract_basic(self, mock_event, sample_config):
"""Test basic file extraction."""
module = ama
mock_file = MagicMock(spec=File)
mock_file.name = "test.pdf"
mock_file.get_file = AsyncMock(return_value="/path/to/test.pdf")
mock_event.message_obj.message = [mock_file]
req = ProviderRequest(prompt="Summarize")
with patch(
"astrbot.core.astr_main_agent.extract_file_moonshotai"
) as mock_extract:
mock_extract.return_value = "File content"
await module._apply_file_extract(mock_event, req, sample_config)
assert len(req.contexts) == 1
assert "File Extract Results" in req.contexts[0]["content"]
@pytest.mark.asyncio
async def test_file_extract_no_files(self, mock_event, sample_config):
"""Test file extraction when no files present."""
module = ama
mock_event.message_obj.message = [Plain(text="Hello")]
req = ProviderRequest(prompt="Hello")
await module._apply_file_extract(mock_event, req, sample_config)
assert len(req.contexts) == 0
@pytest.mark.asyncio
async def test_file_extract_in_reply(self, mock_event, sample_config):
"""Test file extraction from reply chain."""
module = ama
mock_file = MagicMock(spec=File)
mock_file.name = "reply.pdf"
mock_file.get_file = AsyncMock(return_value="/path/to/reply.pdf")
mock_reply = MagicMock(spec=Reply)
mock_reply.chain = [mock_file]
mock_event.message_obj.message = [mock_reply]
req = ProviderRequest(prompt="Summarize")
with patch(
"astrbot.core.astr_main_agent.extract_file_moonshotai"
) as mock_extract:
mock_extract.return_value = "Reply content"
await module._apply_file_extract(mock_event, req, sample_config)
assert len(req.contexts) == 1
@pytest.mark.asyncio
async def test_file_extract_no_prompt(self, mock_event, sample_config):
"""Test file extraction when prompt is empty."""
module = ama
mock_file = MagicMock(spec=File)
mock_file.name = "test.pdf"
mock_file.get_file = AsyncMock(return_value="/path/to/test.pdf")
mock_event.message_obj.message = [mock_file]
req = ProviderRequest(prompt=None)
with patch(
"astrbot.core.astr_main_agent.extract_file_moonshotai"
) as mock_extract:
mock_extract.return_value = "Content"
await module._apply_file_extract(mock_event, req, sample_config)
assert req.prompt == "总结一下文件里面讲了什么?"
@pytest.mark.asyncio
async def test_file_extract_no_api_key(self, mock_event):
"""Test file extraction when no API key is configured."""
module = ama
config = module.MainAgentBuildConfig(
tool_call_timeout=60,
file_extract_enabled=True,
file_extract_msh_api_key="",
)
mock_file = MagicMock(spec=File)
mock_file.name = "test.pdf"
mock_file.get_file = AsyncMock(return_value="/path/to/test.pdf")
mock_event.message_obj.message = [mock_file]
req = ProviderRequest(prompt="Summarize")
await module._apply_file_extract(mock_event, req, config)
assert len(req.contexts) == 0
class TestEnsurePersonaAndSkills:
"""Tests for _ensure_persona_and_skills function."""
@pytest.mark.asyncio
async def test_ensure_persona_from_session(self, mock_event, mock_context):
"""Test applying persona from session service config."""
module = ama
persona = {"name": "test-persona", "prompt": "You are helpful."}
mock_context.persona_manager.personas_v3 = [persona]
mock_context.persona_manager.resolve_selected_persona = AsyncMock(
return_value=("test-persona", persona, "test-persona", False)
)
mock_event.trace = MagicMock(record=MagicMock())
req = ProviderRequest()
req.conversation = MagicMock(persona_id=None)
await module._ensure_persona_and_skills(req, {}, mock_context, mock_event)
assert "You are helpful." in req.system_prompt
@pytest.mark.asyncio
async def test_ensure_persona_from_conversation(self, mock_event, mock_context):
"""Test applying persona from conversation setting."""
module = ama
persona = {"name": "conv-persona", "prompt": "Custom persona."}
mock_context.persona_manager.personas_v3 = [persona]
mock_context.persona_manager.resolve_selected_persona = AsyncMock(
return_value=("conv-persona", persona, None, False)
)
req = ProviderRequest()
req.conversation = MagicMock(persona_id="conv-persona")
await module._ensure_persona_and_skills(req, {}, mock_context, mock_event)
assert "Custom persona." in req.system_prompt
@pytest.mark.asyncio
async def test_ensure_persona_none_explicit(self, mock_event, mock_context):
"""Test that [%None] persona is explicitly set to no persona."""
module = ama
mock_context.persona_manager.personas_v3 = []
mock_context.persona_manager.resolve_selected_persona = AsyncMock(
return_value=("[%None]", None, None, False)
)
req = ProviderRequest()
req.conversation = MagicMock(persona_id="[%None]")
await module._ensure_persona_and_skills(req, {}, mock_context, mock_event)
assert "Persona Instructions" not in req.system_prompt
@pytest.mark.asyncio
async def test_ensure_tools_from_persona(self, mock_event, mock_context):
"""Test applying tools from persona."""
module = ama
mock_tool = MagicMock()
mock_tool.name = "test_tool"
mock_tool.active = True
persona = {"name": "persona", "prompt": "Test", "tools": ["test_tool"]}
mock_context.persona_manager.personas_v3 = [persona]
mock_context.persona_manager.resolve_selected_persona = AsyncMock(
return_value=("persona", persona, None, False)
)
tmgr = mock_context.get_llm_tool_manager.return_value
tmgr.get_func.return_value = mock_tool
req = ProviderRequest()
req.conversation = MagicMock(persona_id="persona")
await module._ensure_persona_and_skills(req, {}, mock_context, mock_event)
assert req.func_tool is not None
class TestDecorateLlmRequest:
"""Tests for _decorate_llm_request function."""
@pytest.mark.asyncio
async def test_decorate_llm_request_basic(
self, mock_event, mock_context, sample_config
):
"""Test basic LLM request decoration."""
module = ama
req = ProviderRequest(prompt="Hello", system_prompt="System")
await module._decorate_llm_request(mock_event, req, mock_context, sample_config)
assert req.prompt == "Hello"
assert req.system_prompt == "System"
@pytest.mark.asyncio
async def test_decorate_llm_request_with_prefix(self, mock_event, mock_context):
"""Test LLM request decoration with prompt prefix."""
module = ama
req = ProviderRequest(prompt="Hello")
config = module.MainAgentBuildConfig(
tool_call_timeout=60, provider_settings={"prompt_prefix": "AI: "}
)
with patch.object(mock_context, "get_config") as mock_get_config:
mock_get_config.return_value = {}
await module._decorate_llm_request(mock_event, req, mock_context, config)
assert req.prompt == "AI: Hello"
@pytest.mark.asyncio
async def test_decorate_llm_request_prefix_with_placeholder(
self, mock_event, mock_context
):
"""Test prompt prefix with {{prompt}} placeholder."""
module = ama
req = ProviderRequest(prompt="Hello")
config = module.MainAgentBuildConfig(
tool_call_timeout=60,
provider_settings={"prompt_prefix": "AI {{prompt}} - Please respond:"},
)
with patch.object(mock_context, "get_config") as mock_get_config:
mock_get_config.return_value = {}
await module._decorate_llm_request(mock_event, req, mock_context, config)
assert req.prompt == "AI Hello - Please respond:"
@pytest.mark.asyncio
async def test_decorate_llm_request_no_conversation(self, mock_event, mock_context):
"""Test decoration when no conversation exists."""
module = ama
req = ProviderRequest(prompt="Hello")
req.conversation = None
config = module.MainAgentBuildConfig(tool_call_timeout=60)
with patch.object(mock_context, "get_config") as mock_get_config:
mock_get_config.return_value = {}
await module._decorate_llm_request(mock_event, req, mock_context, config)
assert req.prompt == "Hello"
class TestModalitiesFix:
"""Tests for _modalities_fix function."""
def test_modalities_fix_image_not_supported(self, mock_provider):
"""Test modality fix when image is not supported."""
module = ama
mock_provider.provider_config = {"modalities": ["text"]}
req = ProviderRequest(prompt="Hello", image_urls=["/path/to/image.jpg"])
module._modalities_fix(mock_provider, req)
assert "[图片]" in req.prompt
assert req.image_urls == []
def test_modalities_fix_tool_not_supported(self, mock_provider):
"""Test modality fix when tool is not supported."""
module = ama
mock_provider.provider_config = {"modalities": ["text", "image"]}
req = ProviderRequest(prompt="Hello")
req.func_tool = ToolSet()
req.func_tool.add_tool(
FunctionTool(
name="dummy_tool",
description="dummy",
parameters={"type": "object", "properties": {}},
)
)
module._modalities_fix(mock_provider, req)
assert req.func_tool is None
def test_modalities_fix_all_supported(self, mock_provider):
"""Test modality fix when all features are supported."""
module = ama
mock_provider.provider_config = {"modalities": ["image", "tool_use"]}
tool_set = ToolSet()
tool_set.add_tool(
FunctionTool(
name="dummy_tool",
description="dummy",
parameters={"type": "object", "properties": {}},
)
)
req = ProviderRequest(
prompt="Hello",
image_urls=["/path/to/image.jpg"],
func_tool=tool_set,
)
module._modalities_fix(mock_provider, req)
assert req.prompt == "Hello"
assert len(req.image_urls) == 1
assert req.func_tool is not None
class TestSanitizeContextByModalities:
"""Tests for _sanitize_context_by_modalities function."""
def test_sanitize_no_op(self, mock_provider):
"""Test sanitize when disabled or modalities support everything."""
module = ama
config = module.MainAgentBuildConfig(
tool_call_timeout=60, sanitize_context_by_modalities=False
)
mock_provider.provider_config = {"modalities": ["image", "tool_use"]}
req = ProviderRequest(contexts=[{"role": "user", "content": "Hello"}])
module._sanitize_context_by_modalities(config, mock_provider, req)
assert len(req.contexts) == 1
def test_sanitize_removes_tool_messages(self, mock_provider):
"""Test sanitize removes tool messages when tool_use not supported."""
module = ama
config = module.MainAgentBuildConfig(
tool_call_timeout=60, sanitize_context_by_modalities=True
)
mock_provider.provider_config = {"modalities": ["image"]}
req = ProviderRequest(
contexts=[
{"role": "user", "content": "Hello"},
{"role": "tool", "content": "Tool result"},
]
)
module._sanitize_context_by_modalities(config, mock_provider, req)
assert len(req.contexts) == 1
assert req.contexts[0]["role"] == "user"
def test_sanitize_removes_tool_calls(self, mock_provider):
"""Test sanitize removes tool_calls from assistant messages."""
module = ama
config = module.MainAgentBuildConfig(
tool_call_timeout=60, sanitize_context_by_modalities=True
)
mock_provider.provider_config = {"modalities": ["image"]}
req = ProviderRequest(
contexts=[
{
"role": "assistant",
"content": "Response",
"tool_calls": [{"name": "tool"}],
}
]
)
module._sanitize_context_by_modalities(config, mock_provider, req)
assert "tool_calls" not in req.contexts[0]
def test_sanitize_removes_image_blocks(self, mock_provider):
"""Test sanitize removes image blocks when image not supported."""
module = ama
config = module.MainAgentBuildConfig(
tool_call_timeout=60, sanitize_context_by_modalities=True
)
mock_provider.provider_config = {"modalities": ["tool_use"]}
req = ProviderRequest(
contexts=[
{
"role": "user",
"content": [
{"type": "text", "text": "Hello"},
{"type": "image_url", "url": "image.jpg"},
],
}
]
)
module._sanitize_context_by_modalities(config, mock_provider, req)
assert len(req.contexts[0]["content"]) == 1
assert req.contexts[0]["content"][0]["type"] == "text"
class TestPluginToolFix:
"""Tests for _plugin_tool_fix function."""
def test_plugin_tool_fix_none_plugins(self, mock_event):
"""Test plugin tool fix when no plugins specified."""
module = ama
req = ProviderRequest(func_tool=ToolSet())
mock_event.plugins_name = None
module._plugin_tool_fix(mock_event, req)
assert req.func_tool is not None
def test_plugin_tool_fix_filters_by_plugin(self, mock_event):
"""Test plugin tool fix filters tools by enabled plugins."""
module = ama
mcp_tool = MagicMock(spec=MCPTool)
mcp_tool.name = "mcp_tool"
plugin_tool = MagicMock()
plugin_tool.name = "plugin_tool"
plugin_tool.handler_module_path = "test_plugin"
plugin_tool.active = True
tool_set = ToolSet()
tool_set.add_tool(mcp_tool)
tool_set.add_tool(plugin_tool)
req = ProviderRequest(func_tool=tool_set)
mock_event.plugins_name = ["test_plugin"]
with patch("astrbot.core.astr_main_agent.star_map") as mock_star_map:
mock_plugin = MagicMock()
mock_plugin.name = "test_plugin"
mock_plugin.reserved = False
mock_star_map.get.return_value = mock_plugin
module._plugin_tool_fix(mock_event, req)
assert "mcp_tool" in req.func_tool.names()
assert "plugin_tool" in req.func_tool.names()
def test_plugin_tool_fix_mcp_preserved(self, mock_event):
"""Test that MCP tools are always preserved."""
module = ama
mcp_tool = MagicMock(spec=MCPTool)
mcp_tool.name = "mcp_tool"
mcp_tool.active = True
tool_set = ToolSet()
tool_set.add_tool(mcp_tool)
req = ProviderRequest(func_tool=tool_set)
mock_event.plugins_name = ["other_plugin"]
with patch("astrbot.core.astr_main_agent.star_map"):
module._plugin_tool_fix(mock_event, req)
assert "mcp_tool" in req.func_tool.names()
class TestBuildMainAgent:
"""Tests for build_main_agent function."""
@pytest.mark.asyncio
async def test_build_main_agent_basic(
self, mock_event, mock_context, mock_provider
):
"""Test basic main agent building."""
module = ama
mock_context.get_provider_by_id.return_value = None
mock_context.get_using_provider.return_value = mock_provider
mock_context.get_config.return_value = {}
conv_mgr = mock_context.conversation_manager
_setup_conversation_for_build(conv_mgr)
with (
patch("astrbot.core.astr_main_agent.AgentRunner") as mock_runner_cls,
patch("astrbot.core.astr_main_agent.AstrAgentContext"),
):
mock_runner = MagicMock()
mock_runner.reset = AsyncMock()
mock_runner_cls.return_value = mock_runner
result = await module.build_main_agent(
event=mock_event,
plugin_context=mock_context,
config=module.MainAgentBuildConfig(tool_call_timeout=60),
)
assert result is not None
assert isinstance(result, module.MainAgentBuildResult)
@pytest.mark.asyncio
async def test_build_main_agent_no_provider(self, mock_event, mock_context):
"""Test building main agent when no provider is available."""
module = ama
mock_context.get_provider_by_id.return_value = None
mock_context.get_using_provider.side_effect = ValueError("No provider")
result = await module.build_main_agent(
event=mock_event,
plugin_context=mock_context,
config=module.MainAgentBuildConfig(tool_call_timeout=60),
)
assert result is None
@pytest.mark.asyncio
async def test_build_main_agent_with_wake_prefix(
self, mock_event, mock_context, mock_provider
):
"""Test building main agent with wake prefix."""
module = ama
mock_event.message_str = "/command"
mock_context.get_provider_by_id.return_value = None
mock_context.get_using_provider.return_value = mock_provider
mock_context.get_config.return_value = {}
conv_mgr = mock_context.conversation_manager
_setup_conversation_for_build(conv_mgr)
with (
patch("astrbot.core.astr_main_agent.AgentRunner") as mock_runner_cls,
patch("astrbot.core.astr_main_agent.AstrAgentContext"),
):
mock_runner = MagicMock()
mock_runner.reset = AsyncMock()
mock_runner_cls.return_value = mock_runner
result = await module.build_main_agent(
event=mock_event,
plugin_context=mock_context,
config=module.MainAgentBuildConfig(
tool_call_timeout=60, provider_wake_prefix="/"
),
)
assert result is not None
@pytest.mark.asyncio
async def test_build_main_agent_no_wake_prefix(
self, mock_event, mock_context, mock_provider
):
"""Test building main agent without matching wake prefix."""
module = ama
mock_event.message_str = "hello"
mock_context.get_provider_by_id.return_value = None
mock_context.get_using_provider.return_value = mock_provider
result = await module.build_main_agent(
event=mock_event,
plugin_context=mock_context,
config=module.MainAgentBuildConfig(
tool_call_timeout=60, provider_wake_prefix="/"
),
)
assert result is None
@pytest.mark.asyncio
async def test_build_main_agent_with_images(
self, mock_event, mock_context, mock_provider
):
"""Test building main agent with image attachments."""
module = ama
mock_image = MagicMock(spec=Image)
mock_image.convert_to_file_path = AsyncMock(return_value="/path/to/image.jpg")
mock_event.message_obj.message = [mock_image]
mock_context.get_provider_by_id.return_value = None
mock_context.get_using_provider.return_value = mock_provider
mock_context.get_config.return_value = {}
conv_mgr = mock_context.conversation_manager
_setup_conversation_for_build(conv_mgr)
with (
patch("astrbot.core.astr_main_agent.AgentRunner") as mock_runner_cls,
patch("astrbot.core.astr_main_agent.AstrAgentContext"),
):
mock_runner = MagicMock()
mock_runner.reset = AsyncMock()
mock_runner_cls.return_value = mock_runner
result = await module.build_main_agent(
event=mock_event,
plugin_context=mock_context,
config=module.MainAgentBuildConfig(tool_call_timeout=60),
)
assert result is not None
@pytest.mark.asyncio
async def test_build_main_agent_no_prompt_no_images(
self, mock_event, mock_context, mock_provider
):
"""Test building main agent returns None when no prompt or images."""
module = ama
mock_event.message_str = ""
mock_event.message_obj.message = []
mock_context.get_provider_by_id.return_value = None
mock_context.get_using_provider.return_value = mock_provider
mock_context.get_config.return_value = {}
conv_mgr = mock_context.conversation_manager
_setup_conversation_for_build(conv_mgr)
result = await module.build_main_agent(
event=mock_event,
plugin_context=mock_context,
config=module.MainAgentBuildConfig(tool_call_timeout=60),
)
assert result is None
@pytest.mark.asyncio
async def test_build_main_agent_apply_reset_false(
self, mock_event, mock_context, mock_provider
):
"""Test building main agent without applying reset."""
module = ama
mock_context.get_provider_by_id.return_value = None
mock_context.get_using_provider.return_value = mock_provider
mock_context.get_config.return_value = {}
conv_mgr = mock_context.conversation_manager
_setup_conversation_for_build(conv_mgr)
with (
patch("astrbot.core.astr_main_agent.AgentRunner") as mock_runner_cls,
patch("astrbot.core.astr_main_agent.AstrAgentContext"),
):
mock_runner = MagicMock()
mock_runner.reset = AsyncMock()
mock_runner_cls.return_value = mock_runner
result = await module.build_main_agent(
event=mock_event,
plugin_context=mock_context,
config=module.MainAgentBuildConfig(tool_call_timeout=60),
apply_reset=False,
)
assert result is not None
assert result.reset_coro is not None
mock_runner.reset.assert_called_once()
result.reset_coro.close()
@pytest.mark.asyncio
async def test_build_main_agent_with_existing_request(
self, mock_event, mock_context, mock_provider
):
"""Test building main agent with existing ProviderRequest."""
module = ama
existing_req = ProviderRequest(prompt="Existing prompt")
mock_event.get_extra.side_effect = lambda k: (
existing_req if k == "provider_request" else None
)
with (
patch("astrbot.core.astr_main_agent.AgentRunner") as mock_runner_cls,
patch("astrbot.core.astr_main_agent.AstrAgentContext"),
):
mock_runner = MagicMock()
mock_runner.reset = AsyncMock()
mock_runner_cls.return_value = mock_runner
result = await module.build_main_agent(
event=mock_event,
plugin_context=mock_context,
config=module.MainAgentBuildConfig(tool_call_timeout=60),
provider=mock_provider,
req=existing_req,
)
assert result is not None
assert result.provider_request == existing_req
class TestHandleWebchat:
"""Tests for _handle_webchat function."""
@pytest.mark.asyncio
async def test_handle_webchat_generates_title(self, mock_event):
"""Test generating title for webchat session without display name."""
module = ama
mock_event.session_id = "platform!webchat-session-123"
req = ProviderRequest(prompt="What is machine learning?")
prov = MagicMock(spec=Provider)
llm_response = MagicMock()
llm_response.completion_text = "Machine Learning Introduction"
prov.text_chat = AsyncMock(return_value=llm_response)
mock_session = MagicMock()
mock_session.display_name = None
with patch("astrbot.core.db_helper") as mock_db:
mock_db.get_platform_session_by_id = AsyncMock(return_value=mock_session)
mock_db.update_platform_session = AsyncMock()
await module._handle_webchat(mock_event, req, prov)
mock_db.get_platform_session_by_id.assert_called_once_with(
"webchat-session-123"
)
mock_db.update_platform_session.assert_called_once_with(
session_id="webchat-session-123",
display_name="Machine Learning Introduction",
)
@pytest.mark.asyncio
async def test_handle_webchat_no_user_prompt(self, mock_event):
"""Test that title generation is skipped when no user prompt."""
module = ama
mock_event.session_id = "platform!webchat-session-123"
req = ProviderRequest(prompt=None)
prov = MagicMock(spec=Provider)
mock_session = MagicMock()
mock_session.display_name = None
with patch("astrbot.core.db_helper") as mock_db:
mock_db.get_platform_session_by_id = AsyncMock(return_value=mock_session)
await module._handle_webchat(mock_event, req, prov)
prov.text_chat.assert_not_called()
@pytest.mark.asyncio
async def test_handle_webchat_empty_user_prompt(self, mock_event):
"""Test that title generation is skipped when user prompt is empty."""
module = ama
mock_event.session_id = "platform!webchat-session-123"
req = ProviderRequest(prompt="")
prov = MagicMock(spec=Provider)
mock_session = MagicMock()
mock_session.display_name = None
with patch("astrbot.core.db_helper") as mock_db:
mock_db.get_platform_session_by_id = AsyncMock(return_value=mock_session)
await module._handle_webchat(mock_event, req, prov)
prov.text_chat.assert_not_called()
@pytest.mark.asyncio
async def test_handle_webchat_session_already_has_display_name(self, mock_event):
"""Test that title generation is skipped when session already has display name."""
module = ama
mock_event.session_id = "platform!webchat-session-123"
req = ProviderRequest(prompt="What is AI?")
prov = MagicMock(spec=Provider)
mock_session = MagicMock()
mock_session.display_name = "Existing Title"
with patch("astrbot.core.db_helper") as mock_db:
mock_db.get_platform_session_by_id = AsyncMock(return_value=mock_session)
await module._handle_webchat(mock_event, req, prov)
prov.text_chat.assert_not_called()
@pytest.mark.asyncio
async def test_handle_webchat_no_session_found(self, mock_event):
"""Test that title generation is skipped when session is not found."""
module = ama
mock_event.session_id = "platform!webchat-session-123"
req = ProviderRequest(prompt="What is AI?")
prov = MagicMock(spec=Provider)
with patch("astrbot.core.db_helper") as mock_db:
mock_db.get_platform_session_by_id = AsyncMock(return_value=None)
await module._handle_webchat(mock_event, req, prov)
prov.text_chat.assert_not_called()
@pytest.mark.asyncio
async def test_handle_webchat_llm_returns_none_title(self, mock_event):
"""Test that title is not updated when LLM returns <None>."""
module = ama
mock_event.session_id = "platform!webchat-session-123"
req = ProviderRequest(prompt="hi")
prov = MagicMock(spec=Provider)
llm_response = MagicMock()
llm_response.completion_text = "<None>"
prov.text_chat = AsyncMock(return_value=llm_response)
mock_session = MagicMock()
mock_session.display_name = None
with patch("astrbot.core.db_helper") as mock_db:
mock_db.get_platform_session_by_id = AsyncMock(return_value=mock_session)
mock_db.update_platform_session = AsyncMock()
await module._handle_webchat(mock_event, req, prov)
mock_db.update_platform_session.assert_not_called()
@pytest.mark.asyncio
async def test_handle_webchat_llm_returns_empty_title(self, mock_event):
"""Test that title is not updated when LLM returns empty string."""
module = ama
mock_event.session_id = "platform!webchat-session-123"
req = ProviderRequest(prompt="hello")
prov = MagicMock(spec=Provider)
llm_response = MagicMock()
llm_response.completion_text = " "
prov.text_chat = AsyncMock(return_value=llm_response)
mock_session = MagicMock()
mock_session.display_name = None
with patch("astrbot.core.db_helper") as mock_db:
mock_db.get_platform_session_by_id = AsyncMock(return_value=mock_session)
mock_db.update_platform_session = AsyncMock()
await module._handle_webchat(mock_event, req, prov)
mock_db.update_platform_session.assert_not_called()
@pytest.mark.asyncio
async def test_handle_webchat_llm_returns_none_response(self, mock_event):
"""Test handling when LLM returns None response."""
module = ama
mock_event.session_id = "platform!webchat-session-123"
req = ProviderRequest(prompt="test question")
prov = MagicMock(spec=Provider)
prov.text_chat = AsyncMock(return_value=None)
mock_session = MagicMock()
mock_session.display_name = None
with patch("astrbot.core.db_helper") as mock_db:
mock_db.get_platform_session_by_id = AsyncMock(return_value=mock_session)
mock_db.update_platform_session = AsyncMock()
await module._handle_webchat(mock_event, req, prov)
mock_db.update_platform_session.assert_not_called()
@pytest.mark.asyncio
async def test_handle_webchat_llm_returns_no_completion_text(self, mock_event):
"""Test handling when LLM response has no completion_text."""
module = ama
mock_event.session_id = "platform!webchat-session-123"
req = ProviderRequest(prompt="test question")
prov = MagicMock(spec=Provider)
llm_response = MagicMock()
llm_response.completion_text = None
prov.text_chat = AsyncMock(return_value=llm_response)
mock_session = MagicMock()
mock_session.display_name = None
with patch("astrbot.core.db_helper") as mock_db:
mock_db.get_platform_session_by_id = AsyncMock(return_value=mock_session)
mock_db.update_platform_session = AsyncMock()
await module._handle_webchat(mock_event, req, prov)
mock_db.update_platform_session.assert_not_called()
@pytest.mark.asyncio
async def test_handle_webchat_strips_title_whitespace(self, mock_event):
"""Test that generated title has whitespace stripped."""
module = ama
mock_event.session_id = "platform!webchat-session-123"
req = ProviderRequest(prompt="What is Python?")
prov = MagicMock(spec=Provider)
llm_response = MagicMock()
llm_response.completion_text = " Python Programming Guide "
prov.text_chat = AsyncMock(return_value=llm_response)
mock_session = MagicMock()
mock_session.display_name = None
with patch("astrbot.core.db_helper") as mock_db:
mock_db.get_platform_session_by_id = AsyncMock(return_value=mock_session)
mock_db.update_platform_session = AsyncMock()
await module._handle_webchat(mock_event, req, prov)
mock_db.update_platform_session.assert_called_once_with(
session_id="webchat-session-123",
display_name="Python Programming Guide",
)
@pytest.mark.asyncio
async def test_handle_webchat_provider_exception_is_handled(self, mock_event):
"""Test that provider exception during title generation is handled."""
module = ama
mock_event.session_id = "platform!webchat-session-123"
req = ProviderRequest(prompt="What is Python?")
prov = MagicMock(spec=Provider)
prov.text_chat = AsyncMock(side_effect=RuntimeError("provider failed"))
mock_session = MagicMock()
mock_session.display_name = None
with (
patch("astrbot.core.db_helper") as mock_db,
patch("astrbot.core.astr_main_agent.logger") as mock_logger,
):
mock_db.get_platform_session_by_id = AsyncMock(return_value=mock_session)
mock_db.update_platform_session = AsyncMock()
await module._handle_webchat(mock_event, req, prov)
mock_logger.exception.assert_called_once()
mock_db.update_platform_session.assert_not_called()
class TestApplyLlmSafetyMode:
"""Tests for _apply_llm_safety_mode function."""
def test_apply_llm_safety_mode_system_prompt_strategy(self):
"""Test applying safety mode with system_prompt strategy."""
module = ama
config = module.MainAgentBuildConfig(
tool_call_timeout=60,
llm_safety_mode=True,
safety_mode_strategy="system_prompt",
)
req = ProviderRequest(prompt="Test", system_prompt="Original prompt")
module._apply_llm_safety_mode(config, req)
assert "You are running in Safe Mode" in req.system_prompt
assert "Original prompt" in req.system_prompt
def test_apply_llm_safety_mode_prepends_safety_prompt(self):
"""Test that safety prompt is prepended before original system prompt."""
module = ama
config = module.MainAgentBuildConfig(
tool_call_timeout=60,
safety_mode_strategy="system_prompt",
)
req = ProviderRequest(prompt="Test", system_prompt="My custom prompt")
module._apply_llm_safety_mode(config, req)
assert req.system_prompt.startswith("You are running in Safe Mode")
assert "My custom prompt" in req.system_prompt
def test_apply_llm_safety_mode_with_none_system_prompt(self):
"""Test applying safety mode when original system_prompt is None."""
module = ama
config = module.MainAgentBuildConfig(
tool_call_timeout=60,
safety_mode_strategy="system_prompt",
)
req = ProviderRequest(prompt="Test", system_prompt=None)
module._apply_llm_safety_mode(config, req)
assert "You are running in Safe Mode" in req.system_prompt
def test_apply_llm_safety_mode_unsupported_strategy(self):
"""Test that unsupported strategy logs warning and does nothing."""
module = ama
config = module.MainAgentBuildConfig(
tool_call_timeout=60,
safety_mode_strategy="unsupported_strategy",
)
req = ProviderRequest(prompt="Test", system_prompt="Original")
with patch("astrbot.core.astr_main_agent.logger") as mock_logger:
module._apply_llm_safety_mode(config, req)
mock_logger.warning.assert_called_once()
assert (
"Unsupported llm_safety_mode strategy"
in mock_logger.warning.call_args[0][0]
)
assert req.system_prompt == "Original"
def test_apply_llm_safety_mode_empty_system_prompt(self):
"""Test applying safety mode when original system_prompt is empty."""
module = ama
config = module.MainAgentBuildConfig(
tool_call_timeout=60,
safety_mode_strategy="system_prompt",
)
req = ProviderRequest(prompt="Test", system_prompt="")
module._apply_llm_safety_mode(config, req)
assert "You are running in Safe Mode" in req.system_prompt
class TestApplySandboxTools:
"""Tests for _apply_sandbox_tools function."""
def test_apply_sandbox_tools_creates_toolset_if_none(self):
"""Test that ToolSet is created when func_tool is None."""
module = ama
config = module.MainAgentBuildConfig(
tool_call_timeout=60,
computer_use_runtime="sandbox",
sandbox_cfg={},
)
req = ProviderRequest(prompt="Test", func_tool=None)
module._apply_sandbox_tools(config, req, "session-123")
assert req.func_tool is not None
assert isinstance(req.func_tool, ToolSet)
def test_apply_sandbox_tools_adds_required_tools(self):
"""Test that all required sandbox tools are added."""
module = ama
config = module.MainAgentBuildConfig(
tool_call_timeout=60,
computer_use_runtime="sandbox",
sandbox_cfg={},
)
req = ProviderRequest(prompt="Test", func_tool=None)
module._apply_sandbox_tools(config, req, "session-123")
tool_names = req.func_tool.names()
assert "astrbot_execute_shell" in tool_names
assert "astrbot_execute_ipython" in tool_names
assert "astrbot_upload_file" in tool_names
assert "astrbot_download_file" in tool_names
def test_apply_sandbox_tools_adds_sandbox_prompt(self):
"""Test that sandbox mode prompt is added to system_prompt."""
module = ama
config = module.MainAgentBuildConfig(
tool_call_timeout=60,
computer_use_runtime="sandbox",
sandbox_cfg={},
)
req = ProviderRequest(prompt="Test", system_prompt="Original prompt")
module._apply_sandbox_tools(config, req, "session-123")
assert "sandboxed environment" in req.system_prompt
def test_apply_sandbox_tools_with_shipyard_booter(self, monkeypatch):
"""Test sandbox tools with shipyard booter configuration."""
module = ama
config = module.MainAgentBuildConfig(
tool_call_timeout=60,
computer_use_runtime="sandbox",
sandbox_cfg={
"booter": "shipyard",
"shipyard_endpoint": "https://shipyard.example.com",
"shipyard_access_token": "test-token",
},
)
req = ProviderRequest(prompt="Test", func_tool=None)
monkeypatch.delenv("SHIPYARD_ENDPOINT", raising=False)
monkeypatch.delenv("SHIPYARD_ACCESS_TOKEN", raising=False)
module._apply_sandbox_tools(config, req, "session-123")
assert os.environ.get("SHIPYARD_ENDPOINT") == "https://shipyard.example.com"
assert os.environ.get("SHIPYARD_ACCESS_TOKEN") == "test-token"
def test_apply_sandbox_tools_shipyard_missing_endpoint(self):
"""Test that shipyard config is skipped when endpoint is missing."""
module = ama
config = module.MainAgentBuildConfig(
tool_call_timeout=60,
computer_use_runtime="sandbox",
sandbox_cfg={
"booter": "shipyard",
"shipyard_endpoint": "",
"shipyard_access_token": "test-token",
},
)
req = ProviderRequest(prompt="Test", func_tool=None)
with patch("astrbot.core.astr_main_agent.logger") as mock_logger:
module._apply_sandbox_tools(config, req, "session-123")
mock_logger.error.assert_called_once()
assert (
"Shipyard sandbox configuration is incomplete"
in mock_logger.error.call_args[0][0]
)
def test_apply_sandbox_tools_shipyard_missing_access_token(self):
"""Test that shipyard config is skipped when access token is missing."""
module = ama
config = module.MainAgentBuildConfig(
tool_call_timeout=60,
computer_use_runtime="sandbox",
sandbox_cfg={
"booter": "shipyard",
"shipyard_endpoint": "https://shipyard.example.com",
"shipyard_access_token": "",
},
)
req = ProviderRequest(prompt="Test", func_tool=None)
with patch("astrbot.core.astr_main_agent.logger") as mock_logger:
module._apply_sandbox_tools(config, req, "session-123")
mock_logger.error.assert_called_once()
def test_apply_sandbox_tools_preserves_existing_toolset(self):
"""Test that existing tools are preserved when adding sandbox tools."""
module = ama
config = module.MainAgentBuildConfig(
tool_call_timeout=60,
computer_use_runtime="sandbox",
sandbox_cfg={},
)
existing_toolset = ToolSet()
existing_tool = MagicMock()
existing_tool.name = "existing_tool"
existing_toolset.add_tool(existing_tool)
req = ProviderRequest(prompt="Test", func_tool=existing_toolset)
module._apply_sandbox_tools(config, req, "session-123")
assert "existing_tool" in req.func_tool.names()
assert "astrbot_execute_shell" in req.func_tool.names()
def test_apply_sandbox_tools_appends_to_existing_system_prompt(self):
"""Test that sandbox prompt is appended to existing system prompt."""
module = ama
config = module.MainAgentBuildConfig(
tool_call_timeout=60,
computer_use_runtime="sandbox",
sandbox_cfg={},
)
req = ProviderRequest(prompt="Test", system_prompt="Base prompt")
module._apply_sandbox_tools(config, req, "session-123")
assert req.system_prompt.startswith("Base prompt")
assert "sandboxed environment" in req.system_prompt
def test_apply_sandbox_tools_with_none_system_prompt(self):
"""Test that sandbox prompt is applied when system_prompt is None."""
module = ama
config = module.MainAgentBuildConfig(
tool_call_timeout=60,
computer_use_runtime="sandbox",
sandbox_cfg={},
)
req = ProviderRequest(prompt="Test", system_prompt=None)
module._apply_sandbox_tools(config, req, "session-123")
assert isinstance(req.system_prompt, str)
assert "sandboxed environment" in req.system_prompt