Cyber_analyst-round1 / server /CyberSecurity_OWASP_environment.py
Humanlearning's picture
feat: enhance CyberSecurity_OWASP observation model with scenario prompt, improve GRPO batch configuration validation, and add scenario grouping for adaptive difficulty curriculum
632c145
"""CyberSecurity_OWASP OpenEnv environment implementation."""
from __future__ import annotations
import json
import shutil
import time
from dataclasses import asdict
from typing import Any
from uuid import uuid4
from openenv.core.env_server.interfaces import Environment
try:
from ..config import load_scenario_authoring_config
from ..models import (
CyberSecurityOWASPAction,
CyberSecurityOWASPObservation,
CyberSecurityOWASPState,
)
from ..validators import detect_cheating
from .action_tools import ActionTools
from .curriculum import CurriculumController
from .episode_logger import EpisodeArtifactLogger
from .reward_engine import evaluate_action
from ..rewards import should_terminate_for_flags
from .scenario_cache import ScenarioCache, ScenarioCacheMiss, cache_key_for_scenario
from .scenario_factory import ScenarioFactory
except ImportError: # pragma: no cover
from config import load_scenario_authoring_config
from models import CyberSecurityOWASPAction, CyberSecurityOWASPObservation, CyberSecurityOWASPState
from validators import detect_cheating
from server.action_tools import ActionTools
from server.curriculum import CurriculumController
from server.episode_logger import EpisodeArtifactLogger
from server.reward_engine import evaluate_action
from rewards import should_terminate_for_flags
from server.scenario_cache import ScenarioCache, ScenarioCacheMiss, cache_key_for_scenario
from server.scenario_factory import ScenarioFactory
ALLOWED_TOOLS = {
"discover": {
"inspect_policy_graph",
"list_routes",
"read_openapi",
"read_file",
"search_code",
"send_local_request",
"compare_identities",
"submit_diagnosis",
"noop",
},
"patch": {
"read_file",
"search_code",
"patch_file",
"run_visible_tests",
"send_local_request",
"submit_fix",
"noop",
},
"done": set(),
}
class CybersecurityOwaspEnvironment(
Environment[CyberSecurityOWASPAction, CyberSecurityOWASPObservation, CyberSecurityOWASPState]
):
"""Single-agent defensive authorization-repair environment."""
SUPPORTS_CONCURRENT_SESSIONS = True
def __init__(self):
super().__init__()
self._state = CyberSecurityOWASPState(episode_id=str(uuid4()))
self._task_brief = ""
self._visible_policy_hint: dict[str, Any] = {}
self._workspace_summary: dict[str, Any] = {}
self._last_done_observation: CyberSecurityOWASPObservation | None = None
self._curriculum = CurriculumController()
self._scenario_factory = ScenarioFactory()
self._episode_logger = EpisodeArtifactLogger()
def reset(
self,
seed: int | None = None,
episode_id: str | None = None,
split: str = "train",
difficulty: int = 0,
family_budget: dict[str, Any] | None = None,
**_: Any,
) -> CyberSecurityOWASPObservation:
reset_started = time.perf_counter()
self.close()
settings = load_scenario_authoring_config()
self._curriculum.settings = settings
actual_seed = int(seed if seed is not None else 0)
curriculum_profile = self._curriculum.select_profile(
seed=actual_seed,
split=split,
requested_difficulty=difficulty,
)
scenario = self._load_or_compile_scenario(
actual_seed,
split=split,
requested_difficulty=difficulty,
curriculum_profile=curriculum_profile,
family_budget=family_budget,
settings=settings,
)
cache_info = dict(scenario.get("cache", {}))
cache_key = cache_info.get("cache_key", {})
scenario_hash = str(cache_info.get("scenario_hash", ""))
reset_latency_ms = (time.perf_counter() - reset_started) * 1000
self._state = CyberSecurityOWASPState(
episode_id=episode_id or str(uuid4()),
task_id=scenario["task_id"],
seed=actual_seed,
split=split,
difficulty=scenario["difficulty"],
difficulty_tier=scenario["difficulty_tier"],
domain=scenario["domain"],
bug_family=scenario["bug_family"],
scenario_family=scenario["scenario_family"],
template_id=scenario["template_id"],
target_weakness=scenario["target_weakness"],
cache_key=cache_key,
scenario_hash=scenario_hash,
generator_version=str(cache_key.get("generator_version", settings.runtime.generator_version)),
verifier_version=str(cache_key.get("verifier_version", settings.runtime.verifier_version)),
cache_hit=bool(cache_info.get("hit", False)),
reset_latency_ms=reset_latency_ms,
phase="discover",
step_count=0,
max_steps=40,
done=False,
success=False,
visible_facts={"workspace_summary": scenario["workspace_summary"]},
hidden_facts=scenario["hidden_facts"],
curriculum_snapshot=scenario["curriculum_snapshot"],
metrics={
"reset_count": 1,
"reset_latency_ms": reset_latency_ms,
"scenario_cache_hit": bool(cache_info.get("hit", False)),
"scenario_cache_mode": settings.runtime.cache_mode,
"scenario_cache_key": cache_key,
"scenario_hash": scenario_hash,
"scenario_bundle_load_latency_ms": float(cache_info.get("load_latency_ms", 0.0)),
"scenario_compile_latency_ms": float(cache_info.get("compile_latency_ms", 0.0)),
"scenario_cache_dir": settings.runtime.cache_dir,
},
)
self._task_brief = scenario["task_brief"]
self._visible_policy_hint = scenario["public_hint"]
self._workspace_summary = scenario["workspace_summary"]
self._last_done_observation = None
return self._observation("Scenario ready. Start in discover phase.", reward=0.0)
def _load_or_compile_scenario(
self,
seed: int,
*,
split: str,
requested_difficulty: int,
curriculum_profile: dict[str, Any],
family_budget: dict[str, Any] | None,
settings: Any,
) -> dict[str, Any]:
difficulty = int(curriculum_profile.get("difficulty", requested_difficulty))
if settings.runtime.cache_mode != "disabled":
cache = ScenarioCache(settings.runtime.cache_dir, settings=settings)
try:
cached = cache.load_bundle(
seed=seed,
split=split,
difficulty=difficulty,
family_budget=family_budget,
)
return cached.scenario
except ScenarioCacheMiss as exc:
if settings.runtime.cache_mode == "require":
raise RuntimeError(
"Scenario cache miss in required mode. Run cache prep before "
"training/eval; runtime reset must not compile scenarios. "
f"Details: {exc}"
) from exc
compile_started = time.perf_counter()
scenario = self._scenario_factory.compile_scenario(
seed,
split=split,
difficulty=requested_difficulty,
curriculum_profile=curriculum_profile,
)
key = cache_key_for_scenario(scenario, settings=settings)
scenario["cache"] = {
"hit": False,
"cache_key": asdict(key),
"scenario_hash": key.scenario_hash,
"compile_latency_ms": (time.perf_counter() - compile_started) * 1000,
}
return scenario
def step(
self,
action: CyberSecurityOWASPAction,
timeout_s: float | None = None,
**_: Any,
) -> CyberSecurityOWASPObservation:
if self._state.done:
return self._last_done_observation or self._observation(
"Episode is already done.", reward=0.0, done_reason=self._state.failure_reason
)
anti_cheat_flags = detect_cheating(self._state, action)
for flag in anti_cheat_flags:
if flag not in self._state.anti_cheat_flags:
self._state.anti_cheat_flags.append(flag)
self._state.step_count += 1
self._state.action_history.append(
{"tool_name": action.tool_name, "arguments": action.arguments}
)
if action.tool_name not in ALLOWED_TOOLS[self._state.phase]:
verifier, reward = evaluate_action(
self._state,
action,
anti_cheat_flags,
invalid_action=True,
)
return self._finish_step(
"Action is not allowed in the current phase.",
reward,
valid=False,
error=f"{action.tool_name} is not allowed during {self._state.phase}",
verifier=verifier,
)
try:
result, verifier, reward, visible_tests = self._execute(action, anti_cheat_flags)
return self._finish_step(
result,
reward,
valid=True,
verifier=verifier,
visible_test_result=visible_tests,
)
except Exception as exc: # keep malformed agent actions from crashing the server
verifier, reward = evaluate_action(
self._state,
action,
anti_cheat_flags,
invalid_action=True,
)
return self._finish_step(
"Tool execution failed.",
reward,
valid=False,
error=str(exc),
verifier=verifier,
)
@property
def state(self) -> CyberSecurityOWASPState:
return self._state
def close(self) -> None:
workspace = self._state.hidden_facts.get("workspace")
if workspace:
shutil.rmtree(workspace, ignore_errors=True)
def _execute(
self, action: CyberSecurityOWASPAction, anti_cheat_flags: list[str]
) -> tuple[str, dict, dict[str, float], str | None]:
if action.tool_name in {
"noop",
"inspect_policy_graph",
"list_routes",
"read_openapi",
"read_file",
"search_code",
"send_local_request",
"compare_identities",
"patch_file",
}:
result = ActionTools(
self._state,
self._visible_policy_hint,
self._workspace_summary,
).execute(action)
verifier, reward = evaluate_action(self._state, action, anti_cheat_flags)
return result.message, verifier, reward, result.visible_test_result
if action.tool_name == "submit_diagnosis":
verifier, reward = evaluate_action(self._state, action, anti_cheat_flags)
self._state.verification_summary = verifier
self._state.diagnosis = dict(action.arguments or {})
if verifier.get("diagnosis", {}).get("valid"):
self._state.diagnosis_submitted = True
self._state.finding_submitted = True
self._state.phase = "patch"
return "Diagnosis recorded. Patch phase unlocked.", verifier, reward, None
return "Diagnosis was not specific enough to unlock patching.", verifier, reward, None
if action.tool_name == "run_visible_tests":
self._state.visible_test_count += 1
verifier, reward = evaluate_action(self._state, action, anti_cheat_flags)
self._state.verification_summary = verifier
visible_tests = json.dumps(verifier.get("visible", {}), indent=2, sort_keys=True)
return visible_tests, verifier, reward, visible_tests
if action.tool_name == "submit_fix":
verifier, reward = evaluate_action(self._state, action, anti_cheat_flags)
self._state.verification_summary = verifier
self._state.patch_submitted = True
security = verifier.get("security", {}).get("passed", False)
oracle = verifier.get("oracle_matrix", {}).get("passed", False)
regression = verifier.get("regression", {}).get("passed", False)
public = verifier.get("public_routes", {}).get("passed", False)
quality = verifier.get("patch_quality", {}).get("passed", False)
self._state.success = bool(security and oracle and regression and public and quality)
self._state.done = True
self._state.phase = "done"
self._state.failure_reason = None if self._state.success else "hidden_verifier_failed"
return json.dumps(verifier, indent=2, sort_keys=True), verifier, reward, None
raise ValueError(f"Unhandled tool {action.tool_name}")
def _finish_step(
self,
message: str,
reward: dict[str, float],
*,
valid: bool,
error: str | None = None,
verifier: dict | None = None,
visible_test_result: str | None = None,
) -> CyberSecurityOWASPObservation:
self._state.last_reward = float(reward.get("total", 0.0))
self._state.accumulated_reward += self._state.last_reward
self._state.reward_history.append(reward)
flags = list((verifier or {}).get("anti_cheat_flags", []) or [])
if flags and should_terminate_for_flags(flags):
self._state.done = True
self._state.phase = "done"
self._state.failure_reason = "anti_cheat_violation"
if self._state.step_count >= self._state.max_steps and not self._state.done:
self._state.done = True
self._state.phase = "done"
self._state.failure_reason = "max_steps_exceeded"
obs = self._observation(
message,
reward=self._state.last_reward,
valid=valid,
error=error,
reward_breakdown=reward,
visible_test_result=visible_test_result,
done_reason=self._state.failure_reason,
)
observation_record = obs.model_dump()
self._state.observation_history.append(observation_record)
if self._state.done:
self._finalize_terminal_episode(observation_record)
self._last_done_observation = obs
return obs
def _observation(
self,
message: str,
*,
reward: float,
valid: bool = True,
error: str | None = None,
reward_breakdown: dict[str, float] | None = None,
visible_test_result: str | None = None,
done_reason: str | None = None,
) -> CyberSecurityOWASPObservation:
available_actions = sorted(ALLOWED_TOOLS[self._state.phase])
return CyberSecurityOWASPObservation(
phase=self._state.phase,
message=message,
task_brief=self._task_brief,
scenario_prompt=self._scenario_prompt(available_actions),
visible_policy_hint=self._visible_policy_hint,
workspace_summary=self._workspace_summary,
available_actions=available_actions,
last_tool_result=message,
last_action_valid=valid,
last_action_error=error,
visible_test_result=visible_test_result,
reward_breakdown=reward_breakdown or {},
done_reason=done_reason,
done=self._state.done,
reward=reward,
metadata={
"episode_id": self._state.episode_id,
"step_count": self._state.step_count,
"seed": self._state.seed,
"split": self._state.split,
"difficulty": self._state.difficulty,
"difficulty_tier": self._state.difficulty_tier,
"domain": self._state.domain,
"bug_family": self._state.bug_family,
"template_id": self._state.template_id,
"scenario_hash": self._state.scenario_hash,
},
)
def _scenario_prompt(self, available_actions: list[str]) -> str:
users = self._visible_policy_hint.get("fixture_aliases", {}).get("users", {})
resources = self._visible_policy_hint.get("fixture_aliases", {}).get("resources", {})
visible_policy = {
"domain": self._visible_policy_hint.get("domain", self._state.domain),
"policy_rules": list(self._visible_policy_hint.get("policy_rules", [])),
"public_routes": list(self._visible_policy_hint.get("public_routes", [])),
"fixture_aliases": {
"users": sorted(str(key) for key in users),
"resources": sorted(str(key) for key in resources),
},
}
prompt = {
"environment": "CyberSecurity_OWASP",
"task": self._task_brief,
"scenario": {
"task_id": self._state.task_id,
"seed": self._state.seed,
"split": self._state.split,
"difficulty": self._state.difficulty,
"difficulty_tier": self._state.difficulty_tier,
"domain": self._state.domain,
"bug_family": self._state.bug_family,
"template_id": self._state.template_id,
"scenario_hash": self._state.scenario_hash,
},
"visible_policy_hint": visible_policy,
"workspace_summary": self._workspace_summary,
"available_actions": available_actions,
"instructions": [
"Use only the local generated application and the listed tools.",
"Discover the authorization defect with local evidence before patching.",
"Preserve legitimate owner/admin flows and intentionally public routes.",
"Submit exactly one secure, policy-aligned fix when ready.",
],
}
return "CyberSecurity_OWASP scenario prompt:\n" + json.dumps(
prompt,
indent=2,
sort_keys=True,
)
def _finalize_terminal_episode(self, observation_record: dict[str, Any]) -> None:
if self._state.episode_artifact_path:
return
mastery = self._curriculum.record_episode(self._state)
self._state.curriculum_snapshot = {
**self._state.curriculum_snapshot,
"post_episode_mastery": mastery,
}
self._episode_logger.log_episode(
self._state,
final_observation=observation_record,
)