File size: 3,205 Bytes
a15535e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""Quick local sanity check for the heuristic repair fallback.

Run with::

    python demo-space/test_heuristic.py

Each case must produce a non-empty fix description and a script that
differs from the input.
"""
from __future__ import annotations

import sys
from pathlib import Path

REPO = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(REPO))
sys.path.insert(0, str(REPO / "demo-space"))

from app import _heuristic_repair  # noqa: E402

CASES = [
    {
        "name": "AttributeError + Did you mean",
        "script": (
            "from transformers import Trainer, TrainingArguments\n"
            "from datasets import load_dataset\n\n"
            "ds = load_dataset('glue', 'sst2')\n"
            "args = TrainingArguments(output_dir='out')\n"
            "trainer = Trainer(model=None, args=args, train_dataset=ds['train'])\n"
            "trainer.start_training()\n"
        ),
        "trace": (
            "AttributeError: 'Trainer' object has no attribute 'start_training'. "
            "Did you mean: 'train'?"
        ),
        "expect_in_repaired": "trainer.train()",
        "expect_not_in_repaired": "start_training",
    },
    {
        "name": "ModuleNotFoundError submodule",
        "script": (
            "import torch.legacy as torch\n"
            "x = torch.randn(2, 3)\n"
            "print(x)\n"
        ),
        "trace": "ModuleNotFoundError: No module named 'torch.legacy'",
        "expect_in_repaired": "import torch",
        "expect_not_in_repaired": "torch.legacy",
    },
    {
        "name": "TypeError + use ... instead",
        "script": (
            "from transformers import AutoTokenizer\n"
            "tok = AutoTokenizer.from_pretrained('bert-base-uncased')\n"
            "out = tok(['hello world'], pad_to_max_length=True, truncate=True)\n"
            "print(out)\n"
        ),
        "trace": (
            "TypeError: __call__() got an unexpected keyword argument "
            "'pad_to_max_length' (use `padding=True` instead)."
        ),
        "expect_in_repaired": "padding=True",
        "expect_not_in_repaired": "pad_to_max_length",
    },
]


def run_one(case: dict) -> bool:
    name = case["name"]
    repaired, description = _heuristic_repair(case["script"], case["trace"])

    ok_changed = repaired != case["script"]
    ok_desc = bool(description)
    ok_in = case["expect_in_repaired"] in repaired
    ok_not = case["expect_not_in_repaired"] not in repaired

    status = "PASS" if (ok_changed and ok_desc and ok_in and ok_not) else "FAIL"
    print(f"[{status}] {name}")
    print(f"  description: {description!r}")
    print(f"  changed?    {ok_changed}")
    print(f"  '{case['expect_in_repaired']}' in repaired? {ok_in}")
    print(f"  '{case['expect_not_in_repaired']}' NOT in repaired? {ok_not}")
    if status == "FAIL":
        print("  --- repaired script ---")
        print(repaired)
        print("  -----------------------")
    return status == "PASS"


def main() -> int:
    results = [run_one(c) for c in CASES]
    print()
    n_pass = sum(results)
    print(f"summary: {n_pass}/{len(results)} passed")
    return 0 if all(results) else 1


if __name__ == "__main__":
    sys.exit(main())