"""AstrBot 会话-对话管理器, 维护两个本地存储, 其中一个是 json 格式的shared_preferences, 另外一个是数据库. 在 AstrBot 中, 会话和对话是独立的, 会话用于标记对话窗口, 例如群聊"123456789"可以建立一个会话, 在一个会话中可以建立多个对话, 并且支持对话的切换和删除 """ import json from collections.abc import Awaitable, Callable from astrbot.core import sp from astrbot.core.agent.message import AssistantMessageSegment, UserMessageSegment from astrbot.core.db import BaseDatabase from astrbot.core.db.po import Conversation, ConversationV2 from astrbot.core.utils.datetime_utils import to_utc_timestamp class ConversationManager: """负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。""" def __init__(self, db_helper: BaseDatabase) -> None: self.session_conversations: dict[str, str] = {} self.db = db_helper self.save_interval = 60 # 每 60 秒保存一次 # 会话删除回调函数列表(用于级联清理,如知识库配置) self._on_session_deleted_callbacks: list[Callable[[str], Awaitable[None]]] = [] def register_on_session_deleted( self, callback: Callable[[str], Awaitable[None]], ) -> None: """注册会话删除回调函数. 其他模块可以注册回调来响应会话删除事件,实现级联清理。 例如:知识库模块可以注册回调来清理会话的知识库配置。 Args: callback: 回调函数,接收会话ID (unified_msg_origin) 作为参数 """ self._on_session_deleted_callbacks.append(callback) async def _trigger_session_deleted(self, unified_msg_origin: str) -> None: """触发会话删除回调. Args: unified_msg_origin: 会话ID """ for callback in self._on_session_deleted_callbacks: try: await callback(unified_msg_origin) except Exception as e: from astrbot.core import logger logger.error( f"会话删除回调执行失败 (session: {unified_msg_origin}): {e}", ) def _convert_conv_from_v2_to_v1(self, conv_v2: ConversationV2) -> Conversation: """将 ConversationV2 对象转换为 Conversation 对象""" created_ts = to_utc_timestamp(conv_v2.created_at) updated_ts = to_utc_timestamp(conv_v2.updated_at) created_at = int(created_ts) if created_ts is not None else 0 updated_at = int(updated_ts) if updated_ts is not None else 0 return Conversation( platform_id=conv_v2.platform_id, user_id=conv_v2.user_id, cid=conv_v2.conversation_id, history=json.dumps(conv_v2.content or []), title=conv_v2.title, persona_id=conv_v2.persona_id, created_at=created_at, updated_at=updated_at, token_usage=conv_v2.token_usage, ) async def new_conversation( self, unified_msg_origin: str, platform_id: str | None = None, content: list[dict] | None = None, title: str | None = None, persona_id: str | None = None, ) -> str: """新建对话,并将当前会话的对话转移到新对话. Args: unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id Returns: conversation_id (str): 对话 ID, 是 uuid 格式的字符串 """ if not platform_id: # 如果没有提供 platform_id,则从 unified_msg_origin 中解析 parts = unified_msg_origin.split(":") if len(parts) >= 3: platform_id = parts[0] if not platform_id: platform_id = "unknown" conv = await self.db.create_conversation( user_id=unified_msg_origin, platform_id=platform_id, content=content, title=title, persona_id=persona_id, ) self.session_conversations[unified_msg_origin] = conv.conversation_id await sp.session_put(unified_msg_origin, "sel_conv_id", conv.conversation_id) return conv.conversation_id async def switch_conversation( self, unified_msg_origin: str, conversation_id: str ) -> None: """切换会话的对话 Args: unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id conversation_id (str): 对话 ID, 是 uuid 格式的字符串 """ self.session_conversations[unified_msg_origin] = conversation_id await sp.session_put(unified_msg_origin, "sel_conv_id", conversation_id) async def delete_conversation( self, unified_msg_origin: str, conversation_id: str | None = None, ) -> None: """删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话 Args: unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id conversation_id (str): 对话 ID, 是 uuid 格式的字符串 """ if not conversation_id: conversation_id = self.session_conversations.get(unified_msg_origin) if conversation_id: await self.db.delete_conversation(cid=conversation_id) curr_cid = await self.get_curr_conversation_id(unified_msg_origin) if curr_cid == conversation_id: self.session_conversations.pop(unified_msg_origin, None) await sp.session_remove(unified_msg_origin, "sel_conv_id") async def delete_conversations_by_user_id(self, unified_msg_origin: str) -> None: """删除会话的所有对话 Args: unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id """ await self.db.delete_conversations_by_user_id(user_id=unified_msg_origin) self.session_conversations.pop(unified_msg_origin, None) await sp.session_remove(unified_msg_origin, "sel_conv_id") # 触发会话删除回调(级联清理) await self._trigger_session_deleted(unified_msg_origin) async def get_curr_conversation_id(self, unified_msg_origin: str) -> str | None: """获取会话当前的对话 ID Args: unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id Returns: conversation_id (str): 对话 ID, 是 uuid 格式的字符串 """ ret = self.session_conversations.get(unified_msg_origin, None) if not ret: ret = await sp.session_get(unified_msg_origin, "sel_conv_id", None) if ret: self.session_conversations[unified_msg_origin] = ret return ret async def get_conversation( self, unified_msg_origin: str, conversation_id: str, create_if_not_exists: bool = False, ) -> Conversation | None: """获取会话的对话. Args: unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id conversation_id (str): 对话 ID, 是 uuid 格式的字符串 create_if_not_exists (bool): 如果对话不存在,是否创建一个新的对话 Returns: conversation (Conversation): 对话对象 """ conv = await self.db.get_conversation_by_id(cid=conversation_id) if not conv and create_if_not_exists: # 如果对话不存在且需要创建,则新建一个对话 conversation_id = await self.new_conversation(unified_msg_origin) conv = await self.db.get_conversation_by_id(cid=conversation_id) conv_res = None if conv: conv_res = self._convert_conv_from_v2_to_v1(conv) return conv_res async def get_conversations( self, unified_msg_origin: str | None = None, platform_id: str | None = None, ) -> list[Conversation]: """获取对话列表. Args: unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id,可选 platform_id (str): 平台 ID, 可选参数, 用于过滤对话 Returns: conversations (List[Conversation]): 对话对象列表 """ convs = await self.db.get_conversations( user_id=unified_msg_origin, platform_id=platform_id, ) convs_res = [] for conv in convs: conv_res = self._convert_conv_from_v2_to_v1(conv) convs_res.append(conv_res) return convs_res async def get_filtered_conversations( self, page: int = 1, page_size: int = 20, platform_ids: list[str] | None = None, search_query: str = "", **kwargs, ) -> tuple[list[Conversation], int]: """获取过滤后的对话列表. Args: page (int): 页码, 默认为 1 page_size (int): 每页大小, 默认为 20 platform_ids (list[str]): 平台 ID 列表, 可选 search_query (str): 搜索查询字符串, 可选 Returns: conversations (list[Conversation]): 对话对象列表 """ convs, cnt = await self.db.get_filtered_conversations( page=page, page_size=page_size, platform_ids=platform_ids, search_query=search_query, **kwargs, ) convs_res = [] for conv in convs: conv_res = self._convert_conv_from_v2_to_v1(conv) convs_res.append(conv_res) return convs_res, cnt async def update_conversation( self, unified_msg_origin: str, conversation_id: str | None = None, history: list[dict] | None = None, title: str | None = None, persona_id: str | None = None, token_usage: int | None = None, ) -> None: """更新会话的对话. Args: unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id conversation_id (str): 对话 ID, 是 uuid 格式的字符串 history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段 token_usage (int | None): token 使用量。None 表示不更新 """ if not conversation_id: # 如果没有提供 conversation_id,则获取当前的 conversation_id = await self.get_curr_conversation_id(unified_msg_origin) if conversation_id: await self.db.update_conversation( cid=conversation_id, title=title, persona_id=persona_id, content=history, token_usage=token_usage, ) async def update_conversation_title( self, unified_msg_origin: str, title: str, conversation_id: str | None = None, ) -> None: """更新会话的对话标题. Args: unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id title (str): 对话标题 conversation_id (str): 对话 ID, 是 uuid 格式的字符串 Deprecated: Use `update_conversation` with `title` parameter instead. """ await self.update_conversation( unified_msg_origin=unified_msg_origin, conversation_id=conversation_id, title=title, ) async def update_conversation_persona_id( self, unified_msg_origin: str, persona_id: str, conversation_id: str | None = None, ) -> None: """更新会话的对话 Persona ID. Args: unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id persona_id (str): 对话 Persona ID conversation_id (str): 对话 ID, 是 uuid 格式的字符串 Deprecated: Use `update_conversation` with `persona_id` parameter instead. """ await self.update_conversation( unified_msg_origin=unified_msg_origin, conversation_id=conversation_id, persona_id=persona_id, ) async def add_message_pair( self, cid: str, user_message: UserMessageSegment | dict, assistant_message: AssistantMessageSegment | dict, ) -> None: """Add a user-assistant message pair to the conversation history. Args: cid (str): Conversation ID user_message (UserMessageSegment | dict): OpenAI-format user message object or dict assistant_message (AssistantMessageSegment | dict): OpenAI-format assistant message object or dict Raises: Exception: If the conversation with the given ID is not found """ conv = await self.db.get_conversation_by_id(cid=cid) if not conv: raise Exception(f"Conversation with id {cid} not found") history = conv.content or [] if isinstance(user_message, UserMessageSegment): user_msg_dict = user_message.model_dump() else: user_msg_dict = user_message if isinstance(assistant_message, AssistantMessageSegment): assistant_msg_dict = assistant_message.model_dump() else: assistant_msg_dict = assistant_message history.append(user_msg_dict) history.append(assistant_msg_dict) await self.db.update_conversation( cid=cid, content=history, ) async def get_human_readable_context( self, unified_msg_origin: str, conversation_id: str, page: int = 1, page_size: int = 10, ) -> tuple[list[str], int]: """获取人类可读的上下文. Args: unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id conversation_id (str): 对话 ID, 是 uuid 格式的字符串 page (int): 页码 page_size (int): 每页大小 """ conversation = await self.get_conversation(unified_msg_origin, conversation_id) if not conversation: return [], 0 history = json.loads(conversation.history) # contexts_groups 存放按顺序的段落(每个段落是一个 str 列表), # 之后会被展平成一个扁平的 str 列表返回。 contexts_groups: list[list[str]] = [] temp_contexts: list[str] = [] for record in history: if record["role"] == "user": temp_contexts.append(f"User: {record['content']}") elif record["role"] == "assistant": if record.get("content"): temp_contexts.append(f"Assistant: {record['content']}") elif "tool_calls" in record: tool_calls_str = json.dumps( record["tool_calls"], ensure_ascii=False, ) temp_contexts.append(f"Assistant: [函数调用] {tool_calls_str}") else: temp_contexts.append("Assistant: [未知的内容]") contexts_groups.insert(0, temp_contexts) temp_contexts = [] # 展平分组后的 contexts 列表为单层字符串列表 contexts = [item for sublist in contexts_groups for item in sublist] # 计算分页 paged_contexts = contexts[(page - 1) * page_size : page * page_size] total_pages = len(contexts) // page_size if len(contexts) % page_size != 0: total_pages += 1 return paged_contexts, total_pages