""" GitHub Search Code Tool Searches code across GitHub with glob filtering and line-level results. """ import asyncio import fnmatch import os import re from dataclasses import asdict, dataclass from typing import Any, Dict, List, Optional, Tuple try: import requests except ImportError: raise ImportError( "requests library is required. Install with: pip install requests" ) from agent.tools.types import ToolResult @dataclass class CodeMatch: """A code match with location information.""" repo: str path: str ref: str line_start: int line_end: int snippet: str def to_dict(self): return asdict(self) class GitHubAPIError(Exception): """Raised when GitHub API returns an error.""" pass def _get_github_token() -> str: """Get GitHub token from environment.""" token = os.environ.get("GITHUB_TOKEN") if not token: raise GitHubAPIError( "GITHUB_TOKEN environment variable is required. " "Set it with: export GITHUB_TOKEN=your_token_here" ) return token def _build_github_query( query: str, repo_glob: Optional[str], path_glob: Optional[str], regex: bool ) -> str: """Build GitHub search query string from parameters.""" parts = [] if regex: parts.append(f"/{query}/") else: if " " in query: parts.append(f'"{query}"') else: parts.append(query) if repo_glob: if "/" in repo_glob: parts.append(f"repo:{repo_glob}") else: parts.append(f"user:{repo_glob}") if path_glob: if "*" not in path_glob and "?" not in path_glob: parts.append(f"path:{path_glob}") elif path_glob.startswith("*."): ext = path_glob[2:] parts.append(f"extension:{ext}") elif "/" not in path_glob and "*" in path_glob: parts.append(f"filename:{path_glob}") else: if "." in path_glob: ext_match = re.search(r"\*\.(\w+)", path_glob) if ext_match: parts.append(f"extension:{ext_match.group(1)}") return " ".join(parts) def _fetch_code_search_results( query: str, token: str, max_results: int ) -> List[Dict[str, Any]]: """Fetch code search results from GitHub API.""" headers = { "Accept": "application/vnd.github.text-match+json", "X-GitHub-Api-Version": "2022-11-28", "Authorization": f"Bearer {token}", } all_items = [] page = 1 per_page = min(100, max_results) while len(all_items) < max_results: params = { "q": query, "page": page, "per_page": per_page, } url = "https://api.github.com/search/code" try: response = requests.get(url, headers=headers, params=params, timeout=30) if response.status_code != 200: break data = response.json() items = data.get("items", []) if not items: break all_items.extend(items) if len(all_items) >= data.get("total_count", 0): break page += 1 except requests.exceptions.RequestException: break return all_items[:max_results] def _glob_match(text: str, pattern: str) -> bool: """Check if text matches glob pattern, supporting ** for multi-level paths.""" if "**" in pattern: regex_pattern = pattern.replace("**", "<<>>") regex_pattern = fnmatch.translate(regex_pattern) regex_pattern = regex_pattern.replace("<<>>", ".*") return re.match(regex_pattern, text) is not None else: return fnmatch.fnmatch(text, pattern) def _estimate_line_numbers(fragment: str) -> Tuple[int, int]: """Estimate line numbers from a code fragment.""" lines = fragment.split("\n") line_count = len([line for line in lines if line.strip()]) return 1, line_count def _parse_results_to_matches( raw_results: List[Dict[str, Any]], repo_glob: Optional[str], path_glob: Optional[str], ) -> List[CodeMatch]: """Parse raw GitHub API results into CodeMatch objects.""" matches = [] for item in raw_results: repo_name = item.get("repository", {}).get("full_name", "unknown/unknown") file_path = item.get("path", "") sha = item.get("sha", "unknown") if repo_glob and not _glob_match(repo_name, repo_glob): continue if path_glob and not _glob_match(file_path, path_glob): continue text_matches = item.get("text_matches", []) if text_matches: for text_match in text_matches: fragment = text_match.get("fragment", "") line_start, line_end = _estimate_line_numbers(fragment) match = CodeMatch( repo=repo_name, path=file_path, ref=sha, line_start=line_start, line_end=line_end, snippet=fragment.strip(), ) matches.append(match) else: match = CodeMatch( repo=repo_name, path=file_path, ref=sha, line_start=1, line_end=1, snippet="", ) matches.append(match) return matches def search_code( query: str, repo_glob: Optional[str] = None, path_glob: Optional[str] = None, regex: bool = False, max_results: int = 100, ) -> List[CodeMatch]: """ Search for code across GitHub with glob filtering and line-level results. Returns: repo, path, ref, line_start, line_end, snippet Args: query: Search term or pattern to find in code repo_glob: Glob pattern to filter repositories (e.g., "github/*", "facebook/react") path_glob: Glob pattern to filter file paths (e.g., "*.py", "src/**/*.js") regex: If True, treat query as a regular expression max_results: Maximum number of results to return (default: 100) Returns: List of CodeMatch objects with repo, path, ref, line numbers, and snippet """ github_query = _build_github_query(query, repo_glob, path_glob, regex) token = _get_github_token() raw_results = _fetch_code_search_results(github_query, token, max_results) matches = _parse_results_to_matches(raw_results, repo_glob, path_glob) return matches async def _async_call(func, *args, **kwargs): """Wrap synchronous calls for async context.""" return await asyncio.to_thread(func, *args, **kwargs) def _format_code_matches(matches: List[CodeMatch]) -> str: """Format code matches.""" if not matches: return "No matches found." lines = [] for i, match in enumerate(matches, 1): lines.append(f"**{i}. {match.repo}/{match.path}:{match.line_start}**") lines.append("```") # Show first 5 lines of snippet snippet_lines = match.snippet.split("\n")[:5] lines.extend(snippet_lines) if len(match.snippet.split("\n")) > 5: lines.append("...") lines.append("```") lines.append("") return "\n".join(lines) class SearchCodeTool: """Tool for searching code across GitHub.""" async def execute(self, params: Dict[str, Any]) -> ToolResult: """Execute search_code operation.""" query = params.get("query") if not query: return { "formatted": "Error: 'query' parameter is required", "totalResults": 0, "resultsShared": 0, "isError": True, } repo_glob = params.get("repo_glob") path_glob = params.get("path_glob") regex = params.get("regex", False) max_results = params.get("max_results", 100) try: matches = await _async_call( search_code, query=query, repo_glob=repo_glob, path_glob=path_glob, regex=regex, max_results=max_results, ) if not matches: return { "formatted": "No matches found", "totalResults": 0, "resultsShared": 0, } formatted = _format_code_matches(matches) response = f"**Found {len(matches)} code matches:**\n\n{formatted}" # Add note about viewing full files if matches: response += "\n**To view full file, use:**\n" top_match = matches[0] response += ( f"read_file(repo='{top_match.repo}', path='{top_match.path}')" ) return { "formatted": response, "totalResults": len(matches), "resultsShared": min(len(matches), 10), } except GitHubAPIError as e: return { "formatted": f"GitHub API Error: {str(e)}", "totalResults": 0, "resultsShared": 0, "isError": True, } except Exception as e: return { "formatted": f"Error: {str(e)}", "totalResults": 0, "resultsShared": 0, "isError": True, } # Tool specification SEARCH_CODE_TOOL_SPEC = { "name": "search_code", "description": ( "Search code across GitHub with glob filtering and line-level results.\n\n" "Returns: repo, path, ref, line_start, line_end, snippet\n\n" "Examples:\n" "- Search Python functions: {'query': 'def train', 'path_glob': '*.py', 'repo_glob': 'huggingface/*'}\n" "- Search TODO comments: {'query': 'TODO', 'repo_glob': 'github/*', 'max_results': 10}\n" "- Regex search: {'query': r'func Test\\w+', 'path_glob': '*.go', 'regex': True}\n" "- Search in specific repo: {'query': 'HfApi', 'repo_glob': 'huggingface/huggingface_hub', 'path_glob': '*.py'}\n\n" ), "parameters": { "type": "object", "properties": { "query": { "type": "string", "description": "Search term or pattern to find in code", }, "repo_glob": { "type": "string", "description": "Glob pattern to filter repositories (e.g., 'github/*', 'facebook/react')", }, "path_glob": { "type": "string", "description": "Glob pattern to filter file paths (e.g., '*.py', 'src/**/*.js', 'test_*.py')", }, "regex": { "type": "boolean", "description": "Treat query as regular expression (default: false)", }, "max_results": { "type": "integer", "description": "Maximum number of results to return (default: 100)", }, }, "required": ["query"], }, } async def search_code_handler(arguments: Dict[str, Any]) -> tuple[str, bool]: """Handler for agent tool router.""" try: tool = SearchCodeTool() result = await tool.execute(arguments) return result["formatted"], not result.get("isError", False) except Exception as e: return f"Error executing search_code: {str(e)}", False