Spaces:
Paused
Paused
File size: 5,988 Bytes
e7b864e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 | """
ImmunoOrg Inference Script
==========================
OpenEnv-compatible inference with [START], [STEP], [END] logging.
Demonstrates agent interaction with the environment.
"""
from __future__ import annotations
import json
import os
import sys
import time
from typing import Any
import requests
BASE_URL = os.environ.get("IMMUNOORG_URL", "http://localhost:7860")
def log_start(task_id: str, config: dict[str, Any]) -> None:
print(f"[START] task={task_id} config={json.dumps(config)}")
def log_step(step: int, action: dict, observation_summary: str, reward: float) -> None:
print(f"[STEP] step={step} action={json.dumps(action)} obs={observation_summary} reward={reward:.4f}")
def log_end(task_id: str, total_reward: float, metrics: dict[str, Any]) -> None:
print(f"[END] task={task_id} total_reward={total_reward:.4f} metrics={json.dumps(metrics)}")
def run_episode(task_id: str = "level1_single_attack", difficulty: int = 1) -> dict[str, Any]:
"""Run a single episode with a rule-based baseline agent."""
config = {"task": task_id, "difficulty": difficulty}
log_start(task_id, config)
# Reset
resp = requests.post(f"{BASE_URL}/reset", json={"task": task_id, "difficulty": difficulty})
resp.raise_for_status()
data = resp.json()
obs = data["observation"]
total_reward = 0.0
step = 0
done = False
while not done and step < 200:
# Simple rule-based agent
action = decide_action(obs, step)
# Step
resp = requests.post(f"{BASE_URL}/step", json={"action": action})
resp.raise_for_status()
result = resp.json()
obs = result["observation"]
reward = result["reward"]
done = result.get("done", result.get("terminated", False))
total_reward += reward
step += 1
obs_summary = f"phase={obs.get('current_phase', '?')} threat={obs.get('threat_level', 0):.2f}"
log_step(step, action, obs_summary, reward)
metrics = {
"steps": step,
"total_reward": total_reward,
"final_phase": obs.get("current_phase", "unknown"),
"threat_level": obs.get("threat_level", 0),
"downtime": obs.get("system_downtime", 0),
}
log_end(task_id, total_reward, metrics)
return metrics
def decide_action(obs: dict, step: int) -> dict[str, Any]:
"""Rule-based baseline agent for demonstration."""
phase = obs.get("current_phase", "detection")
attacks = obs.get("detected_attacks", [])
nodes = obs.get("visible_nodes", [])
if phase == "detection":
# Scan logs on first available node
compromised = [n for n in nodes if n.get("compromised")]
if compromised:
return {
"action_type": "tactical",
"tactical_action": "scan_logs",
"target": compromised[0]["id"],
"reasoning": f"Scanning logs on compromised node {compromised[0]['id']} for attack indicators.",
}
# Scan a random node
if nodes:
return {
"action_type": "tactical",
"tactical_action": "scan_logs",
"target": nodes[0]["id"],
"reasoning": "Scanning logs for initial threat detection.",
}
elif phase == "containment":
# Isolate compromised nodes
compromised = [n for n in nodes if n.get("compromised") and not n.get("isolated")]
if compromised:
return {
"action_type": "tactical",
"tactical_action": "isolate_node",
"target": compromised[0]["id"],
"reasoning": f"Isolating compromised node {compromised[0]['id']} to prevent lateral movement.",
}
elif phase == "rca":
# Try to identify silos
if step % 3 == 0:
return {
"action_type": "diagnostic",
"diagnostic_action": "identify_silo",
"target": "",
"reasoning": "Identifying organizational silos that may have delayed incident response.",
}
# Correlate failures
if attacks:
vector = attacks[0].get("vector", "unknown")
return {
"action_type": "diagnostic",
"diagnostic_action": "correlate_failure",
"target": attacks[0].get("target_node", ""),
"parameters": {
"technical_indicator": f"{vector}_attack",
"organizational_flaw": "no_devsecops",
"confidence": 0.7,
},
"reasoning": f"Correlating {vector} attack to potential DevSecOps gap.",
}
elif phase == "refactor":
# Create shortcut between security and engineering
return {
"action_type": "strategic",
"strategic_action": "establish_devsecops",
"target": "dept-security",
"secondary_target": "dept-engineering",
"reasoning": "Establishing DevSecOps integration to prevent future code-level vulnerabilities.",
}
elif phase == "validation":
return {
"action_type": "diagnostic",
"diagnostic_action": "measure_org_latency",
"target": "",
"reasoning": "Validating that organizational changes improved response efficiency.",
}
# Default: escalate
return {
"action_type": "tactical",
"tactical_action": "escalate_alert",
"target": "",
"reasoning": "Escalating alert to increase threat awareness.",
}
if __name__ == "__main__":
task = sys.argv[1] if len(sys.argv) > 1 else "level1_single_attack"
difficulty = int(sys.argv[2]) if len(sys.argv) > 2 else 1
run_episode(task, difficulty)
|