Spaces:
Running
Running
| """Planner GRPO-like trainer.""" | |
| from __future__ import annotations | |
| from pathlib import Path | |
| from app.agents.orchestrator import Orchestrator | |
| from app.env.env_core import PolyGuardEnv | |
| from app.training.checkpointing import save_checkpoint | |
| from app.training.metrics import TrainingMetrics | |
| from app.training.replay_buffer import ReplayBuffer, failure_mining_summary | |
| def train_planner_grpo(episodes: int = 20, checkpoint_dir: Path | None = None) -> dict: | |
| env = PolyGuardEnv() | |
| orchestrator = Orchestrator(env=env) | |
| metrics = TrainingMetrics() | |
| replay = ReplayBuffer() | |
| for i in range(episodes): | |
| env.reset(seed=101 + i, difficulty="medium" if i < episodes // 2 else "hard") | |
| done = False | |
| while not done: | |
| pre_burden = env.state.burden_score | |
| result = orchestrator.run_step() | |
| reward = result["reward"] | |
| done = result["done"] | |
| legal = result["critic"]["legal"] | |
| severe = len(result["critic"]["violations"]) > 1 | |
| abstain = result["final_action"]["action_type"].startswith("REQUEST_") | |
| reward_components = result["info"].get("reward_breakdown", {}) | |
| primary_channels = result["info"].get("primary_reward_channels", {}) | |
| failure_reasons = result["info"].get("failure_reasons", []) | |
| metrics.add( | |
| reward, | |
| legal=legal, | |
| severe_violation=severe, | |
| abstain=abstain, | |
| episode_len=env.state.step_count, | |
| reward_components=reward_components, | |
| success=done and result["info"].get("termination_reason") == "safe_resolution", | |
| burden_delta=pre_burden - env.state.burden_score, | |
| safety_delta=float(reward_components.get("safety_delta_score", 0.0)), | |
| dosing_quality=float(reward_components.get("dosing_quality_score", 0.0)), | |
| process_fidelity=float(reward_components.get("process_fidelity_score", 0.0)), | |
| exploit_detected=bool(result["info"].get("anti_cheat_reasons")), | |
| timeout=bool(result["info"].get("step_timeout") or result["info"].get("termination_reason") == "wall_clock_timeout"), | |
| failure_visible=bool(failure_reasons), | |
| invalid_actions=int(result["info"].get("invalid_action_count", 0)), | |
| primary_channels=primary_channels if isinstance(primary_channels, dict) else None, | |
| ) | |
| replay.add( | |
| { | |
| "episode": i, | |
| "step": env.state.step_count, | |
| "reward": reward, | |
| "legal": legal, | |
| "termination_reason": result["info"].get("termination_reason"), | |
| "failure_reasons": failure_reasons, | |
| "policy_stack": result.get("policy_stack"), | |
| "bandit_topk": result.get("bandit_topk", []), | |
| "final_action": result.get("final_action", {}), | |
| "primary_reward_channels": primary_channels, | |
| } | |
| ) | |
| summary = metrics.summary() | |
| summary["failure_mining"] = failure_mining_summary(replay.records) | |
| if checkpoint_dir: | |
| save_checkpoint(checkpoint_dir / "planner_grpo.json", summary) | |
| replay.dump_jsonl(checkpoint_dir / "planner_replay.jsonl") | |
| replay.dump_failures_json(checkpoint_dir / "planner_failures.json") | |
| return summary | |