Spaces:
Sleeping
Sleeping
File size: 5,068 Bytes
3807ea3 b3ee507 3807ea3 be8eade 6abc8c5 3807ea3 be8eade 3807ea3 6abc8c5 3807ea3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 | """Minimal rollout loop for CyberSecurity_OWASP episodes."""
from __future__ import annotations
import json
from typing import Any
from CyberSecurity_OWASP import CyberSecurityOWASPAction
def build_cybersecurity_owasp_prompt(observation, action_trace, observation_trace) -> str:
return (
"You are a defensive AppSec repair agent. Output exactly one JSON action.\n"
f"Phase: {observation.phase}\n"
f"Task: {observation.task_brief}\n"
f"Available actions: {observation.available_actions}\n"
f"Last result: {observation.last_tool_result}\n"
'Example: {"tool_name":"read_file","arguments":{"path":"app/routes/invoices.py"}}'
)
def parse_action_json(text: str) -> CyberSecurityOWASPAction:
data = json.loads(text)
return CyberSecurityOWASPAction(**data)
def generate_rollout_completions(trainer, prompts: list[str]) -> list[dict[str, Any]]:
if hasattr(trainer, "generate_rollout_completions"):
return trainer.generate_rollout_completions(prompts)
return [
{
"text": '{"tool_name":"noop","arguments":{}}',
"prompt_ids": [],
"completion_ids": [],
"logprobs": [],
}
for _ in prompts
]
def rollout_once(
trainer,
env,
tokenizer=None,
dataset_prompt: str = "",
max_steps: int = 40,
reset_kwargs: dict[str, Any] | None = None,
) -> dict:
result = env.reset(**(reset_kwargs or {}))
observation = result.observation if hasattr(result, "observation") else result
prompt_ids = []
completion_ids = []
logprobs = []
reward_trace = []
action_trace = []
observation_trace = []
for _ in range(max_steps):
if getattr(observation, "done", False):
break
prompt = build_cybersecurity_owasp_prompt(observation, action_trace, observation_trace)
rollout_output = generate_rollout_completions(trainer, [prompt])[0]
action = parse_action_json(rollout_output["text"])
result = env.step(action)
observation = result.observation if hasattr(result, "observation") else result
prompt_ids.extend(rollout_output["prompt_ids"])
completion_ids.extend(rollout_output["completion_ids"])
logprobs.extend(rollout_output["logprobs"])
reward_trace.append(float(getattr(observation, "reward", 0.0) or 0.0))
action_trace.append(action.model_dump())
observation_trace.append(observation.model_dump())
state = env.state if not callable(getattr(env, "state", None)) else env.state()
final_breakdown = getattr(observation, "reward_breakdown", {}) or {}
verifier = getattr(state, "verification_summary", {}) or {}
anti_cheat_flags = getattr(state, "anti_cheat_flags", []) or []
invalid_actions = [
obs for obs in observation_trace if obs.get("last_action_valid") is False
]
return {
"prompt_ids": prompt_ids,
"completion_ids": completion_ids,
"logprobs": logprobs,
"reward_total": float(getattr(state, "accumulated_reward", sum(reward_trace))),
"reward_terminal_15": float(final_breakdown.get("terminal_total", 0.0)),
"reward_progressive_5": float(
getattr(state, "progress_reward_total", final_breakdown.get("progressive", 0.0))
),
"reward_step_penalty": float(
sum((item or {}).get("step_penalty", 0.0) for item in getattr(state, "reward_history", []))
),
"reward_speed_bonus": float(
sum((item or {}).get("speed_bonus", 0.0) for item in getattr(state, "reward_history", []))
),
"reward_behavior_penalty": float(
sum((item or {}).get("behavior_penalty", 0.0) for item in getattr(state, "reward_history", []))
),
"reward_discovery": float(final_breakdown.get("discovery", 0.0)),
"reward_security": float(final_breakdown.get("security", 0.0)),
"reward_regression": float(final_breakdown.get("regression", 0.0)),
"reward_patch_quality": float(final_breakdown.get("patch_quality", 0.0)),
"reward_anti_cheat": float(final_breakdown.get("anti_cheat", 0.0)),
"success": bool(getattr(state, "success", False)),
"episode_length": len(action_trace),
"exploit_blocked": bool((verifier.get("security") or {}).get("passed", False)),
"regression_preserved": bool((verifier.get("regression") or {}).get("passed", False)),
"public_routes_preserved": bool((verifier.get("public_routes") or {}).get("passed", False)),
"anti_cheat_pass": not bool(anti_cheat_flags),
"invalid_action_rate": len(invalid_actions) / max(1, len(action_trace)),
"timeout": getattr(state, "failure_reason", None) == "max_steps_exceeded",
"safety_violation": bool(
any("network" in flag or "unsafe" in flag for flag in anti_cheat_flags)
),
"episode_artifact_path": getattr(state, "episode_artifact_path", None),
"actions": action_trace,
"observations": observation_trace,
}
|