Cyber_analyst-round1 / scripts /modal_train_sft.py
Humanlearning's picture
feat: enhance SFT training process with new tokenization method, implement custom trainer class for loss computation, and update README with GRPO launcher details for Unsloth LoRA integration
e5fe6f5
"""Modal SFT launcher for CyberSecurity_OWASP action JSON data.
This trains a LoRA adapter on chat JSONL generated by
``scripts/generate_sft_dataset.py``. It intentionally mirrors the repo's Modal
training pattern: local execution only launches remote jobs, while training runs
inside Modal and saves adapters to the persistent run volume.
"""
from __future__ import annotations
import json
import os
import pathlib
import subprocess
from datetime import datetime, timezone
from typing import Any
import modal
APP_NAME = "CyberSecurity_OWASP-sft"
VOLUME_NAME = "CyberSecurity_OWASP-grpo-runs"
CACHE_VOLUME_NAME = "CyberSecurity_OWASP-model-cache"
SECRET_NAME = "CyberSecurity_OWASP-secrets"
RUNS_DIR = pathlib.Path("/runs")
CACHE_DIR = pathlib.Path("/cache")
HF_HOME_DIR = CACHE_DIR / "huggingface"
HF_HUB_CACHE_DIR = HF_HOME_DIR / "hub"
TORCH_HOME_DIR = CACHE_DIR / "torch"
XDG_CACHE_DIR = CACHE_DIR / "xdg"
UNSLOTH_CACHE_DIR = CACHE_DIR / "unsloth"
TRITON_CACHE_DIR = CACHE_DIR / "triton"
REMOTE_PROJECT = "/root/CyberSecurity_OWASP"
PROJECT_ROOT = pathlib.Path(__file__).resolve().parents[1]
DEFAULT_GEMMA_MODEL = "unsloth/gemma-4-E2B-it"
SFT_GPU_FALLBACK = ["H200", "H100", "A100-80GB", "L40S"]
DEFAULT_CURRICULUM_LEVELS = "0,1,2,3"
DEFAULT_TOTAL_TRAIN_EPISODES = 300
DEFAULT_EPISODES_PER_LEVEL = 75
DEFAULT_TRACKIO_SPACE_ID = "Humanlearning/CyberSecurity_OWASP-trackio"
DEFAULT_TRACKIO_PROJECT = "CyberSecurity_OWASP-sft"
DEFAULT_SFT_OUTPUT_REPO_ID = (
"Humanlearning/CyberSecurity_OWASP-unsloth-gemma-4-e2b-it-sft-lora"
)
PUBLIC_REPO_URL = "https://github.com/humandotlearning/CyberSecurity_OWASP.git"
PUBLIC_REPO_BRANCH = "master"
def _ensure_gemma4_model(model_name: str) -> str:
if model_name != DEFAULT_GEMMA_MODEL:
raise ValueError(
"CyberSecurity_OWASP SFT is pinned to "
f"{DEFAULT_GEMMA_MODEL}; received {model_name!r}."
)
return model_name
def _model_repo_slug(model_name: str) -> str:
return model_name.replace("/", "-").replace("_", "-").replace(".", "-").lower()
def _parse_int_csv(value: str) -> set[int]:
if not value.strip():
return set()
return {int(item.strip()) for item in value.split(",") if item.strip()}
SFT_ALLOWED_TOOLS = {
"inspect_policy_graph",
"list_routes",
"read_openapi",
"read_file",
"search_code",
"send_local_request",
"compare_identities",
"submit_diagnosis",
"patch_file",
"run_visible_tests",
"submit_fix",
"noop",
}
def _read_jsonl(path: pathlib.Path) -> list[dict[str, Any]]:
if not path.exists():
return []
rows: list[dict[str, Any]] = []
for line_number, line in enumerate(path.read_text(encoding="utf-8").splitlines(), start=1):
if not line.strip():
continue
try:
row = json.loads(line)
except json.JSONDecodeError as exc:
raise ValueError(f"{path}:{line_number}: invalid JSONL: {exc}") from exc
if not isinstance(row, dict):
raise ValueError(f"{path}:{line_number}: row must be a JSON object")
rows.append(row)
return rows
def _verify_sft_rows(
path: pathlib.Path,
*,
min_terminal_reward: float,
) -> tuple[list[str], list[float], int, set[int]]:
rows = _read_jsonl(path)
failures: list[str] = []
rewards: list[float] = []
difficulties: set[int] = set()
for index, row in enumerate(rows, start=1):
messages = row.get("messages")
if not isinstance(messages, list) or len(messages) < 3:
failures.append(f"{path}:{index}: messages must include system/user/assistant")
continue
assistant = messages[-1]
if assistant.get("role") != "assistant":
failures.append(f"{path}:{index}: final message must be assistant")
continue
try:
action = json.loads(str(assistant.get("content", "")))
except json.JSONDecodeError as exc:
failures.append(f"{path}:{index}: assistant content is not JSON: {exc}")
continue
if not isinstance(action, dict) or action.get("tool_name") not in SFT_ALLOWED_TOOLS:
failures.append(f"{path}:{index}: assistant content is not a valid tool action")
continue
metadata = row.get("metadata")
if not isinstance(metadata, dict):
failures.append(f"{path}:{index}: missing metadata")
continue
if metadata.get("final_success") is not True:
failures.append(f"{path}:{index}: final_success is not true")
continue
if metadata.get("anti_cheat_flags") or []:
failures.append(f"{path}:{index}: anti-cheat flags present")
continue
if "difficulty" in metadata:
difficulties.add(int(metadata.get("difficulty", 0)))
terminal_reward = float(metadata.get("terminal_total", 0.0) or 0.0)
rewards.append(terminal_reward)
if terminal_reward < min_terminal_reward:
failures.append(
f"{path}:{index}: terminal_total {terminal_reward:.3f} below {min_terminal_reward:.3f}"
)
continue
breakdown = metadata.get("final_reward_breakdown") or {}
if not isinstance(breakdown, dict):
failures.append(f"{path}:{index}: missing final_reward_breakdown")
continue
for key in ("security", "regression", "public_routes", "patch_quality", "visible_tests"):
if float(breakdown.get(key, 0.0) or 0.0) <= 0.0:
failures.append(f"{path}:{index}: reward component {key} is not positive")
break
return failures, rewards, len(rows), difficulties
def verify_sft_inputs(
*,
train_jsonl: str,
validation_jsonl: str = "",
manifest_path: str = "",
required_difficulties: str = "",
min_terminal_reward: float = 12.0,
min_train_rows: int = 1,
) -> dict[str, Any]:
train_path = pathlib.Path(train_jsonl)
validation_path = pathlib.Path(validation_jsonl) if validation_jsonl else pathlib.Path("")
failures, rewards, train_rows, difficulties = _verify_sft_rows(
train_path,
min_terminal_reward=min_terminal_reward,
)
validation_rows = 0
if validation_jsonl and validation_path.exists() and validation_path.stat().st_size > 0:
validation_failures, validation_rewards, validation_rows, validation_difficulties = _verify_sft_rows(
validation_path,
min_terminal_reward=min_terminal_reward,
)
failures.extend(validation_failures)
rewards.extend(validation_rewards)
difficulties.update(validation_difficulties)
if train_rows < min_train_rows:
failures.append(f"{train_path}: expected at least {min_train_rows} train rows, found {train_rows}")
manifest_verification: dict[str, Any] = {}
manifest = pathlib.Path(manifest_path) if manifest_path else pathlib.Path("")
if manifest_path and manifest.exists():
try:
manifest_data = json.loads(manifest.read_text(encoding="utf-8"))
manifest_verification = dict(manifest_data.get("reward_verification") or {})
manifest_difficulties = {
int(item) for item in manifest_data.get("difficulty_levels", []) or []
}
except Exception as exc:
failures.append(f"{manifest}: could not read manifest reward verification: {exc}")
manifest_difficulties = set()
if manifest_verification and manifest_verification.get("passed") is not True:
failures.append(f"{manifest}: manifest reward_verification did not pass")
else:
manifest_difficulties = set()
required = _parse_int_csv(required_difficulties) or manifest_difficulties
missing_difficulties = sorted(level for level in required if level not in difficulties)
if missing_difficulties:
failures.append(f"missing required curriculum difficulty rows: {missing_difficulties}")
reward_summary = {
"min": min(rewards) if rewards else 0.0,
"max": max(rewards) if rewards else 0.0,
"mean": (sum(rewards) / len(rewards)) if rewards else 0.0,
}
return {
"passed": not failures,
"failure_count": len(failures),
"failures": failures[:50],
"train_rows": train_rows,
"validation_rows": validation_rows,
"difficulties": sorted(difficulties),
"required_difficulties": sorted(required),
"missing_difficulties": missing_difficulties,
"min_terminal_reward": float(min_terminal_reward),
"reward_summary": reward_summary,
"manifest_reward_verification": manifest_verification,
}
def _configure_modal_cache_env() -> dict[str, str]:
values = {
"HF_HOME": str(HF_HOME_DIR),
"HF_HUB_CACHE": str(HF_HUB_CACHE_DIR),
"TRANSFORMERS_CACHE": str(HF_HUB_CACHE_DIR),
"TORCH_HOME": str(TORCH_HOME_DIR),
"XDG_CACHE_HOME": str(XDG_CACHE_DIR),
"UNSLOTH_CACHE_DIR": str(UNSLOTH_CACHE_DIR),
"UNSLOTH_COMPILE_CACHE": str(UNSLOTH_CACHE_DIR / "compile"),
"TRITON_CACHE_DIR": str(TRITON_CACHE_DIR),
}
for key, value in values.items():
os.environ[key] = value
for path in {
CACHE_DIR,
HF_HOME_DIR,
HF_HUB_CACHE_DIR,
TORCH_HOME_DIR,
XDG_CACHE_DIR,
UNSLOTH_CACHE_DIR,
UNSLOTH_CACHE_DIR / "compile",
TRITON_CACHE_DIR,
}:
path.mkdir(parents=True, exist_ok=True)
return values
def _cli_arg_value(name: str, default: str = "") -> str:
import sys
args = sys.argv[1:]
flag = f"--{name}"
for index, arg in enumerate(args):
if arg == flag and index + 1 < len(args):
return args[index + 1]
if arg.startswith(f"{flag}="):
return arg.split("=", 1)[1]
return default
def _source_mode() -> str:
return _cli_arg_value("source-mode", os.environ.get("MODAL_SOURCE_MODE", "local"))
def _training_image() -> modal.Image:
image = (
modal.Image.from_registry(
"nvidia/cuda:12.8.0-devel-ubuntu22.04",
add_python="3.11",
)
.apt_install("git", "build-essential", "curl")
.uv_pip_install(
"torch==2.10.0",
"triton>=3.4.0",
"torchvision==0.25.0",
"bitsandbytes",
"accelerate",
"datasets",
"huggingface_hub",
"peft",
"tokenizers",
"trackio>=0.25.0",
"transformers>=5.5.0",
"trl>=0.28.0",
)
.uv_pip_install(
"unsloth_zoo[base] @ git+https://github.com/unslothai/unsloth-zoo",
"unsloth[base] @ git+https://github.com/unslothai/unsloth",
)
.uv_pip_install("timm", extra_options="--no-deps")
.uv_pip_install("pydantic==2.10.6")
)
if _source_mode() == "public":
repo_url = _cli_arg_value("repo-url", PUBLIC_REPO_URL)
repo_branch = _cli_arg_value("repo-branch", PUBLIC_REPO_BRANCH)
image = image.run_commands(
f"git clone --depth 1 --branch {repo_branch} {repo_url} {REMOTE_PROJECT}",
f"python -m pip install --no-deps -e {REMOTE_PROJECT}",
)
else:
image = image.add_local_dir(
PROJECT_ROOT,
remote_path=REMOTE_PROJECT,
copy=True,
ignore=[
".git",
".venv",
".env",
".env.*",
"__pycache__",
".pytest_cache",
"outputs",
"*.pyc",
],
)
image = image.run_commands(f"python -m pip install --no-deps -e {REMOTE_PROJECT}")
return image.workdir(REMOTE_PROJECT)
app = modal.App(APP_NAME)
volume = modal.Volume.from_name(VOLUME_NAME, create_if_missing=True)
cache_volume = modal.Volume.from_name(CACHE_VOLUME_NAME, create_if_missing=True)
training_image = _training_image()
secrets = [modal.Secret.from_name(SECRET_NAME, required_keys=["HF_TOKEN"])]
@app.function(
image=modal.Image.debian_slim(python_version="3.11"),
timeout=60 * 20,
volumes={RUNS_DIR: volume},
)
def upload_sft_jsonl(relative_path: str, content: str) -> str:
target = RUNS_DIR / relative_path
target.parent.mkdir(parents=True, exist_ok=True)
target.write_text(content, encoding="utf-8")
volume.commit()
return str(target)
@app.function(
image=training_image,
gpu=SFT_GPU_FALLBACK,
timeout=12 * 60 * 60,
volumes={RUNS_DIR: volume, CACHE_DIR: cache_volume},
secrets=secrets,
)
def train_cybersecurity_owasp_sft(
train_jsonl: str = "/runs/sft/train.jsonl",
validation_jsonl: str = "/runs/sft/validation.jsonl",
manifest_path: str = "/runs/sft/manifest.json",
output_repo_id: str = DEFAULT_SFT_OUTPUT_REPO_ID,
model_name: str = DEFAULT_GEMMA_MODEL,
run_name: str = "",
max_seq_length: int = 4096,
max_steps: int = -1,
num_train_epochs: float = 1.0,
per_device_train_batch_size: int = 4,
gradient_accumulation_steps: int = 4,
learning_rate: float = 2e-5,
lora_rank: int = 32,
trackio_space_id: str = DEFAULT_TRACKIO_SPACE_ID,
trackio_project: str = DEFAULT_TRACKIO_PROJECT,
require_reward_verification: bool = True,
required_difficulties: str = DEFAULT_CURRICULUM_LEVELS,
min_terminal_reward: float = 12.0,
min_train_rows: int = 1,
push_to_hub: bool = False,
) -> dict[str, Any]:
import inspect
from datasets import Dataset, load_dataset
from huggingface_hub import snapshot_download
from transformers import Trainer
from trl import SFTConfig, SFTTrainer
try:
from trl.chat_template_utils import add_response_schema
except ImportError:
def add_response_schema(tokenizer):
return tokenizer
from unsloth import FastVisionModel
model_name = _ensure_gemma4_model(model_name)
cache_env = _configure_modal_cache_env()
hf_token = os.environ.get("HF_TOKEN")
if not hf_token:
raise RuntimeError(f"HF_TOKEN is missing from the Modal secret {SECRET_NAME}.")
output_repo_id = output_repo_id or DEFAULT_SFT_OUTPUT_REPO_ID
os.environ["TRACKIO_SPACE_ID"] = trackio_space_id
os.environ["TRACKIO_PROJECT"] = trackio_project
stamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
run_name = run_name or f"CyberSecurity_OWASP-{_model_repo_slug(model_name)}-sft-{stamp}"
output_dir = RUNS_DIR / run_name
adapter_dir = output_dir / "sft_adapter"
output_dir.mkdir(parents=True, exist_ok=True)
data_files = {"train": train_jsonl}
validation_path = pathlib.Path(validation_jsonl)
has_validation = validation_path.exists() and validation_path.stat().st_size > 0
if has_validation:
data_files["validation"] = validation_jsonl
reward_preflight = verify_sft_inputs(
train_jsonl=train_jsonl,
validation_jsonl=validation_jsonl if has_validation else "",
manifest_path=manifest_path,
required_difficulties=required_difficulties,
min_terminal_reward=min_terminal_reward,
min_train_rows=min_train_rows,
)
print(f"SFT reward preflight: {json.dumps(reward_preflight, sort_keys=True)}")
if require_reward_verification and not reward_preflight["passed"]:
raise RuntimeError(
"SFT reward verification failed; refusing to start model training. "
f"Failures: {reward_preflight['failures']}"
)
dataset = load_dataset("json", data_files=data_files)
print(f"SFT run name: {run_name}")
print(f"Model: {model_name}")
print(f"Train JSONL: {train_jsonl}")
print(f"Validation JSONL: {validation_jsonl if has_validation else '(none)'}")
print(f"Output adapter dir: {adapter_dir}")
print(f"Output repo: https://huggingface.co/{output_repo_id}")
print(f"Trackio Space: https://huggingface.co/spaces/{trackio_space_id}")
print(f"HF_HUB_CACHE: {cache_env['HF_HUB_CACHE']}")
print(
"SFT target: "
f"{DEFAULT_TOTAL_TRAIN_EPISODES} total train episodes, "
f"{DEFAULT_EPISODES_PER_LEVEL} per level across {DEFAULT_CURRICULUM_LEVELS}"
)
try:
snapshot_download(repo_id=model_name, cache_dir=str(HF_HUB_CACHE_DIR), token=hf_token)
cache_volume.commit()
except Exception as exc:
print(f"Model snapshot prefetch skipped; loader will retry directly. Error: {exc!r}")
model_api = FastVisionModel
model, tokenizer = model_api.from_pretrained(
model_name=model_name,
max_seq_length=max_seq_length,
load_in_4bit=False,
fast_inference=False,
cache_dir=str(HF_HUB_CACHE_DIR),
token=hf_token,
)
try:
tokenizer = add_response_schema(tokenizer)
except Exception as exc:
print(f"Tokenizer response schema add skipped: {exc!r}")
def _tokenize_sft_split(split_name: str, split_dataset) -> Dataset:
tokenized_rows: list[dict[str, list[int]]] = []
total_rows = len(split_dataset)
for row_index, example in enumerate(split_dataset, start=1):
messages = example["messages"]
if isinstance(messages, str):
messages = json.loads(messages)
rendered = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False,
)
try:
encoded = tokenizer(
rendered,
add_special_tokens=False,
truncation=True,
max_length=max_seq_length,
)
except TypeError:
encoded = tokenizer(
text=rendered,
add_special_tokens=False,
truncation=True,
max_length=max_seq_length,
)
input_ids = encoded["input_ids"]
if input_ids and isinstance(input_ids[0], list):
input_ids = input_ids[0]
input_ids = [int(token_id) for token_id in input_ids[:max_seq_length]]
if not input_ids:
raise RuntimeError(f"{split_name} row {row_index} produced no tokens.")
tokenized_rows.append({"input_ids": input_ids, "labels": list(input_ids)})
if row_index % 500 == 0 or row_index == total_rows:
print(f"Tokenized {split_name} rows: {row_index}/{total_rows}")
return Dataset.from_list(tokenized_rows)
dataset["train"] = _tokenize_sft_split("train", dataset["train"])
if has_validation:
dataset["validation"] = _tokenize_sft_split("validation", dataset["validation"])
model = model_api.get_peft_model(
model,
r=lora_rank,
target_modules=[
"q_proj",
"k_proj",
"v_proj",
"o_proj",
"gate_proj",
"up_proj",
"down_proj",
],
lora_alpha=lora_rank * 2,
use_gradient_checkpointing="unsloth",
random_state=3407,
)
if hasattr(model_api, "for_training"):
model_api.for_training(model)
sft_values = {
"output_dir": str(output_dir),
"max_seq_length": max_seq_length,
"max_steps": max_steps,
"num_train_epochs": num_train_epochs,
"per_device_train_batch_size": per_device_train_batch_size,
"gradient_accumulation_steps": gradient_accumulation_steps,
"learning_rate": learning_rate,
"optim": "adamw_8bit",
"dataset_num_proc": None,
"logging_steps": 1,
"logging_first_step": True,
"save_steps": max(10, max_steps) if max_steps > 0 else 100,
"report_to": "trackio",
"project": trackio_project,
"trackio_space_id": trackio_space_id,
"run_name": run_name,
"assistant_only_loss": False,
"packing": False,
"bf16": True,
"tf32": True,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": False},
"push_to_hub": push_to_hub,
"hub_model_id": output_repo_id,
"hub_private_repo": True,
"hub_strategy": "every_save",
}
sft_parameters = set(inspect.signature(SFTConfig).parameters)
skipped = sorted(set(sft_values) - sft_parameters)
if skipped:
print(f"Skipping unsupported SFTConfig keys: {skipped}")
training_args = SFTConfig(
**{key: value for key, value in sft_values.items() if key in sft_parameters}
)
trainer_values = {
"model": model,
"processing_class": tokenizer,
"args": training_args,
"train_dataset": dataset["train"],
"eval_dataset": dataset["validation"] if has_validation else None,
}
trainer_parameters = set(inspect.signature(SFTTrainer).parameters)
skipped_trainer = sorted(
key for key, value in trainer_values.items() if key not in trainer_parameters and value is not None
)
if skipped_trainer:
print(f"Skipping unsupported SFTTrainer keys: {skipped_trainer}")
class CyberSecurityOWASPSFTTrainer(SFTTrainer):
def compute_loss(
self,
model,
inputs,
return_outputs: bool = False,
num_items_in_batch=None,
):
compute_loss_kwargs = {"return_outputs": return_outputs}
if "num_items_in_batch" in inspect.signature(Trainer.compute_loss).parameters:
compute_loss_kwargs["num_items_in_batch"] = num_items_in_batch
return Trainer.compute_loss(self, model, inputs, **compute_loss_kwargs)
trainer = CyberSecurityOWASPSFTTrainer(
**{
key: value
for key, value in trainer_values.items()
if value is not None and key in trainer_parameters
}
)
trainer.train()
trainer.save_model(str(adapter_dir))
if push_to_hub:
trainer.push_to_hub()
volume.commit()
cache_volume.commit()
return {
"run_name": run_name,
"model_name": model_name,
"adapter_dir": str(adapter_dir),
"output_repo_id": output_repo_id,
"train_jsonl": train_jsonl,
"validation_jsonl": validation_jsonl if has_validation else "",
"manifest_path": manifest_path,
"reward_preflight": reward_preflight,
"required_difficulties": required_difficulties,
"default_total_train_episodes": DEFAULT_TOTAL_TRAIN_EPISODES,
"default_episodes_per_level": DEFAULT_EPISODES_PER_LEVEL,
"max_steps": max_steps,
"push_to_hub": push_to_hub,
"trackio_space_id": trackio_space_id,
"trackio_project": trackio_project,
}
def _git_sha(default: str = "nogit") -> str:
try:
return subprocess.check_output(
[
"git",
"-c",
f"safe.directory={PROJECT_ROOT.as_posix()}",
"rev-parse",
"HEAD",
],
cwd=PROJECT_ROOT,
text=True,
stderr=subprocess.DEVNULL,
).strip()
except Exception:
return default
@app.local_entrypoint()
def main(
mode: str = "train",
local_train_path: str = "outputs/sft/train.jsonl",
local_validation_path: str = "outputs/sft/validation.jsonl",
local_manifest_path: str = "outputs/sft/manifest.json",
train_jsonl: str = "/runs/sft/train.jsonl",
validation_jsonl: str = "/runs/sft/validation.jsonl",
manifest_path: str = "/runs/sft/manifest.json",
output_repo_id: str = DEFAULT_SFT_OUTPUT_REPO_ID,
model_name: str = DEFAULT_GEMMA_MODEL,
run_name: str = "",
max_seq_length: int = 4096,
max_steps: int = -1,
num_train_epochs: float = 1.0,
per_device_train_batch_size: int = 4,
gradient_accumulation_steps: int = 4,
learning_rate: float = 2e-5,
lora_rank: int = 32,
trackio_space_id: str = DEFAULT_TRACKIO_SPACE_ID,
trackio_project: str = DEFAULT_TRACKIO_PROJECT,
require_reward_verification: bool = True,
required_difficulties: str = DEFAULT_CURRICULUM_LEVELS,
min_terminal_reward: float = 12.0,
min_train_rows: int = 1,
source_mode: str = "local",
repo_url: str = PUBLIC_REPO_URL,
repo_branch: str = PUBLIC_REPO_BRANCH,
detach: bool = False,
push_to_hub: bool = False,
) -> None:
del source_mode, repo_url, repo_branch # consumed during image construction
model_name = _ensure_gemma4_model(model_name)
if mode not in {"upload", "train"}:
raise ValueError("mode must be 'upload' or 'train'")
local_train = pathlib.Path(local_train_path)
local_validation = pathlib.Path(local_validation_path)
local_manifest = pathlib.Path(local_manifest_path)
if require_reward_verification and local_train.exists():
local_reward_preflight = verify_sft_inputs(
train_jsonl=str(local_train),
validation_jsonl=str(local_validation) if local_validation.exists() else "",
manifest_path=str(local_manifest) if local_manifest.exists() else "",
required_difficulties=required_difficulties,
min_terminal_reward=min_terminal_reward,
min_train_rows=min_train_rows,
)
print(f"Local SFT reward preflight: {json.dumps(local_reward_preflight, sort_keys=True)}")
if not local_reward_preflight["passed"]:
raise RuntimeError(
"Local SFT reward verification failed; refusing to upload/train. "
f"Failures: {local_reward_preflight['failures']}"
)
if local_train.exists():
uploaded = upload_sft_jsonl.remote(
"sft/train.jsonl",
local_train.read_text(encoding="utf-8"),
)
print(f"Uploaded train JSONL: {uploaded}")
train_jsonl = uploaded
if local_validation.exists():
uploaded_validation = upload_sft_jsonl.remote(
"sft/validation.jsonl",
local_validation.read_text(encoding="utf-8"),
)
print(f"Uploaded validation JSONL: {uploaded_validation}")
validation_jsonl = uploaded_validation
if local_manifest.exists():
uploaded_manifest = upload_sft_jsonl.remote(
"sft/manifest.json",
local_manifest.read_text(encoding="utf-8"),
)
print(f"Uploaded manifest: {uploaded_manifest}")
manifest_path = uploaded_manifest
if mode == "upload":
return
if not run_name:
stamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
run_name = f"CyberSecurity_OWASP-{_model_repo_slug(model_name)}-sft-{stamp}-{_git_sha()[:8]}"
kwargs = dict(
train_jsonl=train_jsonl,
validation_jsonl=validation_jsonl,
manifest_path=manifest_path,
output_repo_id=output_repo_id,
model_name=model_name,
run_name=run_name,
max_seq_length=max_seq_length,
max_steps=max_steps,
num_train_epochs=num_train_epochs,
per_device_train_batch_size=per_device_train_batch_size,
gradient_accumulation_steps=gradient_accumulation_steps,
learning_rate=learning_rate,
lora_rank=lora_rank,
trackio_space_id=trackio_space_id,
trackio_project=trackio_project,
require_reward_verification=require_reward_verification,
required_difficulties=required_difficulties,
min_terminal_reward=min_terminal_reward,
min_train_rows=min_train_rows,
push_to_hub=push_to_hub,
)
print(f"SFT run name: {run_name}")
print(f"Train JSONL: {train_jsonl}")
print(f"Validation JSONL: {validation_jsonl}")
print(f"Manifest: {manifest_path}")
print(f"Reward verification required: {require_reward_verification}")
if required_difficulties:
print(f"Required curriculum difficulties: {required_difficulties}")
print(f"Hub push enabled: {push_to_hub}")
if detach:
call = train_cybersecurity_owasp_sft.spawn(**kwargs)
print(f"Spawned Modal SFT call: {call.object_id}")
else:
result = train_cybersecurity_owasp_sft.remote(**kwargs)
print(json.dumps(result, indent=2, sort_keys=True))