Spaces:
Sleeping
Sleeping
| """ | |
| 会话持久化存储 | |
| 使用 SQLite 存储会话状态,支持服务器重启后恢复会话 | |
| """ | |
| import sqlite3 | |
| import json | |
| import threading | |
| from pathlib import Path | |
| from datetime import datetime | |
| from typing import Dict, List, Optional, Any | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class SessionStorage: | |
| """会话存储管理器""" | |
| def __init__(self, db_path: str = "./sessions.db"): | |
| """ | |
| 初始化会话存储 | |
| Args: | |
| db_path: SQLite 数据库文件路径 | |
| """ | |
| self.db_path = Path(db_path) | |
| self._lock = threading.Lock() | |
| # 确保数据库目录存在 | |
| self.db_path.parent.mkdir(parents=True, exist_ok=True) | |
| # 初始化数据库表 | |
| self._init_db() | |
| def _init_db(self): | |
| """初始化数据库表结构""" | |
| with sqlite3.connect(self.db_path) as conn: | |
| cursor = conn.cursor() | |
| # 会话表:存储会话元数据和配置 | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS sessions ( | |
| session_id TEXT PRIMARY KEY, | |
| model TEXT NOT NULL, | |
| base_url TEXT NOT NULL, | |
| code_dir TEXT NOT NULL, | |
| max_steps INTEGER NOT NULL, | |
| stream_output INTEGER NOT NULL, | |
| tree_depth INTEGER NOT NULL, | |
| created_at TEXT NOT NULL, | |
| updated_at TEXT NOT NULL, | |
| last_active TEXT | |
| ) | |
| """) | |
| # 对话历史表:存储会话的对话记录 | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS conversation_history ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| session_id TEXT NOT NULL, | |
| role TEXT NOT NULL, | |
| content TEXT NOT NULL, | |
| FOREIGN KEY (session_id) REFERENCES sessions (session_id) ON DELETE CASCADE | |
| ) | |
| """) | |
| # 记忆表:存储 Memory 对象 | |
| cursor.execute(""" | |
| CREATE TABLE IF NOT EXISTS memories ( | |
| id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| session_id TEXT NOT NULL, | |
| file_path TEXT NOT NULL, | |
| overview TEXT, | |
| key_definitions TEXT, -- JSON 数组 | |
| core_logic TEXT, | |
| dependencies TEXT, -- JSON 数组 | |
| needed_info TEXT, | |
| FOREIGN KEY (session_id) REFERENCES sessions (session_id) ON DELETE CASCADE | |
| ) | |
| """) | |
| # 创建索引 | |
| cursor.execute(""" | |
| CREATE INDEX IF NOT EXISTS idx_session_conversation | |
| ON conversation_history (session_id) | |
| """) | |
| cursor.execute(""" | |
| CREATE INDEX IF NOT EXISTS idx_session_memories | |
| ON memories (session_id) | |
| """) | |
| conn.commit() | |
| def save_session( | |
| self, | |
| session_id: str, | |
| model: str, | |
| base_url: str, | |
| code_dir: str, | |
| max_steps: int, | |
| stream_output: bool, | |
| tree_depth: int, | |
| conversation_history: List[Dict], | |
| memories: List[Dict] | |
| ) -> bool: | |
| """ | |
| 保存会话完整状态 | |
| Args: | |
| session_id: 会话ID | |
| model: LLM 模型名称 | |
| base_url: API 基础URL | |
| code_dir: 代码目录 | |
| max_steps: 最大步骤数 | |
| stream_output: 是否流式输出 | |
| tree_depth: 目录树深度 | |
| conversation_history: 对话历史列表 | |
| memories: Memory 对象列表 | |
| Returns: | |
| 是否保存成功 | |
| """ | |
| with self._lock: | |
| try: | |
| with sqlite3.connect(self.db_path) as conn: | |
| cursor = conn.cursor() | |
| now = datetime.now().isoformat() | |
| # 检查会话是否已存在 | |
| cursor.execute("SELECT session_id FROM sessions WHERE session_id = ?", (session_id,)) | |
| exists = cursor.fetchone() is not None | |
| # 更新或插入会话元数据 | |
| if exists: | |
| cursor.execute(""" | |
| UPDATE sessions | |
| SET model = ?, base_url = ?, code_dir = ?, | |
| max_steps = ?, stream_output = ?, tree_depth = ?, | |
| updated_at = ?, last_active = ? | |
| WHERE session_id = ? | |
| """, ( | |
| model, base_url, code_dir, | |
| max_steps, 1 if stream_output else 0, tree_depth, | |
| now, now, session_id | |
| )) | |
| else: | |
| cursor.execute(""" | |
| INSERT INTO sessions | |
| (session_id, model, base_url, code_dir, max_steps, | |
| stream_output, tree_depth, created_at, updated_at, last_active) | |
| VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) | |
| """, ( | |
| session_id, model, base_url, code_dir, max_steps, | |
| 1 if stream_output else 0, tree_depth, now, now, now | |
| )) | |
| # 删除旧的对话历史和记忆 | |
| cursor.execute("DELETE FROM conversation_history WHERE session_id = ?", (session_id,)) | |
| cursor.execute("DELETE FROM memories WHERE session_id = ?", (session_id,)) | |
| # 插入对话历史 | |
| for msg in conversation_history: | |
| cursor.execute(""" | |
| INSERT INTO conversation_history (session_id, role, content) | |
| VALUES (?, ?, ?) | |
| """, (session_id, msg["role"], msg["content"])) | |
| # 插入记忆 | |
| for memory in memories: | |
| cursor.execute(""" | |
| INSERT INTO memories | |
| (session_id, file_path, overview, key_definitions, | |
| core_logic, dependencies, needed_info) | |
| VALUES (?, ?, ?, ?, ?, ?, ?) | |
| """, ( | |
| session_id, | |
| memory["file"], | |
| memory.get("overview", ""), | |
| json.dumps(memory.get("key_definitions", []), ensure_ascii=False), | |
| memory.get("core_logic", ""), | |
| json.dumps(memory.get("dependencies", []), ensure_ascii=False), | |
| memory.get("needed_info", "") | |
| )) | |
| conn.commit() | |
| logger.debug(f"[SessionStorage] 保存会话: {session_id}") | |
| return True | |
| except Exception as e: | |
| logger.error(f"[SessionStorage] 保存会话失败: {e}") | |
| return False | |
| def load_session(self, session_id: str) -> Optional[Dict[str, Any]]: | |
| """ | |
| 加载会话状态 | |
| Args: | |
| session_id: 会话ID | |
| Returns: | |
| 会话数据字典,如果不存在返回 None | |
| """ | |
| with self._lock: | |
| try: | |
| with sqlite3.connect(self.db_path) as conn: | |
| conn.row_factory = sqlite3.Row | |
| cursor = conn.cursor() | |
| # 查询会话元数据 | |
| cursor.execute(""" | |
| SELECT * FROM sessions WHERE session_id = ? | |
| """, (session_id,)) | |
| session_row = cursor.fetchone() | |
| if not session_row: | |
| return None | |
| # 加载对话历史 | |
| cursor.execute(""" | |
| SELECT role, content FROM conversation_history | |
| WHERE session_id = ? ORDER BY id | |
| """, (session_id,)) | |
| conversation_history = [ | |
| {"role": row["role"], "content": row["content"]} | |
| for row in cursor.fetchall() | |
| ] | |
| # 加载记忆 | |
| cursor.execute(""" | |
| SELECT * FROM memories WHERE session_id = ? | |
| """, (session_id,)) | |
| memories = [ | |
| { | |
| "file_path": row["file_path"], | |
| "overview": row["overview"], | |
| "key_definitions": json.loads(row["key_definitions"]) if row["key_definitions"] else [], | |
| "core_logic": row["core_logic"], | |
| "dependencies": json.loads(row["dependencies"]) if row["dependencies"] else [], | |
| "needed_info": row["needed_info"] | |
| } | |
| for row in cursor.fetchall() | |
| ] | |
| return { | |
| "session_id": session_row["session_id"], | |
| "model": session_row["model"], | |
| "base_url": session_row["base_url"], | |
| "code_dir": session_row["code_dir"], | |
| "max_steps": session_row["max_steps"], | |
| "stream_output": bool(session_row["stream_output"]), | |
| "tree_depth": session_row["tree_depth"], | |
| "created_at": session_row["created_at"], | |
| "updated_at": session_row["updated_at"], | |
| "last_active": session_row["last_active"], | |
| "conversation_history": conversation_history, | |
| "memories": memories | |
| } | |
| except Exception as e: | |
| logger.error(f"[SessionStorage] 加载会话失败: {e}") | |
| return None | |
| def delete_session(self, session_id: str) -> bool: | |
| """ | |
| 删除会话 | |
| Args: | |
| session_id: 会话ID | |
| Returns: | |
| 是否删除成功 | |
| """ | |
| with self._lock: | |
| try: | |
| with sqlite3.connect(self.db_path) as conn: | |
| cursor = conn.cursor() | |
| cursor.execute("DELETE FROM sessions WHERE session_id = ?", (session_id,)) | |
| conn.commit() | |
| logger.debug(f"[SessionStorage] 删除会话: {session_id}") | |
| return cursor.rowcount > 0 | |
| except Exception as e: | |
| logger.error(f"[SessionStorage] 删除会话失败: {e}") | |
| return False | |
| def list_sessions(self) -> List[Dict[str, Any]]: | |
| """ | |
| 列出所有会话(不含详细内容) | |
| Returns: | |
| 会话列表 | |
| """ | |
| with self._lock: | |
| try: | |
| with sqlite3.connect(self.db_path) as conn: | |
| conn.row_factory = sqlite3.Row | |
| cursor = conn.cursor() | |
| cursor.execute(""" | |
| SELECT session_id, model, code_dir, created_at, updated_at, last_active, | |
| (SELECT COUNT(*) FROM conversation_history WHERE session_id = s.session_id) as message_count, | |
| (SELECT COUNT(*) FROM memories WHERE session_id = s.session_id) as memory_count | |
| FROM sessions s | |
| ORDER BY updated_at DESC | |
| """) | |
| return [ | |
| { | |
| "session_id": row["session_id"], | |
| "model": row["model"], | |
| "code_dir": row["code_dir"], | |
| "created_at": row["created_at"], | |
| "updated_at": row["updated_at"], | |
| "last_active": row["last_active"], | |
| "message_count": row["message_count"], | |
| "memory_count": row["memory_count"] | |
| } | |
| for row in cursor.fetchall() | |
| ] | |
| except Exception as e: | |
| logger.error(f"[SessionStorage] 列出会话失败: {e}") | |
| return [] | |
| def clear_all(self) -> bool: | |
| """ | |
| 清空所有会话 | |
| Returns: | |
| 是否清空成功 | |
| """ | |
| with self._lock: | |
| try: | |
| with sqlite3.connect(self.db_path) as conn: | |
| cursor = conn.cursor() | |
| cursor.execute("DELETE FROM sessions") | |
| conn.commit() | |
| logger.info(f"[SessionStorage] 清空所有会话") | |
| return True | |
| except Exception as e: | |
| logger.error(f"[SessionStorage] 清空会话失败: {e}") | |
| return False | |
| def cleanup_old_sessions(self, days: int = 30) -> int: | |
| """ | |
| 清理超过指定天数未活跃的会话 | |
| Args: | |
| days: 天数阈值 | |
| Returns: | |
| 删除的会话数 | |
| """ | |
| with self._lock: | |
| try: | |
| with sqlite3.connect(self.db_path) as conn: | |
| cursor = conn.cursor() | |
| # 删除超过 N 天未活跃的会话 | |
| cursor.execute(""" | |
| DELETE FROM sessions | |
| WHERE datetime(last_active) < datetime('now', '-' || ? || ' days') | |
| """, (days,)) | |
| count = cursor.rowcount | |
| conn.commit() | |
| logger.info(f"[SessionStorage] 清理了 {count} 个旧会话") | |
| return count | |
| except Exception as e: | |
| logger.error(f"[SessionStorage] 清理旧会话失败: {e}") | |
| return 0 | |