astrbbbb / astrbot /core /provider /register.py
qa1145's picture
Upload 1245 files
8ede856 verified
from astrbot.core import logger
from .entities import ProviderMetaData, ProviderType
from .func_tool_manager import FuncCall
provider_registry: list[ProviderMetaData] = []
"""维护了通过装饰器注册的 Provider"""
provider_cls_map: dict[str, ProviderMetaData] = {}
"""维护了 Provider 类型名称和 ProviderMetadata 的映射"""
llm_tools = FuncCall()
def register_provider_adapter(
provider_type_name: str,
desc: str,
provider_type: ProviderType = ProviderType.CHAT_COMPLETION,
default_config_tmpl: dict | None = None,
provider_display_name: str | None = None,
):
"""用于注册平台适配器的带参装饰器"""
def decorator(cls):
if provider_type_name in provider_cls_map:
raise ValueError(
f"检测到大模型提供商适配器 {provider_type_name} 已经注册,可能发生了大模型提供商适配器类型命名冲突。",
)
# 添加必备选项
if default_config_tmpl:
if "type" not in default_config_tmpl:
default_config_tmpl["type"] = provider_type_name
if "enable" not in default_config_tmpl:
default_config_tmpl["enable"] = False
if "id" not in default_config_tmpl:
default_config_tmpl["id"] = provider_type_name
pm = ProviderMetaData(
id="default", # will be replaced when instantiated
model=None,
type=provider_type_name,
desc=desc,
provider_type=provider_type,
cls_type=cls,
default_config_tmpl=default_config_tmpl,
provider_display_name=provider_display_name,
)
provider_registry.append(pm)
provider_cls_map[provider_type_name] = pm
logger.debug(f"服务提供商 Provider {provider_type_name} 已注册")
return cls
return decorator