agent-cost-optimizer / aco /optimizer.py
narcolepticchicken's picture
Upload aco/optimizer.py
dafb79c verified
raw
history blame
13.7 kB
"""Agent Cost Optimizer - Main orchestrator."""
import uuid
import time
from typing import Dict, List, Optional, Any
from dataclasses import dataclass
from .config import ACOConfig
from .trace_schema import (
AgentTrace, TraceStep, ModelCall, ToolCall, VerifierCall,
TaskType, Outcome, FailureTag,
)
from .telemetry import CostTelemetryCollector
from .classifier import TaskCostClassifier, TaskPrediction
from .router import ModelCascadeRouter, RoutingDecision
from .context_budgeter import ContextBudgeter, ContextSource, ContextBudget
from .cache_layout import CacheAwarePromptLayout, PromptLayout
from .tool_gate import ToolUseCostGate, ToolGateDecision, ToolDecision
from .verifier_budgeter import VerifierBudgeter, VerifierBudgetDecision
from .retry_optimizer import RetryRecoveryOptimizer, RecoveryDecision, RecoveryAction
from .meta_tool_miner import MetaToolMiner
from .doom_detector import DoomDetector, DoomAssessment, DoomAction
@dataclass
class OptimizationResult:
trace_id: str
routing_decision: RoutingDecision
context_budget: Optional[ContextBudget]
prompt_layout: Optional[PromptLayout]
tool_decisions: List[ToolGateDecision]
verifier_decision: Optional[VerifierBudgetDecision]
recovery_decision: Optional[RecoveryDecision]
doom_assessment: Optional[DoomAssessment]
meta_tool_match: Optional[Dict]
estimated_cost: float
estimated_latency_ms: float
confidence: float
reasoning: str
class AgentCostOptimizer:
"""Universal control layer for reducing agent run costs while preserving quality."""
def __init__(self, config: Optional[ACOConfig] = None):
self.config = config or ACOConfig()
# Core modules
self.telemetry = CostTelemetryCollector(self.config.trace_storage_path)
self.classifier = TaskCostClassifier(self.config)
self.router = ModelCascadeRouter(self.config)
self.context_budgeter = ContextBudgeter(self.config)
self.cache_layout = CacheAwarePromptLayout(self.config)
self.tool_gate = ToolUseCostGate(self.config)
self.verifier_budgeter = VerifierBudgeter(self.config)
self.retry_optimizer = RetryRecoveryOptimizer(self.config)
self.meta_tool_miner = MetaToolMiner(self.config)
self.doom_detector = DoomDetector(self.config)
# Runtime state
self.active_traces: Dict[str, AgentTrace] = {}
self.step_counter: Dict[str, int] = {}
def optimize(self, user_request: str, run_state: Optional[Dict] = None) -> OptimizationResult:
"""Main entry point: decide how to execute an agent request cost-effectively."""
run_state = run_state or {}
trace_id = run_state.get("trace_id", str(uuid.uuid4()))
# 1. Classify the task
past_traces = self._get_past_traces()
prediction = self.classifier.classify_with_history(user_request, past_traces)
# 2. Route to model
routing_mode = run_state.get("routing_mode", "cascade")
routing = self.router.route(prediction, routing_mode=routing_mode)
# 3. Budget context
context_budget = None
if self.config.enable_context_budgeter:
available_sources = self._build_context_sources(run_state)
model_cfg = self.config.models.get(routing.model_id)
max_ctx = model_cfg.max_context if model_cfg else 128000
cost_1k = model_cfg.cost_per_1k_input if model_cfg else 0.01
context_budget = self.context_budgeter.budget(
prediction.task_type, available_sources, max_ctx, cost_1k
)
# 4. Optimize cache layout
prompt_layout = None
if self.config.enable_cache_layout and context_budget:
content_pieces = self._build_content_pieces(context_budget)
cache_discount = 0.5
if model_cfg:
cache_discount = model_cfg.cache_discount_rate
prompt_layout = self.cache_layout.layout(
content_pieces, cost_1k, cache_discount
)
# 5. Check for meta-tool
meta_tool_match = None
if self.config.enable_meta_tool_miner:
planned_tools = run_state.get("planned_tools", [])
planned_tool_names = [t[0] for t in planned_tools]
meta_tool_match = self.meta_tool_miner.match_and_compress(
prediction.task_type, planned_tool_names
)
# 6. Gate tool calls
tool_decisions = []
if self.config.enable_tool_gate:
planned_tools = run_state.get("planned_tools", [])
current_cost = run_state.get("current_cost", 0.0)
for tool_name, tool_input in planned_tools:
decision = self.tool_gate.decide(
tool_name, prediction.task_type, tool_input,
run_state.get("previous_tool_calls"), current_cost,
prediction.expected_cost,
)
tool_decisions.append(decision)
# 7. Decide verifier
verifier_decision = None
if self.config.enable_verifier_budgeter:
verifier_decision = self.verifier_budgeter.decide(
task_type=prediction.task_type,
model_tier_used=routing.tier,
confidence=routing.confidence,
has_prior_failures=bool(run_state.get("prior_failures")),
is_irreversible=run_state.get("is_irreversible", False),
output_length_tokens=run_state.get("expected_output_tokens", 1024),
retrieval_evidence_count=len(run_state.get("retrieved_docs", [])),
step_number=run_state.get("step_number", 1),
total_steps=run_state.get("total_steps", 1),
mode="risk_weighted",
)
# 8. Check for doom
doom = None
if self.config.enable_early_termination and trace_id in self.active_traces:
trace = self.active_traces[trace_id]
current_step = trace.steps[-1] if trace.steps else None
doom = self.doom_detector.assess(
trace, current_step, prediction.expected_cost, prediction.expected_cost * 10000
)
# 9. Recovery (if in recovery mode)
recovery = None
if run_state.get("in_recovery") and trace_id in self.active_traces:
trace = self.active_traces[trace_id]
failure_tags = [FailureTag(f) for f in run_state.get("failure_tags", [])]
current_step = trace.steps[-1] if trace.steps else None
recovery = self.retry_optimizer.decide_recovery(
prediction.task_type,
current_step,
failure_tags,
trace.total_cost_computed,
prediction.expected_cost,
routing.tier,
run_state.get("step_number", 1),
trace.steps,
)
# Estimate cost
est_model_cost = self._estimate_model_cost(routing, context_budget)
est_tool_cost = sum(d.estimated_cost for d in tool_decisions if d.decision == ToolDecision.USE)
est_verifier_cost = verifier_decision.estimated_verifier_cost if verifier_decision else 0.0
estimated_cost = est_model_cost + est_tool_cost + est_verifier_cost
estimated_latency = prediction.expected_latency_ms
return OptimizationResult(
trace_id=trace_id,
routing_decision=routing,
context_budget=context_budget,
prompt_layout=prompt_layout,
tool_decisions=tool_decisions,
verifier_decision=verifier_decision,
recovery_decision=recovery,
doom_assessment=doom,
meta_tool_match=meta_tool_match,
estimated_cost=estimated_cost,
estimated_latency_ms=estimated_latency,
confidence=routing.confidence,
reasoning=f"Task={prediction.task_type.value}, tier={routing.tier}, risk={prediction.risk_of_failure:.2f}",
)
def start_trace(self, trace_id: str, user_request: str, prediction: TaskPrediction) -> AgentTrace:
trace = self.telemetry.start_trace(trace_id, user_request, prediction.task_type)
self.active_traces[trace_id] = trace
self.step_counter[trace_id] = 0
return trace
def record_step(
self,
trace_id: str,
model_call: ModelCall,
tool_calls: Optional[List[ToolCall]] = None,
verifier_calls: Optional[List[VerifierCall]] = None,
context_size_tokens: int = 0,
step_outcome: Optional[Outcome] = None,
) -> None:
self.step_counter[trace_id] = self.step_counter.get(trace_id, 0) + 1
step_id = f"{trace_id}_step_{self.step_counter[trace_id]}"
self.telemetry.add_step(
trace_id=trace_id,
step_id=step_id,
model_call=model_call,
tool_calls=tool_calls or [],
verifier_calls=verifier_calls or [],
context_size_tokens=context_size_tokens,
step_outcome=step_outcome,
)
def finalize_trace(
self,
trace_id: str,
outcome: Outcome,
failure_tags: Optional[List[FailureTag]] = None,
user_satisfaction: Optional[float] = None,
) -> AgentTrace:
trace = self.active_traces.pop(trace_id, None)
if trace:
# Mine for meta-tools if successful
if outcome in (Outcome.SUCCESS, Outcome.PARTIAL_SUCCESS):
self.meta_tool_miner.ingest_trace(trace)
return self.telemetry.finalize_trace(
trace_id=trace_id,
final_outcome=outcome,
failure_tags=failure_tags,
user_satisfaction=user_satisfaction,
)
return None
def compute_cost_adjusted_score(
self,
trace: AgentTrace,
success_score: float = 1.0,
safety_bonus: float = 0.0,
artifact_bonus: float = 0.0,
calibration_bonus: float = 0.0,
) -> float:
"""Compute the cost-adjusted quality score."""
cost = trace.total_cost_computed
retries = trace.total_retries
tools = trace.total_tool_calls
verifiers = trace.total_verifier_calls
score = (
success_score
+ safety_bonus
+ artifact_bonus
+ calibration_bonus
- self.config.model_cost_weight * cost
- self.config.tool_cost_weight * tools * 0.001
- self.config.verifier_cost_weight * verifiers * 0.001
- self.config.latency_weight * trace.total_latency_ms / 1000
- self.config.retry_penalty_weight * retries * 0.01
)
# Penalize critical failures
if FailureTag.UNSAFE_CHEAP_MODEL in trace.failure_tags:
score -= self.config.unsafe_cheap_model_penalty
if FailureTag.MISSED_ESCALATION in trace.failure_tags:
score -= self.config.missed_escalation_penalty
if trace.final_outcome == Outcome.FALSE_DONE:
score -= self.config.false_done_penalty
return score
def _get_past_traces(self) -> List[Dict]:
"""Get historical traces as dicts."""
traces = []
for tid in self.telemetry.list_traces():
t = self.telemetry.load_trace(tid)
if t and isinstance(t, dict):
traces.append(t)
return traces
def _build_context_sources(self, run_state: Dict) -> List[ContextSource]:
"""Build context source objects from run state."""
sources = []
for name, content in run_state.get("context_pieces", {}).items():
sources.append(ContextSource(
name=name,
tokens=len(content) // 4,
importance=0.5,
staleness=0.0,
mutable=name in ["user_message", "retrieved_docs", "recent_trace"],
cacheable=name in ["system_rules", "tool_descriptions", "user_preferences"],
))
return sources
def _build_content_pieces(self, context_budget: ContextBudget) -> Dict[str, str]:
"""Build content pieces dict from context budget."""
pieces = {}
for src in context_budget.allocated_sources:
pieces[src.name] = f"[{src.name}]"
for src, summary in context_budget.summarized_sources:
pieces[src.name] = summary
return pieces
def _estimate_model_cost(self, routing: RoutingDecision, context_budget: Optional[ContextBudget]) -> float:
model_cfg = self.config.models.get(routing.model_id)
if not model_cfg:
return 0.01
input_tokens = context_budget.total_budget_tokens if context_budget else 4096
output_tokens = routing.max_tokens
input_cost = (input_tokens / 1000) * model_cfg.cost_per_1k_input
output_cost = (output_tokens / 1000) * model_cfg.cost_per_1k_output
return input_cost + output_cost
def get_stats(self) -> Dict[str, Any]:
"""Get optimizer-wide statistics."""
return {
"telemetry": self.telemetry.get_stats(),
"cache": self.cache_layout.report(),
"meta_tools": self.meta_tool_miner.get_stats(),
"doom": self.doom_detector.get_stats(),
}
@classmethod
def from_config(cls, path: str) -> "AgentCostOptimizer":
config = ACOConfig.from_yaml(path)
return cls(config)