Spaces:
Running
Running
| """ | |
| agent/tools.py | |
| βββββββββββββββ | |
| Tool definitions for the reflection agent. | |
| Tools available to the agent: | |
| read_file(path) β read a file from the workspace | |
| write_patch(diff) β write a unified diff to the workspace | |
| run_tests(test_ids) β run pytest and return structured output | |
| git_diff() β show current diff vs base commit | |
| list_files(pattern) β list files matching a glob | |
| Each tool returns a structured ToolResult with success/error. | |
| The agent's LLM sees ToolResult.to_prompt_str() in its context. | |
| """ | |
| from __future__ import annotations | |
| import logging | |
| import re | |
| import subprocess | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import Any, Literal | |
| logger = logging.getLogger(__name__) | |
| # ββ Tool result βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class ToolResult: | |
| tool_name: str | |
| success: bool | |
| output: str | |
| error: str = "" | |
| metadata: dict = field(default_factory=dict) | |
| def to_prompt_str(self) -> str: | |
| """Format result for inclusion in LLM prompt.""" | |
| status = "SUCCESS" if self.success else "ERROR" | |
| parts = [f"[TOOL: {self.tool_name} | {status}]"] | |
| if self.output: | |
| parts.append(self.output[:3000]) # truncate for token budget | |
| if self.error: | |
| parts.append(f"ERROR: {self.error[:500]}") | |
| return "\n".join(parts) | |
| # ββ Individual tools ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class AgentTools: | |
| """ | |
| Collection of tools available to the reflection agent. | |
| All file operations are scoped to workspace_dir (sandbox root). | |
| """ | |
| def __init__(self, workspace_dir: Path, sandbox=None): | |
| self.workspace_dir = Path(workspace_dir) | |
| self.sandbox = sandbox # SandboxExecutor instance (optional) | |
| def read_file(self, path: str, max_lines: int = 200) -> ToolResult: | |
| """ | |
| Read the contents of a file relative to workspace_dir. | |
| Args: | |
| path: relative file path within the workspace | |
| max_lines: truncate to this many lines (token budget control) | |
| """ | |
| full_path = self.workspace_dir / path | |
| # Prevent path traversal | |
| try: | |
| full_path.resolve().relative_to(self.workspace_dir.resolve()) | |
| except ValueError: | |
| return ToolResult("read_file", False, "", f"Path traversal rejected: {path}") | |
| if not full_path.exists(): | |
| return ToolResult("read_file", False, "", f"File not found: {path}") | |
| try: | |
| content = full_path.read_text(errors="replace") | |
| lines = content.splitlines() | |
| truncated = len(lines) > max_lines | |
| visible = "\n".join(lines[:max_lines]) | |
| if truncated: | |
| visible += f"\n... [{len(lines) - max_lines} more lines truncated]" | |
| return ToolResult( | |
| "read_file", True, visible, | |
| metadata={"total_lines": len(lines), "truncated": truncated} | |
| ) | |
| except Exception as e: | |
| return ToolResult("read_file", False, "", str(e)) | |
| def write_patch(self, diff_text: str) -> ToolResult: | |
| """ | |
| Write a unified diff to a staging file for git apply. | |
| Does NOT apply the patch β call the sandbox apply_patch() separately. | |
| Args: | |
| diff_text: unified diff text (git format) | |
| """ | |
| if not diff_text.strip(): | |
| return ToolResult("write_patch", False, "", "Empty patch text") | |
| # Basic validation: must start with --- or diff --git | |
| stripped = diff_text.strip() | |
| if not (stripped.startswith("---") or stripped.startswith("diff --git")): | |
| return ToolResult( | |
| "write_patch", False, "", | |
| "Patch must start with '---' or 'diff --git'" | |
| ) | |
| patch_file = self.workspace_dir / "_agent_patch.diff" | |
| try: | |
| patch_file.write_text(diff_text) | |
| return ToolResult( | |
| "write_patch", True, | |
| f"Patch written to {patch_file.name} ({len(diff_text)} chars)", | |
| metadata={"patch_path": str(patch_file)} | |
| ) | |
| except Exception as e: | |
| return ToolResult("write_patch", False, "", str(e)) | |
| def run_tests(self, test_ids: list[str], timeout: int = 60) -> ToolResult: | |
| """ | |
| Run pytest on specific test IDs. | |
| Returns structured output including PASSED/FAILED counts and | |
| the first failing test's traceback (for reflection context). | |
| """ | |
| if not test_ids: | |
| return ToolResult("run_tests", False, "", "No test IDs provided") | |
| if self.sandbox: | |
| test_result = self.sandbox.run_tests(self.workspace_dir, test_ids) | |
| output = test_result.raw_output | |
| success = test_result.all_passed | |
| else: | |
| # Local subprocess fallback | |
| cmd = ["python", "-m", "pytest", "-v", "--tb=short", "--no-header", "-rN"] + test_ids | |
| try: | |
| proc = subprocess.run( | |
| cmd, capture_output=True, text=True, | |
| timeout=timeout, cwd=str(self.workspace_dir) | |
| ) | |
| output = proc.stdout + proc.stderr | |
| success = proc.returncode == 0 | |
| except subprocess.TimeoutExpired: | |
| return ToolResult("run_tests", False, "", f"Tests timed out after {timeout}s") | |
| except Exception as e: | |
| return ToolResult("run_tests", False, "", str(e)) | |
| # Extract key info for the agent | |
| summary = _extract_test_summary(output) | |
| return ToolResult( | |
| "run_tests", success, | |
| summary, | |
| metadata={"full_output": output[:5000]} | |
| ) | |
| def git_diff(self) -> ToolResult: | |
| """Show the current diff vs HEAD (to review what the agent has changed).""" | |
| try: | |
| result = subprocess.run( | |
| ["git", "diff"], capture_output=True, text=True, | |
| cwd=str(self.workspace_dir), timeout=10 | |
| ) | |
| diff = result.stdout or "(no changes)" | |
| return ToolResult("git_diff", True, diff[:3000]) | |
| except Exception as e: | |
| return ToolResult("git_diff", False, "", str(e)) | |
| def list_files(self, pattern: str = "**/*.py", max_results: int = 50) -> ToolResult: | |
| """List files in the workspace matching a glob pattern.""" | |
| try: | |
| files = sorted(self.workspace_dir.glob(pattern)) | |
| rel_files = [ | |
| str(f.relative_to(self.workspace_dir)) | |
| for f in files | |
| if "__pycache__" not in str(f) and ".git" not in str(f) | |
| ][:max_results] | |
| output = "\n".join(rel_files) or "(no files found)" | |
| return ToolResult("list_files", True, output, | |
| metadata={"count": len(rel_files)}) | |
| except Exception as e: | |
| return ToolResult("list_files", False, "", str(e)) | |
| # ββ Helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _extract_test_summary(pytest_output: str) -> str: | |
| """ | |
| Extract a concise test summary from raw pytest output. | |
| Includes: pass/fail counts + first failure traceback. | |
| """ | |
| lines = pytest_output.splitlines() | |
| summary_lines = [] | |
| in_failure_section = False | |
| failure_lines: list[str] = [] | |
| for line in lines: | |
| # Capture summary line | |
| if re.search(r"\d+ (passed|failed|error)", line): | |
| summary_lines.append(line) | |
| # Capture short failure tracebacks | |
| if line.startswith("FAILED") or "AssertionError" in line or "Error" in line: | |
| failure_lines.append(line) | |
| # Short traceback block | |
| if line.startswith("_ " * 3) or "FAILURES" in line: | |
| in_failure_section = True | |
| if in_failure_section: | |
| failure_lines.append(line) | |
| if len(failure_lines) > 40: # cap failure context | |
| break | |
| parts = summary_lines + ["---"] + failure_lines[:40] if failure_lines else summary_lines | |
| return "\n".join(parts) or pytest_output[:1000] | |