"""Tests for eval/permutation_test_v1_v2.py.""" from __future__ import annotations import json from pathlib import Path import pytest from eval.permutation_test_v1_v2 import ( DEFAULT_V1_FP, DEFAULT_V1_TN, DEFAULT_V2_FP, DEFAULT_V2_TN, aggregate_permutation_test, per_row_paired_permutation, _fisher_exact_two_sided, ) def test_aggregate_permutation_significant_for_real_v1_v2_delta(): result = aggregate_permutation_test( v1_fp=DEFAULT_V1_FP, v1_n=DEFAULT_V1_FP + DEFAULT_V1_TN, v2_fp=DEFAULT_V2_FP, v2_n=DEFAULT_V2_FP + DEFAULT_V2_TN, n_perm=2_000, seed=42, ) assert result.v1_fpr == pytest.approx(11 / 30, abs=1e-6) assert result.v2_fpr == pytest.approx(2 / 30, abs=1e-6) assert result.observed_delta == pytest.approx(0.3, abs=1e-6) assert result.p_value_permutation < 0.05 assert result.p_value_fisher_exact < 0.05 def test_aggregate_permutation_not_significant_for_no_delta(): result = aggregate_permutation_test( v1_fp=5, v1_n=30, v2_fp=5, v2_n=30, n_perm=2_000, seed=7 ) assert result.observed_delta == pytest.approx(0.0, abs=1e-9) assert result.p_value_permutation >= 0.5 assert result.p_value_fisher_exact == pytest.approx(1.0, abs=1e-9) def test_aggregate_permutation_p_in_unit_interval(): result = aggregate_permutation_test( v1_fp=10, v1_n=30, v2_fp=3, v2_n=30, n_perm=1_000, seed=1 ) assert 0.0 <= result.p_value_permutation <= 1.0 assert 0.0 <= result.p_value_fisher_exact <= 1.0 def test_fisher_exact_matches_known_table(): p = _fisher_exact_two_sided(11, 19, 2, 28) assert 0.005 < p < 0.02 def test_per_row_paired_significant_when_clearly_better(): n = 50 v1 = [1] * 25 + [0] * 25 v2 = [1] * 45 + [0] * 5 result = per_row_paired_permutation(v1, v2, n_perm=2_000, seed=42) assert result.n_paired == n assert result.v1_correct_count == 25 assert result.v2_correct_count == 45 assert result.observed_delta == pytest.approx(0.4, abs=1e-9) assert result.p_value_permutation < 0.05 def test_per_row_paired_rejects_mismatched_lengths(): with pytest.raises(ValueError): per_row_paired_permutation([1, 0], [1, 0, 1], n_perm=10, seed=0) def test_per_row_paired_rejects_empty(): with pytest.raises(ValueError): per_row_paired_permutation([], [], n_perm=10, seed=0) def test_aggregate_rejects_zero_sample_size(): with pytest.raises(ValueError): aggregate_permutation_test(v1_fp=0, v1_n=0, v2_fp=0, v2_n=10, n_perm=10) def test_logged_artifact_matches_invocation_defaults(): """The shipped artifact in logs/ must match what the script writes by default.""" artifact = Path("logs/permutation_test_v1_v2.json") if not artifact.exists(): pytest.skip("logs/permutation_test_v1_v2.json not shipped yet") payload = json.loads(artifact.read_text()) assert "aggregate_fpr_test" in payload agg = payload["aggregate_fpr_test"] assert agg["input"]["v1_fp"] == DEFAULT_V1_FP assert agg["input"]["v2_fp"] == DEFAULT_V2_FP assert agg["p_value_permutation"] < 0.05 assert agg["p_value_fisher_exact"] < 0.05