File size: 7,630 Bytes
c452421
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
from __future__ import annotations

import argparse
import sys
from pathlib import Path

ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

import proof_pack
from sentinel.evaluation import (
    DEFAULT_EVAL_OUTPUT_DIR,
    DEFAULT_HELD_OUT_TASK_IDS,
    DEFAULT_OOD_EVAL_SEEDS,
    build_eval_report,
    evaluate_tripwire_policy,
    parse_seed_spec,
    write_eval_report,
)


def main() -> None:
    parser = argparse.ArgumentParser(description="Run held-out SENTINEL evaluation.")
    parser.add_argument("--seeds", type=str, default="100-104", help="Comma list or range of held-out seeds.")
    parser.add_argument("--baseline-checkpoint", type=str, default="", help="Optional baseline checkpoint.")
    parser.add_argument("--candidate-checkpoint", type=str, default="", help="Optional candidate checkpoint.")
    parser.add_argument("--base-model", type=str, default="", help="Optional base model for adapter checkpoints.")
    parser.add_argument("--baseline-label", type=str, default="", help="Display label for the baseline policy.")
    parser.add_argument("--candidate-label", type=str, default="", help="Display label for the candidate policy.")
    parser.add_argument("--ood-seeds", type=str, default="200-204", help="Comma list or range of OOD held-out seeds.")
    parser.add_argument("--skip-tripwires", action="store_true", help="Skip the policy-level tripwire evaluation suite.")
    parser.add_argument("--best-of-k", type=int, default=4, help="Sample K first-step decisions and score the best one separately.")
    parser.add_argument("--sampling-temperature", type=float, default=0.8, help="Temperature used for sampled Best-of-K evaluation.")
    parser.add_argument("--skip-best-of-k", action="store_true", help="Skip the sampled Top-1 vs Best-of-K comparison.")
    parser.add_argument("--output-dir", type=str, default=str(DEFAULT_EVAL_OUTPUT_DIR), help="Where to write the eval report.")
    parser.add_argument("--dry-run", action="store_true", help="Validate config and exit without executing episodes.")
    args = parser.parse_args()

    seeds = parse_seed_spec(args.seeds)
    ood_seeds = parse_seed_spec(args.ood_seeds) if args.ood_seeds else list(DEFAULT_OOD_EVAL_SEEDS)
    if args.dry_run:
        print(
            {
                "seeds": seeds,
                "ood_seeds": ood_seeds,
                "baseline_checkpoint": args.baseline_checkpoint or None,
                "candidate_checkpoint": args.candidate_checkpoint or None,
                "base_model": args.base_model or None,
                "tripwires": not args.skip_tripwires,
                "best_of_k": None if args.skip_best_of_k else max(1, int(args.best_of_k)),
                "sampling_temperature": float(args.sampling_temperature),
                "output_dir": args.output_dir,
            }
        )
        return

    baseline_spec = proof_pack._resolve_policy_spec(
        label=args.baseline_label or None,
        checkpoint=args.baseline_checkpoint or None,
        base_model=args.base_model or None,
        fallback_name="approve_all",
        fallback_policy=proof_pack._approve_all_policy,
    )
    candidate_spec = proof_pack._resolve_policy_spec(
        label=args.candidate_label or None,
        checkpoint=args.candidate_checkpoint or None,
        base_model=args.base_model or None,
        fallback_name="corrective_policy",
        fallback_policy=proof_pack._corrective_policy,
    )

    baseline_runs = []
    candidate_runs = []
    baseline_sampling_top1_runs = []
    candidate_sampling_top1_runs = []
    baseline_best_of_k_runs = []
    candidate_best_of_k_runs = []
    baseline_ood_runs = []
    candidate_ood_runs = []
    for task_id in DEFAULT_HELD_OUT_TASK_IDS:
        for seed in seeds:
            baseline_runs.append(
                proof_pack.run_episode(
                    task_id=task_id,
                    variant_seed=seed,
                    policy_name=baseline_spec.name,
                    policy=baseline_spec.policy,
                    eval_mode=True,
                )
            )
            candidate_runs.append(
                proof_pack.run_episode(
                    task_id=task_id,
                    variant_seed=seed,
                    policy_name=candidate_spec.name,
                    policy=candidate_spec.policy,
                    eval_mode=True,
                )
            )
            if not args.skip_best_of_k and args.best_of_k > 1:
                baseline_sampled = proof_pack.evaluate_policy_best_of_k(
                    task_id=task_id,
                    variant_seed=seed,
                    policy_spec=baseline_spec,
                    num_samples=args.best_of_k,
                    temperature=args.sampling_temperature,
                    eval_mode=True,
                )
                candidate_sampled = proof_pack.evaluate_policy_best_of_k(
                    task_id=task_id,
                    variant_seed=seed,
                    policy_spec=candidate_spec,
                    num_samples=args.best_of_k,
                    temperature=args.sampling_temperature,
                    eval_mode=True,
                )
                baseline_sampling_top1_runs.append(baseline_sampled["top1"])
                candidate_sampling_top1_runs.append(candidate_sampled["top1"])
                baseline_best_of_k_runs.append(baseline_sampled["best"])
                candidate_best_of_k_runs.append(candidate_sampled["best"])
        for seed in ood_seeds:
            baseline_ood_runs.append(
                proof_pack.run_episode(
                    task_id=task_id,
                    variant_seed=seed,
                    policy_name=baseline_spec.name,
                    policy=baseline_spec.policy,
                    eval_mode=True,
                )
            )
            candidate_ood_runs.append(
                proof_pack.run_episode(
                    task_id=task_id,
                    variant_seed=seed,
                    policy_name=candidate_spec.name,
                    policy=candidate_spec.policy,
                    eval_mode=True,
                )
            )

    baseline_tripwire = None
    candidate_tripwire = None
    if not args.skip_tripwires:
        baseline_tripwire = evaluate_tripwire_policy(baseline_spec.name, baseline_spec.policy)
        candidate_tripwire = evaluate_tripwire_policy(candidate_spec.name, candidate_spec.policy)

    report = build_eval_report(
        baseline_runs=baseline_runs,
        candidate_runs=candidate_runs,
        baseline_label=baseline_spec.name,
        candidate_label=candidate_spec.name,
        seeds=seeds,
        best_of_k=args.best_of_k,
        sampling_temperature=args.sampling_temperature,
        baseline_sampling_top1_runs=baseline_sampling_top1_runs if baseline_sampling_top1_runs else None,
        candidate_sampling_top1_runs=candidate_sampling_top1_runs if candidate_sampling_top1_runs else None,
        baseline_best_of_k_runs=baseline_best_of_k_runs if baseline_best_of_k_runs else None,
        candidate_best_of_k_runs=candidate_best_of_k_runs if candidate_best_of_k_runs else None,
        ood_seeds=ood_seeds,
        baseline_ood_runs=baseline_ood_runs,
        candidate_ood_runs=candidate_ood_runs,
        baseline_tripwire=baseline_tripwire,
        candidate_tripwire=candidate_tripwire,
    )
    paths = write_eval_report(report, output_dir=args.output_dir)
    print(f"Held-out evaluation written to {paths['json_path']} and {paths['markdown_path']}")


if __name__ == "__main__":
    main()