Spaces:
Sleeping
Sleeping
File size: 14,555 Bytes
d347708 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 | """
会话持久化存储
使用 SQLite 存储会话状态,支持服务器重启后恢复会话
"""
import sqlite3
import json
import threading
from pathlib import Path
from datetime import datetime
from typing import Dict, List, Optional, Any
import logging
logger = logging.getLogger(__name__)
class SessionStorage:
"""会话存储管理器"""
def __init__(self, db_path: str = "./sessions.db"):
"""
初始化会话存储
Args:
db_path: SQLite 数据库文件路径
"""
self.db_path = Path(db_path)
self._lock = threading.Lock()
# 确保数据库目录存在
self.db_path.parent.mkdir(parents=True, exist_ok=True)
# 初始化数据库表
self._init_db()
def _init_db(self):
"""初始化数据库表结构"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
# 会话表:存储会话元数据和配置
cursor.execute("""
CREATE TABLE IF NOT EXISTS sessions (
session_id TEXT PRIMARY KEY,
model TEXT NOT NULL,
base_url TEXT NOT NULL,
code_dir TEXT NOT NULL,
max_steps INTEGER NOT NULL,
stream_output INTEGER NOT NULL,
tree_depth INTEGER NOT NULL,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL,
last_active TEXT
)
""")
# 对话历史表:存储会话的对话记录
cursor.execute("""
CREATE TABLE IF NOT EXISTS conversation_history (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT NOT NULL,
role TEXT NOT NULL,
content TEXT NOT NULL,
FOREIGN KEY (session_id) REFERENCES sessions (session_id) ON DELETE CASCADE
)
""")
# 记忆表:存储 Memory 对象
cursor.execute("""
CREATE TABLE IF NOT EXISTS memories (
id INTEGER PRIMARY KEY AUTOINCREMENT,
session_id TEXT NOT NULL,
file_path TEXT NOT NULL,
overview TEXT,
key_definitions TEXT, -- JSON 数组
core_logic TEXT,
dependencies TEXT, -- JSON 数组
needed_info TEXT,
FOREIGN KEY (session_id) REFERENCES sessions (session_id) ON DELETE CASCADE
)
""")
# 创建索引
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_session_conversation
ON conversation_history (session_id)
""")
cursor.execute("""
CREATE INDEX IF NOT EXISTS idx_session_memories
ON memories (session_id)
""")
conn.commit()
def save_session(
self,
session_id: str,
model: str,
base_url: str,
code_dir: str,
max_steps: int,
stream_output: bool,
tree_depth: int,
conversation_history: List[Dict],
memories: List[Dict]
) -> bool:
"""
保存会话完整状态
Args:
session_id: 会话ID
model: LLM 模型名称
base_url: API 基础URL
code_dir: 代码目录
max_steps: 最大步骤数
stream_output: 是否流式输出
tree_depth: 目录树深度
conversation_history: 对话历史列表
memories: Memory 对象列表
Returns:
是否保存成功
"""
with self._lock:
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
now = datetime.now().isoformat()
# 检查会话是否已存在
cursor.execute("SELECT session_id FROM sessions WHERE session_id = ?", (session_id,))
exists = cursor.fetchone() is not None
# 更新或插入会话元数据
if exists:
cursor.execute("""
UPDATE sessions
SET model = ?, base_url = ?, code_dir = ?,
max_steps = ?, stream_output = ?, tree_depth = ?,
updated_at = ?, last_active = ?
WHERE session_id = ?
""", (
model, base_url, code_dir,
max_steps, 1 if stream_output else 0, tree_depth,
now, now, session_id
))
else:
cursor.execute("""
INSERT INTO sessions
(session_id, model, base_url, code_dir, max_steps,
stream_output, tree_depth, created_at, updated_at, last_active)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
session_id, model, base_url, code_dir, max_steps,
1 if stream_output else 0, tree_depth, now, now, now
))
# 删除旧的对话历史和记忆
cursor.execute("DELETE FROM conversation_history WHERE session_id = ?", (session_id,))
cursor.execute("DELETE FROM memories WHERE session_id = ?", (session_id,))
# 插入对话历史
for msg in conversation_history:
cursor.execute("""
INSERT INTO conversation_history (session_id, role, content)
VALUES (?, ?, ?)
""", (session_id, msg["role"], msg["content"]))
# 插入记忆
for memory in memories:
cursor.execute("""
INSERT INTO memories
(session_id, file_path, overview, key_definitions,
core_logic, dependencies, needed_info)
VALUES (?, ?, ?, ?, ?, ?, ?)
""", (
session_id,
memory["file"],
memory.get("overview", ""),
json.dumps(memory.get("key_definitions", []), ensure_ascii=False),
memory.get("core_logic", ""),
json.dumps(memory.get("dependencies", []), ensure_ascii=False),
memory.get("needed_info", "")
))
conn.commit()
logger.debug(f"[SessionStorage] 保存会话: {session_id}")
return True
except Exception as e:
logger.error(f"[SessionStorage] 保存会话失败: {e}")
return False
def load_session(self, session_id: str) -> Optional[Dict[str, Any]]:
"""
加载会话状态
Args:
session_id: 会话ID
Returns:
会话数据字典,如果不存在返回 None
"""
with self._lock:
try:
with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
# 查询会话元数据
cursor.execute("""
SELECT * FROM sessions WHERE session_id = ?
""", (session_id,))
session_row = cursor.fetchone()
if not session_row:
return None
# 加载对话历史
cursor.execute("""
SELECT role, content FROM conversation_history
WHERE session_id = ? ORDER BY id
""", (session_id,))
conversation_history = [
{"role": row["role"], "content": row["content"]}
for row in cursor.fetchall()
]
# 加载记忆
cursor.execute("""
SELECT * FROM memories WHERE session_id = ?
""", (session_id,))
memories = [
{
"file_path": row["file_path"],
"overview": row["overview"],
"key_definitions": json.loads(row["key_definitions"]) if row["key_definitions"] else [],
"core_logic": row["core_logic"],
"dependencies": json.loads(row["dependencies"]) if row["dependencies"] else [],
"needed_info": row["needed_info"]
}
for row in cursor.fetchall()
]
return {
"session_id": session_row["session_id"],
"model": session_row["model"],
"base_url": session_row["base_url"],
"code_dir": session_row["code_dir"],
"max_steps": session_row["max_steps"],
"stream_output": bool(session_row["stream_output"]),
"tree_depth": session_row["tree_depth"],
"created_at": session_row["created_at"],
"updated_at": session_row["updated_at"],
"last_active": session_row["last_active"],
"conversation_history": conversation_history,
"memories": memories
}
except Exception as e:
logger.error(f"[SessionStorage] 加载会话失败: {e}")
return None
def delete_session(self, session_id: str) -> bool:
"""
删除会话
Args:
session_id: 会话ID
Returns:
是否删除成功
"""
with self._lock:
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("DELETE FROM sessions WHERE session_id = ?", (session_id,))
conn.commit()
logger.debug(f"[SessionStorage] 删除会话: {session_id}")
return cursor.rowcount > 0
except Exception as e:
logger.error(f"[SessionStorage] 删除会话失败: {e}")
return False
def list_sessions(self) -> List[Dict[str, Any]]:
"""
列出所有会话(不含详细内容)
Returns:
会话列表
"""
with self._lock:
try:
with sqlite3.connect(self.db_path) as conn:
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
cursor.execute("""
SELECT session_id, model, code_dir, created_at, updated_at, last_active,
(SELECT COUNT(*) FROM conversation_history WHERE session_id = s.session_id) as message_count,
(SELECT COUNT(*) FROM memories WHERE session_id = s.session_id) as memory_count
FROM sessions s
ORDER BY updated_at DESC
""")
return [
{
"session_id": row["session_id"],
"model": row["model"],
"code_dir": row["code_dir"],
"created_at": row["created_at"],
"updated_at": row["updated_at"],
"last_active": row["last_active"],
"message_count": row["message_count"],
"memory_count": row["memory_count"]
}
for row in cursor.fetchall()
]
except Exception as e:
logger.error(f"[SessionStorage] 列出会话失败: {e}")
return []
def clear_all(self) -> bool:
"""
清空所有会话
Returns:
是否清空成功
"""
with self._lock:
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
cursor.execute("DELETE FROM sessions")
conn.commit()
logger.info(f"[SessionStorage] 清空所有会话")
return True
except Exception as e:
logger.error(f"[SessionStorage] 清空会话失败: {e}")
return False
def cleanup_old_sessions(self, days: int = 30) -> int:
"""
清理超过指定天数未活跃的会话
Args:
days: 天数阈值
Returns:
删除的会话数
"""
with self._lock:
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
# 删除超过 N 天未活跃的会话
cursor.execute("""
DELETE FROM sessions
WHERE datetime(last_active) < datetime('now', '-' || ? || ' days')
""", (days,))
count = cursor.rowcount
conn.commit()
logger.info(f"[SessionStorage] 清理了 {count} 个旧会话")
return count
except Exception as e:
logger.error(f"[SessionStorage] 清理旧会话失败: {e}")
return 0
|