Humanlearning commited on
Commit
1b6d30b
·
1 Parent(s): e5fe6f5

feat: introduce GRPO GPU fallback support, enhance training script with warmstart tagging, and add learning rate parameter for improved training flexibility

Browse files
scripts/modal_train_grpo.py CHANGED
@@ -47,6 +47,7 @@ PUBLIC_REPO_URL = "https://github.com/humandotlearning/CyberSecurity_OWASP.git"
47
  PUBLIC_REPO_BRANCH = "master"
48
  DEFAULT_GEMMA_MODEL = "unsloth/gemma-4-E2B-it"
49
  GRPO_TRAINING_TIMEOUT_SECONDS = 24 * 60 * 60
 
50
  _IMAGE_NOTICE_PRINTED = False
51
 
52
 
@@ -69,6 +70,29 @@ def _model_repo_slug(model_name: str) -> str:
69
  )
70
 
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  def _hf_model_cache_path(model_name: str) -> pathlib.Path:
73
  return HF_HUB_CACHE_DIR / f"models--{model_name.replace('/', '--')}"
74
 
@@ -540,7 +564,7 @@ def verify_modal_scenario_cache_for_training(
540
 
541
  @app.function(
542
  image=training_image,
543
- gpu="L4",
544
  timeout=4 * 60 * 60,
545
  volumes={RUNS_DIR: volume, CACHE_DIR: cache_volume, SCENARIO_CACHE_DIR: scenario_cache_volume},
546
  secrets=secrets,
@@ -578,7 +602,7 @@ def check_training_imports() -> dict[str, str]:
578
 
579
  @app.function(
580
  image=training_image,
581
- gpu="L4",
582
  timeout=4 * 60 * 60,
583
  volumes={RUNS_DIR: volume, CACHE_DIR: cache_volume, SCENARIO_CACHE_DIR: scenario_cache_volume},
584
  secrets=secrets,
@@ -1021,7 +1045,7 @@ def run_cybersecurity_owasp_baseline(
1021
 
1022
  @app.function(
1023
  image=training_image,
1024
- gpu="L4",
1025
  timeout=GRPO_TRAINING_TIMEOUT_SECONDS,
1026
  volumes={RUNS_DIR: volume, CACHE_DIR: cache_volume, SCENARIO_CACHE_DIR: scenario_cache_volume},
1027
  secrets=secrets,
@@ -1044,6 +1068,7 @@ def train_cybersecurity_owasp_grpo(
1044
  num_generations: int = 6,
1045
  per_device_train_batch_size: int = 1,
1046
  gradient_accumulation_steps: int = 0,
 
1047
  use_vllm: bool = False,
1048
  vllm_gpu_memory_utilization: float = 0.2,
1049
  trace_log_every: int = 5,
@@ -1135,7 +1160,7 @@ def train_cybersecurity_owasp_grpo(
1135
  user = whoami(token=hf_token)["name"]
1136
  env_repo_id = env_repo_id or f"{user}/CyberSecurity_OWASP"
1137
  output_repo_id = output_repo_id or (
1138
- f"{user}/CyberSecurity_OWASP-{_model_repo_slug(model_name)}-grpo-lora"
1139
  )
1140
  if not trackio_space_id:
1141
  trackio_space_id = "Humanlearning/CyberSecurity_OWASP-trackio"
@@ -1163,8 +1188,12 @@ def train_cybersecurity_owasp_grpo(
1163
 
1164
  model_slug = model_name.replace("/", "-")
1165
  stamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
 
 
 
 
1166
  run_name = run_name or (
1167
- f"CyberSecurity_OWASP-{model_slug}-grpo-level{difficulty}-"
1168
  f"{reward_tracking_config['reward_variant']}-steps{max_steps}-seed{seed_start}-"
1169
  f"{stamp}-{git_sha[:8]}"
1170
  )
@@ -1757,6 +1786,7 @@ def train_cybersecurity_owasp_grpo(
1757
  print(f"Reward config hash: {reward_tracking_config['reward_config_hash']}")
1758
  print(f"Reward variant: {reward_tracking_config['reward_variant']}")
1759
  print(f"Reward config path: {reward_tracking_config['reward_config_path']}")
 
1760
  print(f"Reward env overrides: {reward_env}")
1761
  print(f"Model cache volume: {CACHE_VOLUME_NAME}")
1762
  print(f"Scenario cache volume: {SCENARIO_CACHE_VOLUME_NAME}")
@@ -1950,7 +1980,7 @@ def train_cybersecurity_owasp_grpo(
1950
 
1951
  grpo_config_values = {
1952
  "temperature": 1.0,
1953
- "learning_rate": 5e-6,
1954
  "weight_decay": 0.001,
1955
  "warmup_ratio": 0.1,
1956
  "lr_scheduler_type": "linear",
@@ -2025,7 +2055,7 @@ def train_cybersecurity_owasp_grpo(
2025
  print(
2026
  "Training heartbeat: still inside trainer.train() "
2027
  f"after {elapsed}s. For this smoke, the slow part is usually "
2028
- f"Gemma generation/backprop on L4: {num_generations} completions "
2029
  f"up to {max_completion_length} tokens, plus Trackio upload."
2030
  )
2031
 
@@ -2075,6 +2105,7 @@ def train_cybersecurity_owasp_grpo(
2075
  "num_generations": num_generations,
2076
  "per_device_train_batch_size": per_device_train_batch_size,
2077
  "gradient_accumulation_steps": resolved_gradient_accumulation_steps,
 
2078
  "effective_train_batch_size": effective_train_batch_size,
2079
  "use_vllm": int(bool(use_vllm)),
2080
  "vllm_gpu_memory_utilization": vllm_gpu_memory_utilization,
@@ -2110,6 +2141,7 @@ def main(
2110
  num_generations: int = 6,
2111
  per_device_train_batch_size: int = 1,
2112
  gradient_accumulation_steps: int = 0,
 
2113
  use_vllm: bool = False,
2114
  vllm_gpu_memory_utilization: float = 0.2,
2115
  trace_log_every: int = 5,
@@ -2228,7 +2260,7 @@ def main(
2228
  )
2229
  resolved_output_repo_id = (
2230
  resolved_output_repo_id
2231
- or f"{user}/CyberSecurity_OWASP-{_model_repo_slug(model_name)}-grpo-lora"
2232
  )
2233
  except Exception as exc:
2234
  print(f"Could not resolve Hugging Face defaults locally: {exc!r}")
@@ -2253,8 +2285,12 @@ def main(
2253
  model_slug = model_name.replace("/", "-")
2254
  local_stamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
2255
  variant_tag = reward_variant or "default"
 
 
 
 
2256
  run_name = run_name or (
2257
- f"CyberSecurity_OWASP-{model_slug}-grpo-level{difficulty}-"
2258
  f"{variant_tag}-steps{max_steps}-seed{seed_start}-{local_stamp}-{git_sha[:8]}"
2259
  )
2260
 
@@ -2273,7 +2309,7 @@ def main(
2273
  else:
2274
  print(
2275
  "Output model repo: derived remotely from HF_TOKEN as "
2276
- f"<hf-user>/CyberSecurity_OWASP-{_model_repo_slug(model_name)}-grpo-lora"
2277
  )
2278
  print(f"Hub push enabled: {push_to_hub}")
2279
  if initial_adapter_path:
@@ -2287,7 +2323,8 @@ def main(
2287
  f"per_device_train_batch_size={per_device_train_batch_size}, "
2288
  f"gradient_accumulation_steps={resolved_gradient_accumulation_steps}, "
2289
  f"num_generations={num_generations}, "
2290
- f"effective_train_batch_size={effective_train_batch_size}"
 
2291
  )
2292
  print(
2293
  "Generation acceleration config: "
@@ -2301,7 +2338,7 @@ def main(
2301
  "slow when local source or dependency layers changed."
2302
  )
2303
  print("2. CPU-only scenario cache preflight in CyberSecurity_OWASP-scenario-cache.")
2304
- print("3. GPU container start on one L4 only after cache preflight passes.")
2305
  print("4. Model cache check in CyberSecurity_OWASP-model-cache.")
2306
  print("5. Cached snapshot load into GPU RAM with Unsloth progress.")
2307
  print("6. GRPO steps, Trackio sync, and volume commit.")
@@ -2328,6 +2365,7 @@ def main(
2328
  num_generations=num_generations,
2329
  per_device_train_batch_size=per_device_train_batch_size,
2330
  gradient_accumulation_steps=resolved_gradient_accumulation_steps,
 
2331
  use_vllm=use_vllm,
2332
  vllm_gpu_memory_utilization=vllm_gpu_memory_utilization,
2333
  trace_log_every=trace_log_every,
 
47
  PUBLIC_REPO_BRANCH = "master"
48
  DEFAULT_GEMMA_MODEL = "unsloth/gemma-4-E2B-it"
49
  GRPO_TRAINING_TIMEOUT_SECONDS = 24 * 60 * 60
50
+ GRPO_GPU_FALLBACK = ["L40S", "L4"]
51
  _IMAGE_NOTICE_PRINTED = False
52
 
53
 
 
70
  )
71
 
72
 
73
+ def _grpo_output_repo_slug(
74
+ model_name: str,
75
+ *,
76
+ initial_adapter_path: str = "",
77
+ initial_adapter_repo_id: str = "",
78
+ ) -> str:
79
+ warmstart_tag = (
80
+ "-sft-warmstart" if initial_adapter_path or initial_adapter_repo_id else ""
81
+ )
82
+ return (
83
+ f"CyberSecurity_OWASP-{_model_repo_slug(model_name)}"
84
+ f"{warmstart_tag}-grpo-lora"
85
+ )
86
+
87
+
88
+ def _grpo_run_algo_tag(
89
+ *,
90
+ initial_adapter_path: str = "",
91
+ initial_adapter_repo_id: str = "",
92
+ ) -> str:
93
+ return "sft-warmstart-grpo" if initial_adapter_path or initial_adapter_repo_id else "grpo"
94
+
95
+
96
  def _hf_model_cache_path(model_name: str) -> pathlib.Path:
97
  return HF_HUB_CACHE_DIR / f"models--{model_name.replace('/', '--')}"
98
 
 
564
 
565
  @app.function(
566
  image=training_image,
567
+ gpu=GRPO_GPU_FALLBACK,
568
  timeout=4 * 60 * 60,
569
  volumes={RUNS_DIR: volume, CACHE_DIR: cache_volume, SCENARIO_CACHE_DIR: scenario_cache_volume},
570
  secrets=secrets,
 
602
 
603
  @app.function(
604
  image=training_image,
605
+ gpu=GRPO_GPU_FALLBACK,
606
  timeout=4 * 60 * 60,
607
  volumes={RUNS_DIR: volume, CACHE_DIR: cache_volume, SCENARIO_CACHE_DIR: scenario_cache_volume},
608
  secrets=secrets,
 
1045
 
1046
  @app.function(
1047
  image=training_image,
1048
+ gpu=GRPO_GPU_FALLBACK,
1049
  timeout=GRPO_TRAINING_TIMEOUT_SECONDS,
1050
  volumes={RUNS_DIR: volume, CACHE_DIR: cache_volume, SCENARIO_CACHE_DIR: scenario_cache_volume},
1051
  secrets=secrets,
 
1068
  num_generations: int = 6,
1069
  per_device_train_batch_size: int = 1,
1070
  gradient_accumulation_steps: int = 0,
1071
+ learning_rate: float = 5e-6,
1072
  use_vllm: bool = False,
1073
  vllm_gpu_memory_utilization: float = 0.2,
1074
  trace_log_every: int = 5,
 
1160
  user = whoami(token=hf_token)["name"]
1161
  env_repo_id = env_repo_id or f"{user}/CyberSecurity_OWASP"
1162
  output_repo_id = output_repo_id or (
1163
+ f"{user}/{_grpo_output_repo_slug(model_name, initial_adapter_path=initial_adapter_path, initial_adapter_repo_id=initial_adapter_repo_id)}"
1164
  )
1165
  if not trackio_space_id:
1166
  trackio_space_id = "Humanlearning/CyberSecurity_OWASP-trackio"
 
1188
 
1189
  model_slug = model_name.replace("/", "-")
1190
  stamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
1191
+ algo_tag = _grpo_run_algo_tag(
1192
+ initial_adapter_path=initial_adapter_path,
1193
+ initial_adapter_repo_id=initial_adapter_repo_id,
1194
+ )
1195
  run_name = run_name or (
1196
+ f"CyberSecurity_OWASP-{model_slug}-{algo_tag}-level{difficulty}-"
1197
  f"{reward_tracking_config['reward_variant']}-steps{max_steps}-seed{seed_start}-"
1198
  f"{stamp}-{git_sha[:8]}"
1199
  )
 
1786
  print(f"Reward config hash: {reward_tracking_config['reward_config_hash']}")
1787
  print(f"Reward variant: {reward_tracking_config['reward_variant']}")
1788
  print(f"Reward config path: {reward_tracking_config['reward_config_path']}")
1789
+ print(f"Learning rate: {learning_rate}")
1790
  print(f"Reward env overrides: {reward_env}")
1791
  print(f"Model cache volume: {CACHE_VOLUME_NAME}")
1792
  print(f"Scenario cache volume: {SCENARIO_CACHE_VOLUME_NAME}")
 
1980
 
1981
  grpo_config_values = {
1982
  "temperature": 1.0,
1983
+ "learning_rate": learning_rate,
1984
  "weight_decay": 0.001,
1985
  "warmup_ratio": 0.1,
1986
  "lr_scheduler_type": "linear",
 
2055
  print(
2056
  "Training heartbeat: still inside trainer.train() "
2057
  f"after {elapsed}s. For this smoke, the slow part is usually "
2058
+ f"Gemma generation/backprop: {num_generations} completions "
2059
  f"up to {max_completion_length} tokens, plus Trackio upload."
2060
  )
2061
 
 
2105
  "num_generations": num_generations,
2106
  "per_device_train_batch_size": per_device_train_batch_size,
2107
  "gradient_accumulation_steps": resolved_gradient_accumulation_steps,
2108
+ "learning_rate": learning_rate,
2109
  "effective_train_batch_size": effective_train_batch_size,
2110
  "use_vllm": int(bool(use_vllm)),
2111
  "vllm_gpu_memory_utilization": vllm_gpu_memory_utilization,
 
2141
  num_generations: int = 6,
2142
  per_device_train_batch_size: int = 1,
2143
  gradient_accumulation_steps: int = 0,
2144
+ learning_rate: float = 5e-6,
2145
  use_vllm: bool = False,
2146
  vllm_gpu_memory_utilization: float = 0.2,
2147
  trace_log_every: int = 5,
 
2260
  )
2261
  resolved_output_repo_id = (
2262
  resolved_output_repo_id
2263
+ or f"{user}/{_grpo_output_repo_slug(model_name, initial_adapter_path=initial_adapter_path, initial_adapter_repo_id=initial_adapter_repo_id)}"
2264
  )
2265
  except Exception as exc:
2266
  print(f"Could not resolve Hugging Face defaults locally: {exc!r}")
 
2285
  model_slug = model_name.replace("/", "-")
2286
  local_stamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
2287
  variant_tag = reward_variant or "default"
2288
+ algo_tag = _grpo_run_algo_tag(
2289
+ initial_adapter_path=initial_adapter_path,
2290
+ initial_adapter_repo_id=initial_adapter_repo_id,
2291
+ )
2292
  run_name = run_name or (
2293
+ f"CyberSecurity_OWASP-{model_slug}-{algo_tag}-level{difficulty}-"
2294
  f"{variant_tag}-steps{max_steps}-seed{seed_start}-{local_stamp}-{git_sha[:8]}"
2295
  )
2296
 
 
2309
  else:
2310
  print(
2311
  "Output model repo: derived remotely from HF_TOKEN as "
2312
+ f"<hf-user>/{_grpo_output_repo_slug(model_name, initial_adapter_path=initial_adapter_path, initial_adapter_repo_id=initial_adapter_repo_id)}"
2313
  )
2314
  print(f"Hub push enabled: {push_to_hub}")
2315
  if initial_adapter_path:
 
2323
  f"per_device_train_batch_size={per_device_train_batch_size}, "
2324
  f"gradient_accumulation_steps={resolved_gradient_accumulation_steps}, "
2325
  f"num_generations={num_generations}, "
2326
+ f"effective_train_batch_size={effective_train_batch_size}, "
2327
+ f"learning_rate={learning_rate}"
2328
  )
2329
  print(
2330
  "Generation acceleration config: "
 
2338
  "slow when local source or dependency layers changed."
2339
  )
2340
  print("2. CPU-only scenario cache preflight in CyberSecurity_OWASP-scenario-cache.")
2341
+ print(f"3. GPU container start after cache preflight passes; fallback={GRPO_GPU_FALLBACK}.")
2342
  print("4. Model cache check in CyberSecurity_OWASP-model-cache.")
2343
  print("5. Cached snapshot load into GPU RAM with Unsloth progress.")
2344
  print("6. GRPO steps, Trackio sync, and volume commit.")
 
2365
  num_generations=num_generations,
2366
  per_device_train_batch_size=per_device_train_batch_size,
2367
  gradient_accumulation_steps=resolved_gradient_accumulation_steps,
2368
+ learning_rate=learning_rate,
2369
  use_vllm=use_vllm,
2370
  vllm_gpu_memory_utilization=vllm_gpu_memory_utilization,
2371
  trace_log_every=trace_log_every,
tests/test_modal_scenario_cache_static.py CHANGED
@@ -31,12 +31,18 @@ def test_modal_ephemeral_smoke_uses_required_scenario_cache():
31
  def test_modal_training_is_pinned_to_gemma4_e2b():
32
  source = (ROOT / "scripts" / "modal_train_grpo.py").read_text(encoding="utf-8")
33
 
 
 
34
  assert "DEFAULT_GEMMA_MODEL = \"unsloth/gemma-4-E2B-it\"" in source
35
  assert "def _ensure_gemma4_model(model_name: str) -> str:" in source
36
  assert "model_name = _ensure_gemma4_model(model_name)" in source
37
  assert "from unsloth import FastVisionModel" in source
38
  assert "Qwen" not in source
39
  assert "FastLanguageModel" not in source
 
 
 
 
40
 
41
 
42
  def test_modal_sft_defaults_match_300_episode_fast_handoff_plan():
 
31
  def test_modal_training_is_pinned_to_gemma4_e2b():
32
  source = (ROOT / "scripts" / "modal_train_grpo.py").read_text(encoding="utf-8")
33
 
34
+ assert 'GRPO_GPU_FALLBACK = ["L40S", "L4"]' in source
35
+ assert "gpu=GRPO_GPU_FALLBACK" in source
36
  assert "DEFAULT_GEMMA_MODEL = \"unsloth/gemma-4-E2B-it\"" in source
37
  assert "def _ensure_gemma4_model(model_name: str) -> str:" in source
38
  assert "model_name = _ensure_gemma4_model(model_name)" in source
39
  assert "from unsloth import FastVisionModel" in source
40
  assert "Qwen" not in source
41
  assert "FastLanguageModel" not in source
42
+ assert "sft-warmstart-grpo" in source
43
+ assert "-sft-warmstart" in source
44
+ assert "learning_rate: float = 5e-6" in source
45
+ assert '"learning_rate": learning_rate' in source
46
 
47
 
48
  def test_modal_sft_defaults_match_300_episode_fast_handoff_plan():
training/configs/sft_warmstart_fast.yaml ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ extends: grpo_small.yaml
2
+ reward:
3
+ mode: dense_train
4
+ training_mode: dense_train
5
+ stage: early
6
+ progressive_cap:
7
+ value: 8.0
8
+ description: "Higher shaping budget for SFT-warmstarted GRPO so early correct workflow actions separate from random exploration."
9
+ penalty_floor:
10
+ value: -4.0
11
+ description: "Less severe dense floor for fast policy learning while terminal verifier penalties still apply."
12
+ train_cap:
13
+ value: 26.0
14
+ description: "Allows strong progressive and terminal rewards in the same episode."
15
+ shaping_weight:
16
+ early: 1.4
17
+ middle: 1.1
18
+ late: 0.8
19
+ final: 0.25
20
+ description: "Emphasizes workflow shaping early, then anneals toward terminal verifier reward."
21
+ policy_inspected:
22
+ value: 0.80
23
+ description: "Stronger reward for starting with the policy graph, matching the SFT oracle trace."
24
+ route_map_inspected:
25
+ value: 0.45
26
+ cap: 0.90
27
+ description: "Rewards route discovery without making route-list loops attractive."
28
+ relevant_file_inspected:
29
+ value: 0.90
30
+ cap: 1.40
31
+ description: "Rewards reading or searching authorization-relevant code before patching."
32
+ local_evidence_found:
33
+ value: 2.20
34
+ cap: 2.20
35
+ description: "Prioritizes local evidence of the authorization failure before diagnosis."
36
+ diagnosis_correct:
37
+ value: 2.00
38
+ description: "Large reward for correct bug class, route, policy rule, and local evidence."
39
+ patch_applies:
40
+ value: 1.20
41
+ description: "Rewards applying a concrete patch after diagnosis."
42
+ app_boots_after_patch:
43
+ value: 1.00
44
+ description: "Rewards keeping the generated app bootable after patching."
45
+ visible_tests_improved:
46
+ value: 1.20
47
+ cap: 1.20
48
+ description: "Rewards visible test success after the patch."
49
+ public_routes_visible_pass:
50
+ value: 0.70
51
+ description: "Rewards preserving intentionally public routes."
52
+ step_penalty:
53
+ early: -0.002
54
+ middle: -0.004
55
+ late: -0.008
56
+ final: 0.0
57
+ cap: -0.35
58
+ description: "Keeps mild pressure toward concise episodes without discouraging exploration."
59
+ speed_bonus:
60
+ value: 0.5
61
+ description: "Small terminal success speed bonus; shaping carries early learning."
62
+ token_penalty:
63
+ target_tokens: 110
64
+ early: -0.002
65
+ middle: -0.0025
66
+ late: -0.003
67
+ final: 0.0
68
+ cap: -0.45
69
+ description: "Penalizes clipped or verbose tool calls immediately in SFT-warmstarted GRPO."
70
+ invalid_action:
71
+ value: -0.60
72
+ description: "Clear penalty for invalid tool calls, schema errors, or phase violations."
73
+ repeated_invalid_action:
74
+ value: -0.80
75
+ description: "Stronger penalty for repeating invalid behavior."
76
+ repeated_low_value_action:
77
+ value: -0.45
78
+ description: "Discourages repeated valid actions that add no new progress."
79
+ no_progress_action:
80
+ value: -0.20
81
+ description: "Penalizes valid but unhelpful actions after useful progress has already been collected."
82
+ noop_action:
83
+ value: -0.10
84
+ description: "Discourages no-op completions."
85
+ repeated_file_read:
86
+ value: -0.25
87
+ description: "Discourages rereading the same file without a patch change."
88
+ repeated_local_request:
89
+ value: -0.25
90
+ description: "Discourages repeated identical requests after evidence is known."
91
+ repeated_visible_tests:
92
+ value: -0.15
93
+ description: "Discourages rerunning visible tests without a new patch."
94
+ patch_before_policy:
95
+ value: -0.60
96
+ description: "Strongly discourages patching before policy inspection."
97
+ submit_without_patch:
98
+ value: -1.00
99
+ description: "Strongly discourages terminal submission without a patch."
100
+ submit_without_visible_tests:
101
+ value: -0.60
102
+ description: "Discourages submitting a patch before visible tests."