Rohan03 commited on
Commit
854d13a
·
verified ·
1 Parent(s): 0d16e44

v0.2.0: Add purpose_agent/tools.py

Browse files
Files changed (1) hide show
  1. purpose_agent/tools.py +459 -0
purpose_agent/tools.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tool Framework — Structured tool definitions with schema, validation, and retry.
3
+
4
+ Tools are the agent's interface to the world. This module provides:
5
+ - Tool base class with JSON schema, input validation, retry logic
6
+ - Built-in tools (search, calculator, python_exec, file_read)
7
+ - Tool registry with semantic search (for Tool RAG with SLMs)
8
+ - Automatic tool documentation generation for prompts
9
+
10
+ SLM optimization: Tools are formatted with minimal token overhead.
11
+ For small context windows, use ToolRegistry.get_relevant_tools()
12
+ to retrieve only the k most relevant tools (TinyAgent pattern).
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import json
18
+ import logging
19
+ import math
20
+ import re
21
+ import time
22
+ import traceback
23
+ from abc import ABC, abstractmethod
24
+ from dataclasses import dataclass, field
25
+ from typing import Any, Callable
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ # ---------------------------------------------------------------------------
31
+ # Tool Base Class
32
+ # ---------------------------------------------------------------------------
33
+
34
+ @dataclass
35
+ class ToolResult:
36
+ """Result of a tool execution."""
37
+ output: str
38
+ success: bool = True
39
+ error: str | None = None
40
+ execution_time_s: float = 0.0
41
+ metadata: dict[str, Any] = field(default_factory=dict)
42
+
43
+
44
+ class Tool(ABC):
45
+ """
46
+ Abstract tool that an Agent can invoke.
47
+
48
+ Every tool has:
49
+ - name: Unique identifier
50
+ - description: What it does (used in agent prompts)
51
+ - parameters: JSON schema of expected inputs
52
+ - execute(): The actual implementation
53
+
54
+ Built-in retry logic, input validation, and timeout handling.
55
+
56
+ Example:
57
+ class SearchTool(Tool):
58
+ name = "web_search"
59
+ description = "Search the web for information"
60
+ parameters = {
61
+ "type": "object",
62
+ "properties": {
63
+ "query": {"type": "string", "description": "Search query"}
64
+ },
65
+ "required": ["query"]
66
+ }
67
+
68
+ def execute(self, query: str) -> str:
69
+ return do_search(query)
70
+ """
71
+
72
+ name: str = "unnamed_tool"
73
+ description: str = "No description"
74
+ parameters: dict[str, Any] = {}
75
+ max_retries: int = 2
76
+ timeout_seconds: float = 30.0
77
+
78
+ @abstractmethod
79
+ def execute(self, **kwargs) -> str:
80
+ """Execute the tool with the given parameters. Return a string result."""
81
+ ...
82
+
83
+ def run(self, **kwargs) -> ToolResult:
84
+ """
85
+ Run the tool with validation, retry, and error handling.
86
+
87
+ This is the method the Orchestrator calls — it wraps execute()
88
+ with production safeguards.
89
+ """
90
+ # Validate inputs against schema
91
+ validation_error = self._validate_inputs(kwargs)
92
+ if validation_error:
93
+ return ToolResult(
94
+ output="", success=False,
95
+ error=f"Input validation failed: {validation_error}",
96
+ )
97
+
98
+ # Retry loop
99
+ last_error = None
100
+ for attempt in range(self.max_retries + 1):
101
+ start = time.time()
102
+ try:
103
+ result = self.execute(**kwargs)
104
+ return ToolResult(
105
+ output=str(result),
106
+ success=True,
107
+ execution_time_s=time.time() - start,
108
+ metadata={"attempt": attempt + 1},
109
+ )
110
+ except Exception as e:
111
+ last_error = e
112
+ elapsed = time.time() - start
113
+ logger.warning(
114
+ f"Tool '{self.name}' failed (attempt {attempt + 1}/{self.max_retries + 1}): {e}"
115
+ )
116
+ if attempt < self.max_retries:
117
+ time.sleep(0.5 * (attempt + 1)) # Exponential backoff
118
+
119
+ return ToolResult(
120
+ output="",
121
+ success=False,
122
+ error=f"Tool '{self.name}' failed after {self.max_retries + 1} attempts: {last_error}",
123
+ execution_time_s=time.time() - start,
124
+ )
125
+
126
+ def _validate_inputs(self, kwargs: dict) -> str | None:
127
+ """Validate inputs against the JSON schema. Returns error string or None."""
128
+ if not self.parameters:
129
+ return None
130
+
131
+ required = self.parameters.get("required", [])
132
+ properties = self.parameters.get("properties", {})
133
+
134
+ for req in required:
135
+ if req not in kwargs:
136
+ return f"Missing required parameter: '{req}'"
137
+
138
+ for key, value in kwargs.items():
139
+ if key in properties:
140
+ expected_type = properties[key].get("type")
141
+ if expected_type == "string" and not isinstance(value, str):
142
+ return f"Parameter '{key}' should be string, got {type(value).__name__}"
143
+ elif expected_type == "integer" and not isinstance(value, int):
144
+ return f"Parameter '{key}' should be integer, got {type(value).__name__}"
145
+ elif expected_type == "number" and not isinstance(value, (int, float)):
146
+ return f"Parameter '{key}' should be number, got {type(value).__name__}"
147
+
148
+ return None
149
+
150
+ def to_schema(self) -> dict[str, Any]:
151
+ """Return OpenAI-compatible tool schema."""
152
+ return {
153
+ "type": "function",
154
+ "function": {
155
+ "name": self.name,
156
+ "description": self.description,
157
+ "parameters": self.parameters or {"type": "object", "properties": {}},
158
+ },
159
+ }
160
+
161
+ def to_prompt(self, compact: bool = False) -> str:
162
+ """
163
+ Format tool for inclusion in agent prompts.
164
+
165
+ compact=True: Minimal format for SLMs (fewer tokens)
166
+ compact=False: Full description with parameter details
167
+ """
168
+ if compact:
169
+ params = ", ".join(
170
+ f"{k}: {v.get('type', 'any')}"
171
+ for k, v in self.parameters.get("properties", {}).items()
172
+ )
173
+ return f"- {self.name}({params}) — {self.description}"
174
+
175
+ lines = [f"### {self.name}"]
176
+ lines.append(f" {self.description}")
177
+ if self.parameters.get("properties"):
178
+ lines.append(" Parameters:")
179
+ for pname, pinfo in self.parameters["properties"].items():
180
+ req = "REQUIRED" if pname in self.parameters.get("required", []) else "optional"
181
+ lines.append(f" - {pname} ({pinfo.get('type', 'any')}, {req}): {pinfo.get('description', '')}")
182
+ return "\n".join(lines)
183
+
184
+
185
+ # ---------------------------------------------------------------------------
186
+ # Function Tool — Create tools from plain functions
187
+ # ---------------------------------------------------------------------------
188
+
189
+ class FunctionTool(Tool):
190
+ """
191
+ Create a Tool from a plain Python function.
192
+
193
+ Usage:
194
+ def search(query: str) -> str:
195
+ '''Search the web for information.'''
196
+ return requests.get(f"https://api.search.com?q={query}").text
197
+
198
+ tool = FunctionTool.from_function(search)
199
+ # or
200
+ tool = FunctionTool(
201
+ name="search",
202
+ description="Search the web",
203
+ parameters={"type": "object", "properties": {"query": {"type": "string"}}, "required": ["query"]},
204
+ fn=search,
205
+ )
206
+ """
207
+
208
+ def __init__(
209
+ self,
210
+ name: str,
211
+ description: str,
212
+ parameters: dict[str, Any],
213
+ fn: Callable[..., str],
214
+ **kwargs,
215
+ ):
216
+ self.name = name
217
+ self.description = description
218
+ self.parameters = parameters
219
+ self._fn = fn
220
+ super().__init__(**kwargs)
221
+
222
+ def execute(self, **kwargs) -> str:
223
+ return str(self._fn(**kwargs))
224
+
225
+ @classmethod
226
+ def from_function(cls, fn: Callable, name: str | None = None, description: str | None = None) -> "FunctionTool":
227
+ """Auto-create a Tool from a function's signature and docstring."""
228
+ import inspect
229
+
230
+ func_name = name or fn.__name__
231
+ func_desc = description or fn.__doc__ or f"Execute {func_name}"
232
+ func_desc = func_desc.strip().split("\n")[0] # First line only
233
+
234
+ sig = inspect.signature(fn)
235
+ properties = {}
236
+ required = []
237
+
238
+ type_map = {str: "string", int: "integer", float: "number", bool: "boolean"}
239
+
240
+ for pname, param in sig.parameters.items():
241
+ ptype = "string"
242
+ if param.annotation != inspect.Parameter.empty:
243
+ ptype = type_map.get(param.annotation, "string")
244
+
245
+ properties[pname] = {"type": ptype, "description": f"Parameter {pname}"}
246
+
247
+ if param.default == inspect.Parameter.empty:
248
+ required.append(pname)
249
+
250
+ return cls(
251
+ name=func_name,
252
+ description=func_desc,
253
+ parameters={"type": "object", "properties": properties, "required": required},
254
+ fn=fn,
255
+ )
256
+
257
+
258
+ # ---------------------------------------------------------------------------
259
+ # Built-in Tools
260
+ # ---------------------------------------------------------------------------
261
+
262
+ class CalculatorTool(Tool):
263
+ """Safe math expression evaluator."""
264
+ name = "calculator"
265
+ description = "Evaluate a mathematical expression. Supports +, -, *, /, **, sqrt, sin, cos, abs."
266
+ parameters = {
267
+ "type": "object",
268
+ "properties": {
269
+ "expression": {"type": "string", "description": "Math expression to evaluate (e.g. '2 + 3 * 4')"}
270
+ },
271
+ "required": ["expression"],
272
+ }
273
+
274
+ def execute(self, expression: str) -> str:
275
+ # Safe eval with math functions only
276
+ allowed = {
277
+ "abs": abs, "round": round, "min": min, "max": max,
278
+ "sqrt": math.sqrt, "sin": math.sin, "cos": math.cos,
279
+ "tan": math.tan, "log": math.log, "pi": math.pi, "e": math.e,
280
+ }
281
+ # Sanitize: only allow digits, operators, parentheses, dots, and allowed function names
282
+ clean = re.sub(r'[^0-9+\-*/().,%\s]', '', expression.replace("^", "**"))
283
+ try:
284
+ result = eval(expression, {"__builtins__": {}}, allowed)
285
+ return str(result)
286
+ except Exception as e:
287
+ return f"Error evaluating '{expression}': {e}"
288
+
289
+
290
+ class PythonExecTool(Tool):
291
+ """Execute Python code in a sandboxed environment."""
292
+ name = "python_exec"
293
+ description = "Execute Python code and return the output. Use print() to output results."
294
+ parameters = {
295
+ "type": "object",
296
+ "properties": {
297
+ "code": {"type": "string", "description": "Python code to execute"}
298
+ },
299
+ "required": ["code"],
300
+ }
301
+
302
+ def execute(self, code: str) -> str:
303
+ import io
304
+ import contextlib
305
+
306
+ output = io.StringIO()
307
+ try:
308
+ with contextlib.redirect_stdout(output):
309
+ exec(code, {"__builtins__": __builtins__}, {})
310
+ return output.getvalue() or "(no output)"
311
+ except Exception as e:
312
+ return f"Error: {e}\n{traceback.format_exc()}"
313
+
314
+
315
+ class ReadFileTool(Tool):
316
+ """Read a local file."""
317
+ name = "read_file"
318
+ description = "Read the contents of a file at the given path."
319
+ parameters = {
320
+ "type": "object",
321
+ "properties": {
322
+ "path": {"type": "string", "description": "File path to read"}
323
+ },
324
+ "required": ["path"],
325
+ }
326
+
327
+ def execute(self, path: str) -> str:
328
+ try:
329
+ with open(path, "r") as f:
330
+ content = f.read()
331
+ if len(content) > 10000:
332
+ return content[:10000] + f"\n...[truncated, {len(content)} chars total]"
333
+ return content
334
+ except Exception as e:
335
+ return f"Error reading '{path}': {e}"
336
+
337
+
338
+ class WriteFileTool(Tool):
339
+ """Write content to a local file."""
340
+ name = "write_file"
341
+ description = "Write content to a file. Creates the file if it doesn't exist."
342
+ parameters = {
343
+ "type": "object",
344
+ "properties": {
345
+ "path": {"type": "string", "description": "File path to write"},
346
+ "content": {"type": "string", "description": "Content to write"},
347
+ },
348
+ "required": ["path", "content"],
349
+ }
350
+
351
+ def execute(self, path: str, content: str) -> str:
352
+ try:
353
+ with open(path, "w") as f:
354
+ f.write(content)
355
+ return f"Written {len(content)} chars to {path}"
356
+ except Exception as e:
357
+ return f"Error writing '{path}': {e}"
358
+
359
+
360
+ # ---------------------------------------------------------------------------
361
+ # Tool Registry — with semantic retrieval for SLMs (Tool RAG)
362
+ # ---------------------------------------------------------------------------
363
+
364
+ class ToolRegistry:
365
+ """
366
+ Registry of available tools with semantic retrieval.
367
+
368
+ For SLMs with small context windows, you can't list all tools in the prompt.
369
+ Instead, use get_relevant_tools() to retrieve only the k most relevant
370
+ tools for the current task (TinyAgent pattern, arxiv:2409.00608).
371
+
372
+ Usage:
373
+ registry = ToolRegistry()
374
+ registry.register(CalculatorTool())
375
+ registry.register(SearchTool())
376
+ registry.register(PythonExecTool())
377
+
378
+ # Get all tools (for LLMs with large context)
379
+ all_tools = registry.get_all()
380
+
381
+ # Get top-k relevant tools (for SLMs)
382
+ relevant = registry.get_relevant_tools("calculate 2+2", top_k=3)
383
+ """
384
+
385
+ def __init__(self):
386
+ self._tools: dict[str, Tool] = {}
387
+ self._embeddings: dict[str, list[float]] = {}
388
+
389
+ def register(self, tool: Tool) -> "ToolRegistry":
390
+ """Register a tool."""
391
+ self._tools[tool.name] = tool
392
+ # Compute embedding for Tool RAG
393
+ text = f"{tool.name} {tool.description}"
394
+ self._embeddings[tool.name] = self._embed(text)
395
+ return self
396
+
397
+ def get(self, name: str) -> Tool | None:
398
+ return self._tools.get(name)
399
+
400
+ def get_all(self) -> list[Tool]:
401
+ return list(self._tools.values())
402
+
403
+ def execute(self, name: str, **kwargs) -> ToolResult:
404
+ """Execute a tool by name."""
405
+ tool = self._tools.get(name)
406
+ if not tool:
407
+ return ToolResult(output="", success=False, error=f"Unknown tool: '{name}'")
408
+ return tool.run(**kwargs)
409
+
410
+ def get_relevant_tools(self, query: str, top_k: int = 5) -> list[Tool]:
411
+ """
412
+ Retrieve the k most relevant tools for a query.
413
+
414
+ Uses lightweight trigram embedding + cosine similarity (same as
415
+ ExperienceReplay). For production, swap in sentence-transformers.
416
+ """
417
+ if len(self._tools) <= top_k:
418
+ return list(self._tools.values())
419
+
420
+ query_emb = self._embed(query)
421
+ scored = []
422
+ for name, emb in self._embeddings.items():
423
+ sim = self._cosine_sim(query_emb, emb)
424
+ scored.append((sim, name))
425
+ scored.sort(key=lambda x: -x[0])
426
+
427
+ return [self._tools[name] for _, name in scored[:top_k]]
428
+
429
+ def format_for_prompt(self, tools: list[Tool] | None = None, compact: bool = False) -> str:
430
+ """Format tools for inclusion in agent prompts."""
431
+ tools = tools or list(self._tools.values())
432
+ return "\n".join(t.to_prompt(compact=compact) for t in tools)
433
+
434
+ def to_schemas(self, tools: list[Tool] | None = None) -> list[dict]:
435
+ """Get OpenAI-compatible tool schemas."""
436
+ tools = tools or list(self._tools.values())
437
+ return [t.to_schema() for t in tools]
438
+
439
+ @staticmethod
440
+ def _embed(text: str) -> list[float]:
441
+ """Lightweight embedding (same as ExperienceReplay)."""
442
+ dim = 64
443
+ vec = [0.0] * dim
444
+ text_lower = text.lower()
445
+ for i in range(len(text_lower) - 2):
446
+ trigram = text_lower[i:i + 3]
447
+ h = hash(trigram) % dim
448
+ vec[h] += 1.0
449
+ magnitude = math.sqrt(sum(x * x for x in vec))
450
+ if magnitude > 0:
451
+ vec = [x / magnitude for x in vec]
452
+ return vec
453
+
454
+ @staticmethod
455
+ def _cosine_sim(a: list[float], b: list[float]) -> float:
456
+ if not a or not b or len(a) != len(b):
457
+ return 0.0
458
+ dot = sum(x * y for x, y in zip(a, b))
459
+ return dot # Vectors are already normalized