Spaces:
Running
Running
| """Supervisor GRPO-like trainer.""" | |
| from __future__ import annotations | |
| from pathlib import Path | |
| from app.env.env_core import PolyGuardEnv | |
| from app.training.checkpointing import save_checkpoint | |
| from app.training.metrics import TrainingMetrics | |
| from app.training.replay_buffer import ReplayBuffer, failure_mining_summary | |
| def train_supervisor_grpo(episodes: int = 10, checkpoint_dir: Path | None = None) -> dict: | |
| env = PolyGuardEnv() | |
| metrics = TrainingMetrics() | |
| replay = ReplayBuffer() | |
| for i in range(episodes): | |
| env.reset(seed=42 + i, difficulty="easy" if i < episodes // 2 else "medium") | |
| done = False | |
| while not done: | |
| candidates = env.get_legal_actions() | |
| action = candidates[0] | |
| pre_burden = env.state.burden_score | |
| obs, reward, done, info = env.step(action) | |
| legal = info["safety_report"]["legal"] | |
| severe = len(info["safety_report"]["violations"]) > 1 | |
| abstain = action["action_type"].startswith("REQUEST_") | |
| reward_components = info.get("reward_breakdown", {}) | |
| primary_channels = info.get("primary_reward_channels", {}) | |
| failure_reasons = info.get("failure_reasons", []) | |
| metrics.add( | |
| reward, | |
| legal=legal, | |
| severe_violation=severe, | |
| abstain=abstain, | |
| episode_len=env.state.step_count, | |
| reward_components=reward_components, | |
| success=done and info.get("termination_reason") == "safe_resolution", | |
| burden_delta=pre_burden - env.state.burden_score, | |
| safety_delta=float(reward_components.get("safety_delta_score", 0.0)), | |
| dosing_quality=float(reward_components.get("dosing_quality_score", 0.0)), | |
| process_fidelity=float(reward_components.get("process_fidelity_score", 0.0)), | |
| exploit_detected=bool(info.get("anti_cheat_reasons")), | |
| timeout=bool(info.get("step_timeout") or info.get("termination_reason") == "wall_clock_timeout"), | |
| failure_visible=bool(failure_reasons), | |
| invalid_actions=int(info.get("invalid_action_count", 0)), | |
| primary_channels=primary_channels if isinstance(primary_channels, dict) else None, | |
| ) | |
| replay.add( | |
| { | |
| "episode": i, | |
| "step": env.state.step_count, | |
| "reward": reward, | |
| "legal": legal, | |
| "termination_reason": info.get("termination_reason"), | |
| "failure_reasons": failure_reasons, | |
| "final_action": action, | |
| "primary_reward_channels": primary_channels, | |
| } | |
| ) | |
| summary = metrics.summary() | |
| summary["failure_mining"] = failure_mining_summary(replay.records) | |
| if checkpoint_dir: | |
| save_checkpoint(checkpoint_dir / "supervisor_grpo.json", summary) | |
| replay.dump_jsonl(checkpoint_dir / "supervisor_replay.jsonl") | |
| replay.dump_failures_json(checkpoint_dir / "supervisor_failures.json") | |
| return summary | |