Spaces:
Sleeping
Sleeping
File size: 3,172 Bytes
b3ee507 3807ea3 b3ee507 3807ea3 4e663d8 3807ea3 b3ee507 be8eade 3807ea3 b3ee507 3807ea3 be8eade 4e663d8 b3ee507 4e663d8 3807ea3 4e663d8 3807ea3 be8eade 3807ea3 b3ee507 be8eade b3ee507 3807ea3 b3ee507 3807ea3 b3ee507 3807ea3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 | """Modal-only GRPO config helper for CyberSecurity_OWASP.
This module intentionally does not run local training.
Use `scripts/modal_train_grpo.py` (persistent) or
`scripts/modal_ephemeral_train.py` (smoke) for execution.
"""
from __future__ import annotations
import os
from training.trackio_utils import build_run_name, get_git_sha
DEFAULT_GEMMA_MODEL = os.getenv("MODEL_NAME", "unsloth/gemma-4-E2B-it")
def ensure_gemma4_model(model_name: str) -> str:
if model_name != "unsloth/gemma-4-E2B-it":
raise ValueError(
"CyberSecurity_OWASP GRPO is pinned to unsloth/gemma-4-E2B-it, "
"matching the Unsloth Gemma 4 E2B RL notebook."
)
return model_name
def build_grpo_config():
"""Build the TRL GRPOConfig used by the Modal training pipeline."""
from trl import GRPOConfig
model_name = ensure_gemma4_model(os.getenv("MODEL_NAME", DEFAULT_GEMMA_MODEL))
difficulty = int(os.getenv("DIFFICULTY", "0"))
output_dir = os.getenv(
"OUTPUT_DIR",
f"CyberSecurity_OWASP-{model_name.replace('/', '-')}-grpo",
)
trackio_space_id = os.getenv("TRACKIO_SPACE_ID", "Humanlearning/CyberSecurity_OWASP-trackio")
os.environ.setdefault("TRACKIO_PROJECT", "CyberSecurity_OWASP-grpo")
run_name = os.getenv(
"RUN_NAME",
build_run_name(model_name, "grpo", difficulty, git_sha=get_git_sha()),
)
return GRPOConfig(
output_dir=output_dir,
report_to="trackio",
trackio_space_id=trackio_space_id,
run_name=run_name,
logging_steps=1,
save_steps=25,
learning_rate=5e-6,
num_train_epochs=1,
per_device_train_batch_size=1,
gradient_accumulation_steps=32,
num_generations=6,
max_prompt_length=4096,
max_completion_length=768,
use_vllm=True,
vllm_mode="colocate",
vllm_gpu_memory_utilization=0.2,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
push_to_hub=False,
)
def main() -> None:
import argparse
parser = argparse.ArgumentParser(
description=(
"CyberSecurity_OWASP GRPO config helper."
" Actual GRPO training is executed on Modal only."
)
)
parser.add_argument(
"--difficulty",
type=int,
default=0,
help="Optional curriculum difficulty included in the generated run name.",
)
parser.add_argument("--model-name", default=DEFAULT_GEMMA_MODEL)
parser.add_argument(
"--output-dir",
default=None,
help="Optional GRPO output_dir override.",
)
args = parser.parse_args()
os.environ["MODEL_NAME"] = ensure_gemma4_model(args.model_name)
if args.output_dir:
os.environ["OUTPUT_DIR"] = args.output_dir
config = build_grpo_config()
print("GRPO config (Modal execution):")
print(config)
print(
"Run on Modal, for example:\n"
"uv run --extra modal modal run scripts/modal_train_grpo.py "
f"--model-name {args.model_name} --difficulty {args.difficulty}"
)
if __name__ == "__main__":
main()
|