immunoorg-v3 / immunoorg /network_graph.py
hirann's picture
Upload immunoorg/network_graph.py with huggingface_hub
7489025 verified
"""
Network Graph Engine
====================
Simulates the technical infrastructure layer with servers, APIs, ports,
cascading failures, and attack propagation.
"""
from __future__ import annotations
import random
from typing import Any
import networkx as nx
from immunoorg.models import (
Attack, AttackVector, LogEntry, LogSeverity, NetworkEdge,
NetworkNode, NodeType, PortState, PortStatus,
)
class NetworkGraph:
"""Manages the technical network topology and simulates infrastructure behavior."""
def __init__(self, difficulty: int = 1, seed: int | None = None):
self.difficulty = difficulty
self.rng = random.Random(seed)
self.graph = nx.DiGraph()
self.nodes: dict[str, NetworkNode] = {}
self.edges: list[NetworkEdge] = []
self.sim_time: float = 0.0
def generate_topology(self) -> None:
"""Generate a realistic enterprise network topology based on difficulty."""
tier_configs = {
1: {"web": 2, "app": 2, "data": 1, "management": 1, "dmz": 1},
2: {"web": 3, "app": 4, "data": 2, "management": 2, "dmz": 1},
3: {"web": 4, "app": 6, "data": 3, "management": 3, "dmz": 2},
4: {"web": 5, "app": 8, "data": 4, "management": 4, "dmz": 2},
}
config = tier_configs.get(self.difficulty, tier_configs[1])
tier_type_map = {
"web": [NodeType.SERVER, NodeType.LOAD_BALANCER],
"app": [NodeType.SERVER, NodeType.API],
"data": [NodeType.DATABASE, NodeType.SERVER],
"management": [NodeType.SERVER, NodeType.ENDPOINT],
"dmz": [NodeType.FIREWALL, NodeType.SERVER],
}
service_map = {
NodeType.SERVER: ["nginx", "apache", "node"],
NodeType.API: ["rest-api", "graphql", "grpc"],
NodeType.DATABASE: ["mysql", "postgres", "redis", "mongodb"],
NodeType.FIREWALL: ["iptables", "pfsense"],
NodeType.LOAD_BALANCER: ["haproxy", "nginx-lb"],
NodeType.ENDPOINT: ["workstation", "admin-console"],
}
port_map = {
"nginx": [80, 443],
"apache": [80, 443, 8080],
"node": [3000, 8080],
"rest-api": [8080, 8443],
"graphql": [4000],
"grpc": [50051],
"mysql": [3306],
"postgres": [5432],
"redis": [6379],
"mongodb": [27017],
"iptables": [22],
"pfsense": [443, 8443],
"haproxy": [80, 443, 8404],
"nginx-lb": [80, 443],
"workstation": [3389, 22],
"admin-console": [443, 8443],
}
node_counter = 0
tier_nodes: dict[str, list[str]] = {}
for tier, count in config.items():
tier_nodes[tier] = []
types = tier_type_map[tier]
for i in range(count):
node_type = types[i % len(types)]
service = self.rng.choice(service_map[node_type])
ports_for_service = port_map.get(service, [8080])
node_id = f"{tier}-{node_type.value}-{node_counter:02d}"
node_counter += 1
ports = [
PortState(
port_number=p,
service=service,
status=PortStatus.OPEN,
vulnerability_score=self.rng.uniform(0.0, 0.4),
)
for p in ports_for_service
]
criticality = {"data": 0.9, "management": 0.7, "app": 0.6, "web": 0.5, "dmz": 0.8}
node = NetworkNode(
id=node_id,
name=f"{service}-{tier}-{i}",
type=node_type,
tier=tier,
ports=ports,
health=1.0,
services=[service],
criticality=criticality.get(tier, 0.5),
)
self.nodes[node_id] = node
self.graph.add_node(node_id, tier=tier, type=node_type.value)
tier_nodes[tier].append(node_id)
# Create edges: dmz → web → app → data, management connects to all
tier_order = ["dmz", "web", "app", "data"]
for i in range(len(tier_order) - 1):
src_tier = tier_order[i]
dst_tier = tier_order[i + 1]
for src in tier_nodes.get(src_tier, []):
for dst in tier_nodes.get(dst_tier, []):
if self.rng.random() < 0.6:
edge = NetworkEdge(
source=src, target=dst,
bandwidth=self.rng.uniform(100, 10000),
latency=self.rng.uniform(0.1, 5.0),
encrypted=self.rng.random() > 0.2,
)
self.edges.append(edge)
self.graph.add_edge(src, dst, weight=edge.latency)
# Management connects to a subset of all nodes
for mgmt_node in tier_nodes.get("management", []):
all_other = [n for n in self.nodes if n != mgmt_node]
targets = self.rng.sample(all_other, min(len(all_other), 4 + self.difficulty))
for t in targets:
edge = NetworkEdge(
source=mgmt_node, target=t,
bandwidth=1000, latency=1.0, encrypted=True,
)
self.edges.append(edge)
self.graph.add_edge(mgmt_node, t, weight=1.0)
def get_node(self, node_id: str) -> NetworkNode | None:
return self.nodes.get(node_id)
def get_all_nodes(self) -> list[NetworkNode]:
return list(self.nodes.values())
def get_all_node_ids(self) -> list[str]:
"""Convenience helper: return all node IDs.
Some higher-level modules/tests operate on IDs rather than full node objects.
"""
return list(self.nodes.keys())
def get_all_edges(self) -> list[NetworkEdge]:
return list(self.edges)
def compromise_node(self, node_id: str, vector: AttackVector, sim_time: float) -> bool:
"""Compromise a node with a given attack vector."""
node = self.nodes.get(node_id)
if not node or node.compromised or node.isolated:
return False
node.compromised = True
node.compromised_at = sim_time
node.attack_vector = vector
node.health = max(0.0, node.health - self.rng.uniform(0.3, 0.7))
# Generate attack log
node.logs.append(LogEntry(
timestamp=sim_time,
severity=LogSeverity.CRITICAL,
source=node_id,
message=f"Compromised via {vector.value}",
attack_indicator=True,
indicator_confidence=0.3 + self.rng.uniform(0, 0.5),
))
return True
def propagate_attack(self, source_id: str, attack: Attack, sim_time: float) -> list[str]:
"""Propagate an attack from a compromised node to neighbors (cascading failure)."""
newly_compromised = []
neighbors = list(self.graph.successors(source_id))
self.rng.shuffle(neighbors)
propagation_chance = {1: 0.1, 2: 0.25, 3: 0.4, 4: 0.6}
chance = propagation_chance.get(self.difficulty, 0.2)
for neighbor in neighbors:
target_node = self.nodes.get(neighbor)
if not target_node or target_node.compromised or target_node.isolated:
continue
if self.rng.random() < chance:
if self.compromise_node(neighbor, AttackVector.LATERAL_MOVEMENT, sim_time):
newly_compromised.append(neighbor)
# Add to attack lateral path
attack.lateral_path.append(neighbor)
return newly_compromised
def apply_damage_tick(self, sim_time: float) -> float:
"""Apply ongoing damage from compromised nodes. Returns total damage this tick."""
damage = 0.0
for node in self.nodes.values():
if node.compromised and not node.isolated:
dmg = node.criticality * 0.05
node.health = max(0.0, node.health - dmg)
damage += dmg
# Generate warning logs with some noise
if self.rng.random() < 0.3:
node.logs.append(LogEntry(
timestamp=sim_time,
severity=LogSeverity.WARNING,
source=node.id,
message=f"Anomalous activity detected on {node.services[0] if node.services else 'unknown'}",
attack_indicator=True,
indicator_confidence=self.rng.uniform(0.1, 0.6),
))
# Generate normal noise logs
if self.rng.random() < 0.1:
node.logs.append(LogEntry(
timestamp=sim_time,
severity=LogSeverity.INFO,
source=node.id,
message=self.rng.choice([
"Health check OK", "Routine maintenance log",
"Connection pool refresh", "Cache cleared",
"Backup checkpoint created",
]),
))
return damage
def isolate_node(self, node_id: str) -> bool:
"""Isolate a node from the network."""
node = self.nodes.get(node_id)
if not node:
return False
node.isolated = True
return True
def block_port(self, node_id: str, port_number: int) -> bool:
"""Block a specific port on a node."""
node = self.nodes.get(node_id)
if not node:
return False
for port in node.ports:
if port.port_number == port_number:
port.status = PortStatus.BLOCKED
return True
return False
def deploy_patch(self, node_id: str) -> bool:
"""Patch a node, reducing vulnerability scores."""
node = self.nodes.get(node_id)
if not node:
return False
node.patched = True
for port in node.ports:
port.vulnerability_score = max(0.0, port.vulnerability_score - 0.3)
if node.compromised:
node.compromised = False
node.attack_vector = None
node.health = min(1.0, node.health + 0.3)
return True
def restore_backup(self, node_id: str) -> bool:
"""Restore a node from backup."""
node = self.nodes.get(node_id)
if not node:
return False
node.health = 1.0
node.compromised = False
node.attack_vector = None
node.isolated = False
return True
def rotate_credentials(self, node_id: str) -> bool:
"""Rotate credentials on a node."""
node = self.nodes.get(node_id)
if not node:
return False
# Reduces effectiveness of credential-based attacks
if node.attack_vector in (AttackVector.CREDENTIAL_STUFFING, AttackVector.PHISHING):
node.compromised = False
node.attack_vector = None
node.health = min(1.0, node.health + 0.2)
return True
def scan_logs(self, node_id: str) -> list[LogEntry]:
"""Return logs for a node, including attack indicators."""
node = self.nodes.get(node_id)
if not node:
return []
return list(node.logs[-20:]) # Last 20 entries
def get_network_health(self) -> dict[str, float]:
"""Get health summary by tier."""
tier_health: dict[str, list[float]] = {}
for node in self.nodes.values():
if node.tier not in tier_health:
tier_health[node.tier] = []
tier_health[node.tier].append(node.health)
return {
tier: sum(healths) / len(healths) if healths else 1.0
for tier, healths in tier_health.items()
}
def get_compromised_nodes(self) -> list[NetworkNode]:
return [n for n in self.nodes.values() if n.compromised]
def find_attack_path(self, source: str, target: str) -> list[str] | None:
"""Find shortest path between two nodes."""
try:
return nx.shortest_path(self.graph, source, target)
except (nx.NetworkXNoPath, nx.NodeNotFound):
return None
def get_vulnerable_nodes(self, threshold: float = 0.3) -> list[NetworkNode]:
"""Find nodes with high vulnerability scores."""
vulnerable = []
for node in self.nodes.values():
max_vuln = max((p.vulnerability_score for p in node.ports), default=0.0)
if max_vuln >= threshold:
vulnerable.append(node)
return vulnerable