Spaces:
Sleeping
Sleeping
feat: expand README with synthetic SFT dataset generation instructions, enhance dataset verification and pushing to Hugging Face Hub, and improve modal training scripts with default configurations for curriculum and GPU fallback
60f97ab | import importlib.util | |
| import json | |
| import os | |
| import sys | |
| import uuid | |
| from pathlib import Path | |
| from CyberSecurity_OWASP.models import CyberSecurityOWASPAction | |
| from CyberSecurity_OWASP.server.CyberSecurity_OWASP_environment import ( | |
| CybersecurityOwaspEnvironment, | |
| ) | |
| MODULE_PATH = Path(__file__).resolve().parents[1] / "scripts" / "generate_sft_dataset.py" | |
| SPEC = importlib.util.spec_from_file_location("generate_sft_dataset", MODULE_PATH) | |
| generate_sft_dataset = importlib.util.module_from_spec(SPEC) | |
| assert SPEC.loader is not None | |
| sys.modules[SPEC.name] = generate_sft_dataset | |
| SPEC.loader.exec_module(generate_sft_dataset) | |
| def _isolated_out_dir(label: str) -> Path: | |
| root = Path("outputs") / "sft_dataset_tests" / f"{label}_{uuid.uuid4().hex[:8]}" | |
| workspace_root = root / "workspaces" | |
| workspace_root.mkdir(parents=True, exist_ok=True) | |
| os.environ["CYBERSECURITY_OWASP_WORKSPACE_ROOT"] = str(workspace_root) | |
| return root / "sft" | |
| def test_extracts_and_validates_action_json(): | |
| action = generate_sft_dataset.parse_action_text( | |
| '```json\n{"tool_name":"inspect_policy_graph","arguments":{}}\n```' | |
| ) | |
| assert isinstance(action, CyberSecurityOWASPAction) | |
| assert action.tool_name == "inspect_policy_graph" | |
| def test_prompt_uses_visible_observation_only(): | |
| _isolated_out_dir("prompt") | |
| env = CybersecurityOwaspEnvironment() | |
| try: | |
| obs = env.reset(seed=501, split="train", difficulty=0) | |
| prompt = generate_sft_dataset.build_user_prompt(obs, []) | |
| finally: | |
| env.close() | |
| lowered = prompt.lower() | |
| assert "hidden_facts" not in lowered | |
| assert "oracle_hidden_focus" not in lowered | |
| assert "reward_engine" not in lowered | |
| assert "validators.py" not in lowered | |
| assert "tests/hidden" not in lowered | |
| assert "hidden tests" not in lowered | |
| def test_chat_row_matches_conversational_sft_shape(): | |
| _isolated_out_dir("chat_row") | |
| env = CybersecurityOwaspEnvironment() | |
| try: | |
| obs = env.reset(seed=502, split="train", difficulty=0) | |
| messages = generate_sft_dataset.build_chat_messages(obs, []) | |
| action = CyberSecurityOWASPAction(tool_name="inspect_policy_graph", arguments={}) | |
| row = generate_sft_dataset.make_chat_row( | |
| messages=messages, | |
| action=action, | |
| metadata={ | |
| "target_model": generate_sft_dataset.DEFAULT_TARGET_MODEL, | |
| "teacher_model": generate_sft_dataset.DEFAULT_TEACHER_MODEL, | |
| "seed": 502, | |
| }, | |
| ) | |
| finally: | |
| env.close() | |
| assert [message["role"] for message in row["messages"]] == [ | |
| "system", | |
| "user", | |
| "assistant", | |
| ] | |
| assert json.loads(row["messages"][-1]["content"]) == action.model_dump() | |
| assert row["metadata"]["target_model"] == "unsloth/gemma-4-E2B-it" | |
| def test_dry_run_oracle_creates_chat_jsonl_without_network(): | |
| out_dir = _isolated_out_dir("dry_run") | |
| manifest = generate_sft_dataset.generate_dataset( | |
| generate_sft_dataset.DatasetConfig( | |
| episodes=2, | |
| validation_episodes=1, | |
| out_dir=out_dir, | |
| dry_run_oracle=True, | |
| workers=2, | |
| difficulty_levels=(0, 1), | |
| ) | |
| ) | |
| assert manifest["difficulty_levels"] == [0, 1] | |
| assert manifest["difficulty_bucket_count"] >= 2 | |
| assert manifest["episodes_attempted"] == 6 | |
| assert manifest["episodes_accepted"] == 6 | |
| assert manifest["workers"] == 2 | |
| assert manifest["reward_verification"]["passed"] is True | |
| assert manifest["reward_verification"]["missing_difficulties"] == [] | |
| assert manifest["rows_by_difficulty"]["0"] > 0 | |
| assert manifest["rows_by_difficulty"]["1"] > 0 | |
| assert (out_dir / "train.jsonl").exists() | |
| assert (out_dir / "validation.jsonl").exists() | |
| train_rows = [ | |
| json.loads(line) | |
| for line in (out_dir / "train.jsonl").read_text(encoding="utf-8").splitlines() | |
| if line.strip() | |
| ] | |
| validation_rows = [ | |
| json.loads(line) | |
| for line in (out_dir / "validation.jsonl").read_text(encoding="utf-8").splitlines() | |
| if line.strip() | |
| ] | |
| assert train_rows | |
| assert validation_rows | |
| assert all(row["messages"][-1]["role"] == "assistant" for row in train_rows) | |
| reward_check = generate_sft_dataset.verify_sft_dataset_rewards( | |
| out_dir, | |
| required_difficulties=(0, 1), | |
| ) | |
| assert reward_check["passed"] is True | |
| assert (out_dir / "README.md").exists() | |
| def test_reward_verification_rejects_low_reward_rows(): | |
| out_dir = _isolated_out_dir("bad_reward") | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| action = CyberSecurityOWASPAction(tool_name="inspect_policy_graph", arguments={}) | |
| row = { | |
| "messages": [ | |
| {"role": "system", "content": "system"}, | |
| {"role": "user", "content": "user"}, | |
| {"role": "assistant", "content": json.dumps(action.model_dump())}, | |
| ], | |
| "metadata": { | |
| "final_success": True, | |
| "terminal_total": 1.0, | |
| "anti_cheat_flags": [], | |
| "final_reward_breakdown": { | |
| "security": 5.0, | |
| "regression": 3.0, | |
| "public_routes": 1.0, | |
| "patch_quality": 2.0, | |
| "visible_tests": 1.0, | |
| }, | |
| }, | |
| } | |
| (out_dir / "train.jsonl").write_text(json.dumps(row) + "\n", encoding="utf-8") | |
| reward_check = generate_sft_dataset.verify_sft_dataset_rewards(out_dir) | |
| assert reward_check["passed"] is False | |
| assert reward_check["failure_count"] == 1 | |
| def test_saved_oracle_trajectory_replays_to_success(): | |
| out_dir = _isolated_out_dir("replay") | |
| generate_sft_dataset.generate_dataset( | |
| generate_sft_dataset.DatasetConfig( | |
| episodes=1, | |
| out_dir=out_dir, | |
| dry_run_oracle=True, | |
| ) | |
| ) | |
| trajectory_path = next((out_dir / "trajectories").glob("train_seed*.json")) | |
| trajectory = json.loads(trajectory_path.read_text(encoding="utf-8")) | |
| env = CybersecurityOwaspEnvironment() | |
| try: | |
| env.reset( | |
| seed=int(trajectory["seed"]), | |
| split=trajectory["split"], | |
| difficulty=int(trajectory["difficulty"]), | |
| ) | |
| final = None | |
| for action_data in trajectory["actions"]: | |
| final = env.step(CyberSecurityOWASPAction(**action_data)) | |
| assert final is not None | |
| assert final.done is True | |
| assert env.state.success is True | |
| assert not env.state.anti_cheat_flags | |
| finally: | |
| env.close() | |