| |
|
|
| import ast |
| import traceback |
| from enum import Enum |
| from typing import Dict, Generator, List, Optional, Set, Tuple |
|
|
| import tree_sitter_python |
| from tree_sitter import Language, Node, Parser |
|
|
|
|
| class NodeType(Enum): |
| CLASS = "class_definition" |
| FUNCTION = "function_definition" |
| IMPORT = ["import_statement", "import_from_statement"] |
| IDENTIFIER = "identifier" |
| ATTRIBUTE = "attribute" |
| RETURN = "return_statement" |
| EXPRESSION = "expression_statement" |
| ASSIGNMENT = "assignment" |
|
|
|
|
| def traverse_tree(node: Node) -> Generator[Node, None, None]: |
| """ |
| Traverse the tree structure starting from the given node. |
| |
| :param node: The root node to start the traversal from. |
| :return: A generator object that yields nodes in the tree. |
| """ |
| cursor = node.walk() |
| depth = 0 |
|
|
| visited_children = False |
| while True: |
| if not visited_children: |
| yield cursor.node |
| if not cursor.goto_first_child(): |
| depth += 1 |
| visited_children = True |
| elif cursor.goto_next_sibling(): |
| visited_children = False |
| elif not cursor.goto_parent() or depth == 0: |
| break |
| else: |
| depth -= 1 |
|
|
|
|
| def syntax_check(code, verbose=False): |
| try: |
| ast.parse(code) |
| return True |
| except (SyntaxError, MemoryError): |
| if verbose: |
| traceback.print_exc() |
| return False |
|
|
|
|
| def code_extract(text: str) -> str: |
| lines = text.split("\n") |
| longest_line_pair = (0, 0) |
| longest_so_far = 0 |
|
|
| for i in range(len(lines)): |
| for j in range(i + 1, len(lines)): |
| current_lines = "\n".join(lines[i : j + 1]) |
| if syntax_check(current_lines): |
| current_length = sum(1 for line in lines[i : j + 1] if line.strip()) |
| if current_length > longest_so_far: |
| longest_so_far = current_length |
| longest_line_pair = (i, j) |
|
|
| return "\n".join(lines[longest_line_pair[0] : longest_line_pair[1] + 1]) |
|
|
|
|
| def get_definition_name(node: Node) -> str: |
| for child in node.children: |
| if child.type == NodeType.IDENTIFIER.value: |
| return child.text.decode("utf8") |
|
|
|
|
| def has_return_statement(node: Node) -> bool: |
| traverse_nodes = traverse_tree(node) |
| for node in traverse_nodes: |
| if node.type == NodeType.RETURN.value: |
| return True |
| return False |
|
|
|
|
| def get_deps(nodes: List[Tuple[str, Node]]) -> Dict[str, Set[str]]: |
| def dfs_get_deps(node: Node, deps: Set[str]) -> None: |
| for child in node.children: |
| if child.type == NodeType.IDENTIFIER.value: |
| deps.add(child.text.decode("utf8")) |
| else: |
| dfs_get_deps(child, deps) |
|
|
| name2deps = {} |
| for name, node in nodes: |
| deps = set() |
| dfs_get_deps(node, deps) |
| name2deps[name] = deps |
| return name2deps |
|
|
|
|
| def get_function_dependency(entrypoint: str, call_graph: Dict[str, str]) -> Set[str]: |
| queue = [entrypoint] |
| visited = {entrypoint} |
| while queue: |
| current = queue.pop(0) |
| if current not in call_graph: |
| continue |
| for neighbour in call_graph[current]: |
| if neighbour not in visited: |
| visited.add(neighbour) |
| queue.append(neighbour) |
| return visited |
|
|
|
|
| def sanitize(code: str, entrypoint: Optional[str] = None) -> str: |
| """ |
| Sanitize and extract relevant parts of the given Python code. |
| This function parses the input code, extracts import statements, class and function definitions, |
| and variable assignments. If an entrypoint is provided, it only includes definitions that are |
| reachable from the entrypoint in the call graph. |
| |
| :param code: The input Python code as a string. |
| :param entrypoint: Optional name of a function to use as the entrypoint for dependency analysis. |
| :return: A sanitized version of the input code, containing only relevant parts. |
| """ |
| code = code_extract(code) |
| code_bytes = bytes(code, "utf8") |
| parser = Parser(Language(tree_sitter_python.language())) |
| tree = parser.parse(code_bytes) |
| class_names = set() |
| function_names = set() |
| variable_names = set() |
|
|
| root_node = tree.root_node |
| import_nodes = [] |
| definition_nodes = [] |
|
|
| for child in root_node.children: |
| if child.type in NodeType.IMPORT.value: |
| import_nodes.append(child) |
| elif child.type == NodeType.CLASS.value: |
| name = get_definition_name(child) |
| if not (name in class_names or name in variable_names or name in function_names): |
| definition_nodes.append((name, child)) |
| class_names.add(name) |
| elif child.type == NodeType.FUNCTION.value: |
| name = get_definition_name(child) |
| if not (name in function_names or name in variable_names or name in class_names) and has_return_statement( |
| child |
| ): |
| definition_nodes.append((name, child)) |
| function_names.add(get_definition_name(child)) |
| elif child.type == NodeType.EXPRESSION.value and child.children[0].type == NodeType.ASSIGNMENT.value: |
| subchild = child.children[0] |
| name = get_definition_name(subchild) |
| if not (name in variable_names or name in function_names or name in class_names): |
| definition_nodes.append((name, subchild)) |
| variable_names.add(name) |
|
|
| if entrypoint: |
| name2deps = get_deps(definition_nodes) |
| reacheable = get_function_dependency(entrypoint, name2deps) |
|
|
| sanitized_output = b"" |
|
|
| for node in import_nodes: |
| sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n" |
|
|
| for pair in definition_nodes: |
| name, node = pair |
| if entrypoint and name not in reacheable: |
| continue |
| sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n" |
| return sanitized_output[:-1].decode("utf8") |
|
|