Spaces:
Sleeping
Sleeping
Commit ·
632c145
1
Parent(s): 2eada22
feat: enhance CyberSecurity_OWASP observation model with scenario prompt, improve GRPO batch configuration validation, and add scenario grouping for adaptive difficulty curriculum
Browse files- models.py +1 -0
- scripts/modal_train_grpo.py +319 -105
- server/CyberSecurity_OWASP_environment.py +57 -2
- server/scenario_cache.py +25 -0
- tests/test_closed_loop_runtime.py +28 -0
- tests/test_grpo_curriculum.py +138 -0
- tests/test_trackio_utils.py +8 -0
- training/grpo_curriculum.py +260 -0
- training/trackio_utils.py +18 -4
models.py
CHANGED
|
@@ -38,6 +38,7 @@ class CyberSecurityOWASPObservation(Observation):
|
|
| 38 |
phase: CyberSecurityOWASPPhase = "discover"
|
| 39 |
message: str = ""
|
| 40 |
task_brief: str = ""
|
|
|
|
| 41 |
visible_policy_hint: dict[str, Any] = Field(default_factory=dict)
|
| 42 |
workspace_summary: dict[str, Any] = Field(default_factory=dict)
|
| 43 |
available_actions: list[str] = Field(default_factory=list)
|
|
|
|
| 38 |
phase: CyberSecurityOWASPPhase = "discover"
|
| 39 |
message: str = ""
|
| 40 |
task_brief: str = ""
|
| 41 |
+
scenario_prompt: str = ""
|
| 42 |
visible_policy_hint: dict[str, Any] = Field(default_factory=dict)
|
| 43 |
workspace_summary: dict[str, Any] = Field(default_factory=dict)
|
| 44 |
available_actions: list[str] = Field(default_factory=list)
|
scripts/modal_train_grpo.py
CHANGED
|
@@ -71,6 +71,55 @@ def _hf_model_cache_path(model_name: str) -> pathlib.Path:
|
|
| 71 |
return HF_HUB_CACHE_DIR / f"models--{model_name.replace('/', '--')}"
|
| 72 |
|
| 73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
def _configure_modal_cache_env() -> dict[str, str]:
|
| 75 |
values = {
|
| 76 |
"HF_HOME": str(HF_HOME_DIR),
|
|
@@ -374,22 +423,20 @@ def verify_modal_scenario_cache_for_training(
|
|
| 374 |
resolved_difficulty = int(scenario_profile["difficulty"])
|
| 375 |
cache = ScenarioCache(SCENARIO_CACHE_DIR, settings=settings)
|
| 376 |
coverage = cache.assert_coverage(split=split, difficulty=resolved_difficulty)
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
.
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
raise RuntimeError(
|
| 384 |
-
"Scenario cache does not cover this Modal dataset. Run "
|
| 385 |
-
"--mode prepare-cache with a larger per-bucket count before training. "
|
| 386 |
-
f"available={available_scenarios}, requested_dataset_size={dataset_size}, "
|
| 387 |
-
f"split={split}, difficulty={resolved_difficulty}"
|
| 388 |
-
)
|
| 389 |
|
| 390 |
env = CybersecurityOwaspEnvironment()
|
| 391 |
try:
|
| 392 |
-
obs = env.reset(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 393 |
if not env.state.cache_hit:
|
| 394 |
raise RuntimeError("Scenario cache preflight reset did not hit cache.")
|
| 395 |
if env.state.metrics.get("scenario_compile_latency_ms", 0.0):
|
|
@@ -413,9 +460,10 @@ def verify_modal_scenario_cache_for_training(
|
|
| 413 |
"scenario_cache_dir": str(SCENARIO_CACHE_DIR),
|
| 414 |
"scenario_cache_mode": "require",
|
| 415 |
"split": split,
|
| 416 |
-
"difficulty":
|
|
|
|
| 417 |
"dataset_size": dataset_size,
|
| 418 |
-
"available_scenarios":
|
| 419 |
"coverage": coverage,
|
| 420 |
"sample_reset": sample,
|
| 421 |
}
|
|
@@ -480,6 +528,11 @@ def train_cybersecurity_owasp_grpo(
|
|
| 480 |
trackio_space_id: str = "Humanlearning/CyberSecurity_OWASP-trackio",
|
| 481 |
trackio_project: str = "CyberSecurity_OWASP-grpo",
|
| 482 |
num_generations: int = 6,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 483 |
seed_start: int = 0,
|
| 484 |
git_sha: str = "nogit",
|
| 485 |
run_name: str = "",
|
|
@@ -495,6 +548,21 @@ def train_cybersecurity_owasp_grpo(
|
|
| 495 |
|
| 496 |
model_name = _ensure_gemma4_model(model_name)
|
| 497 |
cache_env = _configure_modal_cache_env()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 498 |
|
| 499 |
import torch
|
| 500 |
from unsloth import FastVisionModel
|
|
@@ -524,6 +592,10 @@ def train_cybersecurity_owasp_grpo(
|
|
| 524 |
log_trackio_metrics,
|
| 525 |
train_metric_aliases,
|
| 526 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 527 |
|
| 528 |
transformers_hub.TRANSFORMERS_CACHE = cache_env["HF_HUB_CACHE"]
|
| 529 |
|
|
@@ -585,18 +657,14 @@ def train_cybersecurity_owasp_grpo(
|
|
| 585 |
split=split,
|
| 586 |
difficulty=int(scenario_profile["difficulty"]),
|
| 587 |
)
|
| 588 |
-
|
| 589 |
-
|
| 590 |
-
|
| 591 |
-
|
|
|
|
|
|
|
|
|
|
| 592 |
)
|
| 593 |
-
if available_scenarios < dataset_size:
|
| 594 |
-
raise RuntimeError(
|
| 595 |
-
"Scenario cache does not cover this Modal dataset. Run "
|
| 596 |
-
"--mode prepare-cache with a larger per-bucket count before training. "
|
| 597 |
-
f"available={available_scenarios}, requested_dataset_size={dataset_size}, "
|
| 598 |
-
f"split={split}, difficulty={scenario_profile['difficulty']}"
|
| 599 |
-
)
|
| 600 |
|
| 601 |
training_prompt = (
|
| 602 |
"You are a defensive AppSec repair agent in the local CyberSecurity_OWASP "
|
|
@@ -608,15 +676,14 @@ def train_cybersecurity_owasp_grpo(
|
|
| 608 |
)
|
| 609 |
|
| 610 |
dataset = Dataset.from_list(
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
|
| 619 |
-
]
|
| 620 |
)
|
| 621 |
|
| 622 |
def _state_snapshot(env: CybersecurityOwaspEnvironment) -> dict[str, Any]:
|
|
@@ -627,8 +694,10 @@ def train_cybersecurity_owasp_grpo(
|
|
| 627 |
"seed": state.seed,
|
| 628 |
"split": state.split,
|
| 629 |
"difficulty": state.difficulty,
|
|
|
|
| 630 |
"domain": state.domain,
|
| 631 |
"bug_family": state.bug_family,
|
|
|
|
| 632 |
"cache_hit": state.cache_hit,
|
| 633 |
"scenario_hash": state.scenario_hash,
|
| 634 |
"phase": state.phase,
|
|
@@ -647,18 +716,30 @@ def train_cybersecurity_owasp_grpo(
|
|
| 647 |
self.done = False
|
| 648 |
self.success = False
|
| 649 |
self.invalid_actions = 0
|
|
|
|
|
|
|
| 650 |
self.trace_messages: list[dict[str, str]] = []
|
| 651 |
self.trace_metadata: dict[str, Any] = {}
|
| 652 |
|
| 653 |
def reset(self, **kwargs) -> str:
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 657 |
obs = self._env.reset(
|
| 658 |
seed=seed,
|
| 659 |
split=current_split,
|
| 660 |
difficulty=current_difficulty,
|
| 661 |
)
|
|
|
|
|
|
|
| 662 |
self.reward = 0.0
|
| 663 |
self.reward_breakdown = {}
|
| 664 |
self.done = bool(obs.done)
|
|
@@ -668,18 +749,21 @@ def train_cybersecurity_owasp_grpo(
|
|
| 668 |
{
|
| 669 |
"role": "user",
|
| 670 |
"content": (
|
| 671 |
-
f"{training_prompt}\n\
|
| 672 |
-
f"
|
| 673 |
-
f"
|
| 674 |
-
f"Available actions: {obs.available_actions}\n"
|
| 675 |
-
f"Workspace summary: {obs.workspace_summary}\n"
|
| 676 |
-
f"Policy hint: {obs.visible_policy_hint}\n"
|
| 677 |
-
f"Message: {obs.message}"
|
| 678 |
),
|
| 679 |
}
|
| 680 |
]
|
| 681 |
self.trace_metadata = _state_snapshot(self._env)
|
| 682 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 683 |
|
| 684 |
def _step(self, tool_name: str, arguments: dict[str, Any] | None = None) -> str:
|
| 685 |
if self.done:
|
|
@@ -714,6 +798,8 @@ def train_cybersecurity_owasp_grpo(
|
|
| 714 |
"invalid_actions": self.invalid_actions,
|
| 715 |
"scenario_cache_hit": self._env.state.cache_hit,
|
| 716 |
"scenario_hash": self._env.state.scenario_hash,
|
|
|
|
|
|
|
| 717 |
}
|
| 718 |
)
|
| 719 |
return obs.message
|
|
@@ -938,11 +1024,58 @@ def train_cybersecurity_owasp_grpo(
|
|
| 938 |
)
|
| 939 |
episode_records.append(record)
|
| 940 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 941 |
canonical_metrics = aggregate_episode_metrics(episode_records)
|
| 942 |
metrics = {
|
| 943 |
**canonical_metrics,
|
| 944 |
**train_metric_aliases(canonical_metrics),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 945 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 946 |
if rewards:
|
| 947 |
metrics["train/reward_mean"] = _mean(rewards)
|
| 948 |
metrics["train/reward_std"] = statistics.pstdev(rewards) if len(rewards) > 1 else 0.0
|
|
@@ -952,60 +1085,57 @@ def train_cybersecurity_owasp_grpo(
|
|
| 952 |
except Exception as exc:
|
| 953 |
print(f"Trackio metric logging skipped: {exc!r}")
|
| 954 |
|
| 955 |
-
|
| 956 |
-
seen_this_batch: set[str] = set()
|
| 957 |
-
for index, (env, record, reward) in enumerate(zip(environments, episode_records, rewards)):
|
| 958 |
-
fingerprint = episode_trace_fingerprint(record)
|
| 959 |
-
if fingerprint in seen_this_batch or fingerprint in logged_trace_fingerprints:
|
| 960 |
-
continue
|
| 961 |
-
seen_this_batch.add(fingerprint)
|
| 962 |
-
logged_trace_fingerprints.add(fingerprint)
|
| 963 |
-
sampled_traces.append((index, env, record, reward, fingerprint))
|
| 964 |
-
if len(sampled_traces) >= 4:
|
| 965 |
-
break
|
| 966 |
-
|
| 967 |
-
try:
|
| 968 |
-
log_trace_table(
|
| 969 |
-
[record for _, _, record, _, _ in sampled_traces],
|
| 970 |
-
table_name="sample_traces",
|
| 971 |
-
step=trace_step["value"],
|
| 972 |
-
)
|
| 973 |
-
except Exception as exc:
|
| 974 |
-
print(f"Trackio sample trace table logging skipped: {exc!r}")
|
| 975 |
-
|
| 976 |
-
for index, env, _record, reward, fingerprint in sampled_traces:
|
| 977 |
-
messages = list(getattr(env, "trace_messages", []))
|
| 978 |
-
if index < len(completions):
|
| 979 |
-
completion_text = _completion_to_text(completions[index])
|
| 980 |
-
if completion_text:
|
| 981 |
-
messages.append(
|
| 982 |
-
{
|
| 983 |
-
"role": "assistant",
|
| 984 |
-
"content": f"Raw generated completion:\n{completion_text}",
|
| 985 |
-
}
|
| 986 |
-
)
|
| 987 |
-
metadata = dict(getattr(env, "trace_metadata", {}))
|
| 988 |
-
metadata.update(
|
| 989 |
-
{
|
| 990 |
-
"sample_index": index,
|
| 991 |
-
"reward": reward,
|
| 992 |
-
"trace_step": trace_step["value"],
|
| 993 |
-
"trace_fingerprint": fingerprint,
|
| 994 |
-
"run_name": run_name,
|
| 995 |
-
}
|
| 996 |
-
)
|
| 997 |
try:
|
| 998 |
-
|
| 999 |
-
|
| 1000 |
-
|
| 1001 |
-
messages=messages,
|
| 1002 |
-
metadata=metadata,
|
| 1003 |
-
)
|
| 1004 |
-
},
|
| 1005 |
step=trace_step["value"],
|
| 1006 |
)
|
| 1007 |
except Exception as exc:
|
| 1008 |
-
print(f"Trackio trace logging skipped: {exc!r}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1009 |
|
| 1010 |
if rewards:
|
| 1011 |
print(
|
|
@@ -1080,6 +1210,20 @@ def train_cybersecurity_owasp_grpo(
|
|
| 1080 |
print(f"Unsloth cache: {cache_env['UNSLOTH_CACHE_DIR']}")
|
| 1081 |
print(f"Triton cache: {cache_env['TRITON_CACHE_DIR']}")
|
| 1082 |
print(f"Hub push enabled: {push_to_hub}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1083 |
|
| 1084 |
expected_model_cache = _hf_model_cache_path(model_name)
|
| 1085 |
cache_hit = expected_model_cache.exists()
|
|
@@ -1109,13 +1253,36 @@ def train_cybersecurity_owasp_grpo(
|
|
| 1109 |
|
| 1110 |
print(f"Loading model with Unsloth from_pretrained: {model_name}")
|
| 1111 |
model_api = FastVisionModel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1112 |
model, tokenizer = model_api.from_pretrained(
|
| 1113 |
-
|
| 1114 |
-
|
| 1115 |
-
|
| 1116 |
-
|
| 1117 |
-
|
| 1118 |
-
|
| 1119 |
)
|
| 1120 |
print("Model load complete.")
|
| 1121 |
cache_volume.commit()
|
|
@@ -1157,8 +1324,8 @@ def train_cybersecurity_owasp_grpo(
|
|
| 1157 |
"lr_scheduler_type": "linear",
|
| 1158 |
"optim": "adamw_8bit",
|
| 1159 |
"logging_steps": 1,
|
| 1160 |
-
"per_device_train_batch_size":
|
| 1161 |
-
"gradient_accumulation_steps":
|
| 1162 |
"num_generations": num_generations,
|
| 1163 |
"max_prompt_length": max_seq_length,
|
| 1164 |
"max_completion_length": max_completion_length,
|
|
@@ -1175,11 +1342,14 @@ def train_cybersecurity_owasp_grpo(
|
|
| 1175 |
"hub_strategy": "every_save",
|
| 1176 |
"gradient_checkpointing": True,
|
| 1177 |
"gradient_checkpointing_kwargs": {"use_reentrant": False},
|
|
|
|
|
|
|
|
|
|
| 1178 |
"epsilon": 0.2,
|
| 1179 |
"epsilon_high": 0.28,
|
| 1180 |
"delta": 1.5,
|
| 1181 |
"loss_type": "bnpo",
|
| 1182 |
-
"mask_truncated_completions":
|
| 1183 |
}
|
| 1184 |
grpo_config_parameters = set(inspect.signature(GRPOConfig).parameters)
|
| 1185 |
skipped_config_keys = sorted(set(grpo_config_values) - grpo_config_parameters)
|
|
@@ -1269,6 +1439,12 @@ def train_cybersecurity_owasp_grpo(
|
|
| 1269 |
"model_name": model_name,
|
| 1270 |
"max_completion_length": max_completion_length,
|
| 1271 |
"num_generations": num_generations,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1272 |
"source_mode": source_mode,
|
| 1273 |
"repo_url": repo_url,
|
| 1274 |
"repo_branch": repo_branch,
|
|
@@ -1294,6 +1470,11 @@ def main(
|
|
| 1294 |
trackio_space_id: str = "Humanlearning/CyberSecurity_OWASP-trackio",
|
| 1295 |
trackio_project: str = "CyberSecurity_OWASP-grpo",
|
| 1296 |
num_generations: int = 6,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1297 |
seed_start: int = 0,
|
| 1298 |
git_sha: str = "nogit",
|
| 1299 |
source_mode: str = "local",
|
|
@@ -1327,6 +1508,21 @@ def main(
|
|
| 1327 |
if mode != "train":
|
| 1328 |
raise ValueError("mode must be 'prepare-cache', 'train', or 'config'")
|
| 1329 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1330 |
trackio_space_id = trackio_space_id or os.environ.get(
|
| 1331 |
"TRACKIO_SPACE_ID",
|
| 1332 |
"Humanlearning/CyberSecurity_OWASP-trackio",
|
|
@@ -1392,6 +1588,19 @@ def main(
|
|
| 1392 |
print(f"Hub push enabled: {push_to_hub}")
|
| 1393 |
print(f"Model cache volume: {CACHE_VOLUME_NAME}")
|
| 1394 |
print(f"Scenario cache volume: {SCENARIO_CACHE_VOLUME_NAME}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1395 |
print("Launch phases:")
|
| 1396 |
print(
|
| 1397 |
"1. Modal image build/validation: happens before remote Python logs; "
|
|
@@ -1421,6 +1630,11 @@ def main(
|
|
| 1421 |
trackio_space_id=trackio_space_id,
|
| 1422 |
trackio_project=trackio_project,
|
| 1423 |
num_generations=num_generations,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1424 |
seed_start=seed_start,
|
| 1425 |
git_sha=git_sha,
|
| 1426 |
run_name=run_name,
|
|
|
|
| 71 |
return HF_HUB_CACHE_DIR / f"models--{model_name.replace('/', '--')}"
|
| 72 |
|
| 73 |
|
| 74 |
+
def _resolve_grpo_batch_config(
|
| 75 |
+
*,
|
| 76 |
+
per_device_train_batch_size: int,
|
| 77 |
+
gradient_accumulation_steps: int,
|
| 78 |
+
num_generations: int,
|
| 79 |
+
world_size: int = 1,
|
| 80 |
+
) -> tuple[int, int]:
|
| 81 |
+
if num_generations < 1:
|
| 82 |
+
raise ValueError("--num-generations must be at least 1.")
|
| 83 |
+
if per_device_train_batch_size < 1:
|
| 84 |
+
raise ValueError("--per-device-train-batch-size must be at least 1.")
|
| 85 |
+
if world_size < 1:
|
| 86 |
+
raise ValueError("world_size must be at least 1.")
|
| 87 |
+
|
| 88 |
+
resolved_gradient_accumulation_steps = (
|
| 89 |
+
gradient_accumulation_steps
|
| 90 |
+
if gradient_accumulation_steps > 0
|
| 91 |
+
else max(2, num_generations)
|
| 92 |
+
)
|
| 93 |
+
if resolved_gradient_accumulation_steps < 1:
|
| 94 |
+
raise ValueError("--gradient-accumulation-steps must be at least 1.")
|
| 95 |
+
|
| 96 |
+
effective_batch_size = (
|
| 97 |
+
per_device_train_batch_size
|
| 98 |
+
* resolved_gradient_accumulation_steps
|
| 99 |
+
* world_size
|
| 100 |
+
)
|
| 101 |
+
if effective_batch_size % num_generations:
|
| 102 |
+
raise ValueError(
|
| 103 |
+
"Invalid GRPO batch shape: "
|
| 104 |
+
"per_device_train_batch_size * gradient_accumulation_steps * world_size "
|
| 105 |
+
f"must be divisible by num_generations. Got "
|
| 106 |
+
f"{per_device_train_batch_size} * "
|
| 107 |
+
f"{resolved_gradient_accumulation_steps} * {world_size} = "
|
| 108 |
+
f"{effective_batch_size}, which is not divisible by {num_generations}."
|
| 109 |
+
)
|
| 110 |
+
return resolved_gradient_accumulation_steps, effective_batch_size
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def _validate_vllm_config(*, use_vllm: bool, vllm_gpu_memory_utilization: float) -> None:
|
| 114 |
+
if not use_vllm:
|
| 115 |
+
return
|
| 116 |
+
if not 0.0 < vllm_gpu_memory_utilization <= 0.95:
|
| 117 |
+
raise ValueError(
|
| 118 |
+
"--vllm-gpu-memory-utilization must be in the interval (0.0, 0.95] "
|
| 119 |
+
"when --use-vllm is enabled."
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
def _configure_modal_cache_env() -> dict[str, str]:
|
| 124 |
values = {
|
| 125 |
"HF_HOME": str(HF_HOME_DIR),
|
|
|
|
| 423 |
resolved_difficulty = int(scenario_profile["difficulty"])
|
| 424 |
cache = ScenarioCache(SCENARIO_CACHE_DIR, settings=settings)
|
| 425 |
coverage = cache.assert_coverage(split=split, difficulty=resolved_difficulty)
|
| 426 |
+
entries = cache.validated_entries(split=split, difficulty=resolved_difficulty)
|
| 427 |
+
if not entries:
|
| 428 |
+
entries = cache.validated_entries(split=split)
|
| 429 |
+
if not entries:
|
| 430 |
+
raise RuntimeError(f"No validated scenario cache entries found for split={split!r}.")
|
| 431 |
+
sample_entry = entries[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 432 |
|
| 433 |
env = CybersecurityOwaspEnvironment()
|
| 434 |
try:
|
| 435 |
+
obs = env.reset(
|
| 436 |
+
seed=int(sample_entry["seed"]),
|
| 437 |
+
split=str(sample_entry["split"]),
|
| 438 |
+
difficulty=int(sample_entry["difficulty"]),
|
| 439 |
+
)
|
| 440 |
if not env.state.cache_hit:
|
| 441 |
raise RuntimeError("Scenario cache preflight reset did not hit cache.")
|
| 442 |
if env.state.metrics.get("scenario_compile_latency_ms", 0.0):
|
|
|
|
| 460 |
"scenario_cache_dir": str(SCENARIO_CACHE_DIR),
|
| 461 |
"scenario_cache_mode": "require",
|
| 462 |
"split": split,
|
| 463 |
+
"difficulty": "adaptive",
|
| 464 |
+
"initial_difficulty": resolved_difficulty,
|
| 465 |
"dataset_size": dataset_size,
|
| 466 |
+
"available_scenarios": len(cache.validated_entries(split=split)),
|
| 467 |
"coverage": coverage,
|
| 468 |
"sample_reset": sample,
|
| 469 |
}
|
|
|
|
| 528 |
trackio_space_id: str = "Humanlearning/CyberSecurity_OWASP-trackio",
|
| 529 |
trackio_project: str = "CyberSecurity_OWASP-grpo",
|
| 530 |
num_generations: int = 6,
|
| 531 |
+
per_device_train_batch_size: int = 1,
|
| 532 |
+
gradient_accumulation_steps: int = 0,
|
| 533 |
+
use_vllm: bool = False,
|
| 534 |
+
vllm_gpu_memory_utilization: float = 0.2,
|
| 535 |
+
trace_log_every: int = 5,
|
| 536 |
seed_start: int = 0,
|
| 537 |
git_sha: str = "nogit",
|
| 538 |
run_name: str = "",
|
|
|
|
| 548 |
|
| 549 |
model_name = _ensure_gemma4_model(model_name)
|
| 550 |
cache_env = _configure_modal_cache_env()
|
| 551 |
+
world_size = int(os.environ.get("WORLD_SIZE", "1") or "1")
|
| 552 |
+
(
|
| 553 |
+
resolved_gradient_accumulation_steps,
|
| 554 |
+
effective_train_batch_size,
|
| 555 |
+
) = _resolve_grpo_batch_config(
|
| 556 |
+
per_device_train_batch_size=per_device_train_batch_size,
|
| 557 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
| 558 |
+
num_generations=num_generations,
|
| 559 |
+
world_size=world_size,
|
| 560 |
+
)
|
| 561 |
+
_validate_vllm_config(
|
| 562 |
+
use_vllm=use_vllm,
|
| 563 |
+
vllm_gpu_memory_utilization=vllm_gpu_memory_utilization,
|
| 564 |
+
)
|
| 565 |
+
trace_log_every = max(0, int(trace_log_every))
|
| 566 |
|
| 567 |
import torch
|
| 568 |
from unsloth import FastVisionModel
|
|
|
|
| 592 |
log_trackio_metrics,
|
| 593 |
train_metric_aliases,
|
| 594 |
)
|
| 595 |
+
from training.grpo_curriculum import (
|
| 596 |
+
ScenarioGroupRegistry,
|
| 597 |
+
build_scenario_group_rows,
|
| 598 |
+
)
|
| 599 |
|
| 600 |
transformers_hub.TRANSFORMERS_CACHE = cache_env["HF_HUB_CACHE"]
|
| 601 |
|
|
|
|
| 657 |
split=split,
|
| 658 |
difficulty=int(scenario_profile["difficulty"]),
|
| 659 |
)
|
| 660 |
+
scenario_entries = scenario_cache.validated_entries(split=split)
|
| 661 |
+
scenario_registry = ScenarioGroupRegistry(
|
| 662 |
+
scenario_entries,
|
| 663 |
+
split=split,
|
| 664 |
+
initial_difficulty=int(scenario_profile["difficulty"]),
|
| 665 |
+
rng_seed=seed_start,
|
| 666 |
+
max_level=scenario_settings.curriculum.difficulty_bucket_count - 1,
|
| 667 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 668 |
|
| 669 |
training_prompt = (
|
| 670 |
"You are a defensive AppSec repair agent in the local CyberSecurity_OWASP "
|
|
|
|
| 676 |
)
|
| 677 |
|
| 678 |
dataset = Dataset.from_list(
|
| 679 |
+
build_scenario_group_rows(
|
| 680 |
+
dataset_size=dataset_size,
|
| 681 |
+
training_prompt=training_prompt,
|
| 682 |
+
seed_start=seed_start,
|
| 683 |
+
split=split,
|
| 684 |
+
difficulty=difficulty,
|
| 685 |
+
difficulty_policy="adaptive",
|
| 686 |
+
)
|
|
|
|
| 687 |
)
|
| 688 |
|
| 689 |
def _state_snapshot(env: CybersecurityOwaspEnvironment) -> dict[str, Any]:
|
|
|
|
| 694 |
"seed": state.seed,
|
| 695 |
"split": state.split,
|
| 696 |
"difficulty": state.difficulty,
|
| 697 |
+
"difficulty_tier": state.difficulty_tier,
|
| 698 |
"domain": state.domain,
|
| 699 |
"bug_family": state.bug_family,
|
| 700 |
+
"template_id": state.template_id,
|
| 701 |
"cache_hit": state.cache_hit,
|
| 702 |
"scenario_hash": state.scenario_hash,
|
| 703 |
"phase": state.phase,
|
|
|
|
| 716 |
self.done = False
|
| 717 |
self.success = False
|
| 718 |
self.invalid_actions = 0
|
| 719 |
+
self.scenario_group_id = -1
|
| 720 |
+
self.scenario_assignment: dict[str, Any] = {}
|
| 721 |
self.trace_messages: list[dict[str, str]] = []
|
| 722 |
self.trace_metadata: dict[str, Any] = {}
|
| 723 |
|
| 724 |
def reset(self, **kwargs) -> str:
|
| 725 |
+
group_id = int(kwargs.get("scenario_group_id", kwargs.get("seed", seed_start)))
|
| 726 |
+
assignment = scenario_registry.assignment_for(
|
| 727 |
+
scenario_group_id=group_id,
|
| 728 |
+
requested_seed=int(kwargs.get("seed", seed_start)),
|
| 729 |
+
requested_difficulty=int(kwargs.get("difficulty", difficulty)),
|
| 730 |
+
split=str(kwargs.get("split", split)),
|
| 731 |
+
difficulty_policy=str(kwargs.get("difficulty_policy", "adaptive")),
|
| 732 |
+
)
|
| 733 |
+
seed = int(assignment["seed"])
|
| 734 |
+
current_difficulty = int(assignment["difficulty"])
|
| 735 |
+
current_split = str(assignment["split"])
|
| 736 |
obs = self._env.reset(
|
| 737 |
seed=seed,
|
| 738 |
split=current_split,
|
| 739 |
difficulty=current_difficulty,
|
| 740 |
)
|
| 741 |
+
self.scenario_group_id = group_id
|
| 742 |
+
self.scenario_assignment = assignment
|
| 743 |
self.reward = 0.0
|
| 744 |
self.reward_breakdown = {}
|
| 745 |
self.done = bool(obs.done)
|
|
|
|
| 749 |
{
|
| 750 |
"role": "user",
|
| 751 |
"content": (
|
| 752 |
+
f"{training_prompt}\n\n"
|
| 753 |
+
f"{obs.scenario_prompt}\n\n"
|
| 754 |
+
f"Initial message: {obs.message}"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 755 |
),
|
| 756 |
}
|
| 757 |
]
|
| 758 |
self.trace_metadata = _state_snapshot(self._env)
|
| 759 |
+
self.trace_metadata.update(
|
| 760 |
+
{
|
| 761 |
+
"scenario_group_id": self.scenario_group_id,
|
| 762 |
+
"scenario_assignment": dict(self.scenario_assignment),
|
| 763 |
+
"scenario_prompt_length": len(obs.scenario_prompt),
|
| 764 |
+
}
|
| 765 |
+
)
|
| 766 |
+
return obs.scenario_prompt
|
| 767 |
|
| 768 |
def _step(self, tool_name: str, arguments: dict[str, Any] | None = None) -> str:
|
| 769 |
if self.done:
|
|
|
|
| 798 |
"invalid_actions": self.invalid_actions,
|
| 799 |
"scenario_cache_hit": self._env.state.cache_hit,
|
| 800 |
"scenario_hash": self._env.state.scenario_hash,
|
| 801 |
+
"scenario_group_id": self.scenario_group_id,
|
| 802 |
+
"scenario_assignment": dict(self.scenario_assignment),
|
| 803 |
}
|
| 804 |
)
|
| 805 |
return obs.message
|
|
|
|
| 1024 |
)
|
| 1025 |
episode_records.append(record)
|
| 1026 |
|
| 1027 |
+
group_successes: dict[int, list[float]] = {}
|
| 1028 |
+
for env in environments:
|
| 1029 |
+
group_id = int(getattr(env, "scenario_group_id", -1))
|
| 1030 |
+
if group_id < 0:
|
| 1031 |
+
continue
|
| 1032 |
+
group_successes.setdefault(group_id, []).append(1.0 if getattr(env, "success", False) else 0.0)
|
| 1033 |
+
for group_id, successes in group_successes.items():
|
| 1034 |
+
scenario_registry.record_group_outcome(group_id, _mean(successes))
|
| 1035 |
+
|
| 1036 |
+
batch_fingerprints = [
|
| 1037 |
+
episode_trace_fingerprint(record)
|
| 1038 |
+
for record in episode_records
|
| 1039 |
+
]
|
| 1040 |
+
sampled_traces = []
|
| 1041 |
+
seen_this_batch: set[str] = set()
|
| 1042 |
+
duplicate_trace_suppressed_count = 0
|
| 1043 |
+
for index, (env, record, reward, fingerprint) in enumerate(
|
| 1044 |
+
zip(environments, episode_records, rewards, batch_fingerprints)
|
| 1045 |
+
):
|
| 1046 |
+
if fingerprint in seen_this_batch or fingerprint in logged_trace_fingerprints:
|
| 1047 |
+
duplicate_trace_suppressed_count += 1
|
| 1048 |
+
continue
|
| 1049 |
+
seen_this_batch.add(fingerprint)
|
| 1050 |
+
if len(sampled_traces) < 4:
|
| 1051 |
+
sampled_traces.append((index, env, record, reward, fingerprint))
|
| 1052 |
+
|
| 1053 |
+
should_log_trace_artifacts = trace_log_every > 0 and (
|
| 1054 |
+
trace_step["value"] == 1
|
| 1055 |
+
or trace_step["value"] % trace_log_every == 0
|
| 1056 |
+
)
|
| 1057 |
canonical_metrics = aggregate_episode_metrics(episode_records)
|
| 1058 |
metrics = {
|
| 1059 |
**canonical_metrics,
|
| 1060 |
**train_metric_aliases(canonical_metrics),
|
| 1061 |
+
**scenario_registry.metrics(
|
| 1062 |
+
episode_records,
|
| 1063 |
+
unique_trace_count=len(set(batch_fingerprints)),
|
| 1064 |
+
duplicate_trace_suppressed_count=duplicate_trace_suppressed_count,
|
| 1065 |
+
),
|
| 1066 |
}
|
| 1067 |
+
metrics["train/per_device_train_batch_size"] = float(per_device_train_batch_size)
|
| 1068 |
+
metrics["train/gradient_accumulation_steps"] = float(
|
| 1069 |
+
resolved_gradient_accumulation_steps
|
| 1070 |
+
)
|
| 1071 |
+
metrics["train/effective_train_batch_size"] = float(effective_train_batch_size)
|
| 1072 |
+
metrics["train/num_generations"] = float(num_generations)
|
| 1073 |
+
metrics["train/use_vllm"] = float(bool(use_vllm))
|
| 1074 |
+
metrics["train/vllm_gpu_memory_utilization"] = (
|
| 1075 |
+
float(vllm_gpu_memory_utilization) if use_vllm else 0.0
|
| 1076 |
+
)
|
| 1077 |
+
metrics["train/trace_log_every"] = float(trace_log_every)
|
| 1078 |
+
metrics["train/trace_artifacts_logged"] = float(should_log_trace_artifacts)
|
| 1079 |
if rewards:
|
| 1080 |
metrics["train/reward_mean"] = _mean(rewards)
|
| 1081 |
metrics["train/reward_std"] = statistics.pstdev(rewards) if len(rewards) > 1 else 0.0
|
|
|
|
| 1085 |
except Exception as exc:
|
| 1086 |
print(f"Trackio metric logging skipped: {exc!r}")
|
| 1087 |
|
| 1088 |
+
if should_log_trace_artifacts and sampled_traces:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1089 |
try:
|
| 1090 |
+
log_trace_table(
|
| 1091 |
+
[record for _, _, record, _, _ in sampled_traces],
|
| 1092 |
+
table_name="sample_traces",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1093 |
step=trace_step["value"],
|
| 1094 |
)
|
| 1095 |
except Exception as exc:
|
| 1096 |
+
print(f"Trackio sample trace table logging skipped: {exc!r}")
|
| 1097 |
+
|
| 1098 |
+
for index, env, _record, reward, fingerprint in sampled_traces:
|
| 1099 |
+
logged_trace_fingerprints.add(fingerprint)
|
| 1100 |
+
messages = list(getattr(env, "trace_messages", []))
|
| 1101 |
+
if index < len(completions):
|
| 1102 |
+
completion_text = _completion_to_text(completions[index])
|
| 1103 |
+
if completion_text:
|
| 1104 |
+
messages.append(
|
| 1105 |
+
{
|
| 1106 |
+
"role": "assistant",
|
| 1107 |
+
"content": f"Raw generated completion:\n{completion_text}",
|
| 1108 |
+
}
|
| 1109 |
+
)
|
| 1110 |
+
metadata = dict(getattr(env, "trace_metadata", {}))
|
| 1111 |
+
metadata.update(
|
| 1112 |
+
{
|
| 1113 |
+
"sample_index": index,
|
| 1114 |
+
"reward": reward,
|
| 1115 |
+
"trace_step": trace_step["value"],
|
| 1116 |
+
"trace_fingerprint": fingerprint,
|
| 1117 |
+
"num_generations": num_generations,
|
| 1118 |
+
"run_name": run_name,
|
| 1119 |
+
}
|
| 1120 |
+
)
|
| 1121 |
+
try:
|
| 1122 |
+
trackio.log(
|
| 1123 |
+
{
|
| 1124 |
+
f"cybersecurity_owasp_trace/sample_{index}": trackio.Trace(
|
| 1125 |
+
messages=messages,
|
| 1126 |
+
metadata=metadata,
|
| 1127 |
+
)
|
| 1128 |
+
},
|
| 1129 |
+
step=trace_step["value"],
|
| 1130 |
+
)
|
| 1131 |
+
except Exception as exc:
|
| 1132 |
+
print(f"Trackio trace logging skipped: {exc!r}")
|
| 1133 |
+
elif sampled_traces:
|
| 1134 |
+
print(
|
| 1135 |
+
"Trackio trace artifacts throttled at reward callback "
|
| 1136 |
+
f"{trace_step['value']}; set --trace-log-every 1 for every callback "
|
| 1137 |
+
"or 0 to disable trace artifacts."
|
| 1138 |
+
)
|
| 1139 |
|
| 1140 |
if rewards:
|
| 1141 |
print(
|
|
|
|
| 1210 |
print(f"Unsloth cache: {cache_env['UNSLOTH_CACHE_DIR']}")
|
| 1211 |
print(f"Triton cache: {cache_env['TRITON_CACHE_DIR']}")
|
| 1212 |
print(f"Hub push enabled: {push_to_hub}")
|
| 1213 |
+
print(
|
| 1214 |
+
"GRPO throughput config: "
|
| 1215 |
+
f"per_device_train_batch_size={per_device_train_batch_size}, "
|
| 1216 |
+
f"gradient_accumulation_steps={resolved_gradient_accumulation_steps}, "
|
| 1217 |
+
f"num_generations={num_generations}, "
|
| 1218 |
+
f"world_size={world_size}, "
|
| 1219 |
+
f"effective_train_batch_size={effective_train_batch_size}"
|
| 1220 |
+
)
|
| 1221 |
+
print(
|
| 1222 |
+
"Generation acceleration config: "
|
| 1223 |
+
f"use_vllm={use_vllm}, "
|
| 1224 |
+
f"vllm_gpu_memory_utilization={vllm_gpu_memory_utilization}, "
|
| 1225 |
+
f"trace_log_every={trace_log_every}"
|
| 1226 |
+
)
|
| 1227 |
|
| 1228 |
expected_model_cache = _hf_model_cache_path(model_name)
|
| 1229 |
cache_hit = expected_model_cache.exists()
|
|
|
|
| 1253 |
|
| 1254 |
print(f"Loading model with Unsloth from_pretrained: {model_name}")
|
| 1255 |
model_api = FastVisionModel
|
| 1256 |
+
model_load_values = {
|
| 1257 |
+
"model_name": model_name,
|
| 1258 |
+
"max_seq_length": max_seq_length,
|
| 1259 |
+
"load_in_4bit": False,
|
| 1260 |
+
"fast_inference": use_vllm,
|
| 1261 |
+
"gpu_memory_utilization": vllm_gpu_memory_utilization if use_vllm else None,
|
| 1262 |
+
"cache_dir": str(HF_HUB_CACHE_DIR),
|
| 1263 |
+
"token": hf_token,
|
| 1264 |
+
}
|
| 1265 |
+
from_pretrained_parameters = inspect.signature(model_api.from_pretrained).parameters
|
| 1266 |
+
from_pretrained_accepts_kwargs = any(
|
| 1267 |
+
parameter.kind == inspect.Parameter.VAR_KEYWORD
|
| 1268 |
+
for parameter in from_pretrained_parameters.values()
|
| 1269 |
+
)
|
| 1270 |
+
skipped_model_load_keys = sorted(
|
| 1271 |
+
key
|
| 1272 |
+
for key, value in model_load_values.items()
|
| 1273 |
+
if value is not None
|
| 1274 |
+
and key not in from_pretrained_parameters
|
| 1275 |
+
and not from_pretrained_accepts_kwargs
|
| 1276 |
+
)
|
| 1277 |
+
if skipped_model_load_keys:
|
| 1278 |
+
print(f"Skipping unsupported from_pretrained keys: {skipped_model_load_keys}")
|
| 1279 |
model, tokenizer = model_api.from_pretrained(
|
| 1280 |
+
**{
|
| 1281 |
+
key: value
|
| 1282 |
+
for key, value in model_load_values.items()
|
| 1283 |
+
if value is not None
|
| 1284 |
+
and (key in from_pretrained_parameters or from_pretrained_accepts_kwargs)
|
| 1285 |
+
}
|
| 1286 |
)
|
| 1287 |
print("Model load complete.")
|
| 1288 |
cache_volume.commit()
|
|
|
|
| 1324 |
"lr_scheduler_type": "linear",
|
| 1325 |
"optim": "adamw_8bit",
|
| 1326 |
"logging_steps": 1,
|
| 1327 |
+
"per_device_train_batch_size": per_device_train_batch_size,
|
| 1328 |
+
"gradient_accumulation_steps": resolved_gradient_accumulation_steps,
|
| 1329 |
"num_generations": num_generations,
|
| 1330 |
"max_prompt_length": max_seq_length,
|
| 1331 |
"max_completion_length": max_completion_length,
|
|
|
|
| 1342 |
"hub_strategy": "every_save",
|
| 1343 |
"gradient_checkpointing": True,
|
| 1344 |
"gradient_checkpointing_kwargs": {"use_reentrant": False},
|
| 1345 |
+
"use_vllm": use_vllm,
|
| 1346 |
+
"vllm_mode": "colocate",
|
| 1347 |
+
"vllm_gpu_memory_utilization": vllm_gpu_memory_utilization,
|
| 1348 |
"epsilon": 0.2,
|
| 1349 |
"epsilon_high": 0.28,
|
| 1350 |
"delta": 1.5,
|
| 1351 |
"loss_type": "bnpo",
|
| 1352 |
+
"mask_truncated_completions": False,
|
| 1353 |
}
|
| 1354 |
grpo_config_parameters = set(inspect.signature(GRPOConfig).parameters)
|
| 1355 |
skipped_config_keys = sorted(set(grpo_config_values) - grpo_config_parameters)
|
|
|
|
| 1439 |
"model_name": model_name,
|
| 1440 |
"max_completion_length": max_completion_length,
|
| 1441 |
"num_generations": num_generations,
|
| 1442 |
+
"per_device_train_batch_size": per_device_train_batch_size,
|
| 1443 |
+
"gradient_accumulation_steps": resolved_gradient_accumulation_steps,
|
| 1444 |
+
"effective_train_batch_size": effective_train_batch_size,
|
| 1445 |
+
"use_vllm": int(bool(use_vllm)),
|
| 1446 |
+
"vllm_gpu_memory_utilization": vllm_gpu_memory_utilization,
|
| 1447 |
+
"trace_log_every": trace_log_every,
|
| 1448 |
"source_mode": source_mode,
|
| 1449 |
"repo_url": repo_url,
|
| 1450 |
"repo_branch": repo_branch,
|
|
|
|
| 1470 |
trackio_space_id: str = "Humanlearning/CyberSecurity_OWASP-trackio",
|
| 1471 |
trackio_project: str = "CyberSecurity_OWASP-grpo",
|
| 1472 |
num_generations: int = 6,
|
| 1473 |
+
per_device_train_batch_size: int = 1,
|
| 1474 |
+
gradient_accumulation_steps: int = 0,
|
| 1475 |
+
use_vllm: bool = False,
|
| 1476 |
+
vllm_gpu_memory_utilization: float = 0.2,
|
| 1477 |
+
trace_log_every: int = 5,
|
| 1478 |
seed_start: int = 0,
|
| 1479 |
git_sha: str = "nogit",
|
| 1480 |
source_mode: str = "local",
|
|
|
|
| 1508 |
if mode != "train":
|
| 1509 |
raise ValueError("mode must be 'prepare-cache', 'train', or 'config'")
|
| 1510 |
|
| 1511 |
+
(
|
| 1512 |
+
resolved_gradient_accumulation_steps,
|
| 1513 |
+
effective_train_batch_size,
|
| 1514 |
+
) = _resolve_grpo_batch_config(
|
| 1515 |
+
per_device_train_batch_size=per_device_train_batch_size,
|
| 1516 |
+
gradient_accumulation_steps=gradient_accumulation_steps,
|
| 1517 |
+
num_generations=num_generations,
|
| 1518 |
+
world_size=1,
|
| 1519 |
+
)
|
| 1520 |
+
_validate_vllm_config(
|
| 1521 |
+
use_vllm=use_vllm,
|
| 1522 |
+
vllm_gpu_memory_utilization=vllm_gpu_memory_utilization,
|
| 1523 |
+
)
|
| 1524 |
+
trace_log_every = max(0, int(trace_log_every))
|
| 1525 |
+
|
| 1526 |
trackio_space_id = trackio_space_id or os.environ.get(
|
| 1527 |
"TRACKIO_SPACE_ID",
|
| 1528 |
"Humanlearning/CyberSecurity_OWASP-trackio",
|
|
|
|
| 1588 |
print(f"Hub push enabled: {push_to_hub}")
|
| 1589 |
print(f"Model cache volume: {CACHE_VOLUME_NAME}")
|
| 1590 |
print(f"Scenario cache volume: {SCENARIO_CACHE_VOLUME_NAME}")
|
| 1591 |
+
print(
|
| 1592 |
+
"GRPO throughput config: "
|
| 1593 |
+
f"per_device_train_batch_size={per_device_train_batch_size}, "
|
| 1594 |
+
f"gradient_accumulation_steps={resolved_gradient_accumulation_steps}, "
|
| 1595 |
+
f"num_generations={num_generations}, "
|
| 1596 |
+
f"effective_train_batch_size={effective_train_batch_size}"
|
| 1597 |
+
)
|
| 1598 |
+
print(
|
| 1599 |
+
"Generation acceleration config: "
|
| 1600 |
+
f"use_vllm={use_vllm}, "
|
| 1601 |
+
f"vllm_gpu_memory_utilization={vllm_gpu_memory_utilization}, "
|
| 1602 |
+
f"trace_log_every={trace_log_every}"
|
| 1603 |
+
)
|
| 1604 |
print("Launch phases:")
|
| 1605 |
print(
|
| 1606 |
"1. Modal image build/validation: happens before remote Python logs; "
|
|
|
|
| 1630 |
trackio_space_id=trackio_space_id,
|
| 1631 |
trackio_project=trackio_project,
|
| 1632 |
num_generations=num_generations,
|
| 1633 |
+
per_device_train_batch_size=per_device_train_batch_size,
|
| 1634 |
+
gradient_accumulation_steps=resolved_gradient_accumulation_steps,
|
| 1635 |
+
use_vllm=use_vllm,
|
| 1636 |
+
vllm_gpu_memory_utilization=vllm_gpu_memory_utilization,
|
| 1637 |
+
trace_log_every=trace_log_every,
|
| 1638 |
seed_start=seed_start,
|
| 1639 |
git_sha=git_sha,
|
| 1640 |
run_name=run_name,
|
server/CyberSecurity_OWASP_environment.py
CHANGED
|
@@ -373,13 +373,15 @@ class CybersecurityOwaspEnvironment(
|
|
| 373 |
visible_test_result: str | None = None,
|
| 374 |
done_reason: str | None = None,
|
| 375 |
) -> CyberSecurityOWASPObservation:
|
|
|
|
| 376 |
return CyberSecurityOWASPObservation(
|
| 377 |
phase=self._state.phase,
|
| 378 |
message=message,
|
| 379 |
task_brief=self._task_brief,
|
|
|
|
| 380 |
visible_policy_hint=self._visible_policy_hint,
|
| 381 |
workspace_summary=self._workspace_summary,
|
| 382 |
-
available_actions=
|
| 383 |
last_tool_result=message,
|
| 384 |
last_action_valid=valid,
|
| 385 |
last_action_error=error,
|
|
@@ -388,7 +390,60 @@ class CybersecurityOwaspEnvironment(
|
|
| 388 |
done_reason=done_reason,
|
| 389 |
done=self._state.done,
|
| 390 |
reward=reward,
|
| 391 |
-
metadata={
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 392 |
)
|
| 393 |
|
| 394 |
def _finalize_terminal_episode(self, observation_record: dict[str, Any]) -> None:
|
|
|
|
| 373 |
visible_test_result: str | None = None,
|
| 374 |
done_reason: str | None = None,
|
| 375 |
) -> CyberSecurityOWASPObservation:
|
| 376 |
+
available_actions = sorted(ALLOWED_TOOLS[self._state.phase])
|
| 377 |
return CyberSecurityOWASPObservation(
|
| 378 |
phase=self._state.phase,
|
| 379 |
message=message,
|
| 380 |
task_brief=self._task_brief,
|
| 381 |
+
scenario_prompt=self._scenario_prompt(available_actions),
|
| 382 |
visible_policy_hint=self._visible_policy_hint,
|
| 383 |
workspace_summary=self._workspace_summary,
|
| 384 |
+
available_actions=available_actions,
|
| 385 |
last_tool_result=message,
|
| 386 |
last_action_valid=valid,
|
| 387 |
last_action_error=error,
|
|
|
|
| 390 |
done_reason=done_reason,
|
| 391 |
done=self._state.done,
|
| 392 |
reward=reward,
|
| 393 |
+
metadata={
|
| 394 |
+
"episode_id": self._state.episode_id,
|
| 395 |
+
"step_count": self._state.step_count,
|
| 396 |
+
"seed": self._state.seed,
|
| 397 |
+
"split": self._state.split,
|
| 398 |
+
"difficulty": self._state.difficulty,
|
| 399 |
+
"difficulty_tier": self._state.difficulty_tier,
|
| 400 |
+
"domain": self._state.domain,
|
| 401 |
+
"bug_family": self._state.bug_family,
|
| 402 |
+
"template_id": self._state.template_id,
|
| 403 |
+
"scenario_hash": self._state.scenario_hash,
|
| 404 |
+
},
|
| 405 |
+
)
|
| 406 |
+
|
| 407 |
+
def _scenario_prompt(self, available_actions: list[str]) -> str:
|
| 408 |
+
users = self._visible_policy_hint.get("fixture_aliases", {}).get("users", {})
|
| 409 |
+
resources = self._visible_policy_hint.get("fixture_aliases", {}).get("resources", {})
|
| 410 |
+
visible_policy = {
|
| 411 |
+
"domain": self._visible_policy_hint.get("domain", self._state.domain),
|
| 412 |
+
"policy_rules": list(self._visible_policy_hint.get("policy_rules", [])),
|
| 413 |
+
"public_routes": list(self._visible_policy_hint.get("public_routes", [])),
|
| 414 |
+
"fixture_aliases": {
|
| 415 |
+
"users": sorted(str(key) for key in users),
|
| 416 |
+
"resources": sorted(str(key) for key in resources),
|
| 417 |
+
},
|
| 418 |
+
}
|
| 419 |
+
prompt = {
|
| 420 |
+
"environment": "CyberSecurity_OWASP",
|
| 421 |
+
"task": self._task_brief,
|
| 422 |
+
"scenario": {
|
| 423 |
+
"task_id": self._state.task_id,
|
| 424 |
+
"seed": self._state.seed,
|
| 425 |
+
"split": self._state.split,
|
| 426 |
+
"difficulty": self._state.difficulty,
|
| 427 |
+
"difficulty_tier": self._state.difficulty_tier,
|
| 428 |
+
"domain": self._state.domain,
|
| 429 |
+
"bug_family": self._state.bug_family,
|
| 430 |
+
"template_id": self._state.template_id,
|
| 431 |
+
"scenario_hash": self._state.scenario_hash,
|
| 432 |
+
},
|
| 433 |
+
"visible_policy_hint": visible_policy,
|
| 434 |
+
"workspace_summary": self._workspace_summary,
|
| 435 |
+
"available_actions": available_actions,
|
| 436 |
+
"instructions": [
|
| 437 |
+
"Use only the local generated application and the listed tools.",
|
| 438 |
+
"Discover the authorization defect with local evidence before patching.",
|
| 439 |
+
"Preserve legitimate owner/admin flows and intentionally public routes.",
|
| 440 |
+
"Submit exactly one secure, policy-aligned fix when ready.",
|
| 441 |
+
],
|
| 442 |
+
}
|
| 443 |
+
return "CyberSecurity_OWASP scenario prompt:\n" + json.dumps(
|
| 444 |
+
prompt,
|
| 445 |
+
indent=2,
|
| 446 |
+
sort_keys=True,
|
| 447 |
)
|
| 448 |
|
| 449 |
def _finalize_terminal_episode(self, observation_record: dict[str, Any]) -> None:
|
server/scenario_cache.py
CHANGED
|
@@ -265,6 +265,29 @@ class ScenarioCache:
|
|
| 265 |
counts[split][difficulty] = counts[split].get(difficulty, 0) + 1
|
| 266 |
return {"root": str(self.root), "counts": counts, "entries": len(self._manifest_entries())}
|
| 267 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
def assert_coverage(self, *, split: str, difficulty: int | None = None) -> dict[str, Any]:
|
| 269 |
coverage = self.coverage()
|
| 270 |
required = self.settings.curriculum.minimum_for_split(split)
|
|
@@ -490,6 +513,8 @@ def _manifest_entry(
|
|
| 490 |
"seed": int(scenario_record.get("seed", 0)),
|
| 491 |
"split": str(scenario_record.get("split", "train")),
|
| 492 |
"difficulty": int(scenario_record.get("difficulty", 0)),
|
|
|
|
|
|
|
| 493 |
"scenario_hash": str(metadata.get("scenario_hash", "")),
|
| 494 |
"cache_key": metadata.get("cache_key", {}),
|
| 495 |
"validated": bool(metadata.get("validated", False)),
|
|
|
|
| 265 |
counts[split][difficulty] = counts[split].get(difficulty, 0) + 1
|
| 266 |
return {"root": str(self.root), "counts": counts, "entries": len(self._manifest_entries())}
|
| 267 |
|
| 268 |
+
def validated_entries(
|
| 269 |
+
self,
|
| 270 |
+
*,
|
| 271 |
+
split: str | None = None,
|
| 272 |
+
difficulty: int | None = None,
|
| 273 |
+
) -> list[dict[str, Any]]:
|
| 274 |
+
entries = [
|
| 275 |
+
dict(entry)
|
| 276 |
+
for entry in self._manifest_entries()
|
| 277 |
+
if entry.get("validated") is True
|
| 278 |
+
and (split is None or entry.get("split") == split)
|
| 279 |
+
and (difficulty is None or int(entry.get("difficulty", -1)) == int(difficulty))
|
| 280 |
+
]
|
| 281 |
+
return sorted(
|
| 282 |
+
entries,
|
| 283 |
+
key=lambda item: (
|
| 284 |
+
str(item.get("split", "")),
|
| 285 |
+
int(item.get("difficulty", 0)),
|
| 286 |
+
int(item.get("seed", 0)),
|
| 287 |
+
str(item.get("scenario_hash", "")),
|
| 288 |
+
),
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
def assert_coverage(self, *, split: str, difficulty: int | None = None) -> dict[str, Any]:
|
| 292 |
coverage = self.coverage()
|
| 293 |
required = self.settings.curriculum.minimum_for_split(split)
|
|
|
|
| 513 |
"seed": int(scenario_record.get("seed", 0)),
|
| 514 |
"split": str(scenario_record.get("split", "train")),
|
| 515 |
"difficulty": int(scenario_record.get("difficulty", 0)),
|
| 516 |
+
"template_id": str(scenario_record.get("template_id", "")),
|
| 517 |
+
"bug_family": str(scenario_record.get("bug_family", "")),
|
| 518 |
"scenario_hash": str(metadata.get("scenario_hash", "")),
|
| 519 |
"cache_key": metadata.get("cache_key", {}),
|
| 520 |
"validated": bool(metadata.get("validated", False)),
|
tests/test_closed_loop_runtime.py
CHANGED
|
@@ -49,6 +49,34 @@ def test_reset_records_scenario_family_and_partial_observability():
|
|
| 49 |
assert "injected bug" not in serialized_hint
|
| 50 |
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
def test_authz_oracle_fails_vulnerable_app_and_passes_secure_patch():
|
| 53 |
env = make_env(72)
|
| 54 |
oracle = AuthzOracle()
|
|
|
|
| 49 |
assert "injected bug" not in serialized_hint
|
| 50 |
|
| 51 |
|
| 52 |
+
def test_reset_returns_visible_scenario_prompt_without_hidden_identifiers():
|
| 53 |
+
env = make_env(75)
|
| 54 |
+
obs = env.reset(seed=75, split="train", difficulty=0)
|
| 55 |
+
prompt = obs.scenario_prompt
|
| 56 |
+
hidden = dict(env.state.hidden_facts)
|
| 57 |
+
|
| 58 |
+
assert "CyberSecurity_OWASP scenario prompt" in prompt
|
| 59 |
+
assert "available_actions" in prompt
|
| 60 |
+
assert str(env.state.seed) in prompt
|
| 61 |
+
assert env.state.scenario_hash in prompt
|
| 62 |
+
assert env.state.template_id in prompt
|
| 63 |
+
|
| 64 |
+
for key in (
|
| 65 |
+
"owner_user_id",
|
| 66 |
+
"intruder_user_id",
|
| 67 |
+
"admin_user_id",
|
| 68 |
+
"owner_invoice_id",
|
| 69 |
+
"other_invoice_id",
|
| 70 |
+
"foreign_invoice_id",
|
| 71 |
+
"tenant_a",
|
| 72 |
+
"tenant_b",
|
| 73 |
+
):
|
| 74 |
+
value = str(hidden.get(key, ""))
|
| 75 |
+
assert not value or value not in prompt
|
| 76 |
+
assert "hidden_tests" not in prompt.lower()
|
| 77 |
+
assert "oracle" not in prompt.lower()
|
| 78 |
+
|
| 79 |
+
|
| 80 |
def test_authz_oracle_fails_vulnerable_app_and_passes_secure_patch():
|
| 81 |
env = make_env(72)
|
| 82 |
oracle = AuthzOracle()
|
tests/test_grpo_curriculum.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from training.grpo_curriculum import (
|
| 2 |
+
AdaptiveDifficultyCurriculum,
|
| 3 |
+
ScenarioGroupRegistry,
|
| 4 |
+
build_scenario_group_rows,
|
| 5 |
+
)
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def _entries():
|
| 9 |
+
return [
|
| 10 |
+
{
|
| 11 |
+
"seed": 10,
|
| 12 |
+
"split": "train",
|
| 13 |
+
"difficulty": 0,
|
| 14 |
+
"template_id": "fastapi_basic",
|
| 15 |
+
"bug_family": "bola_idor",
|
| 16 |
+
"scenario_hash": "hash-a",
|
| 17 |
+
"validated": True,
|
| 18 |
+
},
|
| 19 |
+
{
|
| 20 |
+
"seed": 20,
|
| 21 |
+
"split": "train",
|
| 22 |
+
"difficulty": 1,
|
| 23 |
+
"template_id": "fastapi_basic",
|
| 24 |
+
"bug_family": "bfla",
|
| 25 |
+
"scenario_hash": "hash-b",
|
| 26 |
+
"validated": True,
|
| 27 |
+
},
|
| 28 |
+
{
|
| 29 |
+
"seed": 30,
|
| 30 |
+
"split": "train",
|
| 31 |
+
"difficulty": 1,
|
| 32 |
+
"template_id": "fastapi_basic",
|
| 33 |
+
"bug_family": "tenant_leak",
|
| 34 |
+
"scenario_hash": "hash-c",
|
| 35 |
+
"validated": True,
|
| 36 |
+
},
|
| 37 |
+
]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def test_scenario_group_reuses_assignment_for_all_generations():
|
| 41 |
+
registry = ScenarioGroupRegistry(
|
| 42 |
+
_entries(),
|
| 43 |
+
split="train",
|
| 44 |
+
initial_difficulty=0,
|
| 45 |
+
rng_seed=1,
|
| 46 |
+
max_level=1,
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
first = registry.assignment_for(scenario_group_id=101, difficulty_policy="adaptive")
|
| 50 |
+
second = registry.assignment_for(scenario_group_id=101, difficulty_policy="adaptive")
|
| 51 |
+
|
| 52 |
+
assert first == second
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def test_different_scenario_groups_use_different_cached_scenarios_when_available():
|
| 56 |
+
registry = ScenarioGroupRegistry(
|
| 57 |
+
_entries(),
|
| 58 |
+
split="train",
|
| 59 |
+
initial_difficulty=1,
|
| 60 |
+
rng_seed=3,
|
| 61 |
+
max_level=1,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
first = registry.assignment_for(
|
| 65 |
+
scenario_group_id=201,
|
| 66 |
+
requested_seed=20,
|
| 67 |
+
requested_difficulty=1,
|
| 68 |
+
split="train",
|
| 69 |
+
difficulty_policy="fixed",
|
| 70 |
+
)
|
| 71 |
+
second = registry.assignment_for(
|
| 72 |
+
scenario_group_id=202,
|
| 73 |
+
requested_seed=30,
|
| 74 |
+
requested_difficulty=1,
|
| 75 |
+
split="train",
|
| 76 |
+
difficulty_policy="fixed",
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
assert first["scenario_hash"] != second["scenario_hash"]
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def test_fixed_assignment_uses_dataset_seed_and_difficulty():
|
| 83 |
+
registry = ScenarioGroupRegistry(
|
| 84 |
+
_entries(),
|
| 85 |
+
split="train",
|
| 86 |
+
initial_difficulty=0,
|
| 87 |
+
rng_seed=1,
|
| 88 |
+
max_level=1,
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
assignment = registry.assignment_for(
|
| 92 |
+
scenario_group_id=301,
|
| 93 |
+
requested_seed=20,
|
| 94 |
+
requested_difficulty=1,
|
| 95 |
+
split="train",
|
| 96 |
+
difficulty_policy="fixed",
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
assert assignment["seed"] == 20
|
| 100 |
+
assert assignment["difficulty"] == 1
|
| 101 |
+
assert assignment["scenario_hash"] == "hash-b"
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def test_adaptive_curriculum_promotes_and_demotes_at_thresholds():
|
| 105 |
+
promote = AdaptiveDifficultyCurriculum(
|
| 106 |
+
min_level=0,
|
| 107 |
+
max_level=2,
|
| 108 |
+
current_level=0,
|
| 109 |
+
promote_after=50,
|
| 110 |
+
)
|
| 111 |
+
for _ in range(50):
|
| 112 |
+
promote.update(0, True)
|
| 113 |
+
assert promote.current_level == 1
|
| 114 |
+
|
| 115 |
+
demote = AdaptiveDifficultyCurriculum(
|
| 116 |
+
min_level=0,
|
| 117 |
+
max_level=2,
|
| 118 |
+
current_level=1,
|
| 119 |
+
promote_after=50,
|
| 120 |
+
)
|
| 121 |
+
for _ in range(50):
|
| 122 |
+
demote.update(1, False)
|
| 123 |
+
assert demote.current_level == 0
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def test_build_scenario_group_rows_include_grpo_group_columns():
|
| 127 |
+
rows = build_scenario_group_rows(
|
| 128 |
+
dataset_size=2,
|
| 129 |
+
training_prompt="repair local app",
|
| 130 |
+
seed_start=7,
|
| 131 |
+
split="train",
|
| 132 |
+
difficulty=1,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
assert rows[0]["scenario_group_id"] == 7
|
| 136 |
+
assert rows[1]["scenario_group_id"] == 8
|
| 137 |
+
assert rows[0]["difficulty_policy"] == "adaptive"
|
| 138 |
+
assert rows[0]["prompt"][0]["content"] == "repair local app"
|
tests/test_trackio_utils.py
CHANGED
|
@@ -102,6 +102,8 @@ def test_trace_fingerprint_ignores_episode_id_but_tracks_action_changes():
|
|
| 102 |
"scenario/split": "train",
|
| 103 |
"scenario/difficulty": 0,
|
| 104 |
"scenario/bug_type": "bola_idor",
|
|
|
|
|
|
|
| 105 |
"action_history": [
|
| 106 |
{
|
| 107 |
"tool_name": "read_file",
|
|
@@ -113,11 +115,17 @@ def test_trace_fingerprint_ignores_episode_id_but_tracks_action_changes():
|
|
| 113 |
}
|
| 114 |
same_trace = dict(base_record)
|
| 115 |
same_trace["episode_id"] = "episode-b"
|
|
|
|
|
|
|
| 116 |
changed_trace = dict(base_record)
|
| 117 |
changed_trace["action_history"] = [
|
| 118 |
*base_record["action_history"],
|
| 119 |
{"tool_name": "submit_fix", "arguments": {}},
|
| 120 |
]
|
|
|
|
|
|
|
| 121 |
|
| 122 |
assert episode_trace_fingerprint(base_record) == episode_trace_fingerprint(same_trace)
|
|
|
|
| 123 |
assert episode_trace_fingerprint(base_record) != episode_trace_fingerprint(changed_trace)
|
|
|
|
|
|
| 102 |
"scenario/split": "train",
|
| 103 |
"scenario/difficulty": 0,
|
| 104 |
"scenario/bug_type": "bola_idor",
|
| 105 |
+
"scenario/template_id": "fastapi_basic",
|
| 106 |
+
"scenario_hash": "scenario-a",
|
| 107 |
"action_history": [
|
| 108 |
{
|
| 109 |
"tool_name": "read_file",
|
|
|
|
| 115 |
}
|
| 116 |
same_trace = dict(base_record)
|
| 117 |
same_trace["episode_id"] = "episode-b"
|
| 118 |
+
token_only_reward_change = dict(base_record)
|
| 119 |
+
token_only_reward_change["reward_total"] = -0.25
|
| 120 |
changed_trace = dict(base_record)
|
| 121 |
changed_trace["action_history"] = [
|
| 122 |
*base_record["action_history"],
|
| 123 |
{"tool_name": "submit_fix", "arguments": {}},
|
| 124 |
]
|
| 125 |
+
different_scenario = dict(base_record)
|
| 126 |
+
different_scenario["scenario_hash"] = "scenario-b"
|
| 127 |
|
| 128 |
assert episode_trace_fingerprint(base_record) == episode_trace_fingerprint(same_trace)
|
| 129 |
+
assert episode_trace_fingerprint(base_record) == episode_trace_fingerprint(token_only_reward_change)
|
| 130 |
assert episode_trace_fingerprint(base_record) != episode_trace_fingerprint(changed_trace)
|
| 131 |
+
assert episode_trace_fingerprint(base_record) != episode_trace_fingerprint(different_scenario)
|
training/grpo_curriculum.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Scenario grouping and adaptive curriculum helpers for GRPO training."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import random
|
| 6 |
+
import threading
|
| 7 |
+
from collections.abc import Iterable, Mapping, Sequence
|
| 8 |
+
from dataclasses import dataclass, field
|
| 9 |
+
from typing import Any
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass
|
| 13 |
+
class AdaptiveDifficultyCurriculum:
|
| 14 |
+
min_level: int = 0
|
| 15 |
+
max_level: int = 3
|
| 16 |
+
current_level: int = 0
|
| 17 |
+
promote_after: int = 50
|
| 18 |
+
promote_threshold: float = 0.70
|
| 19 |
+
demote_threshold: float = 0.35
|
| 20 |
+
ema_alpha: float = 0.10
|
| 21 |
+
rng_seed: int = 0
|
| 22 |
+
counts: dict[int, int] = field(default_factory=dict)
|
| 23 |
+
ema_success: dict[int, float] = field(default_factory=dict)
|
| 24 |
+
|
| 25 |
+
def __post_init__(self) -> None:
|
| 26 |
+
self.min_level = int(self.min_level)
|
| 27 |
+
self.max_level = int(self.max_level)
|
| 28 |
+
self.current_level = max(self.min_level, min(int(self.current_level), self.max_level))
|
| 29 |
+
self._rng = random.Random(int(self.rng_seed))
|
| 30 |
+
|
| 31 |
+
def sample_difficulty(self, available_difficulties: Iterable[int]) -> int:
|
| 32 |
+
available = {int(item) for item in available_difficulties}
|
| 33 |
+
if not available:
|
| 34 |
+
raise ValueError("No cached difficulties are available for GRPO curriculum sampling.")
|
| 35 |
+
|
| 36 |
+
candidates = [
|
| 37 |
+
max(self.min_level, self.current_level - 1),
|
| 38 |
+
self.current_level,
|
| 39 |
+
min(self.max_level, self.current_level + 1),
|
| 40 |
+
]
|
| 41 |
+
weights = [0.20, 0.65, 0.15]
|
| 42 |
+
weighted: dict[int, float] = {}
|
| 43 |
+
for level, weight in zip(candidates, weights):
|
| 44 |
+
if level in available:
|
| 45 |
+
weighted[level] = weighted.get(level, 0.0) + weight
|
| 46 |
+
|
| 47 |
+
if not weighted:
|
| 48 |
+
nearest = min(available, key=lambda level: (abs(level - self.current_level), level))
|
| 49 |
+
return nearest
|
| 50 |
+
levels = list(weighted)
|
| 51 |
+
return int(self._rng.choices(levels, weights=[weighted[level] for level in levels], k=1)[0])
|
| 52 |
+
|
| 53 |
+
def update(self, difficulty: int, success: float | bool) -> dict[str, Any]:
|
| 54 |
+
level = int(difficulty)
|
| 55 |
+
value = max(0.0, min(1.0, float(success)))
|
| 56 |
+
self.counts[level] = self.counts.get(level, 0) + 1
|
| 57 |
+
old = self.ema_success.get(level, 0.0)
|
| 58 |
+
self.ema_success[level] = (1.0 - self.ema_alpha) * old + self.ema_alpha * value
|
| 59 |
+
|
| 60 |
+
if level == self.current_level and self.counts[level] >= self.promote_after:
|
| 61 |
+
if self.ema_success[level] >= self.promote_threshold:
|
| 62 |
+
self.current_level = min(self.max_level, self.current_level + 1)
|
| 63 |
+
elif self.ema_success[level] <= self.demote_threshold:
|
| 64 |
+
self.current_level = max(self.min_level, self.current_level - 1)
|
| 65 |
+
return self.snapshot()
|
| 66 |
+
|
| 67 |
+
def snapshot(self) -> dict[str, Any]:
|
| 68 |
+
return {
|
| 69 |
+
"current_level": self.current_level,
|
| 70 |
+
"counts": {str(key): value for key, value in sorted(self.counts.items())},
|
| 71 |
+
"ema_success": {
|
| 72 |
+
str(key): value for key, value in sorted(self.ema_success.items())
|
| 73 |
+
},
|
| 74 |
+
"current_level_ema_success": self.ema_success.get(self.current_level, 0.0),
|
| 75 |
+
}
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class ScenarioGroupRegistry:
|
| 79 |
+
"""Assign each GRPO group to exactly one cached scenario."""
|
| 80 |
+
|
| 81 |
+
def __init__(
|
| 82 |
+
self,
|
| 83 |
+
entries: Sequence[Mapping[str, Any]],
|
| 84 |
+
*,
|
| 85 |
+
split: str = "train",
|
| 86 |
+
initial_difficulty: int = 0,
|
| 87 |
+
rng_seed: int = 0,
|
| 88 |
+
max_level: int | None = None,
|
| 89 |
+
) -> None:
|
| 90 |
+
self.split = split
|
| 91 |
+
self._rng = random.Random(int(rng_seed))
|
| 92 |
+
self._lock = threading.Lock()
|
| 93 |
+
self._assignments: dict[int, dict[str, Any]] = {}
|
| 94 |
+
self._completed_groups: set[int] = set()
|
| 95 |
+
self._entries_by_difficulty: dict[int, list[dict[str, Any]]] = {}
|
| 96 |
+
self._cursors: dict[int, int] = {}
|
| 97 |
+
|
| 98 |
+
for entry in entries:
|
| 99 |
+
if entry.get("validated") is not True or entry.get("split") != split:
|
| 100 |
+
continue
|
| 101 |
+
difficulty = int(entry.get("difficulty", 0))
|
| 102 |
+
self._entries_by_difficulty.setdefault(difficulty, []).append(dict(entry))
|
| 103 |
+
|
| 104 |
+
for difficulty, items in self._entries_by_difficulty.items():
|
| 105 |
+
items.sort(key=lambda item: (int(item.get("seed", 0)), str(item.get("scenario_hash", ""))))
|
| 106 |
+
self._rng.shuffle(items)
|
| 107 |
+
self._cursors[difficulty] = 0
|
| 108 |
+
|
| 109 |
+
if not self._entries_by_difficulty:
|
| 110 |
+
raise ValueError(f"No validated cached scenarios are available for split={split!r}.")
|
| 111 |
+
|
| 112 |
+
available = sorted(self._entries_by_difficulty)
|
| 113 |
+
resolved_max = max_level if max_level is not None else max(available)
|
| 114 |
+
self.curriculum = AdaptiveDifficultyCurriculum(
|
| 115 |
+
min_level=min(available),
|
| 116 |
+
max_level=int(resolved_max),
|
| 117 |
+
current_level=int(initial_difficulty),
|
| 118 |
+
rng_seed=int(rng_seed),
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
@property
|
| 122 |
+
def available_difficulties(self) -> list[int]:
|
| 123 |
+
return sorted(self._entries_by_difficulty)
|
| 124 |
+
|
| 125 |
+
def assignment_for(
|
| 126 |
+
self,
|
| 127 |
+
*,
|
| 128 |
+
scenario_group_id: int,
|
| 129 |
+
requested_seed: int | None = None,
|
| 130 |
+
requested_difficulty: int | None = None,
|
| 131 |
+
split: str | None = None,
|
| 132 |
+
difficulty_policy: str = "adaptive",
|
| 133 |
+
) -> dict[str, Any]:
|
| 134 |
+
group_id = int(scenario_group_id)
|
| 135 |
+
with self._lock:
|
| 136 |
+
if group_id in self._assignments:
|
| 137 |
+
return dict(self._assignments[group_id])
|
| 138 |
+
|
| 139 |
+
if difficulty_policy == "fixed":
|
| 140 |
+
difficulty = int(
|
| 141 |
+
requested_difficulty
|
| 142 |
+
if requested_difficulty is not None
|
| 143 |
+
else self.curriculum.current_level
|
| 144 |
+
)
|
| 145 |
+
entry = self._find_entry(
|
| 146 |
+
seed=requested_seed,
|
| 147 |
+
split=split or self.split,
|
| 148 |
+
difficulty=difficulty,
|
| 149 |
+
) or self._next_entry(difficulty)
|
| 150 |
+
else:
|
| 151 |
+
difficulty = self.curriculum.sample_difficulty(self.available_difficulties)
|
| 152 |
+
entry = self._next_entry(difficulty)
|
| 153 |
+
|
| 154 |
+
assignment = self._assignment_from_entry(group_id, entry)
|
| 155 |
+
self._assignments[group_id] = assignment
|
| 156 |
+
return dict(assignment)
|
| 157 |
+
|
| 158 |
+
def record_group_outcome(self, scenario_group_id: int, success_rate: float) -> dict[str, Any] | None:
|
| 159 |
+
group_id = int(scenario_group_id)
|
| 160 |
+
with self._lock:
|
| 161 |
+
if group_id in self._completed_groups:
|
| 162 |
+
return None
|
| 163 |
+
self._completed_groups.add(group_id)
|
| 164 |
+
assignment = self._assignments.get(group_id)
|
| 165 |
+
if not assignment:
|
| 166 |
+
return self.curriculum.snapshot()
|
| 167 |
+
return self.curriculum.update(
|
| 168 |
+
int(assignment["difficulty"]),
|
| 169 |
+
max(0.0, min(1.0, float(success_rate))),
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
def metrics(
|
| 173 |
+
self,
|
| 174 |
+
records: Sequence[Mapping[str, Any]],
|
| 175 |
+
*,
|
| 176 |
+
unique_trace_count: int,
|
| 177 |
+
duplicate_trace_suppressed_count: int,
|
| 178 |
+
) -> dict[str, float]:
|
| 179 |
+
scenario_hashes = {
|
| 180 |
+
str(record.get("scenario_hash") or record.get("scenario_id_hash") or "")
|
| 181 |
+
for record in records
|
| 182 |
+
if record.get("scenario_hash") or record.get("scenario_id_hash")
|
| 183 |
+
}
|
| 184 |
+
seeds = {
|
| 185 |
+
int(record.get("scenario/seed", record.get("seed", 0)) or 0)
|
| 186 |
+
for record in records
|
| 187 |
+
}
|
| 188 |
+
total = max(1, len(records))
|
| 189 |
+
snapshot = self.curriculum.snapshot()
|
| 190 |
+
return {
|
| 191 |
+
"train/unique_trace_count": float(unique_trace_count),
|
| 192 |
+
"train/duplicate_trace_suppressed_count": float(duplicate_trace_suppressed_count),
|
| 193 |
+
"train/unique_trace_rate": float(unique_trace_count) / total,
|
| 194 |
+
"train/unique_seed_count": float(len(seeds)),
|
| 195 |
+
"train/unique_scenario_hash_count": float(len(scenario_hashes)),
|
| 196 |
+
"train/curriculum_level": float(snapshot["current_level"]),
|
| 197 |
+
"train/curriculum_ema_success": float(snapshot["current_level_ema_success"]),
|
| 198 |
+
}
|
| 199 |
+
|
| 200 |
+
def _find_entry(
|
| 201 |
+
self,
|
| 202 |
+
*,
|
| 203 |
+
seed: int | None,
|
| 204 |
+
split: str,
|
| 205 |
+
difficulty: int,
|
| 206 |
+
) -> dict[str, Any] | None:
|
| 207 |
+
if seed is None or split != self.split:
|
| 208 |
+
return None
|
| 209 |
+
for entry in self._entries_by_difficulty.get(int(difficulty), []):
|
| 210 |
+
if int(entry.get("seed", -1)) == int(seed):
|
| 211 |
+
return dict(entry)
|
| 212 |
+
return None
|
| 213 |
+
|
| 214 |
+
def _next_entry(self, difficulty: int) -> dict[str, Any]:
|
| 215 |
+
level = int(difficulty)
|
| 216 |
+
items = self._entries_by_difficulty.get(level)
|
| 217 |
+
if not items:
|
| 218 |
+
nearest = min(
|
| 219 |
+
self.available_difficulties,
|
| 220 |
+
key=lambda item: (abs(item - level), item),
|
| 221 |
+
)
|
| 222 |
+
items = self._entries_by_difficulty[nearest]
|
| 223 |
+
level = nearest
|
| 224 |
+
cursor = self._cursors.get(level, 0)
|
| 225 |
+
self._cursors[level] = cursor + 1
|
| 226 |
+
return dict(items[cursor % len(items)])
|
| 227 |
+
|
| 228 |
+
def _assignment_from_entry(self, group_id: int, entry: Mapping[str, Any]) -> dict[str, Any]:
|
| 229 |
+
cache_key = entry.get("cache_key") if isinstance(entry.get("cache_key"), Mapping) else {}
|
| 230 |
+
return {
|
| 231 |
+
"scenario_group_id": int(group_id),
|
| 232 |
+
"seed": int(entry.get("seed", 0)),
|
| 233 |
+
"split": str(entry.get("split", self.split)),
|
| 234 |
+
"difficulty": int(entry.get("difficulty", 0)),
|
| 235 |
+
"scenario_hash": str(entry.get("scenario_hash", "")),
|
| 236 |
+
"template_id": str(entry.get("template_id") or cache_key.get("app_family", "")),
|
| 237 |
+
"bug_family": str(entry.get("bug_family") or cache_key.get("authz_bug_type", "")),
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def build_scenario_group_rows(
|
| 242 |
+
*,
|
| 243 |
+
dataset_size: int,
|
| 244 |
+
training_prompt: str,
|
| 245 |
+
seed_start: int = 0,
|
| 246 |
+
split: str = "train",
|
| 247 |
+
difficulty: int = 0,
|
| 248 |
+
difficulty_policy: str = "adaptive",
|
| 249 |
+
) -> list[dict[str, Any]]:
|
| 250 |
+
return [
|
| 251 |
+
{
|
| 252 |
+
"prompt": [{"role": "user", "content": training_prompt}],
|
| 253 |
+
"scenario_group_id": int(seed_start) + index,
|
| 254 |
+
"seed": int(seed_start) + index,
|
| 255 |
+
"difficulty": int(difficulty),
|
| 256 |
+
"split": split,
|
| 257 |
+
"difficulty_policy": difficulty_policy,
|
| 258 |
+
}
|
| 259 |
+
for index in range(int(dataset_size))
|
| 260 |
+
]
|
training/trackio_utils.py
CHANGED
|
@@ -150,9 +150,16 @@ REQUIRED_SMOKE_TRACKIO_ITEMS = (
|
|
| 150 |
TRACE_TABLE_COLUMNS = (
|
| 151 |
"episode_id",
|
| 152 |
"scenario_id_hash",
|
|
|
|
|
|
|
| 153 |
"split",
|
| 154 |
"difficulty",
|
|
|
|
| 155 |
"bug_type",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
"visible_observation_summary",
|
| 157 |
"action_sequence",
|
| 158 |
"tool_calls",
|
|
@@ -529,6 +536,7 @@ def episode_record_from_state(
|
|
| 529 |
"target_weakness": getattr(state, "target_weakness", ""),
|
| 530 |
"difficulty_tier": getattr(state, "difficulty_tier", ""),
|
| 531 |
"domain": getattr(state, "domain", ""),
|
|
|
|
| 532 |
"success": bool(getattr(state, "success", False)),
|
| 533 |
"failure_reason": getattr(state, "failure_reason", None),
|
| 534 |
"finding_submitted": bool(getattr(state, "finding_submitted", False)),
|
|
@@ -821,12 +829,20 @@ def episode_to_trace_row(episode: Any) -> dict[str, Any]:
|
|
| 821 |
files_modified = _files_modified(record, actions)
|
| 822 |
reward_breakdown = _final_reward_breakdown(record)
|
| 823 |
final_obs = observations[-1] if observations else {}
|
|
|
|
| 824 |
row = {
|
| 825 |
"episode_id": _redact_text(record.get("episode_id", "")),
|
| 826 |
"scenario_id_hash": record.get("scenario_id_hash") or _scenario_hash(record),
|
|
|
|
|
|
|
| 827 |
"split": record.get("scenario/split") or record.get("split", ""),
|
| 828 |
"difficulty": record.get("scenario/difficulty") or record.get("difficulty", 0),
|
|
|
|
| 829 |
"bug_type": record.get("scenario/bug_type") or record.get("bug_type", ""),
|
|
|
|
|
|
|
|
|
|
|
|
|
| 830 |
"visible_observation_summary": json.dumps(
|
| 831 |
{
|
| 832 |
"done": bool(record.get("done", final_obs.get("done", False))),
|
|
@@ -845,9 +861,7 @@ def episode_to_trace_row(episode: Any) -> dict[str, Any]:
|
|
| 845 |
"local_probe_count": sum(
|
| 846 |
1 for name in tool_names if name in {"send_local_request", "compare_identities"}
|
| 847 |
),
|
| 848 |
-
"first_valid_exploit_step":
|
| 849 |
-
"skill/first_valid_exploit_step"
|
| 850 |
-
],
|
| 851 |
"diagnosis_submitted": bool(
|
| 852 |
record.get("diagnosis_submitted", record.get("finding_submitted", False))
|
| 853 |
),
|
|
@@ -894,7 +908,7 @@ def episode_trace_fingerprint(episode: Any) -> str:
|
|
| 894 |
{
|
| 895 |
key: row.get(key, "")
|
| 896 |
for key in TRACE_TABLE_COLUMNS
|
| 897 |
-
if key
|
| 898 |
},
|
| 899 |
length=24,
|
| 900 |
)
|
|
|
|
| 150 |
TRACE_TABLE_COLUMNS = (
|
| 151 |
"episode_id",
|
| 152 |
"scenario_id_hash",
|
| 153 |
+
"scenario_hash",
|
| 154 |
+
"seed",
|
| 155 |
"split",
|
| 156 |
"difficulty",
|
| 157 |
+
"template_id",
|
| 158 |
"bug_type",
|
| 159 |
+
"reward_total",
|
| 160 |
+
"security_pass_rate",
|
| 161 |
+
"regression_pass_rate",
|
| 162 |
+
"step_count",
|
| 163 |
"visible_observation_summary",
|
| 164 |
"action_sequence",
|
| 165 |
"tool_calls",
|
|
|
|
| 536 |
"target_weakness": getattr(state, "target_weakness", ""),
|
| 537 |
"difficulty_tier": getattr(state, "difficulty_tier", ""),
|
| 538 |
"domain": getattr(state, "domain", ""),
|
| 539 |
+
"scenario_hash": getattr(state, "scenario_hash", ""),
|
| 540 |
"success": bool(getattr(state, "success", False)),
|
| 541 |
"failure_reason": getattr(state, "failure_reason", None),
|
| 542 |
"finding_submitted": bool(getattr(state, "finding_submitted", False)),
|
|
|
|
| 829 |
files_modified = _files_modified(record, actions)
|
| 830 |
reward_breakdown = _final_reward_breakdown(record)
|
| 831 |
final_obs = observations[-1] if observations else {}
|
| 832 |
+
tracking_fields = episode_to_tracking_fields(record)
|
| 833 |
row = {
|
| 834 |
"episode_id": _redact_text(record.get("episode_id", "")),
|
| 835 |
"scenario_id_hash": record.get("scenario_id_hash") or _scenario_hash(record),
|
| 836 |
+
"scenario_hash": record.get("scenario_hash") or _as_dict(record.get("metrics")).get("scenario_hash", ""),
|
| 837 |
+
"seed": record.get("scenario/seed") or record.get("seed", 0),
|
| 838 |
"split": record.get("scenario/split") or record.get("split", ""),
|
| 839 |
"difficulty": record.get("scenario/difficulty") or record.get("difficulty", 0),
|
| 840 |
+
"template_id": record.get("scenario/template_id") or record.get("template_id", ""),
|
| 841 |
"bug_type": record.get("scenario/bug_type") or record.get("bug_type", ""),
|
| 842 |
+
"reward_total": tracking_fields["reward/total"],
|
| 843 |
+
"security_pass_rate": tracking_fields["reward/hidden_authz_pass_rate"],
|
| 844 |
+
"regression_pass_rate": tracking_fields["reward/normal_flow_pass_rate"],
|
| 845 |
+
"step_count": record.get("step_count", len(actions)),
|
| 846 |
"visible_observation_summary": json.dumps(
|
| 847 |
{
|
| 848 |
"done": bool(record.get("done", final_obs.get("done", False))),
|
|
|
|
| 861 |
"local_probe_count": sum(
|
| 862 |
1 for name in tool_names if name in {"send_local_request", "compare_identities"}
|
| 863 |
),
|
| 864 |
+
"first_valid_exploit_step": tracking_fields["skill/first_valid_exploit_step"],
|
|
|
|
|
|
|
| 865 |
"diagnosis_submitted": bool(
|
| 866 |
record.get("diagnosis_submitted", record.get("finding_submitted", False))
|
| 867 |
),
|
|
|
|
| 908 |
{
|
| 909 |
key: row.get(key, "")
|
| 910 |
for key in TRACE_TABLE_COLUMNS
|
| 911 |
+
if key not in {"episode_id", "reward_total"}
|
| 912 |
},
|
| 913 |
length=24,
|
| 914 |
)
|