narcolepticchicken commited on
Commit
a7e3035
·
verified ·
1 Parent(s): eaafe86

Upload aco/tool_gate.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. aco/tool_gate.py +77 -253
aco/tool_gate.py CHANGED
@@ -1,260 +1,84 @@
1
- """Tool-Use Cost Gate - Module 6.
2
-
3
- Predicts whether a tool call is worth the cost.
4
-
5
- Decisions:
6
- - use tool
7
- - skip tool
8
- - batch tool calls
9
- - run in parallel
10
- - use cheaper tool
11
- - use cached result
12
- - ask user instead
13
- - escalate
14
-
15
- Tracks:
16
- - unnecessary tool calls
17
- - missed tool calls
18
- - failed tool calls
19
- - tool result ignored
20
- - repeated tool calls
21
- - tool cost
22
- - tool latency
23
- """
24
-
25
- from typing import Dict, List, Tuple, Optional
26
  from dataclasses import dataclass
27
- from enum import Enum
28
-
29
- from .trace_schema import TaskType, ToolCall
30
- from .config import ACOConfig, ToolConfig
31
-
32
-
33
- class ToolDecision(Enum):
34
- USE = "use"
35
- SKIP = "skip"
36
- BATCH = "batch"
37
- PARALLEL = "parallel"
38
- USE_CHEAPER = "use_cheaper"
39
- USE_CACHE = "use_cache"
40
- ASK_USER = "ask_user"
41
- ESCALATE = "escalate"
42
-
43
 
44
  @dataclass
45
- class ToolGateDecision:
46
- decision: ToolDecision
47
  tool_name: str
 
48
  reasoning: str
49
  estimated_cost: float
50
- estimated_benefit: float # 0-1 probability of useful result
51
- confidence: float
52
- alternative_tool: Optional[str] = None
53
- batched_with: Optional[List[str]] = None
54
-
55
-
56
- class ToolUseCostGate:
57
- """Gates tool calls based on predicted value vs. cost."""
58
-
59
- # Tool necessity by task type (probability that task needs this tool class)
60
- TOOL_NECESSITY = {
61
- TaskType.QUICK_ANSWER: {
62
- "search": 0.3,
63
- "retrieve": 0.1,
64
- "calculator": 0.2,
65
- },
66
- TaskType.CODING: {
67
- "code_execution": 0.8,
68
- "linter": 0.6,
69
- "test_runner": 0.7,
70
- "file_read": 0.9,
71
- "file_write": 0.5,
72
- "search": 0.4,
73
- },
74
- TaskType.RESEARCH: {
75
- "search": 0.95,
76
- "retrieve": 0.9,
77
- "fetch": 0.7,
78
- "summarize": 0.8,
79
- },
80
- TaskType.LEGAL_REGULATED: {
81
- "document_retrieval": 0.95,
82
- "compliance_check": 0.9,
83
- "search": 0.6,
84
- },
85
- TaskType.TOOL_HEAVY: {
86
- "search": 0.7,
87
- "fetch": 0.6,
88
- "api_call": 0.8,
89
- "database_query": 0.7,
90
- },
91
- TaskType.RETRIEVAL_HEAVY: {
92
- "retrieve": 0.95,
93
- "search": 0.8,
94
- "fetch": 0.7,
95
- },
96
- TaskType.LONG_HORIZON: {
97
- "task_planner": 0.7,
98
- "progress_tracker": 0.5,
99
- "file_read": 0.6,
100
- "search": 0.4,
101
- },
102
- }
103
-
104
- # Cost-benefit threshold
105
- MIN_BENEFIT_COST_RATIO = 2.0 # benefit must be > 2x cost to call
106
-
107
- def __init__(self, config: Optional[ACOConfig] = None):
108
- self.config = config or ACOConfig()
109
- self.tool_stats: Dict[str, Dict] = {} # tool_name -> stats
110
- self.decision_history: List[Dict] = []
111
-
112
- def decide(
113
- self,
114
- tool_name: str,
115
- task_type: TaskType,
116
- tool_input: Dict,
117
- previous_tool_calls: Optional[List[ToolCall]] = None,
118
- current_cost_so_far: float = 0.0,
119
- predicted_task_cost: float = 0.1,
120
- ) -> ToolGateDecision:
121
- """Decide whether to make a tool call."""
122
-
123
- tool_cfg = self.config.tools.get(tool_name, ToolConfig(tool_name=tool_name))
124
- tool_cost = tool_cfg.cost_per_call
125
- tool_latency = tool_cfg.latency_ms_estimate
126
-
127
- previous = previous_tool_calls or []
128
-
129
- # Check for repeated identical calls
130
- if self._is_repeated(tool_name, tool_input, previous):
131
- return ToolGateDecision(
132
- decision=ToolDecision.USE_CACHE,
133
- tool_name=tool_name,
134
- reasoning="Repeated identical tool call detected — use cached result",
135
- estimated_cost=0.0,
136
- estimated_benefit=0.9,
137
- confidence=0.95,
138
- )
139
-
140
- # Check if tool result was ignored in previous steps
141
- if previous and self._was_ignored(tool_name, previous):
142
- # If tool results are consistently ignored, skip
143
- return ToolGateDecision(
144
- decision=ToolDecision.SKIP,
145
- tool_name=tool_name,
146
- reasoning="Previous results from this tool were ignored by the agent",
147
- estimated_cost=0.0,
148
- estimated_benefit=0.1,
149
- confidence=0.8,
150
- )
151
-
152
- # Estimate necessity
153
- necessity = self.TOOL_NECESSITY.get(task_type, {}).get(tool_name, 0.5)
154
-
155
- # Adjust by historical success rate
156
- stats = self.tool_stats.get(tool_name, {"calls": 0, "useful": 0})
157
- if stats["calls"] > 5:
158
- historical_useful_rate = stats["useful"] / stats["calls"]
159
- necessity = (necessity + historical_useful_rate) / 2
160
-
161
- # Cost escalation check: if we're already over predicted cost, be more selective
162
- cost_ratio = current_cost_so_far / max(predicted_task_cost, 0.001)
163
- if cost_ratio > 1.5:
164
- necessity *= 0.7
165
- if cost_ratio > 2.5:
166
- necessity *= 0.5
167
-
168
- # Normalize cost to benefit scale (assume $0.01 = 1.0 benefit unit)
169
- normalized_cost = tool_cost / 0.01
170
- benefit_cost_ratio = necessity / max(normalized_cost, 0.001)
171
-
172
- if benefit_cost_ratio < self.MIN_BENEFIT_COST_RATIO and necessity < 0.5:
173
- return ToolGateDecision(
174
- decision=ToolDecision.SKIP,
175
- tool_name=tool_name,
176
- reasoning=f"Low benefit/cost ratio ({benefit_cost_ratio:.2f}) and low necessity ({necessity:.2f})",
177
- estimated_cost=0.0,
178
- estimated_benefit=necessity,
179
- confidence=0.75,
180
- )
181
-
182
- # Check if we can batch with other pending tool calls
183
- # (simplified: if multiple reads, batch them)
184
- if tool_name in ("file_read", "search", "retrieve"):
185
- return ToolGateDecision(
186
- decision=ToolDecision.USE,
187
- tool_name=tool_name,
188
- reasoning=f"Tool is necessary (necessity={necessity:.2f}) and cacheable — proceed",
189
- estimated_cost=tool_cost,
190
- estimated_benefit=necessity,
191
- confidence=min(necessity + 0.2, 1.0),
192
- )
193
-
194
- return ToolGateDecision(
195
- decision=ToolDecision.USE,
196
- tool_name=tool_name,
197
- reasoning=f"Tool justified: necessity={necessity:.2f}, cost={tool_cost:.4f}",
198
- estimated_cost=tool_cost,
199
- estimated_benefit=necessity,
200
- confidence=min(necessity + 0.1, 1.0),
201
- )
202
-
203
- def decide_batch(
204
- self,
205
- tool_requests: List[Tuple[str, Dict]],
206
- task_type: TaskType,
207
- previous_tool_calls: Optional[List[ToolCall]] = None,
208
- ) -> List[ToolGateDecision]:
209
- """Decide on a batch of tool calls, potentially grouping parallel ones."""
210
- decisions = []
211
- for tool_name, tool_input in tool_requests:
212
- d = self.decide(tool_name, task_type, tool_input, previous_tool_calls)
213
- decisions.append(d)
214
-
215
- # Group independent tool calls for parallel execution
216
- # (reads can be parallel, writes should be sequential)
217
- read_tools = ["file_read", "search", "retrieve", "fetch", "database_query"]
218
- parallel_group = []
219
- sequential = []
220
-
221
- for d in decisions:
222
- if d.tool_name in read_tools and d.decision == ToolDecision.USE:
223
- parallel_group.append(d)
224
- else:
225
- if parallel_group:
226
- # Mark the first as parallel with the rest
227
- parallel_group[0].batched_with = [p.tool_name for p in parallel_group[1:]]
228
- sequential.append(parallel_group[0])
229
- parallel_group = []
230
- sequential.append(d)
231
-
232
- if parallel_group:
233
- parallel_group[0].batched_with = [p.tool_name for p in parallel_group[1:]]
234
- sequential.append(parallel_group[0])
235
-
236
- return sequential
237
-
238
- def _is_repeated(self, tool_name: str, tool_input: Dict, previous: List[ToolCall]) -> bool:
239
- """Check if this exact tool call was made before in this trace."""
240
- for call in reversed(previous[-5:]): # check last 5
241
- if call.tool_name == tool_name and call.tool_input == tool_input:
242
- return True
243
- return False
244
-
245
- def _was_ignored(self, tool_name: str, previous: List[ToolCall]) -> bool:
246
- """Check if previous calls to this tool had their results ignored."""
247
- relevant = [c for c in previous if c.tool_name == tool_name]
248
- if len(relevant) < 2:
249
- return False
250
- ignored_count = sum(1 for c in relevant if c.ignored_result)
251
- return ignored_count / len(relevant) > 0.5
252
-
253
- def record_outcome(self, tool_name: str, was_useful: bool, cost: float) -> None:
254
- """Record whether a tool call was actually useful."""
255
- stats = self.tool_stats.setdefault(tool_name, {"calls": 0, "useful": 0, "total_cost": 0.0})
256
- stats["calls"] += 1
257
  if was_useful:
258
- stats["useful"] += 1
259
- stats["total_cost"] += cost
260
- stats["useful_rate"] = stats["useful"] / stats["calls"]
 
1
+ """Tool-Use Cost Gate: Predict whether a tool call is worth the cost."""
2
+ from typing import Dict, List, Optional, Tuple
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from dataclasses import dataclass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  @dataclass
6
+ class ToolDecision:
7
+ action: str # "use", "skip", "batch", "cache", "escalate"
8
  tool_name: str
9
+ confidence: float
10
  reasoning: str
11
  estimated_cost: float
12
+ estimated_value: float
13
+ alternative: Optional[str] = None
14
+
15
+ TOOL_COST_ESTIMATES = {
16
+ "web_search": {"cost": 0.01, "latency_ms": 2000, "value_base": 0.6},
17
+ "code_search": {"cost": 0.005, "latency_ms": 1000, "value_base": 0.7},
18
+ "file_read": {"cost": 0.001, "latency_ms": 100, "value_base": 0.8},
19
+ "file_write": {"cost": 0.001, "latency_ms": 100, "value_base": 0.9},
20
+ "api_call": {"cost": 0.05, "latency_ms": 3000, "value_base": 0.5},
21
+ "database_query": {"cost": 0.02, "latency_ms": 500, "value_base": 0.6},
22
+ "code_execute": {"cost": 0.01, "latency_ms": 5000, "value_base": 0.7},
23
+ "web_scrape": {"cost": 0.02, "latency_ms": 5000, "value_base": 0.4},
24
+ "summarize": {"cost": 0.005, "latency_ms": 2000, "value_base": 0.5},
25
+ "verify": {"cost": 0.02, "latency_ms": 3000, "value_base": 0.6},
26
+ }
27
+
28
+ class ToolCostGate:
29
+ def __init__(self, cost_threshold: float = 0.05, value_threshold: float = 0.3,
30
+ batch_window_ms: int = 5000):
31
+ self.cost_threshold = cost_threshold
32
+ self.value_threshold = value_threshold
33
+ self.batch_window_ms = batch_window_ms
34
+ self.pending_calls: List[Dict] = []
35
+ self.call_stats = {"used":0,"skipped":0,"batched":0,"cached":0,"escalated":0}
36
+ self._result_cache: Dict[str, str] = {}
37
+
38
+ def gate(self, tool_name: str, args: Dict, task_type: str,
39
+ step_num: int, total_steps: int, confidence: float,
40
+ prior_results: List[str] = None) -> ToolDecision:
41
+ est = TOOL_COST_ESTIMATES.get(tool_name, {"cost":0.02,"latency_ms":2000,"value_base":0.5})
42
+ # Check cache
43
+ cache_key = f"{tool_name}:{hash(str(args))}"
44
+ if cache_key in self._result_cache:
45
+ self.call_stats["cached"] += 1
46
+ return ToolDecision("cache", tool_name, 1.0, "cached result available",
47
+ 0.0, est["value_base"], "use_cached_result")
48
+ # Adjust value based on context
49
+ value = est["value_base"]
50
+ if task_type == "coding" and tool_name in ("code_search","file_read","code_execute"):
51
+ value += 0.2
52
+ if task_type == "research" and tool_name in ("web_search","web_scrape"):
53
+ value += 0.2
54
+ if step_num == 1 and tool_name in ("web_search","code_search"):
55
+ value += 0.1 # First search is usually valuable
56
+ if prior_results and len(prior_results) > 3:
57
+ value -= 0.2 # Already have enough info
58
+ if confidence > 0.8:
59
+ value -= 0.3 # Already confident, tool less valuable
60
+ cost = est["cost"]
61
+ # Decision
62
+ if value < self.value_threshold:
63
+ self.call_stats["skipped"] += 1
64
+ return ToolDecision("skip", tool_name, value, "low value, not worth cost",
65
+ cost, value, "proceed_without_tool")
66
+ if cost > self.cost_threshold and value < 0.6:
67
+ self.call_stats["skipped"] += 1
68
+ return ToolDecision("skip", tool_name, value, "cost exceeds threshold",
69
+ cost, value, "proceed_without_tool")
70
+ self.call_stats["used"] += 1
71
+ return ToolDecision("use", tool_name, value, "tool value justifies cost",
72
+ cost, value, None)
73
+
74
+ def should_batch(self, pending_calls: List[Dict]) -> bool:
75
+ if len(pending_calls) < 2: return False
76
+ independent = all(c.get("independent", True) for c in pending_calls)
77
+ same_type = len(set(c["tool"] for c in pending_calls)) <= 2
78
+ return independent and same_type
79
+
80
+ def record_result(self, tool_name: str, args: Dict, result: str,
81
+ was_useful: bool = True):
82
+ cache_key = f"{tool_name}:{hash(str(args))}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  if was_useful:
84
+ self._result_cache[cache_key] = result