Spaces:
Running
Running
| """Threshold sweep — score a trained LoRA analyzer once, re-threshold many times. | |
| Use case: you've trained a LoRA and eval at threshold=0.5 shows strong recall but | |
| high FPR (classic over-flagging from reward hacking). Rather than retrain, sweep | |
| the flag threshold across [0.3, 0.4, ..., 0.9] to find the P/R sweet spot. | |
| KEY OPTIMIZATION: the LoRA produces a CONTINUOUS score per scenario. We run it | |
| ONCE across the 175 bench scenarios (~15 min), cache `(scenario_id, score)`, then | |
| apply different thresholds to the cached scores — each re-threshold is <1 second. | |
| Usage: | |
| # On Colab after training: | |
| python -m eval.threshold_sweep \\ | |
| --model Qwen/Qwen2.5-7B-Instruct \\ | |
| --lora /content/drive/MyDrive/chakravyuh/analyzer_lora \\ | |
| --output /content/drive/MyDrive/chakravyuh/threshold_sweep.json | |
| Output JSON: | |
| { | |
| "thresholds": { | |
| "0.3": {"detection": 1.00, "fpr": 0.48, "precision": 0.88, "f1": 0.94, ...}, | |
| "0.4": {...}, | |
| ... | |
| }, | |
| "best_by_f1": {"threshold": 0.7, "f1": 0.945, ...}, | |
| "best_by_fpr_under_15": {"threshold": 0.75, ...} | |
| } | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import logging | |
| import sys | |
| from dataclasses import asdict | |
| from pathlib import Path | |
| from chakravyuh_env.agents.llm_analyzer import LLMAnalyzer | |
| from eval.mode_c_real_cases import ( | |
| DEFAULT_DATASET, | |
| aggregate, | |
| load_dataset, | |
| per_category_breakdown, | |
| per_difficulty_breakdown, | |
| run_eval, | |
| ) | |
| logger = logging.getLogger("chakravyuh.threshold_sweep") | |
| DEFAULT_THRESHOLDS = [0.3, 0.4, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9] | |
| def sweep( | |
| analyzer: LLMAnalyzer, | |
| dataset: list[dict], | |
| thresholds: list[float], | |
| ) -> dict: | |
| """Run inference once, aggregate at each threshold.""" | |
| # One pass: collect continuous scores per scenario. | |
| base_results = run_eval(analyzer, dataset, threshold=0.5) # threshold irrelevant here, we re-apply | |
| # Map for fast re-thresholding. | |
| logger.info("Scored %d scenarios. Re-thresholding %d cutoffs…", len(base_results), len(thresholds)) | |
| out: dict[str, dict] = {} | |
| for thr in thresholds: | |
| # Re-flag every scenario at this threshold. | |
| rethresh = [ | |
| type(r)( | |
| scenario_id=r.scenario_id, | |
| is_scam_truth=r.is_scam_truth, | |
| predicted_score=r.predicted_score, | |
| predicted_flag=(r.predicted_score >= thr), | |
| correct=((r.predicted_score >= thr) == r.is_scam_truth), | |
| category=r.category, | |
| difficulty=r.difficulty, | |
| ) | |
| for r in base_results | |
| ] | |
| m = aggregate(rethresh) | |
| out[f"{thr:.2f}"] = { | |
| "threshold": thr, | |
| "n": m.n, | |
| "detection": round(m.detection_rate, 4), | |
| "fpr": round(m.false_positive_rate, 4), | |
| "precision": round(m.precision, 4), | |
| "recall": round(m.recall, 4), | |
| "f1": round(m.f1, 4), | |
| "accuracy": round(m.accuracy, 4), | |
| } | |
| logger.info( | |
| "thr=%.2f det=%.1f%% fpr=%.1f%% P=%.1f%% F1=%.3f", | |
| thr, m.detection_rate * 100, m.false_positive_rate * 100, | |
| m.precision * 100, m.f1, | |
| ) | |
| return { | |
| "thresholds": out, | |
| "best_by_f1": max(out.values(), key=lambda x: x["f1"]), | |
| "best_by_fpr_under_15": min( | |
| (v for v in out.values() if v["fpr"] <= 0.15), | |
| key=lambda x: -x["f1"], | |
| default=None, | |
| ), | |
| "best_by_fpr_under_10": min( | |
| (v for v in out.values() if v["fpr"] <= 0.10), | |
| key=lambda x: -x["f1"], | |
| default=None, | |
| ), | |
| } | |
| def main(argv: list[str] | None = None) -> int: | |
| parser = argparse.ArgumentParser(description="Sweep flag thresholds on a trained LoRA analyzer.") | |
| parser.add_argument("--model", default="Qwen/Qwen2.5-7B-Instruct") | |
| parser.add_argument("--lora", required=True, type=Path, help="Path to LoRA adapter dir") | |
| parser.add_argument("--dataset", type=Path, default=DEFAULT_DATASET) | |
| parser.add_argument("--output", type=Path, required=True, help="Where to write the sweep JSON") | |
| parser.add_argument( | |
| "--thresholds", | |
| type=float, | |
| nargs="+", | |
| default=DEFAULT_THRESHOLDS, | |
| help="Threshold values to try", | |
| ) | |
| parser.add_argument("--load-in-4bit", action="store_true", help="Force 4-bit load (smaller VRAM)") | |
| args = parser.parse_args(argv) | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s [%(levelname)s] %(name)s: %(message)s", | |
| ) | |
| logger.info("Loading LoRA adapter from %s", args.lora) | |
| analyzer = LLMAnalyzer( | |
| model_name=args.model, | |
| lora_path=str(args.lora), | |
| use_unsloth=False, | |
| load_in_4bit=args.load_in_4bit, | |
| ) | |
| analyzer.load() | |
| dataset = load_dataset(args.dataset) | |
| logger.info("Loaded %d scenarios from %s", len(dataset), args.dataset) | |
| result = sweep(analyzer, dataset, args.thresholds) | |
| args.output.parent.mkdir(parents=True, exist_ok=True) | |
| with open(args.output, "w") as f: | |
| json.dump(result, f, indent=2) | |
| logger.info("Wrote sweep results to %s", args.output) | |
| # Print summary table. | |
| print() | |
| print("=== THRESHOLD SWEEP SUMMARY ===") | |
| print(f"{'thr':<6}{'det':<8}{'fpr':<8}{'prec':<8}{'f1':<8}") | |
| for thr_str, row in result["thresholds"].items(): | |
| print( | |
| f"{thr_str:<6}" | |
| f"{row['detection']:.3f} " | |
| f"{row['fpr']:.3f} " | |
| f"{row['precision']:.3f} " | |
| f"{row['f1']:.3f}" | |
| ) | |
| print() | |
| best = result["best_by_f1"] | |
| print(f"Best F1: thr={best['threshold']} F1={best['f1']:.3f} FPR={best['fpr']:.3f}") | |
| if result["best_by_fpr_under_15"]: | |
| b = result["best_by_fpr_under_15"] | |
| print(f"Best F1 with FPR<15%: thr={b['threshold']} F1={b['f1']:.3f} FPR={b['fpr']:.3f}") | |
| if result["best_by_fpr_under_10"]: | |
| b = result["best_by_fpr_under_10"] | |
| print(f"Best F1 with FPR<10%: thr={b['threshold']} F1={b['f1']:.3f} FPR={b['fpr']:.3f}") | |
| return 0 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |