"""Tests for AstrMessageEvent class.""" import re from unittest.mock import AsyncMock, patch import pytest from astrbot.core.message.components import ( At, AtAll, Face, Forward, Image, Plain, Reply, ) from astrbot.core.message.message_event_result import MessageEventResult from astrbot.core.platform.astr_message_event import AstrMessageEvent from astrbot.core.platform.astrbot_message import AstrBotMessage, MessageMember from astrbot.core.platform.message_type import MessageType from astrbot.core.platform.platform_metadata import PlatformMetadata class ConcreteAstrMessageEvent(AstrMessageEvent): """Concrete implementation of AstrMessageEvent for testing purposes.""" async def send(self, message): """Send message implementation.""" await super().send(message) @pytest.fixture def platform_meta(): """Create platform metadata for testing.""" return PlatformMetadata( name="test_platform", description="Test platform", id="test_platform_id", ) @pytest.fixture def message_member(): """Create a message member for testing.""" return MessageMember(user_id="user123", nickname="TestUser") @pytest.fixture def astrbot_message(message_member): """Create an AstrBotMessage for testing.""" message = AstrBotMessage() message.type = MessageType.FRIEND_MESSAGE message.self_id = "bot123" message.session_id = "session123" message.message_id = "msg123" message.sender = message_member message.message = [Plain(text="Hello world")] message.message_str = "Hello world" message.raw_message = None return message @pytest.fixture def astr_message_event(platform_meta, astrbot_message): """Create an AstrMessageEvent instance for testing.""" return ConcreteAstrMessageEvent( message_str="Hello world", message_obj=astrbot_message, platform_meta=platform_meta, session_id="session123", ) class TestAstrMessageEventInit: """Tests for AstrMessageEvent initialization.""" def test_init_basic(self, astr_message_event): """Test basic AstrMessageEvent initialization.""" assert astr_message_event.message_str == "Hello world" assert astr_message_event.role == "member" assert astr_message_event.is_wake is False assert astr_message_event.is_at_or_wake_command is False assert astr_message_event._extras == {} assert astr_message_event._result is None assert astr_message_event.call_llm is False def test_init_session(self, astr_message_event): """Test session initialization.""" assert astr_message_event.session_id == "session123" assert astr_message_event.session.platform_name == "test_platform_id" def test_init_platform_reference(self, astr_message_event, platform_meta): """Test platform reference initialization.""" assert astr_message_event.platform_meta == platform_meta assert astr_message_event.platform == platform_meta # back compatibility def test_init_created_at(self, astr_message_event): """Test created_at timestamp is set.""" assert astr_message_event.created_at is not None assert isinstance(astr_message_event.created_at, float) def test_init_trace(self, astr_message_event): """Test trace/span initialization.""" assert astr_message_event.trace is not None assert astr_message_event.span is not None assert astr_message_event.trace == astr_message_event.span class TestUnifiedMsgOrigin: """Tests for unified_msg_origin property.""" def test_unified_msg_origin_getter(self, astr_message_event): """Test unified_msg_origin getter.""" expected = "test_platform_id:FriendMessage:session123" assert astr_message_event.unified_msg_origin == expected def test_unified_msg_origin_setter(self, astr_message_event): """Test unified_msg_origin setter.""" astr_message_event.unified_msg_origin = "new_platform:GroupMessage:new_session" assert astr_message_event.session.platform_name == "new_platform" assert astr_message_event.session.session_id == "new_session" class TestSessionId: """Tests for session_id property.""" def test_session_id_getter(self, astr_message_event): """Test session_id getter.""" assert astr_message_event.session_id == "session123" def test_session_id_setter(self, astr_message_event): """Test session_id setter.""" astr_message_event.session_id = "new_session_id" assert astr_message_event.session_id == "new_session_id" class TestGetPlatformInfo: """Tests for platform info methods.""" def test_get_platform_name(self, astr_message_event): """Test get_platform_name method.""" assert astr_message_event.get_platform_name() == "test_platform" def test_get_platform_id(self, astr_message_event): """Test get_platform_id method.""" assert astr_message_event.get_platform_id() == "test_platform_id" class TestGetMessageInfo: """Tests for message info methods.""" def test_get_message_str(self, astr_message_event): """Test get_message_str method.""" assert astr_message_event.get_message_str() == "Hello world" def test_get_message_str_none(self, platform_meta, astrbot_message): """Test get_message_str keeps None when source message_str is None.""" astrbot_message.message_str = None event = ConcreteAstrMessageEvent( message_str=None, message_obj=astrbot_message, platform_meta=platform_meta, session_id="session123", ) assert event.get_message_str() is None def test_get_messages(self, astr_message_event): """Test get_messages method.""" messages = astr_message_event.get_messages() assert len(messages) == 1 assert isinstance(messages[0], Plain) assert messages[0].text == "Hello world" def test_get_message_type(self, astr_message_event): """Test get_message_type method.""" assert astr_message_event.get_message_type() == MessageType.FRIEND_MESSAGE def test_get_session_id(self, astr_message_event): """Test get_session_id method.""" assert astr_message_event.get_session_id() == "session123" def test_get_group_id_empty_for_private(self, astr_message_event): """Test get_group_id returns empty for private messages.""" assert astr_message_event.get_group_id() == "" def test_get_self_id(self, astr_message_event): """Test get_self_id method.""" assert astr_message_event.get_self_id() == "bot123" def test_get_sender_id(self, astr_message_event): """Test get_sender_id method.""" assert astr_message_event.get_sender_id() == "user123" def test_get_sender_name(self, astr_message_event): """Test get_sender_name method.""" assert astr_message_event.get_sender_name() == "TestUser" def test_get_sender_name_empty_when_none(self, platform_meta, astrbot_message): """Test get_sender_name returns empty string when nickname is None.""" astrbot_message.sender = MessageMember(user_id="user123", nickname=None) event = ConcreteAstrMessageEvent( message_str="test", message_obj=astrbot_message, platform_meta=platform_meta, session_id="session123", ) assert event.get_sender_name() == "" def test_get_sender_name_coerces_non_string(self, platform_meta, astrbot_message): """Test get_sender_name stringifies non-string nickname values.""" astrbot_message.sender = MessageMember(user_id="user123", nickname=None) astrbot_message.sender.nickname = 12345 event = ConcreteAstrMessageEvent( message_str="test", message_obj=astrbot_message, platform_meta=platform_meta, session_id="session123", ) assert event.get_sender_name() == "12345" class TestGetMessageOutline: """Tests for get_message_outline method.""" def test_outline_plain_text(self, astr_message_event): """Test outline with plain text message.""" outline = astr_message_event.get_message_outline() assert "Hello world" in outline def test_outline_with_image(self, platform_meta, astrbot_message): """Test outline with image component.""" astrbot_message.message = [ Plain(text="Look at this"), Image(file="http://example.com/img.jpg"), ] event = ConcreteAstrMessageEvent( message_str="Look at this", message_obj=astrbot_message, platform_meta=platform_meta, session_id="session123", ) outline = event.get_message_outline() assert "Look at this" in outline assert "[图片]" in outline def test_outline_with_at(self, platform_meta, astrbot_message): """Test outline with At component.""" astrbot_message.message = [At(qq="12345"), Plain(text=" hello")] event = ConcreteAstrMessageEvent( message_str=" hello", message_obj=astrbot_message, platform_meta=platform_meta, session_id="session123", ) outline = event.get_message_outline() assert "[At:12345]" in outline def test_outline_with_at_all(self, platform_meta, astrbot_message): """Test outline with AtAll component.""" astrbot_message.message = [AtAll()] event = ConcreteAstrMessageEvent( message_str="", message_obj=astrbot_message, platform_meta=platform_meta, session_id="session123", ) outline = event.get_message_outline() # AtAll format is "[At:all]" in the actual implementation assert "[At:" in outline and "all" in outline.lower() def test_outline_with_face(self, platform_meta, astrbot_message): """Test outline with Face component.""" astrbot_message.message = [Face(id="123")] event = ConcreteAstrMessageEvent( message_str="", message_obj=astrbot_message, platform_meta=platform_meta, session_id="session123", ) outline = event.get_message_outline() assert "[表情:123]" in outline def test_outline_with_forward(self, platform_meta, astrbot_message): """Test outline with Forward component.""" # Forward requires an id parameter astrbot_message.message = [Forward(id="test_forward_id")] event = ConcreteAstrMessageEvent( message_str="", message_obj=astrbot_message, platform_meta=platform_meta, session_id="session123", ) outline = event.get_message_outline() assert "[转发消息]" in outline def test_outline_with_reply(self, platform_meta, astrbot_message): """Test outline with Reply component.""" # Reply requires an id parameter reply = Reply(id="test_reply_id") reply.message_str = "Original message" reply.sender_nickname = "Sender" astrbot_message.message = [reply, Plain(text=" reply")] event = ConcreteAstrMessageEvent( message_str=" reply", message_obj=astrbot_message, platform_meta=platform_meta, session_id="session123", ) outline = event.get_message_outline() assert "[引用消息(Sender: Original message)]" in outline def test_outline_with_reply_no_message(self, platform_meta, astrbot_message): """Test outline with Reply component without message_str.""" # Reply requires an id parameter reply = Reply(id="test_reply_id") reply.message_str = None astrbot_message.message = [reply] event = ConcreteAstrMessageEvent( message_str="", message_obj=astrbot_message, platform_meta=platform_meta, session_id="session123", ) outline = event.get_message_outline() assert "[引用消息]" in outline def test_outline_empty_chain(self, platform_meta, astrbot_message): """Test outline with empty message chain.""" astrbot_message.message = [] event = ConcreteAstrMessageEvent( message_str="", message_obj=astrbot_message, platform_meta=platform_meta, session_id="session123", ) outline = event.get_message_outline() assert outline == "" def test_outline_very_long_plain_text(self, platform_meta, astrbot_message): """Test outline generation for very long plain text content.""" long_text = "A" * 20000 astrbot_message.message = [Plain(text=long_text)] event = ConcreteAstrMessageEvent( message_str=long_text, message_obj=astrbot_message, platform_meta=platform_meta, session_id="session123", ) outline = event.get_message_outline() assert outline.startswith("A") assert len(outline) >= 20000 class TestExtras: """Tests for extra information methods.""" def test_set_extra(self, astr_message_event): """Test set_extra method.""" astr_message_event.set_extra("key1", "value1") assert astr_message_event._extras["key1"] == "value1" def test_get_extra_with_key(self, astr_message_event): """Test get_extra with specific key.""" astr_message_event.set_extra("key1", "value1") assert astr_message_event.get_extra("key1") == "value1" def test_get_extra_with_default(self, astr_message_event): """Test get_extra with default value.""" result = astr_message_event.get_extra("nonexistent", "default_value") assert result == "default_value" def test_get_extra_all(self, astr_message_event): """Test get_extra without key returns all extras.""" astr_message_event.set_extra("key1", "value1") astr_message_event.set_extra("key2", "value2") all_extras = astr_message_event.get_extra() assert all_extras == {"key1": "value1", "key2": "value2"} def test_clear_extra(self, astr_message_event): """Test clear_extra method.""" astr_message_event.set_extra("key1", "value1") astr_message_event.clear_extra() assert astr_message_event._extras == {} class TestSetResult: """Tests for set_result method.""" def test_set_result_with_message_event_result(self, astr_message_event): """Test set_result with MessageEventResult object.""" result = MessageEventResult().message("Test message") astr_message_event.set_result(result) assert astr_message_event._result == result def test_set_result_with_string(self, astr_message_event): """Test set_result with string creates MessageEventResult.""" astr_message_event.set_result("Test message") assert astr_message_event._result is not None assert len(astr_message_event._result.chain) == 1 assert isinstance(astr_message_event._result.chain[0], Plain) def test_set_result_with_empty_chain(self, astr_message_event): """Test set_result handles empty chain correctly.""" result = MessageEventResult() # chain is already an empty list by default astr_message_event.set_result(result) assert astr_message_event._result.chain == [] class TestStopContinueEvent: """Tests for stop_event and continue_event methods.""" def test_stop_event_creates_result_if_none(self, astr_message_event): """Test stop_event creates result if none exists.""" astr_message_event.stop_event() assert astr_message_event._result is not None assert astr_message_event.is_stopped() is True def test_stop_event_with_existing_result(self, astr_message_event): """Test stop_event with existing result.""" astr_message_event.set_result(MessageEventResult().message("Test")) astr_message_event.stop_event() assert astr_message_event.is_stopped() is True def test_continue_event_creates_result_if_none(self, astr_message_event): """Test continue_event creates result if none exists.""" astr_message_event.continue_event() assert astr_message_event._result is not None assert astr_message_event.is_stopped() is False def test_continue_event_with_existing_result(self, astr_message_event): """Test continue_event with existing result.""" astr_message_event.set_result(MessageEventResult().message("Test")) astr_message_event.stop_event() astr_message_event.continue_event() assert astr_message_event.is_stopped() is False def test_is_stopped_default_false(self, astr_message_event): """Test is_stopped returns False by default.""" assert astr_message_event.is_stopped() is False class TestIsPrivateChat: """Tests for is_private_chat method.""" def test_is_private_chat_true(self, astr_message_event): """Test is_private_chat returns True for friend message.""" assert astr_message_event.is_private_chat() is True def test_is_private_chat_false(self, platform_meta, astrbot_message): """Test is_private_chat returns False for group message.""" astrbot_message.type = MessageType.GROUP_MESSAGE event = ConcreteAstrMessageEvent( message_str="test", message_obj=astrbot_message, platform_meta=platform_meta, session_id="session123", ) assert event.is_private_chat() is False class TestIsWakeUp: """Tests for is_wake_up method.""" def test_is_wake_up_default_false(self, astr_message_event): """Test is_wake_up returns False by default.""" assert astr_message_event.is_wake_up() is False def test_is_wake_up_when_set(self, astr_message_event): """Test is_wake_up returns True when is_wake is set.""" astr_message_event.is_wake = True assert astr_message_event.is_wake_up() is True class TestIsAdmin: """Tests for is_admin method.""" def test_is_admin_default_false(self, astr_message_event): """Test is_admin returns False by default.""" assert astr_message_event.is_admin() is False def test_is_admin_when_admin(self, astr_message_event): """Test is_admin returns True when role is admin.""" astr_message_event.role = "admin" assert astr_message_event.is_admin() is True class TestProcessBuffer: """Tests for process_buffer method.""" @pytest.mark.asyncio async def test_process_buffer_splits_by_pattern(self, astr_message_event): """Test process_buffer splits buffer by pattern.""" buffer = "Line 1\nLine 2\nLine 3\nRemaining" pattern = re.compile(r".*\n") with patch.object( astr_message_event, "send", new_callable=AsyncMock ) as mock_send: result = await astr_message_event.process_buffer(buffer, pattern) # Should have sent 3 lines and remaining should be "Remaining" assert mock_send.call_count == 3 assert result == "Remaining" @pytest.mark.asyncio async def test_process_buffer_no_match(self, astr_message_event): """Test process_buffer returns original when no match.""" buffer = "No newlines here" pattern = re.compile(r"\n") result = await astr_message_event.process_buffer(buffer, pattern) assert result == "No newlines here" class TestResultHelpers: """Tests for result helper methods.""" def test_make_result(self, astr_message_event): """Test make_result creates empty MessageEventResult.""" result = astr_message_event.make_result() assert isinstance(result, MessageEventResult) def test_plain_result(self, astr_message_event): """Test plain_result creates result with text.""" result = astr_message_event.plain_result("Hello") assert isinstance(result, MessageEventResult) assert len(result.chain) == 1 assert isinstance(result.chain[0], Plain) assert result.chain[0].text == "Hello" def test_image_result_url(self, astr_message_event): """Test image_result with URL.""" result = astr_message_event.image_result("http://example.com/image.jpg") assert isinstance(result, MessageEventResult) assert len(result.chain) == 1 assert isinstance(result.chain[0], Image) def test_image_result_path(self, astr_message_event): """Test image_result with file path.""" result = astr_message_event.image_result("/path/to/image.jpg") assert isinstance(result, MessageEventResult) assert len(result.chain) == 1 assert isinstance(result.chain[0], Image) class TestGetResult: """Tests for get_result and clear_result methods.""" def test_get_result_returns_none_by_default(self, astr_message_event): """Test get_result returns None by default.""" assert astr_message_event.get_result() is None def test_get_result_returns_set_result(self, astr_message_event): """Test get_result returns set result.""" result = MessageEventResult().message("Test") astr_message_event.set_result(result) assert astr_message_event.get_result() == result def test_clear_result(self, astr_message_event): """Test clear_result clears the result.""" astr_message_event.set_result(MessageEventResult().message("Test")) astr_message_event.clear_result() assert astr_message_event.get_result() is None class TestShouldCallLlm: """Tests for should_call_llm method.""" def test_should_call_llm_default(self, astr_message_event): """Test call_llm default is False.""" assert astr_message_event.call_llm is False def test_should_call_llm_when_set(self, astr_message_event): """Test should_call_llm sets call_llm.""" astr_message_event.should_call_llm(True) assert astr_message_event.call_llm is True class TestRequestLlm: """Tests for request_llm method.""" def test_request_llm_basic(self, astr_message_event): """Test request_llm creates ProviderRequest.""" request = astr_message_event.request_llm(prompt="Hello") assert request.prompt == "Hello" assert request.session_id == "" assert request.image_urls == [] assert request.contexts == [] def test_request_llm_with_all_params(self, astr_message_event): """Test request_llm with all parameters.""" request = astr_message_event.request_llm( prompt="Hello", session_id="session123", image_urls=["http://example.com/img.jpg"], contexts=[{"role": "user", "content": "Hi"}], system_prompt="You are helpful", ) assert request.prompt == "Hello" assert request.session_id == "session123" assert request.image_urls == ["http://example.com/img.jpg"] assert request.contexts == [{"role": "user", "content": "Hi"}] assert request.system_prompt == "You are helpful" class TestSendStreaming: """Tests for send_streaming method.""" @pytest.mark.asyncio async def test_send_streaming_sets_has_send_oper(self, astr_message_event): """Test send_streaming sets _has_send_oper flag.""" assert astr_message_event._has_send_oper is False async def generator(): yield MessageEventResult().message("Test") with patch( "astrbot.core.platform.astr_message_event.Metric.upload", new_callable=AsyncMock, ): await astr_message_event.send_streaming(generator()) assert astr_message_event._has_send_oper is True class TestSendTyping: """Tests for send_typing method.""" @pytest.mark.asyncio async def test_send_typing_default_empty(self, astr_message_event): """Test send_typing default implementation is empty.""" # Should not raise any exception await astr_message_event.send_typing() class TestReact: """Tests for react method.""" @pytest.mark.asyncio async def test_react_sends_emoji(self, astr_message_event): """Test react sends emoji as message.""" with patch.object( astr_message_event, "send", new_callable=AsyncMock ) as mock_send: await astr_message_event.react("👍") mock_send.assert_called_once() call_arg = mock_send.call_args[0][0] # MessageChain is a dataclass with chain attribute assert len(call_arg.chain) == 1 assert isinstance(call_arg.chain[0], Plain) assert call_arg.chain[0].text == "👍" class TestGetGroup: """Tests for get_group method.""" @pytest.mark.asyncio async def test_get_group_returns_none_for_private(self, astr_message_event): """Test get_group returns None for private chat.""" result = await astr_message_event.get_group() assert result is None @pytest.mark.asyncio async def test_get_group_with_group_id_param(self, astr_message_event): """Test get_group with group_id parameter.""" # Default implementation returns None result = await astr_message_event.get_group(group_id="group123") assert result is None class TestMessageTypeHandling: """Tests for message type handling edge cases.""" def test_message_type_from_valid_string(self, platform_meta): """Valid MessageType string should be converted correctly.""" message = AstrBotMessage() message.type = "FRIEND_MESSAGE" message.message = [] event = ConcreteAstrMessageEvent( message_str="test", message_obj=message, platform_meta=platform_meta, session_id="session123", ) assert event.session.message_type == MessageType.FRIEND_MESSAGE assert event.get_message_type() == MessageType.FRIEND_MESSAGE def test_message_type_from_invalid_string_defaults_to_friend(self, platform_meta): """Invalid message type should default to FRIEND_MESSAGE.""" message = AstrBotMessage() message.type = "InvalidMessageType" message.message = [] event = ConcreteAstrMessageEvent( message_str="test", message_obj=message, platform_meta=platform_meta, session_id="session123", ) assert event.session.message_type == MessageType.FRIEND_MESSAGE assert event.get_message_type() == MessageType.FRIEND_MESSAGE def test_message_type_from_none_defaults_to_friend(self, platform_meta): """None message type should default to FRIEND_MESSAGE.""" message = AstrBotMessage() message.type = None message.message = [] event = ConcreteAstrMessageEvent( message_str="test", message_obj=message, platform_meta=platform_meta, session_id="session123", ) assert event.session.message_type == MessageType.FRIEND_MESSAGE assert event.get_message_type() == MessageType.FRIEND_MESSAGE def test_message_type_from_integer_defaults_to_friend(self, platform_meta): """Integer message type should default to FRIEND_MESSAGE.""" message = AstrBotMessage() message.type = 123 message.message = [] event = ConcreteAstrMessageEvent( message_str="test", message_obj=message, platform_meta=platform_meta, session_id="session123", ) assert event.session.message_type == MessageType.FRIEND_MESSAGE assert event.get_message_type() == MessageType.FRIEND_MESSAGE class TestDefensiveGetattr: """Tests for defensive getattr behavior in AstrMessageEvent.""" def test_get_messages_without_message_attr(self, astr_message_event): """get_messages should handle message_obj without 'message' attribute.""" astr_message_event.message_obj = type("DummyMessage", (), {})() messages = astr_message_event.get_messages() assert isinstance(messages, list) def test_get_message_type_without_type_attr(self, astr_message_event): """get_message_type should handle message_obj without 'type' attribute.""" astr_message_event.message_obj = type("DummyMessage", (), {})() message_type = astr_message_event.get_message_type() assert isinstance(message_type, MessageType) def test_get_sender_fields_without_sender_attr(self, astr_message_event): """get_sender_id and get_sender_name should handle missing 'sender'.""" astr_message_event.message_obj = type("DummyMessage", (), {})() sender_id = astr_message_event.get_sender_id() sender_name = astr_message_event.get_sender_name() assert isinstance(sender_id, str) assert isinstance(sender_name, str) def test_get_message_type_with_non_enum_type(self, astr_message_event): """get_message_type should handle message_obj.type that is not a MessageType.""" class DummyMessage: def __init__(self): self.type = "not_an_enum" self.message = [] astr_message_event.message_obj = DummyMessage() message_type = astr_message_event.get_message_type() assert isinstance(message_type, MessageType)