File size: 3,205 Bytes
a15535e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 | """Quick local sanity check for the heuristic repair fallback.
Run with::
python demo-space/test_heuristic.py
Each case must produce a non-empty fix description and a script that
differs from the input.
"""
from __future__ import annotations
import sys
from pathlib import Path
REPO = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(REPO))
sys.path.insert(0, str(REPO / "demo-space"))
from app import _heuristic_repair # noqa: E402
CASES = [
{
"name": "AttributeError + Did you mean",
"script": (
"from transformers import Trainer, TrainingArguments\n"
"from datasets import load_dataset\n\n"
"ds = load_dataset('glue', 'sst2')\n"
"args = TrainingArguments(output_dir='out')\n"
"trainer = Trainer(model=None, args=args, train_dataset=ds['train'])\n"
"trainer.start_training()\n"
),
"trace": (
"AttributeError: 'Trainer' object has no attribute 'start_training'. "
"Did you mean: 'train'?"
),
"expect_in_repaired": "trainer.train()",
"expect_not_in_repaired": "start_training",
},
{
"name": "ModuleNotFoundError submodule",
"script": (
"import torch.legacy as torch\n"
"x = torch.randn(2, 3)\n"
"print(x)\n"
),
"trace": "ModuleNotFoundError: No module named 'torch.legacy'",
"expect_in_repaired": "import torch",
"expect_not_in_repaired": "torch.legacy",
},
{
"name": "TypeError + use ... instead",
"script": (
"from transformers import AutoTokenizer\n"
"tok = AutoTokenizer.from_pretrained('bert-base-uncased')\n"
"out = tok(['hello world'], pad_to_max_length=True, truncate=True)\n"
"print(out)\n"
),
"trace": (
"TypeError: __call__() got an unexpected keyword argument "
"'pad_to_max_length' (use `padding=True` instead)."
),
"expect_in_repaired": "padding=True",
"expect_not_in_repaired": "pad_to_max_length",
},
]
def run_one(case: dict) -> bool:
name = case["name"]
repaired, description = _heuristic_repair(case["script"], case["trace"])
ok_changed = repaired != case["script"]
ok_desc = bool(description)
ok_in = case["expect_in_repaired"] in repaired
ok_not = case["expect_not_in_repaired"] not in repaired
status = "PASS" if (ok_changed and ok_desc and ok_in and ok_not) else "FAIL"
print(f"[{status}] {name}")
print(f" description: {description!r}")
print(f" changed? {ok_changed}")
print(f" '{case['expect_in_repaired']}' in repaired? {ok_in}")
print(f" '{case['expect_not_in_repaired']}' NOT in repaired? {ok_not}")
if status == "FAIL":
print(" --- repaired script ---")
print(repaired)
print(" -----------------------")
return status == "PASS"
def main() -> int:
results = [run_one(c) for c in CASES]
print()
n_pass = sum(results)
print(f"summary: {n_pass}/{len(results)} passed")
return 0 if all(results) else 1
if __name__ == "__main__":
sys.exit(main())
|