| |
| """Run policy-stack ablations independently from GRPO training.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| from pathlib import Path |
|
|
| import sys |
|
|
| ROOT = Path(__file__).resolve().parents[1] |
| if str(ROOT) not in sys.path: |
| sys.path.insert(0, str(ROOT)) |
|
|
| from app.training.grpo_experiment import run_policy_stack_rollout |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="Run policy-stack ablations.") |
| parser.add_argument("--episodes", type=int, default=6) |
| parser.add_argument("--output", default="outputs/reports/grpo_ablation_report.json") |
| parser.add_argument("--checkpoint-dir", default="checkpoints") |
| return parser.parse_args() |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| root = Path(__file__).resolve().parents[1] |
| checkpoint_dir = root / args.checkpoint_dir |
| checkpoint_dir.mkdir(parents=True, exist_ok=True) |
|
|
| payload = { |
| "status": "ok", |
| "ablations": { |
| "bandit_only": run_policy_stack_rollout( |
| "bandit-only", |
| episodes=args.episodes, |
| checkpoint_dir=checkpoint_dir, |
| seed_offset=1_200, |
| ), |
| "llm_only": run_policy_stack_rollout( |
| "llm-only", |
| episodes=args.episodes, |
| checkpoint_dir=checkpoint_dir, |
| seed_offset=2_200, |
| ), |
| "llm_bandit": run_policy_stack_rollout( |
| "llm+bandit", |
| episodes=args.episodes, |
| checkpoint_dir=checkpoint_dir, |
| seed_offset=3_200, |
| ), |
| }, |
| } |
|
|
| output_path = root / args.output |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| output_path.write_text(json.dumps(payload, ensure_ascii=True, indent=2), encoding="utf-8") |
| print("policy_ablations_done") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|