Cyber_analyst-round1 / scripts /generate_sft_dataset.py
Humanlearning's picture
feat: expand README with synthetic SFT dataset generation instructions, enhance dataset verification and pushing to Hugging Face Hub, and improve modal training scripts with default configurations for curriculum and GPU fallback
60f97ab
"""Generate verifier-gated SFT data for CyberSecurity_OWASP.
The default path asks a larger Hugging Face-hosted teacher model for one JSON
action at a time, executes those actions in the real environment, and keeps
only trajectories that pass the local deterministic verifier. The
``--dry-run-oracle`` path is intentionally network-free and exists for CI and
smoke tests.
"""
from __future__ import annotations
import argparse
import json
import os
import statistics
import subprocess
import threading
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Iterable
from CyberSecurity_OWASP.models import CyberSecurityOWASPAction, CyberSecurityOWASPObservation
from CyberSecurity_OWASP.server.CyberSecurity_OWASP_environment import (
CybersecurityOwaspEnvironment,
)
from CyberSecurity_OWASP.validators import detect_cheating
DEFAULT_TEACHER_MODEL = "deepseek-ai/DeepSeek-V4-Pro"
DEFAULT_TARGET_MODEL = "unsloth/gemma-4-E2B-it"
TRAINING_SYSTEM_PROMPT = (
"You are a defensive AppSec repair agent in the local CyberSecurity_OWASP "
"OpenEnv environment. Use only the listed local tools. Do not target real "
"systems. Work step by step: inspect policy and generated code, reproduce "
"the authorization issue locally, submit a policy-tied diagnosis, patch the "
"generated app, run visible tests, then submit the fix. Return exactly one "
"JSON action object and no markdown."
)
BANNED_PROMPT_MARKERS = (
"hidden_facts",
"oracle_hidden_focus",
"reward_engine",
"validators.py",
"rewards.py",
"tests/hidden",
"hidden tests",
".git",
)
RISKY_ARGUMENT_MARKERS = (
"hidden",
"oracle",
"reward_engine",
"validators.py",
"rewards.py",
".git",
"..",
)
@dataclass
class DatasetConfig:
teacher_model: str = DEFAULT_TEACHER_MODEL
target_model: str = DEFAULT_TARGET_MODEL
split: str = "train"
difficulty: int = 0
seed_start: int = 0
episodes: int = 100
validation_episodes: int = 0
out_dir: Path = Path("outputs/sft")
max_steps: int = 40
max_teacher_retries: int = 2
max_tokens: int = 768
temperature: float = 0.2
top_p: float = 0.95
dry_run_oracle: bool = False
workers: int = 0
min_terminal_reward: float = 12.0
difficulty_levels: tuple[int, ...] = ()
difficulty_buckets: int = 0
push_to_hub: bool = False
dataset_repo_id: str = "Humanlearning/CyberSecurity_OWASP-sft-dataset"
hub_private: bool = False
progress: bool = False
class HuggingFaceTeacher:
"""Small wrapper around Hugging Face chat completion."""
def __init__(
self,
*,
model: str,
token: str,
max_tokens: int,
temperature: float,
top_p: float,
) -> None:
try:
from huggingface_hub import InferenceClient
except ImportError as exc: # pragma: no cover - dependency smoke checked separately
raise RuntimeError(
"huggingface_hub is required for teacher generation. Install project "
"dependencies or use --dry-run-oracle for local CI."
) from exc
self.model = model
self.max_tokens = int(max_tokens)
self.temperature = float(temperature)
self.top_p = float(top_p)
self.client = InferenceClient(token=token)
def complete(self, messages: list[dict[str, str]]) -> str:
response = self.client.chat_completion(
model=self.model,
messages=messages,
max_tokens=self.max_tokens,
temperature=self.temperature,
top_p=self.top_p,
)
return _chat_response_content(response)
def _chat_response_content(response: Any) -> str:
choices = getattr(response, "choices", None)
if choices:
message = getattr(choices[0], "message", None)
content = getattr(message, "content", None)
if content is not None:
return str(content)
if isinstance(response, dict):
choices = response.get("choices") or []
if choices:
message = choices[0].get("message") or {}
return str(message.get("content", ""))
return str(response)
def extract_first_json_object(text: str) -> dict[str, Any] | None:
"""Extract the first JSON object from raw teacher text."""
stripped = text.strip()
candidates = [stripped]
if "```" in stripped:
for part in stripped.split("```"):
candidate = part.strip()
if candidate.startswith("json"):
candidate = candidate[4:].strip()
candidates.append(candidate)
for candidate in candidates:
try:
loaded = json.loads(candidate)
except Exception:
continue
if isinstance(loaded, dict):
return loaded
start = stripped.find("{")
while start >= 0:
depth = 0
in_string = False
escaped = False
for index in range(start, len(stripped)):
char = stripped[index]
if in_string:
if escaped:
escaped = False
elif char == "\\":
escaped = True
elif char == '"':
in_string = False
continue
if char == '"':
in_string = True
elif char == "{":
depth += 1
elif char == "}":
depth -= 1
if depth == 0:
try:
loaded = json.loads(stripped[start : index + 1])
except Exception:
break
if isinstance(loaded, dict):
return loaded
start = stripped.find("{", start + 1)
return None
def parse_action_text(text: str) -> CyberSecurityOWASPAction:
data = extract_first_json_object(text)
if data is None:
raise ValueError("teacher did not return a JSON object")
return CyberSecurityOWASPAction(**data)
def action_to_json(action: CyberSecurityOWASPAction) -> str:
return json.dumps(action.model_dump(), separators=(",", ":"), sort_keys=True)
def _safe_observation_payload(
observation: CyberSecurityOWASPObservation,
recent_actions: list[dict[str, Any]],
) -> dict[str, Any]:
return {
"phase": observation.phase,
"task_brief": observation.task_brief,
"scenario_prompt": observation.scenario_prompt,
"available_actions": observation.available_actions,
"last_tool_result": observation.last_tool_result,
"last_action_valid": observation.last_action_valid,
"last_action_error": observation.last_action_error,
"visible_test_result": observation.visible_test_result,
"done_reason": observation.done_reason,
"recent_actions": recent_actions[-8:],
}
def build_user_prompt(
observation: CyberSecurityOWASPObservation,
recent_actions: list[dict[str, Any]],
retry_error: str | None = None,
) -> str:
payload = _safe_observation_payload(observation, recent_actions)
prompt = (
"Current CyberSecurity_OWASP observation, containing only information "
"available to the agent:\n"
f"{json.dumps(payload, indent=2, sort_keys=True)}\n\n"
"Choose the next action. Output exactly one JSON object with keys "
"`tool_name` and `arguments`. Do not include markdown or commentary."
)
if retry_error:
prompt += f"\nPrevious candidate was rejected safely: {retry_error}"
_assert_prompt_is_safe(prompt)
return prompt
def _assert_prompt_is_safe(prompt: str) -> None:
lowered = prompt.lower()
leaked = [marker for marker in BANNED_PROMPT_MARKERS if marker.lower() in lowered]
if leaked:
raise ValueError(f"prompt contains blocked marker(s): {', '.join(leaked)}")
def build_chat_messages(
observation: CyberSecurityOWASPObservation,
recent_actions: list[dict[str, Any]],
retry_error: str | None = None,
) -> list[dict[str, str]]:
return [
{"role": "system", "content": TRAINING_SYSTEM_PROMPT},
{"role": "user", "content": build_user_prompt(observation, recent_actions, retry_error)},
]
def make_chat_row(
*,
messages: list[dict[str, str]],
action: CyberSecurityOWASPAction,
metadata: dict[str, Any],
) -> dict[str, Any]:
return {
"messages": [
*messages,
{"role": "assistant", "content": action_to_json(action)},
],
"metadata": metadata,
}
def preflight_action(
env: CybersecurityOwaspEnvironment,
observation: CyberSecurityOWASPObservation,
action: CyberSecurityOWASPAction,
) -> tuple[bool, str]:
if action.tool_name not in observation.available_actions:
return False, f"{action.tool_name} is not allowed during {observation.phase}"
args = action.arguments or {}
flags = detect_cheating(env.state, action)
if flags:
return False, f"action triggered safety flags: {', '.join(flags)}"
arg_text = json.dumps(args, sort_keys=True, default=str).lower()
if any(marker in arg_text for marker in RISKY_ARGUMENT_MARKERS):
return False, "arguments reference blocked files or paths"
if action.tool_name == "read_file" and not args.get("path"):
return False, "read_file requires path"
if action.tool_name == "search_code" and not args.get("query"):
return False, "search_code requires query"
if action.tool_name == "patch_file":
path = str(args.get("path", ""))
if not path:
return False, "patch_file requires path"
if path.replace("\\", "/").startswith("tests/"):
return False, "patch_file cannot modify tests"
if not args.get("content") and not args.get("diff"):
return False, "patch_file requires content or diff"
if action.tool_name == "send_local_request":
path = str(args.get("path", ""))
if not path.startswith("/"):
return False, "send_local_request requires a local route path"
if action.tool_name == "compare_identities":
path = str(args.get("path", ""))
if not path.startswith("/"):
return False, "compare_identities requires a local route path"
if not args.get("first_user_id") or not args.get("second_user_id"):
return False, "compare_identities requires two user ids"
if action.tool_name == "submit_diagnosis":
required = ("bug_class", "route", "violated_policy_rule", "evidence_trace_ids", "fix_plan")
missing = [key for key in required if not args.get(key)]
if missing:
return False, f"submit_diagnosis missing: {', '.join(missing)}"
return True, ""
def _trace_id_from_observation(observation: CyberSecurityOWASPObservation) -> str:
try:
payload = json.loads(observation.last_tool_result)
except Exception:
return "req_001"
return str(payload.get("trace_id", "req_001"))
def _secure_invoice_source(env: CybersecurityOwaspEnvironment) -> str:
source = (Path(env.state.hidden_facts["workspace"]) / "app/routes/invoices.py").read_text(
encoding="utf-8"
)
return source.replace(
" # BUG: this only checks that the caller is authenticated. It forgets the\n"
" # owner/admin and tenant policy checks required by the policy graph.\n"
" return {\"status\": 200, \"body\": invoice}\n",
" if invoice[\"tenant_id\"] != actor[\"tenant_id\"]:\n"
" return {\"status\": 403, \"body\": {\"detail\": \"forbidden\"}}\n"
" if invoice[\"owner_user_id\"] != actor[\"user_id\"] and not is_billing_admin(actor):\n"
" return {\"status\": 403, \"body\": {\"detail\": \"forbidden\"}}\n"
" return {\"status\": 200, \"body\": invoice}\n",
)
def oracle_actions_for_state(
env: CybersecurityOwaspEnvironment,
evidence_trace_id: str | None = None,
) -> list[CyberSecurityOWASPAction]:
hidden = env.state.hidden_facts
trace_id = evidence_trace_id or "req_001"
return [
CyberSecurityOWASPAction(tool_name="inspect_policy_graph", arguments={}),
CyberSecurityOWASPAction(tool_name="list_routes", arguments={}),
CyberSecurityOWASPAction(
tool_name="read_file",
arguments={"path": "app/routes/invoices.py"},
),
CyberSecurityOWASPAction(
tool_name="send_local_request",
arguments={
"method": "GET",
"path": f"/invoices/{hidden['other_invoice_id']}",
"user_id": hidden["owner_user_id"],
},
),
CyberSecurityOWASPAction(
tool_name="submit_diagnosis",
arguments={
"bug_class": "idor_ownership_bug",
"route": "GET /invoices/{invoice_id}",
"violated_policy_rule": "Only the owner or a billing_admin in the same tenant may read invoices.",
"evidence_trace_ids": [trace_id],
"fix_plan": "Add tenant and owner/admin checks before returning invoice data.",
},
),
CyberSecurityOWASPAction(
tool_name="patch_file",
arguments={"path": "app/routes/invoices.py", "content": _secure_invoice_source(env)},
),
CyberSecurityOWASPAction(tool_name="run_visible_tests", arguments={}),
CyberSecurityOWASPAction(tool_name="submit_fix", arguments={}),
]
def _teacher_action(
*,
teacher: HuggingFaceTeacher,
env: CybersecurityOwaspEnvironment,
observation: CyberSecurityOWASPObservation,
recent_actions: list[dict[str, Any]],
config: DatasetConfig,
) -> tuple[CyberSecurityOWASPAction, list[dict[str, str]]]:
retry_error: str | None = None
for _ in range(config.max_teacher_retries + 1):
messages = build_chat_messages(observation, recent_actions, retry_error)
raw = teacher.complete(messages)
try:
action = parse_action_text(raw)
except Exception as exc:
retry_error = str(exc)
continue
ok, error = preflight_action(env, observation, action)
if ok:
return action, messages
retry_error = error
raise ValueError(retry_error or "teacher did not produce a usable action")
def _oracle_action(
*,
env: CybersecurityOwaspEnvironment,
observation: CyberSecurityOWASPObservation,
recent_actions: list[dict[str, Any]],
oracle_actions: list[CyberSecurityOWASPAction],
step_index: int,
) -> tuple[CyberSecurityOWASPAction, list[dict[str, str]]]:
action = oracle_actions[step_index]
messages = build_chat_messages(observation, recent_actions)
ok, error = preflight_action(env, observation, action)
if not ok:
raise ValueError(error)
return action, messages
def _terminal_checks_passed(env: CybersecurityOwaspEnvironment) -> bool:
verifier = env.state.verification_summary or {}
required = ("visible", "security", "regression", "public_routes", "patch_quality")
return all(bool((verifier.get(key) or {}).get("passed", False)) for key in required)
def _episode_reward(env: CybersecurityOwaspEnvironment) -> float:
if env.state.reward_history:
return float(env.state.reward_history[-1].get("terminal_total", 0.0))
return 0.0
def run_episode(
*,
seed: int,
split: str,
difficulty: int,
config: DatasetConfig,
teacher: HuggingFaceTeacher | None,
) -> dict[str, Any]:
env = CybersecurityOwaspEnvironment()
rows: list[dict[str, Any]] = []
trajectory_steps: list[dict[str, Any]] = []
recent_actions: list[dict[str, Any]] = []
try:
observation = env.reset(seed=seed, split=split, difficulty=difficulty)
oracle_actions = oracle_actions_for_state(env) if config.dry_run_oracle else []
for step_index in range(config.max_steps):
if observation.done:
break
if config.dry_run_oracle:
if step_index >= len(oracle_actions):
raise ValueError("oracle action script ended before terminal state")
if step_index == 4 and env.state.request_trace:
trace_id = _trace_id_from_observation(observation)
oracle_actions = oracle_actions_for_state(env, evidence_trace_id=trace_id)
action, messages = _oracle_action(
env=env,
observation=observation,
recent_actions=recent_actions,
oracle_actions=oracle_actions,
step_index=step_index,
)
else:
if teacher is None:
raise RuntimeError("teacher is required unless --dry-run-oracle is set")
action, messages = _teacher_action(
teacher=teacher,
env=env,
observation=observation,
recent_actions=recent_actions,
config=config,
)
step_number = step_index + 1
action_record = action.model_dump()
row = make_chat_row(
messages=messages,
action=action,
metadata={
"target_model": config.target_model,
"teacher_model": config.teacher_model,
"seed": seed,
"split": split,
"difficulty": difficulty,
"step": step_number,
"tool_name": action.tool_name,
"task_id": env.state.task_id,
"episode_id": env.state.episode_id,
"scenario_hash": env.state.scenario_hash,
},
)
next_observation = env.step(action)
trajectory_steps.append(
{
"step": step_number,
"prompt_messages": messages,
"action": action_record,
"observation": next_observation.model_dump(),
"reward_breakdown": dict(next_observation.reward_breakdown or {}),
}
)
if not next_observation.last_action_valid:
raise ValueError(next_observation.last_action_error or "invalid action")
if env.state.anti_cheat_flags:
raise ValueError(f"anti-cheat flags: {env.state.anti_cheat_flags}")
rows.append(row)
recent_actions.append(action_record)
observation = next_observation
if observation.done:
break
if not env.state.done:
raise ValueError("episode did not reach a terminal state")
if not env.state.success:
raise ValueError(env.state.failure_reason or "terminal verifier failed")
if env.state.step_count > config.max_steps:
raise ValueError("episode exceeded max steps")
if env.state.anti_cheat_flags:
raise ValueError("episode has anti-cheat flags")
if not _terminal_checks_passed(env):
raise ValueError("terminal verifier checks did not all pass")
final_reward = _episode_reward(env)
final_breakdown = dict(env.state.reward_history[-1]) if env.state.reward_history else {}
for row in rows:
row["metadata"].update(
{
"final_success": True,
"terminal_total": final_reward,
"total_reward": float(env.state.accumulated_reward),
"anti_cheat_flags": list(env.state.anti_cheat_flags),
"final_reward_breakdown": final_breakdown,
}
)
return {
"accepted": True,
"seed": seed,
"split": split,
"difficulty": difficulty,
"rows": rows,
"trajectory": {
"episode_id": env.state.episode_id,
"task_id": env.state.task_id,
"seed": seed,
"split": split,
"difficulty": difficulty,
"domain": env.state.domain,
"bug_family": env.state.bug_family,
"scenario_hash": env.state.scenario_hash,
"actions": [step["action"] for step in trajectory_steps],
"steps": trajectory_steps,
"reward_breakdown_by_step": list(env.state.reward_history),
"final_reward_breakdown": final_breakdown,
"total_reward": float(env.state.accumulated_reward),
"terminal_total": final_reward,
"success": True,
"failure_reason": None,
"anti_cheat_flags": list(env.state.anti_cheat_flags),
"verification_summary": env.state.verification_summary,
},
}
except Exception as exc:
return {
"accepted": False,
"seed": seed,
"split": split,
"difficulty": difficulty,
"reason": str(exc),
"rows": [],
"trajectory": {
"seed": seed,
"split": split,
"difficulty": difficulty,
"steps": trajectory_steps,
"actions": [step["action"] for step in trajectory_steps],
"success": bool(env.state.success),
"failure_reason": env.state.failure_reason or str(exc),
"anti_cheat_flags": list(env.state.anti_cheat_flags),
},
}
finally:
env.close()
def write_jsonl(path: Path, rows: Iterable[dict[str, Any]]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w", encoding="utf-8") as handle:
for row in rows:
handle.write(json.dumps(row, sort_keys=True, default=str) + "\n")
def write_dataset_card(out_dir: Path, manifest: dict[str, Any], dataset_repo_id: str) -> Path:
card_path = out_dir / "README.md"
difficulty_levels = manifest.get("difficulty_levels", [])
reward_verification = manifest.get("reward_verification", {})
card = f"""---
license: apache-2.0
task_categories:
- text-generation
language:
- en
tags:
- cybersecurity
- owasp
- openenv
- tool-use
- sft
pretty_name: CyberSecurity_OWASP SFT Dataset
---
# CyberSecurity_OWASP SFT Dataset
This dataset contains verifier-gated supervised fine-tuning examples for the
`CyberSecurity_OWASP` OpenEnv environment. Each row teaches one step of the
defensive local AppSec workflow: inspect policy/code, reproduce a local
authorization failure, submit a policy-tied diagnosis, patch the generated app,
run visible tests, and submit the fix.
Every kept trajectory is executed against the real local environment and must
pass the deterministic reward verifier before rows are written.
## Intended Use
- Target SFT model: `{manifest.get("target_model", "")}`
- Teacher model: `{manifest.get("teacher_model", "")}`
- Dataset repo: `{dataset_repo_id}`
- Format: chat JSONL with `messages` and verifier metadata
- Dry-run oracle: `{manifest.get("dry_run_oracle", False)}`
## Curriculum Coverage
- Difficulty levels: `{difficulty_levels}`
- Episodes attempted: `{manifest.get("episodes_attempted", 0)}`
- Episodes accepted: `{manifest.get("episodes_accepted", 0)}`
- Acceptance rate: `{manifest.get("acceptance_rate", 0.0):.4f}`
- Rows by split: `{json.dumps(manifest.get("rows_by_split", {}), sort_keys=True)}`
- Rows by difficulty: `{json.dumps(manifest.get("rows_by_difficulty", {}), sort_keys=True)}`
## Reward Verification
- Passed: `{reward_verification.get("passed", False)}`
- Checked rows: `{reward_verification.get("checked_rows", 0)}`
- Minimum terminal reward: `{reward_verification.get("min_terminal_reward", 0.0)}`
- Reward summary: `{json.dumps(reward_verification.get("reward_summary", {}), sort_keys=True)}`
Rows are rejected if the episode fails hidden security/regression/public-route
checks, triggers anti-cheat flags, lacks a positive patch-quality reward, or
falls below the configured terminal reward threshold.
## Schema
Each JSONL row has:
```json
{{
"messages": [
{{"role": "system", "content": "..."}},
{{"role": "user", "content": "..."}},
{{"role": "assistant", "content": "{{\\"tool_name\\":\\"...\\",\\"arguments\\":{{...}}}}"}}
],
"metadata": {{
"target_model": "...",
"teacher_model": "...",
"seed": 0,
"split": "train",
"difficulty": 0,
"step": 1,
"tool_name": "inspect_policy_graph",
"final_success": true,
"terminal_total": 12.5,
"anti_cheat_flags": []
}}
}}
```
"""
card_path.write_text(card, encoding="utf-8")
return card_path
def push_dataset_to_hub(out_dir: Path, *, repo_id: str, private: bool) -> dict[str, Any]:
token = os.getenv("HF_TOKEN")
if not token:
raise RuntimeError("HF_TOKEN is required for --push-to-hub")
try:
from huggingface_hub import HfApi
except ImportError as exc: # pragma: no cover
raise RuntimeError("huggingface_hub is required for --push-to-hub") from exc
api = HfApi(token=token)
api.create_repo(repo_id=repo_id, repo_type="dataset", private=private, exist_ok=True)
commit_info = api.upload_folder(
repo_id=repo_id,
repo_type="dataset",
folder_path=str(out_dir),
path_in_repo=".",
commit_message="Upload verified CyberSecurity_OWASP SFT dataset",
delete_patterns=[
"README.md",
"manifest.json",
"train.jsonl",
"validation.jsonl",
"hidden_eval.jsonl",
"trajectories/**",
],
)
return {
"repo_id": repo_id,
"private": bool(private),
"url": f"https://huggingface.co/datasets/{repo_id}",
"commit_url": getattr(commit_info, "commit_url", ""),
}
def push_existing_dataset(
out_dir: Path,
*,
repo_id: str,
private: bool,
min_terminal_reward: float,
required_difficulties: tuple[int, ...],
) -> dict[str, Any]:
verification = verify_sft_dataset_rewards(
out_dir,
min_terminal_reward=min_terminal_reward,
require_train_rows=True,
required_difficulties=required_difficulties,
)
if not verification["passed"]:
raise RuntimeError(f"Reward verification failed; refusing Hub push: {verification}")
manifest_path = out_dir / "manifest.json"
if manifest_path.exists():
manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
else:
manifest = {
"teacher_model": DEFAULT_TEACHER_MODEL,
"target_model": DEFAULT_TARGET_MODEL,
"difficulty_levels": [int(level) for level in required_difficulties],
"rows_by_split": verification.get("rows_by_split", {}),
}
manifest["reward_verification"] = verification
manifest["hub"] = {
"repo_id": repo_id,
"private": bool(private),
"url": f"https://huggingface.co/datasets/{repo_id}",
}
write_dataset_card(out_dir, manifest, repo_id)
manifest_path.write_text(
json.dumps(manifest, indent=2, sort_keys=True, default=str),
encoding="utf-8",
)
hub_result = push_dataset_to_hub(out_dir, repo_id=repo_id, private=private)
manifest["hub"].update(hub_result)
manifest_path.write_text(
json.dumps(manifest, indent=2, sort_keys=True, default=str),
encoding="utf-8",
)
return {"reward_verification": verification, "hub": manifest["hub"]}
def _write_trajectory(out_dir: Path, trajectory: dict[str, Any]) -> Path:
traj_dir = out_dir / "trajectories"
traj_dir.mkdir(parents=True, exist_ok=True)
name = (
f"{trajectory.get('split', 'train')}_seed{trajectory.get('seed', 0)}_"
f"{str(trajectory.get('episode_id', 'rejected'))[:12]}.json"
)
path = traj_dir / name
path.write_text(json.dumps(trajectory, indent=2, sort_keys=True, default=str), encoding="utf-8")
return path
def _git_sha() -> str:
root = Path(__file__).resolve().parents[1]
try:
return subprocess.check_output(
[
"git",
"-c",
f"safe.directory={root.as_posix()}",
"rev-parse",
"HEAD",
],
cwd=root,
text=True,
stderr=subprocess.DEVNULL,
).strip()
except Exception:
return "nogit"
def _reward_summary(values: list[float]) -> dict[str, float]:
if not values:
return {"mean": 0.0, "min": 0.0, "max": 0.0, "p50": 0.0}
sorted_values = sorted(values)
return {
"mean": float(statistics.mean(values)),
"min": float(min(values)),
"max": float(max(values)),
"p50": float(sorted_values[len(sorted_values) // 2]),
}
def _parse_int_csv(value: str) -> tuple[int, ...]:
if not value.strip():
return ()
levels = []
for item in value.split(","):
stripped = item.strip()
if not stripped:
continue
levels.append(int(stripped))
return tuple(dict.fromkeys(levels))
def _difficulty_levels(config: DatasetConfig) -> tuple[int, ...]:
if config.difficulty_levels:
return tuple(int(level) for level in config.difficulty_levels)
return (int(config.difficulty),)
def _configure_difficulty_buckets(config: DatasetConfig, levels: tuple[int, ...]) -> int:
requested = max(levels) + 1 if levels else int(config.difficulty) + 1
configured = max(int(config.difficulty_buckets or 0), requested, 1)
existing = os.getenv("CYBERSECURITY_OWASP_DIFFICULTY_BUCKETS")
if existing:
configured = max(configured, int(existing))
os.environ["CYBERSECURITY_OWASP_DIFFICULTY_BUCKETS"] = str(configured)
return configured
def _read_jsonl(path: Path) -> list[dict[str, Any]]:
if not path.exists():
return []
rows: list[dict[str, Any]] = []
for line_number, line in enumerate(path.read_text(encoding="utf-8").splitlines(), start=1):
if not line.strip():
continue
try:
item = json.loads(line)
except json.JSONDecodeError as exc:
raise ValueError(f"{path}:{line_number}: invalid JSONL row: {exc}") from exc
if not isinstance(item, dict):
raise ValueError(f"{path}:{line_number}: row must be a JSON object")
rows.append(item)
return rows
def _verify_sft_row_reward(
row: dict[str, Any],
*,
min_terminal_reward: float,
path: Path,
line_number: int,
) -> tuple[bool, str, float]:
messages = row.get("messages")
if not isinstance(messages, list) or len(messages) < 3:
return False, f"{path}:{line_number}: messages must include system/user/assistant", 0.0
if messages[-1].get("role") != "assistant":
return False, f"{path}:{line_number}: final message must be assistant", 0.0
try:
CyberSecurityOWASPAction(**json.loads(str(messages[-1].get("content", ""))))
except Exception as exc:
return False, f"{path}:{line_number}: assistant content is not a valid action: {exc}", 0.0
metadata = row.get("metadata")
if not isinstance(metadata, dict):
return False, f"{path}:{line_number}: missing metadata object", 0.0
if metadata.get("final_success") is not True:
return False, f"{path}:{line_number}: final_success is not true", 0.0
flags = metadata.get("anti_cheat_flags") or []
if flags:
return False, f"{path}:{line_number}: anti-cheat flags present: {flags}", 0.0
reward = float(metadata.get("terminal_total", 0.0) or 0.0)
if reward < min_terminal_reward:
return (
False,
f"{path}:{line_number}: terminal_total {reward:.3f} below required {min_terminal_reward:.3f}",
reward,
)
breakdown = metadata.get("final_reward_breakdown") or {}
if not isinstance(breakdown, dict):
return False, f"{path}:{line_number}: missing final_reward_breakdown", reward
required_positive = ("security", "regression", "public_routes", "patch_quality", "visible_tests")
missing = [key for key in required_positive if float(breakdown.get(key, 0.0) or 0.0) <= 0.0]
if missing:
return False, f"{path}:{line_number}: non-positive reward components: {', '.join(missing)}", reward
return True, "", reward
def verify_sft_dataset_rewards(
out_dir: Path,
*,
min_terminal_reward: float = 12.0,
require_train_rows: bool = True,
required_difficulties: tuple[int, ...] = (),
) -> dict[str, Any]:
"""Verify generated SFT rows carry successful verifier-backed rewards."""
checked_rows = 0
failed_rows: list[str] = []
rewards: list[float] = []
rows_by_split: dict[str, int] = {}
rows_by_difficulty: dict[str, int] = {}
for split_name in ("train", "validation", "hidden_eval"):
path = out_dir / f"{split_name}.jsonl"
rows = _read_jsonl(path)
if not rows and split_name != "train":
continue
rows_by_split[split_name] = len(rows)
for index, row in enumerate(rows, start=1):
ok, error, reward = _verify_sft_row_reward(
row,
min_terminal_reward=min_terminal_reward,
path=path,
line_number=index,
)
checked_rows += 1
if reward:
rewards.append(reward)
if not ok:
failed_rows.append(error)
metadata = row.get("metadata") if isinstance(row, dict) else {}
if isinstance(metadata, dict) and "difficulty" in metadata:
difficulty_key = str(int(metadata.get("difficulty", 0)))
rows_by_difficulty[difficulty_key] = rows_by_difficulty.get(difficulty_key, 0) + 1
passed = not failed_rows and (checked_rows > 0 or not require_train_rows)
if require_train_rows and rows_by_split.get("train", 0) <= 0:
passed = False
failed_rows.append(f"{out_dir / 'train.jsonl'}: no train rows found")
missing_difficulties = [
int(level)
for level in required_difficulties
if rows_by_difficulty.get(str(int(level)), 0) <= 0
]
if missing_difficulties:
passed = False
failed_rows.append(f"missing required curriculum difficulty rows: {missing_difficulties}")
return {
"passed": passed,
"checked_rows": checked_rows,
"failed_rows": failed_rows[:50],
"failure_count": len(failed_rows),
"rows_by_split": rows_by_split,
"rows_by_difficulty": rows_by_difficulty,
"required_difficulties": [int(level) for level in required_difficulties],
"missing_difficulties": missing_difficulties,
"min_terminal_reward": float(min_terminal_reward),
"reward_summary": _reward_summary(rewards),
}
def _resolved_worker_count(config: DatasetConfig, job_count: int) -> int:
if job_count <= 1:
return 1
if int(config.workers) > 0:
return max(1, min(int(config.workers), job_count))
cpu_count = os.cpu_count() or 4
return max(1, min(8, cpu_count, job_count))
def generate_dataset(config: DatasetConfig) -> dict[str, Any]:
config.out_dir.mkdir(parents=True, exist_ok=True)
teacher_local = threading.local()
teacher_token = None
if not config.dry_run_oracle:
teacher_token = os.getenv("HF_TOKEN")
if not teacher_token:
raise RuntimeError("HF_TOKEN is required unless --dry-run-oracle is set")
def teacher_for_thread() -> HuggingFaceTeacher | None:
if config.dry_run_oracle:
return None
teacher = getattr(teacher_local, "teacher", None)
if teacher is None:
teacher = HuggingFaceTeacher(
model=config.teacher_model,
token=str(teacher_token),
max_tokens=config.max_tokens,
temperature=config.temperature,
top_p=config.top_p,
)
teacher_local.teacher = teacher
return teacher
difficulty_levels = _difficulty_levels(config)
difficulty_bucket_count = _configure_difficulty_buckets(config, difficulty_levels)
validation_seed_start = config.seed_start + int(config.episodes) * len(difficulty_levels)
split_jobs = [(config.split, config.episodes, config.seed_start)]
if config.validation_episodes:
split_jobs.append(("validation", config.validation_episodes, validation_seed_start))
episode_jobs = [
{
"order": job_order,
"split": split,
"difficulty": int(difficulty),
"seed": int(seed_start) + difficulty_index * int(episodes) + offset,
}
for job_order, (split, episodes, seed_start) in enumerate(split_jobs)
for difficulty_index, difficulty in enumerate(difficulty_levels)
for offset in range(int(episodes))
]
rows_by_split: dict[str, list[dict[str, Any]]] = {"train": [], "validation": []}
attempts: list[dict[str, Any]] = []
rewards: list[float] = []
accepted = 0
attempted = len(episode_jobs)
workers = _resolved_worker_count(config, attempted)
def run_job(job: dict[str, Any]) -> dict[str, Any]:
seed = int(job["seed"])
split = str(job["split"])
difficulty = int(job["difficulty"])
return {
"order": int(job["order"]),
**run_episode(
seed=seed,
split=split,
difficulty=difficulty,
config=config,
teacher=teacher_for_thread(),
),
}
results: list[dict[str, Any]] = []
with ThreadPoolExecutor(max_workers=workers, thread_name_prefix="sft-episode") as executor:
futures = [executor.submit(run_job, job) for job in episode_jobs]
for future in as_completed(futures):
result = future.result()
results.append(result)
if config.progress:
print(
json.dumps(
{
"event": "episode_done",
"accepted": bool(result.get("accepted")),
"split": result.get("split"),
"difficulty": result.get("difficulty"),
"seed": result.get("seed"),
"reason": result.get("reason", ""),
},
sort_keys=True,
),
flush=True,
)
for result in sorted(
results,
key=lambda item: (
str(item.get("split", "")),
int(item.get("difficulty", 0)),
int(item.get("seed", 0)),
),
):
seed = int(result["seed"])
split = str(result["split"])
difficulty = int(result["difficulty"])
attempts.append(
{
"seed": seed,
"split": split,
"difficulty": difficulty,
"accepted": bool(result["accepted"]),
"reason": result.get("reason", ""),
"trajectory_path": str(_write_trajectory(config.out_dir, result["trajectory"])),
}
)
if result["accepted"]:
accepted += 1
rows = list(result["rows"])
rows_by_split.setdefault(split, []).extend(rows)
rewards.append(float(result["trajectory"].get("terminal_total", 0.0)))
for split_rows in rows_by_split.values():
split_rows.sort(
key=lambda row: (
int((row.get("metadata") or {}).get("difficulty", 0)),
int((row.get("metadata") or {}).get("seed", 0)),
int((row.get("metadata") or {}).get("step", 0)),
)
)
for split_name in ("train", "validation", config.split):
write_jsonl(config.out_dir / f"{split_name}.jsonl", rows_by_split.get(split_name, []))
reward_verification = verify_sft_dataset_rewards(
config.out_dir,
min_terminal_reward=config.min_terminal_reward,
require_train_rows=config.split == "train",
required_difficulties=difficulty_levels if len(difficulty_levels) > 1 else (),
)
accepted_by_difficulty: dict[str, int] = {}
attempted_by_difficulty: dict[str, int] = {}
reward_by_difficulty: dict[str, list[float]] = {}
row_count_by_difficulty: dict[str, int] = {}
for result in results:
difficulty_key = str(int(result.get("difficulty", 0)))
attempted_by_difficulty[difficulty_key] = attempted_by_difficulty.get(difficulty_key, 0) + 1
if result.get("accepted"):
accepted_by_difficulty[difficulty_key] = accepted_by_difficulty.get(difficulty_key, 0) + 1
reward_by_difficulty.setdefault(difficulty_key, []).append(
float((result.get("trajectory") or {}).get("terminal_total", 0.0))
)
for split_rows in rows_by_split.values():
for row in split_rows:
difficulty_key = str(int((row.get("metadata") or {}).get("difficulty", 0)))
row_count_by_difficulty[difficulty_key] = row_count_by_difficulty.get(difficulty_key, 0) + 1
manifest = {
"teacher_model": config.teacher_model,
"target_model": config.target_model,
"split": config.split,
"difficulty": config.difficulty,
"difficulty_levels": [int(level) for level in difficulty_levels],
"difficulty_bucket_count": int(difficulty_bucket_count),
"episodes_per_difficulty": config.episodes,
"validation_episodes_per_difficulty": config.validation_episodes,
"seed_start": config.seed_start,
"episodes_attempted": attempted,
"episodes_accepted": accepted,
"acceptance_rate": accepted / attempted if attempted else 0.0,
"attempted_by_difficulty": attempted_by_difficulty,
"accepted_by_difficulty": accepted_by_difficulty,
"rows_by_difficulty": row_count_by_difficulty,
"reward_summary_by_difficulty": {
key: _reward_summary(value) for key, value in sorted(reward_by_difficulty.items())
},
"workers": workers,
"rows_by_split": {key: len(value) for key, value in sorted(rows_by_split.items())},
"reward_summary": _reward_summary(rewards),
"reward_verification": reward_verification,
"git_sha": _git_sha(),
"verifier_version": "verifier_v1",
"dry_run_oracle": config.dry_run_oracle,
"attempts": attempts,
}
if config.push_to_hub:
if not reward_verification["passed"]:
raise RuntimeError("Reward verification failed; refusing to push dataset to Hub.")
manifest["hub"] = {
"repo_id": config.dataset_repo_id,
"private": bool(config.hub_private),
"url": f"https://huggingface.co/datasets/{config.dataset_repo_id}",
}
write_dataset_card(config.out_dir, manifest, config.dataset_repo_id)
manifest_path = config.out_dir / "manifest.json"
manifest_path.write_text(
json.dumps(manifest, indent=2, sort_keys=True, default=str),
encoding="utf-8",
)
if config.push_to_hub:
hub_result = push_dataset_to_hub(
config.out_dir,
repo_id=config.dataset_repo_id,
private=config.hub_private,
)
manifest["hub"].update(hub_result)
manifest_path.write_text(
json.dumps(manifest, indent=2, sort_keys=True, default=str),
encoding="utf-8",
)
return manifest
def build_arg_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--teacher-model", default=DEFAULT_TEACHER_MODEL)
parser.add_argument("--target-model", default=DEFAULT_TARGET_MODEL)
parser.add_argument("--split", default="train", choices=["train", "validation", "hidden_eval"])
parser.add_argument("--difficulty", type=int, default=0)
parser.add_argument(
"--difficulty-levels",
default="",
help="Comma-separated curriculum levels to include, for example 0,1,2,3. "
"When set, --episodes is per difficulty level.",
)
parser.add_argument(
"--difficulty-buckets",
type=int,
default=0,
help=(
"Number of curriculum difficulty buckets to expose to the environment. "
"Defaults to max(--difficulty-levels)+1."
),
)
parser.add_argument("--seed-start", type=int, default=0)
parser.add_argument("--episodes", type=int, default=100)
parser.add_argument("--validation-episodes", type=int, default=0)
parser.add_argument("--out-dir", type=Path, default=Path("outputs/sft"))
parser.add_argument("--max-steps", type=int, default=40)
parser.add_argument("--max-teacher-retries", type=int, default=2)
parser.add_argument("--max-tokens", type=int, default=768)
parser.add_argument("--temperature", type=float, default=0.2)
parser.add_argument("--top-p", type=float, default=0.95)
parser.add_argument(
"--workers",
type=int,
default=0,
help="Parallel episode workers. 0 auto-selects up to 8 workers.",
)
parser.add_argument(
"--min-terminal-reward",
type=float,
default=12.0,
help="Minimum verifier-backed terminal reward required for SFT rows.",
)
parser.add_argument(
"--verify-only",
action="store_true",
help="Only verify an existing out-dir dataset reward metadata.",
)
parser.add_argument(
"--push-to-hub",
action="store_true",
help="Upload the verified dataset folder to a Hugging Face dataset repo.",
)
parser.add_argument(
"--progress",
action="store_true",
help="Print one JSON progress event for each completed episode job.",
)
parser.add_argument(
"--push-only",
action="store_true",
help="Verify and upload an existing out-dir dataset without regenerating rows.",
)
parser.add_argument(
"--dataset-repo-id",
default="Humanlearning/CyberSecurity_OWASP-sft-dataset",
help="Hugging Face dataset repo id used with --push-to-hub.",
)
parser.add_argument(
"--hub-private",
action="store_true",
help="Create/upload the Hugging Face dataset repo as private.",
)
parser.add_argument(
"--dry-run-oracle",
action="store_true",
help="Generate deterministic oracle data without calling the HF API.",
)
return parser
def config_from_args(args: argparse.Namespace) -> DatasetConfig:
return DatasetConfig(
teacher_model=args.teacher_model,
target_model=args.target_model,
split=args.split,
difficulty=args.difficulty,
difficulty_levels=_parse_int_csv(args.difficulty_levels),
difficulty_buckets=args.difficulty_buckets,
seed_start=args.seed_start,
episodes=args.episodes,
validation_episodes=args.validation_episodes,
out_dir=args.out_dir,
max_steps=args.max_steps,
max_teacher_retries=args.max_teacher_retries,
max_tokens=args.max_tokens,
temperature=args.temperature,
top_p=args.top_p,
dry_run_oracle=args.dry_run_oracle,
workers=args.workers,
min_terminal_reward=args.min_terminal_reward,
push_to_hub=args.push_to_hub,
dataset_repo_id=args.dataset_repo_id,
hub_private=args.hub_private,
progress=args.progress,
)
def main(argv: list[str] | None = None) -> int:
parser = build_arg_parser()
args = parser.parse_args(argv)
try:
if args.verify_only:
verification = verify_sft_dataset_rewards(
args.out_dir,
min_terminal_reward=args.min_terminal_reward,
require_train_rows=args.split == "train",
required_difficulties=_parse_int_csv(args.difficulty_levels),
)
print(json.dumps({"reward_verification": verification}, indent=2, sort_keys=True))
return 0 if verification["passed"] else 2
if args.push_only:
result = push_existing_dataset(
args.out_dir,
repo_id=args.dataset_repo_id,
private=args.hub_private,
min_terminal_reward=args.min_terminal_reward,
required_difficulties=_parse_int_csv(args.difficulty_levels),
)
print(json.dumps(result, indent=2, sort_keys=True))
return 0
manifest = generate_dataset(config_from_args(args))
print(json.dumps(manifest, indent=2, sort_keys=True))
return 0 if manifest.get("reward_verification", {}).get("passed") else 2
except (RuntimeError, ValueError) as exc:
print(
json.dumps(
{"error": str(exc), "error_type": exc.__class__.__name__},
indent=2,
sort_keys=True,
)
)
return 2
if __name__ == "__main__":
raise SystemExit(main())