polyguard-openenv / app /training /supervisor_grpo.py
TheJackBright's picture
Deploy PolyGuard OpenEnv Space
877add7 verified
"""Supervisor GRPO-like trainer."""
from __future__ import annotations
from pathlib import Path
from app.env.env_core import PolyGuardEnv
from app.training.checkpointing import save_checkpoint
from app.training.metrics import TrainingMetrics
from app.training.replay_buffer import ReplayBuffer, failure_mining_summary
def train_supervisor_grpo(episodes: int = 10, checkpoint_dir: Path | None = None) -> dict:
env = PolyGuardEnv()
metrics = TrainingMetrics()
replay = ReplayBuffer()
for i in range(episodes):
env.reset(seed=42 + i, difficulty="easy" if i < episodes // 2 else "medium")
done = False
while not done:
candidates = env.get_legal_actions()
action = candidates[0]
pre_burden = env.state.burden_score
obs, reward, done, info = env.step(action)
legal = info["safety_report"]["legal"]
severe = len(info["safety_report"]["violations"]) > 1
abstain = action["action_type"].startswith("REQUEST_")
reward_components = info.get("reward_breakdown", {})
primary_channels = info.get("primary_reward_channels", {})
failure_reasons = info.get("failure_reasons", [])
metrics.add(
reward,
legal=legal,
severe_violation=severe,
abstain=abstain,
episode_len=env.state.step_count,
reward_components=reward_components,
success=done and info.get("termination_reason") == "safe_resolution",
burden_delta=pre_burden - env.state.burden_score,
safety_delta=float(reward_components.get("safety_delta_score", 0.0)),
dosing_quality=float(reward_components.get("dosing_quality_score", 0.0)),
process_fidelity=float(reward_components.get("process_fidelity_score", 0.0)),
exploit_detected=bool(info.get("anti_cheat_reasons")),
timeout=bool(info.get("step_timeout") or info.get("termination_reason") == "wall_clock_timeout"),
failure_visible=bool(failure_reasons),
invalid_actions=int(info.get("invalid_action_count", 0)),
primary_channels=primary_channels if isinstance(primary_channels, dict) else None,
)
replay.add(
{
"episode": i,
"step": env.state.step_count,
"reward": reward,
"legal": legal,
"termination_reason": info.get("termination_reason"),
"failure_reasons": failure_reasons,
"final_action": action,
"primary_reward_channels": primary_channels,
}
)
summary = metrics.summary()
summary["failure_mining"] = failure_mining_summary(replay.records)
if checkpoint_dir:
save_checkpoint(checkpoint_dir / "supervisor_grpo.json", summary)
replay.dump_jsonl(checkpoint_dir / "supervisor_replay.jsonl")
replay.dump_failures_json(checkpoint_dir / "supervisor_failures.json")
return summary