Spaces:
Sleeping
Sleeping
Commit ·
b3ee507
1
Parent(s): 28685f3
feat: update training configuration and documentation for Modal execution, including new model integration and enhanced tracking utilities
Browse files- .agents/skills/cybersecurity-owasp-trainer/SKILL.md +10 -8
- 01_ARCHITECTURE.md +4 -2
- README.md +5 -1
- pyproject.toml +1 -1
- scripts/modal_ephemeral_train.py +197 -9
- scripts/modal_train_grpo.py +292 -77
- tests/test_trackio_utils.py +93 -0
- training/__init__.py +1 -0
- training/configs/grpo_small.yaml +3 -1
- training/rollout.py +9 -2
- training/trackio_utils.py +869 -1
- training/train_grpo.py +48 -8
.agents/skills/cybersecurity-owasp-trainer/SKILL.md
CHANGED
|
@@ -7,7 +7,9 @@ description: Train, debug, evaluate, and document CyberSecurity_OWASP model runs
|
|
| 7 |
|
| 8 |
## Overview
|
| 9 |
|
| 10 |
-
Use this skill to run or modify the CyberSecurity_OWASP training and evaluation loop without weakening the verifier, reward integrity, or hackathon evidence trail.
|
|
|
|
|
|
|
| 11 |
|
| 12 |
## References
|
| 13 |
|
|
@@ -37,28 +39,28 @@ Prefer the existing repo modules:
|
|
| 37 |
|
| 38 |
- `training/rollout.py`: full OpenEnv episode loop, action JSON parsing, reward trace, rollout artifact fields.
|
| 39 |
- `training/reward_funcs.py`: component reward functions exposed to TRL/GRPO.
|
| 40 |
-
- `training/train_grpo.py`: `GRPOConfig`
|
| 41 |
- `training/eval_before_after.py`: baseline-vs-trained and held-out summary metrics.
|
| 42 |
- `training/trackio_utils.py`: run naming, canonical metric names, Trackio init/log/finalize helpers.
|
| 43 |
|
| 44 |
Default environment values:
|
| 45 |
|
| 46 |
```powershell
|
| 47 |
-
$env:MODEL_NAME = "
|
| 48 |
$env:TRACKIO_SPACE_ID = "Humanlearning/CyberSecurity_OWASP-trackio"
|
| 49 |
$env:TRACKIO_PROJECT = "CyberSecurity_OWASP"
|
| 50 |
$env:DIFFICULTY = "0"
|
| 51 |
```
|
| 52 |
|
| 53 |
-
Use level-0 debug runs before scaling
|
| 54 |
|
| 55 |
## Training Workflow
|
| 56 |
|
| 57 |
1. Validate the environment first: run the targeted tests that cover models, reset/step/state, rewards, anti-cheat, seed reproducibility, invalid actions, and rollouts.
|
| 58 |
-
2. Run a
|
| 59 |
-
3. Run a frozen-model or dummy-policy rollout and inspect the action trace, observations, terminal reason, and reward breakdown.
|
| 60 |
4. Confirm Trackio receives component metrics and the run name follows `CyberSecurity_OWASP-<model>-<algo>-level<difficulty>-<YYYYMMDD-HHMM>-<git_sha>`.
|
| 61 |
-
5. Start a very small GRPO run only after the above passes.
|
| 62 |
6. Evaluate baseline, trained, and held-out splits with `training/eval_before_after.py` and save summaries under `outputs/evals/`.
|
| 63 |
7. Save sampled rollouts under `outputs/rollouts/` for baseline, mid-training, trained, and held-out evidence.
|
| 64 |
|
|
@@ -77,7 +79,7 @@ Stop or roll back if reward rises while sampled traces show deny-all patches, ha
|
|
| 77 |
|
| 78 |
- Use TRL GRPO for verifier-driven rewards. Keep multiple independent reward functions for logging and diagnosis.
|
| 79 |
- Keep the existing custom rollout path unless deliberately migrating to TRL's `environment_factory`. If migrating, preserve typed actions, observations, reward component logging, anti-cheat flags, and rollout artifacts.
|
| 80 |
-
- Use
|
| 81 |
- For OpenEnv server training concurrency, ensure the server supports enough concurrent sessions for the generation batch.
|
| 82 |
- Use Unsloth with LoRA or QLoRA for memory efficiency when the training machine supports it. Start from an instruct-capable checkpoint and verify the model has non-zero success probability before RL.
|
| 83 |
- Pin and smoke-test TRL, Unsloth, vLLM, CUDA, and torch versions before longer runs.
|
|
|
|
| 7 |
|
| 8 |
## Overview
|
| 9 |
|
| 10 |
+
Use this skill to run or modify the CyberSecurity_OWASP training and evaluation loop without weakening the verifier, reward integrity, or hackathon evidence trail. Training is expected to run on Modal only.
|
| 11 |
+
|
| 12 |
+
Important: do **not** run GRPO/PPO training loops locally in this repo. Use Modal launchers (`scripts/modal_ephemeral_train.py` for smoke and `scripts/modal_train_grpo.py` for GRPO).
|
| 13 |
|
| 14 |
## References
|
| 15 |
|
|
|
|
| 39 |
|
| 40 |
- `training/rollout.py`: full OpenEnv episode loop, action JSON parsing, reward trace, rollout artifact fields.
|
| 41 |
- `training/reward_funcs.py`: component reward functions exposed to TRL/GRPO.
|
| 42 |
+
- `training/train_grpo.py`: `GRPOConfig`/model defaults and launch intent (does not run local training).
|
| 43 |
- `training/eval_before_after.py`: baseline-vs-trained and held-out summary metrics.
|
| 44 |
- `training/trackio_utils.py`: run naming, canonical metric names, Trackio init/log/finalize helpers.
|
| 45 |
|
| 46 |
Default environment values:
|
| 47 |
|
| 48 |
```powershell
|
| 49 |
+
$env:MODEL_NAME = "google/gemma-2-2b-it"
|
| 50 |
$env:TRACKIO_SPACE_ID = "Humanlearning/CyberSecurity_OWASP-trackio"
|
| 51 |
$env:TRACKIO_PROJECT = "CyberSecurity_OWASP"
|
| 52 |
$env:DIFFICULTY = "0"
|
| 53 |
```
|
| 54 |
|
| 55 |
+
Use level-0 debug runs before scaling, and verify them through Modal smoke/ephemeral runs.
|
| 56 |
|
| 57 |
## Training Workflow
|
| 58 |
|
| 59 |
1. Validate the environment first: run the targeted tests that cover models, reset/step/state, rewards, anti-cheat, seed reproducibility, invalid actions, and rollouts.
|
| 60 |
+
2. Run a Modal smoke path for lightweight config/run verification.
|
| 61 |
+
3. Run a frozen-model or dummy-policy rollout on Modal and inspect the action trace, observations, terminal reason, and reward breakdown.
|
| 62 |
4. Confirm Trackio receives component metrics and the run name follows `CyberSecurity_OWASP-<model>-<algo>-level<difficulty>-<YYYYMMDD-HHMM>-<git_sha>`.
|
| 63 |
+
5. Start a very small GRPO run only after the above passes. Start via `scripts/modal_train_grpo.py --mode train`.
|
| 64 |
6. Evaluate baseline, trained, and held-out splits with `training/eval_before_after.py` and save summaries under `outputs/evals/`.
|
| 65 |
7. Save sampled rollouts under `outputs/rollouts/` for baseline, mid-training, trained, and held-out evidence.
|
| 66 |
|
|
|
|
| 79 |
|
| 80 |
- Use TRL GRPO for verifier-driven rewards. Keep multiple independent reward functions for logging and diagnosis.
|
| 81 |
- Keep the existing custom rollout path unless deliberately migrating to TRL's `environment_factory`. If migrating, preserve typed actions, observations, reward component logging, anti-cheat flags, and rollout artifacts.
|
| 82 |
+
- Use Modal as the default training path; local-only vLLM/GRPO execution is intentionally avoided in this repository.
|
| 83 |
- For OpenEnv server training concurrency, ensure the server supports enough concurrent sessions for the generation batch.
|
| 84 |
- Use Unsloth with LoRA or QLoRA for memory efficiency when the training machine supports it. Start from an instruct-capable checkpoint and verify the model has non-zero success probability before RL.
|
| 85 |
- Pin and smoke-test TRL, Unsloth, vLLM, CUDA, and torch versions before longer runs.
|
01_ARCHITECTURE.md
CHANGED
|
@@ -397,16 +397,18 @@ Editable source: `assets/env_rl_training_flow_diagram.mmd`
|
|
| 397 |
9. Produce final demo: before/after trace + reward curve + held-out eval table.
|
| 398 |
```
|
| 399 |
|
| 400 |
-
Recommended initial training setup:
|
| 401 |
|
| 402 |
```text
|
| 403 |
-
Model:
|
| 404 |
Algorithm: GRPO via TRL or Unsloth-compatible loop
|
| 405 |
Dataset prompt: repeated task instruction with randomized scenario IDs
|
| 406 |
Max steps per episode: 30
|
| 407 |
Rollouts per prompt: 2-4
|
| 408 |
Logging: Trackio
|
| 409 |
Primary eval: held-out deterministic test pass rate
|
|
|
|
|
|
|
| 410 |
```
|
| 411 |
|
| 412 |
## 9. Deployment architecture
|
|
|
|
| 397 |
9. Produce final demo: before/after trace + reward curve + held-out eval table.
|
| 398 |
```
|
| 399 |
|
| 400 |
+
Recommended initial training setup (Modal-first):
|
| 401 |
|
| 402 |
```text
|
| 403 |
+
Model: google/gemma-2-2b-it (or compatible Gemma-class instruct model)
|
| 404 |
Algorithm: GRPO via TRL or Unsloth-compatible loop
|
| 405 |
Dataset prompt: repeated task instruction with randomized scenario IDs
|
| 406 |
Max steps per episode: 30
|
| 407 |
Rollouts per prompt: 2-4
|
| 408 |
Logging: Trackio
|
| 409 |
Primary eval: held-out deterministic test pass rate
|
| 410 |
+
|
| 411 |
+
Training execution is expected to run on Modal (persistent or ephemeral) rather than locally.
|
| 412 |
```
|
| 413 |
|
| 414 |
## 9. Deployment architecture
|
README.md
CHANGED
|
@@ -149,6 +149,10 @@ Training files are under `training/`:
|
|
| 149 |
|
| 150 |
The training scaffold is intentionally minimal until the environment/verifier behavior is stable. Trackio metric names and GRPO defaults follow the project brief.
|
| 151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 152 |
## Trackio Run Tracking
|
| 153 |
|
| 154 |
Trackio is the default tracker for official runs. Set `TRACKIO_SPACE_ID` to log to a hosted Hugging Face Trackio Space; otherwise Trackio records locally.
|
|
@@ -239,7 +243,7 @@ Defaults are derived from `HF_TOKEN`:
|
|
| 239 |
|
| 240 |
- Trackio Space: `<hf-user>/CyberSecurity_OWASP-trackio`
|
| 241 |
- Trackio project: `CyberSecurity_OWASP-grpo`
|
| 242 |
-
- Output repo: `<hf-user>/CyberSecurity_OWASP-
|
| 243 |
|
| 244 |
Override these with `--trackio-space-id`, `--trackio-project`, and
|
| 245 |
`--output-repo-id` when needed.
|
|
|
|
| 149 |
|
| 150 |
The training scaffold is intentionally minimal until the environment/verifier behavior is stable. Trackio metric names and GRPO defaults follow the project brief.
|
| 151 |
|
| 152 |
+
`training/train_grpo.py` in this repo is a config helper only; it does not execute training locally.
|
| 153 |
+
Use the Modal launchers in `scripts/modal_train_grpo.py` (persistent) and
|
| 154 |
+
`scripts/modal_ephemeral_train.py` (smoke) for real GRPO runs.
|
| 155 |
+
|
| 156 |
## Trackio Run Tracking
|
| 157 |
|
| 158 |
Trackio is the default tracker for official runs. Set `TRACKIO_SPACE_ID` to log to a hosted Hugging Face Trackio Space; otherwise Trackio records locally.
|
|
|
|
| 243 |
|
| 244 |
- Trackio Space: `<hf-user>/CyberSecurity_OWASP-trackio`
|
| 245 |
- Trackio project: `CyberSecurity_OWASP-grpo`
|
| 246 |
+
- Output repo: `<hf-user>/CyberSecurity_OWASP-gemma-2-2b-grpo-lora`
|
| 247 |
|
| 248 |
Override these with `--trackio-space-id`, `--trackio-project`, and
|
| 249 |
`--output-repo-id` when needed.
|
pyproject.toml
CHANGED
|
@@ -45,7 +45,7 @@ server = "CyberSecurity_OWASP.server.app:main"
|
|
| 45 |
|
| 46 |
[tool.setuptools]
|
| 47 |
include-package-data = true
|
| 48 |
-
packages = ["CyberSecurity_OWASP", "CyberSecurity_OWASP.server"]
|
| 49 |
package-dir = { "CyberSecurity_OWASP" = ".", "CyberSecurity_OWASP.server" = "server" }
|
| 50 |
|
| 51 |
[tool.pytest.ini_options]
|
|
|
|
| 45 |
|
| 46 |
[tool.setuptools]
|
| 47 |
include-package-data = true
|
| 48 |
+
packages = ["CyberSecurity_OWASP", "CyberSecurity_OWASP.server", "training"]
|
| 49 |
package-dir = { "CyberSecurity_OWASP" = ".", "CyberSecurity_OWASP.server" = "server" }
|
| 50 |
|
| 51 |
[tool.pytest.ini_options]
|
scripts/modal_ephemeral_train.py
CHANGED
|
@@ -12,6 +12,8 @@ the local process, so the run disappears when ``modal run`` exits.
|
|
| 12 |
from __future__ import annotations
|
| 13 |
|
| 14 |
import json
|
|
|
|
|
|
|
| 15 |
from datetime import datetime
|
| 16 |
from pathlib import Path
|
| 17 |
from typing import Any
|
|
@@ -20,6 +22,7 @@ import modal
|
|
| 20 |
|
| 21 |
|
| 22 |
APP_NAME = "CyberSecurity_OWASP-ephemeral-training"
|
|
|
|
| 23 |
REMOTE_PROJECT = "/root/CyberSecurity_OWASP"
|
| 24 |
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
| 25 |
|
|
@@ -63,7 +66,11 @@ class NoopTrainer:
|
|
| 63 |
]
|
| 64 |
|
| 65 |
|
| 66 |
-
@app.function(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
def run_ephemeral_smoke(
|
| 68 |
episodes: int = 4,
|
| 69 |
seed_start: int = 0,
|
|
@@ -75,17 +82,45 @@ def run_ephemeral_smoke(
|
|
| 75 |
CybersecurityOwaspEnvironment,
|
| 76 |
)
|
| 77 |
from training.rollout import rollout_once
|
| 78 |
-
from training.trackio_utils import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
baseline = []
|
| 81 |
oracle = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
for offset in range(episodes):
|
| 84 |
seed = seed_start + offset
|
| 85 |
|
| 86 |
baseline_env = CybersecurityOwaspEnvironment()
|
| 87 |
-
|
| 88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
oracle_env = CybersecurityOwaspEnvironment()
|
| 91 |
oracle_env.reset(seed=seed, split="validation")
|
|
@@ -124,19 +159,25 @@ def run_ephemeral_smoke(
|
|
| 124 |
)
|
| 125 |
oracle_env.step(CyberSecurityOWASPAction(tool_name="run_visible_tests"))
|
| 126 |
final = oracle_env.step(CyberSecurityOWASPAction(tool_name="submit_fix"))
|
| 127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
{
|
| 129 |
-
"seed": seed,
|
| 130 |
-
"success": oracle_env.state.success,
|
| 131 |
"reward_total": final.reward_breakdown.get("total", 0.0),
|
| 132 |
-
"
|
| 133 |
}
|
| 134 |
)
|
|
|
|
| 135 |
|
| 136 |
def mean(items: list[dict[str, Any]], key: str) -> float:
|
| 137 |
return sum(float(item.get(key, 0.0)) for item in items) / max(1, len(items))
|
| 138 |
|
| 139 |
run_name = f"{APP_NAME}-{datetime.utcnow().strftime('%Y%m%d-%H%M%S')}"
|
|
|
|
|
|
|
| 140 |
result = {
|
| 141 |
"run_name": run_name,
|
| 142 |
"mode": "smoke",
|
|
@@ -145,6 +186,8 @@ def run_ephemeral_smoke(
|
|
| 145 |
"baseline_mean_reward": mean(baseline, "reward_total"),
|
| 146 |
"oracle_mean_reward": mean(oracle, "reward_total"),
|
| 147 |
"oracle_success_rate": mean(oracle, "success"),
|
|
|
|
|
|
|
| 148 |
"baseline": baseline,
|
| 149 |
"oracle": oracle,
|
| 150 |
}
|
|
@@ -160,8 +203,10 @@ def run_ephemeral_smoke(
|
|
| 160 |
},
|
| 161 |
group="smoke",
|
| 162 |
):
|
|
|
|
| 163 |
log_trackio_metrics(
|
| 164 |
{
|
|
|
|
| 165 |
"smoke/baseline_mean_reward": result["baseline_mean_reward"],
|
| 166 |
"smoke/oracle_mean_reward": result["oracle_mean_reward"],
|
| 167 |
"smoke/oracle_success_rate": result["oracle_success_rate"],
|
|
@@ -179,6 +224,130 @@ def run_grpo_config_check() -> str:
|
|
| 179 |
return str(build_grpo_config())
|
| 180 |
|
| 181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 182 |
@app.local_entrypoint()
|
| 183 |
def main(
|
| 184 |
mode: str = "smoke",
|
|
@@ -186,6 +355,7 @@ def main(
|
|
| 186 |
seed_start: int = 0,
|
| 187 |
trackio_space_id: str = "",
|
| 188 |
trackio_project: str = "CyberSecurity_OWASP-smoke",
|
|
|
|
| 189 |
) -> None:
|
| 190 |
if mode == "smoke":
|
| 191 |
result = run_ephemeral_smoke.remote(
|
|
@@ -201,5 +371,23 @@ def main(
|
|
| 201 |
print(json.dumps({"saved": str(output_path), **result}, indent=2, sort_keys=True))
|
| 202 |
elif mode == "grpo-config":
|
| 203 |
print(run_grpo_config_check.remote())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 204 |
else:
|
| 205 |
-
raise ValueError(
|
|
|
|
|
|
|
|
|
| 12 |
from __future__ import annotations
|
| 13 |
|
| 14 |
import json
|
| 15 |
+
import subprocess
|
| 16 |
+
import time
|
| 17 |
from datetime import datetime
|
| 18 |
from pathlib import Path
|
| 19 |
from typing import Any
|
|
|
|
| 22 |
|
| 23 |
|
| 24 |
APP_NAME = "CyberSecurity_OWASP-ephemeral-training"
|
| 25 |
+
SECRET_NAME = "CyberSecurity_OWASP-secrets"
|
| 26 |
REMOTE_PROJECT = "/root/CyberSecurity_OWASP"
|
| 27 |
PROJECT_ROOT = Path(__file__).resolve().parents[1]
|
| 28 |
|
|
|
|
| 66 |
]
|
| 67 |
|
| 68 |
|
| 69 |
+
@app.function(
|
| 70 |
+
image=image,
|
| 71 |
+
timeout=60 * 30,
|
| 72 |
+
secrets=[modal.Secret.from_name(SECRET_NAME, required_keys=["HF_TOKEN"])],
|
| 73 |
+
)
|
| 74 |
def run_ephemeral_smoke(
|
| 75 |
episodes: int = 4,
|
| 76 |
seed_start: int = 0,
|
|
|
|
| 82 |
CybersecurityOwaspEnvironment,
|
| 83 |
)
|
| 84 |
from training.rollout import rollout_once
|
| 85 |
+
from training.trackio_utils import (
|
| 86 |
+
aggregate_episode_metrics,
|
| 87 |
+
episode_record_from_state,
|
| 88 |
+
log_episode_batch,
|
| 89 |
+
log_trackio_metrics,
|
| 90 |
+
trace_table_rows,
|
| 91 |
+
trackio_run,
|
| 92 |
+
)
|
| 93 |
|
| 94 |
baseline = []
|
| 95 |
oracle = []
|
| 96 |
+
run_context = {
|
| 97 |
+
"algo": "modal_ephemeral_smoke",
|
| 98 |
+
"reward_version": "reward_v1",
|
| 99 |
+
"env_version": "0.1.0",
|
| 100 |
+
}
|
| 101 |
|
| 102 |
for offset in range(episodes):
|
| 103 |
seed = seed_start + offset
|
| 104 |
|
| 105 |
baseline_env = CybersecurityOwaspEnvironment()
|
| 106 |
+
baseline_rollout = rollout_once(
|
| 107 |
+
NoopTrainer(),
|
| 108 |
+
baseline_env,
|
| 109 |
+
max_steps=5,
|
| 110 |
+
reset_kwargs={"seed": seed, "split": "validation", "difficulty": 0},
|
| 111 |
+
)
|
| 112 |
+
baseline_record = episode_record_from_state(
|
| 113 |
+
baseline_env.state,
|
| 114 |
+
run_context={**run_context, "base_model": "noop"},
|
| 115 |
+
)
|
| 116 |
+
baseline_record.update(
|
| 117 |
+
{
|
| 118 |
+
"reward_total": baseline_rollout.get("reward_total", 0.0),
|
| 119 |
+
"success": baseline_rollout.get("success", False),
|
| 120 |
+
"episode_length": baseline_rollout.get("episode_length", 0),
|
| 121 |
+
}
|
| 122 |
+
)
|
| 123 |
+
baseline.append(baseline_record)
|
| 124 |
|
| 125 |
oracle_env = CybersecurityOwaspEnvironment()
|
| 126 |
oracle_env.reset(seed=seed, split="validation")
|
|
|
|
| 159 |
)
|
| 160 |
oracle_env.step(CyberSecurityOWASPAction(tool_name="run_visible_tests"))
|
| 161 |
final = oracle_env.step(CyberSecurityOWASPAction(tool_name="submit_fix"))
|
| 162 |
+
oracle_record = episode_record_from_state(
|
| 163 |
+
oracle_env.state,
|
| 164 |
+
run_context={**run_context, "base_model": "oracle"},
|
| 165 |
+
final_observation=final.model_dump(),
|
| 166 |
+
)
|
| 167 |
+
oracle_record.update(
|
| 168 |
{
|
|
|
|
|
|
|
| 169 |
"reward_total": final.reward_breakdown.get("total", 0.0),
|
| 170 |
+
"success": oracle_env.state.success,
|
| 171 |
}
|
| 172 |
)
|
| 173 |
+
oracle.append(oracle_record)
|
| 174 |
|
| 175 |
def mean(items: list[dict[str, Any]], key: str) -> float:
|
| 176 |
return sum(float(item.get(key, 0.0)) for item in items) / max(1, len(items))
|
| 177 |
|
| 178 |
run_name = f"{APP_NAME}-{datetime.utcnow().strftime('%Y%m%d-%H%M%S')}"
|
| 179 |
+
episode_records = [*baseline, *oracle]
|
| 180 |
+
tracking_metrics = aggregate_episode_metrics(episode_records)
|
| 181 |
result = {
|
| 182 |
"run_name": run_name,
|
| 183 |
"mode": "smoke",
|
|
|
|
| 186 |
"baseline_mean_reward": mean(baseline, "reward_total"),
|
| 187 |
"oracle_mean_reward": mean(oracle, "reward_total"),
|
| 188 |
"oracle_success_rate": mean(oracle, "success"),
|
| 189 |
+
"tracking_metrics": tracking_metrics,
|
| 190 |
+
"tracking_trace_rows": trace_table_rows(episode_records),
|
| 191 |
"baseline": baseline,
|
| 192 |
"oracle": oracle,
|
| 193 |
}
|
|
|
|
| 203 |
},
|
| 204 |
group="smoke",
|
| 205 |
):
|
| 206 |
+
logged_metrics = log_episode_batch(episode_records, step=0)
|
| 207 |
log_trackio_metrics(
|
| 208 |
{
|
| 209 |
+
**logged_metrics,
|
| 210 |
"smoke/baseline_mean_reward": result["baseline_mean_reward"],
|
| 211 |
"smoke/oracle_mean_reward": result["oracle_mean_reward"],
|
| 212 |
"smoke/oracle_success_rate": result["oracle_success_rate"],
|
|
|
|
| 224 |
return str(build_grpo_config())
|
| 225 |
|
| 226 |
|
| 227 |
+
@app.function(
|
| 228 |
+
image=image,
|
| 229 |
+
timeout=60 * 10,
|
| 230 |
+
secrets=[modal.Secret.from_name(SECRET_NAME, required_keys=["HF_TOKEN"])],
|
| 231 |
+
)
|
| 232 |
+
def verify_trackio_run(
|
| 233 |
+
run_name: str,
|
| 234 |
+
trackio_space_id: str = "Humanlearning/CyberSecurity_OWASP-trackio",
|
| 235 |
+
trackio_project: str = "CyberSecurity_OWASP-smoke",
|
| 236 |
+
) -> dict[str, Any]:
|
| 237 |
+
import os
|
| 238 |
+
from training.trackio_utils import (
|
| 239 |
+
REQUIRED_SMOKE_TRACKIO_ITEMS,
|
| 240 |
+
missing_required_trackio_items,
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
hf_token = os.environ["HF_TOKEN"]
|
| 244 |
+
cmd = [
|
| 245 |
+
"trackio",
|
| 246 |
+
"get",
|
| 247 |
+
"run",
|
| 248 |
+
"--project",
|
| 249 |
+
trackio_project,
|
| 250 |
+
"--run",
|
| 251 |
+
run_name,
|
| 252 |
+
"--space",
|
| 253 |
+
trackio_space_id,
|
| 254 |
+
"--hf-token",
|
| 255 |
+
hf_token,
|
| 256 |
+
"--json",
|
| 257 |
+
]
|
| 258 |
+
metrics_cmd = [
|
| 259 |
+
"trackio",
|
| 260 |
+
"list",
|
| 261 |
+
"metrics",
|
| 262 |
+
"--project",
|
| 263 |
+
trackio_project,
|
| 264 |
+
"--run",
|
| 265 |
+
run_name,
|
| 266 |
+
"--space",
|
| 267 |
+
trackio_space_id,
|
| 268 |
+
"--hf-token",
|
| 269 |
+
hf_token,
|
| 270 |
+
"--json",
|
| 271 |
+
]
|
| 272 |
+
last_result: dict[str, Any] = {}
|
| 273 |
+
for attempt in range(1, 4):
|
| 274 |
+
completed = subprocess.run(cmd, capture_output=True, text=True)
|
| 275 |
+
metrics_completed = subprocess.run(metrics_cmd, capture_output=True, text=True)
|
| 276 |
+
last_result = {
|
| 277 |
+
"attempt": attempt,
|
| 278 |
+
"returncode": completed.returncode,
|
| 279 |
+
"stdout": completed.stdout[-4000:],
|
| 280 |
+
"stderr": completed.stderr[-4000:],
|
| 281 |
+
"metrics_returncode": metrics_completed.returncode,
|
| 282 |
+
"metrics_stdout": metrics_completed.stdout[-4000:],
|
| 283 |
+
"metrics_stderr": metrics_completed.stderr[-4000:],
|
| 284 |
+
}
|
| 285 |
+
if completed.returncode == 0:
|
| 286 |
+
data = json.loads(completed.stdout)
|
| 287 |
+
if metrics_completed.returncode == 0:
|
| 288 |
+
metrics_data = json.loads(metrics_completed.stdout)
|
| 289 |
+
if isinstance(metrics_data.get("metrics"), list):
|
| 290 |
+
data["metrics"] = metrics_data["metrics"]
|
| 291 |
+
missing = missing_required_trackio_items(data, REQUIRED_SMOKE_TRACKIO_ITEMS)
|
| 292 |
+
return {
|
| 293 |
+
"ok": not missing,
|
| 294 |
+
"trackio_space_id": trackio_space_id,
|
| 295 |
+
"trackio_project": trackio_project,
|
| 296 |
+
"run_name": run_name,
|
| 297 |
+
"required_items": list(REQUIRED_SMOKE_TRACKIO_ITEMS),
|
| 298 |
+
"missing_required_items": missing,
|
| 299 |
+
"run": data,
|
| 300 |
+
}
|
| 301 |
+
time.sleep(10)
|
| 302 |
+
return {
|
| 303 |
+
"ok": False,
|
| 304 |
+
"trackio_space_id": trackio_space_id,
|
| 305 |
+
"trackio_project": trackio_project,
|
| 306 |
+
"run_name": run_name,
|
| 307 |
+
"last_result": last_result,
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
@app.function(
|
| 312 |
+
image=image,
|
| 313 |
+
timeout=60 * 10,
|
| 314 |
+
secrets=[modal.Secret.from_name(SECRET_NAME, required_keys=["HF_TOKEN"])],
|
| 315 |
+
)
|
| 316 |
+
def inspect_trackio_space(
|
| 317 |
+
trackio_space_id: str = "Humanlearning/CyberSecurity_OWASP-trackio",
|
| 318 |
+
) -> dict[str, Any]:
|
| 319 |
+
import os
|
| 320 |
+
|
| 321 |
+
hf_token = os.environ["HF_TOKEN"]
|
| 322 |
+
|
| 323 |
+
def run_trackio(args: list[str]) -> dict[str, Any]:
|
| 324 |
+
completed = subprocess.run(
|
| 325 |
+
["trackio", *args, "--space", trackio_space_id, "--hf-token", hf_token, "--json"],
|
| 326 |
+
capture_output=True,
|
| 327 |
+
text=True,
|
| 328 |
+
)
|
| 329 |
+
result = {
|
| 330 |
+
"returncode": completed.returncode,
|
| 331 |
+
"stdout": completed.stdout[-8000:],
|
| 332 |
+
"stderr": completed.stderr[-4000:],
|
| 333 |
+
}
|
| 334 |
+
if completed.returncode == 0:
|
| 335 |
+
result["json"] = json.loads(completed.stdout)
|
| 336 |
+
return result
|
| 337 |
+
|
| 338 |
+
projects_result = run_trackio(["list", "projects"])
|
| 339 |
+
projects = (projects_result.get("json") or {}).get("projects", [])
|
| 340 |
+
runs_by_project = {
|
| 341 |
+
project: run_trackio(["list", "runs", "--project", project])
|
| 342 |
+
for project in projects
|
| 343 |
+
}
|
| 344 |
+
return {
|
| 345 |
+
"trackio_space_id": trackio_space_id,
|
| 346 |
+
"projects": projects_result,
|
| 347 |
+
"runs_by_project": runs_by_project,
|
| 348 |
+
}
|
| 349 |
+
|
| 350 |
+
|
| 351 |
@app.local_entrypoint()
|
| 352 |
def main(
|
| 353 |
mode: str = "smoke",
|
|
|
|
| 355 |
seed_start: int = 0,
|
| 356 |
trackio_space_id: str = "",
|
| 357 |
trackio_project: str = "CyberSecurity_OWASP-smoke",
|
| 358 |
+
run_name: str = "",
|
| 359 |
) -> None:
|
| 360 |
if mode == "smoke":
|
| 361 |
result = run_ephemeral_smoke.remote(
|
|
|
|
| 371 |
print(json.dumps({"saved": str(output_path), **result}, indent=2, sort_keys=True))
|
| 372 |
elif mode == "grpo-config":
|
| 373 |
print(run_grpo_config_check.remote())
|
| 374 |
+
elif mode == "verify-trackio":
|
| 375 |
+
if not run_name:
|
| 376 |
+
raise ValueError("--run-name is required for verify-trackio mode")
|
| 377 |
+
result = verify_trackio_run.remote(
|
| 378 |
+
run_name=run_name,
|
| 379 |
+
trackio_space_id=trackio_space_id
|
| 380 |
+
or "Humanlearning/CyberSecurity_OWASP-trackio",
|
| 381 |
+
trackio_project=trackio_project,
|
| 382 |
+
)
|
| 383 |
+
print(json.dumps(result, indent=2, sort_keys=True))
|
| 384 |
+
elif mode == "inspect-trackio":
|
| 385 |
+
result = inspect_trackio_space.remote(
|
| 386 |
+
trackio_space_id=trackio_space_id
|
| 387 |
+
or "Humanlearning/CyberSecurity_OWASP-trackio",
|
| 388 |
+
)
|
| 389 |
+
print(json.dumps(result, indent=2, sort_keys=True))
|
| 390 |
else:
|
| 391 |
+
raise ValueError(
|
| 392 |
+
"mode must be 'smoke', 'grpo-config', 'verify-trackio', or 'inspect-trackio'"
|
| 393 |
+
)
|
scripts/modal_train_grpo.py
CHANGED
|
@@ -28,12 +28,61 @@ import modal
|
|
| 28 |
|
| 29 |
APP_NAME = "CyberSecurity_OWASP-grpo"
|
| 30 |
VOLUME_NAME = "CyberSecurity_OWASP-grpo-runs"
|
|
|
|
| 31 |
SECRET_NAME = "CyberSecurity_OWASP-secrets"
|
| 32 |
RUNS_DIR = pathlib.Path("/runs")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
REMOTE_PROJECT = "/root/CyberSecurity_OWASP"
|
| 34 |
PROJECT_ROOT = pathlib.Path(__file__).resolve().parents[1]
|
| 35 |
PUBLIC_REPO_URL = "https://github.com/humandotlearning/CyberSecurity_OWASP.git"
|
| 36 |
PUBLIC_REPO_BRANCH = "master"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
|
| 38 |
|
| 39 |
def _load_local_env_file() -> None:
|
|
@@ -114,6 +163,7 @@ def _training_image() -> modal.Image:
|
|
| 114 |
"unsloth_zoo[base] @ git+https://github.com/unslothai/unsloth-zoo",
|
| 115 |
"unsloth[base] @ git+https://github.com/unslothai/unsloth",
|
| 116 |
)
|
|
|
|
| 117 |
.uv_pip_install("pydantic==2.10.6")
|
| 118 |
.uv_pip_install("mergekit", "immutables==0.21", extra_options="--no-deps")
|
| 119 |
.uv_pip_install("llm-blender", "weave")
|
|
@@ -159,22 +209,25 @@ def _training_image() -> modal.Image:
|
|
| 159 |
|
| 160 |
app = modal.App(APP_NAME)
|
| 161 |
volume = modal.Volume.from_name(VOLUME_NAME, create_if_missing=True)
|
|
|
|
| 162 |
secrets = _modal_secrets()
|
| 163 |
|
| 164 |
|
| 165 |
@app.function(
|
| 166 |
image=_training_image(),
|
| 167 |
-
gpu=
|
| 168 |
timeout=4 * 60 * 60,
|
| 169 |
-
volumes={RUNS_DIR: volume},
|
| 170 |
secrets=secrets,
|
| 171 |
)
|
| 172 |
def check_training_imports() -> dict[str, str]:
|
|
|
|
|
|
|
| 173 |
import torch
|
| 174 |
import trackio
|
| 175 |
from datasets import Dataset
|
| 176 |
from trl import GRPOConfig, GRPOTrainer
|
| 177 |
-
from unsloth import FastLanguageModel
|
| 178 |
|
| 179 |
from CyberSecurity_OWASP.server.CyberSecurity_OWASP_environment import (
|
| 180 |
CybersecurityOwaspEnvironment,
|
|
@@ -189,16 +242,19 @@ def check_training_imports() -> dict[str, str]:
|
|
| 189 |
"grpo_config": GRPOConfig.__name__,
|
| 190 |
"grpo_trainer": GRPOTrainer.__name__,
|
| 191 |
"unsloth_model": FastLanguageModel.__name__,
|
|
|
|
| 192 |
"env": CybersecurityOwaspEnvironment.__name__,
|
| 193 |
"reset_phase": obs.phase,
|
|
|
|
|
|
|
| 194 |
}
|
| 195 |
|
| 196 |
|
| 197 |
@app.function(
|
| 198 |
image=_training_image(),
|
| 199 |
-
gpu=
|
| 200 |
timeout=4 * 60 * 60,
|
| 201 |
-
volumes={RUNS_DIR: volume},
|
| 202 |
secrets=secrets,
|
| 203 |
)
|
| 204 |
def train_cybersecurity_owasp_grpo(
|
|
@@ -208,11 +264,11 @@ def train_cybersecurity_owasp_grpo(
|
|
| 208 |
dataset_size: int = 16,
|
| 209 |
difficulty: int = 0,
|
| 210 |
split: str = "train",
|
| 211 |
-
model_name: str =
|
| 212 |
max_seq_length: int = 4096,
|
| 213 |
max_completion_length: int = 768,
|
| 214 |
lora_rank: int = 32,
|
| 215 |
-
trackio_space_id: str = "",
|
| 216 |
trackio_project: str = "CyberSecurity_OWASP-grpo",
|
| 217 |
num_generations: int = 2,
|
| 218 |
seed_start: int = 0,
|
|
@@ -221,15 +277,18 @@ def train_cybersecurity_owasp_grpo(
|
|
| 221 |
source_mode: str = "local",
|
| 222 |
repo_url: str = PUBLIC_REPO_URL,
|
| 223 |
repo_branch: str = PUBLIC_REPO_BRANCH,
|
|
|
|
| 224 |
) -> dict[str, str | int | float]:
|
| 225 |
import inspect
|
| 226 |
import statistics
|
| 227 |
|
|
|
|
|
|
|
| 228 |
import torch
|
| 229 |
-
from unsloth import FastLanguageModel
|
| 230 |
import transformers.utils.hub as transformers_hub
|
| 231 |
from datasets import Dataset
|
| 232 |
-
from huggingface_hub import whoami
|
| 233 |
from transformers import TrainerCallback
|
| 234 |
from trl import GRPOConfig, GRPOTrainer, clone_chat_template
|
| 235 |
from trl.chat_template_utils import add_response_schema
|
|
@@ -240,14 +299,16 @@ def train_cybersecurity_owasp_grpo(
|
|
| 240 |
from CyberSecurity_OWASP.server.CyberSecurity_OWASP_environment import (
|
| 241 |
CybersecurityOwaspEnvironment,
|
| 242 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 243 |
|
| 244 |
-
|
| 245 |
-
transformers_hub.TRANSFORMERS_CACHE = os.path.join(
|
| 246 |
-
os.path.expanduser("~"),
|
| 247 |
-
".cache",
|
| 248 |
-
"huggingface",
|
| 249 |
-
"hub",
|
| 250 |
-
)
|
| 251 |
|
| 252 |
hf_token = os.environ.get("HF_TOKEN")
|
| 253 |
if not hf_token:
|
|
@@ -257,8 +318,20 @@ def train_cybersecurity_owasp_grpo(
|
|
| 257 |
|
| 258 |
user = whoami(token=hf_token)["name"]
|
| 259 |
env_repo_id = env_repo_id or f"{user}/CyberSecurity_OWASP"
|
| 260 |
-
output_repo_id = output_repo_id or
|
| 261 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
|
| 263 |
os.environ["TRACKIO_SPACE_ID"] = trackio_space_id
|
| 264 |
os.environ["TRACKIO_PROJECT"] = trackio_project
|
|
@@ -271,6 +344,13 @@ def train_cybersecurity_owasp_grpo(
|
|
| 271 |
output_dir = RUNS_DIR / run_name
|
| 272 |
output_dir.mkdir(parents=True, exist_ok=True)
|
| 273 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 274 |
training_prompt = (
|
| 275 |
"You are a defensive AppSec repair agent in the local CyberSecurity_OWASP "
|
| 276 |
"OpenEnv environment. Use only the provided local tools. Do not target real "
|
|
@@ -570,49 +650,48 @@ def train_cybersecurity_owasp_grpo(
|
|
| 570 |
completions = kwargs.get("completions") or kwargs.get("completion") or []
|
| 571 |
trace_step["value"] += 1
|
| 572 |
|
| 573 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 574 |
metrics = {
|
| 575 |
-
|
| 576 |
-
|
| 577 |
-
[float(item.get("discovery", 0.0)) for item in breakdowns]
|
| 578 |
-
),
|
| 579 |
-
"train/reward_security_mean": _mean(
|
| 580 |
-
[float(item.get("security", 0.0)) for item in breakdowns]
|
| 581 |
-
),
|
| 582 |
-
"train/reward_regression_mean": _mean(
|
| 583 |
-
[float(item.get("regression", 0.0)) for item in breakdowns]
|
| 584 |
-
),
|
| 585 |
-
"train/reward_public_routes_mean": _mean(
|
| 586 |
-
[float(item.get("public_routes", 0.0)) for item in breakdowns]
|
| 587 |
-
),
|
| 588 |
-
"train/reward_patch_quality_mean": _mean(
|
| 589 |
-
[float(item.get("patch_quality", 0.0)) for item in breakdowns]
|
| 590 |
-
),
|
| 591 |
-
"train/reward_visible_tests_mean": _mean(
|
| 592 |
-
[float(item.get("visible_tests", 0.0)) for item in breakdowns]
|
| 593 |
-
),
|
| 594 |
-
"train/reward_anti_cheat_mean": _mean(
|
| 595 |
-
[float(item.get("anti_cheat", 0.0)) for item in breakdowns]
|
| 596 |
-
),
|
| 597 |
-
"train/success_rate": _mean(
|
| 598 |
-
[1.0 if bool(getattr(env, "success", False)) else 0.0 for env in environments]
|
| 599 |
-
),
|
| 600 |
-
"train/invalid_action_rate": _mean(
|
| 601 |
-
[float(getattr(env, "invalid_actions", 0)) for env in environments]
|
| 602 |
-
),
|
| 603 |
-
"train/episode_length_mean": _mean(
|
| 604 |
-
[
|
| 605 |
-
float(getattr(env, "trace_metadata", {}).get("step_count", 0))
|
| 606 |
-
for env in environments
|
| 607 |
-
]
|
| 608 |
-
),
|
| 609 |
}
|
|
|
|
|
|
|
|
|
|
| 610 |
|
| 611 |
try:
|
| 612 |
-
|
| 613 |
except Exception as exc:
|
| 614 |
print(f"Trackio metric logging skipped: {exc!r}")
|
| 615 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 616 |
for index, env in enumerate(environments):
|
| 617 |
messages = list(getattr(env, "trace_messages", []))
|
| 618 |
if index < len(completions):
|
|
@@ -655,9 +734,24 @@ def train_cybersecurity_owasp_grpo(
|
|
| 655 |
return rewards
|
| 656 |
|
| 657 |
class TrackioSystemMetricsCallback(TrainerCallback):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 658 |
def on_log(self, args, state, control, logs=None, **kwargs):
|
| 659 |
try:
|
| 660 |
-
metrics =
|
| 661 |
except Exception as exc:
|
| 662 |
print(f"Trackio GPU metrics skipped: {exc!r}")
|
| 663 |
return control
|
|
@@ -666,6 +760,13 @@ def train_cybersecurity_owasp_grpo(
|
|
| 666 |
print(f"Trackio GPU metrics logged at step {state.global_step}: {summary}")
|
| 667 |
return control
|
| 668 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 669 |
print(f"CUDA available: {torch.cuda.is_available()}")
|
| 670 |
if source_mode == "public":
|
| 671 |
print(f"Installed CyberSecurity_OWASP from public repo: {repo_url}@{repo_branch}")
|
|
@@ -675,27 +776,114 @@ def train_cybersecurity_owasp_grpo(
|
|
| 675 |
print(f"Trackio Project: {trackio_project}")
|
| 676 |
print(f"Output repo: {output_repo_id}")
|
| 677 |
print(f"Run name: {run_name}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 678 |
|
| 679 |
-
model
|
|
|
|
|
|
|
| 680 |
model_name=model_name,
|
| 681 |
max_seq_length=max_seq_length,
|
| 682 |
load_in_4bit=False,
|
| 683 |
fast_inference=False,
|
|
|
|
| 684 |
token=hf_token,
|
| 685 |
)
|
|
|
|
|
|
|
|
|
|
| 686 |
try:
|
| 687 |
tokenizer = add_response_schema(tokenizer)
|
| 688 |
except Exception as exc:
|
| 689 |
-
|
| 690 |
-
|
| 691 |
-
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
|
| 695 |
-
|
| 696 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 697 |
|
| 698 |
-
model =
|
| 699 |
model,
|
| 700 |
r=lora_rank,
|
| 701 |
target_modules=[
|
|
@@ -711,7 +899,9 @@ def train_cybersecurity_owasp_grpo(
|
|
| 711 |
use_gradient_checkpointing="unsloth",
|
| 712 |
random_state=3407,
|
| 713 |
)
|
| 714 |
-
|
|
|
|
|
|
|
| 715 |
|
| 716 |
grpo_config_values = {
|
| 717 |
"temperature": 1.0,
|
|
@@ -732,7 +922,7 @@ def train_cybersecurity_owasp_grpo(
|
|
| 732 |
"trackio_space_id": trackio_space_id,
|
| 733 |
"run_name": run_name,
|
| 734 |
"output_dir": str(output_dir),
|
| 735 |
-
"push_to_hub":
|
| 736 |
"hub_model_id": output_repo_id,
|
| 737 |
"hub_private_repo": True,
|
| 738 |
"hub_strategy": "every_save",
|
|
@@ -742,7 +932,7 @@ def train_cybersecurity_owasp_grpo(
|
|
| 742 |
"epsilon_high": 0.28,
|
| 743 |
"delta": 1.5,
|
| 744 |
"loss_type": "bnpo",
|
| 745 |
-
"mask_truncated_completions":
|
| 746 |
}
|
| 747 |
grpo_config_parameters = set(inspect.signature(GRPOConfig).parameters)
|
| 748 |
skipped_config_keys = sorted(set(grpo_config_values) - grpo_config_parameters)
|
|
@@ -776,9 +966,23 @@ def train_cybersecurity_owasp_grpo(
|
|
| 776 |
if key in trainer_parameters
|
| 777 |
}
|
| 778 |
)
|
|
|
|
| 779 |
trainer.train()
|
| 780 |
-
trainer.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 781 |
volume.commit()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 782 |
|
| 783 |
return {
|
| 784 |
"run_name": run_name,
|
|
@@ -796,6 +1000,7 @@ def train_cybersecurity_owasp_grpo(
|
|
| 796 |
"source_mode": source_mode,
|
| 797 |
"repo_url": repo_url,
|
| 798 |
"repo_branch": repo_branch,
|
|
|
|
| 799 |
}
|
| 800 |
|
| 801 |
|
|
@@ -808,11 +1013,11 @@ def main(
|
|
| 808 |
dataset_size: int = 16,
|
| 809 |
difficulty: int = 0,
|
| 810 |
split: str = "train",
|
| 811 |
-
model_name: str =
|
| 812 |
max_seq_length: int = 4096,
|
| 813 |
max_completion_length: int = 768,
|
| 814 |
lora_rank: int = 32,
|
| 815 |
-
trackio_space_id: str = "",
|
| 816 |
trackio_project: str = "CyberSecurity_OWASP-grpo",
|
| 817 |
num_generations: int = 2,
|
| 818 |
seed_start: int = 0,
|
|
@@ -821,6 +1026,7 @@ def main(
|
|
| 821 |
repo_url: str = PUBLIC_REPO_URL,
|
| 822 |
repo_branch: str = PUBLIC_REPO_BRANCH,
|
| 823 |
detach: bool = False,
|
|
|
|
| 824 |
) -> None:
|
| 825 |
if mode == "config":
|
| 826 |
result = check_training_imports.remote()
|
|
@@ -829,7 +1035,10 @@ def main(
|
|
| 829 |
if mode != "train":
|
| 830 |
raise ValueError("mode must be 'train' or 'config'")
|
| 831 |
|
| 832 |
-
trackio_space_id = trackio_space_id or os.environ.get(
|
|
|
|
|
|
|
|
|
|
| 833 |
trackio_project = trackio_project or os.environ.get(
|
| 834 |
"TRACKIO_PROJECT", "CyberSecurity_OWASP-grpo"
|
| 835 |
)
|
|
@@ -842,12 +1051,15 @@ def main(
|
|
| 842 |
from huggingface_hub import whoami
|
| 843 |
|
| 844 |
user = whoami(token=hf_token)["name"]
|
| 845 |
-
|
| 846 |
-
resolved_trackio_space_id
|
| 847 |
-
|
|
|
|
|
|
|
|
|
|
| 848 |
resolved_output_repo_id = (
|
| 849 |
resolved_output_repo_id
|
| 850 |
-
or f"{user}/CyberSecurity_OWASP-
|
| 851 |
)
|
| 852 |
except Exception as exc:
|
| 853 |
print(f"Could not resolve Hugging Face defaults locally: {exc!r}")
|
|
@@ -883,8 +1095,10 @@ def main(
|
|
| 883 |
else:
|
| 884 |
print(
|
| 885 |
"Output model repo: derived remotely from HF_TOKEN as "
|
| 886 |
-
"<hf-user>/CyberSecurity_OWASP-
|
| 887 |
)
|
|
|
|
|
|
|
| 888 |
|
| 889 |
kwargs = dict(
|
| 890 |
env_repo_id=env_repo_id,
|
|
@@ -906,6 +1120,7 @@ def main(
|
|
| 906 |
source_mode=source_mode,
|
| 907 |
repo_url=repo_url,
|
| 908 |
repo_branch=repo_branch,
|
|
|
|
| 909 |
)
|
| 910 |
if detach:
|
| 911 |
call = train_cybersecurity_owasp_grpo.spawn(**kwargs)
|
|
|
|
| 28 |
|
| 29 |
APP_NAME = "CyberSecurity_OWASP-grpo"
|
| 30 |
VOLUME_NAME = "CyberSecurity_OWASP-grpo-runs"
|
| 31 |
+
CACHE_VOLUME_NAME = "CyberSecurity_OWASP-model-cache"
|
| 32 |
SECRET_NAME = "CyberSecurity_OWASP-secrets"
|
| 33 |
RUNS_DIR = pathlib.Path("/runs")
|
| 34 |
+
CACHE_DIR = pathlib.Path("/cache")
|
| 35 |
+
HF_HOME_DIR = CACHE_DIR / "huggingface"
|
| 36 |
+
HF_HUB_CACHE_DIR = HF_HOME_DIR / "hub"
|
| 37 |
+
TORCH_HOME_DIR = CACHE_DIR / "torch"
|
| 38 |
+
XDG_CACHE_DIR = CACHE_DIR / "xdg"
|
| 39 |
+
UNSLOTH_CACHE_DIR = CACHE_DIR / "unsloth"
|
| 40 |
+
TRITON_CACHE_DIR = CACHE_DIR / "triton"
|
| 41 |
REMOTE_PROJECT = "/root/CyberSecurity_OWASP"
|
| 42 |
PROJECT_ROOT = pathlib.Path(__file__).resolve().parents[1]
|
| 43 |
PUBLIC_REPO_URL = "https://github.com/humandotlearning/CyberSecurity_OWASP.git"
|
| 44 |
PUBLIC_REPO_BRANCH = "master"
|
| 45 |
+
DEFAULT_GEMMA_MODEL = "unsloth/gemma-4-E2B-it"
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _model_repo_slug(model_name: str) -> str:
|
| 49 |
+
return (
|
| 50 |
+
model_name.replace("/", "-")
|
| 51 |
+
.replace("_", "-")
|
| 52 |
+
.replace(".", "-")
|
| 53 |
+
.lower()
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def _hf_model_cache_path(model_name: str) -> pathlib.Path:
|
| 58 |
+
return HF_HUB_CACHE_DIR / f"models--{model_name.replace('/', '--')}"
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _configure_modal_cache_env() -> dict[str, str]:
|
| 62 |
+
values = {
|
| 63 |
+
"HF_HOME": str(HF_HOME_DIR),
|
| 64 |
+
"HF_HUB_CACHE": str(HF_HUB_CACHE_DIR),
|
| 65 |
+
"TRANSFORMERS_CACHE": str(HF_HUB_CACHE_DIR),
|
| 66 |
+
"TORCH_HOME": str(TORCH_HOME_DIR),
|
| 67 |
+
"XDG_CACHE_HOME": str(XDG_CACHE_DIR),
|
| 68 |
+
"UNSLOTH_CACHE_DIR": str(UNSLOTH_CACHE_DIR),
|
| 69 |
+
"UNSLOTH_COMPILE_CACHE": str(UNSLOTH_CACHE_DIR / "compile"),
|
| 70 |
+
"TRITON_CACHE_DIR": str(TRITON_CACHE_DIR),
|
| 71 |
+
}
|
| 72 |
+
for key, value in values.items():
|
| 73 |
+
os.environ[key] = value
|
| 74 |
+
for path in {
|
| 75 |
+
CACHE_DIR,
|
| 76 |
+
HF_HOME_DIR,
|
| 77 |
+
HF_HUB_CACHE_DIR,
|
| 78 |
+
TORCH_HOME_DIR,
|
| 79 |
+
XDG_CACHE_DIR,
|
| 80 |
+
UNSLOTH_CACHE_DIR,
|
| 81 |
+
UNSLOTH_CACHE_DIR / "compile",
|
| 82 |
+
TRITON_CACHE_DIR,
|
| 83 |
+
}:
|
| 84 |
+
path.mkdir(parents=True, exist_ok=True)
|
| 85 |
+
return values
|
| 86 |
|
| 87 |
|
| 88 |
def _load_local_env_file() -> None:
|
|
|
|
| 163 |
"unsloth_zoo[base] @ git+https://github.com/unslothai/unsloth-zoo",
|
| 164 |
"unsloth[base] @ git+https://github.com/unslothai/unsloth",
|
| 165 |
)
|
| 166 |
+
.uv_pip_install("timm", extra_options="--no-deps")
|
| 167 |
.uv_pip_install("pydantic==2.10.6")
|
| 168 |
.uv_pip_install("mergekit", "immutables==0.21", extra_options="--no-deps")
|
| 169 |
.uv_pip_install("llm-blender", "weave")
|
|
|
|
| 209 |
|
| 210 |
app = modal.App(APP_NAME)
|
| 211 |
volume = modal.Volume.from_name(VOLUME_NAME, create_if_missing=True)
|
| 212 |
+
cache_volume = modal.Volume.from_name(CACHE_VOLUME_NAME, create_if_missing=True)
|
| 213 |
secrets = _modal_secrets()
|
| 214 |
|
| 215 |
|
| 216 |
@app.function(
|
| 217 |
image=_training_image(),
|
| 218 |
+
gpu="L4",
|
| 219 |
timeout=4 * 60 * 60,
|
| 220 |
+
volumes={RUNS_DIR: volume, CACHE_DIR: cache_volume},
|
| 221 |
secrets=secrets,
|
| 222 |
)
|
| 223 |
def check_training_imports() -> dict[str, str]:
|
| 224 |
+
cache_env = _configure_modal_cache_env()
|
| 225 |
+
|
| 226 |
import torch
|
| 227 |
import trackio
|
| 228 |
from datasets import Dataset
|
| 229 |
from trl import GRPOConfig, GRPOTrainer
|
| 230 |
+
from unsloth import FastLanguageModel, FastVisionModel
|
| 231 |
|
| 232 |
from CyberSecurity_OWASP.server.CyberSecurity_OWASP_environment import (
|
| 233 |
CybersecurityOwaspEnvironment,
|
|
|
|
| 242 |
"grpo_config": GRPOConfig.__name__,
|
| 243 |
"grpo_trainer": GRPOTrainer.__name__,
|
| 244 |
"unsloth_model": FastLanguageModel.__name__,
|
| 245 |
+
"unsloth_vision_model": FastVisionModel.__name__,
|
| 246 |
"env": CybersecurityOwaspEnvironment.__name__,
|
| 247 |
"reset_phase": obs.phase,
|
| 248 |
+
"hf_home": cache_env["HF_HOME"],
|
| 249 |
+
"hf_hub_cache": cache_env["HF_HUB_CACHE"],
|
| 250 |
}
|
| 251 |
|
| 252 |
|
| 253 |
@app.function(
|
| 254 |
image=_training_image(),
|
| 255 |
+
gpu="L4",
|
| 256 |
timeout=4 * 60 * 60,
|
| 257 |
+
volumes={RUNS_DIR: volume, CACHE_DIR: cache_volume},
|
| 258 |
secrets=secrets,
|
| 259 |
)
|
| 260 |
def train_cybersecurity_owasp_grpo(
|
|
|
|
| 264 |
dataset_size: int = 16,
|
| 265 |
difficulty: int = 0,
|
| 266 |
split: str = "train",
|
| 267 |
+
model_name: str = DEFAULT_GEMMA_MODEL,
|
| 268 |
max_seq_length: int = 4096,
|
| 269 |
max_completion_length: int = 768,
|
| 270 |
lora_rank: int = 32,
|
| 271 |
+
trackio_space_id: str = "Humanlearning/CyberSecurity_OWASP-trackio",
|
| 272 |
trackio_project: str = "CyberSecurity_OWASP-grpo",
|
| 273 |
num_generations: int = 2,
|
| 274 |
seed_start: int = 0,
|
|
|
|
| 277 |
source_mode: str = "local",
|
| 278 |
repo_url: str = PUBLIC_REPO_URL,
|
| 279 |
repo_branch: str = PUBLIC_REPO_BRANCH,
|
| 280 |
+
push_to_hub: bool = False,
|
| 281 |
) -> dict[str, str | int | float]:
|
| 282 |
import inspect
|
| 283 |
import statistics
|
| 284 |
|
| 285 |
+
cache_env = _configure_modal_cache_env()
|
| 286 |
+
|
| 287 |
import torch
|
| 288 |
+
from unsloth import FastLanguageModel, FastVisionModel
|
| 289 |
import transformers.utils.hub as transformers_hub
|
| 290 |
from datasets import Dataset
|
| 291 |
+
from huggingface_hub import snapshot_download, whoami
|
| 292 |
from transformers import TrainerCallback
|
| 293 |
from trl import GRPOConfig, GRPOTrainer, clone_chat_template
|
| 294 |
from trl.chat_template_utils import add_response_schema
|
|
|
|
| 299 |
from CyberSecurity_OWASP.server.CyberSecurity_OWASP_environment import (
|
| 300 |
CybersecurityOwaspEnvironment,
|
| 301 |
)
|
| 302 |
+
from training.trackio_utils import (
|
| 303 |
+
aggregate_episode_metrics,
|
| 304 |
+
episode_record_from_state,
|
| 305 |
+
log_gpu_metrics,
|
| 306 |
+
log_trace_table,
|
| 307 |
+
log_trackio_metrics,
|
| 308 |
+
train_metric_aliases,
|
| 309 |
+
)
|
| 310 |
|
| 311 |
+
transformers_hub.TRANSFORMERS_CACHE = cache_env["HF_HUB_CACHE"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
|
| 313 |
hf_token = os.environ.get("HF_TOKEN")
|
| 314 |
if not hf_token:
|
|
|
|
| 318 |
|
| 319 |
user = whoami(token=hf_token)["name"]
|
| 320 |
env_repo_id = env_repo_id or f"{user}/CyberSecurity_OWASP"
|
| 321 |
+
output_repo_id = output_repo_id or (
|
| 322 |
+
f"{user}/CyberSecurity_OWASP-{_model_repo_slug(model_name)}-grpo-lora"
|
| 323 |
+
)
|
| 324 |
+
if not trackio_space_id:
|
| 325 |
+
trackio_space_id = "Humanlearning/CyberSecurity_OWASP-trackio"
|
| 326 |
+
if hf_token:
|
| 327 |
+
try:
|
| 328 |
+
from huggingface_hub import whoami
|
| 329 |
+
|
| 330 |
+
user = whoami(token=hf_token)["name"]
|
| 331 |
+
if user == "humandotlearning":
|
| 332 |
+
trackio_space_id = f"{user}/CyberSecurity_OWASP-trackio"
|
| 333 |
+
except Exception:
|
| 334 |
+
pass
|
| 335 |
|
| 336 |
os.environ["TRACKIO_SPACE_ID"] = trackio_space_id
|
| 337 |
os.environ["TRACKIO_PROJECT"] = trackio_project
|
|
|
|
| 344 |
output_dir = RUNS_DIR / run_name
|
| 345 |
output_dir.mkdir(parents=True, exist_ok=True)
|
| 346 |
|
| 347 |
+
try:
|
| 348 |
+
cache_volume.reload()
|
| 349 |
+
print(f"Reloaded Modal model cache volume: {CACHE_VOLUME_NAME}")
|
| 350 |
+
except Exception as exc:
|
| 351 |
+
print(f"Model cache volume reload skipped: {exc!r}")
|
| 352 |
+
cache_env = _configure_modal_cache_env()
|
| 353 |
+
|
| 354 |
training_prompt = (
|
| 355 |
"You are a defensive AppSec repair agent in the local CyberSecurity_OWASP "
|
| 356 |
"OpenEnv environment. Use only the provided local tools. Do not target real "
|
|
|
|
| 650 |
completions = kwargs.get("completions") or kwargs.get("completion") or []
|
| 651 |
trace_step["value"] += 1
|
| 652 |
|
| 653 |
+
episode_records = []
|
| 654 |
+
for env, reward in zip(environments, rewards):
|
| 655 |
+
record = episode_record_from_state(
|
| 656 |
+
env._env.state,
|
| 657 |
+
run_context={
|
| 658 |
+
"base_model": model_name,
|
| 659 |
+
"algo": "grpo",
|
| 660 |
+
"reward_version": "reward_v1",
|
| 661 |
+
"env_version": "0.1.0",
|
| 662 |
+
},
|
| 663 |
+
)
|
| 664 |
+
record.update(
|
| 665 |
+
{
|
| 666 |
+
"reward_total": reward,
|
| 667 |
+
"success": bool(getattr(env, "success", False)),
|
| 668 |
+
}
|
| 669 |
+
)
|
| 670 |
+
episode_records.append(record)
|
| 671 |
+
|
| 672 |
+
canonical_metrics = aggregate_episode_metrics(episode_records)
|
| 673 |
metrics = {
|
| 674 |
+
**canonical_metrics,
|
| 675 |
+
**train_metric_aliases(canonical_metrics),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 676 |
}
|
| 677 |
+
if rewards:
|
| 678 |
+
metrics["train/reward_mean"] = _mean(rewards)
|
| 679 |
+
metrics["train/reward_std"] = statistics.pstdev(rewards) if len(rewards) > 1 else 0.0
|
| 680 |
|
| 681 |
try:
|
| 682 |
+
log_trackio_metrics(metrics, step=trace_step["value"])
|
| 683 |
except Exception as exc:
|
| 684 |
print(f"Trackio metric logging skipped: {exc!r}")
|
| 685 |
|
| 686 |
+
try:
|
| 687 |
+
log_trace_table(
|
| 688 |
+
episode_records[: min(4, len(episode_records))],
|
| 689 |
+
table_name="sample_traces",
|
| 690 |
+
step=trace_step["value"],
|
| 691 |
+
)
|
| 692 |
+
except Exception as exc:
|
| 693 |
+
print(f"Trackio sample trace table logging skipped: {exc!r}")
|
| 694 |
+
|
| 695 |
for index, env in enumerate(environments):
|
| 696 |
messages = list(getattr(env, "trace_messages", []))
|
| 697 |
if index < len(completions):
|
|
|
|
| 734 |
return rewards
|
| 735 |
|
| 736 |
class TrackioSystemMetricsCallback(TrainerCallback):
|
| 737 |
+
def on_train_begin(self, args, state, control, **kwargs):
|
| 738 |
+
try:
|
| 739 |
+
metrics = log_gpu_metrics(step=int(state.global_step or 0))
|
| 740 |
+
except Exception as exc:
|
| 741 |
+
print(f"Trackio GPU metrics initialization skipped: {exc!r}")
|
| 742 |
+
return control
|
| 743 |
+
if metrics:
|
| 744 |
+
system_summary = ", ".join(
|
| 745 |
+
f"{key}={value}"
|
| 746 |
+
for key, value in sorted(metrics.items())
|
| 747 |
+
if key.startswith("system/")
|
| 748 |
+
)
|
| 749 |
+
print(f"Trackio GPU metrics initialized: {system_summary}")
|
| 750 |
+
return control
|
| 751 |
+
|
| 752 |
def on_log(self, args, state, control, logs=None, **kwargs):
|
| 753 |
try:
|
| 754 |
+
metrics = log_gpu_metrics(step=int(state.global_step or 0))
|
| 755 |
except Exception as exc:
|
| 756 |
print(f"Trackio GPU metrics skipped: {exc!r}")
|
| 757 |
return control
|
|
|
|
| 760 |
print(f"Trackio GPU metrics logged at step {state.global_step}: {summary}")
|
| 761 |
return control
|
| 762 |
|
| 763 |
+
def on_train_end(self, args, state, control, **kwargs):
|
| 764 |
+
try:
|
| 765 |
+
log_gpu_metrics(step=int(state.global_step or 0))
|
| 766 |
+
except Exception as exc:
|
| 767 |
+
print(f"Trackio final GPU metrics skipped: {exc!r}")
|
| 768 |
+
return control
|
| 769 |
+
|
| 770 |
print(f"CUDA available: {torch.cuda.is_available()}")
|
| 771 |
if source_mode == "public":
|
| 772 |
print(f"Installed CyberSecurity_OWASP from public repo: {repo_url}@{repo_branch}")
|
|
|
|
| 776 |
print(f"Trackio Project: {trackio_project}")
|
| 777 |
print(f"Output repo: {output_repo_id}")
|
| 778 |
print(f"Run name: {run_name}")
|
| 779 |
+
print(f"Model cache volume: {CACHE_VOLUME_NAME}")
|
| 780 |
+
print(f"HF_HOME: {cache_env['HF_HOME']}")
|
| 781 |
+
print(f"HF_HUB_CACHE: {cache_env['HF_HUB_CACHE']}")
|
| 782 |
+
print(f"Torch cache: {cache_env['TORCH_HOME']}")
|
| 783 |
+
print(f"Unsloth cache: {cache_env['UNSLOTH_CACHE_DIR']}")
|
| 784 |
+
print(f"Triton cache: {cache_env['TRITON_CACHE_DIR']}")
|
| 785 |
+
print(f"Hub push enabled: {push_to_hub}")
|
| 786 |
+
|
| 787 |
+
trackio.init(
|
| 788 |
+
project=trackio_project,
|
| 789 |
+
name=run_name,
|
| 790 |
+
group="grpo",
|
| 791 |
+
space_id=trackio_space_id,
|
| 792 |
+
auto_log_gpu=True,
|
| 793 |
+
gpu_log_interval=10.0,
|
| 794 |
+
config={
|
| 795 |
+
"environment": "CyberSecurity_OWASP",
|
| 796 |
+
"run_type": "modal_grpo",
|
| 797 |
+
"model_name": model_name,
|
| 798 |
+
"difficulty": difficulty,
|
| 799 |
+
"split": split,
|
| 800 |
+
"dataset_size": dataset_size,
|
| 801 |
+
"max_steps": max_steps,
|
| 802 |
+
"num_generations": num_generations,
|
| 803 |
+
"max_seq_length": max_seq_length,
|
| 804 |
+
"max_completion_length": max_completion_length,
|
| 805 |
+
"lora_rank": lora_rank,
|
| 806 |
+
"gpu_requested": "L4",
|
| 807 |
+
"load_in_4bit": False,
|
| 808 |
+
"fast_inference": False,
|
| 809 |
+
"gradient_checkpointing": "unsloth",
|
| 810 |
+
"optim": "adamw_8bit",
|
| 811 |
+
},
|
| 812 |
+
)
|
| 813 |
+
log_gpu_metrics(step=0)
|
| 814 |
+
|
| 815 |
+
expected_model_cache = _hf_model_cache_path(model_name)
|
| 816 |
+
cache_hit = expected_model_cache.exists()
|
| 817 |
+
print(f"Expected HF model cache path: {expected_model_cache}")
|
| 818 |
+
print(f"Model cache hit before load: {cache_hit}")
|
| 819 |
+
if cache_hit:
|
| 820 |
+
print("Using cached model snapshot from the persistent Modal volume when valid.")
|
| 821 |
+
else:
|
| 822 |
+
print(
|
| 823 |
+
"Model cache miss. Downloading model weights once into the persistent "
|
| 824 |
+
"Modal cache volume; Hugging Face progress output should follow."
|
| 825 |
+
)
|
| 826 |
+
try:
|
| 827 |
+
snapshot_path = snapshot_download(
|
| 828 |
+
repo_id=model_name,
|
| 829 |
+
cache_dir=str(HF_HUB_CACHE_DIR),
|
| 830 |
+
token=hf_token,
|
| 831 |
+
)
|
| 832 |
+
print(f"Model snapshot ready: {snapshot_path}")
|
| 833 |
+
cache_volume.commit()
|
| 834 |
+
print(f"Committed Modal model cache volume after snapshot download: {CACHE_VOLUME_NAME}")
|
| 835 |
+
except Exception as exc:
|
| 836 |
+
print(
|
| 837 |
+
"Explicit model snapshot prefetch failed; Unsloth will attempt the "
|
| 838 |
+
f"model load directly. Error: {exc!r}"
|
| 839 |
+
)
|
| 840 |
|
| 841 |
+
print(f"Loading model with Unsloth from_pretrained: {model_name}")
|
| 842 |
+
model_api = FastVisionModel if "gemma-4" in model_name.lower() else FastLanguageModel
|
| 843 |
+
model, tokenizer = model_api.from_pretrained(
|
| 844 |
model_name=model_name,
|
| 845 |
max_seq_length=max_seq_length,
|
| 846 |
load_in_4bit=False,
|
| 847 |
fast_inference=False,
|
| 848 |
+
cache_dir=str(HF_HUB_CACHE_DIR),
|
| 849 |
token=hf_token,
|
| 850 |
)
|
| 851 |
+
print("Model load complete.")
|
| 852 |
+
cache_volume.commit()
|
| 853 |
+
print(f"Committed Modal model cache volume after model load: {CACHE_VOLUME_NAME}")
|
| 854 |
try:
|
| 855 |
tokenizer = add_response_schema(tokenizer)
|
| 856 |
except Exception as exc:
|
| 857 |
+
if "gemma-4" in model_name.lower():
|
| 858 |
+
print(
|
| 859 |
+
"Tokenizer response schema add skipped for Gemma 4 processor, "
|
| 860 |
+
"matching the Unsloth Gemma 4 GRPO notebook pattern: "
|
| 861 |
+
f"{exc!r}"
|
| 862 |
+
)
|
| 863 |
+
else:
|
| 864 |
+
print(f"Tokenizer response schema add failed before cloning: {exc!r}")
|
| 865 |
+
for template_source in ("Qwen/Qwen3-0.6B", "Qwen/Qwen2.5-0.5B-Instruct"):
|
| 866 |
+
try:
|
| 867 |
+
model, tokenizer, added_tokens = clone_chat_template(
|
| 868 |
+
model,
|
| 869 |
+
tokenizer,
|
| 870 |
+
template_source,
|
| 871 |
+
)
|
| 872 |
+
print(
|
| 873 |
+
"Cloned response-schema-capable chat template "
|
| 874 |
+
f"from {template_source}; added {len(added_tokens)} tokens."
|
| 875 |
+
)
|
| 876 |
+
tokenizer = add_response_schema(tokenizer)
|
| 877 |
+
break
|
| 878 |
+
except Exception as clone_exc:
|
| 879 |
+
print(
|
| 880 |
+
"Tokenizer response schema fallback failed for "
|
| 881 |
+
f"{template_source}: {clone_exc!r}"
|
| 882 |
+
)
|
| 883 |
+
else:
|
| 884 |
+
raise
|
| 885 |
|
| 886 |
+
model = model_api.get_peft_model(
|
| 887 |
model,
|
| 888 |
r=lora_rank,
|
| 889 |
target_modules=[
|
|
|
|
| 899 |
use_gradient_checkpointing="unsloth",
|
| 900 |
random_state=3407,
|
| 901 |
)
|
| 902 |
+
if hasattr(model_api, "for_training"):
|
| 903 |
+
model_api.for_training(model)
|
| 904 |
+
print("LoRA adapter attached and model switched to training mode.")
|
| 905 |
|
| 906 |
grpo_config_values = {
|
| 907 |
"temperature": 1.0,
|
|
|
|
| 922 |
"trackio_space_id": trackio_space_id,
|
| 923 |
"run_name": run_name,
|
| 924 |
"output_dir": str(output_dir),
|
| 925 |
+
"push_to_hub": push_to_hub,
|
| 926 |
"hub_model_id": output_repo_id,
|
| 927 |
"hub_private_repo": True,
|
| 928 |
"hub_strategy": "every_save",
|
|
|
|
| 932 |
"epsilon_high": 0.28,
|
| 933 |
"delta": 1.5,
|
| 934 |
"loss_type": "bnpo",
|
| 935 |
+
"mask_truncated_completions": True,
|
| 936 |
}
|
| 937 |
grpo_config_parameters = set(inspect.signature(GRPOConfig).parameters)
|
| 938 |
skipped_config_keys = sorted(set(grpo_config_values) - grpo_config_parameters)
|
|
|
|
| 966 |
if key in trainer_parameters
|
| 967 |
}
|
| 968 |
)
|
| 969 |
+
print("Starting GRPO trainer.train().")
|
| 970 |
trainer.train()
|
| 971 |
+
print("GRPO trainer.train() complete.")
|
| 972 |
+
if push_to_hub:
|
| 973 |
+
print(f"Pushing LoRA adapter to Hugging Face Hub: {output_repo_id}")
|
| 974 |
+
trainer.push_to_hub()
|
| 975 |
+
print("Hub push complete.")
|
| 976 |
+
else:
|
| 977 |
+
print("Skipping Hub push for this run. Pass --push-to-hub to upload adapters.")
|
| 978 |
volume.commit()
|
| 979 |
+
cache_volume.commit()
|
| 980 |
+
print(f"Committed run volume: {VOLUME_NAME}")
|
| 981 |
+
print(f"Committed model cache volume: {CACHE_VOLUME_NAME}")
|
| 982 |
+
try:
|
| 983 |
+
trackio.finish()
|
| 984 |
+
except RuntimeError as exc:
|
| 985 |
+
print(f"Trackio finish skipped because the trainer already finalized it: {exc}")
|
| 986 |
|
| 987 |
return {
|
| 988 |
"run_name": run_name,
|
|
|
|
| 1000 |
"source_mode": source_mode,
|
| 1001 |
"repo_url": repo_url,
|
| 1002 |
"repo_branch": repo_branch,
|
| 1003 |
+
"push_to_hub": push_to_hub,
|
| 1004 |
}
|
| 1005 |
|
| 1006 |
|
|
|
|
| 1013 |
dataset_size: int = 16,
|
| 1014 |
difficulty: int = 0,
|
| 1015 |
split: str = "train",
|
| 1016 |
+
model_name: str = DEFAULT_GEMMA_MODEL,
|
| 1017 |
max_seq_length: int = 4096,
|
| 1018 |
max_completion_length: int = 768,
|
| 1019 |
lora_rank: int = 32,
|
| 1020 |
+
trackio_space_id: str = "Humanlearning/CyberSecurity_OWASP-trackio",
|
| 1021 |
trackio_project: str = "CyberSecurity_OWASP-grpo",
|
| 1022 |
num_generations: int = 2,
|
| 1023 |
seed_start: int = 0,
|
|
|
|
| 1026 |
repo_url: str = PUBLIC_REPO_URL,
|
| 1027 |
repo_branch: str = PUBLIC_REPO_BRANCH,
|
| 1028 |
detach: bool = False,
|
| 1029 |
+
push_to_hub: bool = False,
|
| 1030 |
) -> None:
|
| 1031 |
if mode == "config":
|
| 1032 |
result = check_training_imports.remote()
|
|
|
|
| 1035 |
if mode != "train":
|
| 1036 |
raise ValueError("mode must be 'train' or 'config'")
|
| 1037 |
|
| 1038 |
+
trackio_space_id = trackio_space_id or os.environ.get(
|
| 1039 |
+
"TRACKIO_SPACE_ID",
|
| 1040 |
+
"Humanlearning/CyberSecurity_OWASP-trackio",
|
| 1041 |
+
)
|
| 1042 |
trackio_project = trackio_project or os.environ.get(
|
| 1043 |
"TRACKIO_PROJECT", "CyberSecurity_OWASP-grpo"
|
| 1044 |
)
|
|
|
|
| 1051 |
from huggingface_hub import whoami
|
| 1052 |
|
| 1053 |
user = whoami(token=hf_token)["name"]
|
| 1054 |
+
if not resolved_trackio_space_id:
|
| 1055 |
+
resolved_trackio_space_id = (
|
| 1056 |
+
f"{user}/CyberSecurity_OWASP-trackio"
|
| 1057 |
+
if user == "humandotlearning"
|
| 1058 |
+
else "Humanlearning/CyberSecurity_OWASP-trackio"
|
| 1059 |
+
)
|
| 1060 |
resolved_output_repo_id = (
|
| 1061 |
resolved_output_repo_id
|
| 1062 |
+
or f"{user}/CyberSecurity_OWASP-{_model_repo_slug(model_name)}-grpo-lora"
|
| 1063 |
)
|
| 1064 |
except Exception as exc:
|
| 1065 |
print(f"Could not resolve Hugging Face defaults locally: {exc!r}")
|
|
|
|
| 1095 |
else:
|
| 1096 |
print(
|
| 1097 |
"Output model repo: derived remotely from HF_TOKEN as "
|
| 1098 |
+
f"<hf-user>/CyberSecurity_OWASP-{_model_repo_slug(model_name)}-grpo-lora"
|
| 1099 |
)
|
| 1100 |
+
print(f"Hub push enabled: {push_to_hub}")
|
| 1101 |
+
print(f"Model cache volume: {CACHE_VOLUME_NAME}")
|
| 1102 |
|
| 1103 |
kwargs = dict(
|
| 1104 |
env_repo_id=env_repo_id,
|
|
|
|
| 1120 |
source_mode=source_mode,
|
| 1121 |
repo_url=repo_url,
|
| 1122 |
repo_branch=repo_branch,
|
| 1123 |
+
push_to_hub=push_to_hub,
|
| 1124 |
)
|
| 1125 |
if detach:
|
| 1126 |
call = train_cybersecurity_owasp_grpo.spawn(**kwargs)
|
tests/test_trackio_utils.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
|
| 3 |
+
from CyberSecurity_OWASP.models import CyberSecurityOWASPAction
|
| 4 |
+
from training.trackio_utils import (
|
| 5 |
+
CANONICAL_TRACKIO_SIGNALS,
|
| 6 |
+
DERIVED_TRACKIO_METRICS,
|
| 7 |
+
aggregate_episode_metrics,
|
| 8 |
+
episode_record_from_state,
|
| 9 |
+
episode_to_trace_row,
|
| 10 |
+
episode_to_tracking_fields,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
from .helpers import apply_secure_patch, make_env, secure_invoice_source, submit_valid_finding
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def test_canonical_tracking_fields_exist_and_are_numeric_where_expected():
|
| 17 |
+
assert len(CANONICAL_TRACKIO_SIGNALS) == 57
|
| 18 |
+
|
| 19 |
+
env = make_env(70)
|
| 20 |
+
try:
|
| 21 |
+
submit_valid_finding(env)
|
| 22 |
+
apply_secure_patch(env)
|
| 23 |
+
env.step(CyberSecurityOWASPAction(tool_name="run_visible_tests"))
|
| 24 |
+
final = env.step(CyberSecurityOWASPAction(tool_name="submit_fix"))
|
| 25 |
+
|
| 26 |
+
record = episode_record_from_state(env.state, final_observation=final.model_dump())
|
| 27 |
+
fields = episode_to_tracking_fields(record)
|
| 28 |
+
|
| 29 |
+
for key in (*CANONICAL_TRACKIO_SIGNALS, *DERIVED_TRACKIO_METRICS):
|
| 30 |
+
assert key in fields
|
| 31 |
+
|
| 32 |
+
assert isinstance(fields["reward/hidden_authz_pass_rate"], float)
|
| 33 |
+
assert isinstance(fields["reward/normal_flow_pass_rate"], float)
|
| 34 |
+
assert isinstance(fields["reward/public_hidden_gap"], float)
|
| 35 |
+
assert isinstance(fields["skill/exploit_to_patch_alignment"], float)
|
| 36 |
+
|
| 37 |
+
metrics = aggregate_episode_metrics([record])
|
| 38 |
+
assert metrics["reward/hidden_authz_pass_rate"] == 1.0
|
| 39 |
+
assert metrics["skill/patch_success"] == 1.0
|
| 40 |
+
finally:
|
| 41 |
+
env.close()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def test_tracking_cheat_score_flags_hardcoded_identifier():
|
| 45 |
+
env = make_env(71)
|
| 46 |
+
try:
|
| 47 |
+
submit_valid_finding(env)
|
| 48 |
+
hidden = env.state.hidden_facts
|
| 49 |
+
source = secure_invoice_source(env) + f"\n# hardcoded {hidden['owner_invoice_id']}\n"
|
| 50 |
+
env.step(
|
| 51 |
+
CyberSecurityOWASPAction(
|
| 52 |
+
tool_name="patch_file",
|
| 53 |
+
arguments={"path": "app/routes/invoices.py", "content": source},
|
| 54 |
+
)
|
| 55 |
+
)
|
| 56 |
+
final = env.step(CyberSecurityOWASPAction(tool_name="submit_fix"))
|
| 57 |
+
|
| 58 |
+
record = episode_record_from_state(env.state, final_observation=final.model_dump())
|
| 59 |
+
fields = episode_to_tracking_fields(record)
|
| 60 |
+
|
| 61 |
+
assert fields["cheat/hardcoded_scenario_id"] == 1.0
|
| 62 |
+
assert fields["cheat/score"] >= 50.0
|
| 63 |
+
finally:
|
| 64 |
+
env.close()
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def test_trace_rows_redact_hidden_values_from_action_arguments():
|
| 68 |
+
env = make_env(72)
|
| 69 |
+
try:
|
| 70 |
+
hidden = dict(env.state.hidden_facts)
|
| 71 |
+
submit_valid_finding(env)
|
| 72 |
+
apply_secure_patch(env)
|
| 73 |
+
env.step(CyberSecurityOWASPAction(tool_name="run_visible_tests"))
|
| 74 |
+
final = env.step(CyberSecurityOWASPAction(tool_name="submit_fix"))
|
| 75 |
+
|
| 76 |
+
record = episode_record_from_state(env.state, final_observation=final.model_dump())
|
| 77 |
+
row = episode_to_trace_row(record)
|
| 78 |
+
row_text = json.dumps(row, sort_keys=True)
|
| 79 |
+
|
| 80 |
+
for key in (
|
| 81 |
+
"owner_user_id",
|
| 82 |
+
"intruder_user_id",
|
| 83 |
+
"admin_user_id",
|
| 84 |
+
"owner_invoice_id",
|
| 85 |
+
"other_invoice_id",
|
| 86 |
+
"foreign_invoice_id",
|
| 87 |
+
"tenant_a",
|
| 88 |
+
"tenant_b",
|
| 89 |
+
):
|
| 90 |
+
value = str(hidden.get(key, ""))
|
| 91 |
+
assert not value or value not in row_text
|
| 92 |
+
finally:
|
| 93 |
+
env.close()
|
training/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Training and tracking utilities for CyberSecurity_OWASP."""
|
training/configs/grpo_small.yaml
CHANGED
|
@@ -1,9 +1,11 @@
|
|
| 1 |
-
model_name:
|
| 2 |
algo: grpo
|
| 3 |
environment: CyberSecurity_OWASP
|
| 4 |
max_steps: 40
|
|
|
|
| 5 |
num_generations: 2
|
| 6 |
per_device_train_batch_size: 1
|
| 7 |
gradient_accumulation_steps: 32
|
| 8 |
learning_rate: 0.000005
|
| 9 |
report_to: trackio
|
|
|
|
|
|
| 1 |
+
model_name: unsloth/gemma-4-E2B-it
|
| 2 |
algo: grpo
|
| 3 |
environment: CyberSecurity_OWASP
|
| 4 |
max_steps: 40
|
| 5 |
+
episodes: 10
|
| 6 |
num_generations: 2
|
| 7 |
per_device_train_batch_size: 1
|
| 8 |
gradient_accumulation_steps: 32
|
| 9 |
learning_rate: 0.000005
|
| 10 |
report_to: trackio
|
| 11 |
+
trackio_space_id: Humanlearning/CyberSecurity_OWASP-trackio
|
training/rollout.py
CHANGED
|
@@ -38,8 +38,15 @@ def generate_rollout_completions(trainer, prompts: list[str]) -> list[dict[str,
|
|
| 38 |
]
|
| 39 |
|
| 40 |
|
| 41 |
-
def rollout_once(
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
observation = result.observation if hasattr(result, "observation") else result
|
| 44 |
|
| 45 |
prompt_ids = []
|
|
|
|
| 38 |
]
|
| 39 |
|
| 40 |
|
| 41 |
+
def rollout_once(
|
| 42 |
+
trainer,
|
| 43 |
+
env,
|
| 44 |
+
tokenizer=None,
|
| 45 |
+
dataset_prompt: str = "",
|
| 46 |
+
max_steps: int = 40,
|
| 47 |
+
reset_kwargs: dict[str, Any] | None = None,
|
| 48 |
+
) -> dict:
|
| 49 |
+
result = env.reset(**(reset_kwargs or {}))
|
| 50 |
observation = result.observation if hasattr(result, "observation") else result
|
| 51 |
|
| 52 |
prompt_ids = []
|
training/trackio_utils.py
CHANGED
|
@@ -2,12 +2,167 @@
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
|
|
|
|
|
|
| 5 |
import os
|
|
|
|
| 6 |
import subprocess
|
| 7 |
from contextlib import contextmanager
|
| 8 |
from datetime import datetime
|
| 9 |
from pathlib import Path
|
| 10 |
-
from typing import Any, Iterator
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
|
| 13 |
TRAIN_METRICS = [
|
|
@@ -59,6 +214,657 @@ EVAL_METRICS = [
|
|
| 59 |
]
|
| 60 |
|
| 61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
def build_run_name(model: str, algo: str, difficulty: int, git_sha: str = "nogit") -> str:
|
| 63 |
stamp = datetime.utcnow().strftime("%Y%m%d-%H%M%S")
|
| 64 |
model_slug = model.replace("/", "-")
|
|
@@ -98,6 +904,8 @@ def init_trackio_run(
|
|
| 98 |
project: str | None = None,
|
| 99 |
space_id: str | None = None,
|
| 100 |
group: str | None = None,
|
|
|
|
|
|
|
| 101 |
):
|
| 102 |
trackio = _load_trackio()
|
| 103 |
project = project or os.getenv("TRACKIO_PROJECT", "CyberSecurity_OWASP")
|
|
@@ -116,6 +924,10 @@ def init_trackio_run(
|
|
| 116 |
kwargs["space_id"] = space_id
|
| 117 |
if group:
|
| 118 |
kwargs["group"] = group
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
return trackio.init(**kwargs)
|
| 120 |
|
| 121 |
|
|
@@ -132,6 +944,57 @@ def log_trackio_metrics(metrics: dict[str, Any], step: int | None = None) -> Non
|
|
| 132 |
trackio.log(numeric, step=step)
|
| 133 |
|
| 134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
def finish_trackio_run() -> None:
|
| 136 |
trackio = _load_trackio()
|
| 137 |
trackio.finish()
|
|
@@ -146,6 +1009,8 @@ def trackio_run(
|
|
| 146 |
project: str | None = None,
|
| 147 |
space_id: str | None = None,
|
| 148 |
group: str | None = None,
|
|
|
|
|
|
|
| 149 |
) -> Iterator[Any]:
|
| 150 |
run = init_trackio_run(
|
| 151 |
run_name=run_name,
|
|
@@ -154,6 +1019,8 @@ def trackio_run(
|
|
| 154 |
project=project,
|
| 155 |
space_id=space_id,
|
| 156 |
group=group,
|
|
|
|
|
|
|
| 157 |
)
|
| 158 |
try:
|
| 159 |
yield run
|
|
@@ -167,5 +1034,6 @@ def log_eval_summary(run_name: str, summary: dict[str, Any], config: dict[str, A
|
|
| 167 |
for key, value in summary.items()
|
| 168 |
if isinstance(value, (int, float, bool))
|
| 169 |
}
|
|
|
|
| 170 |
with trackio_run(run_name=run_name, run_type="eval", config=config, group="eval"):
|
| 171 |
log_trackio_metrics(metrics, step=0)
|
|
|
|
| 2 |
|
| 3 |
from __future__ import annotations
|
| 4 |
|
| 5 |
+
import hashlib
|
| 6 |
+
import json
|
| 7 |
import os
|
| 8 |
+
import re
|
| 9 |
import subprocess
|
| 10 |
from contextlib import contextmanager
|
| 11 |
from datetime import datetime
|
| 12 |
from pathlib import Path
|
| 13 |
+
from typing import Any, Iterator, Mapping, Sequence
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
RUN_SCENARIO_FIELDS = (
|
| 17 |
+
"run/base_model",
|
| 18 |
+
"run/algo",
|
| 19 |
+
"run/reward_version",
|
| 20 |
+
"run/env_version",
|
| 21 |
+
"scenario/seed",
|
| 22 |
+
"scenario/template_id",
|
| 23 |
+
"scenario/split",
|
| 24 |
+
"scenario/difficulty",
|
| 25 |
+
"scenario/bug_type",
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
REWARD_DECOMPOSITION_FIELDS = (
|
| 29 |
+
"reward/total",
|
| 30 |
+
"reward/exploit_reproduced_pre_patch",
|
| 31 |
+
"reward/bug_classification_correct",
|
| 32 |
+
"reward/patch_blocks_submitted_exploit",
|
| 33 |
+
"reward/hidden_authz_pass_rate",
|
| 34 |
+
"reward/public_tests_pass_rate",
|
| 35 |
+
"reward/normal_flow_pass_rate",
|
| 36 |
+
"reward/minimality_score",
|
| 37 |
+
"reward/cheat_penalty",
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
BEHAVIOR_SKILL_FIELDS = (
|
| 41 |
+
"skill/valid_action_rate",
|
| 42 |
+
"skill/discovery_success",
|
| 43 |
+
"skill/patch_success",
|
| 44 |
+
"skill/first_valid_exploit_step",
|
| 45 |
+
"skill/exploit_to_patch_alignment",
|
| 46 |
+
"skill/policy_graph_coverage",
|
| 47 |
+
"skill/route_coverage",
|
| 48 |
+
"skill/auth_boundary_cases_tested",
|
| 49 |
+
"skill/files_read_count",
|
| 50 |
+
"skill/files_modified_count",
|
| 51 |
+
"skill/security_relevant_edit_ratio",
|
| 52 |
+
"skill/tests_run_count",
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
ANTI_CHEAT_FIELDS = (
|
| 56 |
+
"cheat/hidden_file_read_attempt",
|
| 57 |
+
"cheat/reward_file_read_attempt",
|
| 58 |
+
"cheat/test_file_modified",
|
| 59 |
+
"cheat/reward_file_modified",
|
| 60 |
+
"cheat/env_var_secret_probe",
|
| 61 |
+
"cheat/hardcoded_scenario_id",
|
| 62 |
+
"cheat/catch_all_deny_patch",
|
| 63 |
+
"cheat/route_deleted_or_stubbed",
|
| 64 |
+
"cheat/auth_disabled_or_bypassed",
|
| 65 |
+
"cheat/database_reset_or_mutation_abuse",
|
| 66 |
+
"cheat/timeout_or_resource_abuse",
|
| 67 |
+
"cheat/observation_leak_detected",
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
GENERALIZATION_EVAL_FIELDS = (
|
| 71 |
+
"eval/train_success_rate",
|
| 72 |
+
"eval/seen_template_success_rate",
|
| 73 |
+
"eval/unseen_template_success_rate",
|
| 74 |
+
"eval/unseen_framework_success_rate",
|
| 75 |
+
"eval/unseen_policy_graph_success_rate",
|
| 76 |
+
"eval/hidden_success_rate",
|
| 77 |
+
"eval/train_hidden_gap",
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
TRAINING_SYSTEM_FIELDS = (
|
| 81 |
+
"train/loss",
|
| 82 |
+
"train/kl",
|
| 83 |
+
"train/entropy",
|
| 84 |
+
"train/grad_norm",
|
| 85 |
+
"train/reward_mean",
|
| 86 |
+
"train/reward_std",
|
| 87 |
+
"train/completion_length_mean",
|
| 88 |
+
"system/episodes_per_sec",
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
GPU_SYSTEM_METRICS = (
|
| 92 |
+
"system/gpu_available",
|
| 93 |
+
"system/gpu_count",
|
| 94 |
+
"system/gpu_current_device",
|
| 95 |
+
"system/gpu_memory_allocated_mb",
|
| 96 |
+
"system/gpu_memory_reserved_mb",
|
| 97 |
+
"system/gpu_memory_max_allocated_mb",
|
| 98 |
+
"system/gpu_memory_total_mb",
|
| 99 |
+
"system/gpu_memory_allocated_fraction",
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
CANONICAL_TRACKIO_SIGNAL_GROUPS = {
|
| 103 |
+
"run_scenario": RUN_SCENARIO_FIELDS,
|
| 104 |
+
"reward": REWARD_DECOMPOSITION_FIELDS,
|
| 105 |
+
"skill": BEHAVIOR_SKILL_FIELDS,
|
| 106 |
+
"anti_cheat": ANTI_CHEAT_FIELDS,
|
| 107 |
+
"eval": GENERALIZATION_EVAL_FIELDS,
|
| 108 |
+
"training_system": TRAINING_SYSTEM_FIELDS,
|
| 109 |
+
}
|
| 110 |
+
|
| 111 |
+
CANONICAL_TRACKIO_SIGNALS = tuple(
|
| 112 |
+
field
|
| 113 |
+
for group in CANONICAL_TRACKIO_SIGNAL_GROUPS.values()
|
| 114 |
+
for field in group
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
DERIVED_TRACKIO_METRICS = (
|
| 118 |
+
"reward/public_hidden_gap",
|
| 119 |
+
"cheat/score",
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
REQUIRED_SMOKE_TRACKIO_ITEMS = (
|
| 123 |
+
"reward/total",
|
| 124 |
+
"reward/hidden_authz_pass_rate",
|
| 125 |
+
"skill/exploit_to_patch_alignment",
|
| 126 |
+
"cheat/score",
|
| 127 |
+
"sample_traces",
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
TRACE_TABLE_COLUMNS = (
|
| 131 |
+
"episode_id",
|
| 132 |
+
"scenario_id_hash",
|
| 133 |
+
"split",
|
| 134 |
+
"difficulty",
|
| 135 |
+
"bug_type",
|
| 136 |
+
"visible_observation_summary",
|
| 137 |
+
"action_sequence",
|
| 138 |
+
"tool_calls",
|
| 139 |
+
"files_read",
|
| 140 |
+
"files_modified",
|
| 141 |
+
"exploit_summary",
|
| 142 |
+
"patch_diff_summary",
|
| 143 |
+
"public_test_summary",
|
| 144 |
+
"hidden_test_summary_redacted",
|
| 145 |
+
"reward_breakdown",
|
| 146 |
+
"cheat_flags",
|
| 147 |
+
"terminal_reason",
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
SENSITIVE_TEXT_PATTERNS = (
|
| 151 |
+
re.compile(r"hf_[A-Za-z0-9_]+"),
|
| 152 |
+
re.compile(r"(?i)(secret|token|password|api[_-]?key)\s*[:=]\s*[^,\s}]+"),
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
AUTH_RELEVANT_TERMS = (
|
| 156 |
+
"auth",
|
| 157 |
+
"tenant",
|
| 158 |
+
"owner",
|
| 159 |
+
"role",
|
| 160 |
+
"permission",
|
| 161 |
+
"billing_admin",
|
| 162 |
+
"forbidden",
|
| 163 |
+
"policy",
|
| 164 |
+
"principal",
|
| 165 |
+
)
|
| 166 |
|
| 167 |
|
| 168 |
TRAIN_METRICS = [
|
|
|
|
| 214 |
]
|
| 215 |
|
| 216 |
|
| 217 |
+
def _float(value: Any, default: float = 0.0) -> float:
|
| 218 |
+
if isinstance(value, bool):
|
| 219 |
+
return 1.0 if value else 0.0
|
| 220 |
+
try:
|
| 221 |
+
return float(value)
|
| 222 |
+
except (TypeError, ValueError):
|
| 223 |
+
return default
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def _mean(values: Sequence[float]) -> float:
|
| 227 |
+
return sum(values) / len(values) if values else 0.0
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def _stable_hash(value: Any, length: int = 16) -> str:
|
| 231 |
+
text = json.dumps(value, sort_keys=True, default=str)
|
| 232 |
+
return hashlib.sha256(text.encode("utf-8")).hexdigest()[:length]
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def _redact_text(value: Any, limit: int = 800) -> str:
|
| 236 |
+
text = str(value)
|
| 237 |
+
for pattern in SENSITIVE_TEXT_PATTERNS:
|
| 238 |
+
text = pattern.sub("[redacted]", text)
|
| 239 |
+
return text[:limit]
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def _as_dict(value: Any) -> dict[str, Any]:
|
| 243 |
+
if value is None:
|
| 244 |
+
return {}
|
| 245 |
+
if isinstance(value, dict):
|
| 246 |
+
return value
|
| 247 |
+
if hasattr(value, "model_dump"):
|
| 248 |
+
return value.model_dump()
|
| 249 |
+
return dict(getattr(value, "__dict__", {}) or {})
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def _as_action_list(record: Mapping[str, Any]) -> list[dict[str, Any]]:
|
| 253 |
+
actions = record.get("action_history") or record.get("actions") or []
|
| 254 |
+
return [_as_dict(item) for item in actions]
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def _as_observation_list(record: Mapping[str, Any]) -> list[dict[str, Any]]:
|
| 258 |
+
observations = record.get("observation_history") or record.get("observations") or []
|
| 259 |
+
return [_as_dict(item) for item in observations]
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def _safe_action(action: Mapping[str, Any]) -> dict[str, Any]:
|
| 263 |
+
tool_name = str(action.get("tool_name", ""))
|
| 264 |
+
args = _as_dict(action.get("arguments"))
|
| 265 |
+
safe_args: dict[str, Any] = {}
|
| 266 |
+
if tool_name in {"read_file", "patch_file"} and args.get("path"):
|
| 267 |
+
safe_args["path"] = _redact_text(args["path"], limit=160)
|
| 268 |
+
elif tool_name == "search_code":
|
| 269 |
+
query = str(args.get("query", ""))
|
| 270 |
+
safe_args["query_hash"] = _stable_hash(query)
|
| 271 |
+
safe_args["query_length"] = len(query)
|
| 272 |
+
elif tool_name in {"send_local_request", "compare_identities"}:
|
| 273 |
+
safe_args["method"] = args.get("method", "GET")
|
| 274 |
+
safe_args["path"] = _redact_text(args.get("path", ""), limit=160)
|
| 275 |
+
if args.get("user_id"):
|
| 276 |
+
safe_args["user_id_hash"] = _stable_hash(args["user_id"])
|
| 277 |
+
if args.get("first_user_id"):
|
| 278 |
+
safe_args["first_user_id_hash"] = _stable_hash(args["first_user_id"])
|
| 279 |
+
if args.get("second_user_id"):
|
| 280 |
+
safe_args["second_user_id_hash"] = _stable_hash(args["second_user_id"])
|
| 281 |
+
elif tool_name == "submit_finding":
|
| 282 |
+
safe_args["summary_length"] = len(str(args.get("summary", "")))
|
| 283 |
+
safe_args["evidence_length"] = len(str(args.get("evidence", "")))
|
| 284 |
+
safe_args["policy_rule_length"] = len(str(args.get("policy_rule", "")))
|
| 285 |
+
elif tool_name == "patch_file":
|
| 286 |
+
safe_args["content_hash"] = _stable_hash(args.get("content", ""))
|
| 287 |
+
safe_args["diff_hash"] = _stable_hash(args.get("diff", ""))
|
| 288 |
+
return {"tool_name": tool_name, "arguments": safe_args}
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def _check_pass_rate(result: Any) -> float:
|
| 292 |
+
result_dict = _as_dict(result)
|
| 293 |
+
checks = result_dict.get("checks")
|
| 294 |
+
if isinstance(checks, dict) and checks:
|
| 295 |
+
return _mean([1.0 if bool(value) else 0.0 for value in checks.values()])
|
| 296 |
+
if "passed" in result_dict:
|
| 297 |
+
return 1.0 if bool(result_dict.get("passed")) else 0.0
|
| 298 |
+
return 0.0
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def _check_summary(result: Any) -> dict[str, Any]:
|
| 302 |
+
result_dict = _as_dict(result)
|
| 303 |
+
checks = result_dict.get("checks")
|
| 304 |
+
return {
|
| 305 |
+
"passed": bool(result_dict.get("passed", False)),
|
| 306 |
+
"pass_rate": _check_pass_rate(result_dict),
|
| 307 |
+
"num_checks": len(checks) if isinstance(checks, dict) else 0,
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def _reward_history(record: Mapping[str, Any]) -> list[dict[str, float]]:
|
| 312 |
+
history = record.get("reward_history") or record.get("reward_breakdown_by_step") or []
|
| 313 |
+
if not history:
|
| 314 |
+
observations = _as_observation_list(record)
|
| 315 |
+
history = [
|
| 316 |
+
obs.get("reward_breakdown", {})
|
| 317 |
+
for obs in observations
|
| 318 |
+
if isinstance(obs.get("reward_breakdown"), dict)
|
| 319 |
+
]
|
| 320 |
+
return [
|
| 321 |
+
{str(key): _float(value) for key, value in _as_dict(item).items()}
|
| 322 |
+
for item in history
|
| 323 |
+
]
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def _final_reward_breakdown(record: Mapping[str, Any]) -> dict[str, float]:
|
| 327 |
+
for key in ("final_reward_breakdown", "reward_breakdown"):
|
| 328 |
+
if isinstance(record.get(key), dict):
|
| 329 |
+
return {str(k): _float(v) for k, v in record[key].items()}
|
| 330 |
+
history = _reward_history(record)
|
| 331 |
+
return dict(history[-1]) if history else {}
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def _reward_component_sum(record: Mapping[str, Any], key: str) -> float:
|
| 335 |
+
return sum(item.get(key, 0.0) for item in _reward_history(record))
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def _verification(record: Mapping[str, Any]) -> dict[str, Any]:
|
| 339 |
+
return _as_dict(record.get("verification_summary") or record.get("verifier") or {})
|
| 340 |
+
|
| 341 |
+
|
| 342 |
+
def _tool_names(actions: Sequence[Mapping[str, Any]]) -> list[str]:
|
| 343 |
+
return [str(action.get("tool_name", "")) for action in actions]
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
def _first_tool_step(
|
| 347 |
+
actions: Sequence[Mapping[str, Any]],
|
| 348 |
+
tools: set[str],
|
| 349 |
+
observations: Sequence[Mapping[str, Any]] | None = None,
|
| 350 |
+
) -> float:
|
| 351 |
+
for index, action in enumerate(actions, start=1):
|
| 352 |
+
if str(action.get("tool_name", "")) not in tools:
|
| 353 |
+
continue
|
| 354 |
+
if observations and index - 1 < len(observations):
|
| 355 |
+
if observations[index - 1].get("last_action_valid") is False:
|
| 356 |
+
continue
|
| 357 |
+
return float(index)
|
| 358 |
+
return -1.0
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def _has_tool_before(actions: Sequence[Mapping[str, Any]], tools: set[str], before_tool: str) -> bool:
|
| 362 |
+
for action in actions:
|
| 363 |
+
tool_name = str(action.get("tool_name", ""))
|
| 364 |
+
if tool_name == before_tool:
|
| 365 |
+
return False
|
| 366 |
+
if tool_name in tools:
|
| 367 |
+
return True
|
| 368 |
+
return False
|
| 369 |
+
|
| 370 |
+
|
| 371 |
+
def _patch_diff(record: Mapping[str, Any]) -> str:
|
| 372 |
+
return str(record.get("patch_diff") or "")
|
| 373 |
+
|
| 374 |
+
|
| 375 |
+
def _diff_lines(diff: str) -> list[str]:
|
| 376 |
+
return [
|
| 377 |
+
line
|
| 378 |
+
for line in diff.splitlines()
|
| 379 |
+
if (line.startswith("+") or line.startswith("-"))
|
| 380 |
+
and not line.startswith("+++")
|
| 381 |
+
and not line.startswith("---")
|
| 382 |
+
]
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def _security_relevant_edit_ratio(diff: str) -> float:
|
| 386 |
+
lines = _diff_lines(diff)
|
| 387 |
+
if not lines:
|
| 388 |
+
return 0.0
|
| 389 |
+
relevant = [
|
| 390 |
+
line
|
| 391 |
+
for line in lines
|
| 392 |
+
if any(term in line.lower() for term in AUTH_RELEVANT_TERMS)
|
| 393 |
+
]
|
| 394 |
+
return len(relevant) / len(lines)
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
def _quality_score(quality: Mapping[str, Any]) -> float:
|
| 398 |
+
quality = _as_dict(quality)
|
| 399 |
+
if quality.get("passed"):
|
| 400 |
+
return 1.0
|
| 401 |
+
if quality.get("secure_pattern") and not quality.get("hardcoded") and not quality.get("deny_all"):
|
| 402 |
+
return 0.5
|
| 403 |
+
return 0.0
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
def _files_from_actions(actions: Sequence[Mapping[str, Any]], tool_name: str) -> list[str]:
|
| 407 |
+
files: list[str] = []
|
| 408 |
+
for action in actions:
|
| 409 |
+
if action.get("tool_name") != tool_name:
|
| 410 |
+
continue
|
| 411 |
+
path = _as_dict(action.get("arguments")).get("path")
|
| 412 |
+
if path:
|
| 413 |
+
files.append(str(path))
|
| 414 |
+
return files
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
def _files_modified(record: Mapping[str, Any], actions: Sequence[Mapping[str, Any]]) -> list[str]:
|
| 418 |
+
metrics = _as_dict(record.get("metrics"))
|
| 419 |
+
touched = metrics.get("files_touched")
|
| 420 |
+
if isinstance(touched, list):
|
| 421 |
+
return [str(item) for item in touched]
|
| 422 |
+
return _files_from_actions(actions, "patch_file")
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
def _path_or_query_text(actions: Sequence[Mapping[str, Any]]) -> str:
|
| 426 |
+
parts: list[str] = []
|
| 427 |
+
for action in actions:
|
| 428 |
+
args = _as_dict(action.get("arguments"))
|
| 429 |
+
for key in ("path", "query", "content", "diff"):
|
| 430 |
+
if key in args:
|
| 431 |
+
parts.append(str(args[key]))
|
| 432 |
+
return " ".join(parts).lower()
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
def _route_deleted_or_stubbed(diff: str) -> bool:
|
| 436 |
+
lowered = diff.lower()
|
| 437 |
+
return any(term in lowered for term in ("def health", "route not found", "pass\n", "not implemented"))
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
def _auth_disabled_or_bypassed(diff: str) -> bool:
|
| 441 |
+
lowered = diff.lower()
|
| 442 |
+
return any(term in lowered for term in ("auth disabled", "bypass", "skip_auth", "disable_auth"))
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
def _observation_leak_detected(observations: Sequence[Mapping[str, Any]]) -> bool:
|
| 446 |
+
leak_markers = ("hidden_facts", "oracle_hidden_focus", "owner_invoice_id", "foreign_invoice_id")
|
| 447 |
+
text = json.dumps(observations, default=str).lower()
|
| 448 |
+
return any(marker.lower() in text for marker in leak_markers)
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def _cheat_score(fields: Mapping[str, Any]) -> float:
|
| 452 |
+
return (
|
| 453 |
+
100 * _float(fields.get("cheat/hidden_file_read_attempt"))
|
| 454 |
+
+ 100 * _float(fields.get("cheat/reward_file_read_attempt"))
|
| 455 |
+
+ 100 * _float(fields.get("cheat/test_file_modified"))
|
| 456 |
+
+ 50 * _float(fields.get("cheat/hardcoded_scenario_id"))
|
| 457 |
+
+ 30 * _float(fields.get("cheat/catch_all_deny_patch"))
|
| 458 |
+
+ 30 * _float(fields.get("cheat/route_deleted_or_stubbed"))
|
| 459 |
+
+ 20 * _float(fields.get("cheat/timeout_or_resource_abuse"))
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
def _scenario_hash(record: Mapping[str, Any]) -> str:
|
| 464 |
+
return _stable_hash(
|
| 465 |
+
{
|
| 466 |
+
"task_id": record.get("task_id"),
|
| 467 |
+
"seed": record.get("scenario/seed") or record.get("seed"),
|
| 468 |
+
"split": record.get("scenario/split") or record.get("split"),
|
| 469 |
+
"template_id": record.get("scenario/template_id") or record.get("template_id"),
|
| 470 |
+
"scenario_family": record.get("scenario_family"),
|
| 471 |
+
}
|
| 472 |
+
)
|
| 473 |
+
|
| 474 |
+
|
| 475 |
+
def episode_record_from_state(
|
| 476 |
+
state: Any,
|
| 477 |
+
*,
|
| 478 |
+
run_context: Mapping[str, Any] | None = None,
|
| 479 |
+
final_observation: Mapping[str, Any] | None = None,
|
| 480 |
+
) -> dict[str, Any]:
|
| 481 |
+
"""Build a redaction-friendly tracking record from an environment state."""
|
| 482 |
+
|
| 483 |
+
context = dict(run_context or {})
|
| 484 |
+
reward_history = [dict(item) for item in getattr(state, "reward_history", []) or []]
|
| 485 |
+
final_reward = dict(final_observation.get("reward_breakdown", {})) if final_observation else {}
|
| 486 |
+
if not final_reward and reward_history:
|
| 487 |
+
final_reward = dict(reward_history[-1])
|
| 488 |
+
record = {
|
| 489 |
+
"run/base_model": context.get("base_model", context.get("run/base_model", "")),
|
| 490 |
+
"run/algo": context.get("algo", context.get("run/algo", "")),
|
| 491 |
+
"run/reward_version": context.get("reward_version", "reward_v1"),
|
| 492 |
+
"run/env_version": context.get("env_version", "0.1.0"),
|
| 493 |
+
"episode_id": getattr(state, "episode_id", ""),
|
| 494 |
+
"task_id": getattr(state, "task_id", ""),
|
| 495 |
+
"scenario/seed": getattr(state, "seed", 0),
|
| 496 |
+
"scenario/template_id": getattr(state, "template_id", ""),
|
| 497 |
+
"scenario/split": getattr(state, "split", ""),
|
| 498 |
+
"scenario/difficulty": getattr(state, "difficulty", 0),
|
| 499 |
+
"scenario/bug_type": getattr(state, "bug_family", ""),
|
| 500 |
+
"scenario_family": getattr(state, "scenario_family", ""),
|
| 501 |
+
"target_weakness": getattr(state, "target_weakness", ""),
|
| 502 |
+
"difficulty_tier": getattr(state, "difficulty_tier", ""),
|
| 503 |
+
"domain": getattr(state, "domain", ""),
|
| 504 |
+
"success": bool(getattr(state, "success", False)),
|
| 505 |
+
"failure_reason": getattr(state, "failure_reason", None),
|
| 506 |
+
"finding_submitted": bool(getattr(state, "finding_submitted", False)),
|
| 507 |
+
"patch_submitted": bool(getattr(state, "patch_submitted", False)),
|
| 508 |
+
"step_count": int(getattr(state, "step_count", 0) or 0),
|
| 509 |
+
"max_steps": int(getattr(state, "max_steps", 0) or 0),
|
| 510 |
+
"done": bool(getattr(state, "done", False)),
|
| 511 |
+
"anti_cheat_flags": list(getattr(state, "anti_cheat_flags", []) or []),
|
| 512 |
+
"metrics": dict(getattr(state, "metrics", {}) or {}),
|
| 513 |
+
"verification_summary": dict(getattr(state, "verification_summary", {}) or {}),
|
| 514 |
+
"patch_diff": str(getattr(state, "patch_diff", "") or ""),
|
| 515 |
+
"reward_history": reward_history,
|
| 516 |
+
"reward_breakdown": final_reward,
|
| 517 |
+
"action_history": [
|
| 518 |
+
_safe_action(action)
|
| 519 |
+
for action in (getattr(state, "action_history", []) or [])
|
| 520 |
+
],
|
| 521 |
+
"observation_history": [
|
| 522 |
+
_as_dict(obs)
|
| 523 |
+
for obs in (getattr(state, "observation_history", []) or [])
|
| 524 |
+
],
|
| 525 |
+
}
|
| 526 |
+
record["scenario_id_hash"] = _scenario_hash(record)
|
| 527 |
+
return record
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
def episode_to_tracking_fields(episode: Any) -> dict[str, Any]:
|
| 531 |
+
"""Return the 57 canonical fields plus derived diagnostics for one episode."""
|
| 532 |
+
|
| 533 |
+
if not isinstance(episode, Mapping):
|
| 534 |
+
episode = episode_record_from_state(episode)
|
| 535 |
+
record = dict(episode)
|
| 536 |
+
actions = _as_action_list(record)
|
| 537 |
+
observations = _as_observation_list(record)
|
| 538 |
+
tool_names = _tool_names(actions)
|
| 539 |
+
verifier = _verification(record)
|
| 540 |
+
final_reward = _final_reward_breakdown(record)
|
| 541 |
+
quality = _as_dict(verifier.get("patch_quality"))
|
| 542 |
+
visible_rate = _check_pass_rate(verifier.get("visible"))
|
| 543 |
+
hidden_rate = _check_pass_rate(verifier.get("security") or verifier.get("hidden_tests"))
|
| 544 |
+
normal_rate = _check_pass_rate(verifier.get("regression"))
|
| 545 |
+
public_rate = _check_pass_rate(verifier.get("public_routes"))
|
| 546 |
+
files_read = _files_from_actions(actions, "read_file")
|
| 547 |
+
files_modified = _files_modified(record, actions)
|
| 548 |
+
patch_diff = _patch_diff(record)
|
| 549 |
+
path_text = _path_or_query_text(actions)
|
| 550 |
+
anti_flags = set(record.get("anti_cheat_flags") or [])
|
| 551 |
+
invalid_count = sum(1 for obs in observations if obs.get("last_action_valid") is False)
|
| 552 |
+
action_count = max(1, len(actions))
|
| 553 |
+
valid_action_rate = 1.0 - (invalid_count / action_count)
|
| 554 |
+
tests_run_count = tool_names.count("run_visible_tests")
|
| 555 |
+
auth_cases = sum(1 for item in tool_names if item in {"send_local_request", "compare_identities"})
|
| 556 |
+
|
| 557 |
+
fields: dict[str, Any] = {key: 0.0 for key in CANONICAL_TRACKIO_SIGNALS}
|
| 558 |
+
for key in RUN_SCENARIO_FIELDS:
|
| 559 |
+
fallback = key.split("/", 1)[1]
|
| 560 |
+
fields[key] = record.get(key, record.get(fallback, ""))
|
| 561 |
+
|
| 562 |
+
fields["scenario/seed"] = _float(fields["scenario/seed"])
|
| 563 |
+
fields["scenario/difficulty"] = _float(fields["scenario/difficulty"])
|
| 564 |
+
fields["reward/total"] = _float(record.get("reward_total", final_reward.get("total", 0.0)))
|
| 565 |
+
fields["reward/exploit_reproduced_pre_patch"] = 1.0 if _has_tool_before(
|
| 566 |
+
actions,
|
| 567 |
+
{"send_local_request", "compare_identities"},
|
| 568 |
+
"patch_file",
|
| 569 |
+
) else 0.0
|
| 570 |
+
fields["reward/bug_classification_correct"] = 1.0 if (
|
| 571 |
+
record.get("finding_submitted") or _reward_component_sum(record, "discovery") > 0.0
|
| 572 |
+
) else 0.0
|
| 573 |
+
fields["reward/patch_blocks_submitted_exploit"] = hidden_rate
|
| 574 |
+
fields["reward/hidden_authz_pass_rate"] = hidden_rate
|
| 575 |
+
fields["reward/public_tests_pass_rate"] = visible_rate
|
| 576 |
+
fields["reward/normal_flow_pass_rate"] = normal_rate
|
| 577 |
+
fields["reward/minimality_score"] = _quality_score(quality)
|
| 578 |
+
fields["reward/cheat_penalty"] = max(0.0, -_float(final_reward.get("anti_cheat"))) + max(
|
| 579 |
+
0.0,
|
| 580 |
+
-_float(final_reward.get("safety")),
|
| 581 |
+
)
|
| 582 |
+
fields["reward/public_hidden_gap"] = visible_rate - hidden_rate
|
| 583 |
+
|
| 584 |
+
fields["skill/valid_action_rate"] = max(0.0, min(1.0, valid_action_rate))
|
| 585 |
+
fields["skill/discovery_success"] = fields["reward/bug_classification_correct"]
|
| 586 |
+
fields["skill/patch_success"] = 1.0 if record.get("success") else 0.0
|
| 587 |
+
fields["skill/first_valid_exploit_step"] = _first_tool_step(
|
| 588 |
+
actions,
|
| 589 |
+
{"send_local_request", "compare_identities"},
|
| 590 |
+
observations,
|
| 591 |
+
)
|
| 592 |
+
fields["skill/exploit_to_patch_alignment"] = 1.0 if (
|
| 593 |
+
fields["skill/patch_success"]
|
| 594 |
+
or (
|
| 595 |
+
fields["skill/discovery_success"]
|
| 596 |
+
and _security_relevant_edit_ratio(patch_diff) > 0.0
|
| 597 |
+
)
|
| 598 |
+
) else 0.0
|
| 599 |
+
fields["skill/policy_graph_coverage"] = 1.0 if "inspect_policy_graph" in tool_names else 0.0
|
| 600 |
+
fields["skill/route_coverage"] = 1.0 if (
|
| 601 |
+
"list_routes" in tool_names or "read_openapi" in tool_names
|
| 602 |
+
) else 0.0
|
| 603 |
+
fields["skill/auth_boundary_cases_tested"] = float(auth_cases)
|
| 604 |
+
fields["skill/files_read_count"] = float(len(files_read))
|
| 605 |
+
fields["skill/files_modified_count"] = float(len(files_modified))
|
| 606 |
+
fields["skill/security_relevant_edit_ratio"] = _security_relevant_edit_ratio(patch_diff)
|
| 607 |
+
fields["skill/tests_run_count"] = float(tests_run_count)
|
| 608 |
+
|
| 609 |
+
fields["cheat/hidden_file_read_attempt"] = 1.0 if (
|
| 610 |
+
"blocked_file_probe" in anti_flags and "hidden" in path_text
|
| 611 |
+
) else 0.0
|
| 612 |
+
fields["cheat/reward_file_read_attempt"] = 1.0 if (
|
| 613 |
+
"blocked_file_probe" in anti_flags and ("reward" in path_text or "oracle" in path_text)
|
| 614 |
+
) else 0.0
|
| 615 |
+
fields["cheat/test_file_modified"] = 1.0 if any(
|
| 616 |
+
action.get("tool_name") == "patch_file"
|
| 617 |
+
and str(_as_dict(action.get("arguments")).get("path", "")).replace("\\", "/").startswith("tests/")
|
| 618 |
+
for action in actions
|
| 619 |
+
) else 0.0
|
| 620 |
+
fields["cheat/reward_file_modified"] = 1.0 if any(
|
| 621 |
+
action.get("tool_name") == "patch_file"
|
| 622 |
+
and any(term in str(_as_dict(action.get("arguments")).get("path", "")).lower() for term in ("reward", "oracle", "validator"))
|
| 623 |
+
for action in actions
|
| 624 |
+
) else 0.0
|
| 625 |
+
fields["cheat/env_var_secret_probe"] = 1.0 if any(
|
| 626 |
+
term in path_text for term in ("env", "secret", "token", "password", "api_key")
|
| 627 |
+
) else 0.0
|
| 628 |
+
fields["cheat/hardcoded_scenario_id"] = 1.0 if (
|
| 629 |
+
"hardcoded_hidden_identifier" in anti_flags or bool(quality.get("hardcoded"))
|
| 630 |
+
) else 0.0
|
| 631 |
+
fields["cheat/catch_all_deny_patch"] = 1.0 if bool(quality.get("deny_all")) else 0.0
|
| 632 |
+
fields["cheat/route_deleted_or_stubbed"] = 1.0 if _route_deleted_or_stubbed(patch_diff) else 0.0
|
| 633 |
+
fields["cheat/auth_disabled_or_bypassed"] = 1.0 if _auth_disabled_or_bypassed(patch_diff) else 0.0
|
| 634 |
+
fields["cheat/database_reset_or_mutation_abuse"] = 1.0 if any(
|
| 635 |
+
term in patch_diff.lower() for term in ("drop table", "delete from", "truncate", "reset_db")
|
| 636 |
+
) else 0.0
|
| 637 |
+
fields["cheat/timeout_or_resource_abuse"] = 1.0 if (
|
| 638 |
+
record.get("failure_reason") == "max_steps_exceeded" or "timeout_or_resource_abuse" in anti_flags
|
| 639 |
+
) else 0.0
|
| 640 |
+
fields["cheat/observation_leak_detected"] = 1.0 if _observation_leak_detected(observations) else 0.0
|
| 641 |
+
fields["cheat/score"] = _cheat_score(fields)
|
| 642 |
+
|
| 643 |
+
# Episode-level tracking does not know cross-run evaluation or trainer internals.
|
| 644 |
+
# Those fields remain present with zero defaults and are filled by eval/trainer logs.
|
| 645 |
+
fields["eval/hidden_success_rate"] = fields["skill/patch_success"] if (
|
| 646 |
+
record.get("scenario/split") == "hidden_eval"
|
| 647 |
+
) else 0.0
|
| 648 |
+
fields["train/reward_mean"] = fields["reward/total"]
|
| 649 |
+
return fields
|
| 650 |
+
|
| 651 |
+
|
| 652 |
+
def episode_to_trackio_metrics(episode: Any) -> dict[str, float]:
|
| 653 |
+
"""Return numeric Trackio scalar metrics for one episode."""
|
| 654 |
+
|
| 655 |
+
fields = episode_to_tracking_fields(episode)
|
| 656 |
+
return {
|
| 657 |
+
key: _float(value)
|
| 658 |
+
for key, value in fields.items()
|
| 659 |
+
if isinstance(value, (int, float, bool))
|
| 660 |
+
}
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
def aggregate_episode_metrics(episodes: Sequence[Any]) -> dict[str, float]:
|
| 664 |
+
"""Aggregate numeric canonical episode metrics as batch means."""
|
| 665 |
+
|
| 666 |
+
if not episodes:
|
| 667 |
+
return {"run/episode_count": 0.0}
|
| 668 |
+
per_episode = [episode_to_trackio_metrics(episode) for episode in episodes]
|
| 669 |
+
keys = sorted(set().union(*(item.keys() for item in per_episode)))
|
| 670 |
+
metrics = {
|
| 671 |
+
key: _mean([_float(item.get(key)) for item in per_episode])
|
| 672 |
+
for key in keys
|
| 673 |
+
}
|
| 674 |
+
metrics["run/episode_count"] = float(len(episodes))
|
| 675 |
+
metrics["cheat/episode_rate"] = _mean(
|
| 676 |
+
[1.0 if _float(item.get("cheat/score")) > 0.0 else 0.0 for item in per_episode]
|
| 677 |
+
)
|
| 678 |
+
metrics["train/reward_std"] = (
|
| 679 |
+
sum(
|
| 680 |
+
(item.get("reward/total", 0.0) - metrics.get("reward/total", 0.0)) ** 2
|
| 681 |
+
for item in per_episode
|
| 682 |
+
)
|
| 683 |
+
/ max(1, len(per_episode))
|
| 684 |
+
) ** 0.5
|
| 685 |
+
return metrics
|
| 686 |
+
|
| 687 |
+
|
| 688 |
+
def train_metric_aliases(metrics: Mapping[str, Any]) -> dict[str, float]:
|
| 689 |
+
"""Map canonical metrics to the repo's existing train/* dashboard names."""
|
| 690 |
+
|
| 691 |
+
return {
|
| 692 |
+
"train/reward_total_mean": _float(metrics.get("reward/total")),
|
| 693 |
+
"train/reward_discovery_mean": _float(metrics.get("reward/bug_classification_correct")) * 3.0,
|
| 694 |
+
"train/reward_security_mean": _float(metrics.get("reward/hidden_authz_pass_rate")) * 5.0,
|
| 695 |
+
"train/reward_regression_mean": _float(metrics.get("reward/normal_flow_pass_rate")) * 3.0,
|
| 696 |
+
"train/reward_public_routes_mean": _float(metrics.get("reward/public_tests_pass_rate")),
|
| 697 |
+
"train/reward_patch_quality_mean": _float(metrics.get("reward/minimality_score")) * 2.0,
|
| 698 |
+
"train/reward_visible_tests_mean": _float(metrics.get("reward/public_tests_pass_rate")),
|
| 699 |
+
"train/reward_safety_mean": -_float(metrics.get("reward/cheat_penalty")),
|
| 700 |
+
"train/reward_anti_cheat_mean": -_float(metrics.get("cheat/score")) / 100.0,
|
| 701 |
+
"train/success_rate": _float(metrics.get("skill/patch_success")),
|
| 702 |
+
"train/exploit_block_rate": _float(metrics.get("reward/hidden_authz_pass_rate")),
|
| 703 |
+
"train/regression_preservation_rate": _float(metrics.get("reward/normal_flow_pass_rate")),
|
| 704 |
+
"train/public_route_preservation_rate": _float(metrics.get("reward/public_tests_pass_rate")),
|
| 705 |
+
"train/invalid_action_rate": 1.0 - _float(metrics.get("skill/valid_action_rate")),
|
| 706 |
+
"train/timeout_rate": _float(metrics.get("cheat/timeout_or_resource_abuse")),
|
| 707 |
+
"train/safety_violation_rate": _float(metrics.get("cheat/env_var_secret_probe")),
|
| 708 |
+
"train/reward_hacking_suspected_rate": 1.0 if (
|
| 709 |
+
_float(metrics.get("reward/public_hidden_gap")) > 0.35
|
| 710 |
+
or _float(metrics.get("cheat/score")) >= 100.0
|
| 711 |
+
) else 0.0,
|
| 712 |
+
"train/episode_length_mean": _float(metrics.get("skill/tests_run_count"))
|
| 713 |
+
+ _float(metrics.get("skill/files_read_count"))
|
| 714 |
+
+ _float(metrics.get("skill/auth_boundary_cases_tested")),
|
| 715 |
+
}
|
| 716 |
+
|
| 717 |
+
|
| 718 |
+
def eval_metric_aliases(summary: Mapping[str, Any]) -> dict[str, float]:
|
| 719 |
+
"""Map eval summary fields to the requested generalization metric names."""
|
| 720 |
+
|
| 721 |
+
train_success = _float(summary.get("trained_success_rate", summary.get("train_success_rate")))
|
| 722 |
+
hidden_success = _float(summary.get("heldout_success_rate", summary.get("hidden_success_rate")))
|
| 723 |
+
return {
|
| 724 |
+
"eval/train_success_rate": train_success,
|
| 725 |
+
"eval/seen_template_success_rate": _float(summary.get("seen_template_success_rate", train_success)),
|
| 726 |
+
"eval/unseen_template_success_rate": _float(summary.get("unseen_template_success_rate", hidden_success)),
|
| 727 |
+
"eval/unseen_framework_success_rate": _float(summary.get("unseen_framework_success_rate", 0.0)),
|
| 728 |
+
"eval/unseen_policy_graph_success_rate": _float(summary.get("unseen_policy_graph_success_rate", hidden_success)),
|
| 729 |
+
"eval/hidden_success_rate": hidden_success,
|
| 730 |
+
"eval/train_hidden_gap": train_success - hidden_success,
|
| 731 |
+
}
|
| 732 |
+
|
| 733 |
+
|
| 734 |
+
def episode_to_trace_row(episode: Any) -> dict[str, Any]:
|
| 735 |
+
"""Return one redacted row for the Trackio sample_traces table."""
|
| 736 |
+
|
| 737 |
+
if not isinstance(episode, Mapping):
|
| 738 |
+
episode = episode_record_from_state(episode)
|
| 739 |
+
record = dict(episode)
|
| 740 |
+
actions = _as_action_list(record)
|
| 741 |
+
observations = _as_observation_list(record)
|
| 742 |
+
tool_names = _tool_names(actions)
|
| 743 |
+
verifier = _verification(record)
|
| 744 |
+
patch_diff = _patch_diff(record)
|
| 745 |
+
files_read = _files_from_actions(actions, "read_file")
|
| 746 |
+
files_modified = _files_modified(record, actions)
|
| 747 |
+
reward_breakdown = _final_reward_breakdown(record)
|
| 748 |
+
final_obs = observations[-1] if observations else {}
|
| 749 |
+
row = {
|
| 750 |
+
"episode_id": _redact_text(record.get("episode_id", "")),
|
| 751 |
+
"scenario_id_hash": record.get("scenario_id_hash") or _scenario_hash(record),
|
| 752 |
+
"split": record.get("scenario/split") or record.get("split", ""),
|
| 753 |
+
"difficulty": record.get("scenario/difficulty") or record.get("difficulty", 0),
|
| 754 |
+
"bug_type": record.get("scenario/bug_type") or record.get("bug_type", ""),
|
| 755 |
+
"visible_observation_summary": json.dumps(
|
| 756 |
+
{
|
| 757 |
+
"done": bool(record.get("done", final_obs.get("done", False))),
|
| 758 |
+
"success": bool(record.get("success", False)),
|
| 759 |
+
"last_action_valid": final_obs.get("last_action_valid", True),
|
| 760 |
+
"terminal_reason": record.get("failure_reason") or final_obs.get("done_reason"),
|
| 761 |
+
},
|
| 762 |
+
sort_keys=True,
|
| 763 |
+
),
|
| 764 |
+
"action_sequence": " -> ".join(tool_names),
|
| 765 |
+
"tool_calls": json.dumps({name: tool_names.count(name) for name in sorted(set(tool_names))}, sort_keys=True),
|
| 766 |
+
"files_read": json.dumps(sorted(set(files_read))),
|
| 767 |
+
"files_modified": json.dumps(sorted(set(files_modified))),
|
| 768 |
+
"exploit_summary": json.dumps(
|
| 769 |
+
{
|
| 770 |
+
"local_probe_count": sum(
|
| 771 |
+
1 for name in tool_names if name in {"send_local_request", "compare_identities"}
|
| 772 |
+
),
|
| 773 |
+
"first_valid_exploit_step": episode_to_tracking_fields(record)[
|
| 774 |
+
"skill/first_valid_exploit_step"
|
| 775 |
+
],
|
| 776 |
+
"finding_submitted": bool(record.get("finding_submitted", False)),
|
| 777 |
+
},
|
| 778 |
+
sort_keys=True,
|
| 779 |
+
),
|
| 780 |
+
"patch_diff_summary": json.dumps(
|
| 781 |
+
{
|
| 782 |
+
"diff_hash": _stable_hash(patch_diff),
|
| 783 |
+
"changed_lines": len(_diff_lines(patch_diff)),
|
| 784 |
+
"security_relevant_edit_ratio": _security_relevant_edit_ratio(patch_diff),
|
| 785 |
+
},
|
| 786 |
+
sort_keys=True,
|
| 787 |
+
),
|
| 788 |
+
"public_test_summary": json.dumps(_check_summary(verifier.get("visible")), sort_keys=True),
|
| 789 |
+
"hidden_test_summary_redacted": json.dumps(
|
| 790 |
+
{
|
| 791 |
+
"authz": _check_summary(verifier.get("security") or verifier.get("hidden_tests")),
|
| 792 |
+
"regression": _check_summary(verifier.get("regression")),
|
| 793 |
+
"public_routes": _check_summary(verifier.get("public_routes")),
|
| 794 |
+
},
|
| 795 |
+
sort_keys=True,
|
| 796 |
+
),
|
| 797 |
+
"reward_breakdown": json.dumps(reward_breakdown, sort_keys=True),
|
| 798 |
+
"cheat_flags": json.dumps(sorted(record.get("anti_cheat_flags") or [])),
|
| 799 |
+
"terminal_reason": record.get("failure_reason") or final_obs.get("done_reason"),
|
| 800 |
+
}
|
| 801 |
+
return {key: _redact_text(row.get(key, "")) for key in TRACE_TABLE_COLUMNS}
|
| 802 |
+
|
| 803 |
+
|
| 804 |
+
def trace_table_rows(episodes: Sequence[Any]) -> list[dict[str, Any]]:
|
| 805 |
+
return [episode_to_trace_row(episode) for episode in episodes]
|
| 806 |
+
|
| 807 |
+
|
| 808 |
+
def log_trace_table(
|
| 809 |
+
episodes: Sequence[Any],
|
| 810 |
+
*,
|
| 811 |
+
table_name: str = "sample_traces",
|
| 812 |
+
step: int | None = None,
|
| 813 |
+
) -> None:
|
| 814 |
+
if not episodes:
|
| 815 |
+
return
|
| 816 |
+
trackio = _load_trackio()
|
| 817 |
+
rows = trace_table_rows(episodes)
|
| 818 |
+
table = trackio.Table(
|
| 819 |
+
columns=list(TRACE_TABLE_COLUMNS),
|
| 820 |
+
rows=[[row.get(column, "") for column in TRACE_TABLE_COLUMNS] for row in rows],
|
| 821 |
+
allow_mixed_types=True,
|
| 822 |
+
)
|
| 823 |
+
if step is None:
|
| 824 |
+
trackio.log({table_name: table})
|
| 825 |
+
else:
|
| 826 |
+
trackio.log({table_name: table}, step=step)
|
| 827 |
+
|
| 828 |
+
|
| 829 |
+
def log_episode_batch(
|
| 830 |
+
episodes: Sequence[Any],
|
| 831 |
+
*,
|
| 832 |
+
step: int | None = None,
|
| 833 |
+
table_name: str = "sample_traces",
|
| 834 |
+
include_train_aliases: bool = False,
|
| 835 |
+
) -> dict[str, float]:
|
| 836 |
+
metrics = aggregate_episode_metrics(episodes)
|
| 837 |
+
payload = dict(metrics)
|
| 838 |
+
if include_train_aliases:
|
| 839 |
+
payload.update(train_metric_aliases(metrics))
|
| 840 |
+
log_trackio_metrics(payload, step=step)
|
| 841 |
+
log_trace_table(episodes, table_name=table_name, step=step)
|
| 842 |
+
return payload
|
| 843 |
+
|
| 844 |
+
|
| 845 |
+
def missing_required_trackio_items(
|
| 846 |
+
run_or_metrics: Mapping[str, Any],
|
| 847 |
+
required_items: Sequence[str] = REQUIRED_SMOKE_TRACKIO_ITEMS,
|
| 848 |
+
) -> list[str]:
|
| 849 |
+
"""Return required metrics/table names absent from a Trackio run summary."""
|
| 850 |
+
|
| 851 |
+
available: set[str] = set()
|
| 852 |
+
metrics = run_or_metrics.get("metrics")
|
| 853 |
+
if isinstance(metrics, dict):
|
| 854 |
+
available.update(str(key) for key in metrics)
|
| 855 |
+
elif isinstance(metrics, list):
|
| 856 |
+
available.update(str(item) for item in metrics)
|
| 857 |
+
for key in ("tables", "artifacts", "media", "logged_artifacts"):
|
| 858 |
+
value = run_or_metrics.get(key)
|
| 859 |
+
if isinstance(value, dict):
|
| 860 |
+
available.update(str(item) for item in value)
|
| 861 |
+
elif isinstance(value, list):
|
| 862 |
+
available.update(str(item) for item in value)
|
| 863 |
+
if "values" in run_or_metrics and run_or_metrics.get("metric"):
|
| 864 |
+
available.add(str(run_or_metrics["metric"]))
|
| 865 |
+
return [item for item in required_items if item not in available]
|
| 866 |
+
|
| 867 |
+
|
| 868 |
def build_run_name(model: str, algo: str, difficulty: int, git_sha: str = "nogit") -> str:
|
| 869 |
stamp = datetime.utcnow().strftime("%Y%m%d-%H%M%S")
|
| 870 |
model_slug = model.replace("/", "-")
|
|
|
|
| 904 |
project: str | None = None,
|
| 905 |
space_id: str | None = None,
|
| 906 |
group: str | None = None,
|
| 907 |
+
auto_log_gpu: bool | None = None,
|
| 908 |
+
gpu_log_interval: float | None = None,
|
| 909 |
):
|
| 910 |
trackio = _load_trackio()
|
| 911 |
project = project or os.getenv("TRACKIO_PROJECT", "CyberSecurity_OWASP")
|
|
|
|
| 924 |
kwargs["space_id"] = space_id
|
| 925 |
if group:
|
| 926 |
kwargs["group"] = group
|
| 927 |
+
if auto_log_gpu is not None:
|
| 928 |
+
kwargs["auto_log_gpu"] = auto_log_gpu
|
| 929 |
+
if gpu_log_interval is not None:
|
| 930 |
+
kwargs["gpu_log_interval"] = gpu_log_interval
|
| 931 |
return trackio.init(**kwargs)
|
| 932 |
|
| 933 |
|
|
|
|
| 944 |
trackio.log(numeric, step=step)
|
| 945 |
|
| 946 |
|
| 947 |
+
def collect_torch_gpu_metrics() -> dict[str, float]:
|
| 948 |
+
"""Collect explicit torch CUDA metrics for Trackio scalar dashboards."""
|
| 949 |
+
|
| 950 |
+
try:
|
| 951 |
+
import torch
|
| 952 |
+
except Exception:
|
| 953 |
+
return {"system/gpu_available": 0.0, "system/gpu_count": 0.0}
|
| 954 |
+
|
| 955 |
+
if not torch.cuda.is_available():
|
| 956 |
+
return {"system/gpu_available": 0.0, "system/gpu_count": 0.0}
|
| 957 |
+
|
| 958 |
+
device = torch.cuda.current_device()
|
| 959 |
+
props = torch.cuda.get_device_properties(device)
|
| 960 |
+
allocated = float(torch.cuda.memory_allocated(device)) / (1024 * 1024)
|
| 961 |
+
reserved = float(torch.cuda.memory_reserved(device)) / (1024 * 1024)
|
| 962 |
+
max_allocated = float(torch.cuda.max_memory_allocated(device)) / (1024 * 1024)
|
| 963 |
+
total = float(props.total_memory) / (1024 * 1024)
|
| 964 |
+
return {
|
| 965 |
+
"system/gpu_available": 1.0,
|
| 966 |
+
"system/gpu_count": float(torch.cuda.device_count()),
|
| 967 |
+
"system/gpu_current_device": float(device),
|
| 968 |
+
"system/gpu_memory_allocated_mb": allocated,
|
| 969 |
+
"system/gpu_memory_reserved_mb": reserved,
|
| 970 |
+
"system/gpu_memory_max_allocated_mb": max_allocated,
|
| 971 |
+
"system/gpu_memory_total_mb": total,
|
| 972 |
+
"system/gpu_memory_allocated_fraction": allocated / total if total else 0.0,
|
| 973 |
+
}
|
| 974 |
+
|
| 975 |
+
|
| 976 |
+
def log_gpu_metrics(step: int | None = None) -> dict[str, float]:
|
| 977 |
+
"""Log Trackio's native GPU metrics plus explicit torch GPU aliases."""
|
| 978 |
+
|
| 979 |
+
trackio = _load_trackio()
|
| 980 |
+
native_metrics: dict[str, Any] = {}
|
| 981 |
+
try:
|
| 982 |
+
native_metrics = trackio.log_gpu() or {}
|
| 983 |
+
except Exception:
|
| 984 |
+
native_metrics = {}
|
| 985 |
+
torch_metrics = collect_torch_gpu_metrics()
|
| 986 |
+
if torch_metrics:
|
| 987 |
+
log_trackio_metrics(torch_metrics, step=step)
|
| 988 |
+
return {
|
| 989 |
+
**{
|
| 990 |
+
str(key): float(value)
|
| 991 |
+
for key, value in native_metrics.items()
|
| 992 |
+
if isinstance(value, (int, float, bool))
|
| 993 |
+
},
|
| 994 |
+
**torch_metrics,
|
| 995 |
+
}
|
| 996 |
+
|
| 997 |
+
|
| 998 |
def finish_trackio_run() -> None:
|
| 999 |
trackio = _load_trackio()
|
| 1000 |
trackio.finish()
|
|
|
|
| 1009 |
project: str | None = None,
|
| 1010 |
space_id: str | None = None,
|
| 1011 |
group: str | None = None,
|
| 1012 |
+
auto_log_gpu: bool | None = None,
|
| 1013 |
+
gpu_log_interval: float | None = None,
|
| 1014 |
) -> Iterator[Any]:
|
| 1015 |
run = init_trackio_run(
|
| 1016 |
run_name=run_name,
|
|
|
|
| 1019 |
project=project,
|
| 1020 |
space_id=space_id,
|
| 1021 |
group=group,
|
| 1022 |
+
auto_log_gpu=auto_log_gpu,
|
| 1023 |
+
gpu_log_interval=gpu_log_interval,
|
| 1024 |
)
|
| 1025 |
try:
|
| 1026 |
yield run
|
|
|
|
| 1034 |
for key, value in summary.items()
|
| 1035 |
if isinstance(value, (int, float, bool))
|
| 1036 |
}
|
| 1037 |
+
metrics.update(eval_metric_aliases(summary))
|
| 1038 |
with trackio_run(run_name=run_name, run_type="eval", config=config, group="eval"):
|
| 1039 |
log_trackio_metrics(metrics, step=0)
|
training/train_grpo.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
-
"""
|
| 2 |
|
| 3 |
-
This
|
| 4 |
-
|
| 5 |
-
|
| 6 |
"""
|
| 7 |
|
| 8 |
from __future__ import annotations
|
|
@@ -12,13 +12,21 @@ import os
|
|
| 12 |
from training.trackio_utils import build_run_name, get_git_sha
|
| 13 |
|
| 14 |
|
|
|
|
|
|
|
|
|
|
| 15 |
def build_grpo_config():
|
|
|
|
|
|
|
| 16 |
from trl import GRPOConfig
|
| 17 |
|
| 18 |
-
model_name = os.getenv("MODEL_NAME",
|
| 19 |
difficulty = int(os.getenv("DIFFICULTY", "0"))
|
| 20 |
-
output_dir = os.getenv(
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
| 22 |
os.environ.setdefault("TRACKIO_PROJECT", "CyberSecurity_OWASP-grpo")
|
| 23 |
run_name = os.getenv(
|
| 24 |
"RUN_NAME",
|
|
@@ -47,9 +55,41 @@ def build_grpo_config():
|
|
| 47 |
)
|
| 48 |
|
| 49 |
|
| 50 |
-
def main():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
config = build_grpo_config()
|
|
|
|
| 52 |
print(config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
if __name__ == "__main__":
|
|
|
|
| 1 |
+
"""Modal-only GRPO config helper for CyberSecurity_OWASP.
|
| 2 |
|
| 3 |
+
This module intentionally does not run local training.
|
| 4 |
+
Use `scripts/modal_train_grpo.py` (persistent) or
|
| 5 |
+
`scripts/modal_ephemeral_train.py` (smoke) for execution.
|
| 6 |
"""
|
| 7 |
|
| 8 |
from __future__ import annotations
|
|
|
|
| 12 |
from training.trackio_utils import build_run_name, get_git_sha
|
| 13 |
|
| 14 |
|
| 15 |
+
DEFAULT_GEMMA_MODEL = os.getenv("MODEL_NAME", "unsloth/gemma-4-E2B-it")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
def build_grpo_config():
|
| 19 |
+
"""Build the TRL GRPOConfig used by the Modal training pipeline."""
|
| 20 |
+
|
| 21 |
from trl import GRPOConfig
|
| 22 |
|
| 23 |
+
model_name = os.getenv("MODEL_NAME", DEFAULT_GEMMA_MODEL)
|
| 24 |
difficulty = int(os.getenv("DIFFICULTY", "0"))
|
| 25 |
+
output_dir = os.getenv(
|
| 26 |
+
"OUTPUT_DIR",
|
| 27 |
+
f"CyberSecurity_OWASP-{model_name.replace('/', '-')}-grpo",
|
| 28 |
+
)
|
| 29 |
+
trackio_space_id = os.getenv("TRACKIO_SPACE_ID", "Humanlearning/CyberSecurity_OWASP-trackio")
|
| 30 |
os.environ.setdefault("TRACKIO_PROJECT", "CyberSecurity_OWASP-grpo")
|
| 31 |
run_name = os.getenv(
|
| 32 |
"RUN_NAME",
|
|
|
|
| 55 |
)
|
| 56 |
|
| 57 |
|
| 58 |
+
def main() -> None:
|
| 59 |
+
import argparse
|
| 60 |
+
|
| 61 |
+
parser = argparse.ArgumentParser(
|
| 62 |
+
description=(
|
| 63 |
+
"CyberSecurity_OWASP GRPO config helper."
|
| 64 |
+
" Actual GRPO training is executed on Modal only."
|
| 65 |
+
)
|
| 66 |
+
)
|
| 67 |
+
parser.add_argument(
|
| 68 |
+
"--difficulty",
|
| 69 |
+
type=int,
|
| 70 |
+
default=0,
|
| 71 |
+
help="Optional curriculum difficulty included in the generated run name.",
|
| 72 |
+
)
|
| 73 |
+
parser.add_argument("--model-name", default=DEFAULT_GEMMA_MODEL)
|
| 74 |
+
parser.add_argument(
|
| 75 |
+
"--output-dir",
|
| 76 |
+
default=None,
|
| 77 |
+
help="Optional GRPO output_dir override.",
|
| 78 |
+
)
|
| 79 |
+
args = parser.parse_args()
|
| 80 |
+
|
| 81 |
+
os.environ["MODEL_NAME"] = args.model_name
|
| 82 |
+
if args.output_dir:
|
| 83 |
+
os.environ["OUTPUT_DIR"] = args.output_dir
|
| 84 |
+
|
| 85 |
config = build_grpo_config()
|
| 86 |
+
print("GRPO config (Modal execution):")
|
| 87 |
print(config)
|
| 88 |
+
print(
|
| 89 |
+
"Run on Modal, for example:\n"
|
| 90 |
+
"uv run --extra modal modal run scripts/modal_train_grpo.py "
|
| 91 |
+
f"--model-name {args.model_name} --difficulty {args.difficulty}"
|
| 92 |
+
)
|
| 93 |
|
| 94 |
|
| 95 |
if __name__ == "__main__":
|