Humanlearning commited on
Commit
b3ee507
·
1 Parent(s): 28685f3

feat: update training configuration and documentation for Modal execution, including new model integration and enhanced tracking utilities

Browse files
.agents/skills/cybersecurity-owasp-trainer/SKILL.md CHANGED
@@ -7,7 +7,9 @@ description: Train, debug, evaluate, and document CyberSecurity_OWASP model runs
7
 
8
  ## Overview
9
 
10
- Use this skill to run or modify the CyberSecurity_OWASP training and evaluation loop without weakening the verifier, reward integrity, or hackathon evidence trail. Treat the environment and reward engine as the product; training only starts after those are stable.
 
 
11
 
12
  ## References
13
 
@@ -37,28 +39,28 @@ Prefer the existing repo modules:
37
 
38
  - `training/rollout.py`: full OpenEnv episode loop, action JSON parsing, reward trace, rollout artifact fields.
39
  - `training/reward_funcs.py`: component reward functions exposed to TRL/GRPO.
40
- - `training/train_grpo.py`: `GRPOConfig`, model defaults, Trackio reporting, vLLM settings.
41
  - `training/eval_before_after.py`: baseline-vs-trained and held-out summary metrics.
42
  - `training/trackio_utils.py`: run naming, canonical metric names, Trackio init/log/finalize helpers.
43
 
44
  Default environment values:
45
 
46
  ```powershell
47
- $env:MODEL_NAME = "Qwen/Qwen3-1.7B"
48
  $env:TRACKIO_SPACE_ID = "Humanlearning/CyberSecurity_OWASP-trackio"
49
  $env:TRACKIO_PROJECT = "CyberSecurity_OWASP"
50
  $env:DIFFICULTY = "0"
51
  ```
52
 
53
- Use level-0 debug runs before scaling. Do not increase batch size, prompt count, scenario diversity, or difficulty until sampled artifacts show real discover-then-patch behavior rather than formatting compliance only.
54
 
55
  ## Training Workflow
56
 
57
  1. Validate the environment first: run the targeted tests that cover models, reset/step/state, rewards, anti-cheat, seed reproducibility, invalid actions, and rollouts.
58
- 2. Run a tiny smoke path that constructs `GRPOConfig` without starting expensive training.
59
- 3. Run a frozen-model or dummy-policy rollout and inspect the action trace, observations, terminal reason, and reward breakdown.
60
  4. Confirm Trackio receives component metrics and the run name follows `CyberSecurity_OWASP-<model>-<algo>-level<difficulty>-<YYYYMMDD-HHMM>-<git_sha>`.
61
- 5. Start a very small GRPO run only after the above passes. Watch completions and rollout artifacts during the run, not just aggregate reward.
62
  6. Evaluate baseline, trained, and held-out splits with `training/eval_before_after.py` and save summaries under `outputs/evals/`.
63
  7. Save sampled rollouts under `outputs/rollouts/` for baseline, mid-training, trained, and held-out evidence.
64
 
@@ -77,7 +79,7 @@ Stop or roll back if reward rises while sampled traces show deny-all patches, ha
77
 
78
  - Use TRL GRPO for verifier-driven rewards. Keep multiple independent reward functions for logging and diagnosis.
79
  - Keep the existing custom rollout path unless deliberately migrating to TRL's `environment_factory`. If migrating, preserve typed actions, observations, reward component logging, anti-cheat flags, and rollout artifacts.
80
- - Use vLLM colocate for small local runs when memory allows; use server mode only when a separate inference GPU/server is available.
81
  - For OpenEnv server training concurrency, ensure the server supports enough concurrent sessions for the generation batch.
82
  - Use Unsloth with LoRA or QLoRA for memory efficiency when the training machine supports it. Start from an instruct-capable checkpoint and verify the model has non-zero success probability before RL.
83
  - Pin and smoke-test TRL, Unsloth, vLLM, CUDA, and torch versions before longer runs.
 
7
 
8
  ## Overview
9
 
10
+ Use this skill to run or modify the CyberSecurity_OWASP training and evaluation loop without weakening the verifier, reward integrity, or hackathon evidence trail. Training is expected to run on Modal only.
11
+
12
+ Important: do **not** run GRPO/PPO training loops locally in this repo. Use Modal launchers (`scripts/modal_ephemeral_train.py` for smoke and `scripts/modal_train_grpo.py` for GRPO).
13
 
14
  ## References
15
 
 
39
 
40
  - `training/rollout.py`: full OpenEnv episode loop, action JSON parsing, reward trace, rollout artifact fields.
41
  - `training/reward_funcs.py`: component reward functions exposed to TRL/GRPO.
42
+ - `training/train_grpo.py`: `GRPOConfig`/model defaults and launch intent (does not run local training).
43
  - `training/eval_before_after.py`: baseline-vs-trained and held-out summary metrics.
44
  - `training/trackio_utils.py`: run naming, canonical metric names, Trackio init/log/finalize helpers.
45
 
46
  Default environment values:
47
 
48
  ```powershell
49
+ $env:MODEL_NAME = "google/gemma-2-2b-it"
50
  $env:TRACKIO_SPACE_ID = "Humanlearning/CyberSecurity_OWASP-trackio"
51
  $env:TRACKIO_PROJECT = "CyberSecurity_OWASP"
52
  $env:DIFFICULTY = "0"
53
  ```
54
 
55
+ Use level-0 debug runs before scaling, and verify them through Modal smoke/ephemeral runs.
56
 
57
  ## Training Workflow
58
 
59
  1. Validate the environment first: run the targeted tests that cover models, reset/step/state, rewards, anti-cheat, seed reproducibility, invalid actions, and rollouts.
60
+ 2. Run a Modal smoke path for lightweight config/run verification.
61
+ 3. Run a frozen-model or dummy-policy rollout on Modal and inspect the action trace, observations, terminal reason, and reward breakdown.
62
  4. Confirm Trackio receives component metrics and the run name follows `CyberSecurity_OWASP-<model>-<algo>-level<difficulty>-<YYYYMMDD-HHMM>-<git_sha>`.
63
+ 5. Start a very small GRPO run only after the above passes. Start via `scripts/modal_train_grpo.py --mode train`.
64
  6. Evaluate baseline, trained, and held-out splits with `training/eval_before_after.py` and save summaries under `outputs/evals/`.
65
  7. Save sampled rollouts under `outputs/rollouts/` for baseline, mid-training, trained, and held-out evidence.
66
 
 
79
 
80
  - Use TRL GRPO for verifier-driven rewards. Keep multiple independent reward functions for logging and diagnosis.
81
  - Keep the existing custom rollout path unless deliberately migrating to TRL's `environment_factory`. If migrating, preserve typed actions, observations, reward component logging, anti-cheat flags, and rollout artifacts.
82
+ - Use Modal as the default training path; local-only vLLM/GRPO execution is intentionally avoided in this repository.
83
  - For OpenEnv server training concurrency, ensure the server supports enough concurrent sessions for the generation batch.
84
  - Use Unsloth with LoRA or QLoRA for memory efficiency when the training machine supports it. Start from an instruct-capable checkpoint and verify the model has non-zero success probability before RL.
85
  - Pin and smoke-test TRL, Unsloth, vLLM, CUDA, and torch versions before longer runs.
01_ARCHITECTURE.md CHANGED
@@ -397,16 +397,18 @@ Editable source: `assets/env_rl_training_flow_diagram.mmd`
397
  9. Produce final demo: before/after trace + reward curve + held-out eval table.
398
  ```
399
 
400
- Recommended initial training setup:
401
 
402
  ```text
403
- Model: Qwen/Qwen3-1.7B or similar small instruct model
404
  Algorithm: GRPO via TRL or Unsloth-compatible loop
405
  Dataset prompt: repeated task instruction with randomized scenario IDs
406
  Max steps per episode: 30
407
  Rollouts per prompt: 2-4
408
  Logging: Trackio
409
  Primary eval: held-out deterministic test pass rate
 
 
410
  ```
411
 
412
  ## 9. Deployment architecture
 
397
  9. Produce final demo: before/after trace + reward curve + held-out eval table.
398
  ```
399
 
400
+ Recommended initial training setup (Modal-first):
401
 
402
  ```text
403
+ Model: google/gemma-2-2b-it (or compatible Gemma-class instruct model)
404
  Algorithm: GRPO via TRL or Unsloth-compatible loop
405
  Dataset prompt: repeated task instruction with randomized scenario IDs
406
  Max steps per episode: 30
407
  Rollouts per prompt: 2-4
408
  Logging: Trackio
409
  Primary eval: held-out deterministic test pass rate
410
+
411
+ Training execution is expected to run on Modal (persistent or ephemeral) rather than locally.
412
  ```
413
 
414
  ## 9. Deployment architecture
README.md CHANGED
@@ -149,6 +149,10 @@ Training files are under `training/`:
149
 
150
  The training scaffold is intentionally minimal until the environment/verifier behavior is stable. Trackio metric names and GRPO defaults follow the project brief.
151
 
 
 
 
 
152
  ## Trackio Run Tracking
153
 
154
  Trackio is the default tracker for official runs. Set `TRACKIO_SPACE_ID` to log to a hosted Hugging Face Trackio Space; otherwise Trackio records locally.
@@ -239,7 +243,7 @@ Defaults are derived from `HF_TOKEN`:
239
 
240
  - Trackio Space: `<hf-user>/CyberSecurity_OWASP-trackio`
241
  - Trackio project: `CyberSecurity_OWASP-grpo`
242
- - Output repo: `<hf-user>/CyberSecurity_OWASP-qwen3-1.7b-grpo-lora`
243
 
244
  Override these with `--trackio-space-id`, `--trackio-project`, and
245
  `--output-repo-id` when needed.
 
149
 
150
  The training scaffold is intentionally minimal until the environment/verifier behavior is stable. Trackio metric names and GRPO defaults follow the project brief.
151
 
152
+ `training/train_grpo.py` in this repo is a config helper only; it does not execute training locally.
153
+ Use the Modal launchers in `scripts/modal_train_grpo.py` (persistent) and
154
+ `scripts/modal_ephemeral_train.py` (smoke) for real GRPO runs.
155
+
156
  ## Trackio Run Tracking
157
 
158
  Trackio is the default tracker for official runs. Set `TRACKIO_SPACE_ID` to log to a hosted Hugging Face Trackio Space; otherwise Trackio records locally.
 
243
 
244
  - Trackio Space: `<hf-user>/CyberSecurity_OWASP-trackio`
245
  - Trackio project: `CyberSecurity_OWASP-grpo`
246
+ - Output repo: `<hf-user>/CyberSecurity_OWASP-gemma-2-2b-grpo-lora`
247
 
248
  Override these with `--trackio-space-id`, `--trackio-project`, and
249
  `--output-repo-id` when needed.
pyproject.toml CHANGED
@@ -45,7 +45,7 @@ server = "CyberSecurity_OWASP.server.app:main"
45
 
46
  [tool.setuptools]
47
  include-package-data = true
48
- packages = ["CyberSecurity_OWASP", "CyberSecurity_OWASP.server"]
49
  package-dir = { "CyberSecurity_OWASP" = ".", "CyberSecurity_OWASP.server" = "server" }
50
 
51
  [tool.pytest.ini_options]
 
45
 
46
  [tool.setuptools]
47
  include-package-data = true
48
+ packages = ["CyberSecurity_OWASP", "CyberSecurity_OWASP.server", "training"]
49
  package-dir = { "CyberSecurity_OWASP" = ".", "CyberSecurity_OWASP.server" = "server" }
50
 
51
  [tool.pytest.ini_options]
scripts/modal_ephemeral_train.py CHANGED
@@ -12,6 +12,8 @@ the local process, so the run disappears when ``modal run`` exits.
12
  from __future__ import annotations
13
 
14
  import json
 
 
15
  from datetime import datetime
16
  from pathlib import Path
17
  from typing import Any
@@ -20,6 +22,7 @@ import modal
20
 
21
 
22
  APP_NAME = "CyberSecurity_OWASP-ephemeral-training"
 
23
  REMOTE_PROJECT = "/root/CyberSecurity_OWASP"
24
  PROJECT_ROOT = Path(__file__).resolve().parents[1]
25
 
@@ -63,7 +66,11 @@ class NoopTrainer:
63
  ]
64
 
65
 
66
- @app.function(image=image, timeout=60 * 30)
 
 
 
 
67
  def run_ephemeral_smoke(
68
  episodes: int = 4,
69
  seed_start: int = 0,
@@ -75,17 +82,45 @@ def run_ephemeral_smoke(
75
  CybersecurityOwaspEnvironment,
76
  )
77
  from training.rollout import rollout_once
78
- from training.trackio_utils import log_trackio_metrics, trackio_run
 
 
 
 
 
 
 
79
 
80
  baseline = []
81
  oracle = []
 
 
 
 
 
82
 
83
  for offset in range(episodes):
84
  seed = seed_start + offset
85
 
86
  baseline_env = CybersecurityOwaspEnvironment()
87
- baseline_env.reset(seed=seed, split="validation")
88
- baseline.append(rollout_once(NoopTrainer(), baseline_env, max_steps=5))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  oracle_env = CybersecurityOwaspEnvironment()
91
  oracle_env.reset(seed=seed, split="validation")
@@ -124,19 +159,25 @@ def run_ephemeral_smoke(
124
  )
125
  oracle_env.step(CyberSecurityOWASPAction(tool_name="run_visible_tests"))
126
  final = oracle_env.step(CyberSecurityOWASPAction(tool_name="submit_fix"))
127
- oracle.append(
 
 
 
 
 
128
  {
129
- "seed": seed,
130
- "success": oracle_env.state.success,
131
  "reward_total": final.reward_breakdown.get("total", 0.0),
132
- "reward_breakdown": final.reward_breakdown,
133
  }
134
  )
 
135
 
136
  def mean(items: list[dict[str, Any]], key: str) -> float:
137
  return sum(float(item.get(key, 0.0)) for item in items) / max(1, len(items))
138
 
139
  run_name = f"{APP_NAME}-{datetime.utcnow().strftime('%Y%m%d-%H%M%S')}"
 
 
140
  result = {
141
  "run_name": run_name,
142
  "mode": "smoke",
@@ -145,6 +186,8 @@ def run_ephemeral_smoke(
145
  "baseline_mean_reward": mean(baseline, "reward_total"),
146
  "oracle_mean_reward": mean(oracle, "reward_total"),
147
  "oracle_success_rate": mean(oracle, "success"),
 
 
148
  "baseline": baseline,
149
  "oracle": oracle,
150
  }
@@ -160,8 +203,10 @@ def run_ephemeral_smoke(
160
  },
161
  group="smoke",
162
  ):
 
163
  log_trackio_metrics(
164
  {
 
165
  "smoke/baseline_mean_reward": result["baseline_mean_reward"],
166
  "smoke/oracle_mean_reward": result["oracle_mean_reward"],
167
  "smoke/oracle_success_rate": result["oracle_success_rate"],
@@ -179,6 +224,130 @@ def run_grpo_config_check() -> str:
179
  return str(build_grpo_config())
180
 
181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
  @app.local_entrypoint()
183
  def main(
184
  mode: str = "smoke",
@@ -186,6 +355,7 @@ def main(
186
  seed_start: int = 0,
187
  trackio_space_id: str = "",
188
  trackio_project: str = "CyberSecurity_OWASP-smoke",
 
189
  ) -> None:
190
  if mode == "smoke":
191
  result = run_ephemeral_smoke.remote(
@@ -201,5 +371,23 @@ def main(
201
  print(json.dumps({"saved": str(output_path), **result}, indent=2, sort_keys=True))
202
  elif mode == "grpo-config":
203
  print(run_grpo_config_check.remote())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  else:
205
- raise ValueError("mode must be 'smoke' or 'grpo-config'")
 
 
 
12
  from __future__ import annotations
13
 
14
  import json
15
+ import subprocess
16
+ import time
17
  from datetime import datetime
18
  from pathlib import Path
19
  from typing import Any
 
22
 
23
 
24
  APP_NAME = "CyberSecurity_OWASP-ephemeral-training"
25
+ SECRET_NAME = "CyberSecurity_OWASP-secrets"
26
  REMOTE_PROJECT = "/root/CyberSecurity_OWASP"
27
  PROJECT_ROOT = Path(__file__).resolve().parents[1]
28
 
 
66
  ]
67
 
68
 
69
+ @app.function(
70
+ image=image,
71
+ timeout=60 * 30,
72
+ secrets=[modal.Secret.from_name(SECRET_NAME, required_keys=["HF_TOKEN"])],
73
+ )
74
  def run_ephemeral_smoke(
75
  episodes: int = 4,
76
  seed_start: int = 0,
 
82
  CybersecurityOwaspEnvironment,
83
  )
84
  from training.rollout import rollout_once
85
+ from training.trackio_utils import (
86
+ aggregate_episode_metrics,
87
+ episode_record_from_state,
88
+ log_episode_batch,
89
+ log_trackio_metrics,
90
+ trace_table_rows,
91
+ trackio_run,
92
+ )
93
 
94
  baseline = []
95
  oracle = []
96
+ run_context = {
97
+ "algo": "modal_ephemeral_smoke",
98
+ "reward_version": "reward_v1",
99
+ "env_version": "0.1.0",
100
+ }
101
 
102
  for offset in range(episodes):
103
  seed = seed_start + offset
104
 
105
  baseline_env = CybersecurityOwaspEnvironment()
106
+ baseline_rollout = rollout_once(
107
+ NoopTrainer(),
108
+ baseline_env,
109
+ max_steps=5,
110
+ reset_kwargs={"seed": seed, "split": "validation", "difficulty": 0},
111
+ )
112
+ baseline_record = episode_record_from_state(
113
+ baseline_env.state,
114
+ run_context={**run_context, "base_model": "noop"},
115
+ )
116
+ baseline_record.update(
117
+ {
118
+ "reward_total": baseline_rollout.get("reward_total", 0.0),
119
+ "success": baseline_rollout.get("success", False),
120
+ "episode_length": baseline_rollout.get("episode_length", 0),
121
+ }
122
+ )
123
+ baseline.append(baseline_record)
124
 
125
  oracle_env = CybersecurityOwaspEnvironment()
126
  oracle_env.reset(seed=seed, split="validation")
 
159
  )
160
  oracle_env.step(CyberSecurityOWASPAction(tool_name="run_visible_tests"))
161
  final = oracle_env.step(CyberSecurityOWASPAction(tool_name="submit_fix"))
162
+ oracle_record = episode_record_from_state(
163
+ oracle_env.state,
164
+ run_context={**run_context, "base_model": "oracle"},
165
+ final_observation=final.model_dump(),
166
+ )
167
+ oracle_record.update(
168
  {
 
 
169
  "reward_total": final.reward_breakdown.get("total", 0.0),
170
+ "success": oracle_env.state.success,
171
  }
172
  )
173
+ oracle.append(oracle_record)
174
 
175
  def mean(items: list[dict[str, Any]], key: str) -> float:
176
  return sum(float(item.get(key, 0.0)) for item in items) / max(1, len(items))
177
 
178
  run_name = f"{APP_NAME}-{datetime.utcnow().strftime('%Y%m%d-%H%M%S')}"
179
+ episode_records = [*baseline, *oracle]
180
+ tracking_metrics = aggregate_episode_metrics(episode_records)
181
  result = {
182
  "run_name": run_name,
183
  "mode": "smoke",
 
186
  "baseline_mean_reward": mean(baseline, "reward_total"),
187
  "oracle_mean_reward": mean(oracle, "reward_total"),
188
  "oracle_success_rate": mean(oracle, "success"),
189
+ "tracking_metrics": tracking_metrics,
190
+ "tracking_trace_rows": trace_table_rows(episode_records),
191
  "baseline": baseline,
192
  "oracle": oracle,
193
  }
 
203
  },
204
  group="smoke",
205
  ):
206
+ logged_metrics = log_episode_batch(episode_records, step=0)
207
  log_trackio_metrics(
208
  {
209
+ **logged_metrics,
210
  "smoke/baseline_mean_reward": result["baseline_mean_reward"],
211
  "smoke/oracle_mean_reward": result["oracle_mean_reward"],
212
  "smoke/oracle_success_rate": result["oracle_success_rate"],
 
224
  return str(build_grpo_config())
225
 
226
 
227
+ @app.function(
228
+ image=image,
229
+ timeout=60 * 10,
230
+ secrets=[modal.Secret.from_name(SECRET_NAME, required_keys=["HF_TOKEN"])],
231
+ )
232
+ def verify_trackio_run(
233
+ run_name: str,
234
+ trackio_space_id: str = "Humanlearning/CyberSecurity_OWASP-trackio",
235
+ trackio_project: str = "CyberSecurity_OWASP-smoke",
236
+ ) -> dict[str, Any]:
237
+ import os
238
+ from training.trackio_utils import (
239
+ REQUIRED_SMOKE_TRACKIO_ITEMS,
240
+ missing_required_trackio_items,
241
+ )
242
+
243
+ hf_token = os.environ["HF_TOKEN"]
244
+ cmd = [
245
+ "trackio",
246
+ "get",
247
+ "run",
248
+ "--project",
249
+ trackio_project,
250
+ "--run",
251
+ run_name,
252
+ "--space",
253
+ trackio_space_id,
254
+ "--hf-token",
255
+ hf_token,
256
+ "--json",
257
+ ]
258
+ metrics_cmd = [
259
+ "trackio",
260
+ "list",
261
+ "metrics",
262
+ "--project",
263
+ trackio_project,
264
+ "--run",
265
+ run_name,
266
+ "--space",
267
+ trackio_space_id,
268
+ "--hf-token",
269
+ hf_token,
270
+ "--json",
271
+ ]
272
+ last_result: dict[str, Any] = {}
273
+ for attempt in range(1, 4):
274
+ completed = subprocess.run(cmd, capture_output=True, text=True)
275
+ metrics_completed = subprocess.run(metrics_cmd, capture_output=True, text=True)
276
+ last_result = {
277
+ "attempt": attempt,
278
+ "returncode": completed.returncode,
279
+ "stdout": completed.stdout[-4000:],
280
+ "stderr": completed.stderr[-4000:],
281
+ "metrics_returncode": metrics_completed.returncode,
282
+ "metrics_stdout": metrics_completed.stdout[-4000:],
283
+ "metrics_stderr": metrics_completed.stderr[-4000:],
284
+ }
285
+ if completed.returncode == 0:
286
+ data = json.loads(completed.stdout)
287
+ if metrics_completed.returncode == 0:
288
+ metrics_data = json.loads(metrics_completed.stdout)
289
+ if isinstance(metrics_data.get("metrics"), list):
290
+ data["metrics"] = metrics_data["metrics"]
291
+ missing = missing_required_trackio_items(data, REQUIRED_SMOKE_TRACKIO_ITEMS)
292
+ return {
293
+ "ok": not missing,
294
+ "trackio_space_id": trackio_space_id,
295
+ "trackio_project": trackio_project,
296
+ "run_name": run_name,
297
+ "required_items": list(REQUIRED_SMOKE_TRACKIO_ITEMS),
298
+ "missing_required_items": missing,
299
+ "run": data,
300
+ }
301
+ time.sleep(10)
302
+ return {
303
+ "ok": False,
304
+ "trackio_space_id": trackio_space_id,
305
+ "trackio_project": trackio_project,
306
+ "run_name": run_name,
307
+ "last_result": last_result,
308
+ }
309
+
310
+
311
+ @app.function(
312
+ image=image,
313
+ timeout=60 * 10,
314
+ secrets=[modal.Secret.from_name(SECRET_NAME, required_keys=["HF_TOKEN"])],
315
+ )
316
+ def inspect_trackio_space(
317
+ trackio_space_id: str = "Humanlearning/CyberSecurity_OWASP-trackio",
318
+ ) -> dict[str, Any]:
319
+ import os
320
+
321
+ hf_token = os.environ["HF_TOKEN"]
322
+
323
+ def run_trackio(args: list[str]) -> dict[str, Any]:
324
+ completed = subprocess.run(
325
+ ["trackio", *args, "--space", trackio_space_id, "--hf-token", hf_token, "--json"],
326
+ capture_output=True,
327
+ text=True,
328
+ )
329
+ result = {
330
+ "returncode": completed.returncode,
331
+ "stdout": completed.stdout[-8000:],
332
+ "stderr": completed.stderr[-4000:],
333
+ }
334
+ if completed.returncode == 0:
335
+ result["json"] = json.loads(completed.stdout)
336
+ return result
337
+
338
+ projects_result = run_trackio(["list", "projects"])
339
+ projects = (projects_result.get("json") or {}).get("projects", [])
340
+ runs_by_project = {
341
+ project: run_trackio(["list", "runs", "--project", project])
342
+ for project in projects
343
+ }
344
+ return {
345
+ "trackio_space_id": trackio_space_id,
346
+ "projects": projects_result,
347
+ "runs_by_project": runs_by_project,
348
+ }
349
+
350
+
351
  @app.local_entrypoint()
352
  def main(
353
  mode: str = "smoke",
 
355
  seed_start: int = 0,
356
  trackio_space_id: str = "",
357
  trackio_project: str = "CyberSecurity_OWASP-smoke",
358
+ run_name: str = "",
359
  ) -> None:
360
  if mode == "smoke":
361
  result = run_ephemeral_smoke.remote(
 
371
  print(json.dumps({"saved": str(output_path), **result}, indent=2, sort_keys=True))
372
  elif mode == "grpo-config":
373
  print(run_grpo_config_check.remote())
374
+ elif mode == "verify-trackio":
375
+ if not run_name:
376
+ raise ValueError("--run-name is required for verify-trackio mode")
377
+ result = verify_trackio_run.remote(
378
+ run_name=run_name,
379
+ trackio_space_id=trackio_space_id
380
+ or "Humanlearning/CyberSecurity_OWASP-trackio",
381
+ trackio_project=trackio_project,
382
+ )
383
+ print(json.dumps(result, indent=2, sort_keys=True))
384
+ elif mode == "inspect-trackio":
385
+ result = inspect_trackio_space.remote(
386
+ trackio_space_id=trackio_space_id
387
+ or "Humanlearning/CyberSecurity_OWASP-trackio",
388
+ )
389
+ print(json.dumps(result, indent=2, sort_keys=True))
390
  else:
391
+ raise ValueError(
392
+ "mode must be 'smoke', 'grpo-config', 'verify-trackio', or 'inspect-trackio'"
393
+ )
scripts/modal_train_grpo.py CHANGED
@@ -28,12 +28,61 @@ import modal
28
 
29
  APP_NAME = "CyberSecurity_OWASP-grpo"
30
  VOLUME_NAME = "CyberSecurity_OWASP-grpo-runs"
 
31
  SECRET_NAME = "CyberSecurity_OWASP-secrets"
32
  RUNS_DIR = pathlib.Path("/runs")
 
 
 
 
 
 
 
33
  REMOTE_PROJECT = "/root/CyberSecurity_OWASP"
34
  PROJECT_ROOT = pathlib.Path(__file__).resolve().parents[1]
35
  PUBLIC_REPO_URL = "https://github.com/humandotlearning/CyberSecurity_OWASP.git"
36
  PUBLIC_REPO_BRANCH = "master"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
 
39
  def _load_local_env_file() -> None:
@@ -114,6 +163,7 @@ def _training_image() -> modal.Image:
114
  "unsloth_zoo[base] @ git+https://github.com/unslothai/unsloth-zoo",
115
  "unsloth[base] @ git+https://github.com/unslothai/unsloth",
116
  )
 
117
  .uv_pip_install("pydantic==2.10.6")
118
  .uv_pip_install("mergekit", "immutables==0.21", extra_options="--no-deps")
119
  .uv_pip_install("llm-blender", "weave")
@@ -159,22 +209,25 @@ def _training_image() -> modal.Image:
159
 
160
  app = modal.App(APP_NAME)
161
  volume = modal.Volume.from_name(VOLUME_NAME, create_if_missing=True)
 
162
  secrets = _modal_secrets()
163
 
164
 
165
  @app.function(
166
  image=_training_image(),
167
- gpu=["L4", "A10G"],
168
  timeout=4 * 60 * 60,
169
- volumes={RUNS_DIR: volume},
170
  secrets=secrets,
171
  )
172
  def check_training_imports() -> dict[str, str]:
 
 
173
  import torch
174
  import trackio
175
  from datasets import Dataset
176
  from trl import GRPOConfig, GRPOTrainer
177
- from unsloth import FastLanguageModel
178
 
179
  from CyberSecurity_OWASP.server.CyberSecurity_OWASP_environment import (
180
  CybersecurityOwaspEnvironment,
@@ -189,16 +242,19 @@ def check_training_imports() -> dict[str, str]:
189
  "grpo_config": GRPOConfig.__name__,
190
  "grpo_trainer": GRPOTrainer.__name__,
191
  "unsloth_model": FastLanguageModel.__name__,
 
192
  "env": CybersecurityOwaspEnvironment.__name__,
193
  "reset_phase": obs.phase,
 
 
194
  }
195
 
196
 
197
  @app.function(
198
  image=_training_image(),
199
- gpu=["L4", "A10G"],
200
  timeout=4 * 60 * 60,
201
- volumes={RUNS_DIR: volume},
202
  secrets=secrets,
203
  )
204
  def train_cybersecurity_owasp_grpo(
@@ -208,11 +264,11 @@ def train_cybersecurity_owasp_grpo(
208
  dataset_size: int = 16,
209
  difficulty: int = 0,
210
  split: str = "train",
211
- model_name: str = "Qwen/Qwen3-1.7B",
212
  max_seq_length: int = 4096,
213
  max_completion_length: int = 768,
214
  lora_rank: int = 32,
215
- trackio_space_id: str = "",
216
  trackio_project: str = "CyberSecurity_OWASP-grpo",
217
  num_generations: int = 2,
218
  seed_start: int = 0,
@@ -221,15 +277,18 @@ def train_cybersecurity_owasp_grpo(
221
  source_mode: str = "local",
222
  repo_url: str = PUBLIC_REPO_URL,
223
  repo_branch: str = PUBLIC_REPO_BRANCH,
 
224
  ) -> dict[str, str | int | float]:
225
  import inspect
226
  import statistics
227
 
 
 
228
  import torch
229
- from unsloth import FastLanguageModel
230
  import transformers.utils.hub as transformers_hub
231
  from datasets import Dataset
232
- from huggingface_hub import whoami
233
  from transformers import TrainerCallback
234
  from trl import GRPOConfig, GRPOTrainer, clone_chat_template
235
  from trl.chat_template_utils import add_response_schema
@@ -240,14 +299,16 @@ def train_cybersecurity_owasp_grpo(
240
  from CyberSecurity_OWASP.server.CyberSecurity_OWASP_environment import (
241
  CybersecurityOwaspEnvironment,
242
  )
 
 
 
 
 
 
 
 
243
 
244
- if not hasattr(transformers_hub, "TRANSFORMERS_CACHE"):
245
- transformers_hub.TRANSFORMERS_CACHE = os.path.join(
246
- os.path.expanduser("~"),
247
- ".cache",
248
- "huggingface",
249
- "hub",
250
- )
251
 
252
  hf_token = os.environ.get("HF_TOKEN")
253
  if not hf_token:
@@ -257,8 +318,20 @@ def train_cybersecurity_owasp_grpo(
257
 
258
  user = whoami(token=hf_token)["name"]
259
  env_repo_id = env_repo_id or f"{user}/CyberSecurity_OWASP"
260
- output_repo_id = output_repo_id or f"{user}/CyberSecurity_OWASP-qwen3-1.7b-grpo-lora"
261
- trackio_space_id = trackio_space_id or f"{user}/CyberSecurity_OWASP-trackio"
 
 
 
 
 
 
 
 
 
 
 
 
262
 
263
  os.environ["TRACKIO_SPACE_ID"] = trackio_space_id
264
  os.environ["TRACKIO_PROJECT"] = trackio_project
@@ -271,6 +344,13 @@ def train_cybersecurity_owasp_grpo(
271
  output_dir = RUNS_DIR / run_name
272
  output_dir.mkdir(parents=True, exist_ok=True)
273
 
 
 
 
 
 
 
 
274
  training_prompt = (
275
  "You are a defensive AppSec repair agent in the local CyberSecurity_OWASP "
276
  "OpenEnv environment. Use only the provided local tools. Do not target real "
@@ -570,49 +650,48 @@ def train_cybersecurity_owasp_grpo(
570
  completions = kwargs.get("completions") or kwargs.get("completion") or []
571
  trace_step["value"] += 1
572
 
573
- breakdowns = [getattr(env, "reward_breakdown", {}) or {} for env in environments]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
574
  metrics = {
575
- "train/reward_total_mean": _mean(rewards),
576
- "train/reward_discovery_mean": _mean(
577
- [float(item.get("discovery", 0.0)) for item in breakdowns]
578
- ),
579
- "train/reward_security_mean": _mean(
580
- [float(item.get("security", 0.0)) for item in breakdowns]
581
- ),
582
- "train/reward_regression_mean": _mean(
583
- [float(item.get("regression", 0.0)) for item in breakdowns]
584
- ),
585
- "train/reward_public_routes_mean": _mean(
586
- [float(item.get("public_routes", 0.0)) for item in breakdowns]
587
- ),
588
- "train/reward_patch_quality_mean": _mean(
589
- [float(item.get("patch_quality", 0.0)) for item in breakdowns]
590
- ),
591
- "train/reward_visible_tests_mean": _mean(
592
- [float(item.get("visible_tests", 0.0)) for item in breakdowns]
593
- ),
594
- "train/reward_anti_cheat_mean": _mean(
595
- [float(item.get("anti_cheat", 0.0)) for item in breakdowns]
596
- ),
597
- "train/success_rate": _mean(
598
- [1.0 if bool(getattr(env, "success", False)) else 0.0 for env in environments]
599
- ),
600
- "train/invalid_action_rate": _mean(
601
- [float(getattr(env, "invalid_actions", 0)) for env in environments]
602
- ),
603
- "train/episode_length_mean": _mean(
604
- [
605
- float(getattr(env, "trace_metadata", {}).get("step_count", 0))
606
- for env in environments
607
- ]
608
- ),
609
  }
 
 
 
610
 
611
  try:
612
- trackio.log(metrics, step=trace_step["value"])
613
  except Exception as exc:
614
  print(f"Trackio metric logging skipped: {exc!r}")
615
 
 
 
 
 
 
 
 
 
 
616
  for index, env in enumerate(environments):
617
  messages = list(getattr(env, "trace_messages", []))
618
  if index < len(completions):
@@ -655,9 +734,24 @@ def train_cybersecurity_owasp_grpo(
655
  return rewards
656
 
657
  class TrackioSystemMetricsCallback(TrainerCallback):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
658
  def on_log(self, args, state, control, logs=None, **kwargs):
659
  try:
660
- metrics = trackio.log_gpu()
661
  except Exception as exc:
662
  print(f"Trackio GPU metrics skipped: {exc!r}")
663
  return control
@@ -666,6 +760,13 @@ def train_cybersecurity_owasp_grpo(
666
  print(f"Trackio GPU metrics logged at step {state.global_step}: {summary}")
667
  return control
668
 
 
 
 
 
 
 
 
669
  print(f"CUDA available: {torch.cuda.is_available()}")
670
  if source_mode == "public":
671
  print(f"Installed CyberSecurity_OWASP from public repo: {repo_url}@{repo_branch}")
@@ -675,27 +776,114 @@ def train_cybersecurity_owasp_grpo(
675
  print(f"Trackio Project: {trackio_project}")
676
  print(f"Output repo: {output_repo_id}")
677
  print(f"Run name: {run_name}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
678
 
679
- model, tokenizer = FastLanguageModel.from_pretrained(
 
 
680
  model_name=model_name,
681
  max_seq_length=max_seq_length,
682
  load_in_4bit=False,
683
  fast_inference=False,
 
684
  token=hf_token,
685
  )
 
 
 
686
  try:
687
  tokenizer = add_response_schema(tokenizer)
688
  except Exception as exc:
689
- print(f"Tokenizer response schema add failed before cloning: {exc!r}")
690
- model, tokenizer, added_tokens = clone_chat_template(
691
- model,
692
- tokenizer,
693
- "Qwen/Qwen3-0.6B",
694
- )
695
- print(f"Cloned Qwen3 chat template; added {len(added_tokens)} tokens.")
696
- tokenizer = add_response_schema(tokenizer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
697
 
698
- model = FastLanguageModel.get_peft_model(
699
  model,
700
  r=lora_rank,
701
  target_modules=[
@@ -711,7 +899,9 @@ def train_cybersecurity_owasp_grpo(
711
  use_gradient_checkpointing="unsloth",
712
  random_state=3407,
713
  )
714
- FastLanguageModel.for_training(model)
 
 
715
 
716
  grpo_config_values = {
717
  "temperature": 1.0,
@@ -732,7 +922,7 @@ def train_cybersecurity_owasp_grpo(
732
  "trackio_space_id": trackio_space_id,
733
  "run_name": run_name,
734
  "output_dir": str(output_dir),
735
- "push_to_hub": True,
736
  "hub_model_id": output_repo_id,
737
  "hub_private_repo": True,
738
  "hub_strategy": "every_save",
@@ -742,7 +932,7 @@ def train_cybersecurity_owasp_grpo(
742
  "epsilon_high": 0.28,
743
  "delta": 1.5,
744
  "loss_type": "bnpo",
745
- "mask_truncated_completions": False,
746
  }
747
  grpo_config_parameters = set(inspect.signature(GRPOConfig).parameters)
748
  skipped_config_keys = sorted(set(grpo_config_values) - grpo_config_parameters)
@@ -776,9 +966,23 @@ def train_cybersecurity_owasp_grpo(
776
  if key in trainer_parameters
777
  }
778
  )
 
779
  trainer.train()
780
- trainer.push_to_hub()
 
 
 
 
 
 
781
  volume.commit()
 
 
 
 
 
 
 
782
 
783
  return {
784
  "run_name": run_name,
@@ -796,6 +1000,7 @@ def train_cybersecurity_owasp_grpo(
796
  "source_mode": source_mode,
797
  "repo_url": repo_url,
798
  "repo_branch": repo_branch,
 
799
  }
800
 
801
 
@@ -808,11 +1013,11 @@ def main(
808
  dataset_size: int = 16,
809
  difficulty: int = 0,
810
  split: str = "train",
811
- model_name: str = "Qwen/Qwen3-1.7B",
812
  max_seq_length: int = 4096,
813
  max_completion_length: int = 768,
814
  lora_rank: int = 32,
815
- trackio_space_id: str = "",
816
  trackio_project: str = "CyberSecurity_OWASP-grpo",
817
  num_generations: int = 2,
818
  seed_start: int = 0,
@@ -821,6 +1026,7 @@ def main(
821
  repo_url: str = PUBLIC_REPO_URL,
822
  repo_branch: str = PUBLIC_REPO_BRANCH,
823
  detach: bool = False,
 
824
  ) -> None:
825
  if mode == "config":
826
  result = check_training_imports.remote()
@@ -829,7 +1035,10 @@ def main(
829
  if mode != "train":
830
  raise ValueError("mode must be 'train' or 'config'")
831
 
832
- trackio_space_id = trackio_space_id or os.environ.get("TRACKIO_SPACE_ID", "")
 
 
 
833
  trackio_project = trackio_project or os.environ.get(
834
  "TRACKIO_PROJECT", "CyberSecurity_OWASP-grpo"
835
  )
@@ -842,12 +1051,15 @@ def main(
842
  from huggingface_hub import whoami
843
 
844
  user = whoami(token=hf_token)["name"]
845
- resolved_trackio_space_id = (
846
- resolved_trackio_space_id or f"{user}/CyberSecurity_OWASP-trackio"
847
- )
 
 
 
848
  resolved_output_repo_id = (
849
  resolved_output_repo_id
850
- or f"{user}/CyberSecurity_OWASP-qwen3-1.7b-grpo-lora"
851
  )
852
  except Exception as exc:
853
  print(f"Could not resolve Hugging Face defaults locally: {exc!r}")
@@ -883,8 +1095,10 @@ def main(
883
  else:
884
  print(
885
  "Output model repo: derived remotely from HF_TOKEN as "
886
- "<hf-user>/CyberSecurity_OWASP-qwen3-1.7b-grpo-lora"
887
  )
 
 
888
 
889
  kwargs = dict(
890
  env_repo_id=env_repo_id,
@@ -906,6 +1120,7 @@ def main(
906
  source_mode=source_mode,
907
  repo_url=repo_url,
908
  repo_branch=repo_branch,
 
909
  )
910
  if detach:
911
  call = train_cybersecurity_owasp_grpo.spawn(**kwargs)
 
28
 
29
  APP_NAME = "CyberSecurity_OWASP-grpo"
30
  VOLUME_NAME = "CyberSecurity_OWASP-grpo-runs"
31
+ CACHE_VOLUME_NAME = "CyberSecurity_OWASP-model-cache"
32
  SECRET_NAME = "CyberSecurity_OWASP-secrets"
33
  RUNS_DIR = pathlib.Path("/runs")
34
+ CACHE_DIR = pathlib.Path("/cache")
35
+ HF_HOME_DIR = CACHE_DIR / "huggingface"
36
+ HF_HUB_CACHE_DIR = HF_HOME_DIR / "hub"
37
+ TORCH_HOME_DIR = CACHE_DIR / "torch"
38
+ XDG_CACHE_DIR = CACHE_DIR / "xdg"
39
+ UNSLOTH_CACHE_DIR = CACHE_DIR / "unsloth"
40
+ TRITON_CACHE_DIR = CACHE_DIR / "triton"
41
  REMOTE_PROJECT = "/root/CyberSecurity_OWASP"
42
  PROJECT_ROOT = pathlib.Path(__file__).resolve().parents[1]
43
  PUBLIC_REPO_URL = "https://github.com/humandotlearning/CyberSecurity_OWASP.git"
44
  PUBLIC_REPO_BRANCH = "master"
45
+ DEFAULT_GEMMA_MODEL = "unsloth/gemma-4-E2B-it"
46
+
47
+
48
+ def _model_repo_slug(model_name: str) -> str:
49
+ return (
50
+ model_name.replace("/", "-")
51
+ .replace("_", "-")
52
+ .replace(".", "-")
53
+ .lower()
54
+ )
55
+
56
+
57
+ def _hf_model_cache_path(model_name: str) -> pathlib.Path:
58
+ return HF_HUB_CACHE_DIR / f"models--{model_name.replace('/', '--')}"
59
+
60
+
61
+ def _configure_modal_cache_env() -> dict[str, str]:
62
+ values = {
63
+ "HF_HOME": str(HF_HOME_DIR),
64
+ "HF_HUB_CACHE": str(HF_HUB_CACHE_DIR),
65
+ "TRANSFORMERS_CACHE": str(HF_HUB_CACHE_DIR),
66
+ "TORCH_HOME": str(TORCH_HOME_DIR),
67
+ "XDG_CACHE_HOME": str(XDG_CACHE_DIR),
68
+ "UNSLOTH_CACHE_DIR": str(UNSLOTH_CACHE_DIR),
69
+ "UNSLOTH_COMPILE_CACHE": str(UNSLOTH_CACHE_DIR / "compile"),
70
+ "TRITON_CACHE_DIR": str(TRITON_CACHE_DIR),
71
+ }
72
+ for key, value in values.items():
73
+ os.environ[key] = value
74
+ for path in {
75
+ CACHE_DIR,
76
+ HF_HOME_DIR,
77
+ HF_HUB_CACHE_DIR,
78
+ TORCH_HOME_DIR,
79
+ XDG_CACHE_DIR,
80
+ UNSLOTH_CACHE_DIR,
81
+ UNSLOTH_CACHE_DIR / "compile",
82
+ TRITON_CACHE_DIR,
83
+ }:
84
+ path.mkdir(parents=True, exist_ok=True)
85
+ return values
86
 
87
 
88
  def _load_local_env_file() -> None:
 
163
  "unsloth_zoo[base] @ git+https://github.com/unslothai/unsloth-zoo",
164
  "unsloth[base] @ git+https://github.com/unslothai/unsloth",
165
  )
166
+ .uv_pip_install("timm", extra_options="--no-deps")
167
  .uv_pip_install("pydantic==2.10.6")
168
  .uv_pip_install("mergekit", "immutables==0.21", extra_options="--no-deps")
169
  .uv_pip_install("llm-blender", "weave")
 
209
 
210
  app = modal.App(APP_NAME)
211
  volume = modal.Volume.from_name(VOLUME_NAME, create_if_missing=True)
212
+ cache_volume = modal.Volume.from_name(CACHE_VOLUME_NAME, create_if_missing=True)
213
  secrets = _modal_secrets()
214
 
215
 
216
  @app.function(
217
  image=_training_image(),
218
+ gpu="L4",
219
  timeout=4 * 60 * 60,
220
+ volumes={RUNS_DIR: volume, CACHE_DIR: cache_volume},
221
  secrets=secrets,
222
  )
223
  def check_training_imports() -> dict[str, str]:
224
+ cache_env = _configure_modal_cache_env()
225
+
226
  import torch
227
  import trackio
228
  from datasets import Dataset
229
  from trl import GRPOConfig, GRPOTrainer
230
+ from unsloth import FastLanguageModel, FastVisionModel
231
 
232
  from CyberSecurity_OWASP.server.CyberSecurity_OWASP_environment import (
233
  CybersecurityOwaspEnvironment,
 
242
  "grpo_config": GRPOConfig.__name__,
243
  "grpo_trainer": GRPOTrainer.__name__,
244
  "unsloth_model": FastLanguageModel.__name__,
245
+ "unsloth_vision_model": FastVisionModel.__name__,
246
  "env": CybersecurityOwaspEnvironment.__name__,
247
  "reset_phase": obs.phase,
248
+ "hf_home": cache_env["HF_HOME"],
249
+ "hf_hub_cache": cache_env["HF_HUB_CACHE"],
250
  }
251
 
252
 
253
  @app.function(
254
  image=_training_image(),
255
+ gpu="L4",
256
  timeout=4 * 60 * 60,
257
+ volumes={RUNS_DIR: volume, CACHE_DIR: cache_volume},
258
  secrets=secrets,
259
  )
260
  def train_cybersecurity_owasp_grpo(
 
264
  dataset_size: int = 16,
265
  difficulty: int = 0,
266
  split: str = "train",
267
+ model_name: str = DEFAULT_GEMMA_MODEL,
268
  max_seq_length: int = 4096,
269
  max_completion_length: int = 768,
270
  lora_rank: int = 32,
271
+ trackio_space_id: str = "Humanlearning/CyberSecurity_OWASP-trackio",
272
  trackio_project: str = "CyberSecurity_OWASP-grpo",
273
  num_generations: int = 2,
274
  seed_start: int = 0,
 
277
  source_mode: str = "local",
278
  repo_url: str = PUBLIC_REPO_URL,
279
  repo_branch: str = PUBLIC_REPO_BRANCH,
280
+ push_to_hub: bool = False,
281
  ) -> dict[str, str | int | float]:
282
  import inspect
283
  import statistics
284
 
285
+ cache_env = _configure_modal_cache_env()
286
+
287
  import torch
288
+ from unsloth import FastLanguageModel, FastVisionModel
289
  import transformers.utils.hub as transformers_hub
290
  from datasets import Dataset
291
+ from huggingface_hub import snapshot_download, whoami
292
  from transformers import TrainerCallback
293
  from trl import GRPOConfig, GRPOTrainer, clone_chat_template
294
  from trl.chat_template_utils import add_response_schema
 
299
  from CyberSecurity_OWASP.server.CyberSecurity_OWASP_environment import (
300
  CybersecurityOwaspEnvironment,
301
  )
302
+ from training.trackio_utils import (
303
+ aggregate_episode_metrics,
304
+ episode_record_from_state,
305
+ log_gpu_metrics,
306
+ log_trace_table,
307
+ log_trackio_metrics,
308
+ train_metric_aliases,
309
+ )
310
 
311
+ transformers_hub.TRANSFORMERS_CACHE = cache_env["HF_HUB_CACHE"]
 
 
 
 
 
 
312
 
313
  hf_token = os.environ.get("HF_TOKEN")
314
  if not hf_token:
 
318
 
319
  user = whoami(token=hf_token)["name"]
320
  env_repo_id = env_repo_id or f"{user}/CyberSecurity_OWASP"
321
+ output_repo_id = output_repo_id or (
322
+ f"{user}/CyberSecurity_OWASP-{_model_repo_slug(model_name)}-grpo-lora"
323
+ )
324
+ if not trackio_space_id:
325
+ trackio_space_id = "Humanlearning/CyberSecurity_OWASP-trackio"
326
+ if hf_token:
327
+ try:
328
+ from huggingface_hub import whoami
329
+
330
+ user = whoami(token=hf_token)["name"]
331
+ if user == "humandotlearning":
332
+ trackio_space_id = f"{user}/CyberSecurity_OWASP-trackio"
333
+ except Exception:
334
+ pass
335
 
336
  os.environ["TRACKIO_SPACE_ID"] = trackio_space_id
337
  os.environ["TRACKIO_PROJECT"] = trackio_project
 
344
  output_dir = RUNS_DIR / run_name
345
  output_dir.mkdir(parents=True, exist_ok=True)
346
 
347
+ try:
348
+ cache_volume.reload()
349
+ print(f"Reloaded Modal model cache volume: {CACHE_VOLUME_NAME}")
350
+ except Exception as exc:
351
+ print(f"Model cache volume reload skipped: {exc!r}")
352
+ cache_env = _configure_modal_cache_env()
353
+
354
  training_prompt = (
355
  "You are a defensive AppSec repair agent in the local CyberSecurity_OWASP "
356
  "OpenEnv environment. Use only the provided local tools. Do not target real "
 
650
  completions = kwargs.get("completions") or kwargs.get("completion") or []
651
  trace_step["value"] += 1
652
 
653
+ episode_records = []
654
+ for env, reward in zip(environments, rewards):
655
+ record = episode_record_from_state(
656
+ env._env.state,
657
+ run_context={
658
+ "base_model": model_name,
659
+ "algo": "grpo",
660
+ "reward_version": "reward_v1",
661
+ "env_version": "0.1.0",
662
+ },
663
+ )
664
+ record.update(
665
+ {
666
+ "reward_total": reward,
667
+ "success": bool(getattr(env, "success", False)),
668
+ }
669
+ )
670
+ episode_records.append(record)
671
+
672
+ canonical_metrics = aggregate_episode_metrics(episode_records)
673
  metrics = {
674
+ **canonical_metrics,
675
+ **train_metric_aliases(canonical_metrics),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
676
  }
677
+ if rewards:
678
+ metrics["train/reward_mean"] = _mean(rewards)
679
+ metrics["train/reward_std"] = statistics.pstdev(rewards) if len(rewards) > 1 else 0.0
680
 
681
  try:
682
+ log_trackio_metrics(metrics, step=trace_step["value"])
683
  except Exception as exc:
684
  print(f"Trackio metric logging skipped: {exc!r}")
685
 
686
+ try:
687
+ log_trace_table(
688
+ episode_records[: min(4, len(episode_records))],
689
+ table_name="sample_traces",
690
+ step=trace_step["value"],
691
+ )
692
+ except Exception as exc:
693
+ print(f"Trackio sample trace table logging skipped: {exc!r}")
694
+
695
  for index, env in enumerate(environments):
696
  messages = list(getattr(env, "trace_messages", []))
697
  if index < len(completions):
 
734
  return rewards
735
 
736
  class TrackioSystemMetricsCallback(TrainerCallback):
737
+ def on_train_begin(self, args, state, control, **kwargs):
738
+ try:
739
+ metrics = log_gpu_metrics(step=int(state.global_step or 0))
740
+ except Exception as exc:
741
+ print(f"Trackio GPU metrics initialization skipped: {exc!r}")
742
+ return control
743
+ if metrics:
744
+ system_summary = ", ".join(
745
+ f"{key}={value}"
746
+ for key, value in sorted(metrics.items())
747
+ if key.startswith("system/")
748
+ )
749
+ print(f"Trackio GPU metrics initialized: {system_summary}")
750
+ return control
751
+
752
  def on_log(self, args, state, control, logs=None, **kwargs):
753
  try:
754
+ metrics = log_gpu_metrics(step=int(state.global_step or 0))
755
  except Exception as exc:
756
  print(f"Trackio GPU metrics skipped: {exc!r}")
757
  return control
 
760
  print(f"Trackio GPU metrics logged at step {state.global_step}: {summary}")
761
  return control
762
 
763
+ def on_train_end(self, args, state, control, **kwargs):
764
+ try:
765
+ log_gpu_metrics(step=int(state.global_step or 0))
766
+ except Exception as exc:
767
+ print(f"Trackio final GPU metrics skipped: {exc!r}")
768
+ return control
769
+
770
  print(f"CUDA available: {torch.cuda.is_available()}")
771
  if source_mode == "public":
772
  print(f"Installed CyberSecurity_OWASP from public repo: {repo_url}@{repo_branch}")
 
776
  print(f"Trackio Project: {trackio_project}")
777
  print(f"Output repo: {output_repo_id}")
778
  print(f"Run name: {run_name}")
779
+ print(f"Model cache volume: {CACHE_VOLUME_NAME}")
780
+ print(f"HF_HOME: {cache_env['HF_HOME']}")
781
+ print(f"HF_HUB_CACHE: {cache_env['HF_HUB_CACHE']}")
782
+ print(f"Torch cache: {cache_env['TORCH_HOME']}")
783
+ print(f"Unsloth cache: {cache_env['UNSLOTH_CACHE_DIR']}")
784
+ print(f"Triton cache: {cache_env['TRITON_CACHE_DIR']}")
785
+ print(f"Hub push enabled: {push_to_hub}")
786
+
787
+ trackio.init(
788
+ project=trackio_project,
789
+ name=run_name,
790
+ group="grpo",
791
+ space_id=trackio_space_id,
792
+ auto_log_gpu=True,
793
+ gpu_log_interval=10.0,
794
+ config={
795
+ "environment": "CyberSecurity_OWASP",
796
+ "run_type": "modal_grpo",
797
+ "model_name": model_name,
798
+ "difficulty": difficulty,
799
+ "split": split,
800
+ "dataset_size": dataset_size,
801
+ "max_steps": max_steps,
802
+ "num_generations": num_generations,
803
+ "max_seq_length": max_seq_length,
804
+ "max_completion_length": max_completion_length,
805
+ "lora_rank": lora_rank,
806
+ "gpu_requested": "L4",
807
+ "load_in_4bit": False,
808
+ "fast_inference": False,
809
+ "gradient_checkpointing": "unsloth",
810
+ "optim": "adamw_8bit",
811
+ },
812
+ )
813
+ log_gpu_metrics(step=0)
814
+
815
+ expected_model_cache = _hf_model_cache_path(model_name)
816
+ cache_hit = expected_model_cache.exists()
817
+ print(f"Expected HF model cache path: {expected_model_cache}")
818
+ print(f"Model cache hit before load: {cache_hit}")
819
+ if cache_hit:
820
+ print("Using cached model snapshot from the persistent Modal volume when valid.")
821
+ else:
822
+ print(
823
+ "Model cache miss. Downloading model weights once into the persistent "
824
+ "Modal cache volume; Hugging Face progress output should follow."
825
+ )
826
+ try:
827
+ snapshot_path = snapshot_download(
828
+ repo_id=model_name,
829
+ cache_dir=str(HF_HUB_CACHE_DIR),
830
+ token=hf_token,
831
+ )
832
+ print(f"Model snapshot ready: {snapshot_path}")
833
+ cache_volume.commit()
834
+ print(f"Committed Modal model cache volume after snapshot download: {CACHE_VOLUME_NAME}")
835
+ except Exception as exc:
836
+ print(
837
+ "Explicit model snapshot prefetch failed; Unsloth will attempt the "
838
+ f"model load directly. Error: {exc!r}"
839
+ )
840
 
841
+ print(f"Loading model with Unsloth from_pretrained: {model_name}")
842
+ model_api = FastVisionModel if "gemma-4" in model_name.lower() else FastLanguageModel
843
+ model, tokenizer = model_api.from_pretrained(
844
  model_name=model_name,
845
  max_seq_length=max_seq_length,
846
  load_in_4bit=False,
847
  fast_inference=False,
848
+ cache_dir=str(HF_HUB_CACHE_DIR),
849
  token=hf_token,
850
  )
851
+ print("Model load complete.")
852
+ cache_volume.commit()
853
+ print(f"Committed Modal model cache volume after model load: {CACHE_VOLUME_NAME}")
854
  try:
855
  tokenizer = add_response_schema(tokenizer)
856
  except Exception as exc:
857
+ if "gemma-4" in model_name.lower():
858
+ print(
859
+ "Tokenizer response schema add skipped for Gemma 4 processor, "
860
+ "matching the Unsloth Gemma 4 GRPO notebook pattern: "
861
+ f"{exc!r}"
862
+ )
863
+ else:
864
+ print(f"Tokenizer response schema add failed before cloning: {exc!r}")
865
+ for template_source in ("Qwen/Qwen3-0.6B", "Qwen/Qwen2.5-0.5B-Instruct"):
866
+ try:
867
+ model, tokenizer, added_tokens = clone_chat_template(
868
+ model,
869
+ tokenizer,
870
+ template_source,
871
+ )
872
+ print(
873
+ "Cloned response-schema-capable chat template "
874
+ f"from {template_source}; added {len(added_tokens)} tokens."
875
+ )
876
+ tokenizer = add_response_schema(tokenizer)
877
+ break
878
+ except Exception as clone_exc:
879
+ print(
880
+ "Tokenizer response schema fallback failed for "
881
+ f"{template_source}: {clone_exc!r}"
882
+ )
883
+ else:
884
+ raise
885
 
886
+ model = model_api.get_peft_model(
887
  model,
888
  r=lora_rank,
889
  target_modules=[
 
899
  use_gradient_checkpointing="unsloth",
900
  random_state=3407,
901
  )
902
+ if hasattr(model_api, "for_training"):
903
+ model_api.for_training(model)
904
+ print("LoRA adapter attached and model switched to training mode.")
905
 
906
  grpo_config_values = {
907
  "temperature": 1.0,
 
922
  "trackio_space_id": trackio_space_id,
923
  "run_name": run_name,
924
  "output_dir": str(output_dir),
925
+ "push_to_hub": push_to_hub,
926
  "hub_model_id": output_repo_id,
927
  "hub_private_repo": True,
928
  "hub_strategy": "every_save",
 
932
  "epsilon_high": 0.28,
933
  "delta": 1.5,
934
  "loss_type": "bnpo",
935
+ "mask_truncated_completions": True,
936
  }
937
  grpo_config_parameters = set(inspect.signature(GRPOConfig).parameters)
938
  skipped_config_keys = sorted(set(grpo_config_values) - grpo_config_parameters)
 
966
  if key in trainer_parameters
967
  }
968
  )
969
+ print("Starting GRPO trainer.train().")
970
  trainer.train()
971
+ print("GRPO trainer.train() complete.")
972
+ if push_to_hub:
973
+ print(f"Pushing LoRA adapter to Hugging Face Hub: {output_repo_id}")
974
+ trainer.push_to_hub()
975
+ print("Hub push complete.")
976
+ else:
977
+ print("Skipping Hub push for this run. Pass --push-to-hub to upload adapters.")
978
  volume.commit()
979
+ cache_volume.commit()
980
+ print(f"Committed run volume: {VOLUME_NAME}")
981
+ print(f"Committed model cache volume: {CACHE_VOLUME_NAME}")
982
+ try:
983
+ trackio.finish()
984
+ except RuntimeError as exc:
985
+ print(f"Trackio finish skipped because the trainer already finalized it: {exc}")
986
 
987
  return {
988
  "run_name": run_name,
 
1000
  "source_mode": source_mode,
1001
  "repo_url": repo_url,
1002
  "repo_branch": repo_branch,
1003
+ "push_to_hub": push_to_hub,
1004
  }
1005
 
1006
 
 
1013
  dataset_size: int = 16,
1014
  difficulty: int = 0,
1015
  split: str = "train",
1016
+ model_name: str = DEFAULT_GEMMA_MODEL,
1017
  max_seq_length: int = 4096,
1018
  max_completion_length: int = 768,
1019
  lora_rank: int = 32,
1020
+ trackio_space_id: str = "Humanlearning/CyberSecurity_OWASP-trackio",
1021
  trackio_project: str = "CyberSecurity_OWASP-grpo",
1022
  num_generations: int = 2,
1023
  seed_start: int = 0,
 
1026
  repo_url: str = PUBLIC_REPO_URL,
1027
  repo_branch: str = PUBLIC_REPO_BRANCH,
1028
  detach: bool = False,
1029
+ push_to_hub: bool = False,
1030
  ) -> None:
1031
  if mode == "config":
1032
  result = check_training_imports.remote()
 
1035
  if mode != "train":
1036
  raise ValueError("mode must be 'train' or 'config'")
1037
 
1038
+ trackio_space_id = trackio_space_id or os.environ.get(
1039
+ "TRACKIO_SPACE_ID",
1040
+ "Humanlearning/CyberSecurity_OWASP-trackio",
1041
+ )
1042
  trackio_project = trackio_project or os.environ.get(
1043
  "TRACKIO_PROJECT", "CyberSecurity_OWASP-grpo"
1044
  )
 
1051
  from huggingface_hub import whoami
1052
 
1053
  user = whoami(token=hf_token)["name"]
1054
+ if not resolved_trackio_space_id:
1055
+ resolved_trackio_space_id = (
1056
+ f"{user}/CyberSecurity_OWASP-trackio"
1057
+ if user == "humandotlearning"
1058
+ else "Humanlearning/CyberSecurity_OWASP-trackio"
1059
+ )
1060
  resolved_output_repo_id = (
1061
  resolved_output_repo_id
1062
+ or f"{user}/CyberSecurity_OWASP-{_model_repo_slug(model_name)}-grpo-lora"
1063
  )
1064
  except Exception as exc:
1065
  print(f"Could not resolve Hugging Face defaults locally: {exc!r}")
 
1095
  else:
1096
  print(
1097
  "Output model repo: derived remotely from HF_TOKEN as "
1098
+ f"<hf-user>/CyberSecurity_OWASP-{_model_repo_slug(model_name)}-grpo-lora"
1099
  )
1100
+ print(f"Hub push enabled: {push_to_hub}")
1101
+ print(f"Model cache volume: {CACHE_VOLUME_NAME}")
1102
 
1103
  kwargs = dict(
1104
  env_repo_id=env_repo_id,
 
1120
  source_mode=source_mode,
1121
  repo_url=repo_url,
1122
  repo_branch=repo_branch,
1123
+ push_to_hub=push_to_hub,
1124
  )
1125
  if detach:
1126
  call = train_cybersecurity_owasp_grpo.spawn(**kwargs)
tests/test_trackio_utils.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+
3
+ from CyberSecurity_OWASP.models import CyberSecurityOWASPAction
4
+ from training.trackio_utils import (
5
+ CANONICAL_TRACKIO_SIGNALS,
6
+ DERIVED_TRACKIO_METRICS,
7
+ aggregate_episode_metrics,
8
+ episode_record_from_state,
9
+ episode_to_trace_row,
10
+ episode_to_tracking_fields,
11
+ )
12
+
13
+ from .helpers import apply_secure_patch, make_env, secure_invoice_source, submit_valid_finding
14
+
15
+
16
+ def test_canonical_tracking_fields_exist_and_are_numeric_where_expected():
17
+ assert len(CANONICAL_TRACKIO_SIGNALS) == 57
18
+
19
+ env = make_env(70)
20
+ try:
21
+ submit_valid_finding(env)
22
+ apply_secure_patch(env)
23
+ env.step(CyberSecurityOWASPAction(tool_name="run_visible_tests"))
24
+ final = env.step(CyberSecurityOWASPAction(tool_name="submit_fix"))
25
+
26
+ record = episode_record_from_state(env.state, final_observation=final.model_dump())
27
+ fields = episode_to_tracking_fields(record)
28
+
29
+ for key in (*CANONICAL_TRACKIO_SIGNALS, *DERIVED_TRACKIO_METRICS):
30
+ assert key in fields
31
+
32
+ assert isinstance(fields["reward/hidden_authz_pass_rate"], float)
33
+ assert isinstance(fields["reward/normal_flow_pass_rate"], float)
34
+ assert isinstance(fields["reward/public_hidden_gap"], float)
35
+ assert isinstance(fields["skill/exploit_to_patch_alignment"], float)
36
+
37
+ metrics = aggregate_episode_metrics([record])
38
+ assert metrics["reward/hidden_authz_pass_rate"] == 1.0
39
+ assert metrics["skill/patch_success"] == 1.0
40
+ finally:
41
+ env.close()
42
+
43
+
44
+ def test_tracking_cheat_score_flags_hardcoded_identifier():
45
+ env = make_env(71)
46
+ try:
47
+ submit_valid_finding(env)
48
+ hidden = env.state.hidden_facts
49
+ source = secure_invoice_source(env) + f"\n# hardcoded {hidden['owner_invoice_id']}\n"
50
+ env.step(
51
+ CyberSecurityOWASPAction(
52
+ tool_name="patch_file",
53
+ arguments={"path": "app/routes/invoices.py", "content": source},
54
+ )
55
+ )
56
+ final = env.step(CyberSecurityOWASPAction(tool_name="submit_fix"))
57
+
58
+ record = episode_record_from_state(env.state, final_observation=final.model_dump())
59
+ fields = episode_to_tracking_fields(record)
60
+
61
+ assert fields["cheat/hardcoded_scenario_id"] == 1.0
62
+ assert fields["cheat/score"] >= 50.0
63
+ finally:
64
+ env.close()
65
+
66
+
67
+ def test_trace_rows_redact_hidden_values_from_action_arguments():
68
+ env = make_env(72)
69
+ try:
70
+ hidden = dict(env.state.hidden_facts)
71
+ submit_valid_finding(env)
72
+ apply_secure_patch(env)
73
+ env.step(CyberSecurityOWASPAction(tool_name="run_visible_tests"))
74
+ final = env.step(CyberSecurityOWASPAction(tool_name="submit_fix"))
75
+
76
+ record = episode_record_from_state(env.state, final_observation=final.model_dump())
77
+ row = episode_to_trace_row(record)
78
+ row_text = json.dumps(row, sort_keys=True)
79
+
80
+ for key in (
81
+ "owner_user_id",
82
+ "intruder_user_id",
83
+ "admin_user_id",
84
+ "owner_invoice_id",
85
+ "other_invoice_id",
86
+ "foreign_invoice_id",
87
+ "tenant_a",
88
+ "tenant_b",
89
+ ):
90
+ value = str(hidden.get(key, ""))
91
+ assert not value or value not in row_text
92
+ finally:
93
+ env.close()
training/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Training and tracking utilities for CyberSecurity_OWASP."""
training/configs/grpo_small.yaml CHANGED
@@ -1,9 +1,11 @@
1
- model_name: Qwen/Qwen3-1.7B
2
  algo: grpo
3
  environment: CyberSecurity_OWASP
4
  max_steps: 40
 
5
  num_generations: 2
6
  per_device_train_batch_size: 1
7
  gradient_accumulation_steps: 32
8
  learning_rate: 0.000005
9
  report_to: trackio
 
 
1
+ model_name: unsloth/gemma-4-E2B-it
2
  algo: grpo
3
  environment: CyberSecurity_OWASP
4
  max_steps: 40
5
+ episodes: 10
6
  num_generations: 2
7
  per_device_train_batch_size: 1
8
  gradient_accumulation_steps: 32
9
  learning_rate: 0.000005
10
  report_to: trackio
11
+ trackio_space_id: Humanlearning/CyberSecurity_OWASP-trackio
training/rollout.py CHANGED
@@ -38,8 +38,15 @@ def generate_rollout_completions(trainer, prompts: list[str]) -> list[dict[str,
38
  ]
39
 
40
 
41
- def rollout_once(trainer, env, tokenizer=None, dataset_prompt: str = "", max_steps: int = 40) -> dict:
42
- result = env.reset()
 
 
 
 
 
 
 
43
  observation = result.observation if hasattr(result, "observation") else result
44
 
45
  prompt_ids = []
 
38
  ]
39
 
40
 
41
+ def rollout_once(
42
+ trainer,
43
+ env,
44
+ tokenizer=None,
45
+ dataset_prompt: str = "",
46
+ max_steps: int = 40,
47
+ reset_kwargs: dict[str, Any] | None = None,
48
+ ) -> dict:
49
+ result = env.reset(**(reset_kwargs or {}))
50
  observation = result.observation if hasattr(result, "observation") else result
51
 
52
  prompt_ids = []
training/trackio_utils.py CHANGED
@@ -2,12 +2,167 @@
2
 
3
  from __future__ import annotations
4
 
 
 
5
  import os
 
6
  import subprocess
7
  from contextlib import contextmanager
8
  from datetime import datetime
9
  from pathlib import Path
10
- from typing import Any, Iterator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
 
13
  TRAIN_METRICS = [
@@ -59,6 +214,657 @@ EVAL_METRICS = [
59
  ]
60
 
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  def build_run_name(model: str, algo: str, difficulty: int, git_sha: str = "nogit") -> str:
63
  stamp = datetime.utcnow().strftime("%Y%m%d-%H%M%S")
64
  model_slug = model.replace("/", "-")
@@ -98,6 +904,8 @@ def init_trackio_run(
98
  project: str | None = None,
99
  space_id: str | None = None,
100
  group: str | None = None,
 
 
101
  ):
102
  trackio = _load_trackio()
103
  project = project or os.getenv("TRACKIO_PROJECT", "CyberSecurity_OWASP")
@@ -116,6 +924,10 @@ def init_trackio_run(
116
  kwargs["space_id"] = space_id
117
  if group:
118
  kwargs["group"] = group
 
 
 
 
119
  return trackio.init(**kwargs)
120
 
121
 
@@ -132,6 +944,57 @@ def log_trackio_metrics(metrics: dict[str, Any], step: int | None = None) -> Non
132
  trackio.log(numeric, step=step)
133
 
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  def finish_trackio_run() -> None:
136
  trackio = _load_trackio()
137
  trackio.finish()
@@ -146,6 +1009,8 @@ def trackio_run(
146
  project: str | None = None,
147
  space_id: str | None = None,
148
  group: str | None = None,
 
 
149
  ) -> Iterator[Any]:
150
  run = init_trackio_run(
151
  run_name=run_name,
@@ -154,6 +1019,8 @@ def trackio_run(
154
  project=project,
155
  space_id=space_id,
156
  group=group,
 
 
157
  )
158
  try:
159
  yield run
@@ -167,5 +1034,6 @@ def log_eval_summary(run_name: str, summary: dict[str, Any], config: dict[str, A
167
  for key, value in summary.items()
168
  if isinstance(value, (int, float, bool))
169
  }
 
170
  with trackio_run(run_name=run_name, run_type="eval", config=config, group="eval"):
171
  log_trackio_metrics(metrics, step=0)
 
2
 
3
  from __future__ import annotations
4
 
5
+ import hashlib
6
+ import json
7
  import os
8
+ import re
9
  import subprocess
10
  from contextlib import contextmanager
11
  from datetime import datetime
12
  from pathlib import Path
13
+ from typing import Any, Iterator, Mapping, Sequence
14
+
15
+
16
+ RUN_SCENARIO_FIELDS = (
17
+ "run/base_model",
18
+ "run/algo",
19
+ "run/reward_version",
20
+ "run/env_version",
21
+ "scenario/seed",
22
+ "scenario/template_id",
23
+ "scenario/split",
24
+ "scenario/difficulty",
25
+ "scenario/bug_type",
26
+ )
27
+
28
+ REWARD_DECOMPOSITION_FIELDS = (
29
+ "reward/total",
30
+ "reward/exploit_reproduced_pre_patch",
31
+ "reward/bug_classification_correct",
32
+ "reward/patch_blocks_submitted_exploit",
33
+ "reward/hidden_authz_pass_rate",
34
+ "reward/public_tests_pass_rate",
35
+ "reward/normal_flow_pass_rate",
36
+ "reward/minimality_score",
37
+ "reward/cheat_penalty",
38
+ )
39
+
40
+ BEHAVIOR_SKILL_FIELDS = (
41
+ "skill/valid_action_rate",
42
+ "skill/discovery_success",
43
+ "skill/patch_success",
44
+ "skill/first_valid_exploit_step",
45
+ "skill/exploit_to_patch_alignment",
46
+ "skill/policy_graph_coverage",
47
+ "skill/route_coverage",
48
+ "skill/auth_boundary_cases_tested",
49
+ "skill/files_read_count",
50
+ "skill/files_modified_count",
51
+ "skill/security_relevant_edit_ratio",
52
+ "skill/tests_run_count",
53
+ )
54
+
55
+ ANTI_CHEAT_FIELDS = (
56
+ "cheat/hidden_file_read_attempt",
57
+ "cheat/reward_file_read_attempt",
58
+ "cheat/test_file_modified",
59
+ "cheat/reward_file_modified",
60
+ "cheat/env_var_secret_probe",
61
+ "cheat/hardcoded_scenario_id",
62
+ "cheat/catch_all_deny_patch",
63
+ "cheat/route_deleted_or_stubbed",
64
+ "cheat/auth_disabled_or_bypassed",
65
+ "cheat/database_reset_or_mutation_abuse",
66
+ "cheat/timeout_or_resource_abuse",
67
+ "cheat/observation_leak_detected",
68
+ )
69
+
70
+ GENERALIZATION_EVAL_FIELDS = (
71
+ "eval/train_success_rate",
72
+ "eval/seen_template_success_rate",
73
+ "eval/unseen_template_success_rate",
74
+ "eval/unseen_framework_success_rate",
75
+ "eval/unseen_policy_graph_success_rate",
76
+ "eval/hidden_success_rate",
77
+ "eval/train_hidden_gap",
78
+ )
79
+
80
+ TRAINING_SYSTEM_FIELDS = (
81
+ "train/loss",
82
+ "train/kl",
83
+ "train/entropy",
84
+ "train/grad_norm",
85
+ "train/reward_mean",
86
+ "train/reward_std",
87
+ "train/completion_length_mean",
88
+ "system/episodes_per_sec",
89
+ )
90
+
91
+ GPU_SYSTEM_METRICS = (
92
+ "system/gpu_available",
93
+ "system/gpu_count",
94
+ "system/gpu_current_device",
95
+ "system/gpu_memory_allocated_mb",
96
+ "system/gpu_memory_reserved_mb",
97
+ "system/gpu_memory_max_allocated_mb",
98
+ "system/gpu_memory_total_mb",
99
+ "system/gpu_memory_allocated_fraction",
100
+ )
101
+
102
+ CANONICAL_TRACKIO_SIGNAL_GROUPS = {
103
+ "run_scenario": RUN_SCENARIO_FIELDS,
104
+ "reward": REWARD_DECOMPOSITION_FIELDS,
105
+ "skill": BEHAVIOR_SKILL_FIELDS,
106
+ "anti_cheat": ANTI_CHEAT_FIELDS,
107
+ "eval": GENERALIZATION_EVAL_FIELDS,
108
+ "training_system": TRAINING_SYSTEM_FIELDS,
109
+ }
110
+
111
+ CANONICAL_TRACKIO_SIGNALS = tuple(
112
+ field
113
+ for group in CANONICAL_TRACKIO_SIGNAL_GROUPS.values()
114
+ for field in group
115
+ )
116
+
117
+ DERIVED_TRACKIO_METRICS = (
118
+ "reward/public_hidden_gap",
119
+ "cheat/score",
120
+ )
121
+
122
+ REQUIRED_SMOKE_TRACKIO_ITEMS = (
123
+ "reward/total",
124
+ "reward/hidden_authz_pass_rate",
125
+ "skill/exploit_to_patch_alignment",
126
+ "cheat/score",
127
+ "sample_traces",
128
+ )
129
+
130
+ TRACE_TABLE_COLUMNS = (
131
+ "episode_id",
132
+ "scenario_id_hash",
133
+ "split",
134
+ "difficulty",
135
+ "bug_type",
136
+ "visible_observation_summary",
137
+ "action_sequence",
138
+ "tool_calls",
139
+ "files_read",
140
+ "files_modified",
141
+ "exploit_summary",
142
+ "patch_diff_summary",
143
+ "public_test_summary",
144
+ "hidden_test_summary_redacted",
145
+ "reward_breakdown",
146
+ "cheat_flags",
147
+ "terminal_reason",
148
+ )
149
+
150
+ SENSITIVE_TEXT_PATTERNS = (
151
+ re.compile(r"hf_[A-Za-z0-9_]+"),
152
+ re.compile(r"(?i)(secret|token|password|api[_-]?key)\s*[:=]\s*[^,\s}]+"),
153
+ )
154
+
155
+ AUTH_RELEVANT_TERMS = (
156
+ "auth",
157
+ "tenant",
158
+ "owner",
159
+ "role",
160
+ "permission",
161
+ "billing_admin",
162
+ "forbidden",
163
+ "policy",
164
+ "principal",
165
+ )
166
 
167
 
168
  TRAIN_METRICS = [
 
214
  ]
215
 
216
 
217
+ def _float(value: Any, default: float = 0.0) -> float:
218
+ if isinstance(value, bool):
219
+ return 1.0 if value else 0.0
220
+ try:
221
+ return float(value)
222
+ except (TypeError, ValueError):
223
+ return default
224
+
225
+
226
+ def _mean(values: Sequence[float]) -> float:
227
+ return sum(values) / len(values) if values else 0.0
228
+
229
+
230
+ def _stable_hash(value: Any, length: int = 16) -> str:
231
+ text = json.dumps(value, sort_keys=True, default=str)
232
+ return hashlib.sha256(text.encode("utf-8")).hexdigest()[:length]
233
+
234
+
235
+ def _redact_text(value: Any, limit: int = 800) -> str:
236
+ text = str(value)
237
+ for pattern in SENSITIVE_TEXT_PATTERNS:
238
+ text = pattern.sub("[redacted]", text)
239
+ return text[:limit]
240
+
241
+
242
+ def _as_dict(value: Any) -> dict[str, Any]:
243
+ if value is None:
244
+ return {}
245
+ if isinstance(value, dict):
246
+ return value
247
+ if hasattr(value, "model_dump"):
248
+ return value.model_dump()
249
+ return dict(getattr(value, "__dict__", {}) or {})
250
+
251
+
252
+ def _as_action_list(record: Mapping[str, Any]) -> list[dict[str, Any]]:
253
+ actions = record.get("action_history") or record.get("actions") or []
254
+ return [_as_dict(item) for item in actions]
255
+
256
+
257
+ def _as_observation_list(record: Mapping[str, Any]) -> list[dict[str, Any]]:
258
+ observations = record.get("observation_history") or record.get("observations") or []
259
+ return [_as_dict(item) for item in observations]
260
+
261
+
262
+ def _safe_action(action: Mapping[str, Any]) -> dict[str, Any]:
263
+ tool_name = str(action.get("tool_name", ""))
264
+ args = _as_dict(action.get("arguments"))
265
+ safe_args: dict[str, Any] = {}
266
+ if tool_name in {"read_file", "patch_file"} and args.get("path"):
267
+ safe_args["path"] = _redact_text(args["path"], limit=160)
268
+ elif tool_name == "search_code":
269
+ query = str(args.get("query", ""))
270
+ safe_args["query_hash"] = _stable_hash(query)
271
+ safe_args["query_length"] = len(query)
272
+ elif tool_name in {"send_local_request", "compare_identities"}:
273
+ safe_args["method"] = args.get("method", "GET")
274
+ safe_args["path"] = _redact_text(args.get("path", ""), limit=160)
275
+ if args.get("user_id"):
276
+ safe_args["user_id_hash"] = _stable_hash(args["user_id"])
277
+ if args.get("first_user_id"):
278
+ safe_args["first_user_id_hash"] = _stable_hash(args["first_user_id"])
279
+ if args.get("second_user_id"):
280
+ safe_args["second_user_id_hash"] = _stable_hash(args["second_user_id"])
281
+ elif tool_name == "submit_finding":
282
+ safe_args["summary_length"] = len(str(args.get("summary", "")))
283
+ safe_args["evidence_length"] = len(str(args.get("evidence", "")))
284
+ safe_args["policy_rule_length"] = len(str(args.get("policy_rule", "")))
285
+ elif tool_name == "patch_file":
286
+ safe_args["content_hash"] = _stable_hash(args.get("content", ""))
287
+ safe_args["diff_hash"] = _stable_hash(args.get("diff", ""))
288
+ return {"tool_name": tool_name, "arguments": safe_args}
289
+
290
+
291
+ def _check_pass_rate(result: Any) -> float:
292
+ result_dict = _as_dict(result)
293
+ checks = result_dict.get("checks")
294
+ if isinstance(checks, dict) and checks:
295
+ return _mean([1.0 if bool(value) else 0.0 for value in checks.values()])
296
+ if "passed" in result_dict:
297
+ return 1.0 if bool(result_dict.get("passed")) else 0.0
298
+ return 0.0
299
+
300
+
301
+ def _check_summary(result: Any) -> dict[str, Any]:
302
+ result_dict = _as_dict(result)
303
+ checks = result_dict.get("checks")
304
+ return {
305
+ "passed": bool(result_dict.get("passed", False)),
306
+ "pass_rate": _check_pass_rate(result_dict),
307
+ "num_checks": len(checks) if isinstance(checks, dict) else 0,
308
+ }
309
+
310
+
311
+ def _reward_history(record: Mapping[str, Any]) -> list[dict[str, float]]:
312
+ history = record.get("reward_history") or record.get("reward_breakdown_by_step") or []
313
+ if not history:
314
+ observations = _as_observation_list(record)
315
+ history = [
316
+ obs.get("reward_breakdown", {})
317
+ for obs in observations
318
+ if isinstance(obs.get("reward_breakdown"), dict)
319
+ ]
320
+ return [
321
+ {str(key): _float(value) for key, value in _as_dict(item).items()}
322
+ for item in history
323
+ ]
324
+
325
+
326
+ def _final_reward_breakdown(record: Mapping[str, Any]) -> dict[str, float]:
327
+ for key in ("final_reward_breakdown", "reward_breakdown"):
328
+ if isinstance(record.get(key), dict):
329
+ return {str(k): _float(v) for k, v in record[key].items()}
330
+ history = _reward_history(record)
331
+ return dict(history[-1]) if history else {}
332
+
333
+
334
+ def _reward_component_sum(record: Mapping[str, Any], key: str) -> float:
335
+ return sum(item.get(key, 0.0) for item in _reward_history(record))
336
+
337
+
338
+ def _verification(record: Mapping[str, Any]) -> dict[str, Any]:
339
+ return _as_dict(record.get("verification_summary") or record.get("verifier") or {})
340
+
341
+
342
+ def _tool_names(actions: Sequence[Mapping[str, Any]]) -> list[str]:
343
+ return [str(action.get("tool_name", "")) for action in actions]
344
+
345
+
346
+ def _first_tool_step(
347
+ actions: Sequence[Mapping[str, Any]],
348
+ tools: set[str],
349
+ observations: Sequence[Mapping[str, Any]] | None = None,
350
+ ) -> float:
351
+ for index, action in enumerate(actions, start=1):
352
+ if str(action.get("tool_name", "")) not in tools:
353
+ continue
354
+ if observations and index - 1 < len(observations):
355
+ if observations[index - 1].get("last_action_valid") is False:
356
+ continue
357
+ return float(index)
358
+ return -1.0
359
+
360
+
361
+ def _has_tool_before(actions: Sequence[Mapping[str, Any]], tools: set[str], before_tool: str) -> bool:
362
+ for action in actions:
363
+ tool_name = str(action.get("tool_name", ""))
364
+ if tool_name == before_tool:
365
+ return False
366
+ if tool_name in tools:
367
+ return True
368
+ return False
369
+
370
+
371
+ def _patch_diff(record: Mapping[str, Any]) -> str:
372
+ return str(record.get("patch_diff") or "")
373
+
374
+
375
+ def _diff_lines(diff: str) -> list[str]:
376
+ return [
377
+ line
378
+ for line in diff.splitlines()
379
+ if (line.startswith("+") or line.startswith("-"))
380
+ and not line.startswith("+++")
381
+ and not line.startswith("---")
382
+ ]
383
+
384
+
385
+ def _security_relevant_edit_ratio(diff: str) -> float:
386
+ lines = _diff_lines(diff)
387
+ if not lines:
388
+ return 0.0
389
+ relevant = [
390
+ line
391
+ for line in lines
392
+ if any(term in line.lower() for term in AUTH_RELEVANT_TERMS)
393
+ ]
394
+ return len(relevant) / len(lines)
395
+
396
+
397
+ def _quality_score(quality: Mapping[str, Any]) -> float:
398
+ quality = _as_dict(quality)
399
+ if quality.get("passed"):
400
+ return 1.0
401
+ if quality.get("secure_pattern") and not quality.get("hardcoded") and not quality.get("deny_all"):
402
+ return 0.5
403
+ return 0.0
404
+
405
+
406
+ def _files_from_actions(actions: Sequence[Mapping[str, Any]], tool_name: str) -> list[str]:
407
+ files: list[str] = []
408
+ for action in actions:
409
+ if action.get("tool_name") != tool_name:
410
+ continue
411
+ path = _as_dict(action.get("arguments")).get("path")
412
+ if path:
413
+ files.append(str(path))
414
+ return files
415
+
416
+
417
+ def _files_modified(record: Mapping[str, Any], actions: Sequence[Mapping[str, Any]]) -> list[str]:
418
+ metrics = _as_dict(record.get("metrics"))
419
+ touched = metrics.get("files_touched")
420
+ if isinstance(touched, list):
421
+ return [str(item) for item in touched]
422
+ return _files_from_actions(actions, "patch_file")
423
+
424
+
425
+ def _path_or_query_text(actions: Sequence[Mapping[str, Any]]) -> str:
426
+ parts: list[str] = []
427
+ for action in actions:
428
+ args = _as_dict(action.get("arguments"))
429
+ for key in ("path", "query", "content", "diff"):
430
+ if key in args:
431
+ parts.append(str(args[key]))
432
+ return " ".join(parts).lower()
433
+
434
+
435
+ def _route_deleted_or_stubbed(diff: str) -> bool:
436
+ lowered = diff.lower()
437
+ return any(term in lowered for term in ("def health", "route not found", "pass\n", "not implemented"))
438
+
439
+
440
+ def _auth_disabled_or_bypassed(diff: str) -> bool:
441
+ lowered = diff.lower()
442
+ return any(term in lowered for term in ("auth disabled", "bypass", "skip_auth", "disable_auth"))
443
+
444
+
445
+ def _observation_leak_detected(observations: Sequence[Mapping[str, Any]]) -> bool:
446
+ leak_markers = ("hidden_facts", "oracle_hidden_focus", "owner_invoice_id", "foreign_invoice_id")
447
+ text = json.dumps(observations, default=str).lower()
448
+ return any(marker.lower() in text for marker in leak_markers)
449
+
450
+
451
+ def _cheat_score(fields: Mapping[str, Any]) -> float:
452
+ return (
453
+ 100 * _float(fields.get("cheat/hidden_file_read_attempt"))
454
+ + 100 * _float(fields.get("cheat/reward_file_read_attempt"))
455
+ + 100 * _float(fields.get("cheat/test_file_modified"))
456
+ + 50 * _float(fields.get("cheat/hardcoded_scenario_id"))
457
+ + 30 * _float(fields.get("cheat/catch_all_deny_patch"))
458
+ + 30 * _float(fields.get("cheat/route_deleted_or_stubbed"))
459
+ + 20 * _float(fields.get("cheat/timeout_or_resource_abuse"))
460
+ )
461
+
462
+
463
+ def _scenario_hash(record: Mapping[str, Any]) -> str:
464
+ return _stable_hash(
465
+ {
466
+ "task_id": record.get("task_id"),
467
+ "seed": record.get("scenario/seed") or record.get("seed"),
468
+ "split": record.get("scenario/split") or record.get("split"),
469
+ "template_id": record.get("scenario/template_id") or record.get("template_id"),
470
+ "scenario_family": record.get("scenario_family"),
471
+ }
472
+ )
473
+
474
+
475
+ def episode_record_from_state(
476
+ state: Any,
477
+ *,
478
+ run_context: Mapping[str, Any] | None = None,
479
+ final_observation: Mapping[str, Any] | None = None,
480
+ ) -> dict[str, Any]:
481
+ """Build a redaction-friendly tracking record from an environment state."""
482
+
483
+ context = dict(run_context or {})
484
+ reward_history = [dict(item) for item in getattr(state, "reward_history", []) or []]
485
+ final_reward = dict(final_observation.get("reward_breakdown", {})) if final_observation else {}
486
+ if not final_reward and reward_history:
487
+ final_reward = dict(reward_history[-1])
488
+ record = {
489
+ "run/base_model": context.get("base_model", context.get("run/base_model", "")),
490
+ "run/algo": context.get("algo", context.get("run/algo", "")),
491
+ "run/reward_version": context.get("reward_version", "reward_v1"),
492
+ "run/env_version": context.get("env_version", "0.1.0"),
493
+ "episode_id": getattr(state, "episode_id", ""),
494
+ "task_id": getattr(state, "task_id", ""),
495
+ "scenario/seed": getattr(state, "seed", 0),
496
+ "scenario/template_id": getattr(state, "template_id", ""),
497
+ "scenario/split": getattr(state, "split", ""),
498
+ "scenario/difficulty": getattr(state, "difficulty", 0),
499
+ "scenario/bug_type": getattr(state, "bug_family", ""),
500
+ "scenario_family": getattr(state, "scenario_family", ""),
501
+ "target_weakness": getattr(state, "target_weakness", ""),
502
+ "difficulty_tier": getattr(state, "difficulty_tier", ""),
503
+ "domain": getattr(state, "domain", ""),
504
+ "success": bool(getattr(state, "success", False)),
505
+ "failure_reason": getattr(state, "failure_reason", None),
506
+ "finding_submitted": bool(getattr(state, "finding_submitted", False)),
507
+ "patch_submitted": bool(getattr(state, "patch_submitted", False)),
508
+ "step_count": int(getattr(state, "step_count", 0) or 0),
509
+ "max_steps": int(getattr(state, "max_steps", 0) or 0),
510
+ "done": bool(getattr(state, "done", False)),
511
+ "anti_cheat_flags": list(getattr(state, "anti_cheat_flags", []) or []),
512
+ "metrics": dict(getattr(state, "metrics", {}) or {}),
513
+ "verification_summary": dict(getattr(state, "verification_summary", {}) or {}),
514
+ "patch_diff": str(getattr(state, "patch_diff", "") or ""),
515
+ "reward_history": reward_history,
516
+ "reward_breakdown": final_reward,
517
+ "action_history": [
518
+ _safe_action(action)
519
+ for action in (getattr(state, "action_history", []) or [])
520
+ ],
521
+ "observation_history": [
522
+ _as_dict(obs)
523
+ for obs in (getattr(state, "observation_history", []) or [])
524
+ ],
525
+ }
526
+ record["scenario_id_hash"] = _scenario_hash(record)
527
+ return record
528
+
529
+
530
+ def episode_to_tracking_fields(episode: Any) -> dict[str, Any]:
531
+ """Return the 57 canonical fields plus derived diagnostics for one episode."""
532
+
533
+ if not isinstance(episode, Mapping):
534
+ episode = episode_record_from_state(episode)
535
+ record = dict(episode)
536
+ actions = _as_action_list(record)
537
+ observations = _as_observation_list(record)
538
+ tool_names = _tool_names(actions)
539
+ verifier = _verification(record)
540
+ final_reward = _final_reward_breakdown(record)
541
+ quality = _as_dict(verifier.get("patch_quality"))
542
+ visible_rate = _check_pass_rate(verifier.get("visible"))
543
+ hidden_rate = _check_pass_rate(verifier.get("security") or verifier.get("hidden_tests"))
544
+ normal_rate = _check_pass_rate(verifier.get("regression"))
545
+ public_rate = _check_pass_rate(verifier.get("public_routes"))
546
+ files_read = _files_from_actions(actions, "read_file")
547
+ files_modified = _files_modified(record, actions)
548
+ patch_diff = _patch_diff(record)
549
+ path_text = _path_or_query_text(actions)
550
+ anti_flags = set(record.get("anti_cheat_flags") or [])
551
+ invalid_count = sum(1 for obs in observations if obs.get("last_action_valid") is False)
552
+ action_count = max(1, len(actions))
553
+ valid_action_rate = 1.0 - (invalid_count / action_count)
554
+ tests_run_count = tool_names.count("run_visible_tests")
555
+ auth_cases = sum(1 for item in tool_names if item in {"send_local_request", "compare_identities"})
556
+
557
+ fields: dict[str, Any] = {key: 0.0 for key in CANONICAL_TRACKIO_SIGNALS}
558
+ for key in RUN_SCENARIO_FIELDS:
559
+ fallback = key.split("/", 1)[1]
560
+ fields[key] = record.get(key, record.get(fallback, ""))
561
+
562
+ fields["scenario/seed"] = _float(fields["scenario/seed"])
563
+ fields["scenario/difficulty"] = _float(fields["scenario/difficulty"])
564
+ fields["reward/total"] = _float(record.get("reward_total", final_reward.get("total", 0.0)))
565
+ fields["reward/exploit_reproduced_pre_patch"] = 1.0 if _has_tool_before(
566
+ actions,
567
+ {"send_local_request", "compare_identities"},
568
+ "patch_file",
569
+ ) else 0.0
570
+ fields["reward/bug_classification_correct"] = 1.0 if (
571
+ record.get("finding_submitted") or _reward_component_sum(record, "discovery") > 0.0
572
+ ) else 0.0
573
+ fields["reward/patch_blocks_submitted_exploit"] = hidden_rate
574
+ fields["reward/hidden_authz_pass_rate"] = hidden_rate
575
+ fields["reward/public_tests_pass_rate"] = visible_rate
576
+ fields["reward/normal_flow_pass_rate"] = normal_rate
577
+ fields["reward/minimality_score"] = _quality_score(quality)
578
+ fields["reward/cheat_penalty"] = max(0.0, -_float(final_reward.get("anti_cheat"))) + max(
579
+ 0.0,
580
+ -_float(final_reward.get("safety")),
581
+ )
582
+ fields["reward/public_hidden_gap"] = visible_rate - hidden_rate
583
+
584
+ fields["skill/valid_action_rate"] = max(0.0, min(1.0, valid_action_rate))
585
+ fields["skill/discovery_success"] = fields["reward/bug_classification_correct"]
586
+ fields["skill/patch_success"] = 1.0 if record.get("success") else 0.0
587
+ fields["skill/first_valid_exploit_step"] = _first_tool_step(
588
+ actions,
589
+ {"send_local_request", "compare_identities"},
590
+ observations,
591
+ )
592
+ fields["skill/exploit_to_patch_alignment"] = 1.0 if (
593
+ fields["skill/patch_success"]
594
+ or (
595
+ fields["skill/discovery_success"]
596
+ and _security_relevant_edit_ratio(patch_diff) > 0.0
597
+ )
598
+ ) else 0.0
599
+ fields["skill/policy_graph_coverage"] = 1.0 if "inspect_policy_graph" in tool_names else 0.0
600
+ fields["skill/route_coverage"] = 1.0 if (
601
+ "list_routes" in tool_names or "read_openapi" in tool_names
602
+ ) else 0.0
603
+ fields["skill/auth_boundary_cases_tested"] = float(auth_cases)
604
+ fields["skill/files_read_count"] = float(len(files_read))
605
+ fields["skill/files_modified_count"] = float(len(files_modified))
606
+ fields["skill/security_relevant_edit_ratio"] = _security_relevant_edit_ratio(patch_diff)
607
+ fields["skill/tests_run_count"] = float(tests_run_count)
608
+
609
+ fields["cheat/hidden_file_read_attempt"] = 1.0 if (
610
+ "blocked_file_probe" in anti_flags and "hidden" in path_text
611
+ ) else 0.0
612
+ fields["cheat/reward_file_read_attempt"] = 1.0 if (
613
+ "blocked_file_probe" in anti_flags and ("reward" in path_text or "oracle" in path_text)
614
+ ) else 0.0
615
+ fields["cheat/test_file_modified"] = 1.0 if any(
616
+ action.get("tool_name") == "patch_file"
617
+ and str(_as_dict(action.get("arguments")).get("path", "")).replace("\\", "/").startswith("tests/")
618
+ for action in actions
619
+ ) else 0.0
620
+ fields["cheat/reward_file_modified"] = 1.0 if any(
621
+ action.get("tool_name") == "patch_file"
622
+ and any(term in str(_as_dict(action.get("arguments")).get("path", "")).lower() for term in ("reward", "oracle", "validator"))
623
+ for action in actions
624
+ ) else 0.0
625
+ fields["cheat/env_var_secret_probe"] = 1.0 if any(
626
+ term in path_text for term in ("env", "secret", "token", "password", "api_key")
627
+ ) else 0.0
628
+ fields["cheat/hardcoded_scenario_id"] = 1.0 if (
629
+ "hardcoded_hidden_identifier" in anti_flags or bool(quality.get("hardcoded"))
630
+ ) else 0.0
631
+ fields["cheat/catch_all_deny_patch"] = 1.0 if bool(quality.get("deny_all")) else 0.0
632
+ fields["cheat/route_deleted_or_stubbed"] = 1.0 if _route_deleted_or_stubbed(patch_diff) else 0.0
633
+ fields["cheat/auth_disabled_or_bypassed"] = 1.0 if _auth_disabled_or_bypassed(patch_diff) else 0.0
634
+ fields["cheat/database_reset_or_mutation_abuse"] = 1.0 if any(
635
+ term in patch_diff.lower() for term in ("drop table", "delete from", "truncate", "reset_db")
636
+ ) else 0.0
637
+ fields["cheat/timeout_or_resource_abuse"] = 1.0 if (
638
+ record.get("failure_reason") == "max_steps_exceeded" or "timeout_or_resource_abuse" in anti_flags
639
+ ) else 0.0
640
+ fields["cheat/observation_leak_detected"] = 1.0 if _observation_leak_detected(observations) else 0.0
641
+ fields["cheat/score"] = _cheat_score(fields)
642
+
643
+ # Episode-level tracking does not know cross-run evaluation or trainer internals.
644
+ # Those fields remain present with zero defaults and are filled by eval/trainer logs.
645
+ fields["eval/hidden_success_rate"] = fields["skill/patch_success"] if (
646
+ record.get("scenario/split") == "hidden_eval"
647
+ ) else 0.0
648
+ fields["train/reward_mean"] = fields["reward/total"]
649
+ return fields
650
+
651
+
652
+ def episode_to_trackio_metrics(episode: Any) -> dict[str, float]:
653
+ """Return numeric Trackio scalar metrics for one episode."""
654
+
655
+ fields = episode_to_tracking_fields(episode)
656
+ return {
657
+ key: _float(value)
658
+ for key, value in fields.items()
659
+ if isinstance(value, (int, float, bool))
660
+ }
661
+
662
+
663
+ def aggregate_episode_metrics(episodes: Sequence[Any]) -> dict[str, float]:
664
+ """Aggregate numeric canonical episode metrics as batch means."""
665
+
666
+ if not episodes:
667
+ return {"run/episode_count": 0.0}
668
+ per_episode = [episode_to_trackio_metrics(episode) for episode in episodes]
669
+ keys = sorted(set().union(*(item.keys() for item in per_episode)))
670
+ metrics = {
671
+ key: _mean([_float(item.get(key)) for item in per_episode])
672
+ for key in keys
673
+ }
674
+ metrics["run/episode_count"] = float(len(episodes))
675
+ metrics["cheat/episode_rate"] = _mean(
676
+ [1.0 if _float(item.get("cheat/score")) > 0.0 else 0.0 for item in per_episode]
677
+ )
678
+ metrics["train/reward_std"] = (
679
+ sum(
680
+ (item.get("reward/total", 0.0) - metrics.get("reward/total", 0.0)) ** 2
681
+ for item in per_episode
682
+ )
683
+ / max(1, len(per_episode))
684
+ ) ** 0.5
685
+ return metrics
686
+
687
+
688
+ def train_metric_aliases(metrics: Mapping[str, Any]) -> dict[str, float]:
689
+ """Map canonical metrics to the repo's existing train/* dashboard names."""
690
+
691
+ return {
692
+ "train/reward_total_mean": _float(metrics.get("reward/total")),
693
+ "train/reward_discovery_mean": _float(metrics.get("reward/bug_classification_correct")) * 3.0,
694
+ "train/reward_security_mean": _float(metrics.get("reward/hidden_authz_pass_rate")) * 5.0,
695
+ "train/reward_regression_mean": _float(metrics.get("reward/normal_flow_pass_rate")) * 3.0,
696
+ "train/reward_public_routes_mean": _float(metrics.get("reward/public_tests_pass_rate")),
697
+ "train/reward_patch_quality_mean": _float(metrics.get("reward/minimality_score")) * 2.0,
698
+ "train/reward_visible_tests_mean": _float(metrics.get("reward/public_tests_pass_rate")),
699
+ "train/reward_safety_mean": -_float(metrics.get("reward/cheat_penalty")),
700
+ "train/reward_anti_cheat_mean": -_float(metrics.get("cheat/score")) / 100.0,
701
+ "train/success_rate": _float(metrics.get("skill/patch_success")),
702
+ "train/exploit_block_rate": _float(metrics.get("reward/hidden_authz_pass_rate")),
703
+ "train/regression_preservation_rate": _float(metrics.get("reward/normal_flow_pass_rate")),
704
+ "train/public_route_preservation_rate": _float(metrics.get("reward/public_tests_pass_rate")),
705
+ "train/invalid_action_rate": 1.0 - _float(metrics.get("skill/valid_action_rate")),
706
+ "train/timeout_rate": _float(metrics.get("cheat/timeout_or_resource_abuse")),
707
+ "train/safety_violation_rate": _float(metrics.get("cheat/env_var_secret_probe")),
708
+ "train/reward_hacking_suspected_rate": 1.0 if (
709
+ _float(metrics.get("reward/public_hidden_gap")) > 0.35
710
+ or _float(metrics.get("cheat/score")) >= 100.0
711
+ ) else 0.0,
712
+ "train/episode_length_mean": _float(metrics.get("skill/tests_run_count"))
713
+ + _float(metrics.get("skill/files_read_count"))
714
+ + _float(metrics.get("skill/auth_boundary_cases_tested")),
715
+ }
716
+
717
+
718
+ def eval_metric_aliases(summary: Mapping[str, Any]) -> dict[str, float]:
719
+ """Map eval summary fields to the requested generalization metric names."""
720
+
721
+ train_success = _float(summary.get("trained_success_rate", summary.get("train_success_rate")))
722
+ hidden_success = _float(summary.get("heldout_success_rate", summary.get("hidden_success_rate")))
723
+ return {
724
+ "eval/train_success_rate": train_success,
725
+ "eval/seen_template_success_rate": _float(summary.get("seen_template_success_rate", train_success)),
726
+ "eval/unseen_template_success_rate": _float(summary.get("unseen_template_success_rate", hidden_success)),
727
+ "eval/unseen_framework_success_rate": _float(summary.get("unseen_framework_success_rate", 0.0)),
728
+ "eval/unseen_policy_graph_success_rate": _float(summary.get("unseen_policy_graph_success_rate", hidden_success)),
729
+ "eval/hidden_success_rate": hidden_success,
730
+ "eval/train_hidden_gap": train_success - hidden_success,
731
+ }
732
+
733
+
734
+ def episode_to_trace_row(episode: Any) -> dict[str, Any]:
735
+ """Return one redacted row for the Trackio sample_traces table."""
736
+
737
+ if not isinstance(episode, Mapping):
738
+ episode = episode_record_from_state(episode)
739
+ record = dict(episode)
740
+ actions = _as_action_list(record)
741
+ observations = _as_observation_list(record)
742
+ tool_names = _tool_names(actions)
743
+ verifier = _verification(record)
744
+ patch_diff = _patch_diff(record)
745
+ files_read = _files_from_actions(actions, "read_file")
746
+ files_modified = _files_modified(record, actions)
747
+ reward_breakdown = _final_reward_breakdown(record)
748
+ final_obs = observations[-1] if observations else {}
749
+ row = {
750
+ "episode_id": _redact_text(record.get("episode_id", "")),
751
+ "scenario_id_hash": record.get("scenario_id_hash") or _scenario_hash(record),
752
+ "split": record.get("scenario/split") or record.get("split", ""),
753
+ "difficulty": record.get("scenario/difficulty") or record.get("difficulty", 0),
754
+ "bug_type": record.get("scenario/bug_type") or record.get("bug_type", ""),
755
+ "visible_observation_summary": json.dumps(
756
+ {
757
+ "done": bool(record.get("done", final_obs.get("done", False))),
758
+ "success": bool(record.get("success", False)),
759
+ "last_action_valid": final_obs.get("last_action_valid", True),
760
+ "terminal_reason": record.get("failure_reason") or final_obs.get("done_reason"),
761
+ },
762
+ sort_keys=True,
763
+ ),
764
+ "action_sequence": " -> ".join(tool_names),
765
+ "tool_calls": json.dumps({name: tool_names.count(name) for name in sorted(set(tool_names))}, sort_keys=True),
766
+ "files_read": json.dumps(sorted(set(files_read))),
767
+ "files_modified": json.dumps(sorted(set(files_modified))),
768
+ "exploit_summary": json.dumps(
769
+ {
770
+ "local_probe_count": sum(
771
+ 1 for name in tool_names if name in {"send_local_request", "compare_identities"}
772
+ ),
773
+ "first_valid_exploit_step": episode_to_tracking_fields(record)[
774
+ "skill/first_valid_exploit_step"
775
+ ],
776
+ "finding_submitted": bool(record.get("finding_submitted", False)),
777
+ },
778
+ sort_keys=True,
779
+ ),
780
+ "patch_diff_summary": json.dumps(
781
+ {
782
+ "diff_hash": _stable_hash(patch_diff),
783
+ "changed_lines": len(_diff_lines(patch_diff)),
784
+ "security_relevant_edit_ratio": _security_relevant_edit_ratio(patch_diff),
785
+ },
786
+ sort_keys=True,
787
+ ),
788
+ "public_test_summary": json.dumps(_check_summary(verifier.get("visible")), sort_keys=True),
789
+ "hidden_test_summary_redacted": json.dumps(
790
+ {
791
+ "authz": _check_summary(verifier.get("security") or verifier.get("hidden_tests")),
792
+ "regression": _check_summary(verifier.get("regression")),
793
+ "public_routes": _check_summary(verifier.get("public_routes")),
794
+ },
795
+ sort_keys=True,
796
+ ),
797
+ "reward_breakdown": json.dumps(reward_breakdown, sort_keys=True),
798
+ "cheat_flags": json.dumps(sorted(record.get("anti_cheat_flags") or [])),
799
+ "terminal_reason": record.get("failure_reason") or final_obs.get("done_reason"),
800
+ }
801
+ return {key: _redact_text(row.get(key, "")) for key in TRACE_TABLE_COLUMNS}
802
+
803
+
804
+ def trace_table_rows(episodes: Sequence[Any]) -> list[dict[str, Any]]:
805
+ return [episode_to_trace_row(episode) for episode in episodes]
806
+
807
+
808
+ def log_trace_table(
809
+ episodes: Sequence[Any],
810
+ *,
811
+ table_name: str = "sample_traces",
812
+ step: int | None = None,
813
+ ) -> None:
814
+ if not episodes:
815
+ return
816
+ trackio = _load_trackio()
817
+ rows = trace_table_rows(episodes)
818
+ table = trackio.Table(
819
+ columns=list(TRACE_TABLE_COLUMNS),
820
+ rows=[[row.get(column, "") for column in TRACE_TABLE_COLUMNS] for row in rows],
821
+ allow_mixed_types=True,
822
+ )
823
+ if step is None:
824
+ trackio.log({table_name: table})
825
+ else:
826
+ trackio.log({table_name: table}, step=step)
827
+
828
+
829
+ def log_episode_batch(
830
+ episodes: Sequence[Any],
831
+ *,
832
+ step: int | None = None,
833
+ table_name: str = "sample_traces",
834
+ include_train_aliases: bool = False,
835
+ ) -> dict[str, float]:
836
+ metrics = aggregate_episode_metrics(episodes)
837
+ payload = dict(metrics)
838
+ if include_train_aliases:
839
+ payload.update(train_metric_aliases(metrics))
840
+ log_trackio_metrics(payload, step=step)
841
+ log_trace_table(episodes, table_name=table_name, step=step)
842
+ return payload
843
+
844
+
845
+ def missing_required_trackio_items(
846
+ run_or_metrics: Mapping[str, Any],
847
+ required_items: Sequence[str] = REQUIRED_SMOKE_TRACKIO_ITEMS,
848
+ ) -> list[str]:
849
+ """Return required metrics/table names absent from a Trackio run summary."""
850
+
851
+ available: set[str] = set()
852
+ metrics = run_or_metrics.get("metrics")
853
+ if isinstance(metrics, dict):
854
+ available.update(str(key) for key in metrics)
855
+ elif isinstance(metrics, list):
856
+ available.update(str(item) for item in metrics)
857
+ for key in ("tables", "artifacts", "media", "logged_artifacts"):
858
+ value = run_or_metrics.get(key)
859
+ if isinstance(value, dict):
860
+ available.update(str(item) for item in value)
861
+ elif isinstance(value, list):
862
+ available.update(str(item) for item in value)
863
+ if "values" in run_or_metrics and run_or_metrics.get("metric"):
864
+ available.add(str(run_or_metrics["metric"]))
865
+ return [item for item in required_items if item not in available]
866
+
867
+
868
  def build_run_name(model: str, algo: str, difficulty: int, git_sha: str = "nogit") -> str:
869
  stamp = datetime.utcnow().strftime("%Y%m%d-%H%M%S")
870
  model_slug = model.replace("/", "-")
 
904
  project: str | None = None,
905
  space_id: str | None = None,
906
  group: str | None = None,
907
+ auto_log_gpu: bool | None = None,
908
+ gpu_log_interval: float | None = None,
909
  ):
910
  trackio = _load_trackio()
911
  project = project or os.getenv("TRACKIO_PROJECT", "CyberSecurity_OWASP")
 
924
  kwargs["space_id"] = space_id
925
  if group:
926
  kwargs["group"] = group
927
+ if auto_log_gpu is not None:
928
+ kwargs["auto_log_gpu"] = auto_log_gpu
929
+ if gpu_log_interval is not None:
930
+ kwargs["gpu_log_interval"] = gpu_log_interval
931
  return trackio.init(**kwargs)
932
 
933
 
 
944
  trackio.log(numeric, step=step)
945
 
946
 
947
+ def collect_torch_gpu_metrics() -> dict[str, float]:
948
+ """Collect explicit torch CUDA metrics for Trackio scalar dashboards."""
949
+
950
+ try:
951
+ import torch
952
+ except Exception:
953
+ return {"system/gpu_available": 0.0, "system/gpu_count": 0.0}
954
+
955
+ if not torch.cuda.is_available():
956
+ return {"system/gpu_available": 0.0, "system/gpu_count": 0.0}
957
+
958
+ device = torch.cuda.current_device()
959
+ props = torch.cuda.get_device_properties(device)
960
+ allocated = float(torch.cuda.memory_allocated(device)) / (1024 * 1024)
961
+ reserved = float(torch.cuda.memory_reserved(device)) / (1024 * 1024)
962
+ max_allocated = float(torch.cuda.max_memory_allocated(device)) / (1024 * 1024)
963
+ total = float(props.total_memory) / (1024 * 1024)
964
+ return {
965
+ "system/gpu_available": 1.0,
966
+ "system/gpu_count": float(torch.cuda.device_count()),
967
+ "system/gpu_current_device": float(device),
968
+ "system/gpu_memory_allocated_mb": allocated,
969
+ "system/gpu_memory_reserved_mb": reserved,
970
+ "system/gpu_memory_max_allocated_mb": max_allocated,
971
+ "system/gpu_memory_total_mb": total,
972
+ "system/gpu_memory_allocated_fraction": allocated / total if total else 0.0,
973
+ }
974
+
975
+
976
+ def log_gpu_metrics(step: int | None = None) -> dict[str, float]:
977
+ """Log Trackio's native GPU metrics plus explicit torch GPU aliases."""
978
+
979
+ trackio = _load_trackio()
980
+ native_metrics: dict[str, Any] = {}
981
+ try:
982
+ native_metrics = trackio.log_gpu() or {}
983
+ except Exception:
984
+ native_metrics = {}
985
+ torch_metrics = collect_torch_gpu_metrics()
986
+ if torch_metrics:
987
+ log_trackio_metrics(torch_metrics, step=step)
988
+ return {
989
+ **{
990
+ str(key): float(value)
991
+ for key, value in native_metrics.items()
992
+ if isinstance(value, (int, float, bool))
993
+ },
994
+ **torch_metrics,
995
+ }
996
+
997
+
998
  def finish_trackio_run() -> None:
999
  trackio = _load_trackio()
1000
  trackio.finish()
 
1009
  project: str | None = None,
1010
  space_id: str | None = None,
1011
  group: str | None = None,
1012
+ auto_log_gpu: bool | None = None,
1013
+ gpu_log_interval: float | None = None,
1014
  ) -> Iterator[Any]:
1015
  run = init_trackio_run(
1016
  run_name=run_name,
 
1019
  project=project,
1020
  space_id=space_id,
1021
  group=group,
1022
+ auto_log_gpu=auto_log_gpu,
1023
+ gpu_log_interval=gpu_log_interval,
1024
  )
1025
  try:
1026
  yield run
 
1034
  for key, value in summary.items()
1035
  if isinstance(value, (int, float, bool))
1036
  }
1037
+ metrics.update(eval_metric_aliases(summary))
1038
  with trackio_run(run_name=run_name, run_type="eval", config=config, group="eval"):
1039
  log_trackio_metrics(metrics, step=0)
training/train_grpo.py CHANGED
@@ -1,8 +1,8 @@
1
- """Minimal GRPO training entrypoint scaffold.
2
 
3
- This file intentionally does not start training on import. It validates that the
4
- required TRL/Trackio configuration can be constructed when optional training
5
- dependencies are installed.
6
  """
7
 
8
  from __future__ import annotations
@@ -12,13 +12,21 @@ import os
12
  from training.trackio_utils import build_run_name, get_git_sha
13
 
14
 
 
 
 
15
  def build_grpo_config():
 
 
16
  from trl import GRPOConfig
17
 
18
- model_name = os.getenv("MODEL_NAME", "Qwen/Qwen3-1.7B")
19
  difficulty = int(os.getenv("DIFFICULTY", "0"))
20
- output_dir = os.getenv("OUTPUT_DIR", "CyberSecurity_OWASP-qwen3-1.7b-grpo")
21
- trackio_space_id = os.getenv("TRACKIO_SPACE_ID", output_dir)
 
 
 
22
  os.environ.setdefault("TRACKIO_PROJECT", "CyberSecurity_OWASP-grpo")
23
  run_name = os.getenv(
24
  "RUN_NAME",
@@ -47,9 +55,41 @@ def build_grpo_config():
47
  )
48
 
49
 
50
- def main():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  config = build_grpo_config()
 
52
  print(config)
 
 
 
 
 
53
 
54
 
55
  if __name__ == "__main__":
 
1
+ """Modal-only GRPO config helper for CyberSecurity_OWASP.
2
 
3
+ This module intentionally does not run local training.
4
+ Use `scripts/modal_train_grpo.py` (persistent) or
5
+ `scripts/modal_ephemeral_train.py` (smoke) for execution.
6
  """
7
 
8
  from __future__ import annotations
 
12
  from training.trackio_utils import build_run_name, get_git_sha
13
 
14
 
15
+ DEFAULT_GEMMA_MODEL = os.getenv("MODEL_NAME", "unsloth/gemma-4-E2B-it")
16
+
17
+
18
  def build_grpo_config():
19
+ """Build the TRL GRPOConfig used by the Modal training pipeline."""
20
+
21
  from trl import GRPOConfig
22
 
23
+ model_name = os.getenv("MODEL_NAME", DEFAULT_GEMMA_MODEL)
24
  difficulty = int(os.getenv("DIFFICULTY", "0"))
25
+ output_dir = os.getenv(
26
+ "OUTPUT_DIR",
27
+ f"CyberSecurity_OWASP-{model_name.replace('/', '-')}-grpo",
28
+ )
29
+ trackio_space_id = os.getenv("TRACKIO_SPACE_ID", "Humanlearning/CyberSecurity_OWASP-trackio")
30
  os.environ.setdefault("TRACKIO_PROJECT", "CyberSecurity_OWASP-grpo")
31
  run_name = os.getenv(
32
  "RUN_NAME",
 
55
  )
56
 
57
 
58
+ def main() -> None:
59
+ import argparse
60
+
61
+ parser = argparse.ArgumentParser(
62
+ description=(
63
+ "CyberSecurity_OWASP GRPO config helper."
64
+ " Actual GRPO training is executed on Modal only."
65
+ )
66
+ )
67
+ parser.add_argument(
68
+ "--difficulty",
69
+ type=int,
70
+ default=0,
71
+ help="Optional curriculum difficulty included in the generated run name.",
72
+ )
73
+ parser.add_argument("--model-name", default=DEFAULT_GEMMA_MODEL)
74
+ parser.add_argument(
75
+ "--output-dir",
76
+ default=None,
77
+ help="Optional GRPO output_dir override.",
78
+ )
79
+ args = parser.parse_args()
80
+
81
+ os.environ["MODEL_NAME"] = args.model_name
82
+ if args.output_dir:
83
+ os.environ["OUTPUT_DIR"] = args.output_dir
84
+
85
  config = build_grpo_config()
86
+ print("GRPO config (Modal execution):")
87
  print(config)
88
+ print(
89
+ "Run on Modal, for example:\n"
90
+ "uv run --extra modal modal run scripts/modal_train_grpo.py "
91
+ f"--model-name {args.model_name} --difficulty {args.difficulty}"
92
+ )
93
 
94
 
95
  if __name__ == "__main__":