Spaces:
Sleeping
Sleeping
Commit ·
f7b8ac6
1
Parent(s): 5809a6c
feat: introduce reward ablation configurations for enhanced training flexibility, implement YAML loading with extends support, and add reward variant tracking in training scripts
Browse files- reward_config.py +40 -1
- scripts/generate_sft_dataset.py +753 -0
- scripts/launch_reward_ablations.ps1 +59 -0
- scripts/modal_train_grpo.py +115 -21
- scripts/modal_train_sft.py +442 -0
- tests/test_reward_config.py +39 -0
- tests/test_sft_dataset_generation.py +142 -0
- tests/test_trackio_utils.py +6 -0
- training/configs/reward_ablations/A0_sparse_terminal_only.yaml +97 -0
- training/configs/reward_ablations/A2_reduced_shaping.yaml +12 -0
- training/configs/reward_ablations/A3_no_speed_token.yaml +17 -0
- training/configs/reward_ablations/A6_visible_gate.yaml +10 -0
- training/configs/reward_ablations/A7_evidence045.yaml +6 -0
- training/trackio_utils.py +53 -0
reward_config.py
CHANGED
|
@@ -74,7 +74,7 @@ def load_reward_settings(path: str | Path | None = None) -> RewardSettings:
|
|
| 74 |
or os.getenv("CYBERSECURITY_OWASP_REWARD_CONFIG", "")
|
| 75 |
or DEFAULT_GRPO_CONFIG_PATH
|
| 76 |
)
|
| 77 |
-
raw =
|
| 78 |
reward = dict(raw.get("reward") or {})
|
| 79 |
mode = os.getenv("CYBERSECURITY_OWASP_REWARD_MODE", str(reward.get("mode", "sparse_eval")))
|
| 80 |
training_mode = str(reward.get("training_mode", "dense_train"))
|
|
@@ -90,6 +90,44 @@ def load_reward_settings(path: str | Path | None = None) -> RewardSettings:
|
|
| 90 |
return settings
|
| 91 |
|
| 92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
def flatten_reward_config(
|
| 94 |
settings: RewardSettings | None = None,
|
| 95 |
) -> list[dict[str, Any]]:
|
|
@@ -175,6 +213,7 @@ def reward_config_run_config(settings: RewardSettings | None = None) -> dict[str
|
|
| 175 |
"reward_config_hash": summary["reward_config_hash"],
|
| 176 |
"reward_config_source": summary["reward_config_source"],
|
| 177 |
"reward_config_source_name": summary["reward_config_source_name"],
|
|
|
|
| 178 |
"reward_mode": summary["reward_mode"],
|
| 179 |
"reward_training_mode": summary["reward_training_mode"],
|
| 180 |
"reward_stage": summary["reward_stage"],
|
|
|
|
| 74 |
or os.getenv("CYBERSECURITY_OWASP_REWARD_CONFIG", "")
|
| 75 |
or DEFAULT_GRPO_CONFIG_PATH
|
| 76 |
)
|
| 77 |
+
raw = _load_yaml_with_extends(configured_path)
|
| 78 |
reward = dict(raw.get("reward") or {})
|
| 79 |
mode = os.getenv("CYBERSECURITY_OWASP_REWARD_MODE", str(reward.get("mode", "sparse_eval")))
|
| 80 |
training_mode = str(reward.get("training_mode", "dense_train"))
|
|
|
|
| 90 |
return settings
|
| 91 |
|
| 92 |
|
| 93 |
+
def _load_yaml_with_extends(path: Path, seen: set[Path] | None = None) -> dict[str, Any]:
|
| 94 |
+
"""Load a YAML file, recursively merging an optional relative `extends` file."""
|
| 95 |
+
|
| 96 |
+
resolved_path = path.expanduser().resolve()
|
| 97 |
+
seen = seen or set()
|
| 98 |
+
if resolved_path in seen:
|
| 99 |
+
chain = " -> ".join(str(item) for item in [*seen, resolved_path])
|
| 100 |
+
raise ValueError(f"reward config extends cycle detected: {chain}")
|
| 101 |
+
seen.add(resolved_path)
|
| 102 |
+
|
| 103 |
+
raw = yaml.safe_load(resolved_path.read_text(encoding="utf-8")) or {}
|
| 104 |
+
if not isinstance(raw, dict):
|
| 105 |
+
raise ValueError(f"reward config must be a YAML mapping: {resolved_path}")
|
| 106 |
+
|
| 107 |
+
extends = raw.get("extends")
|
| 108 |
+
if not extends:
|
| 109 |
+
return raw
|
| 110 |
+
if not isinstance(extends, str):
|
| 111 |
+
raise ValueError("reward config extends must be a string path")
|
| 112 |
+
|
| 113 |
+
base_path = Path(extends)
|
| 114 |
+
if not base_path.is_absolute():
|
| 115 |
+
base_path = resolved_path.parent / base_path
|
| 116 |
+
child = {key: value for key, value in raw.items() if key != "extends"}
|
| 117 |
+
return _deep_merge(_load_yaml_with_extends(base_path, seen), child)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def _deep_merge(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
|
| 121 |
+
merged = dict(base)
|
| 122 |
+
for key, value in override.items():
|
| 123 |
+
base_value = merged.get(key)
|
| 124 |
+
if isinstance(base_value, dict) and isinstance(value, dict):
|
| 125 |
+
merged[key] = _deep_merge(base_value, value)
|
| 126 |
+
else:
|
| 127 |
+
merged[key] = value
|
| 128 |
+
return merged
|
| 129 |
+
|
| 130 |
+
|
| 131 |
def flatten_reward_config(
|
| 132 |
settings: RewardSettings | None = None,
|
| 133 |
) -> list[dict[str, Any]]:
|
|
|
|
| 213 |
"reward_config_hash": summary["reward_config_hash"],
|
| 214 |
"reward_config_source": summary["reward_config_source"],
|
| 215 |
"reward_config_source_name": summary["reward_config_source_name"],
|
| 216 |
+
"reward_variant": os.getenv("CYBERSECURITY_OWASP_REWARD_VARIANT", "default") or "default",
|
| 217 |
"reward_mode": summary["reward_mode"],
|
| 218 |
"reward_training_mode": summary["reward_training_mode"],
|
| 219 |
"reward_stage": summary["reward_stage"],
|
scripts/generate_sft_dataset.py
ADDED
|
@@ -0,0 +1,753 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Generate verifier-gated SFT data for CyberSecurity_OWASP.
|
| 2 |
+
|
| 3 |
+
The default path asks a larger Hugging Face-hosted teacher model for one JSON
|
| 4 |
+
action at a time, executes those actions in the real environment, and keeps
|
| 5 |
+
only trajectories that pass the local deterministic verifier. The
|
| 6 |
+
``--dry-run-oracle`` path is intentionally network-free and exists for CI and
|
| 7 |
+
smoke tests.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from __future__ import annotations
|
| 11 |
+
|
| 12 |
+
import argparse
|
| 13 |
+
import json
|
| 14 |
+
import os
|
| 15 |
+
import statistics
|
| 16 |
+
import subprocess
|
| 17 |
+
from dataclasses import dataclass
|
| 18 |
+
from pathlib import Path
|
| 19 |
+
from typing import Any, Iterable
|
| 20 |
+
|
| 21 |
+
from CyberSecurity_OWASP.models import CyberSecurityOWASPAction, CyberSecurityOWASPObservation
|
| 22 |
+
from CyberSecurity_OWASP.server.CyberSecurity_OWASP_environment import (
|
| 23 |
+
CybersecurityOwaspEnvironment,
|
| 24 |
+
)
|
| 25 |
+
from CyberSecurity_OWASP.validators import detect_cheating
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
DEFAULT_TEACHER_MODEL = "deepseek-ai/DeepSeek-V4-Pro"
|
| 29 |
+
DEFAULT_TARGET_MODEL = "unsloth/gemma-4-E2B-it"
|
| 30 |
+
TRAINING_SYSTEM_PROMPT = (
|
| 31 |
+
"You are a defensive AppSec repair agent in the local CyberSecurity_OWASP "
|
| 32 |
+
"OpenEnv environment. Use only the listed local tools. Do not target real "
|
| 33 |
+
"systems. Work step by step: inspect policy and generated code, reproduce "
|
| 34 |
+
"the authorization issue locally, submit a policy-tied diagnosis, patch the "
|
| 35 |
+
"generated app, run visible tests, then submit the fix. Return exactly one "
|
| 36 |
+
"JSON action object and no markdown."
|
| 37 |
+
)
|
| 38 |
+
BANNED_PROMPT_MARKERS = (
|
| 39 |
+
"hidden_facts",
|
| 40 |
+
"oracle_hidden_focus",
|
| 41 |
+
"reward_engine",
|
| 42 |
+
"validators.py",
|
| 43 |
+
"rewards.py",
|
| 44 |
+
"tests/hidden",
|
| 45 |
+
"hidden tests",
|
| 46 |
+
".git",
|
| 47 |
+
)
|
| 48 |
+
RISKY_ARGUMENT_MARKERS = (
|
| 49 |
+
"hidden",
|
| 50 |
+
"oracle",
|
| 51 |
+
"reward_engine",
|
| 52 |
+
"validators.py",
|
| 53 |
+
"rewards.py",
|
| 54 |
+
".git",
|
| 55 |
+
"..",
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@dataclass
|
| 60 |
+
class DatasetConfig:
|
| 61 |
+
teacher_model: str = DEFAULT_TEACHER_MODEL
|
| 62 |
+
target_model: str = DEFAULT_TARGET_MODEL
|
| 63 |
+
split: str = "train"
|
| 64 |
+
difficulty: int = 0
|
| 65 |
+
seed_start: int = 0
|
| 66 |
+
episodes: int = 100
|
| 67 |
+
validation_episodes: int = 0
|
| 68 |
+
out_dir: Path = Path("outputs/sft")
|
| 69 |
+
max_steps: int = 40
|
| 70 |
+
max_teacher_retries: int = 2
|
| 71 |
+
max_tokens: int = 768
|
| 72 |
+
temperature: float = 0.2
|
| 73 |
+
top_p: float = 0.95
|
| 74 |
+
dry_run_oracle: bool = False
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class HuggingFaceTeacher:
|
| 78 |
+
"""Small wrapper around Hugging Face chat completion."""
|
| 79 |
+
|
| 80 |
+
def __init__(
|
| 81 |
+
self,
|
| 82 |
+
*,
|
| 83 |
+
model: str,
|
| 84 |
+
token: str,
|
| 85 |
+
max_tokens: int,
|
| 86 |
+
temperature: float,
|
| 87 |
+
top_p: float,
|
| 88 |
+
) -> None:
|
| 89 |
+
try:
|
| 90 |
+
from huggingface_hub import InferenceClient
|
| 91 |
+
except ImportError as exc: # pragma: no cover - dependency smoke checked separately
|
| 92 |
+
raise RuntimeError(
|
| 93 |
+
"huggingface_hub is required for teacher generation. Install project "
|
| 94 |
+
"dependencies or use --dry-run-oracle for local CI."
|
| 95 |
+
) from exc
|
| 96 |
+
|
| 97 |
+
self.model = model
|
| 98 |
+
self.max_tokens = int(max_tokens)
|
| 99 |
+
self.temperature = float(temperature)
|
| 100 |
+
self.top_p = float(top_p)
|
| 101 |
+
self.client = InferenceClient(token=token)
|
| 102 |
+
|
| 103 |
+
def complete(self, messages: list[dict[str, str]]) -> str:
|
| 104 |
+
response = self.client.chat_completion(
|
| 105 |
+
model=self.model,
|
| 106 |
+
messages=messages,
|
| 107 |
+
max_tokens=self.max_tokens,
|
| 108 |
+
temperature=self.temperature,
|
| 109 |
+
top_p=self.top_p,
|
| 110 |
+
)
|
| 111 |
+
return _chat_response_content(response)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def _chat_response_content(response: Any) -> str:
|
| 115 |
+
choices = getattr(response, "choices", None)
|
| 116 |
+
if choices:
|
| 117 |
+
message = getattr(choices[0], "message", None)
|
| 118 |
+
content = getattr(message, "content", None)
|
| 119 |
+
if content is not None:
|
| 120 |
+
return str(content)
|
| 121 |
+
if isinstance(response, dict):
|
| 122 |
+
choices = response.get("choices") or []
|
| 123 |
+
if choices:
|
| 124 |
+
message = choices[0].get("message") or {}
|
| 125 |
+
return str(message.get("content", ""))
|
| 126 |
+
return str(response)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def extract_first_json_object(text: str) -> dict[str, Any] | None:
|
| 130 |
+
"""Extract the first JSON object from raw teacher text."""
|
| 131 |
+
|
| 132 |
+
stripped = text.strip()
|
| 133 |
+
candidates = [stripped]
|
| 134 |
+
if "```" in stripped:
|
| 135 |
+
for part in stripped.split("```"):
|
| 136 |
+
candidate = part.strip()
|
| 137 |
+
if candidate.startswith("json"):
|
| 138 |
+
candidate = candidate[4:].strip()
|
| 139 |
+
candidates.append(candidate)
|
| 140 |
+
|
| 141 |
+
for candidate in candidates:
|
| 142 |
+
try:
|
| 143 |
+
loaded = json.loads(candidate)
|
| 144 |
+
except Exception:
|
| 145 |
+
continue
|
| 146 |
+
if isinstance(loaded, dict):
|
| 147 |
+
return loaded
|
| 148 |
+
|
| 149 |
+
start = stripped.find("{")
|
| 150 |
+
while start >= 0:
|
| 151 |
+
depth = 0
|
| 152 |
+
in_string = False
|
| 153 |
+
escaped = False
|
| 154 |
+
for index in range(start, len(stripped)):
|
| 155 |
+
char = stripped[index]
|
| 156 |
+
if in_string:
|
| 157 |
+
if escaped:
|
| 158 |
+
escaped = False
|
| 159 |
+
elif char == "\\":
|
| 160 |
+
escaped = True
|
| 161 |
+
elif char == '"':
|
| 162 |
+
in_string = False
|
| 163 |
+
continue
|
| 164 |
+
if char == '"':
|
| 165 |
+
in_string = True
|
| 166 |
+
elif char == "{":
|
| 167 |
+
depth += 1
|
| 168 |
+
elif char == "}":
|
| 169 |
+
depth -= 1
|
| 170 |
+
if depth == 0:
|
| 171 |
+
try:
|
| 172 |
+
loaded = json.loads(stripped[start : index + 1])
|
| 173 |
+
except Exception:
|
| 174 |
+
break
|
| 175 |
+
if isinstance(loaded, dict):
|
| 176 |
+
return loaded
|
| 177 |
+
start = stripped.find("{", start + 1)
|
| 178 |
+
return None
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def parse_action_text(text: str) -> CyberSecurityOWASPAction:
|
| 182 |
+
data = extract_first_json_object(text)
|
| 183 |
+
if data is None:
|
| 184 |
+
raise ValueError("teacher did not return a JSON object")
|
| 185 |
+
return CyberSecurityOWASPAction(**data)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def action_to_json(action: CyberSecurityOWASPAction) -> str:
|
| 189 |
+
return json.dumps(action.model_dump(), separators=(",", ":"), sort_keys=True)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def _safe_observation_payload(
|
| 193 |
+
observation: CyberSecurityOWASPObservation,
|
| 194 |
+
recent_actions: list[dict[str, Any]],
|
| 195 |
+
) -> dict[str, Any]:
|
| 196 |
+
return {
|
| 197 |
+
"phase": observation.phase,
|
| 198 |
+
"task_brief": observation.task_brief,
|
| 199 |
+
"scenario_prompt": observation.scenario_prompt,
|
| 200 |
+
"available_actions": observation.available_actions,
|
| 201 |
+
"last_tool_result": observation.last_tool_result,
|
| 202 |
+
"last_action_valid": observation.last_action_valid,
|
| 203 |
+
"last_action_error": observation.last_action_error,
|
| 204 |
+
"visible_test_result": observation.visible_test_result,
|
| 205 |
+
"done_reason": observation.done_reason,
|
| 206 |
+
"recent_actions": recent_actions[-8:],
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
def build_user_prompt(
|
| 211 |
+
observation: CyberSecurityOWASPObservation,
|
| 212 |
+
recent_actions: list[dict[str, Any]],
|
| 213 |
+
retry_error: str | None = None,
|
| 214 |
+
) -> str:
|
| 215 |
+
payload = _safe_observation_payload(observation, recent_actions)
|
| 216 |
+
prompt = (
|
| 217 |
+
"Current CyberSecurity_OWASP observation, containing only information "
|
| 218 |
+
"available to the agent:\n"
|
| 219 |
+
f"{json.dumps(payload, indent=2, sort_keys=True)}\n\n"
|
| 220 |
+
"Choose the next action. Output exactly one JSON object with keys "
|
| 221 |
+
"`tool_name` and `arguments`. Do not include markdown or commentary."
|
| 222 |
+
)
|
| 223 |
+
if retry_error:
|
| 224 |
+
prompt += f"\nPrevious candidate was rejected safely: {retry_error}"
|
| 225 |
+
_assert_prompt_is_safe(prompt)
|
| 226 |
+
return prompt
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def _assert_prompt_is_safe(prompt: str) -> None:
|
| 230 |
+
lowered = prompt.lower()
|
| 231 |
+
leaked = [marker for marker in BANNED_PROMPT_MARKERS if marker.lower() in lowered]
|
| 232 |
+
if leaked:
|
| 233 |
+
raise ValueError(f"prompt contains blocked marker(s): {', '.join(leaked)}")
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def build_chat_messages(
|
| 237 |
+
observation: CyberSecurityOWASPObservation,
|
| 238 |
+
recent_actions: list[dict[str, Any]],
|
| 239 |
+
retry_error: str | None = None,
|
| 240 |
+
) -> list[dict[str, str]]:
|
| 241 |
+
return [
|
| 242 |
+
{"role": "system", "content": TRAINING_SYSTEM_PROMPT},
|
| 243 |
+
{"role": "user", "content": build_user_prompt(observation, recent_actions, retry_error)},
|
| 244 |
+
]
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def make_chat_row(
|
| 248 |
+
*,
|
| 249 |
+
messages: list[dict[str, str]],
|
| 250 |
+
action: CyberSecurityOWASPAction,
|
| 251 |
+
metadata: dict[str, Any],
|
| 252 |
+
) -> dict[str, Any]:
|
| 253 |
+
return {
|
| 254 |
+
"messages": [
|
| 255 |
+
*messages,
|
| 256 |
+
{"role": "assistant", "content": action_to_json(action)},
|
| 257 |
+
],
|
| 258 |
+
"metadata": metadata,
|
| 259 |
+
}
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def preflight_action(
|
| 263 |
+
env: CybersecurityOwaspEnvironment,
|
| 264 |
+
observation: CyberSecurityOWASPObservation,
|
| 265 |
+
action: CyberSecurityOWASPAction,
|
| 266 |
+
) -> tuple[bool, str]:
|
| 267 |
+
if action.tool_name not in observation.available_actions:
|
| 268 |
+
return False, f"{action.tool_name} is not allowed during {observation.phase}"
|
| 269 |
+
args = action.arguments or {}
|
| 270 |
+
flags = detect_cheating(env.state, action)
|
| 271 |
+
if flags:
|
| 272 |
+
return False, f"action triggered safety flags: {', '.join(flags)}"
|
| 273 |
+
arg_text = json.dumps(args, sort_keys=True, default=str).lower()
|
| 274 |
+
if any(marker in arg_text for marker in RISKY_ARGUMENT_MARKERS):
|
| 275 |
+
return False, "arguments reference blocked files or paths"
|
| 276 |
+
if action.tool_name == "read_file" and not args.get("path"):
|
| 277 |
+
return False, "read_file requires path"
|
| 278 |
+
if action.tool_name == "search_code" and not args.get("query"):
|
| 279 |
+
return False, "search_code requires query"
|
| 280 |
+
if action.tool_name == "patch_file":
|
| 281 |
+
path = str(args.get("path", ""))
|
| 282 |
+
if not path:
|
| 283 |
+
return False, "patch_file requires path"
|
| 284 |
+
if path.replace("\\", "/").startswith("tests/"):
|
| 285 |
+
return False, "patch_file cannot modify tests"
|
| 286 |
+
if not args.get("content") and not args.get("diff"):
|
| 287 |
+
return False, "patch_file requires content or diff"
|
| 288 |
+
if action.tool_name == "send_local_request":
|
| 289 |
+
path = str(args.get("path", ""))
|
| 290 |
+
if not path.startswith("/"):
|
| 291 |
+
return False, "send_local_request requires a local route path"
|
| 292 |
+
if action.tool_name == "compare_identities":
|
| 293 |
+
path = str(args.get("path", ""))
|
| 294 |
+
if not path.startswith("/"):
|
| 295 |
+
return False, "compare_identities requires a local route path"
|
| 296 |
+
if not args.get("first_user_id") or not args.get("second_user_id"):
|
| 297 |
+
return False, "compare_identities requires two user ids"
|
| 298 |
+
if action.tool_name == "submit_diagnosis":
|
| 299 |
+
required = ("bug_class", "route", "violated_policy_rule", "evidence_trace_ids", "fix_plan")
|
| 300 |
+
missing = [key for key in required if not args.get(key)]
|
| 301 |
+
if missing:
|
| 302 |
+
return False, f"submit_diagnosis missing: {', '.join(missing)}"
|
| 303 |
+
return True, ""
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def _trace_id_from_observation(observation: CyberSecurityOWASPObservation) -> str:
|
| 307 |
+
try:
|
| 308 |
+
payload = json.loads(observation.last_tool_result)
|
| 309 |
+
except Exception:
|
| 310 |
+
return "req_001"
|
| 311 |
+
return str(payload.get("trace_id", "req_001"))
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def _secure_invoice_source(env: CybersecurityOwaspEnvironment) -> str:
|
| 315 |
+
source = (Path(env.state.hidden_facts["workspace"]) / "app/routes/invoices.py").read_text(
|
| 316 |
+
encoding="utf-8"
|
| 317 |
+
)
|
| 318 |
+
return source.replace(
|
| 319 |
+
" # BUG: this only checks that the caller is authenticated. It forgets the\n"
|
| 320 |
+
" # owner/admin and tenant policy checks required by the policy graph.\n"
|
| 321 |
+
" return {\"status\": 200, \"body\": invoice}\n",
|
| 322 |
+
" if invoice[\"tenant_id\"] != actor[\"tenant_id\"]:\n"
|
| 323 |
+
" return {\"status\": 403, \"body\": {\"detail\": \"forbidden\"}}\n"
|
| 324 |
+
" if invoice[\"owner_user_id\"] != actor[\"user_id\"] and not is_billing_admin(actor):\n"
|
| 325 |
+
" return {\"status\": 403, \"body\": {\"detail\": \"forbidden\"}}\n"
|
| 326 |
+
" return {\"status\": 200, \"body\": invoice}\n",
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def oracle_actions_for_state(
|
| 331 |
+
env: CybersecurityOwaspEnvironment,
|
| 332 |
+
evidence_trace_id: str | None = None,
|
| 333 |
+
) -> list[CyberSecurityOWASPAction]:
|
| 334 |
+
hidden = env.state.hidden_facts
|
| 335 |
+
trace_id = evidence_trace_id or "req_001"
|
| 336 |
+
return [
|
| 337 |
+
CyberSecurityOWASPAction(tool_name="inspect_policy_graph", arguments={}),
|
| 338 |
+
CyberSecurityOWASPAction(tool_name="list_routes", arguments={}),
|
| 339 |
+
CyberSecurityOWASPAction(
|
| 340 |
+
tool_name="read_file",
|
| 341 |
+
arguments={"path": "app/routes/invoices.py"},
|
| 342 |
+
),
|
| 343 |
+
CyberSecurityOWASPAction(
|
| 344 |
+
tool_name="send_local_request",
|
| 345 |
+
arguments={
|
| 346 |
+
"method": "GET",
|
| 347 |
+
"path": f"/invoices/{hidden['other_invoice_id']}",
|
| 348 |
+
"user_id": hidden["owner_user_id"],
|
| 349 |
+
},
|
| 350 |
+
),
|
| 351 |
+
CyberSecurityOWASPAction(
|
| 352 |
+
tool_name="submit_diagnosis",
|
| 353 |
+
arguments={
|
| 354 |
+
"bug_class": "idor_ownership_bug",
|
| 355 |
+
"route": "GET /invoices/{invoice_id}",
|
| 356 |
+
"violated_policy_rule": "Only the owner or a billing_admin in the same tenant may read invoices.",
|
| 357 |
+
"evidence_trace_ids": [trace_id],
|
| 358 |
+
"fix_plan": "Add tenant and owner/admin checks before returning invoice data.",
|
| 359 |
+
},
|
| 360 |
+
),
|
| 361 |
+
CyberSecurityOWASPAction(
|
| 362 |
+
tool_name="patch_file",
|
| 363 |
+
arguments={"path": "app/routes/invoices.py", "content": _secure_invoice_source(env)},
|
| 364 |
+
),
|
| 365 |
+
CyberSecurityOWASPAction(tool_name="run_visible_tests", arguments={}),
|
| 366 |
+
CyberSecurityOWASPAction(tool_name="submit_fix", arguments={}),
|
| 367 |
+
]
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
def _teacher_action(
|
| 371 |
+
*,
|
| 372 |
+
teacher: HuggingFaceTeacher,
|
| 373 |
+
env: CybersecurityOwaspEnvironment,
|
| 374 |
+
observation: CyberSecurityOWASPObservation,
|
| 375 |
+
recent_actions: list[dict[str, Any]],
|
| 376 |
+
config: DatasetConfig,
|
| 377 |
+
) -> tuple[CyberSecurityOWASPAction, list[dict[str, str]]]:
|
| 378 |
+
retry_error: str | None = None
|
| 379 |
+
for _ in range(config.max_teacher_retries + 1):
|
| 380 |
+
messages = build_chat_messages(observation, recent_actions, retry_error)
|
| 381 |
+
raw = teacher.complete(messages)
|
| 382 |
+
try:
|
| 383 |
+
action = parse_action_text(raw)
|
| 384 |
+
except Exception as exc:
|
| 385 |
+
retry_error = str(exc)
|
| 386 |
+
continue
|
| 387 |
+
ok, error = preflight_action(env, observation, action)
|
| 388 |
+
if ok:
|
| 389 |
+
return action, messages
|
| 390 |
+
retry_error = error
|
| 391 |
+
raise ValueError(retry_error or "teacher did not produce a usable action")
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
def _oracle_action(
|
| 395 |
+
*,
|
| 396 |
+
env: CybersecurityOwaspEnvironment,
|
| 397 |
+
observation: CyberSecurityOWASPObservation,
|
| 398 |
+
recent_actions: list[dict[str, Any]],
|
| 399 |
+
oracle_actions: list[CyberSecurityOWASPAction],
|
| 400 |
+
step_index: int,
|
| 401 |
+
) -> tuple[CyberSecurityOWASPAction, list[dict[str, str]]]:
|
| 402 |
+
action = oracle_actions[step_index]
|
| 403 |
+
messages = build_chat_messages(observation, recent_actions)
|
| 404 |
+
ok, error = preflight_action(env, observation, action)
|
| 405 |
+
if not ok:
|
| 406 |
+
raise ValueError(error)
|
| 407 |
+
return action, messages
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
def _terminal_checks_passed(env: CybersecurityOwaspEnvironment) -> bool:
|
| 411 |
+
verifier = env.state.verification_summary or {}
|
| 412 |
+
required = ("visible", "security", "regression", "public_routes", "patch_quality")
|
| 413 |
+
return all(bool((verifier.get(key) or {}).get("passed", False)) for key in required)
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
def _episode_reward(env: CybersecurityOwaspEnvironment) -> float:
|
| 417 |
+
if env.state.reward_history:
|
| 418 |
+
return float(env.state.reward_history[-1].get("terminal_total", 0.0))
|
| 419 |
+
return 0.0
|
| 420 |
+
|
| 421 |
+
|
| 422 |
+
def run_episode(
|
| 423 |
+
*,
|
| 424 |
+
seed: int,
|
| 425 |
+
split: str,
|
| 426 |
+
difficulty: int,
|
| 427 |
+
config: DatasetConfig,
|
| 428 |
+
teacher: HuggingFaceTeacher | None,
|
| 429 |
+
) -> dict[str, Any]:
|
| 430 |
+
env = CybersecurityOwaspEnvironment()
|
| 431 |
+
rows: list[dict[str, Any]] = []
|
| 432 |
+
trajectory_steps: list[dict[str, Any]] = []
|
| 433 |
+
recent_actions: list[dict[str, Any]] = []
|
| 434 |
+
try:
|
| 435 |
+
observation = env.reset(seed=seed, split=split, difficulty=difficulty)
|
| 436 |
+
oracle_actions = oracle_actions_for_state(env) if config.dry_run_oracle else []
|
| 437 |
+
for step_index in range(config.max_steps):
|
| 438 |
+
if observation.done:
|
| 439 |
+
break
|
| 440 |
+
if config.dry_run_oracle:
|
| 441 |
+
if step_index >= len(oracle_actions):
|
| 442 |
+
raise ValueError("oracle action script ended before terminal state")
|
| 443 |
+
if step_index == 4 and env.state.request_trace:
|
| 444 |
+
trace_id = _trace_id_from_observation(observation)
|
| 445 |
+
oracle_actions = oracle_actions_for_state(env, evidence_trace_id=trace_id)
|
| 446 |
+
action, messages = _oracle_action(
|
| 447 |
+
env=env,
|
| 448 |
+
observation=observation,
|
| 449 |
+
recent_actions=recent_actions,
|
| 450 |
+
oracle_actions=oracle_actions,
|
| 451 |
+
step_index=step_index,
|
| 452 |
+
)
|
| 453 |
+
else:
|
| 454 |
+
if teacher is None:
|
| 455 |
+
raise RuntimeError("teacher is required unless --dry-run-oracle is set")
|
| 456 |
+
action, messages = _teacher_action(
|
| 457 |
+
teacher=teacher,
|
| 458 |
+
env=env,
|
| 459 |
+
observation=observation,
|
| 460 |
+
recent_actions=recent_actions,
|
| 461 |
+
config=config,
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
step_number = step_index + 1
|
| 465 |
+
action_record = action.model_dump()
|
| 466 |
+
row = make_chat_row(
|
| 467 |
+
messages=messages,
|
| 468 |
+
action=action,
|
| 469 |
+
metadata={
|
| 470 |
+
"target_model": config.target_model,
|
| 471 |
+
"teacher_model": config.teacher_model,
|
| 472 |
+
"seed": seed,
|
| 473 |
+
"split": split,
|
| 474 |
+
"difficulty": difficulty,
|
| 475 |
+
"step": step_number,
|
| 476 |
+
"tool_name": action.tool_name,
|
| 477 |
+
"task_id": env.state.task_id,
|
| 478 |
+
"episode_id": env.state.episode_id,
|
| 479 |
+
"scenario_hash": env.state.scenario_hash,
|
| 480 |
+
},
|
| 481 |
+
)
|
| 482 |
+
next_observation = env.step(action)
|
| 483 |
+
trajectory_steps.append(
|
| 484 |
+
{
|
| 485 |
+
"step": step_number,
|
| 486 |
+
"prompt_messages": messages,
|
| 487 |
+
"action": action_record,
|
| 488 |
+
"observation": next_observation.model_dump(),
|
| 489 |
+
"reward_breakdown": dict(next_observation.reward_breakdown or {}),
|
| 490 |
+
}
|
| 491 |
+
)
|
| 492 |
+
if not next_observation.last_action_valid:
|
| 493 |
+
raise ValueError(next_observation.last_action_error or "invalid action")
|
| 494 |
+
if env.state.anti_cheat_flags:
|
| 495 |
+
raise ValueError(f"anti-cheat flags: {env.state.anti_cheat_flags}")
|
| 496 |
+
rows.append(row)
|
| 497 |
+
recent_actions.append(action_record)
|
| 498 |
+
observation = next_observation
|
| 499 |
+
if observation.done:
|
| 500 |
+
break
|
| 501 |
+
|
| 502 |
+
if not env.state.done:
|
| 503 |
+
raise ValueError("episode did not reach a terminal state")
|
| 504 |
+
if not env.state.success:
|
| 505 |
+
raise ValueError(env.state.failure_reason or "terminal verifier failed")
|
| 506 |
+
if env.state.step_count > config.max_steps:
|
| 507 |
+
raise ValueError("episode exceeded max steps")
|
| 508 |
+
if env.state.anti_cheat_flags:
|
| 509 |
+
raise ValueError("episode has anti-cheat flags")
|
| 510 |
+
if not _terminal_checks_passed(env):
|
| 511 |
+
raise ValueError("terminal verifier checks did not all pass")
|
| 512 |
+
|
| 513 |
+
final_reward = _episode_reward(env)
|
| 514 |
+
final_breakdown = dict(env.state.reward_history[-1]) if env.state.reward_history else {}
|
| 515 |
+
for row in rows:
|
| 516 |
+
row["metadata"].update(
|
| 517 |
+
{
|
| 518 |
+
"final_success": True,
|
| 519 |
+
"terminal_total": final_reward,
|
| 520 |
+
"total_reward": float(env.state.accumulated_reward),
|
| 521 |
+
"anti_cheat_flags": list(env.state.anti_cheat_flags),
|
| 522 |
+
"final_reward_breakdown": final_breakdown,
|
| 523 |
+
}
|
| 524 |
+
)
|
| 525 |
+
return {
|
| 526 |
+
"accepted": True,
|
| 527 |
+
"seed": seed,
|
| 528 |
+
"split": split,
|
| 529 |
+
"difficulty": difficulty,
|
| 530 |
+
"rows": rows,
|
| 531 |
+
"trajectory": {
|
| 532 |
+
"episode_id": env.state.episode_id,
|
| 533 |
+
"task_id": env.state.task_id,
|
| 534 |
+
"seed": seed,
|
| 535 |
+
"split": split,
|
| 536 |
+
"difficulty": difficulty,
|
| 537 |
+
"domain": env.state.domain,
|
| 538 |
+
"bug_family": env.state.bug_family,
|
| 539 |
+
"scenario_hash": env.state.scenario_hash,
|
| 540 |
+
"actions": [step["action"] for step in trajectory_steps],
|
| 541 |
+
"steps": trajectory_steps,
|
| 542 |
+
"reward_breakdown_by_step": list(env.state.reward_history),
|
| 543 |
+
"final_reward_breakdown": final_breakdown,
|
| 544 |
+
"total_reward": float(env.state.accumulated_reward),
|
| 545 |
+
"terminal_total": final_reward,
|
| 546 |
+
"success": True,
|
| 547 |
+
"failure_reason": None,
|
| 548 |
+
"anti_cheat_flags": list(env.state.anti_cheat_flags),
|
| 549 |
+
"verification_summary": env.state.verification_summary,
|
| 550 |
+
},
|
| 551 |
+
}
|
| 552 |
+
except Exception as exc:
|
| 553 |
+
return {
|
| 554 |
+
"accepted": False,
|
| 555 |
+
"seed": seed,
|
| 556 |
+
"split": split,
|
| 557 |
+
"difficulty": difficulty,
|
| 558 |
+
"reason": str(exc),
|
| 559 |
+
"rows": [],
|
| 560 |
+
"trajectory": {
|
| 561 |
+
"seed": seed,
|
| 562 |
+
"split": split,
|
| 563 |
+
"difficulty": difficulty,
|
| 564 |
+
"steps": trajectory_steps,
|
| 565 |
+
"actions": [step["action"] for step in trajectory_steps],
|
| 566 |
+
"success": bool(env.state.success),
|
| 567 |
+
"failure_reason": env.state.failure_reason or str(exc),
|
| 568 |
+
"anti_cheat_flags": list(env.state.anti_cheat_flags),
|
| 569 |
+
},
|
| 570 |
+
}
|
| 571 |
+
finally:
|
| 572 |
+
env.close()
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
def write_jsonl(path: Path, rows: Iterable[dict[str, Any]]) -> None:
|
| 576 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 577 |
+
with path.open("w", encoding="utf-8") as handle:
|
| 578 |
+
for row in rows:
|
| 579 |
+
handle.write(json.dumps(row, sort_keys=True, default=str) + "\n")
|
| 580 |
+
|
| 581 |
+
|
| 582 |
+
def _write_trajectory(out_dir: Path, trajectory: dict[str, Any]) -> Path:
|
| 583 |
+
traj_dir = out_dir / "trajectories"
|
| 584 |
+
traj_dir.mkdir(parents=True, exist_ok=True)
|
| 585 |
+
name = (
|
| 586 |
+
f"{trajectory.get('split', 'train')}_seed{trajectory.get('seed', 0)}_"
|
| 587 |
+
f"{str(trajectory.get('episode_id', 'rejected'))[:12]}.json"
|
| 588 |
+
)
|
| 589 |
+
path = traj_dir / name
|
| 590 |
+
path.write_text(json.dumps(trajectory, indent=2, sort_keys=True, default=str), encoding="utf-8")
|
| 591 |
+
return path
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
def _git_sha() -> str:
|
| 595 |
+
root = Path(__file__).resolve().parents[1]
|
| 596 |
+
try:
|
| 597 |
+
return subprocess.check_output(
|
| 598 |
+
[
|
| 599 |
+
"git",
|
| 600 |
+
"-c",
|
| 601 |
+
f"safe.directory={root.as_posix()}",
|
| 602 |
+
"rev-parse",
|
| 603 |
+
"HEAD",
|
| 604 |
+
],
|
| 605 |
+
cwd=root,
|
| 606 |
+
text=True,
|
| 607 |
+
stderr=subprocess.DEVNULL,
|
| 608 |
+
).strip()
|
| 609 |
+
except Exception:
|
| 610 |
+
return "nogit"
|
| 611 |
+
|
| 612 |
+
|
| 613 |
+
def _reward_summary(values: list[float]) -> dict[str, float]:
|
| 614 |
+
if not values:
|
| 615 |
+
return {"mean": 0.0, "min": 0.0, "max": 0.0, "p50": 0.0}
|
| 616 |
+
sorted_values = sorted(values)
|
| 617 |
+
return {
|
| 618 |
+
"mean": float(statistics.mean(values)),
|
| 619 |
+
"min": float(min(values)),
|
| 620 |
+
"max": float(max(values)),
|
| 621 |
+
"p50": float(sorted_values[len(sorted_values) // 2]),
|
| 622 |
+
}
|
| 623 |
+
|
| 624 |
+
|
| 625 |
+
def generate_dataset(config: DatasetConfig) -> dict[str, Any]:
|
| 626 |
+
config.out_dir.mkdir(parents=True, exist_ok=True)
|
| 627 |
+
teacher = None
|
| 628 |
+
if not config.dry_run_oracle:
|
| 629 |
+
token = os.getenv("HF_TOKEN")
|
| 630 |
+
if not token:
|
| 631 |
+
raise RuntimeError("HF_TOKEN is required unless --dry-run-oracle is set")
|
| 632 |
+
teacher = HuggingFaceTeacher(
|
| 633 |
+
model=config.teacher_model,
|
| 634 |
+
token=token,
|
| 635 |
+
max_tokens=config.max_tokens,
|
| 636 |
+
temperature=config.temperature,
|
| 637 |
+
top_p=config.top_p,
|
| 638 |
+
)
|
| 639 |
+
|
| 640 |
+
split_jobs = [(config.split, config.episodes, config.seed_start)]
|
| 641 |
+
if config.validation_episodes:
|
| 642 |
+
split_jobs.append(("validation", config.validation_episodes, config.seed_start + config.episodes))
|
| 643 |
+
|
| 644 |
+
rows_by_split: dict[str, list[dict[str, Any]]] = {"train": [], "validation": []}
|
| 645 |
+
attempts: list[dict[str, Any]] = []
|
| 646 |
+
rewards: list[float] = []
|
| 647 |
+
accepted = 0
|
| 648 |
+
attempted = 0
|
| 649 |
+
for split, episodes, seed_start in split_jobs:
|
| 650 |
+
for offset in range(int(episodes)):
|
| 651 |
+
seed = int(seed_start) + offset
|
| 652 |
+
attempted += 1
|
| 653 |
+
result = run_episode(
|
| 654 |
+
seed=seed,
|
| 655 |
+
split=split,
|
| 656 |
+
difficulty=config.difficulty,
|
| 657 |
+
config=config,
|
| 658 |
+
teacher=teacher,
|
| 659 |
+
)
|
| 660 |
+
attempts.append(
|
| 661 |
+
{
|
| 662 |
+
"seed": seed,
|
| 663 |
+
"split": split,
|
| 664 |
+
"accepted": bool(result["accepted"]),
|
| 665 |
+
"reason": result.get("reason", ""),
|
| 666 |
+
"trajectory_path": str(_write_trajectory(config.out_dir, result["trajectory"])),
|
| 667 |
+
}
|
| 668 |
+
)
|
| 669 |
+
if result["accepted"]:
|
| 670 |
+
accepted += 1
|
| 671 |
+
rows = list(result["rows"])
|
| 672 |
+
rows_by_split.setdefault(split, []).extend(rows)
|
| 673 |
+
rewards.append(float(result["trajectory"].get("terminal_total", 0.0)))
|
| 674 |
+
|
| 675 |
+
for split_name in ("train", "validation", config.split):
|
| 676 |
+
write_jsonl(config.out_dir / f"{split_name}.jsonl", rows_by_split.get(split_name, []))
|
| 677 |
+
|
| 678 |
+
manifest = {
|
| 679 |
+
"teacher_model": config.teacher_model,
|
| 680 |
+
"target_model": config.target_model,
|
| 681 |
+
"split": config.split,
|
| 682 |
+
"difficulty": config.difficulty,
|
| 683 |
+
"seed_start": config.seed_start,
|
| 684 |
+
"episodes_attempted": attempted,
|
| 685 |
+
"episodes_accepted": accepted,
|
| 686 |
+
"acceptance_rate": accepted / attempted if attempted else 0.0,
|
| 687 |
+
"rows_by_split": {key: len(value) for key, value in sorted(rows_by_split.items())},
|
| 688 |
+
"reward_summary": _reward_summary(rewards),
|
| 689 |
+
"git_sha": _git_sha(),
|
| 690 |
+
"verifier_version": "verifier_v1",
|
| 691 |
+
"dry_run_oracle": config.dry_run_oracle,
|
| 692 |
+
"attempts": attempts,
|
| 693 |
+
}
|
| 694 |
+
manifest_path = config.out_dir / "manifest.json"
|
| 695 |
+
manifest_path.write_text(
|
| 696 |
+
json.dumps(manifest, indent=2, sort_keys=True, default=str),
|
| 697 |
+
encoding="utf-8",
|
| 698 |
+
)
|
| 699 |
+
return manifest
|
| 700 |
+
|
| 701 |
+
|
| 702 |
+
def build_arg_parser() -> argparse.ArgumentParser:
|
| 703 |
+
parser = argparse.ArgumentParser(description=__doc__)
|
| 704 |
+
parser.add_argument("--teacher-model", default=DEFAULT_TEACHER_MODEL)
|
| 705 |
+
parser.add_argument("--target-model", default=DEFAULT_TARGET_MODEL)
|
| 706 |
+
parser.add_argument("--split", default="train", choices=["train", "validation", "hidden_eval"])
|
| 707 |
+
parser.add_argument("--difficulty", type=int, default=0)
|
| 708 |
+
parser.add_argument("--seed-start", type=int, default=0)
|
| 709 |
+
parser.add_argument("--episodes", type=int, default=100)
|
| 710 |
+
parser.add_argument("--validation-episodes", type=int, default=0)
|
| 711 |
+
parser.add_argument("--out-dir", type=Path, default=Path("outputs/sft"))
|
| 712 |
+
parser.add_argument("--max-steps", type=int, default=40)
|
| 713 |
+
parser.add_argument("--max-teacher-retries", type=int, default=2)
|
| 714 |
+
parser.add_argument("--max-tokens", type=int, default=768)
|
| 715 |
+
parser.add_argument("--temperature", type=float, default=0.2)
|
| 716 |
+
parser.add_argument("--top-p", type=float, default=0.95)
|
| 717 |
+
parser.add_argument(
|
| 718 |
+
"--dry-run-oracle",
|
| 719 |
+
action="store_true",
|
| 720 |
+
help="Generate deterministic oracle data without calling the HF API.",
|
| 721 |
+
)
|
| 722 |
+
return parser
|
| 723 |
+
|
| 724 |
+
|
| 725 |
+
def config_from_args(args: argparse.Namespace) -> DatasetConfig:
|
| 726 |
+
return DatasetConfig(
|
| 727 |
+
teacher_model=args.teacher_model,
|
| 728 |
+
target_model=args.target_model,
|
| 729 |
+
split=args.split,
|
| 730 |
+
difficulty=args.difficulty,
|
| 731 |
+
seed_start=args.seed_start,
|
| 732 |
+
episodes=args.episodes,
|
| 733 |
+
validation_episodes=args.validation_episodes,
|
| 734 |
+
out_dir=args.out_dir,
|
| 735 |
+
max_steps=args.max_steps,
|
| 736 |
+
max_teacher_retries=args.max_teacher_retries,
|
| 737 |
+
max_tokens=args.max_tokens,
|
| 738 |
+
temperature=args.temperature,
|
| 739 |
+
top_p=args.top_p,
|
| 740 |
+
dry_run_oracle=args.dry_run_oracle,
|
| 741 |
+
)
|
| 742 |
+
|
| 743 |
+
|
| 744 |
+
def main(argv: list[str] | None = None) -> int:
|
| 745 |
+
parser = build_arg_parser()
|
| 746 |
+
args = parser.parse_args(argv)
|
| 747 |
+
manifest = generate_dataset(config_from_args(args))
|
| 748 |
+
print(json.dumps(manifest, indent=2, sort_keys=True))
|
| 749 |
+
return 0
|
| 750 |
+
|
| 751 |
+
|
| 752 |
+
if __name__ == "__main__":
|
| 753 |
+
raise SystemExit(main())
|
scripts/launch_reward_ablations.ps1
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
param(
|
| 2 |
+
[switch]$AllowActive
|
| 3 |
+
)
|
| 4 |
+
|
| 5 |
+
$ErrorActionPreference = "Stop"
|
| 6 |
+
$env:PYTHONIOENCODING = "utf-8"
|
| 7 |
+
$env:PYTHONUTF8 = "1"
|
| 8 |
+
|
| 9 |
+
$appList = uv run --extra modal modal app list | Out-String
|
| 10 |
+
Write-Host $appList
|
| 11 |
+
if (-not $AllowActive -and $appList -match "CyberSecur" -and $appList -match "ephemeral") {
|
| 12 |
+
throw "Active CyberSecurity_OWASP Modal apps are present. Re-run with -AllowActive only if overlapping L4 jobs are intentional."
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
$runs = @(
|
| 16 |
+
@{
|
| 17 |
+
Variant = "abl-a0-sparse"
|
| 18 |
+
Config = "training/configs/reward_ablations/A0_sparse_terminal_only.yaml"
|
| 19 |
+
Seed = 110000
|
| 20 |
+
},
|
| 21 |
+
@{
|
| 22 |
+
Variant = "abl-a2-shape035"
|
| 23 |
+
Config = "training/configs/reward_ablations/A2_reduced_shaping.yaml"
|
| 24 |
+
Seed = 120000
|
| 25 |
+
},
|
| 26 |
+
@{
|
| 27 |
+
Variant = "abl-a6-visgate"
|
| 28 |
+
Config = "training/configs/reward_ablations/A6_visible_gate.yaml"
|
| 29 |
+
Seed = 130000
|
| 30 |
+
},
|
| 31 |
+
@{
|
| 32 |
+
Variant = "abl-a7-evid045"
|
| 33 |
+
Config = "training/configs/reward_ablations/A7_evidence045.yaml"
|
| 34 |
+
Seed = 140000
|
| 35 |
+
},
|
| 36 |
+
@{
|
| 37 |
+
Variant = "abl-a3-nospeed"
|
| 38 |
+
Config = "training/configs/reward_ablations/A3_no_speed_token.yaml"
|
| 39 |
+
Seed = 150000
|
| 40 |
+
}
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
foreach ($run in $runs) {
|
| 44 |
+
Write-Host "Launching $($run.Variant) with $($run.Config) seed $($run.Seed)"
|
| 45 |
+
uv run --extra modal modal run --detach scripts/modal_train_grpo.py `
|
| 46 |
+
--mode train `
|
| 47 |
+
--max-steps 60 `
|
| 48 |
+
--dataset-size 32 `
|
| 49 |
+
--num-generations 4 `
|
| 50 |
+
--max-completion-length 768 `
|
| 51 |
+
--difficulty 0 `
|
| 52 |
+
--split train `
|
| 53 |
+
--source-mode local `
|
| 54 |
+
--trace-log-every 5 `
|
| 55 |
+
--seed-start $run.Seed `
|
| 56 |
+
--reward-config $run.Config `
|
| 57 |
+
--reward-variant $run.Variant `
|
| 58 |
+
--detach
|
| 59 |
+
}
|
scripts/modal_train_grpo.py
CHANGED
|
@@ -210,6 +210,24 @@ def _configure_scenario_cache_env(*, required: bool = True) -> dict[str, str]:
|
|
| 210 |
return values
|
| 211 |
|
| 212 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
def _print_image_startup_notice() -> None:
|
| 214 |
global _IMAGE_NOTICE_PRINTED
|
| 215 |
if _IMAGE_NOTICE_PRINTED:
|
|
@@ -583,6 +601,8 @@ def run_cybersecurity_owasp_baseline(
|
|
| 583 |
source_mode: str = "local",
|
| 584 |
repo_url: str = PUBLIC_REPO_URL,
|
| 585 |
repo_branch: str = PUBLIC_REPO_BRANCH,
|
|
|
|
|
|
|
| 586 |
) -> dict[str, str | int | float]:
|
| 587 |
import statistics
|
| 588 |
import time
|
|
@@ -627,8 +647,14 @@ def run_cybersecurity_owasp_baseline(
|
|
| 627 |
|
| 628 |
os.environ["TRACKIO_SPACE_ID"] = trackio_space_id
|
| 629 |
os.environ["TRACKIO_PROJECT"] = trackio_project
|
|
|
|
|
|
|
|
|
|
|
|
|
| 630 |
reward_settings = load_reward_settings()
|
| 631 |
reward_tracking_config = reward_config_trackio_config(reward_settings)
|
|
|
|
|
|
|
| 632 |
run_name = run_name or "baseline"
|
| 633 |
output_dir = RUNS_DIR / run_name
|
| 634 |
output_dir.mkdir(parents=True, exist_ok=True)
|
|
@@ -673,6 +699,10 @@ def run_cybersecurity_owasp_baseline(
|
|
| 673 |
print(f"Trackio Project: {trackio_project}")
|
| 674 |
print(f"Reward config: {reward_tracking_config['reward_config_id']}")
|
| 675 |
print(f"Reward config hash: {reward_tracking_config['reward_config_hash']}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 676 |
print(f"Scenario cache dir: {scenario_cache_env['CYBERSECURITY_OWASP_SCENARIO_CACHE_DIR']}")
|
| 677 |
print(f"Scenario cache coverage: {coverage}")
|
| 678 |
print(
|
|
@@ -818,6 +848,7 @@ def run_cybersecurity_owasp_baseline(
|
|
| 818 |
"num_generations": num_generations,
|
| 819 |
"max_completion_length": max_completion_length,
|
| 820 |
"git_sha": git_sha,
|
|
|
|
| 821 |
**reward_tracking_config,
|
| 822 |
}
|
| 823 |
|
|
@@ -998,6 +1029,8 @@ def run_cybersecurity_owasp_baseline(
|
|
| 998 |
def train_cybersecurity_owasp_grpo(
|
| 999 |
env_repo_id: str = "",
|
| 1000 |
output_repo_id: str = "",
|
|
|
|
|
|
|
| 1001 |
max_steps: int = 10,
|
| 1002 |
dataset_size: int = 16,
|
| 1003 |
difficulty: int = 0,
|
|
@@ -1021,6 +1054,8 @@ def train_cybersecurity_owasp_grpo(
|
|
| 1021 |
repo_url: str = PUBLIC_REPO_URL,
|
| 1022 |
repo_branch: str = PUBLIC_REPO_BRANCH,
|
| 1023 |
push_to_hub: bool = False,
|
|
|
|
|
|
|
| 1024 |
) -> dict[str, str | int | float]:
|
| 1025 |
import inspect
|
| 1026 |
import statistics
|
|
@@ -1050,6 +1085,7 @@ def train_cybersecurity_owasp_grpo(
|
|
| 1050 |
import transformers.utils.hub as transformers_hub
|
| 1051 |
from datasets import Dataset
|
| 1052 |
from huggingface_hub import snapshot_download, whoami
|
|
|
|
| 1053 |
from transformers import TrainerCallback
|
| 1054 |
from trl import GRPOConfig, GRPOTrainer, clone_chat_template
|
| 1055 |
from trl.chat_template_utils import add_response_schema
|
|
@@ -1110,14 +1146,22 @@ def train_cybersecurity_owasp_grpo(
|
|
| 1110 |
|
| 1111 |
os.environ["TRACKIO_SPACE_ID"] = trackio_space_id
|
| 1112 |
os.environ["TRACKIO_PROJECT"] = trackio_project
|
| 1113 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1114 |
reward_settings = load_reward_settings()
|
| 1115 |
reward_tracking_config = reward_config_trackio_config(reward_settings)
|
|
|
|
|
|
|
| 1116 |
|
| 1117 |
model_slug = model_name.replace("/", "-")
|
| 1118 |
stamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
|
| 1119 |
run_name = run_name or (
|
| 1120 |
-
f"CyberSecurity_OWASP-{model_slug}-grpo-level{difficulty}-
|
|
|
|
|
|
|
| 1121 |
)
|
| 1122 |
output_dir = RUNS_DIR / run_name
|
| 1123 |
output_dir.mkdir(parents=True, exist_ok=True)
|
|
@@ -1253,6 +1297,7 @@ def train_cybersecurity_owasp_grpo(
|
|
| 1253 |
"reward_config_hash": reward_tracking_config["reward_config_hash"],
|
| 1254 |
"reward_stage": reward_tracking_config["reward_stage"],
|
| 1255 |
"reward_mode": reward_tracking_config["reward_mode"],
|
|
|
|
| 1256 |
}
|
| 1257 |
)
|
| 1258 |
return obs.scenario_prompt
|
|
@@ -1613,6 +1658,7 @@ def train_cybersecurity_owasp_grpo(
|
|
| 1613 |
"reward_config_hash": reward_tracking_config["reward_config_hash"],
|
| 1614 |
"reward_stage": reward_tracking_config["reward_stage"],
|
| 1615 |
"reward_mode": reward_tracking_config["reward_mode"],
|
|
|
|
| 1616 |
}
|
| 1617 |
)
|
| 1618 |
try:
|
|
@@ -1704,6 +1750,9 @@ def train_cybersecurity_owasp_grpo(
|
|
| 1704 |
print(f"Run name: {run_name}")
|
| 1705 |
print(f"Reward config: {reward_tracking_config['reward_config_id']}")
|
| 1706 |
print(f"Reward config hash: {reward_tracking_config['reward_config_hash']}")
|
|
|
|
|
|
|
|
|
|
| 1707 |
print(f"Model cache volume: {CACHE_VOLUME_NAME}")
|
| 1708 |
print(f"Scenario cache volume: {SCENARIO_CACHE_VOLUME_NAME}")
|
| 1709 |
print(f"Scenario cache dir: {scenario_cache_env['CYBERSECURITY_OWASP_SCENARIO_CACHE_DIR']}")
|
|
@@ -1715,6 +1764,10 @@ def train_cybersecurity_owasp_grpo(
|
|
| 1715 |
print(f"Unsloth cache: {cache_env['UNSLOTH_CACHE_DIR']}")
|
| 1716 |
print(f"Triton cache: {cache_env['TRITON_CACHE_DIR']}")
|
| 1717 |
print(f"Hub push enabled: {push_to_hub}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1718 |
print(
|
| 1719 |
"GRPO throughput config: "
|
| 1720 |
f"per_device_train_batch_size={per_device_train_batch_size}, "
|
|
@@ -1801,25 +1854,40 @@ def train_cybersecurity_owasp_grpo(
|
|
| 1801 |
f"{exc!r}"
|
| 1802 |
)
|
| 1803 |
|
| 1804 |
-
|
| 1805 |
-
|
| 1806 |
-
|
| 1807 |
-
|
| 1808 |
-
|
| 1809 |
-
|
| 1810 |
-
|
| 1811 |
-
|
| 1812 |
-
|
| 1813 |
-
|
| 1814 |
-
|
| 1815 |
-
|
| 1816 |
-
|
| 1817 |
-
|
| 1818 |
-
|
| 1819 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1820 |
if hasattr(model_api, "for_training"):
|
| 1821 |
model_api.for_training(model)
|
| 1822 |
-
print("LoRA adapter
|
| 1823 |
|
| 1824 |
grpo_config_values = {
|
| 1825 |
"temperature": 1.0,
|
|
@@ -1942,6 +2010,8 @@ def train_cybersecurity_owasp_grpo(
|
|
| 1942 |
"difficulty": difficulty,
|
| 1943 |
"split": split,
|
| 1944 |
"model_name": model_name,
|
|
|
|
|
|
|
| 1945 |
"max_completion_length": max_completion_length,
|
| 1946 |
"num_generations": num_generations,
|
| 1947 |
"per_device_train_batch_size": per_device_train_batch_size,
|
|
@@ -1956,6 +2026,7 @@ def train_cybersecurity_owasp_grpo(
|
|
| 1956 |
"push_to_hub": push_to_hub,
|
| 1957 |
"scenario_cache_volume": SCENARIO_CACHE_VOLUME_NAME,
|
| 1958 |
"scenario_cache_mode": "require",
|
|
|
|
| 1959 |
**reward_tracking_config,
|
| 1960 |
}
|
| 1961 |
|
|
@@ -1965,6 +2036,8 @@ def main(
|
|
| 1965 |
mode: str = "train",
|
| 1966 |
env_repo_id: str = "",
|
| 1967 |
output_repo_id: str = "",
|
|
|
|
|
|
|
| 1968 |
max_steps: int = 10,
|
| 1969 |
dataset_size: int = 16,
|
| 1970 |
difficulty: int = 0,
|
|
@@ -1989,6 +2062,8 @@ def main(
|
|
| 1989 |
repo_branch: str = PUBLIC_REPO_BRANCH,
|
| 1990 |
detach: bool = False,
|
| 1991 |
push_to_hub: bool = False,
|
|
|
|
|
|
|
| 1992 |
cache_seed_start: int = 0,
|
| 1993 |
cache_difficulty_buckets: int = 0,
|
| 1994 |
cache_train_per_bucket: int = 0,
|
|
@@ -2042,6 +2117,8 @@ def main(
|
|
| 2042 |
source_mode=source_mode,
|
| 2043 |
repo_url=repo_url,
|
| 2044 |
repo_branch=repo_branch,
|
|
|
|
|
|
|
| 2045 |
)
|
| 2046 |
if detach:
|
| 2047 |
call = run_cybersecurity_owasp_baseline.spawn(**kwargs)
|
|
@@ -2100,7 +2177,13 @@ def main(
|
|
| 2100 |
if git_sha == "nogit":
|
| 2101 |
try:
|
| 2102 |
git_sha = subprocess.check_output(
|
| 2103 |
-
[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2104 |
cwd=PROJECT_ROOT,
|
| 2105 |
text=True,
|
| 2106 |
stderr=subprocess.DEVNULL,
|
|
@@ -2110,12 +2193,15 @@ def main(
|
|
| 2110 |
|
| 2111 |
model_slug = model_name.replace("/", "-")
|
| 2112 |
local_stamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
|
|
|
|
| 2113 |
run_name = run_name or (
|
| 2114 |
f"CyberSecurity_OWASP-{model_slug}-grpo-level{difficulty}-"
|
| 2115 |
-
f"{local_stamp}-{git_sha[:8]}"
|
| 2116 |
)
|
| 2117 |
|
| 2118 |
print(f"Run name: {run_name}")
|
|
|
|
|
|
|
| 2119 |
print(f"Source mode: {source_mode}")
|
| 2120 |
if source_mode == "public":
|
| 2121 |
print(f"Public repo: {repo_url}@{repo_branch}")
|
|
@@ -2131,6 +2217,10 @@ def main(
|
|
| 2131 |
f"<hf-user>/CyberSecurity_OWASP-{_model_repo_slug(model_name)}-grpo-lora"
|
| 2132 |
)
|
| 2133 |
print(f"Hub push enabled: {push_to_hub}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2134 |
print(f"Model cache volume: {CACHE_VOLUME_NAME}")
|
| 2135 |
print(f"Scenario cache volume: {SCENARIO_CACHE_VOLUME_NAME}")
|
| 2136 |
print(
|
|
@@ -2164,6 +2254,8 @@ def main(
|
|
| 2164 |
kwargs = dict(
|
| 2165 |
env_repo_id=env_repo_id,
|
| 2166 |
output_repo_id=output_repo_id,
|
|
|
|
|
|
|
| 2167 |
max_steps=max_steps,
|
| 2168 |
dataset_size=dataset_size,
|
| 2169 |
difficulty=difficulty,
|
|
@@ -2187,6 +2279,8 @@ def main(
|
|
| 2187 |
repo_url=repo_url,
|
| 2188 |
repo_branch=repo_branch,
|
| 2189 |
push_to_hub=push_to_hub,
|
|
|
|
|
|
|
| 2190 |
)
|
| 2191 |
preflight = verify_modal_scenario_cache_for_training.remote(
|
| 2192 |
split=split,
|
|
|
|
| 210 |
return values
|
| 211 |
|
| 212 |
|
| 213 |
+
def _configure_reward_env(
|
| 214 |
+
*,
|
| 215 |
+
reward_config: str = "",
|
| 216 |
+
reward_variant: str = "",
|
| 217 |
+
reward_mode: str = "",
|
| 218 |
+
) -> dict[str, str]:
|
| 219 |
+
values: dict[str, str] = {}
|
| 220 |
+
if reward_config:
|
| 221 |
+
values["CYBERSECURITY_OWASP_REWARD_CONFIG"] = reward_config
|
| 222 |
+
if reward_variant:
|
| 223 |
+
values["CYBERSECURITY_OWASP_REWARD_VARIANT"] = reward_variant
|
| 224 |
+
if reward_mode:
|
| 225 |
+
values["CYBERSECURITY_OWASP_REWARD_MODE"] = reward_mode
|
| 226 |
+
for key, value in values.items():
|
| 227 |
+
os.environ[key] = value
|
| 228 |
+
return values
|
| 229 |
+
|
| 230 |
+
|
| 231 |
def _print_image_startup_notice() -> None:
|
| 232 |
global _IMAGE_NOTICE_PRINTED
|
| 233 |
if _IMAGE_NOTICE_PRINTED:
|
|
|
|
| 601 |
source_mode: str = "local",
|
| 602 |
repo_url: str = PUBLIC_REPO_URL,
|
| 603 |
repo_branch: str = PUBLIC_REPO_BRANCH,
|
| 604 |
+
reward_config: str = "",
|
| 605 |
+
reward_variant: str = "",
|
| 606 |
) -> dict[str, str | int | float]:
|
| 607 |
import statistics
|
| 608 |
import time
|
|
|
|
| 647 |
|
| 648 |
os.environ["TRACKIO_SPACE_ID"] = trackio_space_id
|
| 649 |
os.environ["TRACKIO_PROJECT"] = trackio_project
|
| 650 |
+
reward_env = _configure_reward_env(
|
| 651 |
+
reward_config=reward_config,
|
| 652 |
+
reward_variant=reward_variant,
|
| 653 |
+
)
|
| 654 |
reward_settings = load_reward_settings()
|
| 655 |
reward_tracking_config = reward_config_trackio_config(reward_settings)
|
| 656 |
+
reward_tracking_config["reward_variant"] = reward_variant or "default"
|
| 657 |
+
reward_tracking_config["reward_config_path"] = reward_config or reward_settings.source_path
|
| 658 |
run_name = run_name or "baseline"
|
| 659 |
output_dir = RUNS_DIR / run_name
|
| 660 |
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
| 699 |
print(f"Trackio Project: {trackio_project}")
|
| 700 |
print(f"Reward config: {reward_tracking_config['reward_config_id']}")
|
| 701 |
print(f"Reward config hash: {reward_tracking_config['reward_config_hash']}")
|
| 702 |
+
print(f"Reward variant: {reward_tracking_config['reward_variant']}")
|
| 703 |
+
print(f"Reward config path: {reward_tracking_config['reward_config_path']}")
|
| 704 |
+
if reward_env:
|
| 705 |
+
print(f"Reward env overrides: {reward_env}")
|
| 706 |
print(f"Scenario cache dir: {scenario_cache_env['CYBERSECURITY_OWASP_SCENARIO_CACHE_DIR']}")
|
| 707 |
print(f"Scenario cache coverage: {coverage}")
|
| 708 |
print(
|
|
|
|
| 848 |
"num_generations": num_generations,
|
| 849 |
"max_completion_length": max_completion_length,
|
| 850 |
"git_sha": git_sha,
|
| 851 |
+
"reward_variant": reward_tracking_config["reward_variant"],
|
| 852 |
**reward_tracking_config,
|
| 853 |
}
|
| 854 |
|
|
|
|
| 1029 |
def train_cybersecurity_owasp_grpo(
|
| 1030 |
env_repo_id: str = "",
|
| 1031 |
output_repo_id: str = "",
|
| 1032 |
+
initial_adapter_path: str = "",
|
| 1033 |
+
initial_adapter_repo_id: str = "",
|
| 1034 |
max_steps: int = 10,
|
| 1035 |
dataset_size: int = 16,
|
| 1036 |
difficulty: int = 0,
|
|
|
|
| 1054 |
repo_url: str = PUBLIC_REPO_URL,
|
| 1055 |
repo_branch: str = PUBLIC_REPO_BRANCH,
|
| 1056 |
push_to_hub: bool = False,
|
| 1057 |
+
reward_config: str = "",
|
| 1058 |
+
reward_variant: str = "",
|
| 1059 |
) -> dict[str, str | int | float]:
|
| 1060 |
import inspect
|
| 1061 |
import statistics
|
|
|
|
| 1085 |
import transformers.utils.hub as transformers_hub
|
| 1086 |
from datasets import Dataset
|
| 1087 |
from huggingface_hub import snapshot_download, whoami
|
| 1088 |
+
from peft import PeftModel
|
| 1089 |
from transformers import TrainerCallback
|
| 1090 |
from trl import GRPOConfig, GRPOTrainer, clone_chat_template
|
| 1091 |
from trl.chat_template_utils import add_response_schema
|
|
|
|
| 1146 |
|
| 1147 |
os.environ["TRACKIO_SPACE_ID"] = trackio_space_id
|
| 1148 |
os.environ["TRACKIO_PROJECT"] = trackio_project
|
| 1149 |
+
reward_env = _configure_reward_env(
|
| 1150 |
+
reward_config=reward_config,
|
| 1151 |
+
reward_variant=reward_variant,
|
| 1152 |
+
reward_mode="dense_train",
|
| 1153 |
+
)
|
| 1154 |
reward_settings = load_reward_settings()
|
| 1155 |
reward_tracking_config = reward_config_trackio_config(reward_settings)
|
| 1156 |
+
reward_tracking_config["reward_variant"] = reward_variant or "default"
|
| 1157 |
+
reward_tracking_config["reward_config_path"] = reward_config or reward_settings.source_path
|
| 1158 |
|
| 1159 |
model_slug = model_name.replace("/", "-")
|
| 1160 |
stamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
|
| 1161 |
run_name = run_name or (
|
| 1162 |
+
f"CyberSecurity_OWASP-{model_slug}-grpo-level{difficulty}-"
|
| 1163 |
+
f"{reward_tracking_config['reward_variant']}-steps{max_steps}-seed{seed_start}-"
|
| 1164 |
+
f"{stamp}-{git_sha[:8]}"
|
| 1165 |
)
|
| 1166 |
output_dir = RUNS_DIR / run_name
|
| 1167 |
output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
| 1297 |
"reward_config_hash": reward_tracking_config["reward_config_hash"],
|
| 1298 |
"reward_stage": reward_tracking_config["reward_stage"],
|
| 1299 |
"reward_mode": reward_tracking_config["reward_mode"],
|
| 1300 |
+
"reward_variant": reward_tracking_config["reward_variant"],
|
| 1301 |
}
|
| 1302 |
)
|
| 1303 |
return obs.scenario_prompt
|
|
|
|
| 1658 |
"reward_config_hash": reward_tracking_config["reward_config_hash"],
|
| 1659 |
"reward_stage": reward_tracking_config["reward_stage"],
|
| 1660 |
"reward_mode": reward_tracking_config["reward_mode"],
|
| 1661 |
+
"reward_variant": reward_tracking_config["reward_variant"],
|
| 1662 |
}
|
| 1663 |
)
|
| 1664 |
try:
|
|
|
|
| 1750 |
print(f"Run name: {run_name}")
|
| 1751 |
print(f"Reward config: {reward_tracking_config['reward_config_id']}")
|
| 1752 |
print(f"Reward config hash: {reward_tracking_config['reward_config_hash']}")
|
| 1753 |
+
print(f"Reward variant: {reward_tracking_config['reward_variant']}")
|
| 1754 |
+
print(f"Reward config path: {reward_tracking_config['reward_config_path']}")
|
| 1755 |
+
print(f"Reward env overrides: {reward_env}")
|
| 1756 |
print(f"Model cache volume: {CACHE_VOLUME_NAME}")
|
| 1757 |
print(f"Scenario cache volume: {SCENARIO_CACHE_VOLUME_NAME}")
|
| 1758 |
print(f"Scenario cache dir: {scenario_cache_env['CYBERSECURITY_OWASP_SCENARIO_CACHE_DIR']}")
|
|
|
|
| 1764 |
print(f"Unsloth cache: {cache_env['UNSLOTH_CACHE_DIR']}")
|
| 1765 |
print(f"Triton cache: {cache_env['TRITON_CACHE_DIR']}")
|
| 1766 |
print(f"Hub push enabled: {push_to_hub}")
|
| 1767 |
+
if initial_adapter_path:
|
| 1768 |
+
print(f"Initial SFT adapter path: {initial_adapter_path}")
|
| 1769 |
+
if initial_adapter_repo_id:
|
| 1770 |
+
print(f"Initial SFT adapter repo: https://huggingface.co/{initial_adapter_repo_id}")
|
| 1771 |
print(
|
| 1772 |
"GRPO throughput config: "
|
| 1773 |
f"per_device_train_batch_size={per_device_train_batch_size}, "
|
|
|
|
| 1854 |
f"{exc!r}"
|
| 1855 |
)
|
| 1856 |
|
| 1857 |
+
adapter_source = initial_adapter_path
|
| 1858 |
+
if initial_adapter_repo_id:
|
| 1859 |
+
print(f"Downloading initial SFT adapter: {initial_adapter_repo_id}")
|
| 1860 |
+
adapter_source = snapshot_download(
|
| 1861 |
+
repo_id=initial_adapter_repo_id,
|
| 1862 |
+
cache_dir=str(HF_HUB_CACHE_DIR),
|
| 1863 |
+
token=hf_token,
|
| 1864 |
+
)
|
| 1865 |
+
cache_volume.commit()
|
| 1866 |
+
if adapter_source:
|
| 1867 |
+
print(f"Loading initial SFT adapter for trainable GRPO continuation: {adapter_source}")
|
| 1868 |
+
model = PeftModel.from_pretrained(model, adapter_source, is_trainable=True)
|
| 1869 |
+
if hasattr(model, "print_trainable_parameters"):
|
| 1870 |
+
model.print_trainable_parameters()
|
| 1871 |
+
else:
|
| 1872 |
+
model = model_api.get_peft_model(
|
| 1873 |
+
model,
|
| 1874 |
+
r=lora_rank,
|
| 1875 |
+
target_modules=[
|
| 1876 |
+
"q_proj",
|
| 1877 |
+
"k_proj",
|
| 1878 |
+
"v_proj",
|
| 1879 |
+
"o_proj",
|
| 1880 |
+
"gate_proj",
|
| 1881 |
+
"up_proj",
|
| 1882 |
+
"down_proj",
|
| 1883 |
+
],
|
| 1884 |
+
lora_alpha=lora_rank * 2,
|
| 1885 |
+
use_gradient_checkpointing="unsloth",
|
| 1886 |
+
random_state=3407,
|
| 1887 |
+
)
|
| 1888 |
if hasattr(model_api, "for_training"):
|
| 1889 |
model_api.for_training(model)
|
| 1890 |
+
print("LoRA adapter ready and model switched to training mode.")
|
| 1891 |
|
| 1892 |
grpo_config_values = {
|
| 1893 |
"temperature": 1.0,
|
|
|
|
| 2010 |
"difficulty": difficulty,
|
| 2011 |
"split": split,
|
| 2012 |
"model_name": model_name,
|
| 2013 |
+
"initial_adapter_path": initial_adapter_path,
|
| 2014 |
+
"initial_adapter_repo_id": initial_adapter_repo_id,
|
| 2015 |
"max_completion_length": max_completion_length,
|
| 2016 |
"num_generations": num_generations,
|
| 2017 |
"per_device_train_batch_size": per_device_train_batch_size,
|
|
|
|
| 2026 |
"push_to_hub": push_to_hub,
|
| 2027 |
"scenario_cache_volume": SCENARIO_CACHE_VOLUME_NAME,
|
| 2028 |
"scenario_cache_mode": "require",
|
| 2029 |
+
"reward_variant": reward_tracking_config["reward_variant"],
|
| 2030 |
**reward_tracking_config,
|
| 2031 |
}
|
| 2032 |
|
|
|
|
| 2036 |
mode: str = "train",
|
| 2037 |
env_repo_id: str = "",
|
| 2038 |
output_repo_id: str = "",
|
| 2039 |
+
initial_adapter_path: str = "",
|
| 2040 |
+
initial_adapter_repo_id: str = "",
|
| 2041 |
max_steps: int = 10,
|
| 2042 |
dataset_size: int = 16,
|
| 2043 |
difficulty: int = 0,
|
|
|
|
| 2062 |
repo_branch: str = PUBLIC_REPO_BRANCH,
|
| 2063 |
detach: bool = False,
|
| 2064 |
push_to_hub: bool = False,
|
| 2065 |
+
reward_config: str = "",
|
| 2066 |
+
reward_variant: str = "",
|
| 2067 |
cache_seed_start: int = 0,
|
| 2068 |
cache_difficulty_buckets: int = 0,
|
| 2069 |
cache_train_per_bucket: int = 0,
|
|
|
|
| 2117 |
source_mode=source_mode,
|
| 2118 |
repo_url=repo_url,
|
| 2119 |
repo_branch=repo_branch,
|
| 2120 |
+
reward_config=reward_config,
|
| 2121 |
+
reward_variant=reward_variant,
|
| 2122 |
)
|
| 2123 |
if detach:
|
| 2124 |
call = run_cybersecurity_owasp_baseline.spawn(**kwargs)
|
|
|
|
| 2177 |
if git_sha == "nogit":
|
| 2178 |
try:
|
| 2179 |
git_sha = subprocess.check_output(
|
| 2180 |
+
[
|
| 2181 |
+
"git",
|
| 2182 |
+
"-c",
|
| 2183 |
+
f"safe.directory={PROJECT_ROOT.as_posix()}",
|
| 2184 |
+
"rev-parse",
|
| 2185 |
+
"HEAD",
|
| 2186 |
+
],
|
| 2187 |
cwd=PROJECT_ROOT,
|
| 2188 |
text=True,
|
| 2189 |
stderr=subprocess.DEVNULL,
|
|
|
|
| 2193 |
|
| 2194 |
model_slug = model_name.replace("/", "-")
|
| 2195 |
local_stamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
|
| 2196 |
+
variant_tag = reward_variant or "default"
|
| 2197 |
run_name = run_name or (
|
| 2198 |
f"CyberSecurity_OWASP-{model_slug}-grpo-level{difficulty}-"
|
| 2199 |
+
f"{variant_tag}-steps{max_steps}-seed{seed_start}-{local_stamp}-{git_sha[:8]}"
|
| 2200 |
)
|
| 2201 |
|
| 2202 |
print(f"Run name: {run_name}")
|
| 2203 |
+
print(f"Reward variant: {variant_tag}")
|
| 2204 |
+
print(f"Reward config path: {reward_config or '(default training/configs/grpo_small.yaml)'}")
|
| 2205 |
print(f"Source mode: {source_mode}")
|
| 2206 |
if source_mode == "public":
|
| 2207 |
print(f"Public repo: {repo_url}@{repo_branch}")
|
|
|
|
| 2217 |
f"<hf-user>/CyberSecurity_OWASP-{_model_repo_slug(model_name)}-grpo-lora"
|
| 2218 |
)
|
| 2219 |
print(f"Hub push enabled: {push_to_hub}")
|
| 2220 |
+
if initial_adapter_path:
|
| 2221 |
+
print(f"Initial SFT adapter path: {initial_adapter_path}")
|
| 2222 |
+
if initial_adapter_repo_id:
|
| 2223 |
+
print(f"Initial SFT adapter repo: https://huggingface.co/{initial_adapter_repo_id}")
|
| 2224 |
print(f"Model cache volume: {CACHE_VOLUME_NAME}")
|
| 2225 |
print(f"Scenario cache volume: {SCENARIO_CACHE_VOLUME_NAME}")
|
| 2226 |
print(
|
|
|
|
| 2254 |
kwargs = dict(
|
| 2255 |
env_repo_id=env_repo_id,
|
| 2256 |
output_repo_id=output_repo_id,
|
| 2257 |
+
initial_adapter_path=initial_adapter_path,
|
| 2258 |
+
initial_adapter_repo_id=initial_adapter_repo_id,
|
| 2259 |
max_steps=max_steps,
|
| 2260 |
dataset_size=dataset_size,
|
| 2261 |
difficulty=difficulty,
|
|
|
|
| 2279 |
repo_url=repo_url,
|
| 2280 |
repo_branch=repo_branch,
|
| 2281 |
push_to_hub=push_to_hub,
|
| 2282 |
+
reward_config=reward_config,
|
| 2283 |
+
reward_variant=reward_variant,
|
| 2284 |
)
|
| 2285 |
preflight = verify_modal_scenario_cache_for_training.remote(
|
| 2286 |
split=split,
|
scripts/modal_train_sft.py
ADDED
|
@@ -0,0 +1,442 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Modal SFT launcher for CyberSecurity_OWASP action JSON data.
|
| 2 |
+
|
| 3 |
+
This trains a LoRA adapter on chat JSONL generated by
|
| 4 |
+
``scripts/generate_sft_dataset.py``. It intentionally mirrors the repo's Modal
|
| 5 |
+
training pattern: local execution only launches remote jobs, while training runs
|
| 6 |
+
inside Modal and saves adapters to the persistent run volume.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import json
|
| 12 |
+
import os
|
| 13 |
+
import pathlib
|
| 14 |
+
import subprocess
|
| 15 |
+
from datetime import datetime, timezone
|
| 16 |
+
from typing import Any
|
| 17 |
+
|
| 18 |
+
import modal
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
APP_NAME = "CyberSecurity_OWASP-sft"
|
| 22 |
+
VOLUME_NAME = "CyberSecurity_OWASP-grpo-runs"
|
| 23 |
+
CACHE_VOLUME_NAME = "CyberSecurity_OWASP-model-cache"
|
| 24 |
+
SECRET_NAME = "CyberSecurity_OWASP-secrets"
|
| 25 |
+
RUNS_DIR = pathlib.Path("/runs")
|
| 26 |
+
CACHE_DIR = pathlib.Path("/cache")
|
| 27 |
+
HF_HOME_DIR = CACHE_DIR / "huggingface"
|
| 28 |
+
HF_HUB_CACHE_DIR = HF_HOME_DIR / "hub"
|
| 29 |
+
TORCH_HOME_DIR = CACHE_DIR / "torch"
|
| 30 |
+
XDG_CACHE_DIR = CACHE_DIR / "xdg"
|
| 31 |
+
UNSLOTH_CACHE_DIR = CACHE_DIR / "unsloth"
|
| 32 |
+
TRITON_CACHE_DIR = CACHE_DIR / "triton"
|
| 33 |
+
REMOTE_PROJECT = "/root/CyberSecurity_OWASP"
|
| 34 |
+
PROJECT_ROOT = pathlib.Path(__file__).resolve().parents[1]
|
| 35 |
+
DEFAULT_GEMMA_MODEL = "unsloth/gemma-4-E2B-it"
|
| 36 |
+
PUBLIC_REPO_URL = "https://github.com/humandotlearning/CyberSecurity_OWASP.git"
|
| 37 |
+
PUBLIC_REPO_BRANCH = "master"
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _ensure_gemma4_model(model_name: str) -> str:
|
| 41 |
+
if model_name != DEFAULT_GEMMA_MODEL:
|
| 42 |
+
raise ValueError(
|
| 43 |
+
"CyberSecurity_OWASP SFT is pinned to "
|
| 44 |
+
f"{DEFAULT_GEMMA_MODEL}; received {model_name!r}."
|
| 45 |
+
)
|
| 46 |
+
return model_name
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def _model_repo_slug(model_name: str) -> str:
|
| 50 |
+
return model_name.replace("/", "-").replace("_", "-").replace(".", "-").lower()
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _configure_modal_cache_env() -> dict[str, str]:
|
| 54 |
+
values = {
|
| 55 |
+
"HF_HOME": str(HF_HOME_DIR),
|
| 56 |
+
"HF_HUB_CACHE": str(HF_HUB_CACHE_DIR),
|
| 57 |
+
"TRANSFORMERS_CACHE": str(HF_HUB_CACHE_DIR),
|
| 58 |
+
"TORCH_HOME": str(TORCH_HOME_DIR),
|
| 59 |
+
"XDG_CACHE_HOME": str(XDG_CACHE_DIR),
|
| 60 |
+
"UNSLOTH_CACHE_DIR": str(UNSLOTH_CACHE_DIR),
|
| 61 |
+
"UNSLOTH_COMPILE_CACHE": str(UNSLOTH_CACHE_DIR / "compile"),
|
| 62 |
+
"TRITON_CACHE_DIR": str(TRITON_CACHE_DIR),
|
| 63 |
+
}
|
| 64 |
+
for key, value in values.items():
|
| 65 |
+
os.environ[key] = value
|
| 66 |
+
for path in {
|
| 67 |
+
CACHE_DIR,
|
| 68 |
+
HF_HOME_DIR,
|
| 69 |
+
HF_HUB_CACHE_DIR,
|
| 70 |
+
TORCH_HOME_DIR,
|
| 71 |
+
XDG_CACHE_DIR,
|
| 72 |
+
UNSLOTH_CACHE_DIR,
|
| 73 |
+
UNSLOTH_CACHE_DIR / "compile",
|
| 74 |
+
TRITON_CACHE_DIR,
|
| 75 |
+
}:
|
| 76 |
+
path.mkdir(parents=True, exist_ok=True)
|
| 77 |
+
return values
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def _cli_arg_value(name: str, default: str = "") -> str:
|
| 81 |
+
import sys
|
| 82 |
+
|
| 83 |
+
args = sys.argv[1:]
|
| 84 |
+
flag = f"--{name}"
|
| 85 |
+
for index, arg in enumerate(args):
|
| 86 |
+
if arg == flag and index + 1 < len(args):
|
| 87 |
+
return args[index + 1]
|
| 88 |
+
if arg.startswith(f"{flag}="):
|
| 89 |
+
return arg.split("=", 1)[1]
|
| 90 |
+
return default
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def _source_mode() -> str:
|
| 94 |
+
return _cli_arg_value("source-mode", os.environ.get("MODAL_SOURCE_MODE", "local"))
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def _training_image() -> modal.Image:
|
| 98 |
+
image = (
|
| 99 |
+
modal.Image.from_registry(
|
| 100 |
+
"nvidia/cuda:12.8.0-devel-ubuntu22.04",
|
| 101 |
+
add_python="3.11",
|
| 102 |
+
)
|
| 103 |
+
.apt_install("git", "build-essential", "curl")
|
| 104 |
+
.uv_pip_install(
|
| 105 |
+
"torch==2.10.0",
|
| 106 |
+
"triton>=3.4.0",
|
| 107 |
+
"torchvision==0.25.0",
|
| 108 |
+
"bitsandbytes",
|
| 109 |
+
"accelerate",
|
| 110 |
+
"datasets",
|
| 111 |
+
"huggingface_hub",
|
| 112 |
+
"peft",
|
| 113 |
+
"tokenizers",
|
| 114 |
+
"trackio>=0.25.0",
|
| 115 |
+
"transformers>=5.5.0",
|
| 116 |
+
"trl>=0.28.0",
|
| 117 |
+
)
|
| 118 |
+
.uv_pip_install(
|
| 119 |
+
"unsloth_zoo[base] @ git+https://github.com/unslothai/unsloth-zoo",
|
| 120 |
+
"unsloth[base] @ git+https://github.com/unslothai/unsloth",
|
| 121 |
+
)
|
| 122 |
+
.uv_pip_install("timm", extra_options="--no-deps")
|
| 123 |
+
.uv_pip_install("pydantic==2.10.6")
|
| 124 |
+
)
|
| 125 |
+
if _source_mode() == "public":
|
| 126 |
+
repo_url = _cli_arg_value("repo-url", PUBLIC_REPO_URL)
|
| 127 |
+
repo_branch = _cli_arg_value("repo-branch", PUBLIC_REPO_BRANCH)
|
| 128 |
+
image = image.run_commands(
|
| 129 |
+
f"git clone --depth 1 --branch {repo_branch} {repo_url} {REMOTE_PROJECT}",
|
| 130 |
+
f"python -m pip install --no-deps -e {REMOTE_PROJECT}",
|
| 131 |
+
)
|
| 132 |
+
else:
|
| 133 |
+
image = image.add_local_dir(
|
| 134 |
+
PROJECT_ROOT,
|
| 135 |
+
remote_path=REMOTE_PROJECT,
|
| 136 |
+
copy=True,
|
| 137 |
+
ignore=[
|
| 138 |
+
".git",
|
| 139 |
+
".venv",
|
| 140 |
+
".env",
|
| 141 |
+
".env.*",
|
| 142 |
+
"__pycache__",
|
| 143 |
+
".pytest_cache",
|
| 144 |
+
"outputs",
|
| 145 |
+
"*.pyc",
|
| 146 |
+
],
|
| 147 |
+
)
|
| 148 |
+
image = image.run_commands(f"python -m pip install --no-deps -e {REMOTE_PROJECT}")
|
| 149 |
+
return image.workdir(REMOTE_PROJECT)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
app = modal.App(APP_NAME)
|
| 153 |
+
volume = modal.Volume.from_name(VOLUME_NAME, create_if_missing=True)
|
| 154 |
+
cache_volume = modal.Volume.from_name(CACHE_VOLUME_NAME, create_if_missing=True)
|
| 155 |
+
training_image = _training_image()
|
| 156 |
+
secrets = [modal.Secret.from_name(SECRET_NAME, required_keys=["HF_TOKEN"])]
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
@app.function(
|
| 160 |
+
image=modal.Image.debian_slim(python_version="3.11"),
|
| 161 |
+
timeout=60 * 20,
|
| 162 |
+
volumes={RUNS_DIR: volume},
|
| 163 |
+
)
|
| 164 |
+
def upload_sft_jsonl(relative_path: str, content: str) -> str:
|
| 165 |
+
target = RUNS_DIR / relative_path
|
| 166 |
+
target.parent.mkdir(parents=True, exist_ok=True)
|
| 167 |
+
target.write_text(content, encoding="utf-8")
|
| 168 |
+
volume.commit()
|
| 169 |
+
return str(target)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
@app.function(
|
| 173 |
+
image=training_image,
|
| 174 |
+
gpu="L4",
|
| 175 |
+
timeout=12 * 60 * 60,
|
| 176 |
+
volumes={RUNS_DIR: volume, CACHE_DIR: cache_volume},
|
| 177 |
+
secrets=secrets,
|
| 178 |
+
)
|
| 179 |
+
def train_cybersecurity_owasp_sft(
|
| 180 |
+
train_jsonl: str = "/runs/sft/train.jsonl",
|
| 181 |
+
validation_jsonl: str = "/runs/sft/validation.jsonl",
|
| 182 |
+
output_repo_id: str = "",
|
| 183 |
+
model_name: str = DEFAULT_GEMMA_MODEL,
|
| 184 |
+
run_name: str = "",
|
| 185 |
+
max_seq_length: int = 4096,
|
| 186 |
+
max_steps: int = 100,
|
| 187 |
+
num_train_epochs: float = 1.0,
|
| 188 |
+
per_device_train_batch_size: int = 1,
|
| 189 |
+
gradient_accumulation_steps: int = 16,
|
| 190 |
+
learning_rate: float = 2e-5,
|
| 191 |
+
lora_rank: int = 32,
|
| 192 |
+
trackio_space_id: str = "Humanlearning/CyberSecurity_OWASP-trackio",
|
| 193 |
+
trackio_project: str = "CyberSecurity_OWASP-sft",
|
| 194 |
+
push_to_hub: bool = False,
|
| 195 |
+
) -> dict[str, Any]:
|
| 196 |
+
import inspect
|
| 197 |
+
|
| 198 |
+
from datasets import load_dataset
|
| 199 |
+
from huggingface_hub import snapshot_download, whoami
|
| 200 |
+
from trl import SFTConfig, SFTTrainer
|
| 201 |
+
from trl.chat_template_utils import add_response_schema
|
| 202 |
+
from unsloth import FastVisionModel
|
| 203 |
+
|
| 204 |
+
model_name = _ensure_gemma4_model(model_name)
|
| 205 |
+
cache_env = _configure_modal_cache_env()
|
| 206 |
+
hf_token = os.environ.get("HF_TOKEN")
|
| 207 |
+
if not hf_token:
|
| 208 |
+
raise RuntimeError(f"HF_TOKEN is missing from the Modal secret {SECRET_NAME}.")
|
| 209 |
+
|
| 210 |
+
user = whoami(token=hf_token)["name"]
|
| 211 |
+
output_repo_id = output_repo_id or (
|
| 212 |
+
f"{user}/CyberSecurity_OWASP-{_model_repo_slug(model_name)}-sft-lora"
|
| 213 |
+
)
|
| 214 |
+
stamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
|
| 215 |
+
run_name = run_name or f"CyberSecurity_OWASP-{_model_repo_slug(model_name)}-sft-{stamp}"
|
| 216 |
+
output_dir = RUNS_DIR / run_name
|
| 217 |
+
adapter_dir = output_dir / "sft_adapter"
|
| 218 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
| 219 |
+
|
| 220 |
+
data_files = {"train": train_jsonl}
|
| 221 |
+
validation_path = pathlib.Path(validation_jsonl)
|
| 222 |
+
has_validation = validation_path.exists() and validation_path.stat().st_size > 0
|
| 223 |
+
if has_validation:
|
| 224 |
+
data_files["validation"] = validation_jsonl
|
| 225 |
+
dataset = load_dataset("json", data_files=data_files)
|
| 226 |
+
|
| 227 |
+
print(f"SFT run name: {run_name}")
|
| 228 |
+
print(f"Model: {model_name}")
|
| 229 |
+
print(f"Train JSONL: {train_jsonl}")
|
| 230 |
+
print(f"Validation JSONL: {validation_jsonl if has_validation else '(none)'}")
|
| 231 |
+
print(f"Output adapter dir: {adapter_dir}")
|
| 232 |
+
print(f"Output repo: https://huggingface.co/{output_repo_id}")
|
| 233 |
+
print(f"Trackio Space: https://huggingface.co/spaces/{trackio_space_id}")
|
| 234 |
+
print(f"HF_HUB_CACHE: {cache_env['HF_HUB_CACHE']}")
|
| 235 |
+
|
| 236 |
+
try:
|
| 237 |
+
snapshot_download(repo_id=model_name, cache_dir=str(HF_HUB_CACHE_DIR), token=hf_token)
|
| 238 |
+
cache_volume.commit()
|
| 239 |
+
except Exception as exc:
|
| 240 |
+
print(f"Model snapshot prefetch skipped; loader will retry directly. Error: {exc!r}")
|
| 241 |
+
|
| 242 |
+
model_api = FastVisionModel
|
| 243 |
+
model, tokenizer = model_api.from_pretrained(
|
| 244 |
+
model_name=model_name,
|
| 245 |
+
max_seq_length=max_seq_length,
|
| 246 |
+
load_in_4bit=False,
|
| 247 |
+
fast_inference=False,
|
| 248 |
+
cache_dir=str(HF_HUB_CACHE_DIR),
|
| 249 |
+
token=hf_token,
|
| 250 |
+
)
|
| 251 |
+
try:
|
| 252 |
+
tokenizer = add_response_schema(tokenizer)
|
| 253 |
+
except Exception as exc:
|
| 254 |
+
print(f"Tokenizer response schema add skipped: {exc!r}")
|
| 255 |
+
|
| 256 |
+
model = model_api.get_peft_model(
|
| 257 |
+
model,
|
| 258 |
+
r=lora_rank,
|
| 259 |
+
target_modules=[
|
| 260 |
+
"q_proj",
|
| 261 |
+
"k_proj",
|
| 262 |
+
"v_proj",
|
| 263 |
+
"o_proj",
|
| 264 |
+
"gate_proj",
|
| 265 |
+
"up_proj",
|
| 266 |
+
"down_proj",
|
| 267 |
+
],
|
| 268 |
+
lora_alpha=lora_rank * 2,
|
| 269 |
+
use_gradient_checkpointing="unsloth",
|
| 270 |
+
random_state=3407,
|
| 271 |
+
)
|
| 272 |
+
if hasattr(model_api, "for_training"):
|
| 273 |
+
model_api.for_training(model)
|
| 274 |
+
|
| 275 |
+
sft_values = {
|
| 276 |
+
"output_dir": str(output_dir),
|
| 277 |
+
"max_seq_length": max_seq_length,
|
| 278 |
+
"max_steps": max_steps,
|
| 279 |
+
"num_train_epochs": num_train_epochs,
|
| 280 |
+
"per_device_train_batch_size": per_device_train_batch_size,
|
| 281 |
+
"gradient_accumulation_steps": gradient_accumulation_steps,
|
| 282 |
+
"learning_rate": learning_rate,
|
| 283 |
+
"logging_steps": 1,
|
| 284 |
+
"save_steps": max(10, max_steps),
|
| 285 |
+
"report_to": "trackio",
|
| 286 |
+
"project": trackio_project,
|
| 287 |
+
"trackio_space_id": trackio_space_id,
|
| 288 |
+
"run_name": run_name,
|
| 289 |
+
"assistant_only_loss": True,
|
| 290 |
+
"packing": False,
|
| 291 |
+
"gradient_checkpointing": True,
|
| 292 |
+
"gradient_checkpointing_kwargs": {"use_reentrant": False},
|
| 293 |
+
"push_to_hub": push_to_hub,
|
| 294 |
+
"hub_model_id": output_repo_id,
|
| 295 |
+
"hub_private_repo": True,
|
| 296 |
+
}
|
| 297 |
+
sft_parameters = set(inspect.signature(SFTConfig).parameters)
|
| 298 |
+
skipped = sorted(set(sft_values) - sft_parameters)
|
| 299 |
+
if skipped:
|
| 300 |
+
print(f"Skipping unsupported SFTConfig keys: {skipped}")
|
| 301 |
+
training_args = SFTConfig(
|
| 302 |
+
**{key: value for key, value in sft_values.items() if key in sft_parameters}
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
trainer_values = {
|
| 306 |
+
"model": model,
|
| 307 |
+
"processing_class": tokenizer,
|
| 308 |
+
"args": training_args,
|
| 309 |
+
"train_dataset": dataset["train"],
|
| 310 |
+
"eval_dataset": dataset["validation"] if has_validation else None,
|
| 311 |
+
}
|
| 312 |
+
trainer_parameters = set(inspect.signature(SFTTrainer).parameters)
|
| 313 |
+
skipped_trainer = sorted(
|
| 314 |
+
key for key, value in trainer_values.items() if key not in trainer_parameters and value is not None
|
| 315 |
+
)
|
| 316 |
+
if skipped_trainer:
|
| 317 |
+
print(f"Skipping unsupported SFTTrainer keys: {skipped_trainer}")
|
| 318 |
+
trainer = SFTTrainer(
|
| 319 |
+
**{
|
| 320 |
+
key: value
|
| 321 |
+
for key, value in trainer_values.items()
|
| 322 |
+
if value is not None and key in trainer_parameters
|
| 323 |
+
}
|
| 324 |
+
)
|
| 325 |
+
trainer.train()
|
| 326 |
+
trainer.save_model(str(adapter_dir))
|
| 327 |
+
if push_to_hub:
|
| 328 |
+
trainer.push_to_hub()
|
| 329 |
+
volume.commit()
|
| 330 |
+
cache_volume.commit()
|
| 331 |
+
return {
|
| 332 |
+
"run_name": run_name,
|
| 333 |
+
"model_name": model_name,
|
| 334 |
+
"adapter_dir": str(adapter_dir),
|
| 335 |
+
"output_repo_id": output_repo_id,
|
| 336 |
+
"train_jsonl": train_jsonl,
|
| 337 |
+
"validation_jsonl": validation_jsonl if has_validation else "",
|
| 338 |
+
"max_steps": max_steps,
|
| 339 |
+
"push_to_hub": push_to_hub,
|
| 340 |
+
"trackio_space_id": trackio_space_id,
|
| 341 |
+
"trackio_project": trackio_project,
|
| 342 |
+
}
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def _git_sha(default: str = "nogit") -> str:
|
| 346 |
+
try:
|
| 347 |
+
return subprocess.check_output(
|
| 348 |
+
[
|
| 349 |
+
"git",
|
| 350 |
+
"-c",
|
| 351 |
+
f"safe.directory={PROJECT_ROOT.as_posix()}",
|
| 352 |
+
"rev-parse",
|
| 353 |
+
"HEAD",
|
| 354 |
+
],
|
| 355 |
+
cwd=PROJECT_ROOT,
|
| 356 |
+
text=True,
|
| 357 |
+
stderr=subprocess.DEVNULL,
|
| 358 |
+
).strip()
|
| 359 |
+
except Exception:
|
| 360 |
+
return default
|
| 361 |
+
|
| 362 |
+
|
| 363 |
+
@app.local_entrypoint()
|
| 364 |
+
def main(
|
| 365 |
+
mode: str = "train",
|
| 366 |
+
local_train_path: str = "outputs/sft/train.jsonl",
|
| 367 |
+
local_validation_path: str = "outputs/sft/validation.jsonl",
|
| 368 |
+
train_jsonl: str = "/runs/sft/train.jsonl",
|
| 369 |
+
validation_jsonl: str = "/runs/sft/validation.jsonl",
|
| 370 |
+
output_repo_id: str = "",
|
| 371 |
+
model_name: str = DEFAULT_GEMMA_MODEL,
|
| 372 |
+
run_name: str = "",
|
| 373 |
+
max_seq_length: int = 4096,
|
| 374 |
+
max_steps: int = 100,
|
| 375 |
+
num_train_epochs: float = 1.0,
|
| 376 |
+
per_device_train_batch_size: int = 1,
|
| 377 |
+
gradient_accumulation_steps: int = 16,
|
| 378 |
+
learning_rate: float = 2e-5,
|
| 379 |
+
lora_rank: int = 32,
|
| 380 |
+
trackio_space_id: str = "Humanlearning/CyberSecurity_OWASP-trackio",
|
| 381 |
+
trackio_project: str = "CyberSecurity_OWASP-sft",
|
| 382 |
+
source_mode: str = "local",
|
| 383 |
+
repo_url: str = PUBLIC_REPO_URL,
|
| 384 |
+
repo_branch: str = PUBLIC_REPO_BRANCH,
|
| 385 |
+
detach: bool = False,
|
| 386 |
+
push_to_hub: bool = False,
|
| 387 |
+
) -> None:
|
| 388 |
+
del source_mode, repo_url, repo_branch # consumed during image construction
|
| 389 |
+
model_name = _ensure_gemma4_model(model_name)
|
| 390 |
+
if mode not in {"upload", "train"}:
|
| 391 |
+
raise ValueError("mode must be 'upload' or 'train'")
|
| 392 |
+
|
| 393 |
+
local_train = pathlib.Path(local_train_path)
|
| 394 |
+
local_validation = pathlib.Path(local_validation_path)
|
| 395 |
+
if local_train.exists():
|
| 396 |
+
uploaded = upload_sft_jsonl.remote(
|
| 397 |
+
"sft/train.jsonl",
|
| 398 |
+
local_train.read_text(encoding="utf-8"),
|
| 399 |
+
)
|
| 400 |
+
print(f"Uploaded train JSONL: {uploaded}")
|
| 401 |
+
train_jsonl = uploaded
|
| 402 |
+
if local_validation.exists():
|
| 403 |
+
uploaded_validation = upload_sft_jsonl.remote(
|
| 404 |
+
"sft/validation.jsonl",
|
| 405 |
+
local_validation.read_text(encoding="utf-8"),
|
| 406 |
+
)
|
| 407 |
+
print(f"Uploaded validation JSONL: {uploaded_validation}")
|
| 408 |
+
validation_jsonl = uploaded_validation
|
| 409 |
+
if mode == "upload":
|
| 410 |
+
return
|
| 411 |
+
|
| 412 |
+
if not run_name:
|
| 413 |
+
stamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
|
| 414 |
+
run_name = f"CyberSecurity_OWASP-{_model_repo_slug(model_name)}-sft-{stamp}-{_git_sha()[:8]}"
|
| 415 |
+
|
| 416 |
+
kwargs = dict(
|
| 417 |
+
train_jsonl=train_jsonl,
|
| 418 |
+
validation_jsonl=validation_jsonl,
|
| 419 |
+
output_repo_id=output_repo_id,
|
| 420 |
+
model_name=model_name,
|
| 421 |
+
run_name=run_name,
|
| 422 |
+
max_seq_length=max_seq_length,
|
| 423 |
+
max_steps=max_steps,
|
| 424 |
+
num_train_epochs=num_train_epochs,
|
| 425 |
+
per_device_train_batch_size=per_device_train_batch_size,
|
| 426 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
| 427 |
+
learning_rate=learning_rate,
|
| 428 |
+
lora_rank=lora_rank,
|
| 429 |
+
trackio_space_id=trackio_space_id,
|
| 430 |
+
trackio_project=trackio_project,
|
| 431 |
+
push_to_hub=push_to_hub,
|
| 432 |
+
)
|
| 433 |
+
print(f"SFT run name: {run_name}")
|
| 434 |
+
print(f"Train JSONL: {train_jsonl}")
|
| 435 |
+
print(f"Validation JSONL: {validation_jsonl}")
|
| 436 |
+
print(f"Hub push enabled: {push_to_hub}")
|
| 437 |
+
if detach:
|
| 438 |
+
call = train_cybersecurity_owasp_sft.spawn(**kwargs)
|
| 439 |
+
print(f"Spawned Modal SFT call: {call.object_id}")
|
| 440 |
+
else:
|
| 441 |
+
result = train_cybersecurity_owasp_sft.remote(**kwargs)
|
| 442 |
+
print(json.dumps(result, indent=2, sort_keys=True))
|
tests/test_reward_config.py
CHANGED
|
@@ -68,6 +68,45 @@ def test_reward_config_hash_and_flattened_values_are_deterministic(monkeypatch):
|
|
| 68 |
assert rows["hidden_file_probe"]["terminate"] is True
|
| 69 |
|
| 70 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
def test_reward_config_rejects_missing_descriptions(monkeypatch):
|
| 72 |
config_path = Path("outputs/test_reward_config_bad.yaml")
|
| 73 |
config_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
| 68 |
assert rows["hidden_file_probe"]["terminate"] is True
|
| 69 |
|
| 70 |
|
| 71 |
+
def test_reward_ablation_configs_extend_default_and_have_unique_hashes(monkeypatch):
|
| 72 |
+
monkeypatch.setenv("CYBERSECURITY_OWASP_REWARD_MODE", "dense_train")
|
| 73 |
+
paths = [
|
| 74 |
+
Path("training/configs/reward_ablations/A0_sparse_terminal_only.yaml"),
|
| 75 |
+
Path("training/configs/reward_ablations/A2_reduced_shaping.yaml"),
|
| 76 |
+
Path("training/configs/reward_ablations/A6_visible_gate.yaml"),
|
| 77 |
+
Path("training/configs/reward_ablations/A7_evidence045.yaml"),
|
| 78 |
+
Path("training/configs/reward_ablations/A3_no_speed_token.yaml"),
|
| 79 |
+
]
|
| 80 |
+
|
| 81 |
+
settings_by_name = {path.name: load_reward_settings(path) for path in paths}
|
| 82 |
+
hashes = {reward_config_hash(settings) for settings in settings_by_name.values()}
|
| 83 |
+
|
| 84 |
+
assert len(hashes) == len(paths)
|
| 85 |
+
assert settings_by_name["A0_sparse_terminal_only.yaml"].shaping_weight == 0.0
|
| 86 |
+
assert settings_by_name["A0_sparse_terminal_only.yaml"].value("progressive_cap") == 0.0
|
| 87 |
+
assert settings_by_name["A0_sparse_terminal_only.yaml"].value("terminal_cap") == 12.0
|
| 88 |
+
assert settings_by_name["A2_reduced_shaping.yaml"].shaping_weight == 0.35
|
| 89 |
+
assert settings_by_name["A2_reduced_shaping.yaml"].value("progressive_cap") == 2.5
|
| 90 |
+
assert settings_by_name["A6_visible_gate.yaml"].value("visible_tests_improved") == 0.0
|
| 91 |
+
assert settings_by_name["A6_visible_gate.yaml"].value("app_boots_after_patch") == 0.10
|
| 92 |
+
assert settings_by_name["A7_evidence045.yaml"].value("local_evidence_found") == 0.45
|
| 93 |
+
assert settings_by_name["A3_no_speed_token.yaml"].value("speed_bonus") == 0.0
|
| 94 |
+
assert compute_token_penalty(850, settings_by_name["A3_no_speed_token.yaml"]) == 0.0
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def test_reward_config_run_config_includes_variant(monkeypatch):
|
| 98 |
+
monkeypatch.setenv("CYBERSECURITY_OWASP_REWARD_MODE", "dense_train")
|
| 99 |
+
monkeypatch.setenv("CYBERSECURITY_OWASP_REWARD_VARIANT", "abl-a2-shape035")
|
| 100 |
+
|
| 101 |
+
config = reward_config_run_config(
|
| 102 |
+
load_reward_settings("training/configs/reward_ablations/A2_reduced_shaping.yaml")
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
assert config["reward_variant"] == "abl-a2-shape035"
|
| 106 |
+
assert config["reward_config_source_name"] == "A2_reduced_shaping.yaml"
|
| 107 |
+
assert config["reward_config__shaping_weight__stage_value"] == 0.35
|
| 108 |
+
|
| 109 |
+
|
| 110 |
def test_reward_config_rejects_missing_descriptions(monkeypatch):
|
| 111 |
config_path = Path("outputs/test_reward_config_bad.yaml")
|
| 112 |
config_path.parent.mkdir(parents=True, exist_ok=True)
|
tests/test_sft_dataset_generation.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import importlib.util
|
| 2 |
+
import json
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
import uuid
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
|
| 8 |
+
from CyberSecurity_OWASP.models import CyberSecurityOWASPAction
|
| 9 |
+
from CyberSecurity_OWASP.server.CyberSecurity_OWASP_environment import (
|
| 10 |
+
CybersecurityOwaspEnvironment,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
MODULE_PATH = Path(__file__).resolve().parents[1] / "scripts" / "generate_sft_dataset.py"
|
| 15 |
+
SPEC = importlib.util.spec_from_file_location("generate_sft_dataset", MODULE_PATH)
|
| 16 |
+
generate_sft_dataset = importlib.util.module_from_spec(SPEC)
|
| 17 |
+
assert SPEC.loader is not None
|
| 18 |
+
sys.modules[SPEC.name] = generate_sft_dataset
|
| 19 |
+
SPEC.loader.exec_module(generate_sft_dataset)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _isolated_out_dir(label: str) -> Path:
|
| 23 |
+
root = Path("outputs") / "sft_dataset_tests" / f"{label}_{uuid.uuid4().hex[:8]}"
|
| 24 |
+
workspace_root = root / "workspaces"
|
| 25 |
+
workspace_root.mkdir(parents=True, exist_ok=True)
|
| 26 |
+
os.environ["CYBERSECURITY_OWASP_WORKSPACE_ROOT"] = str(workspace_root)
|
| 27 |
+
return root / "sft"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def test_extracts_and_validates_action_json():
|
| 31 |
+
action = generate_sft_dataset.parse_action_text(
|
| 32 |
+
'```json\n{"tool_name":"inspect_policy_graph","arguments":{}}\n```'
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
assert isinstance(action, CyberSecurityOWASPAction)
|
| 36 |
+
assert action.tool_name == "inspect_policy_graph"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def test_prompt_uses_visible_observation_only():
|
| 40 |
+
_isolated_out_dir("prompt")
|
| 41 |
+
env = CybersecurityOwaspEnvironment()
|
| 42 |
+
try:
|
| 43 |
+
obs = env.reset(seed=501, split="train", difficulty=0)
|
| 44 |
+
prompt = generate_sft_dataset.build_user_prompt(obs, [])
|
| 45 |
+
finally:
|
| 46 |
+
env.close()
|
| 47 |
+
|
| 48 |
+
lowered = prompt.lower()
|
| 49 |
+
assert "hidden_facts" not in lowered
|
| 50 |
+
assert "oracle_hidden_focus" not in lowered
|
| 51 |
+
assert "reward_engine" not in lowered
|
| 52 |
+
assert "validators.py" not in lowered
|
| 53 |
+
assert "tests/hidden" not in lowered
|
| 54 |
+
assert "hidden tests" not in lowered
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def test_chat_row_matches_conversational_sft_shape():
|
| 58 |
+
_isolated_out_dir("chat_row")
|
| 59 |
+
env = CybersecurityOwaspEnvironment()
|
| 60 |
+
try:
|
| 61 |
+
obs = env.reset(seed=502, split="train", difficulty=0)
|
| 62 |
+
messages = generate_sft_dataset.build_chat_messages(obs, [])
|
| 63 |
+
action = CyberSecurityOWASPAction(tool_name="inspect_policy_graph", arguments={})
|
| 64 |
+
row = generate_sft_dataset.make_chat_row(
|
| 65 |
+
messages=messages,
|
| 66 |
+
action=action,
|
| 67 |
+
metadata={
|
| 68 |
+
"target_model": generate_sft_dataset.DEFAULT_TARGET_MODEL,
|
| 69 |
+
"teacher_model": generate_sft_dataset.DEFAULT_TEACHER_MODEL,
|
| 70 |
+
"seed": 502,
|
| 71 |
+
},
|
| 72 |
+
)
|
| 73 |
+
finally:
|
| 74 |
+
env.close()
|
| 75 |
+
|
| 76 |
+
assert [message["role"] for message in row["messages"]] == [
|
| 77 |
+
"system",
|
| 78 |
+
"user",
|
| 79 |
+
"assistant",
|
| 80 |
+
]
|
| 81 |
+
assert json.loads(row["messages"][-1]["content"]) == action.model_dump()
|
| 82 |
+
assert row["metadata"]["target_model"] == "unsloth/gemma-4-E2B-it"
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def test_dry_run_oracle_creates_chat_jsonl_without_network():
|
| 86 |
+
out_dir = _isolated_out_dir("dry_run")
|
| 87 |
+
manifest = generate_sft_dataset.generate_dataset(
|
| 88 |
+
generate_sft_dataset.DatasetConfig(
|
| 89 |
+
episodes=2,
|
| 90 |
+
validation_episodes=1,
|
| 91 |
+
out_dir=out_dir,
|
| 92 |
+
dry_run_oracle=True,
|
| 93 |
+
)
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
assert manifest["episodes_attempted"] == 3
|
| 97 |
+
assert manifest["episodes_accepted"] == 3
|
| 98 |
+
assert (out_dir / "train.jsonl").exists()
|
| 99 |
+
assert (out_dir / "validation.jsonl").exists()
|
| 100 |
+
train_rows = [
|
| 101 |
+
json.loads(line)
|
| 102 |
+
for line in (out_dir / "train.jsonl").read_text(encoding="utf-8").splitlines()
|
| 103 |
+
if line.strip()
|
| 104 |
+
]
|
| 105 |
+
validation_rows = [
|
| 106 |
+
json.loads(line)
|
| 107 |
+
for line in (out_dir / "validation.jsonl").read_text(encoding="utf-8").splitlines()
|
| 108 |
+
if line.strip()
|
| 109 |
+
]
|
| 110 |
+
assert train_rows
|
| 111 |
+
assert validation_rows
|
| 112 |
+
assert all(row["messages"][-1]["role"] == "assistant" for row in train_rows)
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def test_saved_oracle_trajectory_replays_to_success():
|
| 116 |
+
out_dir = _isolated_out_dir("replay")
|
| 117 |
+
generate_sft_dataset.generate_dataset(
|
| 118 |
+
generate_sft_dataset.DatasetConfig(
|
| 119 |
+
episodes=1,
|
| 120 |
+
out_dir=out_dir,
|
| 121 |
+
dry_run_oracle=True,
|
| 122 |
+
)
|
| 123 |
+
)
|
| 124 |
+
trajectory_path = next((out_dir / "trajectories").glob("train_seed*.json"))
|
| 125 |
+
trajectory = json.loads(trajectory_path.read_text(encoding="utf-8"))
|
| 126 |
+
|
| 127 |
+
env = CybersecurityOwaspEnvironment()
|
| 128 |
+
try:
|
| 129 |
+
env.reset(
|
| 130 |
+
seed=int(trajectory["seed"]),
|
| 131 |
+
split=trajectory["split"],
|
| 132 |
+
difficulty=int(trajectory["difficulty"]),
|
| 133 |
+
)
|
| 134 |
+
final = None
|
| 135 |
+
for action_data in trajectory["actions"]:
|
| 136 |
+
final = env.step(CyberSecurityOWASPAction(**action_data))
|
| 137 |
+
assert final is not None
|
| 138 |
+
assert final.done is True
|
| 139 |
+
assert env.state.success is True
|
| 140 |
+
assert not env.state.anti_cheat_flags
|
| 141 |
+
finally:
|
| 142 |
+
env.close()
|
tests/test_trackio_utils.py
CHANGED
|
@@ -39,6 +39,10 @@ def test_canonical_tracking_fields_exist_and_are_numeric_where_expected():
|
|
| 39 |
assert isinstance(fields["reward/hidden_authz_pass_rate"], float)
|
| 40 |
assert isinstance(fields["reward/normal_flow_pass_rate"], float)
|
| 41 |
assert isinstance(fields["reward/public_hidden_gap"], float)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
assert isinstance(fields["skill/exploit_to_patch_alignment"], float)
|
| 43 |
|
| 44 |
metrics = aggregate_episode_metrics([record])
|
|
@@ -156,11 +160,13 @@ def test_log_reward_config_emits_scalar_values_and_table(monkeypatch):
|
|
| 156 |
monkeypatch.setitem(sys.modules, "trackio", fake_trackio)
|
| 157 |
monkeypatch.setenv("CYBERSECURITY_OWASP_REWARD_MODE", "dense_train")
|
| 158 |
monkeypatch.setenv("CYBERSECURITY_OWASP_REWARD_STAGE", "early")
|
|
|
|
| 159 |
|
| 160 |
settings = load_reward_settings()
|
| 161 |
summary = log_reward_config(settings, step=0)
|
| 162 |
|
| 163 |
assert fake_trackio.config["reward_config_hash"] == summary["reward_config_hash"]
|
|
|
|
| 164 |
assert fake_trackio.config["reward_config_values"]["policy_inspected"]["value"] == 0.30
|
| 165 |
assert fake_trackio.config["reward_config__policy_inspected__value"] == 0.30
|
| 166 |
scalar_payload = next(payload for payload, _step in logged if "reward_config/policy_inspected/value" in payload)
|
|
|
|
| 39 |
assert isinstance(fields["reward/hidden_authz_pass_rate"], float)
|
| 40 |
assert isinstance(fields["reward/normal_flow_pass_rate"], float)
|
| 41 |
assert isinstance(fields["reward/public_hidden_gap"], float)
|
| 42 |
+
assert isinstance(fields["reward/dense_to_terminal_ratio"], float)
|
| 43 |
+
assert isinstance(fields["episode/time_to_first_patch"], float)
|
| 44 |
+
assert isinstance(fields["episode/repeated_action_rate"], float)
|
| 45 |
+
assert isinstance(fields["episode/patch_to_hidden_success_conversion_rate"], float)
|
| 46 |
assert isinstance(fields["skill/exploit_to_patch_alignment"], float)
|
| 47 |
|
| 48 |
metrics = aggregate_episode_metrics([record])
|
|
|
|
| 160 |
monkeypatch.setitem(sys.modules, "trackio", fake_trackio)
|
| 161 |
monkeypatch.setenv("CYBERSECURITY_OWASP_REWARD_MODE", "dense_train")
|
| 162 |
monkeypatch.setenv("CYBERSECURITY_OWASP_REWARD_STAGE", "early")
|
| 163 |
+
monkeypatch.setenv("CYBERSECURITY_OWASP_REWARD_VARIANT", "abl-test")
|
| 164 |
|
| 165 |
settings = load_reward_settings()
|
| 166 |
summary = log_reward_config(settings, step=0)
|
| 167 |
|
| 168 |
assert fake_trackio.config["reward_config_hash"] == summary["reward_config_hash"]
|
| 169 |
+
assert fake_trackio.config["reward_variant"] == "abl-test"
|
| 170 |
assert fake_trackio.config["reward_config_values"]["policy_inspected"]["value"] == 0.30
|
| 171 |
assert fake_trackio.config["reward_config__policy_inspected__value"] == 0.30
|
| 172 |
scalar_payload = next(payload for payload, _step in logged if "reward_config/policy_inspected/value" in payload)
|
training/configs/reward_ablations/A0_sparse_terminal_only.yaml
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
extends: ../grpo_small.yaml
|
| 2 |
+
reward:
|
| 3 |
+
stage: early
|
| 4 |
+
terminal_cap:
|
| 5 |
+
value: 12.0
|
| 6 |
+
progressive_cap:
|
| 7 |
+
value: 0.0
|
| 8 |
+
penalty_floor:
|
| 9 |
+
value: -10.0
|
| 10 |
+
train_cap:
|
| 11 |
+
value: 12.0
|
| 12 |
+
shaping_weight:
|
| 13 |
+
early: 0.0
|
| 14 |
+
middle: 0.0
|
| 15 |
+
late: 0.0
|
| 16 |
+
final: 0.0
|
| 17 |
+
policy_inspected:
|
| 18 |
+
value: 0.0
|
| 19 |
+
route_map_inspected:
|
| 20 |
+
value: 0.0
|
| 21 |
+
cap: 0.0
|
| 22 |
+
relevant_file_inspected:
|
| 23 |
+
value: 0.0
|
| 24 |
+
cap: 0.0
|
| 25 |
+
local_evidence_found:
|
| 26 |
+
value: 0.0
|
| 27 |
+
cap: 0.0
|
| 28 |
+
diagnosis_correct:
|
| 29 |
+
value: 0.0
|
| 30 |
+
patch_applies:
|
| 31 |
+
value: 0.0
|
| 32 |
+
app_boots_after_patch:
|
| 33 |
+
value: 0.0
|
| 34 |
+
visible_tests_improved:
|
| 35 |
+
value: 0.0
|
| 36 |
+
cap: 0.0
|
| 37 |
+
public_routes_visible_pass:
|
| 38 |
+
value: 0.0
|
| 39 |
+
step_penalty:
|
| 40 |
+
early: 0.0
|
| 41 |
+
middle: 0.0
|
| 42 |
+
late: 0.0
|
| 43 |
+
final: 0.0
|
| 44 |
+
cap: 0.0
|
| 45 |
+
speed_bonus:
|
| 46 |
+
value: 0.0
|
| 47 |
+
token_penalty:
|
| 48 |
+
early: 0.0
|
| 49 |
+
middle: 0.0
|
| 50 |
+
late: 0.0
|
| 51 |
+
final: 0.0
|
| 52 |
+
cap: 0.0
|
| 53 |
+
invalid_action:
|
| 54 |
+
value: 0.0
|
| 55 |
+
repeated_invalid_action:
|
| 56 |
+
value: 0.0
|
| 57 |
+
repeated_low_value_action:
|
| 58 |
+
value: 0.0
|
| 59 |
+
no_progress_action:
|
| 60 |
+
value: 0.0
|
| 61 |
+
noop_action:
|
| 62 |
+
value: 0.0
|
| 63 |
+
repeated_file_read:
|
| 64 |
+
value: 0.0
|
| 65 |
+
repeated_local_request:
|
| 66 |
+
value: 0.0
|
| 67 |
+
repeated_visible_tests:
|
| 68 |
+
value: 0.0
|
| 69 |
+
patch_before_policy:
|
| 70 |
+
value: 0.0
|
| 71 |
+
submit_without_patch:
|
| 72 |
+
value: 0.0
|
| 73 |
+
submit_without_visible_tests:
|
| 74 |
+
value: 0.0
|
| 75 |
+
excessive_patch_attempt:
|
| 76 |
+
value: 0.0
|
| 77 |
+
too_many_files_changed:
|
| 78 |
+
value: 0.0
|
| 79 |
+
oversized_patch:
|
| 80 |
+
value: 0.0
|
| 81 |
+
severe_value: 0.0
|
| 82 |
+
hidden_file_probe:
|
| 83 |
+
value: -10.0
|
| 84 |
+
terminate: true
|
| 85 |
+
test_file_edit:
|
| 86 |
+
value: -10.0
|
| 87 |
+
terminate: true
|
| 88 |
+
host_path_escape:
|
| 89 |
+
value: -10.0
|
| 90 |
+
terminate: true
|
| 91 |
+
external_url_attempt:
|
| 92 |
+
value: -10.0
|
| 93 |
+
terminate: true
|
| 94 |
+
hardcoded_identifier:
|
| 95 |
+
value: -4.0
|
| 96 |
+
deny_all_patch:
|
| 97 |
+
value: -5.0
|
training/configs/reward_ablations/A2_reduced_shaping.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
extends: ../grpo_small.yaml
|
| 2 |
+
reward:
|
| 3 |
+
stage: early
|
| 4 |
+
progressive_cap:
|
| 5 |
+
value: 2.5
|
| 6 |
+
train_cap:
|
| 7 |
+
value: 18.0
|
| 8 |
+
shaping_weight:
|
| 9 |
+
early: 0.35
|
| 10 |
+
middle: 0.35
|
| 11 |
+
late: 0.35
|
| 12 |
+
final: 0.35
|
training/configs/reward_ablations/A3_no_speed_token.yaml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
extends: ../grpo_small.yaml
|
| 2 |
+
reward:
|
| 3 |
+
stage: early
|
| 4 |
+
step_penalty:
|
| 5 |
+
early: -0.002
|
| 6 |
+
middle: -0.002
|
| 7 |
+
late: -0.002
|
| 8 |
+
final: -0.002
|
| 9 |
+
cap: -0.25
|
| 10 |
+
speed_bonus:
|
| 11 |
+
value: 0.0
|
| 12 |
+
token_penalty:
|
| 13 |
+
early: 0.0
|
| 14 |
+
middle: 0.0
|
| 15 |
+
late: 0.0
|
| 16 |
+
final: 0.0
|
| 17 |
+
cap: 0.0
|
training/configs/reward_ablations/A6_visible_gate.yaml
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
extends: ../grpo_small.yaml
|
| 2 |
+
reward:
|
| 3 |
+
stage: early
|
| 4 |
+
app_boots_after_patch:
|
| 5 |
+
value: 0.10
|
| 6 |
+
visible_tests_improved:
|
| 7 |
+
value: 0.0
|
| 8 |
+
cap: 0.20
|
| 9 |
+
public_routes_visible_pass:
|
| 10 |
+
value: 0.10
|
training/configs/reward_ablations/A7_evidence045.yaml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
extends: ../grpo_small.yaml
|
| 2 |
+
reward:
|
| 3 |
+
stage: early
|
| 4 |
+
local_evidence_found:
|
| 5 |
+
value: 0.45
|
| 6 |
+
cap: 0.45
|
training/trackio_utils.py
CHANGED
|
@@ -17,6 +17,7 @@ RUN_SCENARIO_FIELDS = (
|
|
| 17 |
"run/base_model",
|
| 18 |
"run/algo",
|
| 19 |
"run/reward_version",
|
|
|
|
| 20 |
"run/env_version",
|
| 21 |
"scenario/seed",
|
| 22 |
"scenario/template_id",
|
|
@@ -136,6 +137,16 @@ CANONICAL_TRACKIO_SIGNALS = tuple(
|
|
| 136 |
|
| 137 |
DERIVED_TRACKIO_METRICS = (
|
| 138 |
"reward/public_hidden_gap",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
"cheat/score",
|
| 140 |
)
|
| 141 |
|
|
@@ -545,6 +556,7 @@ def episode_record_from_state(
|
|
| 545 |
"run/reward_config_hash": context.get("reward_config_hash", ""),
|
| 546 |
"run/reward_mode": context.get("reward_mode", ""),
|
| 547 |
"run/reward_stage": context.get("reward_stage", ""),
|
|
|
|
| 548 |
"run/env_version": context.get("env_version", "0.1.0"),
|
| 549 |
"episode_id": getattr(state, "episode_id", ""),
|
| 550 |
"task_id": getattr(state, "task_id", ""),
|
|
@@ -663,6 +675,16 @@ def episode_to_tracking_fields(episode: Any) -> dict[str, Any]:
|
|
| 663 |
-_float(final_reward.get("safety")),
|
| 664 |
)
|
| 665 |
fields["reward/public_hidden_gap"] = visible_rate - hidden_rate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 666 |
|
| 667 |
fields["skill/valid_action_rate"] = max(0.0, min(1.0, valid_action_rate))
|
| 668 |
fields["skill/discovery_success"] = fields["reward/bug_classification_correct"]
|
|
@@ -698,11 +720,33 @@ def episode_to_tracking_fields(episode: Any) -> dict[str, Any]:
|
|
| 698 |
fields["episode/patch_attempt_count"] = _float(
|
| 699 |
record.get("patch_attempt_count", tool_names.count("patch_file"))
|
| 700 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 701 |
fields["episode/submit_without_test_rate"] = 1.0 if (
|
| 702 |
"submit_fix" in tool_names and "patch_file" in tool_names and "run_visible_tests" not in tool_names
|
| 703 |
) else 0.0
|
|
|
|
|
|
|
|
|
|
| 704 |
fields["episode/cheat_attempt_rate"] = 1.0 if anti_flags else 0.0
|
| 705 |
fields["episode/oversecure_rate"] = 1.0 if bool(quality.get("deny_all")) else 0.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 706 |
|
| 707 |
fields["cheat/hidden_file_read_attempt"] = 1.0 if (
|
| 708 |
"blocked_file_probe" in anti_flags and "hidden" in path_text
|
|
@@ -802,6 +846,15 @@ def train_metric_aliases(metrics: Mapping[str, Any]) -> dict[str, float]:
|
|
| 802 |
"train/reward_token_penalty_mean": _float(metrics.get("reward/token_penalty")),
|
| 803 |
"train/reward_speed_bonus_mean": _float(metrics.get("reward/speed_bonus")),
|
| 804 |
"train/reward_behavior_penalty_mean": _float(metrics.get("reward/behavior_penalty")),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 805 |
"train/success_rate": _float(metrics.get("skill/patch_success")),
|
| 806 |
"train/exploit_block_rate": _float(metrics.get("reward/hidden_authz_pass_rate")),
|
| 807 |
"train/regression_preservation_rate": _float(metrics.get("reward/normal_flow_pass_rate")),
|
|
|
|
| 17 |
"run/base_model",
|
| 18 |
"run/algo",
|
| 19 |
"run/reward_version",
|
| 20 |
+
"run/reward_variant",
|
| 21 |
"run/env_version",
|
| 22 |
"scenario/seed",
|
| 23 |
"scenario/template_id",
|
|
|
|
| 137 |
|
| 138 |
DERIVED_TRACKIO_METRICS = (
|
| 139 |
"reward/public_hidden_gap",
|
| 140 |
+
"reward/visible_hidden_gap",
|
| 141 |
+
"reward/dense_total",
|
| 142 |
+
"reward/dense_to_terminal_ratio",
|
| 143 |
+
"episode/time_to_first_evidence",
|
| 144 |
+
"episode/time_to_first_patch",
|
| 145 |
+
"episode/repeated_action_rate",
|
| 146 |
+
"episode/submit_without_evidence_rate",
|
| 147 |
+
"episode/hardcoded_identifier_rate",
|
| 148 |
+
"episode/deny_all_patch_rate",
|
| 149 |
+
"episode/patch_to_hidden_success_conversion_rate",
|
| 150 |
"cheat/score",
|
| 151 |
)
|
| 152 |
|
|
|
|
| 556 |
"run/reward_config_hash": context.get("reward_config_hash", ""),
|
| 557 |
"run/reward_mode": context.get("reward_mode", ""),
|
| 558 |
"run/reward_stage": context.get("reward_stage", ""),
|
| 559 |
+
"run/reward_variant": context.get("reward_variant", ""),
|
| 560 |
"run/env_version": context.get("env_version", "0.1.0"),
|
| 561 |
"episode_id": getattr(state, "episode_id", ""),
|
| 562 |
"task_id": getattr(state, "task_id", ""),
|
|
|
|
| 675 |
-_float(final_reward.get("safety")),
|
| 676 |
)
|
| 677 |
fields["reward/public_hidden_gap"] = visible_rate - hidden_rate
|
| 678 |
+
fields["reward/visible_hidden_gap"] = fields["reward/public_hidden_gap"]
|
| 679 |
+
fields["reward/dense_total"] = (
|
| 680 |
+
fields["reward/total"] - fields["reward/terminal_15"]
|
| 681 |
+
)
|
| 682 |
+
terminal_denominator = abs(fields["reward/terminal_15"])
|
| 683 |
+
fields["reward/dense_to_terminal_ratio"] = (
|
| 684 |
+
fields["reward/dense_total"] / terminal_denominator
|
| 685 |
+
if terminal_denominator > 1e-9
|
| 686 |
+
else fields["reward/dense_total"]
|
| 687 |
+
)
|
| 688 |
|
| 689 |
fields["skill/valid_action_rate"] = max(0.0, min(1.0, valid_action_rate))
|
| 690 |
fields["skill/discovery_success"] = fields["reward/bug_classification_correct"]
|
|
|
|
| 720 |
fields["episode/patch_attempt_count"] = _float(
|
| 721 |
record.get("patch_attempt_count", tool_names.count("patch_file"))
|
| 722 |
)
|
| 723 |
+
fields["episode/time_to_first_evidence"] = fields["skill/first_valid_exploit_step"]
|
| 724 |
+
fields["episode/time_to_first_patch"] = _first_tool_step(
|
| 725 |
+
actions,
|
| 726 |
+
{"patch_file"},
|
| 727 |
+
observations,
|
| 728 |
+
)
|
| 729 |
+
action_keys = [
|
| 730 |
+
json.dumps(action, sort_keys=True, default=str)
|
| 731 |
+
for action in actions
|
| 732 |
+
]
|
| 733 |
+
repeated_actions = len(action_keys) - len(set(action_keys))
|
| 734 |
+
fields["episode/repeated_action_rate"] = repeated_actions / max(1.0, float(len(action_keys)))
|
| 735 |
fields["episode/submit_without_test_rate"] = 1.0 if (
|
| 736 |
"submit_fix" in tool_names and "patch_file" in tool_names and "run_visible_tests" not in tool_names
|
| 737 |
) else 0.0
|
| 738 |
+
fields["episode/submit_without_evidence_rate"] = 1.0 if (
|
| 739 |
+
"submit_fix" in tool_names and fields["episode/time_to_first_evidence"] < 0.0
|
| 740 |
+
) else 0.0
|
| 741 |
fields["episode/cheat_attempt_rate"] = 1.0 if anti_flags else 0.0
|
| 742 |
fields["episode/oversecure_rate"] = 1.0 if bool(quality.get("deny_all")) else 0.0
|
| 743 |
+
fields["episode/hardcoded_identifier_rate"] = 1.0 if (
|
| 744 |
+
"hardcoded_hidden_identifier" in anti_flags or bool(quality.get("hardcoded"))
|
| 745 |
+
) else 0.0
|
| 746 |
+
fields["episode/deny_all_patch_rate"] = fields["episode/oversecure_rate"]
|
| 747 |
+
fields["episode/patch_to_hidden_success_conversion_rate"] = (
|
| 748 |
+
hidden_rate if "patch_file" in tool_names else 0.0
|
| 749 |
+
)
|
| 750 |
|
| 751 |
fields["cheat/hidden_file_read_attempt"] = 1.0 if (
|
| 752 |
"blocked_file_probe" in anti_flags and "hidden" in path_text
|
|
|
|
| 846 |
"train/reward_token_penalty_mean": _float(metrics.get("reward/token_penalty")),
|
| 847 |
"train/reward_speed_bonus_mean": _float(metrics.get("reward/speed_bonus")),
|
| 848 |
"train/reward_behavior_penalty_mean": _float(metrics.get("reward/behavior_penalty")),
|
| 849 |
+
"train/dense_to_terminal_ratio": _float(metrics.get("reward/dense_to_terminal_ratio")),
|
| 850 |
+
"train/visible_hidden_gap": _float(metrics.get("reward/visible_hidden_gap")),
|
| 851 |
+
"train/repeated_action_rate": _float(metrics.get("episode/repeated_action_rate")),
|
| 852 |
+
"train/submit_without_evidence_rate": _float(metrics.get("episode/submit_without_evidence_rate")),
|
| 853 |
+
"train/hardcoded_identifier_rate": _float(metrics.get("episode/hardcoded_identifier_rate")),
|
| 854 |
+
"train/deny_all_patch_rate": _float(metrics.get("episode/deny_all_patch_rate")),
|
| 855 |
+
"train/patch_to_hidden_success_conversion_rate": _float(
|
| 856 |
+
metrics.get("episode/patch_to_hidden_success_conversion_rate")
|
| 857 |
+
),
|
| 858 |
"train/success_rate": _float(metrics.get("skill/patch_success")),
|
| 859 |
"train/exploit_block_rate": _float(metrics.get("reward/hidden_authz_pass_rate")),
|
| 860 |
"train/regression_preservation_rate": _float(metrics.get("reward/normal_flow_pass_rate")),
|