"""Retry and Recovery Optimizer: Maps failure tags to specific recovery actions.""" from typing import Dict, List, Optional from dataclasses import dataclass @dataclass class RecoveryAction: action: str # "retry_same","retry_changed_prompt","repair_tool","retrieve_more", # "switch_model","ask_clarification","call_verifier","mark_blocked","terminate" reasoning: str new_tier: Optional[int] = None additional_context: Optional[str] = None FAILURE_RECOVERY_MAP = { "tool_error": {"primary": "repair_tool", "fallback": "retry_changed_prompt"}, "tool_not_found": {"primary": "retry_changed_prompt", "fallback": "ask_clarification"}, "timeout": {"primary": "retry_same", "fallback": "switch_model"}, "context_too_large": {"primary": "retrieve_more", "fallback": "switch_model"}, "model_refused": {"primary": "retry_changed_prompt", "fallback": "switch_model"}, "hallucination": {"primary": "call_verifier", "fallback": "retrieve_more"}, "wrong_answer": {"primary": "switch_model", "fallback": "call_verifier"}, "incomplete": {"primary": "retrieve_more", "fallback": "retry_changed_prompt"}, "format_error": {"primary": "retry_changed_prompt", "fallback": "switch_model"}, "permission_denied": {"primary": "ask_clarification", "fallback": "mark_blocked"}, "rate_limit": {"primary": "retry_same", "fallback": "switch_model"}, "unknown": {"primary": "retry_changed_prompt", "fallback": "ask_clarification"}, } ESCALATION_TIERS = { "switch_model": 1, # upgrade by 1 tier "call_verifier": 0, "mark_blocked": 0, } class RetryOptimizer: def __init__(self, max_retries: int = 3, max_total_retries: int = 5): self.max_retries = max_retries self.max_total_retries = max_total_retries self.retry_counts: Dict[str, int] = {} self.total_retries = 0 self.recovery_stats = {} def get_recovery(self, failure_tag: str, current_tier: int, retry_num: int, previous_actions: List[str] = None, run_cost_so_far: float = 0, max_run_cost: float = 5.0) -> RecoveryAction: self.retry_counts[failure_tag] = self.retry_counts.get(failure_tag, 0) + 1 self.total_retries += 1 # Check if we should terminate if self.total_retries >= self.max_total_retries: return RecoveryAction("terminate", "max total retries reached") if retry_num >= self.max_retries: return RecoveryAction("mark_blocked", f"max retries ({self.max_retries}) for this failure") if run_cost_so_far >= max_run_cost * 0.8: return RecoveryAction("terminate", "approaching cost limit") # Get recovery action recovery = FAILURE_RECOVERY_MAP.get(failure_tag, FAILURE_RECOVERY_MAP["unknown"]) action_name = recovery["primary"] # Check if primary was already tried if previous_actions and action_name in previous_actions: action_name = recovery["fallback"] # Build action new_tier = None if action_name == "switch_model": new_tier = min(current_tier + ESCALATION_TIERS["switch_model"], 5) self.recovery_stats[action_name] = self.recovery_stats.get(action_name, 0) + 1 return RecoveryAction( action=action_name, reasoning=f"failure={failure_tag}, retry={retry_num}, action={action_name}", new_tier=new_tier, ) def reset_run(self): self.retry_counts = {} self.total_retries = 0