Spaces:
Sleeping
Sleeping
feat: enhance SFT training process with new tokenization method, implement custom trainer class for loss computation, and update README with GRPO launcher details for Unsloth LoRA integration
e5fe6f5 | """Modal SFT launcher for CyberSecurity_OWASP action JSON data. | |
| This trains a LoRA adapter on chat JSONL generated by | |
| ``scripts/generate_sft_dataset.py``. It intentionally mirrors the repo's Modal | |
| training pattern: local execution only launches remote jobs, while training runs | |
| inside Modal and saves adapters to the persistent run volume. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import pathlib | |
| import subprocess | |
| from datetime import datetime, timezone | |
| from typing import Any | |
| import modal | |
| APP_NAME = "CyberSecurity_OWASP-sft" | |
| VOLUME_NAME = "CyberSecurity_OWASP-grpo-runs" | |
| CACHE_VOLUME_NAME = "CyberSecurity_OWASP-model-cache" | |
| SECRET_NAME = "CyberSecurity_OWASP-secrets" | |
| RUNS_DIR = pathlib.Path("/runs") | |
| CACHE_DIR = pathlib.Path("/cache") | |
| HF_HOME_DIR = CACHE_DIR / "huggingface" | |
| HF_HUB_CACHE_DIR = HF_HOME_DIR / "hub" | |
| TORCH_HOME_DIR = CACHE_DIR / "torch" | |
| XDG_CACHE_DIR = CACHE_DIR / "xdg" | |
| UNSLOTH_CACHE_DIR = CACHE_DIR / "unsloth" | |
| TRITON_CACHE_DIR = CACHE_DIR / "triton" | |
| REMOTE_PROJECT = "/root/CyberSecurity_OWASP" | |
| PROJECT_ROOT = pathlib.Path(__file__).resolve().parents[1] | |
| DEFAULT_GEMMA_MODEL = "unsloth/gemma-4-E2B-it" | |
| SFT_GPU_FALLBACK = ["H200", "H100", "A100-80GB", "L40S"] | |
| DEFAULT_CURRICULUM_LEVELS = "0,1,2,3" | |
| DEFAULT_TOTAL_TRAIN_EPISODES = 300 | |
| DEFAULT_EPISODES_PER_LEVEL = 75 | |
| DEFAULT_TRACKIO_SPACE_ID = "Humanlearning/CyberSecurity_OWASP-trackio" | |
| DEFAULT_TRACKIO_PROJECT = "CyberSecurity_OWASP-sft" | |
| DEFAULT_SFT_OUTPUT_REPO_ID = ( | |
| "Humanlearning/CyberSecurity_OWASP-unsloth-gemma-4-e2b-it-sft-lora" | |
| ) | |
| PUBLIC_REPO_URL = "https://github.com/humandotlearning/CyberSecurity_OWASP.git" | |
| PUBLIC_REPO_BRANCH = "master" | |
| def _ensure_gemma4_model(model_name: str) -> str: | |
| if model_name != DEFAULT_GEMMA_MODEL: | |
| raise ValueError( | |
| "CyberSecurity_OWASP SFT is pinned to " | |
| f"{DEFAULT_GEMMA_MODEL}; received {model_name!r}." | |
| ) | |
| return model_name | |
| def _model_repo_slug(model_name: str) -> str: | |
| return model_name.replace("/", "-").replace("_", "-").replace(".", "-").lower() | |
| def _parse_int_csv(value: str) -> set[int]: | |
| if not value.strip(): | |
| return set() | |
| return {int(item.strip()) for item in value.split(",") if item.strip()} | |
| SFT_ALLOWED_TOOLS = { | |
| "inspect_policy_graph", | |
| "list_routes", | |
| "read_openapi", | |
| "read_file", | |
| "search_code", | |
| "send_local_request", | |
| "compare_identities", | |
| "submit_diagnosis", | |
| "patch_file", | |
| "run_visible_tests", | |
| "submit_fix", | |
| "noop", | |
| } | |
| def _read_jsonl(path: pathlib.Path) -> list[dict[str, Any]]: | |
| if not path.exists(): | |
| return [] | |
| rows: list[dict[str, Any]] = [] | |
| for line_number, line in enumerate(path.read_text(encoding="utf-8").splitlines(), start=1): | |
| if not line.strip(): | |
| continue | |
| try: | |
| row = json.loads(line) | |
| except json.JSONDecodeError as exc: | |
| raise ValueError(f"{path}:{line_number}: invalid JSONL: {exc}") from exc | |
| if not isinstance(row, dict): | |
| raise ValueError(f"{path}:{line_number}: row must be a JSON object") | |
| rows.append(row) | |
| return rows | |
| def _verify_sft_rows( | |
| path: pathlib.Path, | |
| *, | |
| min_terminal_reward: float, | |
| ) -> tuple[list[str], list[float], int, set[int]]: | |
| rows = _read_jsonl(path) | |
| failures: list[str] = [] | |
| rewards: list[float] = [] | |
| difficulties: set[int] = set() | |
| for index, row in enumerate(rows, start=1): | |
| messages = row.get("messages") | |
| if not isinstance(messages, list) or len(messages) < 3: | |
| failures.append(f"{path}:{index}: messages must include system/user/assistant") | |
| continue | |
| assistant = messages[-1] | |
| if assistant.get("role") != "assistant": | |
| failures.append(f"{path}:{index}: final message must be assistant") | |
| continue | |
| try: | |
| action = json.loads(str(assistant.get("content", ""))) | |
| except json.JSONDecodeError as exc: | |
| failures.append(f"{path}:{index}: assistant content is not JSON: {exc}") | |
| continue | |
| if not isinstance(action, dict) or action.get("tool_name") not in SFT_ALLOWED_TOOLS: | |
| failures.append(f"{path}:{index}: assistant content is not a valid tool action") | |
| continue | |
| metadata = row.get("metadata") | |
| if not isinstance(metadata, dict): | |
| failures.append(f"{path}:{index}: missing metadata") | |
| continue | |
| if metadata.get("final_success") is not True: | |
| failures.append(f"{path}:{index}: final_success is not true") | |
| continue | |
| if metadata.get("anti_cheat_flags") or []: | |
| failures.append(f"{path}:{index}: anti-cheat flags present") | |
| continue | |
| if "difficulty" in metadata: | |
| difficulties.add(int(metadata.get("difficulty", 0))) | |
| terminal_reward = float(metadata.get("terminal_total", 0.0) or 0.0) | |
| rewards.append(terminal_reward) | |
| if terminal_reward < min_terminal_reward: | |
| failures.append( | |
| f"{path}:{index}: terminal_total {terminal_reward:.3f} below {min_terminal_reward:.3f}" | |
| ) | |
| continue | |
| breakdown = metadata.get("final_reward_breakdown") or {} | |
| if not isinstance(breakdown, dict): | |
| failures.append(f"{path}:{index}: missing final_reward_breakdown") | |
| continue | |
| for key in ("security", "regression", "public_routes", "patch_quality", "visible_tests"): | |
| if float(breakdown.get(key, 0.0) or 0.0) <= 0.0: | |
| failures.append(f"{path}:{index}: reward component {key} is not positive") | |
| break | |
| return failures, rewards, len(rows), difficulties | |
| def verify_sft_inputs( | |
| *, | |
| train_jsonl: str, | |
| validation_jsonl: str = "", | |
| manifest_path: str = "", | |
| required_difficulties: str = "", | |
| min_terminal_reward: float = 12.0, | |
| min_train_rows: int = 1, | |
| ) -> dict[str, Any]: | |
| train_path = pathlib.Path(train_jsonl) | |
| validation_path = pathlib.Path(validation_jsonl) if validation_jsonl else pathlib.Path("") | |
| failures, rewards, train_rows, difficulties = _verify_sft_rows( | |
| train_path, | |
| min_terminal_reward=min_terminal_reward, | |
| ) | |
| validation_rows = 0 | |
| if validation_jsonl and validation_path.exists() and validation_path.stat().st_size > 0: | |
| validation_failures, validation_rewards, validation_rows, validation_difficulties = _verify_sft_rows( | |
| validation_path, | |
| min_terminal_reward=min_terminal_reward, | |
| ) | |
| failures.extend(validation_failures) | |
| rewards.extend(validation_rewards) | |
| difficulties.update(validation_difficulties) | |
| if train_rows < min_train_rows: | |
| failures.append(f"{train_path}: expected at least {min_train_rows} train rows, found {train_rows}") | |
| manifest_verification: dict[str, Any] = {} | |
| manifest = pathlib.Path(manifest_path) if manifest_path else pathlib.Path("") | |
| if manifest_path and manifest.exists(): | |
| try: | |
| manifest_data = json.loads(manifest.read_text(encoding="utf-8")) | |
| manifest_verification = dict(manifest_data.get("reward_verification") or {}) | |
| manifest_difficulties = { | |
| int(item) for item in manifest_data.get("difficulty_levels", []) or [] | |
| } | |
| except Exception as exc: | |
| failures.append(f"{manifest}: could not read manifest reward verification: {exc}") | |
| manifest_difficulties = set() | |
| if manifest_verification and manifest_verification.get("passed") is not True: | |
| failures.append(f"{manifest}: manifest reward_verification did not pass") | |
| else: | |
| manifest_difficulties = set() | |
| required = _parse_int_csv(required_difficulties) or manifest_difficulties | |
| missing_difficulties = sorted(level for level in required if level not in difficulties) | |
| if missing_difficulties: | |
| failures.append(f"missing required curriculum difficulty rows: {missing_difficulties}") | |
| reward_summary = { | |
| "min": min(rewards) if rewards else 0.0, | |
| "max": max(rewards) if rewards else 0.0, | |
| "mean": (sum(rewards) / len(rewards)) if rewards else 0.0, | |
| } | |
| return { | |
| "passed": not failures, | |
| "failure_count": len(failures), | |
| "failures": failures[:50], | |
| "train_rows": train_rows, | |
| "validation_rows": validation_rows, | |
| "difficulties": sorted(difficulties), | |
| "required_difficulties": sorted(required), | |
| "missing_difficulties": missing_difficulties, | |
| "min_terminal_reward": float(min_terminal_reward), | |
| "reward_summary": reward_summary, | |
| "manifest_reward_verification": manifest_verification, | |
| } | |
| def _configure_modal_cache_env() -> dict[str, str]: | |
| values = { | |
| "HF_HOME": str(HF_HOME_DIR), | |
| "HF_HUB_CACHE": str(HF_HUB_CACHE_DIR), | |
| "TRANSFORMERS_CACHE": str(HF_HUB_CACHE_DIR), | |
| "TORCH_HOME": str(TORCH_HOME_DIR), | |
| "XDG_CACHE_HOME": str(XDG_CACHE_DIR), | |
| "UNSLOTH_CACHE_DIR": str(UNSLOTH_CACHE_DIR), | |
| "UNSLOTH_COMPILE_CACHE": str(UNSLOTH_CACHE_DIR / "compile"), | |
| "TRITON_CACHE_DIR": str(TRITON_CACHE_DIR), | |
| } | |
| for key, value in values.items(): | |
| os.environ[key] = value | |
| for path in { | |
| CACHE_DIR, | |
| HF_HOME_DIR, | |
| HF_HUB_CACHE_DIR, | |
| TORCH_HOME_DIR, | |
| XDG_CACHE_DIR, | |
| UNSLOTH_CACHE_DIR, | |
| UNSLOTH_CACHE_DIR / "compile", | |
| TRITON_CACHE_DIR, | |
| }: | |
| path.mkdir(parents=True, exist_ok=True) | |
| return values | |
| def _cli_arg_value(name: str, default: str = "") -> str: | |
| import sys | |
| args = sys.argv[1:] | |
| flag = f"--{name}" | |
| for index, arg in enumerate(args): | |
| if arg == flag and index + 1 < len(args): | |
| return args[index + 1] | |
| if arg.startswith(f"{flag}="): | |
| return arg.split("=", 1)[1] | |
| return default | |
| def _source_mode() -> str: | |
| return _cli_arg_value("source-mode", os.environ.get("MODAL_SOURCE_MODE", "local")) | |
| def _training_image() -> modal.Image: | |
| image = ( | |
| modal.Image.from_registry( | |
| "nvidia/cuda:12.8.0-devel-ubuntu22.04", | |
| add_python="3.11", | |
| ) | |
| .apt_install("git", "build-essential", "curl") | |
| .uv_pip_install( | |
| "torch==2.10.0", | |
| "triton>=3.4.0", | |
| "torchvision==0.25.0", | |
| "bitsandbytes", | |
| "accelerate", | |
| "datasets", | |
| "huggingface_hub", | |
| "peft", | |
| "tokenizers", | |
| "trackio>=0.25.0", | |
| "transformers>=5.5.0", | |
| "trl>=0.28.0", | |
| ) | |
| .uv_pip_install( | |
| "unsloth_zoo[base] @ git+https://github.com/unslothai/unsloth-zoo", | |
| "unsloth[base] @ git+https://github.com/unslothai/unsloth", | |
| ) | |
| .uv_pip_install("timm", extra_options="--no-deps") | |
| .uv_pip_install("pydantic==2.10.6") | |
| ) | |
| if _source_mode() == "public": | |
| repo_url = _cli_arg_value("repo-url", PUBLIC_REPO_URL) | |
| repo_branch = _cli_arg_value("repo-branch", PUBLIC_REPO_BRANCH) | |
| image = image.run_commands( | |
| f"git clone --depth 1 --branch {repo_branch} {repo_url} {REMOTE_PROJECT}", | |
| f"python -m pip install --no-deps -e {REMOTE_PROJECT}", | |
| ) | |
| else: | |
| image = image.add_local_dir( | |
| PROJECT_ROOT, | |
| remote_path=REMOTE_PROJECT, | |
| copy=True, | |
| ignore=[ | |
| ".git", | |
| ".venv", | |
| ".env", | |
| ".env.*", | |
| "__pycache__", | |
| ".pytest_cache", | |
| "outputs", | |
| "*.pyc", | |
| ], | |
| ) | |
| image = image.run_commands(f"python -m pip install --no-deps -e {REMOTE_PROJECT}") | |
| return image.workdir(REMOTE_PROJECT) | |
| app = modal.App(APP_NAME) | |
| volume = modal.Volume.from_name(VOLUME_NAME, create_if_missing=True) | |
| cache_volume = modal.Volume.from_name(CACHE_VOLUME_NAME, create_if_missing=True) | |
| training_image = _training_image() | |
| secrets = [modal.Secret.from_name(SECRET_NAME, required_keys=["HF_TOKEN"])] | |
| def upload_sft_jsonl(relative_path: str, content: str) -> str: | |
| target = RUNS_DIR / relative_path | |
| target.parent.mkdir(parents=True, exist_ok=True) | |
| target.write_text(content, encoding="utf-8") | |
| volume.commit() | |
| return str(target) | |
| def train_cybersecurity_owasp_sft( | |
| train_jsonl: str = "/runs/sft/train.jsonl", | |
| validation_jsonl: str = "/runs/sft/validation.jsonl", | |
| manifest_path: str = "/runs/sft/manifest.json", | |
| output_repo_id: str = DEFAULT_SFT_OUTPUT_REPO_ID, | |
| model_name: str = DEFAULT_GEMMA_MODEL, | |
| run_name: str = "", | |
| max_seq_length: int = 4096, | |
| max_steps: int = -1, | |
| num_train_epochs: float = 1.0, | |
| per_device_train_batch_size: int = 4, | |
| gradient_accumulation_steps: int = 4, | |
| learning_rate: float = 2e-5, | |
| lora_rank: int = 32, | |
| trackio_space_id: str = DEFAULT_TRACKIO_SPACE_ID, | |
| trackio_project: str = DEFAULT_TRACKIO_PROJECT, | |
| require_reward_verification: bool = True, | |
| required_difficulties: str = DEFAULT_CURRICULUM_LEVELS, | |
| min_terminal_reward: float = 12.0, | |
| min_train_rows: int = 1, | |
| push_to_hub: bool = False, | |
| ) -> dict[str, Any]: | |
| import inspect | |
| from datasets import Dataset, load_dataset | |
| from huggingface_hub import snapshot_download | |
| from transformers import Trainer | |
| from trl import SFTConfig, SFTTrainer | |
| try: | |
| from trl.chat_template_utils import add_response_schema | |
| except ImportError: | |
| def add_response_schema(tokenizer): | |
| return tokenizer | |
| from unsloth import FastVisionModel | |
| model_name = _ensure_gemma4_model(model_name) | |
| cache_env = _configure_modal_cache_env() | |
| hf_token = os.environ.get("HF_TOKEN") | |
| if not hf_token: | |
| raise RuntimeError(f"HF_TOKEN is missing from the Modal secret {SECRET_NAME}.") | |
| output_repo_id = output_repo_id or DEFAULT_SFT_OUTPUT_REPO_ID | |
| os.environ["TRACKIO_SPACE_ID"] = trackio_space_id | |
| os.environ["TRACKIO_PROJECT"] = trackio_project | |
| stamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S") | |
| run_name = run_name or f"CyberSecurity_OWASP-{_model_repo_slug(model_name)}-sft-{stamp}" | |
| output_dir = RUNS_DIR / run_name | |
| adapter_dir = output_dir / "sft_adapter" | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| data_files = {"train": train_jsonl} | |
| validation_path = pathlib.Path(validation_jsonl) | |
| has_validation = validation_path.exists() and validation_path.stat().st_size > 0 | |
| if has_validation: | |
| data_files["validation"] = validation_jsonl | |
| reward_preflight = verify_sft_inputs( | |
| train_jsonl=train_jsonl, | |
| validation_jsonl=validation_jsonl if has_validation else "", | |
| manifest_path=manifest_path, | |
| required_difficulties=required_difficulties, | |
| min_terminal_reward=min_terminal_reward, | |
| min_train_rows=min_train_rows, | |
| ) | |
| print(f"SFT reward preflight: {json.dumps(reward_preflight, sort_keys=True)}") | |
| if require_reward_verification and not reward_preflight["passed"]: | |
| raise RuntimeError( | |
| "SFT reward verification failed; refusing to start model training. " | |
| f"Failures: {reward_preflight['failures']}" | |
| ) | |
| dataset = load_dataset("json", data_files=data_files) | |
| print(f"SFT run name: {run_name}") | |
| print(f"Model: {model_name}") | |
| print(f"Train JSONL: {train_jsonl}") | |
| print(f"Validation JSONL: {validation_jsonl if has_validation else '(none)'}") | |
| print(f"Output adapter dir: {adapter_dir}") | |
| print(f"Output repo: https://huggingface.co/{output_repo_id}") | |
| print(f"Trackio Space: https://huggingface.co/spaces/{trackio_space_id}") | |
| print(f"HF_HUB_CACHE: {cache_env['HF_HUB_CACHE']}") | |
| print( | |
| "SFT target: " | |
| f"{DEFAULT_TOTAL_TRAIN_EPISODES} total train episodes, " | |
| f"{DEFAULT_EPISODES_PER_LEVEL} per level across {DEFAULT_CURRICULUM_LEVELS}" | |
| ) | |
| try: | |
| snapshot_download(repo_id=model_name, cache_dir=str(HF_HUB_CACHE_DIR), token=hf_token) | |
| cache_volume.commit() | |
| except Exception as exc: | |
| print(f"Model snapshot prefetch skipped; loader will retry directly. Error: {exc!r}") | |
| model_api = FastVisionModel | |
| model, tokenizer = model_api.from_pretrained( | |
| model_name=model_name, | |
| max_seq_length=max_seq_length, | |
| load_in_4bit=False, | |
| fast_inference=False, | |
| cache_dir=str(HF_HUB_CACHE_DIR), | |
| token=hf_token, | |
| ) | |
| try: | |
| tokenizer = add_response_schema(tokenizer) | |
| except Exception as exc: | |
| print(f"Tokenizer response schema add skipped: {exc!r}") | |
| def _tokenize_sft_split(split_name: str, split_dataset) -> Dataset: | |
| tokenized_rows: list[dict[str, list[int]]] = [] | |
| total_rows = len(split_dataset) | |
| for row_index, example in enumerate(split_dataset, start=1): | |
| messages = example["messages"] | |
| if isinstance(messages, str): | |
| messages = json.loads(messages) | |
| rendered = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=False, | |
| ) | |
| try: | |
| encoded = tokenizer( | |
| rendered, | |
| add_special_tokens=False, | |
| truncation=True, | |
| max_length=max_seq_length, | |
| ) | |
| except TypeError: | |
| encoded = tokenizer( | |
| text=rendered, | |
| add_special_tokens=False, | |
| truncation=True, | |
| max_length=max_seq_length, | |
| ) | |
| input_ids = encoded["input_ids"] | |
| if input_ids and isinstance(input_ids[0], list): | |
| input_ids = input_ids[0] | |
| input_ids = [int(token_id) for token_id in input_ids[:max_seq_length]] | |
| if not input_ids: | |
| raise RuntimeError(f"{split_name} row {row_index} produced no tokens.") | |
| tokenized_rows.append({"input_ids": input_ids, "labels": list(input_ids)}) | |
| if row_index % 500 == 0 or row_index == total_rows: | |
| print(f"Tokenized {split_name} rows: {row_index}/{total_rows}") | |
| return Dataset.from_list(tokenized_rows) | |
| dataset["train"] = _tokenize_sft_split("train", dataset["train"]) | |
| if has_validation: | |
| dataset["validation"] = _tokenize_sft_split("validation", dataset["validation"]) | |
| model = model_api.get_peft_model( | |
| model, | |
| r=lora_rank, | |
| target_modules=[ | |
| "q_proj", | |
| "k_proj", | |
| "v_proj", | |
| "o_proj", | |
| "gate_proj", | |
| "up_proj", | |
| "down_proj", | |
| ], | |
| lora_alpha=lora_rank * 2, | |
| use_gradient_checkpointing="unsloth", | |
| random_state=3407, | |
| ) | |
| if hasattr(model_api, "for_training"): | |
| model_api.for_training(model) | |
| sft_values = { | |
| "output_dir": str(output_dir), | |
| "max_seq_length": max_seq_length, | |
| "max_steps": max_steps, | |
| "num_train_epochs": num_train_epochs, | |
| "per_device_train_batch_size": per_device_train_batch_size, | |
| "gradient_accumulation_steps": gradient_accumulation_steps, | |
| "learning_rate": learning_rate, | |
| "optim": "adamw_8bit", | |
| "dataset_num_proc": None, | |
| "logging_steps": 1, | |
| "logging_first_step": True, | |
| "save_steps": max(10, max_steps) if max_steps > 0 else 100, | |
| "report_to": "trackio", | |
| "project": trackio_project, | |
| "trackio_space_id": trackio_space_id, | |
| "run_name": run_name, | |
| "assistant_only_loss": False, | |
| "packing": False, | |
| "bf16": True, | |
| "tf32": True, | |
| "gradient_checkpointing": True, | |
| "gradient_checkpointing_kwargs": {"use_reentrant": False}, | |
| "push_to_hub": push_to_hub, | |
| "hub_model_id": output_repo_id, | |
| "hub_private_repo": True, | |
| "hub_strategy": "every_save", | |
| } | |
| sft_parameters = set(inspect.signature(SFTConfig).parameters) | |
| skipped = sorted(set(sft_values) - sft_parameters) | |
| if skipped: | |
| print(f"Skipping unsupported SFTConfig keys: {skipped}") | |
| training_args = SFTConfig( | |
| **{key: value for key, value in sft_values.items() if key in sft_parameters} | |
| ) | |
| trainer_values = { | |
| "model": model, | |
| "processing_class": tokenizer, | |
| "args": training_args, | |
| "train_dataset": dataset["train"], | |
| "eval_dataset": dataset["validation"] if has_validation else None, | |
| } | |
| trainer_parameters = set(inspect.signature(SFTTrainer).parameters) | |
| skipped_trainer = sorted( | |
| key for key, value in trainer_values.items() if key not in trainer_parameters and value is not None | |
| ) | |
| if skipped_trainer: | |
| print(f"Skipping unsupported SFTTrainer keys: {skipped_trainer}") | |
| class CyberSecurityOWASPSFTTrainer(SFTTrainer): | |
| def compute_loss( | |
| self, | |
| model, | |
| inputs, | |
| return_outputs: bool = False, | |
| num_items_in_batch=None, | |
| ): | |
| compute_loss_kwargs = {"return_outputs": return_outputs} | |
| if "num_items_in_batch" in inspect.signature(Trainer.compute_loss).parameters: | |
| compute_loss_kwargs["num_items_in_batch"] = num_items_in_batch | |
| return Trainer.compute_loss(self, model, inputs, **compute_loss_kwargs) | |
| trainer = CyberSecurityOWASPSFTTrainer( | |
| **{ | |
| key: value | |
| for key, value in trainer_values.items() | |
| if value is not None and key in trainer_parameters | |
| } | |
| ) | |
| trainer.train() | |
| trainer.save_model(str(adapter_dir)) | |
| if push_to_hub: | |
| trainer.push_to_hub() | |
| volume.commit() | |
| cache_volume.commit() | |
| return { | |
| "run_name": run_name, | |
| "model_name": model_name, | |
| "adapter_dir": str(adapter_dir), | |
| "output_repo_id": output_repo_id, | |
| "train_jsonl": train_jsonl, | |
| "validation_jsonl": validation_jsonl if has_validation else "", | |
| "manifest_path": manifest_path, | |
| "reward_preflight": reward_preflight, | |
| "required_difficulties": required_difficulties, | |
| "default_total_train_episodes": DEFAULT_TOTAL_TRAIN_EPISODES, | |
| "default_episodes_per_level": DEFAULT_EPISODES_PER_LEVEL, | |
| "max_steps": max_steps, | |
| "push_to_hub": push_to_hub, | |
| "trackio_space_id": trackio_space_id, | |
| "trackio_project": trackio_project, | |
| } | |
| def _git_sha(default: str = "nogit") -> str: | |
| try: | |
| return subprocess.check_output( | |
| [ | |
| "git", | |
| "-c", | |
| f"safe.directory={PROJECT_ROOT.as_posix()}", | |
| "rev-parse", | |
| "HEAD", | |
| ], | |
| cwd=PROJECT_ROOT, | |
| text=True, | |
| stderr=subprocess.DEVNULL, | |
| ).strip() | |
| except Exception: | |
| return default | |
| def main( | |
| mode: str = "train", | |
| local_train_path: str = "outputs/sft/train.jsonl", | |
| local_validation_path: str = "outputs/sft/validation.jsonl", | |
| local_manifest_path: str = "outputs/sft/manifest.json", | |
| train_jsonl: str = "/runs/sft/train.jsonl", | |
| validation_jsonl: str = "/runs/sft/validation.jsonl", | |
| manifest_path: str = "/runs/sft/manifest.json", | |
| output_repo_id: str = DEFAULT_SFT_OUTPUT_REPO_ID, | |
| model_name: str = DEFAULT_GEMMA_MODEL, | |
| run_name: str = "", | |
| max_seq_length: int = 4096, | |
| max_steps: int = -1, | |
| num_train_epochs: float = 1.0, | |
| per_device_train_batch_size: int = 4, | |
| gradient_accumulation_steps: int = 4, | |
| learning_rate: float = 2e-5, | |
| lora_rank: int = 32, | |
| trackio_space_id: str = DEFAULT_TRACKIO_SPACE_ID, | |
| trackio_project: str = DEFAULT_TRACKIO_PROJECT, | |
| require_reward_verification: bool = True, | |
| required_difficulties: str = DEFAULT_CURRICULUM_LEVELS, | |
| min_terminal_reward: float = 12.0, | |
| min_train_rows: int = 1, | |
| source_mode: str = "local", | |
| repo_url: str = PUBLIC_REPO_URL, | |
| repo_branch: str = PUBLIC_REPO_BRANCH, | |
| detach: bool = False, | |
| push_to_hub: bool = False, | |
| ) -> None: | |
| del source_mode, repo_url, repo_branch # consumed during image construction | |
| model_name = _ensure_gemma4_model(model_name) | |
| if mode not in {"upload", "train"}: | |
| raise ValueError("mode must be 'upload' or 'train'") | |
| local_train = pathlib.Path(local_train_path) | |
| local_validation = pathlib.Path(local_validation_path) | |
| local_manifest = pathlib.Path(local_manifest_path) | |
| if require_reward_verification and local_train.exists(): | |
| local_reward_preflight = verify_sft_inputs( | |
| train_jsonl=str(local_train), | |
| validation_jsonl=str(local_validation) if local_validation.exists() else "", | |
| manifest_path=str(local_manifest) if local_manifest.exists() else "", | |
| required_difficulties=required_difficulties, | |
| min_terminal_reward=min_terminal_reward, | |
| min_train_rows=min_train_rows, | |
| ) | |
| print(f"Local SFT reward preflight: {json.dumps(local_reward_preflight, sort_keys=True)}") | |
| if not local_reward_preflight["passed"]: | |
| raise RuntimeError( | |
| "Local SFT reward verification failed; refusing to upload/train. " | |
| f"Failures: {local_reward_preflight['failures']}" | |
| ) | |
| if local_train.exists(): | |
| uploaded = upload_sft_jsonl.remote( | |
| "sft/train.jsonl", | |
| local_train.read_text(encoding="utf-8"), | |
| ) | |
| print(f"Uploaded train JSONL: {uploaded}") | |
| train_jsonl = uploaded | |
| if local_validation.exists(): | |
| uploaded_validation = upload_sft_jsonl.remote( | |
| "sft/validation.jsonl", | |
| local_validation.read_text(encoding="utf-8"), | |
| ) | |
| print(f"Uploaded validation JSONL: {uploaded_validation}") | |
| validation_jsonl = uploaded_validation | |
| if local_manifest.exists(): | |
| uploaded_manifest = upload_sft_jsonl.remote( | |
| "sft/manifest.json", | |
| local_manifest.read_text(encoding="utf-8"), | |
| ) | |
| print(f"Uploaded manifest: {uploaded_manifest}") | |
| manifest_path = uploaded_manifest | |
| if mode == "upload": | |
| return | |
| if not run_name: | |
| stamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S") | |
| run_name = f"CyberSecurity_OWASP-{_model_repo_slug(model_name)}-sft-{stamp}-{_git_sha()[:8]}" | |
| kwargs = dict( | |
| train_jsonl=train_jsonl, | |
| validation_jsonl=validation_jsonl, | |
| manifest_path=manifest_path, | |
| output_repo_id=output_repo_id, | |
| model_name=model_name, | |
| run_name=run_name, | |
| max_seq_length=max_seq_length, | |
| max_steps=max_steps, | |
| num_train_epochs=num_train_epochs, | |
| per_device_train_batch_size=per_device_train_batch_size, | |
| gradient_accumulation_steps=gradient_accumulation_steps, | |
| learning_rate=learning_rate, | |
| lora_rank=lora_rank, | |
| trackio_space_id=trackio_space_id, | |
| trackio_project=trackio_project, | |
| require_reward_verification=require_reward_verification, | |
| required_difficulties=required_difficulties, | |
| min_terminal_reward=min_terminal_reward, | |
| min_train_rows=min_train_rows, | |
| push_to_hub=push_to_hub, | |
| ) | |
| print(f"SFT run name: {run_name}") | |
| print(f"Train JSONL: {train_jsonl}") | |
| print(f"Validation JSONL: {validation_jsonl}") | |
| print(f"Manifest: {manifest_path}") | |
| print(f"Reward verification required: {require_reward_verification}") | |
| if required_difficulties: | |
| print(f"Required curriculum difficulties: {required_difficulties}") | |
| print(f"Hub push enabled: {push_to_hub}") | |
| if detach: | |
| call = train_cybersecurity_owasp_sft.spawn(**kwargs) | |
| print(f"Spawned Modal SFT call: {call.object_id}") | |
| else: | |
| result = train_cybersecurity_owasp_sft.remote(**kwargs) | |
| print(json.dumps(result, indent=2, sort_keys=True)) | |