File size: 4,112 Bytes
a7e3035 7d60df1 a7e3035 7d60df1 a7e3035 7d60df1 a7e3035 7d60df1 a7e3035 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 | """Tool-Use Cost Gate: Predict whether a tool call is worth the cost."""
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
@dataclass
class ToolDecision:
action: str # "use", "skip", "batch", "cache", "escalate"
tool_name: str
confidence: float
reasoning: str
estimated_cost: float
estimated_value: float
alternative: Optional[str] = None
TOOL_COST_ESTIMATES = {
"web_search": {"cost": 0.01, "latency_ms": 2000, "value_base": 0.6},
"code_search": {"cost": 0.005, "latency_ms": 1000, "value_base": 0.7},
"file_read": {"cost": 0.001, "latency_ms": 100, "value_base": 0.8},
"file_write": {"cost": 0.001, "latency_ms": 100, "value_base": 0.9},
"api_call": {"cost": 0.05, "latency_ms": 3000, "value_base": 0.5},
"database_query": {"cost": 0.02, "latency_ms": 500, "value_base": 0.6},
"code_execute": {"cost": 0.01, "latency_ms": 5000, "value_base": 0.7},
"web_scrape": {"cost": 0.02, "latency_ms": 5000, "value_base": 0.4},
"summarize": {"cost": 0.005, "latency_ms": 2000, "value_base": 0.5},
"verify": {"cost": 0.02, "latency_ms": 3000, "value_base": 0.6},
}
class ToolCostGate:
def __init__(self, cost_threshold: float = 0.05, value_threshold: float = 0.3,
batch_window_ms: int = 5000):
self.cost_threshold = cost_threshold
self.value_threshold = value_threshold
self.batch_window_ms = batch_window_ms
self.pending_calls: List[Dict] = []
self.call_stats = {"used":0,"skipped":0,"batched":0,"cached":0,"escalated":0}
self._result_cache: Dict[str, str] = {}
def gate(self, tool_name: str, args: Dict, task_type: str,
step_num: int, total_steps: int, confidence: float,
prior_results: List[str] = None) -> ToolDecision:
est = TOOL_COST_ESTIMATES.get(tool_name, {"cost":0.02,"latency_ms":2000,"value_base":0.5})
# Check cache
cache_key = f"{tool_name}:{hash(str(args))}"
if cache_key in self._result_cache:
self.call_stats["cached"] += 1
return ToolDecision("cache", tool_name, 1.0, "cached result available",
0.0, est["value_base"], "use_cached_result")
# Adjust value based on context
value = est["value_base"]
if task_type == "coding" and tool_name in ("code_search","file_read","code_execute"):
value += 0.2
if task_type == "research" and tool_name in ("web_search","web_scrape"):
value += 0.2
if step_num == 1 and tool_name in ("web_search","code_search"):
value += 0.1 # First search is usually valuable
if prior_results and len(prior_results) > 3:
value -= 0.2 # Already have enough info
if confidence > 0.8:
value -= 0.3 # Already confident, tool less valuable
cost = est["cost"]
# Decision
if value < self.value_threshold:
self.call_stats["skipped"] += 1
return ToolDecision("skip", tool_name, value, "low value, not worth cost",
cost, value, "proceed_without_tool")
if cost > self.cost_threshold and value < 0.6:
self.call_stats["skipped"] += 1
return ToolDecision("skip", tool_name, value, "cost exceeds threshold",
cost, value, "proceed_without_tool")
self.call_stats["used"] += 1
return ToolDecision("use", tool_name, value, "tool value justifies cost",
cost, value, None)
def should_batch(self, pending_calls: List[Dict]) -> bool:
if len(pending_calls) < 2: return False
independent = all(c.get("independent", True) for c in pending_calls)
same_type = len(set(c["tool"] for c in pending_calls)) <= 2
return independent and same_type
def record_result(self, tool_name: str, args: Dict, result: str,
was_useful: bool = True):
cache_key = f"{tool_name}:{hash(str(args))}"
if was_useful:
self._result_cache[cache_key] = result
|