hirann commited on
Commit
1ff4aed
·
verified ·
1 Parent(s): 6305fd5

Upload training/golden_trajectories.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. training/golden_trajectories.py +119 -119
training/golden_trajectories.py CHANGED
@@ -1,119 +1,119 @@
1
- """
2
- Golden Trajectory Extraction
3
- =============================
4
- Extract optimal trajectories from environment rollouts for SFT warm-start.
5
- """
6
-
7
- from __future__ import annotations
8
- import json
9
- import sys, os
10
- sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
11
-
12
- from immunoorg.environment import ImmunoOrgEnvironment
13
- from immunoorg.models import ImmunoAction, TacticalAction, DiagnosticAction, StrategicAction, ActionType
14
-
15
-
16
- # Hand-crafted golden trajectories for each difficulty level
17
- GOLDEN_TRAJECTORIES = {
18
- 1: [
19
- {"action_type": "tactical", "tactical_action": "scan_logs", "reasoning": "Step 1: Scan logs on the first compromised node to identify attack vector and confirm the threat before taking containment action."},
20
- {"action_type": "diagnostic", "diagnostic_action": "trace_attack_path", "reasoning": "Step 2: Trace the attack path to understand scope. Need to know if lateral movement occurred."},
21
- {"action_type": "tactical", "tactical_action": "isolate_node", "reasoning": "Step 3: Isolate the compromised node to stop the attack. This is the fastest containment action."},
22
- {"action_type": "diagnostic", "diagnostic_action": "correlate_failure", "reasoning": "Step 4: Correlate the technical failure to organizational root cause. The attack vector indicates a gap in security practices."},
23
- {"action_type": "diagnostic", "diagnostic_action": "identify_silo", "reasoning": "Step 5: Check for organizational silos that may have slowed response."},
24
- {"action_type": "diagnostic", "diagnostic_action": "measure_org_latency", "reasoning": "Step 6: Validate organizational improvements by measuring current efficiency metrics."},
25
- ],
26
- 2: [
27
- {"action_type": "tactical", "tactical_action": "scan_logs", "reasoning": "Phase 1 Detection: Scanning logs to identify initial compromise point and attack vector."},
28
- {"action_type": "diagnostic", "diagnostic_action": "trace_attack_path", "reasoning": "Tracing lateral movement path to understand the full kill chain before containment."},
29
- {"action_type": "diagnostic", "diagnostic_action": "timeline_reconstruct", "reasoning": "Reconstructing timeline to identify sequence of compromises for targeted containment."},
30
- {"action_type": "tactical", "tactical_action": "isolate_node", "reasoning": "Isolating the source node first to cut off the attacker's entry point."},
31
- {"action_type": "tactical", "tactical_action": "deploy_patch", "reasoning": "Patching the initially compromised node to close the vulnerability."},
32
- {"action_type": "diagnostic", "diagnostic_action": "correlate_failure", "reasoning": "Correlating the lateral movement success to flat network segmentation and security-engineering silo."},
33
- {"action_type": "diagnostic", "diagnostic_action": "identify_silo", "reasoning": "Identifying silos that delayed incident response."},
34
- {"action_type": "strategic", "strategic_action": "create_shortcut_edge", "reasoning": "Creating fast communication channel between security and engineering to improve future response."},
35
- {"action_type": "diagnostic", "diagnostic_action": "measure_org_latency", "reasoning": "Validating that the organizational change improved efficiency."},
36
- ],
37
- }
38
-
39
-
40
- def extract_golden_trajectories(num_episodes: int = 10, difficulty: int = 1) -> list[dict]:
41
- """Run episodes and extract the best trajectories."""
42
- best_trajectories = []
43
-
44
- for ep in range(num_episodes):
45
- env = ImmunoOrgEnvironment(difficulty=difficulty, seed=ep)
46
- obs = env.reset()
47
-
48
- trajectory = []
49
- golden = GOLDEN_TRAJECTORIES.get(difficulty, GOLDEN_TRAJECTORIES[1])
50
-
51
- total_reward = 0.0
52
- for i, action_template in enumerate(golden):
53
- action = ImmunoAction(
54
- action_type=ActionType(action_template["action_type"]),
55
- tactical_action=TacticalAction(action_template["tactical_action"]) if action_template.get("tactical_action") else None,
56
- strategic_action=StrategicAction(action_template["strategic_action"]) if action_template.get("strategic_action") else None,
57
- diagnostic_action=DiagnosticAction(action_template["diagnostic_action"]) if action_template.get("diagnostic_action") else None,
58
- target=_get_target(obs, action_template),
59
- reasoning=action_template["reasoning"],
60
- )
61
-
62
- obs_obj, reward, terminated = env.step(action)
63
- total_reward += reward
64
- obs = obs_obj
65
-
66
- trajectory.append({
67
- "step": i,
68
- "action": action.model_dump(),
69
- "reward": reward,
70
- "observation_summary": f"phase={obs.current_phase.value} threat={obs.threat_level:.2f}",
71
- })
72
-
73
- if terminated:
74
- break
75
-
76
- best_trajectories.append({
77
- "episode": ep,
78
- "difficulty": difficulty,
79
- "total_reward": total_reward,
80
- "steps": len(trajectory),
81
- "trajectory": trajectory,
82
- })
83
-
84
- # Sort by reward and return top trajectories
85
- best_trajectories.sort(key=lambda t: t["total_reward"], reverse=True)
86
- return best_trajectories[:max(1, num_episodes // 2)]
87
-
88
-
89
- def _get_target(obs, action_template: dict) -> str:
90
- """Pick an appropriate target from the observation."""
91
- action_type = action_template.get("action_type", "")
92
- if action_type == "tactical":
93
- nodes = obs.visible_nodes if hasattr(obs, 'visible_nodes') else []
94
- compromised = [n for n in nodes if n.compromised]
95
- if compromised:
96
- return compromised[0].id
97
- if nodes:
98
- return nodes[0].id
99
- elif action_type == "strategic":
100
- return "dept-security"
101
- return ""
102
-
103
-
104
- def save_golden_trajectories(output_path: str = "golden_trajectories.json"):
105
- """Generate and save golden trajectories for all difficulty levels."""
106
- all_trajectories = {}
107
- for diff in [1, 2]:
108
- print(f"Extracting golden trajectories for difficulty {diff}...")
109
- trajectories = extract_golden_trajectories(num_episodes=5, difficulty=diff)
110
- all_trajectories[f"level_{diff}"] = trajectories
111
- print(f" Found {len(trajectories)} good trajectories (best reward: {trajectories[0]['total_reward']:.3f})")
112
-
113
- with open(output_path, "w") as f:
114
- json.dump(all_trajectories, f, indent=2, default=str)
115
- print(f"\n💾 Saved to {output_path}")
116
-
117
-
118
- if __name__ == "__main__":
119
- save_golden_trajectories()
 
1
+ """
2
+ Golden Trajectory Extraction
3
+ =============================
4
+ Extract optimal trajectories from environment rollouts for SFT warm-start.
5
+ """
6
+
7
+ from __future__ import annotations
8
+ import json
9
+ import sys, os
10
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
11
+
12
+ from immunoorg.environment import ImmunoOrgEnvironment
13
+ from immunoorg.models import ImmunoAction, TacticalAction, DiagnosticAction, StrategicAction, ActionType
14
+
15
+
16
+ # Hand-crafted golden trajectories for each difficulty level
17
+ GOLDEN_TRAJECTORIES = {
18
+ 1: [
19
+ {"action_type": "tactical", "tactical_action": "scan_logs", "reasoning": "Step 1: Scan logs on the first compromised node to identify attack vector and confirm the threat before taking containment action."},
20
+ {"action_type": "diagnostic", "diagnostic_action": "trace_attack_path", "reasoning": "Step 2: Trace the attack path to understand scope. Need to know if lateral movement occurred."},
21
+ {"action_type": "tactical", "tactical_action": "isolate_node", "reasoning": "Step 3: Isolate the compromised node to stop the attack. This is the fastest containment action."},
22
+ {"action_type": "diagnostic", "diagnostic_action": "correlate_failure", "reasoning": "Step 4: Correlate the technical failure to organizational root cause. The attack vector indicates a gap in security practices."},
23
+ {"action_type": "diagnostic", "diagnostic_action": "identify_silo", "reasoning": "Step 5: Check for organizational silos that may have slowed response."},
24
+ {"action_type": "diagnostic", "diagnostic_action": "measure_org_latency", "reasoning": "Step 6: Validate organizational improvements by measuring current efficiency metrics."},
25
+ ],
26
+ 2: [
27
+ {"action_type": "tactical", "tactical_action": "scan_logs", "reasoning": "Phase 1 Detection: Scanning logs to identify initial compromise point and attack vector."},
28
+ {"action_type": "diagnostic", "diagnostic_action": "trace_attack_path", "reasoning": "Tracing lateral movement path to understand the full kill chain before containment."},
29
+ {"action_type": "diagnostic", "diagnostic_action": "timeline_reconstruct", "reasoning": "Reconstructing timeline to identify sequence of compromises for targeted containment."},
30
+ {"action_type": "tactical", "tactical_action": "isolate_node", "reasoning": "Isolating the source node first to cut off the attacker's entry point."},
31
+ {"action_type": "tactical", "tactical_action": "deploy_patch", "reasoning": "Patching the initially compromised node to close the vulnerability."},
32
+ {"action_type": "diagnostic", "diagnostic_action": "correlate_failure", "reasoning": "Correlating the lateral movement success to flat network segmentation and security-engineering silo."},
33
+ {"action_type": "diagnostic", "diagnostic_action": "identify_silo", "reasoning": "Identifying silos that delayed incident response."},
34
+ {"action_type": "strategic", "strategic_action": "create_shortcut_edge", "reasoning": "Creating fast communication channel between security and engineering to improve future response."},
35
+ {"action_type": "diagnostic", "diagnostic_action": "measure_org_latency", "reasoning": "Validating that the organizational change improved efficiency."},
36
+ ],
37
+ }
38
+
39
+
40
+ def extract_golden_trajectories(num_episodes: int = 10, difficulty: int = 1) -> list[dict]:
41
+ """Run episodes and extract the best trajectories."""
42
+ best_trajectories = []
43
+
44
+ for ep in range(num_episodes):
45
+ env = ImmunoOrgEnvironment(difficulty=difficulty, seed=ep)
46
+ obs = env.reset()
47
+
48
+ trajectory = []
49
+ golden = GOLDEN_TRAJECTORIES.get(difficulty, GOLDEN_TRAJECTORIES[1])
50
+
51
+ total_reward = 0.0
52
+ for i, action_template in enumerate(golden):
53
+ action = ImmunoAction(
54
+ action_type=ActionType(action_template["action_type"]),
55
+ tactical_action=TacticalAction(action_template["tactical_action"]) if action_template.get("tactical_action") else None,
56
+ strategic_action=StrategicAction(action_template["strategic_action"]) if action_template.get("strategic_action") else None,
57
+ diagnostic_action=DiagnosticAction(action_template["diagnostic_action"]) if action_template.get("diagnostic_action") else None,
58
+ target=_get_target(obs, action_template),
59
+ reasoning=action_template["reasoning"],
60
+ )
61
+
62
+ obs_obj, reward, terminated = env.step(action)
63
+ total_reward += reward
64
+ obs = obs_obj
65
+
66
+ trajectory.append({
67
+ "step": i,
68
+ "action": action.model_dump(),
69
+ "reward": reward,
70
+ "observation_summary": f"phase={obs.current_phase.value} threat={obs.threat_level:.2f}",
71
+ })
72
+
73
+ if terminated:
74
+ break
75
+
76
+ best_trajectories.append({
77
+ "episode": ep,
78
+ "difficulty": difficulty,
79
+ "total_reward": total_reward,
80
+ "steps": len(trajectory),
81
+ "trajectory": trajectory,
82
+ })
83
+
84
+ # Sort by reward and return top trajectories
85
+ best_trajectories.sort(key=lambda t: t["total_reward"], reverse=True)
86
+ return best_trajectories[:max(1, num_episodes // 2)]
87
+
88
+
89
+ def _get_target(obs, action_template: dict) -> str:
90
+ """Pick an appropriate target from the observation."""
91
+ action_type = action_template.get("action_type", "")
92
+ if action_type == "tactical":
93
+ nodes = obs.visible_nodes if hasattr(obs, 'visible_nodes') else []
94
+ compromised = [n for n in nodes if n.compromised]
95
+ if compromised:
96
+ return compromised[0].id
97
+ if nodes:
98
+ return nodes[0].id
99
+ elif action_type == "strategic":
100
+ return "dept-security"
101
+ return ""
102
+
103
+
104
+ def save_golden_trajectories(output_path: str = "golden_trajectories.json"):
105
+ """Generate and save golden trajectories for all difficulty levels."""
106
+ all_trajectories = {}
107
+ for diff in [1, 2]:
108
+ print(f"Extracting golden trajectories for difficulty {diff}...")
109
+ trajectories = extract_golden_trajectories(num_episodes=5, difficulty=diff)
110
+ all_trajectories[f"level_{diff}"] = trajectories
111
+ print(f" Found {len(trajectories)} good trajectories (best reward: {trajectories[0]['total_reward']:.3f})")
112
+
113
+ with open(output_path, "w") as f:
114
+ json.dump(all_trajectories, f, indent=2, default=str)
115
+ print(f"\n💾 Saved to {output_path}")
116
+
117
+
118
+ if __name__ == "__main__":
119
+ save_golden_trajectories()