File size: 14,788 Bytes
c8d30bc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 | """
ThreatHunter 雙層記憶學習 Tool
==============================
Layer 1: JSON 持久化(穩定保底,Day 1 起可用)
Layer 2: LlamaIndex RAG(語義搜尋,ENABLE_MEMORY_RAG=true 啟用)
遵循文件:
- FINAL_PLAN.md §支柱 3(Feedback Loops:雙層記憶學習系統)
- leader_plan.md §記憶學習系統
"""
import json
import logging
from pathlib import Path
from datetime import datetime, timezone
from typing import Any
from config import MEMORY_DIR, ENABLE_MEMORY_RAG, SIMILARITY_THRESHOLD
from crewai.tools import tool
logger = logging.getLogger("threathunter.memory")
# Sandbox Layer 3: Memory cache sanitization
try:
from sandbox.memory_sanitizer import sanitize_memory_write as _sanitize_write
_MEM_SANITIZER_OK = True
except ImportError:
def _sanitize_write(data, agent_name=''): # type: ignore[misc]
return True, data, 'ok'
_MEM_SANITIZER_OK = False
# ── 常數 ─────────────────────────────────────────────────────
VALID_AGENT_NAMES = {"scout", "analyst", "advisor", "critic", "orchestrator"}
# ── Layer 1: JSON 持久化工具函式 ─────────────────────────────
def _get_memory_path(agent_name: str) -> Path:
"""取得指定 Agent 的記憶 JSON 路徑"""
return MEMORY_DIR / f"{agent_name}_memory.json"
def _load_json(path: Path) -> dict:
"""安全載入 JSON,檔案不存在或損壞時回傳空 dict"""
if not path.exists():
return {}
try:
content = path.read_text(encoding="utf-8")
if not content.strip():
return {}
return json.loads(content)
except (json.JSONDecodeError, OSError) as e:
logger.warning("[WARN] Memory file read failed %s: %s, returning empty", path, e)
return {}
def _save_json(path: Path, data: dict) -> None:
"""安全寫入 JSON(先寫臨時檔再 rename,防止寫入中斷導致損壞)"""
path.parent.mkdir(parents=True, exist_ok=True)
temp_path = path.with_suffix(".tmp")
try:
temp_path.write_text(
json.dumps(data, ensure_ascii=False, indent=2),
encoding="utf-8",
)
temp_path.replace(path)
except OSError as e:
logger.error("[FAIL] Memory file write failed %s: %s", path, e)
if temp_path.exists():
temp_path.unlink()
raise
# ── Layer 2: LlamaIndex RAG(條件性啟用)─────────────────────
_rag_index = None
_rag_query_engine = None
def _init_rag() -> None:
"""延遲初始化 LlamaIndex RAG(只在第一次呼叫時執行)"""
global _rag_index, _rag_query_engine
if not ENABLE_MEMORY_RAG:
return
if _rag_index is not None:
return
try:
from llama_index.core import VectorStoreIndex, StorageContext, Settings
from llama_index.core import load_index_from_storage
# ── 設定 Free Local Embedding(不需要 OpenAI API Key)──
# 使用 HuggingFace BAAI/bge-small-en-v1.5:輕量、快速、免費
try:
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
Settings.embed_model = HuggingFaceEmbedding(
model_name="BAAI/bge-small-en-v1.5"
)
logger.info("[OK] Embedding: HuggingFace BAAI/bge-small-en-v1.5 (local free)")
except ImportError:
logger.warning("[WARN] HuggingFace embedding not installed, trying OpenAI embedding")
# 停用 LLM(RAG 記憶層只需要 embedding,不需要 LLM 生成)
try:
from llama_index.core.llms import MockLLM
Settings.llm = MockLLM()
except Exception:
Settings.llm = None # type: ignore
vector_store_path = MEMORY_DIR / "vector_store"
if (vector_store_path / "docstore.json").exists():
storage_context = StorageContext.from_defaults(
persist_dir=str(vector_store_path)
)
_rag_index = load_index_from_storage(storage_context)
logger.info("[OK] LlamaIndex vector index loaded")
_rag_query_engine = _rag_index.as_query_engine(similarity_top_k=3)
else:
_rag_index = VectorStoreIndex([])
logger.info("[OK] LlamaIndex vector index created (empty)")
# 先設 query_engine,再回填(_backfill 需要 _rag_index 已就緒)
_rag_query_engine = _rag_index.as_query_engine(similarity_top_k=3)
_backfill_from_json_history()
# 回填後重建 query_engine(索引已有資料)
_rag_query_engine = _rag_index.as_query_engine(similarity_top_k=3)
except ImportError:
logger.warning("[WARN] LlamaIndex not installed, RAG disabled")
except Exception as e:
logger.warning("[WARN] LlamaIndex init failed: %s", e)
def _backfill_from_json_history() -> None:
"""
首次啟用 RAG 時,將所有已有的 *_memory.json 批次回填進 LlamaIndex。
只在 vector_store 新建(非載入)時執行一次,之後透過 persist 保存。
設計原則(Harness Engineering — Feedback Loops 支柱):
- 確保歷史掃描記錄不因 RAG 冷啟動而遺失語義感知
- 失敗不阻塞(Graceful Degradation)
"""
if _rag_index is None:
return
total_inserted = 0
for agent_name in VALID_AGENT_NAMES:
json_path = _get_memory_path(agent_name)
if not json_path.exists():
continue
data = _load_json(json_path)
if not data:
continue
# 回填 latest(最新掃描)
_rag_insert(agent_name, data)
total_inserted += 1
# 回填 history[] 陣列中的每一筆歷史掃描
for hist_scan in data.get("history", []):
if isinstance(hist_scan, dict) and hist_scan:
_rag_insert(agent_name, hist_scan)
total_inserted += 1
if total_inserted > 0:
logger.info("[OK] RAG history backfill done: %d scan records vectorized", total_inserted)
else:
logger.info("[INFO] RAG backfill: no historical JSON memory to backfill (first scan)")
def _extract_package_names(tech_stack: str) -> set[str]:
"""
從技術棧字串中提取套件名稱(小寫、去版本號)。
範例:
'Django 4.2, Redis 7.0' -> {'django', 'redis'}
'Spring Boot 3.1 和 Node.js 18' -> {'spring', 'boot', 'node.js'}
"""
if not tech_stack:
return set()
names = set()
for part in tech_stack.replace(",", " ").split():
clean = part.strip().lower()
# 跳過版本號(純數字或 x.y.z 格式)
if clean and not clean.replace(".", "").replace("-", "").isdigit():
names.add(clean)
return names
def _rag_insert(agent_name: str, data: dict) -> None:
"""將資料插入 LlamaIndex 向量索引(雙寫的 Layer 2),含 tech_stack 元資料"""
if not ENABLE_MEMORY_RAG or _rag_index is None:
return
try:
from llama_index.core import Document
# 提取 tech_stack:可能在 data 的不同欄位中
tech_stack = (
data.get("tech_stack", "")
or data.get("tech_stack_input", "")
or ""
)
# 如果 tech_stack 是 list,轉成字串
if isinstance(tech_stack, list):
tech_stack = ", ".join(str(t) for t in tech_stack)
doc = Document(
text=json.dumps(data, ensure_ascii=False),
metadata={
"agent": agent_name,
"timestamp": data.get("timestamp", ""),
"scan_id": data.get("scan_id", ""),
"tech_stack": str(tech_stack),
},
)
_rag_index.insert(doc)
vector_store_path = MEMORY_DIR / "vector_store"
_rag_index.storage_context.persist(persist_dir=str(vector_store_path))
logger.info("[OK] RAG index updated: %s (tech_stack=%s)", agent_name, tech_stack[:50])
except Exception as e:
logger.warning("[WARN] RAG write failed (JSON layer unaffected): %s", e)
def _rag_search(query: str, tech_stack: str | None = None) -> str:
"""
語義搜尋(帶安全閥 + 技術棧相關性過濾)。
Args:
query: 搜尋查詢
tech_stack: 當前掃描的技術棧(用於過濾不相關歷史)
"""
if not ENABLE_MEMORY_RAG:
return "RAG disabled (ENABLE_MEMORY_RAG=false)"
_init_rag()
if _rag_index is None:
return "RAG index unavailable"
try:
doc_count = len(_rag_index.docstore.docs) if hasattr(_rag_index, "docstore") else 0
if doc_count == 0:
return "No history available (vector index empty)"
response = _rag_query_engine.query(query)
# 相關性門檻過濾
if hasattr(response, "source_nodes") and response.source_nodes:
scores = [n.score for n in response.source_nodes if n.score is not None]
max_score = max(scores) if scores else 0
if max_score < SIMILARITY_THRESHOLD:
return (
f"No relevant history found"
f" (max_similarity {max_score:.2f} < threshold {SIMILARITY_THRESHOLD})"
)
# 技術棧相關性過濾:只保留與當前掃描套件有交集的歷史
if tech_stack:
current_packages = _extract_package_names(tech_stack)
if current_packages:
filtered_nodes = []
for node in response.source_nodes:
node_tech = node.metadata.get("tech_stack", "")
if not node_tech:
# 無 tech_stack 元資料的舊記錄,保守保留
filtered_nodes.append(node)
continue
node_packages = _extract_package_names(node_tech)
# 有套件名稱交集才保留
if current_packages & node_packages:
filtered_nodes.append(node)
else:
logger.info(
"[FILTER] Excluded history: %s (no overlap with %s)",
node_tech[:50], tech_stack[:50],
)
if not filtered_nodes:
return (
f"No relevant history for current tech stack"
f" (filtered {len(response.source_nodes)} results, 0 matched)"
)
response.source_nodes = filtered_nodes
return str(response)
except Exception as e:
logger.warning("[WARN] RAG search failed: %s", e)
return f"RAG search failed: {e}"
# ── CrewAI Tool 定義 ─────────────────────────────────────────
@tool("read_memory")
def read_memory(agent_name: str) -> str:
"""
讀取指定 Agent 的歷史記憶(JSON Layer 1:穩定保底)。
0 份歷史回傳空 JSON,Agent 可據此判斷是否為第一次掃描。
Args:
agent_name: Agent 名稱(scout / analyst / advisor / critic / orchestrator)
Returns:
JSON 字串格式的歷史記憶
"""
agent_name = agent_name.strip().lower()
if agent_name not in VALID_AGENT_NAMES:
logger.warning("[WARN] Invalid agent name: %s", agent_name)
return json.dumps({}, ensure_ascii=False)
data = _load_json(_get_memory_path(agent_name))
if not data:
logger.info("[INFO] %s has no history (first scan)", agent_name)
else:
logger.info("[OK] %s memory loaded (scan_id: %s)", agent_name, data.get('scan_id', 'N/A'))
return json.dumps(data, ensure_ascii=False, indent=2)
@tool("write_memory")
def write_memory(agent_name: str, data: str) -> str:
"""
寫入 Agent 記憶(雙寫:JSON + LlamaIndex)。自動添加 timestamp。
Args:
agent_name: Agent 名稱(scout / analyst / advisor / critic / orchestrator)
data: JSON 字串格式的記憶資料
Returns:
寫入結果訊息
"""
agent_name = agent_name.strip().lower()
if agent_name not in VALID_AGENT_NAMES:
return f"[FAIL] Invalid agent name: {agent_name} (allowed: {VALID_AGENT_NAMES})"
try:
memory_data = json.loads(data) if isinstance(data, str) else data
except json.JSONDecodeError as e:
return f"[FAIL] JSON format error: {e}"
# Sandbox Layer 3: poison filter before write
is_safe, clean_data, reason = _sanitize_write(memory_data, agent_name)
if not is_safe:
logger.warning('[MEMORY][SANDBOX] Write BLOCKED: %s', reason)
return '[BLOCKED] Memory write rejected by Sandbox: ' + reason
memory_data = clean_data
memory_data["timestamp"] = datetime.now(timezone.utc).isoformat()
# Layer 1: JSON — 累積 history[] 陣列
try:
existing = _load_json(_get_memory_path(agent_name))
history = existing.get("history", [])
# 若已有舊的 latest,推入 history(最多保留 50 筆,防止無限增長)
if existing and "scan_id" in existing:
old_entry = {k: v for k, v in existing.items() if k != "history"}
history.append(old_entry)
if len(history) > 50:
history = history[-50:] # 保留最新 50 筆
memory_data["history"] = history
_save_json(_get_memory_path(agent_name), memory_data)
logger.info("[OK] %s memory saved to JSON (Layer 1 | history=%d records)", agent_name, len(history))
except Exception as e:
return f"[FAIL] JSON write failed: {e}"
# Layer 2: LlamaIndex(雙寫)
_rag_insert(agent_name, memory_data)
return f"[OK] {agent_name} memory saved (timestamp: {memory_data['timestamp']})"
@tool("history_search")
def history_search(query: str, tech_stack: str = "") -> str:
"""
語義搜尋歷史安全報告(帶技術棧過濾)。
帶安全閥:索引為空 / 分數太低 / 技術棧不匹配 / RAG 未啟用 → 回傳提示。
Args:
query: 搜尋查詢(例如:"Django SSRF 歷史")
tech_stack: 當前掃描的技術棧(例如:"Django 4.2, Redis 7.0"),
用於過濾不相關的歷史記錄
Returns:
搜尋結果或安全提示
"""
return _rag_search(query, tech_stack if tech_stack else None)
|