"""ACO Optimizer: Main orchestrator that coordinates all modules.""" import json, time, uuid from typing import Dict, List, Optional, Any from .config import ACOConfig, RoutingPolicy from .trace_schema import AgentTrace, TraceStep, ModelCall, ToolCall from .classifier import TaskCostClassifier from .router import ModelCascadeRouter, RoutingDecision from .context_budgeter import ContextBudgeter, ContextBudget from .cache_layout import CacheAwareLayout, PromptLayout from .tool_gate import ToolCostGate, ToolDecision from .verifier_budgeter import VerifierBudgeter, VerifierDecision from .retry_optimizer import RetryOptimizer, RecoveryAction from .meta_tool_miner import MetaToolMiner, MacroTool from .doom_detector import DoomDetector, DoomAssessment from .execution_feedback import ExecutionFeedbackRouter, CascadeResult, FeedbackSignal class ACOOptimizer: def __init__(self, config: ACOConfig = None): self.config = config or ACOConfig() self.classifier = TaskCostClassifier() self.router = ModelCascadeRouter( model_path=self.config.router_model_path, safety_threshold=self.config.routing_policy.safety_threshold, downgrade_threshold=self.config.routing_policy.downgrade_threshold, task_floor=self.config.task_floors, tier_costs=self.config.tier_costs, ) self.context_budgeter = ContextBudgeter() self.cache_layout = CacheAwareLayout() self.tool_gate = ToolCostGate() self.verifier_budgeter = VerifierBudgeter() self.retry_optimizer = RetryOptimizer( max_retries=self.config.routing_policy.max_retries, ) self.meta_tool_miner = MetaToolMiner() self.doom_detector = DoomDetector() self.execution_feedback = ExecutionFeedbackRouter( tier_costs=self.config.tier_costs, task_floors=self.config.task_floors, ) self._current_trace: Optional[AgentTrace] = None self._step_num = 0 self._traces: List[AgentTrace] = [] def start_run(self, request: str) -> Dict: prediction = self.classifier.classify(request) routing = self.router.route(request, prediction["task_type"], prediction["difficulty"], prediction) context_budget = self.context_budgeter.budget( prediction["task_type"], prediction["difficulty"], prediction["needs_retrieval"], prediction["needs_tools"], ) # Check for meta-tool match macro = self.meta_tool_miner.match_macro(request, prediction["task_type"]) if self.config.enable_meta_tools else None self._current_trace = AgentTrace( request=request, task_type=prediction["task_type"], difficulty=prediction["difficulty"], predicted_tier=routing.tier, ) self._step_num = 0 self.retry_optimizer.reset_run() self.verifier_budgeter.reset_run() return { "trace_id": self._current_trace.trace_id, "prediction": prediction, "routing": { "model_id": routing.model_id, "tier": routing.tier, "confidence": routing.confidence, "cost_estimate": routing.cost_estimate, "dynamic_difficulty": routing.dynamic_difficulty, "escalated": routing.escalated, "downgraded": routing.downgraded, "reasoning": routing.reasoning, }, "context_budget": { "total_tokens": context_budget.total_tokens, "keep_exact": context_budget.keep_exact, "summarize": context_budget.summarize, "omit": context_budget.omit, "retrieve_on_demand": context_budget.retrieve_on_demand, "cache_prefix": context_budget.cache_prefix, }, "macro_tool": macro.name if macro else None, } def record_step(self, model_call: Dict = None, tool_calls: List[Dict] = None, context_size: int = 0, verifier_called: bool = False, verifier_result: str = None, retry_num: int = 0, recovery_action: str = None) -> None: self._step_num += 1 mc = None if model_call: mc = ModelCall(**model_call) tcs = [ToolCall(**tc) for tc in (tool_calls or [])] step = TraceStep( step_num=self._step_num, model_call=mc, tool_calls=tcs, context_size=context_size, verifier_called=verifier_called, verifier_result=verifier_result, retry_num=retry_num, recovery_action=recovery_action, ) if self._current_trace: self._current_trace.steps.append(step) def check_doom(self, current_cost: float = 0.0) -> DoomAssessment: if not self._current_trace: return DoomAssessment(False, 0.0, [], "continue", "no active trace") return self.doom_detector.assess( [s.__dict__ for s in self._current_trace.steps], current_cost, self.config.routing_policy.max_cost_per_task, 4) def should_verify(self, is_irreversible: bool = False, has_prior_failures: bool = False) -> VerifierDecision: if not self._current_trace: return VerifierDecision(False, "skip", 0.0, "no active trace", 0.0) return self.verifier_budgeter.should_verify( self._current_trace.task_type, "medium", 0.8, is_irreversible, has_prior_failures, self._current_trace.predicted_tier) def gate_tool(self, tool_name: str, args: Dict) -> ToolDecision: if not self._current_trace: return ToolDecision("skip", tool_name, 0.0, "no active trace", 0.0, 0.0) return self.tool_gate.gate(tool_name, args, self._current_trace.task_type, self._step_num, self._step_num + 1, 0.5) def cascade_step(self, request: str, initial_tier: int, cheap_logprobs: List[float], cheap_response: str, strong_response: str = "", task_type: str = None) -> CascadeResult: """Execution-feedback cascade: use cheap model output to decide escalation.""" if not self._current_trace: task_type = task_type or "unknown_ambiguous" else: task_type = task_type or self._current_trace.task_type floor = self.config.task_floors.get(task_type, 1) return self.execution_feedback.cascade( request, initial_tier, cheap_logprobs, cheap_response, strong_response, task_type=task_type, task_floor=floor, ) def analyze_output_confidence(self, token_logprobs: List[float], task_type: str = "unknown", current_tier: int = 2) -> FeedbackSignal: """Analyze model output confidence for routing decisions.""" return self.execution_feedback.analyze_output( token_logprobs, task_type=task_type, current_tier=current_tier) def get_recovery(self, failure_tag: str, current_tier: int, retry_num: int, previous_actions: List[str] = None, run_cost: float = 0.0) -> RecoveryAction: return self.retry_optimizer.get_recovery( failure_tag, current_tier, retry_num, previous_actions, run_cost, self.config.routing_policy.max_cost_per_task) def end_run(self, success: bool, outcome: str = "completed", artifacts: List[str] = None, failure_tags: List[str] = None, user_correction: bool = False) -> AgentTrace: if self._current_trace: self._current_trace.task_success = success self._current_trace.final_outcome = outcome self._current_trace.artifacts_created = artifacts or [] self._current_trace.failure_tags = failure_tags or [] self._current_trace.user_correction = user_correction summary = self._current_trace.compute_summary() self._current_trace.total_cost = summary["total_cost"] self._current_trace.total_tokens = summary["total_tokens"] self._current_trace.total_tool_calls = summary["total_tool_calls"] self._current_trace.total_retries = summary["total_retries"] self._current_trace.total_verifier_calls = summary["total_verifier_calls"] self._current_trace.cache_hit_rate = summary["cache_hit_rate"] self._traces.append(self._current_trace) trace = self._current_trace self._current_trace = None return trace def layout_prompt(self, sources: Dict[str, str]) -> PromptLayout: if not self._current_trace: budget = self.context_budgeter.budget("unknown_ambiguous", 3, False, False) else: budget = self.context_budgeter.budget( self._current_trace.task_type, self._current_trace.difficulty, False, False) return self.cache_layout.layout(sources, budget) def get_stats(self) -> Dict: return { "total_runs": len(self._traces), "successful_runs": sum(1 for t in self._traces if t.task_success), "avg_cost": sum(t.total_cost for t in self._traces) / max(len(self._traces),1), "cache_stats": self.cache_layout.stats(), "tool_stats": self.tool_gate.call_stats, "verifier_stats": self.verifier_budgeter.stats, "retry_stats": self.retry_optimizer.recovery_stats, }