File size: 18,697 Bytes
b6f70a1 e07a0fb b6f70a1 e07a0fb b6f70a1 e07a0fb b6f70a1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 | """
Heuristic Optimizer — Extracts "winning heuristics" from high-reward trajectories.
This is the self-improvement engine. It takes successful trajectories and distills
them into reusable heuristics that update the agent's long-term memory.
The key insight (from CER arxiv:2506.06698 and MUSE arxiv:2510.08002):
- Don't store raw trajectories in the prompt (context bloat)
- DISTILL them into abstract, reusable patterns
- Use {variable} placeholders so heuristics generalize
- Deduplicate and merge similar heuristics to prevent memory drift
The Optimizer produces three types of heuristics (MUSE 3-tier):
1. STRATEGIC: High-level <Dilemma, Strategy> pairs (e.g., "When stuck on X, try Y")
2. PROCEDURAL: Step-by-step SOPs for specific task patterns
3. TOOL: Per-action tips based on observed usage patterns
"""
from __future__ import annotations
import json
import logging
from typing import Any
from purpose_agent.types import (
Heuristic,
MemoryTier,
Trajectory,
TrajectoryStep,
)
from purpose_agent.llm_backend import ChatMessage, LLMBackend
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Distillation Prompts (inspired by CER Appendix A.1 + MUSE Section 3.2)
# ---------------------------------------------------------------------------
DISTILL_SYSTEM_PROMPT = """\
You are a HEURISTIC EXTRACTOR. Given a successful task trajectory, you extract
reusable lessons that will help an agent perform better on FUTURE similar tasks.
## Output Format
You produce three types of heuristics:
### 1. STRATEGIC (high-level wisdom)
Format: {"pattern": "When <situation>", "strategy": "Do <approach>"}
- Abstract away specific details — use {variable} placeholders
- Focus on dilemmas and decision points, not routine steps
- Example: {"pattern": "When facing {task_type} with multiple valid approaches",
"strategy": "Start with the simplest approach that could work, escalate only if it fails"}
### 2. PROCEDURAL (step-by-step SOPs)
Format: {"pattern": "To accomplish {task_pattern}", "strategy": "Follow these steps",
"steps": ["Step 1: ...", "Step 2: ..."]}
- Include concrete action names and parameter patterns
- Use {variable} placeholders for task-specific values
- Example: {"pattern": "To search for {item} in {environment}",
"steps": ["Check {most_likely_location} first", "If not found, expand search radius", ...]}
### 3. TOOL (per-action tips)
Format: {"pattern": "When using action {action_name}", "strategy": "Remember to {tip}"}
- Based on action successes and failures in the trajectory
- Focus on non-obvious gotchas and best practices
"""
DISTILL_TRAJECTORY_PROMPT = """\
## Task Description
{task_description}
## Purpose
{purpose}
## Trajectory Summary
Total steps: {num_steps}
Success rate: {success_rate:.1%}
Cumulative reward: {cumulative_reward:.2f}
Net state improvement: {total_delta:.2f}
## Step-by-Step Trajectory
{trajectory_steps}
## Existing Heuristics (do NOT duplicate these)
{existing_heuristics}
Extract the winning heuristics from this trajectory. Focus on:
1. What decisions led to the highest-scoring steps?
2. Were there any mistakes that were corrected? What was learned?
3. Are there any patterns that would generalize to similar tasks?
Respond with a JSON array of heuristics, each with:
- "tier": "strategic" | "procedural" | "tool"
- "pattern": When/what this applies to (use {{variable}} placeholders)
- "strategy": What to do
- "steps": (optional, for procedural only) List of step strings
"""
DISTILL_SCHEMA: dict[str, Any] = {
"type": "object",
"properties": {
"heuristics": {
"type": "array",
"items": {
"type": "object",
"properties": {
"tier": {
"type": "string",
"enum": ["strategic", "procedural", "tool"],
},
"pattern": {"type": "string"},
"strategy": {"type": "string"},
"steps": {
"type": "array",
"items": {"type": "string"},
},
},
"required": ["tier", "pattern", "strategy"],
},
}
},
"required": ["heuristics"],
}
# ---------------------------------------------------------------------------
# Merge / Dedup Prompts
# ---------------------------------------------------------------------------
MERGE_SYSTEM_PROMPT = """\
You are a HEURISTIC DEDUPLICATOR. Given a list of heuristics, merge any that
are semantically similar into a single, more general heuristic.
Rules:
- If two heuristics describe the same strategy for similar situations, MERGE them
- The merged heuristic should be MORE general (wider applicability)
- Keep the higher Q-value when merging
- Preserve concrete action names and step details
- Do NOT merge heuristics from different tiers
"""
MERGE_PROMPT = """\
## Heuristics to Merge/Deduplicate
{heuristics_json}
Return a JSON array of the deduplicated heuristics. If two are similar,
combine them into one. Keep all unique heuristics.
"""
# ---------------------------------------------------------------------------
# Optimizer Class
# ---------------------------------------------------------------------------
class HeuristicOptimizer:
"""
Extracts reusable heuristics from high-reward trajectories and manages
the heuristic library (dedup, merge, Q-value updates).
This is the "learning" module — it reads trajectories from Experience Replay
and produces heuristics that update the Actor's memory.
The optimization loop (called by Orchestrator after each task):
1. Get top trajectories from Experience Replay
2. Distill each into candidate heuristics via LLM
3. Merge/deduplicate with existing heuristic library
4. Update Q-values based on usage success/failure
5. Push updated heuristics to Actor's memory tiers
Args:
llm: LLM backend for distillation (can be same or different from Actor/Critic)
min_reward_threshold: Minimum cumulative reward to consider a trajectory
max_heuristics_per_tier: Cap on heuristics per tier to prevent context bloat
merge_similarity_threshold: How similar two heuristics must be to merge
"""
def __init__(
self,
llm: LLMBackend,
min_reward_threshold: float = 1.0,
max_heuristics_per_tier: int = 20,
):
self.llm = llm
self.min_reward_threshold = min_reward_threshold
self.max_heuristics_per_tier = max_heuristics_per_tier
self.heuristic_library: list[Heuristic] = []
# ------------------------------------------------------------------
# Core: Distill Trajectory → Heuristics
# ------------------------------------------------------------------
def distill_trajectory(
self,
trajectory: Trajectory,
existing_heuristics: list[Heuristic] | None = None,
) -> list[Heuristic]:
"""
Extract heuristics from a single trajectory via LLM distillation.
Uses the CER (arxiv:2506.06698) distillation prompt pattern:
- Abstract away specifics with {variable} placeholders
- Separate into Dynamics (what was learned) and Skills (how to act)
- Skip heuristics that duplicate existing ones
"""
if trajectory.cumulative_reward < self.min_reward_threshold:
logger.info(
f"Optimizer: Skipping trajectory {trajectory.id} "
f"(reward={trajectory.cumulative_reward:.2f} < threshold)"
)
return []
existing = existing_heuristics or self.heuristic_library
# Format trajectory steps for the prompt
step_lines = []
for step in trajectory.steps:
score_info = ""
if step.score is not None:
score_info = (
f" → Φ: {step.score.phi_before:.1f}→{step.score.phi_after:.1f} "
f"(Δ={step.score.delta:+.2f})"
)
step_lines.append(
f"Step {step.step_index}: "
f"Action={step.action.name}({json.dumps(step.action.params, default=str)})\n"
f" Thought: {step.action.thought[:150]}\n"
f" State before: {step.state_before.describe()[:200]}\n"
f" State after: {step.state_after.describe()[:200]}\n"
f" Score{score_info}"
)
existing_str = "None" if not existing else "\n".join(
f"- [{h.tier.value}] {h.pattern}: {h.strategy}" for h in existing[:20]
)
messages = [
ChatMessage(role="system", content=DISTILL_SYSTEM_PROMPT),
ChatMessage(role="user", content=DISTILL_TRAJECTORY_PROMPT.format(
task_description=trajectory.task_description,
purpose=trajectory.purpose,
num_steps=len(trajectory.steps),
success_rate=trajectory.success_rate,
cumulative_reward=trajectory.cumulative_reward,
total_delta=trajectory.total_delta,
trajectory_steps="\n\n".join(step_lines),
existing_heuristics=existing_str,
)),
]
from purpose_agent.robust_parser import parse_optimizer_response
try:
result = self.llm.generate_structured(messages, schema=DISTILL_SCHEMA)
except Exception:
raw = self.llm.generate(messages, temperature=0.5, max_tokens=2000)
result = parse_optimizer_response(raw)
new_heuristics = []
for h_data in result.get("heuristics", []):
tier_str = h_data.get("tier", "strategic")
try:
tier = MemoryTier(tier_str)
except ValueError:
tier = MemoryTier.STRATEGIC
heuristic = Heuristic(
pattern=h_data.get("pattern", ""),
strategy=h_data.get("strategy", ""),
steps=h_data.get("steps", []),
tier=tier,
source_trajectory_id=trajectory.id,
q_value=trajectory.success_rate, # Initial Q from trajectory success
)
new_heuristics.append(heuristic)
logger.info(
f"Optimizer: Distilled {len(new_heuristics)} heuristics from "
f"trajectory {trajectory.id}"
)
return new_heuristics
# ------------------------------------------------------------------
# Merge & Deduplicate
# ------------------------------------------------------------------
def merge_heuristics(
self,
new_heuristics: list[Heuristic],
) -> list[Heuristic]:
"""
Merge new heuristics into the library, deduplicating similar ones.
Per MUSE (arxiv:2510.08002) post-task distillation:
- Merge similar heuristics into more general ones
- Keep the higher Q-value
- Cap per-tier to prevent context bloat
"""
# Add new heuristics to library
combined = self.heuristic_library + new_heuristics
if not combined:
return []
# Group by tier
by_tier: dict[MemoryTier, list[Heuristic]] = {}
for h in combined:
by_tier.setdefault(h.tier, []).append(h)
# Deduplicate within each tier
merged_library: list[Heuristic] = []
for tier, heuristics in by_tier.items():
if len(heuristics) <= self.max_heuristics_per_tier:
merged_library.extend(heuristics)
continue
# Use LLM to merge if over capacity
try:
merged = self._llm_merge(heuristics, tier)
merged_library.extend(merged[:self.max_heuristics_per_tier])
except Exception as e:
logger.warning(f"Optimizer: LLM merge failed ({e}), using Q-value sort")
# Fallback: keep highest Q-value heuristics
heuristics.sort(key=lambda h: -h.q_value)
merged_library.extend(heuristics[:self.max_heuristics_per_tier])
self.heuristic_library = merged_library
logger.info(
f"Optimizer: Library updated — {len(self.heuristic_library)} heuristics "
f"({sum(1 for h in self.heuristic_library if h.tier == MemoryTier.STRATEGIC)} strategic, "
f"{sum(1 for h in self.heuristic_library if h.tier == MemoryTier.PROCEDURAL)} procedural, "
f"{sum(1 for h in self.heuristic_library if h.tier == MemoryTier.TOOL)} tool)"
)
return self.heuristic_library
def _llm_merge(
self,
heuristics: list[Heuristic],
tier: MemoryTier,
) -> list[Heuristic]:
"""Use LLM to merge similar heuristics."""
h_dicts = [
{
"id": h.id,
"pattern": h.pattern,
"strategy": h.strategy,
"steps": h.steps,
"q_value": h.q_value,
}
for h in heuristics
]
messages = [
ChatMessage(role="system", content=MERGE_SYSTEM_PROMPT),
ChatMessage(role="user", content=MERGE_PROMPT.format(
heuristics_json=json.dumps(h_dicts, indent=2)
)),
]
from purpose_agent.robust_parser import parse_optimizer_response
try:
result = self.llm.generate_structured(messages, schema=DISTILL_SCHEMA)
except Exception:
raw = self.llm.generate(messages, temperature=0.5, max_tokens=2000)
result = parse_optimizer_response(raw)
merged = []
for h_data in result.get("heuristics", []):
merged.append(Heuristic(
pattern=h_data.get("pattern", ""),
strategy=h_data.get("strategy", ""),
steps=h_data.get("steps", []),
tier=tier,
q_value=max(
(h.q_value for h in heuristics
if h.pattern == h_data.get("pattern")),
default=0.5,
),
))
return merged
# ------------------------------------------------------------------
# Q-Value Management
# ------------------------------------------------------------------
def update_heuristic_usage(
self,
heuristic_id: str,
was_successful: bool,
alpha: float = 0.1,
) -> None:
"""
Update a heuristic's Q-value based on whether it helped.
Called by the Orchestrator when a heuristic was in the Actor's
context and the task succeeded/failed.
"""
for h in self.heuristic_library:
if h.id == heuristic_id:
h.times_used += 1
if was_successful:
h.times_succeeded += 1
reward = 1.0 if was_successful else 0.0
h.update_q_value(reward, alpha=alpha)
logger.debug(
f"Optimizer: Heuristic {heuristic_id} updated "
f"(success={was_successful}, q={h.q_value:.3f})"
)
return
def get_heuristics_by_tier(self, tier: MemoryTier) -> list[Heuristic]:
"""Get all heuristics for a specific memory tier, sorted by Q-value."""
return sorted(
[h for h in self.heuristic_library if h.tier == tier],
key=lambda h: -h.q_value,
)
def prune_low_quality(self, min_q: float = 0.2, min_uses: int = 3) -> int:
"""Remove heuristics that have been tried and consistently fail."""
before = len(self.heuristic_library)
self.heuristic_library = [
h for h in self.heuristic_library
if h.times_used < min_uses or h.q_value >= min_q
]
pruned = before - len(self.heuristic_library)
if pruned:
logger.info(f"Optimizer: Pruned {pruned} low-quality heuristics")
return pruned
# ------------------------------------------------------------------
# Full Optimization Cycle
# ------------------------------------------------------------------
def optimize(
self,
trajectories: list[Trajectory],
) -> list[Heuristic]:
"""
Run the full optimization cycle:
1. Filter trajectories by minimum reward
2. Distill each into candidate heuristics
3. Merge with existing library
4. Prune low-quality heuristics
Returns the updated heuristic library.
"""
all_new: list[Heuristic] = []
for traj in trajectories:
if traj.cumulative_reward >= self.min_reward_threshold:
new = self.distill_trajectory(traj, self.heuristic_library)
all_new.extend(new)
if all_new:
self.merge_heuristics(all_new)
self.prune_low_quality()
logger.info(
f"Optimizer: Cycle complete — processed {len(trajectories)} trajectories, "
f"library size: {len(self.heuristic_library)}"
)
return self.heuristic_library
# ------------------------------------------------------------------
# Fallback Parser
# ------------------------------------------------------------------
@staticmethod
def _parse_distillation_text(raw: str) -> dict[str, Any]:
"""Best-effort extraction of heuristics from free-form text."""
import re
heuristics = []
# Try to find JSON array in text
json_match = re.search(r'\[.*\]', raw, re.DOTALL)
if json_match:
try:
parsed = json.loads(json_match.group())
if isinstance(parsed, list):
return {"heuristics": parsed}
except json.JSONDecodeError:
pass
# Fall back to extracting patterns from text
pattern_matches = re.findall(
r'(?:pattern|when|if)\s*[:\-]\s*(.+?)(?:\n|$)',
raw, re.IGNORECASE
)
strategy_matches = re.findall(
r'(?:strategy|do|then)\s*[:\-]\s*(.+?)(?:\n|$)',
raw, re.IGNORECASE
)
for pattern, strategy in zip(pattern_matches, strategy_matches):
heuristics.append({
"tier": "strategic",
"pattern": pattern.strip(),
"strategy": strategy.strip(),
})
return {"heuristics": heuristics}
|