Rohan03 commited on
Commit
b6f70a1
·
verified ·
1 Parent(s): 455fdee

Add purpose_agent/optimizer.py

Browse files
Files changed (1) hide show
  1. purpose_agent/optimizer.py +496 -0
purpose_agent/optimizer.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Heuristic Optimizer — Extracts "winning heuristics" from high-reward trajectories.
3
+
4
+ This is the self-improvement engine. It takes successful trajectories and distills
5
+ them into reusable heuristics that update the agent's long-term memory.
6
+
7
+ The key insight (from CER arxiv:2506.06698 and MUSE arxiv:2510.08002):
8
+ - Don't store raw trajectories in the prompt (context bloat)
9
+ - DISTILL them into abstract, reusable patterns
10
+ - Use {variable} placeholders so heuristics generalize
11
+ - Deduplicate and merge similar heuristics to prevent memory drift
12
+
13
+ The Optimizer produces three types of heuristics (MUSE 3-tier):
14
+ 1. STRATEGIC: High-level <Dilemma, Strategy> pairs (e.g., "When stuck on X, try Y")
15
+ 2. PROCEDURAL: Step-by-step SOPs for specific task patterns
16
+ 3. TOOL: Per-action tips based on observed usage patterns
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import json
22
+ import logging
23
+ from typing import Any
24
+
25
+ from purpose_agent.types import (
26
+ Heuristic,
27
+ MemoryTier,
28
+ Trajectory,
29
+ TrajectoryStep,
30
+ )
31
+ from purpose_agent.llm_backend import ChatMessage, LLMBackend
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ # ---------------------------------------------------------------------------
37
+ # Distillation Prompts (inspired by CER Appendix A.1 + MUSE Section 3.2)
38
+ # ---------------------------------------------------------------------------
39
+
40
+ DISTILL_SYSTEM_PROMPT = """\
41
+ You are a HEURISTIC EXTRACTOR. Given a successful task trajectory, you extract
42
+ reusable lessons that will help an agent perform better on FUTURE similar tasks.
43
+
44
+ ## Output Format
45
+ You produce three types of heuristics:
46
+
47
+ ### 1. STRATEGIC (high-level wisdom)
48
+ Format: {"pattern": "When <situation>", "strategy": "Do <approach>"}
49
+ - Abstract away specific details — use {variable} placeholders
50
+ - Focus on dilemmas and decision points, not routine steps
51
+ - Example: {"pattern": "When facing {task_type} with multiple valid approaches",
52
+ "strategy": "Start with the simplest approach that could work, escalate only if it fails"}
53
+
54
+ ### 2. PROCEDURAL (step-by-step SOPs)
55
+ Format: {"pattern": "To accomplish {task_pattern}", "strategy": "Follow these steps",
56
+ "steps": ["Step 1: ...", "Step 2: ..."]}
57
+ - Include concrete action names and parameter patterns
58
+ - Use {variable} placeholders for task-specific values
59
+ - Example: {"pattern": "To search for {item} in {environment}",
60
+ "steps": ["Check {most_likely_location} first", "If not found, expand search radius", ...]}
61
+
62
+ ### 3. TOOL (per-action tips)
63
+ Format: {"pattern": "When using action {action_name}", "strategy": "Remember to {tip}"}
64
+ - Based on action successes and failures in the trajectory
65
+ - Focus on non-obvious gotchas and best practices
66
+ """
67
+
68
+ DISTILL_TRAJECTORY_PROMPT = """\
69
+ ## Task Description
70
+ {task_description}
71
+
72
+ ## Purpose
73
+ {purpose}
74
+
75
+ ## Trajectory Summary
76
+ Total steps: {num_steps}
77
+ Success rate: {success_rate:.1%}
78
+ Cumulative reward: {cumulative_reward:.2f}
79
+ Net state improvement: {total_delta:.2f}
80
+
81
+ ## Step-by-Step Trajectory
82
+ {trajectory_steps}
83
+
84
+ ## Existing Heuristics (do NOT duplicate these)
85
+ {existing_heuristics}
86
+
87
+ Extract the winning heuristics from this trajectory. Focus on:
88
+ 1. What decisions led to the highest-scoring steps?
89
+ 2. Were there any mistakes that were corrected? What was learned?
90
+ 3. Are there any patterns that would generalize to similar tasks?
91
+
92
+ Respond with a JSON array of heuristics, each with:
93
+ - "tier": "strategic" | "procedural" | "tool"
94
+ - "pattern": When/what this applies to (use {{variable}} placeholders)
95
+ - "strategy": What to do
96
+ - "steps": (optional, for procedural only) List of step strings
97
+ """
98
+
99
+ DISTILL_SCHEMA: dict[str, Any] = {
100
+ "type": "object",
101
+ "properties": {
102
+ "heuristics": {
103
+ "type": "array",
104
+ "items": {
105
+ "type": "object",
106
+ "properties": {
107
+ "tier": {
108
+ "type": "string",
109
+ "enum": ["strategic", "procedural", "tool"],
110
+ },
111
+ "pattern": {"type": "string"},
112
+ "strategy": {"type": "string"},
113
+ "steps": {
114
+ "type": "array",
115
+ "items": {"type": "string"},
116
+ },
117
+ },
118
+ "required": ["tier", "pattern", "strategy"],
119
+ },
120
+ }
121
+ },
122
+ "required": ["heuristics"],
123
+ }
124
+
125
+
126
+ # ---------------------------------------------------------------------------
127
+ # Merge / Dedup Prompts
128
+ # ---------------------------------------------------------------------------
129
+
130
+ MERGE_SYSTEM_PROMPT = """\
131
+ You are a HEURISTIC DEDUPLICATOR. Given a list of heuristics, merge any that
132
+ are semantically similar into a single, more general heuristic.
133
+
134
+ Rules:
135
+ - If two heuristics describe the same strategy for similar situations, MERGE them
136
+ - The merged heuristic should be MORE general (wider applicability)
137
+ - Keep the higher Q-value when merging
138
+ - Preserve concrete action names and step details
139
+ - Do NOT merge heuristics from different tiers
140
+ """
141
+
142
+
143
+ MERGE_PROMPT = """\
144
+ ## Heuristics to Merge/Deduplicate
145
+ {heuristics_json}
146
+
147
+ Return a JSON array of the deduplicated heuristics. If two are similar,
148
+ combine them into one. Keep all unique heuristics.
149
+ """
150
+
151
+
152
+ # ---------------------------------------------------------------------------
153
+ # Optimizer Class
154
+ # ---------------------------------------------------------------------------
155
+
156
+ class HeuristicOptimizer:
157
+ """
158
+ Extracts reusable heuristics from high-reward trajectories and manages
159
+ the heuristic library (dedup, merge, Q-value updates).
160
+
161
+ This is the "learning" module — it reads trajectories from Experience Replay
162
+ and produces heuristics that update the Actor's memory.
163
+
164
+ The optimization loop (called by Orchestrator after each task):
165
+ 1. Get top trajectories from Experience Replay
166
+ 2. Distill each into candidate heuristics via LLM
167
+ 3. Merge/deduplicate with existing heuristic library
168
+ 4. Update Q-values based on usage success/failure
169
+ 5. Push updated heuristics to Actor's memory tiers
170
+
171
+ Args:
172
+ llm: LLM backend for distillation (can be same or different from Actor/Critic)
173
+ min_reward_threshold: Minimum cumulative reward to consider a trajectory
174
+ max_heuristics_per_tier: Cap on heuristics per tier to prevent context bloat
175
+ merge_similarity_threshold: How similar two heuristics must be to merge
176
+ """
177
+
178
+ def __init__(
179
+ self,
180
+ llm: LLMBackend,
181
+ min_reward_threshold: float = 1.0,
182
+ max_heuristics_per_tier: int = 20,
183
+ ):
184
+ self.llm = llm
185
+ self.min_reward_threshold = min_reward_threshold
186
+ self.max_heuristics_per_tier = max_heuristics_per_tier
187
+ self.heuristic_library: list[Heuristic] = []
188
+
189
+ # ------------------------------------------------------------------
190
+ # Core: Distill Trajectory → Heuristics
191
+ # ------------------------------------------------------------------
192
+
193
+ def distill_trajectory(
194
+ self,
195
+ trajectory: Trajectory,
196
+ existing_heuristics: list[Heuristic] | None = None,
197
+ ) -> list[Heuristic]:
198
+ """
199
+ Extract heuristics from a single trajectory via LLM distillation.
200
+
201
+ Uses the CER (arxiv:2506.06698) distillation prompt pattern:
202
+ - Abstract away specifics with {variable} placeholders
203
+ - Separate into Dynamics (what was learned) and Skills (how to act)
204
+ - Skip heuristics that duplicate existing ones
205
+ """
206
+ if trajectory.cumulative_reward < self.min_reward_threshold:
207
+ logger.info(
208
+ f"Optimizer: Skipping trajectory {trajectory.id} "
209
+ f"(reward={trajectory.cumulative_reward:.2f} < threshold)"
210
+ )
211
+ return []
212
+
213
+ existing = existing_heuristics or self.heuristic_library
214
+
215
+ # Format trajectory steps for the prompt
216
+ step_lines = []
217
+ for step in trajectory.steps:
218
+ score_info = ""
219
+ if step.score is not None:
220
+ score_info = (
221
+ f" → Φ: {step.score.phi_before:.1f}→{step.score.phi_after:.1f} "
222
+ f"(Δ={step.score.delta:+.2f})"
223
+ )
224
+ step_lines.append(
225
+ f"Step {step.step_index}: "
226
+ f"Action={step.action.name}({json.dumps(step.action.params, default=str)})\n"
227
+ f" Thought: {step.action.thought[:150]}\n"
228
+ f" State before: {step.state_before.describe()[:200]}\n"
229
+ f" State after: {step.state_after.describe()[:200]}\n"
230
+ f" Score{score_info}"
231
+ )
232
+
233
+ existing_str = "None" if not existing else "\n".join(
234
+ f"- [{h.tier.value}] {h.pattern}: {h.strategy}" for h in existing[:20]
235
+ )
236
+
237
+ messages = [
238
+ ChatMessage(role="system", content=DISTILL_SYSTEM_PROMPT),
239
+ ChatMessage(role="user", content=DISTILL_TRAJECTORY_PROMPT.format(
240
+ task_description=trajectory.task_description,
241
+ purpose=trajectory.purpose,
242
+ num_steps=len(trajectory.steps),
243
+ success_rate=trajectory.success_rate,
244
+ cumulative_reward=trajectory.cumulative_reward,
245
+ total_delta=trajectory.total_delta,
246
+ trajectory_steps="\n\n".join(step_lines),
247
+ existing_heuristics=existing_str,
248
+ )),
249
+ ]
250
+
251
+ try:
252
+ result = self.llm.generate_structured(messages, schema=DISTILL_SCHEMA)
253
+ except Exception as e:
254
+ logger.error(f"Optimizer: Distillation failed ({e}), attempting text fallback")
255
+ raw = self.llm.generate(messages, temperature=0.5)
256
+ result = self._parse_distillation_text(raw)
257
+
258
+ new_heuristics = []
259
+ for h_data in result.get("heuristics", []):
260
+ tier_str = h_data.get("tier", "strategic")
261
+ try:
262
+ tier = MemoryTier(tier_str)
263
+ except ValueError:
264
+ tier = MemoryTier.STRATEGIC
265
+
266
+ heuristic = Heuristic(
267
+ pattern=h_data.get("pattern", ""),
268
+ strategy=h_data.get("strategy", ""),
269
+ steps=h_data.get("steps", []),
270
+ tier=tier,
271
+ source_trajectory_id=trajectory.id,
272
+ q_value=trajectory.success_rate, # Initial Q from trajectory success
273
+ )
274
+ new_heuristics.append(heuristic)
275
+
276
+ logger.info(
277
+ f"Optimizer: Distilled {len(new_heuristics)} heuristics from "
278
+ f"trajectory {trajectory.id}"
279
+ )
280
+ return new_heuristics
281
+
282
+ # ------------------------------------------------------------------
283
+ # Merge & Deduplicate
284
+ # ------------------------------------------------------------------
285
+
286
+ def merge_heuristics(
287
+ self,
288
+ new_heuristics: list[Heuristic],
289
+ ) -> list[Heuristic]:
290
+ """
291
+ Merge new heuristics into the library, deduplicating similar ones.
292
+
293
+ Per MUSE (arxiv:2510.08002) post-task distillation:
294
+ - Merge similar heuristics into more general ones
295
+ - Keep the higher Q-value
296
+ - Cap per-tier to prevent context bloat
297
+ """
298
+ # Add new heuristics to library
299
+ combined = self.heuristic_library + new_heuristics
300
+
301
+ if not combined:
302
+ return []
303
+
304
+ # Group by tier
305
+ by_tier: dict[MemoryTier, list[Heuristic]] = {}
306
+ for h in combined:
307
+ by_tier.setdefault(h.tier, []).append(h)
308
+
309
+ # Deduplicate within each tier
310
+ merged_library: list[Heuristic] = []
311
+ for tier, heuristics in by_tier.items():
312
+ if len(heuristics) <= self.max_heuristics_per_tier:
313
+ merged_library.extend(heuristics)
314
+ continue
315
+
316
+ # Use LLM to merge if over capacity
317
+ try:
318
+ merged = self._llm_merge(heuristics, tier)
319
+ merged_library.extend(merged[:self.max_heuristics_per_tier])
320
+ except Exception as e:
321
+ logger.warning(f"Optimizer: LLM merge failed ({e}), using Q-value sort")
322
+ # Fallback: keep highest Q-value heuristics
323
+ heuristics.sort(key=lambda h: -h.q_value)
324
+ merged_library.extend(heuristics[:self.max_heuristics_per_tier])
325
+
326
+ self.heuristic_library = merged_library
327
+ logger.info(
328
+ f"Optimizer: Library updated — {len(self.heuristic_library)} heuristics "
329
+ f"({sum(1 for h in self.heuristic_library if h.tier == MemoryTier.STRATEGIC)} strategic, "
330
+ f"{sum(1 for h in self.heuristic_library if h.tier == MemoryTier.PROCEDURAL)} procedural, "
331
+ f"{sum(1 for h in self.heuristic_library if h.tier == MemoryTier.TOOL)} tool)"
332
+ )
333
+ return self.heuristic_library
334
+
335
+ def _llm_merge(
336
+ self,
337
+ heuristics: list[Heuristic],
338
+ tier: MemoryTier,
339
+ ) -> list[Heuristic]:
340
+ """Use LLM to merge similar heuristics."""
341
+ h_dicts = [
342
+ {
343
+ "id": h.id,
344
+ "pattern": h.pattern,
345
+ "strategy": h.strategy,
346
+ "steps": h.steps,
347
+ "q_value": h.q_value,
348
+ }
349
+ for h in heuristics
350
+ ]
351
+
352
+ messages = [
353
+ ChatMessage(role="system", content=MERGE_SYSTEM_PROMPT),
354
+ ChatMessage(role="user", content=MERGE_PROMPT.format(
355
+ heuristics_json=json.dumps(h_dicts, indent=2)
356
+ )),
357
+ ]
358
+
359
+ result = self.llm.generate_structured(messages, schema=DISTILL_SCHEMA)
360
+
361
+ merged = []
362
+ for h_data in result.get("heuristics", []):
363
+ merged.append(Heuristic(
364
+ pattern=h_data.get("pattern", ""),
365
+ strategy=h_data.get("strategy", ""),
366
+ steps=h_data.get("steps", []),
367
+ tier=tier,
368
+ q_value=max(
369
+ (h.q_value for h in heuristics
370
+ if h.pattern == h_data.get("pattern")),
371
+ default=0.5,
372
+ ),
373
+ ))
374
+ return merged
375
+
376
+ # ------------------------------------------------------------------
377
+ # Q-Value Management
378
+ # ------------------------------------------------------------------
379
+
380
+ def update_heuristic_usage(
381
+ self,
382
+ heuristic_id: str,
383
+ was_successful: bool,
384
+ alpha: float = 0.1,
385
+ ) -> None:
386
+ """
387
+ Update a heuristic's Q-value based on whether it helped.
388
+
389
+ Called by the Orchestrator when a heuristic was in the Actor's
390
+ context and the task succeeded/failed.
391
+ """
392
+ for h in self.heuristic_library:
393
+ if h.id == heuristic_id:
394
+ h.times_used += 1
395
+ if was_successful:
396
+ h.times_succeeded += 1
397
+ reward = 1.0 if was_successful else 0.0
398
+ h.update_q_value(reward, alpha=alpha)
399
+ logger.debug(
400
+ f"Optimizer: Heuristic {heuristic_id} updated "
401
+ f"(success={was_successful}, q={h.q_value:.3f})"
402
+ )
403
+ return
404
+
405
+ def get_heuristics_by_tier(self, tier: MemoryTier) -> list[Heuristic]:
406
+ """Get all heuristics for a specific memory tier, sorted by Q-value."""
407
+ return sorted(
408
+ [h for h in self.heuristic_library if h.tier == tier],
409
+ key=lambda h: -h.q_value,
410
+ )
411
+
412
+ def prune_low_quality(self, min_q: float = 0.2, min_uses: int = 3) -> int:
413
+ """Remove heuristics that have been tried and consistently fail."""
414
+ before = len(self.heuristic_library)
415
+ self.heuristic_library = [
416
+ h for h in self.heuristic_library
417
+ if h.times_used < min_uses or h.q_value >= min_q
418
+ ]
419
+ pruned = before - len(self.heuristic_library)
420
+ if pruned:
421
+ logger.info(f"Optimizer: Pruned {pruned} low-quality heuristics")
422
+ return pruned
423
+
424
+ # ------------------------------------------------------------------
425
+ # Full Optimization Cycle
426
+ # ------------------------------------------------------------------
427
+
428
+ def optimize(
429
+ self,
430
+ trajectories: list[Trajectory],
431
+ ) -> list[Heuristic]:
432
+ """
433
+ Run the full optimization cycle:
434
+ 1. Filter trajectories by minimum reward
435
+ 2. Distill each into candidate heuristics
436
+ 3. Merge with existing library
437
+ 4. Prune low-quality heuristics
438
+
439
+ Returns the updated heuristic library.
440
+ """
441
+ all_new: list[Heuristic] = []
442
+
443
+ for traj in trajectories:
444
+ if traj.cumulative_reward >= self.min_reward_threshold:
445
+ new = self.distill_trajectory(traj, self.heuristic_library)
446
+ all_new.extend(new)
447
+
448
+ if all_new:
449
+ self.merge_heuristics(all_new)
450
+ self.prune_low_quality()
451
+
452
+ logger.info(
453
+ f"Optimizer: Cycle complete — processed {len(trajectories)} trajectories, "
454
+ f"library size: {len(self.heuristic_library)}"
455
+ )
456
+ return self.heuristic_library
457
+
458
+ # ------------------------------------------------------------------
459
+ # Fallback Parser
460
+ # ------------------------------------------------------------------
461
+
462
+ @staticmethod
463
+ def _parse_distillation_text(raw: str) -> dict[str, Any]:
464
+ """Best-effort extraction of heuristics from free-form text."""
465
+ import re
466
+
467
+ heuristics = []
468
+
469
+ # Try to find JSON array in text
470
+ json_match = re.search(r'\[.*\]', raw, re.DOTALL)
471
+ if json_match:
472
+ try:
473
+ parsed = json.loads(json_match.group())
474
+ if isinstance(parsed, list):
475
+ return {"heuristics": parsed}
476
+ except json.JSONDecodeError:
477
+ pass
478
+
479
+ # Fall back to extracting patterns from text
480
+ pattern_matches = re.findall(
481
+ r'(?:pattern|when|if)\s*[:\-]\s*(.+?)(?:\n|$)',
482
+ raw, re.IGNORECASE
483
+ )
484
+ strategy_matches = re.findall(
485
+ r'(?:strategy|do|then)\s*[:\-]\s*(.+?)(?:\n|$)',
486
+ raw, re.IGNORECASE
487
+ )
488
+
489
+ for pattern, strategy in zip(pattern_matches, strategy_matches):
490
+ heuristics.append({
491
+ "tier": "strategic",
492
+ "pattern": pattern.strip(),
493
+ "strategy": strategy.strip(),
494
+ })
495
+
496
+ return {"heuristics": heuristics}