Spaces:
Sleeping
Sleeping
feat: enhance scenario authoring and caching mechanisms, update action submission terminology, and improve reward configuration for CyberSecurity_OWASP environment
be8eade | """Modal-only GRPO config helper for CyberSecurity_OWASP. | |
| This module intentionally does not run local training. | |
| Use `scripts/modal_train_grpo.py` (persistent) or | |
| `scripts/modal_ephemeral_train.py` (smoke) for execution. | |
| """ | |
| from __future__ import annotations | |
| import os | |
| from training.trackio_utils import build_run_name, get_git_sha | |
| DEFAULT_GEMMA_MODEL = os.getenv("MODEL_NAME", "unsloth/gemma-4-E2B-it") | |
| def ensure_gemma4_model(model_name: str) -> str: | |
| if model_name != "unsloth/gemma-4-E2B-it": | |
| raise ValueError( | |
| "CyberSecurity_OWASP GRPO is pinned to unsloth/gemma-4-E2B-it, " | |
| "matching the Unsloth Gemma 4 E2B RL notebook." | |
| ) | |
| return model_name | |
| def build_grpo_config(): | |
| """Build the TRL GRPOConfig used by the Modal training pipeline.""" | |
| from trl import GRPOConfig | |
| model_name = ensure_gemma4_model(os.getenv("MODEL_NAME", DEFAULT_GEMMA_MODEL)) | |
| difficulty = int(os.getenv("DIFFICULTY", "0")) | |
| output_dir = os.getenv( | |
| "OUTPUT_DIR", | |
| f"CyberSecurity_OWASP-{model_name.replace('/', '-')}-grpo", | |
| ) | |
| trackio_space_id = os.getenv("TRACKIO_SPACE_ID", "Humanlearning/CyberSecurity_OWASP-trackio") | |
| os.environ.setdefault("TRACKIO_PROJECT", "CyberSecurity_OWASP-grpo") | |
| run_name = os.getenv( | |
| "RUN_NAME", | |
| build_run_name(model_name, "grpo", difficulty, git_sha=get_git_sha()), | |
| ) | |
| return GRPOConfig( | |
| output_dir=output_dir, | |
| report_to="trackio", | |
| trackio_space_id=trackio_space_id, | |
| run_name=run_name, | |
| logging_steps=1, | |
| save_steps=25, | |
| learning_rate=5e-6, | |
| num_train_epochs=1, | |
| per_device_train_batch_size=1, | |
| gradient_accumulation_steps=32, | |
| num_generations=6, | |
| max_prompt_length=4096, | |
| max_completion_length=768, | |
| use_vllm=True, | |
| vllm_mode="colocate", | |
| vllm_gpu_memory_utilization=0.2, | |
| gradient_checkpointing=True, | |
| gradient_checkpointing_kwargs={"use_reentrant": False}, | |
| push_to_hub=False, | |
| ) | |
| def main() -> None: | |
| import argparse | |
| parser = argparse.ArgumentParser( | |
| description=( | |
| "CyberSecurity_OWASP GRPO config helper." | |
| " Actual GRPO training is executed on Modal only." | |
| ) | |
| ) | |
| parser.add_argument( | |
| "--difficulty", | |
| type=int, | |
| default=0, | |
| help="Optional curriculum difficulty included in the generated run name.", | |
| ) | |
| parser.add_argument("--model-name", default=DEFAULT_GEMMA_MODEL) | |
| parser.add_argument( | |
| "--output-dir", | |
| default=None, | |
| help="Optional GRPO output_dir override.", | |
| ) | |
| args = parser.parse_args() | |
| os.environ["MODEL_NAME"] = ensure_gemma4_model(args.model_name) | |
| if args.output_dir: | |
| os.environ["OUTPUT_DIR"] = args.output_dir | |
| config = build_grpo_config() | |
| print("GRPO config (Modal execution):") | |
| print(config) | |
| print( | |
| "Run on Modal, for example:\n" | |
| "uv run --extra modal modal run scripts/modal_train_grpo.py " | |
| f"--model-name {args.model_name} --difficulty {args.difficulty}" | |
| ) | |
| if __name__ == "__main__": | |
| main() | |