File size: 13,762 Bytes
8ede856
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
import asyncio
import traceback
from asyncio import Queue
from dataclasses import dataclass

from astrbot.core import logger
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.star.star_handler import EventType, star_handlers_registry, star_map
from astrbot.core.utils.webhook_utils import ensure_platform_webhook_config

from .platform import Platform, PlatformStatus
from .register import platform_cls_map
from .sources.webchat.webchat_adapter import WebChatAdapter


@dataclass
class PlatformTasks:
    run: asyncio.Task
    wrapper: asyncio.Task


class PlatformManager:
    def __init__(self, config: AstrBotConfig, event_queue: Queue) -> None:
        self.platform_insts: list[Platform] = []
        """加载的 Platform 的实例"""

        self._inst_map: dict[str, dict] = {}
        self._platform_tasks: dict[str, PlatformTasks] = {}

        self.astrbot_config = config
        self.platforms_config = config["platform"]
        self.settings = config["platform_settings"]
        """NOTE: 这里是 default 的配置文件,以保证最大的兼容性;
        这个配置中的 unique_session 需要特殊处理,
        约定整个项目中对 unique_session 的引用都从 default 的配置中获取"""
        self.event_queue = event_queue

    def _is_valid_platform_id(self, platform_id: str | None) -> bool:
        if not platform_id:
            return False
        return ":" not in platform_id and "!" not in platform_id

    def _sanitize_platform_id(self, platform_id: str | None) -> tuple[str | None, bool]:
        if not platform_id:
            return platform_id, False
        sanitized = platform_id.replace(":", "_").replace("!", "_")
        return sanitized, sanitized != platform_id

    def _start_platform_task(self, task_name: str, inst: Platform) -> None:
        run_task = asyncio.create_task(inst.run(), name=task_name)
        wrapper_task = asyncio.create_task(
            self._task_wrapper(run_task, platform=inst),
            name=f"{task_name}_wrapper",
        )
        self._platform_tasks[inst.client_self_id] = PlatformTasks(
            run=run_task,
            wrapper=wrapper_task,
        )

    async def _stop_platform_task(self, client_id: str) -> None:
        tasks = self._platform_tasks.pop(client_id, None)
        if not tasks:
            return
        for task in (tasks.run, tasks.wrapper):
            if not task.done():
                task.cancel()
        await asyncio.gather(tasks.run, tasks.wrapper, return_exceptions=True)

    async def _terminate_inst_and_tasks(self, inst: Platform) -> None:
        client_id = inst.client_self_id
        try:
            if getattr(inst, "terminate", None):
                try:
                    await inst.terminate()
                except asyncio.CancelledError:
                    raise
                except Exception as e:
                    logger.error(
                        "终止平台适配器失败: client_id=%s, error=%s",
                        client_id,
                        e,
                    )
                    logger.error(traceback.format_exc())
        finally:
            await self._stop_platform_task(client_id)

    async def initialize(self) -> None:
        """初始化所有平台适配器"""
        for platform in self.platforms_config:
            try:
                if ensure_platform_webhook_config(platform):
                    self.astrbot_config.save_config()
                await self.load_platform(platform)
            except Exception as e:
                logger.error(f"初始化 {platform} 平台适配器失败: {e}")

        # 网页聊天
        webchat_inst = WebChatAdapter({}, self.settings, self.event_queue)
        self.platform_insts.append(webchat_inst)
        self._start_platform_task("webchat", webchat_inst)

    async def load_platform(self, platform_config: dict) -> None:
        """实例化一个平台"""
        # 动态导入
        try:
            if not platform_config["enable"]:
                return
            platform_id = platform_config.get("id")
            if not self._is_valid_platform_id(platform_id):
                sanitized_id, changed = self._sanitize_platform_id(platform_id)
                if sanitized_id and changed:
                    logger.warning(
                        "平台 ID %r 包含非法字符 ':' 或 '!',已替换为 %r。",
                        platform_id,
                        sanitized_id,
                    )
                    platform_config["id"] = sanitized_id
                    self.astrbot_config.save_config()
                else:
                    logger.error(
                        f"平台 ID {platform_id!r} 不能为空,跳过加载该平台适配器。",
                    )
                    return

            logger.info(
                f"载入 {platform_config['type']}({platform_config['id']}) 平台适配器 ...",
            )
            match platform_config["type"]:
                case "aiocqhttp":
                    from .sources.aiocqhttp.aiocqhttp_platform_adapter import (
                        AiocqhttpAdapter,  # noqa: F401
                    )
                case "qq_official":
                    from .sources.qqofficial.qqofficial_platform_adapter import (
                        QQOfficialPlatformAdapter,  # noqa: F401
                    )
                case "qq_official_webhook":
                    from .sources.qqofficial_webhook.qo_webhook_adapter import (
                        QQOfficialWebhookPlatformAdapter,  # noqa: F401
                    )
                case "lark":
                    from .sources.lark.lark_adapter import (
                        LarkPlatformAdapter,  # noqa: F401
                    )
                case "dingtalk":
                    from .sources.dingtalk.dingtalk_adapter import (
                        DingtalkPlatformAdapter,  # noqa: F401
                    )
                case "telegram":
                    from .sources.telegram.tg_adapter import (
                        TelegramPlatformAdapter,  # noqa: F401
                    )
                case "wecom":
                    from .sources.wecom.wecom_adapter import (
                        WecomPlatformAdapter,  # noqa: F401
                    )
                case "wecom_ai_bot":
                    from .sources.wecom_ai_bot.wecomai_adapter import (
                        WecomAIBotAdapter,  # noqa: F401
                    )
                case "weixin_official_account":
                    from .sources.weixin_official_account.weixin_offacc_adapter import (
                        WeixinOfficialAccountPlatformAdapter,  # noqa: F401
                    )
                case "discord":
                    from .sources.discord.discord_platform_adapter import (
                        DiscordPlatformAdapter,  # noqa: F401
                    )
                case "misskey":
                    from .sources.misskey.misskey_adapter import (
                        MisskeyPlatformAdapter,  # noqa: F401
                    )
                case "slack":
                    from .sources.slack.slack_adapter import SlackAdapter  # noqa: F401
                case "satori":
                    from .sources.satori.satori_adapter import (
                        SatoriPlatformAdapter,  # noqa: F401
                    )
                case "line":
                    from .sources.line.line_adapter import (
                        LinePlatformAdapter,  # noqa: F401
                    )
                case "kook":
                    from .sources.kook.kook_adapter import (
                        KookPlatformAdapter,  # noqa: F401
                    )
        except (ImportError, ModuleNotFoundError) as e:
            logger.error(
                f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->平台日志->安装Pip库 中安装依赖库。",
            )
        except Exception as e:
            logger.error(f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。")

        if platform_config["type"] not in platform_cls_map:
            logger.error(
                f"未找到适用于 {platform_config['type']}({platform_config['id']}) 平台适配器,请检查是否已经安装或者名称填写错误",
            )
            return
        cls_type = platform_cls_map[platform_config["type"]]
        inst: Platform = cls_type(platform_config, self.settings, self.event_queue)
        self._inst_map[platform_config["id"]] = {
            "inst": inst,
            "client_id": inst.client_self_id,
        }
        self.platform_insts.append(inst)
        self._start_platform_task(
            f"platform_{platform_config['type']}_{platform_config['id']}",
            inst,
        )
        handlers = star_handlers_registry.get_handlers_by_event_type(
            EventType.OnPlatformLoadedEvent,
        )
        for handler in handlers:
            try:
                logger.info(
                    f"hook(on_platform_loaded) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}",
                )
                await handler.handler()
            except Exception:
                logger.error(traceback.format_exc())

    async def _task_wrapper(
        self, task: asyncio.Task, platform: Platform | None = None
    ) -> None:
        # 设置平台状态为运行中
        if platform:
            platform.status = PlatformStatus.RUNNING

        try:
            await task
        except asyncio.CancelledError:
            if platform:
                platform.status = PlatformStatus.STOPPED
        except Exception as e:
            error_msg = str(e)
            tb_str = traceback.format_exc()
            logger.error(f"------- 任务 {task.get_name()} 发生错误: {e}")
            for line in tb_str.split("\n"):
                logger.error(f"|    {line}")
            logger.error("-------")

            # 记录错误到平台实例
            if platform:
                platform.record_error(error_msg, tb_str)

    async def reload(self, platform_config: dict) -> None:
        await self.terminate_platform(platform_config["id"])
        if platform_config["enable"]:
            await self.load_platform(platform_config)

        # 和配置文件保持同步
        config_ids = [provider["id"] for provider in self.platforms_config]
        for key in list(self._inst_map.keys()):
            if key not in config_ids:
                await self.terminate_platform(key)

    async def terminate_platform(self, platform_id: str) -> None:
        if platform_id in self._inst_map:
            logger.info(f"正在尝试终止 {platform_id} 平台适配器 ...")

            # client_id = self._inst_map.pop(platform_id, None)
            info = self._inst_map.pop(platform_id)
            client_id = info["client_id"]
            inst: Platform = info["inst"]
            try:
                self.platform_insts.remove(
                    next(
                        inst
                        for inst in self.platform_insts
                        if inst.client_self_id == client_id
                    ),
                )
            except Exception:
                logger.warning(f"可能未完全移除 {platform_id} 平台适配器")

            await self._terminate_inst_and_tasks(inst)

    async def terminate(self) -> None:
        terminated_client_ids: set[str] = set()
        for platform_id in list(self._inst_map.keys()):
            info = self._inst_map.get(platform_id)
            if info:
                terminated_client_ids.add(info["client_id"])
            await self.terminate_platform(platform_id)

        for inst in list(self.platform_insts):
            client_id = inst.client_self_id
            if client_id in terminated_client_ids:
                continue
            await self._terminate_inst_and_tasks(inst)

        self.platform_insts.clear()
        self._inst_map.clear()
        self._platform_tasks.clear()

    def get_insts(self):
        return self.platform_insts

    def get_all_stats(self) -> dict:
        """获取所有平台的统计信息

        Returns:
            包含所有平台统计信息的字典
        """
        stats_list = []
        total_errors = 0
        running_count = 0
        error_count = 0

        for inst in self.platform_insts:
            try:
                stat = inst.get_stats()
                stats_list.append(stat)
                total_errors += stat.get("error_count", 0)
                if stat.get("status") == PlatformStatus.RUNNING.value:
                    running_count += 1
                elif stat.get("status") == PlatformStatus.ERROR.value:
                    error_count += 1
            except Exception as e:
                # 如果获取统计信息失败,记录基本信息
                logger.warning(f"获取平台统计信息失败: {e}")
                stats_list.append(
                    {
                        "id": getattr(inst, "config", {}).get("id", "unknown"),
                        "type": "unknown",
                        "status": "unknown",
                        "error_count": 0,
                        "last_error": None,
                    }
                )

        return {
            "platforms": stats_list,
            "summary": {
                "total": len(stats_list),
                "running": running_count,
                "error": error_count,
                "total_errors": total_errors,
            },
        }