| from __future__ import annotations |
|
|
| import base64 |
| import enum |
| import json |
| from dataclasses import dataclass, field |
| from typing import Any |
|
|
| from anthropic.types import Message as AnthropicMessage |
| from google.genai.types import GenerateContentResponse |
| from openai.types.chat.chat_completion import ChatCompletion |
|
|
| import astrbot.core.message.components as Comp |
| from astrbot import logger |
| from astrbot.core.agent.message import ( |
| AssistantMessageSegment, |
| ContentPart, |
| ToolCall, |
| ToolCallMessageSegment, |
| ) |
| from astrbot.core.agent.tool import ToolSet |
| from astrbot.core.db.po import Conversation |
| from astrbot.core.message.message_event_result import MessageChain |
| from astrbot.core.utils.io import download_image_by_url |
|
|
|
|
| class ProviderType(enum.Enum): |
| CHAT_COMPLETION = "chat_completion" |
| SPEECH_TO_TEXT = "speech_to_text" |
| TEXT_TO_SPEECH = "text_to_speech" |
| EMBEDDING = "embedding" |
| RERANK = "rerank" |
|
|
|
|
| @dataclass |
| class ProviderMeta: |
| """The basic metadata of a provider instance.""" |
|
|
| id: str |
| """the unique id of the provider instance that user configured""" |
| model: str | None |
| """the model name of the provider instance currently used""" |
| type: str |
| """the name of the provider adapter, such as openai, ollama""" |
| provider_type: ProviderType = ProviderType.CHAT_COMPLETION |
| """the capability type of the provider adapter""" |
|
|
|
|
| @dataclass |
| class ProviderMetaData(ProviderMeta): |
| """The metadata of a provider adapter for registration.""" |
|
|
| desc: str = "" |
| """the short description of the provider adapter""" |
| cls_type: Any = None |
| """the class type of the provider adapter""" |
| default_config_tmpl: dict | None = None |
| """the default configuration template of the provider adapter""" |
| provider_display_name: str | None = None |
| """the display name of the provider shown in the WebUI configuration page; if empty, the type is used""" |
|
|
|
|
| @dataclass |
| class ToolCallsResult: |
| """工具调用结果""" |
|
|
| tool_calls_info: AssistantMessageSegment |
| """函数调用的信息""" |
| tool_calls_result: list[ToolCallMessageSegment] |
| """函数调用的结果""" |
|
|
| def to_openai_messages(self) -> list[dict]: |
| ret = [ |
| self.tool_calls_info.model_dump(), |
| *[item.model_dump() for item in self.tool_calls_result], |
| ] |
| return ret |
|
|
| def to_openai_messages_model( |
| self, |
| ) -> list[AssistantMessageSegment | ToolCallMessageSegment]: |
| return [ |
| self.tool_calls_info, |
| *self.tool_calls_result, |
| ] |
|
|
|
|
| @dataclass |
| class ProviderRequest: |
| prompt: str | None = None |
| """提示词""" |
| session_id: str | None = "" |
| """会话 ID""" |
| image_urls: list[str] = field(default_factory=list) |
| """图片 URL 列表""" |
| extra_user_content_parts: list[ContentPart] = field(default_factory=list) |
| """额外的用户消息内容部分列表,用于在用户消息后添加额外的内容块(如系统提醒、指令等)。支持 dict 或 ContentPart 对象""" |
| func_tool: ToolSet | None = None |
| """可用的函数工具""" |
| contexts: list[dict] = field(default_factory=list) |
| """ |
| OpenAI 格式上下文列表。 |
| 参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages |
| """ |
| system_prompt: str = "" |
| """系统提示词""" |
| conversation: Conversation | None = None |
| """关联的对话对象""" |
| tool_calls_result: list[ToolCallsResult] | ToolCallsResult | None = None |
| """附加的上次请求后工具调用的结果。参考: https://platform.openai.com/docs/guides/function-calling#handling-function-calls""" |
| model: str | None = None |
| """模型名称,为 None 时使用提供商的默认模型""" |
|
|
| def __repr__(self) -> str: |
| return ( |
| f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, " |
| f"image_count={len(self.image_urls or [])}, " |
| f"func_tool={self.func_tool}, " |
| f"contexts={self._print_friendly_context()}, " |
| f"system_prompt={self.system_prompt}, " |
| f"conversation_id={self.conversation.cid if self.conversation else 'N/A'}, " |
| ) |
|
|
| def __str__(self) -> str: |
| return self.__repr__() |
|
|
| def append_tool_calls_result(self, tool_calls_result: ToolCallsResult) -> None: |
| """添加工具调用结果到请求中""" |
| if not self.tool_calls_result: |
| self.tool_calls_result = [] |
| if isinstance(self.tool_calls_result, ToolCallsResult): |
| self.tool_calls_result = [self.tool_calls_result] |
| self.tool_calls_result.append(tool_calls_result) |
|
|
| def _print_friendly_context(self): |
| """打印友好的消息上下文。将 image_url 的值替换为 <Image>""" |
| if not self.contexts: |
| return f"prompt: {self.prompt}, image_count: {len(self.image_urls or [])}" |
|
|
| result_parts = [] |
|
|
| for ctx in self.contexts: |
| role = ctx.get("role", "unknown") |
| content = ctx.get("content", "") |
|
|
| if isinstance(content, str): |
| result_parts.append(f"{role}: {content}") |
| elif isinstance(content, list): |
| msg_parts = [] |
| image_count = 0 |
|
|
| for item in content: |
| item_type = item.get("type", "") |
|
|
| if item_type == "text": |
| msg_parts.append(item.get("text", "")) |
| elif item_type == "image_url": |
| image_count += 1 |
|
|
| if image_count > 0: |
| if msg_parts: |
| msg_parts.append(f"[+{image_count} images]") |
| else: |
| msg_parts.append(f"[{image_count} images]") |
|
|
| result_parts.append(f"{role}: {''.join(msg_parts)}") |
|
|
| return "\n".join(result_parts) |
|
|
| async def assemble_context(self) -> dict: |
| """将请求(prompt 和 image_urls)包装成 OpenAI 的消息格式。""" |
| |
| content_blocks = [] |
|
|
| |
| if self.prompt and self.prompt.strip(): |
| content_blocks.append({"type": "text", "text": self.prompt}) |
| elif self.image_urls: |
| |
| content_blocks.append({"type": "text", "text": "[图片]"}) |
|
|
| |
| if self.extra_user_content_parts: |
| for part in self.extra_user_content_parts: |
| content_blocks.append(part.model_dump()) |
|
|
| |
| if self.image_urls: |
| for image_url in self.image_urls: |
| if image_url.startswith("http"): |
| image_path = await download_image_by_url(image_url) |
| image_data = await self._encode_image_bs64(image_path) |
| elif image_url.startswith("file:///"): |
| image_path = image_url.replace("file:///", "") |
| image_data = await self._encode_image_bs64(image_path) |
| else: |
| image_data = await self._encode_image_bs64(image_url) |
| if not image_data: |
| logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。") |
| continue |
| content_blocks.append( |
| {"type": "image_url", "image_url": {"url": image_data}}, |
| ) |
|
|
| |
| if ( |
| len(content_blocks) == 1 |
| and content_blocks[0]["type"] == "text" |
| and not self.extra_user_content_parts |
| and not self.image_urls |
| ): |
| return {"role": "user", "content": content_blocks[0]["text"]} |
|
|
| |
| return {"role": "user", "content": content_blocks} |
|
|
| async def _encode_image_bs64(self, image_url: str) -> str: |
| """将图片转换为 base64""" |
| if image_url.startswith("base64://"): |
| return image_url.replace("base64://", "data:image/jpeg;base64,") |
| with open(image_url, "rb") as f: |
| image_bs64 = base64.b64encode(f.read()).decode("utf-8") |
| return "data:image/jpeg;base64," + image_bs64 |
| return "" |
|
|
|
|
| @dataclass |
| class TokenUsage: |
| input_other: int = 0 |
| """The number of input tokens, excluding cached tokens.""" |
| input_cached: int = 0 |
| """The number of input cached tokens.""" |
| output: int = 0 |
| """The number of output tokens.""" |
|
|
| @property |
| def total(self) -> int: |
| return self.input_other + self.input_cached + self.output |
|
|
| @property |
| def input(self) -> int: |
| return self.input_other + self.input_cached |
|
|
| def __add__(self, other: TokenUsage) -> TokenUsage: |
| return TokenUsage( |
| input_other=self.input_other + other.input_other, |
| input_cached=self.input_cached + other.input_cached, |
| output=self.output + other.output, |
| ) |
|
|
| def __sub__(self, other: TokenUsage) -> TokenUsage: |
| return TokenUsage( |
| input_other=self.input_other - other.input_other, |
| input_cached=self.input_cached - other.input_cached, |
| output=self.output - other.output, |
| ) |
|
|
|
|
| @dataclass |
| class LLMResponse: |
| role: str |
| """The role of the message, e.g., assistant, tool, err""" |
| result_chain: MessageChain | None = None |
| """A chain of message components representing the text completion from LLM.""" |
| tools_call_args: list[dict[str, Any]] = field(default_factory=list) |
| """Tool call arguments.""" |
| tools_call_name: list[str] = field(default_factory=list) |
| """Tool call names.""" |
| tools_call_ids: list[str] = field(default_factory=list) |
| """Tool call IDs.""" |
| tools_call_extra_content: dict[str, dict[str, Any]] = field(default_factory=dict) |
| """Tool call extra content. tool_call_id -> extra_content dict""" |
| reasoning_content: str = "" |
| """The reasoning content extracted from the LLM, if any.""" |
| reasoning_signature: str | None = None |
| """The signature of the reasoning content, if any.""" |
|
|
| raw_completion: ( |
| ChatCompletion | GenerateContentResponse | AnthropicMessage | None |
| ) = None |
| """The raw completion response from the LLM provider.""" |
|
|
| _completion_text: str = "" |
| """The plain text of the completion.""" |
|
|
| is_chunk: bool = False |
| """Indicates if the response is a chunked response.""" |
|
|
| id: str | None = None |
| """The ID of the response. For chunked responses, it's the ID of the chunk; for non-chunked responses, it's the ID of the response.""" |
| usage: TokenUsage | None = None |
| """The usage of the response. For chunked responses, it's the usage of the chunk; for non-chunked responses, it's the usage of the response.""" |
|
|
| def __init__( |
| self, |
| role: str, |
| completion_text: str | None = None, |
| result_chain: MessageChain | None = None, |
| tools_call_args: list[dict[str, Any]] | None = None, |
| tools_call_name: list[str] | None = None, |
| tools_call_ids: list[str] | None = None, |
| tools_call_extra_content: dict[str, dict[str, Any]] | None = None, |
| reasoning_content: str | None = None, |
| reasoning_signature: str | None = None, |
| raw_completion: ChatCompletion |
| | GenerateContentResponse |
| | AnthropicMessage |
| | None = None, |
| is_chunk: bool = False, |
| id: str | None = None, |
| usage: TokenUsage | None = None, |
| ) -> None: |
| """初始化 LLMResponse |
| |
| Args: |
| role (str): 角色, assistant, tool, err |
| completion_text (str, optional): 返回的结果文本,已经过时,推荐使用 result_chain. Defaults to "". |
| result_chain (MessageChain, optional): 返回的消息链. Defaults to None. |
| tools_call_args (List[Dict[str, any]], optional): 工具调用参数. Defaults to None. |
| tools_call_name (List[str], optional): 工具调用名称. Defaults to None. |
| raw_completion (ChatCompletion, optional): 原始响应, OpenAI 格式. Defaults to None. |
| |
| """ |
| if reasoning_content is None: |
| reasoning_content = "" |
| if tools_call_args is None: |
| tools_call_args = [] |
| if tools_call_name is None: |
| tools_call_name = [] |
| if tools_call_ids is None: |
| tools_call_ids = [] |
| if tools_call_extra_content is None: |
| tools_call_extra_content = {} |
|
|
| self.role = role |
| self.completion_text = completion_text |
| self.result_chain = result_chain |
| self.tools_call_args = tools_call_args |
| self.tools_call_name = tools_call_name |
| self.tools_call_ids = tools_call_ids |
| self.tools_call_extra_content = tools_call_extra_content |
| self.reasoning_content = reasoning_content |
| self.reasoning_signature = reasoning_signature |
| self.raw_completion = raw_completion |
| self.is_chunk = is_chunk |
|
|
| if id is not None: |
| self.id = id |
| if usage is not None: |
| self.usage = usage |
|
|
| @property |
| def completion_text(self): |
| if self.result_chain: |
| return self.result_chain.get_plain_text() |
| return self._completion_text |
|
|
| @completion_text.setter |
| def completion_text(self, value) -> None: |
| if self.result_chain: |
| self.result_chain.chain = [ |
| comp |
| for comp in self.result_chain.chain |
| if not isinstance(comp, Comp.Plain) |
| ] |
| self.result_chain.chain.insert(0, Comp.Plain(value)) |
| else: |
| self._completion_text = value |
|
|
| def to_openai_tool_calls(self) -> list[dict]: |
| """Convert to OpenAI tool calls format. Deprecated, use to_openai_to_calls_model instead.""" |
| ret = [] |
| for idx, tool_call_arg in enumerate(self.tools_call_args): |
| payload = { |
| "id": self.tools_call_ids[idx], |
| "function": { |
| "name": self.tools_call_name[idx], |
| "arguments": json.dumps(tool_call_arg), |
| }, |
| "type": "function", |
| } |
| if self.tools_call_extra_content.get(self.tools_call_ids[idx]): |
| payload["extra_content"] = self.tools_call_extra_content[ |
| self.tools_call_ids[idx] |
| ] |
| ret.append(payload) |
| return ret |
|
|
| def to_openai_to_calls_model(self) -> list[ToolCall]: |
| """The same as to_openai_tool_calls but return pydantic model.""" |
| ret = [] |
| for idx, tool_call_arg in enumerate(self.tools_call_args): |
| ret.append( |
| ToolCall( |
| id=self.tools_call_ids[idx], |
| function=ToolCall.FunctionBody( |
| name=self.tools_call_name[idx], |
| arguments=json.dumps(tool_call_arg), |
| ), |
| |
| extra_content=self.tools_call_extra_content.get( |
| self.tools_call_ids[idx] |
| ), |
| ), |
| ) |
| return ret |
|
|
|
|
| @dataclass |
| class RerankResult: |
| index: int |
| """在候选列表中的索引位置""" |
| relevance_score: float |
| """相关性分数""" |
|
|