from __future__ import annotations import base64 import enum import json from dataclasses import dataclass, field from typing import Any from anthropic.types import Message as AnthropicMessage from google.genai.types import GenerateContentResponse from openai.types.chat.chat_completion import ChatCompletion import astrbot.core.message.components as Comp from astrbot import logger from astrbot.core.agent.message import ( AssistantMessageSegment, ContentPart, ToolCall, ToolCallMessageSegment, ) from astrbot.core.agent.tool import ToolSet from astrbot.core.db.po import Conversation from astrbot.core.message.message_event_result import MessageChain from astrbot.core.utils.io import download_image_by_url class ProviderType(enum.Enum): CHAT_COMPLETION = "chat_completion" SPEECH_TO_TEXT = "speech_to_text" TEXT_TO_SPEECH = "text_to_speech" EMBEDDING = "embedding" RERANK = "rerank" @dataclass class ProviderMeta: """The basic metadata of a provider instance.""" id: str """the unique id of the provider instance that user configured""" model: str | None """the model name of the provider instance currently used""" type: str """the name of the provider adapter, such as openai, ollama""" provider_type: ProviderType = ProviderType.CHAT_COMPLETION """the capability type of the provider adapter""" @dataclass class ProviderMetaData(ProviderMeta): """The metadata of a provider adapter for registration.""" desc: str = "" """the short description of the provider adapter""" cls_type: Any = None """the class type of the provider adapter""" default_config_tmpl: dict | None = None """the default configuration template of the provider adapter""" provider_display_name: str | None = None """the display name of the provider shown in the WebUI configuration page; if empty, the type is used""" @dataclass class ToolCallsResult: """工具调用结果""" tool_calls_info: AssistantMessageSegment """函数调用的信息""" tool_calls_result: list[ToolCallMessageSegment] """函数调用的结果""" def to_openai_messages(self) -> list[dict]: ret = [ self.tool_calls_info.model_dump(), *[item.model_dump() for item in self.tool_calls_result], ] return ret def to_openai_messages_model( self, ) -> list[AssistantMessageSegment | ToolCallMessageSegment]: return [ self.tool_calls_info, *self.tool_calls_result, ] @dataclass class ProviderRequest: prompt: str | None = None """提示词""" session_id: str | None = "" """会话 ID""" image_urls: list[str] = field(default_factory=list) """图片 URL 列表""" extra_user_content_parts: list[ContentPart] = field(default_factory=list) """额外的用户消息内容部分列表,用于在用户消息后添加额外的内容块(如系统提醒、指令等)。支持 dict 或 ContentPart 对象""" func_tool: ToolSet | None = None """可用的函数工具""" contexts: list[dict] = field(default_factory=list) """ OpenAI 格式上下文列表。 参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages """ system_prompt: str = "" """系统提示词""" conversation: Conversation | None = None """关联的对话对象""" tool_calls_result: list[ToolCallsResult] | ToolCallsResult | None = None """附加的上次请求后工具调用的结果。参考: https://platform.openai.com/docs/guides/function-calling#handling-function-calls""" model: str | None = None """模型名称,为 None 时使用提供商的默认模型""" def __repr__(self) -> str: return ( f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, " f"image_count={len(self.image_urls or [])}, " f"func_tool={self.func_tool}, " f"contexts={self._print_friendly_context()}, " f"system_prompt={self.system_prompt}, " f"conversation_id={self.conversation.cid if self.conversation else 'N/A'}, " ) def __str__(self) -> str: return self.__repr__() def append_tool_calls_result(self, tool_calls_result: ToolCallsResult) -> None: """添加工具调用结果到请求中""" if not self.tool_calls_result: self.tool_calls_result = [] if isinstance(self.tool_calls_result, ToolCallsResult): self.tool_calls_result = [self.tool_calls_result] self.tool_calls_result.append(tool_calls_result) def _print_friendly_context(self): """打印友好的消息上下文。将 image_url 的值替换为 """ if not self.contexts: return f"prompt: {self.prompt}, image_count: {len(self.image_urls or [])}" result_parts = [] for ctx in self.contexts: role = ctx.get("role", "unknown") content = ctx.get("content", "") if isinstance(content, str): result_parts.append(f"{role}: {content}") elif isinstance(content, list): msg_parts = [] image_count = 0 for item in content: item_type = item.get("type", "") if item_type == "text": msg_parts.append(item.get("text", "")) elif item_type == "image_url": image_count += 1 if image_count > 0: if msg_parts: msg_parts.append(f"[+{image_count} images]") else: msg_parts.append(f"[{image_count} images]") result_parts.append(f"{role}: {''.join(msg_parts)}") return "\n".join(result_parts) async def assemble_context(self) -> dict: """将请求(prompt 和 image_urls)包装成 OpenAI 的消息格式。""" # 构建内容块列表 content_blocks = [] # 1. 用户原始发言(OpenAI 建议:用户发言在前) if self.prompt and self.prompt.strip(): content_blocks.append({"type": "text", "text": self.prompt}) elif self.image_urls: # 如果没有文本但有图片,添加占位文本 content_blocks.append({"type": "text", "text": "[图片]"}) # 2. 额外的内容块(系统提醒、指令等) if self.extra_user_content_parts: for part in self.extra_user_content_parts: content_blocks.append(part.model_dump()) # 3. 图片内容 if self.image_urls: for image_url in self.image_urls: if image_url.startswith("http"): image_path = await download_image_by_url(image_url) image_data = await self._encode_image_bs64(image_path) elif image_url.startswith("file:///"): image_path = image_url.replace("file:///", "") image_data = await self._encode_image_bs64(image_path) else: image_data = await self._encode_image_bs64(image_url) if not image_data: logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。") continue content_blocks.append( {"type": "image_url", "image_url": {"url": image_data}}, ) # 只有当只有一个来自 prompt 的文本块且没有额外内容块时,才降级为简单格式以保持向后兼容 if ( len(content_blocks) == 1 and content_blocks[0]["type"] == "text" and not self.extra_user_content_parts and not self.image_urls ): return {"role": "user", "content": content_blocks[0]["text"]} # 否则返回多模态格式 return {"role": "user", "content": content_blocks} async def _encode_image_bs64(self, image_url: str) -> str: """将图片转换为 base64""" if image_url.startswith("base64://"): return image_url.replace("base64://", "data:image/jpeg;base64,") with open(image_url, "rb") as f: image_bs64 = base64.b64encode(f.read()).decode("utf-8") return "data:image/jpeg;base64," + image_bs64 return "" @dataclass class TokenUsage: input_other: int = 0 """The number of input tokens, excluding cached tokens.""" input_cached: int = 0 """The number of input cached tokens.""" output: int = 0 """The number of output tokens.""" @property def total(self) -> int: return self.input_other + self.input_cached + self.output @property def input(self) -> int: return self.input_other + self.input_cached def __add__(self, other: TokenUsage) -> TokenUsage: return TokenUsage( input_other=self.input_other + other.input_other, input_cached=self.input_cached + other.input_cached, output=self.output + other.output, ) def __sub__(self, other: TokenUsage) -> TokenUsage: return TokenUsage( input_other=self.input_other - other.input_other, input_cached=self.input_cached - other.input_cached, output=self.output - other.output, ) @dataclass class LLMResponse: role: str """The role of the message, e.g., assistant, tool, err""" result_chain: MessageChain | None = None """A chain of message components representing the text completion from LLM.""" tools_call_args: list[dict[str, Any]] = field(default_factory=list) """Tool call arguments.""" tools_call_name: list[str] = field(default_factory=list) """Tool call names.""" tools_call_ids: list[str] = field(default_factory=list) """Tool call IDs.""" tools_call_extra_content: dict[str, dict[str, Any]] = field(default_factory=dict) """Tool call extra content. tool_call_id -> extra_content dict""" reasoning_content: str = "" """The reasoning content extracted from the LLM, if any.""" reasoning_signature: str | None = None """The signature of the reasoning content, if any.""" raw_completion: ( ChatCompletion | GenerateContentResponse | AnthropicMessage | None ) = None """The raw completion response from the LLM provider.""" _completion_text: str = "" """The plain text of the completion.""" is_chunk: bool = False """Indicates if the response is a chunked response.""" id: str | None = None """The ID of the response. For chunked responses, it's the ID of the chunk; for non-chunked responses, it's the ID of the response.""" usage: TokenUsage | None = None """The usage of the response. For chunked responses, it's the usage of the chunk; for non-chunked responses, it's the usage of the response.""" def __init__( self, role: str, completion_text: str | None = None, result_chain: MessageChain | None = None, tools_call_args: list[dict[str, Any]] | None = None, tools_call_name: list[str] | None = None, tools_call_ids: list[str] | None = None, tools_call_extra_content: dict[str, dict[str, Any]] | None = None, reasoning_content: str | None = None, reasoning_signature: str | None = None, raw_completion: ChatCompletion | GenerateContentResponse | AnthropicMessage | None = None, is_chunk: bool = False, id: str | None = None, usage: TokenUsage | None = None, ) -> None: """初始化 LLMResponse Args: role (str): 角色, assistant, tool, err completion_text (str, optional): 返回的结果文本,已经过时,推荐使用 result_chain. Defaults to "". result_chain (MessageChain, optional): 返回的消息链. Defaults to None. tools_call_args (List[Dict[str, any]], optional): 工具调用参数. Defaults to None. tools_call_name (List[str], optional): 工具调用名称. Defaults to None. raw_completion (ChatCompletion, optional): 原始响应, OpenAI 格式. Defaults to None. """ if reasoning_content is None: reasoning_content = "" if tools_call_args is None: tools_call_args = [] if tools_call_name is None: tools_call_name = [] if tools_call_ids is None: tools_call_ids = [] if tools_call_extra_content is None: tools_call_extra_content = {} self.role = role self.completion_text = completion_text self.result_chain = result_chain self.tools_call_args = tools_call_args self.tools_call_name = tools_call_name self.tools_call_ids = tools_call_ids self.tools_call_extra_content = tools_call_extra_content self.reasoning_content = reasoning_content self.reasoning_signature = reasoning_signature self.raw_completion = raw_completion self.is_chunk = is_chunk if id is not None: self.id = id if usage is not None: self.usage = usage @property def completion_text(self): if self.result_chain: return self.result_chain.get_plain_text() return self._completion_text @completion_text.setter def completion_text(self, value) -> None: if self.result_chain: self.result_chain.chain = [ comp for comp in self.result_chain.chain if not isinstance(comp, Comp.Plain) ] # 清空 Plain 组件 self.result_chain.chain.insert(0, Comp.Plain(value)) else: self._completion_text = value def to_openai_tool_calls(self) -> list[dict]: """Convert to OpenAI tool calls format. Deprecated, use to_openai_to_calls_model instead.""" ret = [] for idx, tool_call_arg in enumerate(self.tools_call_args): payload = { "id": self.tools_call_ids[idx], "function": { "name": self.tools_call_name[idx], "arguments": json.dumps(tool_call_arg), }, "type": "function", } if self.tools_call_extra_content.get(self.tools_call_ids[idx]): payload["extra_content"] = self.tools_call_extra_content[ self.tools_call_ids[idx] ] ret.append(payload) return ret def to_openai_to_calls_model(self) -> list[ToolCall]: """The same as to_openai_tool_calls but return pydantic model.""" ret = [] for idx, tool_call_arg in enumerate(self.tools_call_args): ret.append( ToolCall( id=self.tools_call_ids[idx], function=ToolCall.FunctionBody( name=self.tools_call_name[idx], arguments=json.dumps(tool_call_arg), ), # the extra_content will not serialize if it's None when calling ToolCall.model_dump() extra_content=self.tools_call_extra_content.get( self.tools_call_ids[idx] ), ), ) return ret @dataclass class RerankResult: index: int """在候选列表中的索引位置""" relevance_score: float """相关性分数"""