| """ |
| Tool Framework — Structured tool definitions with schema, validation, and retry. |
| |
| Tools are the agent's interface to the world. This module provides: |
| - Tool base class with JSON schema, input validation, retry logic |
| - Built-in tools (search, calculator, python_exec, file_read) |
| - Tool registry with semantic search (for Tool RAG with SLMs) |
| - Automatic tool documentation generation for prompts |
| |
| SLM optimization: Tools are formatted with minimal token overhead. |
| For small context windows, use ToolRegistry.get_relevant_tools() |
| to retrieve only the k most relevant tools (TinyAgent pattern). |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| import logging |
| import math |
| import os |
| import re |
| import sys |
| import time |
| import traceback |
| from abc import ABC, abstractmethod |
| from dataclasses import dataclass, field |
| from typing import Any, Callable |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class ToolResult: |
| """Result of a tool execution.""" |
| output: str |
| success: bool = True |
| error: str | None = None |
| execution_time_s: float = 0.0 |
| metadata: dict[str, Any] = field(default_factory=dict) |
|
|
|
|
| class Tool(ABC): |
| """ |
| Abstract tool that an Agent can invoke. |
| |
| Every tool has: |
| - name: Unique identifier |
| - description: What it does (used in agent prompts) |
| - parameters: JSON schema of expected inputs |
| - execute(): The actual implementation |
| |
| Built-in retry logic, input validation, and timeout handling. |
| |
| Example: |
| class SearchTool(Tool): |
| name = "web_search" |
| description = "Search the web for information" |
| parameters = { |
| "type": "object", |
| "properties": { |
| "query": {"type": "string", "description": "Search query"} |
| }, |
| "required": ["query"] |
| } |
| |
| def execute(self, query: str) -> str: |
| return do_search(query) |
| """ |
|
|
| name: str = "unnamed_tool" |
| description: str = "No description" |
| parameters: dict[str, Any] = {} |
| max_retries: int = 2 |
| timeout_seconds: float = 30.0 |
|
|
| @abstractmethod |
| def execute(self, **kwargs) -> str: |
| """Execute the tool with the given parameters. Return a string result.""" |
| ... |
|
|
| def run(self, **kwargs) -> ToolResult: |
| """ |
| Run the tool with validation, retry, and error handling. |
| |
| This is the method the Orchestrator calls — it wraps execute() |
| with production safeguards. |
| """ |
| |
| validation_error = self._validate_inputs(kwargs) |
| if validation_error: |
| return ToolResult( |
| output="", success=False, |
| error=f"Input validation failed: {validation_error}", |
| ) |
|
|
| |
| last_error = None |
| for attempt in range(self.max_retries + 1): |
| start = time.time() |
| try: |
| result = self.execute(**kwargs) |
| return ToolResult( |
| output=str(result), |
| success=True, |
| execution_time_s=time.time() - start, |
| metadata={"attempt": attempt + 1}, |
| ) |
| except Exception as e: |
| last_error = e |
| elapsed = time.time() - start |
| logger.warning( |
| f"Tool '{self.name}' failed (attempt {attempt + 1}/{self.max_retries + 1}): {e}" |
| ) |
| if attempt < self.max_retries: |
| time.sleep(0.5 * (attempt + 1)) |
|
|
| return ToolResult( |
| output="", |
| success=False, |
| error=f"Tool '{self.name}' failed after {self.max_retries + 1} attempts: {last_error}", |
| execution_time_s=time.time() - start, |
| ) |
|
|
| def _validate_inputs(self, kwargs: dict) -> str | None: |
| """Validate inputs against the JSON schema. Returns error string or None.""" |
| if not self.parameters: |
| return None |
|
|
| required = self.parameters.get("required", []) |
| properties = self.parameters.get("properties", {}) |
|
|
| for req in required: |
| if req not in kwargs: |
| return f"Missing required parameter: '{req}'" |
|
|
| for key, value in kwargs.items(): |
| if key in properties: |
| expected_type = properties[key].get("type") |
| if expected_type == "string" and not isinstance(value, str): |
| return f"Parameter '{key}' should be string, got {type(value).__name__}" |
| elif expected_type == "integer" and not isinstance(value, int): |
| return f"Parameter '{key}' should be integer, got {type(value).__name__}" |
| elif expected_type == "number" and not isinstance(value, (int, float)): |
| return f"Parameter '{key}' should be number, got {type(value).__name__}" |
| else: |
| return f"Unknown parameter '{key}'. Allowed parameters are: {list(properties.keys())}" |
|
|
| return None |
|
|
| def to_schema(self) -> dict[str, Any]: |
| """Return OpenAI-compatible tool schema.""" |
| return { |
| "type": "function", |
| "function": { |
| "name": self.name, |
| "description": self.description, |
| "parameters": self.parameters or {"type": "object", "properties": {}}, |
| }, |
| } |
|
|
| def to_prompt(self, compact: bool = False) -> str: |
| """ |
| Format tool for inclusion in agent prompts. |
| |
| compact=True: Minimal format for SLMs (fewer tokens) |
| compact=False: Full description with parameter details |
| """ |
| if compact: |
| params = ", ".join( |
| f"{k}: {v.get('type', 'any')}" |
| for k, v in self.parameters.get("properties", {}).items() |
| ) |
| return f"- {self.name}({params}) — {self.description}" |
|
|
| lines = [f"### {self.name}"] |
| lines.append(f" {self.description}") |
| if self.parameters.get("properties"): |
| lines.append(" Parameters:") |
| for pname, pinfo in self.parameters["properties"].items(): |
| req = "REQUIRED" if pname in self.parameters.get("required", []) else "optional" |
| lines.append(f" - {pname} ({pinfo.get('type', 'any')}, {req}): {pinfo.get('description', '')}") |
| return "\n".join(lines) |
|
|
|
|
| |
| |
| |
|
|
| class FunctionTool(Tool): |
| """ |
| Create a Tool from a plain Python function. |
| |
| Usage: |
| def search(query: str) -> str: |
| '''Search the web for information.''' |
| return requests.get(f"https://api.search.com?q={query}").text |
| |
| tool = FunctionTool.from_function(search) |
| # or |
| tool = FunctionTool( |
| name="search", |
| description="Search the web", |
| parameters={"type": "object", "properties": {"query": {"type": "string"}}, "required": ["query"]}, |
| fn=search, |
| ) |
| """ |
|
|
| def __init__( |
| self, |
| name: str, |
| description: str, |
| parameters: dict[str, Any], |
| fn: Callable[..., str], |
| **kwargs, |
| ): |
| self.name = name |
| self.description = description |
| self.parameters = parameters |
| self._fn = fn |
| super().__init__(**kwargs) |
|
|
| def execute(self, **kwargs) -> str: |
| return str(self._fn(**kwargs)) |
|
|
| @classmethod |
| def from_function(cls, fn: Callable, name: str | None = None, description: str | None = None) -> "FunctionTool": |
| """Auto-create a Tool from a function's signature and docstring.""" |
| import inspect |
|
|
| func_name = name or fn.__name__ |
| func_desc = description or fn.__doc__ or f"Execute {func_name}" |
| func_desc = func_desc.strip().split("\n")[0] |
|
|
| sig = inspect.signature(fn) |
| properties = {} |
| required = [] |
|
|
| type_map = {str: "string", int: "integer", float: "number", bool: "boolean"} |
|
|
| for pname, param in sig.parameters.items(): |
| ptype = "string" |
| if param.annotation != inspect.Parameter.empty: |
| ptype = type_map.get(param.annotation, "string") |
|
|
| properties[pname] = {"type": ptype, "description": f"Parameter {pname}"} |
|
|
| if param.default == inspect.Parameter.empty: |
| required.append(pname) |
|
|
| return cls( |
| name=func_name, |
| description=func_desc, |
| parameters={"type": "object", "properties": properties, "required": required}, |
| fn=fn, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| class CalculatorTool(Tool): |
| """Safe math expression evaluator — no eval(), no arbitrary code.""" |
| name = "calculator" |
| description = "Evaluate a mathematical expression. Supports +, -, *, /, **, sqrt, sin, cos, abs." |
| parameters = { |
| "type": "object", |
| "properties": { |
| "expression": {"type": "string", "description": "Math expression to evaluate (e.g. '2 + 3 * 4')"} |
| }, |
| "required": ["expression"], |
| } |
|
|
| |
| _SAFE_PATTERN = re.compile( |
| r'^[\d\s+\-*/().,%e]+$|' |
| r'(abs|round|min|max|sqrt|sin|cos|tan|log|pi)\b' |
| ) |
|
|
| def execute(self, expression: str) -> str: |
| import ast |
| import operator |
|
|
| |
| cleaned = expression.replace("^", "**").strip() |
| |
| tokens = re.sub(r'(abs|round|min|max|sqrt|sin|cos|tan|log|pi|e)\b', '', cleaned) |
| if re.search(r'[a-zA-Z_]', tokens): |
| return f"Error: expression contains disallowed characters: '{expression}'" |
|
|
| allowed = { |
| "abs": abs, "round": round, "min": min, "max": max, |
| "sqrt": math.sqrt, "sin": math.sin, "cos": math.cos, |
| "tan": math.tan, "log": math.log, "pi": math.pi, "e": math.e, |
| } |
| try: |
| |
| code = compile(cleaned, "<calc>", "eval") |
| |
| tree = ast.parse(cleaned, mode="eval") |
| for node in ast.walk(tree): |
| if isinstance(node, (ast.Call,)): |
| if isinstance(node.func, ast.Name) and node.func.id not in allowed: |
| return f"Error: function '{node.func.id}' not allowed" |
| result = eval(code, {"__builtins__": {}}, allowed) |
| return str(result) |
| except Exception as e: |
| return f"Error evaluating '{expression}': {e}" |
|
|
|
|
| class PythonExecTool(Tool): |
| """Execute Python code in a subprocess with timeout and temp directory.""" |
| name = "python_exec" |
| description = "Execute Python code and return the output. Use print() to output results." |
| parameters = { |
| "type": "object", |
| "properties": { |
| "code": {"type": "string", "description": "Python code to execute"} |
| }, |
| "required": ["code"], |
| } |
| timeout_seconds: float = 10.0 |
|
|
| def execute(self, code: str) -> str: |
| import subprocess |
| import tempfile |
| import os |
|
|
| |
| with tempfile.TemporaryDirectory(prefix="pa_exec_") as tmpdir: |
| script = os.path.join(tmpdir, "script.py") |
| with open(script, "w") as f: |
| f.write(code) |
| try: |
| result = subprocess.run( |
| [sys.executable, script], |
| capture_output=True, text=True, |
| timeout=self.timeout_seconds, |
| cwd=tmpdir, |
| env={**os.environ, "HOME": tmpdir}, |
| ) |
| output = result.stdout |
| if result.stderr: |
| output += f"\nSTDERR:\n{result.stderr}" |
| if result.returncode != 0: |
| output += f"\n(exit code {result.returncode})" |
| return output or "(no output)" |
| except subprocess.TimeoutExpired: |
| return f"Error: execution timed out after {self.timeout_seconds}s" |
| except Exception as e: |
| return f"Error: {e}" |
|
|
|
|
| class ReadFileTool(Tool): |
| """Read a local file — sandboxed to allowed root directory.""" |
| name = "read_file" |
| description = "Read the contents of a file at the given path." |
| parameters = { |
| "type": "object", |
| "properties": { |
| "path": {"type": "string", "description": "File path to read"} |
| }, |
| "required": ["path"], |
| } |
|
|
| def __init__(self, sandbox_root: str = ".", **kwargs): |
| self.sandbox_root = os.path.abspath(sandbox_root) |
| super().__init__(**kwargs) |
|
|
| def execute(self, path: str) -> str: |
| import os |
| resolved = os.path.abspath(path) |
| if not resolved.startswith(self.sandbox_root): |
| return f"Error: path '{path}' is outside sandbox root '{self.sandbox_root}'" |
| try: |
| with open(resolved, "r") as f: |
| content = f.read() |
| if len(content) > 10000: |
| return content[:10000] + f"\n...[truncated, {len(content)} chars total]" |
| return content |
| except Exception as e: |
| return f"Error reading '{path}': {e}" |
|
|
|
|
| class WriteFileTool(Tool): |
| """Write content to a local file — sandboxed to allowed root directory.""" |
| name = "write_file" |
| description = "Write content to a file. Creates the file if it doesn't exist." |
| parameters = { |
| "type": "object", |
| "properties": { |
| "path": {"type": "string", "description": "File path to write"}, |
| "content": {"type": "string", "description": "Content to write"}, |
| }, |
| "required": ["path", "content"], |
| } |
|
|
| def __init__(self, sandbox_root: str = ".", **kwargs): |
| self.sandbox_root = os.path.abspath(sandbox_root) |
| super().__init__(**kwargs) |
|
|
| def execute(self, path: str, content: str) -> str: |
| import os |
| resolved = os.path.abspath(path) |
| if not resolved.startswith(self.sandbox_root): |
| return f"Error: path '{path}' is outside sandbox root '{self.sandbox_root}'" |
| try: |
| os.makedirs(os.path.dirname(resolved) or ".", exist_ok=True) |
| with open(resolved, "w") as f: |
| f.write(content) |
| return f"Written {len(content)} chars to {path}" |
| except Exception as e: |
| return f"Error writing '{path}': {e}" |
|
|
|
|
| |
| |
| |
|
|
| class ToolRegistry: |
| """ |
| Registry of available tools with semantic retrieval. |
| |
| For SLMs with small context windows, you can't list all tools in the prompt. |
| Instead, use get_relevant_tools() to retrieve only the k most relevant |
| tools for the current task (TinyAgent pattern, arxiv:2409.00608). |
| |
| Usage: |
| registry = ToolRegistry() |
| registry.register(CalculatorTool()) |
| registry.register(SearchTool()) |
| registry.register(PythonExecTool()) |
| |
| # Get all tools (for LLMs with large context) |
| all_tools = registry.get_all() |
| |
| # Get top-k relevant tools (for SLMs) |
| relevant = registry.get_relevant_tools("calculate 2+2", top_k=3) |
| """ |
|
|
| def __init__(self): |
| self._tools: dict[str, Tool] = {} |
| self._embeddings: dict[str, list[float]] = {} |
|
|
| def register(self, tool: Tool) -> "ToolRegistry": |
| """Register a tool.""" |
| self._tools[tool.name] = tool |
| |
| text = f"{tool.name} {tool.description}" |
| self._embeddings[tool.name] = self._embed(text) |
| return self |
|
|
| def get(self, name: str) -> Tool | None: |
| return self._tools.get(name) |
|
|
| def get_all(self) -> list[Tool]: |
| return list(self._tools.values()) |
|
|
| def execute(self, name: str, **kwargs) -> ToolResult: |
| """Execute a tool by name.""" |
| tool = self._tools.get(name) |
| if not tool: |
| return ToolResult(output="", success=False, error=f"Unknown tool: '{name}'") |
| return tool.run(**kwargs) |
|
|
| def get_relevant_tools(self, query: str, top_k: int = 5) -> list[Tool]: |
| """ |
| Retrieve the k most relevant tools for a query. |
| |
| Uses lightweight trigram embedding + cosine similarity (same as |
| ExperienceReplay). For production, swap in sentence-transformers. |
| """ |
| if len(self._tools) <= top_k: |
| return list(self._tools.values()) |
|
|
| query_emb = self._embed(query) |
| scored = [] |
| for name, emb in self._embeddings.items(): |
| sim = self._cosine_sim(query_emb, emb) |
| scored.append((sim, name)) |
| scored.sort(key=lambda x: -x[0]) |
|
|
| return [self._tools[name] for _, name in scored[:top_k]] |
|
|
| def format_for_prompt(self, tools: list[Tool] | None = None, compact: bool = False) -> str: |
| """Format tools for inclusion in agent prompts.""" |
| tools = tools or list(self._tools.values()) |
| return "\n".join(t.to_prompt(compact=compact) for t in tools) |
|
|
| def to_schemas(self, tools: list[Tool] | None = None) -> list[dict]: |
| """Get OpenAI-compatible tool schemas.""" |
| tools = tools or list(self._tools.values()) |
| return [t.to_schema() for t in tools] |
|
|
| @staticmethod |
| def _embed(text: str) -> list[float]: |
| """Lightweight embedding (same as ExperienceReplay).""" |
| dim = 64 |
| vec = [0.0] * dim |
| text_lower = text.lower() |
| for i in range(len(text_lower) - 2): |
| trigram = text_lower[i:i + 3] |
| h = hash(trigram) % dim |
| vec[h] += 1.0 |
| magnitude = math.sqrt(sum(x * x for x in vec)) |
| if magnitude > 0: |
| vec = [x / magnitude for x in vec] |
| return vec |
|
|
| @staticmethod |
| def _cosine_sim(a: list[float], b: list[float]) -> float: |
| if not a or not b or len(a) != len(b): |
| return 0.0 |
| dot = sum(x * y for x, y in zip(a, b)) |
| return dot |
|
|