narcolepticchicken commited on
Commit
7d60df1
·
verified ·
1 Parent(s): 07f0bb4

Upload aco/tool_gate.py

Browse files
Files changed (1) hide show
  1. aco/tool_gate.py +260 -0
aco/tool_gate.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tool-Use Cost Gate - Module 6.
2
+
3
+ Predicts whether a tool call is worth the cost.
4
+
5
+ Decisions:
6
+ - use tool
7
+ - skip tool
8
+ - batch tool calls
9
+ - run in parallel
10
+ - use cheaper tool
11
+ - use cached result
12
+ - ask user instead
13
+ - escalate
14
+
15
+ Tracks:
16
+ - unnecessary tool calls
17
+ - missed tool calls
18
+ - failed tool calls
19
+ - tool result ignored
20
+ - repeated tool calls
21
+ - tool cost
22
+ - tool latency
23
+ """
24
+
25
+ from typing import Dict, List, Tuple, Optional
26
+ from dataclasses import dataclass
27
+ from enum import Enum
28
+
29
+ from .trace_schema import TaskType, ToolCall
30
+ from .config import ACOConfig, ToolConfig
31
+
32
+
33
+ class ToolDecision(Enum):
34
+ USE = "use"
35
+ SKIP = "skip"
36
+ BATCH = "batch"
37
+ PARALLEL = "parallel"
38
+ USE_CHEAPER = "use_cheaper"
39
+ USE_CACHE = "use_cache"
40
+ ASK_USER = "ask_user"
41
+ ESCALATE = "escalate"
42
+
43
+
44
+ @dataclass
45
+ class ToolGateDecision:
46
+ decision: ToolDecision
47
+ tool_name: str
48
+ reasoning: str
49
+ estimated_cost: float
50
+ estimated_benefit: float # 0-1 probability of useful result
51
+ confidence: float
52
+ alternative_tool: Optional[str] = None
53
+ batched_with: Optional[List[str]] = None
54
+
55
+
56
+ class ToolUseCostGate:
57
+ """Gates tool calls based on predicted value vs. cost."""
58
+
59
+ # Tool necessity by task type (probability that task needs this tool class)
60
+ TOOL_NECESSITY = {
61
+ TaskType.QUICK_ANSWER: {
62
+ "search": 0.3,
63
+ "retrieve": 0.1,
64
+ "calculator": 0.2,
65
+ },
66
+ TaskType.CODING: {
67
+ "code_execution": 0.8,
68
+ "linter": 0.6,
69
+ "test_runner": 0.7,
70
+ "file_read": 0.9,
71
+ "file_write": 0.5,
72
+ "search": 0.4,
73
+ },
74
+ TaskType.RESEARCH: {
75
+ "search": 0.95,
76
+ "retrieve": 0.9,
77
+ "fetch": 0.7,
78
+ "summarize": 0.8,
79
+ },
80
+ TaskType.LEGAL_REGULATED: {
81
+ "document_retrieval": 0.95,
82
+ "compliance_check": 0.9,
83
+ "search": 0.6,
84
+ },
85
+ TaskType.TOOL_HEAVY: {
86
+ "search": 0.7,
87
+ "fetch": 0.6,
88
+ "api_call": 0.8,
89
+ "database_query": 0.7,
90
+ },
91
+ TaskType.RETRIEVAL_HEAVY: {
92
+ "retrieve": 0.95,
93
+ "search": 0.8,
94
+ "fetch": 0.7,
95
+ },
96
+ TaskType.LONG_HORIZON: {
97
+ "task_planner": 0.7,
98
+ "progress_tracker": 0.5,
99
+ "file_read": 0.6,
100
+ "search": 0.4,
101
+ },
102
+ }
103
+
104
+ # Cost-benefit threshold
105
+ MIN_BENEFIT_COST_RATIO = 2.0 # benefit must be > 2x cost to call
106
+
107
+ def __init__(self, config: Optional[ACOConfig] = None):
108
+ self.config = config or ACOConfig()
109
+ self.tool_stats: Dict[str, Dict] = {} # tool_name -> stats
110
+ self.decision_history: List[Dict] = []
111
+
112
+ def decide(
113
+ self,
114
+ tool_name: str,
115
+ task_type: TaskType,
116
+ tool_input: Dict,
117
+ previous_tool_calls: Optional[List[ToolCall]] = None,
118
+ current_cost_so_far: float = 0.0,
119
+ predicted_task_cost: float = 0.1,
120
+ ) -> ToolGateDecision:
121
+ """Decide whether to make a tool call."""
122
+
123
+ tool_cfg = self.config.tools.get(tool_name, ToolConfig(tool_name=tool_name))
124
+ tool_cost = tool_cfg.cost_per_call
125
+ tool_latency = tool_cfg.latency_ms_estimate
126
+
127
+ previous = previous_tool_calls or []
128
+
129
+ # Check for repeated identical calls
130
+ if self._is_repeated(tool_name, tool_input, previous):
131
+ return ToolGateDecision(
132
+ decision=ToolDecision.USE_CACHE,
133
+ tool_name=tool_name,
134
+ reasoning="Repeated identical tool call detected — use cached result",
135
+ estimated_cost=0.0,
136
+ estimated_benefit=0.9,
137
+ confidence=0.95,
138
+ )
139
+
140
+ # Check if tool result was ignored in previous steps
141
+ if previous and self._was_ignored(tool_name, previous):
142
+ # If tool results are consistently ignored, skip
143
+ return ToolGateDecision(
144
+ decision=ToolDecision.SKIP,
145
+ tool_name=tool_name,
146
+ reasoning="Previous results from this tool were ignored by the agent",
147
+ estimated_cost=0.0,
148
+ estimated_benefit=0.1,
149
+ confidence=0.8,
150
+ )
151
+
152
+ # Estimate necessity
153
+ necessity = self.TOOL_NECESSITY.get(task_type, {}).get(tool_name, 0.5)
154
+
155
+ # Adjust by historical success rate
156
+ stats = self.tool_stats.get(tool_name, {"calls": 0, "useful": 0})
157
+ if stats["calls"] > 5:
158
+ historical_useful_rate = stats["useful"] / stats["calls"]
159
+ necessity = (necessity + historical_useful_rate) / 2
160
+
161
+ # Cost escalation check: if we're already over predicted cost, be more selective
162
+ cost_ratio = current_cost_so_far / max(predicted_task_cost, 0.001)
163
+ if cost_ratio > 1.5:
164
+ necessity *= 0.7
165
+ if cost_ratio > 2.5:
166
+ necessity *= 0.5
167
+
168
+ # Normalize cost to benefit scale (assume $0.01 = 1.0 benefit unit)
169
+ normalized_cost = tool_cost / 0.01
170
+ benefit_cost_ratio = necessity / max(normalized_cost, 0.001)
171
+
172
+ if benefit_cost_ratio < self.MIN_BENEFIT_COST_RATIO and necessity < 0.5:
173
+ return ToolGateDecision(
174
+ decision=ToolDecision.SKIP,
175
+ tool_name=tool_name,
176
+ reasoning=f"Low benefit/cost ratio ({benefit_cost_ratio:.2f}) and low necessity ({necessity:.2f})",
177
+ estimated_cost=0.0,
178
+ estimated_benefit=necessity,
179
+ confidence=0.75,
180
+ )
181
+
182
+ # Check if we can batch with other pending tool calls
183
+ # (simplified: if multiple reads, batch them)
184
+ if tool_name in ("file_read", "search", "retrieve"):
185
+ return ToolGateDecision(
186
+ decision=ToolDecision.USE,
187
+ tool_name=tool_name,
188
+ reasoning=f"Tool is necessary (necessity={necessity:.2f}) and cacheable — proceed",
189
+ estimated_cost=tool_cost,
190
+ estimated_benefit=necessity,
191
+ confidence=min(necessity + 0.2, 1.0),
192
+ )
193
+
194
+ return ToolGateDecision(
195
+ decision=ToolDecision.USE,
196
+ tool_name=tool_name,
197
+ reasoning=f"Tool justified: necessity={necessity:.2f}, cost={tool_cost:.4f}",
198
+ estimated_cost=tool_cost,
199
+ estimated_benefit=necessity,
200
+ confidence=min(necessity + 0.1, 1.0),
201
+ )
202
+
203
+ def decide_batch(
204
+ self,
205
+ tool_requests: List[Tuple[str, Dict]],
206
+ task_type: TaskType,
207
+ previous_tool_calls: Optional[List[ToolCall]] = None,
208
+ ) -> List[ToolGateDecision]:
209
+ """Decide on a batch of tool calls, potentially grouping parallel ones."""
210
+ decisions = []
211
+ for tool_name, tool_input in tool_requests:
212
+ d = self.decide(tool_name, task_type, tool_input, previous_tool_calls)
213
+ decisions.append(d)
214
+
215
+ # Group independent tool calls for parallel execution
216
+ # (reads can be parallel, writes should be sequential)
217
+ read_tools = ["file_read", "search", "retrieve", "fetch", "database_query"]
218
+ parallel_group = []
219
+ sequential = []
220
+
221
+ for d in decisions:
222
+ if d.tool_name in read_tools and d.decision == ToolDecision.USE:
223
+ parallel_group.append(d)
224
+ else:
225
+ if parallel_group:
226
+ # Mark the first as parallel with the rest
227
+ parallel_group[0].batched_with = [p.tool_name for p in parallel_group[1:]]
228
+ sequential.append(parallel_group[0])
229
+ parallel_group = []
230
+ sequential.append(d)
231
+
232
+ if parallel_group:
233
+ parallel_group[0].batched_with = [p.tool_name for p in parallel_group[1:]]
234
+ sequential.append(parallel_group[0])
235
+
236
+ return sequential
237
+
238
+ def _is_repeated(self, tool_name: str, tool_input: Dict, previous: List[ToolCall]) -> bool:
239
+ """Check if this exact tool call was made before in this trace."""
240
+ for call in reversed(previous[-5:]): # check last 5
241
+ if call.tool_name == tool_name and call.tool_input == tool_input:
242
+ return True
243
+ return False
244
+
245
+ def _was_ignored(self, tool_name: str, previous: List[ToolCall]) -> bool:
246
+ """Check if previous calls to this tool had their results ignored."""
247
+ relevant = [c for c in previous if c.tool_name == tool_name]
248
+ if len(relevant) < 2:
249
+ return False
250
+ ignored_count = sum(1 for c in relevant if c.ignored_result)
251
+ return ignored_count / len(relevant) > 0.5
252
+
253
+ def record_outcome(self, tool_name: str, was_useful: bool, cost: float) -> None:
254
+ """Record whether a tool call was actually useful."""
255
+ stats = self.tool_stats.setdefault(tool_name, {"calls": 0, "useful": 0, "total_cost": 0.0})
256
+ stats["calls"] += 1
257
+ if was_useful:
258
+ stats["useful"] += 1
259
+ stats["total_cost"] += cost
260
+ stats["useful_rate"] = stats["useful"] / stats["calls"]