narcolepticchicken commited on
Commit
99ad299
·
verified ·
1 Parent(s): 33d6b64

Upload aco/doom_detector.py

Browse files
Files changed (1) hide show
  1. aco/doom_detector.py +287 -0
aco/doom_detector.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Early Termination / Doom Detector - Module 10.
2
+
3
+ Detects runs that are unlikely to succeed without more information or intervention.
4
+
5
+ Signals:
6
+ - repeated failed tool calls
7
+ - no artifact progress
8
+ - growing cost without new evidence
9
+ - repeated planning
10
+ - verifier disagreement
11
+ - context confusion
12
+ - escalating retries
13
+ - model loop behavior
14
+
15
+ Actions:
16
+ - stop
17
+ - mark BLOCKED
18
+ - ask one targeted question
19
+ - switch strategy
20
+ - escalate model
21
+ - escalate human
22
+ """
23
+
24
+ from typing import Dict, List, Optional, Any
25
+ from dataclasses import dataclass
26
+ from enum import Enum
27
+
28
+ from .trace_schema import AgentTrace, TraceStep, Outcome, FailureTag
29
+ from .config import ACOConfig
30
+
31
+
32
+ class DoomAction(Enum):
33
+ STOP = "stop"
34
+ MARK_BLOCKED = "mark_blocked"
35
+ ASK_TARGETED_QUESTION = "ask_targeted_question"
36
+ SWITCH_STRATEGY = "switch_strategy"
37
+ ESCALATE_MODEL = "escalate_model"
38
+ ESCALATE_HUMAN = "escalate_human"
39
+ CONTINUE = "continue"
40
+
41
+
42
+ @dataclass
43
+ class DoomAssessment:
44
+ action: DoomAction
45
+ confidence: float
46
+ reasoning: str
47
+ signals_triggered: List[str]
48
+ recommended_action: Optional[str] = None
49
+ question_to_ask: Optional[str] = None
50
+
51
+
52
+ class DoomDetector:
53
+ """Detects doomed agent runs and recommends intervention."""
54
+
55
+ # Signal weights
56
+ SIGNAL_WEIGHTS = {
57
+ "repeated_tool_failures": 0.3,
58
+ "no_artifact_progress": 0.25,
59
+ "cost_explosion": 0.3,
60
+ "repeated_planning": 0.2,
61
+ "verifier_disagreement": 0.25,
62
+ "context_confusion": 0.2,
63
+ "escalating_retries": 0.35,
64
+ "model_loop": 0.4,
65
+ "stagnant_context": 0.15,
66
+ }
67
+
68
+ # Doom threshold
69
+ DOOM_THRESHOLD = 0.6
70
+ BLOCKED_THRESHOLD = 0.8
71
+
72
+ def __init__(self, config: Optional[ACOConfig] = None):
73
+ self.config = config or ACOConfig()
74
+ self.assessment_history: List[DoomAssessment] = []
75
+
76
+ def assess(
77
+ self,
78
+ trace: AgentTrace,
79
+ current_step: TraceStep,
80
+ predicted_cost: float,
81
+ predicted_steps: int,
82
+ ) -> DoomAssessment:
83
+ """Assess whether a run is doomed and what action to take."""
84
+
85
+ signals = []
86
+ score = 0.0
87
+
88
+ # Signal 1: Repeated failed tool calls
89
+ tool_failures = self._count_recent_tool_failures(trace)
90
+ if tool_failures >= 3:
91
+ signals.append("repeated_tool_failures")
92
+ score += self.SIGNAL_WEIGHTS["repeated_tool_failures"] * min(tool_failures / 5, 1.0)
93
+
94
+ # Signal 2: No artifact progress
95
+ if len(trace.final_artifacts) == 0 and len(trace.steps) > 5:
96
+ signals.append("no_artifact_progress")
97
+ score += self.SIGNAL_WEIGHTS["no_artifact_progress"]
98
+
99
+ # Signal 3: Cost explosion
100
+ total_cost = trace.total_cost_computed
101
+ cost_ratio = total_cost / max(predicted_cost, 0.0001)
102
+ if cost_ratio > self.config.doom_max_cost_ratio:
103
+ signals.append("cost_explosion")
104
+ score += self.SIGNAL_WEIGHTS["cost_explosion"] * min(cost_ratio / 5, 1.0)
105
+
106
+ # Signal 4: Repeated planning (re-planning without execution)
107
+ replan_count = self._count_replanning(trace)
108
+ if replan_count >= 2:
109
+ signals.append("repeated_planning")
110
+ score += self.SIGNAL_WEIGHTS["repeated_planning"] * min(replan_count / 4, 1.0)
111
+
112
+ # Signal 5: Verifier disagreement
113
+ verifier_disagree = self._count_verifier_disagreement(trace)
114
+ if verifier_disagree >= self.config.doom_verifier_disagreement_threshold:
115
+ signals.append("verifier_disagreement")
116
+ score += self.SIGNAL_WEIGHTS["verifier_disagreement"] * min(verifier_disagree / 4, 1.0)
117
+
118
+ # Signal 6: Context confusion (rapid context size oscillation)
119
+ if len(trace.steps) >= 3:
120
+ ctx_sizes = [s.context_size_tokens for s in trace.steps[-3:]]
121
+ if max(ctx_sizes) - min(ctx_sizes) > 5000:
122
+ signals.append("context_confusion")
123
+ score += self.SIGNAL_WEIGHTS["context_confusion"]
124
+
125
+ # Signal 7: Escalating retries
126
+ total_retries = trace.total_retries
127
+ if total_retries >= self.config.doom_max_retries * len(trace.steps) * 0.5:
128
+ signals.append("escalating_retries")
129
+ score += self.SIGNAL_WEIGHTS["escalating_retries"] * min(total_retries / 10, 1.0)
130
+
131
+ # Signal 8: Model loop (same model producing same outputs)
132
+ loop_detected = self._detect_model_loop(trace)
133
+ if loop_detected:
134
+ signals.append("model_loop")
135
+ score += self.SIGNAL_WEIGHTS["model_loop"]
136
+
137
+ # Signal 9: Stagnant context (no new information in last N steps)
138
+ if self._is_context_stagnant(trace):
139
+ signals.append("stagnant_context")
140
+ score += self.SIGNAL_WEIGHTS["stagnant_context"]
141
+
142
+ # Cap score
143
+ score = min(score, 1.0)
144
+
145
+ if score < 0.3:
146
+ return DoomAssessment(
147
+ action=DoomAction.CONTINUE,
148
+ confidence=1.0 - score,
149
+ reasoning="Run appears healthy.",
150
+ signals_triggered=signals,
151
+ )
152
+
153
+ if score < self.DOOM_THRESHOLD:
154
+ return DoomAssessment(
155
+ action=DoomAction.ASK_TARGETED_QUESTION,
156
+ confidence=score,
157
+ reasoning=f"Early warning: {', '.join(signals)}. Asking targeted question.",
158
+ signals_triggered=signals,
159
+ question_to_ask=self._generate_targeted_question(trace, signals),
160
+ )
161
+
162
+ if score < self.BLOCKED_THRESHOLD:
163
+ # Decide between strategy switch, model escalation, or stop
164
+ if current_step.model_call and current_step.model_call.model_id:
165
+ current_tier = self._infer_tier(current_step.model_call.model_id)
166
+ if current_tier < 4:
167
+ return DoomAssessment(
168
+ action=DoomAction.ESCALATE_MODEL,
169
+ confidence=score,
170
+ reasoning=f"Run struggling ({score:.2f} doom score). Escalating model.",
171
+ signals_triggered=signals,
172
+ )
173
+
174
+ return DoomAssessment(
175
+ action=DoomAction.SWITCH_STRATEGY,
176
+ confidence=score,
177
+ reasoning=f"Run struggling ({score:.2f} doom score). Switching strategy.",
178
+ signals_triggered=signals,
179
+ recommended_action="Change approach: retrieve more context, simplify task decomposition",
180
+ )
181
+
182
+ # High doom score — mark blocked or escalate human
183
+ if score > 0.95:
184
+ return DoomAssessment(
185
+ action=DoomAction.ESCALATE_HUMAN,
186
+ confidence=score,
187
+ reasoning=f"Critical failure pattern ({score:.2f} doom score). Human escalation required.",
188
+ signals_triggered=signals,
189
+ )
190
+
191
+ return DoomAssessment(
192
+ action=DoomAction.MARK_BLOCKED,
193
+ confidence=score,
194
+ reasoning=f"Doom threshold exceeded ({score:.2f}). Marking BLOCKED.",
195
+ signals_triggered=signals,
196
+ )
197
+
198
+ def _count_recent_tool_failures(self, trace: AgentTrace, window: int = 5) -> int:
199
+ recent_steps = trace.steps[-window:] if len(trace.steps) > window else trace.steps
200
+ return sum(
201
+ 1 for step in recent_steps
202
+ for tc in step.tool_calls
203
+ if tc.failed
204
+ )
205
+
206
+ def _count_replanning(self, trace: AgentTrace) -> int:
207
+ replan_count = 0
208
+ for step in trace.steps:
209
+ # Heuristic: step mentions plan but no tool calls or artifacts
210
+ if step.planned_next and not step.tool_calls and not step.artifacts_created:
211
+ replan_count += 1
212
+ return replan_count
213
+
214
+ def _count_verifier_disagreement(self, trace: AgentTrace) -> int:
215
+ disagreements = 0
216
+ for step in trace.steps:
217
+ verifiers = step.verifier_calls
218
+ if len(verifiers) >= 2:
219
+ results = [v.passed for v in verifiers]
220
+ if any(results) and not all(results):
221
+ disagreements += 1
222
+ elif len(verifiers) == 1:
223
+ # Verifier rejected but step proceeded anyway
224
+ if not verifiers[0].passed:
225
+ disagreements += 1
226
+ return disagreements
227
+
228
+ def _detect_model_loop(self, trace: AgentTrace) -> bool:
229
+ if len(trace.steps) < 4:
230
+ return False
231
+ # Check if last 4 steps have identical or very similar tool call patterns
232
+ last4 = trace.steps[-4:]
233
+ patterns = [
234
+ tuple(tc.tool_name for tc in s.tool_calls)
235
+ for s in last4
236
+ ]
237
+ return len(set(patterns)) <= 2 and len(patterns) == 4
238
+
239
+ def _is_context_stagnant(self, trace: AgentTrace, window: int = 3) -> bool:
240
+ if len(trace.steps) < window:
241
+ return False
242
+ recent = trace.steps[-window:]
243
+ sources = [set(s.context_sources) for s in recent]
244
+ # If no new sources introduced
245
+ if len(sources) >= 2:
246
+ for i in range(1, len(sources)):
247
+ if sources[i] - sources[i-1]:
248
+ return False
249
+ return True
250
+ return False
251
+
252
+ def _generate_targeted_question(self, trace: AgentTrace, signals: List[str]) -> str:
253
+ if "repeated_tool_failures" in signals:
254
+ return "The requested tools are failing repeatedly. Can you provide the correct parameters or clarify the task scope?"
255
+ if "no_artifact_progress" in signals:
256
+ return "No progress has been made on the expected deliverable. Is there a specific format or file you need?"
257
+ if "cost_explosion" in signals:
258
+ return "This task is taking more resources than expected. Can you narrow the scope or clarify priorities?"
259
+ if "repeated_planning" in signals:
260
+ return "The agent keeps re-planning. What is the single most important next step?"
261
+ return "Can you clarify or narrow the task requirements to help the agent proceed more efficiently?"
262
+
263
+ def _infer_tier(self, model_id: str) -> int:
264
+ # Simplified tier inference
265
+ if "frontier" in model_id.lower() or "gpt-4" in model_id.lower():
266
+ return 4
267
+ if "medium" in model_id.lower():
268
+ return 3
269
+ if "small" in model_id.lower() or "mini" in model_id.lower():
270
+ return 2
271
+ return 1
272
+
273
+ def get_stats(self) -> Dict[str, Any]:
274
+ """Return doom detection statistics."""
275
+ total = len(self.assessment_history)
276
+ if total == 0:
277
+ return {"total_assessments": 0}
278
+
279
+ action_counts = {}
280
+ for a in self.assessment_history:
281
+ action_counts[a.action.value] = action_counts.get(a.action.value, 0) + 1
282
+
283
+ return {
284
+ "total_assessments": total,
285
+ "action_distribution": action_counts,
286
+ "avg_confidence": sum(a.confidence for a in self.assessment_history) / total,
287
+ }