File size: 12,312 Bytes
4af4a71
 
 
 
 
 
 
 
 
 
 
 
 
8fa00b4
4af4a71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8fa00b4
 
4af4a71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9853858
 
 
 
4af4a71
 
 
 
9853858
4af4a71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
"""
Comprehensive tests for IsomorphicPerturbationTesting.

Covers:
  1. extract_hypothesis β€” all formatting variants
  2. verify β€” correct rule, shortcut, bad syntax, edge cases
  3. Dataset ground-truth sanity check (SLR-Bench v1-All)
  4. Full _compute round-trip
"""

import multiprocessing as mp
import sys
import traceback
from pathlib import Path

from tqdm import tqdm

# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

PASS = "\033[92mPASS\033[0m"
FAIL = "\033[91mFAIL\033[0m"

_results = []

def check(name, cond, detail=""):
    status = PASS if cond else FAIL
    print(f"  [{status}] {name}" + (f" β€” {detail}" if detail else ""))
    _results.append((name, cond))

def section(title):
    print(f"\n{'='*60}\n  {title}\n{'='*60}")


# ---------------------------------------------------------------------------
# Minimal validation program (Michalski trains, 2 pos + 2 neg)
# ---------------------------------------------------------------------------

MINI_VP = """\
eastbound(train0).
westbound(train1).
eastbound(train2).
westbound(train3).
has_car(train0, car0_1).
has_car(train1, car1_1).
has_car(train2, car2_1).
has_car(train3, car3_1).
car_color(car0_1, red).
car_color(car1_1, blue).
car_color(car2_1, red).
car_color(car3_1, blue).
"""

EVAL_CFG = {"positive_predicate": "eastbound", "negative_predicate": "westbound"}

# A genuine rule: eastbound iff a car has color red
GOOD_RULE = "eastbound(T) :- has_car(T, C), car_color(C, red)."

# A shortcut: enumerates ground instances
SHORTCUT = "eastbound(train0). eastbound(train2)."

# A wrong rule
WRONG_RULE = "eastbound(T) :- has_car(T, C), car_color(C, blue)."

# Bad syntax
BAD_SYNTAX = "this is not prolog at all :-"


# ===========================================================================
# 1. extract_hypothesis
# ===========================================================================

section("1. extract_hypothesis")

from ipt.verifier import extract_hypothesis

# 1a. Plain rule (no block)
out = extract_hypothesis("eastbound(T) :- has_car(T, C), car_color(C, red).")
check("bare rule extracted", "eastbound" in out and ":-" in out, repr(out))

# 1b. [RULE] block returned verbatim
out = extract_hypothesis("[RULE]\neastbound(T) :- has_car(T, C).\n[/RULE]")
check("[RULE] block verbatim", out.strip() == "eastbound(T) :- has_car(T, C).", repr(out))

# 1c. Fenced code block returned verbatim
out = extract_hypothesis("Some prose.\n```prolog\neastbound(T) :- has_car(T, C).\n```")
check("fenced code block verbatim", out.strip() == "eastbound(T) :- has_car(T, C).", repr(out))

# 1d. Chain-of-thought stripped
cot = "<think>lots of reasoning</think>\neastbound(T) :- has_car(T, C), car_color(C, red)."
out = extract_hypothesis(cot)
check("CoT stripped", "eastbound" in out and "lots" not in out, repr(out))

# 1e. Final Answer marker
out = extract_hypothesis("Let me think...\nFinal Answer:\neastbound(T) :- has_car(T, C).")
check("Final Answer: marker", "eastbound" in out and "Let me" not in out, repr(out))

# 1f. Prolog comment stripped
out = extract_hypothesis("```\neastbound(T) :- has_car(T, C). % this is a comment\n```")
check("Prolog comment stripped", "%" not in out, repr(out))

# 1g. Non-string input returns ""
out = extract_hypothesis(None)
check("None input β†’ empty string", out == "", repr(out))

# 1h. Multiple code blocks β€” last one wins
out = extract_hypothesis("```\nold_rule(T).\n```\nBetter:\n```prolog\neastbound(T) :- has_car(T, C).\n```")
check("last code block wins", "eastbound" in out and "old_rule" not in out, repr(out))

# 1i. Bare fact lines extracted
out = extract_hypothesis("My answer is:\neastbound(train0).\neastbound(train2).")
check("bare facts extracted", "eastbound(train0)" in out and "eastbound(train2)" in out, repr(out))

# 1j. Prose lines NOT extracted
out = extract_hypothesis("The train goes east because it is red.")
check("prose not extracted", out.strip() == "", repr(out))


# ===========================================================================
# 2. verify β€” unit tests
# ===========================================================================

section("2. verify")

from ipt.verifier import verify

# 2a. Correct rule passes both modes
r_ext = verify(GOOD_RULE, MINI_VP, EVAL_CFG, isomorphic=False)
r_iso = verify(GOOD_RULE, MINI_VP, EVAL_CFG, isomorphic=True)
check("good rule: extensional correct",  r_ext["is_correct"], str(r_ext))
check("good rule: isomorphic correct",   r_iso["is_correct"], str(r_iso))
check("good rule: syntax_valid (ext)",   r_ext["syntax_valid"])
check("good rule: syntax_valid (iso)",   r_iso["syntax_valid"])
check("good rule: partial = 1.0 (ext)",  r_ext["partial_score"] == 1.0, str(r_ext["partial_score"]))
check("good rule: partial = 1.0 (iso)",  r_iso["partial_score"] == 1.0, str(r_iso["partial_score"]))

# 2b. Shortcut passes extensional, fails isomorphic
r_ext = verify(SHORTCUT, MINI_VP, EVAL_CFG, isomorphic=False)
r_iso = verify(SHORTCUT, MINI_VP, EVAL_CFG, isomorphic=True)
check("shortcut: extensional correct",     r_ext["is_correct"], str(r_ext))
check("shortcut: isomorphic FAILS",        not r_iso["is_correct"], str(r_iso))
check("shortcut: iso partial < 1.0",       r_iso["partial_score"] < 1.0, str(r_iso["partial_score"]))

# 2c. Wrong rule fails both
r_ext = verify(WRONG_RULE, MINI_VP, EVAL_CFG, isomorphic=False)
r_iso = verify(WRONG_RULE, MINI_VP, EVAL_CFG, isomorphic=True)
check("wrong rule: extensional fails",  not r_ext["is_correct"], str(r_ext))
check("wrong rule: isomorphic fails",   not r_iso["is_correct"], str(r_iso))

# 2d. Bad syntax
r = verify(BAD_SYNTAX, MINI_VP, EVAL_CFG, isomorphic=False)
check("bad syntax: not correct",      not r["is_correct"], str(r))

# 2e. Missing positive predicate guard
r = verify("westbound(T) :- has_car(T, _).", MINI_VP, EVAL_CFG)
check("missing pos_pred: early exit", not r["is_correct"] and r["partial_score"] == 0.0, str(r))

# 2f. Rule in code block
r = verify("```prolog\n" + GOOD_RULE + "\n```", MINI_VP, EVAL_CFG, isomorphic=True)
check("code-block rule: iso correct", r["is_correct"], str(r))

# 2g. Rule in [RULE] tag
r = verify(f"[RULE]\n{GOOD_RULE}\n[/RULE]", MINI_VP, EVAL_CFG, isomorphic=True)
check("[RULE] tag: iso correct", r["is_correct"], str(r))

# 2h. Rule after CoT
cot_rule = f"<think>Hmm, let me think...</think>\n{GOOD_RULE}"
r = verify(cot_rule, MINI_VP, EVAL_CFG, isomorphic=True)
check("CoT + rule: iso correct", r["is_correct"], str(r))

# 2i. Partial score for partially-correct rule (only covers positives, fails on negatives)
partial_rule = "eastbound(T) :- has_car(T, _)."  # classifies everything as eastbound
r_ext = verify(partial_rule, MINI_VP, EVAL_CFG, isomorphic=False)
check("partial rule: 0 < partial < 1", 0 < r_ext["partial_score"] < 1.0, str(r_ext["partial_score"]))

# 2j. Negation shortcut: "eastbound if not westbound" β€” passes extensional (bridge rule makes
#     westbound meaningful), fails isomorphic (westbound undefined β†’ \+ always succeeds β†’ all trains
#     are eastbound β†’ neg examples misclassified).
neg_shortcut = "eastbound(T) :- \\+ westbound(T)."
r_ext = verify(neg_shortcut, MINI_VP, EVAL_CFG, isomorphic=False)
r_iso = verify(neg_shortcut, MINI_VP, EVAL_CFG, isomorphic=True)
check("neg shortcut: extensional correct",  r_ext["is_correct"], str(r_ext))
check("neg shortcut: isomorphic FAILS",     not r_iso["is_correct"], str(r_iso))


# ===========================================================================
# 3. Dataset ground-truth sanity check
# ===========================================================================

section("3. Dataset ground-truth sanity (SLR-Bench v1-All)")

try:
    from datasets import load_dataset
    print("  Loading AIML-TUDA/SLR-Bench v1-All test split...")
    ds = load_dataset("AIML-TUDA/SLR-Bench", "v1-All", split="test")
    print(f"  Loaded {len(ds)} examples.")

    # Inspect first example structure
    ex = ds[0]
    print(f"  Example keys: {list(ex.keys())}")

    VP_KEY = "validation program"
    GT_KEY = "ground-truth rule"
    check("validation_program key exists", VP_KEY in ex, f"keys: {list(ex.keys())}")
    check("ground-truth rule key exists",  GT_KEY in ex, f"keys: {list(ex.keys())}")

    vp_snippet = ex[VP_KEY][:300].replace("\n", " | ")
    print(f"  VP snippet: {vp_snippet}")
    print(f"  GT snippet: {ex[GT_KEY][:120]}")

    # Run ground truths through verifier on first N examples
    import itertools
    examples = list(itertools.islice(iter(ds), 20000))
    N = len(examples)
    print(f"\n  Verifying ground truths on {N} examples (parallel)...")
    from IsomorphicPerturbationTesting import _run_eval
    inputs = [(ex[GT_KEY], ex[VP_KEY], EVAL_CFG, 5) for ex in examples]
    n_cpus = max(1, mp.cpu_count() - 1)
    with mp.Pool(n_cpus) as pool:
        pairs = list(tqdm(pool.imap(_run_eval, inputs), total=N, desc="GT verification"))

    n_pass_ext = n_pass_iso = 0
    for i, d in enumerate(pairs):
        if not d["extensional_correct"] or not d["isomorphic_correct"]:
            print(f"  [WARN] example {i}: ext={d['extensional_correct']} iso={d['isomorphic_correct']}  gt={examples[i][GT_KEY]!r}")
            print(f"         err={d.get('error')}")
        if d["extensional_correct"]: n_pass_ext += 1
        if d["isomorphic_correct"]:  n_pass_iso += 1

    check(f"GT extensional accuracy ({N} ex)", n_pass_ext == N, f"{n_pass_ext}/{N}")
    check(f"GT isomorphic  accuracy ({N} ex)", n_pass_iso == N, f"{n_pass_iso}/{N}")

except Exception as e:
    print(f"  [SKIP] Dataset test failed: {e}")
    traceback.print_exc()


# ===========================================================================
# 4. Full _compute round-trip
# ===========================================================================

section("4. Full _compute round-trip")

try:
    repo_root = Path(__file__).resolve().parent
    sys.path.insert(0, str(repo_root))
    from IsomorphicPerturbationTesting import IsomorphicPerturbationTesting

    ipt = IsomorphicPerturbationTesting()

    predictions = [
        GOOD_RULE,           # genuine rule β†’ ext=T iso=T
        SHORTCUT,            # shortcut      β†’ ext=T iso=F
        WRONG_RULE,          # wrong         β†’ ext=F iso=F
    ]
    references = [
        {"validation_program": MINI_VP, "evaluation_config": EVAL_CFG},
        {"validation_program": MINI_VP, "evaluation_config": EVAL_CFG},
        {"validation_program": MINI_VP, "evaluation_config": EVAL_CFG},
    ]

    results = ipt._compute(predictions, references)

    check("shortcut_count == 1",         results["meta"]["shortcut_count"] == 1,       str(results["meta"]["shortcut_count"]))
    check("shortcut_rate > 0",           results["shortcut_rate"] > 0,                  str(results["shortcut_rate"]))
    check("extensional_accuracy == 2/3", abs(results["meta"]["extensional_accuracy"] - 2/3) < 1e-9,
          str(results["meta"]["extensional_accuracy"]))
    check("isomorphic_accuracy == 1/3",  abs(results["isomorphic_accuracy"] - 1/3) < 1e-9,
          str(results["isomorphic_accuracy"]))
    check("shortcut_rate == 1/3",        abs(results["shortcut_rate"] - 1/3) < 1e-9,
          str(results["shortcut_rate"]))
    check("shortcut_ids == [1]",         results["shortcut_ids"] == [1],               str(results["shortcut_ids"]))
    check("detailed_results length",     len(results["detailed_results"]) == 3)

    d = results["detailed_results"]
    check("good rule: not shortcut",     not d[0]["is_reward_shortcut"])
    check("shortcut: is_reward_shortcut", d[1]["is_reward_shortcut"])
    check("wrong rule: not shortcut",    not d[2]["is_reward_shortcut"])

except Exception as e:
    print(f"  [ERROR] {e}")
    traceback.print_exc()


# ===========================================================================
# Summary
# ===========================================================================

section("Summary")
n_pass = sum(1 for _, ok in _results if ok)
n_total = len(_results)
print(f"  {n_pass}/{n_total} checks passed")
if n_pass < n_total:
    print("\n  Failed checks:")
    for name, ok in _results:
        if not ok:
            print(f"    - {name}")
    sys.exit(1)
else:
    print("  All checks passed!")