Humanlearning commited on
Commit
60f97ab
·
1 Parent(s): f7b8ac6

feat: expand README with synthetic SFT dataset generation instructions, enhance dataset verification and pushing to Hugging Face Hub, and improve modal training scripts with default configurations for curriculum and GPU fallback

Browse files
README.md CHANGED
@@ -256,6 +256,100 @@ The shell wrapper is equivalent:
256
  MODE=smoke EPISODES=4 uv run --extra modal bash scripts/modal_run_ephemeral.sh
257
  ```
258
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
259
  ## Modal GRPO Training
260
 
261
  The persistent GPU training launcher packages this local repo into Modal, trains
 
256
  MODE=smoke EPISODES=4 uv run --extra modal bash scripts/modal_run_ephemeral.sh
257
  ```
258
 
259
+ ## Synthetic SFT Before GRPO
260
+
261
+ Use supervised fine-tuning to warm-start `unsloth/gemma-4-E2B-it` before GRPO.
262
+ The SFT generator executes every teacher action in the real environment and
263
+ keeps only trajectories that pass the deterministic reward verifier.
264
+
265
+ Generate a 300-train-episode curriculum SFT dataset across levels `0,1,2,3`:
266
+
267
+ ```bash
268
+ uv run python scripts/generate_sft_dataset.py \
269
+ --teacher-model deepseek-ai/DeepSeek-V4-Pro \
270
+ --target-model unsloth/gemma-4-E2B-it \
271
+ --difficulty-levels 0,1,2,3 \
272
+ --difficulty-buckets 4 \
273
+ --episodes 75 \
274
+ --validation-episodes 20 \
275
+ --workers 8 \
276
+ --out-dir outputs/sft
277
+ ```
278
+
279
+ `--episodes` is per difficulty level when `--difficulty-levels` is set, so
280
+ `--episodes 75` across four levels gives 300 total train episodes. Expect
281
+ roughly 2,400-4,500 chat-format JSONL rows because each successful trajectory
282
+ contributes one row per action step. The script writes JSONL rows under
283
+ `outputs/sft/`, trajectory artifacts under `outputs/sft/trajectories/`, a
284
+ dataset card at `outputs/sft/README.md`, and `outputs/sft/manifest.json` with
285
+ reward summaries and curriculum coverage.
286
+
287
+ Verify reward metadata before any training run:
288
+
289
+ ```bash
290
+ uv run python scripts/generate_sft_dataset.py \
291
+ --verify-only \
292
+ --difficulty-levels 0,1,2,3 \
293
+ --out-dir outputs/sft
294
+ ```
295
+
296
+ Push the verified dataset to Hugging Face Hub:
297
+
298
+ ```bash
299
+ uv run python scripts/generate_sft_dataset.py \
300
+ --push-only \
301
+ --difficulty-levels 0,1,2,3 \
302
+ --out-dir outputs/sft \
303
+ --dataset-repo-id Humanlearning/CyberSecurity_OWASP-sft-dataset
304
+ ```
305
+
306
+ The canonical dataset repo name is
307
+ `Humanlearning/CyberSecurity_OWASP-sft-dataset`. The upload is refused if
308
+ reward verification fails or `HF_TOKEN` is missing.
309
+
310
+ You can also generate and push in one command by adding `--push-to-hub` to the
311
+ generation command.
312
+
313
+ For local CI or smoke checks, add `--dry-run-oracle`; official SFT data should
314
+ use the teacher path and still pass the verifier gate above.
315
+
316
+ Launch SFT on Modal after reward verification passes:
317
+
318
+ ```bash
319
+ uv run --extra modal modal run --detach scripts/modal_train_sft.py \
320
+ --local-train-path outputs/sft/train.jsonl \
321
+ --local-validation-path outputs/sft/validation.jsonl \
322
+ --local-manifest-path outputs/sft/manifest.json \
323
+ --required-difficulties 0,1,2,3 \
324
+ --trackio-space-id Humanlearning/CyberSecurity_OWASP-trackio \
325
+ --trackio-project CyberSecurity_OWASP-sft \
326
+ --output-repo-id Humanlearning/CyberSecurity_OWASP-unsloth-gemma-4-e2b-it-sft-lora \
327
+ --push-to-hub \
328
+ --detach
329
+ ```
330
+
331
+ `scripts/modal_train_sft.py` re-checks the JSONL reward metadata locally before
332
+ upload and again inside Modal before loading the model. It refuses to start SFT
333
+ unless all required curriculum difficulties are represented and the verifier
334
+ reward metadata passes. The default SFT config trains one full epoch
335
+ (`--max-steps -1`) with packed assistant-only loss, bf16/tf32, LoRA rank 32,
336
+ and Modal GPU fallback `H200 -> H100 -> A100-80GB -> L40S`. A warm run for the
337
+ 300-episode dataset should usually finish in about 15-45 minutes; first image
338
+ or model-cache builds can push that closer to 35-75 minutes.
339
+
340
+ Continue GRPO from the SFT LoRA:
341
+
342
+ ```bash
343
+ uv run --extra modal modal run --detach scripts/modal_train_grpo.py \
344
+ --initial-adapter-repo-id Humanlearning/CyberSecurity_OWASP-unsloth-gemma-4-e2b-it-sft-lora \
345
+ --max-steps 300 \
346
+ --dataset-size 64 \
347
+ --num-generations 8 \
348
+ --difficulty 0 \
349
+ --trace-log-every 10 \
350
+ --detach
351
+ ```
352
+
353
  ## Modal GRPO Training
354
 
355
  The persistent GPU training launcher packages this local repo into Modal, trains
scripts/generate_sft_dataset.py CHANGED
@@ -14,6 +14,8 @@ import json
14
  import os
15
  import statistics
16
  import subprocess
 
 
17
  from dataclasses import dataclass
18
  from pathlib import Path
19
  from typing import Any, Iterable
@@ -72,6 +74,14 @@ class DatasetConfig:
72
  temperature: float = 0.2
73
  top_p: float = 0.95
74
  dry_run_oracle: bool = False
 
 
 
 
 
 
 
 
75
 
76
 
77
  class HuggingFaceTeacher:
@@ -579,6 +589,174 @@ def write_jsonl(path: Path, rows: Iterable[dict[str, Any]]) -> None:
579
  handle.write(json.dumps(row, sort_keys=True, default=str) + "\n")
580
 
581
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
582
  def _write_trajectory(out_dir: Path, trajectory: dict[str, Any]) -> Path:
583
  traj_dir = out_dir / "trajectories"
584
  traj_dir.mkdir(parents=True, exist_ok=True)
@@ -622,80 +800,365 @@ def _reward_summary(values: list[float]) -> dict[str, float]:
622
  }
623
 
624
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
625
  def generate_dataset(config: DatasetConfig) -> dict[str, Any]:
626
  config.out_dir.mkdir(parents=True, exist_ok=True)
627
- teacher = None
 
628
  if not config.dry_run_oracle:
629
- token = os.getenv("HF_TOKEN")
630
- if not token:
631
  raise RuntimeError("HF_TOKEN is required unless --dry-run-oracle is set")
632
- teacher = HuggingFaceTeacher(
633
- model=config.teacher_model,
634
- token=token,
635
- max_tokens=config.max_tokens,
636
- temperature=config.temperature,
637
- top_p=config.top_p,
638
- )
639
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
640
  split_jobs = [(config.split, config.episodes, config.seed_start)]
641
  if config.validation_episodes:
642
- split_jobs.append(("validation", config.validation_episodes, config.seed_start + config.episodes))
 
 
 
 
 
 
 
 
 
 
 
643
 
644
  rows_by_split: dict[str, list[dict[str, Any]]] = {"train": [], "validation": []}
645
  attempts: list[dict[str, Any]] = []
646
  rewards: list[float] = []
647
  accepted = 0
648
- attempted = 0
649
- for split, episodes, seed_start in split_jobs:
650
- for offset in range(int(episodes)):
651
- seed = int(seed_start) + offset
652
- attempted += 1
653
- result = run_episode(
 
 
 
 
654
  seed=seed,
655
  split=split,
656
- difficulty=config.difficulty,
657
  config=config,
658
- teacher=teacher,
659
- )
660
- attempts.append(
661
- {
662
- "seed": seed,
663
- "split": split,
664
- "accepted": bool(result["accepted"]),
665
- "reason": result.get("reason", ""),
666
- "trajectory_path": str(_write_trajectory(config.out_dir, result["trajectory"])),
667
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
668
  )
669
- if result["accepted"]:
670
- accepted += 1
671
- rows = list(result["rows"])
672
- rows_by_split.setdefault(split, []).extend(rows)
673
- rewards.append(float(result["trajectory"].get("terminal_total", 0.0)))
674
 
675
  for split_name in ("train", "validation", config.split):
676
  write_jsonl(config.out_dir / f"{split_name}.jsonl", rows_by_split.get(split_name, []))
677
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
678
  manifest = {
679
  "teacher_model": config.teacher_model,
680
  "target_model": config.target_model,
681
  "split": config.split,
682
  "difficulty": config.difficulty,
 
 
 
 
683
  "seed_start": config.seed_start,
684
  "episodes_attempted": attempted,
685
  "episodes_accepted": accepted,
686
  "acceptance_rate": accepted / attempted if attempted else 0.0,
 
 
 
 
 
 
 
687
  "rows_by_split": {key: len(value) for key, value in sorted(rows_by_split.items())},
688
  "reward_summary": _reward_summary(rewards),
 
689
  "git_sha": _git_sha(),
690
  "verifier_version": "verifier_v1",
691
  "dry_run_oracle": config.dry_run_oracle,
692
  "attempts": attempts,
693
  }
 
 
 
 
 
 
 
 
 
694
  manifest_path = config.out_dir / "manifest.json"
695
  manifest_path.write_text(
696
  json.dumps(manifest, indent=2, sort_keys=True, default=str),
697
  encoding="utf-8",
698
  )
 
 
 
 
 
 
 
 
 
 
 
699
  return manifest
700
 
701
 
@@ -705,6 +1168,21 @@ def build_arg_parser() -> argparse.ArgumentParser:
705
  parser.add_argument("--target-model", default=DEFAULT_TARGET_MODEL)
706
  parser.add_argument("--split", default="train", choices=["train", "validation", "hidden_eval"])
707
  parser.add_argument("--difficulty", type=int, default=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
708
  parser.add_argument("--seed-start", type=int, default=0)
709
  parser.add_argument("--episodes", type=int, default=100)
710
  parser.add_argument("--validation-episodes", type=int, default=0)
@@ -714,6 +1192,48 @@ def build_arg_parser() -> argparse.ArgumentParser:
714
  parser.add_argument("--max-tokens", type=int, default=768)
715
  parser.add_argument("--temperature", type=float, default=0.2)
716
  parser.add_argument("--top-p", type=float, default=0.95)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
717
  parser.add_argument(
718
  "--dry-run-oracle",
719
  action="store_true",
@@ -728,6 +1248,8 @@ def config_from_args(args: argparse.Namespace) -> DatasetConfig:
728
  target_model=args.target_model,
729
  split=args.split,
730
  difficulty=args.difficulty,
 
 
731
  seed_start=args.seed_start,
732
  episodes=args.episodes,
733
  validation_episodes=args.validation_episodes,
@@ -738,15 +1260,50 @@ def config_from_args(args: argparse.Namespace) -> DatasetConfig:
738
  temperature=args.temperature,
739
  top_p=args.top_p,
740
  dry_run_oracle=args.dry_run_oracle,
 
 
 
 
 
 
741
  )
742
 
743
 
744
  def main(argv: list[str] | None = None) -> int:
745
  parser = build_arg_parser()
746
  args = parser.parse_args(argv)
747
- manifest = generate_dataset(config_from_args(args))
748
- print(json.dumps(manifest, indent=2, sort_keys=True))
749
- return 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
750
 
751
 
752
  if __name__ == "__main__":
 
14
  import os
15
  import statistics
16
  import subprocess
17
+ import threading
18
+ from concurrent.futures import ThreadPoolExecutor, as_completed
19
  from dataclasses import dataclass
20
  from pathlib import Path
21
  from typing import Any, Iterable
 
74
  temperature: float = 0.2
75
  top_p: float = 0.95
76
  dry_run_oracle: bool = False
77
+ workers: int = 0
78
+ min_terminal_reward: float = 12.0
79
+ difficulty_levels: tuple[int, ...] = ()
80
+ difficulty_buckets: int = 0
81
+ push_to_hub: bool = False
82
+ dataset_repo_id: str = "Humanlearning/CyberSecurity_OWASP-sft-dataset"
83
+ hub_private: bool = False
84
+ progress: bool = False
85
 
86
 
87
  class HuggingFaceTeacher:
 
589
  handle.write(json.dumps(row, sort_keys=True, default=str) + "\n")
590
 
591
 
592
+ def write_dataset_card(out_dir: Path, manifest: dict[str, Any], dataset_repo_id: str) -> Path:
593
+ card_path = out_dir / "README.md"
594
+ difficulty_levels = manifest.get("difficulty_levels", [])
595
+ reward_verification = manifest.get("reward_verification", {})
596
+ card = f"""---
597
+ license: apache-2.0
598
+ task_categories:
599
+ - text-generation
600
+ language:
601
+ - en
602
+ tags:
603
+ - cybersecurity
604
+ - owasp
605
+ - openenv
606
+ - tool-use
607
+ - sft
608
+ pretty_name: CyberSecurity_OWASP SFT Dataset
609
+ ---
610
+
611
+ # CyberSecurity_OWASP SFT Dataset
612
+
613
+ This dataset contains verifier-gated supervised fine-tuning examples for the
614
+ `CyberSecurity_OWASP` OpenEnv environment. Each row teaches one step of the
615
+ defensive local AppSec workflow: inspect policy/code, reproduce a local
616
+ authorization failure, submit a policy-tied diagnosis, patch the generated app,
617
+ run visible tests, and submit the fix.
618
+
619
+ Every kept trajectory is executed against the real local environment and must
620
+ pass the deterministic reward verifier before rows are written.
621
+
622
+ ## Intended Use
623
+
624
+ - Target SFT model: `{manifest.get("target_model", "")}`
625
+ - Teacher model: `{manifest.get("teacher_model", "")}`
626
+ - Dataset repo: `{dataset_repo_id}`
627
+ - Format: chat JSONL with `messages` and verifier metadata
628
+ - Dry-run oracle: `{manifest.get("dry_run_oracle", False)}`
629
+
630
+ ## Curriculum Coverage
631
+
632
+ - Difficulty levels: `{difficulty_levels}`
633
+ - Episodes attempted: `{manifest.get("episodes_attempted", 0)}`
634
+ - Episodes accepted: `{manifest.get("episodes_accepted", 0)}`
635
+ - Acceptance rate: `{manifest.get("acceptance_rate", 0.0):.4f}`
636
+ - Rows by split: `{json.dumps(manifest.get("rows_by_split", {}), sort_keys=True)}`
637
+ - Rows by difficulty: `{json.dumps(manifest.get("rows_by_difficulty", {}), sort_keys=True)}`
638
+
639
+ ## Reward Verification
640
+
641
+ - Passed: `{reward_verification.get("passed", False)}`
642
+ - Checked rows: `{reward_verification.get("checked_rows", 0)}`
643
+ - Minimum terminal reward: `{reward_verification.get("min_terminal_reward", 0.0)}`
644
+ - Reward summary: `{json.dumps(reward_verification.get("reward_summary", {}), sort_keys=True)}`
645
+
646
+ Rows are rejected if the episode fails hidden security/regression/public-route
647
+ checks, triggers anti-cheat flags, lacks a positive patch-quality reward, or
648
+ falls below the configured terminal reward threshold.
649
+
650
+ ## Schema
651
+
652
+ Each JSONL row has:
653
+
654
+ ```json
655
+ {{
656
+ "messages": [
657
+ {{"role": "system", "content": "..."}},
658
+ {{"role": "user", "content": "..."}},
659
+ {{"role": "assistant", "content": "{{\\"tool_name\\":\\"...\\",\\"arguments\\":{{...}}}}"}}
660
+ ],
661
+ "metadata": {{
662
+ "target_model": "...",
663
+ "teacher_model": "...",
664
+ "seed": 0,
665
+ "split": "train",
666
+ "difficulty": 0,
667
+ "step": 1,
668
+ "tool_name": "inspect_policy_graph",
669
+ "final_success": true,
670
+ "terminal_total": 12.5,
671
+ "anti_cheat_flags": []
672
+ }}
673
+ }}
674
+ ```
675
+ """
676
+ card_path.write_text(card, encoding="utf-8")
677
+ return card_path
678
+
679
+
680
+ def push_dataset_to_hub(out_dir: Path, *, repo_id: str, private: bool) -> dict[str, Any]:
681
+ token = os.getenv("HF_TOKEN")
682
+ if not token:
683
+ raise RuntimeError("HF_TOKEN is required for --push-to-hub")
684
+ try:
685
+ from huggingface_hub import HfApi
686
+ except ImportError as exc: # pragma: no cover
687
+ raise RuntimeError("huggingface_hub is required for --push-to-hub") from exc
688
+
689
+ api = HfApi(token=token)
690
+ api.create_repo(repo_id=repo_id, repo_type="dataset", private=private, exist_ok=True)
691
+ commit_info = api.upload_folder(
692
+ repo_id=repo_id,
693
+ repo_type="dataset",
694
+ folder_path=str(out_dir),
695
+ path_in_repo=".",
696
+ commit_message="Upload verified CyberSecurity_OWASP SFT dataset",
697
+ delete_patterns=[
698
+ "README.md",
699
+ "manifest.json",
700
+ "train.jsonl",
701
+ "validation.jsonl",
702
+ "hidden_eval.jsonl",
703
+ "trajectories/**",
704
+ ],
705
+ )
706
+ return {
707
+ "repo_id": repo_id,
708
+ "private": bool(private),
709
+ "url": f"https://huggingface.co/datasets/{repo_id}",
710
+ "commit_url": getattr(commit_info, "commit_url", ""),
711
+ }
712
+
713
+
714
+ def push_existing_dataset(
715
+ out_dir: Path,
716
+ *,
717
+ repo_id: str,
718
+ private: bool,
719
+ min_terminal_reward: float,
720
+ required_difficulties: tuple[int, ...],
721
+ ) -> dict[str, Any]:
722
+ verification = verify_sft_dataset_rewards(
723
+ out_dir,
724
+ min_terminal_reward=min_terminal_reward,
725
+ require_train_rows=True,
726
+ required_difficulties=required_difficulties,
727
+ )
728
+ if not verification["passed"]:
729
+ raise RuntimeError(f"Reward verification failed; refusing Hub push: {verification}")
730
+ manifest_path = out_dir / "manifest.json"
731
+ if manifest_path.exists():
732
+ manifest = json.loads(manifest_path.read_text(encoding="utf-8"))
733
+ else:
734
+ manifest = {
735
+ "teacher_model": DEFAULT_TEACHER_MODEL,
736
+ "target_model": DEFAULT_TARGET_MODEL,
737
+ "difficulty_levels": [int(level) for level in required_difficulties],
738
+ "rows_by_split": verification.get("rows_by_split", {}),
739
+ }
740
+ manifest["reward_verification"] = verification
741
+ manifest["hub"] = {
742
+ "repo_id": repo_id,
743
+ "private": bool(private),
744
+ "url": f"https://huggingface.co/datasets/{repo_id}",
745
+ }
746
+ write_dataset_card(out_dir, manifest, repo_id)
747
+ manifest_path.write_text(
748
+ json.dumps(manifest, indent=2, sort_keys=True, default=str),
749
+ encoding="utf-8",
750
+ )
751
+ hub_result = push_dataset_to_hub(out_dir, repo_id=repo_id, private=private)
752
+ manifest["hub"].update(hub_result)
753
+ manifest_path.write_text(
754
+ json.dumps(manifest, indent=2, sort_keys=True, default=str),
755
+ encoding="utf-8",
756
+ )
757
+ return {"reward_verification": verification, "hub": manifest["hub"]}
758
+
759
+
760
  def _write_trajectory(out_dir: Path, trajectory: dict[str, Any]) -> Path:
761
  traj_dir = out_dir / "trajectories"
762
  traj_dir.mkdir(parents=True, exist_ok=True)
 
800
  }
801
 
802
 
803
+ def _parse_int_csv(value: str) -> tuple[int, ...]:
804
+ if not value.strip():
805
+ return ()
806
+ levels = []
807
+ for item in value.split(","):
808
+ stripped = item.strip()
809
+ if not stripped:
810
+ continue
811
+ levels.append(int(stripped))
812
+ return tuple(dict.fromkeys(levels))
813
+
814
+
815
+ def _difficulty_levels(config: DatasetConfig) -> tuple[int, ...]:
816
+ if config.difficulty_levels:
817
+ return tuple(int(level) for level in config.difficulty_levels)
818
+ return (int(config.difficulty),)
819
+
820
+
821
+ def _configure_difficulty_buckets(config: DatasetConfig, levels: tuple[int, ...]) -> int:
822
+ requested = max(levels) + 1 if levels else int(config.difficulty) + 1
823
+ configured = max(int(config.difficulty_buckets or 0), requested, 1)
824
+ existing = os.getenv("CYBERSECURITY_OWASP_DIFFICULTY_BUCKETS")
825
+ if existing:
826
+ configured = max(configured, int(existing))
827
+ os.environ["CYBERSECURITY_OWASP_DIFFICULTY_BUCKETS"] = str(configured)
828
+ return configured
829
+
830
+
831
+ def _read_jsonl(path: Path) -> list[dict[str, Any]]:
832
+ if not path.exists():
833
+ return []
834
+ rows: list[dict[str, Any]] = []
835
+ for line_number, line in enumerate(path.read_text(encoding="utf-8").splitlines(), start=1):
836
+ if not line.strip():
837
+ continue
838
+ try:
839
+ item = json.loads(line)
840
+ except json.JSONDecodeError as exc:
841
+ raise ValueError(f"{path}:{line_number}: invalid JSONL row: {exc}") from exc
842
+ if not isinstance(item, dict):
843
+ raise ValueError(f"{path}:{line_number}: row must be a JSON object")
844
+ rows.append(item)
845
+ return rows
846
+
847
+
848
+ def _verify_sft_row_reward(
849
+ row: dict[str, Any],
850
+ *,
851
+ min_terminal_reward: float,
852
+ path: Path,
853
+ line_number: int,
854
+ ) -> tuple[bool, str, float]:
855
+ messages = row.get("messages")
856
+ if not isinstance(messages, list) or len(messages) < 3:
857
+ return False, f"{path}:{line_number}: messages must include system/user/assistant", 0.0
858
+ if messages[-1].get("role") != "assistant":
859
+ return False, f"{path}:{line_number}: final message must be assistant", 0.0
860
+ try:
861
+ CyberSecurityOWASPAction(**json.loads(str(messages[-1].get("content", ""))))
862
+ except Exception as exc:
863
+ return False, f"{path}:{line_number}: assistant content is not a valid action: {exc}", 0.0
864
+ metadata = row.get("metadata")
865
+ if not isinstance(metadata, dict):
866
+ return False, f"{path}:{line_number}: missing metadata object", 0.0
867
+ if metadata.get("final_success") is not True:
868
+ return False, f"{path}:{line_number}: final_success is not true", 0.0
869
+ flags = metadata.get("anti_cheat_flags") or []
870
+ if flags:
871
+ return False, f"{path}:{line_number}: anti-cheat flags present: {flags}", 0.0
872
+ reward = float(metadata.get("terminal_total", 0.0) or 0.0)
873
+ if reward < min_terminal_reward:
874
+ return (
875
+ False,
876
+ f"{path}:{line_number}: terminal_total {reward:.3f} below required {min_terminal_reward:.3f}",
877
+ reward,
878
+ )
879
+ breakdown = metadata.get("final_reward_breakdown") or {}
880
+ if not isinstance(breakdown, dict):
881
+ return False, f"{path}:{line_number}: missing final_reward_breakdown", reward
882
+ required_positive = ("security", "regression", "public_routes", "patch_quality", "visible_tests")
883
+ missing = [key for key in required_positive if float(breakdown.get(key, 0.0) or 0.0) <= 0.0]
884
+ if missing:
885
+ return False, f"{path}:{line_number}: non-positive reward components: {', '.join(missing)}", reward
886
+ return True, "", reward
887
+
888
+
889
+ def verify_sft_dataset_rewards(
890
+ out_dir: Path,
891
+ *,
892
+ min_terminal_reward: float = 12.0,
893
+ require_train_rows: bool = True,
894
+ required_difficulties: tuple[int, ...] = (),
895
+ ) -> dict[str, Any]:
896
+ """Verify generated SFT rows carry successful verifier-backed rewards."""
897
+
898
+ checked_rows = 0
899
+ failed_rows: list[str] = []
900
+ rewards: list[float] = []
901
+ rows_by_split: dict[str, int] = {}
902
+ rows_by_difficulty: dict[str, int] = {}
903
+ for split_name in ("train", "validation", "hidden_eval"):
904
+ path = out_dir / f"{split_name}.jsonl"
905
+ rows = _read_jsonl(path)
906
+ if not rows and split_name != "train":
907
+ continue
908
+ rows_by_split[split_name] = len(rows)
909
+ for index, row in enumerate(rows, start=1):
910
+ ok, error, reward = _verify_sft_row_reward(
911
+ row,
912
+ min_terminal_reward=min_terminal_reward,
913
+ path=path,
914
+ line_number=index,
915
+ )
916
+ checked_rows += 1
917
+ if reward:
918
+ rewards.append(reward)
919
+ if not ok:
920
+ failed_rows.append(error)
921
+ metadata = row.get("metadata") if isinstance(row, dict) else {}
922
+ if isinstance(metadata, dict) and "difficulty" in metadata:
923
+ difficulty_key = str(int(metadata.get("difficulty", 0)))
924
+ rows_by_difficulty[difficulty_key] = rows_by_difficulty.get(difficulty_key, 0) + 1
925
+ passed = not failed_rows and (checked_rows > 0 or not require_train_rows)
926
+ if require_train_rows and rows_by_split.get("train", 0) <= 0:
927
+ passed = False
928
+ failed_rows.append(f"{out_dir / 'train.jsonl'}: no train rows found")
929
+ missing_difficulties = [
930
+ int(level)
931
+ for level in required_difficulties
932
+ if rows_by_difficulty.get(str(int(level)), 0) <= 0
933
+ ]
934
+ if missing_difficulties:
935
+ passed = False
936
+ failed_rows.append(f"missing required curriculum difficulty rows: {missing_difficulties}")
937
+ return {
938
+ "passed": passed,
939
+ "checked_rows": checked_rows,
940
+ "failed_rows": failed_rows[:50],
941
+ "failure_count": len(failed_rows),
942
+ "rows_by_split": rows_by_split,
943
+ "rows_by_difficulty": rows_by_difficulty,
944
+ "required_difficulties": [int(level) for level in required_difficulties],
945
+ "missing_difficulties": missing_difficulties,
946
+ "min_terminal_reward": float(min_terminal_reward),
947
+ "reward_summary": _reward_summary(rewards),
948
+ }
949
+
950
+
951
+ def _resolved_worker_count(config: DatasetConfig, job_count: int) -> int:
952
+ if job_count <= 1:
953
+ return 1
954
+ if int(config.workers) > 0:
955
+ return max(1, min(int(config.workers), job_count))
956
+ cpu_count = os.cpu_count() or 4
957
+ return max(1, min(8, cpu_count, job_count))
958
+
959
+
960
  def generate_dataset(config: DatasetConfig) -> dict[str, Any]:
961
  config.out_dir.mkdir(parents=True, exist_ok=True)
962
+ teacher_local = threading.local()
963
+ teacher_token = None
964
  if not config.dry_run_oracle:
965
+ teacher_token = os.getenv("HF_TOKEN")
966
+ if not teacher_token:
967
  raise RuntimeError("HF_TOKEN is required unless --dry-run-oracle is set")
 
 
 
 
 
 
 
968
 
969
+ def teacher_for_thread() -> HuggingFaceTeacher | None:
970
+ if config.dry_run_oracle:
971
+ return None
972
+ teacher = getattr(teacher_local, "teacher", None)
973
+ if teacher is None:
974
+ teacher = HuggingFaceTeacher(
975
+ model=config.teacher_model,
976
+ token=str(teacher_token),
977
+ max_tokens=config.max_tokens,
978
+ temperature=config.temperature,
979
+ top_p=config.top_p,
980
+ )
981
+ teacher_local.teacher = teacher
982
+ return teacher
983
+
984
+ difficulty_levels = _difficulty_levels(config)
985
+ difficulty_bucket_count = _configure_difficulty_buckets(config, difficulty_levels)
986
+ validation_seed_start = config.seed_start + int(config.episodes) * len(difficulty_levels)
987
  split_jobs = [(config.split, config.episodes, config.seed_start)]
988
  if config.validation_episodes:
989
+ split_jobs.append(("validation", config.validation_episodes, validation_seed_start))
990
+ episode_jobs = [
991
+ {
992
+ "order": job_order,
993
+ "split": split,
994
+ "difficulty": int(difficulty),
995
+ "seed": int(seed_start) + difficulty_index * int(episodes) + offset,
996
+ }
997
+ for job_order, (split, episodes, seed_start) in enumerate(split_jobs)
998
+ for difficulty_index, difficulty in enumerate(difficulty_levels)
999
+ for offset in range(int(episodes))
1000
+ ]
1001
 
1002
  rows_by_split: dict[str, list[dict[str, Any]]] = {"train": [], "validation": []}
1003
  attempts: list[dict[str, Any]] = []
1004
  rewards: list[float] = []
1005
  accepted = 0
1006
+ attempted = len(episode_jobs)
1007
+ workers = _resolved_worker_count(config, attempted)
1008
+
1009
+ def run_job(job: dict[str, Any]) -> dict[str, Any]:
1010
+ seed = int(job["seed"])
1011
+ split = str(job["split"])
1012
+ difficulty = int(job["difficulty"])
1013
+ return {
1014
+ "order": int(job["order"]),
1015
+ **run_episode(
1016
  seed=seed,
1017
  split=split,
1018
+ difficulty=difficulty,
1019
  config=config,
1020
+ teacher=teacher_for_thread(),
1021
+ ),
1022
+ }
1023
+
1024
+ results: list[dict[str, Any]] = []
1025
+ with ThreadPoolExecutor(max_workers=workers, thread_name_prefix="sft-episode") as executor:
1026
+ futures = [executor.submit(run_job, job) for job in episode_jobs]
1027
+ for future in as_completed(futures):
1028
+ result = future.result()
1029
+ results.append(result)
1030
+ if config.progress:
1031
+ print(
1032
+ json.dumps(
1033
+ {
1034
+ "event": "episode_done",
1035
+ "accepted": bool(result.get("accepted")),
1036
+ "split": result.get("split"),
1037
+ "difficulty": result.get("difficulty"),
1038
+ "seed": result.get("seed"),
1039
+ "reason": result.get("reason", ""),
1040
+ },
1041
+ sort_keys=True,
1042
+ ),
1043
+ flush=True,
1044
+ )
1045
+
1046
+ for result in sorted(
1047
+ results,
1048
+ key=lambda item: (
1049
+ str(item.get("split", "")),
1050
+ int(item.get("difficulty", 0)),
1051
+ int(item.get("seed", 0)),
1052
+ ),
1053
+ ):
1054
+ seed = int(result["seed"])
1055
+ split = str(result["split"])
1056
+ difficulty = int(result["difficulty"])
1057
+ attempts.append(
1058
+ {
1059
+ "seed": seed,
1060
+ "split": split,
1061
+ "difficulty": difficulty,
1062
+ "accepted": bool(result["accepted"]),
1063
+ "reason": result.get("reason", ""),
1064
+ "trajectory_path": str(_write_trajectory(config.out_dir, result["trajectory"])),
1065
+ }
1066
+ )
1067
+ if result["accepted"]:
1068
+ accepted += 1
1069
+ rows = list(result["rows"])
1070
+ rows_by_split.setdefault(split, []).extend(rows)
1071
+ rewards.append(float(result["trajectory"].get("terminal_total", 0.0)))
1072
+
1073
+ for split_rows in rows_by_split.values():
1074
+ split_rows.sort(
1075
+ key=lambda row: (
1076
+ int((row.get("metadata") or {}).get("difficulty", 0)),
1077
+ int((row.get("metadata") or {}).get("seed", 0)),
1078
+ int((row.get("metadata") or {}).get("step", 0)),
1079
  )
1080
+ )
 
 
 
 
1081
 
1082
  for split_name in ("train", "validation", config.split):
1083
  write_jsonl(config.out_dir / f"{split_name}.jsonl", rows_by_split.get(split_name, []))
1084
 
1085
+ reward_verification = verify_sft_dataset_rewards(
1086
+ config.out_dir,
1087
+ min_terminal_reward=config.min_terminal_reward,
1088
+ require_train_rows=config.split == "train",
1089
+ required_difficulties=difficulty_levels if len(difficulty_levels) > 1 else (),
1090
+ )
1091
+
1092
+ accepted_by_difficulty: dict[str, int] = {}
1093
+ attempted_by_difficulty: dict[str, int] = {}
1094
+ reward_by_difficulty: dict[str, list[float]] = {}
1095
+ row_count_by_difficulty: dict[str, int] = {}
1096
+ for result in results:
1097
+ difficulty_key = str(int(result.get("difficulty", 0)))
1098
+ attempted_by_difficulty[difficulty_key] = attempted_by_difficulty.get(difficulty_key, 0) + 1
1099
+ if result.get("accepted"):
1100
+ accepted_by_difficulty[difficulty_key] = accepted_by_difficulty.get(difficulty_key, 0) + 1
1101
+ reward_by_difficulty.setdefault(difficulty_key, []).append(
1102
+ float((result.get("trajectory") or {}).get("terminal_total", 0.0))
1103
+ )
1104
+ for split_rows in rows_by_split.values():
1105
+ for row in split_rows:
1106
+ difficulty_key = str(int((row.get("metadata") or {}).get("difficulty", 0)))
1107
+ row_count_by_difficulty[difficulty_key] = row_count_by_difficulty.get(difficulty_key, 0) + 1
1108
+
1109
  manifest = {
1110
  "teacher_model": config.teacher_model,
1111
  "target_model": config.target_model,
1112
  "split": config.split,
1113
  "difficulty": config.difficulty,
1114
+ "difficulty_levels": [int(level) for level in difficulty_levels],
1115
+ "difficulty_bucket_count": int(difficulty_bucket_count),
1116
+ "episodes_per_difficulty": config.episodes,
1117
+ "validation_episodes_per_difficulty": config.validation_episodes,
1118
  "seed_start": config.seed_start,
1119
  "episodes_attempted": attempted,
1120
  "episodes_accepted": accepted,
1121
  "acceptance_rate": accepted / attempted if attempted else 0.0,
1122
+ "attempted_by_difficulty": attempted_by_difficulty,
1123
+ "accepted_by_difficulty": accepted_by_difficulty,
1124
+ "rows_by_difficulty": row_count_by_difficulty,
1125
+ "reward_summary_by_difficulty": {
1126
+ key: _reward_summary(value) for key, value in sorted(reward_by_difficulty.items())
1127
+ },
1128
+ "workers": workers,
1129
  "rows_by_split": {key: len(value) for key, value in sorted(rows_by_split.items())},
1130
  "reward_summary": _reward_summary(rewards),
1131
+ "reward_verification": reward_verification,
1132
  "git_sha": _git_sha(),
1133
  "verifier_version": "verifier_v1",
1134
  "dry_run_oracle": config.dry_run_oracle,
1135
  "attempts": attempts,
1136
  }
1137
+ if config.push_to_hub:
1138
+ if not reward_verification["passed"]:
1139
+ raise RuntimeError("Reward verification failed; refusing to push dataset to Hub.")
1140
+ manifest["hub"] = {
1141
+ "repo_id": config.dataset_repo_id,
1142
+ "private": bool(config.hub_private),
1143
+ "url": f"https://huggingface.co/datasets/{config.dataset_repo_id}",
1144
+ }
1145
+ write_dataset_card(config.out_dir, manifest, config.dataset_repo_id)
1146
  manifest_path = config.out_dir / "manifest.json"
1147
  manifest_path.write_text(
1148
  json.dumps(manifest, indent=2, sort_keys=True, default=str),
1149
  encoding="utf-8",
1150
  )
1151
+ if config.push_to_hub:
1152
+ hub_result = push_dataset_to_hub(
1153
+ config.out_dir,
1154
+ repo_id=config.dataset_repo_id,
1155
+ private=config.hub_private,
1156
+ )
1157
+ manifest["hub"].update(hub_result)
1158
+ manifest_path.write_text(
1159
+ json.dumps(manifest, indent=2, sort_keys=True, default=str),
1160
+ encoding="utf-8",
1161
+ )
1162
  return manifest
1163
 
1164
 
 
1168
  parser.add_argument("--target-model", default=DEFAULT_TARGET_MODEL)
1169
  parser.add_argument("--split", default="train", choices=["train", "validation", "hidden_eval"])
1170
  parser.add_argument("--difficulty", type=int, default=0)
1171
+ parser.add_argument(
1172
+ "--difficulty-levels",
1173
+ default="",
1174
+ help="Comma-separated curriculum levels to include, for example 0,1,2,3. "
1175
+ "When set, --episodes is per difficulty level.",
1176
+ )
1177
+ parser.add_argument(
1178
+ "--difficulty-buckets",
1179
+ type=int,
1180
+ default=0,
1181
+ help=(
1182
+ "Number of curriculum difficulty buckets to expose to the environment. "
1183
+ "Defaults to max(--difficulty-levels)+1."
1184
+ ),
1185
+ )
1186
  parser.add_argument("--seed-start", type=int, default=0)
1187
  parser.add_argument("--episodes", type=int, default=100)
1188
  parser.add_argument("--validation-episodes", type=int, default=0)
 
1192
  parser.add_argument("--max-tokens", type=int, default=768)
1193
  parser.add_argument("--temperature", type=float, default=0.2)
1194
  parser.add_argument("--top-p", type=float, default=0.95)
1195
+ parser.add_argument(
1196
+ "--workers",
1197
+ type=int,
1198
+ default=0,
1199
+ help="Parallel episode workers. 0 auto-selects up to 8 workers.",
1200
+ )
1201
+ parser.add_argument(
1202
+ "--min-terminal-reward",
1203
+ type=float,
1204
+ default=12.0,
1205
+ help="Minimum verifier-backed terminal reward required for SFT rows.",
1206
+ )
1207
+ parser.add_argument(
1208
+ "--verify-only",
1209
+ action="store_true",
1210
+ help="Only verify an existing out-dir dataset reward metadata.",
1211
+ )
1212
+ parser.add_argument(
1213
+ "--push-to-hub",
1214
+ action="store_true",
1215
+ help="Upload the verified dataset folder to a Hugging Face dataset repo.",
1216
+ )
1217
+ parser.add_argument(
1218
+ "--progress",
1219
+ action="store_true",
1220
+ help="Print one JSON progress event for each completed episode job.",
1221
+ )
1222
+ parser.add_argument(
1223
+ "--push-only",
1224
+ action="store_true",
1225
+ help="Verify and upload an existing out-dir dataset without regenerating rows.",
1226
+ )
1227
+ parser.add_argument(
1228
+ "--dataset-repo-id",
1229
+ default="Humanlearning/CyberSecurity_OWASP-sft-dataset",
1230
+ help="Hugging Face dataset repo id used with --push-to-hub.",
1231
+ )
1232
+ parser.add_argument(
1233
+ "--hub-private",
1234
+ action="store_true",
1235
+ help="Create/upload the Hugging Face dataset repo as private.",
1236
+ )
1237
  parser.add_argument(
1238
  "--dry-run-oracle",
1239
  action="store_true",
 
1248
  target_model=args.target_model,
1249
  split=args.split,
1250
  difficulty=args.difficulty,
1251
+ difficulty_levels=_parse_int_csv(args.difficulty_levels),
1252
+ difficulty_buckets=args.difficulty_buckets,
1253
  seed_start=args.seed_start,
1254
  episodes=args.episodes,
1255
  validation_episodes=args.validation_episodes,
 
1260
  temperature=args.temperature,
1261
  top_p=args.top_p,
1262
  dry_run_oracle=args.dry_run_oracle,
1263
+ workers=args.workers,
1264
+ min_terminal_reward=args.min_terminal_reward,
1265
+ push_to_hub=args.push_to_hub,
1266
+ dataset_repo_id=args.dataset_repo_id,
1267
+ hub_private=args.hub_private,
1268
+ progress=args.progress,
1269
  )
1270
 
1271
 
1272
  def main(argv: list[str] | None = None) -> int:
1273
  parser = build_arg_parser()
1274
  args = parser.parse_args(argv)
1275
+ try:
1276
+ if args.verify_only:
1277
+ verification = verify_sft_dataset_rewards(
1278
+ args.out_dir,
1279
+ min_terminal_reward=args.min_terminal_reward,
1280
+ require_train_rows=args.split == "train",
1281
+ required_difficulties=_parse_int_csv(args.difficulty_levels),
1282
+ )
1283
+ print(json.dumps({"reward_verification": verification}, indent=2, sort_keys=True))
1284
+ return 0 if verification["passed"] else 2
1285
+ if args.push_only:
1286
+ result = push_existing_dataset(
1287
+ args.out_dir,
1288
+ repo_id=args.dataset_repo_id,
1289
+ private=args.hub_private,
1290
+ min_terminal_reward=args.min_terminal_reward,
1291
+ required_difficulties=_parse_int_csv(args.difficulty_levels),
1292
+ )
1293
+ print(json.dumps(result, indent=2, sort_keys=True))
1294
+ return 0
1295
+ manifest = generate_dataset(config_from_args(args))
1296
+ print(json.dumps(manifest, indent=2, sort_keys=True))
1297
+ return 0 if manifest.get("reward_verification", {}).get("passed") else 2
1298
+ except (RuntimeError, ValueError) as exc:
1299
+ print(
1300
+ json.dumps(
1301
+ {"error": str(exc), "error_type": exc.__class__.__name__},
1302
+ indent=2,
1303
+ sort_keys=True,
1304
+ )
1305
+ )
1306
+ return 2
1307
 
1308
 
1309
  if __name__ == "__main__":
scripts/modal_train_sft.py CHANGED
@@ -33,6 +33,15 @@ TRITON_CACHE_DIR = CACHE_DIR / "triton"
33
  REMOTE_PROJECT = "/root/CyberSecurity_OWASP"
34
  PROJECT_ROOT = pathlib.Path(__file__).resolve().parents[1]
35
  DEFAULT_GEMMA_MODEL = "unsloth/gemma-4-E2B-it"
 
 
 
 
 
 
 
 
 
36
  PUBLIC_REPO_URL = "https://github.com/humandotlearning/CyberSecurity_OWASP.git"
37
  PUBLIC_REPO_BRANCH = "master"
38
 
@@ -50,6 +59,170 @@ def _model_repo_slug(model_name: str) -> str:
50
  return model_name.replace("/", "-").replace("_", "-").replace(".", "-").lower()
51
 
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  def _configure_modal_cache_env() -> dict[str, str]:
54
  values = {
55
  "HF_HOME": str(HF_HOME_DIR),
@@ -171,7 +344,7 @@ def upload_sft_jsonl(relative_path: str, content: str) -> str:
171
 
172
  @app.function(
173
  image=training_image,
174
- gpu="L4",
175
  timeout=12 * 60 * 60,
176
  volumes={RUNS_DIR: volume, CACHE_DIR: cache_volume},
177
  secrets=secrets,
@@ -179,24 +352,29 @@ def upload_sft_jsonl(relative_path: str, content: str) -> str:
179
  def train_cybersecurity_owasp_sft(
180
  train_jsonl: str = "/runs/sft/train.jsonl",
181
  validation_jsonl: str = "/runs/sft/validation.jsonl",
182
- output_repo_id: str = "",
 
183
  model_name: str = DEFAULT_GEMMA_MODEL,
184
  run_name: str = "",
185
  max_seq_length: int = 4096,
186
- max_steps: int = 100,
187
  num_train_epochs: float = 1.0,
188
- per_device_train_batch_size: int = 1,
189
- gradient_accumulation_steps: int = 16,
190
  learning_rate: float = 2e-5,
191
  lora_rank: int = 32,
192
- trackio_space_id: str = "Humanlearning/CyberSecurity_OWASP-trackio",
193
- trackio_project: str = "CyberSecurity_OWASP-sft",
 
 
 
 
194
  push_to_hub: bool = False,
195
  ) -> dict[str, Any]:
196
  import inspect
197
 
198
  from datasets import load_dataset
199
- from huggingface_hub import snapshot_download, whoami
200
  from trl import SFTConfig, SFTTrainer
201
  from trl.chat_template_utils import add_response_schema
202
  from unsloth import FastVisionModel
@@ -207,10 +385,9 @@ def train_cybersecurity_owasp_sft(
207
  if not hf_token:
208
  raise RuntimeError(f"HF_TOKEN is missing from the Modal secret {SECRET_NAME}.")
209
 
210
- user = whoami(token=hf_token)["name"]
211
- output_repo_id = output_repo_id or (
212
- f"{user}/CyberSecurity_OWASP-{_model_repo_slug(model_name)}-sft-lora"
213
- )
214
  stamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
215
  run_name = run_name or f"CyberSecurity_OWASP-{_model_repo_slug(model_name)}-sft-{stamp}"
216
  output_dir = RUNS_DIR / run_name
@@ -222,6 +399,21 @@ def train_cybersecurity_owasp_sft(
222
  has_validation = validation_path.exists() and validation_path.stat().st_size > 0
223
  if has_validation:
224
  data_files["validation"] = validation_jsonl
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  dataset = load_dataset("json", data_files=data_files)
226
 
227
  print(f"SFT run name: {run_name}")
@@ -232,6 +424,11 @@ def train_cybersecurity_owasp_sft(
232
  print(f"Output repo: https://huggingface.co/{output_repo_id}")
233
  print(f"Trackio Space: https://huggingface.co/spaces/{trackio_space_id}")
234
  print(f"HF_HUB_CACHE: {cache_env['HF_HUB_CACHE']}")
 
 
 
 
 
235
 
236
  try:
237
  snapshot_download(repo_id=model_name, cache_dir=str(HF_HUB_CACHE_DIR), token=hf_token)
@@ -280,19 +477,25 @@ def train_cybersecurity_owasp_sft(
280
  "per_device_train_batch_size": per_device_train_batch_size,
281
  "gradient_accumulation_steps": gradient_accumulation_steps,
282
  "learning_rate": learning_rate,
 
283
  "logging_steps": 1,
284
- "save_steps": max(10, max_steps),
 
285
  "report_to": "trackio",
286
  "project": trackio_project,
287
  "trackio_space_id": trackio_space_id,
288
  "run_name": run_name,
289
  "assistant_only_loss": True,
290
- "packing": False,
 
 
 
291
  "gradient_checkpointing": True,
292
  "gradient_checkpointing_kwargs": {"use_reentrant": False},
293
  "push_to_hub": push_to_hub,
294
  "hub_model_id": output_repo_id,
295
  "hub_private_repo": True,
 
296
  }
297
  sft_parameters = set(inspect.signature(SFTConfig).parameters)
298
  skipped = sorted(set(sft_values) - sft_parameters)
@@ -335,6 +538,11 @@ def train_cybersecurity_owasp_sft(
335
  "output_repo_id": output_repo_id,
336
  "train_jsonl": train_jsonl,
337
  "validation_jsonl": validation_jsonl if has_validation else "",
 
 
 
 
 
338
  "max_steps": max_steps,
339
  "push_to_hub": push_to_hub,
340
  "trackio_space_id": trackio_space_id,
@@ -365,20 +573,26 @@ def main(
365
  mode: str = "train",
366
  local_train_path: str = "outputs/sft/train.jsonl",
367
  local_validation_path: str = "outputs/sft/validation.jsonl",
 
368
  train_jsonl: str = "/runs/sft/train.jsonl",
369
  validation_jsonl: str = "/runs/sft/validation.jsonl",
370
- output_repo_id: str = "",
 
371
  model_name: str = DEFAULT_GEMMA_MODEL,
372
  run_name: str = "",
373
  max_seq_length: int = 4096,
374
- max_steps: int = 100,
375
  num_train_epochs: float = 1.0,
376
- per_device_train_batch_size: int = 1,
377
- gradient_accumulation_steps: int = 16,
378
  learning_rate: float = 2e-5,
379
  lora_rank: int = 32,
380
- trackio_space_id: str = "Humanlearning/CyberSecurity_OWASP-trackio",
381
- trackio_project: str = "CyberSecurity_OWASP-sft",
 
 
 
 
382
  source_mode: str = "local",
383
  repo_url: str = PUBLIC_REPO_URL,
384
  repo_branch: str = PUBLIC_REPO_BRANCH,
@@ -392,6 +606,22 @@ def main(
392
 
393
  local_train = pathlib.Path(local_train_path)
394
  local_validation = pathlib.Path(local_validation_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
  if local_train.exists():
396
  uploaded = upload_sft_jsonl.remote(
397
  "sft/train.jsonl",
@@ -406,6 +636,13 @@ def main(
406
  )
407
  print(f"Uploaded validation JSONL: {uploaded_validation}")
408
  validation_jsonl = uploaded_validation
 
 
 
 
 
 
 
409
  if mode == "upload":
410
  return
411
 
@@ -416,6 +653,7 @@ def main(
416
  kwargs = dict(
417
  train_jsonl=train_jsonl,
418
  validation_jsonl=validation_jsonl,
 
419
  output_repo_id=output_repo_id,
420
  model_name=model_name,
421
  run_name=run_name,
@@ -428,11 +666,19 @@ def main(
428
  lora_rank=lora_rank,
429
  trackio_space_id=trackio_space_id,
430
  trackio_project=trackio_project,
 
 
 
 
431
  push_to_hub=push_to_hub,
432
  )
433
  print(f"SFT run name: {run_name}")
434
  print(f"Train JSONL: {train_jsonl}")
435
  print(f"Validation JSONL: {validation_jsonl}")
 
 
 
 
436
  print(f"Hub push enabled: {push_to_hub}")
437
  if detach:
438
  call = train_cybersecurity_owasp_sft.spawn(**kwargs)
 
33
  REMOTE_PROJECT = "/root/CyberSecurity_OWASP"
34
  PROJECT_ROOT = pathlib.Path(__file__).resolve().parents[1]
35
  DEFAULT_GEMMA_MODEL = "unsloth/gemma-4-E2B-it"
36
+ SFT_GPU_FALLBACK = ["H200", "H100", "A100-80GB", "L40S"]
37
+ DEFAULT_CURRICULUM_LEVELS = "0,1,2,3"
38
+ DEFAULT_TOTAL_TRAIN_EPISODES = 300
39
+ DEFAULT_EPISODES_PER_LEVEL = 75
40
+ DEFAULT_TRACKIO_SPACE_ID = "Humanlearning/CyberSecurity_OWASP-trackio"
41
+ DEFAULT_TRACKIO_PROJECT = "CyberSecurity_OWASP-sft"
42
+ DEFAULT_SFT_OUTPUT_REPO_ID = (
43
+ "Humanlearning/CyberSecurity_OWASP-unsloth-gemma-4-e2b-it-sft-lora"
44
+ )
45
  PUBLIC_REPO_URL = "https://github.com/humandotlearning/CyberSecurity_OWASP.git"
46
  PUBLIC_REPO_BRANCH = "master"
47
 
 
59
  return model_name.replace("/", "-").replace("_", "-").replace(".", "-").lower()
60
 
61
 
62
+ def _parse_int_csv(value: str) -> set[int]:
63
+ if not value.strip():
64
+ return set()
65
+ return {int(item.strip()) for item in value.split(",") if item.strip()}
66
+
67
+
68
+ SFT_ALLOWED_TOOLS = {
69
+ "inspect_policy_graph",
70
+ "list_routes",
71
+ "read_openapi",
72
+ "read_file",
73
+ "search_code",
74
+ "send_local_request",
75
+ "compare_identities",
76
+ "submit_diagnosis",
77
+ "patch_file",
78
+ "run_visible_tests",
79
+ "submit_fix",
80
+ "noop",
81
+ }
82
+
83
+
84
+ def _read_jsonl(path: pathlib.Path) -> list[dict[str, Any]]:
85
+ if not path.exists():
86
+ return []
87
+ rows: list[dict[str, Any]] = []
88
+ for line_number, line in enumerate(path.read_text(encoding="utf-8").splitlines(), start=1):
89
+ if not line.strip():
90
+ continue
91
+ try:
92
+ row = json.loads(line)
93
+ except json.JSONDecodeError as exc:
94
+ raise ValueError(f"{path}:{line_number}: invalid JSONL: {exc}") from exc
95
+ if not isinstance(row, dict):
96
+ raise ValueError(f"{path}:{line_number}: row must be a JSON object")
97
+ rows.append(row)
98
+ return rows
99
+
100
+
101
+ def _verify_sft_rows(
102
+ path: pathlib.Path,
103
+ *,
104
+ min_terminal_reward: float,
105
+ ) -> tuple[list[str], list[float], int, set[int]]:
106
+ rows = _read_jsonl(path)
107
+ failures: list[str] = []
108
+ rewards: list[float] = []
109
+ difficulties: set[int] = set()
110
+ for index, row in enumerate(rows, start=1):
111
+ messages = row.get("messages")
112
+ if not isinstance(messages, list) or len(messages) < 3:
113
+ failures.append(f"{path}:{index}: messages must include system/user/assistant")
114
+ continue
115
+ assistant = messages[-1]
116
+ if assistant.get("role") != "assistant":
117
+ failures.append(f"{path}:{index}: final message must be assistant")
118
+ continue
119
+ try:
120
+ action = json.loads(str(assistant.get("content", "")))
121
+ except json.JSONDecodeError as exc:
122
+ failures.append(f"{path}:{index}: assistant content is not JSON: {exc}")
123
+ continue
124
+ if not isinstance(action, dict) or action.get("tool_name") not in SFT_ALLOWED_TOOLS:
125
+ failures.append(f"{path}:{index}: assistant content is not a valid tool action")
126
+ continue
127
+ metadata = row.get("metadata")
128
+ if not isinstance(metadata, dict):
129
+ failures.append(f"{path}:{index}: missing metadata")
130
+ continue
131
+ if metadata.get("final_success") is not True:
132
+ failures.append(f"{path}:{index}: final_success is not true")
133
+ continue
134
+ if metadata.get("anti_cheat_flags") or []:
135
+ failures.append(f"{path}:{index}: anti-cheat flags present")
136
+ continue
137
+ if "difficulty" in metadata:
138
+ difficulties.add(int(metadata.get("difficulty", 0)))
139
+ terminal_reward = float(metadata.get("terminal_total", 0.0) or 0.0)
140
+ rewards.append(terminal_reward)
141
+ if terminal_reward < min_terminal_reward:
142
+ failures.append(
143
+ f"{path}:{index}: terminal_total {terminal_reward:.3f} below {min_terminal_reward:.3f}"
144
+ )
145
+ continue
146
+ breakdown = metadata.get("final_reward_breakdown") or {}
147
+ if not isinstance(breakdown, dict):
148
+ failures.append(f"{path}:{index}: missing final_reward_breakdown")
149
+ continue
150
+ for key in ("security", "regression", "public_routes", "patch_quality", "visible_tests"):
151
+ if float(breakdown.get(key, 0.0) or 0.0) <= 0.0:
152
+ failures.append(f"{path}:{index}: reward component {key} is not positive")
153
+ break
154
+ return failures, rewards, len(rows), difficulties
155
+
156
+
157
+ def verify_sft_inputs(
158
+ *,
159
+ train_jsonl: str,
160
+ validation_jsonl: str = "",
161
+ manifest_path: str = "",
162
+ required_difficulties: str = "",
163
+ min_terminal_reward: float = 12.0,
164
+ min_train_rows: int = 1,
165
+ ) -> dict[str, Any]:
166
+ train_path = pathlib.Path(train_jsonl)
167
+ validation_path = pathlib.Path(validation_jsonl) if validation_jsonl else pathlib.Path("")
168
+ failures, rewards, train_rows, difficulties = _verify_sft_rows(
169
+ train_path,
170
+ min_terminal_reward=min_terminal_reward,
171
+ )
172
+ validation_rows = 0
173
+ if validation_jsonl and validation_path.exists() and validation_path.stat().st_size > 0:
174
+ validation_failures, validation_rewards, validation_rows, validation_difficulties = _verify_sft_rows(
175
+ validation_path,
176
+ min_terminal_reward=min_terminal_reward,
177
+ )
178
+ failures.extend(validation_failures)
179
+ rewards.extend(validation_rewards)
180
+ difficulties.update(validation_difficulties)
181
+ if train_rows < min_train_rows:
182
+ failures.append(f"{train_path}: expected at least {min_train_rows} train rows, found {train_rows}")
183
+
184
+ manifest_verification: dict[str, Any] = {}
185
+ manifest = pathlib.Path(manifest_path) if manifest_path else pathlib.Path("")
186
+ if manifest_path and manifest.exists():
187
+ try:
188
+ manifest_data = json.loads(manifest.read_text(encoding="utf-8"))
189
+ manifest_verification = dict(manifest_data.get("reward_verification") or {})
190
+ manifest_difficulties = {
191
+ int(item) for item in manifest_data.get("difficulty_levels", []) or []
192
+ }
193
+ except Exception as exc:
194
+ failures.append(f"{manifest}: could not read manifest reward verification: {exc}")
195
+ manifest_difficulties = set()
196
+ if manifest_verification and manifest_verification.get("passed") is not True:
197
+ failures.append(f"{manifest}: manifest reward_verification did not pass")
198
+ else:
199
+ manifest_difficulties = set()
200
+
201
+ required = _parse_int_csv(required_difficulties) or manifest_difficulties
202
+ missing_difficulties = sorted(level for level in required if level not in difficulties)
203
+ if missing_difficulties:
204
+ failures.append(f"missing required curriculum difficulty rows: {missing_difficulties}")
205
+
206
+ reward_summary = {
207
+ "min": min(rewards) if rewards else 0.0,
208
+ "max": max(rewards) if rewards else 0.0,
209
+ "mean": (sum(rewards) / len(rewards)) if rewards else 0.0,
210
+ }
211
+ return {
212
+ "passed": not failures,
213
+ "failure_count": len(failures),
214
+ "failures": failures[:50],
215
+ "train_rows": train_rows,
216
+ "validation_rows": validation_rows,
217
+ "difficulties": sorted(difficulties),
218
+ "required_difficulties": sorted(required),
219
+ "missing_difficulties": missing_difficulties,
220
+ "min_terminal_reward": float(min_terminal_reward),
221
+ "reward_summary": reward_summary,
222
+ "manifest_reward_verification": manifest_verification,
223
+ }
224
+
225
+
226
  def _configure_modal_cache_env() -> dict[str, str]:
227
  values = {
228
  "HF_HOME": str(HF_HOME_DIR),
 
344
 
345
  @app.function(
346
  image=training_image,
347
+ gpu=SFT_GPU_FALLBACK,
348
  timeout=12 * 60 * 60,
349
  volumes={RUNS_DIR: volume, CACHE_DIR: cache_volume},
350
  secrets=secrets,
 
352
  def train_cybersecurity_owasp_sft(
353
  train_jsonl: str = "/runs/sft/train.jsonl",
354
  validation_jsonl: str = "/runs/sft/validation.jsonl",
355
+ manifest_path: str = "/runs/sft/manifest.json",
356
+ output_repo_id: str = DEFAULT_SFT_OUTPUT_REPO_ID,
357
  model_name: str = DEFAULT_GEMMA_MODEL,
358
  run_name: str = "",
359
  max_seq_length: int = 4096,
360
+ max_steps: int = -1,
361
  num_train_epochs: float = 1.0,
362
+ per_device_train_batch_size: int = 4,
363
+ gradient_accumulation_steps: int = 4,
364
  learning_rate: float = 2e-5,
365
  lora_rank: int = 32,
366
+ trackio_space_id: str = DEFAULT_TRACKIO_SPACE_ID,
367
+ trackio_project: str = DEFAULT_TRACKIO_PROJECT,
368
+ require_reward_verification: bool = True,
369
+ required_difficulties: str = DEFAULT_CURRICULUM_LEVELS,
370
+ min_terminal_reward: float = 12.0,
371
+ min_train_rows: int = 1,
372
  push_to_hub: bool = False,
373
  ) -> dict[str, Any]:
374
  import inspect
375
 
376
  from datasets import load_dataset
377
+ from huggingface_hub import snapshot_download
378
  from trl import SFTConfig, SFTTrainer
379
  from trl.chat_template_utils import add_response_schema
380
  from unsloth import FastVisionModel
 
385
  if not hf_token:
386
  raise RuntimeError(f"HF_TOKEN is missing from the Modal secret {SECRET_NAME}.")
387
 
388
+ output_repo_id = output_repo_id or DEFAULT_SFT_OUTPUT_REPO_ID
389
+ os.environ["TRACKIO_SPACE_ID"] = trackio_space_id
390
+ os.environ["TRACKIO_PROJECT"] = trackio_project
 
391
  stamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
392
  run_name = run_name or f"CyberSecurity_OWASP-{_model_repo_slug(model_name)}-sft-{stamp}"
393
  output_dir = RUNS_DIR / run_name
 
399
  has_validation = validation_path.exists() and validation_path.stat().st_size > 0
400
  if has_validation:
401
  data_files["validation"] = validation_jsonl
402
+
403
+ reward_preflight = verify_sft_inputs(
404
+ train_jsonl=train_jsonl,
405
+ validation_jsonl=validation_jsonl if has_validation else "",
406
+ manifest_path=manifest_path,
407
+ required_difficulties=required_difficulties,
408
+ min_terminal_reward=min_terminal_reward,
409
+ min_train_rows=min_train_rows,
410
+ )
411
+ print(f"SFT reward preflight: {json.dumps(reward_preflight, sort_keys=True)}")
412
+ if require_reward_verification and not reward_preflight["passed"]:
413
+ raise RuntimeError(
414
+ "SFT reward verification failed; refusing to start model training. "
415
+ f"Failures: {reward_preflight['failures']}"
416
+ )
417
  dataset = load_dataset("json", data_files=data_files)
418
 
419
  print(f"SFT run name: {run_name}")
 
424
  print(f"Output repo: https://huggingface.co/{output_repo_id}")
425
  print(f"Trackio Space: https://huggingface.co/spaces/{trackio_space_id}")
426
  print(f"HF_HUB_CACHE: {cache_env['HF_HUB_CACHE']}")
427
+ print(
428
+ "SFT target: "
429
+ f"{DEFAULT_TOTAL_TRAIN_EPISODES} total train episodes, "
430
+ f"{DEFAULT_EPISODES_PER_LEVEL} per level across {DEFAULT_CURRICULUM_LEVELS}"
431
+ )
432
 
433
  try:
434
  snapshot_download(repo_id=model_name, cache_dir=str(HF_HUB_CACHE_DIR), token=hf_token)
 
477
  "per_device_train_batch_size": per_device_train_batch_size,
478
  "gradient_accumulation_steps": gradient_accumulation_steps,
479
  "learning_rate": learning_rate,
480
+ "optim": "adamw_8bit",
481
  "logging_steps": 1,
482
+ "logging_first_step": True,
483
+ "save_steps": max(10, max_steps) if max_steps > 0 else 100,
484
  "report_to": "trackio",
485
  "project": trackio_project,
486
  "trackio_space_id": trackio_space_id,
487
  "run_name": run_name,
488
  "assistant_only_loss": True,
489
+ "packing": True,
490
+ "packing_strategy": "bfd",
491
+ "bf16": True,
492
+ "tf32": True,
493
  "gradient_checkpointing": True,
494
  "gradient_checkpointing_kwargs": {"use_reentrant": False},
495
  "push_to_hub": push_to_hub,
496
  "hub_model_id": output_repo_id,
497
  "hub_private_repo": True,
498
+ "hub_strategy": "every_save",
499
  }
500
  sft_parameters = set(inspect.signature(SFTConfig).parameters)
501
  skipped = sorted(set(sft_values) - sft_parameters)
 
538
  "output_repo_id": output_repo_id,
539
  "train_jsonl": train_jsonl,
540
  "validation_jsonl": validation_jsonl if has_validation else "",
541
+ "manifest_path": manifest_path,
542
+ "reward_preflight": reward_preflight,
543
+ "required_difficulties": required_difficulties,
544
+ "default_total_train_episodes": DEFAULT_TOTAL_TRAIN_EPISODES,
545
+ "default_episodes_per_level": DEFAULT_EPISODES_PER_LEVEL,
546
  "max_steps": max_steps,
547
  "push_to_hub": push_to_hub,
548
  "trackio_space_id": trackio_space_id,
 
573
  mode: str = "train",
574
  local_train_path: str = "outputs/sft/train.jsonl",
575
  local_validation_path: str = "outputs/sft/validation.jsonl",
576
+ local_manifest_path: str = "outputs/sft/manifest.json",
577
  train_jsonl: str = "/runs/sft/train.jsonl",
578
  validation_jsonl: str = "/runs/sft/validation.jsonl",
579
+ manifest_path: str = "/runs/sft/manifest.json",
580
+ output_repo_id: str = DEFAULT_SFT_OUTPUT_REPO_ID,
581
  model_name: str = DEFAULT_GEMMA_MODEL,
582
  run_name: str = "",
583
  max_seq_length: int = 4096,
584
+ max_steps: int = -1,
585
  num_train_epochs: float = 1.0,
586
+ per_device_train_batch_size: int = 4,
587
+ gradient_accumulation_steps: int = 4,
588
  learning_rate: float = 2e-5,
589
  lora_rank: int = 32,
590
+ trackio_space_id: str = DEFAULT_TRACKIO_SPACE_ID,
591
+ trackio_project: str = DEFAULT_TRACKIO_PROJECT,
592
+ require_reward_verification: bool = True,
593
+ required_difficulties: str = DEFAULT_CURRICULUM_LEVELS,
594
+ min_terminal_reward: float = 12.0,
595
+ min_train_rows: int = 1,
596
  source_mode: str = "local",
597
  repo_url: str = PUBLIC_REPO_URL,
598
  repo_branch: str = PUBLIC_REPO_BRANCH,
 
606
 
607
  local_train = pathlib.Path(local_train_path)
608
  local_validation = pathlib.Path(local_validation_path)
609
+ local_manifest = pathlib.Path(local_manifest_path)
610
+ if require_reward_verification and local_train.exists():
611
+ local_reward_preflight = verify_sft_inputs(
612
+ train_jsonl=str(local_train),
613
+ validation_jsonl=str(local_validation) if local_validation.exists() else "",
614
+ manifest_path=str(local_manifest) if local_manifest.exists() else "",
615
+ required_difficulties=required_difficulties,
616
+ min_terminal_reward=min_terminal_reward,
617
+ min_train_rows=min_train_rows,
618
+ )
619
+ print(f"Local SFT reward preflight: {json.dumps(local_reward_preflight, sort_keys=True)}")
620
+ if not local_reward_preflight["passed"]:
621
+ raise RuntimeError(
622
+ "Local SFT reward verification failed; refusing to upload/train. "
623
+ f"Failures: {local_reward_preflight['failures']}"
624
+ )
625
  if local_train.exists():
626
  uploaded = upload_sft_jsonl.remote(
627
  "sft/train.jsonl",
 
636
  )
637
  print(f"Uploaded validation JSONL: {uploaded_validation}")
638
  validation_jsonl = uploaded_validation
639
+ if local_manifest.exists():
640
+ uploaded_manifest = upload_sft_jsonl.remote(
641
+ "sft/manifest.json",
642
+ local_manifest.read_text(encoding="utf-8"),
643
+ )
644
+ print(f"Uploaded manifest: {uploaded_manifest}")
645
+ manifest_path = uploaded_manifest
646
  if mode == "upload":
647
  return
648
 
 
653
  kwargs = dict(
654
  train_jsonl=train_jsonl,
655
  validation_jsonl=validation_jsonl,
656
+ manifest_path=manifest_path,
657
  output_repo_id=output_repo_id,
658
  model_name=model_name,
659
  run_name=run_name,
 
666
  lora_rank=lora_rank,
667
  trackio_space_id=trackio_space_id,
668
  trackio_project=trackio_project,
669
+ require_reward_verification=require_reward_verification,
670
+ required_difficulties=required_difficulties,
671
+ min_terminal_reward=min_terminal_reward,
672
+ min_train_rows=min_train_rows,
673
  push_to_hub=push_to_hub,
674
  )
675
  print(f"SFT run name: {run_name}")
676
  print(f"Train JSONL: {train_jsonl}")
677
  print(f"Validation JSONL: {validation_jsonl}")
678
+ print(f"Manifest: {manifest_path}")
679
+ print(f"Reward verification required: {require_reward_verification}")
680
+ if required_difficulties:
681
+ print(f"Required curriculum difficulties: {required_difficulties}")
682
  print(f"Hub push enabled: {push_to_hub}")
683
  if detach:
684
  call = train_cybersecurity_owasp_sft.spawn(**kwargs)
tests/test_modal_scenario_cache_static.py CHANGED
@@ -37,3 +37,40 @@ def test_modal_training_is_pinned_to_gemma4_e2b():
37
  assert "from unsloth import FastVisionModel" in source
38
  assert "Qwen" not in source
39
  assert "FastLanguageModel" not in source
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  assert "from unsloth import FastVisionModel" in source
38
  assert "Qwen" not in source
39
  assert "FastLanguageModel" not in source
40
+
41
+
42
+ def test_modal_sft_defaults_match_300_episode_fast_handoff_plan():
43
+ source = (ROOT / "scripts" / "modal_train_sft.py").read_text(encoding="utf-8")
44
+
45
+ assert 'SFT_GPU_FALLBACK = ["H200", "H100", "A100-80GB", "L40S"]' in source
46
+ assert "gpu=SFT_GPU_FALLBACK" in source
47
+ assert "DEFAULT_TOTAL_TRAIN_EPISODES = 300" in source
48
+ assert "DEFAULT_EPISODES_PER_LEVEL = 75" in source
49
+ assert 'DEFAULT_CURRICULUM_LEVELS = "0,1,2,3"' in source
50
+ assert (
51
+ 'DEFAULT_SFT_OUTPUT_REPO_ID = (\n'
52
+ ' "Humanlearning/CyberSecurity_OWASP-unsloth-gemma-4-e2b-it-sft-lora"'
53
+ ) in source
54
+ assert "output_repo_id = output_repo_id or DEFAULT_SFT_OUTPUT_REPO_ID" in source
55
+ assert source.count("max_steps: int = -1") >= 2
56
+ assert source.count("per_device_train_batch_size: int = 4") >= 2
57
+ assert source.count("gradient_accumulation_steps: int = 4") >= 2
58
+ assert '"assistant_only_loss": True' in source
59
+ assert '"packing": True' in source
60
+ assert '"packing_strategy": "bfd"' in source
61
+ assert '"bf16": True' in source
62
+ assert '"tf32": True' in source
63
+ assert '"hub_strategy": "every_save"' in source
64
+ assert 'trackio_space_id: str = DEFAULT_TRACKIO_SPACE_ID' in source
65
+ assert 'trackio_project: str = DEFAULT_TRACKIO_PROJECT' in source
66
+ assert 'os.environ["TRACKIO_SPACE_ID"] = trackio_space_id' in source
67
+ assert 'os.environ["TRACKIO_PROJECT"] = trackio_project' in source
68
+
69
+
70
+ def test_modal_grpo_loads_sft_adapter_from_hub_as_trainable_lora():
71
+ source = (ROOT / "scripts" / "modal_train_grpo.py").read_text(encoding="utf-8")
72
+
73
+ assert "initial_adapter_repo_id" in source
74
+ assert "Downloading initial SFT adapter" in source
75
+ assert "snapshot_download(" in source
76
+ assert "PeftModel.from_pretrained(model, adapter_source, is_trainable=True)" in source
tests/test_sft_dataset_generation.py CHANGED
@@ -90,11 +90,20 @@ def test_dry_run_oracle_creates_chat_jsonl_without_network():
90
  validation_episodes=1,
91
  out_dir=out_dir,
92
  dry_run_oracle=True,
 
 
93
  )
94
  )
95
 
96
- assert manifest["episodes_attempted"] == 3
97
- assert manifest["episodes_accepted"] == 3
 
 
 
 
 
 
 
98
  assert (out_dir / "train.jsonl").exists()
99
  assert (out_dir / "validation.jsonl").exists()
100
  train_rows = [
@@ -110,6 +119,43 @@ def test_dry_run_oracle_creates_chat_jsonl_without_network():
110
  assert train_rows
111
  assert validation_rows
112
  assert all(row["messages"][-1]["role"] == "assistant" for row in train_rows)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
 
115
  def test_saved_oracle_trajectory_replays_to_success():
 
90
  validation_episodes=1,
91
  out_dir=out_dir,
92
  dry_run_oracle=True,
93
+ workers=2,
94
+ difficulty_levels=(0, 1),
95
  )
96
  )
97
 
98
+ assert manifest["difficulty_levels"] == [0, 1]
99
+ assert manifest["difficulty_bucket_count"] >= 2
100
+ assert manifest["episodes_attempted"] == 6
101
+ assert manifest["episodes_accepted"] == 6
102
+ assert manifest["workers"] == 2
103
+ assert manifest["reward_verification"]["passed"] is True
104
+ assert manifest["reward_verification"]["missing_difficulties"] == []
105
+ assert manifest["rows_by_difficulty"]["0"] > 0
106
+ assert manifest["rows_by_difficulty"]["1"] > 0
107
  assert (out_dir / "train.jsonl").exists()
108
  assert (out_dir / "validation.jsonl").exists()
109
  train_rows = [
 
119
  assert train_rows
120
  assert validation_rows
121
  assert all(row["messages"][-1]["role"] == "assistant" for row in train_rows)
122
+ reward_check = generate_sft_dataset.verify_sft_dataset_rewards(
123
+ out_dir,
124
+ required_difficulties=(0, 1),
125
+ )
126
+ assert reward_check["passed"] is True
127
+ assert (out_dir / "README.md").exists()
128
+
129
+
130
+ def test_reward_verification_rejects_low_reward_rows():
131
+ out_dir = _isolated_out_dir("bad_reward")
132
+ out_dir.mkdir(parents=True, exist_ok=True)
133
+ action = CyberSecurityOWASPAction(tool_name="inspect_policy_graph", arguments={})
134
+ row = {
135
+ "messages": [
136
+ {"role": "system", "content": "system"},
137
+ {"role": "user", "content": "user"},
138
+ {"role": "assistant", "content": json.dumps(action.model_dump())},
139
+ ],
140
+ "metadata": {
141
+ "final_success": True,
142
+ "terminal_total": 1.0,
143
+ "anti_cheat_flags": [],
144
+ "final_reward_breakdown": {
145
+ "security": 5.0,
146
+ "regression": 3.0,
147
+ "public_routes": 1.0,
148
+ "patch_quality": 2.0,
149
+ "visible_tests": 1.0,
150
+ },
151
+ },
152
+ }
153
+ (out_dir / "train.jsonl").write_text(json.dumps(row) + "\n", encoding="utf-8")
154
+
155
+ reward_check = generate_sft_dataset.verify_sft_dataset_rewards(out_dir)
156
+
157
+ assert reward_check["passed"] is False
158
+ assert reward_check["failure_count"] == 1
159
 
160
 
161
  def test_saved_oracle_trajectory_replays_to_success():