| from ..message import Message |
|
|
|
|
| class ContextTruncator: |
| """Context truncator.""" |
|
|
| def _has_tool_calls(self, message: Message) -> bool: |
| """Check if a message contains tool calls.""" |
| return ( |
| message.role == "assistant" |
| and message.tool_calls is not None |
| and len(message.tool_calls) > 0 |
| ) |
|
|
| def fix_messages(self, messages: list[Message]) -> list[Message]: |
| """修复消息列表,确保 tool call 和 tool response 的配对关系有效。 |
| |
| 此方法确保: |
| 1. 每个 `tool` 消息前面都有一个包含 tool_calls 的 `assistant` 消息 |
| 2. 每个包含 tool_calls 的 `assistant` 消息后面都有对应的 `tool` 响应 |
| |
| 这是 OpenAI Chat Completions API 规范的要求(Gemini 对此执行严格检查)。 |
| """ |
| if not messages: |
| return messages |
|
|
| fixed_messages: list[Message] = [] |
| pending_assistant: Message | None = None |
| pending_tools: list[Message] = [] |
|
|
| def flush_pending_if_valid() -> None: |
| nonlocal pending_assistant, pending_tools |
| if pending_assistant is not None and pending_tools: |
| fixed_messages.append(pending_assistant) |
| fixed_messages.extend(pending_tools) |
| pending_assistant = None |
| pending_tools = [] |
|
|
| for msg in messages: |
| if msg.role == "tool": |
| |
| if pending_assistant is not None: |
| pending_tools.append(msg) |
| |
| continue |
|
|
| if self._has_tool_calls(msg): |
| |
| flush_pending_if_valid() |
| pending_assistant = msg |
| continue |
|
|
| |
| |
| flush_pending_if_valid() |
| fixed_messages.append(msg) |
|
|
| |
| flush_pending_if_valid() |
|
|
| return fixed_messages |
|
|
| def truncate_by_turns( |
| self, |
| messages: list[Message], |
| keep_most_recent_turns: int, |
| drop_turns: int = 1, |
| ) -> list[Message]: |
| """截断上下文列表,确保不超过最大长度。 |
| 一个 turn 包含一个 user 消息和一个 assistant 消息。 |
| 这个方法会保证截断后的上下文列表符合 OpenAI 的上下文格式。 |
| |
| Args: |
| messages: 上下文列表 |
| keep_most_recent_turns: 保留最近的对话轮数 |
| drop_turns: 一次性丢弃的对话轮数 |
| |
| Returns: |
| 截断后的上下文列表 |
| """ |
| if keep_most_recent_turns == -1: |
| return messages |
|
|
| first_non_system = 0 |
| for i, msg in enumerate(messages): |
| if msg.role != "system": |
| first_non_system = i |
| break |
|
|
| system_messages = messages[:first_non_system] |
| non_system_messages = messages[first_non_system:] |
|
|
| if len(non_system_messages) // 2 <= keep_most_recent_turns: |
| return messages |
|
|
| num_to_keep = keep_most_recent_turns - drop_turns + 1 |
| if num_to_keep <= 0: |
| truncated_contexts = [] |
| else: |
| truncated_contexts = non_system_messages[-num_to_keep * 2 :] |
|
|
| |
| index = next( |
| (i for i, item in enumerate(truncated_contexts) if item.role == "user"), |
| None, |
| ) |
| if index is not None and index > 0: |
| truncated_contexts = truncated_contexts[index:] |
|
|
| result = system_messages + truncated_contexts |
|
|
| return self.fix_messages(result) |
|
|
| def truncate_by_dropping_oldest_turns( |
| self, |
| messages: list[Message], |
| drop_turns: int = 1, |
| ) -> list[Message]: |
| """丢弃最旧的 N 个对话轮次。""" |
| if drop_turns <= 0: |
| return messages |
|
|
| first_non_system = 0 |
| for i, msg in enumerate(messages): |
| if msg.role != "system": |
| first_non_system = i |
| break |
|
|
| system_messages = messages[:first_non_system] |
| non_system_messages = messages[first_non_system:] |
|
|
| if len(non_system_messages) // 2 <= drop_turns: |
| truncated_non_system = [] |
| else: |
| truncated_non_system = non_system_messages[drop_turns * 2 :] |
|
|
| index = next( |
| (i for i, item in enumerate(truncated_non_system) if item.role == "user"), |
| None, |
| ) |
| if index is not None: |
| truncated_non_system = truncated_non_system[index:] |
| elif truncated_non_system: |
| truncated_non_system = [] |
|
|
| result = system_messages + truncated_non_system |
|
|
| return self.fix_messages(result) |
|
|
| def truncate_by_halving( |
| self, |
| messages: list[Message], |
| ) -> list[Message]: |
| """对半砍策略,删除 50% 的消息""" |
| if len(messages) <= 2: |
| return messages |
|
|
| first_non_system = 0 |
| for i, msg in enumerate(messages): |
| if msg.role != "system": |
| first_non_system = i |
| break |
|
|
| system_messages = messages[:first_non_system] |
| non_system_messages = messages[first_non_system:] |
|
|
| messages_to_delete = len(non_system_messages) // 2 |
| if messages_to_delete == 0: |
| return messages |
|
|
| truncated_non_system = non_system_messages[messages_to_delete:] |
|
|
| index = next( |
| (i for i, item in enumerate(truncated_non_system) if item.role == "user"), |
| None, |
| ) |
| if index is not None: |
| truncated_non_system = truncated_non_system[index:] |
|
|
| result = system_messages + truncated_non_system |
|
|
| return self.fix_messages(result) |
|
|