Rohan03 commited on
Commit
455fdee
·
verified ·
1 Parent(s): fedfb2e

Add purpose_agent/experience_replay.py

Browse files
Files changed (1) hide show
  1. purpose_agent/experience_replay.py +407 -0
purpose_agent/experience_replay.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Experience Replay — Trajectory storage and retrieval.
3
+
4
+ Stores completed trajectories with their scores and supports retrieval
5
+ ranked by a combination of:
6
+ 1. Semantic similarity (embedding-based) — from MemRL (arxiv:2601.03192)
7
+ 2. Learned Q-value utility scores — from REMEMBERER (arxiv:2306.07929)
8
+
9
+ The two-phase retrieval (recall by similarity → re-rank by Q-value) separates
10
+ "semantically similar" from "functionally useful" — a key insight from MemRL.
11
+
12
+ This module is the "database" — it stores but doesn't analyze. The Optimizer
13
+ module reads from here and writes heuristics back.
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import json
19
+ import logging
20
+ import math
21
+ import os
22
+ import time
23
+ from pathlib import Path
24
+ from typing import Any
25
+
26
+ from purpose_agent.types import (
27
+ Heuristic,
28
+ MemoryRecord,
29
+ MemoryTier,
30
+ Trajectory,
31
+ TrajectoryStep,
32
+ )
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+
37
+ class ExperienceReplay:
38
+ """
39
+ Experience Replay buffer with two-phase retrieval.
40
+
41
+ Phase 1 (Recall): Retrieve top-k records by semantic similarity to query
42
+ Phase 2 (Re-rank): Re-order by learned Q-value utility scores
43
+
44
+ The buffer supports:
45
+ - Adding trajectories with automatic scoring
46
+ - Retrieving similar past experiences for the Actor's context
47
+ - Q-value updates after heuristics are applied (Bellman-style)
48
+ - Persistence to disk (JSON)
49
+ - Capacity management (evict lowest Q-value records when full)
50
+
51
+ Args:
52
+ capacity: Maximum number of records to store
53
+ similarity_weight: λ in retrieval score = λ·similarity + (1-λ)·q_value
54
+ persistence_path: If set, auto-save/load buffer to this file
55
+ """
56
+
57
+ def __init__(
58
+ self,
59
+ capacity: int = 500,
60
+ similarity_weight: float = 0.6,
61
+ persistence_path: str | Path | None = None,
62
+ ):
63
+ self.capacity = capacity
64
+ self.similarity_weight = similarity_weight
65
+ self.persistence_path = Path(persistence_path) if persistence_path else None
66
+ self.records: list[MemoryRecord] = []
67
+
68
+ # Load from disk if available
69
+ if self.persistence_path and self.persistence_path.exists():
70
+ self._load()
71
+
72
+ # ------------------------------------------------------------------
73
+ # Core Operations
74
+ # ------------------------------------------------------------------
75
+
76
+ def add(self, trajectory: Trajectory) -> MemoryRecord:
77
+ """
78
+ Add a completed trajectory to the buffer.
79
+
80
+ Automatically computes a task embedding (simple TF-IDF-style hash)
81
+ and initial Q-value based on trajectory performance.
82
+ """
83
+ # Compute initial Q-value from trajectory performance
84
+ initial_q = self._compute_initial_q(trajectory)
85
+
86
+ record = MemoryRecord(
87
+ trajectory=trajectory,
88
+ heuristics=[],
89
+ task_embedding=self._compute_embedding(trajectory.task_description),
90
+ retrieval_q_value=initial_q,
91
+ )
92
+
93
+ # Capacity management: evict lowest Q-value if full
94
+ if len(self.records) >= self.capacity:
95
+ self._evict()
96
+
97
+ self.records.append(record)
98
+ logger.info(
99
+ f"Experience Replay: Added trajectory '{trajectory.id}' "
100
+ f"(q={initial_q:.3f}, steps={len(trajectory.steps)}, "
101
+ f"Σreward={trajectory.cumulative_reward:.2f})"
102
+ )
103
+
104
+ if self.persistence_path:
105
+ self._save()
106
+
107
+ return record
108
+
109
+ def retrieve(
110
+ self,
111
+ query: str,
112
+ top_k: int = 5,
113
+ min_q_value: float = 0.0,
114
+ ) -> list[MemoryRecord]:
115
+ """
116
+ Two-phase retrieval (per MemRL arxiv:2601.03192):
117
+
118
+ Phase 1: Recall candidates by semantic similarity
119
+ Phase 2: Re-rank by Q-value utility
120
+
121
+ Returns top-k records sorted by combined score.
122
+ """
123
+ if not self.records:
124
+ return []
125
+
126
+ query_embedding = self._compute_embedding(query)
127
+
128
+ # Phase 1: Compute similarity scores for all records
129
+ scored: list[tuple[float, MemoryRecord]] = []
130
+ for record in self.records:
131
+ if record.retrieval_q_value < min_q_value:
132
+ continue
133
+ sim = self._cosine_similarity(
134
+ query_embedding, record.task_embedding or []
135
+ )
136
+ # Phase 2: Combined score
137
+ combined = (
138
+ self.similarity_weight * sim
139
+ + (1 - self.similarity_weight) * record.retrieval_q_value
140
+ )
141
+ scored.append((combined, record))
142
+
143
+ # Sort descending by combined score
144
+ scored.sort(key=lambda x: -x[0])
145
+
146
+ results = [record for _, record in scored[:top_k]]
147
+ logger.debug(
148
+ f"Experience Replay: Retrieved {len(results)} records for query "
149
+ f"(top score={scored[0][0]:.3f})" if scored else "no records"
150
+ )
151
+ return results
152
+
153
+ def update_q_value(
154
+ self,
155
+ record_id: str,
156
+ reward: float,
157
+ alpha: float = 0.1,
158
+ ) -> None:
159
+ """
160
+ Update a record's retrieval Q-value using Monte Carlo update.
161
+
162
+ Q_new = Q_old + α * (reward - Q_old)
163
+
164
+ From REMEMBERER (arxiv:2306.07929): α = 1/N where N = number of
165
+ updates. We use a fixed α for simplicity; override for REMEMBERER-exact.
166
+ """
167
+ for record in self.records:
168
+ if record.id == record_id:
169
+ old_q = record.retrieval_q_value
170
+ record.retrieval_q_value += alpha * (reward - old_q)
171
+ record.retrieval_q_value = max(0.0, min(1.0, record.retrieval_q_value))
172
+ logger.debug(
173
+ f"Experience Replay: Q-value update for {record_id}: "
174
+ f"{old_q:.3f} → {record.retrieval_q_value:.3f}"
175
+ )
176
+ if self.persistence_path:
177
+ self._save()
178
+ return
179
+ logger.warning(f"Experience Replay: Record {record_id} not found for Q-update")
180
+
181
+ def attach_heuristics(
182
+ self, record_id: str, heuristics: list[Heuristic]
183
+ ) -> None:
184
+ """Attach extracted heuristics to a memory record."""
185
+ for record in self.records:
186
+ if record.id == record_id:
187
+ record.heuristics = heuristics
188
+ if self.persistence_path:
189
+ self._save()
190
+ return
191
+
192
+ # ------------------------------------------------------------------
193
+ # Statistics & Queries
194
+ # ------------------------------------------------------------------
195
+
196
+ def get_top_trajectories(
197
+ self,
198
+ n: int = 10,
199
+ min_success_rate: float = 0.5,
200
+ ) -> list[Trajectory]:
201
+ """Get the n best trajectories by cumulative reward."""
202
+ candidates = [
203
+ r.trajectory for r in self.records
204
+ if r.trajectory.success_rate >= min_success_rate
205
+ ]
206
+ candidates.sort(key=lambda t: -t.cumulative_reward)
207
+ return candidates[:n]
208
+
209
+ def get_all_heuristics(self, tier: MemoryTier | None = None) -> list[Heuristic]:
210
+ """Get all extracted heuristics, optionally filtered by tier."""
211
+ heuristics = []
212
+ for record in self.records:
213
+ for h in record.heuristics:
214
+ if tier is None or h.tier == tier:
215
+ heuristics.append(h)
216
+ return heuristics
217
+
218
+ @property
219
+ def size(self) -> int:
220
+ return len(self.records)
221
+
222
+ @property
223
+ def stats(self) -> dict[str, Any]:
224
+ if not self.records:
225
+ return {"size": 0}
226
+ q_values = [r.retrieval_q_value for r in self.records]
227
+ rewards = [r.trajectory.cumulative_reward for r in self.records]
228
+ return {
229
+ "size": len(self.records),
230
+ "avg_q_value": sum(q_values) / len(q_values),
231
+ "max_q_value": max(q_values),
232
+ "avg_cumulative_reward": sum(rewards) / len(rewards),
233
+ "total_heuristics": sum(len(r.heuristics) for r in self.records),
234
+ }
235
+
236
+ # ------------------------------------------------------------------
237
+ # Embedding & Similarity (lightweight — no external deps)
238
+ # ------------------------------------------------------------------
239
+
240
+ @staticmethod
241
+ def _compute_embedding(text: str) -> list[float]:
242
+ """
243
+ Lightweight text embedding using character n-gram hashing.
244
+
245
+ This is intentionally simple — for production, swap in a real
246
+ embedding model (sentence-transformers, OpenAI embeddings, etc.).
247
+
248
+ To use real embeddings, subclass ExperienceReplay and override
249
+ _compute_embedding() and _cosine_similarity().
250
+ """
251
+ # Character trigram hashing into a fixed-size vector
252
+ dim = 128
253
+ vec = [0.0] * dim
254
+ text_lower = text.lower()
255
+ for i in range(len(text_lower) - 2):
256
+ trigram = text_lower[i:i + 3]
257
+ h = hash(trigram) % dim
258
+ vec[h] += 1.0
259
+
260
+ # L2 normalize
261
+ magnitude = math.sqrt(sum(x * x for x in vec))
262
+ if magnitude > 0:
263
+ vec = [x / magnitude for x in vec]
264
+ return vec
265
+
266
+ @staticmethod
267
+ def _cosine_similarity(a: list[float], b: list[float]) -> float:
268
+ """Cosine similarity between two vectors."""
269
+ if not a or not b or len(a) != len(b):
270
+ return 0.0
271
+ dot = sum(x * y for x, y in zip(a, b))
272
+ mag_a = math.sqrt(sum(x * x for x in a))
273
+ mag_b = math.sqrt(sum(x * x for x in b))
274
+ if mag_a == 0 or mag_b == 0:
275
+ return 0.0
276
+ return dot / (mag_a * mag_b)
277
+
278
+ # ------------------------------------------------------------------
279
+ # Initial Q-Value Estimation
280
+ # ------------------------------------------------------------------
281
+
282
+ @staticmethod
283
+ def _compute_initial_q(trajectory: Trajectory) -> float:
284
+ """
285
+ Compute initial Q-value from trajectory performance.
286
+
287
+ Uses a combination of:
288
+ - Success rate (fraction of steps that improved state)
289
+ - Total delta (net improvement)
290
+ - Trajectory length efficiency (shorter = better for same delta)
291
+ """
292
+ if not trajectory.steps:
293
+ return 0.3 # Uninformative prior
294
+
295
+ success_rate = trajectory.success_rate
296
+ total_delta = trajectory.total_delta
297
+ length = len(trajectory.steps)
298
+
299
+ # Normalize total_delta to 0-1 (assuming max meaningful delta is ~10)
300
+ delta_normalized = max(0.0, min(1.0, total_delta / 10.0))
301
+
302
+ # Efficiency bonus: more progress per step = higher Q
303
+ efficiency = delta_normalized / max(length, 1)
304
+
305
+ q = 0.4 * success_rate + 0.4 * delta_normalized + 0.2 * min(efficiency * 5, 1.0)
306
+ return max(0.0, min(1.0, q))
307
+
308
+ # ------------------------------------------------------------------
309
+ # Capacity Management
310
+ # ------------------------------------------------------------------
311
+
312
+ def _evict(self) -> None:
313
+ """Evict the lowest Q-value record."""
314
+ if not self.records:
315
+ return
316
+ worst = min(self.records, key=lambda r: r.retrieval_q_value)
317
+ self.records.remove(worst)
318
+ logger.debug(
319
+ f"Experience Replay: Evicted record {worst.id} "
320
+ f"(q={worst.retrieval_q_value:.3f})"
321
+ )
322
+
323
+ # ------------------------------------------------------------------
324
+ # Persistence
325
+ # ------------------------------------------------------------------
326
+
327
+ def _save(self) -> None:
328
+ """Save buffer to disk as JSON."""
329
+ if not self.persistence_path:
330
+ return
331
+ self.persistence_path.parent.mkdir(parents=True, exist_ok=True)
332
+
333
+ data = []
334
+ for record in self.records:
335
+ data.append({
336
+ "id": record.id,
337
+ "retrieval_q_value": record.retrieval_q_value,
338
+ "task_embedding": record.task_embedding,
339
+ "trajectory": {
340
+ "id": record.trajectory.id,
341
+ "task_description": record.trajectory.task_description,
342
+ "purpose": record.trajectory.purpose,
343
+ "created_at": record.trajectory.created_at,
344
+ "cumulative_reward": record.trajectory.cumulative_reward,
345
+ "total_delta": record.trajectory.total_delta,
346
+ "success_rate": record.trajectory.success_rate,
347
+ "num_steps": len(record.trajectory.steps),
348
+ },
349
+ "heuristics": [
350
+ {
351
+ "id": h.id,
352
+ "pattern": h.pattern,
353
+ "strategy": h.strategy,
354
+ "steps": h.steps,
355
+ "tier": h.tier.value,
356
+ "q_value": h.q_value,
357
+ "times_used": h.times_used,
358
+ "times_succeeded": h.times_succeeded,
359
+ }
360
+ for h in record.heuristics
361
+ ],
362
+ })
363
+
364
+ with open(self.persistence_path, "w") as f:
365
+ json.dump(data, f, indent=2, default=str)
366
+
367
+ def _load(self) -> None:
368
+ """Load buffer from disk."""
369
+ if not self.persistence_path or not self.persistence_path.exists():
370
+ return
371
+ try:
372
+ with open(self.persistence_path) as f:
373
+ data = json.load(f)
374
+
375
+ for entry in data:
376
+ traj_data = entry["trajectory"]
377
+ trajectory = Trajectory(
378
+ task_description=traj_data["task_description"],
379
+ purpose=traj_data["purpose"],
380
+ id=traj_data["id"],
381
+ created_at=traj_data.get("created_at", time.time()),
382
+ )
383
+ heuristics = [
384
+ Heuristic(
385
+ id=h["id"],
386
+ pattern=h["pattern"],
387
+ strategy=h["strategy"],
388
+ steps=h["steps"],
389
+ tier=MemoryTier(h["tier"]),
390
+ q_value=h["q_value"],
391
+ times_used=h.get("times_used", 0),
392
+ times_succeeded=h.get("times_succeeded", 0),
393
+ )
394
+ for h in entry.get("heuristics", [])
395
+ ]
396
+ record = MemoryRecord(
397
+ id=entry["id"],
398
+ trajectory=trajectory,
399
+ heuristics=heuristics,
400
+ task_embedding=entry.get("task_embedding"),
401
+ retrieval_q_value=entry.get("retrieval_q_value", 0.5),
402
+ )
403
+ self.records.append(record)
404
+
405
+ logger.info(f"Experience Replay: Loaded {len(self.records)} records from disk")
406
+ except Exception as e:
407
+ logger.error(f"Experience Replay: Failed to load from disk: {e}")