File size: 1,454 Bytes
8ede856
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
"""Pipeline bootstrap utilities."""

from importlib import import_module

from .stage import registered_stages

_BUILTIN_STAGE_MODULES = (
    "astrbot.core.pipeline.waking_check.stage",
    "astrbot.core.pipeline.whitelist_check.stage",
    "astrbot.core.pipeline.session_status_check.stage",
    "astrbot.core.pipeline.rate_limit_check.stage",
    "astrbot.core.pipeline.content_safety_check.stage",
    "astrbot.core.pipeline.preprocess_stage.stage",
    "astrbot.core.pipeline.process_stage.stage",
    "astrbot.core.pipeline.result_decorate.stage",
    "astrbot.core.pipeline.respond.stage",
)

_EXPECTED_STAGE_NAMES = {
    "WakingCheckStage",
    "WhitelistCheckStage",
    "SessionStatusCheckStage",
    "RateLimitStage",
    "ContentSafetyCheckStage",
    "PreProcessStage",
    "ProcessStage",
    "ResultDecorateStage",
    "RespondStage",
}

_builtin_stages_registered = False


def ensure_builtin_stages_registered() -> None:
    """Ensure built-in pipeline stages are imported and registered."""
    global _builtin_stages_registered

    if _builtin_stages_registered:
        return

    stage_names = {stage_cls.__name__ for stage_cls in registered_stages}
    if _EXPECTED_STAGE_NAMES.issubset(stage_names):
        _builtin_stages_registered = True
        return

    for module_path in _BUILTIN_STAGE_MODULES:
        import_module(module_path)

    _builtin_stages_registered = True


__all__ = ["ensure_builtin_stages_registered"]