""" CoALA-inspired Memory Architecture for AETHER. Four modules: Working, Episodic, Semantic, Procedural. Plus Temporal Memory for long-horizon reasoning. """ import torch import torch.nn as nn from collections import deque from typing import Dict, List, Any, Optional import time import json class WorkingMemory: """ Active scratchpad for current reasoning cycle. Limited capacity, fast access. """ def __init__(self, capacity: int = 16): self.capacity = capacity self.buffer: deque = deque(maxlen=capacity) self.attention_weights = nn.Parameter(torch.ones(capacity)) def store(self, item: Dict[str, Any]): item["_timestamp"] = time.time() self.buffer.append(item) def retrieve(self, query: str, top_k: int = 3) -> List[Dict]: """Simple attention-based retrieval from working buffer.""" if not self.buffer: return [] # Compute relevance scores (simplified) scores = [] for i, item in enumerate(self.buffer): score = sum(1 for k in item if query.lower() in str(k).lower()) scores.append(score * torch.sigmoid(self.attention_weights[i]).item()) # Get top-k indices indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_k] return [list(self.buffer)[i] for i in indices] def export(self) -> List[Dict]: return list(self.buffer) def __len__(self): return len(self.buffer) class EpisodicMemory: """ Experience buffer storing past interactions. Temporal structure for long-horizon reasoning. """ def __init__(self, buffer_size: int = 1000): self.buffer_size = buffer_size self.buffer: deque = deque(maxlen=buffer_size) def store(self, episode: Dict[str, Any]): episode["_timestamp"] = time.time() self.buffer.append(episode) def retrieve_similar(self, query: str, top_k: int = 5) -> List[Dict]: """Retrieve episodes similar to query.""" if not self.buffer: return [] # Simple keyword matching (replace with embedding-based in production) scores = [] for item in self.buffer: text = json.dumps(item) score = sum(1 for word in query.lower().split() if word in text.lower()) scores.append(score) indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_k] return [list(self.buffer)[i] for i in indices] def get_recent(self, n: int = 10) -> List[Dict]: """Get n most recent episodes.""" return list(self.buffer)[-n:] def export(self) -> List[Dict]: return list(self.buffer) def __len__(self): return len(self.buffer) class SemanticMemory: """ World knowledge - external and learned facts. Backed by knowledge graph (see knowledge.py). """ def __init__(self): self.facts: Dict[str, Any] = {} def store_fact(self, key: str, value: Any, confidence: float = 1.0): self.facts[key] = {"value": value, "confidence": confidence, "timestamp": time.time()} def retrieve(self, key: str) -> Optional[Dict]: return self.facts.get(key) def query(self, query: str) -> List[Dict]: """Simple prefix matching.""" return [v for k, v in self.facts.items() if query.lower() in k.lower()] def export(self) -> Dict: return self.facts class ProceduralMemory: """ Learned skills, tool definitions, and code implementations. Inspired by Yunjue Agent's tool accumulation. """ def __init__(self): self.tools: Dict[str, Dict] = {} self.tool_usage_stats: Dict[str, int] = {} def register_tool(self, name: str, code: str, description: str, tags: List[str] = None): self.tools[name] = { "code": code, "description": description, "tags": tags or [], "registered_at": time.time(), "version": 1, } self.tool_usage_stats[name] = 0 def get_tool(self, name: str) -> Optional[Dict]: if name in self.tools: self.tool_usage_stats[name] += 1 return self.tools[name] return None def search_tools(self, query: str) -> List[Dict]: """Search tools by description or tags.""" results = [] for name, tool in self.tools.items(): text = f"{name} {tool['description']} {' '.join(tool['tags'])}" if query.lower() in text.lower(): results.append({"name": name, **tool}) return results def merge_tools(self, tool_cluster: List[str]) -> Optional[str]: """ Merge functionally redundant tools (Yunjue-style tool absorption). Returns name of merged tool or None. """ if len(tool_cluster) < 2: return None # Simple merge: keep highest usage tool as canonical canonical = max(tool_cluster, key=lambda t: self.tool_usage_stats.get(t, 0)) # Merge descriptions merged_desc = " | ".join( self.tools[t]["description"] for t in tool_cluster if t in self.tools ) self.tools[canonical]["description"] = merged_desc self.tools[canonical]["version"] += 1 # Remove redundant tools for t in tool_cluster: if t != canonical and t in self.tools: del self.tools[t] return canonical def export(self) -> Dict: return { "tools": self.tools, "usage_stats": self.tool_usage_stats, } class CoALAMemory: """ Unified memory system following CoALA cognitive architecture. Combines Working, Episodic, Semantic, and Procedural memory. """ def __init__(self, capacity: int = 16): self.working = WorkingMemory(capacity=capacity) self.episodic = EpisodicMemory(buffer_size=1000) self.semantic = SemanticMemory() self.procedural = ProceduralMemory() def store(self, item: Dict[str, Any], memory_type: str = "working"): if memory_type == "working": self.working.store(item) elif memory_type == "episodic": self.episodic.store(item) elif memory_type == "semantic": for k, v in item.items(): self.semantic.store_fact(k, v) elif memory_type == "procedural": if "name" in item and "code" in item: self.procedural.register_tool( item["name"], item["code"], item.get("description", ""), item.get("tags", []) ) def retrieve(self, query: str, memory_type: str = "all", top_k: int = 5) -> List[Dict]: if memory_type == "all": results = [] results.extend(self.working.retrieve(query, top_k=top_k//2)) results.extend(self.episodic.retrieve_similar(query, top_k=top_k)) results.extend(self.semantic.query(query)[:top_k]) return results[:top_k] elif memory_type == "working": return self.working.retrieve(query, top_k) elif memory_type == "episodic": return self.episodic.retrieve_similar(query, top_k) elif memory_type == "semantic": return self.semantic.query(query)[:top_k] elif memory_type == "procedural": return self.procedural.search_tools(query) return [] @property def buffer(self): """Alias for working memory buffer.""" return self.working.buffer def export(self) -> Dict[str, Any]: return { "working": self.working.export(), "episodic": self.episodic.export(), "semantic": self.semantic.export(), "procedural": self.procedural.export(), } class TemporalMemory(nn.Module): """ Time-sensitive memory with learned temporal attention. Enables long-horizon reasoning and contextual adaptation. Uses a simple LSTM-like gating mechanism. """ def __init__(self, buffer_size: int = 1000, hidden_dim: int = 64): super().__init__() self.buffer_size = buffer_size self.hidden_dim = hidden_dim self.buffer: deque = deque(maxlen=buffer_size) # Temporal attention network self.temporal_gate = nn.Sequential( nn.Linear(2, hidden_dim), nn.ReLU(), nn.Linear(hidden_dim, 1), nn.Sigmoid(), ) def store(self, event: Dict[str, Any]): event["_timestamp"] = time.time() self.buffer.append(event) def retrieve_context(self, current_time: Optional[float] = None, lookback_window: float = 3600.0) -> List[Dict]: """ Retrieve events within lookback window, weighted by recency. """ current_time = current_time or time.time() relevant = [] for event in self.buffer: age = current_time - event.get("_timestamp", current_time) if age <= lookback_window: # Temporal relevance score: exponential decay recency_score = torch.exp(torch.tensor(-age / lookback_window)) relevant.append({ **event, "recency_score": recency_score.item(), "age_seconds": age, }) # Sort by recency score relevant.sort(key=lambda x: x["recency_score"], reverse=True) return relevant def retrieve_with_attention(self, query_embedding: torch.Tensor, top_k: int = 10) -> List[Dict]: """ Attention-based retrieval combining temporal and semantic relevance. (Placeholder - would use actual embeddings in full implementation) """ return self.retrieve_context()[:top_k] def export(self) -> List[Dict]: return list(self.buffer) @property def buffer_contents(self): return list(self.buffer) def __len__(self): return len(self.buffer)