Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """Run policy-stack ablations independently from GRPO training.""" | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| from pathlib import Path | |
| import sys | |
| ROOT = Path(__file__).resolve().parents[1] | |
| if str(ROOT) not in sys.path: | |
| sys.path.insert(0, str(ROOT)) | |
| from app.training.grpo_experiment import run_policy_stack_rollout | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="Run policy-stack ablations.") | |
| parser.add_argument("--episodes", type=int, default=6) | |
| parser.add_argument("--output", default="outputs/reports/grpo_ablation_report.json") | |
| parser.add_argument("--checkpoint-dir", default="checkpoints") | |
| return parser.parse_args() | |
| def main() -> None: | |
| args = parse_args() | |
| root = Path(__file__).resolve().parents[1] | |
| checkpoint_dir = root / args.checkpoint_dir | |
| checkpoint_dir.mkdir(parents=True, exist_ok=True) | |
| payload = { | |
| "status": "ok", | |
| "ablations": { | |
| "bandit_only": run_policy_stack_rollout( | |
| "bandit-only", | |
| episodes=args.episodes, | |
| checkpoint_dir=checkpoint_dir, | |
| seed_offset=1_200, | |
| ), | |
| "llm_only": run_policy_stack_rollout( | |
| "llm-only", | |
| episodes=args.episodes, | |
| checkpoint_dir=checkpoint_dir, | |
| seed_offset=2_200, | |
| ), | |
| "llm_bandit": run_policy_stack_rollout( | |
| "llm+bandit", | |
| episodes=args.episodes, | |
| checkpoint_dir=checkpoint_dir, | |
| seed_offset=3_200, | |
| ), | |
| }, | |
| } | |
| output_path = root / args.output | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| output_path.write_text(json.dumps(payload, ensure_ascii=True, indent=2), encoding="utf-8") | |
| print("policy_ablations_done") | |
| if __name__ == "__main__": | |
| main() | |