narcolepticchicken commited on
Commit
c55ff1a
·
verified ·
1 Parent(s): 99ad299

Upload aco/optimizer.py

Browse files
Files changed (1) hide show
  1. aco/optimizer.py +322 -0
aco/optimizer.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Agent Cost Optimizer - Main orchestrator."""
2
+
3
+ import uuid
4
+ import time
5
+ from typing import Dict, List, Optional, Any
6
+ from dataclasses import dataclass
7
+
8
+ from .config import ACOConfig
9
+ from .trace_schema import (
10
+ AgentTrace, TraceStep, ModelCall, ToolCall, VerifierCall,
11
+ TaskType, Outcome, FailureTag,
12
+ )
13
+ from .telemetry import CostTelemetryCollector
14
+ from .classifier import TaskCostClassifier, TaskPrediction
15
+ from .router import ModelCascadeRouter, RoutingDecision
16
+ from .context_budgeter import ContextBudgeter, ContextSource, ContextBudget
17
+ from .cache_layout import CacheAwarePromptLayout, PromptLayout
18
+ from .tool_gate import ToolUseCostGate, ToolGateDecision, ToolDecision
19
+ from .verifier_budgeter import VerifierBudgeter, VerifierBudgetDecision
20
+ from .retry_optimizer import RetryRecoveryOptimizer, RecoveryDecision, RecoveryAction
21
+ from .meta_tool_miner import MetaToolMiner
22
+ from .doom_detector import DoomDetector, DoomAssessment, DoomAction
23
+
24
+
25
+ @dataclass
26
+ class OptimizationResult:
27
+ trace_id: str
28
+ routing_decision: RoutingDecision
29
+ context_budget: Optional[ContextBudget]
30
+ prompt_layout: Optional[PromptLayout]
31
+ tool_decisions: List[ToolGateDecision]
32
+ verifier_decision: Optional[VerifierBudgetDecision]
33
+ recovery_decision: Optional[RecoveryDecision]
34
+ doom_assessment: Optional[DoomAssessment]
35
+ meta_tool_match: Optional[Dict]
36
+ estimated_cost: float
37
+ estimated_latency_ms: float
38
+ confidence: float
39
+ reasoning: str
40
+
41
+
42
+ class AgentCostOptimizer:
43
+ """Universal control layer for reducing agent run costs while preserving quality."""
44
+
45
+ def __init__(self, config: Optional[ACOConfig] = None):
46
+ self.config = config or ACOConfig()
47
+
48
+ # Core modules
49
+ self.telemetry = CostTelemetryCollector(self.config.trace_storage_path)
50
+ self.classifier = TaskCostClassifier(self.config)
51
+ self.router = ModelCascadeRouter(self.config)
52
+ self.context_budgeter = ContextBudgeter(self.config)
53
+ self.cache_layout = CacheAwarePromptLayout(self.config)
54
+ self.tool_gate = ToolUseCostGate(self.config)
55
+ self.verifier_budgeter = VerifierBudgeter(self.config)
56
+ self.retry_optimizer = RetryRecoveryOptimizer(self.config)
57
+ self.meta_tool_miner = MetaToolMiner(self.config)
58
+ self.doom_detector = DoomDetector(self.config)
59
+
60
+ # Runtime state
61
+ self.active_traces: Dict[str, AgentTrace] = {}
62
+ self.step_counter: Dict[str, int] = {}
63
+
64
+ def optimize(self, user_request: str, run_state: Optional[Dict] = None) -> OptimizationResult:
65
+ """Main entry point: decide how to execute an agent request cost-effectively."""
66
+
67
+ run_state = run_state or {}
68
+ trace_id = run_state.get("trace_id", str(uuid.uuid4()))
69
+
70
+ # 1. Classify the task
71
+ past_traces = self._get_past_traces()
72
+ prediction = self.classifier.classify_with_history(user_request, past_traces)
73
+
74
+ # 2. Route to model
75
+ routing_mode = run_state.get("routing_mode", "cascade")
76
+ routing = self.router.route(prediction, routing_mode=routing_mode)
77
+
78
+ # 3. Budget context
79
+ context_budget = None
80
+ if self.config.enable_context_budgeter:
81
+ available_sources = self._build_context_sources(run_state)
82
+ model_cfg = self.config.models.get(routing.model_id)
83
+ max_ctx = model_cfg.max_context if model_cfg else 128000
84
+ cost_1k = model_cfg.cost_per_1k_input if model_cfg else 0.01
85
+ context_budget = self.context_budgeter.budget(
86
+ prediction.task_type, available_sources, max_ctx, cost_1k
87
+ )
88
+
89
+ # 4. Optimize cache layout
90
+ prompt_layout = None
91
+ if self.config.enable_cache_layout and context_budget:
92
+ content_pieces = self._build_content_pieces(context_budget)
93
+ prompt_layout = self.cache_layout.layout(
94
+ content_pieces, cost_1k, self.config.models.get(routing.model_id, {}).cache_discount_rate if model_cfg else 0.5
95
+ )
96
+
97
+ # 5. Check for meta-tool
98
+ meta_tool_match = None
99
+ if self.config.enable_meta_tool_miner:
100
+ planned_tools = run_state.get("planned_tools", [])
101
+ meta_tool_match = self.meta_tool_miner.match_and_compress(
102
+ prediction.task_type, planned_tools
103
+ )
104
+
105
+ # 6. Gate tool calls
106
+ tool_decisions = []
107
+ if self.config.enable_tool_gate:
108
+ planned_tools = run_state.get("planned_tools", [])
109
+ current_cost = run_state.get("current_cost", 0.0)
110
+ for tool_name, tool_input in planned_tools:
111
+ decision = self.tool_gate.decide(
112
+ tool_name, prediction.task_type, tool_input,
113
+ run_state.get("previous_tool_calls"), current_cost,
114
+ prediction.expected_cost,
115
+ )
116
+ tool_decisions.append(decision)
117
+
118
+ # 7. Decide verifier
119
+ verifier_decision = None
120
+ if self.config.enable_verifier_budgeter:
121
+ verifier_decision = self.verifier_budgeter.decide(
122
+ task_type=prediction.task_type,
123
+ model_tier_used=routing.tier,
124
+ confidence=routing.confidence,
125
+ has_prior_failures=bool(run_state.get("prior_failures")),
126
+ is_irreversible=run_state.get("is_irreversible", False),
127
+ output_length_tokens=run_state.get("expected_output_tokens", 1024),
128
+ retrieval_evidence_count=len(run_state.get("retrieved_docs", [])),
129
+ step_number=run_state.get("step_number", 1),
130
+ total_steps=run_state.get("total_steps", 1),
131
+ mode="risk_weighted",
132
+ )
133
+
134
+ # 8. Check for doom
135
+ doom = None
136
+ if self.config.enable_early_termination and trace_id in self.active_traces:
137
+ trace = self.active_traces[trace_id]
138
+ current_step = trace.steps[-1] if trace.steps else None
139
+ doom = self.doom_detector.assess(
140
+ trace, current_step, prediction.expected_cost, prediction.expected_cost * 10000
141
+ )
142
+
143
+ # 9. Recovery (if in recovery mode)
144
+ recovery = None
145
+ if run_state.get("in_recovery") and trace_id in self.active_traces:
146
+ trace = self.active_traces[trace_id]
147
+ failure_tags = [FailureTag(f) for f in run_state.get("failure_tags", [])]
148
+ current_step = trace.steps[-1] if trace.steps else None
149
+ recovery = self.retry_optimizer.decide_recovery(
150
+ prediction.task_type,
151
+ current_step,
152
+ failure_tags,
153
+ trace.total_cost_computed,
154
+ prediction.expected_cost,
155
+ routing.tier,
156
+ run_state.get("step_number", 1),
157
+ trace.steps,
158
+ )
159
+
160
+ # Estimate cost
161
+ est_model_cost = self._estimate_model_cost(routing, context_budget)
162
+ est_tool_cost = sum(d.estimated_cost for d in tool_decisions if d.decision == ToolDecision.USE)
163
+ est_verifier_cost = verifier_decision.estimated_verifier_cost if verifier_decision else 0.0
164
+
165
+ estimated_cost = est_model_cost + est_tool_cost + est_verifier_cost
166
+ estimated_latency = prediction.expected_latency_ms
167
+
168
+ return OptimizationResult(
169
+ trace_id=trace_id,
170
+ routing_decision=routing,
171
+ context_budget=context_budget,
172
+ prompt_layout=prompt_layout,
173
+ tool_decisions=tool_decisions,
174
+ verifier_decision=verifier_decision,
175
+ recovery_decision=recovery,
176
+ doom_assessment=doom,
177
+ meta_tool_match=meta_tool_match,
178
+ estimated_cost=estimated_cost,
179
+ estimated_latency_ms=estimated_latency,
180
+ confidence=routing.confidence,
181
+ reasoning=f"Task={prediction.task_type.value}, tier={routing.tier}, risk={prediction.risk_of_failure:.2f}",
182
+ )
183
+
184
+ def start_trace(self, trace_id: str, user_request: str, prediction: TaskPrediction) -> AgentTrace:
185
+ trace = self.telemetry.start_trace(trace_id, user_request, prediction.task_type)
186
+ self.active_traces[trace_id] = trace
187
+ self.step_counter[trace_id] = 0
188
+ return trace
189
+
190
+ def record_step(
191
+ self,
192
+ trace_id: str,
193
+ model_call: ModelCall,
194
+ tool_calls: Optional[List[ToolCall]] = None,
195
+ verifier_calls: Optional[List[VerifierCall]] = None,
196
+ context_size_tokens: int = 0,
197
+ step_outcome: Optional[Outcome] = None,
198
+ ) -> None:
199
+ self.step_counter[trace_id] = self.step_counter.get(trace_id, 0) + 1
200
+ step_id = f"{trace_id}_step_{self.step_counter[trace_id]}"
201
+ self.telemetry.add_step(
202
+ trace_id=trace_id,
203
+ step_id=step_id,
204
+ model_call=model_call,
205
+ tool_calls=tool_calls or [],
206
+ verifier_calls=verifier_calls or [],
207
+ context_size_tokens=context_size_tokens,
208
+ step_outcome=step_outcome,
209
+ )
210
+
211
+ def finalize_trace(
212
+ self,
213
+ trace_id: str,
214
+ outcome: Outcome,
215
+ failure_tags: Optional[List[FailureTag]] = None,
216
+ user_satisfaction: Optional[float] = None,
217
+ ) -> AgentTrace:
218
+ trace = self.active_traces.pop(trace_id, None)
219
+ if trace:
220
+ # Mine for meta-tools if successful
221
+ if outcome in (Outcome.SUCCESS, Outcome.PARTIAL_SUCCESS):
222
+ self.meta_tool_miner.ingest_trace(trace)
223
+
224
+ return self.telemetry.finalize_trace(
225
+ trace_id=trace_id,
226
+ final_outcome=outcome,
227
+ failure_tags=failure_tags,
228
+ user_satisfaction=user_satisfaction,
229
+ )
230
+ return None
231
+
232
+ def compute_cost_adjusted_score(
233
+ self,
234
+ trace: AgentTrace,
235
+ success_score: float = 1.0,
236
+ safety_bonus: float = 0.0,
237
+ artifact_bonus: float = 0.0,
238
+ calibration_bonus: float = 0.0,
239
+ ) -> float:
240
+ """Compute the cost-adjusted quality score."""
241
+ cost = trace.total_cost_computed
242
+ retries = trace.total_retries
243
+ tools = trace.total_tool_calls
244
+ verifiers = trace.total_verifier_calls
245
+
246
+ score = (
247
+ success_score
248
+ + safety_bonus
249
+ + artifact_bonus
250
+ + calibration_bonus
251
+ - self.config.model_cost_weight * cost
252
+ - self.config.tool_cost_weight * tools * 0.001
253
+ - self.config.verifier_cost_weight * verifiers * 0.001
254
+ - self.config.latency_weight * trace.total_latency_ms / 1000
255
+ - self.config.retry_penalty_weight * retries * 0.01
256
+ )
257
+
258
+ # Penalize critical failures
259
+ if FailureTag.UNSAFE_CHEAP_MODEL in trace.failure_tags:
260
+ score -= self.config.unsafe_cheap_model_penalty
261
+ if FailureTag.MISSED_ESCALATION in trace.failure_tags:
262
+ score -= self.config.missed_escalation_penalty
263
+ if trace.final_outcome == Outcome.FALSE_DONE:
264
+ score -= self.config.false_done_penalty
265
+
266
+ return score
267
+
268
+ def _get_past_traces(self) -> List[Dict]:
269
+ """Get historical traces as dicts."""
270
+ traces = []
271
+ for tid in self.telemetry.list_traces():
272
+ t = self.telemetry.load_trace(tid)
273
+ if t and isinstance(t, dict):
274
+ traces.append(t)
275
+ return traces
276
+
277
+ def _build_context_sources(self, run_state: Dict) -> List[ContextSource]:
278
+ """Build context source objects from run state."""
279
+ sources = []
280
+ for name, content in run_state.get("context_pieces", {}).items():
281
+ sources.append(ContextSource(
282
+ name=name,
283
+ tokens=len(content) // 4,
284
+ importance=0.5,
285
+ staleness=0.0,
286
+ mutable=name in ["user_message", "retrieved_docs", "recent_trace"],
287
+ cacheable=name in ["system_rules", "tool_descriptions", "user_preferences"],
288
+ ))
289
+ return sources
290
+
291
+ def _build_content_pieces(self, context_budget: ContextBudget) -> Dict[str, str]:
292
+ """Build content pieces dict from context budget."""
293
+ pieces = {}
294
+ for src in context_budget.allocated_sources:
295
+ pieces[src.name] = f"[{src.name}]"
296
+ for src, summary in context_budget.summarized_sources:
297
+ pieces[src.name] = summary
298
+ return pieces
299
+
300
+ def _estimate_model_cost(self, routing: RoutingDecision, context_budget: Optional[ContextBudget]) -> float:
301
+ model_cfg = self.config.models.get(routing.model_id)
302
+ if not model_cfg:
303
+ return 0.01
304
+ input_tokens = context_budget.total_budget_tokens if context_budget else 4096
305
+ output_tokens = routing.max_tokens
306
+ input_cost = (input_tokens / 1000) * model_cfg.cost_per_1k_input
307
+ output_cost = (output_tokens / 1000) * model_cfg.cost_per_1k_output
308
+ return input_cost + output_cost
309
+
310
+ def get_stats(self) -> Dict[str, Any]:
311
+ """Get optimizer-wide statistics."""
312
+ return {
313
+ "telemetry": self.telemetry.get_stats(),
314
+ "cache": self.cache_layout.report(),
315
+ "meta_tools": self.meta_tool_miner.get_stats(),
316
+ "doom": self.doom_detector.get_stats(),
317
+ }
318
+
319
+ @classmethod
320
+ def from_config(cls, path: str) -> "AgentCostOptimizer":
321
+ config = ACOConfig.from_yaml(path)
322
+ return cls(config)