File size: 10,326 Bytes
a937307
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
"""
preprocess.py
=============
Strip comments, docstrings, and blank lines from Python source.

Used to normalize code BEFORE embedding so the cosine-similarity step
isn't dominated by surface artifacts ("code with comments" vs
"code without comments") instead of the actual logic.

Public function:
    strip(code: str) -> str

Run this file directly to execute the unit tests:
    python preprocess.py
"""

import ast
import io
import sys
import tokenize


# ---------------------------------------------------------------------------
# CORE
# ---------------------------------------------------------------------------

def _remove_comment_tokens(code: str) -> str:
    """Drop tokens of type COMMENT using Python's own tokenizer.
    Safe against `#` inside strings because tokenize knows the difference."""
    if not code.strip():
        return code

    # Collect (start_pos, end_pos) of every comment token.
    comment_ranges = []
    try:
        tokens = tokenize.generate_tokens(io.StringIO(code).readline)
        for tok in tokens:
            if tok.type == tokenize.COMMENT:
                comment_ranges.append((tok.start, tok.end))
    except (tokenize.TokenError, Exception):
        # Source has lexer-level issues. Return as-is rather than corrupt it.
        return code

    if not comment_ranges:
        return code

    # Rebuild line-by-line, deleting the comment slice from each affected line.
    # tokenize positions are (row, col) with row 1-indexed.
    lines = code.splitlines(keepends=True)
    # Group by line so we delete from the rightmost comment first
    # (deleting left-first would shift columns of subsequent ones).
    by_line: dict[int, list] = {}
    for (sr, sc), (er, ec) in comment_ranges:
        by_line.setdefault(sr, []).append((sc, ec, sr == er))

    for row, ranges in by_line.items():
        if row - 1 >= len(lines):
            continue
        line = lines[row - 1]
        # Process rightmost first.
        for sc, ec, single_line in sorted(ranges, key=lambda x: -x[0]):
            if single_line:
                # Cut from sc to ec; preserve trailing newline if present.
                line = line[:sc].rstrip() + ("\n" if line.endswith("\n") else "")
            else:
                line = line[:sc].rstrip() + ("\n" if line.endswith("\n") else "")
        lines[row - 1] = line

    return "".join(lines)


def _remove_docstrings(code: str) -> str:
    """Walk the AST. For Module/FunctionDef/AsyncFunctionDef/ClassDef nodes,
    if the first statement is a bare string-literal expression, that's the
    docstring -- replace it with a `pass` to keep the parent body legal.

    We use AST mutation + ast.unparse rather than line-deletion because
    ast.unparse rebuilds source faithfully and handles every edge case
    (single-line docstrings, raw strings, f-strings used as docstrings, etc.).
    """
    if not code.strip():
        return code

    try:
        tree = ast.parse(code)
    except SyntaxError:
        # Can't parse -> can't safely modify. Return original.
        return code

    docstring_node_types = (
        ast.Module, ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef,
    )

    for node in ast.walk(tree):
        if not isinstance(node, docstring_node_types):
            continue
        if not node.body:
            continue
        first = node.body[0]
        # A docstring is an Expr node whose value is a Constant str.
        if (isinstance(first, ast.Expr)
                and isinstance(first.value, ast.Constant)
                and isinstance(first.value.value, str)):
            if isinstance(node, ast.Module):
                # Module docstrings: always safe to remove.
                node.body.pop(0)
            elif len(node.body) == 1:
                # Docstring is the ONLY statement — replace with pass
                # so the function/class body stays syntactically legal.
                node.body[0] = ast.Pass()
            else:
                # Docstring followed by real code — just remove it.
                # No pass needed; real code keeps the body legal.
                node.body.pop(0)

    try:
        return ast.unparse(tree)
    except Exception:
        # ast.unparse failed (very rare). Return original.
        return code


def _remove_blank_lines(code: str) -> str:
    """Drop lines that are empty or only whitespace."""
    return "\n".join(
        line for line in code.splitlines() if line.strip()
    )


def strip(code: str) -> str:
    """Strip comments, docstrings, and blank lines.
    Order matters: docstrings first (AST-based, needs valid syntax),
    then comments (token-based), then blank lines (string-based)."""
    code = _remove_docstrings(code)
    code = _remove_comment_tokens(code)
    code = _remove_blank_lines(code)
    return code


# ---------------------------------------------------------------------------
# UNIT TESTS
# ---------------------------------------------------------------------------

def _check(name: str, src: str, must_contain=None, must_not_contain=None,
           must_be_empty=False, must_parse=True):
    """Run strip() on src and verify expectations."""
    try:
        result = strip(src)
    except Exception as e:
        print(f"  [FAIL] {name}: strip() raised {type(e).__name__}: {e}")
        return False

    failures = []

    if must_be_empty and result.strip():
        failures.append(f"expected empty, got: {result!r}")

    if must_contain:
        for needle in must_contain:
            if needle not in result:
                failures.append(f"missing: {needle!r}")

    if must_not_contain:
        for needle in must_not_contain:
            if needle in result:
                failures.append(f"should not contain: {needle!r}")

    if must_parse and result.strip():
        try:
            ast.parse(result)
        except SyntaxError as e:
            failures.append(f"output does not parse: {e}")

    if failures:
        print(f"  [FAIL] {name}")
        for f in failures:
            print(f"         {f}")
        print(f"         output was:\n         "
              + result.replace("\n", "\n         "))
        return False
    print(f"  [ OK ] {name}")
    return True


def run_tests():
    print("=" * 70)
    print("UNIT TESTS")
    print("=" * 70)

    passed = 0
    total = 0

    cases = [
        # 1. Plain comment
        ("plain_comment",
         "# this is a comment\nx = 1\n",
         ["x = 1"], ["this is a comment"]),

        # 2. Inline comment
        ("inline_comment",
         "x = 1  # inline\ny = 2\n",
         ["x = 1", "y = 2"], ["inline"]),

        # 3. # inside a string -- MUST NOT be stripped
        # (ast.unparse may normalize quote style, so check content only)
        ("hash_in_string",
         'print("# not a comment")\n',
         ["# not a comment"], None),

        # 4. # in URL string
        ("hash_in_url",
         'url = "https://example.com#anchor"\nprint(url)\n',
         ["#anchor"], None),

        # 5. Module-level docstring
        ("module_docstring",
         '"""this is a module docstring"""\nx = 1\n',
         ["x = 1"], ["module docstring"]),

        # 6. Function docstring
        ("function_docstring",
         'def f():\n    """fn docstring"""\n    return 1\n',
         ["def f", "return 1"], ["fn docstring"]),

        # 7. Class docstring
        ("class_docstring",
         'class C:\n    """class docstring"""\n    x = 1\n',
         ["class C", "x = 1"], ["class docstring"]),

        # 8. Triple-quoted string assigned to variable -- MUST be kept
        ("triple_quoted_value",
         'x = """real value"""\nprint(x)\n',
         ["real value"], None),

        # 9. Blank lines between code
        ("blank_lines",
         "x = 1\n\n\ny = 2\n",
         ["x = 1", "y = 2"], None),

        # 10. Indented inline comment
        ("indented_inline",
         "if True:\n    x = 1  # inner comment\n",
         ["x = 1"], ["inner comment"]),

        # 11. Mixed: comments + docstring + blank lines
        ("mixed",
         '"""module doc"""\n\n# top comment\ndef f():\n    """fn doc"""\n    x = 1  # inline\n    return x\n\n',
         ["def f", "x = 1", "return x"],
         ["module doc", "fn doc", "top comment", "inline"]),

        # 12. f-string with # in format spec
        ("fstring_format",
         'x = 255\nprint(f"{x:#x}")\n',
         ["#x"], None),

        # 13. Comment-only file
        ("comment_only",
         "# only a comment\n# another\n",
         None, None, True),  # must_be_empty

        # 14. Empty file
        ("empty",
         "",
         None, None, True),

        # 15. Whitespace-only file
        ("whitespace_only",
         "   \n  \n\n",
         None, None, True),
    ]

    for case in cases:
        if len(case) == 4:
            name, src, must, must_not = case
            ok = _check(name, src, must_contain=must, must_not_contain=must_not)
        else:
            name, src, must, must_not, must_empty = case
            ok = _check(name, src, must_contain=must,
                        must_not_contain=must_not, must_be_empty=must_empty)
        passed += int(ok)
        total += 1

    print()
    print(f"{passed}/{total} passed")
    return passed == total


def run_apps_demo():
    """Show before/after on the 5 cached APPS samples."""
    import json
    from pathlib import Path

    cache = Path("stress_samples.json")
    if not cache.exists():
        print("\n(stress_samples.json not found, skipping APPS demo)")
        return

    print()
    print("=" * 70)
    print("DEMO ON CACHED APPS SAMPLES (before/after line counts)")
    print("=" * 70)

    samples = json.loads(cache.read_text(encoding="utf-8"))
    for s in samples:
        before_lines = len(s["code"].splitlines())
        stripped = strip(s["code"])
        after_lines = len(stripped.splitlines())
        # Verify it still parses.
        try:
            ast.parse(stripped)
            parse_ok = "yes"
        except SyntaxError:
            parse_ok = "NO -- BROKEN"
        sid = f"{s['category']}_{s['problem_id']}"
        print(f"  {sid:<22s}  before={before_lines:>3d}  "
              f"after={after_lines:>3d}  parses={parse_ok}")


if __name__ == "__main__":
    ok = run_tests()
    run_apps_demo()
    sys.exit(0 if ok else 1)