polyguard-openenv-training-3b-continuation / scripts /evaluate_policy_ablations.py
adithya9903's picture
Deploy PolyGuard HF training Space
fd0c71a verified
#!/usr/bin/env python3
"""Run policy-stack ablations independently from GRPO training."""
from __future__ import annotations
import argparse
import json
from pathlib import Path
import sys
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from app.training.grpo_experiment import run_policy_stack_rollout
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Run policy-stack ablations.")
parser.add_argument("--episodes", type=int, default=6)
parser.add_argument("--output", default="outputs/reports/grpo_ablation_report.json")
parser.add_argument("--checkpoint-dir", default="checkpoints")
return parser.parse_args()
def main() -> None:
args = parse_args()
root = Path(__file__).resolve().parents[1]
checkpoint_dir = root / args.checkpoint_dir
checkpoint_dir.mkdir(parents=True, exist_ok=True)
payload = {
"status": "ok",
"ablations": {
"bandit_only": run_policy_stack_rollout(
"bandit-only",
episodes=args.episodes,
checkpoint_dir=checkpoint_dir,
seed_offset=1_200,
),
"llm_only": run_policy_stack_rollout(
"llm-only",
episodes=args.episodes,
checkpoint_dir=checkpoint_dir,
seed_offset=2_200,
),
"llm_bandit": run_policy_stack_rollout(
"llm+bandit",
episodes=args.episodes,
checkpoint_dir=checkpoint_dir,
seed_offset=3_200,
),
},
}
output_path = root / args.output
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(json.dumps(payload, ensure_ascii=True, indent=2), encoding="utf-8")
print("policy_ablations_done")
if __name__ == "__main__":
main()