File size: 6,218 Bytes
03815d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
"""Threshold sweep — score a trained LoRA analyzer once, re-threshold many times.

Use case: you've trained a LoRA and eval at threshold=0.5 shows strong recall but
high FPR (classic over-flagging from reward hacking). Rather than retrain, sweep
the flag threshold across [0.3, 0.4, ..., 0.9] to find the P/R sweet spot.

KEY OPTIMIZATION: the LoRA produces a CONTINUOUS score per scenario. We run it
ONCE across the 175 bench scenarios (~15 min), cache `(scenario_id, score)`, then
apply different thresholds to the cached scores — each re-threshold is <1 second.

Usage:
    # On Colab after training:
    python -m eval.threshold_sweep \\
        --model Qwen/Qwen2.5-7B-Instruct \\
        --lora /content/drive/MyDrive/chakravyuh/analyzer_lora \\
        --output /content/drive/MyDrive/chakravyuh/threshold_sweep.json

Output JSON:
    {
      "thresholds": {
        "0.3": {"detection": 1.00, "fpr": 0.48, "precision": 0.88, "f1": 0.94, ...},
        "0.4": {...},
        ...
      },
      "best_by_f1": {"threshold": 0.7, "f1": 0.945, ...},
      "best_by_fpr_under_15": {"threshold": 0.75, ...}
    }
"""

from __future__ import annotations

import argparse
import json
import logging
import sys
from dataclasses import asdict
from pathlib import Path

from chakravyuh_env.agents.llm_analyzer import LLMAnalyzer
from eval.mode_c_real_cases import (
    DEFAULT_DATASET,
    aggregate,
    load_dataset,
    per_category_breakdown,
    per_difficulty_breakdown,
    run_eval,
)

logger = logging.getLogger("chakravyuh.threshold_sweep")

DEFAULT_THRESHOLDS = [0.3, 0.4, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9]


def sweep(
    analyzer: LLMAnalyzer,
    dataset: list[dict],
    thresholds: list[float],
) -> dict:
    """Run inference once, aggregate at each threshold."""
    # One pass: collect continuous scores per scenario.
    base_results = run_eval(analyzer, dataset, threshold=0.5)  # threshold irrelevant here, we re-apply

    # Map for fast re-thresholding.
    logger.info("Scored %d scenarios. Re-thresholding %d cutoffs…", len(base_results), len(thresholds))

    out: dict[str, dict] = {}
    for thr in thresholds:
        # Re-flag every scenario at this threshold.
        rethresh = [
            type(r)(
                scenario_id=r.scenario_id,
                is_scam_truth=r.is_scam_truth,
                predicted_score=r.predicted_score,
                predicted_flag=(r.predicted_score >= thr),
                correct=((r.predicted_score >= thr) == r.is_scam_truth),
                category=r.category,
                difficulty=r.difficulty,
            )
            for r in base_results
        ]
        m = aggregate(rethresh)
        out[f"{thr:.2f}"] = {
            "threshold": thr,
            "n": m.n,
            "detection": round(m.detection_rate, 4),
            "fpr": round(m.false_positive_rate, 4),
            "precision": round(m.precision, 4),
            "recall": round(m.recall, 4),
            "f1": round(m.f1, 4),
            "accuracy": round(m.accuracy, 4),
        }
        logger.info(
            "thr=%.2f  det=%.1f%%  fpr=%.1f%%  P=%.1f%%  F1=%.3f",
            thr, m.detection_rate * 100, m.false_positive_rate * 100,
            m.precision * 100, m.f1,
        )

    return {
        "thresholds": out,
        "best_by_f1": max(out.values(), key=lambda x: x["f1"]),
        "best_by_fpr_under_15": min(
            (v for v in out.values() if v["fpr"] <= 0.15),
            key=lambda x: -x["f1"],
            default=None,
        ),
        "best_by_fpr_under_10": min(
            (v for v in out.values() if v["fpr"] <= 0.10),
            key=lambda x: -x["f1"],
            default=None,
        ),
    }


def main(argv: list[str] | None = None) -> int:
    parser = argparse.ArgumentParser(description="Sweep flag thresholds on a trained LoRA analyzer.")
    parser.add_argument("--model", default="Qwen/Qwen2.5-7B-Instruct")
    parser.add_argument("--lora", required=True, type=Path, help="Path to LoRA adapter dir")
    parser.add_argument("--dataset", type=Path, default=DEFAULT_DATASET)
    parser.add_argument("--output", type=Path, required=True, help="Where to write the sweep JSON")
    parser.add_argument(
        "--thresholds",
        type=float,
        nargs="+",
        default=DEFAULT_THRESHOLDS,
        help="Threshold values to try",
    )
    parser.add_argument("--load-in-4bit", action="store_true", help="Force 4-bit load (smaller VRAM)")
    args = parser.parse_args(argv)

    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
    )

    logger.info("Loading LoRA adapter from %s", args.lora)
    analyzer = LLMAnalyzer(
        model_name=args.model,
        lora_path=str(args.lora),
        use_unsloth=False,
        load_in_4bit=args.load_in_4bit,
    )
    analyzer.load()

    dataset = load_dataset(args.dataset)
    logger.info("Loaded %d scenarios from %s", len(dataset), args.dataset)

    result = sweep(analyzer, dataset, args.thresholds)

    args.output.parent.mkdir(parents=True, exist_ok=True)
    with open(args.output, "w") as f:
        json.dump(result, f, indent=2)
    logger.info("Wrote sweep results to %s", args.output)

    # Print summary table.
    print()
    print("=== THRESHOLD SWEEP SUMMARY ===")
    print(f"{'thr':<6}{'det':<8}{'fpr':<8}{'prec':<8}{'f1':<8}")
    for thr_str, row in result["thresholds"].items():
        print(
            f"{thr_str:<6}"
            f"{row['detection']:.3f}  "
            f"{row['fpr']:.3f}  "
            f"{row['precision']:.3f}  "
            f"{row['f1']:.3f}"
        )
    print()
    best = result["best_by_f1"]
    print(f"Best F1:  thr={best['threshold']}  F1={best['f1']:.3f}  FPR={best['fpr']:.3f}")
    if result["best_by_fpr_under_15"]:
        b = result["best_by_fpr_under_15"]
        print(f"Best F1 with FPR<15%:  thr={b['threshold']}  F1={b['f1']:.3f}  FPR={b['fpr']:.3f}")
    if result["best_by_fpr_under_10"]:
        b = result["best_by_fpr_under_10"]
        print(f"Best F1 with FPR<10%:  thr={b['threshold']}  F1={b['f1']:.3f}  FPR={b['fpr']:.3f}")

    return 0


if __name__ == "__main__":
    sys.exit(main())