| """ACO Optimizer: Main orchestrator that coordinates all modules.""" |
| import json, time, uuid |
| from typing import Dict, List, Optional, Any |
| from .config import ACOConfig, RoutingPolicy |
| from .trace_schema import AgentTrace, TraceStep, ModelCall, ToolCall |
| from .classifier import TaskCostClassifier |
| from .router import ModelCascadeRouter, RoutingDecision |
| from .context_budgeter import ContextBudgeter, ContextBudget |
| from .cache_layout import CacheAwareLayout, PromptLayout |
| from .tool_gate import ToolCostGate, ToolDecision |
| from .verifier_budgeter import VerifierBudgeter, VerifierDecision |
| from .retry_optimizer import RetryOptimizer, RecoveryAction |
| from .meta_tool_miner import MetaToolMiner, MacroTool |
| from .doom_detector import DoomDetector, DoomAssessment |
| from .execution_feedback import ExecutionFeedbackRouter, CascadeResult, FeedbackSignal |
|
|
| class ACOOptimizer: |
| def __init__(self, config: ACOConfig = None): |
| self.config = config or ACOConfig() |
| self.classifier = TaskCostClassifier() |
| self.router = ModelCascadeRouter( |
| model_path=self.config.router_model_path, |
| safety_threshold=self.config.routing_policy.safety_threshold, |
| downgrade_threshold=self.config.routing_policy.downgrade_threshold, |
| task_floor=self.config.task_floors, |
| tier_costs=self.config.tier_costs, |
| ) |
| self.context_budgeter = ContextBudgeter() |
| self.cache_layout = CacheAwareLayout() |
| self.tool_gate = ToolCostGate() |
| self.verifier_budgeter = VerifierBudgeter() |
| self.retry_optimizer = RetryOptimizer( |
| max_retries=self.config.routing_policy.max_retries, |
| ) |
| self.meta_tool_miner = MetaToolMiner() |
| self.doom_detector = DoomDetector() |
| self.execution_feedback = ExecutionFeedbackRouter( |
| tier_costs=self.config.tier_costs, |
| task_floors=self.config.task_floors, |
| ) |
| self._current_trace: Optional[AgentTrace] = None |
| self._step_num = 0 |
| self._traces: List[AgentTrace] = [] |
|
|
| def start_run(self, request: str) -> Dict: |
| prediction = self.classifier.classify(request) |
| routing = self.router.route(request, prediction["task_type"], prediction["difficulty"], prediction) |
| context_budget = self.context_budgeter.budget( |
| prediction["task_type"], prediction["difficulty"], |
| prediction["needs_retrieval"], prediction["needs_tools"], |
| ) |
| |
| macro = self.meta_tool_miner.match_macro(request, prediction["task_type"]) if self.config.enable_meta_tools else None |
| self._current_trace = AgentTrace( |
| request=request, |
| task_type=prediction["task_type"], |
| difficulty=prediction["difficulty"], |
| predicted_tier=routing.tier, |
| ) |
| self._step_num = 0 |
| self.retry_optimizer.reset_run() |
| self.verifier_budgeter.reset_run() |
| return { |
| "trace_id": self._current_trace.trace_id, |
| "prediction": prediction, |
| "routing": { |
| "model_id": routing.model_id, |
| "tier": routing.tier, |
| "confidence": routing.confidence, |
| "cost_estimate": routing.cost_estimate, |
| "dynamic_difficulty": routing.dynamic_difficulty, |
| "escalated": routing.escalated, |
| "downgraded": routing.downgraded, |
| "reasoning": routing.reasoning, |
| }, |
| "context_budget": { |
| "total_tokens": context_budget.total_tokens, |
| "keep_exact": context_budget.keep_exact, |
| "summarize": context_budget.summarize, |
| "omit": context_budget.omit, |
| "retrieve_on_demand": context_budget.retrieve_on_demand, |
| "cache_prefix": context_budget.cache_prefix, |
| }, |
| "macro_tool": macro.name if macro else None, |
| } |
|
|
| def record_step(self, model_call: Dict = None, tool_calls: List[Dict] = None, |
| context_size: int = 0, verifier_called: bool = False, |
| verifier_result: str = None, retry_num: int = 0, |
| recovery_action: str = None) -> None: |
| self._step_num += 1 |
| mc = None |
| if model_call: |
| mc = ModelCall(**model_call) |
| tcs = [ToolCall(**tc) for tc in (tool_calls or [])] |
| step = TraceStep( |
| step_num=self._step_num, |
| model_call=mc, |
| tool_calls=tcs, |
| context_size=context_size, |
| verifier_called=verifier_called, |
| verifier_result=verifier_result, |
| retry_num=retry_num, |
| recovery_action=recovery_action, |
| ) |
| if self._current_trace: |
| self._current_trace.steps.append(step) |
|
|
| def check_doom(self, current_cost: float = 0.0) -> DoomAssessment: |
| if not self._current_trace: |
| return DoomAssessment(False, 0.0, [], "continue", "no active trace") |
| return self.doom_detector.assess( |
| [s.__dict__ for s in self._current_trace.steps], |
| current_cost, self.config.routing_policy.max_cost_per_task, 4) |
|
|
| def should_verify(self, is_irreversible: bool = False, |
| has_prior_failures: bool = False) -> VerifierDecision: |
| if not self._current_trace: |
| return VerifierDecision(False, "skip", 0.0, "no active trace", 0.0) |
| return self.verifier_budgeter.should_verify( |
| self._current_trace.task_type, "medium", 0.8, |
| is_irreversible, has_prior_failures, |
| self._current_trace.predicted_tier) |
|
|
| def gate_tool(self, tool_name: str, args: Dict) -> ToolDecision: |
| if not self._current_trace: |
| return ToolDecision("skip", tool_name, 0.0, "no active trace", 0.0, 0.0) |
| return self.tool_gate.gate(tool_name, args, self._current_trace.task_type, |
| self._step_num, self._step_num + 1, 0.5) |
|
|
| def cascade_step(self, request: str, initial_tier: int, |
| cheap_logprobs: List[float], |
| cheap_response: str, |
| strong_response: str = "", |
| task_type: str = None) -> CascadeResult: |
| """Execution-feedback cascade: use cheap model output to decide escalation.""" |
| if not self._current_trace: |
| task_type = task_type or "unknown_ambiguous" |
| else: |
| task_type = task_type or self._current_trace.task_type |
| floor = self.config.task_floors.get(task_type, 1) |
| return self.execution_feedback.cascade( |
| request, initial_tier, cheap_logprobs, |
| cheap_response, strong_response, |
| task_type=task_type, task_floor=floor, |
| ) |
|
|
| def analyze_output_confidence(self, token_logprobs: List[float], |
| task_type: str = "unknown", |
| current_tier: int = 2) -> FeedbackSignal: |
| """Analyze model output confidence for routing decisions.""" |
| return self.execution_feedback.analyze_output( |
| token_logprobs, task_type=task_type, current_tier=current_tier) |
|
|
| def get_recovery(self, failure_tag: str, current_tier: int, |
| retry_num: int, previous_actions: List[str] = None, |
| run_cost: float = 0.0) -> RecoveryAction: |
| return self.retry_optimizer.get_recovery( |
| failure_tag, current_tier, retry_num, |
| previous_actions, run_cost, |
| self.config.routing_policy.max_cost_per_task) |
|
|
| def end_run(self, success: bool, outcome: str = "completed", |
| artifacts: List[str] = None, failure_tags: List[str] = None, |
| user_correction: bool = False) -> AgentTrace: |
| if self._current_trace: |
| self._current_trace.task_success = success |
| self._current_trace.final_outcome = outcome |
| self._current_trace.artifacts_created = artifacts or [] |
| self._current_trace.failure_tags = failure_tags or [] |
| self._current_trace.user_correction = user_correction |
| summary = self._current_trace.compute_summary() |
| self._current_trace.total_cost = summary["total_cost"] |
| self._current_trace.total_tokens = summary["total_tokens"] |
| self._current_trace.total_tool_calls = summary["total_tool_calls"] |
| self._current_trace.total_retries = summary["total_retries"] |
| self._current_trace.total_verifier_calls = summary["total_verifier_calls"] |
| self._current_trace.cache_hit_rate = summary["cache_hit_rate"] |
| self._traces.append(self._current_trace) |
| trace = self._current_trace |
| self._current_trace = None |
| return trace |
|
|
| def layout_prompt(self, sources: Dict[str, str]) -> PromptLayout: |
| if not self._current_trace: |
| budget = self.context_budgeter.budget("unknown_ambiguous", 3, False, False) |
| else: |
| budget = self.context_budgeter.budget( |
| self._current_trace.task_type, |
| self._current_trace.difficulty, |
| False, False) |
| return self.cache_layout.layout(sources, budget) |
|
|
| def get_stats(self) -> Dict: |
| return { |
| "total_runs": len(self._traces), |
| "successful_runs": sum(1 for t in self._traces if t.task_success), |
| "avg_cost": sum(t.total_cost for t in self._traces) / max(len(self._traces),1), |
| "cache_stats": self.cache_layout.stats(), |
| "tool_stats": self.tool_gate.call_stats, |
| "verifier_stats": self.verifier_budgeter.stats, |
| "retry_stats": self.retry_optimizer.recovery_stats, |
| } |
|
|