import asyncio import copy import os import traceback from collections.abc import Callable from typing import Protocol, runtime_checkable from astrbot.core import astrbot_config, logger, sp from astrbot.core.astrbot_config_mgr import AstrBotConfigManager from astrbot.core.db import BaseDatabase from astrbot.core.utils.error_redaction import safe_error from ..persona_mgr import PersonaManager from .entities import ProviderType from .provider import ( EmbeddingProvider, Provider, Providers, RerankProvider, STTProvider, TTSProvider, ) from .register import llm_tools, provider_cls_map @runtime_checkable class HasInitialize(Protocol): async def initialize(self) -> None: ... class ProviderManager: def __init__( self, acm: AstrBotConfigManager, db_helper: BaseDatabase, persona_mgr: PersonaManager, ) -> None: self.reload_lock = asyncio.Lock() self.resource_lock = asyncio.Lock() self.persona_mgr = persona_mgr self.acm = acm config = acm.confs["default"] self.providers_config: list = config["provider"] self.provider_sources_config: list = config.get("provider_sources", []) self.provider_settings: dict = config["provider_settings"] self.provider_stt_settings: dict = config.get("provider_stt_settings", {}) self.provider_tts_settings: dict = config.get("provider_tts_settings", {}) # 人格相关属性,v4.0.0 版本后被废弃,推荐使用 PersonaManager self.default_persona_name = persona_mgr.default_persona self.provider_insts: list[Provider] = [] """加载的 Provider 的实例""" self.stt_provider_insts: list[STTProvider] = [] """加载的 Speech To Text Provider 的实例""" self.tts_provider_insts: list[TTSProvider] = [] """加载的 Text To Speech Provider 的实例""" self.embedding_provider_insts: list[EmbeddingProvider] = [] """加载的 Embedding Provider 的实例""" self.rerank_provider_insts: list[RerankProvider] = [] """加载的 Rerank Provider 的实例""" self.inst_map: dict[ str, Providers, ] = {} """Provider 实例映射. key: provider_id, value: Provider 实例""" self.llm_tools = llm_tools self.curr_provider_inst: Provider | None = None """默认的 Provider 实例。已弃用,请使用 get_using_provider() 方法获取当前使用的 Provider 实例。""" self.curr_stt_provider_inst: STTProvider | None = None """默认的 Speech To Text Provider 实例。已弃用,请使用 get_using_provider() 方法获取当前使用的 Provider 实例。""" self.curr_tts_provider_inst: TTSProvider | None = None """默认的 Text To Speech Provider 实例。已弃用,请使用 get_using_provider() 方法获取当前使用的 Provider 实例。""" self.db_helper = db_helper self._provider_change_callback: ( Callable[[str, ProviderType, str | None], None] | None ) = None self._provider_change_hooks: list[ Callable[[str, ProviderType, str | None], None] ] = [] self._mcp_init_task: asyncio.Task | None = None def set_provider_change_callback( self, cb: Callable[[str, ProviderType, str | None], None] | None, ) -> None: # Backward-compatible single-callback setter. # This callback coexists with register_provider_change_hook subscriptions. self._provider_change_callback = cb def register_provider_change_hook( self, hook: Callable[[str, ProviderType, str | None], None], ) -> None: if hook not in self._provider_change_hooks: self._provider_change_hooks.append(hook) def _notify_provider_changed( self, provider_id: str, provider_type: ProviderType, umo: str | None, ) -> None: if self._provider_change_callback is not None: try: self._provider_change_callback(provider_id, provider_type, umo) except Exception as e: logger.warning( "调用 provider 变更回调失败: provider_id=%s, type=%s, err=%s", provider_id, provider_type, safe_error("", e), ) for hook in list(self._provider_change_hooks): if hook is self._provider_change_callback: continue try: hook(provider_id, provider_type, umo) except Exception as e: logger.warning( "调用 provider 变更钩子失败: provider_id=%s, type=%s, err=%s", provider_id, provider_type, safe_error("", e), ) @property def persona_configs(self) -> list: """动态获取最新的 persona 配置""" return self.persona_mgr.persona_v3_config @property def personas(self) -> list: """动态获取最新的 personas 列表""" return self.persona_mgr.personas_v3 @property def selected_default_persona(self): """动态获取最新的默认选中 persona。已弃用,请使用 context.persona_mgr.get_default_persona_v3()""" return self.persona_mgr.selected_default_persona_v3 async def set_provider( self, provider_id: str, provider_type: ProviderType, umo: str | None = None, ) -> None: """设置提供商。 Args: provider_id (str): 提供商 ID。 provider_type (ProviderType): 提供商类型。 umo (str, optional): 用户会话 ID,用于提供商会话隔离。 Version 4.0.0: 这个版本下已经默认隔离提供商 """ if provider_id not in self.inst_map: raise ValueError(f"提供商 {provider_id} 不存在,无法设置。") if umo: await sp.session_put( umo, f"provider_perf_{provider_type.value}", provider_id, ) self._notify_provider_changed(provider_id, provider_type, umo) return # 不启用提供商会话隔离模式的情况 prov = self.inst_map[provider_id] if provider_type == ProviderType.TEXT_TO_SPEECH and isinstance( prov, TTSProvider, ): self.curr_tts_provider_inst = prov await sp.put_async( key="curr_provider_tts", value=provider_id, scope="global", scope_id="global", ) self._notify_provider_changed(provider_id, provider_type, umo) elif provider_type == ProviderType.SPEECH_TO_TEXT and isinstance( prov, STTProvider, ): self.curr_stt_provider_inst = prov await sp.put_async( key="curr_provider_stt", value=provider_id, scope="global", scope_id="global", ) self._notify_provider_changed(provider_id, provider_type, umo) elif provider_type == ProviderType.CHAT_COMPLETION and isinstance( prov, Provider, ): self.curr_provider_inst = prov await sp.put_async( key="curr_provider", value=provider_id, scope="global", scope_id="global", ) self._notify_provider_changed(provider_id, provider_type, umo) async def get_provider_by_id(self, provider_id: str) -> Providers | None: """根据提供商 ID 获取提供商实例""" return self.inst_map.get(provider_id) def get_using_provider( self, provider_type: ProviderType, umo=None ) -> Providers | None: """获取正在使用的提供商实例。 Args: provider_type (ProviderType): 提供商类型。 umo (str, optional): 用户会话 ID,用于提供商会话隔离。 Returns: Provider: 正在使用的提供商实例。 """ provider = None provider_id = None if umo: provider_id = sp.get( f"provider_perf_{provider_type.value}", None, scope="umo", scope_id=umo, ) if provider_id: provider = self.inst_map.get(provider_id) if not provider: # default setting config = self.acm.get_conf(umo) if provider_type == ProviderType.CHAT_COMPLETION: provider_id = config["provider_settings"].get("default_provider_id") provider = self.inst_map.get(provider_id) if not provider: provider = self.provider_insts[0] if self.provider_insts else None elif provider_type == ProviderType.SPEECH_TO_TEXT: provider_id = config["provider_stt_settings"].get("provider_id") if not provider_id: return None provider = self.inst_map.get(provider_id) if not provider: provider = ( self.stt_provider_insts[0] if self.stt_provider_insts else None ) elif provider_type == ProviderType.TEXT_TO_SPEECH: provider_id = config["provider_tts_settings"].get("provider_id") if not provider_id: return None provider = self.inst_map.get(provider_id) if not provider: provider = ( self.tts_provider_insts[0] if self.tts_provider_insts else None ) else: raise ValueError(f"Unknown provider type: {provider_type}") if not provider and provider_id: logger.warning( f"没有找到 ID 为 {provider_id} 的提供商,这可能是由于您修改了提供商(模型)ID 导致的。" ) return provider async def initialize(self) -> None: # 逐个初始化提供商 for provider_config in self.providers_config: try: await self.load_provider(provider_config) except Exception as e: logger.error(traceback.format_exc()) logger.error(e) selected_provider_id = await sp.get_async( key="curr_provider", default=self.provider_settings.get("default_provider_id"), scope="global", scope_id="global", ) selected_stt_provider_id = await sp.get_async( key="curr_provider_stt", default=self.provider_stt_settings.get("provider_id"), scope="global", scope_id="global", ) selected_tts_provider_id = await sp.get_async( key="curr_provider_tts", default=self.provider_tts_settings.get("provider_id"), scope="global", scope_id="global", ) temp_provider = ( self.inst_map.get(selected_provider_id) if isinstance(selected_provider_id, str) else None ) self.curr_provider_inst = ( temp_provider if isinstance(temp_provider, Provider) else None ) if not self.curr_provider_inst and self.provider_insts: self.curr_provider_inst = self.provider_insts[0] temp_stt = ( self.inst_map.get(selected_stt_provider_id) if isinstance(selected_stt_provider_id, str) else None ) self.curr_stt_provider_inst = ( temp_stt if isinstance(temp_stt, STTProvider) else None ) if not self.curr_stt_provider_inst and self.stt_provider_insts: self.curr_stt_provider_inst = self.stt_provider_insts[0] temp_tts = ( self.inst_map.get(selected_tts_provider_id) if isinstance(selected_tts_provider_id, str) else None ) self.curr_tts_provider_inst = ( temp_tts if isinstance(temp_tts, TTSProvider) else None ) if not self.curr_tts_provider_inst and self.tts_provider_insts: self.curr_tts_provider_inst = self.tts_provider_insts[0] async def _init_mcp_clients_bg() -> None: try: await self.llm_tools.init_mcp_clients() except Exception: logger.error("MCP init background task failed", exc_info=True) if self._mcp_init_task is None or self._mcp_init_task.done(): self._mcp_init_task = asyncio.create_task( _init_mcp_clients_bg(), name="provider-manager:mcp-init", ) def dynamic_import_provider(self, type: str) -> None: """动态导入提供商适配器模块 Args: type (str): 提供商请求类型。 Raises: ImportError: 如果提供商类型未知或无法导入对应模块,则抛出异常。 """ match type: case "openai_chat_completion": from .sources.openai_source import ( ProviderOpenAIOfficial as ProviderOpenAIOfficial, ) case "zhipu_chat_completion": from .sources.zhipu_source import ProviderZhipu as ProviderZhipu case "groq_chat_completion": from .sources.groq_source import ProviderGroq as ProviderGroq case "xai_chat_completion": from .sources.xai_source import ProviderXAI as ProviderXAI case "aihubmix_chat_completion": from .sources.oai_aihubmix_source import ( ProviderAIHubMix as ProviderAIHubMix, ) case "openrouter_chat_completion": from .sources.openrouter_source import ( ProviderOpenRouter as ProviderOpenRouter, ) case "anthropic_chat_completion": from .sources.anthropic_source import ( ProviderAnthropic as ProviderAnthropic, ) case "googlegenai_chat_completion": from .sources.gemini_source import ( ProviderGoogleGenAI as ProviderGoogleGenAI, ) case "sensevoice_stt_selfhost": from .sources.sensevoice_selfhosted_source import ( ProviderSenseVoiceSTTSelfHost as ProviderSenseVoiceSTTSelfHost, ) case "openai_whisper_api": from .sources.whisper_api_source import ( ProviderOpenAIWhisperAPI as ProviderOpenAIWhisperAPI, ) case "openai_whisper_selfhost": from .sources.whisper_selfhosted_source import ( ProviderOpenAIWhisperSelfHost as ProviderOpenAIWhisperSelfHost, ) case "xinference_stt": from .sources.xinference_stt_provider import ( ProviderXinferenceSTT as ProviderXinferenceSTT, ) case "openai_tts_api": from .sources.openai_tts_api_source import ( ProviderOpenAITTSAPI as ProviderOpenAITTSAPI, ) case "genie_tts": from .sources.genie_tts import ( GenieTTSProvider as GenieTTSProvider, ) case "edge_tts": from .sources.edge_tts_source import ( ProviderEdgeTTS as ProviderEdgeTTS, ) case "gsv_tts_selfhost": from .sources.gsv_selfhosted_source import ( ProviderGSVTTS as ProviderGSVTTS, ) case "gsvi_tts_api": from .sources.gsvi_tts_source import ( ProviderGSVITTS as ProviderGSVITTS, ) case "fishaudio_tts_api": from .sources.fishaudio_tts_api_source import ( ProviderFishAudioTTSAPI as ProviderFishAudioTTSAPI, ) case "dashscope_tts": from .sources.dashscope_tts import ( ProviderDashscopeTTSAPI as ProviderDashscopeTTSAPI, ) case "azure_tts": from .sources.azure_tts_source import ( AzureTTSProvider as AzureTTSProvider, ) case "minimax_tts_api": from .sources.minimax_tts_api_source import ( ProviderMiniMaxTTSAPI as ProviderMiniMaxTTSAPI, ) case "volcengine_tts": from .sources.volcengine_tts import ( ProviderVolcengineTTS as ProviderVolcengineTTS, ) case "gemini_tts": from .sources.gemini_tts_source import ( ProviderGeminiTTSAPI as ProviderGeminiTTSAPI, ) case "openai_embedding": from .sources.openai_embedding_source import ( OpenAIEmbeddingProvider as OpenAIEmbeddingProvider, ) case "gemini_embedding": from .sources.gemini_embedding_source import ( GeminiEmbeddingProvider as GeminiEmbeddingProvider, ) case "vllm_rerank": from .sources.vllm_rerank_source import ( VLLMRerankProvider as VLLMRerankProvider, ) case "xinference_rerank": from .sources.xinference_rerank_source import ( XinferenceRerankProvider as XinferenceRerankProvider, ) case "bailian_rerank": from .sources.bailian_rerank_source import ( BailianRerankProvider as BailianRerankProvider, ) def get_merged_provider_config(self, provider_config: dict) -> dict: """获取 provider 配置和 provider_source 配置合并后的结果 Returns: dict: 合并后的 provider 配置,key 为 provider id,value 为合并后的配置字典 """ pc = copy.deepcopy(provider_config) provider_source_id = pc.get("provider_source_id", "") if provider_source_id: provider_source = None for ps in self.provider_sources_config: if ps.get("id") == provider_source_id: provider_source = ps break if provider_source: # 合并配置,provider 的配置优先级更高 merged_config = {**provider_source, **pc} # 保持 id 为 provider 的 id,而不是 source 的 id merged_config["id"] = pc["id"] pc = merged_config return pc def _resolve_env_key_list(self, provider_config: dict) -> dict: keys = provider_config.get("key", []) if not isinstance(keys, list): return provider_config resolved_keys = [] for idx, key in enumerate(keys): if isinstance(key, str) and key.startswith("$"): env_key = key[1:] if env_key.startswith("{") and env_key.endswith("}"): env_key = env_key[1:-1] if env_key: env_val = os.getenv(env_key) if env_val is None: provider_id = provider_config.get("id") logger.warning( f"Provider {provider_id} 配置项 key[{idx}] 使用环境变量 {env_key} 但未设置。", ) resolved_keys.append("") else: resolved_keys.append(env_val) else: resolved_keys.append(key) else: resolved_keys.append(key) provider_config["key"] = resolved_keys return provider_config async def load_provider(self, provider_config: dict) -> None: # 如果 provider_source_id 存在且不为空,则从 provider_sources 中找到对应的配置并合并 provider_config = self.get_merged_provider_config(provider_config) if provider_config.get("provider_type", "") == "chat_completion": provider_config = self._resolve_env_key_list(provider_config) if not provider_config["enable"]: logger.info(f"Provider {provider_config['id']} is disabled, skipping") return if provider_config.get("provider_type", "") == "agent_runner": return logger.info( f"载入 {provider_config['type']}({provider_config['id']}) 服务提供商 ...", ) # 动态导入 try: self.dynamic_import_provider(provider_config["type"]) except (ImportError, ModuleNotFoundError) as e: logger.critical( f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。", exc_info=True, ) return except Exception as e: logger.critical( f"加载 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}。未知原因", exc_info=True, ) return if provider_config["type"] not in provider_cls_map: logger.error( f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。", exc_info=True, ) return provider_metadata = provider_cls_map[provider_config["type"]] try: # 按任务实例化提供商 cls_type = provider_metadata.cls_type if not cls_type: logger.error(f"无法找到 {provider_metadata.type} 的类") return provider_metadata.id = provider_config["id"] match provider_metadata.provider_type: case ProviderType.SPEECH_TO_TEXT: # STT 任务 if not issubclass(cls_type, STTProvider): raise TypeError( f"Provider class {cls_type} is not a subclass of STTProvider" ) inst = cls_type(provider_config, self.provider_settings) if isinstance(inst, HasInitialize): await inst.initialize() self.stt_provider_insts.append(inst) if ( self.provider_stt_settings.get("provider_id") == provider_config["id"] ): self.curr_stt_provider_inst = inst logger.info( f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。", ) if not self.curr_stt_provider_inst: self.curr_stt_provider_inst = inst case ProviderType.TEXT_TO_SPEECH: # TTS 任务 if not issubclass(cls_type, TTSProvider): raise TypeError( f"Provider class {cls_type} is not a subclass of TTSProvider" ) inst = cls_type(provider_config, self.provider_settings) if isinstance(inst, HasInitialize): await inst.initialize() self.tts_provider_insts.append(inst) if ( self.provider_settings.get("provider_id") == provider_config["id"] ): self.curr_tts_provider_inst = inst logger.info( f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。", ) if not self.curr_tts_provider_inst: self.curr_tts_provider_inst = inst case ProviderType.CHAT_COMPLETION: # 文本生成任务 if not issubclass(cls_type, Provider): raise TypeError( f"Provider class {cls_type} is not a subclass of Provider" ) inst = cls_type( provider_config, self.provider_settings, ) if isinstance(inst, HasInitialize): await inst.initialize() self.provider_insts.append(inst) if ( self.provider_settings.get("default_provider_id") == provider_config["id"] ): self.curr_provider_inst = inst logger.info( f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前提供商适配器。", ) if not self.curr_provider_inst: self.curr_provider_inst = inst case ProviderType.EMBEDDING: if not issubclass(cls_type, EmbeddingProvider): raise TypeError( f"Provider class {cls_type} is not a subclass of EmbeddingProvider" ) inst = cls_type(provider_config, self.provider_settings) if isinstance(inst, HasInitialize): await inst.initialize() self.embedding_provider_insts.append(inst) case ProviderType.RERANK: if not issubclass(cls_type, RerankProvider): raise TypeError( f"Provider class {cls_type} is not a subclass of RerankProvider" ) inst = cls_type(provider_config, self.provider_settings) if isinstance(inst, HasInitialize): await inst.initialize() self.rerank_provider_insts.append(inst) case _: # 未知供应商抛出异常,确保inst初始化 # Should be unreachable raise Exception( f"未知的提供商类型:{provider_metadata.provider_type}" ) self.inst_map[provider_config["id"]] = inst except Exception as e: logger.error( f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}", ) raise Exception( f"实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器失败:{e}", ) async def reload(self, provider_config: dict) -> None: async with self.reload_lock: await self.terminate_provider(provider_config["id"]) if provider_config["enable"]: await self.load_provider(provider_config) # 和配置文件保持同步 self.providers_config = astrbot_config["provider"] self.provider_sources_config = astrbot_config.get("provider_sources", []) config_ids = [provider["id"] for provider in self.providers_config] logger.info(f"providers in user's config: {config_ids}") for key in list(self.inst_map.keys()): if key not in config_ids: await self.terminate_provider(key) if len(self.provider_insts) == 0: self.curr_provider_inst = None elif self.curr_provider_inst is None and len(self.provider_insts) > 0: self.curr_provider_inst = self.provider_insts[0] logger.info( f"自动选择 {self.curr_provider_inst.meta().id} 作为当前提供商适配器。", ) if len(self.stt_provider_insts) == 0: self.curr_stt_provider_inst = None elif ( self.curr_stt_provider_inst is None and len(self.stt_provider_insts) > 0 ): self.curr_stt_provider_inst = self.stt_provider_insts[0] logger.info( f"自动选择 {self.curr_stt_provider_inst.meta().id} 作为当前语音转文本提供商适配器。", ) if len(self.tts_provider_insts) == 0: self.curr_tts_provider_inst = None elif ( self.curr_tts_provider_inst is None and len(self.tts_provider_insts) > 0 ): self.curr_tts_provider_inst = self.tts_provider_insts[0] logger.info( f"自动选择 {self.curr_tts_provider_inst.meta().id} 作为当前文本转语音提供商适配器。", ) def get_insts(self): return self.provider_insts async def terminate_provider(self, provider_id: str) -> None: if provider_id in self.inst_map: logger.info( f"终止 {provider_id} 提供商适配器({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)}) ...", ) if self.inst_map[provider_id] in self.provider_insts: prov_inst = self.inst_map[provider_id] if isinstance(prov_inst, Provider): self.provider_insts.remove(prov_inst) if self.inst_map[provider_id] in self.stt_provider_insts: prov_inst = self.inst_map[provider_id] if isinstance(prov_inst, STTProvider): self.stt_provider_insts.remove(prov_inst) if self.inst_map[provider_id] in self.tts_provider_insts: prov_inst = self.inst_map[provider_id] if isinstance(prov_inst, TTSProvider): self.tts_provider_insts.remove(prov_inst) if self.inst_map[provider_id] == self.curr_provider_inst: self.curr_provider_inst = None if self.inst_map[provider_id] == self.curr_stt_provider_inst: self.curr_stt_provider_inst = None if self.inst_map[provider_id] == self.curr_tts_provider_inst: self.curr_tts_provider_inst = None if getattr(self.inst_map[provider_id], "terminate", None): await self.inst_map[provider_id].terminate() # type: ignore logger.info( f"{provider_id} 提供商适配器已终止({len(self.provider_insts)}, {len(self.stt_provider_insts)}, {len(self.tts_provider_insts)})", ) del self.inst_map[provider_id] async def delete_provider( self, provider_id: str | None = None, provider_source_id: str | None = None ) -> None: """Delete provider and/or provider source from config and terminate the instances. Config will be saved after deletion.""" async with self.resource_lock: # delete from config target_prov_ids = [] if provider_id: target_prov_ids.append(provider_id) else: for prov in self.providers_config: if prov.get("provider_source_id") == provider_source_id: target_prov_ids.append(prov.get("id")) config = self.acm.default_conf for tpid in target_prov_ids: await self.terminate_provider(tpid) config["provider"] = [ prov for prov in config["provider"] if prov.get("id") != tpid ] config.save_config() logger.info(f"Provider {target_prov_ids} 已从配置中删除。") async def update_provider(self, origin_provider_id: str, new_config: dict) -> None: """Update provider config and reload the instance. Config will be saved after update.""" async with self.resource_lock: npid = new_config.get("id", None) if not npid: raise ValueError("New provider config must have an 'id' field") config = self.acm.default_conf for provider in config["provider"]: if ( provider.get("id", None) == npid and provider.get("id", None) != origin_provider_id ): raise ValueError(f"Provider ID {npid} already exists") # update config for idx, provider in enumerate(config["provider"]): if provider.get("id", None) == origin_provider_id: config["provider"][idx] = new_config break else: raise ValueError(f"Provider ID {origin_provider_id} not found") config.save_config() # reload instance await self.reload(new_config) async def create_provider(self, new_config: dict) -> None: """Add new provider config and load the instance. Config will be saved after addition.""" async with self.resource_lock: npid = new_config.get("id", None) if not npid: raise ValueError("New provider config must have an 'id' field") config = self.acm.default_conf for provider in config["provider"]: if provider.get("id", None) == npid: raise ValueError(f"Provider ID {npid} already exists") # add to config config["provider"].append(new_config) config.save_config() # load instance await self.load_provider(new_config) async def terminate(self) -> None: if self._mcp_init_task and not self._mcp_init_task.done(): self._mcp_init_task.cancel() try: await self._mcp_init_task except asyncio.CancelledError: pass for provider_inst in self.provider_insts: if hasattr(provider_inst, "terminate"): await provider_inst.terminate() # type: ignore try: await self.llm_tools.disable_mcp_server() except Exception: logger.error("Error while disabling MCP servers", exc_info=True)