Spaces:
Sleeping
Sleeping
File size: 10,326 Bytes
a937307 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 | """
preprocess.py
=============
Strip comments, docstrings, and blank lines from Python source.
Used to normalize code BEFORE embedding so the cosine-similarity step
isn't dominated by surface artifacts ("code with comments" vs
"code without comments") instead of the actual logic.
Public function:
strip(code: str) -> str
Run this file directly to execute the unit tests:
python preprocess.py
"""
import ast
import io
import sys
import tokenize
# ---------------------------------------------------------------------------
# CORE
# ---------------------------------------------------------------------------
def _remove_comment_tokens(code: str) -> str:
"""Drop tokens of type COMMENT using Python's own tokenizer.
Safe against `#` inside strings because tokenize knows the difference."""
if not code.strip():
return code
# Collect (start_pos, end_pos) of every comment token.
comment_ranges = []
try:
tokens = tokenize.generate_tokens(io.StringIO(code).readline)
for tok in tokens:
if tok.type == tokenize.COMMENT:
comment_ranges.append((tok.start, tok.end))
except (tokenize.TokenError, Exception):
# Source has lexer-level issues. Return as-is rather than corrupt it.
return code
if not comment_ranges:
return code
# Rebuild line-by-line, deleting the comment slice from each affected line.
# tokenize positions are (row, col) with row 1-indexed.
lines = code.splitlines(keepends=True)
# Group by line so we delete from the rightmost comment first
# (deleting left-first would shift columns of subsequent ones).
by_line: dict[int, list] = {}
for (sr, sc), (er, ec) in comment_ranges:
by_line.setdefault(sr, []).append((sc, ec, sr == er))
for row, ranges in by_line.items():
if row - 1 >= len(lines):
continue
line = lines[row - 1]
# Process rightmost first.
for sc, ec, single_line in sorted(ranges, key=lambda x: -x[0]):
if single_line:
# Cut from sc to ec; preserve trailing newline if present.
line = line[:sc].rstrip() + ("\n" if line.endswith("\n") else "")
else:
line = line[:sc].rstrip() + ("\n" if line.endswith("\n") else "")
lines[row - 1] = line
return "".join(lines)
def _remove_docstrings(code: str) -> str:
"""Walk the AST. For Module/FunctionDef/AsyncFunctionDef/ClassDef nodes,
if the first statement is a bare string-literal expression, that's the
docstring -- replace it with a `pass` to keep the parent body legal.
We use AST mutation + ast.unparse rather than line-deletion because
ast.unparse rebuilds source faithfully and handles every edge case
(single-line docstrings, raw strings, f-strings used as docstrings, etc.).
"""
if not code.strip():
return code
try:
tree = ast.parse(code)
except SyntaxError:
# Can't parse -> can't safely modify. Return original.
return code
docstring_node_types = (
ast.Module, ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef,
)
for node in ast.walk(tree):
if not isinstance(node, docstring_node_types):
continue
if not node.body:
continue
first = node.body[0]
# A docstring is an Expr node whose value is a Constant str.
if (isinstance(first, ast.Expr)
and isinstance(first.value, ast.Constant)
and isinstance(first.value.value, str)):
if isinstance(node, ast.Module):
# Module docstrings: always safe to remove.
node.body.pop(0)
elif len(node.body) == 1:
# Docstring is the ONLY statement — replace with pass
# so the function/class body stays syntactically legal.
node.body[0] = ast.Pass()
else:
# Docstring followed by real code — just remove it.
# No pass needed; real code keeps the body legal.
node.body.pop(0)
try:
return ast.unparse(tree)
except Exception:
# ast.unparse failed (very rare). Return original.
return code
def _remove_blank_lines(code: str) -> str:
"""Drop lines that are empty or only whitespace."""
return "\n".join(
line for line in code.splitlines() if line.strip()
)
def strip(code: str) -> str:
"""Strip comments, docstrings, and blank lines.
Order matters: docstrings first (AST-based, needs valid syntax),
then comments (token-based), then blank lines (string-based)."""
code = _remove_docstrings(code)
code = _remove_comment_tokens(code)
code = _remove_blank_lines(code)
return code
# ---------------------------------------------------------------------------
# UNIT TESTS
# ---------------------------------------------------------------------------
def _check(name: str, src: str, must_contain=None, must_not_contain=None,
must_be_empty=False, must_parse=True):
"""Run strip() on src and verify expectations."""
try:
result = strip(src)
except Exception as e:
print(f" [FAIL] {name}: strip() raised {type(e).__name__}: {e}")
return False
failures = []
if must_be_empty and result.strip():
failures.append(f"expected empty, got: {result!r}")
if must_contain:
for needle in must_contain:
if needle not in result:
failures.append(f"missing: {needle!r}")
if must_not_contain:
for needle in must_not_contain:
if needle in result:
failures.append(f"should not contain: {needle!r}")
if must_parse and result.strip():
try:
ast.parse(result)
except SyntaxError as e:
failures.append(f"output does not parse: {e}")
if failures:
print(f" [FAIL] {name}")
for f in failures:
print(f" {f}")
print(f" output was:\n "
+ result.replace("\n", "\n "))
return False
print(f" [ OK ] {name}")
return True
def run_tests():
print("=" * 70)
print("UNIT TESTS")
print("=" * 70)
passed = 0
total = 0
cases = [
# 1. Plain comment
("plain_comment",
"# this is a comment\nx = 1\n",
["x = 1"], ["this is a comment"]),
# 2. Inline comment
("inline_comment",
"x = 1 # inline\ny = 2\n",
["x = 1", "y = 2"], ["inline"]),
# 3. # inside a string -- MUST NOT be stripped
# (ast.unparse may normalize quote style, so check content only)
("hash_in_string",
'print("# not a comment")\n',
["# not a comment"], None),
# 4. # in URL string
("hash_in_url",
'url = "https://example.com#anchor"\nprint(url)\n',
["#anchor"], None),
# 5. Module-level docstring
("module_docstring",
'"""this is a module docstring"""\nx = 1\n',
["x = 1"], ["module docstring"]),
# 6. Function docstring
("function_docstring",
'def f():\n """fn docstring"""\n return 1\n',
["def f", "return 1"], ["fn docstring"]),
# 7. Class docstring
("class_docstring",
'class C:\n """class docstring"""\n x = 1\n',
["class C", "x = 1"], ["class docstring"]),
# 8. Triple-quoted string assigned to variable -- MUST be kept
("triple_quoted_value",
'x = """real value"""\nprint(x)\n',
["real value"], None),
# 9. Blank lines between code
("blank_lines",
"x = 1\n\n\ny = 2\n",
["x = 1", "y = 2"], None),
# 10. Indented inline comment
("indented_inline",
"if True:\n x = 1 # inner comment\n",
["x = 1"], ["inner comment"]),
# 11. Mixed: comments + docstring + blank lines
("mixed",
'"""module doc"""\n\n# top comment\ndef f():\n """fn doc"""\n x = 1 # inline\n return x\n\n',
["def f", "x = 1", "return x"],
["module doc", "fn doc", "top comment", "inline"]),
# 12. f-string with # in format spec
("fstring_format",
'x = 255\nprint(f"{x:#x}")\n',
["#x"], None),
# 13. Comment-only file
("comment_only",
"# only a comment\n# another\n",
None, None, True), # must_be_empty
# 14. Empty file
("empty",
"",
None, None, True),
# 15. Whitespace-only file
("whitespace_only",
" \n \n\n",
None, None, True),
]
for case in cases:
if len(case) == 4:
name, src, must, must_not = case
ok = _check(name, src, must_contain=must, must_not_contain=must_not)
else:
name, src, must, must_not, must_empty = case
ok = _check(name, src, must_contain=must,
must_not_contain=must_not, must_be_empty=must_empty)
passed += int(ok)
total += 1
print()
print(f"{passed}/{total} passed")
return passed == total
def run_apps_demo():
"""Show before/after on the 5 cached APPS samples."""
import json
from pathlib import Path
cache = Path("stress_samples.json")
if not cache.exists():
print("\n(stress_samples.json not found, skipping APPS demo)")
return
print()
print("=" * 70)
print("DEMO ON CACHED APPS SAMPLES (before/after line counts)")
print("=" * 70)
samples = json.loads(cache.read_text(encoding="utf-8"))
for s in samples:
before_lines = len(s["code"].splitlines())
stripped = strip(s["code"])
after_lines = len(stripped.splitlines())
# Verify it still parses.
try:
ast.parse(stripped)
parse_ok = "yes"
except SyntaxError:
parse_ok = "NO -- BROKEN"
sid = f"{s['category']}_{s['problem_id']}"
print(f" {sid:<22s} before={before_lines:>3d} "
f"after={after_lines:>3d} parses={parse_ok}")
if __name__ == "__main__":
ok = run_tests()
run_apps_demo()
sys.exit(0 if ok else 1) |