File size: 16,219 Bytes
8ede856
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
"""AstrBot 会话-对话管理器, 维护两个本地存储, 其中一个是 json 格式的shared_preferences, 另外一个是数据库.

在 AstrBot 中, 会话和对话是独立的, 会话用于标记对话窗口, 例如群聊"123456789"可以建立一个会话,
在一个会话中可以建立多个对话, 并且支持对话的切换和删除
"""

import json
from collections.abc import Awaitable, Callable

from astrbot.core import sp
from astrbot.core.agent.message import AssistantMessageSegment, UserMessageSegment
from astrbot.core.db import BaseDatabase
from astrbot.core.db.po import Conversation, ConversationV2
from astrbot.core.utils.datetime_utils import to_utc_timestamp


class ConversationManager:
    """负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。"""

    def __init__(self, db_helper: BaseDatabase) -> None:
        self.session_conversations: dict[str, str] = {}
        self.db = db_helper
        self.save_interval = 60  # 每 60 秒保存一次

        # 会话删除回调函数列表(用于级联清理,如知识库配置)
        self._on_session_deleted_callbacks: list[Callable[[str], Awaitable[None]]] = []

    def register_on_session_deleted(
        self,
        callback: Callable[[str], Awaitable[None]],
    ) -> None:
        """注册会话删除回调函数.

        其他模块可以注册回调来响应会话删除事件,实现级联清理。
        例如:知识库模块可以注册回调来清理会话的知识库配置。

        Args:
            callback: 回调函数,接收会话ID (unified_msg_origin) 作为参数

        """
        self._on_session_deleted_callbacks.append(callback)

    async def _trigger_session_deleted(self, unified_msg_origin: str) -> None:
        """触发会话删除回调.

        Args:
            unified_msg_origin: 会话ID

        """
        for callback in self._on_session_deleted_callbacks:
            try:
                await callback(unified_msg_origin)
            except Exception as e:
                from astrbot.core import logger

                logger.error(
                    f"会话删除回调执行失败 (session: {unified_msg_origin}): {e}",
                )

    def _convert_conv_from_v2_to_v1(self, conv_v2: ConversationV2) -> Conversation:
        """将 ConversationV2 对象转换为 Conversation 对象"""
        created_ts = to_utc_timestamp(conv_v2.created_at)
        updated_ts = to_utc_timestamp(conv_v2.updated_at)
        created_at = int(created_ts) if created_ts is not None else 0
        updated_at = int(updated_ts) if updated_ts is not None else 0
        return Conversation(
            platform_id=conv_v2.platform_id,
            user_id=conv_v2.user_id,
            cid=conv_v2.conversation_id,
            history=json.dumps(conv_v2.content or []),
            title=conv_v2.title,
            persona_id=conv_v2.persona_id,
            created_at=created_at,
            updated_at=updated_at,
            token_usage=conv_v2.token_usage,
        )

    async def new_conversation(
        self,
        unified_msg_origin: str,
        platform_id: str | None = None,
        content: list[dict] | None = None,
        title: str | None = None,
        persona_id: str | None = None,
    ) -> str:
        """新建对话,并将当前会话的对话转移到新对话.

        Args:
            unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
        Returns:
            conversation_id (str): 对话 ID, 是 uuid 格式的字符串

        """
        if not platform_id:
            # 如果没有提供 platform_id,则从 unified_msg_origin 中解析
            parts = unified_msg_origin.split(":")
            if len(parts) >= 3:
                platform_id = parts[0]
        if not platform_id:
            platform_id = "unknown"
        conv = await self.db.create_conversation(
            user_id=unified_msg_origin,
            platform_id=platform_id,
            content=content,
            title=title,
            persona_id=persona_id,
        )
        self.session_conversations[unified_msg_origin] = conv.conversation_id
        await sp.session_put(unified_msg_origin, "sel_conv_id", conv.conversation_id)
        return conv.conversation_id

    async def switch_conversation(
        self, unified_msg_origin: str, conversation_id: str
    ) -> None:
        """切换会话的对话

        Args:
            unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
            conversation_id (str): 对话 ID, 是 uuid 格式的字符串

        """
        self.session_conversations[unified_msg_origin] = conversation_id
        await sp.session_put(unified_msg_origin, "sel_conv_id", conversation_id)

    async def delete_conversation(
        self,
        unified_msg_origin: str,
        conversation_id: str | None = None,
    ) -> None:
        """删除会话的对话,当 conversation_id 为 None 时删除会话当前的对话

        Args:
            unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
            conversation_id (str): 对话 ID, 是 uuid 格式的字符串

        """
        if not conversation_id:
            conversation_id = self.session_conversations.get(unified_msg_origin)
        if conversation_id:
            await self.db.delete_conversation(cid=conversation_id)
            curr_cid = await self.get_curr_conversation_id(unified_msg_origin)
            if curr_cid == conversation_id:
                self.session_conversations.pop(unified_msg_origin, None)
                await sp.session_remove(unified_msg_origin, "sel_conv_id")

    async def delete_conversations_by_user_id(self, unified_msg_origin: str) -> None:
        """删除会话的所有对话

        Args:
            unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id

        """
        await self.db.delete_conversations_by_user_id(user_id=unified_msg_origin)
        self.session_conversations.pop(unified_msg_origin, None)
        await sp.session_remove(unified_msg_origin, "sel_conv_id")

        # 触发会话删除回调(级联清理)
        await self._trigger_session_deleted(unified_msg_origin)

    async def get_curr_conversation_id(self, unified_msg_origin: str) -> str | None:
        """获取会话当前的对话 ID

        Args:
            unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
        Returns:
            conversation_id (str): 对话 ID, 是 uuid 格式的字符串

        """
        ret = self.session_conversations.get(unified_msg_origin, None)
        if not ret:
            ret = await sp.session_get(unified_msg_origin, "sel_conv_id", None)
            if ret:
                self.session_conversations[unified_msg_origin] = ret
        return ret

    async def get_conversation(
        self,
        unified_msg_origin: str,
        conversation_id: str,
        create_if_not_exists: bool = False,
    ) -> Conversation | None:
        """获取会话的对话.

        Args:
            unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
            conversation_id (str): 对话 ID, 是 uuid 格式的字符串
            create_if_not_exists (bool): 如果对话不存在,是否创建一个新的对话
        Returns:
            conversation (Conversation): 对话对象

        """
        conv = await self.db.get_conversation_by_id(cid=conversation_id)
        if not conv and create_if_not_exists:
            # 如果对话不存在且需要创建,则新建一个对话
            conversation_id = await self.new_conversation(unified_msg_origin)
            conv = await self.db.get_conversation_by_id(cid=conversation_id)
        conv_res = None
        if conv:
            conv_res = self._convert_conv_from_v2_to_v1(conv)
        return conv_res

    async def get_conversations(
        self,
        unified_msg_origin: str | None = None,
        platform_id: str | None = None,
    ) -> list[Conversation]:
        """获取对话列表.

        Args:
            unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id,可选
            platform_id (str): 平台 ID, 可选参数, 用于过滤对话
        Returns:
            conversations (List[Conversation]): 对话对象列表

        """
        convs = await self.db.get_conversations(
            user_id=unified_msg_origin,
            platform_id=platform_id,
        )
        convs_res = []
        for conv in convs:
            conv_res = self._convert_conv_from_v2_to_v1(conv)
            convs_res.append(conv_res)
        return convs_res

    async def get_filtered_conversations(
        self,
        page: int = 1,
        page_size: int = 20,
        platform_ids: list[str] | None = None,
        search_query: str = "",
        **kwargs,
    ) -> tuple[list[Conversation], int]:
        """获取过滤后的对话列表.

        Args:
            page (int): 页码, 默认为 1
            page_size (int): 每页大小, 默认为 20
            platform_ids (list[str]): 平台 ID 列表, 可选
            search_query (str): 搜索查询字符串, 可选
        Returns:
            conversations (list[Conversation]): 对话对象列表

        """
        convs, cnt = await self.db.get_filtered_conversations(
            page=page,
            page_size=page_size,
            platform_ids=platform_ids,
            search_query=search_query,
            **kwargs,
        )
        convs_res = []
        for conv in convs:
            conv_res = self._convert_conv_from_v2_to_v1(conv)
            convs_res.append(conv_res)
        return convs_res, cnt

    async def update_conversation(
        self,
        unified_msg_origin: str,
        conversation_id: str | None = None,
        history: list[dict] | None = None,
        title: str | None = None,
        persona_id: str | None = None,
        token_usage: int | None = None,
    ) -> None:
        """更新会话的对话.

        Args:
            unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
            conversation_id (str): 对话 ID, 是 uuid 格式的字符串
            history (List[Dict]): 对话历史记录, 是一个字典列表, 每个字典包含 role 和 content 字段
            token_usage (int | None): token 使用量。None 表示不更新

        """
        if not conversation_id:
            # 如果没有提供 conversation_id,则获取当前的
            conversation_id = await self.get_curr_conversation_id(unified_msg_origin)
        if conversation_id:
            await self.db.update_conversation(
                cid=conversation_id,
                title=title,
                persona_id=persona_id,
                content=history,
                token_usage=token_usage,
            )

    async def update_conversation_title(
        self,
        unified_msg_origin: str,
        title: str,
        conversation_id: str | None = None,
    ) -> None:
        """更新会话的对话标题.

        Args:
            unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
            title (str): 对话标题
            conversation_id (str): 对话 ID, 是 uuid 格式的字符串
        Deprecated:
            Use `update_conversation` with `title` parameter instead.

        """
        await self.update_conversation(
            unified_msg_origin=unified_msg_origin,
            conversation_id=conversation_id,
            title=title,
        )

    async def update_conversation_persona_id(
        self,
        unified_msg_origin: str,
        persona_id: str,
        conversation_id: str | None = None,
    ) -> None:
        """更新会话的对话 Persona ID.

        Args:
            unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
            persona_id (str): 对话 Persona ID
            conversation_id (str): 对话 ID, 是 uuid 格式的字符串
        Deprecated:
            Use `update_conversation` with `persona_id` parameter instead.

        """
        await self.update_conversation(
            unified_msg_origin=unified_msg_origin,
            conversation_id=conversation_id,
            persona_id=persona_id,
        )

    async def add_message_pair(
        self,
        cid: str,
        user_message: UserMessageSegment | dict,
        assistant_message: AssistantMessageSegment | dict,
    ) -> None:
        """Add a user-assistant message pair to the conversation history.

        Args:
            cid (str): Conversation ID
            user_message (UserMessageSegment | dict): OpenAI-format user message object or dict
            assistant_message (AssistantMessageSegment | dict): OpenAI-format assistant message object or dict

        Raises:
            Exception: If the conversation with the given ID is not found
        """
        conv = await self.db.get_conversation_by_id(cid=cid)
        if not conv:
            raise Exception(f"Conversation with id {cid} not found")
        history = conv.content or []
        if isinstance(user_message, UserMessageSegment):
            user_msg_dict = user_message.model_dump()
        else:
            user_msg_dict = user_message
        if isinstance(assistant_message, AssistantMessageSegment):
            assistant_msg_dict = assistant_message.model_dump()
        else:
            assistant_msg_dict = assistant_message
        history.append(user_msg_dict)
        history.append(assistant_msg_dict)
        await self.db.update_conversation(
            cid=cid,
            content=history,
        )

    async def get_human_readable_context(
        self,
        unified_msg_origin: str,
        conversation_id: str,
        page: int = 1,
        page_size: int = 10,
    ) -> tuple[list[str], int]:
        """获取人类可读的上下文.

        Args:
            unified_msg_origin (str): 统一的消息来源字符串。格式为 platform_name:message_type:session_id
            conversation_id (str): 对话 ID, 是 uuid 格式的字符串
            page (int): 页码
            page_size (int): 每页大小

        """
        conversation = await self.get_conversation(unified_msg_origin, conversation_id)
        if not conversation:
            return [], 0
        history = json.loads(conversation.history)

        # contexts_groups 存放按顺序的段落(每个段落是一个 str 列表),
        # 之后会被展平成一个扁平的 str 列表返回。
        contexts_groups: list[list[str]] = []
        temp_contexts: list[str] = []
        for record in history:
            if record["role"] == "user":
                temp_contexts.append(f"User: {record['content']}")
            elif record["role"] == "assistant":
                if record.get("content"):
                    temp_contexts.append(f"Assistant: {record['content']}")
                elif "tool_calls" in record:
                    tool_calls_str = json.dumps(
                        record["tool_calls"],
                        ensure_ascii=False,
                    )
                    temp_contexts.append(f"Assistant: [函数调用] {tool_calls_str}")
                else:
                    temp_contexts.append("Assistant: [未知的内容]")
                contexts_groups.insert(0, temp_contexts)
                temp_contexts = []

        # 展平分组后的 contexts 列表为单层字符串列表
        contexts = [item for sublist in contexts_groups for item in sublist]

        # 计算分页
        paged_contexts = contexts[(page - 1) * page_size : page * page_size]
        total_pages = len(contexts) // page_size
        if len(contexts) % page_size != 0:
            total_pages += 1

        return paged_contexts, total_pages