import abc import asyncio import os from collections.abc import AsyncGenerator from typing import TypeAlias, Union from astrbot.core.agent.message import ContentPart, Message from astrbot.core.agent.tool import ToolSet from astrbot.core.provider.entities import ( LLMResponse, ProviderMeta, RerankResult, ToolCallsResult, ) from astrbot.core.provider.register import provider_cls_map from astrbot.core.utils.astrbot_path import get_astrbot_path Providers: TypeAlias = Union[ "Provider", "STTProvider", "TTSProvider", "EmbeddingProvider", "RerankProvider", ] class AbstractProvider(abc.ABC): """Provider Abstract Class""" def __init__(self, provider_config: dict) -> None: super().__init__() self.model_name = "" self.provider_config = provider_config def set_model(self, model_name: str) -> None: """Set the current model name""" self.model_name = model_name def get_model(self) -> str: """Get the current model name""" return self.model_name def meta(self) -> ProviderMeta: """Get the provider metadata""" provider_type_name = self.provider_config["type"] meta_data = provider_cls_map.get(provider_type_name) if not meta_data: raise ValueError(f"Provider type {provider_type_name} not registered") meta = ProviderMeta( id=self.provider_config.get("id", "default"), model=self.get_model(), type=provider_type_name, provider_type=meta_data.provider_type, ) return meta async def test(self) -> None: """test the provider is a raises: Exception: if the provider is not available """ ... class Provider(AbstractProvider): """Chat Provider""" def __init__( self, provider_config: dict, provider_settings: dict, ) -> None: super().__init__(provider_config) self.provider_settings = provider_settings @abc.abstractmethod def get_current_key(self) -> str: raise NotImplementedError def get_keys(self) -> list[str]: """获得提供商 Key""" keys = self.provider_config.get("key", [""]) return keys or [""] @abc.abstractmethod def set_key(self, key: str) -> None: raise NotImplementedError @abc.abstractmethod async def get_models(self) -> list[str]: """获得支持的模型列表""" raise NotImplementedError @abc.abstractmethod async def text_chat( self, prompt: str | None = None, session_id: str | None = None, image_urls: list[str] | None = None, func_tool: ToolSet | None = None, contexts: list[Message] | list[dict] | None = None, system_prompt: str | None = None, tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None, model: str | None = None, extra_user_content_parts: list[ContentPart] | None = None, **kwargs, ) -> LLMResponse: """获得 LLM 的文本对话结果。会使用当前的模型进行对话。 Args: prompt: 提示词,和 contexts 二选一使用,如果都指定,则会将 prompt(以及可能的 image_urls) 作为最新的一条记录添加到 contexts 中 session_id: 会话 ID(此属性已经被废弃) image_urls: 图片 URL 列表 tools: tool set contexts: 上下文,和 prompt 二选一使用 tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling extra_user_content_parts: 额外的内容块列表,用于在用户消息后添加额外的文本块(如系统提醒、指令等) kwargs: 其他参数 Notes: - 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。 - 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。 """ ... async def text_chat_stream( self, prompt: str | None = None, session_id: str | None = None, image_urls: list[str] | None = None, func_tool: ToolSet | None = None, contexts: list[Message] | list[dict] | None = None, system_prompt: str | None = None, tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None, model: str | None = None, **kwargs, ) -> AsyncGenerator[LLMResponse, None]: """获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。 Args: prompt: 提示词,和 contexts 二选一使用,如果都指定,则会将 prompt(以及可能的 image_urls) 作为最新的一条记录添加到 contexts 中 session_id: 会话 ID(此属性已经被废弃) image_urls: 图片 URL 列表 tools: tool set contexts: 上下文,和 prompt 二选一使用 tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling kwargs: 其他参数 Notes: - 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。 - 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。 """ if False: # pragma: no cover - make this an async generator for typing yield None # type: ignore raise NotImplementedError() async def pop_record(self, context: list) -> None: """弹出 context 第一条非系统提示词对话记录""" poped = 0 indexs_to_pop = [] for idx, record in enumerate(context): if record["role"] == "system": continue indexs_to_pop.append(idx) poped += 1 if poped == 2: break for idx in reversed(indexs_to_pop): context.pop(idx) def _ensure_message_to_dicts( self, messages: list[dict] | list[Message] | None, ) -> list[dict]: """Convert a list of Message objects to a list of dictionaries.""" if not messages: return [] dicts: list[dict] = [] for message in messages: if isinstance(message, Message): dicts.append(message.model_dump()) else: dicts.append(message) return dicts async def test(self, timeout: float = 45.0) -> None: await asyncio.wait_for( self.text_chat(prompt="REPLY `PONG` ONLY"), timeout=timeout, ) class STTProvider(AbstractProvider): def __init__(self, provider_config: dict, provider_settings: dict) -> None: super().__init__(provider_config) self.provider_config = provider_config self.provider_settings = provider_settings @abc.abstractmethod async def get_text(self, audio_url: str) -> str: """获取音频的文本""" raise NotImplementedError async def test(self) -> None: sample_audio_path = os.path.join( get_astrbot_path(), "samples", "stt_health_check.wav", ) await self.get_text(sample_audio_path) class TTSProvider(AbstractProvider): def __init__(self, provider_config: dict, provider_settings: dict) -> None: super().__init__(provider_config) self.provider_config = provider_config self.provider_settings = provider_settings def support_stream(self) -> bool: """是否支持流式 TTS Returns: bool: True 表示支持流式处理,False 表示不支持(默认) Notes: 子类可以重写此方法返回 True 来启用流式 TTS 支持 """ return False @abc.abstractmethod async def get_audio(self, text: str) -> str: """获取文本的音频,返回音频文件路径""" raise NotImplementedError async def get_audio_stream( self, text_queue: asyncio.Queue[str | None], audio_queue: "asyncio.Queue[bytes | tuple[str, bytes] | None]", ) -> None: """流式 TTS 处理方法。 从 text_queue 中读取文本片段,将生成的音频数据(WAV 格式的 in-memory bytes)放入 audio_queue。 当 text_queue 收到 None 时,表示文本输入结束,此时应该处理完所有剩余文本并向 audio_queue 发送 None 表示结束。 Args: text_queue: 输入文本队列,None 表示输入结束 audio_queue: 输出音频队列(bytes 或 (text, bytes)),None 表示输出结束 Notes: - 默认实现会将文本累积后一次性调用 get_audio 生成完整音频 - 子类可以重写此方法实现真正的流式 TTS - 音频数据应该是 WAV 格式的 bytes """ accumulated_text = "" while True: text_part = await text_queue.get() if text_part is None: # 输入结束,处理累积的文本 if accumulated_text: try: # 调用原有的 get_audio 方法获取音频文件路径 audio_path = await self.get_audio(accumulated_text) # 读取音频文件内容 with open(audio_path, "rb") as f: audio_data = f.read() await audio_queue.put((accumulated_text, audio_data)) except Exception: # 出错时也要发送 None 结束标记 pass # 发送结束标记 await audio_queue.put(None) break accumulated_text += text_part async def test(self) -> None: audio_path = await self.get_audio("hi") # 检查生成的音频文件是否有效 if not os.path.exists(audio_path): raise Exception("TTS test failed: audio file was not created") file_size = os.path.getsize(audio_path) if file_size == 0: raise Exception( "TTS test failed: generated audio file is empty (0 bytes). " "Please check your TTS provider configuration, especially required parameters like group_id for MiniMax." ) # 清理测试文件 try: os.remove(audio_path) except Exception: pass class EmbeddingProvider(AbstractProvider): def __init__(self, provider_config: dict, provider_settings: dict) -> None: super().__init__(provider_config) self.provider_config = provider_config self.provider_settings = provider_settings @abc.abstractmethod async def get_embedding(self, text: str) -> list[float]: """获取文本的向量""" ... @abc.abstractmethod async def get_embeddings(self, text: list[str]) -> list[list[float]]: """批量获取文本的向量""" ... @abc.abstractmethod def get_dim(self) -> int: """获取向量的维度""" ... async def test(self) -> None: await self.get_embedding("astrbot") async def get_embeddings_batch( self, texts: list[str], batch_size: int = 16, tasks_limit: int = 3, max_retries: int = 3, progress_callback=None, ) -> list[list[float]]: """批量获取文本的向量,分批处理以节省内存 Args: texts: 文本列表 batch_size: 每批处理的文本数量 tasks_limit: 并发任务数量限制 max_retries: 失败时的最大重试次数 progress_callback: 进度回调函数,接收参数 (current, total) Returns: 向量列表 """ semaphore = asyncio.Semaphore(tasks_limit) all_embeddings: list[list[float]] = [] failed_batches: list[tuple[int, list[str]]] = [] completed_count = 0 total_count = len(texts) async def process_batch(batch_idx: int, batch_texts: list[str]) -> None: nonlocal completed_count async with semaphore: for attempt in range(max_retries): try: batch_embeddings = await self.get_embeddings(batch_texts) all_embeddings.extend(batch_embeddings) completed_count += len(batch_texts) if progress_callback: await progress_callback(completed_count, total_count) return except Exception as e: if attempt == max_retries - 1: # 最后一次重试失败,记录失败的批次 failed_batches.append((batch_idx, batch_texts)) raise Exception( f"批次 {batch_idx} 处理失败,已重试 {max_retries} 次: {e!s}", ) # 等待一段时间后重试,使用指数退避 await asyncio.sleep(2**attempt) tasks = [] for i in range(0, len(texts), batch_size): batch_texts = texts[i : i + batch_size] batch_idx = i // batch_size tasks.append(process_batch(batch_idx, batch_texts)) # 收集所有任务的结果,包括失败的任务 results = await asyncio.gather(*tasks, return_exceptions=True) # 检查是否有失败的任务 errors = [r for r in results if isinstance(r, Exception)] if errors: error_msg = ( f"有 {len(errors)} 个批次处理失败: {'; '.join(str(e) for e in errors)}" ) raise Exception(error_msg) return all_embeddings class RerankProvider(AbstractProvider): def __init__(self, provider_config: dict, provider_settings: dict) -> None: super().__init__(provider_config) self.provider_config = provider_config self.provider_settings = provider_settings @abc.abstractmethod async def rerank( self, query: str, documents: list[str], top_n: int | None = None, ) -> list[RerankResult]: """获取查询和文档的重排序分数""" ... async def test(self) -> None: result = await self.rerank("Apple", documents=["apple", "banana"]) if not result: raise Exception("Rerank provider test failed, no results returned")