| """ |
| Comprehensive tests for IsomorphicPerturbationTesting. |
| |
| Covers: |
| 1. extract_hypothesis β all formatting variants |
| 2. verify β correct rule, shortcut, bad syntax, edge cases |
| 3. Dataset ground-truth sanity check (SLR-Bench v1-All) |
| 4. Full _compute round-trip |
| """ |
|
|
| import multiprocessing as mp |
| import sys |
| import traceback |
| from pathlib import Path |
|
|
| from tqdm import tqdm |
|
|
| |
| |
| |
|
|
| PASS = "\033[92mPASS\033[0m" |
| FAIL = "\033[91mFAIL\033[0m" |
|
|
| _results = [] |
|
|
| def check(name, cond, detail=""): |
| status = PASS if cond else FAIL |
| print(f" [{status}] {name}" + (f" β {detail}" if detail else "")) |
| _results.append((name, cond)) |
|
|
| def section(title): |
| print(f"\n{'='*60}\n {title}\n{'='*60}") |
|
|
|
|
| |
| |
| |
|
|
| MINI_VP = """\ |
| eastbound(train0). |
| westbound(train1). |
| eastbound(train2). |
| westbound(train3). |
| has_car(train0, car0_1). |
| has_car(train1, car1_1). |
| has_car(train2, car2_1). |
| has_car(train3, car3_1). |
| car_color(car0_1, red). |
| car_color(car1_1, blue). |
| car_color(car2_1, red). |
| car_color(car3_1, blue). |
| """ |
|
|
| EVAL_CFG = {"positive_predicate": "eastbound", "negative_predicate": "westbound"} |
|
|
| |
| GOOD_RULE = "eastbound(T) :- has_car(T, C), car_color(C, red)." |
|
|
| |
| SHORTCUT = "eastbound(train0). eastbound(train2)." |
|
|
| |
| WRONG_RULE = "eastbound(T) :- has_car(T, C), car_color(C, blue)." |
|
|
| |
| BAD_SYNTAX = "this is not prolog at all :-" |
|
|
|
|
| |
| |
| |
|
|
| section("1. extract_hypothesis") |
|
|
| from ipt.verifier import extract_hypothesis |
|
|
| |
| out = extract_hypothesis("eastbound(T) :- has_car(T, C), car_color(C, red).") |
| check("bare rule extracted", "eastbound" in out and ":-" in out, repr(out)) |
|
|
| |
| out = extract_hypothesis("[RULE]\neastbound(T) :- has_car(T, C).\n[/RULE]") |
| check("[RULE] block verbatim", out.strip() == "eastbound(T) :- has_car(T, C).", repr(out)) |
|
|
| |
| out = extract_hypothesis("Some prose.\n```prolog\neastbound(T) :- has_car(T, C).\n```") |
| check("fenced code block verbatim", out.strip() == "eastbound(T) :- has_car(T, C).", repr(out)) |
|
|
| |
| cot = "<think>lots of reasoning</think>\neastbound(T) :- has_car(T, C), car_color(C, red)." |
| out = extract_hypothesis(cot) |
| check("CoT stripped", "eastbound" in out and "lots" not in out, repr(out)) |
|
|
| |
| out = extract_hypothesis("Let me think...\nFinal Answer:\neastbound(T) :- has_car(T, C).") |
| check("Final Answer: marker", "eastbound" in out and "Let me" not in out, repr(out)) |
|
|
| |
| out = extract_hypothesis("```\neastbound(T) :- has_car(T, C). % this is a comment\n```") |
| check("Prolog comment stripped", "%" not in out, repr(out)) |
|
|
| |
| out = extract_hypothesis(None) |
| check("None input β empty string", out == "", repr(out)) |
|
|
| |
| out = extract_hypothesis("```\nold_rule(T).\n```\nBetter:\n```prolog\neastbound(T) :- has_car(T, C).\n```") |
| check("last code block wins", "eastbound" in out and "old_rule" not in out, repr(out)) |
|
|
| |
| out = extract_hypothesis("My answer is:\neastbound(train0).\neastbound(train2).") |
| check("bare facts extracted", "eastbound(train0)" in out and "eastbound(train2)" in out, repr(out)) |
|
|
| |
| out = extract_hypothesis("The train goes east because it is red.") |
| check("prose not extracted", out.strip() == "", repr(out)) |
|
|
|
|
| |
| |
| |
|
|
| section("2. verify") |
|
|
| from ipt.verifier import verify |
|
|
| |
| r_ext = verify(GOOD_RULE, MINI_VP, EVAL_CFG, isomorphic=False) |
| r_iso = verify(GOOD_RULE, MINI_VP, EVAL_CFG, isomorphic=True) |
| check("good rule: extensional correct", r_ext["is_correct"], str(r_ext)) |
| check("good rule: isomorphic correct", r_iso["is_correct"], str(r_iso)) |
| check("good rule: syntax_valid (ext)", r_ext["syntax_valid"]) |
| check("good rule: syntax_valid (iso)", r_iso["syntax_valid"]) |
| check("good rule: partial = 1.0 (ext)", r_ext["partial_score"] == 1.0, str(r_ext["partial_score"])) |
| check("good rule: partial = 1.0 (iso)", r_iso["partial_score"] == 1.0, str(r_iso["partial_score"])) |
|
|
| |
| r_ext = verify(SHORTCUT, MINI_VP, EVAL_CFG, isomorphic=False) |
| r_iso = verify(SHORTCUT, MINI_VP, EVAL_CFG, isomorphic=True) |
| check("shortcut: extensional correct", r_ext["is_correct"], str(r_ext)) |
| check("shortcut: isomorphic FAILS", not r_iso["is_correct"], str(r_iso)) |
| check("shortcut: iso partial < 1.0", r_iso["partial_score"] < 1.0, str(r_iso["partial_score"])) |
|
|
| |
| r_ext = verify(WRONG_RULE, MINI_VP, EVAL_CFG, isomorphic=False) |
| r_iso = verify(WRONG_RULE, MINI_VP, EVAL_CFG, isomorphic=True) |
| check("wrong rule: extensional fails", not r_ext["is_correct"], str(r_ext)) |
| check("wrong rule: isomorphic fails", not r_iso["is_correct"], str(r_iso)) |
|
|
| |
| r = verify(BAD_SYNTAX, MINI_VP, EVAL_CFG, isomorphic=False) |
| check("bad syntax: not correct", not r["is_correct"], str(r)) |
|
|
| |
| r = verify("westbound(T) :- has_car(T, _).", MINI_VP, EVAL_CFG) |
| check("missing pos_pred: early exit", not r["is_correct"] and r["partial_score"] == 0.0, str(r)) |
|
|
| |
| r = verify("```prolog\n" + GOOD_RULE + "\n```", MINI_VP, EVAL_CFG, isomorphic=True) |
| check("code-block rule: iso correct", r["is_correct"], str(r)) |
|
|
| |
| r = verify(f"[RULE]\n{GOOD_RULE}\n[/RULE]", MINI_VP, EVAL_CFG, isomorphic=True) |
| check("[RULE] tag: iso correct", r["is_correct"], str(r)) |
|
|
| |
| cot_rule = f"<think>Hmm, let me think...</think>\n{GOOD_RULE}" |
| r = verify(cot_rule, MINI_VP, EVAL_CFG, isomorphic=True) |
| check("CoT + rule: iso correct", r["is_correct"], str(r)) |
|
|
| |
| partial_rule = "eastbound(T) :- has_car(T, _)." |
| r_ext = verify(partial_rule, MINI_VP, EVAL_CFG, isomorphic=False) |
| check("partial rule: 0 < partial < 1", 0 < r_ext["partial_score"] < 1.0, str(r_ext["partial_score"])) |
|
|
| |
| |
| |
| neg_shortcut = "eastbound(T) :- \\+ westbound(T)." |
| r_ext = verify(neg_shortcut, MINI_VP, EVAL_CFG, isomorphic=False) |
| r_iso = verify(neg_shortcut, MINI_VP, EVAL_CFG, isomorphic=True) |
| check("neg shortcut: extensional correct", r_ext["is_correct"], str(r_ext)) |
| check("neg shortcut: isomorphic FAILS", not r_iso["is_correct"], str(r_iso)) |
|
|
|
|
| |
| |
| |
|
|
| section("3. Dataset ground-truth sanity (SLR-Bench v1-All)") |
|
|
| try: |
| from datasets import load_dataset |
| print(" Loading AIML-TUDA/SLR-Bench v1-All test split...") |
| ds = load_dataset("AIML-TUDA/SLR-Bench", "v1-All", split="test") |
| print(f" Loaded {len(ds)} examples.") |
|
|
| |
| ex = ds[0] |
| print(f" Example keys: {list(ex.keys())}") |
|
|
| VP_KEY = "validation program" |
| GT_KEY = "ground-truth rule" |
| check("validation_program key exists", VP_KEY in ex, f"keys: {list(ex.keys())}") |
| check("ground-truth rule key exists", GT_KEY in ex, f"keys: {list(ex.keys())}") |
|
|
| vp_snippet = ex[VP_KEY][:300].replace("\n", " | ") |
| print(f" VP snippet: {vp_snippet}") |
| print(f" GT snippet: {ex[GT_KEY][:120]}") |
|
|
| |
| import itertools |
| examples = list(itertools.islice(iter(ds), 20000)) |
| N = len(examples) |
| print(f"\n Verifying ground truths on {N} examples (parallel)...") |
| from IsomorphicPerturbationTesting import _run_eval |
| inputs = [(ex[GT_KEY], ex[VP_KEY], EVAL_CFG, 5) for ex in examples] |
| n_cpus = max(1, mp.cpu_count() - 1) |
| with mp.Pool(n_cpus) as pool: |
| pairs = list(tqdm(pool.imap(_run_eval, inputs), total=N, desc="GT verification")) |
|
|
| n_pass_ext = n_pass_iso = 0 |
| for i, d in enumerate(pairs): |
| if not d["extensional_correct"] or not d["isomorphic_correct"]: |
| print(f" [WARN] example {i}: ext={d['extensional_correct']} iso={d['isomorphic_correct']} gt={examples[i][GT_KEY]!r}") |
| print(f" err={d.get('error')}") |
| if d["extensional_correct"]: n_pass_ext += 1 |
| if d["isomorphic_correct"]: n_pass_iso += 1 |
|
|
| check(f"GT extensional accuracy ({N} ex)", n_pass_ext == N, f"{n_pass_ext}/{N}") |
| check(f"GT isomorphic accuracy ({N} ex)", n_pass_iso == N, f"{n_pass_iso}/{N}") |
|
|
| except Exception as e: |
| print(f" [SKIP] Dataset test failed: {e}") |
| traceback.print_exc() |
|
|
|
|
| |
| |
| |
|
|
| section("4. Full _compute round-trip") |
|
|
| try: |
| repo_root = Path(__file__).resolve().parent |
| sys.path.insert(0, str(repo_root)) |
| from IsomorphicPerturbationTesting import IsomorphicPerturbationTesting |
|
|
| ipt = IsomorphicPerturbationTesting() |
|
|
| predictions = [ |
| GOOD_RULE, |
| SHORTCUT, |
| WRONG_RULE, |
| ] |
| references = [ |
| {"validation_program": MINI_VP, "evaluation_config": EVAL_CFG}, |
| {"validation_program": MINI_VP, "evaluation_config": EVAL_CFG}, |
| {"validation_program": MINI_VP, "evaluation_config": EVAL_CFG}, |
| ] |
|
|
| results = ipt._compute(predictions, references) |
|
|
| check("shortcut_count == 1", results["meta"]["shortcut_count"] == 1, str(results["meta"]["shortcut_count"])) |
| check("shortcut_rate > 0", results["shortcut_rate"] > 0, str(results["shortcut_rate"])) |
| check("extensional_accuracy == 2/3", abs(results["meta"]["extensional_accuracy"] - 2/3) < 1e-9, |
| str(results["meta"]["extensional_accuracy"])) |
| check("isomorphic_accuracy == 1/3", abs(results["isomorphic_accuracy"] - 1/3) < 1e-9, |
| str(results["isomorphic_accuracy"])) |
| check("shortcut_rate == 1/3", abs(results["shortcut_rate"] - 1/3) < 1e-9, |
| str(results["shortcut_rate"])) |
| check("shortcut_ids == [1]", results["shortcut_ids"] == [1], str(results["shortcut_ids"])) |
| check("detailed_results length", len(results["detailed_results"]) == 3) |
|
|
| d = results["detailed_results"] |
| check("good rule: not shortcut", not d[0]["is_reward_shortcut"]) |
| check("shortcut: is_reward_shortcut", d[1]["is_reward_shortcut"]) |
| check("wrong rule: not shortcut", not d[2]["is_reward_shortcut"]) |
|
|
| except Exception as e: |
| print(f" [ERROR] {e}") |
| traceback.print_exc() |
|
|
|
|
| |
| |
| |
|
|
| section("Summary") |
| n_pass = sum(1 for _, ok in _results if ok) |
| n_total = len(_results) |
| print(f" {n_pass}/{n_total} checks passed") |
| if n_pass < n_total: |
| print("\n Failed checks:") |
| for name, ok in _results: |
| if not ok: |
| print(f" - {name}") |
| sys.exit(1) |
| else: |
| print(" All checks passed!") |
|
|