| """Retry and Recovery Optimizer: Maps failure tags to specific recovery actions.""" |
| from typing import Dict, List, Optional |
| from dataclasses import dataclass |
|
|
| @dataclass |
| class RecoveryAction: |
| action: str |
| |
| reasoning: str |
| new_tier: Optional[int] = None |
| additional_context: Optional[str] = None |
|
|
| FAILURE_RECOVERY_MAP = { |
| "tool_error": {"primary": "repair_tool", "fallback": "retry_changed_prompt"}, |
| "tool_not_found": {"primary": "retry_changed_prompt", "fallback": "ask_clarification"}, |
| "timeout": {"primary": "retry_same", "fallback": "switch_model"}, |
| "context_too_large": {"primary": "retrieve_more", "fallback": "switch_model"}, |
| "model_refused": {"primary": "retry_changed_prompt", "fallback": "switch_model"}, |
| "hallucination": {"primary": "call_verifier", "fallback": "retrieve_more"}, |
| "wrong_answer": {"primary": "switch_model", "fallback": "call_verifier"}, |
| "incomplete": {"primary": "retrieve_more", "fallback": "retry_changed_prompt"}, |
| "format_error": {"primary": "retry_changed_prompt", "fallback": "switch_model"}, |
| "permission_denied": {"primary": "ask_clarification", "fallback": "mark_blocked"}, |
| "rate_limit": {"primary": "retry_same", "fallback": "switch_model"}, |
| "unknown": {"primary": "retry_changed_prompt", "fallback": "ask_clarification"}, |
| } |
|
|
| ESCALATION_TIERS = { |
| "switch_model": 1, |
| "call_verifier": 0, |
| "mark_blocked": 0, |
| } |
|
|
| class RetryOptimizer: |
| def __init__(self, max_retries: int = 3, max_total_retries: int = 5): |
| self.max_retries = max_retries |
| self.max_total_retries = max_total_retries |
| self.retry_counts: Dict[str, int] = {} |
| self.total_retries = 0 |
| self.recovery_stats = {} |
|
|
| def get_recovery(self, failure_tag: str, current_tier: int, |
| retry_num: int, previous_actions: List[str] = None, |
| run_cost_so_far: float = 0, max_run_cost: float = 5.0) -> RecoveryAction: |
| self.retry_counts[failure_tag] = self.retry_counts.get(failure_tag, 0) + 1 |
| self.total_retries += 1 |
| |
| if self.total_retries >= self.max_total_retries: |
| return RecoveryAction("terminate", "max total retries reached") |
| if retry_num >= self.max_retries: |
| return RecoveryAction("mark_blocked", f"max retries ({self.max_retries}) for this failure") |
| if run_cost_so_far >= max_run_cost * 0.8: |
| return RecoveryAction("terminate", "approaching cost limit") |
| |
| recovery = FAILURE_RECOVERY_MAP.get(failure_tag, FAILURE_RECOVERY_MAP["unknown"]) |
| action_name = recovery["primary"] |
| |
| if previous_actions and action_name in previous_actions: |
| action_name = recovery["fallback"] |
| |
| new_tier = None |
| if action_name == "switch_model": |
| new_tier = min(current_tier + ESCALATION_TIERS["switch_model"], 5) |
| self.recovery_stats[action_name] = self.recovery_stats.get(action_name, 0) + 1 |
| return RecoveryAction( |
| action=action_name, |
| reasoning=f"failure={failure_tag}, retry={retry_num}, action={action_name}", |
| new_tier=new_tier, |
| ) |
|
|
| def reset_run(self): |
| self.retry_counts = {} |
| self.total_retries = 0 |
|
|