"""Supervisor GRPO-like trainer.""" from __future__ import annotations from pathlib import Path from app.env.env_core import PolyGuardEnv from app.training.checkpointing import save_checkpoint from app.training.metrics import TrainingMetrics from app.training.replay_buffer import ReplayBuffer, failure_mining_summary def train_supervisor_grpo(episodes: int = 10, checkpoint_dir: Path | None = None) -> dict: env = PolyGuardEnv() metrics = TrainingMetrics() replay = ReplayBuffer() for i in range(episodes): env.reset(seed=42 + i, difficulty="easy" if i < episodes // 2 else "medium") done = False while not done: candidates = env.get_legal_actions() action = candidates[0] pre_burden = env.state.burden_score obs, reward, done, info = env.step(action) legal = info["safety_report"]["legal"] severe = len(info["safety_report"]["violations"]) > 1 abstain = action["action_type"].startswith("REQUEST_") reward_components = info.get("reward_breakdown", {}) primary_channels = info.get("primary_reward_channels", {}) failure_reasons = info.get("failure_reasons", []) metrics.add( reward, legal=legal, severe_violation=severe, abstain=abstain, episode_len=env.state.step_count, reward_components=reward_components, success=done and info.get("termination_reason") == "safe_resolution", burden_delta=pre_burden - env.state.burden_score, safety_delta=float(reward_components.get("safety_delta_score", 0.0)), dosing_quality=float(reward_components.get("dosing_quality_score", 0.0)), process_fidelity=float(reward_components.get("process_fidelity_score", 0.0)), exploit_detected=bool(info.get("anti_cheat_reasons")), timeout=bool(info.get("step_timeout") or info.get("termination_reason") == "wall_clock_timeout"), failure_visible=bool(failure_reasons), invalid_actions=int(info.get("invalid_action_count", 0)), primary_channels=primary_channels if isinstance(primary_channels, dict) else None, ) replay.add( { "episode": i, "step": env.state.step_count, "reward": reward, "legal": legal, "termination_reason": info.get("termination_reason"), "failure_reasons": failure_reasons, "final_action": action, "primary_reward_channels": primary_channels, } ) summary = metrics.summary() summary["failure_mining"] = failure_mining_summary(replay.records) if checkpoint_dir: save_checkpoint(checkpoint_dir / "supervisor_grpo.json", summary) replay.dump_jsonl(checkpoint_dir / "supervisor_replay.jsonl") replay.dump_failures_json(checkpoint_dir / "supervisor_failures.json") return summary