chakravyuh / tests /test_training_data.py
UjjwalPardeshi
deploy: latest main to HF Space
03815d6
"""Training-data integrity tests.
CRITICAL: these tests guard against training-test contamination. The LoRA is
evaluated on chakravyuh-bench-v0 (scenarios.jsonl). If any training example
duplicates a test-set message, the evaluation numbers are invalid.
Run: pytest tests/test_training_data.py -v
"""
from __future__ import annotations
import json
from pathlib import Path
import pytest
from training.grpo_analyzer import (
DEFAULT_BENIGN_PATH,
DEFAULT_MULTITURN_PATH,
DEFAULT_PARAPHRASE_PATH,
DEFAULT_REGIONAL_PATH,
DEFAULT_TEMPLATES_PATH,
TEST_SET_PATH,
build_training_examples,
)
def _normalize(text: str) -> str:
"""Case-insensitive, whitespace-collapsed form for text comparison."""
return " ".join(text.lower().split())
def _load_test_set_texts() -> set[str]:
"""Every scammer message in the benchmark test set, normalized."""
texts: set[str] = set()
with TEST_SET_PATH.open(encoding="utf-8") as f:
for line in f:
if not line.strip():
continue
scenario = json.loads(line)
for step in scenario.get("attack_sequence", []):
if step.get("sender") == "scammer":
texts.add(_normalize(step.get("text", "")))
return texts
@pytest.mark.unit
def test_all_training_source_files_exist():
"""All 5 training template files must exist on disk."""
for path in (
DEFAULT_TEMPLATES_PATH,
DEFAULT_BENIGN_PATH,
DEFAULT_PARAPHRASE_PATH,
DEFAULT_REGIONAL_PATH,
DEFAULT_MULTITURN_PATH,
):
assert path.exists(), f"Missing training source: {path}"
@pytest.mark.unit
def test_training_examples_include_all_five_sources():
"""Built corpus should include scams from templates+paraphrase+regional+multiturn and benign from benign_templates.
After soft-leakage filter drops ~53/200 canonical templates, corpus is
~227 scams + ~55 benign = ~283. Floor is set accordingly.
"""
examples = build_training_examples()
assert len(examples) >= 270, f"Expected 270+ training examples, got {len(examples)}"
scams = [e for e in examples if e.is_scam]
benigns = [e for e in examples if not e.is_scam]
assert len(scams) >= 220, f"Expected 220+ scams after filter, got {len(scams)}"
assert len(benigns) >= 50, f"Expected 50+ benigns, got {len(benigns)}"
@pytest.mark.unit
def test_no_training_scam_duplicates_test_set_scam():
"""CRITICAL: training scam messages must not appear in the benchmark test set."""
test_texts = _load_test_set_texts()
examples = build_training_examples()
overlaps: list[tuple[str, str]] = []
for ex in examples:
if not ex.is_scam:
continue
normalized = _normalize(ex.prompt_text)
# For multi-turn, check each segment
segments = [normalized] + [
_normalize(line) for line in ex.prompt_text.split("\n") if line.strip()
]
for seg in segments:
if len(seg) < 40: # too short to meaningfully overlap
continue
if seg in test_texts:
overlaps.append((ex.category, seg[:80]))
assert not overlaps, (
f"Found {len(overlaps)} training-scam / test-set overlaps. "
f"First 3: {overlaps[:3]}"
)
@pytest.mark.unit
def test_no_soft_substring_leakage_in_built_corpus():
"""CRITICAL: after soft-leakage filter, no training line is a substring of
any test-set scammer text (and vice versa).
This is the stronger leakage guarantee — catches cases where a canonical
template opener appears verbatim inside a longer test scenario.
"""
test_texts = _load_test_set_texts()
examples = build_training_examples()
violations: list[tuple[str, str]] = []
for ex in examples:
if not ex.is_scam:
continue
for line in ex.prompt_text.split("\n"):
line_norm = _normalize(line)
if len(line_norm) < 40:
continue
for t_text in test_texts:
if line_norm in t_text or t_text in line_norm:
violations.append((ex.category, line_norm[:80]))
break
assert not violations, (
f"Soft-leakage filter failed: {len(violations)} training lines are "
f"substrings of test-set messages. First 3: {violations[:3]}"
)
@pytest.mark.unit
def test_no_training_benign_duplicates_test_set_benign():
"""CRITICAL: benign training SMS must not appear in the benchmark."""
test_texts = _load_test_set_texts()
examples = build_training_examples()
overlaps: list[str] = []
for ex in examples:
if ex.is_scam:
continue
normalized = _normalize(ex.prompt_text)
if normalized in test_texts:
overlaps.append(normalized[:80])
assert not overlaps, (
f"Found {len(overlaps)} benign training/test overlaps: {overlaps[:3]}"
)
@pytest.mark.unit
def test_training_covers_all_5_scam_categories():
"""LoRA needs representation of every scam category it will be evaluated on."""
examples = build_training_examples()
categories = {e.category for e in examples if e.is_scam}
required = {
"otp_theft",
"kyc_fraud",
"loan_app_fraud",
"investment_fraud",
"impersonation",
}
missing = required - categories
assert not missing, f"Training missing scam categories: {missing}"
@pytest.mark.unit
def test_training_has_regional_language_coverage():
"""Training must include at least one non-English scam for multilingual transfer."""
with DEFAULT_REGIONAL_PATH.open(encoding="utf-8") as f:
regional = json.load(f)["templates"]
# Verify the regional templates actually contain non-ASCII or specific non-English markers
has_nonlatin = sum(
1 for t in regional if any(ord(c) > 127 for c in t.get("opener", ""))
)
assert has_nonlatin >= 5, (
f"Expected 5+ regional templates with non-Latin scripts, got {has_nonlatin}"
)
@pytest.mark.unit
def test_training_has_multiturn_examples():
"""Multi-turn sequences must be represented so LoRA handles dialog context."""
with DEFAULT_MULTITURN_PATH.open(encoding="utf-8") as f:
mt = json.load(f)["templates"]
assert len(mt) >= 10, f"Expected 10+ multi-turn templates, got {len(mt)}"
for t in mt:
assert len(t.get("turns", [])) >= 2, f"Multi-turn should have 2+ turns: {t['id']}"
@pytest.mark.unit
def test_benign_templates_balanced_across_categories():
"""Benign pool should span banking, delivery, utility, govt, subscription, otp_legit, misc."""
with DEFAULT_BENIGN_PATH.open(encoding="utf-8") as f:
benign = json.load(f)["templates"]
categories = {t.get("category") for t in benign}
required = {
"banking",
"delivery",
"utility",
"insurance",
"govt",
"subscription",
"otp_legit",
"misc",
}
missing = required - categories
assert not missing, f"Benign pool missing categories: {missing}"
assert len(benign) >= 70, f"Expected 70+ benign templates, got {len(benign)}"