""" Heuristic Optimizer — Extracts "winning heuristics" from high-reward trajectories. This is the self-improvement engine. It takes successful trajectories and distills them into reusable heuristics that update the agent's long-term memory. The key insight (from CER arxiv:2506.06698 and MUSE arxiv:2510.08002): - Don't store raw trajectories in the prompt (context bloat) - DISTILL them into abstract, reusable patterns - Use {variable} placeholders so heuristics generalize - Deduplicate and merge similar heuristics to prevent memory drift The Optimizer produces three types of heuristics (MUSE 3-tier): 1. STRATEGIC: High-level pairs (e.g., "When stuck on X, try Y") 2. PROCEDURAL: Step-by-step SOPs for specific task patterns 3. TOOL: Per-action tips based on observed usage patterns """ from __future__ import annotations import json import logging from typing import Any from purpose_agent.types import ( Heuristic, MemoryTier, Trajectory, TrajectoryStep, ) from purpose_agent.llm_backend import ChatMessage, LLMBackend logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Distillation Prompts (inspired by CER Appendix A.1 + MUSE Section 3.2) # --------------------------------------------------------------------------- DISTILL_SYSTEM_PROMPT = """\ You are a HEURISTIC EXTRACTOR. Given a successful task trajectory, you extract reusable lessons that will help an agent perform better on FUTURE similar tasks. ## Output Format You produce three types of heuristics: ### 1. STRATEGIC (high-level wisdom) Format: {"pattern": "When ", "strategy": "Do "} - Abstract away specific details — use {variable} placeholders - Focus on dilemmas and decision points, not routine steps - Example: {"pattern": "When facing {task_type} with multiple valid approaches", "strategy": "Start with the simplest approach that could work, escalate only if it fails"} ### 2. PROCEDURAL (step-by-step SOPs) Format: {"pattern": "To accomplish {task_pattern}", "strategy": "Follow these steps", "steps": ["Step 1: ...", "Step 2: ..."]} - Include concrete action names and parameter patterns - Use {variable} placeholders for task-specific values - Example: {"pattern": "To search for {item} in {environment}", "steps": ["Check {most_likely_location} first", "If not found, expand search radius", ...]} ### 3. TOOL (per-action tips) Format: {"pattern": "When using action {action_name}", "strategy": "Remember to {tip}"} - Based on action successes and failures in the trajectory - Focus on non-obvious gotchas and best practices """ DISTILL_TRAJECTORY_PROMPT = """\ ## Task Description {task_description} ## Purpose {purpose} ## Trajectory Summary Total steps: {num_steps} Success rate: {success_rate:.1%} Cumulative reward: {cumulative_reward:.2f} Net state improvement: {total_delta:.2f} ## Step-by-Step Trajectory {trajectory_steps} ## Existing Heuristics (do NOT duplicate these) {existing_heuristics} Extract the winning heuristics from this trajectory. Focus on: 1. What decisions led to the highest-scoring steps? 2. Were there any mistakes that were corrected? What was learned? 3. Are there any patterns that would generalize to similar tasks? Respond with a JSON array of heuristics, each with: - "tier": "strategic" | "procedural" | "tool" - "pattern": When/what this applies to (use {{variable}} placeholders) - "strategy": What to do - "steps": (optional, for procedural only) List of step strings """ DISTILL_SCHEMA: dict[str, Any] = { "type": "object", "properties": { "heuristics": { "type": "array", "items": { "type": "object", "properties": { "tier": { "type": "string", "enum": ["strategic", "procedural", "tool"], }, "pattern": {"type": "string"}, "strategy": {"type": "string"}, "steps": { "type": "array", "items": {"type": "string"}, }, }, "required": ["tier", "pattern", "strategy"], }, } }, "required": ["heuristics"], } # --------------------------------------------------------------------------- # Merge / Dedup Prompts # --------------------------------------------------------------------------- MERGE_SYSTEM_PROMPT = """\ You are a HEURISTIC DEDUPLICATOR. Given a list of heuristics, merge any that are semantically similar into a single, more general heuristic. Rules: - If two heuristics describe the same strategy for similar situations, MERGE them - The merged heuristic should be MORE general (wider applicability) - Keep the higher Q-value when merging - Preserve concrete action names and step details - Do NOT merge heuristics from different tiers """ MERGE_PROMPT = """\ ## Heuristics to Merge/Deduplicate {heuristics_json} Return a JSON array of the deduplicated heuristics. If two are similar, combine them into one. Keep all unique heuristics. """ # --------------------------------------------------------------------------- # Optimizer Class # --------------------------------------------------------------------------- class HeuristicOptimizer: """ Extracts reusable heuristics from high-reward trajectories and manages the heuristic library (dedup, merge, Q-value updates). This is the "learning" module — it reads trajectories from Experience Replay and produces heuristics that update the Actor's memory. The optimization loop (called by Orchestrator after each task): 1. Get top trajectories from Experience Replay 2. Distill each into candidate heuristics via LLM 3. Merge/deduplicate with existing heuristic library 4. Update Q-values based on usage success/failure 5. Push updated heuristics to Actor's memory tiers Args: llm: LLM backend for distillation (can be same or different from Actor/Critic) min_reward_threshold: Minimum cumulative reward to consider a trajectory max_heuristics_per_tier: Cap on heuristics per tier to prevent context bloat merge_similarity_threshold: How similar two heuristics must be to merge """ def __init__( self, llm: LLMBackend, min_reward_threshold: float = 1.0, max_heuristics_per_tier: int = 20, ): self.llm = llm self.min_reward_threshold = min_reward_threshold self.max_heuristics_per_tier = max_heuristics_per_tier self.heuristic_library: list[Heuristic] = [] # ------------------------------------------------------------------ # Core: Distill Trajectory → Heuristics # ------------------------------------------------------------------ def distill_trajectory( self, trajectory: Trajectory, existing_heuristics: list[Heuristic] | None = None, ) -> list[Heuristic]: """ Extract heuristics from a single trajectory via LLM distillation. Uses the CER (arxiv:2506.06698) distillation prompt pattern: - Abstract away specifics with {variable} placeholders - Separate into Dynamics (what was learned) and Skills (how to act) - Skip heuristics that duplicate existing ones """ if trajectory.cumulative_reward < self.min_reward_threshold: logger.info( f"Optimizer: Skipping trajectory {trajectory.id} " f"(reward={trajectory.cumulative_reward:.2f} < threshold)" ) return [] existing = existing_heuristics or self.heuristic_library # Format trajectory steps for the prompt step_lines = [] for step in trajectory.steps: score_info = "" if step.score is not None: score_info = ( f" → Φ: {step.score.phi_before:.1f}→{step.score.phi_after:.1f} " f"(Δ={step.score.delta:+.2f})" ) step_lines.append( f"Step {step.step_index}: " f"Action={step.action.name}({json.dumps(step.action.params, default=str)})\n" f" Thought: {step.action.thought[:150]}\n" f" State before: {step.state_before.describe()[:200]}\n" f" State after: {step.state_after.describe()[:200]}\n" f" Score{score_info}" ) existing_str = "None" if not existing else "\n".join( f"- [{h.tier.value}] {h.pattern}: {h.strategy}" for h in existing[:20] ) messages = [ ChatMessage(role="system", content=DISTILL_SYSTEM_PROMPT), ChatMessage(role="user", content=DISTILL_TRAJECTORY_PROMPT.format( task_description=trajectory.task_description, purpose=trajectory.purpose, num_steps=len(trajectory.steps), success_rate=trajectory.success_rate, cumulative_reward=trajectory.cumulative_reward, total_delta=trajectory.total_delta, trajectory_steps="\n\n".join(step_lines), existing_heuristics=existing_str, )), ] try: result = self.llm.generate_structured(messages, schema=DISTILL_SCHEMA) except Exception as e: logger.error(f"Optimizer: Distillation failed ({e}), attempting text fallback") raw = self.llm.generate(messages, temperature=0.5) result = self._parse_distillation_text(raw) new_heuristics = [] for h_data in result.get("heuristics", []): tier_str = h_data.get("tier", "strategic") try: tier = MemoryTier(tier_str) except ValueError: tier = MemoryTier.STRATEGIC heuristic = Heuristic( pattern=h_data.get("pattern", ""), strategy=h_data.get("strategy", ""), steps=h_data.get("steps", []), tier=tier, source_trajectory_id=trajectory.id, q_value=trajectory.success_rate, # Initial Q from trajectory success ) new_heuristics.append(heuristic) logger.info( f"Optimizer: Distilled {len(new_heuristics)} heuristics from " f"trajectory {trajectory.id}" ) return new_heuristics # ------------------------------------------------------------------ # Merge & Deduplicate # ------------------------------------------------------------------ def merge_heuristics( self, new_heuristics: list[Heuristic], ) -> list[Heuristic]: """ Merge new heuristics into the library, deduplicating similar ones. Per MUSE (arxiv:2510.08002) post-task distillation: - Merge similar heuristics into more general ones - Keep the higher Q-value - Cap per-tier to prevent context bloat """ # Add new heuristics to library combined = self.heuristic_library + new_heuristics if not combined: return [] # Group by tier by_tier: dict[MemoryTier, list[Heuristic]] = {} for h in combined: by_tier.setdefault(h.tier, []).append(h) # Deduplicate within each tier merged_library: list[Heuristic] = [] for tier, heuristics in by_tier.items(): if len(heuristics) <= self.max_heuristics_per_tier: merged_library.extend(heuristics) continue # Use LLM to merge if over capacity try: merged = self._llm_merge(heuristics, tier) merged_library.extend(merged[:self.max_heuristics_per_tier]) except Exception as e: logger.warning(f"Optimizer: LLM merge failed ({e}), using Q-value sort") # Fallback: keep highest Q-value heuristics heuristics.sort(key=lambda h: -h.q_value) merged_library.extend(heuristics[:self.max_heuristics_per_tier]) self.heuristic_library = merged_library logger.info( f"Optimizer: Library updated — {len(self.heuristic_library)} heuristics " f"({sum(1 for h in self.heuristic_library if h.tier == MemoryTier.STRATEGIC)} strategic, " f"{sum(1 for h in self.heuristic_library if h.tier == MemoryTier.PROCEDURAL)} procedural, " f"{sum(1 for h in self.heuristic_library if h.tier == MemoryTier.TOOL)} tool)" ) return self.heuristic_library def _llm_merge( self, heuristics: list[Heuristic], tier: MemoryTier, ) -> list[Heuristic]: """Use LLM to merge similar heuristics.""" h_dicts = [ { "id": h.id, "pattern": h.pattern, "strategy": h.strategy, "steps": h.steps, "q_value": h.q_value, } for h in heuristics ] messages = [ ChatMessage(role="system", content=MERGE_SYSTEM_PROMPT), ChatMessage(role="user", content=MERGE_PROMPT.format( heuristics_json=json.dumps(h_dicts, indent=2) )), ] result = self.llm.generate_structured(messages, schema=DISTILL_SCHEMA) merged = [] for h_data in result.get("heuristics", []): merged.append(Heuristic( pattern=h_data.get("pattern", ""), strategy=h_data.get("strategy", ""), steps=h_data.get("steps", []), tier=tier, q_value=max( (h.q_value for h in heuristics if h.pattern == h_data.get("pattern")), default=0.5, ), )) return merged # ------------------------------------------------------------------ # Q-Value Management # ------------------------------------------------------------------ def update_heuristic_usage( self, heuristic_id: str, was_successful: bool, alpha: float = 0.1, ) -> None: """ Update a heuristic's Q-value based on whether it helped. Called by the Orchestrator when a heuristic was in the Actor's context and the task succeeded/failed. """ for h in self.heuristic_library: if h.id == heuristic_id: h.times_used += 1 if was_successful: h.times_succeeded += 1 reward = 1.0 if was_successful else 0.0 h.update_q_value(reward, alpha=alpha) logger.debug( f"Optimizer: Heuristic {heuristic_id} updated " f"(success={was_successful}, q={h.q_value:.3f})" ) return def get_heuristics_by_tier(self, tier: MemoryTier) -> list[Heuristic]: """Get all heuristics for a specific memory tier, sorted by Q-value.""" return sorted( [h for h in self.heuristic_library if h.tier == tier], key=lambda h: -h.q_value, ) def prune_low_quality(self, min_q: float = 0.2, min_uses: int = 3) -> int: """Remove heuristics that have been tried and consistently fail.""" before = len(self.heuristic_library) self.heuristic_library = [ h for h in self.heuristic_library if h.times_used < min_uses or h.q_value >= min_q ] pruned = before - len(self.heuristic_library) if pruned: logger.info(f"Optimizer: Pruned {pruned} low-quality heuristics") return pruned # ------------------------------------------------------------------ # Full Optimization Cycle # ------------------------------------------------------------------ def optimize( self, trajectories: list[Trajectory], ) -> list[Heuristic]: """ Run the full optimization cycle: 1. Filter trajectories by minimum reward 2. Distill each into candidate heuristics 3. Merge with existing library 4. Prune low-quality heuristics Returns the updated heuristic library. """ all_new: list[Heuristic] = [] for traj in trajectories: if traj.cumulative_reward >= self.min_reward_threshold: new = self.distill_trajectory(traj, self.heuristic_library) all_new.extend(new) if all_new: self.merge_heuristics(all_new) self.prune_low_quality() logger.info( f"Optimizer: Cycle complete — processed {len(trajectories)} trajectories, " f"library size: {len(self.heuristic_library)}" ) return self.heuristic_library # ------------------------------------------------------------------ # Fallback Parser # ------------------------------------------------------------------ @staticmethod def _parse_distillation_text(raw: str) -> dict[str, Any]: """Best-effort extraction of heuristics from free-form text.""" import re heuristics = [] # Try to find JSON array in text json_match = re.search(r'\[.*\]', raw, re.DOTALL) if json_match: try: parsed = json.loads(json_match.group()) if isinstance(parsed, list): return {"heuristics": parsed} except json.JSONDecodeError: pass # Fall back to extracting patterns from text pattern_matches = re.findall( r'(?:pattern|when|if)\s*[:\-]\s*(.+?)(?:\n|$)', raw, re.IGNORECASE ) strategy_matches = re.findall( r'(?:strategy|do|then)\s*[:\-]\s*(.+?)(?:\n|$)', raw, re.IGNORECASE ) for pattern, strategy in zip(pattern_matches, strategy_matches): heuristics.append({ "tier": "strategic", "pattern": pattern.strip(), "strategy": strategy.strip(), }) return {"heuristics": heuristics}