Spaces:
Running
Running
| """GRPO-style experiments with policy-stack ablations.""" | |
| from __future__ import annotations | |
| import os | |
| from pathlib import Path | |
| from typing import Any | |
| from app.agents.orchestrator import Orchestrator | |
| from app.env.env_core import PolyGuardEnv | |
| from app.training.metrics import TrainingMetrics | |
| from app.training.replay_buffer import ReplayBuffer, failure_mining_summary | |
| def run_policy_stack_rollout( | |
| policy_stack: str, | |
| episodes: int, | |
| checkpoint_dir: Path | None = None, | |
| seed_offset: int = 1_000, | |
| ) -> dict[str, Any]: | |
| previous = os.getenv("POLYGUARD_POLICY_STACK") | |
| os.environ["POLYGUARD_POLICY_STACK"] = policy_stack | |
| env = PolyGuardEnv() | |
| orchestrator = Orchestrator(env=env) | |
| metrics = TrainingMetrics() | |
| replay = ReplayBuffer() | |
| # Start small (easy/medium) before introducing harder environments. | |
| schedule = ["easy", "medium", "medium", "hard"] | |
| for i in range(episodes): | |
| difficulty = schedule[min(len(schedule) - 1, (i * len(schedule)) // max(1, episodes))] | |
| env.reset(seed=seed_offset + i, difficulty=difficulty) | |
| done = False | |
| while not done: | |
| out = orchestrator.run_step() | |
| done = bool(out.get("done", False)) | |
| info = out.get("info", {}) | |
| reward_components = info.get("reward_breakdown", {}) if isinstance(info, dict) else {} | |
| primary_channels = info.get("primary_reward_channels", {}) if isinstance(info, dict) else {} | |
| failure_reasons = info.get("failure_reasons", []) if isinstance(info, dict) else [] | |
| metrics.add( | |
| float(out.get("reward", 0.5)), | |
| legal=bool(out.get("critic", {}).get("legal", False)), | |
| severe_violation=len(out.get("critic", {}).get("violations", [])) > 1, | |
| abstain=str(out.get("final_action", {}).get("action_type", "")).startswith("REQUEST_"), | |
| episode_len=env.state.step_count, | |
| reward_components=reward_components if isinstance(reward_components, dict) else None, | |
| success=done and info.get("termination_reason") == "safe_resolution", | |
| burden_delta=0.0, | |
| safety_delta=float((reward_components or {}).get("safety_delta_score", 0.0)), | |
| dosing_quality=float((reward_components or {}).get("dosing_quality_score", 0.0)), | |
| process_fidelity=float((reward_components or {}).get("process_fidelity_score", 0.0)), | |
| exploit_detected=bool(info.get("anti_cheat_reasons")), | |
| timeout=bool(info.get("step_timeout") or info.get("termination_reason") == "wall_clock_timeout"), | |
| failure_visible=bool(failure_reasons), | |
| invalid_actions=int(info.get("invalid_action_count", 0)), | |
| primary_channels=primary_channels if isinstance(primary_channels, dict) else None, | |
| ) | |
| replay.add( | |
| { | |
| "policy_stack": policy_stack, | |
| "episode": i, | |
| "step": env.state.step_count, | |
| "reward": out.get("reward", 0.5), | |
| "final_action": out.get("final_action", {}), | |
| "termination_reason": info.get("termination_reason"), | |
| "failure_reasons": failure_reasons, | |
| "primary_reward_channels": primary_channels, | |
| } | |
| ) | |
| summary = metrics.summary() | |
| summary["policy_stack"] = policy_stack | |
| summary["failure_mining"] = failure_mining_summary(replay.records) | |
| if checkpoint_dir is not None: | |
| checkpoint_dir.mkdir(parents=True, exist_ok=True) | |
| replay.dump_jsonl(checkpoint_dir / f"{policy_stack.replace('+', '_')}_replay.jsonl") | |
| replay.dump_failures_json(checkpoint_dir / f"{policy_stack.replace('+', '_')}_failures.json") | |
| if previous is None: | |
| os.environ.pop("POLYGUARD_POLICY_STACK", None) | |
| else: | |
| os.environ["POLYGUARD_POLICY_STACK"] = previous | |
| return summary | |
| def probe_trl_grpo_support() -> dict[str, Any]: | |
| try: | |
| from trl import GRPOTrainer # noqa: F401 | |
| return {"available": True, "backend": "trl", "note": "GRPOTrainer import successful."} | |
| except Exception as exc: # noqa: BLE001 | |
| return {"available": False, "backend": "trl", "note": f"GRPOTrainer unavailable: {exc}"} | |