polyguard-openenv / app /training /grpo_experiment.py
TheJackBright's picture
Deploy PolyGuard OpenEnv Space
877add7 verified
"""GRPO-style experiments with policy-stack ablations."""
from __future__ import annotations
import os
from pathlib import Path
from typing import Any
from app.agents.orchestrator import Orchestrator
from app.env.env_core import PolyGuardEnv
from app.training.metrics import TrainingMetrics
from app.training.replay_buffer import ReplayBuffer, failure_mining_summary
def run_policy_stack_rollout(
policy_stack: str,
episodes: int,
checkpoint_dir: Path | None = None,
seed_offset: int = 1_000,
) -> dict[str, Any]:
previous = os.getenv("POLYGUARD_POLICY_STACK")
os.environ["POLYGUARD_POLICY_STACK"] = policy_stack
env = PolyGuardEnv()
orchestrator = Orchestrator(env=env)
metrics = TrainingMetrics()
replay = ReplayBuffer()
# Start small (easy/medium) before introducing harder environments.
schedule = ["easy", "medium", "medium", "hard"]
for i in range(episodes):
difficulty = schedule[min(len(schedule) - 1, (i * len(schedule)) // max(1, episodes))]
env.reset(seed=seed_offset + i, difficulty=difficulty)
done = False
while not done:
out = orchestrator.run_step()
done = bool(out.get("done", False))
info = out.get("info", {})
reward_components = info.get("reward_breakdown", {}) if isinstance(info, dict) else {}
primary_channels = info.get("primary_reward_channels", {}) if isinstance(info, dict) else {}
failure_reasons = info.get("failure_reasons", []) if isinstance(info, dict) else []
metrics.add(
float(out.get("reward", 0.5)),
legal=bool(out.get("critic", {}).get("legal", False)),
severe_violation=len(out.get("critic", {}).get("violations", [])) > 1,
abstain=str(out.get("final_action", {}).get("action_type", "")).startswith("REQUEST_"),
episode_len=env.state.step_count,
reward_components=reward_components if isinstance(reward_components, dict) else None,
success=done and info.get("termination_reason") == "safe_resolution",
burden_delta=0.0,
safety_delta=float((reward_components or {}).get("safety_delta_score", 0.0)),
dosing_quality=float((reward_components or {}).get("dosing_quality_score", 0.0)),
process_fidelity=float((reward_components or {}).get("process_fidelity_score", 0.0)),
exploit_detected=bool(info.get("anti_cheat_reasons")),
timeout=bool(info.get("step_timeout") or info.get("termination_reason") == "wall_clock_timeout"),
failure_visible=bool(failure_reasons),
invalid_actions=int(info.get("invalid_action_count", 0)),
primary_channels=primary_channels if isinstance(primary_channels, dict) else None,
)
replay.add(
{
"policy_stack": policy_stack,
"episode": i,
"step": env.state.step_count,
"reward": out.get("reward", 0.5),
"final_action": out.get("final_action", {}),
"termination_reason": info.get("termination_reason"),
"failure_reasons": failure_reasons,
"primary_reward_channels": primary_channels,
}
)
summary = metrics.summary()
summary["policy_stack"] = policy_stack
summary["failure_mining"] = failure_mining_summary(replay.records)
if checkpoint_dir is not None:
checkpoint_dir.mkdir(parents=True, exist_ok=True)
replay.dump_jsonl(checkpoint_dir / f"{policy_stack.replace('+', '_')}_replay.jsonl")
replay.dump_failures_json(checkpoint_dir / f"{policy_stack.replace('+', '_')}_failures.json")
if previous is None:
os.environ.pop("POLYGUARD_POLICY_STACK", None)
else:
os.environ["POLYGUARD_POLICY_STACK"] = previous
return summary
def probe_trl_grpo_support() -> dict[str, Any]:
try:
from trl import GRPOTrainer # noqa: F401
return {"available": True, "backend": "trl", "note": "GRPOTrainer import successful."}
except Exception as exc: # noqa: BLE001
return {"available": False, "backend": "trl", "note": f"GRPOTrainer unavailable: {exc}"}