narcolepticchicken commited on
Commit
f17c0fd
·
verified ·
1 Parent(s): ff456f8

Upload aco/optimizer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. aco/optimizer.py +164 -316
aco/optimizer.py CHANGED
@@ -1,326 +1,174 @@
1
- """Agent Cost Optimizer - Main orchestrator."""
2
-
3
- import uuid
4
- import time
5
  from typing import Dict, List, Optional, Any
6
- from dataclasses import dataclass
7
-
8
- from .config import ACOConfig
9
- from .trace_schema import (
10
- AgentTrace, TraceStep, ModelCall, ToolCall, VerifierCall,
11
- TaskType, Outcome, FailureTag,
12
- )
13
- from .telemetry import CostTelemetryCollector
14
- from .classifier import TaskCostClassifier, TaskPrediction
15
  from .router import ModelCascadeRouter, RoutingDecision
16
- from .context_budgeter import ContextBudgeter, ContextSource, ContextBudget
17
- from .cache_layout import CacheAwarePromptLayout, PromptLayout
18
- from .tool_gate import ToolUseCostGate, ToolGateDecision, ToolDecision
19
- from .verifier_budgeter import VerifierBudgeter, VerifierBudgetDecision
20
- from .retry_optimizer import RetryRecoveryOptimizer, RecoveryDecision, RecoveryAction
21
- from .meta_tool_miner import MetaToolMiner
22
- from .doom_detector import DoomDetector, DoomAssessment, DoomAction
23
-
24
-
25
- @dataclass
26
- class OptimizationResult:
27
- trace_id: str
28
- routing_decision: RoutingDecision
29
- context_budget: Optional[ContextBudget]
30
- prompt_layout: Optional[PromptLayout]
31
- tool_decisions: List[ToolGateDecision]
32
- verifier_decision: Optional[VerifierBudgetDecision]
33
- recovery_decision: Optional[RecoveryDecision]
34
- doom_assessment: Optional[DoomAssessment]
35
- meta_tool_match: Optional[Dict]
36
- estimated_cost: float
37
- estimated_latency_ms: float
38
- confidence: float
39
- reasoning: str
40
-
41
-
42
- class AgentCostOptimizer:
43
- """Universal control layer for reducing agent run costs while preserving quality."""
44
-
45
- def __init__(self, config: Optional[ACOConfig] = None):
46
  self.config = config or ACOConfig()
47
-
48
- # Core modules
49
- self.telemetry = CostTelemetryCollector(self.config.trace_storage_path)
50
- self.classifier = TaskCostClassifier(self.config)
51
- self.router = ModelCascadeRouter(self.config)
52
- self.context_budgeter = ContextBudgeter(self.config)
53
- self.cache_layout = CacheAwarePromptLayout(self.config)
54
- self.tool_gate = ToolUseCostGate(self.config)
55
- self.verifier_budgeter = VerifierBudgeter(self.config)
56
- self.retry_optimizer = RetryRecoveryOptimizer(self.config)
57
- self.meta_tool_miner = MetaToolMiner(self.config)
58
- self.doom_detector = DoomDetector(self.config)
59
-
60
- # Runtime state
61
- self.active_traces: Dict[str, AgentTrace] = {}
62
- self.step_counter: Dict[str, int] = {}
63
-
64
- def optimize(self, user_request: str, run_state: Optional[Dict] = None) -> OptimizationResult:
65
- """Main entry point: decide how to execute an agent request cost-effectively."""
66
-
67
- run_state = run_state or {}
68
- trace_id = run_state.get("trace_id", str(uuid.uuid4()))
69
-
70
- # 1. Classify the task
71
- past_traces = self._get_past_traces()
72
- prediction = self.classifier.classify_with_history(user_request, past_traces)
73
-
74
- # 2. Route to model
75
- routing_mode = run_state.get("routing_mode", "cascade")
76
- routing = self.router.route(prediction, routing_mode=routing_mode)
77
-
78
- # 3. Budget context
79
- context_budget = None
80
- if self.config.enable_context_budgeter:
81
- available_sources = self._build_context_sources(run_state)
82
- model_cfg = self.config.models.get(routing.model_id)
83
- max_ctx = model_cfg.max_context if model_cfg else 128000
84
- cost_1k = model_cfg.cost_per_1k_input if model_cfg else 0.01
85
- context_budget = self.context_budgeter.budget(
86
- prediction.task_type, available_sources, max_ctx, cost_1k
87
- )
88
-
89
- # 4. Optimize cache layout
90
- prompt_layout = None
91
- if self.config.enable_cache_layout and context_budget:
92
- content_pieces = self._build_content_pieces(context_budget)
93
- cache_discount = 0.5
94
- if model_cfg:
95
- cache_discount = model_cfg.cache_discount_rate
96
- prompt_layout = self.cache_layout.layout(
97
- content_pieces, cost_1k, cache_discount
98
- )
99
-
100
- # 5. Check for meta-tool
101
- meta_tool_match = None
102
- if self.config.enable_meta_tool_miner:
103
- planned_tools = run_state.get("planned_tools", [])
104
- planned_tool_names = [t[0] for t in planned_tools]
105
- meta_tool_match = self.meta_tool_miner.match_and_compress(
106
- prediction.task_type, planned_tool_names
107
- )
108
-
109
- # 6. Gate tool calls
110
- tool_decisions = []
111
- if self.config.enable_tool_gate:
112
- planned_tools = run_state.get("planned_tools", [])
113
- current_cost = run_state.get("current_cost", 0.0)
114
- for tool_name, tool_input in planned_tools:
115
- decision = self.tool_gate.decide(
116
- tool_name, prediction.task_type, tool_input,
117
- run_state.get("previous_tool_calls"), current_cost,
118
- prediction.expected_cost,
119
- )
120
- tool_decisions.append(decision)
121
-
122
- # 7. Decide verifier
123
- verifier_decision = None
124
- if self.config.enable_verifier_budgeter:
125
- verifier_decision = self.verifier_budgeter.decide(
126
- task_type=prediction.task_type,
127
- model_tier_used=routing.tier,
128
- confidence=routing.confidence,
129
- has_prior_failures=bool(run_state.get("prior_failures")),
130
- is_irreversible=run_state.get("is_irreversible", False),
131
- output_length_tokens=run_state.get("expected_output_tokens", 1024),
132
- retrieval_evidence_count=len(run_state.get("retrieved_docs", [])),
133
- step_number=run_state.get("step_number", 1),
134
- total_steps=run_state.get("total_steps", 1),
135
- mode="risk_weighted",
136
- )
137
-
138
- # 8. Check for doom
139
- doom = None
140
- if self.config.enable_early_termination and trace_id in self.active_traces:
141
- trace = self.active_traces[trace_id]
142
- current_step = trace.steps[-1] if trace.steps else None
143
- doom = self.doom_detector.assess(
144
- trace, current_step, prediction.expected_cost, prediction.expected_cost * 10000
145
- )
146
-
147
- # 9. Recovery (if in recovery mode)
148
- recovery = None
149
- if run_state.get("in_recovery") and trace_id in self.active_traces:
150
- trace = self.active_traces[trace_id]
151
- failure_tags = [FailureTag(f) for f in run_state.get("failure_tags", [])]
152
- current_step = trace.steps[-1] if trace.steps else None
153
- recovery = self.retry_optimizer.decide_recovery(
154
- prediction.task_type,
155
- current_step,
156
- failure_tags,
157
- trace.total_cost_computed,
158
- prediction.expected_cost,
159
- routing.tier,
160
- run_state.get("step_number", 1),
161
- trace.steps,
162
- )
163
-
164
- # Estimate cost
165
- est_model_cost = self._estimate_model_cost(routing, context_budget)
166
- est_tool_cost = sum(d.estimated_cost for d in tool_decisions if d.decision == ToolDecision.USE)
167
- est_verifier_cost = verifier_decision.estimated_verifier_cost if verifier_decision else 0.0
168
-
169
- estimated_cost = est_model_cost + est_tool_cost + est_verifier_cost
170
- estimated_latency = prediction.expected_latency_ms
171
-
172
- return OptimizationResult(
173
- trace_id=trace_id,
174
- routing_decision=routing,
175
- context_budget=context_budget,
176
- prompt_layout=prompt_layout,
177
- tool_decisions=tool_decisions,
178
- verifier_decision=verifier_decision,
179
- recovery_decision=recovery,
180
- doom_assessment=doom,
181
- meta_tool_match=meta_tool_match,
182
- estimated_cost=estimated_cost,
183
- estimated_latency_ms=estimated_latency,
184
- confidence=routing.confidence,
185
- reasoning=f"Task={prediction.task_type.value}, tier={routing.tier}, risk={prediction.risk_of_failure:.2f}",
186
  )
187
-
188
- def start_trace(self, trace_id: str, user_request: str, prediction: TaskPrediction) -> AgentTrace:
189
- trace = self.telemetry.start_trace(trace_id, user_request, prediction.task_type)
190
- self.active_traces[trace_id] = trace
191
- self.step_counter[trace_id] = 0
192
- return trace
193
-
194
- def record_step(
195
- self,
196
- trace_id: str,
197
- model_call: ModelCall,
198
- tool_calls: Optional[List[ToolCall]] = None,
199
- verifier_calls: Optional[List[VerifierCall]] = None,
200
- context_size_tokens: int = 0,
201
- step_outcome: Optional[Outcome] = None,
202
- ) -> None:
203
- self.step_counter[trace_id] = self.step_counter.get(trace_id, 0) + 1
204
- step_id = f"{trace_id}_step_{self.step_counter[trace_id]}"
205
- self.telemetry.add_step(
206
- trace_id=trace_id,
207
- step_id=step_id,
208
- model_call=model_call,
209
- tool_calls=tool_calls or [],
210
- verifier_calls=verifier_calls or [],
211
- context_size_tokens=context_size_tokens,
212
- step_outcome=step_outcome,
213
  )
214
-
215
- def finalize_trace(
216
- self,
217
- trace_id: str,
218
- outcome: Outcome,
219
- failure_tags: Optional[List[FailureTag]] = None,
220
- user_satisfaction: Optional[float] = None,
221
- ) -> AgentTrace:
222
- trace = self.active_traces.pop(trace_id, None)
223
- if trace:
224
- # Mine for meta-tools if successful
225
- if outcome in (Outcome.SUCCESS, Outcome.PARTIAL_SUCCESS):
226
- self.meta_tool_miner.ingest_trace(trace)
227
-
228
- return self.telemetry.finalize_trace(
229
- trace_id=trace_id,
230
- final_outcome=outcome,
231
- failure_tags=failure_tags,
232
- user_satisfaction=user_satisfaction,
233
- )
234
- return None
235
-
236
- def compute_cost_adjusted_score(
237
- self,
238
- trace: AgentTrace,
239
- success_score: float = 1.0,
240
- safety_bonus: float = 0.0,
241
- artifact_bonus: float = 0.0,
242
- calibration_bonus: float = 0.0,
243
- ) -> float:
244
- """Compute the cost-adjusted quality score."""
245
- cost = trace.total_cost_computed
246
- retries = trace.total_retries
247
- tools = trace.total_tool_calls
248
- verifiers = trace.total_verifier_calls
249
-
250
- score = (
251
- success_score
252
- + safety_bonus
253
- + artifact_bonus
254
- + calibration_bonus
255
- - self.config.model_cost_weight * cost
256
- - self.config.tool_cost_weight * tools * 0.001
257
- - self.config.verifier_cost_weight * verifiers * 0.001
258
- - self.config.latency_weight * trace.total_latency_ms / 1000
259
- - self.config.retry_penalty_weight * retries * 0.01
260
  )
261
-
262
- # Penalize critical failures
263
- if FailureTag.UNSAFE_CHEAP_MODEL in trace.failure_tags:
264
- score -= self.config.unsafe_cheap_model_penalty
265
- if FailureTag.MISSED_ESCALATION in trace.failure_tags:
266
- score -= self.config.missed_escalation_penalty
267
- if trace.final_outcome == Outcome.FALSE_DONE:
268
- score -= self.config.false_done_penalty
269
-
270
- return score
271
-
272
- def _get_past_traces(self) -> List[Dict]:
273
- """Get historical traces as dicts."""
274
- traces = []
275
- for tid in self.telemetry.list_traces():
276
- t = self.telemetry.load_trace(tid)
277
- if t and isinstance(t, dict):
278
- traces.append(t)
279
- return traces
280
-
281
- def _build_context_sources(self, run_state: Dict) -> List[ContextSource]:
282
- """Build context source objects from run state."""
283
- sources = []
284
- for name, content in run_state.get("context_pieces", {}).items():
285
- sources.append(ContextSource(
286
- name=name,
287
- tokens=len(content) // 4,
288
- importance=0.5,
289
- staleness=0.0,
290
- mutable=name in ["user_message", "retrieved_docs", "recent_trace"],
291
- cacheable=name in ["system_rules", "tool_descriptions", "user_preferences"],
292
- ))
293
- return sources
294
-
295
- def _build_content_pieces(self, context_budget: ContextBudget) -> Dict[str, str]:
296
- """Build content pieces dict from context budget."""
297
- pieces = {}
298
- for src in context_budget.allocated_sources:
299
- pieces[src.name] = f"[{src.name}]"
300
- for src, summary in context_budget.summarized_sources:
301
- pieces[src.name] = summary
302
- return pieces
303
 
304
- def _estimate_model_cost(self, routing: RoutingDecision, context_budget: Optional[ContextBudget]) -> float:
305
- model_cfg = self.config.models.get(routing.model_id)
306
- if not model_cfg:
307
- return 0.01
308
- input_tokens = context_budget.total_budget_tokens if context_budget else 4096
309
- output_tokens = routing.max_tokens
310
- input_cost = (input_tokens / 1000) * model_cfg.cost_per_1k_input
311
- output_cost = (output_tokens / 1000) * model_cfg.cost_per_1k_output
312
- return input_cost + output_cost
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
 
314
- def get_stats(self) -> Dict[str, Any]:
315
- """Get optimizer-wide statistics."""
 
 
 
 
 
 
 
 
 
316
  return {
317
- "telemetry": self.telemetry.get_stats(),
318
- "cache": self.cache_layout.report(),
319
- "meta_tools": self.meta_tool_miner.get_stats(),
320
- "doom": self.doom_detector.get_stats(),
 
 
 
321
  }
322
-
323
- @classmethod
324
- def from_config(cls, path: str) -> "AgentCostOptimizer":
325
- config = ACOConfig.from_yaml(path)
326
- return cls(config)
 
1
+ """ACO Optimizer: Main orchestrator that coordinates all modules."""
2
+ import json, time, uuid
 
 
3
  from typing import Dict, List, Optional, Any
4
+ from .config import ACOConfig, RoutingPolicy
5
+ from .trace_schema import AgentTrace, TraceStep, ModelCall, ToolCall
6
+ from .classifier import TaskCostClassifier
 
 
 
 
 
 
7
  from .router import ModelCascadeRouter, RoutingDecision
8
+ from .context_budgeter import ContextBudgeter, ContextBudget
9
+ from .cache_layout import CacheAwareLayout, PromptLayout
10
+ from .tool_gate import ToolCostGate, ToolDecision
11
+ from .verifier_budgeter import VerifierBudgeter, VerifierDecision
12
+ from .retry_optimizer import RetryOptimizer, RecoveryAction
13
+ from .meta_tool_miner import MetaToolMiner, MacroTool
14
+ from .doom_detector import DoomDetector, DoomAssessment
15
+
16
+ class ACOOptimizer:
17
+ def __init__(self, config: ACOConfig = None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  self.config = config or ACOConfig()
19
+ self.classifier = TaskCostClassifier()
20
+ self.router = ModelCascadeRouter(
21
+ model_path=self.config.router_model_path,
22
+ safety_threshold=self.config.routing_policy.safety_threshold,
23
+ downgrade_threshold=self.config.routing_policy.downgrade_threshold,
24
+ task_floor=self.config.task_floors,
25
+ tier_costs=self.config.tier_costs,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  )
27
+ self.context_budgeter = ContextBudgeter()
28
+ self.cache_layout = CacheAwareLayout()
29
+ self.tool_gate = ToolCostGate()
30
+ self.verifier_budgeter = VerifierBudgeter()
31
+ self.retry_optimizer = RetryOptimizer(
32
+ max_retries=self.config.routing_policy.max_retries,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  )
34
+ self.meta_tool_miner = MetaToolMiner()
35
+ self.doom_detector = DoomDetector()
36
+ self._current_trace: Optional[AgentTrace] = None
37
+ self._step_num = 0
38
+ self._traces: List[AgentTrace] = []
39
+
40
+ def start_run(self, request: str) -> Dict:
41
+ prediction = self.classifier.classify(request)
42
+ routing = self.router.route(request, prediction["task_type"], prediction["difficulty"], prediction)
43
+ context_budget = self.context_budgeter.budget(
44
+ prediction["task_type"], prediction["difficulty"],
45
+ prediction["needs_retrieval"], prediction["needs_tools"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  )
47
+ # Check for meta-tool match
48
+ macro = self.meta_tool_miner.match_macro(request, prediction["task_type"]) if self.config.enable_meta_tools else None
49
+ self._current_trace = AgentTrace(
50
+ request=request,
51
+ task_type=prediction["task_type"],
52
+ difficulty=prediction["difficulty"],
53
+ predicted_tier=routing.tier,
54
+ )
55
+ self._step_num = 0
56
+ self.retry_optimizer.reset_run()
57
+ self.verifier_budgeter.reset_run()
58
+ return {
59
+ "trace_id": self._current_trace.trace_id,
60
+ "prediction": prediction,
61
+ "routing": {
62
+ "model_id": routing.model_id,
63
+ "tier": routing.tier,
64
+ "confidence": routing.confidence,
65
+ "cost_estimate": routing.cost_estimate,
66
+ "dynamic_difficulty": routing.dynamic_difficulty,
67
+ "escalated": routing.escalated,
68
+ "downgraded": routing.downgraded,
69
+ "reasoning": routing.reasoning,
70
+ },
71
+ "context_budget": {
72
+ "total_tokens": context_budget.total_tokens,
73
+ "keep_exact": context_budget.keep_exact,
74
+ "summarize": context_budget.summarize,
75
+ "omit": context_budget.omit,
76
+ "retrieve_on_demand": context_budget.retrieve_on_demand,
77
+ "cache_prefix": context_budget.cache_prefix,
78
+ },
79
+ "macro_tool": macro.name if macro else None,
80
+ }
 
 
 
 
 
 
 
 
81
 
82
+ def record_step(self, model_call: Dict = None, tool_calls: List[Dict] = None,
83
+ context_size: int = 0, verifier_called: bool = False,
84
+ verifier_result: str = None, retry_num: int = 0,
85
+ recovery_action: str = None) -> None:
86
+ self._step_num += 1
87
+ mc = None
88
+ if model_call:
89
+ mc = ModelCall(**model_call)
90
+ tcs = [ToolCall(**tc) for tc in (tool_calls or [])]
91
+ step = TraceStep(
92
+ step_num=self._step_num,
93
+ model_call=mc,
94
+ tool_calls=tcs,
95
+ context_size=context_size,
96
+ verifier_called=verifier_called,
97
+ verifier_result=verifier_result,
98
+ retry_num=retry_num,
99
+ recovery_action=recovery_action,
100
+ )
101
+ if self._current_trace:
102
+ self._current_trace.steps.append(step)
103
+
104
+ def check_doom(self, current_cost: float = 0.0) -> DoomAssessment:
105
+ if not self._current_trace:
106
+ return DoomAssessment(False, 0.0, [], "continue", "no active trace")
107
+ return self.doom_detector.assess(
108
+ [s.__dict__ for s in self._current_trace.steps],
109
+ current_cost, self.config.routing_policy.max_cost_per_task, 4)
110
+
111
+ def should_verify(self, is_irreversible: bool = False,
112
+ has_prior_failures: bool = False) -> VerifierDecision:
113
+ if not self._current_trace:
114
+ return VerifierDecision(False, "skip", 0.0, "no active trace", 0.0)
115
+ return self.verifier_budgeter.should_verify(
116
+ self._current_trace.task_type, "medium", 0.8,
117
+ is_irreversible, has_prior_failures,
118
+ self._current_trace.predicted_tier)
119
+
120
+ def gate_tool(self, tool_name: str, args: Dict) -> ToolDecision:
121
+ if not self._current_trace:
122
+ return ToolDecision("skip", tool_name, 0.0, "no active trace", 0.0, 0.0)
123
+ return self.tool_gate.gate(tool_name, args, self._current_trace.task_type,
124
+ self._step_num, self._step_num + 1, 0.5)
125
+
126
+ def get_recovery(self, failure_tag: str, current_tier: int,
127
+ retry_num: int, previous_actions: List[str] = None,
128
+ run_cost: float = 0.0) -> RecoveryAction:
129
+ return self.retry_optimizer.get_recovery(
130
+ failure_tag, current_tier, retry_num,
131
+ previous_actions, run_cost,
132
+ self.config.routing_policy.max_cost_per_task)
133
+
134
+ def end_run(self, success: bool, outcome: str = "completed",
135
+ artifacts: List[str] = None, failure_tags: List[str] = None,
136
+ user_correction: bool = False) -> AgentTrace:
137
+ if self._current_trace:
138
+ self._current_trace.task_success = success
139
+ self._current_trace.final_outcome = outcome
140
+ self._current_trace.artifacts_created = artifacts or []
141
+ self._current_trace.failure_tags = failure_tags or []
142
+ self._current_trace.user_correction = user_correction
143
+ summary = self._current_trace.compute_summary()
144
+ self._current_trace.total_cost = summary["total_cost"]
145
+ self._current_trace.total_tokens = summary["total_tokens"]
146
+ self._current_trace.total_tool_calls = summary["total_tool_calls"]
147
+ self._current_trace.total_retries = summary["total_retries"]
148
+ self._current_trace.total_verifier_calls = summary["total_verifier_calls"]
149
+ self._current_trace.cache_hit_rate = summary["cache_hit_rate"]
150
+ self._traces.append(self._current_trace)
151
+ trace = self._current_trace
152
+ self._current_trace = None
153
+ return trace
154
 
155
+ def layout_prompt(self, sources: Dict[str, str]) -> PromptLayout:
156
+ if not self._current_trace:
157
+ budget = self.context_budgeter.budget("unknown_ambiguous", 3, False, False)
158
+ else:
159
+ budget = self.context_budgeter.budget(
160
+ self._current_trace.task_type,
161
+ self._current_trace.difficulty,
162
+ False, False)
163
+ return self.cache_layout.layout(sources, budget)
164
+
165
+ def get_stats(self) -> Dict:
166
  return {
167
+ "total_runs": len(self._traces),
168
+ "successful_runs": sum(1 for t in self._traces if t.task_success),
169
+ "avg_cost": sum(t.total_cost for t in self._traces) / max(len(self._traces),1),
170
+ "cache_stats": self.cache_layout.stats(),
171
+ "tool_stats": self.tool_gate.call_stats,
172
+ "verifier_stats": self.verifier_budgeter.stats,
173
+ "retry_stats": self.retry_optimizer.recovery_stats,
174
  }