narcolepticchicken commited on
Commit
ff456f8
·
verified ·
1 Parent(s): a97e900

Upload aco/doom_detector.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. aco/doom_detector.py +80 -276
aco/doom_detector.py CHANGED
@@ -1,287 +1,91 @@
1
- """Early Termination / Doom Detector - Module 10.
2
-
3
- Detects runs that are unlikely to succeed without more information or intervention.
4
-
5
- Signals:
6
- - repeated failed tool calls
7
- - no artifact progress
8
- - growing cost without new evidence
9
- - repeated planning
10
- - verifier disagreement
11
- - context confusion
12
- - escalating retries
13
- - model loop behavior
14
-
15
- Actions:
16
- - stop
17
- - mark BLOCKED
18
- - ask one targeted question
19
- - switch strategy
20
- - escalate model
21
- - escalate human
22
- """
23
-
24
- from typing import Dict, List, Optional, Any
25
  from dataclasses import dataclass
26
- from enum import Enum
27
-
28
- from .trace_schema import AgentTrace, TraceStep, Outcome, FailureTag
29
- from .config import ACOConfig
30
-
31
-
32
- class DoomAction(Enum):
33
- STOP = "stop"
34
- MARK_BLOCKED = "mark_blocked"
35
- ASK_TARGETED_QUESTION = "ask_targeted_question"
36
- SWITCH_STRATEGY = "switch_strategy"
37
- ESCALATE_MODEL = "escalate_model"
38
- ESCALATE_HUMAN = "escalate_human"
39
- CONTINUE = "continue"
40
 
 
 
 
 
 
41
 
42
  @dataclass
43
  class DoomAssessment:
44
- action: DoomAction
45
- confidence: float
 
 
46
  reasoning: str
47
- signals_triggered: List[str]
48
- recommended_action: Optional[str] = None
49
- question_to_ask: Optional[str] = None
50
 
 
 
 
 
 
 
 
51
 
52
  class DoomDetector:
53
- """Detects doomed agent runs and recommends intervention."""
54
-
55
- # Signal weights
56
- SIGNAL_WEIGHTS = {
57
- "repeated_tool_failures": 0.3,
58
- "no_artifact_progress": 0.25,
59
- "cost_explosion": 0.3,
60
- "repeated_planning": 0.2,
61
- "verifier_disagreement": 0.25,
62
- "context_confusion": 0.2,
63
- "escalating_retries": 0.35,
64
- "model_loop": 0.4,
65
- "stagnant_context": 0.15,
66
- }
67
-
68
- # Doom threshold
69
- DOOM_THRESHOLD = 0.6
70
- BLOCKED_THRESHOLD = 0.8
71
 
72
- def __init__(self, config: Optional[ACOConfig] = None):
73
- self.config = config or ACOConfig()
74
- self.assessment_history: List[DoomAssessment] = []
75
-
76
- def assess(
77
- self,
78
- trace: AgentTrace,
79
- current_step: TraceStep,
80
- predicted_cost: float,
81
- predicted_steps: int,
82
- ) -> DoomAssessment:
83
- """Assess whether a run is doomed and what action to take."""
84
-
85
  signals = []
86
- score = 0.0
87
-
88
- # Signal 1: Repeated failed tool calls
89
- tool_failures = self._count_recent_tool_failures(trace)
90
- if tool_failures >= 3:
91
- signals.append("repeated_tool_failures")
92
- score += self.SIGNAL_WEIGHTS["repeated_tool_failures"] * min(tool_failures / 5, 1.0)
93
-
94
- # Signal 2: No artifact progress
95
- if len(trace.final_artifacts) == 0 and len(trace.steps) > 5:
96
- signals.append("no_artifact_progress")
97
- score += self.SIGNAL_WEIGHTS["no_artifact_progress"]
98
-
99
- # Signal 3: Cost explosion
100
- total_cost = trace.total_cost_computed
101
- cost_ratio = total_cost / max(predicted_cost, 0.0001)
102
- if cost_ratio > self.config.doom_max_cost_ratio:
103
- signals.append("cost_explosion")
104
- score += self.SIGNAL_WEIGHTS["cost_explosion"] * min(cost_ratio / 5, 1.0)
105
-
106
- # Signal 4: Repeated planning (re-planning without execution)
107
- replan_count = self._count_replanning(trace)
108
- if replan_count >= 2:
109
- signals.append("repeated_planning")
110
- score += self.SIGNAL_WEIGHTS["repeated_planning"] * min(replan_count / 4, 1.0)
111
-
112
- # Signal 5: Verifier disagreement
113
- verifier_disagree = self._count_verifier_disagreement(trace)
114
- if verifier_disagree >= self.config.doom_verifier_disagreement_threshold:
115
- signals.append("verifier_disagreement")
116
- score += self.SIGNAL_WEIGHTS["verifier_disagreement"] * min(verifier_disagree / 4, 1.0)
117
-
118
- # Signal 6: Context confusion (rapid context size oscillation)
119
- if len(trace.steps) >= 3:
120
- ctx_sizes = [s.context_size_tokens for s in trace.steps[-3:]]
121
- if max(ctx_sizes) - min(ctx_sizes) > 5000:
122
- signals.append("context_confusion")
123
- score += self.SIGNAL_WEIGHTS["context_confusion"]
124
-
125
- # Signal 7: Escalating retries
126
- total_retries = trace.total_retries
127
- if total_retries >= self.config.doom_max_retries * len(trace.steps) * 0.5:
128
- signals.append("escalating_retries")
129
- score += self.SIGNAL_WEIGHTS["escalating_retries"] * min(total_retries / 10, 1.0)
130
-
131
- # Signal 8: Model loop (same model producing same outputs)
132
- loop_detected = self._detect_model_loop(trace)
133
- if loop_detected:
134
- signals.append("model_loop")
135
- score += self.SIGNAL_WEIGHTS["model_loop"]
136
-
137
- # Signal 9: Stagnant context (no new information in last N steps)
138
- if self._is_context_stagnant(trace):
139
- signals.append("stagnant_context")
140
- score += self.SIGNAL_WEIGHTS["stagnant_context"]
141
-
142
- # Cap score
143
- score = min(score, 1.0)
144
-
145
- if score < 0.3:
146
- return DoomAssessment(
147
- action=DoomAction.CONTINUE,
148
- confidence=1.0 - score,
149
- reasoning="Run appears healthy.",
150
- signals_triggered=signals,
151
- )
152
-
153
- if score < self.DOOM_THRESHOLD:
154
- return DoomAssessment(
155
- action=DoomAction.ASK_TARGETED_QUESTION,
156
- confidence=score,
157
- reasoning=f"Early warning: {', '.join(signals)}. Asking targeted question.",
158
- signals_triggered=signals,
159
- question_to_ask=self._generate_targeted_question(trace, signals),
160
- )
161
-
162
- if score < self.BLOCKED_THRESHOLD:
163
- # Decide between strategy switch, model escalation, or stop
164
- if current_step.model_call and current_step.model_call.model_id:
165
- current_tier = self._infer_tier(current_step.model_call.model_id)
166
- if current_tier < 4:
167
- return DoomAssessment(
168
- action=DoomAction.ESCALATE_MODEL,
169
- confidence=score,
170
- reasoning=f"Run struggling ({score:.2f} doom score). Escalating model.",
171
- signals_triggered=signals,
172
- )
173
-
174
- return DoomAssessment(
175
- action=DoomAction.SWITCH_STRATEGY,
176
- confidence=score,
177
- reasoning=f"Run struggling ({score:.2f} doom score). Switching strategy.",
178
- signals_triggered=signals,
179
- recommended_action="Change approach: retrieve more context, simplify task decomposition",
180
- )
181
-
182
- # High doom score — mark blocked or escalate human
183
- if score > 0.95:
184
- return DoomAssessment(
185
- action=DoomAction.ESCALATE_HUMAN,
186
- confidence=score,
187
- reasoning=f"Critical failure pattern ({score:.2f} doom score). Human escalation required.",
188
- signals_triggered=signals,
189
- )
190
-
191
- return DoomAssessment(
192
- action=DoomAction.MARK_BLOCKED,
193
- confidence=score,
194
- reasoning=f"Doom threshold exceeded ({score:.2f}). Marking BLOCKED.",
195
- signals_triggered=signals,
196
- )
197
-
198
- def _count_recent_tool_failures(self, trace: AgentTrace, window: int = 5) -> int:
199
- recent_steps = trace.steps[-window:] if len(trace.steps) > window else trace.steps
200
- return sum(
201
- 1 for step in recent_steps
202
- for tc in step.tool_calls
203
- if tc.failed
204
- )
205
-
206
- def _count_replanning(self, trace: AgentTrace) -> int:
207
- replan_count = 0
208
- for step in trace.steps:
209
- # Heuristic: step mentions plan but no tool calls or artifacts
210
- if step.planned_next and not step.tool_calls and not step.artifacts_created:
211
- replan_count += 1
212
- return replan_count
213
-
214
- def _count_verifier_disagreement(self, trace: AgentTrace) -> int:
215
- disagreements = 0
216
- for step in trace.steps:
217
- verifiers = step.verifier_calls
218
- if len(verifiers) >= 2:
219
- results = [v.passed for v in verifiers]
220
- if any(results) and not all(results):
221
- disagreements += 1
222
- elif len(verifiers) == 1:
223
- # Verifier rejected but step proceeded anyway
224
- if not verifiers[0].passed:
225
- disagreements += 1
226
- return disagreements
227
-
228
- def _detect_model_loop(self, trace: AgentTrace) -> bool:
229
- if len(trace.steps) < 4:
230
- return False
231
- # Check if last 4 steps have identical or very similar tool call patterns
232
- last4 = trace.steps[-4:]
233
- patterns = [
234
- tuple(tc.tool_name for tc in s.tool_calls)
235
- for s in last4
236
- ]
237
- return len(set(patterns)) <= 2 and len(patterns) == 4
238
-
239
- def _is_context_stagnant(self, trace: AgentTrace, window: int = 3) -> bool:
240
- if len(trace.steps) < window:
241
- return False
242
- recent = trace.steps[-window:]
243
- sources = [set(s.context_sources) for s in recent]
244
- # If no new sources introduced
245
- if len(sources) >= 2:
246
- for i in range(1, len(sources)):
247
- if sources[i] - sources[i-1]:
248
- return False
249
- return True
250
- return False
251
-
252
- def _generate_targeted_question(self, trace: AgentTrace, signals: List[str]) -> str:
253
- if "repeated_tool_failures" in signals:
254
- return "The requested tools are failing repeatedly. Can you provide the correct parameters or clarify the task scope?"
255
- if "no_artifact_progress" in signals:
256
- return "No progress has been made on the expected deliverable. Is there a specific format or file you need?"
257
- if "cost_explosion" in signals:
258
- return "This task is taking more resources than expected. Can you narrow the scope or clarify priorities?"
259
- if "repeated_planning" in signals:
260
- return "The agent keeps re-planning. What is the single most important next step?"
261
- return "Can you clarify or narrow the task requirements to help the agent proceed more efficiently?"
262
-
263
- def _infer_tier(self, model_id: str) -> int:
264
- # Simplified tier inference
265
- if "frontier" in model_id.lower() or "gpt-4" in model_id.lower():
266
- return 4
267
- if "medium" in model_id.lower():
268
- return 3
269
- if "small" in model_id.lower() or "mini" in model_id.lower():
270
- return 2
271
- return 1
272
-
273
- def get_stats(self) -> Dict[str, Any]:
274
- """Return doom detection statistics."""
275
- total = len(self.assessment_history)
276
- if total == 0:
277
- return {"total_assessments": 0}
278
-
279
- action_counts = {}
280
- for a in self.assessment_history:
281
- action_counts[a.action.value] = action_counts.get(a.action.value, 0) + 1
282
-
283
- return {
284
- "total_assessments": total,
285
- "action_distribution": action_counts,
286
- "avg_confidence": sum(a.confidence for a in self.assessment_history) / total,
287
- }
 
1
+ """Early Termination / Doom Detector: Stop runs unlikely to succeed."""
2
+ from typing import Dict, List, Optional
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from dataclasses import dataclass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
+ @dataclass
6
+ class DoomSignal:
7
+ signal_type: str
8
+ severity: float # 0-1
9
+ evidence: str
10
 
11
  @dataclass
12
  class DoomAssessment:
13
+ doomed: bool
14
+ severity: float # 0-1
15
+ signals: List[DoomSignal]
16
+ recommended_action: str # "continue","stop","mark_blocked","ask_question","switch_strategy"
17
  reasoning: str
 
 
 
18
 
19
+ DOOM_THRESHOLDS = {
20
+ "max_failed_tool_calls": 3,
21
+ "max_repeated_planning": 2,
22
+ "max_cost_without_progress": 2.0,
23
+ "max_context_confusion": 0.7,
24
+ "max_escalation_loops": 2,
25
+ }
26
 
27
  class DoomDetector:
28
+ def __init__(self, doom_threshold: float = 0.7):
29
+ self.doom_threshold = doom_threshold
30
+ self.assessments: List[DoomAssessment] = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
+ def assess(self, steps: List[Dict], current_cost: float, max_cost: float,
33
+ model_tier: int, verifier_disagreements: int = 0) -> DoomAssessment:
 
 
 
 
 
 
 
 
 
 
 
34
  signals = []
35
+ severity = 0.0
36
+ # Check signals
37
+ failed_tools = sum(1 for s in steps for tc in s.get("tool_calls",[])
38
+ if not tc.get("success", True))
39
+ if failed_tools >= DOOM_THRESHOLDS["max_failed_tool_calls"]:
40
+ signals.append(DoomSignal("failed_tool_calls", 0.6,
41
+ f"{failed_tools} failed tool calls"))
42
+ severity += 0.3
43
+
44
+ no_progress = all(not s.get("artifacts_created") for s in steps)
45
+ if no_progress and len(steps) > 3:
46
+ signals.append(DoomSignal("no_artifact_progress", 0.5,
47
+ "no artifacts after 3+ steps"))
48
+ severity += 0.25
49
+
50
+ if current_cost > DOOM_THRESHOLDS["max_cost_without_progress"] and no_progress:
51
+ signals.append(DoomSignal("growing_cost_no_progress", 0.7,
52
+ f"cost={current_cost:.2f} with no progress"))
53
+ severity += 0.35
54
+
55
+ planning_steps = sum(1 for s in steps if s.get("retry_num",0) > 0)
56
+ if planning_steps >= DOOM_THRESHOLDS["max_repeated_planning"]:
57
+ signals.append(DoomSignal("repeated_planning", 0.4,
58
+ f"{planning_steps} retry steps"))
59
+ severity += 0.2
60
+
61
+ if verifier_disagreements >= 2:
62
+ signals.append(DoomSignal("verifier_disagreement", 0.6,
63
+ f"{verifier_disagreements} verifier disagreements"))
64
+ severity += 0.3
65
+
66
+ if current_cost >= max_cost * 0.9:
67
+ signals.append(DoomSignal("approaching_cost_limit", 0.5,
68
+ f"cost={current_cost:.2f} / {max_cost:.2f}"))
69
+ severity += 0.4
70
+
71
+ severity = min(severity, 1.0)
72
+ doomed = severity >= self.doom_threshold
73
+ # Determine action
74
+ if not doomed:
75
+ action = "continue"
76
+ reasoning = f"severity={severity:.2f} < threshold={self.doom_threshold}"
77
+ elif failed_tools >= 3 and no_progress:
78
+ action = "mark_blocked"
79
+ reasoning = "too many failures with no progress"
80
+ elif current_cost >= max_cost * 0.9:
81
+ action = "stop"
82
+ reasoning = "approaching cost limit"
83
+ elif verifier_disagreements >= 2:
84
+ action = "switch_strategy"
85
+ reasoning = "verifier disagreement suggests wrong approach"
86
+ else:
87
+ action = "ask_question"
88
+ reasoning = "run may be recoverable with user input"
89
+ assessment = DoomAssessment(doomed, severity, signals, action, reasoning)
90
+ self.assessments.append(assessment)
91
+ return assessment