"""Quick local sanity check for the heuristic repair fallback. Run with:: python demo-space/test_heuristic.py Each case must produce a non-empty fix description and a script that differs from the input. """ from __future__ import annotations import sys from pathlib import Path REPO = Path(__file__).resolve().parent.parent sys.path.insert(0, str(REPO)) sys.path.insert(0, str(REPO / "demo-space")) from app import _heuristic_repair # noqa: E402 CASES = [ { "name": "AttributeError + Did you mean", "script": ( "from transformers import Trainer, TrainingArguments\n" "from datasets import load_dataset\n\n" "ds = load_dataset('glue', 'sst2')\n" "args = TrainingArguments(output_dir='out')\n" "trainer = Trainer(model=None, args=args, train_dataset=ds['train'])\n" "trainer.start_training()\n" ), "trace": ( "AttributeError: 'Trainer' object has no attribute 'start_training'. " "Did you mean: 'train'?" ), "expect_in_repaired": "trainer.train()", "expect_not_in_repaired": "start_training", }, { "name": "ModuleNotFoundError submodule", "script": ( "import torch.legacy as torch\n" "x = torch.randn(2, 3)\n" "print(x)\n" ), "trace": "ModuleNotFoundError: No module named 'torch.legacy'", "expect_in_repaired": "import torch", "expect_not_in_repaired": "torch.legacy", }, { "name": "TypeError + use ... instead", "script": ( "from transformers import AutoTokenizer\n" "tok = AutoTokenizer.from_pretrained('bert-base-uncased')\n" "out = tok(['hello world'], pad_to_max_length=True, truncate=True)\n" "print(out)\n" ), "trace": ( "TypeError: __call__() got an unexpected keyword argument " "'pad_to_max_length' (use `padding=True` instead)." ), "expect_in_repaired": "padding=True", "expect_not_in_repaired": "pad_to_max_length", }, ] def run_one(case: dict) -> bool: name = case["name"] repaired, description = _heuristic_repair(case["script"], case["trace"]) ok_changed = repaired != case["script"] ok_desc = bool(description) ok_in = case["expect_in_repaired"] in repaired ok_not = case["expect_not_in_repaired"] not in repaired status = "PASS" if (ok_changed and ok_desc and ok_in and ok_not) else "FAIL" print(f"[{status}] {name}") print(f" description: {description!r}") print(f" changed? {ok_changed}") print(f" '{case['expect_in_repaired']}' in repaired? {ok_in}") print(f" '{case['expect_not_in_repaired']}' NOT in repaired? {ok_not}") if status == "FAIL": print(" --- repaired script ---") print(repaired) print(" -----------------------") return status == "PASS" def main() -> int: results = [run_one(c) for c in CASES] print() n_pass = sum(results) print(f"summary: {n_pass}/{len(results)} passed") return 0 if all(results) else 1 if __name__ == "__main__": sys.exit(main())