| import json |
| import hashlib |
| import asyncio |
| from uuid import uuid4 |
| from typing import Union, List, Dict, Optional, Tuple |
|
|
| from pydantic import Field |
|
|
| from .memory import BaseMemory |
| from evoagentx.rag import RAGConfig, RAGEngine |
| from evoagentx.rag.schema import Corpus, Chunk, ChunkMetadata, Query, RagResult |
| from evoagentx.storages.base import StorageHandler |
| from evoagentx.core.message import Message |
| from evoagentx.core.logging import logger |
|
|
| class LongTermMemory(BaseMemory): |
| """ |
| Manages long-term storage and retrieval of memories, integrating with RAGEngine for indexing |
| and StorageHandler for persistence. |
| """ |
| storage_handler: StorageHandler = Field(..., description="Handler for persistent storage") |
| rag_config: RAGConfig = Field(..., description="Configuration for RAG engine") |
| rag_engine: RAGEngine = Field(default=None, description="RAG engine for indexing and retrieval") |
| memory_table: str = Field(default="memory", description="Database table for storing memories") |
| default_corpus_id: Optional[str] = Field(default=None, description="Default corpus ID for memory indexing") |
|
|
| def init_module(self): |
| """Initialize the RAG engine and memory indices.""" |
| super().init_module() |
| if self.rag_engine is None: |
| self.rag_engine = RAGEngine(config=self.rag_config, storage_handler=self.storage_handler) |
| if self.default_corpus_id is None: |
| self.default_corpus_id = str(uuid4()) |
| logger.info(f"Initialized LongTermMemory with corpus_id {self.default_corpus_id}") |
|
|
| def _create_memory_chunk(self, message: Message, memory_id: str) -> Chunk: |
| """Convert a Message to a Chunk for RAG indexing.""" |
| metadata = ChunkMetadata( |
| corpus_id=self.default_corpus_id, |
| memory_id=memory_id, |
| timestamp=message.timestamp, |
| action=message.action, |
| wf_goal=message.wf_goal, |
| agent=message.agent, |
| msg_type=message.msg_type.value if message.msg_type else None, |
| prompt=message.prompt, |
| next_actions=message.next_actions, |
| wf_task=message.wf_task, |
| wf_task_desc=message.wf_task_desc, |
| message_id=message.message_id, |
| content=json.dumps(message.content), |
| ) |
| return Chunk( |
| chunk_id=memory_id, |
| text=str(message.content), |
| metadata=metadata, |
| start_char_idx=0, |
| end_char_idx=len(str(message.content)), |
| ) |
|
|
| def _chunk_to_message(self, chunk: Chunk) -> Message: |
| """Convert a Chunk to a Message object.""" |
| return Message( |
| content=chunk.metadata.content, |
| action=chunk.metadata.action, |
| wf_goal=chunk.metadata.wf_goal, |
| timestamp=chunk.metadata.timestamp, |
| agent=chunk.metadata.agent, |
| msg_type=chunk.metadata.msg_type, |
| prompt=chunk.metadata.prompt, |
| next_actions=chunk.metadata.next_actions, |
| wf_task=chunk.metadata.wf_task, |
| wf_task_desc=chunk.metadata.wf_task_desc, |
| message_id=chunk.metadata.message_id, |
| ) |
|
|
| def add(self, messages: Union[Message, str, List[Union[Message, str]]]) -> List[str]: |
| """Store messages in memory and index them in RAGEngine, returning memory_ids.""" |
| if not isinstance(messages, list): |
| messages = [messages] |
| messages = [Message(content=msg) if isinstance(msg, str) else msg for msg in messages] |
| messages = [msg for msg in messages if msg.content] |
|
|
| if not messages: |
| logger.warning("No valid messages to add") |
| return [] |
|
|
| |
| existing_hashes = { |
| record["content_hash"] |
| for record in self.storage_handler.load(tables=[self.memory_table]).get(self.memory_table, []) |
| if "content_hash" in record |
| } |
| memory_ids = [str(uuid4()) for _ in messages] |
| final_messages = [] |
| final_memory_ids = [] |
| final_chunks = [] |
|
|
| for msg, memory_id in zip(messages, memory_ids): |
| content_hash = hashlib.sha256(str(msg.content).encode()).hexdigest() |
| if content_hash in existing_hashes: |
| logger.info(f"Duplicate message found (hash): {msg.content[:50]}...") |
| existing_id = next( |
| (r["memory_id"] for r in self.storage_handler.load(tables=[self.memory_table]).get(self.memory_table, []) |
| if r.get("content_hash") == content_hash), None |
| ) |
| if existing_id: |
| final_memory_ids.append(existing_id) |
| continue |
| final_messages.append(msg) |
| final_memory_ids.append(memory_id) |
| chunk = self._create_memory_chunk(msg, memory_id) |
| chunk.metadata.content_hash = content_hash |
| final_chunks.append(chunk) |
|
|
| if not final_chunks: |
| logger.info("No messages added after deduplication") |
| return final_memory_ids |
|
|
| |
| for msg in final_messages: |
| super().add_message(msg) |
|
|
| |
| corpus = Corpus(chunks=final_chunks, corpus_id=self.default_corpus_id) |
| chunk_ids = self.rag_engine.add(index_type=self.rag_config.index.index_type, nodes=corpus, corpus_id=self.default_corpus_id) |
| if not chunk_ids: |
| logger.error("Failed to index memories") |
| return final_memory_ids |
|
|
| return final_memory_ids |
|
|
| async def get(self, memory_ids: Union[str, List[str]], return_chunk: bool = True) -> List[Tuple[Union[Chunk, Message], str]]: |
| """Retrieve memories by memory_ids, returning (Message/Chunk, memory_id) tuples.""" |
| if not isinstance(memory_ids, list): |
| memory_ids = [memory_ids] |
|
|
| if not memory_ids: |
| logger.warning("No memory_ids provided for get") |
| return [] |
|
|
| try: |
| chunks = await self.rag_engine.aget( |
| corpus_id=self.default_corpus_id, |
| index_type=self.rag_config.index.index_type, |
| node_ids=memory_ids |
| ) |
| results = [(self._chunk_to_message(chunk), chunk.metadata.memory_id) if not return_chunk else (chunk, chunk.metadata.memory_id) |
| for chunk in chunks if chunk] |
| logger.info(f"Retrieved {len(results)} memories for memory_ids: {memory_ids}") |
| return results |
| except Exception as e: |
| logger.error(f"Failed to get memories: {str(e)}") |
| return [] |
|
|
| def delete(self, memory_ids: Union[str, List[str]]) -> List[bool]: |
| """Delete memories by memory_ids, returning success status for each.""" |
| if not isinstance(memory_ids, list): |
| memory_ids = [memory_ids] |
|
|
| if not memory_ids: |
| logger.warning("No memory_ids provided for deletion") |
| return [] |
|
|
| successes = [False] * len(memory_ids) |
| valid_memory_ids = [] |
|
|
| existing_chunks = asyncio.run(self.get(memory_ids, return_chunk=True)) |
| for idx, (chunk, mid) in enumerate(existing_chunks): |
| if chunk: |
| valid_memory_ids.append(mid) |
| super().remove_message(self._chunk_to_message(chunk)) |
| successes[idx] = True |
|
|
| if not valid_memory_ids: |
| logger.info("No memories found for deletion") |
| return successes |
|
|
| |
| self.rag_engine.delete( |
| corpus_id=self.default_corpus_id, |
| index_type=self.rag_config.index.index_type, |
| node_ids=valid_memory_ids |
| ) |
|
|
| return successes |
|
|
| def update(self, updates: Union[Tuple[str, Union[Message, str]], List[Tuple[str, Union[Message, str]]]]) -> List[bool]: |
| """Update memories with new content, returning success status for each.""" |
| if not isinstance(updates, list): |
| updates = [updates] |
| updates = [(mid, Message(content=msg) if isinstance(msg, str) else msg) for mid, msg in updates] |
| updates_dict = {mid: msg for mid, msg in updates if msg.content} |
|
|
| if not updates_dict: |
| logger.warning("No valid updates provided") |
| return [] |
|
|
| memory_ids = list(updates_dict.keys()) |
| existing_memories = asyncio.run(self.get(memory_ids, return_chunk=False)) |
| existing_dict = {mid: msg for msg, mid in existing_memories} |
|
|
| successes = [False] * len(updates) |
| final_updates = [] |
| final_memory_ids = [] |
|
|
| for mid, msg in updates_dict.items(): |
| if mid not in existing_dict: |
| logger.warning(f"No memory found with memory_id {mid}") |
| continue |
| final_updates.append((mid, msg)) |
| final_memory_ids.append(mid) |
| successes[memory_ids.index(mid)] = True |
| super().remove_message(existing_dict[mid]) |
|
|
| if not final_updates: |
| logger.info("No memories updated") |
| return successes |
|
|
| chunks = [self._create_memory_chunk(msg, mid) for mid, msg in final_updates] |
| for msg in [msg for _, msg in final_updates]: |
| super().add_message(msg) |
|
|
| corpus = Corpus(chunks=chunks, corpus_id=self.default_corpus_id) |
| chunk_ids = self.rag_engine.add(index_type=self.rag_config.index.index_type, nodes=corpus, corpus_id=self.default_corpus_id) |
| if not chunk_ids: |
| logger.error(f"Failed to update memories in RAG index: {final_memory_ids}") |
| return [False] * len(updates) |
|
|
| return successes |
|
|
| async def search_async(self, query: Union[str, Query], n: Optional[int] = None, |
| metadata_filters: Optional[Dict] = None, return_chunk=False) -> List[Tuple[Message, str]]: |
| """Retrieve messages from RAG index asynchronously based on a query, returning messages and memory_ids.""" |
| if isinstance(query, str): |
| query_obj = Query( |
| query_str=query, |
| top_k=n or self.rag_config.retrieval.top_k, |
| metadata_filters=metadata_filters or {} |
| ) |
| else: |
| query_obj = query |
| query_obj.top_k = n or self.rag_config.retrieval.top_k |
| if metadata_filters: |
| query_obj.metadata_filters = {**query_obj.metadata_filters, **metadata_filters} if query_obj.metadata_filters else metadata_filters |
|
|
| try: |
| result: RagResult = await self.rag_engine.query_async(query_obj, corpus_id=self.default_corpus_id) |
| if return_chunk: |
| return [(chunk, chunk.metadata.memory_id) for chunk in result.corpus.chunks] |
| else: |
| messages = [(self._chunk_to_message(chunk), chunk.metadata.memory_id) for chunk in result.corpus.chunks] |
| logger.info(f"Retrieved {len(messages)} memories for query: {query_obj.query_str}") |
| return messages[:n] if n else messages |
| except Exception as e: |
| logger.error(f"Failed to search memories: {str(e)}") |
| return [] |
|
|
| def search(self, query: Union[str, Query], n: Optional[int] = None, |
| metadata_filters: Optional[Dict] = None) -> List[Tuple[Message, str]]: |
| """Synchronous wrapper for searching memories.""" |
| return asyncio.run(self.search_async(query, n, metadata_filters)) |
|
|
| def clear(self) -> None: |
| """Clear all messages and indices.""" |
| super().clear() |
| self.rag_engine.clear(corpus_id=self.default_corpus_id) |
| logger.info(f"Cleared LongTermMemory with corpus_id {self.default_corpus_id}") |
|
|
| def save(self, save_path: Optional[str] = None) -> None: |
| """Save all indices and memory data to database.""" |
| self.rag_engine.save(output_path=save_path, corpus_id=self.default_corpus_id, table=self.memory_table) |
|
|
| def load(self, save_path: Optional[str] = None) -> List[str]: |
| """Load memory data from database and reconstruct indices, returning memory_ids.""" |
| return self.rag_engine.load(source=save_path, corpus_id=self.default_corpus_id, table=self.memory_table) |