Spaces:
Sleeping
Sleeping
feat: introduce GRPO GPU fallback support, enhance training script with warmstart tagging, and add learning rate parameter for improved training flexibility
1b6d30b | """Persistent Modal GRPO launcher for CyberSecurity_OWASP. | |
| This packages the local repository into a Modal GPU image, runs a small | |
| tool-use GRPO job against the in-process CyberSecurity_OWASP environment, logs | |
| metrics/traces to Trackio, and saves LoRA checkpoints in a persistent Modal | |
| volume. | |
| Example: | |
| uv run --extra modal modal run scripts/modal_train_grpo.py \ | |
| --max-steps 10 \ | |
| --dataset-size 16 \ | |
| --num-generations 6 \ | |
| --difficulty 0 | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import os | |
| import pathlib | |
| import subprocess | |
| import sys | |
| from datetime import datetime, timezone | |
| from typing import Any | |
| import modal | |
| APP_NAME = "CyberSecurity_OWASP-grpo" | |
| VOLUME_NAME = "CyberSecurity_OWASP-grpo-runs" | |
| CACHE_VOLUME_NAME = "CyberSecurity_OWASP-model-cache" | |
| SCENARIO_CACHE_VOLUME_NAME = "CyberSecurity_OWASP-scenario-cache" | |
| SECRET_NAME = "CyberSecurity_OWASP-secrets" | |
| RUNS_DIR = pathlib.Path("/runs") | |
| CACHE_DIR = pathlib.Path("/cache") | |
| SCENARIO_CACHE_DIR = pathlib.Path("/scenario-cache") | |
| HF_HOME_DIR = CACHE_DIR / "huggingface" | |
| HF_HUB_CACHE_DIR = HF_HOME_DIR / "hub" | |
| TORCH_HOME_DIR = CACHE_DIR / "torch" | |
| XDG_CACHE_DIR = CACHE_DIR / "xdg" | |
| UNSLOTH_CACHE_DIR = CACHE_DIR / "unsloth" | |
| TRITON_CACHE_DIR = CACHE_DIR / "triton" | |
| REMOTE_PROJECT = "/root/CyberSecurity_OWASP" | |
| PROJECT_ROOT = pathlib.Path(__file__).resolve().parents[1] | |
| PUBLIC_REPO_URL = "https://github.com/humandotlearning/CyberSecurity_OWASP.git" | |
| PUBLIC_REPO_BRANCH = "master" | |
| DEFAULT_GEMMA_MODEL = "unsloth/gemma-4-E2B-it" | |
| GRPO_TRAINING_TIMEOUT_SECONDS = 24 * 60 * 60 | |
| GRPO_GPU_FALLBACK = ["L40S", "L4"] | |
| _IMAGE_NOTICE_PRINTED = False | |
| def _ensure_gemma4_model(model_name: str) -> str: | |
| if model_name != DEFAULT_GEMMA_MODEL: | |
| raise ValueError( | |
| "CyberSecurity_OWASP GRPO training is pinned to " | |
| f"{DEFAULT_GEMMA_MODEL}, matching the Unsloth Gemma 4 E2B RL notebook. " | |
| f"Received {model_name!r}." | |
| ) | |
| return model_name | |
| def _model_repo_slug(model_name: str) -> str: | |
| return ( | |
| model_name.replace("/", "-") | |
| .replace("_", "-") | |
| .replace(".", "-") | |
| .lower() | |
| ) | |
| def _grpo_output_repo_slug( | |
| model_name: str, | |
| *, | |
| initial_adapter_path: str = "", | |
| initial_adapter_repo_id: str = "", | |
| ) -> str: | |
| warmstart_tag = ( | |
| "-sft-warmstart" if initial_adapter_path or initial_adapter_repo_id else "" | |
| ) | |
| return ( | |
| f"CyberSecurity_OWASP-{_model_repo_slug(model_name)}" | |
| f"{warmstart_tag}-grpo-lora" | |
| ) | |
| def _grpo_run_algo_tag( | |
| *, | |
| initial_adapter_path: str = "", | |
| initial_adapter_repo_id: str = "", | |
| ) -> str: | |
| return "sft-warmstart-grpo" if initial_adapter_path or initial_adapter_repo_id else "grpo" | |
| def _hf_model_cache_path(model_name: str) -> pathlib.Path: | |
| return HF_HUB_CACHE_DIR / f"models--{model_name.replace('/', '--')}" | |
| def _resolve_grpo_batch_config( | |
| *, | |
| per_device_train_batch_size: int, | |
| gradient_accumulation_steps: int, | |
| num_generations: int, | |
| world_size: int = 1, | |
| ) -> tuple[int, int]: | |
| if num_generations < 1: | |
| raise ValueError("--num-generations must be at least 1.") | |
| if per_device_train_batch_size < 1: | |
| raise ValueError("--per-device-train-batch-size must be at least 1.") | |
| if world_size < 1: | |
| raise ValueError("world_size must be at least 1.") | |
| resolved_gradient_accumulation_steps = ( | |
| gradient_accumulation_steps | |
| if gradient_accumulation_steps > 0 | |
| else max(2, num_generations) | |
| ) | |
| if resolved_gradient_accumulation_steps < 1: | |
| raise ValueError("--gradient-accumulation-steps must be at least 1.") | |
| effective_batch_size = ( | |
| per_device_train_batch_size | |
| * resolved_gradient_accumulation_steps | |
| * world_size | |
| ) | |
| if effective_batch_size % num_generations: | |
| raise ValueError( | |
| "Invalid GRPO batch shape: " | |
| "per_device_train_batch_size * gradient_accumulation_steps * world_size " | |
| f"must be divisible by num_generations. Got " | |
| f"{per_device_train_batch_size} * " | |
| f"{resolved_gradient_accumulation_steps} * {world_size} = " | |
| f"{effective_batch_size}, which is not divisible by {num_generations}." | |
| ) | |
| return resolved_gradient_accumulation_steps, effective_batch_size | |
| def _validate_vllm_config(*, use_vllm: bool, vllm_gpu_memory_utilization: float) -> None: | |
| if not use_vllm: | |
| return | |
| if not 0.0 < vllm_gpu_memory_utilization <= 0.95: | |
| raise ValueError( | |
| "--vllm-gpu-memory-utilization must be in the interval (0.0, 0.95] " | |
| "when --use-vllm is enabled." | |
| ) | |
| def _extract_first_json_object(text: str) -> dict[str, Any] | None: | |
| stripped = text.strip() | |
| candidates = [stripped] | |
| if "```" in stripped: | |
| for part in stripped.split("```"): | |
| part = part.strip() | |
| if part.startswith("json"): | |
| part = part[4:].strip() | |
| candidates.append(part) | |
| for candidate in candidates: | |
| try: | |
| loaded = json.loads(candidate) | |
| except Exception: | |
| continue | |
| if isinstance(loaded, dict): | |
| return loaded | |
| start = stripped.find("{") | |
| while start >= 0: | |
| depth = 0 | |
| in_string = False | |
| escaped = False | |
| for index in range(start, len(stripped)): | |
| char = stripped[index] | |
| if in_string: | |
| if escaped: | |
| escaped = False | |
| elif char == "\\": | |
| escaped = True | |
| elif char == '"': | |
| in_string = False | |
| continue | |
| if char == '"': | |
| in_string = True | |
| elif char == "{": | |
| depth += 1 | |
| elif char == "}": | |
| depth -= 1 | |
| if depth == 0: | |
| try: | |
| loaded = json.loads(stripped[start : index + 1]) | |
| except Exception: | |
| break | |
| if isinstance(loaded, dict): | |
| return loaded | |
| start = stripped.find("{", start + 1) | |
| return None | |
| def _configure_modal_cache_env() -> dict[str, str]: | |
| values = { | |
| "HF_HOME": str(HF_HOME_DIR), | |
| "HF_HUB_CACHE": str(HF_HUB_CACHE_DIR), | |
| "TRANSFORMERS_CACHE": str(HF_HUB_CACHE_DIR), | |
| "TORCH_HOME": str(TORCH_HOME_DIR), | |
| "XDG_CACHE_HOME": str(XDG_CACHE_DIR), | |
| "UNSLOTH_CACHE_DIR": str(UNSLOTH_CACHE_DIR), | |
| "UNSLOTH_COMPILE_CACHE": str(UNSLOTH_CACHE_DIR / "compile"), | |
| "TRITON_CACHE_DIR": str(TRITON_CACHE_DIR), | |
| } | |
| for key, value in values.items(): | |
| os.environ[key] = value | |
| for path in { | |
| CACHE_DIR, | |
| HF_HOME_DIR, | |
| HF_HUB_CACHE_DIR, | |
| TORCH_HOME_DIR, | |
| XDG_CACHE_DIR, | |
| UNSLOTH_CACHE_DIR, | |
| UNSLOTH_CACHE_DIR / "compile", | |
| TRITON_CACHE_DIR, | |
| }: | |
| path.mkdir(parents=True, exist_ok=True) | |
| return values | |
| def _configure_scenario_cache_env(*, required: bool = True) -> dict[str, str]: | |
| values = { | |
| "CYBERSECURITY_OWASP_SCENARIO_CACHE_DIR": str(SCENARIO_CACHE_DIR), | |
| "CYBERSECURITY_OWASP_SCENARIO_CACHE_MODE": "require" if required else "fallback", | |
| } | |
| for key, value in values.items(): | |
| os.environ[key] = value | |
| SCENARIO_CACHE_DIR.mkdir(parents=True, exist_ok=True) | |
| return values | |
| def _configure_reward_env( | |
| *, | |
| reward_config: str = "", | |
| reward_variant: str = "", | |
| reward_mode: str = "", | |
| ) -> dict[str, str]: | |
| values: dict[str, str] = {} | |
| if reward_config: | |
| values["CYBERSECURITY_OWASP_REWARD_CONFIG"] = reward_config | |
| if reward_variant: | |
| values["CYBERSECURITY_OWASP_REWARD_VARIANT"] = reward_variant | |
| if reward_mode: | |
| values["CYBERSECURITY_OWASP_REWARD_MODE"] = reward_mode | |
| for key, value in values.items(): | |
| os.environ[key] = value | |
| return values | |
| def _print_image_startup_notice() -> None: | |
| global _IMAGE_NOTICE_PRINTED | |
| if _IMAGE_NOTICE_PRINTED: | |
| return | |
| _IMAGE_NOTICE_PRINTED = True | |
| print( | |
| "Modal startup phase 1/5: building or validating the GPU training image. " | |
| "If this takes minutes, it is Modal image packaging/dependency cache work, " | |
| "not model-weight download." | |
| ) | |
| print( | |
| "Later remote phases will print: cache hit/miss, snapshot_download progress, " | |
| "Unsloth weight loading, GRPO heartbeat, Trackio upload, and volume commits." | |
| ) | |
| def _load_local_env_file() -> None: | |
| env_path = PROJECT_ROOT / ".env.local" | |
| if not env_path.exists(): | |
| return | |
| for raw_line in env_path.read_text(encoding="utf-8").splitlines(): | |
| line = raw_line.strip() | |
| if not line or line.startswith("#") or "=" not in line: | |
| continue | |
| key, value = line.split("=", 1) | |
| key = key.strip() | |
| if key not in {"TRACKIO_PROJECT"}: | |
| continue | |
| value = value.strip().strip('"').strip("'") | |
| os.environ.setdefault(key, value) | |
| def _modal_secrets() -> list[modal.Secret]: | |
| if _is_config_mode(): | |
| return [] | |
| return [modal.Secret.from_name(SECRET_NAME, required_keys=["HF_TOKEN"])] | |
| def _is_config_mode() -> bool: | |
| args = sys.argv[1:] | |
| for index, arg in enumerate(args): | |
| if arg == "--mode" and index + 1 < len(args): | |
| return args[index + 1] == "config" | |
| if arg.startswith("--mode="): | |
| return arg.split("=", 1)[1] == "config" | |
| return False | |
| def _is_prepare_cache_mode() -> bool: | |
| args = sys.argv[1:] | |
| for index, arg in enumerate(args): | |
| if arg == "--mode" and index + 1 < len(args): | |
| return args[index + 1] == "prepare-cache" | |
| if arg.startswith("--mode="): | |
| return arg.split("=", 1)[1] == "prepare-cache" | |
| return False | |
| _load_local_env_file() | |
| def _cli_arg_value(name: str, default: str = "") -> str: | |
| args = sys.argv[1:] | |
| flag = f"--{name}" | |
| for index, arg in enumerate(args): | |
| if arg == flag and index + 1 < len(args): | |
| return args[index + 1] | |
| if arg.startswith(f"{flag}="): | |
| return arg.split("=", 1)[1] | |
| return default | |
| def _source_mode() -> str: | |
| return _cli_arg_value("source-mode", os.environ.get("MODAL_SOURCE_MODE", "local")) | |
| def _training_image() -> modal.Image: | |
| if _is_prepare_cache_mode(): | |
| return _scenario_cache_image() | |
| if not _is_prepare_cache_mode(): | |
| _print_image_startup_notice() | |
| image = ( | |
| modal.Image.from_registry( | |
| "nvidia/cuda:12.8.0-devel-ubuntu22.04", | |
| add_python="3.11", | |
| ) | |
| .apt_install("git", "build-essential", "curl") | |
| .uv_pip_install( | |
| "torch==2.10.0", | |
| "triton>=3.4.0", | |
| "torchvision==0.25.0", | |
| "bitsandbytes", | |
| "accelerate", | |
| "datasets", | |
| "huggingface_hub", | |
| "peft", | |
| "pillow", | |
| "tokenizers", | |
| "nvidia-ml-py", | |
| "trackio>=0.25.0", | |
| "transformers>=5.5.0", | |
| "trl>=0.28.0", | |
| "openenv-core[core]>=0.2.3", | |
| ) | |
| .uv_pip_install( | |
| "unsloth_zoo[base] @ git+https://github.com/unslothai/unsloth-zoo", | |
| "unsloth[base] @ git+https://github.com/unslothai/unsloth", | |
| ) | |
| .uv_pip_install("timm", extra_options="--no-deps") | |
| .uv_pip_install("pydantic==2.10.6") | |
| .uv_pip_install("mergekit", "immutables==0.21", extra_options="--no-deps") | |
| .uv_pip_install("llm-blender", "weave") | |
| .uv_pip_install("trl>=0.28.0", "transformers>=5.5.0", "jmespath") | |
| ) | |
| if _source_mode() == "public": | |
| repo_url = _cli_arg_value("repo-url", PUBLIC_REPO_URL) | |
| repo_branch = _cli_arg_value("repo-branch", PUBLIC_REPO_BRANCH) | |
| image = image.run_commands( | |
| f"git clone --depth 1 --branch {repo_branch} {repo_url} {REMOTE_PROJECT}", | |
| f"python -m pip install --no-deps -e {REMOTE_PROJECT}", | |
| ) | |
| else: | |
| image = image.add_local_dir( | |
| PROJECT_ROOT, | |
| remote_path=REMOTE_PROJECT, | |
| copy=True, | |
| ignore=[ | |
| ".git", | |
| ".venv", | |
| ".env", | |
| ".env.*", | |
| "__pycache__", | |
| ".pytest_cache", | |
| "outputs", | |
| "*.pyc", | |
| ], | |
| ) | |
| image = image.run_commands( | |
| f"python -m pip install --no-deps -e {REMOTE_PROJECT}", | |
| ) | |
| return image.run_commands( | |
| "python -c \"import os, torch; import transformers.utils.hub as hub; " | |
| "hub.TRANSFORMERS_CACHE = getattr(hub, 'TRANSFORMERS_CACHE', " | |
| "os.path.join(os.path.expanduser('~'), '.cache', 'huggingface', 'hub')); " | |
| "from trl import GRPOConfig, GRPOTrainer; " | |
| "from CyberSecurity_OWASP.server.CyberSecurity_OWASP_environment import " | |
| "CybersecurityOwaspEnvironment; print('trainer import ok', torch.__version__)\"", | |
| ).workdir(REMOTE_PROJECT) | |
| def _scenario_cache_image() -> modal.Image: | |
| image = ( | |
| modal.Image.debian_slim(python_version="3.11") | |
| .apt_install("git") | |
| .uv_pip_install("openenv-core[core]>=0.2.3", "trackio>=0.25.0") | |
| ) | |
| if _source_mode() == "public": | |
| repo_url = _cli_arg_value("repo-url", PUBLIC_REPO_URL) | |
| repo_branch = _cli_arg_value("repo-branch", PUBLIC_REPO_BRANCH) | |
| image = image.run_commands( | |
| f"git clone --depth 1 --branch {repo_branch} {repo_url} {REMOTE_PROJECT}", | |
| f"python -m pip install --no-deps -e {REMOTE_PROJECT}", | |
| ) | |
| else: | |
| image = image.add_local_dir( | |
| PROJECT_ROOT, | |
| remote_path=REMOTE_PROJECT, | |
| copy=True, | |
| ignore=[ | |
| ".git", | |
| ".venv", | |
| ".env", | |
| ".env.*", | |
| "__pycache__", | |
| ".pytest_cache", | |
| "outputs", | |
| "*.pyc", | |
| ], | |
| ) | |
| image = image.run_commands( | |
| f"python -m pip install --no-deps -e {REMOTE_PROJECT}", | |
| ) | |
| return image.workdir(REMOTE_PROJECT) | |
| app = modal.App(APP_NAME) | |
| volume = modal.Volume.from_name(VOLUME_NAME, create_if_missing=True) | |
| cache_volume = modal.Volume.from_name(CACHE_VOLUME_NAME, create_if_missing=True) | |
| scenario_cache_volume = modal.Volume.from_name(SCENARIO_CACHE_VOLUME_NAME, create_if_missing=True) | |
| secrets = _modal_secrets() | |
| scenario_cache_image = _scenario_cache_image() | |
| training_image = _training_image() | |
| def prepare_modal_scenario_cache( | |
| seed_start: int = 0, | |
| difficulty_buckets: int = 0, | |
| train_per_bucket: int = 0, | |
| validation_per_bucket: int = 0, | |
| heldout_per_bucket: int = 0, | |
| force: bool = False, | |
| ) -> dict[str, Any]: | |
| if difficulty_buckets: | |
| os.environ["CYBERSECURITY_OWASP_DIFFICULTY_BUCKETS"] = str(difficulty_buckets) | |
| if train_per_bucket: | |
| os.environ["CYBERSECURITY_OWASP_TRAIN_SCENARIOS_PER_BUCKET"] = str(train_per_bucket) | |
| if validation_per_bucket: | |
| os.environ["CYBERSECURITY_OWASP_VALIDATION_SCENARIOS_PER_BUCKET"] = str(validation_per_bucket) | |
| if heldout_per_bucket: | |
| os.environ["CYBERSECURITY_OWASP_HELDOUT_SCENARIOS_PER_BUCKET"] = str(heldout_per_bucket) | |
| _configure_scenario_cache_env(required=False) | |
| from CyberSecurity_OWASP.config import load_scenario_authoring_config | |
| from CyberSecurity_OWASP.server.scenario_cache import prepare_scenario_cache | |
| settings = load_scenario_authoring_config() | |
| result = prepare_scenario_cache( | |
| cache_dir=SCENARIO_CACHE_DIR, | |
| settings=settings, | |
| seed_start=seed_start, | |
| force=force, | |
| ) | |
| scenario_cache_volume.commit() | |
| result["scenario_cache_volume"] = SCENARIO_CACHE_VOLUME_NAME | |
| return result | |
| def verify_modal_scenario_cache_for_training( | |
| split: str = "train", | |
| difficulty: int = 0, | |
| dataset_size: int = 2, | |
| seed_start: int = 0, | |
| ) -> dict[str, Any]: | |
| _configure_scenario_cache_env(required=True) | |
| scenario_cache_volume.reload() | |
| from CyberSecurity_OWASP.config import load_scenario_authoring_config | |
| from CyberSecurity_OWASP.server.CyberSecurity_OWASP_environment import ( | |
| CybersecurityOwaspEnvironment, | |
| ) | |
| from CyberSecurity_OWASP.server.curriculum import CurriculumController | |
| from CyberSecurity_OWASP.server.scenario_cache import ScenarioCache | |
| settings = load_scenario_authoring_config() | |
| scenario_profile = CurriculumController(settings=settings).select_profile( | |
| seed=seed_start, | |
| split=split, | |
| requested_difficulty=difficulty, | |
| ) | |
| resolved_difficulty = int(scenario_profile["difficulty"]) | |
| cache = ScenarioCache(SCENARIO_CACHE_DIR, settings=settings) | |
| coverage = cache.assert_coverage(split=split, difficulty=resolved_difficulty) | |
| entries = cache.validated_entries(split=split, difficulty=resolved_difficulty) | |
| if not entries: | |
| entries = cache.validated_entries(split=split) | |
| if not entries: | |
| raise RuntimeError(f"No validated scenario cache entries found for split={split!r}.") | |
| sample_entry = entries[0] | |
| env = CybersecurityOwaspEnvironment() | |
| try: | |
| obs = env.reset( | |
| seed=int(sample_entry["seed"]), | |
| split=str(sample_entry["split"]), | |
| difficulty=int(sample_entry["difficulty"]), | |
| ) | |
| if not env.state.cache_hit: | |
| raise RuntimeError("Scenario cache preflight reset did not hit cache.") | |
| if env.state.metrics.get("scenario_compile_latency_ms", 0.0): | |
| raise RuntimeError("Scenario cache preflight unexpectedly compiled a scenario.") | |
| sample = { | |
| "phase": obs.phase, | |
| "task_id": env.state.task_id, | |
| "cache_hit": env.state.cache_hit, | |
| "scenario_hash": env.state.scenario_hash, | |
| "reset_latency_ms": env.state.reset_latency_ms, | |
| "bundle_load_latency_ms": env.state.metrics.get( | |
| "scenario_bundle_load_latency_ms", | |
| 0.0, | |
| ), | |
| } | |
| finally: | |
| env.close() | |
| return { | |
| "scenario_cache_volume": SCENARIO_CACHE_VOLUME_NAME, | |
| "scenario_cache_dir": str(SCENARIO_CACHE_DIR), | |
| "scenario_cache_mode": "require", | |
| "split": split, | |
| "difficulty": "adaptive", | |
| "initial_difficulty": resolved_difficulty, | |
| "dataset_size": dataset_size, | |
| "available_scenarios": len(cache.validated_entries(split=split)), | |
| "coverage": coverage, | |
| "sample_reset": sample, | |
| } | |
| def check_training_imports() -> dict[str, str]: | |
| cache_env = _configure_modal_cache_env() | |
| scenario_cache_env = _configure_scenario_cache_env(required=False) | |
| import torch | |
| import trackio | |
| from datasets import Dataset | |
| from trl import GRPOConfig, GRPOTrainer | |
| from unsloth import FastVisionModel | |
| from CyberSecurity_OWASP.server.CyberSecurity_OWASP_environment import ( | |
| CybersecurityOwaspEnvironment, | |
| ) | |
| env = CybersecurityOwaspEnvironment() | |
| obs = env.reset(seed=0, split="validation", difficulty=0) | |
| return { | |
| "torch": torch.__version__, | |
| "trackio": getattr(trackio, "__version__", "unknown"), | |
| "dataset": Dataset.__name__, | |
| "grpo_config": GRPOConfig.__name__, | |
| "grpo_trainer": GRPOTrainer.__name__, | |
| "unsloth_vision_model": FastVisionModel.__name__, | |
| "env": CybersecurityOwaspEnvironment.__name__, | |
| "reset_phase": obs.phase, | |
| "hf_home": cache_env["HF_HOME"], | |
| "hf_hub_cache": cache_env["HF_HUB_CACHE"], | |
| "scenario_cache_dir": scenario_cache_env["CYBERSECURITY_OWASP_SCENARIO_CACHE_DIR"], | |
| } | |
| def run_cybersecurity_owasp_baseline( | |
| max_steps: int = 50, | |
| dataset_size: int = 1, | |
| difficulty: int = 0, | |
| split: str = "train", | |
| model_name: str = DEFAULT_GEMMA_MODEL, | |
| max_seq_length: int = 4096, | |
| max_completion_length: int = 768, | |
| trackio_space_id: str = "Humanlearning/CyberSecurity_OWASP-trackio", | |
| trackio_project: str = "CyberSecurity_OWASP-grpo", | |
| num_generations: int = 1, | |
| trace_log_every: int = 1, | |
| seed_start: int = 0, | |
| git_sha: str = "nogit", | |
| run_name: str = "baseline", | |
| source_mode: str = "local", | |
| repo_url: str = PUBLIC_REPO_URL, | |
| repo_branch: str = PUBLIC_REPO_BRANCH, | |
| reward_config: str = "", | |
| reward_variant: str = "", | |
| ) -> dict[str, str | int | float]: | |
| import statistics | |
| import time | |
| import torch | |
| from huggingface_hub import snapshot_download, whoami | |
| from unsloth import FastVisionModel | |
| import transformers.utils.hub as transformers_hub | |
| from CyberSecurity_OWASP.models import CyberSecurityOWASPAction | |
| from CyberSecurity_OWASP.config import load_scenario_authoring_config | |
| from CyberSecurity_OWASP.reward_config import load_reward_settings | |
| from CyberSecurity_OWASP.server.CyberSecurity_OWASP_environment import ( | |
| CybersecurityOwaspEnvironment, | |
| ) | |
| from CyberSecurity_OWASP.server.curriculum import CurriculumController | |
| from CyberSecurity_OWASP.server.scenario_cache import ScenarioCache | |
| from training.trackio_utils import ( | |
| aggregate_episode_metrics, | |
| episode_record_from_state, | |
| log_reward_config, | |
| log_trace_table, | |
| log_trackio_metrics, | |
| reward_config_trackio_config, | |
| trackio_run, | |
| ) | |
| model_name = _ensure_gemma4_model(model_name) | |
| if int(num_generations) != 1: | |
| raise ValueError("Baseline mode runs the untrained model with --num-generations 1.") | |
| cache_env = _configure_modal_cache_env() | |
| scenario_cache_env = _configure_scenario_cache_env(required=True) | |
| transformers_hub.TRANSFORMERS_CACHE = cache_env["HF_HUB_CACHE"] | |
| hf_token = os.environ.get("HF_TOKEN") | |
| if not hf_token: | |
| raise RuntimeError(f"HF_TOKEN is missing from the Modal secret {SECRET_NAME}.") | |
| try: | |
| whoami(token=hf_token) | |
| except Exception as exc: | |
| raise RuntimeError("HF_TOKEN could not be validated before baseline run.") from exc | |
| os.environ["TRACKIO_SPACE_ID"] = trackio_space_id | |
| os.environ["TRACKIO_PROJECT"] = trackio_project | |
| reward_env = _configure_reward_env( | |
| reward_config=reward_config, | |
| reward_variant=reward_variant, | |
| ) | |
| reward_settings = load_reward_settings() | |
| reward_tracking_config = reward_config_trackio_config(reward_settings) | |
| reward_tracking_config["reward_variant"] = reward_variant or "default" | |
| reward_tracking_config["reward_config_path"] = reward_config or reward_settings.source_path | |
| run_name = run_name or "baseline" | |
| output_dir = RUNS_DIR / run_name | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| try: | |
| cache_volume.reload() | |
| print(f"Reloaded Modal model cache volume: {CACHE_VOLUME_NAME}") | |
| except Exception as exc: | |
| print(f"Model cache volume reload skipped: {exc!r}") | |
| try: | |
| scenario_cache_volume.reload() | |
| print(f"Reloaded Modal scenario cache volume: {SCENARIO_CACHE_VOLUME_NAME}") | |
| except Exception as exc: | |
| print(f"Scenario cache volume reload skipped: {exc!r}") | |
| settings = load_scenario_authoring_config() | |
| scenario_profile = CurriculumController(settings=settings).select_profile( | |
| seed=seed_start, | |
| split=split, | |
| requested_difficulty=difficulty, | |
| ) | |
| resolved_difficulty = int(scenario_profile["difficulty"]) | |
| scenario_cache = ScenarioCache(SCENARIO_CACHE_DIR, settings=settings) | |
| coverage = scenario_cache.assert_coverage( | |
| split=split, | |
| difficulty=resolved_difficulty, | |
| ) | |
| entries = scenario_cache.validated_entries( | |
| split=split, | |
| difficulty=resolved_difficulty, | |
| ) or scenario_cache.validated_entries(split=split) | |
| if not entries: | |
| raise RuntimeError(f"No validated scenario cache entries found for split={split!r}.") | |
| print(f"Baseline run name: {run_name}") | |
| print(f"Source mode: {source_mode}") | |
| if source_mode == "public": | |
| print(f"Installed CyberSecurity_OWASP from public repo: {repo_url}@{repo_branch}") | |
| else: | |
| print("Packaged local CyberSecurity_OWASP repo.") | |
| print(f"Trackio Space: {trackio_space_id}") | |
| print(f"Trackio Project: {trackio_project}") | |
| print(f"Reward config: {reward_tracking_config['reward_config_id']}") | |
| print(f"Reward config hash: {reward_tracking_config['reward_config_hash']}") | |
| print(f"Reward variant: {reward_tracking_config['reward_variant']}") | |
| print(f"Reward config path: {reward_tracking_config['reward_config_path']}") | |
| if reward_env: | |
| print(f"Reward env overrides: {reward_env}") | |
| print(f"Scenario cache dir: {scenario_cache_env['CYBERSECURITY_OWASP_SCENARIO_CACHE_DIR']}") | |
| print(f"Scenario cache coverage: {coverage}") | |
| print( | |
| "Baseline generation config: " | |
| f"episodes={dataset_size}, max_episode_steps={max_steps}, " | |
| f"num_generations={num_generations}, max_completion_length={max_completion_length}, " | |
| f"trace_log_every={trace_log_every}" | |
| ) | |
| expected_model_cache = _hf_model_cache_path(model_name) | |
| print(f"Expected HF model cache path: {expected_model_cache}") | |
| print(f"Model cache hit before load: {expected_model_cache.exists()}") | |
| try: | |
| snapshot_path = snapshot_download( | |
| repo_id=model_name, | |
| cache_dir=str(HF_HUB_CACHE_DIR), | |
| token=hf_token, | |
| ) | |
| print(f"Model snapshot ready: {snapshot_path}") | |
| cache_volume.commit() | |
| except Exception as exc: | |
| print(f"Explicit model snapshot prefetch failed; loading directly. Error: {exc!r}") | |
| model_api = FastVisionModel | |
| model, tokenizer = model_api.from_pretrained( | |
| model_name=model_name, | |
| max_seq_length=max_seq_length, | |
| load_in_4bit=False, | |
| fast_inference=False, | |
| cache_dir=str(HF_HUB_CACHE_DIR), | |
| token=hf_token, | |
| ) | |
| if hasattr(model_api, "for_inference"): | |
| model_api.for_inference(model) | |
| model.eval() | |
| cache_volume.commit() | |
| device = next(model.parameters()).device | |
| text_tokenizer = getattr(tokenizer, "tokenizer", tokenizer) | |
| def render_prompt(observation, actions: list[dict[str, Any]]) -> str: | |
| recent_actions = actions[-8:] | |
| return ( | |
| "You are the untrained baseline model for a defensive local AppSec " | |
| "repair environment. Use only the listed local tools. Return exactly " | |
| "one JSON object and no markdown.\n\n" | |
| f"{observation.scenario_prompt}\n\n" | |
| f"Current phase: {observation.phase}\n" | |
| f"Available actions: {observation.available_actions}\n" | |
| f"Last tool result: {observation.last_tool_result}\n" | |
| f"Recent actions: {json.dumps(recent_actions, sort_keys=True)}\n\n" | |
| 'Required format: {"tool_name":"inspect_policy_graph","arguments":{}}' | |
| ) | |
| def generate_action_text(prompt: str) -> tuple[str, list[int], list[int]]: | |
| messages = [{"role": "user", "content": prompt}] | |
| prompt_text = prompt | |
| for candidate in (tokenizer, text_tokenizer): | |
| if hasattr(candidate, "apply_chat_template"): | |
| try: | |
| prompt_text = candidate.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| break | |
| except Exception: | |
| prompt_text = prompt | |
| encode = tokenizer | |
| try: | |
| inputs = encode( | |
| prompt_text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=max_seq_length, | |
| ) | |
| except Exception: | |
| inputs = text_tokenizer( | |
| prompt_text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=max_seq_length, | |
| ) | |
| if hasattr(inputs, "to"): | |
| inputs = inputs.to(device) | |
| else: | |
| inputs = { | |
| key: value.to(device) if hasattr(value, "to") else value | |
| for key, value in inputs.items() | |
| } | |
| input_ids = inputs.get("input_ids") | |
| input_len = int(input_ids.shape[-1]) if input_ids is not None else 0 | |
| pad_token_id = getattr(text_tokenizer, "pad_token_id", None) | |
| if pad_token_id is None: | |
| pad_token_id = getattr(text_tokenizer, "eos_token_id", None) | |
| with torch.inference_mode(): | |
| generated = model.generate( | |
| **inputs, | |
| max_new_tokens=max_completion_length, | |
| do_sample=False, | |
| pad_token_id=pad_token_id, | |
| ) | |
| output_ids = generated[0] | |
| completion_ids = output_ids[input_len:] | |
| decode = getattr(text_tokenizer, "decode", None) or getattr(tokenizer, "decode") | |
| text = decode(completion_ids, skip_special_tokens=True) | |
| prompt_ids = ( | |
| [int(item) for item in input_ids[0].detach().cpu().tolist()] | |
| if input_ids is not None | |
| else [] | |
| ) | |
| return text, prompt_ids, [int(item) for item in completion_ids.detach().cpu().tolist()] | |
| def action_from_completion(raw_text: str) -> tuple[CyberSecurityOWASPAction, str | None]: | |
| loaded = _extract_first_json_object(raw_text) | |
| if loaded is None: | |
| return CyberSecurityOWASPAction(tool_name="noop", arguments={}), "invalid_json" | |
| arguments = loaded.get("arguments", {}) | |
| if not isinstance(arguments, dict): | |
| arguments = {} | |
| payload = { | |
| "tool_name": loaded.get("tool_name", "noop"), | |
| "arguments": arguments, | |
| } | |
| try: | |
| return CyberSecurityOWASPAction(**payload), None | |
| except Exception as exc: | |
| return ( | |
| CyberSecurityOWASPAction(tool_name="noop", arguments={}), | |
| f"invalid_action_schema: {exc}", | |
| ) | |
| episode_records: list[dict[str, Any]] = [] | |
| raw_traces: list[dict[str, Any]] = [] | |
| invalid_model_outputs = 0 | |
| generation_started = time.monotonic() | |
| config = { | |
| "base_model": model_name, | |
| "algo": "baseline", | |
| "difficulty": difficulty, | |
| "split": split, | |
| "max_episode_steps": max_steps, | |
| "dataset_size": dataset_size, | |
| "num_generations": num_generations, | |
| "max_completion_length": max_completion_length, | |
| "git_sha": git_sha, | |
| "reward_variant": reward_tracking_config["reward_variant"], | |
| **reward_tracking_config, | |
| } | |
| with trackio_run( | |
| run_name=run_name, | |
| run_type="baseline", | |
| config=config, | |
| project=trackio_project, | |
| space_id=trackio_space_id, | |
| group="baseline", | |
| auto_log_gpu=True, | |
| ): | |
| log_reward_config(reward_settings, step=0) | |
| for episode_index in range(max(1, int(dataset_size))): | |
| entry = entries[(seed_start + episode_index) % len(entries)] | |
| env = CybersecurityOwaspEnvironment() | |
| try: | |
| observation = env.reset( | |
| seed=int(entry["seed"]), | |
| split=str(entry["split"]), | |
| difficulty=int(entry["difficulty"]), | |
| ) | |
| env.state.max_steps = int(max_steps) | |
| actions: list[dict[str, Any]] = [] | |
| model_steps: list[dict[str, Any]] = [] | |
| prompt_token_count = 0 | |
| completion_token_count = 0 | |
| for step_index in range(int(max_steps)): | |
| if observation.done: | |
| break | |
| prompt = render_prompt(observation, actions) | |
| raw_text, prompt_ids, completion_ids = generate_action_text(prompt) | |
| prompt_token_count += len(prompt_ids) | |
| completion_token_count += len(completion_ids) | |
| action, invalid_reason = action_from_completion(raw_text) | |
| if invalid_reason: | |
| invalid_model_outputs += 1 | |
| observation = env.step(action) | |
| action_dump = action.model_dump() | |
| actions.append(action_dump) | |
| model_steps.append( | |
| { | |
| "step": step_index + 1, | |
| "raw_completion": raw_text, | |
| "action": action_dump, | |
| "invalid_model_output": invalid_reason, | |
| "observation_message": observation.message, | |
| "reward": observation.reward, | |
| "done": observation.done, | |
| } | |
| ) | |
| env.state.completion_tokens = completion_token_count | |
| env.state.metrics["prompt_tokens"] = prompt_token_count | |
| env.state.metrics["completion_tokens"] = completion_token_count | |
| final_observation = observation.model_dump() | |
| record = episode_record_from_state( | |
| env.state, | |
| run_context={ | |
| "base_model": model_name, | |
| "algo": "baseline", | |
| "reward_version": "reward_v2", | |
| "env_version": "0.1.0", | |
| **reward_tracking_config, | |
| }, | |
| final_observation=final_observation, | |
| ) | |
| record.update( | |
| { | |
| "reward_total": float(env.state.accumulated_reward), | |
| "success": bool(env.state.success), | |
| "episode_length": int(env.state.step_count), | |
| "invalid_model_output_count": sum( | |
| 1 for item in model_steps if item["invalid_model_output"] | |
| ), | |
| "prompt_tokens": prompt_token_count, | |
| "completion_tokens": completion_token_count, | |
| } | |
| ) | |
| episode_records.append(record) | |
| raw_traces.append( | |
| { | |
| "episode_index": episode_index, | |
| "task_id": env.state.task_id, | |
| "seed": env.state.seed, | |
| "split": env.state.split, | |
| "difficulty": env.state.difficulty, | |
| "domain": env.state.domain, | |
| "bug_family": env.state.bug_family, | |
| "steps": model_steps, | |
| } | |
| ) | |
| finally: | |
| env.close() | |
| metrics = aggregate_episode_metrics(episode_records) | |
| metrics.update( | |
| { | |
| "baseline/episode_count": float(len(episode_records)), | |
| "baseline/reward_total_mean": statistics.mean( | |
| float(item.get("reward_total", 0.0)) for item in episode_records | |
| ), | |
| "baseline/success_rate": statistics.mean( | |
| 1.0 if item.get("success") else 0.0 for item in episode_records | |
| ), | |
| "baseline/invalid_model_output_rate": invalid_model_outputs | |
| / max(1.0, sum(float(item.get("episode_length", 0)) for item in episode_records)), | |
| "baseline/num_generations": float(num_generations), | |
| "baseline/max_episode_steps": float(max_steps), | |
| "baseline/max_completion_length": float(max_completion_length), | |
| } | |
| ) | |
| log_trackio_metrics(metrics, step=episode_index + 1) | |
| if trace_log_every > 0 and ( | |
| episode_index == 0 or (episode_index + 1) % trace_log_every == 0 | |
| ): | |
| log_trace_table( | |
| [episode_records[-1]], | |
| table_name="baseline_traces", | |
| step=episode_index + 1, | |
| ) | |
| elapsed_s = time.monotonic() - generation_started | |
| summary = { | |
| "run_name": run_name, | |
| "trackio_space_id": trackio_space_id, | |
| "trackio_project": trackio_project, | |
| "model_name": model_name, | |
| "dataset_size": len(episode_records), | |
| "max_episode_steps": int(max_steps), | |
| "difficulty": int(difficulty), | |
| "split": split, | |
| "num_generations": int(num_generations), | |
| "max_completion_length": int(max_completion_length), | |
| "mean_reward": ( | |
| statistics.mean(float(item.get("reward_total", 0.0)) for item in episode_records) | |
| if episode_records | |
| else 0.0 | |
| ), | |
| "success_rate": ( | |
| statistics.mean(1.0 if item.get("success") else 0.0 for item in episode_records) | |
| if episode_records | |
| else 0.0 | |
| ), | |
| "invalid_model_output_count": int(invalid_model_outputs), | |
| "elapsed_s": elapsed_s, | |
| **reward_tracking_config, | |
| } | |
| artifact_path = output_dir / "baseline_rollouts.json" | |
| artifact_path.write_text( | |
| json.dumps( | |
| { | |
| "summary": summary, | |
| "episodes": episode_records, | |
| "raw_traces": raw_traces, | |
| }, | |
| indent=2, | |
| sort_keys=True, | |
| default=str, | |
| ), | |
| encoding="utf-8", | |
| ) | |
| volume.commit() | |
| cache_volume.commit() | |
| scenario_cache_volume.commit() | |
| print(f"Baseline artifact saved to {artifact_path}") | |
| return {**summary, "artifact_path": str(artifact_path)} | |
| def train_cybersecurity_owasp_grpo( | |
| env_repo_id: str = "", | |
| output_repo_id: str = "", | |
| initial_adapter_path: str = "", | |
| initial_adapter_repo_id: str = "", | |
| max_steps: int = 10, | |
| dataset_size: int = 16, | |
| difficulty: int = 0, | |
| split: str = "train", | |
| model_name: str = DEFAULT_GEMMA_MODEL, | |
| max_seq_length: int = 4096, | |
| max_completion_length: int = 768, | |
| lora_rank: int = 32, | |
| trackio_space_id: str = "Humanlearning/CyberSecurity_OWASP-trackio", | |
| trackio_project: str = "CyberSecurity_OWASP-grpo", | |
| num_generations: int = 6, | |
| per_device_train_batch_size: int = 1, | |
| gradient_accumulation_steps: int = 0, | |
| learning_rate: float = 5e-6, | |
| use_vllm: bool = False, | |
| vllm_gpu_memory_utilization: float = 0.2, | |
| trace_log_every: int = 5, | |
| seed_start: int = 0, | |
| git_sha: str = "nogit", | |
| run_name: str = "", | |
| source_mode: str = "local", | |
| repo_url: str = PUBLIC_REPO_URL, | |
| repo_branch: str = PUBLIC_REPO_BRANCH, | |
| push_to_hub: bool = False, | |
| reward_config: str = "", | |
| reward_variant: str = "", | |
| ) -> dict[str, str | int | float]: | |
| import inspect | |
| import statistics | |
| import threading | |
| import time | |
| model_name = _ensure_gemma4_model(model_name) | |
| cache_env = _configure_modal_cache_env() | |
| world_size = int(os.environ.get("WORLD_SIZE", "1") or "1") | |
| ( | |
| resolved_gradient_accumulation_steps, | |
| effective_train_batch_size, | |
| ) = _resolve_grpo_batch_config( | |
| per_device_train_batch_size=per_device_train_batch_size, | |
| gradient_accumulation_steps=gradient_accumulation_steps, | |
| num_generations=num_generations, | |
| world_size=world_size, | |
| ) | |
| _validate_vllm_config( | |
| use_vllm=use_vllm, | |
| vllm_gpu_memory_utilization=vllm_gpu_memory_utilization, | |
| ) | |
| trace_log_every = max(0, int(trace_log_every)) | |
| import torch | |
| from safetensors.torch import load_file as load_safetensors_file | |
| from unsloth import FastVisionModel | |
| import transformers.utils.hub as transformers_hub | |
| from datasets import Dataset | |
| from huggingface_hub import snapshot_download, whoami | |
| from peft import set_peft_model_state_dict | |
| from transformers import TrainerCallback | |
| from trl import GRPOConfig, GRPOTrainer, clone_chat_template | |
| try: | |
| from trl.chat_template_utils import add_response_schema | |
| except ImportError: | |
| def add_response_schema(tokenizer): | |
| return tokenizer | |
| import trackio | |
| from CyberSecurity_OWASP.models import CyberSecurityOWASPAction | |
| from CyberSecurity_OWASP.config import load_scenario_authoring_config | |
| from CyberSecurity_OWASP.server.CyberSecurity_OWASP_environment import ( | |
| CybersecurityOwaspEnvironment, | |
| ) | |
| from CyberSecurity_OWASP.reward_config import ( | |
| compute_token_penalty, | |
| load_reward_settings, | |
| ) | |
| from CyberSecurity_OWASP.server.curriculum import CurriculumController | |
| from CyberSecurity_OWASP.server.scenario_cache import ScenarioCache | |
| from training.trackio_utils import ( | |
| aggregate_episode_metrics, | |
| episode_record_from_state, | |
| episode_trace_fingerprint, | |
| log_reward_config, | |
| log_gpu_metrics, | |
| log_trace_table, | |
| log_trackio_metrics, | |
| reward_config_trackio_config, | |
| train_metric_aliases, | |
| ) | |
| from training.grpo_curriculum import ( | |
| ScenarioGroupRegistry, | |
| build_scenario_group_rows, | |
| ) | |
| transformers_hub.TRANSFORMERS_CACHE = cache_env["HF_HUB_CACHE"] | |
| hf_token = os.environ.get("HF_TOKEN") | |
| if not hf_token: | |
| raise RuntimeError( | |
| f"HF_TOKEN is missing from the Modal secret {SECRET_NAME}." | |
| ) | |
| user = whoami(token=hf_token)["name"] | |
| env_repo_id = env_repo_id or f"{user}/CyberSecurity_OWASP" | |
| output_repo_id = output_repo_id or ( | |
| f"{user}/{_grpo_output_repo_slug(model_name, initial_adapter_path=initial_adapter_path, initial_adapter_repo_id=initial_adapter_repo_id)}" | |
| ) | |
| if not trackio_space_id: | |
| trackio_space_id = "Humanlearning/CyberSecurity_OWASP-trackio" | |
| if hf_token: | |
| try: | |
| from huggingface_hub import whoami | |
| user = whoami(token=hf_token)["name"] | |
| if user == "humandotlearning": | |
| trackio_space_id = f"{user}/CyberSecurity_OWASP-trackio" | |
| except Exception: | |
| pass | |
| os.environ["TRACKIO_SPACE_ID"] = trackio_space_id | |
| os.environ["TRACKIO_PROJECT"] = trackio_project | |
| reward_env = _configure_reward_env( | |
| reward_config=reward_config, | |
| reward_variant=reward_variant, | |
| reward_mode="dense_train", | |
| ) | |
| reward_settings = load_reward_settings() | |
| reward_tracking_config = reward_config_trackio_config(reward_settings) | |
| reward_tracking_config["reward_variant"] = reward_variant or "default" | |
| reward_tracking_config["reward_config_path"] = reward_config or reward_settings.source_path | |
| model_slug = model_name.replace("/", "-") | |
| stamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S") | |
| algo_tag = _grpo_run_algo_tag( | |
| initial_adapter_path=initial_adapter_path, | |
| initial_adapter_repo_id=initial_adapter_repo_id, | |
| ) | |
| run_name = run_name or ( | |
| f"CyberSecurity_OWASP-{model_slug}-{algo_tag}-level{difficulty}-" | |
| f"{reward_tracking_config['reward_variant']}-steps{max_steps}-seed{seed_start}-" | |
| f"{stamp}-{git_sha[:8]}" | |
| ) | |
| output_dir = RUNS_DIR / run_name | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| try: | |
| cache_volume.reload() | |
| print(f"Reloaded Modal model cache volume: {CACHE_VOLUME_NAME}") | |
| except Exception as exc: | |
| print(f"Model cache volume reload skipped: {exc!r}") | |
| try: | |
| scenario_cache_volume.reload() | |
| print(f"Reloaded Modal scenario cache volume: {SCENARIO_CACHE_VOLUME_NAME}") | |
| except Exception as exc: | |
| print(f"Scenario cache volume reload skipped: {exc!r}") | |
| cache_env = _configure_modal_cache_env() | |
| scenario_cache_env = _configure_scenario_cache_env(required=True) | |
| scenario_settings = load_scenario_authoring_config() | |
| scenario_profile = CurriculumController(settings=scenario_settings).select_profile( | |
| seed=seed_start, | |
| split=split, | |
| requested_difficulty=difficulty, | |
| ) | |
| scenario_cache = ScenarioCache(SCENARIO_CACHE_DIR, settings=scenario_settings) | |
| scenario_cache_coverage = scenario_cache.assert_coverage( | |
| split=split, | |
| difficulty=int(scenario_profile["difficulty"]), | |
| ) | |
| scenario_entries = scenario_cache.validated_entries(split=split) | |
| scenario_registry = ScenarioGroupRegistry( | |
| scenario_entries, | |
| split=split, | |
| initial_difficulty=int(scenario_profile["difficulty"]), | |
| rng_seed=seed_start, | |
| max_level=scenario_settings.curriculum.difficulty_bucket_count - 1, | |
| ) | |
| training_prompt = ( | |
| "You are a defensive AppSec repair agent in the local CyberSecurity_OWASP " | |
| "OpenEnv environment. Use only the provided local tools. Do not target real " | |
| "systems. Work step by step: inspect policy and generated code, reproduce the " | |
| "authorization issue locally, submit a policy-tied diagnosis, patch the generated " | |
| "app, run visible tests, then submit the fix. Do not write explanations unless " | |
| "a tool argument needs evidence text." | |
| ) | |
| dataset = Dataset.from_list( | |
| build_scenario_group_rows( | |
| dataset_size=dataset_size, | |
| training_prompt=training_prompt, | |
| seed_start=seed_start, | |
| split=split, | |
| difficulty=difficulty, | |
| difficulty_policy="adaptive", | |
| ) | |
| ) | |
| def _state_snapshot(env: CybersecurityOwaspEnvironment) -> dict[str, Any]: | |
| state = env.state | |
| return { | |
| "episode_id": state.episode_id, | |
| "task_id": state.task_id, | |
| "seed": state.seed, | |
| "split": state.split, | |
| "difficulty": state.difficulty, | |
| "difficulty_tier": state.difficulty_tier, | |
| "domain": state.domain, | |
| "bug_family": state.bug_family, | |
| "template_id": state.template_id, | |
| "cache_hit": state.cache_hit, | |
| "scenario_hash": state.scenario_hash, | |
| "phase": state.phase, | |
| "step_count": state.step_count, | |
| "done": state.done, | |
| "success": state.success, | |
| "failure_reason": state.failure_reason, | |
| "anti_cheat_flags": list(state.anti_cheat_flags), | |
| } | |
| class CyberSecurityOWASPToolEnv: | |
| def __init__(self): | |
| self._env = CybersecurityOwaspEnvironment() | |
| self.reward = 0.0 | |
| self.reward_breakdown: dict[str, float] = {} | |
| self.done = False | |
| self.success = False | |
| self.invalid_actions = 0 | |
| self.scenario_group_id = -1 | |
| self.scenario_assignment: dict[str, Any] = {} | |
| self.trace_messages: list[dict[str, str]] = [] | |
| self.trace_metadata: dict[str, Any] = {} | |
| def reset(self, **kwargs) -> str: | |
| group_id = int(kwargs.get("scenario_group_id", kwargs.get("seed", seed_start))) | |
| assignment = scenario_registry.assignment_for( | |
| scenario_group_id=group_id, | |
| requested_seed=int(kwargs.get("seed", seed_start)), | |
| requested_difficulty=int(kwargs.get("difficulty", difficulty)), | |
| split=str(kwargs.get("split", split)), | |
| difficulty_policy=str(kwargs.get("difficulty_policy", "adaptive")), | |
| ) | |
| seed = int(assignment["seed"]) | |
| current_difficulty = int(assignment["difficulty"]) | |
| current_split = str(assignment["split"]) | |
| obs = self._env.reset( | |
| seed=seed, | |
| split=current_split, | |
| difficulty=current_difficulty, | |
| ) | |
| self.scenario_group_id = group_id | |
| self.scenario_assignment = assignment | |
| self.reward = 0.0 | |
| self.reward_breakdown = {} | |
| self.done = bool(obs.done) | |
| self.success = False | |
| self.invalid_actions = 0 | |
| self.trace_messages = [ | |
| { | |
| "role": "user", | |
| "content": ( | |
| f"{training_prompt}\n\n" | |
| f"{obs.scenario_prompt}\n\n" | |
| f"Initial message: {obs.message}" | |
| ), | |
| } | |
| ] | |
| self.trace_metadata = _state_snapshot(self._env) | |
| self.trace_metadata.update( | |
| { | |
| "scenario_group_id": self.scenario_group_id, | |
| "scenario_assignment": dict(self.scenario_assignment), | |
| "scenario_prompt_length": len(obs.scenario_prompt), | |
| "reward_config_id": reward_tracking_config["reward_config_id"], | |
| "reward_config_hash": reward_tracking_config["reward_config_hash"], | |
| "reward_stage": reward_tracking_config["reward_stage"], | |
| "reward_mode": reward_tracking_config["reward_mode"], | |
| "reward_variant": reward_tracking_config["reward_variant"], | |
| } | |
| ) | |
| return obs.scenario_prompt | |
| def _step(self, tool_name: str, arguments: dict[str, Any] | None = None) -> str: | |
| if self.done: | |
| raise ValueError("Episode is already over.") | |
| action = CyberSecurityOWASPAction( | |
| tool_name=tool_name, | |
| arguments=arguments or {}, | |
| ) | |
| obs = self._env.step(action) | |
| if not obs.last_action_valid: | |
| self.invalid_actions += 1 | |
| self.reward = float(self._env.state.accumulated_reward) | |
| self.reward_breakdown = dict(obs.reward_breakdown or {}) | |
| self.done = bool(obs.done) | |
| self.success = bool(self._env.state.success) | |
| self.trace_messages.extend( | |
| [ | |
| { | |
| "role": "assistant", | |
| "content": f"{tool_name}({arguments or {}})", | |
| }, | |
| {"role": "tool", "content": obs.message}, | |
| ] | |
| ) | |
| self.trace_metadata.update(_state_snapshot(self._env)) | |
| self.trace_metadata.update( | |
| { | |
| "last_action_valid": obs.last_action_valid, | |
| "last_action_error": obs.last_action_error, | |
| "reward": self.reward, | |
| "reward_breakdown": self.reward_breakdown, | |
| "invalid_actions": self.invalid_actions, | |
| "scenario_cache_hit": self._env.state.cache_hit, | |
| "scenario_hash": self._env.state.scenario_hash, | |
| "scenario_group_id": self.scenario_group_id, | |
| "scenario_assignment": dict(self.scenario_assignment), | |
| } | |
| ) | |
| return obs.message | |
| def inspect_policy_graph(self) -> str: | |
| """Return public policy hints for the generated local scenario.""" | |
| return self._step("inspect_policy_graph") | |
| def list_routes(self) -> str: | |
| """List generated local app route summaries.""" | |
| return self._step("list_routes") | |
| def read_openapi(self) -> str: | |
| """Read generated OpenAPI metadata for the local app.""" | |
| return self._step("read_openapi") | |
| def read_file(self, path: str) -> str: | |
| """ | |
| Read an editable generated workspace file by relative path. | |
| Args: | |
| path: Relative path inside the generated editable workspace. | |
| Returns: | |
| The file contents or a safe tool error observation. | |
| """ | |
| return self._step("read_file", {"path": path}) | |
| def search_code(self, query: str) -> str: | |
| """ | |
| Search editable generated workspace files for a string. | |
| Args: | |
| query: Search text to find in editable generated app files. | |
| Returns: | |
| Matching file lines or a no-match message. | |
| """ | |
| return self._step("search_code", {"query": query}) | |
| def send_local_request( | |
| self, | |
| path: str, | |
| method: str = "GET", | |
| user_id: str | None = None, | |
| ) -> str: | |
| """ | |
| Send a request to the generated local app only. | |
| Args: | |
| path: Local route path such as /health or /invoices/<id>. | |
| method: HTTP method to use for the local request. | |
| user_id: Optional generated user identifier for authentication. | |
| Returns: | |
| JSON response from the simulated local app request. | |
| """ | |
| return self._step( | |
| "send_local_request", | |
| {"path": path, "method": method, "user_id": user_id}, | |
| ) | |
| def compare_identities( | |
| self, | |
| path: str, | |
| first_user_id: str, | |
| second_user_id: str, | |
| method: str = "GET", | |
| ) -> str: | |
| """ | |
| Compare one local request as two generated users. | |
| Args: | |
| path: Local route path to request as both generated users. | |
| first_user_id: First generated user identifier. | |
| second_user_id: Second generated user identifier. | |
| method: HTTP method to use for both local requests. | |
| Returns: | |
| JSON summary of both simulated local responses. | |
| """ | |
| return self._step( | |
| "compare_identities", | |
| { | |
| "path": path, | |
| "method": method, | |
| "first_user_id": first_user_id, | |
| "second_user_id": second_user_id, | |
| }, | |
| ) | |
| def submit_diagnosis( | |
| self, | |
| bug_class: str, | |
| route: str, | |
| violated_policy_rule: str, | |
| evidence_trace_ids: list[str], | |
| fix_plan: str, | |
| ) -> str: | |
| """ | |
| Submit structured diagnosis for the suspected authorization bug. | |
| Args: | |
| bug_class: Short class such as idor_ownership_bug. | |
| route: Method and route pattern believed to be vulnerable. | |
| violated_policy_rule: Policy rule that the behavior violates. | |
| evidence_trace_ids: Request trace IDs from local evidence tools. | |
| fix_plan: Concise secure repair plan. | |
| Returns: | |
| Diagnosis acceptance result and next phase information. | |
| """ | |
| return self._step( | |
| "submit_diagnosis", | |
| { | |
| "bug_class": bug_class, | |
| "route": route, | |
| "violated_policy_rule": violated_policy_rule, | |
| "evidence_trace_ids": evidence_trace_ids, | |
| "fix_plan": fix_plan, | |
| }, | |
| ) | |
| def patch_file( | |
| self, | |
| path: str, | |
| content: str | None = None, | |
| diff: str | None = None, | |
| ) -> str: | |
| """ | |
| Patch an editable generated app file with full content or a unified diff. | |
| Args: | |
| path: Relative path of the editable generated app file to patch. | |
| content: Complete replacement file content, when using full-file patching. | |
| diff: Unified diff to apply, when using diff patching. | |
| Returns: | |
| Patch application result. | |
| """ | |
| args: dict[str, Any] = {"path": path} | |
| if content is not None: | |
| args["content"] = content | |
| if diff is not None: | |
| args["diff"] = diff | |
| return self._step("patch_file", args) | |
| def run_visible_tests(self) -> str: | |
| """Run visible tests only; hidden tests are never exposed.""" | |
| return self._step("run_visible_tests") | |
| def submit_fix(self) -> str: | |
| """Submit the final patch to the hidden deterministic verifier.""" | |
| return self._step("submit_fix") | |
| def noop(self) -> str: | |
| """Take no action.""" | |
| return self._step("noop") | |
| def _score(self, completion_tokens: int = 0) -> float: | |
| token_penalty = compute_token_penalty(completion_tokens) | |
| self._env.state.completion_tokens = int(completion_tokens) | |
| self._env.state.metrics["completion_tokens"] = int(completion_tokens) | |
| self._env.state.metrics["token_penalty"] = token_penalty | |
| return float(self._env.state.accumulated_reward + token_penalty) | |
| def __del__(self): | |
| try: | |
| self._env.close() | |
| except Exception: | |
| pass | |
| trace_step = {"value": 0} | |
| logged_trace_fingerprints: set[str] = set() | |
| def _completion_to_text(completion) -> str: | |
| if completion is None: | |
| return "" | |
| if isinstance(completion, str): | |
| return completion | |
| if isinstance(completion, list): | |
| parts = [] | |
| for item in completion: | |
| if isinstance(item, dict): | |
| parts.append(str(item.get("content", item))) | |
| else: | |
| parts.append(str(item)) | |
| return "\n".join(parts) | |
| return str(completion) | |
| def _mean(values: list[float]) -> float: | |
| return float(sum(values) / len(values)) if values else 0.0 | |
| def cybersecurity_owasp_reward(environments, **kwargs) -> list[float]: | |
| completions = kwargs.get("completions") or kwargs.get("completion") or [] | |
| completion_texts = [_completion_to_text(item) for item in completions] | |
| completion_tokens = [len(text.split()) for text in completion_texts] | |
| rewards = [ | |
| float(env._score(completion_tokens[index] if index < len(completion_tokens) else 0)) | |
| for index, env in enumerate(environments) | |
| ] | |
| trace_step["value"] += 1 | |
| episode_records = [] | |
| for index, (env, reward) in enumerate(zip(environments, rewards)): | |
| record = episode_record_from_state( | |
| env._env.state, | |
| run_context={ | |
| "base_model": model_name, | |
| "algo": "grpo", | |
| "reward_version": "reward_v2", | |
| "env_version": "0.1.0", | |
| **reward_tracking_config, | |
| }, | |
| ) | |
| record.update( | |
| { | |
| "reward_total": reward, | |
| "reward_token_penalty": float(env._env.state.metrics.get("token_penalty", 0.0)), | |
| "completion_tokens": completion_tokens[index] if index < len(completion_tokens) else 0, | |
| "success": bool(getattr(env, "success", False)), | |
| } | |
| ) | |
| episode_records.append(record) | |
| group_successes: dict[int, list[float]] = {} | |
| for env in environments: | |
| group_id = int(getattr(env, "scenario_group_id", -1)) | |
| if group_id < 0: | |
| continue | |
| group_successes.setdefault(group_id, []).append(1.0 if getattr(env, "success", False) else 0.0) | |
| for group_id, successes in group_successes.items(): | |
| scenario_registry.record_group_outcome(group_id, _mean(successes)) | |
| batch_fingerprints = [ | |
| episode_trace_fingerprint(record) | |
| for record in episode_records | |
| ] | |
| sampled_traces = [] | |
| seen_this_batch: set[str] = set() | |
| duplicate_trace_suppressed_count = 0 | |
| for index, (env, record, reward, fingerprint) in enumerate( | |
| zip(environments, episode_records, rewards, batch_fingerprints) | |
| ): | |
| if fingerprint in seen_this_batch or fingerprint in logged_trace_fingerprints: | |
| duplicate_trace_suppressed_count += 1 | |
| continue | |
| seen_this_batch.add(fingerprint) | |
| if len(sampled_traces) < 4: | |
| sampled_traces.append((index, env, record, reward, fingerprint)) | |
| should_log_trace_artifacts = trace_log_every > 0 and ( | |
| trace_step["value"] == 1 | |
| or trace_step["value"] % trace_log_every == 0 | |
| ) | |
| canonical_metrics = aggregate_episode_metrics(episode_records) | |
| metrics = { | |
| **canonical_metrics, | |
| **train_metric_aliases(canonical_metrics), | |
| **scenario_registry.metrics( | |
| episode_records, | |
| unique_trace_count=len(set(batch_fingerprints)), | |
| duplicate_trace_suppressed_count=duplicate_trace_suppressed_count, | |
| ), | |
| } | |
| metrics["train/per_device_train_batch_size"] = float(per_device_train_batch_size) | |
| metrics["train/gradient_accumulation_steps"] = float( | |
| resolved_gradient_accumulation_steps | |
| ) | |
| metrics["train/effective_train_batch_size"] = float(effective_train_batch_size) | |
| metrics["train/num_generations"] = float(num_generations) | |
| metrics["train/use_vllm"] = float(bool(use_vllm)) | |
| metrics["train/vllm_gpu_memory_utilization"] = ( | |
| float(vllm_gpu_memory_utilization) if use_vllm else 0.0 | |
| ) | |
| metrics["train/trace_log_every"] = float(trace_log_every) | |
| metrics["train/trace_artifacts_logged"] = float(should_log_trace_artifacts) | |
| if rewards: | |
| metrics["train/reward_mean"] = _mean(rewards) | |
| metrics["train/reward_std"] = statistics.pstdev(rewards) if len(rewards) > 1 else 0.0 | |
| try: | |
| log_trackio_metrics(metrics, step=trace_step["value"]) | |
| except Exception as exc: | |
| print(f"Trackio metric logging skipped: {exc!r}") | |
| if should_log_trace_artifacts and sampled_traces: | |
| try: | |
| log_trace_table( | |
| [record for _, _, record, _, _ in sampled_traces], | |
| table_name="sample_traces", | |
| step=trace_step["value"], | |
| ) | |
| except Exception as exc: | |
| print(f"Trackio sample trace table logging skipped: {exc!r}") | |
| for index, env, _record, reward, fingerprint in sampled_traces: | |
| logged_trace_fingerprints.add(fingerprint) | |
| messages = list(getattr(env, "trace_messages", [])) | |
| if index < len(completions): | |
| completion_text = _completion_to_text(completions[index]) | |
| if completion_text: | |
| messages.append( | |
| { | |
| "role": "assistant", | |
| "content": f"Raw generated completion:\n{completion_text}", | |
| } | |
| ) | |
| metadata = dict(getattr(env, "trace_metadata", {})) | |
| metadata.update( | |
| { | |
| "sample_index": index, | |
| "reward": reward, | |
| "trace_step": trace_step["value"], | |
| "trace_fingerprint": fingerprint, | |
| "num_generations": num_generations, | |
| "run_name": run_name, | |
| "reward_config_id": reward_tracking_config["reward_config_id"], | |
| "reward_config_hash": reward_tracking_config["reward_config_hash"], | |
| "reward_stage": reward_tracking_config["reward_stage"], | |
| "reward_mode": reward_tracking_config["reward_mode"], | |
| "reward_variant": reward_tracking_config["reward_variant"], | |
| } | |
| ) | |
| try: | |
| trackio.log( | |
| { | |
| f"cybersecurity_owasp_trace/sample_{index}": trackio.Trace( | |
| messages=messages, | |
| metadata=metadata, | |
| ) | |
| }, | |
| step=trace_step["value"], | |
| ) | |
| except Exception as exc: | |
| print(f"Trackio trace logging skipped: {exc!r}") | |
| elif sampled_traces: | |
| print( | |
| "Trackio trace artifacts throttled at reward callback " | |
| f"{trace_step['value']}; set --trace-log-every 1 for every callback " | |
| "or 0 to disable trace artifacts." | |
| ) | |
| if rewards: | |
| print( | |
| "Reward batch: " | |
| f"mean={statistics.mean(rewards):.3f}, " | |
| f"min={min(rewards):.3f}, max={max(rewards):.3f}" | |
| ) | |
| return rewards | |
| class TrackioSystemMetricsCallback(TrainerCallback): | |
| def on_train_begin(self, args, state, control, **kwargs): | |
| try: | |
| reward_summary = log_reward_config(reward_settings, step=int(state.global_step or 0)) | |
| metrics = log_gpu_metrics(step=int(state.global_step or 0)) | |
| log_trackio_metrics( | |
| { | |
| "system/model_cache_hit": float(cache_hit), | |
| "system/scenario_cache_required": 1.0, | |
| "system/scenario_cache_entries": float( | |
| scenario_cache_coverage.get("entries", 0) | |
| ), | |
| "system/hub_push_enabled": float(push_to_hub), | |
| }, | |
| step=int(state.global_step or 0), | |
| ) | |
| print( | |
| "Trackio reward config logged: " | |
| f"{reward_summary['reward_config_id']} " | |
| f"({reward_summary['reward_config_hash']})" | |
| ) | |
| except Exception as exc: | |
| print(f"Trackio initialization metrics skipped: {exc!r}") | |
| return control | |
| if metrics: | |
| system_summary = ", ".join( | |
| f"{key}={value}" | |
| for key, value in sorted(metrics.items()) | |
| if key.startswith("system/") | |
| ) | |
| print(f"Trackio GPU metrics initialized: {system_summary}") | |
| return control | |
| def on_log(self, args, state, control, logs=None, **kwargs): | |
| try: | |
| metrics = log_gpu_metrics(step=int(state.global_step or 0)) | |
| except Exception as exc: | |
| print(f"Trackio GPU metrics skipped: {exc!r}") | |
| return control | |
| if metrics: | |
| summary = ", ".join(f"{key}={value}" for key, value in sorted(metrics.items())[:4]) | |
| print(f"Trackio GPU metrics logged at step {state.global_step}: {summary}") | |
| return control | |
| def on_train_end(self, args, state, control, **kwargs): | |
| try: | |
| log_gpu_metrics(step=int(state.global_step or 0)) | |
| except Exception as exc: | |
| print(f"Trackio final GPU metrics skipped: {exc!r}") | |
| return control | |
| print(f"CUDA available: {torch.cuda.is_available()}") | |
| if source_mode == "public": | |
| print(f"Installed CyberSecurity_OWASP from public repo: {repo_url}@{repo_branch}") | |
| else: | |
| print(f"Packaged local CyberSecurity_OWASP repo; default env repo id: {env_repo_id}") | |
| print(f"Trackio Space: {trackio_space_id}") | |
| print(f"Trackio Project: {trackio_project}") | |
| print(f"Output repo: {output_repo_id}") | |
| print(f"Run name: {run_name}") | |
| print(f"Reward config: {reward_tracking_config['reward_config_id']}") | |
| print(f"Reward config hash: {reward_tracking_config['reward_config_hash']}") | |
| print(f"Reward variant: {reward_tracking_config['reward_variant']}") | |
| print(f"Reward config path: {reward_tracking_config['reward_config_path']}") | |
| print(f"Learning rate: {learning_rate}") | |
| print(f"Reward env overrides: {reward_env}") | |
| print(f"Model cache volume: {CACHE_VOLUME_NAME}") | |
| print(f"Scenario cache volume: {SCENARIO_CACHE_VOLUME_NAME}") | |
| print(f"Scenario cache dir: {scenario_cache_env['CYBERSECURITY_OWASP_SCENARIO_CACHE_DIR']}") | |
| print("Scenario cache mode: require") | |
| print(f"Scenario cache coverage: {scenario_cache_coverage}") | |
| print(f"HF_HOME: {cache_env['HF_HOME']}") | |
| print(f"HF_HUB_CACHE: {cache_env['HF_HUB_CACHE']}") | |
| print(f"Torch cache: {cache_env['TORCH_HOME']}") | |
| print(f"Unsloth cache: {cache_env['UNSLOTH_CACHE_DIR']}") | |
| print(f"Triton cache: {cache_env['TRITON_CACHE_DIR']}") | |
| print(f"Hub push enabled: {push_to_hub}") | |
| if initial_adapter_path: | |
| print(f"Initial SFT adapter path: {initial_adapter_path}") | |
| if initial_adapter_repo_id: | |
| print(f"Initial SFT adapter repo: https://huggingface.co/{initial_adapter_repo_id}") | |
| print( | |
| "GRPO throughput config: " | |
| f"per_device_train_batch_size={per_device_train_batch_size}, " | |
| f"gradient_accumulation_steps={resolved_gradient_accumulation_steps}, " | |
| f"num_generations={num_generations}, " | |
| f"world_size={world_size}, " | |
| f"effective_train_batch_size={effective_train_batch_size}" | |
| ) | |
| print( | |
| "Generation acceleration config: " | |
| f"use_vllm={use_vllm}, " | |
| f"vllm_gpu_memory_utilization={vllm_gpu_memory_utilization}, " | |
| f"trace_log_every={trace_log_every}" | |
| ) | |
| expected_model_cache = _hf_model_cache_path(model_name) | |
| cache_hit = expected_model_cache.exists() | |
| print(f"Expected HF model cache path: {expected_model_cache}") | |
| print(f"Model cache hit before load: {cache_hit}") | |
| if cache_hit: | |
| print("Using cached model snapshot from the persistent Modal volume when valid.") | |
| else: | |
| print( | |
| "Model cache miss. Downloading model weights once into the persistent " | |
| "Modal cache volume; Hugging Face progress output should follow." | |
| ) | |
| try: | |
| snapshot_path = snapshot_download( | |
| repo_id=model_name, | |
| cache_dir=str(HF_HUB_CACHE_DIR), | |
| token=hf_token, | |
| ) | |
| print(f"Model snapshot ready: {snapshot_path}") | |
| cache_volume.commit() | |
| print(f"Committed Modal model cache volume after snapshot download: {CACHE_VOLUME_NAME}") | |
| except Exception as exc: | |
| print( | |
| "Explicit model snapshot prefetch failed; Unsloth will attempt the " | |
| f"model load directly. Error: {exc!r}" | |
| ) | |
| print(f"Loading model with Unsloth from_pretrained: {model_name}") | |
| model_api = FastVisionModel | |
| model_load_values = { | |
| "model_name": model_name, | |
| "max_seq_length": max_seq_length, | |
| "load_in_4bit": False, | |
| "fast_inference": use_vllm, | |
| "gpu_memory_utilization": vllm_gpu_memory_utilization if use_vllm else None, | |
| "cache_dir": str(HF_HUB_CACHE_DIR), | |
| "token": hf_token, | |
| } | |
| from_pretrained_parameters = inspect.signature(model_api.from_pretrained).parameters | |
| from_pretrained_accepts_kwargs = any( | |
| parameter.kind == inspect.Parameter.VAR_KEYWORD | |
| for parameter in from_pretrained_parameters.values() | |
| ) | |
| skipped_model_load_keys = sorted( | |
| key | |
| for key, value in model_load_values.items() | |
| if value is not None | |
| and key not in from_pretrained_parameters | |
| and not from_pretrained_accepts_kwargs | |
| ) | |
| if skipped_model_load_keys: | |
| print(f"Skipping unsupported from_pretrained keys: {skipped_model_load_keys}") | |
| model, tokenizer = model_api.from_pretrained( | |
| **{ | |
| key: value | |
| for key, value in model_load_values.items() | |
| if value is not None | |
| and (key in from_pretrained_parameters or from_pretrained_accepts_kwargs) | |
| } | |
| ) | |
| print("Model load complete.") | |
| cache_volume.commit() | |
| print(f"Committed Modal model cache volume after model load: {CACHE_VOLUME_NAME}") | |
| try: | |
| tokenizer = add_response_schema(tokenizer) | |
| except Exception as exc: | |
| print( | |
| "Tokenizer response schema add skipped for Gemma 4 processor, " | |
| "matching the Unsloth Gemma 4 GRPO notebook pattern: " | |
| f"{exc!r}" | |
| ) | |
| adapter_source = initial_adapter_path | |
| if initial_adapter_repo_id: | |
| print(f"Downloading initial SFT adapter: {initial_adapter_repo_id}") | |
| adapter_source = snapshot_download( | |
| repo_id=initial_adapter_repo_id, | |
| cache_dir=str(HF_HUB_CACHE_DIR), | |
| token=hf_token, | |
| ) | |
| cache_volume.commit() | |
| if adapter_source: | |
| print(f"Loading initial SFT adapter for trainable GRPO continuation: {adapter_source}") | |
| adapter_source_path = pathlib.Path(adapter_source) | |
| adapter_config_path = adapter_source_path / "adapter_config.json" | |
| if not adapter_config_path.exists(): | |
| raise RuntimeError(f"Initial SFT adapter config not found: {adapter_config_path}") | |
| adapter_config = json.loads(adapter_config_path.read_text(encoding="utf-8")) | |
| adapter_rank = int(adapter_config.get("r") or lora_rank) | |
| adapter_alpha = int(adapter_config.get("lora_alpha") or adapter_rank * 2) | |
| adapter_target_modules = adapter_config.get("target_modules") or [ | |
| "q_proj", | |
| "k_proj", | |
| "v_proj", | |
| "o_proj", | |
| "gate_proj", | |
| "up_proj", | |
| "down_proj", | |
| ] | |
| adapter_target_modules = list(adapter_target_modules) | |
| print( | |
| "Attaching Unsloth LoRA before loading SFT weights: " | |
| f"rank={adapter_rank}, alpha={adapter_alpha}, targets={adapter_target_modules}" | |
| ) | |
| model = model_api.get_peft_model( | |
| model, | |
| r=adapter_rank, | |
| target_modules=adapter_target_modules, | |
| lora_alpha=adapter_alpha, | |
| use_gradient_checkpointing="unsloth", | |
| random_state=3407, | |
| ) | |
| adapter_weights_path = adapter_source_path / "adapter_model.safetensors" | |
| if not adapter_weights_path.exists(): | |
| raise RuntimeError(f"Initial SFT adapter weights not found: {adapter_weights_path}") | |
| adapter_state = load_safetensors_file(str(adapter_weights_path), device="cpu") | |
| adapter_load_result = set_peft_model_state_dict( | |
| model, | |
| adapter_state, | |
| adapter_name="default", | |
| ) | |
| unexpected_adapter_keys = sorted( | |
| key | |
| for key in getattr(adapter_load_result, "unexpected_keys", []) | |
| if "lora_" in key or "modules_to_save" in key | |
| ) | |
| if unexpected_adapter_keys: | |
| raise RuntimeError( | |
| "Initial SFT adapter keys do not match the trainable Unsloth LoRA. " | |
| f"Unexpected adapter keys: {unexpected_adapter_keys[:10]}" | |
| ) | |
| missing_lora_keys = sorted( | |
| key | |
| for key in getattr(adapter_load_result, "missing_keys", []) | |
| if "lora_" in key or "modules_to_save" in key | |
| ) | |
| if missing_lora_keys: | |
| print(f"Missing LoRA keys while loading SFT adapter: {missing_lora_keys[:10]}") | |
| if hasattr(model, "print_trainable_parameters"): | |
| model.print_trainable_parameters() | |
| else: | |
| model = model_api.get_peft_model( | |
| model, | |
| r=lora_rank, | |
| target_modules=[ | |
| "q_proj", | |
| "k_proj", | |
| "v_proj", | |
| "o_proj", | |
| "gate_proj", | |
| "up_proj", | |
| "down_proj", | |
| ], | |
| lora_alpha=lora_rank * 2, | |
| use_gradient_checkpointing="unsloth", | |
| random_state=3407, | |
| ) | |
| if hasattr(model_api, "for_training"): | |
| model_api.for_training(model) | |
| print("LoRA adapter ready and model switched to training mode.") | |
| grpo_config_values = { | |
| "temperature": 1.0, | |
| "learning_rate": learning_rate, | |
| "weight_decay": 0.001, | |
| "warmup_ratio": 0.1, | |
| "lr_scheduler_type": "linear", | |
| "optim": "adamw_8bit", | |
| "logging_steps": 1, | |
| "per_device_train_batch_size": per_device_train_batch_size, | |
| "gradient_accumulation_steps": resolved_gradient_accumulation_steps, | |
| "num_generations": num_generations, | |
| "max_prompt_length": max_seq_length, | |
| "max_completion_length": max_completion_length, | |
| "max_steps": max_steps, | |
| "save_steps": max(10, max_steps), | |
| "report_to": "trackio", | |
| "project": trackio_project, | |
| "trackio_space_id": trackio_space_id, | |
| "run_name": run_name, | |
| "output_dir": str(output_dir), | |
| "push_to_hub": push_to_hub, | |
| "hub_model_id": output_repo_id, | |
| "hub_private_repo": True, | |
| "hub_strategy": "every_save", | |
| "gradient_checkpointing": True, | |
| "gradient_checkpointing_kwargs": {"use_reentrant": False}, | |
| "use_vllm": use_vllm, | |
| "vllm_mode": "colocate", | |
| "vllm_gpu_memory_utilization": vllm_gpu_memory_utilization, | |
| "epsilon": 0.2, | |
| "epsilon_high": 0.28, | |
| "delta": 1.5, | |
| "loss_type": "bnpo", | |
| "mask_truncated_completions": False, | |
| } | |
| grpo_config_parameters = set(inspect.signature(GRPOConfig).parameters) | |
| skipped_config_keys = sorted(set(grpo_config_values) - grpo_config_parameters) | |
| if skipped_config_keys: | |
| print(f"Skipping unsupported GRPOConfig keys: {skipped_config_keys}") | |
| training_args = GRPOConfig( | |
| **{ | |
| key: value | |
| for key, value in grpo_config_values.items() | |
| if key in grpo_config_parameters | |
| } | |
| ) | |
| trainer_values = { | |
| "model": model, | |
| "processing_class": tokenizer, | |
| "reward_funcs": cybersecurity_owasp_reward, | |
| "args": training_args, | |
| "train_dataset": dataset, | |
| "environment_factory": CyberSecurityOWASPToolEnv, | |
| "callbacks": [TrackioSystemMetricsCallback()], | |
| } | |
| trainer_parameters = set(inspect.signature(GRPOTrainer).parameters) | |
| skipped_trainer_keys = sorted(set(trainer_values) - trainer_parameters) | |
| if skipped_trainer_keys: | |
| print(f"Skipping unsupported GRPOTrainer keys: {skipped_trainer_keys}") | |
| trainer = GRPOTrainer( | |
| **{ | |
| key: value | |
| for key, value in trainer_values.items() | |
| if key in trainer_parameters | |
| } | |
| ) | |
| print("Starting GRPO trainer.train().") | |
| heartbeat_stop = threading.Event() | |
| def _training_heartbeat() -> None: | |
| start_time = time.monotonic() | |
| while not heartbeat_stop.wait(30): | |
| elapsed = int(time.monotonic() - start_time) | |
| print( | |
| "Training heartbeat: still inside trainer.train() " | |
| f"after {elapsed}s. For this smoke, the slow part is usually " | |
| f"Gemma generation/backprop: {num_generations} completions " | |
| f"up to {max_completion_length} tokens, plus Trackio upload." | |
| ) | |
| heartbeat_thread = threading.Thread( | |
| target=_training_heartbeat, | |
| name="grpo-training-heartbeat", | |
| daemon=True, | |
| ) | |
| heartbeat_thread.start() | |
| try: | |
| trainer.train() | |
| finally: | |
| heartbeat_stop.set() | |
| heartbeat_thread.join(timeout=2) | |
| print("GRPO trainer.train() complete.") | |
| if push_to_hub: | |
| print(f"Pushing LoRA adapter to Hugging Face Hub: {output_repo_id}") | |
| trainer.push_to_hub() | |
| print("Hub push complete.") | |
| else: | |
| print("Skipping Hub push for this run. Pass --push-to-hub to upload adapters.") | |
| volume.commit() | |
| cache_volume.commit() | |
| scenario_cache_volume.commit() | |
| print(f"Committed run volume: {VOLUME_NAME}") | |
| print(f"Committed model cache volume: {CACHE_VOLUME_NAME}") | |
| print(f"Committed scenario cache volume: {SCENARIO_CACHE_VOLUME_NAME}") | |
| try: | |
| trackio.finish() | |
| except RuntimeError as exc: | |
| print(f"Trackio finish skipped because the trainer already finalized it: {exc}") | |
| return { | |
| "run_name": run_name, | |
| "env_repo_id": env_repo_id, | |
| "output_repo_id": output_repo_id, | |
| "trackio_space_id": trackio_space_id, | |
| "trackio_project": trackio_project, | |
| "max_steps": max_steps, | |
| "dataset_size": dataset_size, | |
| "difficulty": difficulty, | |
| "split": split, | |
| "model_name": model_name, | |
| "initial_adapter_path": initial_adapter_path, | |
| "initial_adapter_repo_id": initial_adapter_repo_id, | |
| "max_completion_length": max_completion_length, | |
| "num_generations": num_generations, | |
| "per_device_train_batch_size": per_device_train_batch_size, | |
| "gradient_accumulation_steps": resolved_gradient_accumulation_steps, | |
| "learning_rate": learning_rate, | |
| "effective_train_batch_size": effective_train_batch_size, | |
| "use_vllm": int(bool(use_vllm)), | |
| "vllm_gpu_memory_utilization": vllm_gpu_memory_utilization, | |
| "trace_log_every": trace_log_every, | |
| "source_mode": source_mode, | |
| "repo_url": repo_url, | |
| "repo_branch": repo_branch, | |
| "push_to_hub": push_to_hub, | |
| "scenario_cache_volume": SCENARIO_CACHE_VOLUME_NAME, | |
| "scenario_cache_mode": "require", | |
| "reward_variant": reward_tracking_config["reward_variant"], | |
| **reward_tracking_config, | |
| } | |
| def main( | |
| mode: str = "train", | |
| env_repo_id: str = "", | |
| output_repo_id: str = "", | |
| initial_adapter_path: str = "", | |
| initial_adapter_repo_id: str = "", | |
| max_steps: int = 10, | |
| dataset_size: int = 16, | |
| difficulty: int = 0, | |
| split: str = "train", | |
| model_name: str = DEFAULT_GEMMA_MODEL, | |
| max_seq_length: int = 4096, | |
| max_completion_length: int = 768, | |
| lora_rank: int = 32, | |
| trackio_space_id: str = "Humanlearning/CyberSecurity_OWASP-trackio", | |
| trackio_project: str = "CyberSecurity_OWASP-grpo", | |
| num_generations: int = 6, | |
| per_device_train_batch_size: int = 1, | |
| gradient_accumulation_steps: int = 0, | |
| learning_rate: float = 5e-6, | |
| use_vllm: bool = False, | |
| vllm_gpu_memory_utilization: float = 0.2, | |
| trace_log_every: int = 5, | |
| seed_start: int = 0, | |
| git_sha: str = "nogit", | |
| run_name: str = "", | |
| source_mode: str = "local", | |
| repo_url: str = PUBLIC_REPO_URL, | |
| repo_branch: str = PUBLIC_REPO_BRANCH, | |
| detach: bool = False, | |
| push_to_hub: bool = False, | |
| reward_config: str = "", | |
| reward_variant: str = "", | |
| cache_seed_start: int = 0, | |
| cache_difficulty_buckets: int = 0, | |
| cache_train_per_bucket: int = 0, | |
| cache_validation_per_bucket: int = 0, | |
| cache_heldout_per_bucket: int = 0, | |
| cache_force: bool = False, | |
| ) -> None: | |
| model_name = _ensure_gemma4_model(model_name) | |
| if mode == "prepare-cache": | |
| result = prepare_modal_scenario_cache.remote( | |
| seed_start=cache_seed_start, | |
| difficulty_buckets=cache_difficulty_buckets, | |
| train_per_bucket=cache_train_per_bucket, | |
| validation_per_bucket=cache_validation_per_bucket, | |
| heldout_per_bucket=cache_heldout_per_bucket, | |
| force=cache_force, | |
| ) | |
| print(f"Prepared scenario cache: {result}") | |
| return | |
| if mode == "config": | |
| result = check_training_imports.remote() | |
| print(result) | |
| return | |
| if mode == "baseline": | |
| if int(num_generations) != 1: | |
| raise ValueError("baseline mode expects --num-generations 1.") | |
| trace_log_every = max(0, int(trace_log_every)) | |
| run_name = run_name or "baseline" | |
| preflight = verify_modal_scenario_cache_for_training.remote( | |
| split=split, | |
| difficulty=difficulty, | |
| dataset_size=dataset_size, | |
| seed_start=seed_start, | |
| ) | |
| print(f"CPU scenario cache preflight passed: {preflight}") | |
| kwargs = dict( | |
| max_steps=max_steps, | |
| dataset_size=dataset_size, | |
| difficulty=difficulty, | |
| split=split, | |
| model_name=model_name, | |
| max_seq_length=max_seq_length, | |
| max_completion_length=max_completion_length, | |
| trackio_space_id=trackio_space_id, | |
| trackio_project=trackio_project, | |
| num_generations=num_generations, | |
| trace_log_every=trace_log_every, | |
| seed_start=seed_start, | |
| git_sha=git_sha, | |
| run_name=run_name, | |
| source_mode=source_mode, | |
| repo_url=repo_url, | |
| repo_branch=repo_branch, | |
| reward_config=reward_config, | |
| reward_variant=reward_variant, | |
| ) | |
| if detach: | |
| call = run_cybersecurity_owasp_baseline.spawn(**kwargs) | |
| print(f"Spawned Modal baseline call: {call.object_id}") | |
| else: | |
| result = run_cybersecurity_owasp_baseline.remote(**kwargs) | |
| print(f"Baseline result: {result}") | |
| return | |
| if mode != "train": | |
| raise ValueError("mode must be 'prepare-cache', 'train', 'baseline', or 'config'") | |
| ( | |
| resolved_gradient_accumulation_steps, | |
| effective_train_batch_size, | |
| ) = _resolve_grpo_batch_config( | |
| per_device_train_batch_size=per_device_train_batch_size, | |
| gradient_accumulation_steps=gradient_accumulation_steps, | |
| num_generations=num_generations, | |
| world_size=1, | |
| ) | |
| _validate_vllm_config( | |
| use_vllm=use_vllm, | |
| vllm_gpu_memory_utilization=vllm_gpu_memory_utilization, | |
| ) | |
| trace_log_every = max(0, int(trace_log_every)) | |
| trackio_space_id = trackio_space_id or os.environ.get( | |
| "TRACKIO_SPACE_ID", | |
| "Humanlearning/CyberSecurity_OWASP-trackio", | |
| ) | |
| trackio_project = trackio_project or os.environ.get( | |
| "TRACKIO_PROJECT", "CyberSecurity_OWASP-grpo" | |
| ) | |
| resolved_trackio_space_id = trackio_space_id | |
| resolved_output_repo_id = output_repo_id | |
| if not resolved_trackio_space_id or not resolved_output_repo_id: | |
| hf_token = os.environ.get("HF_TOKEN") | |
| if hf_token: | |
| try: | |
| from huggingface_hub import whoami | |
| user = whoami(token=hf_token)["name"] | |
| if not resolved_trackio_space_id: | |
| resolved_trackio_space_id = ( | |
| f"{user}/CyberSecurity_OWASP-trackio" | |
| if user == "humandotlearning" | |
| else "Humanlearning/CyberSecurity_OWASP-trackio" | |
| ) | |
| resolved_output_repo_id = ( | |
| resolved_output_repo_id | |
| or f"{user}/{_grpo_output_repo_slug(model_name, initial_adapter_path=initial_adapter_path, initial_adapter_repo_id=initial_adapter_repo_id)}" | |
| ) | |
| except Exception as exc: | |
| print(f"Could not resolve Hugging Face defaults locally: {exc!r}") | |
| if git_sha == "nogit": | |
| try: | |
| git_sha = subprocess.check_output( | |
| [ | |
| "git", | |
| "-c", | |
| f"safe.directory={PROJECT_ROOT.as_posix()}", | |
| "rev-parse", | |
| "HEAD", | |
| ], | |
| cwd=PROJECT_ROOT, | |
| text=True, | |
| stderr=subprocess.DEVNULL, | |
| ).strip() | |
| except Exception: | |
| git_sha = "nogit" | |
| model_slug = model_name.replace("/", "-") | |
| local_stamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S") | |
| variant_tag = reward_variant or "default" | |
| algo_tag = _grpo_run_algo_tag( | |
| initial_adapter_path=initial_adapter_path, | |
| initial_adapter_repo_id=initial_adapter_repo_id, | |
| ) | |
| run_name = run_name or ( | |
| f"CyberSecurity_OWASP-{model_slug}-{algo_tag}-level{difficulty}-" | |
| f"{variant_tag}-steps{max_steps}-seed{seed_start}-{local_stamp}-{git_sha[:8]}" | |
| ) | |
| print(f"Run name: {run_name}") | |
| print(f"Reward variant: {variant_tag}") | |
| print(f"Reward config path: {reward_config or '(default training/configs/grpo_small.yaml)'}") | |
| print(f"Source mode: {source_mode}") | |
| if source_mode == "public": | |
| print(f"Public repo: {repo_url}@{repo_branch}") | |
| if resolved_trackio_space_id: | |
| print(f"Trackio Space: https://huggingface.co/spaces/{resolved_trackio_space_id}") | |
| else: | |
| print("Trackio Space: derived remotely from HF_TOKEN as <hf-user>/CyberSecurity_OWASP-trackio") | |
| if resolved_output_repo_id: | |
| print(f"Output model repo: https://huggingface.co/{resolved_output_repo_id}") | |
| else: | |
| print( | |
| "Output model repo: derived remotely from HF_TOKEN as " | |
| f"<hf-user>/{_grpo_output_repo_slug(model_name, initial_adapter_path=initial_adapter_path, initial_adapter_repo_id=initial_adapter_repo_id)}" | |
| ) | |
| print(f"Hub push enabled: {push_to_hub}") | |
| if initial_adapter_path: | |
| print(f"Initial SFT adapter path: {initial_adapter_path}") | |
| if initial_adapter_repo_id: | |
| print(f"Initial SFT adapter repo: https://huggingface.co/{initial_adapter_repo_id}") | |
| print(f"Model cache volume: {CACHE_VOLUME_NAME}") | |
| print(f"Scenario cache volume: {SCENARIO_CACHE_VOLUME_NAME}") | |
| print( | |
| "GRPO throughput config: " | |
| f"per_device_train_batch_size={per_device_train_batch_size}, " | |
| f"gradient_accumulation_steps={resolved_gradient_accumulation_steps}, " | |
| f"num_generations={num_generations}, " | |
| f"effective_train_batch_size={effective_train_batch_size}, " | |
| f"learning_rate={learning_rate}" | |
| ) | |
| print( | |
| "Generation acceleration config: " | |
| f"use_vllm={use_vllm}, " | |
| f"vllm_gpu_memory_utilization={vllm_gpu_memory_utilization}, " | |
| f"trace_log_every={trace_log_every}" | |
| ) | |
| print("Launch phases:") | |
| print( | |
| "1. Modal image build/validation: happens before remote Python logs; " | |
| "slow when local source or dependency layers changed." | |
| ) | |
| print("2. CPU-only scenario cache preflight in CyberSecurity_OWASP-scenario-cache.") | |
| print(f"3. GPU container start after cache preflight passes; fallback={GRPO_GPU_FALLBACK}.") | |
| print("4. Model cache check in CyberSecurity_OWASP-model-cache.") | |
| print("5. Cached snapshot load into GPU RAM with Unsloth progress.") | |
| print("6. GRPO steps, Trackio sync, and volume commit.") | |
| print( | |
| "If there is a long pause after trainer.train() starts, watch for " | |
| "Training heartbeat lines every 30 seconds." | |
| ) | |
| kwargs = dict( | |
| env_repo_id=env_repo_id, | |
| output_repo_id=output_repo_id, | |
| initial_adapter_path=initial_adapter_path, | |
| initial_adapter_repo_id=initial_adapter_repo_id, | |
| max_steps=max_steps, | |
| dataset_size=dataset_size, | |
| difficulty=difficulty, | |
| split=split, | |
| model_name=model_name, | |
| max_seq_length=max_seq_length, | |
| max_completion_length=max_completion_length, | |
| lora_rank=lora_rank, | |
| trackio_space_id=trackio_space_id, | |
| trackio_project=trackio_project, | |
| num_generations=num_generations, | |
| per_device_train_batch_size=per_device_train_batch_size, | |
| gradient_accumulation_steps=resolved_gradient_accumulation_steps, | |
| learning_rate=learning_rate, | |
| use_vllm=use_vllm, | |
| vllm_gpu_memory_utilization=vllm_gpu_memory_utilization, | |
| trace_log_every=trace_log_every, | |
| seed_start=seed_start, | |
| git_sha=git_sha, | |
| run_name=run_name, | |
| source_mode=source_mode, | |
| repo_url=repo_url, | |
| repo_branch=repo_branch, | |
| push_to_hub=push_to_hub, | |
| reward_config=reward_config, | |
| reward_variant=reward_variant, | |
| ) | |
| preflight = verify_modal_scenario_cache_for_training.remote( | |
| split=split, | |
| difficulty=difficulty, | |
| dataset_size=dataset_size, | |
| seed_start=seed_start, | |
| ) | |
| print(f"CPU scenario cache preflight passed: {preflight}") | |
| if detach: | |
| call = train_cybersecurity_owasp_grpo.spawn(**kwargs) | |
| print(f"Spawned Modal training call: {call.object_id}") | |
| else: | |
| result = train_cybersecurity_owasp_grpo.remote(**kwargs) | |
| print(f"Training result: {result}") | |