Spaces:
Sleeping
Sleeping
feat: enhance scenario authoring and caching mechanisms, update action submission terminology, and improve reward configuration for CyberSecurity_OWASP environment
be8eade | """Minimal rollout loop for CyberSecurity_OWASP episodes.""" | |
| from __future__ import annotations | |
| import json | |
| from typing import Any | |
| from CyberSecurity_OWASP import CyberSecurityOWASPAction | |
| def build_cybersecurity_owasp_prompt(observation, action_trace, observation_trace) -> str: | |
| return ( | |
| "You are a defensive AppSec repair agent. Output exactly one JSON action.\n" | |
| f"Phase: {observation.phase}\n" | |
| f"Task: {observation.task_brief}\n" | |
| f"Available actions: {observation.available_actions}\n" | |
| f"Last result: {observation.last_tool_result}\n" | |
| 'Example: {"tool_name":"read_file","arguments":{"path":"app/routes/invoices.py"}}' | |
| ) | |
| def parse_action_json(text: str) -> CyberSecurityOWASPAction: | |
| data = json.loads(text) | |
| return CyberSecurityOWASPAction(**data) | |
| def generate_rollout_completions(trainer, prompts: list[str]) -> list[dict[str, Any]]: | |
| if hasattr(trainer, "generate_rollout_completions"): | |
| return trainer.generate_rollout_completions(prompts) | |
| return [ | |
| { | |
| "text": '{"tool_name":"noop","arguments":{}}', | |
| "prompt_ids": [], | |
| "completion_ids": [], | |
| "logprobs": [], | |
| } | |
| for _ in prompts | |
| ] | |
| def rollout_once( | |
| trainer, | |
| env, | |
| tokenizer=None, | |
| dataset_prompt: str = "", | |
| max_steps: int = 40, | |
| reset_kwargs: dict[str, Any] | None = None, | |
| ) -> dict: | |
| result = env.reset(**(reset_kwargs or {})) | |
| observation = result.observation if hasattr(result, "observation") else result | |
| prompt_ids = [] | |
| completion_ids = [] | |
| logprobs = [] | |
| reward_trace = [] | |
| action_trace = [] | |
| observation_trace = [] | |
| for _ in range(max_steps): | |
| if getattr(observation, "done", False): | |
| break | |
| prompt = build_cybersecurity_owasp_prompt(observation, action_trace, observation_trace) | |
| rollout_output = generate_rollout_completions(trainer, [prompt])[0] | |
| action = parse_action_json(rollout_output["text"]) | |
| result = env.step(action) | |
| observation = result.observation if hasattr(result, "observation") else result | |
| prompt_ids.extend(rollout_output["prompt_ids"]) | |
| completion_ids.extend(rollout_output["completion_ids"]) | |
| logprobs.extend(rollout_output["logprobs"]) | |
| reward_trace.append(float(getattr(observation, "reward", 0.0) or 0.0)) | |
| action_trace.append(action.model_dump()) | |
| observation_trace.append(observation.model_dump()) | |
| state = env.state if not callable(getattr(env, "state", None)) else env.state() | |
| final_breakdown = getattr(observation, "reward_breakdown", {}) or {} | |
| verifier = getattr(state, "verification_summary", {}) or {} | |
| anti_cheat_flags = getattr(state, "anti_cheat_flags", []) or [] | |
| invalid_actions = [ | |
| obs for obs in observation_trace if obs.get("last_action_valid") is False | |
| ] | |
| return { | |
| "prompt_ids": prompt_ids, | |
| "completion_ids": completion_ids, | |
| "logprobs": logprobs, | |
| "reward_total": float(getattr(state, "accumulated_reward", sum(reward_trace))), | |
| "reward_terminal_15": float(final_breakdown.get("terminal_total", 0.0)), | |
| "reward_progressive_5": float( | |
| getattr(state, "progress_reward_total", final_breakdown.get("progressive", 0.0)) | |
| ), | |
| "reward_step_penalty": float( | |
| sum((item or {}).get("step_penalty", 0.0) for item in getattr(state, "reward_history", [])) | |
| ), | |
| "reward_speed_bonus": float( | |
| sum((item or {}).get("speed_bonus", 0.0) for item in getattr(state, "reward_history", [])) | |
| ), | |
| "reward_behavior_penalty": float( | |
| sum((item or {}).get("behavior_penalty", 0.0) for item in getattr(state, "reward_history", [])) | |
| ), | |
| "reward_discovery": float(final_breakdown.get("discovery", 0.0)), | |
| "reward_security": float(final_breakdown.get("security", 0.0)), | |
| "reward_regression": float(final_breakdown.get("regression", 0.0)), | |
| "reward_patch_quality": float(final_breakdown.get("patch_quality", 0.0)), | |
| "reward_anti_cheat": float(final_breakdown.get("anti_cheat", 0.0)), | |
| "success": bool(getattr(state, "success", False)), | |
| "episode_length": len(action_trace), | |
| "exploit_blocked": bool((verifier.get("security") or {}).get("passed", False)), | |
| "regression_preserved": bool((verifier.get("regression") or {}).get("passed", False)), | |
| "public_routes_preserved": bool((verifier.get("public_routes") or {}).get("passed", False)), | |
| "anti_cheat_pass": not bool(anti_cheat_flags), | |
| "invalid_action_rate": len(invalid_actions) / max(1, len(action_trace)), | |
| "timeout": getattr(state, "failure_reason", None) == "max_steps_exceeded", | |
| "safety_violation": bool( | |
| any("network" in flag or "unsafe" in flag for flag in anti_cheat_flags) | |
| ), | |
| "episode_artifact_path": getattr(state, "episode_artifact_path", None), | |
| "actions": action_trace, | |
| "observations": observation_trace, | |
| } | |