Spaces:
Paused
Paused
| """Tests for the ImmunoOrg environment.""" | |
| import sys | |
| import os | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from immunoorg.environment import ImmunoOrgEnvironment | |
| from immunoorg.models import ( | |
| ImmunoAction, ActionType, TacticalAction, DiagnosticAction, | |
| StrategicAction, IncidentPhase, | |
| ) | |
| from immunoorg.network_graph import NetworkGraph | |
| from immunoorg.org_graph import OrgGraph | |
| from immunoorg.reward import RewardCalculator | |
| from immunoorg.belief_map import BeliefMap | |
| from immunoorg.curriculum import CurriculumEngine | |
| def test_network_graph_generation(): | |
| """Test network graph generates valid topology.""" | |
| for diff in [1, 2, 3, 4]: | |
| ng = NetworkGraph(difficulty=diff, seed=42) | |
| ng.generate_topology() | |
| nodes = ng.get_all_nodes() | |
| edges = ng.get_all_edges() | |
| assert len(nodes) > 0, f"No nodes at difficulty {diff}" | |
| assert len(edges) > 0, f"No edges at difficulty {diff}" | |
| assert all(n.health == 1.0 for n in nodes), "Initial health should be 1.0" | |
| print(f" ✅ Difficulty {diff}: {len(nodes)} nodes, {len(edges)} edges") | |
| def test_org_graph_generation(): | |
| """Test org graph generates valid structure.""" | |
| ng = NetworkGraph(difficulty=2, seed=42) | |
| ng.generate_topology() | |
| og = OrgGraph(difficulty=2, seed=42) | |
| og.generate_org_structure(list(ng.nodes.keys())) | |
| nodes = og.get_all_nodes() | |
| edges = og.get_all_edges() | |
| assert len(nodes) > 0 | |
| assert len(edges) > 0 | |
| efficiency = og.calculate_org_efficiency() | |
| assert 0.0 <= efficiency <= 1.0 | |
| print(f" ✅ {len(nodes)} departments, {len(edges)} channels, efficiency={efficiency:.2f}") | |
| def test_environment_reset(): | |
| """Test environment reset returns valid observation.""" | |
| env = ImmunoOrgEnvironment(difficulty=1, seed=42) | |
| obs = env.reset() | |
| assert obs.current_phase == IncidentPhase.DETECTION | |
| assert obs.step_count == 0 | |
| assert len(obs.visible_nodes) > 0 | |
| assert len(obs.org_nodes) > 0 | |
| assert obs.threat_level > 0 # Attack was launched | |
| print(f" ✅ Reset OK: {len(obs.visible_nodes)} nodes, threat={obs.threat_level:.2f}") | |
| def test_environment_step_tactical(): | |
| """Test tactical actions execute correctly.""" | |
| env = ImmunoOrgEnvironment(difficulty=1, seed=42) | |
| obs = env.reset() | |
| # Scan logs | |
| action = ImmunoAction( | |
| action_type=ActionType.TACTICAL, | |
| tactical_action=TacticalAction.SCAN_LOGS, | |
| target=obs.visible_nodes[0].id, | |
| reasoning="Scanning for attack indicators.", | |
| ) | |
| obs2, reward, done = env.step(action) | |
| assert obs2.step_count == 1 | |
| assert "Scanned" in obs2.action_result or "scan" in obs2.action_result.lower() | |
| print(f" ✅ Scan logs: reward={reward:.3f}, result={obs2.action_result[:60]}") | |
| def test_environment_step_diagnostic(): | |
| """Test diagnostic actions.""" | |
| env = ImmunoOrgEnvironment(difficulty=1, seed=42) | |
| obs = env.reset() | |
| action = ImmunoAction( | |
| action_type=ActionType.DIAGNOSTIC, | |
| diagnostic_action=DiagnosticAction.IDENTIFY_SILO, | |
| target="", | |
| reasoning="Looking for organizational silos.", | |
| ) | |
| obs2, reward, done = env.step(action) | |
| assert obs2.action_success | |
| print(f" ✅ Identify silo: {obs2.action_result[:60]}") | |
| def test_full_episode_level1(): | |
| """Test a full episode at level 1.""" | |
| env = ImmunoOrgEnvironment(difficulty=1, seed=42) | |
| obs = env.reset() | |
| actions = [ | |
| ImmunoAction(action_type=ActionType.TACTICAL, tactical_action=TacticalAction.SCAN_LOGS, | |
| target=obs.visible_nodes[0].id, reasoning="Scanning logs."), | |
| ImmunoAction(action_type=ActionType.DIAGNOSTIC, diagnostic_action=DiagnosticAction.TRACE_ATTACK_PATH, | |
| target="", reasoning="Tracing attack."), | |
| ] | |
| # Find compromised node and isolate it | |
| for node in obs.visible_nodes: | |
| if node.compromised: | |
| actions.append(ImmunoAction( | |
| action_type=ActionType.TACTICAL, tactical_action=TacticalAction.ISOLATE_NODE, | |
| target=node.id, reasoning="Isolating compromised node.", | |
| )) | |
| break | |
| # RCA | |
| actions.append(ImmunoAction( | |
| action_type=ActionType.DIAGNOSTIC, diagnostic_action=DiagnosticAction.CORRELATE_FAILURE, | |
| target="", parameters={"technical_indicator": "attack", "organizational_flaw": "no_devsecops", "confidence": 0.7}, | |
| reasoning="Correlating failure.", | |
| )) | |
| actions.append(ImmunoAction( | |
| action_type=ActionType.DIAGNOSTIC, diagnostic_action=DiagnosticAction.IDENTIFY_SILO, | |
| target="", reasoning="Finding silos.", | |
| )) | |
| # Refactor | |
| actions.append(ImmunoAction( | |
| action_type=ActionType.STRATEGIC, strategic_action=StrategicAction.ESTABLISH_DEVSECOPS, | |
| target="dept-security", reasoning="Establishing DevSecOps.", | |
| )) | |
| # Validation | |
| actions.append(ImmunoAction( | |
| action_type=ActionType.DIAGNOSTIC, diagnostic_action=DiagnosticAction.MEASURE_ORG_LATENCY, | |
| target="", reasoning="Validating improvements.", | |
| )) | |
| total_reward = 0.0 | |
| for i, action in enumerate(actions): | |
| obs, reward, done = env.step(action) | |
| total_reward += reward | |
| if done: | |
| break | |
| print(f" ✅ Full episode: {env.state.step_count} steps, reward={total_reward:.3f}, " | |
| f"phase={obs.current_phase.value}, terminated={done}") | |
| def test_curriculum_advancement(): | |
| """Test curriculum level advancement.""" | |
| ce = CurriculumEngine(start_level=1) | |
| # Simulate 3 successes | |
| for _ in range(3): | |
| result = ce.record_episode_result({ | |
| "threats_contained_ratio": 1.0, "total_downtime": 5.0, "total_reward": 1.0 | |
| }) | |
| assert ce.current_level == 2, f"Should advance to level 2, got {ce.current_level}" | |
| print(f" ✅ Curriculum advanced to level {ce.current_level}") | |
| def test_belief_map_accuracy(): | |
| """Test belief map scoring.""" | |
| bm = BeliefMap() | |
| bm.set_ground_truth([{"vector": "sql_injection", "target": "db-01"}]) | |
| # Agent guesses correctly | |
| bm.agent_correlate("sql_injection_on_db-01", "no_devsecops", 0.8, ["evidence"], 1.0) | |
| accuracy = bm.calculate_belief_accuracy() | |
| assert accuracy > 0, f"Accuracy should be > 0, got {accuracy}" | |
| print(f" ✅ Belief accuracy: {accuracy:.2f}") | |
| def test_reward_calculator(): | |
| """Test reward computation.""" | |
| rc = RewardCalculator() | |
| state = ImmunoOrgEnvironment(difficulty=1, seed=42) | |
| state.reset() | |
| action = ImmunoAction( | |
| action_type=ActionType.TACTICAL, tactical_action=TacticalAction.SCAN_LOGS, | |
| target="test", reasoning="Testing the reward function with a scan action.", | |
| ) | |
| reward = rc.compute_step_reward( | |
| state=state.state, action=action, action_success=True, | |
| threats_before=1, threats_after=1, belief_accuracy=0.0, | |
| org_chaos=0.0, downtime_delta=0.0, | |
| ) | |
| assert isinstance(reward, float) | |
| print(f" ✅ Reward computed: {reward:.3f}") | |
| def run_all_tests(): | |
| print("=" * 60) | |
| print("ImmunoOrg Test Suite") | |
| print("=" * 60) | |
| tests = [ | |
| ("Network Graph Generation", test_network_graph_generation), | |
| ("Org Graph Generation", test_org_graph_generation), | |
| ("Environment Reset", test_environment_reset), | |
| ("Tactical Actions", test_environment_step_tactical), | |
| ("Diagnostic Actions", test_environment_step_diagnostic), | |
| ("Full Episode Level 1", test_full_episode_level1), | |
| ("Curriculum Advancement", test_curriculum_advancement), | |
| ("Belief Map Accuracy", test_belief_map_accuracy), | |
| ("Reward Calculator", test_reward_calculator), | |
| ] | |
| passed = 0 | |
| failed = 0 | |
| for name, test_fn in tests: | |
| print(f"\n🧪 {name}") | |
| try: | |
| test_fn() | |
| passed += 1 | |
| except Exception as e: | |
| print(f" ❌ FAILED: {e}") | |
| failed += 1 | |
| print(f"\n{'=' * 60}") | |
| print(f"Results: {passed} passed, {failed} failed out of {len(tests)}") | |
| print("=" * 60) | |
| return failed == 0 | |
| if __name__ == "__main__": | |
| success = run_all_tests() | |
| sys.exit(0 if success else 1) | |