File size: 5,068 Bytes
3807ea3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3ee507
 
 
 
 
 
 
 
 
3807ea3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
be8eade
6abc8c5
 
 
 
 
3807ea3
 
 
 
be8eade
 
 
 
 
 
 
 
 
 
 
 
 
 
3807ea3
 
 
 
 
 
 
6abc8c5
 
 
 
 
 
 
 
 
 
3807ea3
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""Minimal rollout loop for CyberSecurity_OWASP episodes."""

from __future__ import annotations

import json
from typing import Any

from CyberSecurity_OWASP import CyberSecurityOWASPAction


def build_cybersecurity_owasp_prompt(observation, action_trace, observation_trace) -> str:
    return (
        "You are a defensive AppSec repair agent. Output exactly one JSON action.\n"
        f"Phase: {observation.phase}\n"
        f"Task: {observation.task_brief}\n"
        f"Available actions: {observation.available_actions}\n"
        f"Last result: {observation.last_tool_result}\n"
        'Example: {"tool_name":"read_file","arguments":{"path":"app/routes/invoices.py"}}'
    )


def parse_action_json(text: str) -> CyberSecurityOWASPAction:
    data = json.loads(text)
    return CyberSecurityOWASPAction(**data)


def generate_rollout_completions(trainer, prompts: list[str]) -> list[dict[str, Any]]:
    if hasattr(trainer, "generate_rollout_completions"):
        return trainer.generate_rollout_completions(prompts)
    return [
        {
            "text": '{"tool_name":"noop","arguments":{}}',
            "prompt_ids": [],
            "completion_ids": [],
            "logprobs": [],
        }
        for _ in prompts
    ]


def rollout_once(
    trainer,
    env,
    tokenizer=None,
    dataset_prompt: str = "",
    max_steps: int = 40,
    reset_kwargs: dict[str, Any] | None = None,
) -> dict:
    result = env.reset(**(reset_kwargs or {}))
    observation = result.observation if hasattr(result, "observation") else result

    prompt_ids = []
    completion_ids = []
    logprobs = []
    reward_trace = []
    action_trace = []
    observation_trace = []

    for _ in range(max_steps):
        if getattr(observation, "done", False):
            break
        prompt = build_cybersecurity_owasp_prompt(observation, action_trace, observation_trace)
        rollout_output = generate_rollout_completions(trainer, [prompt])[0]
        action = parse_action_json(rollout_output["text"])
        result = env.step(action)
        observation = result.observation if hasattr(result, "observation") else result

        prompt_ids.extend(rollout_output["prompt_ids"])
        completion_ids.extend(rollout_output["completion_ids"])
        logprobs.extend(rollout_output["logprobs"])
        reward_trace.append(float(getattr(observation, "reward", 0.0) or 0.0))
        action_trace.append(action.model_dump())
        observation_trace.append(observation.model_dump())

    state = env.state if not callable(getattr(env, "state", None)) else env.state()
    final_breakdown = getattr(observation, "reward_breakdown", {}) or {}
    verifier = getattr(state, "verification_summary", {}) or {}
    anti_cheat_flags = getattr(state, "anti_cheat_flags", []) or []
    invalid_actions = [
        obs for obs in observation_trace if obs.get("last_action_valid") is False
    ]
    return {
        "prompt_ids": prompt_ids,
        "completion_ids": completion_ids,
        "logprobs": logprobs,
        "reward_total": float(getattr(state, "accumulated_reward", sum(reward_trace))),
        "reward_terminal_15": float(final_breakdown.get("terminal_total", 0.0)),
        "reward_progressive_5": float(
            getattr(state, "progress_reward_total", final_breakdown.get("progressive", 0.0))
        ),
        "reward_step_penalty": float(
            sum((item or {}).get("step_penalty", 0.0) for item in getattr(state, "reward_history", []))
        ),
        "reward_speed_bonus": float(
            sum((item or {}).get("speed_bonus", 0.0) for item in getattr(state, "reward_history", []))
        ),
        "reward_behavior_penalty": float(
            sum((item or {}).get("behavior_penalty", 0.0) for item in getattr(state, "reward_history", []))
        ),
        "reward_discovery": float(final_breakdown.get("discovery", 0.0)),
        "reward_security": float(final_breakdown.get("security", 0.0)),
        "reward_regression": float(final_breakdown.get("regression", 0.0)),
        "reward_patch_quality": float(final_breakdown.get("patch_quality", 0.0)),
        "reward_anti_cheat": float(final_breakdown.get("anti_cheat", 0.0)),
        "success": bool(getattr(state, "success", False)),
        "episode_length": len(action_trace),
        "exploit_blocked": bool((verifier.get("security") or {}).get("passed", False)),
        "regression_preserved": bool((verifier.get("regression") or {}).get("passed", False)),
        "public_routes_preserved": bool((verifier.get("public_routes") or {}).get("passed", False)),
        "anti_cheat_pass": not bool(anti_cheat_flags),
        "invalid_action_rate": len(invalid_actions) / max(1, len(action_trace)),
        "timeout": getattr(state, "failure_reason", None) == "max_steps_exceeded",
        "safety_violation": bool(
            any("network" in flag or "unsafe" in flag for flag in anti_cheat_flags)
        ),
        "episode_artifact_path": getattr(state, "episode_artifact_path", None),
        "actions": action_trace,
        "observations": observation_trace,
    }