"""GRPO-style experiments with policy-stack ablations.""" from __future__ import annotations import os from pathlib import Path from typing import Any from app.agents.orchestrator import Orchestrator from app.env.env_core import PolyGuardEnv from app.training.metrics import TrainingMetrics from app.training.replay_buffer import ReplayBuffer, failure_mining_summary def run_policy_stack_rollout( policy_stack: str, episodes: int, checkpoint_dir: Path | None = None, seed_offset: int = 1_000, ) -> dict[str, Any]: previous = os.getenv("POLYGUARD_POLICY_STACK") os.environ["POLYGUARD_POLICY_STACK"] = policy_stack env = PolyGuardEnv() orchestrator = Orchestrator(env=env) metrics = TrainingMetrics() replay = ReplayBuffer() # Start small (easy/medium) before introducing harder environments. schedule = ["easy", "medium", "medium", "hard"] for i in range(episodes): difficulty = schedule[min(len(schedule) - 1, (i * len(schedule)) // max(1, episodes))] env.reset(seed=seed_offset + i, difficulty=difficulty) done = False while not done: out = orchestrator.run_step() done = bool(out.get("done", False)) info = out.get("info", {}) reward_components = info.get("reward_breakdown", {}) if isinstance(info, dict) else {} primary_channels = info.get("primary_reward_channels", {}) if isinstance(info, dict) else {} failure_reasons = info.get("failure_reasons", []) if isinstance(info, dict) else [] metrics.add( float(out.get("reward", 0.5)), legal=bool(out.get("critic", {}).get("legal", False)), severe_violation=len(out.get("critic", {}).get("violations", [])) > 1, abstain=str(out.get("final_action", {}).get("action_type", "")).startswith("REQUEST_"), episode_len=env.state.step_count, reward_components=reward_components if isinstance(reward_components, dict) else None, success=done and info.get("termination_reason") == "safe_resolution", burden_delta=0.0, safety_delta=float((reward_components or {}).get("safety_delta_score", 0.0)), dosing_quality=float((reward_components or {}).get("dosing_quality_score", 0.0)), process_fidelity=float((reward_components or {}).get("process_fidelity_score", 0.0)), exploit_detected=bool(info.get("anti_cheat_reasons")), timeout=bool(info.get("step_timeout") or info.get("termination_reason") == "wall_clock_timeout"), failure_visible=bool(failure_reasons), invalid_actions=int(info.get("invalid_action_count", 0)), primary_channels=primary_channels if isinstance(primary_channels, dict) else None, ) replay.add( { "policy_stack": policy_stack, "episode": i, "step": env.state.step_count, "reward": out.get("reward", 0.5), "final_action": out.get("final_action", {}), "termination_reason": info.get("termination_reason"), "failure_reasons": failure_reasons, "primary_reward_channels": primary_channels, } ) summary = metrics.summary() summary["policy_stack"] = policy_stack summary["failure_mining"] = failure_mining_summary(replay.records) if checkpoint_dir is not None: checkpoint_dir.mkdir(parents=True, exist_ok=True) replay.dump_jsonl(checkpoint_dir / f"{policy_stack.replace('+', '_')}_replay.jsonl") replay.dump_failures_json(checkpoint_dir / f"{policy_stack.replace('+', '_')}_failures.json") if previous is None: os.environ.pop("POLYGUARD_POLICY_STACK", None) else: os.environ["POLYGUARD_POLICY_STACK"] = previous return summary def probe_trl_grpo_support() -> dict[str, Any]: try: from trl import GRPOTrainer # noqa: F401 return {"available": True, "backend": "trl", "note": "GRPOTrainer import successful."} except Exception as exc: # noqa: BLE001 return {"available": False, "backend": "trl", "note": f"GRPOTrainer unavailable: {exc}"}