File size: 12,498 Bytes
03815d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
"""Permutation test for the v1 β†’ v2 FPR delta.

Bootstrap CIs (already shipped in `logs/bootstrap_v2.json`) tell you the spread
of v2's metric. They do *not* tell you whether the v1 β†’ v2 FPR drop
(36 % β†’ 6.7 %) is statistically significant. This script answers that with two
complementary tests on the same input data:

  1. **Aggregate-counts permutation** (always runs). Treats each model's
     bench evaluation as a 2 x 2 contingency over benigns:
         | predicted scam | predicted benign |
       v1 |     a          |      b           |
       v2 |     c          |      d           |
     and computes a permutation p-value on the difference of FPRs by
     repeatedly randomly relabelling rows under the null "v1 and v2 came from
     the same distribution," over `--n-perm` (default 10 000) iterations.
     Also reports the closed-form Fisher exact p-value as a cross-check.

  2. **Per-row paired permutation** (runs only when both per-row predictions
     exist). Uses the WIN_PLAN D.2 algorithm: for each scenario, randomly
     swap the (v1_correct, v2_correct) label assignment, recompute the
     mean accuracy delta, and tally how often the absolute delta exceeds
     the observed value.

Inputs
------
- `--v1-counts a,b` and `--v2-counts c,d` (or the defaults below) for test 1.
- `--v1-per-row` and `--v2-per-row` JSONL files for test 2 (optional).

Output
------
JSON written to `--output` (default `logs/permutation_test_v1_v2.json`) with
both p-values, observed deltas, sample sizes, and a one-line interpretation.

Defaults
--------
The aggregate counts are taken from the README claims that we want to
quantify the significance of:
  v1: FPR 36 % on n=30 benigns  β†’ 11 false positives, 19 true negatives
  v2: FPR 6.7 % on n=30 benigns β†’  2 false positives, 28 true negatives

These come from `logs/eval_v2.json` (v2 side; threshold 0.55) and from the v1
training-time eval cited in `docs/training_diagnostics.md` (the v1 LoRA was not
re-pushed to HF Hub but its bench-time numbers are the ones cited in every
README/slide).

Run:
    python eval/permutation_test_v1_v2.py
or:
    python eval/permutation_test_v1_v2.py --n-perm 100000 --seed 42
"""

from __future__ import annotations

import argparse
import json
import math
import random
import sys
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any

# Defaults β€” see module docstring for citation.
DEFAULT_V1_FP = 11
DEFAULT_V1_TN = 19
DEFAULT_V2_FP = 2
DEFAULT_V2_TN = 28


# ---------------------------------------------------------------------------
# Aggregate-counts permutation (always runs)
# ---------------------------------------------------------------------------

@dataclass
class AggregateResult:
    v1_fp: int
    v1_n: int
    v2_fp: int
    v2_n: int
    v1_fpr: float
    v2_fpr: float
    observed_delta: float
    n_perm: int
    p_value_permutation: float
    p_value_fisher_exact: float


def _fisher_exact_two_sided(a: int, b: int, c: int, d: int) -> float:
    """Closed-form two-sided Fisher exact for a 2x2 table.

    Computes the probability of seeing a table at least as extreme as
    [[a, b], [c, d]] under the null of independence with the same row + column
    margins. Uses the standard log-factorial trick to avoid overflow.
    """
    n = a + b + c + d
    row1 = a + b
    row2 = c + d
    col1 = a + c
    col2 = b + d

    def log_choose(n_: int, k_: int) -> float:
        if k_ < 0 or k_ > n_:
            return -math.inf
        return math.lgamma(n_ + 1) - math.lgamma(k_ + 1) - math.lgamma(n_ - k_ + 1)

    log_denom = log_choose(n, col1)
    observed_log_p = log_choose(row1, a) + log_choose(row2, c) - log_denom
    observed_p = math.exp(observed_log_p)

    total_p = 0.0
    a_min = max(0, col1 - row2)
    a_max = min(row1, col1)
    for a_alt in range(a_min, a_max + 1):
        c_alt = col1 - a_alt
        log_p = log_choose(row1, a_alt) + log_choose(row2, c_alt) - log_denom
        p = math.exp(log_p)
        if p <= observed_p + 1e-12:
            total_p += p
    return min(1.0, total_p)


def aggregate_permutation_test(
    v1_fp: int,
    v1_n: int,
    v2_fp: int,
    v2_n: int,
    *,
    n_perm: int = 10_000,
    seed: int = 42,
) -> AggregateResult:
    """Permutation test on the FPR delta from aggregate counts.

    Constructs the full label vector (1 = false positive, 0 = true negative)
    for v1 and v2, then under the null reshuffles the model labels and
    measures how often the absolute FPR delta exceeds the observed.
    """
    if v1_n <= 0 or v2_n <= 0:
        raise ValueError("Sample sizes must be positive")

    rng = random.Random(seed)
    v1_labels = [1] * v1_fp + [0] * (v1_n - v1_fp)
    v2_labels = [1] * v2_fp + [0] * (v2_n - v2_fp)
    pooled = v1_labels + v2_labels

    v1_fpr = v1_fp / v1_n
    v2_fpr = v2_fp / v2_n
    observed_delta = abs(v2_fpr - v1_fpr)

    extreme = 0
    for _ in range(n_perm):
        rng.shuffle(pooled)
        perm_v1 = pooled[:v1_n]
        perm_v2 = pooled[v1_n:]
        delta = abs(sum(perm_v2) / v2_n - sum(perm_v1) / v1_n)
        if delta >= observed_delta - 1e-12:
            extreme += 1
    p_perm = extreme / n_perm

    p_fisher = _fisher_exact_two_sided(
        v1_fp, v1_n - v1_fp, v2_fp, v2_n - v2_fp
    )

    return AggregateResult(
        v1_fp=v1_fp,
        v1_n=v1_n,
        v2_fp=v2_fp,
        v2_n=v2_n,
        v1_fpr=v1_fpr,
        v2_fpr=v2_fpr,
        observed_delta=observed_delta,
        n_perm=n_perm,
        p_value_permutation=p_perm,
        p_value_fisher_exact=p_fisher,
    )


# ---------------------------------------------------------------------------
# Per-row paired permutation (optional β€” runs only with --v1-per-row + --v2-per-row)
# ---------------------------------------------------------------------------

@dataclass
class PerRowResult:
    n_paired: int
    v1_correct_count: int
    v2_correct_count: int
    observed_delta: float
    n_perm: int
    p_value_permutation: float


def per_row_paired_permutation(
    v1_correct: list[int],
    v2_correct: list[int],
    *,
    n_perm: int = 10_000,
    seed: int = 42,
) -> PerRowResult:
    """Paired-sample permutation test on per-scenario predictions.

    For each scenario, randomly swap the (v1, v2) correctness labels and
    measure how often the absolute mean-delta exceeds the observed.
    """
    if len(v1_correct) != len(v2_correct):
        raise ValueError("v1 and v2 per-row vectors must be the same length")
    n = len(v1_correct)
    if n == 0:
        raise ValueError("Empty input vectors")

    observed_delta = abs(
        sum(v2_correct) / n - sum(v1_correct) / n
    )

    rng = random.Random(seed)
    extreme = 0
    for _ in range(n_perm):
        diffs = []
        for v1_i, v2_i in zip(v1_correct, v2_correct):
            if rng.random() < 0.5:
                diffs.append(v2_i - v1_i)
            else:
                diffs.append(v1_i - v2_i)
        if abs(sum(diffs) / n) >= observed_delta - 1e-12:
            extreme += 1
    p_perm = extreme / n_perm

    return PerRowResult(
        n_paired=n,
        v1_correct_count=sum(v1_correct),
        v2_correct_count=sum(v2_correct),
        observed_delta=observed_delta,
        n_perm=n_perm,
        p_value_permutation=p_perm,
    )


def _load_per_row_correct(path: Path) -> list[int]:
    rows = []
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            row = json.loads(line)
            rows.append(int(row["predicted"] == row["ground_truth"]))
    return rows


# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------

def _interpret(p: float) -> str:
    if p < 1e-4:
        return "extremely significant (p < 0.0001)"
    if p < 1e-3:
        return f"highly significant (p = {p:.4g})"
    if p < 0.05:
        return f"significant (p = {p:.4g})"
    return f"not significant at alpha=0.05 (p = {p:.4g})"


def main() -> int:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("--v1-fp", type=int, default=DEFAULT_V1_FP)
    parser.add_argument("--v1-n", type=int, default=DEFAULT_V1_FP + DEFAULT_V1_TN)
    parser.add_argument("--v2-fp", type=int, default=DEFAULT_V2_FP)
    parser.add_argument("--v2-n", type=int, default=DEFAULT_V2_FP + DEFAULT_V2_TN)
    parser.add_argument(
        "--v1-per-row",
        type=Path,
        default=None,
        help="Optional path to v1 per-row JSONL (one obj per scenario with `predicted`+`ground_truth`).",
    )
    parser.add_argument(
        "--v2-per-row",
        type=Path,
        default=None,
        help="Optional path to v2 per-row JSONL (B.12 output).",
    )
    parser.add_argument("--n-perm", type=int, default=10_000)
    parser.add_argument("--seed", type=int, default=42)
    parser.add_argument(
        "--output",
        type=Path,
        default=Path("logs/permutation_test_v1_v2.json"),
    )
    args = parser.parse_args()

    payload: dict[str, Any] = {
        "meta": {
            "script": "eval/permutation_test_v1_v2.py",
            "n_perm": args.n_perm,
            "seed": args.seed,
            "interpretation_threshold": "alpha=0.05",
            "tests_run": [],
        },
    }

    # Test 1 β€” aggregate counts (always runs)
    agg = aggregate_permutation_test(
        v1_fp=args.v1_fp,
        v1_n=args.v1_n,
        v2_fp=args.v2_fp,
        v2_n=args.v2_n,
        n_perm=args.n_perm,
        seed=args.seed,
    )
    payload["aggregate_fpr_test"] = {
        "input": {
            "v1_fp": agg.v1_fp,
            "v1_n": agg.v1_n,
            "v1_fpr": round(agg.v1_fpr, 4),
            "v2_fp": agg.v2_fp,
            "v2_n": agg.v2_n,
            "v2_fpr": round(agg.v2_fpr, 4),
        },
        "observed_fpr_delta_abs": round(agg.observed_delta, 4),
        "p_value_permutation": agg.p_value_permutation,
        "p_value_fisher_exact": agg.p_value_fisher_exact,
        "interpretation_permutation": _interpret(agg.p_value_permutation),
        "interpretation_fisher_exact": _interpret(agg.p_value_fisher_exact),
    }
    payload["meta"]["tests_run"].append("aggregate_fpr_test")

    # Test 2 β€” per-row paired (optional)
    if args.v1_per_row and args.v2_per_row:
        if not args.v1_per_row.exists():
            print(f"[warn] --v1-per-row {args.v1_per_row} not found; skipping per-row test", file=sys.stderr)
        elif not args.v2_per_row.exists():
            print(f"[warn] --v2-per-row {args.v2_per_row} not found; skipping per-row test", file=sys.stderr)
        else:
            v1_correct = _load_per_row_correct(args.v1_per_row)
            v2_correct = _load_per_row_correct(args.v2_per_row)
            if len(v1_correct) != len(v2_correct):
                print(
                    f"[warn] paired per-row lengths differ ({len(v1_correct)} vs {len(v2_correct)}); skipping",
                    file=sys.stderr,
                )
            else:
                pr = per_row_paired_permutation(
                    v1_correct,
                    v2_correct,
                    n_perm=args.n_perm,
                    seed=args.seed,
                )
                payload["per_row_paired_test"] = {
                    "input": {
                        "n_paired": pr.n_paired,
                        "v1_correct": pr.v1_correct_count,
                        "v2_correct": pr.v2_correct_count,
                    },
                    "observed_accuracy_delta_abs": round(pr.observed_delta, 4),
                    "p_value_permutation": pr.p_value_permutation,
                    "interpretation_permutation": _interpret(pr.p_value_permutation),
                }
                payload["meta"]["tests_run"].append("per_row_paired_test")
    else:
        payload["per_row_paired_test"] = None
        payload["meta"]["per_row_test_skipped_reason"] = (
            "B.12 per-row outputs not yet shipped (logs/eval_v2_per_row.jsonl + logs/eval_v1_per_row.jsonl). "
            "When B.12 produces them, re-run with --v1-per-row and --v2-per-row flags."
        )

    args.output.parent.mkdir(parents=True, exist_ok=True)
    args.output.write_text(json.dumps(payload, indent=2) + "\n", encoding="utf-8")
    print(json.dumps(payload, indent=2))
    return 0


if __name__ == "__main__":
    raise SystemExit(main())