Spaces:
Running
Running
| from __future__ import annotations | |
| import argparse | |
| import sys | |
| from pathlib import Path | |
| ROOT = Path(__file__).resolve().parents[1] | |
| if str(ROOT) not in sys.path: | |
| sys.path.insert(0, str(ROOT)) | |
| import proof_pack | |
| from sentinel.evaluation import ( | |
| DEFAULT_EVAL_OUTPUT_DIR, | |
| DEFAULT_HELD_OUT_TASK_IDS, | |
| DEFAULT_OOD_EVAL_SEEDS, | |
| build_eval_report, | |
| evaluate_tripwire_policy, | |
| parse_seed_spec, | |
| write_eval_report, | |
| ) | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Run held-out SENTINEL evaluation.") | |
| parser.add_argument("--seeds", type=str, default="100-104", help="Comma list or range of held-out seeds.") | |
| parser.add_argument("--baseline-checkpoint", type=str, default="", help="Optional baseline checkpoint.") | |
| parser.add_argument("--candidate-checkpoint", type=str, default="", help="Optional candidate checkpoint.") | |
| parser.add_argument("--base-model", type=str, default="", help="Optional base model for adapter checkpoints.") | |
| parser.add_argument("--baseline-label", type=str, default="", help="Display label for the baseline policy.") | |
| parser.add_argument("--candidate-label", type=str, default="", help="Display label for the candidate policy.") | |
| parser.add_argument("--ood-seeds", type=str, default="200-204", help="Comma list or range of OOD held-out seeds.") | |
| parser.add_argument("--skip-tripwires", action="store_true", help="Skip the policy-level tripwire evaluation suite.") | |
| parser.add_argument("--best-of-k", type=int, default=4, help="Sample K first-step decisions and score the best one separately.") | |
| parser.add_argument("--sampling-temperature", type=float, default=0.8, help="Temperature used for sampled Best-of-K evaluation.") | |
| parser.add_argument("--skip-best-of-k", action="store_true", help="Skip the sampled Top-1 vs Best-of-K comparison.") | |
| parser.add_argument("--output-dir", type=str, default=str(DEFAULT_EVAL_OUTPUT_DIR), help="Where to write the eval report.") | |
| parser.add_argument("--dry-run", action="store_true", help="Validate config and exit without executing episodes.") | |
| args = parser.parse_args() | |
| seeds = parse_seed_spec(args.seeds) | |
| ood_seeds = parse_seed_spec(args.ood_seeds) if args.ood_seeds else list(DEFAULT_OOD_EVAL_SEEDS) | |
| if args.dry_run: | |
| print( | |
| { | |
| "seeds": seeds, | |
| "ood_seeds": ood_seeds, | |
| "baseline_checkpoint": args.baseline_checkpoint or None, | |
| "candidate_checkpoint": args.candidate_checkpoint or None, | |
| "base_model": args.base_model or None, | |
| "tripwires": not args.skip_tripwires, | |
| "best_of_k": None if args.skip_best_of_k else max(1, int(args.best_of_k)), | |
| "sampling_temperature": float(args.sampling_temperature), | |
| "output_dir": args.output_dir, | |
| } | |
| ) | |
| return | |
| baseline_spec = proof_pack._resolve_policy_spec( | |
| label=args.baseline_label or None, | |
| checkpoint=args.baseline_checkpoint or None, | |
| base_model=args.base_model or None, | |
| fallback_name="approve_all", | |
| fallback_policy=proof_pack._approve_all_policy, | |
| ) | |
| candidate_spec = proof_pack._resolve_policy_spec( | |
| label=args.candidate_label or None, | |
| checkpoint=args.candidate_checkpoint or None, | |
| base_model=args.base_model or None, | |
| fallback_name="corrective_policy", | |
| fallback_policy=proof_pack._corrective_policy, | |
| ) | |
| baseline_runs = [] | |
| candidate_runs = [] | |
| baseline_sampling_top1_runs = [] | |
| candidate_sampling_top1_runs = [] | |
| baseline_best_of_k_runs = [] | |
| candidate_best_of_k_runs = [] | |
| baseline_ood_runs = [] | |
| candidate_ood_runs = [] | |
| for task_id in DEFAULT_HELD_OUT_TASK_IDS: | |
| for seed in seeds: | |
| baseline_runs.append( | |
| proof_pack.run_episode( | |
| task_id=task_id, | |
| variant_seed=seed, | |
| policy_name=baseline_spec.name, | |
| policy=baseline_spec.policy, | |
| eval_mode=True, | |
| ) | |
| ) | |
| candidate_runs.append( | |
| proof_pack.run_episode( | |
| task_id=task_id, | |
| variant_seed=seed, | |
| policy_name=candidate_spec.name, | |
| policy=candidate_spec.policy, | |
| eval_mode=True, | |
| ) | |
| ) | |
| if not args.skip_best_of_k and args.best_of_k > 1: | |
| baseline_sampled = proof_pack.evaluate_policy_best_of_k( | |
| task_id=task_id, | |
| variant_seed=seed, | |
| policy_spec=baseline_spec, | |
| num_samples=args.best_of_k, | |
| temperature=args.sampling_temperature, | |
| eval_mode=True, | |
| ) | |
| candidate_sampled = proof_pack.evaluate_policy_best_of_k( | |
| task_id=task_id, | |
| variant_seed=seed, | |
| policy_spec=candidate_spec, | |
| num_samples=args.best_of_k, | |
| temperature=args.sampling_temperature, | |
| eval_mode=True, | |
| ) | |
| baseline_sampling_top1_runs.append(baseline_sampled["top1"]) | |
| candidate_sampling_top1_runs.append(candidate_sampled["top1"]) | |
| baseline_best_of_k_runs.append(baseline_sampled["best"]) | |
| candidate_best_of_k_runs.append(candidate_sampled["best"]) | |
| for seed in ood_seeds: | |
| baseline_ood_runs.append( | |
| proof_pack.run_episode( | |
| task_id=task_id, | |
| variant_seed=seed, | |
| policy_name=baseline_spec.name, | |
| policy=baseline_spec.policy, | |
| eval_mode=True, | |
| ) | |
| ) | |
| candidate_ood_runs.append( | |
| proof_pack.run_episode( | |
| task_id=task_id, | |
| variant_seed=seed, | |
| policy_name=candidate_spec.name, | |
| policy=candidate_spec.policy, | |
| eval_mode=True, | |
| ) | |
| ) | |
| baseline_tripwire = None | |
| candidate_tripwire = None | |
| if not args.skip_tripwires: | |
| baseline_tripwire = evaluate_tripwire_policy(baseline_spec.name, baseline_spec.policy) | |
| candidate_tripwire = evaluate_tripwire_policy(candidate_spec.name, candidate_spec.policy) | |
| report = build_eval_report( | |
| baseline_runs=baseline_runs, | |
| candidate_runs=candidate_runs, | |
| baseline_label=baseline_spec.name, | |
| candidate_label=candidate_spec.name, | |
| seeds=seeds, | |
| best_of_k=args.best_of_k, | |
| sampling_temperature=args.sampling_temperature, | |
| baseline_sampling_top1_runs=baseline_sampling_top1_runs if baseline_sampling_top1_runs else None, | |
| candidate_sampling_top1_runs=candidate_sampling_top1_runs if candidate_sampling_top1_runs else None, | |
| baseline_best_of_k_runs=baseline_best_of_k_runs if baseline_best_of_k_runs else None, | |
| candidate_best_of_k_runs=candidate_best_of_k_runs if candidate_best_of_k_runs else None, | |
| ood_seeds=ood_seeds, | |
| baseline_ood_runs=baseline_ood_runs, | |
| candidate_ood_runs=candidate_ood_runs, | |
| baseline_tripwire=baseline_tripwire, | |
| candidate_tripwire=candidate_tripwire, | |
| ) | |
| paths = write_eval_report(report, output_dir=args.output_dir) | |
| print(f"Held-out evaluation written to {paths['json_path']} and {paths['markdown_path']}") | |
| if __name__ == "__main__": | |
| main() | |