camdog920 commited on
Commit
22026ba
·
verified ·
1 Parent(s): 997a0f9

Upload aether/memory.py

Browse files
Files changed (1) hide show
  1. aether/memory.py +300 -0
aether/memory.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ CoALA-inspired Memory Architecture for AETHER.
3
+ Four modules: Working, Episodic, Semantic, Procedural.
4
+ Plus Temporal Memory for long-horizon reasoning.
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from collections import deque
10
+ from typing import Dict, List, Any, Optional
11
+ import time
12
+ import json
13
+
14
+
15
+ class WorkingMemory:
16
+ """
17
+ Active scratchpad for current reasoning cycle.
18
+ Limited capacity, fast access.
19
+ """
20
+ def __init__(self, capacity: int = 16):
21
+ self.capacity = capacity
22
+ self.buffer: deque = deque(maxlen=capacity)
23
+ self.attention_weights = nn.Parameter(torch.ones(capacity))
24
+
25
+ def store(self, item: Dict[str, Any]):
26
+ item["_timestamp"] = time.time()
27
+ self.buffer.append(item)
28
+
29
+ def retrieve(self, query: str, top_k: int = 3) -> List[Dict]:
30
+ """Simple attention-based retrieval from working buffer."""
31
+ if not self.buffer:
32
+ return []
33
+
34
+ # Compute relevance scores (simplified)
35
+ scores = []
36
+ for i, item in enumerate(self.buffer):
37
+ score = sum(1 for k in item if query.lower() in str(k).lower())
38
+ scores.append(score * torch.sigmoid(self.attention_weights[i]).item())
39
+
40
+ # Get top-k indices
41
+ indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_k]
42
+ return [list(self.buffer)[i] for i in indices]
43
+
44
+ def export(self) -> List[Dict]:
45
+ return list(self.buffer)
46
+
47
+ def __len__(self):
48
+ return len(self.buffer)
49
+
50
+
51
+ class EpisodicMemory:
52
+ """
53
+ Experience buffer storing past interactions.
54
+ Temporal structure for long-horizon reasoning.
55
+ """
56
+ def __init__(self, buffer_size: int = 1000):
57
+ self.buffer_size = buffer_size
58
+ self.buffer: deque = deque(maxlen=buffer_size)
59
+
60
+ def store(self, episode: Dict[str, Any]):
61
+ episode["_timestamp"] = time.time()
62
+ self.buffer.append(episode)
63
+
64
+ def retrieve_similar(self, query: str, top_k: int = 5) -> List[Dict]:
65
+ """Retrieve episodes similar to query."""
66
+ if not self.buffer:
67
+ return []
68
+
69
+ # Simple keyword matching (replace with embedding-based in production)
70
+ scores = []
71
+ for item in self.buffer:
72
+ text = json.dumps(item)
73
+ score = sum(1 for word in query.lower().split() if word in text.lower())
74
+ scores.append(score)
75
+
76
+ indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_k]
77
+ return [list(self.buffer)[i] for i in indices]
78
+
79
+ def get_recent(self, n: int = 10) -> List[Dict]:
80
+ """Get n most recent episodes."""
81
+ return list(self.buffer)[-n:]
82
+
83
+ def export(self) -> List[Dict]:
84
+ return list(self.buffer)
85
+
86
+ def __len__(self):
87
+ return len(self.buffer)
88
+
89
+
90
+ class SemanticMemory:
91
+ """
92
+ World knowledge - external and learned facts.
93
+ Backed by knowledge graph (see knowledge.py).
94
+ """
95
+ def __init__(self):
96
+ self.facts: Dict[str, Any] = {}
97
+
98
+ def store_fact(self, key: str, value: Any, confidence: float = 1.0):
99
+ self.facts[key] = {"value": value, "confidence": confidence, "timestamp": time.time()}
100
+
101
+ def retrieve(self, key: str) -> Optional[Dict]:
102
+ return self.facts.get(key)
103
+
104
+ def query(self, query: str) -> List[Dict]:
105
+ """Simple prefix matching."""
106
+ return [v for k, v in self.facts.items() if query.lower() in k.lower()]
107
+
108
+ def export(self) -> Dict:
109
+ return self.facts
110
+
111
+
112
+ class ProceduralMemory:
113
+ """
114
+ Learned skills, tool definitions, and code implementations.
115
+ Inspired by Yunjue Agent's tool accumulation.
116
+ """
117
+ def __init__(self):
118
+ self.tools: Dict[str, Dict] = {}
119
+ self.tool_usage_stats: Dict[str, int] = {}
120
+
121
+ def register_tool(self, name: str, code: str, description: str,
122
+ tags: List[str] = None):
123
+ self.tools[name] = {
124
+ "code": code,
125
+ "description": description,
126
+ "tags": tags or [],
127
+ "registered_at": time.time(),
128
+ "version": 1,
129
+ }
130
+ self.tool_usage_stats[name] = 0
131
+
132
+ def get_tool(self, name: str) -> Optional[Dict]:
133
+ if name in self.tools:
134
+ self.tool_usage_stats[name] += 1
135
+ return self.tools[name]
136
+ return None
137
+
138
+ def search_tools(self, query: str) -> List[Dict]:
139
+ """Search tools by description or tags."""
140
+ results = []
141
+ for name, tool in self.tools.items():
142
+ text = f"{name} {tool['description']} {' '.join(tool['tags'])}"
143
+ if query.lower() in text.lower():
144
+ results.append({"name": name, **tool})
145
+ return results
146
+
147
+ def merge_tools(self, tool_cluster: List[str]) -> Optional[str]:
148
+ """
149
+ Merge functionally redundant tools (Yunjue-style tool absorption).
150
+ Returns name of merged tool or None.
151
+ """
152
+ if len(tool_cluster) < 2:
153
+ return None
154
+
155
+ # Simple merge: keep highest usage tool as canonical
156
+ canonical = max(tool_cluster, key=lambda t: self.tool_usage_stats.get(t, 0))
157
+
158
+ # Merge descriptions
159
+ merged_desc = " | ".join(
160
+ self.tools[t]["description"] for t in tool_cluster if t in self.tools
161
+ )
162
+ self.tools[canonical]["description"] = merged_desc
163
+ self.tools[canonical]["version"] += 1
164
+
165
+ # Remove redundant tools
166
+ for t in tool_cluster:
167
+ if t != canonical and t in self.tools:
168
+ del self.tools[t]
169
+
170
+ return canonical
171
+
172
+ def export(self) -> Dict:
173
+ return {
174
+ "tools": self.tools,
175
+ "usage_stats": self.tool_usage_stats,
176
+ }
177
+
178
+
179
+ class CoALAMemory:
180
+ """
181
+ Unified memory system following CoALA cognitive architecture.
182
+ Combines Working, Episodic, Semantic, and Procedural memory.
183
+ """
184
+ def __init__(self, capacity: int = 16):
185
+ self.working = WorkingMemory(capacity=capacity)
186
+ self.episodic = EpisodicMemory(buffer_size=1000)
187
+ self.semantic = SemanticMemory()
188
+ self.procedural = ProceduralMemory()
189
+
190
+ def store(self, item: Dict[str, Any], memory_type: str = "working"):
191
+ if memory_type == "working":
192
+ self.working.store(item)
193
+ elif memory_type == "episodic":
194
+ self.episodic.store(item)
195
+ elif memory_type == "semantic":
196
+ for k, v in item.items():
197
+ self.semantic.store_fact(k, v)
198
+ elif memory_type == "procedural":
199
+ if "name" in item and "code" in item:
200
+ self.procedural.register_tool(
201
+ item["name"], item["code"],
202
+ item.get("description", ""),
203
+ item.get("tags", [])
204
+ )
205
+
206
+ def retrieve(self, query: str, memory_type: str = "all", top_k: int = 5) -> List[Dict]:
207
+ if memory_type == "all":
208
+ results = []
209
+ results.extend(self.working.retrieve(query, top_k=top_k//2))
210
+ results.extend(self.episodic.retrieve_similar(query, top_k=top_k))
211
+ results.extend(self.semantic.query(query)[:top_k])
212
+ return results[:top_k]
213
+ elif memory_type == "working":
214
+ return self.working.retrieve(query, top_k)
215
+ elif memory_type == "episodic":
216
+ return self.episodic.retrieve_similar(query, top_k)
217
+ elif memory_type == "semantic":
218
+ return self.semantic.query(query)[:top_k]
219
+ elif memory_type == "procedural":
220
+ return self.procedural.search_tools(query)
221
+ return []
222
+
223
+ @property
224
+ def buffer(self):
225
+ """Alias for working memory buffer."""
226
+ return self.working.buffer
227
+
228
+ def export(self) -> Dict[str, Any]:
229
+ return {
230
+ "working": self.working.export(),
231
+ "episodic": self.episodic.export(),
232
+ "semantic": self.semantic.export(),
233
+ "procedural": self.procedural.export(),
234
+ }
235
+
236
+
237
+ class TemporalMemory(nn.Module):
238
+ """
239
+ Time-sensitive memory with learned temporal attention.
240
+ Enables long-horizon reasoning and contextual adaptation.
241
+ Uses a simple LSTM-like gating mechanism.
242
+ """
243
+ def __init__(self, buffer_size: int = 1000, hidden_dim: int = 64):
244
+ super().__init__()
245
+ self.buffer_size = buffer_size
246
+ self.hidden_dim = hidden_dim
247
+ self.buffer: deque = deque(maxlen=buffer_size)
248
+
249
+ # Temporal attention network
250
+ self.temporal_gate = nn.Sequential(
251
+ nn.Linear(2, hidden_dim),
252
+ nn.ReLU(),
253
+ nn.Linear(hidden_dim, 1),
254
+ nn.Sigmoid(),
255
+ )
256
+
257
+ def store(self, event: Dict[str, Any]):
258
+ event["_timestamp"] = time.time()
259
+ self.buffer.append(event)
260
+
261
+ def retrieve_context(self, current_time: Optional[float] = None,
262
+ lookback_window: float = 3600.0) -> List[Dict]:
263
+ """
264
+ Retrieve events within lookback window, weighted by recency.
265
+ """
266
+ current_time = current_time or time.time()
267
+ relevant = []
268
+
269
+ for event in self.buffer:
270
+ age = current_time - event.get("_timestamp", current_time)
271
+ if age <= lookback_window:
272
+ # Temporal relevance score: exponential decay
273
+ recency_score = torch.exp(torch.tensor(-age / lookback_window))
274
+ relevant.append({
275
+ **event,
276
+ "recency_score": recency_score.item(),
277
+ "age_seconds": age,
278
+ })
279
+
280
+ # Sort by recency score
281
+ relevant.sort(key=lambda x: x["recency_score"], reverse=True)
282
+ return relevant
283
+
284
+ def retrieve_with_attention(self, query_embedding: torch.Tensor,
285
+ top_k: int = 10) -> List[Dict]:
286
+ """
287
+ Attention-based retrieval combining temporal and semantic relevance.
288
+ (Placeholder - would use actual embeddings in full implementation)
289
+ """
290
+ return self.retrieve_context()[:top_k]
291
+
292
+ def export(self) -> List[Dict]:
293
+ return list(self.buffer)
294
+
295
+ @property
296
+ def buffer_contents(self):
297
+ return list(self.buffer)
298
+
299
+ def __len__(self):
300
+ return len(self.buffer)