from __future__ import annotations import argparse import sys from pathlib import Path ROOT = Path(__file__).resolve().parents[1] if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) import proof_pack from sentinel.evaluation import ( DEFAULT_EVAL_OUTPUT_DIR, DEFAULT_HELD_OUT_TASK_IDS, DEFAULT_OOD_EVAL_SEEDS, build_eval_report, evaluate_tripwire_policy, parse_seed_spec, write_eval_report, ) def main() -> None: parser = argparse.ArgumentParser(description="Run held-out SENTINEL evaluation.") parser.add_argument("--seeds", type=str, default="100-104", help="Comma list or range of held-out seeds.") parser.add_argument("--baseline-checkpoint", type=str, default="", help="Optional baseline checkpoint.") parser.add_argument("--candidate-checkpoint", type=str, default="", help="Optional candidate checkpoint.") parser.add_argument("--base-model", type=str, default="", help="Optional base model for adapter checkpoints.") parser.add_argument("--baseline-label", type=str, default="", help="Display label for the baseline policy.") parser.add_argument("--candidate-label", type=str, default="", help="Display label for the candidate policy.") parser.add_argument("--ood-seeds", type=str, default="200-204", help="Comma list or range of OOD held-out seeds.") parser.add_argument("--skip-tripwires", action="store_true", help="Skip the policy-level tripwire evaluation suite.") parser.add_argument("--best-of-k", type=int, default=4, help="Sample K first-step decisions and score the best one separately.") parser.add_argument("--sampling-temperature", type=float, default=0.8, help="Temperature used for sampled Best-of-K evaluation.") parser.add_argument("--skip-best-of-k", action="store_true", help="Skip the sampled Top-1 vs Best-of-K comparison.") parser.add_argument("--output-dir", type=str, default=str(DEFAULT_EVAL_OUTPUT_DIR), help="Where to write the eval report.") parser.add_argument("--dry-run", action="store_true", help="Validate config and exit without executing episodes.") args = parser.parse_args() seeds = parse_seed_spec(args.seeds) ood_seeds = parse_seed_spec(args.ood_seeds) if args.ood_seeds else list(DEFAULT_OOD_EVAL_SEEDS) if args.dry_run: print( { "seeds": seeds, "ood_seeds": ood_seeds, "baseline_checkpoint": args.baseline_checkpoint or None, "candidate_checkpoint": args.candidate_checkpoint or None, "base_model": args.base_model or None, "tripwires": not args.skip_tripwires, "best_of_k": None if args.skip_best_of_k else max(1, int(args.best_of_k)), "sampling_temperature": float(args.sampling_temperature), "output_dir": args.output_dir, } ) return baseline_spec = proof_pack._resolve_policy_spec( label=args.baseline_label or None, checkpoint=args.baseline_checkpoint or None, base_model=args.base_model or None, fallback_name="approve_all", fallback_policy=proof_pack._approve_all_policy, ) candidate_spec = proof_pack._resolve_policy_spec( label=args.candidate_label or None, checkpoint=args.candidate_checkpoint or None, base_model=args.base_model or None, fallback_name="corrective_policy", fallback_policy=proof_pack._corrective_policy, ) baseline_runs = [] candidate_runs = [] baseline_sampling_top1_runs = [] candidate_sampling_top1_runs = [] baseline_best_of_k_runs = [] candidate_best_of_k_runs = [] baseline_ood_runs = [] candidate_ood_runs = [] for task_id in DEFAULT_HELD_OUT_TASK_IDS: for seed in seeds: baseline_runs.append( proof_pack.run_episode( task_id=task_id, variant_seed=seed, policy_name=baseline_spec.name, policy=baseline_spec.policy, eval_mode=True, ) ) candidate_runs.append( proof_pack.run_episode( task_id=task_id, variant_seed=seed, policy_name=candidate_spec.name, policy=candidate_spec.policy, eval_mode=True, ) ) if not args.skip_best_of_k and args.best_of_k > 1: baseline_sampled = proof_pack.evaluate_policy_best_of_k( task_id=task_id, variant_seed=seed, policy_spec=baseline_spec, num_samples=args.best_of_k, temperature=args.sampling_temperature, eval_mode=True, ) candidate_sampled = proof_pack.evaluate_policy_best_of_k( task_id=task_id, variant_seed=seed, policy_spec=candidate_spec, num_samples=args.best_of_k, temperature=args.sampling_temperature, eval_mode=True, ) baseline_sampling_top1_runs.append(baseline_sampled["top1"]) candidate_sampling_top1_runs.append(candidate_sampled["top1"]) baseline_best_of_k_runs.append(baseline_sampled["best"]) candidate_best_of_k_runs.append(candidate_sampled["best"]) for seed in ood_seeds: baseline_ood_runs.append( proof_pack.run_episode( task_id=task_id, variant_seed=seed, policy_name=baseline_spec.name, policy=baseline_spec.policy, eval_mode=True, ) ) candidate_ood_runs.append( proof_pack.run_episode( task_id=task_id, variant_seed=seed, policy_name=candidate_spec.name, policy=candidate_spec.policy, eval_mode=True, ) ) baseline_tripwire = None candidate_tripwire = None if not args.skip_tripwires: baseline_tripwire = evaluate_tripwire_policy(baseline_spec.name, baseline_spec.policy) candidate_tripwire = evaluate_tripwire_policy(candidate_spec.name, candidate_spec.policy) report = build_eval_report( baseline_runs=baseline_runs, candidate_runs=candidate_runs, baseline_label=baseline_spec.name, candidate_label=candidate_spec.name, seeds=seeds, best_of_k=args.best_of_k, sampling_temperature=args.sampling_temperature, baseline_sampling_top1_runs=baseline_sampling_top1_runs if baseline_sampling_top1_runs else None, candidate_sampling_top1_runs=candidate_sampling_top1_runs if candidate_sampling_top1_runs else None, baseline_best_of_k_runs=baseline_best_of_k_runs if baseline_best_of_k_runs else None, candidate_best_of_k_runs=candidate_best_of_k_runs if candidate_best_of_k_runs else None, ood_seeds=ood_seeds, baseline_ood_runs=baseline_ood_runs, candidate_ood_runs=candidate_ood_runs, baseline_tripwire=baseline_tripwire, candidate_tripwire=candidate_tripwire, ) paths = write_eval_report(report, output_dir=args.output_dir) print(f"Held-out evaluation written to {paths['json_path']} and {paths['markdown_path']}") if __name__ == "__main__": main()