chakravyuh / eval /permutation_test_v1_v2.py
UjjwalPardeshi
deploy: latest main to HF Space
03815d6
"""Permutation test for the v1 → v2 FPR delta.
Bootstrap CIs (already shipped in `logs/bootstrap_v2.json`) tell you the spread
of v2's metric. They do *not* tell you whether the v1 → v2 FPR drop
(36 % → 6.7 %) is statistically significant. This script answers that with two
complementary tests on the same input data:
1. **Aggregate-counts permutation** (always runs). Treats each model's
bench evaluation as a 2 x 2 contingency over benigns:
| predicted scam | predicted benign |
v1 | a | b |
v2 | c | d |
and computes a permutation p-value on the difference of FPRs by
repeatedly randomly relabelling rows under the null "v1 and v2 came from
the same distribution," over `--n-perm` (default 10 000) iterations.
Also reports the closed-form Fisher exact p-value as a cross-check.
2. **Per-row paired permutation** (runs only when both per-row predictions
exist). Uses the WIN_PLAN D.2 algorithm: for each scenario, randomly
swap the (v1_correct, v2_correct) label assignment, recompute the
mean accuracy delta, and tally how often the absolute delta exceeds
the observed value.
Inputs
------
- `--v1-counts a,b` and `--v2-counts c,d` (or the defaults below) for test 1.
- `--v1-per-row` and `--v2-per-row` JSONL files for test 2 (optional).
Output
------
JSON written to `--output` (default `logs/permutation_test_v1_v2.json`) with
both p-values, observed deltas, sample sizes, and a one-line interpretation.
Defaults
--------
The aggregate counts are taken from the README claims that we want to
quantify the significance of:
v1: FPR 36 % on n=30 benigns → 11 false positives, 19 true negatives
v2: FPR 6.7 % on n=30 benigns → 2 false positives, 28 true negatives
These come from `logs/eval_v2.json` (v2 side; threshold 0.55) and from the v1
training-time eval cited in `docs/training_diagnostics.md` (the v1 LoRA was not
re-pushed to HF Hub but its bench-time numbers are the ones cited in every
README/slide).
Run:
python eval/permutation_test_v1_v2.py
or:
python eval/permutation_test_v1_v2.py --n-perm 100000 --seed 42
"""
from __future__ import annotations
import argparse
import json
import math
import random
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
# Defaults — see module docstring for citation.
DEFAULT_V1_FP = 11
DEFAULT_V1_TN = 19
DEFAULT_V2_FP = 2
DEFAULT_V2_TN = 28
# ---------------------------------------------------------------------------
# Aggregate-counts permutation (always runs)
# ---------------------------------------------------------------------------
@dataclass
class AggregateResult:
v1_fp: int
v1_n: int
v2_fp: int
v2_n: int
v1_fpr: float
v2_fpr: float
observed_delta: float
n_perm: int
p_value_permutation: float
p_value_fisher_exact: float
def _fisher_exact_two_sided(a: int, b: int, c: int, d: int) -> float:
"""Closed-form two-sided Fisher exact for a 2x2 table.
Computes the probability of seeing a table at least as extreme as
[[a, b], [c, d]] under the null of independence with the same row + column
margins. Uses the standard log-factorial trick to avoid overflow.
"""
n = a + b + c + d
row1 = a + b
row2 = c + d
col1 = a + c
col2 = b + d
def log_choose(n_: int, k_: int) -> float:
if k_ < 0 or k_ > n_:
return -math.inf
return math.lgamma(n_ + 1) - math.lgamma(k_ + 1) - math.lgamma(n_ - k_ + 1)
log_denom = log_choose(n, col1)
observed_log_p = log_choose(row1, a) + log_choose(row2, c) - log_denom
observed_p = math.exp(observed_log_p)
total_p = 0.0
a_min = max(0, col1 - row2)
a_max = min(row1, col1)
for a_alt in range(a_min, a_max + 1):
c_alt = col1 - a_alt
log_p = log_choose(row1, a_alt) + log_choose(row2, c_alt) - log_denom
p = math.exp(log_p)
if p <= observed_p + 1e-12:
total_p += p
return min(1.0, total_p)
def aggregate_permutation_test(
v1_fp: int,
v1_n: int,
v2_fp: int,
v2_n: int,
*,
n_perm: int = 10_000,
seed: int = 42,
) -> AggregateResult:
"""Permutation test on the FPR delta from aggregate counts.
Constructs the full label vector (1 = false positive, 0 = true negative)
for v1 and v2, then under the null reshuffles the model labels and
measures how often the absolute FPR delta exceeds the observed.
"""
if v1_n <= 0 or v2_n <= 0:
raise ValueError("Sample sizes must be positive")
rng = random.Random(seed)
v1_labels = [1] * v1_fp + [0] * (v1_n - v1_fp)
v2_labels = [1] * v2_fp + [0] * (v2_n - v2_fp)
pooled = v1_labels + v2_labels
v1_fpr = v1_fp / v1_n
v2_fpr = v2_fp / v2_n
observed_delta = abs(v2_fpr - v1_fpr)
extreme = 0
for _ in range(n_perm):
rng.shuffle(pooled)
perm_v1 = pooled[:v1_n]
perm_v2 = pooled[v1_n:]
delta = abs(sum(perm_v2) / v2_n - sum(perm_v1) / v1_n)
if delta >= observed_delta - 1e-12:
extreme += 1
p_perm = extreme / n_perm
p_fisher = _fisher_exact_two_sided(
v1_fp, v1_n - v1_fp, v2_fp, v2_n - v2_fp
)
return AggregateResult(
v1_fp=v1_fp,
v1_n=v1_n,
v2_fp=v2_fp,
v2_n=v2_n,
v1_fpr=v1_fpr,
v2_fpr=v2_fpr,
observed_delta=observed_delta,
n_perm=n_perm,
p_value_permutation=p_perm,
p_value_fisher_exact=p_fisher,
)
# ---------------------------------------------------------------------------
# Per-row paired permutation (optional — runs only with --v1-per-row + --v2-per-row)
# ---------------------------------------------------------------------------
@dataclass
class PerRowResult:
n_paired: int
v1_correct_count: int
v2_correct_count: int
observed_delta: float
n_perm: int
p_value_permutation: float
def per_row_paired_permutation(
v1_correct: list[int],
v2_correct: list[int],
*,
n_perm: int = 10_000,
seed: int = 42,
) -> PerRowResult:
"""Paired-sample permutation test on per-scenario predictions.
For each scenario, randomly swap the (v1, v2) correctness labels and
measure how often the absolute mean-delta exceeds the observed.
"""
if len(v1_correct) != len(v2_correct):
raise ValueError("v1 and v2 per-row vectors must be the same length")
n = len(v1_correct)
if n == 0:
raise ValueError("Empty input vectors")
observed_delta = abs(
sum(v2_correct) / n - sum(v1_correct) / n
)
rng = random.Random(seed)
extreme = 0
for _ in range(n_perm):
diffs = []
for v1_i, v2_i in zip(v1_correct, v2_correct):
if rng.random() < 0.5:
diffs.append(v2_i - v1_i)
else:
diffs.append(v1_i - v2_i)
if abs(sum(diffs) / n) >= observed_delta - 1e-12:
extreme += 1
p_perm = extreme / n_perm
return PerRowResult(
n_paired=n,
v1_correct_count=sum(v1_correct),
v2_correct_count=sum(v2_correct),
observed_delta=observed_delta,
n_perm=n_perm,
p_value_permutation=p_perm,
)
def _load_per_row_correct(path: Path) -> list[int]:
rows = []
with path.open("r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
row = json.loads(line)
rows.append(int(row["predicted"] == row["ground_truth"]))
return rows
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def _interpret(p: float) -> str:
if p < 1e-4:
return "extremely significant (p < 0.0001)"
if p < 1e-3:
return f"highly significant (p = {p:.4g})"
if p < 0.05:
return f"significant (p = {p:.4g})"
return f"not significant at alpha=0.05 (p = {p:.4g})"
def main() -> int:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--v1-fp", type=int, default=DEFAULT_V1_FP)
parser.add_argument("--v1-n", type=int, default=DEFAULT_V1_FP + DEFAULT_V1_TN)
parser.add_argument("--v2-fp", type=int, default=DEFAULT_V2_FP)
parser.add_argument("--v2-n", type=int, default=DEFAULT_V2_FP + DEFAULT_V2_TN)
parser.add_argument(
"--v1-per-row",
type=Path,
default=None,
help="Optional path to v1 per-row JSONL (one obj per scenario with `predicted`+`ground_truth`).",
)
parser.add_argument(
"--v2-per-row",
type=Path,
default=None,
help="Optional path to v2 per-row JSONL (B.12 output).",
)
parser.add_argument("--n-perm", type=int, default=10_000)
parser.add_argument("--seed", type=int, default=42)
parser.add_argument(
"--output",
type=Path,
default=Path("logs/permutation_test_v1_v2.json"),
)
args = parser.parse_args()
payload: dict[str, Any] = {
"meta": {
"script": "eval/permutation_test_v1_v2.py",
"n_perm": args.n_perm,
"seed": args.seed,
"interpretation_threshold": "alpha=0.05",
"tests_run": [],
},
}
# Test 1 — aggregate counts (always runs)
agg = aggregate_permutation_test(
v1_fp=args.v1_fp,
v1_n=args.v1_n,
v2_fp=args.v2_fp,
v2_n=args.v2_n,
n_perm=args.n_perm,
seed=args.seed,
)
payload["aggregate_fpr_test"] = {
"input": {
"v1_fp": agg.v1_fp,
"v1_n": agg.v1_n,
"v1_fpr": round(agg.v1_fpr, 4),
"v2_fp": agg.v2_fp,
"v2_n": agg.v2_n,
"v2_fpr": round(agg.v2_fpr, 4),
},
"observed_fpr_delta_abs": round(agg.observed_delta, 4),
"p_value_permutation": agg.p_value_permutation,
"p_value_fisher_exact": agg.p_value_fisher_exact,
"interpretation_permutation": _interpret(agg.p_value_permutation),
"interpretation_fisher_exact": _interpret(agg.p_value_fisher_exact),
}
payload["meta"]["tests_run"].append("aggregate_fpr_test")
# Test 2 — per-row paired (optional)
if args.v1_per_row and args.v2_per_row:
if not args.v1_per_row.exists():
print(f"[warn] --v1-per-row {args.v1_per_row} not found; skipping per-row test", file=sys.stderr)
elif not args.v2_per_row.exists():
print(f"[warn] --v2-per-row {args.v2_per_row} not found; skipping per-row test", file=sys.stderr)
else:
v1_correct = _load_per_row_correct(args.v1_per_row)
v2_correct = _load_per_row_correct(args.v2_per_row)
if len(v1_correct) != len(v2_correct):
print(
f"[warn] paired per-row lengths differ ({len(v1_correct)} vs {len(v2_correct)}); skipping",
file=sys.stderr,
)
else:
pr = per_row_paired_permutation(
v1_correct,
v2_correct,
n_perm=args.n_perm,
seed=args.seed,
)
payload["per_row_paired_test"] = {
"input": {
"n_paired": pr.n_paired,
"v1_correct": pr.v1_correct_count,
"v2_correct": pr.v2_correct_count,
},
"observed_accuracy_delta_abs": round(pr.observed_delta, 4),
"p_value_permutation": pr.p_value_permutation,
"interpretation_permutation": _interpret(pr.p_value_permutation),
}
payload["meta"]["tests_run"].append("per_row_paired_test")
else:
payload["per_row_paired_test"] = None
payload["meta"]["per_row_test_skipped_reason"] = (
"B.12 per-row outputs not yet shipped (logs/eval_v2_per_row.jsonl + logs/eval_v1_per_row.jsonl). "
"When B.12 produces them, re-run with --v1-per-row and --v2-per-row flags."
)
args.output.parent.mkdir(parents=True, exist_ok=True)
args.output.write_text(json.dumps(payload, indent=2) + "\n", encoding="utf-8")
print(json.dumps(payload, indent=2))
return 0
if __name__ == "__main__":
raise SystemExit(main())