agent-cost-optimizer / aco /retry_optimizer.py
narcolepticchicken's picture
Upload aco/retry_optimizer.py with huggingface_hub
2ffdfdb verified
"""Retry and Recovery Optimizer: Maps failure tags to specific recovery actions."""
from typing import Dict, List, Optional
from dataclasses import dataclass
@dataclass
class RecoveryAction:
action: str # "retry_same","retry_changed_prompt","repair_tool","retrieve_more",
# "switch_model","ask_clarification","call_verifier","mark_blocked","terminate"
reasoning: str
new_tier: Optional[int] = None
additional_context: Optional[str] = None
FAILURE_RECOVERY_MAP = {
"tool_error": {"primary": "repair_tool", "fallback": "retry_changed_prompt"},
"tool_not_found": {"primary": "retry_changed_prompt", "fallback": "ask_clarification"},
"timeout": {"primary": "retry_same", "fallback": "switch_model"},
"context_too_large": {"primary": "retrieve_more", "fallback": "switch_model"},
"model_refused": {"primary": "retry_changed_prompt", "fallback": "switch_model"},
"hallucination": {"primary": "call_verifier", "fallback": "retrieve_more"},
"wrong_answer": {"primary": "switch_model", "fallback": "call_verifier"},
"incomplete": {"primary": "retrieve_more", "fallback": "retry_changed_prompt"},
"format_error": {"primary": "retry_changed_prompt", "fallback": "switch_model"},
"permission_denied": {"primary": "ask_clarification", "fallback": "mark_blocked"},
"rate_limit": {"primary": "retry_same", "fallback": "switch_model"},
"unknown": {"primary": "retry_changed_prompt", "fallback": "ask_clarification"},
}
ESCALATION_TIERS = {
"switch_model": 1, # upgrade by 1 tier
"call_verifier": 0,
"mark_blocked": 0,
}
class RetryOptimizer:
def __init__(self, max_retries: int = 3, max_total_retries: int = 5):
self.max_retries = max_retries
self.max_total_retries = max_total_retries
self.retry_counts: Dict[str, int] = {}
self.total_retries = 0
self.recovery_stats = {}
def get_recovery(self, failure_tag: str, current_tier: int,
retry_num: int, previous_actions: List[str] = None,
run_cost_so_far: float = 0, max_run_cost: float = 5.0) -> RecoveryAction:
self.retry_counts[failure_tag] = self.retry_counts.get(failure_tag, 0) + 1
self.total_retries += 1
# Check if we should terminate
if self.total_retries >= self.max_total_retries:
return RecoveryAction("terminate", "max total retries reached")
if retry_num >= self.max_retries:
return RecoveryAction("mark_blocked", f"max retries ({self.max_retries}) for this failure")
if run_cost_so_far >= max_run_cost * 0.8:
return RecoveryAction("terminate", "approaching cost limit")
# Get recovery action
recovery = FAILURE_RECOVERY_MAP.get(failure_tag, FAILURE_RECOVERY_MAP["unknown"])
action_name = recovery["primary"]
# Check if primary was already tried
if previous_actions and action_name in previous_actions:
action_name = recovery["fallback"]
# Build action
new_tier = None
if action_name == "switch_model":
new_tier = min(current_tier + ESCALATION_TIERS["switch_model"], 5)
self.recovery_stats[action_name] = self.recovery_stats.get(action_name, 0) + 1
return RecoveryAction(
action=action_name,
reasoning=f"failure={failure_tag}, retry={retry_num}, action={action_name}",
new_tier=new_tier,
)
def reset_run(self):
self.retry_counts = {}
self.total_retries = 0