| import asyncio |
| import traceback |
| from asyncio import Queue |
| from dataclasses import dataclass |
|
|
| from astrbot.core import logger |
| from astrbot.core.config.astrbot_config import AstrBotConfig |
| from astrbot.core.star.star_handler import EventType, star_handlers_registry, star_map |
| from astrbot.core.utils.webhook_utils import ensure_platform_webhook_config |
|
|
| from .platform import Platform, PlatformStatus |
| from .register import platform_cls_map |
| from .sources.webchat.webchat_adapter import WebChatAdapter |
|
|
|
|
| @dataclass |
| class PlatformTasks: |
| run: asyncio.Task |
| wrapper: asyncio.Task |
|
|
|
|
| class PlatformManager: |
| def __init__(self, config: AstrBotConfig, event_queue: Queue) -> None: |
| self.platform_insts: list[Platform] = [] |
| """加载的 Platform 的实例""" |
|
|
| self._inst_map: dict[str, dict] = {} |
| self._platform_tasks: dict[str, PlatformTasks] = {} |
|
|
| self.astrbot_config = config |
| self.platforms_config = config["platform"] |
| self.settings = config["platform_settings"] |
| """NOTE: 这里是 default 的配置文件,以保证最大的兼容性; |
| 这个配置中的 unique_session 需要特殊处理, |
| 约定整个项目中对 unique_session 的引用都从 default 的配置中获取""" |
| self.event_queue = event_queue |
|
|
| def _is_valid_platform_id(self, platform_id: str | None) -> bool: |
| if not platform_id: |
| return False |
| return ":" not in platform_id and "!" not in platform_id |
|
|
| def _sanitize_platform_id(self, platform_id: str | None) -> tuple[str | None, bool]: |
| if not platform_id: |
| return platform_id, False |
| sanitized = platform_id.replace(":", "_").replace("!", "_") |
| return sanitized, sanitized != platform_id |
|
|
| def _start_platform_task(self, task_name: str, inst: Platform) -> None: |
| run_task = asyncio.create_task(inst.run(), name=task_name) |
| wrapper_task = asyncio.create_task( |
| self._task_wrapper(run_task, platform=inst), |
| name=f"{task_name}_wrapper", |
| ) |
| self._platform_tasks[inst.client_self_id] = PlatformTasks( |
| run=run_task, |
| wrapper=wrapper_task, |
| ) |
|
|
| async def _stop_platform_task(self, client_id: str) -> None: |
| tasks = self._platform_tasks.pop(client_id, None) |
| if not tasks: |
| return |
| for task in (tasks.run, tasks.wrapper): |
| if not task.done(): |
| task.cancel() |
| await asyncio.gather(tasks.run, tasks.wrapper, return_exceptions=True) |
|
|
| async def _terminate_inst_and_tasks(self, inst: Platform) -> None: |
| client_id = inst.client_self_id |
| try: |
| if getattr(inst, "terminate", None): |
| try: |
| await inst.terminate() |
| except asyncio.CancelledError: |
| raise |
| except Exception as e: |
| logger.error( |
| "终止平台适配器失败: client_id=%s, error=%s", |
| client_id, |
| e, |
| ) |
| logger.error(traceback.format_exc()) |
| finally: |
| await self._stop_platform_task(client_id) |
|
|
| async def initialize(self) -> None: |
| """初始化所有平台适配器""" |
| for platform in self.platforms_config: |
| try: |
| if ensure_platform_webhook_config(platform): |
| self.astrbot_config.save_config() |
| await self.load_platform(platform) |
| except Exception as e: |
| logger.error(f"初始化 {platform} 平台适配器失败: {e}") |
|
|
| |
| webchat_inst = WebChatAdapter({}, self.settings, self.event_queue) |
| self.platform_insts.append(webchat_inst) |
| self._start_platform_task("webchat", webchat_inst) |
|
|
| async def load_platform(self, platform_config: dict) -> None: |
| """实例化一个平台""" |
| |
| try: |
| if not platform_config["enable"]: |
| return |
| platform_id = platform_config.get("id") |
| if not self._is_valid_platform_id(platform_id): |
| sanitized_id, changed = self._sanitize_platform_id(platform_id) |
| if sanitized_id and changed: |
| logger.warning( |
| "平台 ID %r 包含非法字符 ':' 或 '!',已替换为 %r。", |
| platform_id, |
| sanitized_id, |
| ) |
| platform_config["id"] = sanitized_id |
| self.astrbot_config.save_config() |
| else: |
| logger.error( |
| f"平台 ID {platform_id!r} 不能为空,跳过加载该平台适配器。", |
| ) |
| return |
|
|
| logger.info( |
| f"载入 {platform_config['type']}({platform_config['id']}) 平台适配器 ...", |
| ) |
| match platform_config["type"]: |
| case "aiocqhttp": |
| from .sources.aiocqhttp.aiocqhttp_platform_adapter import ( |
| AiocqhttpAdapter, |
| ) |
| case "qq_official": |
| from .sources.qqofficial.qqofficial_platform_adapter import ( |
| QQOfficialPlatformAdapter, |
| ) |
| case "qq_official_webhook": |
| from .sources.qqofficial_webhook.qo_webhook_adapter import ( |
| QQOfficialWebhookPlatformAdapter, |
| ) |
| case "lark": |
| from .sources.lark.lark_adapter import ( |
| LarkPlatformAdapter, |
| ) |
| case "dingtalk": |
| from .sources.dingtalk.dingtalk_adapter import ( |
| DingtalkPlatformAdapter, |
| ) |
| case "telegram": |
| from .sources.telegram.tg_adapter import ( |
| TelegramPlatformAdapter, |
| ) |
| case "wecom": |
| from .sources.wecom.wecom_adapter import ( |
| WecomPlatformAdapter, |
| ) |
| case "wecom_ai_bot": |
| from .sources.wecom_ai_bot.wecomai_adapter import ( |
| WecomAIBotAdapter, |
| ) |
| case "weixin_official_account": |
| from .sources.weixin_official_account.weixin_offacc_adapter import ( |
| WeixinOfficialAccountPlatformAdapter, |
| ) |
| case "discord": |
| from .sources.discord.discord_platform_adapter import ( |
| DiscordPlatformAdapter, |
| ) |
| case "misskey": |
| from .sources.misskey.misskey_adapter import ( |
| MisskeyPlatformAdapter, |
| ) |
| case "slack": |
| from .sources.slack.slack_adapter import SlackAdapter |
| case "satori": |
| from .sources.satori.satori_adapter import ( |
| SatoriPlatformAdapter, |
| ) |
| case "line": |
| from .sources.line.line_adapter import ( |
| LinePlatformAdapter, |
| ) |
| case "kook": |
| from .sources.kook.kook_adapter import ( |
| KookPlatformAdapter, |
| ) |
| except (ImportError, ModuleNotFoundError) as e: |
| logger.error( |
| f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。请检查依赖库是否安装。提示:可以在 管理面板->平台日志->安装Pip库 中安装依赖库。", |
| ) |
| except Exception as e: |
| logger.error(f"加载平台适配器 {platform_config['type']} 失败,原因:{e}。") |
|
|
| if platform_config["type"] not in platform_cls_map: |
| logger.error( |
| f"未找到适用于 {platform_config['type']}({platform_config['id']}) 平台适配器,请检查是否已经安装或者名称填写错误", |
| ) |
| return |
| cls_type = platform_cls_map[platform_config["type"]] |
| inst: Platform = cls_type(platform_config, self.settings, self.event_queue) |
| self._inst_map[platform_config["id"]] = { |
| "inst": inst, |
| "client_id": inst.client_self_id, |
| } |
| self.platform_insts.append(inst) |
| self._start_platform_task( |
| f"platform_{platform_config['type']}_{platform_config['id']}", |
| inst, |
| ) |
| handlers = star_handlers_registry.get_handlers_by_event_type( |
| EventType.OnPlatformLoadedEvent, |
| ) |
| for handler in handlers: |
| try: |
| logger.info( |
| f"hook(on_platform_loaded) -> {star_map[handler.handler_module_path].name} - {handler.handler_name}", |
| ) |
| await handler.handler() |
| except Exception: |
| logger.error(traceback.format_exc()) |
|
|
| async def _task_wrapper( |
| self, task: asyncio.Task, platform: Platform | None = None |
| ) -> None: |
| |
| if platform: |
| platform.status = PlatformStatus.RUNNING |
|
|
| try: |
| await task |
| except asyncio.CancelledError: |
| if platform: |
| platform.status = PlatformStatus.STOPPED |
| except Exception as e: |
| error_msg = str(e) |
| tb_str = traceback.format_exc() |
| logger.error(f"------- 任务 {task.get_name()} 发生错误: {e}") |
| for line in tb_str.split("\n"): |
| logger.error(f"| {line}") |
| logger.error("-------") |
|
|
| |
| if platform: |
| platform.record_error(error_msg, tb_str) |
|
|
| async def reload(self, platform_config: dict) -> None: |
| await self.terminate_platform(platform_config["id"]) |
| if platform_config["enable"]: |
| await self.load_platform(platform_config) |
|
|
| |
| config_ids = [provider["id"] for provider in self.platforms_config] |
| for key in list(self._inst_map.keys()): |
| if key not in config_ids: |
| await self.terminate_platform(key) |
|
|
| async def terminate_platform(self, platform_id: str) -> None: |
| if platform_id in self._inst_map: |
| logger.info(f"正在尝试终止 {platform_id} 平台适配器 ...") |
|
|
| |
| info = self._inst_map.pop(platform_id) |
| client_id = info["client_id"] |
| inst: Platform = info["inst"] |
| try: |
| self.platform_insts.remove( |
| next( |
| inst |
| for inst in self.platform_insts |
| if inst.client_self_id == client_id |
| ), |
| ) |
| except Exception: |
| logger.warning(f"可能未完全移除 {platform_id} 平台适配器") |
|
|
| await self._terminate_inst_and_tasks(inst) |
|
|
| async def terminate(self) -> None: |
| terminated_client_ids: set[str] = set() |
| for platform_id in list(self._inst_map.keys()): |
| info = self._inst_map.get(platform_id) |
| if info: |
| terminated_client_ids.add(info["client_id"]) |
| await self.terminate_platform(platform_id) |
|
|
| for inst in list(self.platform_insts): |
| client_id = inst.client_self_id |
| if client_id in terminated_client_ids: |
| continue |
| await self._terminate_inst_and_tasks(inst) |
|
|
| self.platform_insts.clear() |
| self._inst_map.clear() |
| self._platform_tasks.clear() |
|
|
| def get_insts(self): |
| return self.platform_insts |
|
|
| def get_all_stats(self) -> dict: |
| """获取所有平台的统计信息 |
| |
| Returns: |
| 包含所有平台统计信息的字典 |
| """ |
| stats_list = [] |
| total_errors = 0 |
| running_count = 0 |
| error_count = 0 |
|
|
| for inst in self.platform_insts: |
| try: |
| stat = inst.get_stats() |
| stats_list.append(stat) |
| total_errors += stat.get("error_count", 0) |
| if stat.get("status") == PlatformStatus.RUNNING.value: |
| running_count += 1 |
| elif stat.get("status") == PlatformStatus.ERROR.value: |
| error_count += 1 |
| except Exception as e: |
| |
| logger.warning(f"获取平台统计信息失败: {e}") |
| stats_list.append( |
| { |
| "id": getattr(inst, "config", {}).get("id", "unknown"), |
| "type": "unknown", |
| "status": "unknown", |
| "error_count": 0, |
| "last_error": None, |
| } |
| ) |
|
|
| return { |
| "platforms": stats_list, |
| "summary": { |
| "total": len(stats_list), |
| "running": running_count, |
| "error": error_count, |
| "total_errors": total_errors, |
| }, |
| } |
|
|