aether-core / aether /memory.py
camdog920's picture
Upload aether/memory.py
22026ba verified
"""
CoALA-inspired Memory Architecture for AETHER.
Four modules: Working, Episodic, Semantic, Procedural.
Plus Temporal Memory for long-horizon reasoning.
"""
import torch
import torch.nn as nn
from collections import deque
from typing import Dict, List, Any, Optional
import time
import json
class WorkingMemory:
"""
Active scratchpad for current reasoning cycle.
Limited capacity, fast access.
"""
def __init__(self, capacity: int = 16):
self.capacity = capacity
self.buffer: deque = deque(maxlen=capacity)
self.attention_weights = nn.Parameter(torch.ones(capacity))
def store(self, item: Dict[str, Any]):
item["_timestamp"] = time.time()
self.buffer.append(item)
def retrieve(self, query: str, top_k: int = 3) -> List[Dict]:
"""Simple attention-based retrieval from working buffer."""
if not self.buffer:
return []
# Compute relevance scores (simplified)
scores = []
for i, item in enumerate(self.buffer):
score = sum(1 for k in item if query.lower() in str(k).lower())
scores.append(score * torch.sigmoid(self.attention_weights[i]).item())
# Get top-k indices
indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_k]
return [list(self.buffer)[i] for i in indices]
def export(self) -> List[Dict]:
return list(self.buffer)
def __len__(self):
return len(self.buffer)
class EpisodicMemory:
"""
Experience buffer storing past interactions.
Temporal structure for long-horizon reasoning.
"""
def __init__(self, buffer_size: int = 1000):
self.buffer_size = buffer_size
self.buffer: deque = deque(maxlen=buffer_size)
def store(self, episode: Dict[str, Any]):
episode["_timestamp"] = time.time()
self.buffer.append(episode)
def retrieve_similar(self, query: str, top_k: int = 5) -> List[Dict]:
"""Retrieve episodes similar to query."""
if not self.buffer:
return []
# Simple keyword matching (replace with embedding-based in production)
scores = []
for item in self.buffer:
text = json.dumps(item)
score = sum(1 for word in query.lower().split() if word in text.lower())
scores.append(score)
indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_k]
return [list(self.buffer)[i] for i in indices]
def get_recent(self, n: int = 10) -> List[Dict]:
"""Get n most recent episodes."""
return list(self.buffer)[-n:]
def export(self) -> List[Dict]:
return list(self.buffer)
def __len__(self):
return len(self.buffer)
class SemanticMemory:
"""
World knowledge - external and learned facts.
Backed by knowledge graph (see knowledge.py).
"""
def __init__(self):
self.facts: Dict[str, Any] = {}
def store_fact(self, key: str, value: Any, confidence: float = 1.0):
self.facts[key] = {"value": value, "confidence": confidence, "timestamp": time.time()}
def retrieve(self, key: str) -> Optional[Dict]:
return self.facts.get(key)
def query(self, query: str) -> List[Dict]:
"""Simple prefix matching."""
return [v for k, v in self.facts.items() if query.lower() in k.lower()]
def export(self) -> Dict:
return self.facts
class ProceduralMemory:
"""
Learned skills, tool definitions, and code implementations.
Inspired by Yunjue Agent's tool accumulation.
"""
def __init__(self):
self.tools: Dict[str, Dict] = {}
self.tool_usage_stats: Dict[str, int] = {}
def register_tool(self, name: str, code: str, description: str,
tags: List[str] = None):
self.tools[name] = {
"code": code,
"description": description,
"tags": tags or [],
"registered_at": time.time(),
"version": 1,
}
self.tool_usage_stats[name] = 0
def get_tool(self, name: str) -> Optional[Dict]:
if name in self.tools:
self.tool_usage_stats[name] += 1
return self.tools[name]
return None
def search_tools(self, query: str) -> List[Dict]:
"""Search tools by description or tags."""
results = []
for name, tool in self.tools.items():
text = f"{name} {tool['description']} {' '.join(tool['tags'])}"
if query.lower() in text.lower():
results.append({"name": name, **tool})
return results
def merge_tools(self, tool_cluster: List[str]) -> Optional[str]:
"""
Merge functionally redundant tools (Yunjue-style tool absorption).
Returns name of merged tool or None.
"""
if len(tool_cluster) < 2:
return None
# Simple merge: keep highest usage tool as canonical
canonical = max(tool_cluster, key=lambda t: self.tool_usage_stats.get(t, 0))
# Merge descriptions
merged_desc = " | ".join(
self.tools[t]["description"] for t in tool_cluster if t in self.tools
)
self.tools[canonical]["description"] = merged_desc
self.tools[canonical]["version"] += 1
# Remove redundant tools
for t in tool_cluster:
if t != canonical and t in self.tools:
del self.tools[t]
return canonical
def export(self) -> Dict:
return {
"tools": self.tools,
"usage_stats": self.tool_usage_stats,
}
class CoALAMemory:
"""
Unified memory system following CoALA cognitive architecture.
Combines Working, Episodic, Semantic, and Procedural memory.
"""
def __init__(self, capacity: int = 16):
self.working = WorkingMemory(capacity=capacity)
self.episodic = EpisodicMemory(buffer_size=1000)
self.semantic = SemanticMemory()
self.procedural = ProceduralMemory()
def store(self, item: Dict[str, Any], memory_type: str = "working"):
if memory_type == "working":
self.working.store(item)
elif memory_type == "episodic":
self.episodic.store(item)
elif memory_type == "semantic":
for k, v in item.items():
self.semantic.store_fact(k, v)
elif memory_type == "procedural":
if "name" in item and "code" in item:
self.procedural.register_tool(
item["name"], item["code"],
item.get("description", ""),
item.get("tags", [])
)
def retrieve(self, query: str, memory_type: str = "all", top_k: int = 5) -> List[Dict]:
if memory_type == "all":
results = []
results.extend(self.working.retrieve(query, top_k=top_k//2))
results.extend(self.episodic.retrieve_similar(query, top_k=top_k))
results.extend(self.semantic.query(query)[:top_k])
return results[:top_k]
elif memory_type == "working":
return self.working.retrieve(query, top_k)
elif memory_type == "episodic":
return self.episodic.retrieve_similar(query, top_k)
elif memory_type == "semantic":
return self.semantic.query(query)[:top_k]
elif memory_type == "procedural":
return self.procedural.search_tools(query)
return []
@property
def buffer(self):
"""Alias for working memory buffer."""
return self.working.buffer
def export(self) -> Dict[str, Any]:
return {
"working": self.working.export(),
"episodic": self.episodic.export(),
"semantic": self.semantic.export(),
"procedural": self.procedural.export(),
}
class TemporalMemory(nn.Module):
"""
Time-sensitive memory with learned temporal attention.
Enables long-horizon reasoning and contextual adaptation.
Uses a simple LSTM-like gating mechanism.
"""
def __init__(self, buffer_size: int = 1000, hidden_dim: int = 64):
super().__init__()
self.buffer_size = buffer_size
self.hidden_dim = hidden_dim
self.buffer: deque = deque(maxlen=buffer_size)
# Temporal attention network
self.temporal_gate = nn.Sequential(
nn.Linear(2, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, 1),
nn.Sigmoid(),
)
def store(self, event: Dict[str, Any]):
event["_timestamp"] = time.time()
self.buffer.append(event)
def retrieve_context(self, current_time: Optional[float] = None,
lookback_window: float = 3600.0) -> List[Dict]:
"""
Retrieve events within lookback window, weighted by recency.
"""
current_time = current_time or time.time()
relevant = []
for event in self.buffer:
age = current_time - event.get("_timestamp", current_time)
if age <= lookback_window:
# Temporal relevance score: exponential decay
recency_score = torch.exp(torch.tensor(-age / lookback_window))
relevant.append({
**event,
"recency_score": recency_score.item(),
"age_seconds": age,
})
# Sort by recency score
relevant.sort(key=lambda x: x["recency_score"], reverse=True)
return relevant
def retrieve_with_attention(self, query_embedding: torch.Tensor,
top_k: int = 10) -> List[Dict]:
"""
Attention-based retrieval combining temporal and semantic relevance.
(Placeholder - would use actual embeddings in full implementation)
"""
return self.retrieve_context()[:top_k]
def export(self) -> List[Dict]:
return list(self.buffer)
@property
def buffer_contents(self):
return list(self.buffer)
def __len__(self):
return len(self.buffer)