| import os |
| import chromadb |
| from chromadb.config import Settings |
| from typing import List, Dict, Optional |
| from config import Config |
| from core.openai_client import OpenAIClient |
|
|
| class MemoryManager: |
| """向量记忆管理器 - 存储和检索角色相关的文本片段""" |
| |
| def __init__(self, character_name: str): |
| self.character_name = character_name |
| self.client = OpenAIClient.get_client() |
| |
| |
| os.makedirs(Config.VECTOR_DB_PATH, exist_ok=True) |
| |
| try: |
| self.chroma_client = chromadb.Client(Settings( |
| persist_directory=Config.VECTOR_DB_PATH, |
| anonymized_telemetry=False |
| )) |
| except: |
| |
| self.chroma_client = chromadb.PersistentClient( |
| path=Config.VECTOR_DB_PATH |
| ) |
| |
| |
| collection_name = f"char_{character_name.replace(' ', '_').lower()}" |
| collection_name = collection_name[:63] |
| |
| try: |
| self.collection = self.chroma_client.get_or_create_collection( |
| name=collection_name, |
| metadata={"character": character_name} |
| ) |
| except Exception as e: |
| print(f"创建集合时出错: {e}") |
| |
| collection_name = f"char_{hash(character_name) % 10000}" |
| self.collection = self.chroma_client.get_or_create_collection( |
| name=collection_name, |
| metadata={"character": character_name} |
| ) |
| |
| def add_text_chunks(self, chunks: List[Dict], character_chunks: List[int]): |
| """添加与角色相关的文本块 |
| |
| Args: |
| chunks: 所有文本块 |
| character_chunks: 角色出现的文本块ID列表 |
| """ |
| |
| documents = [] |
| metadatas = [] |
| ids = [] |
| |
| for chunk_id in character_chunks: |
| if chunk_id < len(chunks): |
| chunk = chunks[chunk_id] |
| documents.append(chunk['text']) |
| metadatas.append({ |
| 'chunk_id': chunk_id, |
| 'position': chunk['start'] |
| }) |
| ids.append(f"chunk_{chunk_id}") |
| |
| if documents: |
| try: |
| |
| batch_size = 100 |
| for i in range(0, len(documents), batch_size): |
| batch_docs = documents[i:i+batch_size] |
| batch_metas = metadatas[i:i+batch_size] |
| batch_ids = ids[i:i+batch_size] |
| |
| self.collection.add( |
| documents=batch_docs, |
| metadatas=batch_metas, |
| ids=batch_ids |
| ) |
| |
| print(f"已为 {self.character_name} 添加 {len(documents)} 个文本块到向量库") |
| except Exception as e: |
| print(f"添加文本块到向量库失败: {e}") |
| print("将继续运行,但不使用记忆功能") |
| |
| def search_relevant_context(self, query: str, n_results: int = None) -> List[str]: |
| """检索与查询相关的上下文 |
| |
| Args: |
| query: 查询文本 |
| n_results: 返回结果数量 |
| |
| Returns: |
| 相关文本片段列表 |
| """ |
| |
| n_results = n_results or Config.MAX_MEMORY_RETRIEVAL |
| |
| try: |
| collection_count = self.collection.count() |
| if collection_count == 0: |
| return [] |
| |
| actual_n_results = min(n_results, collection_count) |
| |
| results = self.collection.query( |
| query_texts=[query], |
| n_results=actual_n_results |
| ) |
| |
| if results and results['documents']: |
| return results['documents'][0] |
| return [] |
| |
| except Exception as e: |
| print(f"检索失败: {e}") |
| return [] |
| |
| def get_embedding(self, text: str) -> List[float]: |
| """获取文本嵌入向量 |
| |
| Args: |
| text: 输入文本 |
| |
| Returns: |
| 嵌入向量 |
| """ |
| try: |
| response = self.client.embeddings.create( |
| model=Config.EMBEDDING_MODEL, |
| input=text |
| ) |
| return response.data[0].embedding |
| except Exception as e: |
| print(f"获取嵌入失败: {e}") |
| return [] |
| |
| def get_statistics(self) -> Dict: |
| """获取记忆库统计信息 |
| |
| Returns: |
| 统计信息字典 |
| """ |
| try: |
| count = self.collection.count() |
| return { |
| 'character': self.character_name, |
| 'chunk_count': count, |
| 'collection_name': self.collection.name |
| } |
| except: |
| return { |
| 'character': self.character_name, |
| 'chunk_count': 0, |
| 'collection_name': 'unknown' |
| } |
| |
| def clear(self): |
| """清空记忆库""" |
| try: |
| |
| self.chroma_client.delete_collection(self.collection.name) |
| print(f"已清空 {self.character_name} 的记忆库") |
| except Exception as e: |
| print(f"清空记忆库失败: {e}") |