| """ |
| CoALA-inspired Memory Architecture for AETHER. |
| Four modules: Working, Episodic, Semantic, Procedural. |
| Plus Temporal Memory for long-horizon reasoning. |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| from collections import deque |
| from typing import Dict, List, Any, Optional |
| import time |
| import json |
|
|
|
|
| class WorkingMemory: |
| """ |
| Active scratchpad for current reasoning cycle. |
| Limited capacity, fast access. |
| """ |
| def __init__(self, capacity: int = 16): |
| self.capacity = capacity |
| self.buffer: deque = deque(maxlen=capacity) |
| self.attention_weights = nn.Parameter(torch.ones(capacity)) |
| |
| def store(self, item: Dict[str, Any]): |
| item["_timestamp"] = time.time() |
| self.buffer.append(item) |
| |
| def retrieve(self, query: str, top_k: int = 3) -> List[Dict]: |
| """Simple attention-based retrieval from working buffer.""" |
| if not self.buffer: |
| return [] |
| |
| |
| scores = [] |
| for i, item in enumerate(self.buffer): |
| score = sum(1 for k in item if query.lower() in str(k).lower()) |
| scores.append(score * torch.sigmoid(self.attention_weights[i]).item()) |
| |
| |
| indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_k] |
| return [list(self.buffer)[i] for i in indices] |
| |
| def export(self) -> List[Dict]: |
| return list(self.buffer) |
| |
| def __len__(self): |
| return len(self.buffer) |
|
|
|
|
| class EpisodicMemory: |
| """ |
| Experience buffer storing past interactions. |
| Temporal structure for long-horizon reasoning. |
| """ |
| def __init__(self, buffer_size: int = 1000): |
| self.buffer_size = buffer_size |
| self.buffer: deque = deque(maxlen=buffer_size) |
| |
| def store(self, episode: Dict[str, Any]): |
| episode["_timestamp"] = time.time() |
| self.buffer.append(episode) |
| |
| def retrieve_similar(self, query: str, top_k: int = 5) -> List[Dict]: |
| """Retrieve episodes similar to query.""" |
| if not self.buffer: |
| return [] |
| |
| |
| scores = [] |
| for item in self.buffer: |
| text = json.dumps(item) |
| score = sum(1 for word in query.lower().split() if word in text.lower()) |
| scores.append(score) |
| |
| indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_k] |
| return [list(self.buffer)[i] for i in indices] |
| |
| def get_recent(self, n: int = 10) -> List[Dict]: |
| """Get n most recent episodes.""" |
| return list(self.buffer)[-n:] |
| |
| def export(self) -> List[Dict]: |
| return list(self.buffer) |
| |
| def __len__(self): |
| return len(self.buffer) |
|
|
|
|
| class SemanticMemory: |
| """ |
| World knowledge - external and learned facts. |
| Backed by knowledge graph (see knowledge.py). |
| """ |
| def __init__(self): |
| self.facts: Dict[str, Any] = {} |
| |
| def store_fact(self, key: str, value: Any, confidence: float = 1.0): |
| self.facts[key] = {"value": value, "confidence": confidence, "timestamp": time.time()} |
| |
| def retrieve(self, key: str) -> Optional[Dict]: |
| return self.facts.get(key) |
| |
| def query(self, query: str) -> List[Dict]: |
| """Simple prefix matching.""" |
| return [v for k, v in self.facts.items() if query.lower() in k.lower()] |
| |
| def export(self) -> Dict: |
| return self.facts |
|
|
|
|
| class ProceduralMemory: |
| """ |
| Learned skills, tool definitions, and code implementations. |
| Inspired by Yunjue Agent's tool accumulation. |
| """ |
| def __init__(self): |
| self.tools: Dict[str, Dict] = {} |
| self.tool_usage_stats: Dict[str, int] = {} |
| |
| def register_tool(self, name: str, code: str, description: str, |
| tags: List[str] = None): |
| self.tools[name] = { |
| "code": code, |
| "description": description, |
| "tags": tags or [], |
| "registered_at": time.time(), |
| "version": 1, |
| } |
| self.tool_usage_stats[name] = 0 |
| |
| def get_tool(self, name: str) -> Optional[Dict]: |
| if name in self.tools: |
| self.tool_usage_stats[name] += 1 |
| return self.tools[name] |
| return None |
| |
| def search_tools(self, query: str) -> List[Dict]: |
| """Search tools by description or tags.""" |
| results = [] |
| for name, tool in self.tools.items(): |
| text = f"{name} {tool['description']} {' '.join(tool['tags'])}" |
| if query.lower() in text.lower(): |
| results.append({"name": name, **tool}) |
| return results |
| |
| def merge_tools(self, tool_cluster: List[str]) -> Optional[str]: |
| """ |
| Merge functionally redundant tools (Yunjue-style tool absorption). |
| Returns name of merged tool or None. |
| """ |
| if len(tool_cluster) < 2: |
| return None |
| |
| |
| canonical = max(tool_cluster, key=lambda t: self.tool_usage_stats.get(t, 0)) |
| |
| |
| merged_desc = " | ".join( |
| self.tools[t]["description"] for t in tool_cluster if t in self.tools |
| ) |
| self.tools[canonical]["description"] = merged_desc |
| self.tools[canonical]["version"] += 1 |
| |
| |
| for t in tool_cluster: |
| if t != canonical and t in self.tools: |
| del self.tools[t] |
| |
| return canonical |
| |
| def export(self) -> Dict: |
| return { |
| "tools": self.tools, |
| "usage_stats": self.tool_usage_stats, |
| } |
|
|
|
|
| class CoALAMemory: |
| """ |
| Unified memory system following CoALA cognitive architecture. |
| Combines Working, Episodic, Semantic, and Procedural memory. |
| """ |
| def __init__(self, capacity: int = 16): |
| self.working = WorkingMemory(capacity=capacity) |
| self.episodic = EpisodicMemory(buffer_size=1000) |
| self.semantic = SemanticMemory() |
| self.procedural = ProceduralMemory() |
| |
| def store(self, item: Dict[str, Any], memory_type: str = "working"): |
| if memory_type == "working": |
| self.working.store(item) |
| elif memory_type == "episodic": |
| self.episodic.store(item) |
| elif memory_type == "semantic": |
| for k, v in item.items(): |
| self.semantic.store_fact(k, v) |
| elif memory_type == "procedural": |
| if "name" in item and "code" in item: |
| self.procedural.register_tool( |
| item["name"], item["code"], |
| item.get("description", ""), |
| item.get("tags", []) |
| ) |
| |
| def retrieve(self, query: str, memory_type: str = "all", top_k: int = 5) -> List[Dict]: |
| if memory_type == "all": |
| results = [] |
| results.extend(self.working.retrieve(query, top_k=top_k//2)) |
| results.extend(self.episodic.retrieve_similar(query, top_k=top_k)) |
| results.extend(self.semantic.query(query)[:top_k]) |
| return results[:top_k] |
| elif memory_type == "working": |
| return self.working.retrieve(query, top_k) |
| elif memory_type == "episodic": |
| return self.episodic.retrieve_similar(query, top_k) |
| elif memory_type == "semantic": |
| return self.semantic.query(query)[:top_k] |
| elif memory_type == "procedural": |
| return self.procedural.search_tools(query) |
| return [] |
| |
| @property |
| def buffer(self): |
| """Alias for working memory buffer.""" |
| return self.working.buffer |
| |
| def export(self) -> Dict[str, Any]: |
| return { |
| "working": self.working.export(), |
| "episodic": self.episodic.export(), |
| "semantic": self.semantic.export(), |
| "procedural": self.procedural.export(), |
| } |
|
|
|
|
| class TemporalMemory(nn.Module): |
| """ |
| Time-sensitive memory with learned temporal attention. |
| Enables long-horizon reasoning and contextual adaptation. |
| Uses a simple LSTM-like gating mechanism. |
| """ |
| def __init__(self, buffer_size: int = 1000, hidden_dim: int = 64): |
| super().__init__() |
| self.buffer_size = buffer_size |
| self.hidden_dim = hidden_dim |
| self.buffer: deque = deque(maxlen=buffer_size) |
| |
| |
| self.temporal_gate = nn.Sequential( |
| nn.Linear(2, hidden_dim), |
| nn.ReLU(), |
| nn.Linear(hidden_dim, 1), |
| nn.Sigmoid(), |
| ) |
| |
| def store(self, event: Dict[str, Any]): |
| event["_timestamp"] = time.time() |
| self.buffer.append(event) |
| |
| def retrieve_context(self, current_time: Optional[float] = None, |
| lookback_window: float = 3600.0) -> List[Dict]: |
| """ |
| Retrieve events within lookback window, weighted by recency. |
| """ |
| current_time = current_time or time.time() |
| relevant = [] |
| |
| for event in self.buffer: |
| age = current_time - event.get("_timestamp", current_time) |
| if age <= lookback_window: |
| |
| recency_score = torch.exp(torch.tensor(-age / lookback_window)) |
| relevant.append({ |
| **event, |
| "recency_score": recency_score.item(), |
| "age_seconds": age, |
| }) |
| |
| |
| relevant.sort(key=lambda x: x["recency_score"], reverse=True) |
| return relevant |
| |
| def retrieve_with_attention(self, query_embedding: torch.Tensor, |
| top_k: int = 10) -> List[Dict]: |
| """ |
| Attention-based retrieval combining temporal and semantic relevance. |
| (Placeholder - would use actual embeddings in full implementation) |
| """ |
| return self.retrieve_context()[:top_k] |
| |
| def export(self) -> List[Dict]: |
| return list(self.buffer) |
| |
| @property |
| def buffer_contents(self): |
| return list(self.buffer) |
| |
| def __len__(self): |
| return len(self.buffer) |
|
|