File size: 16,905 Bytes
c452421
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
"""
Cross-episode agent memory system.

Stores observations, strategies, and lessons learned across training episodes.
Injected into the system prompt at the start of every new episode so the
agent builds on past experience.

Inspired by kube-sre-gym's episodic memory and the open-env-assistant
memory consolidation approach.

Usage:
    from training.memory import (
        load_agent_memory,
        save_agent_memory,
        record_episode,
        build_memory_context,
        maybe_consolidate_memory,
    )

    memory = load_agent_memory()
    context_str = build_memory_context(memory)
    # inject context_str into system prompt

    # after episode ends:
    memory = record_episode(memory, {
        "task_id": "root_cause_analysis",
        "score": 0.82,
        "steps": 7,
        "trajectory_summary": "Investigated auth-service first, found JWT expiry bug",
        "mistakes": ["Escalated too early before diagnosing"],
        "successes": ["Correctly identified root cause on step 3"],
    })
    save_agent_memory(memory)
"""

from __future__ import annotations

import json
import logging
import os
from collections import defaultdict
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional

logger = logging.getLogger(__name__)

DEFAULT_PATH = os.path.join("outputs", "agent_memory.json")

# Max items stored per category before consolidation triggers
MAX_EPISODES_STORED = 100
MAX_RULES_PER_TASK  = 10
CONSOLIDATION_EVERY = 20   # consolidate after every N episodes


# ---------------------------------------------------------------------------
# Memory schema
# ---------------------------------------------------------------------------

def _empty_memory() -> Dict[str, Any]:
    return {
        "version": 1,
        "total_episodes": 0,
        "last_consolidated_at": None,
        "global_rules": [],        # list of str β€” apply to every task
        "task_rules": {},          # task_id β†’ list of str
        "episode_log": [],         # last MAX_EPISODES_STORED episodes
        "score_history": {},       # task_id β†’ list of float
        "mistakes": [],            # list of str β€” common mistakes to avoid
        "mistake_cards": [],
        "successes": [],           # list of str β€” things that worked well
    }


def new_agent_memory() -> Dict[str, Any]:
    """Return a fresh in-memory store without reading or writing disk."""
    return _empty_memory()


# ---------------------------------------------------------------------------
# Load / Save
# ---------------------------------------------------------------------------

def load_agent_memory(path: str = DEFAULT_PATH) -> Dict[str, Any]:
    """Load memory from disk. Returns empty memory if file doesn't exist."""
    if not os.path.exists(path):
        logger.info("No memory file found at %s, starting fresh", path)
        return _empty_memory()
    try:
        with open(path) as f:
            data = json.load(f)
        logger.info(
            "Loaded memory: %d episodes, %d global rules",
            data.get("total_episodes", 0),
            len(data.get("global_rules", [])),
        )
        return data
    except Exception as e:
        logger.warning("Failed to load memory from %s: %s β€” starting fresh", path, e)
        return _empty_memory()


def save_agent_memory(memory: Dict[str, Any], path: str = DEFAULT_PATH) -> None:
    """Save memory to disk."""
    os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
    # Trim episode log before saving
    memory["episode_log"] = memory["episode_log"][-MAX_EPISODES_STORED:]
    with open(path, "w") as f:
        json.dump(memory, f, indent=2)
    logger.debug("Saved memory to %s", path)


# ---------------------------------------------------------------------------
# Record an episode
# ---------------------------------------------------------------------------

def record_episode(
    memory: Dict[str, Any],
    episode_data: Dict[str, Any],
) -> Dict[str, Any]:
    """
    Record a completed episode into memory.

    episode_data keys:
        task_id (str)            β€” which task was attempted
        score (float)            β€” 0.0–1.0 final score
        steps (int)              β€” number of steps taken
        trajectory_summary (str) β€” 1-2 sentence summary of what happened
        mistakes (list[str])     β€” things that went wrong (optional)
        successes (list[str])    β€” things that worked (optional)
    """
    task_id = episode_data.get("task_id", "unknown")
    score   = float(episode_data.get("score", 0.0))

    # Score history per task
    if task_id not in memory["score_history"]:
        memory["score_history"][task_id] = []
    memory["score_history"][task_id].append(score)

    # Episode log
    log_entry = {
        "timestamp":    datetime.now(timezone.utc).isoformat(),
        "task_id":      task_id,
        "score":        score,
        "steps":        episode_data.get("steps", 0),
        "summary":      episode_data.get("trajectory_summary", ""),
    }
    memory["episode_log"].append(log_entry)
    memory["total_episodes"] = memory.get("total_episodes", 0) + 1
    memory.setdefault("mistake_cards", [])

    # Extract mistakes and successes
    for mistake in episode_data.get("mistakes", []):
        if mistake and mistake not in memory["mistakes"]:
            memory["mistakes"].append(mistake)

    for success in episode_data.get("successes", []):
        if success and success not in memory["successes"]:
            memory["successes"].append(success)

    for card in episode_data.get("mistake_cards", []):
        _record_mistake_card(memory, card, task_id)

    # Auto-generate rules from patterns
    _update_rules_from_episode(memory, task_id, score, episode_data)

    return memory


def _record_mistake_card(memory: Dict[str, Any], card: Dict[str, Any], task_id: str) -> None:
    """Merge a structured mistake card into memory with seen-count tracking."""
    if not isinstance(card, dict):
        return
    normalized = {
        "mistake_type": str(card.get("mistake_type") or "unknown_mistake"),
        "task_id": str(card.get("task_id") or task_id),
        "worker_id": card.get("worker_id"),
        "bad_decision": card.get("bad_decision"),
        "correct_decision": card.get("correct_decision"),
        "evidence": str(card.get("evidence") or "")[:300],
        "lesson": str(card.get("lesson") or "Avoid repeating this failure.")[:300],
    }
    episode_index = int(memory.get("total_episodes", 0) or 0)
    key_fields = (
        normalized["mistake_type"],
        normalized["task_id"],
        normalized.get("worker_id") or "",
        normalized.get("correct_decision") or "",
    )
    cards = memory.setdefault("mistake_cards", [])
    for existing in cards:
        existing_key = (
            existing.get("mistake_type"),
            existing.get("task_id"),
            existing.get("worker_id") or "",
            existing.get("correct_decision") or "",
        )
        if existing_key == key_fields:
            existing["seen_count"] = int(existing.get("seen_count", 1)) + 1
            existing["last_seen_episode"] = episode_index
            existing["evidence"] = normalized["evidence"] or existing.get("evidence", "")
            existing["lesson"] = normalized["lesson"] or existing.get("lesson", "")
            return

    normalized["seen_count"] = 1
    normalized["first_seen_episode"] = episode_index
    normalized["last_seen_episode"] = episode_index
    cards.append(normalized)
    cards.sort(
        key=lambda item: (
            int(item.get("seen_count", 0)),
            int(item.get("last_seen_episode", 0)),
        ),
        reverse=True,
    )
    del cards[30:]


def _update_rules_from_episode(
    memory: Dict[str, Any],
    task_id: str,
    score: float,
    episode_data: Dict[str, Any],
) -> None:
    """Derive rules from episode outcome and add to task_rules."""
    if task_id not in memory["task_rules"]:
        memory["task_rules"][task_id] = []

    task_rules = memory["task_rules"][task_id]

    # High-score episode: extract successes as rules
    if score >= 0.85 and episode_data.get("successes"):
        for s in episode_data["successes"]:
            rule = f"[WORKS] {s}"
            if rule not in task_rules:
                task_rules.append(rule)

    # Low-score episode: extract mistakes as rules
    if score < 0.50 and episode_data.get("mistakes"):
        for m in episode_data["mistakes"]:
            rule = f"[AVOID] {m}"
            if rule not in task_rules:
                task_rules.append(rule)

    # Trim to max
    memory["task_rules"][task_id] = task_rules[-MAX_RULES_PER_TASK:]


# ---------------------------------------------------------------------------
# Build context string for injection into system prompt
# ---------------------------------------------------------------------------

def build_memory_context(
    memory: Dict[str, Any],
    task_id: Optional[str] = None,
    max_rules: int = 5,
    max_recent: int = 3,
) -> str:
    """
    Build a concise memory context string for injection into the system prompt.

    Returns a string of ~200 tokens that summarizes key lessons learned.
    Inject this at the TOP of the system prompt before each episode.
    """
    lines: List[str] = ["## MEMORY FROM PAST EPISODES"]

    # Task-specific rules
    if task_id and task_id in memory.get("task_rules", {}):
        rules = memory["task_rules"][task_id][-max_rules:]
        if rules:
            lines.append(f"\nRules for {task_id}:")
            for rule in rules:
                lines.append(f"  - {rule}")

    # Global rules
    global_rules = memory.get("global_rules", [])[-max_rules:]
    if global_rules:
        lines.append("\nGeneral rules (all tasks):")
        for rule in global_rules:
            lines.append(f"  - {rule}")

    # Common mistakes
    mistakes = memory.get("mistakes", [])[-3:]
    if mistakes:
        lines.append("\nMistakes to avoid:")
        for m in mistakes:
            lines.append(f"  - AVOID: {m}")

    mistake_cards = sorted(
        memory.get("mistake_cards", []),
        key=lambda item: (
            int(item.get("seen_count", 0)),
            int(item.get("last_seen_episode", 0)),
        ),
        reverse=True,
    )[:3]
    if mistake_cards:
        lines.append("\nStructured mistake cards:")
        for card in mistake_cards:
            seen = int(card.get("seen_count", 1))
            label = card.get("mistake_type", "mistake")
            lesson = card.get("lesson", "")
            evidence = card.get("evidence", "")
            lines.append(f"  - [{label}, seen {seen}x] {lesson} Evidence: {evidence}")

    # Recent episode outcomes for this task
    if task_id:
        recent = [
            ep for ep in memory.get("episode_log", [])
            if ep.get("task_id") == task_id
        ][-max_recent:]
        if recent:
            lines.append(f"\nRecent {task_id} episodes:")
            for ep in recent:
                lines.append(
                    f"  - Score {ep['score']:.2f} in {ep['steps']} steps: {ep['summary'][:100]}"
                )

    # Mean score for this task (self-awareness)
    if task_id and task_id in memory.get("score_history", {}):
        scores = memory["score_history"][task_id]
        if scores:
            mean = sum(scores) / len(scores)
            lines.append(f"\nYour mean score on {task_id}: {mean:.2f} (over {len(scores)} episodes)")

    if len(lines) == 1:
        return ""   # No memory yet β€” return empty string

    return "\n".join(lines)


# ---------------------------------------------------------------------------
# LLM-based memory consolidation (optional, requires API key)
# ---------------------------------------------------------------------------

def maybe_consolidate_memory(
    memory: Dict[str, Any],
    api_key: Optional[str] = None,
    path: str = DEFAULT_PATH,
) -> Dict[str, Any]:
    """
    Every CONSOLIDATION_EVERY episodes, use an LLM to distill episode logs
    into concise, high-signal rules. Saves tokens in future prompts.

    If no API key is available, falls back to simple heuristic consolidation.
    """
    total = memory.get("total_episodes", 0)
    last  = memory.get("last_consolidated_at") or 0
    if isinstance(last, str):
        last = 0  # reset if it was stored as ISO string

    if total - last < CONSOLIDATION_EVERY:
        return memory   # not yet due

    if api_key or os.getenv("GROQ_API_KEY"):
        memory = _llm_consolidate(memory, api_key or os.getenv("GROQ_API_KEY"))
    else:
        memory = _heuristic_consolidate(memory)

    memory["last_consolidated_at"] = total
    save_agent_memory(memory, path)
    return memory


def _heuristic_consolidate(memory: Dict[str, Any]) -> Dict[str, Any]:
    """
    Simple rule: look at episodes where score > 0.85, extract their summaries
    as global rules. Deduplicate. Trim old ones.
    """
    new_rules: List[str] = []
    for ep in memory.get("episode_log", []):
        if ep.get("score", 0.0) >= 0.85 and ep.get("summary"):
            rule = f"[HIGH SCORE {ep['score']:.2f}] {ep['summary'][:120]}"
            if rule not in new_rules:
                new_rules.append(rule)

    # Merge with existing global rules (keep most recent)
    combined = memory.get("global_rules", []) + new_rules
    memory["global_rules"] = list(dict.fromkeys(combined))[-MAX_RULES_PER_TASK * 2:]

    logger.info("Heuristic consolidation: %d global rules", len(memory["global_rules"]))
    return memory


def _llm_consolidate(memory: Dict[str, Any], api_key: str) -> Dict[str, Any]:
    """Use LLM to distill episode logs into concise rules."""
    try:
        import httpx

        episode_summary = "\n".join(
            f"task={ep['task_id']} score={ep['score']:.2f}: {ep['summary']}"
            for ep in memory.get("episode_log", [])[-30:]  # last 30 episodes
        )

        prompt = f"""You are analyzing an AI agent's performance across multiple episodes.
Here are recent episode outcomes:

{episode_summary}

Extract 5 concise, actionable rules the agent should follow in future episodes.
Each rule should be 1 sentence. Focus on what WORKS and what to AVOID.

Return ONLY a JSON array of strings:
["Rule 1...", "Rule 2...", ...]
"""
        response = httpx.post(
            f"{os.getenv('API_BASE_URL', 'https://api.groq.com/openai/v1')}/chat/completions",
            headers={"Authorization": f"Bearer {api_key}"},
            json={
                "model": "llama-3.3-70b-versatile",
                "messages": [{"role": "user", "content": prompt}],
                "temperature": 0.0,
                "max_tokens": 300,
            },
            timeout=30.0,
        )
        response.raise_for_status()
        content = response.json()["choices"][0]["message"]["content"]

        start = content.find("[")
        end   = content.rfind("]") + 1
        if start != -1 and end > 0:
            new_rules: List[str] = json.loads(content[start:end])
            existing = memory.get("global_rules", [])
            combined = existing + [f"[CONSOLIDATED] {r}" for r in new_rules]
            memory["global_rules"] = list(dict.fromkeys(combined))[-MAX_RULES_PER_TASK * 2:]
            logger.info("LLM consolidation: extracted %d new rules", len(new_rules))

    except Exception as e:
        logger.warning("LLM consolidation failed: %s β€” falling back to heuristic", e)
        memory = _heuristic_consolidate(memory)

    return memory


# ---------------------------------------------------------------------------
# Utility: memory stats for logging
# ---------------------------------------------------------------------------

def memory_summary(memory: Dict[str, Any]) -> Dict[str, Any]:
    """Human-readable summary of current memory state."""
    return {
        "total_episodes":  memory.get("total_episodes", 0),
        "global_rules":    len(memory.get("global_rules", [])),
        "task_rules":      {k: len(v) for k, v in memory.get("task_rules", {}).items()},
        "mistakes_stored": len(memory.get("mistakes", [])),
        "mistake_cards_stored": len(memory.get("mistake_cards", [])),
        "top_mistake_cards": [
            {
                "mistake_type": card.get("mistake_type"),
                "task_id": card.get("task_id"),
                "seen_count": card.get("seen_count", 0),
            }
            for card in sorted(
                memory.get("mistake_cards", []),
                key=lambda item: (
                    int(item.get("seen_count", 0)),
                    int(item.get("last_seen_episode", 0)),
                ),
                reverse=True,
            )[:5]
        ],
        "scores_by_task":  {
            k: round(sum(v) / len(v), 3)
            for k, v in memory.get("score_history", {}).items()
            if v
        },
    }