File size: 5,493 Bytes
e3a472a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 | """Tree-sitter-aware parsing with graceful fallback.
If `tree-sitter-languages` isn't installed we degrade to a regex-based
top-level-symbol extractor — good enough for unit tests and for
languages we don't yet have grammars for.
"""
from __future__ import annotations
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, List, Optional
# Map filesystem extensions to tree-sitter grammar names.
LANG_BY_EXT = {
".py": "python", ".pyi": "python",
".rs": "rust",
".go": "go",
".js": "javascript", ".jsx": "javascript",
".ts": "typescript", ".tsx": "tsx",
".c": "c", ".h": "c",
".cc": "cpp", ".cpp": "cpp", ".cxx": "cpp", ".hpp": "cpp",
".java": "java",
".rb": "ruby",
".php": "php",
".cs": "c_sharp",
".swift": "swift",
".kt": "kotlin", ".kts": "kotlin",
".sh": "bash", ".bash": "bash",
".sql": "sql",
".html": "html",
".css": "css",
".json": "json",
".yaml": "yaml", ".yml": "yaml",
".toml": "toml",
".md": "markdown", ".markdown": "markdown",
}
@dataclass
class Symbol:
name: str
kind: str # function / class / method / module / heading
start_line: int
end_line: int
def detect_language(path: str | Path) -> Optional[str]:
suffix = Path(path).suffix.lower()
return LANG_BY_EXT.get(suffix)
def extract_symbols(text: str, language: Optional[str]) -> List[Symbol]:
"""Top-level structural symbols. Tree-sitter when available; regex fallback."""
if not text.strip():
return []
try:
return _ts_symbols(text, language)
except Exception:
return _regex_symbols(text, language)
# ---- tree-sitter path -----------------------------------------------------
def _ts_symbols(text: str, language: Optional[str]) -> List[Symbol]:
if not language:
raise RuntimeError("no language")
try:
from tree_sitter_languages import get_parser # type: ignore
except ImportError as e:
raise RuntimeError("tree_sitter_languages not installed") from e
parser = get_parser(language)
tree = parser.parse(text.encode("utf-8"))
out: List[Symbol] = []
interesting = {
"function_definition", "function_declaration", "method_definition",
"class_definition", "class_declaration", "struct_item", "trait_item",
"impl_item", "enum_item", "type_alias_declaration",
"atx_heading", "setext_heading",
}
def walk(node):
if node.type in interesting:
name = _node_name(node, text) or node.type
kind = _kind_for(node.type)
out.append(Symbol(
name=name, kind=kind,
start_line=node.start_point[0] + 1,
end_line=node.end_point[0] + 1,
))
for c in node.children:
walk(c)
walk(tree.root_node)
return out
def _node_name(node, text: str) -> Optional[str]:
for c in node.children:
if c.type == "identifier" or c.type == "type_identifier":
return text[c.start_byte:c.end_byte]
for cc in c.children:
if cc.type in ("identifier", "type_identifier"):
return text[cc.start_byte:cc.end_byte]
if node.type in ("atx_heading", "setext_heading"):
return text[node.start_byte:node.end_byte].strip().lstrip("#").strip()
return None
def _kind_for(node_type: str) -> str:
if "class" in node_type or "struct" in node_type or "trait" in node_type or "impl" in node_type:
return "class"
if "method" in node_type:
return "method"
if "function" in node_type:
return "function"
if "heading" in node_type:
return "heading"
return "symbol"
# ---- regex fallback -------------------------------------------------------
PY_DEF = re.compile(r"^(?P<indent>\s*)(?:async\s+)?def\s+(?P<name>[A-Za-z_][\w]*)\s*\(", re.MULTILINE)
PY_CLASS = re.compile(r"^(?P<indent>\s*)class\s+(?P<name>[A-Za-z_][\w]*)", re.MULTILINE)
RS_FN = re.compile(r"^\s*(?:pub(?:\([^)]*\))?\s+)?fn\s+(?P<name>[A-Za-z_][\w]*)", re.MULTILINE)
GO_FN = re.compile(r"^\s*func\s+(?:\([^)]*\)\s+)?(?P<name>[A-Za-z_][\w]*)", re.MULTILINE)
JS_FN = re.compile(r"^\s*(?:export\s+)?(?:async\s+)?function\s+(?P<name>[A-Za-z_$][\w$]*)", re.MULTILINE)
MD_HEADING = re.compile(r"^(#{1,6})\s+(?P<name>.+)$", re.MULTILINE)
def _regex_symbols(text: str, language: Optional[str]) -> List[Symbol]:
lines = text.split("\n")
out: List[Symbol] = []
def add(name: str, kind: str, m: re.Match):
line = text[:m.start()].count("\n") + 1
out.append(Symbol(name=name, kind=kind, start_line=line, end_line=line))
if language == "python":
for m in PY_CLASS.finditer(text):
add(m.group("name"), "class", m)
for m in PY_DEF.finditer(text):
indent = m.group("indent")
kind = "method" if indent else "function"
add(m.group("name"), kind, m)
elif language == "rust":
for m in RS_FN.finditer(text):
add(m.group("name"), "function", m)
elif language == "go":
for m in GO_FN.finditer(text):
add(m.group("name"), "function", m)
elif language in ("javascript", "typescript", "tsx"):
for m in JS_FN.finditer(text):
add(m.group("name"), "function", m)
elif language == "markdown":
for m in MD_HEADING.finditer(text):
add(m.group("name").strip(), "heading", m)
return out
|