Spaces:
Running
Running
File size: 7,195 Bytes
03815d6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 | """Training-data integrity tests.
CRITICAL: these tests guard against training-test contamination. The LoRA is
evaluated on chakravyuh-bench-v0 (scenarios.jsonl). If any training example
duplicates a test-set message, the evaluation numbers are invalid.
Run: pytest tests/test_training_data.py -v
"""
from __future__ import annotations
import json
from pathlib import Path
import pytest
from training.grpo_analyzer import (
DEFAULT_BENIGN_PATH,
DEFAULT_MULTITURN_PATH,
DEFAULT_PARAPHRASE_PATH,
DEFAULT_REGIONAL_PATH,
DEFAULT_TEMPLATES_PATH,
TEST_SET_PATH,
build_training_examples,
)
def _normalize(text: str) -> str:
"""Case-insensitive, whitespace-collapsed form for text comparison."""
return " ".join(text.lower().split())
def _load_test_set_texts() -> set[str]:
"""Every scammer message in the benchmark test set, normalized."""
texts: set[str] = set()
with TEST_SET_PATH.open(encoding="utf-8") as f:
for line in f:
if not line.strip():
continue
scenario = json.loads(line)
for step in scenario.get("attack_sequence", []):
if step.get("sender") == "scammer":
texts.add(_normalize(step.get("text", "")))
return texts
@pytest.mark.unit
def test_all_training_source_files_exist():
"""All 5 training template files must exist on disk."""
for path in (
DEFAULT_TEMPLATES_PATH,
DEFAULT_BENIGN_PATH,
DEFAULT_PARAPHRASE_PATH,
DEFAULT_REGIONAL_PATH,
DEFAULT_MULTITURN_PATH,
):
assert path.exists(), f"Missing training source: {path}"
@pytest.mark.unit
def test_training_examples_include_all_five_sources():
"""Built corpus should include scams from templates+paraphrase+regional+multiturn and benign from benign_templates.
After soft-leakage filter drops ~53/200 canonical templates, corpus is
~227 scams + ~55 benign = ~283. Floor is set accordingly.
"""
examples = build_training_examples()
assert len(examples) >= 270, f"Expected 270+ training examples, got {len(examples)}"
scams = [e for e in examples if e.is_scam]
benigns = [e for e in examples if not e.is_scam]
assert len(scams) >= 220, f"Expected 220+ scams after filter, got {len(scams)}"
assert len(benigns) >= 50, f"Expected 50+ benigns, got {len(benigns)}"
@pytest.mark.unit
def test_no_training_scam_duplicates_test_set_scam():
"""CRITICAL: training scam messages must not appear in the benchmark test set."""
test_texts = _load_test_set_texts()
examples = build_training_examples()
overlaps: list[tuple[str, str]] = []
for ex in examples:
if not ex.is_scam:
continue
normalized = _normalize(ex.prompt_text)
# For multi-turn, check each segment
segments = [normalized] + [
_normalize(line) for line in ex.prompt_text.split("\n") if line.strip()
]
for seg in segments:
if len(seg) < 40: # too short to meaningfully overlap
continue
if seg in test_texts:
overlaps.append((ex.category, seg[:80]))
assert not overlaps, (
f"Found {len(overlaps)} training-scam / test-set overlaps. "
f"First 3: {overlaps[:3]}"
)
@pytest.mark.unit
def test_no_soft_substring_leakage_in_built_corpus():
"""CRITICAL: after soft-leakage filter, no training line is a substring of
any test-set scammer text (and vice versa).
This is the stronger leakage guarantee — catches cases where a canonical
template opener appears verbatim inside a longer test scenario.
"""
test_texts = _load_test_set_texts()
examples = build_training_examples()
violations: list[tuple[str, str]] = []
for ex in examples:
if not ex.is_scam:
continue
for line in ex.prompt_text.split("\n"):
line_norm = _normalize(line)
if len(line_norm) < 40:
continue
for t_text in test_texts:
if line_norm in t_text or t_text in line_norm:
violations.append((ex.category, line_norm[:80]))
break
assert not violations, (
f"Soft-leakage filter failed: {len(violations)} training lines are "
f"substrings of test-set messages. First 3: {violations[:3]}"
)
@pytest.mark.unit
def test_no_training_benign_duplicates_test_set_benign():
"""CRITICAL: benign training SMS must not appear in the benchmark."""
test_texts = _load_test_set_texts()
examples = build_training_examples()
overlaps: list[str] = []
for ex in examples:
if ex.is_scam:
continue
normalized = _normalize(ex.prompt_text)
if normalized in test_texts:
overlaps.append(normalized[:80])
assert not overlaps, (
f"Found {len(overlaps)} benign training/test overlaps: {overlaps[:3]}"
)
@pytest.mark.unit
def test_training_covers_all_5_scam_categories():
"""LoRA needs representation of every scam category it will be evaluated on."""
examples = build_training_examples()
categories = {e.category for e in examples if e.is_scam}
required = {
"otp_theft",
"kyc_fraud",
"loan_app_fraud",
"investment_fraud",
"impersonation",
}
missing = required - categories
assert not missing, f"Training missing scam categories: {missing}"
@pytest.mark.unit
def test_training_has_regional_language_coverage():
"""Training must include at least one non-English scam for multilingual transfer."""
with DEFAULT_REGIONAL_PATH.open(encoding="utf-8") as f:
regional = json.load(f)["templates"]
# Verify the regional templates actually contain non-ASCII or specific non-English markers
has_nonlatin = sum(
1 for t in regional if any(ord(c) > 127 for c in t.get("opener", ""))
)
assert has_nonlatin >= 5, (
f"Expected 5+ regional templates with non-Latin scripts, got {has_nonlatin}"
)
@pytest.mark.unit
def test_training_has_multiturn_examples():
"""Multi-turn sequences must be represented so LoRA handles dialog context."""
with DEFAULT_MULTITURN_PATH.open(encoding="utf-8") as f:
mt = json.load(f)["templates"]
assert len(mt) >= 10, f"Expected 10+ multi-turn templates, got {len(mt)}"
for t in mt:
assert len(t.get("turns", [])) >= 2, f"Multi-turn should have 2+ turns: {t['id']}"
@pytest.mark.unit
def test_benign_templates_balanced_across_categories():
"""Benign pool should span banking, delivery, utility, govt, subscription, otp_legit, misc."""
with DEFAULT_BENIGN_PATH.open(encoding="utf-8") as f:
benign = json.load(f)["templates"]
categories = {t.get("category") for t in benign}
required = {
"banking",
"delivery",
"utility",
"insurance",
"govt",
"subscription",
"otp_legit",
"misc",
}
missing = required - categories
assert not missing, f"Benign pool missing categories: {missing}"
assert len(benign) >= 70, f"Expected 70+ benign templates, got {len(benign)}"
|