#!/usr/bin/env python3 """Run policy-stack ablations independently from GRPO training.""" from __future__ import annotations import argparse import json from pathlib import Path import sys ROOT = Path(__file__).resolve().parents[1] if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) from app.training.grpo_experiment import run_policy_stack_rollout def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Run policy-stack ablations.") parser.add_argument("--episodes", type=int, default=6) parser.add_argument("--output", default="outputs/reports/grpo_ablation_report.json") parser.add_argument("--checkpoint-dir", default="checkpoints") return parser.parse_args() def main() -> None: args = parse_args() root = Path(__file__).resolve().parents[1] checkpoint_dir = root / args.checkpoint_dir checkpoint_dir.mkdir(parents=True, exist_ok=True) payload = { "status": "ok", "ablations": { "bandit_only": run_policy_stack_rollout( "bandit-only", episodes=args.episodes, checkpoint_dir=checkpoint_dir, seed_offset=1_200, ), "llm_only": run_policy_stack_rollout( "llm-only", episodes=args.episodes, checkpoint_dir=checkpoint_dir, seed_offset=2_200, ), "llm_bandit": run_policy_stack_rollout( "llm+bandit", episodes=args.episodes, checkpoint_dir=checkpoint_dir, seed_offset=3_200, ), }, } output_path = root / args.output output_path.parent.mkdir(parents=True, exist_ok=True) output_path.write_text(json.dumps(payload, ensure_ascii=True, indent=2), encoding="utf-8") print("policy_ablations_done") if __name__ == "__main__": main()