File size: 14,752 Bytes
8ede856
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
"""Tests for ContextTruncator."""

from astrbot.core.agent.context.truncator import ContextTruncator
from astrbot.core.agent.message import Message


class TestContextTruncator:
    """Test suite for ContextTruncator."""

    def create_message(self, role: str, content: str = "test content") -> Message:
        """Helper to create a simple test message."""
        return Message(role=role, content=content)

    def create_messages(
        self, count: int, include_system: bool = False
    ) -> list[Message]:
        """Helper to create alternating user/assistant messages.

        Args:
            count: Number of messages to create
            include_system: Whether to include a system message at the start

        Returns:
            List of messages
        """
        messages = []
        if include_system:
            messages.append(self.create_message("system", "System prompt"))

        for i in range(count):
            role = "user" if i % 2 == 0 else "assistant"
            messages.append(self.create_message(role, f"Message {i}"))
        return messages

    # ==================== fix_messages Tests ====================

    def test_fix_messages_empty_list(self):
        """Test fix_messages with an empty list."""
        truncator = ContextTruncator()
        result = truncator.fix_messages([])
        assert result == []

    def test_fix_messages_normal_messages(self):
        """Test fix_messages with normal user/assistant messages."""
        truncator = ContextTruncator()
        messages = [
            self.create_message("user", "Hello"),
            self.create_message("assistant", "Hi"),
            self.create_message("user", "How are you?"),
        ]
        result = truncator.fix_messages(messages)
        assert len(result) == 3
        assert result == messages

    def test_fix_messages_tool_without_context(self):
        """Test fix_messages with tool message without enough context."""
        truncator = ContextTruncator()
        messages = [
            self.create_message("tool", "Tool result"),
        ]
        result = truncator.fix_messages(messages)
        # Tool message without context should be removed
        assert len(result) == 0

    # ==================== truncate_by_turns Tests ====================

    def test_truncate_by_turns_no_limit(self):
        """Test truncate_by_turns with -1 (no limit)."""
        truncator = ContextTruncator()
        messages = self.create_messages(20)
        result = truncator.truncate_by_turns(messages, keep_most_recent_turns=-1)
        assert len(result) == 20
        assert result == messages

    def test_truncate_by_turns_basic(self):
        """Test basic truncate_by_turns functionality."""
        truncator = ContextTruncator()
        # Create 10 messages = 5 turns (user/assistant pairs)
        messages = self.create_messages(10)
        result = truncator.truncate_by_turns(
            messages, keep_most_recent_turns=3, drop_turns=1
        )

        # Should keep 3 most recent turns (6 messages)
        assert len(result) <= 8  # (3-1+1)*2 = 6, but may adjust for correct format

    def test_truncate_by_turns_with_system_message(self):
        """Test truncate_by_turns preserves system messages."""
        truncator = ContextTruncator()
        messages = self.create_messages(10, include_system=True)
        result = truncator.truncate_by_turns(
            messages, keep_most_recent_turns=2, drop_turns=1
        )

        # System message should always be preserved
        assert result[0].role == "system"
        assert result[0].content == "System prompt"

    def test_truncate_by_turns_zero_keep(self):
        """Test truncate_by_turns with keep_most_recent_turns=0."""
        truncator = ContextTruncator()
        messages = self.create_messages(10)
        result = truncator.truncate_by_turns(
            messages, keep_most_recent_turns=0, drop_turns=1
        )

        # Should result in empty or minimal list
        assert len(result) == 0

    def test_truncate_by_turns_below_threshold(self):
        """Test truncate_by_turns when messages are below threshold."""
        truncator = ContextTruncator()
        # Create 4 messages = 2 turns
        messages = self.create_messages(4)
        result = truncator.truncate_by_turns(
            messages, keep_most_recent_turns=5, drop_turns=1
        )

        # No truncation should happen
        assert len(result) == 4
        assert result == messages

    def test_truncate_by_turns_exact_threshold(self):
        """Test truncate_by_turns when messages exactly match threshold."""
        truncator = ContextTruncator()
        # Create 6 messages = 3 turns
        messages = self.create_messages(6)
        result = truncator.truncate_by_turns(
            messages, keep_most_recent_turns=3, drop_turns=1
        )

        # No truncation should happen
        assert len(result) == 6
        assert result == messages

    def test_truncate_by_turns_ensures_user_first(self):
        """Test that truncate_by_turns ensures user message comes first."""
        truncator = ContextTruncator()
        # Create scenario where truncation might start with assistant
        messages = self.create_messages(20)
        result = truncator.truncate_by_turns(
            messages, keep_most_recent_turns=3, drop_turns=1
        )

        # First non-system message should be user
        assert result[0].role == "user"

    def test_truncate_by_turns_multiple_drop(self):
        """Test truncate_by_turns with multiple turns dropped at once."""
        truncator = ContextTruncator()
        messages = self.create_messages(20)
        result = truncator.truncate_by_turns(
            messages, keep_most_recent_turns=5, drop_turns=3
        )

        # Should drop 3 turns when limit exceeded
        assert len(result) < len(messages)

    # ==================== truncate_by_dropping_oldest_turns Tests ====================

    def test_truncate_by_dropping_oldest_turns_zero(self):
        """Test truncate_by_dropping_oldest_turns with drop_turns=0."""
        truncator = ContextTruncator()
        messages = self.create_messages(10)
        result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=0)
        assert result == messages

    def test_truncate_by_dropping_oldest_turns_negative(self):
        """Test truncate_by_dropping_oldest_turns with negative drop_turns."""
        truncator = ContextTruncator()
        messages = self.create_messages(10)
        result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=-1)
        assert result == messages

    def test_truncate_by_dropping_oldest_turns_basic(self):
        """Test basic truncate_by_dropping_oldest_turns functionality."""
        truncator = ContextTruncator()
        # Create 10 messages = 5 turns
        messages = self.create_messages(10)
        result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=2)

        # Should drop 2 oldest turns (4 messages)
        assert len(result) == 6
        # Should start with user message
        assert result[0].role == "user"

    def test_truncate_by_dropping_oldest_turns_with_system(self):
        """Test truncate_by_dropping_oldest_turns preserves system messages."""
        truncator = ContextTruncator()
        messages = self.create_messages(10, include_system=True)
        result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=2)

        # System message should be preserved
        assert result[0].role == "system"
        assert result[0].content == "System prompt"

    def test_truncate_by_dropping_oldest_turns_drop_all(self):
        """Test truncate_by_dropping_oldest_turns dropping all turns."""
        truncator = ContextTruncator()
        # Create 4 messages = 2 turns
        messages = self.create_messages(4)
        result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=2)

        # Should drop all turns
        assert len(result) == 0

    def test_truncate_by_dropping_oldest_turns_drop_more_than_available(self):
        """Test truncate_by_dropping_oldest_turns with drop_turns > available turns."""
        truncator = ContextTruncator()
        # Create 4 messages = 2 turns
        messages = self.create_messages(4)
        result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=5)

        # Should result in empty list
        assert len(result) == 0

    def test_truncate_by_dropping_oldest_turns_ensures_user_first(self):
        """Test that result starts with user message after dropping."""
        truncator = ContextTruncator()
        messages = self.create_messages(20)
        result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=3)

        # First message should be user
        if len(result) > 0:
            assert result[0].role == "user"

    # ==================== truncate_by_halving Tests ====================

    def test_truncate_by_halving_empty(self):
        """Test truncate_by_halving with empty list."""
        truncator = ContextTruncator()
        result = truncator.truncate_by_halving([])
        assert result == []

    def test_truncate_by_halving_single_message(self):
        """Test truncate_by_halving with single message."""
        truncator = ContextTruncator()
        messages = [self.create_message("user", "Hello")]
        result = truncator.truncate_by_halving(messages)
        # Should not truncate if <= 2 messages
        assert result == messages

    def test_truncate_by_halving_two_messages(self):
        """Test truncate_by_halving with two messages."""
        truncator = ContextTruncator()
        messages = self.create_messages(2)
        result = truncator.truncate_by_halving(messages)
        # Should not truncate if <= 2 messages
        assert result == messages

    def test_truncate_by_halving_basic(self):
        """Test basic truncate_by_halving functionality."""
        truncator = ContextTruncator()
        # Create 20 messages
        messages = self.create_messages(20)
        result = truncator.truncate_by_halving(messages)

        # Should delete 50% = 10 messages, keep 10
        assert len(result) == 10
        # First message should be user
        assert result[0].role == "user"

    def test_truncate_by_halving_with_system_message(self):
        """Test truncate_by_halving preserves system messages."""
        truncator = ContextTruncator()
        messages = self.create_messages(20, include_system=True)
        result = truncator.truncate_by_halving(messages)

        # System message should be preserved
        assert result[0].role == "system"
        assert result[0].content == "System prompt"

    def test_truncate_by_halving_odd_count(self):
        """Test truncate_by_halving with odd number of messages."""
        truncator = ContextTruncator()
        messages = self.create_messages(11)
        result = truncator.truncate_by_halving(messages)

        # Should delete floor(11/2) = 5 messages, keep 6
        # But after ensuring user first, may be 5
        assert len(result) >= 5
        assert result[0].role == "user"

    def test_truncate_by_halving_ensures_user_first(self):
        """Test that result starts with user message."""
        truncator = ContextTruncator()
        # Create messages starting with user
        messages = self.create_messages(30)
        result = truncator.truncate_by_halving(messages)

        # First message should be user
        assert result[0].role == "user"

    def test_truncate_by_halving_preserves_recent_messages(self):
        """Test that truncate_by_halving keeps the most recent 50%."""
        truncator = ContextTruncator()
        messages = [
            self.create_message("user", "Message 0"),
            self.create_message("assistant", "Message 1"),
            self.create_message("user", "Message 2"),
            self.create_message("assistant", "Message 3"),
        ]
        result = truncator.truncate_by_halving(messages)

        # Should keep last 2 messages
        assert len(result) == 2
        assert result[0].content == "Message 2"
        assert result[1].content == "Message 3"

    # ==================== Integration Tests ====================

    def test_truncate_with_tool_messages(self):
        """Test truncation with tool messages."""
        truncator = ContextTruncator()
        messages = [
            self.create_message("user", "Run tool"),
            self.create_message("assistant", "Running..."),
            self.create_message("tool", "Tool result"),
            self.create_message("user", "Thanks"),
            self.create_message("assistant", "Welcome"),
        ]

        result = truncator.truncate_by_dropping_oldest_turns(messages, drop_turns=1)

        # First turn (user+assistant+tool) should be dropped
        # Tool message should be cleaned up by fix_messages
        assert len(result) <= 2

    def test_chain_multiple_truncations(self):
        """Test chaining multiple truncation methods."""
        truncator = ContextTruncator()
        messages = self.create_messages(40, include_system=True)

        # First: truncate by turns
        result = truncator.truncate_by_turns(
            messages, keep_most_recent_turns=10, drop_turns=2
        )
        # Then: halve
        result = truncator.truncate_by_halving(result)

        # Should have system message + truncated content
        assert result[0].role == "system"
        assert len(result) < len(messages)

    def test_empty_after_system_message(self):
        """Test truncation when only system message exists."""
        truncator = ContextTruncator()
        messages = [self.create_message("system", "System prompt")]

        result = truncator.truncate_by_turns(
            messages, keep_most_recent_turns=5, drop_turns=1
        )

        # Should keep system message
        assert len(result) == 1
        assert result[0].role == "system"

    def test_all_system_messages(self):
        """Test truncation with only system messages."""
        truncator = ContextTruncator()
        messages = [
            self.create_message("system", "System 1"),
            self.create_message("system", "System 2"),
        ]

        result = truncator.truncate_by_turns(
            messages, keep_most_recent_turns=0, drop_turns=1
        )

        # System messages should be preserved, but since there are no non-system
        # messages and keep_most_recent_turns=0, result should be system messages only
        assert len(result) >= 0  # May keep system messages or clear all
        if len(result) > 0:
            assert all(msg.role == "system" for msg in result)