"""Agent Cost Optimizer - Main orchestrator.""" import uuid import time from typing import Dict, List, Optional, Any from dataclasses import dataclass from .config import ACOConfig from .trace_schema import ( AgentTrace, TraceStep, ModelCall, ToolCall, VerifierCall, TaskType, Outcome, FailureTag, ) from .telemetry import CostTelemetryCollector from .classifier import TaskCostClassifier, TaskPrediction from .router import ModelCascadeRouter, RoutingDecision from .context_budgeter import ContextBudgeter, ContextSource, ContextBudget from .cache_layout import CacheAwarePromptLayout, PromptLayout from .tool_gate import ToolUseCostGate, ToolGateDecision, ToolDecision from .verifier_budgeter import VerifierBudgeter, VerifierBudgetDecision from .retry_optimizer import RetryRecoveryOptimizer, RecoveryDecision, RecoveryAction from .meta_tool_miner import MetaToolMiner from .doom_detector import DoomDetector, DoomAssessment, DoomAction @dataclass class OptimizationResult: trace_id: str routing_decision: RoutingDecision context_budget: Optional[ContextBudget] prompt_layout: Optional[PromptLayout] tool_decisions: List[ToolGateDecision] verifier_decision: Optional[VerifierBudgetDecision] recovery_decision: Optional[RecoveryDecision] doom_assessment: Optional[DoomAssessment] meta_tool_match: Optional[Dict] estimated_cost: float estimated_latency_ms: float confidence: float reasoning: str class AgentCostOptimizer: """Universal control layer for reducing agent run costs while preserving quality.""" def __init__(self, config: Optional[ACOConfig] = None): self.config = config or ACOConfig() # Core modules self.telemetry = CostTelemetryCollector(self.config.trace_storage_path) self.classifier = TaskCostClassifier(self.config) self.router = ModelCascadeRouter(self.config) self.context_budgeter = ContextBudgeter(self.config) self.cache_layout = CacheAwarePromptLayout(self.config) self.tool_gate = ToolUseCostGate(self.config) self.verifier_budgeter = VerifierBudgeter(self.config) self.retry_optimizer = RetryRecoveryOptimizer(self.config) self.meta_tool_miner = MetaToolMiner(self.config) self.doom_detector = DoomDetector(self.config) # Runtime state self.active_traces: Dict[str, AgentTrace] = {} self.step_counter: Dict[str, int] = {} def optimize(self, user_request: str, run_state: Optional[Dict] = None) -> OptimizationResult: """Main entry point: decide how to execute an agent request cost-effectively.""" run_state = run_state or {} trace_id = run_state.get("trace_id", str(uuid.uuid4())) # 1. Classify the task past_traces = self._get_past_traces() prediction = self.classifier.classify_with_history(user_request, past_traces) # 2. Route to model routing_mode = run_state.get("routing_mode", "cascade") routing = self.router.route(prediction, routing_mode=routing_mode) # 3. Budget context context_budget = None if self.config.enable_context_budgeter: available_sources = self._build_context_sources(run_state) model_cfg = self.config.models.get(routing.model_id) max_ctx = model_cfg.max_context if model_cfg else 128000 cost_1k = model_cfg.cost_per_1k_input if model_cfg else 0.01 context_budget = self.context_budgeter.budget( prediction.task_type, available_sources, max_ctx, cost_1k ) # 4. Optimize cache layout prompt_layout = None if self.config.enable_cache_layout and context_budget: content_pieces = self._build_content_pieces(context_budget) cache_discount = 0.5 if model_cfg: cache_discount = model_cfg.cache_discount_rate prompt_layout = self.cache_layout.layout( content_pieces, cost_1k, cache_discount ) # 5. Check for meta-tool meta_tool_match = None if self.config.enable_meta_tool_miner: planned_tools = run_state.get("planned_tools", []) planned_tool_names = [t[0] for t in planned_tools] meta_tool_match = self.meta_tool_miner.match_and_compress( prediction.task_type, planned_tool_names ) # 6. Gate tool calls tool_decisions = [] if self.config.enable_tool_gate: planned_tools = run_state.get("planned_tools", []) current_cost = run_state.get("current_cost", 0.0) for tool_name, tool_input in planned_tools: decision = self.tool_gate.decide( tool_name, prediction.task_type, tool_input, run_state.get("previous_tool_calls"), current_cost, prediction.expected_cost, ) tool_decisions.append(decision) # 7. Decide verifier verifier_decision = None if self.config.enable_verifier_budgeter: verifier_decision = self.verifier_budgeter.decide( task_type=prediction.task_type, model_tier_used=routing.tier, confidence=routing.confidence, has_prior_failures=bool(run_state.get("prior_failures")), is_irreversible=run_state.get("is_irreversible", False), output_length_tokens=run_state.get("expected_output_tokens", 1024), retrieval_evidence_count=len(run_state.get("retrieved_docs", [])), step_number=run_state.get("step_number", 1), total_steps=run_state.get("total_steps", 1), mode="risk_weighted", ) # 8. Check for doom doom = None if self.config.enable_early_termination and trace_id in self.active_traces: trace = self.active_traces[trace_id] current_step = trace.steps[-1] if trace.steps else None doom = self.doom_detector.assess( trace, current_step, prediction.expected_cost, prediction.expected_cost * 10000 ) # 9. Recovery (if in recovery mode) recovery = None if run_state.get("in_recovery") and trace_id in self.active_traces: trace = self.active_traces[trace_id] failure_tags = [FailureTag(f) for f in run_state.get("failure_tags", [])] current_step = trace.steps[-1] if trace.steps else None recovery = self.retry_optimizer.decide_recovery( prediction.task_type, current_step, failure_tags, trace.total_cost_computed, prediction.expected_cost, routing.tier, run_state.get("step_number", 1), trace.steps, ) # Estimate cost est_model_cost = self._estimate_model_cost(routing, context_budget) est_tool_cost = sum(d.estimated_cost for d in tool_decisions if d.decision == ToolDecision.USE) est_verifier_cost = verifier_decision.estimated_verifier_cost if verifier_decision else 0.0 estimated_cost = est_model_cost + est_tool_cost + est_verifier_cost estimated_latency = prediction.expected_latency_ms return OptimizationResult( trace_id=trace_id, routing_decision=routing, context_budget=context_budget, prompt_layout=prompt_layout, tool_decisions=tool_decisions, verifier_decision=verifier_decision, recovery_decision=recovery, doom_assessment=doom, meta_tool_match=meta_tool_match, estimated_cost=estimated_cost, estimated_latency_ms=estimated_latency, confidence=routing.confidence, reasoning=f"Task={prediction.task_type.value}, tier={routing.tier}, risk={prediction.risk_of_failure:.2f}", ) def start_trace(self, trace_id: str, user_request: str, prediction: TaskPrediction) -> AgentTrace: trace = self.telemetry.start_trace(trace_id, user_request, prediction.task_type) self.active_traces[trace_id] = trace self.step_counter[trace_id] = 0 return trace def record_step( self, trace_id: str, model_call: ModelCall, tool_calls: Optional[List[ToolCall]] = None, verifier_calls: Optional[List[VerifierCall]] = None, context_size_tokens: int = 0, step_outcome: Optional[Outcome] = None, ) -> None: self.step_counter[trace_id] = self.step_counter.get(trace_id, 0) + 1 step_id = f"{trace_id}_step_{self.step_counter[trace_id]}" self.telemetry.add_step( trace_id=trace_id, step_id=step_id, model_call=model_call, tool_calls=tool_calls or [], verifier_calls=verifier_calls or [], context_size_tokens=context_size_tokens, step_outcome=step_outcome, ) def finalize_trace( self, trace_id: str, outcome: Outcome, failure_tags: Optional[List[FailureTag]] = None, user_satisfaction: Optional[float] = None, ) -> AgentTrace: trace = self.active_traces.pop(trace_id, None) if trace: # Mine for meta-tools if successful if outcome in (Outcome.SUCCESS, Outcome.PARTIAL_SUCCESS): self.meta_tool_miner.ingest_trace(trace) return self.telemetry.finalize_trace( trace_id=trace_id, final_outcome=outcome, failure_tags=failure_tags, user_satisfaction=user_satisfaction, ) return None def compute_cost_adjusted_score( self, trace: AgentTrace, success_score: float = 1.0, safety_bonus: float = 0.0, artifact_bonus: float = 0.0, calibration_bonus: float = 0.0, ) -> float: """Compute the cost-adjusted quality score.""" cost = trace.total_cost_computed retries = trace.total_retries tools = trace.total_tool_calls verifiers = trace.total_verifier_calls score = ( success_score + safety_bonus + artifact_bonus + calibration_bonus - self.config.model_cost_weight * cost - self.config.tool_cost_weight * tools * 0.001 - self.config.verifier_cost_weight * verifiers * 0.001 - self.config.latency_weight * trace.total_latency_ms / 1000 - self.config.retry_penalty_weight * retries * 0.01 ) # Penalize critical failures if FailureTag.UNSAFE_CHEAP_MODEL in trace.failure_tags: score -= self.config.unsafe_cheap_model_penalty if FailureTag.MISSED_ESCALATION in trace.failure_tags: score -= self.config.missed_escalation_penalty if trace.final_outcome == Outcome.FALSE_DONE: score -= self.config.false_done_penalty return score def _get_past_traces(self) -> List[Dict]: """Get historical traces as dicts.""" traces = [] for tid in self.telemetry.list_traces(): t = self.telemetry.load_trace(tid) if t and isinstance(t, dict): traces.append(t) return traces def _build_context_sources(self, run_state: Dict) -> List[ContextSource]: """Build context source objects from run state.""" sources = [] for name, content in run_state.get("context_pieces", {}).items(): sources.append(ContextSource( name=name, tokens=len(content) // 4, importance=0.5, staleness=0.0, mutable=name in ["user_message", "retrieved_docs", "recent_trace"], cacheable=name in ["system_rules", "tool_descriptions", "user_preferences"], )) return sources def _build_content_pieces(self, context_budget: ContextBudget) -> Dict[str, str]: """Build content pieces dict from context budget.""" pieces = {} for src in context_budget.allocated_sources: pieces[src.name] = f"[{src.name}]" for src, summary in context_budget.summarized_sources: pieces[src.name] = summary return pieces def _estimate_model_cost(self, routing: RoutingDecision, context_budget: Optional[ContextBudget]) -> float: model_cfg = self.config.models.get(routing.model_id) if not model_cfg: return 0.01 input_tokens = context_budget.total_budget_tokens if context_budget else 4096 output_tokens = routing.max_tokens input_cost = (input_tokens / 1000) * model_cfg.cost_per_1k_input output_cost = (output_tokens / 1000) * model_cfg.cost_per_1k_output return input_cost + output_cost def get_stats(self) -> Dict[str, Any]: """Get optimizer-wide statistics.""" return { "telemetry": self.telemetry.get_stats(), "cache": self.cache_layout.report(), "meta_tools": self.meta_tool_miner.get_stats(), "doom": self.doom_detector.get_stats(), } @classmethod def from_config(cls, path: str) -> "AgentCostOptimizer": config = ACOConfig.from_yaml(path) return cls(config)