narcolepticchicken commited on
Commit
5d30266
·
verified ·
1 Parent(s): a7e3035

Upload aco/verifier_budgeter.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. aco/verifier_budgeter.py +66 -330
aco/verifier_budgeter.py CHANGED
@@ -1,338 +1,74 @@
1
- """Verifier Budgeter - Module 7.
2
-
3
- Do not call verifiers everywhere.
4
-
5
- Call verifiers when:
6
- - task is high-risk
7
- - confidence is low
8
- - retrieval evidence is weak
9
- - output is irreversible
10
- - prior failures exist
11
- - cheap model was used
12
- - final answer is likely to be hallucination-prone
13
-
14
- Compare:
15
- A. no verifier
16
- B. verifier on every output
17
- C. heuristic verifier
18
- D. learned verifier budgeter
19
- E. risk-weighted verifier
20
- """
21
-
22
- from typing import Dict, List, Optional
23
  from dataclasses import dataclass
24
- from enum import Enum
25
-
26
- from .trace_schema import TaskType, Outcome
27
- from .config import ACOConfig, VerifierConfig
28
-
29
-
30
- class VerifierDecision(Enum):
31
- CALL_VERIFIER = "call_verifier"
32
- SKIP = "skip"
33
- USE_LIGHT_VERIFIER = "use_light_verifier"
34
- SELF_CHECK = "self_check"
35
-
36
 
37
  @dataclass
38
- class VerifierBudgetDecision:
39
- decision: VerifierDecision
40
- verifier_model_id: Optional[str]
41
- reasoning: str
42
- estimated_verifier_cost: float
43
- estimated_value: float
44
  confidence: float
45
- checks: List[str] # what aspects to verify
 
46
 
 
 
 
 
 
 
47
 
48
  class VerifierBudgeter:
49
- """Selectively calls verifiers based on risk and value."""
50
-
51
- # Risk thresholds for mandatory verification
52
- HIGH_RISK_TASKS = {TaskType.LEGAL_REGULATED}
53
- HALLUCINATION_PRONE_TASKS = {TaskType.RESEARCH, TaskType.DOCUMENT_DRAFTING}
54
-
55
- # Confidence thresholds
56
- LOW_CONFIDENCE_THRESHOLD = 0.6
57
- MEDIUM_CONFIDENCE_THRESHOLD = 0.8
58
-
59
- def __init__(self, config: Optional[ACOConfig] = None):
60
- self.config = config or ACOConfig()
61
- self.verifier_stats: Dict[str, Dict] = {}
62
- self.decision_history: List[Dict] = []
63
-
64
- def decide(
65
- self,
66
- task_type: TaskType,
67
- model_tier_used: int,
68
- confidence: float,
69
- has_prior_failures: bool,
70
- is_irreversible: bool,
71
- output_length_tokens: int,
72
- retrieval_evidence_count: int,
73
- step_number: int,
74
- total_steps: int,
75
- mode: str = "risk_weighted",
76
- ) -> VerifierBudgetDecision:
77
- """Decide whether and how to verify an agent output."""
78
-
79
- if mode == "no_verifier":
80
- return VerifierBudgetDecision(
81
- decision=VerifierDecision.SKIP,
82
- verifier_model_id=None,
83
- reasoning="No verifier mode",
84
- estimated_verifier_cost=0.0,
85
- estimated_value=0.0,
86
- confidence=1.0,
87
- checks=[],
88
- )
89
-
90
- if mode == "always":
91
- verifier = list(self.config.verifiers.values())[0] if self.config.verifiers else None
92
- return VerifierBudgetDecision(
93
- decision=VerifierDecision.CALL_VERIFIER,
94
- verifier_model_id=verifier.verifier_model_id if verifier else None,
95
- reasoning="Verifier on every output",
96
- estimated_verifier_cost=verifier.cost_per_call if verifier else 0.0,
97
- estimated_value=0.5,
98
- confidence=0.5,
99
- checks=["all"],
100
- )
101
-
102
- if mode == "heuristic":
103
- return self._heuristic_decide(
104
- task_type, model_tier_used, confidence, has_prior_failures,
105
- is_irreversible, output_length_tokens, retrieval_evidence_count,
106
- step_number, total_steps,
107
- )
108
-
109
- if mode in ("learned", "risk_weighted"):
110
- return self._risk_weighted_decide(
111
- task_type, model_tier_used, confidence, has_prior_failures,
112
- is_irreversible, output_length_tokens, retrieval_evidence_count,
113
- step_number, total_steps,
114
- )
115
-
116
- return self._risk_weighted_decide(
117
- task_type, model_tier_used, confidence, has_prior_failures,
118
- is_irreversible, output_length_tokens, retrieval_evidence_count,
119
- step_number, total_steps,
120
- )
121
-
122
- def _heuristic_decide(
123
- self,
124
- task_type: TaskType,
125
- model_tier_used: int,
126
- confidence: float,
127
- has_prior_failures: bool,
128
- is_irreversible: bool,
129
- output_length_tokens: int,
130
- retrieval_evidence_count: int,
131
- step_number: int,
132
- total_steps: int,
133
- ) -> VerifierBudgetDecision:
134
- """Simple heuristic-based verifier selection."""
135
-
136
- should_verify = False
137
- checks = []
138
-
139
- if task_type in self.HIGH_RISK_TASKS:
140
- should_verify = True
141
- checks.append("legal_compliance")
142
-
143
- if confidence < self.LOW_CONFIDENCE_THRESHOLD:
144
- should_verify = True
145
- checks.append("factual_accuracy")
146
-
147
  if has_prior_failures:
148
- should_verify = True
149
- checks.append("failure_pattern")
150
-
151
- if is_irreversible:
152
- should_verify = True
153
- checks.append("safety")
154
-
155
- if model_tier_used <= 2 and task_type in (TaskType.CODING, TaskType.RESEARCH):
156
- should_verify = True
157
- checks.append("output_quality")
158
-
159
- if retrieval_evidence_count == 0 and task_type in self.HALLUCINATION_PRONE_TASKS:
160
- should_verify = True
161
- checks.append("hallucination")
162
-
163
- if not should_verify:
164
- return VerifierBudgetDecision(
165
- decision=VerifierDecision.SKIP,
166
- verifier_model_id=None,
167
- reasoning="No heuristic triggers met",
168
- estimated_verifier_cost=0.0,
169
- estimated_value=0.0,
170
- confidence=confidence,
171
- checks=[],
172
- )
173
-
174
- # Pick verifier
175
- verifier = self._select_verifier(checks)
176
-
177
- return VerifierBudgetDecision(
178
- decision=VerifierDecision.CALL_VERIFIER,
179
- verifier_model_id=verifier.verifier_model_id if verifier else None,
180
- reasoning=f"Heuristic triggered: {', '.join(checks)}",
181
- estimated_verifier_cost=verifier.cost_per_call if verifier else 0.0,
182
- estimated_value=0.3 + (1.0 - confidence) * 0.5,
183
- confidence=confidence,
184
- checks=checks,
185
- )
186
-
187
- def _risk_weighted_decide(
188
- self,
189
- task_type: TaskType,
190
- model_tier_used: int,
191
- confidence: float,
192
- has_prior_failures: bool,
193
- is_irreversible: bool,
194
- output_length_tokens: int,
195
- retrieval_evidence_count: int,
196
- step_number: int,
197
- total_steps: int,
198
- ) -> VerifierBudgetDecision:
199
- """Risk-weighted verifier selection with value estimation."""
200
-
201
- # Compute risk score
202
- risk = 0.0
203
- checks = []
204
-
205
- # Task type risk
206
- if task_type in self.HIGH_RISK_TASKS:
207
- risk += 0.4
208
- checks.append("legal_compliance")
209
- elif task_type in {TaskType.CODING, TaskType.RESEARCH}:
210
- risk += 0.15
211
-
212
- # Confidence risk
213
- risk += max(0.0, (self.MEDIUM_CONFIDENCE_THRESHOLD - confidence) * 0.5)
214
- if confidence < self.LOW_CONFIDENCE_THRESHOLD:
215
- checks.append("factual_accuracy")
216
-
217
- # Model tier risk (cheap models are riskier)
218
- risk += max(0.0, (3 - model_tier_used) * 0.05)
219
- if model_tier_used <= 2 and task_type in (TaskType.CODING, TaskType.RESEARCH, TaskType.DOCUMENT_DRAFTING):
220
- checks.append("output_quality")
221
-
222
- # Prior failures
223
- if has_prior_failures:
224
- risk += 0.2
225
- checks.append("failure_pattern")
226
-
227
- # Irreversibility
228
- if is_irreversible:
229
- risk += 0.25
230
- checks.append("safety")
231
-
232
- # Evidence weakness
233
- if retrieval_evidence_count == 0 and task_type in self.HALLUCINATION_PRONE_TASKS:
234
- risk += 0.2
235
- checks.append("hallucination")
236
-
237
- # Output size (longer outputs are more error-prone)
238
- if output_length_tokens > 2048:
239
- risk += min(0.1, (output_length_tokens - 2048) / 50000)
240
-
241
- # Step position (first and last steps are more critical)
242
- if step_number == 1:
243
- risk += 0.05
244
- if step_number == total_steps and total_steps > 1:
245
- risk += 0.1
246
- checks.append("final_output")
247
-
248
- risk = min(risk, 1.0)
249
-
250
- # Cost-benefit analysis
251
- verifier = self._select_verifier(checks)
252
- verifier_cost = verifier.cost_per_call if verifier else 0.01
253
-
254
- # Expected value of verification
255
- # If risk is high, catching an error is very valuable
256
- error_cost = self._estimate_error_cost(task_type, is_irreversible)
257
- p_error = risk
258
- p_catch = 0.7 # verifier catches error with 70% probability
259
- expected_value = p_error * p_catch * error_cost - verifier_cost
260
-
261
- if expected_value > 0 or risk > 0.6:
262
- return VerifierBudgetDecision(
263
- decision=VerifierDecision.CALL_VERIFIER,
264
- verifier_model_id=verifier.verifier_model_id if verifier else None,
265
- reasoning=f"Risk={risk:.2f}, expected_value={expected_value:.4f}, checks={checks}",
266
- estimated_verifier_cost=verifier_cost,
267
- estimated_value=expected_value,
268
- confidence=confidence,
269
- checks=list(set(checks)),
270
- )
271
-
272
- # For medium risk, use a lighter self-check
273
- if risk > 0.3:
274
- return VerifierBudgetDecision(
275
- decision=VerifierDecision.SELF_CHECK,
276
- verifier_model_id=None,
277
- reasoning=f"Medium risk ({risk:.2f}) — use lightweight self-check instead of full verifier",
278
- estimated_verifier_cost=0.0,
279
- estimated_value=p_error * 0.3 * error_cost,
280
- confidence=confidence,
281
- checks=["self_consistency"],
282
- )
283
-
284
- return VerifierBudgetDecision(
285
- decision=VerifierDecision.SKIP,
286
- verifier_model_id=None,
287
- reasoning=f"Low risk ({risk:.2f}), expected_value={expected_value:.4f} — skip verification",
288
- estimated_verifier_cost=0.0,
289
- estimated_value=0.0,
290
- confidence=confidence,
291
- checks=[],
292
- )
293
-
294
- def _select_verifier(self, checks: List[str]) -> Optional[VerifierConfig]:
295
- """Select appropriate verifier based on checks needed."""
296
- if not self.config.verifiers:
297
- return None
298
-
299
- # For now, return the first verifier
300
- # In production, map check types to specialist verifiers
301
- return list(self.config.verifiers.values())[0]
302
-
303
- def _estimate_error_cost(self, task_type: TaskType, is_irreversible: bool) -> float:
304
- """Estimate the cost of an undetected error."""
305
- base = 1.0
306
- if task_type == TaskType.LEGAL_REGULATED:
307
- base = 10.0
308
- elif task_type == TaskType.CODING:
309
- base = 3.0
310
- elif task_type == TaskType.RESEARCH:
311
- base = 2.0
312
- if is_irreversible:
313
- base *= 2.0
314
- return base
315
-
316
- def record_verifier_outcome(
317
- self,
318
- verifier_model_id: str,
319
- target_step: str,
320
- passed: bool,
321
- cost: float,
322
- was_actually_correct: bool,
323
- ) -> None:
324
- """Record verifier performance for calibration."""
325
- stats = self.verifier_stats.setdefault(verifier_model_id, {
326
- "calls": 0, "passed": 0, "true_positives": 0, "false_positives": 0,
327
- "false_negatives": 0, "total_cost": 0.0,
328
- })
329
- stats["calls"] += 1
330
- if passed:
331
- stats["passed"] += 1
332
- if passed and was_actually_correct:
333
- stats["true_positives"] += 1
334
- if passed and not was_actually_correct:
335
- stats["false_positives"] += 1
336
- if not passed and was_actually_correct:
337
- stats["false_negatives"] += 1
338
- stats["total_cost"] += cost
 
1
+ """Verifier Budgeter: Selective verification for high-risk outputs only."""
2
+ from typing import Dict, Optional, Tuple
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from dataclasses import dataclass
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  @dataclass
6
+ class VerifierDecision:
7
+ should_verify: bool
8
+ verifier_type: str # "full", "spot_check", "skip"
 
 
 
9
  confidence: float
10
+ reasoning: str
11
+ estimated_cost: float
12
 
13
+ RISK_VERIFIER_MAP = {
14
+ "critical": {"min_confidence": 0.95, "verifier_type": "full"},
15
+ "high": {"min_confidence": 0.85, "verifier_type": "full"},
16
+ "medium": {"min_confidence": 0.70, "verifier_type": "spot_check"},
17
+ "low": {"min_confidence": 0.50, "verifier_type": "spot_check"},
18
+ }
19
 
20
  class VerifierBudgeter:
21
+ def __init__(self, verifier_cost: float = 0.02, max_verifications_per_run: int = 3):
22
+ self.verifier_cost = verifier_cost
23
+ self.max_per_run = max_verifications_per_run
24
+ self.verifications_this_run = 0
25
+ self.stats = {"verified":0,"skipped":0,"spot_checked":0,"false_passes":0,"false_rejects":0}
26
+
27
+ def should_verify(self, task_type: str, risk: str, model_confidence: float,
28
+ is_irreversible: bool = False, has_prior_failures: bool = False,
29
+ model_tier: int = 4, step_num: int = 0,
30
+ total_steps: int = 1) -> VerifierDecision:
31
+ if self.verifications_this_run >= self.max_per_run:
32
+ self.stats["skipped"] += 1
33
+ return VerifierDecision(False, "skip", model_confidence,
34
+ "max verifications reached", 0.0)
35
+ # Check if this is the final answer
36
+ is_final = (step_num == total_steps or step_num == -1)
37
+ # Risk-based threshold
38
+ risk_config = RISK_VERIFIER_MAP.get(risk, RISK_VERIFIER_MAP["medium"])
39
+ min_conf = risk_config["min_confidence"]
40
+ default_type = risk_config["verifier_type"]
41
+ # Adjust for context
42
+ should = False
43
+ reasons = []
44
+ if model_confidence < min_conf:
45
+ should = True
46
+ reasons.append(f"low confidence ({model_confidence:.2f} < {min_conf})")
47
+ if is_irreversible and risk in ("high", "critical"):
48
+ should = True
49
+ reasons.append("irreversible + high risk")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  if has_prior_failures:
51
+ should = True
52
+ reasons.append("prior failures exist")
53
+ if model_tier <= 2 and risk in ("high", "critical"):
54
+ should = True
55
+ reasons.append("cheap model on high-risk task")
56
+ if is_final and risk in ("high", "critical"):
57
+ should = True
58
+ reasons.append("final answer on high-risk task")
59
+ if not should:
60
+ self.stats["skipped"] += 1
61
+ return VerifierDecision(False, "skip", model_confidence,
62
+ "no verification needed", 0.0)
63
+ self.verifications_this_run += 1
64
+ vtype = default_type
65
+ if model_confidence > min_conf and not is_irreversible:
66
+ vtype = "spot_check"
67
+ self.stats["spot_checked"] += 1
68
+ else:
69
+ self.stats["verified"] += 1
70
+ return VerifierDecision(True, vtype, model_confidence,
71
+ "; ".join(reasons), self.verifier_cost)
72
+
73
+ def reset_run(self):
74
+ self.verifications_this_run = 0