File size: 8,401 Bytes
e7b864e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
"""Tests for the ImmunoOrg environment."""

import sys
import os
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from immunoorg.environment import ImmunoOrgEnvironment
from immunoorg.models import (
    ImmunoAction, ActionType, TacticalAction, DiagnosticAction,
    StrategicAction, IncidentPhase,
)
from immunoorg.network_graph import NetworkGraph
from immunoorg.org_graph import OrgGraph
from immunoorg.reward import RewardCalculator
from immunoorg.belief_map import BeliefMap
from immunoorg.curriculum import CurriculumEngine


def test_network_graph_generation():
    """Test network graph generates valid topology."""
    for diff in [1, 2, 3, 4]:
        ng = NetworkGraph(difficulty=diff, seed=42)
        ng.generate_topology()
        nodes = ng.get_all_nodes()
        edges = ng.get_all_edges()
        assert len(nodes) > 0, f"No nodes at difficulty {diff}"
        assert len(edges) > 0, f"No edges at difficulty {diff}"
        assert all(n.health == 1.0 for n in nodes), "Initial health should be 1.0"
        print(f"  βœ… Difficulty {diff}: {len(nodes)} nodes, {len(edges)} edges")


def test_org_graph_generation():
    """Test org graph generates valid structure."""
    ng = NetworkGraph(difficulty=2, seed=42)
    ng.generate_topology()
    og = OrgGraph(difficulty=2, seed=42)
    og.generate_org_structure(list(ng.nodes.keys()))
    nodes = og.get_all_nodes()
    edges = og.get_all_edges()
    assert len(nodes) > 0
    assert len(edges) > 0
    efficiency = og.calculate_org_efficiency()
    assert 0.0 <= efficiency <= 1.0
    print(f"  βœ… {len(nodes)} departments, {len(edges)} channels, efficiency={efficiency:.2f}")


def test_environment_reset():
    """Test environment reset returns valid observation."""
    env = ImmunoOrgEnvironment(difficulty=1, seed=42)
    obs = env.reset()
    assert obs.current_phase == IncidentPhase.DETECTION
    assert obs.step_count == 0
    assert len(obs.visible_nodes) > 0
    assert len(obs.org_nodes) > 0
    assert obs.threat_level > 0  # Attack was launched
    print(f"  βœ… Reset OK: {len(obs.visible_nodes)} nodes, threat={obs.threat_level:.2f}")


def test_environment_step_tactical():
    """Test tactical actions execute correctly."""
    env = ImmunoOrgEnvironment(difficulty=1, seed=42)
    obs = env.reset()

    # Scan logs
    action = ImmunoAction(
        action_type=ActionType.TACTICAL,
        tactical_action=TacticalAction.SCAN_LOGS,
        target=obs.visible_nodes[0].id,
        reasoning="Scanning for attack indicators.",
    )
    obs2, reward, done = env.step(action)
    assert obs2.step_count == 1
    assert "Scanned" in obs2.action_result or "scan" in obs2.action_result.lower()
    print(f"  βœ… Scan logs: reward={reward:.3f}, result={obs2.action_result[:60]}")


def test_environment_step_diagnostic():
    """Test diagnostic actions."""
    env = ImmunoOrgEnvironment(difficulty=1, seed=42)
    obs = env.reset()

    action = ImmunoAction(
        action_type=ActionType.DIAGNOSTIC,
        diagnostic_action=DiagnosticAction.IDENTIFY_SILO,
        target="",
        reasoning="Looking for organizational silos.",
    )
    obs2, reward, done = env.step(action)
    assert obs2.action_success
    print(f"  βœ… Identify silo: {obs2.action_result[:60]}")


def test_full_episode_level1():
    """Test a full episode at level 1."""
    env = ImmunoOrgEnvironment(difficulty=1, seed=42)
    obs = env.reset()

    actions = [
        ImmunoAction(action_type=ActionType.TACTICAL, tactical_action=TacticalAction.SCAN_LOGS,
                     target=obs.visible_nodes[0].id, reasoning="Scanning logs."),
        ImmunoAction(action_type=ActionType.DIAGNOSTIC, diagnostic_action=DiagnosticAction.TRACE_ATTACK_PATH,
                     target="", reasoning="Tracing attack."),
    ]

    # Find compromised node and isolate it
    for node in obs.visible_nodes:
        if node.compromised:
            actions.append(ImmunoAction(
                action_type=ActionType.TACTICAL, tactical_action=TacticalAction.ISOLATE_NODE,
                target=node.id, reasoning="Isolating compromised node.",
            ))
            break

    # RCA
    actions.append(ImmunoAction(
        action_type=ActionType.DIAGNOSTIC, diagnostic_action=DiagnosticAction.CORRELATE_FAILURE,
        target="", parameters={"technical_indicator": "attack", "organizational_flaw": "no_devsecops", "confidence": 0.7},
        reasoning="Correlating failure.",
    ))
    actions.append(ImmunoAction(
        action_type=ActionType.DIAGNOSTIC, diagnostic_action=DiagnosticAction.IDENTIFY_SILO,
        target="", reasoning="Finding silos.",
    ))

    # Refactor
    actions.append(ImmunoAction(
        action_type=ActionType.STRATEGIC, strategic_action=StrategicAction.ESTABLISH_DEVSECOPS,
        target="dept-security", reasoning="Establishing DevSecOps.",
    ))

    # Validation
    actions.append(ImmunoAction(
        action_type=ActionType.DIAGNOSTIC, diagnostic_action=DiagnosticAction.MEASURE_ORG_LATENCY,
        target="", reasoning="Validating improvements.",
    ))

    total_reward = 0.0
    for i, action in enumerate(actions):
        obs, reward, done = env.step(action)
        total_reward += reward
        if done:
            break

    print(f"  βœ… Full episode: {env.state.step_count} steps, reward={total_reward:.3f}, "
          f"phase={obs.current_phase.value}, terminated={done}")


def test_curriculum_advancement():
    """Test curriculum level advancement."""
    ce = CurriculumEngine(start_level=1)
    # Simulate 3 successes
    for _ in range(3):
        result = ce.record_episode_result({
            "threats_contained_ratio": 1.0, "total_downtime": 5.0, "total_reward": 1.0
        })
    assert ce.current_level == 2, f"Should advance to level 2, got {ce.current_level}"
    print(f"  βœ… Curriculum advanced to level {ce.current_level}")


def test_belief_map_accuracy():
    """Test belief map scoring."""
    bm = BeliefMap()
    bm.set_ground_truth([{"vector": "sql_injection", "target": "db-01"}])

    # Agent guesses correctly
    bm.agent_correlate("sql_injection_on_db-01", "no_devsecops", 0.8, ["evidence"], 1.0)
    accuracy = bm.calculate_belief_accuracy()
    assert accuracy > 0, f"Accuracy should be > 0, got {accuracy}"
    print(f"  βœ… Belief accuracy: {accuracy:.2f}")


def test_reward_calculator():
    """Test reward computation."""
    rc = RewardCalculator()
    state = ImmunoOrgEnvironment(difficulty=1, seed=42)
    state.reset()
    action = ImmunoAction(
        action_type=ActionType.TACTICAL, tactical_action=TacticalAction.SCAN_LOGS,
        target="test", reasoning="Testing the reward function with a scan action.",
    )
    reward = rc.compute_step_reward(
        state=state.state, action=action, action_success=True,
        threats_before=1, threats_after=1, belief_accuracy=0.0,
        org_chaos=0.0, downtime_delta=0.0,
    )
    assert isinstance(reward, float)
    print(f"  βœ… Reward computed: {reward:.3f}")


def run_all_tests():
    print("=" * 60)
    print("ImmunoOrg Test Suite")
    print("=" * 60)

    tests = [
        ("Network Graph Generation", test_network_graph_generation),
        ("Org Graph Generation", test_org_graph_generation),
        ("Environment Reset", test_environment_reset),
        ("Tactical Actions", test_environment_step_tactical),
        ("Diagnostic Actions", test_environment_step_diagnostic),
        ("Full Episode Level 1", test_full_episode_level1),
        ("Curriculum Advancement", test_curriculum_advancement),
        ("Belief Map Accuracy", test_belief_map_accuracy),
        ("Reward Calculator", test_reward_calculator),
    ]

    passed = 0
    failed = 0
    for name, test_fn in tests:
        print(f"\nπŸ§ͺ {name}")
        try:
            test_fn()
            passed += 1
        except Exception as e:
            print(f"  ❌ FAILED: {e}")
            failed += 1

    print(f"\n{'=' * 60}")
    print(f"Results: {passed} passed, {failed} failed out of {len(tests)}")
    print("=" * 60)
    return failed == 0


if __name__ == "__main__":
    success = run_all_tests()
    sys.exit(0 if success else 1)