""" preprocess.py ============= Strip comments, docstrings, and blank lines from Python source. Used to normalize code BEFORE embedding so the cosine-similarity step isn't dominated by surface artifacts ("code with comments" vs "code without comments") instead of the actual logic. Public function: strip(code: str) -> str Run this file directly to execute the unit tests: python preprocess.py """ import ast import io import sys import tokenize # --------------------------------------------------------------------------- # CORE # --------------------------------------------------------------------------- def _remove_comment_tokens(code: str) -> str: """Drop tokens of type COMMENT using Python's own tokenizer. Safe against `#` inside strings because tokenize knows the difference.""" if not code.strip(): return code # Collect (start_pos, end_pos) of every comment token. comment_ranges = [] try: tokens = tokenize.generate_tokens(io.StringIO(code).readline) for tok in tokens: if tok.type == tokenize.COMMENT: comment_ranges.append((tok.start, tok.end)) except (tokenize.TokenError, Exception): # Source has lexer-level issues. Return as-is rather than corrupt it. return code if not comment_ranges: return code # Rebuild line-by-line, deleting the comment slice from each affected line. # tokenize positions are (row, col) with row 1-indexed. lines = code.splitlines(keepends=True) # Group by line so we delete from the rightmost comment first # (deleting left-first would shift columns of subsequent ones). by_line: dict[int, list] = {} for (sr, sc), (er, ec) in comment_ranges: by_line.setdefault(sr, []).append((sc, ec, sr == er)) for row, ranges in by_line.items(): if row - 1 >= len(lines): continue line = lines[row - 1] # Process rightmost first. for sc, ec, single_line in sorted(ranges, key=lambda x: -x[0]): if single_line: # Cut from sc to ec; preserve trailing newline if present. line = line[:sc].rstrip() + ("\n" if line.endswith("\n") else "") else: line = line[:sc].rstrip() + ("\n" if line.endswith("\n") else "") lines[row - 1] = line return "".join(lines) def _remove_docstrings(code: str) -> str: """Walk the AST. For Module/FunctionDef/AsyncFunctionDef/ClassDef nodes, if the first statement is a bare string-literal expression, that's the docstring -- replace it with a `pass` to keep the parent body legal. We use AST mutation + ast.unparse rather than line-deletion because ast.unparse rebuilds source faithfully and handles every edge case (single-line docstrings, raw strings, f-strings used as docstrings, etc.). """ if not code.strip(): return code try: tree = ast.parse(code) except SyntaxError: # Can't parse -> can't safely modify. Return original. return code docstring_node_types = ( ast.Module, ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ) for node in ast.walk(tree): if not isinstance(node, docstring_node_types): continue if not node.body: continue first = node.body[0] # A docstring is an Expr node whose value is a Constant str. if (isinstance(first, ast.Expr) and isinstance(first.value, ast.Constant) and isinstance(first.value.value, str)): if isinstance(node, ast.Module): # Module docstrings: always safe to remove. node.body.pop(0) elif len(node.body) == 1: # Docstring is the ONLY statement — replace with pass # so the function/class body stays syntactically legal. node.body[0] = ast.Pass() else: # Docstring followed by real code — just remove it. # No pass needed; real code keeps the body legal. node.body.pop(0) try: return ast.unparse(tree) except Exception: # ast.unparse failed (very rare). Return original. return code def _remove_blank_lines(code: str) -> str: """Drop lines that are empty or only whitespace.""" return "\n".join( line for line in code.splitlines() if line.strip() ) def strip(code: str) -> str: """Strip comments, docstrings, and blank lines. Order matters: docstrings first (AST-based, needs valid syntax), then comments (token-based), then blank lines (string-based).""" code = _remove_docstrings(code) code = _remove_comment_tokens(code) code = _remove_blank_lines(code) return code # --------------------------------------------------------------------------- # UNIT TESTS # --------------------------------------------------------------------------- def _check(name: str, src: str, must_contain=None, must_not_contain=None, must_be_empty=False, must_parse=True): """Run strip() on src and verify expectations.""" try: result = strip(src) except Exception as e: print(f" [FAIL] {name}: strip() raised {type(e).__name__}: {e}") return False failures = [] if must_be_empty and result.strip(): failures.append(f"expected empty, got: {result!r}") if must_contain: for needle in must_contain: if needle not in result: failures.append(f"missing: {needle!r}") if must_not_contain: for needle in must_not_contain: if needle in result: failures.append(f"should not contain: {needle!r}") if must_parse and result.strip(): try: ast.parse(result) except SyntaxError as e: failures.append(f"output does not parse: {e}") if failures: print(f" [FAIL] {name}") for f in failures: print(f" {f}") print(f" output was:\n " + result.replace("\n", "\n ")) return False print(f" [ OK ] {name}") return True def run_tests(): print("=" * 70) print("UNIT TESTS") print("=" * 70) passed = 0 total = 0 cases = [ # 1. Plain comment ("plain_comment", "# this is a comment\nx = 1\n", ["x = 1"], ["this is a comment"]), # 2. Inline comment ("inline_comment", "x = 1 # inline\ny = 2\n", ["x = 1", "y = 2"], ["inline"]), # 3. # inside a string -- MUST NOT be stripped # (ast.unparse may normalize quote style, so check content only) ("hash_in_string", 'print("# not a comment")\n', ["# not a comment"], None), # 4. # in URL string ("hash_in_url", 'url = "https://example.com#anchor"\nprint(url)\n', ["#anchor"], None), # 5. Module-level docstring ("module_docstring", '"""this is a module docstring"""\nx = 1\n', ["x = 1"], ["module docstring"]), # 6. Function docstring ("function_docstring", 'def f():\n """fn docstring"""\n return 1\n', ["def f", "return 1"], ["fn docstring"]), # 7. Class docstring ("class_docstring", 'class C:\n """class docstring"""\n x = 1\n', ["class C", "x = 1"], ["class docstring"]), # 8. Triple-quoted string assigned to variable -- MUST be kept ("triple_quoted_value", 'x = """real value"""\nprint(x)\n', ["real value"], None), # 9. Blank lines between code ("blank_lines", "x = 1\n\n\ny = 2\n", ["x = 1", "y = 2"], None), # 10. Indented inline comment ("indented_inline", "if True:\n x = 1 # inner comment\n", ["x = 1"], ["inner comment"]), # 11. Mixed: comments + docstring + blank lines ("mixed", '"""module doc"""\n\n# top comment\ndef f():\n """fn doc"""\n x = 1 # inline\n return x\n\n', ["def f", "x = 1", "return x"], ["module doc", "fn doc", "top comment", "inline"]), # 12. f-string with # in format spec ("fstring_format", 'x = 255\nprint(f"{x:#x}")\n', ["#x"], None), # 13. Comment-only file ("comment_only", "# only a comment\n# another\n", None, None, True), # must_be_empty # 14. Empty file ("empty", "", None, None, True), # 15. Whitespace-only file ("whitespace_only", " \n \n\n", None, None, True), ] for case in cases: if len(case) == 4: name, src, must, must_not = case ok = _check(name, src, must_contain=must, must_not_contain=must_not) else: name, src, must, must_not, must_empty = case ok = _check(name, src, must_contain=must, must_not_contain=must_not, must_be_empty=must_empty) passed += int(ok) total += 1 print() print(f"{passed}/{total} passed") return passed == total def run_apps_demo(): """Show before/after on the 5 cached APPS samples.""" import json from pathlib import Path cache = Path("stress_samples.json") if not cache.exists(): print("\n(stress_samples.json not found, skipping APPS demo)") return print() print("=" * 70) print("DEMO ON CACHED APPS SAMPLES (before/after line counts)") print("=" * 70) samples = json.loads(cache.read_text(encoding="utf-8")) for s in samples: before_lines = len(s["code"].splitlines()) stripped = strip(s["code"]) after_lines = len(stripped.splitlines()) # Verify it still parses. try: ast.parse(stripped) parse_ok = "yes" except SyntaxError: parse_ok = "NO -- BROKEN" sid = f"{s['category']}_{s['problem_id']}" print(f" {sid:<22s} before={before_lines:>3d} " f"after={after_lines:>3d} parses={parse_ok}") if __name__ == "__main__": ok = run_tests() run_apps_demo() sys.exit(0 if ok else 1)