polyguard-openenv / app /training /planner_grpo.py
TheJackBright's picture
Deploy PolyGuard OpenEnv Space
877add7 verified
"""Planner GRPO-like trainer."""
from __future__ import annotations
from pathlib import Path
from app.agents.orchestrator import Orchestrator
from app.env.env_core import PolyGuardEnv
from app.training.checkpointing import save_checkpoint
from app.training.metrics import TrainingMetrics
from app.training.replay_buffer import ReplayBuffer, failure_mining_summary
def train_planner_grpo(episodes: int = 20, checkpoint_dir: Path | None = None) -> dict:
env = PolyGuardEnv()
orchestrator = Orchestrator(env=env)
metrics = TrainingMetrics()
replay = ReplayBuffer()
for i in range(episodes):
env.reset(seed=101 + i, difficulty="medium" if i < episodes // 2 else "hard")
done = False
while not done:
pre_burden = env.state.burden_score
result = orchestrator.run_step()
reward = result["reward"]
done = result["done"]
legal = result["critic"]["legal"]
severe = len(result["critic"]["violations"]) > 1
abstain = result["final_action"]["action_type"].startswith("REQUEST_")
reward_components = result["info"].get("reward_breakdown", {})
primary_channels = result["info"].get("primary_reward_channels", {})
failure_reasons = result["info"].get("failure_reasons", [])
metrics.add(
reward,
legal=legal,
severe_violation=severe,
abstain=abstain,
episode_len=env.state.step_count,
reward_components=reward_components,
success=done and result["info"].get("termination_reason") == "safe_resolution",
burden_delta=pre_burden - env.state.burden_score,
safety_delta=float(reward_components.get("safety_delta_score", 0.0)),
dosing_quality=float(reward_components.get("dosing_quality_score", 0.0)),
process_fidelity=float(reward_components.get("process_fidelity_score", 0.0)),
exploit_detected=bool(result["info"].get("anti_cheat_reasons")),
timeout=bool(result["info"].get("step_timeout") or result["info"].get("termination_reason") == "wall_clock_timeout"),
failure_visible=bool(failure_reasons),
invalid_actions=int(result["info"].get("invalid_action_count", 0)),
primary_channels=primary_channels if isinstance(primary_channels, dict) else None,
)
replay.add(
{
"episode": i,
"step": env.state.step_count,
"reward": reward,
"legal": legal,
"termination_reason": result["info"].get("termination_reason"),
"failure_reasons": failure_reasons,
"policy_stack": result.get("policy_stack"),
"bandit_topk": result.get("bandit_topk", []),
"final_action": result.get("final_action", {}),
"primary_reward_channels": primary_channels,
}
)
summary = metrics.summary()
summary["failure_mining"] = failure_mining_summary(replay.records)
if checkpoint_dir:
save_checkpoint(checkpoint_dir / "planner_grpo.json", summary)
replay.dump_jsonl(checkpoint_dir / "planner_replay.jsonl")
replay.dump_failures_json(checkpoint_dir / "planner_failures.json")
return summary