Spaces:
Running
Running
File size: 7,630 Bytes
c452421 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 | from __future__ import annotations
import argparse
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
import proof_pack
from sentinel.evaluation import (
DEFAULT_EVAL_OUTPUT_DIR,
DEFAULT_HELD_OUT_TASK_IDS,
DEFAULT_OOD_EVAL_SEEDS,
build_eval_report,
evaluate_tripwire_policy,
parse_seed_spec,
write_eval_report,
)
def main() -> None:
parser = argparse.ArgumentParser(description="Run held-out SENTINEL evaluation.")
parser.add_argument("--seeds", type=str, default="100-104", help="Comma list or range of held-out seeds.")
parser.add_argument("--baseline-checkpoint", type=str, default="", help="Optional baseline checkpoint.")
parser.add_argument("--candidate-checkpoint", type=str, default="", help="Optional candidate checkpoint.")
parser.add_argument("--base-model", type=str, default="", help="Optional base model for adapter checkpoints.")
parser.add_argument("--baseline-label", type=str, default="", help="Display label for the baseline policy.")
parser.add_argument("--candidate-label", type=str, default="", help="Display label for the candidate policy.")
parser.add_argument("--ood-seeds", type=str, default="200-204", help="Comma list or range of OOD held-out seeds.")
parser.add_argument("--skip-tripwires", action="store_true", help="Skip the policy-level tripwire evaluation suite.")
parser.add_argument("--best-of-k", type=int, default=4, help="Sample K first-step decisions and score the best one separately.")
parser.add_argument("--sampling-temperature", type=float, default=0.8, help="Temperature used for sampled Best-of-K evaluation.")
parser.add_argument("--skip-best-of-k", action="store_true", help="Skip the sampled Top-1 vs Best-of-K comparison.")
parser.add_argument("--output-dir", type=str, default=str(DEFAULT_EVAL_OUTPUT_DIR), help="Where to write the eval report.")
parser.add_argument("--dry-run", action="store_true", help="Validate config and exit without executing episodes.")
args = parser.parse_args()
seeds = parse_seed_spec(args.seeds)
ood_seeds = parse_seed_spec(args.ood_seeds) if args.ood_seeds else list(DEFAULT_OOD_EVAL_SEEDS)
if args.dry_run:
print(
{
"seeds": seeds,
"ood_seeds": ood_seeds,
"baseline_checkpoint": args.baseline_checkpoint or None,
"candidate_checkpoint": args.candidate_checkpoint or None,
"base_model": args.base_model or None,
"tripwires": not args.skip_tripwires,
"best_of_k": None if args.skip_best_of_k else max(1, int(args.best_of_k)),
"sampling_temperature": float(args.sampling_temperature),
"output_dir": args.output_dir,
}
)
return
baseline_spec = proof_pack._resolve_policy_spec(
label=args.baseline_label or None,
checkpoint=args.baseline_checkpoint or None,
base_model=args.base_model or None,
fallback_name="approve_all",
fallback_policy=proof_pack._approve_all_policy,
)
candidate_spec = proof_pack._resolve_policy_spec(
label=args.candidate_label or None,
checkpoint=args.candidate_checkpoint or None,
base_model=args.base_model or None,
fallback_name="corrective_policy",
fallback_policy=proof_pack._corrective_policy,
)
baseline_runs = []
candidate_runs = []
baseline_sampling_top1_runs = []
candidate_sampling_top1_runs = []
baseline_best_of_k_runs = []
candidate_best_of_k_runs = []
baseline_ood_runs = []
candidate_ood_runs = []
for task_id in DEFAULT_HELD_OUT_TASK_IDS:
for seed in seeds:
baseline_runs.append(
proof_pack.run_episode(
task_id=task_id,
variant_seed=seed,
policy_name=baseline_spec.name,
policy=baseline_spec.policy,
eval_mode=True,
)
)
candidate_runs.append(
proof_pack.run_episode(
task_id=task_id,
variant_seed=seed,
policy_name=candidate_spec.name,
policy=candidate_spec.policy,
eval_mode=True,
)
)
if not args.skip_best_of_k and args.best_of_k > 1:
baseline_sampled = proof_pack.evaluate_policy_best_of_k(
task_id=task_id,
variant_seed=seed,
policy_spec=baseline_spec,
num_samples=args.best_of_k,
temperature=args.sampling_temperature,
eval_mode=True,
)
candidate_sampled = proof_pack.evaluate_policy_best_of_k(
task_id=task_id,
variant_seed=seed,
policy_spec=candidate_spec,
num_samples=args.best_of_k,
temperature=args.sampling_temperature,
eval_mode=True,
)
baseline_sampling_top1_runs.append(baseline_sampled["top1"])
candidate_sampling_top1_runs.append(candidate_sampled["top1"])
baseline_best_of_k_runs.append(baseline_sampled["best"])
candidate_best_of_k_runs.append(candidate_sampled["best"])
for seed in ood_seeds:
baseline_ood_runs.append(
proof_pack.run_episode(
task_id=task_id,
variant_seed=seed,
policy_name=baseline_spec.name,
policy=baseline_spec.policy,
eval_mode=True,
)
)
candidate_ood_runs.append(
proof_pack.run_episode(
task_id=task_id,
variant_seed=seed,
policy_name=candidate_spec.name,
policy=candidate_spec.policy,
eval_mode=True,
)
)
baseline_tripwire = None
candidate_tripwire = None
if not args.skip_tripwires:
baseline_tripwire = evaluate_tripwire_policy(baseline_spec.name, baseline_spec.policy)
candidate_tripwire = evaluate_tripwire_policy(candidate_spec.name, candidate_spec.policy)
report = build_eval_report(
baseline_runs=baseline_runs,
candidate_runs=candidate_runs,
baseline_label=baseline_spec.name,
candidate_label=candidate_spec.name,
seeds=seeds,
best_of_k=args.best_of_k,
sampling_temperature=args.sampling_temperature,
baseline_sampling_top1_runs=baseline_sampling_top1_runs if baseline_sampling_top1_runs else None,
candidate_sampling_top1_runs=candidate_sampling_top1_runs if candidate_sampling_top1_runs else None,
baseline_best_of_k_runs=baseline_best_of_k_runs if baseline_best_of_k_runs else None,
candidate_best_of_k_runs=candidate_best_of_k_runs if candidate_best_of_k_runs else None,
ood_seeds=ood_seeds,
baseline_ood_runs=baseline_ood_runs,
candidate_ood_runs=candidate_ood_runs,
baseline_tripwire=baseline_tripwire,
candidate_tripwire=candidate_tripwire,
)
paths = write_eval_report(report, output_dir=args.output_dir)
print(f"Held-out evaluation written to {paths['json_path']} and {paths['markdown_path']}")
if __name__ == "__main__":
main()
|