Spaces:
Sleeping
Sleeping
File size: 6,559 Bytes
f7b8ac6 60f97ab f7b8ac6 60f97ab f7b8ac6 60f97ab f7b8ac6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 | import importlib.util
import json
import os
import sys
import uuid
from pathlib import Path
from CyberSecurity_OWASP.models import CyberSecurityOWASPAction
from CyberSecurity_OWASP.server.CyberSecurity_OWASP_environment import (
CybersecurityOwaspEnvironment,
)
MODULE_PATH = Path(__file__).resolve().parents[1] / "scripts" / "generate_sft_dataset.py"
SPEC = importlib.util.spec_from_file_location("generate_sft_dataset", MODULE_PATH)
generate_sft_dataset = importlib.util.module_from_spec(SPEC)
assert SPEC.loader is not None
sys.modules[SPEC.name] = generate_sft_dataset
SPEC.loader.exec_module(generate_sft_dataset)
def _isolated_out_dir(label: str) -> Path:
root = Path("outputs") / "sft_dataset_tests" / f"{label}_{uuid.uuid4().hex[:8]}"
workspace_root = root / "workspaces"
workspace_root.mkdir(parents=True, exist_ok=True)
os.environ["CYBERSECURITY_OWASP_WORKSPACE_ROOT"] = str(workspace_root)
return root / "sft"
def test_extracts_and_validates_action_json():
action = generate_sft_dataset.parse_action_text(
'```json\n{"tool_name":"inspect_policy_graph","arguments":{}}\n```'
)
assert isinstance(action, CyberSecurityOWASPAction)
assert action.tool_name == "inspect_policy_graph"
def test_prompt_uses_visible_observation_only():
_isolated_out_dir("prompt")
env = CybersecurityOwaspEnvironment()
try:
obs = env.reset(seed=501, split="train", difficulty=0)
prompt = generate_sft_dataset.build_user_prompt(obs, [])
finally:
env.close()
lowered = prompt.lower()
assert "hidden_facts" not in lowered
assert "oracle_hidden_focus" not in lowered
assert "reward_engine" not in lowered
assert "validators.py" not in lowered
assert "tests/hidden" not in lowered
assert "hidden tests" not in lowered
def test_chat_row_matches_conversational_sft_shape():
_isolated_out_dir("chat_row")
env = CybersecurityOwaspEnvironment()
try:
obs = env.reset(seed=502, split="train", difficulty=0)
messages = generate_sft_dataset.build_chat_messages(obs, [])
action = CyberSecurityOWASPAction(tool_name="inspect_policy_graph", arguments={})
row = generate_sft_dataset.make_chat_row(
messages=messages,
action=action,
metadata={
"target_model": generate_sft_dataset.DEFAULT_TARGET_MODEL,
"teacher_model": generate_sft_dataset.DEFAULT_TEACHER_MODEL,
"seed": 502,
},
)
finally:
env.close()
assert [message["role"] for message in row["messages"]] == [
"system",
"user",
"assistant",
]
assert json.loads(row["messages"][-1]["content"]) == action.model_dump()
assert row["metadata"]["target_model"] == "unsloth/gemma-4-E2B-it"
def test_dry_run_oracle_creates_chat_jsonl_without_network():
out_dir = _isolated_out_dir("dry_run")
manifest = generate_sft_dataset.generate_dataset(
generate_sft_dataset.DatasetConfig(
episodes=2,
validation_episodes=1,
out_dir=out_dir,
dry_run_oracle=True,
workers=2,
difficulty_levels=(0, 1),
)
)
assert manifest["difficulty_levels"] == [0, 1]
assert manifest["difficulty_bucket_count"] >= 2
assert manifest["episodes_attempted"] == 6
assert manifest["episodes_accepted"] == 6
assert manifest["workers"] == 2
assert manifest["reward_verification"]["passed"] is True
assert manifest["reward_verification"]["missing_difficulties"] == []
assert manifest["rows_by_difficulty"]["0"] > 0
assert manifest["rows_by_difficulty"]["1"] > 0
assert (out_dir / "train.jsonl").exists()
assert (out_dir / "validation.jsonl").exists()
train_rows = [
json.loads(line)
for line in (out_dir / "train.jsonl").read_text(encoding="utf-8").splitlines()
if line.strip()
]
validation_rows = [
json.loads(line)
for line in (out_dir / "validation.jsonl").read_text(encoding="utf-8").splitlines()
if line.strip()
]
assert train_rows
assert validation_rows
assert all(row["messages"][-1]["role"] == "assistant" for row in train_rows)
reward_check = generate_sft_dataset.verify_sft_dataset_rewards(
out_dir,
required_difficulties=(0, 1),
)
assert reward_check["passed"] is True
assert (out_dir / "README.md").exists()
def test_reward_verification_rejects_low_reward_rows():
out_dir = _isolated_out_dir("bad_reward")
out_dir.mkdir(parents=True, exist_ok=True)
action = CyberSecurityOWASPAction(tool_name="inspect_policy_graph", arguments={})
row = {
"messages": [
{"role": "system", "content": "system"},
{"role": "user", "content": "user"},
{"role": "assistant", "content": json.dumps(action.model_dump())},
],
"metadata": {
"final_success": True,
"terminal_total": 1.0,
"anti_cheat_flags": [],
"final_reward_breakdown": {
"security": 5.0,
"regression": 3.0,
"public_routes": 1.0,
"patch_quality": 2.0,
"visible_tests": 1.0,
},
},
}
(out_dir / "train.jsonl").write_text(json.dumps(row) + "\n", encoding="utf-8")
reward_check = generate_sft_dataset.verify_sft_dataset_rewards(out_dir)
assert reward_check["passed"] is False
assert reward_check["failure_count"] == 1
def test_saved_oracle_trajectory_replays_to_success():
out_dir = _isolated_out_dir("replay")
generate_sft_dataset.generate_dataset(
generate_sft_dataset.DatasetConfig(
episodes=1,
out_dir=out_dir,
dry_run_oracle=True,
)
)
trajectory_path = next((out_dir / "trajectories").glob("train_seed*.json"))
trajectory = json.loads(trajectory_path.read_text(encoding="utf-8"))
env = CybersecurityOwaspEnvironment()
try:
env.reset(
seed=int(trajectory["seed"]),
split=trajectory["split"],
difficulty=int(trajectory["difficulty"]),
)
final = None
for action_data in trajectory["actions"]:
final = env.step(CyberSecurityOWASPAction(**action_data))
assert final is not None
assert final.done is True
assert env.state.success is True
assert not env.state.anti_cheat_flags
finally:
env.close()
|