File size: 7,129 Bytes
8ede856
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
"""会话控制"""

import abc
import asyncio
import copy
import functools
import time
from collections.abc import Awaitable, Callable
from typing import Any

import astrbot.core.message.components as Comp
from astrbot.core.platform import AstrMessageEvent

USER_SESSIONS: dict[str, "SessionWaiter"] = {}  # 存储 SessionWaiter 实例
FILTERS: list["SessionFilter"] = []  # 存储 SessionFilter 实例


class SessionController:
    """控制一个 Session 是否已经结束"""

    def __init__(self) -> None:
        self.future = asyncio.Future()
        self.current_event: asyncio.Event | None = None
        """当前正在等待的所用的异步事件"""
        self.ts: float | None = None
        """上次保持(keep)开始时的时间"""
        self.timeout: float | int | None = None
        """上次保持(keep)开始时的超时时间"""

        self.history_chains: list[list[Comp.BaseMessageComponent]] = []

    def stop(self, error: Exception | None = None) -> None:
        """立即结束这个会话"""
        if not self.future.done():
            if error:
                self.future.set_exception(error)
            else:
                self.future.set_result(None)

    def keep(self, timeout: float = 0, reset_timeout=False) -> None:
        """保持这个会话

        Args:
            timeout (float): 必填。会话超时时间。
            当 reset_timeout 设置为 True 时, 代表重置超时时间, timeout 必须 > 0, 如果 <= 0 则立即结束会话。
            当 reset_timeout 设置为 False 时, 代表继续维持原来的超时时间, 新 timeout = 原来剩余的timeout + timeout (可以 < 0)

        """
        new_ts = time.time()

        if reset_timeout:
            if timeout <= 0:
                self.stop()
                return
        else:
            assert self.timeout is not None
            assert self.ts is not None
            left_timeout = self.timeout - (new_ts - self.ts)
            timeout = left_timeout + timeout
            if timeout <= 0:
                self.stop()
                return

        if self.current_event and not self.current_event.is_set():
            self.current_event.set()  # 通知上一个 keep 结束

        new_event = asyncio.Event()
        self.ts = new_ts
        self.current_event = new_event
        self.timeout = timeout

        asyncio.create_task(self._holding(new_event, timeout))  # 开始新的 keep

    async def _holding(self, event: asyncio.Event, timeout: float) -> None:
        """等待事件结束或超时"""
        try:
            await asyncio.wait_for(event.wait(), timeout)
        except asyncio.TimeoutError:
            if not self.future.done():
                self.future.set_exception(TimeoutError("等待超时"))
        except asyncio.CancelledError:
            pass  # 避免报错
        # finally:

    def get_history_chains(self) -> list[list[Comp.BaseMessageComponent]]:
        """获取历史消息链"""
        return self.history_chains


class SessionFilter:
    """如何界定一个会话"""

    @abc.abstractmethod
    def filter(self, event: AstrMessageEvent) -> str:
        """根据事件返回一个会话标识符"""


class DefaultSessionFilter(SessionFilter):
    def filter(self, event: AstrMessageEvent) -> str:
        """默认实现,返回统一消息来源字符串作为会话标识符"""
        return event.unified_msg_origin


class SessionWaiter:
    def __init__(
        self,
        session_filter: SessionFilter,
        session_id: str,
        record_history_chains: bool,
    ) -> None:
        self.session_id = session_id
        self.session_filter = session_filter
        self.handler: (
            Callable[[SessionController, AstrMessageEvent], Awaitable[Any]] | None
        ) = None  # 处理函数

        self.session_controller = SessionController()
        self.record_history_chains = record_history_chains
        """是否记录历史消息链"""

        self._lock = asyncio.Lock()
        """需要保证一个 session 同时只有一个 trigger"""

    async def register_wait(
        self,
        handler: Callable[[SessionController, AstrMessageEvent], Awaitable[Any]],
        timeout: int = 30,
    ) -> Any:
        """等待外部输入并处理"""
        self.handler = handler
        USER_SESSIONS[self.session_id] = self

        # 开始一个会话保持事件
        self.session_controller.keep(timeout, reset_timeout=True)

        try:
            return await self.session_controller.future
        except Exception as e:
            self._cleanup(e)
            raise e
        finally:
            self._cleanup()

    def _cleanup(self, error: Exception | None = None) -> None:
        """清理会话"""
        USER_SESSIONS.pop(self.session_id, None)
        try:
            FILTERS.remove(self.session_filter)
        except ValueError:
            pass
        self.session_controller.stop(error)

    @classmethod
    async def trigger(cls, session_id: str, event: AstrMessageEvent) -> None:
        """外部输入触发会话处理"""
        session = USER_SESSIONS.get(session_id)
        if not session or session.session_controller.future.done():
            return

        async with session._lock:
            if not session.session_controller.future.done():
                if session.record_history_chains:
                    session.session_controller.history_chains.append(
                        [copy.deepcopy(comp) for comp in event.get_messages()],
                    )
                try:
                    # TODO: 这里使用 create_task,跟踪 task,防止超时后这里 handler 仍然在执行
                    assert session.handler is not None
                    await session.handler(session.session_controller, event)
                except Exception as e:
                    session.session_controller.stop(e)


def session_waiter(timeout: int = 30, record_history_chains: bool = False):
    """装饰器:自动将函数注册为 SessionWaiter 处理函数,并等待外部输入触发执行。

    :param timeout: 超时时间(秒)
    :param record_history_chain: 是否自动记录历史消息链。可以通过 controller.get_history_chains() 获取。深拷贝。
    """

    def decorator(
        func: Callable[[SessionController, AstrMessageEvent], Awaitable[Any]],
    ):
        @functools.wraps(func)
        async def wrapper(
            event: AstrMessageEvent,
            session_filter: SessionFilter | None = None,
            *args,
            **kwargs,
        ):
            if not session_filter:
                session_filter = DefaultSessionFilter()
            if not isinstance(session_filter, SessionFilter):
                raise ValueError("session_filter 必须是 SessionFilter")

            session_id = session_filter.filter(event)
            FILTERS.append(session_filter)

            waiter = SessionWaiter(session_filter, session_id, record_history_chains)
            return await waiter.register_wait(func, timeout)

        return wrapper

    return decorator