astrbbbb / astrbot /core /provider /entities.py
qa1145's picture
Upload 1245 files
8ede856 verified
from __future__ import annotations
import base64
import enum
import json
from dataclasses import dataclass, field
from typing import Any
from anthropic.types import Message as AnthropicMessage
from google.genai.types import GenerateContentResponse
from openai.types.chat.chat_completion import ChatCompletion
import astrbot.core.message.components as Comp
from astrbot import logger
from astrbot.core.agent.message import (
AssistantMessageSegment,
ContentPart,
ToolCall,
ToolCallMessageSegment,
)
from astrbot.core.agent.tool import ToolSet
from astrbot.core.db.po import Conversation
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.utils.io import download_image_by_url
class ProviderType(enum.Enum):
CHAT_COMPLETION = "chat_completion"
SPEECH_TO_TEXT = "speech_to_text"
TEXT_TO_SPEECH = "text_to_speech"
EMBEDDING = "embedding"
RERANK = "rerank"
@dataclass
class ProviderMeta:
"""The basic metadata of a provider instance."""
id: str
"""the unique id of the provider instance that user configured"""
model: str | None
"""the model name of the provider instance currently used"""
type: str
"""the name of the provider adapter, such as openai, ollama"""
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
"""the capability type of the provider adapter"""
@dataclass
class ProviderMetaData(ProviderMeta):
"""The metadata of a provider adapter for registration."""
desc: str = ""
"""the short description of the provider adapter"""
cls_type: Any = None
"""the class type of the provider adapter"""
default_config_tmpl: dict | None = None
"""the default configuration template of the provider adapter"""
provider_display_name: str | None = None
"""the display name of the provider shown in the WebUI configuration page; if empty, the type is used"""
@dataclass
class ToolCallsResult:
"""工具调用结果"""
tool_calls_info: AssistantMessageSegment
"""函数调用的信息"""
tool_calls_result: list[ToolCallMessageSegment]
"""函数调用的结果"""
def to_openai_messages(self) -> list[dict]:
ret = [
self.tool_calls_info.model_dump(),
*[item.model_dump() for item in self.tool_calls_result],
]
return ret
def to_openai_messages_model(
self,
) -> list[AssistantMessageSegment | ToolCallMessageSegment]:
return [
self.tool_calls_info,
*self.tool_calls_result,
]
@dataclass
class ProviderRequest:
prompt: str | None = None
"""提示词"""
session_id: str | None = ""
"""会话 ID"""
image_urls: list[str] = field(default_factory=list)
"""图片 URL 列表"""
extra_user_content_parts: list[ContentPart] = field(default_factory=list)
"""额外的用户消息内容部分列表,用于在用户消息后添加额外的内容块(如系统提醒、指令等)。支持 dict 或 ContentPart 对象"""
func_tool: ToolSet | None = None
"""可用的函数工具"""
contexts: list[dict] = field(default_factory=list)
"""
OpenAI 格式上下文列表。
参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages
"""
system_prompt: str = ""
"""系统提示词"""
conversation: Conversation | None = None
"""关联的对话对象"""
tool_calls_result: list[ToolCallsResult] | ToolCallsResult | None = None
"""附加的上次请求后工具调用的结果。参考: https://platform.openai.com/docs/guides/function-calling#handling-function-calls"""
model: str | None = None
"""模型名称,为 None 时使用提供商的默认模型"""
def __repr__(self) -> str:
return (
f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, "
f"image_count={len(self.image_urls or [])}, "
f"func_tool={self.func_tool}, "
f"contexts={self._print_friendly_context()}, "
f"system_prompt={self.system_prompt}, "
f"conversation_id={self.conversation.cid if self.conversation else 'N/A'}, "
)
def __str__(self) -> str:
return self.__repr__()
def append_tool_calls_result(self, tool_calls_result: ToolCallsResult) -> None:
"""添加工具调用结果到请求中"""
if not self.tool_calls_result:
self.tool_calls_result = []
if isinstance(self.tool_calls_result, ToolCallsResult):
self.tool_calls_result = [self.tool_calls_result]
self.tool_calls_result.append(tool_calls_result)
def _print_friendly_context(self):
"""打印友好的消息上下文。将 image_url 的值替换为 <Image>"""
if not self.contexts:
return f"prompt: {self.prompt}, image_count: {len(self.image_urls or [])}"
result_parts = []
for ctx in self.contexts:
role = ctx.get("role", "unknown")
content = ctx.get("content", "")
if isinstance(content, str):
result_parts.append(f"{role}: {content}")
elif isinstance(content, list):
msg_parts = []
image_count = 0
for item in content:
item_type = item.get("type", "")
if item_type == "text":
msg_parts.append(item.get("text", ""))
elif item_type == "image_url":
image_count += 1
if image_count > 0:
if msg_parts:
msg_parts.append(f"[+{image_count} images]")
else:
msg_parts.append(f"[{image_count} images]")
result_parts.append(f"{role}: {''.join(msg_parts)}")
return "\n".join(result_parts)
async def assemble_context(self) -> dict:
"""将请求(prompt 和 image_urls)包装成 OpenAI 的消息格式。"""
# 构建内容块列表
content_blocks = []
# 1. 用户原始发言(OpenAI 建议:用户发言在前)
if self.prompt and self.prompt.strip():
content_blocks.append({"type": "text", "text": self.prompt})
elif self.image_urls:
# 如果没有文本但有图片,添加占位文本
content_blocks.append({"type": "text", "text": "[图片]"})
# 2. 额外的内容块(系统提醒、指令等)
if self.extra_user_content_parts:
for part in self.extra_user_content_parts:
content_blocks.append(part.model_dump())
# 3. 图片内容
if self.image_urls:
for image_url in self.image_urls:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
image_data = await self._encode_image_bs64(image_path)
elif image_url.startswith("file:///"):
image_path = image_url.replace("file:///", "")
image_data = await self._encode_image_bs64(image_path)
else:
image_data = await self._encode_image_bs64(image_url)
if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
continue
content_blocks.append(
{"type": "image_url", "image_url": {"url": image_data}},
)
# 只有当只有一个来自 prompt 的文本块且没有额外内容块时,才降级为简单格式以保持向后兼容
if (
len(content_blocks) == 1
and content_blocks[0]["type"] == "text"
and not self.extra_user_content_parts
and not self.image_urls
):
return {"role": "user", "content": content_blocks[0]["text"]}
# 否则返回多模态格式
return {"role": "user", "content": content_blocks}
async def _encode_image_bs64(self, image_url: str) -> str:
"""将图片转换为 base64"""
if image_url.startswith("base64://"):
return image_url.replace("base64://", "data:image/jpeg;base64,")
with open(image_url, "rb") as f:
image_bs64 = base64.b64encode(f.read()).decode("utf-8")
return "data:image/jpeg;base64," + image_bs64
return ""
@dataclass
class TokenUsage:
input_other: int = 0
"""The number of input tokens, excluding cached tokens."""
input_cached: int = 0
"""The number of input cached tokens."""
output: int = 0
"""The number of output tokens."""
@property
def total(self) -> int:
return self.input_other + self.input_cached + self.output
@property
def input(self) -> int:
return self.input_other + self.input_cached
def __add__(self, other: TokenUsage) -> TokenUsage:
return TokenUsage(
input_other=self.input_other + other.input_other,
input_cached=self.input_cached + other.input_cached,
output=self.output + other.output,
)
def __sub__(self, other: TokenUsage) -> TokenUsage:
return TokenUsage(
input_other=self.input_other - other.input_other,
input_cached=self.input_cached - other.input_cached,
output=self.output - other.output,
)
@dataclass
class LLMResponse:
role: str
"""The role of the message, e.g., assistant, tool, err"""
result_chain: MessageChain | None = None
"""A chain of message components representing the text completion from LLM."""
tools_call_args: list[dict[str, Any]] = field(default_factory=list)
"""Tool call arguments."""
tools_call_name: list[str] = field(default_factory=list)
"""Tool call names."""
tools_call_ids: list[str] = field(default_factory=list)
"""Tool call IDs."""
tools_call_extra_content: dict[str, dict[str, Any]] = field(default_factory=dict)
"""Tool call extra content. tool_call_id -> extra_content dict"""
reasoning_content: str = ""
"""The reasoning content extracted from the LLM, if any."""
reasoning_signature: str | None = None
"""The signature of the reasoning content, if any."""
raw_completion: (
ChatCompletion | GenerateContentResponse | AnthropicMessage | None
) = None
"""The raw completion response from the LLM provider."""
_completion_text: str = ""
"""The plain text of the completion."""
is_chunk: bool = False
"""Indicates if the response is a chunked response."""
id: str | None = None
"""The ID of the response. For chunked responses, it's the ID of the chunk; for non-chunked responses, it's the ID of the response."""
usage: TokenUsage | None = None
"""The usage of the response. For chunked responses, it's the usage of the chunk; for non-chunked responses, it's the usage of the response."""
def __init__(
self,
role: str,
completion_text: str | None = None,
result_chain: MessageChain | None = None,
tools_call_args: list[dict[str, Any]] | None = None,
tools_call_name: list[str] | None = None,
tools_call_ids: list[str] | None = None,
tools_call_extra_content: dict[str, dict[str, Any]] | None = None,
reasoning_content: str | None = None,
reasoning_signature: str | None = None,
raw_completion: ChatCompletion
| GenerateContentResponse
| AnthropicMessage
| None = None,
is_chunk: bool = False,
id: str | None = None,
usage: TokenUsage | None = None,
) -> None:
"""初始化 LLMResponse
Args:
role (str): 角色, assistant, tool, err
completion_text (str, optional): 返回的结果文本,已经过时,推荐使用 result_chain. Defaults to "".
result_chain (MessageChain, optional): 返回的消息链. Defaults to None.
tools_call_args (List[Dict[str, any]], optional): 工具调用参数. Defaults to None.
tools_call_name (List[str], optional): 工具调用名称. Defaults to None.
raw_completion (ChatCompletion, optional): 原始响应, OpenAI 格式. Defaults to None.
"""
if reasoning_content is None:
reasoning_content = ""
if tools_call_args is None:
tools_call_args = []
if tools_call_name is None:
tools_call_name = []
if tools_call_ids is None:
tools_call_ids = []
if tools_call_extra_content is None:
tools_call_extra_content = {}
self.role = role
self.completion_text = completion_text
self.result_chain = result_chain
self.tools_call_args = tools_call_args
self.tools_call_name = tools_call_name
self.tools_call_ids = tools_call_ids
self.tools_call_extra_content = tools_call_extra_content
self.reasoning_content = reasoning_content
self.reasoning_signature = reasoning_signature
self.raw_completion = raw_completion
self.is_chunk = is_chunk
if id is not None:
self.id = id
if usage is not None:
self.usage = usage
@property
def completion_text(self):
if self.result_chain:
return self.result_chain.get_plain_text()
return self._completion_text
@completion_text.setter
def completion_text(self, value) -> None:
if self.result_chain:
self.result_chain.chain = [
comp
for comp in self.result_chain.chain
if not isinstance(comp, Comp.Plain)
] # 清空 Plain 组件
self.result_chain.chain.insert(0, Comp.Plain(value))
else:
self._completion_text = value
def to_openai_tool_calls(self) -> list[dict]:
"""Convert to OpenAI tool calls format. Deprecated, use to_openai_to_calls_model instead."""
ret = []
for idx, tool_call_arg in enumerate(self.tools_call_args):
payload = {
"id": self.tools_call_ids[idx],
"function": {
"name": self.tools_call_name[idx],
"arguments": json.dumps(tool_call_arg),
},
"type": "function",
}
if self.tools_call_extra_content.get(self.tools_call_ids[idx]):
payload["extra_content"] = self.tools_call_extra_content[
self.tools_call_ids[idx]
]
ret.append(payload)
return ret
def to_openai_to_calls_model(self) -> list[ToolCall]:
"""The same as to_openai_tool_calls but return pydantic model."""
ret = []
for idx, tool_call_arg in enumerate(self.tools_call_args):
ret.append(
ToolCall(
id=self.tools_call_ids[idx],
function=ToolCall.FunctionBody(
name=self.tools_call_name[idx],
arguments=json.dumps(tool_call_arg),
),
# the extra_content will not serialize if it's None when calling ToolCall.model_dump()
extra_content=self.tools_call_extra_content.get(
self.tools_call_ids[idx]
),
),
)
return ret
@dataclass
class RerankResult:
index: int
"""在候选列表中的索引位置"""
relevance_score: float
"""相关性分数"""