hirann commited on
Commit
e71c341
·
verified ·
1 Parent(s): 8143530

Upload training/train_grpo.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. training/train_grpo.py +405 -368
training/train_grpo.py CHANGED
@@ -1,368 +1,405 @@
1
- """
2
- GRPO Training Script for ImmunoOrg
3
- ===================================
4
- Uses Unsloth + HF TRL to train a defender agent via GRPO.
5
- """
6
-
7
- from __future__ import annotations
8
-
9
- import argparse
10
- import json
11
- import os
12
- import re
13
- import sys
14
- from argparse import Namespace
15
- from typing import Any
16
-
17
- # Add parent dir to path
18
- sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
19
-
20
-
21
- def parse_action_from_completion(text: str) -> dict[str, Any] | None:
22
- """Extract JSON action from model completion."""
23
- # Try to find JSON block
24
- json_match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', text)
25
- if json_match:
26
- try:
27
- return json.loads(json_match.group())
28
- except json.JSONDecodeError:
29
- pass
30
- return None
31
-
32
-
33
- def format_reward(completions: list[str], prompts: list[str], **kwargs) -> list[float]:
34
- """Reward function: score based on valid JSON action format."""
35
- rewards = []
36
- for completion in completions:
37
- score = 0.0
38
- action = parse_action_from_completion(completion)
39
- if action:
40
- score += 0.3 # Valid JSON
41
- if action.get("action_type") in ("tactical", "strategic", "diagnostic"):
42
- score += 0.2 # Valid action type
43
- if action.get("reasoning") and len(action["reasoning"]) > 20:
44
- score += 0.2 # Has reasoning
45
- if action.get("target"):
46
- score += 0.1 # Has target
47
- # Check specific action fields
48
- if action.get("tactical_action") or action.get("strategic_action") or action.get("diagnostic_action"):
49
- score += 0.2 # Has specific action
50
- rewards.append(score)
51
- return rewards
52
-
53
-
54
- def reasoning_quality_reward(completions: list[str], prompts: list[str], **kwargs) -> list[float]:
55
- """Reward for reasoning quality in completions."""
56
- rewards = []
57
- causal = ["because", "therefore", "since", "indicates", "correlates", "caused by", "root cause"]
58
- structured = ["1.", "2.", "Step", "First", "Then", "Finally"]
59
-
60
- for completion in completions:
61
- score = 0.0
62
- lower = completion.lower()
63
- words = len(completion.split())
64
-
65
- # Length (not too short, not padding)
66
- if 30 <= words <= 500:
67
- score += 0.2
68
- elif words >= 10:
69
- score += 0.1
70
-
71
- # Causal reasoning
72
- if any(kw in lower for kw in causal):
73
- score += 0.3
74
-
75
- # Structured thinking
76
- if any(m in completion for m in structured):
77
- score += 0.2
78
-
79
- # References specific entities
80
- if re.search(r'(node|port|department|server|attack|vulnerability|silo)', lower):
81
- score += 0.2
82
-
83
- # Phase awareness
84
- if re.search(r'(detection|containment|root cause|refactor|validation)', lower):
85
- score += 0.1
86
-
87
- rewards.append(min(1.0, score))
88
- return rewards
89
-
90
-
91
- def phase_appropriate_reward(completions: list[str], prompts: list[str], **kwargs) -> list[float]:
92
- """Reward for taking actions appropriate to the current phase."""
93
- rewards = []
94
- phase_actions = {
95
- "detection": ["scan_logs", "vulnerability_scan", "trace_attack_path"],
96
- "containment": ["block_port", "isolate_node", "quarantine_traffic"],
97
- "rca": ["correlate_failure", "identify_silo", "timeline_reconstruct"],
98
- "refactor": ["merge_departments", "create_shortcut_edge", "establish_devsecops"],
99
- "validation": ["measure_org_latency", "vulnerability_scan"],
100
- }
101
-
102
- for completion, prompt in zip(completions, prompts):
103
- score = 0.0
104
- # Detect phase from prompt
105
- current_phase = None
106
- for phase in phase_actions:
107
- if phase.upper() in prompt or f"Phase: {phase}" in prompt:
108
- current_phase = phase
109
- break
110
-
111
- if current_phase:
112
- appropriate = phase_actions.get(current_phase, [])
113
- action = parse_action_from_completion(completion)
114
- if action:
115
- action_name = (action.get("tactical_action") or
116
- action.get("strategic_action") or
117
- action.get("diagnostic_action") or "")
118
- if action_name in appropriate:
119
- score = 1.0
120
- else:
121
- score = 0.2 # Valid but wrong phase
122
- rewards.append(score)
123
- return rewards
124
-
125
-
126
- def build_training_prompts(num_prompts: int = 200) -> list[dict[str, str]]:
127
- """Generate diverse training prompts by running actual environments.
128
-
129
- Instead of 5 hardcoded scenarios, we run the environment across:
130
- - 4 difficulty levels
131
- - Multiple seeds
132
- - All 5 incident phases
133
- This produces genuine, diverse observations for GRPO training.
134
- """
135
- from immunoorg.agents.defender import get_defender_prompt, format_observation_for_llm
136
- from immunoorg.environment import ImmunoOrgEnvironment
137
- from immunoorg.models import (
138
- ActionType, TacticalAction, DiagnosticAction, StrategicAction, ImmunoAction
139
- )
140
- import random
141
-
142
- system_prompt = get_defender_prompt()
143
- scenarios = []
144
-
145
- # Phase-appropriate actions for generating trajectories
146
- phase_actions = {
147
- "detection": [
148
- lambda nodes: ImmunoAction(action_type=ActionType.TACTICAL, tactical_action=TacticalAction.SCAN_LOGS,
149
- target=nodes[0].id if nodes else "", reasoning="Scanning for indicators."),
150
- lambda nodes: ImmunoAction(action_type=ActionType.DIAGNOSTIC, diagnostic_action=DiagnosticAction.TRACE_ATTACK_PATH,
151
- target="", reasoning="Tracing attack path."),
152
- ],
153
- "containment": [
154
- lambda nodes: ImmunoAction(action_type=ActionType.TACTICAL, tactical_action=TacticalAction.ISOLATE_NODE,
155
- target=next((n.id for n in nodes if n.compromised), nodes[0].id if nodes else ""),
156
- reasoning="Isolating compromised node."),
157
- lambda nodes: ImmunoAction(action_type=ActionType.TACTICAL, tactical_action=TacticalAction.BLOCK_PORT,
158
- target=nodes[0].id if nodes else "", reasoning="Blocking attack port."),
159
- ],
160
- "rca": [
161
- lambda nodes: ImmunoAction(action_type=ActionType.DIAGNOSTIC, diagnostic_action=DiagnosticAction.IDENTIFY_SILO,
162
- target="", reasoning="Finding organizational silos."),
163
- lambda nodes: ImmunoAction(action_type=ActionType.DIAGNOSTIC, diagnostic_action=DiagnosticAction.CORRELATE_FAILURE,
164
- target="", parameters={"technical_indicator": "attack", "organizational_flaw": "no_devsecops", "confidence": 0.7},
165
- reasoning="Correlating technical failure to org weakness."),
166
- ],
167
- "refactor": [
168
- lambda nodes: ImmunoAction(action_type=ActionType.STRATEGIC, strategic_action=StrategicAction.ESTABLISH_DEVSECOPS,
169
- target="dept-security", reasoning="Establishing DevSecOps."),
170
- lambda nodes: ImmunoAction(action_type=ActionType.STRATEGIC, strategic_action=StrategicAction.REDUCE_BUREAUCRACY,
171
- target="dept-management", reasoning="Reducing bureaucracy."),
172
- ],
173
- "validation": [
174
- lambda nodes: ImmunoAction(action_type=ActionType.DIAGNOSTIC, diagnostic_action=DiagnosticAction.MEASURE_ORG_LATENCY,
175
- target="", reasoning="Measuring org improvements."),
176
- lambda nodes: ImmunoAction(action_type=ActionType.DIAGNOSTIC, diagnostic_action=DiagnosticAction.VULNERABILITY_SCAN,
177
- target="", reasoning="Final vulnerability check."),
178
- ],
179
- }
180
-
181
- prompts_per_combo = max(1, num_prompts // (4 * 10)) # 4 difficulties * ~10 seeds
182
-
183
- for difficulty in [1, 2, 3, 4]:
184
- for seed in range(50):
185
- if len(scenarios) >= num_prompts:
186
- break
187
-
188
- try:
189
- env = ImmunoOrgEnvironment(difficulty=difficulty, seed=seed)
190
- obs = env.reset()
191
-
192
- # Run a few steps to reach different phases
193
- rng = random.Random(seed)
194
- for step in range(min(15, env.state.max_steps)):
195
- # Capture observation as a training prompt
196
- obs_text = format_observation_for_llm(obs.model_dump())
197
- prompt = f"{system_prompt}\n\n## Current Observation\n{obs_text}\n\nRespond with a JSON action:"
198
- scenarios.append({"prompt": prompt})
199
-
200
- if len(scenarios) >= num_prompts:
201
- break
202
-
203
- # Take an action to advance the episode
204
- phase = obs.current_phase.value
205
- actions = phase_actions.get(phase, phase_actions["detection"])
206
- action_fn = rng.choice(actions)
207
- action = action_fn(obs.visible_nodes)
208
-
209
- obs, reward, done = env.step(action)
210
- if done:
211
- break
212
- except Exception as e:
213
- continue
214
-
215
- print(f" Generated {len(scenarios)} training prompts across 4 difficulty levels")
216
- return scenarios
217
-
218
-
219
- def build_arg_parser() -> argparse.ArgumentParser:
220
- parser = argparse.ArgumentParser(description="Train ImmunoOrg defender agent with GRPO")
221
- parser.add_argument("--smoke-test", action="store_true", help="Quick test with 2 steps")
222
- parser.add_argument("--warm-start", action="store_true", help="Warm-start using golden trajectories (SFT)")
223
- parser.add_argument("--model", default="Qwen/Qwen2.5-0.5B-Instruct", help="Base model")
224
- parser.add_argument("--output-dir", default="./immunoorg-defender", help="Output directory")
225
- parser.add_argument("--epochs", type=int, default=3, help="Training epochs")
226
- parser.add_argument("--batch-size", type=int, default=2, help="Per-device batch size")
227
- parser.add_argument("--lr", type=float, default=5e-6, help="Learning rate")
228
- parser.add_argument("--num-generations", type=int, default=4, help="GRPO generations per prompt")
229
- parser.add_argument("--max-completion-length", type=int, default=1024, help="Max completion tokens")
230
- return parser
231
-
232
-
233
- def parse_train_args(argv: list[str] | None = None) -> Namespace:
234
- return build_arg_parser().parse_args(argv)
235
-
236
-
237
- def _maybe_push_to_hub(output_dir: str) -> None:
238
- """If HF_TRAINING_PUSH_REPO_ID is set, upload ``output_dir`` to the Hub (uses HF_TOKEN)."""
239
- repo_id = os.environ.get("HF_TRAINING_PUSH_REPO_ID", "").strip()
240
- if not repo_id:
241
- return
242
- token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
243
- if not token:
244
- print("HF_TRAINING_PUSH_REPO_ID is set but HF_TOKEN is missing; skipping Hub upload.")
245
- return
246
- from huggingface_hub import HfApi
247
-
248
- api = HfApi(token=token)
249
- api.create_repo(repo_id, repo_type="model", exist_ok=True)
250
- api.upload_folder(folder_path=output_dir, repo_id=repo_id, repo_type="model")
251
- print(f"Uploaded training artifacts to https://huggingface.co/{repo_id}")
252
-
253
-
254
- def run_grpo_training(args: Namespace) -> None:
255
- print("=" * 60)
256
- print("ImmunoOrg GRPO Training Pipeline")
257
- print("=" * 60)
258
-
259
- # Try importing training libs
260
- try:
261
- from unsloth import FastLanguageModel
262
- HAS_UNSLOTH = True
263
- print("Check: Unsloth detected - using optimized training")
264
- except ImportError:
265
- HAS_UNSLOTH = False
266
- print("Warning: Unsloth not found - using standard HF training")
267
-
268
- try:
269
- from trl import GRPOTrainer, GRPOConfig
270
- from datasets import Dataset
271
- print("Check: TRL and datasets loaded")
272
- except ImportError:
273
- print("Error: TRL not installed. Run: pip install trl datasets")
274
- sys.exit(1)
275
-
276
- # 1. Load model
277
- print(f"\nLoading model: {args.model}")
278
- if HAS_UNSLOTH:
279
- model, tokenizer = FastLanguageModel.from_pretrained(
280
- args.model,
281
- max_seq_length=2048,
282
- load_in_4bit=True,
283
- )
284
- model = FastLanguageModel.get_peft_model(
285
- model, r=16, lora_alpha=16,
286
- target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
287
- "gate_proj", "up_proj", "down_proj"],
288
- )
289
- else:
290
- from transformers import AutoModelForCausalLM, AutoTokenizer
291
- from peft import get_peft_model, LoraConfig
292
- import torch as _torch
293
- tokenizer = AutoTokenizer.from_pretrained(args.model)
294
- device_map = "auto" if _torch.cuda.is_available() else "cpu"
295
- model = AutoModelForCausalLM.from_pretrained(args.model, device_map=device_map)
296
- lora_config = LoraConfig(r=16, lora_alpha=16, target_modules="all-linear")
297
- model = get_peft_model(model, lora_config)
298
-
299
- if tokenizer.pad_token is None:
300
- tokenizer.pad_token = tokenizer.eos_token
301
-
302
- # 2. Build dataset
303
- print("\nBuilding training dataset")
304
- scenarios = build_training_prompts()
305
- dataset = Dataset.from_list([{"prompt": s["prompt"]} for s in scenarios])
306
- print(f" {len(dataset)} scenarios loaded")
307
-
308
- # 3. Configure GRPO
309
- max_steps = 2 if args.smoke_test else None
310
- try:
311
- import torch as _torch
312
- on_cpu = not _torch.cuda.is_available()
313
- except Exception:
314
- on_cpu = True
315
-
316
- grpo_kwargs = dict(
317
- output_dir=args.output_dir,
318
- num_generations=args.num_generations,
319
- max_completion_length=args.max_completion_length,
320
- per_device_train_batch_size=args.batch_size,
321
- learning_rate=args.lr,
322
- num_train_epochs=args.epochs,
323
- beta=0.04,
324
- logging_steps=1,
325
- save_steps=50,
326
- max_steps=max_steps if max_steps else -1,
327
- report_to="none",
328
- )
329
- if on_cpu:
330
- # Recent TRL refuses to start with bf16/fp16 defaults if no GPU is
331
- # present; we have to opt into CPU explicitly and turn mixed
332
- # precision off so the smoke test works on a developer laptop.
333
- grpo_kwargs.update(use_cpu=True, bf16=False, fp16=False)
334
- config = GRPOConfig(**grpo_kwargs)
335
-
336
- # 4. Create trainer (TRL renamed `config` -> `args` somewhere in 1.x;
337
- # try both so this works against older and newer TRL releases.)
338
- print("\nCreating GRPO trainer")
339
- trainer_kwargs = dict(
340
- model=model,
341
- reward_funcs=[format_reward, reasoning_quality_reward, phase_appropriate_reward],
342
- train_dataset=dataset,
343
- processing_class=tokenizer,
344
- )
345
- try:
346
- trainer = GRPOTrainer(args=config, **trainer_kwargs)
347
- except TypeError:
348
- trainer = GRPOTrainer(config=config, **trainer_kwargs)
349
-
350
- # 5. Train
351
- print("\nStarting training...")
352
- trainer.train()
353
-
354
- # 6. Save
355
- print(f"\nSaving model to {args.output_dir}")
356
- os.makedirs(args.output_dir, exist_ok=True)
357
- trainer.save_model(args.output_dir)
358
- tokenizer.save_pretrained(args.output_dir)
359
- print("Training complete!")
360
- _maybe_push_to_hub(args.output_dir)
361
-
362
-
363
- def main() -> None:
364
- run_grpo_training(parse_train_args())
365
-
366
-
367
- if __name__ == "__main__":
368
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GRPO Training Script for ImmunoOrg
3
+ ===================================
4
+ Uses Unsloth + HF TRL to train a defender agent via GRPO.
5
+ """
6
+
7
+ from __future__ import annotations
8
+
9
+ import argparse
10
+ import json
11
+ import os
12
+ import re
13
+ import sys
14
+ from argparse import Namespace
15
+ from typing import Any
16
+
17
+ # Add parent dir to path
18
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
19
+
20
+
21
+ def parse_action_from_completion(text: str) -> dict[str, Any] | None:
22
+ """Extract JSON action from model completion."""
23
+ # Try to find JSON block
24
+ json_match = re.search(r'\{[^{}]*(?:\{[^{}]*\}[^{}]*)*\}', text)
25
+ if json_match:
26
+ try:
27
+ return json.loads(json_match.group())
28
+ except json.JSONDecodeError:
29
+ pass
30
+ return None
31
+
32
+
33
+ def format_reward(completions: list[str], prompts: list[str], **kwargs) -> list[float]:
34
+ """Reward function: score based on valid JSON action format."""
35
+ rewards = []
36
+ for completion in completions:
37
+ score = 0.0
38
+ action = parse_action_from_completion(completion)
39
+ if action:
40
+ score += 0.3 # Valid JSON
41
+ if action.get("action_type") in ("tactical", "strategic", "diagnostic"):
42
+ score += 0.2 # Valid action type
43
+ if action.get("reasoning") and len(action["reasoning"]) > 20:
44
+ score += 0.2 # Has reasoning
45
+ if action.get("target"):
46
+ score += 0.1 # Has target
47
+ # Check specific action fields
48
+ if action.get("tactical_action") or action.get("strategic_action") or action.get("diagnostic_action"):
49
+ score += 0.2 # Has specific action
50
+ rewards.append(score)
51
+ return rewards
52
+
53
+
54
+ def reasoning_quality_reward(completions: list[str], prompts: list[str], **kwargs) -> list[float]:
55
+ """Reward for reasoning quality in completions."""
56
+ rewards = []
57
+ causal = ["because", "therefore", "since", "indicates", "correlates", "caused by", "root cause"]
58
+ structured = ["1.", "2.", "Step", "First", "Then", "Finally"]
59
+
60
+ for completion in completions:
61
+ score = 0.0
62
+ lower = completion.lower()
63
+ words = len(completion.split())
64
+
65
+ # Length (not too short, not padding)
66
+ if 30 <= words <= 500:
67
+ score += 0.2
68
+ elif words >= 10:
69
+ score += 0.1
70
+
71
+ # Causal reasoning
72
+ if any(kw in lower for kw in causal):
73
+ score += 0.3
74
+
75
+ # Structured thinking
76
+ if any(m in completion for m in structured):
77
+ score += 0.2
78
+
79
+ # References specific entities
80
+ if re.search(r'(node|port|department|server|attack|vulnerability|silo)', lower):
81
+ score += 0.2
82
+
83
+ # Phase awareness
84
+ if re.search(r'(detection|containment|root cause|refactor|validation)', lower):
85
+ score += 0.1
86
+
87
+ rewards.append(min(1.0, score))
88
+ return rewards
89
+
90
+
91
+ def phase_appropriate_reward(completions: list[str], prompts: list[str], **kwargs) -> list[float]:
92
+ """Reward for taking actions appropriate to the current phase."""
93
+ rewards = []
94
+ phase_actions = {
95
+ "detection": ["scan_logs", "vulnerability_scan", "trace_attack_path"],
96
+ "containment": ["block_port", "isolate_node", "quarantine_traffic"],
97
+ "rca": ["correlate_failure", "identify_silo", "timeline_reconstruct"],
98
+ "refactor": ["merge_departments", "create_shortcut_edge", "establish_devsecops"],
99
+ "validation": ["measure_org_latency", "vulnerability_scan"],
100
+ }
101
+
102
+ for completion, prompt in zip(completions, prompts):
103
+ score = 0.0
104
+ # Detect phase from prompt
105
+ current_phase = None
106
+ for phase in phase_actions:
107
+ if phase.upper() in prompt or f"Phase: {phase}" in prompt:
108
+ current_phase = phase
109
+ break
110
+
111
+ if current_phase:
112
+ appropriate = phase_actions.get(current_phase, [])
113
+ action = parse_action_from_completion(completion)
114
+ if action:
115
+ action_name = (action.get("tactical_action") or
116
+ action.get("strategic_action") or
117
+ action.get("diagnostic_action") or "")
118
+ if action_name in appropriate:
119
+ score = 1.0
120
+ else:
121
+ score = 0.2 # Valid but wrong phase
122
+ rewards.append(score)
123
+ return rewards
124
+
125
+
126
+ def build_training_prompts(num_prompts: int = 200) -> list[dict[str, str]]:
127
+ """Generate diverse training prompts by running actual environments.
128
+
129
+ Instead of 5 hardcoded scenarios, we run the environment across:
130
+ - 4 difficulty levels
131
+ - Multiple seeds
132
+ - All 5 incident phases
133
+ This produces genuine, diverse observations for GRPO training.
134
+ """
135
+ from immunoorg.agents.defender import get_defender_prompt, format_observation_for_llm
136
+ from immunoorg.environment import ImmunoOrgEnvironment
137
+ from immunoorg.models import (
138
+ ActionType, TacticalAction, DiagnosticAction, StrategicAction, ImmunoAction
139
+ )
140
+ import random
141
+
142
+ system_prompt = get_defender_prompt()
143
+ scenarios = []
144
+
145
+ # Phase-appropriate actions for generating trajectories
146
+ phase_actions = {
147
+ "detection": [
148
+ lambda nodes: ImmunoAction(action_type=ActionType.TACTICAL, tactical_action=TacticalAction.SCAN_LOGS,
149
+ target=nodes[0].id if nodes else "", reasoning="Scanning for indicators."),
150
+ lambda nodes: ImmunoAction(action_type=ActionType.DIAGNOSTIC, diagnostic_action=DiagnosticAction.TRACE_ATTACK_PATH,
151
+ target="", reasoning="Tracing attack path."),
152
+ ],
153
+ "containment": [
154
+ lambda nodes: ImmunoAction(action_type=ActionType.TACTICAL, tactical_action=TacticalAction.ISOLATE_NODE,
155
+ target=next((n.id for n in nodes if n.compromised), nodes[0].id if nodes else ""),
156
+ reasoning="Isolating compromised node."),
157
+ lambda nodes: ImmunoAction(action_type=ActionType.TACTICAL, tactical_action=TacticalAction.BLOCK_PORT,
158
+ target=nodes[0].id if nodes else "", reasoning="Blocking attack port."),
159
+ ],
160
+ "rca": [
161
+ lambda nodes: ImmunoAction(action_type=ActionType.DIAGNOSTIC, diagnostic_action=DiagnosticAction.IDENTIFY_SILO,
162
+ target="", reasoning="Finding organizational silos."),
163
+ lambda nodes: ImmunoAction(action_type=ActionType.DIAGNOSTIC, diagnostic_action=DiagnosticAction.CORRELATE_FAILURE,
164
+ target="", parameters={"technical_indicator": "attack", "organizational_flaw": "no_devsecops", "confidence": 0.7},
165
+ reasoning="Correlating technical failure to org weakness."),
166
+ ],
167
+ "refactor": [
168
+ lambda nodes: ImmunoAction(action_type=ActionType.STRATEGIC, strategic_action=StrategicAction.ESTABLISH_DEVSECOPS,
169
+ target="dept-security", reasoning="Establishing DevSecOps."),
170
+ lambda nodes: ImmunoAction(action_type=ActionType.STRATEGIC, strategic_action=StrategicAction.REDUCE_BUREAUCRACY,
171
+ target="dept-management", reasoning="Reducing bureaucracy."),
172
+ ],
173
+ "validation": [
174
+ lambda nodes: ImmunoAction(action_type=ActionType.DIAGNOSTIC, diagnostic_action=DiagnosticAction.MEASURE_ORG_LATENCY,
175
+ target="", reasoning="Measuring org improvements."),
176
+ lambda nodes: ImmunoAction(action_type=ActionType.DIAGNOSTIC, diagnostic_action=DiagnosticAction.VULNERABILITY_SCAN,
177
+ target="", reasoning="Final vulnerability check."),
178
+ ],
179
+ }
180
+
181
+ prompts_per_combo = max(1, num_prompts // (4 * 10)) # 4 difficulties * ~10 seeds
182
+
183
+ for difficulty in [1, 2, 3, 4]:
184
+ for seed in range(50):
185
+ if len(scenarios) >= num_prompts:
186
+ break
187
+
188
+ try:
189
+ env = ImmunoOrgEnvironment(difficulty=difficulty, seed=seed)
190
+ obs = env.reset()
191
+
192
+ # Run a few steps to reach different phases
193
+ rng = random.Random(seed)
194
+ for step in range(min(15, env.state.max_steps)):
195
+ # Capture observation as a training prompt
196
+ obs_text = format_observation_for_llm(obs.model_dump())
197
+ prompt = f"{system_prompt}\n\n## Current Observation\n{obs_text}\n\nRespond with a JSON action:"
198
+ scenarios.append({"prompt": prompt})
199
+
200
+ if len(scenarios) >= num_prompts:
201
+ break
202
+
203
+ # Take an action to advance the episode
204
+ phase = obs.current_phase.value
205
+ actions = phase_actions.get(phase, phase_actions["detection"])
206
+ action_fn = rng.choice(actions)
207
+ action = action_fn(obs.visible_nodes)
208
+
209
+ obs, reward, done = env.step(action)
210
+ if done:
211
+ break
212
+ except Exception as e:
213
+ continue
214
+
215
+ print(f" Generated {len(scenarios)} training prompts across 4 difficulty levels")
216
+ return scenarios
217
+
218
+
219
+ def build_arg_parser() -> argparse.ArgumentParser:
220
+ parser = argparse.ArgumentParser(description="Train ImmunoOrg defender agent with GRPO")
221
+ parser.add_argument("--smoke-test", action="store_true", help="Quick test with 2 steps")
222
+ parser.add_argument("--warm-start", action="store_true", help="Warm-start using golden trajectories (SFT)")
223
+ parser.add_argument("--model", default="Qwen/Qwen2.5-0.5B-Instruct", help="Base model")
224
+ parser.add_argument("--output-dir", default="./immunoorg-defender", help="Output directory")
225
+ parser.add_argument("--epochs", type=int, default=3, help="Training epochs")
226
+ parser.add_argument("--batch-size", type=int, default=2, help="Per-device batch size")
227
+ parser.add_argument("--lr", type=float, default=5e-6, help="Learning rate")
228
+ parser.add_argument(
229
+ "--num-generations",
230
+ type=int,
231
+ default=2,
232
+ help="GRPO generations per prompt (must divide batch-size × prompts per step; default 2 with batch 2)",
233
+ )
234
+ parser.add_argument("--max-completion-length", type=int, default=1024, help="Max completion tokens")
235
+ return parser
236
+
237
+
238
+ def parse_train_args(argv: list[str] | None = None) -> Namespace:
239
+ return build_arg_parser().parse_args(argv)
240
+
241
+
242
+ def _maybe_push_to_hub(output_dir: str) -> None:
243
+ """If HF_TRAINING_PUSH_REPO_ID is set, upload ``output_dir`` to the Hub (uses HF_TOKEN)."""
244
+ repo_id = os.environ.get("HF_TRAINING_PUSH_REPO_ID", "").strip()
245
+ if not repo_id:
246
+ return
247
+ token = os.environ.get("HF_TOKEN") or os.environ.get("HUGGINGFACE_HUB_TOKEN")
248
+ if not token:
249
+ print("HF_TRAINING_PUSH_REPO_ID is set but HF_TOKEN is missing; skipping Hub upload.")
250
+ return
251
+ from huggingface_hub import HfApi
252
+
253
+ api = HfApi(token=token)
254
+ api.create_repo(repo_id, repo_type="model", exist_ok=True)
255
+ api.upload_folder(folder_path=output_dir, repo_id=repo_id, repo_type="model")
256
+ print(f"Uploaded training artifacts to https://huggingface.co/{repo_id}")
257
+
258
+
259
+ def run_grpo_training(args: Namespace) -> None:
260
+ print("=" * 60)
261
+ print("ImmunoOrg GRPO Training Pipeline")
262
+ print("=" * 60)
263
+
264
+ # Try importing training libs
265
+ try:
266
+ from unsloth import FastLanguageModel
267
+ HAS_UNSLOTH = True
268
+ print("Check: Unsloth detected - using optimized training")
269
+ except ImportError:
270
+ HAS_UNSLOTH = False
271
+ print("Warning: Unsloth not found - using standard HF training")
272
+
273
+ try:
274
+ from trl import GRPOTrainer, GRPOConfig
275
+ from datasets import Dataset
276
+ print("Check: TRL and datasets loaded")
277
+ except ImportError:
278
+ print("Error: TRL not installed. Run: pip install trl datasets")
279
+ sys.exit(1)
280
+
281
+ # 1. Load model
282
+ print(f"\nLoading model: {args.model}")
283
+ if HAS_UNSLOTH:
284
+ model, tokenizer = FastLanguageModel.from_pretrained(
285
+ args.model,
286
+ max_seq_length=2048,
287
+ load_in_4bit=True,
288
+ )
289
+ model = FastLanguageModel.get_peft_model(
290
+ model, r=16, lora_alpha=16,
291
+ target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
292
+ "gate_proj", "up_proj", "down_proj"],
293
+ )
294
+ else:
295
+ from transformers import AutoModelForCausalLM, AutoTokenizer
296
+ from peft import get_peft_model, LoraConfig
297
+ import torch as _torch
298
+ tokenizer = AutoTokenizer.from_pretrained(args.model)
299
+ device_map = "auto" if _torch.cuda.is_available() else "cpu"
300
+ model = AutoModelForCausalLM.from_pretrained(args.model, device_map=device_map)
301
+ lora_config = LoraConfig(r=16, lora_alpha=16, target_modules="all-linear")
302
+ model = get_peft_model(model, lora_config)
303
+
304
+ if tokenizer.pad_token is None:
305
+ tokenizer.pad_token = tokenizer.eos_token
306
+
307
+ # 2. Build dataset
308
+ print("\nBuilding training dataset")
309
+ n_prompts = 8 if args.smoke_test else 200
310
+ scenarios = build_training_prompts(num_prompts=n_prompts)
311
+ dataset = Dataset.from_list([{"prompt": s["prompt"]} for s in scenarios])
312
+ print(f" {len(dataset)} scenarios loaded")
313
+ if args.smoke_test:
314
+ print(
315
+ " [smoke-test] Using a tiny dataset, 2 optimizer steps, max_completion_length capped at 256."
316
+ )
317
+ if len(dataset) > 32:
318
+ print(
319
+ " [smoke-test] WARNING: dataset is large for a smoke run — use latest train_grpo.py (8 prompts)."
320
+ )
321
+
322
+ # 3. Configure GRPO
323
+ max_steps = 2 if args.smoke_test else None
324
+ try:
325
+ import torch as _torch
326
+ on_cpu = not _torch.cuda.is_available()
327
+ except Exception:
328
+ on_cpu = True
329
+
330
+ # Smoke: short completions = much faster generation on CPU
331
+ max_len = min(args.max_completion_length, 256) if args.smoke_test else args.max_completion_length
332
+
333
+ if on_cpu:
334
+ print(
335
+ "\n NOTE: No CUDA — GRPO spends most of its time in **generation** (not shown in % bar).\n"
336
+ " First step can take **many minutes** on CPU for a 0.5B model. For quick runs use Colab GPU.\n"
337
+ )
338
+
339
+ grpo_kwargs = dict(
340
+ output_dir=args.output_dir,
341
+ num_generations=args.num_generations,
342
+ max_completion_length=max_len,
343
+ per_device_train_batch_size=args.batch_size,
344
+ learning_rate=args.lr,
345
+ num_train_epochs=args.epochs,
346
+ beta=0.04,
347
+ logging_steps=1,
348
+ save_steps=50,
349
+ max_steps=max_steps if max_steps else -1,
350
+ report_to="none",
351
+ )
352
+ if on_cpu:
353
+ # Recent TRL refuses to start with bf16/fp16 defaults if no GPU is
354
+ # present; we have to opt into CPU explicitly and turn mixed
355
+ # precision off so the smoke test works on a developer laptop.
356
+ grpo_kwargs.update(use_cpu=True, bf16=False, fp16=False)
357
+ config = GRPOConfig(**grpo_kwargs)
358
+
359
+ # 4. Create trainer (TRL renamed `config` -> `args` somewhere in 1.x;
360
+ # try both so this works against older and newer TRL releases.)
361
+ print("\nCreating GRPO trainer")
362
+ trainer_kwargs = dict(
363
+ model=model,
364
+ reward_funcs=[format_reward, reasoning_quality_reward, phase_appropriate_reward],
365
+ train_dataset=dataset,
366
+ processing_class=tokenizer,
367
+ )
368
+ try:
369
+ trainer = GRPOTrainer(args=config, **trainer_kwargs)
370
+ except TypeError:
371
+ trainer = GRPOTrainer(config=config, **trainer_kwargs)
372
+
373
+ # 5. Train
374
+ print("\nStarting training...")
375
+ trainer.train()
376
+
377
+ # 5b. Export log history for evidence plots (hackathon / README)
378
+ try:
379
+ from pathlib import Path as _Path
380
+
381
+ log_hist = getattr(getattr(trainer, "state", None), "log_history", None) or []
382
+ out_dir = _Path(args.output_dir)
383
+ out_dir.mkdir(parents=True, exist_ok=True)
384
+ log_path = out_dir / "grpo_log_history.json"
385
+ with open(log_path, "w", encoding="utf-8") as f:
386
+ json.dump(log_hist, f, indent=2)
387
+ print(f" Wrote training log history -> {log_path}")
388
+ except Exception as _e:
389
+ print(f" (could not export log_history: {_e})")
390
+
391
+ # 6. Save
392
+ print(f"\nSaving model to {args.output_dir}")
393
+ os.makedirs(args.output_dir, exist_ok=True)
394
+ trainer.save_model(args.output_dir)
395
+ tokenizer.save_pretrained(args.output_dir)
396
+ print("Training complete!")
397
+ _maybe_push_to_hub(args.output_dir)
398
+
399
+
400
+ def main() -> None:
401
+ run_grpo_training(parse_train_args())
402
+
403
+
404
+ if __name__ == "__main__":
405
+ main()