Cyber_analyst-round1 / server /CyberSecurity_OWASP_environment.py
Humanlearning's picture
feat: implement core RL training infrastructure, including GRPO training, evaluation utilities, custom environments, and Modal-based execution scripts.
3807ea3
raw
history blame
15 kB
"""CyberSecurity_OWASP OpenEnv environment implementation."""
from __future__ import annotations
import json
import shutil
from pathlib import Path
from typing import Any
from uuid import uuid4
from openenv.core.env_server.interfaces import Environment
try:
from ..models import (
CyberSecurityOWASPAction,
CyberSecurityOWASPObservation,
CyberSecurityOWASPState,
)
from ..scenario_compiler import compile_scenario
from ..safety import is_local_route
from ..validators import detect_cheating, is_path_allowed, simulate_request
from .reward_engine import evaluate_action
except ImportError: # pragma: no cover
from models import CyberSecurityOWASPAction, CyberSecurityOWASPObservation, CyberSecurityOWASPState
from scenario_compiler import compile_scenario
from safety import is_local_route
from validators import detect_cheating, is_path_allowed, simulate_request
from server.reward_engine import evaluate_action
ALLOWED_TOOLS = {
"discover": {
"inspect_policy_graph",
"list_routes",
"read_openapi",
"read_file",
"search_code",
"send_local_request",
"compare_identities",
"submit_finding",
"noop",
},
"patch": {
"read_file",
"search_code",
"patch_file",
"run_visible_tests",
"send_local_request",
"submit_fix",
"noop",
},
"done": set(),
}
class CybersecurityOwaspEnvironment(
Environment[CyberSecurityOWASPAction, CyberSecurityOWASPObservation, CyberSecurityOWASPState]
):
"""Single-agent defensive authorization-repair environment."""
SUPPORTS_CONCURRENT_SESSIONS = True
def __init__(self):
super().__init__()
self._state = CyberSecurityOWASPState(episode_id=str(uuid4()))
self._task_brief = ""
self._visible_policy_hint: dict[str, Any] = {}
self._workspace_summary: dict[str, Any] = {}
self._last_done_observation: CyberSecurityOWASPObservation | None = None
def reset(
self,
seed: int | None = None,
episode_id: str | None = None,
split: str = "train",
difficulty: int = 0,
**_: Any,
) -> CyberSecurityOWASPObservation:
self.close()
actual_seed = int(seed if seed is not None else 0)
scenario = compile_scenario(actual_seed, split=split, difficulty=difficulty)
self._state = CyberSecurityOWASPState(
episode_id=episode_id or str(uuid4()),
task_id=scenario["task_id"],
seed=actual_seed,
split=split,
difficulty=difficulty,
domain=scenario["domain"],
bug_family=scenario["bug_family"],
phase="discover",
step_count=0,
max_steps=40,
done=False,
success=False,
visible_facts={"workspace_summary": scenario["workspace_summary"]},
hidden_facts=scenario["hidden_facts"],
metrics={"reset_count": 1},
)
self._task_brief = scenario["task_brief"]
self._visible_policy_hint = scenario["public_hint"]
self._workspace_summary = scenario["workspace_summary"]
self._last_done_observation = None
return self._observation("Scenario ready. Start in discover phase.", reward=0.0)
def step(
self,
action: CyberSecurityOWASPAction,
timeout_s: float | None = None,
**_: Any,
) -> CyberSecurityOWASPObservation:
if self._state.done:
return self._last_done_observation or self._observation(
"Episode is already done.", reward=0.0, done_reason=self._state.failure_reason
)
anti_cheat_flags = detect_cheating(self._state, action)
for flag in anti_cheat_flags:
if flag not in self._state.anti_cheat_flags:
self._state.anti_cheat_flags.append(flag)
self._state.step_count += 1
self._state.action_history.append(
{"tool_name": action.tool_name, "arguments": action.arguments}
)
if action.tool_name not in ALLOWED_TOOLS[self._state.phase]:
verifier, reward = evaluate_action(self._state, action, anti_cheat_flags)
return self._finish_step(
"Action is not allowed in the current phase.",
reward,
valid=False,
error=f"{action.tool_name} is not allowed during {self._state.phase}",
verifier=verifier,
)
try:
result, verifier, reward, visible_tests = self._execute(action, anti_cheat_flags)
return self._finish_step(
result,
reward,
valid=True,
verifier=verifier,
visible_test_result=visible_tests,
)
except Exception as exc: # keep malformed agent actions from crashing the server
verifier, reward = evaluate_action(self._state, action, anti_cheat_flags)
return self._finish_step(
"Tool execution failed.",
reward,
valid=False,
error=str(exc),
verifier=verifier,
)
@property
def state(self) -> CyberSecurityOWASPState:
return self._state
def close(self) -> None:
workspace = self._state.hidden_facts.get("workspace")
if workspace:
shutil.rmtree(workspace, ignore_errors=True)
def _execute(
self, action: CyberSecurityOWASPAction, anti_cheat_flags: list[str]
) -> tuple[str, dict, dict[str, float], str | None]:
verifier: dict = {"anti_cheat_flags": anti_cheat_flags}
reward = {key: 0.0 for key in (
"discovery",
"security",
"regression",
"public_routes",
"patch_quality",
"visible_tests",
"safety",
"anti_cheat",
"total",
)}
visible_tests = None
args = action.arguments or {}
if action.tool_name == "noop":
return "No operation.", verifier, reward, None
if action.tool_name == "inspect_policy_graph":
return json.dumps(self._visible_policy_hint, indent=2, sort_keys=True), verifier, reward, None
if action.tool_name == "list_routes":
return json.dumps(self._workspace_summary["routes"], indent=2), verifier, reward, None
if action.tool_name == "read_openapi":
return json.dumps(
{
"openapi": "3.1.0",
"info": {"title": "Generated invoices app", "version": "0.1.0"},
"paths": {
"/health": {"get": {"x-public": True}},
"/invoices/{invoice_id}": {"get": {"x-public": False}},
},
},
indent=2,
), verifier, reward, None
if action.tool_name == "read_file":
path = self._resolve_path(str(args.get("path", "")))
return path.read_text(encoding="utf-8"), verifier, reward, None
if action.tool_name == "search_code":
return self._search_code(str(args.get("query", ""))), verifier, reward, None
if action.tool_name == "send_local_request":
if not is_local_route(str(args.get("path", ""))):
raise ValueError("send_local_request only accepts local route paths")
response = simulate_request(
self._state,
str(args.get("method", "GET")),
str(args.get("path", "")),
args.get("user_id"),
)
return json.dumps(response, indent=2, sort_keys=True), verifier, reward, None
if action.tool_name == "compare_identities":
path = str(args.get("path", ""))
first = str(args.get("first_user_id", ""))
second = str(args.get("second_user_id", ""))
if not is_local_route(path):
raise ValueError("compare_identities only accepts local route paths")
response = {
"first": simulate_request(self._state, str(args.get("method", "GET")), path, first),
"second": simulate_request(self._state, str(args.get("method", "GET")), path, second),
}
return json.dumps(response, indent=2, sort_keys=True), verifier, reward, None
if action.tool_name == "submit_finding":
verifier, reward = evaluate_action(self._state, action, anti_cheat_flags)
if verifier.get("finding", {}).get("valid"):
self._state.finding_submitted = True
self._state.phase = "patch"
return "Finding accepted. Patch phase unlocked.", verifier, reward, None
return "Finding was not specific enough to unlock patching.", verifier, reward, None
if action.tool_name == "patch_file":
path = self._resolve_path(str(args.get("path", "")), write=True)
if "content" in args:
path.write_text(str(args["content"]), encoding="utf-8")
else:
self._apply_unified_diff(path, str(args.get("diff", "")))
return f"Patched {args.get('path')}.", verifier, reward, None
if action.tool_name == "run_visible_tests":
verifier, reward = evaluate_action(self._state, action, anti_cheat_flags)
visible_tests = json.dumps(verifier.get("visible", {}), indent=2, sort_keys=True)
return visible_tests, verifier, reward, visible_tests
if action.tool_name == "submit_fix":
verifier, reward = evaluate_action(self._state, action, anti_cheat_flags)
self._state.patch_submitted = True
security = verifier.get("security", {}).get("passed", False)
regression = verifier.get("regression", {}).get("passed", False)
public = verifier.get("public_routes", {}).get("passed", False)
quality = verifier.get("patch_quality", {}).get("passed", False)
self._state.success = bool(security and regression and public and quality)
self._state.done = True
self._state.phase = "done"
self._state.failure_reason = None if self._state.success else "hidden_verifier_failed"
return json.dumps(verifier, indent=2, sort_keys=True), verifier, reward, None
raise ValueError(f"Unhandled tool {action.tool_name}")
def _finish_step(
self,
message: str,
reward: dict[str, float],
*,
valid: bool,
error: str | None = None,
verifier: dict | None = None,
visible_test_result: str | None = None,
) -> CyberSecurityOWASPObservation:
self._state.last_reward = float(reward.get("total", 0.0))
self._state.accumulated_reward += self._state.last_reward
self._state.reward_history.append(reward)
if self._state.step_count >= self._state.max_steps and not self._state.done:
self._state.done = True
self._state.phase = "done"
self._state.failure_reason = "max_steps_exceeded"
obs = self._observation(
message,
reward=self._state.last_reward,
valid=valid,
error=error,
reward_breakdown=reward,
visible_test_result=visible_test_result,
done_reason=self._state.failure_reason,
)
if self._state.done:
self._last_done_observation = obs
return obs
def _observation(
self,
message: str,
*,
reward: float,
valid: bool = True,
error: str | None = None,
reward_breakdown: dict[str, float] | None = None,
visible_test_result: str | None = None,
done_reason: str | None = None,
) -> CyberSecurityOWASPObservation:
return CyberSecurityOWASPObservation(
phase=self._state.phase,
message=message,
task_brief=self._task_brief,
visible_policy_hint=self._visible_policy_hint,
workspace_summary=self._workspace_summary,
available_actions=sorted(ALLOWED_TOOLS[self._state.phase]),
last_tool_result=message,
last_action_valid=valid,
last_action_error=error,
visible_test_result=visible_test_result,
reward_breakdown=reward_breakdown or {},
done_reason=done_reason,
done=self._state.done,
reward=reward,
metadata={"episode_id": self._state.episode_id, "step_count": self._state.step_count},
)
def _resolve_path(self, path: str, *, write: bool = False) -> Path:
allowed, normalized_or_error = is_path_allowed(self._state, path, write=write)
if not allowed:
raise ValueError(normalized_or_error)
return Path(str(self._state.hidden_facts["workspace"])) / normalized_or_error
def _search_code(self, query: str) -> str:
if not query:
raise ValueError("query is required")
results: list[str] = []
workspace = Path(str(self._state.hidden_facts["workspace"]))
for rel in self._state.hidden_facts.get("editable_files", []):
path = workspace / rel
text = path.read_text(encoding="utf-8")
for idx, line in enumerate(text.splitlines(), start=1):
if query.lower() in line.lower():
results.append(f"{rel}:{idx}: {line}")
return "\n".join(results) or "No matches."
def _apply_unified_diff(self, path: Path, diff: str) -> None:
if not diff.strip():
raise ValueError("diff or content is required")
original = path.read_text(encoding="utf-8").splitlines(True)
output: list[str] = []
old_index = 0
lines = diff.splitlines(True)
i = 0
while i < len(lines):
line = lines[i]
if not line.startswith("@@"):
i += 1
continue
old_start = int(line.split()[1].split(",")[0][1:])
output.extend(original[old_index : old_start - 1])
old_index = old_start - 1
i += 1
while i < len(lines) and not lines[i].startswith("@@"):
hunk_line = lines[i]
if hunk_line.startswith(" "):
output.append(original[old_index])
old_index += 1
elif hunk_line.startswith("-"):
old_index += 1
elif hunk_line.startswith("+"):
output.append(hunk_line[1:])
elif hunk_line.startswith("\\"):
pass
i += 1
output.extend(original[old_index:])
path.write_text("".join(output), encoding="utf-8")