File size: 15,146 Bytes
455fdee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8bdd2de
 
 
 
 
 
 
455fdee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
"""
Experience Replay — Trajectory storage and retrieval.

Stores completed trajectories with their scores and supports retrieval
ranked by a combination of:
  1. Semantic similarity (embedding-based) — from MemRL (arxiv:2601.03192)
  2. Learned Q-value utility scores — from REMEMBERER (arxiv:2306.07929)

The two-phase retrieval (recall by similarity → re-rank by Q-value) separates
"semantically similar" from "functionally useful" — a key insight from MemRL.

This module is the "database" — it stores but doesn't analyze. The Optimizer
module reads from here and writes heuristics back.
"""

from __future__ import annotations

import json
import logging
import math
import os
import time
from pathlib import Path
from typing import Any

from purpose_agent.types import (
    Heuristic,
    MemoryRecord,
    MemoryTier,
    Trajectory,
    TrajectoryStep,
)

logger = logging.getLogger(__name__)


class ExperienceReplay:
    """
    Experience Replay buffer with two-phase retrieval.
    
    Phase 1 (Recall): Retrieve top-k records by semantic similarity to query
    Phase 2 (Re-rank): Re-order by learned Q-value utility scores
    
    The buffer supports:
    - Adding trajectories with automatic scoring
    - Retrieving similar past experiences for the Actor's context
    - Q-value updates after heuristics are applied (Bellman-style)
    - Persistence to disk (JSON)
    - Capacity management (evict lowest Q-value records when full)
    
    Args:
        capacity: Maximum number of records to store
        similarity_weight: λ in retrieval score = λ·similarity + (1-λ)·q_value
        persistence_path: If set, auto-save/load buffer to this file
    """

    def __init__(
        self,
        capacity: int = 500,
        similarity_weight: float = 0.6,
        persistence_path: str | Path | None = None,
    ):
        self.capacity = capacity
        self.similarity_weight = similarity_weight
        self.persistence_path = Path(persistence_path) if persistence_path else None
        self.records: list[MemoryRecord] = []

        # Load from disk if available
        if self.persistence_path and self.persistence_path.exists():
            self._load()

    # ------------------------------------------------------------------
    # Core Operations
    # ------------------------------------------------------------------

    def add(self, trajectory: Trajectory) -> MemoryRecord:
        """
        Add a completed trajectory to the buffer.
        
        Automatically computes a task embedding (simple TF-IDF-style hash)
        and initial Q-value based on trajectory performance.
        """
        # Compute initial Q-value from trajectory performance
        initial_q = self._compute_initial_q(trajectory)

        record = MemoryRecord(
            trajectory=trajectory,
            heuristics=[],
            task_embedding=self._compute_embedding(trajectory.task_description),
            retrieval_q_value=initial_q,
        )

        # Capacity management: evict lowest Q-value if full
        if len(self.records) >= self.capacity:
            self._evict()

        self.records.append(record)
        logger.info(
            f"Experience Replay: Added trajectory '{trajectory.id}' "
            f"(q={initial_q:.3f}, steps={len(trajectory.steps)}, "
            f"Σreward={trajectory.cumulative_reward:.2f})"
        )

        if self.persistence_path:
            self._save()

        return record

    def retrieve(
        self,
        query: str,
        top_k: int = 5,
        min_q_value: float = 0.0,
    ) -> list[MemoryRecord]:
        """
        Two-phase retrieval (per MemRL arxiv:2601.03192):
        
        Phase 1: Recall candidates by semantic similarity
        Phase 2: Re-rank by Q-value utility
        
        Returns top-k records sorted by combined score.
        """
        if not self.records:
            return []

        query_embedding = self._compute_embedding(query)

        # Phase 1: Compute similarity scores for all records
        scored: list[tuple[float, MemoryRecord]] = []
        for record in self.records:
            if record.retrieval_q_value < min_q_value:
                continue
            sim = self._cosine_similarity(
                query_embedding, record.task_embedding or []
            )
            # Phase 2: Combined score
            combined = (
                self.similarity_weight * sim
                + (1 - self.similarity_weight) * record.retrieval_q_value
            )
            scored.append((combined, record))

        # Sort descending by combined score
        scored.sort(key=lambda x: -x[0])

        results = [record for _, record in scored[:top_k]]
        logger.debug(
            f"Experience Replay: Retrieved {len(results)} records for query "
            f"(top score={scored[0][0]:.3f})" if scored else "no records"
        )
        return results

    def update_q_value(
        self,
        record_id: str,
        reward: float,
        alpha: float = 0.1,
    ) -> None:
        """
        Update a record's retrieval Q-value using Monte Carlo update.
        
        Q_new = Q_old + α * (reward - Q_old)
        
        From REMEMBERER (arxiv:2306.07929): α = 1/N where N = number of
        updates. We use a fixed α for simplicity; override for REMEMBERER-exact.
        """
        for record in self.records:
            if record.id == record_id:
                old_q = record.retrieval_q_value
                record.retrieval_q_value += alpha * (reward - old_q)
                record.retrieval_q_value = max(0.0, min(1.0, record.retrieval_q_value))
                logger.debug(
                    f"Experience Replay: Q-value update for {record_id}: "
                    f"{old_q:.3f}{record.retrieval_q_value:.3f}"
                )
                if self.persistence_path:
                    self._save()
                return
        logger.warning(f"Experience Replay: Record {record_id} not found for Q-update")

    def attach_heuristics(
        self, record_id: str, heuristics: list[Heuristic]
    ) -> None:
        """Attach extracted heuristics to a memory record."""
        for record in self.records:
            if record.id == record_id:
                record.heuristics = heuristics
                if self.persistence_path:
                    self._save()
                return

    # ------------------------------------------------------------------
    # Statistics & Queries
    # ------------------------------------------------------------------

    def get_top_trajectories(
        self,
        n: int = 10,
        min_success_rate: float = 0.5,
    ) -> list[Trajectory]:
        """Get the n best trajectories by cumulative reward."""
        candidates = [
            r.trajectory for r in self.records
            if r.trajectory.success_rate >= min_success_rate
        ]
        candidates.sort(key=lambda t: -t.cumulative_reward)
        return candidates[:n]

    def get_all_heuristics(self, tier: MemoryTier | None = None) -> list[Heuristic]:
        """Get all extracted heuristics, optionally filtered by tier."""
        heuristics = []
        for record in self.records:
            for h in record.heuristics:
                if tier is None or h.tier == tier:
                    heuristics.append(h)
        return heuristics

    @property
    def size(self) -> int:
        return len(self.records)

    def clear(self) -> None:
        """Reset the replay buffer. Removes all records and persists the empty state."""
        self.records.clear()
        if self.persistence_path:
            self._save()
        logger.info("Experience Replay: cleared all records")

    @property
    def stats(self) -> dict[str, Any]:
        if not self.records:
            return {"size": 0}
        q_values = [r.retrieval_q_value for r in self.records]
        rewards = [r.trajectory.cumulative_reward for r in self.records]
        return {
            "size": len(self.records),
            "avg_q_value": sum(q_values) / len(q_values),
            "max_q_value": max(q_values),
            "avg_cumulative_reward": sum(rewards) / len(rewards),
            "total_heuristics": sum(len(r.heuristics) for r in self.records),
        }

    # ------------------------------------------------------------------
    # Embedding & Similarity (lightweight — no external deps)
    # ------------------------------------------------------------------

    @staticmethod
    def _compute_embedding(text: str) -> list[float]:
        """
        Lightweight text embedding using character n-gram hashing.
        
        This is intentionally simple — for production, swap in a real
        embedding model (sentence-transformers, OpenAI embeddings, etc.).
        
        To use real embeddings, subclass ExperienceReplay and override
        _compute_embedding() and _cosine_similarity().
        """
        # Character trigram hashing into a fixed-size vector
        dim = 128
        vec = [0.0] * dim
        text_lower = text.lower()
        for i in range(len(text_lower) - 2):
            trigram = text_lower[i:i + 3]
            h = hash(trigram) % dim
            vec[h] += 1.0

        # L2 normalize
        magnitude = math.sqrt(sum(x * x for x in vec))
        if magnitude > 0:
            vec = [x / magnitude for x in vec]
        return vec

    @staticmethod
    def _cosine_similarity(a: list[float], b: list[float]) -> float:
        """Cosine similarity between two vectors."""
        if not a or not b or len(a) != len(b):
            return 0.0
        dot = sum(x * y for x, y in zip(a, b))
        mag_a = math.sqrt(sum(x * x for x in a))
        mag_b = math.sqrt(sum(x * x for x in b))
        if mag_a == 0 or mag_b == 0:
            return 0.0
        return dot / (mag_a * mag_b)

    # ------------------------------------------------------------------
    # Initial Q-Value Estimation
    # ------------------------------------------------------------------

    @staticmethod
    def _compute_initial_q(trajectory: Trajectory) -> float:
        """
        Compute initial Q-value from trajectory performance.
        
        Uses a combination of:
        - Success rate (fraction of steps that improved state)
        - Total delta (net improvement)
        - Trajectory length efficiency (shorter = better for same delta)
        """
        if not trajectory.steps:
            return 0.3  # Uninformative prior

        success_rate = trajectory.success_rate
        total_delta = trajectory.total_delta
        length = len(trajectory.steps)

        # Normalize total_delta to 0-1 (assuming max meaningful delta is ~10)
        delta_normalized = max(0.0, min(1.0, total_delta / 10.0))

        # Efficiency bonus: more progress per step = higher Q
        efficiency = delta_normalized / max(length, 1)

        q = 0.4 * success_rate + 0.4 * delta_normalized + 0.2 * min(efficiency * 5, 1.0)
        return max(0.0, min(1.0, q))

    # ------------------------------------------------------------------
    # Capacity Management
    # ------------------------------------------------------------------

    def _evict(self) -> None:
        """Evict the lowest Q-value record."""
        if not self.records:
            return
        worst = min(self.records, key=lambda r: r.retrieval_q_value)
        self.records.remove(worst)
        logger.debug(
            f"Experience Replay: Evicted record {worst.id} "
            f"(q={worst.retrieval_q_value:.3f})"
        )

    # ------------------------------------------------------------------
    # Persistence
    # ------------------------------------------------------------------

    def _save(self) -> None:
        """Save buffer to disk as JSON."""
        if not self.persistence_path:
            return
        self.persistence_path.parent.mkdir(parents=True, exist_ok=True)

        data = []
        for record in self.records:
            data.append({
                "id": record.id,
                "retrieval_q_value": record.retrieval_q_value,
                "task_embedding": record.task_embedding,
                "trajectory": {
                    "id": record.trajectory.id,
                    "task_description": record.trajectory.task_description,
                    "purpose": record.trajectory.purpose,
                    "created_at": record.trajectory.created_at,
                    "cumulative_reward": record.trajectory.cumulative_reward,
                    "total_delta": record.trajectory.total_delta,
                    "success_rate": record.trajectory.success_rate,
                    "num_steps": len(record.trajectory.steps),
                },
                "heuristics": [
                    {
                        "id": h.id,
                        "pattern": h.pattern,
                        "strategy": h.strategy,
                        "steps": h.steps,
                        "tier": h.tier.value,
                        "q_value": h.q_value,
                        "times_used": h.times_used,
                        "times_succeeded": h.times_succeeded,
                    }
                    for h in record.heuristics
                ],
            })

        with open(self.persistence_path, "w") as f:
            json.dump(data, f, indent=2, default=str)

    def _load(self) -> None:
        """Load buffer from disk."""
        if not self.persistence_path or not self.persistence_path.exists():
            return
        try:
            with open(self.persistence_path) as f:
                data = json.load(f)

            for entry in data:
                traj_data = entry["trajectory"]
                trajectory = Trajectory(
                    task_description=traj_data["task_description"],
                    purpose=traj_data["purpose"],
                    id=traj_data["id"],
                    created_at=traj_data.get("created_at", time.time()),
                )
                heuristics = [
                    Heuristic(
                        id=h["id"],
                        pattern=h["pattern"],
                        strategy=h["strategy"],
                        steps=h["steps"],
                        tier=MemoryTier(h["tier"]),
                        q_value=h["q_value"],
                        times_used=h.get("times_used", 0),
                        times_succeeded=h.get("times_succeeded", 0),
                    )
                    for h in entry.get("heuristics", [])
                ]
                record = MemoryRecord(
                    id=entry["id"],
                    trajectory=trajectory,
                    heuristics=heuristics,
                    task_embedding=entry.get("task_embedding"),
                    retrieval_q_value=entry.get("retrieval_q_value", 0.5),
                )
                self.records.append(record)

            logger.info(f"Experience Replay: Loaded {len(self.records)} records from disk")
        except Exception as e:
            logger.error(f"Experience Replay: Failed to load from disk: {e}")