purpose-agent / purpose_agent /optimizer.py
Rohan03's picture
fix: universal parsing + OpenRouter + state bug — purpose_agent/optimizer.py
e07a0fb verified
"""
Heuristic Optimizer — Extracts "winning heuristics" from high-reward trajectories.
This is the self-improvement engine. It takes successful trajectories and distills
them into reusable heuristics that update the agent's long-term memory.
The key insight (from CER arxiv:2506.06698 and MUSE arxiv:2510.08002):
- Don't store raw trajectories in the prompt (context bloat)
- DISTILL them into abstract, reusable patterns
- Use {variable} placeholders so heuristics generalize
- Deduplicate and merge similar heuristics to prevent memory drift
The Optimizer produces three types of heuristics (MUSE 3-tier):
1. STRATEGIC: High-level <Dilemma, Strategy> pairs (e.g., "When stuck on X, try Y")
2. PROCEDURAL: Step-by-step SOPs for specific task patterns
3. TOOL: Per-action tips based on observed usage patterns
"""
from __future__ import annotations
import json
import logging
from typing import Any
from purpose_agent.types import (
Heuristic,
MemoryTier,
Trajectory,
TrajectoryStep,
)
from purpose_agent.llm_backend import ChatMessage, LLMBackend
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Distillation Prompts (inspired by CER Appendix A.1 + MUSE Section 3.2)
# ---------------------------------------------------------------------------
DISTILL_SYSTEM_PROMPT = """\
You are a HEURISTIC EXTRACTOR. Given a successful task trajectory, you extract
reusable lessons that will help an agent perform better on FUTURE similar tasks.
## Output Format
You produce three types of heuristics:
### 1. STRATEGIC (high-level wisdom)
Format: {"pattern": "When <situation>", "strategy": "Do <approach>"}
- Abstract away specific details — use {variable} placeholders
- Focus on dilemmas and decision points, not routine steps
- Example: {"pattern": "When facing {task_type} with multiple valid approaches",
"strategy": "Start with the simplest approach that could work, escalate only if it fails"}
### 2. PROCEDURAL (step-by-step SOPs)
Format: {"pattern": "To accomplish {task_pattern}", "strategy": "Follow these steps",
"steps": ["Step 1: ...", "Step 2: ..."]}
- Include concrete action names and parameter patterns
- Use {variable} placeholders for task-specific values
- Example: {"pattern": "To search for {item} in {environment}",
"steps": ["Check {most_likely_location} first", "If not found, expand search radius", ...]}
### 3. TOOL (per-action tips)
Format: {"pattern": "When using action {action_name}", "strategy": "Remember to {tip}"}
- Based on action successes and failures in the trajectory
- Focus on non-obvious gotchas and best practices
"""
DISTILL_TRAJECTORY_PROMPT = """\
## Task Description
{task_description}
## Purpose
{purpose}
## Trajectory Summary
Total steps: {num_steps}
Success rate: {success_rate:.1%}
Cumulative reward: {cumulative_reward:.2f}
Net state improvement: {total_delta:.2f}
## Step-by-Step Trajectory
{trajectory_steps}
## Existing Heuristics (do NOT duplicate these)
{existing_heuristics}
Extract the winning heuristics from this trajectory. Focus on:
1. What decisions led to the highest-scoring steps?
2. Were there any mistakes that were corrected? What was learned?
3. Are there any patterns that would generalize to similar tasks?
Respond with a JSON array of heuristics, each with:
- "tier": "strategic" | "procedural" | "tool"
- "pattern": When/what this applies to (use {{variable}} placeholders)
- "strategy": What to do
- "steps": (optional, for procedural only) List of step strings
"""
DISTILL_SCHEMA: dict[str, Any] = {
"type": "object",
"properties": {
"heuristics": {
"type": "array",
"items": {
"type": "object",
"properties": {
"tier": {
"type": "string",
"enum": ["strategic", "procedural", "tool"],
},
"pattern": {"type": "string"},
"strategy": {"type": "string"},
"steps": {
"type": "array",
"items": {"type": "string"},
},
},
"required": ["tier", "pattern", "strategy"],
},
}
},
"required": ["heuristics"],
}
# ---------------------------------------------------------------------------
# Merge / Dedup Prompts
# ---------------------------------------------------------------------------
MERGE_SYSTEM_PROMPT = """\
You are a HEURISTIC DEDUPLICATOR. Given a list of heuristics, merge any that
are semantically similar into a single, more general heuristic.
Rules:
- If two heuristics describe the same strategy for similar situations, MERGE them
- The merged heuristic should be MORE general (wider applicability)
- Keep the higher Q-value when merging
- Preserve concrete action names and step details
- Do NOT merge heuristics from different tiers
"""
MERGE_PROMPT = """\
## Heuristics to Merge/Deduplicate
{heuristics_json}
Return a JSON array of the deduplicated heuristics. If two are similar,
combine them into one. Keep all unique heuristics.
"""
# ---------------------------------------------------------------------------
# Optimizer Class
# ---------------------------------------------------------------------------
class HeuristicOptimizer:
"""
Extracts reusable heuristics from high-reward trajectories and manages
the heuristic library (dedup, merge, Q-value updates).
This is the "learning" module — it reads trajectories from Experience Replay
and produces heuristics that update the Actor's memory.
The optimization loop (called by Orchestrator after each task):
1. Get top trajectories from Experience Replay
2. Distill each into candidate heuristics via LLM
3. Merge/deduplicate with existing heuristic library
4. Update Q-values based on usage success/failure
5. Push updated heuristics to Actor's memory tiers
Args:
llm: LLM backend for distillation (can be same or different from Actor/Critic)
min_reward_threshold: Minimum cumulative reward to consider a trajectory
max_heuristics_per_tier: Cap on heuristics per tier to prevent context bloat
merge_similarity_threshold: How similar two heuristics must be to merge
"""
def __init__(
self,
llm: LLMBackend,
min_reward_threshold: float = 1.0,
max_heuristics_per_tier: int = 20,
):
self.llm = llm
self.min_reward_threshold = min_reward_threshold
self.max_heuristics_per_tier = max_heuristics_per_tier
self.heuristic_library: list[Heuristic] = []
# ------------------------------------------------------------------
# Core: Distill Trajectory → Heuristics
# ------------------------------------------------------------------
def distill_trajectory(
self,
trajectory: Trajectory,
existing_heuristics: list[Heuristic] | None = None,
) -> list[Heuristic]:
"""
Extract heuristics from a single trajectory via LLM distillation.
Uses the CER (arxiv:2506.06698) distillation prompt pattern:
- Abstract away specifics with {variable} placeholders
- Separate into Dynamics (what was learned) and Skills (how to act)
- Skip heuristics that duplicate existing ones
"""
if trajectory.cumulative_reward < self.min_reward_threshold:
logger.info(
f"Optimizer: Skipping trajectory {trajectory.id} "
f"(reward={trajectory.cumulative_reward:.2f} < threshold)"
)
return []
existing = existing_heuristics or self.heuristic_library
# Format trajectory steps for the prompt
step_lines = []
for step in trajectory.steps:
score_info = ""
if step.score is not None:
score_info = (
f" → Φ: {step.score.phi_before:.1f}{step.score.phi_after:.1f} "
f"(Δ={step.score.delta:+.2f})"
)
step_lines.append(
f"Step {step.step_index}: "
f"Action={step.action.name}({json.dumps(step.action.params, default=str)})\n"
f" Thought: {step.action.thought[:150]}\n"
f" State before: {step.state_before.describe()[:200]}\n"
f" State after: {step.state_after.describe()[:200]}\n"
f" Score{score_info}"
)
existing_str = "None" if not existing else "\n".join(
f"- [{h.tier.value}] {h.pattern}: {h.strategy}" for h in existing[:20]
)
messages = [
ChatMessage(role="system", content=DISTILL_SYSTEM_PROMPT),
ChatMessage(role="user", content=DISTILL_TRAJECTORY_PROMPT.format(
task_description=trajectory.task_description,
purpose=trajectory.purpose,
num_steps=len(trajectory.steps),
success_rate=trajectory.success_rate,
cumulative_reward=trajectory.cumulative_reward,
total_delta=trajectory.total_delta,
trajectory_steps="\n\n".join(step_lines),
existing_heuristics=existing_str,
)),
]
from purpose_agent.robust_parser import parse_optimizer_response
try:
result = self.llm.generate_structured(messages, schema=DISTILL_SCHEMA)
except Exception:
raw = self.llm.generate(messages, temperature=0.5, max_tokens=2000)
result = parse_optimizer_response(raw)
new_heuristics = []
for h_data in result.get("heuristics", []):
tier_str = h_data.get("tier", "strategic")
try:
tier = MemoryTier(tier_str)
except ValueError:
tier = MemoryTier.STRATEGIC
heuristic = Heuristic(
pattern=h_data.get("pattern", ""),
strategy=h_data.get("strategy", ""),
steps=h_data.get("steps", []),
tier=tier,
source_trajectory_id=trajectory.id,
q_value=trajectory.success_rate, # Initial Q from trajectory success
)
new_heuristics.append(heuristic)
logger.info(
f"Optimizer: Distilled {len(new_heuristics)} heuristics from "
f"trajectory {trajectory.id}"
)
return new_heuristics
# ------------------------------------------------------------------
# Merge & Deduplicate
# ------------------------------------------------------------------
def merge_heuristics(
self,
new_heuristics: list[Heuristic],
) -> list[Heuristic]:
"""
Merge new heuristics into the library, deduplicating similar ones.
Per MUSE (arxiv:2510.08002) post-task distillation:
- Merge similar heuristics into more general ones
- Keep the higher Q-value
- Cap per-tier to prevent context bloat
"""
# Add new heuristics to library
combined = self.heuristic_library + new_heuristics
if not combined:
return []
# Group by tier
by_tier: dict[MemoryTier, list[Heuristic]] = {}
for h in combined:
by_tier.setdefault(h.tier, []).append(h)
# Deduplicate within each tier
merged_library: list[Heuristic] = []
for tier, heuristics in by_tier.items():
if len(heuristics) <= self.max_heuristics_per_tier:
merged_library.extend(heuristics)
continue
# Use LLM to merge if over capacity
try:
merged = self._llm_merge(heuristics, tier)
merged_library.extend(merged[:self.max_heuristics_per_tier])
except Exception as e:
logger.warning(f"Optimizer: LLM merge failed ({e}), using Q-value sort")
# Fallback: keep highest Q-value heuristics
heuristics.sort(key=lambda h: -h.q_value)
merged_library.extend(heuristics[:self.max_heuristics_per_tier])
self.heuristic_library = merged_library
logger.info(
f"Optimizer: Library updated — {len(self.heuristic_library)} heuristics "
f"({sum(1 for h in self.heuristic_library if h.tier == MemoryTier.STRATEGIC)} strategic, "
f"{sum(1 for h in self.heuristic_library if h.tier == MemoryTier.PROCEDURAL)} procedural, "
f"{sum(1 for h in self.heuristic_library if h.tier == MemoryTier.TOOL)} tool)"
)
return self.heuristic_library
def _llm_merge(
self,
heuristics: list[Heuristic],
tier: MemoryTier,
) -> list[Heuristic]:
"""Use LLM to merge similar heuristics."""
h_dicts = [
{
"id": h.id,
"pattern": h.pattern,
"strategy": h.strategy,
"steps": h.steps,
"q_value": h.q_value,
}
for h in heuristics
]
messages = [
ChatMessage(role="system", content=MERGE_SYSTEM_PROMPT),
ChatMessage(role="user", content=MERGE_PROMPT.format(
heuristics_json=json.dumps(h_dicts, indent=2)
)),
]
from purpose_agent.robust_parser import parse_optimizer_response
try:
result = self.llm.generate_structured(messages, schema=DISTILL_SCHEMA)
except Exception:
raw = self.llm.generate(messages, temperature=0.5, max_tokens=2000)
result = parse_optimizer_response(raw)
merged = []
for h_data in result.get("heuristics", []):
merged.append(Heuristic(
pattern=h_data.get("pattern", ""),
strategy=h_data.get("strategy", ""),
steps=h_data.get("steps", []),
tier=tier,
q_value=max(
(h.q_value for h in heuristics
if h.pattern == h_data.get("pattern")),
default=0.5,
),
))
return merged
# ------------------------------------------------------------------
# Q-Value Management
# ------------------------------------------------------------------
def update_heuristic_usage(
self,
heuristic_id: str,
was_successful: bool,
alpha: float = 0.1,
) -> None:
"""
Update a heuristic's Q-value based on whether it helped.
Called by the Orchestrator when a heuristic was in the Actor's
context and the task succeeded/failed.
"""
for h in self.heuristic_library:
if h.id == heuristic_id:
h.times_used += 1
if was_successful:
h.times_succeeded += 1
reward = 1.0 if was_successful else 0.0
h.update_q_value(reward, alpha=alpha)
logger.debug(
f"Optimizer: Heuristic {heuristic_id} updated "
f"(success={was_successful}, q={h.q_value:.3f})"
)
return
def get_heuristics_by_tier(self, tier: MemoryTier) -> list[Heuristic]:
"""Get all heuristics for a specific memory tier, sorted by Q-value."""
return sorted(
[h for h in self.heuristic_library if h.tier == tier],
key=lambda h: -h.q_value,
)
def prune_low_quality(self, min_q: float = 0.2, min_uses: int = 3) -> int:
"""Remove heuristics that have been tried and consistently fail."""
before = len(self.heuristic_library)
self.heuristic_library = [
h for h in self.heuristic_library
if h.times_used < min_uses or h.q_value >= min_q
]
pruned = before - len(self.heuristic_library)
if pruned:
logger.info(f"Optimizer: Pruned {pruned} low-quality heuristics")
return pruned
# ------------------------------------------------------------------
# Full Optimization Cycle
# ------------------------------------------------------------------
def optimize(
self,
trajectories: list[Trajectory],
) -> list[Heuristic]:
"""
Run the full optimization cycle:
1. Filter trajectories by minimum reward
2. Distill each into candidate heuristics
3. Merge with existing library
4. Prune low-quality heuristics
Returns the updated heuristic library.
"""
all_new: list[Heuristic] = []
for traj in trajectories:
if traj.cumulative_reward >= self.min_reward_threshold:
new = self.distill_trajectory(traj, self.heuristic_library)
all_new.extend(new)
if all_new:
self.merge_heuristics(all_new)
self.prune_low_quality()
logger.info(
f"Optimizer: Cycle complete — processed {len(trajectories)} trajectories, "
f"library size: {len(self.heuristic_library)}"
)
return self.heuristic_library
# ------------------------------------------------------------------
# Fallback Parser
# ------------------------------------------------------------------
@staticmethod
def _parse_distillation_text(raw: str) -> dict[str, Any]:
"""Best-effort extraction of heuristics from free-form text."""
import re
heuristics = []
# Try to find JSON array in text
json_match = re.search(r'\[.*\]', raw, re.DOTALL)
if json_match:
try:
parsed = json.loads(json_match.group())
if isinstance(parsed, list):
return {"heuristics": parsed}
except json.JSONDecodeError:
pass
# Fall back to extracting patterns from text
pattern_matches = re.findall(
r'(?:pattern|when|if)\s*[:\-]\s*(.+?)(?:\n|$)',
raw, re.IGNORECASE
)
strategy_matches = re.findall(
r'(?:strategy|do|then)\s*[:\-]\s*(.+?)(?:\n|$)',
raw, re.IGNORECASE
)
for pattern, strategy in zip(pattern_matches, strategy_matches):
heuristics.append({
"tier": "strategic",
"pattern": pattern.strip(),
"strategy": strategy.strip(),
})
return {"heuristics": heuristics}