Spaces:
Paused
Paused
File size: 8,401 Bytes
e7b864e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 | """Tests for the ImmunoOrg environment."""
import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from immunoorg.environment import ImmunoOrgEnvironment
from immunoorg.models import (
ImmunoAction, ActionType, TacticalAction, DiagnosticAction,
StrategicAction, IncidentPhase,
)
from immunoorg.network_graph import NetworkGraph
from immunoorg.org_graph import OrgGraph
from immunoorg.reward import RewardCalculator
from immunoorg.belief_map import BeliefMap
from immunoorg.curriculum import CurriculumEngine
def test_network_graph_generation():
"""Test network graph generates valid topology."""
for diff in [1, 2, 3, 4]:
ng = NetworkGraph(difficulty=diff, seed=42)
ng.generate_topology()
nodes = ng.get_all_nodes()
edges = ng.get_all_edges()
assert len(nodes) > 0, f"No nodes at difficulty {diff}"
assert len(edges) > 0, f"No edges at difficulty {diff}"
assert all(n.health == 1.0 for n in nodes), "Initial health should be 1.0"
print(f" β
Difficulty {diff}: {len(nodes)} nodes, {len(edges)} edges")
def test_org_graph_generation():
"""Test org graph generates valid structure."""
ng = NetworkGraph(difficulty=2, seed=42)
ng.generate_topology()
og = OrgGraph(difficulty=2, seed=42)
og.generate_org_structure(list(ng.nodes.keys()))
nodes = og.get_all_nodes()
edges = og.get_all_edges()
assert len(nodes) > 0
assert len(edges) > 0
efficiency = og.calculate_org_efficiency()
assert 0.0 <= efficiency <= 1.0
print(f" β
{len(nodes)} departments, {len(edges)} channels, efficiency={efficiency:.2f}")
def test_environment_reset():
"""Test environment reset returns valid observation."""
env = ImmunoOrgEnvironment(difficulty=1, seed=42)
obs = env.reset()
assert obs.current_phase == IncidentPhase.DETECTION
assert obs.step_count == 0
assert len(obs.visible_nodes) > 0
assert len(obs.org_nodes) > 0
assert obs.threat_level > 0 # Attack was launched
print(f" β
Reset OK: {len(obs.visible_nodes)} nodes, threat={obs.threat_level:.2f}")
def test_environment_step_tactical():
"""Test tactical actions execute correctly."""
env = ImmunoOrgEnvironment(difficulty=1, seed=42)
obs = env.reset()
# Scan logs
action = ImmunoAction(
action_type=ActionType.TACTICAL,
tactical_action=TacticalAction.SCAN_LOGS,
target=obs.visible_nodes[0].id,
reasoning="Scanning for attack indicators.",
)
obs2, reward, done = env.step(action)
assert obs2.step_count == 1
assert "Scanned" in obs2.action_result or "scan" in obs2.action_result.lower()
print(f" β
Scan logs: reward={reward:.3f}, result={obs2.action_result[:60]}")
def test_environment_step_diagnostic():
"""Test diagnostic actions."""
env = ImmunoOrgEnvironment(difficulty=1, seed=42)
obs = env.reset()
action = ImmunoAction(
action_type=ActionType.DIAGNOSTIC,
diagnostic_action=DiagnosticAction.IDENTIFY_SILO,
target="",
reasoning="Looking for organizational silos.",
)
obs2, reward, done = env.step(action)
assert obs2.action_success
print(f" β
Identify silo: {obs2.action_result[:60]}")
def test_full_episode_level1():
"""Test a full episode at level 1."""
env = ImmunoOrgEnvironment(difficulty=1, seed=42)
obs = env.reset()
actions = [
ImmunoAction(action_type=ActionType.TACTICAL, tactical_action=TacticalAction.SCAN_LOGS,
target=obs.visible_nodes[0].id, reasoning="Scanning logs."),
ImmunoAction(action_type=ActionType.DIAGNOSTIC, diagnostic_action=DiagnosticAction.TRACE_ATTACK_PATH,
target="", reasoning="Tracing attack."),
]
# Find compromised node and isolate it
for node in obs.visible_nodes:
if node.compromised:
actions.append(ImmunoAction(
action_type=ActionType.TACTICAL, tactical_action=TacticalAction.ISOLATE_NODE,
target=node.id, reasoning="Isolating compromised node.",
))
break
# RCA
actions.append(ImmunoAction(
action_type=ActionType.DIAGNOSTIC, diagnostic_action=DiagnosticAction.CORRELATE_FAILURE,
target="", parameters={"technical_indicator": "attack", "organizational_flaw": "no_devsecops", "confidence": 0.7},
reasoning="Correlating failure.",
))
actions.append(ImmunoAction(
action_type=ActionType.DIAGNOSTIC, diagnostic_action=DiagnosticAction.IDENTIFY_SILO,
target="", reasoning="Finding silos.",
))
# Refactor
actions.append(ImmunoAction(
action_type=ActionType.STRATEGIC, strategic_action=StrategicAction.ESTABLISH_DEVSECOPS,
target="dept-security", reasoning="Establishing DevSecOps.",
))
# Validation
actions.append(ImmunoAction(
action_type=ActionType.DIAGNOSTIC, diagnostic_action=DiagnosticAction.MEASURE_ORG_LATENCY,
target="", reasoning="Validating improvements.",
))
total_reward = 0.0
for i, action in enumerate(actions):
obs, reward, done = env.step(action)
total_reward += reward
if done:
break
print(f" β
Full episode: {env.state.step_count} steps, reward={total_reward:.3f}, "
f"phase={obs.current_phase.value}, terminated={done}")
def test_curriculum_advancement():
"""Test curriculum level advancement."""
ce = CurriculumEngine(start_level=1)
# Simulate 3 successes
for _ in range(3):
result = ce.record_episode_result({
"threats_contained_ratio": 1.0, "total_downtime": 5.0, "total_reward": 1.0
})
assert ce.current_level == 2, f"Should advance to level 2, got {ce.current_level}"
print(f" β
Curriculum advanced to level {ce.current_level}")
def test_belief_map_accuracy():
"""Test belief map scoring."""
bm = BeliefMap()
bm.set_ground_truth([{"vector": "sql_injection", "target": "db-01"}])
# Agent guesses correctly
bm.agent_correlate("sql_injection_on_db-01", "no_devsecops", 0.8, ["evidence"], 1.0)
accuracy = bm.calculate_belief_accuracy()
assert accuracy > 0, f"Accuracy should be > 0, got {accuracy}"
print(f" β
Belief accuracy: {accuracy:.2f}")
def test_reward_calculator():
"""Test reward computation."""
rc = RewardCalculator()
state = ImmunoOrgEnvironment(difficulty=1, seed=42)
state.reset()
action = ImmunoAction(
action_type=ActionType.TACTICAL, tactical_action=TacticalAction.SCAN_LOGS,
target="test", reasoning="Testing the reward function with a scan action.",
)
reward = rc.compute_step_reward(
state=state.state, action=action, action_success=True,
threats_before=1, threats_after=1, belief_accuracy=0.0,
org_chaos=0.0, downtime_delta=0.0,
)
assert isinstance(reward, float)
print(f" β
Reward computed: {reward:.3f}")
def run_all_tests():
print("=" * 60)
print("ImmunoOrg Test Suite")
print("=" * 60)
tests = [
("Network Graph Generation", test_network_graph_generation),
("Org Graph Generation", test_org_graph_generation),
("Environment Reset", test_environment_reset),
("Tactical Actions", test_environment_step_tactical),
("Diagnostic Actions", test_environment_step_diagnostic),
("Full Episode Level 1", test_full_episode_level1),
("Curriculum Advancement", test_curriculum_advancement),
("Belief Map Accuracy", test_belief_map_accuracy),
("Reward Calculator", test_reward_calculator),
]
passed = 0
failed = 0
for name, test_fn in tests:
print(f"\nπ§ͺ {name}")
try:
test_fn()
passed += 1
except Exception as e:
print(f" β FAILED: {e}")
failed += 1
print(f"\n{'=' * 60}")
print(f"Results: {passed} passed, {failed} failed out of {len(tests)}")
print("=" * 60)
return failed == 0
if __name__ == "__main__":
success = run_all_tests()
sys.exit(0 if success else 1)
|