| """Tool-Use Cost Gate: Predict whether a tool call is worth the cost.""" |
| from typing import Dict, List, Optional, Tuple |
| from dataclasses import dataclass |
|
|
| @dataclass |
| class ToolDecision: |
| action: str |
| tool_name: str |
| confidence: float |
| reasoning: str |
| estimated_cost: float |
| estimated_value: float |
| alternative: Optional[str] = None |
|
|
| TOOL_COST_ESTIMATES = { |
| "web_search": {"cost": 0.01, "latency_ms": 2000, "value_base": 0.6}, |
| "code_search": {"cost": 0.005, "latency_ms": 1000, "value_base": 0.7}, |
| "file_read": {"cost": 0.001, "latency_ms": 100, "value_base": 0.8}, |
| "file_write": {"cost": 0.001, "latency_ms": 100, "value_base": 0.9}, |
| "api_call": {"cost": 0.05, "latency_ms": 3000, "value_base": 0.5}, |
| "database_query": {"cost": 0.02, "latency_ms": 500, "value_base": 0.6}, |
| "code_execute": {"cost": 0.01, "latency_ms": 5000, "value_base": 0.7}, |
| "web_scrape": {"cost": 0.02, "latency_ms": 5000, "value_base": 0.4}, |
| "summarize": {"cost": 0.005, "latency_ms": 2000, "value_base": 0.5}, |
| "verify": {"cost": 0.02, "latency_ms": 3000, "value_base": 0.6}, |
| } |
|
|
| class ToolCostGate: |
| def __init__(self, cost_threshold: float = 0.05, value_threshold: float = 0.3, |
| batch_window_ms: int = 5000): |
| self.cost_threshold = cost_threshold |
| self.value_threshold = value_threshold |
| self.batch_window_ms = batch_window_ms |
| self.pending_calls: List[Dict] = [] |
| self.call_stats = {"used":0,"skipped":0,"batched":0,"cached":0,"escalated":0} |
| self._result_cache: Dict[str, str] = {} |
|
|
| def gate(self, tool_name: str, args: Dict, task_type: str, |
| step_num: int, total_steps: int, confidence: float, |
| prior_results: List[str] = None) -> ToolDecision: |
| est = TOOL_COST_ESTIMATES.get(tool_name, {"cost":0.02,"latency_ms":2000,"value_base":0.5}) |
| |
| cache_key = f"{tool_name}:{hash(str(args))}" |
| if cache_key in self._result_cache: |
| self.call_stats["cached"] += 1 |
| return ToolDecision("cache", tool_name, 1.0, "cached result available", |
| 0.0, est["value_base"], "use_cached_result") |
| |
| value = est["value_base"] |
| if task_type == "coding" and tool_name in ("code_search","file_read","code_execute"): |
| value += 0.2 |
| if task_type == "research" and tool_name in ("web_search","web_scrape"): |
| value += 0.2 |
| if step_num == 1 and tool_name in ("web_search","code_search"): |
| value += 0.1 |
| if prior_results and len(prior_results) > 3: |
| value -= 0.2 |
| if confidence > 0.8: |
| value -= 0.3 |
| cost = est["cost"] |
| |
| if value < self.value_threshold: |
| self.call_stats["skipped"] += 1 |
| return ToolDecision("skip", tool_name, value, "low value, not worth cost", |
| cost, value, "proceed_without_tool") |
| if cost > self.cost_threshold and value < 0.6: |
| self.call_stats["skipped"] += 1 |
| return ToolDecision("skip", tool_name, value, "cost exceeds threshold", |
| cost, value, "proceed_without_tool") |
| self.call_stats["used"] += 1 |
| return ToolDecision("use", tool_name, value, "tool value justifies cost", |
| cost, value, None) |
|
|
| def should_batch(self, pending_calls: List[Dict]) -> bool: |
| if len(pending_calls) < 2: return False |
| independent = all(c.get("independent", True) for c in pending_calls) |
| same_type = len(set(c["tool"] for c in pending_calls)) <= 2 |
| return independent and same_type |
|
|
| def record_result(self, tool_name: str, args: Dict, result: str, |
| was_useful: bool = True): |
| cache_key = f"{tool_name}:{hash(str(args))}" |
| if was_useful: |
| self._result_cache[cache_key] = result |
|
|