""" Network Graph Engine ==================== Simulates the technical infrastructure layer with servers, APIs, ports, cascading failures, and attack propagation. """ from __future__ import annotations import random from typing import Any import networkx as nx from immunoorg.models import ( Attack, AttackVector, LogEntry, LogSeverity, NetworkEdge, NetworkNode, NodeType, PortState, PortStatus, ) class NetworkGraph: """Manages the technical network topology and simulates infrastructure behavior.""" def __init__(self, difficulty: int = 1, seed: int | None = None): self.difficulty = difficulty self.rng = random.Random(seed) self.graph = nx.DiGraph() self.nodes: dict[str, NetworkNode] = {} self.edges: list[NetworkEdge] = [] self.sim_time: float = 0.0 def generate_topology(self) -> None: """Generate a realistic enterprise network topology based on difficulty.""" tier_configs = { 1: {"web": 2, "app": 2, "data": 1, "management": 1, "dmz": 1}, 2: {"web": 3, "app": 4, "data": 2, "management": 2, "dmz": 1}, 3: {"web": 4, "app": 6, "data": 3, "management": 3, "dmz": 2}, 4: {"web": 5, "app": 8, "data": 4, "management": 4, "dmz": 2}, } config = tier_configs.get(self.difficulty, tier_configs[1]) tier_type_map = { "web": [NodeType.SERVER, NodeType.LOAD_BALANCER], "app": [NodeType.SERVER, NodeType.API], "data": [NodeType.DATABASE, NodeType.SERVER], "management": [NodeType.SERVER, NodeType.ENDPOINT], "dmz": [NodeType.FIREWALL, NodeType.SERVER], } service_map = { NodeType.SERVER: ["nginx", "apache", "node"], NodeType.API: ["rest-api", "graphql", "grpc"], NodeType.DATABASE: ["mysql", "postgres", "redis", "mongodb"], NodeType.FIREWALL: ["iptables", "pfsense"], NodeType.LOAD_BALANCER: ["haproxy", "nginx-lb"], NodeType.ENDPOINT: ["workstation", "admin-console"], } port_map = { "nginx": [80, 443], "apache": [80, 443, 8080], "node": [3000, 8080], "rest-api": [8080, 8443], "graphql": [4000], "grpc": [50051], "mysql": [3306], "postgres": [5432], "redis": [6379], "mongodb": [27017], "iptables": [22], "pfsense": [443, 8443], "haproxy": [80, 443, 8404], "nginx-lb": [80, 443], "workstation": [3389, 22], "admin-console": [443, 8443], } node_counter = 0 tier_nodes: dict[str, list[str]] = {} for tier, count in config.items(): tier_nodes[tier] = [] types = tier_type_map[tier] for i in range(count): node_type = types[i % len(types)] service = self.rng.choice(service_map[node_type]) ports_for_service = port_map.get(service, [8080]) node_id = f"{tier}-{node_type.value}-{node_counter:02d}" node_counter += 1 ports = [ PortState( port_number=p, service=service, status=PortStatus.OPEN, vulnerability_score=self.rng.uniform(0.0, 0.4), ) for p in ports_for_service ] criticality = {"data": 0.9, "management": 0.7, "app": 0.6, "web": 0.5, "dmz": 0.8} node = NetworkNode( id=node_id, name=f"{service}-{tier}-{i}", type=node_type, tier=tier, ports=ports, health=1.0, services=[service], criticality=criticality.get(tier, 0.5), ) self.nodes[node_id] = node self.graph.add_node(node_id, tier=tier, type=node_type.value) tier_nodes[tier].append(node_id) # Create edges: dmz → web → app → data, management connects to all tier_order = ["dmz", "web", "app", "data"] for i in range(len(tier_order) - 1): src_tier = tier_order[i] dst_tier = tier_order[i + 1] for src in tier_nodes.get(src_tier, []): for dst in tier_nodes.get(dst_tier, []): if self.rng.random() < 0.6: edge = NetworkEdge( source=src, target=dst, bandwidth=self.rng.uniform(100, 10000), latency=self.rng.uniform(0.1, 5.0), encrypted=self.rng.random() > 0.2, ) self.edges.append(edge) self.graph.add_edge(src, dst, weight=edge.latency) # Management connects to a subset of all nodes for mgmt_node in tier_nodes.get("management", []): all_other = [n for n in self.nodes if n != mgmt_node] targets = self.rng.sample(all_other, min(len(all_other), 4 + self.difficulty)) for t in targets: edge = NetworkEdge( source=mgmt_node, target=t, bandwidth=1000, latency=1.0, encrypted=True, ) self.edges.append(edge) self.graph.add_edge(mgmt_node, t, weight=1.0) def get_node(self, node_id: str) -> NetworkNode | None: return self.nodes.get(node_id) def get_all_nodes(self) -> list[NetworkNode]: return list(self.nodes.values()) def get_all_node_ids(self) -> list[str]: """Convenience helper: return all node IDs. Some higher-level modules/tests operate on IDs rather than full node objects. """ return list(self.nodes.keys()) def get_all_edges(self) -> list[NetworkEdge]: return list(self.edges) def compromise_node(self, node_id: str, vector: AttackVector, sim_time: float) -> bool: """Compromise a node with a given attack vector.""" node = self.nodes.get(node_id) if not node or node.compromised or node.isolated: return False node.compromised = True node.compromised_at = sim_time node.attack_vector = vector node.health = max(0.0, node.health - self.rng.uniform(0.3, 0.7)) # Generate attack log node.logs.append(LogEntry( timestamp=sim_time, severity=LogSeverity.CRITICAL, source=node_id, message=f"Compromised via {vector.value}", attack_indicator=True, indicator_confidence=0.3 + self.rng.uniform(0, 0.5), )) return True def propagate_attack(self, source_id: str, attack: Attack, sim_time: float) -> list[str]: """Propagate an attack from a compromised node to neighbors (cascading failure).""" newly_compromised = [] neighbors = list(self.graph.successors(source_id)) self.rng.shuffle(neighbors) propagation_chance = {1: 0.1, 2: 0.25, 3: 0.4, 4: 0.6} chance = propagation_chance.get(self.difficulty, 0.2) for neighbor in neighbors: target_node = self.nodes.get(neighbor) if not target_node or target_node.compromised or target_node.isolated: continue if self.rng.random() < chance: if self.compromise_node(neighbor, AttackVector.LATERAL_MOVEMENT, sim_time): newly_compromised.append(neighbor) # Add to attack lateral path attack.lateral_path.append(neighbor) return newly_compromised def apply_damage_tick(self, sim_time: float) -> float: """Apply ongoing damage from compromised nodes. Returns total damage this tick.""" damage = 0.0 for node in self.nodes.values(): if node.compromised and not node.isolated: dmg = node.criticality * 0.05 node.health = max(0.0, node.health - dmg) damage += dmg # Generate warning logs with some noise if self.rng.random() < 0.3: node.logs.append(LogEntry( timestamp=sim_time, severity=LogSeverity.WARNING, source=node.id, message=f"Anomalous activity detected on {node.services[0] if node.services else 'unknown'}", attack_indicator=True, indicator_confidence=self.rng.uniform(0.1, 0.6), )) # Generate normal noise logs if self.rng.random() < 0.1: node.logs.append(LogEntry( timestamp=sim_time, severity=LogSeverity.INFO, source=node.id, message=self.rng.choice([ "Health check OK", "Routine maintenance log", "Connection pool refresh", "Cache cleared", "Backup checkpoint created", ]), )) return damage def isolate_node(self, node_id: str) -> bool: """Isolate a node from the network.""" node = self.nodes.get(node_id) if not node: return False node.isolated = True return True def block_port(self, node_id: str, port_number: int) -> bool: """Block a specific port on a node.""" node = self.nodes.get(node_id) if not node: return False for port in node.ports: if port.port_number == port_number: port.status = PortStatus.BLOCKED return True return False def deploy_patch(self, node_id: str) -> bool: """Patch a node, reducing vulnerability scores.""" node = self.nodes.get(node_id) if not node: return False node.patched = True for port in node.ports: port.vulnerability_score = max(0.0, port.vulnerability_score - 0.3) if node.compromised: node.compromised = False node.attack_vector = None node.health = min(1.0, node.health + 0.3) return True def restore_backup(self, node_id: str) -> bool: """Restore a node from backup.""" node = self.nodes.get(node_id) if not node: return False node.health = 1.0 node.compromised = False node.attack_vector = None node.isolated = False return True def rotate_credentials(self, node_id: str) -> bool: """Rotate credentials on a node.""" node = self.nodes.get(node_id) if not node: return False # Reduces effectiveness of credential-based attacks if node.attack_vector in (AttackVector.CREDENTIAL_STUFFING, AttackVector.PHISHING): node.compromised = False node.attack_vector = None node.health = min(1.0, node.health + 0.2) return True def scan_logs(self, node_id: str) -> list[LogEntry]: """Return logs for a node, including attack indicators.""" node = self.nodes.get(node_id) if not node: return [] return list(node.logs[-20:]) # Last 20 entries def get_network_health(self) -> dict[str, float]: """Get health summary by tier.""" tier_health: dict[str, list[float]] = {} for node in self.nodes.values(): if node.tier not in tier_health: tier_health[node.tier] = [] tier_health[node.tier].append(node.health) return { tier: sum(healths) / len(healths) if healths else 1.0 for tier, healths in tier_health.items() } def get_compromised_nodes(self) -> list[NetworkNode]: return [n for n in self.nodes.values() if n.compromised] def find_attack_path(self, source: str, target: str) -> list[str] | None: """Find shortest path between two nodes.""" try: return nx.shortest_path(self.graph, source, target) except (nx.NetworkXNoPath, nx.NodeNotFound): return None def get_vulnerable_nodes(self, threshold: float = 0.3) -> list[NetworkNode]: """Find nodes with high vulnerability scores.""" vulnerable = [] for node in self.nodes.values(): max_vuln = max((p.vulnerability_score for p in node.ports), default=0.0) if max_vuln >= threshold: vulnerable.append(node) return vulnerable