chakravyuh / eval /threshold_sweep.py
UjjwalPardeshi
deploy: latest main to HF Space
03815d6
"""Threshold sweep — score a trained LoRA analyzer once, re-threshold many times.
Use case: you've trained a LoRA and eval at threshold=0.5 shows strong recall but
high FPR (classic over-flagging from reward hacking). Rather than retrain, sweep
the flag threshold across [0.3, 0.4, ..., 0.9] to find the P/R sweet spot.
KEY OPTIMIZATION: the LoRA produces a CONTINUOUS score per scenario. We run it
ONCE across the 175 bench scenarios (~15 min), cache `(scenario_id, score)`, then
apply different thresholds to the cached scores — each re-threshold is <1 second.
Usage:
# On Colab after training:
python -m eval.threshold_sweep \\
--model Qwen/Qwen2.5-7B-Instruct \\
--lora /content/drive/MyDrive/chakravyuh/analyzer_lora \\
--output /content/drive/MyDrive/chakravyuh/threshold_sweep.json
Output JSON:
{
"thresholds": {
"0.3": {"detection": 1.00, "fpr": 0.48, "precision": 0.88, "f1": 0.94, ...},
"0.4": {...},
...
},
"best_by_f1": {"threshold": 0.7, "f1": 0.945, ...},
"best_by_fpr_under_15": {"threshold": 0.75, ...}
}
"""
from __future__ import annotations
import argparse
import json
import logging
import sys
from dataclasses import asdict
from pathlib import Path
from chakravyuh_env.agents.llm_analyzer import LLMAnalyzer
from eval.mode_c_real_cases import (
DEFAULT_DATASET,
aggregate,
load_dataset,
per_category_breakdown,
per_difficulty_breakdown,
run_eval,
)
logger = logging.getLogger("chakravyuh.threshold_sweep")
DEFAULT_THRESHOLDS = [0.3, 0.4, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9]
def sweep(
analyzer: LLMAnalyzer,
dataset: list[dict],
thresholds: list[float],
) -> dict:
"""Run inference once, aggregate at each threshold."""
# One pass: collect continuous scores per scenario.
base_results = run_eval(analyzer, dataset, threshold=0.5) # threshold irrelevant here, we re-apply
# Map for fast re-thresholding.
logger.info("Scored %d scenarios. Re-thresholding %d cutoffs…", len(base_results), len(thresholds))
out: dict[str, dict] = {}
for thr in thresholds:
# Re-flag every scenario at this threshold.
rethresh = [
type(r)(
scenario_id=r.scenario_id,
is_scam_truth=r.is_scam_truth,
predicted_score=r.predicted_score,
predicted_flag=(r.predicted_score >= thr),
correct=((r.predicted_score >= thr) == r.is_scam_truth),
category=r.category,
difficulty=r.difficulty,
)
for r in base_results
]
m = aggregate(rethresh)
out[f"{thr:.2f}"] = {
"threshold": thr,
"n": m.n,
"detection": round(m.detection_rate, 4),
"fpr": round(m.false_positive_rate, 4),
"precision": round(m.precision, 4),
"recall": round(m.recall, 4),
"f1": round(m.f1, 4),
"accuracy": round(m.accuracy, 4),
}
logger.info(
"thr=%.2f det=%.1f%% fpr=%.1f%% P=%.1f%% F1=%.3f",
thr, m.detection_rate * 100, m.false_positive_rate * 100,
m.precision * 100, m.f1,
)
return {
"thresholds": out,
"best_by_f1": max(out.values(), key=lambda x: x["f1"]),
"best_by_fpr_under_15": min(
(v for v in out.values() if v["fpr"] <= 0.15),
key=lambda x: -x["f1"],
default=None,
),
"best_by_fpr_under_10": min(
(v for v in out.values() if v["fpr"] <= 0.10),
key=lambda x: -x["f1"],
default=None,
),
}
def main(argv: list[str] | None = None) -> int:
parser = argparse.ArgumentParser(description="Sweep flag thresholds on a trained LoRA analyzer.")
parser.add_argument("--model", default="Qwen/Qwen2.5-7B-Instruct")
parser.add_argument("--lora", required=True, type=Path, help="Path to LoRA adapter dir")
parser.add_argument("--dataset", type=Path, default=DEFAULT_DATASET)
parser.add_argument("--output", type=Path, required=True, help="Where to write the sweep JSON")
parser.add_argument(
"--thresholds",
type=float,
nargs="+",
default=DEFAULT_THRESHOLDS,
help="Threshold values to try",
)
parser.add_argument("--load-in-4bit", action="store_true", help="Force 4-bit load (smaller VRAM)")
args = parser.parse_args(argv)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
)
logger.info("Loading LoRA adapter from %s", args.lora)
analyzer = LLMAnalyzer(
model_name=args.model,
lora_path=str(args.lora),
use_unsloth=False,
load_in_4bit=args.load_in_4bit,
)
analyzer.load()
dataset = load_dataset(args.dataset)
logger.info("Loaded %d scenarios from %s", len(dataset), args.dataset)
result = sweep(analyzer, dataset, args.thresholds)
args.output.parent.mkdir(parents=True, exist_ok=True)
with open(args.output, "w") as f:
json.dump(result, f, indent=2)
logger.info("Wrote sweep results to %s", args.output)
# Print summary table.
print()
print("=== THRESHOLD SWEEP SUMMARY ===")
print(f"{'thr':<6}{'det':<8}{'fpr':<8}{'prec':<8}{'f1':<8}")
for thr_str, row in result["thresholds"].items():
print(
f"{thr_str:<6}"
f"{row['detection']:.3f} "
f"{row['fpr']:.3f} "
f"{row['precision']:.3f} "
f"{row['f1']:.3f}"
)
print()
best = result["best_by_f1"]
print(f"Best F1: thr={best['threshold']} F1={best['f1']:.3f} FPR={best['fpr']:.3f}")
if result["best_by_fpr_under_15"]:
b = result["best_by_fpr_under_15"]
print(f"Best F1 with FPR<15%: thr={b['threshold']} F1={b['f1']:.3f} FPR={b['fpr']:.3f}")
if result["best_by_fpr_under_10"]:
b = result["best_by_fpr_under_10"]
print(f"Best F1 with FPR<10%: thr={b['threshold']} F1={b['f1']:.3f} FPR={b['fpr']:.3f}")
return 0
if __name__ == "__main__":
sys.exit(main())