Spaces:
Sleeping
Sleeping
feat: implement core RL training infrastructure, including GRPO training, evaluation utilities, custom environments, and Modal-based execution scripts.
3807ea3 | """CyberSecurity_OWASP OpenEnv environment implementation.""" | |
| from __future__ import annotations | |
| import json | |
| import shutil | |
| from pathlib import Path | |
| from typing import Any | |
| from uuid import uuid4 | |
| from openenv.core.env_server.interfaces import Environment | |
| try: | |
| from ..models import ( | |
| CyberSecurityOWASPAction, | |
| CyberSecurityOWASPObservation, | |
| CyberSecurityOWASPState, | |
| ) | |
| from ..scenario_compiler import compile_scenario | |
| from ..safety import is_local_route | |
| from ..validators import detect_cheating, is_path_allowed, simulate_request | |
| from .reward_engine import evaluate_action | |
| except ImportError: # pragma: no cover | |
| from models import CyberSecurityOWASPAction, CyberSecurityOWASPObservation, CyberSecurityOWASPState | |
| from scenario_compiler import compile_scenario | |
| from safety import is_local_route | |
| from validators import detect_cheating, is_path_allowed, simulate_request | |
| from server.reward_engine import evaluate_action | |
| ALLOWED_TOOLS = { | |
| "discover": { | |
| "inspect_policy_graph", | |
| "list_routes", | |
| "read_openapi", | |
| "read_file", | |
| "search_code", | |
| "send_local_request", | |
| "compare_identities", | |
| "submit_finding", | |
| "noop", | |
| }, | |
| "patch": { | |
| "read_file", | |
| "search_code", | |
| "patch_file", | |
| "run_visible_tests", | |
| "send_local_request", | |
| "submit_fix", | |
| "noop", | |
| }, | |
| "done": set(), | |
| } | |
| class CybersecurityOwaspEnvironment( | |
| Environment[CyberSecurityOWASPAction, CyberSecurityOWASPObservation, CyberSecurityOWASPState] | |
| ): | |
| """Single-agent defensive authorization-repair environment.""" | |
| SUPPORTS_CONCURRENT_SESSIONS = True | |
| def __init__(self): | |
| super().__init__() | |
| self._state = CyberSecurityOWASPState(episode_id=str(uuid4())) | |
| self._task_brief = "" | |
| self._visible_policy_hint: dict[str, Any] = {} | |
| self._workspace_summary: dict[str, Any] = {} | |
| self._last_done_observation: CyberSecurityOWASPObservation | None = None | |
| def reset( | |
| self, | |
| seed: int | None = None, | |
| episode_id: str | None = None, | |
| split: str = "train", | |
| difficulty: int = 0, | |
| **_: Any, | |
| ) -> CyberSecurityOWASPObservation: | |
| self.close() | |
| actual_seed = int(seed if seed is not None else 0) | |
| scenario = compile_scenario(actual_seed, split=split, difficulty=difficulty) | |
| self._state = CyberSecurityOWASPState( | |
| episode_id=episode_id or str(uuid4()), | |
| task_id=scenario["task_id"], | |
| seed=actual_seed, | |
| split=split, | |
| difficulty=difficulty, | |
| domain=scenario["domain"], | |
| bug_family=scenario["bug_family"], | |
| phase="discover", | |
| step_count=0, | |
| max_steps=40, | |
| done=False, | |
| success=False, | |
| visible_facts={"workspace_summary": scenario["workspace_summary"]}, | |
| hidden_facts=scenario["hidden_facts"], | |
| metrics={"reset_count": 1}, | |
| ) | |
| self._task_brief = scenario["task_brief"] | |
| self._visible_policy_hint = scenario["public_hint"] | |
| self._workspace_summary = scenario["workspace_summary"] | |
| self._last_done_observation = None | |
| return self._observation("Scenario ready. Start in discover phase.", reward=0.0) | |
| def step( | |
| self, | |
| action: CyberSecurityOWASPAction, | |
| timeout_s: float | None = None, | |
| **_: Any, | |
| ) -> CyberSecurityOWASPObservation: | |
| if self._state.done: | |
| return self._last_done_observation or self._observation( | |
| "Episode is already done.", reward=0.0, done_reason=self._state.failure_reason | |
| ) | |
| anti_cheat_flags = detect_cheating(self._state, action) | |
| for flag in anti_cheat_flags: | |
| if flag not in self._state.anti_cheat_flags: | |
| self._state.anti_cheat_flags.append(flag) | |
| self._state.step_count += 1 | |
| self._state.action_history.append( | |
| {"tool_name": action.tool_name, "arguments": action.arguments} | |
| ) | |
| if action.tool_name not in ALLOWED_TOOLS[self._state.phase]: | |
| verifier, reward = evaluate_action(self._state, action, anti_cheat_flags) | |
| return self._finish_step( | |
| "Action is not allowed in the current phase.", | |
| reward, | |
| valid=False, | |
| error=f"{action.tool_name} is not allowed during {self._state.phase}", | |
| verifier=verifier, | |
| ) | |
| try: | |
| result, verifier, reward, visible_tests = self._execute(action, anti_cheat_flags) | |
| return self._finish_step( | |
| result, | |
| reward, | |
| valid=True, | |
| verifier=verifier, | |
| visible_test_result=visible_tests, | |
| ) | |
| except Exception as exc: # keep malformed agent actions from crashing the server | |
| verifier, reward = evaluate_action(self._state, action, anti_cheat_flags) | |
| return self._finish_step( | |
| "Tool execution failed.", | |
| reward, | |
| valid=False, | |
| error=str(exc), | |
| verifier=verifier, | |
| ) | |
| def state(self) -> CyberSecurityOWASPState: | |
| return self._state | |
| def close(self) -> None: | |
| workspace = self._state.hidden_facts.get("workspace") | |
| if workspace: | |
| shutil.rmtree(workspace, ignore_errors=True) | |
| def _execute( | |
| self, action: CyberSecurityOWASPAction, anti_cheat_flags: list[str] | |
| ) -> tuple[str, dict, dict[str, float], str | None]: | |
| verifier: dict = {"anti_cheat_flags": anti_cheat_flags} | |
| reward = {key: 0.0 for key in ( | |
| "discovery", | |
| "security", | |
| "regression", | |
| "public_routes", | |
| "patch_quality", | |
| "visible_tests", | |
| "safety", | |
| "anti_cheat", | |
| "total", | |
| )} | |
| visible_tests = None | |
| args = action.arguments or {} | |
| if action.tool_name == "noop": | |
| return "No operation.", verifier, reward, None | |
| if action.tool_name == "inspect_policy_graph": | |
| return json.dumps(self._visible_policy_hint, indent=2, sort_keys=True), verifier, reward, None | |
| if action.tool_name == "list_routes": | |
| return json.dumps(self._workspace_summary["routes"], indent=2), verifier, reward, None | |
| if action.tool_name == "read_openapi": | |
| return json.dumps( | |
| { | |
| "openapi": "3.1.0", | |
| "info": {"title": "Generated invoices app", "version": "0.1.0"}, | |
| "paths": { | |
| "/health": {"get": {"x-public": True}}, | |
| "/invoices/{invoice_id}": {"get": {"x-public": False}}, | |
| }, | |
| }, | |
| indent=2, | |
| ), verifier, reward, None | |
| if action.tool_name == "read_file": | |
| path = self._resolve_path(str(args.get("path", ""))) | |
| return path.read_text(encoding="utf-8"), verifier, reward, None | |
| if action.tool_name == "search_code": | |
| return self._search_code(str(args.get("query", ""))), verifier, reward, None | |
| if action.tool_name == "send_local_request": | |
| if not is_local_route(str(args.get("path", ""))): | |
| raise ValueError("send_local_request only accepts local route paths") | |
| response = simulate_request( | |
| self._state, | |
| str(args.get("method", "GET")), | |
| str(args.get("path", "")), | |
| args.get("user_id"), | |
| ) | |
| return json.dumps(response, indent=2, sort_keys=True), verifier, reward, None | |
| if action.tool_name == "compare_identities": | |
| path = str(args.get("path", "")) | |
| first = str(args.get("first_user_id", "")) | |
| second = str(args.get("second_user_id", "")) | |
| if not is_local_route(path): | |
| raise ValueError("compare_identities only accepts local route paths") | |
| response = { | |
| "first": simulate_request(self._state, str(args.get("method", "GET")), path, first), | |
| "second": simulate_request(self._state, str(args.get("method", "GET")), path, second), | |
| } | |
| return json.dumps(response, indent=2, sort_keys=True), verifier, reward, None | |
| if action.tool_name == "submit_finding": | |
| verifier, reward = evaluate_action(self._state, action, anti_cheat_flags) | |
| if verifier.get("finding", {}).get("valid"): | |
| self._state.finding_submitted = True | |
| self._state.phase = "patch" | |
| return "Finding accepted. Patch phase unlocked.", verifier, reward, None | |
| return "Finding was not specific enough to unlock patching.", verifier, reward, None | |
| if action.tool_name == "patch_file": | |
| path = self._resolve_path(str(args.get("path", "")), write=True) | |
| if "content" in args: | |
| path.write_text(str(args["content"]), encoding="utf-8") | |
| else: | |
| self._apply_unified_diff(path, str(args.get("diff", ""))) | |
| return f"Patched {args.get('path')}.", verifier, reward, None | |
| if action.tool_name == "run_visible_tests": | |
| verifier, reward = evaluate_action(self._state, action, anti_cheat_flags) | |
| visible_tests = json.dumps(verifier.get("visible", {}), indent=2, sort_keys=True) | |
| return visible_tests, verifier, reward, visible_tests | |
| if action.tool_name == "submit_fix": | |
| verifier, reward = evaluate_action(self._state, action, anti_cheat_flags) | |
| self._state.patch_submitted = True | |
| security = verifier.get("security", {}).get("passed", False) | |
| regression = verifier.get("regression", {}).get("passed", False) | |
| public = verifier.get("public_routes", {}).get("passed", False) | |
| quality = verifier.get("patch_quality", {}).get("passed", False) | |
| self._state.success = bool(security and regression and public and quality) | |
| self._state.done = True | |
| self._state.phase = "done" | |
| self._state.failure_reason = None if self._state.success else "hidden_verifier_failed" | |
| return json.dumps(verifier, indent=2, sort_keys=True), verifier, reward, None | |
| raise ValueError(f"Unhandled tool {action.tool_name}") | |
| def _finish_step( | |
| self, | |
| message: str, | |
| reward: dict[str, float], | |
| *, | |
| valid: bool, | |
| error: str | None = None, | |
| verifier: dict | None = None, | |
| visible_test_result: str | None = None, | |
| ) -> CyberSecurityOWASPObservation: | |
| self._state.last_reward = float(reward.get("total", 0.0)) | |
| self._state.accumulated_reward += self._state.last_reward | |
| self._state.reward_history.append(reward) | |
| if self._state.step_count >= self._state.max_steps and not self._state.done: | |
| self._state.done = True | |
| self._state.phase = "done" | |
| self._state.failure_reason = "max_steps_exceeded" | |
| obs = self._observation( | |
| message, | |
| reward=self._state.last_reward, | |
| valid=valid, | |
| error=error, | |
| reward_breakdown=reward, | |
| visible_test_result=visible_test_result, | |
| done_reason=self._state.failure_reason, | |
| ) | |
| if self._state.done: | |
| self._last_done_observation = obs | |
| return obs | |
| def _observation( | |
| self, | |
| message: str, | |
| *, | |
| reward: float, | |
| valid: bool = True, | |
| error: str | None = None, | |
| reward_breakdown: dict[str, float] | None = None, | |
| visible_test_result: str | None = None, | |
| done_reason: str | None = None, | |
| ) -> CyberSecurityOWASPObservation: | |
| return CyberSecurityOWASPObservation( | |
| phase=self._state.phase, | |
| message=message, | |
| task_brief=self._task_brief, | |
| visible_policy_hint=self._visible_policy_hint, | |
| workspace_summary=self._workspace_summary, | |
| available_actions=sorted(ALLOWED_TOOLS[self._state.phase]), | |
| last_tool_result=message, | |
| last_action_valid=valid, | |
| last_action_error=error, | |
| visible_test_result=visible_test_result, | |
| reward_breakdown=reward_breakdown or {}, | |
| done_reason=done_reason, | |
| done=self._state.done, | |
| reward=reward, | |
| metadata={"episode_id": self._state.episode_id, "step_count": self._state.step_count}, | |
| ) | |
| def _resolve_path(self, path: str, *, write: bool = False) -> Path: | |
| allowed, normalized_or_error = is_path_allowed(self._state, path, write=write) | |
| if not allowed: | |
| raise ValueError(normalized_or_error) | |
| return Path(str(self._state.hidden_facts["workspace"])) / normalized_or_error | |
| def _search_code(self, query: str) -> str: | |
| if not query: | |
| raise ValueError("query is required") | |
| results: list[str] = [] | |
| workspace = Path(str(self._state.hidden_facts["workspace"])) | |
| for rel in self._state.hidden_facts.get("editable_files", []): | |
| path = workspace / rel | |
| text = path.read_text(encoding="utf-8") | |
| for idx, line in enumerate(text.splitlines(), start=1): | |
| if query.lower() in line.lower(): | |
| results.append(f"{rel}:{idx}: {line}") | |
| return "\n".join(results) or "No matches." | |
| def _apply_unified_diff(self, path: Path, diff: str) -> None: | |
| if not diff.strip(): | |
| raise ValueError("diff or content is required") | |
| original = path.read_text(encoding="utf-8").splitlines(True) | |
| output: list[str] = [] | |
| old_index = 0 | |
| lines = diff.splitlines(True) | |
| i = 0 | |
| while i < len(lines): | |
| line = lines[i] | |
| if not line.startswith("@@"): | |
| i += 1 | |
| continue | |
| old_start = int(line.split()[1].split(",")[0][1:]) | |
| output.extend(original[old_index : old_start - 1]) | |
| old_index = old_start - 1 | |
| i += 1 | |
| while i < len(lines) and not lines[i].startswith("@@"): | |
| hunk_line = lines[i] | |
| if hunk_line.startswith(" "): | |
| output.append(original[old_index]) | |
| old_index += 1 | |
| elif hunk_line.startswith("-"): | |
| old_index += 1 | |
| elif hunk_line.startswith("+"): | |
| output.append(hunk_line[1:]) | |
| elif hunk_line.startswith("\\"): | |
| pass | |
| i += 1 | |
| output.extend(original[old_index:]) | |
| path.write_text("".join(output), encoding="utf-8") | |