| """Quick local sanity check for the heuristic repair fallback. |
| |
| Run with:: |
| |
| python demo-space/test_heuristic.py |
| |
| Each case must produce a non-empty fix description and a script that |
| differs from the input. |
| """ |
| from __future__ import annotations |
|
|
| import sys |
| from pathlib import Path |
|
|
| REPO = Path(__file__).resolve().parent.parent |
| sys.path.insert(0, str(REPO)) |
| sys.path.insert(0, str(REPO / "demo-space")) |
|
|
| from app import _heuristic_repair |
|
|
| CASES = [ |
| { |
| "name": "AttributeError + Did you mean", |
| "script": ( |
| "from transformers import Trainer, TrainingArguments\n" |
| "from datasets import load_dataset\n\n" |
| "ds = load_dataset('glue', 'sst2')\n" |
| "args = TrainingArguments(output_dir='out')\n" |
| "trainer = Trainer(model=None, args=args, train_dataset=ds['train'])\n" |
| "trainer.start_training()\n" |
| ), |
| "trace": ( |
| "AttributeError: 'Trainer' object has no attribute 'start_training'. " |
| "Did you mean: 'train'?" |
| ), |
| "expect_in_repaired": "trainer.train()", |
| "expect_not_in_repaired": "start_training", |
| }, |
| { |
| "name": "ModuleNotFoundError submodule", |
| "script": ( |
| "import torch.legacy as torch\n" |
| "x = torch.randn(2, 3)\n" |
| "print(x)\n" |
| ), |
| "trace": "ModuleNotFoundError: No module named 'torch.legacy'", |
| "expect_in_repaired": "import torch", |
| "expect_not_in_repaired": "torch.legacy", |
| }, |
| { |
| "name": "TypeError + use ... instead", |
| "script": ( |
| "from transformers import AutoTokenizer\n" |
| "tok = AutoTokenizer.from_pretrained('bert-base-uncased')\n" |
| "out = tok(['hello world'], pad_to_max_length=True, truncate=True)\n" |
| "print(out)\n" |
| ), |
| "trace": ( |
| "TypeError: __call__() got an unexpected keyword argument " |
| "'pad_to_max_length' (use `padding=True` instead)." |
| ), |
| "expect_in_repaired": "padding=True", |
| "expect_not_in_repaired": "pad_to_max_length", |
| }, |
| ] |
|
|
|
|
| def run_one(case: dict) -> bool: |
| name = case["name"] |
| repaired, description = _heuristic_repair(case["script"], case["trace"]) |
|
|
| ok_changed = repaired != case["script"] |
| ok_desc = bool(description) |
| ok_in = case["expect_in_repaired"] in repaired |
| ok_not = case["expect_not_in_repaired"] not in repaired |
|
|
| status = "PASS" if (ok_changed and ok_desc and ok_in and ok_not) else "FAIL" |
| print(f"[{status}] {name}") |
| print(f" description: {description!r}") |
| print(f" changed? {ok_changed}") |
| print(f" '{case['expect_in_repaired']}' in repaired? {ok_in}") |
| print(f" '{case['expect_not_in_repaired']}' NOT in repaired? {ok_not}") |
| if status == "FAIL": |
| print(" --- repaired script ---") |
| print(repaired) |
| print(" -----------------------") |
| return status == "PASS" |
|
|
|
|
| def main() -> int: |
| results = [run_one(c) for c in CASES] |
| print() |
| n_pass = sum(results) |
| print(f"summary: {n_pass}/{len(results)} passed") |
| return 0 if all(results) else 1 |
|
|
|
|
| if __name__ == "__main__": |
| sys.exit(main()) |
|
|