chakravyuh / eval /time_to_detection.py
UjjwalPardeshi
deploy: latest main to HF Space
03815d6
"""Time-to-detection metric — uses existing 100-episode scripted-baseline trace.
Reads ``logs/baseline_day1.json`` (100 episodes with ``detected_by_turn``),
computes:
- avg_detection_turn (when flagged)
- pct_detected_by_turn_2 / 3 / 5 (cumulative early-detection share)
- per-category breakdown
- detection-turn distribution (histogram)
A turn-2 flag is materially different from a turn-5 flag in fraud terms —
₹ at risk grows monotonically with delay.
Usage:
python eval/time_to_detection.py
python eval/time_to_detection.py --output logs/time_to_detection.json
"""
from __future__ import annotations
import argparse
import json
from collections import Counter, defaultdict
from pathlib import Path
from statistics import mean
REPO_ROOT = Path(__file__).resolve().parent.parent
INPUT_DEFAULT = REPO_ROOT / "logs" / "baseline_day1.json"
OUTPUT_DEFAULT = REPO_ROOT / "logs" / "time_to_detection.json"
def compute(rows: list[dict]) -> dict:
flagged_rows = [
r for r in rows
if r.get("analyzer_flagged") and isinstance(r.get("detected_by_turn"), int)
]
detection_turns = [r["detected_by_turn"] for r in flagged_rows]
summary: dict = {
"n_episodes": len(rows),
"n_flagged": len(flagged_rows),
"detection_rate": round(len(flagged_rows) / len(rows), 4) if rows else 0.0,
"avg_detection_turn": round(mean(detection_turns), 2) if detection_turns else None,
"median_detection_turn": (
sorted(detection_turns)[len(detection_turns) // 2]
if detection_turns else None
),
"min_detection_turn": min(detection_turns) if detection_turns else None,
"max_detection_turn": max(detection_turns) if detection_turns else None,
}
# Cumulative early-detection share, conditioned on the full N (the metric
# judges should care about: among all scams, what fraction was caught by
# turn k? Late-flagged + missed both count as failures.)
n_total = len(rows)
for cutoff in (2, 3, 4, 5):
early = sum(1 for t in detection_turns if t <= cutoff)
summary[f"pct_detected_by_turn_{cutoff}"] = (
round(early / n_total, 4) if n_total else 0.0
)
# Per-category breakdown
by_cat: dict[str, dict] = defaultdict(lambda: {"n": 0, "n_flagged": 0, "turns": []})
for r in rows:
cat = r.get("category", "unknown")
b = by_cat[cat]
b["n"] += 1
if r.get("analyzer_flagged"):
b["n_flagged"] += 1
t = r.get("detected_by_turn")
if isinstance(t, int):
b["turns"].append(t)
summary["by_category"] = {
cat: {
"n": b["n"],
"n_flagged": b["n_flagged"],
"detection_rate": round(b["n_flagged"] / b["n"], 4) if b["n"] else 0.0,
"avg_detection_turn": (
round(mean(b["turns"]), 2) if b["turns"] else None
),
}
for cat, b in sorted(by_cat.items())
}
# Histogram of detection turn (1-9, plus 'never' for unflagged)
histogram: Counter[str] = Counter()
for r in rows:
if r.get("analyzer_flagged") and isinstance(r.get("detected_by_turn"), int):
histogram[f"T{r['detected_by_turn']}"] += 1
else:
histogram["never"] += 1
summary["histogram"] = dict(histogram)
return summary
def main() -> int:
parser = argparse.ArgumentParser(description=__doc__.split("\n")[0])
parser.add_argument("--input", type=Path, default=INPUT_DEFAULT)
parser.add_argument("--output", type=Path, default=OUTPUT_DEFAULT)
args = parser.parse_args()
if not args.input.exists():
raise SystemExit(f"input not found: {args.input}")
data = json.loads(args.input.read_text())
rows = data.get("rows", []) if isinstance(data, dict) else data
if not rows:
raise SystemExit("no rows in input file")
summary = compute(rows)
summary["source"] = str(args.input.relative_to(REPO_ROOT))
summary["analyzer"] = "rule_based_scripted"
summary["notes"] = (
"Time-to-detection on the 100-episode env-rollout baseline (scripted "
"scammer × scripted analyzer × scripted bank). LoRA-v2 time-to-detection "
"in episode rollouts requires GPU re-inference — pending v3."
)
args.output.parent.mkdir(parents=True, exist_ok=True)
args.output.write_text(json.dumps(summary, indent=2) + "\n")
print(f"time-to-detection: {args.output}")
print(f" detection_rate = {summary['detection_rate']:.3f}")
print(f" avg_detection_turn = {summary['avg_detection_turn']}")
for cutoff in (2, 3, 4, 5):
pct = summary[f"pct_detected_by_turn_{cutoff}"]
print(f" pct_detected_by_turn_{cutoff} = {pct:.3f}")
return 0
if __name__ == "__main__":
raise SystemExit(main())