Spaces:
Sleeping
Sleeping
feat: enhance training image setup and add startup notice for Modal execution, improve dependency installation process, and implement training heartbeat for monitoring
448eddd | """Persistent Modal GRPO launcher for CyberSecurity_OWASP. | |
| This packages the local repository into a Modal GPU image, runs a small | |
| tool-use GRPO job against the in-process CyberSecurity_OWASP environment, logs | |
| metrics/traces to Trackio, and saves LoRA checkpoints in a persistent Modal | |
| volume. | |
| Example: | |
| uv run --extra modal modal run scripts/modal_train_grpo.py \ | |
| --max-steps 10 \ | |
| --dataset-size 16 \ | |
| --num-generations 2 \ | |
| --difficulty 0 | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import pathlib | |
| import subprocess | |
| import sys | |
| from datetime import datetime, timezone | |
| from typing import Any | |
| import modal | |
| APP_NAME = "CyberSecurity_OWASP-grpo" | |
| VOLUME_NAME = "CyberSecurity_OWASP-grpo-runs" | |
| CACHE_VOLUME_NAME = "CyberSecurity_OWASP-model-cache" | |
| SECRET_NAME = "CyberSecurity_OWASP-secrets" | |
| RUNS_DIR = pathlib.Path("/runs") | |
| CACHE_DIR = pathlib.Path("/cache") | |
| HF_HOME_DIR = CACHE_DIR / "huggingface" | |
| HF_HUB_CACHE_DIR = HF_HOME_DIR / "hub" | |
| TORCH_HOME_DIR = CACHE_DIR / "torch" | |
| XDG_CACHE_DIR = CACHE_DIR / "xdg" | |
| UNSLOTH_CACHE_DIR = CACHE_DIR / "unsloth" | |
| TRITON_CACHE_DIR = CACHE_DIR / "triton" | |
| REMOTE_PROJECT = "/root/CyberSecurity_OWASP" | |
| PROJECT_ROOT = pathlib.Path(__file__).resolve().parents[1] | |
| PUBLIC_REPO_URL = "https://github.com/humandotlearning/CyberSecurity_OWASP.git" | |
| PUBLIC_REPO_BRANCH = "master" | |
| DEFAULT_GEMMA_MODEL = "unsloth/gemma-4-E2B-it" | |
| _IMAGE_NOTICE_PRINTED = False | |
| def _model_repo_slug(model_name: str) -> str: | |
| return ( | |
| model_name.replace("/", "-") | |
| .replace("_", "-") | |
| .replace(".", "-") | |
| .lower() | |
| ) | |
| def _hf_model_cache_path(model_name: str) -> pathlib.Path: | |
| return HF_HUB_CACHE_DIR / f"models--{model_name.replace('/', '--')}" | |
| def _configure_modal_cache_env() -> dict[str, str]: | |
| values = { | |
| "HF_HOME": str(HF_HOME_DIR), | |
| "HF_HUB_CACHE": str(HF_HUB_CACHE_DIR), | |
| "TRANSFORMERS_CACHE": str(HF_HUB_CACHE_DIR), | |
| "TORCH_HOME": str(TORCH_HOME_DIR), | |
| "XDG_CACHE_HOME": str(XDG_CACHE_DIR), | |
| "UNSLOTH_CACHE_DIR": str(UNSLOTH_CACHE_DIR), | |
| "UNSLOTH_COMPILE_CACHE": str(UNSLOTH_CACHE_DIR / "compile"), | |
| "TRITON_CACHE_DIR": str(TRITON_CACHE_DIR), | |
| } | |
| for key, value in values.items(): | |
| os.environ[key] = value | |
| for path in { | |
| CACHE_DIR, | |
| HF_HOME_DIR, | |
| HF_HUB_CACHE_DIR, | |
| TORCH_HOME_DIR, | |
| XDG_CACHE_DIR, | |
| UNSLOTH_CACHE_DIR, | |
| UNSLOTH_CACHE_DIR / "compile", | |
| TRITON_CACHE_DIR, | |
| }: | |
| path.mkdir(parents=True, exist_ok=True) | |
| return values | |
| def _print_image_startup_notice() -> None: | |
| global _IMAGE_NOTICE_PRINTED | |
| if _IMAGE_NOTICE_PRINTED: | |
| return | |
| _IMAGE_NOTICE_PRINTED = True | |
| print( | |
| "Modal startup phase 1/5: building or validating the GPU training image. " | |
| "If this takes minutes, it is Modal image packaging/dependency cache work, " | |
| "not model-weight download." | |
| ) | |
| print( | |
| "Later remote phases will print: cache hit/miss, snapshot_download progress, " | |
| "Unsloth weight loading, GRPO heartbeat, Trackio upload, and volume commits." | |
| ) | |
| def _load_local_env_file() -> None: | |
| env_path = PROJECT_ROOT / ".env.local" | |
| if not env_path.exists(): | |
| return | |
| for raw_line in env_path.read_text(encoding="utf-8").splitlines(): | |
| line = raw_line.strip() | |
| if not line or line.startswith("#") or "=" not in line: | |
| continue | |
| key, value = line.split("=", 1) | |
| key = key.strip() | |
| if key not in {"TRACKIO_PROJECT"}: | |
| continue | |
| value = value.strip().strip('"').strip("'") | |
| os.environ.setdefault(key, value) | |
| def _modal_secrets() -> list[modal.Secret]: | |
| if _is_config_mode(): | |
| return [] | |
| return [modal.Secret.from_name(SECRET_NAME, required_keys=["HF_TOKEN"])] | |
| def _is_config_mode() -> bool: | |
| args = sys.argv[1:] | |
| for index, arg in enumerate(args): | |
| if arg == "--mode" and index + 1 < len(args): | |
| return args[index + 1] == "config" | |
| if arg.startswith("--mode="): | |
| return arg.split("=", 1)[1] == "config" | |
| return False | |
| _load_local_env_file() | |
| def _cli_arg_value(name: str, default: str = "") -> str: | |
| args = sys.argv[1:] | |
| flag = f"--{name}" | |
| for index, arg in enumerate(args): | |
| if arg == flag and index + 1 < len(args): | |
| return args[index + 1] | |
| if arg.startswith(f"{flag}="): | |
| return arg.split("=", 1)[1] | |
| return default | |
| def _source_mode() -> str: | |
| return _cli_arg_value("source-mode", os.environ.get("MODAL_SOURCE_MODE", "local")) | |
| def _training_image() -> modal.Image: | |
| _print_image_startup_notice() | |
| image = ( | |
| modal.Image.from_registry( | |
| "nvidia/cuda:12.8.0-devel-ubuntu22.04", | |
| add_python="3.11", | |
| ) | |
| .apt_install("git", "build-essential", "curl") | |
| .uv_pip_install( | |
| "torch==2.10.0", | |
| "triton>=3.4.0", | |
| "torchvision==0.25.0", | |
| "bitsandbytes", | |
| "accelerate", | |
| "datasets", | |
| "huggingface_hub", | |
| "peft", | |
| "pillow", | |
| "tokenizers", | |
| "nvidia-ml-py", | |
| "trackio>=0.25.0", | |
| "transformers>=5.5.0", | |
| "trl>=0.28.0", | |
| "openenv-core[core]>=0.2.3", | |
| ) | |
| .uv_pip_install( | |
| "unsloth_zoo[base] @ git+https://github.com/unslothai/unsloth-zoo", | |
| "unsloth[base] @ git+https://github.com/unslothai/unsloth", | |
| ) | |
| .uv_pip_install("timm", extra_options="--no-deps") | |
| .uv_pip_install("pydantic==2.10.6") | |
| .uv_pip_install("mergekit", "immutables==0.21", extra_options="--no-deps") | |
| .uv_pip_install("llm-blender", "weave") | |
| .uv_pip_install("trl>=0.28.0", "transformers>=5.5.0", "jmespath") | |
| ) | |
| if _source_mode() == "public": | |
| repo_url = _cli_arg_value("repo-url", PUBLIC_REPO_URL) | |
| repo_branch = _cli_arg_value("repo-branch", PUBLIC_REPO_BRANCH) | |
| image = image.run_commands( | |
| f"git clone --depth 1 --branch {repo_branch} {repo_url} {REMOTE_PROJECT}", | |
| f"python -m pip install --no-deps -e {REMOTE_PROJECT}", | |
| ) | |
| else: | |
| image = image.add_local_dir( | |
| PROJECT_ROOT, | |
| remote_path=REMOTE_PROJECT, | |
| copy=True, | |
| ignore=[ | |
| ".git", | |
| ".venv", | |
| ".env", | |
| ".env.*", | |
| "__pycache__", | |
| ".pytest_cache", | |
| "outputs", | |
| "*.pyc", | |
| ], | |
| ) | |
| image = image.run_commands( | |
| f"python -m pip install --no-deps -e {REMOTE_PROJECT}", | |
| ) | |
| return image.run_commands( | |
| "python -c \"import os, torch; import transformers.utils.hub as hub; " | |
| "hub.TRANSFORMERS_CACHE = getattr(hub, 'TRANSFORMERS_CACHE', " | |
| "os.path.join(os.path.expanduser('~'), '.cache', 'huggingface', 'hub')); " | |
| "from trl import GRPOConfig, GRPOTrainer; " | |
| "from CyberSecurity_OWASP.server.CyberSecurity_OWASP_environment import " | |
| "CybersecurityOwaspEnvironment; print('trainer import ok', torch.__version__)\"", | |
| ).workdir(REMOTE_PROJECT) | |
| app = modal.App(APP_NAME) | |
| volume = modal.Volume.from_name(VOLUME_NAME, create_if_missing=True) | |
| cache_volume = modal.Volume.from_name(CACHE_VOLUME_NAME, create_if_missing=True) | |
| secrets = _modal_secrets() | |
| training_image = _training_image() | |
| def check_training_imports() -> dict[str, str]: | |
| cache_env = _configure_modal_cache_env() | |
| import torch | |
| import trackio | |
| from datasets import Dataset | |
| from trl import GRPOConfig, GRPOTrainer | |
| from unsloth import FastLanguageModel, FastVisionModel | |
| from CyberSecurity_OWASP.server.CyberSecurity_OWASP_environment import ( | |
| CybersecurityOwaspEnvironment, | |
| ) | |
| env = CybersecurityOwaspEnvironment() | |
| obs = env.reset(seed=0, split="validation", difficulty=0) | |
| return { | |
| "torch": torch.__version__, | |
| "trackio": getattr(trackio, "__version__", "unknown"), | |
| "dataset": Dataset.__name__, | |
| "grpo_config": GRPOConfig.__name__, | |
| "grpo_trainer": GRPOTrainer.__name__, | |
| "unsloth_model": FastLanguageModel.__name__, | |
| "unsloth_vision_model": FastVisionModel.__name__, | |
| "env": CybersecurityOwaspEnvironment.__name__, | |
| "reset_phase": obs.phase, | |
| "hf_home": cache_env["HF_HOME"], | |
| "hf_hub_cache": cache_env["HF_HUB_CACHE"], | |
| } | |
| def train_cybersecurity_owasp_grpo( | |
| env_repo_id: str = "", | |
| output_repo_id: str = "", | |
| max_steps: int = 10, | |
| dataset_size: int = 16, | |
| difficulty: int = 0, | |
| split: str = "train", | |
| model_name: str = DEFAULT_GEMMA_MODEL, | |
| max_seq_length: int = 4096, | |
| max_completion_length: int = 768, | |
| lora_rank: int = 32, | |
| trackio_space_id: str = "Humanlearning/CyberSecurity_OWASP-trackio", | |
| trackio_project: str = "CyberSecurity_OWASP-grpo", | |
| num_generations: int = 2, | |
| seed_start: int = 0, | |
| git_sha: str = "nogit", | |
| run_name: str = "", | |
| source_mode: str = "local", | |
| repo_url: str = PUBLIC_REPO_URL, | |
| repo_branch: str = PUBLIC_REPO_BRANCH, | |
| push_to_hub: bool = False, | |
| ) -> dict[str, str | int | float]: | |
| import inspect | |
| import statistics | |
| import threading | |
| import time | |
| cache_env = _configure_modal_cache_env() | |
| import torch | |
| from unsloth import FastLanguageModel, FastVisionModel | |
| import transformers.utils.hub as transformers_hub | |
| from datasets import Dataset | |
| from huggingface_hub import snapshot_download, whoami | |
| from transformers import TrainerCallback | |
| from trl import GRPOConfig, GRPOTrainer, clone_chat_template | |
| from trl.chat_template_utils import add_response_schema | |
| import trackio | |
| from CyberSecurity_OWASP.models import CyberSecurityOWASPAction | |
| from CyberSecurity_OWASP.server.CyberSecurity_OWASP_environment import ( | |
| CybersecurityOwaspEnvironment, | |
| ) | |
| from training.trackio_utils import ( | |
| aggregate_episode_metrics, | |
| episode_record_from_state, | |
| log_gpu_metrics, | |
| log_trace_table, | |
| log_trackio_metrics, | |
| train_metric_aliases, | |
| ) | |
| transformers_hub.TRANSFORMERS_CACHE = cache_env["HF_HUB_CACHE"] | |
| hf_token = os.environ.get("HF_TOKEN") | |
| if not hf_token: | |
| raise RuntimeError( | |
| f"HF_TOKEN is missing from the Modal secret {SECRET_NAME}." | |
| ) | |
| user = whoami(token=hf_token)["name"] | |
| env_repo_id = env_repo_id or f"{user}/CyberSecurity_OWASP" | |
| output_repo_id = output_repo_id or ( | |
| f"{user}/CyberSecurity_OWASP-{_model_repo_slug(model_name)}-grpo-lora" | |
| ) | |
| if not trackio_space_id: | |
| trackio_space_id = "Humanlearning/CyberSecurity_OWASP-trackio" | |
| if hf_token: | |
| try: | |
| from huggingface_hub import whoami | |
| user = whoami(token=hf_token)["name"] | |
| if user == "humandotlearning": | |
| trackio_space_id = f"{user}/CyberSecurity_OWASP-trackio" | |
| except Exception: | |
| pass | |
| os.environ["TRACKIO_SPACE_ID"] = trackio_space_id | |
| os.environ["TRACKIO_PROJECT"] = trackio_project | |
| model_slug = model_name.replace("/", "-") | |
| stamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S") | |
| run_name = run_name or ( | |
| f"CyberSecurity_OWASP-{model_slug}-grpo-level{difficulty}-{stamp}-{git_sha[:8]}" | |
| ) | |
| output_dir = RUNS_DIR / run_name | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| try: | |
| cache_volume.reload() | |
| print(f"Reloaded Modal model cache volume: {CACHE_VOLUME_NAME}") | |
| except Exception as exc: | |
| print(f"Model cache volume reload skipped: {exc!r}") | |
| cache_env = _configure_modal_cache_env() | |
| training_prompt = ( | |
| "You are a defensive AppSec repair agent in the local CyberSecurity_OWASP " | |
| "OpenEnv environment. Use only the provided local tools. Do not target real " | |
| "systems. Work step by step: inspect policy and generated code, reproduce the " | |
| "authorization issue locally, submit a policy-tied finding, patch the generated " | |
| "app, run visible tests, then submit the fix. Do not write explanations unless " | |
| "a tool argument needs evidence text." | |
| ) | |
| dataset = Dataset.from_list( | |
| [ | |
| { | |
| "prompt": [{"role": "user", "content": training_prompt}], | |
| "seed": seed_start + index, | |
| "difficulty": difficulty, | |
| "split": split, | |
| } | |
| for index in range(dataset_size) | |
| ] | |
| ) | |
| def _state_snapshot(env: CybersecurityOwaspEnvironment) -> dict[str, Any]: | |
| state = env.state | |
| return { | |
| "episode_id": state.episode_id, | |
| "task_id": state.task_id, | |
| "seed": state.seed, | |
| "split": state.split, | |
| "difficulty": state.difficulty, | |
| "domain": state.domain, | |
| "bug_family": state.bug_family, | |
| "phase": state.phase, | |
| "step_count": state.step_count, | |
| "done": state.done, | |
| "success": state.success, | |
| "failure_reason": state.failure_reason, | |
| "anti_cheat_flags": list(state.anti_cheat_flags), | |
| } | |
| class CyberSecurityOWASPToolEnv: | |
| def __init__(self): | |
| self._env = CybersecurityOwaspEnvironment() | |
| self.reward = 0.0 | |
| self.reward_breakdown: dict[str, float] = {} | |
| self.done = False | |
| self.success = False | |
| self.invalid_actions = 0 | |
| self.trace_messages: list[dict[str, str]] = [] | |
| self.trace_metadata: dict[str, Any] = {} | |
| def reset(self, **kwargs) -> str: | |
| seed = int(kwargs.get("seed", seed_start)) | |
| current_difficulty = int(kwargs.get("difficulty", difficulty)) | |
| current_split = str(kwargs.get("split", split)) | |
| obs = self._env.reset( | |
| seed=seed, | |
| split=current_split, | |
| difficulty=current_difficulty, | |
| ) | |
| self.reward = 0.0 | |
| self.reward_breakdown = {} | |
| self.done = bool(obs.done) | |
| self.success = False | |
| self.invalid_actions = 0 | |
| self.trace_messages = [ | |
| { | |
| "role": "user", | |
| "content": ( | |
| f"{training_prompt}\n\nInitial observation:\n" | |
| f"Phase: {obs.phase}\n" | |
| f"Task: {obs.task_brief}\n" | |
| f"Available actions: {obs.available_actions}\n" | |
| f"Workspace summary: {obs.workspace_summary}\n" | |
| f"Policy hint: {obs.visible_policy_hint}\n" | |
| f"Message: {obs.message}" | |
| ), | |
| } | |
| ] | |
| self.trace_metadata = _state_snapshot(self._env) | |
| return obs.message | |
| def _step(self, tool_name: str, arguments: dict[str, Any] | None = None) -> str: | |
| if self.done: | |
| raise ValueError("Episode is already over.") | |
| action = CyberSecurityOWASPAction( | |
| tool_name=tool_name, | |
| arguments=arguments or {}, | |
| ) | |
| obs = self._env.step(action) | |
| if not obs.last_action_valid: | |
| self.invalid_actions += 1 | |
| self.reward = float(obs.reward_breakdown.get("total", obs.reward or 0.0)) | |
| self.reward_breakdown = dict(obs.reward_breakdown or {}) | |
| self.done = bool(obs.done) | |
| self.success = bool(self._env.state.success) | |
| self.trace_messages.extend( | |
| [ | |
| { | |
| "role": "assistant", | |
| "content": f"{tool_name}({arguments or {}})", | |
| }, | |
| {"role": "tool", "content": obs.message}, | |
| ] | |
| ) | |
| self.trace_metadata.update(_state_snapshot(self._env)) | |
| self.trace_metadata.update( | |
| { | |
| "last_action_valid": obs.last_action_valid, | |
| "last_action_error": obs.last_action_error, | |
| "reward": self.reward, | |
| "reward_breakdown": self.reward_breakdown, | |
| "invalid_actions": self.invalid_actions, | |
| } | |
| ) | |
| return obs.message | |
| def inspect_policy_graph(self) -> str: | |
| """Return public policy hints for the generated local scenario.""" | |
| return self._step("inspect_policy_graph") | |
| def list_routes(self) -> str: | |
| """List generated local app route summaries.""" | |
| return self._step("list_routes") | |
| def read_openapi(self) -> str: | |
| """Read generated OpenAPI metadata for the local app.""" | |
| return self._step("read_openapi") | |
| def read_file(self, path: str) -> str: | |
| """ | |
| Read an editable generated workspace file by relative path. | |
| Args: | |
| path: Relative path inside the generated editable workspace. | |
| Returns: | |
| The file contents or a safe tool error observation. | |
| """ | |
| return self._step("read_file", {"path": path}) | |
| def search_code(self, query: str) -> str: | |
| """ | |
| Search editable generated workspace files for a string. | |
| Args: | |
| query: Search text to find in editable generated app files. | |
| Returns: | |
| Matching file lines or a no-match message. | |
| """ | |
| return self._step("search_code", {"query": query}) | |
| def send_local_request( | |
| self, | |
| path: str, | |
| method: str = "GET", | |
| user_id: str | None = None, | |
| ) -> str: | |
| """ | |
| Send a request to the generated local app only. | |
| Args: | |
| path: Local route path such as /health or /invoices/<id>. | |
| method: HTTP method to use for the local request. | |
| user_id: Optional generated user identifier for authentication. | |
| Returns: | |
| JSON response from the simulated local app request. | |
| """ | |
| return self._step( | |
| "send_local_request", | |
| {"path": path, "method": method, "user_id": user_id}, | |
| ) | |
| def compare_identities( | |
| self, | |
| path: str, | |
| first_user_id: str, | |
| second_user_id: str, | |
| method: str = "GET", | |
| ) -> str: | |
| """ | |
| Compare one local request as two generated users. | |
| Args: | |
| path: Local route path to request as both generated users. | |
| first_user_id: First generated user identifier. | |
| second_user_id: Second generated user identifier. | |
| method: HTTP method to use for both local requests. | |
| Returns: | |
| JSON summary of both simulated local responses. | |
| """ | |
| return self._step( | |
| "compare_identities", | |
| { | |
| "path": path, | |
| "method": method, | |
| "first_user_id": first_user_id, | |
| "second_user_id": second_user_id, | |
| }, | |
| ) | |
| def submit_finding( | |
| self, | |
| summary: str, | |
| evidence: str, | |
| policy_rule: str, | |
| ) -> str: | |
| """ | |
| Submit structured evidence for the suspected authorization bug. | |
| Args: | |
| summary: Concise description of the suspected access-control bug. | |
| evidence: Local reproduction evidence from policy, code, or requests. | |
| policy_rule: Policy rule that the observed behavior violates. | |
| Returns: | |
| Finding acceptance result and next phase information. | |
| """ | |
| return self._step( | |
| "submit_finding", | |
| { | |
| "summary": summary, | |
| "evidence": evidence, | |
| "policy_rule": policy_rule, | |
| }, | |
| ) | |
| def patch_file( | |
| self, | |
| path: str, | |
| content: str | None = None, | |
| diff: str | None = None, | |
| ) -> str: | |
| """ | |
| Patch an editable generated app file with full content or a unified diff. | |
| Args: | |
| path: Relative path of the editable generated app file to patch. | |
| content: Complete replacement file content, when using full-file patching. | |
| diff: Unified diff to apply, when using diff patching. | |
| Returns: | |
| Patch application result. | |
| """ | |
| args: dict[str, Any] = {"path": path} | |
| if content is not None: | |
| args["content"] = content | |
| if diff is not None: | |
| args["diff"] = diff | |
| return self._step("patch_file", args) | |
| def run_visible_tests(self) -> str: | |
| """Run visible tests only; hidden tests are never exposed.""" | |
| return self._step("run_visible_tests") | |
| def submit_fix(self) -> str: | |
| """Submit the final patch to the hidden deterministic verifier.""" | |
| return self._step("submit_fix") | |
| def noop(self) -> str: | |
| """Take no action.""" | |
| return self._step("noop") | |
| def _score(self) -> float: | |
| return float(self.reward) | |
| def __del__(self): | |
| try: | |
| self._env.close() | |
| except Exception: | |
| pass | |
| trace_step = {"value": 0} | |
| def _completion_to_text(completion) -> str: | |
| if completion is None: | |
| return "" | |
| if isinstance(completion, str): | |
| return completion | |
| if isinstance(completion, list): | |
| parts = [] | |
| for item in completion: | |
| if isinstance(item, dict): | |
| parts.append(str(item.get("content", item))) | |
| else: | |
| parts.append(str(item)) | |
| return "\n".join(parts) | |
| return str(completion) | |
| def _mean(values: list[float]) -> float: | |
| return float(sum(values) / len(values)) if values else 0.0 | |
| def cybersecurity_owasp_reward(environments, **kwargs) -> list[float]: | |
| rewards = [float(env._score()) for env in environments] | |
| completions = kwargs.get("completions") or kwargs.get("completion") or [] | |
| trace_step["value"] += 1 | |
| episode_records = [] | |
| for env, reward in zip(environments, rewards): | |
| record = episode_record_from_state( | |
| env._env.state, | |
| run_context={ | |
| "base_model": model_name, | |
| "algo": "grpo", | |
| "reward_version": "reward_v1", | |
| "env_version": "0.1.0", | |
| }, | |
| ) | |
| record.update( | |
| { | |
| "reward_total": reward, | |
| "success": bool(getattr(env, "success", False)), | |
| } | |
| ) | |
| episode_records.append(record) | |
| canonical_metrics = aggregate_episode_metrics(episode_records) | |
| metrics = { | |
| **canonical_metrics, | |
| **train_metric_aliases(canonical_metrics), | |
| } | |
| if rewards: | |
| metrics["train/reward_mean"] = _mean(rewards) | |
| metrics["train/reward_std"] = statistics.pstdev(rewards) if len(rewards) > 1 else 0.0 | |
| try: | |
| log_trackio_metrics(metrics, step=trace_step["value"]) | |
| except Exception as exc: | |
| print(f"Trackio metric logging skipped: {exc!r}") | |
| try: | |
| log_trace_table( | |
| episode_records[: min(4, len(episode_records))], | |
| table_name="sample_traces", | |
| step=trace_step["value"], | |
| ) | |
| except Exception as exc: | |
| print(f"Trackio sample trace table logging skipped: {exc!r}") | |
| for index, env in enumerate(environments): | |
| messages = list(getattr(env, "trace_messages", [])) | |
| if index < len(completions): | |
| completion_text = _completion_to_text(completions[index]) | |
| if completion_text: | |
| messages.append( | |
| { | |
| "role": "assistant", | |
| "content": f"Raw generated completion:\n{completion_text}", | |
| } | |
| ) | |
| metadata = dict(getattr(env, "trace_metadata", {})) | |
| metadata.update( | |
| { | |
| "sample_index": index, | |
| "reward": rewards[index], | |
| "trace_step": trace_step["value"], | |
| "run_name": run_name, | |
| } | |
| ) | |
| try: | |
| trackio.log( | |
| { | |
| f"cybersecurity_owasp_trace/sample_{index}": trackio.Trace( | |
| messages=messages, | |
| metadata=metadata, | |
| ) | |
| }, | |
| step=trace_step["value"], | |
| ) | |
| except Exception as exc: | |
| print(f"Trackio trace logging skipped: {exc!r}") | |
| if rewards: | |
| print( | |
| "Reward batch: " | |
| f"mean={statistics.mean(rewards):.3f}, " | |
| f"min={min(rewards):.3f}, max={max(rewards):.3f}" | |
| ) | |
| return rewards | |
| class TrackioSystemMetricsCallback(TrainerCallback): | |
| def on_train_begin(self, args, state, control, **kwargs): | |
| try: | |
| metrics = log_gpu_metrics(step=int(state.global_step or 0)) | |
| log_trackio_metrics( | |
| { | |
| "system/model_cache_hit": float(cache_hit), | |
| "system/hub_push_enabled": float(push_to_hub), | |
| }, | |
| step=int(state.global_step or 0), | |
| ) | |
| except Exception as exc: | |
| print(f"Trackio GPU metrics initialization skipped: {exc!r}") | |
| return control | |
| if metrics: | |
| system_summary = ", ".join( | |
| f"{key}={value}" | |
| for key, value in sorted(metrics.items()) | |
| if key.startswith("system/") | |
| ) | |
| print(f"Trackio GPU metrics initialized: {system_summary}") | |
| return control | |
| def on_log(self, args, state, control, logs=None, **kwargs): | |
| try: | |
| metrics = log_gpu_metrics(step=int(state.global_step or 0)) | |
| except Exception as exc: | |
| print(f"Trackio GPU metrics skipped: {exc!r}") | |
| return control | |
| if metrics: | |
| summary = ", ".join(f"{key}={value}" for key, value in sorted(metrics.items())[:4]) | |
| print(f"Trackio GPU metrics logged at step {state.global_step}: {summary}") | |
| return control | |
| def on_train_end(self, args, state, control, **kwargs): | |
| try: | |
| log_gpu_metrics(step=int(state.global_step or 0)) | |
| except Exception as exc: | |
| print(f"Trackio final GPU metrics skipped: {exc!r}") | |
| return control | |
| print(f"CUDA available: {torch.cuda.is_available()}") | |
| if source_mode == "public": | |
| print(f"Installed CyberSecurity_OWASP from public repo: {repo_url}@{repo_branch}") | |
| else: | |
| print(f"Packaged local CyberSecurity_OWASP repo; default env repo id: {env_repo_id}") | |
| print(f"Trackio Space: {trackio_space_id}") | |
| print(f"Trackio Project: {trackio_project}") | |
| print(f"Output repo: {output_repo_id}") | |
| print(f"Run name: {run_name}") | |
| print(f"Model cache volume: {CACHE_VOLUME_NAME}") | |
| print(f"HF_HOME: {cache_env['HF_HOME']}") | |
| print(f"HF_HUB_CACHE: {cache_env['HF_HUB_CACHE']}") | |
| print(f"Torch cache: {cache_env['TORCH_HOME']}") | |
| print(f"Unsloth cache: {cache_env['UNSLOTH_CACHE_DIR']}") | |
| print(f"Triton cache: {cache_env['TRITON_CACHE_DIR']}") | |
| print(f"Hub push enabled: {push_to_hub}") | |
| expected_model_cache = _hf_model_cache_path(model_name) | |
| cache_hit = expected_model_cache.exists() | |
| print(f"Expected HF model cache path: {expected_model_cache}") | |
| print(f"Model cache hit before load: {cache_hit}") | |
| if cache_hit: | |
| print("Using cached model snapshot from the persistent Modal volume when valid.") | |
| else: | |
| print( | |
| "Model cache miss. Downloading model weights once into the persistent " | |
| "Modal cache volume; Hugging Face progress output should follow." | |
| ) | |
| try: | |
| snapshot_path = snapshot_download( | |
| repo_id=model_name, | |
| cache_dir=str(HF_HUB_CACHE_DIR), | |
| token=hf_token, | |
| ) | |
| print(f"Model snapshot ready: {snapshot_path}") | |
| cache_volume.commit() | |
| print(f"Committed Modal model cache volume after snapshot download: {CACHE_VOLUME_NAME}") | |
| except Exception as exc: | |
| print( | |
| "Explicit model snapshot prefetch failed; Unsloth will attempt the " | |
| f"model load directly. Error: {exc!r}" | |
| ) | |
| print(f"Loading model with Unsloth from_pretrained: {model_name}") | |
| model_api = FastVisionModel if "gemma-4" in model_name.lower() else FastLanguageModel | |
| model, tokenizer = model_api.from_pretrained( | |
| model_name=model_name, | |
| max_seq_length=max_seq_length, | |
| load_in_4bit=False, | |
| fast_inference=False, | |
| cache_dir=str(HF_HUB_CACHE_DIR), | |
| token=hf_token, | |
| ) | |
| print("Model load complete.") | |
| cache_volume.commit() | |
| print(f"Committed Modal model cache volume after model load: {CACHE_VOLUME_NAME}") | |
| try: | |
| tokenizer = add_response_schema(tokenizer) | |
| except Exception as exc: | |
| if "gemma-4" in model_name.lower(): | |
| print( | |
| "Tokenizer response schema add skipped for Gemma 4 processor, " | |
| "matching the Unsloth Gemma 4 GRPO notebook pattern: " | |
| f"{exc!r}" | |
| ) | |
| else: | |
| print(f"Tokenizer response schema add failed before cloning: {exc!r}") | |
| for template_source in ("Qwen/Qwen3-0.6B", "Qwen/Qwen2.5-0.5B-Instruct"): | |
| try: | |
| model, tokenizer, added_tokens = clone_chat_template( | |
| model, | |
| tokenizer, | |
| template_source, | |
| ) | |
| print( | |
| "Cloned response-schema-capable chat template " | |
| f"from {template_source}; added {len(added_tokens)} tokens." | |
| ) | |
| tokenizer = add_response_schema(tokenizer) | |
| break | |
| except Exception as clone_exc: | |
| print( | |
| "Tokenizer response schema fallback failed for " | |
| f"{template_source}: {clone_exc!r}" | |
| ) | |
| else: | |
| raise | |
| model = model_api.get_peft_model( | |
| model, | |
| r=lora_rank, | |
| target_modules=[ | |
| "q_proj", | |
| "k_proj", | |
| "v_proj", | |
| "o_proj", | |
| "gate_proj", | |
| "up_proj", | |
| "down_proj", | |
| ], | |
| lora_alpha=lora_rank * 2, | |
| use_gradient_checkpointing="unsloth", | |
| random_state=3407, | |
| ) | |
| if hasattr(model_api, "for_training"): | |
| model_api.for_training(model) | |
| print("LoRA adapter attached and model switched to training mode.") | |
| grpo_config_values = { | |
| "temperature": 1.0, | |
| "learning_rate": 5e-6, | |
| "weight_decay": 0.001, | |
| "warmup_ratio": 0.1, | |
| "lr_scheduler_type": "linear", | |
| "optim": "adamw_8bit", | |
| "logging_steps": 1, | |
| "per_device_train_batch_size": 1, | |
| "gradient_accumulation_steps": max(2, num_generations), | |
| "num_generations": num_generations, | |
| "max_prompt_length": max_seq_length, | |
| "max_completion_length": max_completion_length, | |
| "max_steps": max_steps, | |
| "save_steps": max(10, max_steps), | |
| "report_to": "trackio", | |
| "project": trackio_project, | |
| "trackio_space_id": trackio_space_id, | |
| "run_name": run_name, | |
| "output_dir": str(output_dir), | |
| "push_to_hub": push_to_hub, | |
| "hub_model_id": output_repo_id, | |
| "hub_private_repo": True, | |
| "hub_strategy": "every_save", | |
| "gradient_checkpointing": True, | |
| "gradient_checkpointing_kwargs": {"use_reentrant": False}, | |
| "epsilon": 0.2, | |
| "epsilon_high": 0.28, | |
| "delta": 1.5, | |
| "loss_type": "bnpo", | |
| "mask_truncated_completions": True, | |
| } | |
| grpo_config_parameters = set(inspect.signature(GRPOConfig).parameters) | |
| skipped_config_keys = sorted(set(grpo_config_values) - grpo_config_parameters) | |
| if skipped_config_keys: | |
| print(f"Skipping unsupported GRPOConfig keys: {skipped_config_keys}") | |
| training_args = GRPOConfig( | |
| **{ | |
| key: value | |
| for key, value in grpo_config_values.items() | |
| if key in grpo_config_parameters | |
| } | |
| ) | |
| trainer_values = { | |
| "model": model, | |
| "processing_class": tokenizer, | |
| "reward_funcs": cybersecurity_owasp_reward, | |
| "args": training_args, | |
| "train_dataset": dataset, | |
| "environment_factory": CyberSecurityOWASPToolEnv, | |
| "callbacks": [TrackioSystemMetricsCallback()], | |
| } | |
| trainer_parameters = set(inspect.signature(GRPOTrainer).parameters) | |
| skipped_trainer_keys = sorted(set(trainer_values) - trainer_parameters) | |
| if skipped_trainer_keys: | |
| print(f"Skipping unsupported GRPOTrainer keys: {skipped_trainer_keys}") | |
| trainer = GRPOTrainer( | |
| **{ | |
| key: value | |
| for key, value in trainer_values.items() | |
| if key in trainer_parameters | |
| } | |
| ) | |
| print("Starting GRPO trainer.train().") | |
| heartbeat_stop = threading.Event() | |
| def _training_heartbeat() -> None: | |
| start_time = time.monotonic() | |
| while not heartbeat_stop.wait(30): | |
| elapsed = int(time.monotonic() - start_time) | |
| print( | |
| "Training heartbeat: still inside trainer.train() " | |
| f"after {elapsed}s. For this smoke, the slow part is usually " | |
| f"Gemma generation/backprop on L4: {num_generations} completions " | |
| f"up to {max_completion_length} tokens, plus Trackio upload." | |
| ) | |
| heartbeat_thread = threading.Thread( | |
| target=_training_heartbeat, | |
| name="grpo-training-heartbeat", | |
| daemon=True, | |
| ) | |
| heartbeat_thread.start() | |
| try: | |
| trainer.train() | |
| finally: | |
| heartbeat_stop.set() | |
| heartbeat_thread.join(timeout=2) | |
| print("GRPO trainer.train() complete.") | |
| if push_to_hub: | |
| print(f"Pushing LoRA adapter to Hugging Face Hub: {output_repo_id}") | |
| trainer.push_to_hub() | |
| print("Hub push complete.") | |
| else: | |
| print("Skipping Hub push for this run. Pass --push-to-hub to upload adapters.") | |
| volume.commit() | |
| cache_volume.commit() | |
| print(f"Committed run volume: {VOLUME_NAME}") | |
| print(f"Committed model cache volume: {CACHE_VOLUME_NAME}") | |
| try: | |
| trackio.finish() | |
| except RuntimeError as exc: | |
| print(f"Trackio finish skipped because the trainer already finalized it: {exc}") | |
| return { | |
| "run_name": run_name, | |
| "env_repo_id": env_repo_id, | |
| "output_repo_id": output_repo_id, | |
| "trackio_space_id": trackio_space_id, | |
| "trackio_project": trackio_project, | |
| "max_steps": max_steps, | |
| "dataset_size": dataset_size, | |
| "difficulty": difficulty, | |
| "split": split, | |
| "model_name": model_name, | |
| "max_completion_length": max_completion_length, | |
| "num_generations": num_generations, | |
| "source_mode": source_mode, | |
| "repo_url": repo_url, | |
| "repo_branch": repo_branch, | |
| "push_to_hub": push_to_hub, | |
| } | |
| def main( | |
| mode: str = "train", | |
| env_repo_id: str = "", | |
| output_repo_id: str = "", | |
| max_steps: int = 10, | |
| dataset_size: int = 16, | |
| difficulty: int = 0, | |
| split: str = "train", | |
| model_name: str = DEFAULT_GEMMA_MODEL, | |
| max_seq_length: int = 4096, | |
| max_completion_length: int = 768, | |
| lora_rank: int = 32, | |
| trackio_space_id: str = "Humanlearning/CyberSecurity_OWASP-trackio", | |
| trackio_project: str = "CyberSecurity_OWASP-grpo", | |
| num_generations: int = 2, | |
| seed_start: int = 0, | |
| git_sha: str = "nogit", | |
| source_mode: str = "local", | |
| repo_url: str = PUBLIC_REPO_URL, | |
| repo_branch: str = PUBLIC_REPO_BRANCH, | |
| detach: bool = False, | |
| push_to_hub: bool = False, | |
| ) -> None: | |
| if mode == "config": | |
| result = check_training_imports.remote() | |
| print(result) | |
| return | |
| if mode != "train": | |
| raise ValueError("mode must be 'train' or 'config'") | |
| trackio_space_id = trackio_space_id or os.environ.get( | |
| "TRACKIO_SPACE_ID", | |
| "Humanlearning/CyberSecurity_OWASP-trackio", | |
| ) | |
| trackio_project = trackio_project or os.environ.get( | |
| "TRACKIO_PROJECT", "CyberSecurity_OWASP-grpo" | |
| ) | |
| resolved_trackio_space_id = trackio_space_id | |
| resolved_output_repo_id = output_repo_id | |
| if not resolved_trackio_space_id or not resolved_output_repo_id: | |
| hf_token = os.environ.get("HF_TOKEN") | |
| if hf_token: | |
| try: | |
| from huggingface_hub import whoami | |
| user = whoami(token=hf_token)["name"] | |
| if not resolved_trackio_space_id: | |
| resolved_trackio_space_id = ( | |
| f"{user}/CyberSecurity_OWASP-trackio" | |
| if user == "humandotlearning" | |
| else "Humanlearning/CyberSecurity_OWASP-trackio" | |
| ) | |
| resolved_output_repo_id = ( | |
| resolved_output_repo_id | |
| or f"{user}/CyberSecurity_OWASP-{_model_repo_slug(model_name)}-grpo-lora" | |
| ) | |
| except Exception as exc: | |
| print(f"Could not resolve Hugging Face defaults locally: {exc!r}") | |
| if git_sha == "nogit": | |
| try: | |
| git_sha = subprocess.check_output( | |
| ["git", "rev-parse", "HEAD"], | |
| cwd=PROJECT_ROOT, | |
| text=True, | |
| stderr=subprocess.DEVNULL, | |
| ).strip() | |
| except Exception: | |
| git_sha = "nogit" | |
| model_slug = model_name.replace("/", "-") | |
| local_stamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S") | |
| run_name = ( | |
| f"CyberSecurity_OWASP-{model_slug}-grpo-level{difficulty}-" | |
| f"{local_stamp}-{git_sha[:8]}" | |
| ) | |
| print(f"Run name: {run_name}") | |
| print(f"Source mode: {source_mode}") | |
| if source_mode == "public": | |
| print(f"Public repo: {repo_url}@{repo_branch}") | |
| if resolved_trackio_space_id: | |
| print(f"Trackio Space: https://huggingface.co/spaces/{resolved_trackio_space_id}") | |
| else: | |
| print("Trackio Space: derived remotely from HF_TOKEN as <hf-user>/CyberSecurity_OWASP-trackio") | |
| if resolved_output_repo_id: | |
| print(f"Output model repo: https://huggingface.co/{resolved_output_repo_id}") | |
| else: | |
| print( | |
| "Output model repo: derived remotely from HF_TOKEN as " | |
| f"<hf-user>/CyberSecurity_OWASP-{_model_repo_slug(model_name)}-grpo-lora" | |
| ) | |
| print(f"Hub push enabled: {push_to_hub}") | |
| print(f"Model cache volume: {CACHE_VOLUME_NAME}") | |
| print("Launch phases:") | |
| print( | |
| "1. Modal image build/validation: happens before remote Python logs; " | |
| "slow when local source or dependency layers changed." | |
| ) | |
| print("2. GPU container start on one L4 and persistent volume reload.") | |
| print("3. Model cache check in CyberSecurity_OWASP-model-cache.") | |
| print("4. Cached snapshot load into GPU RAM with Unsloth progress.") | |
| print("5. One GRPO step, Trackio sync, and volume commit.") | |
| print( | |
| "If there is a long pause after trainer.train() starts, watch for " | |
| "Training heartbeat lines every 30 seconds." | |
| ) | |
| kwargs = dict( | |
| env_repo_id=env_repo_id, | |
| output_repo_id=output_repo_id, | |
| max_steps=max_steps, | |
| dataset_size=dataset_size, | |
| difficulty=difficulty, | |
| split=split, | |
| model_name=model_name, | |
| max_seq_length=max_seq_length, | |
| max_completion_length=max_completion_length, | |
| lora_rank=lora_rank, | |
| trackio_space_id=trackio_space_id, | |
| trackio_project=trackio_project, | |
| num_generations=num_generations, | |
| seed_start=seed_start, | |
| git_sha=git_sha, | |
| run_name=run_name, | |
| source_mode=source_mode, | |
| repo_url=repo_url, | |
| repo_branch=repo_branch, | |
| push_to_hub=push_to_hub, | |
| ) | |
| if detach: | |
| call = train_cybersecurity_owasp_grpo.spawn(**kwargs) | |
| print(f"Spawned Modal training call: {call.object_id}") | |
| else: | |
| result = train_cybersecurity_owasp_grpo.remote(**kwargs) | |
| print(f"Training result: {result}") | |