"""Planner GRPO-like trainer.""" from __future__ import annotations from pathlib import Path from app.agents.orchestrator import Orchestrator from app.env.env_core import PolyGuardEnv from app.training.checkpointing import save_checkpoint from app.training.metrics import TrainingMetrics from app.training.replay_buffer import ReplayBuffer, failure_mining_summary def train_planner_grpo(episodes: int = 20, checkpoint_dir: Path | None = None) -> dict: env = PolyGuardEnv() orchestrator = Orchestrator(env=env) metrics = TrainingMetrics() replay = ReplayBuffer() for i in range(episodes): env.reset(seed=101 + i, difficulty="medium" if i < episodes // 2 else "hard") done = False while not done: pre_burden = env.state.burden_score result = orchestrator.run_step() reward = result["reward"] done = result["done"] legal = result["critic"]["legal"] severe = len(result["critic"]["violations"]) > 1 abstain = result["final_action"]["action_type"].startswith("REQUEST_") reward_components = result["info"].get("reward_breakdown", {}) primary_channels = result["info"].get("primary_reward_channels", {}) failure_reasons = result["info"].get("failure_reasons", []) metrics.add( reward, legal=legal, severe_violation=severe, abstain=abstain, episode_len=env.state.step_count, reward_components=reward_components, success=done and result["info"].get("termination_reason") == "safe_resolution", burden_delta=pre_burden - env.state.burden_score, safety_delta=float(reward_components.get("safety_delta_score", 0.0)), dosing_quality=float(reward_components.get("dosing_quality_score", 0.0)), process_fidelity=float(reward_components.get("process_fidelity_score", 0.0)), exploit_detected=bool(result["info"].get("anti_cheat_reasons")), timeout=bool(result["info"].get("step_timeout") or result["info"].get("termination_reason") == "wall_clock_timeout"), failure_visible=bool(failure_reasons), invalid_actions=int(result["info"].get("invalid_action_count", 0)), primary_channels=primary_channels if isinstance(primary_channels, dict) else None, ) replay.add( { "episode": i, "step": env.state.step_count, "reward": reward, "legal": legal, "termination_reason": result["info"].get("termination_reason"), "failure_reasons": failure_reasons, "policy_stack": result.get("policy_stack"), "bandit_topk": result.get("bandit_topk", []), "final_action": result.get("final_action", {}), "primary_reward_channels": primary_channels, } ) summary = metrics.summary() summary["failure_mining"] = failure_mining_summary(replay.records) if checkpoint_dir: save_checkpoint(checkpoint_dir / "planner_grpo.json", summary) replay.dump_jsonl(checkpoint_dir / "planner_replay.jsonl") replay.dump_failures_json(checkpoint_dir / "planner_failures.json") return summary