import json import os from typing import Tuple, Dict, Any from models import Observation, Action, Reward, LogEntry, ActionType class SOCEnvironment: def __init__(self): # Initialize the internal state variables self.current_time = "2026-03-31T00:00:00Z" self.active_alerts = [] self.system_status = {"database_online": True, "backup_server_online": True} # Track agent actions self.blocked_ips = set() self.isolated_hosts = set() self.current_task_id = None self.step_count = 0 # Load the synthetic BOTS data safely data_path = os.path.join(os.path.dirname(__file__), 'data', 'mock_bots.json') with open(data_path, 'r') as f: self.mock_data = json.load(f) def state(self) -> Observation: """Returns the current state of the environment perfectly formatted to our Pydantic model.""" return Observation( current_time=self.current_time, active_alerts=self.active_alerts, system_status=self.system_status ) def reset(self, task_id: str = "task_1_triage") -> Observation: """Resets the environment for a specific task and loads the appropriate logs.""" self.current_task_id = task_id self.step_count = 0 self.blocked_ips.clear() self.isolated_hosts.clear() self.system_status = {"database_online": True, "backup_server_online": True} # Map the task_id from openenv.yaml to our JSON keys task_mapping = { "task_1_triage": "scenario_1_triage", "task_2_false_positive": "scenario_2_false_positive", "task_3_kill_chain": "scenario_3_kill_chain" } scenario_key = task_mapping.get(task_id, "scenario_1_triage") raw_logs = self.mock_data.get(scenario_key, []) # Convert raw JSON dicts into our strict Pydantic LogEntry models self.active_alerts = [LogEntry(**log) for log in raw_logs] # Set the starting time based on the first log if self.active_alerts: self.current_time = self.active_alerts[0].timestamp return self.state() def step(self, action: Action) -> Tuple[Observation, Reward, bool, Dict[str, Any]]: """Executes the agent's action, updates the state, and returns the result.""" self.step_count += 1 score_delta = 0.0 done = False info = {"action_taken": action.action_type, "target": action.target_ip, "msg": ""} # 1. Process the Action if action.action_type == ActionType.BLOCK_IP: if action.target_ip: self.blocked_ips.add(action.target_ip) info["msg"] = f"Successfully blocked IP: {action.target_ip}" else: score_delta = -0.1 # Penalty for missing target info["msg"] = "Failed: Target IP required for BLOCK_IP." elif action.action_type == ActionType.ISOLATE_HOST: if action.target_ip: self.isolated_hosts.add(action.target_ip) info["msg"] = f"Isolated Host: {action.target_ip}" # Hardcoded consequence: Isolating the core backup server takes it offline if action.target_ip == "10.0.0.250": self.system_status["backup_server_online"] = False score_delta = -0.5 # Immediate penalty for breaking a critical system else: score_delta = -0.1 elif action.action_type == ActionType.DISMISS_ALERT: info["msg"] = "Alerts dismissed." elif action.action_type == ActionType.INVESTIGATE: info["msg"] = f"Pulled deeper telemetry for {action.target_ip}." elif action.action_type == ActionType.ESCALATE_TO_HUMAN: info["msg"] = "Escalated to Tier 2 Human Analyst." done = True # Escalation immediately ends the agent's turn # 2. Check Episode Termination Conditions (Max 10 steps to prevent infinite loops) if self.step_count >= 10: done = True info["msg"] = "Max steps reached. Operation timed out." # Task-specific early termination if self.current_task_id == "task_1_triage" and action.action_type == ActionType.BLOCK_IP: done = True elif self.current_task_id == "task_2_false_positive" and action.action_type in [ActionType.DISMISS_ALERT, ActionType.ISOLATE_HOST]: done = True return self.state(), Reward(score_delta=score_delta), done, info