File size: 4,112 Bytes
a7e3035
 
7d60df1
 
 
a7e3035
 
7d60df1
a7e3035
7d60df1
 
a7e3035
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d60df1
a7e3035
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
"""Tool-Use Cost Gate: Predict whether a tool call is worth the cost."""
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass

@dataclass
class ToolDecision:
    action: str  # "use", "skip", "batch", "cache", "escalate"
    tool_name: str
    confidence: float
    reasoning: str
    estimated_cost: float
    estimated_value: float
    alternative: Optional[str] = None

TOOL_COST_ESTIMATES = {
    "web_search": {"cost": 0.01, "latency_ms": 2000, "value_base": 0.6},
    "code_search": {"cost": 0.005, "latency_ms": 1000, "value_base": 0.7},
    "file_read": {"cost": 0.001, "latency_ms": 100, "value_base": 0.8},
    "file_write": {"cost": 0.001, "latency_ms": 100, "value_base": 0.9},
    "api_call": {"cost": 0.05, "latency_ms": 3000, "value_base": 0.5},
    "database_query": {"cost": 0.02, "latency_ms": 500, "value_base": 0.6},
    "code_execute": {"cost": 0.01, "latency_ms": 5000, "value_base": 0.7},
    "web_scrape": {"cost": 0.02, "latency_ms": 5000, "value_base": 0.4},
    "summarize": {"cost": 0.005, "latency_ms": 2000, "value_base": 0.5},
    "verify": {"cost": 0.02, "latency_ms": 3000, "value_base": 0.6},
}

class ToolCostGate:
    def __init__(self, cost_threshold: float = 0.05, value_threshold: float = 0.3,
                 batch_window_ms: int = 5000):
        self.cost_threshold = cost_threshold
        self.value_threshold = value_threshold
        self.batch_window_ms = batch_window_ms
        self.pending_calls: List[Dict] = []
        self.call_stats = {"used":0,"skipped":0,"batched":0,"cached":0,"escalated":0}
        self._result_cache: Dict[str, str] = {}

    def gate(self, tool_name: str, args: Dict, task_type: str,
             step_num: int, total_steps: int, confidence: float,
             prior_results: List[str] = None) -> ToolDecision:
        est = TOOL_COST_ESTIMATES.get(tool_name, {"cost":0.02,"latency_ms":2000,"value_base":0.5})
        # Check cache
        cache_key = f"{tool_name}:{hash(str(args))}"
        if cache_key in self._result_cache:
            self.call_stats["cached"] += 1
            return ToolDecision("cache", tool_name, 1.0, "cached result available",
                              0.0, est["value_base"], "use_cached_result")
        # Adjust value based on context
        value = est["value_base"]
        if task_type == "coding" and tool_name in ("code_search","file_read","code_execute"):
            value += 0.2
        if task_type == "research" and tool_name in ("web_search","web_scrape"):
            value += 0.2
        if step_num == 1 and tool_name in ("web_search","code_search"):
            value += 0.1  # First search is usually valuable
        if prior_results and len(prior_results) > 3:
            value -= 0.2  # Already have enough info
        if confidence > 0.8:
            value -= 0.3  # Already confident, tool less valuable
        cost = est["cost"]
        # Decision
        if value < self.value_threshold:
            self.call_stats["skipped"] += 1
            return ToolDecision("skip", tool_name, value, "low value, not worth cost",
                              cost, value, "proceed_without_tool")
        if cost > self.cost_threshold and value < 0.6:
            self.call_stats["skipped"] += 1
            return ToolDecision("skip", tool_name, value, "cost exceeds threshold",
                              cost, value, "proceed_without_tool")
        self.call_stats["used"] += 1
        return ToolDecision("use", tool_name, value, "tool value justifies cost",
                          cost, value, None)

    def should_batch(self, pending_calls: List[Dict]) -> bool:
        if len(pending_calls) < 2: return False
        independent = all(c.get("independent", True) for c in pending_calls)
        same_type = len(set(c["tool"] for c in pending_calls)) <= 2
        return independent and same_type

    def record_result(self, tool_name: str, args: Dict, result: str,
                      was_useful: bool = True):
        cache_key = f"{tool_name}:{hash(str(args))}"
        if was_useful:
            self._result_cache[cache_key] = result