narcolepticchicken commited on
Commit
33a5f28
·
verified ·
1 Parent(s): 7d60df1

Upload aco/verifier_budgeter.py

Browse files
Files changed (1) hide show
  1. aco/verifier_budgeter.py +338 -0
aco/verifier_budgeter.py ADDED
@@ -0,0 +1,338 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Verifier Budgeter - Module 7.
2
+
3
+ Do not call verifiers everywhere.
4
+
5
+ Call verifiers when:
6
+ - task is high-risk
7
+ - confidence is low
8
+ - retrieval evidence is weak
9
+ - output is irreversible
10
+ - prior failures exist
11
+ - cheap model was used
12
+ - final answer is likely to be hallucination-prone
13
+
14
+ Compare:
15
+ A. no verifier
16
+ B. verifier on every output
17
+ C. heuristic verifier
18
+ D. learned verifier budgeter
19
+ E. risk-weighted verifier
20
+ """
21
+
22
+ from typing import Dict, List, Optional
23
+ from dataclasses import dataclass
24
+ from enum import Enum
25
+
26
+ from .trace_schema import TaskType, Outcome
27
+ from .config import ACOConfig, VerifierConfig
28
+
29
+
30
+ class VerifierDecision(Enum):
31
+ CALL_VERIFIER = "call_verifier"
32
+ SKIP = "skip"
33
+ USE_LIGHT_VERIFIER = "use_light_verifier"
34
+ SELF_CHECK = "self_check"
35
+
36
+
37
+ @dataclass
38
+ class VerifierBudgetDecision:
39
+ decision: VerifierDecision
40
+ verifier_model_id: Optional[str]
41
+ reasoning: str
42
+ estimated_verifier_cost: float
43
+ estimated_value: float
44
+ confidence: float
45
+ checks: List[str] # what aspects to verify
46
+
47
+
48
+ class VerifierBudgeter:
49
+ """Selectively calls verifiers based on risk and value."""
50
+
51
+ # Risk thresholds for mandatory verification
52
+ HIGH_RISK_TASKS = {TaskType.LEGAL_REGULATED}
53
+ HALLUCINATION_PRONE_TASKS = {TaskType.RESEARCH, TaskType.DOCUMENT_DRAFTING}
54
+
55
+ # Confidence thresholds
56
+ LOW_CONFIDENCE_THRESHOLD = 0.6
57
+ MEDIUM_CONFIDENCE_THRESHOLD = 0.8
58
+
59
+ def __init__(self, config: Optional[ACOConfig] = None):
60
+ self.config = config or ACOConfig()
61
+ self.verifier_stats: Dict[str, Dict] = {}
62
+ self.decision_history: List[Dict] = []
63
+
64
+ def decide(
65
+ self,
66
+ task_type: TaskType,
67
+ model_tier_used: int,
68
+ confidence: float,
69
+ has_prior_failures: bool,
70
+ is_irreversible: bool,
71
+ output_length_tokens: int,
72
+ retrieval_evidence_count: int,
73
+ step_number: int,
74
+ total_steps: int,
75
+ mode: str = "risk_weighted",
76
+ ) -> VerifierBudgetDecision:
77
+ """Decide whether and how to verify an agent output."""
78
+
79
+ if mode == "no_verifier":
80
+ return VerifierBudgetDecision(
81
+ decision=VerifierDecision.SKIP,
82
+ verifier_model_id=None,
83
+ reasoning="No verifier mode",
84
+ estimated_verifier_cost=0.0,
85
+ estimated_value=0.0,
86
+ confidence=1.0,
87
+ checks=[],
88
+ )
89
+
90
+ if mode == "always":
91
+ verifier = list(self.config.verifiers.values())[0] if self.config.verifiers else None
92
+ return VerifierBudgetDecision(
93
+ decision=VerifierDecision.CALL_VERIFIER,
94
+ verifier_model_id=verifier.verifier_model_id if verifier else None,
95
+ reasoning="Verifier on every output",
96
+ estimated_verifier_cost=verifier.cost_per_call if verifier else 0.0,
97
+ estimated_value=0.5,
98
+ confidence=0.5,
99
+ checks=["all"],
100
+ )
101
+
102
+ if mode == "heuristic":
103
+ return self._heuristic_decide(
104
+ task_type, model_tier_used, confidence, has_prior_failures,
105
+ is_irreversible, output_length_tokens, retrieval_evidence_count,
106
+ step_number, total_steps,
107
+ )
108
+
109
+ if mode in ("learned", "risk_weighted"):
110
+ return self._risk_weighted_decide(
111
+ task_type, model_tier_used, confidence, has_prior_failures,
112
+ is_irreversible, output_length_tokens, retrieval_evidence_count,
113
+ step_number, total_steps,
114
+ )
115
+
116
+ return self._risk_weighted_decide(
117
+ task_type, model_tier_used, confidence, has_prior_failures,
118
+ is_irreversible, output_length_tokens, retrieval_evidence_count,
119
+ step_number, total_steps,
120
+ )
121
+
122
+ def _heuristic_decide(
123
+ self,
124
+ task_type: TaskType,
125
+ model_tier_used: int,
126
+ confidence: float,
127
+ has_prior_failures: bool,
128
+ is_irreversible: bool,
129
+ output_length_tokens: int,
130
+ retrieval_evidence_count: int,
131
+ step_number: int,
132
+ total_steps: int,
133
+ ) -> VerifierBudgetDecision:
134
+ """Simple heuristic-based verifier selection."""
135
+
136
+ should_verify = False
137
+ checks = []
138
+
139
+ if task_type in self.HIGH_RISK_TASKS:
140
+ should_verify = True
141
+ checks.append("legal_compliance")
142
+
143
+ if confidence < self.LOW_CONFIDENCE_THRESHOLD:
144
+ should_verify = True
145
+ checks.append("factual_accuracy")
146
+
147
+ if has_prior_failures:
148
+ should_verify = True
149
+ checks.append("failure_pattern")
150
+
151
+ if is_irreversible:
152
+ should_verify = True
153
+ checks.append("safety")
154
+
155
+ if model_tier_used <= 2 and task_type in (TaskType.CODING, TaskType.RESEARCH):
156
+ should_verify = True
157
+ checks.append("output_quality")
158
+
159
+ if retrieval_evidence_count == 0 and task_type in self.HALLUCINATION_PRONE_TASKS:
160
+ should_verify = True
161
+ checks.append("hallucination")
162
+
163
+ if not should_verify:
164
+ return VerifierBudgetDecision(
165
+ decision=VerifierDecision.SKIP,
166
+ verifier_model_id=None,
167
+ reasoning="No heuristic triggers met",
168
+ estimated_verifier_cost=0.0,
169
+ estimated_value=0.0,
170
+ confidence=confidence,
171
+ checks=[],
172
+ )
173
+
174
+ # Pick verifier
175
+ verifier = self._select_verifier(checks)
176
+
177
+ return VerifierBudgetDecision(
178
+ decision=VerifierDecision.CALL_VERIFIER,
179
+ verifier_model_id=verifier.verifier_model_id if verifier else None,
180
+ reasoning=f"Heuristic triggered: {', '.join(checks)}",
181
+ estimated_verifier_cost=verifier.cost_per_call if verifier else 0.0,
182
+ estimated_value=0.3 + (1.0 - confidence) * 0.5,
183
+ confidence=confidence,
184
+ checks=checks,
185
+ )
186
+
187
+ def _risk_weighted_decide(
188
+ self,
189
+ task_type: TaskType,
190
+ model_tier_used: int,
191
+ confidence: float,
192
+ has_prior_failures: bool,
193
+ is_irreversible: bool,
194
+ output_length_tokens: int,
195
+ retrieval_evidence_count: int,
196
+ step_number: int,
197
+ total_steps: int,
198
+ ) -> VerifierBudgetDecision:
199
+ """Risk-weighted verifier selection with value estimation."""
200
+
201
+ # Compute risk score
202
+ risk = 0.0
203
+ checks = []
204
+
205
+ # Task type risk
206
+ if task_type in self.HIGH_RISK_TASKS:
207
+ risk += 0.4
208
+ checks.append("legal_compliance")
209
+ elif task_type in {TaskType.CODING, TaskType.RESEARCH}:
210
+ risk += 0.15
211
+
212
+ # Confidence risk
213
+ risk += max(0.0, (self.MEDIUM_CONFIDENCE_THRESHOLD - confidence) * 0.5)
214
+ if confidence < self.LOW_CONFIDENCE_THRESHOLD:
215
+ checks.append("factual_accuracy")
216
+
217
+ # Model tier risk (cheap models are riskier)
218
+ risk += max(0.0, (3 - model_tier_used) * 0.05)
219
+ if model_tier_used <= 2 and task_type in (TaskType.CODING, TaskType.RESEARCH, TaskType.DOCUMENT_DRAFTING):
220
+ checks.append("output_quality")
221
+
222
+ # Prior failures
223
+ if has_prior_failures:
224
+ risk += 0.2
225
+ checks.append("failure_pattern")
226
+
227
+ # Irreversibility
228
+ if is_irreversible:
229
+ risk += 0.25
230
+ checks.append("safety")
231
+
232
+ # Evidence weakness
233
+ if retrieval_evidence_count == 0 and task_type in self.HALLUCINATION_PRONE_TASKS:
234
+ risk += 0.2
235
+ checks.append("hallucination")
236
+
237
+ # Output size (longer outputs are more error-prone)
238
+ if output_length_tokens > 2048:
239
+ risk += min(0.1, (output_length_tokens - 2048) / 50000)
240
+
241
+ # Step position (first and last steps are more critical)
242
+ if step_number == 1:
243
+ risk += 0.05
244
+ if step_number == total_steps and total_steps > 1:
245
+ risk += 0.1
246
+ checks.append("final_output")
247
+
248
+ risk = min(risk, 1.0)
249
+
250
+ # Cost-benefit analysis
251
+ verifier = self._select_verifier(checks)
252
+ verifier_cost = verifier.cost_per_call if verifier else 0.01
253
+
254
+ # Expected value of verification
255
+ # If risk is high, catching an error is very valuable
256
+ error_cost = self._estimate_error_cost(task_type, is_irreversible)
257
+ p_error = risk
258
+ p_catch = 0.7 # verifier catches error with 70% probability
259
+ expected_value = p_error * p_catch * error_cost - verifier_cost
260
+
261
+ if expected_value > 0 or risk > 0.6:
262
+ return VerifierBudgetDecision(
263
+ decision=VerifierDecision.CALL_VERIFIER,
264
+ verifier_model_id=verifier.verifier_model_id if verifier else None,
265
+ reasoning=f"Risk={risk:.2f}, expected_value={expected_value:.4f}, checks={checks}",
266
+ estimated_verifier_cost=verifier_cost,
267
+ estimated_value=expected_value,
268
+ confidence=confidence,
269
+ checks=list(set(checks)),
270
+ )
271
+
272
+ # For medium risk, use a lighter self-check
273
+ if risk > 0.3:
274
+ return VerifierBudgetDecision(
275
+ decision=VerifierDecision.SELF_CHECK,
276
+ verifier_model_id=None,
277
+ reasoning=f"Medium risk ({risk:.2f}) — use lightweight self-check instead of full verifier",
278
+ estimated_verifier_cost=0.0,
279
+ estimated_value=p_error * 0.3 * error_cost,
280
+ confidence=confidence,
281
+ checks=["self_consistency"],
282
+ )
283
+
284
+ return VerifierBudgetDecision(
285
+ decision=VerifierDecision.SKIP,
286
+ verifier_model_id=None,
287
+ reasoning=f"Low risk ({risk:.2f}), expected_value={expected_value:.4f} — skip verification",
288
+ estimated_verifier_cost=0.0,
289
+ estimated_value=0.0,
290
+ confidence=confidence,
291
+ checks=[],
292
+ )
293
+
294
+ def _select_verifier(self, checks: List[str]) -> Optional[VerifierConfig]:
295
+ """Select appropriate verifier based on checks needed."""
296
+ if not self.config.verifiers:
297
+ return None
298
+
299
+ # For now, return the first verifier
300
+ # In production, map check types to specialist verifiers
301
+ return list(self.config.verifiers.values())[0]
302
+
303
+ def _estimate_error_cost(self, task_type: TaskType, is_irreversible: bool) -> float:
304
+ """Estimate the cost of an undetected error."""
305
+ base = 1.0
306
+ if task_type == TaskType.LEGAL_REGULATED:
307
+ base = 10.0
308
+ elif task_type == TaskType.CODING:
309
+ base = 3.0
310
+ elif task_type == TaskType.RESEARCH:
311
+ base = 2.0
312
+ if is_irreversible:
313
+ base *= 2.0
314
+ return base
315
+
316
+ def record_verifier_outcome(
317
+ self,
318
+ verifier_model_id: str,
319
+ target_step: str,
320
+ passed: bool,
321
+ cost: float,
322
+ was_actually_correct: bool,
323
+ ) -> None:
324
+ """Record verifier performance for calibration."""
325
+ stats = self.verifier_stats.setdefault(verifier_model_id, {
326
+ "calls": 0, "passed": 0, "true_positives": 0, "false_positives": 0,
327
+ "false_negatives": 0, "total_cost": 0.0,
328
+ })
329
+ stats["calls"] += 1
330
+ if passed:
331
+ stats["passed"] += 1
332
+ if passed and was_actually_correct:
333
+ stats["true_positives"] += 1
334
+ if passed and not was_actually_correct:
335
+ stats["false_positives"] += 1
336
+ if not passed and was_actually_correct:
337
+ stats["false_negatives"] += 1
338
+ stats["total_cost"] += cost