File size: 1,904 Bytes
8ede856
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from astrbot.core import logger

from .entities import ProviderMetaData, ProviderType
from .func_tool_manager import FuncCall

provider_registry: list[ProviderMetaData] = []
"""维护了通过装饰器注册的 Provider"""
provider_cls_map: dict[str, ProviderMetaData] = {}
"""维护了 Provider 类型名称和 ProviderMetadata 的映射"""

llm_tools = FuncCall()


def register_provider_adapter(
    provider_type_name: str,
    desc: str,
    provider_type: ProviderType = ProviderType.CHAT_COMPLETION,
    default_config_tmpl: dict | None = None,
    provider_display_name: str | None = None,
):
    """用于注册平台适配器的带参装饰器"""

    def decorator(cls):
        if provider_type_name in provider_cls_map:
            raise ValueError(
                f"检测到大模型提供商适配器 {provider_type_name} 已经注册,可能发生了大模型提供商适配器类型命名冲突。",
            )

        # 添加必备选项
        if default_config_tmpl:
            if "type" not in default_config_tmpl:
                default_config_tmpl["type"] = provider_type_name
            if "enable" not in default_config_tmpl:
                default_config_tmpl["enable"] = False
            if "id" not in default_config_tmpl:
                default_config_tmpl["id"] = provider_type_name

        pm = ProviderMetaData(
            id="default",  # will be replaced when instantiated
            model=None,
            type=provider_type_name,
            desc=desc,
            provider_type=provider_type,
            cls_type=cls,
            default_config_tmpl=default_config_tmpl,
            provider_display_name=provider_display_name,
        )
        provider_registry.append(pm)
        provider_cls_map[provider_type_name] = pm
        logger.debug(f"服务提供商 Provider {provider_type_name} 已注册")
        return cls

    return decorator