Spaces:
Sleeping
Sleeping
Commit ·
60f97ab
1
Parent(s): f7b8ac6
feat: expand README with synthetic SFT dataset generation instructions, enhance dataset verification and pushing to Hugging Face Hub, and improve modal training scripts with default configurations for curriculum and GPU fallback
Browse files- README.md +94 -0
- scripts/generate_sft_dataset.py +593 -36
- scripts/modal_train_sft.py +266 -20
- tests/test_modal_scenario_cache_static.py +37 -0
- tests/test_sft_dataset_generation.py +48 -2
README.md
CHANGED
|
@@ -256,6 +256,100 @@ The shell wrapper is equivalent:
|
|
| 256 |
MODE=smoke EPISODES=4 uv run --extra modal bash scripts/modal_run_ephemeral.sh
|
| 257 |
```
|
| 258 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 259 |
## Modal GRPO Training
|
| 260 |
|
| 261 |
The persistent GPU training launcher packages this local repo into Modal, trains
|
|
|
|
| 256 |
MODE=smoke EPISODES=4 uv run --extra modal bash scripts/modal_run_ephemeral.sh
|
| 257 |
```
|
| 258 |
|
| 259 |
+
## Synthetic SFT Before GRPO
|
| 260 |
+
|
| 261 |
+
Use supervised fine-tuning to warm-start `unsloth/gemma-4-E2B-it` before GRPO.
|
| 262 |
+
The SFT generator executes every teacher action in the real environment and
|
| 263 |
+
keeps only trajectories that pass the deterministic reward verifier.
|
| 264 |
+
|
| 265 |
+
Generate a 300-train-episode curriculum SFT dataset across levels `0,1,2,3`:
|
| 266 |
+
|
| 267 |
+
```bash
|
| 268 |
+
uv run python scripts/generate_sft_dataset.py \
|
| 269 |
+
--teacher-model deepseek-ai/DeepSeek-V4-Pro \
|
| 270 |
+
--target-model unsloth/gemma-4-E2B-it \
|
| 271 |
+
--difficulty-levels 0,1,2,3 \
|
| 272 |
+
--difficulty-buckets 4 \
|
| 273 |
+
--episodes 75 \
|
| 274 |
+
--validation-episodes 20 \
|
| 275 |
+
--workers 8 \
|
| 276 |
+
--out-dir outputs/sft
|
| 277 |
+
```
|
| 278 |
+
|
| 279 |
+
`--episodes` is per difficulty level when `--difficulty-levels` is set, so
|
| 280 |
+
`--episodes 75` across four levels gives 300 total train episodes. Expect
|
| 281 |
+
roughly 2,400-4,500 chat-format JSONL rows because each successful trajectory
|
| 282 |
+
contributes one row per action step. The script writes JSONL rows under
|
| 283 |
+
`outputs/sft/`, trajectory artifacts under `outputs/sft/trajectories/`, a
|
| 284 |
+
dataset card at `outputs/sft/README.md`, and `outputs/sft/manifest.json` with
|
| 285 |
+
reward summaries and curriculum coverage.
|
| 286 |
+
|
| 287 |
+
Verify reward metadata before any training run:
|
| 288 |
+
|
| 289 |
+
```bash
|
| 290 |
+
uv run python scripts/generate_sft_dataset.py \
|
| 291 |
+
--verify-only \
|
| 292 |
+
--difficulty-levels 0,1,2,3 \
|
| 293 |
+
--out-dir outputs/sft
|
| 294 |
+
```
|
| 295 |
+
|
| 296 |
+
Push the verified dataset to Hugging Face Hub:
|
| 297 |
+
|
| 298 |
+
```bash
|
| 299 |
+
uv run python scripts/generate_sft_dataset.py \
|
| 300 |
+
--push-only \
|
| 301 |
+
--difficulty-levels 0,1,2,3 \
|
| 302 |
+
--out-dir outputs/sft \
|
| 303 |
+
--dataset-repo-id Humanlearning/CyberSecurity_OWASP-sft-dataset
|
| 304 |
+
```
|
| 305 |
+
|
| 306 |
+
The canonical dataset repo name is
|
| 307 |
+
`Humanlearning/CyberSecurity_OWASP-sft-dataset`. The upload is refused if
|
| 308 |
+
reward verification fails or `HF_TOKEN` is missing.
|
| 309 |
+
|
| 310 |
+
You can also generate and push in one command by adding `--push-to-hub` to the
|
| 311 |
+
generation command.
|
| 312 |
+
|
| 313 |
+
For local CI or smoke checks, add `--dry-run-oracle`; official SFT data should
|
| 314 |
+
use the teacher path and still pass the verifier gate above.
|
| 315 |
+
|
| 316 |
+
Launch SFT on Modal after reward verification passes:
|
| 317 |
+
|
| 318 |
+
```bash
|
| 319 |
+
uv run --extra modal modal run --detach scripts/modal_train_sft.py \
|
| 320 |
+
--local-train-path outputs/sft/train.jsonl \
|
| 321 |
+
--local-validation-path outputs/sft/validation.jsonl \
|
| 322 |
+
--local-manifest-path outputs/sft/manifest.json \
|
| 323 |
+
--required-difficulties 0,1,2,3 \
|
| 324 |
+
--trackio-space-id Humanlearning/CyberSecurity_OWASP-trackio \
|
| 325 |
+
--trackio-project CyberSecurity_OWASP-sft \
|
| 326 |
+
--output-repo-id Humanlearning/CyberSecurity_OWASP-unsloth-gemma-4-e2b-it-sft-lora \
|
| 327 |
+
--push-to-hub \
|
| 328 |
+
--detach
|
| 329 |
+
```
|
| 330 |
+
|
| 331 |
+
`scripts/modal_train_sft.py` re-checks the JSONL reward metadata locally before
|
| 332 |
+
upload and again inside Modal before loading the model. It refuses to start SFT
|
| 333 |
+
unless all required curriculum difficulties are represented and the verifier
|
| 334 |
+
reward metadata passes. The default SFT config trains one full epoch
|
| 335 |
+
(`--max-steps -1`) with packed assistant-only loss, bf16/tf32, LoRA rank 32,
|
| 336 |
+
and Modal GPU fallback `H200 -> H100 -> A100-80GB -> L40S`. A warm run for the
|
| 337 |
+
300-episode dataset should usually finish in about 15-45 minutes; first image
|
| 338 |
+
or model-cache builds can push that closer to 35-75 minutes.
|
| 339 |
+
|
| 340 |
+
Continue GRPO from the SFT LoRA:
|
| 341 |
+
|
| 342 |
+
```bash
|
| 343 |
+
uv run --extra modal modal run --detach scripts/modal_train_grpo.py \
|
| 344 |
+
--initial-adapter-repo-id Humanlearning/CyberSecurity_OWASP-unsloth-gemma-4-e2b-it-sft-lora \
|
| 345 |
+
--max-steps 300 \
|
| 346 |
+
--dataset-size 64 \
|
| 347 |
+
--num-generations 8 \
|
| 348 |
+
--difficulty 0 \
|
| 349 |
+
--trace-log-every 10 \
|
| 350 |
+
--detach
|
| 351 |
+
```
|
| 352 |
+
|
| 353 |
## Modal GRPO Training
|
| 354 |
|
| 355 |
The persistent GPU training launcher packages this local repo into Modal, trains
|
scripts/generate_sft_dataset.py
CHANGED
|
@@ -14,6 +14,8 @@ import json
|
|
| 14 |
import os
|
| 15 |
import statistics
|
| 16 |
import subprocess
|
|
|
|
|
|
|
| 17 |
from dataclasses import dataclass
|
| 18 |
from pathlib import Path
|
| 19 |
from typing import Any, Iterable
|
|
@@ -72,6 +74,14 @@ class DatasetConfig:
|
|
| 72 |
temperature: float = 0.2
|
| 73 |
top_p: float = 0.95
|
| 74 |
dry_run_oracle: bool = False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
|
| 77 |
class HuggingFaceTeacher:
|
|
@@ -579,6 +589,174 @@ def write_jsonl(path: Path, rows: Iterable[dict[str, Any]]) -> None:
|
|
| 579 |
handle.write(json.dumps(row, sort_keys=True, default=str) + "\n")
|
| 580 |
|
| 581 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 582 |
def _write_trajectory(out_dir: Path, trajectory: dict[str, Any]) -> Path:
|
| 583 |
traj_dir = out_dir / "trajectories"
|
| 584 |
traj_dir.mkdir(parents=True, exist_ok=True)
|
|
@@ -622,80 +800,365 @@ def _reward_summary(values: list[float]) -> dict[str, float]:
|
|
| 622 |
}
|
| 623 |
|
| 624 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 625 |
def generate_dataset(config: DatasetConfig) -> dict[str, Any]:
|
| 626 |
config.out_dir.mkdir(parents=True, exist_ok=True)
|
| 627 |
-
|
|
|
|
| 628 |
if not config.dry_run_oracle:
|
| 629 |
-
|
| 630 |
-
if not
|
| 631 |
raise RuntimeError("HF_TOKEN is required unless --dry-run-oracle is set")
|
| 632 |
-
teacher = HuggingFaceTeacher(
|
| 633 |
-
model=config.teacher_model,
|
| 634 |
-
token=token,
|
| 635 |
-
max_tokens=config.max_tokens,
|
| 636 |
-
temperature=config.temperature,
|
| 637 |
-
top_p=config.top_p,
|
| 638 |
-
)
|
| 639 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 640 |
split_jobs = [(config.split, config.episodes, config.seed_start)]
|
| 641 |
if config.validation_episodes:
|
| 642 |
-
split_jobs.append(("validation", config.validation_episodes,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 643 |
|
| 644 |
rows_by_split: dict[str, list[dict[str, Any]]] = {"train": [], "validation": []}
|
| 645 |
attempts: list[dict[str, Any]] = []
|
| 646 |
rewards: list[float] = []
|
| 647 |
accepted = 0
|
| 648 |
-
attempted =
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 654 |
seed=seed,
|
| 655 |
split=split,
|
| 656 |
-
difficulty=
|
| 657 |
config=config,
|
| 658 |
-
teacher=
|
| 659 |
-
)
|
| 660 |
-
|
| 661 |
-
|
| 662 |
-
|
| 663 |
-
|
| 664 |
-
|
| 665 |
-
|
| 666 |
-
|
| 667 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 668 |
)
|
| 669 |
-
|
| 670 |
-
accepted += 1
|
| 671 |
-
rows = list(result["rows"])
|
| 672 |
-
rows_by_split.setdefault(split, []).extend(rows)
|
| 673 |
-
rewards.append(float(result["trajectory"].get("terminal_total", 0.0)))
|
| 674 |
|
| 675 |
for split_name in ("train", "validation", config.split):
|
| 676 |
write_jsonl(config.out_dir / f"{split_name}.jsonl", rows_by_split.get(split_name, []))
|
| 677 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 678 |
manifest = {
|
| 679 |
"teacher_model": config.teacher_model,
|
| 680 |
"target_model": config.target_model,
|
| 681 |
"split": config.split,
|
| 682 |
"difficulty": config.difficulty,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 683 |
"seed_start": config.seed_start,
|
| 684 |
"episodes_attempted": attempted,
|
| 685 |
"episodes_accepted": accepted,
|
| 686 |
"acceptance_rate": accepted / attempted if attempted else 0.0,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 687 |
"rows_by_split": {key: len(value) for key, value in sorted(rows_by_split.items())},
|
| 688 |
"reward_summary": _reward_summary(rewards),
|
|
|
|
| 689 |
"git_sha": _git_sha(),
|
| 690 |
"verifier_version": "verifier_v1",
|
| 691 |
"dry_run_oracle": config.dry_run_oracle,
|
| 692 |
"attempts": attempts,
|
| 693 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 694 |
manifest_path = config.out_dir / "manifest.json"
|
| 695 |
manifest_path.write_text(
|
| 696 |
json.dumps(manifest, indent=2, sort_keys=True, default=str),
|
| 697 |
encoding="utf-8",
|
| 698 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 699 |
return manifest
|
| 700 |
|
| 701 |
|
|
@@ -705,6 +1168,21 @@ def build_arg_parser() -> argparse.ArgumentParser:
|
|
| 705 |
parser.add_argument("--target-model", default=DEFAULT_TARGET_MODEL)
|
| 706 |
parser.add_argument("--split", default="train", choices=["train", "validation", "hidden_eval"])
|
| 707 |
parser.add_argument("--difficulty", type=int, default=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 708 |
parser.add_argument("--seed-start", type=int, default=0)
|
| 709 |
parser.add_argument("--episodes", type=int, default=100)
|
| 710 |
parser.add_argument("--validation-episodes", type=int, default=0)
|
|
@@ -714,6 +1192,48 @@ def build_arg_parser() -> argparse.ArgumentParser:
|
|
| 714 |
parser.add_argument("--max-tokens", type=int, default=768)
|
| 715 |
parser.add_argument("--temperature", type=float, default=0.2)
|
| 716 |
parser.add_argument("--top-p", type=float, default=0.95)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 717 |
parser.add_argument(
|
| 718 |
"--dry-run-oracle",
|
| 719 |
action="store_true",
|
|
@@ -728,6 +1248,8 @@ def config_from_args(args: argparse.Namespace) -> DatasetConfig:
|
|
| 728 |
target_model=args.target_model,
|
| 729 |
split=args.split,
|
| 730 |
difficulty=args.difficulty,
|
|
|
|
|
|
|
| 731 |
seed_start=args.seed_start,
|
| 732 |
episodes=args.episodes,
|
| 733 |
validation_episodes=args.validation_episodes,
|
|
@@ -738,15 +1260,50 @@ def config_from_args(args: argparse.Namespace) -> DatasetConfig:
|
|
| 738 |
temperature=args.temperature,
|
| 739 |
top_p=args.top_p,
|
| 740 |
dry_run_oracle=args.dry_run_oracle,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 741 |
)
|
| 742 |
|
| 743 |
|
| 744 |
def main(argv: list[str] | None = None) -> int:
|
| 745 |
parser = build_arg_parser()
|
| 746 |
args = parser.parse_args(argv)
|
| 747 |
-
|
| 748 |
-
|
| 749 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 750 |
|
| 751 |
|
| 752 |
if __name__ == "__main__":
|
|
|
|
| 14 |
import os
|
| 15 |
import statistics
|
| 16 |
import subprocess
|
| 17 |
+
import threading
|
| 18 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 19 |
from dataclasses import dataclass
|
| 20 |
from pathlib import Path
|
| 21 |
from typing import Any, Iterable
|
|
|
|
| 74 |
temperature: float = 0.2
|
| 75 |
top_p: float = 0.95
|
| 76 |
dry_run_oracle: bool = False
|
| 77 |
+
workers: int = 0
|
| 78 |
+
min_terminal_reward: float = 12.0
|
| 79 |
+
difficulty_levels: tuple[int, ...] = ()
|
| 80 |
+
difficulty_buckets: int = 0
|
| 81 |
+
push_to_hub: bool = False
|
| 82 |
+
dataset_repo_id: str = "Humanlearning/CyberSecurity_OWASP-sft-dataset"
|
| 83 |
+
hub_private: bool = False
|
| 84 |
+
progress: bool = False
|
| 85 |
|
| 86 |
|
| 87 |
class HuggingFaceTeacher:
|
|
|
|
| 589 |
handle.write(json.dumps(row, sort_keys=True, default=str) + "\n")
|
| 590 |
|
| 591 |
|
| 592 |
+
def write_dataset_card(out_dir: Path, manifest: dict[str, Any], dataset_repo_id: str) -> Path:
|
| 593 |
+
card_path = out_dir / "README.md"
|
| 594 |
+
difficulty_levels = manifest.get("difficulty_levels", [])
|
| 595 |
+
reward_verification = manifest.get("reward_verification", {})
|
| 596 |
+
card = f"""---
|
| 597 |
+
license: apache-2.0
|
| 598 |
+
task_categories:
|
| 599 |
+
- text-generation
|
| 600 |
+
language:
|
| 601 |
+
- en
|
| 602 |
+
tags:
|
| 603 |
+
- cybersecurity
|
| 604 |
+
- owasp
|
| 605 |
+
- openenv
|
| 606 |
+
- tool-use
|
| 607 |
+
- sft
|
| 608 |
+
pretty_name: CyberSecurity_OWASP SFT Dataset
|
| 609 |
+
---
|
| 610 |
+
|
| 611 |
+
# CyberSecurity_OWASP SFT Dataset
|
| 612 |
+
|
| 613 |
+
This dataset contains verifier-gated supervised fine-tuning examples for the
|
| 614 |
+
`CyberSecurity_OWASP` OpenEnv environment. Each row teaches one step of the
|
| 615 |
+
defensive local AppSec workflow: inspect policy/code, reproduce a local
|
| 616 |
+
authorization failure, submit a policy-tied diagnosis, patch the generated app,
|
| 617 |
+
run visible tests, and submit the fix.
|
| 618 |
+
|
| 619 |
+
Every kept trajectory is executed against the real local environment and must
|
| 620 |
+
pass the deterministic reward verifier before rows are written.
|
| 621 |
+
|
| 622 |
+
## Intended Use
|
| 623 |
+
|
| 624 |
+
- Target SFT model: `{manifest.get("target_model", "")}`
|
| 625 |
+
- Teacher model: `{manifest.get("teacher_model", "")}`
|
| 626 |
+
- Dataset repo: `{dataset_repo_id}`
|
| 627 |
+
- Format: chat JSONL with `messages` and verifier metadata
|
| 628 |
+
- Dry-run oracle: `{manifest.get("dry_run_oracle", False)}`
|
| 629 |
+
|
| 630 |
+
## Curriculum Coverage
|
| 631 |
+
|
| 632 |
+
- Difficulty levels: `{difficulty_levels}`
|
| 633 |
+
- Episodes attempted: `{manifest.get("episodes_attempted", 0)}`
|
| 634 |
+
- Episodes accepted: `{manifest.get("episodes_accepted", 0)}`
|
| 635 |
+
- Acceptance rate: `{manifest.get("acceptance_rate", 0.0):.4f}`
|
| 636 |
+
- Rows by split: `{json.dumps(manifest.get("rows_by_split", {}), sort_keys=True)}`
|
| 637 |
+
- Rows by difficulty: `{json.dumps(manifest.get("rows_by_difficulty", {}), sort_keys=True)}`
|
| 638 |
+
|
| 639 |
+
## Reward Verification
|
| 640 |
+
|
| 641 |
+
- Passed: `{reward_verification.get("passed", False)}`
|
| 642 |
+
- Checked rows: `{reward_verification.get("checked_rows", 0)}`
|
| 643 |
+
- Minimum terminal reward: `{reward_verification.get("min_terminal_reward", 0.0)}`
|
| 644 |
+
- Reward summary: `{json.dumps(reward_verification.get("reward_summary", {}), sort_keys=True)}`
|
| 645 |
+
|
| 646 |
+
Rows are rejected if the episode fails hidden security/regression/public-route
|
| 647 |
+
checks, triggers anti-cheat flags, lacks a positive patch-quality reward, or
|
| 648 |
+
falls below the configured terminal reward threshold.
|
| 649 |
+
|
| 650 |
+
## Schema
|
| 651 |
+
|
| 652 |
+
Each JSONL row has:
|
| 653 |
+
|
| 654 |
+
```json
|
| 655 |
+
{{
|
| 656 |
+
"messages": [
|
| 657 |
+
{{"role": "system", "content": "..."}},
|
| 658 |
+
{{"role": "user", "content": "..."}},
|
| 659 |
+
{{"role": "assistant", "content": "{{\\"tool_name\\":\\"...\\",\\"arguments\\":{{...}}}}"}}
|
| 660 |
+
],
|
| 661 |
+
"metadata": {{
|
| 662 |
+
"target_model": "...",
|
| 663 |
+
"teacher_model": "...",
|
| 664 |
+
"seed": 0,
|
| 665 |
+
"split": "train",
|
| 666 |
+
"difficulty": 0,
|
| 667 |
+
"step": 1,
|
| 668 |
+
"tool_name": "inspect_policy_graph",
|
| 669 |
+
"final_success": true,
|
| 670 |
+
"terminal_total": 12.5,
|
| 671 |
+
"anti_cheat_flags": []
|
| 672 |
+
}}
|
| 673 |
+
}}
|
| 674 |
+
```
|
| 675 |
+
"""
|
| 676 |
+
card_path.write_text(card, encoding="utf-8")
|
| 677 |
+
return card_path
|
| 678 |
+
|
| 679 |
+
|
| 680 |
+
def push_dataset_to_hub(out_dir: Path, *, repo_id: str, private: bool) -> dict[str, Any]:
|
| 681 |
+
token = os.getenv("HF_TOKEN")
|
| 682 |
+
if not token:
|
| 683 |
+
raise RuntimeError("HF_TOKEN is required for --push-to-hub")
|
| 684 |
+
try:
|
| 685 |
+
from huggingface_hub import HfApi
|
| 686 |
+
except ImportError as exc: # pragma: no cover
|
| 687 |
+
raise RuntimeError("huggingface_hub is required for --push-to-hub") from exc
|
| 688 |
+
|
| 689 |
+
api = HfApi(token=token)
|
| 690 |
+
api.create_repo(repo_id=repo_id, repo_type="dataset", private=private, exist_ok=True)
|
| 691 |
+
commit_info = api.upload_folder(
|
| 692 |
+
repo_id=repo_id,
|
| 693 |
+
repo_type="dataset",
|
| 694 |
+
folder_path=str(out_dir),
|
| 695 |
+
path_in_repo=".",
|
| 696 |
+
commit_message="Upload verified CyberSecurity_OWASP SFT dataset",
|
| 697 |
+
delete_patterns=[
|
| 698 |
+
"README.md",
|
| 699 |
+
"manifest.json",
|
| 700 |
+
"train.jsonl",
|
| 701 |
+
"validation.jsonl",
|
| 702 |
+
"hidden_eval.jsonl",
|
| 703 |
+
"trajectories/**",
|
| 704 |
+
],
|
| 705 |
+
)
|
| 706 |
+
return {
|
| 707 |
+
"repo_id": repo_id,
|
| 708 |
+
"private": bool(private),
|
| 709 |
+
"url": f"https://huggingface.co/datasets/{repo_id}",
|
| 710 |
+
"commit_url": getattr(commit_info, "commit_url", ""),
|
| 711 |
+
}
|
| 712 |
+
|
| 713 |
+
|
| 714 |
+
def push_existing_dataset(
|
| 715 |
+
out_dir: Path,
|
| 716 |
+
*,
|
| 717 |
+
repo_id: str,
|
| 718 |
+
private: bool,
|
| 719 |
+
min_terminal_reward: float,
|
| 720 |
+
required_difficulties: tuple[int, ...],
|
| 721 |
+
) -> dict[str, Any]:
|
| 722 |
+
verification = verify_sft_dataset_rewards(
|
| 723 |
+
out_dir,
|
| 724 |
+
min_terminal_reward=min_terminal_reward,
|
| 725 |
+
require_train_rows=True,
|
| 726 |
+
required_difficulties=required_difficulties,
|
| 727 |
+
)
|
| 728 |
+
if not verification["passed"]:
|
| 729 |
+
raise RuntimeError(f"Reward verification failed; refusing Hub push: {verification}")
|
| 730 |
+
manifest_path = out_dir / "manifest.json"
|
| 731 |
+
if manifest_path.exists():
|
| 732 |
+
manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
|
| 733 |
+
else:
|
| 734 |
+
manifest = {
|
| 735 |
+
"teacher_model": DEFAULT_TEACHER_MODEL,
|
| 736 |
+
"target_model": DEFAULT_TARGET_MODEL,
|
| 737 |
+
"difficulty_levels": [int(level) for level in required_difficulties],
|
| 738 |
+
"rows_by_split": verification.get("rows_by_split", {}),
|
| 739 |
+
}
|
| 740 |
+
manifest["reward_verification"] = verification
|
| 741 |
+
manifest["hub"] = {
|
| 742 |
+
"repo_id": repo_id,
|
| 743 |
+
"private": bool(private),
|
| 744 |
+
"url": f"https://huggingface.co/datasets/{repo_id}",
|
| 745 |
+
}
|
| 746 |
+
write_dataset_card(out_dir, manifest, repo_id)
|
| 747 |
+
manifest_path.write_text(
|
| 748 |
+
json.dumps(manifest, indent=2, sort_keys=True, default=str),
|
| 749 |
+
encoding="utf-8",
|
| 750 |
+
)
|
| 751 |
+
hub_result = push_dataset_to_hub(out_dir, repo_id=repo_id, private=private)
|
| 752 |
+
manifest["hub"].update(hub_result)
|
| 753 |
+
manifest_path.write_text(
|
| 754 |
+
json.dumps(manifest, indent=2, sort_keys=True, default=str),
|
| 755 |
+
encoding="utf-8",
|
| 756 |
+
)
|
| 757 |
+
return {"reward_verification": verification, "hub": manifest["hub"]}
|
| 758 |
+
|
| 759 |
+
|
| 760 |
def _write_trajectory(out_dir: Path, trajectory: dict[str, Any]) -> Path:
|
| 761 |
traj_dir = out_dir / "trajectories"
|
| 762 |
traj_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
| 800 |
}
|
| 801 |
|
| 802 |
|
| 803 |
+
def _parse_int_csv(value: str) -> tuple[int, ...]:
|
| 804 |
+
if not value.strip():
|
| 805 |
+
return ()
|
| 806 |
+
levels = []
|
| 807 |
+
for item in value.split(","):
|
| 808 |
+
stripped = item.strip()
|
| 809 |
+
if not stripped:
|
| 810 |
+
continue
|
| 811 |
+
levels.append(int(stripped))
|
| 812 |
+
return tuple(dict.fromkeys(levels))
|
| 813 |
+
|
| 814 |
+
|
| 815 |
+
def _difficulty_levels(config: DatasetConfig) -> tuple[int, ...]:
|
| 816 |
+
if config.difficulty_levels:
|
| 817 |
+
return tuple(int(level) for level in config.difficulty_levels)
|
| 818 |
+
return (int(config.difficulty),)
|
| 819 |
+
|
| 820 |
+
|
| 821 |
+
def _configure_difficulty_buckets(config: DatasetConfig, levels: tuple[int, ...]) -> int:
|
| 822 |
+
requested = max(levels) + 1 if levels else int(config.difficulty) + 1
|
| 823 |
+
configured = max(int(config.difficulty_buckets or 0), requested, 1)
|
| 824 |
+
existing = os.getenv("CYBERSECURITY_OWASP_DIFFICULTY_BUCKETS")
|
| 825 |
+
if existing:
|
| 826 |
+
configured = max(configured, int(existing))
|
| 827 |
+
os.environ["CYBERSECURITY_OWASP_DIFFICULTY_BUCKETS"] = str(configured)
|
| 828 |
+
return configured
|
| 829 |
+
|
| 830 |
+
|
| 831 |
+
def _read_jsonl(path: Path) -> list[dict[str, Any]]:
|
| 832 |
+
if not path.exists():
|
| 833 |
+
return []
|
| 834 |
+
rows: list[dict[str, Any]] = []
|
| 835 |
+
for line_number, line in enumerate(path.read_text(encoding="utf-8").splitlines(), start=1):
|
| 836 |
+
if not line.strip():
|
| 837 |
+
continue
|
| 838 |
+
try:
|
| 839 |
+
item = json.loads(line)
|
| 840 |
+
except json.JSONDecodeError as exc:
|
| 841 |
+
raise ValueError(f"{path}:{line_number}: invalid JSONL row: {exc}") from exc
|
| 842 |
+
if not isinstance(item, dict):
|
| 843 |
+
raise ValueError(f"{path}:{line_number}: row must be a JSON object")
|
| 844 |
+
rows.append(item)
|
| 845 |
+
return rows
|
| 846 |
+
|
| 847 |
+
|
| 848 |
+
def _verify_sft_row_reward(
|
| 849 |
+
row: dict[str, Any],
|
| 850 |
+
*,
|
| 851 |
+
min_terminal_reward: float,
|
| 852 |
+
path: Path,
|
| 853 |
+
line_number: int,
|
| 854 |
+
) -> tuple[bool, str, float]:
|
| 855 |
+
messages = row.get("messages")
|
| 856 |
+
if not isinstance(messages, list) or len(messages) < 3:
|
| 857 |
+
return False, f"{path}:{line_number}: messages must include system/user/assistant", 0.0
|
| 858 |
+
if messages[-1].get("role") != "assistant":
|
| 859 |
+
return False, f"{path}:{line_number}: final message must be assistant", 0.0
|
| 860 |
+
try:
|
| 861 |
+
CyberSecurityOWASPAction(**json.loads(str(messages[-1].get("content", ""))))
|
| 862 |
+
except Exception as exc:
|
| 863 |
+
return False, f"{path}:{line_number}: assistant content is not a valid action: {exc}", 0.0
|
| 864 |
+
metadata = row.get("metadata")
|
| 865 |
+
if not isinstance(metadata, dict):
|
| 866 |
+
return False, f"{path}:{line_number}: missing metadata object", 0.0
|
| 867 |
+
if metadata.get("final_success") is not True:
|
| 868 |
+
return False, f"{path}:{line_number}: final_success is not true", 0.0
|
| 869 |
+
flags = metadata.get("anti_cheat_flags") or []
|
| 870 |
+
if flags:
|
| 871 |
+
return False, f"{path}:{line_number}: anti-cheat flags present: {flags}", 0.0
|
| 872 |
+
reward = float(metadata.get("terminal_total", 0.0) or 0.0)
|
| 873 |
+
if reward < min_terminal_reward:
|
| 874 |
+
return (
|
| 875 |
+
False,
|
| 876 |
+
f"{path}:{line_number}: terminal_total {reward:.3f} below required {min_terminal_reward:.3f}",
|
| 877 |
+
reward,
|
| 878 |
+
)
|
| 879 |
+
breakdown = metadata.get("final_reward_breakdown") or {}
|
| 880 |
+
if not isinstance(breakdown, dict):
|
| 881 |
+
return False, f"{path}:{line_number}: missing final_reward_breakdown", reward
|
| 882 |
+
required_positive = ("security", "regression", "public_routes", "patch_quality", "visible_tests")
|
| 883 |
+
missing = [key for key in required_positive if float(breakdown.get(key, 0.0) or 0.0) <= 0.0]
|
| 884 |
+
if missing:
|
| 885 |
+
return False, f"{path}:{line_number}: non-positive reward components: {', '.join(missing)}", reward
|
| 886 |
+
return True, "", reward
|
| 887 |
+
|
| 888 |
+
|
| 889 |
+
def verify_sft_dataset_rewards(
|
| 890 |
+
out_dir: Path,
|
| 891 |
+
*,
|
| 892 |
+
min_terminal_reward: float = 12.0,
|
| 893 |
+
require_train_rows: bool = True,
|
| 894 |
+
required_difficulties: tuple[int, ...] = (),
|
| 895 |
+
) -> dict[str, Any]:
|
| 896 |
+
"""Verify generated SFT rows carry successful verifier-backed rewards."""
|
| 897 |
+
|
| 898 |
+
checked_rows = 0
|
| 899 |
+
failed_rows: list[str] = []
|
| 900 |
+
rewards: list[float] = []
|
| 901 |
+
rows_by_split: dict[str, int] = {}
|
| 902 |
+
rows_by_difficulty: dict[str, int] = {}
|
| 903 |
+
for split_name in ("train", "validation", "hidden_eval"):
|
| 904 |
+
path = out_dir / f"{split_name}.jsonl"
|
| 905 |
+
rows = _read_jsonl(path)
|
| 906 |
+
if not rows and split_name != "train":
|
| 907 |
+
continue
|
| 908 |
+
rows_by_split[split_name] = len(rows)
|
| 909 |
+
for index, row in enumerate(rows, start=1):
|
| 910 |
+
ok, error, reward = _verify_sft_row_reward(
|
| 911 |
+
row,
|
| 912 |
+
min_terminal_reward=min_terminal_reward,
|
| 913 |
+
path=path,
|
| 914 |
+
line_number=index,
|
| 915 |
+
)
|
| 916 |
+
checked_rows += 1
|
| 917 |
+
if reward:
|
| 918 |
+
rewards.append(reward)
|
| 919 |
+
if not ok:
|
| 920 |
+
failed_rows.append(error)
|
| 921 |
+
metadata = row.get("metadata") if isinstance(row, dict) else {}
|
| 922 |
+
if isinstance(metadata, dict) and "difficulty" in metadata:
|
| 923 |
+
difficulty_key = str(int(metadata.get("difficulty", 0)))
|
| 924 |
+
rows_by_difficulty[difficulty_key] = rows_by_difficulty.get(difficulty_key, 0) + 1
|
| 925 |
+
passed = not failed_rows and (checked_rows > 0 or not require_train_rows)
|
| 926 |
+
if require_train_rows and rows_by_split.get("train", 0) <= 0:
|
| 927 |
+
passed = False
|
| 928 |
+
failed_rows.append(f"{out_dir / 'train.jsonl'}: no train rows found")
|
| 929 |
+
missing_difficulties = [
|
| 930 |
+
int(level)
|
| 931 |
+
for level in required_difficulties
|
| 932 |
+
if rows_by_difficulty.get(str(int(level)), 0) <= 0
|
| 933 |
+
]
|
| 934 |
+
if missing_difficulties:
|
| 935 |
+
passed = False
|
| 936 |
+
failed_rows.append(f"missing required curriculum difficulty rows: {missing_difficulties}")
|
| 937 |
+
return {
|
| 938 |
+
"passed": passed,
|
| 939 |
+
"checked_rows": checked_rows,
|
| 940 |
+
"failed_rows": failed_rows[:50],
|
| 941 |
+
"failure_count": len(failed_rows),
|
| 942 |
+
"rows_by_split": rows_by_split,
|
| 943 |
+
"rows_by_difficulty": rows_by_difficulty,
|
| 944 |
+
"required_difficulties": [int(level) for level in required_difficulties],
|
| 945 |
+
"missing_difficulties": missing_difficulties,
|
| 946 |
+
"min_terminal_reward": float(min_terminal_reward),
|
| 947 |
+
"reward_summary": _reward_summary(rewards),
|
| 948 |
+
}
|
| 949 |
+
|
| 950 |
+
|
| 951 |
+
def _resolved_worker_count(config: DatasetConfig, job_count: int) -> int:
|
| 952 |
+
if job_count <= 1:
|
| 953 |
+
return 1
|
| 954 |
+
if int(config.workers) > 0:
|
| 955 |
+
return max(1, min(int(config.workers), job_count))
|
| 956 |
+
cpu_count = os.cpu_count() or 4
|
| 957 |
+
return max(1, min(8, cpu_count, job_count))
|
| 958 |
+
|
| 959 |
+
|
| 960 |
def generate_dataset(config: DatasetConfig) -> dict[str, Any]:
|
| 961 |
config.out_dir.mkdir(parents=True, exist_ok=True)
|
| 962 |
+
teacher_local = threading.local()
|
| 963 |
+
teacher_token = None
|
| 964 |
if not config.dry_run_oracle:
|
| 965 |
+
teacher_token = os.getenv("HF_TOKEN")
|
| 966 |
+
if not teacher_token:
|
| 967 |
raise RuntimeError("HF_TOKEN is required unless --dry-run-oracle is set")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 968 |
|
| 969 |
+
def teacher_for_thread() -> HuggingFaceTeacher | None:
|
| 970 |
+
if config.dry_run_oracle:
|
| 971 |
+
return None
|
| 972 |
+
teacher = getattr(teacher_local, "teacher", None)
|
| 973 |
+
if teacher is None:
|
| 974 |
+
teacher = HuggingFaceTeacher(
|
| 975 |
+
model=config.teacher_model,
|
| 976 |
+
token=str(teacher_token),
|
| 977 |
+
max_tokens=config.max_tokens,
|
| 978 |
+
temperature=config.temperature,
|
| 979 |
+
top_p=config.top_p,
|
| 980 |
+
)
|
| 981 |
+
teacher_local.teacher = teacher
|
| 982 |
+
return teacher
|
| 983 |
+
|
| 984 |
+
difficulty_levels = _difficulty_levels(config)
|
| 985 |
+
difficulty_bucket_count = _configure_difficulty_buckets(config, difficulty_levels)
|
| 986 |
+
validation_seed_start = config.seed_start + int(config.episodes) * len(difficulty_levels)
|
| 987 |
split_jobs = [(config.split, config.episodes, config.seed_start)]
|
| 988 |
if config.validation_episodes:
|
| 989 |
+
split_jobs.append(("validation", config.validation_episodes, validation_seed_start))
|
| 990 |
+
episode_jobs = [
|
| 991 |
+
{
|
| 992 |
+
"order": job_order,
|
| 993 |
+
"split": split,
|
| 994 |
+
"difficulty": int(difficulty),
|
| 995 |
+
"seed": int(seed_start) + difficulty_index * int(episodes) + offset,
|
| 996 |
+
}
|
| 997 |
+
for job_order, (split, episodes, seed_start) in enumerate(split_jobs)
|
| 998 |
+
for difficulty_index, difficulty in enumerate(difficulty_levels)
|
| 999 |
+
for offset in range(int(episodes))
|
| 1000 |
+
]
|
| 1001 |
|
| 1002 |
rows_by_split: dict[str, list[dict[str, Any]]] = {"train": [], "validation": []}
|
| 1003 |
attempts: list[dict[str, Any]] = []
|
| 1004 |
rewards: list[float] = []
|
| 1005 |
accepted = 0
|
| 1006 |
+
attempted = len(episode_jobs)
|
| 1007 |
+
workers = _resolved_worker_count(config, attempted)
|
| 1008 |
+
|
| 1009 |
+
def run_job(job: dict[str, Any]) -> dict[str, Any]:
|
| 1010 |
+
seed = int(job["seed"])
|
| 1011 |
+
split = str(job["split"])
|
| 1012 |
+
difficulty = int(job["difficulty"])
|
| 1013 |
+
return {
|
| 1014 |
+
"order": int(job["order"]),
|
| 1015 |
+
**run_episode(
|
| 1016 |
seed=seed,
|
| 1017 |
split=split,
|
| 1018 |
+
difficulty=difficulty,
|
| 1019 |
config=config,
|
| 1020 |
+
teacher=teacher_for_thread(),
|
| 1021 |
+
),
|
| 1022 |
+
}
|
| 1023 |
+
|
| 1024 |
+
results: list[dict[str, Any]] = []
|
| 1025 |
+
with ThreadPoolExecutor(max_workers=workers, thread_name_prefix="sft-episode") as executor:
|
| 1026 |
+
futures = [executor.submit(run_job, job) for job in episode_jobs]
|
| 1027 |
+
for future in as_completed(futures):
|
| 1028 |
+
result = future.result()
|
| 1029 |
+
results.append(result)
|
| 1030 |
+
if config.progress:
|
| 1031 |
+
print(
|
| 1032 |
+
json.dumps(
|
| 1033 |
+
{
|
| 1034 |
+
"event": "episode_done",
|
| 1035 |
+
"accepted": bool(result.get("accepted")),
|
| 1036 |
+
"split": result.get("split"),
|
| 1037 |
+
"difficulty": result.get("difficulty"),
|
| 1038 |
+
"seed": result.get("seed"),
|
| 1039 |
+
"reason": result.get("reason", ""),
|
| 1040 |
+
},
|
| 1041 |
+
sort_keys=True,
|
| 1042 |
+
),
|
| 1043 |
+
flush=True,
|
| 1044 |
+
)
|
| 1045 |
+
|
| 1046 |
+
for result in sorted(
|
| 1047 |
+
results,
|
| 1048 |
+
key=lambda item: (
|
| 1049 |
+
str(item.get("split", "")),
|
| 1050 |
+
int(item.get("difficulty", 0)),
|
| 1051 |
+
int(item.get("seed", 0)),
|
| 1052 |
+
),
|
| 1053 |
+
):
|
| 1054 |
+
seed = int(result["seed"])
|
| 1055 |
+
split = str(result["split"])
|
| 1056 |
+
difficulty = int(result["difficulty"])
|
| 1057 |
+
attempts.append(
|
| 1058 |
+
{
|
| 1059 |
+
"seed": seed,
|
| 1060 |
+
"split": split,
|
| 1061 |
+
"difficulty": difficulty,
|
| 1062 |
+
"accepted": bool(result["accepted"]),
|
| 1063 |
+
"reason": result.get("reason", ""),
|
| 1064 |
+
"trajectory_path": str(_write_trajectory(config.out_dir, result["trajectory"])),
|
| 1065 |
+
}
|
| 1066 |
+
)
|
| 1067 |
+
if result["accepted"]:
|
| 1068 |
+
accepted += 1
|
| 1069 |
+
rows = list(result["rows"])
|
| 1070 |
+
rows_by_split.setdefault(split, []).extend(rows)
|
| 1071 |
+
rewards.append(float(result["trajectory"].get("terminal_total", 0.0)))
|
| 1072 |
+
|
| 1073 |
+
for split_rows in rows_by_split.values():
|
| 1074 |
+
split_rows.sort(
|
| 1075 |
+
key=lambda row: (
|
| 1076 |
+
int((row.get("metadata") or {}).get("difficulty", 0)),
|
| 1077 |
+
int((row.get("metadata") or {}).get("seed", 0)),
|
| 1078 |
+
int((row.get("metadata") or {}).get("step", 0)),
|
| 1079 |
)
|
| 1080 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1081 |
|
| 1082 |
for split_name in ("train", "validation", config.split):
|
| 1083 |
write_jsonl(config.out_dir / f"{split_name}.jsonl", rows_by_split.get(split_name, []))
|
| 1084 |
|
| 1085 |
+
reward_verification = verify_sft_dataset_rewards(
|
| 1086 |
+
config.out_dir,
|
| 1087 |
+
min_terminal_reward=config.min_terminal_reward,
|
| 1088 |
+
require_train_rows=config.split == "train",
|
| 1089 |
+
required_difficulties=difficulty_levels if len(difficulty_levels) > 1 else (),
|
| 1090 |
+
)
|
| 1091 |
+
|
| 1092 |
+
accepted_by_difficulty: dict[str, int] = {}
|
| 1093 |
+
attempted_by_difficulty: dict[str, int] = {}
|
| 1094 |
+
reward_by_difficulty: dict[str, list[float]] = {}
|
| 1095 |
+
row_count_by_difficulty: dict[str, int] = {}
|
| 1096 |
+
for result in results:
|
| 1097 |
+
difficulty_key = str(int(result.get("difficulty", 0)))
|
| 1098 |
+
attempted_by_difficulty[difficulty_key] = attempted_by_difficulty.get(difficulty_key, 0) + 1
|
| 1099 |
+
if result.get("accepted"):
|
| 1100 |
+
accepted_by_difficulty[difficulty_key] = accepted_by_difficulty.get(difficulty_key, 0) + 1
|
| 1101 |
+
reward_by_difficulty.setdefault(difficulty_key, []).append(
|
| 1102 |
+
float((result.get("trajectory") or {}).get("terminal_total", 0.0))
|
| 1103 |
+
)
|
| 1104 |
+
for split_rows in rows_by_split.values():
|
| 1105 |
+
for row in split_rows:
|
| 1106 |
+
difficulty_key = str(int((row.get("metadata") or {}).get("difficulty", 0)))
|
| 1107 |
+
row_count_by_difficulty[difficulty_key] = row_count_by_difficulty.get(difficulty_key, 0) + 1
|
| 1108 |
+
|
| 1109 |
manifest = {
|
| 1110 |
"teacher_model": config.teacher_model,
|
| 1111 |
"target_model": config.target_model,
|
| 1112 |
"split": config.split,
|
| 1113 |
"difficulty": config.difficulty,
|
| 1114 |
+
"difficulty_levels": [int(level) for level in difficulty_levels],
|
| 1115 |
+
"difficulty_bucket_count": int(difficulty_bucket_count),
|
| 1116 |
+
"episodes_per_difficulty": config.episodes,
|
| 1117 |
+
"validation_episodes_per_difficulty": config.validation_episodes,
|
| 1118 |
"seed_start": config.seed_start,
|
| 1119 |
"episodes_attempted": attempted,
|
| 1120 |
"episodes_accepted": accepted,
|
| 1121 |
"acceptance_rate": accepted / attempted if attempted else 0.0,
|
| 1122 |
+
"attempted_by_difficulty": attempted_by_difficulty,
|
| 1123 |
+
"accepted_by_difficulty": accepted_by_difficulty,
|
| 1124 |
+
"rows_by_difficulty": row_count_by_difficulty,
|
| 1125 |
+
"reward_summary_by_difficulty": {
|
| 1126 |
+
key: _reward_summary(value) for key, value in sorted(reward_by_difficulty.items())
|
| 1127 |
+
},
|
| 1128 |
+
"workers": workers,
|
| 1129 |
"rows_by_split": {key: len(value) for key, value in sorted(rows_by_split.items())},
|
| 1130 |
"reward_summary": _reward_summary(rewards),
|
| 1131 |
+
"reward_verification": reward_verification,
|
| 1132 |
"git_sha": _git_sha(),
|
| 1133 |
"verifier_version": "verifier_v1",
|
| 1134 |
"dry_run_oracle": config.dry_run_oracle,
|
| 1135 |
"attempts": attempts,
|
| 1136 |
}
|
| 1137 |
+
if config.push_to_hub:
|
| 1138 |
+
if not reward_verification["passed"]:
|
| 1139 |
+
raise RuntimeError("Reward verification failed; refusing to push dataset to Hub.")
|
| 1140 |
+
manifest["hub"] = {
|
| 1141 |
+
"repo_id": config.dataset_repo_id,
|
| 1142 |
+
"private": bool(config.hub_private),
|
| 1143 |
+
"url": f"https://huggingface.co/datasets/{config.dataset_repo_id}",
|
| 1144 |
+
}
|
| 1145 |
+
write_dataset_card(config.out_dir, manifest, config.dataset_repo_id)
|
| 1146 |
manifest_path = config.out_dir / "manifest.json"
|
| 1147 |
manifest_path.write_text(
|
| 1148 |
json.dumps(manifest, indent=2, sort_keys=True, default=str),
|
| 1149 |
encoding="utf-8",
|
| 1150 |
)
|
| 1151 |
+
if config.push_to_hub:
|
| 1152 |
+
hub_result = push_dataset_to_hub(
|
| 1153 |
+
config.out_dir,
|
| 1154 |
+
repo_id=config.dataset_repo_id,
|
| 1155 |
+
private=config.hub_private,
|
| 1156 |
+
)
|
| 1157 |
+
manifest["hub"].update(hub_result)
|
| 1158 |
+
manifest_path.write_text(
|
| 1159 |
+
json.dumps(manifest, indent=2, sort_keys=True, default=str),
|
| 1160 |
+
encoding="utf-8",
|
| 1161 |
+
)
|
| 1162 |
return manifest
|
| 1163 |
|
| 1164 |
|
|
|
|
| 1168 |
parser.add_argument("--target-model", default=DEFAULT_TARGET_MODEL)
|
| 1169 |
parser.add_argument("--split", default="train", choices=["train", "validation", "hidden_eval"])
|
| 1170 |
parser.add_argument("--difficulty", type=int, default=0)
|
| 1171 |
+
parser.add_argument(
|
| 1172 |
+
"--difficulty-levels",
|
| 1173 |
+
default="",
|
| 1174 |
+
help="Comma-separated curriculum levels to include, for example 0,1,2,3. "
|
| 1175 |
+
"When set, --episodes is per difficulty level.",
|
| 1176 |
+
)
|
| 1177 |
+
parser.add_argument(
|
| 1178 |
+
"--difficulty-buckets",
|
| 1179 |
+
type=int,
|
| 1180 |
+
default=0,
|
| 1181 |
+
help=(
|
| 1182 |
+
"Number of curriculum difficulty buckets to expose to the environment. "
|
| 1183 |
+
"Defaults to max(--difficulty-levels)+1."
|
| 1184 |
+
),
|
| 1185 |
+
)
|
| 1186 |
parser.add_argument("--seed-start", type=int, default=0)
|
| 1187 |
parser.add_argument("--episodes", type=int, default=100)
|
| 1188 |
parser.add_argument("--validation-episodes", type=int, default=0)
|
|
|
|
| 1192 |
parser.add_argument("--max-tokens", type=int, default=768)
|
| 1193 |
parser.add_argument("--temperature", type=float, default=0.2)
|
| 1194 |
parser.add_argument("--top-p", type=float, default=0.95)
|
| 1195 |
+
parser.add_argument(
|
| 1196 |
+
"--workers",
|
| 1197 |
+
type=int,
|
| 1198 |
+
default=0,
|
| 1199 |
+
help="Parallel episode workers. 0 auto-selects up to 8 workers.",
|
| 1200 |
+
)
|
| 1201 |
+
parser.add_argument(
|
| 1202 |
+
"--min-terminal-reward",
|
| 1203 |
+
type=float,
|
| 1204 |
+
default=12.0,
|
| 1205 |
+
help="Minimum verifier-backed terminal reward required for SFT rows.",
|
| 1206 |
+
)
|
| 1207 |
+
parser.add_argument(
|
| 1208 |
+
"--verify-only",
|
| 1209 |
+
action="store_true",
|
| 1210 |
+
help="Only verify an existing out-dir dataset reward metadata.",
|
| 1211 |
+
)
|
| 1212 |
+
parser.add_argument(
|
| 1213 |
+
"--push-to-hub",
|
| 1214 |
+
action="store_true",
|
| 1215 |
+
help="Upload the verified dataset folder to a Hugging Face dataset repo.",
|
| 1216 |
+
)
|
| 1217 |
+
parser.add_argument(
|
| 1218 |
+
"--progress",
|
| 1219 |
+
action="store_true",
|
| 1220 |
+
help="Print one JSON progress event for each completed episode job.",
|
| 1221 |
+
)
|
| 1222 |
+
parser.add_argument(
|
| 1223 |
+
"--push-only",
|
| 1224 |
+
action="store_true",
|
| 1225 |
+
help="Verify and upload an existing out-dir dataset without regenerating rows.",
|
| 1226 |
+
)
|
| 1227 |
+
parser.add_argument(
|
| 1228 |
+
"--dataset-repo-id",
|
| 1229 |
+
default="Humanlearning/CyberSecurity_OWASP-sft-dataset",
|
| 1230 |
+
help="Hugging Face dataset repo id used with --push-to-hub.",
|
| 1231 |
+
)
|
| 1232 |
+
parser.add_argument(
|
| 1233 |
+
"--hub-private",
|
| 1234 |
+
action="store_true",
|
| 1235 |
+
help="Create/upload the Hugging Face dataset repo as private.",
|
| 1236 |
+
)
|
| 1237 |
parser.add_argument(
|
| 1238 |
"--dry-run-oracle",
|
| 1239 |
action="store_true",
|
|
|
|
| 1248 |
target_model=args.target_model,
|
| 1249 |
split=args.split,
|
| 1250 |
difficulty=args.difficulty,
|
| 1251 |
+
difficulty_levels=_parse_int_csv(args.difficulty_levels),
|
| 1252 |
+
difficulty_buckets=args.difficulty_buckets,
|
| 1253 |
seed_start=args.seed_start,
|
| 1254 |
episodes=args.episodes,
|
| 1255 |
validation_episodes=args.validation_episodes,
|
|
|
|
| 1260 |
temperature=args.temperature,
|
| 1261 |
top_p=args.top_p,
|
| 1262 |
dry_run_oracle=args.dry_run_oracle,
|
| 1263 |
+
workers=args.workers,
|
| 1264 |
+
min_terminal_reward=args.min_terminal_reward,
|
| 1265 |
+
push_to_hub=args.push_to_hub,
|
| 1266 |
+
dataset_repo_id=args.dataset_repo_id,
|
| 1267 |
+
hub_private=args.hub_private,
|
| 1268 |
+
progress=args.progress,
|
| 1269 |
)
|
| 1270 |
|
| 1271 |
|
| 1272 |
def main(argv: list[str] | None = None) -> int:
|
| 1273 |
parser = build_arg_parser()
|
| 1274 |
args = parser.parse_args(argv)
|
| 1275 |
+
try:
|
| 1276 |
+
if args.verify_only:
|
| 1277 |
+
verification = verify_sft_dataset_rewards(
|
| 1278 |
+
args.out_dir,
|
| 1279 |
+
min_terminal_reward=args.min_terminal_reward,
|
| 1280 |
+
require_train_rows=args.split == "train",
|
| 1281 |
+
required_difficulties=_parse_int_csv(args.difficulty_levels),
|
| 1282 |
+
)
|
| 1283 |
+
print(json.dumps({"reward_verification": verification}, indent=2, sort_keys=True))
|
| 1284 |
+
return 0 if verification["passed"] else 2
|
| 1285 |
+
if args.push_only:
|
| 1286 |
+
result = push_existing_dataset(
|
| 1287 |
+
args.out_dir,
|
| 1288 |
+
repo_id=args.dataset_repo_id,
|
| 1289 |
+
private=args.hub_private,
|
| 1290 |
+
min_terminal_reward=args.min_terminal_reward,
|
| 1291 |
+
required_difficulties=_parse_int_csv(args.difficulty_levels),
|
| 1292 |
+
)
|
| 1293 |
+
print(json.dumps(result, indent=2, sort_keys=True))
|
| 1294 |
+
return 0
|
| 1295 |
+
manifest = generate_dataset(config_from_args(args))
|
| 1296 |
+
print(json.dumps(manifest, indent=2, sort_keys=True))
|
| 1297 |
+
return 0 if manifest.get("reward_verification", {}).get("passed") else 2
|
| 1298 |
+
except (RuntimeError, ValueError) as exc:
|
| 1299 |
+
print(
|
| 1300 |
+
json.dumps(
|
| 1301 |
+
{"error": str(exc), "error_type": exc.__class__.__name__},
|
| 1302 |
+
indent=2,
|
| 1303 |
+
sort_keys=True,
|
| 1304 |
+
)
|
| 1305 |
+
)
|
| 1306 |
+
return 2
|
| 1307 |
|
| 1308 |
|
| 1309 |
if __name__ == "__main__":
|
scripts/modal_train_sft.py
CHANGED
|
@@ -33,6 +33,15 @@ TRITON_CACHE_DIR = CACHE_DIR / "triton"
|
|
| 33 |
REMOTE_PROJECT = "/root/CyberSecurity_OWASP"
|
| 34 |
PROJECT_ROOT = pathlib.Path(__file__).resolve().parents[1]
|
| 35 |
DEFAULT_GEMMA_MODEL = "unsloth/gemma-4-E2B-it"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
PUBLIC_REPO_URL = "https://github.com/humandotlearning/CyberSecurity_OWASP.git"
|
| 37 |
PUBLIC_REPO_BRANCH = "master"
|
| 38 |
|
|
@@ -50,6 +59,170 @@ def _model_repo_slug(model_name: str) -> str:
|
|
| 50 |
return model_name.replace("/", "-").replace("_", "-").replace(".", "-").lower()
|
| 51 |
|
| 52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
def _configure_modal_cache_env() -> dict[str, str]:
|
| 54 |
values = {
|
| 55 |
"HF_HOME": str(HF_HOME_DIR),
|
|
@@ -171,7 +344,7 @@ def upload_sft_jsonl(relative_path: str, content: str) -> str:
|
|
| 171 |
|
| 172 |
@app.function(
|
| 173 |
image=training_image,
|
| 174 |
-
gpu=
|
| 175 |
timeout=12 * 60 * 60,
|
| 176 |
volumes={RUNS_DIR: volume, CACHE_DIR: cache_volume},
|
| 177 |
secrets=secrets,
|
|
@@ -179,24 +352,29 @@ def upload_sft_jsonl(relative_path: str, content: str) -> str:
|
|
| 179 |
def train_cybersecurity_owasp_sft(
|
| 180 |
train_jsonl: str = "/runs/sft/train.jsonl",
|
| 181 |
validation_jsonl: str = "/runs/sft/validation.jsonl",
|
| 182 |
-
|
|
|
|
| 183 |
model_name: str = DEFAULT_GEMMA_MODEL,
|
| 184 |
run_name: str = "",
|
| 185 |
max_seq_length: int = 4096,
|
| 186 |
-
max_steps: int =
|
| 187 |
num_train_epochs: float = 1.0,
|
| 188 |
-
per_device_train_batch_size: int =
|
| 189 |
-
gradient_accumulation_steps: int =
|
| 190 |
learning_rate: float = 2e-5,
|
| 191 |
lora_rank: int = 32,
|
| 192 |
-
trackio_space_id: str =
|
| 193 |
-
trackio_project: str =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
push_to_hub: bool = False,
|
| 195 |
) -> dict[str, Any]:
|
| 196 |
import inspect
|
| 197 |
|
| 198 |
from datasets import load_dataset
|
| 199 |
-
from huggingface_hub import snapshot_download
|
| 200 |
from trl import SFTConfig, SFTTrainer
|
| 201 |
from trl.chat_template_utils import add_response_schema
|
| 202 |
from unsloth import FastVisionModel
|
|
@@ -207,10 +385,9 @@ def train_cybersecurity_owasp_sft(
|
|
| 207 |
if not hf_token:
|
| 208 |
raise RuntimeError(f"HF_TOKEN is missing from the Modal secret {SECRET_NAME}.")
|
| 209 |
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
)
|
| 214 |
stamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
|
| 215 |
run_name = run_name or f"CyberSecurity_OWASP-{_model_repo_slug(model_name)}-sft-{stamp}"
|
| 216 |
output_dir = RUNS_DIR / run_name
|
|
@@ -222,6 +399,21 @@ def train_cybersecurity_owasp_sft(
|
|
| 222 |
has_validation = validation_path.exists() and validation_path.stat().st_size > 0
|
| 223 |
if has_validation:
|
| 224 |
data_files["validation"] = validation_jsonl
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 225 |
dataset = load_dataset("json", data_files=data_files)
|
| 226 |
|
| 227 |
print(f"SFT run name: {run_name}")
|
|
@@ -232,6 +424,11 @@ def train_cybersecurity_owasp_sft(
|
|
| 232 |
print(f"Output repo: https://huggingface.co/{output_repo_id}")
|
| 233 |
print(f"Trackio Space: https://huggingface.co/spaces/{trackio_space_id}")
|
| 234 |
print(f"HF_HUB_CACHE: {cache_env['HF_HUB_CACHE']}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
|
| 236 |
try:
|
| 237 |
snapshot_download(repo_id=model_name, cache_dir=str(HF_HUB_CACHE_DIR), token=hf_token)
|
|
@@ -280,19 +477,25 @@ def train_cybersecurity_owasp_sft(
|
|
| 280 |
"per_device_train_batch_size": per_device_train_batch_size,
|
| 281 |
"gradient_accumulation_steps": gradient_accumulation_steps,
|
| 282 |
"learning_rate": learning_rate,
|
|
|
|
| 283 |
"logging_steps": 1,
|
| 284 |
-
"
|
|
|
|
| 285 |
"report_to": "trackio",
|
| 286 |
"project": trackio_project,
|
| 287 |
"trackio_space_id": trackio_space_id,
|
| 288 |
"run_name": run_name,
|
| 289 |
"assistant_only_loss": True,
|
| 290 |
-
"packing":
|
|
|
|
|
|
|
|
|
|
| 291 |
"gradient_checkpointing": True,
|
| 292 |
"gradient_checkpointing_kwargs": {"use_reentrant": False},
|
| 293 |
"push_to_hub": push_to_hub,
|
| 294 |
"hub_model_id": output_repo_id,
|
| 295 |
"hub_private_repo": True,
|
|
|
|
| 296 |
}
|
| 297 |
sft_parameters = set(inspect.signature(SFTConfig).parameters)
|
| 298 |
skipped = sorted(set(sft_values) - sft_parameters)
|
|
@@ -335,6 +538,11 @@ def train_cybersecurity_owasp_sft(
|
|
| 335 |
"output_repo_id": output_repo_id,
|
| 336 |
"train_jsonl": train_jsonl,
|
| 337 |
"validation_jsonl": validation_jsonl if has_validation else "",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 338 |
"max_steps": max_steps,
|
| 339 |
"push_to_hub": push_to_hub,
|
| 340 |
"trackio_space_id": trackio_space_id,
|
|
@@ -365,20 +573,26 @@ def main(
|
|
| 365 |
mode: str = "train",
|
| 366 |
local_train_path: str = "outputs/sft/train.jsonl",
|
| 367 |
local_validation_path: str = "outputs/sft/validation.jsonl",
|
|
|
|
| 368 |
train_jsonl: str = "/runs/sft/train.jsonl",
|
| 369 |
validation_jsonl: str = "/runs/sft/validation.jsonl",
|
| 370 |
-
|
|
|
|
| 371 |
model_name: str = DEFAULT_GEMMA_MODEL,
|
| 372 |
run_name: str = "",
|
| 373 |
max_seq_length: int = 4096,
|
| 374 |
-
max_steps: int =
|
| 375 |
num_train_epochs: float = 1.0,
|
| 376 |
-
per_device_train_batch_size: int =
|
| 377 |
-
gradient_accumulation_steps: int =
|
| 378 |
learning_rate: float = 2e-5,
|
| 379 |
lora_rank: int = 32,
|
| 380 |
-
trackio_space_id: str =
|
| 381 |
-
trackio_project: str =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 382 |
source_mode: str = "local",
|
| 383 |
repo_url: str = PUBLIC_REPO_URL,
|
| 384 |
repo_branch: str = PUBLIC_REPO_BRANCH,
|
|
@@ -392,6 +606,22 @@ def main(
|
|
| 392 |
|
| 393 |
local_train = pathlib.Path(local_train_path)
|
| 394 |
local_validation = pathlib.Path(local_validation_path)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 395 |
if local_train.exists():
|
| 396 |
uploaded = upload_sft_jsonl.remote(
|
| 397 |
"sft/train.jsonl",
|
|
@@ -406,6 +636,13 @@ def main(
|
|
| 406 |
)
|
| 407 |
print(f"Uploaded validation JSONL: {uploaded_validation}")
|
| 408 |
validation_jsonl = uploaded_validation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 409 |
if mode == "upload":
|
| 410 |
return
|
| 411 |
|
|
@@ -416,6 +653,7 @@ def main(
|
|
| 416 |
kwargs = dict(
|
| 417 |
train_jsonl=train_jsonl,
|
| 418 |
validation_jsonl=validation_jsonl,
|
|
|
|
| 419 |
output_repo_id=output_repo_id,
|
| 420 |
model_name=model_name,
|
| 421 |
run_name=run_name,
|
|
@@ -428,11 +666,19 @@ def main(
|
|
| 428 |
lora_rank=lora_rank,
|
| 429 |
trackio_space_id=trackio_space_id,
|
| 430 |
trackio_project=trackio_project,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 431 |
push_to_hub=push_to_hub,
|
| 432 |
)
|
| 433 |
print(f"SFT run name: {run_name}")
|
| 434 |
print(f"Train JSONL: {train_jsonl}")
|
| 435 |
print(f"Validation JSONL: {validation_jsonl}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 436 |
print(f"Hub push enabled: {push_to_hub}")
|
| 437 |
if detach:
|
| 438 |
call = train_cybersecurity_owasp_sft.spawn(**kwargs)
|
|
|
|
| 33 |
REMOTE_PROJECT = "/root/CyberSecurity_OWASP"
|
| 34 |
PROJECT_ROOT = pathlib.Path(__file__).resolve().parents[1]
|
| 35 |
DEFAULT_GEMMA_MODEL = "unsloth/gemma-4-E2B-it"
|
| 36 |
+
SFT_GPU_FALLBACK = ["H200", "H100", "A100-80GB", "L40S"]
|
| 37 |
+
DEFAULT_CURRICULUM_LEVELS = "0,1,2,3"
|
| 38 |
+
DEFAULT_TOTAL_TRAIN_EPISODES = 300
|
| 39 |
+
DEFAULT_EPISODES_PER_LEVEL = 75
|
| 40 |
+
DEFAULT_TRACKIO_SPACE_ID = "Humanlearning/CyberSecurity_OWASP-trackio"
|
| 41 |
+
DEFAULT_TRACKIO_PROJECT = "CyberSecurity_OWASP-sft"
|
| 42 |
+
DEFAULT_SFT_OUTPUT_REPO_ID = (
|
| 43 |
+
"Humanlearning/CyberSecurity_OWASP-unsloth-gemma-4-e2b-it-sft-lora"
|
| 44 |
+
)
|
| 45 |
PUBLIC_REPO_URL = "https://github.com/humandotlearning/CyberSecurity_OWASP.git"
|
| 46 |
PUBLIC_REPO_BRANCH = "master"
|
| 47 |
|
|
|
|
| 59 |
return model_name.replace("/", "-").replace("_", "-").replace(".", "-").lower()
|
| 60 |
|
| 61 |
|
| 62 |
+
def _parse_int_csv(value: str) -> set[int]:
|
| 63 |
+
if not value.strip():
|
| 64 |
+
return set()
|
| 65 |
+
return {int(item.strip()) for item in value.split(",") if item.strip()}
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
SFT_ALLOWED_TOOLS = {
|
| 69 |
+
"inspect_policy_graph",
|
| 70 |
+
"list_routes",
|
| 71 |
+
"read_openapi",
|
| 72 |
+
"read_file",
|
| 73 |
+
"search_code",
|
| 74 |
+
"send_local_request",
|
| 75 |
+
"compare_identities",
|
| 76 |
+
"submit_diagnosis",
|
| 77 |
+
"patch_file",
|
| 78 |
+
"run_visible_tests",
|
| 79 |
+
"submit_fix",
|
| 80 |
+
"noop",
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def _read_jsonl(path: pathlib.Path) -> list[dict[str, Any]]:
|
| 85 |
+
if not path.exists():
|
| 86 |
+
return []
|
| 87 |
+
rows: list[dict[str, Any]] = []
|
| 88 |
+
for line_number, line in enumerate(path.read_text(encoding="utf-8").splitlines(), start=1):
|
| 89 |
+
if not line.strip():
|
| 90 |
+
continue
|
| 91 |
+
try:
|
| 92 |
+
row = json.loads(line)
|
| 93 |
+
except json.JSONDecodeError as exc:
|
| 94 |
+
raise ValueError(f"{path}:{line_number}: invalid JSONL: {exc}") from exc
|
| 95 |
+
if not isinstance(row, dict):
|
| 96 |
+
raise ValueError(f"{path}:{line_number}: row must be a JSON object")
|
| 97 |
+
rows.append(row)
|
| 98 |
+
return rows
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def _verify_sft_rows(
|
| 102 |
+
path: pathlib.Path,
|
| 103 |
+
*,
|
| 104 |
+
min_terminal_reward: float,
|
| 105 |
+
) -> tuple[list[str], list[float], int, set[int]]:
|
| 106 |
+
rows = _read_jsonl(path)
|
| 107 |
+
failures: list[str] = []
|
| 108 |
+
rewards: list[float] = []
|
| 109 |
+
difficulties: set[int] = set()
|
| 110 |
+
for index, row in enumerate(rows, start=1):
|
| 111 |
+
messages = row.get("messages")
|
| 112 |
+
if not isinstance(messages, list) or len(messages) < 3:
|
| 113 |
+
failures.append(f"{path}:{index}: messages must include system/user/assistant")
|
| 114 |
+
continue
|
| 115 |
+
assistant = messages[-1]
|
| 116 |
+
if assistant.get("role") != "assistant":
|
| 117 |
+
failures.append(f"{path}:{index}: final message must be assistant")
|
| 118 |
+
continue
|
| 119 |
+
try:
|
| 120 |
+
action = json.loads(str(assistant.get("content", "")))
|
| 121 |
+
except json.JSONDecodeError as exc:
|
| 122 |
+
failures.append(f"{path}:{index}: assistant content is not JSON: {exc}")
|
| 123 |
+
continue
|
| 124 |
+
if not isinstance(action, dict) or action.get("tool_name") not in SFT_ALLOWED_TOOLS:
|
| 125 |
+
failures.append(f"{path}:{index}: assistant content is not a valid tool action")
|
| 126 |
+
continue
|
| 127 |
+
metadata = row.get("metadata")
|
| 128 |
+
if not isinstance(metadata, dict):
|
| 129 |
+
failures.append(f"{path}:{index}: missing metadata")
|
| 130 |
+
continue
|
| 131 |
+
if metadata.get("final_success") is not True:
|
| 132 |
+
failures.append(f"{path}:{index}: final_success is not true")
|
| 133 |
+
continue
|
| 134 |
+
if metadata.get("anti_cheat_flags") or []:
|
| 135 |
+
failures.append(f"{path}:{index}: anti-cheat flags present")
|
| 136 |
+
continue
|
| 137 |
+
if "difficulty" in metadata:
|
| 138 |
+
difficulties.add(int(metadata.get("difficulty", 0)))
|
| 139 |
+
terminal_reward = float(metadata.get("terminal_total", 0.0) or 0.0)
|
| 140 |
+
rewards.append(terminal_reward)
|
| 141 |
+
if terminal_reward < min_terminal_reward:
|
| 142 |
+
failures.append(
|
| 143 |
+
f"{path}:{index}: terminal_total {terminal_reward:.3f} below {min_terminal_reward:.3f}"
|
| 144 |
+
)
|
| 145 |
+
continue
|
| 146 |
+
breakdown = metadata.get("final_reward_breakdown") or {}
|
| 147 |
+
if not isinstance(breakdown, dict):
|
| 148 |
+
failures.append(f"{path}:{index}: missing final_reward_breakdown")
|
| 149 |
+
continue
|
| 150 |
+
for key in ("security", "regression", "public_routes", "patch_quality", "visible_tests"):
|
| 151 |
+
if float(breakdown.get(key, 0.0) or 0.0) <= 0.0:
|
| 152 |
+
failures.append(f"{path}:{index}: reward component {key} is not positive")
|
| 153 |
+
break
|
| 154 |
+
return failures, rewards, len(rows), difficulties
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def verify_sft_inputs(
|
| 158 |
+
*,
|
| 159 |
+
train_jsonl: str,
|
| 160 |
+
validation_jsonl: str = "",
|
| 161 |
+
manifest_path: str = "",
|
| 162 |
+
required_difficulties: str = "",
|
| 163 |
+
min_terminal_reward: float = 12.0,
|
| 164 |
+
min_train_rows: int = 1,
|
| 165 |
+
) -> dict[str, Any]:
|
| 166 |
+
train_path = pathlib.Path(train_jsonl)
|
| 167 |
+
validation_path = pathlib.Path(validation_jsonl) if validation_jsonl else pathlib.Path("")
|
| 168 |
+
failures, rewards, train_rows, difficulties = _verify_sft_rows(
|
| 169 |
+
train_path,
|
| 170 |
+
min_terminal_reward=min_terminal_reward,
|
| 171 |
+
)
|
| 172 |
+
validation_rows = 0
|
| 173 |
+
if validation_jsonl and validation_path.exists() and validation_path.stat().st_size > 0:
|
| 174 |
+
validation_failures, validation_rewards, validation_rows, validation_difficulties = _verify_sft_rows(
|
| 175 |
+
validation_path,
|
| 176 |
+
min_terminal_reward=min_terminal_reward,
|
| 177 |
+
)
|
| 178 |
+
failures.extend(validation_failures)
|
| 179 |
+
rewards.extend(validation_rewards)
|
| 180 |
+
difficulties.update(validation_difficulties)
|
| 181 |
+
if train_rows < min_train_rows:
|
| 182 |
+
failures.append(f"{train_path}: expected at least {min_train_rows} train rows, found {train_rows}")
|
| 183 |
+
|
| 184 |
+
manifest_verification: dict[str, Any] = {}
|
| 185 |
+
manifest = pathlib.Path(manifest_path) if manifest_path else pathlib.Path("")
|
| 186 |
+
if manifest_path and manifest.exists():
|
| 187 |
+
try:
|
| 188 |
+
manifest_data = json.loads(manifest.read_text(encoding="utf-8"))
|
| 189 |
+
manifest_verification = dict(manifest_data.get("reward_verification") or {})
|
| 190 |
+
manifest_difficulties = {
|
| 191 |
+
int(item) for item in manifest_data.get("difficulty_levels", []) or []
|
| 192 |
+
}
|
| 193 |
+
except Exception as exc:
|
| 194 |
+
failures.append(f"{manifest}: could not read manifest reward verification: {exc}")
|
| 195 |
+
manifest_difficulties = set()
|
| 196 |
+
if manifest_verification and manifest_verification.get("passed") is not True:
|
| 197 |
+
failures.append(f"{manifest}: manifest reward_verification did not pass")
|
| 198 |
+
else:
|
| 199 |
+
manifest_difficulties = set()
|
| 200 |
+
|
| 201 |
+
required = _parse_int_csv(required_difficulties) or manifest_difficulties
|
| 202 |
+
missing_difficulties = sorted(level for level in required if level not in difficulties)
|
| 203 |
+
if missing_difficulties:
|
| 204 |
+
failures.append(f"missing required curriculum difficulty rows: {missing_difficulties}")
|
| 205 |
+
|
| 206 |
+
reward_summary = {
|
| 207 |
+
"min": min(rewards) if rewards else 0.0,
|
| 208 |
+
"max": max(rewards) if rewards else 0.0,
|
| 209 |
+
"mean": (sum(rewards) / len(rewards)) if rewards else 0.0,
|
| 210 |
+
}
|
| 211 |
+
return {
|
| 212 |
+
"passed": not failures,
|
| 213 |
+
"failure_count": len(failures),
|
| 214 |
+
"failures": failures[:50],
|
| 215 |
+
"train_rows": train_rows,
|
| 216 |
+
"validation_rows": validation_rows,
|
| 217 |
+
"difficulties": sorted(difficulties),
|
| 218 |
+
"required_difficulties": sorted(required),
|
| 219 |
+
"missing_difficulties": missing_difficulties,
|
| 220 |
+
"min_terminal_reward": float(min_terminal_reward),
|
| 221 |
+
"reward_summary": reward_summary,
|
| 222 |
+
"manifest_reward_verification": manifest_verification,
|
| 223 |
+
}
|
| 224 |
+
|
| 225 |
+
|
| 226 |
def _configure_modal_cache_env() -> dict[str, str]:
|
| 227 |
values = {
|
| 228 |
"HF_HOME": str(HF_HOME_DIR),
|
|
|
|
| 344 |
|
| 345 |
@app.function(
|
| 346 |
image=training_image,
|
| 347 |
+
gpu=SFT_GPU_FALLBACK,
|
| 348 |
timeout=12 * 60 * 60,
|
| 349 |
volumes={RUNS_DIR: volume, CACHE_DIR: cache_volume},
|
| 350 |
secrets=secrets,
|
|
|
|
| 352 |
def train_cybersecurity_owasp_sft(
|
| 353 |
train_jsonl: str = "/runs/sft/train.jsonl",
|
| 354 |
validation_jsonl: str = "/runs/sft/validation.jsonl",
|
| 355 |
+
manifest_path: str = "/runs/sft/manifest.json",
|
| 356 |
+
output_repo_id: str = DEFAULT_SFT_OUTPUT_REPO_ID,
|
| 357 |
model_name: str = DEFAULT_GEMMA_MODEL,
|
| 358 |
run_name: str = "",
|
| 359 |
max_seq_length: int = 4096,
|
| 360 |
+
max_steps: int = -1,
|
| 361 |
num_train_epochs: float = 1.0,
|
| 362 |
+
per_device_train_batch_size: int = 4,
|
| 363 |
+
gradient_accumulation_steps: int = 4,
|
| 364 |
learning_rate: float = 2e-5,
|
| 365 |
lora_rank: int = 32,
|
| 366 |
+
trackio_space_id: str = DEFAULT_TRACKIO_SPACE_ID,
|
| 367 |
+
trackio_project: str = DEFAULT_TRACKIO_PROJECT,
|
| 368 |
+
require_reward_verification: bool = True,
|
| 369 |
+
required_difficulties: str = DEFAULT_CURRICULUM_LEVELS,
|
| 370 |
+
min_terminal_reward: float = 12.0,
|
| 371 |
+
min_train_rows: int = 1,
|
| 372 |
push_to_hub: bool = False,
|
| 373 |
) -> dict[str, Any]:
|
| 374 |
import inspect
|
| 375 |
|
| 376 |
from datasets import load_dataset
|
| 377 |
+
from huggingface_hub import snapshot_download
|
| 378 |
from trl import SFTConfig, SFTTrainer
|
| 379 |
from trl.chat_template_utils import add_response_schema
|
| 380 |
from unsloth import FastVisionModel
|
|
|
|
| 385 |
if not hf_token:
|
| 386 |
raise RuntimeError(f"HF_TOKEN is missing from the Modal secret {SECRET_NAME}.")
|
| 387 |
|
| 388 |
+
output_repo_id = output_repo_id or DEFAULT_SFT_OUTPUT_REPO_ID
|
| 389 |
+
os.environ["TRACKIO_SPACE_ID"] = trackio_space_id
|
| 390 |
+
os.environ["TRACKIO_PROJECT"] = trackio_project
|
|
|
|
| 391 |
stamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
|
| 392 |
run_name = run_name or f"CyberSecurity_OWASP-{_model_repo_slug(model_name)}-sft-{stamp}"
|
| 393 |
output_dir = RUNS_DIR / run_name
|
|
|
|
| 399 |
has_validation = validation_path.exists() and validation_path.stat().st_size > 0
|
| 400 |
if has_validation:
|
| 401 |
data_files["validation"] = validation_jsonl
|
| 402 |
+
|
| 403 |
+
reward_preflight = verify_sft_inputs(
|
| 404 |
+
train_jsonl=train_jsonl,
|
| 405 |
+
validation_jsonl=validation_jsonl if has_validation else "",
|
| 406 |
+
manifest_path=manifest_path,
|
| 407 |
+
required_difficulties=required_difficulties,
|
| 408 |
+
min_terminal_reward=min_terminal_reward,
|
| 409 |
+
min_train_rows=min_train_rows,
|
| 410 |
+
)
|
| 411 |
+
print(f"SFT reward preflight: {json.dumps(reward_preflight, sort_keys=True)}")
|
| 412 |
+
if require_reward_verification and not reward_preflight["passed"]:
|
| 413 |
+
raise RuntimeError(
|
| 414 |
+
"SFT reward verification failed; refusing to start model training. "
|
| 415 |
+
f"Failures: {reward_preflight['failures']}"
|
| 416 |
+
)
|
| 417 |
dataset = load_dataset("json", data_files=data_files)
|
| 418 |
|
| 419 |
print(f"SFT run name: {run_name}")
|
|
|
|
| 424 |
print(f"Output repo: https://huggingface.co/{output_repo_id}")
|
| 425 |
print(f"Trackio Space: https://huggingface.co/spaces/{trackio_space_id}")
|
| 426 |
print(f"HF_HUB_CACHE: {cache_env['HF_HUB_CACHE']}")
|
| 427 |
+
print(
|
| 428 |
+
"SFT target: "
|
| 429 |
+
f"{DEFAULT_TOTAL_TRAIN_EPISODES} total train episodes, "
|
| 430 |
+
f"{DEFAULT_EPISODES_PER_LEVEL} per level across {DEFAULT_CURRICULUM_LEVELS}"
|
| 431 |
+
)
|
| 432 |
|
| 433 |
try:
|
| 434 |
snapshot_download(repo_id=model_name, cache_dir=str(HF_HUB_CACHE_DIR), token=hf_token)
|
|
|
|
| 477 |
"per_device_train_batch_size": per_device_train_batch_size,
|
| 478 |
"gradient_accumulation_steps": gradient_accumulation_steps,
|
| 479 |
"learning_rate": learning_rate,
|
| 480 |
+
"optim": "adamw_8bit",
|
| 481 |
"logging_steps": 1,
|
| 482 |
+
"logging_first_step": True,
|
| 483 |
+
"save_steps": max(10, max_steps) if max_steps > 0 else 100,
|
| 484 |
"report_to": "trackio",
|
| 485 |
"project": trackio_project,
|
| 486 |
"trackio_space_id": trackio_space_id,
|
| 487 |
"run_name": run_name,
|
| 488 |
"assistant_only_loss": True,
|
| 489 |
+
"packing": True,
|
| 490 |
+
"packing_strategy": "bfd",
|
| 491 |
+
"bf16": True,
|
| 492 |
+
"tf32": True,
|
| 493 |
"gradient_checkpointing": True,
|
| 494 |
"gradient_checkpointing_kwargs": {"use_reentrant": False},
|
| 495 |
"push_to_hub": push_to_hub,
|
| 496 |
"hub_model_id": output_repo_id,
|
| 497 |
"hub_private_repo": True,
|
| 498 |
+
"hub_strategy": "every_save",
|
| 499 |
}
|
| 500 |
sft_parameters = set(inspect.signature(SFTConfig).parameters)
|
| 501 |
skipped = sorted(set(sft_values) - sft_parameters)
|
|
|
|
| 538 |
"output_repo_id": output_repo_id,
|
| 539 |
"train_jsonl": train_jsonl,
|
| 540 |
"validation_jsonl": validation_jsonl if has_validation else "",
|
| 541 |
+
"manifest_path": manifest_path,
|
| 542 |
+
"reward_preflight": reward_preflight,
|
| 543 |
+
"required_difficulties": required_difficulties,
|
| 544 |
+
"default_total_train_episodes": DEFAULT_TOTAL_TRAIN_EPISODES,
|
| 545 |
+
"default_episodes_per_level": DEFAULT_EPISODES_PER_LEVEL,
|
| 546 |
"max_steps": max_steps,
|
| 547 |
"push_to_hub": push_to_hub,
|
| 548 |
"trackio_space_id": trackio_space_id,
|
|
|
|
| 573 |
mode: str = "train",
|
| 574 |
local_train_path: str = "outputs/sft/train.jsonl",
|
| 575 |
local_validation_path: str = "outputs/sft/validation.jsonl",
|
| 576 |
+
local_manifest_path: str = "outputs/sft/manifest.json",
|
| 577 |
train_jsonl: str = "/runs/sft/train.jsonl",
|
| 578 |
validation_jsonl: str = "/runs/sft/validation.jsonl",
|
| 579 |
+
manifest_path: str = "/runs/sft/manifest.json",
|
| 580 |
+
output_repo_id: str = DEFAULT_SFT_OUTPUT_REPO_ID,
|
| 581 |
model_name: str = DEFAULT_GEMMA_MODEL,
|
| 582 |
run_name: str = "",
|
| 583 |
max_seq_length: int = 4096,
|
| 584 |
+
max_steps: int = -1,
|
| 585 |
num_train_epochs: float = 1.0,
|
| 586 |
+
per_device_train_batch_size: int = 4,
|
| 587 |
+
gradient_accumulation_steps: int = 4,
|
| 588 |
learning_rate: float = 2e-5,
|
| 589 |
lora_rank: int = 32,
|
| 590 |
+
trackio_space_id: str = DEFAULT_TRACKIO_SPACE_ID,
|
| 591 |
+
trackio_project: str = DEFAULT_TRACKIO_PROJECT,
|
| 592 |
+
require_reward_verification: bool = True,
|
| 593 |
+
required_difficulties: str = DEFAULT_CURRICULUM_LEVELS,
|
| 594 |
+
min_terminal_reward: float = 12.0,
|
| 595 |
+
min_train_rows: int = 1,
|
| 596 |
source_mode: str = "local",
|
| 597 |
repo_url: str = PUBLIC_REPO_URL,
|
| 598 |
repo_branch: str = PUBLIC_REPO_BRANCH,
|
|
|
|
| 606 |
|
| 607 |
local_train = pathlib.Path(local_train_path)
|
| 608 |
local_validation = pathlib.Path(local_validation_path)
|
| 609 |
+
local_manifest = pathlib.Path(local_manifest_path)
|
| 610 |
+
if require_reward_verification and local_train.exists():
|
| 611 |
+
local_reward_preflight = verify_sft_inputs(
|
| 612 |
+
train_jsonl=str(local_train),
|
| 613 |
+
validation_jsonl=str(local_validation) if local_validation.exists() else "",
|
| 614 |
+
manifest_path=str(local_manifest) if local_manifest.exists() else "",
|
| 615 |
+
required_difficulties=required_difficulties,
|
| 616 |
+
min_terminal_reward=min_terminal_reward,
|
| 617 |
+
min_train_rows=min_train_rows,
|
| 618 |
+
)
|
| 619 |
+
print(f"Local SFT reward preflight: {json.dumps(local_reward_preflight, sort_keys=True)}")
|
| 620 |
+
if not local_reward_preflight["passed"]:
|
| 621 |
+
raise RuntimeError(
|
| 622 |
+
"Local SFT reward verification failed; refusing to upload/train. "
|
| 623 |
+
f"Failures: {local_reward_preflight['failures']}"
|
| 624 |
+
)
|
| 625 |
if local_train.exists():
|
| 626 |
uploaded = upload_sft_jsonl.remote(
|
| 627 |
"sft/train.jsonl",
|
|
|
|
| 636 |
)
|
| 637 |
print(f"Uploaded validation JSONL: {uploaded_validation}")
|
| 638 |
validation_jsonl = uploaded_validation
|
| 639 |
+
if local_manifest.exists():
|
| 640 |
+
uploaded_manifest = upload_sft_jsonl.remote(
|
| 641 |
+
"sft/manifest.json",
|
| 642 |
+
local_manifest.read_text(encoding="utf-8"),
|
| 643 |
+
)
|
| 644 |
+
print(f"Uploaded manifest: {uploaded_manifest}")
|
| 645 |
+
manifest_path = uploaded_manifest
|
| 646 |
if mode == "upload":
|
| 647 |
return
|
| 648 |
|
|
|
|
| 653 |
kwargs = dict(
|
| 654 |
train_jsonl=train_jsonl,
|
| 655 |
validation_jsonl=validation_jsonl,
|
| 656 |
+
manifest_path=manifest_path,
|
| 657 |
output_repo_id=output_repo_id,
|
| 658 |
model_name=model_name,
|
| 659 |
run_name=run_name,
|
|
|
|
| 666 |
lora_rank=lora_rank,
|
| 667 |
trackio_space_id=trackio_space_id,
|
| 668 |
trackio_project=trackio_project,
|
| 669 |
+
require_reward_verification=require_reward_verification,
|
| 670 |
+
required_difficulties=required_difficulties,
|
| 671 |
+
min_terminal_reward=min_terminal_reward,
|
| 672 |
+
min_train_rows=min_train_rows,
|
| 673 |
push_to_hub=push_to_hub,
|
| 674 |
)
|
| 675 |
print(f"SFT run name: {run_name}")
|
| 676 |
print(f"Train JSONL: {train_jsonl}")
|
| 677 |
print(f"Validation JSONL: {validation_jsonl}")
|
| 678 |
+
print(f"Manifest: {manifest_path}")
|
| 679 |
+
print(f"Reward verification required: {require_reward_verification}")
|
| 680 |
+
if required_difficulties:
|
| 681 |
+
print(f"Required curriculum difficulties: {required_difficulties}")
|
| 682 |
print(f"Hub push enabled: {push_to_hub}")
|
| 683 |
if detach:
|
| 684 |
call = train_cybersecurity_owasp_sft.spawn(**kwargs)
|
tests/test_modal_scenario_cache_static.py
CHANGED
|
@@ -37,3 +37,40 @@ def test_modal_training_is_pinned_to_gemma4_e2b():
|
|
| 37 |
assert "from unsloth import FastVisionModel" in source
|
| 38 |
assert "Qwen" not in source
|
| 39 |
assert "FastLanguageModel" not in source
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
assert "from unsloth import FastVisionModel" in source
|
| 38 |
assert "Qwen" not in source
|
| 39 |
assert "FastLanguageModel" not in source
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def test_modal_sft_defaults_match_300_episode_fast_handoff_plan():
|
| 43 |
+
source = (ROOT / "scripts" / "modal_train_sft.py").read_text(encoding="utf-8")
|
| 44 |
+
|
| 45 |
+
assert 'SFT_GPU_FALLBACK = ["H200", "H100", "A100-80GB", "L40S"]' in source
|
| 46 |
+
assert "gpu=SFT_GPU_FALLBACK" in source
|
| 47 |
+
assert "DEFAULT_TOTAL_TRAIN_EPISODES = 300" in source
|
| 48 |
+
assert "DEFAULT_EPISODES_PER_LEVEL = 75" in source
|
| 49 |
+
assert 'DEFAULT_CURRICULUM_LEVELS = "0,1,2,3"' in source
|
| 50 |
+
assert (
|
| 51 |
+
'DEFAULT_SFT_OUTPUT_REPO_ID = (\n'
|
| 52 |
+
' "Humanlearning/CyberSecurity_OWASP-unsloth-gemma-4-e2b-it-sft-lora"'
|
| 53 |
+
) in source
|
| 54 |
+
assert "output_repo_id = output_repo_id or DEFAULT_SFT_OUTPUT_REPO_ID" in source
|
| 55 |
+
assert source.count("max_steps: int = -1") >= 2
|
| 56 |
+
assert source.count("per_device_train_batch_size: int = 4") >= 2
|
| 57 |
+
assert source.count("gradient_accumulation_steps: int = 4") >= 2
|
| 58 |
+
assert '"assistant_only_loss": True' in source
|
| 59 |
+
assert '"packing": True' in source
|
| 60 |
+
assert '"packing_strategy": "bfd"' in source
|
| 61 |
+
assert '"bf16": True' in source
|
| 62 |
+
assert '"tf32": True' in source
|
| 63 |
+
assert '"hub_strategy": "every_save"' in source
|
| 64 |
+
assert 'trackio_space_id: str = DEFAULT_TRACKIO_SPACE_ID' in source
|
| 65 |
+
assert 'trackio_project: str = DEFAULT_TRACKIO_PROJECT' in source
|
| 66 |
+
assert 'os.environ["TRACKIO_SPACE_ID"] = trackio_space_id' in source
|
| 67 |
+
assert 'os.environ["TRACKIO_PROJECT"] = trackio_project' in source
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def test_modal_grpo_loads_sft_adapter_from_hub_as_trainable_lora():
|
| 71 |
+
source = (ROOT / "scripts" / "modal_train_grpo.py").read_text(encoding="utf-8")
|
| 72 |
+
|
| 73 |
+
assert "initial_adapter_repo_id" in source
|
| 74 |
+
assert "Downloading initial SFT adapter" in source
|
| 75 |
+
assert "snapshot_download(" in source
|
| 76 |
+
assert "PeftModel.from_pretrained(model, adapter_source, is_trainable=True)" in source
|
tests/test_sft_dataset_generation.py
CHANGED
|
@@ -90,11 +90,20 @@ def test_dry_run_oracle_creates_chat_jsonl_without_network():
|
|
| 90 |
validation_episodes=1,
|
| 91 |
out_dir=out_dir,
|
| 92 |
dry_run_oracle=True,
|
|
|
|
|
|
|
| 93 |
)
|
| 94 |
)
|
| 95 |
|
| 96 |
-
assert manifest["
|
| 97 |
-
assert manifest["
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
assert (out_dir / "train.jsonl").exists()
|
| 99 |
assert (out_dir / "validation.jsonl").exists()
|
| 100 |
train_rows = [
|
|
@@ -110,6 +119,43 @@ def test_dry_run_oracle_creates_chat_jsonl_without_network():
|
|
| 110 |
assert train_rows
|
| 111 |
assert validation_rows
|
| 112 |
assert all(row["messages"][-1]["role"] == "assistant" for row in train_rows)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
|
| 114 |
|
| 115 |
def test_saved_oracle_trajectory_replays_to_success():
|
|
|
|
| 90 |
validation_episodes=1,
|
| 91 |
out_dir=out_dir,
|
| 92 |
dry_run_oracle=True,
|
| 93 |
+
workers=2,
|
| 94 |
+
difficulty_levels=(0, 1),
|
| 95 |
)
|
| 96 |
)
|
| 97 |
|
| 98 |
+
assert manifest["difficulty_levels"] == [0, 1]
|
| 99 |
+
assert manifest["difficulty_bucket_count"] >= 2
|
| 100 |
+
assert manifest["episodes_attempted"] == 6
|
| 101 |
+
assert manifest["episodes_accepted"] == 6
|
| 102 |
+
assert manifest["workers"] == 2
|
| 103 |
+
assert manifest["reward_verification"]["passed"] is True
|
| 104 |
+
assert manifest["reward_verification"]["missing_difficulties"] == []
|
| 105 |
+
assert manifest["rows_by_difficulty"]["0"] > 0
|
| 106 |
+
assert manifest["rows_by_difficulty"]["1"] > 0
|
| 107 |
assert (out_dir / "train.jsonl").exists()
|
| 108 |
assert (out_dir / "validation.jsonl").exists()
|
| 109 |
train_rows = [
|
|
|
|
| 119 |
assert train_rows
|
| 120 |
assert validation_rows
|
| 121 |
assert all(row["messages"][-1]["role"] == "assistant" for row in train_rows)
|
| 122 |
+
reward_check = generate_sft_dataset.verify_sft_dataset_rewards(
|
| 123 |
+
out_dir,
|
| 124 |
+
required_difficulties=(0, 1),
|
| 125 |
+
)
|
| 126 |
+
assert reward_check["passed"] is True
|
| 127 |
+
assert (out_dir / "README.md").exists()
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def test_reward_verification_rejects_low_reward_rows():
|
| 131 |
+
out_dir = _isolated_out_dir("bad_reward")
|
| 132 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 133 |
+
action = CyberSecurityOWASPAction(tool_name="inspect_policy_graph", arguments={})
|
| 134 |
+
row = {
|
| 135 |
+
"messages": [
|
| 136 |
+
{"role": "system", "content": "system"},
|
| 137 |
+
{"role": "user", "content": "user"},
|
| 138 |
+
{"role": "assistant", "content": json.dumps(action.model_dump())},
|
| 139 |
+
],
|
| 140 |
+
"metadata": {
|
| 141 |
+
"final_success": True,
|
| 142 |
+
"terminal_total": 1.0,
|
| 143 |
+
"anti_cheat_flags": [],
|
| 144 |
+
"final_reward_breakdown": {
|
| 145 |
+
"security": 5.0,
|
| 146 |
+
"regression": 3.0,
|
| 147 |
+
"public_routes": 1.0,
|
| 148 |
+
"patch_quality": 2.0,
|
| 149 |
+
"visible_tests": 1.0,
|
| 150 |
+
},
|
| 151 |
+
},
|
| 152 |
+
}
|
| 153 |
+
(out_dir / "train.jsonl").write_text(json.dumps(row) + "\n", encoding="utf-8")
|
| 154 |
+
|
| 155 |
+
reward_check = generate_sft_dataset.verify_sft_dataset_rewards(out_dir)
|
| 156 |
+
|
| 157 |
+
assert reward_check["passed"] is False
|
| 158 |
+
assert reward_check["failure_count"] == 1
|
| 159 |
|
| 160 |
|
| 161 |
def test_saved_oracle_trajectory_replays_to_success():
|