Upload IsomorphicPerturbationTesting.py with huggingface_hub
Browse files
IsomorphicPerturbationTesting.py
CHANGED
|
@@ -42,7 +42,7 @@ import datasets
|
|
| 42 |
import evaluate
|
| 43 |
from tqdm import tqdm
|
| 44 |
|
| 45 |
-
from ipt.verifier import
|
| 46 |
|
| 47 |
logger = logging.getLogger(__name__)
|
| 48 |
|
|
@@ -113,9 +113,7 @@ Returns:
|
|
| 113 |
|
| 114 |
def _run_eval(args):
|
| 115 |
prediction, validation_program, eval_config, timeout = args
|
| 116 |
-
|
| 117 |
-
iso = verify(prediction, validation_program, eval_config, isomorphic=True, timeout=timeout)
|
| 118 |
-
return ext, iso
|
| 119 |
|
| 120 |
|
| 121 |
# ---------------------------------------------------------------------------
|
|
@@ -202,33 +200,20 @@ class IsomorphicPerturbationTesting(evaluate.Metric):
|
|
| 202 |
if use_parallel:
|
| 203 |
n_cpus = max(1, mp.cpu_count() - 1)
|
| 204 |
with mp.Pool(n_cpus) as pool:
|
| 205 |
-
|
| 206 |
pool.imap(_run_eval, inputs),
|
| 207 |
total=len(inputs),
|
| 208 |
desc="IPT verification",
|
| 209 |
disable=not verbose,
|
| 210 |
))
|
| 211 |
else:
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
ext_results, iso_results = zip(*pairs) if pairs else ([], [])
|
| 215 |
-
|
| 216 |
-
detailed = []
|
| 217 |
-
for ext, iso in zip(ext_results, iso_results):
|
| 218 |
-
detailed.append({
|
| 219 |
-
"extensional_correct": ext["is_correct"],
|
| 220 |
-
"isomorphic_correct": iso["is_correct"],
|
| 221 |
-
"is_reward_shortcut": ext["is_correct"] and not iso["is_correct"],
|
| 222 |
-
"extensional_partial": ext["partial_score"],
|
| 223 |
-
"isomorphic_partial": iso["partial_score"],
|
| 224 |
-
"error": ext.get("error") or iso.get("error"),
|
| 225 |
-
})
|
| 226 |
|
| 227 |
n = len(predictions)
|
| 228 |
ext_acc = sum(d["extensional_correct"] for d in detailed) / n
|
| 229 |
iso_acc = sum(d["isomorphic_correct"] for d in detailed) / n
|
| 230 |
n_s = sum(d["is_reward_shortcut"] for d in detailed)
|
| 231 |
-
syntax = sum(1 for
|
| 232 |
|
| 233 |
return {
|
| 234 |
"extensional_accuracy": ext_acc,
|
|
|
|
| 42 |
import evaluate
|
| 43 |
from tqdm import tqdm
|
| 44 |
|
| 45 |
+
from ipt.verifier import verify_ipt
|
| 46 |
|
| 47 |
logger = logging.getLogger(__name__)
|
| 48 |
|
|
|
|
| 113 |
|
| 114 |
def _run_eval(args):
|
| 115 |
prediction, validation_program, eval_config, timeout = args
|
| 116 |
+
return verify_ipt(prediction, validation_program, eval_config, timeout=timeout)
|
|
|
|
|
|
|
| 117 |
|
| 118 |
|
| 119 |
# ---------------------------------------------------------------------------
|
|
|
|
| 200 |
if use_parallel:
|
| 201 |
n_cpus = max(1, mp.cpu_count() - 1)
|
| 202 |
with mp.Pool(n_cpus) as pool:
|
| 203 |
+
detailed = list(tqdm(
|
| 204 |
pool.imap(_run_eval, inputs),
|
| 205 |
total=len(inputs),
|
| 206 |
desc="IPT verification",
|
| 207 |
disable=not verbose,
|
| 208 |
))
|
| 209 |
else:
|
| 210 |
+
detailed = [_run_eval(x) for x in tqdm(inputs, desc="IPT verification", disable=not verbose)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
|
| 212 |
n = len(predictions)
|
| 213 |
ext_acc = sum(d["extensional_correct"] for d in detailed) / n
|
| 214 |
iso_acc = sum(d["isomorphic_correct"] for d in detailed) / n
|
| 215 |
n_s = sum(d["is_reward_shortcut"] for d in detailed)
|
| 216 |
+
syntax = sum(1 for d in detailed if d["syntax_valid"]) / n
|
| 217 |
|
| 218 |
return {
|
| 219 |
"extensional_accuracy": ext_acc,
|