| """Agent Cost Optimizer - Main orchestrator.""" |
|
|
| import uuid |
| import time |
| from typing import Dict, List, Optional, Any |
| from dataclasses import dataclass |
|
|
| from .config import ACOConfig |
| from .trace_schema import ( |
| AgentTrace, TraceStep, ModelCall, ToolCall, VerifierCall, |
| TaskType, Outcome, FailureTag, |
| ) |
| from .telemetry import CostTelemetryCollector |
| from .classifier import TaskCostClassifier, TaskPrediction |
| from .router import ModelCascadeRouter, RoutingDecision |
| from .context_budgeter import ContextBudgeter, ContextSource, ContextBudget |
| from .cache_layout import CacheAwarePromptLayout, PromptLayout |
| from .tool_gate import ToolUseCostGate, ToolGateDecision, ToolDecision |
| from .verifier_budgeter import VerifierBudgeter, VerifierBudgetDecision |
| from .retry_optimizer import RetryRecoveryOptimizer, RecoveryDecision, RecoveryAction |
| from .meta_tool_miner import MetaToolMiner |
| from .doom_detector import DoomDetector, DoomAssessment, DoomAction |
|
|
|
|
| @dataclass |
| class OptimizationResult: |
| trace_id: str |
| routing_decision: RoutingDecision |
| context_budget: Optional[ContextBudget] |
| prompt_layout: Optional[PromptLayout] |
| tool_decisions: List[ToolGateDecision] |
| verifier_decision: Optional[VerifierBudgetDecision] |
| recovery_decision: Optional[RecoveryDecision] |
| doom_assessment: Optional[DoomAssessment] |
| meta_tool_match: Optional[Dict] |
| estimated_cost: float |
| estimated_latency_ms: float |
| confidence: float |
| reasoning: str |
|
|
|
|
| class AgentCostOptimizer: |
| """Universal control layer for reducing agent run costs while preserving quality.""" |
|
|
| def __init__(self, config: Optional[ACOConfig] = None): |
| self.config = config or ACOConfig() |
| |
| |
| self.telemetry = CostTelemetryCollector(self.config.trace_storage_path) |
| self.classifier = TaskCostClassifier(self.config) |
| self.router = ModelCascadeRouter(self.config) |
| self.context_budgeter = ContextBudgeter(self.config) |
| self.cache_layout = CacheAwarePromptLayout(self.config) |
| self.tool_gate = ToolUseCostGate(self.config) |
| self.verifier_budgeter = VerifierBudgeter(self.config) |
| self.retry_optimizer = RetryRecoveryOptimizer(self.config) |
| self.meta_tool_miner = MetaToolMiner(self.config) |
| self.doom_detector = DoomDetector(self.config) |
| |
| |
| self.active_traces: Dict[str, AgentTrace] = {} |
| self.step_counter: Dict[str, int] = {} |
|
|
| def optimize(self, user_request: str, run_state: Optional[Dict] = None) -> OptimizationResult: |
| """Main entry point: decide how to execute an agent request cost-effectively.""" |
| |
| run_state = run_state or {} |
| trace_id = run_state.get("trace_id", str(uuid.uuid4())) |
| |
| |
| past_traces = self._get_past_traces() |
| prediction = self.classifier.classify_with_history(user_request, past_traces) |
| |
| |
| routing_mode = run_state.get("routing_mode", "cascade") |
| routing = self.router.route(prediction, routing_mode=routing_mode) |
| |
| |
| context_budget = None |
| if self.config.enable_context_budgeter: |
| available_sources = self._build_context_sources(run_state) |
| model_cfg = self.config.models.get(routing.model_id) |
| max_ctx = model_cfg.max_context if model_cfg else 128000 |
| cost_1k = model_cfg.cost_per_1k_input if model_cfg else 0.01 |
| context_budget = self.context_budgeter.budget( |
| prediction.task_type, available_sources, max_ctx, cost_1k |
| ) |
| |
| |
| prompt_layout = None |
| if self.config.enable_cache_layout and context_budget: |
| content_pieces = self._build_content_pieces(context_budget) |
| cache_discount = 0.5 |
| if model_cfg: |
| cache_discount = model_cfg.cache_discount_rate |
| prompt_layout = self.cache_layout.layout( |
| content_pieces, cost_1k, cache_discount |
| ) |
| |
| |
| meta_tool_match = None |
| if self.config.enable_meta_tool_miner: |
| planned_tools = run_state.get("planned_tools", []) |
| planned_tool_names = [t[0] for t in planned_tools] |
| meta_tool_match = self.meta_tool_miner.match_and_compress( |
| prediction.task_type, planned_tool_names |
| ) |
| |
| |
| tool_decisions = [] |
| if self.config.enable_tool_gate: |
| planned_tools = run_state.get("planned_tools", []) |
| current_cost = run_state.get("current_cost", 0.0) |
| for tool_name, tool_input in planned_tools: |
| decision = self.tool_gate.decide( |
| tool_name, prediction.task_type, tool_input, |
| run_state.get("previous_tool_calls"), current_cost, |
| prediction.expected_cost, |
| ) |
| tool_decisions.append(decision) |
| |
| |
| verifier_decision = None |
| if self.config.enable_verifier_budgeter: |
| verifier_decision = self.verifier_budgeter.decide( |
| task_type=prediction.task_type, |
| model_tier_used=routing.tier, |
| confidence=routing.confidence, |
| has_prior_failures=bool(run_state.get("prior_failures")), |
| is_irreversible=run_state.get("is_irreversible", False), |
| output_length_tokens=run_state.get("expected_output_tokens", 1024), |
| retrieval_evidence_count=len(run_state.get("retrieved_docs", [])), |
| step_number=run_state.get("step_number", 1), |
| total_steps=run_state.get("total_steps", 1), |
| mode="risk_weighted", |
| ) |
| |
| |
| doom = None |
| if self.config.enable_early_termination and trace_id in self.active_traces: |
| trace = self.active_traces[trace_id] |
| current_step = trace.steps[-1] if trace.steps else None |
| doom = self.doom_detector.assess( |
| trace, current_step, prediction.expected_cost, prediction.expected_cost * 10000 |
| ) |
| |
| |
| recovery = None |
| if run_state.get("in_recovery") and trace_id in self.active_traces: |
| trace = self.active_traces[trace_id] |
| failure_tags = [FailureTag(f) for f in run_state.get("failure_tags", [])] |
| current_step = trace.steps[-1] if trace.steps else None |
| recovery = self.retry_optimizer.decide_recovery( |
| prediction.task_type, |
| current_step, |
| failure_tags, |
| trace.total_cost_computed, |
| prediction.expected_cost, |
| routing.tier, |
| run_state.get("step_number", 1), |
| trace.steps, |
| ) |
| |
| |
| est_model_cost = self._estimate_model_cost(routing, context_budget) |
| est_tool_cost = sum(d.estimated_cost for d in tool_decisions if d.decision == ToolDecision.USE) |
| est_verifier_cost = verifier_decision.estimated_verifier_cost if verifier_decision else 0.0 |
| |
| estimated_cost = est_model_cost + est_tool_cost + est_verifier_cost |
| estimated_latency = prediction.expected_latency_ms |
| |
| return OptimizationResult( |
| trace_id=trace_id, |
| routing_decision=routing, |
| context_budget=context_budget, |
| prompt_layout=prompt_layout, |
| tool_decisions=tool_decisions, |
| verifier_decision=verifier_decision, |
| recovery_decision=recovery, |
| doom_assessment=doom, |
| meta_tool_match=meta_tool_match, |
| estimated_cost=estimated_cost, |
| estimated_latency_ms=estimated_latency, |
| confidence=routing.confidence, |
| reasoning=f"Task={prediction.task_type.value}, tier={routing.tier}, risk={prediction.risk_of_failure:.2f}", |
| ) |
|
|
| def start_trace(self, trace_id: str, user_request: str, prediction: TaskPrediction) -> AgentTrace: |
| trace = self.telemetry.start_trace(trace_id, user_request, prediction.task_type) |
| self.active_traces[trace_id] = trace |
| self.step_counter[trace_id] = 0 |
| return trace |
|
|
| def record_step( |
| self, |
| trace_id: str, |
| model_call: ModelCall, |
| tool_calls: Optional[List[ToolCall]] = None, |
| verifier_calls: Optional[List[VerifierCall]] = None, |
| context_size_tokens: int = 0, |
| step_outcome: Optional[Outcome] = None, |
| ) -> None: |
| self.step_counter[trace_id] = self.step_counter.get(trace_id, 0) + 1 |
| step_id = f"{trace_id}_step_{self.step_counter[trace_id]}" |
| self.telemetry.add_step( |
| trace_id=trace_id, |
| step_id=step_id, |
| model_call=model_call, |
| tool_calls=tool_calls or [], |
| verifier_calls=verifier_calls or [], |
| context_size_tokens=context_size_tokens, |
| step_outcome=step_outcome, |
| ) |
|
|
| def finalize_trace( |
| self, |
| trace_id: str, |
| outcome: Outcome, |
| failure_tags: Optional[List[FailureTag]] = None, |
| user_satisfaction: Optional[float] = None, |
| ) -> AgentTrace: |
| trace = self.active_traces.pop(trace_id, None) |
| if trace: |
| |
| if outcome in (Outcome.SUCCESS, Outcome.PARTIAL_SUCCESS): |
| self.meta_tool_miner.ingest_trace(trace) |
| |
| return self.telemetry.finalize_trace( |
| trace_id=trace_id, |
| final_outcome=outcome, |
| failure_tags=failure_tags, |
| user_satisfaction=user_satisfaction, |
| ) |
| return None |
|
|
| def compute_cost_adjusted_score( |
| self, |
| trace: AgentTrace, |
| success_score: float = 1.0, |
| safety_bonus: float = 0.0, |
| artifact_bonus: float = 0.0, |
| calibration_bonus: float = 0.0, |
| ) -> float: |
| """Compute the cost-adjusted quality score.""" |
| cost = trace.total_cost_computed |
| retries = trace.total_retries |
| tools = trace.total_tool_calls |
| verifiers = trace.total_verifier_calls |
| |
| score = ( |
| success_score |
| + safety_bonus |
| + artifact_bonus |
| + calibration_bonus |
| - self.config.model_cost_weight * cost |
| - self.config.tool_cost_weight * tools * 0.001 |
| - self.config.verifier_cost_weight * verifiers * 0.001 |
| - self.config.latency_weight * trace.total_latency_ms / 1000 |
| - self.config.retry_penalty_weight * retries * 0.01 |
| ) |
| |
| |
| if FailureTag.UNSAFE_CHEAP_MODEL in trace.failure_tags: |
| score -= self.config.unsafe_cheap_model_penalty |
| if FailureTag.MISSED_ESCALATION in trace.failure_tags: |
| score -= self.config.missed_escalation_penalty |
| if trace.final_outcome == Outcome.FALSE_DONE: |
| score -= self.config.false_done_penalty |
| |
| return score |
|
|
| def _get_past_traces(self) -> List[Dict]: |
| """Get historical traces as dicts.""" |
| traces = [] |
| for tid in self.telemetry.list_traces(): |
| t = self.telemetry.load_trace(tid) |
| if t and isinstance(t, dict): |
| traces.append(t) |
| return traces |
|
|
| def _build_context_sources(self, run_state: Dict) -> List[ContextSource]: |
| """Build context source objects from run state.""" |
| sources = [] |
| for name, content in run_state.get("context_pieces", {}).items(): |
| sources.append(ContextSource( |
| name=name, |
| tokens=len(content) // 4, |
| importance=0.5, |
| staleness=0.0, |
| mutable=name in ["user_message", "retrieved_docs", "recent_trace"], |
| cacheable=name in ["system_rules", "tool_descriptions", "user_preferences"], |
| )) |
| return sources |
|
|
| def _build_content_pieces(self, context_budget: ContextBudget) -> Dict[str, str]: |
| """Build content pieces dict from context budget.""" |
| pieces = {} |
| for src in context_budget.allocated_sources: |
| pieces[src.name] = f"[{src.name}]" |
| for src, summary in context_budget.summarized_sources: |
| pieces[src.name] = summary |
| return pieces |
|
|
| def _estimate_model_cost(self, routing: RoutingDecision, context_budget: Optional[ContextBudget]) -> float: |
| model_cfg = self.config.models.get(routing.model_id) |
| if not model_cfg: |
| return 0.01 |
| input_tokens = context_budget.total_budget_tokens if context_budget else 4096 |
| output_tokens = routing.max_tokens |
| input_cost = (input_tokens / 1000) * model_cfg.cost_per_1k_input |
| output_cost = (output_tokens / 1000) * model_cfg.cost_per_1k_output |
| return input_cost + output_cost |
|
|
| def get_stats(self) -> Dict[str, Any]: |
| """Get optimizer-wide statistics.""" |
| return { |
| "telemetry": self.telemetry.get_stats(), |
| "cache": self.cache_layout.report(), |
| "meta_tools": self.meta_tool_miner.get_stats(), |
| "doom": self.doom_detector.get_stats(), |
| } |
|
|
| @classmethod |
| def from_config(cls, path: str) -> "AgentCostOptimizer": |
| config = ACOConfig.from_yaml(path) |
| return cls(config) |
|
|