narcolepticchicken commited on
Commit
2551fef
·
verified ·
1 Parent(s): 33a5f28

Upload aco/retry_optimizer.py

Browse files
Files changed (1) hide show
  1. aco/retry_optimizer.py +224 -0
aco/retry_optimizer.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Retry and Recovery Optimizer - Module 8.
2
+
3
+ Avoids blind retry loops. For failures, decides:
4
+ - retry same approach
5
+ - retry with changed prompt
6
+ - repair tool call
7
+ - retrieve more context
8
+ - switch model
9
+ - ask clarification
10
+ - call verifier
11
+ - mark BLOCKED
12
+ - terminate
13
+
14
+ Uses trace-based recovery policies.
15
+ """
16
+
17
+ from typing import Dict, List, Optional, Any
18
+ from dataclasses import dataclass
19
+ from enum import Enum
20
+
21
+ from .trace_schema import Outcome, FailureTag, TraceStep, TaskType
22
+ from .config import ACOConfig
23
+
24
+
25
+ class RecoveryAction(Enum):
26
+ RETRY_SAME = "retry_same"
27
+ RETRY_CHANGED_PROMPT = "retry_changed_prompt"
28
+ REPAIR_TOOL = "repair_tool"
29
+ RETRIEVE_MORE_CONTEXT = "retrieve_more_context"
30
+ SWITCH_MODEL = "switch_model"
31
+ ASK_CLARIFICATION = "ask_clarification"
32
+ CALL_VERIFIER = "call_verifier"
33
+ MARK_BLOCKED = "mark_blocked"
34
+ TERMINATE = "terminate"
35
+ SKIP_AND_CONTINUE = "skip_and_continue"
36
+
37
+
38
+ @dataclass
39
+ class RecoveryDecision:
40
+ action: RecoveryAction
41
+ reasoning: str
42
+ confidence: float
43
+ new_model_tier: Optional[int] = None
44
+ context_additions: Optional[List[str]] = None
45
+ prompt_changes: Optional[Dict[str, str]] = None
46
+
47
+
48
+ class RetryRecoveryOptimizer:
49
+ """Intelligently decides how to recover from failures."""
50
+
51
+ # Max retries per recovery type
52
+ MAX_RETRY_SAME = 1
53
+ MAX_RETRY_CHANGED = 2
54
+ MAX_REPAIR_TOOL = 2
55
+ MAX_RETRIEVE_CONTEXT = 1
56
+ MAX_SWITCH_MODEL = 2
57
+
58
+ # Failure pattern -> preferred recovery action
59
+ FAILURE_RECOVERY_MAP = {
60
+ FailureTag.MODEL_TOO_WEAK: RecoveryAction.SWITCH_MODEL,
61
+ FailureTag.CONTEXT_TOO_SMALL: RecoveryAction.RETRIEVE_MORE_CONTEXT,
62
+ FailureTag.TOOL_FAILED: RecoveryAction.REPAIR_TOOL,
63
+ FailureTag.TOOL_UNNECESSARY: RecoveryAction.SKIP_AND_CONTINUE,
64
+ FailureTag.TOOL_MISSED: RecoveryAction.RETRY_CHANGED_PROMPT,
65
+ FailureTag.RETRY_LOOP: RecoveryAction.MARK_BLOCKED,
66
+ FailureTag.CACHE_BREAK: RecoveryAction.RETRY_SAME,
67
+ FailureTag.HALLUCINATION: RecoveryAction.CALL_VERIFIER,
68
+ FailureTag.TIMEOUT: RecoveryAction.SWITCH_MODEL,
69
+ FailureTag.COST_EXCEEDED: RecoveryAction.TERMINATE,
70
+ FailureTag.UNSAFE_CHEAP_MODEL: RecoveryAction.SWITCH_MODEL,
71
+ FailureTag.MISSED_ESCALATION: RecoveryAction.SWITCH_MODEL,
72
+ FailureTag.VERIFIER_FALSE_PASS: RecoveryAction.RETRY_CHANGED_PROMPT,
73
+ FailureTag.VERIFIER_FALSE_REJECT: RecoveryAction.RETRY_SAME,
74
+ }
75
+
76
+ def __init__(self, config: Optional[ACOConfig] = None):
77
+ self.config = config or ACOConfig()
78
+ self.retry_counts: Dict[str, int] = {} # failure_tag -> count
79
+ self.recovery_stats: Dict[str, Dict] = {}
80
+
81
+ def decide_recovery(
82
+ self,
83
+ task_type: TaskType,
84
+ current_step: TraceStep,
85
+ failure_tags: List[FailureTag],
86
+ total_cost_so_far: float,
87
+ predicted_cost: float,
88
+ current_tier: int,
89
+ step_number: int,
90
+ trace_history: Optional[List[TraceStep]] = None,
91
+ ) -> RecoveryDecision:
92
+ """Decide recovery action based on failure analysis."""
93
+
94
+ history = trace_history or []
95
+
96
+ # Count retries in trace
97
+ recent_retries = sum(1 for s in history[-5:] if s.retry_count > 0)
98
+ total_retries = sum(s.retry_count for s in history)
99
+
100
+ # Detect retry loops
101
+ if recent_retries >= 3:
102
+ return RecoveryDecision(
103
+ action=RecoveryAction.MARK_BLOCKED,
104
+ reasoning=f"Retry loop detected: {recent_retries} retries in last 5 steps",
105
+ confidence=0.9,
106
+ )
107
+
108
+ # Cost escalation check
109
+ cost_ratio = total_cost_so_far / max(predicted_cost, 0.001)
110
+ if cost_ratio > self.config.doom_max_cost_ratio * 1.5:
111
+ return RecoveryDecision(
112
+ action=RecoveryAction.TERMINATE,
113
+ reasoning=f"Cost exceeded {self.config.doom_max_cost_ratio * 1.5}x predicted cost ({total_cost_so_far:.4f} vs {predicted_cost:.4f})",
114
+ confidence=0.85,
115
+ )
116
+
117
+ # Analyze primary failure tag
118
+ primary_failure = failure_tags[0] if failure_tags else FailureTag.MODEL_TOO_WEAK
119
+ preferred_action = self.FAILURE_RECOVERY_MAP.get(primary_failure, RecoveryAction.RETRY_CHANGED_PROMPT)
120
+
121
+ # Check if we've exhausted this recovery path
122
+ failure_key = f"{primary_failure.value}_{preferred_action.value}"
123
+ current_count = self.retry_counts.get(failure_key, 0)
124
+
125
+ max_map = {
126
+ RecoveryAction.RETRY_SAME: self.MAX_RETRY_SAME,
127
+ RecoveryAction.RETRY_CHANGED_PROMPT: self.MAX_RETRY_CHANGED,
128
+ RecoveryAction.REPAIR_TOOL: self.MAX_REPAIR_TOOL,
129
+ RecoveryAction.RETRIEVE_MORE_CONTEXT: self.MAX_RETRIEVE_CONTEXT,
130
+ RecoveryAction.SWITCH_MODEL: self.MAX_SWITCH_MODEL,
131
+ }
132
+ max_allowed = max_map.get(preferred_action, 1)
133
+
134
+ if current_count >= max_allowed:
135
+ # Escalate to next recovery action
136
+ escalation_chain = [
137
+ RecoveryAction.RETRY_SAME,
138
+ RecoveryAction.RETRY_CHANGED_PROMPT,
139
+ RecoveryAction.REPAIR_TOOL,
140
+ RecoveryAction.RETRIEVE_MORE_CONTEXT,
141
+ RecoveryAction.SWITCH_MODEL,
142
+ RecoveryAction.ASK_CLARIFICATION,
143
+ RecoveryAction.MARK_BLOCKED,
144
+ ]
145
+
146
+ try:
147
+ idx = escalation_chain.index(preferred_action)
148
+ preferred_action = escalation_chain[min(idx + 1, len(escalation_chain) - 1)]
149
+ except ValueError:
150
+ preferred_action = RecoveryAction.MARK_BLOCKED
151
+
152
+ self.retry_counts[failure_key] = current_count + 1
153
+
154
+ # Build decision
155
+ if preferred_action == RecoveryAction.SWITCH_MODEL:
156
+ new_tier = min(current_tier + 1, 5)
157
+ return RecoveryDecision(
158
+ action=preferred_action,
159
+ reasoning=f"Failure: {primary_failure.value}. Escalating from tier {current_tier} to tier {new_tier}",
160
+ confidence=0.8,
161
+ new_model_tier=new_tier,
162
+ )
163
+
164
+ if preferred_action == RecoveryAction.RETRIEVE_MORE_CONTEXT:
165
+ return RecoveryDecision(
166
+ action=preferred_action,
167
+ reasoning=f"Failure: {primary_failure.value}. Adding retrieved context and retrying.",
168
+ confidence=0.75,
169
+ context_additions=["retrieved_docs", "tool_error_logs", "prior_attempt_summary"],
170
+ )
171
+
172
+ if preferred_action == RecoveryAction.REPAIR_TOOL:
173
+ return RecoveryDecision(
174
+ action=preferred_action,
175
+ reasoning=f"Failure: {primary_failure.value}. Repairing tool call parameters.",
176
+ confidence=0.7,
177
+ prompt_changes={"tool_repair": "true", "validate_params": "true"},
178
+ )
179
+
180
+ if preferred_action == RecoveryAction.RETRY_CHANGED_PROMPT:
181
+ return RecoveryDecision(
182
+ action=preferred_action,
183
+ reasoning=f"Failure: {primary_failure.value}. Retrying with modified prompt strategy.",
184
+ confidence=0.6,
185
+ prompt_changes={"add_examples": "true", "increase_temperature": "0.3"},
186
+ )
187
+
188
+ if preferred_action == RecoveryAction.TERMINATE:
189
+ return RecoveryDecision(
190
+ action=preferred_action,
191
+ reasoning=f"Failure: {primary_failure.value}. Cost ratio {cost_ratio:.1f}x. Terminating.",
192
+ confidence=0.9,
193
+ )
194
+
195
+ if preferred_action == RecoveryAction.MARK_BLOCKED:
196
+ return RecoveryDecision(
197
+ action=preferred_action,
198
+ reasoning=f"Failure: {primary_failure.value}. Exhausted recovery options. Marking BLOCKED.",
199
+ confidence=0.85,
200
+ )
201
+
202
+ return RecoveryDecision(
203
+ action=preferred_action,
204
+ reasoning=f"Failure: {primary_failure.value}. Attempting recovery via {preferred_action.value}.",
205
+ confidence=0.6,
206
+ )
207
+
208
+ def record_recovery_outcome(
209
+ self,
210
+ failure_tag: FailureTag,
211
+ action: RecoveryAction,
212
+ succeeded: bool,
213
+ cost_delta: float,
214
+ ) -> None:
215
+ """Record outcome for policy improvement."""
216
+ key = f"{failure_tag.value}_{action.value}"
217
+ stats = self.recovery_stats.setdefault(key, {
218
+ "attempts": 0, "successes": 0, "total_cost_delta": 0.0,
219
+ })
220
+ stats["attempts"] += 1
221
+ if succeeded:
222
+ stats["successes"] += 1
223
+ stats["total_cost_delta"] += cost_delta
224
+ stats["success_rate"] = stats["successes"] / stats["attempts"]