astrbot_help / src /session_storage.py
qa1145's picture
Upload 28 files
d347708 verified
"""
会话持久化存储
使用 SQLite 存储会话状态,支持服务器重启后恢复会话
"""
import sqlite3
import json
import threading
from pathlib import Path
from datetime import datetime
from typing import Dict, List, Optional, Any
import logging
logger = logging.getLogger(__name__)
class SessionStorage:
"""会话存储管理器"""
def __init__(self, db_path: str = "./sessions.db"):
"""
初始化会话存储
Args:
db_path: SQLite 数据库文件路径
"""
self.db_path = Path(db_path)
self._lock = threading.Lock()
# 确保数据库目录存在
self.db_path.parent.mkdir(parents=True, exist_ok=True)
# 初始化数据库表
self._init_db()
def _init_db(self):
"""初始化数据库表结构"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
# 会话表:存储会话元数据和配置
cursor.execute("""
CREATE TABLE IF NOT EXISTS sessions (
session_id TEXT PRIMARY KEY,
model TEXT NOT NULL,
base_url TEXT NOT NULL,
code_dir TEXT NOT NULL,
max_steps INTEGER NOT NULL,
stream_output INTEGER NOT NULL,
tree_depth INTEGER NOT NULL,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL,
last_active TEXT
)
""")
# 对话历史表:存储会话的对话记录
cursor.execute("""
CREATE TABLE IF NOT EXISTS conversation_history (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT NOT NULL,
role TEXT NOT NULL,
content TEXT NOT NULL,
FOREIGN KEY (session_id) REFERENCES sessions (session_id) ON DELETE CASCADE
)
""")
# 记忆表:存储 Memory 对象
cursor.execute("""
CREATE TABLE IF NOT EXISTS memories (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT NOT NULL,
file_path TEXT NOT NULL,
overview TEXT,
key_definitions TEXT, -- JSON 数组
core_logic TEXT,
dependencies TEXT, -- JSON 数组
needed_info TEXT,
FOREIGN KEY (session_id) REFERENCES sessions (session_id) ON DELETE CASCADE
)
""")
# 创建索引
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_session_conversation
ON conversation_history (session_id)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_session_memories
ON memories (session_id)
""")
conn.commit()
def save_session(
self,
session_id: str,
model: str,
base_url: str,
code_dir: str,
max_steps: int,
stream_output: bool,
tree_depth: int,
conversation_history: List[Dict],
memories: List[Dict]
) -> bool:
"""
保存会话完整状态
Args:
session_id: 会话ID
model: LLM 模型名称
base_url: API 基础URL
code_dir: 代码目录
max_steps: 最大步骤数
stream_output: 是否流式输出
tree_depth: 目录树深度
conversation_history: 对话历史列表
memories: Memory 对象列表
Returns:
是否保存成功
"""
with self._lock:
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
now = datetime.now().isoformat()
# 检查会话是否已存在
cursor.execute("SELECT session_id FROM sessions WHERE session_id = ?", (session_id,))
exists = cursor.fetchone() is not None
# 更新或插入会话元数据
if exists:
cursor.execute("""
UPDATE sessions
SET model = ?, base_url = ?, code_dir = ?,
max_steps = ?, stream_output = ?, tree_depth = ?,
updated_at = ?, last_active = ?
WHERE session_id = ?
""", (
model, base_url, code_dir,
max_steps, 1 if stream_output else 0, tree_depth,
now, now, session_id
))
else:
cursor.execute("""
INSERT INTO sessions
(session_id, model, base_url, code_dir, max_steps,
stream_output, tree_depth, created_at, updated_at, last_active)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
session_id, model, base_url, code_dir, max_steps,
1 if stream_output else 0, tree_depth, now, now, now
))
# 删除旧的对话历史和记忆
cursor.execute("DELETE FROM conversation_history WHERE session_id = ?", (session_id,))
cursor.execute("DELETE FROM memories WHERE session_id = ?", (session_id,))
# 插入对话历史
for msg in conversation_history:
cursor.execute("""
INSERT INTO conversation_history (session_id, role, content)
VALUES (?, ?, ?)
""", (session_id, msg["role"], msg["content"]))
# 插入记忆
for memory in memories:
cursor.execute("""
INSERT INTO memories
(session_id, file_path, overview, key_definitions,
core_logic, dependencies, needed_info)
VALUES (?, ?, ?, ?, ?, ?, ?)
""", (
session_id,
memory["file"],
memory.get("overview", ""),
json.dumps(memory.get("key_definitions", []), ensure_ascii=False),
memory.get("core_logic", ""),
json.dumps(memory.get("dependencies", []), ensure_ascii=False),
memory.get("needed_info", "")
))
conn.commit()
logger.debug(f"[SessionStorage] 保存会话: {session_id}")
return True
except Exception as e:
logger.error(f"[SessionStorage] 保存会话失败: {e}")
return False
def load_session(self, session_id: str) -> Optional[Dict[str, Any]]:
"""
加载会话状态
Args:
session_id: 会话ID
Returns:
会话数据字典,如果不存在返回 None
"""
with self._lock:
try:
with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
# 查询会话元数据
cursor.execute("""
SELECT * FROM sessions WHERE session_id = ?
""", (session_id,))
session_row = cursor.fetchone()
if not session_row:
return None
# 加载对话历史
cursor.execute("""
SELECT role, content FROM conversation_history
WHERE session_id = ? ORDER BY id
""", (session_id,))
conversation_history = [
{"role": row["role"], "content": row["content"]}
for row in cursor.fetchall()
]
# 加载记忆
cursor.execute("""
SELECT * FROM memories WHERE session_id = ?
""", (session_id,))
memories = [
{
"file_path": row["file_path"],
"overview": row["overview"],
"key_definitions": json.loads(row["key_definitions"]) if row["key_definitions"] else [],
"core_logic": row["core_logic"],
"dependencies": json.loads(row["dependencies"]) if row["dependencies"] else [],
"needed_info": row["needed_info"]
}
for row in cursor.fetchall()
]
return {
"session_id": session_row["session_id"],
"model": session_row["model"],
"base_url": session_row["base_url"],
"code_dir": session_row["code_dir"],
"max_steps": session_row["max_steps"],
"stream_output": bool(session_row["stream_output"]),
"tree_depth": session_row["tree_depth"],
"created_at": session_row["created_at"],
"updated_at": session_row["updated_at"],
"last_active": session_row["last_active"],
"conversation_history": conversation_history,
"memories": memories
}
except Exception as e:
logger.error(f"[SessionStorage] 加载会话失败: {e}")
return None
def delete_session(self, session_id: str) -> bool:
"""
删除会话
Args:
session_id: 会话ID
Returns:
是否删除成功
"""
with self._lock:
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("DELETE FROM sessions WHERE session_id = ?", (session_id,))
conn.commit()
logger.debug(f"[SessionStorage] 删除会话: {session_id}")
return cursor.rowcount > 0
except Exception as e:
logger.error(f"[SessionStorage] 删除会话失败: {e}")
return False
def list_sessions(self) -> List[Dict[str, Any]]:
"""
列出所有会话(不含详细内容)
Returns:
会话列表
"""
with self._lock:
try:
with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
cursor.execute("""
SELECT session_id, model, code_dir, created_at, updated_at, last_active,
(SELECT COUNT(*) FROM conversation_history WHERE session_id = s.session_id) as message_count,
(SELECT COUNT(*) FROM memories WHERE session_id = s.session_id) as memory_count
FROM sessions s
ORDER BY updated_at DESC
""")
return [
{
"session_id": row["session_id"],
"model": row["model"],
"code_dir": row["code_dir"],
"created_at": row["created_at"],
"updated_at": row["updated_at"],
"last_active": row["last_active"],
"message_count": row["message_count"],
"memory_count": row["memory_count"]
}
for row in cursor.fetchall()
]
except Exception as e:
logger.error(f"[SessionStorage] 列出会话失败: {e}")
return []
def clear_all(self) -> bool:
"""
清空所有会话
Returns:
是否清空成功
"""
with self._lock:
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("DELETE FROM sessions")
conn.commit()
logger.info(f"[SessionStorage] 清空所有会话")
return True
except Exception as e:
logger.error(f"[SessionStorage] 清空会话失败: {e}")
return False
def cleanup_old_sessions(self, days: int = 30) -> int:
"""
清理超过指定天数未活跃的会话
Args:
days: 天数阈值
Returns:
删除的会话数
"""
with self._lock:
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
# 删除超过 N 天未活跃的会话
cursor.execute("""
DELETE FROM sessions
WHERE datetime(last_active) < datetime('now', '-' || ? || ' days')
""", (days,))
count = cursor.rowcount
conn.commit()
logger.info(f"[SessionStorage] 清理了 {count} 个旧会话")
return count
except Exception as e:
logger.error(f"[SessionStorage] 清理旧会话失败: {e}")
return 0