Spaces:
Sleeping
Sleeping
feat: enhance CyberSecurity_OWASP observation model with scenario prompt, improve GRPO batch configuration validation, and add scenario grouping for adaptive difficulty curriculum
632c145 | """CyberSecurity_OWASP OpenEnv environment implementation.""" | |
| from __future__ import annotations | |
| import json | |
| import shutil | |
| import time | |
| from dataclasses import asdict | |
| from typing import Any | |
| from uuid import uuid4 | |
| from openenv.core.env_server.interfaces import Environment | |
| try: | |
| from ..config import load_scenario_authoring_config | |
| from ..models import ( | |
| CyberSecurityOWASPAction, | |
| CyberSecurityOWASPObservation, | |
| CyberSecurityOWASPState, | |
| ) | |
| from ..validators import detect_cheating | |
| from .action_tools import ActionTools | |
| from .curriculum import CurriculumController | |
| from .episode_logger import EpisodeArtifactLogger | |
| from .reward_engine import evaluate_action | |
| from ..rewards import should_terminate_for_flags | |
| from .scenario_cache import ScenarioCache, ScenarioCacheMiss, cache_key_for_scenario | |
| from .scenario_factory import ScenarioFactory | |
| except ImportError: # pragma: no cover | |
| from config import load_scenario_authoring_config | |
| from models import CyberSecurityOWASPAction, CyberSecurityOWASPObservation, CyberSecurityOWASPState | |
| from validators import detect_cheating | |
| from server.action_tools import ActionTools | |
| from server.curriculum import CurriculumController | |
| from server.episode_logger import EpisodeArtifactLogger | |
| from server.reward_engine import evaluate_action | |
| from rewards import should_terminate_for_flags | |
| from server.scenario_cache import ScenarioCache, ScenarioCacheMiss, cache_key_for_scenario | |
| from server.scenario_factory import ScenarioFactory | |
| ALLOWED_TOOLS = { | |
| "discover": { | |
| "inspect_policy_graph", | |
| "list_routes", | |
| "read_openapi", | |
| "read_file", | |
| "search_code", | |
| "send_local_request", | |
| "compare_identities", | |
| "submit_diagnosis", | |
| "noop", | |
| }, | |
| "patch": { | |
| "read_file", | |
| "search_code", | |
| "patch_file", | |
| "run_visible_tests", | |
| "send_local_request", | |
| "submit_fix", | |
| "noop", | |
| }, | |
| "done": set(), | |
| } | |
| class CybersecurityOwaspEnvironment( | |
| Environment[CyberSecurityOWASPAction, CyberSecurityOWASPObservation, CyberSecurityOWASPState] | |
| ): | |
| """Single-agent defensive authorization-repair environment.""" | |
| SUPPORTS_CONCURRENT_SESSIONS = True | |
| def __init__(self): | |
| super().__init__() | |
| self._state = CyberSecurityOWASPState(episode_id=str(uuid4())) | |
| self._task_brief = "" | |
| self._visible_policy_hint: dict[str, Any] = {} | |
| self._workspace_summary: dict[str, Any] = {} | |
| self._last_done_observation: CyberSecurityOWASPObservation | None = None | |
| self._curriculum = CurriculumController() | |
| self._scenario_factory = ScenarioFactory() | |
| self._episode_logger = EpisodeArtifactLogger() | |
| def reset( | |
| self, | |
| seed: int | None = None, | |
| episode_id: str | None = None, | |
| split: str = "train", | |
| difficulty: int = 0, | |
| family_budget: dict[str, Any] | None = None, | |
| **_: Any, | |
| ) -> CyberSecurityOWASPObservation: | |
| reset_started = time.perf_counter() | |
| self.close() | |
| settings = load_scenario_authoring_config() | |
| self._curriculum.settings = settings | |
| actual_seed = int(seed if seed is not None else 0) | |
| curriculum_profile = self._curriculum.select_profile( | |
| seed=actual_seed, | |
| split=split, | |
| requested_difficulty=difficulty, | |
| ) | |
| scenario = self._load_or_compile_scenario( | |
| actual_seed, | |
| split=split, | |
| requested_difficulty=difficulty, | |
| curriculum_profile=curriculum_profile, | |
| family_budget=family_budget, | |
| settings=settings, | |
| ) | |
| cache_info = dict(scenario.get("cache", {})) | |
| cache_key = cache_info.get("cache_key", {}) | |
| scenario_hash = str(cache_info.get("scenario_hash", "")) | |
| reset_latency_ms = (time.perf_counter() - reset_started) * 1000 | |
| self._state = CyberSecurityOWASPState( | |
| episode_id=episode_id or str(uuid4()), | |
| task_id=scenario["task_id"], | |
| seed=actual_seed, | |
| split=split, | |
| difficulty=scenario["difficulty"], | |
| difficulty_tier=scenario["difficulty_tier"], | |
| domain=scenario["domain"], | |
| bug_family=scenario["bug_family"], | |
| scenario_family=scenario["scenario_family"], | |
| template_id=scenario["template_id"], | |
| target_weakness=scenario["target_weakness"], | |
| cache_key=cache_key, | |
| scenario_hash=scenario_hash, | |
| generator_version=str(cache_key.get("generator_version", settings.runtime.generator_version)), | |
| verifier_version=str(cache_key.get("verifier_version", settings.runtime.verifier_version)), | |
| cache_hit=bool(cache_info.get("hit", False)), | |
| reset_latency_ms=reset_latency_ms, | |
| phase="discover", | |
| step_count=0, | |
| max_steps=40, | |
| done=False, | |
| success=False, | |
| visible_facts={"workspace_summary": scenario["workspace_summary"]}, | |
| hidden_facts=scenario["hidden_facts"], | |
| curriculum_snapshot=scenario["curriculum_snapshot"], | |
| metrics={ | |
| "reset_count": 1, | |
| "reset_latency_ms": reset_latency_ms, | |
| "scenario_cache_hit": bool(cache_info.get("hit", False)), | |
| "scenario_cache_mode": settings.runtime.cache_mode, | |
| "scenario_cache_key": cache_key, | |
| "scenario_hash": scenario_hash, | |
| "scenario_bundle_load_latency_ms": float(cache_info.get("load_latency_ms", 0.0)), | |
| "scenario_compile_latency_ms": float(cache_info.get("compile_latency_ms", 0.0)), | |
| "scenario_cache_dir": settings.runtime.cache_dir, | |
| }, | |
| ) | |
| self._task_brief = scenario["task_brief"] | |
| self._visible_policy_hint = scenario["public_hint"] | |
| self._workspace_summary = scenario["workspace_summary"] | |
| self._last_done_observation = None | |
| return self._observation("Scenario ready. Start in discover phase.", reward=0.0) | |
| def _load_or_compile_scenario( | |
| self, | |
| seed: int, | |
| *, | |
| split: str, | |
| requested_difficulty: int, | |
| curriculum_profile: dict[str, Any], | |
| family_budget: dict[str, Any] | None, | |
| settings: Any, | |
| ) -> dict[str, Any]: | |
| difficulty = int(curriculum_profile.get("difficulty", requested_difficulty)) | |
| if settings.runtime.cache_mode != "disabled": | |
| cache = ScenarioCache(settings.runtime.cache_dir, settings=settings) | |
| try: | |
| cached = cache.load_bundle( | |
| seed=seed, | |
| split=split, | |
| difficulty=difficulty, | |
| family_budget=family_budget, | |
| ) | |
| return cached.scenario | |
| except ScenarioCacheMiss as exc: | |
| if settings.runtime.cache_mode == "require": | |
| raise RuntimeError( | |
| "Scenario cache miss in required mode. Run cache prep before " | |
| "training/eval; runtime reset must not compile scenarios. " | |
| f"Details: {exc}" | |
| ) from exc | |
| compile_started = time.perf_counter() | |
| scenario = self._scenario_factory.compile_scenario( | |
| seed, | |
| split=split, | |
| difficulty=requested_difficulty, | |
| curriculum_profile=curriculum_profile, | |
| ) | |
| key = cache_key_for_scenario(scenario, settings=settings) | |
| scenario["cache"] = { | |
| "hit": False, | |
| "cache_key": asdict(key), | |
| "scenario_hash": key.scenario_hash, | |
| "compile_latency_ms": (time.perf_counter() - compile_started) * 1000, | |
| } | |
| return scenario | |
| def step( | |
| self, | |
| action: CyberSecurityOWASPAction, | |
| timeout_s: float | None = None, | |
| **_: Any, | |
| ) -> CyberSecurityOWASPObservation: | |
| if self._state.done: | |
| return self._last_done_observation or self._observation( | |
| "Episode is already done.", reward=0.0, done_reason=self._state.failure_reason | |
| ) | |
| anti_cheat_flags = detect_cheating(self._state, action) | |
| for flag in anti_cheat_flags: | |
| if flag not in self._state.anti_cheat_flags: | |
| self._state.anti_cheat_flags.append(flag) | |
| self._state.step_count += 1 | |
| self._state.action_history.append( | |
| {"tool_name": action.tool_name, "arguments": action.arguments} | |
| ) | |
| if action.tool_name not in ALLOWED_TOOLS[self._state.phase]: | |
| verifier, reward = evaluate_action( | |
| self._state, | |
| action, | |
| anti_cheat_flags, | |
| invalid_action=True, | |
| ) | |
| return self._finish_step( | |
| "Action is not allowed in the current phase.", | |
| reward, | |
| valid=False, | |
| error=f"{action.tool_name} is not allowed during {self._state.phase}", | |
| verifier=verifier, | |
| ) | |
| try: | |
| result, verifier, reward, visible_tests = self._execute(action, anti_cheat_flags) | |
| return self._finish_step( | |
| result, | |
| reward, | |
| valid=True, | |
| verifier=verifier, | |
| visible_test_result=visible_tests, | |
| ) | |
| except Exception as exc: # keep malformed agent actions from crashing the server | |
| verifier, reward = evaluate_action( | |
| self._state, | |
| action, | |
| anti_cheat_flags, | |
| invalid_action=True, | |
| ) | |
| return self._finish_step( | |
| "Tool execution failed.", | |
| reward, | |
| valid=False, | |
| error=str(exc), | |
| verifier=verifier, | |
| ) | |
| def state(self) -> CyberSecurityOWASPState: | |
| return self._state | |
| def close(self) -> None: | |
| workspace = self._state.hidden_facts.get("workspace") | |
| if workspace: | |
| shutil.rmtree(workspace, ignore_errors=True) | |
| def _execute( | |
| self, action: CyberSecurityOWASPAction, anti_cheat_flags: list[str] | |
| ) -> tuple[str, dict, dict[str, float], str | None]: | |
| if action.tool_name in { | |
| "noop", | |
| "inspect_policy_graph", | |
| "list_routes", | |
| "read_openapi", | |
| "read_file", | |
| "search_code", | |
| "send_local_request", | |
| "compare_identities", | |
| "patch_file", | |
| }: | |
| result = ActionTools( | |
| self._state, | |
| self._visible_policy_hint, | |
| self._workspace_summary, | |
| ).execute(action) | |
| verifier, reward = evaluate_action(self._state, action, anti_cheat_flags) | |
| return result.message, verifier, reward, result.visible_test_result | |
| if action.tool_name == "submit_diagnosis": | |
| verifier, reward = evaluate_action(self._state, action, anti_cheat_flags) | |
| self._state.verification_summary = verifier | |
| self._state.diagnosis = dict(action.arguments or {}) | |
| if verifier.get("diagnosis", {}).get("valid"): | |
| self._state.diagnosis_submitted = True | |
| self._state.finding_submitted = True | |
| self._state.phase = "patch" | |
| return "Diagnosis recorded. Patch phase unlocked.", verifier, reward, None | |
| return "Diagnosis was not specific enough to unlock patching.", verifier, reward, None | |
| if action.tool_name == "run_visible_tests": | |
| self._state.visible_test_count += 1 | |
| verifier, reward = evaluate_action(self._state, action, anti_cheat_flags) | |
| self._state.verification_summary = verifier | |
| visible_tests = json.dumps(verifier.get("visible", {}), indent=2, sort_keys=True) | |
| return visible_tests, verifier, reward, visible_tests | |
| if action.tool_name == "submit_fix": | |
| verifier, reward = evaluate_action(self._state, action, anti_cheat_flags) | |
| self._state.verification_summary = verifier | |
| self._state.patch_submitted = True | |
| security = verifier.get("security", {}).get("passed", False) | |
| oracle = verifier.get("oracle_matrix", {}).get("passed", False) | |
| regression = verifier.get("regression", {}).get("passed", False) | |
| public = verifier.get("public_routes", {}).get("passed", False) | |
| quality = verifier.get("patch_quality", {}).get("passed", False) | |
| self._state.success = bool(security and oracle and regression and public and quality) | |
| self._state.done = True | |
| self._state.phase = "done" | |
| self._state.failure_reason = None if self._state.success else "hidden_verifier_failed" | |
| return json.dumps(verifier, indent=2, sort_keys=True), verifier, reward, None | |
| raise ValueError(f"Unhandled tool {action.tool_name}") | |
| def _finish_step( | |
| self, | |
| message: str, | |
| reward: dict[str, float], | |
| *, | |
| valid: bool, | |
| error: str | None = None, | |
| verifier: dict | None = None, | |
| visible_test_result: str | None = None, | |
| ) -> CyberSecurityOWASPObservation: | |
| self._state.last_reward = float(reward.get("total", 0.0)) | |
| self._state.accumulated_reward += self._state.last_reward | |
| self._state.reward_history.append(reward) | |
| flags = list((verifier or {}).get("anti_cheat_flags", []) or []) | |
| if flags and should_terminate_for_flags(flags): | |
| self._state.done = True | |
| self._state.phase = "done" | |
| self._state.failure_reason = "anti_cheat_violation" | |
| if self._state.step_count >= self._state.max_steps and not self._state.done: | |
| self._state.done = True | |
| self._state.phase = "done" | |
| self._state.failure_reason = "max_steps_exceeded" | |
| obs = self._observation( | |
| message, | |
| reward=self._state.last_reward, | |
| valid=valid, | |
| error=error, | |
| reward_breakdown=reward, | |
| visible_test_result=visible_test_result, | |
| done_reason=self._state.failure_reason, | |
| ) | |
| observation_record = obs.model_dump() | |
| self._state.observation_history.append(observation_record) | |
| if self._state.done: | |
| self._finalize_terminal_episode(observation_record) | |
| self._last_done_observation = obs | |
| return obs | |
| def _observation( | |
| self, | |
| message: str, | |
| *, | |
| reward: float, | |
| valid: bool = True, | |
| error: str | None = None, | |
| reward_breakdown: dict[str, float] | None = None, | |
| visible_test_result: str | None = None, | |
| done_reason: str | None = None, | |
| ) -> CyberSecurityOWASPObservation: | |
| available_actions = sorted(ALLOWED_TOOLS[self._state.phase]) | |
| return CyberSecurityOWASPObservation( | |
| phase=self._state.phase, | |
| message=message, | |
| task_brief=self._task_brief, | |
| scenario_prompt=self._scenario_prompt(available_actions), | |
| visible_policy_hint=self._visible_policy_hint, | |
| workspace_summary=self._workspace_summary, | |
| available_actions=available_actions, | |
| last_tool_result=message, | |
| last_action_valid=valid, | |
| last_action_error=error, | |
| visible_test_result=visible_test_result, | |
| reward_breakdown=reward_breakdown or {}, | |
| done_reason=done_reason, | |
| done=self._state.done, | |
| reward=reward, | |
| metadata={ | |
| "episode_id": self._state.episode_id, | |
| "step_count": self._state.step_count, | |
| "seed": self._state.seed, | |
| "split": self._state.split, | |
| "difficulty": self._state.difficulty, | |
| "difficulty_tier": self._state.difficulty_tier, | |
| "domain": self._state.domain, | |
| "bug_family": self._state.bug_family, | |
| "template_id": self._state.template_id, | |
| "scenario_hash": self._state.scenario_hash, | |
| }, | |
| ) | |
| def _scenario_prompt(self, available_actions: list[str]) -> str: | |
| users = self._visible_policy_hint.get("fixture_aliases", {}).get("users", {}) | |
| resources = self._visible_policy_hint.get("fixture_aliases", {}).get("resources", {}) | |
| visible_policy = { | |
| "domain": self._visible_policy_hint.get("domain", self._state.domain), | |
| "policy_rules": list(self._visible_policy_hint.get("policy_rules", [])), | |
| "public_routes": list(self._visible_policy_hint.get("public_routes", [])), | |
| "fixture_aliases": { | |
| "users": sorted(str(key) for key in users), | |
| "resources": sorted(str(key) for key in resources), | |
| }, | |
| } | |
| prompt = { | |
| "environment": "CyberSecurity_OWASP", | |
| "task": self._task_brief, | |
| "scenario": { | |
| "task_id": self._state.task_id, | |
| "seed": self._state.seed, | |
| "split": self._state.split, | |
| "difficulty": self._state.difficulty, | |
| "difficulty_tier": self._state.difficulty_tier, | |
| "domain": self._state.domain, | |
| "bug_family": self._state.bug_family, | |
| "template_id": self._state.template_id, | |
| "scenario_hash": self._state.scenario_hash, | |
| }, | |
| "visible_policy_hint": visible_policy, | |
| "workspace_summary": self._workspace_summary, | |
| "available_actions": available_actions, | |
| "instructions": [ | |
| "Use only the local generated application and the listed tools.", | |
| "Discover the authorization defect with local evidence before patching.", | |
| "Preserve legitimate owner/admin flows and intentionally public routes.", | |
| "Submit exactly one secure, policy-aligned fix when ready.", | |
| ], | |
| } | |
| return "CyberSecurity_OWASP scenario prompt:\n" + json.dumps( | |
| prompt, | |
| indent=2, | |
| sort_keys=True, | |
| ) | |
| def _finalize_terminal_episode(self, observation_record: dict[str, Any]) -> None: | |
| if self._state.episode_artifact_path: | |
| return | |
| mastery = self._curriculum.record_episode(self._state) | |
| self._state.curriculum_snapshot = { | |
| **self._state.curriculum_snapshot, | |
| "post_episode_mastery": mastery, | |
| } | |
| self._episode_logger.log_episode( | |
| self._state, | |
| final_observation=observation_record, | |
| ) | |