Spaces:
Running
Running
| """Permutation test for the v1 → v2 FPR delta. | |
| Bootstrap CIs (already shipped in `logs/bootstrap_v2.json`) tell you the spread | |
| of v2's metric. They do *not* tell you whether the v1 → v2 FPR drop | |
| (36 % → 6.7 %) is statistically significant. This script answers that with two | |
| complementary tests on the same input data: | |
| 1. **Aggregate-counts permutation** (always runs). Treats each model's | |
| bench evaluation as a 2 x 2 contingency over benigns: | |
| | predicted scam | predicted benign | | |
| v1 | a | b | | |
| v2 | c | d | | |
| and computes a permutation p-value on the difference of FPRs by | |
| repeatedly randomly relabelling rows under the null "v1 and v2 came from | |
| the same distribution," over `--n-perm` (default 10 000) iterations. | |
| Also reports the closed-form Fisher exact p-value as a cross-check. | |
| 2. **Per-row paired permutation** (runs only when both per-row predictions | |
| exist). Uses the WIN_PLAN D.2 algorithm: for each scenario, randomly | |
| swap the (v1_correct, v2_correct) label assignment, recompute the | |
| mean accuracy delta, and tally how often the absolute delta exceeds | |
| the observed value. | |
| Inputs | |
| ------ | |
| - `--v1-counts a,b` and `--v2-counts c,d` (or the defaults below) for test 1. | |
| - `--v1-per-row` and `--v2-per-row` JSONL files for test 2 (optional). | |
| Output | |
| ------ | |
| JSON written to `--output` (default `logs/permutation_test_v1_v2.json`) with | |
| both p-values, observed deltas, sample sizes, and a one-line interpretation. | |
| Defaults | |
| -------- | |
| The aggregate counts are taken from the README claims that we want to | |
| quantify the significance of: | |
| v1: FPR 36 % on n=30 benigns → 11 false positives, 19 true negatives | |
| v2: FPR 6.7 % on n=30 benigns → 2 false positives, 28 true negatives | |
| These come from `logs/eval_v2.json` (v2 side; threshold 0.55) and from the v1 | |
| training-time eval cited in `docs/training_diagnostics.md` (the v1 LoRA was not | |
| re-pushed to HF Hub but its bench-time numbers are the ones cited in every | |
| README/slide). | |
| Run: | |
| python eval/permutation_test_v1_v2.py | |
| or: | |
| python eval/permutation_test_v1_v2.py --n-perm 100000 --seed 42 | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import math | |
| import random | |
| import sys | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from typing import Any | |
| # Defaults — see module docstring for citation. | |
| DEFAULT_V1_FP = 11 | |
| DEFAULT_V1_TN = 19 | |
| DEFAULT_V2_FP = 2 | |
| DEFAULT_V2_TN = 28 | |
| # --------------------------------------------------------------------------- | |
| # Aggregate-counts permutation (always runs) | |
| # --------------------------------------------------------------------------- | |
| class AggregateResult: | |
| v1_fp: int | |
| v1_n: int | |
| v2_fp: int | |
| v2_n: int | |
| v1_fpr: float | |
| v2_fpr: float | |
| observed_delta: float | |
| n_perm: int | |
| p_value_permutation: float | |
| p_value_fisher_exact: float | |
| def _fisher_exact_two_sided(a: int, b: int, c: int, d: int) -> float: | |
| """Closed-form two-sided Fisher exact for a 2x2 table. | |
| Computes the probability of seeing a table at least as extreme as | |
| [[a, b], [c, d]] under the null of independence with the same row + column | |
| margins. Uses the standard log-factorial trick to avoid overflow. | |
| """ | |
| n = a + b + c + d | |
| row1 = a + b | |
| row2 = c + d | |
| col1 = a + c | |
| col2 = b + d | |
| def log_choose(n_: int, k_: int) -> float: | |
| if k_ < 0 or k_ > n_: | |
| return -math.inf | |
| return math.lgamma(n_ + 1) - math.lgamma(k_ + 1) - math.lgamma(n_ - k_ + 1) | |
| log_denom = log_choose(n, col1) | |
| observed_log_p = log_choose(row1, a) + log_choose(row2, c) - log_denom | |
| observed_p = math.exp(observed_log_p) | |
| total_p = 0.0 | |
| a_min = max(0, col1 - row2) | |
| a_max = min(row1, col1) | |
| for a_alt in range(a_min, a_max + 1): | |
| c_alt = col1 - a_alt | |
| log_p = log_choose(row1, a_alt) + log_choose(row2, c_alt) - log_denom | |
| p = math.exp(log_p) | |
| if p <= observed_p + 1e-12: | |
| total_p += p | |
| return min(1.0, total_p) | |
| def aggregate_permutation_test( | |
| v1_fp: int, | |
| v1_n: int, | |
| v2_fp: int, | |
| v2_n: int, | |
| *, | |
| n_perm: int = 10_000, | |
| seed: int = 42, | |
| ) -> AggregateResult: | |
| """Permutation test on the FPR delta from aggregate counts. | |
| Constructs the full label vector (1 = false positive, 0 = true negative) | |
| for v1 and v2, then under the null reshuffles the model labels and | |
| measures how often the absolute FPR delta exceeds the observed. | |
| """ | |
| if v1_n <= 0 or v2_n <= 0: | |
| raise ValueError("Sample sizes must be positive") | |
| rng = random.Random(seed) | |
| v1_labels = [1] * v1_fp + [0] * (v1_n - v1_fp) | |
| v2_labels = [1] * v2_fp + [0] * (v2_n - v2_fp) | |
| pooled = v1_labels + v2_labels | |
| v1_fpr = v1_fp / v1_n | |
| v2_fpr = v2_fp / v2_n | |
| observed_delta = abs(v2_fpr - v1_fpr) | |
| extreme = 0 | |
| for _ in range(n_perm): | |
| rng.shuffle(pooled) | |
| perm_v1 = pooled[:v1_n] | |
| perm_v2 = pooled[v1_n:] | |
| delta = abs(sum(perm_v2) / v2_n - sum(perm_v1) / v1_n) | |
| if delta >= observed_delta - 1e-12: | |
| extreme += 1 | |
| p_perm = extreme / n_perm | |
| p_fisher = _fisher_exact_two_sided( | |
| v1_fp, v1_n - v1_fp, v2_fp, v2_n - v2_fp | |
| ) | |
| return AggregateResult( | |
| v1_fp=v1_fp, | |
| v1_n=v1_n, | |
| v2_fp=v2_fp, | |
| v2_n=v2_n, | |
| v1_fpr=v1_fpr, | |
| v2_fpr=v2_fpr, | |
| observed_delta=observed_delta, | |
| n_perm=n_perm, | |
| p_value_permutation=p_perm, | |
| p_value_fisher_exact=p_fisher, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Per-row paired permutation (optional — runs only with --v1-per-row + --v2-per-row) | |
| # --------------------------------------------------------------------------- | |
| class PerRowResult: | |
| n_paired: int | |
| v1_correct_count: int | |
| v2_correct_count: int | |
| observed_delta: float | |
| n_perm: int | |
| p_value_permutation: float | |
| def per_row_paired_permutation( | |
| v1_correct: list[int], | |
| v2_correct: list[int], | |
| *, | |
| n_perm: int = 10_000, | |
| seed: int = 42, | |
| ) -> PerRowResult: | |
| """Paired-sample permutation test on per-scenario predictions. | |
| For each scenario, randomly swap the (v1, v2) correctness labels and | |
| measure how often the absolute mean-delta exceeds the observed. | |
| """ | |
| if len(v1_correct) != len(v2_correct): | |
| raise ValueError("v1 and v2 per-row vectors must be the same length") | |
| n = len(v1_correct) | |
| if n == 0: | |
| raise ValueError("Empty input vectors") | |
| observed_delta = abs( | |
| sum(v2_correct) / n - sum(v1_correct) / n | |
| ) | |
| rng = random.Random(seed) | |
| extreme = 0 | |
| for _ in range(n_perm): | |
| diffs = [] | |
| for v1_i, v2_i in zip(v1_correct, v2_correct): | |
| if rng.random() < 0.5: | |
| diffs.append(v2_i - v1_i) | |
| else: | |
| diffs.append(v1_i - v2_i) | |
| if abs(sum(diffs) / n) >= observed_delta - 1e-12: | |
| extreme += 1 | |
| p_perm = extreme / n_perm | |
| return PerRowResult( | |
| n_paired=n, | |
| v1_correct_count=sum(v1_correct), | |
| v2_correct_count=sum(v2_correct), | |
| observed_delta=observed_delta, | |
| n_perm=n_perm, | |
| p_value_permutation=p_perm, | |
| ) | |
| def _load_per_row_correct(path: Path) -> list[int]: | |
| rows = [] | |
| with path.open("r", encoding="utf-8") as f: | |
| for line in f: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| row = json.loads(line) | |
| rows.append(int(row["predicted"] == row["ground_truth"])) | |
| return rows | |
| # --------------------------------------------------------------------------- | |
| # CLI | |
| # --------------------------------------------------------------------------- | |
| def _interpret(p: float) -> str: | |
| if p < 1e-4: | |
| return "extremely significant (p < 0.0001)" | |
| if p < 1e-3: | |
| return f"highly significant (p = {p:.4g})" | |
| if p < 0.05: | |
| return f"significant (p = {p:.4g})" | |
| return f"not significant at alpha=0.05 (p = {p:.4g})" | |
| def main() -> int: | |
| parser = argparse.ArgumentParser(description=__doc__) | |
| parser.add_argument("--v1-fp", type=int, default=DEFAULT_V1_FP) | |
| parser.add_argument("--v1-n", type=int, default=DEFAULT_V1_FP + DEFAULT_V1_TN) | |
| parser.add_argument("--v2-fp", type=int, default=DEFAULT_V2_FP) | |
| parser.add_argument("--v2-n", type=int, default=DEFAULT_V2_FP + DEFAULT_V2_TN) | |
| parser.add_argument( | |
| "--v1-per-row", | |
| type=Path, | |
| default=None, | |
| help="Optional path to v1 per-row JSONL (one obj per scenario with `predicted`+`ground_truth`).", | |
| ) | |
| parser.add_argument( | |
| "--v2-per-row", | |
| type=Path, | |
| default=None, | |
| help="Optional path to v2 per-row JSONL (B.12 output).", | |
| ) | |
| parser.add_argument("--n-perm", type=int, default=10_000) | |
| parser.add_argument("--seed", type=int, default=42) | |
| parser.add_argument( | |
| "--output", | |
| type=Path, | |
| default=Path("logs/permutation_test_v1_v2.json"), | |
| ) | |
| args = parser.parse_args() | |
| payload: dict[str, Any] = { | |
| "meta": { | |
| "script": "eval/permutation_test_v1_v2.py", | |
| "n_perm": args.n_perm, | |
| "seed": args.seed, | |
| "interpretation_threshold": "alpha=0.05", | |
| "tests_run": [], | |
| }, | |
| } | |
| # Test 1 — aggregate counts (always runs) | |
| agg = aggregate_permutation_test( | |
| v1_fp=args.v1_fp, | |
| v1_n=args.v1_n, | |
| v2_fp=args.v2_fp, | |
| v2_n=args.v2_n, | |
| n_perm=args.n_perm, | |
| seed=args.seed, | |
| ) | |
| payload["aggregate_fpr_test"] = { | |
| "input": { | |
| "v1_fp": agg.v1_fp, | |
| "v1_n": agg.v1_n, | |
| "v1_fpr": round(agg.v1_fpr, 4), | |
| "v2_fp": agg.v2_fp, | |
| "v2_n": agg.v2_n, | |
| "v2_fpr": round(agg.v2_fpr, 4), | |
| }, | |
| "observed_fpr_delta_abs": round(agg.observed_delta, 4), | |
| "p_value_permutation": agg.p_value_permutation, | |
| "p_value_fisher_exact": agg.p_value_fisher_exact, | |
| "interpretation_permutation": _interpret(agg.p_value_permutation), | |
| "interpretation_fisher_exact": _interpret(agg.p_value_fisher_exact), | |
| } | |
| payload["meta"]["tests_run"].append("aggregate_fpr_test") | |
| # Test 2 — per-row paired (optional) | |
| if args.v1_per_row and args.v2_per_row: | |
| if not args.v1_per_row.exists(): | |
| print(f"[warn] --v1-per-row {args.v1_per_row} not found; skipping per-row test", file=sys.stderr) | |
| elif not args.v2_per_row.exists(): | |
| print(f"[warn] --v2-per-row {args.v2_per_row} not found; skipping per-row test", file=sys.stderr) | |
| else: | |
| v1_correct = _load_per_row_correct(args.v1_per_row) | |
| v2_correct = _load_per_row_correct(args.v2_per_row) | |
| if len(v1_correct) != len(v2_correct): | |
| print( | |
| f"[warn] paired per-row lengths differ ({len(v1_correct)} vs {len(v2_correct)}); skipping", | |
| file=sys.stderr, | |
| ) | |
| else: | |
| pr = per_row_paired_permutation( | |
| v1_correct, | |
| v2_correct, | |
| n_perm=args.n_perm, | |
| seed=args.seed, | |
| ) | |
| payload["per_row_paired_test"] = { | |
| "input": { | |
| "n_paired": pr.n_paired, | |
| "v1_correct": pr.v1_correct_count, | |
| "v2_correct": pr.v2_correct_count, | |
| }, | |
| "observed_accuracy_delta_abs": round(pr.observed_delta, 4), | |
| "p_value_permutation": pr.p_value_permutation, | |
| "interpretation_permutation": _interpret(pr.p_value_permutation), | |
| } | |
| payload["meta"]["tests_run"].append("per_row_paired_test") | |
| else: | |
| payload["per_row_paired_test"] = None | |
| payload["meta"]["per_row_test_skipped_reason"] = ( | |
| "B.12 per-row outputs not yet shipped (logs/eval_v2_per_row.jsonl + logs/eval_v1_per_row.jsonl). " | |
| "When B.12 produces them, re-run with --v1-per-row and --v2-per-row flags." | |
| ) | |
| args.output.parent.mkdir(parents=True, exist_ok=True) | |
| args.output.write_text(json.dumps(payload, indent=2) + "\n", encoding="utf-8") | |
| print(json.dumps(payload, indent=2)) | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |