Rohan03 commited on
Commit
4dc4204
·
verified ·
1 Parent(s): b9b5c0c

Sprint 3: memory_homeostasis.py — budget, archive, consolidation, Q-retriever

Browse files
Files changed (1) hide show
  1. purpose_agent/memory_homeostasis.py +450 -0
purpose_agent/memory_homeostasis.py ADDED
@@ -0,0 +1,450 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ memory_homeostasis.py — Bounded memory with consolidation, hibernation, and archive.
3
+
4
+ Solves: active memory must be bounded; archived evidence must remain recoverable.
5
+
6
+ Components:
7
+ - MemoryBudget: hard limits on active cards, injected tokens, per-kind caps
8
+ - MemoryArchive: append-only cold storage (JSONL or SQLite)
9
+ - ConsolidationEngine: cluster → merge → compress → hibernate
10
+ - QFunctionRetriever: budget-aware ranking with recency decay and diversity
11
+
12
+ Triggers:
13
+ - On N new memories
14
+ - On active_cards > max_active_cards
15
+ - On injected_tokens > max_injected_tokens
16
+ - Manual: team.consolidate_memory()
17
+
18
+ Invariant: active injected memory NEVER exceeds token budget.
19
+ """
20
+ from __future__ import annotations
21
+
22
+ import json
23
+ import logging
24
+ import math
25
+ import time
26
+ from collections import defaultdict
27
+ from dataclasses import dataclass, field
28
+ from pathlib import Path
29
+ from typing import Any
30
+
31
+ from purpose_agent.memory import MemoryCard, MemoryKind, MemoryStatus, MemoryStore
32
+ from purpose_agent.v2_types import MemoryScope
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ # ═══════════════════════════════════════════════════════════════
38
+ # Memory Budget
39
+ # ═══════════════════════════════════════════════════════════════
40
+
41
+ @dataclass
42
+ class MemoryBudget:
43
+ """
44
+ Hard limits on active memory. Enforced by the homeostasis engine.
45
+
46
+ When any limit is exceeded, consolidation/archival is triggered automatically.
47
+ """
48
+ max_active_cards: int = 512
49
+ max_injected_tokens: int = 500 # Max tokens from memory in any single prompt
50
+ max_cards_per_kind: dict[str, int] = field(default_factory=lambda: {
51
+ "skill_card": 100,
52
+ "episodic_case": 200,
53
+ "failure_pattern": 50,
54
+ "user_preference": 50,
55
+ "critic_calibration": 30,
56
+ "tool_policy": 30,
57
+ "purpose_contract": 10,
58
+ })
59
+ archive_after_days: int | None = 90 # Auto-archive unused cards after N days
60
+ consolidation_threshold: int = 50 # Trigger consolidation every N new memories
61
+ chars_per_token: int = 4 # For token estimation
62
+
63
+ def estimate_tokens(self, text: str) -> int:
64
+ return len(text) // self.chars_per_token
65
+
66
+
67
+ # ═══════════════════════════════════════════════════════════════
68
+ # Memory Archive — cold storage
69
+ # ═══════════════════════════════════════════════════════════════
70
+
71
+ class MemoryArchive:
72
+ """
73
+ Append-only cold storage for archived memories.
74
+
75
+ Archived memories are never injected into prompts but remain
76
+ recoverable by source_trace_id for audit, replay, or re-promotion.
77
+ """
78
+
79
+ def __init__(self, path: str | None = None):
80
+ self._path = Path(path) if path else None
81
+ self._archived: list[dict[str, Any]] = []
82
+ if self._path and self._path.exists():
83
+ self._load()
84
+
85
+ def archive(self, card: MemoryCard, reason: str = "") -> None:
86
+ """Move a card to cold storage."""
87
+ entry = {
88
+ "id": card.id,
89
+ "kind": card.kind.value,
90
+ "pattern": card.pattern,
91
+ "strategy": card.strategy,
92
+ "content": card.content,
93
+ "source_trace_id": card.source_trace_id,
94
+ "trust_score": card.trust_score,
95
+ "utility_score": card.utility_score,
96
+ "times_retrieved": card.times_retrieved,
97
+ "archived_at": time.time(),
98
+ "reason": reason,
99
+ }
100
+ self._archived.append(entry)
101
+ if self._path:
102
+ self._append(entry)
103
+
104
+ def recover(self, card_id: str) -> dict[str, Any] | None:
105
+ """Recover an archived card by ID."""
106
+ for entry in self._archived:
107
+ if entry["id"] == card_id:
108
+ return entry
109
+ return None
110
+
111
+ def recover_by_trace(self, trace_id: str) -> list[dict[str, Any]]:
112
+ """Recover all archived cards from a specific trace."""
113
+ return [e for e in self._archived if e.get("source_trace_id") == trace_id]
114
+
115
+ @property
116
+ def size(self) -> int:
117
+ return len(self._archived)
118
+
119
+ def _append(self, entry: dict) -> None:
120
+ if not self._path:
121
+ return
122
+ self._path.parent.mkdir(parents=True, exist_ok=True)
123
+ with open(self._path, "a") as f:
124
+ f.write(json.dumps(entry, default=str) + "\n")
125
+
126
+ def _load(self) -> None:
127
+ if not self._path or not self._path.exists():
128
+ return
129
+ with open(self._path) as f:
130
+ for line in f:
131
+ line = line.strip()
132
+ if line:
133
+ try:
134
+ self._archived.append(json.loads(line))
135
+ except json.JSONDecodeError:
136
+ pass
137
+
138
+
139
+ # ═══════════════════════════════════════════════════════════════
140
+ # Consolidation Engine
141
+ # ═══════════════════════════════════════════════════════════════
142
+
143
+ class ConsolidationEngine:
144
+ """
145
+ Clusters, merges, compresses, and hibernates memories.
146
+
147
+ Operations:
148
+ - cluster: group similar episodic_case cards by pattern similarity
149
+ - merge: promote repeated patterns into a single skill_card
150
+ - compress: shorten singleton low-utility cases to signatures
151
+ - hibernate: deactivate unused skill_cards (recoverable)
152
+
153
+ All operations preserve source_trace_id for audit trail.
154
+ """
155
+
156
+ def __init__(self, store: MemoryStore, archive: MemoryArchive, budget: MemoryBudget):
157
+ self.store = store
158
+ self.archive = archive
159
+ self.budget = budget
160
+ self._consolidation_count = 0
161
+
162
+ def run(self) -> dict[str, int]:
163
+ """
164
+ Run full consolidation cycle. Returns counts of actions taken.
165
+ """
166
+ results = {"clustered": 0, "merged": 0, "compressed": 0, "hibernated": 0, "archived": 0}
167
+
168
+ # 1. Cluster similar episodic cases
169
+ results["merged"] = self._merge_similar_episodics()
170
+
171
+ # 2. Hibernate low-utility skills
172
+ results["hibernated"] = self._hibernate_unused()
173
+
174
+ # 3. Archive old cards if over budget
175
+ results["archived"] = self._archive_over_budget()
176
+
177
+ # 4. Enforce per-kind limits
178
+ results["archived"] += self._enforce_kind_limits()
179
+
180
+ self._consolidation_count += 1
181
+ logger.info(f"Consolidation #{self._consolidation_count}: {results}")
182
+ return results
183
+
184
+ def _merge_similar_episodics(self) -> int:
185
+ """Merge similar episodic cases into skill cards."""
186
+ episodics = [c for c in self.store.get_all()
187
+ if c.kind == MemoryKind.EPISODIC_CASE and c.status == MemoryStatus.PROMOTED]
188
+
189
+ if len(episodics) < 3:
190
+ return 0
191
+
192
+ # Group by pattern similarity (simple: exact pattern match)
193
+ groups: dict[str, list[MemoryCard]] = defaultdict(list)
194
+ for card in episodics:
195
+ key = card.pattern.lower().strip()[:50] # Rough grouping key
196
+ groups[key].append(card)
197
+
198
+ merged = 0
199
+ for key, cards in groups.items():
200
+ if len(cards) >= 3:
201
+ # Merge into a skill card
202
+ avg_utility = sum(c.utility_score for c in cards) / len(cards)
203
+ merged_card = MemoryCard(
204
+ kind=MemoryKind.SKILL_CARD,
205
+ status=MemoryStatus.PROMOTED,
206
+ pattern=cards[0].pattern,
207
+ strategy=f"[CONSOLIDATED from {len(cards)} cases] " + cards[0].strategy,
208
+ trust_score=min(c.trust_score for c in cards),
209
+ utility_score=avg_utility,
210
+ source_trace_id=cards[0].source_trace_id,
211
+ created_by="consolidation",
212
+ )
213
+ self.store.add(merged_card)
214
+
215
+ # Archive the original episodics
216
+ for card in cards:
217
+ self.store.update_status(card.id, MemoryStatus.ARCHIVED, "consolidated")
218
+ self.archive.archive(card, f"merged into {merged_card.id}")
219
+
220
+ merged += 1
221
+
222
+ return merged
223
+
224
+ def _hibernate_unused(self) -> int:
225
+ """Hibernate skill cards that haven't been useful."""
226
+ hibernated = 0
227
+ for card in self.store.get_all():
228
+ if card.status != MemoryStatus.PROMOTED:
229
+ continue
230
+ if card.kind != MemoryKind.SKILL_CARD:
231
+ continue
232
+ # Hibernate if: retrieved many times but rarely helped
233
+ if card.times_retrieved >= 10 and card.utility_score < 0.2:
234
+ self.store.update_status(card.id, MemoryStatus.ARCHIVED, "hibernated: low utility")
235
+ self.archive.archive(card, "hibernated")
236
+ hibernated += 1
237
+
238
+ return hibernated
239
+
240
+ def _archive_over_budget(self) -> int:
241
+ """Archive lowest-utility cards when over max_active_cards."""
242
+ active = self.store.get_by_status(MemoryStatus.PROMOTED)
243
+ if len(active) <= self.budget.max_active_cards:
244
+ return 0
245
+
246
+ # Sort by utility (lowest first) and archive excess
247
+ active.sort(key=lambda c: c.utility_score)
248
+ excess = len(active) - self.budget.max_active_cards
249
+ archived = 0
250
+
251
+ for card in active[:excess]:
252
+ self.store.update_status(card.id, MemoryStatus.ARCHIVED, "budget: over max_active")
253
+ self.archive.archive(card, "budget overflow")
254
+ archived += 1
255
+
256
+ return archived
257
+
258
+ def _enforce_kind_limits(self) -> int:
259
+ """Enforce per-kind card limits."""
260
+ archived = 0
261
+ for kind_str, limit in self.budget.max_cards_per_kind.items():
262
+ try:
263
+ kind = MemoryKind(kind_str)
264
+ except ValueError:
265
+ continue
266
+
267
+ cards = [c for c in self.store.get_all()
268
+ if c.kind == kind and c.status == MemoryStatus.PROMOTED]
269
+
270
+ if len(cards) <= limit:
271
+ continue
272
+
273
+ # Remove lowest utility
274
+ cards.sort(key=lambda c: c.utility_score)
275
+ for card in cards[:len(cards) - limit]:
276
+ self.store.update_status(card.id, MemoryStatus.ARCHIVED, f"kind_limit: {kind_str}")
277
+ self.archive.archive(card, f"kind limit ({kind_str})")
278
+ archived += 1
279
+
280
+ return archived
281
+
282
+
283
+ # ═══════════════════════════════════════════════════════════════
284
+ # Q-Function Retriever — budget-aware ranking
285
+ # ═══════════════════════════════════════════════════════════════
286
+
287
+ class QFunctionRetriever:
288
+ """
289
+ Budget-aware memory retriever with multi-signal ranking.
290
+
291
+ score = relevance * trust * utility * recency_decay * scope_match * diversity_penalty
292
+
293
+ Guarantees: injected tokens NEVER exceed budget.max_injected_tokens.
294
+ """
295
+
296
+ def __init__(self, store: MemoryStore, budget: MemoryBudget):
297
+ self.store = store
298
+ self.budget = budget
299
+
300
+ def retrieve(
301
+ self,
302
+ query: str,
303
+ scope: MemoryScope | None = None,
304
+ max_cards: int = 15,
305
+ ) -> list[MemoryCard]:
306
+ """
307
+ Retrieve memories ranked by composite score, bounded by token budget.
308
+
309
+ Returns only PROMOTED memories that fit within max_injected_tokens.
310
+ """
311
+ candidates = self.store.retrieve(
312
+ query_text=query,
313
+ scope=scope,
314
+ statuses=[MemoryStatus.PROMOTED],
315
+ top_k=max_cards * 3, # Over-fetch for diversity filtering
316
+ )
317
+
318
+ # Re-rank with full Q-function
319
+ now = time.time()
320
+ scored = []
321
+ for card in candidates:
322
+ score = self._compute_score(card, query, now)
323
+ scored.append((score, card))
324
+
325
+ scored.sort(key=lambda x: -x[0])
326
+
327
+ # Select under token budget
328
+ selected = []
329
+ token_used = 0
330
+
331
+ seen_patterns: set[str] = set()
332
+ for score, card in scored:
333
+ # Diversity: skip near-duplicates
334
+ pattern_key = (card.pattern or card.content or "")[:30].lower()
335
+ if pattern_key in seen_patterns:
336
+ continue
337
+ seen_patterns.add(pattern_key)
338
+
339
+ # Token budget check
340
+ card_text = f"{card.pattern} {card.strategy} {' '.join(card.steps)}"
341
+ card_tokens = self.budget.estimate_tokens(card_text)
342
+
343
+ if token_used + card_tokens > self.budget.max_injected_tokens:
344
+ break
345
+
346
+ selected.append(card)
347
+ token_used += card_tokens
348
+
349
+ if len(selected) >= max_cards:
350
+ break
351
+
352
+ return selected
353
+
354
+ def _compute_score(self, card: MemoryCard, query: str, now: float) -> float:
355
+ """
356
+ Composite Q-function score:
357
+ score = relevance * trust * utility * recency_decay
358
+ """
359
+ # Base scores from card
360
+ trust = card.trust_score
361
+ utility = card.utility_score
362
+
363
+ # Relevance (already computed by store.retrieve, use utility as proxy)
364
+ relevance = 0.5 + utility * 0.5
365
+
366
+ # Recency decay: newer memories get slight boost
367
+ age_days = (now - card.created_at) / 86400
368
+ recency = max(0.3, 1.0 - (age_days / 365)) # Decay over a year
369
+
370
+ # Combine
371
+ score = relevance * trust * utility * recency
372
+
373
+ # Boost frequently successful cards
374
+ if card.times_retrieved > 0 and card.times_helped > 0:
375
+ help_rate = card.times_helped / card.times_retrieved
376
+ score *= (1.0 + help_rate * 0.5)
377
+
378
+ return score
379
+
380
+
381
+ # ═══════════════════════════════════════════════════════════════
382
+ # Homeostasis Controller — ties everything together
383
+ # ═══════════════════════════════════════════════════════════════
384
+
385
+ class MemoryHomeostasis:
386
+ """
387
+ Main controller that keeps memory bounded and healthy.
388
+
389
+ Usage:
390
+ homeostasis = MemoryHomeostasis(store, budget=MemoryBudget(max_active_cards=256))
391
+
392
+ # After each task:
393
+ homeostasis.check_and_consolidate()
394
+
395
+ # Manual trigger:
396
+ homeostasis.force_consolidation()
397
+
398
+ # Budget-aware retrieval:
399
+ memories = homeostasis.retrieve("query", scope=scope)
400
+ """
401
+
402
+ def __init__(
403
+ self,
404
+ store: MemoryStore,
405
+ budget: MemoryBudget | None = None,
406
+ archive_path: str | None = None,
407
+ ):
408
+ self.store = store
409
+ self.budget = budget or MemoryBudget()
410
+ self.archive = MemoryArchive(archive_path)
411
+ self.consolidation = ConsolidationEngine(store, self.archive, self.budget)
412
+ self.retriever = QFunctionRetriever(store, self.budget)
413
+ self._new_since_consolidation = 0
414
+
415
+ def on_memory_added(self) -> None:
416
+ """Called after a new memory is added. Triggers consolidation if threshold met."""
417
+ self._new_since_consolidation += 1
418
+ if self._new_since_consolidation >= self.budget.consolidation_threshold:
419
+ self.check_and_consolidate()
420
+
421
+ def check_and_consolidate(self) -> dict[str, int] | None:
422
+ """Check if consolidation is needed and run it if so."""
423
+ active_count = len(self.store.get_by_status(MemoryStatus.PROMOTED))
424
+
425
+ if (active_count > self.budget.max_active_cards or
426
+ self._new_since_consolidation >= self.budget.consolidation_threshold):
427
+ self._new_since_consolidation = 0
428
+ return self.consolidation.run()
429
+ return None
430
+
431
+ def force_consolidation(self) -> dict[str, int]:
432
+ """Force a consolidation cycle regardless of thresholds."""
433
+ self._new_since_consolidation = 0
434
+ return self.consolidation.run()
435
+
436
+ def retrieve(self, query: str, scope: MemoryScope | None = None, max_cards: int = 10) -> list[MemoryCard]:
437
+ """Budget-aware retrieval. Guarantees token budget is respected."""
438
+ return self.retriever.retrieve(query, scope, max_cards)
439
+
440
+ @property
441
+ def stats(self) -> dict[str, Any]:
442
+ active = len(self.store.get_by_status(MemoryStatus.PROMOTED))
443
+ return {
444
+ "active_cards": active,
445
+ "max_active": self.budget.max_active_cards,
446
+ "utilization": f"{active/self.budget.max_active_cards:.0%}" if self.budget.max_active_cards else "0%",
447
+ "archived": self.archive.size,
448
+ "consolidations_run": self.consolidation._consolidation_count,
449
+ "new_since_last": self._new_since_consolidation,
450
+ }