import json from typing import Protocol, runtime_checkable from ..message import Message, TextPart @runtime_checkable class TokenCounter(Protocol): """ Protocol for token counters. Provides an interface for counting tokens in message lists. """ def count_tokens( self, messages: list[Message], trusted_token_usage: int = 0 ) -> int: """Count the total tokens in the message list. Args: messages: The message list. trusted_token_usage: The total token usage that LLM API returned. For some cases, this value is more accurate. But some API does not return it, so the value defaults to 0. Returns: The total token count. """ ... class EstimateTokenCounter: """Estimate token counter implementation. Provides a simple estimation of token count based on character types. """ def count_tokens( self, messages: list[Message], trusted_token_usage: int = 0 ) -> int: if trusted_token_usage > 0: return trusted_token_usage total = 0 for msg in messages: content = msg.content if isinstance(content, str): total += self._estimate_tokens(content) elif isinstance(content, list): # 处理多模态内容 for part in content: if isinstance(part, TextPart): total += self._estimate_tokens(part.text) # 处理 Tool Calls if msg.tool_calls: for tc in msg.tool_calls: tc_str = json.dumps(tc if isinstance(tc, dict) else tc.model_dump()) total += self._estimate_tokens(tc_str) return total def _estimate_tokens(self, text: str) -> int: chinese_count = len([c for c in text if "\u4e00" <= c <= "\u9fff"]) other_count = len(text) - chinese_count return int(chinese_count * 0.6 + other_count * 0.3)