"""Persistent Modal GRPO launcher for CyberSecurity_OWASP. This packages the local repository into a Modal GPU image, runs a small tool-use GRPO job against the in-process CyberSecurity_OWASP environment, logs metrics/traces to Trackio, and saves LoRA checkpoints in a persistent Modal volume. Example: uv run --extra modal modal run scripts/modal_train_grpo.py \ --max-steps 10 \ --dataset-size 16 \ --num-generations 6 \ --difficulty 0 """ from __future__ import annotations import json import os import pathlib import subprocess import sys from datetime import datetime, timezone from typing import Any import modal APP_NAME = "CyberSecurity_OWASP-grpo" VOLUME_NAME = "CyberSecurity_OWASP-grpo-runs" CACHE_VOLUME_NAME = "CyberSecurity_OWASP-model-cache" SCENARIO_CACHE_VOLUME_NAME = "CyberSecurity_OWASP-scenario-cache" SECRET_NAME = "CyberSecurity_OWASP-secrets" RUNS_DIR = pathlib.Path("/runs") CACHE_DIR = pathlib.Path("/cache") SCENARIO_CACHE_DIR = pathlib.Path("/scenario-cache") HF_HOME_DIR = CACHE_DIR / "huggingface" HF_HUB_CACHE_DIR = HF_HOME_DIR / "hub" TORCH_HOME_DIR = CACHE_DIR / "torch" XDG_CACHE_DIR = CACHE_DIR / "xdg" UNSLOTH_CACHE_DIR = CACHE_DIR / "unsloth" TRITON_CACHE_DIR = CACHE_DIR / "triton" REMOTE_PROJECT = "/root/CyberSecurity_OWASP" PROJECT_ROOT = pathlib.Path(__file__).resolve().parents[1] PUBLIC_REPO_URL = "https://github.com/humandotlearning/CyberSecurity_OWASP.git" PUBLIC_REPO_BRANCH = "master" DEFAULT_GEMMA_MODEL = "unsloth/gemma-4-E2B-it" GRPO_TRAINING_TIMEOUT_SECONDS = 24 * 60 * 60 GRPO_GPU_FALLBACK = ["L40S", "L4"] _IMAGE_NOTICE_PRINTED = False def _ensure_gemma4_model(model_name: str) -> str: if model_name != DEFAULT_GEMMA_MODEL: raise ValueError( "CyberSecurity_OWASP GRPO training is pinned to " f"{DEFAULT_GEMMA_MODEL}, matching the Unsloth Gemma 4 E2B RL notebook. " f"Received {model_name!r}." ) return model_name def _model_repo_slug(model_name: str) -> str: return ( model_name.replace("/", "-") .replace("_", "-") .replace(".", "-") .lower() ) def _grpo_output_repo_slug( model_name: str, *, initial_adapter_path: str = "", initial_adapter_repo_id: str = "", ) -> str: warmstart_tag = ( "-sft-warmstart" if initial_adapter_path or initial_adapter_repo_id else "" ) return ( f"CyberSecurity_OWASP-{_model_repo_slug(model_name)}" f"{warmstart_tag}-grpo-lora" ) def _grpo_run_algo_tag( *, initial_adapter_path: str = "", initial_adapter_repo_id: str = "", ) -> str: return "sft-warmstart-grpo" if initial_adapter_path or initial_adapter_repo_id else "grpo" def _hf_model_cache_path(model_name: str) -> pathlib.Path: return HF_HUB_CACHE_DIR / f"models--{model_name.replace('/', '--')}" def _resolve_grpo_batch_config( *, per_device_train_batch_size: int, gradient_accumulation_steps: int, num_generations: int, world_size: int = 1, ) -> tuple[int, int]: if num_generations < 1: raise ValueError("--num-generations must be at least 1.") if per_device_train_batch_size < 1: raise ValueError("--per-device-train-batch-size must be at least 1.") if world_size < 1: raise ValueError("world_size must be at least 1.") resolved_gradient_accumulation_steps = ( gradient_accumulation_steps if gradient_accumulation_steps > 0 else max(2, num_generations) ) if resolved_gradient_accumulation_steps < 1: raise ValueError("--gradient-accumulation-steps must be at least 1.") effective_batch_size = ( per_device_train_batch_size * resolved_gradient_accumulation_steps * world_size ) if effective_batch_size % num_generations: raise ValueError( "Invalid GRPO batch shape: " "per_device_train_batch_size * gradient_accumulation_steps * world_size " f"must be divisible by num_generations. Got " f"{per_device_train_batch_size} * " f"{resolved_gradient_accumulation_steps} * {world_size} = " f"{effective_batch_size}, which is not divisible by {num_generations}." ) return resolved_gradient_accumulation_steps, effective_batch_size def _validate_vllm_config(*, use_vllm: bool, vllm_gpu_memory_utilization: float) -> None: if not use_vllm: return if not 0.0 < vllm_gpu_memory_utilization <= 0.95: raise ValueError( "--vllm-gpu-memory-utilization must be in the interval (0.0, 0.95] " "when --use-vllm is enabled." ) def _extract_first_json_object(text: str) -> dict[str, Any] | None: stripped = text.strip() candidates = [stripped] if "```" in stripped: for part in stripped.split("```"): part = part.strip() if part.startswith("json"): part = part[4:].strip() candidates.append(part) for candidate in candidates: try: loaded = json.loads(candidate) except Exception: continue if isinstance(loaded, dict): return loaded start = stripped.find("{") while start >= 0: depth = 0 in_string = False escaped = False for index in range(start, len(stripped)): char = stripped[index] if in_string: if escaped: escaped = False elif char == "\\": escaped = True elif char == '"': in_string = False continue if char == '"': in_string = True elif char == "{": depth += 1 elif char == "}": depth -= 1 if depth == 0: try: loaded = json.loads(stripped[start : index + 1]) except Exception: break if isinstance(loaded, dict): return loaded start = stripped.find("{", start + 1) return None def _configure_modal_cache_env() -> dict[str, str]: values = { "HF_HOME": str(HF_HOME_DIR), "HF_HUB_CACHE": str(HF_HUB_CACHE_DIR), "TRANSFORMERS_CACHE": str(HF_HUB_CACHE_DIR), "TORCH_HOME": str(TORCH_HOME_DIR), "XDG_CACHE_HOME": str(XDG_CACHE_DIR), "UNSLOTH_CACHE_DIR": str(UNSLOTH_CACHE_DIR), "UNSLOTH_COMPILE_CACHE": str(UNSLOTH_CACHE_DIR / "compile"), "TRITON_CACHE_DIR": str(TRITON_CACHE_DIR), } for key, value in values.items(): os.environ[key] = value for path in { CACHE_DIR, HF_HOME_DIR, HF_HUB_CACHE_DIR, TORCH_HOME_DIR, XDG_CACHE_DIR, UNSLOTH_CACHE_DIR, UNSLOTH_CACHE_DIR / "compile", TRITON_CACHE_DIR, }: path.mkdir(parents=True, exist_ok=True) return values def _configure_scenario_cache_env(*, required: bool = True) -> dict[str, str]: values = { "CYBERSECURITY_OWASP_SCENARIO_CACHE_DIR": str(SCENARIO_CACHE_DIR), "CYBERSECURITY_OWASP_SCENARIO_CACHE_MODE": "require" if required else "fallback", } for key, value in values.items(): os.environ[key] = value SCENARIO_CACHE_DIR.mkdir(parents=True, exist_ok=True) return values def _configure_reward_env( *, reward_config: str = "", reward_variant: str = "", reward_mode: str = "", ) -> dict[str, str]: values: dict[str, str] = {} if reward_config: values["CYBERSECURITY_OWASP_REWARD_CONFIG"] = reward_config if reward_variant: values["CYBERSECURITY_OWASP_REWARD_VARIANT"] = reward_variant if reward_mode: values["CYBERSECURITY_OWASP_REWARD_MODE"] = reward_mode for key, value in values.items(): os.environ[key] = value return values def _print_image_startup_notice() -> None: global _IMAGE_NOTICE_PRINTED if _IMAGE_NOTICE_PRINTED: return _IMAGE_NOTICE_PRINTED = True print( "Modal startup phase 1/5: building or validating the GPU training image. " "If this takes minutes, it is Modal image packaging/dependency cache work, " "not model-weight download." ) print( "Later remote phases will print: cache hit/miss, snapshot_download progress, " "Unsloth weight loading, GRPO heartbeat, Trackio upload, and volume commits." ) def _load_local_env_file() -> None: env_path = PROJECT_ROOT / ".env.local" if not env_path.exists(): return for raw_line in env_path.read_text(encoding="utf-8").splitlines(): line = raw_line.strip() if not line or line.startswith("#") or "=" not in line: continue key, value = line.split("=", 1) key = key.strip() if key not in {"TRACKIO_PROJECT"}: continue value = value.strip().strip('"').strip("'") os.environ.setdefault(key, value) def _modal_secrets() -> list[modal.Secret]: if _is_config_mode(): return [] return [modal.Secret.from_name(SECRET_NAME, required_keys=["HF_TOKEN"])] def _is_config_mode() -> bool: args = sys.argv[1:] for index, arg in enumerate(args): if arg == "--mode" and index + 1 < len(args): return args[index + 1] == "config" if arg.startswith("--mode="): return arg.split("=", 1)[1] == "config" return False def _is_prepare_cache_mode() -> bool: args = sys.argv[1:] for index, arg in enumerate(args): if arg == "--mode" and index + 1 < len(args): return args[index + 1] == "prepare-cache" if arg.startswith("--mode="): return arg.split("=", 1)[1] == "prepare-cache" return False _load_local_env_file() def _cli_arg_value(name: str, default: str = "") -> str: args = sys.argv[1:] flag = f"--{name}" for index, arg in enumerate(args): if arg == flag and index + 1 < len(args): return args[index + 1] if arg.startswith(f"{flag}="): return arg.split("=", 1)[1] return default def _source_mode() -> str: return _cli_arg_value("source-mode", os.environ.get("MODAL_SOURCE_MODE", "local")) def _training_image() -> modal.Image: if _is_prepare_cache_mode(): return _scenario_cache_image() if not _is_prepare_cache_mode(): _print_image_startup_notice() image = ( modal.Image.from_registry( "nvidia/cuda:12.8.0-devel-ubuntu22.04", add_python="3.11", ) .apt_install("git", "build-essential", "curl") .uv_pip_install( "torch==2.10.0", "triton>=3.4.0", "torchvision==0.25.0", "bitsandbytes", "accelerate", "datasets", "huggingface_hub", "peft", "pillow", "tokenizers", "nvidia-ml-py", "trackio>=0.25.0", "transformers>=5.5.0", "trl>=0.28.0", "openenv-core[core]>=0.2.3", ) .uv_pip_install( "unsloth_zoo[base] @ git+https://github.com/unslothai/unsloth-zoo", "unsloth[base] @ git+https://github.com/unslothai/unsloth", ) .uv_pip_install("timm", extra_options="--no-deps") .uv_pip_install("pydantic==2.10.6") .uv_pip_install("mergekit", "immutables==0.21", extra_options="--no-deps") .uv_pip_install("llm-blender", "weave") .uv_pip_install("trl>=0.28.0", "transformers>=5.5.0", "jmespath") ) if _source_mode() == "public": repo_url = _cli_arg_value("repo-url", PUBLIC_REPO_URL) repo_branch = _cli_arg_value("repo-branch", PUBLIC_REPO_BRANCH) image = image.run_commands( f"git clone --depth 1 --branch {repo_branch} {repo_url} {REMOTE_PROJECT}", f"python -m pip install --no-deps -e {REMOTE_PROJECT}", ) else: image = image.add_local_dir( PROJECT_ROOT, remote_path=REMOTE_PROJECT, copy=True, ignore=[ ".git", ".venv", ".env", ".env.*", "__pycache__", ".pytest_cache", "outputs", "*.pyc", ], ) image = image.run_commands( f"python -m pip install --no-deps -e {REMOTE_PROJECT}", ) return image.run_commands( "python -c \"import os, torch; import transformers.utils.hub as hub; " "hub.TRANSFORMERS_CACHE = getattr(hub, 'TRANSFORMERS_CACHE', " "os.path.join(os.path.expanduser('~'), '.cache', 'huggingface', 'hub')); " "from trl import GRPOConfig, GRPOTrainer; " "from CyberSecurity_OWASP.server.CyberSecurity_OWASP_environment import " "CybersecurityOwaspEnvironment; print('trainer import ok', torch.__version__)\"", ).workdir(REMOTE_PROJECT) def _scenario_cache_image() -> modal.Image: image = ( modal.Image.debian_slim(python_version="3.11") .apt_install("git") .uv_pip_install("openenv-core[core]>=0.2.3", "trackio>=0.25.0") ) if _source_mode() == "public": repo_url = _cli_arg_value("repo-url", PUBLIC_REPO_URL) repo_branch = _cli_arg_value("repo-branch", PUBLIC_REPO_BRANCH) image = image.run_commands( f"git clone --depth 1 --branch {repo_branch} {repo_url} {REMOTE_PROJECT}", f"python -m pip install --no-deps -e {REMOTE_PROJECT}", ) else: image = image.add_local_dir( PROJECT_ROOT, remote_path=REMOTE_PROJECT, copy=True, ignore=[ ".git", ".venv", ".env", ".env.*", "__pycache__", ".pytest_cache", "outputs", "*.pyc", ], ) image = image.run_commands( f"python -m pip install --no-deps -e {REMOTE_PROJECT}", ) return image.workdir(REMOTE_PROJECT) app = modal.App(APP_NAME) volume = modal.Volume.from_name(VOLUME_NAME, create_if_missing=True) cache_volume = modal.Volume.from_name(CACHE_VOLUME_NAME, create_if_missing=True) scenario_cache_volume = modal.Volume.from_name(SCENARIO_CACHE_VOLUME_NAME, create_if_missing=True) secrets = _modal_secrets() scenario_cache_image = _scenario_cache_image() training_image = _training_image() @app.function( image=scenario_cache_image, timeout=2 * 60 * 60, volumes={SCENARIO_CACHE_DIR: scenario_cache_volume}, ) def prepare_modal_scenario_cache( seed_start: int = 0, difficulty_buckets: int = 0, train_per_bucket: int = 0, validation_per_bucket: int = 0, heldout_per_bucket: int = 0, force: bool = False, ) -> dict[str, Any]: if difficulty_buckets: os.environ["CYBERSECURITY_OWASP_DIFFICULTY_BUCKETS"] = str(difficulty_buckets) if train_per_bucket: os.environ["CYBERSECURITY_OWASP_TRAIN_SCENARIOS_PER_BUCKET"] = str(train_per_bucket) if validation_per_bucket: os.environ["CYBERSECURITY_OWASP_VALIDATION_SCENARIOS_PER_BUCKET"] = str(validation_per_bucket) if heldout_per_bucket: os.environ["CYBERSECURITY_OWASP_HELDOUT_SCENARIOS_PER_BUCKET"] = str(heldout_per_bucket) _configure_scenario_cache_env(required=False) from CyberSecurity_OWASP.config import load_scenario_authoring_config from CyberSecurity_OWASP.server.scenario_cache import prepare_scenario_cache settings = load_scenario_authoring_config() result = prepare_scenario_cache( cache_dir=SCENARIO_CACHE_DIR, settings=settings, seed_start=seed_start, force=force, ) scenario_cache_volume.commit() result["scenario_cache_volume"] = SCENARIO_CACHE_VOLUME_NAME return result @app.function( image=scenario_cache_image, timeout=60 * 10, volumes={SCENARIO_CACHE_DIR: scenario_cache_volume}, ) def verify_modal_scenario_cache_for_training( split: str = "train", difficulty: int = 0, dataset_size: int = 2, seed_start: int = 0, ) -> dict[str, Any]: _configure_scenario_cache_env(required=True) scenario_cache_volume.reload() from CyberSecurity_OWASP.config import load_scenario_authoring_config from CyberSecurity_OWASP.server.CyberSecurity_OWASP_environment import ( CybersecurityOwaspEnvironment, ) from CyberSecurity_OWASP.server.curriculum import CurriculumController from CyberSecurity_OWASP.server.scenario_cache import ScenarioCache settings = load_scenario_authoring_config() scenario_profile = CurriculumController(settings=settings).select_profile( seed=seed_start, split=split, requested_difficulty=difficulty, ) resolved_difficulty = int(scenario_profile["difficulty"]) cache = ScenarioCache(SCENARIO_CACHE_DIR, settings=settings) coverage = cache.assert_coverage(split=split, difficulty=resolved_difficulty) entries = cache.validated_entries(split=split, difficulty=resolved_difficulty) if not entries: entries = cache.validated_entries(split=split) if not entries: raise RuntimeError(f"No validated scenario cache entries found for split={split!r}.") sample_entry = entries[0] env = CybersecurityOwaspEnvironment() try: obs = env.reset( seed=int(sample_entry["seed"]), split=str(sample_entry["split"]), difficulty=int(sample_entry["difficulty"]), ) if not env.state.cache_hit: raise RuntimeError("Scenario cache preflight reset did not hit cache.") if env.state.metrics.get("scenario_compile_latency_ms", 0.0): raise RuntimeError("Scenario cache preflight unexpectedly compiled a scenario.") sample = { "phase": obs.phase, "task_id": env.state.task_id, "cache_hit": env.state.cache_hit, "scenario_hash": env.state.scenario_hash, "reset_latency_ms": env.state.reset_latency_ms, "bundle_load_latency_ms": env.state.metrics.get( "scenario_bundle_load_latency_ms", 0.0, ), } finally: env.close() return { "scenario_cache_volume": SCENARIO_CACHE_VOLUME_NAME, "scenario_cache_dir": str(SCENARIO_CACHE_DIR), "scenario_cache_mode": "require", "split": split, "difficulty": "adaptive", "initial_difficulty": resolved_difficulty, "dataset_size": dataset_size, "available_scenarios": len(cache.validated_entries(split=split)), "coverage": coverage, "sample_reset": sample, } @app.function( image=training_image, gpu=GRPO_GPU_FALLBACK, timeout=4 * 60 * 60, volumes={RUNS_DIR: volume, CACHE_DIR: cache_volume, SCENARIO_CACHE_DIR: scenario_cache_volume}, secrets=secrets, ) def check_training_imports() -> dict[str, str]: cache_env = _configure_modal_cache_env() scenario_cache_env = _configure_scenario_cache_env(required=False) import torch import trackio from datasets import Dataset from trl import GRPOConfig, GRPOTrainer from unsloth import FastVisionModel from CyberSecurity_OWASP.server.CyberSecurity_OWASP_environment import ( CybersecurityOwaspEnvironment, ) env = CybersecurityOwaspEnvironment() obs = env.reset(seed=0, split="validation", difficulty=0) return { "torch": torch.__version__, "trackio": getattr(trackio, "__version__", "unknown"), "dataset": Dataset.__name__, "grpo_config": GRPOConfig.__name__, "grpo_trainer": GRPOTrainer.__name__, "unsloth_vision_model": FastVisionModel.__name__, "env": CybersecurityOwaspEnvironment.__name__, "reset_phase": obs.phase, "hf_home": cache_env["HF_HOME"], "hf_hub_cache": cache_env["HF_HUB_CACHE"], "scenario_cache_dir": scenario_cache_env["CYBERSECURITY_OWASP_SCENARIO_CACHE_DIR"], } @app.function( image=training_image, gpu=GRPO_GPU_FALLBACK, timeout=4 * 60 * 60, volumes={RUNS_DIR: volume, CACHE_DIR: cache_volume, SCENARIO_CACHE_DIR: scenario_cache_volume}, secrets=secrets, ) def run_cybersecurity_owasp_baseline( max_steps: int = 50, dataset_size: int = 1, difficulty: int = 0, split: str = "train", model_name: str = DEFAULT_GEMMA_MODEL, max_seq_length: int = 4096, max_completion_length: int = 768, trackio_space_id: str = "Humanlearning/CyberSecurity_OWASP-trackio", trackio_project: str = "CyberSecurity_OWASP-grpo", num_generations: int = 1, trace_log_every: int = 1, seed_start: int = 0, git_sha: str = "nogit", run_name: str = "baseline", source_mode: str = "local", repo_url: str = PUBLIC_REPO_URL, repo_branch: str = PUBLIC_REPO_BRANCH, reward_config: str = "", reward_variant: str = "", ) -> dict[str, str | int | float]: import statistics import time import torch from huggingface_hub import snapshot_download, whoami from unsloth import FastVisionModel import transformers.utils.hub as transformers_hub from CyberSecurity_OWASP.models import CyberSecurityOWASPAction from CyberSecurity_OWASP.config import load_scenario_authoring_config from CyberSecurity_OWASP.reward_config import load_reward_settings from CyberSecurity_OWASP.server.CyberSecurity_OWASP_environment import ( CybersecurityOwaspEnvironment, ) from CyberSecurity_OWASP.server.curriculum import CurriculumController from CyberSecurity_OWASP.server.scenario_cache import ScenarioCache from training.trackio_utils import ( aggregate_episode_metrics, episode_record_from_state, log_reward_config, log_trace_table, log_trackio_metrics, reward_config_trackio_config, trackio_run, ) model_name = _ensure_gemma4_model(model_name) if int(num_generations) != 1: raise ValueError("Baseline mode runs the untrained model with --num-generations 1.") cache_env = _configure_modal_cache_env() scenario_cache_env = _configure_scenario_cache_env(required=True) transformers_hub.TRANSFORMERS_CACHE = cache_env["HF_HUB_CACHE"] hf_token = os.environ.get("HF_TOKEN") if not hf_token: raise RuntimeError(f"HF_TOKEN is missing from the Modal secret {SECRET_NAME}.") try: whoami(token=hf_token) except Exception as exc: raise RuntimeError("HF_TOKEN could not be validated before baseline run.") from exc os.environ["TRACKIO_SPACE_ID"] = trackio_space_id os.environ["TRACKIO_PROJECT"] = trackio_project reward_env = _configure_reward_env( reward_config=reward_config, reward_variant=reward_variant, ) reward_settings = load_reward_settings() reward_tracking_config = reward_config_trackio_config(reward_settings) reward_tracking_config["reward_variant"] = reward_variant or "default" reward_tracking_config["reward_config_path"] = reward_config or reward_settings.source_path run_name = run_name or "baseline" output_dir = RUNS_DIR / run_name output_dir.mkdir(parents=True, exist_ok=True) try: cache_volume.reload() print(f"Reloaded Modal model cache volume: {CACHE_VOLUME_NAME}") except Exception as exc: print(f"Model cache volume reload skipped: {exc!r}") try: scenario_cache_volume.reload() print(f"Reloaded Modal scenario cache volume: {SCENARIO_CACHE_VOLUME_NAME}") except Exception as exc: print(f"Scenario cache volume reload skipped: {exc!r}") settings = load_scenario_authoring_config() scenario_profile = CurriculumController(settings=settings).select_profile( seed=seed_start, split=split, requested_difficulty=difficulty, ) resolved_difficulty = int(scenario_profile["difficulty"]) scenario_cache = ScenarioCache(SCENARIO_CACHE_DIR, settings=settings) coverage = scenario_cache.assert_coverage( split=split, difficulty=resolved_difficulty, ) entries = scenario_cache.validated_entries( split=split, difficulty=resolved_difficulty, ) or scenario_cache.validated_entries(split=split) if not entries: raise RuntimeError(f"No validated scenario cache entries found for split={split!r}.") print(f"Baseline run name: {run_name}") print(f"Source mode: {source_mode}") if source_mode == "public": print(f"Installed CyberSecurity_OWASP from public repo: {repo_url}@{repo_branch}") else: print("Packaged local CyberSecurity_OWASP repo.") print(f"Trackio Space: {trackio_space_id}") print(f"Trackio Project: {trackio_project}") print(f"Reward config: {reward_tracking_config['reward_config_id']}") print(f"Reward config hash: {reward_tracking_config['reward_config_hash']}") print(f"Reward variant: {reward_tracking_config['reward_variant']}") print(f"Reward config path: {reward_tracking_config['reward_config_path']}") if reward_env: print(f"Reward env overrides: {reward_env}") print(f"Scenario cache dir: {scenario_cache_env['CYBERSECURITY_OWASP_SCENARIO_CACHE_DIR']}") print(f"Scenario cache coverage: {coverage}") print( "Baseline generation config: " f"episodes={dataset_size}, max_episode_steps={max_steps}, " f"num_generations={num_generations}, max_completion_length={max_completion_length}, " f"trace_log_every={trace_log_every}" ) expected_model_cache = _hf_model_cache_path(model_name) print(f"Expected HF model cache path: {expected_model_cache}") print(f"Model cache hit before load: {expected_model_cache.exists()}") try: snapshot_path = snapshot_download( repo_id=model_name, cache_dir=str(HF_HUB_CACHE_DIR), token=hf_token, ) print(f"Model snapshot ready: {snapshot_path}") cache_volume.commit() except Exception as exc: print(f"Explicit model snapshot prefetch failed; loading directly. Error: {exc!r}") model_api = FastVisionModel model, tokenizer = model_api.from_pretrained( model_name=model_name, max_seq_length=max_seq_length, load_in_4bit=False, fast_inference=False, cache_dir=str(HF_HUB_CACHE_DIR), token=hf_token, ) if hasattr(model_api, "for_inference"): model_api.for_inference(model) model.eval() cache_volume.commit() device = next(model.parameters()).device text_tokenizer = getattr(tokenizer, "tokenizer", tokenizer) def render_prompt(observation, actions: list[dict[str, Any]]) -> str: recent_actions = actions[-8:] return ( "You are the untrained baseline model for a defensive local AppSec " "repair environment. Use only the listed local tools. Return exactly " "one JSON object and no markdown.\n\n" f"{observation.scenario_prompt}\n\n" f"Current phase: {observation.phase}\n" f"Available actions: {observation.available_actions}\n" f"Last tool result: {observation.last_tool_result}\n" f"Recent actions: {json.dumps(recent_actions, sort_keys=True)}\n\n" 'Required format: {"tool_name":"inspect_policy_graph","arguments":{}}' ) def generate_action_text(prompt: str) -> tuple[str, list[int], list[int]]: messages = [{"role": "user", "content": prompt}] prompt_text = prompt for candidate in (tokenizer, text_tokenizer): if hasattr(candidate, "apply_chat_template"): try: prompt_text = candidate.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, ) break except Exception: prompt_text = prompt encode = tokenizer try: inputs = encode( prompt_text, return_tensors="pt", truncation=True, max_length=max_seq_length, ) except Exception: inputs = text_tokenizer( prompt_text, return_tensors="pt", truncation=True, max_length=max_seq_length, ) if hasattr(inputs, "to"): inputs = inputs.to(device) else: inputs = { key: value.to(device) if hasattr(value, "to") else value for key, value in inputs.items() } input_ids = inputs.get("input_ids") input_len = int(input_ids.shape[-1]) if input_ids is not None else 0 pad_token_id = getattr(text_tokenizer, "pad_token_id", None) if pad_token_id is None: pad_token_id = getattr(text_tokenizer, "eos_token_id", None) with torch.inference_mode(): generated = model.generate( **inputs, max_new_tokens=max_completion_length, do_sample=False, pad_token_id=pad_token_id, ) output_ids = generated[0] completion_ids = output_ids[input_len:] decode = getattr(text_tokenizer, "decode", None) or getattr(tokenizer, "decode") text = decode(completion_ids, skip_special_tokens=True) prompt_ids = ( [int(item) for item in input_ids[0].detach().cpu().tolist()] if input_ids is not None else [] ) return text, prompt_ids, [int(item) for item in completion_ids.detach().cpu().tolist()] def action_from_completion(raw_text: str) -> tuple[CyberSecurityOWASPAction, str | None]: loaded = _extract_first_json_object(raw_text) if loaded is None: return CyberSecurityOWASPAction(tool_name="noop", arguments={}), "invalid_json" arguments = loaded.get("arguments", {}) if not isinstance(arguments, dict): arguments = {} payload = { "tool_name": loaded.get("tool_name", "noop"), "arguments": arguments, } try: return CyberSecurityOWASPAction(**payload), None except Exception as exc: return ( CyberSecurityOWASPAction(tool_name="noop", arguments={}), f"invalid_action_schema: {exc}", ) episode_records: list[dict[str, Any]] = [] raw_traces: list[dict[str, Any]] = [] invalid_model_outputs = 0 generation_started = time.monotonic() config = { "base_model": model_name, "algo": "baseline", "difficulty": difficulty, "split": split, "max_episode_steps": max_steps, "dataset_size": dataset_size, "num_generations": num_generations, "max_completion_length": max_completion_length, "git_sha": git_sha, "reward_variant": reward_tracking_config["reward_variant"], **reward_tracking_config, } with trackio_run( run_name=run_name, run_type="baseline", config=config, project=trackio_project, space_id=trackio_space_id, group="baseline", auto_log_gpu=True, ): log_reward_config(reward_settings, step=0) for episode_index in range(max(1, int(dataset_size))): entry = entries[(seed_start + episode_index) % len(entries)] env = CybersecurityOwaspEnvironment() try: observation = env.reset( seed=int(entry["seed"]), split=str(entry["split"]), difficulty=int(entry["difficulty"]), ) env.state.max_steps = int(max_steps) actions: list[dict[str, Any]] = [] model_steps: list[dict[str, Any]] = [] prompt_token_count = 0 completion_token_count = 0 for step_index in range(int(max_steps)): if observation.done: break prompt = render_prompt(observation, actions) raw_text, prompt_ids, completion_ids = generate_action_text(prompt) prompt_token_count += len(prompt_ids) completion_token_count += len(completion_ids) action, invalid_reason = action_from_completion(raw_text) if invalid_reason: invalid_model_outputs += 1 observation = env.step(action) action_dump = action.model_dump() actions.append(action_dump) model_steps.append( { "step": step_index + 1, "raw_completion": raw_text, "action": action_dump, "invalid_model_output": invalid_reason, "observation_message": observation.message, "reward": observation.reward, "done": observation.done, } ) env.state.completion_tokens = completion_token_count env.state.metrics["prompt_tokens"] = prompt_token_count env.state.metrics["completion_tokens"] = completion_token_count final_observation = observation.model_dump() record = episode_record_from_state( env.state, run_context={ "base_model": model_name, "algo": "baseline", "reward_version": "reward_v2", "env_version": "0.1.0", **reward_tracking_config, }, final_observation=final_observation, ) record.update( { "reward_total": float(env.state.accumulated_reward), "success": bool(env.state.success), "episode_length": int(env.state.step_count), "invalid_model_output_count": sum( 1 for item in model_steps if item["invalid_model_output"] ), "prompt_tokens": prompt_token_count, "completion_tokens": completion_token_count, } ) episode_records.append(record) raw_traces.append( { "episode_index": episode_index, "task_id": env.state.task_id, "seed": env.state.seed, "split": env.state.split, "difficulty": env.state.difficulty, "domain": env.state.domain, "bug_family": env.state.bug_family, "steps": model_steps, } ) finally: env.close() metrics = aggregate_episode_metrics(episode_records) metrics.update( { "baseline/episode_count": float(len(episode_records)), "baseline/reward_total_mean": statistics.mean( float(item.get("reward_total", 0.0)) for item in episode_records ), "baseline/success_rate": statistics.mean( 1.0 if item.get("success") else 0.0 for item in episode_records ), "baseline/invalid_model_output_rate": invalid_model_outputs / max(1.0, sum(float(item.get("episode_length", 0)) for item in episode_records)), "baseline/num_generations": float(num_generations), "baseline/max_episode_steps": float(max_steps), "baseline/max_completion_length": float(max_completion_length), } ) log_trackio_metrics(metrics, step=episode_index + 1) if trace_log_every > 0 and ( episode_index == 0 or (episode_index + 1) % trace_log_every == 0 ): log_trace_table( [episode_records[-1]], table_name="baseline_traces", step=episode_index + 1, ) elapsed_s = time.monotonic() - generation_started summary = { "run_name": run_name, "trackio_space_id": trackio_space_id, "trackio_project": trackio_project, "model_name": model_name, "dataset_size": len(episode_records), "max_episode_steps": int(max_steps), "difficulty": int(difficulty), "split": split, "num_generations": int(num_generations), "max_completion_length": int(max_completion_length), "mean_reward": ( statistics.mean(float(item.get("reward_total", 0.0)) for item in episode_records) if episode_records else 0.0 ), "success_rate": ( statistics.mean(1.0 if item.get("success") else 0.0 for item in episode_records) if episode_records else 0.0 ), "invalid_model_output_count": int(invalid_model_outputs), "elapsed_s": elapsed_s, **reward_tracking_config, } artifact_path = output_dir / "baseline_rollouts.json" artifact_path.write_text( json.dumps( { "summary": summary, "episodes": episode_records, "raw_traces": raw_traces, }, indent=2, sort_keys=True, default=str, ), encoding="utf-8", ) volume.commit() cache_volume.commit() scenario_cache_volume.commit() print(f"Baseline artifact saved to {artifact_path}") return {**summary, "artifact_path": str(artifact_path)} @app.function( image=training_image, gpu=GRPO_GPU_FALLBACK, timeout=GRPO_TRAINING_TIMEOUT_SECONDS, volumes={RUNS_DIR: volume, CACHE_DIR: cache_volume, SCENARIO_CACHE_DIR: scenario_cache_volume}, secrets=secrets, ) def train_cybersecurity_owasp_grpo( env_repo_id: str = "", output_repo_id: str = "", initial_adapter_path: str = "", initial_adapter_repo_id: str = "", max_steps: int = 10, dataset_size: int = 16, difficulty: int = 0, split: str = "train", model_name: str = DEFAULT_GEMMA_MODEL, max_seq_length: int = 4096, max_completion_length: int = 768, lora_rank: int = 32, trackio_space_id: str = "Humanlearning/CyberSecurity_OWASP-trackio", trackio_project: str = "CyberSecurity_OWASP-grpo", num_generations: int = 6, per_device_train_batch_size: int = 1, gradient_accumulation_steps: int = 0, learning_rate: float = 5e-6, use_vllm: bool = False, vllm_gpu_memory_utilization: float = 0.2, trace_log_every: int = 5, seed_start: int = 0, git_sha: str = "nogit", run_name: str = "", source_mode: str = "local", repo_url: str = PUBLIC_REPO_URL, repo_branch: str = PUBLIC_REPO_BRANCH, push_to_hub: bool = False, reward_config: str = "", reward_variant: str = "", ) -> dict[str, str | int | float]: import inspect import statistics import threading import time model_name = _ensure_gemma4_model(model_name) cache_env = _configure_modal_cache_env() world_size = int(os.environ.get("WORLD_SIZE", "1") or "1") ( resolved_gradient_accumulation_steps, effective_train_batch_size, ) = _resolve_grpo_batch_config( per_device_train_batch_size=per_device_train_batch_size, gradient_accumulation_steps=gradient_accumulation_steps, num_generations=num_generations, world_size=world_size, ) _validate_vllm_config( use_vllm=use_vllm, vllm_gpu_memory_utilization=vllm_gpu_memory_utilization, ) trace_log_every = max(0, int(trace_log_every)) import torch from safetensors.torch import load_file as load_safetensors_file from unsloth import FastVisionModel import transformers.utils.hub as transformers_hub from datasets import Dataset from huggingface_hub import snapshot_download, whoami from peft import set_peft_model_state_dict from transformers import TrainerCallback from trl import GRPOConfig, GRPOTrainer, clone_chat_template try: from trl.chat_template_utils import add_response_schema except ImportError: def add_response_schema(tokenizer): return tokenizer import trackio from CyberSecurity_OWASP.models import CyberSecurityOWASPAction from CyberSecurity_OWASP.config import load_scenario_authoring_config from CyberSecurity_OWASP.server.CyberSecurity_OWASP_environment import ( CybersecurityOwaspEnvironment, ) from CyberSecurity_OWASP.reward_config import ( compute_token_penalty, load_reward_settings, ) from CyberSecurity_OWASP.server.curriculum import CurriculumController from CyberSecurity_OWASP.server.scenario_cache import ScenarioCache from training.trackio_utils import ( aggregate_episode_metrics, episode_record_from_state, episode_trace_fingerprint, log_reward_config, log_gpu_metrics, log_trace_table, log_trackio_metrics, reward_config_trackio_config, train_metric_aliases, ) from training.grpo_curriculum import ( ScenarioGroupRegistry, build_scenario_group_rows, ) transformers_hub.TRANSFORMERS_CACHE = cache_env["HF_HUB_CACHE"] hf_token = os.environ.get("HF_TOKEN") if not hf_token: raise RuntimeError( f"HF_TOKEN is missing from the Modal secret {SECRET_NAME}." ) user = whoami(token=hf_token)["name"] env_repo_id = env_repo_id or f"{user}/CyberSecurity_OWASP" output_repo_id = output_repo_id or ( f"{user}/{_grpo_output_repo_slug(model_name, initial_adapter_path=initial_adapter_path, initial_adapter_repo_id=initial_adapter_repo_id)}" ) if not trackio_space_id: trackio_space_id = "Humanlearning/CyberSecurity_OWASP-trackio" if hf_token: try: from huggingface_hub import whoami user = whoami(token=hf_token)["name"] if user == "humandotlearning": trackio_space_id = f"{user}/CyberSecurity_OWASP-trackio" except Exception: pass os.environ["TRACKIO_SPACE_ID"] = trackio_space_id os.environ["TRACKIO_PROJECT"] = trackio_project reward_env = _configure_reward_env( reward_config=reward_config, reward_variant=reward_variant, reward_mode="dense_train", ) reward_settings = load_reward_settings() reward_tracking_config = reward_config_trackio_config(reward_settings) reward_tracking_config["reward_variant"] = reward_variant or "default" reward_tracking_config["reward_config_path"] = reward_config or reward_settings.source_path model_slug = model_name.replace("/", "-") stamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S") algo_tag = _grpo_run_algo_tag( initial_adapter_path=initial_adapter_path, initial_adapter_repo_id=initial_adapter_repo_id, ) run_name = run_name or ( f"CyberSecurity_OWASP-{model_slug}-{algo_tag}-level{difficulty}-" f"{reward_tracking_config['reward_variant']}-steps{max_steps}-seed{seed_start}-" f"{stamp}-{git_sha[:8]}" ) output_dir = RUNS_DIR / run_name output_dir.mkdir(parents=True, exist_ok=True) try: cache_volume.reload() print(f"Reloaded Modal model cache volume: {CACHE_VOLUME_NAME}") except Exception as exc: print(f"Model cache volume reload skipped: {exc!r}") try: scenario_cache_volume.reload() print(f"Reloaded Modal scenario cache volume: {SCENARIO_CACHE_VOLUME_NAME}") except Exception as exc: print(f"Scenario cache volume reload skipped: {exc!r}") cache_env = _configure_modal_cache_env() scenario_cache_env = _configure_scenario_cache_env(required=True) scenario_settings = load_scenario_authoring_config() scenario_profile = CurriculumController(settings=scenario_settings).select_profile( seed=seed_start, split=split, requested_difficulty=difficulty, ) scenario_cache = ScenarioCache(SCENARIO_CACHE_DIR, settings=scenario_settings) scenario_cache_coverage = scenario_cache.assert_coverage( split=split, difficulty=int(scenario_profile["difficulty"]), ) scenario_entries = scenario_cache.validated_entries(split=split) scenario_registry = ScenarioGroupRegistry( scenario_entries, split=split, initial_difficulty=int(scenario_profile["difficulty"]), rng_seed=seed_start, max_level=scenario_settings.curriculum.difficulty_bucket_count - 1, ) training_prompt = ( "You are a defensive AppSec repair agent in the local CyberSecurity_OWASP " "OpenEnv environment. Use only the provided local tools. Do not target real " "systems. Work step by step: inspect policy and generated code, reproduce the " "authorization issue locally, submit a policy-tied diagnosis, patch the generated " "app, run visible tests, then submit the fix. Do not write explanations unless " "a tool argument needs evidence text." ) dataset = Dataset.from_list( build_scenario_group_rows( dataset_size=dataset_size, training_prompt=training_prompt, seed_start=seed_start, split=split, difficulty=difficulty, difficulty_policy="adaptive", ) ) def _state_snapshot(env: CybersecurityOwaspEnvironment) -> dict[str, Any]: state = env.state return { "episode_id": state.episode_id, "task_id": state.task_id, "seed": state.seed, "split": state.split, "difficulty": state.difficulty, "difficulty_tier": state.difficulty_tier, "domain": state.domain, "bug_family": state.bug_family, "template_id": state.template_id, "cache_hit": state.cache_hit, "scenario_hash": state.scenario_hash, "phase": state.phase, "step_count": state.step_count, "done": state.done, "success": state.success, "failure_reason": state.failure_reason, "anti_cheat_flags": list(state.anti_cheat_flags), } class CyberSecurityOWASPToolEnv: def __init__(self): self._env = CybersecurityOwaspEnvironment() self.reward = 0.0 self.reward_breakdown: dict[str, float] = {} self.done = False self.success = False self.invalid_actions = 0 self.scenario_group_id = -1 self.scenario_assignment: dict[str, Any] = {} self.trace_messages: list[dict[str, str]] = [] self.trace_metadata: dict[str, Any] = {} def reset(self, **kwargs) -> str: group_id = int(kwargs.get("scenario_group_id", kwargs.get("seed", seed_start))) assignment = scenario_registry.assignment_for( scenario_group_id=group_id, requested_seed=int(kwargs.get("seed", seed_start)), requested_difficulty=int(kwargs.get("difficulty", difficulty)), split=str(kwargs.get("split", split)), difficulty_policy=str(kwargs.get("difficulty_policy", "adaptive")), ) seed = int(assignment["seed"]) current_difficulty = int(assignment["difficulty"]) current_split = str(assignment["split"]) obs = self._env.reset( seed=seed, split=current_split, difficulty=current_difficulty, ) self.scenario_group_id = group_id self.scenario_assignment = assignment self.reward = 0.0 self.reward_breakdown = {} self.done = bool(obs.done) self.success = False self.invalid_actions = 0 self.trace_messages = [ { "role": "user", "content": ( f"{training_prompt}\n\n" f"{obs.scenario_prompt}\n\n" f"Initial message: {obs.message}" ), } ] self.trace_metadata = _state_snapshot(self._env) self.trace_metadata.update( { "scenario_group_id": self.scenario_group_id, "scenario_assignment": dict(self.scenario_assignment), "scenario_prompt_length": len(obs.scenario_prompt), "reward_config_id": reward_tracking_config["reward_config_id"], "reward_config_hash": reward_tracking_config["reward_config_hash"], "reward_stage": reward_tracking_config["reward_stage"], "reward_mode": reward_tracking_config["reward_mode"], "reward_variant": reward_tracking_config["reward_variant"], } ) return obs.scenario_prompt def _step(self, tool_name: str, arguments: dict[str, Any] | None = None) -> str: if self.done: raise ValueError("Episode is already over.") action = CyberSecurityOWASPAction( tool_name=tool_name, arguments=arguments or {}, ) obs = self._env.step(action) if not obs.last_action_valid: self.invalid_actions += 1 self.reward = float(self._env.state.accumulated_reward) self.reward_breakdown = dict(obs.reward_breakdown or {}) self.done = bool(obs.done) self.success = bool(self._env.state.success) self.trace_messages.extend( [ { "role": "assistant", "content": f"{tool_name}({arguments or {}})", }, {"role": "tool", "content": obs.message}, ] ) self.trace_metadata.update(_state_snapshot(self._env)) self.trace_metadata.update( { "last_action_valid": obs.last_action_valid, "last_action_error": obs.last_action_error, "reward": self.reward, "reward_breakdown": self.reward_breakdown, "invalid_actions": self.invalid_actions, "scenario_cache_hit": self._env.state.cache_hit, "scenario_hash": self._env.state.scenario_hash, "scenario_group_id": self.scenario_group_id, "scenario_assignment": dict(self.scenario_assignment), } ) return obs.message def inspect_policy_graph(self) -> str: """Return public policy hints for the generated local scenario.""" return self._step("inspect_policy_graph") def list_routes(self) -> str: """List generated local app route summaries.""" return self._step("list_routes") def read_openapi(self) -> str: """Read generated OpenAPI metadata for the local app.""" return self._step("read_openapi") def read_file(self, path: str) -> str: """ Read an editable generated workspace file by relative path. Args: path: Relative path inside the generated editable workspace. Returns: The file contents or a safe tool error observation. """ return self._step("read_file", {"path": path}) def search_code(self, query: str) -> str: """ Search editable generated workspace files for a string. Args: query: Search text to find in editable generated app files. Returns: Matching file lines or a no-match message. """ return self._step("search_code", {"query": query}) def send_local_request( self, path: str, method: str = "GET", user_id: str | None = None, ) -> str: """ Send a request to the generated local app only. Args: path: Local route path such as /health or /invoices/. method: HTTP method to use for the local request. user_id: Optional generated user identifier for authentication. Returns: JSON response from the simulated local app request. """ return self._step( "send_local_request", {"path": path, "method": method, "user_id": user_id}, ) def compare_identities( self, path: str, first_user_id: str, second_user_id: str, method: str = "GET", ) -> str: """ Compare one local request as two generated users. Args: path: Local route path to request as both generated users. first_user_id: First generated user identifier. second_user_id: Second generated user identifier. method: HTTP method to use for both local requests. Returns: JSON summary of both simulated local responses. """ return self._step( "compare_identities", { "path": path, "method": method, "first_user_id": first_user_id, "second_user_id": second_user_id, }, ) def submit_diagnosis( self, bug_class: str, route: str, violated_policy_rule: str, evidence_trace_ids: list[str], fix_plan: str, ) -> str: """ Submit structured diagnosis for the suspected authorization bug. Args: bug_class: Short class such as idor_ownership_bug. route: Method and route pattern believed to be vulnerable. violated_policy_rule: Policy rule that the behavior violates. evidence_trace_ids: Request trace IDs from local evidence tools. fix_plan: Concise secure repair plan. Returns: Diagnosis acceptance result and next phase information. """ return self._step( "submit_diagnosis", { "bug_class": bug_class, "route": route, "violated_policy_rule": violated_policy_rule, "evidence_trace_ids": evidence_trace_ids, "fix_plan": fix_plan, }, ) def patch_file( self, path: str, content: str | None = None, diff: str | None = None, ) -> str: """ Patch an editable generated app file with full content or a unified diff. Args: path: Relative path of the editable generated app file to patch. content: Complete replacement file content, when using full-file patching. diff: Unified diff to apply, when using diff patching. Returns: Patch application result. """ args: dict[str, Any] = {"path": path} if content is not None: args["content"] = content if diff is not None: args["diff"] = diff return self._step("patch_file", args) def run_visible_tests(self) -> str: """Run visible tests only; hidden tests are never exposed.""" return self._step("run_visible_tests") def submit_fix(self) -> str: """Submit the final patch to the hidden deterministic verifier.""" return self._step("submit_fix") def noop(self) -> str: """Take no action.""" return self._step("noop") def _score(self, completion_tokens: int = 0) -> float: token_penalty = compute_token_penalty(completion_tokens) self._env.state.completion_tokens = int(completion_tokens) self._env.state.metrics["completion_tokens"] = int(completion_tokens) self._env.state.metrics["token_penalty"] = token_penalty return float(self._env.state.accumulated_reward + token_penalty) def __del__(self): try: self._env.close() except Exception: pass trace_step = {"value": 0} logged_trace_fingerprints: set[str] = set() def _completion_to_text(completion) -> str: if completion is None: return "" if isinstance(completion, str): return completion if isinstance(completion, list): parts = [] for item in completion: if isinstance(item, dict): parts.append(str(item.get("content", item))) else: parts.append(str(item)) return "\n".join(parts) return str(completion) def _mean(values: list[float]) -> float: return float(sum(values) / len(values)) if values else 0.0 def cybersecurity_owasp_reward(environments, **kwargs) -> list[float]: completions = kwargs.get("completions") or kwargs.get("completion") or [] completion_texts = [_completion_to_text(item) for item in completions] completion_tokens = [len(text.split()) for text in completion_texts] rewards = [ float(env._score(completion_tokens[index] if index < len(completion_tokens) else 0)) for index, env in enumerate(environments) ] trace_step["value"] += 1 episode_records = [] for index, (env, reward) in enumerate(zip(environments, rewards)): record = episode_record_from_state( env._env.state, run_context={ "base_model": model_name, "algo": "grpo", "reward_version": "reward_v2", "env_version": "0.1.0", **reward_tracking_config, }, ) record.update( { "reward_total": reward, "reward_token_penalty": float(env._env.state.metrics.get("token_penalty", 0.0)), "completion_tokens": completion_tokens[index] if index < len(completion_tokens) else 0, "success": bool(getattr(env, "success", False)), } ) episode_records.append(record) group_successes: dict[int, list[float]] = {} for env in environments: group_id = int(getattr(env, "scenario_group_id", -1)) if group_id < 0: continue group_successes.setdefault(group_id, []).append(1.0 if getattr(env, "success", False) else 0.0) for group_id, successes in group_successes.items(): scenario_registry.record_group_outcome(group_id, _mean(successes)) batch_fingerprints = [ episode_trace_fingerprint(record) for record in episode_records ] sampled_traces = [] seen_this_batch: set[str] = set() duplicate_trace_suppressed_count = 0 for index, (env, record, reward, fingerprint) in enumerate( zip(environments, episode_records, rewards, batch_fingerprints) ): if fingerprint in seen_this_batch or fingerprint in logged_trace_fingerprints: duplicate_trace_suppressed_count += 1 continue seen_this_batch.add(fingerprint) if len(sampled_traces) < 4: sampled_traces.append((index, env, record, reward, fingerprint)) should_log_trace_artifacts = trace_log_every > 0 and ( trace_step["value"] == 1 or trace_step["value"] % trace_log_every == 0 ) canonical_metrics = aggregate_episode_metrics(episode_records) metrics = { **canonical_metrics, **train_metric_aliases(canonical_metrics), **scenario_registry.metrics( episode_records, unique_trace_count=len(set(batch_fingerprints)), duplicate_trace_suppressed_count=duplicate_trace_suppressed_count, ), } metrics["train/per_device_train_batch_size"] = float(per_device_train_batch_size) metrics["train/gradient_accumulation_steps"] = float( resolved_gradient_accumulation_steps ) metrics["train/effective_train_batch_size"] = float(effective_train_batch_size) metrics["train/num_generations"] = float(num_generations) metrics["train/use_vllm"] = float(bool(use_vllm)) metrics["train/vllm_gpu_memory_utilization"] = ( float(vllm_gpu_memory_utilization) if use_vllm else 0.0 ) metrics["train/trace_log_every"] = float(trace_log_every) metrics["train/trace_artifacts_logged"] = float(should_log_trace_artifacts) if rewards: metrics["train/reward_mean"] = _mean(rewards) metrics["train/reward_std"] = statistics.pstdev(rewards) if len(rewards) > 1 else 0.0 try: log_trackio_metrics(metrics, step=trace_step["value"]) except Exception as exc: print(f"Trackio metric logging skipped: {exc!r}") if should_log_trace_artifacts and sampled_traces: try: log_trace_table( [record for _, _, record, _, _ in sampled_traces], table_name="sample_traces", step=trace_step["value"], ) except Exception as exc: print(f"Trackio sample trace table logging skipped: {exc!r}") for index, env, _record, reward, fingerprint in sampled_traces: logged_trace_fingerprints.add(fingerprint) messages = list(getattr(env, "trace_messages", [])) if index < len(completions): completion_text = _completion_to_text(completions[index]) if completion_text: messages.append( { "role": "assistant", "content": f"Raw generated completion:\n{completion_text}", } ) metadata = dict(getattr(env, "trace_metadata", {})) metadata.update( { "sample_index": index, "reward": reward, "trace_step": trace_step["value"], "trace_fingerprint": fingerprint, "num_generations": num_generations, "run_name": run_name, "reward_config_id": reward_tracking_config["reward_config_id"], "reward_config_hash": reward_tracking_config["reward_config_hash"], "reward_stage": reward_tracking_config["reward_stage"], "reward_mode": reward_tracking_config["reward_mode"], "reward_variant": reward_tracking_config["reward_variant"], } ) try: trackio.log( { f"cybersecurity_owasp_trace/sample_{index}": trackio.Trace( messages=messages, metadata=metadata, ) }, step=trace_step["value"], ) except Exception as exc: print(f"Trackio trace logging skipped: {exc!r}") elif sampled_traces: print( "Trackio trace artifacts throttled at reward callback " f"{trace_step['value']}; set --trace-log-every 1 for every callback " "or 0 to disable trace artifacts." ) if rewards: print( "Reward batch: " f"mean={statistics.mean(rewards):.3f}, " f"min={min(rewards):.3f}, max={max(rewards):.3f}" ) return rewards class TrackioSystemMetricsCallback(TrainerCallback): def on_train_begin(self, args, state, control, **kwargs): try: reward_summary = log_reward_config(reward_settings, step=int(state.global_step or 0)) metrics = log_gpu_metrics(step=int(state.global_step or 0)) log_trackio_metrics( { "system/model_cache_hit": float(cache_hit), "system/scenario_cache_required": 1.0, "system/scenario_cache_entries": float( scenario_cache_coverage.get("entries", 0) ), "system/hub_push_enabled": float(push_to_hub), }, step=int(state.global_step or 0), ) print( "Trackio reward config logged: " f"{reward_summary['reward_config_id']} " f"({reward_summary['reward_config_hash']})" ) except Exception as exc: print(f"Trackio initialization metrics skipped: {exc!r}") return control if metrics: system_summary = ", ".join( f"{key}={value}" for key, value in sorted(metrics.items()) if key.startswith("system/") ) print(f"Trackio GPU metrics initialized: {system_summary}") return control def on_log(self, args, state, control, logs=None, **kwargs): try: metrics = log_gpu_metrics(step=int(state.global_step or 0)) except Exception as exc: print(f"Trackio GPU metrics skipped: {exc!r}") return control if metrics: summary = ", ".join(f"{key}={value}" for key, value in sorted(metrics.items())[:4]) print(f"Trackio GPU metrics logged at step {state.global_step}: {summary}") return control def on_train_end(self, args, state, control, **kwargs): try: log_gpu_metrics(step=int(state.global_step or 0)) except Exception as exc: print(f"Trackio final GPU metrics skipped: {exc!r}") return control print(f"CUDA available: {torch.cuda.is_available()}") if source_mode == "public": print(f"Installed CyberSecurity_OWASP from public repo: {repo_url}@{repo_branch}") else: print(f"Packaged local CyberSecurity_OWASP repo; default env repo id: {env_repo_id}") print(f"Trackio Space: {trackio_space_id}") print(f"Trackio Project: {trackio_project}") print(f"Output repo: {output_repo_id}") print(f"Run name: {run_name}") print(f"Reward config: {reward_tracking_config['reward_config_id']}") print(f"Reward config hash: {reward_tracking_config['reward_config_hash']}") print(f"Reward variant: {reward_tracking_config['reward_variant']}") print(f"Reward config path: {reward_tracking_config['reward_config_path']}") print(f"Learning rate: {learning_rate}") print(f"Reward env overrides: {reward_env}") print(f"Model cache volume: {CACHE_VOLUME_NAME}") print(f"Scenario cache volume: {SCENARIO_CACHE_VOLUME_NAME}") print(f"Scenario cache dir: {scenario_cache_env['CYBERSECURITY_OWASP_SCENARIO_CACHE_DIR']}") print("Scenario cache mode: require") print(f"Scenario cache coverage: {scenario_cache_coverage}") print(f"HF_HOME: {cache_env['HF_HOME']}") print(f"HF_HUB_CACHE: {cache_env['HF_HUB_CACHE']}") print(f"Torch cache: {cache_env['TORCH_HOME']}") print(f"Unsloth cache: {cache_env['UNSLOTH_CACHE_DIR']}") print(f"Triton cache: {cache_env['TRITON_CACHE_DIR']}") print(f"Hub push enabled: {push_to_hub}") if initial_adapter_path: print(f"Initial SFT adapter path: {initial_adapter_path}") if initial_adapter_repo_id: print(f"Initial SFT adapter repo: https://huggingface.co/{initial_adapter_repo_id}") print( "GRPO throughput config: " f"per_device_train_batch_size={per_device_train_batch_size}, " f"gradient_accumulation_steps={resolved_gradient_accumulation_steps}, " f"num_generations={num_generations}, " f"world_size={world_size}, " f"effective_train_batch_size={effective_train_batch_size}" ) print( "Generation acceleration config: " f"use_vllm={use_vllm}, " f"vllm_gpu_memory_utilization={vllm_gpu_memory_utilization}, " f"trace_log_every={trace_log_every}" ) expected_model_cache = _hf_model_cache_path(model_name) cache_hit = expected_model_cache.exists() print(f"Expected HF model cache path: {expected_model_cache}") print(f"Model cache hit before load: {cache_hit}") if cache_hit: print("Using cached model snapshot from the persistent Modal volume when valid.") else: print( "Model cache miss. Downloading model weights once into the persistent " "Modal cache volume; Hugging Face progress output should follow." ) try: snapshot_path = snapshot_download( repo_id=model_name, cache_dir=str(HF_HUB_CACHE_DIR), token=hf_token, ) print(f"Model snapshot ready: {snapshot_path}") cache_volume.commit() print(f"Committed Modal model cache volume after snapshot download: {CACHE_VOLUME_NAME}") except Exception as exc: print( "Explicit model snapshot prefetch failed; Unsloth will attempt the " f"model load directly. Error: {exc!r}" ) print(f"Loading model with Unsloth from_pretrained: {model_name}") model_api = FastVisionModel model_load_values = { "model_name": model_name, "max_seq_length": max_seq_length, "load_in_4bit": False, "fast_inference": use_vllm, "gpu_memory_utilization": vllm_gpu_memory_utilization if use_vllm else None, "cache_dir": str(HF_HUB_CACHE_DIR), "token": hf_token, } from_pretrained_parameters = inspect.signature(model_api.from_pretrained).parameters from_pretrained_accepts_kwargs = any( parameter.kind == inspect.Parameter.VAR_KEYWORD for parameter in from_pretrained_parameters.values() ) skipped_model_load_keys = sorted( key for key, value in model_load_values.items() if value is not None and key not in from_pretrained_parameters and not from_pretrained_accepts_kwargs ) if skipped_model_load_keys: print(f"Skipping unsupported from_pretrained keys: {skipped_model_load_keys}") model, tokenizer = model_api.from_pretrained( **{ key: value for key, value in model_load_values.items() if value is not None and (key in from_pretrained_parameters or from_pretrained_accepts_kwargs) } ) print("Model load complete.") cache_volume.commit() print(f"Committed Modal model cache volume after model load: {CACHE_VOLUME_NAME}") try: tokenizer = add_response_schema(tokenizer) except Exception as exc: print( "Tokenizer response schema add skipped for Gemma 4 processor, " "matching the Unsloth Gemma 4 GRPO notebook pattern: " f"{exc!r}" ) adapter_source = initial_adapter_path if initial_adapter_repo_id: print(f"Downloading initial SFT adapter: {initial_adapter_repo_id}") adapter_source = snapshot_download( repo_id=initial_adapter_repo_id, cache_dir=str(HF_HUB_CACHE_DIR), token=hf_token, ) cache_volume.commit() if adapter_source: print(f"Loading initial SFT adapter for trainable GRPO continuation: {adapter_source}") adapter_source_path = pathlib.Path(adapter_source) adapter_config_path = adapter_source_path / "adapter_config.json" if not adapter_config_path.exists(): raise RuntimeError(f"Initial SFT adapter config not found: {adapter_config_path}") adapter_config = json.loads(adapter_config_path.read_text(encoding="utf-8")) adapter_rank = int(adapter_config.get("r") or lora_rank) adapter_alpha = int(adapter_config.get("lora_alpha") or adapter_rank * 2) adapter_target_modules = adapter_config.get("target_modules") or [ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ] adapter_target_modules = list(adapter_target_modules) print( "Attaching Unsloth LoRA before loading SFT weights: " f"rank={adapter_rank}, alpha={adapter_alpha}, targets={adapter_target_modules}" ) model = model_api.get_peft_model( model, r=adapter_rank, target_modules=adapter_target_modules, lora_alpha=adapter_alpha, use_gradient_checkpointing="unsloth", random_state=3407, ) adapter_weights_path = adapter_source_path / "adapter_model.safetensors" if not adapter_weights_path.exists(): raise RuntimeError(f"Initial SFT adapter weights not found: {adapter_weights_path}") adapter_state = load_safetensors_file(str(adapter_weights_path), device="cpu") adapter_load_result = set_peft_model_state_dict( model, adapter_state, adapter_name="default", ) unexpected_adapter_keys = sorted( key for key in getattr(adapter_load_result, "unexpected_keys", []) if "lora_" in key or "modules_to_save" in key ) if unexpected_adapter_keys: raise RuntimeError( "Initial SFT adapter keys do not match the trainable Unsloth LoRA. " f"Unexpected adapter keys: {unexpected_adapter_keys[:10]}" ) missing_lora_keys = sorted( key for key in getattr(adapter_load_result, "missing_keys", []) if "lora_" in key or "modules_to_save" in key ) if missing_lora_keys: print(f"Missing LoRA keys while loading SFT adapter: {missing_lora_keys[:10]}") if hasattr(model, "print_trainable_parameters"): model.print_trainable_parameters() else: model = model_api.get_peft_model( model, r=lora_rank, target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ], lora_alpha=lora_rank * 2, use_gradient_checkpointing="unsloth", random_state=3407, ) if hasattr(model_api, "for_training"): model_api.for_training(model) print("LoRA adapter ready and model switched to training mode.") grpo_config_values = { "temperature": 1.0, "learning_rate": learning_rate, "weight_decay": 0.001, "warmup_ratio": 0.1, "lr_scheduler_type": "linear", "optim": "adamw_8bit", "logging_steps": 1, "per_device_train_batch_size": per_device_train_batch_size, "gradient_accumulation_steps": resolved_gradient_accumulation_steps, "num_generations": num_generations, "max_prompt_length": max_seq_length, "max_completion_length": max_completion_length, "max_steps": max_steps, "save_steps": max(10, max_steps), "report_to": "trackio", "project": trackio_project, "trackio_space_id": trackio_space_id, "run_name": run_name, "output_dir": str(output_dir), "push_to_hub": push_to_hub, "hub_model_id": output_repo_id, "hub_private_repo": True, "hub_strategy": "every_save", "gradient_checkpointing": True, "gradient_checkpointing_kwargs": {"use_reentrant": False}, "use_vllm": use_vllm, "vllm_mode": "colocate", "vllm_gpu_memory_utilization": vllm_gpu_memory_utilization, "epsilon": 0.2, "epsilon_high": 0.28, "delta": 1.5, "loss_type": "bnpo", "mask_truncated_completions": False, } grpo_config_parameters = set(inspect.signature(GRPOConfig).parameters) skipped_config_keys = sorted(set(grpo_config_values) - grpo_config_parameters) if skipped_config_keys: print(f"Skipping unsupported GRPOConfig keys: {skipped_config_keys}") training_args = GRPOConfig( **{ key: value for key, value in grpo_config_values.items() if key in grpo_config_parameters } ) trainer_values = { "model": model, "processing_class": tokenizer, "reward_funcs": cybersecurity_owasp_reward, "args": training_args, "train_dataset": dataset, "environment_factory": CyberSecurityOWASPToolEnv, "callbacks": [TrackioSystemMetricsCallback()], } trainer_parameters = set(inspect.signature(GRPOTrainer).parameters) skipped_trainer_keys = sorted(set(trainer_values) - trainer_parameters) if skipped_trainer_keys: print(f"Skipping unsupported GRPOTrainer keys: {skipped_trainer_keys}") trainer = GRPOTrainer( **{ key: value for key, value in trainer_values.items() if key in trainer_parameters } ) print("Starting GRPO trainer.train().") heartbeat_stop = threading.Event() def _training_heartbeat() -> None: start_time = time.monotonic() while not heartbeat_stop.wait(30): elapsed = int(time.monotonic() - start_time) print( "Training heartbeat: still inside trainer.train() " f"after {elapsed}s. For this smoke, the slow part is usually " f"Gemma generation/backprop: {num_generations} completions " f"up to {max_completion_length} tokens, plus Trackio upload." ) heartbeat_thread = threading.Thread( target=_training_heartbeat, name="grpo-training-heartbeat", daemon=True, ) heartbeat_thread.start() try: trainer.train() finally: heartbeat_stop.set() heartbeat_thread.join(timeout=2) print("GRPO trainer.train() complete.") if push_to_hub: print(f"Pushing LoRA adapter to Hugging Face Hub: {output_repo_id}") trainer.push_to_hub() print("Hub push complete.") else: print("Skipping Hub push for this run. Pass --push-to-hub to upload adapters.") volume.commit() cache_volume.commit() scenario_cache_volume.commit() print(f"Committed run volume: {VOLUME_NAME}") print(f"Committed model cache volume: {CACHE_VOLUME_NAME}") print(f"Committed scenario cache volume: {SCENARIO_CACHE_VOLUME_NAME}") try: trackio.finish() except RuntimeError as exc: print(f"Trackio finish skipped because the trainer already finalized it: {exc}") return { "run_name": run_name, "env_repo_id": env_repo_id, "output_repo_id": output_repo_id, "trackio_space_id": trackio_space_id, "trackio_project": trackio_project, "max_steps": max_steps, "dataset_size": dataset_size, "difficulty": difficulty, "split": split, "model_name": model_name, "initial_adapter_path": initial_adapter_path, "initial_adapter_repo_id": initial_adapter_repo_id, "max_completion_length": max_completion_length, "num_generations": num_generations, "per_device_train_batch_size": per_device_train_batch_size, "gradient_accumulation_steps": resolved_gradient_accumulation_steps, "learning_rate": learning_rate, "effective_train_batch_size": effective_train_batch_size, "use_vllm": int(bool(use_vllm)), "vllm_gpu_memory_utilization": vllm_gpu_memory_utilization, "trace_log_every": trace_log_every, "source_mode": source_mode, "repo_url": repo_url, "repo_branch": repo_branch, "push_to_hub": push_to_hub, "scenario_cache_volume": SCENARIO_CACHE_VOLUME_NAME, "scenario_cache_mode": "require", "reward_variant": reward_tracking_config["reward_variant"], **reward_tracking_config, } @app.local_entrypoint() def main( mode: str = "train", env_repo_id: str = "", output_repo_id: str = "", initial_adapter_path: str = "", initial_adapter_repo_id: str = "", max_steps: int = 10, dataset_size: int = 16, difficulty: int = 0, split: str = "train", model_name: str = DEFAULT_GEMMA_MODEL, max_seq_length: int = 4096, max_completion_length: int = 768, lora_rank: int = 32, trackio_space_id: str = "Humanlearning/CyberSecurity_OWASP-trackio", trackio_project: str = "CyberSecurity_OWASP-grpo", num_generations: int = 6, per_device_train_batch_size: int = 1, gradient_accumulation_steps: int = 0, learning_rate: float = 5e-6, use_vllm: bool = False, vllm_gpu_memory_utilization: float = 0.2, trace_log_every: int = 5, seed_start: int = 0, git_sha: str = "nogit", run_name: str = "", source_mode: str = "local", repo_url: str = PUBLIC_REPO_URL, repo_branch: str = PUBLIC_REPO_BRANCH, detach: bool = False, push_to_hub: bool = False, reward_config: str = "", reward_variant: str = "", cache_seed_start: int = 0, cache_difficulty_buckets: int = 0, cache_train_per_bucket: int = 0, cache_validation_per_bucket: int = 0, cache_heldout_per_bucket: int = 0, cache_force: bool = False, ) -> None: model_name = _ensure_gemma4_model(model_name) if mode == "prepare-cache": result = prepare_modal_scenario_cache.remote( seed_start=cache_seed_start, difficulty_buckets=cache_difficulty_buckets, train_per_bucket=cache_train_per_bucket, validation_per_bucket=cache_validation_per_bucket, heldout_per_bucket=cache_heldout_per_bucket, force=cache_force, ) print(f"Prepared scenario cache: {result}") return if mode == "config": result = check_training_imports.remote() print(result) return if mode == "baseline": if int(num_generations) != 1: raise ValueError("baseline mode expects --num-generations 1.") trace_log_every = max(0, int(trace_log_every)) run_name = run_name or "baseline" preflight = verify_modal_scenario_cache_for_training.remote( split=split, difficulty=difficulty, dataset_size=dataset_size, seed_start=seed_start, ) print(f"CPU scenario cache preflight passed: {preflight}") kwargs = dict( max_steps=max_steps, dataset_size=dataset_size, difficulty=difficulty, split=split, model_name=model_name, max_seq_length=max_seq_length, max_completion_length=max_completion_length, trackio_space_id=trackio_space_id, trackio_project=trackio_project, num_generations=num_generations, trace_log_every=trace_log_every, seed_start=seed_start, git_sha=git_sha, run_name=run_name, source_mode=source_mode, repo_url=repo_url, repo_branch=repo_branch, reward_config=reward_config, reward_variant=reward_variant, ) if detach: call = run_cybersecurity_owasp_baseline.spawn(**kwargs) print(f"Spawned Modal baseline call: {call.object_id}") else: result = run_cybersecurity_owasp_baseline.remote(**kwargs) print(f"Baseline result: {result}") return if mode != "train": raise ValueError("mode must be 'prepare-cache', 'train', 'baseline', or 'config'") ( resolved_gradient_accumulation_steps, effective_train_batch_size, ) = _resolve_grpo_batch_config( per_device_train_batch_size=per_device_train_batch_size, gradient_accumulation_steps=gradient_accumulation_steps, num_generations=num_generations, world_size=1, ) _validate_vllm_config( use_vllm=use_vllm, vllm_gpu_memory_utilization=vllm_gpu_memory_utilization, ) trace_log_every = max(0, int(trace_log_every)) trackio_space_id = trackio_space_id or os.environ.get( "TRACKIO_SPACE_ID", "Humanlearning/CyberSecurity_OWASP-trackio", ) trackio_project = trackio_project or os.environ.get( "TRACKIO_PROJECT", "CyberSecurity_OWASP-grpo" ) resolved_trackio_space_id = trackio_space_id resolved_output_repo_id = output_repo_id if not resolved_trackio_space_id or not resolved_output_repo_id: hf_token = os.environ.get("HF_TOKEN") if hf_token: try: from huggingface_hub import whoami user = whoami(token=hf_token)["name"] if not resolved_trackio_space_id: resolved_trackio_space_id = ( f"{user}/CyberSecurity_OWASP-trackio" if user == "humandotlearning" else "Humanlearning/CyberSecurity_OWASP-trackio" ) resolved_output_repo_id = ( resolved_output_repo_id or f"{user}/{_grpo_output_repo_slug(model_name, initial_adapter_path=initial_adapter_path, initial_adapter_repo_id=initial_adapter_repo_id)}" ) except Exception as exc: print(f"Could not resolve Hugging Face defaults locally: {exc!r}") if git_sha == "nogit": try: git_sha = subprocess.check_output( [ "git", "-c", f"safe.directory={PROJECT_ROOT.as_posix()}", "rev-parse", "HEAD", ], cwd=PROJECT_ROOT, text=True, stderr=subprocess.DEVNULL, ).strip() except Exception: git_sha = "nogit" model_slug = model_name.replace("/", "-") local_stamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S") variant_tag = reward_variant or "default" algo_tag = _grpo_run_algo_tag( initial_adapter_path=initial_adapter_path, initial_adapter_repo_id=initial_adapter_repo_id, ) run_name = run_name or ( f"CyberSecurity_OWASP-{model_slug}-{algo_tag}-level{difficulty}-" f"{variant_tag}-steps{max_steps}-seed{seed_start}-{local_stamp}-{git_sha[:8]}" ) print(f"Run name: {run_name}") print(f"Reward variant: {variant_tag}") print(f"Reward config path: {reward_config or '(default training/configs/grpo_small.yaml)'}") print(f"Source mode: {source_mode}") if source_mode == "public": print(f"Public repo: {repo_url}@{repo_branch}") if resolved_trackio_space_id: print(f"Trackio Space: https://huggingface.co/spaces/{resolved_trackio_space_id}") else: print("Trackio Space: derived remotely from HF_TOKEN as /CyberSecurity_OWASP-trackio") if resolved_output_repo_id: print(f"Output model repo: https://huggingface.co/{resolved_output_repo_id}") else: print( "Output model repo: derived remotely from HF_TOKEN as " f"/{_grpo_output_repo_slug(model_name, initial_adapter_path=initial_adapter_path, initial_adapter_repo_id=initial_adapter_repo_id)}" ) print(f"Hub push enabled: {push_to_hub}") if initial_adapter_path: print(f"Initial SFT adapter path: {initial_adapter_path}") if initial_adapter_repo_id: print(f"Initial SFT adapter repo: https://huggingface.co/{initial_adapter_repo_id}") print(f"Model cache volume: {CACHE_VOLUME_NAME}") print(f"Scenario cache volume: {SCENARIO_CACHE_VOLUME_NAME}") print( "GRPO throughput config: " f"per_device_train_batch_size={per_device_train_batch_size}, " f"gradient_accumulation_steps={resolved_gradient_accumulation_steps}, " f"num_generations={num_generations}, " f"effective_train_batch_size={effective_train_batch_size}, " f"learning_rate={learning_rate}" ) print( "Generation acceleration config: " f"use_vllm={use_vllm}, " f"vllm_gpu_memory_utilization={vllm_gpu_memory_utilization}, " f"trace_log_every={trace_log_every}" ) print("Launch phases:") print( "1. Modal image build/validation: happens before remote Python logs; " "slow when local source or dependency layers changed." ) print("2. CPU-only scenario cache preflight in CyberSecurity_OWASP-scenario-cache.") print(f"3. GPU container start after cache preflight passes; fallback={GRPO_GPU_FALLBACK}.") print("4. Model cache check in CyberSecurity_OWASP-model-cache.") print("5. Cached snapshot load into GPU RAM with Unsloth progress.") print("6. GRPO steps, Trackio sync, and volume commit.") print( "If there is a long pause after trainer.train() starts, watch for " "Training heartbeat lines every 30 seconds." ) kwargs = dict( env_repo_id=env_repo_id, output_repo_id=output_repo_id, initial_adapter_path=initial_adapter_path, initial_adapter_repo_id=initial_adapter_repo_id, max_steps=max_steps, dataset_size=dataset_size, difficulty=difficulty, split=split, model_name=model_name, max_seq_length=max_seq_length, max_completion_length=max_completion_length, lora_rank=lora_rank, trackio_space_id=trackio_space_id, trackio_project=trackio_project, num_generations=num_generations, per_device_train_batch_size=per_device_train_batch_size, gradient_accumulation_steps=resolved_gradient_accumulation_steps, learning_rate=learning_rate, use_vllm=use_vllm, vllm_gpu_memory_utilization=vllm_gpu_memory_utilization, trace_log_every=trace_log_every, seed_start=seed_start, git_sha=git_sha, run_name=run_name, source_mode=source_mode, repo_url=repo_url, repo_branch=repo_branch, push_to_hub=push_to_hub, reward_config=reward_config, reward_variant=reward_variant, ) preflight = verify_modal_scenario_cache_for_training.remote( split=split, difficulty=difficulty, dataset_size=dataset_size, seed_start=seed_start, ) print(f"CPU scenario cache preflight passed: {preflight}") if detach: call = train_cybersecurity_owasp_grpo.spawn(**kwargs) print(f"Spawned Modal training call: {call.object_id}") else: result = train_cybersecurity_owasp_grpo.remote(**kwargs) print(f"Training result: {result}")