Cyber_analyst-round1 / training /train_grpo.py
Humanlearning's picture
feat: enhance scenario authoring and caching mechanisms, update action submission terminology, and improve reward configuration for CyberSecurity_OWASP environment
be8eade
"""Modal-only GRPO config helper for CyberSecurity_OWASP.
This module intentionally does not run local training.
Use `scripts/modal_train_grpo.py` (persistent) or
`scripts/modal_ephemeral_train.py` (smoke) for execution.
"""
from __future__ import annotations
import os
from training.trackio_utils import build_run_name, get_git_sha
DEFAULT_GEMMA_MODEL = os.getenv("MODEL_NAME", "unsloth/gemma-4-E2B-it")
def ensure_gemma4_model(model_name: str) -> str:
if model_name != "unsloth/gemma-4-E2B-it":
raise ValueError(
"CyberSecurity_OWASP GRPO is pinned to unsloth/gemma-4-E2B-it, "
"matching the Unsloth Gemma 4 E2B RL notebook."
)
return model_name
def build_grpo_config():
"""Build the TRL GRPOConfig used by the Modal training pipeline."""
from trl import GRPOConfig
model_name = ensure_gemma4_model(os.getenv("MODEL_NAME", DEFAULT_GEMMA_MODEL))
difficulty = int(os.getenv("DIFFICULTY", "0"))
output_dir = os.getenv(
"OUTPUT_DIR",
f"CyberSecurity_OWASP-{model_name.replace('/', '-')}-grpo",
)
trackio_space_id = os.getenv("TRACKIO_SPACE_ID", "Humanlearning/CyberSecurity_OWASP-trackio")
os.environ.setdefault("TRACKIO_PROJECT", "CyberSecurity_OWASP-grpo")
run_name = os.getenv(
"RUN_NAME",
build_run_name(model_name, "grpo", difficulty, git_sha=get_git_sha()),
)
return GRPOConfig(
output_dir=output_dir,
report_to="trackio",
trackio_space_id=trackio_space_id,
run_name=run_name,
logging_steps=1,
save_steps=25,
learning_rate=5e-6,
num_train_epochs=1,
per_device_train_batch_size=1,
gradient_accumulation_steps=32,
num_generations=6,
max_prompt_length=4096,
max_completion_length=768,
use_vllm=True,
vllm_mode="colocate",
vllm_gpu_memory_utilization=0.2,
gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
push_to_hub=False,
)
def main() -> None:
import argparse
parser = argparse.ArgumentParser(
description=(
"CyberSecurity_OWASP GRPO config helper."
" Actual GRPO training is executed on Modal only."
)
)
parser.add_argument(
"--difficulty",
type=int,
default=0,
help="Optional curriculum difficulty included in the generated run name.",
)
parser.add_argument("--model-name", default=DEFAULT_GEMMA_MODEL)
parser.add_argument(
"--output-dir",
default=None,
help="Optional GRPO output_dir override.",
)
args = parser.parse_args()
os.environ["MODEL_NAME"] = ensure_gemma4_model(args.model_name)
if args.output_dir:
os.environ["OUTPUT_DIR"] = args.output_dir
config = build_grpo_config()
print("GRPO config (Modal execution):")
print(config)
print(
"Run on Modal, for example:\n"
"uv run --extra modal modal run scripts/modal_train_grpo.py "
f"--model-name {args.model_name} --difficulty {args.difficulty}"
)
if __name__ == "__main__":
main()