narcolepticchicken commited on
Commit
2ffdfdb
·
verified ·
1 Parent(s): 5d30266

Upload aco/retry_optimizer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. aco/retry_optimizer.py +63 -214
aco/retry_optimizer.py CHANGED
@@ -1,224 +1,73 @@
1
- """Retry and Recovery Optimizer - Module 8.
2
-
3
- Avoids blind retry loops. For failures, decides:
4
- - retry same approach
5
- - retry with changed prompt
6
- - repair tool call
7
- - retrieve more context
8
- - switch model
9
- - ask clarification
10
- - call verifier
11
- - mark BLOCKED
12
- - terminate
13
-
14
- Uses trace-based recovery policies.
15
- """
16
-
17
- from typing import Dict, List, Optional, Any
18
  from dataclasses import dataclass
19
- from enum import Enum
20
-
21
- from .trace_schema import Outcome, FailureTag, TraceStep, TaskType
22
- from .config import ACOConfig
23
-
24
-
25
- class RecoveryAction(Enum):
26
- RETRY_SAME = "retry_same"
27
- RETRY_CHANGED_PROMPT = "retry_changed_prompt"
28
- REPAIR_TOOL = "repair_tool"
29
- RETRIEVE_MORE_CONTEXT = "retrieve_more_context"
30
- SWITCH_MODEL = "switch_model"
31
- ASK_CLARIFICATION = "ask_clarification"
32
- CALL_VERIFIER = "call_verifier"
33
- MARK_BLOCKED = "mark_blocked"
34
- TERMINATE = "terminate"
35
- SKIP_AND_CONTINUE = "skip_and_continue"
36
-
37
 
38
  @dataclass
39
- class RecoveryDecision:
40
- action: RecoveryAction
 
41
  reasoning: str
42
- confidence: float
43
- new_model_tier: Optional[int] = None
44
- context_additions: Optional[List[str]] = None
45
- prompt_changes: Optional[Dict[str, str]] = None
46
-
47
-
48
- class RetryRecoveryOptimizer:
49
- """Intelligently decides how to recover from failures."""
50
 
51
- # Max retries per recovery type
52
- MAX_RETRY_SAME = 1
53
- MAX_RETRY_CHANGED = 2
54
- MAX_REPAIR_TOOL = 2
55
- MAX_RETRIEVE_CONTEXT = 1
56
- MAX_SWITCH_MODEL = 2
 
 
 
 
 
 
 
 
57
 
58
- # Failure pattern -> preferred recovery action
59
- FAILURE_RECOVERY_MAP = {
60
- FailureTag.MODEL_TOO_WEAK: RecoveryAction.SWITCH_MODEL,
61
- FailureTag.CONTEXT_TOO_SMALL: RecoveryAction.RETRIEVE_MORE_CONTEXT,
62
- FailureTag.TOOL_FAILED: RecoveryAction.REPAIR_TOOL,
63
- FailureTag.TOOL_UNNECESSARY: RecoveryAction.SKIP_AND_CONTINUE,
64
- FailureTag.TOOL_MISSED: RecoveryAction.RETRY_CHANGED_PROMPT,
65
- FailureTag.RETRY_LOOP: RecoveryAction.MARK_BLOCKED,
66
- FailureTag.CACHE_BREAK: RecoveryAction.RETRY_SAME,
67
- FailureTag.HALLUCINATION: RecoveryAction.CALL_VERIFIER,
68
- FailureTag.TIMEOUT: RecoveryAction.SWITCH_MODEL,
69
- FailureTag.COST_EXCEEDED: RecoveryAction.TERMINATE,
70
- FailureTag.UNSAFE_CHEAP_MODEL: RecoveryAction.SWITCH_MODEL,
71
- FailureTag.MISSED_ESCALATION: RecoveryAction.SWITCH_MODEL,
72
- FailureTag.VERIFIER_FALSE_PASS: RecoveryAction.RETRY_CHANGED_PROMPT,
73
- FailureTag.VERIFIER_FALSE_REJECT: RecoveryAction.RETRY_SAME,
74
- }
75
 
76
- def __init__(self, config: Optional[ACOConfig] = None):
77
- self.config = config or ACOConfig()
78
- self.retry_counts: Dict[str, int] = {} # failure_tag -> count
79
- self.recovery_stats: Dict[str, Dict] = {}
 
 
 
80
 
81
- def decide_recovery(
82
- self,
83
- task_type: TaskType,
84
- current_step: TraceStep,
85
- failure_tags: List[FailureTag],
86
- total_cost_so_far: float,
87
- predicted_cost: float,
88
- current_tier: int,
89
- step_number: int,
90
- trace_history: Optional[List[TraceStep]] = None,
91
- ) -> RecoveryDecision:
92
- """Decide recovery action based on failure analysis."""
93
-
94
- history = trace_history or []
95
-
96
- # Count retries in trace
97
- recent_retries = sum(1 for s in history[-5:] if s.retry_count > 0)
98
- total_retries = sum(s.retry_count for s in history)
99
-
100
- # Detect retry loops
101
- if recent_retries >= 3:
102
- return RecoveryDecision(
103
- action=RecoveryAction.MARK_BLOCKED,
104
- reasoning=f"Retry loop detected: {recent_retries} retries in last 5 steps",
105
- confidence=0.9,
106
- )
107
-
108
- # Cost escalation check
109
- cost_ratio = total_cost_so_far / max(predicted_cost, 0.001)
110
- if cost_ratio > self.config.doom_max_cost_ratio * 1.5:
111
- return RecoveryDecision(
112
- action=RecoveryAction.TERMINATE,
113
- reasoning=f"Cost exceeded {self.config.doom_max_cost_ratio * 1.5}x predicted cost ({total_cost_so_far:.4f} vs {predicted_cost:.4f})",
114
- confidence=0.85,
115
- )
116
-
117
- # Analyze primary failure tag
118
- primary_failure = failure_tags[0] if failure_tags else FailureTag.MODEL_TOO_WEAK
119
- preferred_action = self.FAILURE_RECOVERY_MAP.get(primary_failure, RecoveryAction.RETRY_CHANGED_PROMPT)
120
-
121
- # Check if we've exhausted this recovery path
122
- failure_key = f"{primary_failure.value}_{preferred_action.value}"
123
- current_count = self.retry_counts.get(failure_key, 0)
124
-
125
- max_map = {
126
- RecoveryAction.RETRY_SAME: self.MAX_RETRY_SAME,
127
- RecoveryAction.RETRY_CHANGED_PROMPT: self.MAX_RETRY_CHANGED,
128
- RecoveryAction.REPAIR_TOOL: self.MAX_REPAIR_TOOL,
129
- RecoveryAction.RETRIEVE_MORE_CONTEXT: self.MAX_RETRIEVE_CONTEXT,
130
- RecoveryAction.SWITCH_MODEL: self.MAX_SWITCH_MODEL,
131
- }
132
- max_allowed = max_map.get(preferred_action, 1)
133
-
134
- if current_count >= max_allowed:
135
- # Escalate to next recovery action
136
- escalation_chain = [
137
- RecoveryAction.RETRY_SAME,
138
- RecoveryAction.RETRY_CHANGED_PROMPT,
139
- RecoveryAction.REPAIR_TOOL,
140
- RecoveryAction.RETRIEVE_MORE_CONTEXT,
141
- RecoveryAction.SWITCH_MODEL,
142
- RecoveryAction.ASK_CLARIFICATION,
143
- RecoveryAction.MARK_BLOCKED,
144
- ]
145
-
146
- try:
147
- idx = escalation_chain.index(preferred_action)
148
- preferred_action = escalation_chain[min(idx + 1, len(escalation_chain) - 1)]
149
- except ValueError:
150
- preferred_action = RecoveryAction.MARK_BLOCKED
151
-
152
- self.retry_counts[failure_key] = current_count + 1
153
-
154
- # Build decision
155
- if preferred_action == RecoveryAction.SWITCH_MODEL:
156
- new_tier = min(current_tier + 1, 5)
157
- return RecoveryDecision(
158
- action=preferred_action,
159
- reasoning=f"Failure: {primary_failure.value}. Escalating from tier {current_tier} to tier {new_tier}",
160
- confidence=0.8,
161
- new_model_tier=new_tier,
162
- )
163
-
164
- if preferred_action == RecoveryAction.RETRIEVE_MORE_CONTEXT:
165
- return RecoveryDecision(
166
- action=preferred_action,
167
- reasoning=f"Failure: {primary_failure.value}. Adding retrieved context and retrying.",
168
- confidence=0.75,
169
- context_additions=["retrieved_docs", "tool_error_logs", "prior_attempt_summary"],
170
- )
171
-
172
- if preferred_action == RecoveryAction.REPAIR_TOOL:
173
- return RecoveryDecision(
174
- action=preferred_action,
175
- reasoning=f"Failure: {primary_failure.value}. Repairing tool call parameters.",
176
- confidence=0.7,
177
- prompt_changes={"tool_repair": "true", "validate_params": "true"},
178
- )
179
-
180
- if preferred_action == RecoveryAction.RETRY_CHANGED_PROMPT:
181
- return RecoveryDecision(
182
- action=preferred_action,
183
- reasoning=f"Failure: {primary_failure.value}. Retrying with modified prompt strategy.",
184
- confidence=0.6,
185
- prompt_changes={"add_examples": "true", "increase_temperature": "0.3"},
186
- )
187
-
188
- if preferred_action == RecoveryAction.TERMINATE:
189
- return RecoveryDecision(
190
- action=preferred_action,
191
- reasoning=f"Failure: {primary_failure.value}. Cost ratio {cost_ratio:.1f}x. Terminating.",
192
- confidence=0.9,
193
- )
194
-
195
- if preferred_action == RecoveryAction.MARK_BLOCKED:
196
- return RecoveryDecision(
197
- action=preferred_action,
198
- reasoning=f"Failure: {primary_failure.value}. Exhausted recovery options. Marking BLOCKED.",
199
- confidence=0.85,
200
- )
201
-
202
- return RecoveryDecision(
203
- action=preferred_action,
204
- reasoning=f"Failure: {primary_failure.value}. Attempting recovery via {preferred_action.value}.",
205
- confidence=0.6,
206
  )
207
 
208
- def record_recovery_outcome(
209
- self,
210
- failure_tag: FailureTag,
211
- action: RecoveryAction,
212
- succeeded: bool,
213
- cost_delta: float,
214
- ) -> None:
215
- """Record outcome for policy improvement."""
216
- key = f"{failure_tag.value}_{action.value}"
217
- stats = self.recovery_stats.setdefault(key, {
218
- "attempts": 0, "successes": 0, "total_cost_delta": 0.0,
219
- })
220
- stats["attempts"] += 1
221
- if succeeded:
222
- stats["successes"] += 1
223
- stats["total_cost_delta"] += cost_delta
224
- stats["success_rate"] = stats["successes"] / stats["attempts"]
 
1
+ """Retry and Recovery Optimizer: Maps failure tags to specific recovery actions."""
2
+ from typing import Dict, List, Optional
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from dataclasses import dataclass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  @dataclass
6
+ class RecoveryAction:
7
+ action: str # "retry_same","retry_changed_prompt","repair_tool","retrieve_more",
8
+ # "switch_model","ask_clarification","call_verifier","mark_blocked","terminate"
9
  reasoning: str
10
+ new_tier: Optional[int] = None
11
+ additional_context: Optional[str] = None
 
 
 
 
 
 
12
 
13
+ FAILURE_RECOVERY_MAP = {
14
+ "tool_error": {"primary": "repair_tool", "fallback": "retry_changed_prompt"},
15
+ "tool_not_found": {"primary": "retry_changed_prompt", "fallback": "ask_clarification"},
16
+ "timeout": {"primary": "retry_same", "fallback": "switch_model"},
17
+ "context_too_large": {"primary": "retrieve_more", "fallback": "switch_model"},
18
+ "model_refused": {"primary": "retry_changed_prompt", "fallback": "switch_model"},
19
+ "hallucination": {"primary": "call_verifier", "fallback": "retrieve_more"},
20
+ "wrong_answer": {"primary": "switch_model", "fallback": "call_verifier"},
21
+ "incomplete": {"primary": "retrieve_more", "fallback": "retry_changed_prompt"},
22
+ "format_error": {"primary": "retry_changed_prompt", "fallback": "switch_model"},
23
+ "permission_denied": {"primary": "ask_clarification", "fallback": "mark_blocked"},
24
+ "rate_limit": {"primary": "retry_same", "fallback": "switch_model"},
25
+ "unknown": {"primary": "retry_changed_prompt", "fallback": "ask_clarification"},
26
+ }
27
 
28
+ ESCALATION_TIERS = {
29
+ "switch_model": 1, # upgrade by 1 tier
30
+ "call_verifier": 0,
31
+ "mark_blocked": 0,
32
+ }
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ class RetryOptimizer:
35
+ def __init__(self, max_retries: int = 3, max_total_retries: int = 5):
36
+ self.max_retries = max_retries
37
+ self.max_total_retries = max_total_retries
38
+ self.retry_counts: Dict[str, int] = {}
39
+ self.total_retries = 0
40
+ self.recovery_stats = {}
41
 
42
+ def get_recovery(self, failure_tag: str, current_tier: int,
43
+ retry_num: int, previous_actions: List[str] = None,
44
+ run_cost_so_far: float = 0, max_run_cost: float = 5.0) -> RecoveryAction:
45
+ self.retry_counts[failure_tag] = self.retry_counts.get(failure_tag, 0) + 1
46
+ self.total_retries += 1
47
+ # Check if we should terminate
48
+ if self.total_retries >= self.max_total_retries:
49
+ return RecoveryAction("terminate", "max total retries reached")
50
+ if retry_num >= self.max_retries:
51
+ return RecoveryAction("mark_blocked", f"max retries ({self.max_retries}) for this failure")
52
+ if run_cost_so_far >= max_run_cost * 0.8:
53
+ return RecoveryAction("terminate", "approaching cost limit")
54
+ # Get recovery action
55
+ recovery = FAILURE_RECOVERY_MAP.get(failure_tag, FAILURE_RECOVERY_MAP["unknown"])
56
+ action_name = recovery["primary"]
57
+ # Check if primary was already tried
58
+ if previous_actions and action_name in previous_actions:
59
+ action_name = recovery["fallback"]
60
+ # Build action
61
+ new_tier = None
62
+ if action_name == "switch_model":
63
+ new_tier = min(current_tier + ESCALATION_TIERS["switch_model"], 5)
64
+ self.recovery_stats[action_name] = self.recovery_stats.get(action_name, 0) + 1
65
+ return RecoveryAction(
66
+ action=action_name,
67
+ reasoning=f"failure={failure_tag}, retry={retry_num}, action={action_name}",
68
+ new_tier=new_tier,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  )
70
 
71
+ def reset_run(self):
72
+ self.retry_counts = {}
73
+ self.total_retries = 0