Upload aco/optimizer.py with huggingface_hub
Browse files- aco/optimizer.py +164 -316
aco/optimizer.py
CHANGED
|
@@ -1,326 +1,174 @@
|
|
| 1 |
-
"""
|
| 2 |
-
|
| 3 |
-
import uuid
|
| 4 |
-
import time
|
| 5 |
from typing import Dict, List, Optional, Any
|
| 6 |
-
from
|
| 7 |
-
|
| 8 |
-
from .
|
| 9 |
-
from .trace_schema import (
|
| 10 |
-
AgentTrace, TraceStep, ModelCall, ToolCall, VerifierCall,
|
| 11 |
-
TaskType, Outcome, FailureTag,
|
| 12 |
-
)
|
| 13 |
-
from .telemetry import CostTelemetryCollector
|
| 14 |
-
from .classifier import TaskCostClassifier, TaskPrediction
|
| 15 |
from .router import ModelCascadeRouter, RoutingDecision
|
| 16 |
-
from .context_budgeter import ContextBudgeter,
|
| 17 |
-
from .cache_layout import
|
| 18 |
-
from .tool_gate import
|
| 19 |
-
from .verifier_budgeter import VerifierBudgeter,
|
| 20 |
-
from .retry_optimizer import
|
| 21 |
-
from .meta_tool_miner import MetaToolMiner
|
| 22 |
-
from .doom_detector import DoomDetector, DoomAssessment
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
class OptimizationResult:
|
| 27 |
-
trace_id: str
|
| 28 |
-
routing_decision: RoutingDecision
|
| 29 |
-
context_budget: Optional[ContextBudget]
|
| 30 |
-
prompt_layout: Optional[PromptLayout]
|
| 31 |
-
tool_decisions: List[ToolGateDecision]
|
| 32 |
-
verifier_decision: Optional[VerifierBudgetDecision]
|
| 33 |
-
recovery_decision: Optional[RecoveryDecision]
|
| 34 |
-
doom_assessment: Optional[DoomAssessment]
|
| 35 |
-
meta_tool_match: Optional[Dict]
|
| 36 |
-
estimated_cost: float
|
| 37 |
-
estimated_latency_ms: float
|
| 38 |
-
confidence: float
|
| 39 |
-
reasoning: str
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
class AgentCostOptimizer:
|
| 43 |
-
"""Universal control layer for reducing agent run costs while preserving quality."""
|
| 44 |
-
|
| 45 |
-
def __init__(self, config: Optional[ACOConfig] = None):
|
| 46 |
self.config = config or ACOConfig()
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
self.tool_gate = ToolUseCostGate(self.config)
|
| 55 |
-
self.verifier_budgeter = VerifierBudgeter(self.config)
|
| 56 |
-
self.retry_optimizer = RetryRecoveryOptimizer(self.config)
|
| 57 |
-
self.meta_tool_miner = MetaToolMiner(self.config)
|
| 58 |
-
self.doom_detector = DoomDetector(self.config)
|
| 59 |
-
|
| 60 |
-
# Runtime state
|
| 61 |
-
self.active_traces: Dict[str, AgentTrace] = {}
|
| 62 |
-
self.step_counter: Dict[str, int] = {}
|
| 63 |
-
|
| 64 |
-
def optimize(self, user_request: str, run_state: Optional[Dict] = None) -> OptimizationResult:
|
| 65 |
-
"""Main entry point: decide how to execute an agent request cost-effectively."""
|
| 66 |
-
|
| 67 |
-
run_state = run_state or {}
|
| 68 |
-
trace_id = run_state.get("trace_id", str(uuid.uuid4()))
|
| 69 |
-
|
| 70 |
-
# 1. Classify the task
|
| 71 |
-
past_traces = self._get_past_traces()
|
| 72 |
-
prediction = self.classifier.classify_with_history(user_request, past_traces)
|
| 73 |
-
|
| 74 |
-
# 2. Route to model
|
| 75 |
-
routing_mode = run_state.get("routing_mode", "cascade")
|
| 76 |
-
routing = self.router.route(prediction, routing_mode=routing_mode)
|
| 77 |
-
|
| 78 |
-
# 3. Budget context
|
| 79 |
-
context_budget = None
|
| 80 |
-
if self.config.enable_context_budgeter:
|
| 81 |
-
available_sources = self._build_context_sources(run_state)
|
| 82 |
-
model_cfg = self.config.models.get(routing.model_id)
|
| 83 |
-
max_ctx = model_cfg.max_context if model_cfg else 128000
|
| 84 |
-
cost_1k = model_cfg.cost_per_1k_input if model_cfg else 0.01
|
| 85 |
-
context_budget = self.context_budgeter.budget(
|
| 86 |
-
prediction.task_type, available_sources, max_ctx, cost_1k
|
| 87 |
-
)
|
| 88 |
-
|
| 89 |
-
# 4. Optimize cache layout
|
| 90 |
-
prompt_layout = None
|
| 91 |
-
if self.config.enable_cache_layout and context_budget:
|
| 92 |
-
content_pieces = self._build_content_pieces(context_budget)
|
| 93 |
-
cache_discount = 0.5
|
| 94 |
-
if model_cfg:
|
| 95 |
-
cache_discount = model_cfg.cache_discount_rate
|
| 96 |
-
prompt_layout = self.cache_layout.layout(
|
| 97 |
-
content_pieces, cost_1k, cache_discount
|
| 98 |
-
)
|
| 99 |
-
|
| 100 |
-
# 5. Check for meta-tool
|
| 101 |
-
meta_tool_match = None
|
| 102 |
-
if self.config.enable_meta_tool_miner:
|
| 103 |
-
planned_tools = run_state.get("planned_tools", [])
|
| 104 |
-
planned_tool_names = [t[0] for t in planned_tools]
|
| 105 |
-
meta_tool_match = self.meta_tool_miner.match_and_compress(
|
| 106 |
-
prediction.task_type, planned_tool_names
|
| 107 |
-
)
|
| 108 |
-
|
| 109 |
-
# 6. Gate tool calls
|
| 110 |
-
tool_decisions = []
|
| 111 |
-
if self.config.enable_tool_gate:
|
| 112 |
-
planned_tools = run_state.get("planned_tools", [])
|
| 113 |
-
current_cost = run_state.get("current_cost", 0.0)
|
| 114 |
-
for tool_name, tool_input in planned_tools:
|
| 115 |
-
decision = self.tool_gate.decide(
|
| 116 |
-
tool_name, prediction.task_type, tool_input,
|
| 117 |
-
run_state.get("previous_tool_calls"), current_cost,
|
| 118 |
-
prediction.expected_cost,
|
| 119 |
-
)
|
| 120 |
-
tool_decisions.append(decision)
|
| 121 |
-
|
| 122 |
-
# 7. Decide verifier
|
| 123 |
-
verifier_decision = None
|
| 124 |
-
if self.config.enable_verifier_budgeter:
|
| 125 |
-
verifier_decision = self.verifier_budgeter.decide(
|
| 126 |
-
task_type=prediction.task_type,
|
| 127 |
-
model_tier_used=routing.tier,
|
| 128 |
-
confidence=routing.confidence,
|
| 129 |
-
has_prior_failures=bool(run_state.get("prior_failures")),
|
| 130 |
-
is_irreversible=run_state.get("is_irreversible", False),
|
| 131 |
-
output_length_tokens=run_state.get("expected_output_tokens", 1024),
|
| 132 |
-
retrieval_evidence_count=len(run_state.get("retrieved_docs", [])),
|
| 133 |
-
step_number=run_state.get("step_number", 1),
|
| 134 |
-
total_steps=run_state.get("total_steps", 1),
|
| 135 |
-
mode="risk_weighted",
|
| 136 |
-
)
|
| 137 |
-
|
| 138 |
-
# 8. Check for doom
|
| 139 |
-
doom = None
|
| 140 |
-
if self.config.enable_early_termination and trace_id in self.active_traces:
|
| 141 |
-
trace = self.active_traces[trace_id]
|
| 142 |
-
current_step = trace.steps[-1] if trace.steps else None
|
| 143 |
-
doom = self.doom_detector.assess(
|
| 144 |
-
trace, current_step, prediction.expected_cost, prediction.expected_cost * 10000
|
| 145 |
-
)
|
| 146 |
-
|
| 147 |
-
# 9. Recovery (if in recovery mode)
|
| 148 |
-
recovery = None
|
| 149 |
-
if run_state.get("in_recovery") and trace_id in self.active_traces:
|
| 150 |
-
trace = self.active_traces[trace_id]
|
| 151 |
-
failure_tags = [FailureTag(f) for f in run_state.get("failure_tags", [])]
|
| 152 |
-
current_step = trace.steps[-1] if trace.steps else None
|
| 153 |
-
recovery = self.retry_optimizer.decide_recovery(
|
| 154 |
-
prediction.task_type,
|
| 155 |
-
current_step,
|
| 156 |
-
failure_tags,
|
| 157 |
-
trace.total_cost_computed,
|
| 158 |
-
prediction.expected_cost,
|
| 159 |
-
routing.tier,
|
| 160 |
-
run_state.get("step_number", 1),
|
| 161 |
-
trace.steps,
|
| 162 |
-
)
|
| 163 |
-
|
| 164 |
-
# Estimate cost
|
| 165 |
-
est_model_cost = self._estimate_model_cost(routing, context_budget)
|
| 166 |
-
est_tool_cost = sum(d.estimated_cost for d in tool_decisions if d.decision == ToolDecision.USE)
|
| 167 |
-
est_verifier_cost = verifier_decision.estimated_verifier_cost if verifier_decision else 0.0
|
| 168 |
-
|
| 169 |
-
estimated_cost = est_model_cost + est_tool_cost + est_verifier_cost
|
| 170 |
-
estimated_latency = prediction.expected_latency_ms
|
| 171 |
-
|
| 172 |
-
return OptimizationResult(
|
| 173 |
-
trace_id=trace_id,
|
| 174 |
-
routing_decision=routing,
|
| 175 |
-
context_budget=context_budget,
|
| 176 |
-
prompt_layout=prompt_layout,
|
| 177 |
-
tool_decisions=tool_decisions,
|
| 178 |
-
verifier_decision=verifier_decision,
|
| 179 |
-
recovery_decision=recovery,
|
| 180 |
-
doom_assessment=doom,
|
| 181 |
-
meta_tool_match=meta_tool_match,
|
| 182 |
-
estimated_cost=estimated_cost,
|
| 183 |
-
estimated_latency_ms=estimated_latency,
|
| 184 |
-
confidence=routing.confidence,
|
| 185 |
-
reasoning=f"Task={prediction.task_type.value}, tier={routing.tier}, risk={prediction.risk_of_failure:.2f}",
|
| 186 |
)
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
self.
|
| 191 |
-
self.
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
def record_step(
|
| 195 |
-
self,
|
| 196 |
-
trace_id: str,
|
| 197 |
-
model_call: ModelCall,
|
| 198 |
-
tool_calls: Optional[List[ToolCall]] = None,
|
| 199 |
-
verifier_calls: Optional[List[VerifierCall]] = None,
|
| 200 |
-
context_size_tokens: int = 0,
|
| 201 |
-
step_outcome: Optional[Outcome] = None,
|
| 202 |
-
) -> None:
|
| 203 |
-
self.step_counter[trace_id] = self.step_counter.get(trace_id, 0) + 1
|
| 204 |
-
step_id = f"{trace_id}_step_{self.step_counter[trace_id]}"
|
| 205 |
-
self.telemetry.add_step(
|
| 206 |
-
trace_id=trace_id,
|
| 207 |
-
step_id=step_id,
|
| 208 |
-
model_call=model_call,
|
| 209 |
-
tool_calls=tool_calls or [],
|
| 210 |
-
verifier_calls=verifier_calls or [],
|
| 211 |
-
context_size_tokens=context_size_tokens,
|
| 212 |
-
step_outcome=step_outcome,
|
| 213 |
)
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
self
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
self.meta_tool_miner.ingest_trace(trace)
|
| 227 |
-
|
| 228 |
-
return self.telemetry.finalize_trace(
|
| 229 |
-
trace_id=trace_id,
|
| 230 |
-
final_outcome=outcome,
|
| 231 |
-
failure_tags=failure_tags,
|
| 232 |
-
user_satisfaction=user_satisfaction,
|
| 233 |
-
)
|
| 234 |
-
return None
|
| 235 |
-
|
| 236 |
-
def compute_cost_adjusted_score(
|
| 237 |
-
self,
|
| 238 |
-
trace: AgentTrace,
|
| 239 |
-
success_score: float = 1.0,
|
| 240 |
-
safety_bonus: float = 0.0,
|
| 241 |
-
artifact_bonus: float = 0.0,
|
| 242 |
-
calibration_bonus: float = 0.0,
|
| 243 |
-
) -> float:
|
| 244 |
-
"""Compute the cost-adjusted quality score."""
|
| 245 |
-
cost = trace.total_cost_computed
|
| 246 |
-
retries = trace.total_retries
|
| 247 |
-
tools = trace.total_tool_calls
|
| 248 |
-
verifiers = trace.total_verifier_calls
|
| 249 |
-
|
| 250 |
-
score = (
|
| 251 |
-
success_score
|
| 252 |
-
+ safety_bonus
|
| 253 |
-
+ artifact_bonus
|
| 254 |
-
+ calibration_bonus
|
| 255 |
-
- self.config.model_cost_weight * cost
|
| 256 |
-
- self.config.tool_cost_weight * tools * 0.001
|
| 257 |
-
- self.config.verifier_cost_weight * verifiers * 0.001
|
| 258 |
-
- self.config.latency_weight * trace.total_latency_ms / 1000
|
| 259 |
-
- self.config.retry_penalty_weight * retries * 0.01
|
| 260 |
)
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
def _build_content_pieces(self, context_budget: ContextBudget) -> Dict[str, str]:
|
| 296 |
-
"""Build content pieces dict from context budget."""
|
| 297 |
-
pieces = {}
|
| 298 |
-
for src in context_budget.allocated_sources:
|
| 299 |
-
pieces[src.name] = f"[{src.name}]"
|
| 300 |
-
for src, summary in context_budget.summarized_sources:
|
| 301 |
-
pieces[src.name] = summary
|
| 302 |
-
return pieces
|
| 303 |
|
| 304 |
-
def
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
|
| 314 |
-
def
|
| 315 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 316 |
return {
|
| 317 |
-
"
|
| 318 |
-
"
|
| 319 |
-
"
|
| 320 |
-
"
|
|
|
|
|
|
|
|
|
|
| 321 |
}
|
| 322 |
-
|
| 323 |
-
@classmethod
|
| 324 |
-
def from_config(cls, path: str) -> "AgentCostOptimizer":
|
| 325 |
-
config = ACOConfig.from_yaml(path)
|
| 326 |
-
return cls(config)
|
|
|
|
| 1 |
+
"""ACO Optimizer: Main orchestrator that coordinates all modules."""
|
| 2 |
+
import json, time, uuid
|
|
|
|
|
|
|
| 3 |
from typing import Dict, List, Optional, Any
|
| 4 |
+
from .config import ACOConfig, RoutingPolicy
|
| 5 |
+
from .trace_schema import AgentTrace, TraceStep, ModelCall, ToolCall
|
| 6 |
+
from .classifier import TaskCostClassifier
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
from .router import ModelCascadeRouter, RoutingDecision
|
| 8 |
+
from .context_budgeter import ContextBudgeter, ContextBudget
|
| 9 |
+
from .cache_layout import CacheAwareLayout, PromptLayout
|
| 10 |
+
from .tool_gate import ToolCostGate, ToolDecision
|
| 11 |
+
from .verifier_budgeter import VerifierBudgeter, VerifierDecision
|
| 12 |
+
from .retry_optimizer import RetryOptimizer, RecoveryAction
|
| 13 |
+
from .meta_tool_miner import MetaToolMiner, MacroTool
|
| 14 |
+
from .doom_detector import DoomDetector, DoomAssessment
|
| 15 |
+
|
| 16 |
+
class ACOOptimizer:
|
| 17 |
+
def __init__(self, config: ACOConfig = None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
self.config = config or ACOConfig()
|
| 19 |
+
self.classifier = TaskCostClassifier()
|
| 20 |
+
self.router = ModelCascadeRouter(
|
| 21 |
+
model_path=self.config.router_model_path,
|
| 22 |
+
safety_threshold=self.config.routing_policy.safety_threshold,
|
| 23 |
+
downgrade_threshold=self.config.routing_policy.downgrade_threshold,
|
| 24 |
+
task_floor=self.config.task_floors,
|
| 25 |
+
tier_costs=self.config.tier_costs,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
)
|
| 27 |
+
self.context_budgeter = ContextBudgeter()
|
| 28 |
+
self.cache_layout = CacheAwareLayout()
|
| 29 |
+
self.tool_gate = ToolCostGate()
|
| 30 |
+
self.verifier_budgeter = VerifierBudgeter()
|
| 31 |
+
self.retry_optimizer = RetryOptimizer(
|
| 32 |
+
max_retries=self.config.routing_policy.max_retries,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
)
|
| 34 |
+
self.meta_tool_miner = MetaToolMiner()
|
| 35 |
+
self.doom_detector = DoomDetector()
|
| 36 |
+
self._current_trace: Optional[AgentTrace] = None
|
| 37 |
+
self._step_num = 0
|
| 38 |
+
self._traces: List[AgentTrace] = []
|
| 39 |
+
|
| 40 |
+
def start_run(self, request: str) -> Dict:
|
| 41 |
+
prediction = self.classifier.classify(request)
|
| 42 |
+
routing = self.router.route(request, prediction["task_type"], prediction["difficulty"], prediction)
|
| 43 |
+
context_budget = self.context_budgeter.budget(
|
| 44 |
+
prediction["task_type"], prediction["difficulty"],
|
| 45 |
+
prediction["needs_retrieval"], prediction["needs_tools"],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
)
|
| 47 |
+
# Check for meta-tool match
|
| 48 |
+
macro = self.meta_tool_miner.match_macro(request, prediction["task_type"]) if self.config.enable_meta_tools else None
|
| 49 |
+
self._current_trace = AgentTrace(
|
| 50 |
+
request=request,
|
| 51 |
+
task_type=prediction["task_type"],
|
| 52 |
+
difficulty=prediction["difficulty"],
|
| 53 |
+
predicted_tier=routing.tier,
|
| 54 |
+
)
|
| 55 |
+
self._step_num = 0
|
| 56 |
+
self.retry_optimizer.reset_run()
|
| 57 |
+
self.verifier_budgeter.reset_run()
|
| 58 |
+
return {
|
| 59 |
+
"trace_id": self._current_trace.trace_id,
|
| 60 |
+
"prediction": prediction,
|
| 61 |
+
"routing": {
|
| 62 |
+
"model_id": routing.model_id,
|
| 63 |
+
"tier": routing.tier,
|
| 64 |
+
"confidence": routing.confidence,
|
| 65 |
+
"cost_estimate": routing.cost_estimate,
|
| 66 |
+
"dynamic_difficulty": routing.dynamic_difficulty,
|
| 67 |
+
"escalated": routing.escalated,
|
| 68 |
+
"downgraded": routing.downgraded,
|
| 69 |
+
"reasoning": routing.reasoning,
|
| 70 |
+
},
|
| 71 |
+
"context_budget": {
|
| 72 |
+
"total_tokens": context_budget.total_tokens,
|
| 73 |
+
"keep_exact": context_budget.keep_exact,
|
| 74 |
+
"summarize": context_budget.summarize,
|
| 75 |
+
"omit": context_budget.omit,
|
| 76 |
+
"retrieve_on_demand": context_budget.retrieve_on_demand,
|
| 77 |
+
"cache_prefix": context_budget.cache_prefix,
|
| 78 |
+
},
|
| 79 |
+
"macro_tool": macro.name if macro else None,
|
| 80 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
+
def record_step(self, model_call: Dict = None, tool_calls: List[Dict] = None,
|
| 83 |
+
context_size: int = 0, verifier_called: bool = False,
|
| 84 |
+
verifier_result: str = None, retry_num: int = 0,
|
| 85 |
+
recovery_action: str = None) -> None:
|
| 86 |
+
self._step_num += 1
|
| 87 |
+
mc = None
|
| 88 |
+
if model_call:
|
| 89 |
+
mc = ModelCall(**model_call)
|
| 90 |
+
tcs = [ToolCall(**tc) for tc in (tool_calls or [])]
|
| 91 |
+
step = TraceStep(
|
| 92 |
+
step_num=self._step_num,
|
| 93 |
+
model_call=mc,
|
| 94 |
+
tool_calls=tcs,
|
| 95 |
+
context_size=context_size,
|
| 96 |
+
verifier_called=verifier_called,
|
| 97 |
+
verifier_result=verifier_result,
|
| 98 |
+
retry_num=retry_num,
|
| 99 |
+
recovery_action=recovery_action,
|
| 100 |
+
)
|
| 101 |
+
if self._current_trace:
|
| 102 |
+
self._current_trace.steps.append(step)
|
| 103 |
+
|
| 104 |
+
def check_doom(self, current_cost: float = 0.0) -> DoomAssessment:
|
| 105 |
+
if not self._current_trace:
|
| 106 |
+
return DoomAssessment(False, 0.0, [], "continue", "no active trace")
|
| 107 |
+
return self.doom_detector.assess(
|
| 108 |
+
[s.__dict__ for s in self._current_trace.steps],
|
| 109 |
+
current_cost, self.config.routing_policy.max_cost_per_task, 4)
|
| 110 |
+
|
| 111 |
+
def should_verify(self, is_irreversible: bool = False,
|
| 112 |
+
has_prior_failures: bool = False) -> VerifierDecision:
|
| 113 |
+
if not self._current_trace:
|
| 114 |
+
return VerifierDecision(False, "skip", 0.0, "no active trace", 0.0)
|
| 115 |
+
return self.verifier_budgeter.should_verify(
|
| 116 |
+
self._current_trace.task_type, "medium", 0.8,
|
| 117 |
+
is_irreversible, has_prior_failures,
|
| 118 |
+
self._current_trace.predicted_tier)
|
| 119 |
+
|
| 120 |
+
def gate_tool(self, tool_name: str, args: Dict) -> ToolDecision:
|
| 121 |
+
if not self._current_trace:
|
| 122 |
+
return ToolDecision("skip", tool_name, 0.0, "no active trace", 0.0, 0.0)
|
| 123 |
+
return self.tool_gate.gate(tool_name, args, self._current_trace.task_type,
|
| 124 |
+
self._step_num, self._step_num + 1, 0.5)
|
| 125 |
+
|
| 126 |
+
def get_recovery(self, failure_tag: str, current_tier: int,
|
| 127 |
+
retry_num: int, previous_actions: List[str] = None,
|
| 128 |
+
run_cost: float = 0.0) -> RecoveryAction:
|
| 129 |
+
return self.retry_optimizer.get_recovery(
|
| 130 |
+
failure_tag, current_tier, retry_num,
|
| 131 |
+
previous_actions, run_cost,
|
| 132 |
+
self.config.routing_policy.max_cost_per_task)
|
| 133 |
+
|
| 134 |
+
def end_run(self, success: bool, outcome: str = "completed",
|
| 135 |
+
artifacts: List[str] = None, failure_tags: List[str] = None,
|
| 136 |
+
user_correction: bool = False) -> AgentTrace:
|
| 137 |
+
if self._current_trace:
|
| 138 |
+
self._current_trace.task_success = success
|
| 139 |
+
self._current_trace.final_outcome = outcome
|
| 140 |
+
self._current_trace.artifacts_created = artifacts or []
|
| 141 |
+
self._current_trace.failure_tags = failure_tags or []
|
| 142 |
+
self._current_trace.user_correction = user_correction
|
| 143 |
+
summary = self._current_trace.compute_summary()
|
| 144 |
+
self._current_trace.total_cost = summary["total_cost"]
|
| 145 |
+
self._current_trace.total_tokens = summary["total_tokens"]
|
| 146 |
+
self._current_trace.total_tool_calls = summary["total_tool_calls"]
|
| 147 |
+
self._current_trace.total_retries = summary["total_retries"]
|
| 148 |
+
self._current_trace.total_verifier_calls = summary["total_verifier_calls"]
|
| 149 |
+
self._current_trace.cache_hit_rate = summary["cache_hit_rate"]
|
| 150 |
+
self._traces.append(self._current_trace)
|
| 151 |
+
trace = self._current_trace
|
| 152 |
+
self._current_trace = None
|
| 153 |
+
return trace
|
| 154 |
|
| 155 |
+
def layout_prompt(self, sources: Dict[str, str]) -> PromptLayout:
|
| 156 |
+
if not self._current_trace:
|
| 157 |
+
budget = self.context_budgeter.budget("unknown_ambiguous", 3, False, False)
|
| 158 |
+
else:
|
| 159 |
+
budget = self.context_budgeter.budget(
|
| 160 |
+
self._current_trace.task_type,
|
| 161 |
+
self._current_trace.difficulty,
|
| 162 |
+
False, False)
|
| 163 |
+
return self.cache_layout.layout(sources, budget)
|
| 164 |
+
|
| 165 |
+
def get_stats(self) -> Dict:
|
| 166 |
return {
|
| 167 |
+
"total_runs": len(self._traces),
|
| 168 |
+
"successful_runs": sum(1 for t in self._traces if t.task_success),
|
| 169 |
+
"avg_cost": sum(t.total_cost for t in self._traces) / max(len(self._traces),1),
|
| 170 |
+
"cache_stats": self.cache_layout.stats(),
|
| 171 |
+
"tool_stats": self.tool_gate.call_stats,
|
| 172 |
+
"verifier_stats": self.verifier_budgeter.stats,
|
| 173 |
+
"retry_stats": self.retry_optimizer.recovery_stats,
|
| 174 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|