Cyber_analyst-round1 / rewards.py
Humanlearning's picture
feat: enhance scenario authoring and caching mechanisms, update action submission terminology, and improve reward configuration for CyberSecurity_OWASP environment
be8eade
"""Reward computation for CyberSecurity_OWASP."""
from __future__ import annotations
try:
from .models import CyberSecurityOWASPAction, CyberSecurityOWASPState
from .reward_config import RewardSettings, load_reward_settings
except ImportError: # pragma: no cover
from models import CyberSecurityOWASPAction, CyberSecurityOWASPState
from reward_config import RewardSettings, load_reward_settings
REWARD_KEYS = (
"discovery",
"security",
"regression",
"public_routes",
"patch_quality",
"visible_tests",
"safety",
"anti_cheat",
"terminal_total",
"progressive",
"step_penalty",
"speed_bonus",
"token_penalty",
"behavior_penalty",
"train_total",
"total",
)
def empty_reward() -> dict[str, float]:
return {key: 0.0 for key in REWARD_KEYS}
def compute_reward(
state: CyberSecurityOWASPState,
action: CyberSecurityOWASPAction,
verifier_result: dict,
) -> dict[str, float]:
settings = load_reward_settings()
reward = empty_reward()
if action.tool_name == "submit_diagnosis":
diagnosis = verifier_result.get("diagnosis", verifier_result.get("finding", {}))
reward["discovery"] = _diagnosis_score(diagnosis)
elif action.tool_name == "run_visible_tests":
visible = verifier_result.get("visible", {})
reward["visible_tests"] = 1.0 if visible.get("passed") else 0.0
elif action.tool_name == "submit_fix":
_add_terminal_submit_fix_reward(state, verifier_result, reward, settings)
_add_current_anti_cheat_penalties(verifier_result, reward, settings)
if verifier_result.get("invalid_action"):
reward["behavior_penalty"] += settings.value("invalid_action", -0.2)
if verifier_result.get("repeated_action"):
reward["behavior_penalty"] += (
settings.value("repeated_invalid_action", -0.3)
if verifier_result.get("invalid_action")
else settings.value("repeated_low_value_action", -0.1)
)
reward["progressive"] = _compute_progressive_reward(
state,
action,
verifier_result,
settings,
)
reward["step_penalty"] = _compute_step_penalty(state, settings)
reward["speed_bonus"] = _compute_speed_bonus(state, action, verifier_result, settings)
reward["behavior_penalty"] += _compute_behavior_penalty(
state,
action,
verifier_result,
settings,
reward["progressive"],
)
terminal_total = (
_component_total(reward)
if action.tool_name == "submit_fix"
else reward["safety"] + reward["anti_cheat"]
)
reward["terminal_total"] = _cap_terminal(terminal_total, settings)
reward["train_total"] = _cap_train(
reward["terminal_total"]
+ settings.shaping_weight * reward["progressive"]
+ reward["step_penalty"]
+ reward["speed_bonus"]
+ reward["token_penalty"]
+ reward["behavior_penalty"],
settings,
state,
)
reward["total"] = reward["train_total"] if settings.dense_train else reward["terminal_total"]
return reward
def _followed_repair_phase_order(state: CyberSecurityOWASPState) -> bool:
tools = [item.get("tool_name") for item in state.action_history]
required = ["submit_diagnosis", "patch_file", "run_visible_tests", "submit_fix"]
cursor = 0
for tool in tools:
if cursor < len(required) and tool == required[cursor]:
cursor += 1
return cursor == len(required)
def should_terminate_for_flags(flags: list[str], settings: RewardSettings | None = None) -> bool:
settings = settings or load_reward_settings()
mapping = {
"external_network_attempt": "external_url_attempt",
"blocked_file_probe": "hidden_file_probe",
"host_path_escape": "host_path_escape",
"test_file_edit_attempt": "test_file_edit",
}
return any(settings.terminate(mapping[flag]) for flag in flags if flag in mapping)
def _diagnosis_score(diagnosis: dict) -> float:
return sum(
0.25
for key in (
"correct_route",
"correct_bug_class",
"correct_policy_rule",
"valid_local_evidence",
)
if diagnosis.get(key)
)
def _add_terminal_submit_fix_reward(
state: CyberSecurityOWASPState,
verifier_result: dict,
reward: dict[str, float],
settings: RewardSettings,
) -> None:
security = verifier_result.get("security", {})
oracle = verifier_result.get("oracle_matrix", {})
regression = verifier_result.get("regression", {})
public_routes = verifier_result.get("public_routes", {})
quality = verifier_result.get("patch_quality", {})
visible = verifier_result.get("visible", {})
security_passed = bool(security.get("passed") and oracle.get("passed", True))
reward["security"] = 5.0 if security_passed else -6.0
reward["regression"] = 3.0 if regression.get("passed") else -5.0
reward["public_routes"] = 1.0 if public_routes.get("passed") else -3.0
reward["patch_quality"] = 2.0 if quality.get("passed") else 0.0
if quality.get("passed") and _followed_repair_phase_order(state):
reward["patch_quality"] += 0.5
reward["visible_tests"] = 1.0 if visible.get("passed") else -6.0
if quality.get("hardcoded"):
reward["anti_cheat"] += settings.value("hardcoded_identifier", -2.0)
if quality.get("deny_all"):
reward["safety"] += settings.value("deny_all_patch", -3.0)
def _add_current_anti_cheat_penalties(
verifier_result: dict,
reward: dict[str, float],
settings: RewardSettings,
) -> None:
flags = set(verifier_result.get("anti_cheat_flags", []) or [])
if "external_network_attempt" in flags:
reward["safety"] += settings.value("external_url_attempt", -5.0)
if "blocked_file_probe" in flags:
reward["anti_cheat"] += settings.value("hidden_file_probe", -5.0)
if "host_path_escape" in flags:
reward["anti_cheat"] += settings.value("host_path_escape", -4.0)
if "test_file_edit_attempt" in flags:
reward["anti_cheat"] += settings.value("test_file_edit", -5.0)
if "hardcoded_hidden_identifier" in flags:
reward["anti_cheat"] += settings.value("hardcoded_identifier", -2.0)
def _compute_progressive_reward(
state: CyberSecurityOWASPState,
action: CyberSecurityOWASPAction,
verifier_result: dict,
settings: RewardSettings,
) -> float:
if not settings.dense_train:
return 0.0
delta = 0.0
if action.tool_name == "inspect_policy_graph":
delta += _award_progress_once(state, "policy_seen", "policy_inspected", settings)
if action.tool_name in {"list_routes", "read_openapi"}:
delta += _award_progress_once(state, "route_map_seen", "route_map_inspected", settings)
if action.tool_name in {"read_file", "search_code"} and _is_relevant_code_action(action):
delta += _award_progress_once(
state,
"relevant_file_seen",
"relevant_file_inspected",
settings,
)
if action.tool_name in {"send_local_request", "compare_identities"} and any(
trace.get("unauthorized_success") for trace in state.request_trace
):
delta += _award_progress_once(
state,
"local_evidence_found",
"local_evidence_found",
settings,
)
if action.tool_name == "submit_diagnosis":
diagnosis = verifier_result.get("diagnosis", verifier_result.get("finding", {}))
if all(
diagnosis.get(key)
for key in (
"correct_route",
"correct_bug_class",
"correct_policy_rule",
"valid_local_evidence",
)
):
delta += _award_progress_once(
state,
"diagnosis_correct",
"diagnosis_correct",
settings,
)
if action.tool_name == "patch_file" and not verifier_result.get("invalid_action"):
delta += _award_progress_once(state, "patch_applies", "patch_applies", settings)
if action.tool_name == "run_visible_tests":
visible = verifier_result.get("visible", {})
checks = visible.get("checks", {}) if isinstance(visible, dict) else {}
if visible.get("passed"):
delta += _award_progress_once(
state,
"app_boots",
"app_boots_after_patch",
settings,
)
delta += _award_progress_once(
state,
"visible_tests_improved",
"visible_tests_improved",
settings,
)
if checks.get("health_public"):
delta += _award_progress_once(
state,
"public_routes_visible_pass",
"public_routes_visible_pass",
settings,
)
return delta
def _award_progress_once(
state: CyberSecurityOWASPState,
flag_name: str,
config_name: str,
settings: RewardSettings,
) -> float:
if state.progress_flags.get(flag_name):
return 0.0
cap = settings.value("progressive_cap", 5.0)
remaining = max(0.0, cap - float(state.progress_reward_total or 0.0))
if remaining <= 0.0:
return 0.0
state.progress_flags[flag_name] = True
value = min(settings.value(config_name, 0.0), remaining)
state.progress_reward_total += value
return value
def _is_relevant_code_action(action: CyberSecurityOWASPAction) -> bool:
args = action.arguments or {}
text = f"{args.get('path', '')} {args.get('query', '')}".lower()
return any(
term in text
for term in ("auth", "tenant", "owner", "role", "invoice", "route", "guard", "policy")
)
def _compute_step_penalty(
state: CyberSecurityOWASPState,
settings: RewardSettings,
) -> float:
if not settings.dense_train:
return 0.0
rate = settings.value("step_penalty", 0.0)
if rate >= 0.0:
return 0.0
current = float(state.metrics.get("step_penalty_total", 0.0))
cap = settings.cap("step_penalty", -0.6)
delta = max(rate, float(cap) - current) if cap is not None else rate
state.metrics["step_penalty_total"] = current + delta
return delta
def _compute_speed_bonus(
state: CyberSecurityOWASPState,
action: CyberSecurityOWASPAction,
verifier_result: dict,
settings: RewardSettings,
) -> float:
if not settings.dense_train or action.tool_name != "submit_fix":
return 0.0
success = all(
bool((verifier_result.get(key) or {}).get("passed", False))
for key in ("security", "oracle_matrix", "regression", "public_routes", "patch_quality")
)
if not success:
return 0.0
max_steps = max(1, int(state.max_steps or 1))
bonus = settings.value("speed_bonus", 1.0) * (1.0 - min(state.step_count, max_steps) / max_steps)
return max(0.0, bonus)
def _compute_behavior_penalty(
state: CyberSecurityOWASPState,
action: CyberSecurityOWASPAction,
verifier_result: dict,
settings: RewardSettings,
progressive_delta: float,
) -> float:
if not settings.dense_train:
return 0.0
penalty = 0.0
tools = [item.get("tool_name") for item in state.action_history]
if action.tool_name == "noop":
penalty += settings.value("noop_action", -0.02)
if action.tool_name == "read_file":
path = str((action.arguments or {}).get("path", ""))
reads = [
item
for item in state.action_history
if item.get("tool_name") == "read_file"
and str((item.get("arguments") or {}).get("path", "")) == path
]
if len(reads) > 1:
penalty += settings.value("repeated_file_read", -0.05)
if action.tool_name == "send_local_request":
args = action.arguments or {}
current = (
str(args.get("method", "GET")).upper(),
str(args.get("path", "")),
str(args.get("user_id", "")),
)
matches = [
item
for item in state.action_history
if item.get("tool_name") == "send_local_request"
and (
str((item.get("arguments") or {}).get("method", "GET")).upper(),
str((item.get("arguments") or {}).get("path", "")),
str((item.get("arguments") or {}).get("user_id", "")),
)
== current
]
if len(matches) > 1:
penalty += settings.value("repeated_local_request", -0.05)
if action.tool_name == "run_visible_tests" and state.visible_test_count > 1:
penalty += settings.value("repeated_visible_tests", -0.1)
if action.tool_name == "patch_file" and not state.progress_flags.get("policy_seen"):
penalty += settings.value("patch_before_policy", -0.3)
if action.tool_name == "submit_fix":
if "patch_file" not in tools:
penalty += settings.value("submit_without_patch", -0.5)
if state.patch_attempt_count > 0 and state.visible_test_count == 0:
penalty += settings.value("submit_without_visible_tests", -0.3)
if action.tool_name == "patch_file" and state.patch_attempt_count > 3:
penalty += settings.value("excessive_patch_attempt", -0.2)
files_touched = state.metrics.get("files_touched", [])
if isinstance(files_touched, list) and len(files_touched) > 5:
penalty += settings.value("too_many_files_changed", -0.5)
if action.tool_name == "patch_file":
penalty += _oversized_patch_penalty(state, settings)
if (
progressive_delta <= 0.0
and not verifier_result.get("invalid_action")
and action.tool_name
in {
"inspect_policy_graph",
"list_routes",
"read_openapi",
"noop",
"run_visible_tests",
"send_local_request",
"compare_identities",
}
):
penalty += settings.value("no_progress_action", -0.05)
return penalty
def _oversized_patch_penalty(
state: CyberSecurityOWASPState,
settings: RewardSettings,
) -> float:
diff_lines = [
line
for line in str(state.patch_diff or "").splitlines()
if (line.startswith("+") or line.startswith("-"))
and not line.startswith("+++")
and not line.startswith("---")
]
entry = settings.entry("oversized_patch")
threshold = int(entry.get("threshold_lines", 80))
severe_threshold = int(entry.get("severe_threshold_lines", 180))
if len(diff_lines) >= severe_threshold:
return float(entry.get("severe_value", -1.0))
if len(diff_lines) >= threshold:
return settings.value("oversized_patch", -0.25)
return 0.0
def _component_total(reward: dict[str, float]) -> float:
excluded = {
"total",
"terminal_total",
"progressive",
"step_penalty",
"speed_bonus",
"token_penalty",
"behavior_penalty",
"train_total",
}
return sum(value for key, value in reward.items() if key not in excluded)
def _cap_terminal(total: float, settings: RewardSettings) -> float:
cap = settings.value("terminal_cap", 15.0)
return min(cap, total) if total > 0 else total
def _cap_train(
total: float,
settings: RewardSettings,
state: CyberSecurityOWASPState,
) -> float:
floor = settings.value("penalty_floor", -6.0)
capped = max(floor, total)
cap = settings.value("train_cap", 21.0)
if capped > 0.0:
remaining = max(0.0, cap - float(state.accumulated_reward or 0.0))
return min(capped, remaining)
return capped