Humanlearning commited on
Commit
f7b8ac6
·
1 Parent(s): 5809a6c

feat: introduce reward ablation configurations for enhanced training flexibility, implement YAML loading with extends support, and add reward variant tracking in training scripts

Browse files
reward_config.py CHANGED
@@ -74,7 +74,7 @@ def load_reward_settings(path: str | Path | None = None) -> RewardSettings:
74
  or os.getenv("CYBERSECURITY_OWASP_REWARD_CONFIG", "")
75
  or DEFAULT_GRPO_CONFIG_PATH
76
  )
77
- raw = yaml.safe_load(configured_path.read_text(encoding="utf-8")) or {}
78
  reward = dict(raw.get("reward") or {})
79
  mode = os.getenv("CYBERSECURITY_OWASP_REWARD_MODE", str(reward.get("mode", "sparse_eval")))
80
  training_mode = str(reward.get("training_mode", "dense_train"))
@@ -90,6 +90,44 @@ def load_reward_settings(path: str | Path | None = None) -> RewardSettings:
90
  return settings
91
 
92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  def flatten_reward_config(
94
  settings: RewardSettings | None = None,
95
  ) -> list[dict[str, Any]]:
@@ -175,6 +213,7 @@ def reward_config_run_config(settings: RewardSettings | None = None) -> dict[str
175
  "reward_config_hash": summary["reward_config_hash"],
176
  "reward_config_source": summary["reward_config_source"],
177
  "reward_config_source_name": summary["reward_config_source_name"],
 
178
  "reward_mode": summary["reward_mode"],
179
  "reward_training_mode": summary["reward_training_mode"],
180
  "reward_stage": summary["reward_stage"],
 
74
  or os.getenv("CYBERSECURITY_OWASP_REWARD_CONFIG", "")
75
  or DEFAULT_GRPO_CONFIG_PATH
76
  )
77
+ raw = _load_yaml_with_extends(configured_path)
78
  reward = dict(raw.get("reward") or {})
79
  mode = os.getenv("CYBERSECURITY_OWASP_REWARD_MODE", str(reward.get("mode", "sparse_eval")))
80
  training_mode = str(reward.get("training_mode", "dense_train"))
 
90
  return settings
91
 
92
 
93
+ def _load_yaml_with_extends(path: Path, seen: set[Path] | None = None) -> dict[str, Any]:
94
+ """Load a YAML file, recursively merging an optional relative `extends` file."""
95
+
96
+ resolved_path = path.expanduser().resolve()
97
+ seen = seen or set()
98
+ if resolved_path in seen:
99
+ chain = " -> ".join(str(item) for item in [*seen, resolved_path])
100
+ raise ValueError(f"reward config extends cycle detected: {chain}")
101
+ seen.add(resolved_path)
102
+
103
+ raw = yaml.safe_load(resolved_path.read_text(encoding="utf-8")) or {}
104
+ if not isinstance(raw, dict):
105
+ raise ValueError(f"reward config must be a YAML mapping: {resolved_path}")
106
+
107
+ extends = raw.get("extends")
108
+ if not extends:
109
+ return raw
110
+ if not isinstance(extends, str):
111
+ raise ValueError("reward config extends must be a string path")
112
+
113
+ base_path = Path(extends)
114
+ if not base_path.is_absolute():
115
+ base_path = resolved_path.parent / base_path
116
+ child = {key: value for key, value in raw.items() if key != "extends"}
117
+ return _deep_merge(_load_yaml_with_extends(base_path, seen), child)
118
+
119
+
120
+ def _deep_merge(base: dict[str, Any], override: dict[str, Any]) -> dict[str, Any]:
121
+ merged = dict(base)
122
+ for key, value in override.items():
123
+ base_value = merged.get(key)
124
+ if isinstance(base_value, dict) and isinstance(value, dict):
125
+ merged[key] = _deep_merge(base_value, value)
126
+ else:
127
+ merged[key] = value
128
+ return merged
129
+
130
+
131
  def flatten_reward_config(
132
  settings: RewardSettings | None = None,
133
  ) -> list[dict[str, Any]]:
 
213
  "reward_config_hash": summary["reward_config_hash"],
214
  "reward_config_source": summary["reward_config_source"],
215
  "reward_config_source_name": summary["reward_config_source_name"],
216
+ "reward_variant": os.getenv("CYBERSECURITY_OWASP_REWARD_VARIANT", "default") or "default",
217
  "reward_mode": summary["reward_mode"],
218
  "reward_training_mode": summary["reward_training_mode"],
219
  "reward_stage": summary["reward_stage"],
scripts/generate_sft_dataset.py ADDED
@@ -0,0 +1,753 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generate verifier-gated SFT data for CyberSecurity_OWASP.
2
+
3
+ The default path asks a larger Hugging Face-hosted teacher model for one JSON
4
+ action at a time, executes those actions in the real environment, and keeps
5
+ only trajectories that pass the local deterministic verifier. The
6
+ ``--dry-run-oracle`` path is intentionally network-free and exists for CI and
7
+ smoke tests.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import argparse
13
+ import json
14
+ import os
15
+ import statistics
16
+ import subprocess
17
+ from dataclasses import dataclass
18
+ from pathlib import Path
19
+ from typing import Any, Iterable
20
+
21
+ from CyberSecurity_OWASP.models import CyberSecurityOWASPAction, CyberSecurityOWASPObservation
22
+ from CyberSecurity_OWASP.server.CyberSecurity_OWASP_environment import (
23
+ CybersecurityOwaspEnvironment,
24
+ )
25
+ from CyberSecurity_OWASP.validators import detect_cheating
26
+
27
+
28
+ DEFAULT_TEACHER_MODEL = "deepseek-ai/DeepSeek-V4-Pro"
29
+ DEFAULT_TARGET_MODEL = "unsloth/gemma-4-E2B-it"
30
+ TRAINING_SYSTEM_PROMPT = (
31
+ "You are a defensive AppSec repair agent in the local CyberSecurity_OWASP "
32
+ "OpenEnv environment. Use only the listed local tools. Do not target real "
33
+ "systems. Work step by step: inspect policy and generated code, reproduce "
34
+ "the authorization issue locally, submit a policy-tied diagnosis, patch the "
35
+ "generated app, run visible tests, then submit the fix. Return exactly one "
36
+ "JSON action object and no markdown."
37
+ )
38
+ BANNED_PROMPT_MARKERS = (
39
+ "hidden_facts",
40
+ "oracle_hidden_focus",
41
+ "reward_engine",
42
+ "validators.py",
43
+ "rewards.py",
44
+ "tests/hidden",
45
+ "hidden tests",
46
+ ".git",
47
+ )
48
+ RISKY_ARGUMENT_MARKERS = (
49
+ "hidden",
50
+ "oracle",
51
+ "reward_engine",
52
+ "validators.py",
53
+ "rewards.py",
54
+ ".git",
55
+ "..",
56
+ )
57
+
58
+
59
+ @dataclass
60
+ class DatasetConfig:
61
+ teacher_model: str = DEFAULT_TEACHER_MODEL
62
+ target_model: str = DEFAULT_TARGET_MODEL
63
+ split: str = "train"
64
+ difficulty: int = 0
65
+ seed_start: int = 0
66
+ episodes: int = 100
67
+ validation_episodes: int = 0
68
+ out_dir: Path = Path("outputs/sft")
69
+ max_steps: int = 40
70
+ max_teacher_retries: int = 2
71
+ max_tokens: int = 768
72
+ temperature: float = 0.2
73
+ top_p: float = 0.95
74
+ dry_run_oracle: bool = False
75
+
76
+
77
+ class HuggingFaceTeacher:
78
+ """Small wrapper around Hugging Face chat completion."""
79
+
80
+ def __init__(
81
+ self,
82
+ *,
83
+ model: str,
84
+ token: str,
85
+ max_tokens: int,
86
+ temperature: float,
87
+ top_p: float,
88
+ ) -> None:
89
+ try:
90
+ from huggingface_hub import InferenceClient
91
+ except ImportError as exc: # pragma: no cover - dependency smoke checked separately
92
+ raise RuntimeError(
93
+ "huggingface_hub is required for teacher generation. Install project "
94
+ "dependencies or use --dry-run-oracle for local CI."
95
+ ) from exc
96
+
97
+ self.model = model
98
+ self.max_tokens = int(max_tokens)
99
+ self.temperature = float(temperature)
100
+ self.top_p = float(top_p)
101
+ self.client = InferenceClient(token=token)
102
+
103
+ def complete(self, messages: list[dict[str, str]]) -> str:
104
+ response = self.client.chat_completion(
105
+ model=self.model,
106
+ messages=messages,
107
+ max_tokens=self.max_tokens,
108
+ temperature=self.temperature,
109
+ top_p=self.top_p,
110
+ )
111
+ return _chat_response_content(response)
112
+
113
+
114
+ def _chat_response_content(response: Any) -> str:
115
+ choices = getattr(response, "choices", None)
116
+ if choices:
117
+ message = getattr(choices[0], "message", None)
118
+ content = getattr(message, "content", None)
119
+ if content is not None:
120
+ return str(content)
121
+ if isinstance(response, dict):
122
+ choices = response.get("choices") or []
123
+ if choices:
124
+ message = choices[0].get("message") or {}
125
+ return str(message.get("content", ""))
126
+ return str(response)
127
+
128
+
129
+ def extract_first_json_object(text: str) -> dict[str, Any] | None:
130
+ """Extract the first JSON object from raw teacher text."""
131
+
132
+ stripped = text.strip()
133
+ candidates = [stripped]
134
+ if "```" in stripped:
135
+ for part in stripped.split("```"):
136
+ candidate = part.strip()
137
+ if candidate.startswith("json"):
138
+ candidate = candidate[4:].strip()
139
+ candidates.append(candidate)
140
+
141
+ for candidate in candidates:
142
+ try:
143
+ loaded = json.loads(candidate)
144
+ except Exception:
145
+ continue
146
+ if isinstance(loaded, dict):
147
+ return loaded
148
+
149
+ start = stripped.find("{")
150
+ while start >= 0:
151
+ depth = 0
152
+ in_string = False
153
+ escaped = False
154
+ for index in range(start, len(stripped)):
155
+ char = stripped[index]
156
+ if in_string:
157
+ if escaped:
158
+ escaped = False
159
+ elif char == "\\":
160
+ escaped = True
161
+ elif char == '"':
162
+ in_string = False
163
+ continue
164
+ if char == '"':
165
+ in_string = True
166
+ elif char == "{":
167
+ depth += 1
168
+ elif char == "}":
169
+ depth -= 1
170
+ if depth == 0:
171
+ try:
172
+ loaded = json.loads(stripped[start : index + 1])
173
+ except Exception:
174
+ break
175
+ if isinstance(loaded, dict):
176
+ return loaded
177
+ start = stripped.find("{", start + 1)
178
+ return None
179
+
180
+
181
+ def parse_action_text(text: str) -> CyberSecurityOWASPAction:
182
+ data = extract_first_json_object(text)
183
+ if data is None:
184
+ raise ValueError("teacher did not return a JSON object")
185
+ return CyberSecurityOWASPAction(**data)
186
+
187
+
188
+ def action_to_json(action: CyberSecurityOWASPAction) -> str:
189
+ return json.dumps(action.model_dump(), separators=(",", ":"), sort_keys=True)
190
+
191
+
192
+ def _safe_observation_payload(
193
+ observation: CyberSecurityOWASPObservation,
194
+ recent_actions: list[dict[str, Any]],
195
+ ) -> dict[str, Any]:
196
+ return {
197
+ "phase": observation.phase,
198
+ "task_brief": observation.task_brief,
199
+ "scenario_prompt": observation.scenario_prompt,
200
+ "available_actions": observation.available_actions,
201
+ "last_tool_result": observation.last_tool_result,
202
+ "last_action_valid": observation.last_action_valid,
203
+ "last_action_error": observation.last_action_error,
204
+ "visible_test_result": observation.visible_test_result,
205
+ "done_reason": observation.done_reason,
206
+ "recent_actions": recent_actions[-8:],
207
+ }
208
+
209
+
210
+ def build_user_prompt(
211
+ observation: CyberSecurityOWASPObservation,
212
+ recent_actions: list[dict[str, Any]],
213
+ retry_error: str | None = None,
214
+ ) -> str:
215
+ payload = _safe_observation_payload(observation, recent_actions)
216
+ prompt = (
217
+ "Current CyberSecurity_OWASP observation, containing only information "
218
+ "available to the agent:\n"
219
+ f"{json.dumps(payload, indent=2, sort_keys=True)}\n\n"
220
+ "Choose the next action. Output exactly one JSON object with keys "
221
+ "`tool_name` and `arguments`. Do not include markdown or commentary."
222
+ )
223
+ if retry_error:
224
+ prompt += f"\nPrevious candidate was rejected safely: {retry_error}"
225
+ _assert_prompt_is_safe(prompt)
226
+ return prompt
227
+
228
+
229
+ def _assert_prompt_is_safe(prompt: str) -> None:
230
+ lowered = prompt.lower()
231
+ leaked = [marker for marker in BANNED_PROMPT_MARKERS if marker.lower() in lowered]
232
+ if leaked:
233
+ raise ValueError(f"prompt contains blocked marker(s): {', '.join(leaked)}")
234
+
235
+
236
+ def build_chat_messages(
237
+ observation: CyberSecurityOWASPObservation,
238
+ recent_actions: list[dict[str, Any]],
239
+ retry_error: str | None = None,
240
+ ) -> list[dict[str, str]]:
241
+ return [
242
+ {"role": "system", "content": TRAINING_SYSTEM_PROMPT},
243
+ {"role": "user", "content": build_user_prompt(observation, recent_actions, retry_error)},
244
+ ]
245
+
246
+
247
+ def make_chat_row(
248
+ *,
249
+ messages: list[dict[str, str]],
250
+ action: CyberSecurityOWASPAction,
251
+ metadata: dict[str, Any],
252
+ ) -> dict[str, Any]:
253
+ return {
254
+ "messages": [
255
+ *messages,
256
+ {"role": "assistant", "content": action_to_json(action)},
257
+ ],
258
+ "metadata": metadata,
259
+ }
260
+
261
+
262
+ def preflight_action(
263
+ env: CybersecurityOwaspEnvironment,
264
+ observation: CyberSecurityOWASPObservation,
265
+ action: CyberSecurityOWASPAction,
266
+ ) -> tuple[bool, str]:
267
+ if action.tool_name not in observation.available_actions:
268
+ return False, f"{action.tool_name} is not allowed during {observation.phase}"
269
+ args = action.arguments or {}
270
+ flags = detect_cheating(env.state, action)
271
+ if flags:
272
+ return False, f"action triggered safety flags: {', '.join(flags)}"
273
+ arg_text = json.dumps(args, sort_keys=True, default=str).lower()
274
+ if any(marker in arg_text for marker in RISKY_ARGUMENT_MARKERS):
275
+ return False, "arguments reference blocked files or paths"
276
+ if action.tool_name == "read_file" and not args.get("path"):
277
+ return False, "read_file requires path"
278
+ if action.tool_name == "search_code" and not args.get("query"):
279
+ return False, "search_code requires query"
280
+ if action.tool_name == "patch_file":
281
+ path = str(args.get("path", ""))
282
+ if not path:
283
+ return False, "patch_file requires path"
284
+ if path.replace("\\", "/").startswith("tests/"):
285
+ return False, "patch_file cannot modify tests"
286
+ if not args.get("content") and not args.get("diff"):
287
+ return False, "patch_file requires content or diff"
288
+ if action.tool_name == "send_local_request":
289
+ path = str(args.get("path", ""))
290
+ if not path.startswith("/"):
291
+ return False, "send_local_request requires a local route path"
292
+ if action.tool_name == "compare_identities":
293
+ path = str(args.get("path", ""))
294
+ if not path.startswith("/"):
295
+ return False, "compare_identities requires a local route path"
296
+ if not args.get("first_user_id") or not args.get("second_user_id"):
297
+ return False, "compare_identities requires two user ids"
298
+ if action.tool_name == "submit_diagnosis":
299
+ required = ("bug_class", "route", "violated_policy_rule", "evidence_trace_ids", "fix_plan")
300
+ missing = [key for key in required if not args.get(key)]
301
+ if missing:
302
+ return False, f"submit_diagnosis missing: {', '.join(missing)}"
303
+ return True, ""
304
+
305
+
306
+ def _trace_id_from_observation(observation: CyberSecurityOWASPObservation) -> str:
307
+ try:
308
+ payload = json.loads(observation.last_tool_result)
309
+ except Exception:
310
+ return "req_001"
311
+ return str(payload.get("trace_id", "req_001"))
312
+
313
+
314
+ def _secure_invoice_source(env: CybersecurityOwaspEnvironment) -> str:
315
+ source = (Path(env.state.hidden_facts["workspace"]) / "app/routes/invoices.py").read_text(
316
+ encoding="utf-8"
317
+ )
318
+ return source.replace(
319
+ " # BUG: this only checks that the caller is authenticated. It forgets the\n"
320
+ " # owner/admin and tenant policy checks required by the policy graph.\n"
321
+ " return {\"status\": 200, \"body\": invoice}\n",
322
+ " if invoice[\"tenant_id\"] != actor[\"tenant_id\"]:\n"
323
+ " return {\"status\": 403, \"body\": {\"detail\": \"forbidden\"}}\n"
324
+ " if invoice[\"owner_user_id\"] != actor[\"user_id\"] and not is_billing_admin(actor):\n"
325
+ " return {\"status\": 403, \"body\": {\"detail\": \"forbidden\"}}\n"
326
+ " return {\"status\": 200, \"body\": invoice}\n",
327
+ )
328
+
329
+
330
+ def oracle_actions_for_state(
331
+ env: CybersecurityOwaspEnvironment,
332
+ evidence_trace_id: str | None = None,
333
+ ) -> list[CyberSecurityOWASPAction]:
334
+ hidden = env.state.hidden_facts
335
+ trace_id = evidence_trace_id or "req_001"
336
+ return [
337
+ CyberSecurityOWASPAction(tool_name="inspect_policy_graph", arguments={}),
338
+ CyberSecurityOWASPAction(tool_name="list_routes", arguments={}),
339
+ CyberSecurityOWASPAction(
340
+ tool_name="read_file",
341
+ arguments={"path": "app/routes/invoices.py"},
342
+ ),
343
+ CyberSecurityOWASPAction(
344
+ tool_name="send_local_request",
345
+ arguments={
346
+ "method": "GET",
347
+ "path": f"/invoices/{hidden['other_invoice_id']}",
348
+ "user_id": hidden["owner_user_id"],
349
+ },
350
+ ),
351
+ CyberSecurityOWASPAction(
352
+ tool_name="submit_diagnosis",
353
+ arguments={
354
+ "bug_class": "idor_ownership_bug",
355
+ "route": "GET /invoices/{invoice_id}",
356
+ "violated_policy_rule": "Only the owner or a billing_admin in the same tenant may read invoices.",
357
+ "evidence_trace_ids": [trace_id],
358
+ "fix_plan": "Add tenant and owner/admin checks before returning invoice data.",
359
+ },
360
+ ),
361
+ CyberSecurityOWASPAction(
362
+ tool_name="patch_file",
363
+ arguments={"path": "app/routes/invoices.py", "content": _secure_invoice_source(env)},
364
+ ),
365
+ CyberSecurityOWASPAction(tool_name="run_visible_tests", arguments={}),
366
+ CyberSecurityOWASPAction(tool_name="submit_fix", arguments={}),
367
+ ]
368
+
369
+
370
+ def _teacher_action(
371
+ *,
372
+ teacher: HuggingFaceTeacher,
373
+ env: CybersecurityOwaspEnvironment,
374
+ observation: CyberSecurityOWASPObservation,
375
+ recent_actions: list[dict[str, Any]],
376
+ config: DatasetConfig,
377
+ ) -> tuple[CyberSecurityOWASPAction, list[dict[str, str]]]:
378
+ retry_error: str | None = None
379
+ for _ in range(config.max_teacher_retries + 1):
380
+ messages = build_chat_messages(observation, recent_actions, retry_error)
381
+ raw = teacher.complete(messages)
382
+ try:
383
+ action = parse_action_text(raw)
384
+ except Exception as exc:
385
+ retry_error = str(exc)
386
+ continue
387
+ ok, error = preflight_action(env, observation, action)
388
+ if ok:
389
+ return action, messages
390
+ retry_error = error
391
+ raise ValueError(retry_error or "teacher did not produce a usable action")
392
+
393
+
394
+ def _oracle_action(
395
+ *,
396
+ env: CybersecurityOwaspEnvironment,
397
+ observation: CyberSecurityOWASPObservation,
398
+ recent_actions: list[dict[str, Any]],
399
+ oracle_actions: list[CyberSecurityOWASPAction],
400
+ step_index: int,
401
+ ) -> tuple[CyberSecurityOWASPAction, list[dict[str, str]]]:
402
+ action = oracle_actions[step_index]
403
+ messages = build_chat_messages(observation, recent_actions)
404
+ ok, error = preflight_action(env, observation, action)
405
+ if not ok:
406
+ raise ValueError(error)
407
+ return action, messages
408
+
409
+
410
+ def _terminal_checks_passed(env: CybersecurityOwaspEnvironment) -> bool:
411
+ verifier = env.state.verification_summary or {}
412
+ required = ("visible", "security", "regression", "public_routes", "patch_quality")
413
+ return all(bool((verifier.get(key) or {}).get("passed", False)) for key in required)
414
+
415
+
416
+ def _episode_reward(env: CybersecurityOwaspEnvironment) -> float:
417
+ if env.state.reward_history:
418
+ return float(env.state.reward_history[-1].get("terminal_total", 0.0))
419
+ return 0.0
420
+
421
+
422
+ def run_episode(
423
+ *,
424
+ seed: int,
425
+ split: str,
426
+ difficulty: int,
427
+ config: DatasetConfig,
428
+ teacher: HuggingFaceTeacher | None,
429
+ ) -> dict[str, Any]:
430
+ env = CybersecurityOwaspEnvironment()
431
+ rows: list[dict[str, Any]] = []
432
+ trajectory_steps: list[dict[str, Any]] = []
433
+ recent_actions: list[dict[str, Any]] = []
434
+ try:
435
+ observation = env.reset(seed=seed, split=split, difficulty=difficulty)
436
+ oracle_actions = oracle_actions_for_state(env) if config.dry_run_oracle else []
437
+ for step_index in range(config.max_steps):
438
+ if observation.done:
439
+ break
440
+ if config.dry_run_oracle:
441
+ if step_index >= len(oracle_actions):
442
+ raise ValueError("oracle action script ended before terminal state")
443
+ if step_index == 4 and env.state.request_trace:
444
+ trace_id = _trace_id_from_observation(observation)
445
+ oracle_actions = oracle_actions_for_state(env, evidence_trace_id=trace_id)
446
+ action, messages = _oracle_action(
447
+ env=env,
448
+ observation=observation,
449
+ recent_actions=recent_actions,
450
+ oracle_actions=oracle_actions,
451
+ step_index=step_index,
452
+ )
453
+ else:
454
+ if teacher is None:
455
+ raise RuntimeError("teacher is required unless --dry-run-oracle is set")
456
+ action, messages = _teacher_action(
457
+ teacher=teacher,
458
+ env=env,
459
+ observation=observation,
460
+ recent_actions=recent_actions,
461
+ config=config,
462
+ )
463
+
464
+ step_number = step_index + 1
465
+ action_record = action.model_dump()
466
+ row = make_chat_row(
467
+ messages=messages,
468
+ action=action,
469
+ metadata={
470
+ "target_model": config.target_model,
471
+ "teacher_model": config.teacher_model,
472
+ "seed": seed,
473
+ "split": split,
474
+ "difficulty": difficulty,
475
+ "step": step_number,
476
+ "tool_name": action.tool_name,
477
+ "task_id": env.state.task_id,
478
+ "episode_id": env.state.episode_id,
479
+ "scenario_hash": env.state.scenario_hash,
480
+ },
481
+ )
482
+ next_observation = env.step(action)
483
+ trajectory_steps.append(
484
+ {
485
+ "step": step_number,
486
+ "prompt_messages": messages,
487
+ "action": action_record,
488
+ "observation": next_observation.model_dump(),
489
+ "reward_breakdown": dict(next_observation.reward_breakdown or {}),
490
+ }
491
+ )
492
+ if not next_observation.last_action_valid:
493
+ raise ValueError(next_observation.last_action_error or "invalid action")
494
+ if env.state.anti_cheat_flags:
495
+ raise ValueError(f"anti-cheat flags: {env.state.anti_cheat_flags}")
496
+ rows.append(row)
497
+ recent_actions.append(action_record)
498
+ observation = next_observation
499
+ if observation.done:
500
+ break
501
+
502
+ if not env.state.done:
503
+ raise ValueError("episode did not reach a terminal state")
504
+ if not env.state.success:
505
+ raise ValueError(env.state.failure_reason or "terminal verifier failed")
506
+ if env.state.step_count > config.max_steps:
507
+ raise ValueError("episode exceeded max steps")
508
+ if env.state.anti_cheat_flags:
509
+ raise ValueError("episode has anti-cheat flags")
510
+ if not _terminal_checks_passed(env):
511
+ raise ValueError("terminal verifier checks did not all pass")
512
+
513
+ final_reward = _episode_reward(env)
514
+ final_breakdown = dict(env.state.reward_history[-1]) if env.state.reward_history else {}
515
+ for row in rows:
516
+ row["metadata"].update(
517
+ {
518
+ "final_success": True,
519
+ "terminal_total": final_reward,
520
+ "total_reward": float(env.state.accumulated_reward),
521
+ "anti_cheat_flags": list(env.state.anti_cheat_flags),
522
+ "final_reward_breakdown": final_breakdown,
523
+ }
524
+ )
525
+ return {
526
+ "accepted": True,
527
+ "seed": seed,
528
+ "split": split,
529
+ "difficulty": difficulty,
530
+ "rows": rows,
531
+ "trajectory": {
532
+ "episode_id": env.state.episode_id,
533
+ "task_id": env.state.task_id,
534
+ "seed": seed,
535
+ "split": split,
536
+ "difficulty": difficulty,
537
+ "domain": env.state.domain,
538
+ "bug_family": env.state.bug_family,
539
+ "scenario_hash": env.state.scenario_hash,
540
+ "actions": [step["action"] for step in trajectory_steps],
541
+ "steps": trajectory_steps,
542
+ "reward_breakdown_by_step": list(env.state.reward_history),
543
+ "final_reward_breakdown": final_breakdown,
544
+ "total_reward": float(env.state.accumulated_reward),
545
+ "terminal_total": final_reward,
546
+ "success": True,
547
+ "failure_reason": None,
548
+ "anti_cheat_flags": list(env.state.anti_cheat_flags),
549
+ "verification_summary": env.state.verification_summary,
550
+ },
551
+ }
552
+ except Exception as exc:
553
+ return {
554
+ "accepted": False,
555
+ "seed": seed,
556
+ "split": split,
557
+ "difficulty": difficulty,
558
+ "reason": str(exc),
559
+ "rows": [],
560
+ "trajectory": {
561
+ "seed": seed,
562
+ "split": split,
563
+ "difficulty": difficulty,
564
+ "steps": trajectory_steps,
565
+ "actions": [step["action"] for step in trajectory_steps],
566
+ "success": bool(env.state.success),
567
+ "failure_reason": env.state.failure_reason or str(exc),
568
+ "anti_cheat_flags": list(env.state.anti_cheat_flags),
569
+ },
570
+ }
571
+ finally:
572
+ env.close()
573
+
574
+
575
+ def write_jsonl(path: Path, rows: Iterable[dict[str, Any]]) -> None:
576
+ path.parent.mkdir(parents=True, exist_ok=True)
577
+ with path.open("w", encoding="utf-8") as handle:
578
+ for row in rows:
579
+ handle.write(json.dumps(row, sort_keys=True, default=str) + "\n")
580
+
581
+
582
+ def _write_trajectory(out_dir: Path, trajectory: dict[str, Any]) -> Path:
583
+ traj_dir = out_dir / "trajectories"
584
+ traj_dir.mkdir(parents=True, exist_ok=True)
585
+ name = (
586
+ f"{trajectory.get('split', 'train')}_seed{trajectory.get('seed', 0)}_"
587
+ f"{str(trajectory.get('episode_id', 'rejected'))[:12]}.json"
588
+ )
589
+ path = traj_dir / name
590
+ path.write_text(json.dumps(trajectory, indent=2, sort_keys=True, default=str), encoding="utf-8")
591
+ return path
592
+
593
+
594
+ def _git_sha() -> str:
595
+ root = Path(__file__).resolve().parents[1]
596
+ try:
597
+ return subprocess.check_output(
598
+ [
599
+ "git",
600
+ "-c",
601
+ f"safe.directory={root.as_posix()}",
602
+ "rev-parse",
603
+ "HEAD",
604
+ ],
605
+ cwd=root,
606
+ text=True,
607
+ stderr=subprocess.DEVNULL,
608
+ ).strip()
609
+ except Exception:
610
+ return "nogit"
611
+
612
+
613
+ def _reward_summary(values: list[float]) -> dict[str, float]:
614
+ if not values:
615
+ return {"mean": 0.0, "min": 0.0, "max": 0.0, "p50": 0.0}
616
+ sorted_values = sorted(values)
617
+ return {
618
+ "mean": float(statistics.mean(values)),
619
+ "min": float(min(values)),
620
+ "max": float(max(values)),
621
+ "p50": float(sorted_values[len(sorted_values) // 2]),
622
+ }
623
+
624
+
625
+ def generate_dataset(config: DatasetConfig) -> dict[str, Any]:
626
+ config.out_dir.mkdir(parents=True, exist_ok=True)
627
+ teacher = None
628
+ if not config.dry_run_oracle:
629
+ token = os.getenv("HF_TOKEN")
630
+ if not token:
631
+ raise RuntimeError("HF_TOKEN is required unless --dry-run-oracle is set")
632
+ teacher = HuggingFaceTeacher(
633
+ model=config.teacher_model,
634
+ token=token,
635
+ max_tokens=config.max_tokens,
636
+ temperature=config.temperature,
637
+ top_p=config.top_p,
638
+ )
639
+
640
+ split_jobs = [(config.split, config.episodes, config.seed_start)]
641
+ if config.validation_episodes:
642
+ split_jobs.append(("validation", config.validation_episodes, config.seed_start + config.episodes))
643
+
644
+ rows_by_split: dict[str, list[dict[str, Any]]] = {"train": [], "validation": []}
645
+ attempts: list[dict[str, Any]] = []
646
+ rewards: list[float] = []
647
+ accepted = 0
648
+ attempted = 0
649
+ for split, episodes, seed_start in split_jobs:
650
+ for offset in range(int(episodes)):
651
+ seed = int(seed_start) + offset
652
+ attempted += 1
653
+ result = run_episode(
654
+ seed=seed,
655
+ split=split,
656
+ difficulty=config.difficulty,
657
+ config=config,
658
+ teacher=teacher,
659
+ )
660
+ attempts.append(
661
+ {
662
+ "seed": seed,
663
+ "split": split,
664
+ "accepted": bool(result["accepted"]),
665
+ "reason": result.get("reason", ""),
666
+ "trajectory_path": str(_write_trajectory(config.out_dir, result["trajectory"])),
667
+ }
668
+ )
669
+ if result["accepted"]:
670
+ accepted += 1
671
+ rows = list(result["rows"])
672
+ rows_by_split.setdefault(split, []).extend(rows)
673
+ rewards.append(float(result["trajectory"].get("terminal_total", 0.0)))
674
+
675
+ for split_name in ("train", "validation", config.split):
676
+ write_jsonl(config.out_dir / f"{split_name}.jsonl", rows_by_split.get(split_name, []))
677
+
678
+ manifest = {
679
+ "teacher_model": config.teacher_model,
680
+ "target_model": config.target_model,
681
+ "split": config.split,
682
+ "difficulty": config.difficulty,
683
+ "seed_start": config.seed_start,
684
+ "episodes_attempted": attempted,
685
+ "episodes_accepted": accepted,
686
+ "acceptance_rate": accepted / attempted if attempted else 0.0,
687
+ "rows_by_split": {key: len(value) for key, value in sorted(rows_by_split.items())},
688
+ "reward_summary": _reward_summary(rewards),
689
+ "git_sha": _git_sha(),
690
+ "verifier_version": "verifier_v1",
691
+ "dry_run_oracle": config.dry_run_oracle,
692
+ "attempts": attempts,
693
+ }
694
+ manifest_path = config.out_dir / "manifest.json"
695
+ manifest_path.write_text(
696
+ json.dumps(manifest, indent=2, sort_keys=True, default=str),
697
+ encoding="utf-8",
698
+ )
699
+ return manifest
700
+
701
+
702
+ def build_arg_parser() -> argparse.ArgumentParser:
703
+ parser = argparse.ArgumentParser(description=__doc__)
704
+ parser.add_argument("--teacher-model", default=DEFAULT_TEACHER_MODEL)
705
+ parser.add_argument("--target-model", default=DEFAULT_TARGET_MODEL)
706
+ parser.add_argument("--split", default="train", choices=["train", "validation", "hidden_eval"])
707
+ parser.add_argument("--difficulty", type=int, default=0)
708
+ parser.add_argument("--seed-start", type=int, default=0)
709
+ parser.add_argument("--episodes", type=int, default=100)
710
+ parser.add_argument("--validation-episodes", type=int, default=0)
711
+ parser.add_argument("--out-dir", type=Path, default=Path("outputs/sft"))
712
+ parser.add_argument("--max-steps", type=int, default=40)
713
+ parser.add_argument("--max-teacher-retries", type=int, default=2)
714
+ parser.add_argument("--max-tokens", type=int, default=768)
715
+ parser.add_argument("--temperature", type=float, default=0.2)
716
+ parser.add_argument("--top-p", type=float, default=0.95)
717
+ parser.add_argument(
718
+ "--dry-run-oracle",
719
+ action="store_true",
720
+ help="Generate deterministic oracle data without calling the HF API.",
721
+ )
722
+ return parser
723
+
724
+
725
+ def config_from_args(args: argparse.Namespace) -> DatasetConfig:
726
+ return DatasetConfig(
727
+ teacher_model=args.teacher_model,
728
+ target_model=args.target_model,
729
+ split=args.split,
730
+ difficulty=args.difficulty,
731
+ seed_start=args.seed_start,
732
+ episodes=args.episodes,
733
+ validation_episodes=args.validation_episodes,
734
+ out_dir=args.out_dir,
735
+ max_steps=args.max_steps,
736
+ max_teacher_retries=args.max_teacher_retries,
737
+ max_tokens=args.max_tokens,
738
+ temperature=args.temperature,
739
+ top_p=args.top_p,
740
+ dry_run_oracle=args.dry_run_oracle,
741
+ )
742
+
743
+
744
+ def main(argv: list[str] | None = None) -> int:
745
+ parser = build_arg_parser()
746
+ args = parser.parse_args(argv)
747
+ manifest = generate_dataset(config_from_args(args))
748
+ print(json.dumps(manifest, indent=2, sort_keys=True))
749
+ return 0
750
+
751
+
752
+ if __name__ == "__main__":
753
+ raise SystemExit(main())
scripts/launch_reward_ablations.ps1 ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ param(
2
+ [switch]$AllowActive
3
+ )
4
+
5
+ $ErrorActionPreference = "Stop"
6
+ $env:PYTHONIOENCODING = "utf-8"
7
+ $env:PYTHONUTF8 = "1"
8
+
9
+ $appList = uv run --extra modal modal app list | Out-String
10
+ Write-Host $appList
11
+ if (-not $AllowActive -and $appList -match "CyberSecur" -and $appList -match "ephemeral") {
12
+ throw "Active CyberSecurity_OWASP Modal apps are present. Re-run with -AllowActive only if overlapping L4 jobs are intentional."
13
+ }
14
+
15
+ $runs = @(
16
+ @{
17
+ Variant = "abl-a0-sparse"
18
+ Config = "training/configs/reward_ablations/A0_sparse_terminal_only.yaml"
19
+ Seed = 110000
20
+ },
21
+ @{
22
+ Variant = "abl-a2-shape035"
23
+ Config = "training/configs/reward_ablations/A2_reduced_shaping.yaml"
24
+ Seed = 120000
25
+ },
26
+ @{
27
+ Variant = "abl-a6-visgate"
28
+ Config = "training/configs/reward_ablations/A6_visible_gate.yaml"
29
+ Seed = 130000
30
+ },
31
+ @{
32
+ Variant = "abl-a7-evid045"
33
+ Config = "training/configs/reward_ablations/A7_evidence045.yaml"
34
+ Seed = 140000
35
+ },
36
+ @{
37
+ Variant = "abl-a3-nospeed"
38
+ Config = "training/configs/reward_ablations/A3_no_speed_token.yaml"
39
+ Seed = 150000
40
+ }
41
+ )
42
+
43
+ foreach ($run in $runs) {
44
+ Write-Host "Launching $($run.Variant) with $($run.Config) seed $($run.Seed)"
45
+ uv run --extra modal modal run --detach scripts/modal_train_grpo.py `
46
+ --mode train `
47
+ --max-steps 60 `
48
+ --dataset-size 32 `
49
+ --num-generations 4 `
50
+ --max-completion-length 768 `
51
+ --difficulty 0 `
52
+ --split train `
53
+ --source-mode local `
54
+ --trace-log-every 5 `
55
+ --seed-start $run.Seed `
56
+ --reward-config $run.Config `
57
+ --reward-variant $run.Variant `
58
+ --detach
59
+ }
scripts/modal_train_grpo.py CHANGED
@@ -210,6 +210,24 @@ def _configure_scenario_cache_env(*, required: bool = True) -> dict[str, str]:
210
  return values
211
 
212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  def _print_image_startup_notice() -> None:
214
  global _IMAGE_NOTICE_PRINTED
215
  if _IMAGE_NOTICE_PRINTED:
@@ -583,6 +601,8 @@ def run_cybersecurity_owasp_baseline(
583
  source_mode: str = "local",
584
  repo_url: str = PUBLIC_REPO_URL,
585
  repo_branch: str = PUBLIC_REPO_BRANCH,
 
 
586
  ) -> dict[str, str | int | float]:
587
  import statistics
588
  import time
@@ -627,8 +647,14 @@ def run_cybersecurity_owasp_baseline(
627
 
628
  os.environ["TRACKIO_SPACE_ID"] = trackio_space_id
629
  os.environ["TRACKIO_PROJECT"] = trackio_project
 
 
 
 
630
  reward_settings = load_reward_settings()
631
  reward_tracking_config = reward_config_trackio_config(reward_settings)
 
 
632
  run_name = run_name or "baseline"
633
  output_dir = RUNS_DIR / run_name
634
  output_dir.mkdir(parents=True, exist_ok=True)
@@ -673,6 +699,10 @@ def run_cybersecurity_owasp_baseline(
673
  print(f"Trackio Project: {trackio_project}")
674
  print(f"Reward config: {reward_tracking_config['reward_config_id']}")
675
  print(f"Reward config hash: {reward_tracking_config['reward_config_hash']}")
 
 
 
 
676
  print(f"Scenario cache dir: {scenario_cache_env['CYBERSECURITY_OWASP_SCENARIO_CACHE_DIR']}")
677
  print(f"Scenario cache coverage: {coverage}")
678
  print(
@@ -818,6 +848,7 @@ def run_cybersecurity_owasp_baseline(
818
  "num_generations": num_generations,
819
  "max_completion_length": max_completion_length,
820
  "git_sha": git_sha,
 
821
  **reward_tracking_config,
822
  }
823
 
@@ -998,6 +1029,8 @@ def run_cybersecurity_owasp_baseline(
998
  def train_cybersecurity_owasp_grpo(
999
  env_repo_id: str = "",
1000
  output_repo_id: str = "",
 
 
1001
  max_steps: int = 10,
1002
  dataset_size: int = 16,
1003
  difficulty: int = 0,
@@ -1021,6 +1054,8 @@ def train_cybersecurity_owasp_grpo(
1021
  repo_url: str = PUBLIC_REPO_URL,
1022
  repo_branch: str = PUBLIC_REPO_BRANCH,
1023
  push_to_hub: bool = False,
 
 
1024
  ) -> dict[str, str | int | float]:
1025
  import inspect
1026
  import statistics
@@ -1050,6 +1085,7 @@ def train_cybersecurity_owasp_grpo(
1050
  import transformers.utils.hub as transformers_hub
1051
  from datasets import Dataset
1052
  from huggingface_hub import snapshot_download, whoami
 
1053
  from transformers import TrainerCallback
1054
  from trl import GRPOConfig, GRPOTrainer, clone_chat_template
1055
  from trl.chat_template_utils import add_response_schema
@@ -1110,14 +1146,22 @@ def train_cybersecurity_owasp_grpo(
1110
 
1111
  os.environ["TRACKIO_SPACE_ID"] = trackio_space_id
1112
  os.environ["TRACKIO_PROJECT"] = trackio_project
1113
- os.environ.setdefault("CYBERSECURITY_OWASP_REWARD_MODE", "dense_train")
 
 
 
 
1114
  reward_settings = load_reward_settings()
1115
  reward_tracking_config = reward_config_trackio_config(reward_settings)
 
 
1116
 
1117
  model_slug = model_name.replace("/", "-")
1118
  stamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
1119
  run_name = run_name or (
1120
- f"CyberSecurity_OWASP-{model_slug}-grpo-level{difficulty}-{stamp}-{git_sha[:8]}"
 
 
1121
  )
1122
  output_dir = RUNS_DIR / run_name
1123
  output_dir.mkdir(parents=True, exist_ok=True)
@@ -1253,6 +1297,7 @@ def train_cybersecurity_owasp_grpo(
1253
  "reward_config_hash": reward_tracking_config["reward_config_hash"],
1254
  "reward_stage": reward_tracking_config["reward_stage"],
1255
  "reward_mode": reward_tracking_config["reward_mode"],
 
1256
  }
1257
  )
1258
  return obs.scenario_prompt
@@ -1613,6 +1658,7 @@ def train_cybersecurity_owasp_grpo(
1613
  "reward_config_hash": reward_tracking_config["reward_config_hash"],
1614
  "reward_stage": reward_tracking_config["reward_stage"],
1615
  "reward_mode": reward_tracking_config["reward_mode"],
 
1616
  }
1617
  )
1618
  try:
@@ -1704,6 +1750,9 @@ def train_cybersecurity_owasp_grpo(
1704
  print(f"Run name: {run_name}")
1705
  print(f"Reward config: {reward_tracking_config['reward_config_id']}")
1706
  print(f"Reward config hash: {reward_tracking_config['reward_config_hash']}")
 
 
 
1707
  print(f"Model cache volume: {CACHE_VOLUME_NAME}")
1708
  print(f"Scenario cache volume: {SCENARIO_CACHE_VOLUME_NAME}")
1709
  print(f"Scenario cache dir: {scenario_cache_env['CYBERSECURITY_OWASP_SCENARIO_CACHE_DIR']}")
@@ -1715,6 +1764,10 @@ def train_cybersecurity_owasp_grpo(
1715
  print(f"Unsloth cache: {cache_env['UNSLOTH_CACHE_DIR']}")
1716
  print(f"Triton cache: {cache_env['TRITON_CACHE_DIR']}")
1717
  print(f"Hub push enabled: {push_to_hub}")
 
 
 
 
1718
  print(
1719
  "GRPO throughput config: "
1720
  f"per_device_train_batch_size={per_device_train_batch_size}, "
@@ -1801,25 +1854,40 @@ def train_cybersecurity_owasp_grpo(
1801
  f"{exc!r}"
1802
  )
1803
 
1804
- model = model_api.get_peft_model(
1805
- model,
1806
- r=lora_rank,
1807
- target_modules=[
1808
- "q_proj",
1809
- "k_proj",
1810
- "v_proj",
1811
- "o_proj",
1812
- "gate_proj",
1813
- "up_proj",
1814
- "down_proj",
1815
- ],
1816
- lora_alpha=lora_rank * 2,
1817
- use_gradient_checkpointing="unsloth",
1818
- random_state=3407,
1819
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1820
  if hasattr(model_api, "for_training"):
1821
  model_api.for_training(model)
1822
- print("LoRA adapter attached and model switched to training mode.")
1823
 
1824
  grpo_config_values = {
1825
  "temperature": 1.0,
@@ -1942,6 +2010,8 @@ def train_cybersecurity_owasp_grpo(
1942
  "difficulty": difficulty,
1943
  "split": split,
1944
  "model_name": model_name,
 
 
1945
  "max_completion_length": max_completion_length,
1946
  "num_generations": num_generations,
1947
  "per_device_train_batch_size": per_device_train_batch_size,
@@ -1956,6 +2026,7 @@ def train_cybersecurity_owasp_grpo(
1956
  "push_to_hub": push_to_hub,
1957
  "scenario_cache_volume": SCENARIO_CACHE_VOLUME_NAME,
1958
  "scenario_cache_mode": "require",
 
1959
  **reward_tracking_config,
1960
  }
1961
 
@@ -1965,6 +2036,8 @@ def main(
1965
  mode: str = "train",
1966
  env_repo_id: str = "",
1967
  output_repo_id: str = "",
 
 
1968
  max_steps: int = 10,
1969
  dataset_size: int = 16,
1970
  difficulty: int = 0,
@@ -1989,6 +2062,8 @@ def main(
1989
  repo_branch: str = PUBLIC_REPO_BRANCH,
1990
  detach: bool = False,
1991
  push_to_hub: bool = False,
 
 
1992
  cache_seed_start: int = 0,
1993
  cache_difficulty_buckets: int = 0,
1994
  cache_train_per_bucket: int = 0,
@@ -2042,6 +2117,8 @@ def main(
2042
  source_mode=source_mode,
2043
  repo_url=repo_url,
2044
  repo_branch=repo_branch,
 
 
2045
  )
2046
  if detach:
2047
  call = run_cybersecurity_owasp_baseline.spawn(**kwargs)
@@ -2100,7 +2177,13 @@ def main(
2100
  if git_sha == "nogit":
2101
  try:
2102
  git_sha = subprocess.check_output(
2103
- ["git", "rev-parse", "HEAD"],
 
 
 
 
 
 
2104
  cwd=PROJECT_ROOT,
2105
  text=True,
2106
  stderr=subprocess.DEVNULL,
@@ -2110,12 +2193,15 @@ def main(
2110
 
2111
  model_slug = model_name.replace("/", "-")
2112
  local_stamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
 
2113
  run_name = run_name or (
2114
  f"CyberSecurity_OWASP-{model_slug}-grpo-level{difficulty}-"
2115
- f"{local_stamp}-{git_sha[:8]}"
2116
  )
2117
 
2118
  print(f"Run name: {run_name}")
 
 
2119
  print(f"Source mode: {source_mode}")
2120
  if source_mode == "public":
2121
  print(f"Public repo: {repo_url}@{repo_branch}")
@@ -2131,6 +2217,10 @@ def main(
2131
  f"<hf-user>/CyberSecurity_OWASP-{_model_repo_slug(model_name)}-grpo-lora"
2132
  )
2133
  print(f"Hub push enabled: {push_to_hub}")
 
 
 
 
2134
  print(f"Model cache volume: {CACHE_VOLUME_NAME}")
2135
  print(f"Scenario cache volume: {SCENARIO_CACHE_VOLUME_NAME}")
2136
  print(
@@ -2164,6 +2254,8 @@ def main(
2164
  kwargs = dict(
2165
  env_repo_id=env_repo_id,
2166
  output_repo_id=output_repo_id,
 
 
2167
  max_steps=max_steps,
2168
  dataset_size=dataset_size,
2169
  difficulty=difficulty,
@@ -2187,6 +2279,8 @@ def main(
2187
  repo_url=repo_url,
2188
  repo_branch=repo_branch,
2189
  push_to_hub=push_to_hub,
 
 
2190
  )
2191
  preflight = verify_modal_scenario_cache_for_training.remote(
2192
  split=split,
 
210
  return values
211
 
212
 
213
+ def _configure_reward_env(
214
+ *,
215
+ reward_config: str = "",
216
+ reward_variant: str = "",
217
+ reward_mode: str = "",
218
+ ) -> dict[str, str]:
219
+ values: dict[str, str] = {}
220
+ if reward_config:
221
+ values["CYBERSECURITY_OWASP_REWARD_CONFIG"] = reward_config
222
+ if reward_variant:
223
+ values["CYBERSECURITY_OWASP_REWARD_VARIANT"] = reward_variant
224
+ if reward_mode:
225
+ values["CYBERSECURITY_OWASP_REWARD_MODE"] = reward_mode
226
+ for key, value in values.items():
227
+ os.environ[key] = value
228
+ return values
229
+
230
+
231
  def _print_image_startup_notice() -> None:
232
  global _IMAGE_NOTICE_PRINTED
233
  if _IMAGE_NOTICE_PRINTED:
 
601
  source_mode: str = "local",
602
  repo_url: str = PUBLIC_REPO_URL,
603
  repo_branch: str = PUBLIC_REPO_BRANCH,
604
+ reward_config: str = "",
605
+ reward_variant: str = "",
606
  ) -> dict[str, str | int | float]:
607
  import statistics
608
  import time
 
647
 
648
  os.environ["TRACKIO_SPACE_ID"] = trackio_space_id
649
  os.environ["TRACKIO_PROJECT"] = trackio_project
650
+ reward_env = _configure_reward_env(
651
+ reward_config=reward_config,
652
+ reward_variant=reward_variant,
653
+ )
654
  reward_settings = load_reward_settings()
655
  reward_tracking_config = reward_config_trackio_config(reward_settings)
656
+ reward_tracking_config["reward_variant"] = reward_variant or "default"
657
+ reward_tracking_config["reward_config_path"] = reward_config or reward_settings.source_path
658
  run_name = run_name or "baseline"
659
  output_dir = RUNS_DIR / run_name
660
  output_dir.mkdir(parents=True, exist_ok=True)
 
699
  print(f"Trackio Project: {trackio_project}")
700
  print(f"Reward config: {reward_tracking_config['reward_config_id']}")
701
  print(f"Reward config hash: {reward_tracking_config['reward_config_hash']}")
702
+ print(f"Reward variant: {reward_tracking_config['reward_variant']}")
703
+ print(f"Reward config path: {reward_tracking_config['reward_config_path']}")
704
+ if reward_env:
705
+ print(f"Reward env overrides: {reward_env}")
706
  print(f"Scenario cache dir: {scenario_cache_env['CYBERSECURITY_OWASP_SCENARIO_CACHE_DIR']}")
707
  print(f"Scenario cache coverage: {coverage}")
708
  print(
 
848
  "num_generations": num_generations,
849
  "max_completion_length": max_completion_length,
850
  "git_sha": git_sha,
851
+ "reward_variant": reward_tracking_config["reward_variant"],
852
  **reward_tracking_config,
853
  }
854
 
 
1029
  def train_cybersecurity_owasp_grpo(
1030
  env_repo_id: str = "",
1031
  output_repo_id: str = "",
1032
+ initial_adapter_path: str = "",
1033
+ initial_adapter_repo_id: str = "",
1034
  max_steps: int = 10,
1035
  dataset_size: int = 16,
1036
  difficulty: int = 0,
 
1054
  repo_url: str = PUBLIC_REPO_URL,
1055
  repo_branch: str = PUBLIC_REPO_BRANCH,
1056
  push_to_hub: bool = False,
1057
+ reward_config: str = "",
1058
+ reward_variant: str = "",
1059
  ) -> dict[str, str | int | float]:
1060
  import inspect
1061
  import statistics
 
1085
  import transformers.utils.hub as transformers_hub
1086
  from datasets import Dataset
1087
  from huggingface_hub import snapshot_download, whoami
1088
+ from peft import PeftModel
1089
  from transformers import TrainerCallback
1090
  from trl import GRPOConfig, GRPOTrainer, clone_chat_template
1091
  from trl.chat_template_utils import add_response_schema
 
1146
 
1147
  os.environ["TRACKIO_SPACE_ID"] = trackio_space_id
1148
  os.environ["TRACKIO_PROJECT"] = trackio_project
1149
+ reward_env = _configure_reward_env(
1150
+ reward_config=reward_config,
1151
+ reward_variant=reward_variant,
1152
+ reward_mode="dense_train",
1153
+ )
1154
  reward_settings = load_reward_settings()
1155
  reward_tracking_config = reward_config_trackio_config(reward_settings)
1156
+ reward_tracking_config["reward_variant"] = reward_variant or "default"
1157
+ reward_tracking_config["reward_config_path"] = reward_config or reward_settings.source_path
1158
 
1159
  model_slug = model_name.replace("/", "-")
1160
  stamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
1161
  run_name = run_name or (
1162
+ f"CyberSecurity_OWASP-{model_slug}-grpo-level{difficulty}-"
1163
+ f"{reward_tracking_config['reward_variant']}-steps{max_steps}-seed{seed_start}-"
1164
+ f"{stamp}-{git_sha[:8]}"
1165
  )
1166
  output_dir = RUNS_DIR / run_name
1167
  output_dir.mkdir(parents=True, exist_ok=True)
 
1297
  "reward_config_hash": reward_tracking_config["reward_config_hash"],
1298
  "reward_stage": reward_tracking_config["reward_stage"],
1299
  "reward_mode": reward_tracking_config["reward_mode"],
1300
+ "reward_variant": reward_tracking_config["reward_variant"],
1301
  }
1302
  )
1303
  return obs.scenario_prompt
 
1658
  "reward_config_hash": reward_tracking_config["reward_config_hash"],
1659
  "reward_stage": reward_tracking_config["reward_stage"],
1660
  "reward_mode": reward_tracking_config["reward_mode"],
1661
+ "reward_variant": reward_tracking_config["reward_variant"],
1662
  }
1663
  )
1664
  try:
 
1750
  print(f"Run name: {run_name}")
1751
  print(f"Reward config: {reward_tracking_config['reward_config_id']}")
1752
  print(f"Reward config hash: {reward_tracking_config['reward_config_hash']}")
1753
+ print(f"Reward variant: {reward_tracking_config['reward_variant']}")
1754
+ print(f"Reward config path: {reward_tracking_config['reward_config_path']}")
1755
+ print(f"Reward env overrides: {reward_env}")
1756
  print(f"Model cache volume: {CACHE_VOLUME_NAME}")
1757
  print(f"Scenario cache volume: {SCENARIO_CACHE_VOLUME_NAME}")
1758
  print(f"Scenario cache dir: {scenario_cache_env['CYBERSECURITY_OWASP_SCENARIO_CACHE_DIR']}")
 
1764
  print(f"Unsloth cache: {cache_env['UNSLOTH_CACHE_DIR']}")
1765
  print(f"Triton cache: {cache_env['TRITON_CACHE_DIR']}")
1766
  print(f"Hub push enabled: {push_to_hub}")
1767
+ if initial_adapter_path:
1768
+ print(f"Initial SFT adapter path: {initial_adapter_path}")
1769
+ if initial_adapter_repo_id:
1770
+ print(f"Initial SFT adapter repo: https://huggingface.co/{initial_adapter_repo_id}")
1771
  print(
1772
  "GRPO throughput config: "
1773
  f"per_device_train_batch_size={per_device_train_batch_size}, "
 
1854
  f"{exc!r}"
1855
  )
1856
 
1857
+ adapter_source = initial_adapter_path
1858
+ if initial_adapter_repo_id:
1859
+ print(f"Downloading initial SFT adapter: {initial_adapter_repo_id}")
1860
+ adapter_source = snapshot_download(
1861
+ repo_id=initial_adapter_repo_id,
1862
+ cache_dir=str(HF_HUB_CACHE_DIR),
1863
+ token=hf_token,
1864
+ )
1865
+ cache_volume.commit()
1866
+ if adapter_source:
1867
+ print(f"Loading initial SFT adapter for trainable GRPO continuation: {adapter_source}")
1868
+ model = PeftModel.from_pretrained(model, adapter_source, is_trainable=True)
1869
+ if hasattr(model, "print_trainable_parameters"):
1870
+ model.print_trainable_parameters()
1871
+ else:
1872
+ model = model_api.get_peft_model(
1873
+ model,
1874
+ r=lora_rank,
1875
+ target_modules=[
1876
+ "q_proj",
1877
+ "k_proj",
1878
+ "v_proj",
1879
+ "o_proj",
1880
+ "gate_proj",
1881
+ "up_proj",
1882
+ "down_proj",
1883
+ ],
1884
+ lora_alpha=lora_rank * 2,
1885
+ use_gradient_checkpointing="unsloth",
1886
+ random_state=3407,
1887
+ )
1888
  if hasattr(model_api, "for_training"):
1889
  model_api.for_training(model)
1890
+ print("LoRA adapter ready and model switched to training mode.")
1891
 
1892
  grpo_config_values = {
1893
  "temperature": 1.0,
 
2010
  "difficulty": difficulty,
2011
  "split": split,
2012
  "model_name": model_name,
2013
+ "initial_adapter_path": initial_adapter_path,
2014
+ "initial_adapter_repo_id": initial_adapter_repo_id,
2015
  "max_completion_length": max_completion_length,
2016
  "num_generations": num_generations,
2017
  "per_device_train_batch_size": per_device_train_batch_size,
 
2026
  "push_to_hub": push_to_hub,
2027
  "scenario_cache_volume": SCENARIO_CACHE_VOLUME_NAME,
2028
  "scenario_cache_mode": "require",
2029
+ "reward_variant": reward_tracking_config["reward_variant"],
2030
  **reward_tracking_config,
2031
  }
2032
 
 
2036
  mode: str = "train",
2037
  env_repo_id: str = "",
2038
  output_repo_id: str = "",
2039
+ initial_adapter_path: str = "",
2040
+ initial_adapter_repo_id: str = "",
2041
  max_steps: int = 10,
2042
  dataset_size: int = 16,
2043
  difficulty: int = 0,
 
2062
  repo_branch: str = PUBLIC_REPO_BRANCH,
2063
  detach: bool = False,
2064
  push_to_hub: bool = False,
2065
+ reward_config: str = "",
2066
+ reward_variant: str = "",
2067
  cache_seed_start: int = 0,
2068
  cache_difficulty_buckets: int = 0,
2069
  cache_train_per_bucket: int = 0,
 
2117
  source_mode=source_mode,
2118
  repo_url=repo_url,
2119
  repo_branch=repo_branch,
2120
+ reward_config=reward_config,
2121
+ reward_variant=reward_variant,
2122
  )
2123
  if detach:
2124
  call = run_cybersecurity_owasp_baseline.spawn(**kwargs)
 
2177
  if git_sha == "nogit":
2178
  try:
2179
  git_sha = subprocess.check_output(
2180
+ [
2181
+ "git",
2182
+ "-c",
2183
+ f"safe.directory={PROJECT_ROOT.as_posix()}",
2184
+ "rev-parse",
2185
+ "HEAD",
2186
+ ],
2187
  cwd=PROJECT_ROOT,
2188
  text=True,
2189
  stderr=subprocess.DEVNULL,
 
2193
 
2194
  model_slug = model_name.replace("/", "-")
2195
  local_stamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
2196
+ variant_tag = reward_variant or "default"
2197
  run_name = run_name or (
2198
  f"CyberSecurity_OWASP-{model_slug}-grpo-level{difficulty}-"
2199
+ f"{variant_tag}-steps{max_steps}-seed{seed_start}-{local_stamp}-{git_sha[:8]}"
2200
  )
2201
 
2202
  print(f"Run name: {run_name}")
2203
+ print(f"Reward variant: {variant_tag}")
2204
+ print(f"Reward config path: {reward_config or '(default training/configs/grpo_small.yaml)'}")
2205
  print(f"Source mode: {source_mode}")
2206
  if source_mode == "public":
2207
  print(f"Public repo: {repo_url}@{repo_branch}")
 
2217
  f"<hf-user>/CyberSecurity_OWASP-{_model_repo_slug(model_name)}-grpo-lora"
2218
  )
2219
  print(f"Hub push enabled: {push_to_hub}")
2220
+ if initial_adapter_path:
2221
+ print(f"Initial SFT adapter path: {initial_adapter_path}")
2222
+ if initial_adapter_repo_id:
2223
+ print(f"Initial SFT adapter repo: https://huggingface.co/{initial_adapter_repo_id}")
2224
  print(f"Model cache volume: {CACHE_VOLUME_NAME}")
2225
  print(f"Scenario cache volume: {SCENARIO_CACHE_VOLUME_NAME}")
2226
  print(
 
2254
  kwargs = dict(
2255
  env_repo_id=env_repo_id,
2256
  output_repo_id=output_repo_id,
2257
+ initial_adapter_path=initial_adapter_path,
2258
+ initial_adapter_repo_id=initial_adapter_repo_id,
2259
  max_steps=max_steps,
2260
  dataset_size=dataset_size,
2261
  difficulty=difficulty,
 
2279
  repo_url=repo_url,
2280
  repo_branch=repo_branch,
2281
  push_to_hub=push_to_hub,
2282
+ reward_config=reward_config,
2283
+ reward_variant=reward_variant,
2284
  )
2285
  preflight = verify_modal_scenario_cache_for_training.remote(
2286
  split=split,
scripts/modal_train_sft.py ADDED
@@ -0,0 +1,442 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Modal SFT launcher for CyberSecurity_OWASP action JSON data.
2
+
3
+ This trains a LoRA adapter on chat JSONL generated by
4
+ ``scripts/generate_sft_dataset.py``. It intentionally mirrors the repo's Modal
5
+ training pattern: local execution only launches remote jobs, while training runs
6
+ inside Modal and saves adapters to the persistent run volume.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import json
12
+ import os
13
+ import pathlib
14
+ import subprocess
15
+ from datetime import datetime, timezone
16
+ from typing import Any
17
+
18
+ import modal
19
+
20
+
21
+ APP_NAME = "CyberSecurity_OWASP-sft"
22
+ VOLUME_NAME = "CyberSecurity_OWASP-grpo-runs"
23
+ CACHE_VOLUME_NAME = "CyberSecurity_OWASP-model-cache"
24
+ SECRET_NAME = "CyberSecurity_OWASP-secrets"
25
+ RUNS_DIR = pathlib.Path("/runs")
26
+ CACHE_DIR = pathlib.Path("/cache")
27
+ HF_HOME_DIR = CACHE_DIR / "huggingface"
28
+ HF_HUB_CACHE_DIR = HF_HOME_DIR / "hub"
29
+ TORCH_HOME_DIR = CACHE_DIR / "torch"
30
+ XDG_CACHE_DIR = CACHE_DIR / "xdg"
31
+ UNSLOTH_CACHE_DIR = CACHE_DIR / "unsloth"
32
+ TRITON_CACHE_DIR = CACHE_DIR / "triton"
33
+ REMOTE_PROJECT = "/root/CyberSecurity_OWASP"
34
+ PROJECT_ROOT = pathlib.Path(__file__).resolve().parents[1]
35
+ DEFAULT_GEMMA_MODEL = "unsloth/gemma-4-E2B-it"
36
+ PUBLIC_REPO_URL = "https://github.com/humandotlearning/CyberSecurity_OWASP.git"
37
+ PUBLIC_REPO_BRANCH = "master"
38
+
39
+
40
+ def _ensure_gemma4_model(model_name: str) -> str:
41
+ if model_name != DEFAULT_GEMMA_MODEL:
42
+ raise ValueError(
43
+ "CyberSecurity_OWASP SFT is pinned to "
44
+ f"{DEFAULT_GEMMA_MODEL}; received {model_name!r}."
45
+ )
46
+ return model_name
47
+
48
+
49
+ def _model_repo_slug(model_name: str) -> str:
50
+ return model_name.replace("/", "-").replace("_", "-").replace(".", "-").lower()
51
+
52
+
53
+ def _configure_modal_cache_env() -> dict[str, str]:
54
+ values = {
55
+ "HF_HOME": str(HF_HOME_DIR),
56
+ "HF_HUB_CACHE": str(HF_HUB_CACHE_DIR),
57
+ "TRANSFORMERS_CACHE": str(HF_HUB_CACHE_DIR),
58
+ "TORCH_HOME": str(TORCH_HOME_DIR),
59
+ "XDG_CACHE_HOME": str(XDG_CACHE_DIR),
60
+ "UNSLOTH_CACHE_DIR": str(UNSLOTH_CACHE_DIR),
61
+ "UNSLOTH_COMPILE_CACHE": str(UNSLOTH_CACHE_DIR / "compile"),
62
+ "TRITON_CACHE_DIR": str(TRITON_CACHE_DIR),
63
+ }
64
+ for key, value in values.items():
65
+ os.environ[key] = value
66
+ for path in {
67
+ CACHE_DIR,
68
+ HF_HOME_DIR,
69
+ HF_HUB_CACHE_DIR,
70
+ TORCH_HOME_DIR,
71
+ XDG_CACHE_DIR,
72
+ UNSLOTH_CACHE_DIR,
73
+ UNSLOTH_CACHE_DIR / "compile",
74
+ TRITON_CACHE_DIR,
75
+ }:
76
+ path.mkdir(parents=True, exist_ok=True)
77
+ return values
78
+
79
+
80
+ def _cli_arg_value(name: str, default: str = "") -> str:
81
+ import sys
82
+
83
+ args = sys.argv[1:]
84
+ flag = f"--{name}"
85
+ for index, arg in enumerate(args):
86
+ if arg == flag and index + 1 < len(args):
87
+ return args[index + 1]
88
+ if arg.startswith(f"{flag}="):
89
+ return arg.split("=", 1)[1]
90
+ return default
91
+
92
+
93
+ def _source_mode() -> str:
94
+ return _cli_arg_value("source-mode", os.environ.get("MODAL_SOURCE_MODE", "local"))
95
+
96
+
97
+ def _training_image() -> modal.Image:
98
+ image = (
99
+ modal.Image.from_registry(
100
+ "nvidia/cuda:12.8.0-devel-ubuntu22.04",
101
+ add_python="3.11",
102
+ )
103
+ .apt_install("git", "build-essential", "curl")
104
+ .uv_pip_install(
105
+ "torch==2.10.0",
106
+ "triton>=3.4.0",
107
+ "torchvision==0.25.0",
108
+ "bitsandbytes",
109
+ "accelerate",
110
+ "datasets",
111
+ "huggingface_hub",
112
+ "peft",
113
+ "tokenizers",
114
+ "trackio>=0.25.0",
115
+ "transformers>=5.5.0",
116
+ "trl>=0.28.0",
117
+ )
118
+ .uv_pip_install(
119
+ "unsloth_zoo[base] @ git+https://github.com/unslothai/unsloth-zoo",
120
+ "unsloth[base] @ git+https://github.com/unslothai/unsloth",
121
+ )
122
+ .uv_pip_install("timm", extra_options="--no-deps")
123
+ .uv_pip_install("pydantic==2.10.6")
124
+ )
125
+ if _source_mode() == "public":
126
+ repo_url = _cli_arg_value("repo-url", PUBLIC_REPO_URL)
127
+ repo_branch = _cli_arg_value("repo-branch", PUBLIC_REPO_BRANCH)
128
+ image = image.run_commands(
129
+ f"git clone --depth 1 --branch {repo_branch} {repo_url} {REMOTE_PROJECT}",
130
+ f"python -m pip install --no-deps -e {REMOTE_PROJECT}",
131
+ )
132
+ else:
133
+ image = image.add_local_dir(
134
+ PROJECT_ROOT,
135
+ remote_path=REMOTE_PROJECT,
136
+ copy=True,
137
+ ignore=[
138
+ ".git",
139
+ ".venv",
140
+ ".env",
141
+ ".env.*",
142
+ "__pycache__",
143
+ ".pytest_cache",
144
+ "outputs",
145
+ "*.pyc",
146
+ ],
147
+ )
148
+ image = image.run_commands(f"python -m pip install --no-deps -e {REMOTE_PROJECT}")
149
+ return image.workdir(REMOTE_PROJECT)
150
+
151
+
152
+ app = modal.App(APP_NAME)
153
+ volume = modal.Volume.from_name(VOLUME_NAME, create_if_missing=True)
154
+ cache_volume = modal.Volume.from_name(CACHE_VOLUME_NAME, create_if_missing=True)
155
+ training_image = _training_image()
156
+ secrets = [modal.Secret.from_name(SECRET_NAME, required_keys=["HF_TOKEN"])]
157
+
158
+
159
+ @app.function(
160
+ image=modal.Image.debian_slim(python_version="3.11"),
161
+ timeout=60 * 20,
162
+ volumes={RUNS_DIR: volume},
163
+ )
164
+ def upload_sft_jsonl(relative_path: str, content: str) -> str:
165
+ target = RUNS_DIR / relative_path
166
+ target.parent.mkdir(parents=True, exist_ok=True)
167
+ target.write_text(content, encoding="utf-8")
168
+ volume.commit()
169
+ return str(target)
170
+
171
+
172
+ @app.function(
173
+ image=training_image,
174
+ gpu="L4",
175
+ timeout=12 * 60 * 60,
176
+ volumes={RUNS_DIR: volume, CACHE_DIR: cache_volume},
177
+ secrets=secrets,
178
+ )
179
+ def train_cybersecurity_owasp_sft(
180
+ train_jsonl: str = "/runs/sft/train.jsonl",
181
+ validation_jsonl: str = "/runs/sft/validation.jsonl",
182
+ output_repo_id: str = "",
183
+ model_name: str = DEFAULT_GEMMA_MODEL,
184
+ run_name: str = "",
185
+ max_seq_length: int = 4096,
186
+ max_steps: int = 100,
187
+ num_train_epochs: float = 1.0,
188
+ per_device_train_batch_size: int = 1,
189
+ gradient_accumulation_steps: int = 16,
190
+ learning_rate: float = 2e-5,
191
+ lora_rank: int = 32,
192
+ trackio_space_id: str = "Humanlearning/CyberSecurity_OWASP-trackio",
193
+ trackio_project: str = "CyberSecurity_OWASP-sft",
194
+ push_to_hub: bool = False,
195
+ ) -> dict[str, Any]:
196
+ import inspect
197
+
198
+ from datasets import load_dataset
199
+ from huggingface_hub import snapshot_download, whoami
200
+ from trl import SFTConfig, SFTTrainer
201
+ from trl.chat_template_utils import add_response_schema
202
+ from unsloth import FastVisionModel
203
+
204
+ model_name = _ensure_gemma4_model(model_name)
205
+ cache_env = _configure_modal_cache_env()
206
+ hf_token = os.environ.get("HF_TOKEN")
207
+ if not hf_token:
208
+ raise RuntimeError(f"HF_TOKEN is missing from the Modal secret {SECRET_NAME}.")
209
+
210
+ user = whoami(token=hf_token)["name"]
211
+ output_repo_id = output_repo_id or (
212
+ f"{user}/CyberSecurity_OWASP-{_model_repo_slug(model_name)}-sft-lora"
213
+ )
214
+ stamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
215
+ run_name = run_name or f"CyberSecurity_OWASP-{_model_repo_slug(model_name)}-sft-{stamp}"
216
+ output_dir = RUNS_DIR / run_name
217
+ adapter_dir = output_dir / "sft_adapter"
218
+ output_dir.mkdir(parents=True, exist_ok=True)
219
+
220
+ data_files = {"train": train_jsonl}
221
+ validation_path = pathlib.Path(validation_jsonl)
222
+ has_validation = validation_path.exists() and validation_path.stat().st_size > 0
223
+ if has_validation:
224
+ data_files["validation"] = validation_jsonl
225
+ dataset = load_dataset("json", data_files=data_files)
226
+
227
+ print(f"SFT run name: {run_name}")
228
+ print(f"Model: {model_name}")
229
+ print(f"Train JSONL: {train_jsonl}")
230
+ print(f"Validation JSONL: {validation_jsonl if has_validation else '(none)'}")
231
+ print(f"Output adapter dir: {adapter_dir}")
232
+ print(f"Output repo: https://huggingface.co/{output_repo_id}")
233
+ print(f"Trackio Space: https://huggingface.co/spaces/{trackio_space_id}")
234
+ print(f"HF_HUB_CACHE: {cache_env['HF_HUB_CACHE']}")
235
+
236
+ try:
237
+ snapshot_download(repo_id=model_name, cache_dir=str(HF_HUB_CACHE_DIR), token=hf_token)
238
+ cache_volume.commit()
239
+ except Exception as exc:
240
+ print(f"Model snapshot prefetch skipped; loader will retry directly. Error: {exc!r}")
241
+
242
+ model_api = FastVisionModel
243
+ model, tokenizer = model_api.from_pretrained(
244
+ model_name=model_name,
245
+ max_seq_length=max_seq_length,
246
+ load_in_4bit=False,
247
+ fast_inference=False,
248
+ cache_dir=str(HF_HUB_CACHE_DIR),
249
+ token=hf_token,
250
+ )
251
+ try:
252
+ tokenizer = add_response_schema(tokenizer)
253
+ except Exception as exc:
254
+ print(f"Tokenizer response schema add skipped: {exc!r}")
255
+
256
+ model = model_api.get_peft_model(
257
+ model,
258
+ r=lora_rank,
259
+ target_modules=[
260
+ "q_proj",
261
+ "k_proj",
262
+ "v_proj",
263
+ "o_proj",
264
+ "gate_proj",
265
+ "up_proj",
266
+ "down_proj",
267
+ ],
268
+ lora_alpha=lora_rank * 2,
269
+ use_gradient_checkpointing="unsloth",
270
+ random_state=3407,
271
+ )
272
+ if hasattr(model_api, "for_training"):
273
+ model_api.for_training(model)
274
+
275
+ sft_values = {
276
+ "output_dir": str(output_dir),
277
+ "max_seq_length": max_seq_length,
278
+ "max_steps": max_steps,
279
+ "num_train_epochs": num_train_epochs,
280
+ "per_device_train_batch_size": per_device_train_batch_size,
281
+ "gradient_accumulation_steps": gradient_accumulation_steps,
282
+ "learning_rate": learning_rate,
283
+ "logging_steps": 1,
284
+ "save_steps": max(10, max_steps),
285
+ "report_to": "trackio",
286
+ "project": trackio_project,
287
+ "trackio_space_id": trackio_space_id,
288
+ "run_name": run_name,
289
+ "assistant_only_loss": True,
290
+ "packing": False,
291
+ "gradient_checkpointing": True,
292
+ "gradient_checkpointing_kwargs": {"use_reentrant": False},
293
+ "push_to_hub": push_to_hub,
294
+ "hub_model_id": output_repo_id,
295
+ "hub_private_repo": True,
296
+ }
297
+ sft_parameters = set(inspect.signature(SFTConfig).parameters)
298
+ skipped = sorted(set(sft_values) - sft_parameters)
299
+ if skipped:
300
+ print(f"Skipping unsupported SFTConfig keys: {skipped}")
301
+ training_args = SFTConfig(
302
+ **{key: value for key, value in sft_values.items() if key in sft_parameters}
303
+ )
304
+
305
+ trainer_values = {
306
+ "model": model,
307
+ "processing_class": tokenizer,
308
+ "args": training_args,
309
+ "train_dataset": dataset["train"],
310
+ "eval_dataset": dataset["validation"] if has_validation else None,
311
+ }
312
+ trainer_parameters = set(inspect.signature(SFTTrainer).parameters)
313
+ skipped_trainer = sorted(
314
+ key for key, value in trainer_values.items() if key not in trainer_parameters and value is not None
315
+ )
316
+ if skipped_trainer:
317
+ print(f"Skipping unsupported SFTTrainer keys: {skipped_trainer}")
318
+ trainer = SFTTrainer(
319
+ **{
320
+ key: value
321
+ for key, value in trainer_values.items()
322
+ if value is not None and key in trainer_parameters
323
+ }
324
+ )
325
+ trainer.train()
326
+ trainer.save_model(str(adapter_dir))
327
+ if push_to_hub:
328
+ trainer.push_to_hub()
329
+ volume.commit()
330
+ cache_volume.commit()
331
+ return {
332
+ "run_name": run_name,
333
+ "model_name": model_name,
334
+ "adapter_dir": str(adapter_dir),
335
+ "output_repo_id": output_repo_id,
336
+ "train_jsonl": train_jsonl,
337
+ "validation_jsonl": validation_jsonl if has_validation else "",
338
+ "max_steps": max_steps,
339
+ "push_to_hub": push_to_hub,
340
+ "trackio_space_id": trackio_space_id,
341
+ "trackio_project": trackio_project,
342
+ }
343
+
344
+
345
+ def _git_sha(default: str = "nogit") -> str:
346
+ try:
347
+ return subprocess.check_output(
348
+ [
349
+ "git",
350
+ "-c",
351
+ f"safe.directory={PROJECT_ROOT.as_posix()}",
352
+ "rev-parse",
353
+ "HEAD",
354
+ ],
355
+ cwd=PROJECT_ROOT,
356
+ text=True,
357
+ stderr=subprocess.DEVNULL,
358
+ ).strip()
359
+ except Exception:
360
+ return default
361
+
362
+
363
+ @app.local_entrypoint()
364
+ def main(
365
+ mode: str = "train",
366
+ local_train_path: str = "outputs/sft/train.jsonl",
367
+ local_validation_path: str = "outputs/sft/validation.jsonl",
368
+ train_jsonl: str = "/runs/sft/train.jsonl",
369
+ validation_jsonl: str = "/runs/sft/validation.jsonl",
370
+ output_repo_id: str = "",
371
+ model_name: str = DEFAULT_GEMMA_MODEL,
372
+ run_name: str = "",
373
+ max_seq_length: int = 4096,
374
+ max_steps: int = 100,
375
+ num_train_epochs: float = 1.0,
376
+ per_device_train_batch_size: int = 1,
377
+ gradient_accumulation_steps: int = 16,
378
+ learning_rate: float = 2e-5,
379
+ lora_rank: int = 32,
380
+ trackio_space_id: str = "Humanlearning/CyberSecurity_OWASP-trackio",
381
+ trackio_project: str = "CyberSecurity_OWASP-sft",
382
+ source_mode: str = "local",
383
+ repo_url: str = PUBLIC_REPO_URL,
384
+ repo_branch: str = PUBLIC_REPO_BRANCH,
385
+ detach: bool = False,
386
+ push_to_hub: bool = False,
387
+ ) -> None:
388
+ del source_mode, repo_url, repo_branch # consumed during image construction
389
+ model_name = _ensure_gemma4_model(model_name)
390
+ if mode not in {"upload", "train"}:
391
+ raise ValueError("mode must be 'upload' or 'train'")
392
+
393
+ local_train = pathlib.Path(local_train_path)
394
+ local_validation = pathlib.Path(local_validation_path)
395
+ if local_train.exists():
396
+ uploaded = upload_sft_jsonl.remote(
397
+ "sft/train.jsonl",
398
+ local_train.read_text(encoding="utf-8"),
399
+ )
400
+ print(f"Uploaded train JSONL: {uploaded}")
401
+ train_jsonl = uploaded
402
+ if local_validation.exists():
403
+ uploaded_validation = upload_sft_jsonl.remote(
404
+ "sft/validation.jsonl",
405
+ local_validation.read_text(encoding="utf-8"),
406
+ )
407
+ print(f"Uploaded validation JSONL: {uploaded_validation}")
408
+ validation_jsonl = uploaded_validation
409
+ if mode == "upload":
410
+ return
411
+
412
+ if not run_name:
413
+ stamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
414
+ run_name = f"CyberSecurity_OWASP-{_model_repo_slug(model_name)}-sft-{stamp}-{_git_sha()[:8]}"
415
+
416
+ kwargs = dict(
417
+ train_jsonl=train_jsonl,
418
+ validation_jsonl=validation_jsonl,
419
+ output_repo_id=output_repo_id,
420
+ model_name=model_name,
421
+ run_name=run_name,
422
+ max_seq_length=max_seq_length,
423
+ max_steps=max_steps,
424
+ num_train_epochs=num_train_epochs,
425
+ per_device_train_batch_size=per_device_train_batch_size,
426
+ gradient_accumulation_steps=gradient_accumulation_steps,
427
+ learning_rate=learning_rate,
428
+ lora_rank=lora_rank,
429
+ trackio_space_id=trackio_space_id,
430
+ trackio_project=trackio_project,
431
+ push_to_hub=push_to_hub,
432
+ )
433
+ print(f"SFT run name: {run_name}")
434
+ print(f"Train JSONL: {train_jsonl}")
435
+ print(f"Validation JSONL: {validation_jsonl}")
436
+ print(f"Hub push enabled: {push_to_hub}")
437
+ if detach:
438
+ call = train_cybersecurity_owasp_sft.spawn(**kwargs)
439
+ print(f"Spawned Modal SFT call: {call.object_id}")
440
+ else:
441
+ result = train_cybersecurity_owasp_sft.remote(**kwargs)
442
+ print(json.dumps(result, indent=2, sort_keys=True))
tests/test_reward_config.py CHANGED
@@ -68,6 +68,45 @@ def test_reward_config_hash_and_flattened_values_are_deterministic(monkeypatch):
68
  assert rows["hidden_file_probe"]["terminate"] is True
69
 
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  def test_reward_config_rejects_missing_descriptions(monkeypatch):
72
  config_path = Path("outputs/test_reward_config_bad.yaml")
73
  config_path.parent.mkdir(parents=True, exist_ok=True)
 
68
  assert rows["hidden_file_probe"]["terminate"] is True
69
 
70
 
71
+ def test_reward_ablation_configs_extend_default_and_have_unique_hashes(monkeypatch):
72
+ monkeypatch.setenv("CYBERSECURITY_OWASP_REWARD_MODE", "dense_train")
73
+ paths = [
74
+ Path("training/configs/reward_ablations/A0_sparse_terminal_only.yaml"),
75
+ Path("training/configs/reward_ablations/A2_reduced_shaping.yaml"),
76
+ Path("training/configs/reward_ablations/A6_visible_gate.yaml"),
77
+ Path("training/configs/reward_ablations/A7_evidence045.yaml"),
78
+ Path("training/configs/reward_ablations/A3_no_speed_token.yaml"),
79
+ ]
80
+
81
+ settings_by_name = {path.name: load_reward_settings(path) for path in paths}
82
+ hashes = {reward_config_hash(settings) for settings in settings_by_name.values()}
83
+
84
+ assert len(hashes) == len(paths)
85
+ assert settings_by_name["A0_sparse_terminal_only.yaml"].shaping_weight == 0.0
86
+ assert settings_by_name["A0_sparse_terminal_only.yaml"].value("progressive_cap") == 0.0
87
+ assert settings_by_name["A0_sparse_terminal_only.yaml"].value("terminal_cap") == 12.0
88
+ assert settings_by_name["A2_reduced_shaping.yaml"].shaping_weight == 0.35
89
+ assert settings_by_name["A2_reduced_shaping.yaml"].value("progressive_cap") == 2.5
90
+ assert settings_by_name["A6_visible_gate.yaml"].value("visible_tests_improved") == 0.0
91
+ assert settings_by_name["A6_visible_gate.yaml"].value("app_boots_after_patch") == 0.10
92
+ assert settings_by_name["A7_evidence045.yaml"].value("local_evidence_found") == 0.45
93
+ assert settings_by_name["A3_no_speed_token.yaml"].value("speed_bonus") == 0.0
94
+ assert compute_token_penalty(850, settings_by_name["A3_no_speed_token.yaml"]) == 0.0
95
+
96
+
97
+ def test_reward_config_run_config_includes_variant(monkeypatch):
98
+ monkeypatch.setenv("CYBERSECURITY_OWASP_REWARD_MODE", "dense_train")
99
+ monkeypatch.setenv("CYBERSECURITY_OWASP_REWARD_VARIANT", "abl-a2-shape035")
100
+
101
+ config = reward_config_run_config(
102
+ load_reward_settings("training/configs/reward_ablations/A2_reduced_shaping.yaml")
103
+ )
104
+
105
+ assert config["reward_variant"] == "abl-a2-shape035"
106
+ assert config["reward_config_source_name"] == "A2_reduced_shaping.yaml"
107
+ assert config["reward_config__shaping_weight__stage_value"] == 0.35
108
+
109
+
110
  def test_reward_config_rejects_missing_descriptions(monkeypatch):
111
  config_path = Path("outputs/test_reward_config_bad.yaml")
112
  config_path.parent.mkdir(parents=True, exist_ok=True)
tests/test_sft_dataset_generation.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import importlib.util
2
+ import json
3
+ import os
4
+ import sys
5
+ import uuid
6
+ from pathlib import Path
7
+
8
+ from CyberSecurity_OWASP.models import CyberSecurityOWASPAction
9
+ from CyberSecurity_OWASP.server.CyberSecurity_OWASP_environment import (
10
+ CybersecurityOwaspEnvironment,
11
+ )
12
+
13
+
14
+ MODULE_PATH = Path(__file__).resolve().parents[1] / "scripts" / "generate_sft_dataset.py"
15
+ SPEC = importlib.util.spec_from_file_location("generate_sft_dataset", MODULE_PATH)
16
+ generate_sft_dataset = importlib.util.module_from_spec(SPEC)
17
+ assert SPEC.loader is not None
18
+ sys.modules[SPEC.name] = generate_sft_dataset
19
+ SPEC.loader.exec_module(generate_sft_dataset)
20
+
21
+
22
+ def _isolated_out_dir(label: str) -> Path:
23
+ root = Path("outputs") / "sft_dataset_tests" / f"{label}_{uuid.uuid4().hex[:8]}"
24
+ workspace_root = root / "workspaces"
25
+ workspace_root.mkdir(parents=True, exist_ok=True)
26
+ os.environ["CYBERSECURITY_OWASP_WORKSPACE_ROOT"] = str(workspace_root)
27
+ return root / "sft"
28
+
29
+
30
+ def test_extracts_and_validates_action_json():
31
+ action = generate_sft_dataset.parse_action_text(
32
+ '```json\n{"tool_name":"inspect_policy_graph","arguments":{}}\n```'
33
+ )
34
+
35
+ assert isinstance(action, CyberSecurityOWASPAction)
36
+ assert action.tool_name == "inspect_policy_graph"
37
+
38
+
39
+ def test_prompt_uses_visible_observation_only():
40
+ _isolated_out_dir("prompt")
41
+ env = CybersecurityOwaspEnvironment()
42
+ try:
43
+ obs = env.reset(seed=501, split="train", difficulty=0)
44
+ prompt = generate_sft_dataset.build_user_prompt(obs, [])
45
+ finally:
46
+ env.close()
47
+
48
+ lowered = prompt.lower()
49
+ assert "hidden_facts" not in lowered
50
+ assert "oracle_hidden_focus" not in lowered
51
+ assert "reward_engine" not in lowered
52
+ assert "validators.py" not in lowered
53
+ assert "tests/hidden" not in lowered
54
+ assert "hidden tests" not in lowered
55
+
56
+
57
+ def test_chat_row_matches_conversational_sft_shape():
58
+ _isolated_out_dir("chat_row")
59
+ env = CybersecurityOwaspEnvironment()
60
+ try:
61
+ obs = env.reset(seed=502, split="train", difficulty=0)
62
+ messages = generate_sft_dataset.build_chat_messages(obs, [])
63
+ action = CyberSecurityOWASPAction(tool_name="inspect_policy_graph", arguments={})
64
+ row = generate_sft_dataset.make_chat_row(
65
+ messages=messages,
66
+ action=action,
67
+ metadata={
68
+ "target_model": generate_sft_dataset.DEFAULT_TARGET_MODEL,
69
+ "teacher_model": generate_sft_dataset.DEFAULT_TEACHER_MODEL,
70
+ "seed": 502,
71
+ },
72
+ )
73
+ finally:
74
+ env.close()
75
+
76
+ assert [message["role"] for message in row["messages"]] == [
77
+ "system",
78
+ "user",
79
+ "assistant",
80
+ ]
81
+ assert json.loads(row["messages"][-1]["content"]) == action.model_dump()
82
+ assert row["metadata"]["target_model"] == "unsloth/gemma-4-E2B-it"
83
+
84
+
85
+ def test_dry_run_oracle_creates_chat_jsonl_without_network():
86
+ out_dir = _isolated_out_dir("dry_run")
87
+ manifest = generate_sft_dataset.generate_dataset(
88
+ generate_sft_dataset.DatasetConfig(
89
+ episodes=2,
90
+ validation_episodes=1,
91
+ out_dir=out_dir,
92
+ dry_run_oracle=True,
93
+ )
94
+ )
95
+
96
+ assert manifest["episodes_attempted"] == 3
97
+ assert manifest["episodes_accepted"] == 3
98
+ assert (out_dir / "train.jsonl").exists()
99
+ assert (out_dir / "validation.jsonl").exists()
100
+ train_rows = [
101
+ json.loads(line)
102
+ for line in (out_dir / "train.jsonl").read_text(encoding="utf-8").splitlines()
103
+ if line.strip()
104
+ ]
105
+ validation_rows = [
106
+ json.loads(line)
107
+ for line in (out_dir / "validation.jsonl").read_text(encoding="utf-8").splitlines()
108
+ if line.strip()
109
+ ]
110
+ assert train_rows
111
+ assert validation_rows
112
+ assert all(row["messages"][-1]["role"] == "assistant" for row in train_rows)
113
+
114
+
115
+ def test_saved_oracle_trajectory_replays_to_success():
116
+ out_dir = _isolated_out_dir("replay")
117
+ generate_sft_dataset.generate_dataset(
118
+ generate_sft_dataset.DatasetConfig(
119
+ episodes=1,
120
+ out_dir=out_dir,
121
+ dry_run_oracle=True,
122
+ )
123
+ )
124
+ trajectory_path = next((out_dir / "trajectories").glob("train_seed*.json"))
125
+ trajectory = json.loads(trajectory_path.read_text(encoding="utf-8"))
126
+
127
+ env = CybersecurityOwaspEnvironment()
128
+ try:
129
+ env.reset(
130
+ seed=int(trajectory["seed"]),
131
+ split=trajectory["split"],
132
+ difficulty=int(trajectory["difficulty"]),
133
+ )
134
+ final = None
135
+ for action_data in trajectory["actions"]:
136
+ final = env.step(CyberSecurityOWASPAction(**action_data))
137
+ assert final is not None
138
+ assert final.done is True
139
+ assert env.state.success is True
140
+ assert not env.state.anti_cheat_flags
141
+ finally:
142
+ env.close()
tests/test_trackio_utils.py CHANGED
@@ -39,6 +39,10 @@ def test_canonical_tracking_fields_exist_and_are_numeric_where_expected():
39
  assert isinstance(fields["reward/hidden_authz_pass_rate"], float)
40
  assert isinstance(fields["reward/normal_flow_pass_rate"], float)
41
  assert isinstance(fields["reward/public_hidden_gap"], float)
 
 
 
 
42
  assert isinstance(fields["skill/exploit_to_patch_alignment"], float)
43
 
44
  metrics = aggregate_episode_metrics([record])
@@ -156,11 +160,13 @@ def test_log_reward_config_emits_scalar_values_and_table(monkeypatch):
156
  monkeypatch.setitem(sys.modules, "trackio", fake_trackio)
157
  monkeypatch.setenv("CYBERSECURITY_OWASP_REWARD_MODE", "dense_train")
158
  monkeypatch.setenv("CYBERSECURITY_OWASP_REWARD_STAGE", "early")
 
159
 
160
  settings = load_reward_settings()
161
  summary = log_reward_config(settings, step=0)
162
 
163
  assert fake_trackio.config["reward_config_hash"] == summary["reward_config_hash"]
 
164
  assert fake_trackio.config["reward_config_values"]["policy_inspected"]["value"] == 0.30
165
  assert fake_trackio.config["reward_config__policy_inspected__value"] == 0.30
166
  scalar_payload = next(payload for payload, _step in logged if "reward_config/policy_inspected/value" in payload)
 
39
  assert isinstance(fields["reward/hidden_authz_pass_rate"], float)
40
  assert isinstance(fields["reward/normal_flow_pass_rate"], float)
41
  assert isinstance(fields["reward/public_hidden_gap"], float)
42
+ assert isinstance(fields["reward/dense_to_terminal_ratio"], float)
43
+ assert isinstance(fields["episode/time_to_first_patch"], float)
44
+ assert isinstance(fields["episode/repeated_action_rate"], float)
45
+ assert isinstance(fields["episode/patch_to_hidden_success_conversion_rate"], float)
46
  assert isinstance(fields["skill/exploit_to_patch_alignment"], float)
47
 
48
  metrics = aggregate_episode_metrics([record])
 
160
  monkeypatch.setitem(sys.modules, "trackio", fake_trackio)
161
  monkeypatch.setenv("CYBERSECURITY_OWASP_REWARD_MODE", "dense_train")
162
  monkeypatch.setenv("CYBERSECURITY_OWASP_REWARD_STAGE", "early")
163
+ monkeypatch.setenv("CYBERSECURITY_OWASP_REWARD_VARIANT", "abl-test")
164
 
165
  settings = load_reward_settings()
166
  summary = log_reward_config(settings, step=0)
167
 
168
  assert fake_trackio.config["reward_config_hash"] == summary["reward_config_hash"]
169
+ assert fake_trackio.config["reward_variant"] == "abl-test"
170
  assert fake_trackio.config["reward_config_values"]["policy_inspected"]["value"] == 0.30
171
  assert fake_trackio.config["reward_config__policy_inspected__value"] == 0.30
172
  scalar_payload = next(payload for payload, _step in logged if "reward_config/policy_inspected/value" in payload)
training/configs/reward_ablations/A0_sparse_terminal_only.yaml ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ extends: ../grpo_small.yaml
2
+ reward:
3
+ stage: early
4
+ terminal_cap:
5
+ value: 12.0
6
+ progressive_cap:
7
+ value: 0.0
8
+ penalty_floor:
9
+ value: -10.0
10
+ train_cap:
11
+ value: 12.0
12
+ shaping_weight:
13
+ early: 0.0
14
+ middle: 0.0
15
+ late: 0.0
16
+ final: 0.0
17
+ policy_inspected:
18
+ value: 0.0
19
+ route_map_inspected:
20
+ value: 0.0
21
+ cap: 0.0
22
+ relevant_file_inspected:
23
+ value: 0.0
24
+ cap: 0.0
25
+ local_evidence_found:
26
+ value: 0.0
27
+ cap: 0.0
28
+ diagnosis_correct:
29
+ value: 0.0
30
+ patch_applies:
31
+ value: 0.0
32
+ app_boots_after_patch:
33
+ value: 0.0
34
+ visible_tests_improved:
35
+ value: 0.0
36
+ cap: 0.0
37
+ public_routes_visible_pass:
38
+ value: 0.0
39
+ step_penalty:
40
+ early: 0.0
41
+ middle: 0.0
42
+ late: 0.0
43
+ final: 0.0
44
+ cap: 0.0
45
+ speed_bonus:
46
+ value: 0.0
47
+ token_penalty:
48
+ early: 0.0
49
+ middle: 0.0
50
+ late: 0.0
51
+ final: 0.0
52
+ cap: 0.0
53
+ invalid_action:
54
+ value: 0.0
55
+ repeated_invalid_action:
56
+ value: 0.0
57
+ repeated_low_value_action:
58
+ value: 0.0
59
+ no_progress_action:
60
+ value: 0.0
61
+ noop_action:
62
+ value: 0.0
63
+ repeated_file_read:
64
+ value: 0.0
65
+ repeated_local_request:
66
+ value: 0.0
67
+ repeated_visible_tests:
68
+ value: 0.0
69
+ patch_before_policy:
70
+ value: 0.0
71
+ submit_without_patch:
72
+ value: 0.0
73
+ submit_without_visible_tests:
74
+ value: 0.0
75
+ excessive_patch_attempt:
76
+ value: 0.0
77
+ too_many_files_changed:
78
+ value: 0.0
79
+ oversized_patch:
80
+ value: 0.0
81
+ severe_value: 0.0
82
+ hidden_file_probe:
83
+ value: -10.0
84
+ terminate: true
85
+ test_file_edit:
86
+ value: -10.0
87
+ terminate: true
88
+ host_path_escape:
89
+ value: -10.0
90
+ terminate: true
91
+ external_url_attempt:
92
+ value: -10.0
93
+ terminate: true
94
+ hardcoded_identifier:
95
+ value: -4.0
96
+ deny_all_patch:
97
+ value: -5.0
training/configs/reward_ablations/A2_reduced_shaping.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ extends: ../grpo_small.yaml
2
+ reward:
3
+ stage: early
4
+ progressive_cap:
5
+ value: 2.5
6
+ train_cap:
7
+ value: 18.0
8
+ shaping_weight:
9
+ early: 0.35
10
+ middle: 0.35
11
+ late: 0.35
12
+ final: 0.35
training/configs/reward_ablations/A3_no_speed_token.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ extends: ../grpo_small.yaml
2
+ reward:
3
+ stage: early
4
+ step_penalty:
5
+ early: -0.002
6
+ middle: -0.002
7
+ late: -0.002
8
+ final: -0.002
9
+ cap: -0.25
10
+ speed_bonus:
11
+ value: 0.0
12
+ token_penalty:
13
+ early: 0.0
14
+ middle: 0.0
15
+ late: 0.0
16
+ final: 0.0
17
+ cap: 0.0
training/configs/reward_ablations/A6_visible_gate.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ extends: ../grpo_small.yaml
2
+ reward:
3
+ stage: early
4
+ app_boots_after_patch:
5
+ value: 0.10
6
+ visible_tests_improved:
7
+ value: 0.0
8
+ cap: 0.20
9
+ public_routes_visible_pass:
10
+ value: 0.10
training/configs/reward_ablations/A7_evidence045.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ extends: ../grpo_small.yaml
2
+ reward:
3
+ stage: early
4
+ local_evidence_found:
5
+ value: 0.45
6
+ cap: 0.45
training/trackio_utils.py CHANGED
@@ -17,6 +17,7 @@ RUN_SCENARIO_FIELDS = (
17
  "run/base_model",
18
  "run/algo",
19
  "run/reward_version",
 
20
  "run/env_version",
21
  "scenario/seed",
22
  "scenario/template_id",
@@ -136,6 +137,16 @@ CANONICAL_TRACKIO_SIGNALS = tuple(
136
 
137
  DERIVED_TRACKIO_METRICS = (
138
  "reward/public_hidden_gap",
 
 
 
 
 
 
 
 
 
 
139
  "cheat/score",
140
  )
141
 
@@ -545,6 +556,7 @@ def episode_record_from_state(
545
  "run/reward_config_hash": context.get("reward_config_hash", ""),
546
  "run/reward_mode": context.get("reward_mode", ""),
547
  "run/reward_stage": context.get("reward_stage", ""),
 
548
  "run/env_version": context.get("env_version", "0.1.0"),
549
  "episode_id": getattr(state, "episode_id", ""),
550
  "task_id": getattr(state, "task_id", ""),
@@ -663,6 +675,16 @@ def episode_to_tracking_fields(episode: Any) -> dict[str, Any]:
663
  -_float(final_reward.get("safety")),
664
  )
665
  fields["reward/public_hidden_gap"] = visible_rate - hidden_rate
 
 
 
 
 
 
 
 
 
 
666
 
667
  fields["skill/valid_action_rate"] = max(0.0, min(1.0, valid_action_rate))
668
  fields["skill/discovery_success"] = fields["reward/bug_classification_correct"]
@@ -698,11 +720,33 @@ def episode_to_tracking_fields(episode: Any) -> dict[str, Any]:
698
  fields["episode/patch_attempt_count"] = _float(
699
  record.get("patch_attempt_count", tool_names.count("patch_file"))
700
  )
 
 
 
 
 
 
 
 
 
 
 
 
701
  fields["episode/submit_without_test_rate"] = 1.0 if (
702
  "submit_fix" in tool_names and "patch_file" in tool_names and "run_visible_tests" not in tool_names
703
  ) else 0.0
 
 
 
704
  fields["episode/cheat_attempt_rate"] = 1.0 if anti_flags else 0.0
705
  fields["episode/oversecure_rate"] = 1.0 if bool(quality.get("deny_all")) else 0.0
 
 
 
 
 
 
 
706
 
707
  fields["cheat/hidden_file_read_attempt"] = 1.0 if (
708
  "blocked_file_probe" in anti_flags and "hidden" in path_text
@@ -802,6 +846,15 @@ def train_metric_aliases(metrics: Mapping[str, Any]) -> dict[str, float]:
802
  "train/reward_token_penalty_mean": _float(metrics.get("reward/token_penalty")),
803
  "train/reward_speed_bonus_mean": _float(metrics.get("reward/speed_bonus")),
804
  "train/reward_behavior_penalty_mean": _float(metrics.get("reward/behavior_penalty")),
 
 
 
 
 
 
 
 
 
805
  "train/success_rate": _float(metrics.get("skill/patch_success")),
806
  "train/exploit_block_rate": _float(metrics.get("reward/hidden_authz_pass_rate")),
807
  "train/regression_preservation_rate": _float(metrics.get("reward/normal_flow_pass_rate")),
 
17
  "run/base_model",
18
  "run/algo",
19
  "run/reward_version",
20
+ "run/reward_variant",
21
  "run/env_version",
22
  "scenario/seed",
23
  "scenario/template_id",
 
137
 
138
  DERIVED_TRACKIO_METRICS = (
139
  "reward/public_hidden_gap",
140
+ "reward/visible_hidden_gap",
141
+ "reward/dense_total",
142
+ "reward/dense_to_terminal_ratio",
143
+ "episode/time_to_first_evidence",
144
+ "episode/time_to_first_patch",
145
+ "episode/repeated_action_rate",
146
+ "episode/submit_without_evidence_rate",
147
+ "episode/hardcoded_identifier_rate",
148
+ "episode/deny_all_patch_rate",
149
+ "episode/patch_to_hidden_success_conversion_rate",
150
  "cheat/score",
151
  )
152
 
 
556
  "run/reward_config_hash": context.get("reward_config_hash", ""),
557
  "run/reward_mode": context.get("reward_mode", ""),
558
  "run/reward_stage": context.get("reward_stage", ""),
559
+ "run/reward_variant": context.get("reward_variant", ""),
560
  "run/env_version": context.get("env_version", "0.1.0"),
561
  "episode_id": getattr(state, "episode_id", ""),
562
  "task_id": getattr(state, "task_id", ""),
 
675
  -_float(final_reward.get("safety")),
676
  )
677
  fields["reward/public_hidden_gap"] = visible_rate - hidden_rate
678
+ fields["reward/visible_hidden_gap"] = fields["reward/public_hidden_gap"]
679
+ fields["reward/dense_total"] = (
680
+ fields["reward/total"] - fields["reward/terminal_15"]
681
+ )
682
+ terminal_denominator = abs(fields["reward/terminal_15"])
683
+ fields["reward/dense_to_terminal_ratio"] = (
684
+ fields["reward/dense_total"] / terminal_denominator
685
+ if terminal_denominator > 1e-9
686
+ else fields["reward/dense_total"]
687
+ )
688
 
689
  fields["skill/valid_action_rate"] = max(0.0, min(1.0, valid_action_rate))
690
  fields["skill/discovery_success"] = fields["reward/bug_classification_correct"]
 
720
  fields["episode/patch_attempt_count"] = _float(
721
  record.get("patch_attempt_count", tool_names.count("patch_file"))
722
  )
723
+ fields["episode/time_to_first_evidence"] = fields["skill/first_valid_exploit_step"]
724
+ fields["episode/time_to_first_patch"] = _first_tool_step(
725
+ actions,
726
+ {"patch_file"},
727
+ observations,
728
+ )
729
+ action_keys = [
730
+ json.dumps(action, sort_keys=True, default=str)
731
+ for action in actions
732
+ ]
733
+ repeated_actions = len(action_keys) - len(set(action_keys))
734
+ fields["episode/repeated_action_rate"] = repeated_actions / max(1.0, float(len(action_keys)))
735
  fields["episode/submit_without_test_rate"] = 1.0 if (
736
  "submit_fix" in tool_names and "patch_file" in tool_names and "run_visible_tests" not in tool_names
737
  ) else 0.0
738
+ fields["episode/submit_without_evidence_rate"] = 1.0 if (
739
+ "submit_fix" in tool_names and fields["episode/time_to_first_evidence"] < 0.0
740
+ ) else 0.0
741
  fields["episode/cheat_attempt_rate"] = 1.0 if anti_flags else 0.0
742
  fields["episode/oversecure_rate"] = 1.0 if bool(quality.get("deny_all")) else 0.0
743
+ fields["episode/hardcoded_identifier_rate"] = 1.0 if (
744
+ "hardcoded_hidden_identifier" in anti_flags or bool(quality.get("hardcoded"))
745
+ ) else 0.0
746
+ fields["episode/deny_all_patch_rate"] = fields["episode/oversecure_rate"]
747
+ fields["episode/patch_to_hidden_success_conversion_rate"] = (
748
+ hidden_rate if "patch_file" in tool_names else 0.0
749
+ )
750
 
751
  fields["cheat/hidden_file_read_attempt"] = 1.0 if (
752
  "blocked_file_probe" in anti_flags and "hidden" in path_text
 
846
  "train/reward_token_penalty_mean": _float(metrics.get("reward/token_penalty")),
847
  "train/reward_speed_bonus_mean": _float(metrics.get("reward/speed_bonus")),
848
  "train/reward_behavior_penalty_mean": _float(metrics.get("reward/behavior_penalty")),
849
+ "train/dense_to_terminal_ratio": _float(metrics.get("reward/dense_to_terminal_ratio")),
850
+ "train/visible_hidden_gap": _float(metrics.get("reward/visible_hidden_gap")),
851
+ "train/repeated_action_rate": _float(metrics.get("episode/repeated_action_rate")),
852
+ "train/submit_without_evidence_rate": _float(metrics.get("episode/submit_without_evidence_rate")),
853
+ "train/hardcoded_identifier_rate": _float(metrics.get("episode/hardcoded_identifier_rate")),
854
+ "train/deny_all_patch_rate": _float(metrics.get("episode/deny_all_patch_rate")),
855
+ "train/patch_to_hidden_success_conversion_rate": _float(
856
+ metrics.get("episode/patch_to_hidden_success_conversion_rate")
857
+ ),
858
  "train/success_rate": _float(metrics.get("skill/patch_success")),
859
  "train/exploit_block_rate": _float(metrics.get("reward/hidden_authz_pass_rate")),
860
  "train/regression_preservation_rate": _float(metrics.get("reward/normal_flow_pass_rate")),