| import abc |
| import asyncio |
| import os |
| from collections.abc import AsyncGenerator |
| from typing import TypeAlias, Union |
|
|
| from astrbot.core.agent.message import ContentPart, Message |
| from astrbot.core.agent.tool import ToolSet |
| from astrbot.core.provider.entities import ( |
| LLMResponse, |
| ProviderMeta, |
| RerankResult, |
| ToolCallsResult, |
| ) |
| from astrbot.core.provider.register import provider_cls_map |
| from astrbot.core.utils.astrbot_path import get_astrbot_path |
|
|
| Providers: TypeAlias = Union[ |
| "Provider", |
| "STTProvider", |
| "TTSProvider", |
| "EmbeddingProvider", |
| "RerankProvider", |
| ] |
|
|
|
|
| class AbstractProvider(abc.ABC): |
| """Provider Abstract Class""" |
|
|
| def __init__(self, provider_config: dict) -> None: |
| super().__init__() |
| self.model_name = "" |
| self.provider_config = provider_config |
|
|
| def set_model(self, model_name: str) -> None: |
| """Set the current model name""" |
| self.model_name = model_name |
|
|
| def get_model(self) -> str: |
| """Get the current model name""" |
| return self.model_name |
|
|
| def meta(self) -> ProviderMeta: |
| """Get the provider metadata""" |
| provider_type_name = self.provider_config["type"] |
| meta_data = provider_cls_map.get(provider_type_name) |
| if not meta_data: |
| raise ValueError(f"Provider type {provider_type_name} not registered") |
| meta = ProviderMeta( |
| id=self.provider_config.get("id", "default"), |
| model=self.get_model(), |
| type=provider_type_name, |
| provider_type=meta_data.provider_type, |
| ) |
| return meta |
|
|
| async def test(self) -> None: |
| """test the provider is a |
| |
| raises: |
| Exception: if the provider is not available |
| """ |
| ... |
|
|
|
|
| class Provider(AbstractProvider): |
| """Chat Provider""" |
|
|
| def __init__( |
| self, |
| provider_config: dict, |
| provider_settings: dict, |
| ) -> None: |
| super().__init__(provider_config) |
| self.provider_settings = provider_settings |
|
|
| @abc.abstractmethod |
| def get_current_key(self) -> str: |
| raise NotImplementedError |
|
|
| def get_keys(self) -> list[str]: |
| """获得提供商 Key""" |
| keys = self.provider_config.get("key", [""]) |
| return keys or [""] |
|
|
| @abc.abstractmethod |
| def set_key(self, key: str) -> None: |
| raise NotImplementedError |
|
|
| @abc.abstractmethod |
| async def get_models(self) -> list[str]: |
| """获得支持的模型列表""" |
| raise NotImplementedError |
|
|
| @abc.abstractmethod |
| async def text_chat( |
| self, |
| prompt: str | None = None, |
| session_id: str | None = None, |
| image_urls: list[str] | None = None, |
| func_tool: ToolSet | None = None, |
| contexts: list[Message] | list[dict] | None = None, |
| system_prompt: str | None = None, |
| tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None, |
| model: str | None = None, |
| extra_user_content_parts: list[ContentPart] | None = None, |
| **kwargs, |
| ) -> LLMResponse: |
| """获得 LLM 的文本对话结果。会使用当前的模型进行对话。 |
| |
| Args: |
| prompt: 提示词,和 contexts 二选一使用,如果都指定,则会将 prompt(以及可能的 image_urls) 作为最新的一条记录添加到 contexts 中 |
| session_id: 会话 ID(此属性已经被废弃) |
| image_urls: 图片 URL 列表 |
| tools: tool set |
| contexts: 上下文,和 prompt 二选一使用 |
| tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling |
| extra_user_content_parts: 额外的内容块列表,用于在用户消息后添加额外的文本块(如系统提醒、指令等) |
| kwargs: 其他参数 |
| |
| Notes: |
| - 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。 |
| - 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。 |
| |
| """ |
| ... |
|
|
| async def text_chat_stream( |
| self, |
| prompt: str | None = None, |
| session_id: str | None = None, |
| image_urls: list[str] | None = None, |
| func_tool: ToolSet | None = None, |
| contexts: list[Message] | list[dict] | None = None, |
| system_prompt: str | None = None, |
| tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None, |
| model: str | None = None, |
| **kwargs, |
| ) -> AsyncGenerator[LLMResponse, None]: |
| """获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。 |
| |
| Args: |
| prompt: 提示词,和 contexts 二选一使用,如果都指定,则会将 prompt(以及可能的 image_urls) 作为最新的一条记录添加到 contexts 中 |
| session_id: 会话 ID(此属性已经被废弃) |
| image_urls: 图片 URL 列表 |
| tools: tool set |
| contexts: 上下文,和 prompt 二选一使用 |
| tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling |
| kwargs: 其他参数 |
| |
| Notes: |
| - 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。 |
| - 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。 |
| |
| """ |
| if False: |
| yield None |
| raise NotImplementedError() |
|
|
| async def pop_record(self, context: list) -> None: |
| """弹出 context 第一条非系统提示词对话记录""" |
| poped = 0 |
| indexs_to_pop = [] |
| for idx, record in enumerate(context): |
| if record["role"] == "system": |
| continue |
| indexs_to_pop.append(idx) |
| poped += 1 |
| if poped == 2: |
| break |
|
|
| for idx in reversed(indexs_to_pop): |
| context.pop(idx) |
|
|
| def _ensure_message_to_dicts( |
| self, |
| messages: list[dict] | list[Message] | None, |
| ) -> list[dict]: |
| """Convert a list of Message objects to a list of dictionaries.""" |
| if not messages: |
| return [] |
| dicts: list[dict] = [] |
| for message in messages: |
| if isinstance(message, Message): |
| dicts.append(message.model_dump()) |
| else: |
| dicts.append(message) |
|
|
| return dicts |
|
|
| async def test(self, timeout: float = 45.0) -> None: |
| await asyncio.wait_for( |
| self.text_chat(prompt="REPLY `PONG` ONLY"), |
| timeout=timeout, |
| ) |
|
|
|
|
| class STTProvider(AbstractProvider): |
| def __init__(self, provider_config: dict, provider_settings: dict) -> None: |
| super().__init__(provider_config) |
| self.provider_config = provider_config |
| self.provider_settings = provider_settings |
|
|
| @abc.abstractmethod |
| async def get_text(self, audio_url: str) -> str: |
| """获取音频的文本""" |
| raise NotImplementedError |
|
|
| async def test(self) -> None: |
| sample_audio_path = os.path.join( |
| get_astrbot_path(), |
| "samples", |
| "stt_health_check.wav", |
| ) |
| await self.get_text(sample_audio_path) |
|
|
|
|
| class TTSProvider(AbstractProvider): |
| def __init__(self, provider_config: dict, provider_settings: dict) -> None: |
| super().__init__(provider_config) |
| self.provider_config = provider_config |
| self.provider_settings = provider_settings |
|
|
| def support_stream(self) -> bool: |
| """是否支持流式 TTS |
| |
| Returns: |
| bool: True 表示支持流式处理,False 表示不支持(默认) |
| |
| Notes: |
| 子类可以重写此方法返回 True 来启用流式 TTS 支持 |
| """ |
| return False |
|
|
| @abc.abstractmethod |
| async def get_audio(self, text: str) -> str: |
| """获取文本的音频,返回音频文件路径""" |
| raise NotImplementedError |
|
|
| async def get_audio_stream( |
| self, |
| text_queue: asyncio.Queue[str | None], |
| audio_queue: "asyncio.Queue[bytes | tuple[str, bytes] | None]", |
| ) -> None: |
| """流式 TTS 处理方法。 |
| |
| 从 text_queue 中读取文本片段,将生成的音频数据(WAV 格式的 in-memory bytes)放入 audio_queue。 |
| 当 text_queue 收到 None 时,表示文本输入结束,此时应该处理完所有剩余文本并向 audio_queue 发送 None 表示结束。 |
| |
| Args: |
| text_queue: 输入文本队列,None 表示输入结束 |
| audio_queue: 输出音频队列(bytes 或 (text, bytes)),None 表示输出结束 |
| |
| Notes: |
| - 默认实现会将文本累积后一次性调用 get_audio 生成完整音频 |
| - 子类可以重写此方法实现真正的流式 TTS |
| - 音频数据应该是 WAV 格式的 bytes |
| """ |
| accumulated_text = "" |
|
|
| while True: |
| text_part = await text_queue.get() |
|
|
| if text_part is None: |
| |
| if accumulated_text: |
| try: |
| |
| audio_path = await self.get_audio(accumulated_text) |
| |
| with open(audio_path, "rb") as f: |
| audio_data = f.read() |
| await audio_queue.put((accumulated_text, audio_data)) |
| except Exception: |
| |
| pass |
| |
| await audio_queue.put(None) |
| break |
|
|
| accumulated_text += text_part |
|
|
| async def test(self) -> None: |
| audio_path = await self.get_audio("hi") |
|
|
| |
| if not os.path.exists(audio_path): |
| raise Exception("TTS test failed: audio file was not created") |
|
|
| file_size = os.path.getsize(audio_path) |
| if file_size == 0: |
| raise Exception( |
| "TTS test failed: generated audio file is empty (0 bytes). " |
| "Please check your TTS provider configuration, especially required parameters like group_id for MiniMax." |
| ) |
|
|
| |
| try: |
| os.remove(audio_path) |
| except Exception: |
| pass |
|
|
|
|
| class EmbeddingProvider(AbstractProvider): |
| def __init__(self, provider_config: dict, provider_settings: dict) -> None: |
| super().__init__(provider_config) |
| self.provider_config = provider_config |
| self.provider_settings = provider_settings |
|
|
| @abc.abstractmethod |
| async def get_embedding(self, text: str) -> list[float]: |
| """获取文本的向量""" |
| ... |
|
|
| @abc.abstractmethod |
| async def get_embeddings(self, text: list[str]) -> list[list[float]]: |
| """批量获取文本的向量""" |
| ... |
|
|
| @abc.abstractmethod |
| def get_dim(self) -> int: |
| """获取向量的维度""" |
| ... |
|
|
| async def test(self) -> None: |
| await self.get_embedding("astrbot") |
|
|
| async def get_embeddings_batch( |
| self, |
| texts: list[str], |
| batch_size: int = 16, |
| tasks_limit: int = 3, |
| max_retries: int = 3, |
| progress_callback=None, |
| ) -> list[list[float]]: |
| """批量获取文本的向量,分批处理以节省内存 |
| |
| Args: |
| texts: 文本列表 |
| batch_size: 每批处理的文本数量 |
| tasks_limit: 并发任务数量限制 |
| max_retries: 失败时的最大重试次数 |
| progress_callback: 进度回调函数,接收参数 (current, total) |
| |
| Returns: |
| 向量列表 |
| |
| """ |
| semaphore = asyncio.Semaphore(tasks_limit) |
| all_embeddings: list[list[float]] = [] |
| failed_batches: list[tuple[int, list[str]]] = [] |
| completed_count = 0 |
| total_count = len(texts) |
|
|
| async def process_batch(batch_idx: int, batch_texts: list[str]) -> None: |
| nonlocal completed_count |
| async with semaphore: |
| for attempt in range(max_retries): |
| try: |
| batch_embeddings = await self.get_embeddings(batch_texts) |
| all_embeddings.extend(batch_embeddings) |
| completed_count += len(batch_texts) |
| if progress_callback: |
| await progress_callback(completed_count, total_count) |
| return |
| except Exception as e: |
| if attempt == max_retries - 1: |
| |
| failed_batches.append((batch_idx, batch_texts)) |
| raise Exception( |
| f"批次 {batch_idx} 处理失败,已重试 {max_retries} 次: {e!s}", |
| ) |
| |
| await asyncio.sleep(2**attempt) |
|
|
| tasks = [] |
| for i in range(0, len(texts), batch_size): |
| batch_texts = texts[i : i + batch_size] |
| batch_idx = i // batch_size |
| tasks.append(process_batch(batch_idx, batch_texts)) |
|
|
| |
| results = await asyncio.gather(*tasks, return_exceptions=True) |
|
|
| |
| errors = [r for r in results if isinstance(r, Exception)] |
| if errors: |
| error_msg = ( |
| f"有 {len(errors)} 个批次处理失败: {'; '.join(str(e) for e in errors)}" |
| ) |
| raise Exception(error_msg) |
|
|
| return all_embeddings |
|
|
|
|
| class RerankProvider(AbstractProvider): |
| def __init__(self, provider_config: dict, provider_settings: dict) -> None: |
| super().__init__(provider_config) |
| self.provider_config = provider_config |
| self.provider_settings = provider_settings |
|
|
| @abc.abstractmethod |
| async def rerank( |
| self, |
| query: str, |
| documents: list[str], |
| top_n: int | None = None, |
| ) -> list[RerankResult]: |
| """获取查询和文档的重排序分数""" |
| ... |
|
|
| async def test(self) -> None: |
| result = await self.rerank("Apple", documents=["apple", "banana"]) |
| if not result: |
| raise Exception("Rerank provider test failed, no results returned") |
|
|