"""Persistent Modal GRPO launcher for CyberSecurity_OWASP. This packages the local repository into a Modal GPU image, runs a small tool-use GRPO job against the in-process CyberSecurity_OWASP environment, logs metrics/traces to Trackio, and saves LoRA checkpoints in a persistent Modal volume. Example: uv run --extra modal modal run scripts/modal_train_grpo.py \ --max-steps 10 \ --dataset-size 16 \ --num-generations 2 \ --difficulty 0 """ from __future__ import annotations import os import pathlib import subprocess import sys from datetime import datetime, timezone from typing import Any import modal APP_NAME = "CyberSecurity_OWASP-grpo" VOLUME_NAME = "CyberSecurity_OWASP-grpo-runs" CACHE_VOLUME_NAME = "CyberSecurity_OWASP-model-cache" SECRET_NAME = "CyberSecurity_OWASP-secrets" RUNS_DIR = pathlib.Path("/runs") CACHE_DIR = pathlib.Path("/cache") HF_HOME_DIR = CACHE_DIR / "huggingface" HF_HUB_CACHE_DIR = HF_HOME_DIR / "hub" TORCH_HOME_DIR = CACHE_DIR / "torch" XDG_CACHE_DIR = CACHE_DIR / "xdg" UNSLOTH_CACHE_DIR = CACHE_DIR / "unsloth" TRITON_CACHE_DIR = CACHE_DIR / "triton" REMOTE_PROJECT = "/root/CyberSecurity_OWASP" PROJECT_ROOT = pathlib.Path(__file__).resolve().parents[1] PUBLIC_REPO_URL = "https://github.com/humandotlearning/CyberSecurity_OWASP.git" PUBLIC_REPO_BRANCH = "master" DEFAULT_GEMMA_MODEL = "unsloth/gemma-4-E2B-it" _IMAGE_NOTICE_PRINTED = False def _model_repo_slug(model_name: str) -> str: return ( model_name.replace("/", "-") .replace("_", "-") .replace(".", "-") .lower() ) def _hf_model_cache_path(model_name: str) -> pathlib.Path: return HF_HUB_CACHE_DIR / f"models--{model_name.replace('/', '--')}" def _configure_modal_cache_env() -> dict[str, str]: values = { "HF_HOME": str(HF_HOME_DIR), "HF_HUB_CACHE": str(HF_HUB_CACHE_DIR), "TRANSFORMERS_CACHE": str(HF_HUB_CACHE_DIR), "TORCH_HOME": str(TORCH_HOME_DIR), "XDG_CACHE_HOME": str(XDG_CACHE_DIR), "UNSLOTH_CACHE_DIR": str(UNSLOTH_CACHE_DIR), "UNSLOTH_COMPILE_CACHE": str(UNSLOTH_CACHE_DIR / "compile"), "TRITON_CACHE_DIR": str(TRITON_CACHE_DIR), } for key, value in values.items(): os.environ[key] = value for path in { CACHE_DIR, HF_HOME_DIR, HF_HUB_CACHE_DIR, TORCH_HOME_DIR, XDG_CACHE_DIR, UNSLOTH_CACHE_DIR, UNSLOTH_CACHE_DIR / "compile", TRITON_CACHE_DIR, }: path.mkdir(parents=True, exist_ok=True) return values def _print_image_startup_notice() -> None: global _IMAGE_NOTICE_PRINTED if _IMAGE_NOTICE_PRINTED: return _IMAGE_NOTICE_PRINTED = True print( "Modal startup phase 1/5: building or validating the GPU training image. " "If this takes minutes, it is Modal image packaging/dependency cache work, " "not model-weight download." ) print( "Later remote phases will print: cache hit/miss, snapshot_download progress, " "Unsloth weight loading, GRPO heartbeat, Trackio upload, and volume commits." ) def _load_local_env_file() -> None: env_path = PROJECT_ROOT / ".env.local" if not env_path.exists(): return for raw_line in env_path.read_text(encoding="utf-8").splitlines(): line = raw_line.strip() if not line or line.startswith("#") or "=" not in line: continue key, value = line.split("=", 1) key = key.strip() if key not in {"TRACKIO_PROJECT"}: continue value = value.strip().strip('"').strip("'") os.environ.setdefault(key, value) def _modal_secrets() -> list[modal.Secret]: if _is_config_mode(): return [] return [modal.Secret.from_name(SECRET_NAME, required_keys=["HF_TOKEN"])] def _is_config_mode() -> bool: args = sys.argv[1:] for index, arg in enumerate(args): if arg == "--mode" and index + 1 < len(args): return args[index + 1] == "config" if arg.startswith("--mode="): return arg.split("=", 1)[1] == "config" return False _load_local_env_file() def _cli_arg_value(name: str, default: str = "") -> str: args = sys.argv[1:] flag = f"--{name}" for index, arg in enumerate(args): if arg == flag and index + 1 < len(args): return args[index + 1] if arg.startswith(f"{flag}="): return arg.split("=", 1)[1] return default def _source_mode() -> str: return _cli_arg_value("source-mode", os.environ.get("MODAL_SOURCE_MODE", "local")) def _training_image() -> modal.Image: _print_image_startup_notice() image = ( modal.Image.from_registry( "nvidia/cuda:12.8.0-devel-ubuntu22.04", add_python="3.11", ) .apt_install("git", "build-essential", "curl") .uv_pip_install( "torch==2.10.0", "triton>=3.4.0", "torchvision==0.25.0", "bitsandbytes", "accelerate", "datasets", "huggingface_hub", "peft", "pillow", "tokenizers", "nvidia-ml-py", "trackio>=0.25.0", "transformers>=5.5.0", "trl>=0.28.0", "openenv-core[core]>=0.2.3", ) .uv_pip_install( "unsloth_zoo[base] @ git+https://github.com/unslothai/unsloth-zoo", "unsloth[base] @ git+https://github.com/unslothai/unsloth", ) .uv_pip_install("timm", extra_options="--no-deps") .uv_pip_install("pydantic==2.10.6") .uv_pip_install("mergekit", "immutables==0.21", extra_options="--no-deps") .uv_pip_install("llm-blender", "weave") .uv_pip_install("trl>=0.28.0", "transformers>=5.5.0", "jmespath") ) if _source_mode() == "public": repo_url = _cli_arg_value("repo-url", PUBLIC_REPO_URL) repo_branch = _cli_arg_value("repo-branch", PUBLIC_REPO_BRANCH) image = image.run_commands( f"git clone --depth 1 --branch {repo_branch} {repo_url} {REMOTE_PROJECT}", f"python -m pip install --no-deps -e {REMOTE_PROJECT}", ) else: image = image.add_local_dir( PROJECT_ROOT, remote_path=REMOTE_PROJECT, copy=True, ignore=[ ".git", ".venv", ".env", ".env.*", "__pycache__", ".pytest_cache", "outputs", "*.pyc", ], ) image = image.run_commands( f"python -m pip install --no-deps -e {REMOTE_PROJECT}", ) return image.run_commands( "python -c \"import os, torch; import transformers.utils.hub as hub; " "hub.TRANSFORMERS_CACHE = getattr(hub, 'TRANSFORMERS_CACHE', " "os.path.join(os.path.expanduser('~'), '.cache', 'huggingface', 'hub')); " "from trl import GRPOConfig, GRPOTrainer; " "from CyberSecurity_OWASP.server.CyberSecurity_OWASP_environment import " "CybersecurityOwaspEnvironment; print('trainer import ok', torch.__version__)\"", ).workdir(REMOTE_PROJECT) app = modal.App(APP_NAME) volume = modal.Volume.from_name(VOLUME_NAME, create_if_missing=True) cache_volume = modal.Volume.from_name(CACHE_VOLUME_NAME, create_if_missing=True) secrets = _modal_secrets() training_image = _training_image() @app.function( image=training_image, gpu="L4", timeout=4 * 60 * 60, volumes={RUNS_DIR: volume, CACHE_DIR: cache_volume}, secrets=secrets, ) def check_training_imports() -> dict[str, str]: cache_env = _configure_modal_cache_env() import torch import trackio from datasets import Dataset from trl import GRPOConfig, GRPOTrainer from unsloth import FastLanguageModel, FastVisionModel from CyberSecurity_OWASP.server.CyberSecurity_OWASP_environment import ( CybersecurityOwaspEnvironment, ) env = CybersecurityOwaspEnvironment() obs = env.reset(seed=0, split="validation", difficulty=0) return { "torch": torch.__version__, "trackio": getattr(trackio, "__version__", "unknown"), "dataset": Dataset.__name__, "grpo_config": GRPOConfig.__name__, "grpo_trainer": GRPOTrainer.__name__, "unsloth_model": FastLanguageModel.__name__, "unsloth_vision_model": FastVisionModel.__name__, "env": CybersecurityOwaspEnvironment.__name__, "reset_phase": obs.phase, "hf_home": cache_env["HF_HOME"], "hf_hub_cache": cache_env["HF_HUB_CACHE"], } @app.function( image=training_image, gpu="L4", timeout=4 * 60 * 60, volumes={RUNS_DIR: volume, CACHE_DIR: cache_volume}, secrets=secrets, ) def train_cybersecurity_owasp_grpo( env_repo_id: str = "", output_repo_id: str = "", max_steps: int = 10, dataset_size: int = 16, difficulty: int = 0, split: str = "train", model_name: str = DEFAULT_GEMMA_MODEL, max_seq_length: int = 4096, max_completion_length: int = 768, lora_rank: int = 32, trackio_space_id: str = "Humanlearning/CyberSecurity_OWASP-trackio", trackio_project: str = "CyberSecurity_OWASP-grpo", num_generations: int = 2, seed_start: int = 0, git_sha: str = "nogit", run_name: str = "", source_mode: str = "local", repo_url: str = PUBLIC_REPO_URL, repo_branch: str = PUBLIC_REPO_BRANCH, push_to_hub: bool = False, ) -> dict[str, str | int | float]: import inspect import statistics import threading import time cache_env = _configure_modal_cache_env() import torch from unsloth import FastLanguageModel, FastVisionModel import transformers.utils.hub as transformers_hub from datasets import Dataset from huggingface_hub import snapshot_download, whoami from transformers import TrainerCallback from trl import GRPOConfig, GRPOTrainer, clone_chat_template from trl.chat_template_utils import add_response_schema import trackio from CyberSecurity_OWASP.models import CyberSecurityOWASPAction from CyberSecurity_OWASP.server.CyberSecurity_OWASP_environment import ( CybersecurityOwaspEnvironment, ) from training.trackio_utils import ( aggregate_episode_metrics, episode_record_from_state, log_gpu_metrics, log_trace_table, log_trackio_metrics, train_metric_aliases, ) transformers_hub.TRANSFORMERS_CACHE = cache_env["HF_HUB_CACHE"] hf_token = os.environ.get("HF_TOKEN") if not hf_token: raise RuntimeError( f"HF_TOKEN is missing from the Modal secret {SECRET_NAME}." ) user = whoami(token=hf_token)["name"] env_repo_id = env_repo_id or f"{user}/CyberSecurity_OWASP" output_repo_id = output_repo_id or ( f"{user}/CyberSecurity_OWASP-{_model_repo_slug(model_name)}-grpo-lora" ) if not trackio_space_id: trackio_space_id = "Humanlearning/CyberSecurity_OWASP-trackio" if hf_token: try: from huggingface_hub import whoami user = whoami(token=hf_token)["name"] if user == "humandotlearning": trackio_space_id = f"{user}/CyberSecurity_OWASP-trackio" except Exception: pass os.environ["TRACKIO_SPACE_ID"] = trackio_space_id os.environ["TRACKIO_PROJECT"] = trackio_project model_slug = model_name.replace("/", "-") stamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S") run_name = run_name or ( f"CyberSecurity_OWASP-{model_slug}-grpo-level{difficulty}-{stamp}-{git_sha[:8]}" ) output_dir = RUNS_DIR / run_name output_dir.mkdir(parents=True, exist_ok=True) try: cache_volume.reload() print(f"Reloaded Modal model cache volume: {CACHE_VOLUME_NAME}") except Exception as exc: print(f"Model cache volume reload skipped: {exc!r}") cache_env = _configure_modal_cache_env() training_prompt = ( "You are a defensive AppSec repair agent in the local CyberSecurity_OWASP " "OpenEnv environment. Use only the provided local tools. Do not target real " "systems. Work step by step: inspect policy and generated code, reproduce the " "authorization issue locally, submit a policy-tied finding, patch the generated " "app, run visible tests, then submit the fix. Do not write explanations unless " "a tool argument needs evidence text." ) dataset = Dataset.from_list( [ { "prompt": [{"role": "user", "content": training_prompt}], "seed": seed_start + index, "difficulty": difficulty, "split": split, } for index in range(dataset_size) ] ) def _state_snapshot(env: CybersecurityOwaspEnvironment) -> dict[str, Any]: state = env.state return { "episode_id": state.episode_id, "task_id": state.task_id, "seed": state.seed, "split": state.split, "difficulty": state.difficulty, "domain": state.domain, "bug_family": state.bug_family, "phase": state.phase, "step_count": state.step_count, "done": state.done, "success": state.success, "failure_reason": state.failure_reason, "anti_cheat_flags": list(state.anti_cheat_flags), } class CyberSecurityOWASPToolEnv: def __init__(self): self._env = CybersecurityOwaspEnvironment() self.reward = 0.0 self.reward_breakdown: dict[str, float] = {} self.done = False self.success = False self.invalid_actions = 0 self.trace_messages: list[dict[str, str]] = [] self.trace_metadata: dict[str, Any] = {} def reset(self, **kwargs) -> str: seed = int(kwargs.get("seed", seed_start)) current_difficulty = int(kwargs.get("difficulty", difficulty)) current_split = str(kwargs.get("split", split)) obs = self._env.reset( seed=seed, split=current_split, difficulty=current_difficulty, ) self.reward = 0.0 self.reward_breakdown = {} self.done = bool(obs.done) self.success = False self.invalid_actions = 0 self.trace_messages = [ { "role": "user", "content": ( f"{training_prompt}\n\nInitial observation:\n" f"Phase: {obs.phase}\n" f"Task: {obs.task_brief}\n" f"Available actions: {obs.available_actions}\n" f"Workspace summary: {obs.workspace_summary}\n" f"Policy hint: {obs.visible_policy_hint}\n" f"Message: {obs.message}" ), } ] self.trace_metadata = _state_snapshot(self._env) return obs.message def _step(self, tool_name: str, arguments: dict[str, Any] | None = None) -> str: if self.done: raise ValueError("Episode is already over.") action = CyberSecurityOWASPAction( tool_name=tool_name, arguments=arguments or {}, ) obs = self._env.step(action) if not obs.last_action_valid: self.invalid_actions += 1 self.reward = float(obs.reward_breakdown.get("total", obs.reward or 0.0)) self.reward_breakdown = dict(obs.reward_breakdown or {}) self.done = bool(obs.done) self.success = bool(self._env.state.success) self.trace_messages.extend( [ { "role": "assistant", "content": f"{tool_name}({arguments or {}})", }, {"role": "tool", "content": obs.message}, ] ) self.trace_metadata.update(_state_snapshot(self._env)) self.trace_metadata.update( { "last_action_valid": obs.last_action_valid, "last_action_error": obs.last_action_error, "reward": self.reward, "reward_breakdown": self.reward_breakdown, "invalid_actions": self.invalid_actions, } ) return obs.message def inspect_policy_graph(self) -> str: """Return public policy hints for the generated local scenario.""" return self._step("inspect_policy_graph") def list_routes(self) -> str: """List generated local app route summaries.""" return self._step("list_routes") def read_openapi(self) -> str: """Read generated OpenAPI metadata for the local app.""" return self._step("read_openapi") def read_file(self, path: str) -> str: """ Read an editable generated workspace file by relative path. Args: path: Relative path inside the generated editable workspace. Returns: The file contents or a safe tool error observation. """ return self._step("read_file", {"path": path}) def search_code(self, query: str) -> str: """ Search editable generated workspace files for a string. Args: query: Search text to find in editable generated app files. Returns: Matching file lines or a no-match message. """ return self._step("search_code", {"query": query}) def send_local_request( self, path: str, method: str = "GET", user_id: str | None = None, ) -> str: """ Send a request to the generated local app only. Args: path: Local route path such as /health or /invoices/. method: HTTP method to use for the local request. user_id: Optional generated user identifier for authentication. Returns: JSON response from the simulated local app request. """ return self._step( "send_local_request", {"path": path, "method": method, "user_id": user_id}, ) def compare_identities( self, path: str, first_user_id: str, second_user_id: str, method: str = "GET", ) -> str: """ Compare one local request as two generated users. Args: path: Local route path to request as both generated users. first_user_id: First generated user identifier. second_user_id: Second generated user identifier. method: HTTP method to use for both local requests. Returns: JSON summary of both simulated local responses. """ return self._step( "compare_identities", { "path": path, "method": method, "first_user_id": first_user_id, "second_user_id": second_user_id, }, ) def submit_finding( self, summary: str, evidence: str, policy_rule: str, ) -> str: """ Submit structured evidence for the suspected authorization bug. Args: summary: Concise description of the suspected access-control bug. evidence: Local reproduction evidence from policy, code, or requests. policy_rule: Policy rule that the observed behavior violates. Returns: Finding acceptance result and next phase information. """ return self._step( "submit_finding", { "summary": summary, "evidence": evidence, "policy_rule": policy_rule, }, ) def patch_file( self, path: str, content: str | None = None, diff: str | None = None, ) -> str: """ Patch an editable generated app file with full content or a unified diff. Args: path: Relative path of the editable generated app file to patch. content: Complete replacement file content, when using full-file patching. diff: Unified diff to apply, when using diff patching. Returns: Patch application result. """ args: dict[str, Any] = {"path": path} if content is not None: args["content"] = content if diff is not None: args["diff"] = diff return self._step("patch_file", args) def run_visible_tests(self) -> str: """Run visible tests only; hidden tests are never exposed.""" return self._step("run_visible_tests") def submit_fix(self) -> str: """Submit the final patch to the hidden deterministic verifier.""" return self._step("submit_fix") def noop(self) -> str: """Take no action.""" return self._step("noop") def _score(self) -> float: return float(self.reward) def __del__(self): try: self._env.close() except Exception: pass trace_step = {"value": 0} def _completion_to_text(completion) -> str: if completion is None: return "" if isinstance(completion, str): return completion if isinstance(completion, list): parts = [] for item in completion: if isinstance(item, dict): parts.append(str(item.get("content", item))) else: parts.append(str(item)) return "\n".join(parts) return str(completion) def _mean(values: list[float]) -> float: return float(sum(values) / len(values)) if values else 0.0 def cybersecurity_owasp_reward(environments, **kwargs) -> list[float]: rewards = [float(env._score()) for env in environments] completions = kwargs.get("completions") or kwargs.get("completion") or [] trace_step["value"] += 1 episode_records = [] for env, reward in zip(environments, rewards): record = episode_record_from_state( env._env.state, run_context={ "base_model": model_name, "algo": "grpo", "reward_version": "reward_v1", "env_version": "0.1.0", }, ) record.update( { "reward_total": reward, "success": bool(getattr(env, "success", False)), } ) episode_records.append(record) canonical_metrics = aggregate_episode_metrics(episode_records) metrics = { **canonical_metrics, **train_metric_aliases(canonical_metrics), } if rewards: metrics["train/reward_mean"] = _mean(rewards) metrics["train/reward_std"] = statistics.pstdev(rewards) if len(rewards) > 1 else 0.0 try: log_trackio_metrics(metrics, step=trace_step["value"]) except Exception as exc: print(f"Trackio metric logging skipped: {exc!r}") try: log_trace_table( episode_records[: min(4, len(episode_records))], table_name="sample_traces", step=trace_step["value"], ) except Exception as exc: print(f"Trackio sample trace table logging skipped: {exc!r}") for index, env in enumerate(environments): messages = list(getattr(env, "trace_messages", [])) if index < len(completions): completion_text = _completion_to_text(completions[index]) if completion_text: messages.append( { "role": "assistant", "content": f"Raw generated completion:\n{completion_text}", } ) metadata = dict(getattr(env, "trace_metadata", {})) metadata.update( { "sample_index": index, "reward": rewards[index], "trace_step": trace_step["value"], "run_name": run_name, } ) try: trackio.log( { f"cybersecurity_owasp_trace/sample_{index}": trackio.Trace( messages=messages, metadata=metadata, ) }, step=trace_step["value"], ) except Exception as exc: print(f"Trackio trace logging skipped: {exc!r}") if rewards: print( "Reward batch: " f"mean={statistics.mean(rewards):.3f}, " f"min={min(rewards):.3f}, max={max(rewards):.3f}" ) return rewards class TrackioSystemMetricsCallback(TrainerCallback): def on_train_begin(self, args, state, control, **kwargs): try: metrics = log_gpu_metrics(step=int(state.global_step or 0)) log_trackio_metrics( { "system/model_cache_hit": float(cache_hit), "system/hub_push_enabled": float(push_to_hub), }, step=int(state.global_step or 0), ) except Exception as exc: print(f"Trackio GPU metrics initialization skipped: {exc!r}") return control if metrics: system_summary = ", ".join( f"{key}={value}" for key, value in sorted(metrics.items()) if key.startswith("system/") ) print(f"Trackio GPU metrics initialized: {system_summary}") return control def on_log(self, args, state, control, logs=None, **kwargs): try: metrics = log_gpu_metrics(step=int(state.global_step or 0)) except Exception as exc: print(f"Trackio GPU metrics skipped: {exc!r}") return control if metrics: summary = ", ".join(f"{key}={value}" for key, value in sorted(metrics.items())[:4]) print(f"Trackio GPU metrics logged at step {state.global_step}: {summary}") return control def on_train_end(self, args, state, control, **kwargs): try: log_gpu_metrics(step=int(state.global_step or 0)) except Exception as exc: print(f"Trackio final GPU metrics skipped: {exc!r}") return control print(f"CUDA available: {torch.cuda.is_available()}") if source_mode == "public": print(f"Installed CyberSecurity_OWASP from public repo: {repo_url}@{repo_branch}") else: print(f"Packaged local CyberSecurity_OWASP repo; default env repo id: {env_repo_id}") print(f"Trackio Space: {trackio_space_id}") print(f"Trackio Project: {trackio_project}") print(f"Output repo: {output_repo_id}") print(f"Run name: {run_name}") print(f"Model cache volume: {CACHE_VOLUME_NAME}") print(f"HF_HOME: {cache_env['HF_HOME']}") print(f"HF_HUB_CACHE: {cache_env['HF_HUB_CACHE']}") print(f"Torch cache: {cache_env['TORCH_HOME']}") print(f"Unsloth cache: {cache_env['UNSLOTH_CACHE_DIR']}") print(f"Triton cache: {cache_env['TRITON_CACHE_DIR']}") print(f"Hub push enabled: {push_to_hub}") expected_model_cache = _hf_model_cache_path(model_name) cache_hit = expected_model_cache.exists() print(f"Expected HF model cache path: {expected_model_cache}") print(f"Model cache hit before load: {cache_hit}") if cache_hit: print("Using cached model snapshot from the persistent Modal volume when valid.") else: print( "Model cache miss. Downloading model weights once into the persistent " "Modal cache volume; Hugging Face progress output should follow." ) try: snapshot_path = snapshot_download( repo_id=model_name, cache_dir=str(HF_HUB_CACHE_DIR), token=hf_token, ) print(f"Model snapshot ready: {snapshot_path}") cache_volume.commit() print(f"Committed Modal model cache volume after snapshot download: {CACHE_VOLUME_NAME}") except Exception as exc: print( "Explicit model snapshot prefetch failed; Unsloth will attempt the " f"model load directly. Error: {exc!r}" ) print(f"Loading model with Unsloth from_pretrained: {model_name}") model_api = FastVisionModel if "gemma-4" in model_name.lower() else FastLanguageModel model, tokenizer = model_api.from_pretrained( model_name=model_name, max_seq_length=max_seq_length, load_in_4bit=False, fast_inference=False, cache_dir=str(HF_HUB_CACHE_DIR), token=hf_token, ) print("Model load complete.") cache_volume.commit() print(f"Committed Modal model cache volume after model load: {CACHE_VOLUME_NAME}") try: tokenizer = add_response_schema(tokenizer) except Exception as exc: if "gemma-4" in model_name.lower(): print( "Tokenizer response schema add skipped for Gemma 4 processor, " "matching the Unsloth Gemma 4 GRPO notebook pattern: " f"{exc!r}" ) else: print(f"Tokenizer response schema add failed before cloning: {exc!r}") for template_source in ("Qwen/Qwen3-0.6B", "Qwen/Qwen2.5-0.5B-Instruct"): try: model, tokenizer, added_tokens = clone_chat_template( model, tokenizer, template_source, ) print( "Cloned response-schema-capable chat template " f"from {template_source}; added {len(added_tokens)} tokens." ) tokenizer = add_response_schema(tokenizer) break except Exception as clone_exc: print( "Tokenizer response schema fallback failed for " f"{template_source}: {clone_exc!r}" ) else: raise model = model_api.get_peft_model( model, r=lora_rank, target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ], lora_alpha=lora_rank * 2, use_gradient_checkpointing="unsloth", random_state=3407, ) if hasattr(model_api, "for_training"): model_api.for_training(model) print("LoRA adapter attached and model switched to training mode.") grpo_config_values = { "temperature": 1.0, "learning_rate": 5e-6, "weight_decay": 0.001, "warmup_ratio": 0.1, "lr_scheduler_type": "linear", "optim": "adamw_8bit", "logging_steps": 1, "per_device_train_batch_size": 1, "gradient_accumulation_steps": max(2, num_generations), "num_generations": num_generations, "max_prompt_length": max_seq_length, "max_completion_length": max_completion_length, "max_steps": max_steps, "save_steps": max(10, max_steps), "report_to": "trackio", "project": trackio_project, "trackio_space_id": trackio_space_id, "run_name": run_name, "output_dir": str(output_dir), "push_to_hub": push_to_hub, "hub_model_id": output_repo_id, "hub_private_repo": True, "hub_strategy": "every_save", "gradient_checkpointing": True, "gradient_checkpointing_kwargs": {"use_reentrant": False}, "epsilon": 0.2, "epsilon_high": 0.28, "delta": 1.5, "loss_type": "bnpo", "mask_truncated_completions": True, } grpo_config_parameters = set(inspect.signature(GRPOConfig).parameters) skipped_config_keys = sorted(set(grpo_config_values) - grpo_config_parameters) if skipped_config_keys: print(f"Skipping unsupported GRPOConfig keys: {skipped_config_keys}") training_args = GRPOConfig( **{ key: value for key, value in grpo_config_values.items() if key in grpo_config_parameters } ) trainer_values = { "model": model, "processing_class": tokenizer, "reward_funcs": cybersecurity_owasp_reward, "args": training_args, "train_dataset": dataset, "environment_factory": CyberSecurityOWASPToolEnv, "callbacks": [TrackioSystemMetricsCallback()], } trainer_parameters = set(inspect.signature(GRPOTrainer).parameters) skipped_trainer_keys = sorted(set(trainer_values) - trainer_parameters) if skipped_trainer_keys: print(f"Skipping unsupported GRPOTrainer keys: {skipped_trainer_keys}") trainer = GRPOTrainer( **{ key: value for key, value in trainer_values.items() if key in trainer_parameters } ) print("Starting GRPO trainer.train().") heartbeat_stop = threading.Event() def _training_heartbeat() -> None: start_time = time.monotonic() while not heartbeat_stop.wait(30): elapsed = int(time.monotonic() - start_time) print( "Training heartbeat: still inside trainer.train() " f"after {elapsed}s. For this smoke, the slow part is usually " f"Gemma generation/backprop on L4: {num_generations} completions " f"up to {max_completion_length} tokens, plus Trackio upload." ) heartbeat_thread = threading.Thread( target=_training_heartbeat, name="grpo-training-heartbeat", daemon=True, ) heartbeat_thread.start() try: trainer.train() finally: heartbeat_stop.set() heartbeat_thread.join(timeout=2) print("GRPO trainer.train() complete.") if push_to_hub: print(f"Pushing LoRA adapter to Hugging Face Hub: {output_repo_id}") trainer.push_to_hub() print("Hub push complete.") else: print("Skipping Hub push for this run. Pass --push-to-hub to upload adapters.") volume.commit() cache_volume.commit() print(f"Committed run volume: {VOLUME_NAME}") print(f"Committed model cache volume: {CACHE_VOLUME_NAME}") try: trackio.finish() except RuntimeError as exc: print(f"Trackio finish skipped because the trainer already finalized it: {exc}") return { "run_name": run_name, "env_repo_id": env_repo_id, "output_repo_id": output_repo_id, "trackio_space_id": trackio_space_id, "trackio_project": trackio_project, "max_steps": max_steps, "dataset_size": dataset_size, "difficulty": difficulty, "split": split, "model_name": model_name, "max_completion_length": max_completion_length, "num_generations": num_generations, "source_mode": source_mode, "repo_url": repo_url, "repo_branch": repo_branch, "push_to_hub": push_to_hub, } @app.local_entrypoint() def main( mode: str = "train", env_repo_id: str = "", output_repo_id: str = "", max_steps: int = 10, dataset_size: int = 16, difficulty: int = 0, split: str = "train", model_name: str = DEFAULT_GEMMA_MODEL, max_seq_length: int = 4096, max_completion_length: int = 768, lora_rank: int = 32, trackio_space_id: str = "Humanlearning/CyberSecurity_OWASP-trackio", trackio_project: str = "CyberSecurity_OWASP-grpo", num_generations: int = 2, seed_start: int = 0, git_sha: str = "nogit", source_mode: str = "local", repo_url: str = PUBLIC_REPO_URL, repo_branch: str = PUBLIC_REPO_BRANCH, detach: bool = False, push_to_hub: bool = False, ) -> None: if mode == "config": result = check_training_imports.remote() print(result) return if mode != "train": raise ValueError("mode must be 'train' or 'config'") trackio_space_id = trackio_space_id or os.environ.get( "TRACKIO_SPACE_ID", "Humanlearning/CyberSecurity_OWASP-trackio", ) trackio_project = trackio_project or os.environ.get( "TRACKIO_PROJECT", "CyberSecurity_OWASP-grpo" ) resolved_trackio_space_id = trackio_space_id resolved_output_repo_id = output_repo_id if not resolved_trackio_space_id or not resolved_output_repo_id: hf_token = os.environ.get("HF_TOKEN") if hf_token: try: from huggingface_hub import whoami user = whoami(token=hf_token)["name"] if not resolved_trackio_space_id: resolved_trackio_space_id = ( f"{user}/CyberSecurity_OWASP-trackio" if user == "humandotlearning" else "Humanlearning/CyberSecurity_OWASP-trackio" ) resolved_output_repo_id = ( resolved_output_repo_id or f"{user}/CyberSecurity_OWASP-{_model_repo_slug(model_name)}-grpo-lora" ) except Exception as exc: print(f"Could not resolve Hugging Face defaults locally: {exc!r}") if git_sha == "nogit": try: git_sha = subprocess.check_output( ["git", "rev-parse", "HEAD"], cwd=PROJECT_ROOT, text=True, stderr=subprocess.DEVNULL, ).strip() except Exception: git_sha = "nogit" model_slug = model_name.replace("/", "-") local_stamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S") run_name = ( f"CyberSecurity_OWASP-{model_slug}-grpo-level{difficulty}-" f"{local_stamp}-{git_sha[:8]}" ) print(f"Run name: {run_name}") print(f"Source mode: {source_mode}") if source_mode == "public": print(f"Public repo: {repo_url}@{repo_branch}") if resolved_trackio_space_id: print(f"Trackio Space: https://huggingface.co/spaces/{resolved_trackio_space_id}") else: print("Trackio Space: derived remotely from HF_TOKEN as /CyberSecurity_OWASP-trackio") if resolved_output_repo_id: print(f"Output model repo: https://huggingface.co/{resolved_output_repo_id}") else: print( "Output model repo: derived remotely from HF_TOKEN as " f"/CyberSecurity_OWASP-{_model_repo_slug(model_name)}-grpo-lora" ) print(f"Hub push enabled: {push_to_hub}") print(f"Model cache volume: {CACHE_VOLUME_NAME}") print("Launch phases:") print( "1. Modal image build/validation: happens before remote Python logs; " "slow when local source or dependency layers changed." ) print("2. GPU container start on one L4 and persistent volume reload.") print("3. Model cache check in CyberSecurity_OWASP-model-cache.") print("4. Cached snapshot load into GPU RAM with Unsloth progress.") print("5. One GRPO step, Trackio sync, and volume commit.") print( "If there is a long pause after trainer.train() starts, watch for " "Training heartbeat lines every 30 seconds." ) kwargs = dict( env_repo_id=env_repo_id, output_repo_id=output_repo_id, max_steps=max_steps, dataset_size=dataset_size, difficulty=difficulty, split=split, model_name=model_name, max_seq_length=max_seq_length, max_completion_length=max_completion_length, lora_rank=lora_rank, trackio_space_id=trackio_space_id, trackio_project=trackio_project, num_generations=num_generations, seed_start=seed_start, git_sha=git_sha, run_name=run_name, source_mode=source_mode, repo_url=repo_url, repo_branch=repo_branch, push_to_hub=push_to_hub, ) if detach: call = train_cybersecurity_owasp_grpo.spawn(**kwargs) print(f"Spawned Modal training call: {call.object_id}") else: result = train_cybersecurity_owasp_grpo.remote(**kwargs) print(f"Training result: {result}")