| import sqlite3 |
| import time |
| from dataclasses import dataclass |
| from typing import Any |
|
|
| from astrbot.core.db.po import Platform, Stats |
|
|
|
|
| @dataclass |
| class Conversation: |
| """LLM 对话存储 |
| |
| 对于网页聊天,history 存储了包括指令、回复、图片等在内的所有消息。 |
| 对于其他平台的聊天,不存储非 LLM 的回复(因为考虑到已经存储在各自的平台上)。 |
| """ |
|
|
| user_id: str |
| cid: str |
| history: str = "" |
| """字符串格式的列表。""" |
| created_at: int = 0 |
| updated_at: int = 0 |
| title: str = "" |
| persona_id: str = "" |
|
|
|
|
| INIT_SQL = """ |
| CREATE TABLE IF NOT EXISTS platform( |
| name VARCHAR(32), |
| count INTEGER, |
| timestamp INTEGER |
| ); |
| CREATE TABLE IF NOT EXISTS llm( |
| name VARCHAR(32), |
| count INTEGER, |
| timestamp INTEGER |
| ); |
| CREATE TABLE IF NOT EXISTS plugin( |
| name VARCHAR(32), |
| count INTEGER, |
| timestamp INTEGER |
| ); |
| CREATE TABLE IF NOT EXISTS command( |
| name VARCHAR(32), |
| count INTEGER, |
| timestamp INTEGER |
| ); |
| CREATE TABLE IF NOT EXISTS llm_history( |
| provider_type VARCHAR(32), |
| session_id VARCHAR(32), |
| content TEXT |
| ); |
| |
| -- ATRI |
| CREATE TABLE IF NOT EXISTS atri_vision( |
| id TEXT, |
| url_or_path TEXT, |
| caption TEXT, |
| is_meme BOOLEAN, |
| keywords TEXT, |
| platform_name VARCHAR(32), |
| session_id VARCHAR(32), |
| sender_nickname VARCHAR(32), |
| timestamp INTEGER |
| ); |
| |
| CREATE TABLE IF NOT EXISTS webchat_conversation( |
| user_id TEXT, -- 会话 id |
| cid TEXT, -- 对话 id |
| history TEXT, |
| created_at INTEGER, |
| updated_at INTEGER, |
| title TEXT, |
| persona_id TEXT |
| ); |
| |
| PRAGMA encoding = 'UTF-8'; |
| """ |
|
|
|
|
| class SQLiteDatabase: |
| def __init__(self, db_path: str) -> None: |
| super().__init__() |
| self.db_path = db_path |
|
|
| sql = INIT_SQL |
|
|
| |
| self.conn = self._get_conn(self.db_path) |
| c = self.conn.cursor() |
| c.executescript(sql) |
| self.conn.commit() |
|
|
| |
| c.execute( |
| """ |
| PRAGMA table_info(webchat_conversation) |
| """, |
| ) |
| res = c.fetchall() |
| has_title = False |
| has_persona_id = False |
| for row in res: |
| if row[1] == "title": |
| has_title = True |
| if row[1] == "persona_id": |
| has_persona_id = True |
| if not has_title: |
| c.execute( |
| """ |
| ALTER TABLE webchat_conversation ADD COLUMN title TEXT; |
| """, |
| ) |
| self.conn.commit() |
| if not has_persona_id: |
| c.execute( |
| """ |
| ALTER TABLE webchat_conversation ADD COLUMN persona_id TEXT; |
| """, |
| ) |
| self.conn.commit() |
|
|
| c.close() |
|
|
| def _get_conn(self, db_path: str) -> sqlite3.Connection: |
| conn = sqlite3.connect(self.db_path) |
| conn.text_factory = str |
| return conn |
|
|
| def _exec_sql(self, sql: str, params: tuple | None = None) -> None: |
| conn = self.conn |
| try: |
| c = self.conn.cursor() |
| except sqlite3.ProgrammingError: |
| conn = self._get_conn(self.db_path) |
| c = conn.cursor() |
|
|
| if params: |
| c.execute(sql, params) |
| c.close() |
| else: |
| c.execute(sql) |
| c.close() |
|
|
| conn.commit() |
|
|
| def insert_platform_metrics(self, metrics: dict) -> None: |
| for k, v in metrics.items(): |
| self._exec_sql( |
| """ |
| INSERT INTO platform(name, count, timestamp) VALUES (?, ?, ?) |
| """, |
| (k, v, int(time.time())), |
| ) |
|
|
| def insert_llm_metrics(self, metrics: dict) -> None: |
| for k, v in metrics.items(): |
| self._exec_sql( |
| """ |
| INSERT INTO llm(name, count, timestamp) VALUES (?, ?, ?) |
| """, |
| (k, v, int(time.time())), |
| ) |
|
|
| def get_base_stats(self, offset_sec: int = 86400) -> Stats: |
| """获取 offset_sec 秒前到现在的基础统计数据""" |
| where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}" |
|
|
| try: |
| c = self.conn.cursor() |
| except sqlite3.ProgrammingError: |
| c = self._get_conn(self.db_path).cursor() |
|
|
| c.execute( |
| """ |
| SELECT * FROM platform |
| """ |
| + where_clause, |
| ) |
|
|
| platform = [] |
| for row in c.fetchall(): |
| platform.append(Platform(*row)) |
|
|
| c.close() |
|
|
| return Stats(platform=platform) |
|
|
| def get_total_message_count(self) -> int: |
| try: |
| c = self.conn.cursor() |
| except sqlite3.ProgrammingError: |
| c = self._get_conn(self.db_path).cursor() |
|
|
| c.execute( |
| """ |
| SELECT SUM(count) FROM platform |
| """, |
| ) |
| res = c.fetchone() |
| c.close() |
| return res[0] |
|
|
| def get_grouped_base_stats(self, offset_sec: int = 86400) -> Stats: |
| """获取 offset_sec 秒前到现在的基础统计数据(合并)""" |
| where_clause = f" WHERE timestamp >= {int(time.time()) - offset_sec}" |
|
|
| try: |
| c = self.conn.cursor() |
| except sqlite3.ProgrammingError: |
| c = self._get_conn(self.db_path).cursor() |
|
|
| c.execute( |
| """ |
| SELECT name, SUM(count), timestamp FROM platform |
| """ |
| + where_clause |
| + " GROUP BY name", |
| ) |
|
|
| platform = [] |
| for row in c.fetchall(): |
| platform.append(Platform(*row)) |
|
|
| c.close() |
|
|
| return Stats(platform) |
|
|
| def get_conversation_by_user_id( |
| self, user_id: str, cid: str |
| ) -> Conversation | None: |
| try: |
| c = self.conn.cursor() |
| except sqlite3.ProgrammingError: |
| c = self._get_conn(self.db_path).cursor() |
|
|
| c.execute( |
| """ |
| SELECT * FROM webchat_conversation WHERE user_id = ? AND cid = ? |
| """, |
| (user_id, cid), |
| ) |
|
|
| res = c.fetchone() |
| c.close() |
|
|
| if not res: |
| return None |
|
|
| return Conversation(*res) |
|
|
| def new_conversation(self, user_id: str, cid: str) -> None: |
| history = "[]" |
| updated_at = int(time.time()) |
| created_at = updated_at |
| self._exec_sql( |
| """ |
| INSERT INTO webchat_conversation(user_id, cid, history, updated_at, created_at) VALUES (?, ?, ?, ?, ?) |
| """, |
| (user_id, cid, history, updated_at, created_at), |
| ) |
|
|
| def get_conversations(self, user_id: str) -> list[Conversation]: |
| try: |
| c = self.conn.cursor() |
| except sqlite3.ProgrammingError: |
| c = self._get_conn(self.db_path).cursor() |
|
|
| c.execute( |
| """ |
| SELECT cid, created_at, updated_at, title, persona_id FROM webchat_conversation WHERE user_id = ? ORDER BY updated_at DESC |
| """, |
| (user_id,), |
| ) |
|
|
| res = c.fetchall() |
| c.close() |
| conversations = [] |
| for row in res: |
| cid = row[0] |
| created_at = row[1] |
| updated_at = row[2] |
| title = row[3] |
| persona_id = row[4] |
| conversations.append( |
| Conversation("", cid, "[]", created_at, updated_at, title, persona_id), |
| ) |
| return conversations |
|
|
| def update_conversation(self, user_id: str, cid: str, history: str) -> None: |
| """更新对话,并且同时更新时间""" |
| updated_at = int(time.time()) |
| self._exec_sql( |
| """ |
| UPDATE webchat_conversation SET history = ?, updated_at = ? WHERE user_id = ? AND cid = ? |
| """, |
| (history, updated_at, user_id, cid), |
| ) |
|
|
| def update_conversation_title(self, user_id: str, cid: str, title: str) -> None: |
| self._exec_sql( |
| """ |
| UPDATE webchat_conversation SET title = ? WHERE user_id = ? AND cid = ? |
| """, |
| (title, user_id, cid), |
| ) |
|
|
| def update_conversation_persona_id( |
| self, user_id: str, cid: str, persona_id: str |
| ) -> None: |
| self._exec_sql( |
| """ |
| UPDATE webchat_conversation SET persona_id = ? WHERE user_id = ? AND cid = ? |
| """, |
| (persona_id, user_id, cid), |
| ) |
|
|
| def delete_conversation(self, user_id: str, cid: str) -> None: |
| self._exec_sql( |
| """ |
| DELETE FROM webchat_conversation WHERE user_id = ? AND cid = ? |
| """, |
| (user_id, cid), |
| ) |
|
|
| def get_all_conversations( |
| self, |
| page: int = 1, |
| page_size: int = 20, |
| ) -> tuple[list[dict[str, Any]], int]: |
| """获取所有对话,支持分页,按更新时间降序排序""" |
| try: |
| c = self.conn.cursor() |
| except sqlite3.ProgrammingError: |
| c = self._get_conn(self.db_path).cursor() |
|
|
| try: |
| |
| c.execute(""" |
| SELECT COUNT(*) FROM webchat_conversation |
| """) |
| total_count = c.fetchone()[0] |
|
|
| |
| offset = (page - 1) * page_size |
|
|
| |
| c.execute( |
| """ |
| SELECT user_id, cid, created_at, updated_at, title, persona_id |
| FROM webchat_conversation |
| ORDER BY updated_at DESC |
| LIMIT ? OFFSET ? |
| """, |
| (page_size, offset), |
| ) |
|
|
| rows = c.fetchall() |
|
|
| conversations = [] |
|
|
| for row in rows: |
| user_id, cid, created_at, updated_at, title, persona_id = row |
| |
| safe_cid = str(cid) if cid else "unknown" |
| display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid |
|
|
| conversations.append( |
| { |
| "user_id": user_id or "", |
| "cid": safe_cid, |
| "title": title or f"对话 {display_cid}", |
| "persona_id": persona_id or "", |
| "created_at": created_at or 0, |
| "updated_at": updated_at or 0, |
| }, |
| ) |
|
|
| return conversations, total_count |
|
|
| except Exception as _: |
| |
| return [], 0 |
| finally: |
| c.close() |
|
|
| def get_filtered_conversations( |
| self, |
| page: int = 1, |
| page_size: int = 20, |
| platforms: list[str] | None = None, |
| message_types: list[str] | None = None, |
| search_query: str | None = None, |
| exclude_ids: list[str] | None = None, |
| exclude_platforms: list[str] | None = None, |
| ) -> tuple[list[dict[str, Any]], int]: |
| """获取筛选后的对话列表""" |
| try: |
| c = self.conn.cursor() |
| except sqlite3.ProgrammingError: |
| c = self._get_conn(self.db_path).cursor() |
|
|
| try: |
| |
| where_clauses = [] |
| params = [] |
|
|
| |
| if platforms and len(platforms) > 0: |
| platform_conditions = [] |
| for platform in platforms: |
| platform_conditions.append("user_id LIKE ?") |
| params.append(f"{platform}:%") |
|
|
| if platform_conditions: |
| where_clauses.append(f"({' OR '.join(platform_conditions)})") |
|
|
| |
| if message_types and len(message_types) > 0: |
| message_type_conditions = [] |
| for msg_type in message_types: |
| message_type_conditions.append("user_id LIKE ?") |
| params.append(f"%:{msg_type}:%") |
|
|
| if message_type_conditions: |
| where_clauses.append(f"({' OR '.join(message_type_conditions)})") |
|
|
| |
| if search_query: |
| search_query = search_query.encode("unicode_escape").decode("utf-8") |
| where_clauses.append( |
| "(title LIKE ? OR user_id LIKE ? OR cid LIKE ? OR history LIKE ?)", |
| ) |
| search_param = f"%{search_query}%" |
| params.extend([search_param, search_param, search_param, search_param]) |
|
|
| |
| if exclude_ids and len(exclude_ids) > 0: |
| for exclude_id in exclude_ids: |
| where_clauses.append("user_id NOT LIKE ?") |
| params.append(f"{exclude_id}%") |
|
|
| |
| if exclude_platforms and len(exclude_platforms) > 0: |
| for exclude_platform in exclude_platforms: |
| where_clauses.append("user_id NOT LIKE ?") |
| params.append(f"{exclude_platform}:%") |
|
|
| |
| where_sql = " WHERE " + " AND ".join(where_clauses) if where_clauses else "" |
|
|
| |
| count_sql = f"SELECT COUNT(*) FROM webchat_conversation{where_sql}" |
|
|
| |
| c.execute(count_sql, params) |
| total_count = c.fetchone()[0] |
|
|
| |
| offset = (page - 1) * page_size |
|
|
| |
| data_sql = f""" |
| SELECT user_id, cid, created_at, updated_at, title, persona_id |
| FROM webchat_conversation |
| {where_sql} |
| ORDER BY updated_at DESC |
| LIMIT ? OFFSET ? |
| """ |
| query_params = params + [page_size, offset] |
|
|
| |
| c.execute(data_sql, query_params) |
| rows = c.fetchall() |
|
|
| conversations = [] |
|
|
| for row in rows: |
| user_id, cid, created_at, updated_at, title, persona_id = row |
| |
| safe_cid = str(cid) if cid else "unknown" |
| display_cid = safe_cid[:8] if len(safe_cid) >= 8 else safe_cid |
|
|
| conversations.append( |
| { |
| "user_id": user_id or "", |
| "cid": safe_cid, |
| "title": title or f"对话 {display_cid}", |
| "persona_id": persona_id or "", |
| "created_at": created_at or 0, |
| "updated_at": updated_at or 0, |
| }, |
| ) |
|
|
| return conversations, total_count |
|
|
| except Exception as _: |
| |
| return [], 0 |
| finally: |
| c.close() |
|
|